├── MERA_MNIST.ipynb ├── README.md ├── models ├── Densenet.py ├── MERA.py ├── contractables.py ├── contractables.pyc ├── lotenet.py ├── lotenet.pyc ├── mps.py └── mps.pyc ├── train_MERA.py └── utils ├── MNIST_reader.py ├── lidc_dataset.py ├── needle_dataset.py ├── tools.py ├── utils.py └── utils.pyc /MERA_MNIST.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import numpy as np\n", 10 | "import matplotlib.pyplot as plt\n", 11 | "import tensorflow as tf\n", 12 | "tf.compat.v1.enable_v2_behavior\n", 13 | "# Import tensornetwork\n", 14 | "import tensornetwork as tn\n", 15 | "# Set the backend to tesorflow\n", 16 | "# (default is numpy)\n", 17 | "tn.set_default_backend(\"tensorflow\")\n", 18 | "\n", 19 | "(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data('')\n", 20 | "from tensorflow.keras.utils import to_categorical\n", 21 | "\n", 22 | "x_train = x_train.reshape((60000, 28, 28, 1)).astype(np.float32)\n", 23 | "y_train = to_categorical(y_train, 10).astype(np.float32)\n", 24 | "x_test = x_test.reshape((10000, 28, 28, 1))\n", 25 | "y_test = to_categorical(y_test, 10)\n", 26 | "\n", 27 | "xxx_train = (x_train-128)/255\n", 28 | "xxx_test = (x_test-128)/255\n", 29 | "\n", 30 | "xx_train = (tf.image.resize(x_train, [16,16]).numpy()-128)/255\n", 31 | "xx_test = (tf.image.resize(x_test, [16,16]).numpy()-128)/255" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": null, 37 | "metadata": {}, 38 | "outputs": [], 39 | "source": [ 40 | "class Grid4DMERA(tf.keras.layers.Layer):\n", 41 | " \n", 42 | " def __init__(self, input_dim, bond_dims, output_dims, n_layers=None):\n", 43 | " super(Grid4DMERA, self).__init__()\n", 44 | " # Create the variables for the layer.\n", 45 | " # In this case, the input tensor is (, 1936), we factorize it into a tensor (, 11, 11, 16)\n", 46 | " # first_dim: output shape?\n", 47 | " # second_dim: connect with data tensor\n", 48 | " # third_dim: inter-connect\n", 49 | " if n_layers is None:\n", 50 | " n_layers = np.floor(np.log2(input_dim))\n", 51 | " self.n_layers = n_layers\n", 52 | " in_dims = 16\n", 53 | " dims = input_dim\n", 54 | " self.entanglers = []\n", 55 | " self.isometries= []\n", 56 | " \n", 57 | " #entanglers\n", 58 | " self.entanglers1 = tf.Variable(tf.random.normal\n", 59 | " (shape=(in_dims, in_dims, \n", 60 | " in_dims, in_dims, bond_dims, bond_dims, bond_dims, bond_dims),\n", 61 | " stddev=1.0/10000), \n", 62 | " trainable=True)\n", 63 | " self.entanglers2 = tf.Variable(tf.random.normal\n", 64 | " (shape=(bond_dims, bond_dims, \n", 65 | " bond_dims, bond_dims, bond_dims, bond_dims, bond_dims, bond_dims),\n", 66 | " stddev=1.0/10000), \n", 67 | " trainable=True)\n", 68 | " # isometries\n", 69 | " self.isometries1 = [tf.Variable(tf.random.normal(shape=(in_dims, in_dims, in_dims, \n", 70 | " bond_dims, bond_dims)\n", 71 | " , stddev=1.0/10*10000),\n", 72 | " trainable=True), \n", 73 | " tf.Variable(tf.random.normal(shape=(in_dims, in_dims, bond_dims, \n", 74 | " in_dims, bond_dims)\n", 75 | " , stddev=1.0/10*10000),\n", 76 | " trainable=True),\n", 77 | " tf.Variable(tf.random.normal(shape=(in_dims, bond_dims, in_dims, \n", 78 | " in_dims, bond_dims)\n", 79 | " , stddev=1.0/10*10000),\n", 80 | " trainable=True),\n", 81 | " tf.Variable(tf.random.normal(shape=(bond_dims, in_dims, in_dims, \n", 82 | " in_dims, bond_dims)\n", 83 | " , stddev=1.0/10*10000),\n", 84 | " trainable=True)]\n", 85 | " \n", 86 | " self.isometries2 = tf.Variable(tf.random.normal(shape=(bond_dims, bond_dims, bond_dims, \n", 87 | " bond_dims, output_dims)\n", 88 | " , stddev=1.0/10*10000),\n", 89 | " trainable=True)\n", 90 | "\n", 91 | " #print(self.final_mps.shape)\n", 92 | " self.bias = tf.Variable(tf.zeros(shape=(output_dims,)), name=\"bias\", trainable=True)\n", 93 | "\n", 94 | "\n", 95 | " def call(self, inputs):\n", 96 | " # Define the contraction.\n", 97 | " # We break it out so we can parallelize a batch using tf.vectorized_map.\n", 98 | " def f(input_vec, entanglers1, entanglers2, isometries1, isometries2, bias_var, n_layers):\n", 99 | " input_vv = []\n", 100 | " for i in range(4):\n", 101 | " for ii in range(4):\n", 102 | " input_vv.append(tf.reshape(input_vec[i*4:i*4+4, ii*4:ii*4+4, 0], (1, 16)))\n", 103 | " input_vec = tf.concat(input_vv, axis=0)\n", 104 | " input_vec = tf.reshape(input_vec, (16, 16))\n", 105 | " input_vec = tf.unstack(input_vec)\n", 106 | " input_nodes = []\n", 107 | " for e_iv in input_vec:\n", 108 | " input_nodes.append(tn.Node(e_iv))\n", 109 | " \n", 110 | " e_nodes1 = tn.Node(entanglers1)\n", 111 | " e_nodes2 = tn.Node(entanglers2)\n", 112 | " \n", 113 | " \n", 114 | " isometries_nodes1 = []\n", 115 | " for eiso in isometries1:\n", 116 | " isometries_nodes1.append(tn.Node(eiso))\n", 117 | " isometries_nodes2 = tn.Node(isometries2)\n", 118 | " \n", 119 | " \n", 120 | " e_nodes1[0] ^ input_nodes[5][0]\n", 121 | " e_nodes1[1] ^ input_nodes[6][0]\n", 122 | " e_nodes1[2] ^ input_nodes[9][0]\n", 123 | " e_nodes1[3] ^ input_nodes[10][0]\n", 124 | "\n", 125 | " e_nodes1[4] ^ isometries_nodes1[0][3]\n", 126 | " e_nodes1[5] ^ isometries_nodes1[1][2]\n", 127 | " e_nodes1[6] ^ isometries_nodes1[2][1]\n", 128 | " e_nodes1[7] ^ isometries_nodes1[3][0] \n", 129 | " \n", 130 | " input_nodes[0][0] ^ isometries_nodes1[0][0]\n", 131 | " input_nodes[1][0] ^ isometries_nodes1[0][1]\n", 132 | " input_nodes[4][0] ^ isometries_nodes1[0][2]\n", 133 | " \n", 134 | " input_nodes[2][0] ^ isometries_nodes1[1][0]\n", 135 | " input_nodes[3][0] ^ isometries_nodes1[1][1]\n", 136 | " input_nodes[7][0] ^ isometries_nodes1[1][3]\n", 137 | " \n", 138 | " input_nodes[8][0] ^ isometries_nodes1[2][0]\n", 139 | " input_nodes[12][0] ^ isometries_nodes1[2][2]\n", 140 | " input_nodes[13][0] ^ isometries_nodes1[2][3]\n", 141 | " \n", 142 | " input_nodes[11][0] ^ isometries_nodes1[3][1]\n", 143 | " input_nodes[14][0] ^ isometries_nodes1[3][2]\n", 144 | " input_nodes[15][0] ^ isometries_nodes1[3][3]\n", 145 | " \n", 146 | " \n", 147 | " isometries_nodes1[0][4] ^ e_nodes2[0]\n", 148 | " isometries_nodes1[1][4] ^ e_nodes2[1]\n", 149 | " isometries_nodes1[2][4] ^ e_nodes2[2]\n", 150 | " isometries_nodes1[3][4] ^ e_nodes2[3]\n", 151 | "\n", 152 | " e_nodes2[4] ^ isometries_nodes2[0]\n", 153 | " e_nodes2[5] ^ isometries_nodes2[1]\n", 154 | " e_nodes2[6] ^ isometries_nodes2[2]\n", 155 | " e_nodes2[7] ^ isometries_nodes2[3]\n", 156 | "\n", 157 | " \n", 158 | " nodes = tn.reachable(isometries_nodes2)\n", 159 | " result = tn.contractors.greedy(nodes)\n", 160 | " result = result.tensor\n", 161 | " #print(result)\n", 162 | " #result = (c @ b).tensor\n", 163 | " # Finally, add bias.\n", 164 | " return result + bias_var\n", 165 | "\n", 166 | " # To deal with a batch of items, we can use the tf.vectorized_map function.\n", 167 | " # https://www.tensorflow.org/api_docs/python/tf/vectorized_map\n", 168 | " output = tf.vectorized_map(lambda vec: f(vec, self.entanglers1, self.entanglers2,\n", 169 | " self.isometries1, self.isometries2, self.bias, self.n_layers), inputs)\n", 170 | " return tf.reshape(output, (-1, 10))" 171 | ] 172 | }, 173 | { 174 | "cell_type": "code", 175 | "execution_count": null, 176 | "metadata": {}, 177 | "outputs": [], 178 | "source": [] 179 | } 180 | ], 181 | "metadata": {}, 182 | "nbformat": 4, 183 | "nbformat_minor": 5 184 | } 185 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MERA_Image_Classification 2 | ## Code Contributor: Fanjie Kong 3 | #### Finished Work: 4 | 1. Implemented 2D MERA model using PyTorch and TensorFlow. TensorFlow version is more time-efficient. 5 | 2. Tested our 2D MERA model on MNIST, NeedleMNIST(64x64, 128x128) and LIDC dataset. 6 | 7 | | | MNIST | NeedleMNIST(64x64) | NeedleMNIST(128x128) | LIDC | 8 | |----------- |------- |-------------------- |---------------------- |------- | 9 | | CNN | 0.983 | 0.760 | 0.739 | 0.780 | 10 | | Tensor-NN | 0.985 | 0.740 | 0.727 | 0.860 | 11 | | 2D MERA | 0.903 | 0.784 | 0.714 | 0.760 | 12 | 13 | 3. Summarized our work into a paper submitted to QTNML 2020 14 | 15 | #### Description: 16 | ##### PyTorch codes: 17 | * Basic Pytorch dependency 18 | * Tested on Pytorch 1.3, Python 3.6 19 | * Unzip the data and point the path to --data_path 20 | * How to run tests: python train.py --data_path data_location 21 | ##### TensorFlow code: 22 | * TensorFlow 2.1.0 and TensorNetwork 23 | * Experiments are performed on Jupyter Notebook MERA_MNIST.ipynb 24 | 25 | ##### Thanks to the following repositories: 26 | - https://github.com/raghavian/loTeNet_pytorch 27 | -------------------------------------------------------------------------------- /models/Densenet.py: -------------------------------------------------------------------------------- 1 | ### Adapted from https://github.com/bamos/densenet.pytorch 2 | 3 | import torch 4 | import numpy as np 5 | import torch.nn as nn 6 | import torch.optim as optim 7 | 8 | import torch.nn.functional as F 9 | from torch.autograd import Variable 10 | 11 | import torchvision.datasets as dset 12 | import torchvision.transforms as transforms 13 | from torch.utils.data import DataLoader 14 | from scipy.stats import truncnorm 15 | 16 | import torchvision.models as models 17 | 18 | import sys 19 | import math 20 | import pdb 21 | 22 | class Bottleneck(nn.Module): 23 | def __init__(self, nChannels, growthRate): 24 | super(Bottleneck, self).__init__() 25 | interChannels = 4*growthRate 26 | self.bn1 = nn.BatchNorm2d(nChannels) 27 | self.conv1 = nn.Conv2d(nChannels, interChannels, kernel_size=1, 28 | bias=False) 29 | self.bn2 = nn.BatchNorm2d(interChannels) 30 | self.conv2 = nn.Conv2d(interChannels, growthRate, kernel_size=3, 31 | padding=1, bias=False) 32 | 33 | def forward(self, x): 34 | out = self.conv1(F.relu(self.bn1(x))) 35 | out = self.conv2(F.relu(self.bn2(out))) 36 | out = torch.cat((x, out), 1) 37 | return out 38 | 39 | class FrBottleneck(nn.Module): 40 | def __init__(self, nChannels, growthRate, dup=2): 41 | super(FrBottleneck, self).__init__() 42 | self.bn1 = nn.BatchNorm2d(nChannels) 43 | self.Fr1 = frTNLayer2(nChannels, dup=dup) 44 | interChannels = nChannels // dup 45 | self.bn2 = nn.BatchNorm2d(interChannels) 46 | self.conv2 = nn.Conv2d(interChannels, growthRate, kernel_size=3, 47 | padding=1, bias=False) 48 | 49 | def forward(self, x): 50 | out = self.Fr1(F.relu(self.bn1(x))) 51 | out = self.conv2(F.relu(self.bn2(out))) 52 | out = torch.cat((x, out), 1) 53 | return out 54 | 55 | 56 | def get_truncated_normal(mean=0, sd=1, low=-2, upp=2): 57 | return truncnorm( 58 | (low - mean) / sd, (upp - mean) / sd, loc=mean, scale=sd) 59 | 60 | class frTNLayer(nn.Module): 61 | def __init__(self, nChannels, V=False, P=True, H=True, S=True,W=True, dup=2, ratio=6, z_bias=1.0): 62 | super(frTNLayer, self).__init__() 63 | self.flag_V = V 64 | self.flag_P = P 65 | self.flag_H = H 66 | self.flag_S = S 67 | self.flag_W = W 68 | self.nChannels = nChannels 69 | self.dup = dup 70 | self.ratio = 6 71 | 72 | self.z_bias = z_bias 73 | 74 | def forward(self, x): 75 | #print('x_shape', x.shape) 76 | if self.flag_V: 77 | mean, variance = x.view(x.shape[0], x.shape[1], -1).mean(2), x.view(x.shape[0], x.shape[1], -1).std(2) 78 | mv = torch.cat([mean, variance], 1) # B x 2C 79 | coff = 2 80 | else: 81 | mv = x.view(x.shape[0], x.shape[1], -1).mean(2) 82 | coff = 1 83 | α = Variable(torch.zeros(1), requires_grad=True) 84 | α = torch.clamp(α, -1.0, 5.0) 85 | α = α.cuda() 86 | if self.flag_P: 87 | x = torch.nn.functional.relu(x) + self.z_bias 88 | z = torch.pow(x, α + 1) 89 | else: 90 | y = torch.abs(x) 91 | z = torch.pow(y, α + 1) * torch.sign(x) 92 | if self.flag_W: 93 | size_wa = self.nChannels * coff * self.nChannels * coff // self.ratio 94 | wx = get_truncated_normal() 95 | fcwa = wx.rvs(size_wa) 96 | fcwa = np.reshape(fcwa, [self.nChannels * coff, self.nChannels * coff // self.ratio]) 97 | fc_weights_a = Variable(torch.from_numpy(fcwa), requires_grad=True) 98 | 99 | size_wb = self.nChannels * coff * self.nChannels * coff // self.ratio 100 | wx = get_truncated_normal() 101 | fcwb = wx.rvs(size_wb) 102 | fcwb = np.reshape(fcwb, [self.nChannels * coff, self.nChannels * coff // self.ratio]) 103 | fc_weights_b = Variable(torch.from_numpy(fcwb), requires_grad=True) 104 | 105 | fc_weights_a = fc_weights_a.cuda().float() 106 | fc_weights_b = fc_weights_b.cuda().float() 107 | mv = mv.cuda().float() 108 | # print(fc_weights_a.shape) 109 | # print(fc_weights_b.shape) 110 | # print(mv.shape) 111 | 112 | ω = torch.nn.functional.sigmoid(torch.matmul(torch.nn.functional.relu(torch.matmul(mv, fc_weights_a)), fc_weights_b.transpose(0,1))) 113 | ω = ω.view(-1, self.nChannels, 1, 1) 114 | #print(z.shape) 115 | #print(ω.shape) 116 | z = z * ω 117 | z = z.view(-1, self.dup, self.nChannels // self.dup, x.shape[2], x.shape[3]) 118 | z = z.sum(1) 119 | 120 | return z 121 | 122 | class frTNLayer2(nn.Module): 123 | def __init__(self, nChannels, V=True, P=True, H=True, S=True,W=True, dup=2, ratio=6, z_bias=1.0): 124 | super(frTNLayer2, self).__init__() 125 | self.flag_V = V 126 | self.flag_P = P 127 | self.flag_H = H 128 | self.flag_S = S 129 | self.flag_W = W 130 | self.nChannels = nChannels 131 | self.dup = dup 132 | self.ratio = 6 133 | 134 | self.z_bias = z_bias 135 | 136 | def forward(self, x): 137 | #print('x_shape', x.shape) 138 | if self.flag_V: 139 | mean, variance = x.view(x.shape[0], x.shape[1], -1).mean(2), x.view(x.shape[0], x.shape[1], -1).std(2) 140 | mv = torch.cat([mean, variance], 1) # B x 2C 141 | coff = 2 142 | else: 143 | mv = x.view(x.shape[0], x.shape[1], -1).mean(2) 144 | coff = 1 145 | 146 | if self.flag_W: 147 | size_wa = self.nChannels * coff * self.nChannels * coff // self.ratio 148 | wx = get_truncated_normal(mean=0, sd=0.05) 149 | fcwa = wx.rvs(size_wa) 150 | fcwa = np.reshape(fcwa, [self.nChannels * coff, self.nChannels * coff // self.ratio]) 151 | fc_weights_a = Variable(torch.from_numpy(fcwa), requires_grad=True) 152 | 153 | size_wb = self.nChannels * coff * self.nChannels * coff // self.ratio 154 | wx = get_truncated_normal(sd=0.002) 155 | fcwb = wx.rvs(size_wb) 156 | fcwb = np.reshape(fcwb, [self.nChannels * coff, self.nChannels * coff // self.ratio]) 157 | fc_weights_b = Variable(torch.from_numpy(fcwb), requires_grad=True) 158 | fc_weights_a = fc_weights_a.cuda().float() 159 | fc_weights_b = fc_weights_b.cuda().float() 160 | mv = mv.cuda().float() 161 | if self.flag_S: 162 | # print(mv.shape) 163 | # print(fc_weights_a.shape) 164 | # print(fc_weights_b.shape) 165 | η =torch.nn.functional.sigmoid( 166 | torch.matmul(torch.nn.functional.relu(torch.matmul(mv, fc_weights_a)), fc_weights_b.transpose(0, 1))) 167 | #print(η.shape) 168 | α, ω = torch.split(η, η.shape[1]//2, dim=1) 169 | α, ω = α, torch.nn.functional.sigmoid(ω) 170 | # print(self.nChannels) 171 | # print(α.shape) 172 | # print(ω.shape) 173 | α = α.view(-1, self.nChannels, 1, 1) 174 | ω = ω.view(-1, self.nChannels, 1, 1) 175 | α, ω = α.view(-1, self.nChannels, 1, 1), ω.view(-1, self.nChannels, 1, 1) 176 | α = torch.clamp(α, -0.5, 2.0) 177 | α = α.cuda().float() 178 | ω = ω.cuda().float() 179 | 180 | 181 | 182 | if self.flag_P: 183 | x = torch.nn.functional.relu(x) + self.z_bias 184 | z = torch.pow(x, α + 1) 185 | else: 186 | y = torch.abs(x) 187 | z = torch.pow(y, α + 1) * torch.sign(x) 188 | z = z * ω 189 | z = z.view(-1, self.dup, self.nChannels // self.dup, x.shape[2], x.shape[3]) 190 | z = z.sum(1) 191 | 192 | return z 193 | 194 | class SingleLayer(nn.Module): 195 | def __init__(self, nChannels, growthRate): 196 | super(SingleLayer, self).__init__() 197 | self.bn1 = nn.BatchNorm2d(nChannels) 198 | self.conv1 = nn.Conv2d(nChannels, growthRate, kernel_size=3, 199 | padding=1, bias=False) 200 | 201 | def forward(self, x): 202 | out = self.conv1(F.relu(self.bn1(x))) 203 | out = torch.cat((x, out), 1) 204 | return out 205 | 206 | class Transition(nn.Module): 207 | def __init__(self, nChannels, nOutChannels): 208 | super(Transition, self).__init__() 209 | self.bn1 = nn.BatchNorm2d(nChannels) 210 | self.conv1 = nn.Conv2d(nChannels, nOutChannels, kernel_size=1, 211 | bias=False) 212 | 213 | def forward(self, x): 214 | out = self.conv1(F.relu(self.bn1(x))) 215 | out = F.avg_pool2d(out, 2) 216 | return out 217 | 218 | 219 | class DenseNet(nn.Module): 220 | def __init__(self, growthRate, depth, reduction, nClasses, bottleneck): 221 | super(DenseNet, self).__init__() 222 | 223 | nDenseBlocks = (depth-4) // 3 224 | if bottleneck: 225 | nDenseBlocks //= 2 226 | 227 | nChannels = 2*growthRate 228 | self.conv1 = nn.Conv2d(1, nChannels, kernel_size=3, padding=1, 229 | bias=False) 230 | self.dense1 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck) 231 | nChannels += nDenseBlocks*growthRate 232 | nOutChannels = int(math.floor(nChannels*reduction)) 233 | self.trans1 = Transition(nChannels, nOutChannels) 234 | 235 | nChannels = nOutChannels 236 | self.dense2 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck) 237 | nChannels += nDenseBlocks*growthRate 238 | nOutChannels = int(math.floor(nChannels*reduction)) 239 | self.trans2 = Transition(nChannels, nOutChannels) 240 | 241 | nChannels = nOutChannels 242 | self.dense3 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck) 243 | nChannels += nDenseBlocks*growthRate 244 | 245 | self.bn1 = nn.BatchNorm2d(nChannels) 246 | self.fc = nn.Linear(nChannels, nClasses) 247 | 248 | self.nChannels = nChannels 249 | for m in self.modules(): 250 | if isinstance(m, nn.Conv2d): 251 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 252 | m.weight.data.normal_(0, math.sqrt(2. / n)) 253 | elif isinstance(m, nn.BatchNorm2d): 254 | m.weight.data.fill_(1) 255 | m.bias.data.zero_() 256 | elif isinstance(m, nn.Linear): 257 | m.bias.data.zero_() 258 | 259 | def _make_dense(self, nChannels, growthRate, nDenseBlocks, bottleneck): 260 | layers = [] 261 | for i in range(int(nDenseBlocks)): 262 | if bottleneck: 263 | layers.append(Bottleneck(nChannels, growthRate)) 264 | else: 265 | layers.append(SingleLayer(nChannels, growthRate)) 266 | nChannels += growthRate 267 | return nn.Sequential(*layers) 268 | 269 | def forward(self, x): 270 | # pdb.set_trace() 271 | out = self.conv1(x) 272 | out = self.trans1(self.dense1(out)) 273 | out = self.trans2(self.dense2(out)) 274 | out = self.dense3(out) 275 | out = torch.squeeze(F.avg_pool2d(F.relu(self.bn1(out)), 8)) 276 | out = out.view(x.shape[0],self.nChannels,-1).mean(2) 277 | out = torch.sigmoid(self.fc(out)) 278 | return out.squeeze() 279 | 280 | class FrDenseNet(nn.Module): 281 | def __init__(self, growthRate, depth, reduction, nClasses, bottleneck): 282 | super(FrDenseNet, self).__init__() 283 | 284 | nDenseBlocks = (depth-4) // 3 285 | if bottleneck: 286 | nDenseBlocks //= 2 287 | 288 | nChannels = 2*growthRate 289 | self.conv1 = nn.Conv2d(1, nChannels, kernel_size=3, padding=1, 290 | bias=False) 291 | self.dense1 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck) 292 | nChannels += nDenseBlocks*growthRate 293 | nOutChannels = int(math.floor(nChannels*reduction)) 294 | self.trans1 = Transition(nChannels, nOutChannels) 295 | 296 | nChannels = nOutChannels 297 | self.dense2 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck) 298 | nChannels += nDenseBlocks*growthRate 299 | nOutChannels = int(math.floor(nChannels*reduction)) 300 | self.trans2 = Transition(nChannels, nOutChannels) 301 | 302 | nChannels = nOutChannels 303 | self.dense3 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck) 304 | nChannels += nDenseBlocks*growthRate 305 | 306 | self.bn1 = nn.BatchNorm2d(nChannels) 307 | self.fc = nn.Linear(nChannels, nClasses) 308 | 309 | self.nChannels = nChannels 310 | for m in self.modules(): 311 | if isinstance(m, nn.Conv2d): 312 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 313 | m.weight.data.normal_(0, math.sqrt(2. / n)) 314 | elif isinstance(m, nn.BatchNorm2d): 315 | m.weight.data.fill_(1) 316 | m.bias.data.zero_() 317 | elif isinstance(m, nn.Linear): 318 | m.bias.data.zero_() 319 | 320 | def _make_dense(self, nChannels, growthRate, nDenseBlocks, bottleneck): 321 | layers = [] 322 | for i in range(int(nDenseBlocks)): 323 | if bottleneck: 324 | layers.append(FrBottleneck(nChannels, growthRate)) 325 | else: 326 | layers.append(SingleLayer(nChannels, growthRate)) 327 | nChannels += growthRate 328 | return nn.Sequential(*layers) 329 | 330 | def forward(self, x): 331 | # pdb.set_trace() 332 | out = self.conv1(x) 333 | out = self.trans1(self.dense1(out)) 334 | out = self.trans2(self.dense2(out)) 335 | out = self.dense3(out) 336 | out = torch.squeeze(F.avg_pool2d(F.relu(self.bn1(out)), 8)) 337 | out = out.view(x.shape[0],self.nChannels,-1).mean(2) 338 | out = torch.sigmoid(self.fc(out)) 339 | return out.squeeze() 340 | -------------------------------------------------------------------------------- /models/MERA.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from models.mps import MPS 6 | import pdb 7 | 8 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 9 | 10 | EPS = 1e-6 11 | 12 | 13 | class MERAnet_clean(nn.Module): 14 | def __init__(self, input_dim, output_dim, bond_dim, feature_dim=2, nCh=3, 15 | kernel=[2, 2, 2], virtual_dim=1, 16 | adaptive_mode=False, periodic_bc=False, parallel_eval=False, 17 | label_site=None, path=None, init_std=1e-9, use_bias=True, 18 | fixed_bias=True, cutoff=1e-10, merge_threshold=2000): 19 | # bond_dim parameter is non-sense 20 | super().__init__() 21 | self.input_dim = input_dim 22 | self.virtual_dim = bond_dim 23 | 24 | ### Squeezing of spatial dimension in first step 25 | # self.kScale = 4 # what is this? 26 | # nCh = self.kScale ** 2 * nCh 27 | # self.input_dim = self.input_dim / self.kScale 28 | 29 | # print(nCh) 30 | self.nCh = nCh 31 | if isinstance(kernel, int): 32 | kernel = 3 * [kernel] 33 | self.ker = kernel 34 | 35 | num_layers = np.int(np.log2(input_dim[0]) - 2) 36 | self.num_layers = num_layers 37 | self.disentangler_list = [] 38 | self.isometry_list = [] 39 | iDim = (self.input_dim - 2) / 2 40 | 41 | for ii in range(num_layers): 42 | # feature_dim = 2 * nCh 43 | feature_dim = 2 * nCh 44 | # print(feature_dim) 45 | # First level disentanglers 46 | # First level isometries 47 | 48 | ### First level MERA blocks 49 | self.disentangler_list.append(nn.ModuleList([MPS(input_dim=4, 50 | output_dim=4, 51 | nCh=nCh, bond_dim=4, 52 | feature_dim=feature_dim, parallel_eval=parallel_eval, 53 | adaptive_mode=adaptive_mode, periodic_bc=periodic_bc) 54 | for i in range(torch.prod(iDim))])) 55 | 56 | iDim = iDim + 1 57 | 58 | self.isometry_list.append(nn.ModuleList([MPS(input_dim=4, 59 | output_dim=1, 60 | nCh=nCh, bond_dim=bond_dim, 61 | feature_dim=feature_dim, parallel_eval=parallel_eval, 62 | adaptive_mode=adaptive_mode, periodic_bc=periodic_bc) 63 | for i in range(torch.prod(iDim))])) 64 | iDim = (iDim - 2) / 2 65 | 66 | ### Final MPS block 67 | self.mpsFinal = MPS(input_dim=49, 68 | output_dim=output_dim, nCh=1, 69 | bond_dim=bond_dim, feature_dim=feature_dim, 70 | adaptive_mode=adaptive_mode, periodic_bc=periodic_bc, 71 | parallel_eval=parallel_eval) 72 | 73 | def forward(self, x): 74 | iDim = self.input_dim 75 | b = x.shape[0] # Batch size 76 | # Disentangler layer 77 | x_in = x 78 | 79 | for jj in range(self.num_layers): 80 | iDim = iDim // 2 81 | x_ent = x_in[:, :, 1:-1, 1:-1] 82 | # print('---------------') 83 | # print(x_ent.shape) 84 | # print('x_ent unfold shape: ', x_ent.unfold(2, 2, 2).unfold(3, 2, 2).shape) 85 | x_ent = x_ent.unfold(2, 2, 2).unfold(3, 2, 2).reshape(b, self.nCh, -1, 4) 86 | # print('x_ent unfold->reshape shape: ', x_ent.shape) 87 | # print('single x_ent unfold->reshape shape: ', x_ent[:, :, 0].shape) 88 | # print(x_ent.shape) 89 | # print(len(self.disentangler_list[jj])) 90 | y_ent = [self.disentangler_list[jj][i](x_ent[:, :, i]) for i in range(len(self.disentangler_list[jj]))] 91 | y_ent = torch.stack(y_ent, dim=1) # 512, 3969, 4 92 | y_ent = y_ent.view(y_ent.shape[0], y_ent.shape[1], 2, 2) 93 | y_ent = y_ent.view(y_ent.shape[0], iDim[0] - 1, iDim[1] - 1, 2, 2) 94 | y_ent_list = [] 95 | # print('y_ent shape: ', y_ent.shape) 96 | # print('torch cat col shape: ', torch.cat([y_ent[:, 0, i, :, :] for i in range(y_ent.shape[2])], dim=2).shape) 97 | for j in range(y_ent.shape[1]): 98 | y_ent_list.append(torch.cat([y_ent[:, j, i, :, :] for i in range(y_ent.shape[2])], dim=2)) 99 | y_ent = torch.cat(y_ent_list, dim=1) 100 | 101 | # print('y_ent shape: ', y_ent.shape) 102 | x_iso = x_in 103 | # print(x_iso.shape) 104 | x_iso[:, :, 1:-1, 1:-1] = y_ent.view(b, self.nCh, y_ent.shape[1], y_ent.shape[2]) 105 | # print('x_iso shape: ', x_iso.shape) 106 | x_iso = x_iso.unfold(2, 2, 2).unfold(3, 2, 2).reshape(b, self.nCh, -1, 4) 107 | y_iso = [self.isometry_list[jj][i](x_iso[:, :, i]) for i in range(len(self.isometry_list[jj]))] 108 | y_iso = torch.stack(y_iso, dim=1) # 512, 4096, 1 109 | 110 | x_in = y_iso.view(y_iso.shape[0], self.nCh, iDim[0], iDim[1]) 111 | 112 | # print('x6 shape: ', x6.shape) # 512, 1, 2, 2 113 | 114 | y = x_in.view(b, self.nCh, iDim[0] * iDim[1]) 115 | # print('LoTe y shape before mpsfinal ', y.shape) 116 | y = self.mpsFinal(y) 117 | return y.squeeze() 118 | 119 | 120 | -------------------------------------------------------------------------------- /models/contractables.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class Contractable: 4 | """ 5 | Container for tensors with labeled indices and a global batch size 6 | 7 | The labels for our indices give some high-level knowledge of the tensor 8 | layout, and permit the contraction of pairs of indices in a more 9 | systematic manner. However, much of the actual heavy lifting is done 10 | through specific contraction routines in different subclasses 11 | 12 | Attributes: 13 | tensor (Tensor): A Pytorch tensor whose first index is a batch 14 | index. Sub-classes of Contractable may put other 15 | restrictions on tensor 16 | bond_str (str): A string whose letters each label a separate mode 17 | of our tensor, and whose length equals the order 18 | (number of modes) of our tensor 19 | global_bs (int): The batch size associated with all Contractables. 20 | This is shared between all Contractable instances 21 | and allows for automatic expanding of tensors 22 | """ 23 | # The global batch size 24 | global_bs = None 25 | 26 | def __init__(self, tensor, bond_str): 27 | shape = list(tensor.shape) 28 | num_dim = len(shape) 29 | str_len = len(bond_str) 30 | global global_bs 31 | if global_bs is not None: 32 | global_bs = global_bs 33 | else: 34 | global_bs = Contractable.global_bs 35 | batch_dim = tensor.size(0) 36 | 37 | # Expand along a new batch dimension if needed 38 | if ('b' not in bond_str and str_len == num_dim) or \ 39 | ('b' == bond_str[0] and str_len == num_dim + 1): 40 | if global_bs is not None: 41 | tensor = tensor.unsqueeze(0).expand([global_bs] + shape) 42 | else: 43 | raise RuntimeError("No batch size given and no previous " 44 | "batch size set") 45 | if bond_str[0] != 'b': 46 | bond_str = 'b' + bond_str 47 | 48 | # Check for correct formatting in bond_str 49 | elif bond_str[0] != 'b' or str_len != num_dim: 50 | raise ValueError("Length of bond string '{bond_str}' ({len(bond_str)}) must match order of tensor ({len(shape)})") 51 | 52 | # Set the global batch size if it is unset or needs to be updated 53 | elif global_bs is None or global_bs != batch_dim: 54 | Contractable.global_bs = batch_dim 55 | 56 | # Check that global batch size agrees with input tensor's first dim 57 | elif global_bs != batch_dim: 58 | raise RuntimeError("Batch size previously set to {global_bs}" 59 | ", but input tensor has batch size {batch_dim}") 60 | 61 | # Set the defining attributes of our Contractable 62 | self.tensor = tensor 63 | self.bond_str = bond_str 64 | 65 | def __mul__(self, contractable, rmul=False): 66 | """ 67 | Multiply with another contractable along a linear index 68 | 69 | The default behavior is to multiply the 'r' index of this instance 70 | with the 'l' index of contractable, matching the batch ('b') 71 | index of both, and take the outer product of other indices. 72 | If rmul is True, contractable is instead multiplied on the right. 73 | """ 74 | # This method works for general Core subclasses besides Scalar (no 'l' 75 | # and 'r' indices), composite contractables (no tensor attribute), and 76 | # MatRegion (multiplication isn't just simple index contraction) 77 | if isinstance(contractable, Scalar) or \ 78 | not hasattr(contractable, 'tensor') or \ 79 | type(contractable) is MatRegion: 80 | return NotImplemented 81 | 82 | tensors = [self.tensor, contractable.tensor] 83 | bond_strs = [list(self.bond_str), list(contractable.bond_str)] 84 | lowercases = [chr(c) for c in range(ord('a'), ord('z')+1)] 85 | 86 | # Reverse the order of tensors if needed 87 | if rmul: 88 | tensors = tensors[::-1] 89 | bond_strs = bond_strs[::-1] 90 | 91 | # Check that bond strings are in proper format 92 | for i, bs in enumerate(bond_strs): 93 | assert bs[0] == 'b' 94 | assert len(set(bs)) == len(bs) 95 | assert all([c in lowercases for c in bs]) 96 | assert (i == 0 and 'r' in bs) or (i == 1 and 'l' in bs) 97 | 98 | # Get used and free characters 99 | used_chars = set(bond_strs[0]).union(bond_strs[1]) 100 | free_chars = [c for c in lowercases if c not in used_chars] 101 | 102 | # Rename overlapping indices in the bond strings (except 'b', 'l', 'r') 103 | specials = ['b', 'l', 'r'] 104 | for i, c in enumerate(bond_strs[1]): 105 | if c in bond_strs[0] and c not in specials: 106 | bond_strs[1][i] = free_chars.pop() 107 | 108 | # Combine right bond of left tensor and left bond of right tensor 109 | sum_char = free_chars.pop() 110 | bond_strs[0][bond_strs[0].index('r')] = sum_char 111 | bond_strs[1][bond_strs[1].index('l')] = sum_char 112 | specials.append(sum_char) 113 | 114 | # Build bond string of ouput tensor 115 | out_str = ['b'] 116 | for bs in bond_strs: 117 | out_str.extend([c for c in bs if c not in specials]) 118 | out_str.append('l' if 'l' in bond_strs[0] else '') 119 | out_str.append('r' if 'r' in bond_strs[1] else '') 120 | 121 | # Build the einsum string for this operation 122 | bond_strs = [''.join(bs) for bs in bond_strs] 123 | out_str = ''.join(out_str) 124 | ein_str = f"{bond_strs[0]},{bond_strs[1]}->{out_str}" 125 | 126 | # Contract along the linear dimension to get an output tensor 127 | out_tensor = torch.einsum(ein_str, [tensors[0], tensors[1]]) 128 | 129 | # Return our output tensor wrapped in an appropriate class 130 | if out_str == 'br': 131 | return EdgeVec(out_tensor, is_left_vec=True) 132 | elif out_str == 'bl': 133 | return EdgeVec(out_tensor, is_left_vec=False) 134 | elif out_str == 'blr': 135 | return SingleMat(out_tensor) 136 | elif out_str == 'bolr': 137 | return OutputCore(out_tensor) 138 | else: 139 | return Contractable(out_tensor, out_str) 140 | 141 | def __rmul__(self, contractable): 142 | """ 143 | Multiply with another contractable along a linear index 144 | """ 145 | return self.__mul__(contractable, rmul=True) 146 | 147 | def reduce(self): 148 | """ 149 | Return the contractable without any modification 150 | 151 | reduce() can be any method which returns a contractable. This is 152 | trivially possible for any contractable by returning itself 153 | """ 154 | return self 155 | 156 | class ContractableList(Contractable): 157 | """ 158 | A list of contractables which can all be multiplied together in order 159 | 160 | Calling reduce on a ContractableList instance will first reduce every item 161 | to a linear contractable, and then contract everything together 162 | """ 163 | def __init__(self, contractable_list): 164 | # Check that input list is nonempty and has contractables as entries 165 | if not isinstance(contractable_list, list) or contractable_list is []: 166 | raise ValueError("Input to ContractableList must be nonempty list") 167 | for i, item in enumerate(contractable_list): 168 | if not isinstance(item, Contractable): 169 | raise ValueError("Input items to ContractableList must be Contractable instances, but item {i} is not") 170 | 171 | self.contractable_list = contractable_list 172 | 173 | def __mul__(self, contractable, rmul=False): 174 | """ 175 | Multiply a contractable by everything in ContractableList in order 176 | """ 177 | # The input cannot be a composite contractable 178 | assert hasattr(contractable, 'tensor') 179 | output = contractable.tensor 180 | 181 | # Multiply by everything in ContractableList, in the correct order 182 | if rmul: 183 | for item in self.contractable_list: 184 | output = item * output 185 | else: 186 | for item in self.contractable_list[::-1]: 187 | output = output * item 188 | 189 | return output 190 | 191 | def __rmul__(self, contractable): 192 | """ 193 | Multiply another contractable by everything in ContractableList 194 | """ 195 | return self.__mul__(contractable, rmul=True) 196 | 197 | def reduce(self, parallel_eval=False): 198 | """ 199 | Reduce all the contractables in list before multiplying them together 200 | """ 201 | c_list = self.contractable_list 202 | # For parallel_eval, reduce all contractables in c_list 203 | if parallel_eval: 204 | c_list = [item.reduce() for item in c_list] 205 | 206 | # Multiply together all the contractables. This multiplies in right to 207 | # left order, but certain inefficient contractions are unsupported. 208 | # If we encounter an unsupported operation, then try multiplying from 209 | # the left end of the list instead 210 | while len(c_list) > 1: 211 | try: 212 | c_list[-2] = c_list[-2] * c_list[-1] 213 | del c_list[-1] 214 | except TypeError: 215 | c_list[1] = c_list[0] * c_list[1] 216 | del c_list[0] 217 | 218 | return c_list[0] 219 | 220 | class MatRegion(Contractable): 221 | """ 222 | A contiguous collection of matrices which are multiplied together 223 | 224 | The input tensor defining our MatRegion must have shape 225 | [batch_size, num_mats, D, D], or [num_mats, D, D] when the global batch 226 | size is already known 227 | """ 228 | def __init__(self, mats): 229 | shape = list(mats.shape) 230 | if len(shape) not in [3, 4] or shape[-2] != shape[-1]: 231 | raise ValueError("MatRegion tensors must have shape " 232 | "[batch_size, num_mats, D, D], or [num_mats," 233 | " D, D] if batch size has already been set") 234 | 235 | super().__init__(mats, bond_str='bslr') 236 | 237 | def __mul__(self, edge_vec, rmul=False): 238 | """ 239 | Iteratively multiply an input vector with all matrices in MatRegion 240 | """ 241 | # The input must be an instance of EdgeVec 242 | if not isinstance(edge_vec, EdgeVec): 243 | return NotImplemented 244 | 245 | mats = self.tensor 246 | num_mats = mats.size(1) 247 | batch_size = mats.size(0) 248 | 249 | # Load our vector and matrix batches 250 | dummy_ind = 1 if rmul else 2 251 | vec = edge_vec.tensor.unsqueeze(dummy_ind) 252 | mat_list = [mat.squeeze(1) for mat in torch.chunk(mats, num_mats, 1)] 253 | 254 | # Do the repeated matrix-vector multiplications in the proper order 255 | log_norm = 0 256 | for i, mat in enumerate(mat_list[::(1 if rmul else -1)], 1): 257 | if rmul: 258 | vec = torch.bmm(vec, mat) 259 | else: 260 | vec = torch.bmm(mat, vec) 261 | 262 | # Since we only have a single vector, wrap it as a EdgeVec 263 | return EdgeVec(vec.squeeze(dummy_ind), is_left_vec=rmul) 264 | 265 | def __rmul__(self, edge_vec): 266 | return self.__mul__(edge_vec, rmul=True) 267 | 268 | def reduce(self): 269 | """ 270 | Multiplies together all matrices and returns resultant SingleMat 271 | 272 | This method uses iterated batch multiplication to evaluate the full 273 | matrix product in depth O( log(num_mats) ) 274 | """ 275 | mats = self.tensor 276 | shape = list(mats.shape) 277 | batch_size = mats.size(0) 278 | size, D = shape[1:3] 279 | 280 | # Iteratively multiply pairs of matrices until there is only one 281 | while size > 1: 282 | odd_size = (size % 2 == 1) 283 | half_size = size // 2 284 | nice_size = 2 * half_size 285 | 286 | even_mats = mats[:, 0:nice_size:2] 287 | odd_mats = mats[:, 1:nice_size:2] 288 | # For odd sizes, set aside one batch of matrices for the next round 289 | leftover = mats[:, nice_size:] 290 | 291 | # Multiply together all pairs of matrices (except leftovers) 292 | mats = torch.einsum('bslu,bsur->bslr', [even_mats, odd_mats]) 293 | mats = torch.cat([mats, leftover], 1) 294 | 295 | size = half_size + int(odd_size) 296 | 297 | # Since we only have a single matrix, wrap it as a SingleMat 298 | return SingleMat(mats.squeeze(1)) 299 | 300 | class OutputCore(Contractable): 301 | """ 302 | A single MPS core with a single output index 303 | """ 304 | def __init__(self, tensor, global_bs=512): 305 | # Check the input shape 306 | #self.global_bs = global_bs 307 | if len(tensor.shape) not in [3, 4]: 308 | raise ValueError("OutputCore tensors must have shape [batch_size, " 309 | "output_dim, D_l, D_r], or else [output_dim, D_l," 310 | " D_r] if batch size has already been set") 311 | 312 | super().__init__(tensor, bond_str='bolr') 313 | 314 | class SingleMat(Contractable): 315 | """ 316 | A batch of matrices associated with a single location in our MPS 317 | """ 318 | def __init__(self, mat): 319 | # Check the input shape 320 | if len(mat.shape) not in [2, 3]: 321 | raise ValueError("SingleMat tensors must have shape [batch_size, " 322 | "D_l, D_r], or else [D_l, D_r] if batch size " 323 | "has already been set") 324 | 325 | super().__init__(mat, bond_str='blr') 326 | 327 | class OutputMat(Contractable): 328 | """ 329 | An output core associated with an edge of our MPS 330 | """ 331 | def __init__(self, mat, is_left_mat): 332 | # Check the input shape 333 | if len(mat.shape) not in [2, 3]: 334 | raise ValueError("OutputMat tensors must have shape [batch_size, " 335 | "D, output_dim], or else [D, output_dim] if " 336 | "batch size has already been set") 337 | 338 | # OutputMats on left edge will have a right-facing bond, and vice versa 339 | bond_str = 'b' + ('r' if is_left_mat else 'l') + 'o' 340 | super().__init__(mat, bond_str=bond_str) 341 | 342 | def __mul__(self, edge_vec, rmul=False): 343 | """ 344 | Multiply with an edge vector along the shared linear index 345 | """ 346 | if not isinstance(edge_vec, EdgeVec): 347 | raise NotImplemented 348 | else: 349 | return super().__mul__(edge_vec, rmul) 350 | 351 | def __rmul__(self, edge_vec): 352 | return self.__mul__(edge_vec, rmul=True) 353 | 354 | class EdgeVec(Contractable): 355 | """ 356 | A batch of vectors associated with an edge of our MPS 357 | 358 | EdgeVec instances are always associated with an edge of an MPS, which 359 | requires the is_left_vec flag to be set to True (vector on left edge) or 360 | False (vector on right edge) 361 | """ 362 | def __init__(self, vec, is_left_vec): 363 | # Check the input shape 364 | if len(vec.shape) not in [1, 2]: 365 | raise ValueError("EdgeVec tensors must have shape " 366 | "[batch_size, D], or else [D] if batch size " 367 | "has already been set") 368 | 369 | # EdgeVecs on left edge will have a right-facing bond, and vice versa 370 | bond_str = 'b' + ('r' if is_left_vec else 'l') 371 | super().__init__(vec, bond_str=bond_str) 372 | 373 | def __mul__(self, right_vec): 374 | """ 375 | Take the inner product of our vector with another vector 376 | """ 377 | # The input must be an instance of EdgeVec 378 | if not isinstance(right_vec, EdgeVec): 379 | return NotImplemented 380 | 381 | left_vec = self.tensor.unsqueeze(1) 382 | right_vec = right_vec.tensor.unsqueeze(2) 383 | batch_size = left_vec.size(0) 384 | 385 | # Do the batch inner product 386 | scalar = torch.bmm(left_vec, right_vec).view([batch_size]) 387 | 388 | # Since we only have a single scalar, wrap it as a Scalar 389 | return Scalar(scalar) 390 | 391 | class Scalar(Contractable): 392 | """ 393 | A batch of scalars 394 | """ 395 | def __init__(self, scalar): 396 | # Add dummy dimension if we have a torch scalar 397 | shape = list(scalar.shape) 398 | if shape is []: 399 | scalar = scalar.view([1]) 400 | shape = [1] 401 | 402 | # Check the input shape 403 | if len(shape) != 1: 404 | raise ValueError("input scalar must be a torch tensor with shape " 405 | "[batch_size], or [] or [1] if batch size has " 406 | "been set") 407 | 408 | super().__init__(scalar, bond_str='b') 409 | 410 | def __mul__(self, contractable): 411 | """ 412 | Multiply a contractable by our scalar and return the result 413 | """ 414 | scalar = self.tensor 415 | tensor = contractable.tensor 416 | bond_str = contractable.bond_str 417 | 418 | ein_string = "{bond_str},b->{bond_str}" 419 | out_tensor = torch.einsum(ein_string, [tensor, scalar]) 420 | 421 | # Wrap the result in the same class right_contractable belongs to 422 | contract_class = type(contractable) 423 | if contract_class is not Contractable: 424 | return contract_class(out_tensor) 425 | else: 426 | return Contractable(out_tensor, bond_str) 427 | 428 | def __rmul__(self, contractable): 429 | # Scalar multiplication is commutative 430 | return self.__mul__(contractable) 431 | -------------------------------------------------------------------------------- /models/contractables.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/timqqt/MERA_Image_Classification/e96211f45ade86f031a0d99ad0670231844ef3a1/models/contractables.pyc -------------------------------------------------------------------------------- /models/lotenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from models.mps import MPS, ReLUMPS 6 | import pdb 7 | 8 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 9 | 10 | EPS = 1e-6 11 | 12 | class MERAnet_clean(nn.Module): 13 | def __init__(self, input_dim, output_dim, bond_dim, feature_dim=2, nCh=3, 14 | kernel=[2, 2, 2], virtual_dim=1, 15 | adaptive_mode=False, periodic_bc=False, parallel_eval=False, 16 | label_site=None, path=None, init_std=1e-9, use_bias=True, 17 | fixed_bias=True, cutoff=1e-10, merge_threshold=2000): 18 | #bond_dim parameter is non-sense 19 | super().__init__() 20 | self.input_dim = input_dim 21 | self.virtual_dim = bond_dim 22 | 23 | ### Squeezing of spatial dimension in first step 24 | #self.kScale = 4 # what is this? 25 | #nCh = self.kScale ** 2 * nCh 26 | # self.input_dim = self.input_dim / self.kScale 27 | 28 | # print(nCh) 29 | self.nCh = nCh 30 | if isinstance(kernel, int): 31 | kernel = 3 * [kernel] 32 | self.ker = kernel 33 | 34 | num_layers = np.int(np.log2(input_dim[0]) - 2) 35 | self.num_layers = num_layers 36 | self.disentangler_list = [] 37 | self.isometry_list = [] 38 | iDim = (self.input_dim - 2) / 2 39 | 40 | for ii in range(num_layers): 41 | #feature_dim = 2 * nCh 42 | feature_dim = 2 * nCh 43 | # print(feature_dim) 44 | #First level disentanglers 45 | # First level isometries 46 | 47 | ### First level MERA blocks 48 | self.disentangler_list.append(nn.ModuleList([ReLUMPS(input_dim=4, 49 | output_dim=4, 50 | nCh=nCh, bond_dim=4, 51 | feature_dim=feature_dim, parallel_eval=parallel_eval, 52 | adaptive_mode=adaptive_mode, periodic_bc=periodic_bc) 53 | for i in range(torch.prod(iDim))])) 54 | 55 | iDim = iDim + 1 56 | 57 | self.isometry_list.append(nn.ModuleList([ReLUMPS(input_dim=4, 58 | output_dim=1, 59 | nCh=nCh, bond_dim=bond_dim, 60 | feature_dim=feature_dim, parallel_eval=parallel_eval, 61 | adaptive_mode=adaptive_mode, periodic_bc=periodic_bc) 62 | for i in range(torch.prod(iDim))])) 63 | iDim = (iDim-2) / 2 64 | 65 | 66 | ### Final MPS block 67 | self.mpsFinal = ReLUMPS(input_dim=49, 68 | output_dim=output_dim, nCh=1, 69 | bond_dim=bond_dim, feature_dim=feature_dim, 70 | adaptive_mode=adaptive_mode, periodic_bc=periodic_bc, 71 | parallel_eval=parallel_eval) 72 | 73 | def forward(self, x): 74 | iDim = self.input_dim 75 | b = x.shape[0] # Batch size 76 | # Disentangler layer 77 | x_in = x 78 | 79 | for jj in range(self.num_layers): 80 | iDim = iDim // 2 81 | x_ent = x_in[:, :, 1:-1, 1:-1] 82 | # print('---------------') 83 | # print(x_ent.shape) 84 | #print('x_ent unfold shape: ', x_ent.unfold(2, 2, 2).unfold(3, 2, 2).shape) 85 | x_ent = x_ent.unfold(2, 2, 2).unfold(3, 2, 2).reshape(b, self.nCh, -1, 4) 86 | #print('x_ent unfold->reshape shape: ', x_ent.shape) 87 | #print('single x_ent unfold->reshape shape: ', x_ent[:, :, 0].shape) 88 | # print(x_ent.shape) 89 | # print(len(self.disentangler_list[jj])) 90 | y_ent = [self.disentangler_list[jj][i](x_ent[:, :, i]) for i in range(len(self.disentangler_list[jj]))] 91 | y_ent = torch.stack(y_ent, dim=1) # 512, 3969, 4 92 | y_ent = y_ent.view(y_ent.shape[0], y_ent.shape[1], 2, 2) 93 | y_ent = y_ent.view(y_ent.shape[0], iDim[0]-1, iDim[1]-1, 2, 2) 94 | y_ent_list = [] 95 | #print('y_ent shape: ', y_ent.shape) 96 | #print('torch cat col shape: ', torch.cat([y_ent[:, 0, i, :, :] for i in range(y_ent.shape[2])], dim=2).shape) 97 | for j in range(y_ent.shape[1]): 98 | y_ent_list.append(torch.cat([y_ent[:, j, i, :, :] for i in range(y_ent.shape[2])], dim=2)) 99 | y_ent = torch.cat(y_ent_list, dim=1) 100 | 101 | #print('y_ent shape: ', y_ent.shape) 102 | x_iso = x_in 103 | #print(x_iso.shape) 104 | x_iso[:, :, 1:-1, 1:-1] = y_ent.view(b, self.nCh, y_ent.shape[1], y_ent.shape[2]) 105 | #print('x_iso shape: ', x_iso.shape) 106 | x_iso = x_iso.unfold(2, 2, 2).unfold(3, 2, 2).reshape(b, self.nCh, -1, 4) 107 | y_iso = [self.isometry_list[jj][i](x_iso[:, :, i]) for i in range(len(self.isometry_list[jj]))] 108 | y_iso = torch.stack(y_iso, dim=1) # 512, 4096, 1 109 | 110 | x_in = y_iso.view(y_iso.shape[0], self.nCh, iDim[0], iDim[1]) 111 | 112 | #print('x6 shape: ', x6.shape) # 512, 1, 2, 2 113 | 114 | y = x_in.view(b, self.nCh, iDim[0] * iDim[1]) 115 | #print('LoTe y shape before mpsfinal ', y.shape) 116 | y = self.mpsFinal(y) 117 | return y.squeeze() 118 | 119 | class MERAnet(nn.Module): 120 | def __init__(self, input_dim, output_dim, bond_dim, feature_dim=2, nCh=3, 121 | kernel=[2, 2, 2], virtual_dim=1, 122 | adaptive_mode=False, periodic_bc=False, parallel_eval=False, 123 | label_site=None, path=None, init_std=1e-9, use_bias=True, 124 | fixed_bias=True, cutoff=1e-10, merge_threshold=2000): 125 | #bond_dim parameter is non-sense 126 | super().__init__() 127 | self.input_dim = input_dim 128 | self.virtual_dim = bond_dim 129 | 130 | ### Squeezing of spatial dimension in first step 131 | #self.kScale = 4 # what is this? 132 | #nCh = self.kScale ** 2 * nCh 133 | # self.input_dim = self.input_dim / self.kScale 134 | 135 | # print(nCh) 136 | self.nCh = nCh 137 | if isinstance(kernel, int): 138 | kernel = 3 * [kernel] 139 | self.ker = kernel 140 | 141 | 142 | iDim = (self.input_dim-2) / 2 143 | 144 | #feature_dim = 2 * nCh 145 | feature_dim = 2 * nCh 146 | # print(feature_dim) 147 | #First level disentanglers 148 | # First level isometries 149 | 150 | ### First level MERA blocks 151 | self.Disentangler_1 = nn.ModuleList([ReLUMPS(input_dim=4, 152 | output_dim=4, 153 | nCh=nCh, bond_dim=4, 154 | feature_dim=feature_dim, parallel_eval=parallel_eval, 155 | adaptive_mode=adaptive_mode, periodic_bc=periodic_bc) 156 | for i in range(torch.prod(iDim))]) 157 | 158 | iDim = iDim + 1 159 | 160 | self.Isometry_1 = nn.ModuleList([ReLUMPS(input_dim=4, 161 | output_dim=1, 162 | nCh=nCh, bond_dim=bond_dim, 163 | feature_dim=feature_dim, parallel_eval=parallel_eval, 164 | adaptive_mode=adaptive_mode, periodic_bc=periodic_bc) 165 | for i in range(torch.prod(iDim))]) 166 | 167 | 168 | iDim = (iDim-2) / 2 169 | 170 | ### Second level MERA blocks 171 | self.Disentangler_2 = nn.ModuleList([ReLUMPS(input_dim=4, 172 | output_dim=4, 173 | nCh=nCh, bond_dim=4, 174 | feature_dim=feature_dim, parallel_eval=parallel_eval, 175 | adaptive_mode=adaptive_mode, periodic_bc=periodic_bc) 176 | for i in range(torch.prod(iDim))]) 177 | 178 | iDim = iDim + 1 179 | 180 | self.Isometry_2 = nn.ModuleList([ReLUMPS(input_dim=4, 181 | output_dim=1, 182 | nCh=nCh, bond_dim=bond_dim, 183 | feature_dim=feature_dim, parallel_eval=parallel_eval, 184 | adaptive_mode=adaptive_mode, periodic_bc=periodic_bc) 185 | for i in range(torch.prod(iDim))]) 186 | 187 | iDim = (iDim - 2) / 2 188 | 189 | ### 3rd level MERA blocks 190 | self.Disentangler_3 = nn.ModuleList([ReLUMPS(input_dim=4, 191 | output_dim=4, 192 | nCh=nCh, bond_dim=4, 193 | feature_dim=feature_dim, parallel_eval=parallel_eval, 194 | adaptive_mode=adaptive_mode, periodic_bc=periodic_bc) 195 | for i in range(torch.prod(iDim))]) 196 | 197 | iDim = iDim + 1 198 | 199 | self.Isometry_3 = nn.ModuleList([ReLUMPS(input_dim=4, 200 | output_dim=1, 201 | nCh=nCh, bond_dim=bond_dim, 202 | feature_dim=feature_dim, parallel_eval=parallel_eval, 203 | adaptive_mode=adaptive_mode, periodic_bc=periodic_bc) 204 | for i in range(torch.prod(iDim))]) 205 | iDim = (iDim - 2) / 2 206 | 207 | ### 4th level MERA blocks 208 | self.Disentangler_4 = nn.ModuleList([ReLUMPS(input_dim=4, 209 | output_dim=4, 210 | nCh=nCh, bond_dim=4, 211 | feature_dim=feature_dim, parallel_eval=parallel_eval, 212 | adaptive_mode=adaptive_mode, periodic_bc=periodic_bc) 213 | for i in range(torch.prod(iDim))]) 214 | 215 | iDim = iDim + 1 216 | 217 | self.Isometry_4 = nn.ModuleList([ReLUMPS(input_dim=4, 218 | output_dim=1, 219 | nCh=nCh, bond_dim=bond_dim, 220 | feature_dim=feature_dim, parallel_eval=parallel_eval, 221 | adaptive_mode=adaptive_mode, periodic_bc=periodic_bc) 222 | for i in range(torch.prod(iDim))]) 223 | iDim = (iDim - 2) / 2 224 | 225 | ### 5th level MERA blocks 226 | self.Disentangler_5 = nn.ModuleList([ReLUMPS(input_dim=4, 227 | output_dim=4, 228 | nCh=nCh, bond_dim=4, 229 | feature_dim=feature_dim, parallel_eval=parallel_eval, 230 | adaptive_mode=adaptive_mode, periodic_bc=periodic_bc) 231 | for i in range(torch.prod(iDim))]) 232 | 233 | iDim = iDim + 1 234 | 235 | self.Isometry_5 = nn.ModuleList([ReLUMPS(input_dim=4, 236 | output_dim=1, 237 | nCh=nCh, bond_dim=bond_dim, 238 | feature_dim=feature_dim, parallel_eval=parallel_eval, 239 | adaptive_mode=adaptive_mode, periodic_bc=periodic_bc) 240 | for i in range(torch.prod(iDim))]) 241 | iDim = (iDim - 2) / 2 242 | ### 6th level MERA blocks 243 | self.Disentangler_6 = nn.ModuleList([ReLUMPS(input_dim=4, 244 | output_dim=4, 245 | nCh=nCh, bond_dim=4, 246 | feature_dim=feature_dim, parallel_eval=parallel_eval, 247 | adaptive_mode=adaptive_mode, periodic_bc=periodic_bc) 248 | for i in range(torch.prod(iDim))]) 249 | 250 | iDim = iDim + 1 251 | 252 | self.Isometry_6 = nn.ModuleList([ReLUMPS(input_dim=4, 253 | output_dim=1, 254 | nCh=nCh, bond_dim=bond_dim, 255 | feature_dim=feature_dim, parallel_eval=parallel_eval, 256 | adaptive_mode=adaptive_mode, periodic_bc=periodic_bc) 257 | for i in range(torch.prod(iDim))]) 258 | 259 | 260 | iDim = (iDim - 2) / 2 261 | 262 | ### Final MPS block 263 | self.mpsFinal = ReLUMPS(input_dim=4, 264 | output_dim=output_dim, nCh=1, 265 | bond_dim=bond_dim, feature_dim=feature_dim, 266 | adaptive_mode=adaptive_mode, periodic_bc=periodic_bc, 267 | parallel_eval=parallel_eval) 268 | 269 | def forward(self, x): 270 | iDim = self.input_dim // 2 271 | b = x.shape[0] # Batch size 272 | # Disentangler layer 273 | x_ent = x[:, :, 1:-1, 1:-1] 274 | #print('x_ent unfold shape: ', x_ent.unfold(2, 2, 2).unfold(3, 2, 2).shape) 275 | x_ent = x_ent.unfold(2, 2, 2).unfold(3, 2, 2).reshape(b, self.nCh, -1, 4) 276 | #print('x_ent unfold->reshape shape: ', x_ent.shape) 277 | #print('single x_ent unfold->reshape shape: ', x_ent[:, :, 0].shape) 278 | 279 | y_ent = [self.Disentangler_1[i](x_ent[:, :, i]) for i in range(len(self.Disentangler_1))] 280 | y_ent = torch.stack(y_ent, dim=1) # 512, 3969, 4 281 | y_ent = y_ent.view(y_ent.shape[0], y_ent.shape[1], 2, 2) 282 | y_ent = y_ent.view(y_ent.shape[0], iDim[0]-1, iDim[1]-1, 2, 2) 283 | y_ent_list = [] 284 | #print('y_ent shape: ', y_ent.shape) 285 | #print('torch cat col shape: ', torch.cat([y_ent[:, 0, i, :, :] for i in range(y_ent.shape[2])], dim=2).shape) 286 | for j in range(y_ent.shape[1]): 287 | y_ent_list.append(torch.cat([y_ent[:, j, i, :, :] for i in range(y_ent.shape[2])], dim=2)) 288 | y_ent = torch.cat(y_ent_list, dim=1) 289 | 290 | #print('y_ent shape: ', y_ent.shape) 291 | x_iso = x 292 | x_iso[:, :, 1:-1, 1:-1] = y_ent.view(b, self.nCh, y_ent.shape[1], y_ent.shape[2]) 293 | #print('x_iso shape: ', x_iso.shape) 294 | x_iso = x_iso.unfold(2, 2, 2).unfold(3, 2, 2).reshape(b, self.nCh, -1, 4) 295 | y_iso = [self.Isometry_1[i](x_iso[:, :, i]) for i in range(len(self.Isometry_1))] 296 | y_iso = torch.stack(y_iso, dim=1) # 512, 4096, 1 297 | 298 | x1 = y_iso.view(y_iso.shape[0], self.nCh, iDim[0], iDim[1]) 299 | 300 | #print('x1 shape: ', x1.shape) 301 | 302 | iDim = iDim // 2 303 | x_ent = x1[:, :, 1:-1, 1:-1] 304 | x_ent = x_ent.unfold(2, 2, 2).unfold(3, 2, 2).reshape(b, self.nCh, -1, 4) 305 | y_ent = [self.Disentangler_2[i](x_ent[:, :, i]) for i in range(len(self.Disentangler_2))] 306 | y_ent = torch.stack(y_ent, dim=1) # 512, 3969, 4 307 | y_ent = y_ent.view(y_ent.shape[0], y_ent.shape[1], 2, 2) 308 | y_ent = y_ent.view(y_ent.shape[0], iDim[0] - 1, iDim[1] - 1, 2, 2) 309 | y_ent_list = [] 310 | for j in range(iDim[0] - 1): 311 | y_ent_list.append(torch.cat([y_ent[:, j, i, :, :] for i in range(iDim[1] - 1)], dim=2)) 312 | y_ent = torch.cat(y_ent_list, dim=1) 313 | 314 | x_iso = x1 315 | x_iso[:, :, 1:-1, 1:-1] = y_ent.view(b, self.nCh, y_ent.shape[1], y_ent.shape[2]) 316 | x_iso = x_iso.unfold(2, 2, 2).unfold(3, 2, 2).reshape(b, self.nCh, -1, 4) 317 | y_iso = [self.Isometry_2[i](x_iso[:, :, i]) for i in range(len(self.Isometry_2))] 318 | y_iso = torch.stack(y_iso, dim=1) # 512, 4096, 1 319 | 320 | x2 = y_iso.view(y_iso.shape[0], self.nCh, iDim[0], iDim[1]) 321 | 322 | #print('x2 shape: ', x2.shape) 323 | 324 | iDim = iDim // 2 325 | x_ent = x2[:, :, 1:-1, 1:-1] 326 | x_ent = x_ent.unfold(2, 2, 2).unfold(3, 2, 2).reshape(b, self.nCh, -1, 4) 327 | y_ent = [self.Disentangler_3[i](x_ent[:, :, i]) for i in range(len(self.Disentangler_3))] 328 | y_ent = torch.stack(y_ent, dim=1) # 512, 3969, 4 329 | y_ent = y_ent.view(y_ent.shape[0], y_ent.shape[1], 2, 2) 330 | y_ent = y_ent.view(y_ent.shape[0], iDim[0] - 1, iDim[1] - 1, 2, 2) 331 | y_ent_list = [] 332 | for j in range(iDim[0] - 1): 333 | y_ent_list.append(torch.cat([y_ent[:, j, i, :, :] for i in range(iDim[1] - 1)], dim=2)) 334 | y_ent = torch.cat(y_ent_list, dim=1) 335 | 336 | x_iso = x2 337 | x_iso[:, :, 1:-1, 1:-1] = y_ent.view(b, self.nCh, y_ent.shape[1], y_ent.shape[2]) 338 | x_iso = x_iso.unfold(2, 2, 2).unfold(3, 2, 2).reshape(b, self.nCh, -1, 4) 339 | y_iso = [self.Isometry_3[i](x_iso[:, :, i]) for i in range(len(self.Isometry_3))] 340 | y_iso = torch.stack(y_iso, dim=1) # 512, 4096, 1 341 | 342 | x3 = y_iso.view(y_iso.shape[0], self.nCh, iDim[0], iDim[1]) 343 | 344 | #print('x3 shape: ', x3.shape) 345 | 346 | iDim = iDim // 2 347 | x_ent = x3[:, :, 1:-1, 1:-1] 348 | x_ent = x_ent.unfold(2, 2, 2).unfold(3, 2, 2).reshape(b, self.nCh, -1, 4) 349 | y_ent = [self.Disentangler_4[i](x_ent[:, :, i]) for i in range(len(self.Disentangler_4))] 350 | y_ent = torch.stack(y_ent, dim=1) # 512, 3969, 4 351 | y_ent = y_ent.view(y_ent.shape[0], y_ent.shape[1], 2, 2) 352 | y_ent = y_ent.view(y_ent.shape[0], iDim[0] - 1, iDim[1] - 1, 2, 2) 353 | y_ent_list = [] 354 | for j in range(iDim[0] - 1): 355 | y_ent_list.append(torch.cat([y_ent[:, j, i, :, :] for i in range(iDim[1] - 1)], dim=2)) 356 | y_ent = torch.cat(y_ent_list, dim=1) 357 | 358 | x_iso = x3 359 | x_iso[:, :, 1:-1, 1:-1] = y_ent.view(b, self.nCh, y_ent.shape[1], y_ent.shape[2]) 360 | x_iso = x_iso.unfold(2, 2, 2).unfold(3, 2, 2).reshape(b, self.nCh, -1, 4) 361 | y_iso = [self.Isometry_4[i](x_iso[:, :, i]) for i in range(len(self.Isometry_4))] 362 | y_iso = torch.stack(y_iso, dim=1) # 512, 4096, 1 363 | 364 | x4 = y_iso.view(y_iso.shape[0], self.nCh, iDim[0], iDim[1]) 365 | 366 | #print('x4 shape: ', x4.shape) 367 | 368 | iDim = iDim // 2 369 | x_ent = x4[:, :, 1:-1, 1:-1] 370 | x_ent = x_ent.unfold(2, 2, 2).unfold(3, 2, 2).reshape(b, self.nCh, -1, 4) 371 | y_ent = [self.Disentangler_5[i](x_ent[:, :, i]) for i in range(len(self.Disentangler_5))] 372 | y_ent = torch.stack(y_ent, dim=1) # 512, 3969, 4 373 | y_ent = y_ent.view(y_ent.shape[0], y_ent.shape[1], 2, 2) 374 | y_ent = y_ent.view(y_ent.shape[0], iDim[0] - 1, iDim[1] - 1, 2, 2) 375 | y_ent_list = [] 376 | for j in range(iDim[0] - 1): 377 | y_ent_list.append(torch.cat([y_ent[:, j, i, :, :] for i in range(iDim[1] - 1)], dim=2)) 378 | y_ent = torch.cat(y_ent_list, dim=1) 379 | 380 | x_iso = x4 381 | x_iso[:, :, 1:-1, 1:-1] = y_ent.view(b, self.nCh, y_ent.shape[1], y_ent.shape[2]) 382 | x_iso = x_iso.unfold(2, 2, 2).unfold(3, 2, 2).reshape(b, self.nCh, -1, 4) 383 | y_iso = [self.Isometry_5[i](x_iso[:, :, i]) for i in range(len(self.Isometry_5))] 384 | y_iso = torch.stack(y_iso, dim=1) # 512, 4096, 1 385 | 386 | x5 = y_iso.view(y_iso.shape[0], self.nCh, iDim[0], iDim[1]) 387 | 388 | #print('x5 shape: ', x5.shape) 389 | 390 | iDim = iDim // 2 391 | x_ent = x5[:, :, 1:-1, 1:-1] 392 | x_ent = x_ent.unfold(2, 2, 2).unfold(3, 2, 2).reshape(b, self.nCh, -1, 4) 393 | #print(x_ent.shape) 394 | #print(len(self.Disentangler_6)) 395 | y_ent = [self.Disentangler_6[i](x_ent[:, :, i]) for i in range(len(self.Disentangler_6))] 396 | y_ent = torch.stack(y_ent, dim=1) # 512, 3969, 4 397 | y_ent = y_ent.view(y_ent.shape[0], y_ent.shape[1], 2, 2) 398 | y_ent = y_ent.view(y_ent.shape[0], iDim[0] - 1, iDim[1] - 1, 2, 2) 399 | y_ent_list = [] 400 | for j in range(iDim[0] - 1): 401 | y_ent_list.append(torch.cat([y_ent[:, j, i, :, :] for i in range(iDim[1] - 1)], dim=2)) 402 | y_ent = torch.cat(y_ent_list, dim=1) 403 | 404 | x_iso = x5 405 | x_iso[:, :, 1:-1, 1:-1] = y_ent.view(b, self.nCh, y_ent.shape[1], y_ent.shape[2]) 406 | x_iso = x_iso.unfold(2, 2, 2).unfold(3, 2, 2).reshape(b, self.nCh, -1, 4) 407 | y_iso = [self.Isometry_6[i](x_iso[:, :, i]) for i in range(len(self.Isometry_6))] 408 | y_iso = torch.stack(y_iso, dim=1) # 512, 4096, 1 409 | 410 | x6 = y_iso.view(y_iso.shape[0], self.nCh, iDim[0], iDim[1]) 411 | 412 | #print('x6 shape: ', x6.shape) # 512, 1, 2, 2 413 | 414 | y = x6.view(b, self.nCh, iDim[0] * iDim[1]) 415 | # print('LoTe y shape before mpsfinal ', y.shape) 416 | y = self.mpsFinal(y) 417 | return y.squeeze() 418 | 419 | 420 | class loTeNet(nn.Module): 421 | def __init__(self, input_dim, output_dim, bond_dim, feature_dim=2, nCh=3, 422 | kernel=[2, 2, 2], virtual_dim=1, 423 | adaptive_mode=False, periodic_bc=False, parallel_eval=False, 424 | label_site=None, path=None, init_std=1e-9, use_bias=True, 425 | fixed_bias=True, cutoff=1e-10, merge_threshold=2000): 426 | super().__init__() 427 | self.input_dim = input_dim 428 | self.virtual_dim = bond_dim 429 | 430 | ### Squeezing of spatial dimension in first step 431 | self.kScale = 4 # what is this? 432 | nCh = self.kScale**2 * nCh 433 | self.input_dim = self.input_dim/self.kScale 434 | 435 | #print(nCh) 436 | self.nCh = nCh 437 | if isinstance(kernel, int): 438 | kernel = 3 * [kernel] 439 | self.ker = kernel 440 | iDim = (self.input_dim/(self.ker[0])) 441 | 442 | feature_dim = 2*nCh 443 | #print(feature_dim) 444 | ### First level MPS blocks 445 | self.module1 = nn.ModuleList([ MPS(input_dim=(self.ker[0])**2, 446 | output_dim=self.virtual_dim, 447 | nCh=nCh, bond_dim=bond_dim, 448 | feature_dim=feature_dim, parallel_eval=parallel_eval, 449 | adaptive_mode=adaptive_mode, periodic_bc=periodic_bc) 450 | for i in range(torch.prod(iDim))]) 451 | 452 | self.BN1 = nn.BatchNorm1d(torch.prod(iDim).numpy(),affine=True) 453 | 454 | 455 | iDim = iDim/self.ker[1] 456 | feature_dim = 2*self.virtual_dim 457 | 458 | ### Second level MPS blocks 459 | self.module2 = nn.ModuleList([ MPS(input_dim=self.ker[1]**2, 460 | output_dim=self.virtual_dim, 461 | nCh=self.virtual_dim, bond_dim=bond_dim, 462 | feature_dim=feature_dim, parallel_eval=parallel_eval, 463 | adaptive_mode=adaptive_mode, periodic_bc=periodic_bc) 464 | for i in range(torch.prod(iDim))]) 465 | 466 | self.BN2 = nn.BatchNorm1d(torch.prod(iDim).numpy(),affine=True) 467 | 468 | iDim = iDim/self.ker[2] 469 | 470 | ### Third level MPS blocks 471 | self.module3 = nn.ModuleList([ MPS(input_dim=self.ker[2]**2, 472 | output_dim=self.virtual_dim, 473 | nCh=self.virtual_dim, bond_dim=bond_dim, 474 | feature_dim=feature_dim, parallel_eval=parallel_eval, 475 | adaptive_mode=adaptive_mode, periodic_bc=periodic_bc) 476 | for i in range(torch.prod(iDim))]) 477 | 478 | self.BN3 = nn.BatchNorm1d(torch.prod(iDim).numpy(),affine=True) 479 | 480 | ### Final MPS block 481 | self.mpsFinal = MPS(input_dim=len(self.module3), 482 | output_dim=output_dim, nCh=1, 483 | bond_dim=bond_dim, feature_dim=feature_dim, 484 | adaptive_mode=adaptive_mode, periodic_bc=periodic_bc, 485 | parallel_eval=parallel_eval) 486 | 487 | def forward(self,x): 488 | 489 | b = x.shape[0] #Batch size 490 | iDim = self.input_dim/(self.ker[0]) 491 | #print(self.input_dim) 492 | #print(x.shape) 493 | #print(torch.prod(iDim)) 494 | #print(self.nCh) 495 | # Level 1 contraction 496 | #print(x.unfold(2,iDim[0],iDim[0]).unfold(3,iDim[1],iDim[1]).shape) 497 | x = x.unfold(2,iDim[0],iDim[0]).unfold(3,iDim[1],iDim[1]).reshape(b, 498 | self.nCh,-1,(self.ker[0])**2) 499 | # print(x.shape) 500 | #print('x[:, :, 0].shape', x[:, :, 0].shape) 501 | #assert False 502 | #print('LoTe x total shape layer 1: ', x.shape) 503 | #print('LoTe x shape layer 1: ', x[:, :, 0].shape) 504 | y = [ self.module1[i](x[:,:,i]) for i in range(len(self.module1))] 505 | y = torch.stack(y,dim=1) 506 | #print(y.shape) 507 | #assert False 508 | y = self.BN1(y).unsqueeze(1) 509 | 510 | # Level 2 contraction 511 | 512 | y = y.view(b,self.virtual_dim,iDim[0],iDim[1]) 513 | iDim = (iDim/self.ker[1]) 514 | y = y.unfold(2,iDim[0],iDim[0]).unfold(3,iDim[1], 515 | iDim[1]).reshape(b,self.virtual_dim,-1,self.ker[1]**2) 516 | #print('LoTe x total shape layer 2: ', y.shape) 517 | #print('LoTe x shape layer 2: ', y[:, :, 0].shape) 518 | x = [ self.module2[i](y[:,:,i]) for i in range(len(self.module2))] 519 | x = torch.stack(x,dim=1) 520 | x = self.BN2(x).unsqueeze(1) 521 | 522 | 523 | # Level 3 contraction 524 | x = x.view(b,self.virtual_dim,iDim[0],iDim[1]) 525 | iDim = (iDim/self.ker[2]) 526 | x = x.unfold(2,iDim[0],iDim[0]).unfold(3,iDim[1], 527 | iDim[1]).reshape(b,self.virtual_dim,-1,self.ker[2]**2) 528 | #print('LoTe x total shape layer 3: ', x.shape) 529 | #print('LoTe x shape layer 3: ', x[:, :, 0].shape) 530 | y = [self.module3[i](x[:,:,i]) for i in range(len(self.module3))] 531 | 532 | y = torch.stack(y,dim=1) 533 | y = self.BN3(y) 534 | 535 | if self.virtual_dim == 1: 536 | y = y.unsqueeze(2) 537 | if y.shape[1] > 1: 538 | # Final layer 539 | y = y.permute(0,2,1) 540 | #print('LoTe y shape before mpsfinal ', y.shape) 541 | y = self.mpsFinal(y) 542 | return y.squeeze() 543 | 544 | 545 | class ConvTeNet(nn.Module): 546 | def __init__(self, input_dim, output_dim, bond_dim, feature_dim=2, nCh=3, 547 | kernel=[2, 2, 2], virtual_dim=1, 548 | adaptive_mode=False, periodic_bc=False, parallel_eval=False, 549 | label_site=None, path=None, init_std=1e-9, use_bias=True, 550 | fixed_bias=True, cutoff=1e-10, merge_threshold=2000): 551 | super().__init__() 552 | self.input_dim = input_dim 553 | self.virtual_dim = bond_dim 554 | 555 | ### Squeezing of spatial dimension in first step 556 | #self.kScale = 4 557 | #nCh = self.kScale ** 2 * nCh 558 | self.input_dim = self.input_dim 559 | 560 | self.nCh = nCh 561 | if isinstance(kernel, int): 562 | kernel = 3 * [kernel] 563 | self.ker = kernel 564 | iDim = (self.input_dim / (self.ker[0])) 565 | #feature_dim = 2 * nCh 566 | feature_dim = 2 * self.ker[0] ** 2 567 | #print(feature_dim) 568 | ### First level MPS blocks 569 | #(self.ker[0]) ** 2, 570 | self.module1 = nn.ModuleList([MPS(input_dim=nCh, 571 | output_dim=self.virtual_dim, 572 | nCh=nCh, bond_dim=bond_dim, 573 | feature_dim=feature_dim, parallel_eval=parallel_eval, 574 | adaptive_mode=adaptive_mode, periodic_bc=periodic_bc) 575 | for i in range(torch.prod(iDim))]) 576 | 577 | self.BN1 = nn.BatchNorm1d(torch.prod(iDim).numpy(), affine=True) 578 | 579 | iDim = iDim / self.ker[1] 580 | feature_dim = 2 * self.ker[1] ** 2 581 | 582 | ### Second level MPS blocks 583 | self.module2 = nn.ModuleList([MPS(input_dim=self.virtual_dim, 584 | output_dim=self.virtual_dim, 585 | nCh=self.virtual_dim, bond_dim=bond_dim, 586 | feature_dim=feature_dim, parallel_eval=parallel_eval, 587 | adaptive_mode=adaptive_mode, periodic_bc=periodic_bc) 588 | for i in range(torch.prod(iDim))]) 589 | 590 | self.BN2 = nn.BatchNorm1d(torch.prod(iDim).numpy(), affine=True) 591 | 592 | iDim = iDim / self.ker[2] 593 | feature_dim = 2 * self.ker[2] ** 2 594 | ### Third level MPS blocks 595 | self.module3 = nn.ModuleList([MPS(input_dim=self.virtual_dim, 596 | output_dim=self.virtual_dim, 597 | nCh=self.virtual_dim, bond_dim=bond_dim, 598 | feature_dim=feature_dim, parallel_eval=parallel_eval, 599 | adaptive_mode=adaptive_mode, periodic_bc=periodic_bc) 600 | for i in range(torch.prod(iDim))]) 601 | 602 | self.BN3 = nn.BatchNorm1d(torch.prod(iDim).numpy(), affine=True) 603 | feature_dim = 2 * self.virtual_dim 604 | ### Final MPS block 605 | self.mpsFinal = MPS(input_dim=len(self.module3), 606 | output_dim=output_dim, nCh=1, 607 | bond_dim=bond_dim, feature_dim=feature_dim, 608 | adaptive_mode=adaptive_mode, periodic_bc=periodic_bc, 609 | parallel_eval=parallel_eval) 610 | 611 | def forward(self, x): 612 | 613 | b = x.shape[0] # Batch size 614 | H, W = x.shape[2], x.shape[3] 615 | iDim = self.input_dim / (self.ker[0]) 616 | #print('iDim: ', iDim) 617 | #print('x.shape: ',x.shape) 618 | # Level 1 contraction 619 | x_org = x 620 | x = x.unfold(2, self.ker[0], self.ker[0]).unfold(3, self.ker[0], self.ker[0]).reshape(b, (self.ker[0]) ** 2, -1, self.nCh) 621 | ### 622 | x_unfold = x_org.unfold(2, self.ker[0], self.ker[0]).unfold(3, self.ker[0], self.ker[0]).reshape(b, (self.ker[0]) ** 2, -1, self.nCh) 623 | x_unfold = x_unfold.view(x_unfold.shape[0], x_unfold.shape[2], x_unfold.shape[1]) 624 | print('unfolded x shape: ', x_unfold.shape) 625 | x_fold = F.fold(x_unfold, output_size=(H, W), kernel_size=(self.ker[0], self.ker[0])) 626 | print('x_fold shape: ', x_fold) 627 | assert False 628 | #print('x.shape: ',x.shape) 629 | #print(x.) 630 | #assert False 631 | # print(iDim) 632 | #x = torch.cat([x, x], dim=3) 633 | #print('x[:, :, 0].shape', x[:, :, 0].shape) 634 | print('Conv x total shape layer 1: ', x.shape) 635 | print('Conv x shape layer 1: ', x[:, :, 0].shape) 636 | print('After contraction x_i: ', self.module1[0](x[:, :, 0]).shape) 637 | y = [self.module1[i](x[:, :, i]) for i in range(len(self.module1))] 638 | y = torch.stack(y, dim=1) 639 | print(y.shape) 640 | assert False 641 | y = self.BN1(y).unsqueeze(1) 642 | 643 | # Level 2 contraction 644 | #print(y.shape) 645 | #iDim = (iDim / self.ker[1]) 646 | y = y.view(b, self.virtual_dim, iDim[0], iDim[1]) 647 | iDim = (iDim / self.ker[1]) 648 | y = y.unfold(2, self.ker[1], self.ker[1]).unfold(3, self.ker[1], self.ker[1]).reshape(b, self.ker[1] ** 2, -1, self.virtual_dim) 649 | #print(y.shape) 650 | #print(y[:, :, 0].shape) 651 | #print('Conv x total shape layer 2: ', y.shape) 652 | #print('Conv x shape layer 2: ', y[:, :, 0].shape) 653 | 654 | x = [self.module2[i](y[:, :, i]) for i in range(len(self.module2))] 655 | #assert False 656 | 657 | x = torch.stack(x, dim=1) 658 | #print(x.shape) 659 | 660 | x = self.BN2(x).unsqueeze(1) 661 | 662 | # Level 3 contraction 663 | x = x.view(b, self.virtual_dim, iDim[0], iDim[1]) 664 | iDim = (iDim / self.ker[2]) 665 | x = x.unfold(2, self.ker[2], self.ker[2]).unfold(3, self.ker[2], self.ker[2]).reshape(b, self.ker[2] ** 2, -1, self.virtual_dim) 666 | #print('x[:, :, 0].shape before module3: ', x[:, :, 0].shape) 667 | #print('Conv x total shape layer 3: ', x.shape) 668 | #print('Conv x shape layer 3: ', x[:, :, 0].shape) 669 | y = [self.module3[i](x[:, :, i]) for i in range(len(self.module3))] 670 | 671 | y = torch.stack(y, dim=1) 672 | #print(y.shape) 673 | y = self.BN3(y) 674 | 675 | if self.virtual_dim == 1: 676 | y = y.unsqueeze(2) 677 | if y.shape[1] > 1: 678 | # Final layer 679 | y = y.permute(0, 2, 1) 680 | #print('Conv y shape before mpsfinal ', y.shape) 681 | y = self.mpsFinal(y) 682 | return y.squeeze() 683 | 684 | 685 | class Combined_LoTeConv(nn.Module): 686 | def __init__(self, input_dim, output_dim, bond_dim, feature_dim=2, nCh=3, 687 | kernel1=[2, 2, 2], kernel2=[2, 2, 2],virtual_dim=1, 688 | adaptive_mode=False, periodic_bc=False, parallel_eval=False, 689 | label_site=None, path=None, init_std=1e-9, use_bias=True, 690 | fixed_bias=True, cutoff=1e-10, merge_threshold=2000): 691 | super().__init__() 692 | self.input_dim = input_dim 693 | self.virtual_dim = bond_dim 694 | 695 | ### Squeezing of spatial dimension in first step 696 | self.LoTe_kScale = 4 # what is this? 697 | LoTe_nCh = self.LoTe_kScale ** 2 * nCh 698 | self.LoTe_input_dim = self.input_dim / self.LoTe_kScale 699 | 700 | # print(nCh) 701 | self.LoTe_nCh = LoTe_nCh 702 | if isinstance(kernel1, int): 703 | LoTe_kernel = 3 * [kernel1] 704 | else: 705 | LoTe_kernel = kernel1 706 | self.LoTe_ker = LoTe_kernel 707 | LoTe_iDim = (self.LoTe_input_dim / (self.LoTe_ker[0])) 708 | 709 | LoTe_feature_dim = 2 * LoTe_nCh 710 | 711 | ## Parameters for Conv 712 | self.Conv_input_dim = self.input_dim 713 | 714 | self.Conv_nCh = nCh 715 | if isinstance(kernel2, int): 716 | Conv_kernel = 3 * [kernel2] 717 | else: 718 | Conv_kernel = kernel2 719 | self.Conv_ker = Conv_kernel 720 | Conv_iDim = (self.Conv_input_dim / (self.Conv_ker[0])) 721 | Conv_feature_dim = 2 * self.Conv_ker[0] ** 2 722 | 723 | ### First level MPS blocks 724 | self.LoTe_module1 = nn.ModuleList([MPS(input_dim=(self.LoTe_ker[0]) ** 2, 725 | output_dim=self.virtual_dim, 726 | nCh=nCh, bond_dim=bond_dim, 727 | feature_dim=LoTe_feature_dim, parallel_eval=parallel_eval, 728 | adaptive_mode=adaptive_mode, periodic_bc=periodic_bc) 729 | for i in range(torch.prod(LoTe_iDim))]) 730 | 731 | self.LoTe_BN1 = nn.BatchNorm1d(torch.prod(LoTe_iDim).numpy(), affine=True) 732 | 733 | LoTe_iDim = LoTe_iDim / self.LoTe_ker[1] 734 | LoTe_feature_dim = 2 * self.virtual_dim 735 | 736 | ### Second level MPS blocks 737 | self.LoTe_module2 = nn.ModuleList([MPS(input_dim=self.LoTe_ker[1] ** 2 + self.Conv_ker[1] ** 2, 738 | output_dim=self.virtual_dim, 739 | nCh=self.virtual_dim, bond_dim=bond_dim, 740 | feature_dim=LoTe_feature_dim, parallel_eval=parallel_eval, 741 | adaptive_mode=adaptive_mode, periodic_bc=periodic_bc) 742 | for i in range(torch.prod(LoTe_iDim))]) 743 | 744 | self.LoTe_BN2 = nn.BatchNorm1d(torch.prod(LoTe_iDim).numpy(), affine=True) 745 | LoTe_iDim = LoTe_iDim / self.LoTe_ker[2] 746 | 747 | ### Third level MPS blocks 748 | self.LoTe_module3 = nn.ModuleList([MPS(input_dim=self.LoTe_ker[2] ** 2, 749 | output_dim=self.virtual_dim, 750 | nCh=self.virtual_dim, bond_dim=bond_dim, 751 | feature_dim=2 * LoTe_feature_dim, parallel_eval=parallel_eval, 752 | adaptive_mode=adaptive_mode, periodic_bc=periodic_bc) 753 | for i in range(torch.prod(LoTe_iDim))]) 754 | 755 | self.LoTe_BN3 = nn.BatchNorm1d(torch.prod(LoTe_iDim).numpy(), affine=True) 756 | 757 | ### Final MPS block 758 | # self.LoTe_mpsFinal = MPS(input_dim=len(self.LoTe_module3), 759 | # output_dim=output_dim, nCh=1, 760 | # bond_dim=bond_dim, feature_dim=LoTe_feature_dim, 761 | # adaptive_mode=adaptive_mode, periodic_bc=periodic_bc, 762 | # parallel_eval=parallel_eval) 763 | 764 | ############################################################################## 765 | ############################################################################## 766 | ############################ # ####### ##### ###### ####################### 767 | ########################### ########## ### ##### #### ######################## 768 | ########################### ########## ### ###### ## ######################### 769 | ############################# # ####### ######## ########################### 770 | ############################################################################# 771 | 772 | ## Parameters for Conv TN 773 | ### First level MPS blocks 774 | self.Conv_module1 = nn.ModuleList([MPS(input_dim=self.Conv_nCh, 775 | output_dim=self.virtual_dim, 776 | nCh=nCh, bond_dim=bond_dim, 777 | feature_dim=Conv_feature_dim, parallel_eval=parallel_eval, 778 | adaptive_mode=adaptive_mode, periodic_bc=periodic_bc) 779 | for i in range(torch.prod(Conv_iDim))]) 780 | 781 | self.Conv_BN1 = nn.BatchNorm1d(torch.prod(Conv_iDim).numpy(), affine=True) 782 | 783 | Conv_iDim = Conv_iDim / self.Conv_ker[1] 784 | Conv_feature_dim = 2 * self.Conv_ker[1] ** 2 785 | 786 | ### Second level MPS blocks 787 | self.Conv_module2 = nn.ModuleList([MPS(input_dim=self.virtual_dim, 788 | output_dim=self.virtual_dim, 789 | nCh=self.virtual_dim, bond_dim=bond_dim, 790 | feature_dim=2 * self.LoTe_kScale + Conv_feature_dim, parallel_eval=parallel_eval, 791 | adaptive_mode=adaptive_mode, periodic_bc=periodic_bc) 792 | for i in range(torch.prod(Conv_iDim))]) 793 | 794 | self.Conv_BN2 = nn.BatchNorm1d(torch.prod(Conv_iDim).numpy(), affine=True) 795 | self.BN2 = nn.BatchNorm1d(torch.prod(Conv_iDim).numpy(), affine=True) 796 | 797 | Conv_iDim = Conv_iDim / self.Conv_ker[2] 798 | Conv_feature_dim = 2 * self.Conv_ker[2] ** 2 799 | ### Third level MPS blocks 800 | self.Conv_module3 = nn.ModuleList([MPS(input_dim=2 * self.virtual_dim, 801 | output_dim=self.virtual_dim, 802 | nCh=self.virtual_dim, bond_dim=bond_dim, 803 | feature_dim=Conv_feature_dim, parallel_eval=parallel_eval, 804 | adaptive_mode=adaptive_mode, periodic_bc=periodic_bc) 805 | for i in range(torch.prod(Conv_iDim))]) 806 | 807 | self.Conv_BN3 = nn.BatchNorm1d(torch.prod(Conv_iDim).numpy(), affine=True) 808 | self.BN3 = nn.BatchNorm1d(torch.prod(Conv_iDim).numpy(), affine=True) 809 | 810 | Conv_feature_dim = 2 * self.virtual_dim 811 | ### Final MPS block 812 | # self.Conv_mpsFinal = MPS(input_dim=len(self.Conv_module3), 813 | # output_dim=output_dim, nCh=1, 814 | # bond_dim=bond_dim, feature_dim=Conv_feature_dim, 815 | # adaptive_mode=adaptive_mode, periodic_bc=periodic_bc, 816 | # parallel_eval=parallel_eval) 817 | 818 | self.mpsFinal = MPS(input_dim=len(self.Conv_module3), 819 | output_dim=output_dim, nCh=1, 820 | bond_dim=bond_dim, feature_dim=2 * Conv_feature_dim, 821 | adaptive_mode=adaptive_mode, periodic_bc=periodic_bc, 822 | parallel_eval=parallel_eval) 823 | def forward(self, x): 824 | 825 | b = x.shape[0] # Batch size 826 | LoTe_iDim = self.LoTe_input_dim / (self.LoTe_ker[0]) 827 | Conv_iDim = self.Conv_input_dim / (self.Conv_ker[0]) 828 | 829 | LoTe_x = x.unfold(2, LoTe_iDim[0], LoTe_iDim[0]).unfold(3, LoTe_iDim[1], LoTe_iDim[1]).reshape(b, 830 | self.LoTe_nCh, -1, (self.LoTe_ker[0]) ** 2) 831 | Conv_x = x.unfold(2, self.Conv_ker[0], self.Conv_ker[0]).unfold(3, self.Conv_ker[0], 832 | self.Conv_ker[0]).reshape(b, (self.Conv_ker[0]) ** 2, -1, 833 | self.Conv_nCh) 834 | LoTe_y = [self.LoTe_module1[i](LoTe_x[:, :, i]) for i in range(len(self.LoTe_module1))] 835 | Conv_y = [self.Conv_module1[i](Conv_x[:, :, i]) for i in range(len(self.Conv_module1))] 836 | 837 | LoTe_y= torch.stack(LoTe_y, dim=1) 838 | Conv_y= torch.stack(Conv_y, dim=1) 839 | 840 | LoTe_y = self.LoTe_BN1(LoTe_y).unsqueeze(1) 841 | Conv_y = self.Conv_BN1(Conv_y).unsqueeze(1) 842 | 843 | # Level 2 contraction 844 | 845 | LoTe_y = LoTe_y.view(b, self.virtual_dim, LoTe_iDim[0], LoTe_iDim[1]) 846 | LoTe_iDim = (LoTe_iDim / self.LoTe_ker[1]) 847 | LoTe_y = LoTe_y.unfold(2, LoTe_iDim[0], LoTe_iDim[0]).unfold(3, LoTe_iDim[1], 848 | LoTe_iDim[1]).reshape(b, self.virtual_dim, -1, self.LoTe_ker[1] ** 2) 849 | 850 | Conv_y = Conv_y.view(b, self.virtual_dim, Conv_iDim[0], Conv_iDim[1]) 851 | Conv_iDim = (Conv_iDim / self.Conv_ker[1]) 852 | Conv_y = Conv_y.unfold(2, self.Conv_ker[1], self.Conv_ker[1]).unfold(3, self.Conv_ker[1], 853 | self.Conv_ker[1]).reshape(b, self.Conv_ker[1] ** 2, -1, 854 | self.virtual_dim) 855 | 856 | #print(LoTe_y.shape) 857 | #print(Conv_y.permute(0, 3, 2, 1).shape) 858 | Combined_Feature_1 = torch.cat([LoTe_y, Conv_y.permute(0, 3, 2, 1)], dim=3) 859 | Lote_x2 = Combined_Feature_1 860 | Conv_x2 = Combined_Feature_1.permute(0, 3, 2, 1) 861 | 862 | # print('Lote_x2', Lote_x2.shape) 863 | # print(len(self.LoTe_module2)) 864 | # for i in range(len(self.LoTe_module2)): 865 | # print('Now loop: ', i) 866 | # print(Lote_x2[:, :, i].shape) 867 | # self.LoTe_module2[i](Lote_x2[:, :, i]) 868 | # assert False, 'stop here' 869 | LoTe_x2 = [self.LoTe_module2[i](Lote_x2[:, :, i]) for i in range(len(self.LoTe_module2))] 870 | Conv_x2 = [self.Conv_module2[i](Conv_x2[:, :, i]) for i in range(len(self.Conv_module2))] 871 | 872 | LoTe_x2 = torch.stack(LoTe_x2, dim=1) 873 | Conv_x2 = torch.stack(Conv_x2, dim=1) 874 | 875 | cat_x2 = torch.cat([LoTe_x2, Conv_x2], dim=2) 876 | bn_x2 = self.BN2(cat_x2).unsqueeze(1) 877 | #print("LoTe_x2", LoTe_x2.shape) 878 | #print("Conv_x2", Conv_x2.shape) 879 | 880 | # 881 | # # Level 3 contraction 882 | Lote_x3 = bn_x2.view(b, 2 * self.virtual_dim, LoTe_iDim[0], LoTe_iDim[1]) 883 | LoTe_iDim = (LoTe_iDim / self.LoTe_ker[2]) 884 | Lote_x3 = Lote_x3.unfold(2, LoTe_iDim[0], LoTe_iDim[0]).unfold(3, LoTe_iDim[1], 885 | LoTe_iDim[1]).reshape(b, 2 * self.virtual_dim, -1, self.LoTe_ker[2] ** 2) 886 | 887 | 888 | #print('LoTe x total shape layer 3: ', Lote_x3.shape) 889 | #print('LoTe x shape layer 3: ', Lote_x3[:, :, 0].shape) 890 | 891 | Conv_x3 = bn_x2.view(b, 2 * self.virtual_dim, Conv_iDim[0], Conv_iDim[1]) 892 | Conv_iDim = (Conv_iDim / self.Conv_ker[2]) 893 | Conv_x3 = Conv_x3.unfold(2, self.Conv_ker[2], self.Conv_ker[2]).unfold(3, self.Conv_ker[2], 894 | self.Conv_ker[2]).reshape(b, self.Conv_ker[2] ** 2, -1, 895 | 2 * self.virtual_dim) 896 | # print('x[:, :, 0].shape before module3: ', x[:, :, 0].shape) 897 | #print('Conv x total shape layer 3: ', x.shape) 898 | #print('Conv x shape layer 3: ', x[:, :, 0].shape) 899 | 900 | Lote_x3 = [self.LoTe_module3[i](Lote_x3[:, :, i]) for i in range(len(self.LoTe_module3))] 901 | Conv_x3 = [self.Conv_module3[i](Conv_x3[:, :, i]) for i in range(len(self.Conv_module3))] 902 | 903 | Lote_y3 = torch.stack(Lote_x3, dim=1) 904 | Conv_y3 = torch.stack(Conv_x3, dim=1) 905 | 906 | #print("Lote_y3", Lote_y3.shape) 907 | #print("Conv_y3", Conv_y3.shape) 908 | 909 | cat_x3 = torch.cat([Lote_y3, Conv_y3], dim=2) 910 | 911 | bn_x3 = self.BN3(cat_x3) 912 | 913 | # Final layer 914 | y3 = bn_x3.permute(0, 2, 1) 915 | #print('LoTe y shape before mpsfinal ', y3.shape) 916 | y3 = self.mpsFinal(y3) 917 | return y3.squeeze() -------------------------------------------------------------------------------- /models/lotenet.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/timqqt/MERA_Image_Classification/e96211f45ade86f031a0d99ad0670231844ef3a1/models/lotenet.pyc -------------------------------------------------------------------------------- /models/mps.py: -------------------------------------------------------------------------------- 1 | ### Adapted from https://github.com/jemisjoky/TorchMPS/ 2 | import math 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from utils.utils import init_tensor, svd_flex 7 | from models.contractables import SingleMat, MatRegion, OutputCore, ContractableList, \ 8 | EdgeVec, OutputMat 9 | import pdb 10 | from numpy import pi as PI 11 | from numpy import sqrt 12 | from scipy.special import comb 13 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 14 | 15 | from models import contractables 16 | 17 | class MPS(nn.Module): 18 | """ 19 | Matrix product state which converts input into a single output vector 20 | """ 21 | def __init__(self, input_dim, output_dim, bond_dim, feature_dim=2, nCh=3, 22 | adaptive_mode=False, periodic_bc=False, parallel_eval=False, 23 | label_site=None, path=None, init_std=1e-9, use_bias=True, 24 | fixed_bias=True, cutoff=1e-10, merge_threshold=2000): 25 | super().__init__() 26 | 27 | org_input_dim = input_dim 28 | if label_site is None: 29 | label_site = input_dim // 2 30 | assert label_site >= 0 and label_site <= input_dim 31 | 32 | # Using bias matrices in adaptive_mode is too complicated, so I'm 33 | # disabling it here 34 | if adaptive_mode: 35 | use_bias = False 36 | # Our MPS is made of two InputRegions separated by an OutputSite. 37 | module_list = [] 38 | init_args = {'bond_str': 'slri', 39 | 'shape': [label_site, bond_dim, bond_dim, feature_dim], 40 | 'init_method': ('min_random_eye' if adaptive_mode else 41 | 'random_zero', init_std, output_dim)} 42 | 43 | # The first input region 44 | if label_site > 0: 45 | tensor = init_tensor(**init_args) 46 | 47 | module_list.append(InputRegion(tensor, use_bias=use_bias, 48 | fixed_bias=fixed_bias)) 49 | 50 | # The output site 51 | tensor = init_tensor(shape=[output_dim, bond_dim, bond_dim], 52 | bond_str='olr', init_method=('min_random_eye' if adaptive_mode else 53 | 'random_eye', init_std, output_dim)) 54 | module_list.append(OutputSite(tensor)) 55 | 56 | # The other input region 57 | if label_site < input_dim: 58 | init_args['shape'] = [input_dim-label_site, bond_dim, bond_dim, 59 | feature_dim] 60 | #print('label_site ', label_site) 61 | #print('init_args[shape] ', init_args['shape']) 62 | tensor = init_tensor(**init_args) 63 | module_list.append(InputRegion(tensor, use_bias=use_bias, 64 | fixed_bias=fixed_bias)) 65 | 66 | # Initialize linear_region according to our adaptive_mode specification 67 | if adaptive_mode: 68 | self.linear_region = MergedLinearRegion(module_list=module_list, 69 | periodic_bc=periodic_bc, 70 | parallel_eval=parallel_eval, cutoff=cutoff, 71 | merge_threshold=merge_threshold) 72 | 73 | # Initialize the list of bond dimensions, which starts out constant 74 | self.bond_list = bond_dim * torch.ones(input_dim + 2, 75 | dtype=torch.long) 76 | if not periodic_bc: 77 | self.bond_list[0], self.bond_list[-1] = 1, 1 78 | 79 | # Initialize the list of singular values, which start out at -1 80 | self.sv_list = -1. * torch.ones([input_dim + 2, bond_dim]) 81 | 82 | else: 83 | self.linear_region = LinearRegion(module_list=module_list, 84 | periodic_bc=periodic_bc, 85 | parallel_eval=parallel_eval) 86 | assert len(self.linear_region) == input_dim 87 | 88 | if path: 89 | assert isinstance(path, (list, torch.Tensor)) 90 | assert len(path) == input_dim 91 | 92 | # Set the rest of our MPS attributes 93 | self.input_dim = input_dim 94 | self.output_dim = output_dim 95 | self.bond_dim = bond_dim 96 | self.feature_dim = feature_dim 97 | self.periodic_bc = periodic_bc 98 | self.adaptive_mode = adaptive_mode 99 | self.label_site = label_site 100 | self.path = path 101 | self.use_bias = use_bias 102 | self.fixed_bias = fixed_bias 103 | self.cutoff = cutoff 104 | self.merge_threshold = merge_threshold 105 | self.feature_map = None 106 | self.linear_region = self.linear_region.to(device) 107 | module_list = [m.to(device) for m in module_list] 108 | 109 | def forward(self, input_data): 110 | """ 111 | Embed our data and pass it to an MPS with a single output site 112 | 113 | Args: 114 | input_data (Tensor): Input with shape [batch_size, input_dim] or 115 | [batch_size, input_dim, feature_dim]. In the 116 | former case, the data points are turned into 117 | 2D vectors using a default linear feature map. 118 | """ 119 | input_data = input_data.permute(0,2,1) 120 | 121 | x1 = torch.cos(input_data * PI/2) 122 | x2 = torch.sin(input_data* PI/2) 123 | x = torch.cat((x1,x2),dim=2) 124 | output = self.linear_region(x) 125 | 126 | return output.squeeze() 127 | 128 | def core_len(self): 129 | """ 130 | Returns the number of cores, which is at least the required input size 131 | """ 132 | return self.linear_region.core_len() 133 | 134 | def __len__(self): 135 | """ 136 | Returns the number of input sites, which equals the input size 137 | """ 138 | return self.input_dim 139 | 140 | class ReLUMPS(nn.Module): 141 | """ 142 | Matrix product state which converts input into a single output vector 143 | """ 144 | def __init__(self, input_dim, output_dim, bond_dim, feature_dim=2, nCh=3, 145 | adaptive_mode=False, periodic_bc=False, parallel_eval=False, 146 | label_site=None, path=None, init_std=1e-9, use_bias=True, 147 | fixed_bias=True, cutoff=1e-10, merge_threshold=2000): 148 | super().__init__() 149 | 150 | org_input_dim = input_dim 151 | if label_site is None: 152 | label_site = input_dim // 2 153 | assert label_site >= 0 and label_site <= input_dim 154 | 155 | # Using bias matrices in adaptive_mode is too complicated, so I'm 156 | # disabling it here 157 | if adaptive_mode: 158 | use_bias = False 159 | # Our MPS is made of two InputRegions separated by an OutputSite. 160 | module_list = [] 161 | init_args = {'bond_str': 'slri', 162 | 'shape': [label_site, bond_dim, bond_dim, feature_dim], 163 | 'init_method': ('min_random_eye' if adaptive_mode else 164 | 'random_zero', init_std, output_dim)} 165 | 166 | # The first input region 167 | if label_site > 0: 168 | tensor = init_tensor(**init_args) 169 | 170 | module_list.append(InputRegion(tensor, use_bias=use_bias, 171 | fixed_bias=fixed_bias)) 172 | 173 | # The output site 174 | tensor = init_tensor(shape=[output_dim, bond_dim, bond_dim], 175 | bond_str='olr', init_method=('min_random_eye' if adaptive_mode else 176 | 'random_eye', init_std, output_dim)) 177 | module_list.append(OutputSite(tensor)) 178 | 179 | # The other input region 180 | if label_site < input_dim: 181 | init_args['shape'] = [input_dim-label_site, bond_dim, bond_dim, 182 | feature_dim] 183 | #print('label_site ', label_site) 184 | #print('init_args[shape] ', init_args['shape']) 185 | tensor = init_tensor(**init_args) 186 | module_list.append(InputRegion(tensor, use_bias=use_bias, 187 | fixed_bias=fixed_bias)) 188 | 189 | # Initialize linear_region according to our adaptive_mode specification 190 | if adaptive_mode: 191 | self.linear_region = MergedLinearRegion(module_list=module_list, 192 | periodic_bc=periodic_bc, 193 | parallel_eval=parallel_eval, cutoff=cutoff, 194 | merge_threshold=merge_threshold) 195 | 196 | # Initialize the list of bond dimensions, which starts out constant 197 | self.bond_list = bond_dim * torch.ones(input_dim + 2, 198 | dtype=torch.long) 199 | if not periodic_bc: 200 | self.bond_list[0], self.bond_list[-1] = 1, 1 201 | 202 | # Initialize the list of singular values, which start out at -1 203 | self.sv_list = -1. * torch.ones([input_dim + 2, bond_dim]) 204 | 205 | else: 206 | self.linear_region = LinearRegion(module_list=module_list, 207 | periodic_bc=periodic_bc, 208 | parallel_eval=parallel_eval) 209 | assert len(self.linear_region) == input_dim 210 | 211 | if path: 212 | assert isinstance(path, (list, torch.Tensor)) 213 | assert len(path) == input_dim 214 | 215 | # Set the rest of our MPS attributes 216 | self.input_dim = input_dim 217 | self.output_dim = output_dim 218 | self.bond_dim = bond_dim 219 | self.feature_dim = feature_dim 220 | self.periodic_bc = periodic_bc 221 | self.adaptive_mode = adaptive_mode 222 | self.label_site = label_site 223 | self.path = path 224 | self.use_bias = use_bias 225 | self.fixed_bias = fixed_bias 226 | self.cutoff = cutoff 227 | self.merge_threshold = merge_threshold 228 | self.feature_map = None 229 | self.linear_region = self.linear_region.to(device) 230 | module_list = [m.to(device) for m in module_list] 231 | 232 | def forward(self, input_data): 233 | """ 234 | Embed our data and pass it to an MPS with a single output site 235 | 236 | Args: 237 | input_data (Tensor): Input with shape [batch_size, input_dim] or 238 | [batch_size, input_dim, feature_dim]. In the 239 | former case, the data points are turned into 240 | 2D vectors using a default linear feature map. 241 | """ 242 | input_data = input_data.permute(0,2,1) 243 | 244 | x1 = torch.nn.functional.relu(input_data) 245 | x2 = torch.sigmoid(input_data) 246 | x = torch.cat((x1,x2),dim=2) 247 | output = self.linear_region(x) 248 | 249 | return output.squeeze() 250 | 251 | def core_len(self): 252 | """ 253 | Returns the number of cores, which is at least the required input size 254 | """ 255 | return self.linear_region.core_len() 256 | 257 | def __len__(self): 258 | """ 259 | Returns the number of input sites, which equals the input size 260 | """ 261 | return self.input_dim 262 | 263 | class LinearRegion(nn.Module): 264 | """ 265 | List of modules which feeds input to each module and returns reduced output 266 | """ 267 | def __init__(self, module_list,periodic_bc=False, parallel_eval=False, 268 | module_states=None): 269 | # Check that module_list is a list whose entries are Pytorch modules 270 | if not isinstance(module_list, list) or module_list is []: 271 | raise ValueError("Input to LinearRegion must be nonempty list") 272 | for i, item in enumerate(module_list): 273 | if not isinstance(item, nn.Module): 274 | raise ValueError("Input items to LinearRegion must be PyTorch Module instances, but item {i} is not") 275 | super().__init__() 276 | 277 | # Wrap as a ModuleList for proper parameter registration 278 | self.module_list = nn.ModuleList(module_list) 279 | self.periodic_bc = periodic_bc 280 | self.parallel_eval = parallel_eval 281 | 282 | def forward(self, input_data): 283 | """ 284 | Contract input with list of MPS cores and return result as contractable 285 | 286 | Args: 287 | input_data (Tensor): Input with shape [batch_size, input_dim, 288 | feature_dim] 289 | """ 290 | # Check that input_data has the correct shape 291 | #print('input_data.shape', input_data.shape) 292 | #print('len(self)', len(self)) 293 | # print('Input_data and Len(self)') 294 | # print(input_data.shape) 295 | # print(len(self)) 296 | assert len(input_data.shape) == 3 297 | assert input_data.size(1) == len(self) 298 | #print('input_data.shape', input_data.shape) 299 | periodic_bc = self.periodic_bc 300 | parallel_eval = self.parallel_eval 301 | lin_bonds = ['l', 'r'] 302 | 303 | # For each module, pull out the number of pixels needed and call that 304 | # module's forward() method, putting the result in contractable_list 305 | ind = 0 306 | contractable_list = [] 307 | contractables.global_bs = input_data.shape[0] 308 | for module in self.module_list: 309 | mod_len = len(module) 310 | #print('mod_len', mod_len) 311 | if mod_len == 1: 312 | mod_input = input_data[:, ind] 313 | else: 314 | mod_input = input_data[:, ind:(ind+mod_len)] 315 | ind += mod_len 316 | 317 | contractable_list.append(module(mod_input)) 318 | 319 | # For periodic boundary conditions, reduce contractable_list and 320 | # trace over the left and right indices to get our output 321 | if periodic_bc: 322 | contractable_list = ContractableList(contractable_list) 323 | contractable = contractable_list.reduce(parallel_eval=True) 324 | 325 | # Unpack the output (atomic) contractable 326 | tensor, bond_str = contractable.tensor, contractable.bond_str 327 | assert all(c in bond_str for c in lin_bonds) 328 | 329 | # Build einsum string for the trace of tensor 330 | in_str, out_str = "", "" 331 | for c in bond_str: 332 | if c in lin_bonds: 333 | in_str += 'l' 334 | else: 335 | in_str += c 336 | out_str += c 337 | ein_str = in_str + "->" + out_str 338 | 339 | # Return the trace over left and right indices 340 | return torch.einsum(ein_str, [tensor]) 341 | 342 | # For open boundary conditions, add dummy edge vectors to 343 | # contractable_list and reduce everything to get our output 344 | else: 345 | # Get the dimension of left and right bond indices 346 | end_items = [contractable_list[i]for i in [0, -1]] 347 | bond_strs = [item.bond_str for item in end_items] 348 | bond_inds = [bs.index(c) for (bs, c) in zip(bond_strs, lin_bonds)] 349 | bond_dims = [item.tensor.size(ind) for (item, ind) in 350 | zip(end_items, bond_inds)] 351 | 352 | # Build dummy end vectors and insert them at the ends of our list 353 | end_vecs = [torch.zeros(dim) for dim in bond_dims] 354 | end_vecs = [e.to(device) for e in end_vecs] 355 | 356 | for vec in end_vecs: 357 | vec[0] = 1 358 | 359 | contractable_list.insert(0, EdgeVec(end_vecs[0], is_left_vec=True)) 360 | contractable_list.append(EdgeVec(end_vecs[1], is_left_vec=False)) 361 | 362 | # Multiply together everything in contractable_list 363 | contractable_list = ContractableList(contractable_list) 364 | output = contractable_list.reduce(parallel_eval=parallel_eval) 365 | 366 | return output.tensor 367 | 368 | def core_len(self): 369 | """ 370 | Returns the number of cores, which is at least the required input size 371 | """ 372 | return sum([module.core_len() for module in self.module_list]) 373 | 374 | def __len__(self): 375 | """ 376 | Returns the number of input sites, which is the required input size 377 | """ 378 | return sum([len(module) for module in self.module_list]) 379 | 380 | class MergedLinearRegion(LinearRegion): 381 | """ 382 | Dynamic variant of LinearRegion that periodically rearranges its submodules 383 | """ 384 | def __init__(self, module_list, periodic_bc=False, parallel_eval=False, cutoff=1e-10, merge_threshold=2000): 385 | # Initialize a LinearRegion with our given module_list 386 | super().__init__(module_list, periodic_bc, parallel_eval) 387 | 388 | # Initialize attributes self.module_list_0 and self.module_list_1 389 | # using the unmerged self.module_list, then redefine the latter in 390 | # terms of one of the former lists 391 | self.offset = 0 392 | self.merge(offset=self.offset) 393 | self.merge(offset=(self.offset+1)%2) 394 | self.module_list = getattr(self, "module_list_{self.offset}") 395 | 396 | # Initialize variables used during switching 397 | self.input_counter = 0 398 | self.merge_threshold = merge_threshold 399 | self.cutoff = cutoff 400 | 401 | def forward(self, input_data): 402 | """ 403 | Contract input with list of MPS cores and return result as contractable 404 | 405 | MergedLinearRegion keeps an input counter of the number of inputs, and 406 | when this exceeds its merge threshold, triggers an unmerging and 407 | remerging of its parameter tensors. 408 | 409 | Args: 410 | input_data (Tensor): Input with shape [batch_size, input_dim, 411 | feature_dim] 412 | """ 413 | # If we've hit our threshold, flip the merge state of our tensors 414 | if self.input_counter >= self.merge_threshold: 415 | bond_list, sv_list = self.unmerge(cutoff=self.cutoff) 416 | self.offset = (self.offset + 1) % 2 417 | self.merge(offset=self.offset) 418 | self.input_counter -= self.merge_threshold 419 | 420 | # Point self.module_list to the appropriate merged module 421 | self.module_list = getattr(self, "module_list_{self.offset}") 422 | else: 423 | bond_list, sv_list = None, None 424 | 425 | # Increment our counter and call the LinearRegion's forward method 426 | self.input_counter += input_data.size(0) 427 | output = super().forward(input_data) 428 | 429 | # If we flipped our merge state, then return the bond_list and output 430 | if bond_list: 431 | return output, bond_list, sv_list 432 | else: 433 | return output 434 | 435 | @torch.no_grad() 436 | def merge(self, offset): 437 | """ 438 | Convert unmerged modules in self.module_list to merged counterparts 439 | 440 | This proceeds by first merging all unmerged cores internally, then 441 | merging lone cores when possible during a second sweep 442 | """ 443 | assert offset in [0, 1] 444 | 445 | unmerged_list = self.module_list 446 | 447 | # Merge each core internally and add the results to midway_list 448 | site_num = offset 449 | merged_list = [] 450 | for core in unmerged_list: 451 | assert not isinstance(core, MergedInput) 452 | assert not isinstance(core, MergedOutput) 453 | 454 | # Apply internal merging routine if our core supports it 455 | if hasattr(core, 'merge'): 456 | merged_list.extend(core.merge(offset=site_num%2)) 457 | else: 458 | merged_list.append(core) 459 | 460 | site_num += core.core_len() 461 | 462 | # Merge pairs of cores when possible (currently only with 463 | # InputSites), making sure to respect the offset for merging. 464 | while True: 465 | mod_num, site_num = 0, 0 466 | combined_list = [] 467 | 468 | while mod_num < len(merged_list) - 1: 469 | left_core, right_core = merged_list[mod_num: mod_num+2] 470 | new_core = self.combine(left_core, right_core, 471 | merging=True) 472 | 473 | # If cores aren't combinable, move our sliding window by 1 474 | if new_core is None or offset != site_num % 2: 475 | combined_list.append(left_core) 476 | mod_num += 1 477 | site_num += left_core.core_len() 478 | 479 | # If we get something new, move to the next distinct pair 480 | else: 481 | assert new_core.core_len() == left_core.core_len() + \ 482 | right_core.core_len() 483 | combined_list.append(new_core) 484 | mod_num += 2 485 | site_num += new_core.core_len() 486 | 487 | # Add the last core if there's nothing to merge it with 488 | if mod_num == len(merged_list)-1: 489 | combined_list.append(merged_list[mod_num]) 490 | mod_num += 1 491 | 492 | # We're finished when unmerged_list remains unchanged 493 | if len(combined_list) == len(merged_list): 494 | break 495 | else: 496 | merged_list = combined_list 497 | 498 | # Finally, update the appropriate merged module list 499 | list_name = "module_list_{offset}" 500 | # If the merged module list hasn't been set yet, initialize it 501 | if not hasattr(self, list_name): 502 | setattr(self, list_name, nn.ModuleList(merged_list)) 503 | 504 | # Otherwise, do an in-place update so that all tensors remain 505 | # properly registered with whatever optimizer we use 506 | else: 507 | module_list = getattr(self, list_name) 508 | assert len(module_list) == len(merged_list) 509 | for i in range(len(module_list)): 510 | assert module_list[i].tensor.shape == \ 511 | merged_list[i].tensor.shape 512 | module_list[i].tensor[:] = merged_list[i].tensor 513 | 514 | @torch.no_grad() 515 | def unmerge(self, cutoff=1e-10): 516 | """ 517 | Convert merged modules to unmerged counterparts 518 | 519 | This proceeds by first unmerging all merged cores internally, then 520 | combining lone cores where possible 521 | """ 522 | list_name = "module_list_{self.offset}" 523 | merged_list = getattr(self, list_name) 524 | 525 | # Unmerge each core internally and add results to unmerged_list 526 | unmerged_list, bond_list, sv_list = [], [-1], [-1] 527 | for core in merged_list: 528 | 529 | # Apply internal unmerging routine if our core supports it 530 | if hasattr(core, 'unmerge'): 531 | new_cores, new_bonds, new_svs = core.unmerge(cutoff) 532 | unmerged_list.extend(new_cores) 533 | bond_list.extend(new_bonds[1:]) 534 | sv_list.extend(new_svs[1:]) 535 | else: 536 | assert not isinstance(core, InputRegion) 537 | unmerged_list.append(core) 538 | bond_list.append(-1) 539 | sv_list.append(-1) 540 | 541 | # Combine all combinable pairs of cores. This occurs in several 542 | # passes, and for now acts nontrivially only on InputSite instances 543 | while True: 544 | mod_num = 0 545 | combined_list = [] 546 | 547 | while mod_num < len(unmerged_list) - 1: 548 | left_core, right_core = unmerged_list[mod_num: mod_num+2] 549 | new_core = self.combine(left_core, right_core, 550 | merging=False) 551 | 552 | # If cores aren't combinable, move our sliding window by 1 553 | if new_core is None: 554 | combined_list.append(left_core) 555 | mod_num += 1 556 | 557 | # If we get something new, move to the next distinct pair 558 | else: 559 | combined_list.append(new_core) 560 | mod_num += 2 561 | 562 | # Add the last core if there's nothing to combine it with 563 | if mod_num == len(unmerged_list)-1: 564 | combined_list.append(unmerged_list[mod_num]) 565 | mod_num += 1 566 | 567 | # We're finished when unmerged_list remains unchanged 568 | if len(combined_list) == len(unmerged_list): 569 | break 570 | else: 571 | unmerged_list = combined_list 572 | 573 | # Find the average (log) norm of all of our cores 574 | log_norms = [] 575 | for core in unmerged_list: 576 | log_norms.append([torch.log(norm) for norm in core.get_norm()]) 577 | log_scale = sum([sum(ns) for ns in log_norms]) 578 | log_scale /= sum([len(ns) for ns in log_norms]) 579 | 580 | # Now rescale all cores so that their norms are roughly equal 581 | scales = [[torch.exp(log_scale-n) for n in ns] for ns in log_norms] 582 | for core, these_scales in zip(unmerged_list, scales): 583 | core.rescale_norm(these_scales) 584 | 585 | # Add our unmerged module list as a new attribute and return 586 | # the updated bond dimensions 587 | self.module_list = nn.ModuleList(unmerged_list) 588 | return bond_list, sv_list 589 | 590 | def combine(self, left_core, right_core, merging): 591 | """ 592 | Combine a pair of cores into a new core using context-dependent rules 593 | 594 | Depending on the types of left_core and right_core, along with whether 595 | we're currently merging (merging=True) or unmerging (merging=False), 596 | either return a new core, or None if no rule exists for this context 597 | """ 598 | 599 | # Combine an OutputSite with a stray InputSite, return a MergedOutput 600 | if merging and ((isinstance(left_core, OutputSite) and 601 | isinstance(right_core, InputSite)) or 602 | (isinstance(left_core, InputSite) and 603 | isinstance(right_core, OutputSite))): 604 | 605 | left_site = isinstance(left_core, InputSite) 606 | if left_site: 607 | new_tensor = torch.einsum('lui,our->olri', [left_core.tensor, 608 | right_core.tensor]) 609 | else: 610 | new_tensor = torch.einsum('olu,uri->olri', [left_core.tensor, 611 | right_core.tensor]) 612 | return MergedOutput(new_tensor, left_output=(not left_site)) 613 | 614 | # Combine an InputRegion with a stray InputSite, return an InputRegion 615 | elif not merging and ((isinstance(left_core, InputRegion) and 616 | isinstance(right_core, InputSite)) or 617 | (isinstance(left_core, InputSite) and 618 | isinstance(right_core, InputRegion))): 619 | 620 | left_site = isinstance(left_core, InputSite) 621 | if left_site: 622 | left_tensor = left_core.tensor.unsqueeze(0) 623 | right_tensor = right_core.tensor 624 | else: 625 | left_tensor = left_core.tensor 626 | right_tensor = right_core.tensor.unsqueeze(0) 627 | 628 | assert left_tensor.shape[1:] == right_tensor.shape[1:] 629 | new_tensor = torch.cat([left_tensor, right_tensor]) 630 | 631 | return InputRegion(new_tensor) 632 | 633 | # If this situation doesn't belong to the above cases, return None 634 | else: 635 | return None 636 | 637 | def core_len(self): 638 | """ 639 | Returns the number of cores, which is at least the required input size 640 | """ 641 | return sum([module.core_len() for module in self.module_list]) 642 | 643 | def __len__(self): 644 | """ 645 | Returns the number of input sites, which is the required input size 646 | """ 647 | return sum([len(module) for module in self.module_list]) 648 | 649 | class InputRegion(nn.Module): 650 | """ 651 | Contiguous region of MPS input cores, associated with bond_str = 'slri' 652 | """ 653 | def __init__(self, tensor, use_bias=True, fixed_bias=True, bias_mat=None, 654 | ephemeral=False): 655 | super().__init__() 656 | 657 | # Make sure tensor has correct size and the component mats are square 658 | assert len(tensor.shape) == 4 659 | assert tensor.size(1) == tensor.size(2) 660 | bond_dim = tensor.size(1) 661 | #print('Tensor in input region: ', tensor.shape) 662 | # If we are using bias matrices, set those up here 663 | if use_bias: 664 | assert bias_mat is None or isinstance(bias_mat, torch.Tensor) 665 | bias_mat = torch.eye(bond_dim).unsqueeze(0) if bias_mat is None \ 666 | else bias_mat 667 | 668 | bias_modes = len(list(bias_mat.shape)) 669 | assert bias_modes in [2, 3] 670 | if bias_modes == 2: 671 | bias_mat = bias_mat.unsqueeze(0) 672 | 673 | # Register our tensors as a Pytorch Parameter or Tensor 674 | if ephemeral: 675 | self.register_buffer(name='tensor', tensor=tensor.contiguous()) 676 | self.register_buffer(name='bias_mat', tensor=bias_mat) 677 | else: 678 | self.register_parameter(name='tensor', 679 | param=nn.Parameter(tensor.contiguous())) 680 | if fixed_bias: 681 | self.register_buffer(name='bias_mat', tensor=bias_mat) 682 | else: 683 | self.register_parameter(name='bias_mat', 684 | param=nn.Parameter(bias_mat)) 685 | 686 | self.use_bias = use_bias 687 | self.fixed_bias = fixed_bias 688 | 689 | def forward(self, input_data): 690 | """ 691 | Contract input with MPS cores and return result as a MatRegion 692 | 693 | Args: 694 | input_data (Tensor): Input with shape [batch_size, input_dim, 695 | feature_dim] 696 | """ 697 | # Check that input_data has the correct shape 698 | tensor = self.tensor 699 | #print(tensor.shape) 700 | #print(tensor.size(3)) 701 | #print(input_data.shape) 702 | #print(input_data.size(2)) 703 | #assert False 704 | if len(input_data.shape) == 2: 705 | input_data = input_data.unsqueeze(1) 706 | assert len(input_data.shape) == 3 707 | assert input_data.size(1) == len(self) 708 | assert input_data.size(2) == tensor.size(3) 709 | 710 | # Contract the input with our core tensor 711 | mats = torch.einsum('slri,bsi->bslr', [tensor, input_data]) 712 | 713 | # If we're using bias matrices, add those here 714 | if self.use_bias: 715 | bond_dim = tensor.size(1) 716 | bias_mat = self.bias_mat.unsqueeze(0) 717 | mats = mats + bias_mat.expand_as(mats) 718 | 719 | return MatRegion(mats) 720 | 721 | def merge(self, offset): 722 | """ 723 | Merge all pairs of neighboring cores and return a new list of cores 724 | 725 | offset is either 0 or 1, which gives the first core at which we start 726 | our merging. Depending on the length of our InputRegion, the output of 727 | merge may have 1, 2, or 3 entries, with the majority of sites ending in 728 | a MergedInput instance 729 | """ 730 | assert offset in [0, 1] 731 | num_sites = self.core_len() 732 | parity = num_sites % 2 733 | 734 | # Cases with empty tensors might arise in recursion below 735 | if num_sites == 0: 736 | return [None] 737 | 738 | # Simplify the problem into one where offset=0 and num_sites is even 739 | if (offset, parity) == (1, 1): 740 | out_list = [self[0], self[1:].merge(offset=0)[0]] 741 | elif (offset, parity) == (1, 0): 742 | out_list = [self[0], self[1:-1].merge(offset=0)[0], self[-1]] 743 | elif (offset, parity) == (0, 1): 744 | out_list = [self[:-1].merge(offset=0)[0], self[-1]] 745 | 746 | # The main case of interest, with no offset and an even number of sites 747 | else: 748 | tensor = self.tensor 749 | even_cores, odd_cores = tensor[0::2], tensor[1::2] 750 | assert len(even_cores) == len(odd_cores) 751 | 752 | # Multiply all pairs of cores, keeping inputs separate 753 | merged_cores = torch.einsum('slui,surj->slrij', [even_cores, 754 | odd_cores]) 755 | out_list = [MergedInput(merged_cores)] 756 | 757 | # Remove empty MergedInputs, which appear in very small InputRegions 758 | return [x for x in out_list if x is not None] 759 | 760 | def __getitem__(self, key): 761 | """ 762 | Returns an InputRegion instance sliced along the site index 763 | """ 764 | assert isinstance(key, int) or isinstance(key, slice) 765 | 766 | if isinstance(key, slice): 767 | return InputRegion(self.tensor[key]) 768 | else: 769 | return InputSite(self.tensor[key]) 770 | 771 | def get_norm(self): 772 | """ 773 | Returns list of the norms of each core in InputRegion 774 | """ 775 | return [torch.norm(core) for core in self.tensor] 776 | 777 | @torch.no_grad() 778 | def rescale_norm(self, scale_list): 779 | """ 780 | Rescales the norm of each core by an amount specified in scale_list 781 | 782 | For the i'th tensor defining a core in InputRegion, we rescale as 783 | tensor_i <- scale_i * tensor_i, where scale_i = scale_list[i] 784 | """ 785 | assert len(scale_list) == len(self.tensor) 786 | 787 | for core, scale in zip(self.tensor, scale_list): 788 | core *= scale 789 | 790 | def core_len(self): 791 | return len(self) 792 | 793 | def __len__(self): 794 | return self.tensor.size(0) 795 | 796 | class MergedInput(nn.Module): 797 | """ 798 | Contiguous region of merged MPS cores, each taking in a pair of input data 799 | 800 | Since MergedInput arises after contracting together existing input cores, 801 | a merged input tensor is required for initialization 802 | """ 803 | def __init__(self, tensor): 804 | # Check that our input tensor has the correct shape 805 | bond_str = 'slrij' 806 | shape = tensor.shape 807 | assert len(shape) == 5 808 | assert shape[1] == shape[2] 809 | assert shape[3] == shape[4] 810 | 811 | super().__init__() 812 | 813 | # Register our tensor as a Pytorch Parameter 814 | self.register_parameter(name='tensor', 815 | param=nn.Parameter(tensor.contiguous())) 816 | 817 | def forward(self, input_data): 818 | """ 819 | Contract input with merged MPS cores and return result as a MatRegion 820 | 821 | Args: 822 | input_data (Tensor): Input with shape [batch_size, input_dim, 823 | feature_dim], where input_dim must be even 824 | (each merged core takes 2 inputs) 825 | """ 826 | # Check that input_data has the correct shape 827 | tensor = self.tensor 828 | assert len(input_data.shape) == 3 829 | assert input_data.size(1) == len(self) 830 | assert input_data.size(2) == tensor.size(3) 831 | assert input_data.size(1) % 2 == 0 832 | 833 | # Divide input_data into inputs living on even and on odd sites 834 | inputs = [input_data[:, 0::2], input_data[:, 1::2]] 835 | 836 | # Contract the odd (right-most) and even inputs with merged cores 837 | tensor = torch.einsum('slrij,bsj->bslri', [tensor, inputs[1]]) 838 | mats = torch.einsum('bslri,bsi->bslr', [tensor, inputs[0]]) 839 | 840 | return MatRegion(mats) 841 | 842 | def unmerge(self, cutoff=1e-10): 843 | """ 844 | Separate the cores in our MergedInput and return an InputRegion 845 | 846 | The length of the resultant InputRegion will be identical to our 847 | original MergedInput (same number of inputs), but its core_len will 848 | be doubled (twice as many individual cores) 849 | """ 850 | bond_str = 'slrij' 851 | tensor = self.tensor 852 | svd_string = 'lrij->lui,urj' 853 | max_D = tensor.size(1) 854 | 855 | # Split every one of the cores into two and add them both to core_list 856 | core_list, bond_list, sv_list = [], [-1], [-1] 857 | for merged_core in tensor: 858 | sv_vec = torch.empty(max_D) 859 | left_core, right_core, bond_dim = svd_flex(merged_core, svd_string, 860 | max_D, cutoff, sv_vec=sv_vec) 861 | 862 | core_list += [left_core, right_core] 863 | bond_list += [bond_dim, -1] 864 | sv_list += [sv_vec, -1] 865 | 866 | # Collate the split cores into one tensor and return as an InputRegion 867 | tensor = torch.stack(core_list) 868 | return [InputRegion(tensor)], bond_list, sv_list 869 | 870 | def get_norm(self): 871 | """ 872 | Returns list of the norm of each core in MergedInput 873 | """ 874 | return [torch.norm(core) for core in self.tensor] 875 | 876 | @torch.no_grad() 877 | def rescale_norm(self, scale_list): 878 | """ 879 | Rescales the norm of each core by an amount specified in scale_list 880 | 881 | For the i'th tensor defining a core in MergedInput, we rescale as 882 | tensor_i <- scale_i * tensor_i, where scale_i = scale_list[i] 883 | """ 884 | assert len(scale_list) == len(self.tensor) 885 | 886 | for core, scale in zip(self.tensor, scale_list): 887 | core *= scale 888 | 889 | def core_len(self): 890 | return len(self) 891 | 892 | def __len__(self): 893 | """ 894 | Returns the number of input sites, which is twice the number of cores 895 | """ 896 | return 2 * self.tensor.size(0) 897 | 898 | class InputSite(nn.Module): 899 | """ 900 | A single MPS core which takes in a single input datum, bond_str = 'lri' 901 | """ 902 | def __init__(self, tensor): 903 | super().__init__() 904 | # Register our tensor as a Pytorch Parameter 905 | self.register_parameter(name='tensor', 906 | param=nn.Parameter(tensor.contiguous())) 907 | 908 | def forward(self, input_data): 909 | """ 910 | Contract input with MPS core and return result as a SingleMat 911 | 912 | Args: 913 | input_data (Tensor): Input with shape [batch_size, feature_dim] 914 | """ 915 | # Check that input_data has the correct shape 916 | tensor = self.tensor 917 | assert len(input_data.shape) == 2 918 | assert input_data.size(1) == tensor.size(2) 919 | 920 | # Contract the input with our core tensor 921 | mat = torch.einsum('lri,bi->blr', [tensor, input_data]) 922 | 923 | return SingleMat(mat) 924 | 925 | def get_norm(self): 926 | """ 927 | Returns the norm of our core tensor, wrapped as a singleton list 928 | """ 929 | return [torch.norm(self.tensor)] 930 | 931 | @torch.no_grad() 932 | def rescale_norm(self, scale): 933 | """ 934 | Rescales the norm of our core by a factor of input `scale` 935 | """ 936 | if isinstance(scale, list): 937 | assert len(scale) == 1 938 | scale = scale[0] 939 | 940 | self.tensor *= scale 941 | 942 | def core_len(self): 943 | return 1 944 | 945 | def __len__(self): 946 | return 1 947 | 948 | class OutputSite(nn.Module): 949 | """ 950 | A single MPS core with no input and a single output index, bond_str = 'olr' 951 | """ 952 | def __init__(self, tensor): 953 | super().__init__() 954 | # Register our tensor as a Pytorch Parameter 955 | self.register_parameter(name='tensor', 956 | param=nn.Parameter(tensor.contiguous())) 957 | 958 | def forward(self, input_data): 959 | """ 960 | Return the OutputSite wrapped as an OutputCore contractable 961 | """ 962 | return OutputCore(self.tensor) 963 | 964 | def get_norm(self): 965 | """ 966 | Returns the norm of our core tensor, wrapped as a singleton list 967 | """ 968 | return [torch.norm(self.tensor)] 969 | 970 | @torch.no_grad() 971 | def rescale_norm(self, scale): 972 | """ 973 | Rescales the norm of our core by a factor of input `scale` 974 | """ 975 | if isinstance(scale, list): 976 | assert len(scale) == 1 977 | scale = scale[0] 978 | 979 | self.tensor *= scale 980 | 981 | def core_len(self): 982 | return 1 983 | 984 | def __len__(self): 985 | return 0 986 | 987 | class MergedOutput(nn.Module): 988 | """ 989 | Merged MPS core taking in one input datum and returning an output vector 990 | 991 | Since MergedOutput arises after contracting together an existing input and 992 | output core, an already-merged tensor is required for initialization 993 | 994 | Args: 995 | tensor (Tensor): Value that our merged core is initialized to 996 | left_output (bool): Specifies if the output core is on the left side of 997 | the input core (True), or on the right (False) 998 | """ 999 | def __init__(self, tensor, left_output): 1000 | # Check that our input tensor has the correct shape 1001 | bond_str = 'olri' 1002 | assert len(tensor.shape) == 4 1003 | super().__init__() 1004 | 1005 | # Register our tensor as a Pytorch Parameter 1006 | self.register_parameter(name='tensor', 1007 | param=nn.Parameter(tensor.contiguous())) 1008 | self.left_output = left_output 1009 | 1010 | def forward(self, input_data): 1011 | """ 1012 | Contract input with input index of core and return an OutputCore 1013 | 1014 | Args: 1015 | input_data (Tensor): Input with shape [batch_size, feature_dim] 1016 | """ 1017 | # Check that input_data has the correct shape 1018 | tensor = self.tensor 1019 | assert len(input_data.shape) == 2 1020 | assert input_data.size(1) == tensor.size(3) 1021 | 1022 | # Contract the input with our core tensor 1023 | tensor = torch.einsum('olri,bi->bolr', [tensor, input_data]) 1024 | 1025 | return OutputCore(tensor) 1026 | 1027 | def unmerge(self, cutoff=1e-10): 1028 | """ 1029 | Split our MergedOutput into an OutputSite and an InputSite 1030 | 1031 | The non-zero entries of our tensors are dynamically sized according to 1032 | the SVD cutoff, but will generally be padded with zeros to give the 1033 | new index a regular size. 1034 | """ 1035 | bond_str = 'olri' 1036 | tensor = self.tensor 1037 | left_output = self.left_output 1038 | if left_output: 1039 | svd_string = 'olri->olu,uri' 1040 | max_D = tensor.size(2) 1041 | sv_vec = torch.empty(max_D) 1042 | 1043 | output_core, input_core, bond_dim = svd_flex(tensor, svd_string, 1044 | max_D, cutoff, sv_vec=sv_vec) 1045 | return ([OutputSite(output_core), InputSite(input_core)], 1046 | [-1, bond_dim, -1], [-1, sv_vec, -1]) 1047 | 1048 | else: 1049 | svd_string = 'olri->our,lui' 1050 | max_D = tensor.size(1) 1051 | sv_vec = torch.empty(max_D) 1052 | 1053 | output_core, input_core, bond_dim = svd_flex(tensor, svd_string, 1054 | max_D, cutoff, sv_vec=sv_vec) 1055 | return ([InputSite(input_core), OutputSite(output_core)], 1056 | [-1, bond_dim, -1], [-1, sv_vec, -1]) 1057 | 1058 | def get_norm(self): 1059 | """ 1060 | Returns the norm of our core tensor, wrapped as a singleton list 1061 | """ 1062 | return [torch.norm(self.tensor)] 1063 | 1064 | @torch.no_grad() 1065 | def rescale_norm(self, scale): 1066 | """ 1067 | Rescales the norm of our core by a factor of input `scale` 1068 | """ 1069 | if isinstance(scale, list): 1070 | assert len(scale) == 1 1071 | scale = scale[0] 1072 | 1073 | self.tensor *= scale 1074 | 1075 | def core_len(self): 1076 | return 2 1077 | 1078 | def __len__(self): 1079 | return 1 1080 | 1081 | class InitialVector(nn.Module): 1082 | """ 1083 | Vector of ones and zeros to act as initial vector within the MPS 1084 | 1085 | By default the initial vector is chosen to be all ones, but if fill_dim is 1086 | specified then only the first fill_dim entries are set to one, with the 1087 | rest zero. 1088 | 1089 | If fixed_vec is False, then the initial vector will be registered as a 1090 | trainable model parameter. 1091 | """ 1092 | def __init__(self, bond_dim, fill_dim=None, fixed_vec=True, 1093 | is_left_vec=True): 1094 | super().__init__() 1095 | 1096 | vec = torch.ones(bond_dim) 1097 | if fill_dim is not None: 1098 | assert fill_dim >= 0 and fill_dim <= bond_dim 1099 | vec[fill_dim:] = 0 1100 | 1101 | if fixed_vec: 1102 | vec.requires_grad = False 1103 | self.register_buffer(name='vec', tensor=vec) 1104 | else: 1105 | vec.requires_grad = True 1106 | self.register_parameter(name='vec', param=nn.Parameter(vec)) 1107 | 1108 | assert isinstance(is_left_vec, bool) 1109 | self.is_left_vec = is_left_vec 1110 | 1111 | def forward(self): 1112 | """ 1113 | Return our initial vector wrapped as an EdgeVec contractable 1114 | """ 1115 | return EdgeVec(self.vec, self.is_left_vec) 1116 | 1117 | def core_len(self): 1118 | return 1 1119 | 1120 | def __len__(self): 1121 | return 0 1122 | 1123 | class TerminalOutput(nn.Module): 1124 | """ 1125 | Output matrix at end of chain to transmute virtual state into output vector 1126 | 1127 | By default, a fixed rectangular identity matrix with shape 1128 | [bond_dim, output_dim] will be used as a state transducer. If fixed_mat is 1129 | False, then the matrix will be registered as a trainable model parameter. 1130 | """ 1131 | def __init__(self, bond_dim, output_dim, fixed_mat=False, 1132 | is_left_mat=False): 1133 | super().__init__() 1134 | 1135 | # I don't have a nice initialization scheme for a non-injective fixed 1136 | # state transducer, so just throw an error if that's needed 1137 | if fixed_mat and output_dim > bond_dim: 1138 | raise ValueError("With fixed_mat=True, TerminalOutput currently " 1139 | "only supports initialization for bond_dim >= " 1140 | "output_dim, but here bond_dim={bond_dim} and output_dim={output_dim}") 1141 | 1142 | # Initialize the matrix and register it appropriately 1143 | mat = torch.eye(bond_dim, output_dim) 1144 | if fixed_mat: 1145 | mat.requires_grad = False 1146 | self.register_buffer(name='mat', tensor=mat) 1147 | else: 1148 | # Add some noise to help with training 1149 | mat = mat + torch.randn_like(mat) / bond_dim 1150 | 1151 | mat.requires_grad = True 1152 | self.register_parameter(name='mat', param=nn.Parameter(mat)) 1153 | 1154 | assert isinstance(is_left_mat, bool) 1155 | self.is_left_mat = is_left_mat 1156 | 1157 | def forward(self): 1158 | """ 1159 | Return our terminal matrix wrapped as an OutputMat contractable 1160 | """ 1161 | return OutputMat(self.mat, self.is_left_mat) 1162 | 1163 | def core_len(self): 1164 | return 1 1165 | 1166 | def __len__(self): 1167 | return 0 1168 | -------------------------------------------------------------------------------- /models/mps.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/timqqt/MERA_Image_Classification/e96211f45ade86f031a0d99ad0670231844ef3a1/models/mps.pyc -------------------------------------------------------------------------------- /train_MERA.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import time 3 | import torch 4 | from models.lotenet import loTeNet, ConvTeNet 5 | from models.MERA import MERAnet_clean as MERAnet 6 | 7 | from torchvision import transforms, datasets 8 | import pdb 9 | from utils.lidc_dataset import LIDC 10 | from utils.tools import * 11 | from models.Densenet import * 12 | from models.Densenet import DenseNet, FrDenseNet 13 | import argparse 14 | 15 | # Globally load device identifier 16 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 17 | global global_bs 18 | def evaluate(loader): 19 | ### Evaluation funcntion for validation/testing 20 | 21 | with torch.no_grad(): 22 | vl_acc = 0. 23 | vl_loss = 0. 24 | #labelsNp = np.zeros(1) 25 | #predsNp = np.zeros(1) 26 | model.eval() 27 | 28 | for i, (inputs, labels) in enumerate(loader): 29 | 30 | inputs = inputs.to(device) 31 | labels = labels.to(device) 32 | #labelsNp = np.concatenate((labelsNp, labels.cpu().numpy())) 33 | 34 | # Inference 35 | #scores = torch.sigmoid(model(inputs)) 36 | sm = torch.nn.Softmax(dim=1) 37 | scores = sm(model(inputs)) 38 | 39 | scores = scores.float() 40 | labels = labels.long() 41 | preds = scores 42 | 43 | loss = loss_fun(scores, labels) 44 | #predsNp = np.concatenate((predsNp, preds.cpu().numpy())) 45 | vl_loss += loss.item() 46 | 47 | # Compute AUC over the full (valid/test) set 48 | vl_acc = computeAuc(labels.detach().cpu().numpy(), preds.detach().cpu().numpy()) 49 | vl_loss = vl_loss/len(loader) 50 | 51 | return vl_acc, vl_loss 52 | 53 | # Miscellaneous initialization 54 | torch.manual_seed(1) 55 | start_time = time.time() 56 | 57 | parser = argparse.ArgumentParser() 58 | parser.add_argument('--num_epochs', type=int, default=100, help='Number of training epochs') 59 | parser.add_argument('--batch_size', type=int, default=512, help='Batch size') 60 | parser.add_argument('--lr', type=float, default=5e-4, help='Learning rate') 61 | parser.add_argument('--l2', type=float, default=0, help='L2 regularisation') 62 | parser.add_argument('--aug', action='store_true', default=False, help='Use data augmentation') 63 | parser.add_argument('--data_path', type=str, default='./data/',help='Path to data.') 64 | parser.add_argument('--bond_dim', type=int, default=5, help='MPS Bond dimension') 65 | parser.add_argument('--nChannel', type=int, default=1, help='Number of input channels') 66 | parser.add_argument('--dense_net', action='store_true', 67 | default=False, help='Using Dense Net model') 68 | parser.add_argument('--MERA', action='store_true', 69 | default=False, help='Using Conv style Tensor Net model') 70 | parser.add_argument('--kernel', nargs='+', type=int) 71 | 72 | parser.add_argument( 73 | "--gpu", 74 | default=0, 75 | help="the directory of the model" 76 | ) 77 | args = parser.parse_args() 78 | global_bs = args.batch_size 79 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 80 | device = torch.device('cuda:'+str(args.gpu)) 81 | print(device) 82 | with torch.cuda.device('cuda:'+str(args.gpu)): 83 | batch_size = args.batch_size 84 | 85 | # LoTeNet parameters 86 | adaptive_mode = False 87 | periodic_bc = False 88 | 89 | kernel = args.kernel # Stride along spatial dimensions 90 | output_dim = 2 # output dimension 91 | 92 | feature_dim = 2 93 | 94 | logFile = time.strftime("%Y%m%d_%H_%M")+'.txt' 95 | makeLogFile(logFile) 96 | 97 | normTensor = 0.5*torch.ones(args.nChannel) 98 | ### Data processing and loading.... 99 | trans_valid = transforms.Compose([transforms.Normalize(mean=normTensor,std=normTensor)]) 100 | 101 | if args.aug: 102 | trans_train = transforms.Compose([transforms.ToPILImage(), 103 | transforms.RandomHorizontalFlip(), 104 | transforms.RandomVerticalFlip(), 105 | transforms.RandomRotation(20), 106 | transforms.ToTensor(), 107 | transforms.Normalize(mean=normTensor,std=normTensor)]) 108 | print("Using Augmentation....") 109 | else: 110 | trans_train = trans_valid 111 | print("No augmentation....") 112 | 113 | # Load processed LIDC data 114 | dataset_train = LIDC(split='Train', data_dir=args.data_path, 115 | transform=trans_train,rater=4) 116 | dataset_valid = LIDC(split='Valid', data_dir=args.data_path, 117 | transform=trans_valid,rater=4) 118 | dataset_test = LIDC(split='Test', data_dir=args.data_path, 119 | transform=trans_valid,rater=4) 120 | 121 | num_train = len(dataset_train) 122 | num_valid = len(dataset_valid) 123 | num_test = len(dataset_test) 124 | print("Num. train = %d, Num. val = %d"%(num_train,num_valid)) 125 | 126 | loader_train = DataLoader(dataset = dataset_train, drop_last=True, 127 | batch_size=batch_size, shuffle=True) 128 | loader_valid = DataLoader(dataset = dataset_valid, drop_last=True, 129 | batch_size=batch_size, shuffle=False) 130 | loader_test = DataLoader(dataset = dataset_test, drop_last=True, 131 | batch_size=batch_size, shuffle=False) 132 | 133 | # Initiliaze input dimensions 134 | dim = torch.ShortTensor(list(dataset_train[0][0].shape[1:])) 135 | nCh = int(dataset_train[0][0].shape[0]) 136 | 137 | # Initialize the models 138 | if not args.dense_net: 139 | #print(args.convTN) 140 | if args.MERA: 141 | print("Using MERA") 142 | model = MERAnet(input_dim=dim, output_dim=output_dim, 143 | nCh=nCh, kernel=kernel, 144 | bond_dim=args.bond_dim, feature_dim=feature_dim, 145 | adaptive_mode=adaptive_mode, periodic_bc=periodic_bc, virtual_dim=1) 146 | else: 147 | print("Using LoTeNet") 148 | model = loTeNet(input_dim=dim, output_dim=output_dim, 149 | nCh=nCh, kernel=kernel, 150 | bond_dim=args.bond_dim, feature_dim=feature_dim, 151 | adaptive_mode=adaptive_mode, periodic_bc=periodic_bc, virtual_dim=1) 152 | else: 153 | print("Densenet Baseline!") 154 | model = FrDenseNet(depth=40, growthRate=12, 155 | reduction=0.5,bottleneck=True,nClasses=output_dim) 156 | 157 | # Choose loss function and optimizer 158 | #loss_fun = torch.nn.BCELoss() 159 | 160 | loss_fun = torch.nn.CrossEntropyLoss() 161 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, 162 | weight_decay=args.l2) 163 | 164 | nParam = sum(p.numel() for p in model.parameters() if p.requires_grad) 165 | print("Number of parameters:%d"%(nParam)) 166 | print("Maximum MPS bond dimension = {args.bond_dim}") 167 | with open(logFile, "a") as f: 168 | print("Bond dim: %d"%(args.bond_dim)) 169 | print("Number of parameters:%d"%(nParam)) 170 | 171 | print("Using Adam w/ learning rate = {args.lr:.1e}") 172 | print("Feature_dim: %d, nCh: %d, B:%d"%(feature_dim,nCh,batch_size)) 173 | 174 | model = model.to(device) 175 | nValid = len(loader_valid) 176 | nTrain = len(loader_train) 177 | nTest = len(loader_test) 178 | 179 | maxAuc = 0 180 | minLoss = 1e3 181 | convCheck = 5 182 | convIter = 0 183 | 184 | # Let's start training! 185 | for epoch in range(args.num_epochs): 186 | running_loss = 0. 187 | running_acc = 0. 188 | t = time.time() 189 | model.train() 190 | #predsNp = np.zeros(1) 191 | #labelsNp = np.zeros(1) 192 | 193 | for i, (inputs, labels) in enumerate(loader_train): 194 | 195 | inputs = inputs.to(device) 196 | labels = labels.to(device) 197 | #labelsNp = np.concatenate((labelsNp, labels.cpu().numpy())) 198 | sm = torch.nn.Softmax(dim=1) 199 | scores = sm(model(inputs)) 200 | #scores = torch.sigmoid(model(inputs)) 201 | #labels = torch.nn.functional.one_hot(labels.to(torch.int64), num_classes=output_dim) 202 | preds = scores 203 | preds = preds.float() 204 | #print(scores.shape) 205 | # print(labels.shape) 206 | # print(labels.detach().cpu().numpy()) 207 | # assert False 208 | loss = loss_fun(scores, labels.long()) 209 | 210 | with torch.no_grad(): 211 | #predsNp = np.concatenate((predsNp, preds.detach().cpu().numpy())) 212 | running_loss += loss 213 | 214 | # Backpropagate and update parameters 215 | optimizer.zero_grad() 216 | loss.backward() 217 | optimizer.step() 218 | 219 | if (i+1) % 5 == 0: 220 | print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}' 221 | .format(epoch+1, args.num_epochs, i+1, nTrain, loss.item())) 222 | 223 | accuracy = computeAuc(labels.detach().cpu().numpy(), preds.detach().cpu().numpy()) 224 | 225 | # Evaluate on Validation set 226 | with torch.no_grad(): 227 | 228 | vl_acc, vl_loss = evaluate(loader_valid) 229 | if vl_acc > maxAuc or vl_loss < minLoss: 230 | if vl_loss < minLoss: 231 | minLoss = vl_loss 232 | if vl_acc > maxAuc: 233 | ### Predict on test set 234 | ts_acc, ts_loss = evaluate(loader_test) 235 | maxAuc = vl_acc 236 | print('New Max: %.4f'%maxAuc) 237 | print('Test Set Loss:%.4f Auc:%.4f'%(ts_loss, ts_acc)) 238 | with open(logFile,"a") as f: 239 | print('Test Set Loss:%.4f Auc:%.4f'%(ts_loss, ts_acc)) 240 | convEpoch = epoch 241 | convIter = 0 242 | else: 243 | convIter += 1 244 | if convIter == convCheck: 245 | if not args.dense_net: 246 | print("MPS") 247 | else: 248 | print("DenseNet") 249 | print("Converged at epoch:%d with AUC:%.4f"%(convEpoch+1,maxAuc)) 250 | 251 | #break 252 | writeLog(logFile, epoch, running_loss/nTrain, accuracy, 253 | vl_loss, vl_acc, time.time()-t) 254 | -------------------------------------------------------------------------------- /utils/MNIST_reader.py: -------------------------------------------------------------------------------- 1 | # https://gist.github.com/kevinzakka/d33bf8d6c7f06a9d8c76d97a7879f5cb#file-data_loader-py 2 | # This is an example for the MNIST dataset (formerly CIFAR-10). 3 | # There's a function for creating a train and validation iterator. 4 | # There's also a function for creating a test iterator. 5 | # Inspired by https://discuss.pytorch.org/t/feedback-on-pytorch-for-kaggle-competitions/2252/4 6 | 7 | # Adapted for MNIST by github.com/MatthewKleinsmith 8 | 9 | import numpy as np 10 | import torch 11 | from torchvision import datasets, transforms 12 | from torch.utils.data.sampler import SubsetRandomSampler 13 | 14 | 15 | def get_train_valid_loader(data_dir, 16 | batch_size, 17 | augment=False, 18 | valid_size=0.2, 19 | shuffle=True, 20 | show_sample=False, 21 | num_workers=1, 22 | pin_memory=True): 23 | """ 24 | Utility function for loading and returning train and valid 25 | multi-process iterators over the MNIST dataset. A sample 26 | 9x9 grid of the images can be optionally displayed. 27 | If using CUDA, num_workers should be set to 1 and pin_memory to True. 28 | Params 29 | ------ 30 | - data_dir: path directory to the dataset. 31 | - batch_size: how many samples per batch to load. 32 | - augment: whether to apply the data augmentation scheme 33 | mentioned in the paper. Only applied on the train split. 34 | - random_seed: fix seed for reproducibility. 35 | - valid_size: percentage split of the training set used for 36 | the validation set. Should be a float in the range [0, 1]. 37 | - shuffle: whether to shuffle the train/validation indices. 38 | - show_sample: plot 9x9 sample grid of the dataset. 39 | - num_workers: number of subprocesses to use when loading the dataset. 40 | - pin_memory: whether to copy tensors into CUDA pinned memory. Set it to 41 | True if using GPU. 42 | Returns 43 | ------- 44 | - train_loader: training set iterator. 45 | - valid_loader: validation set iterator. 46 | """ 47 | error_msg = "[!] valid_size should be in the range [0, 1]." 48 | assert ((valid_size >= 0) and (valid_size <= 1)), error_msg 49 | 50 | normalize = transforms.Normalize((0.1307,), (0.3081,)) # MNIST 51 | 52 | # define transforms 53 | valid_transform = transforms.Compose([ 54 | transforms.ToTensor(), 55 | normalize 56 | ]) 57 | if augment: 58 | train_transform = transforms.Compose([ 59 | transforms.RandomCrop(32, padding=4), 60 | transforms.RandomHorizontalFlip(), 61 | transforms.ToTensor(), 62 | normalize 63 | ]) 64 | else: 65 | train_transform = transforms.Compose([ 66 | transforms.ToTensor(), 67 | normalize 68 | ]) 69 | 70 | # load the dataset 71 | train_dataset = datasets.MNIST(root=data_dir, train=True, 72 | download=True, transform=train_transform) 73 | 74 | valid_dataset = datasets.MNIST(root=data_dir, train=True, 75 | download=True, transform=valid_transform) 76 | train_dataset.targets[train_dataset.targets < 1] = 0 77 | train_dataset.targets[train_dataset.targets >= 1] = 1 78 | 79 | valid_dataset.targets[valid_dataset.targets < 1] = 0 80 | valid_dataset.targets[valid_dataset.targets >= 1] = 1 81 | 82 | num_train = len(train_dataset) 83 | indices = list(range(num_train)) 84 | split = int(np.floor(valid_size * num_train)) 85 | 86 | if shuffle == True: 87 | np.random.seed(1) 88 | np.random.shuffle(indices) 89 | 90 | train_idx, valid_idx = indices[split:], indices[:split] 91 | 92 | train_sampler = SubsetRandomSampler(train_idx) 93 | valid_sampler = SubsetRandomSampler(valid_idx) 94 | 95 | train_loader = torch.utils.data.DataLoader(train_dataset, 96 | batch_size=batch_size, sampler=train_sampler, 97 | num_workers=num_workers, pin_memory=pin_memory) 98 | 99 | valid_loader = torch.utils.data.DataLoader(valid_dataset, 100 | batch_size=batch_size, sampler=valid_sampler, 101 | num_workers=num_workers, pin_memory=pin_memory) 102 | 103 | # visualize some images 104 | 105 | return (train_loader, valid_loader) 106 | 107 | 108 | def get_test_loader(data_dir, 109 | batch_size, 110 | shuffle=True, 111 | num_workers=1, 112 | pin_memory=True): 113 | """ 114 | Utility function for loading and returning a multi-process 115 | test iterator over the MNIST dataset. 116 | If using CUDA, num_workers should be set to 1 and pin_memory to True. 117 | Params 118 | ------ 119 | - data_dir: path directory to the dataset. 120 | - batch_size: how many samples per batch to load. 121 | - shuffle: whether to shuffle the dataset after every epoch. 122 | - num_workers: number of subprocesses to use when loading the dataset. 123 | - pin_memory: whether to copy tensors into CUDA pinned memory. Set it to 124 | True if using GPU. 125 | Returns 126 | ------- 127 | - data_loader: test set iterator. 128 | """ 129 | normalize = transforms.Normalize((0.1307,), (0.3081,)) # MNIST 130 | 131 | # define transform 132 | transform = transforms.Compose([ 133 | transforms.ToTensor(), 134 | normalize 135 | ]) 136 | 137 | dataset = datasets.MNIST(root=data_dir, 138 | train=False, 139 | download=True, 140 | transform=transform) 141 | 142 | dataset.targets[dataset.targets < 1] = 0 143 | dataset.targets[dataset.targets >= 1] = 1 144 | 145 | data_loader = torch.utils.data.DataLoader(dataset, 146 | batch_size=batch_size, 147 | shuffle=shuffle, 148 | num_workers=num_workers, 149 | pin_memory=pin_memory) 150 | 151 | return data_loader -------------------------------------------------------------------------------- /utils/lidc_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torchvision 6 | import torchvision.transforms as transforms 7 | from torch.utils.data import TensorDataset, DataLoader, Dataset 8 | import pdb 9 | 10 | class LIDC(Dataset): 11 | def __init__(self, rater=4, split='Train', data_dir = './', transform=None): 12 | super().__init__() 13 | 14 | self.data_dir = data_dir 15 | self.rater = rater 16 | self.transform = transform 17 | self.data, self.targets = torch.load(data_dir+split+'.pt') 18 | self.targets = self.targets.type(torch.FloatTensor) 19 | def __len__(self): 20 | return len(self.targets) 21 | 22 | def __getitem__(self, index): 23 | 24 | image, label = self.data[index], self.targets[index] 25 | if self.rater == 4: 26 | label = (label.sum() > 2).type_as(self.targets) 27 | else: 28 | label = label[self.rater] 29 | image = image.type(torch.FloatTensor)/255.0 30 | if self.transform is not None: 31 | image = self.transform(image) 32 | return image, label 33 | 34 | 35 | 36 | 37 | -------------------------------------------------------------------------------- /utils/needle_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | 9 | import os 10 | import torch 11 | import argparse 12 | import numpy as np 13 | from scipy.ndimage import zoom 14 | import pandas as pd 15 | import numpy as np 16 | import os 17 | from torch.utils.data.dataset import Dataset 18 | # Copyright (c) Facebook, Inc. and its affiliates. 19 | # All rights reserved. 20 | # 21 | # This source code is licensed under the license found in the 22 | # LICENSE file in the root directory of this source tree. 23 | # 24 | from scipy.ndimage import zoom 25 | import pandas as pd 26 | import numpy as np 27 | import os 28 | from torch.utils.data.dataset import Dataset 29 | 30 | class ClutteredMNISTDataset(Dataset): 31 | reg_dataset_size = 11276 32 | 33 | def __init__(self, base_path, csv_path, data_scaling=1., num_examples=None, balance=None): 34 | self.base_path = base_path 35 | self.csv_path = csv_path 36 | self.csv = pd.read_csv(csv_path) 37 | self.data_scaling = data_scaling 38 | self.num_examples = num_examples 39 | self.balance = balance 40 | 41 | self.img_paths = self.csv['img_path'].values 42 | self.lbls = self.csv['label'].values.astype(np.int32) 43 | self.weights = np.ones([len(self.img_paths), ]) 44 | 45 | if self.num_examples is not None and self.balance is not None: 46 | # only rebalance if examples and balance is given 47 | assert self.num_examples <= len(self.img_paths), \ 48 | 'not enough examples in dataset {} - {}'.format(self.num_examples, len(self.img_paths)) 49 | 50 | pos_num = int(self.balance * self.num_examples) 51 | neg_num = self.num_examples - pos_num 52 | 53 | pos_mask = (self.lbls == 1) 54 | pos_paths = self.img_paths[pos_mask][:pos_num] 55 | 56 | neg_mask = (self.lbls == 0) 57 | neg_paths = self.img_paths[neg_mask][:neg_num] 58 | 59 | self.img_paths = np.concatenate([pos_paths, neg_paths], 0) 60 | self.lbls = np.concatenate([np.ones([pos_num, ]), np.zeros([neg_num, ])], 0) 61 | self.weights = np.ones([self.num_examples, ]) 62 | self.weights[:pos_num] /= pos_num 63 | self.weights[pos_num:] /= neg_num 64 | 65 | self.shrinkage = self.reg_dataset_size // len(self.img_paths) 66 | 67 | def __len__(self): 68 | return len(self.img_paths) 69 | 70 | def __getitem__(self, index): 71 | path = os.path.join(self.base_path, self.img_paths[index]) 72 | lbl = self.lbls[index].astype(np.int64) 73 | 74 | img = np.load(path) 75 | if isinstance(img, np.lib.npyio.NpzFile): 76 | img = img['arr_0'] 77 | 78 | if self.data_scaling != 1.: 79 | img = zoom(img, self.data_scaling) 80 | img = img.clip(0., 1.) 81 | 82 | img = img[np.newaxis].astype(np.float32) 83 | 84 | return img, lbl 85 | 86 | def NeedleMNIST(data_path, data_scaling, data_balance, batch_size, num_examples): 87 | dataset = ClutteredMNISTDataset 88 | 89 | train_dataset = dataset(data_path, os.path.join(data_path, 'train.csv'), 90 | data_scaling, num_examples, data_balance) 91 | 92 | weights = torch.Tensor(train_dataset.weights) 93 | # over and undersamples.. 94 | sampler = torch.utils.data.sampler.WeightedRandomSampler(weights, len(train_dataset)) 95 | 96 | train_loader = torch.utils.data.DataLoader( 97 | train_dataset, batch_size=batch_size, num_workers=8, sampler=sampler, pin_memory=True) 98 | 99 | val_dataset = dataset(data_path, os.path.join(data_path, 'val.csv'), data_scaling) 100 | val_loader = torch.utils.data.DataLoader( 101 | val_dataset, batch_size=batch_size, shuffle=False, num_workers=8) 102 | 103 | test_dataset = dataset(data_path, os.path.join(data_path, 'test.csv'), data_scaling) 104 | test_loader = torch.utils.data.DataLoader( 105 | test_dataset, batch_size=batch_size, num_workers=8, shuffle=False) 106 | 107 | return train_loader, val_loader, test_loader 108 | 109 | 110 | 111 | 112 | 113 | 114 | -------------------------------------------------------------------------------- /utils/tools.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.sparse as sp 3 | import torch 4 | from os.path import isfile 5 | from os import rename 6 | SMOOTH=1 7 | import pdb 8 | from sklearn.metrics import auc, roc_curve 9 | import torch.nn.functional as F 10 | import torch.nn as nn 11 | from PIL.ImageFilter import GaussianBlur 12 | 13 | 14 | def wCELoss(prediction, target): 15 | w1 = 1.33 # False negative penalty 16 | w2 = .66 # False positive penalty 17 | return -torch.mean(w1 * target * torch.log(prediction.clamp_min(1e-3)) 18 | + w2 * (1. - target) * torch.log(1. - prediction.clamp_max(.999))) 19 | 20 | class GaussianFilter(object): 21 | """Apply Gaussian blur to the PIL image 22 | Args: 23 | sigma (float): Sigma of Gaussian kernel. Default value 1.0 24 | """ 25 | def __init__(self, sigma=1): 26 | self.sigma = sigma 27 | self.filter = GaussianBlur(radius=sigma) 28 | 29 | def __call__(self, img): 30 | """ 31 | Args: 32 | img (PIL Image): Image to be blurred. 33 | 34 | Returns: 35 | PIL Image: Blurred image. 36 | """ 37 | return img.filter(self.filter) 38 | 39 | def __repr__(self): 40 | return self.__class__.__name__ + '(sigma={})'.format(self.sigma) 41 | 42 | class GaussianLayer(nn.Module): 43 | def __init__(self): 44 | super(GaussianLayer, self, sigma=1, size=10).__init__() 45 | self.sigma = sigma 46 | self.size = size 47 | self.seq = nn.Sequential( 48 | nn.ReflectionPad2d(size), 49 | nn.Conv2d(3, 3, size, stride=1, padding=0, bias=None, groups=3) 50 | ) 51 | self.weights_init() 52 | 53 | def forward(self, x): 54 | return self.seq(x) 55 | 56 | def weights_init(self): 57 | s = self.size * 2 + 1 58 | k = np.zeros((s,s)) 59 | k[s,s] = 1 60 | kernel = gaussian_filter(k,sigma=self.sigma) 61 | for name, f in self.named_parameters(): 62 | f.data.copy_(torch.from_numpy(kernel)) 63 | 64 | class focalLoss(nn.Module): 65 | def __init__(self, alpha=1, gamma=2, logits=False, reduce=True): 66 | super(focalLoss, self).__init__() 67 | self.alpha = alpha 68 | self.gamma = gamma 69 | self.logits = logits 70 | self.reduce = reduce 71 | 72 | def forward(self, inputs, targets): 73 | if self.logits: 74 | BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduce=False) 75 | else: 76 | BCE_loss = F.binary_cross_entropy(inputs, targets, reduce=False) 77 | pt = torch.exp(-BCE_loss) 78 | F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss 79 | 80 | if self.reduce: 81 | return torch.mean(F_loss) 82 | else: 83 | return F_loss 84 | 85 | def computeAuc(target,preds): 86 | print(preds.shape) 87 | if preds.shape[1] == 2: 88 | print("Two ouput dim mode") 89 | preds = np.argmax(preds, axis=1) 90 | print(preds.shape) 91 | else: 92 | preds = np.where(preds >= 0.5, 1, 0) 93 | #assert False 94 | aucVal = np.sum(preds == target)/ np.sum(np.ones_like(target)) 95 | #fpr, tpr, thresholds = roc_curve(target,preds) 96 | #aucVal = auc(fpr,tpr) 97 | return aucVal 98 | 99 | class hingeLoss(torch.nn.Module): 100 | 101 | def __init__(self): 102 | super(hingeLoss, self).__init__() 103 | 104 | def forward(self, output, target): 105 | # pdb.set_trace() 106 | target = 2*target-1 107 | output = 2*output-1 108 | hinge_loss = 1 - torch.mul(output, target) 109 | hinge_loss[hinge_loss < 0] = 0 110 | return hinge_loss.mean() 111 | 112 | 113 | def makeBatchAdj(adj,bSize): 114 | 115 | E = adj._nnz() 116 | N = adj.shape[0] 117 | batch_idx = torch.zeros(2,bSize*E).type(torch.LongTensor) 118 | batch_val = torch.zeros(bSize*E) 119 | 120 | idx = adj._indices() 121 | vals = adj._values() 122 | 123 | for i in range(bSize): 124 | batch_idx[:,i*E:(i+1)*E] = idx + i*N 125 | batch_val[i*E:(i+1)*E] = vals 126 | 127 | return torch.sparse.FloatTensor(batch_idx,batch_val,(bSize*N,bSize*N)) 128 | 129 | 130 | def makeAdj(ngbrs, normalize=True): 131 | """ Create an adjacency matrix, given the neighbour indices 132 | Input: Nxd neighbourhood, where N is number of nodes 133 | Output: NxN sparse torch adjacency matrix 134 | """ 135 | # pdb.set_trace() 136 | N, d = ngbrs.shape 137 | validNgbrs = (ngbrs >= 0) # Mask for valid neighbours amongst the d-neighbours 138 | row = np.repeat(np.arange(N),d) # Row indices like in sparse matrix formats 139 | row = row[validNgbrs.reshape(-1)] #Remove non-neighbour row indices 140 | col = (ngbrs*validNgbrs).reshape(-1) # Obtain nieghbour col indices 141 | col = col[validNgbrs.reshape(-1)] # Remove non-neighbour col indices 142 | data = np.ones(col.size) 143 | adj = sp.csr_matrix((np.ones(col.size, dtype=bool),(row, col)), shape=(N, N)).toarray() # Make adj matrix 144 | adj = adj + np.eye(N) # Self connections 145 | adj = sp.csr_matrix(adj, dtype=np.float32)#/(d+1) 146 | if normalize: 147 | adj = row_normalize(adj) 148 | adj = sparse_mx_to_torch_sparse_tensor(adj) 149 | 150 | return adj 151 | 152 | def makeRegAdj(numNgbrs=26): 153 | """ Make regular pixel neighbourhoods""" 154 | idx = 0 155 | ngbrOffset = np.zeros((3,numNgbrs),dtype=int) 156 | for i in range(-1,2): 157 | for j in range(-1,2): 158 | for k in range(-1,2): 159 | if(i | j | k): 160 | ngbrOffset[:,idx] = [i,j,k] 161 | idx+=1 162 | idx = 0 163 | ngbrs = np.zeros((numEl, numNgbrs), dtype=int) 164 | 165 | for i in range(xdim): 166 | for j in range(ydim): 167 | for k in range(zdim): 168 | xIdx = np.mod(ngbrOffset[0,:]+i,xdim) 169 | yIdx = np.mod(ngbrOffset[1,:]+j,ydim) 170 | zIdx = np.mod(ngbrOffset[2,:]+k,zdim) 171 | ngbrs[idx,:] = idxVol[xIdx, yIdx, zIdx] 172 | idx += 1 173 | 174 | 175 | def makeAdjWithInvNgbrs(ngbrs, normalize=False): 176 | """ Create an adjacency matrix, given the neighbour indices including invalid indices where self connections are added. 177 | Input: Nxd neighbourhood, where N is number of nodes 178 | Output: NxN sparse torch adjacency matrix 179 | """ 180 | np.random.seed(2) 181 | # pdb.set_trace() 182 | N, d = ngbrs.shape 183 | row = np.arange(N).reshape(-1,1) 184 | random = np.random.randint(0,N-1,(N,d)) 185 | valIdx = np.array((ngbrs < 0),dtype=int) 186 | ngbrs = random*valIdx + ngbrs*(1-valIdx)# Mask for valid neighbours amongst the d-neighbours 187 | row = np.repeat(row,d).reshape(-1) # Row indices like in sparse matrix formats 188 | col = ngbrs.reshape(-1) # Obtain nieghbour col indices 189 | data = np.ones(col.size) 190 | adj = sp.csr_matrix((np.ones(col.size, dtype=bool),(row, col)), shape=(N, N)).toarray() # Make adj matrix 191 | adj = adj + np.eye(N) # Self connections 192 | adj = sp.csr_matrix(adj, dtype=np.float32)#/(d+1) 193 | if normalize: 194 | adj = row_normalize(adj) 195 | adj = sparse_mx_to_torch_sparse_tensor(adj) 196 | adj = adj.coalesce() 197 | adj._values = adj.values() 198 | return adj 199 | 200 | 201 | def transformers(adj): 202 | """ Obtain source and sink node transformer matrices""" 203 | edges = adj._indices() 204 | N = adj.shape[0] 205 | nnz = adj._nnz() 206 | val = torch.ones(nnz) 207 | idx0 = torch.arange(nnz) 208 | 209 | idx = torch.stack((idx0,edges[1,:])) 210 | n2e_in = torch.sparse.FloatTensor(idx,val,(nnz,N)) 211 | 212 | idx = torch.stack((idx0,edges[0,:])) 213 | n2e_out = torch.sparse.FloatTensor(idx,val,(nnz,N)) 214 | 215 | return n2e_in, n2e_out 216 | 217 | def sparse_mx_to_torch_sparse_tensor(sparse_mx): 218 | """Convert a scipy sparse matrix to a torch sparse tensor.""" 219 | sparse_mx = sparse_mx.tocoo().astype(np.float32) 220 | indices = torch.from_numpy(np.vstack((sparse_mx.row, 221 | sparse_mx.col))).long() 222 | values = torch.from_numpy(sparse_mx.data) 223 | shape = torch.Size(sparse_mx.shape) 224 | return torch.sparse.FloatTensor(indices, values, shape) 225 | 226 | def to_linear_idx(x_idx, y_idx, num_cols): 227 | assert num_cols > np.max(x_idx) 228 | x_idx = np.array(x_idx, dtype=np.int32) 229 | y_idx = np.array(y_idx, dtype=np.int32) 230 | return y_idx * num_cols + x_idx 231 | 232 | 233 | def row_normalize(mx): 234 | """Row-normalize sparse matrix""" 235 | rowsum = np.array(mx.sum(1), dtype=np.float32) 236 | r_inv = np.power(rowsum, -1).flatten() 237 | r_inv[np.isinf(r_inv)] = 0. 238 | r_mat_inv = sp.diags(r_inv) 239 | mx = r_mat_inv.dot(mx) 240 | return mx 241 | 242 | def to_2d_idx(idx, num_cols): 243 | idx = np.array(idx, dtype=np.int64) 244 | y_idx = np.array(np.floor(idx / float(num_cols)), dtype=np.int64) 245 | x_idx = idx % num_cols 246 | return x_idx, y_idx 247 | 248 | def dice_loss(preds, labels): 249 | "Return dice score. " 250 | preds_sq = preds**2 251 | return 1 - (2. * (torch.sum(preds * labels)) + SMOOTH) / \ 252 | (preds_sq.sum() + labels.sum() + SMOOTH) 253 | 254 | def binary_accuracy(output, labels): 255 | preds = output > 0.5 256 | correct = preds.type_as(labels).eq(labels).double() 257 | correct = correct.sum() 258 | return correct / len(labels) 259 | 260 | def multiClassAccuracy(output, labels): 261 | # pdb.set_trace() 262 | preds = output.argmax(1) 263 | # preds = (output > (1.0/labels.shape[1])).type_as(labels) 264 | correct = (preds == labels.view(-1)) 265 | correct = correct.sum().float() 266 | return correct / len(labels) 267 | 268 | def regrAcc(output, labels): 269 | # pdb.set_trace() 270 | preds = output.round().type(torch.long).type_as(labels) 271 | # preds = (output > (1.0/labels.shape[1])).type_as(labels) 272 | correct = (preds == labels.view(-1)) 273 | correct = correct.sum().float() 274 | return correct / len(labels) 275 | 276 | 277 | def rescaledRegAcc(output,labels,lRange=37,lMin=-20): 278 | # pdb.set_trace() 279 | preds = (output+1)*(lRange)/2 + lMin 280 | preds = preds.round().type(torch.long).type_as(labels) 281 | # preds = (output > (1.0/labels.shape[1])).type_as(labels) 282 | correct = (preds == labels.view(-1)) 283 | correct = correct.sum().float() 284 | return correct / len(labels) 285 | 286 | def focalCE(preds, labels, gamma=1): 287 | "Return focal cross entropy" 288 | loss = -torch.mean( ( ((1-preds)**gamma) * labels * torch.log(preds) ) \ 289 | + ( ((preds)**gamma) * (1-labels) * torch.log(1-preds) ) ) 290 | return loss 291 | 292 | def dice(preds, labels): 293 | # pdb.set_trace() 294 | "Return dice score" 295 | preds_bin = (preds > 0.5).type_as(labels) 296 | return 2. * torch.sum(preds_bin * labels) / (preds_bin.sum() + labels.sum()) 297 | 298 | def wBCE(preds, labels, w): 299 | "Return weighted CE loss." 300 | return -torch.mean( w*labels*torch.log(preds) + (1-w)*(1-labels)*torch.log(1-preds) ) 301 | 302 | def makeLogFile(filename="lossHistory.txt"): 303 | if isfile(filename): 304 | rename(filename,"lossHistoryOld.txt") 305 | 306 | with open(filename,"w") as text_file: 307 | print('Epoch\tlossTr\taccTr\tlossVl\taccVl\ttime(s)',file=text_file) 308 | print("Log file created...") 309 | return 310 | 311 | def writeLog(logFile, epoch, lossTr, accTr, lossVl, accVl,eTime): 312 | print('Epoch:{:04d}\t'.format(epoch + 1), 313 | 'lossTr:{:.4f}\t'.format(lossTr), 314 | 'accTr:{:.4f}\t'.format(accTr), 315 | 'lossVl:{:.4f}\t'.format(lossVl), 316 | 'accVl:{:.4f}\t'.format(accVl), 317 | 'time:{:.4f}'.format(eTime)) 318 | 319 | with open(logFile,"a") as text_file: 320 | print('{:04d}\t'.format(epoch + 1), 321 | '{:.4f}\t'.format(lossTr), 322 | '{:.4f}\t'.format(accTr), 323 | '{:.4f}\t'.format(lossVl), 324 | '{:.4f}\t'.format(accVl), 325 | '{:.4f}'.format(eTime),file=text_file) 326 | return 327 | 328 | def plotLearningCurve(): 329 | plt.clf() 330 | tmp = np.load('loss_tr.npz')['arr_0'] 331 | plt.plot(tmp,label='Tr.Loss') 332 | tmp = np.load('loss_vl.npz')['arr_0'] 333 | plt.plot(tmp,label='Vl.Loss') 334 | tmp = np.load('dice_tr.npz')['arr_0'] 335 | plt.plot(tmp,label='Tr.Dice') 336 | tmp = np.load('dice_vl.npz')['arr_0'] 337 | plt.plot(tmp,label='Vl.Dice') 338 | plt.legend() 339 | plt.grid() 340 | plt.show() 341 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | def svd_flex(tensor, svd_string, max_D=None, cutoff=1e-10, sv_right=True, 5 | sv_vec=None): 6 | """ 7 | Split an input tensor into two pieces using a SVD across some partition 8 | 9 | Args: 10 | tensor (Tensor): Pytorch tensor with at least two indices 11 | 12 | svd_string (str): String of the form 'init_str->left_str,right_str', 13 | where init_str describes the indices of tensor, and 14 | left_str/right_str describe those of the left and 15 | right output tensors. The characters of left_str 16 | and right_str form a partition of the characters in 17 | init_str, but each contain one additional character 18 | representing the new bond which comes from the SVD 19 | 20 | Reversing the terms in svd_string to the left and 21 | right of '->' gives an ein_string which can be used 22 | to multiply both output tensors to give a (low rank 23 | approximation) of the input tensor 24 | 25 | cutoff (float): A truncation threshold which eliminates any 26 | singular values which are strictly less than cutoff 27 | 28 | max_D (int): A maximum allowed value for the new bond. If max_D 29 | is specified, the returned tensors 30 | 31 | sv_right (bool): The SVD gives two orthogonal matrices and a matrix 32 | of singular values. sv_right=True merges the SV 33 | matrix with the right output, while sv_right=False 34 | merges it with the left output 35 | 36 | sv_vec (Tensor): Pytorch vector with length max_D, which is modified 37 | in place to return the vector of singular values 38 | 39 | Returns: 40 | left_tensor (Tensor), 41 | right_tensor (Tensor): Tensors whose indices are described by the 42 | left_str and right_str parts of svd_string 43 | 44 | bond_dim: The dimension of the new bond appearing from 45 | the cutoff in our SVD. Note that this generally 46 | won't match the dimension of left_/right_tensor 47 | at this mode, which is padded with zeros 48 | whenever max_D is specified 49 | """ 50 | def prod(int_list): 51 | output = 1 52 | for num in int_list: 53 | output *= num 54 | return output 55 | 56 | with torch.no_grad(): 57 | # Parse svd_string into init_str, left_str, and right_str 58 | svd_string = svd_string.replace(' ', '') 59 | init_str, post_str = svd_string.split('->') 60 | left_str, right_str = post_str.split(',') 61 | 62 | # Check formatting of init_str, left_str, and right_str 63 | assert all([c.islower() for c in init_str+left_str+right_str]) 64 | assert len(set(init_str+left_str+right_str)) == len(init_str) + 1 65 | assert len(set(init_str))+len(set(left_str))+len(set(right_str)) == \ 66 | len(init_str)+len(left_str)+len(right_str) 67 | 68 | # Get the special character representing our SVD-truncated bond 69 | bond_char = set(left_str).intersection(set(right_str)).pop() 70 | left_part = left_str.replace(bond_char, '') 71 | right_part = right_str.replace(bond_char, '') 72 | 73 | # Permute our tensor into something that can be viewed as a matrix 74 | ein_str = "{init_str}->{left_part+right_part}" 75 | tensor = torch.einsum(ein_str, [tensor]).contiguous() 76 | 77 | left_shape = list(tensor.shape[:len(left_part)]) 78 | right_shape = list(tensor.shape[len(left_part):]) 79 | left_dim, right_dim = prod(left_shape), prod(right_shape) 80 | 81 | tensor = tensor.view([left_dim, right_dim]) 82 | 83 | # Get SVD and format so that left_mat * diag(svs) * right_mat = tensor 84 | left_mat, svs, right_mat = torch.svd(tensor) 85 | svs, _ = torch.sort(svs, descending=True) 86 | right_mat = torch.t(right_mat) 87 | 88 | # Decrease or increase our tensor sizes in the presence of max_D 89 | if max_D and len(svs) > max_D: 90 | svs = svs[:max_D] 91 | left_mat = left_mat[:, :max_D] 92 | right_mat = right_mat[:max_D] 93 | elif max_D and len(svs) < max_D: 94 | copy_svs = torch.zeros([max_D]) 95 | copy_svs[:len(svs)] = svs 96 | copy_left = torch.zeros([left_mat.size(0), max_D]) 97 | copy_left[:, :left_mat.size(1)] = left_mat 98 | copy_right = torch.zeros([max_D, right_mat.size(1)]) 99 | copy_right[:right_mat.size(0)] = right_mat 100 | svs, left_mat, right_mat = copy_svs, copy_left, copy_right 101 | 102 | # If given as input, copy singular values into sv_vec 103 | if sv_vec is not None and svs.shape == sv_vec.shape: 104 | sv_vec[:] = svs 105 | elif sv_vec is not None and svs.shape != sv_vec.shape: 106 | raise TypeError("sv_vec.shape must be {list(svs.shape)}, but is currently {list(sv_vec.shape)}") 107 | 108 | # Find the truncation point relative to our singular value cutoff 109 | truncation = 0 110 | for s in svs: 111 | if s < cutoff: 112 | break 113 | truncation += 1 114 | if truncation == 0: 115 | raise RuntimeError("SVD cutoff too large, attempted to truncate " 116 | "tensor to bond dimension 0") 117 | 118 | # Perform the actual truncation 119 | if max_D: 120 | svs[truncation:] = 0 121 | left_mat[:, truncation:] = 0 122 | right_mat[truncation:] = 0 123 | else: 124 | # If max_D wasn't given, set it to the truncation index 125 | max_D = truncation 126 | svs = svs[:truncation] 127 | left_mat = left_mat[:, :truncation] 128 | right_mat = right_mat[:truncation] 129 | 130 | # Merge the singular values into the appropriate matrix 131 | if sv_right: 132 | right_mat = torch.einsum('l,lr->lr', [svs, right_mat]) 133 | else: 134 | left_mat = torch.einsum('lr,r->lr', [left_mat, svs]) 135 | 136 | # Reshape the matrices to make them proper tensors 137 | left_tensor = left_mat.view(left_shape+[max_D]) 138 | right_tensor = right_mat.view([max_D]+right_shape) 139 | 140 | # Finally, permute the indices into the desired order 141 | if left_str != left_part + bond_char: 142 | left_tensor = torch.einsum("{left_part+bond_char}->{left_str}", 143 | [left_tensor]) 144 | if right_str != bond_char + right_part: 145 | right_tensor = torch.einsum("{bond_char+right_part}->{right_str}", 146 | [right_tensor]) 147 | 148 | return left_tensor, right_tensor, truncation 149 | 150 | def init_tensor(shape, bond_str, init_method): 151 | """ 152 | Initialize a tensor with a given shape 153 | 154 | Args: 155 | shape: The shape of our output parameter tensor. 156 | 157 | bond_str: The bond string describing our output parameter tensor, 158 | which is used in 'random_eye' initialization method. 159 | The characters 'l' and 'r' are used to refer to the 160 | left or right virtual indices of our tensor, and are 161 | both required to be present for the random_eye and 162 | min_random_eye initialization methods. 163 | 164 | init_method: The method used to initialize the entries of our tensor. 165 | This can be either a string, or else a tuple whose first 166 | entry is an initialization method and whose remaining 167 | entries are specific to that method. In each case, std 168 | will always refer to a standard deviation for a random 169 | normal random component of each entry of the tensor. 170 | 171 | Allowed options are: 172 | * ('random_eye', std): Initialize each tensor input 173 | slice close to the identity 174 | * ('random_zero', std): Initialize each tensor input 175 | slice close to the zero matrix 176 | * ('min_random_eye', std, init_dim): Initialize each 177 | tensor input slice close to a truncated identity 178 | matrix, whose truncation leaves init_dim unit 179 | entries on the diagonal. If init_dim is larger 180 | than either of the bond dimensions, then init_dim 181 | is capped at the smaller bond dimension. 182 | """ 183 | # Unpack init_method if it is a tuple 184 | if not isinstance(init_method, str): 185 | init_str = init_method[0] 186 | std = init_method[1] 187 | if init_str == 'min_random_eye': 188 | init_dim = init_method[2] 189 | 190 | init_method = init_str 191 | else: 192 | std = 1e-9 193 | 194 | # Check that bond_str is properly sized and doesn't have repeat indices 195 | assert len(shape) == len(bond_str) 196 | assert len(set(bond_str)) == len(bond_str) 197 | 198 | if init_method not in ['random_eye', 'min_random_eye', 'random_zero']: 199 | raise ValueError("Unknown initialization method: {init_method}") 200 | 201 | if init_method in ['random_eye', 'min_random_eye']: 202 | bond_chars = ['l', 'r'] 203 | assert all([c in bond_str for c in bond_chars]) 204 | 205 | # Initialize our tensor slices as identity matrices which each fill 206 | # some or all of the initially allocated bond space 207 | if init_method == 'min_random_eye': 208 | 209 | # The dimensions for our initial identity matrix. These will each 210 | # be init_dim, unless init_dim exceeds one of the bond dimensions 211 | bond_dims = [shape[bond_str.index(c)] for c in bond_chars] 212 | if all([init_dim <= full_dim for full_dim in bond_dims]): 213 | bond_dims = [init_dim, init_dim] 214 | else: 215 | init_dim = min(bond_dims) 216 | 217 | eye_shape = [init_dim if c in bond_chars else 1 for c in bond_str] 218 | expand_shape = [init_dim if c in bond_chars else shape[i] 219 | for i, c in enumerate(bond_str)] 220 | 221 | elif init_method == 'random_eye': 222 | eye_shape = [shape[i] if c in bond_chars else 1 223 | for i, c in enumerate(bond_str)] 224 | expand_shape = shape 225 | bond_dims = [shape[bond_str.index(c)] for c in bond_chars] 226 | 227 | eye_tensor = torch.eye(bond_dims[0], bond_dims[1]).view(eye_shape) 228 | eye_tensor = eye_tensor.expand(expand_shape) 229 | 230 | tensor = torch.zeros(shape) 231 | tensor[[slice(dim) for dim in expand_shape]] = eye_tensor 232 | 233 | # Add on a bit of random noise 234 | tensor += std * torch.randn(shape) 235 | 236 | elif init_method == 'random_zero': 237 | tensor = std * torch.randn(shape) 238 | 239 | return tensor 240 | 241 | 242 | ### OLDER MISCELLANEOUS FUNCTIONS ### 243 | 244 | def onehot(labels, max_value): 245 | """ 246 | Convert a batch of labels from the set {0, 1,..., num_value-1} into their 247 | onehot encoded counterparts 248 | """ 249 | label_vecs = torch.zeros([len(labels), max_value]) 250 | 251 | for i, label in enumerate(labels): 252 | label_vecs[i, label] = 1. 253 | 254 | return label_vecs 255 | 256 | def joint_shuffle(input_data, input_labels): 257 | """ 258 | Shuffle input data and labels in a joint manner, so each label points to 259 | its corresponding datum. Works for both regular and CUDA tensors 260 | """ 261 | assert input_data.is_cuda == input_labels.is_cuda 262 | use_gpu = input_data.is_cuda 263 | if use_gpu: 264 | input_data, input_labels = input_data.cpu(), input_labels.cpu() 265 | 266 | data, labels = input_data.numpy(), input_labels.numpy() 267 | 268 | # Shuffle relative to the same seed 269 | np.random.seed(0) 270 | np.random.shuffle(data) 271 | np.random.seed(0) 272 | np.random.shuffle(labels) 273 | 274 | data, labels = torch.from_numpy(data), torch.from_numpy(labels) 275 | if use_gpu: 276 | data, labels = data.cuda(), labels.cuda() 277 | 278 | return data, labels 279 | 280 | def load_HV_data(length): 281 | """ 282 | Output a toy "horizontal/vertical" data set of black and white 283 | images with size length x length. Each image contains a single 284 | horizontal or vertical stripe, set against a background 285 | of the opposite color. The labels associated with these images 286 | are either 0 (horizontal stripe) or 1 (vertical stripe). 287 | 288 | In its current version, this returns two data sets, a training 289 | set with 75% of the images and a test set with 25% of the 290 | images. 291 | """ 292 | num_images = 4 * (2**(length-1) - 1) 293 | num_patterns = num_images // 2 294 | split = num_images // 4 295 | 296 | if length > 14: 297 | print("load_HV_data will generate {} images, " 298 | "this could take a while...".format(num_images)) 299 | 300 | images = np.empty([num_images,length,length], dtype=np.float32) 301 | labels = np.empty(num_images, dtype=np.int) 302 | 303 | # Used to generate the stripe pattern from integer i below 304 | template = "{:0" + str(length) + "b}" 305 | 306 | for i in range(1, num_patterns+1): 307 | pattern = template.format(i) 308 | pattern = [int(s) for s in pattern] 309 | 310 | for j, val in enumerate(pattern): 311 | # Horizontal stripe pattern 312 | images[2*i-2, j, :] = val 313 | # Vertical stripe pattern 314 | images[2*i-1, :, j] = val 315 | 316 | labels[2*i-2] = 0 317 | labels[2*i-1] = 1 318 | 319 | # Shuffle and partition into training and test sets 320 | np.random.seed(0) 321 | np.random.shuffle(images) 322 | np.random.seed(0) 323 | np.random.shuffle(labels) 324 | 325 | train_images, train_labels = images[split:], labels[split:] 326 | test_images, test_labels = images[:split], labels[:split] 327 | 328 | return torch.from_numpy(train_images), \ 329 | torch.from_numpy(train_labels), \ 330 | torch.from_numpy(test_images), \ 331 | torch.from_numpy(test_labels) 332 | -------------------------------------------------------------------------------- /utils/utils.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/timqqt/MERA_Image_Classification/e96211f45ade86f031a0d99ad0670231844ef3a1/utils/utils.pyc --------------------------------------------------------------------------------