├── .gitignore ├── .runconfigs ├── Dockerfile ├── FrEIA ├── __init__.py ├── __pycache__ │ └── __init__.cpython-38.pyc ├── framework │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-38.pyc │ │ ├── graph_inn.cpython-38.pyc │ │ ├── reversible_graph_net.cpython-38.pyc │ │ ├── reversible_sequential_net.cpython-38.pyc │ │ └── sequence_inn.cpython-38.pyc │ ├── graph_inn.py │ ├── reversible_graph_net.py │ ├── reversible_sequential_net.py │ └── sequence_inn.py └── modules │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-38.pyc │ ├── all_in_one_block.cpython-38.pyc │ ├── base.cpython-38.pyc │ ├── coupling_layers.cpython-38.pyc │ ├── fixed_transforms.cpython-38.pyc │ ├── gaussian_mixture.cpython-38.pyc │ ├── graph_topology.cpython-38.pyc │ ├── inv_auto_layers.cpython-38.pyc │ ├── invertible_resnet.cpython-38.pyc │ ├── orthogonal.cpython-38.pyc │ └── reshapes.cpython-38.pyc │ ├── all_in_one_block.py │ ├── base.py │ ├── coupling_layers.py │ ├── fixed_transforms.py │ ├── gaussian_mixture.py │ ├── graph_topology.py │ ├── inv_auto_layers.py │ ├── invertible_resnet.py │ ├── orthogonal.py │ └── reshapes.py ├── LICENSE ├── README.md ├── config.py ├── dummy_dataset └── dummy_class │ ├── test │ ├── anomaly │ │ ├── .directory │ │ ├── 1.png │ │ ├── 2.png │ │ ├── 3png.png │ │ ├── 4.png │ │ ├── 5.png │ │ ├── 6.png │ │ ├── 7.png │ │ └── 8png.png │ └── good │ │ ├── .directory │ │ ├── 1.png │ │ ├── 2.png │ │ ├── 3.png │ │ ├── 4.png │ │ ├── 5.png │ │ ├── 6.png │ │ ├── 7.png │ │ └── 8.png │ └── train │ └── good │ ├── 1.png │ ├── 2.png │ ├── 3.png │ └── 4.png ├── evaluate.py ├── handledata.py ├── localization.py ├── main.py ├── model.py ├── multi_transform_loader.py ├── requirements.txt ├── train.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | /data/ 2 | /dataset/ 3 | /models/alexnet/hub/ 4 | /models/ 5 | /weights/ 6 | /pretrained/ 7 | /mtd_train_filenames.txt 8 | /FrEIA_github/ 9 | /neptuneparams.py 10 | -------------------------------------------------------------------------------- /.runconfigs: -------------------------------------------------------------------------------- 1 | # NVIDIA-Docker run options 2 | docker_args: -v /home/galluccio/datasets/private:/exp/data 3 | 4 | # Run script options 5 | # To run the container multiple times 6 | num_iters: 1 7 | # The name of containers (remember you have to prefix it with '{user}_') 8 | container_name: {user}_Differnet_test_GPU{args.gpu}_{date} 9 | # The docker image to use 10 | image_name: galluccio/fast-flow-1 11 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | #FROM airlab404/dl:cuda10_pytorch_py36 2 | FROM python:3.8 3 | 4 | WORKDIR /exp 5 | 6 | # Install extras 7 | #COPY requirements.txt requirements.txt 8 | COPY . . 9 | RUN pip install -r requirements.txt 10 | #COPY . . 11 | CMD ["python", "main.py"] 12 | 13 | 14 | -------------------------------------------------------------------------------- /FrEIA/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Framework for Easily Invertible Architectures. 3 | Module to construct invertible networks with pytorch, based on a graph 4 | structure of operations. 5 | """ 6 | from . import framework 7 | from . import modules 8 | 9 | __all__ = ["framework", "modules"] 10 | -------------------------------------------------------------------------------- /FrEIA/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlessioGalluccio/FastFlow/57bd6c02b347daaebd7dc29e47865b3ff64aeaa4/FrEIA/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /FrEIA/framework/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | The framework module contains the logic used in building the graph and 3 | inferring the order that the nodes have to be executed in forward and backward 4 | direction. 5 | """ 6 | 7 | from .graph_inn import * 8 | from .sequence_inn import * 9 | from .reversible_graph_net import * 10 | from .reversible_sequential_net import * 11 | 12 | __all__ = [ 13 | 'SequenceINN', 14 | 'ReversibleSequential', 15 | 'GraphINN', 16 | 'ReversibleGraphNet', 17 | 'Node', 18 | 'InputNode', 19 | 'ConditionNode', 20 | 'OutputNode' 21 | ] 22 | -------------------------------------------------------------------------------- /FrEIA/framework/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlessioGalluccio/FastFlow/57bd6c02b347daaebd7dc29e47865b3ff64aeaa4/FrEIA/framework/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /FrEIA/framework/__pycache__/graph_inn.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlessioGalluccio/FastFlow/57bd6c02b347daaebd7dc29e47865b3ff64aeaa4/FrEIA/framework/__pycache__/graph_inn.cpython-38.pyc -------------------------------------------------------------------------------- /FrEIA/framework/__pycache__/reversible_graph_net.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlessioGalluccio/FastFlow/57bd6c02b347daaebd7dc29e47865b3ff64aeaa4/FrEIA/framework/__pycache__/reversible_graph_net.cpython-38.pyc -------------------------------------------------------------------------------- /FrEIA/framework/__pycache__/reversible_sequential_net.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlessioGalluccio/FastFlow/57bd6c02b347daaebd7dc29e47865b3ff64aeaa4/FrEIA/framework/__pycache__/reversible_sequential_net.cpython-38.pyc -------------------------------------------------------------------------------- /FrEIA/framework/__pycache__/sequence_inn.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlessioGalluccio/FastFlow/57bd6c02b347daaebd7dc29e47865b3ff64aeaa4/FrEIA/framework/__pycache__/sequence_inn.cpython-38.pyc -------------------------------------------------------------------------------- /FrEIA/framework/graph_inn.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from collections import deque, defaultdict 3 | from typing import List, Tuple, Iterable, Union, Optional 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | from torch import Tensor 9 | 10 | from ..modules.base import InvertibleModule 11 | 12 | 13 | class Node: 14 | """ 15 | The Node class represents one transformation in the graph, with an 16 | arbitrary number of in- and outputs. 17 | 18 | The user specifies the input, and the underlying module computes the 19 | number of outputs. 20 | """ 21 | 22 | def __init__(self, inputs: Union["Node", Tuple["Node", int], 23 | Iterable[Tuple["Node", int]]], 24 | module_type, module_args: dict, conditions=None, name=None): 25 | if conditions is None: 26 | conditions = [] 27 | 28 | if name: 29 | self.name = name 30 | else: 31 | self.name = hex(id(self))[-6:] 32 | self.inputs = self.parse_inputs(inputs) 33 | if isinstance(conditions, (list, tuple)): 34 | self.conditions = conditions 35 | else: 36 | self.conditions = [conditions, ] 37 | 38 | self.outputs: List[Tuple[Node, int]] = [] 39 | self.module_type = module_type 40 | self.module_args = module_args 41 | 42 | input_shapes = [input_node.output_dims[node_out_idx] 43 | for input_node, node_out_idx in self.inputs] 44 | condition_shapes = [cond_node.output_dims[0] 45 | for cond_node in self.conditions] 46 | 47 | self.input_dims = input_shapes 48 | self.condition_dims = condition_shapes 49 | self.module, self.output_dims = self.build_module(condition_shapes, 50 | input_shapes) 51 | 52 | # Notify preceding nodes that their output ends up here 53 | # Entry at position co -> (n, ci) means: 54 | # My output co goes to input channel ci of n. 55 | for in_idx, (in_node, out_idx) in enumerate(self.inputs): 56 | in_node.outputs[out_idx] = (self, in_idx) 57 | 58 | # Enable .outX access 59 | for i in range(len(self.output_dims)): 60 | self.__dict__[f"out{i}"] = self, i 61 | self.outputs.append(None) 62 | 63 | def build_module(self, condition_shapes, input_shapes) \ 64 | -> Tuple[InvertibleModule, List[Tuple[int]]]: 65 | """ 66 | Instantiates the module and determines the output dimension by 67 | calling InvertibleModule#output_dims. 68 | """ 69 | if len(self.conditions) > 0: 70 | module = self.module_type(input_shapes, dims_c=condition_shapes, 71 | **self.module_args) 72 | else: 73 | module = self.module_type(input_shapes, **self.module_args) 74 | return module, module.output_dims(input_shapes) 75 | 76 | def parse_inputs(self, inputs: Union["Node", Tuple["Node", int], 77 | Iterable[Tuple["Node", int]]]) \ 78 | -> List[Tuple["Node", int]]: 79 | """ 80 | Converts specified inputs to a node to a canonical format. 81 | Inputs can be specified in three forms: 82 | 83 | - a single node, then this nodes first output is taken as input 84 | - a single tuple (node, idx), specifying output idx of node 85 | - a list of tuples [(node, idx)], each specifying output idx of node 86 | 87 | All such formats are converted to the last format. 88 | """ 89 | if isinstance(inputs, (list, tuple)): 90 | if len(inputs) == 0: 91 | return inputs 92 | elif isinstance(inputs[0], (list, tuple)): 93 | return inputs 94 | elif len(inputs) == 2: 95 | return [inputs, ] 96 | else: 97 | raise ValueError( 98 | f"Cannot parse inputs provided to node '{self.name}'.") 99 | else: 100 | if not isinstance(inputs, Node): 101 | raise TypeError(f"Received object of invalid type " 102 | f"({type(inputs)}) as input for node " 103 | f"'{self.name}'.") 104 | return [(inputs, 0), ] 105 | 106 | def __str__(self): 107 | module_hint = (self.module_type.__name__ if self.module_type is not None 108 | else "") 109 | name_hint = f" {self.name!r}" if self.name is not None else "" 110 | return f"{self.__class__.__name__}{name_hint}: {self.input_dims} -> " \ 111 | f"{module_hint} -> {self.output_dims}" 112 | 113 | def __repr__(self): 114 | name_hint = f" {self.name!r}" if self.name is not None else "" 115 | return f"{self.__class__.__name__}{name_hint}" 116 | 117 | 118 | class InputNode(Node): 119 | """ 120 | Special type of node that represents the input data of the whole net (or the 121 | output when running reverse) 122 | """ 123 | 124 | def __init__(self, *dims: int, name=None): 125 | self.dims = dims 126 | super().__init__([], None, {}, name=name) 127 | 128 | def build_module(self, condition_shapes, input_shapes) \ 129 | -> Tuple[None, List[Tuple[int]]]: 130 | if len(condition_shapes) > 0: 131 | raise ValueError( 132 | f"{self.__class__.__name__} does not accept conditions") 133 | assert len(input_shapes) == 0, "Forbidden by constructor" 134 | return None, [self.dims] 135 | 136 | 137 | class ConditionNode(Node): 138 | """ 139 | Special type of node that represents contitional input to the internal 140 | networks inside coupling layers. 141 | """ 142 | 143 | def __init__(self, *dims: int, name=None): 144 | self.dims = dims 145 | super().__init__([], None, {}, name=name) 146 | self.outputs: List[Tuple[Node, int]] = [] 147 | 148 | def build_module(self, condition_shapes, input_shapes) \ 149 | -> Tuple[None, List[Tuple[int]]]: 150 | if len(condition_shapes) > 0: 151 | raise ValueError( 152 | f"{self.__class__.__name__} does not accept conditions") 153 | assert len(input_shapes) == 0, "Forbidden by constructor" 154 | return None, [self.dims] 155 | 156 | 157 | class OutputNode(Node): 158 | """ 159 | Special type of node that represents the output of the whole net (or the 160 | input when running in reverse). 161 | """ 162 | 163 | def __init__(self, in_node: Union[Node, Tuple[Node, int]], name=None): 164 | super().__init__(in_node, None, {}, name=name) 165 | 166 | def build_module(self, condition_shapes, input_shapes) \ 167 | -> Tuple[None, List[Tuple[int]]]: 168 | if len(condition_shapes) > 0: 169 | raise ValueError( 170 | f"{self.__class__.__name__} does not accept conditions") 171 | if len(input_shapes) != 1: 172 | raise ValueError(f"Output node received {len(input_shapes)} inputs," 173 | f"but only single input is allowed.") 174 | return None, [] 175 | 176 | 177 | class GraphINN(InvertibleModule): 178 | """ 179 | This class represents the invertible net itself. It is a subclass of 180 | InvertibleModule and supports the same methods. 181 | 182 | The forward method has an additional option 'rev', with which the net can be 183 | computed in reverse. Passing `jac` to the forward method additionally 184 | computes the log determinant of the (inverse) Jacobian of the forward 185 | (backward) pass. 186 | """ 187 | 188 | def __init__(self, node_list, force_tuple_output=False, verbose=False): 189 | # Gather lists of input, output and condition nodes 190 | in_nodes = [node_list[i] for i in range(len(node_list)) 191 | if isinstance(node_list[i], InputNode)] 192 | out_nodes = [node_list[i] for i in range(len(node_list)) 193 | if isinstance(node_list[i], OutputNode)] 194 | condition_nodes = [node_list[i] for i in range(len(node_list)) if 195 | isinstance(node_list[i], ConditionNode)] 196 | 197 | # Check that all nodes are in the list 198 | for node in node_list: 199 | for in_node, idx in node.inputs: 200 | if in_node not in node_list: 201 | raise ValueError(f"{node} gets input from {in_node}, " 202 | f"but the latter is not in the node_list " 203 | f"passed to GraphINN.") 204 | for out_node, idx in node.outputs: 205 | if out_node not in node_list: 206 | raise ValueError(f"{out_node} gets input from {node}, " 207 | f"but the it's not in the node_list " 208 | f"passed to GraphINN.") 209 | 210 | # Build the graph and tell nodes about their dimensions so that they can 211 | # build the modules 212 | node_list = topological_order(node_list, in_nodes, out_nodes) 213 | global_in_shapes = [node.output_dims[0] for node in in_nodes] 214 | global_out_shapes = [node.input_dims[0] for node in out_nodes] 215 | global_cond_shapes = [node.output_dims[0] for node in condition_nodes] 216 | 217 | # Only now we can set out shapes 218 | super().__init__(global_in_shapes, global_cond_shapes) 219 | self.node_list = node_list 220 | 221 | # Now we can store everything -- before calling super constructor, 222 | # nn.Module doesn't allow assigning anything 223 | self.in_nodes = in_nodes 224 | self.condition_nodes = condition_nodes 225 | self.out_nodes = out_nodes 226 | 227 | self.global_out_shapes = global_out_shapes 228 | self.force_tuple_output = force_tuple_output 229 | self.module_list = nn.ModuleList([n.module for n in node_list 230 | if n.module is not None]) 231 | 232 | if verbose: 233 | print(self) 234 | 235 | def output_dims(self, input_dims: List[Tuple[int]]) -> List[Tuple[int]]: 236 | if len(self.global_out_shapes) == 1 and not self.force_tuple_output: 237 | raise ValueError("You can only call output_dims on a " 238 | "GraphINN with more than one output " 239 | "or when setting force_tuple_output=True.") 240 | return self.global_out_shapes 241 | 242 | def forward(self, x_or_z: Union[Tensor, Iterable[Tensor]], 243 | c: Iterable[Tensor] = None, rev: bool = False, jac: bool = True, 244 | intermediate_outputs: bool = False, x: None = None) \ 245 | -> Tuple[Tuple[Tensor], Tensor]: 246 | """ 247 | Forward or backward computation of the whole net. 248 | """ 249 | if x is not None: 250 | x_or_z = x 251 | warnings.warn("You called GraphINN(x=...). x is now called x_or_z, " 252 | "please pass input as positional argument.") 253 | 254 | if torch.is_tensor(x_or_z): 255 | x_or_z = x_or_z, 256 | if torch.is_tensor(c): 257 | c = c, 258 | 259 | jacobian = torch.zeros(x_or_z[0].shape[0]).to(x_or_z[0]) 260 | outs = {} 261 | jacobian_dict = {} if jac else None 262 | 263 | # Explicitly set conditions and starts 264 | start_nodes = self.out_nodes if rev else self.in_nodes 265 | if len(x_or_z) != len(start_nodes): 266 | raise ValueError(f"Got {len(x_or_z)} inputs, but expected " 267 | f"{len(start_nodes)}.") 268 | for tensor, start_node in zip(x_or_z, start_nodes): 269 | outs[start_node, 0] = tensor 270 | 271 | if c is None: 272 | c = [] 273 | if len(c) != len(self.condition_nodes): 274 | raise ValueError(f"Got {len(c)} conditions, but expected " 275 | f"{len(self.condition_nodes)}.") 276 | for tensor, condition_node in zip(c, self.condition_nodes): 277 | outs[condition_node, 0] = tensor 278 | 279 | # Go backwards through nodes if rev=True 280 | for node in self.node_list[::-1 if rev else 1]: 281 | # Skip all special nodes 282 | if node in self.in_nodes + self.out_nodes + self.condition_nodes: 283 | continue 284 | 285 | has_condition = len(node.conditions) > 0 286 | 287 | mod_in = [] 288 | mod_c = [] 289 | for prev_node, channel in (node.outputs if rev else node.inputs): 290 | mod_in.append(outs[prev_node, channel]) 291 | for cond_node in node.conditions: 292 | mod_c.append(outs[cond_node, 0]) 293 | mod_in = tuple(mod_in) 294 | mod_c = tuple(mod_c) 295 | 296 | try: 297 | if has_condition: 298 | mod_out = node.module(mod_in, c=mod_c, rev=rev, jac=jac) 299 | else: 300 | mod_out = node.module(mod_in, rev=rev, jac=jac) 301 | except Exception as e: 302 | raise RuntimeError(f"{node} encountered an error.") from e 303 | 304 | out, mod_jac = self._check_output(node, mod_out, jac, rev) 305 | 306 | for out_idx, out_value in enumerate(out): 307 | outs[node, out_idx] = out_value 308 | 309 | if jac: 310 | jacobian = jacobian + mod_jac 311 | jacobian_dict[node] = mod_jac 312 | 313 | for out_node in (self.in_nodes if rev else self.out_nodes): 314 | # This copies the one input of the out node 315 | outs[out_node, 0] = outs[(out_node.outputs if rev 316 | else out_node.inputs)[0]] 317 | 318 | if intermediate_outputs: 319 | return outs, jacobian_dict 320 | else: 321 | out_list = [outs[out_node, 0] for out_node 322 | in (self.in_nodes if rev else self.out_nodes)] 323 | if len(out_list) == 1 and not self.force_tuple_output: 324 | return out_list[0], jacobian 325 | else: 326 | return tuple(out_list), jacobian 327 | 328 | def _check_output(self, node, mod_out, jac, rev): 329 | if torch.is_tensor(mod_out): 330 | raise ValueError( 331 | f"The node {node}'s module returned a tensor only. This " 332 | f"is deprecated without fallback. Please follow the " 333 | f"signature of InvertibleOperator#forward in your module " 334 | f"if you want to use it in a GraphINN.") 335 | 336 | if len(mod_out) != 2: 337 | raise ValueError( 338 | f"The node {node}'s module returned a tuple of length " 339 | f"{len(mod_out)}, but should return a tuple `z_or_x, jac`.") 340 | 341 | out, mod_jac = mod_out 342 | 343 | if torch.is_tensor(out): 344 | raise ValueError(f"The node {node}'s module returns a tensor. " 345 | f"This is deprecated.") 346 | 347 | if len(out) != len(node.inputs if rev else node.outputs): 348 | raise ValueError( 349 | f"The node {node}'s module returned {len(out)} output " 350 | f"variables, but should return " 351 | f"{len(node.inputs if rev else node.outputs)}.") 352 | 353 | if not torch.is_tensor(mod_jac): 354 | if isinstance(mod_jac, (float, int)): 355 | mod_jac = torch.zeros(out[0].shape[0]).to(out[0].device) \ 356 | + mod_jac 357 | elif jac: 358 | raise ValueError( 359 | f"The node {node}'s module returned a non-tensor as " 360 | f"Jacobian: {mod_jac}") 361 | elif not jac and mod_jac is not None: 362 | raise ValueError( 363 | f"The node {node}'s module returned neither None nor a " 364 | f"Jacobian: {mod_jac}") 365 | return out, mod_jac 366 | 367 | def log_jacobian_numerical(self, x, c=None, rev=False, h=1e-04): 368 | """ 369 | Approximate log Jacobian determinant via finite differences. 370 | """ 371 | if isinstance(x, (list, tuple)): 372 | batch_size = x[0].shape[0] 373 | ndim_x_separate = [np.prod(x_i.shape[1:]) for x_i in x] 374 | ndim_x_total = sum(ndim_x_separate) 375 | x_flat = torch.cat([x_i.view(batch_size, -1) for x_i in x], dim=1) 376 | else: 377 | batch_size = x.shape[0] 378 | ndim_x_total = np.prod(x.shape[1:]) 379 | x_flat = x.reshape(batch_size, -1) 380 | 381 | J_num = torch.zeros(batch_size, ndim_x_total, ndim_x_total) 382 | for i in range(ndim_x_total): 383 | offset = x[0].new_zeros(batch_size, ndim_x_total) 384 | offset[:, i] = h 385 | if isinstance(x, (list, tuple)): 386 | x_upper = torch.split(x_flat + offset, ndim_x_separate, dim=1) 387 | x_upper = [x_upper[i].view(*x[i].shape) for i in range(len(x))] 388 | x_lower = torch.split(x_flat - offset, ndim_x_separate, dim=1) 389 | x_lower = [x_lower[i].view(*x[i].shape) for i in range(len(x))] 390 | else: 391 | x_upper = (x_flat + offset).view(*x.shape) 392 | x_lower = (x_flat - offset).view(*x.shape) 393 | y_upper, _ = self.forward(x_upper, c=c, rev=rev, jac=False) 394 | y_lower, _ = self.forward(x_lower, c=c, rev=rev, jac=False) 395 | if isinstance(y_upper, (list, tuple)): 396 | y_upper = torch.cat( 397 | [y_i.view(batch_size, -1) for y_i in y_upper], dim=1) 398 | y_lower = torch.cat( 399 | [y_i.view(batch_size, -1) for y_i in y_lower], dim=1) 400 | J_num[:, :, i] = (y_upper - y_lower).view(batch_size, -1) / (2 * h) 401 | logdet_num = x[0].new_zeros(batch_size) 402 | for i in range(batch_size): 403 | logdet_num[i] = torch.slogdet(J_num[i])[1] 404 | 405 | return logdet_num 406 | 407 | def get_node_by_name(self, name) -> Optional[Node]: 408 | """ 409 | Return the first node in the graph with the provided name. 410 | """ 411 | for node in self.node_list: 412 | if node.name == name: 413 | return node 414 | return None 415 | 416 | def get_module_by_name(self, name) -> Optional[nn.Module]: 417 | """ 418 | Return module of the first node in the graph with the provided name. 419 | """ 420 | node = self.get_node_by_name(name) 421 | try: 422 | return node.module 423 | except AttributeError: 424 | return None 425 | 426 | 427 | def topological_order(all_nodes: List[Node], in_nodes: List[InputNode], 428 | out_nodes: List[OutputNode]) -> List[Node]: 429 | """ 430 | Computes the topological order of nodes. 431 | 432 | Parameters: 433 | all_nodes: All nodes in the computation graph. 434 | in_nodes: Input nodes (must also be present in `all_nodes`) 435 | out_nodes: Output nodes (must also be present in `all_nodes`) 436 | 437 | Returns: 438 | A sorted list of nodes, where the inputs to some node in the list 439 | are available when all previous nodes in the list have been executed. 440 | """ 441 | # Edge dicts in both directions 442 | edges_out_to_in = {node_b: {node_a for node_a, out_idx in node_b.inputs} for 443 | node_b in all_nodes + out_nodes} 444 | edges_in_to_out = defaultdict(set) 445 | for node_out, node_ins in edges_out_to_in.items(): 446 | for node_in in node_ins: 447 | edges_in_to_out[node_in].add(node_out) 448 | 449 | # Kahn's algorithm starting from the output nodes 450 | sorted_nodes = [] 451 | no_pending_edges = deque(out_nodes) 452 | 453 | while len(no_pending_edges) > 0: 454 | node = no_pending_edges.popleft() 455 | sorted_nodes.append(node) 456 | for in_node in list(edges_out_to_in[node]): 457 | edges_out_to_in[node].remove(in_node) 458 | edges_in_to_out[in_node].remove(node) 459 | 460 | if len(edges_in_to_out[in_node]) == 0: 461 | no_pending_edges.append(in_node) 462 | 463 | for in_node in in_nodes: 464 | if in_node not in sorted_nodes: 465 | raise ValueError(f"Error in graph: {in_node} is not connected " 466 | f"to any output.") 467 | 468 | if sum(map(len, edges_in_to_out.values())) == 0: 469 | return sorted_nodes[::-1] 470 | else: 471 | raise ValueError("Graph is cyclic.") 472 | -------------------------------------------------------------------------------- /FrEIA/framework/reversible_graph_net.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from typing import Union, Iterable, Tuple 3 | 4 | from torch import Tensor 5 | 6 | from .graph_inn import GraphINN 7 | 8 | 9 | class ReversibleGraphNet(GraphINN): 10 | def __init__(self, node_list, ind_in=None, ind_out=None, verbose=True, 11 | force_tuple_output=False): 12 | warnings.warn("ReversibleGraphNet is deprecated in favour of GraphINN. " 13 | "It will be removed in the next version of FrEIA_github.", 14 | DeprecationWarning) 15 | if ind_in is not None: 16 | raise ValueError( 17 | "ReversibleGraphNet's ind_in was removed in FrEIA_github v0.3.0. " 18 | "Please use InputNodes and switch to GraphINN." 19 | ) 20 | if ind_out is not None: 21 | raise ValueError( 22 | "ReversibleGraphNet's ind_out was removed in FrEIA_github v0.3.0. " 23 | "Please use OutputNodes and switch to GraphINN." 24 | ) 25 | super().__init__(node_list, verbose=verbose, 26 | force_tuple_output=force_tuple_output) 27 | 28 | def forward(self, x_or_z: Union[Tensor, Iterable[Tensor]], 29 | c: Iterable[Tensor] = None, rev: bool = False, jac: bool = True, 30 | intermediate_outputs: bool = False)\ 31 | -> Tuple[Tuple[Tensor], Tensor]: 32 | warnings.warn("ReversibleGraphNet's forward() now " 33 | "returns a tuple (output, jacobian). " 34 | "It will be removed in the next version of FrEIA_github.", 35 | DeprecationWarning) 36 | return super().forward(x_or_z, c, rev, jac, intermediate_outputs) 37 | -------------------------------------------------------------------------------- /FrEIA/framework/reversible_sequential_net.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | from FrEIA.framework.sequence_inn import SequenceINN 4 | 5 | 6 | class ReversibleSequential(SequenceINN): 7 | def __init__(self, *dims: int): 8 | warnings.warn("ReversibleSequential is deprecated in favour of " 9 | "SequenceINN. It will be removed in a future version " 10 | "of FrEIA_github.", 11 | DeprecationWarning) 12 | super().__init__(*dims) 13 | -------------------------------------------------------------------------------- /FrEIA/framework/sequence_inn.py: -------------------------------------------------------------------------------- 1 | from typing import Iterable, Tuple, List 2 | 3 | import torch.nn as nn 4 | import torch 5 | from torch import Tensor 6 | 7 | from FrEIA.modules import InvertibleModule 8 | 9 | 10 | class SequenceINN(InvertibleModule): 11 | """ 12 | Simpler than FrEIA_github.framework.GraphINN: 13 | Only supports a sequential series of modules (no splitting, merging, 14 | branching off). 15 | Has an append() method, to add new blocks in a more simple way than the 16 | computation-graph based approach of GraphINN. For example: 17 | 18 | ``` 19 | inn = SequenceINN(channels, dims_H, dims_W) 20 | 21 | for i in range(n_blocks): 22 | inn.append(FrEIA_github.modules.AllInOneBlock, clamp=2.0, permute_soft=True) 23 | inn.append(FrEIA_github.modules.HaarDownsampling) 24 | # and so on 25 | ``` 26 | """ 27 | 28 | def __init__(self, *dims: int, force_tuple_output=False): 29 | super().__init__([dims]) 30 | 31 | self.shapes = [tuple(dims)] 32 | self.conditions = [] 33 | self.module_list = nn.ModuleList() 34 | 35 | self.force_tuple_output = force_tuple_output 36 | 37 | def append(self, module_class, cond=None, cond_shape=None, **kwargs): 38 | """ 39 | Append a reversible block from FrEIA_github.modules to the network. 40 | module_class: Class from FrEIA_github.modules. 41 | cond (int): index of which condition to use (conditions will be passed as list to forward()). 42 | Conditioning nodes are not needed for SequenceINN. 43 | cond_shape (tuple[int]): the shape of the condition tensor. 44 | **kwargs: Further keyword arguments that are passed to the constructor of module_class (see example). 45 | """ 46 | 47 | dims_in = [self.shapes[-1]] 48 | self.conditions.append(cond) 49 | 50 | if cond is not None: 51 | kwargs['dims_c'] = [cond_shape] 52 | 53 | module = module_class(dims_in, **kwargs) 54 | self.module_list.append(module) 55 | ouput_dims = module.output_dims(dims_in) 56 | assert len(ouput_dims) == 1, "Module has more than one output" 57 | self.shapes.append(ouput_dims[0]) 58 | 59 | def __getitem__(self, item): 60 | return self.module_list.__getitem__(item) 61 | 62 | def __len__(self): 63 | return self.module_list.__len__() 64 | 65 | def __iter__(self): 66 | return self.module_list.__iter__() 67 | 68 | def output_dims(self, input_dims: List[Tuple[int]]) -> List[Tuple[int]]: 69 | if not self.force_tuple_output: 70 | raise ValueError("You can only call output_dims on a SequentialINN " 71 | "when setting force_tuple_output=True.") 72 | return input_dims 73 | 74 | def forward(self, x_or_z: Tensor, c: Iterable[Tensor] = None, 75 | rev: bool = False, jac: bool = True) -> Tuple[Tensor, Tensor]: 76 | """ 77 | Executes the sequential INN in forward or inverse (rev=True) direction. 78 | 79 | Arguments: 80 | x_or_z: input tensor (in contrast to GraphINN, a list of 81 | tensors is not supported, as SequenceINN only has 82 | one input). 83 | c: list of conditions. 84 | rev: whether to compute the network forward or reversed. 85 | jac: whether to compute the log jacobian 86 | 87 | Returns: 88 | z_or_x (Tensor): network output. 89 | jac (Tensor): log-jacobian-determinant. 90 | """ 91 | 92 | iterator = range(len(self.module_list)) 93 | log_det_jac = 0 94 | 95 | if rev: 96 | iterator = reversed(iterator) 97 | 98 | if torch.is_tensor(x_or_z): 99 | x_or_z = (x_or_z,) 100 | for i in iterator: 101 | if self.conditions[i] is None: 102 | x_or_z, j = self.module_list[i](x_or_z, jac=jac, rev=rev) 103 | else: 104 | x_or_z, j = self.module_list[i](x_or_z, c=[c[self.conditions[i]]], 105 | jac=jac, rev=rev) 106 | log_det_jac = j + log_det_jac 107 | 108 | return x_or_z if self.force_tuple_output else x_or_z[0], log_det_jac 109 | 110 | 111 | -------------------------------------------------------------------------------- /FrEIA/modules/__init__.py: -------------------------------------------------------------------------------- 1 | '''Subclasses of torch.nn.Module, that are reversible and can be used in the 2 | nodes of the GraphINN class. The only additional things that are 3 | needed compared to the base class is an @staticmethod otuput_dims, and the 4 | 'rev'-argument of the forward-method. 5 | 6 | Abstract template: 7 | 8 | * InvertibleModule 9 | 10 | Coupling blocks: 11 | 12 | * AllInOneBlock 13 | * NICECouplingBlock 14 | * RNVPCouplingBlock 15 | * GLOWCouplingBlock 16 | * GINCouplingBlock 17 | * AffineCouplingOneSided 18 | * ConditionalAffineTransform 19 | 20 | Reshaping: 21 | 22 | * IRevNetDownsampling 23 | * IRevNetUpsampling 24 | * HaarDownsampling 25 | * HaarUpsampling 26 | * Flatten 27 | * Reshape 28 | 29 | Graph topology: 30 | 31 | * Split 32 | * Concat 33 | 34 | Other learned transforms: 35 | 36 | * ActNorm 37 | * IResNetLayer 38 | * InvAutoAct 39 | * InvAutoActFixed 40 | * InvAutoActTwoSided 41 | * InvAutoConv2D 42 | * InvAutoFC 43 | * LearnedElementwiseScaling 44 | * OrthogonalTransform 45 | * HouseholderPerm 46 | 47 | Fixed (non-learned) transforms: 48 | 49 | * PermuteRandom 50 | * FixedLinearTransform 51 | * Fixed1x1Conv 52 | * InvertibleSigmoid 53 | 54 | ''' 55 | 56 | # Import the base class first 57 | from .base import * 58 | 59 | # Then all inheriting modules 60 | from .all_in_one_block import * 61 | from .fixed_transforms import * 62 | from .reshapes import * 63 | from .coupling_layers import * 64 | from .graph_topology import * 65 | from .orthogonal import * 66 | from .inv_auto_layers import * 67 | from .invertible_resnet import * 68 | from .gaussian_mixture import * 69 | 70 | __all__ = [ 71 | 'InvertibleModule', 72 | 'AllInOneBlock', 73 | 'ActNorm', 74 | 'HouseholderPerm', 75 | 'IResNetLayer', 76 | 'InvAutoAct', 77 | 'InvAutoActFixed', 78 | 'InvAutoActTwoSided', 79 | 'InvAutoConv2D', 80 | 'InvAutoFC', 81 | 'InvertibleModule', 82 | 'LearnedElementwiseScaling', 83 | 'NICECouplingBlock', 84 | 'RNVPCouplingBlock', 85 | 'GLOWCouplingBlock', 86 | 'GINCouplingBlock', 87 | 'AffineCouplingOneSided', 88 | 'ConditionalAffineTransform', 89 | 'PermuteRandom', 90 | 'FixedLinearTransform', 91 | 'Fixed1x1Conv', 92 | 'InvertibleSigmoid', 93 | 'SplitChannel', 94 | 'ConcatChannel', 95 | 'Split', 96 | 'Concat', 97 | 'OrthogonalTransform', 98 | 'HouseholderPerm', 99 | 'IRevNetDownsampling', 100 | 'IRevNetUpsampling', 101 | 'HaarDownsampling', 102 | 'HaarUpsampling', 103 | 'Flatten', 104 | 'Reshape', 105 | 'GaussianMixtureModel', 106 | ] 107 | -------------------------------------------------------------------------------- /FrEIA/modules/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlessioGalluccio/FastFlow/57bd6c02b347daaebd7dc29e47865b3ff64aeaa4/FrEIA/modules/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /FrEIA/modules/__pycache__/all_in_one_block.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlessioGalluccio/FastFlow/57bd6c02b347daaebd7dc29e47865b3ff64aeaa4/FrEIA/modules/__pycache__/all_in_one_block.cpython-38.pyc -------------------------------------------------------------------------------- /FrEIA/modules/__pycache__/base.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlessioGalluccio/FastFlow/57bd6c02b347daaebd7dc29e47865b3ff64aeaa4/FrEIA/modules/__pycache__/base.cpython-38.pyc -------------------------------------------------------------------------------- /FrEIA/modules/__pycache__/coupling_layers.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlessioGalluccio/FastFlow/57bd6c02b347daaebd7dc29e47865b3ff64aeaa4/FrEIA/modules/__pycache__/coupling_layers.cpython-38.pyc -------------------------------------------------------------------------------- /FrEIA/modules/__pycache__/fixed_transforms.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlessioGalluccio/FastFlow/57bd6c02b347daaebd7dc29e47865b3ff64aeaa4/FrEIA/modules/__pycache__/fixed_transforms.cpython-38.pyc -------------------------------------------------------------------------------- /FrEIA/modules/__pycache__/gaussian_mixture.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlessioGalluccio/FastFlow/57bd6c02b347daaebd7dc29e47865b3ff64aeaa4/FrEIA/modules/__pycache__/gaussian_mixture.cpython-38.pyc -------------------------------------------------------------------------------- /FrEIA/modules/__pycache__/graph_topology.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlessioGalluccio/FastFlow/57bd6c02b347daaebd7dc29e47865b3ff64aeaa4/FrEIA/modules/__pycache__/graph_topology.cpython-38.pyc -------------------------------------------------------------------------------- /FrEIA/modules/__pycache__/inv_auto_layers.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlessioGalluccio/FastFlow/57bd6c02b347daaebd7dc29e47865b3ff64aeaa4/FrEIA/modules/__pycache__/inv_auto_layers.cpython-38.pyc -------------------------------------------------------------------------------- /FrEIA/modules/__pycache__/invertible_resnet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlessioGalluccio/FastFlow/57bd6c02b347daaebd7dc29e47865b3ff64aeaa4/FrEIA/modules/__pycache__/invertible_resnet.cpython-38.pyc -------------------------------------------------------------------------------- /FrEIA/modules/__pycache__/orthogonal.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlessioGalluccio/FastFlow/57bd6c02b347daaebd7dc29e47865b3ff64aeaa4/FrEIA/modules/__pycache__/orthogonal.cpython-38.pyc -------------------------------------------------------------------------------- /FrEIA/modules/__pycache__/reshapes.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlessioGalluccio/FastFlow/57bd6c02b347daaebd7dc29e47865b3ff64aeaa4/FrEIA/modules/__pycache__/reshapes.cpython-38.pyc -------------------------------------------------------------------------------- /FrEIA/modules/all_in_one_block.py: -------------------------------------------------------------------------------- 1 | from . import InvertibleModule 2 | 3 | import warnings 4 | from typing import Callable 5 | 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from scipy.stats import special_ortho_group 11 | 12 | 13 | class AllInOneBlock(InvertibleModule): 14 | '''Module combining the most common operations in a normalizing flow or similar model. 15 | 16 | It combines affine coupling, permutation, and global affine transformation 17 | ('ActNorm'). It can also be used as GIN coupling block, perform learned 18 | householder permutations, and use an inverted pre-permutation. The affine 19 | transformation includes a soft clamping mechanism, first used in Real-NVP. 20 | The block as a whole performs the following computation: 21 | 22 | .. math:: 23 | 24 | y = V\\,R \\; \\Psi(s_\\mathrm{global}) \\odot \\mathrm{Coupling}\\Big(R^{-1} V^{-1} x\\Big)+ t_\\mathrm{global} 25 | 26 | - The inverse pre-permutation of x (i.e. :math:`R^{-1} V^{-1}`) is optional (see 27 | ``reverse_permutation`` below). 28 | - The learned householder reflection matrix 29 | :math:`V` is also optional all together (see ``learned_householder_permutation`` 30 | below). 31 | - For the coupling, the input is split into :math:`x_1, x_2` along 32 | the channel dimension. Then the output of the coupling operation is the 33 | two halves :math:`u = \\mathrm{concat}(u_1, u_2)`. 34 | 35 | .. math:: 36 | 37 | u_1 &= x_1 \\odot \\exp \\Big( \\alpha \\; \\mathrm{tanh}\\big( s(x_2) \\big)\\Big) + t(x_2) \\\\ 38 | u_2 &= x_2 39 | 40 | Because :math:`\\mathrm{tanh}(s) \\in [-1, 1]`, this clamping mechanism prevents 41 | exploding values in the exponential. The hyperparameter :math:`\\alpha` can be adjusted. 42 | 43 | ''' 44 | 45 | def __init__(self, dims_in, dims_c=[], 46 | subnet_constructor: Callable = None, 47 | affine_clamping: float = 2., 48 | gin_block: bool = False, 49 | global_affine_init: float = 1., 50 | global_affine_type: str = 'SOFTPLUS', 51 | permute_soft: bool = False, 52 | learned_householder_permutation: int = 0, 53 | reverse_permutation: bool = False): 54 | ''' 55 | Args: 56 | subnet_constructor: 57 | class or callable ``f``, called as ``f(channels_in, channels_out)`` and 58 | should return a torch.nn.Module. Predicts coupling coefficients :math:`s, t`. 59 | affine_clamping: 60 | clamp the output of the multiplicative coefficients before 61 | exponentiation to +/- ``affine_clamping`` (see :math:`\\alpha` above). 62 | gin_block: 63 | Turn the block into a GIN block from Sorrenson et al, 2019. 64 | Makes it so that the coupling operations as a whole is volume preserving. 65 | global_affine_init: 66 | Initial value for the global affine scaling :math:`s_\mathrm{global}`. 67 | global_affine_init: 68 | ``'SIGMOID'``, ``'SOFTPLUS'``, or ``'EXP'``. Defines the activation to be used 69 | on the beta for the global affine scaling (:math:`\\Psi` above). 70 | permute_soft: 71 | bool, whether to sample the permutation matrix :math:`R` from :math:`SO(N)`, 72 | or to use hard permutations instead. Note, ``permute_soft=True`` is very slow 73 | when working with >512 dimensions. 74 | learned_householder_permutation: 75 | Int, if >0, turn on the matrix :math:`V` above, that represents 76 | multiple learned householder reflections. Slow if large number. 77 | Dubious whether it actually helps network performance. 78 | reverse_permutation: 79 | Reverse the permutation before the block, as introduced by Putzky 80 | et al, 2019. Turns on the :math:`R^{-1} V^{-1}` pre-multiplication above. 81 | ''' 82 | 83 | super().__init__(dims_in, dims_c) 84 | 85 | channels = dims_in[0][0] 86 | # rank of the tensors means 1d, 2d, 3d tensor etc. 87 | self.input_rank = len(dims_in[0]) - 1 88 | # tuple containing all dims except for batch-dim (used at various points) 89 | self.sum_dims = tuple(range(1, 2 + self.input_rank)) 90 | 91 | if len(dims_c) == 0: 92 | self.conditional = False 93 | self.condition_channels = 0 94 | else: 95 | assert tuple(dims_c[0][1:]) == tuple(dims_in[0][1:]), \ 96 | F"Dimensions of input and condition don't agree: {dims_c} vs {dims_in}." 97 | self.conditional = True 98 | self.condition_channels = sum(dc[0] for dc in dims_c) 99 | 100 | split_len1 = channels - channels // 2 101 | split_len2 = channels // 2 102 | self.splits = [split_len1, split_len2] 103 | 104 | try: 105 | self.permute_function = {0: F.linear, 106 | 1: F.conv1d, 107 | 2: F.conv2d, 108 | 3: F.conv3d}[self.input_rank] 109 | except KeyError: 110 | raise ValueError(f"Data is {1 + self.input_rank}D. Must be 1D-4D.") 111 | 112 | self.in_channels = channels 113 | self.clamp = affine_clamping 114 | self.GIN = gin_block 115 | self.reverse_pre_permute = reverse_permutation 116 | self.householder = learned_householder_permutation 117 | 118 | if permute_soft and channels > 512: 119 | warnings.warn(("Soft permutation will take a very long time to initialize " 120 | f"with {channels} feature channels. Consider using hard permutation instead.")) 121 | 122 | # global_scale is used as the initial value for the global affine scale 123 | # (pre-activation). It is computed such that 124 | # global_scale_activation(global_scale) = global_affine_init 125 | # the 'magic numbers' (specifically for sigmoid) scale the activation to 126 | # a sensible range. 127 | if global_affine_type == 'SIGMOID': 128 | global_scale = 2. - np.log(10. / global_affine_init - 1.) 129 | self.global_scale_activation = (lambda a: 10 * torch.sigmoid(a - 2.)) 130 | elif global_affine_type == 'SOFTPLUS': 131 | global_scale = 2. * np.log(np.exp(0.5 * 10. * global_affine_init) - 1) 132 | self.softplus = nn.Softplus(beta=0.5) 133 | self.global_scale_activation = (lambda a: 0.1 * self.softplus(a)) 134 | elif global_affine_type == 'EXP': 135 | global_scale = np.log(global_affine_init) 136 | self.global_scale_activation = (lambda a: torch.exp(a)) 137 | else: 138 | raise ValueError('Global affine activation must be "SIGMOID", "SOFTPLUS" or "EXP"') 139 | 140 | self.global_scale = nn.Parameter(torch.ones(1, self.in_channels, *([1] * self.input_rank)) * float(global_scale)) 141 | self.global_offset = nn.Parameter(torch.zeros(1, self.in_channels, *([1] * self.input_rank))) 142 | 143 | if permute_soft: 144 | w = special_ortho_group.rvs(channels) 145 | else: 146 | w = np.zeros((channels, channels)) 147 | for i, j in enumerate(np.random.permutation(channels)): 148 | w[i, j] = 1. 149 | 150 | if self.householder: 151 | # instead of just the permutation matrix w, the learned housholder 152 | # permutation keeps track of reflection vectors vk, in addition to a 153 | # random initial permutation w_0. 154 | self.vk_householder = nn.Parameter(0.2 * torch.randn(self.householder, channels), requires_grad=True) 155 | self.w_perm = None 156 | self.w_perm_inv = None 157 | self.w_0 = nn.Parameter(torch.FloatTensor(w), requires_grad=False) 158 | else: 159 | self.w_perm = nn.Parameter(torch.FloatTensor(w).view(channels, channels, *([1] * self.input_rank)), 160 | requires_grad=False) 161 | self.w_perm_inv = nn.Parameter(torch.FloatTensor(w.T).view(channels, channels, *([1] * self.input_rank)), 162 | requires_grad=False) 163 | 164 | if subnet_constructor is None: 165 | raise ValueError("Please supply a callable subnet_constructor" 166 | "function or object (see docstring)") 167 | self.subnet = subnet_constructor(self.splits[0] + self.condition_channels, 2 * self.splits[1]) 168 | self.last_jac = None 169 | 170 | def _construct_householder_permutation(self): 171 | '''Computes a permutation matrix from the reflection vectors that are 172 | learned internally as nn.Parameters.''' 173 | w = self.w_0 174 | for vk in self.vk_householder: 175 | w = torch.mm(w, torch.eye(self.in_channels).to(w.device) - 2 * torch.ger(vk, vk) / torch.dot(vk, vk)) 176 | 177 | for i in range(self.input_rank): 178 | w = w.unsqueeze(-1) 179 | return w 180 | 181 | def _permute(self, x, rev=False): 182 | '''Performs the permutation and scaling after the coupling operation. 183 | Returns transformed outputs and the LogJacDet of the scaling operation.''' 184 | if self.GIN: 185 | scale = 1. 186 | perm_log_jac = 0. 187 | else: 188 | scale = self.global_scale_activation(self.global_scale) 189 | perm_log_jac = torch.sum(torch.log(scale)) 190 | 191 | if rev: 192 | return ((self.permute_function(x, self.w_perm_inv) - self.global_offset) / scale, 193 | perm_log_jac) 194 | else: 195 | return (self.permute_function(x * scale + self.global_offset, self.w_perm), 196 | perm_log_jac) 197 | 198 | def _pre_permute(self, x, rev=False): 199 | '''Permutes before the coupling block, only used if 200 | reverse_permutation is set''' 201 | if rev: 202 | return self.permute_function(x, self.w_perm) 203 | else: 204 | return self.permute_function(x, self.w_perm_inv) 205 | 206 | def _affine(self, x, a, rev=False): 207 | '''Given the passive half, and the pre-activation outputs of the 208 | coupling subnetwork, perform the affine coupling operation. 209 | Returns both the transformed inputs and the LogJacDet.''' 210 | 211 | # the entire coupling coefficient tensor is scaled down by a 212 | # factor of ten for stability and easier initialization. 213 | a *= 0.1 214 | ch = x.shape[1] 215 | 216 | sub_jac = self.clamp * torch.tanh(a[:, :ch]) 217 | if self.GIN: 218 | sub_jac -= torch.mean(sub_jac, dim=self.sum_dims, keepdim=True) 219 | 220 | if not rev: 221 | return (x * torch.exp(sub_jac) + a[:, ch:], 222 | torch.sum(sub_jac, dim=self.sum_dims)) 223 | else: 224 | return ((x - a[:, ch:]) * torch.exp(-sub_jac), 225 | -torch.sum(sub_jac, dim=self.sum_dims)) 226 | 227 | def forward(self, x, c=[], rev=False, jac=True): 228 | '''See base class docstring''' 229 | if self.householder: 230 | self.w_perm = self._construct_householder_permutation() 231 | if rev or self.reverse_pre_permute: 232 | self.w_perm_inv = self.w_perm.transpose(0, 1).contiguous() 233 | 234 | if rev: 235 | x, global_scaling_jac = self._permute(x[0], rev=True) 236 | x = (x,) 237 | elif self.reverse_pre_permute: 238 | x = (self._pre_permute(x[0], rev=False),) 239 | 240 | x1, x2 = torch.split(x[0], self.splits, dim=1) 241 | 242 | if self.conditional: 243 | x1c = torch.cat([x1, *c], 1) 244 | else: 245 | x1c = x1 246 | 247 | if not rev: 248 | a1 = self.subnet(x1c) 249 | x2, j2 = self._affine(x2, a1) 250 | else: 251 | a1 = self.subnet(x1c) 252 | x2, j2 = self._affine(x2, a1, rev=True) 253 | 254 | log_jac_det = j2 255 | x_out = torch.cat((x1, x2), 1) 256 | 257 | if not rev: 258 | x_out, global_scaling_jac = self._permute(x_out, rev=False) 259 | elif self.reverse_pre_permute: 260 | x_out = self._pre_permute(x_out, rev=True) 261 | 262 | # add the global scaling Jacobian to the total. 263 | # trick to get the total number of non-channel dimensions: 264 | # number of elements of the first channel of the first batch member 265 | n_pixels = x_out[0, :1].numel() 266 | log_jac_det += (-1)**rev * n_pixels * global_scaling_jac 267 | 268 | return (x_out,), log_jac_det 269 | 270 | def output_dims(self, input_dims): 271 | return input_dims 272 | -------------------------------------------------------------------------------- /FrEIA/modules/base.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Iterable, List 2 | 3 | import torch.nn as nn 4 | from torch import Tensor 5 | 6 | 7 | class InvertibleModule(nn.Module): 8 | """ 9 | Base class for all invertible modules in FrEIA_github. 10 | 11 | Given ``module``, an instance of some InvertibleModule. 12 | This ``module`` shall be invertible in its input dimensions, 13 | so that the input can be recovered by applying the module 14 | in backwards mode (``rev=True``), not to be confused with 15 | ``pytorch.backward()`` which computes the gradient of an operation:: 16 | 17 | x = torch.randn(BATCH_SIZE, DIM_COUNT) 18 | c = torch.randn(BATCH_SIZE, CONDITION_DIM) 19 | 20 | # Forward mode 21 | z, jac = module([x], [c], jac=True) 22 | 23 | # Backward mode 24 | x_rev, jac_rev = module(z, [c], rev=True) 25 | 26 | The ``module`` returns :math:`\\log \\det J = \\log \\left| \\det \\frac{\\partial f}{\\partial x} \\right|` 27 | of the operation in forward mode, and 28 | :math:`-\\log | \\det J | = \\log \\left| \\det \\frac{\\partial f^{-1}}{\\partial z} \\right| = -\\log \\left| \\det \\frac{\\partial f}{\\partial x} \\right|` 29 | in backward mode (``rev=True``). 30 | 31 | Then, ``torch.allclose(x, x_rev) == True`` and ``torch.allclose(jac, -jac_rev) == True``. 32 | """ 33 | 34 | def __init__(self, dims_in: Iterable[Tuple[int]], 35 | dims_c: Iterable[Tuple[int]] = None): 36 | """ 37 | Args: 38 | dims_in: list of tuples specifying the shape of the inputs to this 39 | operator: ``dims_in = [shape_x_0, shape_x_1, ...]`` 40 | dims_c: list of tuples specifying the shape of the conditions to 41 | this operator. 42 | """ 43 | super().__init__() 44 | if dims_c is None: 45 | dims_c = [] 46 | self.dims_in = list(dims_in) 47 | self.dims_c = list(dims_c) 48 | 49 | def forward(self, x_or_z: Iterable[Tensor], c: Iterable[Tensor] = None, 50 | rev: bool = False, jac: bool = True) \ 51 | -> Tuple[Tuple[Tensor], Tensor]: 52 | """ 53 | Perform a forward (default, ``rev=False``) or backward pass (``rev=True``) 54 | through this module/operator. 55 | 56 | **Note to implementers:** 57 | 58 | - Subclasses MUST return a Jacobian when ``jac=True``, but CAN return a 59 | valid Jacobian when ``jac=False`` (not punished). The latter is only recommended 60 | if the computation of the Jacobian is trivial. 61 | - Subclasses MUST follow the convention that the returned Jacobian be 62 | consistent with the evaluation direction. Let's make this more precise: 63 | Let :math:`f` be the function that the subclass represents. Then: 64 | 65 | .. math:: 66 | 67 | J &= \\log \\det \\frac{\\partial f}{\\partial x} \\\\ 68 | -J &= \\log \\det \\frac{\\partial f^{-1}}{\\partial z}. 69 | 70 | Any subclass MUST return :math:`J` for forward evaluation (``rev=False``), 71 | and :math:`-J` for backward evaluation (``rev=True``). 72 | 73 | Args: 74 | x_or_z: input data (array-like of one or more tensors) 75 | c: conditioning data (array-like of none or more tensors) 76 | rev: perform backward pass 77 | jac: return Jacobian associated to the direction 78 | """ 79 | raise NotImplementedError( 80 | f"{self.__class__.__name__} does not provide forward(...) method") 81 | 82 | def log_jacobian(self, *args, **kwargs): 83 | '''This method is deprecated, and does nothing except raise a warning.''' 84 | raise DeprecationWarning("module.log_jacobian(...) is deprecated. " 85 | "module.forward(..., jac=True) returns a " 86 | "tuple (out, jacobian) now.") 87 | 88 | def output_dims(self, input_dims: List[Tuple[int]]) -> List[Tuple[int]]: 89 | ''' 90 | Used for shape inference during construction of the graph. MUST be 91 | implemented for each subclass of ``InvertibleModule``. 92 | 93 | Args: 94 | input_dims: A list with one entry for each input to the module. 95 | Even if the module only has one input, must be a list with one 96 | entry. Each entry is a tuple giving the shape of that input, 97 | excluding the batch dimension. For example for a module with one 98 | input, which receives a 32x32 pixel RGB image, ``input_dims`` would 99 | be ``[(3, 32, 32)]`` 100 | 101 | Returns: 102 | A list structured in the same way as ``input_dims``. Each entry 103 | represents one output of the module, and the entry is a tuple giving 104 | the shape of that output. For example if the module splits the image 105 | into a right and a left half, the return value should be 106 | ``[(3, 16, 32), (3, 16, 32)]``. It is up to the implementor of the 107 | subclass to ensure that the total number of elements in all inputs 108 | and all outputs is consistent. 109 | 110 | ''' 111 | raise NotImplementedError( 112 | f"{self.__class__.__name__} does not provide output_dims(...)") 113 | -------------------------------------------------------------------------------- /FrEIA/modules/coupling_layers.py: -------------------------------------------------------------------------------- 1 | from . import InvertibleModule 2 | 3 | from typing import Callable, Union 4 | 5 | import torch 6 | 7 | 8 | class _BaseCouplingBlock(InvertibleModule): 9 | '''Base class to implement various coupling schemes. It takes care of 10 | checking the dimensions, conditions, clamping mechanism, etc. 11 | Each child class only has to implement the _coupling1 and _coupling2 methods 12 | for the left and right coupling operations. 13 | (In some cases below, forward() is also overridden) 14 | ''' 15 | 16 | def __init__(self, dims_in, dims_c=[], 17 | clamp: float = 2., 18 | clamp_activation: Union[str, Callable] = "ATAN", 19 | split_len: Union[float, int] = 0.5): 20 | ''' 21 | Additional args in docstring of base class. 22 | 23 | Args: 24 | clamp: Soft clamping for the multiplicative component. The 25 | amplification or attenuation of each input dimension can be at most 26 | exp(±clamp). 27 | clamp_activation: Function to perform the clamping. String values 28 | "ATAN", "TANH", and "SIGMOID" are recognized, or a function of 29 | object can be passed. TANH behaves like the original realNVP paper. 30 | A custom function should take tensors and map -inf to -1 and +inf to +1. 31 | split_len: Specify the dimension where the data should be split. 32 | If given as int, directly indicates the split dimension. 33 | If given as float, must fulfil 0 <= split_len <= 1 and number of 34 | unchanged dimensions is set to `round(split_len * dims_in[0, 0])`. 35 | ''' 36 | 37 | super().__init__(dims_in, dims_c) 38 | 39 | self.channels = dims_in[0][0] 40 | 41 | # ndims means the rank of tensor strictly speaking. 42 | # i.e. 1D, 2D, 3D tensor, etc. 43 | self.ndims = len(dims_in[0]) 44 | 45 | if isinstance(split_len, float): 46 | if not (0 <= split_len <= 1): 47 | raise ValueError(f"Float split_len must be in range [0, 1], " 48 | f"but is: {split_len}") 49 | split_len = round(self.channels * split_len) 50 | else: 51 | if not (0 <= split_len <= self.channels): 52 | raise ValueError(f"Integer split_len must be in range " 53 | f"0 <= split_len <= {self.channels}, " 54 | f"but is: {split_len}") 55 | self.split_len1 = split_len 56 | self.split_len2 = self.channels - split_len 57 | 58 | self.clamp = clamp 59 | 60 | assert all([tuple(dims_c[i][1:]) == tuple(dims_in[0][1:]) for i in range(len(dims_c))]), \ 61 | F"Dimensions of input {dims_in} and one or more conditions {dims_c} don't agree." 62 | self.conditional = (len(dims_c) > 0) 63 | self.condition_length = sum([dims_c[i][0] for i in range(len(dims_c))]) 64 | 65 | if isinstance(clamp_activation, str): 66 | if clamp_activation == "ATAN": 67 | self.f_clamp = (lambda u: 0.636 * torch.atan(u)) 68 | elif clamp_activation == "TANH": 69 | self.f_clamp = torch.tanh 70 | elif clamp_activation == "SIGMOID": 71 | self.f_clamp = (lambda u: 2. * (torch.sigmoid(u) - 0.5)) 72 | else: 73 | raise ValueError(f'Unknown clamp activation "{clamp_activation}"') 74 | else: 75 | self.f_clamp = clamp_activation 76 | 77 | def forward(self, x, c=[], rev=False, jac=True): 78 | '''See base class docstring''' 79 | 80 | # notation: 81 | # x1, x2: two halves of the input 82 | # y1, y2: two halves of the output 83 | # *_c: variable with condition concatenated 84 | # j1, j2: Jacobians of the two coupling operations 85 | 86 | x1, x2 = torch.split(x[0], [self.split_len1, self.split_len2], dim=1) 87 | 88 | if not rev: 89 | x2_c = torch.cat([x2, *c], 1) if self.conditional else x2 90 | y1, j1 = self._coupling1(x1, x2_c) 91 | 92 | y1_c = torch.cat([y1, *c], 1) if self.conditional else y1 93 | y2, j2 = self._coupling2(x2, y1_c) 94 | else: 95 | # names of x and y are swapped for the reverse computation 96 | x1_c = torch.cat([x1, *c], 1) if self.conditional else x1 97 | y2, j2 = self._coupling2(x2, x1_c, rev=True) 98 | 99 | y2_c = torch.cat([y2, *c], 1) if self.conditional else y2 100 | y1, j1 = self._coupling1(x1, y2_c, rev=True) 101 | 102 | return (torch.cat((y1, y2), 1),), j1 + j2 103 | 104 | def _coupling1(self, x1, u2, rev=False): 105 | '''The first/left coupling operation in a two-sided coupling block. 106 | 107 | Args: 108 | x1 (Tensor): the 'active' half being transformed. 109 | u2 (Tensor): the 'passive' half, including the conditions, from 110 | which the transformation is computed. 111 | Returns: 112 | y1 (Tensor): same shape as x1, the transformed 'active' half. 113 | j1 (float or Tensor): the Jacobian, only has batch dimension. 114 | If the Jacobian is zero of fixed, may also return float. 115 | ''' 116 | raise NotImplementedError() 117 | 118 | def _coupling2(self, x2, u1, rev=False): 119 | '''The second/right coupling operation in a two-sided coupling block. 120 | 121 | Args: 122 | x2 (Tensor): the 'active' half being transformed. 123 | u1 (Tensor): the 'passive' half, including the conditions, from 124 | which the transformation is computed. 125 | Returns: 126 | y2 (Tensor): same shape as x1, the transformed 'active' half. 127 | j2 (float or Tensor): the Jacobian, only has batch dimension. 128 | If the Jacobian is zero of fixed, may also return float. 129 | ''' 130 | raise NotImplementedError() 131 | 132 | def output_dims(self, input_dims): 133 | '''See base class for docstring''' 134 | if len(input_dims) != 1: 135 | raise ValueError("Can only use 1 input") 136 | return input_dims 137 | 138 | 139 | class NICECouplingBlock(_BaseCouplingBlock): 140 | '''Coupling Block following the NICE (Dinh et al, 2015) design. 141 | The inputs are split in two halves. For 2D, 3D, 4D inputs, the split is 142 | performed along the channel dimension. Then, residual coefficients are 143 | predicted by two subnetworks that are added to each half in turn. 144 | ''' 145 | 146 | def __init__(self, dims_in, dims_c=[], 147 | subnet_constructor: callable = None, 148 | split_len: Union[float, int] = 0.5): 149 | ''' 150 | Additional args in docstring of base class. 151 | 152 | Args: 153 | subnet_constructor: 154 | Callable function, class, or factory object, with signature 155 | constructor(dims_in, dims_out). The result should be a torch 156 | nn.Module, that takes dims_in input channels, and dims_out output 157 | channels. See tutorial for examples. 158 | Two of these subnetworks will be initialized inside the block. 159 | ''' 160 | super().__init__(dims_in, dims_c, 161 | clamp=0., clamp_activation=(lambda u: u), 162 | split_len=split_len) 163 | 164 | self.F = subnet_constructor(self.split_len2 + self.condition_length, self.split_len1) 165 | self.G = subnet_constructor(self.split_len1 + self.condition_length, self.split_len2) 166 | 167 | def _coupling1(self, x1, u2, rev=False): 168 | if rev: 169 | return x1 - self.F(u2), 0. 170 | return x1 + self.F(u2), 0. 171 | 172 | def _coupling2(self, x2, u1, rev=False): 173 | if rev: 174 | return x2 - self.G(u1), 0. 175 | return x2 + self.G(u1), 0. 176 | 177 | 178 | class RNVPCouplingBlock(_BaseCouplingBlock): 179 | '''Coupling Block following the RealNVP design (Dinh et al, 2017) with some 180 | minor differences. The inputs are split in two halves. For 2D, 3D, 4D 181 | inputs, the split is performed along the channel dimension. For 182 | checkerboard-splitting, prepend an i_RevNet_downsampling module. Two affine 183 | coupling operations are performed in turn on both halves of the input. 184 | ''' 185 | 186 | def __init__(self, dims_in, dims_c=[], 187 | subnet_constructor: Callable = None, 188 | clamp: float = 2., 189 | clamp_activation: Union[str, Callable] = "ATAN", 190 | split_len: Union[float, int] = 0.5): 191 | ''' 192 | Additional args in docstring of base class. 193 | 194 | Args: 195 | subnet_constructor: function or class, with signature 196 | constructor(dims_in, dims_out). The result should be a torch 197 | nn.Module, that takes dims_in input channels, and dims_out output 198 | channels. See tutorial for examples. Four of these subnetworks will be 199 | initialized in the block. 200 | clamp: Soft clamping for the multiplicative component. The 201 | amplification or attenuation of each input dimension can be at most 202 | exp(±clamp). 203 | clamp_activation: Function to perform the clamping. String values 204 | "ATAN", "TANH", and "SIGMOID" are recognized, or a function of 205 | object can be passed. TANH behaves like the original realNVP paper. 206 | A custom function should take tensors and map -inf to -1 and +inf to +1. 207 | ''' 208 | 209 | super().__init__(dims_in, dims_c, clamp, clamp_activation, 210 | split_len=split_len) 211 | 212 | self.subnet_s1 = subnet_constructor(self.split_len1 + self.condition_length, self.split_len2) 213 | self.subnet_t1 = subnet_constructor(self.split_len1 + self.condition_length, self.split_len2) 214 | self.subnet_s2 = subnet_constructor(self.split_len2 + self.condition_length, self.split_len1) 215 | self.subnet_t2 = subnet_constructor(self.split_len2 + self.condition_length, self.split_len1) 216 | 217 | def _coupling1(self, x1, u2, rev=False): 218 | 219 | # notation (same for _coupling2): 220 | # x: inputs (i.e. 'x-side' when rev is False, 'z-side' when rev is True) 221 | # y: outputs (same scheme) 222 | # *_c: variables with condition appended 223 | # *1, *2: left half, right half 224 | # a: all affine coefficients 225 | # s, t: multiplicative and additive coefficients 226 | # j: log det Jacobian 227 | 228 | s2, t2 = self.subnet_s2(u2), self.subnet_t2(u2) 229 | s2 = self.clamp * self.f_clamp(s2) 230 | j1 = torch.sum(s2, dim=tuple(range(1, self.ndims + 1))) 231 | 232 | if rev: 233 | y1 = (x1 - t2) * torch.exp(-s2) 234 | return y1, -j1 235 | else: 236 | y1 = torch.exp(s2) * x1 + t2 237 | return y1, j1 238 | 239 | def _coupling2(self, x2, u1, rev=False): 240 | s1, t1 = self.subnet_s1(u1), self.subnet_t1(u1) 241 | s1 = self.clamp * self.f_clamp(s1) 242 | j2 = torch.sum(s1, dim=tuple(range(1, self.ndims + 1))) 243 | 244 | if rev: 245 | y2 = (x2 - t1) * torch.exp(-s1) 246 | return y2, -j2 247 | else: 248 | y2 = torch.exp(s1) * x2 + t1 249 | return y2, j2 250 | 251 | 252 | class GLOWCouplingBlock(_BaseCouplingBlock): 253 | '''Coupling Block following the GLOW design. Note, this is only the coupling 254 | part itself, and does not include ActNorm, invertible 1x1 convolutions, etc. 255 | See AllInOneBlock for a block combining these functions at once. 256 | The only difference to the RNVPCouplingBlock coupling blocks 257 | is that it uses a single subnetwork to jointly predict [s_i, t_i], instead of two separate 258 | subnetworks. This reduces computational cost and speeds up learning. 259 | ''' 260 | 261 | def __init__(self, dims_in, dims_c=[], 262 | subnet_constructor: Callable = None, 263 | clamp: float = 2., 264 | clamp_activation: Union[str, Callable] = "ATAN", 265 | split_len: Union[float, int] = 0.5): 266 | ''' 267 | Additional args in docstring of base class. 268 | 269 | Args: 270 | subnet_constructor: function or class, with signature 271 | constructor(dims_in, dims_out). The result should be a torch 272 | nn.Module, that takes dims_in input channels, and dims_out output 273 | channels. See tutorial for examples. Two of these subnetworks will be 274 | initialized in the block. 275 | clamp: Soft clamping for the multiplicative component. The 276 | amplification or attenuation of each input dimension can be at most 277 | exp(±clamp). 278 | clamp_activation: Function to perform the clamping. String values 279 | "ATAN", "TANH", and "SIGMOID" are recognized, or a function of 280 | object can be passed. TANH behaves like the original realNVP paper. 281 | A custom function should take tensors and map -inf to -1 and +inf to +1. 282 | ''' 283 | 284 | super().__init__(dims_in, dims_c, clamp, clamp_activation, 285 | split_len=split_len) 286 | 287 | self.subnet1 = subnet_constructor(self.split_len1 + self.condition_length, self.split_len2 * 2) 288 | self.subnet2 = subnet_constructor(self.split_len2 + self.condition_length, self.split_len1 * 2) 289 | 290 | def _coupling1(self, x1, u2, rev=False): 291 | 292 | # notation (same for _coupling2): 293 | # x: inputs (i.e. 'x-side' when rev is False, 'z-side' when rev is True) 294 | # y: outputs (same scheme) 295 | # *_c: variables with condition appended 296 | # *1, *2: left half, right half 297 | # a: all affine coefficients 298 | # s, t: multiplicative and additive coefficients 299 | # j: log det Jacobian 300 | 301 | a2 = self.subnet2(u2) 302 | s2, t2 = a2[:, :self.split_len1], a2[:, self.split_len1:] 303 | s2 = self.clamp * self.f_clamp(s2) 304 | j1 = torch.sum(s2, dim=tuple(range(1, self.ndims + 1))) 305 | 306 | if rev: 307 | y1 = (x1 - t2) * torch.exp(-s2) 308 | return y1, -j1 309 | else: 310 | y1 = torch.exp(s2) * x1 + t2 311 | return y1, j1 312 | 313 | def _coupling2(self, x2, u1, rev=False): 314 | a1 = self.subnet1(u1) 315 | s1, t1 = a1[:, :self.split_len2], a1[:, self.split_len2:] 316 | s1 = self.clamp * self.f_clamp(s1) 317 | j2 = torch.sum(s1, dim=tuple(range(1, self.ndims + 1))) 318 | 319 | if rev: 320 | y2 = (x2 - t1) * torch.exp(-s1) 321 | return y2, -j2 322 | else: 323 | y2 = torch.exp(s1) * x2 + t1 324 | return y2, j2 325 | 326 | 327 | class GINCouplingBlock(_BaseCouplingBlock): 328 | '''Coupling Block following the GIN design. The difference from 329 | GLOWCouplingBlock (and other affine coupling blocks) is that the Jacobian 330 | determinant is constrained to be 1. This constrains the block to be 331 | volume-preserving. Volume preservation is achieved by subtracting the mean 332 | of the output of the s subnetwork from itself. While volume preserving, GIN 333 | is still more powerful than NICE, as GIN is not volume preserving within 334 | each dimension. 335 | Note: this implementation differs slightly from the originally published 336 | implementation, which scales the final component of the s subnetwork so the 337 | sum of the outputs of s is zero. There was no difference found between the 338 | implementations in practice, but subtracting the mean guarantees that all 339 | outputs of s are at most ±exp(clamp), which might be more stable in certain 340 | cases. 341 | ''' 342 | def __init__(self, dims_in, dims_c=[], 343 | subnet_constructor: Callable = None, 344 | clamp: float = 2., 345 | clamp_activation: Union[str, Callable] = "ATAN", 346 | split_len: Union[float, int] = 0.5): 347 | ''' 348 | Additional args in docstring of base class. 349 | 350 | Args: 351 | subnet_constructor: function or class, with signature 352 | constructor(dims_in, dims_out). The result should be a torch 353 | nn.Module, that takes dims_in input channels, and dims_out output 354 | channels. See tutorial for examples. Two of these subnetworks will be 355 | initialized in the block. 356 | clamp: Soft clamping for the multiplicative component. The 357 | amplification or attenuation of each input dimension can be at most 358 | exp(±clamp). 359 | clamp_activation: Function to perform the clamping. String values 360 | "ATAN", "TANH", and "SIGMOID" are recognized, or a function of 361 | object can be passed. TANH behaves like the original realNVP paper. 362 | A custom function should take tensors and map -inf to -1 and +inf to +1. 363 | ''' 364 | 365 | super().__init__(dims_in, dims_c, clamp, clamp_activation, 366 | split_len=split_len) 367 | 368 | self.subnet1 = subnet_constructor(self.split_len1 + self.condition_length, self.split_len2 * 2) 369 | self.subnet2 = subnet_constructor(self.split_len2 + self.condition_length, self.split_len1 * 2) 370 | 371 | def _coupling1(self, x1, u2, rev=False): 372 | 373 | # notation (same for _coupling2): 374 | # x: inputs (i.e. 'x-side' when rev is False, 'z-side' when rev is True) 375 | # y: outputs (same scheme) 376 | # *_c: variables with condition appended 377 | # *1, *2: left half, right half 378 | # a: all affine coefficients 379 | # s, t: multiplicative and additive coefficients 380 | # j: log det Jacobian 381 | 382 | a2 = self.subnet2(u2) 383 | s2, t2 = a2[:, :self.split_len1], a2[:, self.split_len1:] 384 | s2 = self.clamp * self.f_clamp(s2) 385 | s2 -= s2.mean(1, keepdim=True) 386 | 387 | if rev: 388 | y1 = (x1 - t2) * torch.exp(-s2) 389 | return y1, 0. 390 | else: 391 | y1 = torch.exp(s2) * x1 + t2 392 | return y1, 0. 393 | 394 | def _coupling2(self, x2, u1, rev=False): 395 | a1 = self.subnet1(u1) 396 | s1, t1 = a1[:, :self.split_len2], a1[:, self.split_len2:] 397 | s1 = self.clamp * self.f_clamp(s1) 398 | s1 -= s1.mean(1, keepdim=True) 399 | 400 | if rev: 401 | y2 = (x2 - t1) * torch.exp(-s1) 402 | return y2, 0. 403 | else: 404 | y2 = torch.exp(s1) * x2 + t1 405 | return y2, 0. 406 | 407 | 408 | class AffineCouplingOneSided(_BaseCouplingBlock): 409 | '''Half of a coupling block following the GLOWCouplingBlock design. This 410 | means only one affine transformation on half the inputs. In the case where 411 | random permutations or orthogonal transforms are used after every block, 412 | this is not a restriction and simplifies the design. ''' 413 | 414 | def __init__(self, dims_in, dims_c=[], 415 | subnet_constructor: Callable = None, 416 | clamp: float = 2., 417 | clamp_activation: Union[str, Callable] = "ATAN", 418 | split_len: Union[float, int] = 0.5): 419 | ''' 420 | Additional args in docstring of base class. 421 | 422 | Args: 423 | subnet_constructor: function or class, with signature 424 | constructor(dims_in, dims_out). The result should be a torch 425 | nn.Module, that takes dims_in input channels, and dims_out output 426 | channels. See tutorial for examples. One subnetwork will be 427 | initialized in the block. 428 | clamp: Soft clamping for the multiplicative component. The 429 | amplification or attenuation of each input dimension can be at most 430 | exp(±clamp). 431 | clamp_activation: Function to perform the clamping. String values 432 | "ATAN", "TANH", and "SIGMOID" are recognized, or a function of 433 | object can be passed. TANH behaves like the original realNVP paper. 434 | A custom function should take tensors and map -inf to -1 and +inf to +1. 435 | ''' 436 | 437 | super().__init__(dims_in, dims_c, clamp, clamp_activation, 438 | split_len=split_len) 439 | self.subnet = subnet_constructor(self.split_len1 + self.condition_length, 2 * self.split_len2) 440 | 441 | def forward(self, x, c=[], rev=False, jac=True): 442 | x1, x2 = torch.split(x[0], [self.split_len1, self.split_len2], dim=1) 443 | x1_c = torch.cat([x1, *c], 1) if self.conditional else x1 444 | 445 | # notation: 446 | # x1, x2: two halves of the input 447 | # y1, y2: two halves of the output 448 | # a: all affine coefficients 449 | # s, t: multiplicative and additive coefficients 450 | # j: log det Jacobian 451 | 452 | a = self.subnet(x1_c) 453 | s, t = a[:, :self.split_len2], a[:, self.split_len2:] 454 | s = self.clamp * self.f_clamp(s) 455 | j = torch.sum(s, dim=tuple(range(1, self.ndims + 1))) 456 | 457 | if rev: 458 | y2 = (x2 - t) * torch.exp(-s) 459 | j *= -1 460 | else: 461 | y2 = x2 * torch.exp(s) + t 462 | 463 | return (torch.cat((x1, y2), 1),), j 464 | 465 | 466 | class ConditionalAffineTransform(_BaseCouplingBlock): 467 | '''Similar to the conditioning layers from SPADE (Park et al, 2019): Perform 468 | an affine transformation on the whole input, where the affine coefficients 469 | are predicted from only the condition. 470 | ''' 471 | 472 | def __init__(self, dims_in, dims_c=[], 473 | subnet_constructor: Callable = None, 474 | clamp: float = 2., 475 | clamp_activation: Union[str, Callable] = "ATAN", 476 | split_len: Union[float, int] = 0.5): 477 | ''' 478 | Additional args in docstring of base class. 479 | 480 | Args: 481 | subnet_constructor: function or class, with signature 482 | constructor(dims_in, dims_out). The result should be a torch 483 | nn.Module, that takes dims_in input channels, and dims_out output 484 | channels. See tutorial for examples. One subnetwork will be 485 | initialized in the block. 486 | clamp: Soft clamping for the multiplicative component. The 487 | amplification or attenuation of each input dimension can be at most 488 | exp(±clamp). 489 | clamp_activation: Function to perform the clamping. String values 490 | "ATAN", "TANH", and "SIGMOID" are recognized, or a function of 491 | object can be passed. TANH behaves like the original realNVP paper. 492 | A custom function should take tensors and map -inf to -1 and +inf to +1. 493 | ''' 494 | 495 | super().__init__(dims_in, dims_c, clamp, clamp_activation, 496 | split_len=split_len) 497 | 498 | if not self.conditional: 499 | raise ValueError("ConditionalAffineTransform must have a condition") 500 | 501 | self.subnet = subnet_constructor(self.condition_length, 2 * self.channels) 502 | 503 | def forward(self, x, c=[], rev=False, jac=True): 504 | if len(c) > 1: 505 | cond = torch.cat(c, 1) 506 | else: 507 | cond = c[0] 508 | 509 | # notation: 510 | # x: inputs (i.e. 'x-side' when rev is False, 'z-side' when rev is True) 511 | # y: outputs (same scheme) 512 | # *_c: variables with condition appended 513 | # *1, *2: left half, right half 514 | # a: all affine coefficients 515 | # s, t: multiplicative and additive coefficients 516 | # j: log det Jacobian 517 | 518 | a = self.subnet(cond) 519 | s, t = a[:, :self.channels], a[:, self.channels:] 520 | s = self.clamp * self.f_clamp(s) 521 | j = torch.sum(s, dim=tuple(range(1, self.ndims + 1))) 522 | 523 | if rev: 524 | y = (x[0] - t) * torch.exp(-s) 525 | return (y,), -j 526 | else: 527 | y = torch.exp(s) * x[0] + t 528 | return (y,), j 529 | -------------------------------------------------------------------------------- /FrEIA/modules/fixed_transforms.py: -------------------------------------------------------------------------------- 1 | from . import InvertibleModule 2 | 3 | from typing import Union, Iterable, Tuple 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | 11 | class PermuteRandom(InvertibleModule): 12 | '''Constructs a random permutation, that stays fixed during training. 13 | Permutes along the first (channel-) dimension for multi-dimenional tensors.''' 14 | 15 | def __init__(self, dims_in, dims_c=None, seed: Union[int, None] = None): 16 | '''Additional args in docstring of base class FrEIA_github.modules.InvertibleModule. 17 | 18 | Args: 19 | seed: Int seed for the permutation (numpy is used for RNG). If seed is 20 | None, do not reseed RNG. 21 | ''' 22 | super().__init__(dims_in, dims_c) 23 | 24 | self.in_channels = dims_in[0][0] 25 | 26 | if seed is not None: 27 | np.random.seed(seed) 28 | self.perm = np.random.permutation(self.in_channels) 29 | 30 | self.perm_inv = np.zeros_like(self.perm) 31 | for i, p in enumerate(self.perm): 32 | self.perm_inv[p] = i 33 | 34 | self.perm = nn.Parameter(torch.LongTensor(self.perm), requires_grad=False) 35 | self.perm_inv = nn.Parameter(torch.LongTensor(self.perm_inv), requires_grad=False) 36 | 37 | def forward(self, x, rev=False, jac=True): 38 | if not rev: 39 | return [x[0][:, self.perm]], 0. 40 | else: 41 | return [x[0][:, self.perm_inv]], 0. 42 | 43 | def output_dims(self, input_dims): 44 | if len(input_dims) != 1: 45 | raise ValueError(f"{self.__class__.__name__} can only use 1 input") 46 | return input_dims 47 | 48 | 49 | class FixedLinearTransform(InvertibleModule): 50 | '''Fixed linear transformation for 1D input tesors. The transformation is 51 | :math:`y = Mx + b`. With *d* input dimensions, *M* must be an invertible *d x d* tensor, 52 | and *b* is an optional offset vector of length *d*.''' 53 | 54 | def __init__(self, dims_in, dims_c=None, M: torch.Tensor = None, 55 | b: Union[None, torch.Tensor] = None): 56 | '''Additional args in docstring of base class FrEIA_github.modules.InvertibleModule. 57 | 58 | Args: 59 | M: Square, invertible matrix, with which each input is multiplied. Shape ``(d, d)``. 60 | b: Optional vector which is added element-wise. Shape ``(d,)``. 61 | ''' 62 | super().__init__(dims_in, dims_c) 63 | 64 | # TODO: it should be possible to give conditioning instead of M, so that the condition 65 | # provides M and b on each forward pass. 66 | 67 | if M is None: 68 | raise ValueError("Need to specify the M argument, the matrix to be multiplied.") 69 | 70 | self.M = nn.Parameter(M.t(), requires_grad=False) 71 | self.M_inv = nn.Parameter(M.t().inverse(), requires_grad=False) 72 | 73 | if b is None: 74 | self.b = 0. 75 | else: 76 | self.b = nn.Parameter(b.unsqueeze(0), requires_grad=False) 77 | 78 | self.logDetM = nn.Parameter(torch.slogdet(M)[1], requires_grad=False) 79 | 80 | def forward(self, x, rev=False, jac=True): 81 | j = self.logDetM.expand(x[0].shape[0]) 82 | if not rev: 83 | out = x[0].mm(self.M) + self.b 84 | return (out,), j 85 | else: 86 | out = (x[0] - self.b).mm(self.M_inv) 87 | return (out,), -j 88 | 89 | def output_dims(self, input_dims): 90 | if len(input_dims) != 1: 91 | raise ValueError(f"{self.__class__.__name__} can only use 1 input") 92 | return input_dims 93 | 94 | 95 | class Fixed1x1Conv(InvertibleModule): 96 | '''Given an invertible matrix M, a 1x1 convolution is performed using M as 97 | the convolution kernel. Effectively, a matrix muplitplication along the 98 | channel dimension is performed in each pixel.''' 99 | 100 | def __init__(self, dims_in, dims_c=None, M: torch.Tensor = None): 101 | '''Additional args in docstring of base class FrEIA_github.modules.InvertibleModule. 102 | 103 | Args: 104 | M: Square, invertible matrix, with which each input is multiplied. Shape ``(d, d)``. 105 | ''' 106 | super().__init__(dims_in, dims_c) 107 | 108 | # TODO: it should be possible to give conditioning instead of M, so that the condition 109 | # provides M and b on each forward pass. 110 | 111 | if M is None: 112 | raise ValueError("Need to specify the M argument, the matrix to be multiplied.") 113 | 114 | self.M = nn.Parameter(M.t().view(*M.shape, 1, 1), requires_grad=False) 115 | self.M_inv = nn.Parameter(M.t().inverse().view(*M.shape, 1, 1), requires_grad=False) 116 | self.logDetM = nn.Parameter(torch.slogdet(M)[1], requires_grad=False) 117 | 118 | def forward(self, x, rev=False, jac=True): 119 | n_pixels = x[0][0, 0].numel() 120 | j = self.logDetM * n_pixels 121 | 122 | if not rev: 123 | return (F.conv2d(x[0], self.M),), j 124 | else: 125 | return (F.conv2d(x[0], self.M_inv),), -j 126 | 127 | def output_dims(self, input_dims): 128 | '''See base class for docstring''' 129 | if len(input_dims) != 1: 130 | raise ValueError(f"{self.__class__.__name__} can only use 1 input") 131 | if len(input_dims[0]) != 3: 132 | raise ValueError(f"{self.__class__.__name__} requires 3D input (channels, height, width)") 133 | return input_dims 134 | 135 | 136 | class InvertibleSigmoid(InvertibleModule): 137 | '''Applies the sigmoid function element-wise across all batches, and the associated 138 | inverse function in reverse pass. Contains no trainable parameters. 139 | Sigmoid function S(x) and its corresponding inverse function is given by 140 | 141 | .. math:: 142 | 143 | S(x) &= \\frac{1}{1 + \\exp(-x)} \\\\ 144 | S^{-1}(x) &= \\log{\\frac{x}{1-x}}. 145 | 146 | The returning Jacobian is computed as 147 | 148 | .. math:: 149 | 150 | J = \\log \\det \\frac{1}{(1+\\exp{x})(1+\\exp{-x})}. 151 | 152 | ''' 153 | def __init__(self, dims_in, **kwargs): 154 | super().__init__(dims_in, **kwargs) 155 | 156 | def output_dims(self, dims_in): 157 | return dims_in 158 | 159 | def forward(self, x_or_z: Iterable[torch.Tensor], c: Iterable[torch.Tensor] = None, 160 | rev: bool = False, jac: bool = True) \ 161 | -> Tuple[Tuple[torch.Tensor], torch.Tensor]: 162 | x_or_z = x_or_z[0] 163 | if not rev: 164 | # S(x) 165 | result = 1 / (1 + torch.exp(-x_or_z)) 166 | else: 167 | # S^-1(z) 168 | # only defined within range 0-1, non-inclusive; else, it will returns nan. 169 | result = torch.log(x_or_z / (1 - x_or_z)) 170 | if not jac: 171 | return (result, ) 172 | 173 | # always compute jacobian using the forward direction, but with different inputs 174 | _input = result if rev else x_or_z 175 | # the following is the diagonal Jacobian as sigmoid is an element-wise op 176 | logJ = torch.log(1 / ((1 + torch.exp(_input)) * (1 + torch.exp(-_input)))) 177 | # determinant of a log diagonal Jacobian is simply the sum of its diagonals 178 | detLogJ = logJ.sum(1) 179 | if not rev: 180 | return ((result, ), detLogJ) 181 | else: 182 | return ((result, ), -detLogJ) 183 | -------------------------------------------------------------------------------- /FrEIA/modules/gaussian_mixture.py: -------------------------------------------------------------------------------- 1 | from . import InvertibleModule 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | 7 | class GaussianMixtureModel(InvertibleModule): 8 | '''An invertible Gaussian mixture model. The weights, means, covariance 9 | parameterization and component index must be supplied as conditional inputs 10 | to the module and can come from an external feed-forward network, which may 11 | be trained by backpropagating through the GMM. Weights should first be 12 | normalized via GaussianMixtureModel.normalize_weights(w) and component 13 | indices can be sampled via GaussianMixtureModel.pick_mixture_component(w). 14 | If component indices are specified, the model reduces to that Gaussian 15 | mixture component and maps between data x and standard normal latent 16 | variable z. Components can also be chosen consistently at random, by 17 | supplying an integer random seed instead of indices. If a None value is 18 | supplied instead of indices, the model maps between K data points x and K 19 | latent codes z simultaneously, where K is the number of mixture components. 20 | Mathematical derivations are found in the technical report "Training Mixture 21 | Density Networks with full covariance matrices" on arXiv.''' 22 | 23 | def __init__(self, dims_in, dims_c): 24 | super().__init__(dims_in, dims_c) 25 | 26 | self.x_dims = dims_in[0][0] 27 | # Prepare masks for filling the (triangular) Cholesky factors of the precision matrices 28 | self.mask_upper = (torch.triu(torch.ones(self.x_dims, self.x_dims), diagonal=1) == 1) 29 | self.mask_diagonal = torch.eye(self.x_dims, self.x_dims).bool() 30 | 31 | 32 | @staticmethod 33 | def pick_mixture_component(w, seed=None): 34 | '''Randomly choose mixture component indices with probability given by 35 | the component weights w. Works on batches of component weights. 36 | 37 | w: Weights of the mixture components, must be positive and sum to one 38 | seed: Optional RNG seed for consistent decisions''' 39 | 40 | w_thresholds = torch.cumsum(w, dim=1) 41 | # Prepare local random number generator 42 | rng = torch.Generator(device=w.device) 43 | if isinstance(seed, int): 44 | rng = rng.manual_seed(seed) 45 | else: 46 | rng.seed() 47 | # Draw one uniform random number per batch row and compare against thresholds 48 | u = torch.rand(w.shape[0], 1, device=w.device, generator=rng) 49 | indices = torch.sum(u > w_thresholds, dim=1).int() 50 | # Return mixture component indices 51 | return indices 52 | 53 | 54 | @staticmethod 55 | def normalize_weights(w): 56 | '''Apply softmax to ensure component weights are positive and sum to 57 | one. Works on batches of component weights. 58 | 59 | w: Unnormalized weights for Gaussian mixture components, must be of 60 | size [batch_size, n_components]''' 61 | 62 | return F.softmax(w - w.max(), dim=-1) 63 | 64 | 65 | @staticmethod 66 | def nll_loss(w, z, log_jacobian): 67 | '''Negative log-likelihood loss for training a Mixture Density Network. 68 | 69 | w: Mixture component weights, must be positive and sum to 70 | one. Tensor must be of size [batch_size, n_components]. 71 | z: Latent codes for all mixture components. Tensor must be 72 | of size [batch, n_components, n_dims]. 73 | log_jacobian: Jacobian log-determinants for each precision matrix. 74 | Tensor size must be [batch_size, n_components].''' 75 | 76 | return -((-0.5 * (z**2).sum(dim=-1) + log_jacobian).exp() * w).sum(dim=-1).log() 77 | 78 | 79 | @staticmethod 80 | def nll_upper_bound(w, z, log_jacobian): 81 | '''Numerically more stable upper bound of the negative log-likelihood 82 | loss for training a Mixture Density Network. 83 | 84 | w: Mixture component weights, must be positive and sum to 85 | one. Tensor must be of size [batch_size, n_components]. 86 | z: Latent codes for all mixture components. Tensor must be 87 | of size [batch, n_components, n_dims]. 88 | log_jacobian: Jacobian log-determinants for each precision matrix. 89 | Tensor size must be [batch_size, n_components].''' 90 | 91 | return -(w.log() - 0.5 * (z**2).sum(dim=-1) + log_jacobian).sum(dim=-1) 92 | 93 | 94 | def forward(self, x, c, rev=False, jac=True): 95 | '''Map between data distribution and standard normal latent distribution 96 | of mixture components or entire mixture, in an invertible way. 97 | 98 | x: Data during forward pass or latent codes during backward pass. Size 99 | must be [batch_size, n_dims] if component indices i are specified 100 | and should be [batch_size, n_components, n_dims] if not. 101 | 102 | The conditional input c must be a list [w, mu, U, i] of parameters for 103 | the Gaussian mixture model with the following properties: 104 | 105 | w: Weights of the mixture components, must be positive and sum to one 106 | and have size [batch_size, n_components]. 107 | mu: Means of the mixture components, must have size [batch_size, 108 | n_components, n_dims]. 109 | U: Entries for the (upper triangular) Cholesky factors for the 110 | precision matrices of the mixture components. These are needed to 111 | parameterize the covariance of the mixture components and must have 112 | size [batch_size, n_components, n_dims * (n_dims + 1) / 2]. 113 | i: Tensor of component indices (size [batch_size]), or a single integer 114 | to be used as random number generator seed for component selection, 115 | or None to indicate that all mixture components are modelled.''' 116 | assert len(x) == 1, f"GaussianMixtureModel got {len(x)} inputs, but " \ 117 | f"only one is allowed." 118 | x = x[0] 119 | 120 | # Get GMM parameters 121 | w, mu, U_entries, i = c 122 | batch_size, n_components = w.shape 123 | 124 | # Construct upper triangular Cholesky factors U of all precision matrices 125 | U = torch.zeros(batch_size, n_components, self.x_dims, self.x_dims, device=x.device) 126 | # Fill everything above the diagonal as is 127 | U[self.mask_upper.expand(batch_size,n_components,-1,-1)] = U_entries[:,:,self.x_dims:].reshape(-1) 128 | # Diagonal entries must be positive 129 | U[self.mask_diagonal.expand(batch_size,n_components,-1,-1)] = U_entries[:,:,:self.x_dims].exp().reshape(-1) 130 | 131 | # Indices of chosen mixture components, if provided 132 | if i is None: 133 | fixed_components = False 134 | else: 135 | fixed_components = True 136 | if not isinstance(i, torch.Tensor): 137 | i = self.pick_mixture_component(w, seed=i) 138 | 139 | if jac: 140 | # Compute Jacobian log-determinants 141 | # Note: we avoid a log operation by taking diagonal entries directly from U_entries, where they are in log space 142 | if fixed_components: 143 | # Keep Jacobian log-determinants for chosen components only 144 | j = torch.stack([U_entries[b, i[b], :self.x_dims].sum(dim=-1) for b in range(batch_size)]) 145 | else: 146 | # Keep Jacobian log-determinants for all components simultaneously 147 | j = U_entries[:, :, :self.x_dims].sum(dim=-1) 148 | 149 | if rev: 150 | j *= -1 151 | else: 152 | j = None 153 | 154 | # Actual forward and inverse pass 155 | if not rev: 156 | if fixed_components: 157 | # Return latent codes of x according to chosen component distributions only 158 | return [torch.stack([torch.matmul(U[b,i[b],:,:], x[b,:] - mu[b,i[b],:]) for b in range(batch_size)])], j 159 | else: 160 | # Return latent codes of x according to all component distributions simultaneously 161 | if len(x.shape) < 3: 162 | x = x[:,None,:] 163 | return [torch.matmul(U, (x - mu)[...,None])[...,0]], j 164 | else: 165 | if fixed_components: 166 | # Transform latent samples to samples from chosen mixture distributions 167 | return [torch.stack([mu[b,i[b],:] + torch.matmul(torch.inverse(U[b,i[b],:,:]), x[b,:]) for b in range(batch_size)])], j 168 | else: 169 | # Transform latent samples to samples from all mixture distributions simultaneously 170 | return [torch.matmul(torch.inverse(U), x[...,None])[...,0] + mu], j 171 | 172 | def output_dims(self, input_dims): 173 | assert len(input_dims) == 1, "Can only use 1 input" 174 | return input_dims 175 | -------------------------------------------------------------------------------- /FrEIA/modules/graph_topology.py: -------------------------------------------------------------------------------- 1 | from . import InvertibleModule 2 | 3 | import warnings 4 | from copy import deepcopy 5 | from typing import Sequence, Union 6 | 7 | import torch 8 | 9 | 10 | class Split(InvertibleModule): 11 | """Invertible split operation. 12 | 13 | Splits the incoming tensor along the given dimension, and returns a list of 14 | separate output tensors. The inverse is the corresponding merge operation. 15 | 16 | """ 17 | 18 | def __init__(self, 19 | dims_in: Sequence[Sequence[int]], 20 | section_sizes: Union[int, Sequence[int]] = None, 21 | n_sections: int = 2, 22 | dim: int = 0, 23 | ): 24 | """Inits the Split module with the attributes described above and 25 | checks that split sizes and dimensionality are compatible. 26 | 27 | Args: 28 | dims_in: 29 | A list of tuples containing the non-batch dimensionality of all 30 | incoming tensors. Handled automatically during compute graph setup. 31 | Split only takes one input tensor. 32 | section_sizes: 33 | If set, takes precedence over ``n_sections`` and behaves like the 34 | argument in torch.split(), except when a list of section sizes is given 35 | that doesn't add up to the size of ``dim``, an additional split section is 36 | created to take the slack. Defaults to None. 37 | n_sections: 38 | If ``section_sizes`` is None, the tensor is split into ``n_sections`` 39 | parts of equal size or close to it. This mode behaves like 40 | ``numpy.array_split()``. Defaults to 2, i.e. splitting the data into two 41 | equal halves. 42 | dim: 43 | Index of the dimension along which to split, not counting the batch 44 | dimension. Defaults to 0, i.e. the channel dimension in structured data. 45 | """ 46 | super().__init__(dims_in) 47 | 48 | # Size and dimensionality checks 49 | assert len(dims_in) == 1, "Split layer takes exactly one input tensor" 50 | assert len(dims_in[0]) >= dim, "Split dimension index out of range" 51 | self.dim = dim 52 | l_dim = dims_in[0][dim] 53 | 54 | if section_sizes is None: 55 | assert 2 <= n_sections, "'n_sections' must be a least 2" 56 | if l_dim % n_sections != 0: 57 | warnings.warn('Split will create sections of unequal size') 58 | self.split_size_or_sections = ( 59 | [l_dim//n_sections + 1] * (l_dim%n_sections) + 60 | [l_dim//n_sections] * (n_sections - l_dim%n_sections)) 61 | else: 62 | if isinstance(section_sizes, int): 63 | assert section_sizes < l_dim, "'section_sizes' too large" 64 | else: 65 | assert isinstance(section_sizes, (list, tuple)), \ 66 | "'section_sizes' must be either int or list/tuple of int" 67 | assert sum(section_sizes) <= l_dim, "'section_sizes' too large" 68 | if sum(section_sizes) < l_dim: 69 | warnings.warn("'section_sizes' too small, adding additional section") 70 | section_sizes = list(section_sizes).append(l_dim - sum(section_sizes)) 71 | self.split_size_or_sections = section_sizes 72 | 73 | def forward(self, x, rev=False, jac=True): 74 | """See super class InvertibleModule. 75 | Jacobian log-det of splitting is always zero.""" 76 | if rev: 77 | return [torch.cat(x, dim=self.dim+1)], 0 78 | else: 79 | return torch.split(x[0], self.split_size_or_sections, 80 | dim=self.dim+1), 0 81 | 82 | def output_dims(self, input_dims): 83 | """See super class InvertibleModule.""" 84 | assert len(input_dims) == 1, "Split layer takes exactly one input tensor" 85 | # Assemble dims of all resulting outputs 86 | return [tuple(input_dims[0][j] if (j != self.dim) else section_size 87 | for j in range(len(input_dims[0]))) 88 | for section_size in self.split_size_or_sections] 89 | 90 | 91 | 92 | class Concat(InvertibleModule): 93 | """Invertible merge operation. 94 | 95 | Concatenates a list of incoming tensors along a given dimension and passes 96 | on the result. Inverse is the corresponding split operation. 97 | """ 98 | 99 | def __init__(self, 100 | dims_in: Sequence[Sequence[int]], 101 | dim: int = 0, 102 | ): 103 | """Inits the Concat module with the attributes described above and 104 | checks that all dimensions are compatible. 105 | 106 | Args: 107 | dims_in: 108 | A list of tuples containing the non-batch dimensionality of all 109 | incoming tensors. Handled automatically during compute graph setup. 110 | Dimensionality of incoming tensors must be identical, except in the 111 | merge dimension ``dim``. Concat only makes sense with multiple input 112 | tensors. 113 | dim: 114 | Index of the dimension along which to concatenate, not counting the 115 | batch dimension. Defaults to 0, i.e. the channel dimension in structured 116 | data. 117 | """ 118 | super().__init__(dims_in) 119 | assert len(dims_in) > 1, ("Concatenation only makes sense for " 120 | "multiple inputs") 121 | assert len(dims_in[0]) >= dim, "Merge dimension index out of range" 122 | assert all(len(dims_in[i]) == len(dims_in[0]) 123 | for i in range(len(dims_in))), ( 124 | "All input tensors must have same number of " 125 | "dimensions" 126 | ) 127 | assert all(dims_in[i][j] == dims_in[0][j] for i in range(len(dims_in)) 128 | for j in range(len(dims_in[i])) if j != dim), ( 129 | "All input tensor dimensions except merge " 130 | "dimension must be identical" 131 | ) 132 | self.dim = dim 133 | self.split_size_or_sections = [dims_in[i][dim] 134 | for i in range(len(dims_in))] 135 | 136 | def forward(self, x, rev=False, jac=True): 137 | """See super class InvertibleModule. 138 | Jacobian log-det of concatenation is always zero.""" 139 | if rev: 140 | return torch.split(x[0], self.split_size_or_sections, 141 | dim=self.dim+1), 0 142 | else: 143 | return [torch.cat(x, dim=self.dim+1)], 0 144 | 145 | def output_dims(self, input_dims): 146 | """See super class InvertibleModule.""" 147 | assert len(input_dims) > 1, ("Concatenation only makes sense for " 148 | "multiple inputs") 149 | output_dims = deepcopy(list(input_dims[0])) 150 | output_dims[self.dim] = sum(input_dim[self.dim] 151 | for input_dim in input_dims) 152 | return [tuple(output_dims)] 153 | 154 | 155 | def _deprecated_by(orig_class): 156 | class deprecated_class(orig_class): 157 | def __init__(self, *args, **kwargs): 158 | warnings.warn(F"{self.__class__.__name__} is deprecated and will be removed in the public release. " 159 | F"Use {orig_class.__name__} instead.", 160 | DeprecationWarning) 161 | super().__init__(*args, **kwargs) 162 | 163 | return deprecated_class 164 | 165 | _depr_docstring = "This class is deprecated and replaced by ``{}``" 166 | 167 | Split1D = _deprecated_by(Split) 168 | Split1D.__doc__ = _depr_docstring.format(Split.__name__) 169 | 170 | SplitChannel = _deprecated_by(Split) 171 | SplitChannel.__doc__ = _depr_docstring.format(Split.__name__) 172 | 173 | Concat1d = _deprecated_by(Concat) 174 | Concat1d.__doc__ = _depr_docstring.format(Concat.__name__) 175 | 176 | ConcatChannel = _deprecated_by(Concat) 177 | ConcatChannel.__doc__ = _depr_docstring.format(Concat.__name__) 178 | -------------------------------------------------------------------------------- /FrEIA/modules/inv_auto_layers.py: -------------------------------------------------------------------------------- 1 | from . import InvertibleModule 2 | 3 | import warnings 4 | from copy import deepcopy 5 | 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as f 10 | 11 | 12 | class InvAutoActTwoSided(InvertibleModule): 13 | '''A nonlinear invertible activation analogous to Leaky ReLU, with 14 | learned slopes. 15 | 16 | The slopes are learned separately for each entry along the first 17 | intput dimenison (after the batch dimenison). I.e. element-wise for 18 | flattened inputs, channel-wise for image inputs, etc. 19 | Internally, the slopes are learned in log-space, to ensure they stay 20 | strictly > 0: 21 | 22 | .. math:: 23 | 24 | x \\geq 0 &\\implies g(x) = x \\odot \\exp(\\alpha_+) 25 | 26 | x < 0 &\\implies g(x) = x \\odot \\exp(\\alpha_-) 27 | ''' 28 | 29 | def __init__(self, dims_in, dims_c=None, init_pos: float = 2.0, init_neg: float = 0.5, learnable: bool = True): 30 | ''' 31 | Args: 32 | init_pos: The initial slope for the positive half of the activation. Must be > 0. 33 | Note that the initial value accounts for the exp-activation, meaning 34 | :math:`\\exp(\\alpha_+) =` ``init_pos``. 35 | init_pos: The initial slope for the negative half of the activation. Must be > 0. 36 | The initial value accounts for the exp-activation the same as init_pos. 37 | learnable: If False, the slopes are fixed at their initial value, and not learned. 38 | ''' 39 | super().__init__(dims_in, dims_c) 40 | self.tensor_rank = len(dims_in[0]) 41 | 42 | self.alpha_pos = np.log(init_pos) * torch.ones(dims_in[0][0]) 43 | self.alpha_pos = self.alpha_pos.view(1, -1, *([1] * (self.tensor_rank - 1))) 44 | self.alpha_pos = nn.Parameter(self.alpha_pos) 45 | 46 | self.alpha_neg = np.log(init_neg) * torch.ones(dims_in[0][0]) 47 | self.alpha_neg = self.alpha_neg.view(1, -1, *([1] * (self.tensor_rank - 1))) 48 | self.alpha_neg = nn.Parameter(self.alpha_neg) 49 | 50 | if not learnable: 51 | self.alpha_pos.requires_grad = False 52 | self.alpha_neg.requires_grad = False 53 | 54 | def forward(self, x, rev=False, jac=True): 55 | 56 | log_slope = self.alpha_pos + 0.5 * (self.alpha_neg - self.alpha_pos) * (1 - x[0].sign()) 57 | if rev: 58 | log_slope *= -1 59 | 60 | if jac: 61 | j = torch.sum(log_slope, dim=tuple(range(1, self.tensor_rank + 1))) 62 | else: 63 | j = None 64 | 65 | return [x[0] * torch.exp(log_slope)], j 66 | 67 | def output_dims(self, input_dims): 68 | if len(input_dims) != 1: 69 | raise ValueError(f"{self.__class__.__name__} can only use 1 input") 70 | return input_dims 71 | 72 | 73 | class InvAutoAct(InvertibleModule): 74 | '''A nonlinear invertible activation analogous to Leaky ReLU, with 75 | learned slopes. 76 | 77 | The slope is symmetric between the positive and negative side, i.e. 78 | 79 | .. math:: 80 | 81 | x \\geq 0 &\\implies g(x) = x \\odot \\exp(\\alpha) 82 | 83 | x < 0 &\\implies g(x) = x \\oslash \\exp(\\alpha) 84 | 85 | A separate slope is learned for each entry along the first 86 | intput dimenison (after the batch dimenison). I.e. element-wise for 87 | flattened inputs, channel-wise for image inputs, etc. 88 | ''' 89 | 90 | def __init__(self, dims_in, dims_c=None, slope_init=2.0, learnable=True): 91 | ''' 92 | Args: 93 | slope_init: The initial value of the slope on the positive side. 94 | Accounts for the exp-activation, i.e. :math:`\\exp(\\alpha) =` ``slope_init``. 95 | learnable: If False, the slopes are fixed at their initial value, and not learned. 96 | ''' 97 | super().__init__(dims_in, dims_c) 98 | 99 | self.tensor_rank = len(dims_in[0]) 100 | self.alpha = np.log(slope_init) * torch.ones(1, dims_in[0][0], *([1] * (len(dims_in[0]) - 1))) 101 | self.alpha = nn.Parameter(self.alpha) 102 | 103 | if not learnable: 104 | self.alpha.requires_grad = False 105 | 106 | def forward(self, x, rev=False, jac=True): 107 | log_slope = self.alpha * x[0].sign() 108 | if rev: 109 | log_slope *= -1 110 | 111 | if jac: 112 | j = torch.sum(log_slope, dim=tuple(range(1, self.tensor_rank + 1))) 113 | else: 114 | j = None 115 | 116 | return [x[0] * torch.exp(log_slope)], j 117 | 118 | def output_dims(self, input_dims): 119 | if len(input_dims) != 1: 120 | raise ValueError(f"{self.__class__.__name__} can only use 1 input") 121 | return input_dims 122 | 123 | 124 | class InvAutoActFixed(InvAutoAct): 125 | def __init__(self, *args, **kwargs): 126 | super().__init__(*args, **kwargs) 127 | warnings.warn("Deprecated: please use InvAutoAct with the learnable=False argument.") 128 | 129 | 130 | class LearnedElementwiseScaling(InvertibleModule): 131 | '''Scale each element of the input by a learned, non-negative factor. 132 | Unlike most other FrEIA_github modules, the scaling is not e.g. channel-wise for images, 133 | but really scales each individual element. 134 | To ensure positivity, the scaling is learned in log-space: 135 | 136 | .. math:: 137 | 138 | g(x) = x \\odot \\exp(s) 139 | ''' 140 | 141 | def __init__(self, dims_in, dims_c=None, init_scale=1.0): 142 | ''' 143 | Args: 144 | init_scale: The initial scaling value. It accounts for the exp-activation, 145 | i.e. :math:`\\exp(s) =` ``init_scale``. 146 | ''' 147 | super().__init__(dims_in, dims_c) 148 | self.s = nn.Parameter(np.log(init_scale) * torch.zeros(1, *dims_in[0])) 149 | 150 | def forward(self, x, rev=False, jac=True): 151 | 152 | if rev: 153 | scale = -self.s 154 | else: 155 | scale = self.s 156 | 157 | if jac: 158 | jac = torch.sum(self.s).unsqueeze(0) 159 | else: 160 | jac = None 161 | 162 | return [x[0] * torch.exp(scale)], jac 163 | 164 | def output_dims(self, input_dims): 165 | if len(input_dims) != 1: 166 | raise ValueError(f"{self.__class__.__name__} can only use 1 input") 167 | return input_dims 168 | 169 | 170 | class InvAutoFC(InvertibleModule): 171 | '''Fully connected 'Invertible Autoencoder'-layer (see arxiv.org/pdf/1802.06869.pdf). 172 | The weight matrix of the inverse is the tranposed weight matrix of the forward pass. 173 | If a reconstruction loss between forward and inverse is used, the layer converges 174 | to an invertible, orthogonal, linear transformation. 175 | ''' 176 | 177 | def __init__(self, dims_in, dims_c=None, dims_out=None): 178 | ''' 179 | Args: 180 | dims_out: If None, the output dimenison equals the input dimenison. 181 | However, becuase InvAuto is only asymptotically invertible, there is 182 | no strict limitation to have the same number of input- and 183 | ouput-dimensions. If dims_out is an integer instead of None, 184 | that number of output dimensions is used. 185 | ''' 186 | super().__init__(dims_in, dims_c) 187 | self.dims_in = dims_in 188 | if dims_out is None: 189 | self.dims_out = dims_in[0][0] 190 | else: 191 | self.dims_out = dims_out 192 | 193 | self.weights = nn.Parameter(np.sqrt(1. / self.dims_out) * torch.randn(self.dims_out, self.dims_in[0][0])) 194 | self.bias = nn.Parameter(torch.randn(1, self.dims_out)) 195 | print(self.weights.shape) 196 | print(self.bias.shape) 197 | 198 | def forward(self, x, rev=False, jac=True): 199 | if jac: 200 | warnings.warn('Invertible Autoencoder layers do not have a tractable log-det-Jacobian. ' 201 | 'It approaches 0 at convergence, but the value may be incorrect duing training.') 202 | 203 | if not rev: 204 | return [f.linear(x[0], self.weights) + self.bias], 0. 205 | else: 206 | return [f.linear(x[0] - self.bias, self.weights.t())], 0. 207 | 208 | def output_dims(self, input_dims): 209 | if len(input_dims) != 1: 210 | raise ValueError(f"{self.__class__.__name__} can only use 1 input") 211 | if len(input_dims[0]) != 1: 212 | raise ValueError(f"{self.__class__.__name__} can only use flattened (1D) input") 213 | return [(self.dims_out,)] 214 | 215 | 216 | class InvAutoConv2D(InvertibleModule): 217 | '''Convolutional variant of the 'Invertible Autoencoder'-layer 218 | (see arxiv.org/pdf/1802.06869.pdf). The the inverse is a tranposed 219 | convolution with the same kernel as the forward pass. If a reconstruction 220 | loss between forward and inverse is used, the layer converges to an 221 | invertible, orthogonal, linear transformation. 222 | ''' 223 | 224 | def __init__(self, dims_in, dims_c=None, dims_out=None, kernel_size=3, padding=1): 225 | ''' 226 | Args: 227 | kernel_size: Spatial size of the convlution kernel. 228 | padding: Padding of the input. Choosing ``padding = kernel_size // 2`` retains 229 | the image shape between in- and output. 230 | dims_out: If None, the output dimenison equals the input dimenison. 231 | However, becuase InvAuto is only asymptotically invertible, there is 232 | no strict limitation to have the same number of input- and 233 | ouput-dimensions. Therefore dims_out can also be a tuple of length 3: 234 | (channels, width, height). The channels are the output channels of the 235 | convolution. The user is responsible for making the width and height match 236 | with the actual output, depending on kernel_size and padding. 237 | ''' 238 | 239 | super().__init__(dims_in, dims_c) 240 | self.dims_in = dims_in 241 | 242 | if dims_out is None: 243 | self.dims_out = dims_in[0] 244 | else: 245 | self.dims_out = dims_out 246 | 247 | self.kernel_size = kernel_size 248 | self.padding = padding 249 | 250 | self.conv2d = nn.Conv2d(dims_in[0][0], self.dims_out[0], kernel_size=kernel_size, padding=padding, bias=False) 251 | self.bias = nn.Parameter(torch.randn(1, self.dims_out[0], 1, 1)) 252 | 253 | def forward(self, x, rev=False, jac=True): 254 | if jac: 255 | warnings.warn('Invertible Autoencoder layers do not have a tractable log-det-Jacobian.' 256 | 'It approaches 0 at convergence, but the value may be incorrect duing training.') 257 | 258 | if not rev: 259 | out = self.conv2d(x[0]) 260 | out += self.bias 261 | else: 262 | out = x[0] - self.bias 263 | out = f.conv_transpose2d(out, self.conv2d.weight, bias=None, padding=self.padding) 264 | 265 | return [out], 0. 266 | 267 | def output_dims(self, input_dims): 268 | if len(input_dims) != 1: 269 | raise ValueError(f"{self.__class__.__name__} can only use 1 input") 270 | if len(input_dims[0]) != 3: 271 | raise ValueError(f"{self.__class__.__name__} can only use image input (3D tensors)") 272 | return [self.dims_out] 273 | -------------------------------------------------------------------------------- /FrEIA/modules/invertible_resnet.py: -------------------------------------------------------------------------------- 1 | from . import InvertibleModule 2 | 3 | from typing import Union 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | from torch.nn.functional import conv2d, conv_transpose2d 9 | 10 | 11 | class ActNorm(InvertibleModule): 12 | '''A technique to achieve a stable initlization. 13 | 14 | First introduced in Kingma et al 2018: https://arxiv.org/abs/1807.03039 15 | The module is similar to a traditional batch normalization layer, but the 16 | data mean and standard deviation is only computed for the first batch of 17 | data. To ensure invertibility, the mean and standard devation are kept 18 | fixed from that point on. 19 | Using ActNorm layers interspersed throughout an INN ensures that 20 | intermediate outputs of the INN have standard deviation 1 and mean 0, so 21 | that the training is stable at the start, avoiding exploding or zeroed 22 | outputs. 23 | Just as with standard batch normalization layers, ActNorm contains 24 | additional channel-wise scaling and bias parameters. 25 | ''' 26 | 27 | def __init__(self, dims_in, dims_c=None, init_data: Union[torch.Tensor, None] = None): 28 | ''' 29 | Args: 30 | init_data: If ``None``, use the first batch of data passed through this 31 | module to initialize the mean and standard deviation. 32 | If ``torch.Tensor``, use this as data to initialize instead of the 33 | first real batch. 34 | ''' 35 | 36 | super().__init__(dims_in, dims_c) 37 | self.dims_in = dims_in[0] 38 | param_dims = [1, self.dims_in[0]] + [1 for i in range(len(self.dims_in) - 1)] 39 | self.scale = nn.Parameter(torch.zeros(*param_dims)) 40 | self.bias = nn.Parameter(torch.zeros(*param_dims)) 41 | 42 | if init_data: 43 | self._initialize_with_data(init_data) 44 | else: 45 | self.init_on_next_batch = True 46 | 47 | def on_load_state_dict(*args): 48 | # when this module is loading state dict, we SHOULDN'T init with data, 49 | # because that will reset the trained parameters. Registering a hook 50 | # that disable this initialisation. 51 | self.init_on_next_batch = False 52 | self._register_load_state_dict_pre_hook(on_load_state_dict) 53 | 54 | def _initialize_with_data(self, data): 55 | # Initialize to mean 0 and std 1 with sample batch 56 | # 'data' expected to be of shape (batch, channels[, ...]) 57 | assert all([data.shape[i+1] == self.dims_in[i] for i in range(len(self.dims_in))]),\ 58 | "Can't initialize ActNorm layer, provided data don't match input dimensions." 59 | self.scale.data.view(-1)[:] \ 60 | = torch.log(1 / data.transpose(0,1).contiguous().view(self.dims_in[0], -1).std(dim=-1)) 61 | data = data * self.scale.exp() 62 | self.bias.data.view(-1)[:] \ 63 | = -data.transpose(0,1).contiguous().view(self.dims_in[0], -1).mean(dim=-1) 64 | self.init_on_next_batch = False 65 | 66 | def forward(self, x, rev=False, jac=True): 67 | if self.init_on_next_batch: 68 | self._initialize_with_data(x[0]) 69 | 70 | jac = (self.scale.sum() * np.prod(self.dims_in[1:])).repeat(x[0].shape[0]) 71 | if rev: 72 | jac = -jac 73 | 74 | if not rev: 75 | return [x[0] * self.scale.exp() + self.bias], jac 76 | else: 77 | return [(x[0] - self.bias) / self.scale.exp()], jac 78 | 79 | def output_dims(self, input_dims): 80 | assert len(input_dims) == 1, "Can only use 1 input" 81 | return input_dims 82 | 83 | 84 | 85 | class IResNetLayer(InvertibleModule): 86 | """ 87 | Implementation of the i-ResNet architecture as proposed in 88 | https://arxiv.org/pdf/1811.00995.pdf 89 | """ 90 | 91 | def __init__(self, dims_in, dims_c=[], 92 | internal_size=None, 93 | n_internal_layers=1, 94 | jacobian_iterations=20, 95 | hutchinson_samples=1, 96 | fixed_point_iterations=50, 97 | lipschitz_iterations=10, 98 | lipschitz_batchsize=10, 99 | spectral_norm_max=0.8): 100 | 101 | super().__init__(dims_in, dims_c) 102 | 103 | if internal_size: 104 | self.internal_size = internal_size 105 | else: 106 | self.internal_size = 2 * dims_in[0][0] 107 | self.n_internal_layers = n_internal_layers 108 | self.jacobian_iterations = jacobian_iterations 109 | self.hutchinson_samples = hutchinson_samples 110 | self.fixed_point_iterations = fixed_point_iterations 111 | self.lipschitz_iterations = lipschitz_iterations 112 | self.lipschitz_batchsize = lipschitz_batchsize 113 | self.spectral_norm_max = spectral_norm_max 114 | assert 0 < spectral_norm_max <= 1, "spectral_norm_max must be in (0,1]." 115 | 116 | self.dims_in = dims_in[0] 117 | if len(self.dims_in) == 1: 118 | # Linear case 119 | self.layers = [nn.Linear(self.dims_in[0], self.internal_size),] 120 | for i in range(self.n_internal_layers): 121 | self.layers.append(nn.Linear(self.internal_size, self.internal_size)) 122 | self.layers.append(nn.Linear(self.internal_size, self.dims_in[0])) 123 | else: 124 | # Convolutional case 125 | self.layers = [nn.Conv2d(self.dims_in[0], self.internal_size, 3, padding=1),] 126 | for i in range(self.n_internal_layers): 127 | self.layers.append(nn.Conv2d(self.internal_size, self.internal_size, 3, padding=1)) 128 | self.layers.append(nn.Conv2d(self.internal_size, self.dims_in[0], 3, padding=1)) 129 | elus = [nn.ELU() for i in range(len(self.layers))] 130 | module_list = sum(zip(self.layers, elus), ())[:-1] # interleaves the lists 131 | self.residual = nn.Sequential(*module_list) 132 | 133 | 134 | def lipschitz_correction(self): 135 | with torch.no_grad(): 136 | # Power method to approximate spectral norm 137 | # Following https://arxiv.org/pdf/1804.04368.pdf 138 | for i in range(len(self.layers)): 139 | W = self.layers[i].weight 140 | x = torch.randn(self.lipschitz_batchsize, W.shape[1], *self.dims_in[1:], device=W.device) 141 | 142 | if len(self.dims_in) == 1: 143 | # Linear case 144 | for j in range(self.lipschitz_iterations): 145 | x = W.t().matmul(W.matmul(x.unsqueeze(-1))).squeeze(-1) 146 | spectral_norm = (torch.norm(W.matmul(x.unsqueeze(-1)).squeeze(-1), dim=1) /\ 147 | torch.norm(x, dim=1)).max() 148 | else: 149 | # Convolutional case 150 | for j in range(self.lipschitz_iterations): 151 | x = conv2d(x, W) 152 | x = conv_transpose2d(x, W) 153 | spectral_norm = (torch.norm(conv2d(x, W).view(self.lipschitz_batchsize, -1), dim=1) /\ 154 | torch.norm(x.view(self.lipschitz_batchsize, -1), dim=1)).max() 155 | 156 | if spectral_norm > self.spectral_norm_max: 157 | self.layers[i].weight.data *= self.spectral_norm_max / spectral_norm 158 | 159 | 160 | def forward(self, x, c=[], rev=False, jac=True): 161 | if jac: 162 | jac = self._jacobian(x, c, rev=rev) 163 | else: 164 | jac = None 165 | 166 | if not rev: 167 | return [x[0] + self.residual(x[0])], jac 168 | else: 169 | # Fixed-point iteration (works if residual has Lipschitz constant < 1) 170 | y = x[0] 171 | with torch.no_grad(): 172 | x_hat = x[0] 173 | for i in range(self.fixed_point_iterations): 174 | x_hat = y - self.residual(x_hat) 175 | return [y - self.residual(x_hat.detach())], jac 176 | 177 | 178 | def _jacobian(self, x, c=[], rev=False): 179 | if rev: 180 | return -self._jacobian(x, c=c) 181 | 182 | # Initialize log determinant of Jacobian to zero 183 | batch_size = x[0].shape[0] 184 | logdet_J = x[0].new_zeros(batch_size) 185 | # Make sure we can get vector-Jacobian product w.r.t. x even if x is the network input 186 | if x[0].is_leaf: 187 | x[0].requires_grad = True 188 | 189 | # Sample random vectors for Hutchinson trace estimate 190 | v_right = [torch.randn_like(x[0]).sign() for i in range(self.hutchinson_samples)] 191 | v_left = [v.clone() for v in v_right] 192 | 193 | # Compute terms of power series 194 | for k in range(1, self.jacobian_iterations+1): 195 | # Estimate trace of Jacobian of residual branch 196 | trace_est = [] 197 | for i in range(self.hutchinson_samples): 198 | # Compute vector-Jacobian product v.t() * J 199 | residual = self.residual(x[0]) 200 | v_left[i] = torch.autograd.grad(outputs=[residual], 201 | inputs=x, 202 | grad_outputs=[v_left[i]])[0] 203 | trace_est.append(v_left[i].view(batch_size, 1, -1).matmul(v_right[i].view(batch_size, -1, 1)).squeeze(-1).squeeze(-1)) 204 | if len(trace_est) > 1: 205 | trace_est = torch.stack(trace_est).mean(dim=0) 206 | else: 207 | trace_est = trace_est[0] 208 | # Update power series approximation of log determinant for the whole block 209 | logdet_J = logdet_J + (-1)**(k+1) * trace_est / k 210 | 211 | # # Shorter version when self.hutchinson_samples is fixed to one 212 | # v_right = torch.randn_like(x[0]) 213 | # v_left = v_right.clone() 214 | # residual = self.residual(x[0]) 215 | # for k in range(1, self.jacobian_iterations+1): 216 | # # Compute vector-Jacobian product v.t() * J 217 | # v_left = torch.autograd.grad(outputs=[residual], 218 | # inputs=x, 219 | # grad_outputs=[v_left], 220 | # retain_graph=(k < self.jacobian_iterations))[0] 221 | # # Iterate power series approximation of log determinant 222 | # trace_est = v_left.view(batch_size, 1, -1).matmul(v_right.view(batch_size, -1, 1)).squeeze(-1).squeeze(-1) 223 | # logdet_J = logdet_J + (-1)**(k+1) * trace_est / k 224 | 225 | return logdet_J 226 | 227 | 228 | def output_dims(self, input_dims): 229 | assert len(input_dims) == 1, "Can only use 1 input" 230 | return input_dims 231 | -------------------------------------------------------------------------------- /FrEIA/modules/orthogonal.py: -------------------------------------------------------------------------------- 1 | from . import InvertibleModule 2 | 3 | import torch 4 | import torch.nn as nn 5 | import numpy as np 6 | 7 | def _fast_h(v, stride=2): 8 | """ 9 | Fast product of a series of Householder matrices. This implementation is oriented to the one introducesd in: 10 | https://invertibleworkshop.github.io/accepted_papers/pdfs/10.pdf 11 | This makes use of method 2 in: https://ecommons.cornell.edu/bitstream/handle/1813/6521/85-681.pdf?sequence=1&isAllowed=y 12 | 13 | :param v: Batched series of Householder matrices. The last dim is the dim of one vector and the second last is the 14 | number of elements in one product. This is the min amount of dims that need to be present. 15 | All further ones are considered batch dimensions. 16 | :param stride: Controls the number of parallel operations by the WY representation (see paper) 17 | should not be larger than half the number of matrices in one product. 18 | :return: The batched product of Householder matrices defined by v 19 | """ 20 | assert v.ndim > 1 21 | assert stride <= v.shape[-2] 22 | 23 | d, m = v.shape[-2], v.shape[-1] 24 | k = d // stride 25 | last = k * stride 26 | v = v / torch.norm(v, dim=-1, p=2, keepdim=True) 27 | v = v.unsqueeze(-1) 28 | u = 2 * v 29 | ID = torch.eye(m, device=u.device) 30 | for dim in range(v.ndim-3): 31 | ID = ID.unsqueeze(0) 32 | 33 | # step 1 (compute intermediate groupings P_i) 34 | W = u[..., 0:last:stride, :, :] 35 | Y = v[..., 0:last:stride, :, :] 36 | 37 | for idx in range(1, stride): 38 | Pt = ID - torch.matmul(u[..., idx:last:stride, :, :], v[..., idx:last:stride, :, :].transpose(-1, -2)) 39 | W = torch.cat([W, u[..., idx:last:stride, :, :]], dim=-1) 40 | Y = torch.cat([torch.matmul(Pt, Y), v[..., idx:last:stride, :, :]], dim=-1) 41 | 42 | # step 2 (multiply the WY reps) 43 | P = ID - torch.matmul(W[..., k-1, :, :], Y[..., k-1, :, :].transpose(-1, -2)) 44 | for idx in reversed(range(0, k-1)): 45 | P = P - torch.matmul(W[..., idx, :, :], torch.matmul(Y[..., idx, :, :].transpose(-1, -2), P)) 46 | 47 | # deal with the residual, using a stride of 2 here maxes the amount of parallel ops 48 | if d > last: 49 | even_end = d if (d-last) % 2 == 0 else d - 1 50 | W_resi = u[..., last:even_end:2, :, :] 51 | Y_resi = v[..., last:even_end:2, :, :] 52 | for idx in range(last+1, d if d == last+1 else last+2): 53 | Pt = ID - torch.matmul(u[..., idx:even_end:2, :, :], v[..., idx:even_end:2, :, :].transpose(-1, -2)) 54 | W_resi = torch.cat([W_resi, u[..., idx:even_end:2, :, :]], dim=-1) 55 | Y_resi = torch.cat([torch.matmul(Pt, Y_resi), v[..., idx:even_end:2, :, :]], dim=-1) 56 | 57 | for idx in range(0, W_resi.shape[-3]): 58 | P = P - torch.matmul(P, torch.matmul(W_resi[..., idx, :, :], Y_resi[..., idx, :, :].transpose(-1, -2))) 59 | 60 | if even_end != d: 61 | P = P - torch.matmul(P, torch.matmul(u[..., -1, :, :], v[..., -1, :, :].transpose(-1, -2))) 62 | 63 | return P 64 | 65 | def orth_correction(R): 66 | R[0] /= torch.norm(R[0]) 67 | for i in range(1, R.shape[0]): 68 | 69 | R[i] -= torch.sum( R[:i].t() * torch.matmul(R[:i], R[i]), dim=1) 70 | R[i] /= torch.norm(R[i]) 71 | 72 | def correct_weights(module, grad_in, grad_out): 73 | 74 | module.back_counter += 1 75 | 76 | if module.back_counter > module.correction_interval: 77 | module.back_counter = np.random.randint(0, module.correction_interval) // 4 78 | orth_correction(module.weights.data) 79 | 80 | class OrthogonalTransform(InvertibleModule): 81 | '''Learnable orthogonal matrix, with additional scaling and bias term. 82 | 83 | The matrix is learned as a completely free weight matrix, and projected back 84 | to the Stiefel manifold (set of all orthogonal matrices) in regular intervals. 85 | With input x, the output z is computed as 86 | 87 | .. math:: 88 | 89 | z = \\Psi(s) \\odot Rx + b 90 | 91 | R is the orthogonal matrix, b the bias, s the scaling, and :math:`\\Psi` 92 | is a clamped scaling activation 93 | :math:`\\Psi(\\cdot) = \\exp(\\frac{2 \\alpha}{\\pi} \\mathrm{atan}(\\cdot))`. 94 | ''' 95 | 96 | def __init__(self, dims_in, dims_c=None, 97 | correction_interval: int = 256, 98 | clamp: float = 5.): 99 | ''' 100 | Args: 101 | 102 | correction_interval: After this many gradient steps, the matrix is 103 | projected back to the Stiefel manifold to make it perfectly orthogonal. 104 | clamp: clamps the log scaling for stability. Corresponds to 105 | :math:`alpha` above. 106 | ''' 107 | super().__init__(dims_in, dims_c) 108 | self.width = dims_in[0][0] 109 | self.clamp = clamp 110 | 111 | self.correction_interval = correction_interval 112 | self.back_counter = np.random.randint(0, correction_interval) // 2 113 | 114 | self.weights = torch.randn(self.width, self.width) 115 | self.weights = self.weights + self.weights.t() 116 | self.weights, S, V = torch.svd(self.weights) 117 | 118 | self.weights = nn.Parameter(self.weights) 119 | 120 | self.bias = nn.Parameter(0.05 * torch.randn(1, self.width)) 121 | self.scaling = nn.Parameter(0.02 * torch.randn(1, self.width)) 122 | 123 | self.register_backward_hook(correct_weights) 124 | 125 | def _log_e(self, s): 126 | '''log of the nonlinear function e''' 127 | return self.clamp * 0.636 * torch.atan(s/self.clamp) 128 | 129 | def forward(self, x, rev=False, jac=True): 130 | log_scaling = self._log_e(self.scaling) 131 | j = torch.sum(log_scaling, dim=1).expand(x[0].shape[0]) 132 | 133 | if rev: 134 | return [(x[0] * torch.exp(-log_scaling) - self.bias).mm(self.weights.t())], -j 135 | return [(x[0].mm(self.weights) + self.bias) * torch.exp(log_scaling)], j 136 | 137 | def output_dims(self, input_dims): 138 | if len(input_dims) != 1: 139 | raise ValueError(f"{self.__class__.__name__} can only use 1 input") 140 | if len(input_dims[0]) != 1: 141 | raise ValueError(f"{self.__class__.__name__} input tensor must be 1D") 142 | return input_dims 143 | 144 | 145 | class HouseholderPerm(InvertibleModule): 146 | ''' 147 | Fast product of a series of learned Householder matrices. 148 | This implementation is based on work by Mathiesen et al, 2020: 149 | https://invertibleworkshop.github.io/accepted_papers/pdfs/10.pdf 150 | Only works for flattened 1D input tensors. 151 | 152 | The module can be used in one of two ways: 153 | 154 | * Without a condition, the reflection vectors that form the householder 155 | matrices are learned as free parameters 156 | * Used as a conditional module, the condition conatins the reflection vectors. 157 | The module does not have any learnable parameters in that case, but the 158 | condition can be backpropagated (e.g. to predict the reflection vectors by 159 | some other network). The condition must have the shape 160 | ``(input size, n_reflections)``. 161 | ''' 162 | 163 | def __init__(self, dims_in, dims_c=None, 164 | n_reflections: int = 1, 165 | fixed: bool = False): 166 | ''' 167 | Args: 168 | 169 | n_reflections: How many subsequent householder reflections to perform. 170 | Each householder reflection is learned independently. 171 | Must be ``>= 2`` due to implementation reasons. 172 | fixed: If true, the householder matrices are initialized randomly and 173 | only computed once, and then kept fixed from there on. 174 | ''' 175 | super().__init__(dims_in, dims_c) 176 | self.width = dims_in[0][0] 177 | self.n_reflections = n_reflections 178 | self.fixed = fixed 179 | self.conditional = (not dims_c is None) and (len(dims_c) > 0) 180 | 181 | if self.n_reflections < 2: 182 | raise ValueError("Need at least 2 householder reflections.") 183 | 184 | if self.conditional: 185 | if len(dims_c) != 1: 186 | raise ValueError("No more than one conditional input supported.") 187 | if self.fixed: 188 | raise ValueError("Permutation can't be fixed and conditional simultaneously.") 189 | if np.prod(dims_c[0]) != self.width * self.n_reflections: 190 | raise ValueError("Dimensions of input, n_reflections and condition don't agree.") 191 | else: 192 | if self.fixed: 193 | # init randomly 194 | init = torch.randn(self.width, self.n_reflections) 195 | else: 196 | # init close to identity 197 | init = torch.eye(self.width, self.n_reflections) 198 | init += torch.randn_like(init) * 0.1 199 | Vs = init.transpose(-1, -2) 200 | self.Vs = nn.Parameter(Vs) 201 | 202 | Vs.requires_grad = not self.fixed 203 | self.register_parameter('Vs', self.Vs) 204 | 205 | if self.fixed: 206 | self.W = _fast_h(self.Vs) 207 | self.W = nn.Parameter(self.W, requires_grad=False) 208 | self.register_parameter('weight', self.W) 209 | 210 | def forward(self, x, c=[], rev=False, jac=True): 211 | 212 | if self.conditional: 213 | Vs = c[0].reshape(-1, self.width, self.n_reflections).transpose(-1, -2) 214 | W = _fast_h(Vs) 215 | else: 216 | if self.fixed: 217 | W = self.W 218 | else: 219 | W = _fast_h(self.Vs) 220 | 221 | if not rev: 222 | return [x[0].mm(W)], 0. 223 | else: 224 | return [x[0].mm(W.transpose(-1, -2))], 0. 225 | 226 | def output_dims(self, input_dims): 227 | if len(input_dims) != 1: 228 | raise ValueError(f"{self.__class__.__name__} can only use 1 input") 229 | if len(input_dims[0]) != 1: 230 | raise ValueError(f"{self.__class__.__name__} input tensor must be 1D") 231 | return input_dims 232 | -------------------------------------------------------------------------------- /FrEIA/modules/reshapes.py: -------------------------------------------------------------------------------- 1 | from . import InvertibleModule 2 | 3 | from warnings import warn 4 | from typing import Iterable 5 | 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | 12 | class IRevNetDownsampling(InvertibleModule): 13 | '''The invertible spatial downsampling used in i-RevNet. 14 | Each group of four neighboring pixels is reordered into one pixel with four times 15 | the channels in a checkerboard-like pattern. See i-RevNet, Jacobsen 2018 et al. 16 | ''' 17 | 18 | def __init__(self, dims_in, dims_c=None, legacy_backend: bool = False): 19 | '''See docstring of base class (FrEIA_github.modules.InvertibleModule) for more. 20 | 21 | Args: 22 | legacy_backend: If True, uses the splitting and concatenating method, 23 | adapted from 24 | github.com/jhjacobsen/pytorch-i-revnet/blob/master/models/model_utils.py 25 | for the use in FrEIA_github. Is usually slower on GPU. 26 | If False, uses a 2d strided convolution with a kernel representing 27 | the downsampling. Note that the ordering of the output channels 28 | will be different. If pixels in each patch in channel 1 29 | are ``a1, b1,...``, and in channel 2 are ``a2, b2,...`` 30 | Then the output channels will be the following: 31 | 32 | ``legacy_backend=True: a1, a2, ..., b1, b2, ..., c1, c2, ...`` 33 | 34 | ``legacy_backend=False: a1, b1, ..., a2, b2, ..., a3, b3, ...`` 35 | 36 | (see also order_by_wavelet in module HaarDownsampling) 37 | Generally this difference is completely irrelevant, 38 | unless a certaint subset of pixels or channels is supposed to be 39 | split off or extracted. 40 | ''' 41 | super().__init__(dims_in, dims_c) 42 | 43 | self.channels = dims_in[0][0] 44 | self.block_size = 2 45 | self.block_size_sq = self.block_size**2 46 | self.legacy_backend = legacy_backend 47 | 48 | if not self.legacy_backend: 49 | # this kernel represents the reshape: 50 | # it applies to 2x2 patches (stride 2), and transforms each 51 | # input channel to 4 channels. 52 | # The input value is transferred wherever the kernel is 1. 53 | # (hence the indexing pattern 00, 01, 10, 11 represents the 54 | # checkerboard. 55 | # For the upsampling, a transposed convolution is used for the 56 | # opposite effect. 57 | 58 | self.downsample_kernel = torch.zeros(4, 1, 2, 2) 59 | 60 | self.downsample_kernel[0, 0, 0, 0] = 1 61 | self.downsample_kernel[1, 0, 0, 1] = 1 62 | self.downsample_kernel[2, 0, 1, 0] = 1 63 | self.downsample_kernel[3, 0, 1, 1] = 1 64 | 65 | self.downsample_kernel = torch.cat([self.downsample_kernel] * self.channels, 0) 66 | self.downsample_kernel = nn.Parameter(self.downsample_kernel) 67 | self.downsample_kernel.requires_grad = False 68 | 69 | def forward(self, x, c=None, jac=True, rev=False): 70 | '''See docstring of base class (FrEIA_github.modules.InvertibleModule).''' 71 | input = x[0] 72 | if not rev: 73 | if self.legacy_backend: 74 | # only j.h. jacobsen understands how this works, 75 | # https://github.com/jhjacobsen/pytorch-i-revnet/blob/master/models/model_utils.py 76 | output = input.permute(0, 2, 3, 1) 77 | 78 | (batch_size, s_height, s_width, s_depth) = output.size() 79 | d_depth = s_depth * self.block_size_sq 80 | d_height = s_height // self.block_size 81 | 82 | t_1 = output.split(self.block_size, dim=2) 83 | stack = [t_t.contiguous().view(batch_size, d_height, d_depth) 84 | for t_t in t_1] 85 | output = torch.stack(stack, 1) 86 | output = output.permute(0, 2, 1, 3) 87 | output = output.permute(0, 3, 1, 2) 88 | return (output.contiguous(),), 0. 89 | else: 90 | output = F.conv2d(input, self.downsample_kernel, stride=2, groups=self.channels) 91 | return (output,), 0. 92 | 93 | else: 94 | if self.legacy_backend: 95 | # only j.h. jacobsen understands how this works, 96 | # https://github.com/jhjacobsen/pytorch-i-revnet/blob/master/models/model_utils.py 97 | output = input.permute(0, 2, 3, 1) 98 | (batch_size, d_height, d_width, d_depth) = output.size() 99 | s_depth = int(d_depth / self.block_size_sq) 100 | s_width = int(d_width * self.block_size) 101 | s_height = int(d_height * self.block_size) 102 | t_1 = output.contiguous().view(batch_size, d_height, d_width, 103 | self.block_size_sq, s_depth) 104 | spl = t_1.split(self.block_size, 3) 105 | stack = [t_t.contiguous().view(batch_size, d_height, s_width, 106 | s_depth) for t_t in spl] 107 | output = torch.stack(stack, 0).transpose(0, 1) 108 | output = output.permute(0, 2, 1, 3, 4).contiguous() 109 | output = output.view(batch_size, s_height, s_width, s_depth) 110 | output = output.permute(0, 3, 1, 2) 111 | return (output.contiguous(),), 0. 112 | else: 113 | output = F.conv_transpose2d(input, self.downsample_kernel, 114 | stride=2, groups=self.channels) 115 | return (output,), 0. 116 | 117 | def output_dims(self, input_dims): 118 | '''See docstring of base class (FrEIA_github.modules.InvertibleModule).''' 119 | 120 | if len(input_dims) != 1: 121 | raise ValueError("i-RevNet downsampling must have exactly 1 input") 122 | if len(input_dims[0]) != 3: 123 | raise ValueError("i-RevNet downsampling can only transform 2D images" 124 | "of the shape CxWxH (channels, width, height)") 125 | 126 | c, w, h = input_dims[0] 127 | c2, w2, h2 = c * 4, w // 2, h // 2 128 | 129 | if c * h * w != c2 * h2 * w2: 130 | raise ValueError("Input cannot be cleanly reshaped, most likely because" 131 | "the input height or width are an odd number") 132 | 133 | return ((c2, w2, h2),) 134 | 135 | 136 | class IRevNetUpsampling(IRevNetDownsampling): 137 | '''The inverted operation of IRevNetDownsampling (see that docstring for details).''' 138 | 139 | def __init__(self, dims_in, dims_c=None, legacy_backend: bool = False): 140 | '''See docstring of base class (FrEIA_github.modules.InvertibleModule) for more. 141 | 142 | Args: 143 | legacy_backend: If True, uses the splitting and concatenating method, 144 | adapted from 145 | github.com/jhjacobsen/pytorch-i-revnet/blob/master/models/model_utils.py 146 | for the use in FrEIA_github. Is usually slower on GPU. 147 | If False, uses a 2d strided transposed convolution with a representing 148 | the downsampling. Note that the expected ordering of the input channels 149 | will be different. If pixels in each output patch in channel 1 150 | are ``a1, b1,...``, and in channel 2 are ``a2, b2,...`` 151 | Then the expected input channels are be the following: 152 | 153 | ``legacy_backend=True: a1, a2, ..., b1, b2, ..., c1, c2, ...`` 154 | 155 | ``legacy_backend=False: a1, b1, ..., a2, b2, ..., a3, b3, ...`` 156 | 157 | (see also order_by_wavelet in module HaarDownsampling) 158 | Generally this difference is completely irrelevant, 159 | unless a certaint subset of pixels or channels is supposed to be 160 | split off or extracted. 161 | ''' 162 | 163 | # have to initialize with the OUTPUT shape, because everything is 164 | # inherited from IRevNetDownsampling: 165 | inv_shape = self.output_dims(dims_in) 166 | super().__init__(inv_shape, dims_c, legacy_backend=legacy_backend) 167 | 168 | def forward(self, x, c=None, jac=True, rev=False): 169 | '''See docstring of base class (FrEIA_github.modules.InvertibleModule).''' 170 | return super().forward(x, c=None, rev=not rev) 171 | 172 | def output_dims(self, input_dims): 173 | '''See docstring of base class (FrEIA_github.modules.InvertibleModule).''' 174 | 175 | if len(input_dims) != 1: 176 | raise ValueError("i-RevNet downsampling must have exactly 1 input") 177 | if len(input_dims[0]) != 3: 178 | raise ValueError("i-RevNet downsampling can only transform 2D images" 179 | "of the shape cxwxh (channels, width, height)") 180 | 181 | c, w, h = input_dims[0] 182 | c2, w2, h2 = c // 4, w * 2, h * 2 183 | 184 | if c * h * w != c2 * h2 * w2: 185 | raise ValueError("input cannot be cleanly reshaped, most likely because" 186 | "the input height or width are an odd number") 187 | 188 | return ((c2, w2, h2),) 189 | 190 | 191 | class HaarDownsampling(InvertibleModule): 192 | '''Uses Haar wavelets to split each channel into 4 channels, with half the 193 | width and height dimensions.''' 194 | 195 | def __init__(self, dims_in, dims_c = None, 196 | order_by_wavelet: bool = False, 197 | rebalance: float = 1.): 198 | '''See docstring of base class (FrEIA_github.modules.InvertibleModule) for more. 199 | 200 | Args: 201 | order_by_wavelet: Whether to group the output by original channels or 202 | by wavelet. I.e. if the average, vertical, horizontal and diagonal 203 | wavelets for channel 1 are ``a1, v1, h1, d1``, those for channel 2 are 204 | ``a2, v2, h2, d2``, etc, then the output channels will be structured as 205 | follows: 206 | 207 | set to ``True: a1, a2, ..., v1, v2, ..., h1, h2, ..., d1, d2, ...`` 208 | 209 | set to ``False: a1, v1, h1, d1, a2, v2, h2, d2, ...`` 210 | 211 | The ``True`` option is slightly slower to compute than the ``False`` option. 212 | The option is useful if e.g. the average channels should be split 213 | off by a FrEIA_github.modules.Split. Then, setting ``order_by_wavelet=True`` 214 | allows to split off the first quarter of channels to isolate the 215 | average wavelets only. 216 | rebalance: Must be !=0. There exist different conventions how to define 217 | the Haar wavelets. The wavelet components in the forward direction 218 | are multiplied with this factor, and those in the inverse direction 219 | are adjusted accordingly, so that the module as a whole is 220 | invertible. Stability of the network may be increased for rebalance 221 | < 1 (e.g. 0.5). 222 | ''' 223 | super().__init__(dims_in, dims_c) 224 | 225 | if rebalance == 0: 226 | raise ValueError("'rebalance' argument must be != 0.") 227 | 228 | self.in_channels = dims_in[0][0] 229 | 230 | # self.jac_{fwd,rev} is the log Jacobian determinant for a single pixel 231 | # in a single channel computed explicitly from the matrix below. 232 | 233 | self.fac_fwd = 0.5 * rebalance 234 | self.jac_fwd = (np.log(16.) + 4 * np.log(self.fac_fwd)) / 4. 235 | 236 | self.fac_rev = 0.5 / rebalance 237 | self.jac_rev = (np.log(16.) + 4 * np.log(self.fac_rev)) / 4. 238 | 239 | # See https://en.wikipedia.org/wiki/Haar_wavelet#Haar_matrix 240 | # for an explanation of how this weight matrix comes about 241 | self.haar_weights = torch.ones(4, 1, 2, 2) 242 | 243 | self.haar_weights[1, 0, 0, 1] = -1 244 | self.haar_weights[1, 0, 1, 1] = -1 245 | 246 | self.haar_weights[2, 0, 1, 0] = -1 247 | self.haar_weights[2, 0, 1, 1] = -1 248 | 249 | self.haar_weights[3, 0, 1, 0] = -1 250 | self.haar_weights[3, 0, 0, 1] = -1 251 | 252 | self.haar_weights = torch.cat([self.haar_weights] * self.in_channels, 0) 253 | self.haar_weights = nn.Parameter(self.haar_weights) 254 | self.haar_weights.requires_grad = False 255 | 256 | # for 'order_by_wavelet', we just perform the channel-wise wavelet 257 | # transform as usual, and then permute the channels into the correct 258 | # order afterward (hence 'self.permute') 259 | self.permute = order_by_wavelet 260 | 261 | if self.permute: 262 | permutation = [] 263 | for i in range(4): 264 | permutation += [i + 4 * j for j in range(self.in_channels)] 265 | 266 | self.perm = torch.LongTensor(permutation) 267 | self.perm_inv = torch.LongTensor(permutation) 268 | 269 | # clever trick to invert a permutation 270 | for i, p in enumerate(self.perm): 271 | self.perm_inv[p] = i 272 | 273 | def forward(self, x, c=None, jac=True, rev=False): 274 | '''See docstring of base class (FrEIA_github.modules.InvertibleModule).''' 275 | 276 | inp = x[0] 277 | #number total entries except for batch dimension: 278 | ndims = inp[0].numel() 279 | 280 | if not rev: 281 | jac = ndims * self.jac_fwd 282 | out = F.conv2d(inp, self.haar_weights, 283 | bias=None, stride=2, groups=self.in_channels) 284 | 285 | if self.permute: 286 | return (out[:, self.perm] * self.fac_fwd,), jac 287 | else: 288 | return (out * self.fac_fwd,), jac 289 | 290 | else: 291 | jac = ndims * self.jac_rev 292 | if self.permute: 293 | x_perm = inp[:, self.perm_inv] 294 | else: 295 | x_perm = inp 296 | 297 | x_perm *= self.fac_rev 298 | out = F.conv_transpose2d(x_perm, self.haar_weights, stride=2, groups=self.in_channels) 299 | 300 | return (out,), jac 301 | 302 | def output_dims(self, input_dims): 303 | '''See docstring of base class (FrEIA_github.modules.InvertibleModule).''' 304 | 305 | if len(input_dims) != 1: 306 | raise ValueError("HaarDownsampling must have exactly 1 input") 307 | if len(input_dims[0]) != 3: 308 | raise ValueError("HaarDownsampling can only transform 2D images" 309 | "of the shape CxWxH (channels, width, height)") 310 | 311 | c, w, h = input_dims[0] 312 | c2, w2, h2 = c * 4, w // 2, h // 2 313 | 314 | if c * h * w != c2 * h2 * w2: 315 | raise ValueError("Input cannot be cleanly reshaped, most likely because" 316 | "the input height or width are an odd number") 317 | 318 | return ((c2, w2, h2),) 319 | 320 | 321 | class HaarUpsampling(HaarDownsampling): 322 | '''The inverted operation of HaarDownsampling (see that docstring for details).''' 323 | 324 | def __init__(self, dims_in, dims_c = None, 325 | order_by_wavelet: bool = False, 326 | rebalance: float = 1.): 327 | '''See docstring of base class (FrEIA_github.modules.InvertibleModule) for more. 328 | 329 | Args: 330 | order_by_wavelet: Expected grouping of the input channels by wavelet or 331 | by output channel. I.e. if the average, vertical, horizontal and diagonal 332 | wavelets for channel 1 are ``a1, v1, h1, d1``, those for channel 2 are 333 | ``a2, v2, h2, d2``, etc, then the input channels are taken as follows: 334 | 335 | set to ``True: a1, a2, ..., v1, v2, ..., h1, h2, ..., d1, d2, ...`` 336 | 337 | set to ``False: a1, v1, h1, d1, a2, v2, h2, d2, ...`` 338 | 339 | The ``True`` option is slightly slower to compute than the ``False`` option. 340 | The option is useful if e.g. the input has been concatentated from average 341 | channels and the higher-frequency channels. Then, setting 342 | ``order_by_wavelet=True`` allows to split off the first quarter of 343 | channels to isolate the average wavelets only. 344 | rebalance: Must be !=0. There exist different conventions how to define 345 | the Haar wavelets. The wavelet components in the forward direction 346 | are multiplied with this factor, and those in the inverse direction 347 | are adjusted accordingly, so that the module as a whole is 348 | invertible. Stability of the network may be increased for rebalance 349 | < 1 (e.g. 0.5). 350 | ''' 351 | inv_shape = self.output_dims(dims_in) 352 | super().__init__(inv_shape, dims_c, order_by_wavelet, rebalance) 353 | 354 | def forward(self, x, c=None, jac=True, rev=False): 355 | '''See docstring of base class (FrEIA_github.modules.InvertibleModule).''' 356 | return super().forward(x, c=None, rev=not rev) 357 | 358 | def output_dims(self, input_dims): 359 | '''See docstring of base class (FrEIA_github.modules.InvertibleModule).''' 360 | 361 | if len(input_dims) != 1: 362 | raise ValueError("i-revnet downsampling must have exactly 1 input") 363 | if len(input_dims[0]) != 3: 364 | raise ValueError("i-revnet downsampling can only tranform 2d images" 365 | "of the shape cxwxh (channels, width, height)") 366 | 367 | c, w, h = input_dims[0] 368 | c2, w2, h2 = c // 4, w * 2, h * 2 369 | 370 | if c * h * w != c2 * h2 * w2: 371 | raise ValueError("input cannot be cleanly reshaped, most likely because" 372 | "the input height or width are an odd number") 373 | 374 | return ((c2, w2, h2),) 375 | 376 | 377 | class Flatten(InvertibleModule): 378 | '''Flattens N-D tensors into 1-D tensors.''' 379 | 380 | def __init__(self, dims_in, dims_c=None): 381 | '''See docstring of base class (FrEIA_github.modules.InvertibleModule).''' 382 | super().__init__(dims_in, dims_c) 383 | 384 | if len(dims_in) != 1: 385 | raise ValueError("Flattening must have exactly 1 input") 386 | 387 | self.input_shape = dims_in[0] 388 | self.output_shape = (int(np.prod(dims_in[0])),) 389 | 390 | def forward(self, x, c=None, jac=True, rev=False): 391 | '''See docstring of base class (FrEIA_github.modules.InvertibleModule).''' 392 | if not rev: 393 | return (x[0].view(x[0].shape[0], -1),), 0. 394 | else: 395 | return (x[0].view(x[0].shape[0], *self.input_shape),), 0. 396 | 397 | def output_dims(self, input_dims): 398 | '''See docstring of base class (FrEIA_github.modules.InvertibleModule).''' 399 | return (self.output_shape,) 400 | 401 | 402 | class Reshape(InvertibleModule): 403 | '''Reshapes N-D tensors into target dim tensors. Note that the reshape resulting from 404 | e.g. (3, 32, 32) -> (12, 16, 16) will not necessarily be spatially sensible. 405 | See ``IRevNetDownsampling``, ``IRevNetUpsampling``, ``HaarDownsampling``, 406 | ``HaarUpsampling`` for spatially meaningful reshaping operations.''' 407 | 408 | def __init__(self, dims_in, dims_c=None, output_dims: Iterable[int] = None, target_dim = None): 409 | '''See docstring of base class (FrEIA_github.modules.InvertibleModule) for more. 410 | 411 | Args: 412 | output_dims: The shape the reshaped output is supposed to have (not 413 | including batch dimension) 414 | target_dim: Deprecated name for output_dims 415 | ''' 416 | super().__init__(dims_in, dims_c) 417 | 418 | if target_dim is not None: 419 | warn("Use the new name for the 'target_dim' argument: 'output_dims'" 420 | "the 'target_dim' argument will be removed in the next version") 421 | output_dims = target_dim 422 | 423 | if output_dims is None: 424 | raise ValueError("Please specify the desired output shape") 425 | 426 | self.size = dims_in[0] 427 | self.target_dim = output_dims 428 | 429 | if len(dims_in) != 1: 430 | raise ValueError("Reshape must have exactly 1 input") 431 | if int(np.prod(dims_in[0])) != int(np.prod(self.target_dim)): 432 | raise ValueError(f"Incoming dimensions {dims_in[0]} and target_dim" 433 | f"{self.target_dim} don't match." 434 | "Must have same number of elements for invertibility") 435 | 436 | def forward(self, x, c=None, jac=True, rev=False): 437 | '''See docstring of base class (FrEIA_github.modules.InvertibleModule).''' 438 | 439 | if not rev: 440 | return (x[0].reshape(x[0].shape[0], *self.target_dim),), 0. 441 | else: 442 | return (x[0].reshape(x[0].shape[0], *self.size),), 0. 443 | 444 | def output_dims(self, dim): 445 | '''See docstring of base class (FrEIA_github.modules.InvertibleModule).''' 446 | return (self.target_dim,) 447 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 AlessioGalluccio 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Attention!: This repo is not ready! Work in progress🚧 2 | # FastFlow 3 | 4 | An unofficial implementation of the architecture of FastFlow [(Jiawei Yu et al.)](https://arxiv.org/pdf/2111.07677v2.pdf). 5 | Starting from [this](https://github.com/marco-rudolph/differnet) implementation of Differnet by Marco Rudolph, I'm trying to create an easy to use implementation of FastFlow. 6 | 7 | Python version >= 3.8 8 | 9 | If you use neptune, create a file named `neptuneparams.py` and insert this code 10 | ``` 11 | project="insert_name_of_neptune_project_here" 12 | api_token="inset_token_here" 13 | ``` 14 | These parameters are generated when you create a project on neptune, and you can find them there. 15 | If you are not interested in using neptune, you can comment the neptune code and the import statement in `train.py`. 16 | 17 | This project assumes that you have the mvtec dataset in the following path structure: 18 | ``` 19 | - data: 20 | - mvtec: 21 | - hazelnut 22 | - toothbush 23 | - ... 24 | ``` 25 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | '''This file configures the training procedure because handling arguments in every single function is so exhaustive for 2 | research purposes. Don't try this code if you are a software engineer.''' 3 | import torch 4 | 5 | # device settings 6 | #'cuda' or 'cpu' 7 | device = 'cpu' 8 | 9 | if device == 'cuda': 10 | torch.cuda.set_device(0) 11 | 12 | # neptune 13 | neptune_activate = False 14 | 15 | # data settings 16 | dataset_path = "dummy_dataset" 17 | class_name = "dummy_class" 18 | modelname = "dummy_test" 19 | 20 | 21 | # transformation settings 22 | transf_rotations = True 23 | transf_brightness = 0.0 24 | transf_contrast = 0.0 25 | transf_saturation = 0.0 26 | norm_mean, norm_std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] 27 | 28 | # feature extractor 29 | # select "resnet18", "deit", or "cait" 30 | extractor_name = "deit" 31 | 32 | # network hyperparameters 33 | n_scales = 1 # number of scales at which features are extracted, img_size is the highest - others are //2, //4,... 34 | clamp_alpha = 3 # see paper (differnet) equation 2 for explanation 35 | clamp = 1.2 # clamp in convolutional layers 36 | n_coupling_blocks = 4 37 | #fc_internal = 2048 # number of neurons in hidden layers of s-t-networks 38 | dropout = 0.0 # dropout in s-t-networks 39 | lr_init = 2e-4 40 | subnet_conv_dim = 128 # internal dimension of the convolutional layera 41 | only_3x3_convolution = False # set all convolutional layers to have 3x3 convolutions 42 | 43 | if(extractor_name == "resnet18"): 44 | n_feat = 64*64*64*n_scales 45 | img_size = (256, 256) 46 | elif(extractor_name == "deit"): 47 | n_feat = 24*24*768*n_scales 48 | img_size = (384, 384) 49 | elif(extractor_name == "cait"): 50 | n_feat = 28*28*768*n_scales 51 | img_size = (448, 448) 52 | else: 53 | n_feat = 256 * n_scales # do not change except you change the feature extractor 54 | img_size = (448, 448) 55 | 56 | img_dims = [3] + list(img_size) 57 | 58 | # dataloader parameters 59 | n_transforms = 4 # number of transformations per sample in training 60 | n_transforms_test = 64 # number of transformations per sample in testing 61 | batch_size = 24 # actual batch size is this value multiplied by n_transforms(_test) 62 | batch_size_test = batch_size * n_transforms // n_transforms_test 63 | 64 | # total epochs = meta_epochs * sub_epochs 65 | # evaluation after epochs 66 | meta_epochs = 24 67 | sub_epochs = 8 68 | 69 | # output settings 70 | verbose = True 71 | grad_map_viz = False 72 | hide_tqdm_bar = False 73 | save_model = True 74 | 75 | -------------------------------------------------------------------------------- /dummy_dataset/dummy_class/test/anomaly/.directory: -------------------------------------------------------------------------------- 1 | [Dolphin] 2 | Timestamp=2020,8,11,12,21,46 3 | Version=4 4 | ViewMode=1 5 | -------------------------------------------------------------------------------- /dummy_dataset/dummy_class/test/anomaly/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlessioGalluccio/FastFlow/57bd6c02b347daaebd7dc29e47865b3ff64aeaa4/dummy_dataset/dummy_class/test/anomaly/1.png -------------------------------------------------------------------------------- /dummy_dataset/dummy_class/test/anomaly/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlessioGalluccio/FastFlow/57bd6c02b347daaebd7dc29e47865b3ff64aeaa4/dummy_dataset/dummy_class/test/anomaly/2.png -------------------------------------------------------------------------------- /dummy_dataset/dummy_class/test/anomaly/3png.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlessioGalluccio/FastFlow/57bd6c02b347daaebd7dc29e47865b3ff64aeaa4/dummy_dataset/dummy_class/test/anomaly/3png.png -------------------------------------------------------------------------------- /dummy_dataset/dummy_class/test/anomaly/4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlessioGalluccio/FastFlow/57bd6c02b347daaebd7dc29e47865b3ff64aeaa4/dummy_dataset/dummy_class/test/anomaly/4.png -------------------------------------------------------------------------------- /dummy_dataset/dummy_class/test/anomaly/5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlessioGalluccio/FastFlow/57bd6c02b347daaebd7dc29e47865b3ff64aeaa4/dummy_dataset/dummy_class/test/anomaly/5.png -------------------------------------------------------------------------------- /dummy_dataset/dummy_class/test/anomaly/6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlessioGalluccio/FastFlow/57bd6c02b347daaebd7dc29e47865b3ff64aeaa4/dummy_dataset/dummy_class/test/anomaly/6.png -------------------------------------------------------------------------------- /dummy_dataset/dummy_class/test/anomaly/7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlessioGalluccio/FastFlow/57bd6c02b347daaebd7dc29e47865b3ff64aeaa4/dummy_dataset/dummy_class/test/anomaly/7.png -------------------------------------------------------------------------------- /dummy_dataset/dummy_class/test/anomaly/8png.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlessioGalluccio/FastFlow/57bd6c02b347daaebd7dc29e47865b3ff64aeaa4/dummy_dataset/dummy_class/test/anomaly/8png.png -------------------------------------------------------------------------------- /dummy_dataset/dummy_class/test/good/.directory: -------------------------------------------------------------------------------- 1 | [Dolphin] 2 | Timestamp=2020,8,11,12,21,40 3 | Version=4 4 | ViewMode=1 5 | -------------------------------------------------------------------------------- /dummy_dataset/dummy_class/test/good/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlessioGalluccio/FastFlow/57bd6c02b347daaebd7dc29e47865b3ff64aeaa4/dummy_dataset/dummy_class/test/good/1.png -------------------------------------------------------------------------------- /dummy_dataset/dummy_class/test/good/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlessioGalluccio/FastFlow/57bd6c02b347daaebd7dc29e47865b3ff64aeaa4/dummy_dataset/dummy_class/test/good/2.png -------------------------------------------------------------------------------- /dummy_dataset/dummy_class/test/good/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlessioGalluccio/FastFlow/57bd6c02b347daaebd7dc29e47865b3ff64aeaa4/dummy_dataset/dummy_class/test/good/3.png -------------------------------------------------------------------------------- /dummy_dataset/dummy_class/test/good/4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlessioGalluccio/FastFlow/57bd6c02b347daaebd7dc29e47865b3ff64aeaa4/dummy_dataset/dummy_class/test/good/4.png -------------------------------------------------------------------------------- /dummy_dataset/dummy_class/test/good/5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlessioGalluccio/FastFlow/57bd6c02b347daaebd7dc29e47865b3ff64aeaa4/dummy_dataset/dummy_class/test/good/5.png -------------------------------------------------------------------------------- /dummy_dataset/dummy_class/test/good/6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlessioGalluccio/FastFlow/57bd6c02b347daaebd7dc29e47865b3ff64aeaa4/dummy_dataset/dummy_class/test/good/6.png -------------------------------------------------------------------------------- /dummy_dataset/dummy_class/test/good/7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlessioGalluccio/FastFlow/57bd6c02b347daaebd7dc29e47865b3ff64aeaa4/dummy_dataset/dummy_class/test/good/7.png -------------------------------------------------------------------------------- /dummy_dataset/dummy_class/test/good/8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlessioGalluccio/FastFlow/57bd6c02b347daaebd7dc29e47865b3ff64aeaa4/dummy_dataset/dummy_class/test/good/8.png -------------------------------------------------------------------------------- /dummy_dataset/dummy_class/train/good/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlessioGalluccio/FastFlow/57bd6c02b347daaebd7dc29e47865b3ff64aeaa4/dummy_dataset/dummy_class/train/good/1.png -------------------------------------------------------------------------------- /dummy_dataset/dummy_class/train/good/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlessioGalluccio/FastFlow/57bd6c02b347daaebd7dc29e47865b3ff64aeaa4/dummy_dataset/dummy_class/train/good/2.png -------------------------------------------------------------------------------- /dummy_dataset/dummy_class/train/good/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlessioGalluccio/FastFlow/57bd6c02b347daaebd7dc29e47865b3ff64aeaa4/dummy_dataset/dummy_class/train/good/3.png -------------------------------------------------------------------------------- /dummy_dataset/dummy_class/train/good/4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlessioGalluccio/FastFlow/57bd6c02b347daaebd7dc29e47865b3ff64aeaa4/dummy_dataset/dummy_class/train/good/4.png -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | from model import load_model 2 | from os import listdir 3 | from os.path import join 4 | from utils import get_random_transforms, get_fixed_transforms 5 | from PIL import Image 6 | import config as c 7 | import torch 8 | 9 | def get_anomaly_score(model, image_path, transforms): 10 | img = Image.open(image_path).convert('RGB') 11 | transformed_imgs = torch.stack([tf(img) for tf in transforms]) 12 | z = model(transformed_imgs) 13 | anomaly_score = torch.mean(z ** 2) 14 | print("image: %s, score: %.2f" % (image_path, anomaly_score)) 15 | return anomaly_score 16 | 17 | def evaluate(model_name, image_folder, fixed_transforms=True): 18 | model = load_model(model_name) 19 | files = listdir(image_folder) 20 | 21 | if fixed_transforms: 22 | fixed_degrees = [i * 360.0 / c.n_transforms_test for i in range(c.n_transforms_test)] 23 | transforms = [get_fixed_transforms(fd) for fd in fixed_degrees] 24 | else: 25 | transforms = [get_random_transforms()] * c.n_transforms_test 26 | 27 | for f in files: 28 | get_anomaly_score(model, join(image_folder, f), transforms) 29 | 30 | image_folder = 'dummy_dataset/dummy_class/train/good' 31 | evaluate(c.modelname, image_folder, fixed_transforms=True) -------------------------------------------------------------------------------- /handledata.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from distutils.dir_util import copy_tree 3 | import config as c 4 | import os 5 | import shutil 6 | 7 | def handledata(): 8 | Path("./dataset/"+ c.class_name).mkdir(parents=True, exist_ok=True) 9 | Path("./dataset/"+ c.class_name + "/train").mkdir(parents=True, exist_ok=True) 10 | Path("./dataset/"+ c.class_name + "/test").mkdir(parents=True, exist_ok=True) 11 | Path("./dataset/"+ c.class_name + "/train/good").mkdir(parents=True, exist_ok=True) 12 | Path("./dataset/"+ c.class_name + "/test/anomaly").mkdir(parents=True, exist_ok=True) 13 | Path("./dataset/"+ c.class_name + "/test/good").mkdir(parents=True, exist_ok=True) 14 | 15 | copy_directory("./data/mvtec/"+ c.class_name + "/train/good", "./dataset/"+ c.class_name + "/train/good") 16 | copy_directory("./data/mvtec/"+ c.class_name + "/test/good", "./dataset/"+ c.class_name + "/test/good") 17 | 18 | # TOOTHBRUSH 19 | #copy_directory("./data/mvtec/"+ c.class_name + "/test/defective", "./dataset/"+ c.class_name + "/test/anomaly") 20 | 21 | # CAPSULE 22 | #copy_directory("./data/mvtec/"+ c.class_name + "/test/crack", "./dataset/"+ c.class_name + "/test/anomaly") 23 | #copy_directory("./data/mvtec/"+ c.class_name + "/test/faulty_imprint", "./dataset/"+ c.class_name + "/test/anomaly") 24 | #copy_directory("./data/mvtec/"+ c.class_name + "/test/poke", "./dataset/"+ c.class_name + "/test/anomaly") 25 | #copy_directory("./data/mvtec/"+ c.class_name + "/test/scratch", "./dataset/"+ c.class_name + "/test/anomaly") 26 | #copy_directory("./data/mvtec/"+ c.class_name + "/test/squeeze", "./dataset/"+ c.class_name + "/test/anomaly") 27 | 28 | # GRID 29 | #copy_directory("./data/mvtec/"+ c.class_name + "/test/bent", "./dataset/"+ c.class_name + "/test/anomaly") 30 | #copy_directory("./data/mvtec/"+ c.class_name + "/test/broken", "./dataset/"+ c.class_name + "/test/anomaly") 31 | #copy_directory("./data/mvtec/"+ c.class_name + "/test/glue", "./dataset/"+ c.class_name + "/test/anomaly") 32 | #copy_directory("./data/mvtec/"+ c.class_name + "/test/metal_contamination", "./dataset/"+ c.class_name + "/test/anomaly") 33 | #copy_directory("./data/mvtec/"+ c.class_name + "/test/thread", "./dataset/"+ c.class_name + "/test/anomaly") 34 | 35 | # HAZELNUT 36 | copy_directory("./data/mvtec/"+ c.class_name + "/test/crack", "./dataset/"+ c.class_name + "/test/anomaly", "crack") 37 | copy_directory("./data/mvtec/"+ c.class_name + "/test/cut", "./dataset/"+ c.class_name + "/test/anomaly", "cut") 38 | copy_directory("./data/mvtec/"+ c.class_name + "/test/hole", "./dataset/"+ c.class_name + "/test/anomaly", "hole") 39 | copy_directory("./data/mvtec/"+ c.class_name + "/test/print", "./dataset/"+ c.class_name + "/test/anomaly", "print") 40 | 41 | 42 | 43 | def copy_directory(fromDirectory, toDirectory, tag = None): 44 | if tag == None: 45 | copy_tree(fromDirectory, toDirectory) 46 | else: 47 | for root, dirs, files in os.walk(fromDirectory): 48 | for filename in files: 49 | # I use absolute path, case you want to move several dirs. 50 | source = os.path.join( os.path.abspath(root), filename ) 51 | 52 | # Separate base from extension 53 | base, extension = os.path.splitext(filename) 54 | 55 | # Initial new name 56 | new_name = os.path.join(toDirectory, base + "_" + tag + extension) 57 | 58 | shutil.copy(source, new_name) -------------------------------------------------------------------------------- /localization.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch.autograd import Variable 3 | import config as c 4 | from utils import * 5 | import matplotlib.pyplot as plt 6 | from tqdm import tqdm 7 | import os 8 | from scipy.ndimage import rotate, gaussian_filter 9 | 10 | GRADIENT_MAP_DIR = './gradient_maps/' 11 | 12 | 13 | def save_imgs(inputs, grad, cnt): 14 | export_dir = os.path.join(GRADIENT_MAP_DIR, c.modelname) 15 | if not os.path.exists(export_dir): 16 | os.makedirs(export_dir) 17 | 18 | for g in range(grad.shape[0]): 19 | normed_grad = (grad[g] - np.min(grad[g])) / ( 20 | np.max(grad[g]) - np.min(grad[g])) 21 | orig_image = inputs[g] 22 | for image, file_suffix in [(normed_grad, '_gradient_map.png'), (orig_image, '_orig.png')]: 23 | plt.clf() 24 | plt.imshow(image) 25 | plt.axis('off') 26 | plt.savefig(os.path.join(export_dir, str(cnt) + file_suffix), bbox_inches='tight', pad_inches=0) 27 | cnt += 1 28 | return cnt 29 | 30 | 31 | def export_gradient_maps(model, testloader, optimizer, n_batches=1): 32 | plt.figure(figsize=(10, 10)) 33 | testloader.dataset.get_fixed = True 34 | cnt = 0 35 | degrees = -1 * np.arange(c.n_transforms_test) * 360.0 / c.n_transforms_test 36 | 37 | # TODO n batches 38 | for i, data in enumerate(tqdm(testloader, disable=c.hide_tqdm_bar)): 39 | optimizer.zero_grad() 40 | inputs, labels = preprocess_batch(data) 41 | inputs = Variable(inputs, requires_grad=True) 42 | 43 | emb, log_jac_det = model(inputs) 44 | loss = get_loss(emb, log_jac_det) 45 | loss.backward() 46 | 47 | grad = inputs.grad.view(-1, c.n_transforms_test, *inputs.shape[-3:]) 48 | grad = grad[labels > 0] 49 | if grad.shape[0] == 0: 50 | continue 51 | grad = t2np(grad) 52 | 53 | inputs = inputs.view(-1, c.n_transforms_test, *inputs.shape[-3:])[:, 0] 54 | inputs = np.transpose(t2np(inputs[labels > 0]), [0, 2, 3, 1]) 55 | inputs_unnormed = np.clip(inputs * c.norm_std + c.norm_mean, 0, 1) 56 | 57 | for i_item in range(c.n_transforms_test): 58 | old_shape = grad[:, i_item].shape 59 | img = np.reshape(grad[:, i_item], [-1, *grad.shape[-2:]]) 60 | img = np.transpose(img, [1, 2, 0]) 61 | img = np.transpose(rotate(img, degrees[i_item], reshape=False), [2, 0, 1]) 62 | img = gaussian_filter(img, (0, 3, 3)) 63 | grad[:, i_item] = np.reshape(img, old_shape) 64 | 65 | grad = np.reshape(grad, [grad.shape[0], -1, *grad.shape[-2:]]) 66 | grad_img = np.mean(np.abs(grad), axis=1) 67 | grad_img_sq = grad_img ** 2 68 | 69 | cnt = save_imgs(inputs_unnormed, grad_img_sq, cnt) 70 | 71 | if i == n_batches: 72 | break 73 | 74 | plt.close() 75 | testloader.dataset.get_fixed = False 76 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | '''This is the repo which contains the original code to the WACV 2021 paper 2 | "Same Same But DifferNet: Semi-Supervised Defect Detection with Normalizing Flows" 3 | by Marco Rudolph, Bastian Wandt and Bodo Rosenhahn. 4 | For further information contact Marco Rudolph (rudolph@tnt.uni-hannover.de)''' 5 | 6 | import config as c 7 | from train import train 8 | from utils import load_datasets, make_dataloaders 9 | import os 10 | 11 | import sys 12 | print("Python version") 13 | print (sys.version) 14 | print(os.listdir()) 15 | print(os.listdir("./data")) 16 | 17 | #I change the location where pytorch saves pretrained models 18 | os.environ['TORCH_HOME'] = 'models\\alexnet' #setting the environment variable 19 | 20 | #import torch 21 | #xcv = torch.hub.load('facebookresearch/deit:main', 'deit_base_distilled_patch16_224', pretrained=True) 22 | 23 | #manage dataset 24 | from handledata import handledata 25 | #handledata() 26 | 27 | train_set, test_set = load_datasets(c.dataset_path, c.class_name) 28 | train_loader, test_loader = make_dataloaders(train_set, test_set) 29 | model = train(train_loader, test_loader) 30 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch import nn 4 | from torchsummary import summary 5 | 6 | import config as c 7 | import FrEIA.modules as Fm 8 | import FrEIA.framework as Ff 9 | 10 | import torchvision.models as models 11 | 12 | import numpy as np 13 | 14 | WEIGHT_DIR = './weights' 15 | MODEL_DIR = './models' 16 | 17 | def subnet_conv_1(c_in, c_out): 18 | return nn.Sequential(nn.Conv2d(c_in, c.subnet_conv_dim, kernel_size=(1,1), padding='same'), 19 | nn.ReLU(), 20 | nn.Conv2d(c.subnet_conv_dim, c_out, kernel_size=(1,1), padding='same')) 21 | 22 | def subnet_conv_3(c_in, c_out): 23 | return nn.Sequential(nn.Conv2d(c_in, c.subnet_conv_dim, kernel_size=(3,3), padding='same'), 24 | nn.ReLU(), 25 | nn.Conv2d(c.subnet_conv_dim, c_out, kernel_size=(3,3), padding='same')) 26 | 27 | 28 | def nf_fast_flow(input_dim): 29 | nodes = list() 30 | 31 | nodes.append(Ff.InputNode(input_dim[0],input_dim[1], input_dim[2], name='input')) 32 | # I add blocks with 3x3 and 1x1 convolutions alternatively. Before them, I add a fixed permutation of the channels 33 | for k in range(c.n_coupling_blocks): 34 | # It permutes the first dimension, the channels 35 | ''' 36 | nodes.append(Ff.Node(nodes[-1], 37 | Fm.PermuteRandom, 38 | {'seed':k}, 39 | name=F'permute_high_res_{k}')) 40 | ''' 41 | if k % 2 == 0 or c.only_3x3_convolution: 42 | nodes.append(Ff.Node(nodes[-1], 43 | Fm.AllInOneBlock, 44 | {'subnet_constructor':subnet_conv_3, 'affine_clamping':c.clamp}, 45 | name=F'conv_high_res_{k}')) 46 | else: 47 | nodes.append(Ff.Node(nodes[-1], 48 | Fm.AllInOneBlock, 49 | {'subnet_constructor':subnet_conv_1, 'affine_clamping':c.clamp}, 50 | name=F'conv_high_res_{k}')) 51 | 52 | nodes.append(Ff.OutputNode(nodes[-1], name='output')) 53 | #print(nodes) 54 | coder = Ff.GraphINN(nodes) 55 | #print(coder) 56 | return coder 57 | 58 | 59 | class FastFlow(nn.Module): 60 | def __init__(self): 61 | super(FastFlow, self).__init__() 62 | 63 | if c.extractor_name == "resnet18": 64 | self.feature_extractor = models.resnet18(pretrained=True) 65 | # I take only the first blocks of the net, which has 64x64x64 as output 66 | self.feature_extractor = torch.nn.Sequential(*(list(self.feature_extractor.children())[:5])) 67 | 68 | # freeze the layers 69 | for param in self.feature_extractor.parameters(): 70 | param.requires_grad = False 71 | 72 | self.feature_extractor.to(c.device) 73 | print(summary(self.feature_extractor, (3,256,256), device=c.device)) 74 | #self.feature_extractor = torch.load('./pretrained/M48_448.pth') #sbagliato, carica solo i pesi, non il modello 75 | #self.feature_extractor.eval() # to deactivate the dropout layers 76 | 77 | # This input is unfortunately hardcoded. See the output dimensions of the feature extractor. 78 | # Don't add the batch size (first number) 79 | self.nf = nf_fast_flow((64,64,64)) 80 | 81 | elif c.extractor_name == "deit": 82 | self.feature_extractor = torch.hub.load('facebookresearch/deit:main', 'deit_base_distilled_patch16_384', pretrained=True) 83 | # I select the input layers and the first 7 blocks 84 | self.feature_extractor = torch.nn.Sequential(*list(self.feature_extractor.children())[:2], 85 | *list(list(self.feature_extractor.children())[2].children())[:7]) 86 | self.feature_extractor.to(c.device) 87 | # freeze the layers 88 | for param in self.feature_extractor.parameters(): 89 | param.requires_grad = False 90 | print(summary(self.feature_extractor, (3,384,384), device=c.device)) 91 | self.nf = nf_fast_flow((24,24,768)) 92 | 93 | elif c.extractor_name == "cait": 94 | self.feature_extractor = torch.hub.load('facebookresearch/deit:main', 'cait_M48', pretrained=True) 95 | self.feature_extractor.to(c.device) 96 | 97 | # how to print the first 5 Layerscale blocks (input layers are not included 98 | print(list(list(self.feature_extractor.children())[2].children())[:5]) 99 | 100 | # this network has a gigantic children called ModuleList, that's why we can't use only children() method to split the network 101 | # ModuleList contains many Layerscale blocks. We want to select only the first 20 ones 102 | # ModuleList content can be viewed in list(self.feature_extractor.children())[2] 103 | self.feature_extractor = torch.nn.Sequential(*list(self.feature_extractor.children())[:2], 104 | *list(list(self.feature_extractor.children())[2].children())[:20]) 105 | 106 | 107 | # freeze the layers 108 | for param in self.feature_extractor.parameters(): 109 | param.requires_grad = False 110 | 111 | 112 | 113 | print(summary(self.feature_extractor, (3,448,448), device=c.device)) 114 | self.nf = nf_fast_flow((28,28,768)) 115 | 116 | def forward(self, x): 117 | feat_s = self.feature_extractor(x) 118 | 119 | # I have to reshape the linearized output of deit back to a 2D image 120 | # From (576,768) to (24,24,768). The first number is the batch size 121 | if c.extractor_name == "deit": 122 | dim_batch = feat_s.size(dim=0) 123 | feat_s = feat_s.reshape(dim_batch,24,24,768) 124 | #print(feat_s.size()) 125 | 126 | # I have to reshape the linearized output of cait back to a 2D image 127 | if c.extractor_name == "cait": 128 | dim_batch = feat_s.size(dim=0) 129 | feat_s = feat_s.reshape(dim_batch,28,28,768) 130 | 131 | # Resnet doesn't need reshape 132 | 133 | z, log_jac_det = self.nf(feat_s) 134 | return z, log_jac_det 135 | 136 | 137 | 138 | def save_model(model, filename): 139 | if not os.path.exists(MODEL_DIR): 140 | os.makedirs(MODEL_DIR) 141 | torch.save(model, os.path.join(MODEL_DIR, filename)) 142 | 143 | 144 | def load_model(filename): 145 | path = os.path.join(MODEL_DIR, filename) 146 | model = torch.load(path) 147 | return model 148 | 149 | 150 | def save_weights(model, filename): 151 | if not os.path.exists(WEIGHT_DIR): 152 | os.makedirs(WEIGHT_DIR) 153 | torch.save(model.state_dict(), os.path.join(WEIGHT_DIR, filename)) 154 | 155 | 156 | def load_weights(model, filename): 157 | path = os.path.join(WEIGHT_DIR, filename) 158 | model.load_state_dict(torch.load(path)) 159 | return model 160 | -------------------------------------------------------------------------------- /multi_transform_loader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.transforms as transforms 3 | from torchvision.datasets import DatasetFolder 4 | from torchvision.datasets.folder import make_dataset, pil_loader, default_loader, IMG_EXTENSIONS 5 | from torchvision.transforms.functional import rotate 6 | 7 | import config as c 8 | 9 | 10 | def fixed_rotation(self, sample, degrees): 11 | cust_rot = lambda x: rotate(x, degrees, False, False, None) 12 | augmentative_transforms = [cust_rot] 13 | if c.transf_brightness > 0.0 or c.transf_contrast > 0.0 or c.transf_saturation > 0.0: 14 | augmentative_transforms += [ 15 | transforms.ColorJitter(brightness=c.transf_brightness, contrast=c.transf_contrast, 16 | saturation=c.transf_saturation)] 17 | tfs = [transforms.Resize(c.img_size)] + augmentative_transforms + [transforms.ToTensor(), 18 | transforms.Normalize(c.norm_mean, 19 | c.norm_std)] 20 | return transforms.Compose(tfs)(sample) 21 | 22 | 23 | class DatasetFolderMultiTransform(DatasetFolder): 24 | """Adapts class DatasetFolder of PyTorch in a way that one sample is transformed several times. 25 | Args: 26 | n_transforms (int): number of transformations per sample 27 | all others: see torchvision.datasets.DatasetFolder 28 | """ 29 | 30 | def __init__(self, root, loader, extensions=None, transform=None, 31 | target_transform=None, is_valid_file=None, n_transforms=1): 32 | super(DatasetFolderMultiTransform, self).__init__(root, loader, extensions=extensions, transform=transform, 33 | target_transform=target_transform) 34 | try: 35 | classes, class_to_idx = self.find_classes(self.root) 36 | except: 37 | classes, class_to_idx = self._find_classes(self.root) 38 | if is_valid_file is not None: 39 | extensions = None 40 | self.samples = make_dataset(self.root, class_to_idx, extensions) 41 | self.n_transforms = n_transforms 42 | self.get_fixed = False # set to true if the rotations should be fixed and regularly over 360 degrees 43 | self.fixed_degrees = [i * 360.0 / n_transforms for i in range(n_transforms)] 44 | 45 | def __getitem__(self, index): 46 | path, target = self.samples[index] 47 | sample = self.loader(path) 48 | if self.transform is not None: 49 | samples = list() 50 | for i in range(self.n_transforms): 51 | if self.get_fixed: 52 | samples.append(fixed_rotation(self, sample, self.fixed_degrees[i])) 53 | else: 54 | samples.append(self.transform(sample)) 55 | samples = torch.stack(samples, dim=0) 56 | if self.target_transform is not None: 57 | target = self.target_transform(target) 58 | return samples, target 59 | 60 | 61 | class ImageFolderMultiTransform(DatasetFolderMultiTransform): 62 | """Adapts class ImageFolder of PyTorch in a way that one sample can be transformed several times. 63 | Args: 64 | n_transforms (int): number of transformations per sample 65 | all others: see ImageFolder 66 | """ 67 | 68 | def __init__(self, root, transform=None, target_transform=None, 69 | loader=default_loader, is_valid_file=None, n_transforms=c.n_transforms): 70 | super(ImageFolderMultiTransform, self).__init__(root, loader, IMG_EXTENSIONS, 71 | transform=transform, 72 | target_transform=target_transform, 73 | is_valid_file=is_valid_file, n_transforms=n_transforms) 74 | self.imgs = self.samples 75 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | scikit-learn>=0.22 2 | scipy>=1.3.2 3 | numpy>=1.17.4 4 | torch>=1.10.0 5 | torchvision>=0.2.2 6 | matplotlib>=3.0.3 7 | tqdm>=4.40.2 8 | neptune-client 9 | timm==0.4.12 10 | torchsummary 11 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from sklearn.metrics import roc_auc_score, average_precision_score 4 | from tqdm import tqdm 5 | 6 | from localization import export_gradient_maps 7 | from model import FastFlow, save_model, save_weights 8 | from utils import * 9 | 10 | 11 | import neptune.new as neptune # comment this statement if you don't use neptune 12 | import config as c 13 | import neptuneparams as nep_params # comment this statement if you don't use neptune 14 | 15 | # Neptune.ai set up, in order to keep track of your experiments 16 | if c.neptune_activate: 17 | run = neptune.init( 18 | project = nep_params.project, 19 | api_token = nep_params.api_token, 20 | ) # your credentials 21 | 22 | run["name_dataset"] = [c.dataset_path] 23 | run["img_dims"] = [c.img_dims] 24 | run["device"] = c.device 25 | run["n_scales"] = c.n_scales 26 | run["class_name"] = [c.class_name] 27 | run["meta_epochs"] = c.meta_epochs 28 | run["sub_epochs"] = c.sub_epochs 29 | run["batch_size"]= c.batch_size 30 | run["n_coupling_blocks"] = c.n_coupling_blocks 31 | run["n_transforms"] = c.n_transforms 32 | run["n_transforms_test"] = c.n_transforms_test 33 | run["dropout"] =c.dropout 34 | run["learning_rate"] = c.lr_init 35 | run["subnet_conv_dim"]= c.subnet_conv_dim 36 | 37 | 38 | 39 | 40 | 41 | class Score_Observer: 42 | '''Keeps an eye on the current and highest score so far''' 43 | 44 | def __init__(self, name): 45 | self.name = name 46 | self.max_epoch = 0 47 | self.max_score = None 48 | self.last = None 49 | 50 | def update(self, score, epoch, print_score=False): 51 | self.last = score 52 | if epoch == 0 or score > self.max_score: 53 | self.max_score = score 54 | self.max_epoch = epoch 55 | if print_score: 56 | self.print_score() 57 | 58 | def print_score(self): 59 | print('{:s}: \t last: {:.4f} \t max: {:.4f} \t epoch_max: {:d}'.format(self.name, self.last, self.max_score, 60 | self.max_epoch)) 61 | 62 | 63 | def train(train_loader, test_loader): 64 | model = FastFlow() 65 | optimizer = torch.optim.Adam(model.nf.parameters(), lr=c.lr_init, betas=(0.8, 0.8), eps=1e-04, weight_decay=1e-5) 66 | model.to(c.device) 67 | 68 | score_obs_auroc = Score_Observer('AUROC') 69 | score_obs_aucpr = Score_Observer('AUCPR') 70 | 71 | for epoch in range(c.meta_epochs): 72 | 73 | # train some epochs 74 | model.train() 75 | if c.verbose: 76 | print(F'\nTrain epoch {epoch}') 77 | for sub_epoch in range(c.sub_epochs): 78 | train_loss = list() 79 | for i, data in enumerate(tqdm(train_loader, disable=c.hide_tqdm_bar)): 80 | optimizer.zero_grad() 81 | inputs, labels = preprocess_batch(data) # move to device and reshape 82 | # TODO inspect 83 | # inputs += torch.randn(*inputs.shape).cuda() * c.add_img_noise 84 | 85 | z, log_jac_det = model(inputs) 86 | loss = get_loss(z, log_jac_det) 87 | train_loss.append(t2np(loss)) 88 | loss.backward() 89 | optimizer.step() 90 | 91 | mean_train_loss = np.mean(train_loss) 92 | if c.verbose: 93 | print('Epoch: {:d}.{:d} \t train loss: {:.4f}'.format(epoch, sub_epoch, mean_train_loss)) 94 | if c.neptune_activate: 95 | run["train/train_loss"].log(mean_train_loss) 96 | 97 | # evaluate 98 | model.eval() 99 | if c.verbose: 100 | print('\nCompute loss and scores on test set:') 101 | test_loss = list() 102 | test_z = list() 103 | test_labels = list() 104 | anomaly_score = list() 105 | with torch.no_grad(): 106 | for i, data in enumerate(tqdm(test_loader, disable=c.hide_tqdm_bar)): 107 | inputs, labels = preprocess_batch(data) 108 | z, log_jac_det = model(inputs) 109 | # Why do I compute the loss also for defective images, which will have great loss values? 110 | loss = get_loss(z, log_jac_det) 111 | test_z.append(z) 112 | test_loss.append(t2np(loss)) 113 | test_labels.append(t2np(labels)) 114 | 115 | #I compute the values of anomaly score here in order to use less GPU memory 116 | z_grouped_temp = z.view(-1, c.n_transforms_test, c.n_feat) 117 | anomaly_score.append(t2np(torch.mean(z_grouped_temp ** 2, dim=(-2, -1)))) 118 | 119 | 120 | 121 | test_loss_good = list() 122 | test_loss_defective = list() 123 | for i in range(len(test_labels)): 124 | if test_labels[i] == 0: # label value of good TODO eliminate magic numbers 125 | test_loss_good.append(test_loss[i]) 126 | else: 127 | test_loss_defective.append(-test_loss[i]) 128 | test_loss_good = np.mean(np.array(test_loss_good)) 129 | test_loss_defective = np.mean(np.array(test_loss_defective)) 130 | 131 | test_loss = np.mean(np.array(test_loss)) 132 | if c.verbose: 133 | print('Epoch: {:d} \t test_loss: {:.4f} \t test_loss_good: {:.4f} \t test_loss_defective: {:.4f}'.format(epoch, test_loss, test_loss_good, test_loss_defective)) 134 | 135 | test_labels = np.concatenate(test_labels) 136 | is_anomaly = np.array([0 if l == 0 else 1 for l in test_labels]) 137 | #z_grouped = torch.cat(test_z, dim=0).view(-1, c.n_transforms_test, c.n_feat) 138 | #anomaly_score = t2np(torch.mean(z_grouped ** 2, dim=(-2, -1))) 139 | score_obs_auroc.update(roc_auc_score(is_anomaly, anomaly_score), epoch, 140 | print_score=c.verbose or epoch == c.meta_epochs - 1) 141 | score_obs_aucpr.update(average_precision_score(is_anomaly, anomaly_score), epoch, 142 | print_score=c.verbose or epoch == c.meta_epochs - 1) 143 | 144 | if c.neptune_activate: 145 | run["train/auroc"].log(score_obs_auroc.last) 146 | run["train/aucpr"].log(score_obs_aucpr.last) 147 | run["train/test_loss"].log(test_loss) 148 | run["train/test_loss_good"].log(test_loss_good) 149 | run["train/test_loss_defective"].log(test_loss_defective) 150 | 151 | 152 | 153 | 154 | 155 | if c.grad_map_viz: 156 | export_gradient_maps(model, test_loader, optimizer, -1) 157 | 158 | if c.save_model: 159 | model.to('cpu') 160 | save_model(model.state_dict(), c.modelname) 161 | save_weights(model, c.modelname) 162 | return model 163 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch.utils.data import DataLoader 4 | from torchvision import datasets, transforms 5 | from torchvision.transforms.functional import rotate 6 | import config as c 7 | from multi_transform_loader import ImageFolderMultiTransform 8 | 9 | 10 | def get_random_transforms(): 11 | augmentative_transforms = [] 12 | if c.transf_rotations: 13 | augmentative_transforms += [transforms.RandomRotation(180)] 14 | if c.transf_brightness > 0.0 or c.transf_contrast > 0.0 or c.transf_saturation > 0.0: 15 | augmentative_transforms += [transforms.ColorJitter(brightness=c.transf_brightness, contrast=c.transf_contrast, 16 | saturation=c.transf_saturation)] 17 | 18 | tfs = [transforms.Resize(c.img_size)] + augmentative_transforms + [transforms.ToTensor(), 19 | transforms.Normalize(c.norm_mean, c.norm_std)] 20 | 21 | transform_train = transforms.Compose(tfs) 22 | return transform_train 23 | 24 | 25 | def get_fixed_transforms(degrees): 26 | cust_rot = lambda x: rotate(x, degrees, False, False, None) 27 | augmentative_transforms = [cust_rot] 28 | if c.transf_brightness > 0.0 or c.transf_contrast > 0.0 or c.transf_saturation > 0.0: 29 | augmentative_transforms += [ 30 | transforms.ColorJitter(brightness=c.transf_brightness, contrast=c.transf_contrast, 31 | saturation=c.transf_saturation)] 32 | tfs = [transforms.Resize(c.img_size)] + augmentative_transforms + [transforms.ToTensor(), 33 | transforms.Normalize(c.norm_mean, 34 | c.norm_std)] 35 | return transforms.Compose(tfs) 36 | 37 | #tensor to numpy 38 | def t2np(tensor): 39 | '''pytorch tensor -> numpy array''' 40 | return tensor.cpu().data.numpy() if tensor is not None else None 41 | 42 | 43 | def get_loss(z, jac): 44 | '''check equation 4 of the paper why this makes sense - oh and just ignore the scaling here''' 45 | return torch.mean(0.5 * torch.sum(z ** 2, dim=(1,2,3)) - jac) / z.shape[1] 46 | 47 | 48 | def load_datasets(dataset_path, class_name): 49 | ''' 50 | Expected folder/file format to find anomalies of class from dataset location : 51 | 52 | train data: 53 | 54 | dataset_path/class_name/train/good/any_filename.png 55 | dataset_path/class_name/train/good/another_filename.tif 56 | dataset_path/class_name/train/good/xyz.png 57 | [...] 58 | 59 | test data: 60 | 61 | 'normal data' = non-anomalies 62 | 63 | dataset_path/class_name/test/good/name_the_file_as_you_like_as_long_as_there_is_an_image_extension.webp 64 | dataset_path/class_name/test/good/did_you_know_the_image_extension_webp?.png 65 | dataset_path/class_name/test/good/did_you_know_that_filenames_may_contain_question_marks????.png 66 | dataset_path/class_name/test/good/dont_know_how_it_is_with_windows.png 67 | dataset_path/class_name/test/good/just_dont_use_windows_for_this.png 68 | [...] 69 | 70 | anomalies - assume there are anomaly classes 'crack' and 'curved' 71 | 72 | dataset_path/class_name/test/crack/dat_crack_damn.png 73 | dataset_path/class_name/test/crack/let_it_crack.png 74 | dataset_path/class_name/test/crack/writing_docs_is_fun.png 75 | [...] 76 | 77 | dataset_path/class_name/test/curved/wont_make_a_difference_if_you_put_all_anomalies_in_one_class.png 78 | dataset_path/class_name/test/curved/but_this_code_is_practicable_for_the_mvtec_dataset.png 79 | [...] 80 | ''' 81 | 82 | def target_transform(target): 83 | return class_perm[target] 84 | 85 | data_dir_train = os.path.join(dataset_path, class_name, 'train') 86 | data_dir_test = os.path.join(dataset_path, class_name, 'test') 87 | 88 | classes = os.listdir(data_dir_test) 89 | if 'good' not in classes: 90 | print('There should exist a subdirectory "good". Read the doc of this function for further information.') 91 | exit() 92 | classes.sort() 93 | class_perm = list() 94 | class_idx = 1 95 | for cl in classes: 96 | if cl == 'good': 97 | class_perm.append(0) 98 | else: 99 | class_perm.append(class_idx) 100 | class_idx += 1 101 | 102 | transform_train = get_random_transforms() 103 | 104 | trainset = ImageFolderMultiTransform(data_dir_train, transform=transform_train, n_transforms=c.n_transforms) 105 | testset = ImageFolderMultiTransform(data_dir_test, transform=transform_train, target_transform=target_transform, 106 | n_transforms=c.n_transforms_test) 107 | return trainset, testset 108 | 109 | 110 | def make_dataloaders(trainset, testset): 111 | trainloader = torch.utils.data.DataLoader(trainset, pin_memory=True, batch_size=c.batch_size, shuffle=True, 112 | drop_last=False) 113 | testloader = torch.utils.data.DataLoader(testset, pin_memory=True, batch_size=c.batch_size_test, shuffle=True, 114 | drop_last=False) 115 | return trainloader, testloader 116 | 117 | 118 | def preprocess_batch(data): 119 | '''move data to device and reshape image''' 120 | inputs, labels = data 121 | inputs, labels = inputs.to(c.device), labels.to(c.device) 122 | inputs = inputs.view(-1, *inputs.shape[-3:]) 123 | return inputs, labels 124 | --------------------------------------------------------------------------------