├── .gitignore ├── LICENSE ├── README.md ├── data ├── data │ └── video │ │ └── README.md └── datasets │ ├── CATER │ └── dataset.py │ ├── CLEVRER │ └── dataset.py │ ├── moving_mnist │ └── dataset.py │ └── video │ └── dataset.py ├── environment.yml ├── model ├── cater-stage1.json ├── cater-stage2.json ├── cater-stage3.json ├── cater-stage4.json ├── cater.json ├── loci.py ├── main.py └── scripts │ ├── evaluation.py │ ├── playground.py │ └── training.py ├── nn ├── background.py ├── decoder.py ├── encoder.py ├── eprop_gate_l0rd.py ├── eprop_lstm.py ├── latent_classifier.py ├── predictor.py ├── residual.py ├── tracker.py └── vae.py └── utils ├── configuration.py ├── data.py ├── io.py ├── loss.py ├── optimizers.py ├── parallel.py ├── scheduled_sampling.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.sw? 3 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Cognitive Modeling 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Loci 2 | 3 | *Loci is an unsupervised disentangled LOCation and Identity tracking system, which excels on the CATER and related object tracking challenges featuring emergent object permanence and stable entity disentanglement via fully unsupervised learning.* 4 | 5 | Paper: "Learning What and Where - Unsupervised Disentangling Location and Identity Tracking" | [arXiv](https://arxiv.org/abs/2205.13349) 6 | 7 | 8 | https://user-images.githubusercontent.com/28415607/191936991-f4f7dccf-75cc-439f-818d-9f25c0482ca4.mp4 9 | 10 | ## Requirements 11 | A suitable [conda](https://conda.io/) environment named `loci` can be created 12 | and activated with: 13 | 14 | ``` 15 | conda env create -f environment.yaml 16 | conda activate loci 17 | ``` 18 | 19 | ## Dataset and trained models 20 | 21 | A preprocessed CATER dataset together with the 5 trained networks from the paper can be found [here](https://unitc-my.sharepoint.com/:f:/g/personal/iiimt01_cloud_uni-tuebingen_de/Et0PVeCi7OhMuaz60a5RtcMBgS4Sq-fLAZkjNJsDVFgyOw?e=fLh7xN](https://unitc-my.sharepoint.com/:f:/g/personal/iiimt01_cloud_uni-tuebingen_de/Et0PVeCi7OhMuaz60a5RtcMBgS4Sq-fLAZkjNJsDVFgyOw?e=WzMItA)) 22 | 23 | The dataset folder (CATER) needs to be copied to ```data/data/video/``` 24 | 25 | ## Interactive GUI 26 | 27 | 28 | https://user-images.githubusercontent.com/28415607/191945713-fb43df65-0247-459e-9944-f8c3c7331c93.mp4 29 | 30 | 31 | We provide an interactive GUI to explore the learned representations of the model. The GUI can load the extracted latent state for one slot. In the top left grid the bits of the gestalt code can be flipped, while in the top right image the position can be changed (by clicking or scrolling). The Bottom half of the GUI shows the composition of the background with the reconstructed slot content as well as the entity's RGB repressentation and mask. 32 | 33 | Run the GUI (extracted latent states can be found [here](https://unitc-my.sharepoint.com/:f:/g/personal/iiimt01_cloud_uni-tuebingen_de/Et0PVeCi7OhMuaz60a5RtcMBgS4Sq-fLAZkjNJsDVFgyOw?e=fLh7xN)): 34 | 35 | ``` 36 | python -m model.scripts.playground -cfg model/cater.json \ 37 | -background data/data/video/CATER/background.jpg -load net2.pt \ 38 | -latent latent-states/net2/latent-0000-07.pickle 39 | ``` 40 | 41 | ## Training 42 | 43 | Training can be started with: 44 | 45 | ``` 46 | python -m model.main -train -cfg model/cater-stage1.json 47 | ``` 48 | 49 | ## Evaluation 50 | 51 | A trained model can be evaluated with: 52 | 53 | ``` 54 | python -m model.main -eval -testset -cfg model/cater.json -load net1.pt 55 | ``` 56 | 57 | Images and latent states can be generated using: 58 | 59 | 60 | ``` 61 | python -m model.main -save -testset -cfg model/cater.json -load net1.pt 62 | ``` 63 | -------------------------------------------------------------------------------- /data/data/video/README.md: -------------------------------------------------------------------------------- 1 | place for datasets 2 | -------------------------------------------------------------------------------- /data/datasets/CATER/dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils import data 2 | from typing import Tuple, Union, List 3 | import numpy as np 4 | import json 5 | import math 6 | import cv2 7 | import h5py 8 | import os 9 | import pickle 10 | 11 | __author__ = "Manuel Traub" 12 | 13 | class RamImage(): 14 | def __init__(self, path): 15 | 16 | fd = open(path, 'rb') 17 | img_str = fd.read() 18 | fd.close() 19 | 20 | self.img_raw = np.frombuffer(img_str, np.uint8) 21 | 22 | def to_numpy(self): 23 | return cv2.imdecode(self.img_raw, cv2.IMREAD_COLOR) 24 | 25 | class CaterSample(data.Dataset): 26 | def __init__(self, root_path: str, data_path: str, size: Tuple[int, int]): 27 | 28 | data_path = os.path.join(root_path, data_path, "train", f'{size[0]}x{size[1]}') 29 | 30 | frames = [] 31 | self.size = size 32 | 33 | for file in os.listdir(data_path): 34 | if file.startswith("frame") and (file.endswith(".jpg") or file.endswith(".png")): 35 | frames.append(os.path.join(data_path, file)) 36 | 37 | frames.sort() 38 | self.imgs = [] 39 | for path in frames: 40 | self.imgs.append(RamImage(path)) 41 | 42 | def get_data(self): 43 | 44 | frames = np.zeros((301,3,self.size[1], self.size[0]),dtype=np.float32) 45 | for i in range(len(self.imgs)): 46 | img = self.imgs[i].to_numpy() 47 | frames[i] = img.transpose(2, 0, 1).astype(np.float32) / 255.0 48 | 49 | return frames 50 | 51 | 52 | class CaterDataset(data.Dataset): 53 | 54 | def save(self): 55 | state = { 'samples': self.samples, 'labels': self.labels } 56 | with open(self.file, "wb") as outfile: 57 | pickle.dump(state, outfile) 58 | 59 | def load(self): 60 | with open(self.file, "rb") as infile: 61 | state = pickle.load(infile) 62 | self.samples = state['samples'] 63 | self.labels = state['labels'] 64 | 65 | def __init__(self, root_path: str, dataset_name: str, type: str, size: Tuple[int, int]): 66 | 67 | data_path = f'data/data/video/{dataset_name}' 68 | data_path = os.path.join(root_path, data_path) 69 | self.file = os.path.join(data_path, f'dataset-{size[0]}x{size[1]}-{type}.pickle') 70 | self.train = (type == "train") 71 | self.val = (type == "val") 72 | self.test = (type == "test") 73 | 74 | self.samples = [] 75 | self.labels = [] 76 | 77 | if os.path.exists(self.file): 78 | self.load() 79 | else: 80 | 81 | samples = list(filter(lambda x: x.startswith("0"), next(os.walk(data_path))[1])) 82 | num_all_samples = len(samples) 83 | num_samples = 0 84 | sample_start = 0 85 | 86 | if type == "train": 87 | num_samples = int(num_all_samples * 0.7 * 0.8) 88 | if type == "val": 89 | num_samples = int(num_all_samples * 0.7 * 0.2) 90 | if type == "test": 91 | num_samples = int(num_all_samples * 0.3) 92 | 93 | if type == "val": 94 | sample_start = int(num_all_samples * 0.7 * 0.8) 95 | if type == "test": 96 | sample_start = int(num_all_samples * 0.7) 97 | 98 | for i, dir in enumerate(samples[sample_start:sample_start+num_samples]): 99 | self.samples.append(CaterSample(data_path, dir, size)) 100 | self.labels.append(json.load(open(os.path.join(data_path, "labels", f"{dir}.json")))) 101 | 102 | print(f"Loading CATER {type} [{i * 100 / num_samples:.2f}]", flush=True) 103 | 104 | self.save() 105 | 106 | self.length = len(self.samples) 107 | self.background = None 108 | if "background.jpg" in os.listdir(data_path): 109 | self.background = cv2.imread(os.path.join(data_path, "background.jpg")) 110 | self.background = cv2.resize(self.background, dsize=size, interpolation=cv2.INTER_CUBIC) 111 | self.background = self.background.transpose(2, 0, 1).astype(np.float32) / 255.0 112 | self.background = self.background.reshape(1, self.background.shape[0], self.background.shape[1], self.background.shape[2]) 113 | 114 | print(f"CaterDataset[{type}]: {self.length}") 115 | 116 | if len(self) == 0: 117 | raise FileNotFoundError(f'Found no dataset at {self.data_path}') 118 | 119 | 120 | self.cam = np.array([ 121 | (1.4503, 1.6376, 0.0000, -0.0251), 122 | (-1.0346, 0.9163, 2.5685, 0.0095), 123 | (-0.6606, 0.5850, -0.4748, 10.5666), 124 | (-0.6592, 0.5839, -0.4738, 10.7452) 125 | ]) 126 | 127 | self.z = 0.3421497941017151 128 | 129 | self.object_actions = { 130 | 'sphere_slide': 0, 131 | 'sphere_pick_place': 1, 132 | 'spl_slide': 2, 133 | 'spl_pick_place': 3, 134 | 'spl_rotate': 4, 135 | 'cylinder_slide': 5, 136 | 'cylinder_pick_place': 6, 137 | 'cylinder_rotate': 7, 138 | 'cube_slide': 8, 139 | 'cube_pick_place': 9, 140 | 'cube_rotate': 10, 141 | 'cone_slide': 11, 142 | 'cone_pick_place': 12, 143 | 'cone_contain': 13, 144 | 'sphere_no_op': 14, 145 | 'spl_no_op': 14, 146 | 'cylinder_no_op': 14, 147 | 'cube_no_op': 14, 148 | 'cone_no_op': 14, 149 | } 150 | 151 | self.object_materials = { 152 | 'sphere_rubber': 0, 153 | 'sphere_metal': 1, 154 | 'cylinder_rubber': 2, 155 | 'cylinder_metal': 3, 156 | 'cube_rubber': 4, 157 | 'cube_metal': 5, 158 | 'cone_rubber': 6, 159 | 'cone_metal': 7, 160 | } 161 | 162 | self.object_sizes = { 163 | 'sphere_small': 0, 164 | 'sphere_medium': 1, 165 | 'sphere_large': 2, 166 | 'cylinder_small': 3, 167 | 'cylinder_medium': 4, 168 | 'cylinder_large': 5, 169 | 'cube_small': 6, 170 | 'cube_medium': 7, 171 | 'cube_large': 8, 172 | 'cone_small': 9, 173 | 'cone_medium': 10, 174 | 'cone_large': 11, 175 | } 176 | 177 | self.object_colors = { 178 | 'sphere_red': 0, 179 | 'sphere_purple': 1, 180 | 'sphere_yellow': 2, 181 | 'sphere_brown': 3, 182 | 'sphere_gray': 4, 183 | 'sphere_blue': 5, 184 | 'sphere_cyan': 6, 185 | 'sphere_green': 7, 186 | 'cylinder_red': 8, 187 | 'cylinder_purple': 9, 188 | 'cylinder_yellow': 10, 189 | 'cylinder_brown': 11, 190 | 'cylinder_gray': 12, 191 | 'cylinder_blue': 13, 192 | 'cylinder_cyan': 14, 193 | 'cylinder_green': 15, 194 | 'cube_red': 16, 195 | 'cube_purple': 17, 196 | 'cube_yellow': 18, 197 | 'cube_brown': 19, 198 | 'cube_gray': 20, 199 | 'cube_blue': 21, 200 | 'cube_cyan': 22, 201 | 'cube_green': 23, 202 | 'cone_red': 24, 203 | 'cone_purple': 25, 204 | 'cone_yellow': 26, 205 | 'cone_brown': 27, 206 | 'cone_gray': 28, 207 | 'cone_blue': 29, 208 | 'cone_cyan': 30, 209 | 'cone_green': 31, 210 | } 211 | 212 | def project_3d_point(self, pts): 213 | """ 214 | Args: pts: Nx3 matrix, with the 3D coordinates of the points to convert 215 | Returns: Nx2 matrix, with the coordinates of the point in 2D 216 | """ 217 | p = np.matmul( 218 | self.cam, 219 | np.hstack((pts, np.ones((pts.shape[0], 1)))).transpose()).transpose() 220 | # The predictions are -1 to 1, Negating the 2nd to put low Y axis on top 221 | p[:, 0] /= p[:, -1] 222 | p[:, 1] /= -p[:, -1] 223 | return np.concatenate((p[:,1:2],p[:,0:1]), axis=1) 224 | 225 | def snitch_position(self, metadata): 226 | 227 | objects = metadata['objects'] 228 | object = [el for el in objects if el['shape'] == 'spl'][0] 229 | pts = np.zeros((len(object['locations']), 3)) 230 | for i in range(len(object['locations'])): 231 | pts[i] = object['locations'][str(i)] 232 | 233 | return pts #self.project_3d_point(pts) 234 | 235 | def localize_label(self, metadata, num_rows=3, num_cols=3): 236 | 237 | objects = metadata['objects'] 238 | object = [el for el in objects if el['shape'] == 'spl'][0] 239 | pos = object['locations'][str(len(object['locations']) - 1)] 240 | if num_rows != 3 or num_cols != 3: 241 | # In this case, need to scale the pos values to scale to the new num_rows etc 242 | pos[0] *= num_cols * 1.0 / 3 243 | pos[1] *= num_rows * 1.0 / 3 244 | # Without math.floor it would screw up on negative axis 245 | x, y = (int(math.floor(pos[0])) + num_cols, 246 | int(math.floor(pos[1])) + num_rows) 247 | cls_id = y * (2 * num_cols) + x 248 | 249 | return cls_id 250 | #return np.eye(num_rows * num_cols * 4)[cls_id] 251 | 252 | 253 | def visibility_mask(self, metadata): 254 | 255 | movements = metadata['movements'] 256 | objects = metadata['objects'] 257 | visible = {el['instance']: np.ones((301)) for el in objects} 258 | 259 | for name, motions in movements.items(): 260 | if name.startswith('Cone_'): 261 | start = -1 262 | end = 301 263 | object = "" 264 | for motion in motions: 265 | if motion[0] == '_contain': 266 | start = motion[3] 267 | object = motion[1] 268 | if motion[0] == '_pick_place' and start > 0: 269 | end = motion[2] 270 | visible[object][start:end] = 0 271 | start = -1 272 | end = 301 273 | 274 | if start > 0: 275 | visible[object][start:end] = 0 276 | 277 | return visible 278 | 279 | def actions_over_time(self, metadata): 280 | 281 | movements = metadata['movements'] 282 | objects = metadata['objects'] 283 | to_type = {el['instance']: el['shape'] for el in objects} 284 | 285 | actions_visible = np.zeros((301, 15)) 286 | actions_hidden = np.zeros((301, 15)) 287 | 288 | visible = self.visibility_mask(metadata) 289 | 290 | for name, motions in movements.items(): 291 | for motion in motions: 292 | action_index = self.object_actions[to_type[name] + motion[0]] 293 | actions_visible[motion[2]:motion[3],action_index] += visible[name][motion[2]:motion[3]] 294 | actions_hidden[motion[2]:motion[3],action_index] += 1 - visible[name][motion[2]:motion[3]] 295 | 296 | # remove no op 297 | return actions_visible[:,:-1], actions_hidden[:,:-1] 298 | 299 | def snitch_contained(self, metadata): 300 | visible = self.visibility_mask(metadata) 301 | return 1 - visible['Spl_0'] 302 | 303 | def materials_over_time(self, metadata): 304 | 305 | movements = metadata['movements'] 306 | objects = metadata['objects'] 307 | objects = {el['instance']: el['shape'] + "_" + el['material'] for el in objects} 308 | 309 | materials_visible = np.zeros((301, 8)) 310 | materials_hidden = np.zeros((301, 8)) 311 | 312 | visible = self.visibility_mask(metadata) 313 | 314 | for instance, class_name in objects.items(): 315 | if instance != 'Spl_0': 316 | index = self.object_materials[class_name] 317 | materials_visible[:,index] += visible[instance] 318 | materials_hidden[:,index] += 1 - visible[instance] 319 | 320 | return materials_visible, materials_hidden 321 | 322 | def sizes_over_time(self, metadata): 323 | 324 | movements = metadata['movements'] 325 | objects = metadata['objects'] 326 | objects = {el['instance']: el['shape'] + "_" + el['size'] for el in objects} 327 | 328 | size_visible = np.zeros((301, 12)) 329 | size_hidden = np.zeros((301, 12)) 330 | 331 | visible = self.visibility_mask(metadata) 332 | 333 | for instance, class_name in objects.items(): 334 | if instance != 'Spl_0': 335 | index = self.object_sizes[class_name] 336 | size_visible[:,index] += visible[instance] 337 | size_hidden[:,index] += 1 - visible[instance] 338 | 339 | return size_visible, size_hidden 340 | 341 | def colors_over_time(self, metadata): 342 | 343 | movements = metadata['movements'] 344 | objects = metadata['objects'] 345 | objects = {el['instance']: el['shape'] + "_" + el['color'] for el in objects} 346 | 347 | color_visible = np.zeros((301, 32)) 348 | color_hidden = np.zeros((301, 32)) 349 | 350 | visible = self.visibility_mask(metadata) 351 | 352 | for instance, class_name in objects.items(): 353 | if instance != 'Spl_0': 354 | index = self.object_colors[class_name] 355 | color_visible[:,index] += visible[instance] 356 | color_hidden[:,index] += 1 - visible[instance] 357 | 358 | return color_visible, color_hidden 359 | 360 | def __len__(self): 361 | return self.length 362 | 363 | def __getitem__(self, index: int): 364 | 365 | label = self.labels[index] 366 | 367 | snitch_positions = self.snitch_position(label) 368 | snitch_label = self.localize_label(label) 369 | snitch_contained = self.snitch_contained(label) 370 | 371 | actions_visible, actions_hidden = self.actions_over_time(label) 372 | materials_visible, materials_hidden = self.materials_over_time(label) 373 | sizes_visible, sizes_hidden = self.sizes_over_time(label) 374 | colors_visible, colors_hidden = self.colors_over_time(label) 375 | 376 | if self.background is not None: 377 | return ( 378 | self.samples[index].get_data(), 379 | self.background, 380 | snitch_positions, 381 | snitch_label, 382 | snitch_contained, 383 | actions_visible, 384 | actions_hidden, 385 | materials_visible, 386 | materials_hidden, 387 | sizes_visible, 388 | sizes_hidden, 389 | colors_visible, 390 | colors_hidden 391 | ) 392 | 393 | return ( 394 | self.samples[index].get_data(), 395 | self.background, 396 | snitch_positions, 397 | snitch_label, 398 | snitch_contained, 399 | actions_visible, 400 | actions_hidden, 401 | materials_visible, 402 | materials_hidden, 403 | sizes_visible, 404 | sizes_hidden, 405 | colors_visible, 406 | colors_hidden 407 | ) 408 | 409 | 410 | class CaterLatentDataset(data.Dataset): 411 | def __init__(self, root_path: str, filename: str, type: str): 412 | self.type = type 413 | self.data_path = os.path.join(root_path, filename) 414 | 415 | self.dataset = h5py.File(self.data_path, 'r') 416 | 417 | self.length = len(self.dataset['train']["snitch_positions"]) + len(self.dataset['test']["snitch_positions"]) 418 | 419 | if len(self) == 0: 420 | raise FileNotFoundError(f'Found no dataset at {self.data_path}') 421 | 422 | self.dataset.close() 423 | self.dataset = None 424 | 425 | self.cam = np.array([ 426 | (1.4503, 1.6376, 0.0000, -0.0251), 427 | (-1.0346, 0.9163, 2.5685, 0.0095), 428 | (-0.6606, 0.5850, -0.4748, 10.5666), 429 | (-0.6592, 0.5839, -0.4738, 10.7452) 430 | ]) 431 | 432 | self.z = 0.3421497941017151 433 | 434 | def __len__(self): 435 | if self.type == "train": 436 | return int(self.length * 0.7 * 0.8) 437 | 438 | if self.type == "val": 439 | return int(self.length * 0.7 * 0.2) 440 | 441 | return int(self.length * 0.3) 442 | 443 | def __getitem__(self, index: int): 444 | if self.dataset is None: 445 | self.dataset = h5py.File(self.data_path, 'r') 446 | 447 | if self.type == "val": 448 | index = index + int(self.length * 0.7 * 0.8) 449 | 450 | if self.type == "test": 451 | index = index + int(self.length * 0.7) 452 | 453 | 454 | if index >= len(self.dataset['train']["snitch_positions"]): 455 | index = index - len(self.dataset['train']["snitch_positions"]) 456 | latent_states = self.dataset['test']["object_states"][index].astype(np.float32) 457 | snitch_positions = self.dataset['test']["snitch_positions"][index] 458 | snitch_label = self.dataset['test']["snitch_label"][index] 459 | snitch_contained = self.dataset['test']["snitch_contained"][index] 460 | actions_visible = self.dataset['test']["actions_visible"][index] 461 | actions_hidden = self.dataset['test']["actions_hidden"][index] 462 | materials_visible = self.dataset['test']["materials_visible"][index] 463 | materials_hidden = self.dataset['test']["materials_hidden"][index] 464 | sizes_visible = self.dataset['test']["sizes_visible"][index] 465 | sizes_hidden = self.dataset['test']["sizes_hidden"][index] 466 | colors_visible = self.dataset['test']["colors_visible"][index] 467 | colors_hidden = self.dataset['test']["colors_hidden"][index] 468 | 469 | latent_states = self.dataset['train']["object_states"][index].astype(np.float32) 470 | snitch_positions = self.dataset['train']["snitch_positions"][index] 471 | snitch_label = self.dataset['train']["snitch_label"][index] 472 | snitch_contained = self.dataset['train']["snitch_contained"][index] 473 | actions_visible = self.dataset['train']["actions_visible"][index] 474 | actions_hidden = self.dataset['train']["actions_hidden"][index] 475 | materials_visible = self.dataset['train']["materials_visible"][index] 476 | materials_hidden = self.dataset['train']["materials_hidden"][index] 477 | sizes_visible = self.dataset['train']["sizes_visible"][index] 478 | sizes_hidden = self.dataset['train']["sizes_hidden"][index] 479 | colors_visible = self.dataset['train']["colors_visible"][index] 480 | colors_hidden = self.dataset['train']["colors_hidden"][index] 481 | return ( 482 | latent_states, 483 | snitch_positions, 484 | snitch_label, 485 | snitch_contained, 486 | actions_visible, 487 | actions_hidden, 488 | materials_visible, 489 | materials_hidden, 490 | sizes_visible, 491 | sizes_hidden, 492 | colors_visible, 493 | colors_hidden 494 | ) 495 | -------------------------------------------------------------------------------- /data/datasets/CLEVRER/dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils import data 2 | from typing import Tuple, Union, List 3 | import numpy as np 4 | import json 5 | import math 6 | import cv2 7 | import h5py 8 | import os 9 | import pickle 10 | 11 | __author__ = "Manuel Traub" 12 | 13 | class RamImage(): 14 | def __init__(self, path): 15 | 16 | fd = open(path, 'rb') 17 | img_str = fd.read() 18 | fd.close() 19 | 20 | self.img_raw = np.frombuffer(img_str, np.uint8) 21 | 22 | def to_numpy(self): 23 | return cv2.imdecode(self.img_raw, cv2.IMREAD_COLOR) 24 | 25 | class ClevrerSample(data.Dataset): 26 | def __init__(self, root_path: str, data_path: str, size: Tuple[int, int]): 27 | 28 | data_path = os.path.join(root_path, data_path, "train", f'{size[0]}x{size[1]}') 29 | 30 | frames = [] 31 | self.size = size 32 | 33 | for file in os.listdir(data_path): 34 | if file.startswith("frame") and (file.endswith(".jpg") or file.endswith(".png")): 35 | frames.append(os.path.join(data_path, file)) 36 | 37 | frames.sort() 38 | self.imgs = [] 39 | for path in frames: 40 | self.imgs.append(RamImage(path)) 41 | 42 | def get_data(self): 43 | 44 | frames = np.zeros((128,3,self.size[1], self.size[0]),dtype=np.float32) 45 | for i in range(len(self.imgs)): 46 | img = self.imgs[i].to_numpy() 47 | frames[i] = img.transpose(2, 0, 1).astype(np.float32) / 255.0 48 | 49 | return frames 50 | 51 | 52 | class ClevrerDataset(data.Dataset): 53 | 54 | def save(self): 55 | with open(self.file, "wb") as outfile: 56 | pickle.dump(self.samples, outfile) 57 | 58 | def load(self): 59 | with open(self.file, "rb") as infile: 60 | self.samples = pickle.load(infile) 61 | 62 | def __init__(self, root_path: str, dataset_name: str, type: str, size: Tuple[int, int]): 63 | 64 | data_path = f'data/data/video/{dataset_name}' 65 | data_path = os.path.join(root_path, data_path) 66 | self.file = os.path.join(data_path, f'dataset-{size[0]}x{size[1]}.pickle') 67 | self.train = (type == "train") 68 | 69 | self.samples = [] 70 | 71 | if os.path.exists(self.file): 72 | self.load() 73 | else: 74 | 75 | samples = list(filter(lambda x: x.startswith("0"), next(os.walk(data_path))[1])) 76 | num_samples = len(samples) 77 | 78 | for i, dir in enumerate(samples): 79 | self.samples.append(ClevrerSample(data_path, dir, size)) 80 | 81 | print(f"Loading CLEVRER [{i * 100 / num_samples:.2f}]", flush=True) 82 | 83 | self.save() 84 | 85 | self.length = len(self.samples) 86 | self.background = None 87 | if "background.jpg" in os.listdir(data_path): 88 | self.background = cv2.imread(os.path.join(data_path, "background.jpg")) 89 | self.background = cv2.resize(self.background, dsize=size, interpolation=cv2.INTER_CUBIC) 90 | self.background = self.background.transpose(2, 0, 1).astype(np.float32) / 255.0 91 | self.background = self.background.reshape(1, self.background.shape[0], self.background.shape[1], self.background.shape[2]) 92 | 93 | print(f"ClevrerDataset: {self.length}") 94 | 95 | if len(self) == 0: 96 | raise FileNotFoundError(f'Found no dataset at {self.data_path}') 97 | 98 | def __len__(self): 99 | if self.train: 100 | return int(self.length * 0.9) 101 | 102 | return int(self.length * 0.1) 103 | 104 | def __getitem__(self, index: int): 105 | 106 | if not self.train: 107 | index += int(self.length * 0.9) 108 | 109 | return ( 110 | self.samples[index].get_data(), 111 | self.background, 112 | ) 113 | -------------------------------------------------------------------------------- /data/datasets/moving_mnist/dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file contains the organization of the custom dataset such that it can be 3 | read efficiently in combination with the DataLoader from PyTorch to prevent that 4 | data reading and preparing becomes the bottleneck. 5 | 6 | This script was inspired by 7 | https://stanford.edu/~shervine/blog/pytorch-how-to-generate-data-parallel 8 | """ 9 | 10 | from torch.utils import data 11 | import numpy as np 12 | import h5py 13 | import os 14 | import mnist 15 | import itertools 16 | import math 17 | from PIL import Image 18 | 19 | 20 | __author__ = "Manuel Traub" 21 | 22 | class MovingMNISTDataset(data.Dataset): 23 | def __init__(self, numbers_per_image, mode, width, height, sequence_length): 24 | 25 | self.numbers_per_image = numbers_per_image 26 | self.sequence_length = sequence_length 27 | self.width = width 28 | self.height = height 29 | 30 | # Get the data and compute the possible combinations of numbers in one 31 | # sample 32 | self.mnist_dict, self.mnist_sizes = self.load_dataset(mode=mode) 33 | self.combinations = list( 34 | itertools.combinations(list(self.mnist_dict.keys()), 35 | numbers_per_image) 36 | ) 37 | 38 | def __len__(self): 39 | return 100000000 40 | 41 | def __getitem__(self, index: int): 42 | """ 43 | Single moving MNIST sample generation using the parameters of this 44 | simulator. 45 | :param combination_index: The index of the self.combination sample 46 | :return: The generated sample as numpy array(t, x, y) 47 | """ 48 | 49 | # Determine the number-combination for the current sample 50 | combination = self.combinations[np.random.randint(len(self.combinations))] 51 | mnist_images = self.get_random_images(self.mnist_dict, self.mnist_sizes, combination) 52 | 53 | # randomly generate direc/speed/position, calculate velocity vector 54 | dx = [np.random.randint(1, 5) * np.random.choice([-1,1]) for i in range(len(mnist_images))] 55 | dy = [np.random.randint(1, 5) * np.random.choice([-1,1]) for i in range(len(mnist_images))] 56 | 57 | lx = [mnist_images[i].shape[0] for i in range(len(mnist_images))] 58 | ly = [mnist_images[i].shape[1] for i in range(len(mnist_images))] 59 | 60 | x = [np.random.randint(0, self.width - lx[i]) for i in range(len(mnist_images))] 61 | y = [np.random.randint(0, self.width - ly[i]) for i in range(len(mnist_images))] 62 | 63 | video = np.zeros((self.sequence_length, 3, self.width, self.height), dtype=np.float32) 64 | 65 | for t in range(self.sequence_length): 66 | 67 | for i in range(self.numbers_per_image): 68 | x[i] = x[i] + dx[i] 69 | y[i] = y[i] + dy[i] 70 | 71 | if x[i] < 0: 72 | x[i] = -x[i] 73 | dx[i] = -dx[i] 74 | 75 | if x[i] > self.width - lx[i] - 1: 76 | x[i] = (self.width - lx[i] - 1) - (x[i] - (self.width - lx[i] - 1)) 77 | dx[i] = -dx[i] 78 | 79 | if y[i] < 0: 80 | y[i] = -y[i] 81 | dy[i] = -dy[i] 82 | 83 | if y[i] > self.height - ly[i] - 1: 84 | y[i] = (self.height - ly[i] - 1) - (y[i] - (self.height - ly[i] - 1)) 85 | dy[i] = -dy[i] 86 | 87 | video[t,:,x[i]:x[i]+lx[i],y[i]:y[i]+ly[i]] += mnist_images[i] 88 | 89 | return (np.clip(video / 255.0, 0, 1).astype(np.float32), np.zeros((1, 3, self.width, self.height), dtype=np.float32)) 90 | 91 | 92 | def load_dataset(self, mode="train"): 93 | """ 94 | Loads the dataset using the python mnist package. 95 | :param mode: Any of "train" or "test" 96 | :return: Dictionary of MNIST numbers and a list of the pixel sizes 97 | """ 98 | print("Loading Mnist {}".format(mode)) 99 | if mode == "train" or mode == "val": 100 | mnist_images = mnist.train_images() 101 | mnist_labels = mnist.train_labels() 102 | elif mode == "test": 103 | mnist_images = mnist.test_images() 104 | mnist_labels = mnist.test_labels() 105 | 106 | n_labels = np.unique(mnist_labels) 107 | mnist_dict = {} 108 | mnist_sizes = [] 109 | for i in n_labels: 110 | idxs = np.where(mnist_labels == i) 111 | mnist_dict[i] = mnist_images[idxs] 112 | mnist_sizes.append(mnist_dict[i].shape[0]) 113 | 114 | return mnist_dict, mnist_sizes 115 | 116 | def get_random_images(self, dataset, size_list, id_list): 117 | """ 118 | Returns a list of randomly chosen images from the given dataset. 119 | :param dataset: dictionary of images 120 | :param size_list: Corresponding sizes of the dataset images 121 | :param id_list: Numbers to put into the sample (e.g. [2, 7]) 122 | :return: A list of randomly chosen images of the specified numbers 123 | """ 124 | images = [] 125 | for id in id_list: 126 | idx = np.random.randint(0, size_list[id]) 127 | images.append(self.crop_image(dataset[id][idx])) 128 | 129 | return images 130 | 131 | def crop_image(self, img): 132 | sx = 0 133 | sy = 0 134 | ex = 28 135 | ey = 28 136 | 137 | while np.sum(img[sx,:]) == 0: 138 | sx += 1 139 | 140 | while np.sum(img[ex-1,:]) == 0: 141 | ex -= 1 142 | 143 | while np.sum(img[:,sy]) == 0: 144 | sy += 1 145 | 146 | while np.sum(img[:,ey-1]) == 0: 147 | ey -= 1 148 | 149 | return img[sx:ex,sy:ey] 150 | -------------------------------------------------------------------------------- /data/datasets/video/dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file contains the organization of the custom dataset such that it can be 3 | read efficiently in combination with the DataLoader from PyTorch to prevent that 4 | data reading and preparing becomes the bottleneck. 5 | 6 | This script was inspired by 7 | https://stanford.edu/~shervine/blog/pytorch-how-to-generate-data-parallel 8 | """ 9 | 10 | from torch.utils import data 11 | from typing import Tuple, Union, List 12 | import numpy as np 13 | import cv2 14 | import os 15 | import pickle 16 | 17 | __author__ = "Manuel Traub" 18 | 19 | class RamImage(): 20 | def __init__(self, path): 21 | 22 | fd = open(path, 'rb') 23 | img_str = fd.read() 24 | fd.close() 25 | 26 | self.img_raw = np.frombuffer(img_str, np.uint8) 27 | 28 | def to_numpy(self): 29 | return cv2.imdecode(self.img_raw, cv2.IMREAD_COLOR) 30 | 31 | 32 | class VideoDataset(data.Dataset): 33 | def __init__(self, root_path: str, dataset_name: str, type: str, size: Tuple[int, int], time: int, subset: bool = False): 34 | 35 | data_path = dataset_name if subset else f'data/data/video/{dataset_name}' 36 | data_path = os.path.join(root_path, data_path) 37 | self.file = os.path.join(data_path, f'dataset-{size[0]}x{size[1]}-{type}.pickle') 38 | 39 | if os.path.exists(self.file): 40 | self.load() 41 | else: 42 | data_path = os.path.join(data_path, "train") 43 | data_path = os.path.join(data_path, f'{size[0]}x{size[1]}') 44 | 45 | frames = [] 46 | 47 | for file in os.listdir(data_path): 48 | if file.startswith("frame") and (file.endswith(".jpg") or file.endswith(".png")): 49 | frames.append(os.path.join(data_path, file)) 50 | 51 | frames.sort() 52 | 53 | if not subset: 54 | if type == "train": 55 | frames = frames[:int(len(frames) * 0.9)] 56 | else: 57 | frames = frames[int(len(frames) * 0.9):] 58 | 59 | num_samples = len(frames) 60 | 61 | self.imgs = [] 62 | for i, path in enumerate(frames): 63 | self.imgs.append(RamImage(path)) 64 | 65 | if not subset and i % 1000 == 0: 66 | print(f"Loading Video {type} [{i * 100 / num_samples:.2f}]", flush=True) 67 | 68 | 69 | self.save() 70 | 71 | self.length = len(self.imgs) - time + 1 72 | self.time = time 73 | self.size = size 74 | 75 | if not subset: 76 | print(f'loaded {type} Video Dataset {dataset_name} [{self.length}]') 77 | 78 | if len(self) == 0: 79 | print(subset, self.length, len(self.frames), self.time) 80 | raise FileNotFoundError(f'Found no dataset at {self.data_path}') 81 | 82 | def save(self): 83 | with open(self.file, "wb") as outfile: 84 | pickle.dump(self.imgs, outfile) 85 | 86 | def load(self): 87 | with open(self.file, "rb") as infile: 88 | self.imgs = pickle.load(infile) 89 | 90 | def __len__(self): 91 | return max(self.length, 0) 92 | 93 | def __getitem__(self, index: int): 94 | 95 | frames = [] 96 | frames = np.zeros((self.time, 3, self.size[1], self.size[0]), dtype=np.float32) 97 | for i in range(self.time): 98 | img = self.imgs[index + i].to_numpy() 99 | frames[i] = img.transpose(2, 0, 1).astype(np.float32) / 255.0 100 | 101 | return frames, np.zeros((1, 3, self.size[1], self.size[0]), dtype=np.float32) 102 | 103 | class MultipleVideosDataset(data.Dataset): 104 | def __init__(self, root_path: str, dataset_name: str, type: str, size: Tuple[int, int], time: int): 105 | 106 | data_path = f'data/data/video/{dataset_name}' 107 | data_path = os.path.join(root_path, data_path) 108 | self.train = (type == "train") 109 | 110 | self.datasets = [] 111 | self.length = 0 112 | for dir in next(os.walk(data_path))[1]: 113 | if dir.startswith("00"): 114 | self.datasets.append(VideoDataset(data_path, dir, type, size, time, True)) 115 | self.length += self.datasets[-1].length 116 | 117 | self.background = None 118 | if "background.jpg" in os.listdir(data_path): 119 | self.background = cv2.imread(os.path.join(data_path, "background.jpg")) 120 | self.background = cv2.resize(self.background, dsize=size, interpolation=cv2.INTER_CUBIC) 121 | self.background = self.background.transpose(2, 0, 1).astype(float) / 255.0 122 | self.background = self.background.reshape(1, self.background.shape[0], self.background.shape[1], self.background.shape[2]) 123 | 124 | print(f"MultipleVideosDataset: {self.length}") 125 | 126 | if len(self) == 0: 127 | raise FileNotFoundError(f'Found no dataset at {self.data_path}') 128 | 129 | def __len__(self): 130 | if self.train: 131 | return int(self.length * 0.9) 132 | 133 | return int(self.length * 0.1) 134 | 135 | def __getitem__(self, index: int): 136 | 137 | if not self.train: 138 | index += int((self.length / 300) * 0.9) * 300 139 | 140 | length = 0 141 | for dataset in self.datasets: 142 | length += len(dataset) 143 | 144 | if index < length: 145 | index = index - (length - len(dataset)) 146 | if self.background is not None: 147 | return dataset.__getitem__(int(index)), self.background 148 | 149 | return dataset.__getitem__(int(index)) 150 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: loci 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - cudatoolkit=11.3 8 | - h5py=3.6 9 | - libpng=1.6 10 | - matplotlib=3.5 11 | - numpy=1.21 12 | - python=3.9 13 | - pytorch=1.10 14 | - torchaudio=0.10 15 | - torchvision=0.11 16 | - pip==22.2.2 17 | - pip: 18 | - einops 19 | - jsmin 20 | - opencv-python 21 | - pandas 22 | - pytorch-msssim 23 | - mnist 24 | -------------------------------------------------------------------------------- /model/cater-stage1.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_path": "cater_stage1", 3 | "datatype": "cater", 4 | "dataset": "CATER", 5 | "latent_type": "snitch_tracker", 6 | "num_workers": 2, 7 | "prefetch_factor": 2, 8 | "learning_rate": 0.0001, 9 | "sequence_len": 300, 10 | "backprop_steps": 1, 11 | "epochs": 1000, 12 | "updates": 50000000, 13 | "closed_loop": false, 14 | "teacher_forcing": 10, 15 | "statistics_offset": 10, 16 | "msssim": false, 17 | "load_optimizers": false, 18 | "scheduled_sampling": false, 19 | "entity_pretraining_steps": 30000, 20 | "background_pretraining_steps": 0, 21 | "model": { 22 | "level": 1, 23 | "batch_size": 16, 24 | "num_objects": 10, 25 | "img_channels": 3, 26 | "input_size": [240, 320], 27 | "latent_size": [15, 20], 28 | "gestalt_size": 96, 29 | "object_regularizer": 1, 30 | "position_regularizer": 1, 31 | "supervision_factor": 0.01, 32 | "time_regularizer": 1, 33 | "encoder": { 34 | "channels": 48, 35 | "level1_channels": 24, 36 | "num_layers": 3, 37 | "reg_lambda": 1e-10 38 | }, 39 | "predictor": { 40 | "heads": 2, 41 | "layers": 2, 42 | "channels_multiplier": 2, 43 | "reg_lambda": 1e-10 44 | }, 45 | "decoder": { 46 | "channels": 48, 47 | "level1_channels": 3, 48 | "num_layers": 5 49 | }, 50 | "background": { 51 | "learning_rate": 0.0001, 52 | "num_layers": 1, 53 | "reg_lambda": 1e-10, 54 | "latent_channels": 48, 55 | "level1_channels": 24, 56 | "gestalt_size": 8, 57 | "use": false 58 | } 59 | } 60 | } 61 | -------------------------------------------------------------------------------- /model/cater-stage2.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_path": "cater_stage2", 3 | "datatype": "cater", 4 | "dataset": "CATER", 5 | "latent_type": "snitch_tracker", 6 | "num_workers": 2, 7 | "prefetch_factor": 2, 8 | "learning_rate": 0.0000333333, 9 | "sequence_len": 300, 10 | "backprop_steps": 1, 11 | "epochs": 1000, 12 | "updates": 50000000, 13 | "closed_loop": false, 14 | "teacher_forcing": 10, 15 | "statistics_offset": 10, 16 | "msssim": false, 17 | "load_optimizers": false, 18 | "scheduled_sampling": false, 19 | "entity_pretraining_steps": 30000, 20 | "background_pretraining_steps": 0, 21 | "model": { 22 | "level": 1, 23 | "batch_size": 16, 24 | "num_objects": 10, 25 | "img_channels": 3, 26 | "input_size": [240, 320], 27 | "latent_size": [15, 20], 28 | "gestalt_size": 96, 29 | "object_regularizer": 1, 30 | "position_regularizer": 1, 31 | "supervision_factor": 0.1, 32 | "time_regularizer": 1, 33 | "encoder": { 34 | "channels": 48, 35 | "level1_channels": 24, 36 | "num_layers": 3, 37 | "reg_lambda": 1e-10 38 | }, 39 | "predictor": { 40 | "heads": 2, 41 | "layers": 2, 42 | "channels_multiplier": 2, 43 | "reg_lambda": 1e-10 44 | }, 45 | "decoder": { 46 | "channels": 48, 47 | "level1_channels": 3, 48 | "num_layers": 5 49 | }, 50 | "background": { 51 | "learning_rate": 0.0000333333, 52 | "num_layers": 1, 53 | "reg_lambda": 1e-10, 54 | "latent_channels": 48, 55 | "level1_channels": 24, 56 | "gestalt_size": 8, 57 | "use": false 58 | } 59 | } 60 | } 61 | -------------------------------------------------------------------------------- /model/cater-stage3.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_path": "cater_stage3", 3 | "datatype": "cater", 4 | "dataset": "CATER", 5 | "latent_type": "snitch_tracker", 6 | "num_workers": 2, 7 | "prefetch_factor": 2, 8 | "learning_rate": 0.0000333333, 9 | "sequence_len": 300, 10 | "backprop_steps": 1, 11 | "epochs": 1000, 12 | "updates": 50000000, 13 | "closed_loop": false, 14 | "teacher_forcing": 10, 15 | "statistics_offset": 10, 16 | "msssim": false, 17 | "load_optimizers": false, 18 | "scheduled_sampling": false, 19 | "entity_pretraining_steps": 30000, 20 | "background_pretraining_steps": 0, 21 | "model": { 22 | "level": 2, 23 | "batch_size": 16, 24 | "num_objects": 10, 25 | "img_channels": 3, 26 | "input_size": [240, 320], 27 | "latent_size": [15, 20], 28 | "gestalt_size": 96, 29 | "object_regularizer": 1, 30 | "position_regularizer": 1, 31 | "supervision_factor": 0.1, 32 | "time_regularizer": 1, 33 | "encoder": { 34 | "channels": 48, 35 | "level1_channels": 24, 36 | "num_layers": 3, 37 | "reg_lambda": 1e-10 38 | }, 39 | "predictor": { 40 | "heads": 2, 41 | "layers": 2, 42 | "channels_multiplier": 2, 43 | "reg_lambda": 1e-10 44 | }, 45 | "decoder": { 46 | "channels": 48, 47 | "level1_channels": 3, 48 | "num_layers": 5 49 | }, 50 | "background": { 51 | "learning_rate": 0.0000333333, 52 | "num_layers": 1, 53 | "reg_lambda": 1e-10, 54 | "latent_channels": 48, 55 | "level1_channels": 24, 56 | "gestalt_size": 8, 57 | "use": false 58 | } 59 | } 60 | } 61 | -------------------------------------------------------------------------------- /model/cater-stage4.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_path": "cater_stage4", 3 | "datatype": "cater", 4 | "dataset": "CATER", 5 | "latent_type": "snitch_tracker", 6 | "num_workers": 2, 7 | "prefetch_factor": 2, 8 | "learning_rate": 0.00001, 9 | "sequence_len": 300, 10 | "backprop_steps": 1, 11 | "epochs": 1000, 12 | "updates": 50000000, 13 | "closed_loop": false, 14 | "teacher_forcing": 10, 15 | "statistics_offset": 10, 16 | "msssim": false, 17 | "load_optimizers": false, 18 | "scheduled_sampling": false, 19 | "entity_pretraining_steps": 30000, 20 | "background_pretraining_steps": 0, 21 | "model": { 22 | "level": 2, 23 | "batch_size": 16, 24 | "num_objects": 10, 25 | "img_channels": 3, 26 | "input_size": [240, 320], 27 | "latent_size": [15, 20], 28 | "gestalt_size": 96, 29 | "object_regularizer": 1, 30 | "position_regularizer": 1, 31 | "supervision_factor": 0.3333333, 32 | "time_regularizer": 1, 33 | "encoder": { 34 | "channels": 48, 35 | "level1_channels": 24, 36 | "num_layers": 3, 37 | "reg_lambda": 1e-10 38 | }, 39 | "predictor": { 40 | "heads": 2, 41 | "layers": 2, 42 | "channels_multiplier": 2, 43 | "reg_lambda": 1e-10 44 | }, 45 | "decoder": { 46 | "channels": 48, 47 | "level1_channels": 3, 48 | "num_layers": 5 49 | }, 50 | "background": { 51 | "learning_rate": 0.00001, 52 | "num_layers": 1, 53 | "reg_lambda": 1e-10, 54 | "latent_channels": 48, 55 | "level1_channels": 24, 56 | "gestalt_size": 8, 57 | "use": false 58 | } 59 | } 60 | } 61 | -------------------------------------------------------------------------------- /model/cater.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_path": "cater", 3 | "datatype": "cater", 4 | "dataset": "CATER", 5 | "latent_type": "snitch_tracker", 6 | "num_workers": 2, 7 | "prefetch_factor": 2, 8 | "learning_rate": 0.0001, 9 | "sequence_len": 300, 10 | "backprop_steps": 1, 11 | "epochs": 1000, 12 | "updates": 50000000, 13 | "closed_loop": false, 14 | "teacher_forcing": 10, 15 | "statistics_offset": 10, 16 | "msssim": false, 17 | "load_optimizers": false, 18 | "scheduled_sampling": false, 19 | "entity_pretraining_steps": 30000, 20 | "background_pretraining_steps": 0, 21 | "model": { 22 | "level": 2, 23 | "batch_size": 16, 24 | "num_objects": 10, 25 | "img_channels": 3, 26 | "input_size": [240, 320], 27 | "latent_size": [15, 20], 28 | "gestalt_size": 96, 29 | "object_regularizer": 1, 30 | "position_regularizer": 1, 31 | "supervision_factor": 0.01, 32 | "time_regularizer": 1, 33 | "encoder": { 34 | "channels": 48, 35 | "level1_channels": 24, 36 | "num_layers": 3, 37 | "reg_lambda": 1e-10 38 | }, 39 | "predictor": { 40 | "heads": 2, 41 | "layers": 2, 42 | "channels_multiplier": 2, 43 | "reg_lambda": 1e-10 44 | }, 45 | "decoder": { 46 | "channels": 48, 47 | "level1_channels": 3, 48 | "num_layers": 5 49 | }, 50 | "background": { 51 | "learning_rate": 0.0001, 52 | "num_layers": 1, 53 | "reg_lambda": 1e-10, 54 | "latent_channels": 48, 55 | "level1_channels": 24, 56 | "gestalt_size": 8, 57 | "use": false 58 | } 59 | } 60 | } 61 | -------------------------------------------------------------------------------- /model/loci.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | import torch.nn as nn 3 | import numpy as np 4 | from typing import Tuple 5 | from einops import rearrange, repeat, reduce 6 | from utils.optimizers import SDAMSGrad 7 | from nn.decoder import LociDecoder 8 | from nn.encoder import LociEncoder 9 | from nn.predictor import LociPredictor 10 | from utils.utils import PrintGradient, InitialLatentStates 11 | from utils.loss import MaskModulatedObjectLoss, ObjectModulator, TranslationInvariantObjectLoss, PositionLoss 12 | from nn.background import BackgroundEnhancer, PrecalculatedBackground 13 | 14 | class Loci(nn.Module): 15 | def __init__( 16 | self, 17 | cfg, 18 | camera_view_matrix = None, 19 | zero_elevation = None, 20 | teacher_forcing=1 21 | ): 22 | super(Loci, self).__init__() 23 | 24 | self.teacher_forcing = teacher_forcing 25 | self.cfg = cfg 26 | 27 | self.encoder = LociEncoder( 28 | input_size = cfg.input_size, 29 | latent_size = cfg.latent_size, 30 | num_objects = cfg.num_objects, 31 | img_channels = cfg.img_channels * 2 + 6, 32 | hidden_channels = cfg.encoder.channels, 33 | level1_channels = cfg.encoder.level1_channels, 34 | num_layers = cfg.encoder.num_layers, 35 | gestalt_size = cfg.gestalt_size, 36 | batch_size = cfg.batch_size, 37 | ) 38 | 39 | self.predictor = LociPredictor( 40 | num_objects = cfg.num_objects, 41 | gestalt_size = cfg.gestalt_size, 42 | heads = cfg.predictor.heads, 43 | layers = cfg.predictor.layers, 44 | reg_lambda = cfg.predictor.reg_lambda, 45 | batch_size = cfg.batch_size, 46 | camera_view_matrix = camera_view_matrix, 47 | zero_elevation = zero_elevation 48 | ) 49 | 50 | self.decoder = LociDecoder( 51 | latent_size = cfg.latent_size, 52 | num_objects = cfg.num_objects, 53 | gestalt_size = cfg.gestalt_size, 54 | img_channels = cfg.img_channels, 55 | hidden_channels = cfg.decoder.channels, 56 | level1_channels = cfg.decoder.level1_channels, 57 | num_layers = cfg.decoder.num_layers 58 | ) 59 | 60 | if cfg.background.use: 61 | self.background = BackgroundEnhancer( 62 | input_size = cfg.input_size, 63 | gestalt_size = cfg.background.gestalt_size, 64 | img_channels = cfg.img_channels, 65 | depth = cfg.background.num_layers, 66 | latent_channels = cfg.background.latent_channels, 67 | level1_channels = cfg.background.level1_channels, 68 | batch_size = cfg.batch_size, 69 | ) 70 | else: 71 | self.background = PrecalculatedBackground( 72 | input_size = cfg.input_size, 73 | img_channels = cfg.img_channels, 74 | ) 75 | 76 | self.initial_states = InitialLatentStates( 77 | gestalt_size = cfg.gestalt_size, 78 | num_objects = cfg.num_objects, 79 | size = cfg.input_size, 80 | ) 81 | 82 | 83 | self.translation_invariant_object_loss = TranslationInvariantObjectLoss(cfg.num_objects, teacher_forcing) 84 | self.mask_modulated_object_loss = MaskModulatedObjectLoss(cfg.num_objects, teacher_forcing) 85 | self.position_loss = PositionLoss(cfg.num_objects, teacher_forcing) 86 | self.modulator = ObjectModulator(cfg.num_objects) 87 | 88 | self.background.set_level(cfg.level) 89 | self.encoder.set_level(cfg.level) 90 | self.decoder.set_level(cfg.level) 91 | self.initial_states.set_level(cfg.level) 92 | 93 | def get_init_status(self): 94 | init = [] 95 | for module in self.modules(): 96 | if callable(getattr(module, "get_init", None)): 97 | init.append(module.get_init()) 98 | 99 | assert len(set(init)) == 1 100 | return init[0] 101 | 102 | def get_openings(self): 103 | return self.predictor.get_openings() 104 | 105 | def detach(self): 106 | for module in self.modules(): 107 | if module != self and callable(getattr(module, "detach", None)): 108 | module.detach() 109 | 110 | def reset_state(self): 111 | for module in self.modules(): 112 | if module != self and callable(getattr(module, "reset_state", None)): 113 | module.reset_state() 114 | 115 | def forward(self, *input, reset=True, detach=True, mode='end2end', evaluate=False, train_background=False, test=False): 116 | 117 | if detach: 118 | self.detach() 119 | 120 | if reset: 121 | self.reset_state() 122 | 123 | if train_background: 124 | return self.background(*input)[1] 125 | 126 | return self.run_end2end(*input, evaluate=evaluate, test=test) 127 | 128 | def run_decoder( 129 | self, 130 | position: th.Tensor, 131 | gestalt: th.Tensor, 132 | priority: th.Tensor, 133 | bg_mask: th.Tensor, 134 | background: th.Tensor 135 | ): 136 | mask, object = self.decoder(position, gestalt, priority) 137 | 138 | mask = th.softmax(th.cat((mask, bg_mask), dim=1), dim=1) 139 | object = th.cat((th.sigmoid(object - 2.5), background), dim=1) 140 | 141 | _mask = mask.unsqueeze(dim=2) 142 | _object = object.view( 143 | mask.shape[0], 144 | self.cfg.num_objects + 1, 145 | self.cfg.img_channels, 146 | *mask.shape[2:] 147 | ) 148 | 149 | output = th.sum(_mask * _object, dim=1) 150 | return output, mask, object 151 | 152 | def run_end2end( 153 | self, 154 | input: th.Tensor, 155 | error: th.Tensor = None, 156 | mask: th.Tensor = None, 157 | position: th.Tensor = None, 158 | gestalt: th.Tensor = None, 159 | priority: th.Tensor = None, 160 | evaluate = False, 161 | test = False 162 | ): 163 | output_sequence = list() 164 | position_sequence = list() 165 | gestalt_sequence = list() 166 | priority_sequence = list() 167 | mask_sequence = list() 168 | object_sequence = list() 169 | background_sequence = list() 170 | error_sequence = list() 171 | 172 | position_loss = th.tensor(0, device=input.device) 173 | object_loss = th.tensor(0, device=input.device) 174 | time_loss = th.tensor(0, device=input.device) 175 | bg_mask = None 176 | 177 | if error is None or mask is None: 178 | bg_mask, background, _ = self.background(input) 179 | error = th.sqrt(reduce((input - background)**2, 'b c h w -> b 1 h w', 'mean')).detach() 180 | 181 | position, gestalt, priority = self.initial_states(error, mask, position, gestalt, priority) 182 | 183 | if mask is None: 184 | mask = self.decoder(position, gestalt, priority)[0] 185 | mask = th.softmax(th.cat((mask, bg_mask), dim=1), dim=1) 186 | 187 | 188 | output = None 189 | bg_mask = None 190 | object_last = None 191 | 192 | position_last = position 193 | object_last = self.decoder(position_last, gestalt)[-1] 194 | 195 | # background and cores ponding mask for the next time point 196 | bg_mask, background, raw_background = self.background(input, error, mask[:,-1:]) 197 | 198 | # position and gestalt for the current time point 199 | position, gestalt, priority = self.encoder(input, error, mask, object_last, position, priority) 200 | 201 | # position and gestalt for the next time point 202 | position, gestalt, priority, snitch_position = self.predictor(position, gestalt, priority) 203 | 204 | # combinded background and objects (masks) for next timepoint 205 | output, mask, object = self.run_decoder(position, gestalt, priority, bg_mask, background) 206 | 207 | if not evaluate and not test: 208 | 209 | #regularize to small possition chananges over time 210 | position_loss = position_loss + self.position_loss(position, position_last.detach(), mask[:,:-1].detach()) 211 | 212 | # regularize to encode last visible object 213 | object_cur = self.decoder(position, gestalt)[-1] 214 | object_modulated = self.decoder(*self.modulator(position, gestalt, mask[:,:-1]))[-1] 215 | object_loss = object_loss + self.mask_modulated_object_loss( 216 | object_cur, 217 | object_modulated.detach(), 218 | mask[:,:-1].detach() 219 | ) 220 | 221 | # regularize to prduce consistent object codes over time 222 | time_loss = time_loss + 0.1 * self.translation_invariant_object_loss( 223 | mask[:,:-1].detach(), 224 | object_last.detach(), 225 | position_last.detach(), 226 | object_cur, 227 | position.detach(), 228 | ) 229 | 230 | if evaluate and not test: 231 | object = self.run_decoder(position, gestalt, None, bg_mask, background)[-1] 232 | 233 | return ( 234 | output, 235 | position, 236 | gestalt, 237 | priority, 238 | mask, 239 | object, 240 | raw_background, 241 | position_loss, 242 | object_loss, 243 | time_loss, 244 | snitch_position 245 | ) 246 | -------------------------------------------------------------------------------- /model/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import sys 3 | import torch as th 4 | from copy import copy 5 | 6 | 7 | from utils.configuration import Configuration 8 | from model.scripts import training, evaluation 9 | from data.datasets.moving_mnist.dataset import MovingMNISTDataset 10 | from data.datasets.video.dataset import VideoDataset, MultipleVideosDataset 11 | from data.datasets.CATER.dataset import CaterDataset, CaterLatentDataset 12 | from data.datasets.CLEVRER.dataset import ClevrerDataset 13 | 14 | CFG_PATH = "cfg.json" 15 | 16 | if __name__ == "__main__": 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument("-cfg", default=CFG_PATH) 19 | parser.add_argument("-num-gpus", default=1, type=int) 20 | parser.add_argument("-n", default=-1, type=int) 21 | parser.add_argument("-load", default="", type=str) 22 | parser.add_argument("-dataset-file", default="", type=str) 23 | parser.add_argument("-device", default=0, type=int) 24 | parser.add_argument("-testset", action="store_true") 25 | mode_group = parser.add_mutually_exclusive_group(required=True) 26 | mode_group.add_argument("-train", action="store_true") 27 | mode_group.add_argument("-eval", action="store_true") 28 | mode_group.add_argument("-save", action="store_true") 29 | mode_group.add_argument("-export", action="store_true") 30 | parser.add_argument("-objects", action="store_true") 31 | parser.add_argument("-nice", action="store_true") 32 | parser.add_argument("-individual", action="store_true") 33 | 34 | args = parser.parse_args(sys.argv[1:]) 35 | 36 | if not args.objects and not args.nice and not args.individual: 37 | args.objects = True 38 | 39 | cfg = Configuration(args.cfg) 40 | 41 | if args.device >= 0: 42 | cfg.device = args.device 43 | cfg.model_path = f"{cfg.model_path}.device{cfg.device}" 44 | 45 | if args.n >= 0: 46 | cfg.device = args.device 47 | cfg.model_path = f"{cfg.model_path}.run{args.n}" 48 | 49 | num_gpus = th.cuda.device_count() 50 | 51 | if cfg.device >= num_gpus: 52 | cfg.device = num_gpus - 1 53 | 54 | if args.num_gpus > 0: 55 | num_gpus = args.num_gpus 56 | 57 | print(f'Using {num_gpus} GPU{"s" if num_gpus > 1 else ""}') 58 | print(f'{"Training" if args.train else "Evaluating"} model {cfg.model_path}') 59 | 60 | trainset = None 61 | valset = None 62 | testset = None 63 | if cfg.datatype == "moving_mnist": 64 | trainset = MovingMNISTDataset(2, 'train', 64, 64, cfg.sequence_len) 65 | valset = MovingMNISTDataset(2, 'train', 64, 64, cfg.sequence_len) 66 | testset = MovingMNISTDataset(2, 'test', 64, 64, cfg.sequence_len) 67 | 68 | if cfg.datatype == "video" or cfg.datatype == "multiple-videos": 69 | if args.run_testset: 70 | cfg.sequence_len = 3 71 | 72 | if args.save_patch: 73 | cfg.sequence_len = 0 74 | 75 | if args.save: 76 | cfg.sequence_len = 1 77 | cfg.model.batch_size = 1 78 | 79 | if args.save_patch or args.eval_patch: 80 | cfg.model.latent_size[0] = cfg.model.patch_grid_size[0] * 2 81 | cfg.model.latent_size[1] = cfg.model.patch_grid_size[1] * 2 82 | 83 | if cfg.datatype == "video": 84 | trainset = None if args.save and args.testset else VideoDataset("./", cfg.dataset, "train", (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2)), cfg.sequence_len + 1) 85 | valset = None if args.save else VideoDataset("./", cfg.dataset, "test", (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2)), cfg.sequence_len + 1) 86 | testset = VideoDataset("./", cfg.dataset, "test", (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2)), cfg.sequence_len + 1) 87 | cfg.sequence_len += 1 88 | 89 | if cfg.datatype == "multiple-videos": 90 | trainset = None if args.save and args.testset else MultipleVideosDataset("./", cfg.dataset, "train", (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2)), cfg.sequence_len + 1) 91 | valset = None if args.save else MultipleVideosDataset("./", cfg.dataset, "test", (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2)), cfg.sequence_len + 1) 92 | testset = MultipleVideosDataset("./", cfg.dataset, "test", (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2)), cfg.sequence_len + 1) 93 | cfg.sequence_len += 1 94 | 95 | if cfg.datatype == "clevrer": 96 | trainset = None if args.save and args.testset else ClevrerDataset("./", cfg.dataset, "train", (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2))) 97 | valset = None if args.save else ClevrerDataset("./", cfg.dataset, "val", (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2))) 98 | testset = ClevrerDataset("./", cfg.dataset, "test", (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2))) 99 | 100 | valset.train = False 101 | testset.train = False 102 | 103 | cfg.sequence_len += 1 104 | 105 | if cfg.datatype == "cater": 106 | trainset = None if args.save and args.testset else CaterDataset("./", cfg.dataset, "train", (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2))) 107 | valset = None if args.save else CaterDataset("./", cfg.dataset, "val", (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2))) 108 | testset = CaterDataset("./", cfg.dataset, "test", (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2))) 109 | 110 | 111 | cfg.sequence_len += 1 112 | 113 | if cfg.datatype == "latent-cater": 114 | if args.dataset_file != "": 115 | cfg.dataset = args.dataset_file 116 | trainset = CaterLatentDataset("./", cfg.dataset, "train") 117 | valset = CaterLatentDataset("./", cfg.dataset, "val") 118 | testset = CaterLatentDataset("./", cfg.dataset, "test") 119 | 120 | if cfg.datatype == "latent-cater": 121 | if cfg.latent_type == "snitch_tracker": 122 | training.train_latent_tracker(cfg, trainset, valset, testset, args.load) 123 | elif cfg.latent_type == "object_behavior": 124 | training.train_latent_action_classifier(cfg, trainset, testset, args.load) 125 | elif args.train: 126 | training.run(cfg, num_gpus, trainset, valset, testset, args.load, (cfg.model.level*2)) 127 | elif args.eval: 128 | evaluation.evaluate(cfg, num_gpus, testset if args.testset else valset, args.load, (cfg.model.level*2)) 129 | elif args.save: 130 | evaluation.save(cfg, testset if args.testset else trainset, args.load, (cfg.model.level*2), cfg.model.input_size, args.objects, args.nice, args.individual) 131 | elif args.export: 132 | evaluation.export_dataset(cfg, trainset, testset, args.load, f"{args.load}.latent-states") 133 | -------------------------------------------------------------------------------- /model/scripts/playground.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import sys 3 | import matplotlib.pyplot as plt 4 | from matplotlib.widgets import Slider, Button, RadioButtons 5 | import numpy as np 6 | import torch as th 7 | import torch.nn as nn 8 | import cv2 9 | import pickle 10 | from utils.utils import Gaus2D 11 | from einops import rearrange, repeat, reduce 12 | from utils.configuration import Configuration 13 | from nn.background import BackgroundEnhancer 14 | from nn.decoder import GPDecoder 15 | 16 | class LociPlayground: 17 | 18 | def __init__(self, cfg, device, file, background, gestalt = None, position = None): 19 | 20 | device = th.device(device) 21 | 22 | self.cfg = cfg 23 | self.gestalt = th.zeros((1, cfg.gestalt_size)).to(device) 24 | self.position = th.tensor([[0,0,0.05]]).to(device) 25 | self.size = cfg.input_size 26 | self.gaus2d = Gaus2D(cfg.input_size).to(device) 27 | self.gestalt_gridcell_size = 25 28 | self.gestalt_gridcell_margin = 5 29 | self.gestalt_grid_width = 16 30 | self.gestalt_grid_height = 6 31 | 32 | if gestalt is not None: 33 | self.gestalt = gestalt.to(device) 34 | 35 | if position is not None: 36 | self.position = position.to(device) 37 | 38 | self.decoder = GPDecoder( 39 | latent_size = cfg.latent_size, 40 | num_objects = 1, 41 | gestalt_size = cfg.gestalt_size, 42 | img_channels = cfg.img_channels, 43 | hidden_channels = cfg.decoder.channels, 44 | level1_channels = cfg.decoder.level1_channels, 45 | num_layers = cfg.decoder.num_layers, 46 | ).to(device) 47 | self.decoder.set_level(2) 48 | 49 | print(f'loading to device {self.decoder.mask_alpha.device}...', end='', flush=True) 50 | state = th.load(file, map_location=device) 51 | 52 | # backward compatibility 53 | model = {} 54 | for key, value in state["model"].items(): 55 | model[key.replace(".module.", ".")] = value 56 | 57 | decoder_state = {} 58 | for k, v in model.items(): 59 | if k.startswith('decoder.'): 60 | decoder_state[k.replace('decoder.', '')] = v 61 | 62 | self.decoder.load_state_dict(decoder_state) 63 | 64 | self.bg_mask = model['background.mask'] 65 | self.background = th.from_numpy(cv2.imread(background)).to(device) / 255 66 | self.background = rearrange(self.background, 'h w c -> 1 c h w') 67 | print('done', flush=True) 68 | 69 | self.fig = plt.figure(figsize=(6,6)) 70 | 71 | self.ax_gestalt = plt.subplot2grid((3, 3), (0, 0), colspan=2) 72 | self.ax_position = plt.subplot2grid((3, 3), (0, 2)) 73 | self.ax_output1 = plt.subplot2grid((3, 3), (1, 0), colspan=2, rowspan=2) 74 | self.ax_output2 = plt.subplot2grid((3, 3), (1, 2)) 75 | self.ax_output3 = plt.subplot2grid((3, 3), (2, 2)) 76 | 77 | self.outputs = [self.ax_output1, self.ax_output2, self.ax_output3] 78 | self.indices = [0, 1, 2] 79 | 80 | self.connections = () 81 | 82 | 83 | self.add_image(self.ax_gestalt, self.create_gestalt_image()) 84 | self.add_image(self.ax_position, self.create_position_image()) 85 | 86 | self.update_outputs() 87 | 88 | plt.tight_layout() 89 | 90 | def update_outputs(self): 91 | mask, object, output = None, None, None 92 | with th.no_grad(): 93 | mask, object = self.decoder(self.position, self.gestalt) 94 | bg_mask, background = self.bg_mask, self.background 95 | 96 | # we have to somehow correct for not having 30 objects 97 | mask = th.softmax(th.cat((mask*100, bg_mask), dim=1) * 0.01, dim=1) 98 | object = th.cat((th.sigmoid(object - 2.5), background), dim=1) 99 | 100 | mask = rearrange(mask, 'b n h w -> b n 1 h w') 101 | object = rearrange(object, 'b (n c) h w -> b n c h w', n = 2) 102 | 103 | output = th.sum(mask * object, dim = 1) 104 | output = rearrange(output[0], 'c h w -> h w c').cpu().numpy() 105 | 106 | object = rearrange(object[0,0], 'c h w -> h w c').cpu().numpy() 107 | mask = rearrange(mask[0,0], 'c h w -> h w c') 108 | mask = th.cat((mask, mask, mask * 0.6 + 0.4), dim=2).cpu().numpy() 109 | 110 | self.add_image(self.outputs[self.indices[0]], output[:,:,::-1]) 111 | self.add_image(self.outputs[self.indices[1]], object[:,:,::-1]) 112 | self.add_image(self.outputs[self.indices[2]], mask) 113 | 114 | def __enter__(self): 115 | self.connections = ( 116 | self.fig.canvas.mpl_connect('button_press_event', self.onclick), 117 | self.fig.canvas.mpl_connect('scroll_event', self.onscroll) 118 | ) 119 | return self 120 | 121 | def __exit__(self, *args, **kwargs): 122 | for connection in self.connections: 123 | self.fig.canvas.mpl_disconnect(connection) 124 | 125 | def create_gestalt_image(self): 126 | 127 | gestalt = self.gestalt[0].cpu().numpy() 128 | size = self.gestalt_gridcell_size 129 | margin = self.gestalt_gridcell_margin 130 | 131 | width = self.gestalt_grid_width * (margin + size) + margin 132 | height = self.gestalt_grid_height * (margin + size) + margin 133 | img = np.zeros((height, width, 3)) + 0.3 134 | 135 | for i in range(gestalt.shape[0]): 136 | h = i // self.gestalt_grid_width 137 | w = i % self.gestalt_grid_width 138 | 139 | img[h*size+(h+1)*margin:(h+1)*(margin+size),w*size+(w+1)*margin:(w+1)*(margin+size),0] = (1 - gestalt[i]) * 0.8 + gestalt[i] * 0.2 140 | img[h*size+(h+1)*margin:(h+1)*(margin+size),w*size+(w+1)*margin:(w+1)*(margin+size),1] = gestalt[i] * 0.8 + (1 - gestalt[i]) * 0.2 141 | img[h*size+(h+1)*margin:(h+1)*(margin+size),w*size+(w+1)*margin:(w+1)*(margin+size),2] = 0.2 142 | 143 | 144 | return img 145 | 146 | def create_position_image(self): 147 | 148 | img = self.gaus2d(self.position) 149 | img = rearrange(img[0], 'c h w -> h w c') 150 | 151 | return th.cat((img, img, img * 0.6 + 0.4), dim=2).cpu().numpy() 152 | 153 | 154 | def add_image(self, ax, img): 155 | ax.clear() 156 | ax.imshow(img) 157 | ax.axis('off') 158 | 159 | 160 | 161 | def onclick(self, event): 162 | x, y = event.xdata, event.ydata 163 | 164 | if self.ax_gestalt == event.inaxes: 165 | 166 | size = self.gestalt_gridcell_size 167 | margin = self.gestalt_gridcell_margin 168 | 169 | w = int(x / (margin + size)) 170 | h = int(y / (margin + size)) 171 | 172 | i = h * self.gestalt_grid_width + w 173 | self.gestalt[0,i] = 1 - self.gestalt[0,i] 174 | 175 | self.add_image(self.ax_gestalt, self.create_gestalt_image()) 176 | self.update_outputs() 177 | self.fig.canvas.draw() 178 | 179 | if self.ax_position == event.inaxes: 180 | 181 | x = (x / self.size[1]) * 2 - 1 182 | y = (y / self.size[0]) * 2 - 1 183 | 184 | self.position[0,0] = y 185 | self.position[0,1] = x 186 | 187 | self.add_image(self.ax_position, self.create_position_image()) 188 | self.update_outputs() 189 | self.fig.canvas.draw() 190 | 191 | if self.ax_output2 == event.inaxes: 192 | ax_tmp = self.indices[0] 193 | self.indices[0] = self.indices[1] 194 | self.indices[1] = ax_tmp 195 | self.update_outputs() 196 | self.fig.canvas.draw() 197 | 198 | if self.ax_output3 == event.inaxes: 199 | ax_tmp = self.indices[0] 200 | self.indices[0] = self.indices[2] 201 | self.indices[2] = ax_tmp 202 | self.update_outputs() 203 | self.fig.canvas.draw() 204 | 205 | def onscroll(self, event): 206 | if self.ax_position == event.inaxes: 207 | std = max(self.position[0,2], 0.0) 208 | if event.button == 'down': 209 | self.position[0,2] = max(std - std * (1 - std) * 0.1, 0.0) 210 | 211 | elif event.button == 'up': 212 | self.position[0,2] = std + max(std * (1 - std) * 0.1, 0.001) 213 | 214 | self.add_image(self.ax_position, self.create_position_image()) 215 | self.update_outputs() 216 | self.fig.canvas.draw() 217 | 218 | if __name__=="__main__": 219 | 220 | parser = argparse.ArgumentParser() 221 | parser.add_argument("-cfg", required=True, type=str) 222 | parser.add_argument("-load", required=True, type=str) 223 | parser.add_argument("-background", required=True, type=str) 224 | parser.add_argument("-latent", default="", type=str) 225 | parser.add_argument("-device", default=0, type=int) 226 | 227 | args = parser.parse_args(sys.argv[1:]) 228 | cfg = Configuration(args.cfg) 229 | 230 | gestalt = None 231 | position = None 232 | if args.latent != "": 233 | with open(args.latent, 'rb') as infile: 234 | state = pickle.load(infile) 235 | gestalt = state["gestalt"] 236 | position = state["position"] 237 | 238 | with LociPlayground(cfg.model, args.device, args.load, args.background, gestalt, position): 239 | plt.show() 240 | -------------------------------------------------------------------------------- /nn/background.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch as th 3 | from torch.autograd import Function 4 | import nn as nn_modules 5 | import utils 6 | from nn.residual import ResidualBlock, SkipConnection 7 | from nn.encoder import AggressiveDownConv 8 | from nn.encoder import AggressiveConvTo1x1 9 | from nn.decoder import AggressiveUpConv 10 | from utils.utils import LambdaModule, ForcedAlpha, PrintShape, PushToInf 11 | from nn.predictor import EpropAlphaGateL0rd 12 | from nn.vae import VariationalFunction 13 | from einops import rearrange, repeat, reduce 14 | 15 | from typing import Union, Tuple 16 | 17 | __author__ = "Manuel Traub" 18 | 19 | class BackgroundEnhancer(nn.Module): 20 | def __init__( 21 | self, 22 | input_size: Tuple[int, int], 23 | img_channels: int, 24 | level1_channels, 25 | latent_channels, 26 | gestalt_size, 27 | batch_size, 28 | depth 29 | ): 30 | super(BackgroundEnhancer, self).__init__() 31 | 32 | latent_size = [input_size[0] // 16, input_size[1] // 16] 33 | self.input_size = input_size 34 | 35 | self.register_buffer('init', th.zeros(1).long()) 36 | self.alpha = nn.Parameter(th.zeros(1)+1e-16) 37 | 38 | self.level = 1 39 | self.down_level2 = nn.Sequential( 40 | AggressiveDownConv(img_channels*2+2, level1_channels), 41 | *[ResidualBlock(level1_channels, level1_channels, alpha_residual = True) for i in range(depth)] 42 | ) 43 | 44 | self.down_level1 = nn.Sequential( 45 | AggressiveDownConv(level1_channels, latent_channels), 46 | *[ResidualBlock(latent_channels, latent_channels) for i in range(depth)] 47 | ) 48 | 49 | self.down_level0 = nn.Sequential( 50 | *[ResidualBlock(latent_channels, latent_channels) for i in range(depth)], 51 | AggressiveConvTo1x1(latent_channels, latent_size), 52 | LambdaModule(lambda x: rearrange(x, 'b c 1 1 -> b c')), 53 | nn.Sigmoid(), 54 | LambdaModule(lambda x: x + x * (1 - x) * th.randn_like(x)), 55 | ) 56 | 57 | self.bias = nn.Parameter(th.zeros((1, gestalt_size, *latent_size))) 58 | 59 | self.to_grid = nn.Sequential( 60 | LinearResidual(gestalt_size, gestalt_size, input_relu = False), 61 | LambdaModule(lambda x: rearrange(x, 'b c -> b c 1 1')), 62 | LambdaModule(lambda x: x + self.bias), 63 | *[ResidualBlock(gestalt_size, gestalt_size) for i in range(depth)], 64 | ) 65 | 66 | 67 | self.up_level0 = nn.Sequential( 68 | ResidualBlock(gestalt_size, latent_channels), 69 | *[ResidualBlock(latent_channels, latent_channels) for i in range(depth)], 70 | ) 71 | 72 | self.up_level1 = nn.Sequential( 73 | *[ResidualBlock(latent_channels, latent_channels) for i in range(depth)], 74 | AggressiveUpConv(latent_channels, level1_channels), 75 | ) 76 | 77 | self.up_level2 = nn.Sequential( 78 | *[ResidualBlock(level1_channels, level1_channels, alpha_residual = True) for i in range(depth)], 79 | AggressiveUpConv(level1_channels, img_channels), 80 | ) 81 | 82 | self.to_channels = nn.ModuleList([ 83 | SkipConnection(img_channels*2+2, latent_channels), 84 | SkipConnection(img_channels*2+2, level1_channels), 85 | SkipConnection(img_channels*2+2, img_channels*2+2), 86 | ]) 87 | 88 | self.to_img = nn.ModuleList([ 89 | SkipConnection(latent_channels, img_channels), 90 | SkipConnection(level1_channels, img_channels), 91 | SkipConnection(img_channels, img_channels), 92 | ]) 93 | 94 | self.mask = nn.Parameter(th.ones(1, 1, *input_size) * 10) 95 | self.object = nn.Parameter(th.ones(1, img_channels, *input_size)) 96 | 97 | self.register_buffer('latent', th.zeros((batch_size, gestalt_size)), persistent=False) 98 | 99 | def get_init(self): 100 | return self.init.item() 101 | 102 | def step_init(self): 103 | self.init = self.init + 1 104 | 105 | def detach(self): 106 | self.latent = self.latent.detach() 107 | 108 | def reset_state(self): 109 | self.latent = th.zeros_like(self.latent) 110 | 111 | def set_level(self, level): 112 | self.level = level 113 | 114 | def encoder(self, input): 115 | latent = self.to_channels[self.level](input) 116 | 117 | if self.level >= 2: 118 | latent = self.down_level2(latent) 119 | 120 | if self.level >= 1: 121 | latent = self.down_level1(latent) 122 | 123 | return self.down_level0(latent) 124 | 125 | def get_last_latent_gird(self): 126 | return self.to_grid(self.latent) * self.alpha 127 | 128 | def decoder(self, latent, input): 129 | grid = self.to_grid(latent) 130 | latent = self.up_level0(grid) 131 | 132 | if self.level >= 1: 133 | latent = self.up_level1(latent) 134 | 135 | if self.level >= 2: 136 | latent = self.up_level2(latent) 137 | 138 | object = reduce(self.object, '1 c (h h2) (w w2) -> 1 c h w', 'mean', h = input.shape[2], w = input.shape[3]) 139 | object = repeat(object, '1 c h w -> b c h w', b = input.shape[0]) 140 | 141 | return th.sigmoid(object + self.to_img[self.level](latent)), grid 142 | 143 | def forward(self, input: th.Tensor, error: th.Tensor = None, mask: th.Tensor = None): 144 | 145 | last_bg = self.decoder(self.latent, input)[0] 146 | 147 | bg_error = th.sqrt(reduce((input - last_bg)**2, 'b c h w -> b 1 h w', 'mean')).detach() 148 | bg_mask = (bg_error < th.mean(bg_error) + th.std(bg_error)).float().detach() 149 | 150 | if error is None or self.get_init() < 2: 151 | error = bg_error 152 | 153 | if mask is None or self.get_init() < 2: 154 | mask = bg_mask 155 | 156 | self.latent = self.encoder(th.cat((input, last_bg, error, mask), dim=1)) 157 | 158 | mask = reduce(self.mask, '1 1 (h h2) (w w2) -> 1 1 h w', 'mean', h = input.shape[2], w = input.shape[3]) 159 | mask = repeat(mask, '1 1 h w -> b 1 h w', b = input.shape[0]) * 0.1 160 | 161 | background, grid = self.decoder(self.latent, input) 162 | 163 | if self.get_init() < 1: 164 | return mask, background, background 165 | 166 | if self.get_init() < 2: 167 | return mask, th.zeros_like(background), background 168 | 169 | return mask, background, background 170 | 171 | class PrecalculatedBackground(nn.Module): 172 | def __init__( 173 | self, 174 | input_size: Tuple[int, int], 175 | img_channels: int, 176 | ): 177 | super(PrecalculatedBackground, self).__init__() 178 | 179 | self.mask = nn.Parameter(th.ones(1, 1, input_size[0], input_size[1])) 180 | self.init = 0 181 | self.alpha = nn.Parameter(th.zeros(1)) 182 | self.to_inf = PushToInf() 183 | self.local_bias = nn.Parameter(th.zeros(1, img_channels, input_size[0], input_size[1])) 184 | self.register_buffer("background", th.zeros(1, img_channels, input_size[0], input_size[1]), persistent = False) 185 | 186 | def set_background(self, background): 187 | self.background = background 188 | 189 | def get_init(self): 190 | return th.tanh(self.to_inf(self.alpha * 0.5)).item() 191 | 192 | def set_level(self, level): 193 | self.level = level 194 | 195 | def forward(self, input: th.Tensor, error: th.Tensor = None, mask: th.Tensor = None): 196 | 197 | H, W = input.shape[2:] 198 | 199 | background = self.background 200 | if W < background.shape[-1]: 201 | background = reduce(background, 'b c (h h2) (w w2) -> b c h w', 'mean', h = H, w = W) 202 | 203 | mask = self.mask 204 | if W < mask.shape[-1]: 205 | mask = reduce(mask, 'b c (h h2) (w w2) -> b c h w', 'mean', h = H, w = W) 206 | 207 | mask = mask.expand(input.shape[0], 1, *input.shape[2:]) 208 | return mask, background * th.tanh(self.to_inf(self.alpha * 0.5)), background 209 | -------------------------------------------------------------------------------- /nn/decoder.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch as th 3 | import numpy as np 4 | from nn.residual import ResidualBlock, SkipConnection 5 | from utils.utils import Gaus2D, SharedObjectsToBatch, BatchToSharedObjects, Prioritize, LambdaModule 6 | from torch.autograd import Function 7 | from einops import rearrange, repeat, reduce 8 | 9 | from typing import Tuple, Union, List 10 | import utils 11 | 12 | __author__ = "Manuel Traub" 13 | 14 | class GestaltPositionMerge(nn.Module): 15 | def __init__( 16 | self, 17 | latent_size: Union[int, Tuple[int, int]], 18 | num_objects: int 19 | ): 20 | 21 | super(GestaltPositionMerge, self).__init__() 22 | self.num_objects = num_objects 23 | 24 | self.gaus2d = Gaus2D(size=latent_size) 25 | 26 | self.to_batch = SharedObjectsToBatch(num_objects) 27 | self.to_shared = BatchToSharedObjects(num_objects) 28 | 29 | self.prioritize = Prioritize(num_objects) 30 | 31 | def forward(self, position, gestalt, priority): 32 | 33 | position = rearrange(position, 'b (o c) -> (b o) c', o = self.num_objects) 34 | gestalt = rearrange(gestalt, 'b (o c) -> (b o) c 1 1', o = self.num_objects) 35 | 36 | position = self.gaus2d(position) 37 | position = self.to_batch(self.prioritize(self.to_shared(position), priority)) 38 | 39 | return position * gestalt 40 | 41 | class AggressiveUpConv(nn.Module): 42 | def __init__(self, in_channels, img_channels, alpha = 1e-16): 43 | super(AggressiveUpConv, self).__init__() 44 | 45 | self.layers = nn.Sequential( 46 | nn.ReLU(), 47 | nn.Conv2d( 48 | in_channels = in_channels, 49 | out_channels = in_channels, 50 | kernel_size = 3, 51 | stride = 1, 52 | padding = 1, 53 | ), 54 | nn.ReLU(), 55 | nn.ConvTranspose2d( 56 | in_channels = in_channels, 57 | out_channels = img_channels, 58 | kernel_size = 12, 59 | stride = 4, 60 | padding = 4 61 | ) 62 | ) 63 | self.alpha = nn.Parameter(th.zeros(1) + alpha) 64 | self.size_factor = 4 65 | self.channels_factor = in_channels // img_channels 66 | 67 | 68 | def forward(self, input: th.Tensor): 69 | s = self.size_factor 70 | c = self.channels_factor 71 | skip = reduce(input, 'b (c n) h w -> b c h w', 'mean', n=c) 72 | skip = repeat(skip, 'b c h w -> b c (h h2) (w w2)', h2=s, w2=s) 73 | return skip + self.alpha * self.layers(input) 74 | 75 | class LociDecoder(nn.Module): 76 | def __init__( 77 | self, 78 | latent_size: Union[int, Tuple[int, int]], 79 | gestalt_size: int, 80 | num_objects: int, 81 | img_channels: int, 82 | hidden_channels: int, 83 | level1_channels: int, 84 | num_layers: int, 85 | ): 86 | 87 | super(LociDecoder, self).__init__() 88 | self.to_batch = SharedObjectsToBatch(num_objects) 89 | self.to_shared = BatchToSharedObjects(num_objects) 90 | self.level = 1 91 | 92 | assert(level1_channels % img_channels == 0) 93 | level1_factor = level1_channels // img_channels 94 | print(f"Level1 channels: {level1_channels}") 95 | 96 | self.merge = GestaltPositionMerge( 97 | latent_size = latent_size, 98 | num_objects = num_objects, 99 | ) 100 | 101 | self.layer0 = nn.Sequential( 102 | ResidualBlock(gestalt_size, hidden_channels, input_nonlinearity = False), 103 | *[ResidualBlock(hidden_channels, hidden_channels) for _ in range(num_layers-1)], 104 | ) 105 | 106 | self.to_mask_level0 = ResidualBlock(hidden_channels, hidden_channels) 107 | self.to_mask_level1 = AggressiveUpConv(hidden_channels, level1_factor) 108 | 109 | self.to_mask_level2 = nn.Sequential( 110 | ResidualBlock(hidden_channels, hidden_channels), 111 | ResidualBlock(hidden_channels, hidden_channels), 112 | AggressiveUpConv(hidden_channels, 4, alpha = 1), 113 | AggressiveUpConv(4, 1, alpha = 1), 114 | ) 115 | 116 | self.to_object_level0 = ResidualBlock(hidden_channels, hidden_channels) 117 | self.to_object_level1 = AggressiveUpConv(hidden_channels, level1_channels) 118 | 119 | self.to_object_level2 = nn.Sequential( 120 | ResidualBlock(hidden_channels, hidden_channels), 121 | ResidualBlock(hidden_channels, hidden_channels), 122 | AggressiveUpConv(hidden_channels, 12, alpha = 1), 123 | AggressiveUpConv(12, img_channels, alpha = 1), 124 | ) 125 | 126 | self.mask_to_pixel = nn.ModuleList([ 127 | SkipConnection(hidden_channels, 1), 128 | SkipConnection(level1_factor, 1), 129 | SkipConnection(1, 1), 130 | ]) 131 | self.object_to_pixel = nn.ModuleList([ 132 | SkipConnection(hidden_channels, img_channels), 133 | SkipConnection(level1_channels, img_channels), 134 | SkipConnection(img_channels, img_channels), 135 | ]) 136 | 137 | self.mask_alpha = nn.Parameter(th.zeros(1)+1e-16) 138 | self.object_alpha = nn.Parameter(th.zeros(1)+1e-16) 139 | 140 | def set_level(self, level): 141 | self.level = level 142 | 143 | def forward(self, position, gestalt, priority = None): 144 | 145 | maps = self.layer0(self.merge(position, gestalt, priority)) 146 | 147 | mask0 = self.to_mask_level0(maps) 148 | object0 = self.to_object_level0(maps) 149 | 150 | if self.level > 0: 151 | mask = self.to_mask_level1(mask0) 152 | object = self.to_object_level1(object0) 153 | 154 | if self.level > 1: 155 | mask = repeat(mask, 'b c h w -> b c (h h2) (w w2)', h2 = 4, w2 = 4) + self.to_mask_level2(mask0) * self.mask_alpha 156 | object = repeat(object, 'b c h w -> b c (h h2) (w w2)', h2 = 4, w2 = 4) + self.to_object_level2(object0) * self.object_alpha 157 | 158 | mask = self.mask_to_pixel[self.level](mask) 159 | object = self.object_to_pixel[self.level](object) 160 | 161 | return self.to_shared(mask), self.to_shared(object) 162 | -------------------------------------------------------------------------------- /nn/encoder.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch as th 3 | import numpy as np 4 | from utils.utils import Gaus2D, SharedObjectsToBatch, BatchToSharedObjects, Prioritize, LambdaModule, ForcedAlpha 5 | from nn.eprop_lstm import EpropLSTM 6 | from nn.residual import ResidualBlock, SkipConnection 7 | from nn.eprop_gate_l0rd import EpropGateL0rd 8 | from torch.autograd import Function 9 | from einops import rearrange, repeat, reduce 10 | 11 | from typing import Tuple, Union, List 12 | import utils 13 | import cv2 14 | 15 | __author__ = "Manuel Traub" 16 | 17 | class NeighbourChannels(nn.Module): 18 | def __init__(self, channels): 19 | super(NeighbourChannels, self).__init__() 20 | 21 | self.register_buffer("weights", th.ones(channels, channels, 1, 1), persistent=False) 22 | 23 | for i in range(channels): 24 | self.weights[i,i,0,0] = 0 25 | 26 | def forward(self, input: th.Tensor): 27 | return nn.functional.conv2d(input, self.weights) 28 | 29 | class InputPreprocessing(nn.Module): 30 | def __init__(self, num_objects: int, size: Union[int, Tuple[int, int]]): 31 | super(InputPreprocessing, self).__init__() 32 | self.num_objects = num_objects 33 | self.neighbours = NeighbourChannels(num_objects) 34 | self.prioritize = Prioritize(num_objects) 35 | self.gaus2d = Gaus2D(size) 36 | self.to_batch = LambdaModule(lambda x: rearrange(x, 'b (o c) -> (b o) c', o = num_objects)) 37 | self.to_shared = BatchToSharedObjects(num_objects) 38 | 39 | def forward( 40 | self, 41 | input: th.Tensor, 42 | error: th.Tensor, 43 | mask: th.Tensor, 44 | object: th.Tensor, 45 | position: th.Tensor, 46 | priority: th.Tensor 47 | ): 48 | batch_size = input.shape[0] 49 | size = input.shape[2:] 50 | 51 | mask = mask.detach() 52 | position = position.detach() 53 | priority = priority.detach() 54 | 55 | bg_mask = repeat(mask[:,-1:], 'b 1 h w -> b c h w', c = self.num_objects) 56 | mask = mask[:,:-1] 57 | mask_others = self.neighbours(mask) 58 | 59 | own_gaus2d = self.to_shared(self.gaus2d(self.to_batch(position))) 60 | others_gaus2d = self.neighbours(self.prioritize(own_gaus2d, priority)) 61 | 62 | input = repeat(input, 'b c h w -> b o c h w', o = self.num_objects) 63 | error = repeat(error, 'b 1 h w -> b o 1 h w', o = self.num_objects) 64 | bg_mask = rearrange(bg_mask, 'b o h w -> b o 1 h w') 65 | mask_others = rearrange(mask_others, 'b o h w -> b o 1 h w') 66 | mask = rearrange(mask, 'b o h w -> b o 1 h w') 67 | object = rearrange(object, 'b (o c) h w -> b o c h w', o = self.num_objects) 68 | own_gaus2d = rearrange(own_gaus2d, 'b o h w -> b o 1 h w') 69 | others_gaus2d = rearrange(others_gaus2d, 'b o h w -> b o 1 h w') 70 | 71 | output = th.cat((input, error, mask, mask_others, bg_mask, object, own_gaus2d, others_gaus2d), dim=2) 72 | output = rearrange(output, 'b o c h w -> (b o) c h w') 73 | 74 | return output 75 | 76 | class AggressiveDownConv(nn.Module): 77 | def __init__(self, in_channels, out_channels): 78 | super(AggressiveDownConv, self).__init__() 79 | assert out_channels % in_channels == 0 80 | 81 | self.layers = nn.Sequential( 82 | nn.Conv2d( 83 | in_channels = in_channels, 84 | out_channels = out_channels, 85 | kernel_size = 11, 86 | stride = 4, 87 | padding = 5 88 | ), 89 | nn.ReLU(), 90 | nn.Conv2d( 91 | in_channels = out_channels, 92 | out_channels = out_channels, 93 | kernel_size = 3, 94 | padding = 1 95 | ) 96 | ) 97 | self.alpha = nn.Parameter(th.zeros(1) + 1e-12) 98 | self.size_factor = 4 99 | self.channels_factor = out_channels // in_channels 100 | 101 | 102 | def forward(self, input: th.Tensor): 103 | s = self.size_factor 104 | c = self.channels_factor 105 | skip = reduce(input, 'b c (h h2) (w w2) -> b c h w', 'mean', h2=s, w2=s) 106 | skip = repeat(skip, 'b c h w -> b (c n) h w', n=c) 107 | return skip + self.alpha * self.layers(input) 108 | 109 | class AggressiveConvTo1x1(nn.Module): 110 | def __init__(self, in_channels, out_channels, size: Union[int, Tuple[int, int]]): 111 | super(AggressiveConvTo1x1, self).__init__() 112 | 113 | self.layers = nn.Sequential( 114 | nn.Conv2d( 115 | in_channels = in_channels, 116 | out_channels = in_channels, 117 | kernel_size = 5, 118 | stride = 3, 119 | padding = 3 120 | ), 121 | nn.ReLU(), 122 | nn.Conv2d( 123 | in_channels = in_channels, 124 | out_channels = out_channels, 125 | kernel_size = ((size[0] + 1)//3 + 1, (size[1] + 1)//3 + 1) 126 | ) 127 | ) 128 | self.alpha = nn.Parameter(th.zeros(1) + 1e-12) 129 | self.size = size 130 | self.factor = out_channels // in_channels 131 | 132 | 133 | def forward(self, input: th.Tensor): 134 | skip = reduce(input, 'b c h w -> b c 1 1', 'mean') 135 | skip = repeat(skip, 'b c 1 1 -> b (c n) 1 1', n = self.factor) 136 | return skip + self.alpha * self.layers(input) 137 | 138 | class PixelToPosition(nn.Module): 139 | def __init__(self, size: Union[int, Tuple[int, int]]): 140 | super(PixelToPosition, self).__init__() 141 | 142 | self.register_buffer("grid_x", th.arange(size[0]), persistent=False) 143 | self.register_buffer("grid_y", th.arange(size[1]), persistent=False) 144 | 145 | self.grid_x = (self.grid_x / (size[0]-1)) * 2 - 1 146 | self.grid_y = (self.grid_y / (size[1]-1)) * 2 - 1 147 | 148 | self.grid_x = self.grid_x.view(1, 1, -1, 1).expand(1, 1, *size).clone() 149 | self.grid_y = self.grid_y.view(1, 1, 1, -1).expand(1, 1, *size).clone() 150 | 151 | self.size = size 152 | 153 | def forward(self, input: th.Tensor): 154 | assert input.shape[1] == 1 155 | 156 | input = rearrange(input, 'b c h w -> b c (h w)') 157 | input = th.softmax(input, dim=2) 158 | input = rearrange(input, 'b c (h w) -> b c h w', h = self.size[0], w = self.size[1]) 159 | 160 | x = th.sum(input * self.grid_x, dim=(2,3)) 161 | y = th.sum(input * self.grid_y, dim=(2,3)) 162 | 163 | return th.cat((x,y),dim=1) 164 | 165 | class PixelToSTD(nn.Module): 166 | def __init__(self): 167 | super(PixelToSTD, self).__init__() 168 | self.alpha = ForcedAlpha() 169 | 170 | def forward(self, input: th.Tensor): 171 | assert input.shape[1] == 1 172 | return self.alpha(reduce(th.sigmoid(input - 10), 'b c h w -> b c', 'mean')) 173 | 174 | class PixelToPriority(nn.Module): 175 | def __init__(self): 176 | super(PixelToPriority, self).__init__() 177 | 178 | def forward(self, input: th.Tensor): 179 | assert input.shape[1] == 1 180 | return reduce(th.tanh(input), 'b c h w -> b c', 'mean') 181 | 182 | class LociEncoder(nn.Module): 183 | def __init__( 184 | self, 185 | input_size: Union[int, Tuple[int, int]], 186 | latent_size: Union[int, Tuple[int, int]], 187 | num_objects: int, 188 | img_channels: int, 189 | hidden_channels: int, 190 | level1_channels: int, 191 | num_layers: int, 192 | gestalt_size: int, 193 | batch_size: int, 194 | ): 195 | super(LociEncoder, self).__init__() 196 | 197 | self.num_objects = num_objects 198 | self.latent_size = latent_size 199 | self.level = 1 200 | 201 | self.to_shared = LambdaModule(lambda x: rearrange(x, '(b o) c -> b (o c)', o = self.num_objects)) 202 | 203 | print(f"Level1 channels: {level1_channels}") 204 | 205 | self.preprocess = nn.ModuleList([ 206 | InputPreprocessing(num_objects, (input_size[0] // 16, input_size[1] // 16)), 207 | InputPreprocessing(num_objects, (input_size[0] // 4, input_size[1] // 4)), 208 | InputPreprocessing(num_objects, (input_size[0], input_size[1])) 209 | ]) 210 | 211 | self.to_channels = nn.ModuleList([ 212 | SkipConnection(img_channels, hidden_channels), 213 | SkipConnection(img_channels, level1_channels), 214 | SkipConnection(img_channels, img_channels) 215 | ]) 216 | 217 | self.layers2 = nn.Sequential( 218 | AggressiveDownConv(img_channels, level1_channels), 219 | *[ResidualBlock(level1_channels, level1_channels, alpha_residual=True) for _ in range(num_layers)] 220 | ) 221 | 222 | self.layers1 = nn.Sequential(AggressiveDownConv(level1_channels, hidden_channels)) 223 | 224 | self.layers0 = nn.Sequential( 225 | *[ResidualBlock(hidden_channels, hidden_channels) for _ in range(num_layers)] 226 | ) 227 | 228 | self.position_encoder = nn.Sequential( 229 | *[ResidualBlock(hidden_channels, hidden_channels) for _ in range(num_layers)], 230 | ResidualBlock(hidden_channels, 3), 231 | ) 232 | 233 | self.xy_encoder = PixelToPosition(latent_size) 234 | self.std_encoder = PixelToSTD() 235 | self.priority_encoder = PixelToPriority() 236 | 237 | gestalt_channels = max(hidden_channels, gestalt_size) 238 | self.gestalt_encoder = nn.Sequential( 239 | AggressiveConvTo1x1( 240 | in_channels = hidden_channels, 241 | out_channels = gestalt_channels, 242 | size = latent_size 243 | ), 244 | *[ResidualBlock(gestalt_channels, gestalt_channels, kernel_size = 1) for _ in range(num_layers)], 245 | ResidualBlock(gestalt_channels, gestalt_size, kernel_size = 1), 246 | LambdaModule(lambda x: rearrange(x, 'b c 1 1 -> b c')), 247 | ) 248 | 249 | def set_level(self, level): 250 | self.level = level 251 | 252 | def forward( 253 | self, 254 | input: th.Tensor, 255 | error: th.Tensor, 256 | mask: th.Tensor, 257 | object: th.Tensor, 258 | position: th.Tensor, 259 | priority: th.Tensor 260 | ): 261 | 262 | latent = self.preprocess[self.level](input, error, mask, object, position, priority) 263 | latent = self.to_channels[self.level](latent) 264 | 265 | if self.level >= 2: 266 | latent = self.layers2(latent) 267 | 268 | if self.level >= 1: 269 | latent = self.layers1(latent) 270 | 271 | latent = self.layers0(latent) 272 | gestalt = self.gestalt_encoder(latent) 273 | 274 | latent = self.position_encoder(latent) 275 | std = self.std_encoder(latent[:,0:1]) 276 | xy = self.xy_encoder(latent[:,1:2]) 277 | priority = self.priority_encoder(latent[:,2:3]) 278 | 279 | position = self.to_shared(th.cat((xy, std), dim=1)) 280 | gestalt = self.to_shared(gestalt) 281 | priority = self.to_shared(priority) 282 | 283 | return position, gestalt, priority 284 | 285 | -------------------------------------------------------------------------------- /nn/eprop_gate_l0rd.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch as th 3 | import numpy as np 4 | from torch.autograd import Function 5 | from einops import rearrange, repeat, reduce 6 | 7 | from typing import Tuple, Union, List 8 | import cv2 9 | 10 | __author__ = "Manuel Traub" 11 | 12 | class EpropGateL0rdFunction(Function): 13 | @staticmethod 14 | def forward(ctx, x, h_last, w_gx, w_gh, b_g, w_rx, w_rh, b_r, args): 15 | 16 | e_w_gx, e_w_gh, e_b_g, e_w_rx, e_w_rh, e_b_r, reg, noise_level = args 17 | 18 | noise = th.normal(mean=0, std=noise_level, size=b_g.shape, device=b_g.device) 19 | g = th.relu(th.tanh(x.mm(w_gx.t()) + h_last.mm(w_gh.t()) + b_g + noise)) 20 | r = th.tanh(x.mm(w_rx.t()) + h_last.mm(w_rh.t()) + b_r) 21 | 22 | h = g * r + (1 - g) * h_last 23 | 24 | # Haevisite step function 25 | H_g = th.ceil(g).clamp(0, 1) 26 | 27 | dg = (1 - g**2) * H_g 28 | dr = (1 - r**2) 29 | 30 | delta_h = r - h_last 31 | 32 | g_j = g.unsqueeze(dim=2) 33 | dg_j = dg.unsqueeze(dim=2) 34 | dr_j = dr.unsqueeze(dim=2) 35 | 36 | x_i = x.unsqueeze(dim=1) 37 | h_last_i = h_last.unsqueeze(dim=1) 38 | delta_h_j = delta_h.unsqueeze(dim=2) 39 | 40 | e_w_gh.copy_(e_w_gh * (1 - g_j) + dg_j * h_last_i * delta_h_j) 41 | e_w_gx.copy_(e_w_gx * (1 - g_j) + dg_j * x_i * delta_h_j) 42 | e_b_g.copy_( e_b_g * (1 - g) + dg * delta_h ) 43 | 44 | e_w_rh.copy_(e_w_rh * (1 - g_j) + dr_j * h_last_i * g_j) 45 | e_w_rx.copy_(e_w_rx * (1 - g_j) + dr_j * x_i * g_j) 46 | e_b_r.copy_( e_b_r * (1 - g) + dr * g ) 47 | 48 | ctx.save_for_backward( 49 | g.clone(), dg.clone(), dg_j.clone(), dr.clone(), x_i.clone(), h_last_i.clone(), 50 | reg.clone(), H_g.clone(), delta_h.clone(), w_gx.clone(), w_gh.clone(), w_rx.clone(), w_rh.clone(), 51 | e_w_gx.clone(), e_w_gh.clone(), e_b_g.clone(), 52 | e_w_rx.clone(), e_w_rh.clone(), e_b_r.clone(), 53 | ) 54 | 55 | return h, th.mean(H_g) 56 | 57 | @staticmethod 58 | def backward(ctx, dh, _): 59 | 60 | g, dg, dg_j, dr, x_i, h_last_i, reg, H_g, delta_h, w_gx, w_gh, w_rx, w_rh, \ 61 | e_w_gx, e_w_gh, e_b_g, e_w_rx, e_w_rh, e_b_r = ctx.saved_tensors 62 | 63 | dh_j = dh.unsqueeze(dim=2) 64 | H_g_reg = reg * H_g 65 | H_g_reg_j = H_g_reg.unsqueeze(dim=2) 66 | 67 | dw_gx = th.sum(dh_j * e_w_gx + H_g_reg_j * dg_j * x_i, dim=0) 68 | dw_gh = th.sum(dh_j * e_w_gh + H_g_reg_j * dg_j * h_last_i, dim=0) 69 | db_g = th.sum(dh * e_b_g + H_g_reg * dg, dim=0) 70 | 71 | dw_rx = th.sum(dh_j * e_w_rx, dim=0) 72 | dw_rh = th.sum(dh_j * e_w_rh, dim=0) 73 | db_r = th.sum(dh * e_b_r , dim=0) 74 | 75 | dh_dg = (dh * delta_h + H_g_reg) * dg 76 | dh_dr = dh * g * dr 77 | 78 | dx = dh_dg.mm(w_gx) + dh_dr.mm(w_rx) 79 | dh = dh * (1 - g) + dh_dg.mm(w_gh) + dh_dr.mm(w_rh) 80 | 81 | return dx, dh, dw_gx, dw_gh, db_g, dw_rx, dw_rh, db_r, None 82 | 83 | class ReTanhFunction(Function): 84 | @staticmethod 85 | def forward(ctx, x, reg): 86 | 87 | g = th.relu(th.tanh(x)) 88 | 89 | # Haevisite step function 90 | H_g = th.ceil(g).clamp(0, 1) 91 | 92 | dg = (1 - g**2) * H_g 93 | 94 | ctx.save_for_backward(g, dg, H_g, reg) 95 | return g, th.mean(H_g) 96 | 97 | @staticmethod 98 | def backward(ctx, dh, _): 99 | 100 | g, dg, H_g, reg = ctx.saved_tensors 101 | 102 | dx = (dh + reg * H_g) * dg 103 | 104 | return dx, None 105 | 106 | class ReTanh(nn.Module): 107 | def __init__(self, reg_lambda): 108 | super(ReTanh, self).__init__() 109 | 110 | self.re_tanh = ReTanhFunction().apply 111 | self.register_buffer("reg_lambda", th.tensor(reg_lambda), persistent=False) 112 | 113 | def forward(self, input): 114 | h, openings = self.re_tanh(input, self.reg_lambda) 115 | self.openings = openings.item() 116 | 117 | return h 118 | 119 | 120 | class EpropGateL0rd(nn.Module): 121 | def __init__( 122 | self, 123 | num_inputs, 124 | num_hidden, 125 | num_outputs, 126 | batch_size, 127 | reg_lambda = 0, 128 | gate_noise_level = 0, 129 | ): 130 | super(EpropGateL0rd, self).__init__() 131 | 132 | self.register_buffer("reg", th.tensor(reg_lambda).view(1,1), persistent=False) 133 | self.register_buffer("noise", th.tensor(gate_noise_level), persistent=False) 134 | self.num_inputs = num_inputs 135 | self.num_hidden = num_hidden 136 | self.num_outputs = num_outputs 137 | 138 | self.fcn = EpropGateL0rdFunction().apply 139 | self.retanh = ReTanh(reg_lambda) 140 | 141 | # gate weights and biases 142 | self.w_gx = nn.Parameter(th.empty(num_hidden, num_inputs)) 143 | self.w_gh = nn.Parameter(th.empty(num_hidden, num_hidden)) 144 | self.b_g = nn.Parameter(th.zeros(num_hidden)) 145 | 146 | # candidate weights and biases 147 | self.w_rx = nn.Parameter(th.empty(num_hidden, num_inputs)) 148 | self.w_rh = nn.Parameter(th.empty(num_hidden, num_hidden)) 149 | self.b_r = nn.Parameter(th.zeros(num_hidden)) 150 | 151 | # output projection weights and bias 152 | self.w_px = nn.Parameter(th.empty(num_outputs, num_inputs)) 153 | self.w_ph = nn.Parameter(th.empty(num_outputs, num_hidden)) 154 | self.b_p = nn.Parameter(th.zeros(num_outputs)) 155 | 156 | # output gate weights and bias 157 | self.w_ox = nn.Parameter(th.empty(num_outputs, num_inputs)) 158 | self.w_oh = nn.Parameter(th.empty(num_outputs, num_hidden)) 159 | self.b_o = nn.Parameter(th.zeros(num_outputs)) 160 | 161 | # input gate eligibilitiy traces 162 | self.register_buffer("e_w_gx", th.zeros(batch_size, num_hidden, num_inputs), persistent=False) 163 | self.register_buffer("e_w_gh", th.zeros(batch_size, num_hidden, num_hidden), persistent=False) 164 | self.register_buffer("e_b_g", th.zeros(batch_size, num_hidden), persistent=False) 165 | 166 | # forget gate eligibilitiy traces 167 | self.register_buffer("e_w_rx", th.zeros(batch_size, num_hidden, num_inputs), persistent=False) 168 | self.register_buffer("e_w_rh", th.zeros(batch_size, num_hidden, num_hidden), persistent=False) 169 | self.register_buffer("e_b_r", th.zeros(batch_size, num_hidden), persistent=False) 170 | 171 | # hidden state 172 | self.register_buffer("h_last", th.zeros(batch_size, num_hidden), persistent=False) 173 | 174 | self.register_buffer("openings", th.zeros(1), persistent=False) 175 | 176 | # initialize weights 177 | stdv_ih = np.sqrt(6/(self.num_inputs + self.num_hidden)) 178 | stdv_hh = np.sqrt(3/self.num_hidden) 179 | stdv_io = np.sqrt(6/(self.num_inputs + self.num_outputs)) 180 | stdv_ho = np.sqrt(6/(self.num_hidden + self.num_outputs)) 181 | 182 | nn.init.uniform_(self.w_gx, -stdv_ih, stdv_ih) 183 | nn.init.uniform_(self.w_gh, -stdv_hh, stdv_hh) 184 | 185 | nn.init.uniform_(self.w_rx, -stdv_ih, stdv_ih) 186 | nn.init.uniform_(self.w_rh, -stdv_hh, stdv_hh) 187 | 188 | nn.init.uniform_(self.w_px, -stdv_io, stdv_io) 189 | nn.init.uniform_(self.w_ph, -stdv_ho, stdv_ho) 190 | 191 | nn.init.uniform_(self.w_ox, -stdv_io, stdv_io) 192 | nn.init.uniform_(self.w_oh, -stdv_ho, stdv_ho) 193 | 194 | self.backprop = False 195 | 196 | def reset_state(self): 197 | self.h_last.zero_() 198 | self.e_w_gx.zero_() 199 | self.e_w_gh.zero_() 200 | self.e_b_g.zero_() 201 | self.e_w_rx.zero_() 202 | self.e_w_rh.zero_() 203 | self.e_b_r.zero_() 204 | self.openings.zero_() 205 | 206 | def backprop_forward(self, x: th.Tensor): 207 | 208 | noise = th.normal(mean=0, std=self.noise, size=self.b_g.shape, device=self.b_g.device) 209 | g = self.retanh(x.mm(self.w_gx.t()) + self.h_last.mm(self.w_gh.t()) + self.b_g + noise) 210 | r = th.tanh(x.mm(self.w_rx.t()) + self.h_last.mm(self.w_rh.t()) + self.b_r) 211 | 212 | self.h_last = g * r + (1 - g) * self.h_last 213 | 214 | # Haevisite step function 215 | H_g = th.ceil(g).clamp(0, 1) 216 | 217 | self.openings = th.mean(H_g) 218 | 219 | p = th.tanh(x.mm(self.w_px.t()) + self.h_last.mm(self.w_ph.t()) + self.b_p) 220 | o = th.sigmoid(x.mm(self.w_ox.t()) + self.h_last.mm(self.w_oh.t()) + self.b_o) 221 | return o * p 222 | 223 | def activate_backprop(self): 224 | self.backprop = True 225 | 226 | def deactivate_backprop(self): 227 | self.backprop = False 228 | 229 | def detach(self): 230 | self.h_last.detach_() 231 | 232 | def eprop_forward(self, x: th.Tensor): 233 | h, openings = self.fcn( 234 | x, self.h_last, 235 | self.w_gx, self.w_gh, self.b_g, 236 | self.w_rx, self.w_rh, self.b_r, 237 | ( 238 | self.e_w_gx, self.e_w_gh, self.e_b_g, 239 | self.e_w_rx, self.e_w_rh, self.e_b_r, 240 | self.reg, self.noise 241 | ) 242 | ) 243 | 244 | self.openings = openings 245 | self.h_last = h 246 | 247 | p = th.tanh(x.mm(self.w_px.t()) + h.mm(self.w_ph.t()) + self.b_p) 248 | o = th.sigmoid(x.mm(self.w_ox.t()) + h.mm(self.w_oh.t()) + self.b_o) 249 | return o * p 250 | 251 | def save_hidden(self): 252 | self.h_last_saved = self.h_last.detach() 253 | 254 | def restore_hidden(self): 255 | self.h_last = self.h_last_saved 256 | 257 | def get_hidden(self): 258 | return self.h_last 259 | 260 | def set_hidden(self, h_last): 261 | self.h_last = h_last 262 | 263 | def forward(self, x: th.Tensor): 264 | if self.backprop: 265 | return self.backprop_forward(x) 266 | 267 | return self.eprop_forward(x) 268 | 269 | """ 270 | if __name__ == "__main__": 271 | 272 | device = "cpu" 273 | batch_size = 256 274 | time = 10 275 | 276 | l0rd = EpropGateL0rd(3, 8, 1, 256, reg_lambda=1e-5) 277 | opti = th.optim.Adam(l0rd.parameters(), lr=0.003) 278 | 279 | avg_openings = 0 280 | avg_openings_sum = 1e-100 281 | avg_loss = 0 282 | avg_sum = 0 283 | for i in range(10000): 284 | l0rd.reset_state() 285 | 286 | loss = 0 287 | 288 | values = th.rand(time, batch_size, 1).to(device) 289 | store = th.zeros(time, batch_size, 1).to(device) 290 | recall = th.zeros(time, batch_size, 1).to(device) 291 | 292 | s = th.floor(th.rand(batch_size) * (time//3)).long() 293 | r = th.floor(th.rand(batch_size) * (time//3) + (time//3)*2).long() 294 | 295 | store[s,th.arange(batch_size),0] = 1 296 | recall[r,th.arange(batch_size),0] = 1 297 | 298 | x = th.cat((values, store, recall), dim=2) 299 | 300 | for t in range(time): 301 | h = l0rd(x[t]) 302 | 303 | if t >= (time//3)*2: 304 | avg_openings = avg_openings * 0.99 + l0rd.openings.item() 305 | avg_openings_sum = avg_openings_sum * 0.99 + 1 306 | 307 | opti.zero_grad() 308 | v = values[t] 309 | loss = th.mean((h - v)**2 * recall[t]) 310 | 311 | loss.backward() 312 | 313 | avg_loss = avg_loss * 0.99 + loss.item() 314 | avg_sum = avg_sum * 0.99 + 1 315 | 316 | opti.step() 317 | print(f"{i}: {avg_loss / avg_sum:.3e} => {avg_openings / avg_openings_sum:.3e}") 318 | """ 319 | 320 | if __name__ == "__main__": 321 | 322 | device = "cpu" 323 | batch_size = 1 324 | time = 100 325 | intervall = 10 326 | 327 | l0rd = EpropGateL0rd(1, 1, 1, 1, reg_lambda=1e-5) 328 | opti = th.optim.Adam(l0rd.parameters(), lr=0.003) 329 | 330 | avg_openings = 0 331 | avg_openings_sum = 1e-100 332 | avg_loss = 0 333 | avg_sum = 0 334 | loss = 0 335 | for t in range(intervall*1): 336 | if t % intervall == 0: 337 | l0rd.reset_state() 338 | opti.zero_grad() 339 | loss = 0 340 | 341 | v = th.ones(batch_size, 1).to(device) 342 | h = l0rd(v) 343 | 344 | avg_openings = avg_openings * 0.99 + l0rd.openings.item() 345 | avg_openings_sum = avg_openings_sum * 0.99 + 1 346 | 347 | #loss = loss + th.mean((h - v)**2) 348 | loss = th.mean((h - v)**2) 349 | 350 | if t % intervall == intervall - 1: 351 | loss.backward() 352 | 353 | avg_loss = avg_loss * 0.99 + loss.item() 354 | avg_sum = avg_sum * 0.99 + 1 355 | 356 | opti.step() 357 | print(f"{t}: {avg_loss / avg_sum:.3e} => {avg_openings / avg_openings_sum:.3e}") 358 | 359 | -------------------------------------------------------------------------------- /nn/eprop_lstm.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch as th 3 | import numpy as np 4 | from torch.autograd import Function 5 | from einops import rearrange, repeat, reduce 6 | 7 | from typing import Tuple, Union, List 8 | import cv2 9 | 10 | __author__ = "Manuel Traub" 11 | 12 | class EpropLSTMFunction(Function): 13 | @staticmethod 14 | def forward(ctx, x, w_ix, w_ih, b_i, w_fx, w_fh, b_f, w_ox, w_oh, b_o, w_cx, w_ch, b_c, args): 15 | 16 | h_last, c_last, e_w_ix, e_w_ih, e_b_i, e_w_fx, e_w_fh, e_b_f, e_w_cx, e_w_ch, e_b_c = args 17 | 18 | i = th.sigmoid(x.mm(w_ix.t()) + h_last.mm(w_ih.t()) + b_i) 19 | f = th.sigmoid(x.mm(w_fx.t()) + h_last.mm(w_fh.t()) + b_f) 20 | o = th.sigmoid(x.mm(w_ox.t()) + h_last.mm(w_oh.t()) + b_o) 21 | c_hat = th.tanh( x.mm(w_cx.t()) + h_last.mm(w_ch.t()) + b_c) 22 | 23 | c = f * c_last + i * c_hat 24 | h = o * c 25 | 26 | di = i * (1 - i) 27 | df = f * (1 - f) 28 | do = o * (1 - o) 29 | dc_hat = (1 - c_hat**2) 30 | 31 | i_j = i.unsqueeze(dim=2) 32 | f_j = f.unsqueeze(dim=2) 33 | o_j = o.unsqueeze(dim=2) 34 | 35 | di_j = di.unsqueeze(dim=2) 36 | df_j = df.unsqueeze(dim=2) 37 | do_j = do.unsqueeze(dim=2) 38 | dc_hat_j = dc_hat.unsqueeze(dim=2) 39 | 40 | x_i = x.unsqueeze(dim=1) 41 | h_last_i = h_last.unsqueeze(dim=1) 42 | c_last_j = c_last.unsqueeze(dim=2) 43 | c_hat_j = c_hat.unsqueeze(dim=2) 44 | c_j = c.unsqueeze(dim=2) 45 | 46 | e_w_ih.copy_(e_w_ih * f_j + di_j * h_last_i * c_hat_j) 47 | e_w_ix.copy_(e_w_ix * f_j + di_j * x_i * c_hat_j) 48 | e_b_i.copy_( e_b_i * f + di * c_hat ) 49 | 50 | e_w_fh.copy_(e_w_fh * f_j + df_j * h_last_i * c_last_j) 51 | e_w_fx.copy_(e_w_fx * f_j + df_j * x_i * c_last_j) 52 | e_b_f.copy_( e_b_f * f + df * c_last ) 53 | 54 | e_w_ch.copy_(e_w_ch * f_j + dc_hat_j * h_last_i * i_j) 55 | e_w_cx.copy_(e_w_cx * f_j + dc_hat_j * x_i * i_j) 56 | e_b_c.copy_( e_b_c * f + dc_hat * i ) 57 | 58 | ctx.e_w_ix = e_w_ix 59 | ctx.e_w_ih = e_w_ih 60 | ctx.e_b_i = e_b_i 61 | ctx.e_w_fx = e_w_fx 62 | ctx.e_w_fh = e_w_fh 63 | ctx.e_b_f = e_b_f 64 | ctx.e_w_cx = e_w_cx 65 | ctx.e_w_ch = e_w_ch 66 | ctx.e_b_c = e_b_c 67 | 68 | ctx.i = i 69 | ctx.o = o 70 | ctx.c = c 71 | ctx.o_j = o_j 72 | ctx.c_j = c_j 73 | ctx.do = do 74 | ctx.df = df 75 | ctx.di = di 76 | ctx.dc_hat = dc_hat 77 | ctx.di_j = di_j 78 | ctx.do_j = do_j 79 | ctx.h_last_i = h_last_i 80 | ctx.c_last = c_last 81 | ctx.c_hat = c_hat 82 | ctx.x_i = x_i 83 | ctx.w_ox = w_ox 84 | ctx.w_fx = w_fx 85 | ctx.w_ix = w_ix 86 | ctx.w_cx = w_cx 87 | 88 | return h, c 89 | 90 | @staticmethod 91 | def backward(ctx, dh, _): 92 | 93 | e_w_ix = ctx.e_w_ix 94 | e_w_ih = ctx.e_w_ih 95 | e_b_i = ctx.e_b_i 96 | e_w_fx = ctx.e_w_fx 97 | e_w_fh = ctx.e_w_fh 98 | e_b_f = ctx.e_b_f 99 | e_w_cx = ctx.e_w_cx 100 | e_w_ch = ctx.e_w_ch 101 | e_b_c = ctx.e_b_c 102 | 103 | i = ctx.i 104 | o = ctx.o 105 | c = ctx.c 106 | o_j = ctx.o_j 107 | c_j = ctx.c_j 108 | do = ctx.do 109 | df = ctx.df 110 | di = ctx.di 111 | dc_hat = ctx.dc_hat 112 | di_j = ctx.di_j 113 | do_j = ctx.do_j 114 | h_last_i = ctx.h_last_i 115 | c_last = ctx.c_last 116 | c_hat = ctx.c_hat 117 | x_i = ctx.x_i 118 | w_ox = ctx.w_ox 119 | w_fx = ctx.w_fx 120 | w_ix = ctx.w_ix 121 | w_cx = ctx.w_cx 122 | 123 | dh_j = dh.unsqueeze(dim=2) 124 | 125 | dw_ix = th.sum(dh_j * o_j * e_w_ix, dim=0) 126 | dw_ih = th.sum(dh_j * o_j * e_w_ih, dim=0) 127 | db_i = th.sum(dh * o * e_b_i , dim=0) 128 | 129 | dw_fx = th.sum(dh_j * o_j * e_w_fx, dim=0) 130 | dw_fh = th.sum(dh_j * o_j * e_w_fh, dim=0) 131 | db_f = th.sum(dh * o * e_b_f , dim=0) 132 | 133 | dw_cx = th.sum(dh_j * o_j * e_w_cx, dim=0) 134 | dw_ch = th.sum(dh_j * o_j * e_w_ch, dim=0) 135 | db_c = th.sum(dh * o * e_b_c , dim=0) 136 | 137 | dw_oh = th.sum(dh_j * do_j * h_last_i * c_j, dim=0) 138 | dw_ox = th.sum(dh_j * do_j * x_i * c_j, dim=0) 139 | db_o = th.sum(dh * do * c , dim=0) 140 | 141 | dh_do = dh * do * c 142 | dh_df = dh * df * c_last * o 143 | dh_di = dh * di * c_hat * o 144 | dh_dc_hat = dh * dc_hat * i * o 145 | 146 | dx = dh_do.mm(w_ox) + dh_df.mm(w_fx) + dh_di.mm(w_ix) + dh_dc_hat.mm(w_cx) 147 | 148 | return dx, dw_ix, dw_ih, db_i, dw_fx, dw_fh, db_f, dw_ox, dw_oh, db_o, dw_cx, dw_ch, db_c, None 149 | 150 | class EpropLSTM(nn.Module): 151 | def __init__( 152 | self, 153 | num_inputs, 154 | num_hidden, 155 | batch_size 156 | ): 157 | super(EpropLSTM, self).__init__() 158 | 159 | self.num_inputs = num_inputs 160 | self.num_hidden = num_hidden 161 | 162 | self.fcn = EpropLSTMFunction().apply 163 | 164 | # input gate weights and biases 165 | self.w_ix = nn.Parameter(th.empty(num_hidden, num_inputs)) 166 | self.w_ih = nn.Parameter(th.empty(num_hidden, num_hidden)) 167 | self.b_i = nn.Parameter(th.ones(num_hidden)) 168 | 169 | # forget gate weights and biases 170 | self.w_fx = nn.Parameter(th.empty(num_hidden, num_inputs)) 171 | self.w_fh = nn.Parameter(th.empty(num_hidden, num_hidden)) 172 | self.b_f = nn.Parameter(th.ones(num_hidden)) 173 | 174 | # output gate weights and biases 175 | self.w_ox = nn.Parameter(th.empty(num_hidden, num_inputs)) 176 | self.w_oh = nn.Parameter(th.empty(num_hidden, num_hidden)) 177 | self.b_o = nn.Parameter(th.ones(num_hidden)) 178 | 179 | # cell weights and biases 180 | self.w_cx = nn.Parameter(th.empty(num_hidden, num_inputs)) 181 | self.w_ch = nn.Parameter(th.empty(num_hidden, num_hidden)) 182 | self.b_c = nn.Parameter(th.zeros(num_hidden)) 183 | 184 | # input gate eligibilitiy traces 185 | self.register_buffer("e_w_ix", th.zeros(batch_size, num_hidden, num_inputs), persistent=False) 186 | self.register_buffer("e_w_ih", th.zeros(batch_size, num_hidden, num_hidden), persistent=False) 187 | self.register_buffer("e_b_i", th.zeros(batch_size, num_hidden), persistent=False) 188 | 189 | # forget gate eligibilitiy traces 190 | self.register_buffer("e_w_fx", th.zeros(batch_size, num_hidden, num_inputs), persistent=False) 191 | self.register_buffer("e_w_fh", th.zeros(batch_size, num_hidden, num_hidden), persistent=False) 192 | self.register_buffer("e_b_f", th.zeros(batch_size, num_hidden), persistent=False) 193 | 194 | # cell eligibilitiy traces 195 | self.register_buffer("e_w_cx", th.zeros(batch_size, num_hidden, num_inputs), persistent=False) 196 | self.register_buffer("e_w_ch", th.zeros(batch_size, num_hidden, num_hidden), persistent=False) 197 | self.register_buffer("e_b_c", th.zeros(batch_size, num_hidden), persistent=False) 198 | 199 | # cell sate and hidden state 200 | self.register_buffer("h_last", th.zeros(batch_size, num_hidden), persistent=False) 201 | self.register_buffer("c_last", th.zeros(batch_size, num_hidden), persistent=False) 202 | 203 | # initialize weights 204 | stdv_i = np.sqrt(6/(self.num_inputs + self.num_hidden)) 205 | stdv_h = np.sqrt(3/self.num_hidden) 206 | 207 | nn.init.uniform_(self.w_ix, -stdv_i, stdv_i) 208 | nn.init.uniform_(self.w_ih, -stdv_h, stdv_h) 209 | 210 | nn.init.uniform_(self.w_fx, -stdv_i, stdv_i) 211 | nn.init.uniform_(self.w_fh, -stdv_h, stdv_h) 212 | 213 | nn.init.uniform_(self.w_ox, -stdv_i, stdv_i) 214 | nn.init.uniform_(self.w_oh, -stdv_h, stdv_h) 215 | 216 | nn.init.uniform_(self.w_cx, -stdv_i, stdv_i) 217 | nn.init.uniform_(self.w_ch, -stdv_h, stdv_h) 218 | 219 | def reset_state(self): 220 | self.h_last.zero_() 221 | self.c_last.zero_() 222 | self.e_w_ix.zero_() 223 | self.e_w_ih.zero_() 224 | self.e_b_i.zero_() 225 | self.e_w_fx.zero_() 226 | self.e_w_fh.zero_() 227 | self.e_b_f.zero_() 228 | self.e_w_cx.zero_() 229 | self.e_w_ch.zero_() 230 | self.e_b_c.zero_() 231 | 232 | def forward(self, x: th.Tensor): 233 | self.h_last.detach_() 234 | self.c_last.detach_() 235 | 236 | h, c = self.fcn( 237 | x, 238 | self.w_ix, self.w_ih, self.b_i, 239 | self.w_fx, self.w_fh, self.b_f, 240 | self.w_ox, self.w_oh, self.b_o, 241 | self.w_cx, self.w_ch, self.b_c, 242 | ( 243 | self.h_last, self.c_last, 244 | self.e_w_ix, self.e_w_ih, self.e_b_i, 245 | self.e_w_fx, self.e_w_fh, self.e_b_f, 246 | self.e_w_cx, self.e_w_ch, self.e_b_c 247 | ) 248 | ) 249 | 250 | self.h_last = h 251 | self.c_last = c 252 | 253 | return h 254 | -------------------------------------------------------------------------------- /nn/predictor.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch as th 3 | import numpy as np 4 | from utils.utils import Gaus2D, SharedObjectsToBatch, BatchToSharedObjects, LambdaModule 5 | from nn.eprop_lstm import EpropLSTM 6 | from nn.eprop_gate_l0rd import EpropGateL0rd 7 | from nn.residual import ResidualBlock 8 | from nn.tracker import CaterSnitchTracker 9 | from nn.vae import VariationalFunction 10 | from torch.autograd import Function 11 | from einops import rearrange, repeat, reduce 12 | 13 | from typing import Tuple, Union, List 14 | import utils 15 | import cv2 16 | 17 | __author__ = "Manuel Traub" 18 | 19 | class AlphaAttention(nn.Module): 20 | def __init__( 21 | self, 22 | num_hidden, 23 | num_objects, 24 | heads, 25 | dropout = 0.0 26 | ): 27 | super(AlphaAttention, self).__init__() 28 | 29 | self.to_sequence = LambdaModule(lambda x: rearrange(x, '(b o) c -> b o c', o = num_objects)) 30 | self.to_batch = LambdaModule(lambda x: rearrange(x, 'b o c -> (b o) c', o = num_objects)) 31 | 32 | self.alpha = nn.Parameter(th.zeros(1)+1e-12) 33 | self.attention = nn.MultiheadAttention( 34 | num_hidden, 35 | heads, 36 | dropout = dropout, 37 | batch_first = True 38 | ) 39 | 40 | def forward(self, x: th.Tensor): 41 | x = self.to_sequence(x) 42 | x = x + self.alpha * self.attention(x, x, x, need_weights=False)[0] 43 | return self.to_batch(x) 44 | 45 | class EpropAlphaGateL0rd(nn.Module): 46 | def __init__(self, num_hidden, batch_size, reg_lambda): 47 | super(EpropAlphaGateL0rd, self).__init__() 48 | 49 | self.alpha = nn.Parameter(th.zeros(1)+1e-12) 50 | self.l0rd = EpropGateL0rd( 51 | num_inputs = num_hidden, 52 | num_hidden = num_hidden, 53 | num_outputs = num_hidden, 54 | reg_lambda = reg_lambda, 55 | batch_size = batch_size 56 | ) 57 | 58 | def forward(self, input): 59 | return input + self.alpha * self.l0rd(input) 60 | 61 | class InputEmbeding(nn.Module): 62 | def __init__(self, num_inputs, num_hidden): 63 | super(InputEmbeding, self).__init__() 64 | 65 | self.embeding = nn.Sequential( 66 | nn.ReLU(), 67 | nn.Linear(num_inputs, num_hidden), 68 | nn.ReLU(), 69 | nn.Linear(num_hidden, num_hidden), 70 | ) 71 | self.skip = LambdaModule( 72 | lambda x: repeat(x, 'b c -> b (n c)', n = num_hidden // num_inputs) 73 | ) 74 | self.alpha = nn.Parameter(th.zeros(1)+1e-12) 75 | 76 | def forward(self, input: th.Tensor): 77 | return self.skip(input) + self.alpha * self.embeding(input) 78 | 79 | class OutputEmbeding(nn.Module): 80 | def __init__(self, num_hidden, num_outputs): 81 | super(OutputEmbeding, self).__init__() 82 | 83 | self.embeding = nn.Sequential( 84 | nn.ReLU(), 85 | nn.Linear(num_hidden, num_outputs), 86 | nn.ReLU(), 87 | nn.Linear(num_outputs, num_outputs), 88 | ) 89 | self.skip = LambdaModule( 90 | lambda x: reduce(x, 'b (n c) -> b c', 'mean', n = num_hidden // num_outputs) 91 | ) 92 | self.alpha = nn.Parameter(th.zeros(1)+1e-12) 93 | 94 | def forward(self, input: th.Tensor): 95 | return self.skip(input) + self.alpha * self.embeding(input) 96 | 97 | class EpropGateL0rdTransformer(nn.Module): 98 | def __init__( 99 | self, 100 | channels, 101 | num_objects, 102 | batch_size, 103 | heads, 104 | deepth, 105 | reg_lambda, 106 | dropout=0.0 107 | ): 108 | super(EpropGateL0rdTransformer, self).__init__() 109 | 110 | num_inputs = channels 111 | num_outputs = channels 112 | num_hidden = channels * heads 113 | 114 | self.deepth = deepth 115 | _layers = [] 116 | _layers.append(InputEmbeding(num_inputs, num_hidden)) 117 | 118 | for i in range(deepth): 119 | _layers.append(AlphaAttention(num_hidden, num_objects, heads, dropout)) 120 | _layers.append(EpropAlphaGateL0rd(num_hidden, batch_size * num_objects, reg_lambda)) 121 | 122 | _layers.append(OutputEmbeding(num_hidden, num_outputs)) 123 | self.layers = nn.Sequential(*_layers) 124 | 125 | def get_openings(self): 126 | openings = 0 127 | for i in range(self.deepth): 128 | openings += self.layers[2 * (i + 1)].l0rd.openings.item() 129 | 130 | return openings / self.deepth 131 | 132 | def get_hidden(self): 133 | states = [] 134 | for i in range(self.deepth): 135 | states.append(self.layers[2 * (i + 1)].l0rd.get_hidden()) 136 | 137 | return th.cat(states, dim=1) 138 | 139 | def set_hidden(self, hidden): 140 | states = th.chunk(hidden, self.deepth, dim=1) 141 | for i in range(self.deepth): 142 | self.layers[2 * (i + 1)].l0rd.set_hidden(states[i]) 143 | 144 | def forward(self, input: th.Tensor) -> th.Tensor: 145 | return self.layers(input) 146 | 147 | class PriorityEncoder(nn.Module): 148 | def __init__(self, num_objects, batch_size): 149 | super(PriorityEncoder, self).__init__() 150 | 151 | self.num_objects = num_objects 152 | self.register_buffer("indices", repeat(th.arange(num_objects), 'a -> (b a) 1', b=batch_size), persistent=False) 153 | 154 | self.index_factor = nn.Parameter(th.ones(1)) 155 | self.priority_factor = nn.Parameter(th.ones(1)) 156 | 157 | def forward(self, priority: th.Tensor) -> th.Tensor: 158 | 159 | priority = priority * self.num_objects + th.randn_like(priority) * 0.1 160 | priority = priority * self.priority_factor 161 | priority = priority + self.indices * self.index_factor 162 | priority = rearrange(priority, '(b o) 1 -> b o', o=self.num_objects) 163 | 164 | return priority * 25 165 | 166 | class LociPredictor(nn.Module): 167 | def __init__( 168 | self, 169 | heads: int, 170 | layers: int, 171 | reg_lambda: float, 172 | num_objects: int, 173 | gestalt_size: int, 174 | batch_size: int, 175 | camera_view_matrix = None, 176 | zero_elevation = None 177 | ): 178 | super(LociPredictor, self).__init__() 179 | self.num_objects = num_objects 180 | self.std_alpha = nn.Parameter(th.zeros(1)+1e-16) 181 | 182 | self.reg_lambda = reg_lambda 183 | self.predictor = EpropGateL0rdTransformer( 184 | channels = gestalt_size + 4, 185 | heads = heads, 186 | deepth = layers, 187 | num_objects = num_objects, 188 | reg_lambda = reg_lambda, 189 | batch_size = batch_size, 190 | ) 191 | 192 | self.tracker = None 193 | if camera_view_matrix is not None: 194 | self.tracker = CaterSnitchTracker( 195 | latent_size = gestalt_size + 2, 196 | num_objects = num_objects, 197 | camera_view_matrix = camera_view_matrix, 198 | zero_elevation = zero_elevation 199 | ) 200 | 201 | self.bottleneck = nn.Sequential( 202 | LambdaModule(lambda x: rearrange(x, 'b c -> b c 1 1')), 203 | ResidualBlock(gestalt_size, gestalt_size, kernel_size=1), 204 | nn.Sigmoid(), 205 | LambdaModule(lambda x: x + x * (1 - x) * th.randn_like(x)), 206 | LambdaModule(lambda x: rearrange(x, '(b o) c 1 1 -> b (o c)', o=num_objects)) 207 | ) 208 | 209 | self.priority_encoder = PriorityEncoder(num_objects, batch_size) 210 | 211 | self.to_batch = LambdaModule(lambda x: rearrange(x, 'b (o c) -> (b o) c', o=num_objects)) 212 | self.to_shared = LambdaModule(lambda x: rearrange(x, '(b o) c -> b (o c)', o=num_objects)) 213 | 214 | def get_openings(self): 215 | return self.predictor.get_openings() 216 | 217 | def get_hidden(self): 218 | return self.predictor.get_hidden() 219 | 220 | def set_hidden(self, hidden): 221 | self.predictor.set_hidden(hidden) 222 | 223 | def forward( 224 | self, 225 | position: th.Tensor, 226 | gestalt: th.Tensor, 227 | priority: th.Tensor, 228 | ): 229 | 230 | position = self.to_batch(position) 231 | gestalt = self.to_batch(gestalt) 232 | priority = self.to_batch(priority) 233 | 234 | input = th.cat((position, gestalt, priority), dim=1) 235 | output = self.predictor(input) 236 | 237 | xy = output[:,:2] 238 | std = output[:,2:3] 239 | gestalt = output[:,3:-1] 240 | priority = output[:,-1:] 241 | 242 | snitch_position = None 243 | if self.tracker is not None: 244 | snitch_position = self.tracker(xy, output[:,2:]) 245 | 246 | position = th.cat((xy, std * self.std_alpha), dim=1) 247 | 248 | position = self.to_shared(position) 249 | gestalt = self.bottleneck(gestalt) 250 | priority = self.priority_encoder(priority) 251 | 252 | return position, gestalt, priority, snitch_position 253 | -------------------------------------------------------------------------------- /nn/residual.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch as th 3 | import numpy as np 4 | import nn as nn_modules 5 | from einops import rearrange, repeat, reduce 6 | from utils.utils import LambdaModule 7 | 8 | from typing import Union, Tuple 9 | 10 | __author__ = "Manuel Traub" 11 | 12 | class DynamicLayerNorm(nn.Module): 13 | 14 | def __init__(self, eps: float = 1e-5): 15 | super(DynamicLayerNorm, self).__init__() 16 | self.eps = eps 17 | 18 | def forward(self, input: th.Tensor) -> th.Tensor: 19 | return nn.functional.layer_norm(input, input.shape[2:], None, None, self.eps) 20 | 21 | 22 | class SkipConnection(nn.Module): 23 | def __init__( 24 | self, 25 | in_channels: int, 26 | out_channels: int, 27 | scale_factor: float = 1.0 28 | ): 29 | super(SkipConnection, self).__init__() 30 | assert scale_factor == 1 or int(scale_factor) > 1 or int(1 / scale_factor) > 1, f'invalid scale factor in SpikeFunction: {scale_factor}' 31 | 32 | self.in_channels = in_channels 33 | self.out_channels = out_channels 34 | self.scale_factor = scale_factor 35 | 36 | def channel_skip(self, input: th.Tensor): 37 | in_channels = self.in_channels 38 | out_channels = self.out_channels 39 | 40 | if in_channels == out_channels: 41 | return input 42 | 43 | if in_channels % out_channels == 0 or out_channels % in_channels == 0: 44 | 45 | if in_channels > out_channels: 46 | return reduce(input, 'b (c n) h w -> b c h w', 'mean', n = in_channels // out_channels) 47 | 48 | if out_channels > in_channels: 49 | return repeat(input, 'b c h w -> b (c n) h w', n = out_channels // in_channels) 50 | 51 | mean_channels = np.gcd(in_channels, out_channels) 52 | input = reduce(input, 'b (c n) h w -> b c h w', 'mean', n = in_channels // mean_channels) 53 | return repeat(input, 'b c h w -> b (c n) h w', n = out_channels // mean_channels) 54 | 55 | def scale_skip(self, input: th.Tensor): 56 | scale_factor = self.scale_factor 57 | 58 | if scale_factor == 1: 59 | return input 60 | 61 | if scale_factor > 1: 62 | return repeat( 63 | input, 64 | 'b c h w -> b c (h h2) (w w2)', 65 | h2 = int(scale_factor), 66 | w2 = int(scale_factor) 67 | ) 68 | 69 | height = input.shape[2] 70 | width = input.shape[3] 71 | 72 | # scale factor < 1 73 | scale_factor = int(1 / scale_factor) 74 | 75 | if width % scale_factor == 0 and height % scale_factor == 0: 76 | return reduce( 77 | input, 78 | 'b c (h h2) (w w2) -> b c h w', 79 | 'mean', 80 | h2 = scale_factor, 81 | w2 = scale_factor 82 | ) 83 | 84 | if width >= scale_factor and height >= scale_factor: 85 | return nn.functional.avg_pool2d( 86 | input, 87 | kernel_size = scale_factor, 88 | stride = scale_factor 89 | ) 90 | 91 | assert width > 1 or height > 1 92 | return reduce(input, 'b c h w -> b c 1 1', 'mean') 93 | 94 | 95 | def forward(self, input: th.Tensor): 96 | 97 | if self.scale_factor > 1: 98 | return self.scale_skip(self.channel_skip(input)) 99 | 100 | return self.channel_skip(self.scale_skip(input)) 101 | 102 | class DownScale(nn.Module): 103 | def __init__( 104 | self, 105 | in_channels: int, 106 | out_channels: int, 107 | scale_factor: int, 108 | groups: int = 1, 109 | bias: bool = True 110 | ): 111 | 112 | super(DownScale, self).__init__() 113 | 114 | assert(in_channels % groups == 0) 115 | assert(out_channels % groups == 0) 116 | 117 | self.groups = groups 118 | self.scale_factor = scale_factor 119 | self.weight = nn.Parameter(th.empty((out_channels, in_channels // groups, scale_factor, scale_factor))) 120 | self.bias = nn.Parameter(th.empty((out_channels,))) if bias else None 121 | 122 | nn.init.kaiming_uniform_(self.weight, a=np.sqrt(5)) 123 | 124 | if self.bias is not None: 125 | fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) 126 | bound = 1 / np.sqrt(fan_in) 127 | nn.init.uniform_(self.bias, -bound, bound) 128 | 129 | def forward(self, input: th.Tensor): 130 | height = input.shape[2] 131 | width = input.shape[3] 132 | assert height > 1 or width > 1, "trying to dowscale 1x1" 133 | 134 | scale_factor = self.scale_factor 135 | padding = [0, 0] 136 | 137 | if height < scale_factor: 138 | padding[0] = scale_factor - height 139 | 140 | if width < scale_factor: 141 | padding[1] = scale_factor - width 142 | 143 | return nn.functional.conv2d( 144 | input, 145 | self.weight, 146 | bias=self.bias, 147 | stride=scale_factor, 148 | padding=padding, 149 | groups=self.groups 150 | ) 151 | 152 | 153 | class ResidualBlock(nn.Module): 154 | def __init__( 155 | self, 156 | in_channels: int, 157 | out_channels: int, 158 | kernel_size: Union[int, Tuple[int, int]] = (3, 3), 159 | scale_factor: int = 1, 160 | groups: Union[int, Tuple[int, int]] = (1, 1), 161 | bias: bool = True, 162 | layer_norm: bool = False, 163 | leaky_relu: bool = False, 164 | residual: bool = True, 165 | alpha_residual: bool = False, 166 | input_nonlinearity = True 167 | ): 168 | 169 | super(ResidualBlock, self).__init__() 170 | self.residual = residual 171 | self.alpha_residual = alpha_residual 172 | self.skip = False 173 | self.in_channels = in_channels 174 | self.out_channels = out_channels 175 | 176 | if isinstance(kernel_size, int): 177 | kernel_size = [kernel_size, kernel_size] 178 | 179 | if isinstance(groups, int): 180 | groups = [groups, groups] 181 | 182 | padding = (kernel_size[0] // 2, kernel_size[1] // 2) 183 | 184 | _layers = list() 185 | if layer_norm: 186 | _layers.append(DynamicLayerNorm()) 187 | 188 | if input_nonlinearity: 189 | if leaky_relu: 190 | _layers.append(nn.LeakyReLU()) 191 | else: 192 | _layers.append(nn.ReLU()) 193 | 194 | if scale_factor > 1: 195 | _layers.append( 196 | nn.ConvTranspose2d( 197 | in_channels=in_channels, 198 | out_channels=out_channels, 199 | kernel_size=scale_factor, 200 | stride=scale_factor, 201 | groups=groups[0], 202 | bias=bias 203 | ) 204 | ) 205 | elif scale_factor < 1: 206 | _layers.append( 207 | DownScale( 208 | in_channels=in_channels, 209 | out_channels=out_channels, 210 | scale_factor=int(1.0/scale_factor), 211 | groups=groups[0], 212 | bias=bias 213 | ) 214 | ) 215 | else: 216 | _layers.append( 217 | nn.Conv2d( 218 | in_channels=in_channels, 219 | out_channels=out_channels, 220 | kernel_size=kernel_size, 221 | padding=padding, 222 | groups=groups[0], 223 | bias=bias 224 | ) 225 | ) 226 | 227 | if layer_norm: 228 | _layers.append(DynamicLayerNorm()) 229 | if leaky_relu: 230 | _layers.append(nn.LeakyReLU()) 231 | else: 232 | _layers.append(nn.ReLU()) 233 | _layers.append( 234 | nn.Conv2d( 235 | in_channels=out_channels, 236 | out_channels=out_channels, 237 | kernel_size=kernel_size, 238 | padding=padding, 239 | groups=groups[1], 240 | bias=bias 241 | ) 242 | ) 243 | self.layers = nn.Sequential(*_layers) 244 | 245 | if self.residual: 246 | self.skip_connection = SkipConnection( 247 | in_channels=in_channels, 248 | out_channels=out_channels, 249 | scale_factor=scale_factor 250 | ) 251 | 252 | if self.alpha_residual: 253 | self.alpha = nn.Parameter(th.zeros(1) + 1e-12) 254 | 255 | def set_mode(self, **kwargs): 256 | if 'skip' in kwargs: 257 | self.skip = kwargs['skip'] 258 | 259 | if 'residual' in kwargs: 260 | self.residual = kwargs['residual'] 261 | 262 | def forward(self, input: th.Tensor) -> th.Tensor: 263 | if self.skip: 264 | return self.skip_connection(input) 265 | 266 | if not self.residual: 267 | return self.layers(input) 268 | 269 | if self.alpha_residual: 270 | return self.alpha * self.layers(input) + self.skip_connection(input) 271 | 272 | return self.layers(input) + self.skip_connection(input) 273 | 274 | class LinearSkip(nn.Module): 275 | def __init__(self, num_inputs: int, num_outputs: int): 276 | super(LinearSkip, self).__init__() 277 | 278 | self.num_inputs = num_inputs 279 | self.num_outputs = num_outputs 280 | 281 | if num_inputs % num_outputs != 0 and num_outputs % num_inputs != 0: 282 | mean_channels = np.gcd(num_inputs, num_outputs) 283 | print(f"[WW] gcd skip: {num_inputs} -> {mean_channels} -> {num_outputs}") 284 | assert(False) 285 | 286 | def forward(self, input: th.Tensor): 287 | num_inputs = self.num_inputs 288 | num_outputs = self.num_outputs 289 | 290 | if num_inputs == num_outputs: 291 | return input 292 | 293 | if num_inputs % num_outputs == 0 or num_outputs % num_inputs == 0: 294 | 295 | if num_inputs > num_outputs: 296 | return reduce(input, 'b (c n) -> b c', 'mean', n = num_inputs // num_outputs) 297 | 298 | if num_outputs > num_inputs: 299 | return repeat(input, 'b c -> b (c n)', n = num_outputs // num_inputs) 300 | 301 | mean_channels = np.gcd(num_inputs, num_outputs) 302 | input = reduce(input, 'b (c n) -> b c', 'mean', n = num_inputs // mean_channels) 303 | return repeat(input, 'b c -> b (c n)', n = num_outputs // mean_channels) 304 | 305 | class LinearResidual(nn.Module): 306 | def __init__( 307 | self, 308 | num_inputs: int, 309 | num_outputs: int, 310 | num_hidden: int = None, 311 | residual: bool = True, 312 | alpha_residual: bool = False, 313 | input_relu: bool = True 314 | ): 315 | super(LinearResidual, self).__init__() 316 | 317 | self.residual = residual 318 | self.alpha_residual = alpha_residual 319 | 320 | if num_hidden is None: 321 | num_hidden = num_outputs 322 | 323 | _layers = [] 324 | if input_relu: 325 | _layers.append(nn.ReLU()) 326 | _layers.append(nn.Linear(num_inputs, num_hidden)) 327 | _layers.append(nn.ReLU()) 328 | _layers.append(nn.Linear(num_hidden, num_outputs)) 329 | 330 | self.layers = nn.Sequential(*_layers) 331 | 332 | if residual: 333 | self.skip = LinearSkip(num_inputs, num_outputs) 334 | 335 | if alpha_residual: 336 | self.alpha = nn.Parameter(th.zeros(1)+1e-16) 337 | 338 | def forward(self, input: th.Tensor): 339 | if not self.residual: 340 | return self.layers(input) 341 | 342 | if not self.alpha_residual: 343 | return self.skip(input) + self.layers(input) 344 | 345 | return self.skip(input) + self.alpha * self.layers(input) 346 | 347 | class ResidualAttentionBlock(nn.Module): 348 | def __init__(self, channels: int, channel_factor = 4, space_factor = 5): 349 | super(ResidualAttentionBlock, self).__init__() 350 | 351 | self.layers = nn.Sequential( 352 | nn.ReLU(), 353 | nn.Conv2d( 354 | in_channels=channels, 355 | out_channels=channels, 356 | kernel_size=3, 357 | padding=1, 358 | ), 359 | nn.ReLU(), 360 | nn.Conv2d( 361 | in_channels=channels, 362 | out_channels=channels, 363 | kernel_size=3, 364 | padding=1, 365 | ) 366 | ) 367 | 368 | self.gate2d = nn.Sequential( 369 | nn.ReLU(), 370 | nn.Conv2d( 371 | in_channels=channels, 372 | out_channels=1, 373 | kernel_size=space_factor, 374 | padding=0, 375 | stride=space_factor 376 | ), 377 | nn.ReLU(), 378 | nn.Conv2d( 379 | in_channels=1, 380 | out_channels=channels, 381 | kernel_size=3, 382 | padding=1, 383 | ), 384 | nn.ReLU(), 385 | nn.ConvTranspose2d( 386 | in_channels=channels, 387 | out_channels=channels, 388 | kernel_size=space_factor, 389 | padding=0, 390 | stride=space_factor 391 | ), 392 | nn.ReLU(), 393 | nn.Conv2d( 394 | in_channels=channels, 395 | out_channels=1, 396 | kernel_size=3, 397 | padding=1, 398 | ), 399 | nn.Sigmoid() 400 | ) 401 | #self.gate2d[-2].bias.data = th.ones_like(self.gate2d[-2].bias.data) + 4 402 | 403 | self.channel_gate = nn.Sequential( 404 | nn.ReLU(), 405 | LambdaModule(lambda x: reduce(x, 'b c h w -> b c', 'mean')), 406 | nn.Linear(channels, channels // channel_factor), 407 | nn.ReLU(), 408 | nn.Linear(channels // channel_factor, channels), 409 | nn.Sigmoid(), 410 | LambdaModule(lambda x: rearrange(x, 'b c -> b c 1 1')) 411 | ) 412 | #self.channel_gate[-3].bias.data = th.ones_like(self.channel_gate[-3].bias.data) + 4 413 | 414 | 415 | def forward(self, input: th.Tensor) -> th.Tensor: 416 | #print(f"channel_gate: {th.mean(self.channel_gate[-3].weight.data).item():.9e} +- {th.std(self.channel_gate[-3].weight.data).item():.2e}") 417 | #print(f"gate2d: {th.mean(self.gate2d[-2].weight.data).item():.9e} +- {th.std(self.gate2d[-2].weight.data).item():.9e}") 418 | return input + self.layers(input) * self.gate2d(input) * self.channel_gate(input) 419 | -------------------------------------------------------------------------------- /nn/tracker.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch as th 3 | import numpy as np 4 | from torch.autograd import Function 5 | from einops import rearrange, repeat, reduce 6 | from typing import Tuple, Union, List 7 | from utils.utils import LambdaModule 8 | from nn.residual import LinearResidual 9 | 10 | __author__ = "Manuel Traub" 11 | 12 | class GridClassifier(nn.Module): 13 | def __init__(self): 14 | super(GridClassifier, self).__init__() 15 | 16 | x = (th.arange(36) % 6) - 3 + 0.5 17 | y = th.floor(th.arange(36) / 6) - 3 + 0.5 18 | self.register_buffer("grid_centers", th.stack((x,y), dim=1).unsqueeze(dim=0), persistent=False) 19 | 20 | def forward(self, position: th.Tensor): 21 | grid_distances = th.exp(-th.sum((self.grid_centers - position[:,:2].unsqueeze(dim=1))**2, dim=2)) 22 | return th.softmax(grid_distances, dim=1) 23 | 24 | class L1GridDistance(nn.Module): 25 | def __init__(self, camera_view_matrix, zero_elevation): 26 | super(L1GridDistance, self).__init__() 27 | 28 | self.to_world = ScreenToWorld(camera_view_matrix, zero_elevation) 29 | self.to_grid = GridClassifier() 30 | 31 | def forward(self, position, target_label): 32 | 33 | probabilities = self.to_grid(self.to_world(position)) 34 | label = th.argmax(probabilities, dim=1).long() 35 | target_label = target_label.long() 36 | 37 | x = label % 6 38 | y = label / 6 39 | 40 | target_x = target_label % 6 41 | target_y = target_label / 6 42 | 43 | l1 = th.mean((th.abs(x - target_x) + th.abs(y - target_y)).float()) 44 | top1 = th.mean((label == target_label).float()) 45 | top5 = th.mean(th.sum((th.topk(probabilities, 5, dim=1)[1] == target_label.unsqueeze(dim=1)).float(), dim=1)) 46 | 47 | return l1, top1, top5, label 48 | 49 | class L2TrackingDistance(nn.Module): 50 | def __init__(self, camera_view_matrix, zero_elevation): 51 | super(L2TrackingDistance, self).__init__() 52 | 53 | self.to_world = ScreenToWorld(camera_view_matrix, zero_elevation) 54 | 55 | def forward(self, position, target_position): 56 | 57 | position = self.to_world(position) 58 | return th.sqrt(th.sum((position - target_position)**2, dim=1)), position 59 | 60 | 61 | class TrackingLoss(nn.Module): 62 | def __init__(self, camera_view_matrix, zero_elevation): 63 | super(TrackingLoss, self).__init__() 64 | 65 | self.to_screen = WolrdToScreen(camera_view_matrix, zero_elevation) 66 | 67 | def forward(self, position, target_position): 68 | 69 | target_position = self.to_screen(target_position) 70 | return th.mean((position - target_position)**2) 71 | 72 | class WolrdToScreen(nn.Module): 73 | def __init__(self, camera_view_matrix, zero_elevation): 74 | super(WolrdToScreen, self).__init__() 75 | 76 | self.register_buffer('cam', th.tensor(camera_view_matrix).float()) 77 | 78 | def forward(self, world_xyz): 79 | world_xyzw = th.cat((world_xyz, th.ones((world_xyz.shape[0], 1), device=world_xyz.device)), dim=1) 80 | screen_yxzw = world_xyzw.mm(self.cam.t()) 81 | 82 | screen_yx = th.stack(( 83 | screen_yxzw[:,1] / -screen_yxzw[:,-1], 84 | screen_yxzw[:,0] / screen_yxzw[:,-1] 85 | ), dim=1) 86 | 87 | return screen_yx 88 | 89 | class ScreenToWorld(nn.Module): 90 | def __init__(self, camera_view_matrix, zero_elevation): 91 | super(ScreenToWorld, self).__init__() 92 | 93 | self.register_buffer('cam', th.tensor(camera_view_matrix).float()) 94 | self.z = zero_elevation 95 | 96 | def forward(self, xy, z = None): 97 | a,b,c,d = th.unbind(self.cam[0]) 98 | e,f,g,h = th.unbind(self.cam[1]) 99 | m,n,o,p = th.unbind(self.cam[3]) 100 | 101 | X = xy[:,1] 102 | Y = xy[:,0] 103 | 104 | if z is None: 105 | z = th.zeros_like(X) 106 | else: 107 | z = z[:,0] 108 | 109 | z = z + self.z 110 | 111 | x = (b*h - d*f + Y*b*p - Y*d*n + X*f*p - X*h*n + b*g*z - c*f*z + Y*b*o*z - Y*c*n*z + X*f*o*z - X*g*n*z)/(a*f - b*e + Y*a*n - Y*b*m + X*e*n - X*f*m) 112 | y = -(a*h - d*e + Y*a*p - Y*d*m + X*e*p - X*h*m + a*g*z - c*e*z + Y*a*o*z - Y*c*m*z + X*e*o*z - X*g*m*z)/(a*f - b*e + Y*a*n - Y*b*m + X*e*n - X*f*m) 113 | 114 | return th.stack((x,y,z), dim=1) 115 | 116 | 117 | class CaterSnitchTracker(nn.Module): 118 | def __init__( 119 | self, 120 | latent_size, 121 | num_objects, 122 | camera_view_matrix, 123 | zero_elevation 124 | ): 125 | super(CaterSnitchTracker, self).__init__() 126 | 127 | self.to_shared = LambdaModule(lambda x: rearrange(x, '(b o) c -> b o c', o = num_objects)) 128 | 129 | self.gate = nn.Sequential( 130 | LinearResidual(latent_size, latent_size), 131 | LinearResidual(latent_size, latent_size), 132 | LinearResidual(latent_size, 1), 133 | LambdaModule(lambda x: rearrange(x, '(b o) 1 -> b o 1', o = num_objects)), 134 | nn.Softmax(dim=1), 135 | LambdaModule(lambda x: x + x * (1 - x) * th.randn_like(x)), 136 | ) 137 | 138 | def forward(self, position, latent_state): 139 | return th.sum(self.to_shared(position) * self.gate(latent_state), dim=1) 140 | 141 | 142 | if __name__ == "__main__": 143 | cam = np.array([ 144 | (1.4503, 1.6376, 0.0000, -0.0251), 145 | (-1.0346, 0.9163, 2.5685, 0.0095), 146 | (-0.6606, 0.5850, -0.4748, 10.5666), 147 | (-0.6592, 0.5839, -0.4738, 10.7452) 148 | ]) 149 | 150 | z = 0.3421497941017151 151 | 152 | to_screen = WolrdToScreen(cam, z) 153 | to_world = ScreenToWorld(cam, z) 154 | 155 | xyz = th.rand(1000, 3) * 6 - 3 156 | 157 | _xyz = to_world(to_screen(to_world(to_screen(xyz), xyz[:,2:] - z)), xyz[:,2:] - z) 158 | 159 | print(th.mean((xyz - _xyz)**2).item()) 160 | -------------------------------------------------------------------------------- /nn/vae.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | from torch import nn 3 | from torch.autograd import Function 4 | from typing import List, Tuple 5 | 6 | 7 | class KLLossFunction(Function): 8 | @staticmethod 9 | def forward(ctx, mu: th.Tensor, logsigma: th.Tensor, lr=1.0): 10 | ctx.save_for_backward(mu, logsigma, th.tensor(lr, device=mu.device)) 11 | return mu, logsigma 12 | 13 | @staticmethod 14 | def backward(ctx, grad_mu: th.Tensor, grad_sigma: th.Tensor): 15 | mu, logsigma, lr = ctx.saved_tensors 16 | 17 | # KL Divergence for the grad 18 | #kl_div = 0.5 * (th.exp(logsigma) + mu**2 - 1 - logsigma).sum(dim=1).mean() 19 | 20 | # Undoes grad.sum(dim=1).mean() 21 | n_mean = mu.numel() / mu.shape[1] 22 | 23 | kl_mu = mu / n_mean 24 | kl_sigma = 0.5 * (th.exp(logsigma) - 1) / n_mean 25 | 26 | grad_mu = grad_mu + lr * kl_mu 27 | grad_sigma = grad_sigma + lr * kl_sigma 28 | return grad_mu, grad_sigma, None 29 | 30 | class VariationalFunction(nn.Module): 31 | """ 32 | Variational Layer with KLDiv Loss, samples from a given predicted mu and sigma distribution; 33 | Needs an additional Encoder and Decoder for an VAE 34 | """ 35 | def __init__(self, mean=0, factor=1, groups=1): 36 | """ 37 | Init 38 | :param factor: kl gradient scalling factor 39 | """ 40 | super(VariationalFunction, self).__init__() 41 | self.factor = factor 42 | self.mean = mean 43 | self.groups = groups 44 | 45 | # Kullback Leibler Divergence 46 | self.kl = KLLossFunction.apply 47 | 48 | def forward(self, input: th.Tensor): 49 | # Encodes latent state 50 | input = input.view(input.shape[0], self.groups, -1, *input.shape[2:]) 51 | mu, logsigma = th.chunk(input, chunks=2, dim=2) 52 | 53 | # Adds gradients from Kullback Leibler Divergence loss 54 | mu, logsigma = self.kl(th.clip(mu, -100, 100), th.clip(logsigma, -1000, 10), self.factor) 55 | 56 | # Sampled from the latent state 57 | noise = th.normal(mean=self.mean, std=1, size=logsigma.shape, device=logsigma.device) 58 | z = mu + (th.exp(0.5 * logsigma) * noise if self.training else 0) # TODO verifiy training / testing noise!!!!!! TODO 59 | 60 | return z.view(z.shape[0], -1, *z.shape[3:]) 61 | -------------------------------------------------------------------------------- /utils/configuration.py: -------------------------------------------------------------------------------- 1 | import json 2 | import jsmin 3 | import os 4 | 5 | 6 | class Dict(dict): 7 | """ 8 | Dictionary that allows to access per attributes and to except names from being loaded 9 | """ 10 | def __init__(self, dictionary: dict = None): 11 | super(Dict, self).__init__() 12 | 13 | if dictionary is not None: 14 | self.load(dictionary) 15 | 16 | def __getattr__(self, item): 17 | try: 18 | return self[item] if item in self else getattr(super(Dict, self), item) 19 | except AttributeError: 20 | raise AttributeError(f'This dictionary has no attribute "{item}"') 21 | 22 | def load(self, dictionary: dict, name_list: list = None): 23 | """ 24 | Loads a dictionary 25 | :param dictionary: Dictionary to be loaded 26 | :param name_list: List of names to be updated 27 | """ 28 | for name in dictionary: 29 | data = dictionary[name] 30 | if name_list is None or name in name_list: 31 | if isinstance(data, dict): 32 | if name in self: 33 | self[name].load(data) 34 | else: 35 | self[name] = Dict(data) 36 | elif isinstance(data, list): 37 | self[name] = list() 38 | for item in data: 39 | if isinstance(item, dict): 40 | self[name].append(Dict(item)) 41 | else: 42 | self[name].append(item) 43 | else: 44 | self[name] = data 45 | 46 | def save(self, path): 47 | """ 48 | Saves the dictionary into a json file 49 | :param path: Path of the json file 50 | """ 51 | os.makedirs(path, exist_ok=True) 52 | 53 | path = os.path.join(path, 'cfg.json') 54 | 55 | with open(path, 'w') as file: 56 | json.dump(self, file, indent=True) 57 | 58 | 59 | class Configuration(Dict): 60 | """ 61 | Configuration loaded from a json file 62 | """ 63 | def __init__(self, path: str, default_path=None): 64 | super(Configuration, self).__init__() 65 | 66 | if default_path is not None: 67 | self.load(default_path) 68 | 69 | self.load(path) 70 | 71 | def load_model(self, path: str): 72 | self.load(path, name_list=["model"]) 73 | 74 | def load(self, path: str, name_list: list = None): 75 | """ 76 | Loads attributes from a json file 77 | :param path: Path of the json file 78 | :param name_list: List of names to be updated 79 | :return: 80 | """ 81 | with open(path) as file: 82 | data = json.loads(jsmin.jsmin(file.read())) 83 | 84 | super(Configuration, self).load(data, name_list) 85 | -------------------------------------------------------------------------------- /utils/data.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch as th 3 | import numpy as np 4 | from einops import rearrange, repeat, reduce 5 | 6 | from typing import Tuple, Union, List 7 | 8 | __author__ = "Manuel Traub" 9 | 10 | class DeviceSideDataset: 11 | 12 | def __init__(self, dataset, device, batch_size): 13 | 14 | self.batch_size = batch_size 15 | 16 | self.data = [[]] 17 | if isinstance(dataset[0], tuple) or isinstance(dataset[0], list): 18 | for i in range(1, len(dataset[0])): 19 | self.data.append([]) 20 | 21 | for i in range(len(dataset)): 22 | item = dataset[i] 23 | for n in range(len(item)): 24 | self.data[n].append(item[n]) 25 | 26 | if i % 100 == 0: 27 | print(f"loading data {i * 100 / len(dataset):.2f}%") 28 | 29 | else: 30 | for i in range(len(dataset)): 31 | self.data[0].append(dataset[i]) 32 | if i % 100 == 0: 33 | print(f"loading data {i * 100 / len(dataset):.2f}%") 34 | 35 | print(f"loading data {100:.2f}%") 36 | for i in range(len(self.data)): 37 | self.data[i] = th.tensor(np.stack(self.data[i])).float().to(device) 38 | print(f"pushed[{i}]: {self.data[i].element_size()} * {self.data[i].nelement()} = {self.data[i].element_size() * self.data[i].nelement()}") 39 | 40 | self.shuffle() 41 | 42 | def __len__(self): 43 | return self.data[0].shape[0] // self.batch_size 44 | 45 | def __iter__(self): 46 | self.shuffle() 47 | self.batch_counter = 0 48 | return self 49 | 50 | def next(self): 51 | return self.__next__() 52 | 53 | def __next__(self): 54 | batch_start = self.batch_counter * self.batch_size 55 | batch_end = (self.batch_counter + 1) * self.batch_size 56 | self.batch_counter += 1 57 | 58 | if batch_end >= self.data[0].shape[0]: 59 | raise StopIteration 60 | 61 | batch = [] 62 | for i in range(len(self.data)): 63 | batch.append(self.data[i][batch_start:batch_end]) 64 | 65 | return tuple(batch) 66 | 67 | def shuffle(self): 68 | with th.no_grad(): 69 | indices = th.randperm(self.data[0].shape[0]) 70 | 71 | for i in range(len(self.data)): 72 | self.data[i] = self.data[i][indices] 73 | 74 | -------------------------------------------------------------------------------- /utils/io.py: -------------------------------------------------------------------------------- 1 | import os 2 | from utils.configuration import Configuration 3 | import time 4 | import torch as th 5 | import numpy as np 6 | from einops import rearrange, repeat, reduce 7 | import cv2 8 | 9 | class Timer: 10 | 11 | def __init__(self): 12 | self.last = time.time() 13 | self.passed = 0 14 | self.sum = 0 15 | 16 | def __str__(self): 17 | self.passed = self.passed * 0.99 + time.time() - self.last 18 | self.sum = self.sum * 0.99 + 1 19 | passed = self.passed / self.sum 20 | self.last = time.time() 21 | 22 | if passed > 1: 23 | return f"{passed:.2f}s/it" 24 | 25 | return f"{1.0/passed:.2f}it/s" 26 | 27 | class UEMA: 28 | 29 | def __init__(self, memory = 100): 30 | self.value = 0 31 | self.sum = 1e-30 32 | self.decay = np.exp(-1 / memory) 33 | 34 | def update(self, value): 35 | self.value = self.value * self.decay + value 36 | self.sum = self.sum * self.decay + 1 37 | 38 | def __float__(self): 39 | return self.value / self.sum 40 | 41 | class BinaryStatistics: 42 | 43 | def __init__(self): 44 | self.true_positive = 0 45 | self.true_negative = 0 46 | self.false_positive = 0 47 | self.false_negative = 0 48 | 49 | def update(self, outputs, labels): 50 | outputs = th.round(outputs) 51 | self.true_positive += th.sum((outputs == labels).float() * (labels == th.ones_like(labels)).float()).item() 52 | self.true_negative += th.sum((outputs == labels).float() * (labels == th.zeros_like(labels)).float()).item() 53 | self.false_positive += th.sum((outputs != labels).float() * (labels == th.zeros_like(labels)).float()).item() 54 | self.false_negative += th.sum((outputs != labels).float() * (labels == th.ones_like(labels)).float()).item() 55 | 56 | def accuracy(self): 57 | return 100 * (self.true_positive + self.true_negative) / (self.true_positive + self.true_negative + self.false_positive + self.false_negative + 1e-30) 58 | 59 | def sensitivity(self): 60 | return 100 * self.true_positive / (self.true_positive + self.false_negative + 1e-30) 61 | 62 | def specificity(self): 63 | return 100 * self.true_negative / (self.true_negative + self.false_positive + 1e-30) 64 | 65 | 66 | def model_path(cfg: Configuration, overwrite=False, move_old=True): 67 | """ 68 | Makes the model path, option to not overwrite 69 | :param cfg: Configuration file with the model path 70 | :param overwrite: Overwrites the files in the directory, else makes a new directory 71 | :param move_old: Moves old folder with the same name to an old folder, if not overwrite 72 | :return: Model path 73 | """ 74 | _path = os.path.join('out') 75 | path = os.path.join(_path, cfg.model_path) 76 | 77 | if not os.path.exists(_path): 78 | os.makedirs(_path) 79 | 80 | if not overwrite: 81 | if move_old: 82 | # Moves existing directory to an old folder 83 | if os.path.exists(path): 84 | old_path = os.path.join(_path, f'{cfg.model_path}_old') 85 | if not os.path.exists(old_path): 86 | os.makedirs(old_path) 87 | _old_path = os.path.join(old_path, cfg.model_path) 88 | i = 0 89 | while os.path.exists(_old_path): 90 | i = i + 1 91 | _old_path = os.path.join(old_path, f'{cfg.model_path}_{i}') 92 | os.renames(path, _old_path) 93 | else: 94 | # Increases number after directory name for each new path 95 | i = 0 96 | while os.path.exists(path): 97 | i = i + 1 98 | path = os.path.join(_path, f'{cfg.model_path}_{i}') 99 | 100 | return path 101 | -------------------------------------------------------------------------------- /utils/loss.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | import torchvision as tv 3 | from torch import nn 4 | from utils.utils import BatchToSharedObjects, SharedObjectsToBatch, LambdaModule 5 | from pytorch_msssim import ms_ssim as msssim, ssim 6 | from einops import rearrange, repeat, reduce 7 | 8 | __author__ = "Manuel Traub" 9 | 10 | class SSIMLoss(nn.Module): 11 | def __init__(self): 12 | super(SSIMLoss, self).__init__() 13 | 14 | def forward(self, output: th.Tensor, target: th.Tensor): 15 | return -ssim(output, target) 16 | 17 | class MSSSIMLoss(nn.Module): 18 | def __init__(self): 19 | super(MSSSIMLoss, self).__init__() 20 | 21 | def forward(self, output: th.Tensor, target: th.Tensor): 22 | return -msssim(output, target) #, normalize="relu") 23 | 24 | class PositionLoss(nn.Module): 25 | def __init__(self, num_objects: int, teacher_forcing: int): 26 | super(PositionLoss, self).__init__() 27 | 28 | self.to_batch = LambdaModule(lambda x: rearrange(x, 'b (o c) -> (b o) c', o = num_objects)) 29 | self.last_mask = None 30 | self.t = 0 31 | self.teacher_forcing = teacher_forcing 32 | 33 | def reset_state(self): 34 | self.last_mask = None 35 | self.t = 0 36 | 37 | def forward(self, position, position_last, mask): 38 | 39 | mask = th.max(th.max(mask, dim=3)[0], dim=2)[0] 40 | mask = self.to_batch(mask).detach() 41 | self.t = self.t + 1 42 | 43 | if self.last_mask is None or self.t <= self.teacher_forcing: 44 | self.last_mask = mask.detach() 45 | return th.zeros(1, device=mask.device) 46 | 47 | self.last_mask = th.maximum(self.last_mask, mask) 48 | 49 | position = self.to_batch(position) 50 | position_last = self.to_batch(position_last).detach() 51 | 52 | return 0.01 * th.mean(self.last_mask * (position - position_last)**2) 53 | 54 | 55 | class MaskModulatedObjectLoss(nn.Module): 56 | def __init__(self, num_objects: int, teacher_forcing: int): 57 | super(MaskModulatedObjectLoss, self).__init__() 58 | 59 | self.to_batch = SharedObjectsToBatch(num_objects) 60 | self.last_mask = None 61 | self.t = 0 62 | self.teacher_forcing = teacher_forcing 63 | 64 | def reset_state(self): 65 | self.last_mask = None 66 | self.t = 0 67 | 68 | def forward( 69 | self, 70 | object_output, 71 | object_target, 72 | mask: th.Tensor 73 | ): 74 | mask = self.to_batch(mask).detach() 75 | mask = th.max(th.max(mask, dim=3, keepdim=True)[0], dim=2, keepdim=True)[0] 76 | self.t = self.t + 1 77 | 78 | if self.last_mask is None or self.t <= self.teacher_forcing: 79 | self.last_mask = mask.detach() 80 | return th.zeros(1, device=mask.device) 81 | 82 | self.last_mask = th.maximum(self.last_mask, mask).detach() 83 | 84 | object_output = th.sigmoid(self.to_batch(object_output) - 2.5) 85 | object_target = th.sigmoid(self.to_batch(object_target) - 2.5) 86 | 87 | return th.mean((1 - mask) * self.last_mask * (object_output - object_target)**2) 88 | 89 | class ObjectModulator(nn.Module): 90 | def __init__(self, num_objects: int): 91 | super(ObjectModulator, self).__init__() 92 | self.to_batch = LambdaModule(lambda x: rearrange(x, 'b (o c) -> (b o) c', o = num_objects)) 93 | self.to_shared = LambdaModule(lambda x: rearrange(x, '(b o) c -> b (o c)', o = num_objects)) 94 | self.position = None 95 | self.gestalt = None 96 | 97 | def reset_state(self): 98 | self.position = None 99 | self.gestalt = None 100 | 101 | def forward(self, position: th.Tensor, gestalt: th.Tensor, mask: th.Tensor): 102 | 103 | position = self.to_batch(position) 104 | gestalt = self.to_batch(gestalt) 105 | 106 | if self.position is None or self.gestalt is None: 107 | self.position = position.detach() 108 | self.gestalt = gestalt.detach() 109 | return self.to_shared(position), self.to_shared(gestalt) 110 | 111 | mask = th.max(th.max(mask, dim=3)[0], dim=2)[0] 112 | mask = self.to_batch(mask.detach()) 113 | 114 | _position = mask * position + (1 - mask) * self.position 115 | position = th.cat((position[:,:-1], _position[:,-1:]), dim=1) 116 | gestalt = mask * gestalt + (1 - mask) * self.gestalt 117 | 118 | self.gestalt = gestalt.detach() 119 | self.position = position.detach() 120 | return self.to_shared(position), self.to_shared(gestalt) 121 | 122 | class MoveToCenter(nn.Module): 123 | def __init__(self, num_objects: int): 124 | super(MoveToCenter, self).__init__() 125 | 126 | self.to_batch2d = SharedObjectsToBatch(num_objects) 127 | self.to_batch = LambdaModule(lambda x: rearrange(x, 'b (o c) -> (b o) c', o = num_objects)) 128 | 129 | def forward(self, input: th.Tensor, position: th.Tensor): 130 | 131 | input = self.to_batch2d(input) 132 | position = self.to_batch(position).detach() 133 | position = th.stack((position[:,1], position[:,0]), dim=1) 134 | 135 | theta = th.tensor([1, 0, 0, 1], dtype=th.float, device=input.device).view(1,2,2) 136 | theta = repeat(theta, '1 a b -> n a b', n=input.shape[0]) 137 | 138 | position = rearrange(position, 'b c -> b c 1') 139 | theta = th.cat((theta, position), dim=2) 140 | 141 | grid = nn.functional.affine_grid(theta, input.shape, align_corners=False) 142 | output = nn.functional.grid_sample(input, grid, align_corners=False) 143 | 144 | return output 145 | 146 | class TranslationInvariantObjectLoss(nn.Module): 147 | def __init__(self, num_objects: int, teacher_forcing: int): 148 | super(TranslationInvariantObjectLoss, self).__init__() 149 | 150 | self.move_to_center = MoveToCenter(num_objects) 151 | self.to_batch = SharedObjectsToBatch(num_objects) 152 | self.last_mask = None 153 | self.t = 0 154 | self.teacher_forcing = teacher_forcing 155 | 156 | def reset_state(self): 157 | self.last_mask = None 158 | self.t = 0 159 | 160 | def forward( 161 | self, 162 | mask: th.Tensor, 163 | object1: th.Tensor, 164 | position1: th.Tensor, 165 | object2: th.Tensor, 166 | position2: th.Tensor, 167 | ): 168 | mask = self.to_batch(mask).detach() 169 | mask = th.max(th.max(mask, dim=3, keepdim=True)[0], dim=2, keepdim=True)[0] 170 | self.t = self.t + 1 171 | 172 | if self.last_mask is None or self.t <= self.teacher_forcing: 173 | self.last_mask = mask.detach() 174 | return th.zeros(1, device=mask.device) 175 | 176 | self.last_mask = th.maximum(self.last_mask, mask).detach() 177 | 178 | object1 = self.move_to_center(th.sigmoid(object1 - 2.5), position1) 179 | object2 = self.move_to_center(th.sigmoid(object2 - 2.5), position2) 180 | 181 | return th.mean(self.last_mask * (object1 - object2)**2) 182 | 183 | -------------------------------------------------------------------------------- /utils/optimizers.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch as th 3 | import numpy as np 4 | from torch.optim.optimizer import Optimizer, required 5 | 6 | """ 7 | Liyuan Liu , Haoming Jiang, Pengcheng He, Weizhu Chen, Xiaodong Liu, Jianfeng Gao, and Jiawei Han (2020). 8 | On the Variance of the Adaptive Learning Rate and Beyond. the Eighth International Conference on Learning 9 | Representations. 10 | """ 11 | class RAdam(Optimizer): 12 | 13 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-16, weight_decay=0, degenerated_to_sgd=False): 14 | if not 0.0 <= lr: 15 | raise ValueError("Invalid learning rate: {}".format(lr)) 16 | if not 0.0 <= eps: 17 | raise ValueError("Invalid epsilon value: {}".format(eps)) 18 | if not 0.0 <= betas[0] < 1.0: 19 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 20 | if not 0.0 <= betas[1] < 1.0: 21 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 22 | 23 | self.degenerated_to_sgd = degenerated_to_sgd 24 | if isinstance(params, (list, tuple)) and len(params) > 0 and isinstance(params[0], dict): 25 | for param in params: 26 | if 'betas' in param and (param['betas'][0] != betas[0] or param['betas'][1] != betas[1]): 27 | param['buffer'] = [[None, None, None] for _ in range(10)] 28 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, buffer=[[None, None, None] for _ in range(10)]) 29 | super(RAdam, self).__init__(params, defaults) 30 | 31 | def __setstate__(self, state): 32 | super(RAdam, self).__setstate__(state) 33 | 34 | def step(self, closure=None): 35 | 36 | loss = None 37 | if closure is not None: 38 | loss = closure() 39 | 40 | for group in self.param_groups: 41 | 42 | for p in group['params']: 43 | if p.grad is None: 44 | continue 45 | grad = p.grad.data.float() 46 | if grad.is_sparse: 47 | raise RuntimeError('RAdam does not support sparse gradients') 48 | 49 | p_data_fp32 = p.data.float() 50 | 51 | state = self.state[p] 52 | 53 | if len(state) == 0: 54 | state['step'] = 0 55 | state['exp_avg'] = th.zeros_like(p_data_fp32) 56 | state['exp_avg_sq'] = th.zeros_like(p_data_fp32) 57 | else: 58 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 59 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 60 | 61 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 62 | beta1, beta2 = group['betas'] 63 | 64 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value = 1 - beta2) 65 | exp_avg.mul_(beta1).add_(grad, alpha = 1 - beta1) 66 | 67 | state['step'] += 1 68 | buffered = group['buffer'][int(state['step'] % 10)] 69 | if state['step'] == buffered[0]: 70 | N_sma, step_size = buffered[1], buffered[2] 71 | else: 72 | buffered[0] = state['step'] 73 | beta2_t = beta2 ** state['step'] 74 | N_sma_max = 2 / (1 - beta2) - 1 75 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) 76 | buffered[1] = N_sma 77 | 78 | # more conservative since it's an approximated value 79 | if N_sma >= 5: 80 | step_size = math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step']) 81 | elif self.degenerated_to_sgd: 82 | step_size = 1.0 / (1 - beta1 ** state['step']) 83 | else: 84 | step_size = -1 85 | buffered[2] = step_size 86 | 87 | # more conservative since it's an approximated value 88 | if N_sma >= 5: 89 | if group['weight_decay'] != 0: 90 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 91 | denom = exp_avg_sq.sqrt().add_(group['eps']) 92 | p_data_fp32.addcdiv_(exp_avg, denom, value = -step_size * group['lr']) 93 | p.data.copy_(p_data_fp32) 94 | elif step_size > 0: 95 | if group['weight_decay'] != 0: 96 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 97 | p_data_fp32.add_(-step_size * group['lr'], exp_avg) 98 | p.data.copy_(p_data_fp32) 99 | 100 | return loss 101 | 102 | 103 | class SDRMSprop(Optimizer): 104 | 105 | def __init__(self, params, lr=1e-3, alpha=0.99, beta=0.9, eps=1e-8, weight_decay=0): 106 | 107 | defaults = dict(lr=lr, alpha=alpha, beta=beta, eps=eps, weight_decay=weight_decay) 108 | super(SDRMSprop, self).__init__(params, defaults) 109 | 110 | def reset_state(self): 111 | for group in self.param_groups: 112 | for p in group['params']: 113 | state = self.state[p] 114 | state['v'] = th.zeros_like(p, memory_format=th.preserve_format) 115 | state['s'] = th.zeros_like(p, memory_format=th.preserve_format) 116 | 117 | @th.no_grad() 118 | def step(self, debug=False): 119 | 120 | for group in self.param_groups: 121 | 122 | a = group['alpha'] 123 | b = group['beta'] 124 | lr = group['lr'] 125 | eps = group['eps'] 126 | 127 | for p in group['params']: 128 | if p.grad is not None: 129 | 130 | g = p.grad 131 | 132 | state = self.state[p] 133 | 134 | # Lazy state initialization 135 | if len(state) == 0: 136 | state['v'] = th.zeros_like(p, memory_format=th.preserve_format) 137 | state['s'] = th.zeros_like(p, memory_format=th.preserve_format) 138 | 139 | state['v'] = a * state['v'] + (1 - a) * g**2 140 | state['s'] = b * state['s'] + (1 - b) * th.sign(g) 141 | 142 | _v = state['v'] 143 | _s = state['s'] 144 | 145 | 146 | p.add_(-1 * lr * _s**2 * g / (th.sqrt(_v) + eps)) 147 | 148 | 149 | 150 | class SDAdam(Optimizer): 151 | 152 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999, 0.9), eps=1e-8, weight_decay=0): 153 | 154 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) 155 | super(SDAdam, self).__init__(params, defaults) 156 | 157 | def reset_state(self): 158 | for group in self.param_groups: 159 | for p in group['params']: 160 | state = self.state[p] 161 | state['step'] = 0 162 | state['m'] = th.zeros_like(p, memory_format=th.preserve_format) 163 | state['v'] = th.zeros_like(p, memory_format=th.preserve_format) 164 | state['s'] = th.zeros_like(p, memory_format=th.preserve_format) 165 | 166 | @th.no_grad() 167 | def step(self, debug=False): 168 | 169 | for group in self.param_groups: 170 | 171 | b1, b2, b3 = group['betas'] 172 | lr = group['lr'] 173 | eps = group['eps'] 174 | 175 | for p in group['params']: 176 | if p.grad is not None: 177 | 178 | g = p.grad 179 | 180 | state = self.state[p] 181 | 182 | # Lazy state initialization 183 | if len(state) == 0: 184 | state['step'] = 0 185 | state['m'] = th.zeros_like(p, memory_format=th.preserve_format) 186 | state['v'] = th.zeros_like(p, memory_format=th.preserve_format) 187 | state['s'] = th.zeros_like(p, memory_format=th.preserve_format) 188 | 189 | state['step'] += 1 190 | 191 | state['m'] = b1 * state['m'] + (1 - b1) * g 192 | state['v'] = b2 * state['v'] + (1 - b2) * g**2 193 | state['s'] = b3 * state['s'] + (1 - b3) * th.sign(g) 194 | 195 | _m = state['m'] / (1 - b1**state['step']) 196 | _v = state['v'] / (1 - b2**state['step']) 197 | _s = state['s'] / (1 - b3**state['step']) 198 | 199 | 200 | p.add_(-1 * lr * _s**2 * _m / (th.sqrt(_v) + eps)) 201 | 202 | class SDAMSGrad(Optimizer): 203 | 204 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999, 0.9), eps=1e-8, weight_decay=0): 205 | 206 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) 207 | super(SDAMSGrad, self).__init__(params, defaults) 208 | 209 | def reset_state(self): 210 | for group in self.param_groups: 211 | for p in group['params']: 212 | state = self.state[p] 213 | state['step'] = 0 214 | state['m'] = th.zeros_like(p, memory_format=th.preserve_format) 215 | state['v'] = th.zeros_like(p, memory_format=th.preserve_format) 216 | state['s'] = th.zeros_like(p, memory_format=th.preserve_format) 217 | state['v_max'] = th.zeros_like(p, memory_format=th.preserve_format) 218 | 219 | @th.no_grad() 220 | def step(self, debug=False): 221 | 222 | for group in self.param_groups: 223 | 224 | b1, b2, b3 = group['betas'] 225 | lr = group['lr'] 226 | eps = group['eps'] 227 | 228 | for p in group['params']: 229 | if p.grad is not None: 230 | 231 | g = p.grad 232 | 233 | state = self.state[p] 234 | 235 | # Lazy state initialization 236 | if len(state) == 0: 237 | state['step'] = 0 238 | state['m'] = th.zeros_like(p, memory_format=th.preserve_format) 239 | state['v'] = th.zeros_like(p, memory_format=th.preserve_format) 240 | state['s'] = th.zeros_like(p, memory_format=th.preserve_format) 241 | state['v_max'] = th.zeros_like(p, memory_format=th.preserve_format) 242 | 243 | state['step'] += 1 244 | 245 | state['m'] = b1 * state['m'] + (1 - b1) * g 246 | state['v'] = b2 * state['v'] + (1 - b2) * g**2 247 | state['s'] = b3 * state['s'] + (1 - b3) * th.sign(g) 248 | state['v_max'] = th.max(state['v_max'], state['v']) 249 | 250 | _m = state['m'] / (1 - b1**state['step']) 251 | _v = state['v_max'] / (1 - b2**state['step']) 252 | _s = state['s'] / (1 - b3**state['step']) 253 | 254 | 255 | p.add_(-1 * lr * _s**2 * _m / (th.sqrt(_v) + eps)) 256 | 257 | 258 | -------------------------------------------------------------------------------- /utils/parallel.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | from torch import distributed 3 | import torch.multiprocessing as multiprocessing 4 | 5 | from torch.utils import data 6 | import os 7 | import numpy as np 8 | 9 | 10 | def run_parallel(fn, num_gpus=1, *args): 11 | """ 12 | Initializes and starts the parallel training of a network 13 | :param fn: Function used for the training 14 | :param num_gpus: Number of GPUs 15 | :param args: Arguments for fn(rank, world_size, *args) 16 | """ 17 | # Processes epochs parallel on GPUs 18 | multiprocessing.spawn(fn=_init_process, args=tuple([num_gpus, fn, *args]), 19 | nprocs=num_gpus, join=True) 20 | 21 | 22 | def _init_process(rank: int, world_size: int, fn, *args, backend='gloo'): 23 | """ 24 | Initializes a process for one GPU 25 | :param rank: Number of the GPU (0..world_size-1) 26 | :param world_size: Number of GPUs available 27 | :param fn: Function used for the training 28 | :param args: Arguments for the function fn 29 | :param backend: Backend used by the framework 30 | """ 31 | 32 | th.cuda.set_device(rank) 33 | 34 | def open_port(port=60000): 35 | os.environ['MASTER_ADDR'] = '127.0.0.1' 36 | master_port = port 37 | while True: 38 | try: 39 | os.environ['MASTER_PORT'] = f'{master_port}' 40 | distributed.init_process_group(backend=backend, rank=rank, world_size=world_size) 41 | return master_port 42 | except RuntimeError as e: 43 | if master_port >= port + 100: 44 | raise e 45 | master_port = master_port + 1 46 | 47 | open_port() 48 | 49 | distributed.barrier() 50 | fn(rank, world_size, *args) 51 | 52 | 53 | class DatasetPartition(data.Dataset): 54 | def __init__(self, parent: data.Dataset, rank: int, world_size: int): 55 | """ 56 | Partition of Dataset for use in parallel training 57 | :param parent: Parent Dataset to be partitioned 58 | :param rank: Index [0..num-1] of the partition 59 | :param world_size: Number of all Dataset partitions 60 | """ 61 | 62 | self.rank = rank 63 | self.world_size = world_size 64 | 65 | self.parent = parent 66 | self.indices = np.arange(len(parent)) 67 | np.random.seed(123456) 68 | np.random.shuffle(self.indices) 69 | 70 | size = len(parent) // self.world_size 71 | if size == 0: 72 | raise ValueError(f'The datasets {parent} is empty') 73 | 74 | self.idx_start = self.rank * size 75 | self.idx_end = (self.rank + 1) * size 76 | 77 | def __len__(self): 78 | return self.idx_end - self.idx_start 79 | 80 | def __getitem__(self, idx: int): 81 | return self.parent[self.indices[self.idx_start + idx]] 82 | -------------------------------------------------------------------------------- /utils/scheduled_sampling.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from numpy.random import random 3 | 4 | 5 | class ScheduledSampler: 6 | def __init__(self, error_threshold=0.0, min_value=0.0, max_seq=-1): 7 | """ 8 | Scheduled Sampling, samples wetter the predicted output should be used as input 9 | From "Scheduled Sampling for Sequence Prediction with Recurrent Neural Networks" 10 | by Samy Bengio and Oriol Vinyals and Navdeep Jaitly and Noam Shazeer 11 | :param error_threshold: Error threshold for when a self.step() increases the iteration 12 | :param min_value: Minimum probability of using the output as input 13 | :param max_seq: Maximum steps of a sequence of using the output as input 14 | """ 15 | super(ScheduledSampler, self).__init__() 16 | 17 | self.iteration = 0 18 | self.error_threshold = error_threshold 19 | self.min_value = min_value 20 | self.max_seq = max_seq if max_seq > 0 else np.inf 21 | 22 | self.seq_len = 0 23 | 24 | def __probability__(self, iteration: int = -1): 25 | """ 26 | Probability function of using the output as input 27 | :param iteration: Number of the iteration, uses current iteration if iteration < 0 28 | :return: Probability 0..1 29 | """ 30 | raise NotImplementedError('Implement this function') 31 | 32 | def probability(self, iteration: int = -1): 33 | """ 34 | Post-processed probability of using the output as input 35 | :param iteration: Number of the iteration, uses current iteration if iteration < 0 36 | :return: Probability 0..1 37 | """ 38 | iteration = iteration if iteration >= 0 else self.iteration 39 | return max(self.__probability__(iteration), self.min_value) 40 | 41 | def step(self, iteration=-1, error=0.0): 42 | """ 43 | Increments the iteration if error < self.error_threshold 44 | :param iteration: Optionally sets the iteration explicitly 45 | :param error: Error, where the iteration is only increased when error < self.error_threshold 46 | """ 47 | if error < self.error_threshold or self.error_threshold <= 0: 48 | self.iteration += 1 49 | 50 | self.iteration = iteration if iteration >= 0 else self.iteration 51 | 52 | def sample(self): 53 | """ 54 | Samples wetter the output should be used as input 55 | :return: True if output should be used as input 56 | """ 57 | sample = self.seq_len < self.max_seq and random() > self.probability() 58 | self.seq_len = self.seq_len + 1 if sample else 0 59 | return sample 60 | 61 | def sample_smooth(self): 62 | """ 63 | Samples how much the output vs the input is used as new input 64 | :return: Ratio p=0..1 with input = p * output + (1 - p) * input 65 | """ 66 | sample = self.probability() if self.seq_len < self.max_seq else 1.0 67 | self.seq_len = self.seq_len + 1 if sample < 0 else 0 68 | return sample 69 | 70 | 71 | class LinearSampler(ScheduledSampler): 72 | def __init__(self, slope: float, error_threshold=0.0, min_value=0.0, max_seq=-1): 73 | """ 74 | Samples with a linear slope wetter the predicted output should be used as input 75 | :param slope: Linear slope regarding the iterations 76 | :param error_threshold: Error threshold for increasing iteration with self.step() 77 | :param min_value: Minimum probability of using the output as input 78 | :param max_seq: Maximum steps of a sequence where the output is used as input 79 | """ 80 | super(LinearSampler, self).__init__(error_threshold, min_value, max_seq) 81 | 82 | self.slope = slope 83 | 84 | def __probability__(self, iteration: int = -1): 85 | return 1 - iteration * self.slope 86 | 87 | 88 | class ExponentialSampler(ScheduledSampler): 89 | def __init__(self, initial_probability: float, error_threshold=0.0, min_value=0.0, max_seq=-1): 90 | """ 91 | Samples an exponential decrease, with probability = initial_probability ^ iteration 92 | :param initial_probability: Initial base, must be (0, 1] 93 | :param error_threshold: Error threshold for increasing iteration with self.step() 94 | :param min_value: Minimum probability of using the output as input 95 | :param max_seq: Maximum iterations in a closed loop sequence 96 | """ 97 | super(ExponentialSampler, self).__init__(error_threshold, min_value, max_seq) 98 | 99 | assert(0 < initial_probability <= 1) 100 | self.initial_propability = initial_probability 101 | 102 | def __probability__(self, iteration: int = -1): 103 | return self.initial_propability**iteration 104 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch as th 3 | import numpy as np 4 | import nn as nn_modules 5 | from torch.autograd import Function 6 | from einops import rearrange, repeat, reduce 7 | 8 | from typing import Tuple, Union, List 9 | import utils 10 | 11 | __author__ = "Manuel Traub" 12 | 13 | class PrintShape(nn.Module): 14 | def __init__(self, msg = ""): 15 | super(PrintShape, self).__init__() 16 | self.msg = msg 17 | 18 | def forward(self, input: th.Tensor): 19 | if self.msg != "": 20 | print(self.msg, input.shape) 21 | else: 22 | print(input.shape) 23 | return input 24 | 25 | class PrintStats(nn.Module): 26 | def __init__(self): 27 | super(PrintStats, self).__init__() 28 | 29 | def forward(self, input: th.Tensor): 30 | print( 31 | "min: ", th.min(input).detach().cpu().numpy(), 32 | ", mean: ", th.mean(input).detach().cpu().numpy(), 33 | ", max: ", th.max(input).detach().cpu().numpy() 34 | ) 35 | return input 36 | 37 | class PushToInfFunction(Function): 38 | @staticmethod 39 | def forward(ctx, tensor): 40 | ctx.save_for_backward(tensor) 41 | return tensor.clone() 42 | 43 | @staticmethod 44 | def backward(ctx, grad_output): 45 | tensor = ctx.saved_tensors[0] 46 | grad_input = -th.ones_like(grad_output) 47 | return grad_input 48 | 49 | class PushToInf(nn.Module): 50 | def __init__(self): 51 | super(PushToInf, self).__init__() 52 | 53 | self.fcn = PushToInfFunction.apply 54 | 55 | def forward(self, input: th.Tensor): 56 | return self.fcn(input) 57 | 58 | class ForcedAlpha(nn.Module): 59 | def __init__(self, speed = 1): 60 | super(ForcedAlpha, self).__init__() 61 | 62 | self.init = nn.Parameter(th.zeros(1)) 63 | self.speed = speed 64 | self.to_inf = PushToInf() 65 | 66 | def item(self): 67 | return th.tanh(self.to_inf(self.init * self.speed)).item() 68 | 69 | def forward(self, input: th.Tensor): 70 | return input * th.tanh(self.to_inf(self.init * self.speed)) 71 | 72 | class AlphaThreshold(nn.Module): 73 | def __init__(self, max_value = 1): 74 | super(AlphaThreshold, self).__init__() 75 | 76 | self.init = nn.Parameter(th.zeros(1)) 77 | self.to_inf = PushToInf() 78 | self.max_value = max_value 79 | 80 | def forward(self): 81 | return th.tanh(self.to_inf(self.init)) * self.max_value 82 | 83 | class InitialLatentStates(nn.Module): 84 | def __init__( 85 | self, 86 | gestalt_size: int, 87 | num_objects: int, 88 | size: Tuple[int, int] 89 | ): 90 | super(InitialLatentStates, self).__init__() 91 | 92 | self.num_objects = num_objects 93 | self.gestalt_size = gestalt_size 94 | self.gestalt_mean = nn.Parameter(th.zeros(1, gestalt_size)) 95 | self.gestalt_std = nn.Parameter(th.ones(1, gestalt_size)) 96 | self.std = nn.Parameter(th.zeros(1)) 97 | 98 | self.register_buffer('priority', th.arange(num_objects).float() * 25, persistent=False) 99 | self.register_buffer('threshold', th.ones(1) * 0.5) 100 | self.last_mask = None 101 | 102 | self.gaus2d = nn.Sequential( 103 | Gaus2D((size[0] // 16, size[1] // 16)), 104 | Gaus2D((size[0] // 4, size[1] // 4)), 105 | Gaus2D(size) 106 | ) 107 | 108 | self.level = 1 109 | 110 | self.to_batch = LambdaModule(lambda x: rearrange(x, 'b (o c) -> (b o) c', o = num_objects)) 111 | self.to_shared = LambdaModule(lambda x: rearrange(x, '(b o) c -> b (o c)', o = num_objects)) 112 | 113 | def reset_state(self): 114 | self.last_mask = None 115 | 116 | def set_level(self, level): 117 | self.level = level 118 | 119 | def forward( 120 | self, 121 | error: th.Tensor, 122 | mask: th.Tensor = None, 123 | position: th.Tensor = None, 124 | gestalt: th.Tensor = None, 125 | priority: th.Tensor = None 126 | ): 127 | 128 | batch_size = error.shape[0] 129 | device = error.device 130 | 131 | if self.last_mask is None: 132 | self.last_mask = th.zeros((batch_size * self.num_objects, 1), device = device) 133 | 134 | if mask is not None: 135 | mask = reduce(mask[:,:-1], 'b c h w -> (b c) 1' , 'max').detach() 136 | self.last_mask = th.maximum(self.last_mask, mask) 137 | 138 | std = repeat(self.std, '1 -> (b o) 1', b = batch_size, o=self.num_objects) 139 | mask = (self.last_mask > self.threshold).float().detach() 140 | 141 | gestalt_rand = th.randn((batch_size * self.num_objects, self.gestalt_size), device = device) 142 | gestalt_new = th.sigmoid(gestalt_rand * self.gestalt_std + self.gestalt_mean) 143 | 144 | if gestalt is None: 145 | gestalt = gestalt_new 146 | else: 147 | gestalt = self.to_batch(gestalt) * mask + gestalt_new * (1 - mask) 148 | 149 | if priority is None: 150 | priority = repeat(self.priority, 'o -> (b o) 1', b = batch_size) 151 | else: 152 | priority = self.to_batch(priority) * mask + repeat(self.priority, 'o -> (b o) 1', b = batch_size) * (1 - mask) 153 | 154 | xy_rand_new = th.rand((batch_size * self.num_objects * 10, 2), device = device) * 2 - 1 155 | std_new = th.zeros((batch_size * self.num_objects * 10, 1), device = device) 156 | position_new = th.cat((xy_rand_new, std_new), dim=1) 157 | 158 | position2d = self.gaus2d[self.level](position_new) 159 | position2d = rearrange(position2d, '(b o) 1 h w -> b o h w', b = batch_size) 160 | 161 | rand_error = reduce(position2d * error, 'b o h w -> (b o) 1', 'sum') 162 | 163 | xy_rand_new = rearrange(xy_rand_new, '(b r) c -> r b c', r = 10) 164 | rand_error = rearrange(rand_error, '(b r) c -> r b c', r = 10) 165 | 166 | max_error = th.argmax(rand_error, dim=0, keepdim=True) 167 | x, y = th.chunk(xy_rand_new, 2, dim=2) 168 | x = th.gather(x, dim=0, index=max_error).detach().squeeze(dim=0) 169 | y = th.gather(y, dim=0, index=max_error).detach().squeeze(dim=0) 170 | 171 | if position is None: 172 | position = th.cat((x, y, std), dim=1) 173 | else: 174 | position = self.to_batch(position) * mask + th.cat((x, y, std), dim=1) * (1 - mask) 175 | 176 | return self.to_shared(position), self.to_shared(gestalt), self.to_shared(priority) 177 | 178 | class Gaus2D(nn.Module): 179 | def __init__(self, size: Tuple[int, int]): 180 | super(Gaus2D, self).__init__() 181 | 182 | self.size = size 183 | 184 | self.register_buffer("grid_x", th.arange(size[0]), persistent=False) 185 | self.register_buffer("grid_y", th.arange(size[1]), persistent=False) 186 | 187 | self.grid_x = (self.grid_x / (size[0]-1)) * 2 - 1 188 | self.grid_y = (self.grid_y / (size[1]-1)) * 2 - 1 189 | 190 | self.grid_x = self.grid_x.view(1, 1, -1, 1).expand(1, 1, *size).clone() 191 | self.grid_y = self.grid_y.view(1, 1, 1, -1).expand(1, 1, *size).clone() 192 | 193 | def forward(self, input: th.Tensor): 194 | 195 | x = rearrange(input[:,0:1], 'b c -> b c 1 1') 196 | y = rearrange(input[:,1:2], 'b c -> b c 1 1') 197 | std = rearrange(input[:,2:3], 'b c -> b c 1 1') 198 | 199 | x = th.clip(x, -1, 1) 200 | y = th.clip(y, -1, 1) 201 | std = th.clip(std, 0, 1) 202 | 203 | max_size = max(self.size) 204 | std_x = (1 + max_size * std) / self.size[0] 205 | std_y = (1 + max_size * std) / self.size[1] 206 | 207 | return th.exp(-1 * ((self.grid_x - x)**2/(2 * std_x**2) + (self.grid_y - y)**2/(2 * std_y**2))) 208 | 209 | class SharedObjectsToBatch(nn.Module): 210 | def __init__(self, num_objects): 211 | super(SharedObjectsToBatch, self).__init__() 212 | 213 | self.num_objects = num_objects 214 | 215 | def forward(self, input: th.Tensor): 216 | return rearrange(input, 'b (o c) h w -> (b o) c h w', o=self.num_objects) 217 | 218 | class BatchToSharedObjects(nn.Module): 219 | def __init__(self, num_objects): 220 | super(BatchToSharedObjects, self).__init__() 221 | 222 | self.num_objects = num_objects 223 | 224 | def forward(self, input: th.Tensor): 225 | return rearrange(input, '(b o) c h w -> b (o c) h w', o=self.num_objects) 226 | 227 | class LambdaModule(nn.Module): 228 | def __init__(self, lambd): 229 | super().__init__() 230 | import types 231 | assert type(lambd) is types.LambdaType 232 | self.lambd = lambd 233 | 234 | def forward(self, x): 235 | return self.lambd(x) 236 | 237 | class PrintGradientFunction(Function): 238 | @staticmethod 239 | def forward(ctx, tensor, msg): 240 | ctx.msg = msg 241 | return tensor 242 | 243 | @staticmethod 244 | def backward(ctx, grad_output): 245 | grad_input = grad_output.clone() 246 | print(f"{ctx.msg}: {th.mean(grad_output).item()} +- {th.std(grad_output).item()}") 247 | return grad_input, None 248 | 249 | class PrintGradient(nn.Module): 250 | def __init__(self, msg = "PrintGradient"): 251 | super(PrintGradient, self).__init__() 252 | 253 | self.fcn = PrintGradientFunction.apply 254 | self.msg = msg 255 | 256 | def forward(self, input: th.Tensor): 257 | return self.fcn(input, self.msg) 258 | 259 | class Prioritize(nn.Module): 260 | def __init__(self, num_objects): 261 | super(Prioritize, self).__init__() 262 | 263 | self.num_objects = num_objects 264 | self.to_batch = SharedObjectsToBatch(num_objects) 265 | 266 | def forward(self, input: th.Tensor, priority: th.Tensor): 267 | 268 | if priority is None: 269 | return input 270 | 271 | batch_size = input.shape[0] 272 | weights = th.zeros((batch_size, self.num_objects, self.num_objects, 1, 1), device=input.device) 273 | 274 | for o in range(self.num_objects): 275 | weights[:,o,:,0,0] = th.sigmoid(priority[:,:] - priority[:,o:o+1]) 276 | weights[:,o,o,0,0] = weights[:,o,o,0,0] * 0 277 | 278 | input = rearrange(input, 'b c h w -> 1 (b c) h w') 279 | weights = rearrange(weights, 'b o i 1 1 -> (b o) i 1 1') 280 | 281 | output = th.relu(input - nn.functional.conv2d(input, weights, groups=batch_size)) 282 | output = rearrange(output, '1 (b c) h w -> b c h w ', b=batch_size) 283 | 284 | return output 285 | --------------------------------------------------------------------------------