├── .gitignore ├── LICENSE ├── README.md ├── docs └── tree15.pdf ├── examples ├── dirichrnd.m ├── sample_mog.m ├── test_bimodal.m ├── test_mixture_of_gaussian.m └── test_rare_binary.m └── src ├── PART ├── MultiStageMCMC.m ├── NormalizeForest.m ├── NormalizeTree.m ├── OneStageMCMC.m ├── aggregate_PART_onestage.m ├── aggregate_PART_pairwise.m ├── buildForest.m ├── buildHistogram.m ├── buildTree.m ├── cleanTree.m ├── copyTree.m ├── cut │ ├── consensusMLCut.m │ ├── fastMLCut.m │ ├── kdCut.m │ ├── meanCut.m │ └── midPointCut.m ├── empiricalLogLikelihood.m ├── part_options.m ├── treeDensity.m ├── treeNode.m ├── treeNormalize.m └── treeSampling.m ├── alternatives ├── aggregate_average.m ├── aggregate_uai_nonparametric.m ├── aggregate_uai_parametric.m ├── aggregate_uai_semiparametric.m ├── aggregate_weighted_average.m └── logmvnpdf.m ├── init.m └── utils ├── approximate_KL.m ├── parfor_progress.m ├── performance_table.m ├── plot_marginal_compare.m ├── plot_pairwise_compare.m ├── plot_tree_blocks.m ├── relative_error.m ├── rmse_posterior_cov.m ├── rmse_posterior_mean.m └── thinning.m /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by https://www.gitignore.io/api/matlab 2 | 3 | ### Matlab ### 4 | ##--------------------------------------------------- 5 | ## Remove autosaves generated by the Matlab editor 6 | ## We have git for backups! 7 | ##--------------------------------------------------- 8 | 9 | # Windows default autosave extension 10 | *.asv 11 | 12 | # OSX / *nix default autosave extension 13 | *.m~ 14 | .DS_Store 15 | 16 | # Compiled MEX binaries (all platforms) 17 | *.mex* 18 | 19 | # Simulink Code Generation 20 | slprj/ 21 | 22 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2015 Richard Kwo 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PART: Parallelizing MCMC with Random Partition Trees 2 | 3 | MATLAB implementation of **PART**[1](#myfootnote1): a fast algorithm for aggregating MCMC sub-chain samples 4 | 5 | [TOC] 6 | 7 | ## Introduction 8 | 9 | The modern scale of data has brought new challenges to Bayesian inference. In particular, conventional MCMC algorithms are computationally very expensive for large data sets. A promising approach to solve this problem is **embarrassingly parallel MCMC** (EP-MCMC), which first partitions the data into multiple subsets and runs independent sampling algorithms on each subset. The subset posterior draws are then aggregated via some combining rules to obtain the final approximation. 10 | 11 | **PART**[1](#myfootnote1) is an EP-MCMC algorithm that applies random partition tree to combine the subset posterior draws, which is distribution-free, easy to resample from and can adapt to multiple scales. This repository maintains a MATLAB implementation of PART. 12 | 13 | ## QuickStart 14 | 15 | Running parallel MCMC with PART consists of the following steps. 16 | 17 | 1. Partition the data into M subsets. 18 | 19 | For example, supposing there are 1,000 data points, one can partition them into M=10 subsets, with each subset containing 100 data points. Partitioning does not have to be equal-sized. 20 | 21 | 2. Run MCMC on each subset in parallel with your favorite sampler. 22 | 23 | For example, you have p=4 parameters in the model and for each subset i (i=1,2,….,10), you draw 5,000 posterior **sub-chain samples**, resulting in 10 matrices of size 5,000 x 4. 24 | 25 | > **Note**: PART currently supports real-valued samples only. 26 | 27 | 3. Call PART to aggregate **sub-chain samples** into **posterior samples**, as if the posterior samples come from running MCMC on the whole dataset. 28 | 29 | With the running example, do the following in MATLAB. 30 | 31 | 1. Run `init` under the `src/` to add paths. 32 | 33 | 2. Then, store all the sub-chain samples into a cell `sub_chain = cell(1, 10)` with `sub_chain{i}` being a 5,000 x 4 matrix corresponding to MCMC samples on subset `i`. 34 | 35 | 3. Now, call PART to aggregate samples. The following code draws 10,000 samples from the combined posterior with **PART-KD** algorithm running **pairwise** aggregation. `combined_posterior_kd_pairwise` will be a 10,000 x 4 matrix. 36 | 37 | ``` octave 38 | options = part_options('cut_type', 'kd', 'resample_N', 10000); 39 | combined_posterior_kd_pairwise = aggregate_PART_pairwise(sub_chain, options); 40 | ``` 41 | 42 | > **Before running the code above**: run `src/init.m` to make sure paths are included. 43 | 44 | We strongly recommend reading and running the demos under the `examples/` directory. 45 | 46 | ### Aggregation Schemes 47 | 48 | We provide two schemes for aggregation: **pairwise** (recommended) and **one-stage**. 49 | 50 | 1. Calling `combined_posterior = aggregate_PART_pairwise(sub_chain, options)` will aggregate subset samples recursively by pairs. For examples, for 4 sub-chains, it will first aggregate 1 and 2 into (1+2) and 3 and 4 into (3+4), then aggregate (1+2) and (3+4) into the final samples. 51 | 52 | Set `options.match=true` (enabled by default) will match subsets into better pairs. 53 | 54 | 2. Calling `combined_posterior = aggregate_PART_onestage(sub_chain, options)` will aggregate all subsets at once. This is not recommended when the number of subsets is large. 55 | 56 | ### Options 57 | 58 | The parameters of PART are set by the `options` structure, which is configured by 59 | 60 | ``` 61 | options = part_options('Option_1', Value_1, 'Option_2', Value_2, ...) 62 | ``` 63 | 64 | in terms of *(Name, Value)* pairs. `part_options(…)` accepts the following configurations. 65 | 66 | 1. `ntree`: number of trees (default = 16) 67 | 2. `resample_N`: number of samples drawn from the combined posterior (default = 10,000) 68 | 3. `parallel`: if `true`, building trees in parallel with MATLAB Parallel Computing Toolbox (default = true) 69 | 4. `cut_type`: partition rule[1](#myfootnote1), currently supporting `kd` (median cut, default) and `ml` (maximum likelihood cut). `kd` can be significantly faster than `ml` for large sample size. 70 | 5. `min_cut_length`: stopping rule, minimum side length of a block, corresponding to $\delta_a$ in the paper[1](#myfootnote1) (default = 0.001). It is recommended to adjust the scale of the samples such that `min_cut_length` is applicable to all dimensions. Alternatively, one can manually set `min_cut_length` to a vector of p, corresponding to each dimension of the parameter. 71 | 6. `min_fraction_block`: stopping rule, corresponding to $\delta_{\rho}$ in the paper[1](#myfootnote1), such that the leaf nodes should contain no smaller a fraction of samples than this number. So roughly, this is the minimum probability of a block. A number in (0,1) is required (default = 0.01). 72 | 7. `local_gaussian_smoothing`: if `true`, local Gaussian smoothing[1](#myfootnote1) is applied to the block-wise densities. (default = `true`) 73 | 8. `match`: if `true`, doing better pairwise matching in `aggregate_PART_pairwise(…)`. (default = `true`) 74 | 75 | Please refer to `help part_options` for more options. 76 | 77 | ## Evaluation and Visualization Tools 78 | 79 | We have implemented the following functions in `src/utils` for evaluation and visualization. 80 | 81 | 1. `plot_marginal_compare({chain_1, chain_2, ...}, {name_1, name_2, ...}, ...)` plots the marginals of several samples. 82 | 2. `plot_pairwise_compare({chain_1, chain_2, ...}, {name_1, name_2, ...}, ...)` plots the bivariate distribution for several samples. 83 | 3. `approximate_KL(samples_left, samples_right)` computes the approximate KL divergence KL(left||right) between two samples. The KL is computed from Laplacian approximations fitted to both distributions. 84 | 4. `performance_table({chain_1, chain_2, ...}, {'chain_1', chain_2', ...}, full_posterior, theta)` outputs a table comparing the approximation accuracy of each chain to `full_posterior`. `theta` is a point estimator to the parameters sampled. 85 | 86 | ## Other Algorithms 87 | 88 | We also implemented several other popular aggregation algorithms under `src/alternatives`. 89 | 90 | 1. `aggregate_average(sub_chain)` aggregates by simple averaging. 91 | 2. `aggregate_weighted_average(sub_chain)` aggregates by weighted averaging, with weights optimal for Gaussian distributions. 92 | 3. `aggregate_uai_parametric(sub_chain, n)` outputs n samples by drawing from multiplied Laplacian approximation to subset posteriors[2](#myfootnote2). 93 | 4. `aggregate_uai_nonparametric(sub_chain, n)` draws n samples from multiplied kernel density estimation to subchain posteriors[2](#myfootnote2). 94 | 5. `aggregate_uai_semiparametric(sub_chain, n)` draws n samples from multiplied semi-parametric estimation to subchain posteriors[2](#myfootnote2). 95 | 96 | ## Inner Working 97 | 98 | ### Files 99 | 100 | ``` 101 | ├── README.md readme 102 | ├── docs 103 | │   └── tree15.pdf PART paper 104 | ├── examples/ QuickStart examples 105 | │   ├── dirichrnd.m sample from Dirichlet 106 | │   ├── sample_mog.m Gibbs sampling for mean of mixture of Gaussian 107 | │   ├── test_bimodal.m demo - bimodal 108 | │   ├── test_mixture_of_gaussian.m demo - mixture of Gaussian 109 | │   └── test_rare_binary.m demo - rare Bernoulli 110 | └── src 111 | ├── PART implementation of PART 112 | │   ├── MultiStageMCMC.m pairwise aggregation 113 | │   ├── NormalizeForest.m normalize trees in a forest 114 | │   ├── NormalizeTree.m flatten and normalize a tree 115 | │   ├── OneStageMCMC.m one-stage aggregation 116 | │   ├── aggregate_PART_onestage.m one-stage aggregation, used with part_options 117 | │   ├── aggregate_PART_pairwise.m pairwise aggregation, used with part_options 118 | │   ├── buildForest.m build a forest 119 | │   ├── buildHistogram.m generic 1-D histogram building with a supplied cutting function 120 | │   ├── buildTree.m build a partition tree 121 | │   ├── cleanTree.m cleans a given forest of random partition trees 122 | │   ├── copyTree.m a deepCopy function that copies a forest of random partition 123 | │   ├── cut various "partition rules" 124 | │   │   ├── consensusMLCut.m ML (maximum likelihood) 125 | │   │   ├── fastMLCut.m faster ML with stochastic approximation 126 | │   │   ├── kdCut.m KD/median 127 | │   │   ├── meanCut.m average 128 | │   │   └── midPointCut.m mid-point 129 | │   ├── empiricalLogLikelihood.m evaluates the empirlcal log likelihood of x under cuts 130 | │   ├── part_options.m configure an option for running PART algorithm 131 | │   ├── treeDensity.m evaluate the combined density for a given point 132 | │   ├── treeNode.m tree structure 133 | │   ├── treeNormalize.m normalizer 134 | │   └── treeSampling.m resample from a tree density estimator 135 | ├── alternatives/ a few other EP-MCMC algorithms 136 | ├── init.m add paths 137 | └── utils/ a few evaluation and visualization functions 138 | ``` 139 | 140 | ### Key Functions 141 | 142 | We give a brief introduction to the following key functions in PART MATLAB implementation. 143 | 144 | #### Low level 145 | 146 | 1. `[obj, nodes, log_probs] = buildTree(...)` builds a single tree by first pooling all the data across subsets so that the tree structure is determined. Then unnormalized probability and corresponding density are estimated by block-wise multiplication. `obj` is the root tree node; `nodes` are the leaf nodes that each corresponds to a block in the sample space;`log_probs` are **unnormalized** probabilities of those leaf nodes in logarithm, in correspondence to `nodes`. 147 | 2. `[tree, sampler, sampler_prob] = NormalizeTree(...)` performs (1) recomputing or (2) normalization of a tree by summerming over leaf nodes. 148 | 149 | #### Medium level 150 | 151 | 1. `[trees, sampler, sampler_prob] = buildForest(RawMCMC, display_info)` builds a random ensemble of trees by calling `buildTree(...)` in parallel. Each tree is also normalized with `NormalizeTree(...)`. 152 | 153 | #### High level 154 | 155 | 1. `[trees, sampler, sampler_prob, RawMCMC] = OneStageMCMC(MCdraws, option)` is a wrapper of `buildForest(...)`. The resulting ensemble consists of multiple trees, with each tree combines the posterior across subsets with block-wise direct multiplication. 156 | 2. `[trees, sampler, sampler_prob] = MultiStageMCMC(MCdraws, option)` does pairwise combine for multiple layers until converging into one posterior. Two subsets A and B are matched with some criteria, `OneStageMCMC(...)` is called to combine them into A+B. Samples are drawn from (A+B) and (C+D) respectively to represent the distributions after 1st stage combination. Then, (A+B) and (C+D) are combined into (A+B+C+D) by again calling`OneStageMCMC(...)` on their samples. 157 | 3. `aggregated_samples = aggregate_PART_pairwise(sub_chain, options)` is a high level wrapper of `MultiStageMCMC(...)` that draws samples from pairwise aggregated sub-chains. 158 | 4. `aggregated_samples = aggregate_PART_onestage(sub_chain, options)` is a high level wrapper of `OneStageMCMC(…)` that draws samples from one-stage aggregated sub-chains. 159 | 160 | ### References 161 | 162 | 1: Xiangyu Wang, Fangjian Guo, Katherine A. Heller and David B. Dunson. Parallelizing MCMC with random partition trees. NIPS 2015. 163 | 164 | 2: W Neiswanger, C Wang, E Xing. Asymptotically Exact, Embarrassingly Parallel MCMC. UAI 2014. 165 | 166 | ## Authors 167 | 168 | [Richard Guo](https://github.com/richardkwo) and [Samuel Wang](https://github.com/wwrechard). We would appreciate your feedback under [*issues*](https://github.com/richardkwo/random-tree-parallel-MCMC/issues) of this repository. 169 | 170 | ## License 171 | 172 | © Contributors, 2015. Released under the [MIT](https://github.com/richardkwo/random-tree-parallel-MCMC/blob/master/LICENSE) license. -------------------------------------------------------------------------------- /docs/tree15.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/richardkwo/random-tree-parallel-MCMC/321142b1e5ac3bf1d6142f8946f3ac6dea7dc030/docs/tree15.pdf -------------------------------------------------------------------------------- /examples/dirichrnd.m: -------------------------------------------------------------------------------- 1 | function p = dirichrnd(n, alpha) 2 | % a density of prod(x_i^(alpha_i - 1)) 3 | m = length(alpha); 4 | internal = gamrnd(repmat(alpha, n, 1),1); 5 | p = internal./repmat(sum(internal, 2),1,m); 6 | end -------------------------------------------------------------------------------- /examples/sample_mog.m: -------------------------------------------------------------------------------- 1 | function mu = sample_mog(data, Sigmas, Mus, N) 2 | n = size(data, 1); 3 | K = length(Sigmas); 4 | p = size(Mus, 2); 5 | mu = zeros(N, p); 6 | Z = zeros(n, K); 7 | W = zeros(n, K); 8 | for i=1:N 9 | % sample assignments 10 | for k=1:K 11 | W(:,k) = mvnpdf(data, Mus(k,:), Sigmas{k}); 12 | end 13 | for j=1:n 14 | Z(j, :) = 0; 15 | Z(j, randsample(K, 1, true, W(j,:))) = 1; 16 | end 17 | % sample mean 18 | for k=1:K 19 | Mus(k,:) = mvnrnd(mean(data(Z(:,k)==1, :)), Sigmas{k}/sum(Z(:,k))); 20 | end 21 | kk = randsample(K, 1, true, sum(Z)); 22 | mu(i,:) = Mus(kk, :); 23 | end 24 | mu = mu(randperm(N), :); 25 | end -------------------------------------------------------------------------------- /examples/test_bimodal.m: -------------------------------------------------------------------------------- 1 | addpath ../src/PART ../src/PART/cut ../src/alternatives ../src/utils 2 | 3 | %% Combine densities with two modes 4 | % The data is generated using the following generative model: 5 | % 6 | % $$x \sim \prod_{i = 1}^m p_i(x)$$ 7 | % 8 | % where m is the number of subsets, n is the number of observations and 9 | % each $p_i(\cdot)$ is a mixture of two Gaussians, 10 | % 11 | % $$ p_i(x) = w_{i1} N(u_{i1}, s_{i1}^2) + w_{i2} N(u_{i2}, s_{i2}^2)$$ 12 | 13 | p = 1; % dim 14 | m = 10; % number of subsets 15 | n = 10000; % number of total data 16 | u = [0 + normrnd(-5,0.5,m,1), 5 + normrnd(0,0.5,m,1)]; % means 17 | s = [1 + abs(normrnd(0,0.1,m,1)), 4 + abs(normrnd(0,0.1,m,1))]; %standard deviations 18 | w = zeros(m,2); %weights 19 | w(:,1) = 0.8/3; 20 | w(:,2) = 1 - w(:,1); 21 | 22 | % Here we show the true density of the data 23 | x = -20:0.001:20; 24 | y = zeros(m, length(x)); 25 | for i = 1:length(x) 26 | y(:,i) = w(:,1).*normpdf(x(i), u(:,1), s(:,1)) + w(:,2).*normpdf(x(i), u(:,2), s(:,2)); 27 | end 28 | 29 | y1 = prod(y); 30 | y1 = y1/sum(y1)/0.001; 31 | 32 | figure; 33 | ymax = max(y1) + 0.05; 34 | axis_f = [-10 10 0 ymax]; 35 | plot(x,y1,'-r'); 36 | axis(axis_f) 37 | hold on; 38 | title('The true density for the multiplied distribution') 39 | 40 | N = 5000; %Number of samples drawn from each subset. 41 | posterior_N = N; 42 | sub_chains = cell(1,m); 43 | sub_density = cell(1,m); 44 | for i = 1:m 45 | sub_chains{i} = [normrnd(u(i,1), s(i,1), floor(w(i,1)*N), 1);... 46 | normrnd(u(i,2), s(i,2), N - floor(w(i,1)*N), 1)]; 47 | sub_chains{i} = sub_chains{i}(randperm(size(sub_chains{i},1)),:); 48 | % subset density 49 | sub_density{i} = w(i,1) * normpdf(x, u(i,1), s(i,1)) + w(i,2) * normpdf(x, u(i,2), s(i,2)); 50 | plot(x, sub_density{i}, '--'); 51 | end 52 | 53 | 54 | %% Combine via PART: one-stage aggregation 55 | options = part_options('min_cut_length', 0.001, 'min_fraction_block', 0.01, 'ntree', 8, 'verbose', 2); 56 | combined_posterior_kd_onestage = aggregate_PART_onestage(sub_chains, options); 57 | options.cut_type = 'ml'; 58 | combined_posterior_ml_onestage = aggregate_PART_onestage(sub_chains, options); 59 | options.local_gaussian_smoothing = false; 60 | combined_posterior_ml_onestage_NoVar = aggregate_PART_onestage(sub_chains, options); 61 | options.cut_type = 'kd'; 62 | combined_posterior_kd_onestage_NoVar = aggregate_PART_onestage(sub_chains, options); 63 | 64 | 65 | 66 | figure; 67 | plot_marginal_compare({combined_posterior_kd_onestage, ... 68 | combined_posterior_ml_onestage, ... 69 | combined_posterior_kd_onestage_NoVar, ... 70 | combined_posterior_ml_onestage_NoVar}, ... 71 | { 'KD-onestage', 'ML-onestage','KD-onestage-NoSmoothing','ML-onestage-NoSmoothing'}); 72 | subplot(1,1,1); 73 | hold on; 74 | plot(x,y1,'--','DisplayName', 'True', 'LineWidth',2); 75 | legend('-DynamicLegend', 'Location', 'best'); 76 | title('PART: one-stage aggregation'); hold off; 77 | 78 | 79 | %% Combine via PART: pairwise aggregation 80 | options = part_options('min_cut_length', 0.001, 'min_fraction_block', 0.01, 'ntree', 8, 'verbose', 2); 81 | combined_posterior_kd_pairwise = aggregate_PART_pairwise(sub_chains, options); 82 | options.cut_type = 'ml'; 83 | combined_posterior_ml_pairwise = aggregate_PART_pairwise(sub_chains, options); 84 | 85 | figure; 86 | plot_marginal_compare({combined_posterior_kd_pairwise, ... 87 | combined_posterior_ml_pairwise}, ... 88 | { 'KD-pairwise','ML-pairwise'}); 89 | subplot(1,1,1); 90 | hold on; 91 | plot(x,y1,'--','DisplayName', 'True', 'LineWidth',2); 92 | legend('-DynamicLegend', 'Location', 'best'); 93 | title('PART: pairwise (multi-stage) aggregation'); hold off; 94 | 95 | 96 | %% Other methods for comparison 97 | 98 | % simple averaging 99 | combined_posterior_averaging = aggregate_average(sub_chains); 100 | 101 | % weighted averaging 102 | combined_posterior_weighted_averaging = aggregate_weighted_average(sub_chains); 103 | 104 | % UAI - parametric 105 | combined_posterior_parametric = aggregate_uai_parametric(sub_chains); 106 | 107 | % UAI - nonparametric 108 | combined_posterior_nonparametric = aggregate_uai_nonparametric(sub_chains, 1e4); 109 | 110 | % UAI - semiparametric 111 | combined_posterior_semiparametric = aggregate_uai_semiparametric(sub_chains, 1e4); 112 | 113 | figure; 114 | plot_marginal_compare({combined_posterior_kd_pairwise, combined_posterior_ml_pairwise, ... 115 | combined_posterior_averaging, ... 116 | combined_posterior_weighted_averaging, ... 117 | combined_posterior_parametric, ... 118 | combined_posterior_nonparametric, ... 119 | combined_posterior_semiparametric}, ... 120 | {'PART-KD', 'PART-ML', ... 121 | 'average', 'weighted average', ... 122 | 'Neiswanger - parametric', 'Neiswanger - nonparametric', 'Neiswanger - semiparametric'}); 123 | subplot(1,1,1); 124 | hold on; 125 | plot(x,y1,'--','DisplayName', 'True', 'LineWidth',2); 126 | legend('-DynamicLegend', 'Location', 'best'); 127 | title('Comparison of various algorithms'); hold off; 128 | -------------------------------------------------------------------------------- /examples/test_mixture_of_gaussian.m: -------------------------------------------------------------------------------- 1 | addpath ../src/PART ../src/PART/cut ../src/alternatives ../src/utils 2 | 3 | %% generate mixture of mv-normal sample, 2-D 4 | K = 2; % number of components 5 | M = 4; % split 6 | n = 400; 7 | Sigmas = cell(1,K); 8 | Mus = zeros(K, 2); 9 | 10 | % draw means 11 | mean_mu = [5, 5]; 12 | sigma_mu = 10 * eye(2); 13 | Mus = mvnrnd(repmat(mean_mu, K, 1), sigma_mu); 14 | 15 | % use the same cov for now 16 | for k=1:K 17 | Sigmas{k} = eye(2); 18 | end 19 | 20 | % mixture weights 21 | mixture_weights = dirichrnd(1, 1 * ones(1,K)); 22 | 23 | % draw samples 24 | all_data = zeros(n, 2); 25 | for i=1:n 26 | z =randsample(K, 1, true, mixture_weights); 27 | all_data(i,:) = mvnrnd(Mus(z,:), Sigmas{z}); 28 | end 29 | 30 | % split into subsets of data 31 | subset_data = cell(1,M); 32 | assign_vec = repmat(1:M, 1, n/M); 33 | for c=1:M 34 | subset_data{c} = all_data(assign_vec==c, :); 35 | end 36 | 37 | rr = floor(sqrt(M+1)); ss = ceil(sqrt(M+1)); 38 | for c=0:M 39 | if c==0 40 | figure; 41 | subplot(rr,ss,1); 42 | plot(all_data(:,1), all_data(:,2), 'k.'); 43 | hold on; 44 | plot(Mus(:,1)', Mus(:,2)', 'ro'); 45 | title('full data'); 46 | else 47 | subplot(rr,ss,c+1); 48 | plot(subset_data{c}(:,1), subset_data{c}(:,2), 'k.'); 49 | hold on; 50 | plot(Mus(:,1)', Mus(:,2)', 'ro'); 51 | title(sprintf('subset data %d', c)); 52 | end 53 | end 54 | 55 | %% mcmc 56 | 57 | N = 5000; 58 | fprintf('Sampling full chain...\n'); 59 | full_chain = sample_mog(all_data, Sigmas, Mus, N); 60 | sub_chain = cell(1, M); 61 | for c=1:M 62 | fprintf('Sampling subset %d...\n', c); 63 | 64 | sub_chain{c} = sample_mog(subset_data{c}, Sigmas, Mus, N); 65 | end 66 | 67 | rr = floor(sqrt(M+1)); ss = ceil(sqrt(M+1)); 68 | for c=0:M 69 | if c==0 70 | figure; 71 | subplot(rr,ss,1); 72 | plot(full_chain(:,1), full_chain(:,2), 'k.'); 73 | hold on; 74 | plot(Mus(:,1)', Mus(:,2)', 'ro'); 75 | title('full chain'); 76 | else 77 | subplot(rr,ss,c+1); 78 | plot(sub_chain{c}(:,1), sub_chain{c}(:,2), 'k.'); 79 | hold on; 80 | plot(Mus(:,1)', Mus(:,2)', 'ro'); 81 | title(sprintf('subchain %d', c)); 82 | end 83 | end 84 | 85 | 86 | 87 | %% averaging, weighted averaging & parametric 88 | combined_posterior_averaging = aggregate_average(sub_chain); 89 | combined_posterior_weighted_averaging = aggregate_weighted_average(sub_chain); 90 | combined_posterior_parametric = aggregate_uai_parametric(sub_chain); 91 | 92 | figure; 93 | plot(combined_posterior_averaging(:,1), combined_posterior_averaging(:,2), 'rx'); hold on; 94 | plot(combined_posterior_weighted_averaging(:,1), combined_posterior_weighted_averaging(:,2), 'bo'); 95 | plot(combined_posterior_parametric(:,1), combined_posterior_parametric(:,2), 'gs'); 96 | plot(full_chain(:,1), full_chain(:,2), 'k.'); 97 | legend('average', 'weighted average', 'parametric', 'True'); 98 | 99 | %% KD/ML aggregations 100 | 101 | options = part_options('min_cut_length', 0.01, 'min_fraction_block', 0.01, 'resample_N', 5000, 'ntree', 16); 102 | combined_posterior_kd_pairwise = aggregate_PART_pairwise(sub_chain, options); 103 | options.cut_type = 'ml'; 104 | [combined_posterior_ml_pairwise, sampler, ~] = aggregate_PART_pairwise(sub_chain, options); 105 | 106 | figure; 107 | plot(combined_posterior_kd_pairwise(:,1), combined_posterior_kd_pairwise(:,2), 'rx'); hold on; 108 | plot(combined_posterior_ml_pairwise(:,1), combined_posterior_ml_pairwise(:,2), 'bo'); 109 | plot(full_chain(:,1), full_chain(:,2), 'k.'); 110 | legend('PART-KD', 'PART-ML', 'True', 'Location', 'best'); 111 | 112 | figure; 113 | for i=1:16 114 | subplot(4,4,i); axis off; 115 | plot_tree_blocks(sampler{i}, []); 116 | title(['Tree ', num2str(i)]); 117 | end 118 | 119 | %% other combinations 120 | N = 1e5; ii = randsample(N, 1000); 121 | % UAI - nonparametric 122 | combined_posterior_nonparametric = aggregate_uai_nonparametric(sub_chain, N); 123 | % UAI - semiparametric 124 | combined_posterior_semiparametric = aggregate_uai_semiparametric(sub_chain, N); 125 | 126 | figure; 127 | plot(combined_posterior_nonparametric(ii,1), combined_posterior_nonparametric(ii,2), 'rx'); hold on; 128 | plot(combined_posterior_semiparametric(ii,1), combined_posterior_semiparametric(ii,2), 'bo'); 129 | plot(full_chain(:,1), full_chain(:,2), 'k.'); 130 | legend('nonparametric', 'semiparametric', 'True'); 131 | 132 | %% summary 133 | performance_table({full_chain, ... 134 | combined_posterior_kd_pairwise, ... 135 | combined_posterior_ml_pairwise, ... 136 | combined_posterior_averaging, ... 137 | combined_posterior_weighted_averaging, ... 138 | combined_posterior_parametric, ... 139 | combined_posterior_nonparametric, ... 140 | combined_posterior_semiparametric}, ... 141 | {'true', ... 142 | 'KD', 'ML', ... 143 | 'average', 'weighted average', ... 144 | 'Neiswanger - parametric', 'Neiswanger - nonparametric', 'Neiswanger - semiparametric'}, ... 145 | full_chain) 146 | fprintf('NOTE: this table may not be a good evaluation since the posterior is non-Gaussian.\n'); 147 | -------------------------------------------------------------------------------- /examples/test_rare_binary.m: -------------------------------------------------------------------------------- 1 | addpath ../src/PART ../src/PART/cut ../src/alternatives ../src/utils 2 | 3 | %% Binomial data with rare event 4 | % The data is generated using the following generative model: 5 | % 6 | % $$x \sim bin(p), \quad p = 2m/n$$ 7 | % 8 | % where m is the number of subsets, n is the number of observations. 9 | 10 | p = 1; % 1 dim 11 | cat = 2; %binomial distribution 12 | M = 15; %number of subsets 13 | n = 10000; %number of observations 14 | theta = zeros(1, cat); %probability 15 | theta(1) = 2*M/n; 16 | theta(2) = 1 - theta(1); 17 | 18 | % beta prior on p 19 | pseudocounts = [2 2]; 20 | 21 | N = 10000; % number drawn for posterior 22 | X = mnrnd(1, theta, n); 23 | 24 | % We then random partition the data set into m subsets 25 | sub_X = cell(1,M); 26 | for m = 1:M 27 | sub_X{m} = X(m:M:end, :); 28 | end 29 | subset_totals = cellfun(@(x) sum(x(:,1)), sub_X) 30 | 31 | % We draw N posterior samples for full data 32 | full_chain = dirichrnd(N, sum(X, 1) + pseudocounts); 33 | full_chain = full_chain(:,1); 34 | 35 | % We draw N posterir samples for each subset posterior 36 | sub_chain = cell(1,M); 37 | for m = 1:M 38 | sub_chain{m} = dirichrnd(N, sum(sub_X{m}, 1) + (pseudocounts/M + (M-1)/M)); 39 | sub_chain{m} = sub_chain{m}(:,1); 40 | end 41 | 42 | 43 | do_posterior_comparison = true; 44 | 45 | if do_posterior_comparison 46 | fprintf('Plotting full chain and sub-chains...\n'); 47 | figure; 48 | hold on; 49 | [f,xi] = ksdensity(full_chain); 50 | plot(xi,f,'DisplayName','full'); 51 | for l=1:M 52 | [f,xi] = ksdensity(sub_chain{l}); 53 | plot(xi,f,'DisplayName',['subset ',num2str(l)]); 54 | end 55 | % legend('-DynamicLegend', 'Location', 'best'); 56 | % title('theta'); 57 | hold off; 58 | 59 | end 60 | 61 | %% parametric 62 | combined_posterior_averaging = aggregate_average(sub_chain); 63 | combined_posterior_weighted_averaging = aggregate_weighted_average(sub_chain); 64 | combined_posterior_parametric = aggregate_uai_parametric(sub_chain); 65 | figure; 66 | plot_marginal_compare({full_chain, combined_posterior_parametric, combined_posterior_averaging, combined_posterior_weighted_averaging}, ... 67 | {'true', 'parametric','average','weighted'}); 68 | 69 | %% KD/ML aggregations 70 | 71 | options = part_options('min_cut_length', 0.01, 'resample_N', 10000); 72 | combined_posterior_kd_pairwise = aggregate_PART_pairwise(sub_chain, options); 73 | options.cut_type = 'ml'; 74 | combined_posterior_ml_pairwise = aggregate_PART_pairwise(sub_chain, options); 75 | 76 | figure; 77 | plot_marginal_compare({full_chain, ... 78 | combined_posterior_kd_pairwise, combined_posterior_ml_pairwise}, ... 79 | {'true', 'kd-pairwise', 'ml-pairwise'}); 80 | xlim([-2 14]*1e-3); 81 | 82 | 83 | %% other combinations 84 | % UAI - nonparametric 85 | combined_posterior_nonparametric = aggregate_uai_nonparametric(sub_chain, 1e4); 86 | 87 | % UAI - semiparametric 88 | combined_posterior_semiparametric = aggregate_uai_semiparametric(sub_chain, 1e4); 89 | 90 | %% plotting 91 | % marginal density plot 92 | figure; 93 | plot_marginal_compare({full_chain, sub_chain{1} ,... 94 | combined_posterior_kd_pairwise, ... 95 | combined_posterior_ml_pairwise, ... 96 | combined_posterior_averaging, ... 97 | combined_posterior_weighted_averaging, ... 98 | combined_posterior_parametric, ... 99 | combined_posterior_nonparametric, ... 100 | combined_posterior_semiparametric}, ... 101 | {'true', 'subset 1',... 102 | 'KD', 'ML', ... 103 | 'average', 'weighted average', ... 104 | 'Neiswanger - parametric', 'Neiswanger - nonparametric', 'Neiswanger - semiparametric'}); 105 | xlim([0 15]*1e-3); 106 | 107 | % comparison of accuracy 108 | performance_table({full_chain, ... 109 | combined_posterior_kd_pairwise, ... 110 | combined_posterior_ml_pairwise, ... 111 | combined_posterior_averaging, ... 112 | combined_posterior_weighted_averaging, ... 113 | combined_posterior_parametric, ... 114 | combined_posterior_nonparametric, ... 115 | combined_posterior_semiparametric}, ... 116 | {'true', ... 117 | 'KD', 'ML', ... 118 | 'average', 'weighted average', ... 119 | 'Neiswanger - parametric', 'Neiswanger - nonparametric', 'Neiswanger - semiparametric'}, ... 120 | full_chain, theta(1)) -------------------------------------------------------------------------------- /src/PART/MultiStageMCMC.m: -------------------------------------------------------------------------------- 1 | function [trees, sampler, sampler_prob] = MultiStageMCMC(MCdraws, option) 2 | % MultiStageMCMC is a function that takes an input of raw subset posteriors 3 | % and combine the posterior densities via a multistage manner, i.e., a 4 | % pairwise combining until a final one density is reached. The function 5 | % calls OneStageMCMC for combining each pair of posterior samples 6 | % 7 | % Calls: 8 | % [trees, sampler, sampler_prob] = MultiStageMCMC(MCdraws, option) 9 | % 10 | % Arguments: 11 | % 12 | % MCdraws: a 1 x m cell containing subset posteriors from m subsets 13 | % option: option configured by part_options(...). Use `option = part_options()` for default settings. 14 | % 15 | % Outputs: 16 | % 17 | % trees: A list containing multiple partition trees, each corresponding to a density estimation of the combined posterior 18 | % sampler: A list containing multiple flattened partition trees 19 | % sampler_prob: A list containing the probabilities for each flattened partition tree 20 | % 21 | % See also: 22 | % part_options, buildForest, OneStageMCMC 23 | 24 | 25 | %Setting the default value for parameters 26 | if ~isfield(option,'area'), area = []; else area = option.area;end 27 | min_fraction_block = option.min_fraction_block; 28 | if option.halving 29 | fprintf('\nPairwise aggregation with min fraction = %f, halving enabled. \n', min_fraction_block); 30 | else 31 | fprintf('\nPairwise aggregation with min fraction = %f, halving disabled. \n', min_fraction_block); 32 | end 33 | 34 | %Initialization 35 | m = length(MCdraws); %number of subsets 36 | d = size(MCdraws{1}, 2); %number of dimensions 37 | N = zeros(1, m); % number of posterior samples of each subsets 38 | 39 | tic; 40 | %To avoid duplicate computing, we give the default area from the very top 41 | if isempty(area) 42 | area = zeros(d, 2); 43 | area(:,1) = min(MCdraws{1}); 44 | area(:,2) = max(MCdraws{1}); 45 | for i = 1:m 46 | N(i) = size(MCdraws{i},1); %update posterior sample sizes 47 | area(:,1) = min([area(:,1)';MCdraws{i}]); %update boundaries 48 | area(:,2) = max([area(:,2)';MCdraws{i}]); 49 | end 50 | 51 | %enlarge the area by a factor of 1.01 52 | for j = 1:d 53 | if area(j,1)>0 54 | area(j,1) = area(j,1)/1.01; 55 | else 56 | area(j,1) = area(j,1)*1.01; 57 | end 58 | if area(j,2)>0 59 | area(j,2) = area(j,2)*1.01; 60 | else 61 | area(j,2) = area(j,2)/1.01; 62 | end 63 | end 64 | else 65 | for i = 1:m 66 | N(i) = size(MCdraws{i},1); 67 | end 68 | end 69 | 70 | %inner_option is used by OneStageMCMC for combining paired subsets. We 71 | %inhibit the information display from there 72 | inner_option = option; 73 | inner_option.verbose = 0; 74 | number_of_stages = floor(log(m)/log(2)); 75 | if option.halving 76 | current_min_fraction_block = min_fraction_block * (2^number_of_stages); 77 | else 78 | current_min_fraction_block = min_fraction_block; 79 | end 80 | inner_option.min_number_points_block = max(ceil(min(N) * current_min_fraction_block), 3); 81 | 82 | %Now we start combining posterior densities in a pairwise manner 83 | 84 | %We use multiple tree at each stage 85 | 86 | MC_left = m; %Number of remaining subsets 87 | Internal_list = MCdraws; %Remaining subset posteriors. 88 | 89 | %setting up waitbar 90 | if option.verbose>1 91 | h = waitbar(0,'Start building multistage tree ensemble...'); 92 | end 93 | 94 | %Initialize stage number (for waitbar use) 95 | stage = 1; 96 | 97 | %Starting pairwise combining! 98 | while MC_left > 3 99 | %update waitbar information 100 | if option.verbose>1 101 | waitbar(0.01, h, ['Building multistage tree ensemble for stage ', num2str(stage), ' ...']); 102 | end 103 | 104 | MC_match = cell(1,floor(MC_left/2)); %Restoring all pairs 105 | 106 | %Initialize pairs as 1-2, 3-4, 5-6, ... 107 | for MC = 1:floor(MC_left/2) 108 | MC_match{MC} = [2*MC-1 2*MC]; 109 | end 110 | 111 | %If the number of subset is odd, we add the last subset to the 112 | %previous pair. So like 1-2, 3-4, 5-6-7... 113 | if MC_left > floor(MC_left/2)*2; 114 | MC_match{floor(MC_left/2)} = [MC_match{floor(MC_left/2)} MC_left]; 115 | end 116 | 117 | %If we want to do a better match 118 | if option.match 119 | %updating waitbar 120 | if option.verbose>1 121 | waitbar(0.01,h,'Starting matching subsets...'); 122 | end 123 | 124 | %Compute means for all subset posteriors 125 | MC_means = zeros(d,MC_left); 126 | for MC = 1:MC_left 127 | MC_means(:,MC) = mean(Internal_list{MC})'; 128 | end 129 | 130 | %match the subset with another subset in median distance 131 | MC_used = []; %store subset that has been used 132 | MC_left_set = 1:MC_left; %initialize subsets haven't been paired up 133 | pair = 1; %current number of pairs + 1 134 | for pair_first = 1:MC_left 135 | %update waitbar 136 | if option.verbose>1 137 | waitbar(1-length(MC_used)/MC_left,h,'Matching subsets...'); 138 | end 139 | 140 | if ~ismember(pair_first, MC_used) %check if the subset has been used 141 | %compute distances 142 | dis = MC_means(:,MC_left_set) -... 143 | repmat(MC_means(:,pair_first),1,length(MC_left_set)); 144 | dis = sum(dis.^2,1); 145 | 146 | %sort distances 147 | [~,ix] = sort(dis); 148 | 149 | %The subset to pair is the one possesses the median 150 | pair_second = MC_left_set(ix(ceil(length(MC_left_set)/2))); 151 | 152 | %argmenting the used subsets and remaining subsets 153 | MC_used = [MC_used, pair_first, pair_second]; 154 | MC_left_set = setdiff(1:MC_left, MC_used); 155 | 156 | %saving the pair into MC_match 157 | MC_match{pair} = [pair_first pair_second]; 158 | 159 | %Terminating for odd or even number of subsets 160 | if pair == floor(MC_left/2) && MC_left > 2*pair 161 | last = setdiff(1:MC_left, MC_used); 162 | MC_match{pair} = [MC_match{pair}, last]; 163 | break; 164 | elseif pair == floor(MC_left/2) 165 | break; 166 | end 167 | pair = pair+1; 168 | end 169 | end 170 | end 171 | 172 | MC_left = floor(MC_left/2); %Number of subsets left after the current round 173 | list_temp = cell(1, MC_left); %Storing the combined posterior samples for each pair 174 | 175 | %Starting merging 176 | for MC_pair = 1:MC_left 177 | %setting up waitbar 178 | if option.verbose>1 179 | waitbar(MC_pair/MC_left,h,['Pairwise combining subsets for stage ', num2str(stage), '...']); 180 | end 181 | 182 | %Call OneStageMCMC to combine the pair of subsets 183 | [~, sampler, prob, RawMCMC] = OneStageMCMC(Internal_list(MC_match{MC_pair}), inner_option); 184 | if option.resample_data_only 185 | inner_option.mark = RawMCMC.mark; 186 | inner_option.list = RawMCMC.list; 187 | else 188 | inner_option.list = []; 189 | inner_option.mark = []; 190 | end 191 | 192 | %Resample from the combined density 193 | inner_N = max([max(N), option.resample_N]); 194 | list_temp{MC_pair} = treeSampling(sampler, prob, inner_N, inner_option); 195 | end 196 | 197 | Internal_list = list_temp; %update the posterior samples for the remaining subsets 198 | stage = stage + 1; %moving forward! 199 | if option.halving 200 | % if "halve", min block size is halved after each stage 201 | current_min_fraction_block = current_min_fraction_block / 2; 202 | end 203 | inner_option.min_number_points_block = max(ceil(current_min_fraction_block * inner_N), 3); 204 | end 205 | 206 | %close waitbar 207 | if option.verbose>1 208 | close(h); 209 | end 210 | 211 | %For final stage, we call OneStageMCMC the last time to combine the 212 | %final two or three subsets. 213 | option.min_number_points_block = inner_option.min_number_points_block; 214 | option.min_cut_length = inner_option.min_cut_length; 215 | [trees, sampler, sampler_prob] = OneStageMCMC(Internal_list,option); 216 | 217 | fprintf('Pairwise aggregation: finsihed in %f seconds \n', toc); 218 | 219 | end -------------------------------------------------------------------------------- /src/PART/NormalizeForest.m: -------------------------------------------------------------------------------- 1 | function [newForest, sampler, sampler_prob] = NormalizeForest(Forest, MCdraws) 2 | %NormalizeForest is a function to process a forest of unprocessed 3 | %random partition trees. The situation that we need this function is 4 | %that when a tree stucture is given, and we insert subset posteriors into 5 | %that structure, then we need process the forest to obtain probability, 6 | %density, node.mean and node.var and output the corresponding 7 | %partition tree and sampler. NormalizeForest calls NormalizeTree to 8 | %normalize each of its component and produce the corresponding sampler. 9 | % 10 | %Calls: 11 | %[newForest, sampler, sampler_prob] = NormalizeForest(Forest, MCdraws) 12 | % 13 | %Arguments: 14 | % 15 | %Forest: A forest of unprocessed partition tree. Typically resulted 16 | % from InsertNode.m 17 | %MCdraws: A list of raw MCMC samples from subsets 18 | % 19 | %Output: 20 | % 21 | %newForest: Normalized Forest that can be used by treeDensity 22 | %sampler: Flattened normalized forest that can be used by treeSampling 23 | %sampler_prob: corresponding log(probability) of sampler 24 | % 25 | %See also: 26 | %NormalizeTree 27 | % 28 | 29 | ntree = length(Forest); %Number of trees 30 | 31 | %Initialization 32 | newForest = cell(1,ntree); 33 | sampler = cell(1,ntree); 34 | sampler_prob = cell(1,ntree); 35 | 36 | %Setting up waitbar 37 | h = waitbar(0, 'Starting normalizing forest...'); 38 | 39 | %Calling NormalizeTree to process 40 | for i = 1:ntree 41 | waitbar((i-1)/ntree, h, ['Evaluating the ', num2str(i), 'th tree...']); 42 | [a, b, p] = NormalizeTree(Forest{i}, MCdraws); 43 | newForest{i} = a; 44 | sampler{i} = b; 45 | sampler_prob{i} = p; 46 | end 47 | close(h); 48 | end -------------------------------------------------------------------------------- /src/PART/NormalizeTree.m: -------------------------------------------------------------------------------- 1 | function [tree, sampler, sampler_prob] = NormalizeTree(tree, MCdraws, log_total_sum) 2 | %NormalizeTree is a function that takes an unnormalized tree 3 | %and return a normalized tree. When NormalizeTree is called with 4 | %only two arguments, it will first search the whole tree, compute the 5 | %treeNode.log_prob (leaf probability) and treeNode.log_density (leaf density) 6 | %for each leaf with information in treeNode.point and update treeNode.mean 7 | %and treeNode.var using MCdraws. It will simultaneously add up 8 | %the total probability and use the obtained total probability to 9 | %normalize the tree, and return the normalized tree + normalized 10 | %flattened tree(sampler) and the corresponding (sampler_prob) 11 | % 12 | %When NormalizeTree is called with three arguments, it searches over 13 | %tree and normalize the tree but only return the normalized tree 14 | %without all computation as well as sampler and sampler_prob 15 | % 16 | %It is called by NormalizeForest to handle forest input 17 | % 18 | %Call: 19 | %tree = Normalize(tree, [], log_total_sum) 20 | %[tree, sampler, sampler_prob] = Normalize(tree, MCdraws) 21 | % 22 | %Arguments: 23 | % 24 | %tree: an output by buildTree, or an element of trees cells outputted by 25 | % buildForest, OneStageMCMC or MultiStageMCMC, or any copied structure. 26 | %MCdraws: A list of subsets containing posterior samples on different 27 | % subsets 28 | %log_total_sum: the logarithm of sums of exp(leaf.log_prob) 29 | % 30 | %Outputs: 31 | % 32 | %trees: Normalized random partition tree for density 33 | % evaluation. 34 | %sampler: Normalized flattened tree for resampling (only outputted when 35 | % the number of inputs is two) 36 | %sampler_prob: Corresponding probability for sampler (only outputted when 37 | % the number of inputs is two) 38 | % 39 | %See also: 40 | %buildTree, NormalizeForest, buildForest 41 | % 42 | 43 | if nargin == 2 44 | %Two inputs, need to compute all components for leaves and then 45 | %normalize the tree and output the sampler 46 | 47 | total_set = length(MCdraws);%total subsets 48 | 49 | %Initialize sampler, sampler_prob, dimension and subset posteiror sizes 50 | sampler = {}; 51 | sampler_prob = []; 52 | d = size(MCdraws{1},2); 53 | % size of each subset 54 | N = zeros(1,total_set); 55 | for i = 1:total_set 56 | N(i) = size(MCdraws{i},1); 57 | end 58 | 59 | %Initialize the total probability 60 | %total_prob = 0; 61 | 62 | %We use an stack to write the non-recursion way of tree traverse 63 | stack = tree; 64 | 65 | while ~isempty(stack) 66 | %when stack is not empty, pop up the last element 67 | current_node = stack(end); 68 | stack = stack(1:end-1); %pop the last one 69 | 70 | if isempty(current_node.left) && isempty(current_node.right) 71 | %If this element is leaf already, do the calculation... 72 | 73 | l = current_node.area(:,2) - current_node.area(:,1);%Size of the area 74 | 75 | %If multiple subsets 76 | if total_set > 1 77 | %count numbers of points in different subsets. 78 | if ~isempty(current_node.point) 79 | c = histc(current_node.point(:,2), 1:total_set); % treeNode.point(i,j) = i-th sample in subset j 80 | else 81 | c = zeros(1, total_set); 82 | end 83 | counts = c + 0.01;%bound away from 0 84 | 85 | %compute the probability and the density 86 | current_node.log_prob = sum(log(counts)) - sum(log(N)) - (total_set - 1)*sum(log(l)); %leave normalizing to next step in logarithm 87 | current_node.log_density = current_node.log_prob - sum(log(l)); %divide by the area in logarithm 88 | 89 | %If only one subset 90 | else 91 | n = size(current_node.point,1); 92 | current_node.log_prob = log(n) - log(N); 93 | current_node.log_density = current_node.log_prob - sum(log(l)); 94 | end 95 | 96 | %Adding up the unnormalized probability 97 | %total_prob = total_prob + exp(current_node.prob); 98 | 99 | %Next, to compute the node.mean and node.var 100 | if total_set>1 101 | %When multiple subsets. Because we just counted the 102 | %number of points within each subset in "c". We can 103 | %make use of that to pick the original data point from 104 | %MCdraws. 105 | 106 | num_point = size(current_node.point,1); %total number of points 107 | list = zeros(num_point,d); %temporal list to cached these points 108 | cached = 0; %number of poins cached already 109 | 110 | %looping over all subsets 111 | for cached_set = 1:total_set 112 | if c(cached_set)>0 113 | %If there are points from cached_set, we obtain 114 | %index from node.point, transform it back to 115 | %their id in the original subset 116 | index = (cached+1): (cached + c(cached_set)); 117 | list(index,:) = MCdraws{cached_set}(current_node.point(index,1),:); 118 | cached = cached + c(cached_set); 119 | end 120 | end 121 | 122 | %Compute mean and variance 123 | current_node.mean = mean(list); 124 | current_node.var = var(list)/total_set; 125 | else 126 | %when there is only one subset 127 | current_node.mean = mean(MCdraws{1}(current_node.point(:,1),:)); 128 | current_node.var = var(MCdraws{1}(current_node.point(:,1),:)); 129 | end 130 | 131 | %argmenting the sampler and sampler probability 132 | sampler = [sampler, {current_node}]; 133 | sampler_prob = [sampler_prob, current_node.log_prob]; 134 | 135 | %If the current node is not a leaf, we push its left and right 136 | %child into the stack. 137 | else 138 | stack = [stack, current_node.left, current_node.right]; 139 | end 140 | end 141 | 142 | %We normalize the sampler_prob by using total_probability 143 | max_prob = max(sampler_prob); 144 | 145 | %Now we call NormalizeTree again but using three inputs to get tree 146 | %normalized by log(total_prob) 147 | tree = NormalizeTree(tree, [], max_prob + log(sum(exp(sampler_prob - max_prob)))); 148 | sampler_prob = sampler_prob - max_prob - log(sum(exp(sampler_prob - max_prob))); 149 | 150 | else 151 | %When the function is called with three inputs, we just use 152 | %log_total_sum to normalize it. 153 | 154 | %We use an iterative way to transverse the tree 155 | stack = tree; 156 | 157 | %Just to initialize to avoid any error information 158 | sampler = {}; 159 | sampler_prob = []; 160 | 161 | %Transpassing the tree and update the node.prob and node.value 162 | while ~isempty(stack) 163 | if isempty(stack(end).left) && isempty(stack(end).right) 164 | % this is a leaf node 165 | stack(end).log_prob = stack(end).log_prob - log_total_sum; 166 | % 1-1 mapped: normalized density = normalized prob / (area of that block) 167 | stack(end).log_density = stack(end).log_density - log_total_sum; 168 | stack = stack(1:end-1); %pop the last one 169 | else 170 | stack = [stack(1:end-1), stack(end).left, stack(end).right]; 171 | end 172 | end 173 | end 174 | end -------------------------------------------------------------------------------- /src/PART/OneStageMCMC.m: -------------------------------------------------------------------------------- 1 | function [trees, sampler, sampler_prob, RawMCMC] = OneStageMCMC(MCdraws, option) 2 | % OneStageMCMC is a function that takes an input of raw subset posteriors 3 | % and package it into a RawMCMC structure. It will also call buildForest 4 | % automatically to produce the combined densities in an one-stage manner 5 | % 6 | % Calls: 7 | % [trees, sampler, sampler_prob] = OneStageMCMC(MCdraws, option) 8 | % 9 | % Arguments: 10 | % 11 | % MCdraws: a 1 x m cell containing subset posteriors from m subsets 12 | % option: option configured by part_options(...). Use `option = part_options()` for default settings. 13 | % 14 | % Outputs: 15 | % 16 | % trees: A list containing multiple partition trees, each corresponding 17 | % to a density estimation of the combined posterior. Used by 18 | % treeDensity to evaluate density at a given point 19 | % sampler: A list containing multiple flattened partition trees, used by 20 | % treeSampling to resample from the forest 21 | % sampler_prob: A list containing the probabilities for each flattened 22 | % partition tree, used be called by treeSampling 23 | % 24 | % See also: 25 | % buildForest, MultiStageMCMC, NormalizeTree, treeSampling, treeDensity 26 | % 27 | 28 | % setting up the default value for parameters. 29 | 30 | if ~isfield(option,'area'), area = [];else area = option.area;end 31 | if ~isfield(option, 'rule') 32 | switch option.cut_type 33 | case 'kd' 34 | option.rule = @kdCut; 35 | case 'ml' 36 | option.rule = @consensusMLCut; 37 | end 38 | end 39 | m = length(MCdraws); %number of subsets 40 | d = size(MCdraws{1}, 2); %number of dimensions 41 | N = zeros(1, m); %Store number of posterior samples for different subsets 42 | 43 | %Initialization 44 | list = []; %To concatenate MCMC samples 45 | mark = []; %To index different points 46 | 47 | for i = 1:m 48 | N(i) = size(MCdraws{i},1); %Obtain the subset posterior size 49 | list = [list;MCdraws{i}]; %Concatenate the posterior draws 50 | mark = [mark;[1:N(i);ones(1,N(i))*i]']; %augmenting the index list 51 | end 52 | 53 | if ~isfield(option, 'min_number_points_block') 54 | RawMCMC.min_number_points_block = ceil(max(N) * option.min_fraction_block); 55 | else 56 | RawMCMC.min_number_points_block = option.min_number_points_block; 57 | end 58 | 59 | 60 | %Packaging all information into RawMCMC 61 | RawMCMC.list = list; 62 | RawMCMC.area = area; 63 | RawMCMC.N = N; 64 | RawMCMC.para = 1:d; 65 | RawMCMC.mark = mark; 66 | RawMCMC.total_set = m; 67 | RawMCMC.ntree = option.ntree; 68 | RawMCMC.rule = option.rule; 69 | RawMCMC.parallel = option.parallel; 70 | RawMCMC.verbose = option.verbose; 71 | RawMCMC.min_cut_length = option.min_cut_length; 72 | 73 | %Call buildForest to produce the combined density 74 | fprintf('Building tree ensemble: minimum %d points per block, minimum side length = %f\n', RawMCMC.min_number_points_block, RawMCMC.min_cut_length); 75 | [trees, sampler, sampler_prob] = buildForest(RawMCMC, option.verbose); 76 | block_numbers = zeros(1, option.ntree); 77 | for t=1:option.ntree 78 | block_numbers(t) = length(sampler{t}); 79 | end 80 | fprintf('\nRandom tree ensemble constructed -- %d trees and %f leafs per tree\n', option.ntree, mean(block_numbers)); 81 | end -------------------------------------------------------------------------------- /src/PART/aggregate_PART_onestage.m: -------------------------------------------------------------------------------- 1 | function [aggregated_samples, sampler_onestage, prob_onestage] = aggregate_PART_onestage(sub_chains, options) 2 | % Aggregating sub-chain samples with PART by one-stage combination 3 | % 4 | % Call: 5 | % [aggregated_samples, sampler_onestage, prob_onestage] = aggregate_PART_onestage(sub_chains, options) 6 | % 7 | % Arguments: 8 | % sub_chains: a 1 x m cell of m sub-chains. 9 | % sub_chains{i} should be an N_i x p matrix for N_i number of p-dimensional MCMC samples. 10 | % options: options configured by part_options(...). Use options = part_options() for default settings. 11 | % 12 | % Output: 13 | % aggregated_samples: aggregated posterior samples. Number of samples set by options.resample_N. 14 | % sampler_onestage: flattened partition trees, used by is used by treeSampling(...) and teeDensity(...). 15 | % prob_onestage: normalized leaf probabilities corresponding to sampler_onestage, used by is used by treeSampling(...) and teeDensity(...). 16 | 17 | N = max([size(sub_chains{1}, 1), options.resample_N]); 18 | M = length(sub_chains); 19 | 20 | options.UseData = false; 21 | options.list = []; 22 | options.mark = []; 23 | 24 | if strcmp(options.cut_type, 'ml') 25 | options.rule = @consensusMLCut; 26 | elseif strcmp(options.cut_type, 'kd') 27 | options.rule = @kdCut; 28 | end 29 | 30 | [~, sampler_onestage, prob_onestage] = OneStageMCMC(sub_chains, options); 31 | aggregated_samples = treeSampling(sampler_onestage, prob_onestage, N, options); 32 | 33 | end 34 | 35 | -------------------------------------------------------------------------------- /src/PART/aggregate_PART_pairwise.m: -------------------------------------------------------------------------------- 1 | function [aggregated_samples, sampler_pairwise, prob_pairwise] = aggregate_PART_pairwise(sub_chains, options) 2 | % Aggregating sub-chain samples with PART by pairwise combination 3 | % 4 | % Call: 5 | % aggregated_samples = aggregate_PART_pairwise(sub_chains, options) 6 | % 7 | % Arguments: 8 | % sub_chains: a 1 x m cell of m sub-chains. 9 | % sub_chains{i} should be an N_i x p matrix for N_i number of p-dimensional MCMC samples. 10 | % options: options configured by part_options(...). Use options = part_options() for default settings. 11 | % 12 | % Output: 13 | % aggregated_samples: aggregated posterior samples. Number of samples set by options.resample_N. 14 | 15 | 16 | N = max([size(sub_chains{1}, 1), options.resample_N]); 17 | 18 | options.UseData = false; 19 | options.list = []; 20 | options.mark = []; 21 | 22 | if strcmp(options.cut_type, 'ml') 23 | options.rule = @consensusMLCut; 24 | elseif strcmp(options.cut_type, 'kd') 25 | options.rule = @kdCut; 26 | end 27 | 28 | [~, sampler_pairwise, prob_pairwise] = MultiStageMCMC(sub_chains, options); 29 | aggregated_samples = treeSampling(sampler_pairwise, prob_pairwise, N, options); 30 | 31 | end 32 | 33 | -------------------------------------------------------------------------------- /src/PART/buildForest.m: -------------------------------------------------------------------------------- 1 | function [trees, sampler, sampler_prob] = buildForest(RawMCMC, verbose) 2 | %buildForest is a function that takes input of a list of one or multiple subset posteriors 3 | %to produce the combined posterior density using an ENSEMBLE of random partition trees. 4 | %The function calls buildTree multiple times to build a number (specified by user) 5 | %of iid partition trees. This function can be used alone or 6 | %called by CombineMCMC and MultiMCMC for combining with different 7 | %strategies 8 | % 9 | %Calls: 10 | %[trees, sampler, sampler_prob] = buildForest(RawMCMC, verbose) 11 | % 12 | %Arguments: 13 | % 14 | %RawMCMC is a data structure containing: 15 | % list: a list of MCMC samples from different subsets 16 | % area: the total area for cutting, the 17 | % N: number of posterior samples of each subset MCMC samples 18 | % para: the columns/dimensions that represent parameters of interest 19 | % mark: Index of data point. It would cost too much to store data directly 20 | % in "obj" and "p", we instead store their index within each node 21 | % in a form of (i, j) where j is the index for the subset and i is 22 | % the index of the data point in that subset. For example, (1, 2) 23 | % means the first point from the second subset 24 | % total_set: total number of subsets 25 | % ntree: the number of random partition trees to grow 26 | % rule: A function handle that determines the optimal cut at each gate 27 | % parallel: if True, building trees in a forest in parallel 28 | % max_n_unique: FORCE it to be a leaf node if has < max_n_unique number of unique samples in a region 29 | % min_number_points_block: stopping criteria: continue cutting if the number of points in a block > min_number_points_block 30 | % min_cut_length: either a scalar of a vector of the same dim as parameters. The lengths of a block >= min_cut_length 31 | % 32 | %verbose: indicateing whether a waitbar should be shown 33 | % 34 | %Outputs: 35 | % 36 | %trees: A list containing multiple partiiton trees, each corresponding 37 | % to a density estimation of the combined posterior. Used by 38 | % treeDensity to evaluate density at a given point 39 | %sampler: A list containing multiple flattened partition trees, used by 40 | % treeSampling to resample from the forest: 41 | % {(leafs/blocks of tree_1), (leafs/blocks of tree_2), ...} 42 | %sampler_prob: A list containing the probabilities for each flattened 43 | % partition tree, used be called by treeSampling: 44 | % {(probs for blocks of tree_1), (probs for blocks of tree_2), ...} 45 | % 46 | %See also: 47 | %buildTree, OneStageMCMC, MultiStageMCMC, NormalizeTree, treeSampling, 48 | %treeDensity 49 | % 50 | 51 | %By default, we use a waitbar 52 | if nargin == 1 53 | verbose = 0; 54 | end 55 | 56 | %Obtain the list of subset posteriors 57 | if ~isfield(RawMCMC, 'list'), error('no data loaded.');else list = RawMCMC.list; end 58 | [n, d] = size(list); 59 | 60 | %Check the different component of RawMCMC and specify the default value 61 | if ~isfield(RawMCMC, 'area'), area = []; else area = RawMCMC.area; end 62 | if ~isfield(RawMCMC, 'N'), error('I dont know the number of posterior samples on each subset.'), else N = RawMCMC.N; end 63 | if ~isfield(RawMCMC, 'para'), para = 1:d; else para = RawMCMC.para; end 64 | if ~isfield(RawMCMC, 'mark'), mark = [1:n;ones(1,n)]'; else mark = RawMCMC.mark; end 65 | if ~isfield(RawMCMC, 'total_set'), total_set = 1; else total_set = RawMCMC.total_set; end 66 | if ~isfield(RawMCMC, 'ntree'), ntree = 20; else ntree = RawMCMC.ntree; end 67 | if ~isfield(RawMCMC, 'rule'), rule = @kdCut; else rule = RawMCMC.rule; end 68 | if ~isfield(RawMCMC, 'parallel'), do_parallel=false; else do_parallel=RawMCMC.parallel; end 69 | if ~isfield(RawMCMC, 'max_n_unique'), max_n_unique=10; else max_n_unique=RawMCMC.max_n_unique; end 70 | if ~isfield(RawMCMC, 'min_number_points_block'), min_number_points_block=20; else min_number_points_block=RawMCMC.min_number_points_block; end 71 | if ~isfield(RawMCMC, 'min_cut_length'), min_cut_length=0.1; else min_cut_length=RawMCMC.min_cut_length; end 72 | 73 | %Initialization. Cell lists to restore information for different random 74 | %partition trees 75 | trees = cell(1,ntree); 76 | sampler = cell(1,ntree); 77 | sampler_prob = cell(1,ntree); 78 | 79 | %Set up waitbar 80 | if verbose>1 81 | h = waitbar(0,'The final stage...'); 82 | end 83 | 84 | parfor_progress(ntree); 85 | %Start building forest... 86 | if do_parallel && ntree>1 87 | fprintf('\nBuilding a forest of %d trees in parallel...\n', ntree); 88 | parfor i = 1:ntree 89 | %Call buildTree to build random partition tree 90 | [tree, nodes, log_probs] = buildTree(list, area, N, para, mark, total_set, ... 91 | min_number_points_block, min_cut_length, 0, rule, max_n_unique); 92 | % log of normalizer 93 | log_normalizer = 0; 94 | %When there are multiple subsets, we need an extra normalizing step 95 | if total_set > 1 96 | %normalizing densities using the total probability 97 | % first substracting the max to prevent enormous exp(..) 98 | max_log_prob = max(log_probs); 99 | log_normalizer = max_log_prob + log(sum(exp(log_probs - max_log_prob))); 100 | tree = NormalizeTree(tree, [], log_normalizer); 101 | end 102 | 103 | %save the results 104 | trees{i} = tree; 105 | sampler{i} = nodes; 106 | sampler_prob{i} = log_probs - log_normalizer; % normalized probs in log scale 107 | 108 | assert(abs(sum(exp(sampler_prob{i}))-1)<1e-10) % FIXME: comment it out 109 | 110 | % text update (only this works for parallel) 111 | parfor_progress; 112 | end 113 | parfor_progress(0); 114 | else 115 | fprintf('\nBuilding a forest of %d trees...\n', ntree); 116 | for i = 1:ntree 117 | fprintf('Building %d-th out of %d trees...\n', i, ntree); 118 | %Call buildTree to build random partition tree 119 | [tree, nodes, log_probs] = buildTree(list, area, N, para, mark, total_set, ... 120 | min_number_points_block, min_cut_length, 0, rule, max_n_unique); 121 | % log of normalizer 122 | log_normalizer = 0; 123 | %When there are multiple subsets, we need an extra normalizing step 124 | if total_set > 1 125 | %normalizing densities using the total probability 126 | % first substracting the max to prevent enormous exp(..) 127 | max_log_prob = max(log_probs); 128 | log_normalizer = max_log_prob + log(sum(exp(log_probs - max_log_prob))); 129 | tree = NormalizeTree(tree, [], log_normalizer); 130 | end 131 | 132 | %save the results 133 | trees{i} = tree; 134 | sampler{i} = nodes; 135 | sampler_prob{i} = log_probs - log_normalizer; % normalized probs in log scale 136 | 137 | assert(abs(sum(exp(sampler_prob{i}))-1)<1e-10) % FIXME: comment it out 138 | 139 | % update waitbar 140 | if verbose>1 141 | waitbar(i/ntree,h); 142 | end 143 | end 144 | end 145 | %close waitbar 146 | if verbose>1 147 | close(h); 148 | end 149 | end -------------------------------------------------------------------------------- /src/PART/buildHistogram.m: -------------------------------------------------------------------------------- 1 | function cuts = buildHistogram(x, l, r, cutFun, varargin) 2 | % generic 1-D histogram building with a supplied cutting function 3 | % cuts = buildHistogram(x, l, r, cutFun) 4 | % 5 | % x: data 6 | % l: left boundary 7 | % r: right boundary 8 | % cutFun: [index value] = cutFun(x, l, r) 9 | % optional: 10 | % 'minNodeSize': 10 (default) 11 | % 'minDepth': 3 (default) 12 | % 13 | % cuts: a set of cutting points 14 | 15 | % cuts = histogramMaxLCut(x, l, r, varargin) 16 | % 17 | assert(all(x>=l) && all(x<=r), 'not all x in the range [l,r]'); 18 | options = struct('minNodeSize', 10, 'minDepth', 3); 19 | 20 | %# read the acceptable names 21 | optionNames = fieldnames(options); 22 | 23 | %# count arguments 24 | nArgs = length(varargin); 25 | if round(nArgs/2)~=nArgs/2 26 | error('EXAMPLE needs propertyName/propertyValue pairs') 27 | end 28 | 29 | for pair = reshape(varargin,2,[]) %# pair is {propName;propValue} 30 | inpName = pair{1}; %# make case insensitive 31 | if any(strcmp(inpName,optionNames)) 32 | options.(inpName) = pair{2}; 33 | else 34 | error('%s is not a recognized parameter name',inpName) 35 | end 36 | end 37 | 38 | cuts = [l r]; 39 | depth = 0; 40 | while true 41 | new_cuts = []; 42 | for k=1:length(cuts)-1 43 | x_sub = x(x>=cuts(k) & x %f\n', cuts(k), cuts(k+1), cut); 49 | n_left = sum(x_sub<=cut); 50 | n_right = sum(x_sub>cut); 51 | if n_left>options.minNodeSize && n_right>options.minNodeSize 52 | % accept this good cut 53 | new_cuts = [new_cuts, cut]; 54 | end 55 | end 56 | 57 | if isempty(new_cuts) 58 | if depth < options.minDepth 59 | warning('Depth = %d < minimum depth = %d', depth, options.minDepth); 60 | end 61 | break 62 | else 63 | depth = depth + 1; 64 | cuts = sort([cuts, new_cuts]); 65 | end 66 | end 67 | if length(cuts)>2 68 | cuts = cuts(2:end-1); 69 | else 70 | cuts = []; 71 | end 72 | 73 | end -------------------------------------------------------------------------------- /src/PART/buildTree.m: -------------------------------------------------------------------------------- 1 | function [obj, nodes, log_probs] = buildTree(list, area, N, para, mark, total_set, min_number_points_block, min_cut_length, height, rule, max_n_unique) 2 | %[obj, nodes, log_probs] = buildTree(list, area, N, para, mark, total_set, rho, height, rule, max_n_unique) 3 | % 4 | %For a given list of MCMC draws from multiple subsets (could be 5 | %multiple or one, but must be merged in one matrix). 6 | %buildTree returns the correponding partition tree stored in "obj", the 7 | %flattened tree and the corresponding probability (for sampling) 8 | %stored in "p" and "prob" respectively. buildTree is an inner function 9 | %to be called by buildForest. Do not use buildTree directly as it lacks 10 | %certain normalizing component when combining multple subsets. Using 11 | %buildForest instead (you can build a forest containing one single 12 | %tree as well) 13 | % 14 | %[obj, p, prob] = buildTree(list, area, N, para, mark, total_set, rho, 15 | %height, rule, max_n_unique) returns the partition tree that combines the product of 16 | %densities from subset posterior samples. 17 | % 18 | %Arguments: 19 | % 20 | %list: A sum(N) x d matrix containing "pooled" MCMC samples. d is the dimension 21 | % and N is a vector containing the number of samples on different subsets. If there are 22 | % multiple subsets, all points are merged in one matrix, but the data 23 | % will be indexed by "mark" (define later) for reference to the original 24 | % subsets 25 | %area: the total area for cutting. It is a d x 2 matrix, each row 26 | % corresonding to one dimension, the first column is the minimum 27 | % boundary and the second column is the maximum boundary. The 28 | % default value is generated by the largest and smallest value at 29 | % each dimension x 1.1 30 | %N: number of posterior samples for each subset MCMC 31 | %para: the columns/dimensions that buildTree is to cut. 32 | %mark: Index of data point. It would cost too much to store data directly 33 | % in "obj" and "p", we instead store their index within each node 34 | % in a form of (i, j) where j is the index for the subset and i is 35 | % the index of the data point in that subset. For example, (1, 2) 36 | % means the first point from the second subset 37 | %total_set: total number of subsets 38 | %min_number_points_block: # of samples inside a block >= min_number_points_block 39 | %min_cut_length: either a scalar or a vector of the same dim as the 40 | %parameters. min_cut_length(j) corresponds to the minimum length of the 41 | %block for dim j. 42 | %height: We write buildTree in a recursion way, so use height to indicate 43 | % the depth of the tree when building 44 | %rule: A function handle that determines the optimal cut at each gate 45 | % that takes the format of 46 | % [index, value] = rule(x, l, r), where value is the cutting location 47 | %max_n_unique: maximum number of unique samples allowed in a leaf node 48 | % 49 | %Outputs: 50 | % 51 | %obj: the partition tree that gives the (combined) density estimation 52 | %nodes: A flattened obj containing all (leaf) nodes -- they represent a blockwise constant estimated density 53 | %log_probs: the log-probability for each node/block in nodes 54 | % 55 | % (NOTE: node=leaf, gate="non-leaf nodes") 56 | % gates are only relevant to tree building; all the subsequent sampling 57 | % and combining only use nodes 58 | % 59 | %See also: 60 | %buildForest, MultiStageMCMC, OneStageMCMC 61 | 62 | %Obtain the total sample and dimension. 63 | [n, d] = size(list); 64 | 65 | % parameter parsing 66 | if isscalar(min_cut_length) 67 | min_cut_length = min_cut_length * ones(1,d); 68 | end 69 | 70 | % number of unique samples (using dim 1 would suffice due to continuous 71 | % variable) 72 | % n_unique = length(unique(list(:,1))); 73 | 74 | %Specify the default area 75 | if isempty(area) && height == 0 76 | area = zeros(d,2); 77 | area(:,1) = min(list)'; %min-boundary 78 | area(:,2) = max(list)'; %max-boundary 79 | assert(all(area(:,2)>area(:,1)), 'not strictly bigger'); 80 | 81 | %To enlarge the boundary by a factor of 1.001 so that no boundary issue 82 | for i = 1:d 83 | if area(d,1) > 0 84 | area(d,1) = area(d,1)/1.001; 85 | else 86 | area(d,1) = area(d,1)*1.001; 87 | end 88 | 89 | if area(d,2) > 0 90 | area(d, 2) = area(d, 2)*1.001; 91 | else 92 | area(d, 2) = area(d, 2)/1.001; 93 | end 94 | end 95 | end 96 | 97 | %Starting building nodes 98 | 99 | 100 | % decide if this should be a leaf/block vs. continue cutting 101 | to_be_leaf = true; 102 | dim_to_cut = nan; 103 | cutting_point = nan; 104 | if isempty(para) 105 | to_be_leaf = true; 106 | else 107 | % only those dimensions with length > min_cut_length are passed via para 108 | candidate_dims = para; 109 | % randomly choose a dimension and fetch the cutting point 110 | % until either (1) resulting a valid split (non-left) or (2) 111 | % candidate dimensions become empty 112 | while ~isempty(candidate_dims) 113 | dim_index = randsample(length(candidate_dims), 1); 114 | dim_to_cut = candidate_dims(dim_index); 115 | % call: rule(x, l, r, M, subset_index) 116 | [~, cutting_point] = feval(rule, list(:,dim_to_cut)', area(dim_to_cut,1), area(dim_to_cut,2), total_set, mark(:,2)); 117 | if (cutting_point - area(dim_to_cut,1) >= min_cut_length(dim_to_cut)) && (area(dim_to_cut, 2) - cutting_point >= min_cut_length(dim_to_cut)) 118 | left = list(:,dim_to_cut)cutting_point; %bool vector: points that go to the right child 120 | if sum(left) >= min_number_points_block && sum(right) >= min_number_points_block 121 | % if min(histc(mark(left,2), 1:total_set)) >= min_number_points_block && min(histc(mark(right,2), 1:total_set)) >= min_number_points_block 122 | % this is a valid cut 123 | to_be_leaf = false; 124 | break; 125 | end 126 | end 127 | % non-valid cut 128 | candidate_dims(dim_index) = []; 129 | end 130 | end 131 | 132 | l = area(:,2) - area(:,1); %finding lenght on each side of a block 133 | %Check if it is a node (meets the termination condition) 134 | if to_be_leaf 135 | %Yes!! This is a leaf! 136 | %Check if the area condition is violated (non-zero area) 137 | assert(all(l>0)); 138 | 139 | %create a new node for the leaf 140 | obj = treeNode; 141 | 142 | % compute the aggregated density and probability via proper normalization 143 | %count numbers of points in different subsets. 144 | c = histc(mark(:,2), 1:total_set); 145 | if total_set > 1 %if there are multiple subsets 146 | counts = c + 0.1;%this number is used to bound away from 0 147 | 148 | %compute unnormalized probability & density in logarithm 149 | % density = density_1 x density_2 x ... 150 | % = n_1/N_1/(prod of l) x n_2/N_2/(prod of l) x ... 151 | % prob = density x (prod of l) 152 | % unnormalized probability by multiplying subset-wise density 153 | % and then multiply the area 154 | obj.log_prob = sum(log(counts)) - sum(log(N)) - (total_set - 1)*sum(log(l)); 155 | % unnormalized density in log scale 156 | obj.log_density = obj.log_prob - sum(log(l)); 157 | else 158 | % if there's only one subset & compute probability and density 159 | obj.log_prob = log(n) - log(N); 160 | obj.log_density = obj.log_prob - sum(log(l)); 161 | end 162 | 163 | % FIXME: comment them out 164 | assert(isfinite(sum(log(l)))); 165 | assert(isfinite(obj.log_prob) && isfinite(obj.log_density)); 166 | 167 | % store information in the leaf node 168 | obj.point = mark; %the index of the data points 169 | obj.area = area; % the area 170 | 171 | % a local Gaussian kernel -- a gaussian (instead of uniform) estimated by the points in 172 | % the block 173 | % multiplicated Laplacian approximation is adopted 174 | if size(mark, 1)==1 175 | % cannot estimate from single point 176 | obj.cov = nan; 177 | obj.mean = nan; 178 | elseif total_set==1 || ~all(c>d) 179 | % multiple points from one subset or too few samples each set, one Gaussian 180 | obj.cov = cov(list)/total_set; 181 | obj.mean = mean(list, 1); 182 | else 183 | % multiplicated Laplacian 184 | tmp = zeros(d); 185 | agg_Mu = zeros(1,d); 186 | for m=1:total_set 187 | tmp_cov = cov(list(mark(:,2)==m, :)); 188 | tmp = tmp + inv(tmp_cov); 189 | agg_Mu = agg_Mu + mean(list(mark(:,2)==m,:),1) / tmp_cov; 190 | end 191 | obj.cov = inv(tmp); 192 | obj.mean = agg_Mu * obj.cov; 193 | % test positive semidefiniteness 194 | [~, chol_number] = chol(obj.cov); 195 | if chol_number~=0 196 | % the matrix is not positive definite 197 | % fall back to crude estimation with pooled samples 198 | obj.cov = cov(list)/total_set; 199 | obj.mean = mean(list, 1); 200 | end 201 | end 202 | 203 | % singleton return values 204 | nodes = {obj}; %Output the leaf for constructing flattened tree (sampler) 205 | log_probs = [obj.log_prob]; %Output the corresponding probability for each leaf 206 | else 207 | %No!! This is not yet a leaf, then this is a gate! We need 208 | %to partition the data and build the partition rule on this 209 | %gate. 210 | obj = treeNode; %create the gate 211 | % NOTE: randomness of the ensemble 212 | obj.dim = dim_to_cut; %random select a dimension for building partition rule 213 | if height == 0 && d == 1 214 | %For one dimensional case, in order to introduce 215 | %randomness in building different trees, we always do a 216 | %random cut at the first step 217 | obj.value = list(randsample(1:n,1),obj.dim); 218 | left = list(:,obj.dim)obj.value; %bool vector: points that go to the right child 220 | else 221 | %Otherwise, we call the function handle @rule to 222 | %determine the cutting point 223 | obj.value = cutting_point; 224 | end 225 | 226 | % split into left vs. right depending on obj.value 227 | % We save all information at the gate 228 | obj.area = area; %the whole area 229 | middle = find(list(:,obj.dim)==obj.value); %middle is important when MCMC chain got stuck somewhere 230 | % randomly split the middle into left and right with half & half prob 231 | tmp = rand(size(middle))>0.5; 232 | left(middle(tmp)) = true; 233 | right(middle(~tmp)) = true; 234 | assert(all(left+right==1), 'wrong split'); % FIXME: comment it out 235 | list1 = list(left,:); 236 | list2 = list(right,:); 237 | mark1 = mark(left,:); %indexes that go to left 238 | mark2 = mark(right,:); %indexes that go to right 239 | 240 | area1 = area; %the area of the left child 241 | area1(obj.dim,:) = [area(obj.dim,1), obj.value]; 242 | assert(all(area1(:,2)>area1(:,1))); 243 | 244 | area2 = area; %the area of the right child 245 | area2(obj.dim,:) = [obj.value, area(obj.dim,2)]; 246 | assert(all(area2(:,2)>area2(:,1))); 247 | 248 | %We then recursively build the left sub-tree and right 249 | %sub-tree by calling the same buildTree function with 250 | %corresponding information 251 | 252 | % the candidate dimensions for next cutting are restricted to those wider than min_cutting_length 253 | para_1 = para(area1(para,2)-area1(para,1)>min_cut_length(para)'); 254 | para_2 = para(area2(para,2)-area2(para,1)>min_cut_length(para)'); 255 | 256 | %Build the left sub-tree 257 | [leftNode, nodes_left, log_probs_left] = buildTree(list1, area1, N, para_1, mark1, total_set, ... 258 | min_number_points_block, min_cut_length, height+1, rule, max_n_unique); 259 | 260 | %Build the right sub-tree 261 | [rightNode, nodes_right, log_probs_right] = buildTree(list2, area2, N, para_2, mark2, total_set, ... 262 | min_number_points_block, min_cut_length, height+1, rule, max_n_unique); 263 | 264 | %Connect the current gate to its left and right children 265 | obj.left = leftNode; 266 | obj.right = rightNode; 267 | 268 | %Augment the information for the flattened tree (sampler) 269 | nodes = [nodes_left, nodes_right]; % collecting the leaf nodes 270 | log_probs = [log_probs_left, log_probs_right]; 271 | end 272 | end -------------------------------------------------------------------------------- /src/PART/cleanTree.m: -------------------------------------------------------------------------------- 1 | function [trees, numNode] = cleanTree(trees) 2 | %This is a function that cleans a given forest of random partition 3 | %tree. In particular, it cleans up node.log_prob, node.log_density and node.point 4 | %but leaves all gates(partition point) and area information unchanged 5 | %(i.e., the tree structure is preserved) 6 | 7 | m = length(trees); 8 | h = waitbar(0,'Starting cleaning trees...'); 9 | for i = 1:m 10 | Node = trees{i}; 11 | stack = Node; 12 | while ~isempty(stack) 13 | current_node = stack(end); 14 | stack = stack(1:end-1); %pop the last one 15 | if isempty(current_node.left) && isempty(current_node.right) 16 | current_node.log_prob = nan; 17 | current_node.log_density = nan; 18 | current_node.point = []; 19 | current_node.mean = nan; 20 | current_node.cov = nan; 21 | else 22 | stack = [stack, current_node.left, current_node.right]; 23 | end 24 | end 25 | waitbar(i/m, h, ['Cleaning', num2str(i), 'th tree...']); 26 | end 27 | close(h); 28 | end -------------------------------------------------------------------------------- /src/PART/copyTree.m: -------------------------------------------------------------------------------- 1 | function newTrees = copyTree(trees) 2 | %This is a deepCopy function that copies a forest of random partition 3 | %tree structure 4 | 5 | m = length(trees); 6 | newTrees = cell(1,m); 7 | h = waitbar(0,'Starting copying trees...'); 8 | for i = 1:m 9 | Node = trees{i}; 10 | newNode = copyNode(Node); 11 | newTrees{i} = newNode; 12 | stack = Node; 13 | newStack = newNode; 14 | while ~isempty(stack) 15 | current_node = stack(end); 16 | current_newNode = newStack(end); 17 | stack = stack(1:end-1); %pop the last one 18 | newStack = newStack(1:end-1); 19 | if ~isempty(current_node.left) || ~isempty(current_node.right) 20 | stack = [stack, current_node.left, current_node.right]; 21 | newLeft = copyNode(current_node.left); 22 | newRight = copyNode(current_node.right); 23 | current_newNode.left = newLeft; 24 | current_newNode.right = newRight; 25 | newStack = [newStack, newLeft, newRight]; 26 | end 27 | end 28 | waitbar(i/m, h, ['Copying', num2str(i), 'th tree...']); 29 | end 30 | close(h); 31 | end 32 | 33 | function new = copyNode(this) 34 | new = feval(class(this)); 35 | p = properties(this); 36 | for i = 1:length(p) 37 | new.(p{i}) = this.(p{i}); 38 | end 39 | end -------------------------------------------------------------------------------- /src/PART/cut/consensusMLCut.m: -------------------------------------------------------------------------------- 1 | function [index, value, goodness_of_cuts] = consensusMLCut(x, l, r, M, subset_index) 2 | % [index, value] = consensusMLCut(x, l, r, M, subset_index) 3 | % get the cut that maximizes the summed empirical likelihood from multiple subsets via a line search 4 | % 5 | % ARGUMENT 6 | % x: 1-D array 7 | % l: scalar, left boundary 8 | % r: scalar, right boundary 9 | % M: scalar, # of subsets 10 | % subset_index: (i.e. mark(:,2)) the indicator for which subset the sample comes from, of the same 11 | % dimension as x 12 | % 13 | % OUTPUT 14 | % index: the index of x that is the rightmost point lying left of cutting point 15 | % value: the value of the cutting point 16 | % 17 | x = x(:)'; 18 | N = length(x); 19 | [y, I] = sort(x); % increasing order 20 | assert(length(x) == length(subset_index), 'length of x must match length of subset_index'); 21 | assert(all(subset_index>=1) && all(subset_index<=M), 'subset index must take 1...M'); 22 | assert(y(1)>=l & y(end)<=r, 'x not in range of [l,r]'); 23 | % index_vec is a subset of I and is considered for cutting 24 | index_vec = 1:(N-1); % corresponding to x(index_vec(1:end-1)) 25 | cuts = y(index_vec); 26 | num_cuts = length(index_vec); 27 | subset_sizes = histc(subset_index, 1:M); 28 | % profile a (M x num_cuts) cum-sum matrix: cum_sum_matrix(m, j) = # of samples in subset m that falls left of cut j 29 | sorted_subset_indices = subset_index(I(index_vec)); 30 | cum_sum_matrix = zeros(M, num_cuts); 31 | complement_cum_sum_matrix = zeros(M, num_cuts); 32 | for m=1:M 33 | cum_sum_matrix(m,:) = cumsum(sorted_subset_indices==m); % deal with zeros 34 | complement_cum_sum_matrix(m,:) = subset_sizes(m) - cum_sum_matrix(m,:); 35 | end 36 | % compute goodness of cut from matrix operations 37 | L = cuts - l; 38 | R = r - cuts; 39 | L(L==0) = 1e-4; 40 | R(R==0) = 1e-4; 41 | A1 = repmat(L, M, 1); 42 | A2 = repmat(R, M, 1); 43 | tmp_mat_1 = cum_sum_matrix .* (log(cum_sum_matrix) - log(A1)); 44 | tmp_mat_1(cum_sum_matrix==0) = 0; 45 | tmp_mat_2 = complement_cum_sum_matrix .* (log(complement_cum_sum_matrix) - log(A2)); 46 | tmp_mat_2(complement_cum_sum_matrix==0) = 0; 47 | goodness_of_cuts = sum(tmp_mat_1 + tmp_mat_2, 1); 48 | % pick the ML 49 | [~, u] = max(goodness_of_cuts); 50 | index = I(u); 51 | value = x(index); 52 | % fprintf('%d: goc = %f\n', u, goodness_of_cuts(u)); 53 | end -------------------------------------------------------------------------------- /src/PART/cut/fastMLCut.m: -------------------------------------------------------------------------------- 1 | function [index, value] = fastMLCut(x, l, r, varargin) 2 | % [index, value] = fastMLCut(x, l, r, varargin) 3 | % get the cut that maximizes the empirical likelihood via a line search 4 | % 5 | % x: 1-D array 6 | % l: left boundary 7 | % r: right boundary 8 | % optional: 'plot' (default no plotting) 9 | % 10 | % index: the index of x that is the rightmost point lying left of cutting point 11 | % value: the value of the cutting point 12 | x = x(:)'; 13 | N = length(x); 14 | [y, I] = sort(x); % increasing order 15 | jvec = 1:(N-1); % cuts are numbered 1 ... N-1 16 | assert(y(1)>=l & y(end)<=r, 'x not in range of [l,r]'); 17 | if N>500 18 | index_vec = randsample(N, 500, true)'; 19 | cuts = y(index_vec); 20 | goodness_of_cuts = index_vec .* log(index_vec ./ (cuts-l)) + (N-index_vec) .* log((N-index_vec) ./ (r-cuts)); 21 | [~, j] = max(goodness_of_cuts); 22 | index = I(index_vec(j)); 23 | value = x(index); 24 | else 25 | cuts = [(y(1)+y(2))/2, y(2:end-1)]; 26 | % goodness of cut by directly computing the empirical log liklihood 27 | goodness_of_cuts = jvec .* log(jvec ./ (cuts-l)) + (N-jvec) .* log((N-jvec) ./ (r-cuts)); 28 | % smooth with slide windowed averaging 29 | % goodness_of_cuts = conv(goodness_of_cuts, ones(1,5)/5, 'same'); 30 | % not allow cutting too small area 31 | % imbalance = (cuts - l) ./ (r - cuts); 32 | % J = imbalance<1; 33 | % imbalance(J) = 1 ./ imbalance(J); 34 | % goodness_of_cuts(imbalance>10) = -Inf; 35 | 36 | [~, j] = max(goodness_of_cuts); 37 | index = I(j); 38 | value = x(index); 39 | end 40 | end -------------------------------------------------------------------------------- /src/PART/cut/kdCut.m: -------------------------------------------------------------------------------- 1 | function [index, value] = kdCut(x, l, r, varargin) 2 | assert(x(1)>=l & x(end)<=r, 'x not in range of [l,r]'); 3 | index = ceil(length(x)/2); 4 | value = median(x); 5 | end 6 | -------------------------------------------------------------------------------- /src/PART/cut/meanCut.m: -------------------------------------------------------------------------------- 1 | function [index, value] = meanCut(x, l, r, varargin) 2 | assert(x(1)>=l & x(end)<=r, 'x not in range of [l,r]'); 3 | value = mean(x); 4 | [~, index] = min(abs(x-value)); 5 | end 6 | -------------------------------------------------------------------------------- /src/PART/cut/midPointCut.m: -------------------------------------------------------------------------------- 1 | function [index, value] = midPointCut(x, l, r, varargin) 2 | assert(x(1)>=l & x(end)<=r, 'x not in range of [l,r]'); 3 | value = (max(x)+min(x))/2; 4 | [~, index] = min(abs(x-value)); 5 | end 6 | -------------------------------------------------------------------------------- /src/PART/empiricalLogLikelihood.m: -------------------------------------------------------------------------------- 1 | function ll = empiricalLogLikelihood(x, l, r, cuts) 2 | % ll = empiricalLogLikelihood(x, l, r, cuts) 3 | % evaluates the empirlcal log likelihood of x under cuts 4 | % 5 | % x: data 6 | % l: left boundary 7 | % r: right boundary 8 | % cuts: cutting points 9 | % 10 | % ll: empirical log likelihood 11 | assert(all(x>=l) && all(x<=r), 'not all x in the range [l,r]'); 12 | cuts = [l sort(cuts) r]; 13 | N = length(x); 14 | ll = 0; 15 | for k=1:(length(cuts)-1) 16 | lk = cuts(k); 17 | rk = cuts(k+1); 18 | nk = sum(lk<=x & x0, 'minimum length of block must be positive'); 53 | case 'min_fraction_block' 54 | assert(pair{2}>0 && pair{2}<1, 'minimum fraction of samples contained in a block must be >0 and <1'); 55 | case 'cut_type' 56 | assert(strcmp(pair{2}, 'kd') || strcmp(pair{2}, 'ml'), 'cut_type must be kd or ml'); 57 | case 'resample_N' 58 | assert(pair{2}>=1000, 'resample_N too small'); 59 | case 'resample_data_only' 60 | warning('resample_data_only = true only for debugging'); 61 | case 'ntree' 62 | assert(pair{2}>0, 'ntree must be a positive integer'); 63 | end 64 | else 65 | error('%s is not a recognized parameter name',inpName) 66 | end 67 | end 68 | 69 | fprintf('Options configured:\n'); 70 | disp(option); 71 | 72 | end -------------------------------------------------------------------------------- /src/PART/treeDensity.m: -------------------------------------------------------------------------------- 1 | function y = treeDensity(x, trees, dens) 2 | %treeDensity is a function taking output from buildForest, 3 | %OneStageMCMC and MultiStageMCMC to evaluate the combined density for a given 4 | %point. 5 | % 6 | %Call: 7 | %y = treeDensity(x, trees) 8 | %y = treeDensity(x, trees, dens) 9 | % 10 | %Arguments: 11 | % 12 | %x: is a N x d matrix containing all points that density value need be 13 | % evaluated. N is number of points and d is the dimension. 14 | %trees: forest of random partition tree (or flattened tree, depending 15 | % on "dens") outputed ffrom buildForest, OneStageMCMC and MultiStageMCMC. 16 | % When dens is "uniform", trees are 17 | % "trees" outputted from the three functions. When dens is "normal", 18 | % trees are "sampler" outputted from the three functions. 19 | %dens: choices of different density functions on leaves. Options are 20 | %"uniform" and "normal". Default is "uniform". 21 | % 22 | %Outputs: 23 | % 24 | %y: the density values at x 25 | % 26 | %See also: 27 | %buildForest, OneStageMCMC, MultiStageMCMC, treeSampling 28 | % 29 | 30 | 31 | m = length(trees); %number of tree 32 | [n, d] = size(x); %number of points and the dimension 33 | 34 | %Setting the default value for parameters 35 | if nargin <3 36 | dens = 'uniform'; 37 | end 38 | 39 | %If the choice is uniform density on each leaf 40 | if strcmp(dens, 'uniform') 41 | %setting up waitbar... 42 | h = waitbar(0,'Density evaluation in parallel (uniform) ...'); 43 | 44 | %Initializing z 45 | z = zeros(m,n); 46 | 47 | % parallel loop over trees 48 | for i=1:m 49 | for k=1:n 50 | Node = trees{i}; 51 | while ~isnan(Node.dim) 52 | if x(k, Node.dim) <= Node.value 53 | Node = Node.left; 54 | else 55 | Node = Node.right; 56 | end 57 | end 58 | %Adding the density to y(k) 59 | z(i, k) = exp(Node.value); 60 | end 61 | end 62 | 63 | %Averaging over all trees 64 | y = mean(z, 1); 65 | 66 | %updating waitbar... 67 | waitbar(1, h, 'Evaluating densities ...'); 68 | close(h); 69 | end 70 | 71 | %If the choice is to use normal density on each leaf 72 | if strcmp(dens,'normal') 73 | %setting up the waitbar 74 | h = waitbar(0,'Starting density evaluation (normal fit) ...'); 75 | 76 | %looping over all points 77 | for t = 1:n 78 | %looping over all trees 79 | for i = 1:m 80 | y1 = 0; 81 | %looping over all leaves in that tree 82 | for k = 1:length(trees{i}) 83 | %l = trees{i}{k}.area(:,2) - trees{i}{k}.area(:,1); 84 | y1 = y1 + mvnpdf(x(t,:),trees{i}{k}.mean',trees{i}{k}.cov); 85 | end 86 | y1 = y1/length(trees{i}); 87 | y(t) = y(t) + y1; 88 | end 89 | y(t) = y(t)/m; 90 | 91 | %updating waitbar 92 | waitbar(t/n, h, 'Evaluating densities (guassian fit) ...'); 93 | end 94 | close(h); 95 | end 96 | end -------------------------------------------------------------------------------- /src/PART/treeNode.m: -------------------------------------------------------------------------------- 1 | classdef treeNode < handle 2 | % treeNode is either a gate (non-leaf nodes that splits into two), or a 3 | % node (leaf nodes that represents an equal-density block) 4 | properties 5 | dim = nan %The cutting dimension (for gates) 6 | value = nan %The cutting point (for gates) 7 | log_density = nan % The density (for nodes) (in log-scale) 8 | log_prob = nan %The node probability (for nodes) (in log-scale) 9 | point = [] %Indexes for original posterior samples. (i,j) means the ith point in jth subset 10 | area = nan %Area information 11 | cov = nan %variance of normal density approximation 12 | left = [] %reference to the left child 13 | right = [] %reference to the right child 14 | mean = nan %mean of normal density approximation 15 | end 16 | methods 17 | function new = copyNode(this) 18 | new = feval(class(this)); 19 | p = properties(this); 20 | for i = 1:length(p) 21 | new.(p{i}) = this.(p{i}); 22 | end 23 | end 24 | end 25 | end -------------------------------------------------------------------------------- /src/PART/treeNormalize.m: -------------------------------------------------------------------------------- 1 | function [sum2, prob] = treeNormalize(tree, sum1) 2 | if nargin == 1 3 | %to get the sum and the probability vector 4 | if isnan(tree.left) && isnan(tree.right) 5 | sum2 = tree.value; 6 | prob = [tree.prob]; 7 | else 8 | [sum2_left, prob_left] = treeNormalize(tree.left); 9 | [sum2_right, prob_right] = treeNormalize(tree.right); 10 | sum2 = sum2_left + sum2_right; 11 | prob = [prob_left prob_right]; 12 | end 13 | 14 | elseif nargin == 2 15 | %to normalize using sum1 16 | sum2 = 0; 17 | prob = 0; 18 | if isnan(tree.left) && isnan(tree.right) 19 | tree.value = tree.value/sum1; 20 | else 21 | treeNormalize(tree.left, sum1); 22 | treeNormalize(tree.right, sum1); 23 | end 24 | end 25 | end -------------------------------------------------------------------------------- /src/PART/treeSampling.m: -------------------------------------------------------------------------------- 1 | function y = treeSampling(sampler, sampler_prob, n, option) 2 | %treeSampling use the output from buildForest, OneStageMCMC and 3 | %MultiStageMCMC to resample points from the combined density. 4 | % 5 | %Call: 6 | %y = treeSampling(sampler, sampler_prob) 7 | %y = treeSampling(sampler, sampler_prob, n) 8 | %y = treeSampling(sampler, sampler_prob, n, option) 9 | % 10 | %Arguments: 11 | % 12 | %sampler: the output from buildForest, OneStageMCMC and MultiStageMCMC, 13 | % which is a forest of flattened random partition tree 14 | % sampler_prob: the corresponding probiblity for sampler, also outputed 15 | % by the three functions aforementioned 16 | %n: number of resamples 17 | %option: optional choices including: 18 | % resample_data_only: an indicator whether the resampling within each stage 19 | % should use the original data points, i.e., when resampling at a 20 | % given leaf, we only resample original data points that are within 21 | % that leaf. 22 | % local_gaussian_smoothing: an indicator whether the resampling should use an normal 23 | % density on each leaf, i.e., when resampling at a given leaf, 24 | % we use the original data points located in the leaf to fit a 25 | % normal density (the variance will be shrinked by a factor of 26 | % 1/total_subset), and resample using this density 27 | % list: the list containing original subset posteriors concatenated 28 | % in one matrix (See buildForest) 29 | % mark: index for original subset posteriors (See buildForest) 30 | % 31 | %Outputs: 32 | % 33 | %y: the resampled posterior samples from the combined density 34 | % 35 | %See also: 36 | %buildForest, OneStageMCMC, MultiStageMCMC, treeDensity 37 | % 38 | 39 | %The default number of resampling 40 | if nargin == 2 41 | n = 1; 42 | 43 | %The default options 44 | elseif nargin == 3 45 | option.list = []; 46 | option.mark = []; 47 | option.resample_data_only = false; 48 | option.local_gaussian_smoothing = false; 49 | elseif nargin == 4 50 | %Check if the option is consistent 51 | if ~isempty(option.list) && isempty(option.mark) 52 | error('mark must be provided if list exists'); 53 | elseif option.resample_data_only && isempty(option.list) 54 | error('Need to provide the data in order to resample_data_only'); 55 | end 56 | end 57 | 58 | m = length(sampler); %number of trees 59 | d = size(sampler{1}{1}.area,1); %data dimension 60 | y = zeros(n, d); %restore the resamples 61 | 62 | %h = waitbar(0,['Sampling ',num2str(n), ' %samplers from the distribution...']); 63 | 64 | if option.resample_data_only 65 | %if resampling using the original subset posteriors, we need to 66 | %obtain the sizes of posterior samples on different subsets 67 | N = zeros(1,length(unique(option.mark(:,2)))); 68 | for set = 1:length(N) 69 | N(set) = sum(option.mark(:,2)< set); 70 | end 71 | %N is accumulated sizes i.e., N(1) = 0, N(2) = |subset1|, N(3) = 72 | %|subset1| + |subset2| 73 | end 74 | %sample n times and count how many times a particular tree is picked 75 | sampled_tree = randsample(1:m, n, true); 76 | tree_counted = histc(sampled_tree, 1:m); 77 | 78 | %loop over all trees and sample points from the corresponding trees 79 | %with associate counts 80 | num_sampled = 0; 81 | 82 | fprintf('Resampling %d points...\n',n); 83 | %starting resampling 84 | for ctree = 1:m 85 | 86 | tree = sampler{ctree}; %The current tree for resampling 87 | prob = exp(sampler_prob{ctree}); %the probability for each leaf 88 | 89 | %Sample nodes with replacement tree_counted(ctree) times 90 | sampled_nodes = randsample(1:length(tree),tree_counted(ctree),true, prob); 91 | 92 | %Count how many times a unique node is picked 93 | unique_nodes = unique(sampled_nodes); 94 | nodes_count = histc(sampled_nodes, unique_nodes); 95 | 96 | %looping over all nodes picked to resample real samples 97 | for node_index = 1:length(unique_nodes) 98 | node = tree{unique_nodes(node_index)}; %the current sampled node 99 | if option.resample_data_only 100 | %to resample using the observed data. data_sampled contains 101 | %rows of node.mark that have been sampled 102 | data_sampled = randsample(1:size(node.point,1),nodes_count(node_index), true); 103 | 104 | %Now we added the resampled points to y 105 | for k = data_sampled 106 | data_loc = node.point(k,:); 107 | y(num_sampled + 1,:) = option.list(N(data_loc(2)) + data_loc(1),:); 108 | num_sampled = num_sampled + 1; 109 | end 110 | 111 | elseif option.local_gaussian_smoothing && all(isfinite(node.cov(:))) 112 | %to resample using the normal density 113 | y(num_sampled+1:num_sampled+nodes_count(node_index),:) = ... 114 | mvnrnd(node.mean,node.cov,nodes_count(node_index)); 115 | num_sampled = num_sampled + nodes_count(node_index); 116 | 117 | else 118 | %to resample from the uniform distribution 119 | l = node.area(:,2) - node.area(:,1); 120 | y(num_sampled+1:num_sampled+nodes_count(node_index),:) =... 121 | (ones(nodes_count(node_index),1)*l').*rand(... 122 | nodes_count(node_index),d) + repmat(node.area(:,1)',nodes_count(node_index),1); 123 | 124 | num_sampled = num_sampled + nodes_count(node_index); 125 | end 126 | end 127 | end 128 | end -------------------------------------------------------------------------------- /src/alternatives/aggregate_average.m: -------------------------------------------------------------------------------- 1 | function averaged_chain = aggregate_average(sub_chain) 2 | fprintf('Combining chains with simple averaging...\n'); 3 | % aggregation via simple averaging 4 | k = length(sub_chain); 5 | [~, p] = size(sub_chain{1}); 6 | chain_lengths = zeros(1,k); 7 | for c=1:k 8 | chain_lengths(c) = size(sub_chain{c},1); 9 | end 10 | n = min(chain_lengths); 11 | averaged_chain = zeros(n, p); 12 | for c=1:k 13 | averaged_chain = averaged_chain + sub_chain{c}(1:n, :) * 1/k; 14 | end 15 | 16 | end -------------------------------------------------------------------------------- /src/alternatives/aggregate_uai_nonparametric.m: -------------------------------------------------------------------------------- 1 | function aggregated_samples = aggregate_uai_nonparametric(sub_chain, varargin) 2 | % drawing N posterior samples from multiplicated KD estimated subchain posteriors 3 | % "nonparametric" method from the uai paper. See its Algorithm 1. 4 | 5 | if ~isempty(varargin) 6 | N = varargin{1}; 7 | else 8 | N = size(sub_chain{1},1); 9 | end 10 | 11 | M = length(sub_chain); 12 | p = size(sub_chain{1},2); 13 | subchain_sizes = zeros(1,M); 14 | for c=1:M 15 | subchain_sizes(c) = size(sub_chain{c},1); 16 | end 17 | 18 | aggregated_samples = zeros(N, p); 19 | index_t = zeros(1, M); 20 | % uniformly init the index 21 | for c=1:M 22 | index_t(c) = randsample(subchain_sizes(c), 1); 23 | end 24 | x_t = zeros(M, p); 25 | for c=1:M 26 | x_t(c,:) = sub_chain{c}(index_t(c), :); 27 | end 28 | x_t_mean = mean(x_t, 1); 29 | ss_t = sum(sum((x_t - repmat(x_t_mean, M, 1)).^2)); 30 | 31 | fprintf('Getting %d samples with nonparametric method...\n', N); 32 | tic; 33 | count_rej = 0; 34 | for i=1:N 35 | h = i^(-1/(4+p)); 36 | for c=1:M 37 | index_new = index_t; 38 | % move the index for this subchain 39 | index_new(c) = randsample(subchain_sizes(c), 1); 40 | x_t_new = x_t; 41 | x_t_new(c,:) = sub_chain{c}(index_new(c), :); 42 | x_t_new_mean = mean(x_t_new, 1); 43 | ss_new = sum(sum((x_t_new - repmat(x_t_new_mean, M, 1)).^2)); 44 | u = rand; 45 | if u < exp(-(ss_new - ss_t)/(2*h^2)) 46 | % accept the move 47 | index_t = index_new; 48 | x_t = x_t_new; 49 | x_t_mean = x_t_new_mean; 50 | ss_t = ss_new; 51 | else 52 | count_rej = count_rej + 1; 53 | end 54 | end 55 | % sample with the index 56 | aggregated_samples(i, :) = mvnrnd(x_t_mean, h^2/M * eye(p)); 57 | end 58 | fprintf('Sampling done (finished in %.1f seconds). Acceptance rate=%f\n', toc, 1-count_rej/M/N); 59 | end 60 | -------------------------------------------------------------------------------- /src/alternatives/aggregate_uai_parametric.m: -------------------------------------------------------------------------------- 1 | function aggregated_samples = aggregate_uai_parametric(sub_chain, varargin) 2 | % "parametric" aggregation from the UAI paper 3 | % multiplicated gaussian with laplacian approximation to each subchain 4 | if ~isempty(varargin) 5 | n = varargin{1}; 6 | else 7 | n = size(sub_chain{1},1); 8 | end 9 | M = length(sub_chain); 10 | [~,p] = size(sub_chain{1}); 11 | Prec = cell(1,M); 12 | agg_Sig = zeros(p); 13 | agg_Mu = zeros(1,p); 14 | for c=1:M 15 | Prec{c} = inv(cov(sub_chain{c})); 16 | agg_Sig = agg_Sig + Prec{c}; 17 | agg_Mu = agg_Mu + mean(sub_chain{c},1) * Prec{c}; 18 | end 19 | agg_Sig = inv(agg_Sig); 20 | agg_Mu = agg_Mu * agg_Sig; 21 | 22 | aggregated_samples = mvnrnd(repmat(agg_Mu,n,1), agg_Sig); 23 | end -------------------------------------------------------------------------------- /src/alternatives/aggregate_uai_semiparametric.m: -------------------------------------------------------------------------------- 1 | function aggregated_samples = aggregate_uai_semiparametric(sub_chain, varargin) 2 | % Drawing N samples from aggregated posterior with Neiswager et al's 3 | % semiparametric method 4 | 5 | if ~isempty(varargin) 6 | N = varargin{1}; 7 | else 8 | N = size(sub_chain{1},1); 9 | end 10 | 11 | M = length(sub_chain); 12 | [~,p] = size(sub_chain{1}); 13 | subchain_sizes = zeros(1,M); 14 | for c=1:M 15 | subchain_sizes(c) = size(sub_chain{c},1); 16 | end 17 | 18 | subchain_Sigma = cell(1,M); 19 | subchain_Mu = cell(1,M); 20 | Prec = cell(1,M); 21 | sum_of_Prec = zeros(p); 22 | agg_Mu = zeros(1,p); 23 | for c=1:M 24 | subchain_Sigma{c} = cov(sub_chain{c}); 25 | subchain_Mu{c} = mean(sub_chain{c},1); 26 | Prec{c} = inv(subchain_Sigma{c}); 27 | sum_of_Prec = sum_of_Prec + Prec{c}; 28 | agg_Mu = agg_Mu + subchain_Mu{c} * Prec{c}; 29 | end 30 | agg_Sigma = inv(sum_of_Prec); 31 | agg_Mu = agg_Mu / sum_of_Prec; 32 | 33 | aggregated_samples = zeros(N, p); 34 | index_t = zeros(1, M); 35 | % uniformly init the index 36 | for c=1:M 37 | index_t(c) = randsample(subchain_sizes(c), 1); 38 | end 39 | x_t = zeros(M, p); 40 | for c=1:M 41 | x_t(c,:) = sub_chain{c}(index_t(c), :); 42 | end 43 | x_t_mean = mean(x_t, 1); 44 | ss_t = sum(sum((x_t - repmat(x_t_mean, M, 1)).^2)); 45 | whole_normal = log_pdf_normal_whole(x_t_mean, agg_Mu, agg_Sigma, 1, M, p); 46 | subchain_normals = log_pdf_subchain_normals(x_t, subchain_Mu, subchain_Sigma, M); 47 | 48 | fprintf('Getting %d samples with semi-parametric method...\n', N); 49 | tic 50 | count_rej = 0; 51 | for i=1:N 52 | h = i^(-1/(4+p)); 53 | for c=1:M 54 | index_new = index_t; 55 | % move the index for this subchain 56 | index_new(c) = randsample(subchain_sizes(c), 1); 57 | x_t_new = x_t; 58 | x_t_new(c,:) = sub_chain{c}(index_new(c), :); 59 | x_t_new_mean = mean(x_t_new, 1); 60 | ss_new = sum(sum((x_t_new - repmat(x_t_new_mean, M, 1)).^2)); 61 | whole_normal_new = log_pdf_normal_whole(x_t_new_mean, agg_Mu, agg_Sigma, h, M, p); 62 | subchain_normals_new = log_pdf_subchain_normals(x_t_new, subchain_Mu, subchain_Sigma, M); 63 | u = rand; 64 | if u < exp(-(ss_new - ss_t)/(2*h^2) + ... 65 | whole_normal_new - whole_normal + subchain_normals - subchain_normals_new) 66 | % accept the move 67 | index_t = index_new; 68 | x_t = x_t_new; 69 | x_t_mean = x_t_new_mean; 70 | ss_t = ss_new; 71 | whole_normal = whole_normal_new; 72 | subchain_normals = subchain_normals_new; 73 | else 74 | count_rej = count_rej + 1; 75 | end 76 | end 77 | % sample with the index 78 | aggregated_samples(i, :) = mvnrnd(x_t_mean, h^2/M * eye(p)); 79 | end 80 | fprintf('Sampling done (finished in %.1f seconds). Acceptance rate=%f\n', toc, 1-count_rej/M/N); 81 | 82 | end 83 | 84 | function l = log_pdf_normal_whole(x_t_mean, agg_Mu, agg_Sig, h, M, p) 85 | l = logmvnpdf(x_t_mean, agg_Mu, agg_Sig + h/M*eye(p)); 86 | end 87 | 88 | function l = log_pdf_subchain_normals(x_t, subchain_Mu, subchain_Sigma, M) 89 | l = 0; 90 | for c=1:M 91 | l = l + logmvnpdf(x_t(c,:), subchain_Mu{c}, subchain_Sigma{c}); 92 | end 93 | end 94 | -------------------------------------------------------------------------------- /src/alternatives/aggregate_weighted_average.m: -------------------------------------------------------------------------------- 1 | function averaged_chain = aggregate_weighted_average(sub_chain) 2 | fprintf('Aggregating by weighted averaging...\n'); 3 | k = length(sub_chain); 4 | p = size(sub_chain{1},2); 5 | chain_length = zeros(1,k); 6 | for c=1:k 7 | chain_length(c) = size(sub_chain{c},1); 8 | end 9 | n = min(chain_length); 10 | averaged_chain = zeros(n, p); 11 | 12 | w = cell(1,k); 13 | w_inv = cell(1,k); 14 | sum_inv = zeros(p); 15 | for i = 1:k 16 | w{i} = cov(sub_chain{i}); 17 | w_inv{i} = inv(w{i}); 18 | sum_inv = sum_inv + w_inv{i}; 19 | end 20 | 21 | for c=1:k 22 | averaged_chain = averaged_chain + sub_chain{c}(1:n,:) * w_inv{i} / sum_inv; 23 | end 24 | 25 | 26 | end -------------------------------------------------------------------------------- /src/alternatives/logmvnpdf.m: -------------------------------------------------------------------------------- 1 | function [logp] = logmvnpdf(x,mu,Sigma) 2 | % outputs log likelihood array for observations x where x_n ~ N(mu,Sigma) 3 | % x is NxD, mu is 1xD, Sigma is DxD 4 | 5 | [N,D] = size(x); 6 | const = -0.5 * D * log(2*pi); 7 | 8 | xc = bsxfun(@minus,x,mu); 9 | 10 | term1 = -0.5 * sum((xc / Sigma) .* xc, 2); % N x 1 11 | term2 = const - 0.5 * logdet(Sigma); % scalar 12 | logp = term1' + term2; 13 | 14 | end 15 | 16 | function y = logdet(A) 17 | 18 | U = chol(A); 19 | y = 2*sum(log(diag(U))); 20 | 21 | end -------------------------------------------------------------------------------- /src/init.m: -------------------------------------------------------------------------------- 1 | addpath ./PART ./PART/cut ./alternatives ./utils 2 | fprintf('Path added.\n'); -------------------------------------------------------------------------------- /src/utils/approximate_KL.m: -------------------------------------------------------------------------------- 1 | function kl = approximate_KL(samples_left, samples_right) 2 | % compute the apprximate KL (p_left || p_right) 3 | % The KL is computed from Laplacian approximations fitted to both 4 | % distributions 5 | assert(size(samples_left,2)==size(samples_right,2),'dim not match'); 6 | p = size(samples_left,2); 7 | mu_left = mean(samples_left,1); 8 | mu_right = mean(samples_right, 1); 9 | Sigma_left = cov(samples_left); 10 | Sigma_right = cov(samples_right); 11 | inv_Sigma_right = inv(Sigma_right); 12 | d_mu = mu_right - mu_left; 13 | kl = 1/2 * (trace(inv_Sigma_right * Sigma_left) + ... 14 | d_mu * inv_Sigma_right * d_mu' - p + log(det(Sigma_right)) - log(det(Sigma_left))); 15 | end 16 | 17 | -------------------------------------------------------------------------------- /src/utils/parfor_progress.m: -------------------------------------------------------------------------------- 1 | function percent = parfor_progress(N) 2 | %PARFOR_PROGRESS Progress monitor (progress bar) that works with parfor. 3 | % PARFOR_PROGRESS works by creating a file called parfor_progress.txt in 4 | % your working directory, and then keeping track of the parfor loop's 5 | % progress within that file. This workaround is necessary because parfor 6 | % workers cannot communicate with one another so there is no simple way 7 | % to know which iterations have finished and which haven't. 8 | % 9 | % PARFOR_PROGRESS(N) initializes the progress monitor for a set of N 10 | % upcoming calculations. 11 | % 12 | % PARFOR_PROGRESS updates the progress inside your parfor loop and 13 | % displays an updated progress bar. 14 | % 15 | % PARFOR_PROGRESS(0) deletes parfor_progress.txt and finalizes progress 16 | % bar. 17 | % 18 | % To suppress output from any of these functions, just ask for a return 19 | % variable from the function calls, like PERCENT = PARFOR_PROGRESS which 20 | % returns the percentage of completion. 21 | % 22 | % Example: 23 | % 24 | % N = 100; 25 | % parfor_progress(N); 26 | % parfor i=1:N 27 | % pause(rand); % Replace with real code 28 | % parfor_progress; 29 | % end 30 | % parfor_progress(0); 31 | % 32 | % See also PARFOR. 33 | 34 | % By Jeremy Scheff - jdscheff@gmail.com - http://www.jeremyscheff.com/ 35 | 36 | error(nargchk(0, 1, nargin, 'struct')); 37 | 38 | if nargin < 1 39 | N = -1; 40 | end 41 | 42 | percent = 0; 43 | w = 50; % Width of progress bar 44 | 45 | if N > 0 46 | f = fopen([tempdir,'parfor_progress.txt'], 'w'); 47 | if f<0 48 | error('Do you have write permissions for %s?', pwd); 49 | end 50 | fprintf(f, '%d\n', N); % Save N at the top of progress.txt 51 | fclose(f); 52 | 53 | if nargout == 0 54 | disp([' 0%[>', repmat(' ', 1, w), ']']); 55 | end 56 | elseif N == 0 57 | delete([tempdir,'parfor_progress.txt']); 58 | percent = 100; 59 | 60 | if nargout == 0 61 | disp([repmat(char(8), 1, (w+9)), char(10), '100%[', repmat('=', 1, w+1), ']']); 62 | end 63 | else 64 | if ~exist([tempdir,'parfor_progress.txt'], 'file') 65 | error('parfor_progress.txt not found. Run PARFOR_PROGRESS(N) before PARFOR_PROGRESS to initialize parfor_progress.txt.'); 66 | end 67 | 68 | f = fopen([tempdir,'parfor_progress.txt'], 'a'); 69 | fprintf(f, '1\n'); 70 | fclose(f); 71 | 72 | f = fopen([tempdir,'parfor_progress.txt'], 'r'); 73 | progress = fscanf(f, '%d'); 74 | fclose(f); 75 | percent = (length(progress)-1)/progress(1)*100; 76 | 77 | if nargout == 0 78 | perc = sprintf('%3.0f%%', percent); % 4 characters wide, percentage 79 | disp([repmat(char(8), 1, (w+9)), char(10), perc, '[', repmat('=', 1, round(percent*w/100)), '>', repmat(' ', 1, w - round(percent*w/100)), ']']); 80 | end 81 | end 82 | -------------------------------------------------------------------------------- /src/utils/performance_table.m: -------------------------------------------------------------------------------- 1 | function T = performance_table(aggregated_posteriors, labels, full_posterior, varargin) 2 | L = length(aggregated_posteriors); 3 | assert(L==length(labels)); 4 | metric_names = {'rmse_cov', 'rmse_var', 'rmse_mean', 'KL_true_vs_est', 'KL_est_vs_true', 'relative_l2_error'}; 5 | if isempty(varargin) 6 | metric_names = metric_names(1:end-1); 7 | end 8 | metrics = zeros(L, length(metric_names)); 9 | for l=1:L 10 | for i=1:length(metric_names) 11 | switch metric_names{i} 12 | case 'rmse_cov' 13 | metrics(l, i) = rmse_posterior_cov(aggregated_posteriors{l}, full_posterior); 14 | case 'rmse_var' 15 | metrics(l, i) = rmse_posterior_cov(aggregated_posteriors{l}, full_posterior, 'diag'); 16 | case 'rmse_mean' 17 | metrics(l, i) = rmse_posterior_mean(aggregated_posteriors{l}, full_posterior); 18 | case 'KL_true_vs_est' 19 | metrics(l, i) = approximate_KL(full_posterior, aggregated_posteriors{l}); 20 | case 'KL_est_vs_true' 21 | metrics(l, i) = approximate_KL(aggregated_posteriors{l}, full_posterior); 22 | case 'relative_l2_error' 23 | metrics(l, i) = relative_error(aggregated_posteriors{l}, full_posterior, varargin{1}); 24 | end 25 | end 26 | end 27 | T = array2table(metrics, 'VariableNames', metric_names, 'RowNames', labels); 28 | end -------------------------------------------------------------------------------- /src/utils/plot_marginal_compare.m: -------------------------------------------------------------------------------- 1 | function plot_marginal_compare(chains, labels, varargin) 2 | % comparing posterior by plotting marginal densities 3 | % Call 4 | % plot_marginal_compare({chain_1, chain_2, ...}, {name_1, name_2, ...}) 5 | % plot_marginal_compare({chain_1, chain_2, ...}, {name_1, name_2, ...}, [p1 6 | % p2]) if only a few dims are compared 7 | % plot_marginal_compare({chain_1, chain_2, ...}, {name_1, name_2, ...}, [..true values of theta]) 8 | % if also want to annotate the true values 9 | 10 | k = length(chains); 11 | assert(k==length(labels)); 12 | [~, p] = size(chains{1}); 13 | true_values = []; 14 | if isempty(varargin) 15 | dims = 1:p; 16 | else 17 | dims = varargin{1}; 18 | assert(all(dims>=1) && all(dims<=p)); 19 | if length(varargin)>1 20 | true_values = varargin{2}; 21 | assert(length(true_values)==p); 22 | end 23 | end 24 | 25 | lineStyles = {'-','--',':','-.'}; 26 | a = floor(sqrt(length(dims))); 27 | b = ceil(length(dims)/a); 28 | for jj=1:length(dims) 29 | % plot dim j 30 | j = dims(jj); 31 | cnt = 0; 32 | subplot(a,b,jj); 33 | hold on; 34 | for c=1:k 35 | [f, x] = ksdensity(chains{c}(1:10:end,j)); 36 | if cnt==0 37 | % plot truth in red 38 | plot(x, f, '-r', 'DisplayName', labels{c}, 'LineWidth', 2); 39 | else 40 | plot(x, f, 'DisplayName', labels{c}, 'LineWidth', 2, 'LineStyle', lineStyles{mod(cnt, length(lineStyles))+1}); 41 | end 42 | cnt = cnt + 1; 43 | end 44 | if ~isempty(true_values) 45 | % also show the true values 46 | line([true_values(j), true_values(j)], ylim, 'Color','red', 'LineStyle', '--'); 47 | end 48 | 49 | if j==1, legend('-DynamicLegend', 'Location', 'best'); end 50 | title(['parameter ',num2str(j)]); 51 | hold off; 52 | end 53 | end 54 | 55 | -------------------------------------------------------------------------------- /src/utils/plot_pairwise_compare.m: -------------------------------------------------------------------------------- 1 | function plot_pairwise_compare(chains, labels, varargin) 2 | % comparing posterior by pairwise plotting samples 3 | % Call 4 | % plot_pairwise_compare({chain_1, chain_2, ...}, {name_1, name_2, ...}) 5 | % plot_pairwise_compare({chain_1, chain_2, ...}, {name_1, name_2, ...}, [p1 6 | % p2]) if only a few dims are compared 7 | 8 | k = length(chains); 9 | assert(k==length(labels)); 10 | [~, p] = size(chains{1}); 11 | if isempty(varargin) 12 | dims = 1:p; 13 | else 14 | dims = varargin{1}; 15 | assert(all(dims>=1) && all(dims<=p)); 16 | end 17 | 18 | a = length(dims); 19 | for u=1:length(dims) 20 | i = dims(u); 21 | for v=1:length(dims); 22 | j = dims(v); 23 | h = subplot(a,a,length(dims)*(u-1)+v); 24 | if i==j 25 | text(0.5, 0.5, ['parameter ',num2str(j)], 'Parent', h); 26 | else 27 | hold on; 28 | for c=k:-1:1 29 | if size(chains{c},1)>1000 30 | x = randsample(size(chains{c},1), 1000); 31 | else 32 | x = 1:length(chains{c}); 33 | end 34 | if c==1 35 | plot(chains{c}(x,i), chains{c}(x,j), 'k.', 'MarkerSize', 1, 'DisplayName', labels{c}); 36 | else 37 | plot(chains{c}(x,i), chains{c}(x,j), 'x', 'MarkerSize', 2, 'DisplayName', labels{c}) 38 | end 39 | end 40 | hold off; 41 | if i==1 && j==2 42 | legend('-DynamicLegend', 'Location', 'best'); 43 | end 44 | xlabel(['parameter ',num2str(i)]); 45 | ylabel(['parameter ',num2str(j)]); 46 | end 47 | end 48 | end 49 | end 50 | 51 | -------------------------------------------------------------------------------- /src/utils/plot_tree_blocks.m: -------------------------------------------------------------------------------- 1 | function plot_tree_blocks(tree, sub_chain, varargin) 2 | % visualize the blocks estimated by a tree by drawing pairwise block 3 | % structures 4 | % Call 5 | % plot_tree_blocks(treeNode) 6 | % plot_tree_blocks(tree_sampler, []) 7 | % plot_tree_blocks(tree_sampler, sub_chain) 8 | % plot_tree_blocks(tree_sampler, sub_chain, [p1 ... pn]), if only a few dimensions are plotted 9 | 10 | 11 | M = length(sub_chain); 12 | cc1 = hsv(M); 13 | cc2 = flipud(gray(1000)); 14 | 15 | if isa(tree, 'treeNode') 16 | % get its leaf nodes 17 | else 18 | max_logdens = 0; 19 | n_leaf = length(tree); 20 | for t=1:n_leaf 21 | if tree{t}.log_density>max_logdens 22 | max_logdens = tree{t}.log_density; 23 | end 24 | end 25 | 26 | assert(isa(tree{1}, 'treeNode'), 'wrong input'); 27 | fprintf('Plotting %d blocks...\n', n_leaf); 28 | % tree is a set of leaf nodes (samplers) 29 | p = size(tree{1}.area, 1); 30 | if ~isempty(varargin) 31 | dims = varargin{1}; 32 | assert(all(dims>=1) && all(dims<=p)); 33 | else 34 | dims = 1:p; 35 | end 36 | a = length(dims); 37 | if a==2 38 | a=1; 39 | end 40 | % pairwise plots 41 | for i=dims 42 | for j=dims 43 | if a>1 44 | hs = subplot(a,a,(i-1)*a+j); 45 | end 46 | axis off; 47 | hold on; 48 | if i==j && a>1 49 | text(0.5, 0.5, ['parameter ',num2str(j)], 'Parent', hs); 50 | elseif i20 68 | kvec = randsample(B, 20)'; 69 | else 70 | kvec = 1:B; 71 | end 72 | for k=kvec 73 | c = tree{t}.point(k,2); 74 | internal_index = tree{t}.point(k,1); 75 | points(k,1) = sub_chain{c}(internal_index, i); 76 | points(k,2) = sub_chain{c}(internal_index, j); 77 | end 78 | plot(points(:,1), points(:,2), '.', 'color', cc1(c,:), 'MarkerSize', 0.1); 79 | else 80 | % draw filled rectangle 81 | rectangle('Position',[x y w h],... 82 | 'FaceColor', cc2(floor(dens * size(cc2,1))+1,:), ... 83 | 'LineWidth', 0.5); 84 | % rectangle('Position',[x y w h], 'LineWidth', 0.1); 85 | end 86 | end 87 | end 88 | hold off; 89 | xlabel(['parameter ', num2str(i)]); 90 | ylabel(['parameter ', num2str(j)]); 91 | end 92 | end 93 | end 94 | 95 | end -------------------------------------------------------------------------------- /src/utils/relative_error.m: -------------------------------------------------------------------------------- 1 | function re = relative_error(samples, full_chain, theta) 2 | % re = relative_error(samples, full_chain, theta) 3 | % ||posterior_theta - theta||_2 / ||posterior_theta (full chain) - theta||_2 4 | theta = theta(:)'; 5 | assert(size(samples,2)==length(theta) && size(full_chain,2)==length(theta)); 6 | error_full_chain = mean((full_chain - repmat(theta, size(full_chain,1), 1)).^2, 1); 7 | error_samples = mean((samples - repmat(theta, size(samples, 1), 1)).^2, 1); 8 | re = sqrt(sum(error_samples) / sum(error_full_chain)); 9 | end 10 | 11 | -------------------------------------------------------------------------------- /src/utils/rmse_posterior_cov.m: -------------------------------------------------------------------------------- 1 | function rmse = rmse_posterior_cov(samples_left, samples_right, varargin) 2 | % compute the rmse of posterior cov between two samples 3 | assert(size(samples_left,2)==size(samples_right,2),'dim not match'); 4 | cov_left = cov(samples_left); 5 | cov_right = cov(samples_right); 6 | errors = (cov_left - cov_right).^2; 7 | if ~isempty(varargin) && strcmp(varargin{1},'diag') 8 | errors = diag(errors); 9 | end 10 | rmse = sqrt(mean(errors(:))); 11 | end -------------------------------------------------------------------------------- /src/utils/rmse_posterior_mean.m: -------------------------------------------------------------------------------- 1 | function rmse = rmse_posterior_mean(samples_left, samples_right) 2 | % rmse of the posterior mean of two samples 3 | assert(size(samples_left,2)==size(samples_right,2),'dim not match'); 4 | mean_left = mean(samples_left,1); 5 | mean_right = mean(samples_right,1); 6 | rmse = sqrt(mean((mean_left - mean_right).^2)); 7 | end 8 | 9 | -------------------------------------------------------------------------------- /src/utils/thinning.m: -------------------------------------------------------------------------------- 1 | function output_chains = thinning(chains, N, varargin) 2 | % Call: 3 | % output_chain = thinning(chain, N) 4 | % {chain_1, chain_2,..} = thinning({chain_1, chain_2,..}, N) 5 | % chains = thinning(chains, N, 'burn', burn_in) 6 | % 7 | % thin a chain (or a cell of chains) to the desired length N, 8 | % optionally after discarding a burn-in period at head 9 | 10 | burn_in = 0; 11 | if ~isempty(varargin) && strcmp(varargin{1},'burn') 12 | burn_in = varargin{2}; 13 | end 14 | 15 | if isa(chains, 'double') 16 | chains = chains(burn_in+1:end, :); 17 | if size(chains,1)>N 18 | idx = floor(linspace(1,size(chains,1),N)); 19 | output_chains = chains(idx, :); 20 | else 21 | output_chains = chains; 22 | end 23 | else 24 | assert(isa(chains, 'cell'), 'wrong chain'); 25 | output_chains = cell(size(chains)); 26 | for m=1:numel(chains) 27 | chains{m} = chains{m}(burn_in+1:end, :); 28 | if size(chains{m},1)>N 29 | idx = floor(linspace(1,size(chains{m},1),N)); 30 | output_chains{m} = chains{m}(idx, :); 31 | else 32 | output_chains{m} = chains{m}; 33 | end 34 | end 35 | end 36 | 37 | end 38 | 39 | --------------------------------------------------------------------------------