├── LICENSE.txt ├── README.md ├── models ├── Densenet.py ├── __init__.py ├── contractables.py ├── lotenet.py └── mps.py ├── train.py └── utils ├── __init__.py ├── lidc_dataset.py ├── model.png ├── tools.py └── utils.py /LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Raghavendra Selvan 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # README # 2 | 3 | This is the official Pytorch implementation of LoTeNet model in 4 | "[Tensor Networks for Medical Image Classification](https://openreview.net/forum?id=jjk6bxk07G)", Raghavendra Selvan & Erik Dam, MIDL 2020. Runner up for the best paper award at [MIDL2020](https://2020.midl.io/papers/selvan20.html). 5 | 6 | ![lotenet](utils/model.png) 7 | ### What is this repository for? ### 8 | 9 | * Run and reproduce results in the paper on LIDC dataset 10 | * v1.0 11 | 12 | 13 | ### How do I get set up? ### 14 | 15 | * Basic Pytorch dependency 16 | * Tested on Pytorch 1.3, Python 3.6 17 | * Download the data from [here](https://bitbucket.org/raghavian/lotenet_pytorch/src/master/data/lidc.zip) 18 | * Unzip the data and point the path to --data_path 19 | * How to run tests: python train.py --data_path data_location 20 | 21 | ### Usage guidelines ### 22 | 23 | * Kindly cite our publication if you use any part of the code 24 | 25 | ``` 26 | @inproceedings{ 27 | raghav2020tensor, 28 | title={Tensor Networks for Medical Image Classification}, 29 | author={Raghavendra Selvan, Erik B Dam}, 30 | booktitle={International Conference on Medical Imaging with Deep Learning -- Full Paper Track}, 31 | year={2020}, 32 | month={July}, 33 | url={https://openreview.net/forum?id=jjk6bxk07G}} 34 | ``` 35 | 36 | ### Who do I talk to? ### 37 | 38 | * raghav@di.ku.dk 39 | 40 | ### Thanks to the following repositories we base our project on: 41 | * [Torch MPS](https://github.com/jemisjoky/TorchMPS/) for the amazing MPS in Pytorch implementations 42 | * [Prob.U-Net](https://github.com/stefanknegt/Probabilistic-Unet-Pytorch) for preprocessing LIDC data 43 | * [Dense Net](https://github.com/bamos/densenet.pytorch/) implementation 44 | -------------------------------------------------------------------------------- /models/Densenet.py: -------------------------------------------------------------------------------- 1 | ### Adapted from https://github.com/bamos/densenet.pytorch 2 | 3 | import torch 4 | 5 | import torch.nn as nn 6 | import torch.optim as optim 7 | 8 | import torch.nn.functional as F 9 | from torch.autograd import Variable 10 | 11 | import torchvision.datasets as dset 12 | import torchvision.transforms as transforms 13 | from torch.utils.data import DataLoader 14 | 15 | import torchvision.models as models 16 | 17 | import sys 18 | import math 19 | import pdb 20 | 21 | class Bottleneck(nn.Module): 22 | def __init__(self, nChannels, growthRate): 23 | super(Bottleneck, self).__init__() 24 | interChannels = 4*growthRate 25 | self.bn1 = nn.BatchNorm2d(nChannels) 26 | self.conv1 = nn.Conv2d(nChannels, interChannels, kernel_size=1, 27 | bias=False) 28 | self.bn2 = nn.BatchNorm2d(interChannels) 29 | self.conv2 = nn.Conv2d(interChannels, growthRate, kernel_size=3, 30 | padding=1, bias=False) 31 | 32 | def forward(self, x): 33 | out = self.conv1(F.relu(self.bn1(x))) 34 | out = self.conv2(F.relu(self.bn2(out))) 35 | out = torch.cat((x, out), 1) 36 | return out 37 | 38 | class SingleLayer(nn.Module): 39 | def __init__(self, nChannels, growthRate): 40 | super(SingleLayer, self).__init__() 41 | self.bn1 = nn.BatchNorm2d(nChannels) 42 | self.conv1 = nn.Conv2d(nChannels, growthRate, kernel_size=3, 43 | padding=1, bias=False) 44 | 45 | def forward(self, x): 46 | out = self.conv1(F.relu(self.bn1(x))) 47 | out = torch.cat((x, out), 1) 48 | return out 49 | 50 | class Transition(nn.Module): 51 | def __init__(self, nChannels, nOutChannels): 52 | super(Transition, self).__init__() 53 | self.bn1 = nn.BatchNorm2d(nChannels) 54 | self.conv1 = nn.Conv2d(nChannels, nOutChannels, kernel_size=1, 55 | bias=False) 56 | 57 | def forward(self, x): 58 | out = self.conv1(F.relu(self.bn1(x))) 59 | out = F.avg_pool2d(out, 2) 60 | return out 61 | 62 | 63 | class DenseNet(nn.Module): 64 | def __init__(self, growthRate, depth, reduction, nClasses, bottleneck): 65 | super(DenseNet, self).__init__() 66 | 67 | nDenseBlocks = (depth-4) // 3 68 | if bottleneck: 69 | nDenseBlocks //= 2 70 | 71 | nChannels = 2*growthRate 72 | self.conv1 = nn.Conv2d(1, nChannels, kernel_size=3, padding=1, 73 | bias=False) 74 | self.dense1 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck) 75 | nChannels += nDenseBlocks*growthRate 76 | nOutChannels = int(math.floor(nChannels*reduction)) 77 | self.trans1 = Transition(nChannels, nOutChannels) 78 | 79 | nChannels = nOutChannels 80 | self.dense2 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck) 81 | nChannels += nDenseBlocks*growthRate 82 | nOutChannels = int(math.floor(nChannels*reduction)) 83 | self.trans2 = Transition(nChannels, nOutChannels) 84 | 85 | nChannels = nOutChannels 86 | self.dense3 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck) 87 | nChannels += nDenseBlocks*growthRate 88 | 89 | self.bn1 = nn.BatchNorm2d(nChannels) 90 | self.fc = nn.Linear(nChannels, nClasses) 91 | 92 | self.nChannels = nChannels 93 | for m in self.modules(): 94 | if isinstance(m, nn.Conv2d): 95 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 96 | m.weight.data.normal_(0, math.sqrt(2. / n)) 97 | elif isinstance(m, nn.BatchNorm2d): 98 | m.weight.data.fill_(1) 99 | m.bias.data.zero_() 100 | elif isinstance(m, nn.Linear): 101 | m.bias.data.zero_() 102 | 103 | def _make_dense(self, nChannels, growthRate, nDenseBlocks, bottleneck): 104 | layers = [] 105 | for i in range(int(nDenseBlocks)): 106 | if bottleneck: 107 | layers.append(Bottleneck(nChannels, growthRate)) 108 | else: 109 | layers.append(SingleLayer(nChannels, growthRate)) 110 | nChannels += growthRate 111 | return nn.Sequential(*layers) 112 | 113 | def forward(self, x): 114 | # pdb.set_trace() 115 | out = self.conv1(x) 116 | out = self.trans1(self.dense1(out)) 117 | out = self.trans2(self.dense2(out)) 118 | out = self.dense3(out) 119 | out = torch.squeeze(F.avg_pool2d(F.relu(self.bn1(out)), 8)) 120 | out = out.view(x.shape[0],self.nChannels,-1).mean(2) 121 | out = torch.sigmoid(self.fc(out)) 122 | return out.squeeze() 123 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raghavian/loTeNet_pytorch/9b9814b3dea92c80be7fb0f26616cca93155e035/models/__init__.py -------------------------------------------------------------------------------- /models/contractables.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class Contractable: 4 | """ 5 | Container for tensors with labeled indices and a global batch size 6 | 7 | The labels for our indices give some high-level knowledge of the tensor 8 | layout, and permit the contraction of pairs of indices in a more 9 | systematic manner. However, much of the actual heavy lifting is done 10 | through specific contraction routines in different subclasses 11 | 12 | Attributes: 13 | tensor (Tensor): A Pytorch tensor whose first index is a batch 14 | index. Sub-classes of Contractable may put other 15 | restrictions on tensor 16 | bond_str (str): A string whose letters each label a separate mode 17 | of our tensor, and whose length equals the order 18 | (number of modes) of our tensor 19 | global_bs (int): The batch size associated with all Contractables. 20 | This is shared between all Contractable instances 21 | and allows for automatic expanding of tensors 22 | """ 23 | # The global batch size 24 | global_bs = None 25 | 26 | def __init__(self, tensor, bond_str): 27 | shape = list(tensor.shape) 28 | num_dim = len(shape) 29 | str_len = len(bond_str) 30 | 31 | global_bs = Contractable.global_bs 32 | batch_dim = tensor.size(0) 33 | 34 | # Expand along a new batch dimension if needed 35 | if ('b' not in bond_str and str_len == num_dim) or \ 36 | ('b' == bond_str[0] and str_len == num_dim + 1): 37 | if global_bs is not None: 38 | tensor = tensor.unsqueeze(0).expand([global_bs] + shape) 39 | else: 40 | raise RuntimeError("No batch size given and no previous " 41 | "batch size set") 42 | if bond_str[0] != 'b': 43 | bond_str = 'b' + bond_str 44 | 45 | # Check for correct formatting in bond_str 46 | elif bond_str[0] != 'b' or str_len != num_dim: 47 | raise ValueError("Length of bond string '{bond_str}' " 48 | f"({len(bond_str)}) must match order of " 49 | f"tensor ({len(shape)})") 50 | 51 | # Set the global batch size if it is unset or needs to be updated 52 | elif global_bs is None or global_bs != batch_dim: 53 | Contractable.global_bs = batch_dim 54 | 55 | # Check that global batch size agrees with input tensor's first dim 56 | elif global_bs != batch_dim: 57 | raise RuntimeError(f"Batch size previously set to {global_bs}" 58 | ", but input tensor has batch size " 59 | f"{batch_dim}") 60 | 61 | # Set the defining attributes of our Contractable 62 | self.tensor = tensor 63 | self.bond_str = bond_str 64 | 65 | def __mul__(self, contractable, rmul=False): 66 | """ 67 | Multiply with another contractable along a linear index 68 | 69 | The default behavior is to multiply the 'r' index of this instance 70 | with the 'l' index of contractable, matching the batch ('b') 71 | index of both, and take the outer product of other indices. 72 | If rmul is True, contractable is instead multiplied on the right. 73 | """ 74 | # This method works for general Core subclasses besides Scalar (no 'l' 75 | # and 'r' indices), composite contractables (no tensor attribute), and 76 | # MatRegion (multiplication isn't just simple index contraction) 77 | if isinstance(contractable, Scalar) or \ 78 | not hasattr(contractable, 'tensor') or \ 79 | type(contractable) is MatRegion: 80 | return NotImplemented 81 | 82 | tensors = [self.tensor, contractable.tensor] 83 | bond_strs = [list(self.bond_str), list(contractable.bond_str)] 84 | lowercases = [chr(c) for c in range(ord('a'), ord('z')+1)] 85 | 86 | # Reverse the order of tensors if needed 87 | if rmul: 88 | tensors = tensors[::-1] 89 | bond_strs = bond_strs[::-1] 90 | 91 | # Check that bond strings are in proper format 92 | for i, bs in enumerate(bond_strs): 93 | assert bs[0] == 'b' 94 | assert len(set(bs)) == len(bs) 95 | assert all([c in lowercases for c in bs]) 96 | assert (i == 0 and 'r' in bs) or (i == 1 and 'l' in bs) 97 | 98 | # Get used and free characters 99 | used_chars = set(bond_strs[0]).union(bond_strs[1]) 100 | free_chars = [c for c in lowercases if c not in used_chars] 101 | 102 | # Rename overlapping indices in the bond strings (except 'b', 'l', 'r') 103 | specials = ['b', 'l', 'r'] 104 | for i, c in enumerate(bond_strs[1]): 105 | if c in bond_strs[0] and c not in specials: 106 | bond_strs[1][i] = free_chars.pop() 107 | 108 | # Combine right bond of left tensor and left bond of right tensor 109 | sum_char = free_chars.pop() 110 | bond_strs[0][bond_strs[0].index('r')] = sum_char 111 | bond_strs[1][bond_strs[1].index('l')] = sum_char 112 | specials.append(sum_char) 113 | 114 | # Build bond string of ouput tensor 115 | out_str = ['b'] 116 | for bs in bond_strs: 117 | out_str.extend([c for c in bs if c not in specials]) 118 | out_str.append('l' if 'l' in bond_strs[0] else '') 119 | out_str.append('r' if 'r' in bond_strs[1] else '') 120 | 121 | # Build the einsum string for this operation 122 | bond_strs = [''.join(bs) for bs in bond_strs] 123 | out_str = ''.join(out_str) 124 | ein_str = f"{bond_strs[0]},{bond_strs[1]}->{out_str}" 125 | 126 | # Contract along the linear dimension to get an output tensor 127 | out_tensor = torch.einsum(ein_str, [tensors[0], tensors[1]]) 128 | 129 | # Return our output tensor wrapped in an appropriate class 130 | if out_str == 'br': 131 | return EdgeVec(out_tensor, is_left_vec=True) 132 | elif out_str == 'bl': 133 | return EdgeVec(out_tensor, is_left_vec=False) 134 | elif out_str == 'blr': 135 | return SingleMat(out_tensor) 136 | elif out_str == 'bolr': 137 | return OutputCore(out_tensor) 138 | else: 139 | return Contractable(out_tensor, out_str) 140 | 141 | def __rmul__(self, contractable): 142 | """ 143 | Multiply with another contractable along a linear index 144 | """ 145 | return self.__mul__(contractable, rmul=True) 146 | 147 | def reduce(self): 148 | """ 149 | Return the contractable without any modification 150 | 151 | reduce() can be any method which returns a contractable. This is 152 | trivially possible for any contractable by returning itself 153 | """ 154 | return self 155 | 156 | class ContractableList(Contractable): 157 | """ 158 | A list of contractables which can all be multiplied together in order 159 | 160 | Calling reduce on a ContractableList instance will first reduce every item 161 | to a linear contractable, and then contract everything together 162 | """ 163 | def __init__(self, contractable_list): 164 | # Check that input list is nonempty and has contractables as entries 165 | if not isinstance(contractable_list, list) or contractable_list is []: 166 | raise ValueError("Input to ContractableList must be nonempty list") 167 | for i, item in enumerate(contractable_list): 168 | if not isinstance(item, Contractable): 169 | raise ValueError("Input items to ContractableList must be " 170 | f"Contractable instances, but item {i} is not") 171 | 172 | self.contractable_list = contractable_list 173 | 174 | def __mul__(self, contractable, rmul=False): 175 | """ 176 | Multiply a contractable by everything in ContractableList in order 177 | """ 178 | # The input cannot be a composite contractable 179 | assert hasattr(contractable, 'tensor') 180 | output = contractable.tensor 181 | 182 | # Multiply by everything in ContractableList, in the correct order 183 | if rmul: 184 | for item in self.contractable_list: 185 | output = item * output 186 | else: 187 | for item in self.contractable_list[::-1]: 188 | output = output * item 189 | 190 | return output 191 | 192 | def __rmul__(self, contractable): 193 | """ 194 | Multiply another contractable by everything in ContractableList 195 | """ 196 | return self.__mul__(contractable, rmul=True) 197 | 198 | def reduce(self, parallel_eval=False): 199 | """ 200 | Reduce all the contractables in list before multiplying them together 201 | """ 202 | c_list = self.contractable_list 203 | # For parallel_eval, reduce all contractables in c_list 204 | if parallel_eval: 205 | c_list = [item.reduce() for item in c_list] 206 | 207 | # Multiply together all the contractables. This multiplies in right to 208 | # left order, but certain inefficient contractions are unsupported. 209 | # If we encounter an unsupported operation, then try multiplying from 210 | # the left end of the list instead 211 | while len(c_list) > 1: 212 | try: 213 | c_list[-2] = c_list[-2] * c_list[-1] 214 | del c_list[-1] 215 | except TypeError: 216 | c_list[1] = c_list[0] * c_list[1] 217 | del c_list[0] 218 | 219 | return c_list[0] 220 | 221 | class MatRegion(Contractable): 222 | """ 223 | A contiguous collection of matrices which are multiplied together 224 | 225 | The input tensor defining our MatRegion must have shape 226 | [batch_size, num_mats, D, D], or [num_mats, D, D] when the global batch 227 | size is already known 228 | """ 229 | def __init__(self, mats): 230 | shape = list(mats.shape) 231 | if len(shape) not in [3, 4] or shape[-2] != shape[-1]: 232 | raise ValueError("MatRegion tensors must have shape " 233 | "[batch_size, num_mats, D, D], or [num_mats," 234 | " D, D] if batch size has already been set") 235 | 236 | super().__init__(mats, bond_str='bslr') 237 | 238 | def __mul__(self, edge_vec, rmul=False): 239 | """ 240 | Iteratively multiply an input vector with all matrices in MatRegion 241 | """ 242 | # The input must be an instance of EdgeVec 243 | if not isinstance(edge_vec, EdgeVec): 244 | return NotImplemented 245 | 246 | mats = self.tensor 247 | num_mats = mats.size(1) 248 | batch_size = mats.size(0) 249 | 250 | # Load our vector and matrix batches 251 | dummy_ind = 1 if rmul else 2 252 | vec = edge_vec.tensor.unsqueeze(dummy_ind) 253 | mat_list = [mat.squeeze(1) for mat in torch.chunk(mats, num_mats, 1)] 254 | 255 | # Do the repeated matrix-vector multiplications in the proper order 256 | log_norm = 0 257 | for i, mat in enumerate(mat_list[::(1 if rmul else -1)], 1): 258 | if rmul: 259 | vec = torch.bmm(vec, mat) 260 | else: 261 | vec = torch.bmm(mat, vec) 262 | 263 | # Since we only have a single vector, wrap it as a EdgeVec 264 | return EdgeVec(vec.squeeze(dummy_ind), is_left_vec=rmul) 265 | 266 | def __rmul__(self, edge_vec): 267 | return self.__mul__(edge_vec, rmul=True) 268 | 269 | def reduce(self): 270 | """ 271 | Multiplies together all matrices and returns resultant SingleMat 272 | 273 | This method uses iterated batch multiplication to evaluate the full 274 | matrix product in depth O( log(num_mats) ) 275 | """ 276 | mats = self.tensor 277 | shape = list(mats.shape) 278 | batch_size = mats.size(0) 279 | size, D = shape[1:3] 280 | 281 | # Iteratively multiply pairs of matrices until there is only one 282 | while size > 1: 283 | odd_size = (size % 2 == 1) 284 | half_size = size // 2 285 | nice_size = 2 * half_size 286 | 287 | even_mats = mats[:, 0:nice_size:2] 288 | odd_mats = mats[:, 1:nice_size:2] 289 | # For odd sizes, set aside one batch of matrices for the next round 290 | leftover = mats[:, nice_size:] 291 | 292 | # Multiply together all pairs of matrices (except leftovers) 293 | mats = torch.einsum('bslu,bsur->bslr', [even_mats, odd_mats]) 294 | mats = torch.cat([mats, leftover], 1) 295 | 296 | size = half_size + int(odd_size) 297 | 298 | # Since we only have a single matrix, wrap it as a SingleMat 299 | return SingleMat(mats.squeeze(1)) 300 | 301 | class OutputCore(Contractable): 302 | """ 303 | A single MPS core with a single output index 304 | """ 305 | def __init__(self, tensor): 306 | # Check the input shape 307 | if len(tensor.shape) not in [3, 4]: 308 | raise ValueError("OutputCore tensors must have shape [batch_size, " 309 | "output_dim, D_l, D_r], or else [output_dim, D_l," 310 | " D_r] if batch size has already been set") 311 | 312 | super().__init__(tensor, bond_str='bolr') 313 | 314 | class SingleMat(Contractable): 315 | """ 316 | A batch of matrices associated with a single location in our MPS 317 | """ 318 | def __init__(self, mat): 319 | # Check the input shape 320 | if len(mat.shape) not in [2, 3]: 321 | raise ValueError("SingleMat tensors must have shape [batch_size, " 322 | "D_l, D_r], or else [D_l, D_r] if batch size " 323 | "has already been set") 324 | 325 | super().__init__(mat, bond_str='blr') 326 | 327 | class OutputMat(Contractable): 328 | """ 329 | An output core associated with an edge of our MPS 330 | """ 331 | def __init__(self, mat, is_left_mat): 332 | # Check the input shape 333 | if len(mat.shape) not in [2, 3]: 334 | raise ValueError("OutputMat tensors must have shape [batch_size, " 335 | "D, output_dim], or else [D, output_dim] if " 336 | "batch size has already been set") 337 | 338 | # OutputMats on left edge will have a right-facing bond, and vice versa 339 | bond_str = 'b' + ('r' if is_left_mat else 'l') + 'o' 340 | super().__init__(mat, bond_str=bond_str) 341 | 342 | def __mul__(self, edge_vec, rmul=False): 343 | """ 344 | Multiply with an edge vector along the shared linear index 345 | """ 346 | if not isinstance(edge_vec, EdgeVec): 347 | raise NotImplemented 348 | else: 349 | return super().__mul__(edge_vec, rmul) 350 | 351 | def __rmul__(self, edge_vec): 352 | return self.__mul__(edge_vec, rmul=True) 353 | 354 | class EdgeVec(Contractable): 355 | """ 356 | A batch of vectors associated with an edge of our MPS 357 | 358 | EdgeVec instances are always associated with an edge of an MPS, which 359 | requires the is_left_vec flag to be set to True (vector on left edge) or 360 | False (vector on right edge) 361 | """ 362 | def __init__(self, vec, is_left_vec): 363 | # Check the input shape 364 | if len(vec.shape) not in [1, 2]: 365 | raise ValueError("EdgeVec tensors must have shape " 366 | "[batch_size, D], or else [D] if batch size " 367 | "has already been set") 368 | 369 | # EdgeVecs on left edge will have a right-facing bond, and vice versa 370 | bond_str = 'b' + ('r' if is_left_vec else 'l') 371 | super().__init__(vec, bond_str=bond_str) 372 | 373 | def __mul__(self, right_vec): 374 | """ 375 | Take the inner product of our vector with another vector 376 | """ 377 | # The input must be an instance of EdgeVec 378 | if not isinstance(right_vec, EdgeVec): 379 | return NotImplemented 380 | 381 | left_vec = self.tensor.unsqueeze(1) 382 | right_vec = right_vec.tensor.unsqueeze(2) 383 | batch_size = left_vec.size(0) 384 | 385 | # Do the batch inner product 386 | scalar = torch.bmm(left_vec, right_vec).view([batch_size]) 387 | 388 | # Since we only have a single scalar, wrap it as a Scalar 389 | return Scalar(scalar) 390 | 391 | class Scalar(Contractable): 392 | """ 393 | A batch of scalars 394 | """ 395 | def __init__(self, scalar): 396 | # Add dummy dimension if we have a torch scalar 397 | shape = list(scalar.shape) 398 | if shape is []: 399 | scalar = scalar.view([1]) 400 | shape = [1] 401 | 402 | # Check the input shape 403 | if len(shape) != 1: 404 | raise ValueError("input scalar must be a torch tensor with shape " 405 | "[batch_size], or [] or [1] if batch size has " 406 | "been set") 407 | 408 | super().__init__(scalar, bond_str='b') 409 | 410 | def __mul__(self, contractable): 411 | """ 412 | Multiply a contractable by our scalar and return the result 413 | """ 414 | scalar = self.tensor 415 | tensor = contractable.tensor 416 | bond_str = contractable.bond_str 417 | 418 | ein_string = f"{bond_str},b->{bond_str}" 419 | out_tensor = torch.einsum(ein_string, [tensor, scalar]) 420 | 421 | # Wrap the result in the same class right_contractable belongs to 422 | contract_class = type(contractable) 423 | if contract_class is not Contractable: 424 | return contract_class(out_tensor) 425 | else: 426 | return Contractable(out_tensor, bond_str) 427 | 428 | def __rmul__(self, contractable): 429 | # Scalar multiplication is commutative 430 | return self.__mul__(contractable) 431 | -------------------------------------------------------------------------------- /models/lotenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from models.mps import MPS 5 | import pdb 6 | 7 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 8 | 9 | EPS = 1e-6 10 | class loTeNet(nn.Module): 11 | def __init__(self, input_dim, output_dim, bond_dim, feature_dim=2, nCh=3, 12 | kernel=2, virtual_dim=1, 13 | adaptive_mode=False, periodic_bc=False, parallel_eval=False, 14 | label_site=None, path=None, init_std=1e-9, use_bias=True, 15 | fixed_bias=True, cutoff=1e-10, merge_threshold=2000): 16 | super().__init__() 17 | self.input_dim = input_dim 18 | self.virtual_dim = bond_dim 19 | 20 | ### Squeezing of spatial dimension in first step 21 | self.kScale = 4 22 | nCh = self.kScale**2 * nCh 23 | self.input_dim = self.input_dim // self.kScale 24 | 25 | self.nCh = nCh 26 | self.ker = kernel 27 | iDim = (self.input_dim // (self.ker)) 28 | 29 | feature_dim = 2*nCh 30 | 31 | ### First level MPS blocks 32 | self.module1 = nn.ModuleList([ MPS(input_dim=(self.ker)**2, 33 | output_dim=self.virtual_dim, 34 | nCh=nCh, bond_dim=bond_dim, 35 | feature_dim=feature_dim, parallel_eval=parallel_eval, 36 | adaptive_mode=adaptive_mode, periodic_bc=periodic_bc) 37 | for i in range(torch.prod(iDim))]) 38 | 39 | self.BN1 = nn.BatchNorm1d(self.virtual_dim,affine=True) 40 | 41 | 42 | iDim = iDim // self.ker 43 | feature_dim = 2*self.virtual_dim 44 | 45 | ### Second level MPS blocks 46 | self.module2 = nn.ModuleList([ MPS(input_dim=self.ker**2, 47 | output_dim=self.virtual_dim, 48 | nCh=self.virtual_dim, bond_dim=bond_dim, 49 | feature_dim=feature_dim, parallel_eval=parallel_eval, 50 | adaptive_mode=adaptive_mode, periodic_bc=periodic_bc) 51 | for i in range(torch.prod(iDim))]) 52 | 53 | self.BN2 = nn.BatchNorm1d(self.virtual_dim,affine=True) 54 | 55 | iDim = iDim // self.ker 56 | 57 | ### Third level MPS blocks 58 | self.module3 = nn.ModuleList([ MPS(input_dim=self.ker**2, 59 | output_dim=self.virtual_dim, 60 | nCh=self.virtual_dim, bond_dim=bond_dim, 61 | feature_dim=feature_dim, parallel_eval=parallel_eval, 62 | adaptive_mode=adaptive_mode, periodic_bc=periodic_bc) 63 | for i in range(torch.prod(iDim))]) 64 | 65 | self.BN3 = nn.BatchNorm1d(self.virtual_dim,affine=True) 66 | 67 | ### Final MPS block 68 | self.mpsFinal = MPS(input_dim=len(self.module3), 69 | output_dim=output_dim, nCh=1, 70 | bond_dim=bond_dim, feature_dim=feature_dim, 71 | adaptive_mode=adaptive_mode, periodic_bc=periodic_bc, 72 | parallel_eval=parallel_eval) 73 | 74 | def forward(self,x): 75 | b = x.shape[0] #Batch size 76 | 77 | # Increase input feature channel 78 | iDim = self.input_dim 79 | if self.kScale > 1: 80 | x = x.unfold(2,iDim[0],iDim[0]).unfold(3,iDim[1],iDim[1]) 81 | x = x.reshape(b,iDim[0],iDim[1],-1) 82 | 83 | # Level 1 contraction 84 | iDim = self.input_dim//(self.ker) 85 | x = x.unfold(2,iDim[0],iDim[0]).unfold(3,iDim[1],iDim[1]).reshape(b, 86 | self.nCh,(self.ker)**2,-1) 87 | y = [ self.module1[i](x[:,:,:,i]) for i in range(len(self.module1))] 88 | y = torch.stack(y,dim=2) 89 | y = self.BN1(y) 90 | 91 | # Level 2 contraction 92 | 93 | y = y.view(b,self.virtual_dim,iDim[0],iDim[1]) 94 | iDim = (iDim//self.ker) 95 | y = y.unfold(2,iDim[0],iDim[0]).unfold(3,iDim[1], 96 | iDim[1]).reshape(b,self.virtual_dim,self.ker**2,-1) 97 | x = [ self.module2[i](y[:,:,:,i]) for i in range(len(self.module2))] 98 | x = torch.stack(x,dim=2) 99 | x = self.BN2(x) 100 | 101 | 102 | # Level 3 contraction 103 | x = x.view(b,self.virtual_dim,iDim[0],iDim[1]) 104 | iDim = (iDim//self.ker) 105 | x = x.unfold(2,iDim[0],iDim[0]).unfold(3,iDim[1], 106 | iDim[1]).reshape(b,self.virtual_dim,self.ker**2,-1) 107 | y = [ self.module3[i](x[:,:,:,i]) for i in range(len(self.module3))] 108 | 109 | y = torch.stack(y,dim=2) 110 | y = self.BN3(y) 111 | 112 | if y.shape[1] > 1: 113 | # Final layer 114 | y = self.mpsFinal(y) 115 | 116 | return y.view(b) 117 | 118 | 119 | -------------------------------------------------------------------------------- /models/mps.py: -------------------------------------------------------------------------------- 1 | ### Adapted from https://github.com/jemisjoky/TorchMPS/ 2 | import math 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from utils.utils import init_tensor, svd_flex 7 | from models.contractables import SingleMat, MatRegion, OutputCore, ContractableList, \ 8 | EdgeVec, OutputMat 9 | import pdb 10 | from numpy import pi as PI 11 | from numpy import sqrt 12 | from scipy.special import comb 13 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 14 | 15 | class MPS(nn.Module): 16 | """ 17 | Matrix product state which converts input into a single output vector 18 | """ 19 | def __init__(self, input_dim, output_dim, bond_dim, feature_dim=2, nCh=3, 20 | adaptive_mode=False, periodic_bc=False, parallel_eval=False, 21 | label_site=None, path=None, init_std=1e-9, use_bias=True, 22 | fixed_bias=True, cutoff=1e-10, merge_threshold=2000): 23 | super().__init__() 24 | 25 | org_input_dim = input_dim 26 | if label_site is None: 27 | label_site = input_dim // 2 28 | assert label_site >= 0 and label_site <= input_dim 29 | 30 | # Using bias matrices in adaptive_mode is too complicated, so I'm 31 | # disabling it here 32 | if adaptive_mode: 33 | use_bias = False 34 | # Our MPS is made of two InputRegions separated by an OutputSite. 35 | module_list = [] 36 | init_args = {'bond_str': 'slri', 37 | 'shape': [label_site, bond_dim, bond_dim, feature_dim], 38 | 'init_method': ('min_random_eye' if adaptive_mode else 39 | 'random_zero', init_std, output_dim)} 40 | 41 | # The first input region 42 | if label_site > 0: 43 | tensor = init_tensor(**init_args) 44 | 45 | module_list.append(InputRegion(tensor, use_bias=use_bias, 46 | fixed_bias=fixed_bias)) 47 | 48 | # The output site 49 | tensor = init_tensor(shape=[output_dim, bond_dim, bond_dim], 50 | bond_str='olr', init_method=('min_random_eye' if adaptive_mode else 51 | 'random_eye', init_std, output_dim)) 52 | module_list.append(OutputSite(tensor)) 53 | 54 | # The other input region 55 | if label_site < input_dim: 56 | init_args['shape'] = [input_dim-label_site, bond_dim, bond_dim, 57 | feature_dim] 58 | tensor = init_tensor(**init_args) 59 | module_list.append(InputRegion(tensor, use_bias=use_bias, 60 | fixed_bias=fixed_bias)) 61 | 62 | # Initialize linear_region according to our adaptive_mode specification 63 | if adaptive_mode: 64 | self.linear_region = MergedLinearRegion(module_list=module_list, 65 | periodic_bc=periodic_bc, 66 | parallel_eval=parallel_eval, cutoff=cutoff, 67 | merge_threshold=merge_threshold) 68 | 69 | # Initialize the list of bond dimensions, which starts out constant 70 | self.bond_list = bond_dim * torch.ones(input_dim + 2, 71 | dtype=torch.long) 72 | if not periodic_bc: 73 | self.bond_list[0], self.bond_list[-1] = 1, 1 74 | 75 | # Initialize the list of singular values, which start out at -1 76 | self.sv_list = -1. * torch.ones([input_dim + 2, bond_dim]) 77 | 78 | else: 79 | self.linear_region = LinearRegion(module_list=module_list, 80 | periodic_bc=periodic_bc, 81 | parallel_eval=parallel_eval) 82 | assert len(self.linear_region) == input_dim 83 | 84 | if path: 85 | assert isinstance(path, (list, torch.Tensor)) 86 | assert len(path) == input_dim 87 | 88 | # Set the rest of our MPS attributes 89 | self.input_dim = input_dim 90 | self.output_dim = output_dim 91 | self.bond_dim = bond_dim 92 | self.feature_dim = feature_dim 93 | self.periodic_bc = periodic_bc 94 | self.adaptive_mode = adaptive_mode 95 | self.label_site = label_site 96 | self.path = path 97 | self.use_bias = use_bias 98 | self.fixed_bias = fixed_bias 99 | self.cutoff = cutoff 100 | self.merge_threshold = merge_threshold 101 | self.feature_map = None 102 | self.linear_region = self.linear_region.to(device) 103 | module_list = [m.to(device) for m in module_list] 104 | 105 | def forward(self, input_data): 106 | """ 107 | Embed our data and pass it to an MPS with a single output site 108 | 109 | Args: 110 | input_data (Tensor): Input with shape [batch_size, input_dim] or 111 | [batch_size, input_dim, feature_dim]. In the 112 | former case, the data points are turned into 113 | 2D vectors using a default linear feature map. 114 | """ 115 | input_data = input_data.permute(0,2,1) 116 | 117 | x1 = torch.cos(input_data * PI/2) 118 | x2 = torch.sin(input_data* PI/2) 119 | x = torch.cat((x1,x2),dim=2) 120 | 121 | output = self.linear_region(x) 122 | 123 | return output.squeeze() 124 | 125 | def core_len(self): 126 | """ 127 | Returns the number of cores, which is at least the required input size 128 | """ 129 | return self.linear_region.core_len() 130 | 131 | def __len__(self): 132 | """ 133 | Returns the number of input sites, which equals the input size 134 | """ 135 | return self.input_dim 136 | 137 | class LinearRegion(nn.Module): 138 | """ 139 | List of modules which feeds input to each module and returns reduced output 140 | """ 141 | def __init__(self, module_list,periodic_bc=False, parallel_eval=False, 142 | module_states=None): 143 | # Check that module_list is a list whose entries are Pytorch modules 144 | if not isinstance(module_list, list) or module_list is []: 145 | raise ValueError("Input to LinearRegion must be nonempty list") 146 | for i, item in enumerate(module_list): 147 | if not isinstance(item, nn.Module): 148 | raise ValueError("Input items to LinearRegion must be PyTorch " 149 | f"Module instances, but item {i} is not") 150 | super().__init__() 151 | 152 | # Wrap as a ModuleList for proper parameter registration 153 | self.module_list = nn.ModuleList(module_list) 154 | self.periodic_bc = periodic_bc 155 | self.parallel_eval = parallel_eval 156 | 157 | def forward(self, input_data): 158 | """ 159 | Contract input with list of MPS cores and return result as contractable 160 | 161 | Args: 162 | input_data (Tensor): Input with shape [batch_size, input_dim, 163 | feature_dim] 164 | """ 165 | # Check that input_data has the correct shape 166 | assert len(input_data.shape) == 3 167 | assert input_data.size(1) == len(self) 168 | periodic_bc = self.periodic_bc 169 | parallel_eval = self.parallel_eval 170 | lin_bonds = ['l', 'r'] 171 | 172 | # For each module, pull out the number of pixels needed and call that 173 | # module's forward() method, putting the result in contractable_list 174 | ind = 0 175 | contractable_list = [] 176 | for module in self.module_list: 177 | mod_len = len(module) 178 | if mod_len == 1: 179 | mod_input = input_data[:, ind] 180 | else: 181 | mod_input = input_data[:, ind:(ind+mod_len)] 182 | ind += mod_len 183 | 184 | contractable_list.append(module(mod_input)) 185 | 186 | # For periodic boundary conditions, reduce contractable_list and 187 | # trace over the left and right indices to get our output 188 | if periodic_bc: 189 | contractable_list = ContractableList(contractable_list) 190 | contractable = contractable_list.reduce(parallel_eval=True) 191 | 192 | # Unpack the output (atomic) contractable 193 | tensor, bond_str = contractable.tensor, contractable.bond_str 194 | assert all(c in bond_str for c in lin_bonds) 195 | 196 | # Build einsum string for the trace of tensor 197 | in_str, out_str = "", "" 198 | for c in bond_str: 199 | if c in lin_bonds: 200 | in_str += 'l' 201 | else: 202 | in_str += c 203 | out_str += c 204 | ein_str = in_str + "->" + out_str 205 | 206 | # Return the trace over left and right indices 207 | return torch.einsum(ein_str, [tensor]) 208 | 209 | # For open boundary conditions, add dummy edge vectors to 210 | # contractable_list and reduce everything to get our output 211 | else: 212 | # Get the dimension of left and right bond indices 213 | end_items = [contractable_list[i]for i in [0, -1]] 214 | bond_strs = [item.bond_str for item in end_items] 215 | bond_inds = [bs.index(c) for (bs, c) in zip(bond_strs, lin_bonds)] 216 | bond_dims = [item.tensor.size(ind) for (item, ind) in 217 | zip(end_items, bond_inds)] 218 | 219 | # Build dummy end vectors and insert them at the ends of our list 220 | end_vecs = [torch.zeros(dim) for dim in bond_dims] 221 | end_vecs = [e.to(device) for e in end_vecs] 222 | 223 | for vec in end_vecs: 224 | vec[0] = 1 225 | 226 | contractable_list.insert(0, EdgeVec(end_vecs[0], is_left_vec=True)) 227 | contractable_list.append(EdgeVec(end_vecs[1], is_left_vec=False)) 228 | 229 | # Multiply together everything in contractable_list 230 | contractable_list = ContractableList(contractable_list) 231 | output = contractable_list.reduce(parallel_eval=parallel_eval) 232 | 233 | return output.tensor 234 | 235 | def core_len(self): 236 | """ 237 | Returns the number of cores, which is at least the required input size 238 | """ 239 | return sum([module.core_len() for module in self.module_list]) 240 | 241 | def __len__(self): 242 | """ 243 | Returns the number of input sites, which is the required input size 244 | """ 245 | return sum([len(module) for module in self.module_list]) 246 | 247 | class MergedLinearRegion(LinearRegion): 248 | """ 249 | Dynamic variant of LinearRegion that periodically rearranges its submodules 250 | """ 251 | def __init__(self, module_list, periodic_bc=False, parallel_eval=False, cutoff=1e-10, merge_threshold=2000): 252 | # Initialize a LinearRegion with our given module_list 253 | super().__init__(module_list, periodic_bc, parallel_eval) 254 | 255 | # Initialize attributes self.module_list_0 and self.module_list_1 256 | # using the unmerged self.module_list, then redefine the latter in 257 | # terms of one of the former lists 258 | self.offset = 0 259 | self.merge(offset=self.offset) 260 | self.merge(offset=(self.offset+1)%2) 261 | self.module_list = getattr(self, f"module_list_{self.offset}") 262 | 263 | # Initialize variables used during switching 264 | self.input_counter = 0 265 | self.merge_threshold = merge_threshold 266 | self.cutoff = cutoff 267 | 268 | def forward(self, input_data): 269 | """ 270 | Contract input with list of MPS cores and return result as contractable 271 | 272 | MergedLinearRegion keeps an input counter of the number of inputs, and 273 | when this exceeds its merge threshold, triggers an unmerging and 274 | remerging of its parameter tensors. 275 | 276 | Args: 277 | input_data (Tensor): Input with shape [batch_size, input_dim, 278 | feature_dim] 279 | """ 280 | # If we've hit our threshold, flip the merge state of our tensors 281 | if self.input_counter >= self.merge_threshold: 282 | bond_list, sv_list = self.unmerge(cutoff=self.cutoff) 283 | self.offset = (self.offset + 1) % 2 284 | self.merge(offset=self.offset) 285 | self.input_counter -= self.merge_threshold 286 | 287 | # Point self.module_list to the appropriate merged module 288 | self.module_list = getattr(self, f"module_list_{self.offset}") 289 | else: 290 | bond_list, sv_list = None, None 291 | 292 | # Increment our counter and call the LinearRegion's forward method 293 | self.input_counter += input_data.size(0) 294 | output = super().forward(input_data) 295 | 296 | # If we flipped our merge state, then return the bond_list and output 297 | if bond_list: 298 | return output, bond_list, sv_list 299 | else: 300 | return output 301 | 302 | @torch.no_grad() 303 | def merge(self, offset): 304 | """ 305 | Convert unmerged modules in self.module_list to merged counterparts 306 | 307 | This proceeds by first merging all unmerged cores internally, then 308 | merging lone cores when possible during a second sweep 309 | """ 310 | assert offset in [0, 1] 311 | 312 | unmerged_list = self.module_list 313 | 314 | # Merge each core internally and add the results to midway_list 315 | site_num = offset 316 | merged_list = [] 317 | for core in unmerged_list: 318 | assert not isinstance(core, MergedInput) 319 | assert not isinstance(core, MergedOutput) 320 | 321 | # Apply internal merging routine if our core supports it 322 | if hasattr(core, 'merge'): 323 | merged_list.extend(core.merge(offset=site_num%2)) 324 | else: 325 | merged_list.append(core) 326 | 327 | site_num += core.core_len() 328 | 329 | # Merge pairs of cores when possible (currently only with 330 | # InputSites), making sure to respect the offset for merging. 331 | while True: 332 | mod_num, site_num = 0, 0 333 | combined_list = [] 334 | 335 | while mod_num < len(merged_list) - 1: 336 | left_core, right_core = merged_list[mod_num: mod_num+2] 337 | new_core = self.combine(left_core, right_core, 338 | merging=True) 339 | 340 | # If cores aren't combinable, move our sliding window by 1 341 | if new_core is None or offset != site_num % 2: 342 | combined_list.append(left_core) 343 | mod_num += 1 344 | site_num += left_core.core_len() 345 | 346 | # If we get something new, move to the next distinct pair 347 | else: 348 | assert new_core.core_len() == left_core.core_len() + \ 349 | right_core.core_len() 350 | combined_list.append(new_core) 351 | mod_num += 2 352 | site_num += new_core.core_len() 353 | 354 | # Add the last core if there's nothing to merge it with 355 | if mod_num == len(merged_list)-1: 356 | combined_list.append(merged_list[mod_num]) 357 | mod_num += 1 358 | 359 | # We're finished when unmerged_list remains unchanged 360 | if len(combined_list) == len(merged_list): 361 | break 362 | else: 363 | merged_list = combined_list 364 | 365 | # Finally, update the appropriate merged module list 366 | list_name = f"module_list_{offset}" 367 | # If the merged module list hasn't been set yet, initialize it 368 | if not hasattr(self, list_name): 369 | setattr(self, list_name, nn.ModuleList(merged_list)) 370 | 371 | # Otherwise, do an in-place update so that all tensors remain 372 | # properly registered with whatever optimizer we use 373 | else: 374 | module_list = getattr(self, list_name) 375 | assert len(module_list) == len(merged_list) 376 | for i in range(len(module_list)): 377 | assert module_list[i].tensor.shape == \ 378 | merged_list[i].tensor.shape 379 | module_list[i].tensor[:] = merged_list[i].tensor 380 | 381 | @torch.no_grad() 382 | def unmerge(self, cutoff=1e-10): 383 | """ 384 | Convert merged modules to unmerged counterparts 385 | 386 | This proceeds by first unmerging all merged cores internally, then 387 | combining lone cores where possible 388 | """ 389 | list_name = f"module_list_{self.offset}" 390 | merged_list = getattr(self, list_name) 391 | 392 | # Unmerge each core internally and add results to unmerged_list 393 | unmerged_list, bond_list, sv_list = [], [-1], [-1] 394 | for core in merged_list: 395 | 396 | # Apply internal unmerging routine if our core supports it 397 | if hasattr(core, 'unmerge'): 398 | new_cores, new_bonds, new_svs = core.unmerge(cutoff) 399 | unmerged_list.extend(new_cores) 400 | bond_list.extend(new_bonds[1:]) 401 | sv_list.extend(new_svs[1:]) 402 | else: 403 | assert not isinstance(core, InputRegion) 404 | unmerged_list.append(core) 405 | bond_list.append(-1) 406 | sv_list.append(-1) 407 | 408 | # Combine all combinable pairs of cores. This occurs in several 409 | # passes, and for now acts nontrivially only on InputSite instances 410 | while True: 411 | mod_num = 0 412 | combined_list = [] 413 | 414 | while mod_num < len(unmerged_list) - 1: 415 | left_core, right_core = unmerged_list[mod_num: mod_num+2] 416 | new_core = self.combine(left_core, right_core, 417 | merging=False) 418 | 419 | # If cores aren't combinable, move our sliding window by 1 420 | if new_core is None: 421 | combined_list.append(left_core) 422 | mod_num += 1 423 | 424 | # If we get something new, move to the next distinct pair 425 | else: 426 | combined_list.append(new_core) 427 | mod_num += 2 428 | 429 | # Add the last core if there's nothing to combine it with 430 | if mod_num == len(unmerged_list)-1: 431 | combined_list.append(unmerged_list[mod_num]) 432 | mod_num += 1 433 | 434 | # We're finished when unmerged_list remains unchanged 435 | if len(combined_list) == len(unmerged_list): 436 | break 437 | else: 438 | unmerged_list = combined_list 439 | 440 | # Find the average (log) norm of all of our cores 441 | log_norms = [] 442 | for core in unmerged_list: 443 | log_norms.append([torch.log(norm) for norm in core.get_norm()]) 444 | log_scale = sum([sum(ns) for ns in log_norms]) 445 | log_scale /= sum([len(ns) for ns in log_norms]) 446 | 447 | # Now rescale all cores so that their norms are roughly equal 448 | scales = [[torch.exp(log_scale-n) for n in ns] for ns in log_norms] 449 | for core, these_scales in zip(unmerged_list, scales): 450 | core.rescale_norm(these_scales) 451 | 452 | # Add our unmerged module list as a new attribute and return 453 | # the updated bond dimensions 454 | self.module_list = nn.ModuleList(unmerged_list) 455 | return bond_list, sv_list 456 | 457 | def combine(self, left_core, right_core, merging): 458 | """ 459 | Combine a pair of cores into a new core using context-dependent rules 460 | 461 | Depending on the types of left_core and right_core, along with whether 462 | we're currently merging (merging=True) or unmerging (merging=False), 463 | either return a new core, or None if no rule exists for this context 464 | """ 465 | 466 | # Combine an OutputSite with a stray InputSite, return a MergedOutput 467 | if merging and ((isinstance(left_core, OutputSite) and 468 | isinstance(right_core, InputSite)) or 469 | (isinstance(left_core, InputSite) and 470 | isinstance(right_core, OutputSite))): 471 | 472 | left_site = isinstance(left_core, InputSite) 473 | if left_site: 474 | new_tensor = torch.einsum('lui,our->olri', [left_core.tensor, 475 | right_core.tensor]) 476 | else: 477 | new_tensor = torch.einsum('olu,uri->olri', [left_core.tensor, 478 | right_core.tensor]) 479 | return MergedOutput(new_tensor, left_output=(not left_site)) 480 | 481 | # Combine an InputRegion with a stray InputSite, return an InputRegion 482 | elif not merging and ((isinstance(left_core, InputRegion) and 483 | isinstance(right_core, InputSite)) or 484 | (isinstance(left_core, InputSite) and 485 | isinstance(right_core, InputRegion))): 486 | 487 | left_site = isinstance(left_core, InputSite) 488 | if left_site: 489 | left_tensor = left_core.tensor.unsqueeze(0) 490 | right_tensor = right_core.tensor 491 | else: 492 | left_tensor = left_core.tensor 493 | right_tensor = right_core.tensor.unsqueeze(0) 494 | 495 | assert left_tensor.shape[1:] == right_tensor.shape[1:] 496 | new_tensor = torch.cat([left_tensor, right_tensor]) 497 | 498 | return InputRegion(new_tensor) 499 | 500 | # If this situation doesn't belong to the above cases, return None 501 | else: 502 | return None 503 | 504 | def core_len(self): 505 | """ 506 | Returns the number of cores, which is at least the required input size 507 | """ 508 | return sum([module.core_len() for module in self.module_list]) 509 | 510 | def __len__(self): 511 | """ 512 | Returns the number of input sites, which is the required input size 513 | """ 514 | return sum([len(module) for module in self.module_list]) 515 | 516 | class InputRegion(nn.Module): 517 | """ 518 | Contiguous region of MPS input cores, associated with bond_str = 'slri' 519 | """ 520 | def __init__(self, tensor, use_bias=True, fixed_bias=True, bias_mat=None, 521 | ephemeral=False): 522 | super().__init__() 523 | 524 | # Make sure tensor has correct size and the component mats are square 525 | assert len(tensor.shape) == 4 526 | assert tensor.size(1) == tensor.size(2) 527 | bond_dim = tensor.size(1) 528 | 529 | # If we are using bias matrices, set those up here 530 | if use_bias: 531 | assert bias_mat is None or isinstance(bias_mat, torch.Tensor) 532 | bias_mat = torch.eye(bond_dim).unsqueeze(0) if bias_mat is None \ 533 | else bias_mat 534 | 535 | bias_modes = len(list(bias_mat.shape)) 536 | assert bias_modes in [2, 3] 537 | if bias_modes == 2: 538 | bias_mat = bias_mat.unsqueeze(0) 539 | 540 | # Register our tensors as a Pytorch Parameter or Tensor 541 | if ephemeral: 542 | self.register_buffer(name='tensor', tensor=tensor.contiguous()) 543 | self.register_buffer(name='bias_mat', tensor=bias_mat) 544 | else: 545 | self.register_parameter(name='tensor', 546 | param=nn.Parameter(tensor.contiguous())) 547 | if fixed_bias: 548 | self.register_buffer(name='bias_mat', tensor=bias_mat) 549 | else: 550 | self.register_parameter(name='bias_mat', 551 | param=nn.Parameter(bias_mat)) 552 | 553 | self.use_bias = use_bias 554 | self.fixed_bias = fixed_bias 555 | 556 | def forward(self, input_data): 557 | """ 558 | Contract input with MPS cores and return result as a MatRegion 559 | 560 | Args: 561 | input_data (Tensor): Input with shape [batch_size, input_dim, 562 | feature_dim] 563 | """ 564 | # Check that input_data has the correct shape 565 | tensor = self.tensor 566 | assert len(input_data.shape) == 3 567 | assert input_data.size(1) == len(self) 568 | assert input_data.size(2) == tensor.size(3) 569 | 570 | # Contract the input with our core tensor 571 | mats = torch.einsum('slri,bsi->bslr', [tensor, input_data]) 572 | 573 | # If we're using bias matrices, add those here 574 | if self.use_bias: 575 | bond_dim = tensor.size(1) 576 | bias_mat = self.bias_mat.unsqueeze(0) 577 | mats = mats + bias_mat.expand_as(mats) 578 | 579 | return MatRegion(mats) 580 | 581 | def merge(self, offset): 582 | """ 583 | Merge all pairs of neighboring cores and return a new list of cores 584 | 585 | offset is either 0 or 1, which gives the first core at which we start 586 | our merging. Depending on the length of our InputRegion, the output of 587 | merge may have 1, 2, or 3 entries, with the majority of sites ending in 588 | a MergedInput instance 589 | """ 590 | assert offset in [0, 1] 591 | num_sites = self.core_len() 592 | parity = num_sites % 2 593 | 594 | # Cases with empty tensors might arise in recursion below 595 | if num_sites == 0: 596 | return [None] 597 | 598 | # Simplify the problem into one where offset=0 and num_sites is even 599 | if (offset, parity) == (1, 1): 600 | out_list = [self[0], self[1:].merge(offset=0)[0]] 601 | elif (offset, parity) == (1, 0): 602 | out_list = [self[0], self[1:-1].merge(offset=0)[0], self[-1]] 603 | elif (offset, parity) == (0, 1): 604 | out_list = [self[:-1].merge(offset=0)[0], self[-1]] 605 | 606 | # The main case of interest, with no offset and an even number of sites 607 | else: 608 | tensor = self.tensor 609 | even_cores, odd_cores = tensor[0::2], tensor[1::2] 610 | assert len(even_cores) == len(odd_cores) 611 | 612 | # Multiply all pairs of cores, keeping inputs separate 613 | merged_cores = torch.einsum('slui,surj->slrij', [even_cores, 614 | odd_cores]) 615 | out_list = [MergedInput(merged_cores)] 616 | 617 | # Remove empty MergedInputs, which appear in very small InputRegions 618 | return [x for x in out_list if x is not None] 619 | 620 | def __getitem__(self, key): 621 | """ 622 | Returns an InputRegion instance sliced along the site index 623 | """ 624 | assert isinstance(key, int) or isinstance(key, slice) 625 | 626 | if isinstance(key, slice): 627 | return InputRegion(self.tensor[key]) 628 | else: 629 | return InputSite(self.tensor[key]) 630 | 631 | def get_norm(self): 632 | """ 633 | Returns list of the norms of each core in InputRegion 634 | """ 635 | return [torch.norm(core) for core in self.tensor] 636 | 637 | @torch.no_grad() 638 | def rescale_norm(self, scale_list): 639 | """ 640 | Rescales the norm of each core by an amount specified in scale_list 641 | 642 | For the i'th tensor defining a core in InputRegion, we rescale as 643 | tensor_i <- scale_i * tensor_i, where scale_i = scale_list[i] 644 | """ 645 | assert len(scale_list) == len(self.tensor) 646 | 647 | for core, scale in zip(self.tensor, scale_list): 648 | core *= scale 649 | 650 | def core_len(self): 651 | return len(self) 652 | 653 | def __len__(self): 654 | return self.tensor.size(0) 655 | 656 | class MergedInput(nn.Module): 657 | """ 658 | Contiguous region of merged MPS cores, each taking in a pair of input data 659 | 660 | Since MergedInput arises after contracting together existing input cores, 661 | a merged input tensor is required for initialization 662 | """ 663 | def __init__(self, tensor): 664 | # Check that our input tensor has the correct shape 665 | bond_str = 'slrij' 666 | shape = tensor.shape 667 | assert len(shape) == 5 668 | assert shape[1] == shape[2] 669 | assert shape[3] == shape[4] 670 | 671 | super().__init__() 672 | 673 | # Register our tensor as a Pytorch Parameter 674 | self.register_parameter(name='tensor', 675 | param=nn.Parameter(tensor.contiguous())) 676 | 677 | def forward(self, input_data): 678 | """ 679 | Contract input with merged MPS cores and return result as a MatRegion 680 | 681 | Args: 682 | input_data (Tensor): Input with shape [batch_size, input_dim, 683 | feature_dim], where input_dim must be even 684 | (each merged core takes 2 inputs) 685 | """ 686 | # Check that input_data has the correct shape 687 | tensor = self.tensor 688 | assert len(input_data.shape) == 3 689 | assert input_data.size(1) == len(self) 690 | assert input_data.size(2) == tensor.size(3) 691 | assert input_data.size(1) % 2 == 0 692 | 693 | # Divide input_data into inputs living on even and on odd sites 694 | inputs = [input_data[:, 0::2], input_data[:, 1::2]] 695 | 696 | # Contract the odd (right-most) and even inputs with merged cores 697 | tensor = torch.einsum('slrij,bsj->bslri', [tensor, inputs[1]]) 698 | mats = torch.einsum('bslri,bsi->bslr', [tensor, inputs[0]]) 699 | 700 | return MatRegion(mats) 701 | 702 | def unmerge(self, cutoff=1e-10): 703 | """ 704 | Separate the cores in our MergedInput and return an InputRegion 705 | 706 | The length of the resultant InputRegion will be identical to our 707 | original MergedInput (same number of inputs), but its core_len will 708 | be doubled (twice as many individual cores) 709 | """ 710 | bond_str = 'slrij' 711 | tensor = self.tensor 712 | svd_string = 'lrij->lui,urj' 713 | max_D = tensor.size(1) 714 | 715 | # Split every one of the cores into two and add them both to core_list 716 | core_list, bond_list, sv_list = [], [-1], [-1] 717 | for merged_core in tensor: 718 | sv_vec = torch.empty(max_D) 719 | left_core, right_core, bond_dim = svd_flex(merged_core, svd_string, 720 | max_D, cutoff, sv_vec=sv_vec) 721 | 722 | core_list += [left_core, right_core] 723 | bond_list += [bond_dim, -1] 724 | sv_list += [sv_vec, -1] 725 | 726 | # Collate the split cores into one tensor and return as an InputRegion 727 | tensor = torch.stack(core_list) 728 | return [InputRegion(tensor)], bond_list, sv_list 729 | 730 | def get_norm(self): 731 | """ 732 | Returns list of the norm of each core in MergedInput 733 | """ 734 | return [torch.norm(core) for core in self.tensor] 735 | 736 | @torch.no_grad() 737 | def rescale_norm(self, scale_list): 738 | """ 739 | Rescales the norm of each core by an amount specified in scale_list 740 | 741 | For the i'th tensor defining a core in MergedInput, we rescale as 742 | tensor_i <- scale_i * tensor_i, where scale_i = scale_list[i] 743 | """ 744 | assert len(scale_list) == len(self.tensor) 745 | 746 | for core, scale in zip(self.tensor, scale_list): 747 | core *= scale 748 | 749 | def core_len(self): 750 | return len(self) 751 | 752 | def __len__(self): 753 | """ 754 | Returns the number of input sites, which is twice the number of cores 755 | """ 756 | return 2 * self.tensor.size(0) 757 | 758 | class InputSite(nn.Module): 759 | """ 760 | A single MPS core which takes in a single input datum, bond_str = 'lri' 761 | """ 762 | def __init__(self, tensor): 763 | super().__init__() 764 | # Register our tensor as a Pytorch Parameter 765 | self.register_parameter(name='tensor', 766 | param=nn.Parameter(tensor.contiguous())) 767 | 768 | def forward(self, input_data): 769 | """ 770 | Contract input with MPS core and return result as a SingleMat 771 | 772 | Args: 773 | input_data (Tensor): Input with shape [batch_size, feature_dim] 774 | """ 775 | # Check that input_data has the correct shape 776 | tensor = self.tensor 777 | assert len(input_data.shape) == 2 778 | assert input_data.size(1) == tensor.size(2) 779 | 780 | # Contract the input with our core tensor 781 | mat = torch.einsum('lri,bi->blr', [tensor, input_data]) 782 | 783 | return SingleMat(mat) 784 | 785 | def get_norm(self): 786 | """ 787 | Returns the norm of our core tensor, wrapped as a singleton list 788 | """ 789 | return [torch.norm(self.tensor)] 790 | 791 | @torch.no_grad() 792 | def rescale_norm(self, scale): 793 | """ 794 | Rescales the norm of our core by a factor of input `scale` 795 | """ 796 | if isinstance(scale, list): 797 | assert len(scale) == 1 798 | scale = scale[0] 799 | 800 | self.tensor *= scale 801 | 802 | def core_len(self): 803 | return 1 804 | 805 | def __len__(self): 806 | return 1 807 | 808 | class OutputSite(nn.Module): 809 | """ 810 | A single MPS core with no input and a single output index, bond_str = 'olr' 811 | """ 812 | def __init__(self, tensor): 813 | super().__init__() 814 | # Register our tensor as a Pytorch Parameter 815 | self.register_parameter(name='tensor', 816 | param=nn.Parameter(tensor.contiguous())) 817 | 818 | def forward(self, input_data): 819 | """ 820 | Return the OutputSite wrapped as an OutputCore contractable 821 | """ 822 | return OutputCore(self.tensor) 823 | 824 | def get_norm(self): 825 | """ 826 | Returns the norm of our core tensor, wrapped as a singleton list 827 | """ 828 | return [torch.norm(self.tensor)] 829 | 830 | @torch.no_grad() 831 | def rescale_norm(self, scale): 832 | """ 833 | Rescales the norm of our core by a factor of input `scale` 834 | """ 835 | if isinstance(scale, list): 836 | assert len(scale) == 1 837 | scale = scale[0] 838 | 839 | self.tensor *= scale 840 | 841 | def core_len(self): 842 | return 1 843 | 844 | def __len__(self): 845 | return 0 846 | 847 | class MergedOutput(nn.Module): 848 | """ 849 | Merged MPS core taking in one input datum and returning an output vector 850 | 851 | Since MergedOutput arises after contracting together an existing input and 852 | output core, an already-merged tensor is required for initialization 853 | 854 | Args: 855 | tensor (Tensor): Value that our merged core is initialized to 856 | left_output (bool): Specifies if the output core is on the left side of 857 | the input core (True), or on the right (False) 858 | """ 859 | def __init__(self, tensor, left_output): 860 | # Check that our input tensor has the correct shape 861 | bond_str = 'olri' 862 | assert len(tensor.shape) == 4 863 | super().__init__() 864 | 865 | # Register our tensor as a Pytorch Parameter 866 | self.register_parameter(name='tensor', 867 | param=nn.Parameter(tensor.contiguous())) 868 | self.left_output = left_output 869 | 870 | def forward(self, input_data): 871 | """ 872 | Contract input with input index of core and return an OutputCore 873 | 874 | Args: 875 | input_data (Tensor): Input with shape [batch_size, feature_dim] 876 | """ 877 | # Check that input_data has the correct shape 878 | tensor = self.tensor 879 | assert len(input_data.shape) == 2 880 | assert input_data.size(1) == tensor.size(3) 881 | 882 | # Contract the input with our core tensor 883 | tensor = torch.einsum('olri,bi->bolr', [tensor, input_data]) 884 | 885 | return OutputCore(tensor) 886 | 887 | def unmerge(self, cutoff=1e-10): 888 | """ 889 | Split our MergedOutput into an OutputSite and an InputSite 890 | 891 | The non-zero entries of our tensors are dynamically sized according to 892 | the SVD cutoff, but will generally be padded with zeros to give the 893 | new index a regular size. 894 | """ 895 | bond_str = 'olri' 896 | tensor = self.tensor 897 | left_output = self.left_output 898 | if left_output: 899 | svd_string = 'olri->olu,uri' 900 | max_D = tensor.size(2) 901 | sv_vec = torch.empty(max_D) 902 | 903 | output_core, input_core, bond_dim = svd_flex(tensor, svd_string, 904 | max_D, cutoff, sv_vec=sv_vec) 905 | return ([OutputSite(output_core), InputSite(input_core)], 906 | [-1, bond_dim, -1], [-1, sv_vec, -1]) 907 | 908 | else: 909 | svd_string = 'olri->our,lui' 910 | max_D = tensor.size(1) 911 | sv_vec = torch.empty(max_D) 912 | 913 | output_core, input_core, bond_dim = svd_flex(tensor, svd_string, 914 | max_D, cutoff, sv_vec=sv_vec) 915 | return ([InputSite(input_core), OutputSite(output_core)], 916 | [-1, bond_dim, -1], [-1, sv_vec, -1]) 917 | 918 | def get_norm(self): 919 | """ 920 | Returns the norm of our core tensor, wrapped as a singleton list 921 | """ 922 | return [torch.norm(self.tensor)] 923 | 924 | @torch.no_grad() 925 | def rescale_norm(self, scale): 926 | """ 927 | Rescales the norm of our core by a factor of input `scale` 928 | """ 929 | if isinstance(scale, list): 930 | assert len(scale) == 1 931 | scale = scale[0] 932 | 933 | self.tensor *= scale 934 | 935 | def core_len(self): 936 | return 2 937 | 938 | def __len__(self): 939 | return 1 940 | 941 | class InitialVector(nn.Module): 942 | """ 943 | Vector of ones and zeros to act as initial vector within the MPS 944 | 945 | By default the initial vector is chosen to be all ones, but if fill_dim is 946 | specified then only the first fill_dim entries are set to one, with the 947 | rest zero. 948 | 949 | If fixed_vec is False, then the initial vector will be registered as a 950 | trainable model parameter. 951 | """ 952 | def __init__(self, bond_dim, fill_dim=None, fixed_vec=True, 953 | is_left_vec=True): 954 | super().__init__() 955 | 956 | vec = torch.ones(bond_dim) 957 | if fill_dim is not None: 958 | assert fill_dim >= 0 and fill_dim <= bond_dim 959 | vec[fill_dim:] = 0 960 | 961 | if fixed_vec: 962 | vec.requires_grad = False 963 | self.register_buffer(name='vec', tensor=vec) 964 | else: 965 | vec.requires_grad = True 966 | self.register_parameter(name='vec', param=nn.Parameter(vec)) 967 | 968 | assert isinstance(is_left_vec, bool) 969 | self.is_left_vec = is_left_vec 970 | 971 | def forward(self): 972 | """ 973 | Return our initial vector wrapped as an EdgeVec contractable 974 | """ 975 | return EdgeVec(self.vec, self.is_left_vec) 976 | 977 | def core_len(self): 978 | return 1 979 | 980 | def __len__(self): 981 | return 0 982 | 983 | class TerminalOutput(nn.Module): 984 | """ 985 | Output matrix at end of chain to transmute virtual state into output vector 986 | 987 | By default, a fixed rectangular identity matrix with shape 988 | [bond_dim, output_dim] will be used as a state transducer. If fixed_mat is 989 | False, then the matrix will be registered as a trainable model parameter. 990 | """ 991 | def __init__(self, bond_dim, output_dim, fixed_mat=False, 992 | is_left_mat=False): 993 | super().__init__() 994 | 995 | # I don't have a nice initialization scheme for a non-injective fixed 996 | # state transducer, so just throw an error if that's needed 997 | if fixed_mat and output_dim > bond_dim: 998 | raise ValueError("With fixed_mat=True, TerminalOutput currently " 999 | "only supports initialization for bond_dim >= " 1000 | "output_dim, but here bond_dim=" 1001 | f"{bond_dim} and output_dim={output_dim}") 1002 | 1003 | # Initialize the matrix and register it appropriately 1004 | mat = torch.eye(bond_dim, output_dim) 1005 | if fixed_mat: 1006 | mat.requires_grad = False 1007 | self.register_buffer(name='mat', tensor=mat) 1008 | else: 1009 | # Add some noise to help with training 1010 | mat = mat + torch.randn_like(mat) / bond_dim 1011 | 1012 | mat.requires_grad = True 1013 | self.register_parameter(name='mat', param=nn.Parameter(mat)) 1014 | 1015 | assert isinstance(is_left_mat, bool) 1016 | self.is_left_mat = is_left_mat 1017 | 1018 | def forward(self): 1019 | """ 1020 | Return our terminal matrix wrapped as an OutputMat contractable 1021 | """ 1022 | return OutputMat(self.mat, self.is_left_mat) 1023 | 1024 | def core_len(self): 1025 | return 1 1026 | 1027 | def __len__(self): 1028 | return 0 1029 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import time 3 | import torch 4 | from models.lotenet import loTeNet 5 | from torchvision import transforms, datasets 6 | import pdb 7 | from utils.lidc_dataset import LIDC 8 | from utils.tools import * 9 | from models.Densenet import * 10 | import argparse 11 | 12 | # Globally load device identifier 13 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 14 | 15 | def evaluate(loader): 16 | ### Evaluation funcntion for validation/testing 17 | 18 | with torch.no_grad(): 19 | vl_acc = 0. 20 | vl_loss = 0. 21 | labelsNp = np.zeros(1) 22 | predsNp = np.zeros(1) 23 | model.eval() 24 | 25 | for i, (inputs, labels) in enumerate(loader): 26 | 27 | inputs = inputs.to(device) 28 | labels = labels.to(device) 29 | labelsNp = np.concatenate((labelsNp, labels.cpu().numpy())) 30 | 31 | # Inference 32 | scores = torch.sigmoid(model(inputs)) 33 | 34 | preds = scores 35 | loss = loss_fun(scores, labels) 36 | predsNp = np.concatenate((predsNp, preds.cpu().numpy())) 37 | vl_loss += loss.item() 38 | 39 | # Compute AUC over the full (valid/test) set 40 | vl_acc = computeAuc(labelsNp[1:],predsNp[1:]) 41 | vl_loss = vl_loss/len(loader) 42 | 43 | return vl_acc, vl_loss 44 | 45 | # Miscellaneous initialization 46 | torch.manual_seed(1) 47 | start_time = time.time() 48 | 49 | parser = argparse.ArgumentParser() 50 | parser.add_argument('--num_epochs', type=int, default=100, help='Number of training epochs') 51 | parser.add_argument('--batch_size', type=int, default=512, help='Batch size') 52 | parser.add_argument('--lr', type=float, default=5e-4, help='Learning rate') 53 | parser.add_argument('--l2', type=float, default=0, help='L2 regularisation') 54 | parser.add_argument('--aug', action='store_true', default=False, help='Use data augmentation') 55 | parser.add_argument('--data_path', type=str, default='lidc/',help='Path to data.') 56 | parser.add_argument('--bond_dim', type=int, default=5, help='MPS Bond dimension') 57 | parser.add_argument('--nChannel', type=int, default=1, help='Number of input channels') 58 | parser.add_argument('--dense_net', action='store_true', 59 | default=False, help='Using Dense Net model') 60 | args = parser.parse_args() 61 | 62 | 63 | batch_size = args.batch_size 64 | 65 | # LoTeNet parameters 66 | adaptive_mode = False 67 | periodic_bc = False 68 | 69 | kernel = 2 # Stride along spatial dimensions 70 | output_dim = 1 # output dimension 71 | 72 | feature_dim = 2 73 | 74 | logFile = time.strftime("%Y%m%d_%H_%M")+'.txt' 75 | makeLogFile(logFile) 76 | 77 | normTensor = 0.5*torch.ones(args.nChannel) 78 | ### Data processing and loading.... 79 | trans_valid = transforms.Compose([transforms.Normalize(mean=normTensor,std=normTensor)]) 80 | 81 | if args.aug: 82 | trans_train = transforms.Compose([transforms.ToPILImage(), 83 | transforms.RandomHorizontalFlip(), 84 | transforms.RandomVerticalFlip(), 85 | transforms.RandomRotation(20), 86 | transforms.ToTensor(), 87 | transforms.Normalize(mean=normTensor,std=normTensor)]) 88 | print("Using Augmentation....") 89 | else: 90 | trans_train = trans_valid 91 | print("No augmentation....") 92 | 93 | # Load processed LIDC data 94 | dataset_train = LIDC(split='Train', data_dir=args.data_path, 95 | transform=trans_train,rater=4) 96 | dataset_valid = LIDC(split='Valid', data_dir=args.data_path, 97 | transform=trans_valid,rater=4) 98 | dataset_test = LIDC(split='Test', data_dir=args.data_path, 99 | transform=trans_valid,rater=4) 100 | 101 | num_train = len(dataset_train) 102 | num_valid = len(dataset_valid) 103 | num_test = len(dataset_test) 104 | print("Num. train = %d, Num. val = %d"%(num_train,num_valid)) 105 | 106 | loader_train = DataLoader(dataset = dataset_train, drop_last=True, 107 | batch_size=batch_size, shuffle=True) 108 | loader_valid = DataLoader(dataset = dataset_valid, drop_last=True, 109 | batch_size=batch_size, shuffle=False) 110 | loader_test = DataLoader(dataset = dataset_test, drop_last=True, 111 | batch_size=batch_size, shuffle=False) 112 | 113 | # Initiliaze input dimensions 114 | dim = torch.ShortTensor(list(dataset_train[0][0].shape[1:])) 115 | nCh = int(dataset_train[0][0].shape[0]) 116 | 117 | # Initialize the models 118 | if not args.dense_net: 119 | print("Using LoTeNet") 120 | model = loTeNet(input_dim=dim, output_dim=output_dim, 121 | nCh=nCh, kernel=kernel, 122 | bond_dim=args.bond_dim, feature_dim=feature_dim, 123 | adaptive_mode=adaptive_mode, periodic_bc=periodic_bc, virtual_dim=1) 124 | else: 125 | print("Densenet Baseline!") 126 | model = DenseNet(depth=40, growthRate=12, 127 | reduction=0.5,bottleneck=True,nClasses=output_dim) 128 | 129 | # Choose loss function and optimizer 130 | loss_fun = torch.nn.BCELoss() 131 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, 132 | weight_decay=args.l2) 133 | 134 | nParam = sum(p.numel() for p in model.parameters() if p.requires_grad) 135 | print("Number of parameters:%d"%(nParam)) 136 | print(f"Maximum MPS bond dimension = {args.bond_dim}") 137 | with open(logFile,"a") as f: 138 | print("Bond dim: %d"%(args.bond_dim),file=f) 139 | print("Number of parameters:%d"%(nParam),file=f) 140 | 141 | print(f"Using Adam w/ learning rate = {args.lr:.1e}") 142 | print("Feature_dim: %d, nCh: %d, B:%d"%(feature_dim,nCh,batch_size)) 143 | 144 | model = model.to(device) 145 | nValid = len(loader_valid) 146 | nTrain = len(loader_train) 147 | nTest = len(loader_test) 148 | 149 | maxAuc = 0 150 | minLoss = 1e3 151 | convCheck = 5 152 | convIter = 0 153 | 154 | # Let's start training! 155 | for epoch in range(args.num_epochs): 156 | running_loss = 0. 157 | running_acc = 0. 158 | t = time.time() 159 | model.train() 160 | predsNp = np.zeros(1) 161 | labelsNp = np.zeros(1) 162 | 163 | for i, (inputs, labels) in enumerate(loader_train): 164 | 165 | inputs = inputs.to(device) 166 | labels = labels.to(device) 167 | labelsNp = np.concatenate((labelsNp, labels.cpu().numpy())) 168 | 169 | scores = torch.sigmoid(model(inputs)) 170 | 171 | preds = scores 172 | loss = loss_fun(scores, labels) 173 | 174 | with torch.no_grad(): 175 | predsNp = np.concatenate((predsNp, preds.detach().cpu().numpy())) 176 | running_loss += loss 177 | 178 | # Backpropagate and update parameters 179 | optimizer.zero_grad() 180 | loss.backward() 181 | optimizer.step() 182 | 183 | if (i+1) % 5 == 0: 184 | print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}' 185 | .format(epoch+1, args.num_epochs, i+1, nTrain, loss.item())) 186 | 187 | accuracy = computeAuc(labelsNp,predsNp) 188 | 189 | # Evaluate on Validation set 190 | with torch.no_grad(): 191 | 192 | vl_acc, vl_loss = evaluate(loader_valid) 193 | if vl_acc > maxAuc or vl_loss < minLoss: 194 | if vl_loss < minLoss: 195 | minLoss = vl_loss 196 | if vl_acc > maxAuc: 197 | ### Predict on test set 198 | ts_acc, ts_loss = evaluate(loader_test) 199 | maxAuc = vl_acc 200 | print('New Max: %.4f'%maxAuc) 201 | print('Test Set Loss:%.4f Auc:%.4f'%(ts_loss, ts_acc)) 202 | with open(logFile,"a") as f: 203 | print('Test Set Loss:%.4f Auc:%.4f'%(ts_loss, ts_acc),file=f) 204 | convEpoch = epoch 205 | convIter = 0 206 | else: 207 | convIter += 1 208 | if convIter == convCheck: 209 | if not args.dense_net: 210 | print("MPS") 211 | else: 212 | print("DenseNet") 213 | print("Converged at epoch:%d with AUC:%.4f"%(convEpoch+1,maxAuc)) 214 | 215 | break 216 | writeLog(logFile, epoch, running_loss/nTrain, accuracy, 217 | vl_loss, vl_acc, time.time()-t) 218 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raghavian/loTeNet_pytorch/9b9814b3dea92c80be7fb0f26616cca93155e035/utils/__init__.py -------------------------------------------------------------------------------- /utils/lidc_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torchvision 6 | import torchvision.transforms as transforms 7 | from torch.utils.data import TensorDataset, DataLoader, Dataset 8 | import pdb 9 | 10 | class LIDC(Dataset): 11 | def __init__(self, rater=4, split='Train', data_dir = './', transform=None): 12 | super().__init__() 13 | 14 | self.data_dir = data_dir 15 | self.rater = rater 16 | self.transform = transform 17 | self.data, self.targets = torch.load(data_dir+split+'.pt') 18 | self.targets = self.targets.type(torch.FloatTensor) 19 | def __len__(self): 20 | return len(self.targets) 21 | 22 | def __getitem__(self, index): 23 | 24 | image, label = self.data[index], self.targets[index] 25 | if self.rater == 4: 26 | label = (label.sum() > 2).type_as(self.targets) 27 | else: 28 | label = label[self.rater] 29 | image = image.type(torch.FloatTensor)/255.0 30 | if self.transform is not None: 31 | image = self.transform(image) 32 | return image, label 33 | 34 | 35 | 36 | 37 | -------------------------------------------------------------------------------- /utils/model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raghavian/loTeNet_pytorch/9b9814b3dea92c80be7fb0f26616cca93155e035/utils/model.png -------------------------------------------------------------------------------- /utils/tools.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.sparse as sp 3 | import torch 4 | from os.path import isfile 5 | from os import rename 6 | SMOOTH=1 7 | import pdb 8 | from sklearn.metrics import auc, roc_curve 9 | import torch.nn.functional as F 10 | import torch.nn as nn 11 | from PIL.ImageFilter import GaussianBlur 12 | 13 | 14 | def wCELoss(prediction, target): 15 | w1 = 1.33 # False negative penalty 16 | w2 = .66 # False positive penalty 17 | return -torch.mean(w1 * target * torch.log(prediction.clamp_min(1e-3)) 18 | + w2 * (1. - target) * torch.log(1. - prediction.clamp_max(.999))) 19 | 20 | class GaussianFilter(object): 21 | """Apply Gaussian blur to the PIL image 22 | Args: 23 | sigma (float): Sigma of Gaussian kernel. Default value 1.0 24 | """ 25 | def __init__(self, sigma=1): 26 | self.sigma = sigma 27 | self.filter = GaussianBlur(radius=sigma) 28 | 29 | def __call__(self, img): 30 | """ 31 | Args: 32 | img (PIL Image): Image to be blurred. 33 | 34 | Returns: 35 | PIL Image: Blurred image. 36 | """ 37 | return img.filter(self.filter) 38 | 39 | def __repr__(self): 40 | return self.__class__.__name__ + '(sigma={})'.format(self.sigma) 41 | 42 | class GaussianLayer(nn.Module): 43 | def __init__(self): 44 | super(GaussianLayer, self, sigma=1, size=10).__init__() 45 | self.sigma = sigma 46 | self.size = size 47 | self.seq = nn.Sequential( 48 | nn.ReflectionPad2d(size), 49 | nn.Conv2d(3, 3, size, stride=1, padding=0, bias=None, groups=3) 50 | ) 51 | self.weights_init() 52 | 53 | def forward(self, x): 54 | return self.seq(x) 55 | 56 | def weights_init(self): 57 | s = self.size * 2 + 1 58 | k = np.zeros((s,s)) 59 | k[s,s] = 1 60 | kernel = gaussian_filter(k,sigma=self.sigma) 61 | for name, f in self.named_parameters(): 62 | f.data.copy_(torch.from_numpy(kernel)) 63 | 64 | class focalLoss(nn.Module): 65 | def __init__(self, alpha=1, gamma=2, logits=False, reduce=True): 66 | super(focalLoss, self).__init__() 67 | self.alpha = alpha 68 | self.gamma = gamma 69 | self.logits = logits 70 | self.reduce = reduce 71 | 72 | def forward(self, inputs, targets): 73 | if self.logits: 74 | BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduce=False) 75 | else: 76 | BCE_loss = F.binary_cross_entropy(inputs, targets, reduce=False) 77 | pt = torch.exp(-BCE_loss) 78 | F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss 79 | 80 | if self.reduce: 81 | return torch.mean(F_loss) 82 | else: 83 | return F_loss 84 | 85 | def computeAuc(target,preds): 86 | fpr, tpr, thresholds = roc_curve(target,preds) 87 | aucVal = auc(fpr,tpr) 88 | return aucVal 89 | 90 | class hingeLoss(torch.nn.Module): 91 | 92 | def __init__(self): 93 | super(hingeLoss, self).__init__() 94 | 95 | def forward(self, output, target): 96 | # pdb.set_trace() 97 | target = 2*target-1 98 | output = 2*output-1 99 | hinge_loss = 1 - torch.mul(output, target) 100 | hinge_loss[hinge_loss < 0] = 0 101 | return hinge_loss.mean() 102 | 103 | 104 | def makeBatchAdj(adj,bSize): 105 | 106 | E = adj._nnz() 107 | N = adj.shape[0] 108 | batch_idx = torch.zeros(2,bSize*E).type(torch.LongTensor) 109 | batch_val = torch.zeros(bSize*E) 110 | 111 | idx = adj._indices() 112 | vals = adj._values() 113 | 114 | for i in range(bSize): 115 | batch_idx[:,i*E:(i+1)*E] = idx + i*N 116 | batch_val[i*E:(i+1)*E] = vals 117 | 118 | return torch.sparse.FloatTensor(batch_idx,batch_val,(bSize*N,bSize*N)) 119 | 120 | 121 | def makeAdj(ngbrs, normalize=True): 122 | """ Create an adjacency matrix, given the neighbour indices 123 | Input: Nxd neighbourhood, where N is number of nodes 124 | Output: NxN sparse torch adjacency matrix 125 | """ 126 | # pdb.set_trace() 127 | N, d = ngbrs.shape 128 | validNgbrs = (ngbrs >= 0) # Mask for valid neighbours amongst the d-neighbours 129 | row = np.repeat(np.arange(N),d) # Row indices like in sparse matrix formats 130 | row = row[validNgbrs.reshape(-1)] #Remove non-neighbour row indices 131 | col = (ngbrs*validNgbrs).reshape(-1) # Obtain nieghbour col indices 132 | col = col[validNgbrs.reshape(-1)] # Remove non-neighbour col indices 133 | data = np.ones(col.size) 134 | adj = sp.csr_matrix((np.ones(col.size, dtype=bool),(row, col)), shape=(N, N)).toarray() # Make adj matrix 135 | adj = adj + np.eye(N) # Self connections 136 | adj = sp.csr_matrix(adj, dtype=np.float32)#/(d+1) 137 | if normalize: 138 | adj = row_normalize(adj) 139 | adj = sparse_mx_to_torch_sparse_tensor(adj) 140 | 141 | return adj 142 | 143 | def makeRegAdj(numNgbrs=26): 144 | """ Make regular pixel neighbourhoods""" 145 | idx = 0 146 | ngbrOffset = np.zeros((3,numNgbrs),dtype=int) 147 | for i in range(-1,2): 148 | for j in range(-1,2): 149 | for k in range(-1,2): 150 | if(i | j | k): 151 | ngbrOffset[:,idx] = [i,j,k] 152 | idx+=1 153 | idx = 0 154 | ngbrs = np.zeros((numEl, numNgbrs), dtype=int) 155 | 156 | for i in range(xdim): 157 | for j in range(ydim): 158 | for k in range(zdim): 159 | xIdx = np.mod(ngbrOffset[0,:]+i,xdim) 160 | yIdx = np.mod(ngbrOffset[1,:]+j,ydim) 161 | zIdx = np.mod(ngbrOffset[2,:]+k,zdim) 162 | ngbrs[idx,:] = idxVol[xIdx, yIdx, zIdx] 163 | idx += 1 164 | 165 | 166 | def makeAdjWithInvNgbrs(ngbrs, normalize=False): 167 | """ Create an adjacency matrix, given the neighbour indices including invalid indices where self connections are added. 168 | Input: Nxd neighbourhood, where N is number of nodes 169 | Output: NxN sparse torch adjacency matrix 170 | """ 171 | np.random.seed(2) 172 | # pdb.set_trace() 173 | N, d = ngbrs.shape 174 | row = np.arange(N).reshape(-1,1) 175 | random = np.random.randint(0,N-1,(N,d)) 176 | valIdx = np.array((ngbrs < 0),dtype=int) 177 | ngbrs = random*valIdx + ngbrs*(1-valIdx)# Mask for valid neighbours amongst the d-neighbours 178 | row = np.repeat(row,d).reshape(-1) # Row indices like in sparse matrix formats 179 | col = ngbrs.reshape(-1) # Obtain nieghbour col indices 180 | data = np.ones(col.size) 181 | adj = sp.csr_matrix((np.ones(col.size, dtype=bool),(row, col)), shape=(N, N)).toarray() # Make adj matrix 182 | adj = adj + np.eye(N) # Self connections 183 | adj = sp.csr_matrix(adj, dtype=np.float32)#/(d+1) 184 | if normalize: 185 | adj = row_normalize(adj) 186 | adj = sparse_mx_to_torch_sparse_tensor(adj) 187 | adj = adj.coalesce() 188 | adj._values = adj.values() 189 | return adj 190 | 191 | 192 | def transformers(adj): 193 | """ Obtain source and sink node transformer matrices""" 194 | edges = adj._indices() 195 | N = adj.shape[0] 196 | nnz = adj._nnz() 197 | val = torch.ones(nnz) 198 | idx0 = torch.arange(nnz) 199 | 200 | idx = torch.stack((idx0,edges[1,:])) 201 | n2e_in = torch.sparse.FloatTensor(idx,val,(nnz,N)) 202 | 203 | idx = torch.stack((idx0,edges[0,:])) 204 | n2e_out = torch.sparse.FloatTensor(idx,val,(nnz,N)) 205 | 206 | return n2e_in, n2e_out 207 | 208 | def sparse_mx_to_torch_sparse_tensor(sparse_mx): 209 | """Convert a scipy sparse matrix to a torch sparse tensor.""" 210 | sparse_mx = sparse_mx.tocoo().astype(np.float32) 211 | indices = torch.from_numpy(np.vstack((sparse_mx.row, 212 | sparse_mx.col))).long() 213 | values = torch.from_numpy(sparse_mx.data) 214 | shape = torch.Size(sparse_mx.shape) 215 | return torch.sparse.FloatTensor(indices, values, shape) 216 | 217 | def to_linear_idx(x_idx, y_idx, num_cols): 218 | assert num_cols > np.max(x_idx) 219 | x_idx = np.array(x_idx, dtype=np.int32) 220 | y_idx = np.array(y_idx, dtype=np.int32) 221 | return y_idx * num_cols + x_idx 222 | 223 | 224 | def row_normalize(mx): 225 | """Row-normalize sparse matrix""" 226 | rowsum = np.array(mx.sum(1), dtype=np.float32) 227 | r_inv = np.power(rowsum, -1).flatten() 228 | r_inv[np.isinf(r_inv)] = 0. 229 | r_mat_inv = sp.diags(r_inv) 230 | mx = r_mat_inv.dot(mx) 231 | return mx 232 | 233 | def to_2d_idx(idx, num_cols): 234 | idx = np.array(idx, dtype=np.int64) 235 | y_idx = np.array(np.floor(idx / float(num_cols)), dtype=np.int64) 236 | x_idx = idx % num_cols 237 | return x_idx, y_idx 238 | 239 | def dice_loss(preds, labels): 240 | "Return dice score. " 241 | preds_sq = preds**2 242 | return 1 - (2. * (torch.sum(preds * labels)) + SMOOTH) / \ 243 | (preds_sq.sum() + labels.sum() + SMOOTH) 244 | 245 | def binary_accuracy(output, labels): 246 | preds = output > 0.5 247 | correct = preds.type_as(labels).eq(labels).double() 248 | correct = correct.sum() 249 | return correct / len(labels) 250 | 251 | def multiClassAccuracy(output, labels): 252 | # pdb.set_trace() 253 | preds = output.argmax(1) 254 | # preds = (output > (1.0/labels.shape[1])).type_as(labels) 255 | correct = (preds == labels.view(-1)) 256 | correct = correct.sum().float() 257 | return correct / len(labels) 258 | 259 | def regrAcc(output, labels): 260 | # pdb.set_trace() 261 | preds = output.round().type(torch.long).type_as(labels) 262 | # preds = (output > (1.0/labels.shape[1])).type_as(labels) 263 | correct = (preds == labels.view(-1)) 264 | correct = correct.sum().float() 265 | return correct / len(labels) 266 | 267 | 268 | def rescaledRegAcc(output,labels,lRange=37,lMin=-20): 269 | # pdb.set_trace() 270 | preds = (output+1)*(lRange)/2 + lMin 271 | preds = preds.round().type(torch.long).type_as(labels) 272 | # preds = (output > (1.0/labels.shape[1])).type_as(labels) 273 | correct = (preds == labels.view(-1)) 274 | correct = correct.sum().float() 275 | return correct / len(labels) 276 | 277 | def focalCE(preds, labels, gamma=1): 278 | "Return focal cross entropy" 279 | loss = -torch.mean( ( ((1-preds)**gamma) * labels * torch.log(preds) ) \ 280 | + ( ((preds)**gamma) * (1-labels) * torch.log(1-preds) ) ) 281 | return loss 282 | 283 | def dice(preds, labels): 284 | # pdb.set_trace() 285 | "Return dice score" 286 | preds_bin = (preds > 0.5).type_as(labels) 287 | return 2. * torch.sum(preds_bin * labels) / (preds_bin.sum() + labels.sum()) 288 | 289 | def wBCE(preds, labels, w): 290 | "Return weighted CE loss." 291 | return -torch.mean( w*labels*torch.log(preds) + (1-w)*(1-labels)*torch.log(1-preds) ) 292 | 293 | def makeLogFile(filename="lossHistory.txt"): 294 | if isfile(filename): 295 | rename(filename,"lossHistoryOld.txt") 296 | 297 | with open(filename,"w") as text_file: 298 | print('Epoch\tlossTr\taccTr\tlossVl\taccVl\ttime(s)',file=text_file) 299 | print("Log file created...") 300 | return 301 | 302 | def writeLog(logFile, epoch, lossTr, accTr, lossVl, accVl,eTime): 303 | print('Epoch:{:04d}\t'.format(epoch + 1), 304 | 'lossTr:{:.4f}\t'.format(lossTr), 305 | 'accTr:{:.4f}\t'.format(accTr), 306 | 'lossVl:{:.4f}\t'.format(lossVl), 307 | 'accVl:{:.4f}\t'.format(accVl), 308 | 'time:{:.4f}'.format(eTime)) 309 | 310 | with open(logFile,"a") as text_file: 311 | print('{:04d}\t'.format(epoch + 1), 312 | '{:.4f}\t'.format(lossTr), 313 | '{:.4f}\t'.format(accTr), 314 | '{:.4f}\t'.format(lossVl), 315 | '{:.4f}\t'.format(accVl), 316 | '{:.4f}'.format(eTime),file=text_file) 317 | return 318 | 319 | def plotLearningCurve(): 320 | plt.clf() 321 | tmp = np.load('loss_tr.npz')['arr_0'] 322 | plt.plot(tmp,label='Tr.Loss') 323 | tmp = np.load('loss_vl.npz')['arr_0'] 324 | plt.plot(tmp,label='Vl.Loss') 325 | tmp = np.load('dice_tr.npz')['arr_0'] 326 | plt.plot(tmp,label='Tr.Dice') 327 | tmp = np.load('dice_vl.npz')['arr_0'] 328 | plt.plot(tmp,label='Vl.Dice') 329 | plt.legend() 330 | plt.grid() 331 | plt.show() 332 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | def svd_flex(tensor, svd_string, max_D=None, cutoff=1e-10, sv_right=True, 5 | sv_vec=None): 6 | """ 7 | Split an input tensor into two pieces using a SVD across some partition 8 | 9 | Args: 10 | tensor (Tensor): Pytorch tensor with at least two indices 11 | 12 | svd_string (str): String of the form 'init_str->left_str,right_str', 13 | where init_str describes the indices of tensor, and 14 | left_str/right_str describe those of the left and 15 | right output tensors. The characters of left_str 16 | and right_str form a partition of the characters in 17 | init_str, but each contain one additional character 18 | representing the new bond which comes from the SVD 19 | 20 | Reversing the terms in svd_string to the left and 21 | right of '->' gives an ein_string which can be used 22 | to multiply both output tensors to give a (low rank 23 | approximation) of the input tensor 24 | 25 | cutoff (float): A truncation threshold which eliminates any 26 | singular values which are strictly less than cutoff 27 | 28 | max_D (int): A maximum allowed value for the new bond. If max_D 29 | is specified, the returned tensors 30 | 31 | sv_right (bool): The SVD gives two orthogonal matrices and a matrix 32 | of singular values. sv_right=True merges the SV 33 | matrix with the right output, while sv_right=False 34 | merges it with the left output 35 | 36 | sv_vec (Tensor): Pytorch vector with length max_D, which is modified 37 | in place to return the vector of singular values 38 | 39 | Returns: 40 | left_tensor (Tensor), 41 | right_tensor (Tensor): Tensors whose indices are described by the 42 | left_str and right_str parts of svd_string 43 | 44 | bond_dim: The dimension of the new bond appearing from 45 | the cutoff in our SVD. Note that this generally 46 | won't match the dimension of left_/right_tensor 47 | at this mode, which is padded with zeros 48 | whenever max_D is specified 49 | """ 50 | def prod(int_list): 51 | output = 1 52 | for num in int_list: 53 | output *= num 54 | return output 55 | 56 | with torch.no_grad(): 57 | # Parse svd_string into init_str, left_str, and right_str 58 | svd_string = svd_string.replace(' ', '') 59 | init_str, post_str = svd_string.split('->') 60 | left_str, right_str = post_str.split(',') 61 | 62 | # Check formatting of init_str, left_str, and right_str 63 | assert all([c.islower() for c in init_str+left_str+right_str]) 64 | assert len(set(init_str+left_str+right_str)) == len(init_str) + 1 65 | assert len(set(init_str))+len(set(left_str))+len(set(right_str)) == \ 66 | len(init_str)+len(left_str)+len(right_str) 67 | 68 | # Get the special character representing our SVD-truncated bond 69 | bond_char = set(left_str).intersection(set(right_str)).pop() 70 | left_part = left_str.replace(bond_char, '') 71 | right_part = right_str.replace(bond_char, '') 72 | 73 | # Permute our tensor into something that can be viewed as a matrix 74 | ein_str = f"{init_str}->{left_part+right_part}" 75 | tensor = torch.einsum(ein_str, [tensor]).contiguous() 76 | 77 | left_shape = list(tensor.shape[:len(left_part)]) 78 | right_shape = list(tensor.shape[len(left_part):]) 79 | left_dim, right_dim = prod(left_shape), prod(right_shape) 80 | 81 | tensor = tensor.view([left_dim, right_dim]) 82 | 83 | # Get SVD and format so that left_mat * diag(svs) * right_mat = tensor 84 | left_mat, svs, right_mat = torch.svd(tensor) 85 | svs, _ = torch.sort(svs, descending=True) 86 | right_mat = torch.t(right_mat) 87 | 88 | # Decrease or increase our tensor sizes in the presence of max_D 89 | if max_D and len(svs) > max_D: 90 | svs = svs[:max_D] 91 | left_mat = left_mat[:, :max_D] 92 | right_mat = right_mat[:max_D] 93 | elif max_D and len(svs) < max_D: 94 | copy_svs = torch.zeros([max_D]) 95 | copy_svs[:len(svs)] = svs 96 | copy_left = torch.zeros([left_mat.size(0), max_D]) 97 | copy_left[:, :left_mat.size(1)] = left_mat 98 | copy_right = torch.zeros([max_D, right_mat.size(1)]) 99 | copy_right[:right_mat.size(0)] = right_mat 100 | svs, left_mat, right_mat = copy_svs, copy_left, copy_right 101 | 102 | # If given as input, copy singular values into sv_vec 103 | if sv_vec is not None and svs.shape == sv_vec.shape: 104 | sv_vec[:] = svs 105 | elif sv_vec is not None and svs.shape != sv_vec.shape: 106 | raise TypeError(f"sv_vec.shape must be {list(svs.shape)}, but is " 107 | f"currently {list(sv_vec.shape)}") 108 | 109 | # Find the truncation point relative to our singular value cutoff 110 | truncation = 0 111 | for s in svs: 112 | if s < cutoff: 113 | break 114 | truncation += 1 115 | if truncation == 0: 116 | raise RuntimeError("SVD cutoff too large, attempted to truncate " 117 | "tensor to bond dimension 0") 118 | 119 | # Perform the actual truncation 120 | if max_D: 121 | svs[truncation:] = 0 122 | left_mat[:, truncation:] = 0 123 | right_mat[truncation:] = 0 124 | else: 125 | # If max_D wasn't given, set it to the truncation index 126 | max_D = truncation 127 | svs = svs[:truncation] 128 | left_mat = left_mat[:, :truncation] 129 | right_mat = right_mat[:truncation] 130 | 131 | # Merge the singular values into the appropriate matrix 132 | if sv_right: 133 | right_mat = torch.einsum('l,lr->lr', [svs, right_mat]) 134 | else: 135 | left_mat = torch.einsum('lr,r->lr', [left_mat, svs]) 136 | 137 | # Reshape the matrices to make them proper tensors 138 | left_tensor = left_mat.view(left_shape+[max_D]) 139 | right_tensor = right_mat.view([max_D]+right_shape) 140 | 141 | # Finally, permute the indices into the desired order 142 | if left_str != left_part + bond_char: 143 | left_tensor = torch.einsum(f"{left_part+bond_char}->{left_str}", 144 | [left_tensor]) 145 | if right_str != bond_char + right_part: 146 | right_tensor = torch.einsum(f"{bond_char+right_part}->{right_str}", 147 | [right_tensor]) 148 | 149 | return left_tensor, right_tensor, truncation 150 | 151 | def init_tensor(shape, bond_str, init_method): 152 | """ 153 | Initialize a tensor with a given shape 154 | 155 | Args: 156 | shape: The shape of our output parameter tensor. 157 | 158 | bond_str: The bond string describing our output parameter tensor, 159 | which is used in 'random_eye' initialization method. 160 | The characters 'l' and 'r' are used to refer to the 161 | left or right virtual indices of our tensor, and are 162 | both required to be present for the random_eye and 163 | min_random_eye initialization methods. 164 | 165 | init_method: The method used to initialize the entries of our tensor. 166 | This can be either a string, or else a tuple whose first 167 | entry is an initialization method and whose remaining 168 | entries are specific to that method. In each case, std 169 | will always refer to a standard deviation for a random 170 | normal random component of each entry of the tensor. 171 | 172 | Allowed options are: 173 | * ('random_eye', std): Initialize each tensor input 174 | slice close to the identity 175 | * ('random_zero', std): Initialize each tensor input 176 | slice close to the zero matrix 177 | * ('min_random_eye', std, init_dim): Initialize each 178 | tensor input slice close to a truncated identity 179 | matrix, whose truncation leaves init_dim unit 180 | entries on the diagonal. If init_dim is larger 181 | than either of the bond dimensions, then init_dim 182 | is capped at the smaller bond dimension. 183 | """ 184 | # Unpack init_method if it is a tuple 185 | if not isinstance(init_method, str): 186 | init_str = init_method[0] 187 | std = init_method[1] 188 | if init_str == 'min_random_eye': 189 | init_dim = init_method[2] 190 | 191 | init_method = init_str 192 | else: 193 | std = 1e-9 194 | 195 | # Check that bond_str is properly sized and doesn't have repeat indices 196 | assert len(shape) == len(bond_str) 197 | assert len(set(bond_str)) == len(bond_str) 198 | 199 | if init_method not in ['random_eye', 'min_random_eye', 'random_zero']: 200 | raise ValueError(f"Unknown initialization method: {init_method}") 201 | 202 | if init_method in ['random_eye', 'min_random_eye']: 203 | bond_chars = ['l', 'r'] 204 | assert all([c in bond_str for c in bond_chars]) 205 | 206 | # Initialize our tensor slices as identity matrices which each fill 207 | # some or all of the initially allocated bond space 208 | if init_method == 'min_random_eye': 209 | 210 | # The dimensions for our initial identity matrix. These will each 211 | # be init_dim, unless init_dim exceeds one of the bond dimensions 212 | bond_dims = [shape[bond_str.index(c)] for c in bond_chars] 213 | if all([init_dim <= full_dim for full_dim in bond_dims]): 214 | bond_dims = [init_dim, init_dim] 215 | else: 216 | init_dim = min(bond_dims) 217 | 218 | eye_shape = [init_dim if c in bond_chars else 1 for c in bond_str] 219 | expand_shape = [init_dim if c in bond_chars else shape[i] 220 | for i, c in enumerate(bond_str)] 221 | 222 | elif init_method == 'random_eye': 223 | eye_shape = [shape[i] if c in bond_chars else 1 224 | for i, c in enumerate(bond_str)] 225 | expand_shape = shape 226 | bond_dims = [shape[bond_str.index(c)] for c in bond_chars] 227 | 228 | eye_tensor = torch.eye(bond_dims[0], bond_dims[1]).view(eye_shape) 229 | eye_tensor = eye_tensor.expand(expand_shape) 230 | 231 | tensor = torch.zeros(shape) 232 | tensor[[slice(dim) for dim in expand_shape]] = eye_tensor 233 | 234 | # Add on a bit of random noise 235 | tensor += std * torch.randn(shape) 236 | 237 | elif init_method == 'random_zero': 238 | tensor = std * torch.randn(shape) 239 | 240 | return tensor 241 | 242 | 243 | ### OLDER MISCELLANEOUS FUNCTIONS ### 244 | 245 | def onehot(labels, max_value): 246 | """ 247 | Convert a batch of labels from the set {0, 1,..., num_value-1} into their 248 | onehot encoded counterparts 249 | """ 250 | label_vecs = torch.zeros([len(labels), max_value]) 251 | 252 | for i, label in enumerate(labels): 253 | label_vecs[i, label] = 1. 254 | 255 | return label_vecs 256 | 257 | def joint_shuffle(input_data, input_labels): 258 | """ 259 | Shuffle input data and labels in a joint manner, so each label points to 260 | its corresponding datum. Works for both regular and CUDA tensors 261 | """ 262 | assert input_data.is_cuda == input_labels.is_cuda 263 | use_gpu = input_data.is_cuda 264 | if use_gpu: 265 | input_data, input_labels = input_data.cpu(), input_labels.cpu() 266 | 267 | data, labels = input_data.numpy(), input_labels.numpy() 268 | 269 | # Shuffle relative to the same seed 270 | np.random.seed(0) 271 | np.random.shuffle(data) 272 | np.random.seed(0) 273 | np.random.shuffle(labels) 274 | 275 | data, labels = torch.from_numpy(data), torch.from_numpy(labels) 276 | if use_gpu: 277 | data, labels = data.cuda(), labels.cuda() 278 | 279 | return data, labels 280 | 281 | def load_HV_data(length): 282 | """ 283 | Output a toy "horizontal/vertical" data set of black and white 284 | images with size length x length. Each image contains a single 285 | horizontal or vertical stripe, set against a background 286 | of the opposite color. The labels associated with these images 287 | are either 0 (horizontal stripe) or 1 (vertical stripe). 288 | 289 | In its current version, this returns two data sets, a training 290 | set with 75% of the images and a test set with 25% of the 291 | images. 292 | """ 293 | num_images = 4 * (2**(length-1) - 1) 294 | num_patterns = num_images // 2 295 | split = num_images // 4 296 | 297 | if length > 14: 298 | print("load_HV_data will generate {} images, " 299 | "this could take a while...".format(num_images)) 300 | 301 | images = np.empty([num_images,length,length], dtype=np.float32) 302 | labels = np.empty(num_images, dtype=np.int) 303 | 304 | # Used to generate the stripe pattern from integer i below 305 | template = "{:0" + str(length) + "b}" 306 | 307 | for i in range(1, num_patterns+1): 308 | pattern = template.format(i) 309 | pattern = [int(s) for s in pattern] 310 | 311 | for j, val in enumerate(pattern): 312 | # Horizontal stripe pattern 313 | images[2*i-2, j, :] = val 314 | # Vertical stripe pattern 315 | images[2*i-1, :, j] = val 316 | 317 | labels[2*i-2] = 0 318 | labels[2*i-1] = 1 319 | 320 | # Shuffle and partition into training and test sets 321 | np.random.seed(0) 322 | np.random.shuffle(images) 323 | np.random.seed(0) 324 | np.random.shuffle(labels) 325 | 326 | train_images, train_labels = images[split:], labels[split:] 327 | test_images, test_labels = images[:split], labels[:split] 328 | 329 | return torch.from_numpy(train_images), \ 330 | torch.from_numpy(train_labels), \ 331 | torch.from_numpy(test_images), \ 332 | torch.from_numpy(test_labels) 333 | --------------------------------------------------------------------------------