├── Datasets ├── gas_data │ └── read_gas_data.py ├── merl │ ├── frame_01.png │ ├── frame_02.png │ ├── frame_03.png │ ├── frame_04.png │ ├── frame_05.png │ ├── frame_06.png │ ├── frame_07.png │ ├── frame_08.png │ ├── frame_09.png │ ├── frame_10.png │ ├── frame_11.png │ ├── frame_12.png │ ├── frame_13.png │ ├── frame_14.png │ ├── frame_15.png │ ├── frame_16.png │ ├── frame_17.png │ ├── frame_18.png │ ├── frame_19.png │ ├── frame_20.png │ ├── frame_21.png │ ├── frame_22.png │ ├── frame_23.png │ ├── frame_24.png │ ├── frame_25.png │ ├── frame_26.png │ ├── frame_27.png │ ├── frame_28.png │ ├── frame_29.png │ ├── frame_30.png │ ├── frame_31.png │ ├── frame_32.png │ ├── frame_33.png │ ├── frame_34.png │ ├── frame_35.png │ ├── frame_36.png │ ├── frame_37.png │ └── frame_38.png └── mri_data │ ├── aperiodic_pincat.mat │ ├── imag_mri.mat │ ├── invivo_perfusion.mat │ └── real_mri.mat ├── README.md └── jupyter_notebooks ├── .ipynb_checkpoints ├── gas_sensor_array_experiments-checkpoint.ipynb ├── toucan_synthetic_tests-checkpoint.ipynb └── toucan_video_completion-checkpoint.ipynb ├── __pycache__ ├── olstec.cpython-36.pyc ├── tecpsgd.cpython-36.pyc └── tsvd.cpython-36.pyc ├── gas_sensor_array_experiments.ipynb ├── olstec.py ├── run_all_mri.py ├── stc.py ├── tecpsgd.py ├── toucan_synthetic_tests.ipynb ├── toucan_video_completion.ipynb └── tsvd.py /Datasets/gas_data/read_gas_data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | 4 | path = '/Users/kgilman/Desktop/WTD_upload/Toluene_200/' 5 | mypath = '/Users/kgilman/Desktop/WTD_upload/Toluene_200/L4/' 6 | strfile = '/Users/kgilman/Desktop/WTD_upload/Toluene_200/L4/201106060617_board_setPoint_500V_fan_setPoint_060_mfc_setPoint_Toluene_200ppm_p7' 7 | 8 | from os import listdir 9 | from os.path import isfile, join 10 | myfiles = [f for f in listdir(mypath) if isfile(join(mypath, f))] 11 | 12 | def find_nearest(array, values): 13 | indices = np.abs(np.subtract.outer(array, values)).argmin(0) 14 | return indices 15 | 16 | datalist = [] 17 | with open(strfile,'r') as file: 18 | for line in file: 19 | datalist.append(line.split('\t')) 20 | file.close() 21 | 22 | 23 | i = 0 24 | 25 | # values = np.arange(0,260000,100) #10 Hz 26 | # data = np.zeros((len(myfiles),2600,71)) 27 | # outfile = path + 'L4array_10hz' 28 | 29 | values = np.arange(0,260000,200) #5 Hz 30 | data = np.zeros((len(myfiles),1300,71)) 31 | outfile = path + 'L4array_5hz' 32 | 33 | # values = np.arange(0,260000,1000) #1 Hz 34 | # data = np.zeros((len(myfiles),260,71)) 35 | # outfile = path + 'L4array_1hz' 36 | 37 | data_means = np.zeros((len(myfiles),71)) 38 | # data_norms = np.zeros((len(myfiles),71)) 39 | data_norms = np.zeros((len(myfiles))) 40 | for file in myfiles: 41 | datalist = [] 42 | with open(mypath + file,'r') as f: 43 | for line in f: 44 | datalist.append(line.split('\t')) 45 | f.close() 46 | 47 | #Sample at desired frequency 48 | data_array = np.array(datalist) 49 | 50 | time_idx = data_array[:,0].astype(int) 51 | sample_idx = find_nearest(time_idx,values) 52 | data_array = data_array[sample_idx,:] 53 | 54 | #Filter the columns 55 | data_array = data_array[:,-1-80:-2] 56 | data_array = data_array.astype(int) 57 | 58 | data_idx = np.where(data_array[1,:]!=1)[0] 59 | data_array = data_array[:,data_idx] 60 | 61 | #Convert to float, subtract time-series mean, normalize by Frob norm 62 | data_array = data_array.astype(float) 63 | means = np.mean(data_array,axis=0) 64 | data_array -= means 65 | norms = np.linalg.norm(data_array,'fro') 66 | # norms = np.linalg.norm(data_array, axis=0) 67 | # norms = np.max(data_array,axis=0) 68 | 69 | data_array /= norms 70 | 71 | data[i,:,:] = data_array 72 | data_means[i,:] = means 73 | # data_norms[i,:] = norms 74 | data_norms[i] = norms 75 | i += 1 76 | 77 | np.savez(outfile,data,data_means,data_norms) 78 | 79 | print("I'm done") -------------------------------------------------------------------------------- /Datasets/merl/frame_01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kgilman/TOUCAN/059b81f3e3f39afedc3e5ad51c37600d6056f9fd/Datasets/merl/frame_01.png -------------------------------------------------------------------------------- /Datasets/merl/frame_02.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kgilman/TOUCAN/059b81f3e3f39afedc3e5ad51c37600d6056f9fd/Datasets/merl/frame_02.png -------------------------------------------------------------------------------- /Datasets/merl/frame_03.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kgilman/TOUCAN/059b81f3e3f39afedc3e5ad51c37600d6056f9fd/Datasets/merl/frame_03.png -------------------------------------------------------------------------------- /Datasets/merl/frame_04.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kgilman/TOUCAN/059b81f3e3f39afedc3e5ad51c37600d6056f9fd/Datasets/merl/frame_04.png -------------------------------------------------------------------------------- /Datasets/merl/frame_05.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kgilman/TOUCAN/059b81f3e3f39afedc3e5ad51c37600d6056f9fd/Datasets/merl/frame_05.png -------------------------------------------------------------------------------- /Datasets/merl/frame_06.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kgilman/TOUCAN/059b81f3e3f39afedc3e5ad51c37600d6056f9fd/Datasets/merl/frame_06.png -------------------------------------------------------------------------------- /Datasets/merl/frame_07.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kgilman/TOUCAN/059b81f3e3f39afedc3e5ad51c37600d6056f9fd/Datasets/merl/frame_07.png -------------------------------------------------------------------------------- /Datasets/merl/frame_08.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kgilman/TOUCAN/059b81f3e3f39afedc3e5ad51c37600d6056f9fd/Datasets/merl/frame_08.png -------------------------------------------------------------------------------- /Datasets/merl/frame_09.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kgilman/TOUCAN/059b81f3e3f39afedc3e5ad51c37600d6056f9fd/Datasets/merl/frame_09.png -------------------------------------------------------------------------------- /Datasets/merl/frame_10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kgilman/TOUCAN/059b81f3e3f39afedc3e5ad51c37600d6056f9fd/Datasets/merl/frame_10.png -------------------------------------------------------------------------------- /Datasets/merl/frame_11.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kgilman/TOUCAN/059b81f3e3f39afedc3e5ad51c37600d6056f9fd/Datasets/merl/frame_11.png -------------------------------------------------------------------------------- /Datasets/merl/frame_12.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kgilman/TOUCAN/059b81f3e3f39afedc3e5ad51c37600d6056f9fd/Datasets/merl/frame_12.png -------------------------------------------------------------------------------- /Datasets/merl/frame_13.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kgilman/TOUCAN/059b81f3e3f39afedc3e5ad51c37600d6056f9fd/Datasets/merl/frame_13.png -------------------------------------------------------------------------------- /Datasets/merl/frame_14.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kgilman/TOUCAN/059b81f3e3f39afedc3e5ad51c37600d6056f9fd/Datasets/merl/frame_14.png -------------------------------------------------------------------------------- /Datasets/merl/frame_15.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kgilman/TOUCAN/059b81f3e3f39afedc3e5ad51c37600d6056f9fd/Datasets/merl/frame_15.png -------------------------------------------------------------------------------- /Datasets/merl/frame_16.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kgilman/TOUCAN/059b81f3e3f39afedc3e5ad51c37600d6056f9fd/Datasets/merl/frame_16.png -------------------------------------------------------------------------------- /Datasets/merl/frame_17.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kgilman/TOUCAN/059b81f3e3f39afedc3e5ad51c37600d6056f9fd/Datasets/merl/frame_17.png -------------------------------------------------------------------------------- /Datasets/merl/frame_18.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kgilman/TOUCAN/059b81f3e3f39afedc3e5ad51c37600d6056f9fd/Datasets/merl/frame_18.png -------------------------------------------------------------------------------- /Datasets/merl/frame_19.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kgilman/TOUCAN/059b81f3e3f39afedc3e5ad51c37600d6056f9fd/Datasets/merl/frame_19.png -------------------------------------------------------------------------------- /Datasets/merl/frame_20.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kgilman/TOUCAN/059b81f3e3f39afedc3e5ad51c37600d6056f9fd/Datasets/merl/frame_20.png -------------------------------------------------------------------------------- /Datasets/merl/frame_21.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kgilman/TOUCAN/059b81f3e3f39afedc3e5ad51c37600d6056f9fd/Datasets/merl/frame_21.png -------------------------------------------------------------------------------- /Datasets/merl/frame_22.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kgilman/TOUCAN/059b81f3e3f39afedc3e5ad51c37600d6056f9fd/Datasets/merl/frame_22.png -------------------------------------------------------------------------------- /Datasets/merl/frame_23.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kgilman/TOUCAN/059b81f3e3f39afedc3e5ad51c37600d6056f9fd/Datasets/merl/frame_23.png -------------------------------------------------------------------------------- /Datasets/merl/frame_24.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kgilman/TOUCAN/059b81f3e3f39afedc3e5ad51c37600d6056f9fd/Datasets/merl/frame_24.png -------------------------------------------------------------------------------- /Datasets/merl/frame_25.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kgilman/TOUCAN/059b81f3e3f39afedc3e5ad51c37600d6056f9fd/Datasets/merl/frame_25.png -------------------------------------------------------------------------------- /Datasets/merl/frame_26.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kgilman/TOUCAN/059b81f3e3f39afedc3e5ad51c37600d6056f9fd/Datasets/merl/frame_26.png -------------------------------------------------------------------------------- /Datasets/merl/frame_27.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kgilman/TOUCAN/059b81f3e3f39afedc3e5ad51c37600d6056f9fd/Datasets/merl/frame_27.png -------------------------------------------------------------------------------- /Datasets/merl/frame_28.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kgilman/TOUCAN/059b81f3e3f39afedc3e5ad51c37600d6056f9fd/Datasets/merl/frame_28.png -------------------------------------------------------------------------------- /Datasets/merl/frame_29.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kgilman/TOUCAN/059b81f3e3f39afedc3e5ad51c37600d6056f9fd/Datasets/merl/frame_29.png -------------------------------------------------------------------------------- /Datasets/merl/frame_30.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kgilman/TOUCAN/059b81f3e3f39afedc3e5ad51c37600d6056f9fd/Datasets/merl/frame_30.png -------------------------------------------------------------------------------- /Datasets/merl/frame_31.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kgilman/TOUCAN/059b81f3e3f39afedc3e5ad51c37600d6056f9fd/Datasets/merl/frame_31.png -------------------------------------------------------------------------------- /Datasets/merl/frame_32.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kgilman/TOUCAN/059b81f3e3f39afedc3e5ad51c37600d6056f9fd/Datasets/merl/frame_32.png -------------------------------------------------------------------------------- /Datasets/merl/frame_33.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kgilman/TOUCAN/059b81f3e3f39afedc3e5ad51c37600d6056f9fd/Datasets/merl/frame_33.png -------------------------------------------------------------------------------- /Datasets/merl/frame_34.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kgilman/TOUCAN/059b81f3e3f39afedc3e5ad51c37600d6056f9fd/Datasets/merl/frame_34.png -------------------------------------------------------------------------------- /Datasets/merl/frame_35.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kgilman/TOUCAN/059b81f3e3f39afedc3e5ad51c37600d6056f9fd/Datasets/merl/frame_35.png -------------------------------------------------------------------------------- /Datasets/merl/frame_36.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kgilman/TOUCAN/059b81f3e3f39afedc3e5ad51c37600d6056f9fd/Datasets/merl/frame_36.png -------------------------------------------------------------------------------- /Datasets/merl/frame_37.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kgilman/TOUCAN/059b81f3e3f39afedc3e5ad51c37600d6056f9fd/Datasets/merl/frame_37.png -------------------------------------------------------------------------------- /Datasets/merl/frame_38.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kgilman/TOUCAN/059b81f3e3f39afedc3e5ad51c37600d6056f9fd/Datasets/merl/frame_38.png -------------------------------------------------------------------------------- /Datasets/mri_data/aperiodic_pincat.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kgilman/TOUCAN/059b81f3e3f39afedc3e5ad51c37600d6056f9fd/Datasets/mri_data/aperiodic_pincat.mat -------------------------------------------------------------------------------- /Datasets/mri_data/imag_mri.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kgilman/TOUCAN/059b81f3e3f39afedc3e5ad51c37600d6056f9fd/Datasets/mri_data/imag_mri.mat -------------------------------------------------------------------------------- /Datasets/mri_data/invivo_perfusion.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kgilman/TOUCAN/059b81f3e3f39afedc3e5ad51c37600d6056f9fd/Datasets/mri_data/invivo_perfusion.mat -------------------------------------------------------------------------------- /Datasets/mri_data/real_mri.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kgilman/TOUCAN/059b81f3e3f39afedc3e5ad51c37600d6056f9fd/Datasets/mri_data/real_mri.mat -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | This repository contains the code necessary to generate all figures from our papers: 2 | 3 | Gilman, Kyle and Laura Balzano. "Grassmannian Optimization for Tensor Completion and Tracking in the t-SVD Algebra," In review, 2020. http://arxiv.org/abs/2001.11419 4 | 5 | Gilman, Kyle and Laura Balzano. "Online Tensor Completion and Free Submodule Tracking with the t-SVD," To appear in International Conference on Acoustics, Speech, and Signal Processing, 2020. 6 | 7 | Please cite one of these two. 8 | 9 | 10 | All code is in the folder "jupyter_notebooks." 11 | -TOUCAN code is in the file tsvd.py. 12 | -Run each Jupyter notebook to generate results for synthetic data, chemo-sensing, time-lapse video experiments. 13 | -Run the Python script run_all_mri.py to generate MRI results. 14 | 15 | All necessary datasets are in the folder "Datasets." -------------------------------------------------------------------------------- /jupyter_notebooks/.ipynb_checkpoints/toucan_video_completion-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from tsvd import *\n", 10 | "import numpy as np\n", 11 | "import matplotlib.pyplot as plt\n", 12 | "import matplotlib.image as mpimg\n", 13 | "import time\n", 14 | "import scipy.misc\n", 15 | "from tecpsgd import *\n", 16 | "from olstec import *\n", 17 | "from stc import *\n", 18 | "from skimage.measure import compare_ssim\n", 19 | "\n", 20 | "\n", 21 | "## Read in video data and form tensor\n", 22 | "import os\n", 23 | "path = '/Users/kgilman/Desktop/datasets/dataMERL/'\n", 24 | "files = []\n", 25 | "for i in sorted(os.listdir(path)):\n", 26 | " if os.path.isfile(os.path.join(path,i)) and 'frame' in i:\n", 27 | " files.append(i)\n", 28 | "\n", 29 | "num_frames = len(files)\n", 30 | "nx,ny = mpimg.imread(path + files[0]).shape\n", 31 | "L = np.zeros((int(nx), num_frames, int(ny)))\n", 32 | "\n", 33 | "for i in range(0,num_frames):\n", 34 | " im = mpimg.imread(path + files[i])\n", 35 | " im = im / np.max(im)\n", 36 | " L[:,i,:] = im\n", 37 | "\n", 38 | " \n", 39 | "# Lmean = np.mean(L)\n", 40 | "Lmean = 0\n", 41 | "L -= Lmean\n", 42 | "L = Tensor(L)\n", 43 | "n1,n2,n3 = L.shape()\n", 44 | "\n", 45 | "np.random.seed(0)\n", 46 | "rho = 0.8 #% missing entries\n", 47 | "mask = np.random.rand(n1, n2, n3)\n", 48 | "mask[mask > rho] = 1\n", 49 | "mask[mask <= rho] = 0\n", 50 | "mask = mask.astype(int)\n", 51 | "\n", 52 | "Y = Tensor(L.array() * mask)\n", 53 | "\n", 54 | "Lfrob = tfrobnorm(L)\n", 55 | "\n", 56 | "print(Y.shape())\n", 57 | "\n", 58 | "plt.imshow(L.array()[:,35,:],cmap='gray')\n", 59 | "name = 'original-' + str(int(rho*100)) + '%'+'.eps'\n", 60 | "plt.savefig(name)\n", 61 | "plt.show()\n", 62 | "\n", 63 | "plt.imshow(Y.array()[:,35,:],cmap='gray')\n", 64 | "name = 'observed-' + str(int(rho*100)) + '%'+'.eps'\n", 65 | "plt.savefig(name)\n", 66 | "plt.show()\n" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": null, 72 | "metadata": {}, 73 | "outputs": [], 74 | "source": [ 75 | "frame = L.array()[:,35,:]\n", 76 | "print(frame.shape)\n", 77 | "blk_size = 25\n", 78 | "ranks = []\n", 79 | "for k in range(0,int(np.floor(n3/blk_size))):\n", 80 | " blk = frame[:,blk_size * k: blk_size * k + blk_size]\n", 81 | " U,S,V = np.linalg.svd(blk,full_matrices=False)\n", 82 | " pwr_perc = np.cumsum(S / np.sum(S))\n", 83 | "# print(pwr_perc)\n", 84 | " idx = np.where(pwr_perc >= 0.8)[0]\n", 85 | "# print(idx)\n", 86 | " ranks.append(idx[0] + 1)\n", 87 | " \n", 88 | "plt.plot(ranks)\n", 89 | "plt.show()\n" 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "execution_count": null, 95 | "metadata": {}, 96 | "outputs": [], 97 | "source": [ 98 | "U,S,V = tsvd(L,full=False)\n", 99 | "\n", 100 | "plt.figure(figsize=(10,5))\n", 101 | "plt.plot(np.diag(S.array()[:,:,0]))\n", 102 | "plt.xticks(np.arange(0, min(n1,n2), step=5))\n", 103 | "plt.rcParams.update({'font.size': 15})\n", 104 | "plt.show()\n", 105 | "\n", 106 | "s = np.diag(S.array()[:,:,0])\n", 107 | "power80 = 0.75*np.sum(s)\n", 108 | "\n", 109 | "cum_sum = 0\n", 110 | "k = 0\n", 111 | "for i in range(0,len(s)):\n", 112 | " cum_sum += s[i]\n", 113 | " k += 1\n", 114 | " if(cum_sum > power80):\n", 115 | " break\n", 116 | "\n", 117 | "print('Number of LR t-SVD Approx Components: {:.3f}'.format(k))" 118 | ] 119 | }, 120 | { 121 | "cell_type": "markdown", 122 | "metadata": {}, 123 | "source": [ 124 | "### STC" 125 | ] 126 | }, 127 | { 128 | "cell_type": "code", 129 | "execution_count": null, 130 | "metadata": {}, 131 | "outputs": [], 132 | "source": [ 133 | "## Sequential Tensor Completion\n", 134 | "\n", 135 | "Tensor_Y = np.transpose(L.array(),[0,2,1])\n", 136 | "Mask_Y = np.transpose(mask,[0,2,1])\n", 137 | "numcycles = 1\n", 138 | "outer = 3\n", 139 | "r1 = 50\n", 140 | "r2 = 50\n", 141 | "r3 = 2\n", 142 | "fun = lambda Lhat,idx: [0, tfrobnorm_array(Lhat[:,:,idx] - Tensor_Y[:,:,idx]) / tfrobnorm_array(Tensor_Y[:,:,idx])]\n", 143 | "Lhat, stats, tElapsed = stc(Tensor_Y,Mask_Y,r1,r2,r3,outer,numcycles,fun=fun,verbose=True)\n", 144 | "\n", 145 | "Lhat = np.transpose(Lhat,[0,2,1])\n", 146 | "\n", 147 | "nrmse_stc = tfrobnorm(Tensor(Lhat) - L) / Lfrob\n", 148 | "\n", 149 | "print('STC Time: {:4f}'.format(tElapsed))\n", 150 | "print('NRMSE STC: {:6f}'.format(nrmse_stc))" 151 | ] 152 | }, 153 | { 154 | "cell_type": "code", 155 | "execution_count": null, 156 | "metadata": {}, 157 | "outputs": [], 158 | "source": [ 159 | "plt.imshow(Lhat[:,35,:],cmap='gray')\n", 160 | "name = 'merl_stc-' + str(int(rho*100)) + '%'+'.eps'\n", 161 | "plt.savefig(name)\n", 162 | "plt.show()" 163 | ] 164 | }, 165 | { 166 | "cell_type": "code", 167 | "execution_count": null, 168 | "metadata": {}, 169 | "outputs": [], 170 | "source": [ 171 | "# from skimage.metrics import structural_similarity as ssim\n", 172 | "# ssim_noise = ssim(L[:,35,:], Lhat[:,35,:],\n", 173 | "# data_range=np.max(Lhat[:,35,:]) - np.min(Lhat[:,35,:]))\n", 174 | "\n", 175 | "(score, diff) = compare_ssim(L.array()[:,35,:], Lhat[:,35,:], full=True)\n", 176 | "print('STC SSIM: {:.5f}'.format(score))" 177 | ] 178 | }, 179 | { 180 | "cell_type": "markdown", 181 | "metadata": {}, 182 | "source": [ 183 | "### OLSTEC" 184 | ] 185 | }, 186 | { 187 | "cell_type": "code", 188 | "execution_count": null, 189 | "metadata": {}, 190 | "outputs": [], 191 | "source": [ 192 | "## OLSTEC\n", 193 | "\n", 194 | "Tensor_Y_Noiseless = np.transpose(L.array(),[0,2,1])\n", 195 | "rank = 100\n", 196 | "OmegaTensor = np.transpose(mask,[0,2,1])\n", 197 | "tensor_dims = [n1,n3,n2]\n", 198 | "maxepochs = 2\n", 199 | "tolcost = 1e-14\n", 200 | "permute_on = False\n", 201 | "\n", 202 | "options = {\n", 203 | " 'maxepochs': maxepochs,\n", 204 | " 'tolcost': tolcost,\n", 205 | " 'lam': 0.7,\n", 206 | " 'mu': 0.1,\n", 207 | " 'permute_on': permute_on,\n", 208 | " 'store_subinfo': True,\n", 209 | " 'store_matrix': False,\n", 210 | " 'verbose': False,\n", 211 | " 'tw_flag': None,\n", 212 | " 'tw_len': None\n", 213 | "}\n", 214 | "\n", 215 | "Xinit = {\n", 216 | " 'A': np.random.randn(tensor_dims[0], rank),\n", 217 | " 'B': np.random.randn(tensor_dims[1], rank),\n", 218 | " 'C': np.random.randn(tensor_dims[2], rank)\n", 219 | "}\n", 220 | "\n", 221 | "Xsol_olstec, Y_hat_olstec, info_olstec, sub_infos_olstec = OLSTEC(Tensor_Y_Noiseless, OmegaTensor, None, tensor_dims, rank,\n", 222 | " Xinit, options)\n", 223 | "\n", 224 | "A_t0 = Xsol_olstec['A']\n", 225 | "B_t0 = Xsol_olstec['B']\n", 226 | "C_t0 = Xsol_olstec['C']\n", 227 | "\n", 228 | "# Y_hat_olstec = np.zeros(Tensor_Y_Noiseless.shape)\n", 229 | "# for f in range(0,n2):\n", 230 | "# gamma = C_t0[f,:].T\n", 231 | "# Y_hat_olstec[:,:,f] = A_t0 @ np.diag(gamma) @ B_t0.T\n", 232 | " \n", 233 | "Y_hat_olstec = np.transpose(Y_hat_olstec,[0,2,1])\n", 234 | "\n", 235 | "nrmse_olstec = tfrobnorm(Tensor(Y_hat_olstec) - L) / Lfrob\n", 236 | "tElapsed_olstec = np.sum(sub_infos_olstec['times'])\n", 237 | "\n", 238 | "print('OLSTEC Time: {:4f}'.format(tElapsed_olstec))\n", 239 | "print('NRMSE OLSTEC: {:6f}'.format(nrmse_olstec))\n", 240 | "\n", 241 | "plt.imshow(Y_hat_olstec[:,35,:],cmap='gray')\n", 242 | "name = 'merl_olstec-' + str(int(rho*100)) + '%'+'.eps'\n", 243 | "plt.savefig(name)\n", 244 | "plt.show()\n", 245 | "\n", 246 | "(score, diff) = compare_ssim(L.array()[:,35,:], Y_hat_olstec[:,35,:], full=True)\n", 247 | "print('OLSTEC SSIM: {:.5f}'.format(score))" 248 | ] 249 | }, 250 | { 251 | "cell_type": "markdown", 252 | "metadata": {}, 253 | "source": [ 254 | "### TeCPSGD" 255 | ] 256 | }, 257 | { 258 | "cell_type": "code", 259 | "execution_count": null, 260 | "metadata": {}, 261 | "outputs": [], 262 | "source": [ 263 | "## TeCPSGD\n", 264 | "from tecpsgd import *\n", 265 | "Tensor_Y_Noiseless = np.transpose(L.array(),[0,2,1])\n", 266 | "rank = 100\n", 267 | "OmegaTensor = np.transpose(mask,[0,2,1])\n", 268 | "tensor_dims = [n1,n3,n2]\n", 269 | "maxepochs = 2\n", 270 | "tolcost = 1e-14\n", 271 | "permute_on = False\n", 272 | "\n", 273 | "options = {\n", 274 | " 'maxepochs': maxepochs,\n", 275 | " 'tolcost': tolcost,\n", 276 | " 'lam': 0.05,\n", 277 | " 'stepsize': 10,\n", 278 | "# 'mu': 0.1,\n", 279 | " 'permute_on': permute_on,\n", 280 | " 'store_subinfo': True,\n", 281 | " 'store_matrix': False,\n", 282 | " 'verbose': False\n", 283 | "}\n", 284 | "\n", 285 | "Xinit = {\n", 286 | " 'A': np.random.randn(tensor_dims[0], rank),\n", 287 | " 'B': np.random.randn(tensor_dims[1], rank),\n", 288 | " 'C': np.random.randn(tensor_dims[2], rank)\n", 289 | "}\n", 290 | "\n", 291 | "Xsol_TeCPSGD, Y_hat_tecpsgd, info_TeCPSGD, sub_infos_TeCPSGD = TeCPSGD(Tensor_Y_Noiseless, OmegaTensor, None, tensor_dims, rank,\n", 292 | " Xinit, options)\n", 293 | "\n", 294 | "A_t0 = Xsol_TeCPSGD['A']\n", 295 | "B_t0 = Xsol_TeCPSGD['B']\n", 296 | "C_t0 = Xsol_TeCPSGD['C']\n", 297 | "\n", 298 | "# Y_hat_tecpsgd = np.zeros(Tensor_Y_Noiseless.shape)\n", 299 | "# for f in range(0,n2):\n", 300 | "# gamma = C_t0[f,:].T\n", 301 | "# Y_hat_tecpsgd[:,:, f] = A_t0 @ np.diag(gamma) @ B_t0.T\n", 302 | " \n", 303 | "Y_hat_tecpsgd = np.transpose(Y_hat_tecpsgd,[0,2,1])\n", 304 | "\n", 305 | "nrmse_tecpsgd = tfrobnorm(Tensor(Y_hat_tecpsgd) - L) / Lfrob\n", 306 | "tElapsed_tecpsgd = np.sum(sub_infos_TeCPSGD['times'])\n", 307 | "\n", 308 | "print('TeCPSGD Time: {:4f}'.format(tElapsed_tecpsgd))\n", 309 | "print('NRMSE TeCPSGD: {:6f}'.format(nrmse_tecpsgd))\n", 310 | "plt.imshow(Y_hat_tecpsgd[:,35,:],cmap='gray')\n", 311 | "name = 'merl_tecpsgd-' + str(int(rho*100)) + '%'+'.eps'\n", 312 | "plt.savefig(name)\n", 313 | "plt.show()\n", 314 | "\n", 315 | "(score, diff) = compare_ssim(L.array()[:,35,:], Y_hat_tecpsgd[:,35,:], full=True)\n", 316 | "print('TeCPSGD SSIM: {:.5f}'.format(score))" 317 | ] 318 | }, 319 | { 320 | "cell_type": "markdown", 321 | "metadata": {}, 322 | "source": [ 323 | "### TNN-ADMM" 324 | ] 325 | }, 326 | { 327 | "cell_type": "code", 328 | "execution_count": null, 329 | "metadata": {}, 330 | "outputs": [], 331 | "source": [ 332 | "# ## TNN-ADMM\n", 333 | "fun = lambda X: [0,tfrobnorm(X - L) / Lfrob]\n", 334 | "# fun = lambda X: [0,0]\n", 335 | "Y_hat_tnn,stats_tnn,tElapsed_tnn = lrtc(Y,mask,niter = 75,fun=fun,verbose=False)\n", 336 | "# Y_hat_tnn,stats_tnn,tElapsed_tnn = lrtc(Tensor(np.transpose(Y.array(),[0,2,1])),np.transpose(mask,[0,2,1]),niter = 75,fun=fun,verbose=False)\n", 337 | "\n", 338 | "tnn_nrmse = tfrobnorm(Y_hat_tnn - L) / Lfrob\n", 339 | "print('TNN Time: {:4f}'.format(tElapsed_tnn))\n", 340 | "print('TNN NRMSE: {:4f}'.format(tnn_nrmse))\n", 341 | "\n", 342 | "plt.imshow(Y_hat_tnn.array()[:,35,:],cmap='gray')\n", 343 | "name = 'merl_tnn-' + str(int(rho*100)) + '%'+'.eps'\n", 344 | "plt.savefig(name)\n", 345 | "plt.show()\n", 346 | "\n", 347 | "(score, diff) = compare_ssim(L.array()[:,35,:], Y_hat_tnn.array()[:,35,:], full=True)\n", 348 | "print('TNN-ADMM SSIM: {:.5f}'.format(score))" 349 | ] 350 | }, 351 | { 352 | "cell_type": "markdown", 353 | "metadata": {}, 354 | "source": [ 355 | "### TCTF" 356 | ] 357 | }, 358 | { 359 | "cell_type": "code", 360 | "execution_count": null, 361 | "metadata": {}, 362 | "outputs": [], 363 | "source": [ 364 | "## TCTF\n", 365 | "rank = 2\n", 366 | "fun = lambda X,Z: [0, tfrobnorm(X*Z - L) / Lfrob]\n", 367 | "Xtctf,Ztctf, stats_tctf, tElapsed_tctf = tctf(Y,mask,rank,niter = 75,fun=fun,verbose=False)\n", 368 | "Y_hat_tctf = Xtctf * Ztctf\n", 369 | "\n", 370 | "tctf_nrmse = tfrobnorm(Y_hat_tctf - L) / Lfrob\n", 371 | "print('TCTF Time: {:4f}'.format(tElapsed_tctf))\n", 372 | "print('TCTF NRMSE: {:4f}'.format(tctf_nrmse))\n", 373 | "\n", 374 | "plt.imshow(Y_hat_tctf.array()[:,35,:],cmap='gray')\n", 375 | "name = 'merl_tctf-' + str(int(rho*100)) + '%'+'.eps'\n", 376 | "plt.savefig(name)\n", 377 | "plt.show()\n", 378 | "\n", 379 | "(score, diff) = compare_ssim(L.array()[:,35,:], Y_hat_tctf.array()[:,35,:], full=True)\n", 380 | "print('TCTF SSIM: {:.5f}'.format(score))" 381 | ] 382 | }, 383 | { 384 | "cell_type": "markdown", 385 | "metadata": {}, 386 | "source": [ 387 | "### TOUCAN" 388 | ] 389 | }, 390 | { 391 | "cell_type": "code", 392 | "execution_count": null, 393 | "metadata": {}, 394 | "outputs": [], 395 | "source": [ 396 | "## TOUCAN\n", 397 | "rank = 1\n", 398 | "# fun = lambda X: [0, tfrobnorm(X - L) / Lfrob]\n", 399 | "fun = lambda X,k: [0, tfrobnorm_array(X - L.array()[:,k,:]) / tfrobnorm_array(L.array()[:,k,:])]\n", 400 | "Y_hat_toucan, U, w, stats_toucan, tElapsed_toucan = toucan(Y,mask,rank,tube=False,outer=2,mode='online',fun=fun,cgtol=1e-6,\n", 401 | " randomOrder=False,verbose=False)\n", 402 | "\n", 403 | "toucan_nrmse = tfrobnorm(Y_hat_toucan - L) / Lfrob\n", 404 | "print('TOUCAN Time: {:4f}'.format(tElapsed_toucan))\n", 405 | "print('TOUCAN NRMSE: {:4f}'.format(toucan_nrmse))\n", 406 | "\n", 407 | "plt.imshow(Y_hat_toucan.array()[:,35,:],cmap='gray')\n", 408 | "name = 'merl_toucan-' + str(int(rho*100)) + '%'+'.eps'\n", 409 | "plt.savefig(name)\n", 410 | "plt.show()\n", 411 | "\n", 412 | "(score, diff) = compare_ssim(L.array()[:,35,:], Y_hat_toucan.array()[:,35,:], full=True)\n", 413 | "print('TOUCAN SSIM: {:.5f}'.format(score))" 414 | ] 415 | }, 416 | { 417 | "cell_type": "code", 418 | "execution_count": null, 419 | "metadata": {}, 420 | "outputs": [], 421 | "source": [ 422 | "plt.imshow(U.array().squeeze())\n", 423 | "plt.show()\n", 424 | "\n", 425 | "U_row = U.array()[100,0,:]\n", 426 | "plt.plot(U_row)\n", 427 | "plt.show()\n", 428 | "\n", 429 | "plt.plot(w.array()[0,0,:])\n", 430 | "plt.show()\n", 431 | "\n", 432 | "rec = (U * w).array().squeeze()\n", 433 | "\n", 434 | "plt.imshow(rec)\n", 435 | "plt.show()\n", 436 | "\n", 437 | "w_fft = np.fft.fft(w.array()[0,0,:])\n", 438 | "plt.plot(np.abs(w_fft))\n", 439 | "plt.show()\n", 440 | "\n", 441 | "U_row_fft = np.fft.fft(U.array()[100,0,:])\n", 442 | "plt.plot(np.abs(U_row_fft))\n", 443 | "plt.show()\n", 444 | "\n", 445 | "mult = U_row_fft * w_fft\n", 446 | "plt.plot(np.abs(mult))\n", 447 | "plt.show()\n", 448 | "\n", 449 | "im_row_reconst = np.real(np.fft.ifft(mult))\n", 450 | "plt.plot(im_row_reconst)\n", 451 | "plt.show()\n", 452 | "\n" 453 | ] 454 | }, 455 | { 456 | "cell_type": "code", 457 | "execution_count": null, 458 | "metadata": {}, 459 | "outputs": [], 460 | "source": [ 461 | "## Matrix Completion Algorithms\n", 462 | "\n", 463 | "Lm = lr_flatten(L.array())\n", 464 | "Ym = lr_flatten(Y.array())\n", 465 | "mask_m = lr_flatten(mask)" 466 | ] 467 | }, 468 | { 469 | "cell_type": "markdown", 470 | "metadata": {}, 471 | "source": [ 472 | "### MatComp" 473 | ] 474 | }, 475 | { 476 | "cell_type": "code", 477 | "execution_count": null, 478 | "metadata": {}, 479 | "outputs": [], 480 | "source": [ 481 | "## MatComp\n", 482 | "\n", 483 | "fun = lambda X: [0, np.linalg.norm(X - Lm,'fro') / np.linalg.norm(Lm,'fro')]\n", 484 | "Ym_hat, stats_matcomp, tElapsed_matcomp = lrmc(Ym,mask_m,niter=100,fun=fun,verbose=False)\n", 485 | "\n", 486 | "matcomp_nrmse = np.linalg.norm(Ym_hat - Lm,'fro') / np.linalg.norm(Lm,'fro')\n", 487 | "print('TOUCAN Time: {:4f}'.format(tElapsed_matcomp))\n", 488 | "print('TOUCAN NRMSE: {:4f}'.format(matcomp_nrmse))\n", 489 | "\n", 490 | "plt.imshow(np.reshape(Ym_hat[:,35],(int(ny),int(nx))).T,cmap='gray')\n", 491 | "name = 'merl_matcomp-' + str(int(rho*100)) + '%'+'.eps'\n", 492 | "plt.savefig(name)\n", 493 | "plt.show()\n", 494 | "\n", 495 | "(score, diff) = compare_ssim(Lm[:,35], Ym_hat[:,35], full=True)\n", 496 | "print('MatComp SSIM: {:.5f}'.format(score))" 497 | ] 498 | }, 499 | { 500 | "cell_type": "markdown", 501 | "metadata": {}, 502 | "source": [ 503 | "### GROUSE" 504 | ] 505 | }, 506 | { 507 | "cell_type": "code", 508 | "execution_count": null, 509 | "metadata": {}, 510 | "outputs": [], 511 | "source": [ 512 | "## GROUSE\n", 513 | "rank = 1\n", 514 | "# fun = lambda X: [0, np.linalg.norm(X - Lm, 'fro') / np.linalg.norm(Lm, 'fro')]\n", 515 | "fun = lambda X,idx: [0, np.linalg.norm(X - Lm[:,idx]) / np.linalg.norm(Lm[:,idx])]\n", 516 | "Ym_hat_grouse, stats_grouse, tElapsed_grouse = grouse(Ym, mask_m, rank,outer=3,mode=\"online\",fun=fun,randomOrder=False,\n", 517 | " verbose=False)\n", 518 | "grouse_nrmse = np.linalg.norm(Ym_hat_grouse - Lm,'fro') / np.linalg.norm(Lm,'fro')\n", 519 | "print('GROUSE Time: {:4f}'.format(tElapsed_grouse))\n", 520 | "print('GROUSE NRMSE: {:4f}'.format(grouse_nrmse))\n", 521 | "\n", 522 | "plt.imshow(np.reshape(Ym_hat_grouse[:,35],(int(ny),int(nx))).T,cmap='gray')\n", 523 | "name = 'merl_grouse-' + str(int(rho*100)) + '%'+'.eps'\n", 524 | "plt.savefig(name)\n", 525 | "plt.show()\n", 526 | "\n", 527 | "(score, diff) = compare_ssim(Lm[:,35], Ym_hat_grouse[:,35], full=True)\n", 528 | "print('GROUSE SSIM: {:.5f}'.format(score))" 529 | ] 530 | }, 531 | { 532 | "cell_type": "code", 533 | "execution_count": null, 534 | "metadata": {}, 535 | "outputs": [], 536 | "source": [ 537 | "# ## Write out results\n", 538 | "\n", 539 | "# for i in range(0,num_frames):\n", 540 | "# name = path + 'observed/frame_' + '%02d' % i + '.png'\n", 541 | "# scipy.misc.imsave(name, np.reshape(Ym[:, i], (int(ny), int(nx))).T)\n", 542 | "\n", 543 | "# # name = path + '/tnn_results/frame_' + '%02d' % i + '.png'\n", 544 | "# # scipy.misc.imsave(name, Y_hat_tnn.array()[:, i, :])\n", 545 | "\n", 546 | "# # name = path + '/tctf_results/frame_' + '%02d' % i + '.png'\n", 547 | "# # scipy.misc.imsave(name, Y_hat_tctf.array()[:, i, :])\n", 548 | "\n", 549 | "# name = path + '/toucan_results/frame_' + '%02d' % i + '.png'\n", 550 | "# scipy.misc.imsave(name, Y_hat_toucan.array()[:, i, :])\n", 551 | "\n", 552 | "# # name = path + 'matcomp_results/frame_' + '%02d' % i + '.png'\n", 553 | "# # scipy.misc.imsave(name, np.reshape(Ym_hat[:, i],(int(nx),int(ny))))\n", 554 | "\n", 555 | "# name = path + 'grouse_results/frame_' + '%02d' % i + '.png'\n", 556 | "# scipy.misc.imsave(name, np.reshape(Ym_hat_grouse[:, i], (int(ny), int(nx))).T)\n", 557 | " \n", 558 | "# name = path + 'tecpsgd_results/frame_' + '%02d' % i + '.png'\n", 559 | "# scipy.misc.imsave(name, np.reshape(Y_hat_tecpsgd[:, i], (int(nx), int(ny))))\n", 560 | " \n", 561 | "# name = path + 'olstec_results/frame_' + '%02d' % i + '.png'\n", 562 | "# scipy.misc.imsave(name, np.reshape(Y_hat_olstec[:, i], (int(nx), int(ny))))\n", 563 | " " 564 | ] 565 | }, 566 | { 567 | "cell_type": "code", 568 | "execution_count": null, 569 | "metadata": {}, 570 | "outputs": [], 571 | "source": [ 572 | "# plt.figure(figsize=(10,5), tight_layout=True)\n", 573 | "# plt.semilogy(np.cumsum(stats_toucan[:,-1]),stats_toucan[:,1], '#ff7f0e',\n", 574 | "# np.cumsum(stats_grouse[:, -1]), stats_grouse[:, 1], 'b',\n", 575 | "# np.cumsum(stats_tnn[:,-1]),stats_tnn[:,1], 'r',\n", 576 | "# np.cumsum(stats_tctf[:,-1]),stats_tctf[:,1], 'k',\n", 577 | "# np.cumsum(stats_matcomp[:,-1]),stats_matcomp[:,1],'g',\n", 578 | "# )\n", 579 | "\n", 580 | "# plt.legend(('TOUCAN', 'GROUSE', 'TNN-ADMM', 'TCTF', 'MatComp'),bbox_to_anchor=(1.1, 1 ))\n", 581 | "# plt.xlabel('Time (s)')\n", 582 | "# plt.ylabel('NRMSE')\n", 583 | "# plt.grid()\n", 584 | "# plt.rcParams.update({'font.size': 22})\n", 585 | "# plt.savefig('merl_video_completion_nrmse_' + str(rho*100) + 'p.eps')\n", 586 | "# plt.show()" 587 | ] 588 | }, 589 | { 590 | "cell_type": "code", 591 | "execution_count": null, 592 | "metadata": {}, 593 | "outputs": [], 594 | "source": [ 595 | "plt.figure(figsize=(10,5), tight_layout=True)\n", 596 | "plt.semilogy(np.cumsum(stats_toucan[:,-1]),stats_toucan[:,1],'#ff7f0e',\n", 597 | " np.cumsum(stats_grouse[:, -1]), stats_grouse[:, 1],'b',\n", 598 | " np.cumsum(sub_infos_TeCPSGD['times'][1:]),sub_infos_TeCPSGD['err_residual'][1:],'k',\n", 599 | " np.cumsum(sub_infos_olstec['times'][1:]),sub_infos_olstec['err_residual'][1:],'r'\n", 600 | ")\n", 601 | "\n", 602 | "plt.legend(('TOUCAN', 'GROUSE', 'TeCPSGD', 'OLSTEC'),bbox_to_anchor=(1.1, 1 ))\n", 603 | "plt.xlabel('Time (s)')\n", 604 | "plt.ylabel('NRMSE')\n", 605 | "plt.grid()\n", 606 | "plt.rcParams.update({'font.size': 22})\n", 607 | "# plt.savefig('merl_video_completion_nrmse_' + str(rho*100) + 'p.eps')\n", 608 | "plt.show()" 609 | ] 610 | }, 611 | { 612 | "cell_type": "code", 613 | "execution_count": null, 614 | "metadata": {}, 615 | "outputs": [], 616 | "source": [ 617 | "plt.figure(figsize=(12,5), tight_layout=True)\n", 618 | "plt.semilogy(np.arange(0,len(stats_toucan)),stats_toucan[:,1],'#ff7f0e',\n", 619 | " np.arange(0,len(stats_grouse)), stats_grouse[:, 1],'b',\n", 620 | " np.arange(0,len(sub_infos_TeCPSGD['times'][1:])),sub_infos_TeCPSGD['err_residual'][1:],'k',\n", 621 | " np.arange(0,len(sub_infos_olstec['times'][1:])),sub_infos_olstec['err_residual'][1:],'r'\n", 622 | ")\n", 623 | "\n", 624 | "plt.legend(('TOUCAN', 'GROUSE', 'TeCPSGD', 'OLSTEC'),bbox_to_anchor=(1.1, 1 ))\n", 625 | "plt.xlabel('Slice index')\n", 626 | "plt.ylabel('NRMSE')\n", 627 | "plt.grid()\n", 628 | "plt.rcParams.update({'font.size': 22})\n", 629 | "# plt.savefig('merl_video_completion_nrmse_' + str(rho*100) + 'p.eps')\n", 630 | "plt.show()" 631 | ] 632 | }, 633 | { 634 | "cell_type": "code", 635 | "execution_count": null, 636 | "metadata": {}, 637 | "outputs": [], 638 | "source": [ 639 | "olstec_frame_err = []\n", 640 | "toucan_frame_err = []\n", 641 | "\n", 642 | "for i in range(0,n2):\n", 643 | " olstec_frame_err.append(np.linalg.norm(Y_hat_olstec[:,i,:].squeeze() - L.array()[:,i,:]) / np.linalg.norm(L.array()[:,i,:]))\n", 644 | " toucan_frame_err.append(np.linalg.norm(Y_hat_toucan.array()[:,i,:].squeeze() - L.array()[:,i,:]) / np.linalg.norm(L.array()[:,i,:]))\n", 645 | "\n", 646 | "plt.semilogy(np.arange(0,len(olstec_frame_err)),olstec_frame_err,np.arange(0,len(toucan_frame_err)),toucan_frame_err)\n", 647 | "plt.legend(['OLSTEC','TOUCAN'])\n", 648 | "plt.show()" 649 | ] 650 | }, 651 | { 652 | "cell_type": "code", 653 | "execution_count": null, 654 | "metadata": {}, 655 | "outputs": [], 656 | "source": [] 657 | } 658 | ], 659 | "metadata": { 660 | "@webio": { 661 | "lastCommId": null, 662 | "lastKernelId": null 663 | }, 664 | "kernelspec": { 665 | "display_name": "Python 3", 666 | "language": "python", 667 | "name": "python3" 668 | }, 669 | "language_info": { 670 | "codemirror_mode": { 671 | "name": "ipython", 672 | "version": 3 673 | }, 674 | "file_extension": ".py", 675 | "mimetype": "text/x-python", 676 | "name": "python", 677 | "nbconvert_exporter": "python", 678 | "pygments_lexer": "ipython3", 679 | "version": "3.6.6" 680 | } 681 | }, 682 | "nbformat": 4, 683 | "nbformat_minor": 2 684 | } 685 | -------------------------------------------------------------------------------- /jupyter_notebooks/__pycache__/olstec.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kgilman/TOUCAN/059b81f3e3f39afedc3e5ad51c37600d6056f9fd/jupyter_notebooks/__pycache__/olstec.cpython-36.pyc -------------------------------------------------------------------------------- /jupyter_notebooks/__pycache__/tecpsgd.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kgilman/TOUCAN/059b81f3e3f39afedc3e5ad51c37600d6056f9fd/jupyter_notebooks/__pycache__/tecpsgd.cpython-36.pyc -------------------------------------------------------------------------------- /jupyter_notebooks/__pycache__/tsvd.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kgilman/TOUCAN/059b81f3e3f39afedc3e5ad51c37600d6056f9fd/jupyter_notebooks/__pycache__/tsvd.cpython-36.pyc -------------------------------------------------------------------------------- /jupyter_notebooks/olstec.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import time 3 | 4 | def compute_cost_tensor(X, P, PA, tensor_dims): 5 | n1 = tensor_dims[0] 6 | n2 = tensor_dims[1] 7 | n3 = tensor_dims[2] 8 | 9 | Diff = P * X - PA 10 | Diff_flat = np.reshape(Diff, (n1 * n2, n3)) 11 | 12 | return 0.5 * np.linalg.norm(Diff_flat, 'fro') ** 2 13 | 14 | def OLSTEC(A_in, Omega_in, Gamma_in, tensor_dims, rank, xinit, options,fun = lambda X: 0): 15 | A = A_in # Full entries 16 | Omega = Omega_in # Trainingset 'Omega' 17 | Gamma = Gamma_in # Test set 'Gamma' 18 | 19 | A_Omega = Omega_in * A_in #Training entries i.e., Omega_in. * A_in 20 | if Gamma_in is not None: 21 | A_Gamma = Gamma_in * A_in # Test entries i.e., Gamma_in. * A_in 22 | else: 23 | A_Gamma = [] 24 | 25 | if xinit is None: 26 | A_t0 = np.random.randn(tensor_dims[0], rank) 27 | B_t0 = np.random.randn(tensor_dims[1], rank) 28 | C_t0 = np.random.randn(tensor_dims[2], rank) 29 | else: 30 | A_t0 = xinit['A'] 31 | B_t0 = xinit['B'] 32 | C_t0 = xinit['C'] 33 | 34 | # set tensor size 35 | rows = tensor_dims[0] 36 | cols = tensor_dims[1] 37 | slice_length = tensor_dims[2] 38 | 39 | # set options 40 | lam = options['lam'] 41 | mu = options['mu'] 42 | maxepochs = options['maxepochs'] 43 | tolcost = options['tolcost'] 44 | store_subinfo = options['store_subinfo'] 45 | store_matrix = options['store_matrix'] 46 | verbose = options['verbose'] 47 | 48 | if options['permute_on'] is None: 49 | permute_on = 1 50 | else: 51 | permute_on = options['permute_on'] 52 | 53 | if options['tw_flag'] is None: 54 | TW_Flag = False 55 | else: 56 | TW_Flag = options['tw_flag'] 57 | 58 | if options['tw_len'] is None: 59 | TW_LEN = 10 60 | else: 61 | TW_LEN = options['tw_len'] 62 | 63 | # prepare Rinv history buffers 64 | RAinv = np.tile(100 * np.eye(rank), (rows, 1)) 65 | RBinv = np.tile(100 * np.eye(rank), (cols, 1)) 66 | 67 | # prepare 68 | N_AlphaAlphaT = np.zeros((rank * rows, rank * (TW_LEN + 1))) 69 | N_BetaBetaT = np.zeros((rank * cols, rank * (TW_LEN + 1))) 70 | 71 | # prepare 72 | N_AlphaResi = np.zeros((rank * rows, TW_LEN + 1)) 73 | N_BetaResi = np.zeros((rank * cols, TW_LEN + 1)) 74 | 75 | # calculate initial cost 76 | Rec = np.zeros((rows, cols, slice_length)) 77 | for k in range(0,slice_length): 78 | gamma = C_t0[k,:].T 79 | Rec[:,:,k] = A_t0 @ np.diag(gamma) @ B_t0.T 80 | 81 | train_cost = compute_cost_tensor(Rec, Omega, A_Omega, tensor_dims) 82 | 83 | if Gamma is None and A_Gamma is None: 84 | test_cost = compute_cost_tensor(Rec, Gamma, A_Gamma, tensor_dims) 85 | else: 86 | test_cost = 0 87 | 88 | # initialize infos 89 | infos = { 90 | 'iter': 0, 91 | 'train_cost': train_cost, 92 | 'test_cost': test_cost, 93 | 'time': [0] 94 | } 95 | 96 | # initialize sub_infos 97 | sub_infos = { 98 | 'inner_iter': 0, 99 | 'err_residual':0, 100 | 'error': 0, 101 | 'err_run_ave': [0], 102 | 'global_train_cost':0, 103 | 'global_test_cost':0, 104 | 'times':[0] 105 | } 106 | 107 | if store_matrix: 108 | sub_infos['I'] = np.zeros((rows * cols, slice_length)) 109 | sub_infos['L'] = np.zeros((rows * cols, slice_length)) 110 | sub_infos['E'] = np.zeros((rows * cols, slice_length)) 111 | 112 | if verbose > 0: 113 | print('TeCPSGD Epoch 000, Cost {:.5f}, Cost(test) {:.5f}'.format(train_cost,test_cost)) 114 | # main loop 115 | 116 | A_t1 = A_t0.copy() 117 | B_t1 = B_t0.copy() 118 | 119 | Yhat = np.zeros(A_in.shape) 120 | 121 | for outiter in range(0,maxepochs): 122 | #permute samples 123 | if permute_on: 124 | col_order = np.random.permutation(slice_length) 125 | else: 126 | col_order = np.arange(0,slice_length) 127 | 128 | 129 | # Begin the time counter for the epoch 130 | 131 | for k in range(0,slice_length): 132 | 133 | tStart = time.time() 134 | 135 | #sampled original image 136 | I_mat = A[:,:, col_order[k]] 137 | Omega_mat = Omega[:,:, col_order[k]] 138 | I_mat_Omega = Omega_mat * I_mat 139 | 140 | # Calculate gamma 141 | temp3 = 0 142 | temp4 = 0 143 | for m in range(0,rows): 144 | alpha_remat = np.tile(A_t0[m,:].T, (cols,1)).T 145 | alpha_beta = alpha_remat * B_t0.T 146 | I_row = I_mat_Omega[m,:] 147 | temp3 = temp3 + alpha_beta @ I_row.T 148 | 149 | Omega_mat_ind = np.where(Omega_mat[m,:]>0)[0] 150 | alpha_beta_Omega = alpha_beta[:, Omega_mat_ind] 151 | temp4 = temp4 + alpha_beta_Omega @ alpha_beta_Omega.T 152 | 153 | temp4 = lam * np.eye(rank) + temp4 154 | gamma = np.linalg.lstsq(temp4,temp3,rcond=-1)[0] # gamma = temp4 \ temp3; 155 | 156 | ## update A 157 | for m in range(0,rows): 158 | Omega_mat_ind = np.where(Omega_mat[m,:]>0)[0] 159 | I_row = I_mat_Omega[m,:] 160 | I_row_Omega = I_row[Omega_mat_ind] 161 | C_t0_Omega = B_t0[Omega_mat_ind,:] 162 | N_alpha_Omega = np.diag(gamma) @ C_t0_Omega.T 163 | N_alpha_alpha_t_Omega = N_alpha_Omega @ N_alpha_Omega.T 164 | 165 | # Calc TAinv(i.e.RAinv) 166 | TAinv = lam**(-1) * RAinv[m * rank: (m+1) * rank,:] 167 | if TW_Flag: 168 | Oldest_alpha_alpha_t = N_AlphaAlphaT[m * rank :(m+1) * rank, 0:rank] 169 | TAinv = np.linalg.inv(np.linalg.inv(TAinv) + N_alpha_alpha_t_Omega + (mu - lam * mu) * np.eye(rank) - lam ** TW_LEN * Oldest_alpha_alpha_t) 170 | else: 171 | TAinv = np.linalg.inv(np.linalg.inv(TAinv) + N_alpha_alpha_t_Omega + (mu - lam * mu) * np.eye(rank)) 172 | 173 | # Calc delta A_t0(m,:) 174 | recX_col_Omega = N_alpha_Omega.T @ A_t0[m,:].T 175 | resi_col_Omega = I_row_Omega.T - recX_col_Omega 176 | N_alpha_Resi_Omega = N_alpha_Omega @ np.diag(resi_col_Omega) 177 | 178 | N_resi_Rt_alpha = TAinv @ N_alpha_Resi_Omega 179 | delta_A_t0_m = np.sum(N_resi_Rt_alpha, 1) 180 | 181 | # Update A 182 | if TW_Flag: 183 | # update A 184 | Oldest_alpha_resi = N_AlphaResi[m * rank: (m+1) * rank, 1].T 185 | A_t1[m,:] = A_t0[m,:] - (mu - lam * mu) * A_t0[m,:] @ TAinv.T + delta_A_t0_m.T - lam ** TW_LEN @ Oldest_alpha_resi 186 | 187 | # Store data 188 | N_AlphaAlphaT[m * rank: (m+1) * rank, TW_LEN * rank + 1: (TW_LEN + 1) * rank] = \ 189 | N_alpha_alpha_t_Omega 190 | N_AlphaResi[m * rank : (m+1) * rank, TW_LEN + 1] = np.sum(N_alpha_Resi_Omega, 2) 191 | else: 192 | A_t1[m,:] = A_t0[m,:] - (mu - lam * mu) * A_t0[m,:] @ TAinv.T + delta_A_t0_m.T 193 | 194 | # Store RAinv 195 | RAinv[m * rank : (m+1) * rank,:] = TAinv 196 | 197 | # Final update of A 198 | A_t0 = A_t1.copy() 199 | 200 | ## update B 201 | for n in range(0,cols): 202 | Omega_mat_ind = np.where(Omega_mat[:,n] > 0)[0] 203 | I_col = I_mat_Omega[:, n] 204 | I_col_Omega = I_col[Omega_mat_ind] 205 | A_t0_Omega = A_t0[Omega_mat_ind,:] 206 | N_beta_Omega = A_t0_Omega @ np.diag(gamma) 207 | N_beta_beta_t_Omega = N_beta_Omega.T @ N_beta_Omega 208 | 209 | # Calc TBinv(i.e.RBinv) 210 | TBinv = lam**(-1) * RBinv[n*rank: (n+1) * rank,:] 211 | if TW_Flag: 212 | Oldest_beta_beta_t = N_BetaBetaT[n*rank:(n+1) * rank, 1: rank] 213 | TBinv = np.linalg.inv(np.linalg.inv(TBinv) + N_beta_beta_t_Omega + (mu - lam * mu) * np.eye(rank) 214 | - lam**TW_LEN * Oldest_beta_beta_t) 215 | else: 216 | TBinv = np.linalg.inv(np.linalg.inv(TBinv) + N_beta_beta_t_Omega + (mu - lam * mu) * np.eye(rank)) 217 | 218 | # Calc delta B_t0(n,:) 219 | recX_col_Omega = B_t0[n,:] @ N_beta_Omega.T 220 | resi_col_Omega = I_col_Omega.T - recX_col_Omega 221 | N_beta_Resi_Omega = N_beta_Omega.T @ np.diag(resi_col_Omega) 222 | N_resi_Rt_beta = TBinv @ N_beta_Resi_Omega 223 | delta_C_t0_n = np.sum(N_resi_Rt_beta, 1) 224 | 225 | if TW_Flag: 226 | # Upddate B 227 | Oldest_beta_resi = N_BetaResi[n*rank:(n+1) * rank, 1].T 228 | B_t1[n,:] = B_t0[n,:] - (mu - lam * mu) * B_t0[n,:] @ TBinv.T + delta_C_t0_n.T -lam ** TW_LEN \ 229 | * Oldest_beta_resi 230 | 231 | # Store data 232 | N_BetaBetaT[n*rank: (n+1) * rank, TW_LEN * rank + 1: (TW_LEN + 1) * rank] = N_beta_beta_t_Omega 233 | N_BetaResi[n*rank: (n+1) * rank, TW_LEN + 1] = np.sum(N_beta_Resi_Omega, 2) 234 | else: 235 | B_t1[n,:] = B_t0[n,:] - (mu - lam * mu) * B_t0[n,:] @ TBinv.T + delta_C_t0_n.T 236 | 237 | # Store RBinv 238 | RBinv[n*rank: (n+1) * rank,:] = TBinv 239 | 240 | # Final update of B 241 | B_t0 = B_t1.copy() 242 | 243 | # # Calculate gamma 244 | # temp3 = 0 245 | # temp4 = 0 246 | # for m in range(0, rows): 247 | # alpha_remat = np.tile(A_t0[m, :].T, (cols, 1)).T 248 | # alpha_beta = alpha_remat * B_t0.T 249 | # I_row = I_mat_Omega[m, :] 250 | # temp3 = temp3 + alpha_beta @ I_row.T 251 | 252 | # Omega_mat_ind = np.where(Omega_mat[m, :] > 0)[0] 253 | # alpha_beta_Omega = alpha_beta[:, Omega_mat_ind] 254 | # temp4 = temp4 + alpha_beta_Omega @ alpha_beta_Omega.T 255 | 256 | # temp4 = lam * np.eye(rank) + temp4 257 | # gamma = np.linalg.lstsq(temp4, temp3, rcond=-1)[0] # gamma = temp4 \ temp3; 258 | 259 | tElapsed = time.time() - tStart 260 | 261 | #Store gamma into C_t0 262 | C_t0[col_order[k],:] = gamma.T 263 | 264 | #Reconstruct Low - rank Matrix 265 | L_rec = A_t0 @ np.diag(gamma) @ B_t0.T 266 | Yhat[:,:,k] = L_rec 267 | 268 | ## Diagnostics 269 | 270 | if store_matrix: 271 | E_rec = I_mat - L_rec 272 | sub_infos['E'] = np.append(sub_infos['E'], np.vectorize(E_rec)) 273 | I = sub_infos['I'] 274 | I[:,k] = np.vectorize(I_mat_Omega) 275 | sub_infos['I'] = I 276 | 277 | L = sub_infos['L'] 278 | L[:, k] = np.vectorize(L_rec) 279 | sub_infos['L'] = L 280 | 281 | E = sub_infos['E'] 282 | E[:, k] = np.vectorize(E_rec) 283 | sub_infos['E'] = E 284 | 285 | # sub_infos.I[:, k] = np.vectorize(I_mat_Omega) 286 | # sub_infos.L[:, k] = np.vectorize(L_rec) 287 | # sub_infos.E[:, k] = np.vectorize(E_rec) 288 | 289 | if store_subinfo: 290 | # Residual Error 291 | norm_residual = np.linalg.norm(I_mat - L_rec,'fro') 292 | norm_I = np.linalg.norm(I_mat,'fro') 293 | error = norm_residual / norm_I 294 | sub_infos['inner_iter'] = np.append(sub_infos['inner_iter'], (outiter + 1 - 1) * slice_length + k + 1) 295 | sub_infos['err_residual'] = np.append(sub_infos['err_residual'], error) 296 | sub_infos['times'] = np.append(sub_infos['times'],tElapsed) 297 | sub_infos['error'] = np.append(sub_infos['error'],fun(Yhat)) 298 | 299 | #Running - average Estimation Error 300 | if k == 0: 301 | run_error = error 302 | else: 303 | run_error = (sub_infos['err_run_ave'][-1] * (k + 1 - 1) + error) / (k+1) 304 | 305 | sub_infos['err_run_ave'] = np.append(sub_infos['err_run_ave'], run_error) 306 | 307 | # Store reconstruction Error 308 | if store_matrix: 309 | E_rec = I_mat - L_rec 310 | sub_infos['E'] = np.append(sub_infos['E'],np.vectorize(E_rec)) 311 | 312 | # for f in range(0,slice_length): 313 | # gamma = C_t0[f,:].T 314 | # Rec[:,:, f] = A_t0 @ np.diag(gamma) @ B_t0.T 315 | 316 | # Global train_cost computation 317 | # train_cost = compute_cost_tensor(Rec, Omega, A_Omega, tensor_dims) 318 | # if Gamma is None and A_Gamma is None: 319 | # test_cost = compute_cost_tensor(Rec, Gamma, A_Gamma, tensor_dims) 320 | # else: 321 | # test_cost = 0 322 | 323 | # sub_infos['global_train_cost'] = np.append(sub_infos['global_train_cost'],train_cost) 324 | # sub_infos['global_test_cost'] = np.append(sub_infos['global_test_cost'],test_cost) 325 | 326 | if verbose: 327 | train_cost = 0 328 | fnum = (outiter + 1 - 1) * slice_length + k + 1 329 | print('OLSTEC: fnum = {:3d}, cost = {:.3f}, error = {:.8f}'.format(fnum, train_cost, error)) 330 | 331 | # store infos 332 | infos['iter'] = np.append(infos['iter'],outiter) 333 | # infos['time'] = np.append(infos['time'], infos['time'][-1] + time.time() - t_begin) 334 | 335 | if store_subinfo is False: 336 | # for f in range(0,slice_length): 337 | # gamma = C_t0[f,:].T 338 | # Rec[:,:, f] = A_t0 @ np.diag(gamma) @ B_t0.T 339 | pass 340 | 341 | # train_cost = compute_cost_tensor(Rec, Omega, A_Omega, tensor_dims) 342 | if Gamma is None and A_Gamma is None: 343 | test_cost = compute_cost_tensor(Rec, Gamma, A_Gamma, tensor_dims) 344 | else: 345 | test_cost = 0 346 | 347 | infos['train_cost'] = [infos['train_cost'], train_cost] 348 | infos['test_cost'] = [infos['test_cost'], test_cost] 349 | 350 | if verbose > 0: 351 | print('OLSTEC Epoch {:3d}, Cost {:.7f}, Cost(test) {:.7f}'.format(outiter, train_cost, test_cost)) 352 | 353 | #stopping criteria: cost tolerance reached 354 | # if train_cost < tolcost: 355 | # print('train_cost sufficiently decreased.') 356 | # break 357 | 358 | Xsol = { 359 | 'A': A_t0, 360 | 'B': B_t0, 361 | 'C': C_t0 362 | } 363 | 364 | return Xsol, Yhat, infos, sub_infos -------------------------------------------------------------------------------- /jupyter_notebooks/run_all_mri.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib 3 | matplotlib.use('Agg') 4 | import matplotlib.pyplot as plt 5 | from tsvd import * 6 | import scipy.io as sio 7 | from stc import * 8 | from olstec import * 9 | from tecpsgd import * 10 | from skimage.measure import compare_ssim 11 | 12 | 13 | 14 | def im2kspace(images): 15 | ### images are nt x nx x ny 16 | c1 = np.sqrt(n2) 17 | c2 = np.sqrt(n3) 18 | 19 | # c1 = 1 20 | # c2 = 1 21 | z = np.fft.fftshift(np.fft.fft(np.fft.fftshift(images, axes=1), axis=1), axes=1) * c1 22 | kdata = np.fft.fftshift(np.fft.fft(np.fft.fftshift(z, axes=2), axis=2), axes=2) * c2 23 | return kdata 24 | 25 | 26 | def kspace2im(kdata): 27 | ### kdata is nt x nx x ny 28 | 29 | # data_rec = tensor.array() + 1j * C.array() 30 | # data_rec = np.transpose(data_rec, (1, 0, 2)) 31 | n1, n2, n3 = kdata.shape 32 | c = np.fft.fftshift(np.fft.ifft(np.fft.fftshift(kdata, axes=1), axis=1), axes=1) * 1 / np.sqrt(n2) 33 | im_rec = np.fft.fftshift(np.fft.ifft(np.fft.fftshift(c, axes=2), axis=2), axes=2) * 1 / np.sqrt(n3) 34 | # c = np.fft.fftshift(np.fft.ifft(np.fft.fftshift(kdata, axes=1), axis=1), axes=1) 35 | # im_rec = np.fft.fftshift(np.fft.ifft(np.fft.fftshift(c, axes=2), axis=2), axes=2) 36 | return np.abs(im_rec) 37 | 38 | 39 | def compute_stats(Rhat, Chat, R, C, orig_im, kdata_orig): 40 | slice_err = [] 41 | for i in range(0, n2): 42 | slice_err.append(tfrobnorm_array(Rhat.array()[:, i, :] - R.array()[:, i, :]) / tfrobnorm_array(R.array()[:, i, 43 | :])) 44 | 45 | kdata_rec = Rhat.array() + 1j * Chat.array() 46 | kdata_rec = np.transpose(kdata_rec, (1, 0, 2)) 47 | im_rec = kspace2im(kdata_rec) 48 | rec_im = np.abs(im_rec) 49 | 50 | plt.figure(figsize=(10, 10)) 51 | plt.subplot(2, 2, 1) 52 | plt.imshow(orig_im[vis_idx, :, :], cmap='gray') 53 | plt.title('Original magnitude image') 54 | 55 | plt.subplot(2, 2, 2) 56 | test = np.log(np.abs(kdata_orig[vis_idx, :, :]) + 1e-13) 57 | plt.imshow(test, cmap='gray') 58 | plt.title('Original Magnitude kspace') 59 | 60 | plt.subplot(2, 2, 3) 61 | plt.imshow(rec_im[vis_idx, :, :], cmap='gray') 62 | plt.title('Reconstructed') 63 | plt.subplot(2, 2, 4) 64 | plt.imshow(np.log(np.abs((kdata_rec[vis_idx, :, :])) + 1e-13), cmap='gray') 65 | plt.title('Reconstructed Magnitude kspace') 66 | plt.show() 67 | plt.close() 68 | 69 | return slice_err, rec_im, kdata_rec 70 | # plt.ioff() 71 | datasets = ['brain','invivo_cardiac'] 72 | # datasets = ['brain'] 73 | # datasets = ['invivo_cardiac'] 74 | # dataset = 'brain' 75 | # dataset = 'invivo_cardiac' 76 | # datasets = ['aperiodic_pincat'] 77 | vis_idx = 30 78 | for dataset in datasets: 79 | 80 | if(dataset is 'brain'): 81 | real_filepath = '/Users/kgilman/Desktop/t-SVD/real_mri.mat' 82 | imag_filepath = '/Users/kgilman/Desktop/t-SVD/imag_mri.mat' 83 | 84 | real_data = sio.loadmat(real_filepath)['real_mri'] 85 | imag_data = sio.loadmat(imag_filepath)['imag_mri'] 86 | 87 | kdata = real_data + 1j*(imag_data) 88 | kdata = np.transpose(kdata,(2,0,1)) 89 | orig_im = kspace2im(kdata) 90 | fig = plt.figure() 91 | plt.imshow(orig_im[10,:,:],cmap='gray') 92 | plt.show() 93 | plt.close(fig) 94 | else: 95 | if(dataset is 'invivo_cardiac'): 96 | data = sio.loadmat('/Users/kgilman/Desktop/t-SVD/invivo_perfusion.mat')['x'] 97 | else: 98 | data = sio.loadmat('/Users/kgilman/Desktop/t-SVD/aperiodic_pincat.mat')['new'] 99 | data = np.transpose(data,(2,0,1)) 100 | n1,n2,n3 = data.shape 101 | # orig_im = np.abs(data) 102 | orig_im = (data) 103 | ## convert to kspace 104 | kdata = im2kspace(data) 105 | 106 | n1,n2,n3 = kdata.shape 107 | ### Form the real and imag tensors 108 | kdata_orig = kdata.copy() 109 | kdata = np.transpose(kdata, (1, 0, 2)) 110 | 111 | R = np.real(kdata) 112 | Rmean = np.mean(R) 113 | 114 | C = np.imag(kdata) 115 | Cmean = np.mean(C) 116 | 117 | # R -= Rmean 118 | # C -= Cmean 119 | 120 | R = Tensor(R) 121 | C = Tensor(C) 122 | 123 | n1, n2, n3 = R.shape() 124 | 125 | 126 | def computeStats(im): 127 | plt.subplot(1,2,1) 128 | plt.imshow(im[30,:,:],cmap='gray') 129 | plt.subplot(1,2,2) 130 | plt.imshow(orig_im[30,:,:],cmap='gray') 131 | plt.show() 132 | nrmse = tfrobnorm_array(im - orig_im) / tfrobnorm_array(orig_im) 133 | scores = [] 134 | for i in range(im.shape[0]): 135 | (score, diff) = compare_ssim(orig_im[i, :, :] / np.max(orig_im[i, :, :]), 136 | im[i, :, :] / np.max(im[i, :, :]), full=True) 137 | scores.append(score) 138 | return nrmse, np.mean(scores) 139 | 140 | ### Generate k-space sampling mask 141 | # tubes = [True, False] 142 | tubes = [True, False] 143 | rhos = [0.8, 0.6, 0.5] 144 | # rhos = [0.8] 145 | # rhos = [0.4] 146 | for tube in tubes: 147 | for rho in rhos: 148 | np.random.seed(0) 149 | if (tube is False): 150 | mask = np.random.rand(n1, n2, n3) 151 | mask[mask > rho] = 1 152 | mask[mask <= rho] = 0 153 | mask = mask.astype(int) 154 | else: 155 | mask = np.random.rand(n1, n2) 156 | mask[mask > rho] = 1 157 | mask[mask <= rho] = 0 158 | mask = mask.astype(int) 159 | mask = np.repeat(mask[:, :, np.newaxis], n3, axis=2) 160 | 161 | Rfrob = tfrobnorm(R) 162 | Cfrob = tfrobnorm(C) 163 | R_sample = Tensor(R.array() * mask) 164 | C_sample = Tensor(C.array() * mask) 165 | 166 | kdata_rec = R_sample.array() + 1j * C_sample.array() 167 | kdata_rec = np.transpose(kdata_rec, (1, 0, 2)) 168 | im_rec = kspace2im(kdata_rec) 169 | # subsampled_im = np.abs(im_rec[vis_idx,:,:]) 170 | subsampled_im = np.abs(im_rec) 171 | 172 | # plt.figure(figsize=(15,12)) 173 | fig = plt.figure() 174 | plt.subplot(1, 2, 1) 175 | plt.imshow(subsampled_im[vis_idx, :, :], cmap='gray') 176 | plt.title('Reconstruction from subsampled') 177 | plt.subplot(1, 2, 2) 178 | plt.imshow(np.log(np.abs((kdata_rec[vis_idx, :, :])) + 1e-8), cmap='gray') 179 | plt.title('Magnitude of subsampled kspace') 180 | plt.rcParams.update({'font.size': 22}) 181 | name = '/Users/kgilman/Desktop/t-SVD/mri_reconstruct/mri_results/subsampled_and_kspace_tube_' + str(tube) + '_' + str(int(rho * 100)) + '.eps' 182 | plt.savefig(name) 183 | plt.show() 184 | plt.close(fig) 185 | 186 | ############################# TNN-ADMM ################################# 187 | 188 | niter = 300 189 | # niter = 100 190 | 191 | fun = lambda X: [0, tfrobnorm(X - R) / Rfrob] 192 | Rhat_tnn, stats, tElapsed_tensor = lrtc(R_sample, mask, rho=1.1, niter=niter, min_iter = 100, fun=fun, verbose=True) 193 | 194 | fun = lambda X: [0, tfrobnorm(X - C) / Cfrob] 195 | Chat_tnn, stats, tElapsed_tensor = lrtc(C_sample, mask, rho=1.1, niter=niter, min_iter = 100, fun=fun, verbose=True) 196 | 197 | cost_tensor = stats[:, 0] 198 | nrmse_tensor = stats[:, 1] 199 | times_tensor = stats[:, 2] 200 | 201 | print('Time elapsed: Tensor: {:.3f} '.format(tElapsed_tensor)) 202 | print('Final R NRMSE: Tensor: {:.8f} '.format(tfrobnorm(Rhat_tnn - R) / Rfrob)) 203 | print('Final C NRMSE: Tensor: {:.8f} '.format(tfrobnorm(Chat_tnn - C) / Cfrob)) 204 | 205 | slice_err_tnn, rec_im_tnn, kdata_rec_tnn = compute_stats(Rhat_tnn, Chat_tnn, R, C, orig_im, kdata_orig) 206 | 207 | #############################################################3 208 | ## TCTF 209 | # niter = 200 210 | niter = 300 211 | 212 | if(dataset is 'brain'): 213 | rank = 1 214 | else: 215 | rank = 5 216 | 217 | fun = lambda U, V: [0, tfrobnorm(U * V - R) / Rfrob] 218 | Xtctf, Ztctf, stats_tctf, tElapsed_tctf = tctf(R_sample, mask, rank=rank, niter=niter, min_iter = 50, fun=fun, 219 | verbose=False) 220 | Rhat_tctf = Xtctf * Ztctf 221 | 222 | fun = lambda U, V: [0, tfrobnorm(U * V - C) / Cfrob] 223 | Xtctf, Ztctf, stats_tctf, tElapsed_tctf = tctf(C_sample, mask, rank=rank, niter=niter, min_iter = 50, fun=fun, 224 | verbose=False) 225 | Chat_tctf = Xtctf * Ztctf 226 | 227 | nrmse_tctf = stats_tctf[:, 1] 228 | times_tctf = stats_tctf[:, -1] 229 | print('TCTF Time: {:4f}'.format(tElapsed_tctf)) 230 | print('Final R NRMSE: Tensor: {:.8f} '.format(tfrobnorm(Rhat_tctf - R) / Rfrob)) 231 | print('Final C NRMSE: Tensor: {:.8f} '.format(tfrobnorm(Chat_tctf - C) / Cfrob)) 232 | 233 | slice_err_tctf, rec_im_tctf, kdata_rec_tctf = compute_stats(Rhat_tctf, Chat_tctf, R, C, orig_im, kdata_orig) 234 | # 235 | # #############################################################3 236 | ## TOUCAN 237 | if (dataset is 'brain'): 238 | rank = 1 239 | else: 240 | rank = 5 241 | outer = 1 242 | # fun = lambda X: [0, tfrobnorm(X - R) / Rfrob] 243 | 244 | fun = lambda X, k: [0, tfrobnorm_array(X.array()[:, k, :] - R.array()[:, k, :]) / tfrobnorm_array( 245 | R.array()[:, k, :])] 246 | 247 | Rhat_toucan, U, stats_toucan, tElapsed_toucan = toucan(R_sample, mask, rank, tube=tube, mode='online', 248 | outer=outer, 249 | fun=fun, cgtol=1e-7, randomOrder=False, 250 | verbose=False) 251 | 252 | fun = lambda X, k: [0, tfrobnorm_array(X.array()[:, k, :] - C.array()[:, k, :]) / tfrobnorm_array( 253 | C.array()[:, k, :])] 254 | Chat_toucan, U, stats_toucan, tElapsed_toucan = toucan(C_sample, mask, rank, tube=tube, mode='online', 255 | outer=outer, 256 | fun=fun, cgtol=1e-7, randomOrder=False, 257 | verbose=False) 258 | 259 | print('Initial R NRMSE: Tensor: {:.8f} '.format(tfrobnorm(R_sample - R) / Rfrob)) 260 | print('Initial C NRMSE: Tensor: {:.8f} '.format(tfrobnorm(C_sample - C) / Cfrob)) 261 | print('Final R NRMSE: Tensor: {:.8f} '.format(tfrobnorm(Rhat_toucan - R) / Rfrob)) 262 | print('Final C NRMSE: Tensor: {:.8f} '.format(tfrobnorm(Chat_toucan - C) / Cfrob)) 263 | 264 | slice_err_toucan, rec_im_toucan, kdata_rec_toucan = compute_stats(Rhat_toucan, Chat_toucan, R, C, orig_im, 265 | kdata_orig) 266 | 267 | # 268 | # #############################################################3 269 | #### STC 270 | 271 | if(tube is False): 272 | 273 | Tensor_R_sample = np.transpose(R_sample.array(), [0, 2, 1]) 274 | Tensor_C_sample = np.transpose(C_sample.array(), [0, 2, 1]) 275 | Mask_Y = np.transpose(mask, [0, 2, 1]) 276 | numcycles = 1 277 | outer = 1 278 | r1 = 25 279 | r2 = 25 280 | # r3 = 1 281 | if (dataset is 'brain'): 282 | r3 = 1 283 | else: 284 | r3 = 5 285 | fun = lambda Lhat, idx: [0, 1] 286 | Rhat_stc, stats, tElapsed_stc = stc(Tensor_R_sample, Mask_Y, r1, r2, r3, outer, numcycles, fun=fun, 287 | verbose=False) 288 | Chat_stc, stats, tElapsed_stc = stc(Tensor_C_sample, Mask_Y, r1, r2, r3, outer, numcycles, fun=fun, 289 | verbose=False) 290 | 291 | Rhat_stc = Tensor(np.transpose(Rhat_stc, [0, 2, 1])) 292 | Chat_stc = Tensor(np.transpose(Chat_stc, [0, 2, 1])) 293 | 294 | Rhat_nrmse_stc = tfrobnorm((Rhat_stc) - R) / Rfrob 295 | Chat_nrmse_stc = tfrobnorm((Chat_stc) - C) / Cfrob 296 | 297 | print('STC Time: {:4f}'.format(tElapsed_stc)) 298 | print('Rhat NRMSE STC: {:6f}'.format(Rhat_nrmse_stc)) 299 | print('Chat NRMSE STC: {:6f}'.format(Chat_nrmse_stc)) 300 | 301 | slice_err_stc, rec_im_stc, kdata_rec_stc = compute_stats(Rhat_stc, Chat_stc, R, C, orig_im, kdata_orig) 302 | 303 | ###################################################### 304 | ### OLSTEC 305 | 306 | rank = 50 307 | 308 | Tensor_R_sample = np.transpose(R.array(), [0, 2, 1]) 309 | Tensor_C_sample = np.transpose(C.array(), [0, 2, 1]) 310 | Mask_Y = np.transpose(mask, [0, 2, 1]) 311 | 312 | tensor_dims = [n1, n3, n2] 313 | maxepochs = 1 314 | tolcost = 1e-14 315 | permute_on = False 316 | 317 | 318 | if (dataset is 'brain'): 319 | options = { 320 | 'maxepochs': maxepochs, 321 | 'tolcost': tolcost, 322 | 'lam': 0.8, 323 | 'mu': 0.001, 324 | 'permute_on': permute_on, 325 | 'store_subinfo': True, 326 | 'store_matrix': False, 327 | 'verbose': False, 328 | 'tw_flag': None, 329 | 'tw_len': None 330 | } 331 | else: 332 | options = { 333 | 'maxepochs': maxepochs, 334 | 'tolcost': tolcost, 335 | 'lam': 0.5, 336 | 'mu': 0.0001, 337 | 'permute_on': permute_on, 338 | 'store_subinfo': True, 339 | 'store_matrix': False, 340 | 'verbose': False, 341 | 'tw_flag': None, 342 | 'tw_len': None 343 | } 344 | 345 | Xinit = { 346 | 'A': np.random.randn(tensor_dims[0], rank), 347 | 'B': np.random.randn(tensor_dims[1], rank), 348 | 'C': np.random.randn(tensor_dims[2], rank) 349 | } 350 | 351 | Xsol_olstec, Rhat_olstec, info_olstec, sub_infos_olstec = OLSTEC(Tensor_R_sample, Mask_Y, None, tensor_dims, 352 | rank, 353 | Xinit, options) 354 | Xsol_olstec, Chat_olstec, info_olstec, sub_infos_olstec = OLSTEC(Tensor_C_sample, Mask_Y, None, tensor_dims, 355 | rank, 356 | Xinit, options) 357 | 358 | Rhat_olstec = Tensor(np.transpose(Rhat_olstec, [0, 2, 1])) 359 | Chat_olstec = Tensor(np.transpose(Chat_olstec, [0, 2, 1])) 360 | 361 | Rhat_nrmse_olstec = tfrobnorm(Rhat_olstec - R) / Rfrob 362 | Chat_nrmse_olstec = tfrobnorm(Chat_olstec - C) / Cfrob 363 | tElapsed_olstec = np.sum(sub_infos_olstec['times']) 364 | 365 | print('OLSTEC Time: {:4f}'.format(tElapsed_olstec)) 366 | print('Rhat NRMSE OLSTEC: {:6f}'.format(Rhat_nrmse_olstec)) 367 | print('Chat NRMSE OLSTEC: {:6f}'.format(Chat_nrmse_olstec)) 368 | 369 | slice_err_olstec, rec_im_olstec, kdata_rec_olstec = compute_stats(Rhat_olstec, Chat_olstec, R, C, orig_im, 370 | kdata_orig) 371 | #######################################################################3 372 | ##### TeCPSGD 373 | rank = 50 374 | tensor_dims = [n1, n3, n2] 375 | maxepochs = 1 376 | tolcost = 1e-14 377 | permute_on = False 378 | 379 | if (dataset is 'brain'): 380 | options = { 381 | 'maxepochs': maxepochs, 382 | 'tolcost': tolcost, 383 | 'lam': 0.001, 384 | 'stepsize': 0.01, 385 | 'permute_on': permute_on, 386 | 'store_subinfo': True, 387 | 'store_matrix': False, 388 | 'verbose': False 389 | } 390 | 391 | else: 392 | options = { 393 | 'maxepochs': maxepochs, 394 | 'tolcost': tolcost, 395 | 'lam': 0.0001, 396 | 'stepsize': 100000, 397 | 'permute_on': permute_on, 398 | 'store_subinfo': True, 399 | 'store_matrix': False, 400 | 'verbose': False 401 | } 402 | 403 | # Xinit = { 404 | # 'A': np.random.randn(tensor_dims[0], rank), 405 | # 'B': np.random.randn(tensor_dims[1], rank), 406 | # 'C': np.random.randn(tensor_dims[2], rank) 407 | # } 408 | 409 | Xsol_TeCPSGD, Rhat_tecpsgd, info_TeCPSGD, sub_infos_TeCPSGD = TeCPSGD(Tensor_R_sample, Mask_Y, None, 410 | tensor_dims, rank, 411 | Xinit, options) 412 | 413 | Xsol_TeCPSGD, Chat_tecpsgd, info_TeCPSGD, sub_infos_TeCPSGD = TeCPSGD(Tensor_C_sample, Mask_Y, None, 414 | tensor_dims, rank, 415 | Xinit, options) 416 | 417 | Rhat_tecpsgd = Tensor(np.transpose(Rhat_tecpsgd, [0, 2, 1])) 418 | Chat_tecpsgd = Tensor(np.transpose(Chat_tecpsgd, [0, 2, 1])) 419 | 420 | Rhat_nrmse_tecpsgd = tfrobnorm(Rhat_tecpsgd - R) / Rfrob 421 | Chat_nrmse_tecpsgd = tfrobnorm(Chat_tecpsgd - C) / Cfrob 422 | 423 | print('Rhat NRMSE TeCPSGD: {:6f}'.format(Rhat_nrmse_tecpsgd)) 424 | print('Chat NRMSE TeCPSGD: {:6f}'.format(Chat_nrmse_tecpsgd)) 425 | tElapsed_tecpsgd = np.sum(sub_infos_TeCPSGD['times']) 426 | 427 | slice_err_tecpsgd, rec_im_tecpsgd, kdata_rec_tecpsgd = compute_stats(Rhat_tecpsgd, Chat_tecpsgd, R, C, 428 | orig_im, kdata_orig) 429 | 430 | ############################################## 431 | 432 | 433 | # print(name + ' NRMSE: {:.5f}'.format(nrmse)) 434 | # print(name + ' SSIM: {:.5f} \n'.format(score)) 435 | 436 | vis_idx = 39 437 | fig = plt.figure() 438 | plt.figure(figsize=(20, 20)) 439 | plt.subplot(2, 4, 1) 440 | plt.imshow(orig_im[vis_idx, :, :], cmap='gray') 441 | # plt.title('Original: NRMSE: {:5f}, SSIM {:5f}'.format(*computeStats(orig_im))) 442 | plt.title('Original') 443 | 444 | plt.subplot(2, 4, 2) 445 | plt.imshow(subsampled_im[vis_idx, :, :], cmap='gray') 446 | # plt.title('Subsampled: NRMSE: {:5f}, SSIM {:5f}'.format(*computeStats(subsampled_im))) 447 | plt.title('Subsampled') 448 | 449 | plt.subplot(2, 4, 3) 450 | plt.imshow(rec_im_toucan[vis_idx, :, :], cmap='gray') 451 | # plt.title('TOUCAN: NRMSE: {:5f}, SSIM {:5f}'.format(*computeStats(toucan_rec_im))) 452 | plt.title('TOUCAN') 453 | 454 | plt.subplot(2, 4, 4) 455 | plt.imshow(rec_im_tnn[vis_idx, :, :], cmap='gray') 456 | # plt.title('TNN-ADMM: NRMSE: {:5f}, SSIM {:5f}'.format(*computeStats(tnn_rec_im))) 457 | plt.title('TNN-ADMM') 458 | 459 | plt.subplot(2, 4, 5) 460 | plt.imshow(rec_im_tctf[vis_idx, :, :], cmap='gray') 461 | # plt.title('TCTF NRMSE: {:5f}, SSIM {:5f}'.format(*computeStats(tctf_rec_im))) 462 | plt.title('TCTF') 463 | 464 | if(tube is False): 465 | plt.subplot(2,4,6) 466 | plt.imshow(rec_im_stc[vis_idx,:,:],cmap='gray') 467 | # plt.title('STC NRMSE: {:5f}, SSIM {:5f}'.format(*computeStats(stc_rec_im))) 468 | plt.title('STC') 469 | 470 | plt.subplot(2, 4, 7) 471 | plt.imshow(rec_im_olstec[vis_idx, :, :], cmap='gray') 472 | # plt.title('STC NRMSE: {:5f}, SSIM {:5f}'.format(*computeStats(stc_rec_im))) 473 | plt.title('OLSTEC') 474 | 475 | plt.subplot(2, 4, 8) 476 | plt.imshow(rec_im_tecpsgd[vis_idx, :, :], cmap='gray') 477 | # plt.title('STC NRMSE: {:5f}, SSIM {:5f}'.format(*computeStats(stc_rec_im))) 478 | plt.title('TeCPSGD') 479 | 480 | plt.rcParams.update({'font.size': 22}) 481 | name = '/Users/kgilman/Desktop/t-SVD/mri_reconstruct/mri_results/tensor_mri_reconstruct_' + dataset + '_tube' + str(tube) + '_' + str(int(rho * 100)) + \ 482 | '.eps' 483 | plt.savefig(name) 484 | plt.show() 485 | plt.close() 486 | 487 | fig = plt.figure(figsize=(12, 5), tight_layout=True) 488 | # plt.semilogy(np.arange(0, len(slice_err_tnn)), slice_err_tnn, 'r', label='TNN-ADMM') 489 | plt.semilogy(np.arange(0, len(slice_err_toucan)), slice_err_toucan, '#ff7f0e', label='TOUCAN') 490 | # plt.semilogy(np.arange(0, len(slice_err_tctf)), slice_err_tctf, 'k', label='TCTF') 491 | if(tube is False): 492 | plt.semilogy(np.arange(0, len(slice_err_stc)), slice_err_stc, '#FF007F', label='STC') 493 | plt.semilogy(np.arange(0, len(slice_err_olstec)), slice_err_olstec, '#8B008B', label='OLSTEC') 494 | plt.semilogy(np.arange(0, len(slice_err_tecpsgd)), slice_err_tecpsgd, '#00FFFF', label='TeCPSGD') 495 | plt.legend(bbox_to_anchor=(1.5, 1)) 496 | plt.title('NRMSE by frame') 497 | plt.xlabel('Frame idx') 498 | name = '/Users/kgilman/Desktop/t-SVD/mri_reconstruct/mri_results/tensor_mri_reconstruct_' + dataset + '_tube_' + str(tube) + \ 499 | '_' + str(int(rho * 100)) + '_frameNRMSE.eps' 500 | plt.savefig(name) 501 | plt.show() 502 | plt.close() 503 | 504 | ### Print computation times 505 | filename = '/Users/kgilman/Desktop/t-SVD/mri_reconstruct/mri_results/times_' + dataset + '_tube_' + str(tube) + '_' + str(int(rho * 100)) + '.text' 506 | print('TNN ADMM: {:.3f} '.format(tElapsed_tensor), file=open(filename, "a")) 507 | print('TCTF: {:4f}'.format(tElapsed_tctf), file=open(filename, "a")) 508 | print('TOUCAN: {:4f}'.format(tElapsed_toucan), file=open(filename, "a")) 509 | if(tube is False): 510 | print('STC: {:4f}'.format(tElapsed_stc),file=open(filename, "a")) 511 | print('OLSTEC: {:4f}'.format(tElapsed_olstec), file=open(filename, "a")) 512 | print('TeCPSGD: {:4f}'.format(tElapsed_tecpsgd), file=open(filename, "a")) 513 | 514 | filename = '/Users/kgilman/Desktop/t-SVD/mri_reconstruct/mri_results/stats_' + dataset + '_tube_' + str(tube) + '_' + str(int(rho * 100)) + '.text' 515 | print('Original: NRMSE: {:5f}, SSIM {:5f}'.format(*computeStats(orig_im)), file=open(filename, "a")) 516 | print('Subsampled: NRMSE: {:5f}, SSIM {:5f}'.format(*computeStats(subsampled_im)), file=open(filename, "a")) 517 | print('TOUCAN: NRMSE: {:5f}, SSIM {:5f}'.format(*computeStats(rec_im_toucan)), file=open(filename, "a")) 518 | print('TNN-ADMM: NRMSE: {:5f}, SSIM {:5f}'.format(*computeStats(rec_im_tnn)), file=open(filename, "a")) 519 | print('TCTF NRMSE: {:5f}, SSIM {:5f}'.format(*computeStats(rec_im_tctf)), file=open(filename, "a")) 520 | if(tube is False): 521 | print('STC NRMSE: {:5f}, SSIM {:5f}'.format(*computeStats(rec_im_stc)),file=open(filename,"a")) 522 | print('OLSTEC NRMSE: {:5f}, SSIM {:5f}'.format(*computeStats(rec_im_olstec)), file=open(filename, "a")) 523 | print('TeCPSGD NRMSE: {:5f}, SSIM {:5f}'.format(*computeStats(rec_im_tecpsgd)), file=open(filename, "a")) 524 | 525 | 526 | 527 | 528 | -------------------------------------------------------------------------------- /jupyter_notebooks/stc.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import time 3 | from numpy import ndarray 4 | 5 | def grouse_stream(v, xIdx, U, step=None): 6 | 7 | ### Main GROUSE update 8 | 9 | 10 | U_Omega = U[xIdx,:] 11 | v_Omega = v[xIdx] 12 | w_hat = np.linalg.pinv(U_Omega)@v_Omega 13 | 14 | r = v_Omega - U_Omega @ w_hat 15 | 16 | rnorm = np.linalg.norm(r) 17 | wnorm = np.linalg.norm(w_hat) 18 | sigma = rnorm * np.linalg.norm(w_hat) 19 | 20 | if(step is None): 21 | t = np.arctan(rnorm / wnorm) 22 | else: 23 | t = step * sigma 24 | 25 | alpha = (np.cos(t) - 1) / wnorm**2 26 | beta = np.sin(t) / sigma 27 | Ustep = U @ (alpha * w_hat) 28 | Ustep[xIdx] = Ustep[xIdx] + beta * r 29 | Uhat = U + np.outer(Ustep, w_hat) 30 | 31 | 32 | return Uhat, w_hat 33 | 34 | 35 | def stc_stream(Yvec, M, U1, U2, U3, NumCycles): 36 | 37 | V = np.multiply(M, Yvec) 38 | 39 | t,o = V.shape 40 | 41 | ## Mode 1 42 | for k in range(0,NumCycles): 43 | for i in range(0,t): 44 | idx = np.where(M[i,:] > 0)[0] 45 | U1, w1 = grouse_stream(V[i,:],idx,U1) 46 | 47 | W1 = np.linalg.pinv(U1) @ Yvec.T 48 | 49 | ## Mode 2 50 | for k in range(0, NumCycles): 51 | for j in range(0, o): 52 | idx = np.where(M[:, j] > 0)[0] 53 | U2, w2 = grouse_stream(V[:, j], idx, U2) 54 | 55 | W2 = np.linalg.pinv(U2) @ Yvec 56 | 57 | ## Mode 3 58 | for k in range(0,NumCycles): 59 | idx = np.where(np.reshape(M,-1) > 0)[0] 60 | U3,W3 = grouse_stream(np.reshape(Yvec,-1),idx,U3) 61 | 62 | 63 | return U1,U2,U3,W1,W2,W3 64 | 65 | 66 | def stc(Y,mask,r1,r2,r3,outercycles,numcycles,tol=1e-9,fun=lambda Lhat: [0,0],randomOrder=False,verbose=False): 67 | 68 | t,o,d = Y.shape 69 | 70 | Lhat = np.zeros(Y.shape) 71 | 72 | ### Initialize U 73 | U1 = np.linalg.svd(np.random.randn(o,r1),full_matrices=False)[0] 74 | U2 = np.linalg.svd(np.random.randn(t,r2),full_matrices=False)[0] 75 | U3 = np.linalg.svd(np.random.randn(o*t,r3),full_matrices=False)[0] 76 | 77 | stats = np.zeros((outercycles * d + 1,3)) 78 | # cost,nrmse = fun(Y) 79 | cost = 0 80 | nrmse = 1 81 | stats[0,:] = [cost,nrmse,0] 82 | 83 | iter = 1 84 | for outer in range(0, outercycles): 85 | if (randomOrder): 86 | frame_order = np.random.permutation(d) 87 | else: 88 | frame_order = np.arange(0,d) 89 | 90 | for inner in range(0, d): 91 | frame_idx = frame_order[inner] 92 | Yvec = Y[:, :, frame_idx] 93 | 94 | tStart = time.time() 95 | U1,U2,U3,W1,W2,W3 = stc_stream(Yvec, mask[:,:,frame_idx], U1, U2, U3, numcycles) 96 | tEnd = time.time() 97 | 98 | tElapsed = tEnd - tStart 99 | 100 | rec1 = U1 @ W1 101 | rec2 = U2 @ W2 102 | rec3 = U3 @ W3 103 | 104 | frec1 = np.transpose(rec1) 105 | frec2 = rec2 106 | frec3 = np.reshape(rec3,Yvec.shape) 107 | 108 | Lvec_hat = 1/3 * (frec1 + frec2 + frec3) 109 | Lhat[:,:,frame_idx] = Lvec_hat 110 | cost, nrmse = fun(Lhat,frame_idx) 111 | if(nrmse < tol): 112 | break; 113 | 114 | stats[iter, :] = [cost, nrmse, tElapsed] 115 | iter += 1 116 | 117 | if(verbose): 118 | if(inner % 10 == 0): 119 | print('Outer[{:d}], Inner[{:d}]: NRMSE: {:.8f} '.format(outer, inner, nrmse)) 120 | 121 | tElapsed = np.sum(stats[:,-1]) 122 | return Lhat, stats, tElapsed 123 | -------------------------------------------------------------------------------- /jupyter_notebooks/tecpsgd.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import time 3 | 4 | def compute_cost_tensor(X, P, PA, tensor_dims): 5 | n1 = tensor_dims[0] 6 | n2 = tensor_dims[1] 7 | n3 = tensor_dims[2] 8 | 9 | Diff = P * X - PA 10 | Diff_flat = np.reshape(Diff, (n1 * n2, n3)) 11 | 12 | return 0.5 * np.linalg.norm(Diff_flat, 'fro') ** 2 13 | 14 | def TeCPSGD(A_in, Omega_in, Gamma_in, tensor_dims, rank, xinit, options, fun = lambda X: 0): 15 | A = A_in # Full entries 16 | Omega = Omega_in # Trainingset 'Omega' 17 | Gamma = Gamma_in # Test set 'Gamma' 18 | 19 | Yhat = np.zeros(A.shape) 20 | 21 | A_Omega = Omega_in * A_in #Training entries i.e., Omega_in. * A_in 22 | if Gamma_in is not None: 23 | A_Gamma = Gamma_in * A_in # Test entries i.e., Gamma_in. * A_in 24 | else: 25 | A_Gamma = [] 26 | 27 | if xinit is None: 28 | A_t0 = np.random.randn(tensor_dims[0], rank) 29 | B_t0 = np.random.randn(tensor_dims[1], rank) 30 | C_t0 = np.random.randn(tensor_dims[2], rank) 31 | else: 32 | A_t0 = xinit['A'] 33 | B_t0 = xinit['B'] 34 | C_t0 = xinit['C'] 35 | 36 | # set tensor size 37 | rows = tensor_dims[0] 38 | cols = tensor_dims[1] 39 | slice_length = tensor_dims[2] 40 | 41 | # set options 42 | lam = options['lam'] 43 | # mu = options['mu'] 44 | stepsize_init = options['stepsize'] 45 | maxepochs = options['maxepochs'] 46 | tolcost = options['tolcost'] 47 | store_subinfo = options['store_subinfo'] 48 | store_matrix = options['store_matrix'] 49 | verbose = options['verbose'] 50 | 51 | if options['permute_on'] is None: 52 | permute_on = 1 53 | else: 54 | permute_on = options['permute_on'] 55 | 56 | # calculate initial cost 57 | Rec = np.zeros((rows, cols, slice_length)) 58 | for k in range(0,slice_length): 59 | gamma = C_t0[k,:].T 60 | Rec[:,:,k] = A_t0 @ np.diag(gamma) @ B_t0.T 61 | 62 | train_cost = compute_cost_tensor(Rec, Omega, A_Omega, tensor_dims) 63 | 64 | if Gamma is None and A_Gamma is None: 65 | test_cost = compute_cost_tensor(Rec, Gamma, A_Gamma, tensor_dims) 66 | else: 67 | test_cost = 0 68 | 69 | # initialize infos 70 | infos = { 71 | 'iter': 0, 72 | 'train_cost': train_cost, 73 | 'test_cost': test_cost, 74 | 'time': [0] 75 | } 76 | 77 | # initialize sub_infos 78 | sub_infos = { 79 | 'inner_iter': 0, 80 | 'err_residual':0, 81 | 'error': 0, 82 | 'err_run_ave': [0], 83 | 'global_train_cost':0, 84 | 'global_test_cost':0, 85 | 'times':[0] 86 | } 87 | 88 | if store_matrix: 89 | sub_infos['I'] = np.zeros((rows * cols, slice_length)) 90 | sub_infos['L'] = np.zeros((rows * cols, slice_length)) 91 | sub_infos['E'] = np.zeros((rows * cols, slice_length)) 92 | 93 | # set parameters 94 | eta = 0 95 | 96 | if verbose > 0: 97 | print('TeCPSGD [{:3f}] Epoch 000, Cost {:.5f}, Cost(test) {:.5f}, Stepsize {:.5f}'.format(stepsize_init, 98 | train_cost,test_cost, eta)) 99 | # main loop 100 | for outiter in range(0,maxepochs): 101 | #permute samples 102 | if permute_on: 103 | col_order = np.random.permutation(slice_length) 104 | else: 105 | col_order = np.arange(0,slice_length) 106 | 107 | 108 | # Begin the time counter for the epoch 109 | 110 | for k in range(0,slice_length): 111 | 112 | tStart = time.time() 113 | 114 | fnum = (outiter + 1 - 1) * slice_length + k + 1 115 | 116 | #sampled original image 117 | I_mat = A[:,:, col_order[k]] 118 | Omega_mat = Omega[:,:, col_order[k]] 119 | I_mat_Omega = Omega_mat * I_mat 120 | 121 | # Recalculate gamma(C) 122 | temp3 = 0 123 | temp4 = 0 124 | for m in range(0,rows): 125 | alpha_remat = np.tile(A_t0[m,:].T, (cols,1)).T 126 | alpha_beta = alpha_remat * B_t0.T 127 | I_row = I_mat_Omega[m,:] 128 | temp3 = temp3 + alpha_beta @ I_row.T 129 | 130 | Omega_mat_ind = np.where(Omega_mat[m,:]>0)[0] 131 | alpha_beta_Omega = alpha_beta[:, Omega_mat_ind] 132 | temp4 = temp4 + alpha_beta_Omega @ alpha_beta_Omega.T 133 | 134 | temp4 = lam * np.eye(rank) + temp4 135 | gamma = np.linalg.lstsq(temp4,temp3,rcond=-1)[0] # gamma = temp4 \ temp3; 136 | 137 | L_rec = A_t0 @ np.diag(gamma) @ B_t0.T 138 | diff = Omega_mat * (I_mat - L_rec) 139 | 140 | eta = stepsize_init / (1 +lam * stepsize_init * fnum) 141 | A_t1 = (1 - lam * eta) * A_t0 + eta * diff @ B_t0 @ np.diag(gamma) # equation(20) & (21) 142 | B_t1 = (1 - lam * eta) * B_t0 + eta * diff.T @ A_t0 @ np.diag(gamma) # equation (20) & (22) 143 | 144 | A_t0 = A_t1 145 | B_t0 = B_t1 146 | 147 | # Recalculate gamma(C) 148 | temp3 = 0 149 | temp4 = 0 150 | for m in range(0, rows): 151 | alpha_remat = np.tile(A_t0[m, :].T, (cols, 1)).T 152 | alpha_beta = alpha_remat * B_t0.T 153 | I_row = I_mat_Omega[m, :] 154 | temp3 = temp3 + alpha_beta @ I_row.T 155 | 156 | Omega_mat_ind = np.where(Omega_mat[m, :] > 0)[0] 157 | alpha_beta_Omega = alpha_beta[:, Omega_mat_ind] 158 | temp4 = temp4 + alpha_beta_Omega @ alpha_beta_Omega.T 159 | 160 | temp4 = lam * np.eye(rank) + temp4 161 | # gamma = np.linalg.lstsq(temp4, temp3, rcond=-1)[0] # gamma = temp4 \ temp3; 162 | gamma = np.linalg.inv(temp4)@temp3 163 | 164 | tElapsed = time.time() - tStart 165 | 166 | #Store gamma into C_t0 167 | C_t0[col_order[k],:] = gamma.T 168 | 169 | #Reconstruct Low - rank Matrix 170 | L_rec = A_t0 @ np.diag(gamma) @ B_t0.T 171 | Yhat[:,:,k] = L_rec 172 | 173 | ## Diagnostics 174 | 175 | if store_matrix: 176 | E_rec = I_mat - L_rec 177 | sub_infos['E'] = np.append(sub_infos['E'], np.vectorize(E_rec)) 178 | I = sub_infos['I'] 179 | I[:,k] = np.vectorize(I_mat_Omega) 180 | sub_infos['I'] = I 181 | 182 | L = sub_infos['L'] 183 | L[:, k] = np.vectorize(L_rec) 184 | sub_infos['L'] = L 185 | 186 | E = sub_infos['E'] 187 | E[:, k] = np.vectorize(E_rec) 188 | sub_infos['E'] = E 189 | 190 | # sub_infos.I[:, k] = np.vectorize(I_mat_Omega) 191 | # sub_infos.L[:, k] = np.vectorize(L_rec) 192 | # sub_infos.E[:, k] = np.vectorize(E_rec) 193 | 194 | if store_subinfo: 195 | # Residual Error 196 | norm_residual = np.linalg.norm(I_mat - L_rec,'fro') 197 | norm_I = np.linalg.norm(I_mat,'fro') 198 | error = norm_residual / norm_I 199 | sub_infos['inner_iter'] = np.append(sub_infos['inner_iter'], (outiter + 1 - 1) * slice_length + k + 1) 200 | sub_infos['err_residual'] = np.append(sub_infos['err_residual'], error) 201 | sub_infos['times'] = np.append(sub_infos['times'],tElapsed) 202 | sub_infos['error'] = np.append(sub_infos['error'],fun(Yhat)) 203 | 204 | #Running - average Estimation Error 205 | if k == 0: 206 | run_error = error 207 | else: 208 | run_error = (sub_infos['err_run_ave'][-1] * (k + 1 - 1) + error) / (k+1) 209 | 210 | sub_infos['err_run_ave'] = np.append(sub_infos['err_run_ave'], run_error) 211 | 212 | # Store reconstruction Error 213 | if store_matrix: 214 | E_rec = I_mat - L_rec 215 | sub_infos['E'] = np.append(sub_infos['E'],np.vectorize(E_rec)) 216 | 217 | # for f in range(0,slice_length): 218 | # gamma = C_t0[f,:].T 219 | # Rec[:,:, f] = A_t0 @ np.diag(gamma) @ B_t0.T 220 | 221 | # Global train_cost computation 222 | # train_cost = compute_cost_tensor(Rec, Omega, A_Omega, tensor_dims) 223 | # if Gamma is None and A_Gamma is None: 224 | # test_cost = compute_cost_tensor(Rec, Gamma, A_Gamma, tensor_dims) 225 | # else: 226 | # test_cost = 0 227 | 228 | # sub_infos['global_train_cost'] = np.append(sub_infos['global_train_cost'],train_cost) 229 | # sub_infos['global_test_cost'] = np.append(sub_infos['global_test_cost'],test_cost) 230 | 231 | if verbose: 232 | train_cost = 0 233 | fnum = (outiter + 1 - 1) * slice_length + k + 1 234 | print('TeCPSGD: fnum = {:3d}, cost = {:.3f}, error = {:.8f}'.format(fnum, train_cost, error)) 235 | 236 | # store infos 237 | infos['iter'] = np.append(infos['iter'],outiter) 238 | # infos['time'] = np.append(infos['time'], infos['time'][-1] + time.time() - t_begin) 239 | 240 | if store_subinfo is False: 241 | # for f in range(0,slice_length): 242 | # gamma = C_t0[f,:].T 243 | # Rec[:,:, f] = A_t0 @ np.diag(gamma) @ B_t0.T 244 | pass 245 | 246 | # train_cost = compute_cost_tensor(Rec, Omega, A_Omega, tensor_dims) 247 | if Gamma is None and A_Gamma is None: 248 | test_cost = compute_cost_tensor(Rec, Gamma, A_Gamma, tensor_dims) 249 | else: 250 | test_cost = 0 251 | 252 | infos['train_cost'] = [infos['train_cost'], train_cost] 253 | infos['test_cost'] = [infos['test_cost'], test_cost] 254 | 255 | if verbose > 0: 256 | print('TeCPSGD [{:3f}] Epoch {:3d}, Cost {:.7f}, Cost(test) {:.7f}, Stepsize {:.7f}'.format( 257 | stepsize_init, outiter, train_cost, test_cost, eta)) 258 | 259 | #stopping criteria: cost tolerance reached 260 | # if train_cost < tolcost: 261 | # print('train_cost sufficiently decreased.') 262 | # break 263 | 264 | Xsol = { 265 | 'A': A_t0, 266 | 'B': B_t0, 267 | 'C': C_t0 268 | } 269 | 270 | return Xsol, Yhat, infos, sub_infos -------------------------------------------------------------------------------- /jupyter_notebooks/toucan_video_completion.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from tsvd import *\n", 10 | "import numpy as np\n", 11 | "import matplotlib.pyplot as plt\n", 12 | "import matplotlib.image as mpimg\n", 13 | "import time\n", 14 | "import scipy.misc\n", 15 | "from tecpsgd import *\n", 16 | "from olstec import *\n", 17 | "from stc import *\n", 18 | "from skimage.measure import compare_ssim\n", 19 | "\n", 20 | "\n", 21 | "## Read in video data and form tensor\n", 22 | "import os\n", 23 | "path = '/Users/kgilman/Desktop/datasets/dataMERL/'\n", 24 | "files = []\n", 25 | "for i in sorted(os.listdir(path)):\n", 26 | " if os.path.isfile(os.path.join(path,i)) and 'frame' in i:\n", 27 | " files.append(i)\n", 28 | "\n", 29 | "num_frames = len(files)\n", 30 | "nx,ny = mpimg.imread(path + files[0]).shape\n", 31 | "L = np.zeros((int(nx), num_frames, int(ny)))\n", 32 | "\n", 33 | "for i in range(0,num_frames):\n", 34 | " im = mpimg.imread(path + files[i])\n", 35 | " im = im / np.max(im)\n", 36 | " L[:,i,:] = im\n", 37 | "\n", 38 | " \n", 39 | "# Lmean = np.mean(L)\n", 40 | "Lmean = 0\n", 41 | "L -= Lmean\n", 42 | "L = Tensor(L)\n", 43 | "n1,n2,n3 = L.shape()\n", 44 | "\n", 45 | "np.random.seed(0)\n", 46 | "rho = 0.8 #% missing entries\n", 47 | "mask = np.random.rand(n1, n2, n3)\n", 48 | "mask[mask > rho] = 1\n", 49 | "mask[mask <= rho] = 0\n", 50 | "mask = mask.astype(int)\n", 51 | "\n", 52 | "Y = Tensor(L.array() * mask)\n", 53 | "\n", 54 | "Lfrob = tfrobnorm(L)\n", 55 | "\n", 56 | "print(Y.shape())\n", 57 | "\n", 58 | "plt.imshow(L.array()[:,35,:],cmap='gray')\n", 59 | "name = 'original-' + str(int(rho*100)) + '%'+'.eps'\n", 60 | "plt.savefig(name)\n", 61 | "plt.show()\n", 62 | "\n", 63 | "plt.imshow(Y.array()[:,35,:],cmap='gray')\n", 64 | "name = 'observed-' + str(int(rho*100)) + '%'+'.eps'\n", 65 | "plt.savefig(name)\n", 66 | "plt.show()\n" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": null, 72 | "metadata": {}, 73 | "outputs": [], 74 | "source": [ 75 | "frame = L.array()[:,35,:]\n", 76 | "print(frame.shape)\n", 77 | "blk_size = 25\n", 78 | "ranks = []\n", 79 | "for k in range(0,int(np.floor(n3/blk_size))):\n", 80 | " blk = frame[:,blk_size * k: blk_size * k + blk_size]\n", 81 | " U,S,V = np.linalg.svd(blk,full_matrices=False)\n", 82 | " pwr_perc = np.cumsum(S / np.sum(S))\n", 83 | "# print(pwr_perc)\n", 84 | " idx = np.where(pwr_perc >= 0.8)[0]\n", 85 | "# print(idx)\n", 86 | " ranks.append(idx[0] + 1)\n", 87 | " \n", 88 | "plt.plot(ranks)\n", 89 | "plt.show()\n" 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "execution_count": null, 95 | "metadata": {}, 96 | "outputs": [], 97 | "source": [ 98 | "U,S,V = tsvd(L,full=False)\n", 99 | "\n", 100 | "plt.figure(figsize=(10,5))\n", 101 | "plt.plot(np.diag(S.array()[:,:,0]))\n", 102 | "plt.xticks(np.arange(0, min(n1,n2), step=5))\n", 103 | "plt.rcParams.update({'font.size': 15})\n", 104 | "plt.show()\n", 105 | "\n", 106 | "s = np.diag(S.array()[:,:,0])\n", 107 | "power80 = 0.75*np.sum(s)\n", 108 | "\n", 109 | "cum_sum = 0\n", 110 | "k = 0\n", 111 | "for i in range(0,len(s)):\n", 112 | " cum_sum += s[i]\n", 113 | " k += 1\n", 114 | " if(cum_sum > power80):\n", 115 | " break\n", 116 | "\n", 117 | "print('Number of LR t-SVD Approx Components: {:.3f}'.format(k))" 118 | ] 119 | }, 120 | { 121 | "cell_type": "markdown", 122 | "metadata": {}, 123 | "source": [ 124 | "### STC" 125 | ] 126 | }, 127 | { 128 | "cell_type": "code", 129 | "execution_count": null, 130 | "metadata": {}, 131 | "outputs": [], 132 | "source": [ 133 | "## Sequential Tensor Completion\n", 134 | "\n", 135 | "Tensor_Y = np.transpose(L.array(),[0,2,1])\n", 136 | "Mask_Y = np.transpose(mask,[0,2,1])\n", 137 | "numcycles = 1\n", 138 | "outer = 3\n", 139 | "r1 = 50\n", 140 | "r2 = 50\n", 141 | "r3 = 2\n", 142 | "fun = lambda Lhat,idx: [0, tfrobnorm_array(Lhat[:,:,idx] - Tensor_Y[:,:,idx]) / tfrobnorm_array(Tensor_Y[:,:,idx])]\n", 143 | "Lhat, stats, tElapsed = stc(Tensor_Y,Mask_Y,r1,r2,r3,outer,numcycles,fun=fun,verbose=True)\n", 144 | "\n", 145 | "Lhat = np.transpose(Lhat,[0,2,1])\n", 146 | "\n", 147 | "nrmse_stc = tfrobnorm(Tensor(Lhat) - L) / Lfrob\n", 148 | "\n", 149 | "print('STC Time: {:4f}'.format(tElapsed))\n", 150 | "print('NRMSE STC: {:6f}'.format(nrmse_stc))" 151 | ] 152 | }, 153 | { 154 | "cell_type": "code", 155 | "execution_count": null, 156 | "metadata": {}, 157 | "outputs": [], 158 | "source": [ 159 | "plt.imshow(Lhat[:,35,:],cmap='gray')\n", 160 | "name = 'merl_stc-' + str(int(rho*100)) + '%'+'.eps'\n", 161 | "plt.savefig(name)\n", 162 | "plt.show()" 163 | ] 164 | }, 165 | { 166 | "cell_type": "code", 167 | "execution_count": null, 168 | "metadata": {}, 169 | "outputs": [], 170 | "source": [ 171 | "# from skimage.metrics import structural_similarity as ssim\n", 172 | "# ssim_noise = ssim(L[:,35,:], Lhat[:,35,:],\n", 173 | "# data_range=np.max(Lhat[:,35,:]) - np.min(Lhat[:,35,:]))\n", 174 | "\n", 175 | "(score, diff) = compare_ssim(L.array()[:,35,:], Lhat[:,35,:], full=True)\n", 176 | "print('STC SSIM: {:.5f}'.format(score))" 177 | ] 178 | }, 179 | { 180 | "cell_type": "markdown", 181 | "metadata": {}, 182 | "source": [ 183 | "### OLSTEC" 184 | ] 185 | }, 186 | { 187 | "cell_type": "code", 188 | "execution_count": null, 189 | "metadata": {}, 190 | "outputs": [], 191 | "source": [ 192 | "## OLSTEC\n", 193 | "\n", 194 | "Tensor_Y_Noiseless = np.transpose(L.array(),[0,2,1])\n", 195 | "rank = 100\n", 196 | "OmegaTensor = np.transpose(mask,[0,2,1])\n", 197 | "tensor_dims = [n1,n3,n2]\n", 198 | "maxepochs = 2\n", 199 | "tolcost = 1e-14\n", 200 | "permute_on = False\n", 201 | "\n", 202 | "options = {\n", 203 | " 'maxepochs': maxepochs,\n", 204 | " 'tolcost': tolcost,\n", 205 | " 'lam': 0.7,\n", 206 | " 'mu': 0.1,\n", 207 | " 'permute_on': permute_on,\n", 208 | " 'store_subinfo': True,\n", 209 | " 'store_matrix': False,\n", 210 | " 'verbose': False,\n", 211 | " 'tw_flag': None,\n", 212 | " 'tw_len': None\n", 213 | "}\n", 214 | "\n", 215 | "Xinit = {\n", 216 | " 'A': np.random.randn(tensor_dims[0], rank),\n", 217 | " 'B': np.random.randn(tensor_dims[1], rank),\n", 218 | " 'C': np.random.randn(tensor_dims[2], rank)\n", 219 | "}\n", 220 | "\n", 221 | "Xsol_olstec, Y_hat_olstec, info_olstec, sub_infos_olstec = OLSTEC(Tensor_Y_Noiseless, OmegaTensor, None, tensor_dims, rank,\n", 222 | " Xinit, options)\n", 223 | "\n", 224 | "A_t0 = Xsol_olstec['A']\n", 225 | "B_t0 = Xsol_olstec['B']\n", 226 | "C_t0 = Xsol_olstec['C']\n", 227 | "\n", 228 | "# Y_hat_olstec = np.zeros(Tensor_Y_Noiseless.shape)\n", 229 | "# for f in range(0,n2):\n", 230 | "# gamma = C_t0[f,:].T\n", 231 | "# Y_hat_olstec[:,:,f] = A_t0 @ np.diag(gamma) @ B_t0.T\n", 232 | " \n", 233 | "Y_hat_olstec = np.transpose(Y_hat_olstec,[0,2,1])\n", 234 | "\n", 235 | "nrmse_olstec = tfrobnorm(Tensor(Y_hat_olstec) - L) / Lfrob\n", 236 | "tElapsed_olstec = np.sum(sub_infos_olstec['times'])\n", 237 | "\n", 238 | "print('OLSTEC Time: {:4f}'.format(tElapsed_olstec))\n", 239 | "print('NRMSE OLSTEC: {:6f}'.format(nrmse_olstec))\n", 240 | "\n", 241 | "plt.imshow(Y_hat_olstec[:,35,:],cmap='gray')\n", 242 | "name = 'merl_olstec-' + str(int(rho*100)) + '%'+'.eps'\n", 243 | "plt.savefig(name)\n", 244 | "plt.show()\n", 245 | "\n", 246 | "(score, diff) = compare_ssim(L.array()[:,35,:], Y_hat_olstec[:,35,:], full=True)\n", 247 | "print('OLSTEC SSIM: {:.5f}'.format(score))" 248 | ] 249 | }, 250 | { 251 | "cell_type": "markdown", 252 | "metadata": {}, 253 | "source": [ 254 | "### TeCPSGD" 255 | ] 256 | }, 257 | { 258 | "cell_type": "code", 259 | "execution_count": null, 260 | "metadata": {}, 261 | "outputs": [], 262 | "source": [ 263 | "## TeCPSGD\n", 264 | "from tecpsgd import *\n", 265 | "Tensor_Y_Noiseless = np.transpose(L.array(),[0,2,1])\n", 266 | "rank = 100\n", 267 | "OmegaTensor = np.transpose(mask,[0,2,1])\n", 268 | "tensor_dims = [n1,n3,n2]\n", 269 | "maxepochs = 2\n", 270 | "tolcost = 1e-14\n", 271 | "permute_on = False\n", 272 | "\n", 273 | "options = {\n", 274 | " 'maxepochs': maxepochs,\n", 275 | " 'tolcost': tolcost,\n", 276 | " 'lam': 0.05,\n", 277 | " 'stepsize': 10,\n", 278 | "# 'mu': 0.1,\n", 279 | " 'permute_on': permute_on,\n", 280 | " 'store_subinfo': True,\n", 281 | " 'store_matrix': False,\n", 282 | " 'verbose': False\n", 283 | "}\n", 284 | "\n", 285 | "Xinit = {\n", 286 | " 'A': np.random.randn(tensor_dims[0], rank),\n", 287 | " 'B': np.random.randn(tensor_dims[1], rank),\n", 288 | " 'C': np.random.randn(tensor_dims[2], rank)\n", 289 | "}\n", 290 | "\n", 291 | "Xsol_TeCPSGD, Y_hat_tecpsgd, info_TeCPSGD, sub_infos_TeCPSGD = TeCPSGD(Tensor_Y_Noiseless, OmegaTensor, None, tensor_dims, rank,\n", 292 | " Xinit, options)\n", 293 | "\n", 294 | "A_t0 = Xsol_TeCPSGD['A']\n", 295 | "B_t0 = Xsol_TeCPSGD['B']\n", 296 | "C_t0 = Xsol_TeCPSGD['C']\n", 297 | "\n", 298 | "# Y_hat_tecpsgd = np.zeros(Tensor_Y_Noiseless.shape)\n", 299 | "# for f in range(0,n2):\n", 300 | "# gamma = C_t0[f,:].T\n", 301 | "# Y_hat_tecpsgd[:,:, f] = A_t0 @ np.diag(gamma) @ B_t0.T\n", 302 | " \n", 303 | "Y_hat_tecpsgd = np.transpose(Y_hat_tecpsgd,[0,2,1])\n", 304 | "\n", 305 | "nrmse_tecpsgd = tfrobnorm(Tensor(Y_hat_tecpsgd) - L) / Lfrob\n", 306 | "tElapsed_tecpsgd = np.sum(sub_infos_TeCPSGD['times'])\n", 307 | "\n", 308 | "print('TeCPSGD Time: {:4f}'.format(tElapsed_tecpsgd))\n", 309 | "print('NRMSE TeCPSGD: {:6f}'.format(nrmse_tecpsgd))\n", 310 | "plt.imshow(Y_hat_tecpsgd[:,35,:],cmap='gray')\n", 311 | "name = 'merl_tecpsgd-' + str(int(rho*100)) + '%'+'.eps'\n", 312 | "plt.savefig(name)\n", 313 | "plt.show()\n", 314 | "\n", 315 | "(score, diff) = compare_ssim(L.array()[:,35,:], Y_hat_tecpsgd[:,35,:], full=True)\n", 316 | "print('TeCPSGD SSIM: {:.5f}'.format(score))" 317 | ] 318 | }, 319 | { 320 | "cell_type": "markdown", 321 | "metadata": {}, 322 | "source": [ 323 | "### TNN-ADMM" 324 | ] 325 | }, 326 | { 327 | "cell_type": "code", 328 | "execution_count": null, 329 | "metadata": {}, 330 | "outputs": [], 331 | "source": [ 332 | "# ## TNN-ADMM\n", 333 | "fun = lambda X: [0,tfrobnorm(X - L) / Lfrob]\n", 334 | "# fun = lambda X: [0,0]\n", 335 | "Y_hat_tnn,stats_tnn,tElapsed_tnn = lrtc(Y,mask,niter = 75,fun=fun,verbose=False)\n", 336 | "# Y_hat_tnn,stats_tnn,tElapsed_tnn = lrtc(Tensor(np.transpose(Y.array(),[0,2,1])),np.transpose(mask,[0,2,1]),niter = 75,fun=fun,verbose=False)\n", 337 | "\n", 338 | "tnn_nrmse = tfrobnorm(Y_hat_tnn - L) / Lfrob\n", 339 | "print('TNN Time: {:4f}'.format(tElapsed_tnn))\n", 340 | "print('TNN NRMSE: {:4f}'.format(tnn_nrmse))\n", 341 | "\n", 342 | "plt.imshow(Y_hat_tnn.array()[:,35,:],cmap='gray')\n", 343 | "name = 'merl_tnn-' + str(int(rho*100)) + '%'+'.eps'\n", 344 | "plt.savefig(name)\n", 345 | "plt.show()\n", 346 | "\n", 347 | "(score, diff) = compare_ssim(L.array()[:,35,:], Y_hat_tnn.array()[:,35,:], full=True)\n", 348 | "print('TNN-ADMM SSIM: {:.5f}'.format(score))" 349 | ] 350 | }, 351 | { 352 | "cell_type": "markdown", 353 | "metadata": {}, 354 | "source": [ 355 | "### TCTF" 356 | ] 357 | }, 358 | { 359 | "cell_type": "code", 360 | "execution_count": null, 361 | "metadata": {}, 362 | "outputs": [], 363 | "source": [ 364 | "## TCTF\n", 365 | "rank = 2\n", 366 | "fun = lambda X,Z: [0, tfrobnorm(X*Z - L) / Lfrob]\n", 367 | "Xtctf,Ztctf, stats_tctf, tElapsed_tctf = tctf(Y,mask,rank,niter = 75,fun=fun,verbose=False)\n", 368 | "Y_hat_tctf = Xtctf * Ztctf\n", 369 | "\n", 370 | "tctf_nrmse = tfrobnorm(Y_hat_tctf - L) / Lfrob\n", 371 | "print('TCTF Time: {:4f}'.format(tElapsed_tctf))\n", 372 | "print('TCTF NRMSE: {:4f}'.format(tctf_nrmse))\n", 373 | "\n", 374 | "plt.imshow(Y_hat_tctf.array()[:,35,:],cmap='gray')\n", 375 | "name = 'merl_tctf-' + str(int(rho*100)) + '%'+'.eps'\n", 376 | "plt.savefig(name)\n", 377 | "plt.show()\n", 378 | "\n", 379 | "(score, diff) = compare_ssim(L.array()[:,35,:], Y_hat_tctf.array()[:,35,:], full=True)\n", 380 | "print('TCTF SSIM: {:.5f}'.format(score))" 381 | ] 382 | }, 383 | { 384 | "cell_type": "markdown", 385 | "metadata": {}, 386 | "source": [ 387 | "### TOUCAN" 388 | ] 389 | }, 390 | { 391 | "cell_type": "code", 392 | "execution_count": null, 393 | "metadata": {}, 394 | "outputs": [], 395 | "source": [ 396 | "## TOUCAN\n", 397 | "rank = 1\n", 398 | "# fun = lambda X: [0, tfrobnorm(X - L) / Lfrob]\n", 399 | "fun = lambda X,k: [0, tfrobnorm_array(X - L.array()[:,k,:]) / tfrobnorm_array(L.array()[:,k,:])]\n", 400 | "Y_hat_toucan, U, w, stats_toucan, tElapsed_toucan = toucan(Y,mask,rank,tube=False,outer=2,mode='online',fun=fun,cgtol=1e-6,\n", 401 | " randomOrder=False,verbose=False)\n", 402 | "\n", 403 | "toucan_nrmse = tfrobnorm(Y_hat_toucan - L) / Lfrob\n", 404 | "print('TOUCAN Time: {:4f}'.format(tElapsed_toucan))\n", 405 | "print('TOUCAN NRMSE: {:4f}'.format(toucan_nrmse))\n", 406 | "\n", 407 | "plt.imshow(Y_hat_toucan.array()[:,35,:],cmap='gray')\n", 408 | "name = 'merl_toucan-' + str(int(rho*100)) + '%'+'.eps'\n", 409 | "plt.savefig(name)\n", 410 | "plt.show()\n", 411 | "\n", 412 | "(score, diff) = compare_ssim(L.array()[:,35,:], Y_hat_toucan.array()[:,35,:], full=True)\n", 413 | "print('TOUCAN SSIM: {:.5f}'.format(score))" 414 | ] 415 | }, 416 | { 417 | "cell_type": "code", 418 | "execution_count": null, 419 | "metadata": {}, 420 | "outputs": [], 421 | "source": [ 422 | "plt.imshow(U.array().squeeze())\n", 423 | "plt.show()\n", 424 | "\n", 425 | "U_row = U.array()[100,0,:]\n", 426 | "plt.plot(U_row)\n", 427 | "plt.show()\n", 428 | "\n", 429 | "plt.plot(w.array()[0,0,:])\n", 430 | "plt.show()\n", 431 | "\n", 432 | "rec = (U * w).array().squeeze()\n", 433 | "\n", 434 | "plt.imshow(rec)\n", 435 | "plt.show()\n", 436 | "\n", 437 | "w_fft = np.fft.fft(w.array()[0,0,:])\n", 438 | "plt.plot(np.abs(w_fft))\n", 439 | "plt.show()\n", 440 | "\n", 441 | "U_row_fft = np.fft.fft(U.array()[100,0,:])\n", 442 | "plt.plot(np.abs(U_row_fft))\n", 443 | "plt.show()\n", 444 | "\n", 445 | "mult = U_row_fft * w_fft\n", 446 | "plt.plot(np.abs(mult))\n", 447 | "plt.show()\n", 448 | "\n", 449 | "im_row_reconst = np.real(np.fft.ifft(mult))\n", 450 | "plt.plot(im_row_reconst)\n", 451 | "plt.show()\n", 452 | "\n" 453 | ] 454 | }, 455 | { 456 | "cell_type": "code", 457 | "execution_count": null, 458 | "metadata": {}, 459 | "outputs": [], 460 | "source": [ 461 | "## Matrix Completion Algorithms\n", 462 | "\n", 463 | "Lm = lr_flatten(L.array())\n", 464 | "Ym = lr_flatten(Y.array())\n", 465 | "mask_m = lr_flatten(mask)" 466 | ] 467 | }, 468 | { 469 | "cell_type": "markdown", 470 | "metadata": {}, 471 | "source": [ 472 | "### MatComp" 473 | ] 474 | }, 475 | { 476 | "cell_type": "code", 477 | "execution_count": null, 478 | "metadata": {}, 479 | "outputs": [], 480 | "source": [ 481 | "## MatComp\n", 482 | "\n", 483 | "fun = lambda X: [0, np.linalg.norm(X - Lm,'fro') / np.linalg.norm(Lm,'fro')]\n", 484 | "Ym_hat, stats_matcomp, tElapsed_matcomp = lrmc(Ym,mask_m,niter=100,fun=fun,verbose=False)\n", 485 | "\n", 486 | "matcomp_nrmse = np.linalg.norm(Ym_hat - Lm,'fro') / np.linalg.norm(Lm,'fro')\n", 487 | "print('TOUCAN Time: {:4f}'.format(tElapsed_matcomp))\n", 488 | "print('TOUCAN NRMSE: {:4f}'.format(matcomp_nrmse))\n", 489 | "\n", 490 | "plt.imshow(np.reshape(Ym_hat[:,35],(int(ny),int(nx))).T,cmap='gray')\n", 491 | "name = 'merl_matcomp-' + str(int(rho*100)) + '%'+'.eps'\n", 492 | "plt.savefig(name)\n", 493 | "plt.show()\n", 494 | "\n", 495 | "(score, diff) = compare_ssim(Lm[:,35], Ym_hat[:,35], full=True)\n", 496 | "print('MatComp SSIM: {:.5f}'.format(score))" 497 | ] 498 | }, 499 | { 500 | "cell_type": "markdown", 501 | "metadata": {}, 502 | "source": [ 503 | "### GROUSE" 504 | ] 505 | }, 506 | { 507 | "cell_type": "code", 508 | "execution_count": null, 509 | "metadata": {}, 510 | "outputs": [], 511 | "source": [ 512 | "## GROUSE\n", 513 | "rank = 1\n", 514 | "# fun = lambda X: [0, np.linalg.norm(X - Lm, 'fro') / np.linalg.norm(Lm, 'fro')]\n", 515 | "fun = lambda X,idx: [0, np.linalg.norm(X - Lm[:,idx]) / np.linalg.norm(Lm[:,idx])]\n", 516 | "Ym_hat_grouse, stats_grouse, tElapsed_grouse = grouse(Ym, mask_m, rank,outer=3,mode=\"online\",fun=fun,randomOrder=False,\n", 517 | " verbose=False)\n", 518 | "grouse_nrmse = np.linalg.norm(Ym_hat_grouse - Lm,'fro') / np.linalg.norm(Lm,'fro')\n", 519 | "print('GROUSE Time: {:4f}'.format(tElapsed_grouse))\n", 520 | "print('GROUSE NRMSE: {:4f}'.format(grouse_nrmse))\n", 521 | "\n", 522 | "plt.imshow(np.reshape(Ym_hat_grouse[:,35],(int(ny),int(nx))).T,cmap='gray')\n", 523 | "name = 'merl_grouse-' + str(int(rho*100)) + '%'+'.eps'\n", 524 | "plt.savefig(name)\n", 525 | "plt.show()\n", 526 | "\n", 527 | "(score, diff) = compare_ssim(Lm[:,35], Ym_hat_grouse[:,35], full=True)\n", 528 | "print('GROUSE SSIM: {:.5f}'.format(score))" 529 | ] 530 | }, 531 | { 532 | "cell_type": "code", 533 | "execution_count": null, 534 | "metadata": {}, 535 | "outputs": [], 536 | "source": [ 537 | "# ## Write out results\n", 538 | "\n", 539 | "# for i in range(0,num_frames):\n", 540 | "# name = path + 'observed/frame_' + '%02d' % i + '.png'\n", 541 | "# scipy.misc.imsave(name, np.reshape(Ym[:, i], (int(ny), int(nx))).T)\n", 542 | "\n", 543 | "# # name = path + '/tnn_results/frame_' + '%02d' % i + '.png'\n", 544 | "# # scipy.misc.imsave(name, Y_hat_tnn.array()[:, i, :])\n", 545 | "\n", 546 | "# # name = path + '/tctf_results/frame_' + '%02d' % i + '.png'\n", 547 | "# # scipy.misc.imsave(name, Y_hat_tctf.array()[:, i, :])\n", 548 | "\n", 549 | "# name = path + '/toucan_results/frame_' + '%02d' % i + '.png'\n", 550 | "# scipy.misc.imsave(name, Y_hat_toucan.array()[:, i, :])\n", 551 | "\n", 552 | "# # name = path + 'matcomp_results/frame_' + '%02d' % i + '.png'\n", 553 | "# # scipy.misc.imsave(name, np.reshape(Ym_hat[:, i],(int(nx),int(ny))))\n", 554 | "\n", 555 | "# name = path + 'grouse_results/frame_' + '%02d' % i + '.png'\n", 556 | "# scipy.misc.imsave(name, np.reshape(Ym_hat_grouse[:, i], (int(ny), int(nx))).T)\n", 557 | " \n", 558 | "# name = path + 'tecpsgd_results/frame_' + '%02d' % i + '.png'\n", 559 | "# scipy.misc.imsave(name, np.reshape(Y_hat_tecpsgd[:, i], (int(nx), int(ny))))\n", 560 | " \n", 561 | "# name = path + 'olstec_results/frame_' + '%02d' % i + '.png'\n", 562 | "# scipy.misc.imsave(name, np.reshape(Y_hat_olstec[:, i], (int(nx), int(ny))))\n", 563 | " " 564 | ] 565 | }, 566 | { 567 | "cell_type": "code", 568 | "execution_count": null, 569 | "metadata": {}, 570 | "outputs": [], 571 | "source": [ 572 | "# plt.figure(figsize=(10,5), tight_layout=True)\n", 573 | "# plt.semilogy(np.cumsum(stats_toucan[:,-1]),stats_toucan[:,1], '#ff7f0e',\n", 574 | "# np.cumsum(stats_grouse[:, -1]), stats_grouse[:, 1], 'b',\n", 575 | "# np.cumsum(stats_tnn[:,-1]),stats_tnn[:,1], 'r',\n", 576 | "# np.cumsum(stats_tctf[:,-1]),stats_tctf[:,1], 'k',\n", 577 | "# np.cumsum(stats_matcomp[:,-1]),stats_matcomp[:,1],'g',\n", 578 | "# )\n", 579 | "\n", 580 | "# plt.legend(('TOUCAN', 'GROUSE', 'TNN-ADMM', 'TCTF', 'MatComp'),bbox_to_anchor=(1.1, 1 ))\n", 581 | "# plt.xlabel('Time (s)')\n", 582 | "# plt.ylabel('NRMSE')\n", 583 | "# plt.grid()\n", 584 | "# plt.rcParams.update({'font.size': 22})\n", 585 | "# plt.savefig('merl_video_completion_nrmse_' + str(rho*100) + 'p.eps')\n", 586 | "# plt.show()" 587 | ] 588 | }, 589 | { 590 | "cell_type": "code", 591 | "execution_count": null, 592 | "metadata": {}, 593 | "outputs": [], 594 | "source": [ 595 | "plt.figure(figsize=(10,5), tight_layout=True)\n", 596 | "plt.semilogy(np.cumsum(stats_toucan[:,-1]),stats_toucan[:,1],'#ff7f0e',\n", 597 | " np.cumsum(stats_grouse[:, -1]), stats_grouse[:, 1],'b',\n", 598 | " np.cumsum(sub_infos_TeCPSGD['times'][1:]),sub_infos_TeCPSGD['err_residual'][1:],'k',\n", 599 | " np.cumsum(sub_infos_olstec['times'][1:]),sub_infos_olstec['err_residual'][1:],'r'\n", 600 | ")\n", 601 | "\n", 602 | "plt.legend(('TOUCAN', 'GROUSE', 'TeCPSGD', 'OLSTEC'),bbox_to_anchor=(1.1, 1 ))\n", 603 | "plt.xlabel('Time (s)')\n", 604 | "plt.ylabel('NRMSE')\n", 605 | "plt.grid()\n", 606 | "plt.rcParams.update({'font.size': 22})\n", 607 | "# plt.savefig('merl_video_completion_nrmse_' + str(rho*100) + 'p.eps')\n", 608 | "plt.show()" 609 | ] 610 | }, 611 | { 612 | "cell_type": "code", 613 | "execution_count": null, 614 | "metadata": {}, 615 | "outputs": [], 616 | "source": [ 617 | "plt.figure(figsize=(12,5), tight_layout=True)\n", 618 | "plt.semilogy(np.arange(0,len(stats_toucan)),stats_toucan[:,1],'#ff7f0e',\n", 619 | " np.arange(0,len(stats_grouse)), stats_grouse[:, 1],'b',\n", 620 | " np.arange(0,len(sub_infos_TeCPSGD['times'][1:])),sub_infos_TeCPSGD['err_residual'][1:],'k',\n", 621 | " np.arange(0,len(sub_infos_olstec['times'][1:])),sub_infos_olstec['err_residual'][1:],'r'\n", 622 | ")\n", 623 | "\n", 624 | "plt.legend(('TOUCAN', 'GROUSE', 'TeCPSGD', 'OLSTEC'),bbox_to_anchor=(1.1, 1 ))\n", 625 | "plt.xlabel('Slice index')\n", 626 | "plt.ylabel('NRMSE')\n", 627 | "plt.grid()\n", 628 | "plt.rcParams.update({'font.size': 22})\n", 629 | "# plt.savefig('merl_video_completion_nrmse_' + str(rho*100) + 'p.eps')\n", 630 | "plt.show()" 631 | ] 632 | }, 633 | { 634 | "cell_type": "code", 635 | "execution_count": null, 636 | "metadata": {}, 637 | "outputs": [], 638 | "source": [ 639 | "olstec_frame_err = []\n", 640 | "toucan_frame_err = []\n", 641 | "\n", 642 | "for i in range(0,n2):\n", 643 | " olstec_frame_err.append(np.linalg.norm(Y_hat_olstec[:,i,:].squeeze() - L.array()[:,i,:]) / np.linalg.norm(L.array()[:,i,:]))\n", 644 | " toucan_frame_err.append(np.linalg.norm(Y_hat_toucan.array()[:,i,:].squeeze() - L.array()[:,i,:]) / np.linalg.norm(L.array()[:,i,:]))\n", 645 | "\n", 646 | "plt.semilogy(np.arange(0,len(olstec_frame_err)),olstec_frame_err,np.arange(0,len(toucan_frame_err)),toucan_frame_err)\n", 647 | "plt.legend(['OLSTEC','TOUCAN'])\n", 648 | "plt.show()" 649 | ] 650 | }, 651 | { 652 | "cell_type": "code", 653 | "execution_count": null, 654 | "metadata": {}, 655 | "outputs": [], 656 | "source": [] 657 | } 658 | ], 659 | "metadata": { 660 | "@webio": { 661 | "lastCommId": null, 662 | "lastKernelId": null 663 | }, 664 | "kernelspec": { 665 | "display_name": "Python 3", 666 | "language": "python", 667 | "name": "python3" 668 | }, 669 | "language_info": { 670 | "codemirror_mode": { 671 | "name": "ipython", 672 | "version": 3 673 | }, 674 | "file_extension": ".py", 675 | "mimetype": "text/x-python", 676 | "name": "python", 677 | "nbconvert_exporter": "python", 678 | "pygments_lexer": "ipython3", 679 | "version": "3.6.6" 680 | } 681 | }, 682 | "nbformat": 4, 683 | "nbformat_minor": 2 684 | } 685 | -------------------------------------------------------------------------------- /jupyter_notebooks/tsvd.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import time 3 | from numpy import ndarray 4 | 5 | class myarray(ndarray): 6 | @property 7 | def H(self): 8 | return self.conj().T 9 | 10 | class Tensor: 11 | def __init__(self, array): 12 | ## input is n1 x n2 x n3 multidimensional array ### 13 | self.__n1 = array.shape[0] 14 | self.__n2 = array.shape[1] 15 | self.__n3 = array.shape[2] 16 | self.__array = array 17 | 18 | def T(self): # transpose method 19 | At = np.zeros((self.__n2, self.__n1, self.__n3)) 20 | At[:, :, 0] = self.__array[:, :, 0].T 21 | for i in range(1, self.__n3): At[:, :, self.__n3 - i] = self.__array[:, :, i].T 22 | 23 | return Tensor(At) 24 | def shape(self): #return dimensions of tensors 25 | return self.__n1, self.__n2, self.__n3 26 | 27 | def array(self): #return 4d numpy array of tensor 28 | return self.__array 29 | 30 | def __add__(self, B): 31 | return Tensor(self.__array + B.__array) 32 | 33 | def __sub__(self, B): 34 | return Tensor(self.__array - B.__array) 35 | 36 | def __mul__(self, B): # tensor-tensor product method 37 | 38 | assert self.__n2 == B.__n1, "Dimensions of tensors must match" 39 | 40 | n3 = self.__n3 41 | 42 | Abar = np.fft.fft(self.__array, axis=2) 43 | Bbar = np.fft.fft(B.__array, axis=2) 44 | Cbar = np.zeros((self.__n1, B.__n2, n3), dtype=complex) 45 | 46 | for i in range(0, int(np.ceil((n3 + 1) / 2))): Cbar[:, :, i] = Abar[:, :, i] @ Bbar[:, :, i] 47 | 48 | for i in range(int(np.ceil((n3 + 1) / 2)), n3): Cbar[:, :, i] = np.conj(Cbar[:, :, n3 - i]) 49 | 50 | C = np.real(np.fft.ifft(Cbar, axis=2)) 51 | 52 | return Tensor(C) 53 | 54 | 55 | def tfrobnorm(A): 56 | # 57 | # Computes \sum_{i,j,k} A[i,j,k]^2 58 | # 59 | # sum = 0 60 | # for i in range(0, A._Tensor__n3): sum += np.linalg.norm(A._Tensor__array[:, :, i], 'fro') ** 2 61 | # return np.sqrt(sum) 62 | return np.sqrt(np.sum(A.array()**2)) 63 | 64 | 65 | def tfrobnorm_array(A): 66 | # 67 | # Computes \sum_{i,j,k} A[i,j,k]^2 68 | # 69 | return np.sqrt(np.sum(A**2)) 70 | 71 | 72 | def tsvd(A, full=True): 73 | # 74 | # Compute the tensor-SVD from "Tensor Robust Principal Component Analysis with a New Tensor Nuclear Norm," avail: https://ieeexplore.ieee.org/abstract/document/8606166 75 | # 76 | # Input 77 | # Object of class type Tensor, size n1 x n2 x n3 78 | # 79 | # Output 80 | # Orthonormal U tensor of size n1 x n1 x n3 81 | # F-diagonal S tensor of tubular singular values n1 x n2 x n3 82 | # Orthonoraml V.T tensor of size n2 x n2 x n3 83 | 84 | n1, n2, n3 = A.array().shape 85 | Abar = np.fft.fft(A.array(), axis=2) 86 | 87 | K = min(n1, n2) 88 | if (full): 89 | Ubar = np.zeros((n1, n1, n3), dtype=complex) 90 | Sbar = np.zeros((min(n1, n2), n3), dtype=complex) 91 | Vbar = np.zeros((n2, n2, n3), dtype=complex) 92 | else: 93 | Ubar = np.zeros((n1, K, n3), dtype=complex) 94 | Sbar = np.zeros((K, n3), dtype=complex) 95 | Vbar = np.zeros((K, n2, n3), dtype=complex) 96 | 97 | for i in range(0, int(np.ceil((n3 + 1) / 2))): 98 | U, S, V = np.linalg.svd(Abar[:, :, i], full_matrices=full) 99 | Ubar[:, :, i] = U 100 | Sbar[:, i] = S 101 | Vbar[:, :, i] = V 102 | 103 | for i in range(int(np.ceil((n3 + 1) / 2)), n3): 104 | Ubar[:, :, i] = np.conj(Ubar[:, :, n3 - i]) 105 | Sbar[:, i] = np.conj(Sbar[:, n3 - i]) 106 | Vbar[:, :, i] = np.conj(Vbar[:, :, n3 - i]) 107 | 108 | tU = Tensor(np.real(np.fft.ifft(Ubar, axis=2))) 109 | tV = Tensor(np.real(np.fft.ifft(Vbar, axis=2))).T() 110 | 111 | S = np.real(np.fft.ifft(Sbar, axis=1)) 112 | if (full): 113 | tS = np.zeros((n1, n2, n3)) 114 | for i in range(0, K): tS[i, i, :] = S[i, :] 115 | tS = Tensor(tS) 116 | else: 117 | tS = np.zeros((K, K, n3)) 118 | for i in range(0, K): tS[i, i, :] = S[i, :] 119 | tS = Tensor(tS) 120 | 121 | return tU, tS, tV 122 | 123 | 124 | def teye(n, n3): 125 | # 126 | # Function that returns Identity Tensor of size n x n x n3 127 | # 128 | I = np.expand_dims(np.eye(n), axis=2) 129 | I2 = np.zeros((n, n, n3 - 1)) 130 | return Tensor(np.concatenate((I, I2), axis=2)) 131 | 132 | 133 | def tpinv(A): 134 | n1, n2, n3 = A.shape() 135 | Abar = np.fft.fft(A.array(),axis=2) 136 | Apinv_bar = np.zeros((n2,n1,n3),dtype=complex) 137 | 138 | for i in range(0, int(np.ceil((n3 + 1) / 2))): 139 | Apinv_bar[:,:,i] = np.linalg.pinv(Abar[:,:,i]) 140 | for i in range(int(np.ceil((n3 + 1) / 2)), n3): 141 | Apinv_bar[:, :, i] = np.conj(Apinv_bar[:, :, n3 - i]) 142 | 143 | 144 | return Tensor(np.fft.ifft(Apinv_bar,axis=2)) 145 | 146 | def normalizeTensorVec(v): 147 | v_F = np.fft.fft(v._Tensor__array, axis=2) 148 | vnorms_F = np.expand_dims(np.linalg.norm(v_F, axis=0), axis=0) 149 | vnormal_F = v_F / vnorms_F 150 | return vnormal_F, vnorms_F 151 | 152 | 153 | def orthoTest(U): 154 | test = U.T() * U 155 | K = test._Tensor__n1 156 | n3 = test._Tensor__n3 157 | assert tfrobnorm(test - teye(K, n3)) < 1e-10 158 | 159 | def tnn(A): 160 | # 161 | # Compute the tensor nuclear norm described in "Tensor Robust Principal Component Analysis with a New Tensor Nuclear Norm," avail: https://ieeexplore.ieee.org/abstract/document/8606166 162 | # 163 | U, S, V = tsvd(A, False) 164 | S = S._Tensor__array[:, :, 0] 165 | return np.sum(np.diag(S)) 166 | 167 | def tSVST(X, beta, full=False): 168 | # 169 | # Input: multidimensional array of size n1 x n2 x n3 170 | # 171 | # Output: multidimensional array of size n1 x n2 x n3 172 | # 173 | n1, n2, n3 = X.shape 174 | Xbar = np.fft.fft(X, axis=2) 175 | 176 | Wbar = np.zeros((n1, n2, n3), dtype=complex) 177 | 178 | for i in range(0, int(np.ceil((n3 + 1) / 2))): 179 | U, S, V = np.linalg.svd(Xbar[:, :, i], full_matrices=full) 180 | Wbar[:, :, i] = U @ np.diag(np.maximum(S - beta, 0.0)) @ V 181 | 182 | for i in range(int(np.ceil((n3 + 1) / 2)), n3): 183 | Wbar[:, :, i] = np.conj(Wbar[:, :, n3 - i]) 184 | 185 | return np.real(np.fft.ifft(Wbar, axis=2)) 186 | 187 | 188 | def lrtc(Y, Mask, niter=400, min_iter = 10, rho = 1.1, mu = 1e-4, mu_max = 1e10, tol=1e-8, it_tol = 1e-4, fun=lambda X: [0, 0], verbose=False): 189 | # 190 | # Description: Iterative singular value soft thresholding algorithm for tensors 191 | # 192 | # Inputs 193 | # Y: Tensor object of size n1 x n2 x n3 194 | # mask: multidimensional binary array with 1's indicating observed entries 195 | # beta: regularizatio parameter 196 | # niter: number of ISTA iterations to perform 197 | # fun: will evaluate cost function at current iterate X (but default is function -> 0) 198 | # 199 | # Outputs 200 | # X: Tensor object object of size n1 x n2 x n3 with completed entries 201 | # cost_ista: Cost function array of size niter + 1 202 | # 203 | Y = Y._Tensor__array 204 | X = Y.copy() 205 | X0 = X.copy() 206 | LAM = np.zeros(X.shape) 207 | stats = np.zeros((niter + 1,3)) 208 | 209 | cost,nrmse = fun(Tensor(X)) 210 | stats[0,:] = [cost, nrmse, 0] 211 | nrmse_0 = nrmse.copy() 212 | for k in range(0, niter): 213 | start = time.time() 214 | X = tSVST((Mask < 1).astype(float)*X + Mask*(Y - LAM / mu), 1 / mu) 215 | LAM = LAM + mu * (X - Y) 216 | 217 | mu = min(rho * mu, mu_max) 218 | 219 | tElapsed = time.time() - start 220 | 221 | cost, nrmse = fun(Tensor(X)) 222 | stats[k + 1, :] = [cost, nrmse, tElapsed] 223 | it_diff = np.linalg.norm(X - X0) 224 | X0 = X.copy() 225 | 226 | if(nrmse < tol or (abs(nrmse - nrmse_0) < it_tol and k > min_iter)): 227 | stats = stats[:k+1,:] 228 | break 229 | 230 | nrmse_0 = nrmse.copy() 231 | 232 | if (verbose and k%10 == 0): 233 | print('Iter[{:d}]: Cost fxn: {:.3f}, NRMSE: {:.6f} '.format(k, cost, nrmse)) 234 | 235 | tElapsed = np.sum(stats[:,2]) 236 | return Tensor(X), stats, tElapsed 237 | 238 | def tctf(Y,mask,rank,niter=100,min_iter = 10,tol=1e-8,it_tol = 1e-4,fun=lambda U,V: [0, 0],verbose=False): 239 | 240 | n1,n2,n3 = Y.shape() 241 | r = rank 242 | 243 | # X = Tensor(np.random.randn(n1, r, n3)) 244 | # Z = Tensor(np.random.randn(r, n2, n3)) 245 | 246 | U,S,V = tsvd(Y,full=False) 247 | X = Tensor(U.array()[:,:rank,:]) 248 | Z = Tensor((S*V.T()).array()[:rank,:,:]) 249 | 250 | stats = np.zeros((niter + 1, 3)) 251 | 252 | # 0th iteration 253 | cost,nrmse = fun(X,Z) 254 | stats[0,:] = [cost,nrmse,0] 255 | C0 = X * Z 256 | nrmse_0 = nrmse.copy() 257 | for iter in range(0, niter): 258 | 259 | tStart = time.time() 260 | # C update 261 | C = X * Z + Tensor(np.multiply(mask, (Y - X * Z).array())) 262 | 263 | # Fourier Transforms 264 | Chat = np.fft.fft(C._Tensor__array, axis=2).view(myarray) 265 | Xhat = np.fft.fft(X._Tensor__array, axis=2).view(myarray) 266 | Zhat = np.fft.fft(Z._Tensor__array, axis=2).view(myarray) 267 | 268 | # X and Z updates in Fourier Domain 269 | # for i in range(0, n3): 270 | for i in range(0, int(np.ceil((n3 + 1) / 2))): 271 | Ci = Chat[:, :, i] 272 | Zi = Zhat[:, :, i] 273 | ZiH = Zi.H 274 | 275 | # X update 276 | Xi = Ci @ ZiH @ np.linalg.pinv(Zi @ ZiH) 277 | Xhat[:, :, i] = Xi 278 | 279 | # Z update 280 | XiH = Xi.H 281 | Zhat[:, :, i] = np.linalg.pinv(XiH @ Xi) @ XiH @ Ci 282 | 283 | for i in range(int(np.ceil((n3 + 1) / 2)), n3): 284 | Xhat[:, :, i] = np.conj(Xhat[:, :, n3 - i]) 285 | Zhat[:, :, i] = np.conj(Zhat[:, :, n3 - i]) 286 | 287 | # Inverse Fourier Transforms 288 | X = Tensor(np.real(np.fft.ifft(Xhat, axis=2))) 289 | Z = Tensor(np.real(np.fft.ifft(Zhat, axis=2))) 290 | 291 | tElapsed = time.time() - tStart 292 | 293 | cost, nrmse = fun(X,Z) 294 | stats[iter + 1,:] = [cost,nrmse,tElapsed] 295 | 296 | if(nrmse < tol or (abs(nrmse - nrmse_0) < it_tol and iter > min_iter)): 297 | stats = stats[:iter+1,:] 298 | break 299 | 300 | nrmse_0 = nrmse.copy() 301 | 302 | if(verbose): 303 | if(iter % 10 == 0): 304 | print('Iter[{:d}]: NRMSE: {:.6f} '.format(iter+1, nrmse)) 305 | 306 | 307 | tElapsed = np.sum(stats[:,-1]) 308 | return X,Z,stats,tElapsed 309 | 310 | 311 | def lr_flatten(X): 312 | n1,n2,n3 = X.shape 313 | B = np.transpose(X, [2, 0, 1]) 314 | X_m = np.reshape(B,(n1*n3,n2)) 315 | 316 | return X_m 317 | 318 | 319 | def lrmc(Y, Mask, niter=400, rho=1.1, mu=1e-4, mu_max=1e10, tol=1e-8, fun=lambda X: [0, 0], verbose=False): 320 | X = Y.copy() 321 | X0 = X.copy() 322 | LAM = np.zeros(X.shape) 323 | 324 | stats = np.zeros((niter + 1, 3)) 325 | stats[0, :] = np.append(fun(X), 0) 326 | 327 | for k in range(0, niter): 328 | start = time.time() 329 | X = SVST((Mask < 1).astype(float) * X + Mask * (Y - LAM / mu), 1 / mu) 330 | LAM = LAM + mu * (X - Y) 331 | 332 | mu = min(rho * mu, mu_max) 333 | 334 | end = time.time() 335 | tElapsed = end - start 336 | 337 | cost, nrmse = fun(X) 338 | stats[k + 1, :] = [cost, nrmse, tElapsed] 339 | 340 | if(nrmse 0)[0] 509 | tStart = time.time() 510 | U, w = grouse_stream(Yvec, idx, U,step) 511 | tEnd = time.time() 512 | 513 | tElapsed = tEnd - tStart 514 | 515 | rec = U @ w 516 | 517 | if(mode == "online"): 518 | Lhat[:,inner] = rec 519 | cost, nrmse = fun(rec,frame_idx) 520 | else: 521 | # Lhat[:,inner] = rec 522 | What[:,frame_idx] = w 523 | Lhat = U @ What 524 | cost, nrmse = fun(Lhat) 525 | if(nrmse < tol): 526 | stats = stats[:iter,:] 527 | break 528 | 529 | stats[iter, :] = [cost, nrmse, tElapsed] 530 | iter += 1 531 | 532 | if(verbose): 533 | print('Outer[{:d}], Inner[{:d}]: NRMSE: {:.3f} '.format(outer, inner, nrmse)) 534 | 535 | tElapsed = np.sum(stats[:,2]) 536 | return Lhat, stats, tElapsed 537 | 538 | def toucan_stream_tube(Yvec,idx,U,step=None): 539 | # 540 | # Description: performs one step of TOUCAN in the regime of missing TUBES to estimate n1 x K x n3 orthonormal 541 | # tensor U from input tensor column of size n1 x 1 x n3 542 | # 543 | # Inputs: 544 | # Yvec: n1 x 1 x n3 tensor column (of type Tensor) 545 | # idx: list of indices where Yvec is observed on the first dimension 546 | # U: n1 x K x n3 orthonormal tensor (of type Tensor) initial estimate 547 | # step: step size (real-valued constant) 548 | # 549 | # Outputs: 550 | # U: updated estimate of orthonormal n1 x K x n3 tensor (of type Tensor) 551 | # w: estimated principal components (weights) of size K x 1 x n3 (of type Tensor) 552 | # 553 | n1,K,n3 = U.shape() 554 | 555 | v_Omega = Tensor(Yvec.array()[idx, :, :]) 556 | U_Omega = Tensor(U.array()[idx, :, :]) 557 | 558 | w = tpinv(U_Omega) * v_Omega 559 | 560 | p = U * w 561 | r = v_Omega - U_Omega * w 562 | 563 | wnormal_F,wnorms_F = normalizeTensorVec(w) 564 | pnormal_F,pnorms_F = normalizeTensorVec(p) 565 | rnormal_F,rnorms_F = normalizeTensorVec(r) 566 | 567 | Ustep_F = np.zeros((n1, K, n3), dtype=complex) 568 | 569 | 570 | 571 | for i in range(0, int(np.ceil((n3 + 1) / 2))): 572 | if (step is None): 573 | t = np.arctan(rnorms_F[0, 0, i] / wnorms_F[0, 0, i]) 574 | alpha = (np.cos(t) - 1) 575 | beta = np.sin(t) 576 | else: 577 | sG = rnorms_F[0, 0, i] * wnorms_F[0, 0, i] 578 | alpha = (np.cos(step * sG) - 1) 579 | beta = np.sin(step * sG) 580 | 581 | gamma = alpha * pnormal_F[:, :, i] 582 | gamma[idx, :] = gamma[idx, :] + beta * rnormal_F[:, :, i] 583 | Ustep_F[:, :, i] = gamma @ wnormal_F[:, :, i].conj().T 584 | 585 | for i in range(int(np.ceil((n3 + 1) / 2)), n3): 586 | Ustep_F[:, :, i] = np.conj(Ustep_F[:, :, n3 - i]) 587 | 588 | Ustep = Tensor(np.real(np.fft.ifft(Ustep_F, axis=2))) 589 | 590 | U = U + Ustep 591 | 592 | orthoTest(U) 593 | 594 | return U,w 595 | 596 | 597 | def normalizeTensorVec2(v_F): 598 | vnorms_F = np.expand_dims(np.linalg.norm(v_F, axis=0), axis=0) 599 | vnormal_F = v_F / vnorms_F 600 | 601 | return vnormal_F, vnorms_F 602 | 603 | 604 | def fourier_dot(abar, bbar): 605 | n1, _, n3 = abar.shape 606 | _, n2, _ = bbar.shape 607 | cbar = np.zeros(n3, dtype=complex) 608 | 609 | for i in range(0, int(np.ceil((n3 + 1) / 2))): cbar[i] = abar[:, :, i].conj().T @ bbar[:, :, i] 610 | 611 | for i in range(int(np.ceil((n3 + 1) / 2)), n3): cbar[i] = np.conj(cbar[n3 - i]) 612 | 613 | return np.sum(cbar) 614 | 615 | 616 | def fourier_mult(Abar, Bbar): 617 | n1, _, n3 = Abar.shape 618 | _, n2, _ = Bbar.shape 619 | Cbar = np.zeros((n1, n2, n3), dtype=complex) 620 | 621 | for i in range(0, int(np.ceil((n3 + 1) / 2))): Cbar[:, :, i] = Abar[:, :, i] @ Bbar[:, :, i] 622 | 623 | for i in range(int(np.ceil((n3 + 1) / 2)), n3): Cbar[:, :, i] = np.conj(Cbar[:, :, n3 - i]) 624 | 625 | return Cbar 626 | 627 | 628 | def cg(Ubar, vbar, mask, cg_iter=None, tol=1e-12): 629 | 630 | n1,K,n3 = Ubar.shape 631 | 632 | if (cg_iter is None): 633 | cg_iter = np.prod(vbar.shape) 634 | 635 | UbarT = np.transpose(Ubar, [1, 0, 2]).conj() 636 | wbar = np.zeros((K, 1, n3), dtype=complex) 637 | bbar = fourier_mult(UbarT, vbar) 638 | 639 | F_inv = lambda qbar: np.real(np.fft.ifft(qbar, axis=2)) 640 | F = lambda q: np.fft.fft(q, axis=2) 641 | A = lambda x: fourier_mult(UbarT, F(np.multiply(mask, F_inv(fourier_mult(Ubar, x))))) 642 | 643 | Aw = A(wbar) 644 | r = bbar - Aw 645 | p = r.copy() 646 | rsold = fourier_dot(r, r) 647 | 648 | num_iters = 0 649 | for i in range(0, cg_iter): 650 | 651 | Ap = A(p) 652 | 653 | # test = fourier_dot(p, Ap) 654 | alpha = rsold / fourier_dot(p, Ap) 655 | wbar = wbar + alpha * p 656 | r = r - alpha * Ap 657 | rsnew = fourier_dot(r, r) 658 | 659 | p = r + (rsnew / rsold) * p 660 | rsold = rsnew 661 | 662 | num_iters += 1 663 | 664 | if (np.sqrt(np.real(rsnew)) < tol): 665 | break 666 | 667 | 668 | return wbar, num_iters 669 | 670 | 671 | def toucan_stream(Yvec, M, U, step=None, cgiter = None, cgtol = None): 672 | # 673 | # Description: performs one step of TOUCAN in the regime of missing random ENTRIES to estimate n1 x K x n3 674 | # orthonormal tensor U from input tensor column of size n1 x 1 x n3 675 | # 676 | # Inputs: 677 | # Yvec: n1 x 1 x n3 tensor column (of type Tensor) 678 | # Mask: n1 x 1 x n3 array of 0 and 1 where Yvec is observed on indices indicated by 1 679 | # U: n1 x K x n3 orthonormal tensor (of type Tensor) initial estimate 680 | # 681 | # Outputs: 682 | # U: updated estimate of orthonormal n1 x K x n3 tensor (of type Tensor) 683 | # w: estimated weights of size K x 1 x n3 (of type Tensor) 684 | # 685 | 686 | M = np.expand_dims(M.astype(float), axis=1) 687 | 688 | n1,K,n3 = U.shape() 689 | 690 | v = Tensor(np.multiply(M, Yvec.array())) 691 | 692 | Ubar = np.fft.fft(U.array(), axis=2) 693 | UbarT = np.transpose(Ubar, [1, 0, 2]).conj() 694 | vbar = np.fft.fft(v.array(), axis=2) 695 | 696 | ## Update w 697 | if(cgtol is None): 698 | cgtol = 1e-14 699 | 700 | wbar,cg_iters = cg(Ubar, vbar, M, cg_iter = cgiter, tol=cgtol) 701 | 702 | ## Update U 703 | pbar = fourier_mult(Ubar, wbar) 704 | 705 | p = np.real(np.fft.ifft(pbar, axis=2)) 706 | r = np.multiply(M, v.array() - p) 707 | rbar = np.fft.fft(r, axis=2) 708 | 709 | rbar2 = np.zeros(rbar.shape, dtype=complex) 710 | for i in range(0, int(np.ceil((n3 + 1) / 2))): 711 | rbar2[:, :, i] = Ubar[:, :, i] @ (UbarT[:, :, i] @ rbar[:, :, i]) 712 | for i in range(int(np.ceil((n3 + 1) / 2)), n3): 713 | rbar2[:, :, i] = np.conj(rbar2[:, :, n3 - i]) 714 | 715 | rbar -= rbar2 716 | 717 | wnormal_F, wnorms_F = normalizeTensorVec2(wbar) 718 | pnormal_F, pnorms_F = normalizeTensorVec2(pbar) 719 | rnormal_F, rnorms_F = normalizeTensorVec2(rbar) 720 | 721 | Ustep_F = np.zeros((n1, K, n3), dtype=complex) 722 | 723 | for i in range(0, int(np.ceil((n3 + 1) / 2))): 724 | if (step is None): 725 | t = np.arctan(rnorms_F[0, 0, i] / wnorms_F[0, 0, i]) 726 | alpha = (np.cos(t) - 1) 727 | beta = np.sin(t) 728 | else: 729 | sG = rnorms_F[0, 0, i] * wnorms_F[0, 0, i] 730 | alpha = (np.cos(step * sG) - 1) 731 | beta = np.sin(step * sG) 732 | 733 | gamma = alpha * pnormal_F[:, :, i] 734 | gamma = gamma + beta * rnormal_F[:, :, i] 735 | Ustep_F[:, :, i] = gamma @ wnormal_F[:, :, i].conj().T 736 | 737 | for i in range(int(np.ceil((n3 + 1) / 2)), n3): 738 | Ustep_F[:, :, i] = np.conj(Ustep_F[:, :, n3 - i]) 739 | 740 | Ubar = Ubar + Ustep_F 741 | 742 | U = Tensor(np.real(np.fft.ifft(Ubar, axis=2))) 743 | w = Tensor(np.real(np.fft.ifft(wbar, axis=2))) 744 | # orthoTest(U) 745 | 746 | return U, w, cg_iters 747 | 748 | 749 | def toucan(Y,mask,rank,tube,outer,mode,tol=1e-9,step=None,cgiter=None,cgtol=1e-9,fun=lambda Lhat, idx: [0,0], 750 | randomOrder=False,verbose=False, U0=None): 751 | 752 | n1,n2,n3 = Y.shape() 753 | 754 | ### Initialize U 755 | K = rank 756 | 757 | if(U0 is None): 758 | U = tsvd(Tensor(np.random.randn(n1, K, n3)),full=False)[0] 759 | U = Tensor(U.array()[:, :K, :]) 760 | else: 761 | U = U0 762 | 763 | Lhat = np.zeros((n1, n2, n3)) 764 | What = np.zeros((K,n2,n3)) 765 | 766 | stats_toucan = np.zeros((outer * n2 + 1,4)) 767 | # cost,nrmse = fun(Tensor(Lhat)) 768 | cost = 0 769 | nrmse = 1 770 | stats_toucan[0,:] = [cost,nrmse, 0, 0] 771 | 772 | iter = 1 773 | for outer in range(0, outer): 774 | if(nrmse < tol): 775 | stats_toucan = stats_toucan[:iter,:] 776 | break; 777 | 778 | if (randomOrder): 779 | frame_order = np.random.permutation(n2) 780 | else: 781 | frame_order = np.arange(0,n2) 782 | 783 | for inner in range(0, n2): 784 | frame_idx = frame_order[inner] 785 | Yvec = Tensor(np.expand_dims(Y.array()[:, frame_idx, :], axis=1)) 786 | 787 | if (tube is True): 788 | idx = np.where(mask[:, inner, 0] > 0)[0] 789 | tStart = time.time() 790 | U, w = toucan_stream_tube(Yvec, idx, U, step) 791 | tEnd = time.time() 792 | cg_iters = 0 793 | else: 794 | if(cgiter is not None): 795 | num_cgiter = cgiter**(outer + 1) 796 | else: 797 | num_cgiter = None 798 | 799 | tStart = time.time() 800 | U, w, cg_iters = toucan_stream(Yvec, mask[:, frame_idx, :], U, step,cgiter = num_cgiter, cgtol=cgtol) 801 | tEnd = time.time() 802 | 803 | tElapsed = tEnd - tStart 804 | 805 | if(mode == 'online'): ## online mode 806 | rec = U * w 807 | # cost, nrmse = fun(rec.array().squeeze(),frame_idx) 808 | Lhat[:, frame_idx, :] = rec.array().squeeze() 809 | 810 | else: 811 | What[:,frame_idx,:] = w.array().squeeze() 812 | Lhat = (U * Tensor(What)).array() ## batch mode 813 | 814 | cost, nrmse = fun(Tensor(Lhat),frame_idx) 815 | if(nrmse < tol): 816 | break; 817 | 818 | stats_toucan[iter, :] = [cost, nrmse, cg_iters, tElapsed] 819 | iter += 1 820 | 821 | if(verbose): 822 | if(inner % 10 ==0): 823 | print('Outer[{:d}], Inner[{:d}]: NRMSE: {:.8f} '.format(outer, inner, nrmse)) 824 | 825 | tElapsed = np.sum(stats_toucan[:,-1]) 826 | 827 | 828 | return Tensor(Lhat), U, stats_toucan, tElapsed 829 | --------------------------------------------------------------------------------