├── .gitignore
├── CITATION.cff
├── README.md
├── assets
├── explicit_mse_10.png
├── explicit_mse_30.png
├── implicit_ebm_10.png
└── implicit_ebm_30.png
├── ibc
├── __init__.py
├── dataset.py
├── experiment.py
├── models.py
├── modules.py
├── optimizers.py
├── trainer.py
└── utils.py
├── mypy.ini
├── plot.py
├── requirements.txt
├── run_explicit.sh
├── run_implicit.sh
└── train.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__/
2 | experiments/
3 |
--------------------------------------------------------------------------------
/CITATION.cff:
--------------------------------------------------------------------------------
1 | cff-version: 0.0.1
2 | message: "If you use this software, please cite it as below."
3 | authors:
4 | - family-names: Zakka
5 | given-names: Kevin
6 | title: "A PyTorch Implementation of Implicit Behavioral Cloning"
7 | version: 0.0.1
8 | date-released: 2021-10-27
9 | url: "https://github.com/kevinzakka/ibc"
10 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## Implicit Behavioral Cloning - PyTorch
2 |
3 | Pytorch implementation of Implicit Behavioral Cloning.
4 |
5 | ## Install
6 |
7 | ```bash
8 | conda create -n ibc python=3.8
9 | pip install -r requirements.txt
10 | ```
11 |
12 | ## Results
13 |
14 | To reproduce results from the Coordinate Regression Task (Section 3), execute the `run_explicit.sh` and `run_implicit.sh` scripts. Note that the implicit policy does a tad bit worst with 30 examples than 10. Not entirely sure why that is the case and need to investigate more.
15 |
16 | | | Explicit Policy | Implicit Policy |
17 | |-------------|-----------------|-----------------|
18 | | 10 examples |
|
|
19 | | 30 examples |
|
|
20 |
21 | ## Citation
22 |
23 | If you find this code useful, consider citing it along with the paper:
24 |
25 | ```bibtex
26 | @software{zakka2021ibc,
27 | author = {Zakka, Kevin},
28 | month = {10},
29 | title = {{A PyTorch Implementation of Implicit Behavioral Cloning}},
30 | url = {https://github.com/kevinzakka/ibc},
31 | version = {0.0.1},
32 | year = {2021}
33 | }
34 | ```
35 |
36 | ```bibtex
37 | @misc{florence2021implicit,
38 | title = {Implicit Behavioral Cloning},
39 | author = {Pete Florence and Corey Lynch and Andy Zeng and Oscar Ramirez and Ayzaan Wahid and Laura Downs and Adrian Wong and Johnny Lee and Igor Mordatch and Jonathan Tompson},
40 | year = {2021},
41 | eprint = {2109.00137},
42 | archivePrefix = {arXiv},
43 | primaryClass = {cs.RO}
44 | }
45 | ```
46 |
--------------------------------------------------------------------------------
/assets/explicit_mse_10.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kevinzakka/ibc/a30c4f527ccf23b7b893824e1c64b7385b7e85e7/assets/explicit_mse_10.png
--------------------------------------------------------------------------------
/assets/explicit_mse_30.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kevinzakka/ibc/a30c4f527ccf23b7b893824e1c64b7385b7e85e7/assets/explicit_mse_30.png
--------------------------------------------------------------------------------
/assets/implicit_ebm_10.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kevinzakka/ibc/a30c4f527ccf23b7b893824e1c64b7385b7e85e7/assets/implicit_ebm_10.png
--------------------------------------------------------------------------------
/assets/implicit_ebm_30.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kevinzakka/ibc/a30c4f527ccf23b7b893824e1c64b7385b7e85e7/assets/implicit_ebm_30.png
--------------------------------------------------------------------------------
/ibc/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kevinzakka/ibc/a30c4f527ccf23b7b893824e1c64b7385b7e85e7/ibc/__init__.py
--------------------------------------------------------------------------------
/ibc/dataset.py:
--------------------------------------------------------------------------------
1 | import dataclasses
2 | from typing import Optional, Tuple
3 |
4 | import numpy as np
5 | import torch
6 | from torch.utils.data import Dataset
7 | from torchvision.transforms import ToTensor
8 |
9 |
10 | @dataclasses.dataclass
11 | class DatasetConfig:
12 | dataset_size: int = 30
13 | """The size of the dataset. Useful for sample efficiency experiments."""
14 |
15 | resolution: Tuple[int, int] = (96, 96)
16 | """The resolution of the image."""
17 |
18 | pixel_size: int = 7
19 | """The size of the pixel whose coordinates we'd like to regress. Must be odd."""
20 |
21 | pixel_color: Tuple[int, int, int] = (0, 255, 0)
22 | """The color of the pixel whose coordinates we'd like to regress."""
23 |
24 | seed: Optional[int] = None
25 | """Whether to seed the dataset. Disabled if None."""
26 |
27 |
28 | class CoordinateRegression(Dataset):
29 | """Regress the coordinates of a colored pixel block on a white canvas."""
30 |
31 | def __init__(self, config: DatasetConfig) -> None:
32 | if not config.pixel_size % 2:
33 | raise ValueError("'pixel_size' must be odd.")
34 |
35 | self.dataset_size = config.dataset_size
36 | self.resolution = config.resolution
37 | self.pixel_size = config.pixel_size
38 | self.pixel_color = config.pixel_color
39 | self.seed = config.seed
40 |
41 | self.reset()
42 |
43 | def reset(self) -> None:
44 | if self.seed is not None:
45 | np.random.seed(self.seed)
46 |
47 | self._coordinates = self._sample_coordinates(self.dataset_size)
48 | self._coordinates_scaled = self._scale_coordinates(self._coordinates)
49 |
50 | def exclude(self, coordinates: np.ndarray) -> None:
51 | """Exclude the given coordinates, if present, from the previously sampled ones.
52 |
53 | This is useful for ensuring the train set does not accidentally leak into the
54 | test set.
55 | """
56 | mask = (self.coordinates[:, None] == coordinates).all(-1).any(1)
57 | num_matches = mask.sum()
58 | while mask.sum() > 0:
59 | self._coordinates[mask] = self._sample_coordinates(mask.sum())
60 | mask = (self.coordinates[:, None] == coordinates).all(-1).any(1)
61 | self._coordinates_scaled = self._scale_coordinates(self._coordinates)
62 | print(f"Resampled {num_matches} data points.")
63 |
64 | def get_target_bounds(self) -> np.ndarray:
65 | """Return per-dimension target min/max."""
66 | return np.array([[-1.0, -1.0], [1.0, 1.0]])
67 |
68 | def _sample_coordinates(self, size: int) -> np.ndarray:
69 | """Helper method for generating pixel coordinates."""
70 | # Randomly generate pixel coordinates.
71 | u = np.random.randint(0, self.resolution[0], size=size)
72 | v = np.random.randint(0, self.resolution[1], size=size)
73 |
74 | # Ensure we remain within bounds when we take the pixel size into account.
75 | slack = self.pixel_size // 2
76 | u = np.clip(u, a_min=slack, a_max=self.resolution[0] - 1 - slack)
77 | v = np.clip(v, a_min=slack, a_max=self.resolution[1] - 1 - slack)
78 |
79 | return np.vstack([u, v]).astype(np.int16).T
80 |
81 | def _scale_coordinates(self, coords: np.ndarray) -> np.ndarray:
82 | """Helper method for scaling coordinates to the [-1, 1] range."""
83 | coords_scaled = np.array(coords, dtype=np.float32)
84 | coords_scaled[:, 0] /= self.resolution[0] - 1
85 | coords_scaled[:, 1] /= self.resolution[1] - 1
86 | coords_scaled *= 2
87 | coords_scaled -= 1
88 | return coords_scaled
89 |
90 | @property
91 | def image_shape(self) -> Tuple[int, int, int]:
92 | return self.resolution + (3,)
93 |
94 | @property
95 | def coordinates(self) -> np.ndarray:
96 | return self._coordinates
97 |
98 | @property
99 | def coordinates_scaled(self) -> np.ndarray:
100 | return self._coordinates_scaled
101 |
102 | def __len__(self) -> int:
103 | return self.dataset_size
104 |
105 | def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]:
106 | uv = self._coordinates[index]
107 | uv_scaled = self._coordinates_scaled[index]
108 |
109 | image = np.full(self.image_shape, fill_value=255, dtype=np.uint8)
110 | image[
111 | uv[0] - self.pixel_size // 2 : uv[0] + self.pixel_size // 2 + 1,
112 | uv[1] - self.pixel_size // 2 : uv[1] + self.pixel_size // 2 + 1,
113 | ] = self.pixel_color
114 |
115 | image_tensor = ToTensor()(image)
116 | target_tensor = torch.as_tensor(uv_scaled, dtype=torch.float32)
117 |
118 | return image_tensor, target_tensor
119 |
120 |
121 | if __name__ == "__main__":
122 | import matplotlib.pyplot as plt
123 | from scipy.spatial import ConvexHull
124 |
125 | dataset = CoordinateRegression(DatasetConfig(dataset_size=30, seed=0))
126 |
127 | # Visualize one instance.
128 | image, target = dataset[np.random.randint(len(dataset))]
129 | print(target)
130 | plt.imshow(image.permute(1, 2, 0).numpy())
131 | plt.show()
132 |
133 | # Plot target distribution and convex hull.
134 | targets = dataset.coordinates
135 | plt.scatter(targets[:, 0], targets[:, 1], marker="x", c="black")
136 | for simplex in ConvexHull(targets).simplices:
137 | plt.plot(
138 | targets[simplex, 0],
139 | targets[simplex, 1],
140 | "--",
141 | zorder=2,
142 | alpha=0.5,
143 | c="black",
144 | )
145 | plt.xlim(0, dataset.resolution[1])
146 | plt.ylim(0, dataset.resolution[0])
147 | plt.show()
148 |
149 | # Plot target distribution and convex hull.
150 | targets = dataset.coordinates_scaled
151 | plt.scatter(targets[:, 0], targets[:, 1], marker="x", c="black")
152 | for simplex in ConvexHull(targets).simplices:
153 | plt.plot(
154 | targets[simplex, 0],
155 | targets[simplex, 1],
156 | "--",
157 | zorder=2,
158 | alpha=0.5,
159 | c="black",
160 | )
161 | plt.xlim(-1, 1)
162 | plt.ylim(-1, 1)
163 | plt.show()
164 |
165 | print(f"Target bounds:")
166 | print(dataset.get_target_bounds())
167 |
--------------------------------------------------------------------------------
/ibc/experiment.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import dataclasses
4 | import os
5 | import pathlib
6 | import signal
7 | import tempfile
8 | from typing import TYPE_CHECKING, Any, Dict, Optional, Type, TypeVar, Union
9 |
10 | if TYPE_CHECKING:
11 | from .trainer import TrainStateProtocol
12 |
13 | import numpy as np
14 | import torch
15 | import yaml
16 | from torch.utils.tensorboard import SummaryWriter
17 |
18 | T = TypeVar("T")
19 | TensorOrFloat = Union[np.ndarray, torch.Tensor, float]
20 |
21 |
22 | @dataclasses.dataclass
23 | class TensorboardLogData:
24 | scalars: Dict[str, TensorOrFloat] = dataclasses.field(default_factory=dict)
25 |
26 | @staticmethod
27 | def merge(a: TensorboardLogData, b: TensorboardLogData) -> TensorboardLogData:
28 | return TensorboardLogData(scalars=dict(**a.scalars, **b.scalars))
29 |
30 | def extend(self, scalars: Dict[str, TensorOrFloat] = {}) -> TensorboardLogData:
31 | return TensorboardLogData.merge(self, TensorboardLogData(scalars=scalars))
32 |
33 |
34 | @dataclasses.dataclass(frozen=True)
35 | class Experiment:
36 | identifier: str
37 | """The name of the experiment."""
38 |
39 | data_dir: pathlib.Path = dataclasses.field(init=False)
40 | log_dir: pathlib.Path = dataclasses.field(init=False)
41 | checkpoint_dir: pathlib.Path = dataclasses.field(init=False)
42 |
43 | def __post_init__(self) -> None:
44 | root = pathlib.Path("./experiments")
45 |
46 | super().__setattr__("data_dir", root / self.identifier)
47 | super().__setattr__("log_dir", self.data_dir / "tb")
48 | super().__setattr__("checkpoint_dir", self.data_dir / "checkpoints")
49 |
50 | def assert_new(self) -> Experiment:
51 | """Makes sure that there are no existing checkpoints, logs, or metadata."""
52 | assert not self.data_dir.exists() or tuple(self.data_dir.iterdir()) == ()
53 | return self
54 |
55 | def assert_exists(self) -> Experiment:
56 | """Makes sure that there are existing checkpoints, logs, or metadata."""
57 | assert self.data_dir.exists() and tuple(self.data_dir.iterdir()) != ()
58 | return self
59 |
60 | # =================================================================== #
61 | # Properties.
62 | # =================================================================== #
63 |
64 | # Note: This hack is necessary because the SummaryWriter object instantly creates
65 | # a directory upon construction and thus would break the `assert_*` functionality.
66 | @property
67 | def summary_writer(self) -> SummaryWriter:
68 | if not hasattr(self, "__summary_writer__"):
69 | object.__setattr__(
70 | self,
71 | "__summary_writer__",
72 | SummaryWriter(self.log_dir),
73 | )
74 | return object.__getattribute__(self, "__summary_writer__")
75 |
76 | # =================================================================== #
77 | # Checkpointing.
78 | # =================================================================== #
79 |
80 | def save_checkpoint(
81 | self,
82 | target: TrainStateProtocol,
83 | step: int,
84 | prefix: str = "ckpt_",
85 | keep: int = 10,
86 | ) -> None:
87 | """Save a snapshot of the train state to disk."""
88 | self._ensure_directory_exists(self.checkpoint_dir)
89 |
90 | # Create a snapshot of the state.
91 | snapshot_state = {}
92 | for k, v in target.__dict__.items():
93 | if hasattr(v, "state_dict"):
94 | snapshot_state[k] = v.state_dict()
95 | snapshot_state["steps"] = target.steps
96 |
97 | # Save to disk.
98 | checkpoint_path = self.checkpoint_dir / f"{prefix}{step}.ckpt"
99 | self._atomic_save(checkpoint_path, snapshot_state)
100 |
101 | # Trim extraneous checkpoints.
102 | self._trim_checkpoints(keep)
103 |
104 | def restore_checkpoint(
105 | self,
106 | target: TrainStateProtocol,
107 | step: Optional[int] = None,
108 | prefix: str = "ckpt_",
109 | ):
110 | """Restore a snapshot of the train state from disk."""
111 | # Get latest checkpoint if no step has been provided.
112 | if step is None:
113 | step = self._get_latest_checkpoint_step()
114 |
115 | checkpoint_path = self.checkpoint_dir / f"{prefix}{step}.ckpt"
116 | snapshot_state = torch.load(checkpoint_path, map_location="cpu")
117 |
118 | delattr(target, "steps")
119 | setattr(target, "steps", snapshot_state["steps"])
120 |
121 | for k, v in target.__dict__.items():
122 | if hasattr(v, "state_dict"):
123 | getattr(target, k).load_state_dict(snapshot_state[k])
124 |
125 | # =================================================================== #
126 | # Logging.
127 | # =================================================================== #
128 |
129 | def log(self, log_data: TensorboardLogData, step: int) -> None:
130 | """Logging helper for TensorBoard."""
131 | for k, v in log_data.scalars.items():
132 | self.summary_writer.add_scalar(tag=k, scalar_value=v, global_step=step)
133 | self.summary_writer.flush()
134 |
135 | def write_metadata(self, name: str, object: Any) -> None:
136 | """Serialize an object to disk as a yaml file."""
137 | self._ensure_directory_exists(self.data_dir)
138 | assert not name.endswith(".yaml")
139 |
140 | path = self.data_dir / (name + ".yaml")
141 | print(f"Writing metadata to {path}")
142 | with open(path, "w") as fp:
143 | yaml.dump(object, fp)
144 |
145 | def read_metadata(self, name: str, expected_type: Type[T]) -> T:
146 | """Load an object from the experiment's metadata directory."""
147 | path = self.data_dir / (name + ".yaml")
148 |
149 | with open(path, "r") as fp:
150 | output = yaml.load(fp, Loader=yaml.Loader)
151 |
152 | assert isinstance(output, expected_type)
153 | return output
154 |
155 | # =================================================================== #
156 | # Helper functions.
157 | # =================================================================== #
158 |
159 | def _ensure_directory_exists(self, path: pathlib.Path) -> None:
160 | """Helper for ensuring that a directory exists."""
161 | if not path.exists():
162 | path.mkdir(parents=True)
163 |
164 | def _trim_checkpoints(self, keep: int) -> None:
165 | """Helper for deleting older checkpoints."""
166 | # Get a list of checkpoints.
167 | ckpts = list(self.checkpoint_dir.glob(pattern="*.ckpt"))
168 |
169 | # Sort in reverse `step` order.
170 | ckpts.sort(key=lambda f: -int(f.stem.split("_")[-1]))
171 |
172 | # Remove until `keep` remain.
173 | while len(ckpts) - keep > 0:
174 | ckpts.pop().unlink()
175 |
176 | def _get_latest_checkpoint_step(self) -> int:
177 | """Helper for returning the step of the latest checkpoint."""
178 | ckpts = list(self.checkpoint_dir.glob(pattern="*.ckpt"))
179 | ckpts.sort(key=lambda f: -int(f.stem.split("_")[-1]))
180 | return int(ckpts[0].stem.split("_")[-1])
181 |
182 | def _atomic_save(self, save_path: pathlib.Path, snapshot: Dict[str, Any]) -> None:
183 | """Helper for safely saving to disk."""
184 | # Ignore Ctrl-C while saving.
185 | try:
186 | orig_handler = signal.getsignal(signal.SIGINT)
187 | signal.signal(signal.SIGINT, lambda _sig, _frame: None)
188 | except ValueError:
189 | # Signal throws a ValueError if we're not in the main thread.
190 | orig_handler = None
191 |
192 | with tempfile.TemporaryDirectory() as tmp_dir:
193 | tmp_path = pathlib.Path(tmp_dir) / "tmp.ckpt"
194 | torch.save(snapshot, tmp_path)
195 | # `rename` is POSIX-compliant and thus, is an atomic operation.
196 | # Ref: https://docs.python.org/3/library/os.html#os.rename
197 | os.rename(tmp_path, save_path)
198 |
199 | # Restore SIGINT handler.
200 | if orig_handler is not None:
201 | signal.signal(signal.SIGINT, orig_handler)
202 |
--------------------------------------------------------------------------------
/ibc/models.py:
--------------------------------------------------------------------------------
1 | import dataclasses
2 | import enum
3 | from functools import partial
4 | from typing import Callable, Optional, Sequence
5 |
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 |
10 | from .modules import CoordConv, GlobalAvgPool2d, GlobalMaxPool2d, SpatialSoftArgmax
11 |
12 |
13 | class ActivationType(enum.Enum):
14 | RELU = nn.ReLU
15 | SELU = nn.SiLU
16 |
17 |
18 | @dataclasses.dataclass(frozen=True)
19 | class MLPConfig:
20 | input_dim: int
21 | hidden_dim: int
22 | output_dim: int
23 | hidden_depth: int
24 | dropout_prob: Optional[float] = None
25 | activation_fn: ActivationType = ActivationType.RELU
26 |
27 |
28 | class MLP(nn.Module):
29 | """A feedforward multi-layer perceptron."""
30 |
31 | def __init__(self, config: MLPConfig) -> None:
32 | super().__init__()
33 |
34 | dropout_layer: Callable
35 | if config.dropout_prob is not None:
36 | dropout_layer = partial(nn.Dropout, p=config.dropout_prob)
37 | else:
38 | dropout_layer = nn.Identity
39 |
40 | layers: Sequence[nn.Module]
41 | if config.hidden_depth == 0:
42 | layers = [nn.Linear(config.input_dim, config.output_dim)]
43 | else:
44 | layers = [
45 | nn.Linear(config.input_dim, config.hidden_dim),
46 | config.activation_fn.value(),
47 | dropout_layer(),
48 | ]
49 | for _ in range(config.hidden_depth - 1):
50 | layers += [
51 | nn.Linear(config.hidden_dim, config.hidden_dim),
52 | config.activation_fn.value(),
53 | dropout_layer(),
54 | ]
55 | layers += [nn.Linear(config.hidden_dim, config.output_dim)]
56 | layers = [layer for layer in layers if not isinstance(layer, nn.Identity)]
57 |
58 | self.net = nn.Sequential(*layers)
59 |
60 | def forward(self, x: torch.Tensor) -> torch.Tensor:
61 | return self.net(x)
62 |
63 |
64 | class ResidualBlock(nn.Module):
65 | def __init__(
66 | self,
67 | depth: int,
68 | activation_fn: ActivationType = ActivationType.RELU,
69 | ) -> None:
70 | super().__init__()
71 |
72 | self.conv1 = nn.Conv2d(depth, depth, 3, padding=1, bias=True)
73 | self.conv2 = nn.Conv2d(depth, depth, 3, padding=1, bias=True)
74 | self.activation = activation_fn.value()
75 |
76 | def forward(self, x: torch.Tensor) -> torch.Tensor:
77 | out = self.activation(x)
78 | out = self.conv1(out)
79 | out = self.activation(x)
80 | out = self.conv2(out)
81 | return out + x
82 |
83 |
84 | @dataclasses.dataclass(frozen=True)
85 | class CNNConfig:
86 | in_channels: int
87 | blocks: Sequence[int] = dataclasses.field(default=(16, 32, 32))
88 | activation_fn: ActivationType = ActivationType.RELU
89 |
90 |
91 | class CNN(nn.Module):
92 | """A residual convolutional network."""
93 |
94 | def __init__(self, config: CNNConfig) -> None:
95 | super().__init__()
96 |
97 | depth_in = config.in_channels
98 |
99 | layers = []
100 | for depth_out in config.blocks:
101 | layers.extend(
102 | [
103 | nn.Conv2d(depth_in, depth_out, 3, padding=1),
104 | ResidualBlock(depth_out, config.activation_fn),
105 | ]
106 | )
107 | depth_in = depth_out
108 |
109 | self.net = nn.Sequential(*layers)
110 | self.activation = config.activation_fn.value()
111 |
112 | def forward(self, x: torch.Tensor, activate: bool = False) -> torch.Tensor:
113 | out = self.net(x)
114 | if activate:
115 | return self.activation(out)
116 | return out
117 |
118 |
119 | class SpatialReduction(enum.Enum):
120 | SPATIAL_SOFTMAX = SpatialSoftArgmax
121 | AVERAGE_POOL = GlobalAvgPool2d
122 | MAX_POOL = GlobalMaxPool2d
123 |
124 |
125 | @dataclasses.dataclass(frozen=True)
126 | class ConvMLPConfig:
127 | cnn_config: CNNConfig
128 | mlp_config: MLPConfig
129 | spatial_reduction: SpatialReduction = SpatialReduction.AVERAGE_POOL
130 | coord_conv: bool = False
131 |
132 |
133 | class ConvMLP(nn.Module):
134 | def __init__(self, config: ConvMLPConfig) -> None:
135 | super().__init__()
136 |
137 | self.coord_conv = config.coord_conv
138 |
139 | self.cnn = CNN(config.cnn_config)
140 | self.conv = nn.Conv2d(config.cnn_config.blocks[-1], 16, 1)
141 | self.reducer = config.spatial_reduction.value()
142 | self.mlp = MLP(config.mlp_config)
143 |
144 | def forward(self, x: torch.Tensor) -> torch.Tensor:
145 | if self.coord_conv:
146 | x = CoordConv()(x)
147 | out = self.cnn(x, activate=True)
148 | out = F.relu(self.conv(out))
149 | out = self.reducer(out)
150 | out = self.mlp(out)
151 | return out
152 |
153 |
154 | class EBMConvMLP(nn.Module):
155 | def __init__(self, config: ConvMLPConfig) -> None:
156 | super().__init__()
157 |
158 | self.coord_conv = config.coord_conv
159 |
160 | self.cnn = CNN(config.cnn_config)
161 | self.conv = nn.Conv2d(config.cnn_config.blocks[-1], 16, 1)
162 | self.reducer = config.spatial_reduction.value()
163 | self.mlp = MLP(config.mlp_config)
164 |
165 | def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
166 | if self.coord_conv:
167 | x = CoordConv()(x)
168 | out = self.cnn(x, activate=True)
169 | out = F.relu(self.conv(out))
170 | out = self.reducer(out)
171 | fused = torch.cat([out.unsqueeze(1).expand(-1, y.size(1), -1), y], dim=-1)
172 | B, N, D = fused.size()
173 | fused = fused.reshape(B * N, D)
174 | out = self.mlp(fused)
175 | return out.view(B, N)
176 |
177 |
178 | if __name__ == "__main__":
179 | config = ConvMLPConfig(
180 | cnn_config=CNNConfig(5),
181 | mlp_config=MLPConfig(32, 128, 2, 2),
182 | spatial_reduction=SpatialReduction.AVERAGE_POOL,
183 | coord_conv=True,
184 | )
185 |
186 | net = ConvMLP(config)
187 | print(net)
188 |
189 | x = torch.randn(2, 3, 96, 96)
190 | with torch.no_grad():
191 | out = net(x)
192 | print(out.shape)
193 |
--------------------------------------------------------------------------------
/ibc/modules.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | # Adapted from: https://github.com/Wizaron/coord-conv-pytorch
7 | class CoordConv(nn.Module):
8 | def __init__(self) -> None:
9 | super().__init__()
10 |
11 | def forward(self, x: torch.Tensor) -> torch.Tensor:
12 | batch_size, _, image_height, image_width = x.size()
13 | y_coords = (
14 | 2.0
15 | * torch.arange(image_height).unsqueeze(1).expand(image_height, image_width)
16 | / (image_height - 1.0)
17 | - 1.0
18 | )
19 | x_coords = (
20 | 2.0
21 | * torch.arange(image_width).unsqueeze(0).expand(image_height, image_width)
22 | / (image_width - 1.0)
23 | - 1.0
24 | )
25 | coords = torch.stack((y_coords, x_coords), dim=0)
26 | coords = torch.unsqueeze(coords, dim=0).repeat(batch_size, 1, 1, 1)
27 | x = torch.cat((coords.to(x.device), x), dim=1)
28 | return x
29 |
30 |
31 | class SpatialSoftArgmax(nn.Module):
32 | """Spatial softmax as defined in https://arxiv.org/abs/1504.00702.
33 |
34 | Concretely, the spatial softmax of each feature map is used to compute a weighted
35 | mean of the pixel locations, effectively performing a soft arg-max over the feature
36 | dimension.
37 | """
38 |
39 | def __init__(self, normalize: bool = True) -> None:
40 | super().__init__()
41 |
42 | self.normalize = normalize
43 |
44 | def _coord_grid(
45 | self,
46 | h: int,
47 | w: int,
48 | device: torch.device,
49 | ) -> torch.Tensor:
50 | if self.normalize:
51 | return torch.stack(
52 | torch.meshgrid(
53 | torch.linspace(-1, 1, w, device=device),
54 | torch.linspace(-1, 1, h, device=device),
55 | indexing="ij",
56 | )
57 | )
58 | return torch.stack(
59 | torch.meshgrid(
60 | torch.arange(0, w, device=device),
61 | torch.arange(0, h, device=device),
62 | indexing="ij",
63 | )
64 | )
65 |
66 | def forward(self, x: torch.Tensor) -> torch.Tensor:
67 | assert x.ndim == 4, "Expecting a tensor of shape (B, C, H, W)."
68 |
69 | # Compute a spatial softmax over the input:
70 | # Given an input of shape (B, C, H, W), reshape it to (B*C, H*W) then apply the
71 | # softmax operator over the last dimension.
72 | _, c, h, w = x.shape
73 | softmax = F.softmax(x.view(-1, h * w), dim=-1)
74 |
75 | # Create a meshgrid of normalized pixel coordinates.
76 | xc, yc = self._coord_grid(h, w, x.device)
77 |
78 | # Element-wise multiply the x and y coordinates with the softmax, then sum over
79 | # the h*w dimension. This effectively computes the weighted mean x and y
80 | # locations.
81 | x_mean = (softmax * xc.flatten()).sum(dim=1, keepdims=True)
82 | y_mean = (softmax * yc.flatten()).sum(dim=1, keepdims=True)
83 |
84 | # Concatenate and reshape the result to (B, C*2) where for every feature we have
85 | # the expected x and y pixel locations.
86 | return torch.cat([x_mean, y_mean], dim=1).view(-1, c * 2)
87 |
88 |
89 | class GlobalMaxPool2d(nn.Module):
90 | """Global spatial max pooling layer."""
91 |
92 | def __init__(self) -> None:
93 | super().__init__()
94 |
95 | self._pool = F.max_pool2d
96 |
97 | def forward(self, x: torch.Tensor) -> torch.Tensor:
98 | out = self._pool(x, kernel_size=x.size()[2:])
99 | for _ in range(len(out.shape[2:])):
100 | out.squeeze_(dim=-1)
101 | return out
102 |
103 |
104 | class GlobalAvgPool2d(nn.Module):
105 | """Global spatial average pooling layer."""
106 |
107 | def __init__(self) -> None:
108 | super().__init__()
109 |
110 | self._pool = F.avg_pool2d
111 |
112 | def forward(self, x: torch.Tensor) -> torch.Tensor:
113 | out = self._pool(x, kernel_size=x.size()[2:])
114 | for _ in range(len(out.shape[2:])):
115 | out.squeeze_(dim=-1)
116 | return out
117 |
--------------------------------------------------------------------------------
/ibc/optimizers.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import dataclasses
4 | import enum
5 | from typing import Protocol
6 |
7 | import numpy as np
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 |
12 | # =================================================================== #
13 | # Model optimization.
14 | # =================================================================== #
15 |
16 |
17 | @dataclasses.dataclass
18 | class OptimizerConfig:
19 | learning_rate: float = 1e-3
20 | weight_decay: float = 0.0
21 | beta1: float = 0.9
22 | beta2: float = 0.999
23 | lr_scheduler_step: int = 100
24 | lr_scheduler_gamma: float = 0.99
25 |
26 |
27 | # =================================================================== #
28 | # Stochastic optimization for EBM training and inference.
29 | # =================================================================== #
30 |
31 |
32 | @dataclasses.dataclass
33 | class StochasticOptimizerConfig:
34 | bounds: np.ndarray
35 | """Bounds on the samples, min/max for each dimension."""
36 |
37 | iters: int
38 | """The total number of inference iters."""
39 |
40 | train_samples: int
41 | """The number of counter-examples to sample per iter during training."""
42 |
43 | inference_samples: int
44 | """The number of candidates to sample per iter during inference."""
45 |
46 |
47 | class StochasticOptimizer(Protocol):
48 | """Functionality that needs to be implemented by all stochastic optimizers."""
49 |
50 | device: torch.device
51 |
52 | def sample(self, batch_size: int, ebm: nn.Module) -> torch.Tensor:
53 | """Sample counter-negatives for feeding to the InfoNCE objective."""
54 |
55 | def infer(self, x: torch.Tensor, ebm: nn.Module) -> torch.Tensor:
56 | """Optimize for the best action conditioned on the current observation."""
57 |
58 |
59 | @dataclasses.dataclass
60 | class DerivativeFreeConfig(StochasticOptimizerConfig):
61 | noise_scale: float = 0.33
62 | noise_shrink: float = 0.5
63 | iters: int = 3
64 | train_samples: int = 256
65 | inference_samples: int = 2 ** 14
66 |
67 |
68 | @dataclasses.dataclass
69 | class DerivativeFreeOptimizer:
70 | """A simple derivative-free optimizer. Great for up to 5 dimensions."""
71 |
72 | device: torch.device
73 | noise_scale: float
74 | noise_shrink: float
75 | iters: int
76 | train_samples: int
77 | inference_samples: int
78 | bounds: np.ndarray
79 |
80 | @staticmethod
81 | def initialize(
82 | config: DerivativeFreeConfig, device_type: str
83 | ) -> DerivativeFreeOptimizer:
84 | return DerivativeFreeOptimizer(
85 | device=torch.device(device_type if torch.cuda.is_available() else "cpu"),
86 | noise_scale=config.noise_scale,
87 | noise_shrink=config.noise_shrink,
88 | iters=config.iters,
89 | train_samples=config.train_samples,
90 | inference_samples=config.inference_samples,
91 | bounds=config.bounds,
92 | )
93 |
94 | def _sample(self, num_samples: int) -> torch.Tensor:
95 | """Helper method for drawing samples from the uniform random distribution."""
96 | size = (num_samples, self.bounds.shape[1])
97 | samples = np.random.uniform(self.bounds[0, :], self.bounds[1, :], size=size)
98 | return torch.as_tensor(samples, dtype=torch.float32, device=self.device)
99 |
100 | def sample(self, batch_size: int, ebm: nn.Module) -> torch.Tensor:
101 | del ebm # The derivative-free optimizer does not use the ebm for sampling.
102 | samples = self._sample(batch_size * self.train_samples)
103 | return samples.reshape(batch_size, self.train_samples, -1)
104 |
105 | @torch.no_grad()
106 | def infer(self, x: torch.Tensor, ebm: nn.Module) -> torch.Tensor:
107 | """Optimize for the best action given a trained EBM."""
108 | noise_scale = self.noise_scale
109 | bounds = torch.as_tensor(self.bounds).to(self.device)
110 |
111 | samples = self._sample(x.size(0) * self.inference_samples)
112 | samples = samples.reshape(x.size(0), self.inference_samples, -1)
113 |
114 | for i in range(self.iters):
115 | # Compute energies.
116 | energies = ebm(x, samples)
117 | probs = F.softmax(-1.0 * energies, dim=-1)
118 |
119 | # Resample with replacement.
120 | idxs = torch.multinomial(probs, self.inference_samples, replacement=True)
121 | samples = samples[torch.arange(samples.size(0)).unsqueeze(-1), idxs]
122 |
123 | # Add noise and clip to target bounds.
124 | samples = samples + torch.randn_like(samples) * noise_scale
125 | samples = samples.clamp(min=bounds[0, :], max=bounds[1, :])
126 |
127 | noise_scale *= self.noise_shrink
128 |
129 | # Return target with highest probability.
130 | energies = ebm(x, samples)
131 | probs = F.softmax(-1.0 * energies, dim=-1)
132 | best_idxs = probs.argmax(dim=-1)
133 | return samples[torch.arange(samples.size(0)), best_idxs, :]
134 |
135 |
136 | class StochasticOptimizerType(enum.Enum):
137 | # Note(kevin): The paper describes three types of samplers. Right now, we just have
138 | # the derivative free sampler implemented.
139 | DERIVATIVE_FREE = enum.auto()
140 |
141 |
142 | if __name__ == "__main__":
143 | from dataset import CoordinateRegression, DatasetConfig
144 |
145 | dataset = CoordinateRegression(DatasetConfig(dataset_size=10))
146 | bounds = dataset.get_target_bounds()
147 |
148 | config = DerivativeFreeConfig(bounds=bounds, train_samples=256)
149 | so = DerivativeFreeOptimizer.initialize(config, "cuda")
150 |
151 | negatives = so.sample(64, nn.Identity())
152 | assert negatives.shape == (64, config.train_samples, bounds.shape[1])
153 |
--------------------------------------------------------------------------------
/ibc/trainer.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import dataclasses
4 | import enum
5 | from typing import Protocol
6 |
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 | from tqdm.auto import tqdm
11 |
12 | from . import experiment, models, optimizers
13 |
14 |
15 | class TrainStateProtocol(Protocol):
16 | """Functionality that needs to be implemented by all training states."""
17 |
18 | model: nn.Module
19 | device: torch.device
20 | steps: int
21 |
22 | def training_step(
23 | self, input: torch.Tensor, target: torch.Tensor
24 | ) -> experiment.TensorboardLogData:
25 | """Performs a single training step on a mini-batch of data."""
26 |
27 | def evaluate(
28 | self, dataloader: torch.utils.data.DataLoader
29 | ) -> experiment.TensorboardLogData:
30 | """Performs a full evaluation of the model on one epoch."""
31 |
32 | def predict(self, input: torch.Tensor) -> torch.Tensor:
33 | """Performs a single inference step on a mini-batch of data."""
34 |
35 |
36 | @dataclasses.dataclass
37 | class ExplicitTrainState:
38 | """An explicit feedforward policy trained with a MSE objective."""
39 |
40 | model: nn.Module
41 | optimizer: torch.optim.Optimizer
42 | scheduler: torch.optim.lr_scheduler._LRScheduler
43 | device: torch.device
44 | steps: int
45 |
46 | @staticmethod
47 | def initialize(
48 | model_config: models.ConvMLPConfig,
49 | optim_config: optimizers.OptimizerConfig,
50 | device_type: str,
51 | ) -> ExplicitTrainState:
52 | device = torch.device(device_type if torch.cuda.is_available() else "cpu")
53 | print(f"Using device: {device}")
54 |
55 | model = models.ConvMLP(config=model_config)
56 | model.to(device)
57 |
58 | optimizer = torch.optim.Adam(
59 | model.parameters(),
60 | lr=optim_config.learning_rate,
61 | weight_decay=optim_config.weight_decay,
62 | betas=(optim_config.beta1, optim_config.beta2),
63 | )
64 |
65 | scheduler = torch.optim.lr_scheduler.StepLR(
66 | optimizer,
67 | step_size=optim_config.lr_scheduler_step,
68 | gamma=optim_config.lr_scheduler_gamma,
69 | )
70 |
71 | return ExplicitTrainState(
72 | model=model,
73 | optimizer=optimizer,
74 | scheduler=scheduler,
75 | device=device,
76 | steps=0,
77 | )
78 |
79 | def training_step(
80 | self, input: torch.Tensor, target: torch.Tensor
81 | ) -> experiment.TensorboardLogData:
82 | self.model.train()
83 |
84 | input = input.to(self.device)
85 | target = target.to(self.device)
86 |
87 | out = self.model(input)
88 | loss = F.mse_loss(out, target)
89 |
90 | self.optimizer.zero_grad(set_to_none=True)
91 | loss.backward()
92 | self.optimizer.step()
93 | self.scheduler.step()
94 |
95 | self.steps += 1
96 |
97 | return experiment.TensorboardLogData(
98 | scalars={
99 | "train/loss": loss.item(),
100 | "train/learning_rate": self.scheduler.get_last_lr()[0],
101 | }
102 | )
103 |
104 | @torch.no_grad()
105 | def evaluate(
106 | self, dataloader: torch.utils.data.DataLoader
107 | ) -> experiment.TensorboardLogData:
108 | self.model.eval()
109 |
110 | total_mse = 0.0
111 | for input, target in tqdm(dataloader, leave=False):
112 | input = input.to(self.device)
113 | target = target.to(self.device)
114 |
115 | out = self.model(input)
116 | mse = F.mse_loss(out, target, reduction="none")
117 | total_mse += mse.mean(dim=-1).sum().item()
118 |
119 | mean_mse = total_mse / len(dataloader.dataset)
120 | return experiment.TensorboardLogData(scalars={"test/mse": mean_mse})
121 |
122 | @torch.no_grad()
123 | def predict(self, input: torch.Tensor) -> torch.Tensor:
124 | self.model.eval()
125 | return self.model(input.to(self.device))
126 |
127 |
128 | @dataclasses.dataclass
129 | class ImplicitTrainState:
130 | """An implicit conditional EBM trained with an InfoNCE objective."""
131 |
132 | model: nn.Module
133 | optimizer: torch.optim.Optimizer
134 | scheduler: torch.optim.lr_scheduler._LRScheduler
135 | stochastic_optimizer: optimizers.StochasticOptimizer
136 | device: torch.device
137 | steps: int
138 |
139 | @staticmethod
140 | def initialize(
141 | model_config: models.ConvMLPConfig,
142 | optim_config: optimizers.OptimizerConfig,
143 | stochastic_optim_config: optimizers.DerivativeFreeConfig,
144 | device_type: str,
145 | ) -> ImplicitTrainState:
146 | device = torch.device(device_type if torch.cuda.is_available() else "cpu")
147 | print(f"Using device: {device}")
148 |
149 | model = models.EBMConvMLP(config=model_config)
150 | model.to(device)
151 |
152 | optimizer = torch.optim.Adam(
153 | model.parameters(),
154 | lr=optim_config.learning_rate,
155 | weight_decay=optim_config.weight_decay,
156 | betas=(optim_config.beta1, optim_config.beta2),
157 | )
158 |
159 | scheduler = torch.optim.lr_scheduler.StepLR(
160 | optimizer,
161 | step_size=optim_config.lr_scheduler_step,
162 | gamma=optim_config.lr_scheduler_gamma,
163 | )
164 |
165 | stochastic_optimizer = optimizers.DerivativeFreeOptimizer.initialize(
166 | stochastic_optim_config,
167 | device_type,
168 | )
169 |
170 | return ImplicitTrainState(
171 | model=model,
172 | optimizer=optimizer,
173 | scheduler=scheduler,
174 | stochastic_optimizer=stochastic_optimizer,
175 | device=device,
176 | steps=0,
177 | )
178 |
179 | def training_step(
180 | self, input: torch.Tensor, target: torch.Tensor
181 | ) -> experiment.TensorboardLogData:
182 | self.model.train()
183 |
184 | input = input.to(self.device)
185 | target = target.to(self.device)
186 |
187 | # Generate N negatives, one for each element in the batch: (B, N, D).
188 | negatives = self.stochastic_optimizer.sample(input.size(0), self.model)
189 |
190 | # Merge target and negatives: (B, N+1, D).
191 | targets = torch.cat([target.unsqueeze(dim=1), negatives], dim=1)
192 |
193 | # Generate a random permutation of the positives and negatives.
194 | permutation = torch.rand(targets.size(0), targets.size(1)).argsort(dim=1)
195 | targets = targets[torch.arange(targets.size(0)).unsqueeze(-1), permutation]
196 |
197 | # Get the original index of the positive. This will serve as the class label
198 | # for the loss.
199 | ground_truth = (permutation == 0).nonzero()[:, 1].to(self.device)
200 |
201 | # For every element in the mini-batch, there is 1 positive for which the EBM
202 | # should output a low energy value, and N negatives for which the EBM should
203 | # output high energy values.
204 | energy = self.model(input, targets)
205 |
206 | # Interpreting the energy as a negative logit, we can apply a cross entropy loss
207 | # to train the EBM.
208 | logits = -1.0 * energy
209 | loss = F.cross_entropy(logits, ground_truth)
210 |
211 | self.optimizer.zero_grad(set_to_none=True)
212 | loss.backward()
213 | self.optimizer.step()
214 | self.scheduler.step()
215 |
216 | self.steps += 1
217 |
218 | return experiment.TensorboardLogData(
219 | scalars={
220 | "train/loss": loss.item(),
221 | "train/learning_rate": self.scheduler.get_last_lr()[0],
222 | }
223 | )
224 |
225 | @torch.no_grad()
226 | def evaluate(
227 | self, dataloader: torch.utils.data.DataLoader
228 | ) -> experiment.TensorboardLogData:
229 | self.model.eval()
230 |
231 | total_mse = 0.0
232 | for input, target in tqdm(dataloader, leave=False):
233 | input = input.to(self.device)
234 | target = target.to(self.device)
235 |
236 | out = self.stochastic_optimizer.infer(input, self.model)
237 |
238 | mse = F.mse_loss(out, target, reduction="none")
239 | total_mse += mse.mean(dim=-1).sum().item()
240 |
241 | mean_mse = total_mse / len(dataloader.dataset)
242 | return experiment.TensorboardLogData(scalars={"test/mse": mean_mse})
243 |
244 | @torch.no_grad()
245 | def predict(self, input: torch.Tensor) -> torch.Tensor:
246 | self.model.eval()
247 | return self.stochastic_optimizer.infer(input.to(self.device), self.model)
248 |
249 |
250 | class PolicyType(enum.Enum):
251 | EXPLICIT = ExplicitTrainState
252 | """An explicit policy is a feedforward structure trained with a MSE objective."""
253 |
254 | IMPLICIT = ImplicitTrainState
255 | """An implicit policy is a conditional EBM trained with an InfoNCE objective."""
256 |
--------------------------------------------------------------------------------
/ibc/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 |
4 | import numpy as np
5 | import torch
6 |
7 |
8 | def seed_rngs(seed: int, pytorch: bool = True) -> None:
9 | os.environ["PYTHONHASHSEED"] = str(seed)
10 | random.seed(seed)
11 | np.random.seed(seed)
12 | if pytorch:
13 | torch.manual_seed(seed)
14 |
15 |
16 | def set_cudnn(deterministic: bool = False, benchmark: bool = True) -> None:
17 | torch.backends.cudnn.deterministic = deterministic
18 | torch.backends.cudnn.benchmark = benchmark
19 |
--------------------------------------------------------------------------------
/mypy.ini:
--------------------------------------------------------------------------------
1 | [mypy]
2 | ignore_missing_imports = True
3 |
--------------------------------------------------------------------------------
/plot.py:
--------------------------------------------------------------------------------
1 | """Generate Figure 4 plot. Pass in --help flag for options."""
2 |
3 | import dataclasses
4 | import pathlib
5 | from typing import Dict, Tuple
6 |
7 | import dcargs
8 | import matplotlib
9 | import matplotlib.pyplot as plt
10 | import numpy as np
11 | import torch
12 | from scipy.spatial import ConvexHull
13 | from tqdm.auto import tqdm
14 |
15 | from ibc.dataset import CoordinateRegression
16 | from ibc.experiment import Experiment
17 | from ibc.trainer import TrainStateProtocol
18 | from train import TrainConfig, make_dataloaders, make_train_state
19 |
20 |
21 | @dataclasses.dataclass
22 | class Args:
23 | experiment_name: str
24 | plot_dir: str = "assets"
25 | dpi: int = 200
26 | threshold: float = 140
27 |
28 |
29 | def eval(
30 | train_state: TrainStateProtocol,
31 | dataloaders: Dict[str, torch.utils.data.DataLoader],
32 | ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
33 | dataset_test = dataloaders["test"].dataset
34 | dataset_train = dataloaders["train"].dataset
35 | assert isinstance(dataset_test, CoordinateRegression)
36 | assert isinstance(dataset_train, CoordinateRegression)
37 |
38 | total_mse = 0.0
39 | num_small_err = 0
40 | pixel_error = []
41 | for batch in tqdm(dataloaders["test"]):
42 | input, target = batch
43 | prediction = train_state.predict(input).cpu().numpy()
44 | target = target.cpu().numpy()
45 |
46 | pred_unscaled = np.array(prediction)
47 | pred_unscaled += 1
48 | pred_unscaled /= 2
49 | pred_unscaled[:, 0] *= dataset_test.resolution[0] - 1
50 | pred_unscaled[:, 1] *= dataset_test.resolution[1] - 1
51 |
52 | target_unscaled = np.array(target)
53 | target_unscaled += 1
54 | target_unscaled /= 2
55 | target_unscaled[:, 0] *= dataset_test.resolution[0] - 1
56 | target_unscaled[:, 1] *= dataset_test.resolution[1] - 1
57 |
58 | diff = pred_unscaled - target_unscaled
59 | error = np.asarray(np.linalg.norm(diff, axis=1))
60 | num_small_err += len(error[error < 1.0])
61 | pixel_error.extend(error.tolist())
62 | total_mse += (diff ** 2).mean(axis=1).sum()
63 |
64 | total_test = len(dataset_test)
65 | average_mse = total_mse / total_test
66 | print(f"Test set MSE: {average_mse} ({num_small_err}/{total_test})")
67 |
68 | test_coords = dataset_test.coordinates
69 | train_coords = dataset_train.coordinates
70 | return train_coords, test_coords, np.asarray(pixel_error)
71 |
72 |
73 | def plot(
74 | train_coords: np.ndarray,
75 | test_coords: np.ndarray,
76 | errors: np.ndarray,
77 | resolution: Tuple[int, int],
78 | plot_path: pathlib.Path,
79 | dpi: int,
80 | threshold: float,
81 | ) -> None:
82 | # Threshold the errors so that all generated plot colors cover the same range.
83 | errors[errors >= threshold] = threshold
84 | colormap = plt.cm.Reds
85 | normalize = matplotlib.colors.Normalize(vmin=0, vmax=threshold)
86 |
87 | plt.scatter(
88 | train_coords[:, 0],
89 | train_coords[:, 1],
90 | marker="x",
91 | c="black",
92 | zorder=2,
93 | alpha=0.5,
94 | )
95 | plt.scatter(
96 | test_coords[:, 0],
97 | test_coords[:, 1],
98 | c=errors,
99 | cmap=colormap,
100 | norm=normalize,
101 | zorder=1,
102 | )
103 | plt.colorbar()
104 |
105 | # Find index of predictions with less than 1 pixel error and color them in blue.
106 | idxs = errors < 1.0
107 | plt.scatter(
108 | test_coords[idxs, 0],
109 | test_coords[idxs, 1],
110 | marker="o",
111 | c="blue",
112 | zorder=1,
113 | alpha=1.0,
114 | )
115 |
116 | # Add convex hull of train set.
117 | if train_coords.shape[0] > 2:
118 | for simplex in ConvexHull(train_coords).simplices:
119 | plt.plot(
120 | train_coords[simplex, 0],
121 | train_coords[simplex, 1],
122 | "--",
123 | zorder=2,
124 | alpha=0.5,
125 | c="black",
126 | )
127 |
128 | plt.xlim(0 - 2, resolution[1] + 2)
129 | plt.ylim(0 - 2, resolution[0] + 2)
130 |
131 | plt.savefig(plot_path, format="png", dpi=dpi)
132 | plt.close()
133 |
134 |
135 | def main(args: Args):
136 | plot_dir = pathlib.Path(args.plot_dir)
137 | plot_dir.mkdir(parents=True, exist_ok=True)
138 |
139 | experiment = Experiment(
140 | identifier=args.experiment_name,
141 | ).assert_exists()
142 |
143 | # Read saved config file.
144 | train_config = experiment.read_metadata("config", TrainConfig)
145 |
146 | # Restore training state.
147 | dataloaders = make_dataloaders(train_config)
148 | train_state = make_train_state(train_config, dataloaders["train"])
149 | experiment.restore_checkpoint(train_state)
150 | print(f"Loaded checkpoint at step: {train_state.steps}.")
151 |
152 | # Compute MSE for every test set data point.
153 | train_coords, test_coords, errors = eval(train_state, dataloaders)
154 |
155 | # Plot and dump to disk.
156 | plot(
157 | train_coords,
158 | test_coords,
159 | errors,
160 | dataloaders["test"].dataset.resolution,
161 | plot_dir / f"{args.experiment_name}.png",
162 | args.dpi,
163 | args.threshold,
164 | )
165 |
166 |
167 | if __name__ == "__main__":
168 | main(dcargs.parse(Args))
169 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | dcargs
2 | torch
3 | torchvision
4 | matplotlib
5 | tqdm
6 | tensorboard
7 | scipy
8 |
--------------------------------------------------------------------------------
/run_explicit.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | set -e
4 | set -x
5 |
6 | for train_size in 10 30
7 | do
8 | EXPERIMENT_NAME=explicit_mse_${train_size}
9 |
10 | python train.py \
11 | --experiment-name $EXPERIMENT_NAME \
12 | --train-dataset-size $train_size \
13 | --policy-type EXPLICIT \
14 | --dropout-prob 0.1 \
15 | --weight-decay 1e-4 \
16 | --max-epochs 2000 \
17 | --learning-rate 1e-3 \
18 | --spatial-reduction SPATIAL_SOFTMAX \
19 | --coord-conv \
20 |
21 | python plot.py --experiment-name $EXPERIMENT_NAME
22 | done
23 |
--------------------------------------------------------------------------------
/run_implicit.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | set -e
4 | set -x
5 |
6 | for train_size in 10 30
7 | do
8 | EXPERIMENT_NAME=implicit_ebm_${train_size}
9 |
10 | python train.py \
11 | --experiment-name $EXPERIMENT_NAME \
12 | --train-dataset-size $train_size \
13 | --policy-type IMPLICIT \
14 | --dropout-prob 0.0 \
15 | --weight-decay 0.0 \
16 | --max-epochs 2000 \
17 | --learning-rate 1e-3 \
18 | --spatial-reduction SPATIAL_SOFTMAX \
19 | --coord-conv \
20 | --stochastic-optimizer-train-samples 128 \
21 |
22 | python plot.py --experiment-name $EXPERIMENT_NAME
23 | done
24 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | """Script for training. Pass in --help flag for options."""
2 |
3 | import dataclasses
4 | from typing import Dict, Optional
5 |
6 | import dcargs
7 | import torch
8 | from tqdm.auto import tqdm
9 |
10 | from ibc import dataset, models, optimizers, trainer, utils
11 | from ibc.experiment import Experiment
12 |
13 |
14 | @dataclasses.dataclass
15 | class TrainConfig:
16 | experiment_name: str
17 | seed: int = 0
18 | device_type: str = "cuda"
19 | train_dataset_size: int = 10
20 | test_dataset_size: int = 500
21 | max_epochs: int = 200
22 | learning_rate: float = 1e-3
23 | weight_decay: float = 0.0
24 | train_batch_size: int = 8
25 | test_batch_size: int = 64
26 | spatial_reduction: models.SpatialReduction = models.SpatialReduction.SPATIAL_SOFTMAX
27 | coord_conv: bool = False
28 | dropout_prob: Optional[float] = None
29 | num_workers: int = 1
30 | cudnn_deterministic: bool = True
31 | cudnn_benchmark: bool = False
32 | log_every_n_steps: int = 10
33 | checkpoint_every_n_steps: int = 100
34 | eval_every_n_steps: int = 1000
35 | policy_type: trainer.PolicyType = trainer.PolicyType.EXPLICIT
36 | stochastic_optimizer_train_samples: int = 64
37 |
38 |
39 | def make_dataloaders(
40 | train_config: TrainConfig,
41 | ) -> Dict[str, torch.utils.data.DataLoader]:
42 | """Initialize train/test dataloaders based on config values."""
43 | # Train split.
44 | train_dataset_config = dataset.DatasetConfig(
45 | dataset_size=train_config.train_dataset_size,
46 | seed=train_config.seed,
47 | )
48 | train_dataset = dataset.CoordinateRegression(train_dataset_config)
49 | train_dataloader = torch.utils.data.DataLoader(
50 | train_dataset,
51 | batch_size=train_config.train_batch_size,
52 | shuffle=True,
53 | num_workers=train_config.num_workers,
54 | pin_memory=torch.cuda.is_available(),
55 | )
56 |
57 | # Test split.
58 | test_dataset_config = dataset.DatasetConfig(
59 | dataset_size=train_config.test_dataset_size,
60 | seed=train_config.seed,
61 | )
62 | test_dataset = dataset.CoordinateRegression(test_dataset_config)
63 | test_dataset.exclude(train_dataset.coordinates)
64 | test_dataloader = torch.utils.data.DataLoader(
65 | test_dataset,
66 | batch_size=train_config.test_batch_size,
67 | shuffle=False,
68 | num_workers=train_config.num_workers,
69 | pin_memory=torch.cuda.is_available(),
70 | )
71 |
72 | return {
73 | "train": train_dataloader,
74 | "test": test_dataloader,
75 | }
76 |
77 |
78 | def make_train_state(
79 | train_config: TrainConfig,
80 | train_dataloader: torch.utils.data.DataLoader,
81 | ) -> trainer.TrainStateProtocol:
82 | """Initialize train state based on config values."""
83 | in_channels = 3
84 | if train_config.coord_conv:
85 | in_channels += 2
86 | residual_blocks = [16, 32, 32]
87 | cnn_config = models.CNNConfig(in_channels, residual_blocks)
88 |
89 | input_dim = 16 # We have a 1x1 conv that reduces to 16 channels.
90 | output_dim = 2
91 | if train_config.spatial_reduction == models.SpatialReduction.SPATIAL_SOFTMAX:
92 | input_dim *= 2
93 | if train_config.policy_type == trainer.PolicyType.IMPLICIT:
94 | input_dim += 2 # Dimension of the targets.
95 | output_dim = 1
96 | mlp_config = models.MLPConfig(
97 | input_dim=input_dim,
98 | hidden_dim=256,
99 | output_dim=output_dim,
100 | hidden_depth=1,
101 | dropout_prob=train_config.dropout_prob,
102 | )
103 |
104 | model_config = models.ConvMLPConfig(
105 | cnn_config=cnn_config,
106 | mlp_config=mlp_config,
107 | spatial_reduction=train_config.spatial_reduction,
108 | coord_conv=train_config.coord_conv,
109 | )
110 |
111 | optim_config = optimizers.OptimizerConfig(
112 | learning_rate=train_config.learning_rate,
113 | weight_decay=train_config.weight_decay,
114 | )
115 |
116 | train_state: trainer.TrainStateProtocol
117 | if train_config.policy_type == trainer.PolicyType.EXPLICIT:
118 | train_state = trainer.ExplicitTrainState.initialize(
119 | model_config=model_config,
120 | optim_config=optim_config,
121 | device_type=train_config.device_type,
122 | )
123 | else:
124 | target_bounds = train_dataloader.dataset.get_target_bounds()
125 | stochastic_optim_config = optimizers.DerivativeFreeConfig(
126 | bounds=target_bounds,
127 | train_samples=train_config.stochastic_optimizer_train_samples,
128 | )
129 |
130 | train_state = trainer.ImplicitTrainState.initialize(
131 | model_config=model_config,
132 | optim_config=optim_config,
133 | stochastic_optim_config=stochastic_optim_config,
134 | device_type=train_config.device_type,
135 | )
136 |
137 | return train_state
138 |
139 |
140 | def main(train_config: TrainConfig) -> None:
141 | # Seed RNGs.
142 | utils.seed_rngs(train_config.seed)
143 |
144 | # CUDA/CUDNN-related shenanigans.
145 | utils.set_cudnn(train_config.cudnn_deterministic, train_config.cudnn_benchmark)
146 |
147 | experiment = Experiment(
148 | identifier=train_config.experiment_name,
149 | ).assert_new()
150 |
151 | # Write some metadata.
152 | experiment.write_metadata("config", train_config)
153 |
154 | # Initialize train and test dataloaders.
155 | dataloaders = make_dataloaders(train_config)
156 |
157 | train_state = make_train_state(train_config, dataloaders["train"])
158 |
159 | for epoch in tqdm(range(train_config.max_epochs)):
160 | if not train_state.steps % train_config.checkpoint_every_n_steps:
161 | experiment.save_checkpoint(train_state, step=train_state.steps)
162 |
163 | if not train_state.steps % train_config.eval_every_n_steps:
164 | test_log_data = train_state.evaluate(dataloaders["test"])
165 | experiment.log(test_log_data, step=train_state.steps)
166 |
167 | for batch in dataloaders["train"]:
168 | train_log_data = train_state.training_step(*batch)
169 |
170 | # Log to tensorboard.
171 | if not train_state.steps % train_config.log_every_n_steps:
172 | experiment.log(train_log_data, step=train_state.steps)
173 |
174 | # Save one final checkpoint.
175 | experiment.save_checkpoint(train_state, step=train_state.steps)
176 |
177 |
178 | if __name__ == "__main__":
179 | main(dcargs.parse(TrainConfig, description=__doc__))
180 |
--------------------------------------------------------------------------------