├── .flake8 ├── .gitignore ├── .pre-commit-config.yaml ├── README.md ├── pyproject.toml ├── setup.cfg ├── setup.py ├── src └── gcm │ ├── __init__.py │ ├── edge_selectors │ ├── .dense.py.swp │ ├── __init__.py │ ├── dense.py │ ├── distance.py │ ├── learned.py │ ├── self_edge.py │ └── temporal.py │ ├── gcm.py │ ├── nav_gcm.py │ ├── ray_gcm.py │ ├── ray_sparse_gcm.py │ ├── sparse_edge_selectors │ ├── learned.py │ ├── spatial.py │ └── temporal.py │ ├── sparse_gcm.py │ └── util.py └── tests ├── profile_sparse.py ├── test_gcm.py ├── test_nav_gcm.py ├── test_ray_gcm.py ├── test_sparse_gcm.py └── test_speed.py /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | ignore = E203, E266, E402, E501, W503, F403, F401 3 | max-line-length = 79 4 | max-complexity = 18 5 | select = B,C,E,F,W,T4,B9 6 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.swp 2 | *.swo 3 | *__pycache__ 4 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/psf/black 3 | rev: 20.8b1 4 | hooks: 5 | - id: black 6 | language_version: python3 7 | - repo: https://github.com/pycqa/flake8 #https://github.com/pre-commit/pre-commit-hooks 8 | rev: 3.9.0 9 | hooks: 10 | - id: flake8 11 | - repo: https://github.com/pre-commit/mirrors-mypy 12 | rev: 'v0.812' # Use the sha / tag you want to point at 13 | hooks: 14 | - id: mypy 15 | args: [--no-strict-optional, --ignore-missing-imports] 16 | 17 | 18 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Graph Convolutional Memory using Topological Priors 2 | 3 | ## Description 4 | Graph convolutional memory (GCM) is graph-structured memory that may be applied to reinforcement learning to solve POMDPs, replacing LSTMs or attention mechanisms. GCM allows you to embed your domain knowledge in the form of connections in a graph. See the [full paper](https://arxiv.org/pdf/2106.14117.pdf) for further details. This repo contains the GCM library implementation for use in your projects. To replicate the experiments from the paper, please see [this repository instead](https://github.com/smorad/graph-conv-memory-paper). 5 | 6 | If you use GCM, please cite the paper! 7 | ``` 8 | @article{morad2021graph, 9 | title={Graph Convolutional Memory using Topological Priors}, 10 | author={Morad, Steven D and Liwicki, Stephan and Kortvelesy, Ryan and Mecca, Roberto and Prorok, Amanda}, 11 | journal={arXiv preprint arXiv:2106.14117}, 12 | year={2021} 13 | } 14 | ``` 15 | 16 | 17 | ## Installation 18 | GCM is installed using `pip`. The dependencies must be installed manually, as they target your specific architecture (with or without CUDA). 19 | 20 | ### Conda install 21 | First install `torch >= 1.8.0` and `torch-geometric` dependencies, then `gcm`: 22 | ``` 23 | conda install torch 24 | conda install pytorch-geometric -c rusty1s -c conda-forge 25 | pip install graph-conv-memory 26 | ``` 27 | 28 | ### Pip install 29 | Please follow the [torch-geometric install guide](https://pytorch-geometric.readthedocs.io/en/latest/notes/installation.html), then 30 | ``` 31 | pip install graph-conv-memory 32 | ``` 33 | 34 | 35 | ## Quickstart 36 | Below is a quick example of how to use GCM in a basic RL problem: 37 | 38 | ```python 39 | import torch 40 | import torch_geometric 41 | from gcm.gcm import DenseGCM 42 | from gcm.edge_selectors.temporal import TemporalBackedge 43 | 44 | 45 | # graph_size denotes the maximum number of observations in the graph, after which 46 | # the oldest observations will be overwritten with newer observations. Reduce this number to 47 | # reduce memory usage. 48 | graph_size = 128 49 | 50 | # Define the GNN used in GCM. The following is the one used in the paper 51 | # Make sure you define the first layer to match your observation space 52 | class GNN(torch.nn.Module): 53 | """A simple two-layer graph neural network""" 54 | def __init__(self, obs_size, hidden_size=32): 55 | super().__init__() 56 | self.gc0 = torch_geometric.nn.DenseGraphConv(obs_size, hidden_size) 57 | self.gc1 = torch_geometric.nn.DenseGraphConv(hidden_size, hidden_size) 58 | self.act = torch.nn.Tanh() 59 | 60 | def forward(self, x, adj, weights, B, N): 61 | x = self.act(self.gc0(x, adj)) 62 | return self.act(self.gc1(x, adj)) 63 | 64 | # Build GNN that GCM uses internally 65 | obs_size = 8 66 | gnn = GNN(obs_size) 67 | # Create the GCM using our GNN and edge selection criteria. TemporalBackedge([1]) will link observation o_t to o_{t-1}. 68 | # See `gcm.edge_selectors` for different kinds of priors suitable for your specific problem. Do not be afraid to implement your own! 69 | gcm = DenseGCM(gnn, edge_selectors=TemporalBackedge([1]), graph_size=graph_size) 70 | 71 | # If the hidden state m_t is None, GCM will initialize one for you 72 | # only do this at the beginning, as GCM must track and update the hidden 73 | # state to function correctly 74 | # 75 | # You can inspect m_t, as it is just a graph of observations 76 | # the first element is the node feature matrix and the second is the adjacency matrix 77 | m_t = None 78 | 79 | for t in train_timestep: 80 | # Obs at timestep t should be a tensor of shape (batch_size, obs_size) 81 | # obs = my_env.step() 82 | belief, m_t = gcm(obs, m_t) 83 | # GCM provides a belief state -- a combination of all past observational data relevant to the problem 84 | # What you likely want to do is put this state through actor and critic networks to obtain 85 | # action and value estimates 86 | action_logits = logits_nn(belief) 87 | state_value = vf_nn(belief) 88 | ``` 89 | 90 | We provide a few edge selectors, which we briefly detail here: 91 | ```python 92 | gcm.edge_selectors.temporal.TemporalBackedge 93 | # Connections to the past. Give it [1,2,4] to connect each 94 | # observation t to t-1, t-2, and t-4. 95 | 96 | gcm.edge_selectors.dense.DenseEdge 97 | # Connections to all past observations 98 | # observation t is connected to t-1, t-2, ... 0 99 | 100 | gcm.edge_selectors.distance.EuclideanEdge 101 | # Connections to observations within some max_distance 102 | # e.g. if l2_norm(o_t, o_k) < max_distance, create an edge 103 | 104 | gcm.edge_selectors.distance.CosineEdge 105 | # Like euclidean edge, but using cosine similarity instead 106 | 107 | gcm.edge_selectors.distance.SpatialEdge 108 | # Euclidean distance, but only compares slices from the observation 109 | # this is useful if you have an 'x' and 'y' dimension in your observation 110 | # and only want to connect nearby entries 111 | # 112 | # You can also implement the identity priors using this by setting 113 | # max_distance to something like 1e-6 114 | 115 | gcm.edge_selectors.learned.LearnedEdge 116 | # Learn an edge function from the data 117 | # Will randomly sample edges and train thru gradient descent 118 | # call the constructor with the output size of your GNN 119 | ``` 120 | 121 | ## Ray Quickstart (WIP) 122 | We provide a ray rllib wrapper around GCM, see the example below for how to use it 123 | 124 | ```python 125 | import unittest 126 | import torch 127 | import torch_geometric 128 | import ray 129 | from ray import tune 130 | 131 | from gcm.ray_gcm import RayDenseGCM 132 | from gcm.edge_selectors.temporal import TemporalBackedge 133 | 134 | class GNN(torch.nn.Module): 135 | """A simple two-layer graph neural network""" 136 | def __init__(self, obs_size, hidden_size=32): 137 | super().__init__() 138 | self.gc0 = torch_geometric.nn.DenseGraphConv(obs_size, hidden_size) 139 | self.gc1 = torch_geometric.nn.DenseGraphConv(hidden_size, hidden_size) 140 | self.act = torch.nn.Tanh() 141 | 142 | def forward(self, x, adj, weights, B, N): 143 | x = self.act(self.gc0(x, adj)) 144 | return self.act(self.gc1(x, adj)) 145 | 146 | 147 | ray.init( 148 | local_mode=True, 149 | object_store_memory=3e10, 150 | ) 151 | input_size = 16 152 | hidden_size = 32 153 | cfg = { 154 | "framework": "torch", 155 | "num_gpus": 0, 156 | "env": "CartPole-v0", 157 | "num_workers": 0, 158 | "model": { 159 | "custom_model": RayDenseGCM, 160 | "custom_model_config": { 161 | "graph_size": 20, 162 | # GCM Ray wrapper will automatically convert observation 163 | # to gnn_input_size using a linear layer 164 | "gnn_input_size": input_size, 165 | "gnn_output_size": hidden_size, 166 | "gnn": GNN(input_size), 167 | "edge_selectors": TemporalBackedge([1]), 168 | "edge_weights": False, 169 | } 170 | } 171 | } 172 | tune.run( 173 | "A2C", 174 | config=cfg, 175 | stop={"info/num_steps_trained": 100} 176 | ) 177 | ``` 178 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = [ 3 | "setuptools>=42", 4 | "wheel" 5 | ] 6 | build-backend = "setuptools.build_meta" 7 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | name = graph-conv-memory 3 | version = 0.0.7 4 | author = Steven Morad 5 | author_email = stevenmorad@gmail.com 6 | description = Graph convolutional memory for reinforcement learning 7 | long_description = file: README.md 8 | long_description_content_type = text/markdown 9 | url = https://github.com/smorad/graph-conv-memory 10 | project_urls = 11 | Bug Tracker = https://github.com/smorad/graph-conv-memory/issues 12 | classifiers = 13 | Programming Language :: Python :: 3 14 | License :: OSI Approved :: MIT License 15 | Operating System :: OS Independent 16 | 17 | [options] 18 | package_dir = 19 | = src 20 | packages = find: 21 | python_requires = >=3.6 22 | install_requires = 23 | torch 24 | torch_geometric >= 1.7.0 25 | sparsemax 26 | torchtyping 27 | 28 | [options.packages.find] 29 | where = src 30 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | if __name__ == "__main__": 4 | setup() 5 | -------------------------------------------------------------------------------- /src/gcm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/proroklab/graph-conv-memory/995b6ffd97e74e4f6add656e849eff10245585f2/src/gcm/__init__.py -------------------------------------------------------------------------------- /src/gcm/edge_selectors/.dense.py.swp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/proroklab/graph-conv-memory/995b6ffd97e74e4f6add656e849eff10245585f2/src/gcm/edge_selectors/.dense.py.swp -------------------------------------------------------------------------------- /src/gcm/edge_selectors/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/proroklab/graph-conv-memory/995b6ffd97e74e4f6add656e849eff10245585f2/src/gcm/edge_selectors/__init__.py -------------------------------------------------------------------------------- /src/gcm/edge_selectors/dense.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class DenseEdge(torch.nn.Module): 5 | """Add temporal bidirectional back edge, but only if we have >1 nodes 6 | E.g., node_{t} <-> node_{t-1}""" 7 | 8 | def __init__(self): 9 | super().__init__() 10 | 11 | def forward(self, nodes, adj_mats, edge_weights, num_nodes, B): 12 | """Since this is called for each obs, it is sufficient to make row/col 13 | for obs 1""" 14 | 15 | # TODO: Batch this like DistanceEdge 16 | for b in range(B): 17 | i = num_nodes[b] 18 | adj_mats[b][i, :i] = 1 19 | adj_mats[b][:i, i] = 1 20 | # Self edge 21 | adj_mats[b][i, i] = 1 22 | 23 | return adj_mats, edge_weights 24 | -------------------------------------------------------------------------------- /src/gcm/edge_selectors/distance.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class Distance(torch.nn.Module): 5 | """Base class for edges based on the similarity between 6 | latent representations""" 7 | 8 | def __init__(self, max_distance, bidirectional=False, learned=False): 9 | super().__init__() 10 | self.max_distance = max_distance 11 | self.bidirectional = bidirectional 12 | self.learned = learned 13 | if learned: 14 | # Easier to scale node matrix than do comparison w/ grad 15 | self.dist_param = torch.nn.Parameter(torch.Tensor([max_distance])) 16 | self.max_distance = 1.0 17 | 18 | def forward(self, nodes, adj_mats, edge_weights, num_nodes, B): 19 | """Connect current obs to past obs based on distance of the node features""" 20 | 21 | if self.learned: 22 | nodes = nodes / self.dist_param 23 | 24 | B_idx = torch.arange(B) 25 | curr_nodes = nodes[B_idx, num_nodes[B_idx].squeeze()] 26 | dists = self.dist_fn(curr_nodes, nodes) 27 | batch_idxs, node_idxs = torch.where(dists < self.max_distance) 28 | # Remove entries beyond num_nodes 29 | # as well as num_nodes because we don't want 30 | # the self edge 31 | num_nodes_mask = node_idxs < num_nodes[batch_idxs] 32 | batch_idxs = batch_idxs.masked_select(num_nodes_mask) 33 | node_idxs = node_idxs.masked_select(num_nodes_mask) 34 | 35 | adj_mats[batch_idxs, num_nodes[batch_idxs].squeeze(), node_idxs] = 1 36 | if self.bidirectional: 37 | adj_mats[batch_idxs, node_idxs, num_nodes[batch_idxs].squeeze()] = 1 38 | 39 | return adj_mats, edge_weights 40 | 41 | 42 | class EuclideanEdge(Distance): 43 | """Mean per-dimension euclidean distance between obs vectors""" 44 | 45 | def __init__(self, max_distance, learned=False): 46 | super().__init__(max_distance, learned=learned) 47 | 48 | def dist_fn(self, a, b): 49 | return torch.cdist(a, b).mean(dim=1) 50 | 51 | 52 | class CosineEdge(Distance): 53 | """Mean per-dimension cosine distance between obs vectors""" 54 | 55 | def __init__(self, max_distance, learned=False): 56 | super().__init__(max_distance, learned=learned) 57 | self.cs = torch.nn.modules.distance.CosineSimilarity(dim=2) 58 | 59 | def dist_fn(self, a, b): 60 | a = torch.cat([a.unsqueeze(1)] * b.shape[1], dim=1) 61 | return self.cs(a, b) 62 | 63 | 64 | class SpatialEdge(Distance): 65 | """Euclidean distance representing the physical distance between two observations. 66 | Uses the slices a_pose_slice and b_pose_slice to extract the respective 67 | poses from the latent vectors""" 68 | 69 | def __init__(self, max_distance, a_pose_slice, b_pose_slice=None, learned=False): 70 | super().__init__(max_distance, learned=learned) 71 | self.a_pose_slice = a_pose_slice 72 | if b_pose_slice: 73 | self.b_pose_slice = b_pose_slice 74 | else: 75 | self.b_pose_slice = a_pose_slice 76 | 77 | def dist_fn(self, a, b): 78 | a = torch.cat([a.unsqueeze(1)] * b.shape[1], dim=1) 79 | ra = a[:, :, self.a_pose_slice] 80 | rb = b[:, :, self.b_pose_slice] 81 | return torch.cdist(ra, rb).mean(dim=1) 82 | -------------------------------------------------------------------------------- /src/gcm/edge_selectors/learned.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import itertools 3 | from typing import Dict, Tuple, List 4 | import gcm.util 5 | 6 | 7 | class LearnedEdge(torch.nn.Module): 8 | """An edge selector where the prior is learned from data. An MLP 9 | computes logits which create edges via either sampling or sparsemax.""" 10 | 11 | def __init__( 12 | self, 13 | input_size: int = 0, 14 | model: torch.nn.Sequential = None, 15 | num_edge_samples: int = 5, 16 | deterministic: bool = False, 17 | ): 18 | """ 19 | input_size: Feature dim size of GNN, not required if model is specificed 20 | model: Model for logits network, if not specified one is provided 21 | num_edge_samples: If not deterministic, how many samples to take from dist. 22 | determinstic: Whether edges are randomly sampled or argmaxed 23 | """ 24 | super().__init__() 25 | self.deterministic = deterministic 26 | self.num_edge_samples = num_edge_samples 27 | assert input_size or model, "Must specify either input_size or model" 28 | if model: 29 | self.edge_network = model 30 | else: 31 | # This MUST be done here 32 | # if initialized in forward model does not learn... 33 | self.edge_network = self.build_edge_network(input_size) 34 | if deterministic: 35 | self.sm = gcm.util.Spardmax() 36 | self.ste = gcm.util.StraightThroughEstimator() 37 | 38 | def build_edge_network(self, input_size: int) -> torch.nn.Sequential: 39 | """Builds a network to predict edges. 40 | Network input: (i || j) 41 | Network output: logits(edge(i,j)) 42 | """ 43 | return torch.nn.Sequential( 44 | torch.nn.Linear(2 * input_size, input_size), 45 | torch.nn.ReLU(), 46 | torch.nn.LayerNorm(input_size), 47 | torch.nn.Linear(input_size, input_size), 48 | torch.nn.ReLU(), 49 | torch.nn.LayerNorm(input_size), 50 | torch.nn.Linear(input_size, 1), 51 | ) 52 | 53 | def compute_new_adj( 54 | self, 55 | nodes: torch.Tensor, 56 | num_nodes: torch.Tensor, 57 | adj: torch.Tensor, 58 | B: int, 59 | ): 60 | """Computes a new adjacency matrix using the edge network. 61 | The edge network outputs logits for all possible edges, 62 | which are spardmaxed or sampled to produce edges. Edges are then 63 | placed into a new adjacency matrix and returned.""" 64 | # No edges for a single node 65 | if torch.max(num_nodes) < 1: 66 | return adj 67 | 68 | b_idxs, past_idxs, curr_idx = gcm.util.idxs_up_to_num_nodes(adj, num_nodes) 69 | # curr_idx > past_idxs 70 | # flows from past_idxs to j 71 | # so [j, past_idxs] 72 | curr_nodes = nodes[b_idxs, curr_idx] 73 | past_nodes = nodes[b_idxs, past_idxs] 74 | 75 | net_in = torch.cat((curr_nodes, past_nodes), dim=-1) 76 | logits = self.edge_network(net_in).squeeze() 77 | 78 | # Load logits into [B, nodes] matrix and set unfilled entries to large 79 | # negative value so unfilled entries don't affect spardmax 80 | # then spardmax per-batch (dim=-1) 81 | shaped_logits = torch.empty( 82 | (B, torch.max(num_nodes)), device=nodes.device 83 | ).fill_(-1e10) 84 | shaped_logits[b_idxs, past_idxs] = logits 85 | if self.deterministic: 86 | edges = self.sm(shaped_logits) 87 | else: 88 | cutoff = 1 / (1 + self.num_edge_samples) 89 | soft = torch.nn.functional.gumbel_softmax(shaped_logits) 90 | edges = self.ste(soft - cutoff) 91 | # Multinomial straight-thru estimator 92 | """ 93 | gs_in = shaped_logits.unsqueeze(0).repeat(self.num_edge_samples, 1, 1) 94 | # int_edges in Z but we need in [0,1] -- straight thru estimator 95 | soft = torch.nn.functional.gumbel_softmax(gs_in, hard=True) 96 | edges = self.ste(soft.sum(dim=0)) 97 | """ 98 | 99 | new_adj = adj.clone() 100 | # Reindexing edges in this manner ensures even if the edge network 101 | # went beyond -10e20 and set an invalid edge, it will not be used 102 | # at most, it affects the scaling for the valid edges 103 | # 104 | # Ensure we don't overwrite 1's in adj in case we have more than one 105 | # edge selector 106 | # We don't want to add the old adj to the new adj, 107 | # because grads from previous rows will accumulate 108 | # and grad will explode 109 | new_adj[b_idxs, curr_idx, past_idxs] = self.ste( 110 | edges[b_idxs, past_idxs] + adj[b_idxs, curr_idx, past_idxs] 111 | ) 112 | 113 | return new_adj 114 | 115 | def forward(self, nodes, adj, weights, num_nodes, B): 116 | # First run 117 | if self.edge_network[0].weight.device != nodes.device: 118 | self.edge_network = self.edge_network.to(nodes.device) 119 | 120 | # No self edges allowed 121 | if torch.max(num_nodes) < 1: 122 | return adj, weights 123 | 124 | new_adj = self.compute_new_adj(nodes, num_nodes, adj, B) 125 | return new_adj, weights 126 | -------------------------------------------------------------------------------- /src/gcm/edge_selectors/self_edge.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class TemporalBackedge(torch.nn.Module): 5 | """Add temporal bidirectional back edge, but only if we have >1 nodes 6 | E.g., node_{t} <-> node_{t-1}""" 7 | 8 | def __init__(self, parent): 9 | self.parent = parent 10 | 11 | def forward(self, nodes, adj_mats, num_nodes, B): 12 | import pdb 13 | 14 | pdb.set_trace() 15 | -------------------------------------------------------------------------------- /src/gcm/edge_selectors/temporal.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import List 3 | from gcm.util import diff_or, Spardmax 4 | 5 | # VISDOM TOP HALF PROVIDES BEST PERF 6 | 7 | # adj[0,3] = 1 8 | # neigh = matmul(adj, nodes) = nodes[0] 9 | # [i,j] => base[j] neighbor[i] 10 | # propagates from i to j 11 | 12 | # neighbor: torch.matmul(Adj[i, j], x) = x[i] = adj[i] 13 | # self: adj[j] 14 | # Vis: should be top half of visdom 15 | 16 | 17 | class TemporalBackedge(torch.nn.Module): 18 | """Add temporal directional back edge, e.g., node_{t} -> node_{t-1}""" 19 | 20 | def __init__( 21 | self, 22 | hops: List[int] = [1], 23 | direction="forward", 24 | learned=False, 25 | learning_window=10, 26 | deterministic=False, 27 | num_samples=3, 28 | ): 29 | """ 30 | Hops: number of hops in the past to connect to 31 | E.g. [1] is t <- t-1, [2] is t <- t-2, 32 | [5,8] is t <- t-5 AND t <- t-8 33 | 34 | Direction: Directionality of graph edges. You likely want 35 | 'forward', which indicates information flowing from past 36 | to future. Backward is information from future to past, 37 | and both is both. 38 | """ 39 | super().__init__() 40 | self.hops = hops 41 | assert direction in ["forward", "backward", "both"] 42 | self.direction = direction 43 | self.learned = learned 44 | if learned: 45 | self.window = torch.nn.Parameter(torch.ones(learning_window)) 46 | self.num_samples = num_samples 47 | self.deterministic = deterministic 48 | if deterministic: 49 | self.spardmax = Spardmax() 50 | 51 | def learned_forward(self, nodes, adj_mats, edge_weights, num_nodes, B): 52 | if self.window.device != nodes.device: 53 | self.window = self.window.to(nodes.device) 54 | for b in range(B): 55 | if num_nodes[b] == 0: 56 | continue 57 | relative_window = self.window[: num_nodes[b]] 58 | if self.deterministic: 59 | mask = self.spardmax(relative_window.reshape(1, -1)).reshape(-1) 60 | else: 61 | masks = [] 62 | for i in range(self.num_samples): 63 | masks.append( 64 | torch.nn.functional.gumbel_softmax(relative_window, hard=True) 65 | ) 66 | mask = diff_or(masks) 67 | adj_mats[b][num_nodes[b]][: num_nodes[b]] = ( 68 | adj_mats[b][num_nodes[b]][: num_nodes[b]] + mask 69 | ) 70 | return adj_mats, edge_weights 71 | 72 | def deterministic_forward(self, nodes, adj_mats, edge_weights, num_nodes, B): 73 | for hop in self.hops: 74 | [valid_batches] = torch.where(num_nodes >= hop) 75 | if self.direction in ["forward", "both"]: 76 | adj_mats[ 77 | valid_batches, 78 | num_nodes[valid_batches], 79 | num_nodes[valid_batches] - hop, 80 | ] = 1 81 | if self.direction in ["backward", "both"]: 82 | adj_mats[ 83 | valid_batches, 84 | num_nodes[valid_batches] - hop, 85 | num_nodes[valid_batches], 86 | ] = 1 87 | 88 | return adj_mats, edge_weights 89 | 90 | def forward(self, nodes, adj_mats, edge_weights, num_nodes, B): 91 | if self.learned: 92 | return self.learned_forward(nodes, adj_mats, edge_weights, num_nodes, B) 93 | 94 | return self.deterministic_forward(nodes, adj_mats, edge_weights, num_nodes, B) 95 | -------------------------------------------------------------------------------- /src/gcm/gcm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch_geometric 3 | from torch_geometric.data import Data, Batch 4 | from typing import List, Tuple, Union, Any, Dict, Callable 5 | import time 6 | import math 7 | import gcm.util 8 | 9 | 10 | class SparseToDense(torch.nn.Module): 11 | """Convert from edge_list to adj. """ 12 | 13 | def forward(self, x, edge_index, batch_idx, B, N): 14 | # TODO: Should handle weights 15 | x = torch_geometric.utils.to_dense_batch(x=x, batch=batch_idx, max_num_nodes=N)[ 16 | 0 17 | ] 18 | adj = torch_geometric.utils.to_dense_adj( 19 | edge_index, batch=batch_idx, max_num_nodes=N 20 | )[0] 21 | return x, adj 22 | 23 | 24 | class DenseToSparse(torch.nn.Module): 25 | """Convert from adj to edge_list while allowing gradients 26 | to flow through adj. 27 | 28 | x: shape[B, N+k, feat] 29 | adj: shape[B, N+k, N+k] 30 | mask: shape[B, N+k]""" 31 | 32 | def forward(self, x, adj, mask=None): 33 | assert x.dim() == adj.dim() == 3 34 | B = x.shape[0] 35 | N = x.shape[1] 36 | if mask: 37 | raise NotImplementedError() 38 | assert mask.shape == (B, N) 39 | x_mask = mask.unsqueeze(-1).expand(-1, -1, x.shape[-1]) 40 | adj_mask = mask.unsqueeze(-1).expand(-1, -1, adj.shape[-1]) 41 | x = x * x_mask 42 | adj = adj * adj_mask 43 | N = mask.shape[1] 44 | offset, row, col = torch.nonzero(adj > 0).t() 45 | row += offset * N 46 | col += offset * N 47 | edge_index = torch.stack([row, col], dim=0).long() 48 | x = x.view(B * N, x.shape[-1]) 49 | batch_idx = ( 50 | torch.arange(0, B, device=x.device).view(-1, 1).repeat(1, N).view(-1) 51 | ) 52 | 53 | return x, edge_index, batch_idx 54 | 55 | 56 | class RelativePositionalEncoding(torch.nn.Module): 57 | def __init__(self, max_len: int = 5000): 58 | super().__init__() 59 | self.max_len = max_len 60 | 61 | def run_once(self, nodes: torch.Tensor) -> None: 62 | # Dim must be even 63 | d_model = math.ceil(nodes.shape[-1] / 2) * 2 64 | position = torch.arange(self.max_len).unsqueeze(1) 65 | div_term = torch.exp( 66 | torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model) 67 | ) 68 | pe = torch.zeros(self.max_len, d_model, device=nodes.device) 69 | pe[:, 0::2] = torch.sin(position * div_term) 70 | pe[:, 1::2] = torch.cos(position * div_term) 71 | self.register_buffer("pe", pe) 72 | 73 | def forward(self, nodes: torch.Tensor, num_nodes: torch.Tensor) -> torch.Tensor: 74 | """ 75 | Args: 76 | x: Tensor, shape [batch_size, seq_len, embedding_dim] 77 | num_nodes: Tensor 78 | """ 79 | if not hasattr(self, "pe"): 80 | self.run_once(nodes) 81 | 82 | B = nodes.shape[0] 83 | for b in range(B): 84 | center = num_nodes[b] 85 | pe = self.pe.roll(int(center), 0) 86 | nodes[b, : center + 1] = ( 87 | nodes[b, : center + 1] + pe[: center + 1, : nodes.shape[-1]] 88 | ) 89 | return nodes 90 | 91 | 92 | class PositionalEncoding(torch.nn.Module): 93 | """Embed positional encoding into the graph. Ensures we do not 94 | encode future nodes (node_idx > num_nodes)""" 95 | 96 | def __init__(self, max_len: int = 5000, mode="add", cat_dim: int = 8): 97 | super().__init__() 98 | self.max_len = max_len 99 | self.mode = mode 100 | self.cat_dim = cat_dim 101 | assert mode in ["add", "cat"] 102 | 103 | def run_once(self, x: torch.Tensor) -> None: 104 | # Dim must be even 105 | d_model = math.ceil(x.shape[-1] / 2) * 2 106 | position = torch.arange(self.max_len).unsqueeze(1) 107 | div_term = torch.exp( 108 | torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model) 109 | ) 110 | pe = torch.zeros(self.max_len, d_model, device=x.device) 111 | pe[:, 0::2] = torch.sin(position * div_term) 112 | pe[:, 1::2] = torch.cos(position * div_term) 113 | self.register_buffer("pe", pe) 114 | 115 | if self.mode == "cat": 116 | self.reproject = torch.nn.Linear( 117 | x.shape[-1], x.shape[-1] - self.cat_dim, device=x.device 118 | ) 119 | 120 | def forward(self, x: torch.Tensor, num_nodes: torch.Tensor) -> torch.Tensor: 121 | """ 122 | Args: 123 | x: Tensor, shape [batch_size, seq_len, embedding_dim] 124 | num_nodes: Tensor 125 | """ 126 | if not hasattr(self, "pe"): 127 | self.run_once(x) 128 | 129 | b_idxs, n_idxs = gcm.util.idxs_up_to_including_num_nodes(x, num_nodes) 130 | if self.mode == "add": 131 | x[b_idxs, n_idxs] = x[b_idxs, n_idxs] + self.pe[n_idxs, : x.shape[-1]] 132 | elif self.mode == "cat": 133 | x_reproj = self.reproject(x[b_idxs, n_idxs]).reshape( 134 | len(b_idxs), x.shape[-1] - self.cat_dim 135 | ) 136 | # positional encoding 137 | x = x.clone() 138 | x[b_idxs, n_idxs, : self.cat_dim] = self.pe[n_idxs, : self.cat_dim] 139 | # Reprojected feature assignment 140 | x[b_idxs, n_idxs, self.cat_dim :] = x_reproj 141 | else: 142 | raise NotImplementedError("Invalid mode") 143 | return x 144 | 145 | 146 | @torch.jit.script 147 | def overflow(num_nodes: torch.Tensor, N: int): 148 | return torch.any(num_nodes + 1 > N) 149 | 150 | 151 | class DenseGCM(torch.nn.Module): 152 | """Graph Associative Memory""" 153 | 154 | did_warn = False 155 | 156 | def __init__( 157 | self, 158 | # Graph neural network, see torch_geometric.nn.Sequential 159 | # for some examples 160 | gnn: torch.nn.Module, 161 | # Preprocessor for each feat vec before it's placed in graph 162 | preprocessor: torch.nn.Module = None, 163 | # an edge selector from gcm.edge_selectors 164 | # you can chain multiple selectors together using 165 | # torch_geometric.nn.Sequential 166 | edge_selectors: torch.nn.Module = None, 167 | # Auxiliary edge selectors are called 168 | # after the positional encoding and reprojection 169 | # this should only be used for non-human (learned) priors 170 | aux_edge_selectors: torch.nn.Module = None, 171 | # Maximum number of nodes in the graph 172 | graph_size: int = 128, 173 | # Whether the gnn outputs graph_size nodes or uses global pooling 174 | pooled: bool = False, 175 | # Whether to add sin/cos positional encoding like in transformer 176 | # to the nodes 177 | # Creates an ordering in the graph 178 | positional_encoder: torch.nn.Module = None, 179 | # Whether to use edge_weights 180 | # only required if using learned edges 181 | edge_weights: bool = False, 182 | ): 183 | super().__init__() 184 | 185 | self.preprocessor = preprocessor 186 | self.gnn = gnn 187 | self.graph_size = graph_size 188 | self.edge_selectors = edge_selectors 189 | self.aux_edge_selectors = aux_edge_selectors 190 | self.pooled = pooled 191 | self.edge_weights = edge_weights 192 | self.positional_encoder = positional_encoder 193 | 194 | def get_initial_hidden_state(self, x): 195 | """Given a dummy x of shape [B, feats], construct 196 | the hidden state for the base case (adj matrix, weights, etc)""" 197 | """Returns the initial hidden state h (e.g. h, output = gcm(input, h)), 198 | for a given batch size (B). Feats denotes the feature size (# dims of each 199 | node in the graph).""" 200 | 201 | assert x.dim() == 2 202 | B, feats = x.shape 203 | edges = torch.zeros(B, self.graph_size, self.graph_size, device=x.device) 204 | nodes = torch.zeros(B, self.graph_size, feats, device=x.device) 205 | if self.edge_weights: 206 | weights = torch.zeros(B, self.graph_size, self.graph_size, device=x.device) 207 | else: 208 | weights = torch.zeros(0, device=x.device) 209 | num_nodes = torch.zeros(B, dtype=torch.long, device=x.device) 210 | 211 | return nodes, edges, weights, num_nodes 212 | 213 | def forward( 214 | self, 215 | x, 216 | hidden: Union[ 217 | None, Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] 218 | ], 219 | ) -> Tuple[torch.Tensor, torch.Tensor]: 220 | """Add a memory x to the graph, and query the memory for it. 221 | B = batch size 222 | N = maximum graph size 223 | Inputs: 224 | x: [B,feat] 225 | hidden: ( 226 | nodes: [B,N,feats] 227 | adj: [B,N,N] 228 | weights: [B,N,N] or None 229 | number_of_nodes_in_graph: [B] 230 | ) 231 | Outputs: 232 | m(x): [B,feat] 233 | hidden: ( 234 | nodes: [B,N,feats] 235 | adj: [B,N,N] 236 | weights: [B,N,N] or None 237 | number_of_nodes_in_graph: [B] 238 | ) 239 | """ 240 | # Base case 241 | if hidden is None: 242 | hidden = self.get_initial_hidden_state(x) 243 | 244 | nodes, adj, weights, num_nodes = hidden 245 | 246 | assert x.dtype == torch.float32 247 | assert nodes.dtype == torch.float 248 | # if self.gnn.sparse: 249 | # assert adj.dtype == torch.long 250 | assert weights.dtype == torch.float 251 | assert num_nodes.dtype == torch.long 252 | assert num_nodes.dim() == 1 253 | 254 | N = nodes.shape[1] 255 | B = x.shape[0] 256 | B_idx = torch.arange(B) 257 | 258 | assert ( 259 | N == adj.shape[1] == adj.shape[2] 260 | ), "N must be equal for adj mat and node mat" 261 | 262 | nodes = nodes.clone() 263 | if overflow(num_nodes, N): 264 | if not self.did_warn: 265 | print("Overflow detected, wrapping around. Will not warn again") 266 | self.did_warn = True 267 | adj = adj.clone() 268 | weights = weights.clone() 269 | nodes, adj, weights, num_nodes = self.wrap_overflow( 270 | nodes, adj, weights, num_nodes 271 | ) 272 | # Add new nodes to the current graph 273 | # starting at num_nodes 274 | nodes[B_idx, num_nodes[B_idx]] = x[B_idx] 275 | # We do not want to modify graph nodes in the GCM 276 | # Do all mutation operations on dirty_nodes, 277 | # then use clean nodes in the graph state 278 | dirty_nodes = nodes.clone() 279 | # Do NOT add self edges or they will be counted twice using 280 | # GraphConv 281 | 282 | # Adj and weights must be cloned as 283 | # edge selectors will modify them in-place 284 | if self.edge_selectors: 285 | adj, weights = self.edge_selectors( 286 | dirty_nodes, adj.clone(), weights.clone(), num_nodes, B 287 | ) 288 | 289 | # Thru network 290 | if self.preprocessor: 291 | dirty_nodes = self.preprocessor(dirty_nodes) 292 | # if self.positional_encoder: 293 | # dirty_nodes = self.positional_encoder(dirty_nodes, num_nodes) 294 | if self.aux_edge_selectors: 295 | if self.positional_encoder: 296 | adj, weights = self.aux_edge_selectors( 297 | self.positional_encoder(dirty_nodes, num_nodes), 298 | adj.clone(), 299 | weights.clone(), 300 | num_nodes, 301 | B, 302 | ) 303 | else: 304 | adj, weights = self.aux_edge_selectors( 305 | dirty_nodes, adj.clone(), weights.clone(), num_nodes, B 306 | ) 307 | 308 | node_feats = self.gnn(dirty_nodes, adj, weights, B, N) 309 | if self.pooled: 310 | # If pooled, we expect only a single output node 311 | mx = node_feats 312 | else: 313 | # Otherwise extract the hidden repr at the current node 314 | mx = node_feats[B_idx, num_nodes[B_idx]] 315 | 316 | assert torch.all( 317 | torch.isfinite(mx) 318 | ), "Got NaN in returned memory, try using tanh activation" 319 | 320 | num_nodes = num_nodes + 1 321 | return mx, (nodes, adj, weights, num_nodes) 322 | 323 | def wrap_overflow(self, nodes, adj, weights, num_nodes): 324 | """Call this when the node/adj matrices are full. Deletes the zeroth element 325 | of the matrices and shifts all the elements up by one, producing a free row 326 | at the end. You will likely want to call .clone() on the arguments that require 327 | gradient computation. 328 | 329 | Returns new nodes, adj, weights, and num_nodes matrices""" 330 | N = nodes.shape[1] 331 | overflow_mask = num_nodes + 1 > N 332 | # Shift node matrix into the past 333 | # by one and forget the zeroth node 334 | overflowing_batches = overflow_mask.nonzero().squeeze() 335 | # nodes = nodes.clone() 336 | # adj = adj.clone() 337 | # Zero entries before shifting 338 | nodes[overflowing_batches, 0] = 0 339 | adj[overflowing_batches, 0, :] = 0 340 | adj[overflowing_batches, :, 0] = 0 341 | # Roll newly zeroed zeroth entry to final entry 342 | nodes[overflowing_batches] = torch.roll(nodes[overflowing_batches], -1, -2) 343 | adj[overflowing_batches] = torch.roll( 344 | adj[overflowing_batches], (-1, -1), (-1, -2) 345 | ) 346 | if weights.numel() != 0: 347 | # weights = weights.clone() 348 | weights[overflowing_batches, 0, :] = 0 349 | weights[overflowing_batches, :, 0] = 0 350 | weights[overflowing_batches] = torch.roll( 351 | weights[overflowing_batches], (-1, -1), (-1, -2) 352 | ) 353 | 354 | num_nodes[overflow_mask] = num_nodes[overflow_mask] - 1 355 | return nodes, adj, weights, num_nodes 356 | -------------------------------------------------------------------------------- /src/gcm/nav_gcm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch_geometric 3 | 4 | 5 | # TODO: We need to make this a for loop over temporal dim 6 | # causal edges do not allow loop closures 7 | # memory should be okay because we discard edges 8 | # grads will stack and take up lots of memory, but we can set small seq_lens 9 | # to reduce memory pressure 10 | 11 | 12 | class NavGCM(torch.nn.Module): 13 | """GCM tailored specifically to the navigation domain. This allows us 14 | to use priors to speed up and simplify GCM""" 15 | def __init__( 16 | self, 17 | gnn, 18 | pool=False, 19 | max_verts=128, 20 | edge_method="radius", 21 | k=16, 22 | r=1.0, 23 | causal=True, 24 | disjoint_edges=False, 25 | ): 26 | super().__init__() 27 | self.k = k 28 | self.r = r 29 | self.gnn = gnn 30 | self.max_verts = max_verts 31 | self.pool = pool 32 | assert edge_method in ["knn", "radius"] 33 | assert edge_method != "knn", "KNN does not train/infer correctly" 34 | self.edge_method = edge_method 35 | self.causal = causal 36 | self.disjoint_edges = disjoint_edges 37 | 38 | def make_idx(self, T, taus): 39 | """Returns batch and time idxs marking 40 | all valid (non-padded) elements in the vert matrix""" 41 | batch = torch.repeat_interleave( 42 | torch.arange(T.shape[0], device=T.device), 43 | T + taus, 44 | ) 45 | time = torch.cat([ 46 | torch.arange(T[b] + taus[b], device=T.device) 47 | for b in range(T.shape[0]) 48 | ], dim=-1) 49 | return batch, time 50 | 51 | def make_new_idx( 52 | self, T, taus 53 | ): 54 | """Returns batch and time idxs marking 55 | all NEW elements in the vert matrix""" 56 | batch = torch.repeat_interleave( 57 | torch.arange(T.shape[0], device=T.device), 58 | taus, 59 | ) 60 | time = torch.cat([ 61 | torch.arange(T[b], T[b] + taus[b], device=T.device) 62 | for b in range(T.shape[0]) 63 | ], dim=-1) 64 | return batch, time 65 | 66 | def make_output_idx( 67 | self, taus 68 | ): 69 | """Like make_new_idx, returns the batch and 70 | time idxs marking all NEW elements. Unlike make_new_idx 71 | the indices correspond to locations in the padded OUTPUT 72 | matrix rather than the full vertex matrix""" 73 | batch = torch.repeat_interleave( 74 | torch.arange(taus.shape[0], device=taus.device), 75 | taus, 76 | ) 77 | time = torch.cat([ 78 | torch.arange(taus[b], device=taus.device) 79 | for b in range(taus.shape[0]) 80 | ], dim=-1) 81 | return batch, time 82 | 83 | def make_flat_new_idx( 84 | self, T, taus 85 | ): 86 | """Return index of all new elements, like make_new_idx. 87 | However, rather than [B, T] indexing, this returns [B * T] 88 | indicies to extract new nodes from the GNN output""" 89 | cs = (T + taus).cumsum(0) 90 | return torch.cat([ 91 | torch.arange(cs[b] - taus[b], cs[b], device=T.device) 92 | for b in range(T.shape[0]) 93 | ], dim=-1) 94 | 95 | def knn_edges( 96 | self, x, pos, rot 97 | ): 98 | # TODO: Getting future edges here... 99 | edges = torch_geometric.nn.knn_graph( 100 | x=pos, 101 | k=self.k, 102 | batch=self.idx[0] 103 | ) 104 | return edges 105 | 106 | def radius_edges( 107 | self, x, pos, rot 108 | ): 109 | # TODO: Getting future edges here... 110 | edges = torch_geometric.nn.radius_graph( 111 | x=pos, 112 | r=self.r, 113 | batch=self.idx[0], 114 | loop=True, 115 | max_num_neighbors=self.k, 116 | ) 117 | return edges 118 | 119 | def remove_noncausal_edges(self, edges, T, taus): 120 | # Remove edges where sink > source where 121 | # sink and source are both in tau 122 | # TODO: This is not yet correct 123 | keep = edges[0] < edges[1] 124 | return edges[:, keep] 125 | 126 | def update( 127 | self, 128 | x, 129 | pos, 130 | rot, 131 | old_x, 132 | old_pos, 133 | old_rot, 134 | T, 135 | taus 136 | ): 137 | """Add new observations to the respective state mats""" 138 | old_x[self.new_idx] = x[self.out_idx] 139 | old_pos[self.new_idx] = pos[self.out_idx] 140 | old_rot[self.new_idx] = rot[self.out_idx] 141 | return old_x, old_pos, old_rot 142 | 143 | def compute_idx(self, T, taus): 144 | """Precompute useful indices for forward pass""" 145 | # Compute idxs once for efficiency 146 | # Nearly all idxs are pairs of [batch, time] indices 147 | # Idx pointing to non-padded elements 148 | # Shape: [sum(B[i]*T[i])] * 2 149 | self.idx = self.make_idx(T, taus) 150 | # Idx pointing to new elements in the x, pos, vert mats 151 | # Shape: [sum(taus[i])] * 2 152 | self.new_idx = self.make_new_idx(T, taus) 153 | # Idx pointing to new elements in the flattened gnn output 154 | # Shape: [sum(taus[i])], note there is only one idx here 155 | self.flat_new_idx = self.make_flat_new_idx(T, taus) 156 | # Idx pointing to padded output elements (this is where the outputs go) 157 | # Shape: [sum(taus[i])] 158 | self.out_idx = self.make_output_idx(taus) 159 | # Pointer to the last node in each graph 160 | # Shape: [B] 161 | self.back_ptr = (T + taus).cumsum(0) - 1 162 | # Pointer to the zeroth node in each graph 163 | # Shape: [B] 164 | self.front_ptr = torch.cat( 165 | [torch.tensor([0], device=T.device), self.back_ptr[:-1] + 1] 166 | ) 167 | #self.front_ptr = self.back_ptr.roll(1) 168 | #self.front_ptr[0] = 0 169 | 170 | def causal_forward(self, x, pos, rot, T, taus, out_batch, out_time): 171 | """Causal forward restricts edges to being fully-causal. This means 172 | incoming edges of node t cannot be updated except at time t. This prevents 173 | loop closures, but uses significantly less memory and runs faster. Note 174 | that this produces differen graph topologies than full_forward.""" 175 | # Remove padding and flatten 176 | # Unpadded shapes are [B*(T+taus), *] 177 | x = x[self.idx] 178 | pos = pos[self.idx] 179 | rot = rot[self.idx] 180 | 181 | if self.edge_method == "knn": 182 | edges = self.knn_edges(x, pos, rot) 183 | else: 184 | edges = self.radius_edges(x, pos, rot) 185 | # TODO: This changes eval -- we can rewire old things to be noncausal 186 | #if self.training: 187 | edges = self.remove_noncausal_edges(edges, T, taus) 188 | # [B, T] ordering is [0, 0], [0, 1], ... [0, t], [1, 0] 189 | # TODO: Pooling can be done in output using new_idx 190 | # as well as max_hops graph reduction and sampling 191 | output = self.gnn( 192 | x, edges, pos, rot, self.idx[0], self.front_ptr, self.back_ptr, self.flat_new_idx 193 | ) 194 | 195 | # return output 196 | # Offset from 0 instead of T 197 | return output[self.flat_new_idx] 198 | 199 | def full_forward(self, x, pos, rot, T, taus, out_batch, out_time): 200 | """Unlike causal_forward, full_forward allows graph rewiring for 201 | old vertices (loop closures), but needs to construct a separate graph 202 | for each timestep, greatly increasing memory usage and reducing 203 | throughput""" 204 | # TODO dont hardcode hidden 205 | # TODO: flat_new_idx indexing is probably wrong, 206 | # as the total number of nodes is different here than in the causal case 207 | graphs = [] 208 | for b in range(out_batch): 209 | for t in range(out_time): 210 | if t < taus[b]: 211 | curr_slice = slice(0, T[b] + t + 1) 212 | graphs.append( 213 | torch_geometric.data.Data( 214 | x=x[b, curr_slice], 215 | pos=pos[b, curr_slice], 216 | rot=rot[b, curr_slice] 217 | ) 218 | ) 219 | 220 | batch = torch_geometric.data.Batch.from_data_list(graphs) 221 | if self.edge_method == "knn": 222 | batch.edge_index = torch_geometric.nn.knn_graph( 223 | batch.pos, k=self.k, batch=batch.batch 224 | ) 225 | else: 226 | batch.edge_index = torch_geometric.nn.radius_graph( 227 | x=batch.pos, 228 | r=self.r, 229 | batch=batch.batch, 230 | loop=True, 231 | max_num_neighbors=self.k, 232 | ) 233 | 234 | output = self.gnn(batch.x, batch.edge_index, batch.pos, batch.rot, batch.batch) 235 | return output 236 | 237 | def forward( 238 | self, 239 | x, # [B, taus.max(), F] 240 | pos, # [B, taus.max(), 2] 241 | rot, # [B, taus.max(), 1] 242 | taus, # [B] 243 | state, # [old_x, old_pos, old_rots, each of size [B, T.max(), *]] 244 | ): 245 | old_x, old_pos, old_rot, T = state 246 | out_batch, out_time = x.shape[0], taus.max() 247 | # Update hidden state 248 | self.compute_idx(T, taus) 249 | x, pos, rot = self.update(x, pos, rot, old_x, old_pos, old_rot, T, taus) 250 | state = [x, pos, rot, T + taus] 251 | if self.causal: 252 | output_at_target = self.causal_forward(x, pos, rot, T, taus, out_batch, out_time) 253 | else: 254 | output_at_target = self.full_forward(x, pos, rot, T, taus, out_batch, out_time) 255 | 256 | # Compute padded output at the inputted vert idxs 257 | padded_output = torch.zeros( 258 | (out_batch, out_time, output.shape[-1]), device=x.device 259 | ) 260 | # Offset from 0 instead of T 261 | padded_output[self.out_idx] = output[self.flat_new_idx] 262 | 263 | return padded_output, state 264 | -------------------------------------------------------------------------------- /src/gcm/ray_gcm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import gym 4 | from torch import nn 5 | from typing import Union, Dict, List, Tuple, Any 6 | import ray 7 | from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 8 | from ray.rllib.models.torch.fcnet import FullyConnectedNetwork as TorchFC 9 | from ray.rllib.models.torch.misc import SlimFC, normc_initializer 10 | from ray.rllib.models.torch.recurrent_net import RecurrentNetwork 11 | from ray.rllib.utils.typing import ModelConfigDict, TensorType 12 | from ray.rllib.policy.sample_batch import SampleBatch 13 | from ray.rllib.policy.view_requirement import ViewRequirement 14 | from ray.rllib.utils.torch_utils import one_hot 15 | 16 | import torch_geometric 17 | from torch_geometric.data import Data, Batch 18 | from gcm.gcm import DenseGCM, PositionalEncoding, RelativePositionalEncoding 19 | 20 | 21 | class RayDenseGCM(TorchModelV2, nn.Module): 22 | DEFAULT_CONFIG = { 23 | # Maximum number of nodes in a graph 24 | "graph_size": 32, 25 | # Input size to the GNN. Make sure your first gnn layer 26 | # has this many input channels 27 | "gnn_input_size": 64, 28 | # Number of output channels of the GNN. This feeds into the logits 29 | # and value function layers 30 | "gnn_output_size": 64, 31 | # GNN model that takes x, edge_index, weights 32 | # Note that input will be reshaped by a linear layer 33 | # to gnn_input_size 34 | "gnn": torch_geometric.nn.Sequential( 35 | "x, edge_index, weights, B, N", 36 | [ 37 | (torch_geometric.nn.GraphConv(64, 64), "x, edge_index -> x"), 38 | torch.nn.Tanh(), 39 | (torch_geometric.nn.GraphConv(64, 64), "x, edge_index -> x"), 40 | torch.nn.Tanh(), 41 | ], 42 | ), 43 | # Torch.nn.module used for determining edges between nodes. 44 | # You can chain multiple modules together use 45 | # torch_geometric.nn.Sequential 46 | "edge_selectors": None, 47 | # Same as edge selectors, but called after reprojection 48 | # and positional encoding. Only use non-human (learned) edges here 49 | # as they are no longer in a human readable form 50 | "aux_edge_selectors": None, 51 | # Whether the final output is pooled (a single node). If pooled, 52 | # we will not try to extract the hidden representation of the current node 53 | # and instead just use the raw pooled representation 54 | "pooled": False, 55 | # Whether edge weights are used. This should be false unless using 56 | # bernoulli edges 57 | "edge_weights": False, 58 | # Optional network that processes observations before 59 | # the GNN. May allow for learning representations that 60 | # aggregate better. Note the input to the preprocessor will 61 | # already be of shape "gnn_input_size" 62 | # Note that the node preprocessor will run after observations are 63 | # inserted in the graph. This means the observations can be 64 | # reconstructed at the expense of greater memory usage compared 65 | # to the preprocessor 66 | "preprocessor": None, 67 | # Whether to train the preprocessor or freeze the weights 68 | "preprocessor_frozen": False, 69 | "pre_preprocessor": None, 70 | # Whether the prev action should be placed in the observation nodes 71 | "use_prev_action": False, 72 | # Whether to use positional encoding (ala transformer) in the GNN 73 | # False for none, 'cat' for concatenate encoding to feature vector, 74 | # and 'add' for sum encoding with feature vector 75 | "positional_encoding": None, 76 | # if 'cat', how many dimensions should be reserved for positional 77 | # encoding 78 | "positional_encoding_dim": 4, 79 | } 80 | 81 | def __init__( 82 | self, 83 | obs_space: gym.spaces.Space, 84 | action_space: gym.spaces.Space, 85 | num_outputs: int, 86 | model_config: ModelConfigDict, 87 | name: str, 88 | **custom_model_kwargs, 89 | ): 90 | TorchModelV2.__init__( 91 | self, obs_space, action_space, num_outputs, model_config, name 92 | ) 93 | nn.Module.__init__(self) 94 | self.num_outputs = num_outputs 95 | self.obs_dim = gym.spaces.utils.flatdim(obs_space) 96 | self.act_space = action_space 97 | self.act_dim = gym.spaces.utils.flatdim(action_space) 98 | # edge selectors must be attrib of torch.nn.module 99 | # so edge_selectors get self.training, etc. 100 | 101 | for k in custom_model_kwargs: 102 | assert k in self.DEFAULT_CONFIG, f"Invalid config key {k}" 103 | self.cfg = dict(self.DEFAULT_CONFIG, **custom_model_kwargs) 104 | self.input_dim = self.obs_dim 105 | if self.cfg["use_prev_action"]: 106 | self.input_dim += self.act_dim 107 | self.view_requirements["prev_actions"] = ViewRequirement( 108 | "actions", space=self.action_space, shift=-1 109 | ) 110 | 111 | self.build_network(self.cfg) 112 | print("Full GCM network is:", self) 113 | 114 | self.cur_val = None 115 | 116 | def build_network(self, cfg): 117 | """Builds the GNN and MLPs based on config""" 118 | pp = torch.nn.Linear(self.input_dim, cfg["gnn_input_size"]) 119 | pe = None 120 | if cfg["positional_encoding"]: 121 | pe = PositionalEncoding( 122 | max_len=self.cfg["graph_size"], 123 | mode=cfg["positional_encoding"], 124 | cat_dim=cfg["positional_encoding_dim"], 125 | ) 126 | 127 | if cfg["preprocessor"]: 128 | if cfg["preprocessor_frozen"]: 129 | for param in cfg["preprocessor"].parameters(): 130 | param.requires_grad = False 131 | pp = torch.nn.Sequential(pp, cfg["preprocessor"]) 132 | 133 | self.gcm = DenseGCM( 134 | gnn=cfg["gnn"], 135 | preprocessor=pp, 136 | edge_selectors=self.cfg["edge_selectors"], 137 | aux_edge_selectors=self.cfg["aux_edge_selectors"], 138 | pooled=self.cfg["pooled"], 139 | positional_encoder=pe, 140 | ) 141 | 142 | self.logit_branch = SlimFC( 143 | in_size=cfg["gnn_output_size"], 144 | out_size=self.num_outputs, 145 | activation_fn=None, 146 | initializer=normc_initializer(0.01), 147 | ) 148 | 149 | self.value_branch = SlimFC( 150 | in_size=cfg["gnn_output_size"], 151 | out_size=1, 152 | activation_fn=None, 153 | initializer=normc_initializer(0.01), 154 | ) 155 | 156 | def get_initial_state(self): 157 | edges = torch.zeros((self.cfg["graph_size"], self.cfg["graph_size"])) 158 | nodes = torch.zeros((self.cfg["graph_size"], self.input_dim)) 159 | if self.cfg["edge_weights"]: 160 | weights = torch.zeros((self.cfg["graph_size"], self.cfg["graph_size"])) 161 | else: 162 | weights = torch.zeros(0) 163 | 164 | num_nodes = torch.tensor(0, dtype=torch.long) 165 | state = [nodes, edges, weights, num_nodes] 166 | 167 | return state 168 | 169 | def value_function(self): 170 | assert self.cur_val is not None, "must call forward() first" 171 | return self.cur_val 172 | 173 | def forward( 174 | self, 175 | input_dict: Dict[str, TensorType], 176 | state: List[TensorType], 177 | seq_lens: TensorType, 178 | ) -> Tuple[TensorType, List[TensorType]]: 179 | 180 | if self.cfg["use_prev_action"]: 181 | prev_acts = one_hot(input_dict["prev_actions"].float(), self.act_space) 182 | prev_acts = prev_acts.reshape(-1, self.act_dim) 183 | flat = torch.cat((input_dict["obs_flat"], prev_acts), dim=-1) 184 | else: 185 | flat = input_dict["obs_flat"] 186 | # Batch and Time 187 | # Forward expects outputs as [B, T, logits] 188 | B = len(seq_lens) 189 | T = flat.shape[0] // B 190 | 191 | outs = torch.zeros(B, T, self.cfg["gnn_output_size"], device=flat.device) 192 | # Deconstruct batch into batch and time dimensions: [B, T, feats] 193 | flat = torch.reshape(flat, [-1, T] + list(flat.shape[1:])) 194 | nodes, adj_mats, weights, num_nodes = state 195 | 196 | num_nodes = num_nodes.long() 197 | 198 | # Push thru pre-gcm layers 199 | hidden = (nodes, adj_mats, weights, num_nodes) 200 | for t in range(T): 201 | out, hidden = self.gcm(flat[:, t, :], hidden) 202 | outs[:, t, :] = out 203 | 204 | # Collapse batch and time for more efficient forward 205 | out_view = outs.view(B * T, self.cfg["gnn_output_size"]) 206 | logits = self.logit_branch(out_view) 207 | values = self.value_branch(out_view) 208 | 209 | self.cur_val = values.squeeze(1) 210 | 211 | state = list(hidden) 212 | return logits, state 213 | -------------------------------------------------------------------------------- /src/gcm/ray_sparse_gcm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import gym 4 | from torch import nn 5 | from typing import Dict, Tuple, List 6 | from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 7 | from ray.rllib.models.torch.misc import SlimFC, normc_initializer 8 | from ray.rllib.utils.typing import ModelConfigDict, TensorType 9 | from ray.rllib.policy.view_requirement import ViewRequirement 10 | from ray.rllib.utils.torch_utils import one_hot 11 | from ray.rllib.policy.rnn_sequencing import add_time_dimension 12 | 13 | import torch_geometric 14 | from gcm.gcm import PositionalEncoding 15 | from gcm.sparse_gcm import SparseGCM 16 | from gcm import util 17 | 18 | 19 | class RaySparseGCM(TorchModelV2, nn.Module): 20 | DEFAULT_CONFIG = { 21 | # Maximum number of nodes in a graph 22 | "graph_size": 32, 23 | # Maximum number of edges in each batch 24 | "max_edges": 256, 25 | # Input size to the GNN. Make sure your first gnn layer 26 | # has this many input channels 27 | "gnn_input_size": 64, 28 | # Number of output channels of the GNN. This feeds into the logits 29 | # and value function layers 30 | "gnn_output_size": 64, 31 | # GNN model that takes x, edge_index, weights 32 | # Note that input will be reshaped by a linear layer 33 | # to gnn_input_size 34 | "gnn": torch_geometric.nn.Sequential( 35 | "x, edge_index, weights", 36 | [ 37 | (torch_geometric.nn.GraphConv(64, 64), "x, edge_index, weights -> x"), 38 | torch.nn.Tanh(), 39 | (torch_geometric.nn.GraphConv(64, 64), "x, edge_index, weights -> x"), 40 | torch.nn.Tanh(), 41 | ], 42 | ), 43 | # If set, max_hops will extract the valid k-hop subgraph before 44 | # the forward pass. This will improve runtime if set. It should 45 | # be set to the number of graph layers in the GNN 46 | "max_hops": None, 47 | # Torch.nn.module used for determining edges between nodes. 48 | # You can chain multiple modules together use 49 | # torch_geometric.nn.Sequential 50 | "edge_selectors": None, 51 | # Same as edge selectors, but called after reprojection 52 | # and positional encoding. Only use non-human (learned) edges here 53 | # as they are no longer in a human readable form 54 | "aux_edge_selectors": None, 55 | # Optional network that processes observations before 56 | # the GNN. May allow for learning representations that 57 | # aggregate better. Note the input to the preprocessor will 58 | # already be of shape "gnn_input_size" 59 | # Note that the node preprocessor will run after observations are 60 | # inserted in the graph. This means the observations can be 61 | # reconstructed at the expense of greater memory usage compared 62 | # to the preprocessor 63 | "preprocessor": None, 64 | # Whether to train the preprocessor or freeze the weights 65 | "preprocessor_frozen": False, 66 | "pre_preprocessor": None, 67 | # Whether the prev action should be placed in the observation nodes 68 | "use_prev_action": False, 69 | # Whether to use positional encoding (ala transformer) in the GNN 70 | # False for none, 'cat' for concatenate encoding to feature vector, 71 | # and 'add' for sum encoding with feature vector 72 | "positional_encoding": None, 73 | # if 'cat', how many dimensions should be reserved for positional 74 | # encoding 75 | "positional_encoding_dim": 4, 76 | } 77 | 78 | def __init__( 79 | self, 80 | obs_space: gym.spaces.Space, 81 | action_space: gym.spaces.Space, 82 | num_outputs: int, 83 | model_config: ModelConfigDict, 84 | name: str, 85 | **custom_model_kwargs, 86 | ): 87 | TorchModelV2.__init__( 88 | self, obs_space, action_space, num_outputs, model_config, name 89 | ) 90 | nn.Module.__init__(self) 91 | self.num_outputs = num_outputs 92 | self.obs_dim = gym.spaces.utils.flatdim(obs_space) 93 | self.act_space = action_space 94 | self.act_dim = gym.spaces.utils.flatdim(action_space) 95 | # edge selectors must be attrib of torch.nn.module 96 | # so edge_selectors get self.training, etc. 97 | 98 | for k in custom_model_kwargs: 99 | assert k in self.DEFAULT_CONFIG, f"Invalid config key {k}" 100 | self.cfg = dict(self.DEFAULT_CONFIG, **custom_model_kwargs) 101 | self.input_dim = self.obs_dim 102 | if self.cfg["use_prev_action"]: 103 | self.input_dim += self.act_dim 104 | self.view_requirements["prev_actions"] = ViewRequirement( 105 | "actions", space=self.action_space, shift=-1 106 | ) 107 | 108 | self.build_network(self.cfg) 109 | print("Full GCM network is:", self) 110 | 111 | self.cur_val = None 112 | 113 | def build_network(self, cfg): 114 | """Builds the GNN and MLPs based on config""" 115 | pp = torch.nn.Linear(self.input_dim, cfg["gnn_input_size"]) 116 | pe = None 117 | if cfg["positional_encoding"]: 118 | pe = PositionalEncoding( 119 | max_len=self.cfg["graph_size"], 120 | mode=cfg["positional_encoding"], 121 | cat_dim=cfg["positional_encoding_dim"], 122 | ) 123 | 124 | if cfg["preprocessor"]: 125 | if cfg["preprocessor_frozen"]: 126 | for param in cfg["preprocessor"].parameters(): 127 | param.requires_grad = False 128 | pp = torch.nn.Sequential(pp, cfg["preprocessor"]) 129 | 130 | self.gcm = SparseGCM( 131 | gnn=cfg["gnn"], 132 | graph_size=cfg["graph_size"], 133 | preprocessor=pp, 134 | edge_selectors=self.cfg["edge_selectors"], 135 | aux_edge_selectors=self.cfg["aux_edge_selectors"], 136 | positional_encoder=pe, 137 | max_hops=cfg["max_hops"], 138 | ) 139 | 140 | self.logit_branch = SlimFC( 141 | in_size=cfg["gnn_output_size"], 142 | out_size=self.num_outputs, 143 | activation_fn=None, 144 | initializer=normc_initializer(0.01), 145 | ) 146 | 147 | self.value_branch = SlimFC( 148 | in_size=cfg["gnn_output_size"], 149 | out_size=1, 150 | activation_fn=None, 151 | initializer=normc_initializer(0.01), 152 | ) 153 | 154 | def get_initial_state(self): 155 | nodes = torch.zeros((self.cfg["graph_size"], self.input_dim)) 156 | # If these are type==long, they become np arrays instead of torch... 157 | edges = torch.zeros((2, self.cfg["max_edges"])) 158 | # If we set weights to one ray makes it into np array of objs... 159 | weights = torch.zeros((1, self.cfg["max_edges"])) 160 | 161 | T = torch.tensor(0, dtype=torch.long) 162 | state = [nodes, edges, weights, T] 163 | 164 | return state 165 | 166 | def value_function(self): 167 | assert self.cur_val is not None, "must call forward() first" 168 | return self.cur_val 169 | 170 | def forward( 171 | self, 172 | input_dict: Dict[str, TensorType], 173 | state: List[TensorType], 174 | seq_lens: TensorType, 175 | ) -> Tuple[TensorType, List[TensorType]]: 176 | 177 | if self.cfg["use_prev_action"]: 178 | prev_acts = one_hot(input_dict["prev_actions"].float(), self.act_space) 179 | prev_acts = prev_acts.reshape(-1, self.act_dim) 180 | flat = torch.cat((input_dict["obs_flat"], prev_acts), dim=-1) 181 | else: 182 | flat = input_dict["obs_flat"] 183 | 184 | dense = add_time_dimension(flat, max_seq_len=seq_lens.max(), framework="torch") 185 | # TODO: ppo sequencing is broken (rllib bug not ours) 186 | # Batch and Time 187 | B = dense.shape[0] 188 | t = dense.shape[1] 189 | # Sometimes numpy sometimes tensor... 190 | if type(seq_lens) == np.ndarray: 191 | taus = torch.from_numpy(seq_lens).to(dense.device).long() 192 | else: 193 | taus = seq_lens.long() 194 | 195 | # We cannot set non-zero values in get_initial_state 196 | # so do it here instead, fill -1 and 1 for edges and weights respectively 197 | # where T (graph timesteps) is zero 198 | init_batch_idx = (state[-1] == 0).nonzero().squeeze() 199 | state[1][init_batch_idx] = -1 200 | state[2][init_batch_idx] = 1.0 201 | nodes, adj, T = util.unpack_hidden(state, B) 202 | 203 | T = T.long() 204 | hidden = (nodes, adj, T) 205 | 206 | # Push thru pre-gcm layers 207 | out, hidden = self.gcm(dense, taus, hidden) 208 | 209 | logits = self.logit_branch(out).reshape(B * t, self.num_outputs) 210 | self.cur_val = self.value_branch(out).reshape(B * t) 211 | 212 | packed_state = util.pack_hidden(hidden, B, self.cfg["max_edges"]) 213 | return logits, list(packed_state) 214 | -------------------------------------------------------------------------------- /src/gcm/sparse_edge_selectors/learned.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import functools 3 | from typing import List, Any, Tuple, Union 4 | 5 | from torchtyping import TensorType, patch_typeguard # type: ignore 6 | from typeguard import typechecked # type: ignore 7 | from gcm import util 8 | 9 | #patch_typeguard() 10 | 11 | 12 | class LearnedEdge(torch.nn.Module): 13 | """Add temporal edges to the edge list""" 14 | 15 | def __init__( 16 | self, 17 | # Feature size of a graph node 18 | input_size: int = 0, 19 | # Custom model, if None, one will be created for you 20 | model: Union[None, torch.nn.Module] = None, 21 | # Number of edges to sample per node (upper bounds the 22 | # number of edges for each node) 23 | num_edge_samples: int = 5, 24 | # Whether to randomly sample using gumbel softmax 25 | # or use sparsemax 26 | deterministic: bool = False, 27 | # Only consider edges to vertices in a fixed-size window 28 | # this reduces memory usage but prohibits edges to nodes outside 29 | # the window. Use None for no window (all possible edges) 30 | window: Union[int, None] = None, 31 | # Stores useful information in instance variables 32 | log_stats: bool = True, 33 | # Default (initial) temperature for gumbel-softmax 34 | softmax_temp: float = 1.0, 35 | # Whether the temperature parameter for 36 | # softmax/gumbel softmax should be learned 37 | # or fixed 38 | learn_softmax_temp: bool = True, 39 | # If learning the softmax temp, 40 | # the lower and upper bounds for the temperature 41 | # variable. Note that softmax is undefined for temp <= 0 42 | temp_bounds: Tuple[float, float] = (0.001, 5), 43 | # Whether or not to store gradients for logging 44 | store_grads: bool = True, 45 | ): 46 | super().__init__() 47 | assert model or input_size, "Must specify either input_size or model" 48 | self.deterministic = deterministic 49 | self.num_edge_samples = num_edge_samples 50 | self.store_grads = store_grads 51 | # This MUST be done here 52 | # if initialized in forward model does not learn... 53 | self.edge_network = self.build_edge_network(input_size) if model is None else model 54 | self.ste = util.StraightThroughEstimator() 55 | self.window = window 56 | self.log_stats = log_stats 57 | self.stats = {} 58 | self.tau_param = torch.tensor([softmax_temp]) 59 | self.temp_bounds = temp_bounds 60 | if learn_softmax_temp: 61 | self.tau_param = torch.nn.Parameter(self.tau_param) 62 | 63 | def init_weights(self, m): 64 | if isinstance(m, torch.nn.Linear): 65 | torch.nn.init.orthogonal_(m.weight) 66 | 67 | def grad_hook(self, p_name, grad): 68 | self.stats[f"gnorm_{p_name}"] = grad.norm().detach().item() 69 | 70 | def build_edge_network(self, input_size: int) -> torch.nn.Sequential: 71 | """Builds a network to predict edges. 72 | Network input: (i || j) 73 | Network output: logits(edge(i,j)) 74 | """ 75 | m = torch.nn.Sequential( 76 | torch.nn.Linear(2 * input_size, input_size), 77 | torch.nn.ReLU(), 78 | torch.nn.LayerNorm(input_size), 79 | torch.nn.Linear(input_size, input_size), 80 | torch.nn.ReLU(), 81 | torch.nn.LayerNorm(input_size), 82 | torch.nn.Linear(input_size, 1), 83 | ) 84 | m.apply(self.init_weights) 85 | if self.store_grads: 86 | for n, p in m.named_parameters(): 87 | p.register_hook(functools.partial(self.grad_hook, n)) 88 | return m 89 | 90 | @typechecked 91 | def forward( 92 | self, 93 | nodes: TensorType["B", "N", "feat", float], # type: ignore # noqa: F821 94 | T: TensorType["B", int], # type: ignore # noqa: F821 95 | taus: TensorType["B", int], # type: ignore # noqa: F821 96 | B: int, 97 | ) -> TensorType["B", "N", "N", float, torch.sparse_coo]: # type: ignore # noqa: F821 98 | 99 | # No edges to create 100 | if (T + taus).max() <= 1: 101 | return torch.sparse_coo_tensor( 102 | indices=torch.zeros(3,0, dtype=torch.long, device=nodes.device), 103 | values=torch.zeros(0, device=nodes.device), 104 | size=(B, nodes.shape[1], nodes.shape[1]) 105 | ) 106 | 107 | if list(self.parameters())[0].device != nodes.device: 108 | self = self.to(nodes.device) 109 | 110 | # Do for all batches at once 111 | # 112 | # Construct indices denoting all edges, which we sample from 113 | # Note that we only want to sample incoming edges from nodes T to T + tau 114 | # these indices denote nodes pairs being fed to network 115 | edge_idx = util.get_causal_edges(T, taus, self.window) 116 | 117 | batch_idx, sink_idx, source_idx = edge_idx.unbind() 118 | # Feed node pairs to network 119 | sink_nodes = nodes[batch_idx, sink_idx] 120 | source_nodes = nodes[batch_idx, source_idx] 121 | network_input = torch.cat((sink_nodes, source_nodes), dim=-1) 122 | # Logits is of shape [N] 123 | logits = self.edge_network(network_input).squeeze() 124 | # TODO rather than sparse to dense conversion, implement 125 | # a sparse gumbel softmax 126 | cutoff = 1 / (1 + self.num_edge_samples) 127 | gs_input = torch.sparse_coo_tensor( 128 | indices=edge_idx, 129 | values=logits, 130 | size=(B, nodes.shape[1], nodes.shape[1]) 131 | ) 132 | self.tau_param.data.clamp_(*self.temp_bounds) 133 | if not self.deterministic: 134 | soft = util.sparse_gumbel_softmax( 135 | gs_input, dim=2, hard=False, tau=self.tau_param 136 | ) 137 | else: 138 | soft = util.sparse_tempered_softmax( 139 | gs_input, dim=2, hard=False, tau=self.tau_param 140 | ) 141 | 142 | activation_mask = soft.values() > cutoff 143 | adj = torch.sparse_coo_tensor( 144 | indices=soft.indices()[:,activation_mask], 145 | values=( 146 | soft.values()[activation_mask] 147 | / soft.values()[activation_mask].detach() 148 | ), 149 | size=(B, nodes.shape[1], nodes.shape[1]) 150 | ) 151 | 152 | # CAREFUL _values() detaches from autograd graph and breaks grads 153 | self.stats["edges_per_node"] = ( 154 | adj._values().numel() / taus.sum().detach() 155 | ).item() 156 | self.stats["edge_density"] = adj._values().numel() / edge_idx[0].numel() 157 | self.stats["logits_mean"] = logits.detach().mean().item() 158 | self.stats["logits_var"] = logits.detach().var().item() 159 | self.stats["temperature"] = self.tau_param.detach().item() 160 | return adj 161 | -------------------------------------------------------------------------------- /src/gcm/sparse_edge_selectors/spatial.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch_geometric 3 | from typing import List, Any, Tuple, Union 4 | 5 | from torchtyping import TensorType, patch_typeguard # type: ignore 6 | from typeguard import typechecked # type: ignore 7 | from gcm import util 8 | from torch_geometric.transforms.delaunay import Delaunay 9 | 10 | #patch_typeguard() 11 | 12 | class SpatialKNNEdge(torch.nn.Module): 13 | def __init__(self, position_slice, k, causal=True): 14 | # In meters 15 | super().__init__() 16 | self.k = k 17 | self.position_slice = position_slice 18 | self.causal = causal 19 | 20 | @typechecked 21 | def forward( 22 | self, 23 | nodes: TensorType["B", "N", "feat", float], # type: ignore # noqa: F821 24 | T: TensorType["B", int], # type: ignore # noqa: F821 25 | taus: TensorType["B", int], # type: ignore # noqa: F821 26 | B: int, 27 | ) -> TensorType["B", "MAX_EDGES", "MAX_EDGES", float, torch.sparse_coo]: # type: ignore # noqa: F821 28 | 29 | # No edges to create 30 | if (T + taus).max() <= 1: 31 | return torch.sparse_coo_tensor( 32 | indices=torch.zeros(3,0, dtype=torch.long, device=nodes.device), 33 | values=torch.zeros(0, device=nodes.device), 34 | size=(B, nodes.shape[1], nodes.shape[1]) 35 | ) 36 | 37 | pos = nodes[:, :, self.position_slice] 38 | if not self.causal: 39 | raise NotImplementedError() 40 | 41 | edges = [] 42 | for b in range(B): 43 | sink_idx = torch.arange(T[b], T[b] + taus[b]) 44 | source_idx = torch.arange(0, T[b] + taus[b]) 45 | sink = pos[b, sink_idx] 46 | source = pos[b, source_idx] 47 | edge = torch_geometric.nn.knn(source, sink, self.k) 48 | # Filter out non-causal edges 49 | # TODO: This will break during backprop 50 | # we won't actually compute knn as we will prune 51 | # probably all edges 52 | # try topk: https://discuss.pytorch.org/t/k-nearest-neighbor-in-pytorch/59695/2 53 | mask = edge[0] > edge[1] 54 | edge = edge[:, mask] 55 | batch = b * torch.ones(edge.shape[-1], device=pos.device, dtype=torch.long) 56 | edges.append(torch.stack([batch, edge[0], edge[1]])) 57 | 58 | edges = torch.cat(edges, dim=-1) 59 | weights = torch.ones(edges.shape[-1], device=edges.device) 60 | adj = torch.sparse_coo_tensor( 61 | indices=edges, values=weights, size=(B, nodes.shape[1], nodes.shape[1]) 62 | ) 63 | return adj 64 | 65 | class SpatialRadiusEdge(torch.nn.Module): 66 | def __init__(self, position_slice, radius=0.25, causal=True): 67 | # In meters 68 | super().__init__() 69 | self.radius = radius 70 | self.position_slice = position_slice 71 | self.causal = causal 72 | 73 | @typechecked 74 | def forward( 75 | self, 76 | nodes: TensorType["B", "N", "feat", float], # type: ignore # noqa: F821 77 | T: TensorType["B", int], # type: ignore # noqa: F821 78 | taus: TensorType["B", int], # type: ignore # noqa: F821 79 | B: int, 80 | ) -> TensorType["B", "MAX_EDGES", "MAX_EDGES", float, torch.sparse_coo]: # type: ignore # noqa: F821 81 | # No edges to create 82 | if (T + taus).max() <= 1: 83 | return torch.sparse_coo_tensor( 84 | indices=torch.zeros(3,0, dtype=torch.long, device=nodes.device), 85 | values=torch.zeros(0, device=nodes.device), 86 | size=(B, nodes.shape[1], nodes.shape[1]) 87 | ) 88 | 89 | pos = nodes[:, :, self.position_slice] 90 | edges = [] 91 | for b in range(B): 92 | if self.causal: 93 | sink_idx, source_idx = util.get_causal_edges_one_batch(T[b], taus[b]) 94 | else: 95 | new_idx = torch.arange(T[b], T[b] + taus[b], device=pos.device) 96 | old_idx = torch.arange(T[b] + taus[b], device=pos.device) 97 | sink_idx, source_idx = torch.cartesian_prod(old_idx, new_idx).unbind(dim=1) 98 | sink_nodes = pos[b, sink_idx] 99 | source_nodes = pos[b, source_idx] 100 | # For some reason, torch.cdist is really slow... 101 | dist = ((sink_nodes - source_nodes) ** 2).sum(dim=-1).sqrt() 102 | idx_idx = torch.where(dist < self.radius) 103 | sink_edges = sink_idx[idx_idx] 104 | source_edges = source_idx[idx_idx] 105 | batch = b * torch.ones(sink_edges.numel(), device=pos.device, dtype=torch.long) 106 | edges.append( 107 | torch.stack([batch, sink_edges, source_edges]) 108 | ) 109 | 110 | edges = torch.cat(edges, dim=-1) 111 | weights = torch.ones(edges.shape[-1], device=edges.device) 112 | adj = torch.sparse_coo_tensor( 113 | indices=edges, values=weights, size=(B, nodes.shape[1], nodes.shape[1]) 114 | ) 115 | return adj 116 | 117 | ''' 118 | class SpatialDelaunayEdge(torch.nn.Module): 119 | """Add temporal edges to the edge list""" 120 | 121 | def __init__(self, position_slice): 122 | super().__init__() 123 | self.position_slice = position_slice 124 | 125 | @typechecked 126 | def forward( 127 | self, 128 | nodes: TensorType["B", "N", "feat", float], # type: ignore # noqa: F821 129 | T: TensorType["B", int], # type: ignore # noqa: F821 130 | taus: TensorType["B", int], # type: ignore # noqa: F821 131 | B: int, 132 | ) -> TensorType["B", "MAX_EDGES", "MAX_EDGES", float, torch.sparse_coo]: # type: ignore # noqa: F821 133 | # Connect each [t in T to T + tau] to [t - h for h in hops] 134 | 135 | batch_starts, batch_ends = util.get_batch_offsets(T + taus) 136 | edge_base: Union[List, torch.Tensor] = [] 137 | 138 | pos = nodes[:, :, position_slice] 139 | for b in range(B): 140 | d = 141 | 142 | 143 | # Build a base of edges (x - hop for all hops) 144 | # then we add the batch offsets to them 145 | # Add the -1 filler so we can stack the batches later 146 | edge_base = torch.empty( 147 | (B, taus.max()), device=nodes.device, dtype=torch.long 148 | ).fill_(-1) 149 | for b in range(B): 150 | edge_base[b, : taus[b]] = torch.arange( 151 | T[b], T[b] + taus[b], device=nodes.device 152 | ) 153 | 154 | # No edges to add 155 | if len(edge_base) < 1: 156 | # TODO don't hardcode max edges 157 | return torch.zeros( 158 | (B, int(1e5), int(1e5)), 159 | device=nodes.device, 160 | layout=torch.sparse_coo, 161 | dtype=torch.float, 162 | ) 163 | empty_edges = torch.empty((2, 0), device=nodes.device, dtype=torch.long) 164 | empty_weights = torch.empty((0), device=nodes.device, dtype=torch.float) 165 | return empty_edges, empty_weights 166 | 167 | sink_edges = edge_base.unsqueeze(-1).repeat(1, 1, len(self.hops)) 168 | source_edges = sink_edges - self.hops.to(nodes.device) 169 | batch_idx = ( 170 | torch.arange(B, device=nodes.device) 171 | .unsqueeze(-1) 172 | .unsqueeze(-1) 173 | .expand(source_edges.shape) 174 | ) 175 | 176 | sink_edges = sink_edges.flatten() 177 | source_edges = source_edges.flatten() 178 | edge_batch = batch_idx.flatten() 179 | edge_idx = torch.stack([edge_batch, sink_edges, source_edges]) 180 | weights = torch.ones(source_edges.shape, device=nodes.device) 181 | 182 | # Need to filter negative (invalid) indices 183 | mask = (source_edges >= 0) * (sink_edges > 0) 184 | filtered_edge_idx = edge_idx[:, mask] 185 | weights = weights[mask] 186 | adj = torch.sparse_coo_tensor( 187 | indices=filtered_edge_idx, 188 | values=weights, 189 | size=(B, int(1e5), int(1e5)), 190 | device=nodes.device, 191 | ) 192 | return adj 193 | ''' 194 | -------------------------------------------------------------------------------- /src/gcm/sparse_edge_selectors/temporal.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import List, Any, Tuple, Union 3 | 4 | from torchtyping import TensorType, patch_typeguard # type: ignore 5 | from typeguard import typechecked # type: ignore 6 | from gcm import util 7 | 8 | patch_typeguard() 9 | 10 | 11 | class TemporalEdge(torch.nn.Module): 12 | """Add temporal edges to the edge list""" 13 | 14 | def __init__(self, hops: List[int] = [1]): 15 | super().__init__() 16 | self.hops = torch.tensor(hops) 17 | 18 | @typechecked 19 | def forward( 20 | self, 21 | nodes: TensorType["B", "N", "feat", float], # type: ignore # noqa: F821 22 | T: TensorType["B", int], # type: ignore # noqa: F821 23 | taus: TensorType["B", int], # type: ignore # noqa: F821 24 | B: int, 25 | ) -> TensorType["B", "MAX_EDGES", "MAX_EDGES", float, torch.sparse_coo]: # type: ignore # noqa: F821 26 | # Connect each [t in T to T + tau] to [t - h for h in hops] 27 | 28 | batch_starts, batch_ends = util.get_batch_offsets(T + taus) 29 | edge_base: Union[List, torch.Tensor] = [] 30 | 31 | # Build a base of edges (x - hop for all hops) 32 | # then we add the batch offsets to them 33 | # Add the -1 filler so we can stack the batches later 34 | edge_base = torch.empty((B, taus.max()), device=nodes.device, dtype=torch.long).fill_(-1) 35 | for b in range(B): 36 | edge_base[b, :taus[b]] = torch.arange(T[b], T[b] + taus[b], device=nodes.device) 37 | 38 | 39 | # No edges to add 40 | if len(edge_base) < 1: 41 | # TODO don't hardcode max edges 42 | return torch.zeros((B, int(1e5), int(1e5)), device=nodes.device, layout=torch.sparse_coo, dtype=torch.float) 43 | empty_edges = torch.empty((2, 0), device=nodes.device, dtype=torch.long) 44 | empty_weights = torch.empty((0), device=nodes.device, dtype=torch.float) 45 | return empty_edges, empty_weights 46 | 47 | 48 | sink_edges = edge_base.unsqueeze(-1).repeat(1, 1, len(self.hops)) 49 | source_edges = sink_edges - self.hops.to(nodes.device) 50 | batch_idx = torch.arange(B, device=nodes.device).unsqueeze(-1).unsqueeze(-1).expand(source_edges.shape) 51 | 52 | sink_edges = sink_edges.flatten() 53 | source_edges = source_edges.flatten() 54 | edge_batch = batch_idx.flatten() 55 | edge_idx = torch.stack([edge_batch, sink_edges, source_edges]) 56 | weights = torch.ones(source_edges.shape, device=nodes.device) 57 | 58 | # Need to filter negative (invalid) indices 59 | mask = (source_edges >= 0) * (sink_edges > 0) 60 | filtered_edge_idx = edge_idx[:, mask] 61 | weights = weights[mask] 62 | adj = torch.sparse_coo_tensor(indices=filtered_edge_idx, values=weights, size=(B, int(1e5), int(1e5)), device=nodes.device) 63 | return adj 64 | -------------------------------------------------------------------------------- /src/gcm/sparse_gcm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch_geometric 3 | from gcm import util 4 | from typing import Union, Tuple, List 5 | 6 | from torchtyping import TensorType, patch_typeguard # type: ignore 7 | from typeguard import typechecked 8 | 9 | #patch_typeguard() 10 | 11 | 12 | class SparseGCM(torch.nn.Module): 13 | """Graph Associative Memory using sparse-graph representations""" 14 | 15 | did_warn = False 16 | 17 | def __init__( 18 | self, 19 | # Graph neural network, see torch_geometric.nn.Sequential 20 | # for some examples 21 | gnn: torch.nn.Module, 22 | # Preprocessor for each feat vec before it's placed in graph 23 | preprocessor: torch.nn.Module = None, 24 | # an edge selector from gcm.edge_selectors 25 | # you can chain multiple selectors together using 26 | # torch_geometric.nn.Sequential 27 | edge_selectors: torch.nn.Module = None, 28 | # Auxiliary edge selectors are called 29 | # after the positional encoding and reprojection 30 | # this should only be used for non-human (learned) priors 31 | aux_edge_selectors: torch.nn.Module = None, 32 | # Maximum number of nodes in the graph 33 | graph_size: int = 128, 34 | # Optional maximum hops in the graph. If set, 35 | # we will extract the k-hop subgraph for better efficiency. 36 | # If set, this should be equal to the number of convolution 37 | # layers in the GNN 38 | max_hops: Union[int, None] = None, 39 | # Whether to add sin/cos positional encoding like in transformer 40 | # to the nodes 41 | # Creates an ordering in the graph 42 | positional_encoder: torch.nn.Module = None, 43 | ): 44 | super().__init__() 45 | 46 | self.preprocessor = preprocessor 47 | self.gnn = gnn 48 | self.graph_size = graph_size 49 | self.edge_selectors = edge_selectors 50 | self.aux_edge_selectors = aux_edge_selectors 51 | self.positional_encoder = positional_encoder 52 | self.max_hops = max_hops 53 | self.ste = util.StraightThroughEstimator() 54 | 55 | def get_initial_hidden_state(self, x): 56 | """Given a dummy x of shape [B, feats], construct 57 | the hidden state for the base case (adj matrix, weights, etc)""" 58 | """Returns the initial hidden state h (e.g. h, output = gcm(input, h)), 59 | for a given batch size (B). Feats denotes the feature size (# dims of each 60 | node in the graph).""" 61 | 62 | assert x.dim() == 3 63 | B, _, feats = x.shape 64 | #edges = torch.zeros(2, 0, device=x.device, dtype=torch.long) 65 | nodes = torch.zeros(B, self.graph_size, feats, device=x.device) 66 | #weights = torch.zeros(0, device=x.device) 67 | adj = torch.zeros((B, self.graph_size, self.graph_size), device=x.device, layout=torch.sparse_coo) 68 | T = torch.zeros(B, dtype=torch.long, device=x.device) 69 | 70 | return nodes, adj, T 71 | 72 | @typechecked 73 | def forward( 74 | self, 75 | x: TensorType["B", "t", "feat", float], # type: ignore # noqa: F821 #input observations 76 | taus: TensorType["B", int], # type: ignore # noqa: F821 #sequence_lengths 77 | hidden: Union[ # type: ignore # noqa: F821 78 | None, 79 | Tuple[ 80 | TensorType["B", "N", "feats", float], # noqa: F821 # Nodes 81 | TensorType["B", "MAX_E", "MAX_E", float, torch.sparse_coo], # noqa: F821 # Sparse adj 82 | TensorType["B", int], # noqa: F821 # T 83 | ], 84 | ], 85 | ) -> Tuple[ # type: ignore 86 | torch.Tensor, 87 | Tuple[ 88 | TensorType["B", "N", "feats", float], # noqa: F821 # Nodes 89 | TensorType["B", "MAX_E", "MAX_E", float, torch.sparse_coo], # noqa: F821 # Sparse adj 90 | TensorType["B", int], # noqa: F821 # T 91 | ], 92 | ]: 93 | """Add a memory x with temporal size tau to the graph, and query the memory for it. 94 | B = batch size 95 | N = maximum graph size 96 | T = number of timesteps in graph before input 97 | taus = number of timesteps in each input batch 98 | t = the zero-padded time dimension (i.e. max(taus)) 99 | E = number of edge pairs 100 | """ 101 | # Base case 102 | if hidden is None: 103 | hidden = self.get_initial_hidden_state(x) 104 | 105 | nodes, adj, T = hidden 106 | # TODO remove coalesce when bug is fixed 107 | adj = adj.coalesce() 108 | 109 | N = nodes.shape[1] 110 | B = x.shape[0] 111 | 112 | # Batch and time idxs for nodes we intend to add 113 | B_idxs, tau_idxs = util.get_new_node_idxs(T, taus, B) 114 | dense_B_idxs, dense_tau_idxs = util.get_nonpadded_idxs(T, taus, B) 115 | 116 | nodes = nodes.clone() 117 | # Add new nodes to the current graph 118 | #assert torch.all(adj._indices()[1] < adj._indices()[2]) 119 | # TODO: Wrap around instead of terminating 120 | if tau_idxs.max() >= N: 121 | raise Exception("Overflow") 122 | 123 | nodes[B_idxs, tau_idxs] = x[dense_B_idxs, dense_tau_idxs] 124 | 125 | # We do not want to modify graph nodes in the GCM 126 | # Do all mutation operations on dirty_nodes, 127 | # then use clean nodes in the graph state 128 | dirty_nodes = nodes.clone() 129 | 130 | if self.edge_selectors: 131 | # TODO remove coalesce when bug is fixed 132 | new_adj = self.edge_selectors(dirty_nodes, T, taus, B).coalesce() 133 | new_idx = torch.cat([adj.indices(), new_adj.indices()], dim=-1) 134 | new_val = torch.cat([adj.values(), new_adj.values()], dim=-1) 135 | adj = torch.sparse_coo_tensor( 136 | indices=new_idx, 137 | values=new_val, 138 | size=adj.shape 139 | ).coalesce() 140 | 141 | 142 | # Thru network 143 | if self.preprocessor: 144 | dirty_nodes = self.preprocessor(dirty_nodes) 145 | if self.positional_encoder: 146 | dirty_nodes = self.positional_encoder(dirty_nodes, T + taus) 147 | if self.aux_edge_selectors: 148 | # TODO remove coalesce when bug is fixed 149 | new_adj = self.aux_edge_selectors(dirty_nodes, T, taus, B).coalesce() 150 | new_idx = torch.cat([adj.indices(), new_adj.indices()], dim=-1) 151 | new_val = torch.cat([adj.values(), new_adj.values()], dim=-1) 152 | adj = torch.sparse_coo_tensor(indices=new_idx, values=new_val, size=adj.shape).coalesce() 153 | 154 | # Remove duplicates from edge selectors 155 | # and set all weights to 1.0 without cancelling out gradients 156 | # from logits. 157 | # For some reason, torch_geometric coalesces incorrectly here 158 | # TODO: remove comment when coalesce bug is fixed 159 | # adj = adj.coalesce() 160 | adj = torch.sparse_coo_tensor( 161 | indices=adj.indices(), 162 | values=adj.values() / adj.values().detach(), 163 | size=adj.shape 164 | ) 165 | # Convert to GNN input format 166 | flat_nodes, output_node_idxs = util.flatten_nodes(dirty_nodes, T, taus, B) 167 | edges, weights, edge_batch = util.flatten_adj(adj, T, taus, B) 168 | # Our adj matrix is sink -> source, but torch_geometric 169 | # expects edgelist as source -> sink, so flip 170 | edges = torch.flip(edges, (0,)) 171 | assert torch.all(edges[0] < edges[1]), "Causality violated" 172 | if edges.numel() > 0: 173 | edges, weights = torch_geometric.utils.coalesce( 174 | edges, weights, reduce="mean" 175 | ) 176 | if self.max_hops is None: 177 | # Convolve over entire graph 178 | node_feats = self.gnn(flat_nodes, edges, weights) 179 | # Extract the hidden repr at the new nodes 180 | # Each mx is variable in temporal dim, so return 2D tensor of [B*tau, feat] 181 | mx = node_feats[output_node_idxs] 182 | else: 183 | # Convolve over subgraph (more efficient) induced by the 184 | # target nodes (taus) 185 | # i.e. ignore nodes/edges that are not connected 186 | # to the tau (input) nodes 187 | ( 188 | subnodes, 189 | subedges, 190 | node_map, 191 | edge_mask, 192 | ) = torch_geometric.utils.k_hop_subgraph( 193 | output_node_idxs, 194 | self.max_hops, 195 | edges, 196 | relabel_nodes=True, 197 | num_nodes=(T + taus).sum() 198 | ) 199 | mx = self.gnn(flat_nodes[subnodes], subedges, weights[edge_mask])[node_map] 200 | 201 | assert torch.all( 202 | torch.isfinite(mx) 203 | ), "Got NaN in returned memory, try using tanh activation" 204 | 205 | # Input obs were dense and padded, so output should be dense and padded 206 | dense_B_idxs, dense_tau_idxs = util.get_nonpadded_idxs(T, taus, B) 207 | mx_dense = torch.zeros((*x.shape[:-1], mx.shape[-1]), device=x.device) 208 | mx_dense[dense_B_idxs, dense_tau_idxs] = mx 209 | 210 | T = T + taus 211 | 212 | return mx_dense, (nodes, adj, T) 213 | -------------------------------------------------------------------------------- /src/gcm/util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch_geometric 4 | from torch_scatter import scatter_max, scatter 5 | #import sparsemax 6 | from typing import Tuple, List 7 | 8 | 9 | class STEFunction(torch.autograd.Function): 10 | @staticmethod 11 | def forward(ctx, input): 12 | return (input > 0).float() 13 | 14 | @staticmethod 15 | def backward(ctx, grad_output): 16 | return grad_output 17 | # return torch.nn.functional.hardtanh(grad_output) 18 | 19 | 20 | class StraightThroughEstimator(torch.nn.Module): 21 | def __init__(self): 22 | super(StraightThroughEstimator, self).__init__() 23 | 24 | def forward(self, x): 25 | x = STEFunction.apply(x) 26 | return x 27 | 28 | 29 | class Spardmax(torch.nn.Module): 30 | """A hard version of sparsemax""" 31 | 32 | def __init__(self, dim=-1, cutoff=0): 33 | super().__init__() 34 | self.dim = dim 35 | self.cutoff = cutoff 36 | self.sm = sparsemax.Sparsemax(dim) 37 | 38 | def forward(self, x): 39 | # Straight through. 40 | y_soft = self.sm(x) 41 | y_hard = (y_soft > self.cutoff).float() 42 | return y_hard - y_soft.detach() + y_soft 43 | 44 | 45 | class Hardmax(torch.nn.Module): 46 | def __init__(self, dim=-1, cutoff=0.2): 47 | super().__init__() 48 | self.dim = dim 49 | self.cutoff = cutoff 50 | self.sm = torch.nn.Softmax(dim) 51 | 52 | def forward(self, x): 53 | # Straight through. 54 | y_soft = self.sm(x) 55 | y_hard = (y_soft > self.cutoff).float() 56 | return y_hard - y_soft.detach() + y_soft 57 | 58 | def sparse_max(x: torch.sparse_coo, dim: int=-1, keepdim=True): 59 | vals, counts = torch.unique(x._indices(), return_counts=True) 60 | max_size = counts.max() 61 | dense = torch.empty(max_size).fill_(1e-20) 62 | 63 | def flatten_idx(idx): 64 | return idx[0] * idx.shape[1] + idx[1] 65 | 66 | def unflatten_idx(idx, b): 67 | b_idx = idx // b 68 | f_idx = idx % b 69 | return torch.stack((b_idx, f_idx)) 70 | 71 | 72 | def flatten_idx_n_dim(idx): 73 | assert idx.ndim == 2 74 | strides = idx.max(dim=1).values + 1 75 | #offsets = strides.cumprod(0).flip(0) 76 | new_idx = torch.zeros(idx.shape[-1], dtype=torch.long, device=idx.device) 77 | offsets = [] 78 | 79 | for i in range(len(strides) - 1): 80 | offset = strides[i + 1:].prod() 81 | offsets.append(offset) 82 | new_idx += offset * idx[i] 83 | 84 | new_idx += idx[-1] 85 | 86 | return new_idx, offsets 87 | 88 | 89 | def sparse_gumbel_softmax( 90 | logits: torch.sparse_coo, 91 | dim: int, 92 | tau: float=1, 93 | hard: bool=False, 94 | ) -> torch.sparse_coo: 95 | # TODO remove coalesce when bug is fixed 96 | logits = logits.coalesce() 97 | gumbels = -torch.empty_like(logits.values()).exponential_().log() 98 | gumbels = (logits.values() + gumbels) / tau 99 | gumbels = torch.sparse_coo_tensor( 100 | indices=logits.indices(), 101 | values=gumbels, 102 | size=logits.shape 103 | ) 104 | # TODO remove coalesce when bug is fixed 105 | y_soft = torch.sparse.softmax(gumbels, dim=dim).coalesce() 106 | 107 | if not hard: 108 | return y_soft 109 | 110 | index = [] 111 | # Want to max across dim, so exclude it during scatter 112 | scat_dims = list(range(dim)) + list(range(dim+1, logits._indices().shape[0])) 113 | scat_idx = y_soft._indices()[scat_dims] 114 | flat_scat_idx, offsets = flatten_idx_n_dim(scat_idx) 115 | maxes, argmax = scatter_max(y_soft._values(), flat_scat_idx) 116 | # TODO: Sometimes argmax will give us out of bound indices 117 | # because dim_size < numel 118 | # we would use the dim_size arg to scatter, but it crashes :( 119 | # so instead just mask out invalid entries 120 | argmax_mask = argmax < y_soft._indices().shape[-1] 121 | maxes = maxes[argmax_mask] 122 | argmax = argmax[argmax_mask] 123 | index = y_soft._indices()[:, argmax] 124 | 125 | return torch.sparse_coo_tensor( 126 | indices=index, 127 | values=maxes, 128 | size=logits.shape, 129 | device=logits.device 130 | ) 131 | 132 | def sparse_tempered_softmax( 133 | logits: torch.sparse_coo, 134 | dim: int, 135 | tau: float=1, 136 | hard: bool=False, 137 | ) -> torch.sparse_coo: 138 | # TODO remove coalesce when bug is fixed 139 | logits = logits.coalesce() 140 | sm_val = logits.values() / tau 141 | sm_in = torch.sparse_coo_tensor( 142 | indices=logits.indices(), 143 | values=sm_val, 144 | size=logits.shape 145 | ) 146 | # TODO remove coalesce when bug is fixed 147 | y_soft = torch.sparse.softmax(sm_in, dim=dim).coalesce() 148 | 149 | if not hard: 150 | return y_soft 151 | 152 | index = [] 153 | # Want to max across dim, so exclude it during scatter 154 | scat_dims = list(range(dim)) + list(range(dim+1, logits._indices().shape[0])) 155 | scat_idx = y_soft._indices()[scat_dims] 156 | flat_scat_idx, offsets = flatten_idx_n_dim(scat_idx) 157 | maxes, argmax = scatter_max(y_soft._values(), flat_scat_idx) 158 | # TODO: Sometimes argmax will give us out of bound indices 159 | # because dim_size < numel 160 | # we would use the dim_size arg to scatter, but it crashes :( 161 | # so instead just mask out invalid entries 162 | argmax_mask = argmax < y_soft._indices().shape[-1] 163 | maxes = maxes[argmax_mask] 164 | argmax = argmax[argmax_mask] 165 | index = y_soft._indices()[:, argmax] 166 | 167 | return torch.sparse_coo_tensor( 168 | indices=index, 169 | values=maxes, 170 | size=logits.shape, 171 | device=logits.device 172 | ) 173 | 174 | 175 | 176 | @torch.jit.script 177 | def get_nonpadded_idxs(T: torch.Tensor, taus: torch.Tensor, B: int): 178 | """Get the non-padded indices of a zero-padded 179 | batch of observations. In other words, get only valid elements and discard 180 | the meaningless zeros.""" 181 | dense_B_idxs = torch.cat( 182 | [torch.ones(taus[b], device=T.device, dtype=torch.long) * b for b in range(B)] 183 | ) 184 | # These must not be offset by T like get_new_node_idxs 185 | dense_tau_idxs = torch.cat( 186 | [torch.arange(taus[b], device=T.device) for b in range(B)] 187 | ) 188 | return dense_B_idxs, dense_tau_idxs 189 | 190 | 191 | @torch.jit.script 192 | def get_new_node_idxs(T: torch.Tensor, taus: torch.Tensor, B: int): 193 | """Given T and tau tensors, return indices matching batches to taus. 194 | These tell us which elements in the node matrix we have just added 195 | during this iteration, and organize them by batch. 196 | 197 | E.g. 198 | g_idxs = torch.where(B_idxs == 0) 199 | zeroth_graph_new_nodes = nodes[B_idxs[g_idxs], tau_idxs[g_idxs]] 200 | """ 201 | # TODO: batch this using b_idx and cumsum 202 | B_idxs = torch.cat( 203 | [torch.ones(taus[b], device=T.device, dtype=torch.long) * b for b in range(B)] 204 | ) 205 | tau_idxs = torch.cat( 206 | [torch.arange(T[b], T[b] + taus[b], device=T.device) for b in range(B)] 207 | ) 208 | return B_idxs, tau_idxs 209 | 210 | 211 | @torch.jit.script 212 | def get_valid_node_idxs(T: torch.Tensor, taus: torch.Tensor, B: int): 213 | """Given T and tau tensors, return indices matching batches to taus. 214 | These tell us which elements in the node matrix are valid for convolution, 215 | and organize them by batch. 216 | 217 | E.g. 218 | g_idxs = torch.where(B_idxs == 0) 219 | zeroth_graph_all_nodes = nodes[B_idxs[g_idxs], tau_idxs[g_idxs]] 220 | """ 221 | # TODO: batch this using b_idx and cumsum 222 | B_idxs = torch.cat( 223 | [ 224 | torch.ones(T[b] + taus[b], device=T.device, dtype=torch.long) * b 225 | for b in range(B) 226 | ] 227 | ) 228 | tau_idxs = torch.cat( 229 | [torch.arange(0, T[b] + taus[b], device=T.device) for b in range(B)] 230 | ) 231 | return B_idxs, tau_idxs 232 | 233 | 234 | def get_batch_offsets(T: torch.Tensor): 235 | """Get edge offsets into the flattened edge tensor.""" 236 | batch_ends = T.cumsum(dim=0) 237 | batch_starts = batch_ends.roll(1) 238 | batch_starts[0] = 0 239 | 240 | return batch_starts, batch_ends 241 | 242 | def get_causal_edges_one_batch(t, tau, window=None): 243 | """A (potentially) more memory-efficient version of 244 | get_causal_idxs that operates on a single batch. This can 245 | be called in a for loop to reduce memory usage.""" 246 | tril_input = t + tau 247 | edge = torch.tril_indices( 248 | t + tau, t + tau, offset=-1, 249 | dtype=torch.long, 250 | device=t.device, 251 | ) 252 | # Use windows to reduce size, in case the graph is too big. 253 | # Remove indices outside of the window 254 | if window is not None: 255 | window_min_idx = max(0, t - window) 256 | window_mask = edge[1] >= window_min_idx 257 | # Remove edges outside of window 258 | edge = edge[:, window_mask] 259 | 260 | # Filter edges -- we only want incoming edges to tau nodes 261 | # we should have no sinks < T 262 | edge = edge[:, edge[0] >= t] 263 | return edge 264 | 265 | batch = b * torch.ones(1, device=t.device, dtype=torch.long) 266 | batch = batch.expand(edge[-1].shape[-1]) 267 | 268 | return torch.cat((batch.unsqueeze(0), edge), dim=0) 269 | 270 | def get_causal_edges(T, taus, window=None): 271 | """Given T and taus, select all the causal indices. In other words, 272 | return all edges going from past to future (not future to past)""" 273 | edge_idx = [] 274 | #tril_inputs = T + taus 275 | B = T.numel() 276 | for b in range(B): 277 | edge = get_causal_edges_one_batch(T[b], taus[b], window=window) 278 | batch = b * torch.ones(1, device=T.device, dtype=torch.long) 279 | batch = batch.expand(edge[-1].shape[-1]) 280 | edge_idx.append(torch.cat((batch.unsqueeze(0), edge), dim=0)) 281 | edge_idx = torch.cat(edge_idx, dim=-1) 282 | return edge_idx 283 | 284 | 285 | 286 | 287 | def flatten_adj(adj, T, taus, B): 288 | """Flatten a torch.coo_sparse [B, MAX_NODES, MAX_NODES] to [2, NE] and 289 | adds offsets to avoid collisions. 290 | This readies a sparse tensor for torch_geometric GNN 291 | ingestion. 292 | 293 | Returns edges, weights, and corresponding batch ids 294 | """ 295 | batch_starts, batch_ends = get_batch_offsets(T + taus) 296 | 297 | # TODO remove coalesce when bug is fixed 298 | adj = adj.coalesce() 299 | batch_idx = adj.indices()[0] 300 | edge_offsets = batch_starts[batch_idx] 301 | flat_edges = adj.indices()[1:] + edge_offsets 302 | flat_weights = adj.values() 303 | 304 | return flat_edges, flat_weights, batch_idx 305 | 306 | 307 | def unflatten_adj(edges, weights, batch_idx, T, taus, B, max_edges): 308 | """Unflatten edges [2,NE], weights: [NE], and batch_idx [NE] 309 | into a torch.coo_sparse adjacency matrix of [B, NE, NE]""" 310 | batch_starts, batch_ends = get_batch_offsets(T + taus) 311 | 312 | edge_offsets = batch_starts[batch_idx] 313 | adj_edge_idx = edges - edge_offsets 314 | adj_idx = torch.stack([batch_idx, adj_edge_idx[0], adj_edge_idx[1]]) 315 | adj_val = weights 316 | 317 | return torch.sparse_coo_tensor( 318 | indices=adj_idx, values=adj_val, size=(B, max_edges, max_edges) 319 | ) 320 | 321 | 322 | 323 | def pack_hidden(hidden, B, max_edges: int, edge_fill: int=-1, weight_fill: float=1.0): 324 | return _pack_hidden(*hidden, B, max_edges, edge_fill, weight_fill) 325 | 326 | def _pack_hidden( 327 | nodes: torch.Tensor, 328 | adj: torch.Tensor, 329 | T: torch.Tensor, 330 | B: int, 331 | max_edges: int, 332 | edge_fill: int = -1, 333 | weight_fill: float = 1.0, 334 | ): 335 | """Converts a torch.coo_sparse adj to a ray dense edgelist.""" 336 | batch_idx, source_idx, sink_idx = adj._indices().unbind() 337 | dense_edges = torch.empty((B, 2, max_edges), device=adj.device, dtype=torch.long).fill_(edge_fill) 338 | dense_weights = torch.empty((B, 1, max_edges), device=adj.device, dtype=torch.float).fill_(weight_fill) 339 | 340 | # TODO remove coalesce when bug is fixed 341 | adj = adj.coalesce() 342 | # TODO can we vectorize this without a BxNE matrix? 343 | for b in range(B): 344 | sparse_b_idx = torch.nonzero(batch_idx == b).reshape(-1) 345 | assert sparse_b_idx.shape[-1] < max_edges, ( 346 | f"Cannot pack {sparse_b_idx.shape[-1]} edges into {max_edges}, increase" 347 | " max edges" 348 | ) 349 | dense_b_idx = torch.arange(sparse_b_idx.shape[0]) 350 | dense_edges[b, :, dense_b_idx] = adj.indices()[1:, sparse_b_idx] 351 | dense_weights[b, 0, dense_b_idx] = adj.values()[sparse_b_idx] 352 | 353 | return nodes, dense_edges, dense_weights, T 354 | 355 | def unpack_hidden(hidden, B): 356 | return _unpack_hidden(*hidden, B) 357 | 358 | def _unpack_hidden( 359 | nodes: torch.Tensor, 360 | edges: torch.Tensor, 361 | weights: torch.Tensor, 362 | T: torch.Tensor, 363 | B: torch.Tensor 364 | ): 365 | """Convert a ray dense edgelist into a torch.coo_sparse tensor""" 366 | # Get indices of valid edge pairs 367 | batch_idx, edge_idx = (edges[:,0] >= 0).nonzero().T.unbind() 368 | # Get values of valid edge pairs 369 | sources = edges[batch_idx, 0, edge_idx] 370 | sinks = edges[batch_idx, 1, edge_idx] 371 | 372 | adj_idx = torch.stack([batch_idx, sources, sinks]) 373 | weights_filtered = weights[batch_idx, 0, edge_idx] 374 | #sink_idx = edges[batch_idx, 1, source_idx] 375 | #adj_idx = torch.stack([batch_idx, source_idx, sink_idx]) 376 | 377 | 378 | adj = torch.sparse_coo_tensor( 379 | indices=adj_idx, values=weights_filtered, size=(B, nodes.shape[1], nodes.shape[1]) 380 | ) 381 | 382 | return nodes, adj, T 383 | 384 | 385 | 386 | def flatten_edges_and_weights(edges, weights, T, taus, B): 387 | """Flatten edges from [B, 2, NE] to [2, k * NE], coalescing 388 | and removing invalid edges (-1). In other words, prep 389 | edges and weights for GNN ingestion. 390 | 391 | Returns flattened edges, weights, and corresponding 392 | batch indices""" 393 | batch_offsets = get_batch_offsets(T, taus) 394 | edge_offsets = ( 395 | batch_offsets.unsqueeze(-1).unsqueeze(-1).expand(-1, 2, edges.shape[-1]) 396 | ) 397 | offset_edges = edges + edge_offsets 398 | offset_edges_B_idx = torch.cat( 399 | [ 400 | b * torch.ones(edges.shape[-1], device=edges.device, dtype=torch.long) 401 | for b in range(B) 402 | ] 403 | ) 404 | # Filter invalid edges (those that were < 0 originally) 405 | # Swap dims (B,2,NE) => (2,B,NE) 406 | mask = (offset_edges >= edge_offsets).permute(1, 0, 2) 407 | stacked_mask = (mask[0] & mask[1]).unsqueeze(0).expand(2, -1, -1) 408 | # Now filter edges, weights, and indices using masks 409 | # Careful, mask select will automatically flatten 410 | # so do it last, this squeezes from from (2,B,NE) => (2,B*NE) 411 | flat_edges = edges.permute(1, 0, 2).masked_select(stacked_mask).reshape(2, -1) 412 | flat_weights = weights.permute(1, 0, 2).masked_select(stacked_mask[0]).flatten() 413 | flat_B_idx = offset_edges_B_idx.masked_select(stacked_mask[0].flatten()) 414 | 415 | # Finally, remove duplicate edges and weights 416 | # but only if we have edges 417 | if flat_edges.numel() > 0: 418 | # Make sure idxs are removed alongside edges and weights 419 | flat_edges, [flat_weights, flat_B_idx] = torch_geometric.utils.coalesce( 420 | flat_edges, [flat_weights, flat_B_idx], reduce="min" 421 | ) 422 | 423 | return flat_edges, flat_weights, flat_B_idx 424 | 425 | 426 | def flatten_nodes(nodes: torch.Tensor, T: torch.Tensor, taus: torch.Tensor, B: int): 427 | """Flatten nodes from [B, N, feat] to [B * N, feat] for ingestion 428 | by the GNN. 429 | 430 | Returns flattened nodes and corresponding batch indices""" 431 | batch_offsets, _ = get_batch_offsets(T + taus) 432 | #batch_offsets, end_offset = get_edge_offsets(T + taus) 433 | B_idxs, tau_idxs = get_valid_node_idxs(T, taus, B) 434 | flat_nodes = nodes[B_idxs, tau_idxs] 435 | # Extracting belief requires batch-tau indices (newly inserted nodes) 436 | # return these too 437 | # Flat nodes are ordered B,:T+tau (all valid nodes) 438 | # We want B,T:T+tau (new nodes), which is batch_offsets:batch_offsets + tau 439 | offset_starts = batch_offsets + T 440 | offset_ends = offset_starts + taus 441 | 442 | output_node_idxs = torch.cat( 443 | [ 444 | torch.arange( 445 | offset_starts[b], 446 | offset_ends[b], 447 | device=nodes.device, 448 | ) 449 | for b in range(B) 450 | ] 451 | ) 452 | return flat_nodes, output_node_idxs 453 | 454 | 455 | @torch.jit.script 456 | def diff_or(tensors: List[torch.Tensor]): 457 | """Differentiable OR operation bewteen n-tuple of tensors 458 | Input: List[tensors in {0,1}] 459 | Output: tensor in {0,1}""" 460 | print("This seems to dilute gradients, dont use it") 461 | res = torch.zeros_like(tensors[0]) 462 | for t in tensors: 463 | tmp = res.clone() 464 | res = tmp + t - tmp * t 465 | return res 466 | 467 | 468 | @torch.jit.script 469 | def diff_or2(tensors: List[torch.Tensor]): 470 | """Differentiable OR operation bewteen n-tuple of tensors 471 | Input: List[tensors in {0,1}] 472 | Output: tensor in {0,1}""" 473 | print("This seems to dilute gradients, dont use it") 474 | # This nice form is actually slower than the matrix mult form 475 | return 1 - (1 - torch.stack(tensors, dim=0)).prod(dim=0) 476 | 477 | 478 | @torch.jit.script 479 | def idxs_up_to_including_num_nodes( 480 | nodes: torch.Tensor, num_nodes: torch.Tensor 481 | ) -> Tuple[torch.Tensor, torch.Tensor]: 482 | """Given nodes and num_nodes, returns idxs from nodes 483 | up to and including num_nodes. I.e. 484 | [batches, 0:num_nodes + 1]. Note the order is 485 | sorted by (batches, num_nodes + 1) in ascending order. 486 | 487 | Useful for getting all active nodes in the graph""" 488 | seq_lens = num_nodes.unsqueeze(-1) 489 | N = nodes.shape[1] 490 | N_idx = torch.arange(N, device=nodes.device).unsqueeze(0) 491 | N_idx = N_idx.expand(seq_lens.shape[0], N_idx.shape[1]) 492 | # include the current node 493 | N_idx = torch.nonzero(N_idx <= num_nodes.unsqueeze(1)) 494 | assert N_idx.shape[-1] == 2 495 | batch_idxs = N_idx[:, 0] 496 | node_idxs = N_idx[:, 1] 497 | 498 | return batch_idxs, node_idxs 499 | 500 | 501 | @torch.jit.script 502 | def idxs_up_to_num_nodes( 503 | adj: torch.Tensor, num_nodes: torch.Tensor 504 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 505 | """Given num_nodes, returns idxs from adj 506 | up to but not including num_nodes. I.e. 507 | [batches, 0:num_nodes, num_nodes]. Note the order is 508 | sorted by (batches, num_nodes, 0:num_nodes) in ascending order. 509 | 510 | Useful for getting all actives adj entries in the graph""" 511 | seq_lens = num_nodes.unsqueeze(-1) 512 | N = adj.shape[-1] 513 | N_idx = torch.arange(N, device=adj.device).unsqueeze(0) 514 | N_idx = N_idx.expand(seq_lens.shape[0], N_idx.shape[1]) 515 | # Do not include the current node 516 | N_idx = torch.nonzero(N_idx < num_nodes.unsqueeze(1)) 517 | assert N_idx.shape[-1] == 2 518 | batch_idxs = N_idx[:, 0] 519 | past_idxs = N_idx[:, 1] 520 | curr_idx = num_nodes[batch_idxs] 521 | 522 | return batch_idxs, past_idxs, curr_idx 523 | -------------------------------------------------------------------------------- /tests/profile_sparse.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch_geometric 3 | from gcm.sparse_gcm import SparseGCM 4 | from gcm.sparse_edge_selectors.learned import LearnedEdge 5 | from gcm import util 6 | import cProfile, pstats, io 7 | from pstats import SortKey 8 | 9 | 10 | sparse_g = torch_geometric.nn.Sequential( 11 | "x, edges, weights", 12 | [ 13 | (torch_geometric.nn.GraphConv(32, 32), "x, edges, weights -> x"), 14 | (torch.nn.Tanh()), 15 | (torch_geometric.nn.GraphConv(32, 32), "x, edges, weights -> x"), 16 | (torch.nn.Tanh()), 17 | ], 18 | ) 19 | 20 | 21 | 22 | def fn(): 23 | B = 8 24 | num_obs = 256 25 | obs_size = 32 26 | sparse_gcm = SparseGCM( 27 | sparse_g, graph_size=num_obs, edge_selectors=LearnedEdge(obs_size), 28 | max_hops=2 29 | ) 30 | obs = torch.rand(B, num_obs, obs_size) 31 | taus = torch.ones(B, dtype=torch.long) 32 | hidden = None 33 | with cProfile.Profile() as pr: 34 | #with torch.profiler.profile(with_stack=True, profile_memory=True) as p: 35 | # inference 36 | with torch.no_grad(): 37 | for i in range(num_obs): 38 | out, hidden = sparse_gcm(obs[:,i,None], taus, hidden) 39 | tmp = util.pack_hidden(hidden, B, max_edges = 5 * num_obs) 40 | tmp = util.unpack_hidden(tmp, B) 41 | # train 42 | out, hidden = sparse_gcm(obs, taus * num_obs, None) 43 | tmp = util.pack_hidden(hidden, B, max_edges = 5 * num_obs) 44 | tmp = util.unpack_hidden(tmp, B) 45 | out.mean().backward() 46 | 47 | #print(p.key_averages(group_by_stack_n=100).table(sort_by="self_cpu_time_total", row_limit=20)) 48 | pr.print_stats(sort="cumtime") 49 | 50 | fn() 51 | 52 | -------------------------------------------------------------------------------- /tests/test_gcm.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import math 3 | import torch 4 | import torch_geometric 5 | import torchviz 6 | 7 | from gcm.gcm import DenseGCM, DenseToSparse, SparseToDense, PositionalEncoding 8 | from gcm.edge_selectors.temporal import TemporalBackedge 9 | from gcm.edge_selectors.distance import EuclideanEdge, CosineEdge, SpatialEdge 10 | from gcm.edge_selectors.dense import DenseEdge 11 | from gcm.edge_selectors.learned import LearnedEdge 12 | from gcm.util import diff_or, diff_or2 13 | 14 | 15 | class TestPositionalEncoding(unittest.TestCase): 16 | def setUp(self): 17 | torch.autograd.set_detect_anomaly(True) 18 | feats = 5 19 | batches = 2 20 | N = 7 21 | self.g = torch_geometric.nn.Sequential( 22 | "x, adj, weights, B, N", 23 | [ 24 | (lambda x: x, "x -> x"), 25 | ], 26 | ) 27 | self.pe = PositionalEncoding( 28 | max_len=N, 29 | mode="add", 30 | ) 31 | # self.s = DenseGCM(self.g, positional_encoder=self.pe) 32 | 33 | self.nodes = torch.zeros(batches, N, feats) 34 | self.obs = torch.ones(batches, feats) 35 | self.adj = torch.zeros(batches, N, N) 36 | self.weights = torch.ones(batches, N, N) 37 | self.num_nodes = torch.tensor([0, 7]) 38 | 39 | def test_pos_enc_add(self): 40 | enc = self.pe(self.nodes.clone(), self.num_nodes) 41 | if not torch.all(enc[0, 1, :] == 0): 42 | self.fail("Off by one error in encoder (overflow)") 43 | 44 | if torch.all(enc[0, 0, :] == 0): 45 | self.fail("Off by one error in encoder (underflow)") 46 | 47 | # PE(x,2i) = sin(x/10000^(2i/D)) 48 | # PE(x,2i+1) = cos(x/10000^(2i/D)) 49 | # 50 | # Zeroth row: 51 | # PE(x=0,2i=0) = sin(0/1) = 0 52 | # PE(x=0,2i+1=1) = cos(0/1) = pi/2 53 | # PE(x=0,2i=2) = sin(0/10000^(2/4)) = 0 54 | # PE(x=0,2i+1=3) = cos(0/10000^(2/4) = pi/2 55 | # 56 | # First row: 57 | # Note we must have 2i so its /6 instead of /5 58 | # PE(x=1,2i=0) = sin 1/10000^(0/6) 59 | # PE(x=1,2i+1=1) = cos 1/10000^(0/6) 60 | # PE(x=1,2i=2) = sin 1/10000^(2/6) 61 | # PE(x=1,2i+1=3) = cos 1/10000^(2/6) 62 | # PE(x=1,2i=2) = sin 1/10000^(4/6) 63 | 64 | sin_actual = enc[0, 0, ::2] 65 | sin_desired = torch.zeros_like(enc[0, 0, ::2]) 66 | cos_actual = enc[0, 0, 1::2] 67 | cos_desired = torch.ones_like(enc[0, 0, 1::2]) 68 | 69 | if torch.abs(sin_actual - sin_desired).sum() > 0.01: 70 | self.fail(f"Sine zeroth row desired {sin_desired} actual {sin_actual}") 71 | if torch.abs(cos_actual - cos_desired).sum() > 0.01: 72 | self.fail(f"Cosine zeroth row desired {cos_desired} actual {cos_actual}") 73 | 74 | enc = self.pe(self.nodes.clone(), self.num_nodes + 1) 75 | desired = torch.tensor( 76 | [ 77 | math.sin((1 / 10000) ** (0 / 6)), 78 | math.cos((1 / 10000) ** (0 / 6)), 79 | math.sin((1 / 10000) ** (2 / 6)), 80 | math.cos((1 / 10000) ** (2 / 6)), 81 | math.sin((1 / 10000) ** (4 / 6)), 82 | ] 83 | ) 84 | 85 | if torch.abs(enc[0, 1] - desired).sum() > 0.01: 86 | self.fail(f"Desired {desired} actual {enc[0,1]}") 87 | 88 | 89 | class TestWrapOverflow(unittest.TestCase): 90 | def setUp(self): 91 | torch.autograd.set_detect_anomaly(True) 92 | feats = 5 93 | batches = 2 94 | N = 7 95 | conv_type = torch_geometric.nn.DenseGraphConv 96 | self.g = torch_geometric.nn.Sequential( 97 | "x, adj, weights, B, N", 98 | [ 99 | (conv_type(feats, feats), "x, adj -> x"), 100 | (torch.nn.ReLU()), 101 | ], 102 | ) 103 | self.s = DenseGCM(self.g) 104 | 105 | self.nodes = torch.arange((batches * N * feats), dtype=torch.float).reshape( 106 | batches, N, feats 107 | ) 108 | self.obs = torch.ones(batches, feats) * 5 109 | self.adj = torch.zeros(batches, N, N) 110 | self.weights = torch.ones(batches, N, N) 111 | self.num_nodes = torch.tensor([1, 7]) 112 | 113 | def test_wrap_overflow(self): 114 | self.adj[:, 0, :] = 1 115 | self.adj[:, :, 0] = 1 116 | self.weights[:, 0, :] = 5 117 | self.weights[:, :, 0] = 5 118 | self.nodes[:, 0] = 0 119 | 120 | desired = torch.zeros_like(self.adj) 121 | desired[0, 0, :] = 1 122 | desired[0, :, 0] = 1 123 | desired_weights = torch.ones_like(self.weights) 124 | desired_weights[0, 0, :] = 5 125 | desired_weights[0, :, 0] = 5 126 | desired_weights[1, -1, :] = 0 127 | desired_weights[1, :, -1] = 0 128 | desired_nodes = self.nodes.clone() 129 | desired_nodes[0, 1] = 5 130 | desired_nodes[1, -1] = 5 131 | desired_nodes[1, 0] = torch.arange(8 * 5, 9 * 5) 132 | _, (nodes, adj, weights, num_nodes) = self.s( 133 | self.obs, (self.nodes, self.adj, self.weights, self.num_nodes) 134 | ) 135 | if not torch.all(adj == desired): 136 | self.fail(f"{adj} != {desired}") 137 | 138 | if not torch.all(weights == desired_weights): 139 | self.fail(f"{weights} != {desired_weights}") 140 | 141 | if not torch.all(nodes[0] == desired_nodes[0]): 142 | self.fail(f"{nodes[0]} != {desired_nodes[0]}") 143 | 144 | # It's shifted by one 145 | if not torch.all(nodes[1, 1] == desired_nodes[1, 2]): 146 | self.fail(f"{nodes[1,2]} != {desired_nodes[1,1]}") 147 | 148 | if not torch.all(nodes[1, 0] == desired_nodes[1, 0]): 149 | self.fail(f"{nodes[1,0]} != {desired_nodes[1,0]}") 150 | 151 | if not torch.all(nodes[1, -1] == desired_nodes[1, -1]): 152 | self.fail(f"{nodes[0]} != {desired_nodes[0]}") 153 | 154 | def test_wrap_overflow_no_weights(self): 155 | self.weights = torch.ones(0) 156 | self.adj[:, 0, :] = 1 157 | self.adj[:, :, 0] = 1 158 | self.nodes[:, 0] = 0 159 | 160 | desired = torch.zeros_like(self.adj) 161 | desired[0, 0, :] = 1 162 | desired[0, :, 0] = 1 163 | desired_nodes = self.nodes.clone() 164 | desired_nodes[0, 1] = 5 165 | desired_nodes[1, -1] = 5 166 | desired_nodes[1, 0] = torch.arange(8 * 5, 9 * 5) 167 | _, (nodes, adj, weights, num_nodes) = self.s( 168 | self.obs, (self.nodes, self.adj, self.weights, self.num_nodes) 169 | ) 170 | if not torch.all(adj == desired): 171 | self.fail(f"{adj} != {desired}") 172 | 173 | if not torch.all(nodes[0] == desired_nodes[0]): 174 | self.fail(f"{nodes[0]} != {desired_nodes[0]}") 175 | 176 | # It's shifted by one 177 | if not torch.all(nodes[1, 1] == desired_nodes[1, 2]): 178 | self.fail(f"{nodes[1,2]} != {desired_nodes[1,1]}") 179 | 180 | if not torch.all(nodes[1, 0] == desired_nodes[1, 0]): 181 | self.fail(f"{nodes[1,0]} != {desired_nodes[1,0]}") 182 | 183 | if not torch.all(nodes[1, -1] == desired_nodes[1, -1]): 184 | self.fail(f"{nodes[0]} != {desired_nodes[0]}") 185 | 186 | 187 | class TestGCMDirection(unittest.TestCase): 188 | def setUp(self): 189 | torch.autograd.set_detect_anomaly(True) 190 | feats = 11 191 | batches = 1 192 | N = 10 193 | conv_type = torch_geometric.nn.DenseGraphConv 194 | self.g = torch_geometric.nn.Sequential( 195 | "x, adj, weights, B, N", 196 | [ 197 | (conv_type(feats, feats), "x, adj -> x"), 198 | (torch.nn.ReLU()), 199 | ], 200 | ) 201 | for layer in list(self.g.modules())[1]: 202 | if layer.__class__ == conv_type: 203 | layer.lin_root.weight = torch.nn.Parameter( 204 | torch.diag(torch.zeros(layer.lin_root.weight.shape[-1])) 205 | ) 206 | layer.lin_root.bias = torch.nn.Parameter( 207 | torch.zeros_like(layer.lin_root.bias) 208 | ) 209 | layer.lin_rel.weight = torch.nn.Parameter( 210 | torch.diag(torch.ones(layer.lin_root.weight.shape[-1])) 211 | ) 212 | self.s = DenseGCM(self.g) 213 | 214 | self.nodes = torch.arange((batches * N * feats), dtype=torch.float).reshape( 215 | batches, N, feats 216 | ) 217 | self.all_obs = [ 218 | 1 * torch.ones(batches, feats), 219 | 2 * torch.ones(batches, feats), 220 | 3 * torch.ones(batches, feats), 221 | ] 222 | self.adj = torch.zeros(batches, N, N) 223 | self.weights = torch.ones(batches, N, N) 224 | self.num_nodes = torch.zeros(batches, dtype=torch.long) 225 | 226 | def test_gcm_direction(self): 227 | # Only get neighbor 228 | self.adj[:, 0, 3] = 1 229 | # list(self.g.modules())[1][0].lin_rel(torch.matmul(self.adj, self.nodes)) 230 | 231 | out, (nodes, adj, weights, num_nodes) = self.s( 232 | self.all_obs[0], (self.nodes, self.adj, self.weights, self.num_nodes) 233 | ) 234 | # neighbor 235 | # flows from 3 => 0, neighbor => root 236 | # root = i, neighbor = j 237 | # i should be > j in adj[i, j] 238 | desired = torch.arange(3 * 11, 4 * 11, dtype=torch.float) 239 | if not torch.all(self.nodes[0, 3] == desired): 240 | self.fail(f"{self.nodes[0,3]} != {desired}") 241 | 242 | 243 | class TestDenseGCME2E(unittest.TestCase): 244 | def setUp(self): 245 | torch.autograd.set_detect_anomaly(True) 246 | feats = 11 247 | batches = 5 248 | N = 10 249 | conv_type = torch_geometric.nn.DenseGraphConv 250 | self.g = torch_geometric.nn.Sequential( 251 | "x, adj, weights, B, N", 252 | [ 253 | (conv_type(feats, feats), "x, adj -> x"), 254 | (torch.nn.ReLU()), 255 | (conv_type(feats, feats), "x, adj -> x"), 256 | (torch.nn.ReLU()), 257 | ], 258 | ) 259 | for layer in list(self.g.modules())[1]: 260 | if layer.__class__ == conv_type: 261 | layer.lin_root.weight = torch.nn.Parameter( 262 | torch.diag(torch.ones(layer.lin_root.weight.shape[-1])) 263 | ) 264 | layer.lin_root.bias = torch.nn.Parameter( 265 | torch.zeros_like(layer.lin_root.bias) 266 | ) 267 | layer.lin_rel.weight = torch.nn.Parameter( 268 | torch.diag(torch.ones(layer.lin_root.weight.shape[-1])) 269 | ) 270 | self.s = DenseGCM(self.g) 271 | 272 | self.nodes = torch.zeros((batches, N, feats), dtype=torch.float) 273 | self.all_obs = [ 274 | 1 * torch.ones(batches, feats), 275 | 2 * torch.ones(batches, feats), 276 | 3 * torch.ones(batches, feats), 277 | ] 278 | self.adj = torch.zeros(batches, N, N) 279 | self.weights = torch.ones(batches, N, N) 280 | self.num_nodes = torch.zeros(batches, dtype=torch.long) 281 | 282 | def test_e2e_self_edge(self): 283 | (nodes, adj, weights, num_nodes) = ( 284 | self.nodes, 285 | self.adj, 286 | self.weights, 287 | self.num_nodes, 288 | ) 289 | # First iter 290 | # Zeroth row of graph should be 1111... 291 | # Output should be 1111... 292 | obs = self.all_obs[0].clone() 293 | out, (nodes, adj, weights, num_nodes) = self.s( 294 | obs, (nodes, adj, weights, num_nodes) 295 | ) 296 | if torch.any(out != self.all_obs[0]): 297 | self.fail(f"out: {out} != {self.all_obs[0]}") 298 | 299 | desired_nodes = torch.cat( 300 | (self.all_obs[0].unsqueeze(1), torch.zeros(5, 9, 11)), dim=1 301 | ) 302 | if torch.any(nodes != desired_nodes): 303 | self.fail(f"out: {nodes} != {desired_nodes}") 304 | 305 | # Second iter 306 | # Rows 1111 and 2222... 307 | # Output should be 2222 308 | obs = self.all_obs[1].clone() 309 | out, (nodes, adj, weights, num_nodes) = self.s( 310 | obs, (nodes, adj, weights, num_nodes) 311 | ) 312 | if torch.any(out != self.all_obs[1]): 313 | self.fail(f"out: {out} != {self.all_obs[1]}") 314 | 315 | # Third iter 316 | # Rows 1111 and 2222 and 3333... 317 | # Output should be 3333 318 | obs = self.all_obs[2].clone() 319 | out, (nodes, adj, weights, num_nodes) = self.s( 320 | obs, (nodes, adj, weights, num_nodes) 321 | ) 322 | if torch.any(out != self.all_obs[2]): 323 | self.fail(f"out: {out} != {self.all_obs[1]}") 324 | 325 | 326 | class TestDenseGCM(unittest.TestCase): 327 | def setUp(self): 328 | torch.autograd.set_detect_anomaly(True) 329 | feats = 11 330 | batches = 5 331 | N = 10 332 | conv_type = torch_geometric.nn.DenseGCNConv 333 | self.g = torch_geometric.nn.Sequential( 334 | "x, adj, weights, B, N", 335 | [ 336 | (conv_type(feats, feats), "x, adj -> x"), 337 | (torch.nn.ReLU()), 338 | (conv_type(feats, feats), "x, adj -> x"), 339 | (torch.nn.ReLU()), 340 | ], 341 | ) 342 | self.s = DenseGCM(self.g) 343 | 344 | # Now do it in a loop to make sure grads propagate 345 | self.optimizer = torch.optim.Adam(self.s.parameters(), lr=0.005) 346 | 347 | self.nodes = torch.arange(batches * N * feats, dtype=torch.float).reshape( 348 | batches, N, feats 349 | ) 350 | self.obs = torch.ones(batches, feats) 351 | self.adj = torch.zeros(batches, N, N) 352 | self.weights = torch.ones(batches, N, N) 353 | self.num_nodes = torch.zeros(batches, dtype=torch.long) 354 | 355 | def test_grad_prop(self): 356 | self.g.grad_test_var = torch.nn.Parameter(torch.tensor([1.0])) 357 | self.nodes = self.nodes * self.g.grad_test_var 358 | self.assertTrue(self.nodes.requires_grad) 359 | out, (nodes, adj, weights, num_nodes) = self.s( 360 | self.obs, (self.nodes, self.adj, self.weights, self.num_nodes) 361 | ) 362 | loss = torch.norm(out) 363 | dot = torchviz.make_dot(loss, params=dict(self.s.named_parameters())) 364 | # Make sure gradients make it all the way thru node_feats 365 | self.assertTrue("grad_test_var" in dot.source) 366 | 367 | def test_no_weights(self): 368 | _, (nodes, adj, weights, num_nodes) = self.s( 369 | self.obs, (self.nodes, self.adj, torch.ones(0), self.num_nodes) 370 | ) 371 | 372 | if len(weights) != 0: 373 | self.fail(f"Weights should be none, is {weights}") 374 | 375 | def test_zeroth_entry(self): 376 | # Ensure first obs ends up in nodes matrix 377 | _, (nodes, adj, weights, num_nodes) = self.s( 378 | self.obs, (self.nodes, self.adj, self.weights, self.num_nodes) 379 | ) 380 | if torch.any(nodes[:, 0] != self.obs): 381 | self.fail(f"{nodes[:,0]} != {self.obs}") 382 | # Ensure only self edges 383 | adj = torch.zeros_like(self.adj, dtype=torch.long) 384 | if torch.any(self.adj != adj): 385 | self.fail(f"adj: {adj} != {self.adj}") 386 | 387 | def test_first_entry(self): 388 | _, (nodes, adj, weights, num_nodes) = self.s( 389 | self.obs, (self.nodes, self.adj, self.weights, self.num_nodes) 390 | ) 391 | _, (nodes, adj, weights, num_nodes) = self.s( 392 | self.obs, (nodes, adj, weights, num_nodes) 393 | ) 394 | if torch.any(nodes[:, 1] != self.obs): 395 | self.fail(f"{nodes[:,0]} != {self.obs}") 396 | 397 | def test_num_nodes_entry(self): 398 | _, (nodes, adj, weights, num_nodes) = self.s( 399 | self.obs, (self.nodes, self.adj, self.weights, self.num_nodes) 400 | ) 401 | B_idx = torch.arange(num_nodes.shape[0]) 402 | if torch.any(nodes[B_idx, num_nodes - 1] != self.obs): 403 | self.fail(f"{nodes[:,num_nodes]} != {self.obs}") 404 | 405 | def test_propagation(self): 406 | out, (nodes, adj, weights, num_nodes) = self.s( 407 | self.obs, (self.nodes, self.adj, self.weights, self.num_nodes) 408 | ) 409 | if torch.all(out == self.obs): 410 | self.fail(f"{out} == {self.obs}") 411 | 412 | def test_dense_learn(self): 413 | feats = 11 414 | batches = 5 415 | T = 4 416 | N = 10 417 | losses = [] 418 | for i in range(20): 419 | nodes = torch.arange(batches * N * feats, dtype=torch.float).reshape( 420 | batches, N, feats 421 | ) 422 | obs = torch.ones(batches, feats) 423 | adj = torch.zeros(batches, N, N, dtype=torch.long) 424 | weights = torch.ones(batches, N, N) 425 | num_nodes = torch.zeros(batches, dtype=torch.long) 426 | 427 | self.s.zero_grad() 428 | hidden = (nodes, adj, weights, num_nodes) 429 | for t in range(T): 430 | obs, hidden = self.s(obs, hidden) 431 | 432 | loss = torch.norm(obs) 433 | loss.backward() 434 | losses.append(loss) 435 | 436 | self.optimizer.step() 437 | 438 | if not losses[-1] < losses[0]: 439 | self.fail(f"Final loss {losses[-1]} not better than init loss {losses[0]}") 440 | 441 | 442 | class TestSparseGCM(unittest.TestCase): 443 | def setUp(self): 444 | feats = 11 445 | batches = 5 446 | N = 10 447 | conv_type = torch_geometric.nn.GCNConv 448 | self.g = torch_geometric.nn.Sequential( 449 | "x, adj, weights, B, N", 450 | [ 451 | (DenseToSparse(), "x, adj, -> x_sp, edge_index, batch_idx"), 452 | (conv_type(feats, feats), "x_sp, edge_index -> x_sp"), 453 | (torch.nn.ReLU()), 454 | (conv_type(feats, feats), "x_sp, edge_index -> x_sp"), 455 | (torch.nn.ReLU()), 456 | (SparseToDense(), "x_sp, edge_index, batch_idx, B, N -> x, adj"), 457 | # Return only x not adj 458 | (lambda x: x, "x -> x"), 459 | ], 460 | ) 461 | self.s = DenseGCM(self.g) 462 | 463 | # Now do it in a loop to make sure grads propagate 464 | self.optimizer = torch.optim.Adam(self.s.parameters(), lr=0.005) 465 | 466 | self.nodes = torch.arange(batches * N * feats, dtype=torch.float).reshape( 467 | batches, N, feats 468 | ) 469 | self.obs = torch.ones(batches, feats) 470 | self.adj = torch.zeros(batches, N, N, dtype=torch.long) 471 | self.adj[:, 0:2, 3:4] = 1 472 | self.weights = torch.ones(batches, N, N) 473 | self.num_nodes = torch.zeros(batches, dtype=torch.long) 474 | 475 | def test_zeroth_entry(self): 476 | # Ensure first obs ends up in nodes matrix 477 | _, (nodes, adj, weights, num_nodes) = self.s( 478 | self.obs, (self.nodes, self.adj, self.weights, self.num_nodes) 479 | ) 480 | if torch.any(nodes[:, 0] != self.obs): 481 | self.fail(f"{nodes[:,0]} != {self.obs}") 482 | 483 | def test_first_entry(self): 484 | _, (nodes, adj, weights, num_nodes) = self.s( 485 | self.obs, (self.nodes, self.adj, self.weights, self.num_nodes) 486 | ) 487 | _, (nodes, adj, weights, num_nodes) = self.s( 488 | self.obs, (nodes, adj, weights, num_nodes) 489 | ) 490 | if torch.any(nodes[:, 1] != self.obs): 491 | self.fail(f"{nodes[:,0]} != {self.obs}") 492 | 493 | def test_DenseToSparse_SparseToDense(self): 494 | B = self.nodes.shape[0] 495 | N = self.nodes.shape[1] 496 | 497 | x = self.nodes.clone() 498 | adj = self.adj.clone() 499 | weight = self.weights.clone() 500 | 501 | seq = torch_geometric.nn.Sequential( 502 | "x, adj, weights, B, N", 503 | [ 504 | (DenseToSparse(), "x, adj -> x_sp, edge_index, batch_idx"), 505 | (SparseToDense(), "x_sp, edge_index, batch_idx, B, N -> x_d, adj_d"), 506 | ], 507 | ) 508 | 509 | x_d, adj_d = seq(x, adj, weight, B, N) 510 | 511 | if torch.any(x_d != self.nodes): 512 | self.fail(f"x: {x_d} != {self.nodes}") 513 | 514 | if torch.any(adj_d != self.adj): 515 | self.fail(f"adj: {adj_d} != {self.adj}") 516 | 517 | def test_propagation(self): 518 | out, (nodes, adj, weights, num_nodes) = self.s( 519 | self.obs, (self.nodes, self.adj, self.weights, self.num_nodes) 520 | ) 521 | if torch.all(out == self.obs): 522 | self.fail(f"{out} == {self.obs}") 523 | 524 | def test_sparse_learn(self): 525 | feats = 11 526 | batches = 5 527 | T = 4 528 | N = 10 529 | losses = [] 530 | for i in range(20): 531 | nodes = torch.arange(batches * N * feats, dtype=torch.float).reshape( 532 | batches, N, feats 533 | ) 534 | obs = torch.ones(batches, feats) 535 | adj = torch.zeros(batches, N, N, dtype=torch.long) 536 | weights = torch.ones(batches, N, N) 537 | num_nodes = torch.zeros(batches, dtype=torch.long) 538 | 539 | self.s.zero_grad() 540 | hidden = (nodes, adj, weights, num_nodes) 541 | for t in range(T): 542 | obs, hidden = self.s(obs, hidden) 543 | 544 | loss = torch.norm(obs) 545 | loss.backward() 546 | losses.append(loss) 547 | 548 | self.optimizer.step() 549 | 550 | self.assertTrue(losses[-1] < losses[0]) 551 | 552 | 553 | class TestTemporalEdge(unittest.TestCase): 554 | def setUp(self): 555 | feats = 3 556 | batches = 2 557 | N = 10 558 | conv_type = torch_geometric.nn.DenseGCNConv 559 | self.g = torch_geometric.nn.Sequential( 560 | "x, adj, weights, B, N", 561 | [ 562 | (conv_type(feats, feats), "x, adj -> x"), 563 | (torch.nn.ReLU()), 564 | (conv_type(feats, feats), "x, adj -> x"), 565 | (torch.nn.ReLU()), 566 | ], 567 | ) 568 | self.s = DenseGCM(self.g, edge_selectors=TemporalBackedge(hops=[1])) 569 | 570 | # Now do it in a loop to make sure grads propagate 571 | self.optimizer = torch.optim.Adam(self.s.parameters(), lr=0.005) 572 | 573 | self.nodes = torch.arange(batches * N * feats, dtype=torch.float).reshape( 574 | batches, N, feats 575 | ) 576 | self.obs = torch.ones(batches, feats) 577 | self.adj = torch.zeros(batches, N, N) 578 | self.weights = torch.ones(batches, N, N) 579 | self.num_nodes = torch.zeros(batches, dtype=torch.long) 580 | 581 | def test_two_nodes(self): 582 | _, (nodes, adj, weights, num_nodes) = self.s( 583 | self.obs, (self.nodes, self.adj, self.weights, self.num_nodes) 584 | ) 585 | _, (nodes, adj, weights, num_nodes) = self.s( 586 | self.obs, (nodes, adj, weights, num_nodes) 587 | ) 588 | tgt_adj = torch.zeros_like(adj, dtype=torch.long) 589 | # tgt_adj[:, 0, 1] = 1 590 | tgt_adj[:, 1, 0] = 1 591 | # Also add self edges 592 | if torch.any(tgt_adj != adj): 593 | self.fail(f"{tgt_adj} != {adj}") 594 | 595 | def test_far_hops(self): 596 | (nodes, adj, weights, num_nodes) = ( 597 | self.nodes, 598 | self.adj, 599 | self.weights, 600 | self.num_nodes, 601 | ) 602 | self.s = DenseGCM(self.g, edge_selectors=TemporalBackedge(hops=[4])) 603 | for i in range(10): 604 | _, (nodes, adj, weights, num_nodes) = self.s( 605 | self.obs, (nodes, adj, weights, num_nodes) 606 | ) 607 | # hop 1 should start at t-1 608 | # hop 5 should start at t-5: 5=>0, 6=>1, etc 609 | tgt_adj = torch.zeros_like(adj) 610 | tgt_adj[:, 4, 0] = 1 611 | tgt_adj[:, 5, 1] = 1 612 | tgt_adj[:, 6, 2] = 1 613 | tgt_adj[:, 7, 3] = 1 614 | tgt_adj[:, 8, 4] = 1 615 | tgt_adj[:, 9, 5] = 1 616 | if torch.any(tgt_adj != adj): 617 | self.fail(f"{tgt_adj} != {adj}") 618 | 619 | def test_learned_edge_grad(self): 620 | self.s = DenseGCM(self.g, edge_selectors=TemporalBackedge(learned=True)) 621 | self.num_nodes += 1 622 | _, (nodes, adj, weights, num_nodes) = self.s( 623 | self.obs, (self.nodes, self.adj, self.weights, self.num_nodes) 624 | ) 625 | if not adj.requires_grad: 626 | self.fail("Adj has no grad") 627 | 628 | adj.sum().backward() 629 | 630 | 631 | class TestDoubleEdge(unittest.TestCase): 632 | def setUp(self): 633 | feats = 11 634 | batches = 5 635 | N = 10 636 | conv_type = torch_geometric.nn.DenseGCNConv 637 | self.g = torch_geometric.nn.Sequential( 638 | "x, adj, weights, B, N", 639 | [ 640 | (conv_type(feats, feats), "x, adj -> x"), 641 | (torch.nn.ReLU()), 642 | (conv_type(feats, feats), "x, adj -> x"), 643 | (torch.nn.ReLU()), 644 | ], 645 | ) 646 | es = torch_geometric.nn.Sequential( 647 | "x, adj, weights, num_nodes, B", 648 | [ 649 | ( 650 | TemporalBackedge([1]), 651 | "x, adj, weights, num_nodes, B -> adj, weights", 652 | ), 653 | ( 654 | TemporalBackedge([2]), 655 | "x, adj, weights, num_nodes, B -> adj, weights", 656 | ), 657 | ], 658 | ) 659 | self.s = DenseGCM(self.g, edge_selectors=es) 660 | 661 | self.nodes = torch.zeros(batches, N, feats, dtype=torch.float) 662 | self.obs = torch.zeros(batches, feats) 663 | self.adj = torch.zeros(batches, N, N, dtype=torch.long) 664 | self.weights = torch.ones(batches, N, N) 665 | self.num_nodes = torch.ones(batches, dtype=torch.long) 666 | self.optimizer = torch.optim.Adam(self.s.parameters(), lr=0.005) 667 | 668 | def test_backwards(self): 669 | nodes, adj, weights, num_nodes = ( 670 | self.nodes, 671 | self.adj, 672 | self.weights, 673 | self.num_nodes + 1, 674 | ) 675 | out, (nodes, adj, weights, num_nodes) = self.s( 676 | self.obs, (nodes, adj, weights, num_nodes) 677 | ) 678 | self.s.edge_selectors.zero_grad() 679 | nodes = torch.rand_like(self.nodes) * torch.nn.Parameter(torch.tensor([0.01])) 680 | adj, weights = self.s.edge_selectors(nodes, adj, weights, num_nodes, 5) 681 | nodes.mean().backward() 682 | self.optimizer.step() 683 | 684 | 685 | class TestDistanceEdge(unittest.TestCase): 686 | def setUp(self): 687 | feats = 11 688 | batches = 5 689 | N = 10 690 | conv_type = torch_geometric.nn.DenseGCNConv 691 | self.g = torch_geometric.nn.Sequential( 692 | "x, adj, weights, B, N", 693 | [ 694 | (conv_type(feats, feats), "x, adj -> x"), 695 | (torch.nn.ReLU()), 696 | (conv_type(feats, feats), "x, adj -> x"), 697 | (torch.nn.ReLU()), 698 | ], 699 | ) 700 | self.s = DenseGCM(self.g, edge_selectors=EuclideanEdge(max_distance=1)) 701 | 702 | self.nodes = torch.zeros(batches, N, feats, dtype=torch.float) 703 | self.obs = torch.zeros(batches, feats) 704 | self.adj = torch.zeros(batches, N, N, dtype=torch.long) 705 | self.weights = torch.ones(batches, N, N) 706 | self.num_nodes = torch.ones(batches, dtype=torch.long) 707 | 708 | def test_zero_dist(self): 709 | # Start num_nodes = 1 710 | _, (nodes, adj, weights, num_nodes) = self.s( 711 | self.obs, (self.nodes, self.adj, self.weights, self.num_nodes) 712 | ) 713 | tgt_adj = torch.zeros_like(adj, dtype=torch.long) 714 | tgt_adj[:, 1, 0] = 1 715 | 716 | # TODO: Ensure not off by one 717 | if torch.any(tgt_adj != adj): 718 | self.fail(f"{tgt_adj} != {self.adj}") 719 | 720 | def test_one_dist(self): 721 | # Start num_nodes = 1 722 | self.obs = torch.ones_like(self.obs) 723 | _, (nodes, adj, weights, num_nodes) = self.s( 724 | self.obs, (self.nodes, self.adj, self.weights, self.num_nodes) 725 | ) 726 | tgt_adj = torch.zeros_like(adj, dtype=torch.long) 727 | # Adds self edge 728 | if torch.any(tgt_adj != adj): 729 | self.fail(f"{tgt_adj} != {self.adj}") 730 | 731 | def test_cosine(self): 732 | self.s = DenseGCM(self.g, edge_selectors=CosineEdge(max_distance=1)) 733 | _, (nodes, adj, weights, num_nodes) = self.s( 734 | self.obs, (self.nodes, self.adj, self.weights, self.num_nodes) 735 | ) 736 | 737 | def test_learned_edge(self): 738 | self.s = DenseGCM( 739 | self.g, edge_selectors=EuclideanEdge(max_distance=1, learned=True) 740 | ) 741 | self.obs = torch.ones_like(self.obs) 742 | 743 | _, (nodes, adj, weights, num_nodes) = self.s( 744 | self.obs, (self.nodes, self.adj, self.weights, self.num_nodes) 745 | ) 746 | tgt_adj = torch.zeros_like(adj, dtype=torch.long) 747 | # Adds self edge 748 | if torch.any(tgt_adj != adj): 749 | self.fail(f"{tgt_adj} != {self.adj}") 750 | 751 | def test_learned_edge_grad(self): 752 | e = EuclideanEdge(max_distance=1, learned=True) 753 | self.num_nodes += 2 754 | 755 | adj, weights = e(self.nodes, self.adj, self.weights, self.num_nodes, B=5) 756 | 757 | if not self.nodes.requires_grad: 758 | self.fail("Nodes has no grad") 759 | 760 | 761 | class TestDenseEdge(unittest.TestCase): 762 | def setUp(self): 763 | feats = 11 764 | batches = 5 765 | N = 10 766 | conv_type = torch_geometric.nn.DenseGCNConv 767 | self.g = torch_geometric.nn.Sequential( 768 | "x, adj, weights, B, N", 769 | [ 770 | (conv_type(feats, feats), "x, adj -> x"), 771 | (torch.nn.ReLU()), 772 | (conv_type(feats, feats), "x, adj -> x"), 773 | (torch.nn.ReLU()), 774 | ], 775 | ) 776 | self.s = DenseGCM(self.g, edge_selectors=DenseEdge()) 777 | 778 | self.nodes = torch.zeros(batches, N, feats, dtype=torch.float) 779 | self.obs = torch.zeros(batches, feats) 780 | self.adj = torch.zeros(batches, N, N, dtype=torch.long) 781 | self.weights = torch.ones(batches, N, N) 782 | self.num_nodes = torch.zeros(batches, dtype=torch.long) 783 | 784 | def test_two_obs(self): 785 | # Start num_nodes = 1 786 | _, (nodes, adj, weights, num_nodes) = self.s( 787 | self.obs, (self.nodes, self.adj, self.weights, self.num_nodes) 788 | ) 789 | _, (nodes, adj, weights, num_nodes) = self.s( 790 | self.obs, (nodes, adj, weights, num_nodes) 791 | ) 792 | tgt_adj = torch.zeros_like(adj, dtype=torch.long) 793 | # It adds self edge 794 | tgt_adj[:, 1, 1] = 1 795 | tgt_adj[:, 0, 0] = 1 796 | tgt_adj[:, 1, 0] = 1 797 | tgt_adj[:, 0, 1] = 1 798 | 799 | # TODO: Ensure not off by one 800 | if torch.any(tgt_adj != adj): 801 | self.fail(f"{tgt_adj} != {self.adj}") 802 | 803 | 804 | class Sum(torch.nn.Module): 805 | def __init__(self): 806 | super().__init__() 807 | self.weight = torch.nn.Parameter(torch.tensor([1.0])) 808 | 809 | def forward(self, x): 810 | # Returns sum of the first feature of base and neighbor nodes 811 | return x[:, 0] + x[:, 5] 812 | 813 | 814 | class TestLearnedEdge(unittest.TestCase): 815 | def setUp(self): 816 | torch.autograd.set_detect_anomaly(True) 817 | feats = 5 818 | batches = 2 819 | N = 4 820 | conv_type = torch_geometric.nn.DenseGCNConv 821 | self.g = torch_geometric.nn.Sequential( 822 | "x, adj, weights, B, N", 823 | [ 824 | (conv_type(feats, feats), "x, adj -> x"), 825 | (torch.nn.ReLU()), 826 | (conv_type(feats, feats), "x, adj -> x"), 827 | (torch.nn.ReLU()), 828 | ], 829 | ) 830 | self.s = DenseGCM(self.g, edge_selectors=LearnedEdge(feats)) 831 | 832 | # Now do it in a loop to make sure grads propagate 833 | self.optimizer = torch.optim.Adam(self.s.parameters(), lr=0.005) 834 | 835 | self.nodes = torch.arange(batches * N * feats, dtype=torch.float).reshape( 836 | batches, N, feats 837 | ) 838 | self.obs = torch.ones(batches, feats) 839 | self.adj = torch.zeros(batches, N, N) 840 | self.weights = torch.ones(batches, N, N) 841 | self.num_nodes = torch.zeros(batches, dtype=torch.long) 842 | 843 | def test_compute_new_adj_deterministic_grad(self): 844 | self.num_nodes = torch.tensor((2, 3), dtype=torch.long) 845 | es = LearnedEdge(5, deterministic=True) 846 | self.nodes.requires_grad = True 847 | adj, weights = es(self.nodes, self.adj, self.weights, self.num_nodes, 2) 848 | 849 | self.assertTrue(adj.requires_grad) 850 | 851 | def test_compute_new_adj_grad(self): 852 | p = torch.nn.Parameter(torch.tensor([1.0])) 853 | self.nodes = self.nodes * p 854 | self.num_nodes = torch.tensor((2, 3), dtype=torch.long) 855 | es = LearnedEdge(5, deterministic=False) 856 | adj, weights = es(self.nodes, self.adj, self.weights, self.num_nodes, 2) 857 | 858 | self.assertTrue(adj.requires_grad) 859 | optimizer = torch.optim.Adam(es.parameters(), lr=0.005) 860 | self.nodes.sum().backward() 861 | optimizer.step() 862 | # grads = [p.grad for p in es.parameters()] 863 | 864 | """ 865 | def test_diff_or(self): 866 | test = [(torch.rand(32,128,128) > 0.5).float() for i in range(3)] 867 | desired = (sum(test) >= 1).float() 868 | actual = diff_or(test) 869 | if not torch.all(actual == desired): 870 | self.fail(f"Desired {desired} actual {actual}") 871 | 872 | actual = diff_or2(test) 873 | if not torch.all(actual == desired): 874 | self.fail(f"Desired {desired} actual {actual}") 875 | 876 | 877 | 878 | def test_update_density(self): 879 | self.b = LearnedEdge(5, torch.nn.Sequential(Sum())) 880 | a = torch.tensor([1.5, 2, 0, 0, 2, 0, 3, 1, 0.2]) 881 | b = torch.tensor([1, 4, 2, 0.1]) 882 | self.b.update_density(a) 883 | self.b.update_density(b) 884 | rm = self.b.detach_loss() 885 | self.assertTrue(torch.isclose(rm, torch.cat((a, b)).mean())) 886 | 887 | def test_sample_hard(self): 888 | self.b = LearnedEdge(5, torch.nn.Sequential()) 889 | self.num_nodes = torch.tensor([0,1,2]) 890 | self.weights = torch.zeros_like(self.weights) 891 | self.weights[0, 0, 0] = 1e6 892 | 893 | self.weights[1, 1, 1] = 1e6 894 | self.weights[1, 0, 1] = 1e6 895 | 896 | # Test valid weights correspond to adj 897 | desired = self.adj.clone() 898 | # Ensure diagonal (0,0) is set to 0 899 | desired[0, 0, 0] = 0.0 900 | desired[1, 1, 1] = 0.0 901 | # Ensure large, nondiagonal is nonzero 902 | desired[1, 0, 1] = 1.0 903 | 904 | # b_idxs, curr_idx, past_idxs = ([0, 0, 1, 1], [2, 2, 2, 2], [0, 1, 0, 1]) 905 | adj = sample_hard(self.adj, self.weights, self.num_nodes) 906 | if torch.any(desired != adj): 907 | self.fail(f"{desired} != {adj}") 908 | """ 909 | 910 | ''' 911 | def test_weight_to_adj(self): 912 | self.b = LearnedEdge(5, torch.nn.Sequential(Sum())) 913 | self.weights = torch.zeros_like(self.weights) 914 | self.weights[0, 2] = 1.0 915 | self.weights[1, 2] = 1.0 916 | self.weights = self.weights.clamp(*self.b.clamp_range) 917 | 918 | desired = self.adj.clone() 919 | desired[0, 2] = 1.0 920 | desired[1, 2] = 1.0 921 | 922 | # b_idxs, curr_idx, past_idxs = ([0, 0, 1, 1], [2, 2, 2, 2], [0, 1, 0, 1]) 923 | self.adj = sample_hard(self.weights, self.b.clamp_range) 924 | """ 925 | self.adj[b_idxs, curr_idx, past_idxs] = sample_hard( 926 | self.weights[b_idxs, curr_idx, past_idxs], self.b.clamp_range 927 | ) 928 | """ 929 | if torch.any(desired != self.adj): 930 | self.fail(f"{desired} != {self.adj}") 931 | ''' 932 | 933 | def test_indexing(self): 934 | self.b = LearnedEdge(5, torch.nn.Sequential(Sum())) 935 | 936 | self.s = DenseGCM(self.g, edge_selectors=self.b) 937 | self.all_obs = [ 938 | torch.ones_like(self.obs) * 0.1, 939 | torch.ones_like(self.obs) * 0.2, 940 | torch.ones_like(self.obs) * 0.3, 941 | ] 942 | self.weights = torch.zeros_like(self.weights) 943 | 944 | (nodes, adj, weights, num_nodes) = ( 945 | self.nodes, 946 | self.adj, 947 | self.weights.clone(), 948 | self.num_nodes, 949 | ) 950 | for i in range(3): 951 | out, (nodes, adj, weights, num_nodes) = self.s( 952 | self.all_obs[i], (nodes, adj, weights, num_nodes) 953 | ) 954 | 955 | # 0: skip 956 | # 1: cat(0.2, 0.1) = 0.3 957 | # 2: cat(0.3, 0.1) = 0.4, cat(0.3, 0.2) = 0.5 958 | 959 | # 0 -> [] 960 | # 1 -> [1,0] 961 | # 2 -> [2,0], [2,1] # does this order matter? 962 | # Network input is __ascending__ e.g. [3,0], [3,1], [3,2] 963 | # so 964 | desired = self.weights.clone() 965 | desired[:, 1, 0] = 0.3 966 | desired[:, 2, 0] = 0.4 967 | desired[:, 2, 1] = 0.5 968 | if torch.any(desired != weights): 969 | self.fail(f"{desired} != {weights}") 970 | # b <- a should sum to 6 971 | 972 | def test_grad_prop(self): 973 | self.g.grad_test_var = torch.nn.Parameter(torch.tensor([1.0])) 974 | self.nodes = self.nodes * self.g.grad_test_var 975 | self.assertTrue(self.nodes.requires_grad) 976 | 977 | out, (nodes, adj, weights, num_nodes) = self.s( 978 | self.obs, (self.nodes, self.adj, self.weights, self.num_nodes) 979 | ) 980 | # First run has no gradient as no edges to be made 981 | out, (nodes, adj, weights, num_nodes) = self.s( 982 | self.obs, (nodes, adj, weights, num_nodes) 983 | ) 984 | 985 | loss = torch.norm(out) 986 | dot = torchviz.make_dot(loss, params=dict(self.s.named_parameters())) 987 | # Make sure gradients make it all the way thru node_feats 988 | self.assertTrue("grad_test_var" in dot.source) 989 | 990 | def test_grad_prop2(self): 991 | out, (nodes, adj, weights, num_nodes) = self.s( 992 | self.obs, (self.nodes, self.adj, self.weights, self.num_nodes) 993 | ) 994 | # First run has no gradient as no edges to be made 995 | out, (nodes, adj, weights, num_nodes) = self.s( 996 | self.obs, (nodes, adj, weights, num_nodes) 997 | ) 998 | adj, weights = self.s.edge_selectors(nodes, adj, weights, num_nodes, 5) 999 | self.assertTrue(adj.grad_fn, "Adj has no gradient") 1000 | self.assertTrue(weights.grad_fn, "Weight has no gradient") 1001 | 1002 | def test_backward_multiple_selectors(self): 1003 | selector = torch_geometric.nn.Sequential( 1004 | "nodes, adj, weights, num_nodes, B", 1005 | [ 1006 | ( 1007 | TemporalBackedge(), 1008 | "nodes, adj, weights, num_nodes, B -> adj, weights", 1009 | ), 1010 | (LearnedEdge(5), "nodes, adj, weights, num_nodes, B -> adj, weights"), 1011 | ], 1012 | ) 1013 | self.s = DenseGCM(self.g, edge_selectors=selector) 1014 | out, (nodes, adj, weights, num_nodes) = self.s( 1015 | self.obs, (self.nodes, self.adj, self.weights, self.num_nodes) 1016 | ) 1017 | # First run has no gradient as no edges to be made 1018 | out, (nodes, adj, weights, num_nodes) = self.s( 1019 | self.obs, (nodes, adj, weights, num_nodes) 1020 | ) 1021 | adj, weights = self.s.edge_selectors(nodes, adj, weights, num_nodes, 5) 1022 | self.assertTrue(adj.grad_fn, "Adj has no gradient") 1023 | self.assertTrue(weights.grad_fn, "Weight has no gradient") 1024 | adj.mean().backward() 1025 | self.optimizer.step() 1026 | 1027 | def test_backwards(self): 1028 | nodes, adj, weights, num_nodes = ( 1029 | self.nodes, 1030 | self.adj, 1031 | self.weights, 1032 | self.num_nodes + 1, 1033 | ) 1034 | out, (nodes, adj, weights, num_nodes) = self.s( 1035 | self.obs, (nodes, adj, weights, num_nodes) 1036 | ) 1037 | self.s.edge_selectors.zero_grad() 1038 | nodes = torch.rand_like(self.nodes) * 0.00001 1039 | adj, weights = self.s.edge_selectors(nodes, adj, weights, num_nodes, 5) 1040 | adj.mean().backward() 1041 | self.optimizer.step() 1042 | 1043 | """ 1044 | def test_reg_loss(self): 1045 | feats = 5 1046 | batches = 2 1047 | T = 4 1048 | N = 10 1049 | losses = [] 1050 | for i in range(20): 1051 | nodes = torch.arange(batches * N * feats, dtype=torch.float).reshape( 1052 | batches, N, feats 1053 | ) 1054 | obs = torch.ones(batches, feats) 1055 | adj = torch.zeros(batches, N, N) 1056 | weights = torch.ones(batches, N, N) 1057 | num_nodes = torch.zeros(batches, dtype=torch.long) 1058 | 1059 | if i == 0: 1060 | continue 1061 | 1062 | hidden = (nodes, adj, weights, num_nodes) 1063 | for t in range(T): 1064 | obs, hidden = self.s(obs, hidden) 1065 | 1066 | loss = self.s.edge_selectors.compute_full_loss(nodes, nodes.shape[0]) 1067 | loss.backward() 1068 | losses.append(loss) 1069 | 1070 | self.optimizer.step() 1071 | # Must zero grad AND reset density 1072 | self.optimizer.zero_grad() 1073 | 1074 | if not losses[-1] < losses[0]: 1075 | self.fail(f"Final loss {losses[-1]} not better than init loss {losses[0]}") 1076 | """ 1077 | 1078 | def test_logit_index(self): 1079 | # Given 3 nodes, make sure we compare node 3 to nodes 1,2 1080 | out, (nodes, adj, weights, num_nodes) = self.s( 1081 | self.obs, (self.nodes, self.adj, self.weights, self.num_nodes) 1082 | ) 1083 | # First run has no gradient as no edges to be made 1084 | out, (nodes, adj, weights, num_nodes) = self.s( 1085 | self.obs, (nodes, adj, weights, num_nodes) 1086 | ) 1087 | adj, weights = self.s.edge_selectors(nodes, adj, weights, num_nodes, 5) 1088 | self.assertTrue(adj.grad_fn, "Adj has no gradient") 1089 | self.assertTrue(weights.grad_fn, "Weight has no gradient") 1090 | 1091 | def test_validate_logits(self): 1092 | # TODO: Fix test 1093 | return 1094 | nodes = self.nodes 1095 | adj = self.adj 1096 | weights = self.weights 1097 | num_nodes = self.num_nodes 1098 | adj, weights = self.s.edge_selectors.compute_logits( 1099 | nodes, weights.clone(), num_nodes, 5 1100 | ) 1101 | adj2, weights2 = self.s.edge_selectors.compute_logits2( 1102 | nodes, weights.clone(), num_nodes, 5 1103 | ) 1104 | if torch.any(adj != adj2): 1105 | self.fail(f"{adj} != {adj2}") 1106 | 1107 | if torch.any(weights != weights): 1108 | self.fail(f"{weights} != {weights2}") 1109 | 1110 | 1111 | class TestSpatialEdge(unittest.TestCase): 1112 | def setUp(self): 1113 | feats = 11 1114 | batches = 5 1115 | N = 10 1116 | conv_type = torch_geometric.nn.DenseGCNConv 1117 | self.g = torch_geometric.nn.Sequential( 1118 | "x, adj, weights, B, N", 1119 | [ 1120 | (conv_type(feats, feats), "x, adj -> x"), 1121 | (torch.nn.ReLU()), 1122 | (conv_type(feats, feats), "x, adj -> x"), 1123 | (torch.nn.ReLU()), 1124 | ], 1125 | ) 1126 | self.slice = slice(0, 2) 1127 | self.s = DenseGCM(self.g, edge_selectors=SpatialEdge(1, self.slice)) 1128 | 1129 | self.nodes = torch.zeros(batches, N, feats, dtype=torch.float) 1130 | self.obs = torch.zeros(batches, feats) 1131 | self.adj = torch.zeros(batches, N, N, dtype=torch.float) 1132 | self.weights = torch.ones(batches, N, N) 1133 | self.num_nodes = torch.ones(batches, dtype=torch.long) 1134 | 1135 | def test_zero_dist(self): 1136 | # Start num_nodes = 1 1137 | self.nodes[:] = 1 1138 | self.nodes[:, 0:2, self.slice] = 0 1139 | _, (nodes, adj, weights, num_nodes) = self.s( 1140 | self.obs, (self.nodes, self.adj, self.weights, self.num_nodes) 1141 | ) 1142 | tgt_adj = torch.zeros_like(adj, dtype=torch.long) 1143 | B_idx = torch.arange(num_nodes.shape[0]) 1144 | # It adds self edge 1145 | # tgt_adj[:, 0, 1] = 1 1146 | # tgt_adj[:, 1, 0] = 1 1147 | tgt_adj[B_idx, num_nodes[B_idx] - 1, 0] = 1 1148 | 1149 | B_idx = torch.arange(num_nodes.shape[0]) 1150 | 1151 | # TODO: Ensure not off by one 1152 | if torch.any(tgt_adj != adj): 1153 | self.fail(f"{tgt_adj} != {adj}") 1154 | 1155 | def test_one_dist(self): 1156 | # Start num_nodes = 1 1157 | self.nodes[:, 0, self.slice] = 1 1158 | _, (nodes, adj, weights, num_nodes) = self.s( 1159 | self.obs, (self.nodes, self.adj, self.weights, self.num_nodes) 1160 | ) 1161 | tgt_adj = torch.zeros_like(adj, dtype=torch.long) 1162 | # It adds self edge 1163 | 1164 | # TODO: Ensure not off by one 1165 | if torch.any(tgt_adj != adj): 1166 | self.fail(f"{tgt_adj} != {adj}") 1167 | 1168 | 1169 | if __name__ == "__main__": 1170 | unittest.main() 1171 | -------------------------------------------------------------------------------- /tests/test_nav_gcm.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import torch 3 | import torch_geometric 4 | from collections import OrderedDict 5 | import gym 6 | 7 | from gcm.nav_gcm import NavGCM 8 | 9 | class IdentGNN(torch.nn.Module): 10 | def forward(self, x, edges, rot, pos, batch, front_ptr, back_ptr, flat_new_idx): 11 | return x 12 | 13 | class GNN(torch.nn.Module): 14 | def __init__(self, size): 15 | super().__init__() 16 | self.gc = torch_geometric.nn.GraphConv(size, 1) 17 | 18 | def forward(self, x, edges, rot, pos, batch, front_ptr, back_ptr, flat_new_idx): 19 | self.edges = edges 20 | self.x = x 21 | self.rot = rot 22 | self.pos = pos 23 | self.batch = batch 24 | self.front_ptr = front_ptr 25 | self.back_ptr = back_ptr 26 | #return torch.cat([x, pos, rot], dim=-1) 27 | self.out = self.gc(torch.cat([x, pos, rot], dim=-1), edges) 28 | return self.out 29 | 30 | class TestComputeIdx(unittest.TestCase): 31 | def setUp(self): 32 | self.gcm = NavGCM(gnn=IdentGNN()) 33 | 34 | def test_ragged(self): 35 | taus = torch.tensor([2,3], dtype=torch.long) 36 | T = torch.tensor([1,2], dtype=torch.long) 37 | self.gcm.compute_idx(T, taus) 38 | 39 | t_idx = torch.tensor([0, 1, 2, 0, 1, 2, 3, 4]) 40 | b_idx = torch.tensor([0, 0, 0, 1, 1, 1, 1, 1]) 41 | self.assertTrue(torch.all(self.gcm.idx[0] == b_idx)) 42 | self.assertTrue(torch.all(self.gcm.idx[1] == t_idx)) 43 | 44 | b_new_idx = torch.tensor([0, 0, 1, 1, 1]) 45 | t_new_idx = torch.tensor([1, 2, 2, 3, 4]) 46 | self.assertTrue(torch.all(self.gcm.new_idx[0] == b_new_idx)) 47 | self.assertTrue(torch.all(self.gcm.new_idx[1] == t_new_idx)) 48 | 49 | # t_idx = torch.tensor([0, 1, 2, 0, 1, 2, 3, 4]) 50 | # b_idx = torch.tensor([0, 0, 0, 1, 1, 1, 1, 1]) 51 | # f_idx = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]) 52 | # new N N N N N 53 | flat_new_idx = torch.tensor([1, 2, 5, 6, 7]) 54 | self.assertTrue(torch.all(self.gcm.flat_new_idx == flat_new_idx)) 55 | 56 | b_out_idx = torch.tensor([0, 0, 1, 1, 1]) 57 | t_out_idx = torch.tensor([0, 1, 0, 1, 2]) 58 | self.assertTrue(torch.all(self.gcm.out_idx[0] == b_out_idx)) 59 | self.assertTrue(torch.all(self.gcm.out_idx[1] == t_out_idx)) 60 | 61 | back_ptr = torch.tensor([2, 7]) 62 | self.assertTrue(torch.all(self.gcm.back_ptr == back_ptr)) 63 | 64 | front_ptr = torch.tensor([0, 3]) 65 | self.assertTrue(torch.all(self.gcm.front_ptr == front_ptr)) 66 | 67 | def test_base_case(self): 68 | taus = torch.tensor([1,1,1], dtype=torch.long) 69 | T = torch.tensor([0,0,0], dtype=torch.long) 70 | self.gcm.compute_idx(T, taus) 71 | 72 | t_idx = torch.tensor([0, 0, 0]) 73 | b_idx = torch.tensor([0, 1, 2]) 74 | self.assertTrue(torch.all(self.gcm.idx[0] == b_idx)) 75 | self.assertTrue(torch.all(self.gcm.idx[1] == t_idx)) 76 | 77 | b_new_idx = torch.tensor([0, 1, 2]) 78 | t_new_idx = torch.tensor([0, 0, 0]) 79 | self.assertTrue(torch.all(self.gcm.new_idx[0] == b_new_idx)) 80 | self.assertTrue(torch.all(self.gcm.new_idx[1] == t_new_idx)) 81 | 82 | # t_idx = torch.tensor([0, 0]) 83 | # b_idx = torch.tensor([0, 1]) 84 | # f_idx = torch.tensor([0, 1]) 85 | # new N N 86 | flat_new_idx = torch.tensor([0, 1, 2]) 87 | self.assertTrue(torch.all(self.gcm.flat_new_idx == flat_new_idx)) 88 | 89 | b_out_idx = torch.tensor([0, 1, 2]) 90 | t_out_idx = torch.tensor([0, 0, 0]) 91 | self.assertTrue(torch.all(self.gcm.out_idx[0] == b_out_idx)) 92 | self.assertTrue(torch.all(self.gcm.out_idx[1] == t_out_idx)) 93 | 94 | back_ptr = torch.tensor([0, 1, 2]) 95 | self.assertTrue(torch.all(self.gcm.back_ptr == back_ptr)) 96 | 97 | front_ptr = torch.tensor([0, 1, 2]) 98 | self.assertTrue(torch.all(self.gcm.front_ptr == front_ptr)) 99 | 100 | def test_inference(self): 101 | taus = torch.tensor([1], dtype=torch.long) 102 | T = torch.tensor([2], dtype=torch.long) 103 | self.gcm.compute_idx(T, taus) 104 | 105 | t_idx = torch.tensor([0, 1, 2]) 106 | b_idx = torch.tensor([0, 0, 0]) 107 | self.assertTrue(torch.all(self.gcm.idx[0] == b_idx)) 108 | self.assertTrue(torch.all(self.gcm.idx[1] == t_idx)) 109 | 110 | b_new_idx = torch.tensor([0]) 111 | t_new_idx = torch.tensor([2]) 112 | self.assertTrue(torch.all(self.gcm.new_idx[0] == b_new_idx)) 113 | self.assertTrue(torch.all(self.gcm.new_idx[1] == t_new_idx)) 114 | 115 | # t_idx = torch.tensor([0, 1, 2]) 116 | # b_idx = torch.tensor([0, 0, 0]) 117 | # f_idx = torch.tensor([0, 1, 2]) 118 | # new N 119 | flat_new_idx = torch.tensor([2]) 120 | self.assertTrue(torch.all(self.gcm.flat_new_idx == flat_new_idx)) 121 | 122 | b_out_idx = torch.tensor([0]) 123 | t_out_idx = torch.tensor([0]) 124 | self.assertTrue(torch.all(self.gcm.out_idx[0] == b_out_idx)) 125 | self.assertTrue(torch.all(self.gcm.out_idx[1] == t_out_idx)) 126 | 127 | back_ptr = torch.tensor([2]) 128 | self.assertTrue(torch.all(self.gcm.back_ptr == back_ptr)) 129 | 130 | front_ptr = torch.tensor([0]) 131 | self.assertTrue(torch.all(self.gcm.front_ptr == front_ptr)) 132 | 133 | class TestUpdate(unittest.TestCase): 134 | def setUp(self): 135 | self.gcm = NavGCM(gnn=IdentGNN(), causal=True) 136 | 137 | def test_ragged(self): 138 | taus = torch.tensor([2,3], dtype=torch.long) 139 | T = torch.tensor([1,2], dtype=torch.long) 140 | 141 | x = torch.zeros((2, 10, 1)) 142 | pos = torch.zeros((2, 10, 2)) 143 | rot = torch.zeros((2, 10, 1)) 144 | x_in = torch.ones((2, 3, 1)) 145 | pos_in = torch.ones((2, 3, 2)) 146 | rot_in = torch.ones((2, 3, 1)) 147 | 148 | tgt_x = x.clone() 149 | tgt_pos = pos.clone() 150 | tgt_rot = rot.clone() 151 | 152 | tgt_x[0, 1:3] = 1 153 | tgt_x[1, 2:5] = 1 154 | 155 | tgt_pos[0, 1:3] = 1 156 | tgt_pos[1, 2:5] = 1 157 | 158 | tgt_rot[0, 1:3] = 1 159 | tgt_rot[1, 2:5] = 1 160 | 161 | self.gcm.compute_idx(T, taus) 162 | new_x, new_pos, new_rot = self.gcm.update( 163 | x_in, pos_in, rot_in, 164 | x, pos, rot, 165 | T, taus 166 | ) 167 | 168 | self.assertTrue(torch.all(tgt_x == new_x)) 169 | self.assertTrue(torch.all(tgt_pos == new_pos)) 170 | self.assertTrue(torch.all(tgt_rot == new_rot)) 171 | 172 | class TestE2E(unittest.TestCase): 173 | def setUp(self): 174 | self.gcm = NavGCM(gnn=GNN(4), causal=True, max_verts=8, r=3, edge_method="radius") 175 | 176 | def test_e2e_one_batch(self): 177 | taus = torch.tensor([8], dtype=torch.long) 178 | T = torch.tensor([0], dtype=torch.long) 179 | obs = torch.arange( 8* 1).reshape(1, 8, 1).float() 180 | pos = torch.arange( 8* 2).reshape(1, 8, 2).float() 181 | rot = torch.arange( 8* 1).reshape(1, 8, 1).float() 182 | state = [ 183 | torch.zeros(1, 8, 1), 184 | torch.zeros(1, 8, 2), 185 | torch.zeros(1, 8, 1), 186 | T 187 | ] 188 | inf_state = [s.clone() for s in state] 189 | train_output, train_state = self.gcm(obs, pos, rot, taus, state) 190 | train_edges = self.gcm.gnn.edges.clone() 191 | train_x = self.gcm.gnn.x.clone() 192 | train_pos = self.gcm.gnn.pos.clone() 193 | train_rot = self.gcm.gnn.rot.clone() 194 | train_out = self.gcm.gnn.out.clone() 195 | train_batch = self.gcm.gnn.batch.clone() 196 | 197 | inf_output = [] 198 | taus = torch.tensor([1], dtype=torch.long) 199 | for i in range(8): 200 | output, inf_state = self.gcm( 201 | obs[:,i,None], pos[:,i,None], rot[:,i,None], taus, inf_state 202 | ) 203 | if not torch.allclose(output, train_output[:,i,None]): 204 | self.fail(f"{i}: {output} != {train_output[:,i,None]}") 205 | inf_output.append(output) 206 | inf_output = torch.cat(inf_output, dim=1) 207 | for i in range(len(train_state)): 208 | if not torch.all(train_state[i] == inf_state[i]): 209 | self.fail(f"{i}: {train_state[i]} != {inf_state[i]}") 210 | inf_edges = self.gcm.gnn.edges.clone() 211 | inf_x = self.gcm.gnn.x.clone() 212 | inf_pos = self.gcm.gnn.pos.clone() 213 | inf_rot = self.gcm.gnn.rot.clone() 214 | inf_out = self.gcm.gnn.out.clone() 215 | inf_batch = self.gcm.gnn.batch.clone() 216 | self.assertTrue(torch.all(train_edges == inf_edges)) 217 | self.assertTrue(torch.all(train_x == inf_x)) 218 | self.assertTrue(torch.all(train_pos == inf_pos)) 219 | self.assertTrue(torch.all(train_rot == inf_rot)) 220 | self.assertTrue(torch.all(train_out == inf_out)) 221 | self.assertTrue(torch.all(train_batch == inf_batch)) 222 | self.assertTrue(torch.allclose(inf_output, train_output)) 223 | 224 | def test_e2e_multi_batch(self): 225 | taus = torch.tensor([8, 8], dtype=torch.long) 226 | T = torch.tensor([0, 0], dtype=torch.long) 227 | obs = torch.arange(2* 8* 1).reshape(2, 8, 1).float() 228 | pos = torch.arange(2* 8* 2).reshape(2, 8, 2).float() 229 | rot = torch.arange(2* 8* 1).reshape(2, 8, 1).float() 230 | state = [ 231 | torch.zeros(2, 8, 1), 232 | torch.zeros(2, 8, 2), 233 | torch.zeros(2, 8, 1), 234 | T 235 | ] 236 | inf_state = [s.clone() for s in state] 237 | train_output, train_state = self.gcm(obs, pos, rot, taus, state) 238 | train_edges = self.gcm.gnn.edges.clone() 239 | train_x = self.gcm.gnn.x.clone() 240 | train_pos = self.gcm.gnn.pos.clone() 241 | train_rot = self.gcm.gnn.rot.clone() 242 | train_out = self.gcm.gnn.out.clone() 243 | train_batch = self.gcm.gnn.batch.clone() 244 | 245 | inf_output = [] 246 | taus = torch.tensor([1, 1], dtype=torch.long) 247 | for i in range(8): 248 | output, inf_state = self.gcm( 249 | obs[:,i,None], pos[:,i,None], rot[:,i,None], taus, inf_state 250 | ) 251 | if not torch.allclose(output, train_output[:,i,None]): 252 | self.fail(f"{i}: {output} != {train_output[:,i,None]}") 253 | inf_output.append(output) 254 | inf_output = torch.cat(inf_output, dim=1) 255 | for i in range(len(train_state)): 256 | if not torch.all(train_state[i] == inf_state[i]): 257 | self.fail(f"{i}: {train_state[i]} != {inf_state[i]}") 258 | inf_edges = self.gcm.gnn.edges.clone() 259 | inf_x = self.gcm.gnn.x.clone() 260 | inf_pos = self.gcm.gnn.pos.clone() 261 | inf_rot = self.gcm.gnn.rot.clone() 262 | inf_out = self.gcm.gnn.out.clone() 263 | inf_batch = self.gcm.gnn.batch.clone() 264 | self.assertTrue(torch.all(train_edges == inf_edges)) 265 | self.assertTrue(torch.all(train_x == inf_x)) 266 | self.assertTrue(torch.all(train_pos == inf_pos)) 267 | self.assertTrue(torch.all(train_rot == inf_rot)) 268 | self.assertTrue(torch.all(train_out == inf_out)) 269 | self.assertTrue(torch.all(train_batch == inf_batch)) 270 | self.assertTrue(torch.allclose(inf_output, train_output)) 271 | -------------------------------------------------------------------------------- /tests/test_ray_gcm.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import torch 3 | import torch_geometric 4 | import ray 5 | from ray import tune 6 | 7 | from gcm.ray_gcm import RayDenseGCM 8 | from gcm.edge_selectors.temporal import TemporalBackedge 9 | 10 | 11 | class TestRaySanity(unittest.TestCase): 12 | def test_one_iter(self): 13 | hidden = 32 14 | ray.init( 15 | local_mode=True, 16 | object_store_memory=3e10, 17 | ) 18 | dgc = torch_geometric.nn.Sequential( 19 | "x, adj, weights, B, N", 20 | [ 21 | # Mean and sum aggregation perform roughly the same 22 | # Preprocessor with 1 layer did not help 23 | (torch_geometric.nn.DenseGraphConv(hidden, hidden), "x, adj -> x"), 24 | (torch.nn.Tanh()), 25 | (torch_geometric.nn.DenseGraphConv(hidden, hidden), "x, adj -> x"), 26 | (torch.nn.Tanh()), 27 | ], 28 | ) 29 | cfg = { 30 | "framework": "torch", 31 | "num_gpus": 0, 32 | "env": "CartPole-v0", 33 | "num_workers": 0, 34 | "model": { 35 | "custom_model": RayDenseGCM, 36 | "custom_model_config": { 37 | "graph_size": 32, 38 | "gnn_input_size": hidden, 39 | "gnn_output_size": hidden, 40 | "gnn": dgc, 41 | "edge_selectors": TemporalBackedge([1]), 42 | "edge_weights": False, 43 | }, 44 | }, 45 | } 46 | tune.run("A2C", config=cfg, stop={"info/num_steps_trained": 100}) 47 | -------------------------------------------------------------------------------- /tests/test_sparse_gcm.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import torch 3 | import torch_geometric 4 | from collections import OrderedDict 5 | import gym 6 | 7 | from gcm.gcm import DenseGCM 8 | from gcm.edge_selectors.temporal import TemporalBackedge 9 | from gcm.sparse_edge_selectors.temporal import TemporalEdge 10 | from gcm.edge_selectors.learned import LearnedEdge as DLearnedEdge 11 | from gcm.sparse_edge_selectors.learned import LearnedEdge as SLearnedEdge 12 | from gcm import util 13 | from gcm.sparse_gcm import SparseGCM 14 | from gcm import ray_sparse_gcm 15 | 16 | 17 | class TestFlattenAdj(unittest.TestCase): 18 | def setUp(self): 19 | self.F = 4 20 | self.B = 4 21 | self.T = torch.tensor([1, 2, 0, 0]) 22 | self.taus = torch.zeros(4, dtype=torch.long) 23 | self.graph_size = 5 24 | self.max_edges = 6 25 | 26 | def test_flatten_unflatten(self): 27 | adj = torch.zeros(self.B, 2, self.max_edges) 28 | adj[1, 0, 1] = 1.0 29 | adj[1, 1, 2] = 2.0 30 | sparse_adj = adj.to_sparse() 31 | e, w, b = util.flatten_adj(sparse_adj, self.T, self.taus, self.B) 32 | new_sparse_adj = util.unflatten_adj(e, w, b, self.T, self.taus, self.B, self.max_edges) 33 | 34 | if not torch.all( 35 | sparse_adj._indices() == new_sparse_adj._indices() 36 | ): 37 | self.fail( 38 | f"\n{sparse_adj._indices()} != \n{new_sparse_adj._indices()}" 39 | ) 40 | if not torch.all( 41 | sparse_adj.coalesce().values() == new_sparse_adj.coalesce().values() 42 | ): 43 | self.fail( 44 | f"{sparse_adj.coalesce().values()} != {sparse_adj.coalesce().values()}" 45 | ) 46 | 47 | def test_flatten_unflatten2(self): 48 | self.T = torch.tensor([1, 5, 2, 1]) 49 | self.taus = torch.zeros(4, dtype=torch.long) 50 | adj = torch.zeros(self.B, 2, self.max_edges) 51 | adj[1, 0, 1] = 1.0 52 | adj[1, 1, 2] = 2.0 53 | adj[1, 1, 3] = 3.0 54 | adj[2, 0, 1] = 4.0 55 | sparse_adj = adj.to_sparse() 56 | e, w, b = util.flatten_adj(sparse_adj, self.T, self.taus, self.B) 57 | new_sparse_adj = util.unflatten_adj(e, w, b, self.T, self.taus, self.B, self.max_edges) 58 | 59 | if not torch.all( 60 | sparse_adj._indices() == new_sparse_adj._indices() 61 | ): 62 | self.fail( 63 | f"\n{sparse_adj._indices()} != \n{new_sparse_adj._indices()}" 64 | ) 65 | if not torch.all( 66 | sparse_adj.coalesce().values() == new_sparse_adj.coalesce().values() 67 | ): 68 | self.fail( 69 | f"{sparse_adj.coalesce().values()} != {sparse_adj.coalesce().values()}" 70 | ) 71 | 72 | 73 | 74 | class TestPack(unittest.TestCase): 75 | def setUp(self): 76 | self.F = 4 77 | self.B = 3 78 | self.T = torch.tensor([2, 3, 0]) 79 | self.graph_size = 5 80 | self.max_edges = 10 81 | 82 | def test_unpack_pack_empty(self): 83 | self.T = torch.tensor([0, 0, 0]) 84 | nodes = torch.zeros(self.B, self.graph_size, self.F) 85 | dense_edge = torch.empty(self.B, 2, self.max_edges, dtype=torch.long).fill_(-1) 86 | dense_weight = torch.empty(self.B, 1, self.max_edges).fill_(1.0) 87 | 88 | initial_packed_hidden = ( 89 | nodes.clone(), 90 | dense_edge.clone(), 91 | dense_weight.clone(), 92 | self.T.clone(), 93 | ) 94 | 95 | packed_hidden = (nodes, dense_edge, dense_weight, self.T) 96 | unpacked_hidden = util.unpack_hidden(packed_hidden, self.B) 97 | repacked_hidden = util.pack_hidden(unpacked_hidden, self.B, self.max_edges) 98 | for i in range(len(initial_packed_hidden)): 99 | if not (initial_packed_hidden[i] == repacked_hidden[i]).all(): 100 | self.fail( 101 | f"packed hidden tensor {i} != repacked hidden tensor" 102 | "{initial_packed_hidden[i]) != {repacked_hidden[i]}" 103 | ) 104 | 105 | def test_unpack_pack(self): 106 | nodes = torch.zeros(self.B, self.graph_size, self.F) 107 | 108 | dense_edge = torch.empty(self.B, 2, self.max_edges, dtype=torch.long).fill_(-1) 109 | dense_edge[0, :, 0] = torch.tensor([0, 1]) 110 | dense_edge[1, :, 0] = torch.tensor([0, 1]) 111 | dense_edge[1, :, 1] = torch.tensor([1, 2]) 112 | 113 | dense_weight = torch.empty(self.B, 1, self.max_edges).fill_(1.0) 114 | dense_weight[0, 0, 0] = 0.5 115 | dense_weight[1, 0, 0] = 0.33 116 | dense_weight[1, 0, 1] = 0.25 117 | initial_packed_hidden = ( 118 | nodes.clone(), 119 | dense_edge.clone(), 120 | dense_weight.clone(), 121 | self.T.clone(), 122 | ) 123 | 124 | packed_hidden = (nodes, dense_edge, dense_weight, self.T) 125 | unpacked_hidden = util.unpack_hidden(packed_hidden, self.B) 126 | repacked_hidden = util.pack_hidden(unpacked_hidden, self.B, self.max_edges) 127 | 128 | for i in range(len(initial_packed_hidden)): 129 | if not (initial_packed_hidden[i] == repacked_hidden[i]).all(): 130 | self.fail( 131 | f"packed hidden tensor {i} != repacked hidden tensor" 132 | "{initial_packed_hidden[i]) != {repacked_hidden[i]}" 133 | ) 134 | 135 | def test_unpack_pack_one_batch(self): 136 | self.B = 1 137 | self.T = torch.tensor([3]) 138 | nodes = torch.zeros(self.B, self.graph_size, self.F) 139 | 140 | dense_edge = torch.empty(self.B, 2, self.max_edges, dtype=torch.long).fill_(-1) 141 | dense_edge[0, :, 0] = torch.tensor([0, 1]) 142 | dense_edge[0, :, 1] = torch.tensor([0, 2]) 143 | dense_edge[0, :, 2] = torch.tensor([1, 2]) 144 | 145 | dense_weight = torch.empty(self.B, 1, self.max_edges).fill_(1.0) 146 | dense_weight[0, 0, 0] = 0.5 147 | dense_weight[0, 0, 1] = 0.25 148 | initial_packed_hidden = ( 149 | nodes.clone(), 150 | dense_edge.clone(), 151 | dense_weight.clone(), 152 | self.T.clone(), 153 | ) 154 | 155 | packed_hidden = (nodes, dense_edge, dense_weight, self.T) 156 | unpacked_hidden = util.unpack_hidden(packed_hidden, self.B) 157 | repacked_hidden = util.pack_hidden(unpacked_hidden, self.B, self.max_edges) 158 | 159 | for i in range(len(initial_packed_hidden)): 160 | if not (initial_packed_hidden[i] == repacked_hidden[i]).all(): 161 | self.fail( 162 | f"packed hidden tensor {i} != repacked hidden tensor" 163 | "{initial_packed_hidden[i]) != {repacked_hidden[i]}" 164 | ) 165 | 166 | def test_unpack_pack_full_empty(self): 167 | self.B = 3 168 | self.T = torch.tensor([5, 0, 0]) 169 | nodes = torch.zeros(self.B, self.graph_size, self.F) 170 | 171 | dense_edge = torch.empty(self.B, 2, self.max_edges, dtype=torch.long).fill_(-1) 172 | dense_edge[0, :, 0] = torch.tensor([0, 1]) 173 | dense_edge[0, :, 1] = torch.tensor([0, 2]) 174 | dense_edge[0, :, 2] = torch.tensor([0, 3]) 175 | dense_edge[0, :, 3] = torch.tensor([0, 4]) 176 | 177 | dense_edge[0, :, 3] = torch.tensor([1, 2]) 178 | dense_edge[0, :, 4] = torch.tensor([1, 3]) 179 | dense_edge[0, :, 5] = torch.tensor([1, 4]) 180 | 181 | dense_edge[0, :, 6] = torch.tensor([2, 3]) 182 | dense_edge[0, :, 7] = torch.tensor([2, 4]) 183 | 184 | dense_edge[0, :, 8] = torch.tensor([3, 4]) 185 | 186 | dense_weight = torch.empty(self.B, 1, self.max_edges).fill_(1.0) 187 | dense_weight[0, 0, 0] = 0.5 188 | dense_weight[0, 0, 5] = 0.25 189 | initial_packed_hidden = ( 190 | nodes.clone(), 191 | dense_edge.clone(), 192 | dense_weight.clone(), 193 | self.T.clone(), 194 | ) 195 | 196 | packed_hidden = (nodes, dense_edge, dense_weight, self.T) 197 | unpacked_hidden = util.unpack_hidden(packed_hidden, self.B) 198 | repacked_hidden = util.pack_hidden(unpacked_hidden, self.B, self.max_edges) 199 | 200 | for i in range(len(initial_packed_hidden)): 201 | if not (initial_packed_hidden[i] == repacked_hidden[i]).all(): 202 | self.fail( 203 | f"packed hidden tensor {i} != repacked hidden tensor" 204 | "{initial_packed_hidden[i]) != {repacked_hidden[i]}" 205 | ) 206 | 207 | def test_unpack_pack_ragged(self): 208 | self.B = 3 209 | self.T = torch.tensor([5, 4, 3]) 210 | nodes = torch.zeros(self.B, self.graph_size, self.F) 211 | 212 | dense_edge = torch.empty(self.B, 2, self.max_edges, dtype=torch.long).fill_(-1) 213 | dense_edge[0, :, 0] = torch.tensor([0, 1]) 214 | dense_edge[0, :, 1] = torch.tensor([0, 2]) 215 | dense_edge[0, :, 2] = torch.tensor([1, 2]) 216 | dense_edge[0, :, 3] = torch.tensor([2, 3]) 217 | dense_edge[0, :, 4] = torch.tensor([3, 4]) 218 | 219 | dense_edge[1, :, 0] = torch.tensor([0, 1]) 220 | dense_edge[1, :, 1] = torch.tensor([0, 2]) 221 | dense_edge[1, :, 2] = torch.tensor([1, 2]) 222 | dense_edge[1, :, 3] = torch.tensor([1, 3]) 223 | 224 | dense_edge[2, :, 0] = torch.tensor([0, 1]) 225 | dense_edge[2, :, 1] = torch.tensor([0, 2]) 226 | 227 | dense_weight = torch.empty(self.B, 1, self.max_edges).fill_(1.0) 228 | dense_weight[0, 0, 0] = 0.5 229 | dense_weight[0, 0, 1] = 0.25 230 | 231 | dense_weight[1, 0, 0] = 0.1 232 | dense_weight[1, 0, 1] = 0.2 233 | 234 | initial_packed_hidden = ( 235 | nodes.clone(), 236 | dense_edge.clone(), 237 | dense_weight.clone(), 238 | self.T.clone(), 239 | ) 240 | 241 | packed_hidden = (nodes, dense_edge, dense_weight, self.T) 242 | unpacked_hidden = util.unpack_hidden(packed_hidden, self.B) 243 | repacked_hidden = util.pack_hidden(unpacked_hidden, self.B, self.max_edges) 244 | 245 | for i in range(len(initial_packed_hidden)): 246 | if not (initial_packed_hidden[i] == repacked_hidden[i]).all(): 247 | self.fail( 248 | f"packed hidden tensor {i} != repacked hidden tensor" 249 | f"\n{initial_packed_hidden[i]} != \n\n{repacked_hidden[i]}" 250 | ) 251 | 252 | def test_unpack_pack_many(self): 253 | nodes = torch.zeros(self.B, self.graph_size, self.F) 254 | 255 | dense_edge = torch.empty(self.B, 2, self.max_edges, dtype=torch.long).fill_(-1) 256 | dense_edge[0, :, 0] = torch.tensor([0, 1]) 257 | dense_edge[1, :, 0] = torch.tensor([0, 1]) 258 | dense_edge[1, :, 1] = torch.tensor([1, 2]) 259 | 260 | dense_weight = torch.empty(self.B, 1, self.max_edges).fill_(1.0) 261 | dense_weight[0, 0, 0] = 0.5 262 | dense_weight[1, 0, 1] = 0.25 263 | initial_packed_hidden = ( 264 | nodes.clone(), 265 | dense_edge.clone(), 266 | dense_weight.clone(), 267 | self.T.clone(), 268 | ) 269 | 270 | packed_hidden = (nodes, dense_edge, dense_weight, self.T) 271 | for i in range(10): 272 | unpacked_hidden = util.unpack_hidden(packed_hidden, self.B) 273 | packed_hidden = util.pack_hidden(unpacked_hidden, self.B, self.max_edges) 274 | 275 | for i in range(len(initial_packed_hidden)): 276 | if not (initial_packed_hidden[i] == packed_hidden[i]).all(): 277 | self.fail( 278 | f"packed hidden tensor {i} != repacked hidden tensor" 279 | "{initial_packed_hidden[i]) != {repacked_hidden[i]}" 280 | ) 281 | 282 | def test_unpack_empty(self): 283 | nodes = torch.zeros(self.B, self.graph_size, self.F) 284 | 285 | dense_edge = torch.empty(self.B, 2, self.max_edges, dtype=torch.long).fill_(-1) 286 | 287 | dense_weight = torch.empty(self.B, 1, self.max_edges).fill_(1.0) 288 | initial_packed_hidden = ( 289 | nodes.clone(), 290 | dense_edge.clone(), 291 | dense_weight.clone(), 292 | self.T.clone(), 293 | ) 294 | 295 | packed_hidden = (nodes, dense_edge, dense_weight, self.T) 296 | unpacked_hidden = util.unpack_hidden(packed_hidden, self.B) 297 | repacked_hidden = util.pack_hidden(unpacked_hidden, self.B, self.max_edges) 298 | 299 | for i in range(len(initial_packed_hidden)): 300 | if not (initial_packed_hidden[i] == repacked_hidden[i]).all(): 301 | self.fail( 302 | f"packed hidden tensor {i} != repacked hidden tensor" 303 | "{initial_packed_hidden[i]) != {repacked_hidden[i]}" 304 | ) 305 | 306 | 307 | class TestDenseVsSparse(unittest.TestCase): 308 | def setUp(self): 309 | self.F = 3 310 | dense_conv_type = torch_geometric.nn.DenseGraphConv 311 | sparse_conv_type = torch_geometric.nn.GraphConv 312 | self.dense_g = torch_geometric.nn.Sequential( 313 | "x, adj, weights, B, N", 314 | [ 315 | (dense_conv_type(self.F, self.F), "x, adj -> x"), 316 | (dense_conv_type(self.F, self.F), "x, adj -> x"), 317 | ], 318 | ) 319 | self.sparse_g = torch_geometric.nn.Sequential( 320 | "x, edges, weights", 321 | [ 322 | (sparse_conv_type(self.F, self.F), "x, edges, weights -> x"), 323 | (sparse_conv_type(self.F, self.F), "x, edges, weights -> x"), 324 | ], 325 | ) 326 | dense_params = self.dense_g.state_dict() 327 | self.sparse_g.load_state_dict(dense_params) 328 | # sanity check 329 | for k, v in self.sparse_g.state_dict().items(): 330 | self.assertTrue((v == self.dense_g.state_dict()[k]).all()) 331 | 332 | self.dense_gcm = DenseGCM(self.dense_g, graph_size=8) 333 | self.sparse_gcm = SparseGCM(self.sparse_g, graph_size=8) 334 | 335 | def test_gnn_no_weights(self): 336 | sparse_out = self.sparse_g( 337 | torch.ones(2, 3, self.F), torch.ones(2, 0, dtype=torch.long), torch.ones(0) 338 | ) 339 | dense_out = self.dense_g( 340 | torch.ones(2, 3, self.F), 341 | torch.zeros(2, 3, 3), 342 | torch.ones(2, 3, 3), 343 | None, 344 | None, 345 | ) 346 | 347 | self.assertTrue((sparse_out == dense_out).all()) 348 | 349 | def test_no_edges(self): 350 | F = self.F 351 | B = 3 352 | ts = 4 353 | self.obs = torch.arange(B * ts * F, dtype=torch.float32).reshape(B, ts, F) 354 | 355 | dense_outs = [] 356 | 357 | hidden = None 358 | for i in range(ts): 359 | dense_out, hidden = self.dense_gcm(self.obs[:, i], hidden) 360 | dense_outs.append(dense_out) 361 | dense_outs = torch.stack(dense_outs, dim=1) 362 | 363 | # One step at at ime 364 | self.sparse_step = SparseGCM(self.sparse_g, graph_size=8) 365 | sparse_step_hidden = None 366 | sparse_step_outs = [] 367 | for i in range(ts): 368 | taus = torch.ones(B, dtype=torch.long) 369 | sparse_step_out, sparse_step_hidden = self.sparse_gcm( 370 | self.obs[:, i].unsqueeze(1), taus, sparse_step_hidden 371 | ) 372 | sparse_step_outs.append(sparse_step_out) 373 | sparse_step_outs = torch.cat(sparse_step_outs, dim=1) 374 | 375 | # All at once 376 | taus = torch.ones(B, dtype=torch.long) * ts 377 | sparse_outs, sparse_hidden = self.sparse_gcm(self.obs, taus, None) 378 | 379 | if dense_outs.numel() != sparse_outs.numel(): 380 | self.fail(f"sizes {dense_outs.numel()} != {sparse_outs.numel()}") 381 | 382 | # Check hiddens 383 | if not torch.all(hidden[0] == sparse_hidden[0]): 384 | self.fail(f"{hidden[0]} != {sparse_hidden[0]}") 385 | 386 | if not torch.all(hidden[0] == sparse_step_hidden[0]): 387 | self.fail(f"{hidden[0]} != {sparse_step_hidden[0]}") 388 | 389 | if not torch.all(dense_outs == sparse_outs): 390 | self.fail(f"{dense_outs} != {sparse_outs}") 391 | 392 | if not torch.all(dense_outs == sparse_step_outs): 393 | self.fail(f"{dense_outs} != {sparse_step_outs}") 394 | 395 | def test_temporal_edges(self): 396 | self.dense_gcm = DenseGCM( 397 | self.dense_g, edge_selectors=TemporalBackedge([1, 2]), graph_size=8 398 | ) 399 | self.sparse_gcm = SparseGCM( 400 | self.sparse_g, edge_selectors=TemporalEdge([1, 2]), graph_size=8 401 | ) 402 | F = self.F 403 | B = 3 404 | ts = 8 405 | self.obs = torch.arange(B * ts * F, dtype=torch.float32).reshape(B, ts, F) 406 | 407 | dense_outs = [] 408 | 409 | dense_hidden = None 410 | for i in range(ts): 411 | dense_out, dense_hidden = self.dense_gcm(self.obs[:, i], dense_hidden) 412 | dense_outs.append(dense_out) 413 | dense_outs = torch.stack(dense_outs, dim=1) 414 | 415 | taus = torch.ones(B, dtype=torch.long) * ts 416 | sparse_outs, sparse_hidden = self.sparse_gcm(self.obs, taus, None) 417 | 418 | if dense_outs.numel() != sparse_outs.numel(): 419 | self.fail(f"sizes {dense_outs.numel()} != {sparse_outs.numel()}") 420 | 421 | # Check hiddens 422 | if not torch.all(dense_hidden[0] == sparse_hidden[0]): 423 | self.fail(f"{dense_hidden[0]} != {sparse_hidden[0]}") 424 | 425 | if not torch.all(dense_hidden[1].nonzero().T == sparse_hidden[1].coalesce().indices()): 426 | self.fail(f"dense and sparse edges inequal: \n{dense_hidden[1].nonzero().T} != \n{sparse_hidden[1]._indices()}") 427 | 428 | if not torch.all(dense_outs == sparse_outs): 429 | self.fail(f"{dense_outs} != {sparse_outs}") 430 | 431 | 432 | def test_temporal_edges_2_hop(self): 433 | self.dense_gcm = DenseGCM( 434 | self.dense_g, edge_selectors=TemporalBackedge([1, 2]), graph_size=8 435 | ) 436 | self.sparse_gcm = SparseGCM( 437 | self.sparse_g, edge_selectors=TemporalEdge([1, 2]), graph_size=8, max_hops=2 438 | ) 439 | F = self.F 440 | B = 3 441 | ts = 8 442 | self.obs = torch.arange(B * ts * F, dtype=torch.float32).reshape(B, ts, F) 443 | 444 | dense_outs = [] 445 | 446 | dense_hidden = None 447 | for i in range(ts): 448 | dense_out, dense_hidden = self.dense_gcm(self.obs[:, i], dense_hidden) 449 | dense_outs.append(dense_out) 450 | dense_outs = torch.stack(dense_outs, dim=1) 451 | 452 | taus = torch.ones(B, dtype=torch.long) * ts 453 | sparse_outs, sparse_hidden = self.sparse_gcm(self.obs, taus, None) 454 | 455 | if dense_outs.numel() != sparse_outs.numel(): 456 | self.fail(f"sizes {dense_outs.numel()} != {sparse_outs.numel()}") 457 | 458 | # Check hiddens 459 | if not torch.all(dense_hidden[0] == sparse_hidden[0]): 460 | self.fail(f"{dense_hidden[0]} != {sparse_hidden[0]}") 461 | 462 | if not torch.all(dense_hidden[1].nonzero().T == sparse_hidden[1].coalesce().indices()): 463 | self.fail(f"dense and sparse edges inequal: \n{dense_hidden[1].nonzero().T} != \n{sparse_hidden[1]._indices()}") 464 | 465 | if not torch.all(dense_outs == sparse_outs): 466 | self.fail(f"{dense_outs} != {sparse_outs}") 467 | 468 | 469 | def test_temporal_edges_many_iter_2_hop(self): 470 | self.dense_gcm = DenseGCM( 471 | self.dense_g, edge_selectors=TemporalBackedge([1, 2]), graph_size=8 472 | ) 473 | self.sparse_gcm = SparseGCM( 474 | self.sparse_g, edge_selectors=TemporalEdge([1, 2]), graph_size=8, max_hops=2 475 | ) 476 | F = self.F 477 | B = 3 478 | ts = 8 479 | self.obs = torch.arange(B * ts * F, dtype=torch.float32).reshape(B, ts, F) 480 | 481 | dense_outs = [] 482 | 483 | dense_hidden = None 484 | for i in range(ts): 485 | dense_out, dense_hidden = self.dense_gcm(self.obs[:, i], dense_hidden) 486 | dense_outs.append(dense_out) 487 | dense_outs = torch.stack(dense_outs, dim=1) 488 | 489 | sparse_hidden = None 490 | taus = torch.ones(B, dtype=torch.long) 491 | sparse_outs = [] 492 | for i in range(ts): 493 | sparse_out, sparse_hidden = self.sparse_gcm(self.obs[:,i].unsqueeze(1), taus, sparse_hidden) 494 | sparse_outs.append(sparse_out) 495 | sparse_outs = torch.cat(sparse_outs, dim=1) 496 | 497 | if dense_outs.numel() != sparse_outs.numel(): 498 | self.fail(f"sizes {dense_outs.numel()} != {sparse_outs.numel()}") 499 | 500 | if not torch.all(dense_hidden[1].nonzero().T == sparse_hidden[1].coalesce().indices()): 501 | self.fail(f"sparse and dense edges inequal: \n{dense_hidden[1].nonzero().T} != \n{sparse_hidden[1].coalesce().indices()}") 502 | 503 | if not torch.all(dense_outs == sparse_outs): 504 | self.fail(f"{dense_outs} != {sparse_outs}") 505 | def test_temporal_edges_many_iter(self): 506 | self.dense_gcm = DenseGCM( 507 | self.dense_g, edge_selectors=TemporalBackedge([1, 2]), graph_size=8 508 | ) 509 | self.sparse_gcm = SparseGCM( 510 | self.sparse_g, edge_selectors=TemporalEdge([1, 2]), graph_size=8 511 | ) 512 | F = self.F 513 | B = 3 514 | ts = 8 515 | self.obs = torch.arange(B * ts * F, dtype=torch.float32).reshape(B, ts, F) 516 | 517 | dense_outs = [] 518 | 519 | dense_hidden = None 520 | for i in range(ts): 521 | dense_out, dense_hidden = self.dense_gcm(self.obs[:, i], dense_hidden) 522 | dense_outs.append(dense_out) 523 | dense_outs = torch.stack(dense_outs, dim=1) 524 | 525 | sparse_hidden = None 526 | taus = torch.ones(B, dtype=torch.long) 527 | sparse_outs = [] 528 | for i in range(ts): 529 | sparse_out, sparse_hidden = self.sparse_gcm(self.obs[:,i].unsqueeze(1), taus, sparse_hidden) 530 | sparse_outs.append(sparse_out) 531 | sparse_outs = torch.cat(sparse_outs, dim=1) 532 | 533 | if dense_outs.numel() != sparse_outs.numel(): 534 | self.fail(f"sizes {dense_outs.numel()} != {sparse_outs.numel()}") 535 | 536 | if not torch.all(dense_hidden[1].nonzero().T == sparse_hidden[1].coalesce().indices()): 537 | self.fail(f"sparse and dense edges inequal: \n{dense_hidden[1].nonzero().T} != \n{sparse_hidden[1].coalesce().indices()}") 538 | 539 | if not torch.all(dense_outs == sparse_outs): 540 | self.fail(f"{dense_outs} != {sparse_outs}") 541 | 542 | def test_learning_temporal_edges(self): 543 | self.dense_gcm = DenseGCM( 544 | self.dense_g, edge_selectors=TemporalBackedge([1, 2]), graph_size=8 545 | ) 546 | self.sparse_gcm = SparseGCM( 547 | self.sparse_g, edge_selectors=TemporalEdge([1, 2]), graph_size=8 548 | ) 549 | self.sparse_step = SparseGCM( 550 | self.sparse_g, edge_selectors=TemporalEdge([1, 2]), graph_size=8 551 | ) 552 | d_opt = torch.optim.Adam(self.dense_gcm.parameters()) 553 | s_opt = torch.optim.Adam(self.sparse_gcm.parameters()) 554 | ss_opt = torch.optim.Adam(self.sparse_step.parameters()) 555 | F = self.F 556 | B = 3 557 | ts = 8 558 | num_iters = 3 559 | 560 | for i in range(num_iters): 561 | d_opt.zero_grad() 562 | s_opt.zero_grad() 563 | ss_opt.zero_grad() 564 | self.obs = torch.rand((B, ts, F), dtype=torch.float32) 565 | 566 | dense_outs = [] 567 | 568 | dense_hidden = None 569 | for i in range(ts): 570 | dense_out, dense_hidden = self.dense_gcm(self.obs[:, i], dense_hidden) 571 | dense_outs.append(dense_out) 572 | dense_outs = torch.stack(dense_outs, dim=1) 573 | 574 | # One step sparse 575 | sparse_step_hidden = None 576 | sparse_step_outs = [] 577 | for i in range(ts): 578 | taus = torch.ones(B, dtype=torch.long) 579 | sparse_step_out, sparse_step_hidden = self.sparse_gcm( 580 | self.obs[:, i].unsqueeze(1), taus, sparse_step_hidden 581 | ) 582 | sparse_step_outs.append(sparse_step_out) 583 | sparse_step_outs = torch.cat(sparse_step_outs, dim=1) 584 | 585 | # Time batched sparse 586 | taus = torch.ones(B, dtype=torch.long) * ts 587 | sparse_outs, sparse_hidden = self.sparse_gcm(self.obs, taus, None) 588 | 589 | if dense_outs.numel() != sparse_outs.numel(): 590 | self.fail(f"sizes {dense_outs.numel()} != {sparse_outs.numel()}") 591 | 592 | if dense_outs.numel() != sparse_step_outs.numel(): 593 | self.fail(f"sizes {dense_outs.numel()} != {sparse_step_outs.numel()}") 594 | 595 | # Check hiddens 596 | if not torch.all(dense_hidden[0] == sparse_hidden[0]): 597 | self.fail(f"{dense_hidden[0]} != {sparse_hidden[0]}") 598 | 599 | if not torch.allclose(dense_outs, sparse_outs, atol=0.01): 600 | self.fail(f"{dense_outs} != {sparse_outs}") 601 | 602 | if not torch.all(dense_outs == sparse_step_outs): 603 | self.fail(f"{dense_outs} != {sparse_step_outs}") 604 | sparse_outs.mean().backward() 605 | dense_outs.mean().backward() 606 | d_opt.step() 607 | s_opt.step() 608 | 609 | for k, v in self.sparse_g.state_dict().items(): 610 | if not torch.allclose(v, self.dense_g.state_dict()[k], atol=0.01): 611 | self.fail( 612 | f"Parameters diverged: {v}, {self.dense_g.state_dict()[k]}" 613 | ) 614 | 615 | 616 | class DummyEdgenet(torch.nn.Module): 617 | def __init__(self): 618 | super().__init__() 619 | self.ghost = torch.nn.Linear(1, 1) 620 | 621 | def forward(self, x): 622 | return 1e15 * torch.all(x > 0, dim=1).float() 623 | 624 | 625 | class TestLearnedEdge(unittest.TestCase): 626 | def setUp(self): 627 | self.F = 3 628 | sparse_conv_type = torch_geometric.nn.GraphConv 629 | self.sparse_g = torch_geometric.nn.Sequential( 630 | "x, edges, weights", 631 | [ 632 | (sparse_conv_type(self.F, self.F), "x, edges, weights -> x"), 633 | (sparse_conv_type(self.F, self.F), "x, edges, weights -> x"), 634 | ], 635 | ) 636 | 637 | 638 | 639 | def test_first_pass(self): 640 | B = 2 641 | gsize = 5 642 | taus = torch.ones(B, dtype=torch.long) * 5 643 | T = torch.zeros(B) 644 | obs = torch.zeros(B, taus.max().int(), self.F) 645 | obs[0,0] = 1 646 | obs[0,4] = 1 647 | obs[1,1] = 1 648 | obs[1,4] = 1 649 | sel = SLearnedEdge(input_size=0, model=DummyEdgenet(), num_edge_samples=1) 650 | gcm = SparseGCM( 651 | self.sparse_g, graph_size=gsize, edge_selectors=sel 652 | ) 653 | 654 | out, hidden = gcm(obs, taus, hidden=None) 655 | # Should result in edges: 656 | #b0: [1,0], [2,?], [3,?], [4,0] 657 | #b1: [1,0], [2,?], [3,?], [4,1] 658 | self.assertTrue(torch.tensor([0, 1, 0]) in hidden[1]._indices().T) 659 | self.assertTrue(torch.tensor([0, 4, 0]) in hidden[1]._indices().T) 660 | self.assertTrue(torch.tensor([1, 1, 0]) in hidden[1]._indices().T) 661 | self.assertTrue(torch.tensor([1, 4, 1]) in hidden[1]._indices().T) 662 | 663 | def test_second_pass(self): 664 | B = 2 665 | gsize = 10 666 | taus = torch.ones(B, dtype=torch.long) * 5 667 | T = torch.zeros(B) 668 | obs = torch.zeros(B, taus.max().int(), self.F) 669 | obs[0,0] = 1 670 | obs[0,4] = 1 671 | obs[1,1] = 1 672 | obs[1,4] = 1 673 | sel = SLearnedEdge(input_size=0, model=DummyEdgenet(), num_edge_samples=1) 674 | gcm = SparseGCM( 675 | self.sparse_g, graph_size=gsize, edge_selectors=sel 676 | ) 677 | 678 | out, hidden = gcm(obs, taus, hidden=None) 679 | # Should result in edges: 680 | #b0: [1,0], [2,?], [3,?], [4,0] 681 | #b1: [1,0], [2,?], [3,?], [4,1] 682 | self.assertTrue(torch.tensor([0, 1, 0]) in hidden[1]._indices().T) 683 | self.assertTrue(torch.tensor([0, 4, 0]) in hidden[1]._indices().T) 684 | self.assertTrue(torch.tensor([1, 1, 0]) in hidden[1]._indices().T) 685 | self.assertTrue(torch.tensor([1, 4, 1]) in hidden[1]._indices().T) 686 | 687 | # Second pass 688 | # should have first pass edges 689 | out, hidden = gcm(obs, taus, hidden) 690 | self.assertTrue(torch.tensor([0, 1, 0]) in hidden[1]._indices().T) 691 | self.assertTrue(torch.tensor([0, 4, 0]) in hidden[1]._indices().T) 692 | self.assertTrue(torch.tensor([1, 1, 0]) in hidden[1]._indices().T) 693 | self.assertTrue(torch.tensor([1, 4, 1]) in hidden[1]._indices().T) 694 | 695 | 696 | def test_multi_pass(self): 697 | B = 2 698 | gsize = 10 699 | taus = torch.ones(B, dtype=torch.long) 700 | T = torch.zeros(B) 701 | obs = torch.zeros(B, gsize, self.F) 702 | obs[0,0] = 1 703 | obs[0,4] = 1 704 | obs[1,1] = 1 705 | obs[1,4] = 1 706 | sel = SLearnedEdge(input_size=0, model=DummyEdgenet(), num_edge_samples=1) 707 | gcm = SparseGCM( 708 | self.sparse_g, graph_size=gsize, edge_selectors=sel 709 | ) 710 | hidden = None 711 | for i in range(gsize): 712 | out, hidden = gcm(obs[:,i].unsqueeze(1), taus, hidden) 713 | 714 | # Should result in edges: 715 | #b0: [1,0], [2,?], [3,?], [4,0] 716 | #b1: [1,0], [2,?], [3,?], [4,1] 717 | self.assertTrue(torch.tensor([0, 1, 0]) in hidden[1]._indices().T) 718 | self.assertTrue(torch.tensor([0, 4, 0]) in hidden[1]._indices().T) 719 | self.assertTrue(torch.tensor([1, 1, 0]) in hidden[1]._indices().T) 720 | self.assertTrue(torch.tensor([1, 4, 1]) in hidden[1]._indices().T) 721 | 722 | # Ensure same as batched case 723 | bout, bhidden = gcm(obs, taus * gsize, None) 724 | self.assertTrue(bhidden[1].shape == hidden[1].shape) 725 | 726 | def test_window_multi_pass(self): 727 | B = 2 728 | gsize = 4 729 | taus = torch.ones(B, dtype=torch.long) 730 | T = torch.zeros(B) 731 | obs = torch.zeros(B, gsize, self.F) 732 | sel = SLearnedEdge(input_size=0, model=DummyEdgenet(), num_edge_samples=1, window=1) 733 | gcm = SparseGCM( 734 | self.sparse_g, graph_size=gsize, edge_selectors=sel 735 | ) 736 | hidden = None 737 | for i in range(gsize): 738 | out, hidden = gcm(obs[:,i].unsqueeze(1), taus, hidden) 739 | 740 | # Should result in edges: 741 | #b0: [1,0], [2,?], [3,?], [4,0] 742 | #b1: [1,0], [2,?], [3,?], [4,1] 743 | desired = torch.tensor( 744 | [ 745 | [0, 1, 0], 746 | [0, 2, 1], 747 | [0, 3, 2], 748 | # second batch 749 | [1, 1, 0], 750 | [1, 2, 1], 751 | [1, 3, 2], 752 | ] 753 | ).T 754 | if not torch.all(hidden[1].coalesce().indices() == desired): 755 | self.fail(f"{hidden[1].coalesce().indices()} != {desired}") 756 | 757 | def test_grad(self): 758 | B = 2 759 | gsize = 4 760 | taus = gsize * torch.ones(B, dtype=torch.long) 761 | T = torch.zeros(B, dtype=torch.long) 762 | obs = torch.zeros(B, gsize, self.F) 763 | canary = torch.tensor([1.0], requires_grad=True) 764 | obs = obs * canary 765 | 766 | 767 | sel = SLearnedEdge(input_size=self.F, num_edge_samples=10, window=16) 768 | adj = sel(obs, T, taus, B) 769 | adj.coalesce().values().sum().backward() 770 | self.assertTrue(canary.grad is not None) 771 | 772 | 773 | 774 | class TestUtil(unittest.TestCase): 775 | def test_flatten_idx(self): 776 | idx = torch.tensor([ 777 | [0, 0, 0, 0, 1], 778 | [0, 0, 0, 1, 0], 779 | [0, 0, 1, 0, 0], 780 | [0, 1, 0, 0, 0] 781 | ]) 782 | flat, offsets = util.flatten_idx_n_dim(idx) 783 | 784 | def test_flatten_many(self): 785 | idx = torch.tensor([ 786 | [0, 0, 0, 0, 0, 0, 0, 0], 787 | [0, 0, 0, 0, 1, 1, 1, 1], 788 | [0, 1, 2, 3, 0, 1, 2, 3] 789 | ]) 790 | flat, offsets = util.flatten_idx_n_dim(idx) 791 | if flat.unique().shape != flat.shape: 792 | self.fail(f"Repeated elems {flat}") 793 | 794 | def test_sparse_gumbel_softmax(self): 795 | idx = torch.tensor([ 796 | [0, 0, 0, 0, 0, 0, 1, 1], 797 | [0, 0, 0, 0, 1, 1, 1, 1], 798 | [0, 1, 2, 2, 0, 5, 4, 4], 799 | [0, 0, 1, 0, 0, 3, 0, 3] 800 | ]) 801 | values = torch.ones(8) * 1e15 802 | values[3] = 0 803 | values[-1] = 0 804 | a = torch.sparse_coo_tensor(idx, values, size=(2, 2, 100, 100)) 805 | res = util.sparse_gumbel_softmax(a, 3, hard=True) 806 | desired_idx = torch.tensor([ 807 | [0, 0, 0, 0, 0, 1], 808 | [0, 0, 0, 1, 1, 1], 809 | [0, 1, 2, 0, 5, 4], 810 | [0, 0, 1, 0, 3, 0] 811 | ]) 812 | desired_values = torch.ones(6) 813 | desired = torch.sparse_coo_tensor( 814 | desired_idx, desired_values, size=(2, 2, 100, 100)) 815 | 816 | if torch.any(res.coalesce().indices() != desired.coalesce().indices()): 817 | self.fail(f"{res} != {desired}") 818 | if torch.any(res.coalesce().values() != desired.coalesce().values()): 819 | self.fail(f"{res} != {desired}") 820 | 821 | 822 | class TestE2E(unittest.TestCase): 823 | def test_e2e_learned_edge(self): 824 | sparse_g = torch_geometric.nn.Sequential( 825 | "x, edges, weights", 826 | [ 827 | (torch_geometric.nn.GraphConv(32, 32), "x, edges, weights -> x"), 828 | (torch.nn.Tanh()), 829 | (torch_geometric.nn.GraphConv(32, 32), "x, edges, weights -> x"), 830 | (torch.nn.Tanh()), 831 | ], 832 | ) 833 | B = 8 834 | num_obs = 256 835 | obs_size = 32 836 | sparse_gcm = SparseGCM( 837 | sparse_g, graph_size=num_obs, edge_selectors=SLearnedEdge(obs_size), 838 | max_hops=2 839 | ) 840 | obs = torch.rand(B, num_obs, obs_size) 841 | taus = torch.ones(B, dtype=torch.long) 842 | hidden = None 843 | with torch.no_grad(): 844 | for i in range(num_obs): 845 | out, hidden = sparse_gcm(obs[:,i,None], taus, hidden) 846 | tmp = util.pack_hidden(hidden, B, max_edges = 5 * num_obs) 847 | tmp = util.unpack_hidden(tmp, B) 848 | # train 849 | out, hidden = sparse_gcm(obs, taus * num_obs, None) 850 | tmp = util.pack_hidden(hidden, B, max_edges = 5 * num_obs) 851 | tmp = util.unpack_hidden(tmp, B) 852 | out.mean().backward() 853 | 854 | def test_e2e_learned_edge_grad(self): 855 | sparse_g = torch_geometric.nn.Sequential( 856 | "x, edges, weights", 857 | [ 858 | (torch_geometric.nn.GraphConv(32, 32), "x, edges, weights -> x"), 859 | (torch.nn.Tanh()), 860 | (torch_geometric.nn.GraphConv(32, 32), "x, edges, weights -> x"), 861 | (torch.nn.Tanh()), 862 | ], 863 | ) 864 | B = 8 865 | num_obs = 4 866 | obs_size = 32 867 | sparse_gcm = SparseGCM( 868 | sparse_g, graph_size=num_obs, edge_selectors=SLearnedEdge(obs_size), 869 | max_hops=2 870 | ) 871 | canary = torch.tensor([1.0], requires_grad=True) 872 | obs = torch.rand(B, num_obs, obs_size) * canary 873 | taus = torch.ones(B, dtype=torch.long) 874 | hidden = None 875 | with torch.no_grad(): 876 | for i in range(num_obs): 877 | out, hidden = sparse_gcm(obs[:,i,None], taus, hidden) 878 | print(hidden[1]._values().numel()) 879 | tmp = util.pack_hidden(hidden, B, max_edges = 5 * num_obs) 880 | tmp = util.unpack_hidden(tmp, B) 881 | # train 882 | out, hidden = sparse_gcm(obs, taus * num_obs, None) 883 | tmp = util.pack_hidden(hidden, B, max_edges = 5 * num_obs) 884 | tmp = util.unpack_hidden(tmp, B) 885 | out.mean().backward() 886 | self.assertTrue(canary.grad is not None) 887 | 888 | def test_ray_sparse_edge_grad(self): 889 | B = 1 890 | F = 64 891 | num_obs = 32 892 | graph_size = 32 893 | taus = num_obs * torch.ones(B) 894 | act_space = gym.spaces.Discrete(1) 895 | obs_space = gym.spaces.Box(high=1000, low=-1000, shape=(F,)) 896 | cfg = ray_sparse_gcm.RaySparseGCM.DEFAULT_CONFIG 897 | cfg["aux_edge_selectors"] = SLearnedEdge(F) 898 | ray_gcm = ray_sparse_gcm.RaySparseGCM( 899 | obs_space, 900 | act_space, 901 | 1, 902 | cfg, 903 | 'my_model', 904 | ) 905 | 906 | canary = torch.tensor([1.0], requires_grad=True) 907 | input_dict = { 908 | "obs_flat": torch.ones(B*num_obs, F) * canary 909 | } 910 | state = [ 911 | torch.zeros(B, graph_size, F), 912 | torch.ones(B, 2, 50).long(), 913 | torch.ones(B, 1, 50), 914 | torch.zeros(B).long() 915 | ] 916 | seq_lens = taus.int().numpy() 917 | output, hidden = ray_gcm.forward(input_dict, state, seq_lens) 918 | # Check grads for adj 919 | hidden[2].sum().backward() 920 | self.assertTrue(hidden[2].requires_grad) 921 | self.assertTrue(canary.grad is not None) 922 | 923 | 924 | def test_ray_sparse_node_grad(self): 925 | B = 1 926 | F = 64 927 | num_obs = 32 928 | graph_size = 32 929 | taus = num_obs * torch.ones(B) 930 | act_space = gym.spaces.Discrete(1) 931 | obs_space = gym.spaces.Box(high=1000, low=-1000, shape=(F,)) 932 | cfg = ray_sparse_gcm.RaySparseGCM.DEFAULT_CONFIG 933 | cfg["aux_edge_selectors"] = SLearnedEdge(F) 934 | ray_gcm = ray_sparse_gcm.RaySparseGCM( 935 | obs_space, 936 | act_space, 937 | 1, 938 | cfg, 939 | 'my_model', 940 | ) 941 | 942 | canary = torch.tensor([1.0], requires_grad=True) 943 | input_dict = { 944 | "obs_flat": torch.ones(B*num_obs, F) * canary 945 | } 946 | state = [ 947 | torch.zeros(B, graph_size, F), 948 | torch.ones(B, 2, 50).long(), 949 | torch.ones(B, 1, 50), 950 | torch.zeros(B).long() 951 | ] 952 | seq_lens = taus.int().numpy() 953 | output, hidden = ray_gcm.forward(input_dict, state, seq_lens) 954 | # Check grads for nodes 955 | output.sum().backward() 956 | self.assertTrue(output.requires_grad) 957 | self.assertTrue(canary.grad is not None) 958 | 959 | 960 | if __name__ == "__main__": 961 | unittest.main() 962 | -------------------------------------------------------------------------------- /tests/test_speed.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch_geometric 3 | from gcm.gcm import DenseGCM 4 | from gcm.sparse_gcm import SparseGCM 5 | from gcm.edge_selectors.dense import DenseEdge 6 | from gcm.sparse_edge_selectors.temporal import TemporalEdge 7 | import time 8 | 9 | 10 | 11 | lstm = torch.nn.LSTMCell(32,32) 12 | g = torch_geometric.nn.Sequential( 13 | "x, adj, weights, B, N", 14 | [ 15 | (torch_geometric.nn.DenseGraphConv(32, 32), "x, adj -> x"), 16 | (torch.nn.Tanh()), 17 | (torch_geometric.nn.DenseGraphConv(32, 32), "x, adj -> x"), 18 | (torch.nn.Tanh()), 19 | ], 20 | ) 21 | 22 | num_obs = 16 23 | gcm = DenseGCM( 24 | g, 25 | edge_selectors=DenseEdge(), 26 | graph_size=num_obs 27 | ) 28 | 29 | sparse_g = torch_geometric.nn.Sequential( 30 | "x, edges, weights", 31 | [ 32 | (torch_geometric.nn.GraphConv(32, 32), "x, edges, weights -> x"), 33 | (torch.nn.Tanh()), 34 | (torch_geometric.nn.GraphConv(32, 32), "x, edges, weights -> x"), 35 | (torch.nn.Tanh()), 36 | ], 37 | ) 38 | sparse_gcm = SparseGCM(sparse_g, graph_size=num_obs, edge_selectors=TemporalEdge([1,2])) 39 | 40 | 41 | 42 | obs = torch.rand(num_obs, 32) 43 | 44 | lstm_s = time.time() 45 | hidden = None 46 | for i in range(num_obs): 47 | hidden = lstm(obs[i].unsqueeze(0), hidden) 48 | hidden[0].mean().backward() 49 | print("lstm took", time.time() - lstm_s) 50 | 51 | gcm_s = time.time() 52 | hidden = None 53 | for i in range(num_obs): 54 | out, hidden = gcm(obs[i].unsqueeze(0), hidden) 55 | out.mean().backward() 56 | print("gcm took", time.time() - gcm_s) 57 | 58 | gcms_s = time.time() 59 | hidden = None 60 | taus = torch.tensor([num_obs]) 61 | out, hidden = sparse_gcm(obs.unsqueeze(0), taus, hidden) 62 | out.mean().backward() 63 | print("sparse gcm took", time.time() - gcms_s) 64 | 65 | --------------------------------------------------------------------------------