├── .gitignore ├── CITATION.cff ├── README.md ├── assets ├── explicit_mse_10.png ├── explicit_mse_30.png ├── implicit_ebm_10.png └── implicit_ebm_30.png ├── ibc ├── __init__.py ├── dataset.py ├── experiment.py ├── models.py ├── modules.py ├── optimizers.py ├── trainer.py └── utils.py ├── mypy.ini ├── plot.py ├── requirements.txt ├── run_explicit.sh ├── run_implicit.sh └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | experiments/ 3 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 0.0.1 2 | message: "If you use this software, please cite it as below." 3 | authors: 4 | - family-names: Zakka 5 | given-names: Kevin 6 | title: "A PyTorch Implementation of Implicit Behavioral Cloning" 7 | version: 0.0.1 8 | date-released: 2021-10-27 9 | url: "https://github.com/kevinzakka/ibc" 10 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Implicit Behavioral Cloning - PyTorch 2 | 3 | Pytorch implementation of Implicit Behavioral Cloning. 4 | 5 | ## Install 6 | 7 | ```bash 8 | conda create -n ibc python=3.8 9 | pip install -r requirements.txt 10 | ``` 11 | 12 | ## Results 13 | 14 | To reproduce results from the Coordinate Regression Task (Section 3), execute the `run_explicit.sh` and `run_implicit.sh` scripts. Note that the implicit policy does a tad bit worst with 30 examples than 10. Not entirely sure why that is the case and need to investigate more. 15 | 16 | | | Explicit Policy | Implicit Policy | 17 | |-------------|-----------------|-----------------| 18 | | 10 examples ||| 19 | | 30 examples ||| 20 | 21 | ## Citation 22 | 23 | If you find this code useful, consider citing it along with the paper: 24 | 25 | ```bibtex 26 | @software{zakka2021ibc, 27 | author = {Zakka, Kevin}, 28 | month = {10}, 29 | title = {{A PyTorch Implementation of Implicit Behavioral Cloning}}, 30 | url = {https://github.com/kevinzakka/ibc}, 31 | version = {0.0.1}, 32 | year = {2021} 33 | } 34 | ``` 35 | 36 | ```bibtex 37 | @misc{florence2021implicit, 38 | title = {Implicit Behavioral Cloning}, 39 | author = {Pete Florence and Corey Lynch and Andy Zeng and Oscar Ramirez and Ayzaan Wahid and Laura Downs and Adrian Wong and Johnny Lee and Igor Mordatch and Jonathan Tompson}, 40 | year = {2021}, 41 | eprint = {2109.00137}, 42 | archivePrefix = {arXiv}, 43 | primaryClass = {cs.RO} 44 | } 45 | ``` 46 | -------------------------------------------------------------------------------- /assets/explicit_mse_10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kevinzakka/ibc/a30c4f527ccf23b7b893824e1c64b7385b7e85e7/assets/explicit_mse_10.png -------------------------------------------------------------------------------- /assets/explicit_mse_30.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kevinzakka/ibc/a30c4f527ccf23b7b893824e1c64b7385b7e85e7/assets/explicit_mse_30.png -------------------------------------------------------------------------------- /assets/implicit_ebm_10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kevinzakka/ibc/a30c4f527ccf23b7b893824e1c64b7385b7e85e7/assets/implicit_ebm_10.png -------------------------------------------------------------------------------- /assets/implicit_ebm_30.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kevinzakka/ibc/a30c4f527ccf23b7b893824e1c64b7385b7e85e7/assets/implicit_ebm_30.png -------------------------------------------------------------------------------- /ibc/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kevinzakka/ibc/a30c4f527ccf23b7b893824e1c64b7385b7e85e7/ibc/__init__.py -------------------------------------------------------------------------------- /ibc/dataset.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | from typing import Optional, Tuple 3 | 4 | import numpy as np 5 | import torch 6 | from torch.utils.data import Dataset 7 | from torchvision.transforms import ToTensor 8 | 9 | 10 | @dataclasses.dataclass 11 | class DatasetConfig: 12 | dataset_size: int = 30 13 | """The size of the dataset. Useful for sample efficiency experiments.""" 14 | 15 | resolution: Tuple[int, int] = (96, 96) 16 | """The resolution of the image.""" 17 | 18 | pixel_size: int = 7 19 | """The size of the pixel whose coordinates we'd like to regress. Must be odd.""" 20 | 21 | pixel_color: Tuple[int, int, int] = (0, 255, 0) 22 | """The color of the pixel whose coordinates we'd like to regress.""" 23 | 24 | seed: Optional[int] = None 25 | """Whether to seed the dataset. Disabled if None.""" 26 | 27 | 28 | class CoordinateRegression(Dataset): 29 | """Regress the coordinates of a colored pixel block on a white canvas.""" 30 | 31 | def __init__(self, config: DatasetConfig) -> None: 32 | if not config.pixel_size % 2: 33 | raise ValueError("'pixel_size' must be odd.") 34 | 35 | self.dataset_size = config.dataset_size 36 | self.resolution = config.resolution 37 | self.pixel_size = config.pixel_size 38 | self.pixel_color = config.pixel_color 39 | self.seed = config.seed 40 | 41 | self.reset() 42 | 43 | def reset(self) -> None: 44 | if self.seed is not None: 45 | np.random.seed(self.seed) 46 | 47 | self._coordinates = self._sample_coordinates(self.dataset_size) 48 | self._coordinates_scaled = self._scale_coordinates(self._coordinates) 49 | 50 | def exclude(self, coordinates: np.ndarray) -> None: 51 | """Exclude the given coordinates, if present, from the previously sampled ones. 52 | 53 | This is useful for ensuring the train set does not accidentally leak into the 54 | test set. 55 | """ 56 | mask = (self.coordinates[:, None] == coordinates).all(-1).any(1) 57 | num_matches = mask.sum() 58 | while mask.sum() > 0: 59 | self._coordinates[mask] = self._sample_coordinates(mask.sum()) 60 | mask = (self.coordinates[:, None] == coordinates).all(-1).any(1) 61 | self._coordinates_scaled = self._scale_coordinates(self._coordinates) 62 | print(f"Resampled {num_matches} data points.") 63 | 64 | def get_target_bounds(self) -> np.ndarray: 65 | """Return per-dimension target min/max.""" 66 | return np.array([[-1.0, -1.0], [1.0, 1.0]]) 67 | 68 | def _sample_coordinates(self, size: int) -> np.ndarray: 69 | """Helper method for generating pixel coordinates.""" 70 | # Randomly generate pixel coordinates. 71 | u = np.random.randint(0, self.resolution[0], size=size) 72 | v = np.random.randint(0, self.resolution[1], size=size) 73 | 74 | # Ensure we remain within bounds when we take the pixel size into account. 75 | slack = self.pixel_size // 2 76 | u = np.clip(u, a_min=slack, a_max=self.resolution[0] - 1 - slack) 77 | v = np.clip(v, a_min=slack, a_max=self.resolution[1] - 1 - slack) 78 | 79 | return np.vstack([u, v]).astype(np.int16).T 80 | 81 | def _scale_coordinates(self, coords: np.ndarray) -> np.ndarray: 82 | """Helper method for scaling coordinates to the [-1, 1] range.""" 83 | coords_scaled = np.array(coords, dtype=np.float32) 84 | coords_scaled[:, 0] /= self.resolution[0] - 1 85 | coords_scaled[:, 1] /= self.resolution[1] - 1 86 | coords_scaled *= 2 87 | coords_scaled -= 1 88 | return coords_scaled 89 | 90 | @property 91 | def image_shape(self) -> Tuple[int, int, int]: 92 | return self.resolution + (3,) 93 | 94 | @property 95 | def coordinates(self) -> np.ndarray: 96 | return self._coordinates 97 | 98 | @property 99 | def coordinates_scaled(self) -> np.ndarray: 100 | return self._coordinates_scaled 101 | 102 | def __len__(self) -> int: 103 | return self.dataset_size 104 | 105 | def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]: 106 | uv = self._coordinates[index] 107 | uv_scaled = self._coordinates_scaled[index] 108 | 109 | image = np.full(self.image_shape, fill_value=255, dtype=np.uint8) 110 | image[ 111 | uv[0] - self.pixel_size // 2 : uv[0] + self.pixel_size // 2 + 1, 112 | uv[1] - self.pixel_size // 2 : uv[1] + self.pixel_size // 2 + 1, 113 | ] = self.pixel_color 114 | 115 | image_tensor = ToTensor()(image) 116 | target_tensor = torch.as_tensor(uv_scaled, dtype=torch.float32) 117 | 118 | return image_tensor, target_tensor 119 | 120 | 121 | if __name__ == "__main__": 122 | import matplotlib.pyplot as plt 123 | from scipy.spatial import ConvexHull 124 | 125 | dataset = CoordinateRegression(DatasetConfig(dataset_size=30, seed=0)) 126 | 127 | # Visualize one instance. 128 | image, target = dataset[np.random.randint(len(dataset))] 129 | print(target) 130 | plt.imshow(image.permute(1, 2, 0).numpy()) 131 | plt.show() 132 | 133 | # Plot target distribution and convex hull. 134 | targets = dataset.coordinates 135 | plt.scatter(targets[:, 0], targets[:, 1], marker="x", c="black") 136 | for simplex in ConvexHull(targets).simplices: 137 | plt.plot( 138 | targets[simplex, 0], 139 | targets[simplex, 1], 140 | "--", 141 | zorder=2, 142 | alpha=0.5, 143 | c="black", 144 | ) 145 | plt.xlim(0, dataset.resolution[1]) 146 | plt.ylim(0, dataset.resolution[0]) 147 | plt.show() 148 | 149 | # Plot target distribution and convex hull. 150 | targets = dataset.coordinates_scaled 151 | plt.scatter(targets[:, 0], targets[:, 1], marker="x", c="black") 152 | for simplex in ConvexHull(targets).simplices: 153 | plt.plot( 154 | targets[simplex, 0], 155 | targets[simplex, 1], 156 | "--", 157 | zorder=2, 158 | alpha=0.5, 159 | c="black", 160 | ) 161 | plt.xlim(-1, 1) 162 | plt.ylim(-1, 1) 163 | plt.show() 164 | 165 | print(f"Target bounds:") 166 | print(dataset.get_target_bounds()) 167 | -------------------------------------------------------------------------------- /ibc/experiment.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import dataclasses 4 | import os 5 | import pathlib 6 | import signal 7 | import tempfile 8 | from typing import TYPE_CHECKING, Any, Dict, Optional, Type, TypeVar, Union 9 | 10 | if TYPE_CHECKING: 11 | from .trainer import TrainStateProtocol 12 | 13 | import numpy as np 14 | import torch 15 | import yaml 16 | from torch.utils.tensorboard import SummaryWriter 17 | 18 | T = TypeVar("T") 19 | TensorOrFloat = Union[np.ndarray, torch.Tensor, float] 20 | 21 | 22 | @dataclasses.dataclass 23 | class TensorboardLogData: 24 | scalars: Dict[str, TensorOrFloat] = dataclasses.field(default_factory=dict) 25 | 26 | @staticmethod 27 | def merge(a: TensorboardLogData, b: TensorboardLogData) -> TensorboardLogData: 28 | return TensorboardLogData(scalars=dict(**a.scalars, **b.scalars)) 29 | 30 | def extend(self, scalars: Dict[str, TensorOrFloat] = {}) -> TensorboardLogData: 31 | return TensorboardLogData.merge(self, TensorboardLogData(scalars=scalars)) 32 | 33 | 34 | @dataclasses.dataclass(frozen=True) 35 | class Experiment: 36 | identifier: str 37 | """The name of the experiment.""" 38 | 39 | data_dir: pathlib.Path = dataclasses.field(init=False) 40 | log_dir: pathlib.Path = dataclasses.field(init=False) 41 | checkpoint_dir: pathlib.Path = dataclasses.field(init=False) 42 | 43 | def __post_init__(self) -> None: 44 | root = pathlib.Path("./experiments") 45 | 46 | super().__setattr__("data_dir", root / self.identifier) 47 | super().__setattr__("log_dir", self.data_dir / "tb") 48 | super().__setattr__("checkpoint_dir", self.data_dir / "checkpoints") 49 | 50 | def assert_new(self) -> Experiment: 51 | """Makes sure that there are no existing checkpoints, logs, or metadata.""" 52 | assert not self.data_dir.exists() or tuple(self.data_dir.iterdir()) == () 53 | return self 54 | 55 | def assert_exists(self) -> Experiment: 56 | """Makes sure that there are existing checkpoints, logs, or metadata.""" 57 | assert self.data_dir.exists() and tuple(self.data_dir.iterdir()) != () 58 | return self 59 | 60 | # =================================================================== # 61 | # Properties. 62 | # =================================================================== # 63 | 64 | # Note: This hack is necessary because the SummaryWriter object instantly creates 65 | # a directory upon construction and thus would break the `assert_*` functionality. 66 | @property 67 | def summary_writer(self) -> SummaryWriter: 68 | if not hasattr(self, "__summary_writer__"): 69 | object.__setattr__( 70 | self, 71 | "__summary_writer__", 72 | SummaryWriter(self.log_dir), 73 | ) 74 | return object.__getattribute__(self, "__summary_writer__") 75 | 76 | # =================================================================== # 77 | # Checkpointing. 78 | # =================================================================== # 79 | 80 | def save_checkpoint( 81 | self, 82 | target: TrainStateProtocol, 83 | step: int, 84 | prefix: str = "ckpt_", 85 | keep: int = 10, 86 | ) -> None: 87 | """Save a snapshot of the train state to disk.""" 88 | self._ensure_directory_exists(self.checkpoint_dir) 89 | 90 | # Create a snapshot of the state. 91 | snapshot_state = {} 92 | for k, v in target.__dict__.items(): 93 | if hasattr(v, "state_dict"): 94 | snapshot_state[k] = v.state_dict() 95 | snapshot_state["steps"] = target.steps 96 | 97 | # Save to disk. 98 | checkpoint_path = self.checkpoint_dir / f"{prefix}{step}.ckpt" 99 | self._atomic_save(checkpoint_path, snapshot_state) 100 | 101 | # Trim extraneous checkpoints. 102 | self._trim_checkpoints(keep) 103 | 104 | def restore_checkpoint( 105 | self, 106 | target: TrainStateProtocol, 107 | step: Optional[int] = None, 108 | prefix: str = "ckpt_", 109 | ): 110 | """Restore a snapshot of the train state from disk.""" 111 | # Get latest checkpoint if no step has been provided. 112 | if step is None: 113 | step = self._get_latest_checkpoint_step() 114 | 115 | checkpoint_path = self.checkpoint_dir / f"{prefix}{step}.ckpt" 116 | snapshot_state = torch.load(checkpoint_path, map_location="cpu") 117 | 118 | delattr(target, "steps") 119 | setattr(target, "steps", snapshot_state["steps"]) 120 | 121 | for k, v in target.__dict__.items(): 122 | if hasattr(v, "state_dict"): 123 | getattr(target, k).load_state_dict(snapshot_state[k]) 124 | 125 | # =================================================================== # 126 | # Logging. 127 | # =================================================================== # 128 | 129 | def log(self, log_data: TensorboardLogData, step: int) -> None: 130 | """Logging helper for TensorBoard.""" 131 | for k, v in log_data.scalars.items(): 132 | self.summary_writer.add_scalar(tag=k, scalar_value=v, global_step=step) 133 | self.summary_writer.flush() 134 | 135 | def write_metadata(self, name: str, object: Any) -> None: 136 | """Serialize an object to disk as a yaml file.""" 137 | self._ensure_directory_exists(self.data_dir) 138 | assert not name.endswith(".yaml") 139 | 140 | path = self.data_dir / (name + ".yaml") 141 | print(f"Writing metadata to {path}") 142 | with open(path, "w") as fp: 143 | yaml.dump(object, fp) 144 | 145 | def read_metadata(self, name: str, expected_type: Type[T]) -> T: 146 | """Load an object from the experiment's metadata directory.""" 147 | path = self.data_dir / (name + ".yaml") 148 | 149 | with open(path, "r") as fp: 150 | output = yaml.load(fp, Loader=yaml.Loader) 151 | 152 | assert isinstance(output, expected_type) 153 | return output 154 | 155 | # =================================================================== # 156 | # Helper functions. 157 | # =================================================================== # 158 | 159 | def _ensure_directory_exists(self, path: pathlib.Path) -> None: 160 | """Helper for ensuring that a directory exists.""" 161 | if not path.exists(): 162 | path.mkdir(parents=True) 163 | 164 | def _trim_checkpoints(self, keep: int) -> None: 165 | """Helper for deleting older checkpoints.""" 166 | # Get a list of checkpoints. 167 | ckpts = list(self.checkpoint_dir.glob(pattern="*.ckpt")) 168 | 169 | # Sort in reverse `step` order. 170 | ckpts.sort(key=lambda f: -int(f.stem.split("_")[-1])) 171 | 172 | # Remove until `keep` remain. 173 | while len(ckpts) - keep > 0: 174 | ckpts.pop().unlink() 175 | 176 | def _get_latest_checkpoint_step(self) -> int: 177 | """Helper for returning the step of the latest checkpoint.""" 178 | ckpts = list(self.checkpoint_dir.glob(pattern="*.ckpt")) 179 | ckpts.sort(key=lambda f: -int(f.stem.split("_")[-1])) 180 | return int(ckpts[0].stem.split("_")[-1]) 181 | 182 | def _atomic_save(self, save_path: pathlib.Path, snapshot: Dict[str, Any]) -> None: 183 | """Helper for safely saving to disk.""" 184 | # Ignore Ctrl-C while saving. 185 | try: 186 | orig_handler = signal.getsignal(signal.SIGINT) 187 | signal.signal(signal.SIGINT, lambda _sig, _frame: None) 188 | except ValueError: 189 | # Signal throws a ValueError if we're not in the main thread. 190 | orig_handler = None 191 | 192 | with tempfile.TemporaryDirectory() as tmp_dir: 193 | tmp_path = pathlib.Path(tmp_dir) / "tmp.ckpt" 194 | torch.save(snapshot, tmp_path) 195 | # `rename` is POSIX-compliant and thus, is an atomic operation. 196 | # Ref: https://docs.python.org/3/library/os.html#os.rename 197 | os.rename(tmp_path, save_path) 198 | 199 | # Restore SIGINT handler. 200 | if orig_handler is not None: 201 | signal.signal(signal.SIGINT, orig_handler) 202 | -------------------------------------------------------------------------------- /ibc/models.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | import enum 3 | from functools import partial 4 | from typing import Callable, Optional, Sequence 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | from .modules import CoordConv, GlobalAvgPool2d, GlobalMaxPool2d, SpatialSoftArgmax 11 | 12 | 13 | class ActivationType(enum.Enum): 14 | RELU = nn.ReLU 15 | SELU = nn.SiLU 16 | 17 | 18 | @dataclasses.dataclass(frozen=True) 19 | class MLPConfig: 20 | input_dim: int 21 | hidden_dim: int 22 | output_dim: int 23 | hidden_depth: int 24 | dropout_prob: Optional[float] = None 25 | activation_fn: ActivationType = ActivationType.RELU 26 | 27 | 28 | class MLP(nn.Module): 29 | """A feedforward multi-layer perceptron.""" 30 | 31 | def __init__(self, config: MLPConfig) -> None: 32 | super().__init__() 33 | 34 | dropout_layer: Callable 35 | if config.dropout_prob is not None: 36 | dropout_layer = partial(nn.Dropout, p=config.dropout_prob) 37 | else: 38 | dropout_layer = nn.Identity 39 | 40 | layers: Sequence[nn.Module] 41 | if config.hidden_depth == 0: 42 | layers = [nn.Linear(config.input_dim, config.output_dim)] 43 | else: 44 | layers = [ 45 | nn.Linear(config.input_dim, config.hidden_dim), 46 | config.activation_fn.value(), 47 | dropout_layer(), 48 | ] 49 | for _ in range(config.hidden_depth - 1): 50 | layers += [ 51 | nn.Linear(config.hidden_dim, config.hidden_dim), 52 | config.activation_fn.value(), 53 | dropout_layer(), 54 | ] 55 | layers += [nn.Linear(config.hidden_dim, config.output_dim)] 56 | layers = [layer for layer in layers if not isinstance(layer, nn.Identity)] 57 | 58 | self.net = nn.Sequential(*layers) 59 | 60 | def forward(self, x: torch.Tensor) -> torch.Tensor: 61 | return self.net(x) 62 | 63 | 64 | class ResidualBlock(nn.Module): 65 | def __init__( 66 | self, 67 | depth: int, 68 | activation_fn: ActivationType = ActivationType.RELU, 69 | ) -> None: 70 | super().__init__() 71 | 72 | self.conv1 = nn.Conv2d(depth, depth, 3, padding=1, bias=True) 73 | self.conv2 = nn.Conv2d(depth, depth, 3, padding=1, bias=True) 74 | self.activation = activation_fn.value() 75 | 76 | def forward(self, x: torch.Tensor) -> torch.Tensor: 77 | out = self.activation(x) 78 | out = self.conv1(out) 79 | out = self.activation(x) 80 | out = self.conv2(out) 81 | return out + x 82 | 83 | 84 | @dataclasses.dataclass(frozen=True) 85 | class CNNConfig: 86 | in_channels: int 87 | blocks: Sequence[int] = dataclasses.field(default=(16, 32, 32)) 88 | activation_fn: ActivationType = ActivationType.RELU 89 | 90 | 91 | class CNN(nn.Module): 92 | """A residual convolutional network.""" 93 | 94 | def __init__(self, config: CNNConfig) -> None: 95 | super().__init__() 96 | 97 | depth_in = config.in_channels 98 | 99 | layers = [] 100 | for depth_out in config.blocks: 101 | layers.extend( 102 | [ 103 | nn.Conv2d(depth_in, depth_out, 3, padding=1), 104 | ResidualBlock(depth_out, config.activation_fn), 105 | ] 106 | ) 107 | depth_in = depth_out 108 | 109 | self.net = nn.Sequential(*layers) 110 | self.activation = config.activation_fn.value() 111 | 112 | def forward(self, x: torch.Tensor, activate: bool = False) -> torch.Tensor: 113 | out = self.net(x) 114 | if activate: 115 | return self.activation(out) 116 | return out 117 | 118 | 119 | class SpatialReduction(enum.Enum): 120 | SPATIAL_SOFTMAX = SpatialSoftArgmax 121 | AVERAGE_POOL = GlobalAvgPool2d 122 | MAX_POOL = GlobalMaxPool2d 123 | 124 | 125 | @dataclasses.dataclass(frozen=True) 126 | class ConvMLPConfig: 127 | cnn_config: CNNConfig 128 | mlp_config: MLPConfig 129 | spatial_reduction: SpatialReduction = SpatialReduction.AVERAGE_POOL 130 | coord_conv: bool = False 131 | 132 | 133 | class ConvMLP(nn.Module): 134 | def __init__(self, config: ConvMLPConfig) -> None: 135 | super().__init__() 136 | 137 | self.coord_conv = config.coord_conv 138 | 139 | self.cnn = CNN(config.cnn_config) 140 | self.conv = nn.Conv2d(config.cnn_config.blocks[-1], 16, 1) 141 | self.reducer = config.spatial_reduction.value() 142 | self.mlp = MLP(config.mlp_config) 143 | 144 | def forward(self, x: torch.Tensor) -> torch.Tensor: 145 | if self.coord_conv: 146 | x = CoordConv()(x) 147 | out = self.cnn(x, activate=True) 148 | out = F.relu(self.conv(out)) 149 | out = self.reducer(out) 150 | out = self.mlp(out) 151 | return out 152 | 153 | 154 | class EBMConvMLP(nn.Module): 155 | def __init__(self, config: ConvMLPConfig) -> None: 156 | super().__init__() 157 | 158 | self.coord_conv = config.coord_conv 159 | 160 | self.cnn = CNN(config.cnn_config) 161 | self.conv = nn.Conv2d(config.cnn_config.blocks[-1], 16, 1) 162 | self.reducer = config.spatial_reduction.value() 163 | self.mlp = MLP(config.mlp_config) 164 | 165 | def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 166 | if self.coord_conv: 167 | x = CoordConv()(x) 168 | out = self.cnn(x, activate=True) 169 | out = F.relu(self.conv(out)) 170 | out = self.reducer(out) 171 | fused = torch.cat([out.unsqueeze(1).expand(-1, y.size(1), -1), y], dim=-1) 172 | B, N, D = fused.size() 173 | fused = fused.reshape(B * N, D) 174 | out = self.mlp(fused) 175 | return out.view(B, N) 176 | 177 | 178 | if __name__ == "__main__": 179 | config = ConvMLPConfig( 180 | cnn_config=CNNConfig(5), 181 | mlp_config=MLPConfig(32, 128, 2, 2), 182 | spatial_reduction=SpatialReduction.AVERAGE_POOL, 183 | coord_conv=True, 184 | ) 185 | 186 | net = ConvMLP(config) 187 | print(net) 188 | 189 | x = torch.randn(2, 3, 96, 96) 190 | with torch.no_grad(): 191 | out = net(x) 192 | print(out.shape) 193 | -------------------------------------------------------------------------------- /ibc/modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | # Adapted from: https://github.com/Wizaron/coord-conv-pytorch 7 | class CoordConv(nn.Module): 8 | def __init__(self) -> None: 9 | super().__init__() 10 | 11 | def forward(self, x: torch.Tensor) -> torch.Tensor: 12 | batch_size, _, image_height, image_width = x.size() 13 | y_coords = ( 14 | 2.0 15 | * torch.arange(image_height).unsqueeze(1).expand(image_height, image_width) 16 | / (image_height - 1.0) 17 | - 1.0 18 | ) 19 | x_coords = ( 20 | 2.0 21 | * torch.arange(image_width).unsqueeze(0).expand(image_height, image_width) 22 | / (image_width - 1.0) 23 | - 1.0 24 | ) 25 | coords = torch.stack((y_coords, x_coords), dim=0) 26 | coords = torch.unsqueeze(coords, dim=0).repeat(batch_size, 1, 1, 1) 27 | x = torch.cat((coords.to(x.device), x), dim=1) 28 | return x 29 | 30 | 31 | class SpatialSoftArgmax(nn.Module): 32 | """Spatial softmax as defined in https://arxiv.org/abs/1504.00702. 33 | 34 | Concretely, the spatial softmax of each feature map is used to compute a weighted 35 | mean of the pixel locations, effectively performing a soft arg-max over the feature 36 | dimension. 37 | """ 38 | 39 | def __init__(self, normalize: bool = True) -> None: 40 | super().__init__() 41 | 42 | self.normalize = normalize 43 | 44 | def _coord_grid( 45 | self, 46 | h: int, 47 | w: int, 48 | device: torch.device, 49 | ) -> torch.Tensor: 50 | if self.normalize: 51 | return torch.stack( 52 | torch.meshgrid( 53 | torch.linspace(-1, 1, w, device=device), 54 | torch.linspace(-1, 1, h, device=device), 55 | indexing="ij", 56 | ) 57 | ) 58 | return torch.stack( 59 | torch.meshgrid( 60 | torch.arange(0, w, device=device), 61 | torch.arange(0, h, device=device), 62 | indexing="ij", 63 | ) 64 | ) 65 | 66 | def forward(self, x: torch.Tensor) -> torch.Tensor: 67 | assert x.ndim == 4, "Expecting a tensor of shape (B, C, H, W)." 68 | 69 | # Compute a spatial softmax over the input: 70 | # Given an input of shape (B, C, H, W), reshape it to (B*C, H*W) then apply the 71 | # softmax operator over the last dimension. 72 | _, c, h, w = x.shape 73 | softmax = F.softmax(x.view(-1, h * w), dim=-1) 74 | 75 | # Create a meshgrid of normalized pixel coordinates. 76 | xc, yc = self._coord_grid(h, w, x.device) 77 | 78 | # Element-wise multiply the x and y coordinates with the softmax, then sum over 79 | # the h*w dimension. This effectively computes the weighted mean x and y 80 | # locations. 81 | x_mean = (softmax * xc.flatten()).sum(dim=1, keepdims=True) 82 | y_mean = (softmax * yc.flatten()).sum(dim=1, keepdims=True) 83 | 84 | # Concatenate and reshape the result to (B, C*2) where for every feature we have 85 | # the expected x and y pixel locations. 86 | return torch.cat([x_mean, y_mean], dim=1).view(-1, c * 2) 87 | 88 | 89 | class GlobalMaxPool2d(nn.Module): 90 | """Global spatial max pooling layer.""" 91 | 92 | def __init__(self) -> None: 93 | super().__init__() 94 | 95 | self._pool = F.max_pool2d 96 | 97 | def forward(self, x: torch.Tensor) -> torch.Tensor: 98 | out = self._pool(x, kernel_size=x.size()[2:]) 99 | for _ in range(len(out.shape[2:])): 100 | out.squeeze_(dim=-1) 101 | return out 102 | 103 | 104 | class GlobalAvgPool2d(nn.Module): 105 | """Global spatial average pooling layer.""" 106 | 107 | def __init__(self) -> None: 108 | super().__init__() 109 | 110 | self._pool = F.avg_pool2d 111 | 112 | def forward(self, x: torch.Tensor) -> torch.Tensor: 113 | out = self._pool(x, kernel_size=x.size()[2:]) 114 | for _ in range(len(out.shape[2:])): 115 | out.squeeze_(dim=-1) 116 | return out 117 | -------------------------------------------------------------------------------- /ibc/optimizers.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import dataclasses 4 | import enum 5 | from typing import Protocol 6 | 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | # =================================================================== # 13 | # Model optimization. 14 | # =================================================================== # 15 | 16 | 17 | @dataclasses.dataclass 18 | class OptimizerConfig: 19 | learning_rate: float = 1e-3 20 | weight_decay: float = 0.0 21 | beta1: float = 0.9 22 | beta2: float = 0.999 23 | lr_scheduler_step: int = 100 24 | lr_scheduler_gamma: float = 0.99 25 | 26 | 27 | # =================================================================== # 28 | # Stochastic optimization for EBM training and inference. 29 | # =================================================================== # 30 | 31 | 32 | @dataclasses.dataclass 33 | class StochasticOptimizerConfig: 34 | bounds: np.ndarray 35 | """Bounds on the samples, min/max for each dimension.""" 36 | 37 | iters: int 38 | """The total number of inference iters.""" 39 | 40 | train_samples: int 41 | """The number of counter-examples to sample per iter during training.""" 42 | 43 | inference_samples: int 44 | """The number of candidates to sample per iter during inference.""" 45 | 46 | 47 | class StochasticOptimizer(Protocol): 48 | """Functionality that needs to be implemented by all stochastic optimizers.""" 49 | 50 | device: torch.device 51 | 52 | def sample(self, batch_size: int, ebm: nn.Module) -> torch.Tensor: 53 | """Sample counter-negatives for feeding to the InfoNCE objective.""" 54 | 55 | def infer(self, x: torch.Tensor, ebm: nn.Module) -> torch.Tensor: 56 | """Optimize for the best action conditioned on the current observation.""" 57 | 58 | 59 | @dataclasses.dataclass 60 | class DerivativeFreeConfig(StochasticOptimizerConfig): 61 | noise_scale: float = 0.33 62 | noise_shrink: float = 0.5 63 | iters: int = 3 64 | train_samples: int = 256 65 | inference_samples: int = 2 ** 14 66 | 67 | 68 | @dataclasses.dataclass 69 | class DerivativeFreeOptimizer: 70 | """A simple derivative-free optimizer. Great for up to 5 dimensions.""" 71 | 72 | device: torch.device 73 | noise_scale: float 74 | noise_shrink: float 75 | iters: int 76 | train_samples: int 77 | inference_samples: int 78 | bounds: np.ndarray 79 | 80 | @staticmethod 81 | def initialize( 82 | config: DerivativeFreeConfig, device_type: str 83 | ) -> DerivativeFreeOptimizer: 84 | return DerivativeFreeOptimizer( 85 | device=torch.device(device_type if torch.cuda.is_available() else "cpu"), 86 | noise_scale=config.noise_scale, 87 | noise_shrink=config.noise_shrink, 88 | iters=config.iters, 89 | train_samples=config.train_samples, 90 | inference_samples=config.inference_samples, 91 | bounds=config.bounds, 92 | ) 93 | 94 | def _sample(self, num_samples: int) -> torch.Tensor: 95 | """Helper method for drawing samples from the uniform random distribution.""" 96 | size = (num_samples, self.bounds.shape[1]) 97 | samples = np.random.uniform(self.bounds[0, :], self.bounds[1, :], size=size) 98 | return torch.as_tensor(samples, dtype=torch.float32, device=self.device) 99 | 100 | def sample(self, batch_size: int, ebm: nn.Module) -> torch.Tensor: 101 | del ebm # The derivative-free optimizer does not use the ebm for sampling. 102 | samples = self._sample(batch_size * self.train_samples) 103 | return samples.reshape(batch_size, self.train_samples, -1) 104 | 105 | @torch.no_grad() 106 | def infer(self, x: torch.Tensor, ebm: nn.Module) -> torch.Tensor: 107 | """Optimize for the best action given a trained EBM.""" 108 | noise_scale = self.noise_scale 109 | bounds = torch.as_tensor(self.bounds).to(self.device) 110 | 111 | samples = self._sample(x.size(0) * self.inference_samples) 112 | samples = samples.reshape(x.size(0), self.inference_samples, -1) 113 | 114 | for i in range(self.iters): 115 | # Compute energies. 116 | energies = ebm(x, samples) 117 | probs = F.softmax(-1.0 * energies, dim=-1) 118 | 119 | # Resample with replacement. 120 | idxs = torch.multinomial(probs, self.inference_samples, replacement=True) 121 | samples = samples[torch.arange(samples.size(0)).unsqueeze(-1), idxs] 122 | 123 | # Add noise and clip to target bounds. 124 | samples = samples + torch.randn_like(samples) * noise_scale 125 | samples = samples.clamp(min=bounds[0, :], max=bounds[1, :]) 126 | 127 | noise_scale *= self.noise_shrink 128 | 129 | # Return target with highest probability. 130 | energies = ebm(x, samples) 131 | probs = F.softmax(-1.0 * energies, dim=-1) 132 | best_idxs = probs.argmax(dim=-1) 133 | return samples[torch.arange(samples.size(0)), best_idxs, :] 134 | 135 | 136 | class StochasticOptimizerType(enum.Enum): 137 | # Note(kevin): The paper describes three types of samplers. Right now, we just have 138 | # the derivative free sampler implemented. 139 | DERIVATIVE_FREE = enum.auto() 140 | 141 | 142 | if __name__ == "__main__": 143 | from dataset import CoordinateRegression, DatasetConfig 144 | 145 | dataset = CoordinateRegression(DatasetConfig(dataset_size=10)) 146 | bounds = dataset.get_target_bounds() 147 | 148 | config = DerivativeFreeConfig(bounds=bounds, train_samples=256) 149 | so = DerivativeFreeOptimizer.initialize(config, "cuda") 150 | 151 | negatives = so.sample(64, nn.Identity()) 152 | assert negatives.shape == (64, config.train_samples, bounds.shape[1]) 153 | -------------------------------------------------------------------------------- /ibc/trainer.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import dataclasses 4 | import enum 5 | from typing import Protocol 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from tqdm.auto import tqdm 11 | 12 | from . import experiment, models, optimizers 13 | 14 | 15 | class TrainStateProtocol(Protocol): 16 | """Functionality that needs to be implemented by all training states.""" 17 | 18 | model: nn.Module 19 | device: torch.device 20 | steps: int 21 | 22 | def training_step( 23 | self, input: torch.Tensor, target: torch.Tensor 24 | ) -> experiment.TensorboardLogData: 25 | """Performs a single training step on a mini-batch of data.""" 26 | 27 | def evaluate( 28 | self, dataloader: torch.utils.data.DataLoader 29 | ) -> experiment.TensorboardLogData: 30 | """Performs a full evaluation of the model on one epoch.""" 31 | 32 | def predict(self, input: torch.Tensor) -> torch.Tensor: 33 | """Performs a single inference step on a mini-batch of data.""" 34 | 35 | 36 | @dataclasses.dataclass 37 | class ExplicitTrainState: 38 | """An explicit feedforward policy trained with a MSE objective.""" 39 | 40 | model: nn.Module 41 | optimizer: torch.optim.Optimizer 42 | scheduler: torch.optim.lr_scheduler._LRScheduler 43 | device: torch.device 44 | steps: int 45 | 46 | @staticmethod 47 | def initialize( 48 | model_config: models.ConvMLPConfig, 49 | optim_config: optimizers.OptimizerConfig, 50 | device_type: str, 51 | ) -> ExplicitTrainState: 52 | device = torch.device(device_type if torch.cuda.is_available() else "cpu") 53 | print(f"Using device: {device}") 54 | 55 | model = models.ConvMLP(config=model_config) 56 | model.to(device) 57 | 58 | optimizer = torch.optim.Adam( 59 | model.parameters(), 60 | lr=optim_config.learning_rate, 61 | weight_decay=optim_config.weight_decay, 62 | betas=(optim_config.beta1, optim_config.beta2), 63 | ) 64 | 65 | scheduler = torch.optim.lr_scheduler.StepLR( 66 | optimizer, 67 | step_size=optim_config.lr_scheduler_step, 68 | gamma=optim_config.lr_scheduler_gamma, 69 | ) 70 | 71 | return ExplicitTrainState( 72 | model=model, 73 | optimizer=optimizer, 74 | scheduler=scheduler, 75 | device=device, 76 | steps=0, 77 | ) 78 | 79 | def training_step( 80 | self, input: torch.Tensor, target: torch.Tensor 81 | ) -> experiment.TensorboardLogData: 82 | self.model.train() 83 | 84 | input = input.to(self.device) 85 | target = target.to(self.device) 86 | 87 | out = self.model(input) 88 | loss = F.mse_loss(out, target) 89 | 90 | self.optimizer.zero_grad(set_to_none=True) 91 | loss.backward() 92 | self.optimizer.step() 93 | self.scheduler.step() 94 | 95 | self.steps += 1 96 | 97 | return experiment.TensorboardLogData( 98 | scalars={ 99 | "train/loss": loss.item(), 100 | "train/learning_rate": self.scheduler.get_last_lr()[0], 101 | } 102 | ) 103 | 104 | @torch.no_grad() 105 | def evaluate( 106 | self, dataloader: torch.utils.data.DataLoader 107 | ) -> experiment.TensorboardLogData: 108 | self.model.eval() 109 | 110 | total_mse = 0.0 111 | for input, target in tqdm(dataloader, leave=False): 112 | input = input.to(self.device) 113 | target = target.to(self.device) 114 | 115 | out = self.model(input) 116 | mse = F.mse_loss(out, target, reduction="none") 117 | total_mse += mse.mean(dim=-1).sum().item() 118 | 119 | mean_mse = total_mse / len(dataloader.dataset) 120 | return experiment.TensorboardLogData(scalars={"test/mse": mean_mse}) 121 | 122 | @torch.no_grad() 123 | def predict(self, input: torch.Tensor) -> torch.Tensor: 124 | self.model.eval() 125 | return self.model(input.to(self.device)) 126 | 127 | 128 | @dataclasses.dataclass 129 | class ImplicitTrainState: 130 | """An implicit conditional EBM trained with an InfoNCE objective.""" 131 | 132 | model: nn.Module 133 | optimizer: torch.optim.Optimizer 134 | scheduler: torch.optim.lr_scheduler._LRScheduler 135 | stochastic_optimizer: optimizers.StochasticOptimizer 136 | device: torch.device 137 | steps: int 138 | 139 | @staticmethod 140 | def initialize( 141 | model_config: models.ConvMLPConfig, 142 | optim_config: optimizers.OptimizerConfig, 143 | stochastic_optim_config: optimizers.DerivativeFreeConfig, 144 | device_type: str, 145 | ) -> ImplicitTrainState: 146 | device = torch.device(device_type if torch.cuda.is_available() else "cpu") 147 | print(f"Using device: {device}") 148 | 149 | model = models.EBMConvMLP(config=model_config) 150 | model.to(device) 151 | 152 | optimizer = torch.optim.Adam( 153 | model.parameters(), 154 | lr=optim_config.learning_rate, 155 | weight_decay=optim_config.weight_decay, 156 | betas=(optim_config.beta1, optim_config.beta2), 157 | ) 158 | 159 | scheduler = torch.optim.lr_scheduler.StepLR( 160 | optimizer, 161 | step_size=optim_config.lr_scheduler_step, 162 | gamma=optim_config.lr_scheduler_gamma, 163 | ) 164 | 165 | stochastic_optimizer = optimizers.DerivativeFreeOptimizer.initialize( 166 | stochastic_optim_config, 167 | device_type, 168 | ) 169 | 170 | return ImplicitTrainState( 171 | model=model, 172 | optimizer=optimizer, 173 | scheduler=scheduler, 174 | stochastic_optimizer=stochastic_optimizer, 175 | device=device, 176 | steps=0, 177 | ) 178 | 179 | def training_step( 180 | self, input: torch.Tensor, target: torch.Tensor 181 | ) -> experiment.TensorboardLogData: 182 | self.model.train() 183 | 184 | input = input.to(self.device) 185 | target = target.to(self.device) 186 | 187 | # Generate N negatives, one for each element in the batch: (B, N, D). 188 | negatives = self.stochastic_optimizer.sample(input.size(0), self.model) 189 | 190 | # Merge target and negatives: (B, N+1, D). 191 | targets = torch.cat([target.unsqueeze(dim=1), negatives], dim=1) 192 | 193 | # Generate a random permutation of the positives and negatives. 194 | permutation = torch.rand(targets.size(0), targets.size(1)).argsort(dim=1) 195 | targets = targets[torch.arange(targets.size(0)).unsqueeze(-1), permutation] 196 | 197 | # Get the original index of the positive. This will serve as the class label 198 | # for the loss. 199 | ground_truth = (permutation == 0).nonzero()[:, 1].to(self.device) 200 | 201 | # For every element in the mini-batch, there is 1 positive for which the EBM 202 | # should output a low energy value, and N negatives for which the EBM should 203 | # output high energy values. 204 | energy = self.model(input, targets) 205 | 206 | # Interpreting the energy as a negative logit, we can apply a cross entropy loss 207 | # to train the EBM. 208 | logits = -1.0 * energy 209 | loss = F.cross_entropy(logits, ground_truth) 210 | 211 | self.optimizer.zero_grad(set_to_none=True) 212 | loss.backward() 213 | self.optimizer.step() 214 | self.scheduler.step() 215 | 216 | self.steps += 1 217 | 218 | return experiment.TensorboardLogData( 219 | scalars={ 220 | "train/loss": loss.item(), 221 | "train/learning_rate": self.scheduler.get_last_lr()[0], 222 | } 223 | ) 224 | 225 | @torch.no_grad() 226 | def evaluate( 227 | self, dataloader: torch.utils.data.DataLoader 228 | ) -> experiment.TensorboardLogData: 229 | self.model.eval() 230 | 231 | total_mse = 0.0 232 | for input, target in tqdm(dataloader, leave=False): 233 | input = input.to(self.device) 234 | target = target.to(self.device) 235 | 236 | out = self.stochastic_optimizer.infer(input, self.model) 237 | 238 | mse = F.mse_loss(out, target, reduction="none") 239 | total_mse += mse.mean(dim=-1).sum().item() 240 | 241 | mean_mse = total_mse / len(dataloader.dataset) 242 | return experiment.TensorboardLogData(scalars={"test/mse": mean_mse}) 243 | 244 | @torch.no_grad() 245 | def predict(self, input: torch.Tensor) -> torch.Tensor: 246 | self.model.eval() 247 | return self.stochastic_optimizer.infer(input.to(self.device), self.model) 248 | 249 | 250 | class PolicyType(enum.Enum): 251 | EXPLICIT = ExplicitTrainState 252 | """An explicit policy is a feedforward structure trained with a MSE objective.""" 253 | 254 | IMPLICIT = ImplicitTrainState 255 | """An implicit policy is a conditional EBM trained with an InfoNCE objective.""" 256 | -------------------------------------------------------------------------------- /ibc/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | 4 | import numpy as np 5 | import torch 6 | 7 | 8 | def seed_rngs(seed: int, pytorch: bool = True) -> None: 9 | os.environ["PYTHONHASHSEED"] = str(seed) 10 | random.seed(seed) 11 | np.random.seed(seed) 12 | if pytorch: 13 | torch.manual_seed(seed) 14 | 15 | 16 | def set_cudnn(deterministic: bool = False, benchmark: bool = True) -> None: 17 | torch.backends.cudnn.deterministic = deterministic 18 | torch.backends.cudnn.benchmark = benchmark 19 | -------------------------------------------------------------------------------- /mypy.ini: -------------------------------------------------------------------------------- 1 | [mypy] 2 | ignore_missing_imports = True 3 | -------------------------------------------------------------------------------- /plot.py: -------------------------------------------------------------------------------- 1 | """Generate Figure 4 plot. Pass in --help flag for options.""" 2 | 3 | import dataclasses 4 | import pathlib 5 | from typing import Dict, Tuple 6 | 7 | import dcargs 8 | import matplotlib 9 | import matplotlib.pyplot as plt 10 | import numpy as np 11 | import torch 12 | from scipy.spatial import ConvexHull 13 | from tqdm.auto import tqdm 14 | 15 | from ibc.dataset import CoordinateRegression 16 | from ibc.experiment import Experiment 17 | from ibc.trainer import TrainStateProtocol 18 | from train import TrainConfig, make_dataloaders, make_train_state 19 | 20 | 21 | @dataclasses.dataclass 22 | class Args: 23 | experiment_name: str 24 | plot_dir: str = "assets" 25 | dpi: int = 200 26 | threshold: float = 140 27 | 28 | 29 | def eval( 30 | train_state: TrainStateProtocol, 31 | dataloaders: Dict[str, torch.utils.data.DataLoader], 32 | ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: 33 | dataset_test = dataloaders["test"].dataset 34 | dataset_train = dataloaders["train"].dataset 35 | assert isinstance(dataset_test, CoordinateRegression) 36 | assert isinstance(dataset_train, CoordinateRegression) 37 | 38 | total_mse = 0.0 39 | num_small_err = 0 40 | pixel_error = [] 41 | for batch in tqdm(dataloaders["test"]): 42 | input, target = batch 43 | prediction = train_state.predict(input).cpu().numpy() 44 | target = target.cpu().numpy() 45 | 46 | pred_unscaled = np.array(prediction) 47 | pred_unscaled += 1 48 | pred_unscaled /= 2 49 | pred_unscaled[:, 0] *= dataset_test.resolution[0] - 1 50 | pred_unscaled[:, 1] *= dataset_test.resolution[1] - 1 51 | 52 | target_unscaled = np.array(target) 53 | target_unscaled += 1 54 | target_unscaled /= 2 55 | target_unscaled[:, 0] *= dataset_test.resolution[0] - 1 56 | target_unscaled[:, 1] *= dataset_test.resolution[1] - 1 57 | 58 | diff = pred_unscaled - target_unscaled 59 | error = np.asarray(np.linalg.norm(diff, axis=1)) 60 | num_small_err += len(error[error < 1.0]) 61 | pixel_error.extend(error.tolist()) 62 | total_mse += (diff ** 2).mean(axis=1).sum() 63 | 64 | total_test = len(dataset_test) 65 | average_mse = total_mse / total_test 66 | print(f"Test set MSE: {average_mse} ({num_small_err}/{total_test})") 67 | 68 | test_coords = dataset_test.coordinates 69 | train_coords = dataset_train.coordinates 70 | return train_coords, test_coords, np.asarray(pixel_error) 71 | 72 | 73 | def plot( 74 | train_coords: np.ndarray, 75 | test_coords: np.ndarray, 76 | errors: np.ndarray, 77 | resolution: Tuple[int, int], 78 | plot_path: pathlib.Path, 79 | dpi: int, 80 | threshold: float, 81 | ) -> None: 82 | # Threshold the errors so that all generated plot colors cover the same range. 83 | errors[errors >= threshold] = threshold 84 | colormap = plt.cm.Reds 85 | normalize = matplotlib.colors.Normalize(vmin=0, vmax=threshold) 86 | 87 | plt.scatter( 88 | train_coords[:, 0], 89 | train_coords[:, 1], 90 | marker="x", 91 | c="black", 92 | zorder=2, 93 | alpha=0.5, 94 | ) 95 | plt.scatter( 96 | test_coords[:, 0], 97 | test_coords[:, 1], 98 | c=errors, 99 | cmap=colormap, 100 | norm=normalize, 101 | zorder=1, 102 | ) 103 | plt.colorbar() 104 | 105 | # Find index of predictions with less than 1 pixel error and color them in blue. 106 | idxs = errors < 1.0 107 | plt.scatter( 108 | test_coords[idxs, 0], 109 | test_coords[idxs, 1], 110 | marker="o", 111 | c="blue", 112 | zorder=1, 113 | alpha=1.0, 114 | ) 115 | 116 | # Add convex hull of train set. 117 | if train_coords.shape[0] > 2: 118 | for simplex in ConvexHull(train_coords).simplices: 119 | plt.plot( 120 | train_coords[simplex, 0], 121 | train_coords[simplex, 1], 122 | "--", 123 | zorder=2, 124 | alpha=0.5, 125 | c="black", 126 | ) 127 | 128 | plt.xlim(0 - 2, resolution[1] + 2) 129 | plt.ylim(0 - 2, resolution[0] + 2) 130 | 131 | plt.savefig(plot_path, format="png", dpi=dpi) 132 | plt.close() 133 | 134 | 135 | def main(args: Args): 136 | plot_dir = pathlib.Path(args.plot_dir) 137 | plot_dir.mkdir(parents=True, exist_ok=True) 138 | 139 | experiment = Experiment( 140 | identifier=args.experiment_name, 141 | ).assert_exists() 142 | 143 | # Read saved config file. 144 | train_config = experiment.read_metadata("config", TrainConfig) 145 | 146 | # Restore training state. 147 | dataloaders = make_dataloaders(train_config) 148 | train_state = make_train_state(train_config, dataloaders["train"]) 149 | experiment.restore_checkpoint(train_state) 150 | print(f"Loaded checkpoint at step: {train_state.steps}.") 151 | 152 | # Compute MSE for every test set data point. 153 | train_coords, test_coords, errors = eval(train_state, dataloaders) 154 | 155 | # Plot and dump to disk. 156 | plot( 157 | train_coords, 158 | test_coords, 159 | errors, 160 | dataloaders["test"].dataset.resolution, 161 | plot_dir / f"{args.experiment_name}.png", 162 | args.dpi, 163 | args.threshold, 164 | ) 165 | 166 | 167 | if __name__ == "__main__": 168 | main(dcargs.parse(Args)) 169 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | dcargs 2 | torch 3 | torchvision 4 | matplotlib 5 | tqdm 6 | tensorboard 7 | scipy 8 | -------------------------------------------------------------------------------- /run_explicit.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | set -x 5 | 6 | for train_size in 10 30 7 | do 8 | EXPERIMENT_NAME=explicit_mse_${train_size} 9 | 10 | python train.py \ 11 | --experiment-name $EXPERIMENT_NAME \ 12 | --train-dataset-size $train_size \ 13 | --policy-type EXPLICIT \ 14 | --dropout-prob 0.1 \ 15 | --weight-decay 1e-4 \ 16 | --max-epochs 2000 \ 17 | --learning-rate 1e-3 \ 18 | --spatial-reduction SPATIAL_SOFTMAX \ 19 | --coord-conv \ 20 | 21 | python plot.py --experiment-name $EXPERIMENT_NAME 22 | done 23 | -------------------------------------------------------------------------------- /run_implicit.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | set -x 5 | 6 | for train_size in 10 30 7 | do 8 | EXPERIMENT_NAME=implicit_ebm_${train_size} 9 | 10 | python train.py \ 11 | --experiment-name $EXPERIMENT_NAME \ 12 | --train-dataset-size $train_size \ 13 | --policy-type IMPLICIT \ 14 | --dropout-prob 0.0 \ 15 | --weight-decay 0.0 \ 16 | --max-epochs 2000 \ 17 | --learning-rate 1e-3 \ 18 | --spatial-reduction SPATIAL_SOFTMAX \ 19 | --coord-conv \ 20 | --stochastic-optimizer-train-samples 128 \ 21 | 22 | python plot.py --experiment-name $EXPERIMENT_NAME 23 | done 24 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """Script for training. Pass in --help flag for options.""" 2 | 3 | import dataclasses 4 | from typing import Dict, Optional 5 | 6 | import dcargs 7 | import torch 8 | from tqdm.auto import tqdm 9 | 10 | from ibc import dataset, models, optimizers, trainer, utils 11 | from ibc.experiment import Experiment 12 | 13 | 14 | @dataclasses.dataclass 15 | class TrainConfig: 16 | experiment_name: str 17 | seed: int = 0 18 | device_type: str = "cuda" 19 | train_dataset_size: int = 10 20 | test_dataset_size: int = 500 21 | max_epochs: int = 200 22 | learning_rate: float = 1e-3 23 | weight_decay: float = 0.0 24 | train_batch_size: int = 8 25 | test_batch_size: int = 64 26 | spatial_reduction: models.SpatialReduction = models.SpatialReduction.SPATIAL_SOFTMAX 27 | coord_conv: bool = False 28 | dropout_prob: Optional[float] = None 29 | num_workers: int = 1 30 | cudnn_deterministic: bool = True 31 | cudnn_benchmark: bool = False 32 | log_every_n_steps: int = 10 33 | checkpoint_every_n_steps: int = 100 34 | eval_every_n_steps: int = 1000 35 | policy_type: trainer.PolicyType = trainer.PolicyType.EXPLICIT 36 | stochastic_optimizer_train_samples: int = 64 37 | 38 | 39 | def make_dataloaders( 40 | train_config: TrainConfig, 41 | ) -> Dict[str, torch.utils.data.DataLoader]: 42 | """Initialize train/test dataloaders based on config values.""" 43 | # Train split. 44 | train_dataset_config = dataset.DatasetConfig( 45 | dataset_size=train_config.train_dataset_size, 46 | seed=train_config.seed, 47 | ) 48 | train_dataset = dataset.CoordinateRegression(train_dataset_config) 49 | train_dataloader = torch.utils.data.DataLoader( 50 | train_dataset, 51 | batch_size=train_config.train_batch_size, 52 | shuffle=True, 53 | num_workers=train_config.num_workers, 54 | pin_memory=torch.cuda.is_available(), 55 | ) 56 | 57 | # Test split. 58 | test_dataset_config = dataset.DatasetConfig( 59 | dataset_size=train_config.test_dataset_size, 60 | seed=train_config.seed, 61 | ) 62 | test_dataset = dataset.CoordinateRegression(test_dataset_config) 63 | test_dataset.exclude(train_dataset.coordinates) 64 | test_dataloader = torch.utils.data.DataLoader( 65 | test_dataset, 66 | batch_size=train_config.test_batch_size, 67 | shuffle=False, 68 | num_workers=train_config.num_workers, 69 | pin_memory=torch.cuda.is_available(), 70 | ) 71 | 72 | return { 73 | "train": train_dataloader, 74 | "test": test_dataloader, 75 | } 76 | 77 | 78 | def make_train_state( 79 | train_config: TrainConfig, 80 | train_dataloader: torch.utils.data.DataLoader, 81 | ) -> trainer.TrainStateProtocol: 82 | """Initialize train state based on config values.""" 83 | in_channels = 3 84 | if train_config.coord_conv: 85 | in_channels += 2 86 | residual_blocks = [16, 32, 32] 87 | cnn_config = models.CNNConfig(in_channels, residual_blocks) 88 | 89 | input_dim = 16 # We have a 1x1 conv that reduces to 16 channels. 90 | output_dim = 2 91 | if train_config.spatial_reduction == models.SpatialReduction.SPATIAL_SOFTMAX: 92 | input_dim *= 2 93 | if train_config.policy_type == trainer.PolicyType.IMPLICIT: 94 | input_dim += 2 # Dimension of the targets. 95 | output_dim = 1 96 | mlp_config = models.MLPConfig( 97 | input_dim=input_dim, 98 | hidden_dim=256, 99 | output_dim=output_dim, 100 | hidden_depth=1, 101 | dropout_prob=train_config.dropout_prob, 102 | ) 103 | 104 | model_config = models.ConvMLPConfig( 105 | cnn_config=cnn_config, 106 | mlp_config=mlp_config, 107 | spatial_reduction=train_config.spatial_reduction, 108 | coord_conv=train_config.coord_conv, 109 | ) 110 | 111 | optim_config = optimizers.OptimizerConfig( 112 | learning_rate=train_config.learning_rate, 113 | weight_decay=train_config.weight_decay, 114 | ) 115 | 116 | train_state: trainer.TrainStateProtocol 117 | if train_config.policy_type == trainer.PolicyType.EXPLICIT: 118 | train_state = trainer.ExplicitTrainState.initialize( 119 | model_config=model_config, 120 | optim_config=optim_config, 121 | device_type=train_config.device_type, 122 | ) 123 | else: 124 | target_bounds = train_dataloader.dataset.get_target_bounds() 125 | stochastic_optim_config = optimizers.DerivativeFreeConfig( 126 | bounds=target_bounds, 127 | train_samples=train_config.stochastic_optimizer_train_samples, 128 | ) 129 | 130 | train_state = trainer.ImplicitTrainState.initialize( 131 | model_config=model_config, 132 | optim_config=optim_config, 133 | stochastic_optim_config=stochastic_optim_config, 134 | device_type=train_config.device_type, 135 | ) 136 | 137 | return train_state 138 | 139 | 140 | def main(train_config: TrainConfig) -> None: 141 | # Seed RNGs. 142 | utils.seed_rngs(train_config.seed) 143 | 144 | # CUDA/CUDNN-related shenanigans. 145 | utils.set_cudnn(train_config.cudnn_deterministic, train_config.cudnn_benchmark) 146 | 147 | experiment = Experiment( 148 | identifier=train_config.experiment_name, 149 | ).assert_new() 150 | 151 | # Write some metadata. 152 | experiment.write_metadata("config", train_config) 153 | 154 | # Initialize train and test dataloaders. 155 | dataloaders = make_dataloaders(train_config) 156 | 157 | train_state = make_train_state(train_config, dataloaders["train"]) 158 | 159 | for epoch in tqdm(range(train_config.max_epochs)): 160 | if not train_state.steps % train_config.checkpoint_every_n_steps: 161 | experiment.save_checkpoint(train_state, step=train_state.steps) 162 | 163 | if not train_state.steps % train_config.eval_every_n_steps: 164 | test_log_data = train_state.evaluate(dataloaders["test"]) 165 | experiment.log(test_log_data, step=train_state.steps) 166 | 167 | for batch in dataloaders["train"]: 168 | train_log_data = train_state.training_step(*batch) 169 | 170 | # Log to tensorboard. 171 | if not train_state.steps % train_config.log_every_n_steps: 172 | experiment.log(train_log_data, step=train_state.steps) 173 | 174 | # Save one final checkpoint. 175 | experiment.save_checkpoint(train_state, step=train_state.steps) 176 | 177 | 178 | if __name__ == "__main__": 179 | main(dcargs.parse(TrainConfig, description=__doc__)) 180 | --------------------------------------------------------------------------------