├── .gitignore ├── requirements.txt ├── images ├── samples.png └── thumbnail.png ├── gflownet ├── utils.py ├── env.py ├── log.py └── gflownet.py ├── LICENSE ├── policy.py ├── grid.py ├── train.py └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | __pycache__/ 3 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib==3.5.3 2 | torch==1.12.1 3 | tqdm==4.50.0 4 | -------------------------------------------------------------------------------- /images/samples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/augustwester/gflownet/HEAD/images/samples.png -------------------------------------------------------------------------------- /images/thumbnail.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/augustwester/gflownet/HEAD/images/thumbnail.png -------------------------------------------------------------------------------- /gflownet/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def trajectory_balance_loss(total_flow, rewards, fwd_probs, back_probs): 4 | """ 5 | Computes the mean trajectory balance loss for a collection of samples. For 6 | more information, see Bengio et al. (2022): https://arxiv.org/abs/2201.13259 7 | 8 | Args: 9 | total_flow: The estimated total flow used by the GFlowNet when drawing 10 | the collection of samples for which the loss should be computed 11 | 12 | rewards: The rewards associated with the final state of each of the 13 | samples 14 | 15 | fwd_probs: The forward probabilities associated with the trajectory of 16 | each sample (i.e. the probabilities of the actions actually taken in 17 | each trajectory) 18 | 19 | back_probs: The backward probabilities associated with each trajectory 20 | """ 21 | lhs = total_flow * torch.prod(fwd_probs, dim=1) 22 | rhs = rewards * torch.prod(back_probs, dim=1) 23 | loss = torch.log(lhs / rhs)**2 24 | return loss.mean() 25 | -------------------------------------------------------------------------------- /gflownet/env.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | class Env(ABC): 4 | """ 5 | Abstract base class defining the signatures of the required functions to be 6 | implemented in a GFlowNet environment. 7 | """ 8 | @abstractmethod 9 | def update(self, s, actions): 10 | """ 11 | Takes as input state-action pairs and returns the resulting states. 12 | 13 | Args: 14 | s: An NxD matrix of state vectors 15 | 16 | actions: An Nx1 vector of actions 17 | """ 18 | pass 19 | 20 | @abstractmethod 21 | def mask(self, s): 22 | """ 23 | Defines a mask to disallow certain actions given certain states. 24 | 25 | Args: 26 | s: An NxD matrix of state vectors 27 | """ 28 | pass 29 | 30 | @abstractmethod 31 | def reward(self, s): 32 | """ 33 | Defines a reward function, mapping states to rewards. 34 | 35 | Args: 36 | s: An NxD matrix of state vectors 37 | """ 38 | pass -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 August Wester 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /policy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn.functional import relu 4 | from torch.nn.functional import softmax 5 | 6 | class ForwardPolicy(nn.Module): 7 | def __init__(self, state_dim, hidden_dim, num_actions): 8 | super().__init__() 9 | self.dense1 = nn.Linear(state_dim, hidden_dim) 10 | self.dense2 = nn.Linear(hidden_dim, num_actions) 11 | 12 | def forward(self, s): 13 | x = self.dense1(s) 14 | x = relu(x) 15 | x = self.dense2(x) 16 | return softmax(x, dim=1) 17 | 18 | class BackwardPolicy: 19 | def __init__(self, state_dim, num_actions): 20 | super().__init__() 21 | self.num_actions = num_actions 22 | self.size = int(state_dim**0.5) 23 | 24 | def __call__(self, s): 25 | idx = s.argmax(-1) 26 | at_top_edge = idx < self.size 27 | at_left_edge = (idx > 0) & (idx % self.size == 0) 28 | 29 | probs = 0.5 * torch.ones(len(s), self.num_actions) 30 | probs[at_left_edge] = torch.Tensor([1, 0, 0]) 31 | probs[at_top_edge] = torch.Tensor([0, 1, 0]) 32 | probs[:, -1] = 0 # disregard termination 33 | 34 | return probs -------------------------------------------------------------------------------- /grid.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn.functional import one_hot 3 | from gflownet.env import Env 4 | 5 | class Grid(Env): 6 | def __init__(self, size): 7 | self.size = size 8 | self.state_dim = size**2 9 | self.num_actions = 3 # down, right, terminate 10 | 11 | def update(self, s, actions): 12 | idx = s.argmax(1) 13 | down, right = actions == 0, actions == 1 14 | idx[down] = idx[down] + self.size 15 | idx[right] = idx[right] + 1 16 | return one_hot(idx, self.state_dim).float() 17 | 18 | def mask(self, s): 19 | mask = torch.ones(len(s), self.num_actions) 20 | idx = s.argmax(1) + 1 21 | right_edge = (idx > 0) & (idx % (self.size) == 0) 22 | bottom_edge = idx > self.size*(self.size-1) 23 | mask[right_edge, 1] = 0 24 | mask[bottom_edge, 0] = 0 25 | return mask 26 | 27 | def reward(self, s): 28 | grid = s.view(len(s), self.size, self.size) 29 | coord = (grid == 1).nonzero()[:, 1:].view(len(s), 2) 30 | R0, R1, R2 = 1e-2, 0.5, 2 31 | norm = torch.abs(coord / (self.size-1) - 0.5) 32 | R1_term = torch.prod(0.25 < norm, dim=1) 33 | R2_term = torch.prod((0.3 < norm) & (norm < 0.4), dim=1) 34 | return (R0 + R1*R1_term + R2*R2_term) -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import matplotlib.pyplot as plt 3 | import argparse 4 | from tqdm import tqdm 5 | from torch.nn.functional import one_hot 6 | from gflownet.gflownet import GFlowNet 7 | from policy import ForwardPolicy, BackwardPolicy 8 | from gflownet.utils import trajectory_balance_loss 9 | from torch.optim import Adam 10 | from grid import Grid 11 | 12 | size = 16 13 | 14 | def plot(samples, env): 15 | _, ax = plt.subplots(1, 2) 16 | s = samples.sum(0).view(size, size) 17 | e = env.reward(torch.eye(env.state_dim)).view(size, size) 18 | 19 | ax[0].matshow(s.numpy()) 20 | ax[0].set_title("Samples") 21 | ax[1].matshow(e.numpy()) 22 | ax[1].set_title("Environment") 23 | 24 | plt.show() 25 | 26 | def train(batch_size, num_epochs): 27 | env = Grid(size=size) 28 | forward_policy = ForwardPolicy(env.state_dim, hidden_dim=32, num_actions=env.num_actions) 29 | backward_policy = BackwardPolicy(env.state_dim, num_actions=env.num_actions) 30 | model = GFlowNet(forward_policy, backward_policy, env) 31 | opt = Adam(model.parameters(), lr=5e-3) 32 | 33 | for i in (p := tqdm(range(num_epochs))): 34 | s0 = one_hot(torch.zeros(batch_size).long(), env.state_dim).float() 35 | s, log = model.sample_states(s0, return_log=True) 36 | loss = trajectory_balance_loss(log.total_flow, 37 | log.rewards, 38 | log.fwd_probs, 39 | log.back_probs) 40 | loss.backward() 41 | opt.step() 42 | opt.zero_grad() 43 | if i % 10 == 0: p.set_description(f"{loss.item():.3f}") 44 | 45 | s0 = one_hot(torch.zeros(10**4).long(), env.state_dim).float() 46 | s = model.sample_states(s0, return_log=False) 47 | plot(s, env) 48 | 49 | if __name__ == "__main__": 50 | parser = argparse.ArgumentParser() 51 | parser.add_argument("--batch_size", type=int, default=256) 52 | parser.add_argument("--num_epochs", type=int, default=1000) 53 | 54 | args = parser.parse_args() 55 | batch_size = args.batch_size 56 | num_epochs = args.num_epochs 57 | 58 | train(batch_size, num_epochs) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GFlowNet in PyTorch 2 | 3 | ![gflownet](images/thumbnail.png) 4 | 5 | This repo is associated with the blog post ["Proportional Reward Sampling With GFlowNets"](https://sigmoidprime.com/post/gflownets/) over at [sigmoid prime](https://sigmoidprime.com). It contains an implementation of a Generative Flow Network (GFlowNet), proposed by Bengio et al. in the paper ["Flow Network based Generative Models for Non-Iterative Diverse Candidate Generation"](https://arxiv.org/abs/2106.04399) (2021). 6 | 7 | The model is trained using online learning (i.e. by continually evaluating samples drawn from the model's own policy rather than a fixed set of samples drawn from another policy) and the [trajectory balance loss](https://arxiv.org/abs/2201.13259). We evaluate the model's performance using the grid domain of the original paper. This is visualized by the end of training. 8 | 9 | ![samples](images/samples.png) 10 | 11 | The code for training the model is simple: 12 | 13 | 1. Initialize the grid environment using a grid size 14 | 2. Define a policy network taking a state vector as input and outputting a vector of probabilities over possible actions. (In the grid domain, the number of actions is three: **Down**, **Right**, and **Terminate**.) 15 | 3. Define a backward policy. In this case, the policy is not estimated but fixed to 0.5 for all parent states (except when there is only one parent state). 16 | 17 | With this, you initialize the GFlowNet along with the optimizer to use during training. 18 | 19 | ```python 20 | env = Grid(size=16) 21 | forward_policy = ForwardPolicy(env.state_dim, hidden_dim=32, num_actions=3) 22 | model = GFlowNet(forward_policy, backward_policy, env) 23 | opt = Adam(model.parameters(), lr=5e-3) 24 | ``` 25 | 26 | To train the model, construct an NxD matrix of initial states, where N is the desired number of samples and D is the dimensionality of the state vector (i.e. `state_dim`). Then, draw samples from the model using the `sample_states` method, giving it the initial states and setting `return_log=True`. The resulting `Log` object contains information about the trajectory of each sample, which is used to compute the trajectory balance loss. 27 | 28 | ```python 29 | for i in range(num_epochs): 30 | s0 = one_hot(torch.zeros(batch_size).long(), env.state_dim).float() 31 | s, log = model.sample_states(s0, return_log=True) 32 | loss = trajectory_balance_loss(log.total_flow, log.rewards, log.fwd_probs, log.back_probs) 33 | loss.backward() 34 | opt.step() 35 | opt.zero_grad() 36 | ``` 37 | 38 | Finally, when the model has been trained, you can sample states using the same `sample_states(...)` method as before, this time without needing to supply the `return_log=True` argument. 39 | 40 | ```python 41 | s0 = one_hot(torch.zeros(10**4).long(), env.state_dim).float() 42 | s = model.sample_states(s0) 43 | ``` 44 | -------------------------------------------------------------------------------- /gflownet/log.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class Log: 4 | def __init__(self, s0, backward_policy, total_flow, env): 5 | """ 6 | Initializes a Stats object to record sampling statistics from a 7 | GFlowNet (e.g. trajectories, forward and backward probabilities, 8 | actions, etc.) 9 | 10 | Args: 11 | s0: The initial state of collection of samples 12 | 13 | backward_policy: The backward policy used to estimate the backward 14 | probabilities associated with each sample's trajectory 15 | 16 | total_flow: The estimated total flow used by the GFlowNet during 17 | sampling 18 | 19 | env: The environment (i.e. state space and reward function) from 20 | which samples are drawn 21 | """ 22 | self.backward_policy = backward_policy 23 | self.total_flow = total_flow 24 | self.env = env 25 | self._traj = [s0.view(len(s0), 1, -1)] 26 | self._fwd_probs = [] 27 | self._back_probs = None 28 | self._actions = [] 29 | self.rewards = torch.zeros(len(s0)) 30 | self.num_samples = s0.shape[0] 31 | 32 | def log(self, s, probs, actions, done): 33 | """ 34 | Logs relevant information about each sampling step 35 | 36 | Args: 37 | s: An NxD matrix containing he current state of complete and 38 | incomplete samples 39 | 40 | probs: An NxA matrix containing the forward probabilities output by the 41 | GFlowNet for the given states 42 | 43 | actions: A Nx1 vector containing the actions taken by the GFlowNet 44 | in the given states 45 | 46 | done: An Nx1 Boolean vector indicating which samples are complete 47 | (True) and which are incomplete (False) 48 | """ 49 | had_terminating_action = actions == probs.shape[-1] - 1 50 | active, just_finished = ~done, ~done 51 | active[active == True] = ~had_terminating_action 52 | just_finished[just_finished == True] = had_terminating_action 53 | 54 | states = self._traj[-1].squeeze(1).clone() 55 | states[active] = s[active] 56 | self._traj.append(states.view(self.num_samples, 1, -1)) 57 | 58 | fwd_probs = torch.ones(self.num_samples, 1) 59 | fwd_probs[~done] = probs.gather(1, actions.unsqueeze(1)) 60 | self._fwd_probs.append(fwd_probs) 61 | 62 | _actions = -torch.ones(self.num_samples, 1).long() 63 | _actions[~done] = actions.unsqueeze(1) 64 | self._actions.append(_actions) 65 | 66 | self.rewards[just_finished] = self.env.reward(s[just_finished]) 67 | 68 | @property 69 | def traj(self): 70 | if type(self._traj) is list: 71 | self._traj = torch.cat(self._traj, dim=1)[:, :-1, :] 72 | return self._traj 73 | 74 | @property 75 | def fwd_probs(self): 76 | if type(self._fwd_probs) is list: 77 | self._fwd_probs = torch.cat(self._fwd_probs, dim=1) 78 | return self._fwd_probs 79 | 80 | @property 81 | def actions(self): 82 | if type(self._actions) is list: 83 | self._actions = torch.cat(self._actions, dim=1) 84 | return self._actions 85 | 86 | @property 87 | def back_probs(self): 88 | if self._back_probs is not None: 89 | return self._back_probs 90 | 91 | s = self.traj[:, 1:, :].reshape(-1, self.env.state_dim) 92 | prev_s = self.traj[:, :-1, :].reshape(-1, self.env.state_dim) 93 | actions = self.actions[:, :-1].flatten() 94 | 95 | terminated = (actions == -1) | (actions == self.env.num_actions - 1) 96 | zero_to_n = torch.arange(len(terminated)) 97 | back_probs = self.backward_policy(s) * self.env.mask(prev_s) 98 | back_probs = torch.where(terminated, 1, back_probs[zero_to_n, actions]) 99 | self._back_probs = back_probs.reshape(self.num_samples, -1) 100 | 101 | return self._back_probs -------------------------------------------------------------------------------- /gflownet/gflownet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn.parameter import Parameter 4 | from torch.distributions import Categorical 5 | from .log import Log 6 | 7 | class GFlowNet(nn.Module): 8 | def __init__(self, forward_policy, backward_policy, env): 9 | """ 10 | Initializes a GFlowNet using the specified forward and backward policies 11 | acting over an environment, i.e. a state space and a reward function. 12 | 13 | Args: 14 | forward_policy: A policy network taking as input a state and 15 | outputting a vector of probabilities over actions 16 | 17 | backward_policy: A policy network (or fixed function) taking as 18 | input a state and outputting a vector of probabilities over the 19 | actions which led to that state 20 | 21 | env: An environment defining a state space and an associated reward 22 | function 23 | """ 24 | super().__init__() 25 | self.total_flow = Parameter(torch.ones(1)) 26 | self.forward_policy = forward_policy 27 | self.backward_policy = backward_policy 28 | self.env = env 29 | 30 | def mask_and_normalize(self, s, probs): 31 | """ 32 | Masks a vector of action probabilities to avoid illegal actions (i.e. 33 | actions that lead outside the state space). 34 | 35 | Args: 36 | s: An NxD matrix representing N states 37 | 38 | probs: An NxA matrix of action probabilities 39 | """ 40 | probs = self.env.mask(s) * probs 41 | return probs / probs.sum(1).unsqueeze(1) 42 | 43 | def forward_probs(self, s): 44 | """ 45 | Returns a vector of probabilities over actions in a given state. 46 | 47 | Args: 48 | s: An NxD matrix representing N states 49 | """ 50 | probs = self.forward_policy(s) 51 | return self.mask_and_normalize(s, probs) 52 | 53 | def sample_states(self, s0, return_log=False): 54 | """ 55 | Samples and returns a collection of final states from the GFlowNet. 56 | 57 | Args: 58 | s0: An NxD matrix of initial states 59 | 60 | return_log: Return an object containing information about the 61 | sampling process (e.g. the trajectory of each sample, the forward 62 | and backward probabilities, the actions taken, etc.) 63 | """ 64 | s = s0.clone() 65 | done = torch.BoolTensor([False] * len(s)) 66 | log = Log(s0, self.backward_policy, self.total_flow, self.env) if return_log else None 67 | 68 | while not done.all(): 69 | probs = self.forward_probs(s[~done]) 70 | actions = Categorical(probs).sample() 71 | s[~done] = self.env.update(s[~done], actions) 72 | 73 | if return_log: 74 | log.log(s, probs, actions, done) 75 | 76 | terminated = actions == probs.shape[-1] - 1 77 | done[~done] = terminated 78 | 79 | return (s, log) if return_log else s 80 | 81 | def evaluate_trajectories(self, traj, actions): 82 | """ 83 | Returns the GFlowNet's estimated forward probabilities, backward 84 | probabilities, and rewards for a collection of trajectories. This is 85 | useful in an offline learning context where samples drawn according to 86 | another policy (e.g. a random one) are used to train the model. 87 | 88 | Args: 89 | traj: The trajectory of each sample 90 | 91 | actions: The actions that produced the trajectories in traj 92 | """ 93 | num_samples = len(traj) 94 | traj = traj.reshape(-1, traj.shape[-1]) 95 | actions = actions.flatten() 96 | finals = traj[actions == self.env.num_actions - 1] 97 | zero_to_n = torch.arange(len(actions)) 98 | 99 | fwd_probs = self.forward_probs(traj) 100 | fwd_probs = torch.where(actions == -1, 1, fwd_probs[zero_to_n, actions]) 101 | fwd_probs = fwd_probs.reshape(num_samples, -1) 102 | 103 | actions = actions.reshape(num_samples, -1)[:, :-1].flatten() 104 | 105 | back_probs = self.backward_policy(traj) 106 | back_probs = back_probs.reshape(num_samples, -1, back_probs.shape[1]) 107 | back_probs = back_probs[:, 1:, :].reshape(-1, back_probs.shape[2]) 108 | back_probs = torch.where((actions == -1) | (actions == 2), 1, 109 | back_probs[zero_to_n[:-num_samples], actions]) 110 | back_probs = back_probs.reshape(num_samples, -1) 111 | 112 | rewards = self.env.reward(finals) 113 | 114 | return fwd_probs, back_probs, rewards --------------------------------------------------------------------------------