├── CloudDistanceType.m ├── compute-optimal-transport ├── OTSolver.m ├── compute_single_ot_distance_sinkhorn.m ├── compute_single_ot_distance_gurobi.m ├── compute_single_ot_distance_linprog.m ├── compute_single_ot_distance.m ├── compute_single_ot_distance_mosek.m └── compute_single_ot_distance_fastemd.m ├── plotting ├── centroid_distance_matrix.m ├── group_data_by_location.m ├── prepare_globe_plots.m ├── group_by_opts.m ├── centroid_fractions_by_location.m ├── estimate_metric_kmeans_objective.m ├── make_globe_plots.m ├── subaxis.m ├── set_my_colormap.m ├── parseArgs.m └── weather_state_plots.m ├── wasserstein_streaming_kmeans_plusplus.m ├── euclidean_streaming_kmeans_plusplus.m ├── initialize_kmeans_plusplus.m ├── README.md ├── metric_streaming_kmeans_plusplus.m ├── cluster_histograms.m └── extract_cloud_histograms.m /CloudDistanceType.m: -------------------------------------------------------------------------------- 1 | classdef CloudDistanceType 2 | enumeration 3 | EuclideanBasedOnValues, EuclideanBasedOnGrid 4 | end 5 | end -------------------------------------------------------------------------------- /compute-optimal-transport/OTSolver.m: -------------------------------------------------------------------------------- 1 | classdef OTSolver 2 | enumeration 3 | FastEMD, Gurobi, Linprog, Mosek, Sinkhorn 4 | end 5 | end -------------------------------------------------------------------------------- /compute-optimal-transport/compute_single_ot_distance_sinkhorn.m: -------------------------------------------------------------------------------- 1 | function [ val, grad ] = compute_single_ot_distance_sinkhorn( C, a, b ) 2 | %COMPUTE_SINGLE_OT_DISTANCE Summary of this function goes here 3 | % Detailed explanation goes here 4 | 5 | lambda = 5e1; 6 | K = exp(-lambda*C); 7 | [D,L,u,v] = sinkhornTransport(a(:), b(:), K, K.*C, lambda); 8 | val = D; 9 | 10 | alpha = log(u); 11 | alpha(isinf(alpha)) = 0; 12 | grad = -alpha / lambda; 13 | grad = grad - sum(grad) / length(grad); 14 | 15 | end 16 | -------------------------------------------------------------------------------- /plotting/centroid_distance_matrix.m: -------------------------------------------------------------------------------- 1 | function [ D ] = centroid_distance_matrix( C, centroids, opts ) 2 | %CENTROID_DISTANCE_MATRIX Summary of this function goes here 3 | % Detailed explanation goes here 4 | 5 | k = size(centroids, 1); 6 | D = zeros(k); 7 | for aa=1:k 8 | for bb=(aa+1):k 9 | D(aa,bb) = compute_single_ot_distance(C, centroids(aa,:), centroids(bb,:), opts.OTSolver, opts.p); 10 | D(bb,aa) = D(aa,bb); 11 | end 12 | end 13 | 14 | end 15 | 16 | %function [x_hist] = add_last_col(x_row) 17 | % curr_sum = sum(x_row); 18 | % x_hist = [x_row(:)' 1-curr_sum]; 19 | %end 20 | -------------------------------------------------------------------------------- /wasserstein_streaming_kmeans_plusplus.m: -------------------------------------------------------------------------------- 1 | function [ centroids, out_struct ] = wasserstein_streaming_kmeans_plusplus( C, X, opts ) 2 | %WASSERSTEIN_KMEANS_PLUSPLUS Performs W2 histogram clustering with smart 3 | %initialization 4 | % C is the pairwise cost matrix (in this case, of _squared_ distances) 5 | % X is a matrix where each row is a histogram 6 | % opts contains other parameters for metric_streaming_kmeans++ 7 | 8 | oracle = @(x,y) compute_single_ot_distance(C, x, y, opts.OTSolver, opts.p); 9 | 10 | opts.project = 1; 11 | [centroids, out_struct] = metric_streaming_kmeans_plusplus(oracle, X, opts); 12 | end -------------------------------------------------------------------------------- /euclidean_streaming_kmeans_plusplus.m: -------------------------------------------------------------------------------- 1 | function [ centroids, out_struct ] = euclidean_streaming_kmeans_plusplus( X, opts ) 2 | %EUCLIDEAN_KMEANS_PLUSPLUS Performs W2 histogram clustering with smart 3 | %initialization 4 | % X is a matrix where each row is a histogram 5 | % opts contains other parameters for metric_streaming_kmeans++ 6 | 7 | oracle = @(x,y) oracle_helper(x, y); 8 | 9 | opts.project = 0; 10 | [centroids, out_struct] = metric_streaming_kmeans_plusplus(oracle, X, opts); 11 | end 12 | 13 | function [D, grad] = oracle_helper(x, y) 14 | x = x(:); y = y(:); 15 | D = norm(x - y, 2)^2; 16 | grad = 2*(x-y); 17 | end 18 | -------------------------------------------------------------------------------- /plotting/group_data_by_location.m: -------------------------------------------------------------------------------- 1 | function [ X_grouped, lats_unique, lons_unique ] = group_data_by_location( X, lats, lons ) 2 | %UNTITLED Summary of this function goes here 3 | % Detailed explanation goes here 4 | 5 | lats_unique = unique(lats); 6 | lons_unique = unique(lons); 7 | 8 | X_grouped = cell(length(lats_unique), length(lons_unique)); 9 | for ii=1:length(lats_unique) 10 | for jj=1:length(lons_unique) 11 | X_grouped{ii,jj} = []; 12 | end 13 | end 14 | 15 | for ii=1:size(X,1) 16 | lats_inx = find(lats_unique == lats(ii)); 17 | lons_inx = find(lons_unique == lons(ii)); 18 | X_grouped{lats_inx, lons_inx} = [X_grouped{lats_inx, lons_inx}; X(ii,:)]; 19 | end 20 | 21 | end 22 | 23 | -------------------------------------------------------------------------------- /plotting/prepare_globe_plots.m: -------------------------------------------------------------------------------- 1 | %load('emd_runs_7_20_gurobionly.mat'); 2 | load('emd_runs_7_21.mat'); 3 | 4 | %% select plots 5 | [g, c, p] = group_by_opts(out_struct_cell, centroids_cell, opts_vec); 6 | 7 | 8 | for ii=1:length(g) 9 | costs = cellfun(@(z) z.estimated_costs(end), g{ii}); 10 | [~, I] = min(costs); 11 | centroids_to_plot{ii} = c{ii}{I}; 12 | end 13 | 14 | oracle = @(x,y) compute_single_ot_distance(C, x, y, OTSolver.Gurobi, opts_vec(1).p); 15 | 16 | load('grouped_histograms_tropics.mat'); 17 | 18 | for ii=1:length(centroids_to_plot) 19 | relative = centroid_fractions_by_location(X_grouped, centroids_to_plot{ii}, oracle); 20 | save(sprintf('relative_7_22_%d.mat', ii), 'relative'); 21 | end 22 | -------------------------------------------------------------------------------- /compute-optimal-transport/compute_single_ot_distance_gurobi.m: -------------------------------------------------------------------------------- 1 | function [ val, grad ] = compute_single_ot_distance_gurobi( C, a, b ) 2 | %COMPUTE_SINGLE_OT_DISTANCE Summary of this function goes here 3 | % Detailed explanation goes here 4 | 5 | n = length(a); 6 | 7 | A = sparse(2*n,n^2); 8 | for ii=1:n 9 | A(ii,(1:n) + (ii-1)*n) = 1; 10 | end 11 | for ii=1:n 12 | A((n+1):(2*n), (1:n) + (ii-1)*n) = speye(n); 13 | end 14 | 15 | model.A = [A'; ones(1,n), zeros(1,n)]; 16 | model.obj = [a(:); b(:)]; 17 | model.modelsense = 'max'; 18 | model.rhs = [C(:); 1]; 19 | model.sense = [repmat('<', 1, n^2), '=']; 20 | model.lb = -Inf(2*n,1); 21 | 22 | param = []; 23 | param.OutputFlag = 0; 24 | str = gurobi(model, param); 25 | 26 | val = str.objval; 27 | grad = str.x(1:n); 28 | 29 | end -------------------------------------------------------------------------------- /compute-optimal-transport/compute_single_ot_distance_linprog.m: -------------------------------------------------------------------------------- 1 | function [ val, grad ] = compute_single_ot_distance_linprog( C, a, b ) 2 | %COMPUTE_SINGLE_OT_DISTANCE Summary of this function goes here 3 | % Detailed explanation goes here 4 | 5 | n = length(a); 6 | 7 | A = sparse(2*n,n^2); 8 | for ii=1:n 9 | A(ii,(1:n) + (ii-1)*n) = 1; 10 | end 11 | for ii=1:n 12 | A((n+1):(2*n), (1:n) + (ii-1)*n) = speye(n); 13 | end 14 | 15 | options = optimset('linprog'); 16 | options.Display = 'off'; 17 | problem.options = options; 18 | problem.solver = 'linprog'; 19 | problem.f = -[a(:); b(:)]; 20 | problem.Aineq = A'; 21 | problem.bineq = C(:); 22 | problem.Aeq = [ones(1,n), zeros(1,n)]; 23 | problem.beq = 1; 24 | [x,fval] = linprog(problem); 25 | %[x,fval] = linprog(-[a(:); b(:)], A', C(:), [ones(1,n), zeros(1,n)], 1); 26 | val = -fval; 27 | grad = x(1:n); 28 | 29 | end -------------------------------------------------------------------------------- /plotting/group_by_opts.m: -------------------------------------------------------------------------------- 1 | function [ grouped_out_structs, grouped_centroids, opts_prototypes ] = group_by_opts(out_struct_cell, centroid_cell, opts_vec) 2 | %UNTITLED5 Summary of this function goes here 3 | % Detailed explanation goes here 4 | 5 | opts_prototypes{1} = opts_vec(1); 6 | proto_inx(1) = 1; 7 | for ii=2:length(opts_vec) 8 | matches = cellfun(@(opt) isequal(opt, opts_vec(ii)), opts_prototypes); 9 | if ~any(matches) 10 | this_proto_inx = 1 + length(opts_prototypes); 11 | opts_prototypes{this_proto_inx} = opts_vec(ii); 12 | proto_inx(ii) = this_proto_inx; 13 | else 14 | inx = find(matches); 15 | proto_inx(ii) = inx; 16 | end 17 | end 18 | 19 | for ii=1:length(opts_prototypes) 20 | grouped_out_structs{ii} = out_struct_cell(proto_inx == ii); 21 | grouped_centroids{ii} = centroid_cell(proto_inx == ii); 22 | end 23 | 24 | end 25 | 26 | -------------------------------------------------------------------------------- /compute-optimal-transport/compute_single_ot_distance.m: -------------------------------------------------------------------------------- 1 | function [ val, grad ] = compute_single_ot_distance( C, a, b, solver, p ) 2 | %COMPUTE_SINGLE_OT_DISTANCE Summary of this function goes here 3 | % Detailed explanation goes here 4 | 5 | if p == 2 6 | Cmat = C.^2; 7 | else 8 | Cmat = C; 9 | end 10 | 11 | 12 | if solver == OTSolver.FastEMD 13 | assert(p == 1); 14 | [val, grad] = compute_single_ot_distance_fastemd(Cmat, a, b); 15 | elseif solver == OTSolver.Gurobi 16 | [val, grad] = compute_single_ot_distance_gurobi(Cmat, a, b); 17 | elseif solver == OTSolver.Linprog 18 | [val, grad] = compute_single_ot_distance_linprog(Cmat, a, b); 19 | elseif solver == OTSolver.Mosek 20 | [val, grad] = compute_single_ot_distance_mosek(Cmat, a, b); 21 | else %solver == OTSolver.Sinkhorn 22 | [val, grad] = compute_single_ot_distance_sinkhorn(Cmat, a, b); 23 | end 24 | 25 | if p == 1 26 | grad = 2*val*grad; 27 | val = val^2; 28 | end 29 | 30 | end 31 | -------------------------------------------------------------------------------- /compute-optimal-transport/compute_single_ot_distance_mosek.m: -------------------------------------------------------------------------------- 1 | function [ val, grad ] = compute_single_ot_distance_mosek( C, a, b ) 2 | %COMPUTE_SINGLE_OT_DISTANCE Summary of this function goes here 3 | % Detailed explanation goes here 4 | 5 | n = length(a); 6 | 7 | A = sparse(2*n,n^2); 8 | for ii=1:n 9 | A(ii,(1:n) + (ii-1)*n) = 1; 10 | end 11 | for ii=1:n 12 | A((n+1):(2*n), (1:n) + (ii-1)*n) = speye(n); 13 | end 14 | 15 | param.MSK_IPAR_LOG = 0; 16 | param.MSK_IPAR_LOG_HEAD = 0; 17 | param.MSK_IPAR_NUM_THREADS = 1; 18 | %param.MSK_IPAR_OPTIMIZER = 'MSK_OPTIMIZER_FREE_SIMPLEX'; 19 | %[res] = msklpopt(-[a(:); b(:)], [A'; ones(1,n), zeros(1,n); -ones(1,n), zeros(1,n)], [], [C(:); 1; -1], [], [], param); 20 | % for interior point solver 21 | %val = -res.sol.itr.pobjval; 22 | %grad = res.sol.itr.xx(1:n); 23 | 24 | prob.c = -[a(:); b(:)]; 25 | prob.a = [A'; ones(1,n), zeros(1,n); -ones(1,n), zeros(1,n)]; 26 | prob.blc = []; 27 | prob.buc = [C(:); 1; -1]; 28 | prob.blx = []; 29 | prob.bux = []; 30 | [r,res] = mosekopt('minimize echo(0)', prob, param); 31 | 32 | % for simplex solver 33 | val = -res.sol.bas.pobjval; 34 | grad = res.sol.bas.xx(1:n); 35 | 36 | 37 | end -------------------------------------------------------------------------------- /compute-optimal-transport/compute_single_ot_distance_fastemd.m: -------------------------------------------------------------------------------- 1 | function [ val, grad ] = compute_single_ot_distance_fastemd( C, a, b ) 2 | %COMPUTE_SINGLE_OT_DISTANCE Summary of this function goes here 3 | % Detailed explanation goes here 4 | 5 | n = length(a); 6 | 7 | % [val, F] = emd_hat_gd_metric_mex(a(:),b(:),C,-1);%,FType) 8 | % uv = dual_variables(C, F); 9 | % grad = uv(1:n) - 1/n*sum(uv(1:n)); 10 | 11 | % assume C is already made up of _squared_ distances 12 | [val_unsq, F] = emd_hat_mex(a(:),b(:),sqrt(C),-1);%,FType) 13 | %[val_unsq, F] = emd_hat_gd_metric_mex(a(:),b(:),sqrt(C),-1);%,FType) 14 | uv = dual_variables(C, F); 15 | grad_unsq = uv(1:n) - 1/n*sum(uv(1:n)); 16 | 17 | val = val_unsq;%^2; 18 | grad = grad_unsq;%2*val_unsq * grad_unsq; 19 | end 20 | 21 | function [uv] = dual_variables(cost, flow) 22 | n1 = size(flow,1); 23 | n2 = size(flow,2); 24 | 25 | [I,J,~] = find(flow); 26 | A = sparse(length(I)+1, n1+n2); 27 | for kk=1:length(I) 28 | A(kk, I(kk)) = 1; 29 | A(kk, n1 + J(kk)) = 1; 30 | end 31 | A(length(I)+1,1) = 1; 32 | C_sub = zeros(length(I)+1,1); 33 | for kk=1:length(I) 34 | C_sub(kk) = cost(I(kk),J(kk)); 35 | end 36 | warning('off'); 37 | uv = A\C_sub; 38 | warning('on'); 39 | end 40 | -------------------------------------------------------------------------------- /initialize_kmeans_plusplus.m: -------------------------------------------------------------------------------- 1 | function [ centroids ] = initialize_kmeans_plusplus( oracle, X, k, use_mcmc ) 2 | %INITIALIZE_KMEANS_PLUSPLUS Assigns centroids via D2 weighting 3 | % 4 | 5 | n = size(X, 1); 6 | d = size(X, 2); 7 | centroids = zeros(k, d); 8 | 9 | centroids(1,:) = X(randi(n),:); 10 | 11 | squared_distances = Inf(n, 1); 12 | 13 | % MCMC version from the Bachem/Krause paper 14 | if use_mcmc 15 | for ii=2:k 16 | x = X(randi(n),:); 17 | dx = oracle(centroids(ii-1,:), x); 18 | 19 | % number of burn-in steps, hardcoded, arbitrary... 20 | m = 2000; 21 | for jj=2:m 22 | y = X(randi(n),:); 23 | dy = oracle(centroids(ii-1,:), y); 24 | if dy/dx > rand 25 | x = y; 26 | dx = dy; 27 | end 28 | end 29 | 30 | centroids(ii,:) = x; 31 | end 32 | else 33 | for ii=2:k 34 | for jj=1:n 35 | D = oracle(centroids(ii-1,:), X(jj,:)); 36 | squared_distances(jj) = min(squared_distances(jj), D); 37 | end 38 | new_centroid_inx = randsample(n,1,true,squared_distances); 39 | 40 | centroids(ii,:) = X(new_centroid_inx,:); 41 | end 42 | end 43 | 44 | end 45 | -------------------------------------------------------------------------------- /plotting/centroid_fractions_by_location.m: -------------------------------------------------------------------------------- 1 | function [ relative_occurence_by_centroid ] = centroid_fractions_by_location(X_grouped, centroids, oracle) 2 | %UNTITLED2 Summary of this function goes here 3 | % Detailed explanation goes here 4 | 5 | [m, n] = size(X_grouped); 6 | N = 100; % number of samples to take 7 | 8 | k = size(centroids,1); 9 | relative_occurence_by_centroid = cell(k, 1); 10 | for kk=1:k 11 | relative_occurence_by_centroid{kk} = zeros(m, n); 12 | end 13 | 14 | ws_counts_cell = cell(m, n); 15 | 16 | for ii=1:m 17 | parfor jj=1:n 18 | X = X_grouped{ii,jj}; 19 | if isempty(X) 20 | continue; 21 | end 22 | 23 | X = double(X); % won't scale well to the full dataset because uint8 more efficient 24 | X = X / 255; 25 | X = [X, 1 - sum(X, 2)]; 26 | [~, ws_counts] = estimate_metric_kmeans_objective(oracle, X, centroids, N); 27 | 28 | ws_counts_cell{ii,jj} = ws_counts; 29 | end 30 | fprintf('Finished ii=%d of %d\n', ii, m); 31 | end 32 | 33 | for ii=1:m 34 | for jj=1:n 35 | for kk=1:k 36 | if isempty(ws_counts_cell{ii,jj}) 37 | continue 38 | end 39 | relative_occurence_by_centroid{kk}(ii, jj) = ws_counts_cell{ii,jj}(kk); 40 | end 41 | end 42 | end 43 | 44 | end 45 | 46 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Wasserstein k-means++ for Cloud Regime Histogram Clustering 2 | This repository contains the supporting code for the paper: 3 | 4 | Staib, Matthew and Jegelka, Stefanie. Wasserstein k-means++ for Cloud Regime Histogram Clustering. In _Proceedings of the Seventh International Workshop on Climate Informatics_, 2017. 5 | 6 | ``` 7 | @inproceedings{staib2017wasserstein, 8 | author = {Staib, Matthew and Jegelka, Stefanie}, 9 | title = {Wasserstein k-means++ for Cloud Regime Histogram Clustering}, 10 | booktitle = {Proceedings of the Seventh International Workshop on Climate 11 | Informatics: CI 2017}, 12 | year = {2017} 13 | } 14 | ``` 15 | 16 | ## Dependencies 17 | * [TFOCS](https://github.com/cvxr/TFOCS) 18 | * [Gurobi](https://www.gurobi.com/) (for the Wasserstein gradient oracle) 19 | * [MOSEK](https://www.mosek.com/) (for an alternative Wasserstein gradient oracle) 20 | * [FastEMD](http://www.ariel.ac.il/sites/ofirpele/FastEMD/code/) (for an alternative Wasserstein gradient oracle) 21 | * [sinkhornTransport.m from Marco Cuturi](http://marcocuturi.net/SI.html) (for an alternative Wasserstein gradient oracle) 22 | * ICCSP D1 dataset 23 | 24 | ## Getting started 25 | 1. First parse the ICCSP dataset by navigating to the directory with all the .hdf files, then running `extract_cloud_histograms` 26 | 2. Add the `compute-optimal-transport` directory to the path, as well as any other third party code (e.g. TFOCS) 27 | 3. Run `cluster_histograms` which will load the preprocessed ICCSP data and run various clustering algorithms. (be sure to modify the first few lines of `cluster_histograms` to load the `.mat` file from step 1, wherever it was stored) 28 | 29 | At this point, the figures from the paper can be generated by running 30 | * `weather_state_plots` 31 | * `prepare_globe_plots` followed by `make_globe_plots` 32 | -------------------------------------------------------------------------------- /plotting/estimate_metric_kmeans_objective.m: -------------------------------------------------------------------------------- 1 | function [ val, ws_counts ] = estimate_metric_kmeans_objective( oracle, X, centroids_cell, N) 2 | %ESTIMATE_METRIC_KMEANS_OBJECTIVE Performs histogram clustering with smart 3 | %initialization wrt general metric 4 | % oracle takes the form [dist, grad] = oracle(x,y) where the gradient is wrt x 5 | % C is the pairwise cost matrix (in this case, of _squared_ distances) 6 | % X is a matrix where each row is a histogram 7 | % k is the number of clusters 8 | 9 | if ~iscell(centroids_cell) 10 | centroids = centroids_cell; 11 | clear centroids_cell; 12 | centroids_cell{1} = centroids; 13 | end 14 | 15 | %N = 100;%2000; 16 | batch = X(randsample(size(X,1), N), :); 17 | 18 | divide_by_255 = any(sum(centroids_cell{1}, 2) > 1 + 1e-8); 19 | 20 | % use the same batch for all candidate centroids 21 | val = zeros(length(centroids_cell), 1); 22 | for ii=1:length(centroids_cell) 23 | if divide_by_255 24 | centroids = centroids_cell{ii} / 255; %Euclidean ones are stored out of 255 25 | end 26 | if size(centroids,2) ~= size(X,2) 27 | centroids = [centroids 1-sum(centroids,2)]; 28 | end 29 | [vals, ws_counts] = find_closest_centroid(centroids, oracle, batch); 30 | val(ii) = sum(vals) * size(X,1) / N; 31 | end 32 | 33 | end 34 | 35 | function [vals, ws_counts] = find_closest_centroid(centroids, oracle, X) 36 | n = size(X, 1); 37 | 38 | vals = zeros(n,1); 39 | ws_counts = zeros(size(centroids, 1), 1); 40 | for ii=1:n 41 | Xii = X(ii,:); 42 | [vals(ii), inx] = cost_single_point(centroids, oracle, Xii); 43 | ws_counts(inx) = ws_counts(inx) + 1; 44 | end 45 | 46 | ws_counts = ws_counts / sum(ws_counts); 47 | end 48 | 49 | function [estimate, inx] = cost_single_point(centroids, oracle, x) 50 | k = size(centroids, 1); 51 | dists = zeros(k,1); 52 | for jj=1:k 53 | [D, ~] = oracle(centroids(jj,:), x); 54 | dists(jj) = D; 55 | end 56 | 57 | [estimate, inx] = min(dists); 58 | end -------------------------------------------------------------------------------- /plotting/make_globe_plots.m: -------------------------------------------------------------------------------- 1 | %% actually make the plots (heatmap of cluster frequency) 2 | valid_inx = ~cellfun(@isempty, X_grouped); 3 | [lons, lats] = meshgrid(double(lons_unique), double(lats_unique)); 4 | valid_lons = lons(valid_inx); 5 | valid_lats = lats(valid_inx); 6 | 7 | scale = 2; 8 | %'Position',scale*[0 0 3.4 3.4], ... %for k=8 9 | figure('Units','inches', ... 10 | 'Position',scale*[0 0 3.4 2.2], ... %for k=5, 2.6 for k=6 11 | 'PaperPositionMode', 'auto'); 12 | set(gcf, 'Renderer', 'painters'); % so that it is vector graphics 13 | set(0, 'defaultaxeslooseinset', [0 0 0 0]) 14 | 15 | for kk=1:length(relative) 16 | %subplot(length(relative), 1, kk); 17 | subaxis(length(relative), 1, kk, 'sv', 0, 'sh', 0, 'Padding', 0, 'mr', 0.0, 'ml', 0.0, 'PaddingBottom', 0.0, 'PaddingRight', 0.1, 'PaddingLeft', 0.02); 18 | 19 | freqs = relative{kk}; 20 | valid_freqs = freqs(valid_inx); 21 | SI = scatteredInterpolant(valid_lons, valid_lats, valid_freqs, 'natural', 'linear'); 22 | 23 | glon = double(linspace(min(lons_unique), max(lons_unique), 2*140)); 24 | glat = double(linspace(min(lats_unique), max(lats_unique), 2*length(lats_unique))); 25 | 26 | ax = worldmap([-15 15], [0 360]); 27 | 28 | setm(ax,... 29 | 'FontUnits','points',... 30 | 'FontWeight','normal',... 31 | 'FontSize',scale*6,... 32 | 'FontName','Times'); 33 | setm(ax, 'mlabelparallel', -15); 34 | setm(ax, 'MLabelLocation', 0:60:360); 35 | if kk ~= length(relative) 36 | mlabel('off'); 37 | end 38 | setm(ax, 'ParallelLabel', 'off'); 39 | 40 | load geoid; 41 | %geoshow(geoid, geoidrefvec, 'DisplayType', 'texturemap'); hold on; 42 | %surfm(1:10, 1:10, 3*ones(1,10)); 43 | %surfm(valid_lats, valid_lons, valid_freqs); 44 | [gx, gy] = meshgrid(glon, glat); 45 | gridded = SI(gx,gy); 46 | surfm(glat - 90, glon, 255*max(0, gridded)); 47 | 48 | set_my_colormap; 49 | 50 | load coastlines; 51 | [latcells, loncells] = polysplit(coastlat, coastlon); 52 | h = plotm(coastlat, coastlon, 'Color', 'black'); 53 | 54 | %axis tight; 55 | end 56 | 57 | hp_bottom = [0.1300 0.1100 0.7750 0.0760]; 58 | hp_top = [0.1300 0.8490 0.7750 0.0760]; 59 | 60 | ticks = 0:20:100;%*max(max(gridded)); 61 | tick_labels = cellfun(@num2str, num2cell(ticks), 'uniformoutput',false); 62 | 63 | rescaled_ticks = 256/max(ticks) * ticks; 64 | caxis([0 256]) 65 | colorbar('Position', [hp_bottom(1)+hp_bottom(3)+0.01, hp_bottom(2)+0.004, 0.03, hp_top(2)+hp_top(4)-hp_bottom(2)-0.045], ... 66 | 'XTick', rescaled_ticks, ... 67 | 'XTickLabel', tick_labels); 68 | 69 | print(gcf, '-depsc','weather-state-frequencies.eps'); -------------------------------------------------------------------------------- /plotting/subaxis.m: -------------------------------------------------------------------------------- 1 | function h=subaxis(varargin) 2 | %SUBAXIS Create axes in tiled positions. (just like subplot) 3 | % Usage: 4 | % h=subaxis(rows,cols,cellno[,settings]) 5 | % h=subaxis(rows,cols,cellx,celly[,settings]) 6 | % h=subaxis(rows,cols,cellx,celly,spanx,spany[,settings]) 7 | % 8 | % SETTINGS: Spacing,SpacingHoriz,SpacingVert 9 | % Padding,PaddingRight,PaddingLeft,PaddingTop,PaddingBottom 10 | % Margin,MarginRight,MarginLeft,MarginTop,MarginBottom 11 | % Holdaxis 12 | % 13 | % all units are relative (i.e. from 0 to 1) 14 | % 15 | % Abbreviations of parameters can be used.. (Eg MR instead of MarginRight) 16 | % (holdaxis means that it wont delete any axes below.) 17 | % 18 | % 19 | % Example: 20 | % 21 | % >> subaxis(2,1,1,'SpacingVert',0,'MR',0); 22 | % >> imagesc(magic(3)) 23 | % >> subaxis(2,'p',.02); 24 | % >> imagesc(magic(4)) 25 | % 26 | % 2001-2014 / Aslak Grinsted (Feel free to modify this code.) 27 | 28 | f=gcf; 29 | 30 | 31 | 32 | UserDataArgsOK=0; 33 | Args=get(f,'UserData'); 34 | if isstruct(Args) 35 | UserDataArgsOK=isfield(Args,'SpacingHorizontal')&isfield(Args,'Holdaxis')&isfield(Args,'rows')&isfield(Args,'cols'); 36 | end 37 | OKToStoreArgs=isempty(Args)|UserDataArgsOK; 38 | 39 | if isempty(Args)&&(~UserDataArgsOK) 40 | Args=struct('Holdaxis',0, ... 41 | 'SpacingVertical',0.05,'SpacingHorizontal',0.05, ... 42 | 'PaddingLeft',0,'PaddingRight',0,'PaddingTop',0,'PaddingBottom',0, ... 43 | 'MarginLeft',.1,'MarginRight',.1,'MarginTop',.1,'MarginBottom',.1, ... 44 | 'rows',[],'cols',[]); 45 | end 46 | Args=parseArgs(varargin,Args,{'Holdaxis'},{'Spacing' {'sh','sv'}; 'Padding' {'pl','pr','pt','pb'}; 'Margin' {'ml','mr','mt','mb'}}); 47 | 48 | if (length(Args.NumericArguments)>2) 49 | Args.rows=Args.NumericArguments{1}; 50 | Args.cols=Args.NumericArguments{2}; 51 | %remove these 2 numerical arguments 52 | Args.NumericArguments={Args.NumericArguments{3:end}}; 53 | end 54 | 55 | if OKToStoreArgs 56 | set(f,'UserData',Args); 57 | end 58 | 59 | 60 | switch length(Args.NumericArguments) 61 | case 0 62 | return % no arguments but rows/cols.... 63 | case 1 64 | if numel(Args.NumericArguments{1}) > 1 % restore subplot(m,n,[x y]) behaviour 65 | [x1 y1] = ind2sub([Args.cols Args.rows],Args.NumericArguments{1}(1)); % subplot and ind2sub count differently (column instead of row first) --> switch cols/rows 66 | [x2 y2] = ind2sub([Args.cols Args.rows],Args.NumericArguments{1}(end)); 67 | else 68 | x1=mod((Args.NumericArguments{1}-1),Args.cols)+1; x2=x1; 69 | y1=floor((Args.NumericArguments{1}-1)/Args.cols)+1; y2=y1; 70 | end 71 | % x1=mod((Args.NumericArguments{1}-1),Args.cols)+1; x2=x1; 72 | % y1=floor((Args.NumericArguments{1}-1)/Args.cols)+1; y2=y1; 73 | case 2 74 | x1=Args.NumericArguments{1};x2=x1; 75 | y1=Args.NumericArguments{2};y2=y1; 76 | case 4 77 | x1=Args.NumericArguments{1};x2=x1+Args.NumericArguments{3}-1; 78 | y1=Args.NumericArguments{2};y2=y1+Args.NumericArguments{4}-1; 79 | otherwise 80 | error('subaxis argument error') 81 | end 82 | 83 | 84 | cellwidth=((1-Args.MarginLeft-Args.MarginRight)-(Args.cols-1)*Args.SpacingHorizontal)/Args.cols; 85 | cellheight=((1-Args.MarginTop-Args.MarginBottom)-(Args.rows-1)*Args.SpacingVertical)/Args.rows; 86 | xpos1=Args.MarginLeft+Args.PaddingLeft+cellwidth*(x1-1)+Args.SpacingHorizontal*(x1-1); 87 | xpos2=Args.MarginLeft-Args.PaddingRight+cellwidth*x2+Args.SpacingHorizontal*(x2-1); 88 | ypos1=Args.MarginTop+Args.PaddingTop+cellheight*(y1-1)+Args.SpacingVertical*(y1-1); 89 | ypos2=Args.MarginTop-Args.PaddingBottom+cellheight*y2+Args.SpacingVertical*(y2-1); 90 | 91 | if Args.Holdaxis 92 | h=axes('position',[xpos1 1-ypos2 xpos2-xpos1 ypos2-ypos1]); 93 | else 94 | h=subplot('position',[xpos1 1-ypos2 xpos2-xpos1 ypos2-ypos1]); 95 | end 96 | 97 | 98 | set(h,'box','on'); 99 | %h=axes('position',[x1 1-y2 x2-x1 y2-y1]); 100 | set(h,'units',get(gcf,'defaultaxesunits')); 101 | set(h,'tag','subaxis'); 102 | 103 | 104 | 105 | if (nargout==0), clear h; end; 106 | 107 | -------------------------------------------------------------------------------- /metric_streaming_kmeans_plusplus.m: -------------------------------------------------------------------------------- 1 | function [ centroids, out_struct ] = metric_streaming_kmeans_plusplus( oracle, X, opts) 2 | %METRIC_STREAMING_KMEANS_PLUSPLUS Performs histogram clustering with smart 3 | %initialization wrt general metric 4 | % oracle takes the form [dist, grad] = oracle(x,y) where the gradient is wrt x 5 | % X is a matrix where each row is a histogram 6 | % opts is a struct with the following options: 7 | % k -- the number of cluster 8 | % smart_seeding -- whether to use the k-means++ seeding 9 | % iters -- number of iterations to run online method 10 | % batch_size -- number of points to consider in one batch 11 | % stepsize -- step size/learning rate for stochastic gradient 12 | % batch_gradients_in_parallel -- whether to use the gradients 13 | % computed when assigning points to clusters (as compared to updating 14 | % the gradients later on, during the pass) 15 | 16 | proj_simplex_helper = proj_simplex(); %TFOCS 17 | function [x_out] = proj(x) 18 | [~, x_out] = proj_simplex_helper(x, 1); 19 | end 20 | 21 | % extract options 22 | k = opts.k; 23 | smart_seeding = opts.smart_seeding; 24 | iters = opts.iters; 25 | batch_size = opts.batch_size; 26 | stepsize = opts.stepsize; 27 | use_decaying_stepsize = opts.use_decaying_stepsize; 28 | batch_gradients_in_parallel = opts.batch_gradients_in_parallel; 29 | 30 | % initialize centroids 31 | if smart_seeding 32 | use_mcmc = opts.use_mcmc; 33 | 34 | fprintf('initializing centroids via k-means++\n'); 35 | centroids = initialize_kmeans_plusplus( oracle, X, k, use_mcmc ); 36 | else 37 | centroids = X(randsample(size(X,1), k), :); 38 | end 39 | 40 | fprintf('fixing random order\n'); 41 | perm_inx = randperm(size(X,1)); 42 | 43 | % now iterate centroid assignment and updates 44 | fprintf('starting main loop\n'); 45 | cluster_sizes = zeros(k, 1); 46 | estimated_costs = zeros(iters,1); 47 | for ii=1:iters 48 | fprintf('\titeration %d...', ii); 49 | starttime = tic; 50 | 51 | range = mod((0:batch_size-1) + (ii-1)*batch_size, size(X,1)) + 1; 52 | batch = X(perm_inx(range),:); 53 | 54 | [cluster_indices, vals, grads] = find_closest_centroid(centroids, oracle, batch); 55 | total_cost_this_iter = 0; 56 | for jj=1:length(range) 57 | Xjj = batch(jj,:); 58 | cluster_inx = cluster_indices(jj); 59 | 60 | % determining stepsizes as in Sculley 61 | cluster_sizes(cluster_inx) = cluster_sizes(cluster_inx) + 1; 62 | 63 | if use_decaying_stepsize 64 | stepsize = 1 / cluster_sizes(cluster_inx); 65 | end 66 | 67 | if batch_gradients_in_parallel 68 | D = vals(jj); 69 | grad = grads(jj,:); 70 | else 71 | [D, grad] = oracle(centroids(cluster_inx,:)', Xjj); 72 | end 73 | 74 | total_cost_this_iter = total_cost_this_iter + D; 75 | 76 | % is this valid? 77 | grad(isinf(grad)) = 0; 78 | 79 | if opts.project 80 | centroids(cluster_inx,:) = proj(centroids(cluster_inx,:) - stepsize * grad(:)'); 81 | else 82 | centroids(cluster_inx,:) = centroids(cluster_inx,:) - stepsize * grad(:)'; 83 | end 84 | end 85 | estimated_costs(ii) = total_cost_this_iter * size(X,1) / batch_size; 86 | fprintf('done, took %d seconds, cost: %f\n', ceil(toc(starttime)), estimated_costs(ii)); 87 | end 88 | 89 | out_struct.estimated_costs = estimated_costs; 90 | out_struct.cluster_inx_last_iter = cluster_inx; 91 | out_struct.cluster_sizes = cluster_sizes; 92 | 93 | end 94 | 95 | function [cluster_inx, vals, grads] = find_closest_centroid(centroids, oracle, X) 96 | k = size(centroids, 1); 97 | dim = size(centroids, 2); 98 | n = size(X, 1); 99 | 100 | cluster_inx = zeros(n,1); 101 | vals = zeros(n,1); 102 | grads = zeros(n,dim); 103 | parfor ii=1:n 104 | Xii = X(ii,:); 105 | 106 | dists = zeros(k,1); 107 | grads_ii = zeros(k,dim); 108 | for jj=1:k 109 | [D, grad] = oracle(centroids(jj,:), Xii); 110 | dists(jj) = D; 111 | grads_ii(jj,:) = grad; 112 | end 113 | [~, inx] = min(dists); 114 | cluster_inx(ii) = inx; 115 | vals(ii) = dists(inx); 116 | grads(ii,:) = grads_ii(inx,:); 117 | end 118 | end 119 | 120 | -------------------------------------------------------------------------------- /cluster_histograms.m: -------------------------------------------------------------------------------- 1 | clear; 2 | [~,hostname] = system('hostname'); 3 | if strcmp(strtrim(hostname), 'mstaib.mit.edu') 4 | load '/mnt/data/climate-data/ISCCP-D1-full/all_histograms_tropics.mat' 5 | X = X_tropics_full; clear X_tropics_full; 6 | else 7 | load 'all_histograms_tropics.mat' 8 | X = X_tropics_full; clear X_tropics_full; 9 | end 10 | 11 | %% truncate dataset for now 12 | inx = randperm(size(X,1)); 13 | X = X(inx,:); 14 | 15 | % won't scale well to the full dataset because uint8 more efficient; 16 | % long term we shouldn't load the whole dataset into memory... 17 | X = double(X); 18 | X = X / 255; 19 | 20 | %% do k-means++ clustering 21 | ii = 1; 22 | for lambda=[0.5]% 1]%[0.1 0.5 1 2] 23 | for k=[4 5 6 7 8] 24 | for smart_seeding=[1] 25 | for use_decaying_stepsize=[1] 26 | for batch_gradients_in_parallel=[0] 27 | for stepsize=[1e-5]% 1e-3] 28 | for cloud_distance_type=CloudDistanceType.EuclideanBasedOnGrid %[CloudDistanceType.EuclideanBasedOnValues CloudDistanceType.EuclideanBasedOnGrid] 29 | for p=[1] % 2] 30 | for ot_solver=[OTSolver.Gurobi] 31 | for dummy=1:8 32 | % experiments showed best results from decaying stepsize and NOT batching gradients 33 | if use_decaying_stepsize && batch_gradients_in_parallel 34 | continue 35 | end 36 | opts.lambda = lambda; 37 | opts.k = k; 38 | opts.smart_seeding = 1; 39 | opts.iters = 20; 40 | opts.batch_size = 1000; 41 | opts.stepsize = stepsize; %1e-4; 42 | opts.use_decaying_stepsize = use_decaying_stepsize; 43 | opts.batch_gradients_in_parallel = batch_gradients_in_parallel; 44 | opts.OTSolver = ot_solver; 45 | opts.CloudDistanceType = cloud_distance_type; 46 | opts.use_mcmc = 1; 47 | opts.p = p; 48 | 49 | 50 | opts_vec(ii) = opts; 51 | ii = ii + 1; 52 | end 53 | end 54 | end 55 | end 56 | end 57 | end 58 | end 59 | end 60 | end 61 | end 62 | 63 | num_params = length(opts_vec); 64 | centroids_cell = cell(num_params,1); 65 | out_struct_cell = cell(num_params,1); 66 | 67 | euclidean_centroids_cell = cell(num_params,1); 68 | euclidean_out_struct_cell = cell(num_params,1); 69 | 70 | fprintf('Starting all Euclidean k-means runs\n'); 71 | for ii=1:num_params 72 | opts = opts_vec(ii); 73 | lambda = opts.lambda; 74 | 75 | [e_centroids, e_out_struct] = euclidean_streaming_kmeans_plusplus(X, opts); 76 | euclidean_centroids_cell{ii} = e_centroids; 77 | euclidean_out_struct_cell{ii} = e_out_struct; 78 | end 79 | 80 | fprintf('Starting all Wasserstein k-means runs\n'); 81 | X_expanded = [X, 1 - sum(X, 2)]; 82 | for ii=1:num_params 83 | opts = opts_vec(ii); 84 | lambda = opts.lambda; 85 | 86 | C = build_pairwise_distance_matrix(opts.CloudDistanceType, lambda); 87 | [centroids, out_struct] = wasserstein_streaming_kmeans_plusplus(C, X_expanded, opts); 88 | centroids_cell{ii} = centroids; 89 | out_struct_cell{ii} = out_struct; 90 | end 91 | 92 | %% build up pairwise distances 93 | centroid_dists_cell = cell(num_params,1); 94 | for ii=1:num_params 95 | opts = opts_vec(ii); 96 | lambda = opts.lambda; 97 | k = opts.k; 98 | C = build_pairwise_distance_matrix(opts.CloudDistanceType, lambda); 99 | 100 | centroid_dists_cell{ii} = centroid_distance_matrix(C, centroids_cell{ii}, opts_vec(ii)); 101 | end 102 | 103 | % %% plot the centroids 104 | % for ii=1:k 105 | % figure; 106 | % centroid_mat = reshape(centroids(ii,:), 6, 7); 107 | % imagesc(centroid_mat); 108 | % end 109 | -------------------------------------------------------------------------------- /extract_cloud_histograms.m: -------------------------------------------------------------------------------- 1 | clear; 2 | files = dir('*.hdf'); 3 | 4 | n = length(files); 5 | skip = 10; 6 | 7 | for ii=1:ceil(n/skip) 8 | start = (ii-1)*skip + 1; 9 | stop = min( ii*skip, n ); 10 | these_files = files(start:stop); 11 | 12 | % matrix which stores a histogram in each row 13 | X_tropics = []; 14 | lats_tropics = []; 15 | lons_tropics = []; 16 | %TODO: also save time information 17 | 18 | for file=these_files' 19 | for k=[6 8 10 12 14 16 18 20] 20 | dataset_name = strcat('/Data-Set-', num2str(k)); 21 | data_all = hdfread(strcat(file.folder, '/', file.name), dataset_name, 'Index', {[1 1],[1 1],[6596 202]}); 22 | data = data_all(:, 30:71); 23 | 24 | ANC2 = hdfread(strcat(file.folder, '/', file.name), '/ANC2'); 25 | lats = ANC2{6}; lons = ANC2{7}; 26 | 27 | tropics_inx = abs(lats - 90) <= 15; 28 | tropics_inx = tropics_inx(:); 29 | undefined_inx = any(data == 255, 2); 30 | undefined_inx = undefined_inx(:); 31 | saved_inx = tropics_inx & ~undefined_inx; 32 | 33 | %data_tropics = data(saved_inx,:); 34 | X_tropics = [X_tropics; data(saved_inx,:)]; 35 | lats_tropics = [lats_tropics lats(saved_inx)]; 36 | lons_tropics = [lons_tropics lons(saved_inx)]; 37 | 38 | % 39 | % 40 | % data_tropics = data(tropics_inx,:); 41 | % 42 | % % 255 means UNDEFINED, so remove all those rows 43 | % any_undefined_tropics = any(data_tropics == 255, 2); 44 | % data_tropics = data_tropics(~any_undefined_tropics, :); 45 | % 46 | % % if all entries are zero, we will have problems normalizing 47 | % %% WAIT: no we won't! Because now we have an extra "no cloud" state 48 | % %all_zeros = all(data_extended_tropics == 0, 2); 49 | % %data_extended_tropics = data_extended_tropics(~all_zeros, :); 50 | % 51 | % % don't normalize, so that we use less memory 52 | % %%% now normalize, to make it a histogram 53 | % %%%data = double(data) ./ sum(data, 2); 54 | % 55 | % % append to X 56 | % X_tropics = [X_tropics; data_tropics]; 57 | end 58 | 59 | fprintf('Finished file: %s\n', file.name); 60 | end 61 | 62 | %save(strcat('all_histograms_extended_tropics', num2str(ii), '.mat'), 'X_extended_tropics', '-v7.3'); 63 | save(strcat('all_histograms_tropics', num2str(ii), '.mat'), 'X_tropics', 'lats_tropics', 'lons_tropics', '-v7.3'); 64 | end 65 | 66 | % extended_tropics_files = dir('all_histograms_extended_tropics*.mat'); 67 | % X_extended_tropics_full = []; 68 | % for file=extended_tropics_files' 69 | % load(file.name); 70 | % X_extended_tropics_full = [X_extended_tropics_full; X_extended_tropics]; 71 | % end 72 | % save('all_histograms_extended_tropics.mat', 'X_extended_tropics_full'); 73 | 74 | tropics_files = dir('all_histograms_extended_tropics*.mat'); 75 | X_tropics_full = []; 76 | lats_tropics_full = []; 77 | lons_tropics_full = []; 78 | for file=tropics_files' 79 | load(file.name); 80 | X_tropics_full = [X_tropics_full; X_tropics]; 81 | lats_tropics_full = [lats_tropics_full lats_tropics]; 82 | lons_tropics_full = [lons_tropics_full lons_tropics]; 83 | end 84 | save('all_histograms_tropics.mat', 'X_tropics_full', 'lats_tropics_full', 'lons_tropics_full', '-v7.3'); 85 | 86 | % for file=files' 87 | % for k=[6 8 10 12 14 16 18 20] 88 | % dataset_name = strcat('/Data-Set-', num2str(k)); 89 | % data_all = hdfread(strcat(file.folder, '/', file.name), dataset_name, 'Index', {[1 1],[1 1],[6596 202]}); 90 | % data = data_all(:, 30:71); 91 | % 92 | % % 255 means UNDEFINED, so remove all those rows 93 | % any_undefined = any(data == 255, 2); 94 | % data = data(~any_undefined, :); 95 | % 96 | % % if all entries are zero, we will have problems normalizing 97 | % all_zeros = all(data == 0, 2); 98 | % data = data(~all_zeros, :); 99 | % 100 | % % now normalize, to make it a histogram 101 | % data = double(data) ./ sum(data, 2); 102 | % 103 | % % append to X 104 | % X = [X; data]; 105 | % end 106 | % end 107 | % 108 | % save('all_histograms.mat', 'X'); 109 | -------------------------------------------------------------------------------- /plotting/set_my_colormap.m: -------------------------------------------------------------------------------- 1 | myColormap = [255,255,255; 2 | 252,253,255; 3 | 249,252,254; 4 | 246,250,254; 5 | 244,249,254; 6 | 241,247,254; 7 | 238,246,253; 8 | 235,244,253; 9 | 232,243,253; 10 | 229,241,252; 11 | 227,240,252; 12 | 224,238,252; 13 | 221,237,252; 14 | 218,235,251; 15 | 215,233,251; 16 | 212,232,251; 17 | 209,230,250; 18 | 207,229,250; 19 | 204,227,250; 20 | 201,226,249; 21 | 198,224,249; 22 | 195,223,249; 23 | 192,221,249; 24 | 189,220,248; 25 | 187,218,248; 26 | 184,217,248; 27 | 181,215,247; 28 | 178,213,247; 29 | 175,212,247; 30 | 172,210,247; 31 | 170,209,246; 32 | 167,207,246; 33 | 164,206,246; 34 | 161,204,245; 35 | 158,203,245; 36 | 155,201,245; 37 | 152,200,245; 38 | 150,198,244; 39 | 147,197,244; 40 | 144,195,244; 41 | 141,193,243; 42 | 138,192,243; 43 | 135,190,243; 44 | 133,189,242; 45 | 130,187,242; 46 | 127,186,242; 47 | 124,184,242; 48 | 121,183,241; 49 | 118,181,241; 50 | 115,180,241; 51 | 113,178,240; 52 | 110,176,240; 53 | 107,175,240; 54 | 104,173,240; 55 | 101,172,239; 56 | 98,170,239; 57 | 95,169,239; 58 | 93,167,238; 59 | 90,166,238; 60 | 87,164,238; 61 | 84,163,238; 62 | 81,161,237; 63 | 78,160,237; 64 | 76,158,237; 65 | 73,157,236; 66 | 73,158,235; 67 | 73,159,233; 68 | 73,160,232; 69 | 72,161,231; 70 | 72,162,229; 71 | 72,162,228; 72 | 72,163,226; 73 | 71,164,225; 74 | 71,165,224; 75 | 71,166,222; 76 | 71,167,221; 77 | 70,168,219; 78 | 70,169,218; 79 | 70,170,217; 80 | 70,171,215; 81 | 69,171,214; 82 | 69,172,212; 83 | 69,173,211; 84 | 69,174,210; 85 | 68,175,208; 86 | 68,176,207; 87 | 68,177,205; 88 | 67,178,204; 89 | 67,179,203; 90 | 67,180,201; 91 | 67,180,200; 92 | 66,181,198; 93 | 66,182,197; 94 | 66,183,196; 95 | 66,184,194; 96 | 65,185,193; 97 | 65,186,192; 98 | 65,187,190; 99 | 65,188,189; 100 | 64,188,187; 101 | 64,189,186; 102 | 64,190,185; 103 | 64,191,183; 104 | 63,192,182; 105 | 63,193,180; 106 | 63,194,179; 107 | 63,195,178; 108 | 62,196,176; 109 | 62,197,175; 110 | 62,197,173; 111 | 62,198,172; 112 | 61,199,171; 113 | 61,200,169; 114 | 61,201,168; 115 | 61,202,166; 116 | 60,203,165; 117 | 60,204,164; 118 | 60,205,162; 119 | 60,206,161; 120 | 59,206,159; 121 | 59,207,158; 122 | 59,208,157; 123 | 59,209,155; 124 | 58,210,154; 125 | 58,211,152; 126 | 58,212,151; 127 | 58,213,150; 128 | 57,214,148; 129 | 59,214,147; 130 | 62,215,147; 131 | 65,216,147; 132 | 68,216,146; 133 | 71,217,146; 134 | 74,218,146; 135 | 77,218,145; 136 | 80,219,145; 137 | 83,220,144; 138 | 86,220,144; 139 | 89,221,144; 140 | 92,222,143; 141 | 95,222,143; 142 | 98,223,143; 143 | 101,224,142; 144 | 104,224,142; 145 | 107,225,141; 146 | 110,226,141; 147 | 113,226,141; 148 | 117,227,140; 149 | 120,227,140; 150 | 123,228,140; 151 | 126,229,139; 152 | 129,229,139; 153 | 132,230,138; 154 | 135,231,138; 155 | 138,231,138; 156 | 141,232,137; 157 | 144,233,137; 158 | 147,233,137; 159 | 150,234,136; 160 | 153,235,136; 161 | 156,235,135; 162 | 159,236,135; 163 | 162,237,135; 164 | 165,237,134; 165 | 168,238,134; 166 | 171,239,133; 167 | 174,239,133; 168 | 177,240,133; 169 | 180,241,132; 170 | 183,241,132; 171 | 186,242,132; 172 | 190,243,131; 173 | 193,243,131; 174 | 196,244,130; 175 | 199,244,130; 176 | 202,245,130; 177 | 205,246,129; 178 | 208,246,129; 179 | 211,247,129; 180 | 214,248,128; 181 | 217,248,128; 182 | 220,249,127; 183 | 223,250,127; 184 | 226,250,127; 185 | 229,251,126; 186 | 232,252,126; 187 | 235,252,126; 188 | 238,253,125; 189 | 241,254,125; 190 | 244,254,124; 191 | 247,255,124; 192 | 250,255,124; 193 | 251,253,122; 194 | 251,249,121; 195 | 251,245,119; 196 | 251,241,118; 197 | 251,237,116; 198 | 252,234,115; 199 | 252,230,113; 200 | 252,226,112; 201 | 252,222,110; 202 | 252,218,109; 203 | 252,214,107; 204 | 252,210,105; 205 | 252,207,104; 206 | 252,203,102; 207 | 252,199,101; 208 | 252,195,99; 209 | 252,191,98; 210 | 252,187,96; 211 | 253,184,95; 212 | 253,180,93; 213 | 253,176,91; 214 | 253,172,90; 215 | 253,168,88; 216 | 253,164,87; 217 | 253,160,85; 218 | 253,157,84; 219 | 253,153,82; 220 | 253,149,81; 221 | 253,145,79; 222 | 253,141,77; 223 | 253,137,76; 224 | 253,133,74; 225 | 254,130,73; 226 | 254,126,71; 227 | 254,122,70; 228 | 254,118,68; 229 | 254,114,67; 230 | 254,110,65; 231 | 254,106,64; 232 | 254,103,62; 233 | 254,99,60; 234 | 254,95,59; 235 | 254,91,57; 236 | 254,87,56; 237 | 254,83,54; 238 | 254,79,53; 239 | 255,76,51; 240 | 255,72,50; 241 | 255,68,48; 242 | 255,64,46; 243 | 255,60,45; 244 | 255,56,43; 245 | 255,52,42; 246 | 255,49,40; 247 | 255,45,39; 248 | 255,41,37; 249 | 255,37,36; 250 | 255,33,34; 251 | 255,29,32; 252 | 255,25,31; 253 | 255,22,29; 254 | 255,18,28; 255 | 255,14,26; 256 | 255,10,25]; 257 | colormap(myColormap/255); -------------------------------------------------------------------------------- /plotting/parseArgs.m: -------------------------------------------------------------------------------- 1 | function ArgStruct=parseArgs(args,ArgStruct,varargin) 2 | % Helper function for parsing varargin. 3 | % 4 | % 5 | % ArgStruct=parseArgs(varargin,ArgStruct[,FlagtypeParams[,Aliases]]) 6 | % 7 | % * ArgStruct is the structure full of named arguments with default values. 8 | % * Flagtype params is params that don't require a value. (the value will be set to 1 if it is present) 9 | % * Aliases can be used to map one argument-name to several argstruct fields 10 | % 11 | % 12 | % example usage: 13 | % -------------- 14 | % function parseargtest(varargin) 15 | % 16 | % %define the acceptable named arguments and assign default values 17 | % Args=struct('Holdaxis',0, ... 18 | % 'SpacingVertical',0.05,'SpacingHorizontal',0.05, ... 19 | % 'PaddingLeft',0,'PaddingRight',0,'PaddingTop',0,'PaddingBottom',0, ... 20 | % 'MarginLeft',.1,'MarginRight',.1,'MarginTop',.1,'MarginBottom',.1, ... 21 | % 'rows',[],'cols',[]); 22 | % 23 | % %The capital letters define abrreviations. 24 | % % Eg. parseargtest('spacingvertical',0) is equivalent to parseargtest('sv',0) 25 | % 26 | % Args=parseArgs(varargin,Args, ... % fill the arg-struct with values entered by the user 27 | % {'Holdaxis'}, ... %this argument has no value (flag-type) 28 | % {'Spacing' {'sh','sv'}; 'Padding' {'pl','pr','pt','pb'}; 'Margin' {'ml','mr','mt','mb'}}); 29 | % 30 | % disp(Args) 31 | % 32 | % 33 | % 34 | % 35 | % Aslak Grinsted 2004 36 | 37 | % ------------------------------------------------------------------------- 38 | % Copyright (C) 2002-2004, Aslak Grinsted 39 | % This software may be used, copied, or redistributed as long as it is not 40 | % sold and this copyright notice is reproduced on each copy made. This 41 | % routine is provided as is without any express or implied warranties 42 | % whatsoever. 43 | 44 | persistent matlabver 45 | 46 | if isempty(matlabver) 47 | matlabver=ver('MATLAB'); 48 | matlabver=str2double(matlabver.Version); 49 | end 50 | 51 | Aliases={}; 52 | FlagTypeParams=''; 53 | 54 | if (length(varargin)>0) 55 | FlagTypeParams=lower(strvcat(varargin{1})); %#ok 56 | if length(varargin)>1 57 | Aliases=varargin{2}; 58 | end 59 | end 60 | 61 | 62 | %---------------Get "numeric" arguments 63 | NumArgCount=1; 64 | while (NumArgCount<=size(args,2))&&(~ischar(args{NumArgCount})) 65 | NumArgCount=NumArgCount+1; 66 | end 67 | NumArgCount=NumArgCount-1; 68 | if (NumArgCount>0) 69 | ArgStruct.NumericArguments={args{1:NumArgCount}}; 70 | else 71 | ArgStruct.NumericArguments={}; 72 | end 73 | 74 | 75 | %--------------Make an accepted fieldname matrix (case insensitive) 76 | Fnames=fieldnames(ArgStruct); 77 | for i=1:length(Fnames) 78 | name=lower(Fnames{i,1}); 79 | Fnames{i,2}=name; %col2=lower 80 | Fnames{i,3}=[name(Fnames{i,1}~=name) ' ']; %col3=abreviation letters (those that are uppercase in the ArgStruct) e.g. SpacingHoriz->sh 81 | %the space prevents strvcat from removing empty lines 82 | Fnames{i,4}=isempty(strmatch(Fnames{i,2},FlagTypeParams)); %Does this parameter have a value? 83 | end 84 | FnamesFull=strvcat(Fnames{:,2}); %#ok 85 | FnamesAbbr=strvcat(Fnames{:,3}); %#ok 86 | 87 | if length(Aliases)>0 88 | for i=1:length(Aliases) 89 | name=lower(Aliases{i,1}); 90 | FieldIdx=strmatch(name,FnamesAbbr,'exact'); %try abbreviations (must be exact) 91 | if isempty(FieldIdx) 92 | FieldIdx=strmatch(name,FnamesFull); %&??????? exact or not? 93 | end 94 | Aliases{i,2}=FieldIdx; 95 | Aliases{i,3}=[name(Aliases{i,1}~=name) ' ']; %the space prevents strvcat from removing empty lines 96 | Aliases{i,1}=name; %dont need the name in uppercase anymore for aliases 97 | end 98 | %Append aliases to the end of FnamesFull and FnamesAbbr 99 | FnamesFull=strvcat(FnamesFull,strvcat(Aliases{:,1})); %#ok 100 | FnamesAbbr=strvcat(FnamesAbbr,strvcat(Aliases{:,3})); %#ok 101 | end 102 | 103 | %--------------get parameters-------------------- 104 | l=NumArgCount+1; 105 | while (l<=length(args)) 106 | a=args{l}; 107 | if ischar(a) 108 | paramHasValue=1; % assume that the parameter has is of type 'param',value 109 | a=lower(a); 110 | FieldIdx=strmatch(a,FnamesAbbr,'exact'); %try abbreviations (must be exact) 111 | if isempty(FieldIdx) 112 | FieldIdx=strmatch(a,FnamesFull); 113 | end 114 | if (length(FieldIdx)>1) %shortest fieldname should win 115 | [mx,mxi]=max(sum(FnamesFull(FieldIdx,:)==' ',2));%#ok 116 | FieldIdx=FieldIdx(mxi); 117 | end 118 | if FieldIdx>length(Fnames) %then it's an alias type. 119 | FieldIdx=Aliases{FieldIdx-length(Fnames),2}; 120 | end 121 | 122 | if isempty(FieldIdx) 123 | error(['Unknown named parameter: ' a]) 124 | end 125 | for curField=FieldIdx' %if it is an alias it could be more than one. 126 | if (Fnames{curField,4}) 127 | if (l+1>length(args)) 128 | error(['Expected a value for parameter: ' Fnames{curField,1}]) 129 | end 130 | val=args{l+1}; 131 | else %FLAG PARAMETER 132 | if (l=6 150 | ArgStruct.(Fnames{curField,1})=val; %try the line below if you get an error here 151 | else 152 | ArgStruct=setfield(ArgStruct,Fnames{curField,1},val); %#ok <-works in old matlab versions 153 | end 154 | end 155 | l=l+1+paramHasValue; %if a wildcard matches more than one 156 | else 157 | error(['Expected a named parameter: ' num2str(a)]) 158 | end 159 | end -------------------------------------------------------------------------------- /plotting/weather_state_plots.m: -------------------------------------------------------------------------------- 1 | % replace this 2 | load('emd_runs_7_20_gurobionly.mat'); 3 | 4 | load('all_histograms_tropics.mat'); 5 | X = X_tropics_full; clear X_tropics_full; 6 | inx = randperm(size(X,1)); 7 | X = X(inx,:); 8 | X = double(X); % won't scale well to the full dataset because uint8 more efficient 9 | X = X / 255; 10 | X_expanded = [X, 1 - sum(X, 2)]; 11 | 12 | figure; 13 | for kk=1:length(out_struct_cell) 14 | if opts_vec(kk).smart_seeding == 0 15 | subplot(2,1,1); 16 | plot(out_struct_cell{kk}.estimated_costs); hold on; 17 | else 18 | subplot(2,1,2); 19 | plot(out_struct_cell{kk}.estimated_costs); hold on; 20 | end 21 | end 22 | % 23 | % cellfun(@(x) x.estimated_costs(end), out_struct_cell([opts_vec.smart_seeding] == 0)); 24 | % 25 | % cellfun(@(x) x.estimated_costs(end), out_struct_cell([opts_vec.smart_seeding] == 1)); 26 | 27 | C = build_pairwise_distance_matrix(opts_vec(1).CloudDistanceType, 1); 28 | oracle = @(x,y) compute_single_ot_distance(C, x, y, OTSolver.Gurobi, opts_vec(1).p); 29 | % cost_of_euclidean_clusters = estimate_metric_kmeans_objective(oracle, X_expanded, euclidean_centroids_cell); 30 | % 31 | 32 | %% select plots 33 | [g, c, p] = group_by_opts(out_struct_cell, centroids_cell, opts_vec); 34 | 35 | 36 | for ii=1:length(g) 37 | costs = cellfun(@(z) z.estimated_costs(end), g{ii}); 38 | [~, I] = min(costs); 39 | centroids_to_plot{ii} = c{ii}{I}; 40 | end 41 | 42 | %% Weather State plots 43 | 44 | for ii=2%length(centroids_to_plot) 45 | centroids = centroids_to_plot{ii};%centroids_cell{1+16+48}; 46 | centroids = centroids(:, 1:42) * 255; 47 | 48 | [val, rfo] = estimate_metric_kmeans_objective(oracle, X_expanded, centroids, 2000); 49 | tcc = sum(centroids * 100/255, 2); 50 | 51 | 52 | %x_cutoffs = {'0.02' '1.27' '3.55' '9.38' '22.63' '60.36' '378.65'}; 53 | x_cutoffs = {'0' '1.3' '3.6' '9.4' '23' '60' '379'}; 54 | y_cutoffs = {'10' '180' '310' '440' '560' '680' '800' '1000'}; 55 | 56 | num_cols = ceil(size(centroids,1) / 2); 57 | 58 | figure('Units','inches', ... 59 | 'Position',[0 0 3.4 2.0], ... 60 | 'PaperPositionMode', 'auto'); 61 | set(gcf, 'Renderer', 'painters'); % so that it is vector graphics 62 | 63 | for kk=1:size(centroids,1) 64 | ws = reshape(centroids(kk,:), 6, 7)'; 65 | 66 | subplot(2,num_cols,kk); 67 | subaxis(2,num_cols,kk,'PaddingBottom', 0.07);%, 'sv', 0, 'sh', 0, 'Padding', 0, 'mr', 0.0, 'ml', 0.0, 'PaddingBottom', 0.1); 68 | 69 | 70 | ax(kk) = gca;%axes; 71 | set(ax(kk),... 72 | 'Units','normalized',... 73 | 'FontUnits','points',... 74 | 'FontWeight','normal',... 75 | 'FontSize',4,... 76 | 'FontName','Times'); 77 | 78 | %subplot(size(centroids,1),1,kk); 79 | %subplot(2,num_cols,kk); 80 | ws_scaled = ws * 255 / max(max(centroids)); 81 | %pcolor(ws_scaled); 82 | image(ws_scaled); 83 | %shading flat; 84 | 85 | %ax(kk) = gca; 86 | 87 | if kk>num_cols || (kk == num_cols && mod(size(centroids,1),2) ~= 0) 88 | xlabel(ax(kk), 'Cloud Optical Thickness', ... 89 | 'FontUnits','points',... 90 | 'FontSize',4,... 91 | 'FontName','Times'); 92 | end 93 | 94 | if kk==1 || kk==1+num_cols 95 | ylabel(ax(kk), 'Cloud Top Pressure (mb)', ... 96 | 'FontUnits','points',... 97 | 'FontSize',4,... 98 | 'FontName','Times'); 99 | end 100 | this_title = sprintf('WS%d, RFO=%d%%, TCC=%d%%', kk, round(100*rfo(kk)), round(tcc(kk))); 101 | title(this_title, ... 102 | 'FontUnits','points',... 103 | 'FontWeight','normal',... 104 | 'FontSize',4,... 105 | 'FontName','Times'); 106 | 107 | axis on; 108 | grid on; 109 | set(ax(kk), 'xtick', (1:7) - 0.5); 110 | set(ax(kk), 'ytick', (1:8) - 0.5); 111 | 112 | if kk > num_cols || (kk == num_cols && mod(size(centroids,1),2) ~= 0) 113 | set(ax(kk), 'xticklabel', x_cutoffs, ... 114 | 'FontUnits','points',... 115 | 'FontSize',4,... 116 | 'FontName','Times'); 117 | else 118 | set(ax(kk), 'xticklabel', cell(0)); 119 | end 120 | 121 | if kk==1 || kk==1+num_cols 122 | set(ax(kk), 'yticklabel', y_cutoffs, ... 123 | 'FontUnits','points',... 124 | 'FontSize',4,... 125 | 'FontName','Times'); 126 | else 127 | set(ax(kk), 'yticklabel', cell(0)); 128 | end 129 | 130 | % set(subplot(2,num_cols,kk,ax(kk))); 131 | 132 | end 133 | 134 | ticks = 0:5:max(max(centroids)); 135 | tick_labels = cellfun(@num2str, num2cell(ticks), 'uniformoutput',false); 136 | 137 | rescaled_ticks = 255/max(ticks) * ticks + 1; 138 | h = colorbar('SouthOutside', 'XTick', rescaled_ticks, ... 139 | 'XTickLabel', tick_labels);%0:5:max(max(centroids))); 140 | set(h, 'Position', [.1 .05 .8150 .05]); 141 | % for kk=1:size(centroids,1) 142 | % pos=get(ax(kk), 'Position'); 143 | % set(ax(kk), 'Position', [pos(1) 0.1+pos(2) pos(3) 0.8*pos(4)]); 144 | % end 145 | 146 | myColorMap = parula(255); % Make a copy of jet. 147 | % Assign white (all 1's) to black (the first row in myColorMap). 148 | myColorMap(1, :) = [1 1 1]; 149 | colormap(myColorMap); % Apply the colormap 150 | 151 | for kk=(num_cols+1):size(centroids,1) 152 | hp = get(ax(kk), 'Position'); 153 | hp(2) = 0.22; 154 | set(ax(kk), 'Position', hp); 155 | end 156 | 157 | print(gcf, '-dpdf','weather-states.pdf'); 158 | end 159 | 160 | %% distances between centroids 161 | centroid_dists_cell = cell(length(centroids_to_plot), 1); 162 | for ii=1:length(centroids_to_plot) 163 | opts = p{ii}; 164 | lambda = opts.lambda; 165 | k = opts.k; 166 | C = build_pairwise_distance_matrix(opts.CloudDistanceType, lambda); 167 | 168 | centroid_dists_cell{ii} = centroid_distance_matrix(C, centroids_to_plot{ii}, opts); 169 | end 170 | 171 | min_between_cluster_distance = zeros(length(centroid_dists_cell),1); 172 | for ii=1:length(centroid_dists_cell) 173 | dists = centroid_dists_cell{ii}; 174 | min_between_cluster_distance(ii) = min(min(dists(dists ~= 0))); 175 | end 176 | 177 | %% heatmap of cluster frequency 178 | 179 | figure('Units','inches', ... 180 | 'Position',[0 0 3.4 1.0], ... 181 | 'PaperPositionMode', 'auto'); 182 | set(gcf, 'Renderer', 'painters'); % so that it is vector graphics 183 | 184 | ax = worldmap([-15 15], [0 360]); 185 | setm(ax, 'mlabelparallel', -15); 186 | setm(ax, 'MLabelLocation', 0:60:360) 187 | 188 | load geoid; 189 | geoshow(geoid, geoidrefvec, 'DisplayType', 'texturemap'); hold on; 190 | 191 | load coastlines; 192 | [latcells, loncells] = polysplit(coastlat, coastlon); 193 | h = plotm(coastlat, coastlon, 'Color', 'black'); 194 | 195 | % load topo 196 | % R = georasterref('RasterSize', size(topo), ... 197 | % 'LatitudeLimits', [-15 15], 'LongitudeLimits', [0 360], 'RasterSize', [30 360]); 198 | % grid2image(topo(75:104,:), R); 199 | --------------------------------------------------------------------------------