├── src ├── plot_electrodes.m ├── symmetrize_colorbar.m ├── standardize.m ├── random_orthogonal.m ├── prox_tv0.m ├── reorder_components.m ├── smooth_linear_comparison_setup.m ├── compute_pdist.m ├── smooth_linear_setup.m ├── switching_linear_comparison_setup.m ├── rebalance.m ├── get_nino.m ├── get_pdo.m ├── init_dmd.m ├── switching_linear_setup.m ├── indep_models.m ├── rebalance_2.m ├── example_worms.m ├── preprocess_neurotycho.py ├── example_lorenz.m ├── smooth_linear.m ├── example_el_nino.m ├── smooth_linear_comparison.m ├── ssm_compare.py ├── switching_linear_comparison.m ├── switching_linear.m ├── TVART_alt_min.m ├── example_neurotycho.m └── run_all_tests_ARHMM-SSM.ipynb ├── data ├── K1_electrodes.mat ├── lorenz_0_001.mat └── PDO.txt ├── figures └── monkeys │ ├── monkey_K1.png │ └── monkey_K1_overview.png ├── LICENSE.txt └── README.md /src/plot_electrodes.m: -------------------------------------------------------------------------------- 1 | function plot_electrodes(vec, x, y) 2 | plot( 3 | -------------------------------------------------------------------------------- /data/K1_electrodes.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kamdh/tvart/HEAD/data/K1_electrodes.mat -------------------------------------------------------------------------------- /data/lorenz_0_001.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kamdh/tvart/HEAD/data/lorenz_0_001.mat -------------------------------------------------------------------------------- /figures/monkeys/monkey_K1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kamdh/tvart/HEAD/figures/monkeys/monkey_K1.png -------------------------------------------------------------------------------- /figures/monkeys/monkey_K1_overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kamdh/tvart/HEAD/figures/monkeys/monkey_K1_overview.png -------------------------------------------------------------------------------- /src/symmetrize_colorbar.m: -------------------------------------------------------------------------------- 1 | function symmetrize_colorbar() 2 | ax = gca(); 3 | cax = caxis(); 4 | lim = max(abs(cax)); 5 | caxis([-lim, lim]); 6 | end -------------------------------------------------------------------------------- /src/standardize.m: -------------------------------------------------------------------------------- 1 | function Y = standardize(X); 2 | % Y = X - repmat(mean(X, 1), size(X,1), 1); 3 | Y = X - repmat(mean(X, 2), 1, size(X, 2)); 4 | Y = Y ./ repmat(std(Y, 0, 2), 1, size(X, 2)); 5 | end 6 | -------------------------------------------------------------------------------- /src/random_orthogonal.m: -------------------------------------------------------------------------------- 1 | function U = random_orthogonal(n,m) 2 | % Z = (randn(n,n) + 1.j * randn(n,n))/sqrt(2*n); 3 | Z = rand(n,m) / sqrt(n); 4 | [U,S,V] = svd(Z,0); 5 | % [Q,R] = qr(Z); 6 | % D = sign(diag(R)); 7 | % U = Q * diag(D); -------------------------------------------------------------------------------- /src/prox_tv0.m: -------------------------------------------------------------------------------- 1 | function [y, cost] = prox_tv0(x, beta) 2 | y = zeros(size(x)); 3 | ipts = findchangepts(x, 'statistic', 'mean', 'minthreshold', beta); 4 | npts = length(ipts); 5 | if npts > 0 6 | for i=1:npts+1 7 | if i == 1 8 | idxi = 1:ipts(i)-1; 9 | elseif i == (npts+1) 10 | idxi = ipts(i-1):length(x); 11 | else 12 | idxi = ipts(i-1):ipts(i)-1; 13 | end 14 | y(idxi) = mean(x(idxi)); 15 | end 16 | else 17 | y(:) = mean(x); 18 | end 19 | cost = npts * beta; 20 | end -------------------------------------------------------------------------------- /src/reorder_components.m: -------------------------------------------------------------------------------- 1 | function [lambda, A, B, C, varargout] = ... 2 | reorder_components(lambda, A, B, C, varargin) 3 | [~,I] = sort(lambda, 'descend'); 4 | A = A(:,I); 5 | B = B(:,I); 6 | C = C(:,I); 7 | if nargin > 4 8 | W = varargin{1}; 9 | W = W(:,I); 10 | end 11 | lambda = lambda(I); 12 | for i=1:size(C,2) 13 | if mean(C(:,i)) < 0 14 | C(:,i) = C(:,i) * -1; 15 | A(:,i) = A(:,i) * -1; 16 | if nargin > 4 17 | W(:,i) = W(:,i) * -1; 18 | end 19 | end 20 | end 21 | varargout = {}; 22 | if nargin > 4 23 | varargout{1} = W; 24 | end 25 | end 26 | -------------------------------------------------------------------------------- /src/smooth_linear_comparison_setup.m: -------------------------------------------------------------------------------- 1 | clear all 2 | close all 3 | 4 | addpath('~/work/MATLAB/') 5 | 6 | %% Parameters 7 | %% test system 8 | rng(1) 9 | 10 | M = 1; 11 | r = 6; 12 | num_steps = M * 160 + 1; 13 | num_trans = 200; 14 | noise_std = 0.2; 15 | noise_process = 0.0; 16 | T = floor((num_steps - 1) / M); 17 | noise_compensate = 0; 18 | offset = 0; 19 | 20 | N_vec = [6,12,24,50,100,200,400,1000,2000,4000]; 21 | %N_vec = [8000]; 22 | 23 | for N = N_vec 24 | %% Setup the problem 25 | [X, thetas, U] = smooth_linear_setup(N, num_steps, noise_std); 26 | save(sprintf('../data/test_data_smooth_N_%d_M_%d_sigma_%f.mat', N, num_steps, ... 27 | noise_std), ... 28 | 'X', 'thetas', 'U', 'N', 'num_steps', 'noise_std'); 29 | end 30 | -------------------------------------------------------------------------------- /src/compute_pdist.m: -------------------------------------------------------------------------------- 1 | function D = compute_pdist(A, B, C) 2 | T = size(C, 1); 3 | D = zeros(T*(T-1)/2, 1); 4 | idx = 1; 5 | for w_i = 1:T 6 | for w_j = (w_i + 1):T 7 | Cdiff = diag(C(w_j, :)) - diag(C(w_i, :)); 8 | if size(B, 1) == size(A, 1) + 1 9 | % affine mode 10 | D(idx) = ... 11 | norm(A * Cdiff * B(1:end-1, :)', ... 12 | 'fro') + ... 13 | norm(Cdiff * B(end, :)'); 14 | elseif size(B, 1) == size(A, 1) 15 | D(idx) = ... 16 | norm(A * Cdiff * B', 'fro'); 17 | else 18 | error('Dimension mismatch in A and B'); 19 | end 20 | idx = idx + 1; 21 | end 22 | end 23 | end -------------------------------------------------------------------------------- /src/smooth_linear_setup.m: -------------------------------------------------------------------------------- 1 | function [X, thetas, U] = smooth_linear_setup(N, num_steps, sigma) 2 | num_trans = 200; 3 | ts = 1:(num_steps + num_trans); 4 | tau = length(ts); 5 | K = zeros(tau, tau); 6 | for i = 1:tau 7 | for j = 1:tau 8 | K(i,j) = exp(-(ts(i) - ts(j))^2 / (30)^2); 9 | if i == j 10 | K(i,j) = K(i,j) + 0.001; 11 | end 12 | end 13 | end 14 | R = chol(K); 15 | thetas = R * randn(tau, 1); 16 | 17 | if N > 2 18 | U = random_orthogonal(N, 2); 19 | else 20 | U = eye(2); 21 | end 22 | 23 | x = ones(N,1); 24 | X = zeros(N, num_steps); 25 | % transient removal 26 | for i = 1:tau 27 | Ar = [cos(thetas(i)) -sin(thetas(i)); 28 | sin(thetas(i)) cos(thetas(i))]; 29 | A = U * Ar * U'; 30 | x = A * x; 31 | if i == 1 32 | x = x / norm(x) * sqrt(N); 33 | end 34 | if i > num_trans 35 | X(:, i - num_trans) = x; 36 | end 37 | end 38 | %% add noise 39 | X = X + randn(size(X)) * sigma; % / sqrt(N); 40 | thetas = thetas(1+num_trans:end); -------------------------------------------------------------------------------- /src/switching_linear_comparison_setup.m: -------------------------------------------------------------------------------- 1 | clear all 2 | close all 3 | 4 | addpath('~/work/MATLAB/') 5 | 6 | %% Parameters 7 | %% test system 8 | rng(1) 9 | 10 | M = 20; 11 | num_steps = M * 10 + 1; 12 | %num_steps = M * 100 + 1; 13 | num_trans = 200; 14 | noise_std = 0.5; 15 | noise_process = 0.0; 16 | T = floor((num_steps - 1) / M); 17 | noise_compensate = 0; 18 | offset = 0; 19 | num_reps = 3; 20 | 21 | %% for N sweeps 22 | %N_vec = [6,10,14,18,20,30,50,80,100,200,400,1000,2000,4000]; 23 | %noise_vec = [0.5]; 24 | 25 | %% for noise_std sweeps 26 | N_vec = [10,20,40,80,160,320]; %,2000,4000]; 27 | noise_vec = [0.05, 0.1, 0.2, 0.4, 0.6, 0.8, 1.]; 28 | 29 | %N_vec = [8000]; 30 | for noise_std = noise_vec 31 | for N = N_vec 32 | for rep = 1:num_reps 33 | %% Setup the problem 34 | [X, A1, A2] = switching_linear_setup(N, num_steps, noise_std, ... 35 | offset); 36 | save(sprintf('../data/test_data_N_%d_M_%d_sigma_%f_rep_%d.mat', N, num_steps, ... 37 | noise_std, rep), ... 38 | 'X', 'A1', 'A2', 'N', 'num_steps', 'noise_std', 'offset'); 39 | end 40 | end 41 | end -------------------------------------------------------------------------------- /src/rebalance.m: -------------------------------------------------------------------------------- 1 | function [lambda, Anew, Bnew, Cnew] = rebalance(A, B, C, varargin) 2 | if nargin > 3 3 | old_style = varargin{1}; 4 | else 5 | old_style = 0; 6 | end 7 | if nargin > 4 8 | tol = varargin{2}; 9 | else 10 | tol = 1e-8; 11 | end 12 | r = size(A, 2); 13 | lambda = zeros(r,1); 14 | for l = 1:r 15 | anorm = norm(A(:, l)); 16 | Anew(:, l) = A(:, l) / anorm; 17 | bnorm = norm(B(:, l)); 18 | Bnew(:, l) = B(:, l) / bnorm; 19 | cnorm = norm(C(:, l)); 20 | Cnew(:, l) = C(:, l) / cnorm; 21 | if old_style 22 | lambda(l) = anorm*bnorm*cnorm; 23 | if lambda(l) < tol 24 | Anew(:, l) = 0; 25 | Bnew(:, l) = 0; 26 | Cnew(:, l) = 0; 27 | lambda(l) = 0; 28 | end 29 | else 30 | new_norm = (anorm*bnorm*cnorm)^(1/3); 31 | Anew(:, l) = Anew(:, l) * new_norm; 32 | Bnew(:, l) = Bnew(:, l) * new_norm; 33 | Cnew(:, l) = Cnew(:, l) * new_norm; 34 | lambda(l) = 1; 35 | end 36 | end 37 | if (size(A, 1) + 1) == size(B, 1) 38 | % affine 39 | for l = 1:r 40 | Anew(:, l) = Anew(:, l) * sign(Bnew(end, l)); 41 | Bnew(end, l) = Bnew(end, l) * sign(Bnew(end, l)); 42 | end 43 | end 44 | end 45 | -------------------------------------------------------------------------------- /src/get_nino.m: -------------------------------------------------------------------------------- 1 | function [soi_ind, soi_dates] = get_nino(start_date, end_date) 2 | % Get the El Niño/Southern Oscillation indices for a given date 3 | % range. 4 | % 5 | % Modified from get_soi.m by Chad Greene: 6 | % https://www.mathworks.com/matlabcentral/fileexchange/38629-get-el-nino-southern-oscillation-index-values 7 | % 8 | % Data: 9 | % https://www.cpc.ncep.noaa.gov/data/indices/ersst4.nino.mth.81-10.ascii 10 | 11 | soi_ascii = fopen('../data/ersst4.nino.mth.81-10.ascii'); 12 | soi_struct = textscan(soi_ascii,'%n %n %n %n %n %n %n %n %n %n', ... 13 | 'HeaderLines', 1); 14 | soi_mat = cell2mat(soi_struct); % converts cell structure to matrix 15 | 16 | % Regrid data from calendar format to time series: 17 | soi_dates = NaN(size(soi_mat,1), 1); % preallocate variable 18 | soi_ind = NaN(size(soi_dates,1), size(soi_mat,2)-2); % preallocate variable 19 | for m = 1:length(soi_mat); 20 | soi_dates(m) = datenum(soi_mat(m,1), soi_mat(m,2),15); 21 | soi_ind(m, :) = soi_mat(m, 3:end); 22 | end 23 | 24 | 25 | % If input arguments (start and/or end dates) are used, rewrite time series 26 | % data to include only dates of interest: 27 | if exist('start_date','var')==1 28 | if start_date < datenum(1950,1,1) 29 | warning('Historical data record begins in January 1950.') 30 | end 31 | if exist('end_date','var')~=1 32 | end_date = datenum(date); % uses today as end_date if not specified 33 | end 34 | soi_ind = soi_ind(soi_dates>=start_date&soi_dates<=end_date, :); 35 | soi_dates = soi_dates(soi_dates>=start_date&soi_dates<=end_date); 36 | end 37 | 38 | fclose(soi_ascii); -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | The Clear BSD License 2 | 3 | Copyright (c) 2019 Kameron Decker Harris 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted (subject to the limitations in the disclaimer 8 | below) provided that the following conditions are met: 9 | 10 | * Redistributions of source code must retain the above copyright notice, 11 | this list of conditions and the following disclaimer. 12 | 13 | * Redistributions in binary form must reproduce the above copyright 14 | notice, this list of conditions and the following disclaimer in the 15 | documentation and/or other materials provided with the distribution. 16 | 17 | * Neither the name of the copyright holder nor the names of its 18 | contributors may be used to endorse or promote products derived from this 19 | software without specific prior written permission. 20 | 21 | NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY 22 | THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND 23 | CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 24 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A 25 | PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR 26 | CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 27 | EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 28 | PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR 29 | BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER 30 | IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 31 | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 32 | POSSIBILITY OF SUCH DAMAGE. 33 | -------------------------------------------------------------------------------- /src/get_pdo.m: -------------------------------------------------------------------------------- 1 | function [pdo_index, pdo_dates] = get_pdo(start_date, end_date) 2 | % Get the Pacific Decadal Oscillation indices for a given date 3 | % range. 4 | % 5 | % Modified from get_soi.m by Chad Greene: 6 | % https://www.mathworks.com/matlabcentral/fileexchange/38629-get-el-nino-southern-oscillation-index-values 7 | % 8 | % Data: 9 | % http://research.jisao.washington.edu/pdo/ 10 | 11 | pdo_ascii = fopen('../data/PDO.txt'); 12 | pdo_struct = textscan(pdo_ascii,'%n %n %n %n %n %n %n %n %n %n %n %n %n', ... 13 | 'HeaderLines', 0); 14 | pdo_mat = cell2mat(pdo_struct); % converts cell structure to matrix 15 | disp(size(pdo_mat)) 16 | % Regrid data from calendar format to time series: 17 | pdo_dates = NaN(size(pdo_mat, 1) * 12, 1); 18 | pdo_index = NaN(size(pdo_mat, 1) * 12, 1); 19 | counter = 1; 20 | for yid = 1:size(pdo_mat, 1); 21 | for month = 1:12 22 | pdo_dates(counter) = datenum(pdo_mat(yid, 1), month, 15); 23 | pdo_index(counter) = pdo_mat(yid, month + 1); 24 | counter = counter + 1; 25 | end 26 | end 27 | 28 | 29 | % If input arguments (start and/or end dates) are used, rewrite time series 30 | % data to include only dates of interest: 31 | if exist('start_date','var') == 1 32 | if start_date < datenum(1900, 1, 1) 33 | warning('Historical data record begins in January 1950.') 34 | end 35 | if end_date > datenum(2017, 12, 31) 36 | warning(['Historical data record ends in December ' ... 37 | '2017.']); 38 | end 39 | pdo_index = pdo_index(pdo_dates >= start_date & pdo_dates <= end_date); 40 | pdo_dates = pdo_dates(pdo_dates >= start_date & pdo_dates <= end_date); 41 | end 42 | 43 | fclose(pdo_ascii); -------------------------------------------------------------------------------- /src/init_dmd.m: -------------------------------------------------------------------------------- 1 | function [A, B, C] = init_dmd(X, Y, R, T, varargin) 2 | % Rfull = 0; 3 | N = size(X, 1); 4 | Admd = Y / X; 5 | [U, S, V] = svd(Admd, 'econ'); 6 | S = diag(S); 7 | %V = U; 8 | % [U, D] = eig(Admd); 9 | % U = real(U); 10 | % V = U; 11 | % S = diag(ones(N, 1)); 12 | if R > N 13 | % warning(['Not implemented for rank > N. Using all ones ' ... 14 | % 'initialization']); 15 | % A = repmat(U * diag(sqrt(S)), 1, ceil(R/N)); 16 | % A = A(:, 1:R); 17 | % B = repmat(V * diag(sqrt(S)), 1, ceil(R/N)); 18 | % B = B(:, 1:R); 19 | %scale = sqrt(min(S)); 20 | scale = 1; 21 | A = [U,... % * diag(sqrt(S)), ... 22 | ones(N, R-size(U, 2)) / sqrt(N) * scale]; 23 | B = [V,... % * diag(sqrt(S)), ... 24 | ones(N, R-size(V, 2)) / sqrt(N) * scale]; 25 | C = ones(T, R); 26 | % A = ones(size(Y, 1), R); 27 | % B = ones(size(X, 1), R); 28 | % warning('Not implemented for rank > N. Using rand initialization'); 29 | % R = R; 30 | % R = size(X, 1); 31 | % A = ones(size(Y, 1), R) + rand(size(Y, 1), R); 32 | % B = ones(size(X, 1), R) + rand(size(X, 1), R); 33 | % C = ones(T, R) + rand(T, R); 34 | else 35 | % [U, S, V] = svd(X, 'econ'); 36 | % S = diag(S); 37 | % Ur = U(:, 1:R); 38 | % Sr = S(1:R); 39 | % Vr = V(:, 1:R); 40 | % Srinv = diag(1./Sr); 41 | % Admd = Y * Vr * Srinv * Ur'; 42 | A = U(:, 1:R); % * diag(sqrt(S(1:R))); 43 | B = V(:, 1:R); % * diag(sqrt(S(1:R))); 44 | C = ones(T, R); 45 | end 46 | % A = ones(N, R) / sqrt(N); 47 | % B = ones(N, R) / sqrt(N); 48 | % C = ones(T, R) / sqrt(T); 49 | % if nargin > 4 50 | % if varargin{1} == 1 51 | A = A + 0.5 * randn(N, R) / sqrt(N); 52 | B = B + 0.5 * randn(N, R) / sqrt(N); 53 | C = C + 0.5 * randn(T, R) / sqrt(T); 54 | % A = (2 * rand(N, R) - 1) / sqrt(N); 55 | % B = (2 * rand(N, R) - 1) / sqrt(N); 56 | % C = (2 * rand(T, R) - 1) / sqrt(T); 57 | %end 58 | %end 59 | end -------------------------------------------------------------------------------- /src/switching_linear_setup.m: -------------------------------------------------------------------------------- 1 | function [X, A1, A2] = switching_linear_setup(N, num_steps, sigma, offset) 2 | num_trans = 200; 3 | noise_process = 0.0; 4 | noise_compensate = 0; 5 | 6 | x0 = ones(N,1) / sqrt(N); 7 | t1 = 0.1 * pi; 8 | A1 = [cos(t1) -sin(t1); 9 | sin(t1) cos(t1)]; 10 | %t2 = rand()*pi; 11 | t2 = 0.37 * pi; 12 | A2 = [cos(t2) -sin(t2); 13 | sin(t2) cos(t2)]; 14 | if N > 2 15 | %A1 = randn(N,N)/sqrt(N); 16 | %A2 = randn(N,N)/sqrt(N); 17 | %A1 = (2*rand(N,r)-1)*(2*rand(r,N)-1) / sqrt(N); 18 | %A2 = (2*rand(N,r)-1)*(2*rand(r,N)-1) / sqrt(N); 19 | % A1 = randn(N,r)*randn(r,N) / sqrt(r*N); 20 | % A2 = randn(N,r)*randn(r,N) / sqrt(r*N); 21 | % A1 = A1 / (max(abs(eig(A1))) ); 22 | % A2 = A2 / (max(abs(eig(A2))) ); 23 | U = random_orthogonal(N, 2); 24 | A1 = U * A1 * U'; 25 | U = random_orthogonal(N, 2); 26 | A2 = U * A2 * U'; 27 | end 28 | 29 | %A2 = A1; 30 | 31 | % disp('eigs of A1') 32 | % eig(A1) 33 | % disp('eigs of A2') 34 | % eig(A2) 35 | 36 | if offset 37 | %% add offsets 38 | b1 = 4+rand(N,1); 39 | b2 = -1-rand(N,1); 40 | end 41 | % transient removal 42 | for i = 1:num_trans 43 | x0 = A1 * x0; 44 | if i == 1 45 | x0 = x0 / norm(x0) * sqrt(N); 46 | end 47 | end 48 | 49 | %% integrate linear system 50 | X = zeros(N, num_steps); 51 | X(:, 1) = x0; 52 | for i = 2:num_steps 53 | if i < (num_steps)/2 54 | X(:, i) = A1 * X(:, i-1) + randn(N,1) * noise_process; 55 | else 56 | X(:, i) = A2 * X(:, i-1) + randn(N,1) * noise_process; 57 | if i == ceil((num_steps)/2) 58 | X(:, i) = X(:, i) / norm(X(:, i)) * sqrt(N); 59 | end 60 | end 61 | end 62 | if offset 63 | for i = 2:num_steps 64 | if i < (num_steps)/2 65 | X(:, i) = X(:, i) + b1; 66 | else 67 | X(:, i) = X(:, i) + b2; 68 | end 69 | end 70 | end 71 | %% add noise 72 | X = X + randn(size(X)) * sigma; % / sqrt(N); 73 | -------------------------------------------------------------------------------- /src/indep_models.m: -------------------------------------------------------------------------------- 1 | function [A] = indep_models(data, window_size, varargin) 2 | % param defaults 3 | center = 0; 4 | verbosity = 1; 5 | % parse params 6 | p = inputParser; 7 | addParameter(p, 'center', center); 8 | addParameter(p, 'verbosity', verbosity); 9 | parse(p, varargin{:}); 10 | center = p.Results.center; 11 | verbosity = p.Results.verbosity; 12 | 13 | if verbosity > 0 14 | fprintf('Running independent modeling\n'); 15 | fprintf('\twindow size: %d\n', window_size); 16 | end 17 | fprintf('\n'); 18 | 19 | N = size(data, 1); 20 | M = window_size; 21 | t = size(data, 2); 22 | 23 | T = floor((t-1) / window_size); 24 | X = data(:, 1:end-1); 25 | X = X(:, 1:M*T); 26 | Y = data(:, 2:end); 27 | Y = Y(:, 1:M*T); 28 | 29 | if center == 0 30 | X = reshape(X, [N, M, T]); 31 | else 32 | % append ones to input data 33 | X = reshape([X; ones(1, M*T)], [N+1, M, T]); 34 | end 35 | Y = reshape(Y, [N, M, T]); 36 | 37 | if center 38 | A = zeros(N, N+1, T); 39 | else 40 | A = zeros(N, N, T); 41 | end 42 | 43 | for k = 1:T 44 | A(:, :, k) = Y(:, :, k) * pinv(X(:, :, k)); 45 | end 46 | 47 | var_tot = 0; 48 | var_int = 0; 49 | var_aff = 0; 50 | for k = 1:T 51 | Ypred_k = A(:, :, k) * X(:, :, k); 52 | var_aff = var_aff + norm(Y(:,:,k) - Ypred_k, 'fro')^2; 53 | if center 54 | Ypred_int = A(:, end, k) * ones(1, size(Ypred_k, 2)); 55 | var_int = var_int + norm(Y(:,:,k) - Ypred_int, 'fro')^2; 56 | else 57 | var_int = var_int + norm(Y(:,:,k), 'fro')^2; 58 | end 59 | var_tot = var_tot + norm(Y(:,:,k), 'fro')^2; 60 | end 61 | var_tot = var_tot / (N*M*T); 62 | var_int = var_int / (N*M*T); 63 | var_aff = var_aff / (N*M*T); 64 | fprintf('total var:\t\t%1.3f\n', var_tot); 65 | fprintf('intercept resid var:\t%1.3f, %1.3g%%\n', var_int, ... 66 | var_int / var_tot* 100); 67 | fprintf('affine model resid var:\t%1.3f, %1.3g%%\n', var_aff, ... 68 | var_aff / var_tot* 100); 69 | 70 | if nargout >= 2 71 | varargout{1} = X; 72 | end 73 | if nargout >= 3 74 | varargout{2} = Y; 75 | end 76 | if nargout >= 4 77 | varargout{3} = var_aff; 78 | end 79 | end 80 | -------------------------------------------------------------------------------- /src/rebalance_2.m: -------------------------------------------------------------------------------- 1 | function [lambda, Anew, Bnew, Cnew, Wnew] = rebalance_2(A, B, C, W, varargin) 2 | if nargin > 4 3 | old_style = varargin{1}; 4 | else 5 | old_style = 0; 6 | end 7 | if nargin > 5 8 | tol = varargin{2}; 9 | else 10 | tol = 1e-8; 11 | end 12 | r = size(A, 2); 13 | lambda = zeros(r,1); 14 | for l = 1:r 15 | anorm = norm(A(:, l)); 16 | Anew(:, l) = A(:, l) / anorm; 17 | bnorm = norm(B(:, l)); 18 | Bnew(:, l) = B(:, l) / bnorm; 19 | cnorm = norm(C(:, l)); 20 | Cnew(:, l) = C(:, l) / cnorm; 21 | Wnew(:, l) = W(:, l) / cnorm; 22 | if old_style 23 | lambda(l) = anorm*bnorm*cnorm; 24 | if lambda(l) < tol 25 | Anew(:, l) = 0; 26 | Bnew(:, l) = 0; 27 | Cnew(:, l) = 0; 28 | Wnew(:, l) = 0; 29 | lambda(l) = 0; 30 | end 31 | else 32 | new_norm = (anorm*bnorm*cnorm)^(1/3); 33 | Anew(:, l) = Anew(:, l) * new_norm; 34 | Bnew(:, l) = Bnew(:, l) * new_norm; 35 | Cnew(:, l) = Cnew(:, l) * new_norm; 36 | Wnew(:, l) = Wnew(:, l) * new_norm; 37 | lambda(l) = 1; 38 | end 39 | end 40 | if (size(A, 1) + 1) == size(B, 1) 41 | % affine 42 | for l = 1:r 43 | Anew(:, l) = Anew(:, l) * sign(Bnew(end, l)); 44 | Bnew(end, l) = Bnew(end, l) * sign(Bnew(end, l)); 45 | end 46 | end 47 | end 48 | 49 | % function [lambda, Anew, Bnew, Cnew, Wnew] = rebalance_2(A, B, C, W) 50 | % r = size(A, 2); 51 | % lambda = ones(r,1); 52 | % for l = 1:r 53 | % anorm = norm(A(:, l)); 54 | % Anew(:, l) = A(:, l) / anorm; 55 | % bnorm = norm(B(:, l)); 56 | % Bnew(:, l) = B(:, l) / bnorm; 57 | % cnorm = norm(C(:, l)); 58 | % Cnew(:, l) = C(:, l) * anorm * bnorm; 59 | % Wnew(:, l) = W(:, l) * anorm * bnorm; 60 | % % if old_style 61 | % % lambda(l) = anorm*bnorm*cnorm; 62 | % % else 63 | % % new_norm = (anorm*bnorm*cnorm)^(1/3); 64 | % % Anew(:, l) = Anew(:, l) * new_norm; 65 | % % Bnew(:, l) = Bnew(:, l) * new_norm; 66 | % % Cnew(:, l) = Cnew(:, l) * new_norm; 67 | % % lambda(l) = 1; 68 | % % end 69 | % end 70 | % if (size(A, 1) + 1) == size(B, 1) 71 | % % affine 72 | % for l = 1:r 73 | % Anew(:, l) = Anew(:, l) * sign(Bnew(end, l)); 74 | % Bnew(end, l) = Bnew(end, l) * sign(Bnew(end, l)); 75 | % end 76 | % end 77 | 78 | % end 79 | 80 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Time-varying Autoregression with Low Rank Tensors (TVART) 2 | ### by Kameron Decker Harris 3 | 4 | This is the code repository for the TVART method, to accompany the 5 | paper "Time-varying Autoregression with Low Rank Tensors" by 6 | Kameron Decker Harris, Aleksandr Aravkin, Rajesh Rao, and Bing Brunton. 7 | [[arxiv link]](https://arxiv.org/abs/1905.08389) 8 | 9 | Dependencies: 10 | * MATLAB R2017b (not tested on earlier versions) with relevant toolboxes 11 | * UNLocBox https://epfl-lts2.github.io/unlocbox-html/ 12 | * vline/hline https://www.mathworks.com/matlabcentral/fileexchange/1039-hline-and-vline 13 | * Python 2.7 or 3.6 with numpy, scipy (optional, for preprocessing neural example) 14 | 15 | ## src/ 16 | 17 | The files to run the TVART algorithm and examples are included here. 18 | 19 | * TVART_alt_min.m - implementes the alternating minimization algorithm described in the text 20 | * switching_linear.m - switching linear test case 21 | * smooth_linear.m - smooth linear test case 22 | * example_worms.m - worm behavior example 23 | * example_el_nino.m - sea surface temperature example 24 | * preprocess_neurotycho.py - preprocessing script to remove line noise and compute band power for neural example 25 | * example_neurotycho.m - neural activity example 26 | * other files: helper functions, iPython notebooks used to compare with SLDS, 27 | switching_linear_comparison* and smooth_linear_comparison* run sweeps of test problems across N... 28 | these are provided as-is and will require some tweaking to run 29 | 30 | ## data/ 31 | 32 | The data for the examples is stored here. 33 | You will need to carry out some extra steps to run all examples: 34 | 35 | ### Worm behavior 36 | 37 | We obtained the code and data from Costa et al. from 38 | https://github.com/AntonioCCosta/local-linear-segmentation. 39 | To just run our example, all that is needed is "worm_tseries.h5". 40 | 41 | ### Sea surface temperature 42 | 43 | In order to run the "Sea surface temperature" example, you must download 44 | * sst.wkmean.1990-present.nc 45 | * lsmask.nc 46 | 47 | from https://www.esrl.noaa.gov/psd/repository/entry/show/PSD+Climate+Data+Repository/Public/PSD+Datasets/NOAA+OI+SST/Weekly+and+Monthly/. 48 | 49 | The files "ersst4.nino.mth.81-10.ascii" and "PDO.txt" are from https://www.cpc.ncep.noaa.gov/data/indices/ersst4.nino.mth.81-10.ascii and http://research.jisao.washington.edu/pdo/PDO.latest.txt. 50 | 51 | ### Neural activity 52 | 53 | These data are kindly provided by the [Neurotycho project](http://neurotycho.org): 54 | http://neurotycho.brain.riken.jp/download/base/20090525S1_Food-Tracking_K1_Zenas+Chao_mat_ECoG64-Motion8.zip. 55 | 56 | In order to prepare the data, you must run the preprocessing script. 57 | 58 | ## figures/ 59 | 60 | After running the code, figures will be saved in this directory. 61 | We include some figures modified from Neurotycho http://neurotycho.org/food-tracking-task. 62 | -------------------------------------------------------------------------------- /src/example_worms.m: -------------------------------------------------------------------------------- 1 | clear all 2 | close all 3 | 4 | tic 5 | 6 | addpath('~/work/MATLAB/unlocbox') 7 | init_unlocbox(); 8 | 9 | worm_id = 1; 10 | worm_data = h5read(['~/src/local-linear-segmentation/sample_data/' ... 11 | 'worm_tseries.h5'], ['/' num2str(worm_id) ... 12 | '/tseries']); 13 | %worm_data= standardize(worm_data); 14 | figdir = '../figures/'; 15 | 16 | rng(1); 17 | % %% Tensor decomp params 18 | R = 6; 19 | M = 6; 20 | %M = 10; 21 | eta = 0.05; 22 | eta2 = 0.1; 23 | beta = 6.0; 24 | center = 1; 25 | regularization = 'TV'; 26 | max_iter = 1000; 27 | numclust = 3; 28 | 29 | %% Good for TV-l0 30 | % R = 12; 31 | % M = 4; 32 | % eta = 10.0; 33 | % eta2 = 0.1; 34 | % beta = 0.05; 35 | % center = 0; 36 | % max_iter = 1000; 37 | % regularization = 'TVL0'; 38 | 39 | 40 | % [lambda, A, B, C, cost, Xten, Yten] = ... 41 | % tensor_DMD_ALS_aux(worm_data, M, R, ... 42 | % 'center', center, ... 43 | % 'eta', eta, ... 44 | % 'eta2', eta2, ... 45 | % 'beta', beta, ... 46 | % 'regularization', 'TV', ... 47 | % 'rtol', 1e-4, ... 48 | % 'max_iter', max_iter,... 49 | % 'verbosity', 1); 50 | 51 | [lambda, A, B, C, cost, Xten, Yten] = ... 52 | TVART_alt_min(worm_data, M, R, ... 53 | 'center', center, ... 54 | 'eta', eta, ... 55 | 'beta', beta, ... 56 | 'regularization', regularization, ... 57 | 'verbosity', 2, ... 58 | 'max_iter', max_iter); 59 | 60 | 61 | % [lambda, A, B, C] = rebalance(A, B, C, 1); 62 | % [lambda, A, B, C] = reorder_components(lambda, A, B, C); 63 | 64 | 65 | 66 | 67 | %% Clustering 68 | %D = compute_pdist(A, B, C*diag(lambda)); 69 | D = pdist(C); 70 | Z = linkage(squareform(D), 'complete'); 71 | T = cluster(Z, 'maxclust', numclust); 72 | 73 | phase = atan2(worm_data(2,:), worm_data(1,:)); 74 | 75 | 76 | %% Load Costa model results for comparison 77 | Costa_data = load(['~/src/local-linear-segmentation/sample_data/' ... 78 | 'worm_1_results.mat']); 79 | C_segmentation = (Costa_data.segmentation(:, 1) + ... 80 | Costa_data.segmentation(:, 2)) / 2; % midpoints of windows 81 | C_labels = mod(3 - Costa_data.cluster_labels, 3) + 1 % reorder 82 | 83 | 84 | figure 85 | ax1 = subplot(4, 1, 1); 86 | plot(1:length(worm_data), worm_data, 'color', [0,0,0]+0.4); 87 | %plot(1:length(worm_data), worm_data); 88 | %legend({'a_1','a_2','a_3','a_4'}) 89 | title('Observations') 90 | ax2 = subplot(4, 1, 2); 91 | plot(1:(length(worm_data)), phase,... 92 | 'color', [0,0,0]+0.4) 93 | ylim([-pi, pi]) 94 | title('Phase') 95 | ax3 = subplot(4, 1, 3); 96 | plot(M/2 + (1:M:M*length(C)), C); 97 | title('Temporal modes') 98 | %ylim([0.06, 0.077]) 99 | ax4 = subplot(4, 1, 4); 100 | hold on 101 | plot(M/2 + (1:M:M*length(C)), T, 'ks-'); 102 | plot(C_segmentation, C_labels, 'ko--'); 103 | text(10, 1.5, 'forward', 'Color', 'red') 104 | text(65, 1.5, 'turn', 'Color', 'red') 105 | text(140, 1.5, 'backward', 'Color', 'red') 106 | legend({'TVART', 'Costa et al.'}) 107 | ylim([0.8, numclust+.2]) 108 | yticks([1,2,3]) 109 | box on 110 | xlabel('Time step') 111 | title('Clusters') 112 | linkaxes([ax1 ax2 ax3 ax4], 'x') 113 | set(gcf, 'Color', 'w', 'Position', [100 200 600 700]); 114 | set(gcf, 'PaperUnits', 'inches', ... 115 | 'PaperPosition', [0 0 6.5 7.5], 'PaperPositionMode', ... 116 | 'manual'); 117 | print('-depsc2', '-loose', '-r200', [figdir 'example_worms.eps']); 118 | 119 | elapsedtime = toc(); 120 | 121 | fprintf('%f seconds elapsed\n', elapsedtime); 122 | -------------------------------------------------------------------------------- /src/preprocess_neurotycho.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.signal 3 | from scipy.io import loadmat, savemat 4 | 5 | 6 | def filter_ecog(arr, f_sample, f_high=200., f_low=2., w_stop=3., order=2, line_freq=60.): 7 | """Filtering method for ECoG data 8 | 9 | Uses fourth-order Butterworth forward-backward filtering. 10 | 11 | Parameters 12 | ---------- 13 | arr : array-like, shape (n_time,) 14 | Data to filter 15 | f_sample : float 16 | Sample rate, in Hz 17 | f_high : float, optional (default = 200) 18 | High frequency cutoff, in Hz 19 | f_low : float, optional (default = 1.0) 20 | Low frequency cutoff, in Hz 21 | w_stop : float, optional (default = 2.5) 22 | Half-width of 60 Hz bandstop filter 23 | line_freq : float, optional (default = 60.) 24 | Line frequency, in Hz. Set to 50 for some countries. 25 | 26 | Returns 27 | ------- 28 | arr_filt : array-like, shape (n_time,) 29 | Filtered data 30 | """ 31 | from scipy import signal 32 | f_Ny = f_sample / 2. # Nyquist frequency 33 | # 1) bandpass filter 34 | b, a = signal.butter(order, [f_low / f_Ny, f_high / f_Ny], 'band') 35 | arr_filt = signal.filtfilt(b, a, arr) 36 | # 2) bandstop filter, 60 Hz line noise and harmonics 37 | for mult in [1, 2, 3, 4]: 38 | f_stop_mid = line_freq * mult 39 | f_stop = np.array([f_stop_mid - w_stop, f_stop_mid + w_stop]) 40 | b, a = signal.butter(order, f_stop / f_Ny, 'bandstop') 41 | arr_filt = signal.filtfilt(b, a, arr_filt) 42 | return arr_filt 43 | 44 | 45 | #nperseg = 96 46 | #noverlap = 64 47 | window_len = 0.096 48 | window_overlap = 0.064 49 | ## Miller et al. (2007) choices: 0.25 s window, 0.1 s overlap 50 | #window_len = 0.25 51 | #window_overlap = 0.1 52 | 53 | # data_dir = '../data_full/20090611S1_FTT_A_ZenasChao_mat_ECoG32-Motion12' 54 | # n_chan = 32 55 | 56 | # data_dir = '../data_full/20090527S1_FTT_K1_ZenasChao_mat_ECoG64-Motion8' 57 | # n_chan = 64 58 | 59 | data_dir = '../data_full/20090525S1_FTT_K1_ZenasChao_mat_ECoG64-Motion8' 60 | n_chan = 64 61 | 62 | #data_dir = '../data/20100802S1_Epidural-ECoG+Food-Tracking_B_Kentaro+' + \ 63 | # 'Shimoda_mat_ECoG64-Motion6' 64 | #n_chan = 64 65 | 66 | # out_fn = data_dir + '/Kam_Bands_param_Miller.mat' 67 | out_fn = data_dir + '/Kam_Bands.mat' 68 | 69 | 70 | # Load data 71 | ecog_time = loadmat(data_dir + '/ECoG_time.mat') 72 | ecog_time = ecog_time['ECoGTime'].flatten() 73 | ecog_data = np.zeros((n_chan, len(ecog_time))) 74 | for chan in range(n_chan): 75 | fn = "%s/ECoG_ch%d.mat" % (data_dir, chan + 1) 76 | var = "ECoGData_ch%d" % (chan + 1) 77 | tmp = loadmat(fn) 78 | ecog_data[chan, :] = tmp[var].astype(float) 79 | 80 | DT = ecog_time[1] - ecog_time[0] 81 | nperseg = int(window_len / DT) 82 | noverlap = int(window_overlap / DT) 83 | 84 | # Filter out line noise + highpass 85 | for chan in range(n_chan): 86 | ecog_data[chan, :] = filter_ecog(ecog_data[chan, :], 1./DT, line_freq = 50.) 87 | 88 | # Common average referencing 89 | #ecog_data = ecog_data - np.tile(np.median(ecog_data, axis=0), (n_chan, 1)) 90 | # Standardization 91 | #ecog_data = ecog_data - np.tile(np.median(ecog_data, axis=1)[:,np.newaxis], 92 | # (1, ecog_data.shape[1])) 93 | # ecog_data = ecog_data / np.tile(np.std(ecog_data, 1, ddof=1)[:,np.newaxis], 94 | # (1, ecog_data.shape[1])) 95 | #ecog_data = ecog_data / np.tile(np.median(ecog_data, 1)[:,np.newaxis], 96 | # (1, ecog_data.shape[1])) 97 | 98 | 99 | band_edges = [2, 32, 200] 100 | ## Miller choices: 101 | #band_edges = [8, 32, 76, 100] 102 | # add notch filters around 60, 120 Hz OR 50, 100, 150 103 | # compare fits to just low, high separate versus combined 104 | # regression decoder 105 | # look at lambdas 106 | # straw man comparisons: PCA, windowed DMD, ICA 107 | 108 | n_band = len(band_edges) - 1 109 | new_ecog = np.zeros((n_chan, n_band, int(len(ecog_time)/10))) 110 | 111 | for chan in range(n_chan): 112 | f, t, Sxx = scipy.signal.spectrogram(ecog_data[chan, :], fs=1./DT, nperseg=nperseg, 113 | noverlap=noverlap) 114 | if chan == 0: 115 | new_ecog = new_ecog[:, :, range(len(t))] 116 | for band in range(n_band): 117 | f_low = band_edges[band] 118 | f_high = band_edges[band + 1] 119 | idx = (f < f_high) & (f >= f_low) 120 | new_ecog[chan, band, :] = np.sum(Sxx[idx, :], axis=0) 121 | 122 | new_time = t 123 | n_time = len(new_time) 124 | 125 | ## rearrange data 126 | ecog_post = np.zeros((n_chan * n_band, n_time)) 127 | for chan in range(n_chan): 128 | for band in range(n_band): 129 | ecog_post[chan * n_band + band, :] = new_ecog[chan, band, :] 130 | 131 | ecog_post = np.log10(ecog_post) 132 | 133 | import sklearn.preprocessing as skpre 134 | # transform = skpre.QuantileTransformer(output_distribution='normal') 135 | transform = skpre.StandardScaler() 136 | ecog_post = transform.fit_transform(ecog_post.T).T 137 | 138 | savemat(out_fn, 139 | {'ecog_power': new_ecog, 140 | 'ecog_power_time' : new_time, 141 | 'ecog_power_post' : ecog_post, 142 | 'ecog_v_filtered' : ecog_data, 143 | 'ecog_v_time' : ecog_time 144 | } ) 145 | 146 | -------------------------------------------------------------------------------- /src/example_lorenz.m: -------------------------------------------------------------------------------- 1 | clear all 2 | close all 3 | addpath('~/work/MATLAB/unlocbox') 4 | init_unlocbox(); 5 | 6 | load('../data/lorenz_0_001.mat', 'state', 'step'); 7 | 8 | save_file = 'output_lorenz.mat'; 9 | figdir = '../figures/'; 10 | dt = step; 11 | clear('step'); 12 | 13 | rng(1) 14 | %% Tensor DMD algorithm 15 | 16 | % %% Good params w/o standardization, w/ TV-l0 17 | M = 10; 18 | R = 4; 19 | max_iter = 2000; % iterations 20 | eta1 = 0.01; % Tikhonov/proximal 21 | eta2 = 1e-3; % relaxed vars 22 | beta = 4e2; % regularization 23 | center = 1; 24 | noise = 1.0; 25 | % %% Good w/ standardization 26 | % M = 3; 27 | % R = 6; 28 | % max_iter = 140; % iterations 29 | % eta1 = 1e2; % Tikhonov/proximal 30 | % eta2 = 1e-4; % relaxed vars 31 | % beta = 0.001; % regularization 32 | % center = 0; 33 | % noise = 0.2; 34 | %% Good params w/o standardization or centering 35 | % M = 10; 36 | % R = 6; 37 | % max_iter = 2000; % iterations 38 | % eta1 = 1e-2; % Tikhonov/proximal 39 | % eta2 = 1e-1; % relaxed vars 40 | % beta = 20; % regularization 41 | % center = 1; 42 | % noise = 1.0; 43 | 44 | state = state(:, 1:4200); 45 | %X = state(:, 1:2000); 46 | X = state; 47 | %X = randn(30,3) * state; 48 | % X = standardize(X); 49 | X = X + noise*randn(size(X)); 50 | %X = X - repmat(mean(X,2), 1, length(X)); 51 | %X = X ./ repmat(std(X,0,2), 1, length(X)); 52 | 53 | % [lambda, A, B, C, cost, Xten, Yten, rmse] = ... 54 | % tensor_DMD_ALS_smooth(X, M, R, ... 55 | % 'center', center, ... 56 | % 'eta1', eta1, ... 57 | % 'eta2', eta2, ... 58 | % 'beta', beta, ... 59 | % 'max_iter', max_iter); 60 | 61 | % [lambda, A, B, C, cost, Xten, Yten, rmse, W] = ... 62 | % tensor_DMD_ALS_aux(X, M, R, ... 63 | % 'center', center, ... 64 | % 'eta1', eta1, ... 65 | % 'eta2', eta2, ... 66 | % 'beta', beta, ... 67 | % 'proximal', 0, ... 68 | % 'verbosity', 2,... 69 | % 'regularization', 'TV', ... 70 | % 'rtol', 1e-6, ... 71 | % 'atol', 1e-6, ... 72 | % 'max_iter', max_iter); 73 | 74 | [lambda, A, B, C, cost, Xten, Yten, rmse] = ... 75 | TVART_alt_min(X, M, R, ... 76 | 'center', center, ... 77 | 'eta', eta1, ... 78 | 'beta', beta, ... 79 | 'regularization', 'TV', ... 80 | 'verbosity', 2, ... 81 | 'max_iter', max_iter); 82 | 83 | % [lambda, A, B, C, cost, Xten, Yten, rmse] = ... 84 | % tensor_DMD_alt_min(X, M, R, ... 85 | % 'center', center, ... 86 | % 'eta', eta1, ... 87 | % 'beta', beta, ... 88 | % 'regularization', 'TV', ... 89 | % 'rtol', 1e-4, ... 90 | % 'atol', 1e-5, ... 91 | % 'verbosity', 2,... 92 | % 'max_iter', max_iter); 93 | 94 | % [lambda, A, B, C, cost, Xten, Yten, rmse] = ... 95 | % tensor_DMD_alt_min_l0(X, M, R, ... 96 | % 'center', center, ... 97 | % 'eta', eta1, ... 98 | % 'beta', beta, ... 99 | % 'regularization', 'TV0', ... 100 | % 'rtol', 1e-3, ... 101 | % 'atol', 1e-5, ... 102 | % 'verbosity', 2,... 103 | % 'max_iter', max_iter); 104 | 105 | % [lambda, A, B, C, cost, Xten, Yten, rmse, W] = ... 106 | % tensor_DMD_ALS_smooth(X, M, R, ... 107 | % 'center', center, ... 108 | % 'eta1', eta1, ... 109 | % 'eta2', eta2, ... 110 | % 'beta', beta, ... 111 | % 'rtol', 1e-4, ... 112 | % 'atol', 1e-4, ... 113 | % 'max_iter', max_iter); 114 | 115 | A_r = A; B_r = B; C_r = C; lambda_r = lambda; 116 | % %W_r = W; 117 | % lambda_r = lambda; 118 | W_r = C; 119 | %[lambda_r, A_r, B_r, C_r, W_r] = rebalance_2(A, B, C, W_r, 1, 1e-6); 120 | %[lambda_r, A_r, B_r, C_r, W_r] = reorder_components(lambda_r, A_r, B_r, ... 121 | % C_r, W_r); 122 | 123 | 124 | figure(1); 125 | semilogy(cost); 126 | hold on 127 | semilogy(rmse, 'r-') 128 | legend({'cost', 'RMSE'}) 129 | axis('auto') 130 | 131 | % [~,I] = sort(lambda, 'descend'); 132 | % A = A(:,I); 133 | % B = B(:,I); 134 | % C = C(:,I); 135 | % lambda = lambda(I); 136 | % for i=1:size(C,2) 137 | % if all(C(:,i) < 0) 138 | % C(:,i) = C(:,i) * -1; 139 | % W(:,i) = W(:,i) * -1; 140 | % A(:,i) = A(:,i) * -1; 141 | % end 142 | % end 143 | %% Clustering 144 | %D = compute_pdist(A_r, B_r, C_r*diag(lambda_r)); 145 | D = pdist(C); 146 | Z = linkage(squareform(D), 'complete'); 147 | cluster_ids = cluster(Z, 'maxclust', 3); 148 | T = size(C_r, 1); 149 | N = size(A_r, 1); 150 | xA = zeros(T, N*(N+center)); 151 | for k = 1:T 152 | Akaug = A_r * diag(C_r(k,:)) * B_r'; 153 | xA(k, :) = Akaug(:); 154 | end 155 | %[clust_ids, clust_centroids] = kmedoids(xA, 3); 156 | [clust_ids, clust_centroids] = kmedoids(C, 3); 157 | 158 | 159 | figure(2); 160 | ax1 = subplot(3,1,1); 161 | plot(1:length(X), X', 'linewidth', 0.5, 'color', [0,0,0]+0.4) 162 | title('Observations') 163 | %legend({'x_1', 'x_2', 'x_3'}) 164 | % ax1 = subplot(3,1,2); 165 | % plot(1:length(state), state(1,:)) 166 | % title('Trajectory') 167 | %ylim([-30, 45]) 168 | ax2 = subplot(3,1,2); 169 | plot(M*(1:length(C_r)), C_r, 'linewidth', 0.5); 170 | title('Temporal modes') 171 | %legend({'1', '2', '3', '4'}) 172 | %ylim([0.0475, 0.0491]) 173 | %hold on 174 | %plot(M*(1:length(C_r)), W_r, 'ko'); 175 | % linkaxes([ax0 ax1 ax2], 'x') 176 | ax3 = subplot(3,1,3); 177 | plot(M/2 + (1:M:M*length(C)), cluster_ids, 'ks-'); 178 | ylim([0.8, 3.2]) 179 | %xlabel('Time') 180 | title('Cluster') 181 | xlabel('Time step') 182 | linkaxes([ax1 ax2 ax3], 'x') 183 | figure(2); 184 | xlim([0, 2000]) 185 | %axis tight 186 | %ylim([0.0475, 0.0491]) 187 | set(gcf, 'Color', 'w', 'Position', [100 200 600 700]); 188 | set(gcf, 'PaperUnits', 'inches', ... 189 | 'PaperPosition', [0 0 6.5 7.5], 'PaperPositionMode', 'manual'); 190 | print('-depsc2', '-loose', [figdir 'example_lorenz.eps']); 191 | 192 | disp(lambda_r) 193 | % fprintf('A_r^T A_r = \n') 194 | % disp(A_r'*A_r) 195 | % fprintf('B^T B = \n') 196 | % disp(B'*B) 197 | % fprintf('C_r^T C_r = \n') 198 | % disp(C_r'*C_r) 199 | 200 | save(save_file) 201 | 202 | 203 | %% Compare scaled Jacobians 204 | s1 = 10; % sigma 205 | s2 = 28; % rho 206 | s3 = 8/3; % beta 207 | dt = 0.001; 208 | fps = [ 0, 0, 0; 209 | sqrt(s3*(s2-1)), sqrt(s3*(s2-1)), s2-1; 210 | -sqrt(s3*(s2-1)), -sqrt(s3*(s2-1)), s2-1]; 211 | Df = @(x) [-s1 s1 0; 212 | s2-x(3) -1 -x(1); 213 | x(2) x(1) -s3]; 214 | 215 | for fp = 1:3 216 | disp('Jacobian + fwd Euler:') 217 | disp([eye(N) + dt * Df(fps(fp,:)), fps(fp,:)'] ) 218 | end 219 | for fp = 1:3 220 | fprintf('\nTVART cluster centroid:\n') 221 | Aclust = A*diag(clust_centroids(fp,:))*B(1:end-1,:)'; 222 | bclust = A*diag(clust_centroids(fp,:))*B(end,:)'; 223 | c = (eye(N) - Aclust) \ bclust; 224 | disp([Aclust, c]) 225 | %disp(reshape(clust_centroids(fp,:), [N, N+center])) 226 | fprintf('\n') 227 | end 228 | -------------------------------------------------------------------------------- /src/smooth_linear.m: -------------------------------------------------------------------------------- 1 | clear all 2 | close all 3 | 4 | addpath('~/work/MATLAB/') 5 | addpath('~/work/MATLAB/unlocbox') 6 | init_unlocbox(); 7 | figdir = '../figures/'; 8 | 9 | %% Parameters 10 | %% test system 11 | %rng('shuffle'); 12 | %rng(1337) 13 | rng(2) 14 | 15 | N = 10; 16 | M = 1; 17 | r = 4; 18 | num_steps = M * 160 + 1; 19 | noise_std = 0.2; 20 | noise_process = 0.0; 21 | T = floor((num_steps - 1) / M); 22 | noise_compensate = 0; 23 | %% tensor DMD algorithm 24 | algorithm = 'alt_min'; % ALS or prox_grad 25 | regularization = 'spline'; 26 | max_iter = 300; % iterations 27 | offset = 0; 28 | center = 0; 29 | eta = 6. / N; % Tikhonov 30 | beta = 600 * log10(N)^2; % regularization strength 31 | proximal = 0; 32 | save_plots = 1; 33 | 34 | 35 | %% Setup the problem 36 | 37 | [X, thetas, U] = smooth_linear_setup(N, num_steps, noise_std); 38 | %load(sprintf('test_data_N_%d_M_%d_sigma_%f.mat', N, num_steps, ... 39 | % noise_std)); 40 | 41 | %% Preprocessing: 42 | %[U,S,V] = svd(X, 0); 43 | 44 | %% reinitialize RNG 45 | %rng('shuffle') 46 | 47 | %% run tensor DMD 48 | [lambda, A, B, C, cost_vec, Xten, Yten, rmse_vec] = ... 49 | TVART_alt_min(X, M, r, ... 50 | 'center', center, ... 51 | 'eta', eta, ... 52 | 'beta', beta, ... 53 | 'max_iter', max_iter, ... 54 | 'verbosity', 2, ... 55 | 'proximal', proximal, ... 56 | 'regularization', regularization); 57 | figure(1) 58 | semilogy(cost_vec) 59 | hold on 60 | semilogy(rmse_vec, 'r-') 61 | legend({'cost', 'RMSE'}) 62 | ylabel('cost') 63 | xlabel('iteration') 64 | xlim([0, length(rmse_vec)]) 65 | 66 | %% Figure 2 67 | figure(2) 68 | subplot(2,1,1) 69 | plot(X', 'color', [0,0,0] + 0.4, 'linewidth', 0.5) 70 | axis tight 71 | xlim([1, num_steps]) 72 | xlabel('Time step') 73 | %ylabel('State variable') 74 | title('Observations') 75 | subplot(2,1,2) 76 | plot(C, '-'); 77 | %if strcmp(algorithm, 'ALS_aux') 78 | % hold on 79 | % plot(W, 'ko'); 80 | %end 81 | %set(gca, 'xtick', 1:T) 82 | axis tight 83 | xlim([0.5, T+.5]) 84 | arr = ylim(); 85 | ylim([arr(1) - 0.05, arr(2) + 0.05]) 86 | xlabel('Window') 87 | title('Temporal modes') 88 | figure(2) 89 | set(gcf, 'Color', 'w', 'Position', [100 500 800 600]); 90 | set(gcf, 'PaperUnits', 'inches', 'PaperPosition', [0 0 6.5 4], 'PaperPositionMode', 'manual'); 91 | set(gcf,'renderer','painters') 92 | %export_fig('-depsc2', '-r200', [figdir 'smooth_summary.eps']); 93 | if save_plots 94 | print('-depsc2', '-loose', '-r200', [figdir ... 95 | 'smooth_summary.eps']); 96 | end 97 | %print('-dpng', '-r200', [figdir 'smooth_summary.png']); 98 | 99 | % A = Aten.U{1}; 100 | % B = Aten.U{2}; 101 | % C = Aten.U{3}; 102 | 103 | 104 | %% test closeness 105 | figure(3); 106 | set(gcf, 'position', [246 66 431 1231]); 107 | for k = 1:T 108 | Dk = diag(C(k,:)); 109 | Ak = A * diag(lambda) * Dk * B'; 110 | %Ak = Aten(:, :, k) 111 | ten_MSE = norm(Yten(:,:,k) - Ak * Xten(:,:,k), 'fro')^2 / (N* ... 112 | M); 113 | Ar = [cos(thetas(k+1)) -sin(thetas(k+1)); 114 | sin(thetas(k+1)) cos(thetas(k+1))]; 115 | A_true = U * Ar * U'; 116 | 117 | [D] = eig(Ak(:,1:N)); 118 | [Dtrue] = eig(A_true); 119 | 120 | A_err = max(max(abs(Ak(:, 1:N) - A_true))); 121 | A_err_2 = norm(Ak(:, 1:N) - A_true, 2); 122 | true_MSE = norm(Yten(:,:,k) - A_true * Xten(1:N,:,k), 'fro')^2 / ... 123 | (N*M); 124 | 125 | fprintf('\nWindow %d\n', k) 126 | fprintf('prediction MSE using A_tensor: %1.3g (SNR: %1.2g)\n', ten_MSE, ... 127 | ten_MSE / noise_std^2) 128 | fprintf('prediction MSE using A_truth: %1.3g (SNR: %1.2g)\n', true_MSE,... 129 | true_MSE / noise_std^2) 130 | fprintf('tensor MSE relative to truth: %1.3g%%\n', ... 131 | ten_MSE / true_MSE * 100) 132 | fprintf('A infty error: %1.3g\n', A_err) 133 | fprintf('A 2-error: %1.3g\n', A_err_2) 134 | fprintf('radius(A_tensor) = %1.2g, radius(A_true) = %1.2g\n', ... 135 | max(abs(D)), max(abs(Dtrue))); 136 | 137 | set(0, 'CurrentFigure', 3); 138 | subplot(3,1,1); 139 | imagesc(Ak) 140 | colorbar 141 | colormap(brewermap([], 'PuOr')) 142 | caxis([-1,1]/sqrt(N)) 143 | %caxis auto 144 | title('TVAR system matrix', 'fontweight', 'normal') 145 | axis square 146 | 147 | subplot(3,1,2); 148 | imagesc(A_true); 149 | colorbar 150 | colormap(brewermap([], 'PuOr')) 151 | caxis([-1,1]/sqrt(N)) 152 | %caxis auto 153 | title('True matrix', 'fontweight', 'normal') 154 | axis square 155 | 156 | subplot(3,1,3); 157 | imagesc(Ak(:,1:N) - A_true); 158 | colorbar 159 | colormap(brewermap([], 'PuOr')) 160 | caxis([-1, 1]/sqrt(N)) 161 | %caxis auto 162 | title('Difference', 'fontweight', 'normal') 163 | axis square 164 | 165 | if k == 1 166 | set(gcf, 'Color', 'w', 'Position', [0 0 600 1000]); 167 | set(gcf, 'PaperUnits', 'inches', ... 168 | 'PaperPosition', [0 0 3 7.2], ... 169 | 'PaperPositionMode', 'manual'); 170 | set(gcf,'renderer','painters') 171 | %print('-dpng', '-painters', '-r200', [figdir 172 | %'smooth_matrices.png']) 173 | if save_plots 174 | print('-depsc2', '-loose', '-r200', [figdir ... 175 | 'smooth_matrices.eps']); 176 | end 177 | %export_fig('-depsc2', '-r200', [figdir 'smooth_matrices.eps']); 178 | % print('-dpdf', '-painters', [figdir 'smooth_matrices.pdf']); 179 | end 180 | 181 | pause 182 | if k ~= T 183 | clf(3) 184 | end 185 | end 186 | 187 | 188 | 189 | %% Figure 4 190 | figure(4) 191 | 192 | set(gcf, 'position', [693 63 560 1240]) 193 | subplot(3,1,1); 194 | imagesc(A); 195 | title('Left spatial modes', 'fontweight', 'normal') 196 | colormap(brewermap([], 'PuOr')) 197 | c = caxis(); 198 | c = max(abs(c)); 199 | caxis([-c, c]) 200 | colorbar 201 | axis square 202 | 203 | subplot(3,1,2); 204 | imagesc(B(1:N,:)); 205 | title('Right spatial modes', 'fontweight', 'normal') 206 | colormap(brewermap([], 'PuOr')) 207 | %caxis([-1, 1]) 208 | c = caxis(); 209 | c = max(abs(c)); 210 | caxis([-c, c]) 211 | colorbar 212 | axis square 213 | 214 | subplot(3,1,3); 215 | imagesc(C); 216 | title('Temporal modes', 'fontweight', 'normal') 217 | colormap(brewermap([], 'PuOr')) 218 | %caxis([-1, 1]) 219 | c = caxis(); 220 | c = max(abs(c)); 221 | caxis([-c, c]) 222 | colorbar 223 | axis square 224 | 225 | set(gcf, 'Color', 'w', 'Position', [0 0 600 1000]); 226 | set(gcf, 'PaperUnits', 'inches', 'PaperPosition', [0 0 3 7.2], ... 227 | 'PaperPositionMode', 'manual'); 228 | set(gcf,'renderer','painters') 229 | if save_plots 230 | print('-depsc2', '-loose', '-r200', [figdir ... 231 | 'smooth_components.eps']); 232 | end 233 | %export_fig('-depsc2', '-r200', [figdir 'smooth_components.eps']); 234 | %print('-dpng', '-r200', [figdir 'smooth_components.png']); 235 | 236 | 237 | 238 | %% Figure 5 239 | figure(5); 240 | scatter(real(Dtrue), imag(Dtrue), 'ob'); 241 | hold on 242 | scatter(real(D), imag(D), '+k'); 243 | plot(cos(linspace(0,1,100)*2*pi), sin(linspace(0,1,100)*2*pi), 'k--') 244 | title('eigenvalues', 'fontweight', 'normal') 245 | xlim([-1.2, 1.2]) 246 | ylim([-1.2, 1.2]) 247 | axis square; 248 | set(gcf, 'Color', 'w', 'Position', [100 200 600 700]); 249 | set(gcf, 'PaperUnits', 'inches', 'PaperPosition', [0 0 4 4.5], ... 250 | 'PaperPositionMode', 'manual'); 251 | set(gcf,'renderer','painters') 252 | %export_fig('-depsc2', '-r200', [figdir 'smooth_eigs.eps']); 253 | if save_plots 254 | print('-depsc2', '-loose', '-r200', [figdir ... 255 | 'smooth_eigs.eps']); 256 | end 257 | %print('-dpng', '-r200', [figdir 'smooth_eigs.png']); 258 | 259 | 260 | 261 | %% Clustering 262 | %D = compute_pdist(A, B, C*diag(lambda)); 263 | D = pdist(C); 264 | Z = linkage(squareform(D), 'complete'); 265 | cluster_ids = cluster(Z, 'maxclust', 3); 266 | -------------------------------------------------------------------------------- /src/example_el_nino.m: -------------------------------------------------------------------------------- 1 | clear all 2 | close all 3 | addpath('~/work/MATLAB') 4 | 5 | seed=1; 6 | rng(seed); 7 | 8 | figdir = ['../figures/test_rng_', num2str(seed), '/']; 9 | mkdir(figdir); 10 | 11 | %% base parameters 12 | M = 9; % 4.5 weeks/month * 2 months 13 | R = 6; 14 | downsample = 6; 15 | %% tensor DMD algorithm 16 | max_iter = 300; % iterations 17 | eta = 2e-3; % Tikhonov 18 | beta = 1e4; % regularizatio 19 | center = 0; 20 | %eta2 = 200/beta; % closeness to relaxation variable 21 | %% ALS 22 | nu = 0.0; % ALS noise level 23 | RALS = 1; 24 | %% prox grad 25 | step_size = 1e-5; 26 | save_file = 'test_el_nino_12_uncentered.mat'; 27 | 28 | sst = ncread('../data/sst.wkmean.1990-present.nc', 'sst'); 29 | ls_mask = ncread('../data/lsmask.nc', 'mask'); 30 | time = ncread('../data/sst.wkmean.1990-present.nc', 'time'); 31 | time = time + datenum('1800-1-1 00:00:00'); 32 | time = datetime(time, 'ConvertFrom', 'datenum'); 33 | 34 | %% Simple downsampling 35 | %sst = sst(1:downsample:end, 1:downsample:end, :); 36 | ls_mask = ls_mask(1:downsample:end, 1:downsample:end); 37 | %% Try imresize instead (in loop below) 38 | 39 | tau = size(sst, 3); 40 | N = sum(ls_mask(:) == 1); 41 | X = zeros(N, tau); 42 | 43 | for i = 1:tau 44 | sst_t = sst(:, :, i); 45 | sst_t = imresize(sst_t, 1 / downsample); 46 | sst_t = sst_t(ls_mask == 1); 47 | X(:, i) = sst_t; 48 | end 49 | %X = standardize(X); 50 | 51 | % [lambda, A, B, C, cost, Xten, Yten, rmse] = ... 52 | % tensor_DMD_alt_min_smooth(X, M, R, ... 53 | % 'center', center, ... 54 | % 'eta', eta, ... 55 | % 'beta', beta, ... 56 | % 'max_iter', max_iter,... 57 | % 'verbosity', 2); 58 | 59 | [lambda, A, B, C, cost, Xten, Yten, rmse] = ... 60 | TVART_alt_min(X, M, R, ... 61 | 'center', center, ... 62 | 'eta', eta, ... 63 | 'beta', beta, ... 64 | 'max_iter', 10, ... 65 | 'regularization', 'spline', ... 66 | 'verbosity', 2); 67 | % Set B = A and restart! 68 | if ~center 69 | s = struct('A', A, 'B', A, 'C', C); 70 | else 71 | s = struct('A', A, 'B', [A; B(end,:)], 'C', C); 72 | end 73 | [lambda, A, B, C, cost, Xten, Yten, rmse] = ... 74 | TVART_alt_min(X, M, R, ... 75 | 'center', center, ... 76 | 'eta', eta, ... 77 | 'beta', beta, ... 78 | 'max_iter', max_iter, ... 79 | 'regularization', 'spline', ... 80 | 'verbosity', 2,... 81 | 'init', s); 82 | 83 | 84 | 85 | %plot((1:size(C,1))*M/52., C); 86 | %xlabel('years') 87 | %[lambda_r, A_r, B_r, C_r] = rebalance(A*diag(lambda), B, C, 1); 88 | A_r = A; B_r = B; C_r = C; %W_r = W; 89 | lambda_r = lambda; 90 | %W = C; 91 | %[lambda_r, A_r, B_r, C_r] = reorder_components(lambda_r, A_r, B_r, C_r); 92 | %[lambda_r, A_r, B_r, C_r, W_r] = rebalance_2(A, B, C, W, 1); 93 | 94 | 95 | [lambda_r, A_r, B_r, C_r] = rebalance(A, B, C, 1); 96 | [lambda_r, A_r, B_r, C_r] = reorder_components(lambda_r, A_r, B_r, C_r); 97 | 98 | 99 | 100 | plot_times = time(floor(M/2):M:end-M+1); 101 | %[soi, soi_t] = get_soi(datenum(plot_times(1)), datenum(plot_times(end))); 102 | [nino, nino_t] = get_nino(datenum(plot_times(1)), datenum(plot_times(end))); 103 | [pdo, pdo_t] = get_pdo(datenum(plot_times(1)), datenum(plot_times(end))); 104 | 105 | save(save_file); 106 | 107 | %% Plotting 108 | figure 109 | semilogy(cost) 110 | 111 | 112 | figure(2); 113 | for mode = 1:R 114 | ax1 = subplot(4,1,1); 115 | plot(plot_times, C_r(:,mode), 'k-'); 116 | %hold on 117 | %plot(plot_times, W_r(:,mode), 'ko'); 118 | title(sprintf('Mode %d, \\lambda = %1.2f\nTemporal mode', mode, ... 119 | lambda_r(mode))); 120 | ylim('auto') 121 | ax2 = subplot(4,1,2); 122 | plot(datetime(nino_t, 'ConvertFrom', 'datenum'), ... 123 | nino(:, 2:2:end), 'color', [0 0 0] + 0.4) 124 | %title(['El Ni' char(241) 'o indices']) 125 | title('ENSO indices'); 126 | xlabel('Year') 127 | ax3 = subplot(4,1,3); 128 | regrid_a = nan(size(ls_mask)); 129 | regrid_a(ls_mask == 1) = A_r(:, mode); 130 | regrid_b = nan(size(ls_mask)); 131 | if center 132 | regrid_b(ls_mask == 1) = B_r(1:end-1, mode); 133 | else 134 | regrid_b(ls_mask == 1) = B_r(1:end, mode); 135 | end 136 | mypcolor = @(x) pcolor([x nan(size(x,1),1); nan(1, size(x,2)+ ... 137 | 1)]); 138 | imAlpha=ones(size(regrid_a')); 139 | imAlpha(isnan(regrid_a'))=0; 140 | imagesc(regrid_a', 'alphadata', imAlpha); 141 | set(gca,'YTickLabel',[]); 142 | set(gca,'XTickLabel',[]); 143 | title('Left spatial mode') 144 | colormap(flipud(brewermap([], 'PuOr'))) 145 | c = max(abs(caxis)); 146 | caxis([-c, c]) 147 | colorbar 148 | ax4 = subplot(4,1,4); 149 | imAlpha=ones(size(regrid_b')); 150 | imAlpha(isnan(regrid_b'))=0; 151 | imagesc(regrid_b', 'alphadata', imAlpha); 152 | %imagesc(regrid_b'); 153 | set(gca,'YTickLabel',[]); 154 | set(gca,'XTickLabel',[]); 155 | title('Right spatial mode') 156 | colormap(flipud(brewermap([], 'PuOr'))) 157 | c = max(abs(caxis)); 158 | caxis([-c, c]) 159 | colorbar 160 | if center 161 | xlabel(sprintf('affine weight = %1.2g, %1.2g%%', ... 162 | B_r(end, mode), ... 163 | abs(B_r(end, mode)) / sum(abs(B_r(end,:))) * 100 ),... 164 | 'fontsize', 14); 165 | end 166 | drawnow 167 | set(gcf, 'Color', 'w', 'Position', [100 200 600 700]); 168 | set(gcf, 'PaperUnits', 'inches', ... 169 | 'PaperPosition', [0 0 6.5 7.5], 'PaperPositionMode', 'manual'); 170 | print('-depsc2', '-loose', '-r200', ... 171 | [figdir 'example_el_nino_mode_' num2str(mode) '.eps']); 172 | % pause 173 | clf 174 | end 175 | 176 | 177 | 178 | figure(3); 179 | for mode = 1:R 180 | ax1 = subplot(4,1,1); 181 | plot(plot_times, C_r(:,mode), 'k-'); 182 | %hold on 183 | %plot(plot_times, W_r(:,mode), 'ko'); 184 | title(sprintf('Mode %d, \\lambda = %1.2f\nTemporal mode', mode, ... 185 | lambda_r(mode))); 186 | ylim('auto') 187 | ax2 = subplot(4,1,2); 188 | plot(datetime(pdo_t, 'ConvertFrom', 'datenum'), pdo, 'k-'); 189 | %title(['El Ni' char(241) 'o indices']) 190 | title('PDO index'); 191 | xlabel('Year') 192 | ax3 = subplot(4,1,3); 193 | regrid_a = nan(size(ls_mask)); 194 | regrid_a(ls_mask == 1) = A_r(:, mode); 195 | regrid_b = nan(size(ls_mask)); 196 | if center 197 | regrid_b(ls_mask == 1) = B_r(1:end-1, mode); 198 | else 199 | regrid_b(ls_mask == 1) = B_r(1:end, mode); 200 | end 201 | mypcolor = @(x) pcolor([x nan(size(x,1),1); nan(1, size(x,2)+ ... 202 | 1)]); 203 | imAlpha=ones(size(regrid_a')); 204 | imAlpha(isnan(regrid_a'))=0; 205 | imagesc(regrid_a', 'alphadata', imAlpha); 206 | set(gca,'YTickLabel',[]); 207 | set(gca,'XTickLabel',[]); 208 | title('Left spatial mode') 209 | colormap(flipud(brewermap([], 'PuOr'))) 210 | c = max(abs(caxis)); 211 | caxis([-c, c]) 212 | colorbar 213 | ax4 = subplot(4,1,4); 214 | imAlpha=ones(size(regrid_b')); 215 | imAlpha(isnan(regrid_b'))=0; 216 | imagesc(regrid_b', 'alphadata', imAlpha); 217 | %imagesc(regrid_b'); 218 | set(gca,'YTickLabel',[]); 219 | set(gca,'XTickLabel',[]); 220 | title('Right spatial mode') 221 | colormap(flipud(brewermap([], 'PuOr'))) 222 | c = max(abs(caxis)); 223 | caxis([-c, c]) 224 | colorbar 225 | if center 226 | xlabel(sprintf('affine weight = %1.2g, %1.2g%%', ... 227 | B_r(end, mode), ... 228 | abs(B_r(end, mode)) / sum(abs(B_r(end,:))) * 100 ),... 229 | 'fontsize', 14); 230 | end 231 | drawnow 232 | set(gcf, 'Color', 'w', 'Position', [100 200 600 700]); 233 | set(gcf, 'PaperUnits', 'inches', ... 234 | 'PaperPosition', [0 0 6.5 7.5], 'PaperPositionMode', 'manual'); 235 | print('-depsc2', '-loose', '-r200', ... 236 | [figdir 'example_pdo_mode_' num2str(mode) '.eps']); 237 | % pause 238 | clf 239 | end 240 | 241 | -------------------------------------------------------------------------------- /src/smooth_linear_comparison.m: -------------------------------------------------------------------------------- 1 | clear all 2 | close all 3 | 4 | addpath('~/work/MATLAB/') 5 | addpath('~/work/MATLAB/tensor_toolbox') 6 | addpath('~/work/MATLAB/export_fig') 7 | addpath('~/work/MATLAB/unlocbox') 8 | init_unlocbox(); 9 | figdir = '../figures/'; 10 | 11 | %% Parameters 12 | %% test system 13 | rng(1) 14 | rng(1) 15 | 16 | M = 1; 17 | r = 4; 18 | num_steps = M * 160 + 1; 19 | num_trans = 200; 20 | noise_std = 0.2; 21 | noise_process = 0.0; 22 | T = floor((num_steps - 1) / M); 23 | noise_compensate = 0; 24 | offset = 0; 25 | plotting = 0; 26 | num_rep = 1; 27 | 28 | N_vec = [6,12,24,50,100,200,400,1000,2000,4000]; 29 | %N_vec = [8000]; 30 | 31 | eta_vec = 6 ./ N_vec; 32 | %beta_vec = ones(length(N_vec), 1) * 600; 33 | beta_vec = log10(N_vec).^2 * 600; 34 | output_file = ['../data/smooth_comparison_output_rank' num2str(r) '.csv']; 35 | save_file = ['../data/smooth_comparison_output_rank' num2str(r) '.mat']; 36 | 37 | %% Now do the computation 38 | err_table = zeros(num_rep * length(N_vec), 3 + 5); 39 | row = 0; 40 | N_idx = 0; 41 | elapsed_times = zeros(length(N_vec),1); 42 | for N = N_vec 43 | N_idx = N_idx + 1; 44 | eta = eta_vec(N_idx); 45 | beta = beta_vec(N_idx); 46 | %% Setup the problem 47 | load(sprintf('../data/test_data_smooth_N_%d_M_%d_sigma_%f.mat', N, num_steps, ... 48 | noise_std)); 49 | 50 | %% TVART model 51 | tic 52 | [lambda, A, B, C, cost_vec, Xten, Yten, rmse_vec] = ... 53 | TVART_alt_min(X, M, r, ... 54 | 'eta', eta, ... 55 | 'beta', beta, ... 56 | 'max_iter', num_steps, ... 57 | 'verbosity', 1, ... 58 | 'rtol', 1e-4,... 59 | 'regularization', 'spline'); 60 | t_elapsed = toc 61 | elapsed_times(N_idx) = t_elapsed; 62 | fprintf('Fitting N = %d took %f s\n', N, t_elapsed); 63 | 64 | %% Independent models 65 | [Aindep] = indep_models(X, M); 66 | 67 | %% Low-rank independent models 68 | [Usvd,Ssvd,Vsvd] = svd(X, 0); 69 | %Xr = U(:, 1:r) * S(1:r, 1:r) * V(:, 1:r)'; 70 | %[Aindep_r] = indep_models(Xr, M); 71 | [Aindep_r] = indep_models(Ssvd(1:r, 1:r) * Vsvd(:, 1:r)', M); 72 | 73 | for model = 1:3 74 | row = row + 1; 75 | err_mse = 0.; 76 | true_mse = 0.; 77 | err_inf = 0.; 78 | err_2 = 0.; 79 | err_fro = 0.; 80 | for t = 1:(num_steps - 1) 81 | k = floor((t - 1) / M) + 1; 82 | % true system matrix 83 | Ar = [cos(thetas(k+1)) -sin(thetas(k+1)); 84 | sin(thetas(k+1)) cos(thetas(k+1))]; 85 | A_true = U * Ar * U'; 86 | 87 | % predicted system matrix 88 | if model == 1 89 | % indep(n) 90 | A_pred = Aindep(:, :, k); 91 | elseif model == 2 92 | % indep(r) 93 | %Ak = Aindep_r(:, :, k); 94 | A_pred = Usvd(:, 1:r) * Aindep_r(:, :, k) * Usvd(:, 1:r)'; 95 | elseif model == 3 96 | % TVART(r) 97 | Dk = diag(C(k,:)); 98 | A_pred = A * diag(lambda) * Dk * B'; 99 | end 100 | % compute errors 101 | xpred = A_pred * X(:, t); 102 | err_mse = err_mse + norm(xpred - X(:, t+1), 2)^2; 103 | true_mse = true_mse + norm(A_true * X(:, t) - X(:, t+1), 2)^2; 104 | err_inf = err_inf + max(abs(A_pred(:) - A_true(:))); 105 | err_2 = err_2 + norm(A_pred - A_true, 2); 106 | err_fro = err_fro + norm(A_pred - A_true, 'fro'); 107 | end % for t 108 | err_mse = err_mse / (N * (num_steps - 1)) 109 | true_mse = true_mse / (N * (num_steps - 1)) 110 | err_inf = err_inf / (num_steps - 1) 111 | err_2 = err_2 / (num_steps - 1) 112 | err_fro = err_fro / (num_steps - 1) 113 | 114 | err_table(row, :) = [N, nan, model, err_inf, err_2, ... 115 | err_fro, err_mse, true_mse]; 116 | end 117 | 118 | % for model = 1:3 119 | % for k = 1:T 120 | % row = row + 1; 121 | % if model == 1 122 | % Ak = Aindep(:, :, k); 123 | % elseif model == 2 124 | % %Ak = Aindep_r(:, :, k); 125 | % Ak = Usvd(:, 1:r) * Aindep_r(:, :, k) * Usvd(:, 1:r)'; 126 | % elseif model == 3 127 | % Dk = diag(C(k,:)); 128 | % Ak = A * diag(lambda) * Dk * B'; 129 | % end 130 | 131 | % Ar = [cos(thetas(k+1)) -sin(thetas(k+1)); 132 | % sin(thetas(k+1)) cos(thetas(k+1))]; 133 | % A_true = U * Ar * U'; 134 | 135 | % [D] = eig(Ak(:, 1:N)); 136 | % [Dtrue] = eig(A_true); 137 | % % compute errors: 138 | % model_MSE = norm(Yten(:,:,k) - Ak * Xten(:,:,k), 'fro')^2 / (N*M); 139 | % true_MSE = norm(Yten(:,:,k) - A_true * Xten(1:N,:,k), 'fro')^2 / ... 140 | % (N*M); 141 | % A_err_inf = max(max(abs(Ak(:, 1:N) - A_true))); 142 | % A_err_2 = norm(Ak(:, 1:N) - A_true, 2); 143 | % A_err_fro = norm(Ak(:, 1:N) - A_true, 'fro'); 144 | 145 | % err_table(row, :) = [N, k, model, A_err_inf, A_err_2, ... 146 | % A_err_fro, model_MSE, true_MSE]; 147 | 148 | % fprintf('\n\tWindow %d\n', k) 149 | % fprintf('\tprediction MSE using A_tensor: %1.3g (SNR: %1.2g)\n', model_MSE, ... 150 | % model_MSE / noise_std^2) 151 | % fprintf('\tprediction MSE using A_truth: %1.3g (SNR: %1.2g)\n', true_MSE,... 152 | % true_MSE / noise_std^2) 153 | % fprintf('\tmodel MSE relative to truth: %1.3g%%\n', ... 154 | % model_MSE / true_MSE * 100) 155 | % fprintf('\tA infty error: %1.3g\n', A_err_inf) 156 | % fprintf('\tA 2-error: %1.3g\n', A_err_2) 157 | % fprintf('\tradius(A_tensor) = %1.2g, radius(A_true) = %1.2g\n', ... 158 | % max(abs(D)), max(abs(Dtrue))); 159 | 160 | % if plotting 161 | % figure(1) 162 | % set(gcf, 'position', [246 66 431 1231]); 163 | % subplot(3,1,1); 164 | % imagesc(Ak) 165 | % colorbar 166 | % colormap(brewermap([], 'PuOr')) 167 | % caxis([-1,1]) 168 | % title(sprintf('Inferred system matrix, model %d', model), 'fontweight', 'normal') 169 | % axis square 170 | % subplot(3,1,2); 171 | % imagesc(A_true); 172 | % colorbar 173 | % colormap(brewermap([], 'PuOr')) 174 | % caxis([-1,1]) 175 | % title('True matrix', 'fontweight', 'normal') 176 | % axis square 177 | % subplot(3,1,3); 178 | % imagesc(Ak(:,1:N) - A_true); 179 | % colorbar 180 | % colormap(brewermap([], 'PuOr')) 181 | % caxis([-1, 1]) 182 | % title('Difference', 'fontweight', 'normal') 183 | % axis square 184 | % [~] = input('Press any key to continue'); 185 | % if k ~= T 186 | % close 187 | % end 188 | % end 189 | % end 190 | 191 | % fprintf('\n=====================\nSummaries for N = %d\n', N); 192 | % for model = 1:3 193 | % idx = err_table(:, 3) == model & err_table(:, 1) == N; 194 | % errs = mean(err_table(idx, 4:end), 1); 195 | % A_err_inf = errs(1); 196 | % A_err_2 = errs(2); 197 | % A_err_fro = errs(3); 198 | % model_MSE = errs(4); 199 | % true_MSE = errs(5); 200 | % fprintf(['Model %d :\n\tA infinity error = %1.2g\n\tA 2-error = %1.2g\n\t' ... 201 | % 'model MSE = %1.2g\n\ttrue MSE = %1.2g\n'], model, A_err_inf, A_err_2, ... 202 | % model_MSE, true_MSE); 203 | % end 204 | 205 | end % N loop 206 | 207 | %% summary of errors for plotting 208 | error_summary = zeros(3, length(N_vec), 5); 209 | for model = 1:3 210 | for nidx = 1:length(N_vec) 211 | N = N_vec(nidx); 212 | idx = err_table(:, 3) == model & err_table(:, 1) == N; 213 | errs = mean(err_table(idx, 4:end, 1), 1); 214 | error_summary(model, nidx, :) = errs; 215 | end 216 | end 217 | 218 | error_table = array2table(err_table, ... 219 | 'VariableNames', {'N', 'window' 'model', ... 220 | 'err_inf', 'err_2', 'err_fro', 'model_MSE', ... 221 | 'true_MSE'}); 222 | 223 | model_vec = err_table(:,3); 224 | 225 | writetable(error_table, output_file); 226 | 227 | % figure; 228 | % loglog(N_vec, error_summary(:,:,1)) 229 | % legend({'independent', 'rank r indep', 'TVART'}) 230 | % xlabel('N') 231 | % ylabel('||A - A_{true} ||_\infty') 232 | % print('-depsc', [figdir, 'compare_err_inf.eps']) 233 | 234 | % figure; 235 | % loglog(N_vec, error_summary(:,:,2)) 236 | % legend({'independent', 'rank r indep', 'TVART'}) 237 | % xlabel('N') 238 | % ylabel('||A - A_{true} ||_2') 239 | % print('-depsc', [figdir, 'compare_err_2.eps']) 240 | 241 | % figure; 242 | % loglog(N_vec, error_summary(:,:,3)) 243 | % legend({'independent', 'rank r indep', 'TVART'}) 244 | % xlabel('N') 245 | % ylabel('||A - A_{true} ||_{Fro}') 246 | % print('-depsc', [figdir, 'compare_err_fro.eps']) 247 | 248 | save(save_file) -------------------------------------------------------------------------------- /src/ssm_compare.py: -------------------------------------------------------------------------------- 1 | import ssm 2 | import numpy as np 3 | from numpy.linalg import norm, svd 4 | from ssm.util import find_permutation 5 | from scipy.optimize import curve_fit, fsolve 6 | 7 | def fit_arhmm_and_return_errors(X, A1, A2, Kmax=4, num_restarts=1, 8 | num_iters=100, rank=None): 9 | ''' 10 | Fit an ARHMM to test data and return errors. 11 | 12 | Parameters 13 | ========== 14 | 15 | X : array, T x N 16 | A1 : array, N x N 17 | A2 : array, N x N 18 | ''' 19 | # hardcoded 20 | true_K = 2 21 | # params 22 | N = X.shape[1] 23 | T = X.shape[0] 24 | 25 | if rank is not None: 26 | # project data down 27 | u, s, vt = np.linalg.svd(X) 28 | Xp = u[:, 0:rank] * s[0:rank] # T x rank matrix 29 | else: 30 | Xp = X 31 | 32 | def _fit_once(): 33 | # fit a model 34 | if rank is not None: 35 | arhmm = ssm.HMM(Kmax, rank, observations="ar") 36 | else: 37 | arhmm = ssm.HMM(Kmax, N, observations="ar") 38 | lls = arhmm.fit(Xp, num_iters=num_iters) 39 | return arhmm, lls 40 | 41 | # Fit num_restarts many models 42 | results = [] 43 | for restart in range(num_restarts): 44 | print("restart ", restart + 1, " / ", num_restarts) 45 | results.append(_fit_once()) 46 | arhmms, llss = list(zip(*results)) 47 | 48 | # Take the ARHMM that achieved the highest training ELBO 49 | best = np.argmax([lls[-1] for lls in llss]) 50 | arhmm, lls = arhmms[best], llss[best] 51 | 52 | # xhat = arhmm.smooth(X) 53 | pred_states = arhmm.most_likely_states(Xp) 54 | 55 | # Align the labels between true and most likely 56 | true_states = np.array([0 if i < T/2 else 1 for i in range(T)]) 57 | arhmm.permute(find_permutation(true_states, pred_states, 58 | true_K, Kmax)) 59 | print("predicted states:") 60 | print(pred_states) 61 | # extract predicted A1, A2 matrices 62 | Ahats, bhats = arhmm.observations.As, arhmm.observations.bs 63 | if rank is not None: 64 | # project back up 65 | Ahats = [ vt[0:rank, :].T @ Ahat @ vt[0:rank, :] for Ahat in Ahats ] 66 | bhats = [ vt[0:rank, :].T @ bhat for bhat in bhats ] 67 | 68 | # A_r = slds.dynamics.As 69 | # b_r = slds.dynamics.bs 70 | # Cs = slds.emissions.Cs[0] 71 | # A1_pred = Cs @ A_r[0] @ np.linalg.pinv(Cs) 72 | # A2_pred = Cs @ A_r[1] @ np.linalg.pinv(Cs) 73 | # compare inferred and true 74 | #err_inf = 0.5 * (np.max(np.abs(A1_pred[:] - A1[:])) + \ 75 | # np.max(np.abs(A2_pred[:] - A2[:]))) 76 | #err_2 = 0.5 * (norm(A1_pred - A1, 2) + \ 77 | # norm(A2_pred - A2, 2)) 78 | #err_fro = 0.5 * (norm(A1_pred - A1, 'fro') + \ 79 | # norm(A2_pred - A2, 'fro')) 80 | err_mse, err_inf, err_2, err_fro = errors(Ahats, bhats, pred_states, true_states, A1, A2, X) 81 | return (err_inf, err_2, err_fro, err_mse, lls) 82 | 83 | def fit_slds_and_return_errors(X, A1, A2, Kmax=4, r=6, num_iters=200, 84 | num_restarts=1, 85 | laplace_em=True, 86 | single_subspace=True, 87 | use_ds=True): 88 | ''' 89 | Fit an SLDS to test data and return errors. 90 | 91 | Parameters 92 | ========== 93 | 94 | X : array, T x N 95 | A1 : array, N x N 96 | A2 : array, N x N 97 | ''' 98 | # hardcoded 99 | true_K = 2 100 | # params 101 | N = X.shape[1] 102 | T = X.shape[0] 103 | 104 | def _fit_once(): 105 | # fit a model 106 | slds = ssm.SLDS(N, Kmax, r, single_subspace=single_subspace, 107 | emissions='gaussian') 108 | #slds.initialize(X) 109 | #q_mf = SLDSMeanFieldVariationalPosterior(slds, X) 110 | if laplace_em: 111 | elbos, posterior = slds.fit(X, num_iters=num_iters, initialize=True, 112 | method="laplace_em", 113 | variational_posterior="structured_meanfield") 114 | posterior_x = posterior.mean_continuous_states[0] 115 | else: 116 | # Use blackbox + meanfield 117 | elbos, posterior = slds.fit(X, num_iters=num_iters, initialize=True, 118 | method="bbvi", 119 | variational_posterior="mf") 120 | # predict states 121 | return slds, elbos, posterior 122 | 123 | # Fit num_restarts many models 124 | results = [] 125 | for restart in range(num_restarts): 126 | print("restart ", restart + 1, " / ", num_restarts) 127 | results.append(_fit_once()) 128 | sldss, elboss, posteriors = list(zip(*results)) 129 | 130 | # Take the SLDS that achieved the highest training ELBO 131 | best = np.argmax([elbos[-1] for elbos in elboss]) 132 | slds, elbos, posterior = sldss[best], elboss[best], posteriors[best] 133 | 134 | if laplace_em: 135 | posterior_x = posterior.mean_continuous_states[0] 136 | else: 137 | posterior_x = posterior.mean[0] 138 | 139 | # Align the labels between true and most likely 140 | true_states = np.array([0 if i < T/2 else 1 for i in range(T)]) 141 | slds.permute(find_permutation(true_states, 142 | slds.most_likely_states(posterior_x, X), 143 | true_K, Kmax)) 144 | pred_states = slds.most_likely_states(posterior_x, X) 145 | print("predicted states:") 146 | print(pred_states) 147 | # extract predicted A1, A2 matrices 148 | Ahats, bhats = convert_slds_to_tvart(slds) 149 | # A_r = slds.dynamics.As 150 | # b_r = slds.dynamics.bs 151 | # Cs = slds.emissions.Cs[0] 152 | # A1_pred = Cs @ A_r[0] @ np.linalg.pinv(Cs) 153 | # A2_pred = Cs @ A_r[1] @ np.linalg.pinv(Cs) 154 | # compare inferred and true 155 | #err_inf = 0.5 * (np.max(np.abs(A1_pred[:] - A1[:])) + \ 156 | # np.max(np.abs(A2_pred[:] - A2[:]))) 157 | #err_2 = 0.5 * (norm(A1_pred - A1, 2) + \ 158 | # norm(A2_pred - A2, 2)) 159 | #err_fro = 0.5 * (norm(A1_pred - A1, 'fro') + \ 160 | # norm(A2_pred - A2, 'fro')) 161 | err_mse, err_inf, err_2, err_fro = errors(Ahats, bhats, pred_states, true_states, A1, A2, X) 162 | return (err_inf, err_2, err_fro, err_mse, elbos) 163 | 164 | def find_final_iterate(data, rtol): 165 | # Used for estimating convergence in older version 166 | # Runtime comparisons were performed by hand in final 167 | def sigmoid (x, A, x0, slope, C): 168 | return 1 / (1 + np.exp ((x0 - x) / slope)) * A + C 169 | 170 | x = np.arange(len(data)) 171 | y = data / np.std(data) 172 | 173 | pinit = [np.max(y), np.median(x), 1, np.min(y)] 174 | popt, pcov = curve_fit(sigmoid, x, y, pinit, maxfev=10000) 175 | 176 | fmax = popt[0] + popt[3] 177 | if fmax < 0: 178 | thresh = fmax * (1 + rtol) 179 | else: 180 | thresh = fmax * (1 - rtol) 181 | #thresh = popt[3] + 0.999 * popt[0] 182 | f = lambda x: sigmoid(x, *popt) - thresh 183 | maxit = int(fsolve(f, len(data)/2)[0]) 184 | return maxit 185 | 186 | def errors(Ahats, bhats, pred_states, true_states, A1, A2, X): 187 | # params 188 | N = X.shape[1] 189 | T = X.shape[0] 190 | assert len(pred_states) == T, "pred_states must be length T" 191 | assert len(true_states) == T, "true_states must be length T" 192 | N = X.shape[0] 193 | err_mse = 0. 194 | err_inf = 0. 195 | err_2 = 0. 196 | err_fro = 0. 197 | for t in range(T - 1): 198 | if true_states[t] == 0: 199 | A_true = A1 200 | else: 201 | A_true = A2 202 | A_pred = Ahats[pred_states[t]] 203 | b_pred = bhats[pred_states[t]] 204 | xpred = A_pred @ X[t, :].T + b_pred 205 | # A_r = slds.dynamics.As[pred_states[t]] 206 | # A_pred = Cs @ A_r @ np.linalg.pinv(Cs) 207 | # xpred = A_pred @ X[t, :].T + Cs @ b_r[pred_states[t]] 208 | err_mse += norm(xpred - X[t+1, :], 2)**2 209 | err_inf += np.max(np.abs(A_pred[:] - A_true[:])) 210 | err_2 += norm(A_pred - A_true, 2) 211 | err_fro += norm(A_pred - A_true, 'fro') 212 | err_mse /= float(N * (T - 1.)) 213 | err_inf /= float(T - 1.) 214 | err_2 /= float(T - 1.) 215 | err_fro /= float(T - 1.) 216 | return err_mse, err_inf, err_2, err_fro 217 | 218 | def convert_slds_to_tvart(slds, use_ds=True): 219 | # This code modified from that provided by Scott Linderman 220 | # Compare the true and inferred parameters 221 | Cs, ds = slds.emissions.Cs, slds.emissions.ds 222 | As, bs = slds.dynamics.As, slds.dynamics.bs 223 | single_subspace = slds.emissions.single_subspace 224 | 225 | # Use the pseudoinverse of C to project down to latent space 226 | Cinvs = np.linalg.pinv(Cs, rcond=1e-8) 227 | 228 | if single_subspace: 229 | Cs = np.repeat(Cs, slds.K, axis=0) 230 | Cinvs = np.repeat(Cinvs, slds.K, axis=0) 231 | ds = np.repeat(ds, slds.K, axis=0) 232 | 233 | # Compute the effective transition operator on the data 234 | Aeffs = np.matmul(Cs, np.matmul(As, Cinvs)) 235 | # Compute effective affine/intercept term 236 | if use_ds: 237 | beffs = ds[:, :, None] - np.matmul(Aeffs, ds[:, :, None]) \ 238 | + np.matmul(Cs, bs[:, :, None]) 239 | else: 240 | beffs = np.matmul(Cs, bs[:, :, None]) 241 | return Aeffs, beffs 242 | -------------------------------------------------------------------------------- /data/PDO.txt: -------------------------------------------------------------------------------- 1 | 1900 0.04 1.32 0.49 0.35 0.77 0.65 0.95 0.14 -0.24 0.23 -0.44 1.19 2 | 1901 0.79 -0.12 0.35 0.61 -0.42 -0.05 -0.60 -1.20 -0.33 0.16 -0.60 -0.14 3 | 1902 0.82 1.58 0.48 1.37 1.09 0.52 1.58 1.57 0.44 0.70 0.16 -1.10 4 | 1903 0.86 -0.24 -0.22 -0.50 0.43 0.23 0.40 1.01 -0.24 0.18 0.08 -0.03 5 | 1904 0.63 -0.91 -0.71 -0.07 -0.22 -1.53 -1.58 -0.64 0.06 0.43 1.45 0.06 6 | 1905 0.73 0.91 1.31 1.59 -0.07 0.69 0.85 1.26 -0.03 -0.15 1.11 -0.50 7 | 1906 0.92 1.18 0.83 0.74 0.44 1.24 0.09 -0.53 -0.31 0.08 1.69 -0.54 8 | 1907 -0.30 -0.32 -0.19 -0.16 0.16 0.57 0.63 -0.96 -0.23 0.84 0.66 0.72 9 | 1908 1.36 1.02 0.67 0.23 0.23 0.41 0.60 -1.04 -0.16 -0.41 0.47 1.16 10 | 1909 0.23 1.01 0.54 0.24 -0.39 -0.64 -0.39 -0.68 -0.89 -0.02 -0.40 -0.01 11 | 1910 -0.25 -0.70 0.18 -0.37 -0.06 -0.28 0.03 -0.06 0.40 -0.66 0.02 0.84 12 | 1911 -1.11 0.00 -0.78 -0.73 0.17 0.02 0.48 0.43 0.29 0.20 -0.86 0.01 13 | 1912 -1.72 -0.23 -0.04 -0.38 -0.02 0.77 1.07 -0.84 0.94 0.56 0.74 0.98 14 | 1913 -0.03 0.34 0.06 -0.92 0.66 1.43 1.06 1.29 0.73 0.62 0.75 0.90 15 | 1914 0.34 -0.29 0.08 1.20 0.11 0.11 -0.21 0.11 -0.34 -0.11 0.03 0.89 16 | 1915 -0.41 0.14 -1.22 1.40 0.32 0.99 1.07 0.27 -0.05 -0.43 -0.12 0.17 17 | 1916 -0.64 -0.19 -0.11 0.35 0.42 -0.82 -0.78 -0.73 -0.77 -0.22 -0.68 -1.94 18 | 1917 -0.79 -0.84 -0.71 -0.34 0.82 -0.03 0.10 -0.22 -0.40 -1.75 -0.34 -0.60 19 | 1918 -1.13 -0.66 -1.15 -0.32 -0.33 0.07 0.98 -0.31 -0.59 0.61 0.34 0.86 20 | 1919 -1.07 1.31 -0.50 0.08 0.17 -0.71 -0.47 0.38 0.06 -0.42 -0.80 0.76 21 | 1920 -1.18 0.06 -0.78 -1.29 -0.97 -1.30 -0.90 -2.21 -1.28 -1.06 -0.26 0.29 22 | 1921 -0.66 -0.61 -0.01 -0.93 -0.42 0.40 -0.58 -0.69 -0.78 -0.23 1.92 1.42 23 | 1922 1.05 -0.85 0.08 0.43 -0.19 -1.04 -0.82 -0.93 -0.81 0.84 -0.60 0.48 24 | 1923 0.75 -0.04 0.49 0.99 -0.20 0.68 1.16 0.84 -0.24 1.10 0.62 -0.36 25 | 1924 1.29 0.73 1.13 -0.02 0.36 0.75 -0.55 -0.67 -0.48 -1.25 0.24 0.11 26 | 1925 -0.05 -0.14 0.20 0.86 0.79 -1.08 -0.06 -0.86 0.52 0.04 0.88 1.19 27 | 1926 0.30 0.98 -0.50 2.10 1.43 2.03 1.05 1.64 1.18 1.65 1.00 1.06 28 | 1927 1.07 1.73 0.15 -0.18 0.30 0.69 -0.31 -0.73 -0.41 -0.62 -0.07 0.07 29 | 1928 0.96 0.79 0.52 0.81 0.66 0.15 0.30 -0.72 -1.41 -1.31 0.14 0.98 30 | 1929 0.97 0.52 0.50 0.55 1.07 0.50 -0.06 -0.69 0.45 -0.21 1.24 -0.03 31 | 1930 0.97 -1.06 -0.43 -0.70 0.06 0.58 -0.45 -0.53 -0.20 -0.38 -0.31 1.20 32 | 1931 0.08 1.56 1.13 1.28 1.66 0.39 1.49 0.02 -0.01 -0.17 0.34 1.09 33 | 1932 -0.26 -0.58 0.51 1.15 0.64 0.10 -0.12 -0.14 -0.40 -0.29 -0.88 0.02 34 | 1933 0.29 0.02 0.15 -0.05 -0.50 -0.68 -1.81 -1.56 -2.28 -1.19 0.55 -1.10 35 | 1934 0.17 0.68 1.34 1.63 1.23 0.51 0.44 1.54 1.25 2.10 1.63 1.67 36 | 1935 1.01 0.79 -0.11 1.10 0.99 1.39 0.68 0.63 0.98 0.21 0.13 1.78 37 | 1936 1.79 1.75 1.36 1.32 1.83 2.37 2.57 1.71 0.04 2.10 2.65 1.28 38 | 1937 0.00 -0.49 0.38 0.20 0.53 1.75 0.11 -0.35 0.63 0.76 -0.18 0.55 39 | 1938 0.50 0.02 0.24 0.27 -0.25 -0.20 -0.21 -0.45 -0.01 0.07 0.48 1.40 40 | 1939 1.36 0.07 -0.39 0.45 0.98 1.04 -0.21 -0.74 -1.10 -1.31 -0.88 1.51 41 | 1940 2.03 1.74 1.89 2.37 2.32 2.43 2.12 1.40 1.10 1.19 0.68 1.96 42 | 1941 2.14 2.07 2.41 1.89 2.25 3.01 2.33 3.31 1.99 1.22 0.40 0.91 43 | 1942 1.01 0.79 0.29 0.79 0.84 1.19 0.12 0.44 0.68 0.54 -0.10 -1.00 44 | 1943 -0.18 0.02 0.26 1.08 0.43 0.68 -0.36 -0.90 -0.49 -0.04 0.29 0.58 45 | 1944 0.18 0.17 0.08 0.72 -0.35 -0.98 -0.40 -0.51 -0.56 -0.40 0.33 0.20 46 | 1945 -1.02 0.72 -0.42 -0.40 -0.07 0.56 1.02 0.18 -0.27 0.10 -1.94 -0.74 47 | 1946 -0.91 -0.32 -0.41 -0.78 0.50 -0.86 -0.84 -0.36 -0.22 -0.36 -1.48 -0.96 48 | 1947 -0.73 -0.29 1.17 0.70 0.37 1.36 0.16 0.30 0.58 0.85 -0.14 1.67 49 | 1948 -0.11 -0.74 -0.03 -1.33 -0.23 0.08 -0.92 -1.56 -1.74 -1.32 -0.89 -1.70 50 | 1949 -2.01 -3.60 -1.00 -0.53 -1.07 -0.70 -0.56 -1.30 -0.93 -1.41 -0.83 -0.80 51 | 1950 -2.13 -2.91 -1.13 -1.20 -2.23 -1.77 -2.93 -0.70 -2.14 -1.36 -2.46 -0.76 52 | 1951 -1.54 -1.06 -1.90 -0.36 -0.25 -1.09 0.70 -1.37 -0.08 -0.32 -0.28 -1.68 53 | 1952 -2.01 -0.46 -0.63 -1.05 -1.00 -1.43 -1.25 -0.60 -0.89 -0.35 -0.76 0.04 54 | 1953 -0.57 -0.07 -1.12 0.05 0.43 0.29 0.74 0.05 -0.63 -1.09 -0.03 0.07 55 | 1954 -1.32 -1.61 -0.52 -1.33 0.01 0.97 0.43 0.08 -0.94 0.52 0.72 -0.50 56 | 1955 0.20 -1.52 -1.26 -1.97 -1.21 -2.44 -2.35 -2.25 -1.95 -2.80 -3.08 -2.75 57 | 1956 -2.48 -2.74 -2.56 -2.17 -1.41 -1.70 -1.03 -1.16 -0.71 -2.30 -2.11 -1.28 58 | 1957 -1.82 -0.68 0.03 -0.58 0.57 1.76 0.72 0.51 1.59 1.50 -0.32 -0.55 59 | 1958 0.25 0.62 0.25 1.06 1.28 1.33 0.89 1.06 0.29 0.01 -0.18 0.86 60 | 1959 0.69 -0.43 -0.95 -0.02 0.23 0.44 -0.50 -0.62 -0.85 0.52 1.11 0.06 61 | 1960 0.30 0.52 -0.21 0.09 0.91 0.64 -0.27 -0.38 -0.94 0.09 -0.23 0.17 62 | 1961 1.18 0.43 0.09 0.34 -0.06 -0.61 -1.22 -1.13 -2.01 -2.28 -1.85 -2.69 63 | 1962 -1.29 -1.15 -1.42 -0.80 -1.22 -1.62 -1.46 -0.48 -1.58 -1.55 -0.37 -0.96 64 | 1963 -0.33 -0.16 -0.54 -0.41 -0.65 -0.88 -1.00 -1.03 0.45 -0.52 -2.08 -1.08 65 | 1964 0.01 -0.21 -0.87 -1.03 -1.91 -0.32 -0.51 -1.03 -0.68 -0.37 -0.80 -1.52 66 | 1965 -1.24 -1.16 0.04 0.62 -0.66 -0.80 -0.47 0.20 0.59 -0.36 -0.59 0.06 67 | 1966 -0.82 -0.03 -1.29 0.06 -0.53 0.16 0.26 -0.35 -0.33 -1.17 -1.15 -0.32 68 | 1967 -0.20 -0.18 -1.20 -0.89 -1.24 -1.16 -0.89 -1.24 -0.72 -0.64 -0.05 -0.40 69 | 1968 -0.95 -0.40 -0.31 -1.03 -0.53 -0.35 0.53 0.19 0.06 -0.34 -0.44 -1.27 70 | 1969 -1.26 -0.95 -0.50 -0.44 -0.20 0.89 0.10 -0.81 -0.66 1.12 0.15 1.38 71 | 1970 0.61 0.43 1.33 0.43 -0.49 0.06 -0.68 -1.63 -1.67 -1.39 -0.80 -0.97 72 | 1971 -1.90 -1.74 -1.68 -1.59 -1.55 -1.55 -2.20 -0.15 0.21 -0.22 -1.25 -1.87 73 | 1972 -1.99 -1.83 -2.09 -1.65 -1.57 -1.87 -0.83 0.25 0.17 0.11 0.57 -0.33 74 | 1973 -0.46 -0.61 -0.50 -0.69 -0.76 -0.97 -0.57 -1.14 -0.51 -0.87 -1.81 -0.76 75 | 1974 -1.22 -1.65 -0.90 -0.52 -0.28 -0.31 -0.08 0.27 0.44 -0.10 0.43 -0.12 76 | 1975 -0.84 -0.71 -0.51 -1.30 -1.02 -1.16 -0.40 -1.07 -1.23 -1.29 -2.08 -1.61 77 | 1976 -1.14 -1.85 -0.96 -0.89 -0.68 -0.67 0.61 1.28 0.82 1.11 1.25 1.22 78 | 1977 1.65 1.11 0.72 0.30 0.31 0.42 0.19 0.64 -0.55 -0.61 -0.72 -0.69 79 | 1978 0.34 1.45 1.34 1.29 0.90 0.15 -1.24 -0.56 -0.44 0.10 -0.07 -0.43 80 | 1979 -0.58 -1.33 0.30 0.89 1.09 0.17 0.84 0.52 1.00 1.06 0.48 -0.42 81 | 1980 -0.11 1.32 1.09 1.49 1.20 -0.22 0.23 0.51 0.10 1.35 0.37 -0.10 82 | 1981 0.59 1.46 0.99 1.45 1.75 1.69 0.84 0.18 0.42 0.18 0.80 0.67 83 | 1982 0.34 0.20 0.19 -0.19 -0.58 -0.78 0.58 0.39 0.84 0.37 -0.25 0.26 84 | 1983 0.56 1.14 2.11 1.87 1.80 2.36 3.51 1.85 0.91 0.96 1.02 1.69 85 | 1984 1.50 1.21 1.77 1.52 1.30 0.18 -0.18 -0.03 0.67 0.58 0.71 0.82 86 | 1985 1.27 0.94 0.57 0.19 0.00 0.18 1.07 0.81 0.44 0.29 -0.75 0.38 87 | 1986 1.12 1.61 2.18 1.55 1.16 0.89 1.38 0.22 0.22 1.00 1.77 1.77 88 | 1987 1.88 1.75 2.10 2.16 1.85 0.73 2.01 2.83 2.44 1.36 1.47 1.27 89 | 1988 0.93 1.24 1.42 0.94 1.20 0.74 0.64 0.19 -0.37 -0.10 -0.02 -0.43 90 | 1989 -0.95 -1.02 -0.83 -0.32 0.47 0.36 0.83 0.09 0.05 -0.12 -0.50 -0.21 91 | 1990 -0.30 -0.65 -0.62 0.27 0.44 0.44 0.27 0.11 0.38 -0.69 -1.69 -2.23 92 | 1991 -2.02 -1.19 -0.74 -1.01 -0.51 -1.47 -0.10 0.36 0.65 0.49 0.42 0.09 93 | 1992 0.05 0.31 0.67 0.75 1.54 1.26 1.90 1.44 0.83 0.93 0.93 0.53 94 | 1993 0.05 0.19 0.76 1.21 2.13 2.34 2.35 2.69 1.56 1.41 1.24 1.07 95 | 1994 1.21 0.59 0.80 1.05 1.23 0.46 0.06 -0.79 -1.36 -1.32 -1.96 -1.79 96 | 1995 -0.49 0.46 0.75 0.83 1.46 1.27 1.71 0.21 1.16 0.47 -0.28 0.16 97 | 1996 0.59 0.75 1.01 1.46 2.18 1.10 0.77 -0.14 0.24 -0.33 0.09 -0.03 98 | 1997 0.23 0.28 0.65 1.05 1.83 2.76 2.35 2.79 2.19 1.61 1.12 0.67 99 | 1998 0.83 1.56 2.01 1.27 0.70 0.40 -0.04 -0.22 -1.21 -1.39 -0.52 -0.44 100 | 1999 -0.32 -0.66 -0.33 -0.41 -0.68 -1.30 -0.66 -0.96 -1.53 -2.23 -2.05 -1.63 101 | 2000 -2.00 -0.83 0.29 0.35 -0.05 -0.44 -0.66 -1.19 -1.24 -1.30 -0.53 0.52 102 | 2001 0.60 0.29 0.45 -0.31 -0.30 -0.47 -1.31 -0.77 -1.37 -1.37 -1.26 -0.93 103 | 2002 0.27 -0.64 -0.43 -0.32 -0.63 -0.35 -0.31 0.60 0.43 0.42 1.51 2.10 104 | 2003 2.09 1.75 1.51 1.18 0.89 0.68 0.96 0.88 0.01 0.83 0.52 0.33 105 | 2004 0.43 0.48 0.61 0.57 0.88 0.04 0.44 0.85 0.75 -0.11 -0.63 -0.17 106 | 2005 0.44 0.81 1.36 1.03 1.86 1.17 0.66 0.25 -0.46 -1.32 -1.50 0.20 107 | 2006 1.03 0.66 0.05 0.40 0.48 1.04 0.35 -0.65 -0.94 -0.05 -0.22 0.14 108 | 2007 0.01 0.04 -0.36 0.16 -0.10 0.09 0.78 0.50 -0.36 -1.45 -1.08 -0.58 109 | 2008 -1.00 -0.77 -0.71 -1.52 -1.37 -1.34 -1.67 -1.70 -1.55 -1.76 -1.25 -0.87 110 | 2009 -1.40 -1.55 -1.59 -1.65 -0.88 -0.31 -0.53 0.09 0.52 0.27 -0.40 0.08 111 | 2010 0.83 0.82 0.44 0.78 0.62 -0.22 -1.05 -1.27 -1.61 -1.06 -0.82 -1.21 112 | 2011 -0.92 -0.83 -0.69 -0.42 -0.37 -0.69 -1.86 -1.74 -1.79 -1.34 -2.33 -1.79 113 | 2012 -1.38 -0.85 -1.05 -0.27 -1.26 -0.87 -1.52 -1.93 -2.21 -0.79 -0.59 -0.48 114 | 2013 -0.13 -0.43 -0.63 -0.16 0.08 -0.78 -1.25 -1.04 -0.48 -0.87 -0.11 -0.41 115 | 2014 0.30 0.38 0.97 1.13 1.80 0.82 0.70 0.67 1.08 1.49 1.72 2.51 116 | 2015 2.45 2.30 2.00 1.44 1.20 1.54 1.84 1.56 1.94 1.47 0.86 1.01 117 | 2016 1.53 1.75 2.40 2.62 2.35 2.03 1.25 0.52 0.45 0.56 1.88 1.17 118 | 2017 0.77 0.70 0.74 1.12 0.88 0.79 0.10 0.09 0.32 0.05 0.15 0.50 119 | -------------------------------------------------------------------------------- /src/switching_linear_comparison.m: -------------------------------------------------------------------------------- 1 | clear all 2 | close all 3 | 4 | addpath('~/work/MATLAB/') 5 | %addpath('~/work/MATLAB/tensor_toolbox') 6 | %addpath('~/work/MATLAB/export_fig') 7 | addpath('~/work/MATLAB/unlocbox') 8 | init_unlocbox(); 9 | figdir = '../figures/'; 10 | 11 | %% Parameters 12 | %% test system 13 | rng(1) 14 | 15 | M = 20; 16 | r = 4; 17 | num_rep = 3; 18 | num_steps = M * 100 + 1; 19 | %num_steps = M * 10 + 1; 20 | num_trans = 200; 21 | noise_std = 0.5; 22 | noise_process = 0.0; 23 | T = floor((num_steps - 1) / M); 24 | noise_compensate = 0; 25 | offset = 0; 26 | %% tensor DMD algorithm 27 | regularization = 'TV'; 28 | max_iter = 300; % iterations 29 | center = 0; 30 | eta = 1.; % Tikhonov 31 | beta = 1.; % regularization strength 32 | proximal = 0; 33 | plotting = 0; 34 | 35 | N_vec = [10,14,18,20,30,50,80,100,200,400]; %,2000,4000]; 36 | %eta_vec = linspace(1, 0.01, length(N_vec)); 37 | %eta_vec = logspace(0, -2, length(N_vec)); 38 | eta_vec = 1 ./ N_vec; 39 | beta_vec = ones(length(N_vec), 1); 40 | % output_file = '../data/comparison_output_rank6.csv'; 41 | % save_file = '../data/comparison_output_rank6.mat'; 42 | % output_file = '../data/comparison_output_rank4.csv'; 43 | % save_file = '../data/comparison_output_rank4.mat'; 44 | output_file = '../data/comparison_output_2000_rank6.csv'; 45 | save_file = '../data/comparison_output_2000_rank6.mat'; 46 | 47 | %% Now do the computation 48 | err_table = zeros(num_rep * length(N_vec), 3 + 5); 49 | row = 0; 50 | N_idx = 0; 51 | elapsed_times = zeros(length(N_vec), num_rep); 52 | for N = N_vec 53 | N_idx = N_idx + 1; 54 | for rep = 1:num_rep 55 | % if N > 40 56 | % eta = 0.; 57 | % %beta = 4; 58 | % end 59 | % if N > 500 60 | % eta = 0.05; 61 | % %beta = 6; 62 | % end 63 | eta = eta_vec(N_idx); 64 | beta = beta_vec(N_idx); 65 | %% Setup the problem 66 | %save(sprintf('test_data_N_%d_M_%d_sigma_%f.mat', N, num_steps, ... 67 | % noise_std), 'X', 'A1', 'A2'); 68 | load(sprintf('../data/test_data_N_%d_M_%d_sigma_%f_rep_%d.mat', N, num_steps, ... 69 | noise_std, rep)); 70 | 71 | %% TVART model 72 | tic 73 | [lambda, A, B, C, cost_vec, Xten, Yten, rmse_vec] = ... 74 | TVART_alt_min(X, M, r, ... 75 | 'center', center, ... 76 | 'eta', eta, ... 77 | 'beta', beta, ... 78 | 'max_iter', max_iter, ... 79 | 'verbosity', 1, ... 80 | 'rtol', 1e-4,... 81 | 'proximal', proximal, ... 82 | 'regularization', regularization); 83 | t_elapsed = toc 84 | elapsed_times(N_idx, rep) = t_elapsed; 85 | fprintf('Fitting N = %d took %f s\n', N, t_elapsed); 86 | 87 | 88 | %% Independent models 89 | [Aindep] = indep_models(X, M); 90 | 91 | %% Low-rank independent models 92 | [U,S,V] = svd(X, 0); 93 | %Xr = U(:, 1:r) * S(1:r, 1:r) * V(:, 1:r)'; 94 | %[Aindep_r] = indep_models(Xr, M); 95 | [Aindep_r] = indep_models(S(1:r, 1:r) * V(:, 1:r)' , M); 96 | 97 | for model = 1:3 98 | row = row + 1; 99 | err_mse = 0.; 100 | true_mse = 0.; 101 | err_inf = 0.; 102 | err_2 = 0.; 103 | err_fro = 0.; 104 | for t = 1:(num_steps - 1) 105 | k = floor((t - 1) / M) + 1; 106 | % true system matrix 107 | if t < num_steps / 2 108 | A_true = A1; 109 | else 110 | A_true = A2; 111 | end 112 | % predicted system matrix 113 | if model == 1 114 | % indep(n) 115 | A_pred = Aindep(:, :, k); 116 | elseif model == 2 117 | % indep(r) 118 | %Ak = Aindep_r(:, :, k); 119 | A_pred = U(:, 1:r) * Aindep_r(:, :, k) * U(:, 1:r)'; 120 | elseif model == 3 121 | % TVART(r) 122 | Dk = diag(C(k,:)); 123 | A_pred = A * diag(lambda) * Dk * B'; 124 | end 125 | % compute errors 126 | xpred = A_pred * X(:, t); 127 | err_mse = err_mse + norm(xpred - X(:, t+1), 2)^2; 128 | true_mse = true_mse + norm(A_true * X(:, t) - X(:, t+1), 2)^2; 129 | err_inf = err_inf + max(abs(A_pred(:) - A_true(:))); 130 | err_2 = err_2 + norm(A_pred - A_true, 2); 131 | err_fro = err_fro + norm(A_pred - A_true, 'fro'); 132 | end % for t 133 | if model == 1 134 | disp('---full rank model---') 135 | elseif model == 2 136 | disp('---low rank model---') 137 | elseif model == 3 138 | disp('---TVART model---') 139 | end 140 | err_mse = err_mse / (N * (num_steps - 1)) 141 | true_mse = true_mse / (N * (num_steps - 1)) 142 | err_inf = err_inf / (num_steps - 1) 143 | err_2 = err_2 / (num_steps - 1) 144 | err_fro = err_fro / (num_steps - 1) 145 | 146 | err_table(row, :) = [N, nan, model, err_inf, err_2, ... 147 | err_fro, err_mse, true_mse]; 148 | % for k = 1:T 149 | % row = row + 1; 150 | % if model == 1 151 | % % indep(n) 152 | % Ak = Aindep(:, :, k); 153 | % elseif model == 2 154 | % % indep(r) 155 | % %Ak = Aindep_r(:, :, k); 156 | % Ak = U(:, 1:r) * Aindep_r(:, :, k) * U(:, 1:r)'; 157 | % elseif model == 3 158 | % % TVART(r) 159 | % Dk = diag(C(k,:)); 160 | % Ak = A * diag(lambda) * Dk * B'; 161 | % end 162 | % if k < T/2 163 | % % system 1 164 | % A_true = A1; 165 | % system = 1; 166 | % elseif k == T/2 && ~mod(T,2) 167 | % A_true = (A1 + A2)/2; 168 | % system = 1.5; 169 | % else 170 | % % system 2 171 | % A_true = A2; 172 | % system = 2; 173 | % end 174 | 175 | % [D] = eig(Ak(:, 1:N)); 176 | % [Dtrue] = eig(A_true); 177 | % % compute errors: 178 | % model_MSE = norm(Yten(:,:,k) - Ak * Xten(:,:,k), 'fro')^2 / (N*M); 179 | % true_MSE = norm(Yten(:,:,k) - A_true * Xten(1:N,:,k), 'fro')^2 / ... 180 | % (N*M); 181 | % A_err_inf = max(max(abs(Ak(:, 1:N) - A_true))); 182 | % A_err_2 = norm(Ak(:, 1:N) - A_true, 2); 183 | % A_err_fro = norm(Ak(:, 1:N) - A_true, 'fro'); 184 | 185 | % err_table(row, :) = [N, k, model, A_err_inf, A_err_2, ... 186 | % A_err_fro, model_MSE, true_MSE]; 187 | 188 | % fprintf('\n\tWindow %d, system %d\n', k, system) 189 | % fprintf('\tprediction MSE using A_tensor: %1.3g (SNR: %1.2g)\n', model_MSE, ... 190 | % model_MSE / noise_std^2) 191 | % fprintf('\tprediction MSE using A_truth: %1.3g (SNR: %1.2g)\n', true_MSE,... 192 | % true_MSE / noise_std^2) 193 | % fprintf('\tmodel MSE relative to truth: %1.3g%%\n', ... 194 | % model_MSE / true_MSE * 100) 195 | % fprintf('\tA infty error: %1.3g\n', A_err_inf) 196 | % fprintf('\tA 2-error: %1.3g\n', A_err_2) 197 | % fprintf('\tradius(A_tensor) = %1.2g, radius(A_true) = %1.2g\n', ... 198 | % max(abs(D)), max(abs(Dtrue))); 199 | 200 | % if plotting 201 | % figure(1) 202 | % set(gcf, 'position', [246 66 431 1231]); 203 | % subplot(3,1,1); 204 | % imagesc(Ak) 205 | % colorbar 206 | % colormap(brewermap([], 'PuOr')) 207 | % caxis([-1,1]) 208 | % title(sprintf('Inferred system matrix, model %d', model), 'fontweight', 'normal') 209 | % axis square 210 | % subplot(3,1,2); 211 | % imagesc(A_true); 212 | % colorbar 213 | % colormap(brewermap([], 'PuOr')) 214 | % caxis([-1,1]) 215 | % title(['True matrix (system ' num2str(system) ')'], 'fontweight', 'normal') 216 | % axis square 217 | % subplot(3,1,3); 218 | % imagesc(Ak(:,1:N) - A_true); 219 | % colorbar 220 | % colormap(brewermap([], 'PuOr')) 221 | % caxis([-1, 1]) 222 | % title('Difference', 'fontweight', 'normal') 223 | % axis square 224 | % [~] = input('Press any key to continue'); 225 | % if k ~= T 226 | % close 227 | % end 228 | % end 229 | end % model 230 | 231 | % for model = 1:3 232 | % idx = err_table(:, 3) == model & err_table(:, 1) == N; 233 | % errs = mean(err_table(idx, 4:end), 1); 234 | % A_err_inf = errs(1); 235 | % A_err_2 = errs(2); 236 | % A_err_fro = errs(3); 237 | % model_MSE = errs(4); 238 | % true_MSE = errs(5); 239 | % fprintf(['Model %d :\n\tA infinity error = %1.2g\n\tA 2-error = %1.2g\n\t' ... 240 | % 'model MSE = %1.2g\n\ttrue MSE = %1.2g\n'], model, A_err_inf, A_err_2, ... 241 | % model_MSE, true_MSE); 242 | % end 243 | end % rep loop 244 | end % N loop 245 | 246 | % %% summary of errors for plotting 247 | % error_summary = zeros(3, length(N_vec), 5); 248 | % for model = 1:3 249 | % for nidx = 1:length(N_vec) 250 | % N = N_vec(nidx); 251 | % idx = err_table(:, 3) == model & err_table(:, 1) == N; 252 | % errs = mean(err_table(idx, 4:end, 1), 1); 253 | % error_summary(model, nidx, :) = errs; 254 | % end 255 | % end 256 | 257 | error_table = array2table(err_table, ... 258 | 'VariableNames', {'N', 'window' 'model', ... 259 | 'err_inf', 'err_2', 'err_fro', 'model_MSE', ... 260 | 'true_MSE'}); 261 | 262 | model_vec = err_table(:,3); 263 | 264 | writetable(error_table, output_file); 265 | 266 | % figure; 267 | % loglog(N_vec, error_summary(:,:,1)) 268 | % legend({'independent', 'rank r indep', 'TVART'}) 269 | % xlabel('N') 270 | % ylabel('||A - A_{true} ||_\infty') 271 | % print('-depsc', [figdir, 'compare_err_inf.eps']) 272 | 273 | % figure; 274 | % loglog(N_vec, error_summary(:,:,2)) 275 | % legend({'independent', 'rank r indep', 'TVART'}) 276 | % xlabel('N') 277 | % ylabel('||A - A_{true} ||_2') 278 | % print('-depsc', [figdir, 'compare_err_2.eps']) 279 | 280 | % figure; 281 | % loglog(N_vec, error_summary(:,:,3)) 282 | % legend({'independent', 'rank r indep', 'TVART'}) 283 | % xlabel('N') 284 | % ylabel('||A - A_{true} ||_{Fro}') 285 | % print('-depsc', [figdir, 'compare_err_fro.eps']) 286 | 287 | save(save_file) -------------------------------------------------------------------------------- /src/switching_linear.m: -------------------------------------------------------------------------------- 1 | clear all 2 | close all 3 | 4 | addpath('~/work/MATLAB/') 5 | addpath('~/work/MATLAB/unlocbox') 6 | init_unlocbox(); 7 | figdir = '../figures/'; 8 | 9 | %% Parameters 10 | %% test system 11 | %rng('shuffle'); 12 | %rng(1337) 13 | rng(1) 14 | N = 10; 15 | M = 20; 16 | r = 8; 17 | num_steps = M * 10 + 1; 18 | num_trans = 200; 19 | noise_std = 0.5; 20 | noise_process = 0.0; 21 | T = floor((num_steps - 1) / M); 22 | noise_compensate = 0; 23 | %% tensor DMD algorithm 24 | algorithm = 'alt_min'; % ALS or prox_grad 25 | regularization = 'TV'; 26 | max_iter = 300; % iterations 27 | offset = 0; 28 | center = 0; 29 | eta = 1. / N; % Tikhonov 30 | beta = 5; % regularization strength 31 | proximal = 0; 32 | save_plots = 1; 33 | 34 | % %% testing params 35 | % N_vec = [6,10,14,18,20,30,50,80,100,200,400,1000,2000,4000]; 36 | % %eta_vec = linspace(1, 0.01, length(N_vec)); 37 | % eta_vec = logspace(0, 1, length(N_vec)); 38 | % eta = 0.1 %eta_vec(end-2); 39 | % N = N_vec(end-2); 40 | % beta = 120 / N; 41 | 42 | 43 | % algorithm = 'ALS_aux'; % ALS or prox_grad 44 | % regularization = 'TV'; 45 | % max_iter = 300; % iterations 46 | % eta1 = 0.80; % Tikhonov 47 | % beta = 0.4; % regularization strength 48 | % eta2 = 1; % closeness to relaxation variable 49 | % 50 | % regularization = 'TVL0'; 51 | % max_iter = 300; % iterations 52 | % offset = 0; 53 | % center = 0; 54 | % eta = 1; % Tikhonov 55 | % beta = 0.25; 56 | % regularization strength 57 | % 58 | % regularization = 'spline'; 59 | % max_iter = 300; % iterations 60 | % offset = 0; 61 | % center = 0; 62 | % eta = 1; % Tikhonov 63 | % beta = 10; % regularization strength 64 | % 65 | % algorithm = 'ALS_aux'; % ALS or prox_grad 66 | % regularization = 'TV'; 67 | % max_iter = 300; % iterations 68 | % eta1 = 1.; % Tikhonova 69 | % beta = 0.4; % regularization strength 70 | % eta2 = 1; % closeness to relaxation variable 71 | % 72 | % algorithm = 'ALS_prox_mix'; 73 | % regularization = 'TV'; 74 | % max_iter = 400; % iterations 75 | % eta1 = 1e2; % Tikhonov 76 | % beta = 0.1; 77 | % regularization strength 78 | % 79 | % algorithm = 'ALS_aux'; ALS or prox_grad 80 | % regularization = 'TV'; 81 | % max_iter = 60; iterations 82 | % eta1 = 100.0; Tikhonov 83 | % eta2 = 0.1; closeness to relaxation variable 84 | % beta = 1.0; regularization strength 85 | % 86 | % algorithm = 'ALS_smooth'; % ALS or prox_grad 87 | % max_iter = 60; 88 | % eta1 = 50; 89 | % eta2 = 0.1; 90 | % beta = 1.0; % regularization strength 91 | 92 | %% Setup the problem 93 | 94 | [X, A1, A2] = switching_linear_setup(N, num_steps, noise_std, offset); 95 | %load(sprintf('test_data_N_%d_M_%d_sigma_%f.mat', N, num_steps, ... 96 | % noise_std)); 97 | 98 | %% Preprocessing: 99 | [U,S,V] = svd(X, 0); 100 | 101 | %% reinitialize RNG 102 | %rng('shuffle') 103 | 104 | %% run tensor DMD 105 | if strcmp(algorithm, 'alt_min') 106 | [lambda, A, B, C, cost_vec, Xten, Yten, rmse_vec] = ... 107 | TVART_alt_min(X, M, r, ... 108 | 'center', center, ... 109 | 'eta', eta, ... 110 | 'beta', beta, ... 111 | 'max_iter', max_iter, ... 112 | 'verbosity', 2, ... 113 | 'proximal', proximal, ... 114 | 'regularization', regularization); 115 | figure(1) 116 | semilogy(cost_vec) 117 | hold on 118 | semilogy(rmse_vec, 'r-') 119 | legend({'cost', 'RMSE'}) 120 | ylabel('cost') 121 | xlabel('iteration') 122 | xlim([0, length(rmse_vec)]) 123 | elseif strcmp(algorithm, 'ALS') 124 | [lambda, A, B, C, Xten, Yten] = ... 125 | tensor_DMD_ALS(X, M, r, ... 126 | 'center', offset, ... 127 | 'eta', eta, ... 128 | 'nu', nu, ... 129 | 'beta', beta, ... 130 | 'max_iter', max_iter,... 131 | 'regularization', regularization); 132 | elseif strcmp(algorithm, 'alt_min_l0') 133 | [lambda, A, B, C, cost_vec, Xten, Yten, rmse_vec] = ... 134 | tensor_DMD_alt_min_l0(X, M, r, ... 135 | 'center', center, ... 136 | 'eta', eta, ... 137 | 'beta', beta, ... 138 | 'max_iter', max_iter, ... 139 | 'verbosity', 2); 140 | figure(1) 141 | semilogy(cost_vec) 142 | hold on 143 | semilogy(rmse_vec, 'r-') 144 | legend({'cost', 'RMSE'}) 145 | ylabel('cost') 146 | xlabel('iteration') 147 | xlim([0, length(rmse_vec)]) 148 | elseif strcmp(algorithm, 'ALS_aux') 149 | [lambda, A, B, C, cost_vec, Xten, Yten, rmse_vec, W] = ... 150 | tensor_DMD_ALS_aux(X, M, r, ... 151 | 'center', offset, ... 152 | 'eta1', eta1, ... 153 | 'eta2', eta2, ... 154 | 'beta', beta, ... 155 | 'max_iter', max_iter, ... 156 | 'proximal', 0, ... 157 | 'verbosity', 2, ... 158 | 'regularization', regularization); 159 | figure(1) 160 | semilogy(cost_vec) 161 | hold on 162 | semilogy(rmse_vec, 'r-') 163 | legend({'cost', 'RMSE'}) 164 | ylabel('cost') 165 | xlabel('iteration') 166 | xlim([0, length(rmse_vec)]) 167 | elseif strcmp(algorithm, 'ALS_prox_mix') 168 | [lambda, A, B, C, cost_vec, Xten, Yten, rmse_vec, W] = ... 169 | tensor_DMD_ALS_prox_mix(X, M, r, ... 170 | 'center', 0, ... 171 | 'eta1', eta1, ... 172 | 'beta', beta, ... 173 | 'max_iter', max_iter, ... 174 | 'proximal', 0, ... 175 | 'regularization', regularization); 176 | figure(1) 177 | semilogy(cost_vec) 178 | hold on 179 | semilogy(rmse_vec, 'r-') 180 | legend({'cost', 'RMSE'}) 181 | ylabel('cost') 182 | xlabel('iteration') 183 | %xlim([0,max_iter+1]) 184 | elseif strcmp(algorithm, 'ALS_smooth') 185 | [lambda, A, B, C, cost_vec, Xten, Yten, rmse_vec] = ... 186 | tensor_DMD_ALS_smooth(X, M, r, ... 187 | 'center', 1, ... 188 | 'eta1', eta1, ... 189 | 'eta2', eta2, ... 190 | 'beta', beta, ... 191 | 'max_iter', max_iter,... 192 | 'verbosity', 2); 193 | figure(1) 194 | semilogy(cost_vec) 195 | hold on 196 | semilogy(rmse_vec, 'r-') 197 | legend({'cost', 'RMSE'}) 198 | ylabel('cost') 199 | xlabel('iteration') 200 | %xlim([0,max_iter+1]) 201 | elseif strcmp(algorithm, 'prox_grad') 202 | [lambda, A, B, C, Xten, Yten, cost_vec, rmse_vec] = ... 203 | tensor_DMD_prox_grad(X, M, r, ... 204 | 'center', 0, ... 205 | 'eta', eta, ... 206 | 'step_size', step_size, ... 207 | 'beta', beta, ... 208 | 'iter_disp', 1, ... 209 | 'max_iter', max_iter); 210 | figure(1) 211 | semilogy(cost_vec) 212 | hold on 213 | semilogy(rmse_vec, 'r-') 214 | legend({'cost', 'RMSE'}) 215 | ylabel('cost') 216 | xlabel('iteration') 217 | %xlim([0,max_iter+1]) 218 | else 219 | error('algorithm not implemented'); 220 | end 221 | 222 | %% noise compensation 223 | if noise_compensate 224 | C = C / (1-noise_std); 225 | end 226 | 227 | 228 | % [lambda_r, A_r, B_r, C_r, W_r] = rebalance_2(A, B, C, W, 1, 1e-6); 229 | % [lambda_r, A_r, B_r, C_r, W_r] = reorder_components(lambda_r, A_r, B_r, ... 230 | % C_r, W_r); 231 | 232 | % lambda = lambda_r; 233 | % A = A_r; 234 | % B = B_r; 235 | % C = C_r; 236 | % W = W_r; 237 | 238 | 239 | %% Figure 2 240 | figure(2) 241 | subplot(2,1,1) 242 | plot(X', 'color', [0,0,0] + 0.4, 'linewidth', 0.5) 243 | axis tight 244 | vline(M*[1:(T-1)], 'k--') 245 | xlim([1, num_steps]) 246 | xlabel('Time step') 247 | %ylabel('State variable') 248 | title('Observations') 249 | subplot(2,1,2) 250 | plot(C, '-'); 251 | %if strcmp(algorithm, 'ALS_aux') 252 | % hold on 253 | % plot(W, 'ko'); 254 | %end 255 | set(gca, 'xtick', 1:T) 256 | axis tight 257 | xlim([0.5, T+.5]) 258 | arr = ylim(); 259 | ylim([arr(1) - 0.05, arr(2) + 0.05]) 260 | xlabel('Window') 261 | title('Temporal modes') 262 | figure(2) 263 | set(gcf, 'Color', 'w', 'Position', [100 500 800 600]); 264 | set(gcf, 'PaperUnits', 'inches', 'PaperPosition', [0 0 6.5 4], 'PaperPositionMode', 'manual'); 265 | set(gcf,'renderer','painters') 266 | %export_fig('-depsc2', '-r200', [figdir 'switching_summary.eps']); 267 | if save_plots 268 | print('-depsc2', '-loose', '-r200', [figdir ... 269 | 'switching_summary.eps']); 270 | end 271 | %print('-dpng', '-r200', [figdir 'switching_summary.png']); 272 | 273 | % A = Aten.U{1}; 274 | % B = Aten.U{2}; 275 | % C = Aten.U{3}; 276 | 277 | 278 | %% test closeness 279 | figure(3); 280 | set(gcf, 'position', [246 66 431 1231]); 281 | for k = 1:T 282 | Dk = diag(C(k,:)); 283 | Ak = A * diag(lambda) * Dk * B'; 284 | %Ak = Aten(:, :, k) 285 | ten_MSE = norm(Yten(:,:,k) - Ak * Xten(:,:,k), 'fro')^2 / (N*M); 286 | if k <= T/2 287 | % system 1 288 | A_true = A1; 289 | system = 1; 290 | else 291 | % system 2 292 | A_true = A2; 293 | system = 2; 294 | end 295 | 296 | [D] = eig(Ak(:,1:N)); 297 | [Dtrue] = eig(A_true); 298 | 299 | A_err = max(max(abs(Ak(:, 1:N) - A_true))); 300 | A_err_2 = norm(Ak(:, 1:N) - A_true, 2); 301 | true_MSE = norm(Yten(:,:,k) - A_true * Xten(1:N,:,k), 'fro')^2 / ... 302 | (N*M); 303 | fprintf('\nWindow %d, system %d\n', k, system) 304 | fprintf('prediction MSE using A_tensor: %1.3g (SNR: %1.2g)\n', ten_MSE, ... 305 | ten_MSE / noise_std^2) 306 | fprintf('prediction MSE using A_truth: %1.3g (SNR: %1.2g)\n', true_MSE,... 307 | true_MSE / noise_std^2) 308 | fprintf('tensor MSE relative to truth: %1.3g%%\n', ... 309 | ten_MSE / true_MSE * 100) 310 | fprintf('A infty error: %1.3g\n', A_err) 311 | fprintf('A 2-error: %1.3g\n', A_err_2) 312 | fprintf('radius(A_tensor) = %1.2g, radius(A_true) = %1.2g\n', ... 313 | max(abs(D)), max(abs(Dtrue))); 314 | 315 | set(0, 'CurrentFigure', 3); 316 | subplot(3,1,1); 317 | imagesc(Ak) 318 | colorbar 319 | colormap(brewermap([], 'PuOr')) 320 | caxis([-1,1]/sqrt(N)) 321 | %caxis auto 322 | title('TVAR system matrix', 'fontweight', 'normal') 323 | axis square 324 | 325 | subplot(3,1,2); 326 | imagesc(A_true); 327 | colorbar 328 | colormap(brewermap([], 'PuOr')) 329 | caxis([-1,1]/sqrt(N)) 330 | %caxis auto 331 | title(['True matrix (system ' num2str(system) ')'], 'fontweight', 'normal') 332 | axis square 333 | 334 | subplot(3,1,3); 335 | imagesc(Ak(:,1:N) - A_true); 336 | colorbar 337 | colormap(brewermap([], 'PuOr')) 338 | caxis([-1, 1]/sqrt(N)) 339 | %caxis auto 340 | title('Difference', 'fontweight', 'normal') 341 | axis square 342 | 343 | if k == 1 344 | set(gcf, 'Color', 'w', 'Position', [0 0 600 1000]); 345 | set(gcf, 'PaperUnits', 'inches', ... 346 | 'PaperPosition', [0 0 3 7.2], ... 347 | 'PaperPositionMode', 'manual'); 348 | set(gcf,'renderer','painters') 349 | %print('-dpng', '-painters', '-r200', [figdir 350 | %'switching_matrices.png']) 351 | if save_plots 352 | print('-depsc2', '-loose', '-r200', [figdir ... 353 | 'switching_matrices.eps']); 354 | end 355 | %export_fig('-depsc2', '-r200', [figdir 'switching_matrices.eps']); 356 | % print('-dpdf', '-painters', [figdir 'switching_matrices.pdf']); 357 | end 358 | 359 | pause 360 | if k ~= T 361 | clf(3) 362 | end 363 | end 364 | 365 | 366 | 367 | %% Figure 4 368 | figure(4) 369 | 370 | set(gcf, 'position', [693 63 560 1240]) 371 | subplot(3,1,1); 372 | imagesc(A); 373 | title('Left spatial modes', 'fontweight', 'normal') 374 | colormap(brewermap([], 'PuOr')) 375 | c = caxis(); 376 | c = max(abs(c)); 377 | caxis([-c, c]) 378 | colorbar 379 | axis square 380 | 381 | subplot(3,1,2); 382 | imagesc(B(1:N,:)); 383 | title('Right spatial modes', 'fontweight', 'normal') 384 | colormap(brewermap([], 'PuOr')) 385 | %caxis([-1, 1]) 386 | c = caxis(); 387 | c = max(abs(c)); 388 | caxis([-c, c]) 389 | colorbar 390 | axis square 391 | 392 | subplot(3,1,3); 393 | imagesc(C); 394 | title('Temporal modes', 'fontweight', 'normal') 395 | colormap(brewermap([], 'PuOr')) 396 | %caxis([-1, 1]) 397 | c = caxis(); 398 | c = max(abs(c)); 399 | caxis([-c, c]) 400 | colorbar 401 | axis square 402 | 403 | set(gcf, 'Color', 'w', 'Position', [0 0 600 1000]); 404 | set(gcf, 'PaperUnits', 'inches', 'PaperPosition', [0 0 3 7.2], ... 405 | 'PaperPositionMode', 'manual'); 406 | set(gcf,'renderer','painters') 407 | if save_plots 408 | print('-depsc2', '-loose', '-r200', [figdir ... 409 | 'switching_components.eps']); 410 | end 411 | %export_fig('-depsc2', '-r200', [figdir 'switching_components.eps']); 412 | %print('-dpng', '-r200', [figdir 'switching_components.png']); 413 | 414 | 415 | 416 | %% Figure 5 417 | figure(5); 418 | scatter(real(Dtrue), imag(Dtrue), 'ob'); 419 | hold on 420 | scatter(real(D), imag(D), '+k'); 421 | plot(cos(linspace(0,1,100)*2*pi), sin(linspace(0,1,100)*2*pi), 'k--') 422 | title('eigenvalues', 'fontweight', 'normal') 423 | xlim([-1.2, 1.2]) 424 | ylim([-1.2, 1.2]) 425 | axis square; 426 | set(gcf, 'Color', 'w', 'Position', [100 200 600 700]); 427 | set(gcf, 'PaperUnits', 'inches', 'PaperPosition', [0 0 4 4.5], ... 428 | 'PaperPositionMode', 'manual'); 429 | set(gcf,'renderer','painters') 430 | %export_fig('-depsc2', '-r200', [figdir 'switching_eigs.eps']); 431 | if save_plots 432 | print('-depsc2', '-loose', '-r200', [figdir ... 433 | 'switching_eigs.eps']); 434 | end 435 | %print('-dpng', '-r200', [figdir 'switching_eigs.png']); 436 | 437 | 438 | 439 | %% Clustering 440 | %D = compute_pdist(A, B, C*diag(lambda)); 441 | D = pdist(C); 442 | Z = linkage(squareform(D), 'complete'); 443 | cluster_ids = cluster(Z, 'maxclust', 3); 444 | -------------------------------------------------------------------------------- /src/TVART_alt_min.m: -------------------------------------------------------------------------------- 1 | function [lambda, A, B, C, cost_vec, varargout] = ... 2 | TVART_alt_min(data, window_size, r, varargin) 3 | % TVART_alt_min 4 | % Solve TVART problem using alternating minimization. 5 | % 6 | % Parameters: 7 | % data : N by time matrix of data 8 | % window_size : number of time points in window 9 | % r : rank 10 | % 11 | % Output: 12 | % lambda : vector of scalings for components (all ones) 13 | % A : left spatial modes 14 | % B : right spatial modes 15 | % C : temporal modes 16 | % cost_vec : cost per iterate 17 | % X (optional) : data reshaped into tensor 18 | % Y (optional) : target data reshaped into tensor 19 | % rmse_vec : RMSE per iterate 20 | % 21 | % Optional parameters: 22 | % max_iter : maximum number of iterations (default: 20) 23 | % eta : Tikhonov regularization parameter (default: 0.1) 24 | % beta : temporal regularization parameter (default 0) 25 | % regularization : temporal regularization (default: 'TV', 26 | % options: 'TV', 'Spline', 'TV0') 27 | % center : fit an affine/centered model (default: 0) 28 | % iter_disp : display cost after this many iterations 29 | % (default: 1) 30 | % init : struct containing initialization (default: false) 31 | % atol : absolute tolerance (default: 1e-6) 32 | % rtol : relative tolerance (default: 1e-4) 33 | % 34 | 35 | % Input parsing, defaults 36 | max_iter = 20; 37 | iter_disp = 1; 38 | center = 0; 39 | eta = 0.1; % Tikhonov 40 | beta = 0.0; % total variation parameter 41 | regularization = 'TV'; 42 | rtol = 1e-4; 43 | atol = 1e-6; 44 | verbosity = 1; 45 | 46 | p = inputParser; 47 | addParameter(p, 'max_iter', max_iter); 48 | addParameter(p, 'iter_disp', iter_disp); 49 | addParameter(p, 'center', center); 50 | addParameter(p, 'eta', eta); 51 | addParameter(p, 'proximal', 0); 52 | addParameter(p, 'beta', beta); 53 | addParameter(p, 'regularization', regularization); 54 | addParameter(p, 'scale_params', false); 55 | addParameter(p, 'init', false); 56 | addParameter(p, 'atol', atol); 57 | addParameter(p, 'rtol', rtol); 58 | addParameter(p, 'verbosity', verbosity); 59 | parse(p, varargin{:}); 60 | max_iter = p.Results.max_iter; 61 | iter_disp = p.Results.iter_disp; 62 | center = p.Results.center; 63 | eta = p.Results.eta; 64 | proximal = p.Results.proximal; 65 | beta = p.Results.beta; 66 | regularization = p.Results.regularization; 67 | init = p.Results.init; 68 | atol = p.Results.atol; 69 | rtol = p.Results.rtol; 70 | verbosity = p.Results.verbosity; 71 | 72 | if verbosity > 0 73 | fprintf('Running TVART alternating minimization\n'); 74 | fprintf('\tregularization %s\n', regularization); 75 | fprintf('\tproximal = %d\n', proximal); 76 | fprintf('\tcenter = %d\n', center); 77 | end 78 | if verbosity > 1 79 | fprintf('\tbeta = %1.3g\n', beta); 80 | fprintf('\teta = %1.3g\n', eta); 81 | end 82 | fprintf('\n'); 83 | 84 | if ~(strcmp(regularization, 'TV') || ... 85 | strcmp(regularization, 'spline') || ... 86 | strcmp(regularization, 'TVL0')) 87 | disp(['Received regularization = ' regularization]) 88 | error(['Unknown regularization type, should be one of ''TV'', ' ... 89 | '''spline,'' or ''TVL0''']); 90 | end 91 | 92 | N = size(data, 1); 93 | M = window_size; 94 | t = size(data, 2); 95 | 96 | T = floor((t-1) / window_size); 97 | % throws out some data if not an integer multiple of M 98 | X = data(:, 1:end-1); 99 | X = X(:, 1:M*T); 100 | Y = data(:, 2:end); 101 | Y = Y(:, 1:M*T); 102 | 103 | % initialize decomposition 104 | if isstruct(init) 105 | disp('using initialization given') 106 | A = init.A; 107 | B = init.B; 108 | C = init.C; 109 | if isfield(init, 'lambda') 110 | lambda = init.lambda; 111 | C = C * diag(lambda); 112 | end 113 | lambda = ones(r, 1); 114 | else 115 | [A, B, C] = init_dmd(X, Y, r, T); % initalization 116 | if center 117 | B = [B; zeros(1, r)]; 118 | end 119 | lambda = ones(r, 1); 120 | end 121 | 122 | % reshape the data into tensors 123 | if ~center 124 | X = reshape(X, [N, M, T]); 125 | else 126 | % append ones to input data 127 | X = reshape([X; 128 | ones(1, M*T)], [N+1, M, T]); 129 | end 130 | Y = reshape(Y, [N, M, T]); 131 | 132 | cost_vec = zeros(max_iter, 1); 133 | rmse_vec = zeros(max_iter, 1); 134 | % Tcurr = 0; 135 | c = cost(X, Y, A, B, C, N, M, T, r, eta, ... 136 | beta, regularization, proximal); 137 | %cost_vec(1) = c; 138 | fprintf('initial cost = %1.5g\n', c); 139 | %fprintf('eta = %1.3g, 1/eta = %1.5g\n', eta, 1/(eta)); 140 | update_iter = 0; 141 | for iter = 1:max_iter 142 | %% Update A 143 | old_A = A; 144 | old_c = c; 145 | A = solve_A(X, Y, A, B, C, N, M, T, r, eta, proximal); 146 | if verbosity >= 2 147 | [c,rmse] = cost(X, Y, A, B, C, N, M, T, r,... 148 | eta, beta, regularization, proximal); 149 | fprintf('\tpost A solve: cost = %1.5g, RMSE = %1.3g\n', ... 150 | c, rmse); 151 | end 152 | if old_c < c 153 | fprintf(['Warning: not updating A because cost went up!\' ... 154 | 'n']); 155 | A = old_A; 156 | c = old_c; 157 | end 158 | 159 | %% Update B 160 | old_B = B; 161 | old_c = c; 162 | B = solve_B_cg(X, Y, A, B, C, N, M, T, r, eta, proximal); 163 | if verbosity >= 2 164 | [c,rmse] = cost(X, Y, A, B, C, N, M, T, r,... 165 | eta, beta, regularization, proximal); 166 | fprintf('\tpost B solve: cost = %1.5g, RMSE = %1.3g\n', ... 167 | c, rmse); 168 | end 169 | if old_c < c 170 | fprintf(['Warning: not updating B because cost went up!\' ... 171 | 'n']); 172 | B = old_B; 173 | c = old_c; 174 | end 175 | 176 | %% Update C 177 | old_C = C; 178 | olc_c = c; 179 | if strcmp(regularization, 'spline') 180 | C = solve_C_cg(X, Y, A, B, C, N, M, T, r, eta, beta, proximal); 181 | else 182 | C = solve_C_proxg(X, Y, A, B, C, N, M, T, r, eta, beta, ... 183 | regularization, proximal); 184 | end 185 | [c,rmse] = cost(X, Y, A, B, C, N, M, T, r,... 186 | eta, beta, regularization, proximal); 187 | if verbosity >= 2 188 | fprintf('\tpost C solve: cost = %1.5g, RMSE = %1.3g\n', ... 189 | c, rmse); 190 | end 191 | if old_c < c 192 | fprintf(['Warning: not updating C because cost went up!\' ... 193 | 'n']); 194 | C = old_C; 195 | c = old_c; 196 | end 197 | 198 | cost_vec(iter) = c; 199 | rmse_vec(iter) = rmse; 200 | if mod(iter, iter_disp) == 0 201 | fprintf('iter %d: cost = %1.5g, RMSE = %1.3g\n', ... 202 | iter, c, rmse); 203 | end 204 | 205 | if iter >= 2 206 | %cost_rel = cost_vec(2); 207 | cost_rel = cost_vec(iter - 1); 208 | end 209 | if iter > 2 && abs(cost_vec(iter) - cost_vec(iter-1)) < ... 210 | max([rtol * cost_rel, atol]) 211 | fprintf('tolerance reached!\n'); 212 | break 213 | end 214 | if isnan(cost_vec(iter)) 215 | fprintf('Error: algorithm diverged\n') 216 | break 217 | end 218 | end 219 | rmse_vec = rmse_vec(1:iter); 220 | cost_vec = cost_vec(1:iter); 221 | 222 | % var_tot = 0; 223 | % var_int = 0; 224 | % var_aff = 0; 225 | % for k = 1:T 226 | % Dk = diag(C(k, :)); 227 | % if center 228 | % Ypred_int = A * diag(lambda) * Dk * B(end, :)' * ... 229 | % ones(1, size(Ypred_k, 2)); 230 | % Ypred_k = A * diag(lambda) * Dk * B(1:end-1, :)' * X(:,:,k) ... 231 | % + Ypred_int; 232 | % var_aff = var_aff + norm(Y(:,:,k) - Ypred_k - Ypred_int, ... 233 | % 'fro')^2; 234 | % var_int = var_int + norm(Y(:,:,k) - Ypred_int, ... 235 | % 'fro')^2; 236 | % else 237 | % Ypred_k = A * diag(lambda) * Dk * B' * X(:,:,k); 238 | % var_aff = var_aff + norm(Y(:,:,k) - Ypred_k, 'fro')^2; 239 | % var_int = var_int + norm(Y(:,:,k), 'fro')^2; 240 | % end 241 | % var_tot = var_tot + norm(Y(:,:,k), 'fro')^2; 242 | % end 243 | % var_tot = var_tot / (N*M*T); 244 | % var_int = var_int / (N*M*T); 245 | % var_aff = var_aff / (N*M*T); 246 | % fprintf('total var:\t\t%1.3f\n', var_tot); 247 | % fprintf('intercept resid var:\t%1.3f, %1.3g%%\n', var_int, ... 248 | % var_int / var_tot* 100); 249 | % fprintf('affine model resid var:\t%1.3f, %1.3g%%\n', var_aff, ... 250 | % var_aff / var_tot* 100); 251 | 252 | if nargout >= 6 253 | varargout{1} = X; 254 | end 255 | if nargout >= 7 256 | varargout{2} = Y; 257 | end 258 | if nargout >= 8 259 | varargout{3} = rmse_vec; 260 | end 261 | end 262 | 263 | function Anew = solve_A(X, Y, A, B, C, N, M, T, r, eta, proximal) 264 | rhs = zeros(N, r); 265 | sysmat = zeros(r, r); 266 | for k = 1:T 267 | Dk = diag(C(k, :)); 268 | rhs = rhs + Y(:,:,k) * X(:,:,k)' * B * Dk; 269 | sysmat = sysmat + Dk * B' * X(:,:,k) * X(:,:,k)' * B * Dk; 270 | end 271 | sysmat = sysmat + (1/eta) * eye(size(sysmat)); 272 | if proximal 273 | rhs = rhs + (1/eta) * A; 274 | end 275 | %fprintf('\tcond(A) = %1.3g\n', cond(sysmat)) 276 | Anew = rhs / sysmat; 277 | % function y=apply_mat(x) 278 | % x = reshape(x, size(A)); 279 | % y = x*(sysmat + (1/eta)*eye(size(sysmat))); 280 | % y = y(:); 281 | % end 282 | % maxit = 10; 283 | % tol = 1e-4; 284 | % [x, flag, rr, iter] = pcg(@apply_mat, rhs(:), tol, maxit); 285 | % Anew = reshape(x, size(A)); 286 | % if flag ~= 0 287 | % disp('Warning: pcg did not converge in A solve') 288 | % fprintf('flag: %d, relres: %f, iter: %d\n', flag, rr, iter) 289 | % end 290 | end 291 | 292 | % function Bnew = solve_B(X, Y, A, B, C, N, M, T, r, eta) 293 | % % solve nasty sylvester equation the stupid way 294 | % rhs = zeros(size(B)); 295 | % sysmat = zeros(prod(size(B)), prod(size(B))); 296 | % %fprintf('\t\tsolve_B: 1/eta = %1.5g\n', 1/eta); 297 | % for k = 1:T 298 | % Dk = diag(C(k, :)); 299 | % rhs = rhs + X(:,:,k) * Y(:,:,k)' * A * Dk; 300 | % Rk = Dk * A' * A * Dk; 301 | % Lk = X(:,:,k) * X(:,:,k)'; 302 | % sysmat = sysmat + kron(Rk', Lk); 303 | % end 304 | % % if proximal 305 | % % rhs = rhs + (1/eta1) * B; 306 | % % end 307 | % rhs = rhs(:); 308 | % if size(B,1) > N 309 | % tmp = eye(size(sysmat)); 310 | % for i = 1:r 311 | % tmp(i*(N+1), i*(N+1)) = 0.0; 312 | % end 313 | % bvec = (sysmat + (1/eta) * tmp) \ rhs; 314 | % else 315 | % bvec = (sysmat + (1/eta) * eye(size(sysmat))) \ rhs; 316 | % end 317 | % Bnew = reshape(bvec, size(B)); 318 | % end 319 | 320 | function Bnew = solve_B_cg(X, Y, A, B, C, N, M, T, r, eta, proximal) 321 | maxit = 24; 322 | tol = 1e-3; 323 | %fprintf('\t\tsolve_B: 1/eta = %1.5g\n', 1/eta); 324 | function z = apply_sylv(b) 325 | Bt = reshape(b, size(B)); 326 | Zt = zeros(size(X,1), r); 327 | for k = 1:T 328 | Dk = diag(C(k, :)); 329 | Rk = Dk * A' * A * Dk; 330 | Lk = X(:,:,k) * X(:,:,k)'; 331 | Zt = Zt + Lk * Bt * Rk; 332 | end 333 | if size(B,1) > N 334 | Zt = Zt + (1/eta) * [Bt(1:N,:); zeros(1,r)]; 335 | else 336 | Zt = Zt + (1/eta) * Bt; 337 | end 338 | z = Zt(:); 339 | end 340 | 341 | rhs = zeros(size(B)); 342 | for k = 1:T 343 | Dk = diag(C(k, :)); 344 | rhs = rhs + X(:,:,k) * Y(:,:,k)' * A * Dk; 345 | end 346 | if proximal 347 | rhs = rhs + (1/eta) * B; 348 | end 349 | rhs = rhs(:); 350 | 351 | % function z = precond(bx) 352 | % Bx = reshape(bx, size(B)); 353 | % Z = zeros(size(X,1), r); 354 | % L = zeros(size(X,1), size(X,1)); 355 | % R = zeros(r, r); 356 | % for k = 1:T 357 | % Dk = diag(C(k, :)); 358 | % R = R + Dk * A' * A * Dk; 359 | %p L = L + X(:,:,k) * X(:,:,k)'; 360 | % end 361 | % Z = L \ Bx / R; 362 | % z = Z(:); 363 | % end 364 | 365 | [x, flag, rr, iter] = pcg(@apply_sylv, rhs, tol, maxit, [], [], ... 366 | B(:)); 367 | Bnew = reshape(x, size(B)); 368 | if flag ~= 0 369 | disp('Warning: pcg did not converge in B solve') 370 | fprintf('flag: %d, relres: %f, iter: %d\n', flag, rr, iter) 371 | end 372 | end 373 | 374 | function Cnew = solve_C_proxg(X, Y, A, B, C, N, M, T, r, eta, beta, ... 375 | regularization, proximal) 376 | max_iter = 40; 377 | shrinkage = 1/10; 378 | %step_size = 2e-4 / T; 379 | 380 | function c = g(X, Y, A, B, C, N, M, T, r, eta, beta) 381 | c = 0; 382 | for k = 1:T 383 | Dk = diag(C(k, :)); 384 | Ypred_k = A * Dk * B' * X(:,:,k); 385 | c = c + ... 386 | 0.5 * norm(Y(:,:,k) - Ypred_k, 'fro')^2; 387 | end 388 | if ~proximal 389 | if size(B,1) > N 390 | c = c + 0.5 / (eta) * ... 391 | (norm(A, 'fro')^2 + ... 392 | norm(B(1:N,:), 'fro')^2 + ... 393 | norm(C, 'fro')^2); 394 | else 395 | c = c + 0.5 / (eta) * ... 396 | (norm(A, 'fro')^2 + ... 397 | norm(B, 'fro')^2 + ... 398 | norm(C, 'fro')^2); 399 | end 400 | end 401 | end 402 | 403 | step_size = 1; 404 | for iter = 1:max_iter 405 | if iter == 1 406 | gold = g(X, Y, A, B, C, N, M, T, r, eta, beta); 407 | grad = grad_C(X, Y, A, B, C, N, M, T, r, eta); 408 | Cmom = C; 409 | else 410 | Cmom = C + (iter - 2)/(iter + 1) * (C - Cold); 411 | gold = g(X, Y, A, B, Cmom, N, M, T, r, eta, beta); 412 | grad = grad_C(X, Y, A, B, Cmom, N, M, T, r, eta); 413 | end 414 | if proximal 415 | grad = grad - (1/eta) * C; 416 | end %fprintf('\t\tg_old = %1.3g', gold); 417 | while 1 % line search loop 418 | Cnew = Cmom - step_size * grad; 419 | if beta > 0 420 | if strcmp(regularization, 'TV') 421 | for k = 1:r 422 | %Cnew(:, k) = l1tv(Cnew(:, k), beta * step_size); 423 | param = {}; 424 | param.verbose = 0; 425 | param.use_fast = 1; 426 | param.init = Cnew(:, k); 427 | Cnew(:, k) = prox_tv1d(Cnew(:, k), ... 428 | beta * step_size,... 429 | param); 430 | end 431 | elseif strcmp(regularization, 'TVL0') 432 | for k = 1:r 433 | [Cnew(:,k), ~] = prox_tv0(Cnew(:, k), beta * step_size); 434 | end 435 | end 436 | end 437 | G = (Cmom(:) - Cnew(:)) / step_size; 438 | gnew = g(X, Y, A, B, Cnew, N, M, T, r, eta, beta); 439 | if gnew > (gold - step_size * grad(:)' * G + ... 440 | step_size / 2 * norm(G)^2) 441 | if step_size > 1e-12 442 | step_size = step_size * shrinkage; 443 | %fprintf('shrinking step size to %f\n', ... 444 | % step_size); 445 | else 446 | %fprintf(' g_new = %1.3g\n', gnew); 447 | step_size = step_size * shrinkage; 448 | warning('step size very small!') 449 | % Cold = C; 450 | % C = Cnew; 451 | % break 452 | end 453 | else 454 | %fprintf(' g_new = %1.3g\n', gnew); 455 | Cold = C; 456 | C = Cnew; 457 | break 458 | end 459 | end 460 | end 461 | end 462 | 463 | function Cnew = solve_C_cg(X, Y, A, B, C, N, M, T, r, eta, beta, proximal) 464 | maxit = 24; 465 | tol = 1e-4; 466 | diff_mat = setup_smoother(T); 467 | 468 | function z = apply_mat(c) 469 | Ct = reshape(c, size(C)); 470 | Zt = zeros(size(C)); 471 | for k = 1:T 472 | Xk = X(:, :, k); 473 | Yk = Y(:, :, k); 474 | % rhs = diag(B' * Xk * Yk' * A); 475 | L = B'* Xk * Xk' * B; 476 | R = A' * A; 477 | sysmat = L .* R; 478 | sysmat = sysmat + (1/eta) * eye(r); 479 | Zt(k, :) = sysmat * Ct(k, :)'; 480 | end 481 | Zt = Zt + beta * diff_mat' * diff_mat * Ct; 482 | z = Zt(:); 483 | end 484 | 485 | rhs = zeros(size(C)); 486 | for k = 1:T 487 | Xk = X(:, :, k); 488 | Yk = Y(:, :, k); 489 | rhs(k, :) = diag(B' * Xk * Yk' * A); 490 | end 491 | if proximal 492 | rhs = rhs + (1/eta) * C; 493 | end 494 | rhs = rhs(:); 495 | 496 | [x, flag, rr, iter] = pcg(@apply_mat, rhs, tol, maxit, [], [], ... 497 | C(:)); 498 | Cnew = reshape(x, size(C)); 499 | if flag ~= 0 500 | disp('Warning: pcg did not converge in C solve') 501 | fprintf('flag: %d, relres: %f, iter: %d\n', flag, rr, iter) 502 | end 503 | end 504 | 505 | function [c,varargout] = cost(X, Y, A, B, C, N, M, T, r, ... 506 | eta, beta, regularization, proximal) 507 | c = 0; 508 | for k = 1:T 509 | Dk = diag(C(k, :)); 510 | Ypred_k = A * Dk * B' * X(:,:,k); 511 | c = c + ... 512 | 0.5 * norm(Y(:,:,k) - Ypred_k, 'fro')^2; 513 | end 514 | rmse = sqrt(2 * c / (N * M * T)); 515 | if beta > 0 516 | if strcmp(regularization, 'TV') 517 | %% L1 TV 518 | DC = diff(C); 519 | c = c + beta * sum(abs(DC(:))); 520 | % elseif strcmp(regularization, 'groupTV') 521 | % %% Group TV 522 | % for k = 2:T 523 | % c = c + beta * norm(C(k, :) - C(k-1, :), 2); 524 | % end 525 | elseif strcmp(regularization, 'spline') 526 | diff_mat = setup_smoother(T); 527 | c = c + 0.5 * beta * norm(diff_mat * C, 'fro')^2; 528 | %c = c + 0.5 * beta * norm(diff(C), 'fro')^2; 529 | elseif strcmp(regularization, 'TV0') 530 | %% L0 TV 531 | c = c + beta * nnz(diff(C)); 532 | end 533 | end 534 | if ~proximal 535 | if size(B,1) > N 536 | c = c + 0.5 / (eta) * ... 537 | (norm(A, 'fro')^2 + ... 538 | norm(B(1:N,:), 'fro')^2 + ... 539 | norm(C, 'fro')^2); 540 | else 541 | c = c + 0.5 / (eta) * ... 542 | (norm(A, 'fro')^2 + ... 543 | norm(B, 'fro')^2 + ... 544 | norm(C, 'fro')^2); 545 | end 546 | end 547 | 548 | if nargout ==2 549 | varargout{1} = rmse; 550 | end 551 | end 552 | 553 | 554 | 555 | function g = grad_C(X, Y, A, B, C, N, M, T, r, eta) 556 | g = zeros(size(C)); 557 | for k = 1:T 558 | % grad for each row of C separately 559 | Dk = diag(C(k, :)); 560 | g1 = diag(B' * X(:,:,k) * Y(:,:,k)' * A); 561 | L = B'* X(:,:,k) * X(:,:,k)' * B; 562 | R = A' * A; 563 | sysmat = L .* R; 564 | g2 = diag(L * Dk * R); 565 | g(k, :) = - g1 + g2; 566 | end 567 | g = g + (1/eta) * C; 568 | end 569 | 570 | function D = setup_smoother(T) 571 | I2 = speye(T-1, T-1); 572 | O2 = zeros(T-1, 1); 573 | D = [I2 O2]+[O2 -1*I2]; % first difference matrix 574 | %Dchol = chol( (eta2 * beta) * (D'*D) + eye(T) ); 575 | %sysinv = pinv(sysmat); 576 | end 577 | -------------------------------------------------------------------------------- /src/example_neurotycho.m: -------------------------------------------------------------------------------- 1 | clear all 2 | close all 3 | set(0,'DefaultFigurePaperPositionMode','auto') 4 | 5 | %set(0,'DefaultFigurePosition',[25,25,500,500]); 6 | set(0,'DefaultAxesFontSize',10) 7 | set(0,'DefaultAxesFontName','Helvetica') 8 | set(0,'DefaultAxesLineWidth',1); 9 | 10 | set(0,'DefaultTextFontSize',12) 11 | set(0,'DefaultTextFontName','Helvetica') 12 | set(0,'DefaultLineLineWidth',1); 13 | 14 | addpath('~/work/MATLAB/unlocbox') 15 | init_unlocbox(); 16 | 17 | rng(1); 18 | 19 | data_dir = '../data_full/20090525S1_FTT_K1_ZenasChao_mat_ECoG64-Motion8'; 20 | 21 | figdir = '../figures/'; 22 | post_ecog_data = [data_dir '/Kam_Bands.mat']; 23 | %post_ecog_data= [data_dir '/Kam_Bands_post_Steve.mat']); 24 | t0 = 200; 25 | tf = 1000; 26 | center = 0; 27 | voltage = 1; 28 | regularization = 'TV'; 29 | save_file = '../data/test_neurotycho.mat'; 30 | 31 | load([data_dir '/Motion.mat']); 32 | ntime = size(MotionData{1}, 1); 33 | ncoord = size(MotionData{1}, 2); 34 | nchan = length(MotionData); 35 | X = zeros(ntime, ncoord*nchan); 36 | for i = 1:nchan 37 | X(:, (i-1) * ncoord + 1 : i * ncoord) = MotionData{i}; 38 | end 39 | X = X'; 40 | X = standardize(X); 41 | 42 | idx = (MotionTime >= t0 & MotionTime <= tf); 43 | MotionTime = MotionTime(idx); 44 | X = X(:, idx); 45 | 46 | M = 30; 47 | R = 8; 48 | %% tensor DMD algorithm TV 49 | max_iter = 1000; % iterations 50 | eta1 = 0.005; % Tikhonov 51 | eta2 = 0.01; % regularization 52 | beta = 0.15; % TV-regularization 53 | %% TV-l0 54 | max_iter = 1000; % iterations 55 | eta1 = 0.01; % Tikhonov 56 | eta2 = 0.01; % regularization 57 | beta = 0.01; % TV-regularization 58 | %% ALS 59 | nu = 0.0; % ALS noise level 60 | RALS = 1; 61 | 62 | 63 | % [lambda, A, B, C, Xten, Yten] = ... 64 | % tensor_DMD_ALS(X', M, R, ... 65 | % 'center', 1, ... 66 | % 'eta', eta, ... 67 | % 'nu', nu, ... 68 | % 'beta', beta, ... 69 | % 'max_iter', max_iter); 70 | 71 | % [lambda, A, B, C, cost, Xten, Yten, rmse_vec, W] = ... 72 | % tensor_DMD_ALS_aux(X, M, R, ... 73 | % 'center', center, ... 74 | % 'eta1', eta1, ... 75 | % 'eta2', eta2, ... 76 | % 'beta', beta, ... 77 | % 'regularization', 'TV', ... 78 | % 'rtol', 1e-5, ... 79 | % 'max_iter', max_iter); 80 | 81 | [lambda, A, B, C, cost, Xten, Yten, rmse_vec] = ... 82 | TVART_alt_min(X, M, R, ... 83 | 'center', center, ... 84 | 'eta', eta1, ... 85 | 'beta', beta, ... 86 | 'regularization', regularization, ... 87 | 'verbosity', 2, ... 88 | 'max_iter', max_iter); 89 | %W = C; 90 | 91 | [lambda_r, A_r, B_r, C_r] = rebalance(A, B, C, 1); 92 | [lambda_r, A_r, B_r, C_r] = reorder_components(lambda_r, A_r, B_r, ... 93 | C_r); 94 | %A_r = A; B_r = B; C_r = C; lambda_r = lambda; 95 | %C_r = C_r * diag(lambda_r); 96 | 97 | %% Old ECoG preprocessing, now uses preprocess_neurotycho.py 98 | % ecog_time = load([data_dir '/ECoG_time.mat']); 99 | % ecog_time = ecog_time.('ECoGTime'); 100 | % n_chan = 64; 101 | % ecog_data = zeros(n_chan, length(ecog_time)); 102 | % for chan = 1:n_chan 103 | % fn = sprintf('%s/ECoG_ch%d.mat', data_dir, chan); 104 | % var = sprintf('ECoGData_ch%d', chan); 105 | % tmp = load(fn); 106 | % ecog_data(chan, :) = tmp.(var); 107 | % end 108 | % %% Common average referencing 109 | % ecog_data = ecog_data - repmat(mean(ecog_data, 1), n_chan, 1); 110 | % %ecog_data = ecog_data - repmat(mean(ecog_data, 2), 1, length(ecog_data)); 111 | % %ecog_data = ecog_data./ repmat(std(ecog_data, 0, 2), 1, 112 | % %length(ecog_data)); 113 | % ecog_data = standardize(ecog_data); 114 | % DT = ecog_time(2) - ecog_time(1); 115 | 116 | fprintf('\nNow working on ECoG data\n') 117 | 118 | load(post_ecog_data) 119 | %voltage = 0; 120 | % %% Miller comparison 121 | %load([data_dir '/Kam_Bands_param_Miller.mat']) 122 | %ecog_data = [ecog_post(1:3:end, :) ; 123 | % ecog_post(3:3:end, :) ]; 124 | % M_ecog = 6; 125 | % R_ecog = 3; 126 | % eta2 = 1e-4; 127 | % beta_ecog = 3e2; 128 | if voltage 129 | ecog_data = ecog_v_filtered; 130 | ecog_time = ecog_v_time; 131 | % Voltage params - TV! 132 | M_ecog = 200; 133 | R_ecog = 8; 134 | eta1 = 1.0; 135 | eta2 = 1e-4; 136 | beta_ecog = 100; 137 | % Voltage params - TV-L0 138 | % M_ecog = 200; 139 | % R_ecog = 10; 140 | % eta1 = 1.; 141 | % eta2 = 5e-5; 142 | % beta_ecog = 2.0; 143 | % downsample 144 | ecog_data = ecog_data(:, 1:2:end); 145 | ecog_time = ecog_time(1:2:end); 146 | else 147 | %% Power params - TV! 148 | M_ecog = 12; 149 | R_ecog = 10; 150 | eta1 = 1.; 151 | eta2 = 1e-4; 152 | beta_ecog = 300; 153 | ecog_data = ecog_power_post; 154 | ecog_time = ecog_power_time; 155 | end 156 | 157 | ecog_data = ecog_data(:, ecog_time <= MotionTime(end) & ecog_time >= MotionTime(1)); 158 | ecog_data = standardize(ecog_data); 159 | ecog_time = ecog_time(ecog_time <= MotionTime(end) & ecog_time >= MotionTime(1)); 160 | 161 | 162 | 163 | % [ecog_lambda, ecog_A, ecog_B, ecog_C, ecog_cost, ~,~, ecog_rmse, ecog_W] = ... 164 | % tensor_DMD_ALS_aux(ecog_data, M_ecog, R_ecog, ... 165 | % 'center', center, ... 166 | % 'eta1', eta1, ... 167 | % 'eta2', eta2, ... 168 | % 'beta', beta_ecog, ... 169 | % 'regularization', 'TV', ... 170 | % 'rtol', 1e-4,... 171 | % 'max_iter', max_iter); 172 | 173 | % [ecog_lambda, ecog_A, ecog_B, ecog_C, ecog_cost, ~,~, ecog_rmse] = ... 174 | % tensor_DMD_alt_min(ecog_data, M_ecog, R_ecog, ... 175 | % 'center', center, ... 176 | % 'eta', eta1, ... 177 | % 'beta', beta_ecog, ... 178 | % 'regularization', 'TV', ... 179 | % 'rtol', 1e-5,... 180 | % 'verbosity', 2,... 181 | % 'max_iter', max_iter); 182 | 183 | 184 | [ecog_lambda, ecog_A, ecog_B, ecog_C, ecog_cost, ~,~, ecog_rmse] = ... 185 | TVART_alt_min(ecog_data, M_ecog, R_ecog, ... 186 | 'center', center, ... 187 | 'eta', eta1, ... 188 | 'beta', beta_ecog, ... 189 | 'regularization', regularization, ... 190 | 'verbosity', 2,... 191 | 'max_iter', max_iter); 192 | 193 | 194 | %'init', struct('A', 195 | %ecog_A, 'B', ecog_B, 'C', 196 | %ecog_C)); 197 | 198 | [ecog_lambda, ecog_A, ecog_B, ecog_C, ecog_cost, ~,~, ecog_rmse] = ... 199 | TVART_alt_min(ecog_data, M_ecog, R_ecog, ... 200 | 'center', center, ... 201 | 'eta', eta1, ... 202 | 'beta', beta_ecog, ... 203 | 'regularization', regularization, ... 204 | 'verbosity', 2,... 205 | 'max_iter', max_iter, ... 206 | 'rtol', 1e-4,... 207 | 'init', struct('A', ecog_A, 'B', ecog_B, 'C', ... 208 | ecog_C)); 209 | 210 | ecog_W = ecog_C; 211 | ecog_A_r = ecog_A; ecog_B_r = ecog_B; ecog_C_r = ecog_C; ... 212 | ecog_lambda_r = ecog_lambda; 213 | % [ecog_lambda_r, ecog_A_r, ecog_B_r, ecog_C_r, ecog_W_r] = ... 214 | % rebalance_2(ecog_A, ecog_B, ecog_C, ecog_W, 1); 215 | % [ecog_lambda_r, ecog_A_r, ecog_B_r, ecog_C_r] = ... 216 | % rebalance(ecog_A, ecog_B, ecog_C, 1); 217 | % [ecog_lambda_r, ecog_A_r, ecog_B_r, ecog_C_r] = ... 218 | % reorder_components(ecog_lambda_r, ecog_A_r, ecog_B_r, ecog_C_r); 219 | % ecog_C_r = ecog_C_r * diag(ecog_lambda_r); 220 | 221 | % s = struct('A', ecog_A, 'B', ecog_B, 'C', ecog_C, 'lambda', ecog_lambda, ... 222 | % 'W', ecog_W); 223 | 224 | % [ecog_lambda, ecog_A, ecog_B, ecog_C, ecog_cost, ~,~, ecog_rmse, ecog_W] = ... 225 | % tensor_DMD_ALS_aux(ecog_data, M_ecog, R_ecog, ... 226 | % 'center', 0, ... 227 | % 'eta1', eta1, ... 228 | % 'eta2', eta2, ... 229 | % 'beta', beta_ecog, ... 230 | % 'regularization', 'TV', ... 231 | % 'max_iter', max_iter, ... 232 | % 'init', s); 233 | 234 | % [ecog_lambda, ecog_A, ecog_B, ecog_C, ecog_cost, ~,~, ecog_rmse] = ... 235 | % tensor_DMD_prox_grad(ecog_data, M_ecog, R_ecog, ... 236 | % 'center', 0, ... 237 | % 'eta', eta1, ... 238 | % 'step_size', 1e-6, ... 239 | % 'beta', beta_ecog, ... 240 | % 'iter_disp', 1, ... 241 | % 'max_iter', 1000, ... 242 | % 'init', s ); 243 | 244 | 245 | % eta2 = 1e-3; 246 | % [ecog_lambda, ecog_A, ecog_B, ecog_C, ecog_cost, ~,~, ecog_rmse] = ... 247 | % tensor_DMD_ALS_smooth(ecog_post, M_ecog, R_ecog, ... 248 | % 'center', 0, ... 249 | % 'eta1', eta1, ... 250 | % 'eta2', eta2, ... 251 | % 'beta', 5, ... 252 | % 'max_iter', 300); 253 | 254 | 255 | % figure; 256 | % ax1 = subplot(2,1,1); 257 | % plot(MotionTime(1:M:end-M), sum(C.^2, 2)) 258 | % title('MOCAP extracted components') 259 | % axis tight 260 | % ax2 = subplot(2,1,2); 261 | % plot(ecog_time(1:M_ecog:end-M_ecog), sum(ecog_C.^2, 2)) 262 | % title('ECoG extracted components') 263 | % xlabel('time (s)') 264 | % axis tight 265 | % linkaxes([ax1 ax2], 'x') 266 | % xlim([100, 300]) 267 | 268 | ts1 = timeseries(X, MotionTime); 269 | ts3 = timeseries(ecog_C, ecog_time(1:M_ecog:end-M_ecog)); 270 | ts2 = timeseries(C, MotionTime(1:M:end-M)); 271 | [ts2,ts3] = synchronize(ts2, ts3, 'union'); 272 | % tmp=corrcoef([ts2.Data, ts3.Data]); 273 | % disp(tmp); 274 | 275 | % figure; 276 | % semilogy(ecog_cost) 277 | % title('ECoG convergence') 278 | 279 | % figure; 280 | % semilogy(cost) 281 | % title('MOCAP convergence') 282 | 283 | clusters_mocap = kmeans(X', 2); 284 | clusters_mocap_modes = kmedoids(C_r, 2); 285 | clusters_ecog_modes = kmedoids(ecog_C_r, 2); 286 | 287 | % figure; 288 | % ax1 = subplot(3,1,1); 289 | % plot(MotionTime(1:M:end-M), C_r) 290 | % title('MOCAP extracted components') 291 | % axis tight; 292 | % ax2 = subplot(3,1,2); 293 | % plot(MotionTime(1:M:end-M), clusters, 'o') 294 | % ylim([0,4]); 295 | % title('MOCAP clusters') 296 | % ax3 = subplot(3,1,3); 297 | % plot(ecog_time(1:M_ecog:end-M_ecog), ecog_clusters, 'o') 298 | % ylim([0,4]); 299 | % title('ECoG clusters') 300 | % %plot(transitions(:,1)/M, transitions(:,2)+1, 'k') 301 | % linkaxes([ax1 ax2 ax3], 'x') 302 | % xlim([120, 320]) 303 | 304 | 305 | %% correlate ecog signals against movement` 306 | 307 | lf_power = sum(ecog_power_post(1:2:end,:),1); 308 | hf_power = sum(ecog_power_post(2:2:end,:),1); 309 | 310 | % [~, locs] = findpeaks(X(6,:), 'MinPeakWidth', 20, ... 311 | % ' MinPeakProminence', 3); 312 | 313 | idx_changes = findchangepts(X, 'MaxNumChanges', 120); 314 | 315 | 316 | %% Save output 317 | save(save_file); 318 | 319 | 320 | %% Now plotting 321 | 322 | 323 | figure; 324 | ax1 = subplot(5,1,1:2); 325 | plot(MotionTime, X, 'color', [0,0,0] + 0.4, 'linewidth', 0.5); 326 | vline(MotionTime(idx_changes), 'k--') 327 | ylim([-8 8]) 328 | title('MOCAP data', 'fontweight','bold') 329 | %grid on; grid minor 330 | set(gca, 'xticklabels', {}) 331 | % ax2 = subplot(8,1,7); 332 | % %plot(MotionTime, clusters_mocap, 'o'); 333 | % imagesc([MotionTime(1), MotionTime(end)], [0,1], clusters_mocap'); 334 | % colormap(gca, 'winter') 335 | % ylabel('MOCAP', 'fontweight','bold') 336 | % %ylim([0.9, 2.1]) 337 | % %grid on; grid minor 338 | % %set(gca, 'xtick', []) 339 | % set(gca, 'ytick', []) 340 | % set(gca, 'xticklabels', {}) 341 | ax2 = subplot(5,1,3); 342 | %plot(MotionTime(1:M:end-M), clusters_mocap_modes, 'o') 343 | %plot(ecog_time(1:M_ecog:end-M_ecog), clusters_ecog_modes,'o'); 344 | imagesc([ecog_time(1), ecog_time(end-M_ecog)], [0 1], clusters_ecog_modes'); 345 | colormap(gca, 'summer') 346 | %set(gca, 'xtick', []) 347 | set(gca, 'ytick', []) 348 | set(gca, 'xticklabels', {}) 349 | title('ECoG clusters', 'fontweight','bold') 350 | ax3 = subplot(5,1,4); 351 | plot(ecog_power_time, lf_power) 352 | hold on; 353 | plot(ecog_power_time, hf_power,'r-') 354 | grid on; grid minor 355 | legend({'low freq', 'high freq'}) 356 | title('ECoG band power', 'fontweight', 'bold') 357 | axis tight 358 | set(gca, 'xticklabels', {}) 359 | ax5 = subplot(5,1,5); 360 | %tmp = ecog_C_r - repmat(mean(ecog_C_r,1), size(ecog_C_r,1), 1); 361 | tmp = ecog_C_r; 362 | %tmp = ecog_C_r(:,2) - ecog_C_r(:,1); 363 | plot(ecog_time(1:M_ecog:(length(tmp)*M_ecog)), tmp); 364 | %vline(MotionTime(clusters_mocap == 1), 'k:'); 365 | %vline(MotionTime(locs), 'k:'); 366 | title('Temporal modes', 'fontweight', 'bold') 367 | axis tight 368 | %ylim([0.9, 2.1]) 369 | %grid on; grid minor 370 | xlabel('Time (s)') 371 | linkaxes([ax1 ax2 ax3 ax5], 'x') 372 | xlim([700 800]) 373 | set(gcf,'renderer','Painters') 374 | set(gcf, 'Color', 'w', 'Position', [100 200 600 700]); 375 | set(gcf, 'PaperUnits', 'inches', ... 376 | 'PaperPosition', [0 0 6.5 7.5], 'PaperPositionMode', 'manual'); 377 | print('-depsc2', '-loose', [figdir 'neurotycho_clusters.eps'], '-r300'); 378 | 379 | 380 | if ~voltage 381 | figure; 382 | ax1 = subplot(4,1,1); 383 | plot(MotionTime, X) 384 | grid on; grid minor 385 | %vline(MotionTime(clusters_mocap == 1), 'k:'); 386 | title('MOCAP data', 'fontweight', 'normal') 387 | axis tight; 388 | ax2 = subplot(4,1,2); 389 | plot(MotionTime(1:M:end-M), C_r) 390 | grid on; grid minor 391 | title('MOCAP temporal modes', 'fontweight', 'normal') 392 | axis tight 393 | ax4 = subplot(4,1,3); 394 | plot(ecog_power_time, lf_power) 395 | hold on; 396 | plot(ecog_power_time, hf_power,'r-') 397 | grid on; grid minor 398 | legend({'low freq', 'high freq'}) 399 | title('ECoG band power', 'fontweight', 'normal') 400 | axis tight 401 | ax3 = subplot(4,1,4); 402 | %tmp = ecog_C_r - repmat(mean(ecog_C_r,1), size(ecog_C_r,1), 1); 403 | tmp = ecog_C_r; 404 | %tmp = ecog_C_r(:,2) - ecog_C_r(:,1); 405 | plot(ecog_time(1:M_ecog:(length(tmp)*M_ecog)), tmp); 406 | grid on; grid minor 407 | %vline(MotionTime(clusters_mocap == 1), 'k:'); 408 | %vline(MotionTime(locs), 'k:'); 409 | title('ECoG power: temporal modes', 'fontweight', 'normal') 410 | xlabel('time (s)') 411 | legend 412 | axis tight 413 | linkaxes([ax1 ax2 ax3 ax4], 'x') 414 | xlim([t0, tf]) 415 | %xlim([200, 300]) 416 | set(gcf,'renderer','Painters') 417 | set(gcf, 'Color', 'w', 'Position', [100 500 800 980]); 418 | set(gcf, 'PaperUnits', 'inches', ... 419 | 'PaperPosition', [0 0 7.5 10], ... 420 | 'PaperPositionMode', 'manual'); 421 | print('-depsc2', '-loose', [figdir 'neurotycho_summary.eps'], '-r300'); 422 | else 423 | figure; 424 | ax1 = subplot(4,1,1); 425 | plot(MotionTime, X) 426 | grid on; grid minor 427 | %vline(MotionTime(clusters_mocap == 1), 'k:'); 428 | title('MOCAP data') 429 | axis tight; 430 | ax2 = subplot(4,1,2); 431 | plot(MotionTime(1:M:end-M), C_r) 432 | grid on; grid minor 433 | title('MOCAP temporal modes') 434 | axis tight 435 | ax4 = subplot(4,1,3); 436 | plot(ecog_power_time, lf_power) 437 | hold on; 438 | plot(ecog_power_time, hf_power,'r-') 439 | grid on; grid minor 440 | legend({'low freq', 'high freq'}) 441 | title('ECoG low / high power') 442 | axis tight 443 | ax3 = subplot(4,1,4); 444 | %tmp = ecog_C_r - repmat(mean(ecog_C_r,1), size(ecog_C_r,1), 1); 445 | tmp = ecog_C_r; 446 | %tmp = ecog_C_r(:,2) - ecog_C_r(:,1); 447 | plot(ecog_time(1:M_ecog:(length(tmp)*M_ecog)), tmp); 448 | grid on; grid minor 449 | %vline(MotionTime(clusters_mocap == 1), 'k:'); 450 | %vline(MotionTime(locs), 'k:'); 451 | title('ECoG voltage: temporal modes') 452 | xlabel('time (s)') 453 | %legend 454 | axis tight 455 | linkaxes([ax1 ax2 ax3 ax4], 'x') 456 | %xlim([t0, tf]) 457 | xlim([700, 760]) 458 | set(gcf,'renderer','Painters') 459 | set(gcf, 'Color', 'w', 'Position', [0 0 750 1000]); 460 | set(gcf, 'PaperUnits', 'inches', ... 461 | 'PaperPosition', [0 0 7.5 10], ... 462 | 'PaperPositionMode', 'manual'); 463 | print('-depsc2', '-loose', [figdir 'neurotycho_summary.eps'], '-r200'); 464 | end 465 | 466 | 467 | 468 | 469 | %% Plot modes 470 | if ~voltage 471 | %% Plot modes - LOW/HIGH power 472 | addpath('~/work/MATLAB') 473 | img = imread([figdir 'monkeys/monkey_K1.png']); 474 | load('../data/K1_electrodes.mat') 475 | for mode = 1:R_ecog 476 | figure; 477 | ax = subplot(1,4,1); 478 | image(img); 479 | set(gca,'YTickLabel',[]); 480 | set(gca, 'XtickLabel', []); 481 | hold on 482 | vec = ecog_A_r(:, mode); 483 | scatter(img_electrodes_x, img_electrodes_y, 80, vec(1:2:end), ... 484 | 'filled'); 485 | colormap(flipud(brewermap([], 'PuOr'))) 486 | colorbar 487 | symmetrize_colorbar(); 488 | title('Low freq, left') 489 | ylabel(sprintf('Mode %d', mode), 'fontsize', 20) 490 | axis equal 491 | axis tight 492 | 493 | ax= subplot(1,4,2); 494 | image(img); 495 | set(gca,'YTickLabel',[]); 496 | set(gca, 'XtickLabel', []); 497 | hold on 498 | vec = ecog_A_r(:, mode); 499 | scatter(img_electrodes_x, img_electrodes_y, 80, vec(2:2:end), 'filled'); 500 | colormap(flipud(brewermap([], 'PuOr'))) 501 | colorbar 502 | symmetrize_colorbar(); 503 | title('High freq, left') 504 | axis equal 505 | axis tight 506 | 507 | ax = subplot(1,4,3); 508 | image(img); 509 | set(gca,'YTickLabel',[]); 510 | set(gca, 'XtickLabel', []); 511 | hold on 512 | if center 513 | vec = ecog_B_r(1:end-1, mode); 514 | else 515 | vec = ecog_B_r(:, mode); 516 | end 517 | scatter(img_electrodes_x, img_electrodes_y, 80, vec(1:2:end), 'filled'); 518 | colormap(flipud(brewermap([], 'PuOr'))) 519 | colorbar 520 | symmetrize_colorbar(); 521 | title('Low freq, right') 522 | axis equal 523 | axis tight 524 | 525 | ax = subplot(1,4,4); 526 | image(img); 527 | set(gca,'YTickLabel',[]); 528 | set(gca, 'XtickLabel', []); 529 | hold on 530 | if center 531 | vec = ecog_B_r(1:end-1, mode); 532 | else 533 | vec = ecog_B_r(:, mode); 534 | end 535 | scatter(img_electrodes_x, img_electrodes_y, 80, vec(2:2:end), 'filled'); 536 | colormap(flipud(brewermap([], 'PuOr'))) 537 | colorbar 538 | symmetrize_colorbar(); 539 | title('High freq, right') 540 | axis equal 541 | axis tight 542 | if center 543 | xlabel(sprintf('affine weight = %1.2g, %1.2g%%', ... 544 | ecog_B_r(end, mode), ... 545 | abs(ecog_B_r(end, mode)) / sum(abs(ecog_B_r(end,:))) * 100 ),... 546 | 'fontsize', 14); 547 | end 548 | 549 | set(gcf,'renderer','Painters') 550 | set(gcf, 'Color', 'w') 551 | set(gcf, 'position', [ 105 450 1653 ... 552 | 473]) 553 | set(gcf, 'PaperUnits', 'inches', ... 554 | 'PaperPosition', [0 0 11 3.2], ... 555 | 'PaperPositionMode', 'auto'); 556 | print('-depsc2', '-loose', '-r300', ... 557 | sprintf('%sneurotycho_mode_%d.eps', figdir, mode)); 558 | end 559 | else 560 | %% Plot modes - VOLTAGE 561 | addpath('~/work/MATLAB') 562 | img = imread([figdir 'monkeys/monkey_K1.png']); 563 | load('../data/K1_electrodes.mat') 564 | for mode = 1:R_ecog 565 | figure; 566 | ax = subplot(1,2,1); 567 | image(img); 568 | set(gca,'YTickLabel',[]); 569 | set(gca, 'XtickLabel', []); 570 | hold on 571 | vec = ecog_A_r(:, mode); 572 | scatter(img_electrodes_x, img_electrodes_y, 80, vec, ... 573 | 'filled'); 574 | colormap(flipud(brewermap([], 'PuOr'))) 575 | colorbar 576 | symmetrize_colorbar(); 577 | title('Left') 578 | ylabel(sprintf('Mode %d', mode), 'fontsize', 20) 579 | axis equal 580 | axis tight 581 | 582 | ax = subplot(1,2,2); 583 | image(img); 584 | set(gca,'YTickLabel',[]); 585 | set(gca, 'XtickLabel', []); 586 | hold on 587 | if center 588 | vec = ecog_B_r(1:end-1, mode); 589 | else 590 | vec = ecog_B_r(:, mode); 591 | end 592 | scatter(img_electrodes_x, img_electrodes_y, 80, vec, 'filled'); 593 | colormap(flipud(brewermap([], 'PuOr'))) 594 | colorbar 595 | symmetrize_colorbar(); 596 | title('Right') 597 | axis equal 598 | axis tight 599 | if center 600 | xlabel(sprintf('affine weight = %1.2g, %1.2g%%', ... 601 | ecog_B_r(end, mode), ... 602 | abs(ecog_B_r(end, mode)) / sum(abs(ecog_B_r(end,:))) * 100 ), ... 603 | 'fontsize', 14); 604 | end 605 | 606 | set(gcf,'renderer','Painters') 607 | set(gcf, 'Color', 'w') 608 | set(gcf, 'position', [ 105 450 830 ... 609 | 473]) 610 | set(gcf, 'PaperUnits', 'inches', ... 611 | 'PaperPosition', [0 0 5.5 3.2], ... 612 | 'PaperPositionMode', 'auto'); 613 | print('-depsc2', '-loose', '-r300', ... 614 | sprintf('%sneurotycho_mode_%d.eps', figdir, mode)); 615 | end 616 | end 617 | -------------------------------------------------------------------------------- /src/run_all_tests_ARHMM-SSM.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 25, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import numpy as np\n", 10 | "import scipy.io\n", 11 | "import matplotlib.pyplot as plt\n", 12 | "import pandas as pd\n", 13 | "from ssm_compare import fit_arhmm_and_return_errors" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": 26, 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "#N_array = np.array([10,14,18,20,30,50,80,100,200,400,1000,2000,4000])\n", 23 | "N_array = np.array([10,14,18,20,30,50,80,100,200, 400, 1000])\n", 24 | "# N_array = np.array([10,100,1000])\n", 25 | "#N_array = np.array([10,14,18,20,30,50,80,100,200])\n", 26 | "Kmax = 4\n", 27 | "num_iters = 1000\n", 28 | "num_restarts = 3\n", 29 | "model = 'ARHMM'\n", 30 | "rank = None\n", 31 | "\n", 32 | "#table_file = \"../data/comparison_output_rank4.csv\"\n", 33 | "#output_file = \"../data/comparison_final_rank4.csv\"\n", 34 | "#output_file = \"../data/comparison_output_slds_multiple_300_30_grow.csv\"\n", 35 | "# output_file = \"../data/comparison_output_2000_arhmm.csv\"\n", 36 | "output_file = \"../data/comparison_output_200_arhmm.csv\"" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": 27, 42 | "metadata": {}, 43 | "outputs": [ 44 | { 45 | "name": "stdout", 46 | "output_type": "stream", 47 | "text": [ 48 | "Empty DataFrame\n", 49 | "Columns: [N, window, model, err_inf, err_2, err_fro, model_MSE, true_MSE]\n", 50 | "Index: []\n", 51 | "Index(['N', 'window', 'model', 'err_inf', 'err_2', 'err_fro', 'model_MSE',\n", 52 | " 'true_MSE'],\n", 53 | " dtype='object')\n", 54 | "[]\n" 55 | ] 56 | } 57 | ], 58 | "source": [ 59 | "# error_table = pd.read_csv(table_file)\n", 60 | "\n", 61 | "error_table = pd.DataFrame(columns = ['N', 'window', 'model', 'err_inf', 'err_2', 'err_fro', 'model_MSE', \\\n", 62 | " 'true_MSE'])\n", 63 | "\n", 64 | "print(error_table.head())\n", 65 | "print(error_table.columns)\n", 66 | "print(error_table['model'].unique())" 67 | ] 68 | }, 69 | { 70 | "cell_type": "markdown", 71 | "metadata": {}, 72 | "source": [ 73 | "## Fit SLDS with rank r = 4 & r = 6" 74 | ] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "execution_count": null, 79 | "metadata": { 80 | "scrolled": true 81 | }, 82 | "outputs": [ 83 | { 84 | "name": "stdout", 85 | "output_type": "stream", 86 | "text": [ 87 | "-----------------------------\n", 88 | "N=10, STEPS: 1000 \n", 89 | "restart 1 / 3\n" 90 | ] 91 | }, 92 | { 93 | "data": { 94 | "application/vnd.jupyter.widget-view+json": { 95 | "model_id": "d596704390804c9187d75853c416ece3", 96 | "version_major": 2, 97 | "version_minor": 0 98 | }, 99 | "text/plain": [ 100 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1000.0), HTML(value='')))" 101 | ] 102 | }, 103 | "metadata": {}, 104 | "output_type": "display_data" 105 | }, 106 | { 107 | "name": "stdout", 108 | "output_type": "stream", 109 | "text": [ 110 | "\n", 111 | "restart 2 / 3\n" 112 | ] 113 | }, 114 | { 115 | "data": { 116 | "application/vnd.jupyter.widget-view+json": { 117 | "model_id": "5bd44f7d464e41d3bb4a861b8a842b6d", 118 | "version_major": 2, 119 | "version_minor": 0 120 | }, 121 | "text/plain": [ 122 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1000.0), HTML(value='')))" 123 | ] 124 | }, 125 | "metadata": {}, 126 | "output_type": "display_data" 127 | }, 128 | { 129 | "name": "stdout", 130 | "output_type": "stream", 131 | "text": [ 132 | "\n", 133 | "restart 3 / 3\n" 134 | ] 135 | }, 136 | { 137 | "data": { 138 | "application/vnd.jupyter.widget-view+json": { 139 | "model_id": "9ccc44350c044023a95defaf757d0f47", 140 | "version_major": 2, 141 | "version_minor": 0 142 | }, 143 | "text/plain": [ 144 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1000.0), HTML(value='')))" 145 | ] 146 | }, 147 | "metadata": {}, 148 | "output_type": "display_data" 149 | }, 150 | { 151 | "name": "stdout", 152 | "output_type": "stream", 153 | "text": [ 154 | "\n", 155 | "predicted states:\n", 156 | "[3 3 3 1 0 3 1 1 0 1 3 3 1 3 3 1 0 1 1 1 1 2 1 2 2 3 2 0 0 1 0 1 0 3 1 1 1\n", 157 | " 1 0 1 1 1 0 3 2 3 1 1 3 1 0 2 0 3 3 1 1 1 2 3 2 2 0 3 3 1 0 2 2 1 1 1 1 0\n", 158 | " 1 1 1 2 2 1 3 1 1 0 2 2 2 2 2 2 0 0 1 3 0 1 1 2 0 2 1 1 1 3 0 0 1 0 2 2 1\n", 159 | " 1 1 1 1 1 1 1 0 3 2 2 2 2 2 1 1 2 2 1 1 0 1 0 2 0 3 3 1 1 1 2 2 1 1 1 1 0\n", 160 | " 1 3 3 3 1 1 1 2 1 1 1 1 1 1 1 3 0 3 1 0 1 1 1 0 1 1 2 1 1 3 3 3 3 3 3 1 1\n", 161 | " 1 1 1 2 1 2 0 2 2 1 1 2 1 1 0 0]\n", 162 | "N = 10 : err_inf = 0.809909, err_2 = 1.942150, err_fro = 2.629339, err_mse = 0.025584\n", 163 | "restart 1 / 3\n" 164 | ] 165 | }, 166 | { 167 | "data": { 168 | "application/vnd.jupyter.widget-view+json": { 169 | "model_id": "4eaeba53d412402ba59625f9cd5f3169", 170 | "version_major": 2, 171 | "version_minor": 0 172 | }, 173 | "text/plain": [ 174 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1000.0), HTML(value='')))" 175 | ] 176 | }, 177 | "metadata": {}, 178 | "output_type": "display_data" 179 | }, 180 | { 181 | "name": "stdout", 182 | "output_type": "stream", 183 | "text": [ 184 | "\n", 185 | "restart 2 / 3\n" 186 | ] 187 | }, 188 | { 189 | "data": { 190 | "application/vnd.jupyter.widget-view+json": { 191 | "model_id": "b5c406763f0b4f0cb8eed4a449946bd3", 192 | "version_major": 2, 193 | "version_minor": 0 194 | }, 195 | "text/plain": [ 196 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1000.0), HTML(value='')))" 197 | ] 198 | }, 199 | "metadata": {}, 200 | "output_type": "display_data" 201 | }, 202 | { 203 | "name": "stdout", 204 | "output_type": "stream", 205 | "text": [ 206 | "\n", 207 | "restart 3 / 3\n" 208 | ] 209 | }, 210 | { 211 | "data": { 212 | "application/vnd.jupyter.widget-view+json": { 213 | "model_id": "01e92c3e14b64f3d850b2a25f37aa3cc", 214 | "version_major": 2, 215 | "version_minor": 0 216 | }, 217 | "text/plain": [ 218 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1000.0), HTML(value='')))" 219 | ] 220 | }, 221 | "metadata": {}, 222 | "output_type": "display_data" 223 | }, 224 | { 225 | "name": "stdout", 226 | "output_type": "stream", 227 | "text": [ 228 | "\n", 229 | "predicted states:\n", 230 | "[2 2 2 2 2 2 2 3 0 0 2 2 0 0 0 2 2 1 0 2 2 2 2 0 0 0 0 1 1 1 1 1 3 3 3 2 2\n", 231 | " 2 2 3 2 2 0 0 0 1 0 2 3 3 3 3 0 2 2 0 2 2 3 3 0 0 2 2 2 3 3 3 3 2 2 2 2 2\n", 232 | " 2 3 3 3 3 1 1 1 0 1 1 2 2 2 2 2 2 1 1 3 3 2 2 2 2 3 3 3 1 0 0 0 0 0 1 0 1\n", 233 | " 1 1 1 1 2 2 2 1 2 2 2 2 2 2 1 1 0 0 3 3 1 1 3 0 2 2 2 0 3 3 3 3 1 0 1 3 3\n", 234 | " 2 3 3 2 0 1 2 0 1 0 0 0 3 2 2 2 2 2 2 2 2 2 2 2 2 2 0 1 2 2 2 2 0 2 1 2 2\n", 235 | " 2 2 3 3 0 1 2 1 2 3 2 0 0 3 3 3]\n", 236 | "N = 10 : err_inf = 0.778965, err_2 = 1.938223, err_fro = 2.813311, err_mse = 0.025525\n", 237 | "restart 1 / 3\n" 238 | ] 239 | }, 240 | { 241 | "data": { 242 | "application/vnd.jupyter.widget-view+json": { 243 | "model_id": "944e3155b1884ac4bce3399461d28afb", 244 | "version_major": 2, 245 | "version_minor": 0 246 | }, 247 | "text/plain": [ 248 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1000.0), HTML(value='')))" 249 | ] 250 | }, 251 | "metadata": {}, 252 | "output_type": "display_data" 253 | }, 254 | { 255 | "name": "stdout", 256 | "output_type": "stream", 257 | "text": [ 258 | "\n", 259 | "restart 2 / 3\n" 260 | ] 261 | }, 262 | { 263 | "data": { 264 | "application/vnd.jupyter.widget-view+json": { 265 | "model_id": "451262a99794436784f364866ebcc25e", 266 | "version_major": 2, 267 | "version_minor": 0 268 | }, 269 | "text/plain": [ 270 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1000.0), HTML(value='')))" 271 | ] 272 | }, 273 | "metadata": {}, 274 | "output_type": "display_data" 275 | }, 276 | { 277 | "name": "stdout", 278 | "output_type": "stream", 279 | "text": [ 280 | "\n", 281 | "restart 3 / 3\n" 282 | ] 283 | }, 284 | { 285 | "data": { 286 | "application/vnd.jupyter.widget-view+json": { 287 | "model_id": "e1f88c1e8ad945e6914aeacc237764f4", 288 | "version_major": 2, 289 | "version_minor": 0 290 | }, 291 | "text/plain": [ 292 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1000.0), HTML(value='')))" 293 | ] 294 | }, 295 | "metadata": {}, 296 | "output_type": "display_data" 297 | }, 298 | { 299 | "name": "stdout", 300 | "output_type": "stream", 301 | "text": [ 302 | "\n", 303 | "predicted states:\n", 304 | "[1 1 0 0 0 0 0 2 3 3 3 2 2 0 1 1 1 1 0 0 3 2 0 0 0 0 0 1 0 3 2 1 1 2 0 2 2\n", 305 | " 1 1 1 0 3 2 2 0 0 2 1 0 3 3 2 0 3 0 2 2 2 1 0 0 1 1 1 1 2 2 1 1 1 0 1 0 0\n", 306 | " 0 0 0 0 3 3 1 1 0 3 2 2 2 1 1 1 3 3 2 2 2 1 0 0 3 3 2 3 3 2 1 0 1 0 0 0 0\n", 307 | " 2 2 0 0 0 0 3 2 0 3 2 3 0 1 1 1 2 0 0 0 0 0 0 0 3 0 0 3 1 2 3 2 2 0 2 1 1\n", 308 | " 1 0 3 2 0 0 0 2 2 0 1 1 2 1 1 1 1 0 0 3 1 0 3 1 1 1 1 3 2 2 2 2 3 3 3 2 2\n", 309 | " 0 0 0 0 0 0 0 3 2 0 0 0 0 2 3 0]\n", 310 | "N = 10 : err_inf = 0.755064, err_2 = 1.815993, err_fro = 2.571793, err_mse = 0.026484\n", 311 | "-----------------------------\n", 312 | "N=14, STEPS: 1000 \n", 313 | "restart 1 / 3\n" 314 | ] 315 | }, 316 | { 317 | "data": { 318 | "application/vnd.jupyter.widget-view+json": { 319 | "model_id": "67fa0fcd43d14a288fb0e953f4a21f45", 320 | "version_major": 2, 321 | "version_minor": 0 322 | }, 323 | "text/plain": [ 324 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1000.0), HTML(value='')))" 325 | ] 326 | }, 327 | "metadata": {}, 328 | "output_type": "display_data" 329 | }, 330 | { 331 | "name": "stdout", 332 | "output_type": "stream", 333 | "text": [ 334 | "\n", 335 | "restart 2 / 3\n" 336 | ] 337 | }, 338 | { 339 | "data": { 340 | "application/vnd.jupyter.widget-view+json": { 341 | "model_id": "46e84d74afab454cb31ce45f02eef958", 342 | "version_major": 2, 343 | "version_minor": 0 344 | }, 345 | "text/plain": [ 346 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1000.0), HTML(value='')))" 347 | ] 348 | }, 349 | "metadata": {}, 350 | "output_type": "display_data" 351 | }, 352 | { 353 | "name": "stdout", 354 | "output_type": "stream", 355 | "text": [ 356 | "\n", 357 | "restart 3 / 3\n" 358 | ] 359 | }, 360 | { 361 | "data": { 362 | "application/vnd.jupyter.widget-view+json": { 363 | "model_id": "39cc91b2618149d59a68cfad78db1653", 364 | "version_major": 2, 365 | "version_minor": 0 366 | }, 367 | "text/plain": [ 368 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1000.0), HTML(value='')))" 369 | ] 370 | }, 371 | "metadata": {}, 372 | "output_type": "display_data" 373 | }, 374 | { 375 | "name": "stdout", 376 | "output_type": "stream", 377 | "text": [ 378 | "\n", 379 | "predicted states:\n", 380 | "[0 0 0 1 3 3 1 1 2 0 1 1 2 1 2 3 0 1 3 3 0 2 1 3 3 2 3 0 0 3 1 0 0 0 0 2 1\n", 381 | " 3 0 0 1 0 0 0 1 3 2 0 1 2 0 2 0 2 2 3 3 3 3 0 3 1 0 0 0 3 1 0 1 1 1 0 1 3\n", 382 | " 2 1 0 3 3 1 2 2 1 1 1 0 0 0 3 3 1 0 0 1 3 0 2 1 0 1 0 0 0 2 0 0 0 2 2 0 0\n", 383 | " 0 0 2 2 3 0 0 3 2 2 0 3 1 2 3 3 3 3 3 3 0 1 1 0 0 0 2 3 0 0 1 0 3 0 1 0 0\n", 384 | " 3 2 0 0 3 2 0 2 3 2 2 2 3 3 0 0 3 2 3 0 1 0 2 1 1 0 1 0 0 2 3 1 1 1 0 0 3\n", 385 | " 0 3 3 3 2 3 2 2 3 1 1 3 2 0 0 0]\n", 386 | "N = 14 : err_inf = 0.699468, err_2 = 1.923152, err_fro = 2.995038, err_mse = 0.030382\n", 387 | "restart 1 / 3\n" 388 | ] 389 | }, 390 | { 391 | "data": { 392 | "application/vnd.jupyter.widget-view+json": { 393 | "model_id": "e4953c6b4210427dbc193910d9d70c56", 394 | "version_major": 2, 395 | "version_minor": 0 396 | }, 397 | "text/plain": [ 398 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1000.0), HTML(value='')))" 399 | ] 400 | }, 401 | "metadata": {}, 402 | "output_type": "display_data" 403 | }, 404 | { 405 | "name": "stdout", 406 | "output_type": "stream", 407 | "text": [ 408 | "\n", 409 | "restart 2 / 3\n" 410 | ] 411 | }, 412 | { 413 | "data": { 414 | "application/vnd.jupyter.widget-view+json": { 415 | "model_id": "992149c83855418a893ad6c8eac68bdd", 416 | "version_major": 2, 417 | "version_minor": 0 418 | }, 419 | "text/plain": [ 420 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1000.0), HTML(value='')))" 421 | ] 422 | }, 423 | "metadata": {}, 424 | "output_type": "display_data" 425 | }, 426 | { 427 | "name": "stdout", 428 | "output_type": "stream", 429 | "text": [ 430 | "\n", 431 | "restart 3 / 3\n" 432 | ] 433 | }, 434 | { 435 | "data": { 436 | "application/vnd.jupyter.widget-view+json": { 437 | "model_id": "61a5bb65152144db8e0d3642a66361d4", 438 | "version_major": 2, 439 | "version_minor": 0 440 | }, 441 | "text/plain": [ 442 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1000.0), HTML(value='')))" 443 | ] 444 | }, 445 | "metadata": {}, 446 | "output_type": "display_data" 447 | }, 448 | { 449 | "name": "stdout", 450 | "output_type": "stream", 451 | "text": [ 452 | "\n", 453 | "predicted states:\n", 454 | "[3 3 1 2 0 0 1 0 0 2 1 1 1 0 1 1 1 3 1 2 3 2 3 3 2 1 1 1 1 2 3 1 0 0 3 1 0\n", 455 | " 0 0 0 2 3 0 0 3 3 1 1 2 3 3 0 1 3 0 3 2 2 2 2 3 1 2 2 2 3 1 1 3 0 1 0 3 3\n", 456 | " 3 3 3 1 0 2 0 1 0 1 0 0 2 3 1 1 2 1 0 0 1 2 1 2 1 1 0 1 0 3 1 1 3 3 1 1 2\n", 457 | " 0 2 0 1 0 0 0 0 0 0 0 0 0 1 1 3 1 2 2 0 1 0 0 1 1 3 3 0 1 2 0 1 3 0 2 2 1\n", 458 | " 0 1 0 0 2 2 0 0 1 1 0 0 1 3 3 1 0 3 0 3 3 1 2 1 2 2 1 1 3 2 1 3 2 1 1 3 3\n", 459 | " 3 1 3 1 1 1 2 1 1 1 1 1 2 2 2 2]\n", 460 | "N = 14 : err_inf = 0.758821, err_2 = 2.097091, err_fro = 3.300425, err_mse = 0.033146\n", 461 | "restart 1 / 3\n" 462 | ] 463 | }, 464 | { 465 | "data": { 466 | "application/vnd.jupyter.widget-view+json": { 467 | "model_id": "16f281b39ce649b0ab2bbf10ffa5e17b", 468 | "version_major": 2, 469 | "version_minor": 0 470 | }, 471 | "text/plain": [ 472 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1000.0), HTML(value='')))" 473 | ] 474 | }, 475 | "metadata": {}, 476 | "output_type": "display_data" 477 | }, 478 | { 479 | "name": "stdout", 480 | "output_type": "stream", 481 | "text": [ 482 | "\n", 483 | "restart 2 / 3\n" 484 | ] 485 | }, 486 | { 487 | "data": { 488 | "application/vnd.jupyter.widget-view+json": { 489 | "model_id": "e28c2bb3205640ea8df5ccb9d1798311", 490 | "version_major": 2, 491 | "version_minor": 0 492 | }, 493 | "text/plain": [ 494 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1000.0), HTML(value='')))" 495 | ] 496 | }, 497 | "metadata": {}, 498 | "output_type": "display_data" 499 | }, 500 | { 501 | "name": "stdout", 502 | "output_type": "stream", 503 | "text": [ 504 | "\n", 505 | "restart 3 / 3\n" 506 | ] 507 | }, 508 | { 509 | "data": { 510 | "application/vnd.jupyter.widget-view+json": { 511 | "model_id": "a3fc7976462e44d2b54fb24433a17d40", 512 | "version_major": 2, 513 | "version_minor": 0 514 | }, 515 | "text/plain": [ 516 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1000.0), HTML(value='')))" 517 | ] 518 | }, 519 | "metadata": {}, 520 | "output_type": "display_data" 521 | }, 522 | { 523 | "name": "stdout", 524 | "output_type": "stream", 525 | "text": [ 526 | "\n", 527 | "predicted states:\n", 528 | "[2 1 0 0 1 3 2 2 0 3 3 1 2 2 1 3 1 3 1 1 1 1 1 1 0 2 0 2 2 1 3 3 0 2 1 3 2\n", 529 | " 0 2 2 0 0 0 1 2 2 2 3 0 2 2 3 3 3 2 2 1 3 1 3 3 3 0 1 0 2 1 0 0 0 3 3 0 1\n", 530 | " 0 2 2 1 2 3 2 3 2 3 3 3 2 0 2 1 0 2 3 1 3 0 3 3 3 3 0 0 2 2 1 2 2 3 3 0 0\n", 531 | " 0 0 0 3 3 2 0 3 3 0 0 3 1 3 3 2 3 3 3 2 2 1 0 2 1 3 1 1 2 3 0 0 0 1 1 3 3\n", 532 | " 3 1 0 0 3 1 0 2 1 2 3 3 1 2 0 1 0 1 0 0 0 2 3 3 3 1 3 3 3 3 3 3 1 3 2 3 0\n", 533 | " 0 0 1 3 3 2 0 3 1 3 3 0 1 1 2 3]\n", 534 | "N = 14 : err_inf = 0.777233, err_2 = 2.076351, err_fro = 3.297718, err_mse = 0.034824\n", 535 | "-----------------------------\n", 536 | "N=18, STEPS: 1000 \n", 537 | "restart 1 / 3\n" 538 | ] 539 | }, 540 | { 541 | "data": { 542 | "application/vnd.jupyter.widget-view+json": { 543 | "model_id": "138534be6da64a9b8cb15cc12dd0fe9c", 544 | "version_major": 2, 545 | "version_minor": 0 546 | }, 547 | "text/plain": [ 548 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1000.0), HTML(value='')))" 549 | ] 550 | }, 551 | "metadata": {}, 552 | "output_type": "display_data" 553 | }, 554 | { 555 | "name": "stdout", 556 | "output_type": "stream", 557 | "text": [ 558 | "\n", 559 | "restart 2 / 3\n" 560 | ] 561 | }, 562 | { 563 | "data": { 564 | "application/vnd.jupyter.widget-view+json": { 565 | "model_id": "0416c93527ba45f783f5bccc9d432ea1", 566 | "version_major": 2, 567 | "version_minor": 0 568 | }, 569 | "text/plain": [ 570 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1000.0), HTML(value='')))" 571 | ] 572 | }, 573 | "metadata": {}, 574 | "output_type": "display_data" 575 | }, 576 | { 577 | "name": "stdout", 578 | "output_type": "stream", 579 | "text": [ 580 | "\n", 581 | "restart 3 / 3\n" 582 | ] 583 | }, 584 | { 585 | "data": { 586 | "application/vnd.jupyter.widget-view+json": { 587 | "model_id": "334ea5587bd54299bb1f17fee6b34861", 588 | "version_major": 2, 589 | "version_minor": 0 590 | }, 591 | "text/plain": [ 592 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1000.0), HTML(value='')))" 593 | ] 594 | }, 595 | "metadata": {}, 596 | "output_type": "display_data" 597 | }, 598 | { 599 | "name": "stdout", 600 | "output_type": "stream", 601 | "text": [ 602 | "\n", 603 | "predicted states:\n", 604 | "[1 0 3 2 1 3 3 3 0 3 3 0 0 2 0 0 1 0 1 1 1 1 3 3 3 1 0 3 2 1 3 2 1 0 1 1 0\n", 605 | " 3 3 2 0 1 2 2 1 2 1 3 1 0 0 1 1 3 1 1 3 3 2 0 1 3 2 0 2 3 1 3 3 3 3 2 1 0\n", 606 | " 3 2 1 3 0 2 1 3 2 3 0 2 1 1 2 0 0 3 2 2 0 0 1 0 2 2 2 3 0 1 2 2 2 2 1 0 2\n", 607 | " 2 1 2 2 3 2 1 0 1 3 1 1 3 2 2 2 2 3 2 3 0 2 3 0 0 1 2 1 3 0 3 3 0 0 1 1 0\n", 608 | " 3 1 2 2 0 2 2 0 3 1 3 1 2 2 3 0 2 2 2 1 3 3 3 0 1 0 2 0 1 0 0 2 2 3 1 1 0\n", 609 | " 3 2 0 3 2 0 0 1 1 0 0 0 0 1 3 3]\n", 610 | "N = 18 : err_inf = 0.803452, err_2 = 2.415605, err_fro = 4.047948, err_mse = 0.045745\n", 611 | "restart 1 / 3\n" 612 | ] 613 | }, 614 | { 615 | "data": { 616 | "application/vnd.jupyter.widget-view+json": { 617 | "model_id": "da0be8e12a524dac83401712c9deed7e", 618 | "version_major": 2, 619 | "version_minor": 0 620 | }, 621 | "text/plain": [ 622 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1000.0), HTML(value='')))" 623 | ] 624 | }, 625 | "metadata": {}, 626 | "output_type": "display_data" 627 | }, 628 | { 629 | "name": "stdout", 630 | "output_type": "stream", 631 | "text": [ 632 | "\n", 633 | "restart 2 / 3\n" 634 | ] 635 | }, 636 | { 637 | "data": { 638 | "application/vnd.jupyter.widget-view+json": { 639 | "model_id": "d94c496fe4ba4339b384966c35caf4c3", 640 | "version_major": 2, 641 | "version_minor": 0 642 | }, 643 | "text/plain": [ 644 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1000.0), HTML(value='')))" 645 | ] 646 | }, 647 | "metadata": {}, 648 | "output_type": "display_data" 649 | }, 650 | { 651 | "name": "stdout", 652 | "output_type": "stream", 653 | "text": [ 654 | "\n", 655 | "restart 3 / 3\n" 656 | ] 657 | }, 658 | { 659 | "data": { 660 | "application/vnd.jupyter.widget-view+json": { 661 | "model_id": "887ee7159a9842068841440788adb4f8", 662 | "version_major": 2, 663 | "version_minor": 0 664 | }, 665 | "text/plain": [ 666 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1000.0), HTML(value='')))" 667 | ] 668 | }, 669 | "metadata": {}, 670 | "output_type": "display_data" 671 | }, 672 | { 673 | "name": "stdout", 674 | "output_type": "stream", 675 | "text": [ 676 | "\n", 677 | "predicted states:\n", 678 | "[0 1 1 1 0 1 2 3 2 3 1 1 2 0 3 0 1 2 0 1 2 3 1 0 2 1 1 2 1 2 3 0 1 3 2 3 0\n", 679 | " 1 3 0 1 3 0 2 1 3 3 2 1 3 1 2 0 1 1 0 3 0 3 3 2 3 2 2 3 3 2 1 3 0 0 1 1 0\n", 680 | " 3 3 2 0 0 1 1 0 1 0 0 1 3 1 2 2 1 0 2 2 0 0 0 0 2 1 0 1 0 1 0 1 0 0 1 3 3\n", 681 | " 0 1 2 1 1 0 2 0 1 1 1 2 3 1 3 2 2 0 3 0 0 2 3 0 3 0 1 0 1 1 0 0 2 0 3 2 3\n", 682 | " 2 2 0 1 1 2 0 2 0 1 3 0 1 0 2 1 0 2 1 1 2 1 1 2 3 0 2 2 1 0 2 1 3 1 1 1 2\n", 683 | " 2 3 0 3 1 3 0 2 0 1 3 1 0 3 2 1]\n", 684 | "N = 18 : err_inf = 0.799911, err_2 = 2.458965, err_fro = 4.075791, err_mse = 0.047006\n", 685 | "restart 1 / 3\n" 686 | ] 687 | }, 688 | { 689 | "data": { 690 | "application/vnd.jupyter.widget-view+json": { 691 | "model_id": "59b1aeab0b0946b2a02236a8c0ed967a", 692 | "version_major": 2, 693 | "version_minor": 0 694 | }, 695 | "text/plain": [ 696 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1000.0), HTML(value='')))" 697 | ] 698 | }, 699 | "metadata": {}, 700 | "output_type": "display_data" 701 | }, 702 | { 703 | "name": "stdout", 704 | "output_type": "stream", 705 | "text": [ 706 | "\n", 707 | "restart 2 / 3\n" 708 | ] 709 | }, 710 | { 711 | "data": { 712 | "application/vnd.jupyter.widget-view+json": { 713 | "model_id": "d65df1e4401d4e6f9415334b0dfc993c", 714 | "version_major": 2, 715 | "version_minor": 0 716 | }, 717 | "text/plain": [ 718 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1000.0), HTML(value='')))" 719 | ] 720 | }, 721 | "metadata": {}, 722 | "output_type": "display_data" 723 | }, 724 | { 725 | "name": "stdout", 726 | "output_type": "stream", 727 | "text": [ 728 | "\n", 729 | "restart 3 / 3\n" 730 | ] 731 | }, 732 | { 733 | "data": { 734 | "application/vnd.jupyter.widget-view+json": { 735 | "model_id": "b7f12045abbb45489edc6c4708ba28cb", 736 | "version_major": 2, 737 | "version_minor": 0 738 | }, 739 | "text/plain": [ 740 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1000.0), HTML(value='')))" 741 | ] 742 | }, 743 | "metadata": {}, 744 | "output_type": "display_data" 745 | }, 746 | { 747 | "name": "stdout", 748 | "output_type": "stream", 749 | "text": [ 750 | "\n", 751 | "predicted states:\n", 752 | "[1 1 2 0 0 2 1 2 0 1 0 1 2 2 2 2 1 1 1 1 0 0 2 1 2 2 3 1 1 3 0 3 3 0 0 1 2\n", 753 | " 0 3 2 2 2 2 0 2 1 0 2 3 0 1 3 1 0 1 1 2 3 3 1 3 2 2 1 1 2 0 3 1 1 0 0 2 1\n", 754 | " 0 3 3 3 2 3 1 3 3 0 2 3 1 3 2 1 1 0 1 0 1 3 0 3 3 1 2 3 1 1 1 1 3 3 2 2 1\n", 755 | " 1 1 1 2 1 1 3 0 0 0 2 2 2 3 3 0 2 1 3 0 3 2 3 2 2 0 1 2 2 3 1 3 2 2 0 0 0\n", 756 | " 1 2 3 2 0 3 0 2 2 1 0 3 0 1 2 3 3 2 2 2 3 2 1 1 2 0 2 2 2 3 3 2 3 1 0 3 0\n", 757 | " 0 3 3 1 1 2 0 1 0 2 2 2 2 0 3 0]\n", 758 | "N = 18 : err_inf = 0.915241, err_2 = 2.872617, err_fro = 4.324601, err_mse = 0.047059\n", 759 | "-----------------------------\n", 760 | "N=20, STEPS: 1000 \n", 761 | "restart 1 / 3\n" 762 | ] 763 | }, 764 | { 765 | "data": { 766 | "application/vnd.jupyter.widget-view+json": { 767 | "model_id": "e58e9eaee4a4425fb5999b024c1ca801", 768 | "version_major": 2, 769 | "version_minor": 0 770 | }, 771 | "text/plain": [ 772 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1000.0), HTML(value='')))" 773 | ] 774 | }, 775 | "metadata": {}, 776 | "output_type": "display_data" 777 | }, 778 | { 779 | "name": "stdout", 780 | "output_type": "stream", 781 | "text": [ 782 | "\n", 783 | "restart 2 / 3\n" 784 | ] 785 | }, 786 | { 787 | "data": { 788 | "application/vnd.jupyter.widget-view+json": { 789 | "model_id": "f06895c8c93d455a985f83949a87acb6", 790 | "version_major": 2, 791 | "version_minor": 0 792 | }, 793 | "text/plain": [ 794 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1000.0), HTML(value='')))" 795 | ] 796 | }, 797 | "metadata": {}, 798 | "output_type": "display_data" 799 | }, 800 | { 801 | "name": "stdout", 802 | "output_type": "stream", 803 | "text": [ 804 | "\n", 805 | "restart 3 / 3\n" 806 | ] 807 | }, 808 | { 809 | "data": { 810 | "application/vnd.jupyter.widget-view+json": { 811 | "model_id": "9f894852575c42a7a011cf1af692fa25", 812 | "version_major": 2, 813 | "version_minor": 0 814 | }, 815 | "text/plain": [ 816 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1000.0), HTML(value='')))" 817 | ] 818 | }, 819 | "metadata": {}, 820 | "output_type": "display_data" 821 | }, 822 | { 823 | "name": "stdout", 824 | "output_type": "stream", 825 | "text": [ 826 | "\n", 827 | "predicted states:\n", 828 | "[2 0 1 1 1 0 3 2 0 2 0 0 3 2 0 3 0 3 1 2 1 2 1 1 2 2 1 1 2 2 1 3 0 3 3 2 3\n", 829 | " 1 2 1 3 2 0 1 2 0 1 2 0 1 2 3 2 1 0 3 3 0 2 2 0 3 3 0 2 3 1 3 1 2 2 0 2 1\n", 830 | " 3 3 2 3 3 3 2 1 0 0 0 0 2 0 2 2 3 3 2 1 0 1 3 2 0 1 3 3 3 2 3 1 0 3 2 1 3\n", 831 | " 3 2 0 3 0 1 3 1 1 2 3 1 0 3 3 1 2 3 0 3 1 0 1 3 0 0 1 2 1 3 3 3 3 0 3 3 1\n", 832 | " 2 3 3 2 2 0 3 1 3 0 2 2 3 2 1 2 1 2 2 1 1 0 2 3 3 2 2 2 2 2 1 2 2 1 1 2 0\n", 833 | " 1 2 1 1 0 2 1 1 3 0 1 2 3 0 3 1]\n", 834 | "N = 20 : err_inf = 0.844948, err_2 = 2.853636, err_fro = 4.773794, err_mse = 0.054623\n", 835 | "restart 1 / 3\n" 836 | ] 837 | }, 838 | { 839 | "data": { 840 | "application/vnd.jupyter.widget-view+json": { 841 | "model_id": "db33b29dff0e43b8a62c533ea0ff2f2e", 842 | "version_major": 2, 843 | "version_minor": 0 844 | }, 845 | "text/plain": [ 846 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1000.0), HTML(value='')))" 847 | ] 848 | }, 849 | "metadata": {}, 850 | "output_type": "display_data" 851 | }, 852 | { 853 | "name": "stdout", 854 | "output_type": "stream", 855 | "text": [ 856 | "\n", 857 | "restart 2 / 3\n" 858 | ] 859 | }, 860 | { 861 | "data": { 862 | "application/vnd.jupyter.widget-view+json": { 863 | "model_id": "27452e9564ef43838a42dffd813b17d5", 864 | "version_major": 2, 865 | "version_minor": 0 866 | }, 867 | "text/plain": [ 868 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1000.0), HTML(value='')))" 869 | ] 870 | }, 871 | "metadata": {}, 872 | "output_type": "display_data" 873 | }, 874 | { 875 | "name": "stdout", 876 | "output_type": "stream", 877 | "text": [ 878 | "\n", 879 | "restart 3 / 3\n" 880 | ] 881 | }, 882 | { 883 | "data": { 884 | "application/vnd.jupyter.widget-view+json": { 885 | "model_id": "5ea6a4531704499d8015ae01c46ab992", 886 | "version_major": 2, 887 | "version_minor": 0 888 | }, 889 | "text/plain": [ 890 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1000.0), HTML(value='')))" 891 | ] 892 | }, 893 | "metadata": {}, 894 | "output_type": "display_data" 895 | }, 896 | { 897 | "name": "stdout", 898 | "output_type": "stream", 899 | "text": [ 900 | "\n", 901 | "predicted states:\n", 902 | "[0 3 0 3 2 0 2 0 1 0 1 2 0 3 2 2 2 2 2 2 3 1 2 0 3 2 3 3 2 0 3 2 1 3 2 0 1\n", 903 | " 2 0 3 3 1 1 2 1 0 3 1 2 2 2 3 3 3 2 2 1 3 1 0 3 0 1 1 0 3 0 0 3 2 2 3 3 3\n", 904 | " 2 1 3 3 3 2 2 3 1 3 2 0 3 1 0 0 1 3 3 1 1 0 3 2 3 2 0 1 2 1 2 0 0 0 1 1 2\n", 905 | " 2 0 0 1 1 2 3 3 1 0 3 0 0 3 1 3 2 2 3 0 0 2 2 0 1 2 1 3 0 1 0 1 2 2 1 0 0\n", 906 | " 3 3 0 3 2 1 3 1 0 1 2 3 0 0 3 2 0 1 0 3 0 2 1 0 3 3 3 1 3 1 3 3 3 1 0 3 2\n", 907 | " 0 2 2 0 1 2 0 0 3 0 3 3 2 3 1 2]\n", 908 | "N = 20 : err_inf = 0.794804, err_2 = 2.996576, err_fro = 4.684119, err_mse = 0.052172\n", 909 | "restart 1 / 3\n" 910 | ] 911 | }, 912 | { 913 | "data": { 914 | "application/vnd.jupyter.widget-view+json": { 915 | "model_id": "e8d88212f3f145b78750310e55c62608", 916 | "version_major": 2, 917 | "version_minor": 0 918 | }, 919 | "text/plain": [ 920 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1000.0), HTML(value='')))" 921 | ] 922 | }, 923 | "metadata": {}, 924 | "output_type": "display_data" 925 | }, 926 | { 927 | "name": "stdout", 928 | "output_type": "stream", 929 | "text": [ 930 | "\n", 931 | "restart 2 / 3\n" 932 | ] 933 | }, 934 | { 935 | "data": { 936 | "application/vnd.jupyter.widget-view+json": { 937 | "model_id": "5bda5159a7914349b715030915834b92", 938 | "version_major": 2, 939 | "version_minor": 0 940 | }, 941 | "text/plain": [ 942 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1000.0), HTML(value='')))" 943 | ] 944 | }, 945 | "metadata": {}, 946 | "output_type": "display_data" 947 | }, 948 | { 949 | "name": "stdout", 950 | "output_type": "stream", 951 | "text": [ 952 | "\n", 953 | "restart 3 / 3\n" 954 | ] 955 | }, 956 | { 957 | "data": { 958 | "application/vnd.jupyter.widget-view+json": { 959 | "model_id": "d0720807ea9e47f296db7492642cb610", 960 | "version_major": 2, 961 | "version_minor": 0 962 | }, 963 | "text/plain": [ 964 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1000.0), HTML(value='')))" 965 | ] 966 | }, 967 | "metadata": {}, 968 | "output_type": "display_data" 969 | }, 970 | { 971 | "name": "stdout", 972 | "output_type": "stream", 973 | "text": [ 974 | "\n", 975 | "predicted states:\n", 976 | "[3 3 0 1 3 0 0 2 3 2 0 0 2 0 0 3 0 0 1 3 1 1 2 2 3 3 2 0 3 3 3 3 1 2 0 3 3\n", 977 | " 1 1 3 0 1 0 1 0 2 0 1 3 2 2 1 0 3 3 2 2 2 1 1 0 0 3 1 2 0 0 0 2 2 1 2 0 2\n", 978 | " 2 3 1 1 1 0 2 3 0 3 2 2 1 0 2 2 3 0 2 3 2 1 0 2 1 1 1 1 2 0 0 2 0 1 0 0 2\n", 979 | " 3 1 2 0 1 1 1 0 0 0 2 1 2 1 3 3 3 2 2 0 2 2 3 2 0 2 2 1 3 2 0 0 3 3 2 2 0\n", 980 | " 1 0 3 3 3 3 0 2 3 2 2 3 0 1 2 2 2 0 3 2 3 2 0 3 2 3 2 0 0 0 3 3 3 2 0 2 3\n", 981 | " 2 1 3 2 2 1 0 1 3 3 2 3 1 2 1 1]\n", 982 | "N = 20 : err_inf = 0.717674, err_2 = 2.490861, err_fro = 4.493759, err_mse = 0.049774\n", 983 | "-----------------------------\n", 984 | "N=30, STEPS: 1000 \n", 985 | "restart 1 / 3\n" 986 | ] 987 | }, 988 | { 989 | "data": { 990 | "application/vnd.jupyter.widget-view+json": { 991 | "model_id": "73bca2f9382e47a7a80088e8addcd34e", 992 | "version_major": 2, 993 | "version_minor": 0 994 | }, 995 | "text/plain": [ 996 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1000.0), HTML(value='')))" 997 | ] 998 | }, 999 | "metadata": {}, 1000 | "output_type": "display_data" 1001 | }, 1002 | { 1003 | "name": "stdout", 1004 | "output_type": "stream", 1005 | "text": [ 1006 | "\n", 1007 | "predicted states:\n", 1008 | "[0 2 0 2 0 2 2 1 3 1 1 0 1 2 3 1 0 2 2 3 3 0 3 1 1 0 3 0 0 2 0 0 0 1 0 0 3\n", 1009 | " 0 2 1 2 3 0 3 3 3 0 2 2 3 0 2 3 1 1 3 1 3 3 0 1 3 1 0 2 3 0 3 3 0 0 2 1 1\n", 1010 | " 2 0 0 3 2 3 3 0 2 0 0 3 3 3 3 3 1 0 2 1 3 3 3 2 3 0 0 2 3 3 2 1 1 0 3 1 1\n", 1011 | " 3 2 3 3 1 0 2 0 1 0 0 1 0 3 1 3 2 1 2 1 1 3 3 1 0 2 0 0 1 2 3 1 2 1 3 3 2\n", 1012 | " 0 2 2 0 2 1 1 0 0 3 2 0 2 1 3 2 1 3 1 2 1 0 0 0 1 1 1 3 3 3 3 3 1 0 2 2 1\n", 1013 | " 0 1 1 2 0 1 2 2 2 0 2 2 3 3 1 1]\n", 1014 | "N = 50 : err_inf = 2.968470, err_2 = 19.123944, err_fro = 29.101795, err_mse = 1.026776\n", 1015 | "restart 1 / 3\n" 1016 | ] 1017 | }, 1018 | { 1019 | "data": { 1020 | "application/vnd.jupyter.widget-view+json": { 1021 | "model_id": "77a570c9ce2b4cb5b3b6c58ac10cb17d", 1022 | "version_major": 2, 1023 | "version_minor": 0 1024 | }, 1025 | "text/plain": [ 1026 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1000.0), HTML(value='')))" 1027 | ] 1028 | }, 1029 | "metadata": {}, 1030 | "output_type": "display_data" 1031 | }, 1032 | { 1033 | "name": "stdout", 1034 | "output_type": "stream", 1035 | "text": [ 1036 | "\n", 1037 | "restart 2 / 3\n" 1038 | ] 1039 | }, 1040 | { 1041 | "data": { 1042 | "application/vnd.jupyter.widget-view+json": { 1043 | "model_id": "4c3cd666157b412185c9a7106e040fe4", 1044 | "version_major": 2, 1045 | "version_minor": 0 1046 | }, 1047 | "text/plain": [ 1048 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1000.0), HTML(value='')))" 1049 | ] 1050 | }, 1051 | "metadata": {}, 1052 | "output_type": "display_data" 1053 | }, 1054 | { 1055 | "name": "stdout", 1056 | "output_type": "stream", 1057 | "text": [ 1058 | "\n", 1059 | "restart 3 / 3\n" 1060 | ] 1061 | }, 1062 | { 1063 | "data": { 1064 | "application/vnd.jupyter.widget-view+json": { 1065 | "model_id": "b339001129834f94ac704b40202c3fa1", 1066 | "version_major": 2, 1067 | "version_minor": 0 1068 | }, 1069 | "text/plain": [ 1070 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1000.0), HTML(value='')))" 1071 | ] 1072 | }, 1073 | "metadata": {}, 1074 | "output_type": "display_data" 1075 | }, 1076 | { 1077 | "name": "stdout", 1078 | "output_type": "stream", 1079 | "text": [ 1080 | "\n", 1081 | "predicted states:\n", 1082 | "[2 1 1 2 0 1 2 3 3 3 3 3 1 3 2 3 1 2 2 1 3 3 1 2 2 1 2 2 2 0 0 3 0 2 2 1 2\n", 1083 | " 0 0 3 3 3 2 1 2 3 1 2 1 3 0 1 2 3 2 1 0 0 2 3 1 3 2 2 1 2 3 3 2 3 1 0 2 1\n", 1084 | " 2 0 0 0 2 0 0 2 2 2 0 1 0 0 1 1 1 3 2 0 1 1 0 1 0 0 3 1 1 3 3 1 3 0 3 2 1\n", 1085 | " 1 1 0 0 2 3 1 2 1 1 1 2 3 0 0 1 2 2 3 1 1 0 0 0 3 3 3 2 3 1 3 2 3 2 3 0 1\n", 1086 | " 3 2 2 0 0 0 0 0 3 3 0 3 3 0 0 0 2 1 3 1 1 3 1 2 1 3 2 1 3 3 0 2 1 1 2 1 3\n", 1087 | " 0 2 3 0 1 3 3 3 3 0 0 0 3 1 0 0]\n", 1088 | "N = 50 : err_inf = 9.696988, err_2 = 66.542701, err_fro = 73.990241, err_mse = 7.015966\n", 1089 | "restart 1 / 3\n" 1090 | ] 1091 | }, 1092 | { 1093 | "data": { 1094 | "application/vnd.jupyter.widget-view+json": { 1095 | "model_id": "c13a656f5f54468596c842726cdeecb8", 1096 | "version_major": 2, 1097 | "version_minor": 0 1098 | }, 1099 | "text/plain": [ 1100 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1000.0), HTML(value='')))" 1101 | ] 1102 | }, 1103 | "metadata": {}, 1104 | "output_type": "display_data" 1105 | }, 1106 | { 1107 | "name": "stdout", 1108 | "output_type": "stream", 1109 | "text": [ 1110 | "\n", 1111 | "restart 2 / 3\n" 1112 | ] 1113 | }, 1114 | { 1115 | "data": { 1116 | "application/vnd.jupyter.widget-view+json": { 1117 | "model_id": "086735f458d3459da712128f20cdbe45", 1118 | "version_major": 2, 1119 | "version_minor": 0 1120 | }, 1121 | "text/plain": [ 1122 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1000.0), HTML(value='')))" 1123 | ] 1124 | }, 1125 | "metadata": {}, 1126 | "output_type": "display_data" 1127 | }, 1128 | { 1129 | "name": "stdout", 1130 | "output_type": "stream", 1131 | "text": [ 1132 | "\n", 1133 | "restart 3 / 3\n" 1134 | ] 1135 | }, 1136 | { 1137 | "data": { 1138 | "application/vnd.jupyter.widget-view+json": { 1139 | "model_id": "559b3d721d95471b916f218d3fe2570e", 1140 | "version_major": 2, 1141 | "version_minor": 0 1142 | }, 1143 | "text/plain": [ 1144 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1000.0), HTML(value='')))" 1145 | ] 1146 | }, 1147 | "metadata": {}, 1148 | "output_type": "display_data" 1149 | }, 1150 | { 1151 | "name": "stdout", 1152 | "output_type": "stream", 1153 | "text": [ 1154 | "\n", 1155 | "predicted states:\n", 1156 | "[2 1 0 0 2 3 0 1 2 3 3 3 3 0 2 1 3 1 1 1 0 0 0 0 2 2 0 0 2 1 1 0 2 1 0 0 2\n", 1157 | " 1 3 2 0 3 2 1 1 3 0 0 0 2 0 0 0 2 1 2 2 3 0 2 1 0 0 3 0 3 2 0 3 1 0 2 3 1\n", 1158 | " 0 3 3 1 1 2 1 2 2 1 3 1 2 0 3 2 3 2 3 0 1 1 0 2 3 3 1 2 0 0 1 2 3 0 0 0 0\n", 1159 | " 1 1 2 2 0 2 3 1 1 2 1 3 3 0 1 0 0 2 1 2 1 3 3 3 3 3 3 2 2 3 1 3 2 1 1 3 3\n", 1160 | " 3 1 0 1 2 3 2 2 2 1 0 1 1 0 3 0 1 0 2 3 2 2 0 3 0 0 2 2 0 0 1 1 0 0 1 3 0\n", 1161 | " 1 0 3 3 3 3 2 2 2 3 3 0 2 1 3 1]\n", 1162 | "N = 50 : err_inf = 2.676320, err_2 = 16.685327, err_fro = 26.799433, err_mse = 0.889505\n", 1163 | "-----------------------------\n", 1164 | "N=80, STEPS: 1000 \n", 1165 | "restart 1 / 3\n" 1166 | ] 1167 | }, 1168 | { 1169 | "data": { 1170 | "application/vnd.jupyter.widget-view+json": { 1171 | "model_id": "ba75ec89f2b748b2b2d4361103a75b36", 1172 | "version_major": 2, 1173 | "version_minor": 0 1174 | }, 1175 | "text/plain": [ 1176 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1000.0), HTML(value='')))" 1177 | ] 1178 | }, 1179 | "metadata": {}, 1180 | "output_type": "display_data" 1181 | }, 1182 | { 1183 | "name": "stderr", 1184 | "output_type": "stream", 1185 | "text": [ 1186 | "IOPub message rate exceeded.\n", 1187 | "The notebook server will temporarily stop sending output\n", 1188 | "to the client in order to avoid crashing it.\n", 1189 | "To change this limit, set the config variable\n", 1190 | "`--NotebookApp.iopub_msg_rate_limit`.\n", 1191 | "\n", 1192 | "Current values:\n", 1193 | "NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n", 1194 | "NotebookApp.rate_limit_window=3.0 (secs)\n", 1195 | "\n" 1196 | ] 1197 | }, 1198 | { 1199 | "name": "stdout", 1200 | "output_type": "stream", 1201 | "text": [ 1202 | "\n", 1203 | "restart 3 / 3\n" 1204 | ] 1205 | }, 1206 | { 1207 | "data": { 1208 | "application/vnd.jupyter.widget-view+json": { 1209 | "model_id": "776ce9bd10cd4b75bddf39eb083088ce", 1210 | "version_major": 2, 1211 | "version_minor": 0 1212 | }, 1213 | "text/plain": [ 1214 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1000.0), HTML(value='')))" 1215 | ] 1216 | }, 1217 | "metadata": {}, 1218 | "output_type": "display_data" 1219 | }, 1220 | { 1221 | "name": "stdout", 1222 | "output_type": "stream", 1223 | "text": [ 1224 | "\n", 1225 | "predicted states:\n", 1226 | "[0 0 3 3 3 1 0 2 3 2 2 1 2 1 0 2 0 0 2 2 0 1 3 0 1 0 3 2 2 2 0 1 0 1 1 1 3\n", 1227 | " 1 1 3 2 2 0 2 1 0 3 2 1 0 0 1 0 1 3 2 2 1 3 2 1 2 2 1 1 1 0 0 2 2 2 3 1 3\n", 1228 | " 1 2 2 1 1 1 3 0 2 0 2 1 2 3 3 2 0 0 2 0 3 0 1 1 0 3 2 2 2 1 0 2 0 1 1 1 2\n", 1229 | " 2 3 1 1 0 2 1 1 1 1 0 2 1 3 3 2 3 0 2 1 3 0 0 0 0 3 1 2 2 3 3 1 3 1 2 3 0\n", 1230 | " 1 3 2 1 3 3 3 1 2 3 3 2 3 1 3 1 2 1 3 0 3 1 0 1 2 1 2 1 0 0 0 0 0 2 3 1 2\n", 1231 | " 2 1 1 1 1 1 3 2 1 0 1 1 2 2 1 1]\n", 1232 | "N = 100 : err_inf = 0.770405, err_2 = 4.238526, err_fro = 12.927433, err_mse = 0.282129\n", 1233 | "restart 1 / 3\n" 1234 | ] 1235 | }, 1236 | { 1237 | "data": { 1238 | "application/vnd.jupyter.widget-view+json": { 1239 | "model_id": "7551badde8c04ed795da1b1cbe3801b7", 1240 | "version_major": 2, 1241 | "version_minor": 0 1242 | }, 1243 | "text/plain": [ 1244 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1000.0), HTML(value='')))" 1245 | ] 1246 | }, 1247 | "metadata": {}, 1248 | "output_type": "display_data" 1249 | }, 1250 | { 1251 | "name": "stderr", 1252 | "output_type": "stream", 1253 | "text": [ 1254 | "IOPub message rate exceeded.\n", 1255 | "The notebook server will temporarily stop sending output\n", 1256 | "to the client in order to avoid crashing it.\n", 1257 | "To change this limit, set the config variable\n", 1258 | "`--NotebookApp.iopub_msg_rate_limit`.\n", 1259 | "\n", 1260 | "Current values:\n", 1261 | "NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n", 1262 | "NotebookApp.rate_limit_window=3.0 (secs)\n", 1263 | "\n" 1264 | ] 1265 | }, 1266 | { 1267 | "name": "stdout", 1268 | "output_type": "stream", 1269 | "text": [ 1270 | "\n", 1271 | "restart 3 / 3\n" 1272 | ] 1273 | }, 1274 | { 1275 | "data": { 1276 | "application/vnd.jupyter.widget-view+json": { 1277 | "model_id": "39cce6aabb464aeaa18025572a20b80a", 1278 | "version_major": 2, 1279 | "version_minor": 0 1280 | }, 1281 | "text/plain": [ 1282 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1000.0), HTML(value='')))" 1283 | ] 1284 | }, 1285 | "metadata": {}, 1286 | "output_type": "display_data" 1287 | }, 1288 | { 1289 | "name": "stderr", 1290 | "output_type": "stream", 1291 | "text": [ 1292 | "IOPub message rate exceeded.\n", 1293 | "The notebook server will temporarily stop sending output\n", 1294 | "to the client in order to avoid crashing it.\n", 1295 | "To change this limit, set the config variable\n", 1296 | "`--NotebookApp.iopub_msg_rate_limit`.\n", 1297 | "\n", 1298 | "Current values:\n", 1299 | "NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n", 1300 | "NotebookApp.rate_limit_window=3.0 (secs)\n", 1301 | "\n" 1302 | ] 1303 | }, 1304 | { 1305 | "name": "stdout", 1306 | "output_type": "stream", 1307 | "text": [ 1308 | "\n", 1309 | "restart 3 / 3\n" 1310 | ] 1311 | }, 1312 | { 1313 | "data": { 1314 | "application/vnd.jupyter.widget-view+json": { 1315 | "model_id": "c509266a6e3344d688ae75584ff8dc8c", 1316 | "version_major": 2, 1317 | "version_minor": 0 1318 | }, 1319 | "text/plain": [ 1320 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1000.0), HTML(value='')))" 1321 | ] 1322 | }, 1323 | "metadata": {}, 1324 | "output_type": "display_data" 1325 | }, 1326 | { 1327 | "name": "stderr", 1328 | "output_type": "stream", 1329 | "text": [ 1330 | "IOPub message rate exceeded.\n", 1331 | "The notebook server will temporarily stop sending output\n", 1332 | "to the client in order to avoid crashing it.\n", 1333 | "To change this limit, set the config variable\n", 1334 | "`--NotebookApp.iopub_msg_rate_limit`.\n", 1335 | "\n", 1336 | "Current values:\n", 1337 | "NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n", 1338 | "NotebookApp.rate_limit_window=3.0 (secs)\n", 1339 | "\n" 1340 | ] 1341 | }, 1342 | { 1343 | "name": "stdout", 1344 | "output_type": "stream", 1345 | "text": [ 1346 | "\n", 1347 | "restart 2 / 3\n" 1348 | ] 1349 | }, 1350 | { 1351 | "data": { 1352 | "application/vnd.jupyter.widget-view+json": { 1353 | "model_id": "b022afcf13044e56bd6f4dfe5469333a", 1354 | "version_major": 2, 1355 | "version_minor": 0 1356 | }, 1357 | "text/plain": [ 1358 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1000.0), HTML(value='')))" 1359 | ] 1360 | }, 1361 | "metadata": {}, 1362 | "output_type": "display_data" 1363 | }, 1364 | { 1365 | "name": "stderr", 1366 | "output_type": "stream", 1367 | "text": [ 1368 | "IOPub message rate exceeded.\n", 1369 | "The notebook server will temporarily stop sending output\n", 1370 | "to the client in order to avoid crashing it.\n", 1371 | "To change this limit, set the config variable\n", 1372 | "`--NotebookApp.iopub_msg_rate_limit`.\n", 1373 | "\n", 1374 | "Current values:\n", 1375 | "NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n", 1376 | "NotebookApp.rate_limit_window=3.0 (secs)\n", 1377 | "\n" 1378 | ] 1379 | } 1380 | ], 1381 | "source": [ 1382 | "for N in N_array:\n", 1383 | " print(\"-----------------------------\\nN=%d, STEPS: %d \" % (N, num_iters))\n", 1384 | " for rep in range(3):\n", 1385 | " # load data\n", 1386 | " data = scipy.io.loadmat(\"../data/test_data_N_%d_M_201_sigma_0.500000_rep_%d.mat\" % (N, rep + 1))\n", 1387 | " X = data['X']\n", 1388 | " A1 = data['A1']\n", 1389 | " A2 = data['A2']\n", 1390 | " \n", 1391 | " # fit model\n", 1392 | " err_inf, err_2, err_fro, err_mse, _ = \\\n", 1393 | " fit_arhmm_and_return_errors(X.T, A1, A2, Kmax=Kmax, num_restarts=num_restarts, num_iters=num_iters, rank=rank)\n", 1394 | " \n", 1395 | " # print some output\n", 1396 | " print(\"N = %d : err_inf = %f, err_2 = %f, err_fro = %f, err_mse = %f\" % \\\n", 1397 | " (N, err_inf, err_2, err_fro, err_mse))\n", 1398 | " new_row = dict(zip(error_table.columns, \n", 1399 | " [N, np.nan, model, err_inf, err_2, err_fro, err_mse, np.nan]))\n", 1400 | " error_table = error_table.append(new_row, ignore_index=True)\n", 1401 | "\n", 1402 | "# write output\n", 1403 | "error_table.to_csv(output_file, header=True, index=False)" 1404 | ] 1405 | }, 1406 | { 1407 | "cell_type": "code", 1408 | "execution_count": null, 1409 | "metadata": {}, 1410 | "outputs": [], 1411 | "source": [ 1412 | "data = error_table\n", 1413 | "#plt.loglog(data['N'], data['err_2'])\n", 1414 | "fig, ax = plt.subplots()\n", 1415 | "\n", 1416 | "for key, grp in data.groupby(['model']):\n", 1417 | " grp = grp.groupby(['N']).mean()\n", 1418 | "# if key == 1:\n", 1419 | "# keystr = 'indep(N)'\n", 1420 | "# elif key == 2:\n", 1421 | "# keystr = 'indep(4)'\n", 1422 | "# elif key == 3:\n", 1423 | "# keystr = 'TVART(4)'\n", 1424 | "# elif key == 4:\n", 1425 | "# keystr = 'SLDS(4)'\n", 1426 | "# elif key == 5:\n", 1427 | "# keystr = 'SLDS(6)'\n", 1428 | "# elif key == 6:\n", 1429 | "# keystr = 'SLDS(2)'\n", 1430 | " keystr = key\n", 1431 | " ax = grp.plot(ax=ax, kind='line', y='err_2', label=keystr, logx=True, logy=True)\n", 1432 | " plt.ylabel('2-norm error')\n", 1433 | "\n", 1434 | "plt.legend(loc='best')\n", 1435 | "#plt.ylim([1e-2, 1e-1])\n", 1436 | "plt.show()\n", 1437 | "\n", 1438 | "#data.plot.line(x='N', y='err_inf', logx=True, logy=True)" 1439 | ] 1440 | }, 1441 | { 1442 | "cell_type": "code", 1443 | "execution_count": null, 1444 | "metadata": {}, 1445 | "outputs": [], 1446 | "source": [] 1447 | }, 1448 | { 1449 | "cell_type": "code", 1450 | "execution_count": null, 1451 | "metadata": {}, 1452 | "outputs": [], 1453 | "source": [] 1454 | } 1455 | ], 1456 | "metadata": { 1457 | "kernelspec": { 1458 | "display_name": "Python 3", 1459 | "language": "python", 1460 | "name": "python3" 1461 | }, 1462 | "language_info": { 1463 | "codemirror_mode": { 1464 | "name": "ipython", 1465 | "version": 3 1466 | }, 1467 | "file_extension": ".py", 1468 | "mimetype": "text/x-python", 1469 | "name": "python", 1470 | "nbconvert_exporter": "python", 1471 | "pygments_lexer": "ipython3", 1472 | "version": "3.8.5" 1473 | } 1474 | }, 1475 | "nbformat": 4, 1476 | "nbformat_minor": 2 1477 | } 1478 | --------------------------------------------------------------------------------