├── .gitignore ├── src ├── means.jl ├── datasets.jl ├── scorefunctions.jl ├── utilFunctions.jl ├── nodeFunctions.jl ├── finetuning.jl ├── spnmatrix.jl ├── DeepStructuredMixtures.jl ├── optimisers.jl ├── AdvancedCholeskey.jl ├── optimize.jl ├── gaussianprocess.jl ├── regionGraphUtils.jl ├── plot.jl ├── kernels.jl ├── common.jl ├── fit.jl └── treeStructure.jl ├── Project.toml └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | *.swp 2 | *.tar 3 | *.tar.gz 4 | *.ipynb_checkpoints 5 | Manifest.toml 6 | *.svg 7 | -------------------------------------------------------------------------------- /src/means.jl: -------------------------------------------------------------------------------- 1 | export MeanFunction, ConstMean 2 | 3 | # Mean functions for a GP 4 | 5 | abstract type MeanFunction end 6 | 7 | struct ConstMean{T<:AbstractFloat} <: MeanFunction 8 | m::T 9 | end 10 | 11 | function apply_subtract!(m::ConstMean, y::AbstractVector, yout::AbstractVector) 12 | map!(i -> i - m.m, yout, y) 13 | return yout 14 | end 15 | 16 | function get(m::ConstMean{T}, N::Int) where {T} 17 | return ones(T,N)*m.m 18 | end 19 | -------------------------------------------------------------------------------- /src/datasets.jl: -------------------------------------------------------------------------------- 1 | using StatsFuns 2 | 3 | export nonstationary 4 | 5 | function nonstationary(n; σ²=0.4) 6 | 7 | # Create toy data 8 | x = range(-200, stop = 200, length = n); 9 | 10 | f1 = 3.0*sin.(-3 .+ 0.2.*x[1:Int(ceil(0.25*n))]) 11 | f1 = vcat(f1, 0*sin.(0.1*x[Int(ceil(0.25*n))+1:Int(ceil(0.75*n))])) 12 | f1 = vcat(f1, 3.0*sin.(2.8 .+ 0.2.*x[Int(ceil(0.75*n)) .+ 1:end])) 13 | 14 | f2 = 100*normpdf.(110, 20, x) + 100*normpdf.(-10, 20, x) 15 | 16 | x = x .- mean(x) 17 | x = x / std(x) 18 | f1 = f1 .- mean(f1) 19 | f1 = f1 / std(f1) 20 | 21 | noise = sqrt.((σ².*exp.(f2))) 22 | y = f1 + noise.*randn(size(x)) 23 | x=x[:]*10 24 | y=y[:]; 25 | 26 | return reshape(x,:,1), y, noise 27 | end 28 | -------------------------------------------------------------------------------- /src/scorefunctions.jl: -------------------------------------------------------------------------------- 1 | export mse, sse 2 | export mae, sae 3 | export nlpd 4 | 5 | # squared error, mean squared error and standard error of mean squared error 6 | @inline se(y_true, y_pred) = (y_true - y_pred).^2 7 | @inline mse(y_true, y_pred) = mean(se(y_true, y_pred)) 8 | @inline sse(y_true, y_pred) = std(se(y_true, y_pred)) / sqrt(size(y_true,1)) 9 | 10 | # absolute error, mean absolute error and standard error of mean absolute error 11 | @inline ae(y_true, y_pred) = abs.(y_true - y_pred) 12 | @inline mae(y_true, y_pred) = mean(ae(y_true, y_pred)) 13 | @inline sae(y_true, y_pred) = std(ae(y_true, y_pred)) / sqrt(size(y_true,1)) 14 | 15 | # negative log predictive density 16 | @inline nlpd(y_true::AbstractVector, μ::AbstractVector, σ²::AbstractVector) = -mapreduce(i -> logpdf(Normal(μ[i], sqrt(σ²[i])), y_true[i]), +, 1:length(y_true))/length(y_true) 17 | -------------------------------------------------------------------------------- /Project.toml: -------------------------------------------------------------------------------- 1 | name = "DeepStructuredMixtures" 2 | uuid = "e2c58eaa-0682-4206-a956-029c3ead8c40" 3 | authors = ["Martin Trapp "] 4 | version = "0.1.0" 5 | 6 | [deps] 7 | AxisArrays = "39de3d68-74b9-583c-8d2d-e117c070f3a9" 8 | Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" 9 | Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" 10 | Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" 11 | LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" 12 | Optim = "429524aa-4258-5aef-a3af-852621145aeb" 13 | ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" 14 | Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" 15 | RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" 16 | Reexport = "189a3867-3050-52da-a836-e630ba90ab69" 17 | Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" 18 | StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" 19 | StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" 20 | SumProductNetworks = "5f6e642e-680c-5a9c-a175-7c23ed4da89e" 21 | -------------------------------------------------------------------------------- /src/utilFunctions.jl: -------------------------------------------------------------------------------- 1 | function getAllSplits(spn) 2 | 3 | splitNodes = filter(n -> isa(n, FiniteSplitNode), SumProductNetworks.getOrderedNodes(spn)) 4 | allSplits = Dict{Int, Vector{Vector{Float64}}}() 5 | 6 | for splitNode in splitNodes 7 | d = depth(splitNode) 8 | if !haskey(allSplits, d) 9 | allSplits[d] = Vector{Vector{Float64}}(0) 10 | end 11 | 12 | push!(allSplits[d], splitNode.split) 13 | end 14 | 15 | return allSplits 16 | end 17 | 18 | function getSplits(spn, minDepth) 19 | 20 | splitNodes = filter(n -> isa(n, FiniteSplitNode), SumProductNetworks.getOrderedNodes(spn)) 21 | allSplits = Dict{Int, Vector{Vector{Float64}}}() 22 | 23 | for splitNode in splitNodes 24 | d = depth(splitNode) 25 | 26 | if d >= minDepth 27 | 28 | if !haskey(allSplits, d) 29 | allSplits[d] = Vector{Vector{Float64}}(0) 30 | end 31 | 32 | push!(allSplits[d], splitNode.split) 33 | end 34 | end 35 | 36 | return allSplits 37 | end 38 | 39 | function plotSplits!(plt, splits) 40 | depths = sort(collect(keys(splits))) 41 | for d in depths 42 | vline!(plt, [s[1] for s in splits[d]], label = "depth $(d) splits") 43 | end 44 | plt 45 | end -------------------------------------------------------------------------------- /src/nodeFunctions.jl: -------------------------------------------------------------------------------- 1 | function add!(parent::GPSumNode, child::SPNNode) 2 | if !(child in parent.children) 3 | push!(parent.children, child) 4 | push!(parent.prior_weights, 1.) 5 | push!(parent.posterior_weights, 1.) 6 | push!(child.parents, parent) 7 | 8 | parent.prior_weights ./= sum(parent.prior_weights) 9 | parent.posterior_weights ./= sum(parent.posterior_weights) 10 | end 11 | 12 | @assert sum(parent.prior_weights) ≈ 1. "Weights should sum up to one, sum(w) = $(sum(parent.prior_weights))" 13 | @assert sum(parent.posterior_weights) ≈ 1. "Weights should sum up to one, sum(w) = $(sum(parent.prior_weights))" 14 | end 15 | 16 | function add!(parent::FiniteSplitNode, child::SPNNode) 17 | if !(child in parent.children) 18 | push!(parent.children, child) 19 | push!(child.parents, parent) 20 | end 21 | end 22 | 23 | Base.show(io::IO, n::GPLeaf) = 24 | print(io, "Gaussian Process Leaf Node [ID: ", n.id, ", LLH: ", round(n.gp.target, 3), "]") 25 | 26 | Base.show(io::IO, n::FiniteSplitNode) = 27 | print(io, "Split (Product) Node [ID: ", n.id, ", split: ", round.(n.split, 3), "]") 28 | 29 | Base.show(io::IO, n::GPSumNode) = 30 | print(io, "Gaussian Process Sum Node [ID: ", n.id, ", \n\t w_prior: ", 31 | round.(n.prior_weights, 3), ", \n\t w_posterior: ", 32 | round.(n.posterior_weights, 3), "]") 33 | 34 | function isDirty(n::GPSumNode) 35 | return n.posteriorDirty 36 | end -------------------------------------------------------------------------------- /src/finetuning.jl: -------------------------------------------------------------------------------- 1 | export finetune! 2 | 3 | function finetune!(model::DSMGP, optim; iterations = 1000, λ = 0.5) 4 | return finetune!(model.root, model.D, model.gpmap, optim, iterations=iterations, λ=λ) 5 | end 6 | 7 | function finetune!(spn::Union{GPSumNode,GPSplitNode}, D::AbstractMatrix, gpmap, optim; 8 | iterations = 1000, 9 | λ = 0.5 # early stopping 10 | ) 11 | 12 | gp = leftGP(spn) 13 | 14 | n = gp isa Array ? sum(map(sum, nparams.(gp))) : sum(nparams(gp)) 15 | grad = zeros(n) 16 | 17 | nodes = SumProductNetworks.getOrderedNodes(spn); 18 | ids = map(n -> n.id, nodes) 19 | 20 | gps = filter(n -> n isa GPNode, nodes); 21 | 22 | hyp = Dict(gp.id => reduce(vcat, params(gp.dist, logscale=true)) for gp in gps) 23 | 24 | L = AxisArray(zeros(length(ids)), ids) 25 | l = 0.0 26 | δ = Inf 27 | c = 0 28 | ℓ = zeros(iterations) 29 | 30 | Dd = copy(D) 31 | Dd[diagind(Dd)] .= 1.0 32 | 33 | p = Progress(iterations, 1, "Training...") 34 | for iteration in 1:iterations 35 | 36 | l = 0.0 37 | for gp in gps 38 | hyp_ = hyp[gp.id] 39 | 40 | # set the parameter 41 | setparams!(spn, hyp_) 42 | 43 | # fit model 44 | fit!(spn, D, gpmap) 45 | 46 | # compute mll 47 | fill!(L, 0.0) 48 | mll!(spn, L) 49 | 50 | updategradients!(spn) 51 | l += L[gp.id] 52 | 53 | fill!(grad, 0.0) 54 | ∇mll!(spn, 0.0, 0.0, L, L[spn.id], grad, view(D, gpmap.x[gp.id], :), gpmap) 55 | Flux.Optimise.apply!(optim, hyp_, grad) 56 | hyp[gp.id] += grad 57 | end 58 | 59 | ℓ[iteration] = l 60 | 61 | δ = iteration > 10 ? abs(ℓ[(iteration)] - mean(ℓ[(iteration-9):(iteration-1)])) : Inf 62 | ProgressMeter.next!(p; showvalues = [(:iter,iteration), (:delta,δ), (:c,c), (:llh,L[spn.id])]) 63 | 64 | # early stopping 65 | if δ < λ 66 | c += 1 67 | else 68 | c = 0 69 | end 70 | 71 | if c >= 10 72 | @info "Early stopping at iteration $iteration with δ: $δ" 73 | 74 | for gp in gps 75 | setparams!(gp.dist, hyp[gp.id]) 76 | update_cholesky!(gp.dist) 77 | end 78 | return spn, ℓ 79 | end 80 | end 81 | 82 | for gp in gps 83 | setparams!(gp.dist, hyp[gp.id]) 84 | update_cholesky!(gp.dist) 85 | end 86 | 87 | return spn, ℓ 88 | end 89 | 90 | -------------------------------------------------------------------------------- /src/spnmatrix.jl: -------------------------------------------------------------------------------- 1 | export SDiagonal 2 | export copyvec 3 | 4 | using LinearAlgebra 5 | import LinearAlgebra: lmul! 6 | import Base.fill! 7 | 8 | struct SDiagonal{T<:Real,N,MT<:AbstractArray{T}} <: AbstractArray{T,N} 9 | Ix :: Vector{Dict{Int,Int}} 10 | V :: Vector{MT} 11 | end 12 | 13 | function SDiagonal() 14 | return SDiagonal{Float64,1,Vector{Float64}}(Vector{Dict{Int,Int}}(), Vector{Vector{Float64}}()) 15 | end 16 | 17 | function SDiagonal(Ix::Vector{Vector{Int}}, V::Vector{MT}) where {T<:Real,N,MT<:AbstractArray{T,N}} 18 | Ixdict = Vector{Dict{Int,Int}}(undef,length(Ix)) 19 | Vsparse = similar(V) 20 | for i in Base.axes(Ix,1) 21 | I = Ix[i] 22 | @assert size(V[i],1) == length(I) 23 | Vsparse[i] = N < 2 ? copy(V[i][I .>= i]) : copy(V[i][I .>= i,:]) 24 | II = @view I[I .>= i] 25 | Ixdict[i] = Dict{Int,Int}(j => k for (k,j) in enumerate(II)) 26 | end 27 | SDiagonal{T,N+1,MT}(Ixdict, Vsparse) 28 | end 29 | 30 | function Base.getindex(SD::SDiagonal{T,2,MT}, I::Vararg{Int,2}) where {T,MT<:AbstractArray} 31 | i,j = I 32 | i,j = i>j ? (j,i) : (i,j) 33 | return haskey(SD.Ix[i],j) ? SD.V[i][SD.Ix[i][j]] : zero(T) 34 | end 35 | 36 | function Base.getindex(SD::SDiagonal{T,3,MT}, I::Vararg{Int,3}) where {T,MT<:AbstractArray} 37 | i,j,k = I 38 | i,j = i>j ? (j,i) : (i,j) 39 | return haskey(SD.Ix[i],j) ? SD.V[i][SD.Ix[i][j],k] : zero(T) 40 | end 41 | 42 | function Base.setindex!(SD::SDiagonal{T,2,MT}, v::Tv, I::Vararg{Int,2}) where {T,MT<:AbstractVector,Tv<:Number} 43 | i,j = I 44 | i,j = i>j ? (j,i) : (i,j) 45 | if haskey(SD.Ix[i],j) 46 | SD.V[i][SD.Ix[i][j]] = v 47 | end 48 | end 49 | 50 | function Base.setindex!(SD::SDiagonal{T,3,MT}, v::Tv, I::Vararg{Int,3}) where {T,MT<:AbstractArray,Tv<:Number} 51 | i,j,k = I 52 | i,j = i>j ? (j,i) : (i,j) 53 | SD.V[i][SD.Ix[i][j],k] = v 54 | end 55 | 56 | function Base.size(SD::SDiagonal{T,N,MT}) where {T,N,MT<:AbstractArray} 57 | return (length(SD.Ix), length(SD.Ix), size(SD.V[1])[2:(N-1)]...) 58 | end 59 | 60 | function LinearAlgebra.lmul!(a::T, B::SDiagonal{TB,N,MT}) where {T<:Number,TB,N,MT} 61 | for i in eachindex(B.V) 62 | @inbounds lmul!(a,@view(B.V[i])) 63 | end 64 | end 65 | 66 | function Base.fill!(A::SDiagonal{T,N,MT}, b::Tb) where {T,N,MT,Tb<:Number} 67 | for i in eachindex(A.V) 68 | @inbounds fill!(A.V[i],b) 69 | end 70 | end 71 | 72 | function Base.copy(M::SDiagonal{T,N,MT}) where {T,N,MT} 73 | V = deepcopy(M.V) 74 | Ix = deepcopy(M.Ix) 75 | return SDiagonal{T,N,MT}(V,Ix) 76 | end 77 | 78 | function Base.zero(M::SDiagonal{T,N,MT}) where {T,N,MT} 79 | Ix = deepcopy(M.Ix) 80 | V = deepcopy(M.V) 81 | for i in eachindex(V) 82 | @inbounds fill!(V[i],zero(T)) 83 | end 84 | return SDiagonal{T,N,MT}(V,Ix) 85 | end 86 | 87 | function copyvec(M::SDiagonal{T,2,MT}, dim::Int) where {T,MT} 88 | return deepcopy(M) 89 | end 90 | 91 | function copyvec(M::SDiagonal{T,3,MT}, dim::Int) where {T,N,MT} 92 | V = map(v -> copy(vec(v[:,dim])), M.V) 93 | return SDiagonal{T,2,typeof(V[1])}(deepcopy(M.Ix),V) 94 | end 95 | -------------------------------------------------------------------------------- /src/DeepStructuredMixtures.jl: -------------------------------------------------------------------------------- 1 | module DeepStructuredMixtures 2 | 3 | using Reexport 4 | @reexport using SumProductNetworks 5 | @reexport using Flux 6 | using RecipesBase 7 | using Distributions 8 | using StatsFuns 9 | using LinearAlgebra 10 | using AxisArrays 11 | 12 | import Base.rand 13 | import Base.get 14 | 15 | import SumProductNetworks.scope 16 | import SumProductNetworks.hasscope 17 | import SumProductNetworks.hasobs 18 | import SumProductNetworks.params 19 | 20 | import StatsBase.params 21 | 22 | export GPSumNode, GPSplitNode, GPNode, DSMGPConfig 23 | export DSMGP, PoE, gPoE, rBCM 24 | export getchild, leftGP, rightGP, predict, getx, gety, rand, 25 | buildTree, optimize!, stats, resample!, optim!, target 26 | 27 | const ϵ = 1e-8 28 | 29 | include("AdvancedCholeskey.jl") 30 | 31 | # custom block-diagonal matrix type 32 | include("spnmatrix.jl") 33 | 34 | # GP related codes 35 | include("means.jl") 36 | include("kernels.jl") 37 | include("gaussianprocess.jl") 38 | 39 | # Type definitions 40 | struct GPSumNode{T<:Real,C<:SPNNode} <: SumNode{T} 41 | id::Symbol 42 | parents::Vector{<:Node} 43 | children::Vector{C} 44 | logweights::Vector{T} 45 | end 46 | 47 | function Base.show(io::IO, ::MIME"text/plain", m::GPSumNode) 48 | print(io, "GP Sum Node [$(m.id)] \n weights: ",exp.(m.logweights)) 49 | end 50 | Base.show(io::IO, m::GPSumNode) = print(io, "GPSumNode(",m.id,")") 51 | 52 | struct GPSplitNode <: ProductNode 53 | id::Symbol 54 | parents::Vector{<:Node} 55 | children::Vector{<:SPNNode} 56 | lowerBound::Vector{Float64} 57 | upperBound::Vector{Float64} 58 | split::Vector{Tuple{Int, Float64}} 59 | end 60 | 61 | struct GPNode <: Leaf 62 | id::Symbol 63 | parents::Vector{<:Node} 64 | dist::GaussianProcess 65 | observations::BitArray{1} 66 | obs::Vector{Int} 67 | lb::Vector{Float64} 68 | ub::Vector{Float64} 69 | nobs::Int 70 | kernelid::Int 71 | end 72 | 73 | 74 | @inline id(node::Node) = node.id 75 | @inline id(node::GPNode) = node.id 76 | 77 | params(node::GPNode) = params(node.dist) 78 | 79 | @inline hasscope(node::GPNode) = true 80 | @inline hasscope(node::GPSumNode) = true 81 | @inline hasscope(node::GPSplitNode) = true 82 | 83 | @inline scope(node::GPNode) = node.observations 84 | @inline scope(node::GPSplitNode) = mapreduce(scope, vcat, children(node)) 85 | @inline scope(node::GPSumNode) = scope(node[1]) 86 | 87 | @inline hasobs(node::GPNode) = false 88 | @inline hasobs(node::GPSplitNode) = false 89 | @inline hasobs(node::GPSumNode) = false 90 | 91 | struct DSMGPConfig 92 | meanFun::Union{Nothing,MeanFunction} 93 | kernels::Union{KernelFunction, Vector{KernelFunction}} 94 | observationNoise::Float64 95 | minData::Int 96 | K::Int # number of splits per GPSplitNode 97 | V::Int # number of children under GPSumNode 98 | depth::Int # maximum depth (consecutive sum-product nodes) 99 | bnoise::Float64 # split noise 100 | sumRoot::Bool # use sum root 101 | end 102 | 103 | struct BiDict 104 | x::Dict 105 | fx::Dict 106 | end 107 | 108 | struct DSMGP{T<:Real} 109 | root::Node 110 | D::Matrix{T} 111 | gpmap::BiDict 112 | end 113 | 114 | struct PoE{T<:Real} 115 | root::GPSplitNode 116 | D::Matrix{T} 117 | gpmap::BiDict 118 | end 119 | 120 | struct gPoE{T<:Real} 121 | root::GPSplitNode 122 | D::Matrix{T} 123 | gpmap::BiDict 124 | end 125 | 126 | struct rBCM{T<:Real} 127 | root::GPSplitNode 128 | D::Matrix{T} 129 | gpmap::BiDict 130 | end 131 | 132 | # codes 133 | include("common.jl") 134 | include("treeStructure.jl") 135 | include("fit.jl") 136 | include("optimize.jl") 137 | include("optimisers.jl") 138 | include("finetuning.jl") 139 | 140 | # utilities 141 | include("plot.jl") 142 | include("datasets.jl") 143 | include("scorefunctions.jl") 144 | end 145 | -------------------------------------------------------------------------------- /src/optimisers.jl: -------------------------------------------------------------------------------- 1 | using ProgressMeter 2 | export train! 3 | 4 | function train!(model::Union{DSMGP,PoE,gPoE,rBCM}, optim; iterations = 10_000, λ = 0.05, randinit = true, earlystop = 10) 5 | train!(model.root, model.D, model.gpmap, optim, iterations=iterations, λ=λ, randinit = randinit, earlystop = earlystop) 6 | end 7 | 8 | function train!(spn::Union{GPSumNode,GPSplitNode}, D::AbstractMatrix, gpmap::BiDict, optim; 9 | iterations = 10_000, 10 | λ = 0.05, # early stopping 11 | sharedGradients = false, 12 | randinit = true, earlystop = 10 13 | ) 14 | 15 | gp = leftGP(spn) 16 | 17 | n = gp isa Array ? sum(map(sum, nparams.(gp))) : sum(nparams(gp)) 18 | hyp = randinit ? randn(n) : reduce(vcat, params(gp, logscale=true)) 19 | grad = zeros(n) 20 | 21 | nodes = SumProductNetworks.getOrderedNodes(spn) 22 | ids = map(n -> n.id, nodes) 23 | L = AxisArray(zeros(length(ids)), ids) 24 | 25 | p = Progress(iterations, 1, "Training...") 26 | 27 | c = 0 28 | δ = 0.0 29 | ℓ = zeros(iterations) 30 | 31 | P = SDiagonal() 32 | K = SDiagonal() 33 | if sharedGradients && (gp isa GaussianProcess) 34 | P = distancematrix(spn, gp.kernel, getx(spn)) 35 | K = copyvec(P,1) 36 | fill!(K, 0.0) 37 | kernelmatrix!(gp.kernel, K, P) 38 | end 39 | 40 | for iteration in 1:iterations 41 | 42 | # set the parameter 43 | setparams!(spn, hyp) 44 | 45 | # fit model 46 | fit!(spn, D, gpmap) 47 | 48 | # compute mll 49 | fill!(L, 0.0) 50 | mll!(spn, L) 51 | 52 | ℓ[iteration] = L[spn.id] 53 | δ = iteration > 10 ? abs(ℓ[(iteration)] - mean(ℓ[(iteration-9):(iteration-1)])) : Inf 54 | ProgressMeter.next!(p; showvalues = [(:iter,iteration), (:delta,δ), (:c,c), (:llh,L[spn.id])]) 55 | 56 | # early stopping 57 | if δ < λ 58 | c += 1 59 | else 60 | c = 0 61 | end 62 | 63 | if c >= earlystop 64 | @info "Early stopping at iteration $iteration with δ: $δ" 65 | return spn, ℓ[1:iteration] 66 | end 67 | 68 | if sharedGradients && (gp isa GaussianProcess) 69 | fill!(K, 0.0) 70 | kernelmatrix!(gp.kernel, K, P) 71 | updategradients!(spn, K, P, D, gpmap) 72 | else 73 | updategradients!(spn) 74 | end 75 | 76 | fill!(grad, 0.0) 77 | ∇mll!(spn, 0.0, 0.0, L, L[spn.id], grad) 78 | Flux.Optimise.apply!(optim, hyp, grad) 79 | hyp += grad 80 | end 81 | 82 | setparams!(spn, hyp) 83 | fit!(spn, D, gpmap) 84 | 85 | @info "Exit after $iterations iterations with δ: $δ" 86 | return spn, ℓ 87 | end 88 | 89 | function train!(gp::GaussianProcess; 90 | iterations = 10_000, 91 | optim = RMSProp(), 92 | λ = 0.1 # early stopping 93 | ) 94 | 95 | n = sum(nparams(gp)) 96 | hyp = randn(n) 97 | oldhyp = hyp 98 | grad = zeros(n) 99 | 100 | δ = 0.0 101 | ℓ = zeros(iterations) 102 | p = Progress(iterations, 1, "Training...") 103 | 104 | for iteration in 1:iterations 105 | 106 | # set the parameter 107 | setparams!(gp, hyp) 108 | 109 | # fit model 110 | update_cholesky!(gp) 111 | 112 | # compute mll 113 | ℓ[iteration] = mll(gp) 114 | 115 | if isnan(ℓ[iteration]) 116 | setparams!(gp, oldhyp) 117 | update_cholesky!(gp) 118 | return gp, ℓ[1:iteration] 119 | end 120 | 121 | δ = iteration > 10 ? abs(ℓ[(iteration)] - mean(ℓ[(iteration-9):(iteration-1)])) : Inf 122 | ProgressMeter.next!(p; showvalues = [(:iter,iteration), (:delta,δ), (:llh,ℓ[iteration])]) 123 | 124 | # early stopping 125 | if δ < λ 126 | @info "Early stopping at iteration $iteration with δ: $δ" 127 | return gp, ℓ[1:iteration] 128 | end 129 | 130 | updategradients!(gp) 131 | 132 | fill!(grad, 0.0) 133 | ∇mll!(gp, grad) 134 | 135 | oldhyp = copy(hyp) 136 | Flux.Optimise.apply!(optim, hyp, grad) 137 | hyp += grad 138 | end 139 | 140 | setparams!(gp, hyp) 141 | update_cholesky!(gp) 142 | 143 | @info "Exit after $iterations iterations with δ: $δ" 144 | return gp, ℓ 145 | end 146 | -------------------------------------------------------------------------------- /src/AdvancedCholeskey.jl: -------------------------------------------------------------------------------- 1 | """ 2 | # AdvancedCholesky module 3 | 4 | This module aims to implement extensions to existing Cholesky factorisations 5 | available in Julia. 6 | 7 | """ 8 | module AdvancedCholesky 9 | using LinearAlgebra, Statistics, Random 10 | import LinearAlgebra.lowrankupdate 11 | 12 | genCov(D::Int; uplo::Symbol=:U) = Symmetric(rand(D,D) .+ Matrix(I*D,D,D), uplo) 13 | 14 | 15 | function lowrankupdate!(C::Cholesky, v::StridedVector, k::Int) 16 | lowrankupdate!(C.factors, v, k, C.uplo) 17 | return C 18 | end 19 | 20 | function lowrankupdate!(A::Matrix, v::StridedVector, k::Int, uplo::Char) 21 | @assert k > 0 22 | 23 | n = length(v) 24 | if (size(A,1)-(k-1)) != n 25 | throw(DimensionMismatch("updating vector must fit size of factorization")) 26 | end 27 | if uplo == 'U' 28 | conj!(v) 29 | end 30 | 31 | for i = k:n 32 | 33 | # Compute Givens rotation 34 | @inbounds c, s, r = LinearAlgebra.givensAlgorithm(A[i,i], v[i-(k-1)]) 35 | 36 | # Store new diagonal element 37 | @inbounds A[i,i] = r 38 | 39 | # Update remaining elements in row/column 40 | if uplo == 'U' 41 | @inbounds for j = i + 1:n 42 | Aij = A[i,j] 43 | vj = v[j-(k-1)] 44 | A[i,j] = c*Aij + s*vj 45 | v[j-(k-1)] = -s'*Aij + c*vj 46 | end 47 | else 48 | for j = i + 1:n 49 | @inbounds begin 50 | Aji = A[j,i] 51 | vj = v[j-(k-1)] 52 | A[j,i] = c*Aji + s*vj 53 | v[j-(k-1)] = -s'*Aji + c*vj 54 | end 55 | end 56 | end 57 | end 58 | return A 59 | end 60 | 61 | function lrtest(;D = 1000, uplo = :L) 62 | 63 | missing_rows = shuffle(1:(D-1))[1:10] 64 | P = D - length(missing_rows) 65 | 66 | @info "A = $D x $D , and B = $P x $P" 67 | @info "# rank-1 updates: $(length(missing_rows)/P)" 68 | 69 | # B does not contain column/row 3 70 | idx = setdiff(1:D, missing_rows) 71 | 72 | runs = 10 73 | 74 | t1 = zeros(runs) 75 | t2 = zeros(runs) 76 | err = zeros(runs) 77 | 78 | for r = 1:runs 79 | 80 | A = genCov(D, uplo = uplo) 81 | B = A[idx,idx] 82 | 83 | # Cholesky of A 84 | C = cholesky(A) 85 | 86 | # Cholesky of B (for testing) 87 | t1[r] = @elapsed CCt = cholesky(Symmetric(B, uplo)) 88 | 89 | # Copy Cholesky of A 90 | CC = deepcopy(C) 91 | 92 | # rank-1 update 93 | t2[r] = @elapsed begin 94 | @inbounds for r in missing_rows 95 | if uplo == :U 96 | lowrankupdate!(CC, view(CC.factors,r,(r+1):D), (r+1)) 97 | else 98 | lowrankupdate!(CC, view(CC.factors,(r+1):D,r), (r+1)) 99 | end 100 | end 101 | end 102 | 103 | # compute error 104 | F = Cholesky(CC.factors[idx, idx], uplo, 0) 105 | err[r] = sum(abs.(F.U .- CCt.U)) 106 | end 107 | 108 | @info "Nunmerical error (avg): $(mean(err))" 109 | @info "Time difference (avg): $(mean(t1 .- t2)) sec" 110 | end 111 | 112 | 113 | """ 114 | Run a simple test on `chol_continue!`. 115 | 116 | Return: 117 | * `t1`: time taken for LinearAlgebra.cholesky! 118 | * `t2`: time taken for LAPACK.potrf! and chol_continue! 119 | * `Δ`: difference between Cholesky factorisations 120 | """ 121 | function test_chol_continue() 122 | D = 100 123 | P = 10 124 | Σ = genCov(D, uplo = :L) 125 | U = deepcopy(Σ) 126 | A = deepcopy(Σ) 127 | 128 | t1 = @elapsed LinearAlgebra.cholesky!(U) 129 | t2 = @elapsed begin 130 | LAPACK.potrf!('L', view(A.data, 1:P, 1:P)) 131 | chol_continue!(A.data, P+1, LowerTriangular) 132 | end 133 | 134 | return t1 - t2, sum(abs.(LowerTriangular(A) .- LowerTriangular(U))) 135 | end 136 | 137 | """ 138 | # Continue a Cholesky decomposition of `A` from `ki` on. 139 | `chol_continue!(A::AbstractMatrix, ki::Int)` 140 | 141 | Currently only works if a A contains elements of a lower-triangular Cholsky decomposition. 142 | 143 | ## Usage: 144 | 145 | ```julia 146 | A = genCov(10) 147 | LAPACK.potrf!('L', view(A.data, 1:5, 1:5)) 148 | chol_continue!(A.data, 5+1) 149 | ``` 150 | 151 | """ 152 | function chol_continue!(A::AbstractMatrix{T}, ki::Int) where {T<:LinearAlgebra.BlasFloat} 153 | 154 | # clean up 155 | # necessary? 156 | tril!(A) 157 | 158 | N = size(A,1)-ki 159 | # update lower matrix 160 | v = @view A[1:(ki-1),1:(ki-1)] 161 | @inbounds A[ki:end, 1:(ki-1)] /= v' 162 | 163 | # symmetrix rank-k update 164 | # check if this correct! 165 | v = @view A[ki:end,1:(ki-1)] 166 | C = @view A[ki:end,ki:end] 167 | BLAS.syrk!('L', 'N', -one(T), v, one(T), C) 168 | 169 | # solve Cholesky of remainder 170 | C = @view A[ki:end,ki:end] 171 | _, info = LAPACK.potrf!('L', C) 172 | 173 | return LowerTriangular(A), info 174 | end 175 | 176 | end # module 177 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Deep Mixtures of Gaussian Processes 2 | 3 | This package implements Deep Structured Mixtures of Gaussian Processes (DSMGP) [1] in Julia 1.3. 4 | 5 | ## Installation 6 | To use this package you need Julia 1.3 installed on your machine. 7 | 8 | Inside the Julia REPL, you can install the package in the Pkg mode (type `]` in the REPL): 9 | ```julia 10 | pkg> add https://github.com/trappmartin/DeepStructuredMixtures 11 | ``` 12 | 13 | After the installation you can load the package using: 14 | ```julia 15 | using DeepStructuredMixtures 16 | ``` 17 | 18 | Note that the package will be compiled the first time you load it. 19 | 20 | ## Python bridge 21 | The package can be used from python using the excelent pyjulia package: https://github.com/JuliaPy/pyjulia 22 | 23 | ## Usage 24 | The following example explains the usage of `DeepStructuredMixtures`. Note that this example assume that you have `Plots` installed in your Julia environment. 25 | 26 | First, we load the necessary libraries: 27 | ```julia 28 | using Plots 29 | using DeepStructuredMixtures 30 | using Random 31 | ``` 32 | 33 | Now we can create some synthetic data or load some real data: 34 | ```julia 35 | xtrain = collect(range(0, stop=1, length = 100)) 36 | ytrain = sin.(xtrain*4*pi + randn(100)*0.2) 37 | ``` 38 | 39 | We will now use a squared exponential kernel-function with a constant mean-function to fit the DSMGP. See API for more options. 40 | ```julia 41 | kernelf = IsoSE(1.0, 1.0) 42 | meanf = ConstMean(mean(xtrain)) 43 | ``` 44 | 45 | Now we can construct a DSMGP on our data and find optimial hyperparameters. 46 | ```julia 47 | K = 4 # Number of splits per product node 48 | V = 3 # Number of children per sum node 49 | M = 10 # Minimum number of observations per expert 50 | 51 | model = buildDSMGP(reshape(xtrain,:,1), ytrain, V, K; M = M, kernel = kernelf, meanFun = meanf) 52 | train!(model, ADAM()) 53 | 54 | # finally we perfom exact posterior infence 55 | update!(model) 56 | ``` 57 | 58 | Note that for large data sets it is recommended to train the DSMGP with `V = 1` and use the hyper-parameters to initialise the training of a model with `V > 1`: 59 | ```julia 60 | model1 = buildDSMGP(reshape(xtrain,:,1), ytrain, 1, V; M = M, kernel = kernelf, meanFun = meanf) 61 | train!(model1, ADAM()) 62 | 63 | # get hyper-parameters 64 | hyp = reduce(vcat, params(leftGP(model1.root), logscale=true)) 65 | 66 | model = buildDSMGP(reshape(xtrain,:,1), ytrain, K, V; M = M, kernel = kernelf, meanFun = meanf) 67 | 68 | # set hyper-parameters instead of learning from scratch 69 | setparams!(model.root, hyp) 70 | train!(model, ADAM(), randinit = false) 71 | ``` 72 | 73 | Finally, we can plot the model: 74 | ```julia 75 | plot(model) 76 | ``` 77 | 78 | and use it for predictions: 79 | ```julia 80 | xtest = collect(range(0.5, stop=1.5, length = 100)) 81 | m, s = predict(model, reshape(xtest,:,1)) 82 | ``` 83 | 84 | Note that all methods assume that `xtrain` and `xtest` are matrices, which is why we use `reshape(xtest,:,1)` to reshape the respective vectors to a matrix. 85 | 86 | 87 | ## API 88 | 89 | #### Mean functions 90 | ```julia 91 | # A constant mean of zero aka zero-mean function. 92 | ConstMean(0.0) 93 | ``` 94 | 95 | #### Kernel functions 96 | ```julia 97 | # A squared exponential kernel-function with lengthscale 1 and std of 1. 98 | IsoSE(1.0, 1.0) 99 | 100 | # A squared exponential kernel-function with ARD and lengthscales of 1 and std of 1. 101 | ArdSE(ones(10), 1.0) 102 | 103 | # A linear kernel-function with lengthscale of 1. 104 | IsoLinear(1.0) 105 | 106 | # A linear kernel-function with ARD and lengthscales of 1. 107 | ArdLinear(ones(10)) 108 | 109 | # Composition of kernel-function for inference over kernel-functions. 110 | KernelFunction[IsoSE(1.0, 1.0), IsoLinear(1.0)] 111 | ``` 112 | 113 | #### Models 114 | ```julia 115 | # An exact Gaussian process 116 | GaussianProcess(trainx, trainy, mean = meanf, kernel = kernelf) 117 | 118 | # A (generalized) product of experts (PoE) model with K splits per node and a miminum of M observations per expert 119 | buildPoE(trainx, trainy, K; generalized = true, M = M, kernel = kernelf, meanFun = meanf) 120 | 121 | # A (robust) Bayesian comittee machine (BCM) model with K splits per node and a miminum of M observations per expert 122 | # ! Training not implemented ! 123 | buildrBCM(x, y, K; M = M, kernel = kernelf, meanFun = meanf) 124 | 125 | # A deep structured mixture of GPs (DSMGP) model with K splits per product node, V children per sum node and a miminum of M observations per expert. 126 | buildDSMGP(x, y, V, K; M = M, kernel = kernelf, meanFun = meanf) 127 | ``` 128 | 129 | #### Training 130 | Note that DeepStructuredMixtures reexports Flux.jl and uses the optimisers available in Flux. We refer to the Flux.jl documentation of the available optimisers. 131 | 132 | ```julia 133 | # train a model for 1000 iterations using RMSProp 134 | train!(model, ADAM(), iterations = 1_000) 135 | 136 | # fine-tune a model for 1000 iterations using RMSProp 137 | finetune!(model, ADAM(), iterations = 1_000) 138 | 139 | # fit the posterior of a hierarchical model, e.g. gPoE 140 | fit_naive!(model.root) 141 | 142 | # fit the posterior of a DSMGP using shared Cholesky 143 | fit!(model) 144 | ``` 145 | 146 | #### Prediction 147 | ```julia 148 | # make predictions using a model, i.e., compute mean (s) and variance (s). 149 | m, s = prediction(model, testx) 150 | 151 | # plot a model and the training data. 152 | plot(model) 153 | ``` 154 | 155 | ## Reference 156 | [1] Martin Trapp, Robert Peharz, Franz Pernkopf and Carl Edward Rasmussen: Deep Structured Mixtures of Gaussian Processes. To appear at the International Conference on Artificial Intelligence and Statistics (AISTATS), 2020. 157 | 158 | ## Acknowledgments 159 | This project received funding from the European Union's Horizon 2020 research and innovation programme under the Marie Sklodowska-Curie Grant Agreement No. 797223 (HYBSPN). 160 | -------------------------------------------------------------------------------- /src/optimize.jl: -------------------------------------------------------------------------------- 1 | using Base.Threads 2 | using Optim 3 | 4 | export getparams, set_params! 5 | export optimize_restarts!, rand_init! 6 | export ∇mll, mll, mll! 7 | 8 | function rand_init!(spn::Union{GPSplitNode, GPSumNode}, D) 9 | gp = leftGP(spn) 10 | n = gp isa Array ? sum(map(sum, nparams.(gp))) : sum(nparams(gp)) 11 | hyp = randn(n) 12 | 13 | setparams!(spn, hyp) 14 | fit!(spn, D) 15 | return spn 16 | end 17 | 18 | @inline mll(node::GPNode) = mll(node.dist) 19 | @inline mll(node::GPSplitNode) = mapreduce(mll, +, children(node)) 20 | function mll(node::GPSumNode) 21 | K = length(node) 22 | StatsFuns.logsumexp(map(c -> -log(K)+mll(c), children(node))) 23 | end 24 | 25 | @inline mll(model::Union{DSMGP,PoE}) = mll(model.root) 26 | 27 | function mll!(node::GPNode, ℓ::AxisArray) 28 | ℓ[node.id] = mll(node.dist) 29 | return ℓ[node.id] 30 | end 31 | function mll!(node::GPSplitNode, ℓ::AxisArray) 32 | ℓ[node.id] = mapreduce(c -> mll!(c, ℓ), +, children(node)) 33 | return ℓ[node.id] 34 | end 35 | function mll!(node::GPSumNode, ℓ::AxisArray) 36 | K = length(node) 37 | ℓ[node.id] = StatsFuns.logsumexp(map(c -> -log(K)+mll!(c, ℓ), children(node))) 38 | return ℓ[node.id] 39 | end 40 | 41 | # == gradient propagation (global) == 42 | @inline function ∇mll!(node::GPNode, 43 | ∇parent::Float64, 44 | lρ::Float64, 45 | ℓ::AxisArray, 46 | logS::Float64, 47 | ∇::AbstractVector) 48 | w = exp(-logS+lρ+ℓ[node.id]+∇parent) 49 | ∇[:] += ∇mll(node.dist)*w 50 | end 51 | 52 | function ∇mll!(node::GPSplitNode, 53 | ∇parent::Float64, 54 | lρ::Float64, 55 | ℓ::AxisArray, 56 | logS::Float64, 57 | ∇::AbstractVector) 58 | Threads.@threads for child = children(node) 59 | lp = ℓ[node.id] - ℓ[child.id] 60 | ∇mll!(child, ∇parent + lp, lρ, ℓ, logS, ∇) 61 | end 62 | end 63 | 64 | function ∇mll!(node::GPSumNode, 65 | ∇parent::Float64, 66 | lρ::Float64, 67 | ℓ::AxisArray, 68 | logS::Float64, 69 | ∇::AbstractVector) 70 | K = length(node) 71 | @inbounds for child in children(node) 72 | ∇mll!(child, -log(K)+∇parent, log(K)+lρ, ℓ, logS, ∇) 73 | end 74 | end 75 | 76 | function ∇mll!(node::GPSumNode{T,V}, 77 | ∇parent::Float64, 78 | lρ::Float64, 79 | ℓ::AxisArray, 80 | logS::Float64, 81 | ∇::AbstractVector) where {T<:AbstractFloat, V<:GPNode} 82 | K = length(node) 83 | c = 1 84 | @inbounds for k = 1:K 85 | n = sum(nparams(node[k].dist))-1 86 | ∇mll!(node[k], ∇parent, lρ, ℓ, logS, view(∇, c:(c+n)) ) 87 | c += n+1 88 | end 89 | end 90 | 91 | # == gradient propagation (finetune) == 92 | @inline function ∇mll!(node::GPNode, 93 | ∇parent::Float64, 94 | lρ::Float64, 95 | ℓ::AxisArray, 96 | logS::Float64, 97 | ∇::AbstractVector, 98 | D::AbstractVector, 99 | gpmap::BiDict) 100 | w = exp(-logS+lρ+ℓ[node.id]+∇parent) 101 | ∇[:] += ∇mll(node.dist)*w*D[gpmap.x[node.id]] 102 | end 103 | 104 | function ∇mll!(node::GPSplitNode, 105 | ∇parent::Float64, 106 | lρ::Float64, 107 | ℓ::AxisArray, 108 | logS::Float64, 109 | ∇::AbstractVector, 110 | D::AbstractVector, 111 | gpmap::BiDict) 112 | K = length(node) 113 | 114 | Threads.@threads for k = 1:K 115 | lp = ℓ[node.id] - ℓ[node[k].id] 116 | ∇mll!(node[k], ∇parent + lp, lρ, ℓ, logS, ∇, D, gpmap) 117 | end 118 | end 119 | 120 | function ∇mll!(node::GPSumNode, 121 | ∇parent::Float64, 122 | lρ::Float64, 123 | ℓ::AxisArray, 124 | logS::Float64, 125 | ∇::AbstractVector, 126 | D::AbstractVector, 127 | gpmap::BiDict) 128 | K = length(node) 129 | @inbounds for child in children(node) 130 | ∇mll!(child, -log(K)+∇parent, log(K)+lρ, ℓ, logS, ∇, D, gpmap) 131 | end 132 | end 133 | 134 | function ∇mll!(node::GPSumNode{T,V}, 135 | ∇parent::Float64, 136 | lρ::Float64, 137 | ℓ::AxisArray, 138 | logS::Float64, 139 | ∇::AbstractVector, 140 | D::AbstractVector, 141 | gpmap::BiDict) where {T<:AbstractFloat, V<:GPNode} 142 | @warn "should not be here.." 143 | K = length(node) 144 | c = 1 145 | @inbounds for k = 1:K 146 | n = sum(nparams(node[k].dist))-1 147 | ∇mll!(node[k], ∇parent, lρ, ℓ, logS, view(∇, c:(c+n)), D, gpmap) 148 | c += n+1 149 | end 150 | end 151 | 152 | 153 | function ∇mll(node::GPNode, ∇parent::Float64, ℓ, logS, ::Val{true}; kwargs...) 154 | ∇ = mll(node.dist) 155 | w = ℓ[node.id] -logS + ∇parent 156 | ∇ *= exp(w) 157 | return Dict(node.id => ∇) 158 | end 159 | 160 | function ∇mll(node::GPSplitNode, ∇parent::Float64, ℓ, logS, soft::Val{true}; kwargs...) 161 | K = length(node) 162 | 163 | lpchildren = map(c -> ℓ[c.id], children(node)) 164 | 165 | ∇ = ∇mll(node[1], ∇parent + sum(lpchildren[2:end]), ℓ, logS, soft; kwargs...) 166 | @inbounds for k = 2:K 167 | merge!(∇, ∇mll(node[k], ∇parent + sum(lpchildren[vcat(1:k-1, k+1:end)]), ℓ, logS, soft; kwargs...)) 168 | end 169 | return ∇ 170 | end 171 | 172 | function ∇mll(node::GPSumNode, ∇parent::Float64, ℓ, logS, soft::Val{true}; kwargs...) 173 | K = length(node) 174 | return mapreduce(c -> ∇mll(c, ∇parent - log(K), ℓ, logS, soft; kwargs...), merge, children(node)) 175 | end 176 | 177 | function ∇mll(node::SPNNode, soft; kwargs...) 178 | nodes = SumProductNetworks.getOrderedNodes(node) 179 | ids = map(n -> n.id, nodes) 180 | L = AxisArray(zeros(length(ids)), ids) 181 | mll!(node, L) 182 | ∇mll(node, 0.0, L, L[node.id], soft; kwargs...) 183 | end 184 | 185 | getparams(node::GPNode) = Dict(node.id => vcat(DeepStructuredMixtures.params(node.dist)...)) 186 | getparams(node::Node) = mapreduce(c -> getparams(c), merge, children(node)) 187 | 188 | setparams!(node::GPNode, params) = setparams!(node.dist, params) 189 | setparams!(node::SPNNode, params) = setparams!(node, params) 190 | function setparams!(node::GPSumNode{T,V}, params) where {T<:AbstractFloat,V<:GPNode} 191 | c = 1 192 | for child in children(node) 193 | n = sum(nparams(child.dist))-1 194 | setparams!(child.dist, params[c:(c+n)]) 195 | c += n+1 196 | end 197 | end 198 | setparams!(node::Node, params) = map(c-> setparams!(c, params), children(node)) 199 | 200 | @inline getLeafIds(node::GPNode) = node.id 201 | @inline getLeafIds(node::Node) = mapreduce(getLeafIds, vcat, children(node)) 202 | 203 | -------------------------------------------------------------------------------- /src/gaussianprocess.jl: -------------------------------------------------------------------------------- 1 | using Distances 2 | import StatsBase.params 3 | 4 | export GaussianProcess 5 | export update_cholesky!, prediction 6 | export mll, ∇mll, ∇mll! 7 | export params, setparams!, nparams, getnoise 8 | export updategradients! 9 | 10 | mutable struct Scalar 11 | value::Float64 12 | end 13 | 14 | struct GaussianProcess{TX<:AbstractMatrix{<:AbstractFloat}, 15 | Ty<:AbstractVector{<:AbstractFloat}, 16 | Tm<:MeanFunction, 17 | Tk<:KernelFunction, 18 | Tchol<:AbstractFloat, 19 | Ta<:AbstractVector{Tchol}, 20 | Tdata<:AbstractArray} 21 | 22 | x::TX # inputs 23 | y::Ty # outputs 24 | 25 | mean::Tm # mean function 26 | kernel::Tk # kernel function 27 | logNoise::Scalar 28 | ∂ϵ::Scalar 29 | 30 | D::Int 31 | N::Int 32 | 33 | cK::Cholesky{Tchol} 34 | α::Ta 35 | P::Tdata 36 | 37 | end 38 | 39 | @inline getnoise(gp::GaussianProcess; logscale=false) = logscale ? gp.logNoise.value : exp(2*gp.logNoise.value) 40 | @inline function setnoise!(gp::GaussianProcess, noise::AbstractFloat) 41 | gp.logNoise.value = noise 42 | end 43 | 44 | function Base.show(io::IO, ::MIME"text/plain", m::GaussianProcess) 45 | ℓ = mll(m) 46 | print(io, "Gaussian process\n noise: ",getnoise(m),"\n kernel: ",m.kernel,"\n mean: ",m.mean,"\n mll:",ℓ) 47 | end 48 | Base.show(io::IO, m::GaussianProcess) = print(io, "GP(",m.kernel,", ",m.mean,")") 49 | 50 | function GaussianProcess(x::AbstractMatrix, 51 | y::AbstractVector; 52 | mean = ConstMean(mean(y)), 53 | kernel = IsoSE(0.0, 0.0), 54 | logNoise = log(7), 55 | run_cholesky = false 56 | ) 57 | P = getdistancematrix(kernel, x) 58 | return GaussianProcess(x, y, mean, kernel, logNoise, P; run_cholesky = run_cholesky) 59 | end 60 | 61 | function GaussianProcess(x::AbstractMatrix, 62 | y::AbstractVector, 63 | mean::MeanFunction, 64 | kernel::KernelFunction, 65 | logNoise::Float64, 66 | P::AbstractArray; 67 | run_cholesky = false 68 | ) 69 | N,D = size(x) 70 | cK = Cholesky(zeros(N,N), 'L', 0) 71 | α = zeros(N) 72 | yy = similar(y) 73 | apply_subtract!(mean, y, yy) 74 | gp = GaussianProcess(x, yy, mean, kernel, Scalar(logNoise), Scalar(0.0), D, N, cK, α, P) 75 | 76 | if run_cholesky 77 | update_cholesky!(gp) 78 | end 79 | return gp 80 | end 81 | 82 | function update_cholesky!(gp::GaussianProcess{Tx,Ty,Tm,Tk,Tchol,Ta,Tdata}) where {Tx,Ty,Tm,Tk,Tchol,Ta,Tdata} 83 | Knn = kernelmatrix(gp.kernel, gp.P) 84 | return update_cholesky!(gp, Knn) 85 | end 86 | 87 | function update_cholesky!(gp::GaussianProcess{Tx,Ty,Tm,Tk,Tchol,Ta,Tdata}, Knn::AbstractMatrix) where {Tx,Ty,Tm,Tk,Tchol,Ta,Tdata} 88 | 89 | # reset factors to kernel matrix 90 | F = gp.cK.factors 91 | @inbounds F[:] = Tchol.(Knn) 92 | 93 | # compute noise 94 | noise = Tchol(getnoise(gp) + ϵ) 95 | 96 | # add noise 97 | σ = @view F[diagind(F)] 98 | map!(i -> i+noise, σ, σ) 99 | 100 | # solve cholesky 101 | LAPACK.potrf!('L', F) 102 | 103 | # update α 104 | # See Rasmussen and Williams, Algorithm 2.1 105 | gp.α[:] = gp.cK.L' \ (gp.cK.L \ gp.y) 106 | 107 | return gp 108 | end 109 | 110 | function prediction(gp, 111 | Knt::AbstractMatrix, 112 | Ktt::AbstractMatrix, 113 | xtest::AbstractMatrix 114 | ) 115 | 116 | # See Rasmussen and Williams, Algorithm 2.1 117 | mx = get(gp.mean, size(xtest,1)) 118 | μ = mx + Knt' * gp.α 119 | 120 | V = gp.cK.L \ Knt 121 | Σ = Ktt - V' * V 122 | 123 | noise = eltype(Σ)(getnoise(gp)) 124 | 125 | σ = @view Σ[diagind(Σ)] 126 | map!(i -> i+noise, σ, σ) 127 | 128 | return μ, Σ 129 | end 130 | 131 | function prediction(gp, xtest::AbstractMatrix) 132 | 133 | Knt = kernelmatrix(gp.kernel, gp.x, xtest) 134 | Ktt = kernelmatrix(gp.kernel, xtest, xtest) 135 | 136 | return prediction(gp, Knt, Ktt, xtest) 137 | end 138 | 139 | @inline nparams(gp::GaussianProcess) = map(length, params(gp)) 140 | 141 | function params(gp::GaussianProcess; logscale = false) 142 | return (getlengthscales(gp.kernel, logscale=logscale), 143 | getvariance(gp.kernel, logscale=logscale), 144 | getnoise(gp, logscale=logscale)) 145 | end 146 | 147 | function setparams!(gp::GaussianProcess, lengthscale, variance, noise::AbstractFloat) 148 | setlengthscale!(gp.kernel, lengthscale) 149 | setvariance!(gp.kernel, variance) 150 | setnoise!(gp, noise) 151 | end 152 | 153 | function setparams!(gp::GaussianProcess, hyper::AbstractVector) 154 | setnoise!(gp, hyper[end]) 155 | setvariance!(gp.kernel, hyper[end-1]) 156 | if length(hyper) == 3 157 | setlengthscale!(gp.kernel, hyper[1]) 158 | else 159 | setlengthscale!(gp.kernel, hyper[1:end-2]) 160 | end 161 | end 162 | 163 | @inline mll(gp) = - (dot(gp.y, gp.α) + logdet(gp.cK) + log2π * gp.N) / 2 164 | 165 | updategradients!(gp::GaussianProcess) = updategradients!(gp, gp.P) 166 | function updategradients!(gp::GaussianProcess, P::AbstractArray) 167 | K = kernelmatrix(gp.kernel, P) 168 | return updategradients!(gp, K, P) 169 | end 170 | 171 | function updategradients!(gp::GaussianProcess, K::AbstractMatrix, P::AbstractArray ) 172 | T = eltype(K) 173 | precomp = zeros(T, gp.N, gp.N) 174 | ααinvcK!(precomp, gp.cK, gp.α) 175 | 176 | gp.∂ϵ.value = getnoise(gp) * tr(precomp) 177 | updategradients!(gp.kernel, precomp, K, P) 178 | end 179 | 180 | function copygradients(source::GaussianProcess, dest::GaussianProcess) 181 | dest.∂ϵ.value = source.∂ϵ.value 182 | setgradients!(dest.kernel, getgradients(source.kernel)) 183 | end 184 | 185 | function ∇mll(gp::GaussianProcess) 186 | updategradients!(gp) 187 | grad = zeros(sum(nparams(gp))) 188 | ∇mll!(gp, grad) 189 | return grad 190 | end 191 | 192 | function ∇mll!(gp::GaussianProcess{Tx,Ty,Tm,Tk,Tchol,Ta,Tdata}, 193 | grad::AbstractVector{Tg} 194 | ) where {Tx,Ty,Tm,Tk,Tchol,Ta,Tdata,Tg<:AbstractFloat} 195 | return ∇mll!(gp, grad, gp.P) 196 | end 197 | 198 | function ∇mll!(gp::GaussianProcess, 199 | grad::AbstractVector{Tg}, 200 | P::AbstractArray{T} 201 | ) where {Tg<:AbstractFloat,T} 202 | K = kernelmatrix(gp.kernel, P) 203 | return ∇mll!(gp, grad, K, P) 204 | end 205 | 206 | function ∇mll!(gp::GaussianProcess, 207 | grad::AbstractVector{Tg}, 208 | K::AbstractMatrix{T}, 209 | P::AbstractArray{V} 210 | ) where {Tg<:AbstractFloat,T,V} 211 | 212 | ∂v, ∂l = getgradients(gp.kernel) 213 | grad[1:end-1] = vcat(∂l, ∂v) 214 | grad[end] = gp.∂ϵ.value 215 | 216 | return grad 217 | end 218 | 219 | function ααinvcK!(out::AbstractMatrix{T}, cK::Cholesky{T}, α::AbstractVector{T}) where {T} 220 | o = @view out[diagind(out)] 221 | map!(i -> i-one(T), o, o) 222 | 223 | ldiv!(cK, out) 224 | BLAS.ger!(1.0, α, α, out) 225 | return out 226 | end 227 | -------------------------------------------------------------------------------- /src/regionGraphUtils.jl: -------------------------------------------------------------------------------- 1 | function convertToSPN(rootRegion::SumRegion, gpRegions, RegionIDs, PartitionIDS, X, y, meanFunction, kernelFunctions::Vector, kernelPriors::Vector, noise; overlap = 2, do_mcmc = false) 2 | 3 | nodes = Dict{Int, Vector{SPNNode}}() 4 | for r in gpRegions 5 | 6 | s = vec(all((X .> (r.min' - overlap)) .& (X .< (r.max' + overlap)), 2)) 7 | xx = X[s,:]' 8 | yy = y[s] 9 | 10 | # construct GPs 11 | gp_nodes = [] 12 | for (ki, kernel_function) in enumerate(kernelFunctions) 13 | 14 | if do_mcmc 15 | 16 | gp = GP(xx, yy, deepcopy(meanFunction), deepcopy(kernel_function), copy(noise)) 17 | set_priors!(gp.k, kernelPriors[ki]) 18 | 19 | samples = mcmc(gp; nIter=1000,burnin=0,thin=100); 20 | 21 | for i in 1:size(samples,2) 22 | 23 | node = GPLeaf{Any}(nextID(), 24 | GP(xx, yy, deepcopy(meanFunction), deepcopy(kernel_function), copy(noise)) 25 | ) 26 | set_params!(node.gp, samples[:,i]) 27 | update_target!(node.gp) 28 | node.parents = SPNNode[] 29 | node.minx = r.min 30 | node.maxx = r.max 31 | push!(gp_nodes, node) 32 | end 33 | else 34 | 35 | node = GPLeaf{Any}(nextID(), 36 | GP(xx, yy, deepcopy(meanFunction), deepcopy(kernel_function), copy(noise)) 37 | ) 38 | node.parents = SPNNode[] 39 | node.minx = r.min 40 | node.maxx = r.max 41 | push!(gp_nodes, node) 42 | end 43 | end 44 | nodes[RegionIDs[r]] = gp_nodes 45 | end 46 | 47 | return buildNodes(rootRegion, RegionIDs, PartitionIDS, nodes, rootRegion)[1] 48 | end 49 | 50 | function buildNodes(r::SumRegion, RegionIDs, PartitionIDS, nodes::Dict, rootRegion::SumRegion) 51 | @assert haskey(RegionIDs, r) 52 | if !haskey(nodes, RegionIDs[r]) 53 | 54 | childrn = reduce(vcat, map(p -> buildNodes(p, RegionIDs, PartitionIDS, nodes, rootRegion), r.partitions)) 55 | 56 | if r == rootRegion 57 | # construct only a single sum node 58 | node = GPSumNode(nextID(), Int[]); 59 | 60 | for child in childrn 61 | add!(node, child) 62 | end 63 | 64 | fill!(node.prior_weights, 1. / length(node)) 65 | fill!(node.posterior_weights, 1. / length(node)) 66 | 67 | @assert length(node) == length(node.posterior_weights) 68 | 69 | nodes[RegionIDs[r]] = [node] 70 | else 71 | n = SPNNode[] 72 | for s in 1:numSums 73 | 74 | # construct only a single sum node 75 | node = GPSumNode(nextID(), Int[]); 76 | 77 | for child in childrn 78 | add!(node, child) 79 | end 80 | 81 | node.prior_weights[:] = rand(Dirichlet(ones(length(node)))) # use Dirichlet instead ? 82 | fill!(node.prior_weights, 1. / length(node)) 83 | fill!(node.posterior_weights, 1. / length(node)) 84 | 85 | @assert length(node) == length(node.posterior_weights) 86 | push!(n, node) 87 | end 88 | nodes[RegionIDs[r]] = n 89 | end 90 | end 91 | 92 | return nodes[RegionIDs[r]] 93 | end 94 | 95 | function buildNodes(r::SumRegion, RegionIDs, PartitionIDS, nodes::Dict, rootRegion::SumRegion) 96 | @assert haskey(RegionIDs, r) 97 | if !haskey(nodes, RegionIDs[r]) 98 | 99 | childrn = reduce(vcat, map(p -> buildNodes(p, RegionIDs, PartitionIDS, nodes, rootRegion), r.partitions)) 100 | 101 | if r == rootRegion 102 | # construct only a single sum node 103 | node = GPSumNode(nextID(), Int[]); 104 | 105 | for child in childrn 106 | add!(node, child) 107 | end 108 | 109 | fill!(node.prior_weights, 1. / length(node)) 110 | fill!(node.posterior_weights, 1. / length(node)) 111 | 112 | @assert length(node) == length(node.posterior_weights) 113 | 114 | nodes[RegionIDs[r]] = [node] 115 | else 116 | n = SPNNode[] 117 | for s in 1:numSums 118 | 119 | # construct only a single sum node 120 | node = GPSumNode(nextID(), Int[]); 121 | 122 | for child in childrn 123 | add!(node, child) 124 | end 125 | 126 | node.prior_weights[:] = rand(Dirichlet(ones(length(node)))) # use Dirichlet instead ? 127 | fill!(node.prior_weights, 1. / length(node)) 128 | fill!(node.posterior_weights, 1. / length(node)) 129 | 130 | @assert length(node) == length(node.posterior_weights) 131 | push!(n, node) 132 | end 133 | nodes[RegionIDs[r]] = n 134 | end 135 | end 136 | 137 | return nodes[RegionIDs[r]] 138 | end 139 | 140 | function buildNodes(p::SplitPartition, RegionIDs, PartitionIDS, nodes::Dict, rootRegion::SumRegion) 141 | 142 | childrn = map(r -> buildNodes(r, RegionIDs, PartitionIDS, nodes, rootRegion), p.regions) 143 | 144 | n = SPNNode[] 145 | for ch1 in childrn[1] 146 | for ch2 in childrn[2] 147 | # construct node 148 | split = ones(p.dimensions) * -Inf 149 | split[p.dimension] = p.split 150 | node = FiniteSplitNode(nextID(), split) 151 | 152 | add!(node, ch1) 153 | add!(node, ch2) 154 | push!(n, node) 155 | end 156 | end 157 | 158 | node = GPSumNode(nextID(), Int[]); 159 | 160 | for child in n 161 | add!(node, child) 162 | end 163 | 164 | fill!(node.prior_weights, 1. / length(node)) # use Dirichlet instead ? 165 | fill!(node.posterior_weights, 1. / length(node)) 166 | push!(n, node) 167 | 168 | @assert haskey(PartitionIDS, p) 169 | nodes[PartitionIDS[p]] = [node] 170 | return nodes[PartitionIDS[p]] 171 | end 172 | 173 | function buildNodes(p::SplitPartition, RegionIDs, PartitionIDS, nodes::Dict, rootRegion::SumRegion) 174 | 175 | childrn = map(r -> buildNodes(r, RegionIDs, PartitionIDS, nodes, rootRegion), p.regions) 176 | 177 | n = SPNNode[] 178 | for ch1 in childrn[1] 179 | for ch2 in childrn[2] 180 | # construct node 181 | split = [p.split] 182 | node = FiniteSplitNode(nextID(), split) 183 | 184 | add!(node, ch1) 185 | add!(node, ch2) 186 | push!(n, node) 187 | end 188 | end 189 | 190 | node = GPSumNode(nextID(), Int[]); 191 | 192 | for child in n 193 | add!(node, child) 194 | end 195 | 196 | fill!(node.prior_weights, 1. / length(node)) # use Dirichlet instead ? 197 | fill!(node.posterior_weights, 1. / length(node)) 198 | push!(n, node) 199 | 200 | @assert haskey(PartitionIDS, p) 201 | nodes[PartitionIDS[p]] = [node] 202 | return nodes[PartitionIDS[p]] 203 | end 204 | 205 | function buildNodes(r::GPRegion, RegionIDs, PartitionIDS, nodes::Dict, rootRegion::SumRegion) 206 | @assert haskey(RegionIDs, r) 207 | return nodes[RegionIDs[r]] 208 | end 209 | 210 | function buildNodes(r::GPRegion, RegionIDs, PartitionIDS, nodes::Dict, rootRegion::SumRegion) 211 | @assert haskey(RegionIDs, r) 212 | return nodes[RegionIDs[r]] 213 | end 214 | 215 | -------------------------------------------------------------------------------- /src/plot.jl: -------------------------------------------------------------------------------- 1 | export kernelidfunction 2 | 3 | const invΦ = norminvcdf 4 | 5 | function kernelidfunction(spn::Union{GPSplitNode, GPSumNode}) 6 | lgp = leftGP(spn) 7 | rgp = rightGP(spn) 8 | 9 | xmin = lgp isa AbstractArray ? mapreduce(gp -> minimum(gp.x), min, lgp) : minimum(lgp.x) 10 | xmax = rgp isa AbstractArray ? mapreduce(gp -> maximum(gp.x), max, rgp) : maximum(rgp.x) 11 | 12 | x = range(xmin, stop=xmax, length=100) 13 | y = kernelid(spn, reshape(collect(x), :, 1)) 14 | 15 | return x, y 16 | end 17 | 18 | @recipe function f(model::Union{DSMGP,PoE,gPoE,rBCM}; β=0.95, obsv=true, var=false, n=100, xmin=-Inf, xmax=Inf, filled = true) 19 | 20 | root = model.root 21 | 22 | lgp = leftGP(root) 23 | rgp = rightGP(root) 24 | 25 | D = lgp isa AbstractArray ? first(lgp).D : lgp.D 26 | 27 | if D == 1 28 | 29 | if isinf(xmin) 30 | xmin = lgp isa AbstractArray ? mapreduce(gp -> minimum(gp.x), min, lgp) : minimum(lgp.x) 31 | end 32 | if isinf(xmax) 33 | xmax = rgp isa AbstractArray ? mapreduce(gp -> maximum(gp.x), max, rgp) : maximum(rgp.x) 34 | end 35 | 36 | xlims --> (xmin, xmax) 37 | xmin, xmax = plotattributes[:xlims] 38 | x = range(xmin, stop=xmax, length=100) 39 | 40 | y, Σ = predict(model, reshape(x,:,1)) 41 | Σ[Σ .< 0] .= 0.0 42 | err = invΦ((1+β)/2)*sqrt.(Σ) 43 | 44 | if filled 45 | @series begin 46 | seriestype := :path 47 | ribbon := err 48 | fillcolor --> :orange 49 | linewidth --> 1 50 | linestyle := :dash 51 | model isa DSMGP ? label --> "DSMGP" : label --> "PoE" 52 | x,y 53 | end 54 | else 55 | @series begin 56 | seriestype := :path 57 | linewidth := 1.5 58 | model isa DSMGP ? label --> "DSMGP" : label --> "PoE" 59 | x,y 60 | end 61 | @series begin 62 | primary := false 63 | seriestype := :path 64 | linewidth := 1.4 65 | linestyle := :dot 66 | x,y.-err 67 | end 68 | @series begin 69 | primary := false 70 | seriestype := :path 71 | linewidth := 1.4 72 | linestyle := :dot 73 | x,y.+err 74 | end 75 | end 76 | if obsv 77 | @series begin 78 | primary := false 79 | #label := "observations" 80 | seriestype := :scatter 81 | markershape := :circle 82 | markercolor := :black 83 | #markersize := 0.7 84 | getx(root), gety(root) 85 | end 86 | end 87 | elseif D == 2 88 | 89 | xmin = lgp isa AbstractArray ? mapreduce(gp -> minimum(gp.x[:,1]), min, lgp) : minimum(lgp.x[:,1]) 90 | xmax = rgp isa AbstractArray ? mapreduce(gp -> maximum(gp.x[:,1]), max, rgp) : maximum(rgp.x[:,1]) 91 | ymin = lgp isa AbstractArray ? mapreduce(gp -> minimum(gp.x[:,2]), min, lgp) : minimum(lgp.x[:,2]) 92 | ymax = rgp isa AbstractArray ? mapreduce(gp -> maximum(gp.x[:,2]), max, rgp) : maximum(rgp.x[:,2]) 93 | 94 | xlims --> (xmin,xmax) 95 | ylims --> (ymin,ymax) 96 | xmin, xmax = plotattributes[:xlims] 97 | ymin, ymax = plotattributes[:ylims] 98 | x = range(xmin, stop=xmax, length=n) 99 | y = range(ymin, stop=ymax, length=n) 100 | xgrid = repeat(x', n, 1) 101 | ygrid = repeat(y, 1, n) 102 | 103 | μ, Σ = predict(model, hcat(vec(xgrid), vec(ygrid))) 104 | 105 | if var 106 | zgrid = reshape(Σ,n,n) 107 | else 108 | zgrid = reshape(μ,n,n) 109 | end 110 | x, y, zgrid 111 | end 112 | end 113 | 114 | @recipe function f(root::Union{GPSplitNode,GPSumNode}; β=0.95, obsv=true, var=false, n=100, 115 | xmin=-Inf, xmax=Inf, filled = true, show_splits = false) 116 | 117 | lgp = leftGP(root) 118 | rgp = rightGP(root) 119 | 120 | D = lgp isa AbstractArray ? first(lgp).D : lgp.D 121 | 122 | if D == 1 123 | 124 | if isinf(xmin) 125 | xmin = lgp isa AbstractArray ? mapreduce(gp -> minimum(gp.x), min, lgp) : minimum(lgp.x) 126 | end 127 | if isinf(xmax) 128 | xmax = rgp isa AbstractArray ? mapreduce(gp -> maximum(gp.x), max, rgp) : maximum(rgp.x) 129 | end 130 | 131 | xlims --> (xmin, xmax) 132 | xmin, xmax = plotattributes[:xlims] 133 | x = range(xmin, stop=xmax, length=100) 134 | 135 | y, Σ = predict(root, reshape(x,:,1)) 136 | Σ[Σ .< 0] .= 0.0 137 | err = DeepStructuredMixtures.invΦ((1+β)/2)*sqrt.(Σ) 138 | 139 | if filled 140 | @series begin 141 | seriestype := :path 142 | ribbon := err 143 | fillcolor := :orange 144 | linewidth --> 1 145 | opacity := 0.75 146 | x,y 147 | end 148 | else 149 | @series begin 150 | seriestype := :path 151 | linewidth := 3 152 | x,y 153 | end 154 | @series begin 155 | primary := false 156 | seriestype := :path 157 | linewidth := 1 158 | x,y.-err 159 | end 160 | @series begin 161 | primary := false 162 | seriestype := :path 163 | linewidth := 1 164 | x,y.+err 165 | end 166 | end 167 | if obsv 168 | @series begin 169 | primary := false 170 | #label := "observations" 171 | seriestype := :scatter 172 | markershape := :circle 173 | markercolor := :black 174 | markersize := 3 175 | getx(root), gety(root) 176 | end 177 | end 178 | if show_splits && (root isa GPSplitNode) 179 | @series begin 180 | primary := false 181 | seriestype := :vline 182 | linewidth := 1 183 | fillcolor := :red 184 | linestyle := :dash 185 | root.split[1:end-1] 186 | end 187 | end 188 | end 189 | end 190 | 191 | @recipe function f(gp::GaussianProcess; β=0.95, obsv=true, var=false, n=100, xmin=-Inf, xmax=Inf) 192 | 193 | @assert gp.D == 1 194 | 195 | xmin = isinf(xmin) ? minimum(gp.x) : xmin 196 | xmax = isinf(xmax) ? maximum(gp.x) : xmax 197 | 198 | xlims --> (xmin, xmax) 199 | xmin, xmax = plotattributes[:xlims] 200 | x = range(xmin, stop=xmax, length=100) 201 | 202 | y, Σ = prediction(gp, reshape(x, :, 1)) 203 | σ² = diag(Σ) 204 | err = invΦ((1+β)/2)*sqrt.(σ²) 205 | 206 | @series begin 207 | seriestype := :path 208 | ribbon := err 209 | fillcolor --> :lightblue 210 | linewidth --> 1 211 | linestyle := :dash 212 | label --> "GP" 213 | x,y 214 | end 215 | if obsv 216 | @series begin 217 | primary := false 218 | #label := "observations" 219 | seriestype := :scatter 220 | markershape := :circle 221 | markercolor := :black 222 | #markersize := 0.7 223 | gp.x, gp.y + get(gp.mean, gp.N) 224 | end 225 | end 226 | end 227 | -------------------------------------------------------------------------------- /src/kernels.jl: -------------------------------------------------------------------------------- 1 | using Distances 2 | 3 | export KernelFunction, IsoKernel, ArdKernel 4 | export IsoSE, IsoLinear 5 | export ArdSE, ArdLinear 6 | export getvariance, getlengthscales 7 | export kernelmatrix, kernelmatrix!, kappa 8 | export updategradients!, getgradients 9 | export getdistancematrix 10 | 11 | abstract type KernelFunction end 12 | abstract type IsoKernel <: KernelFunction end 13 | abstract type ArdKernel <: KernelFunction end 14 | 15 | function kernelmatrix(kernel::KernelFunction, x1::AbstractMatrix, x2::AbstractMatrix) 16 | P = getdistancematrix(kernel, x1, x2) 17 | return kernelmatrix(kernel, P) 18 | end 19 | 20 | # using pre-computed K 21 | function kernelmatrix!(kernel::IsoKernel, K::AbstractMatrix, P::AbstractMatrix) 22 | l = getlengthscales(kernel)^2 23 | map!(kappa(kernel, l), K, P) 24 | v = getvariance(kernel) 25 | return lmul!(v,K) 26 | end 27 | @inline kernelmatrix(kernel::KernelFunction, P::AbstractMatrix) = kernelmatrix!(kernel, zero(P), P) 28 | 29 | # using pre-computed K 30 | 31 | function umap!(f::Function, A::AbstractMatrix{T}, B::AbstractMatrix{T}) where {T} 32 | #Threads.@threads 33 | for (i,j) in zip(eachindex(A),eachindex(B)) 34 | v = f(@inbounds(B[j])) 35 | @inbounds A[i] += v 36 | end 37 | end 38 | 39 | function kernelmatrix!(kernel::ArdKernel, K::AbstractMatrix{T}, P::AbstractArray{T,3}) where {T<:Real} 40 | v = getvariance(kernel) 41 | ls = getlengthscales(kernel).^2 42 | fill!(K, zero(T)) 43 | for (d,p) in enumerate(eachslice(P, dims=[3])) 44 | umap!(kappa(kernel, @inbounds(ls[d])), K, p) 45 | end 46 | 47 | lmul!(v,K) 48 | return K 49 | end 50 | 51 | function kernelmatrix(kernel::KernelFunction, P::AbstractArray{T,3}) where {T} 52 | return kernelmatrix!(kernel, zeros(T,size(P,1),size(P,2)), P) 53 | end 54 | 55 | @inline getdistancematrix(k::KernelFunction, x1) = getdistancematrix(k, x1, x1) 56 | 57 | # == SE ISO kernel == 58 | 59 | mutable struct IsoSE{T<:AbstractFloat} <: IsoKernel 60 | logℓ::T 61 | logσ::T 62 | ∂ℓ::T 63 | ∂σ::T 64 | end 65 | 66 | IsoSE(logℓ, logσ) = IsoSE(logℓ, logσ, zero(logℓ), zero(logσ)) 67 | 68 | @inline getvariance(k::IsoSE; logscale=false) = logscale ? k.logσ : exp(2*k.logσ) 69 | @inline getstd(k::IsoSE) = exp(k.logσ) 70 | function setvariance!(k::IsoSE, v::AbstractFloat) 71 | k.logσ = v 72 | end 73 | @inline getlengthscales(k::IsoSE; logscale=false) = logscale ? k.logℓ : exp(k.logℓ) 74 | function setlengthscale!(k::IsoSE{T}, l::T) where {T} 75 | k.logℓ = l 76 | end 77 | 78 | @inline rbfkernel(z::T, l::T) where {T<:AbstractFloat} = exp(-0.5*(z/l)) 79 | 80 | function kappa(k::IsoSE{T}, l::T) where {T<:AbstractFloat} 81 | return z->rbfkernel(z, l) 82 | end 83 | @inline getdistancematrix(k::IsoSE, x1, x2) = pairwise(SqEuclidean(), x1, x2, dims=1) 84 | 85 | function updategradients!(k::IsoSE, precomp::AbstractMatrix, K::AbstractMatrix, P::AbstractMatrix) 86 | σ = getstd(k) 87 | l = getlengthscales(k)^2 88 | 89 | # σ * K 90 | lmul!(σ, K) 91 | 92 | # 0.5 * trace precomp * 2*σ*K) 93 | k.∂σ = 0.5*tr(precomp * 2*K) 94 | 95 | # 0.5 * trace precomp * σ*K*(P/l^2) 96 | K.*=P/l 97 | k.∂ℓ = 0.5*tr(precomp * K) 98 | return k 99 | end 100 | 101 | @inline getgradients(k::IsoSE) = (k.∂σ, k.∂ℓ) 102 | function setgradients!(k::IsoSE, grad) 103 | ∂σ, ∂ℓ = grad 104 | k.∂σ = ∂σ 105 | k.∂ℓ = ∂ℓ 106 | end 107 | 108 | # == SE ARD kernel == 109 | mutable struct ArdSE{T<:AbstractFloat} <: ArdKernel 110 | logℓ::Vector{T} 111 | logσ::T 112 | ∂ℓ::Vector{T} 113 | ∂σ::T 114 | end 115 | 116 | ArdSE(logℓ, logσ) = ArdSE(logℓ, logσ, zero(logℓ), zero(logσ)) 117 | 118 | @inline getvariance(k::ArdSE; logscale=false) = logscale ? k.logσ : exp(2*k.logσ) 119 | @inline getstd(k::ArdSE) = exp(k.logσ) 120 | function setvariance!(k::ArdSE, v::AbstractFloat) 121 | k.logσ = v 122 | end 123 | @inline getlengthscales(k::ArdSE; logscale=false) = logscale ? k.logℓ : exp.(k.logℓ) 124 | @inline getlengthscales(k::ArdSE, d::Int) = exp(k.logℓ[d]) 125 | function setlengthscale!(k::ArdSE{T}, l::AbstractVector{T}) where {T} 126 | k.logℓ[:] = l 127 | end 128 | function setlengthscale!(k::ArdSE{T}, l::T) where {T} 129 | @assert length(k.logℓ) == 1 130 | k.logℓ[1] = l 131 | end 132 | 133 | function kappa(k::ArdSE{T}, l::T) where {T<:AbstractFloat} 134 | return z->rbfkernel(z,l) 135 | end 136 | 137 | function getdistancematrix(k::ArdSE{T}, x1::AbstractMatrix{T}, x2::AbstractMatrix{T}) where {T} 138 | P = zeros(T, size(x1,1), size(x2,1), length(k.logℓ)) 139 | for d in Base.axes(P,3) 140 | @inbounds pairwise!(@view(P[:,:,d]), SqEuclidean(), @inbounds(@view(x1[:,d]))', @inbounds(@view(x2[:,d]))', dims=2) 141 | end 142 | return P 143 | #return @inbounds map(d -> pairwise(SqEuclidean(), view(x1,:,d)', view(x2,:,d)', dims=2), 1:length(k.logℓ)) 144 | end 145 | 146 | function updategradients!(k::ArdSE, 147 | precomp::AbstractMatrix{T}, 148 | K::AbstractMatrix{T}, 149 | P::AbstractArray{T,3}) where {T} 150 | σ = getstd(k) 151 | ls = getlengthscales(k).^2 152 | 153 | # σ * K 154 | lmul!(σ, K) 155 | 156 | # 0.5 * trace precomp * 2*σ*K) 157 | k.∂σ = T(0.5)*tr(precomp * 2*K) 158 | 159 | # 0.5 * trace precomp * σ*K*(P/l^2) 160 | for (d,p) in enumerate(eachslice(P, dims=[3])) 161 | @inbounds k.∂ℓ[d] = T(0.5)*tr(precomp * K.*(p/ls[d]) ) 162 | end 163 | return k 164 | end 165 | @inline getgradients(k::ArdSE) = (k.∂σ, k.∂ℓ) 166 | function setgradients!(k::ArdSE, grad) 167 | ∂σ, ∂ℓ = grad 168 | k.∂σ = ∂σ 169 | k.∂ℓ = ∂ℓ 170 | end 171 | 172 | # == Linear ISO kernel == 173 | 174 | mutable struct IsoLinear{T<:AbstractFloat} <: IsoKernel 175 | logℓ::T 176 | ∂ℓ::T 177 | end 178 | 179 | IsoLinear(logℓ) = IsoLinear(logℓ, zero(logℓ)) 180 | 181 | @inline getvariance(k::IsoLinear; logscale=false) = logscale ? 0.0 : 1.0 182 | @inline getstd(k::IsoLinear) = 1.0 183 | @inline setvariance!(k::IsoLinear, v) = nothing 184 | @inline getlengthscales(k::IsoLinear; logscale=false) = logscale ? k.logℓ : exp(k.logℓ) 185 | function setlengthscale!(k::IsoLinear, l::AbstractFloat) 186 | k.logℓ = l 187 | end 188 | 189 | @inline linearkernel(z::T, l::T) where {T<:AbstractFloat} = z/l 190 | function kappa(k::IsoLinear, l) 191 | return z -> linearkernel(z, l) 192 | end 193 | 194 | @inline getdistancematrix(k::IsoLinear, x1, x2) = x1 * x2' 195 | 196 | function updategradients!(k::IsoLinear, precomp::AbstractMatrix, K::AbstractMatrix, P::AbstractMatrix) 197 | l = getlengthscales(k) 198 | k.∂ℓ = 0.5*tr(precomp * -2*K) 199 | return k 200 | end 201 | @inline getgradients(k::IsoLinear) = (zero(typeof(k.∂ℓ)), k.∂ℓ) 202 | function setgradients!(k::IsoLinear, grad) 203 | _, ∂ℓ = grad 204 | k.∂ℓ = ∂ℓ 205 | end 206 | 207 | # == Linear ARD kernel == 208 | 209 | mutable struct ArdLinear{T<:AbstractFloat} <: ArdKernel 210 | logℓ::Vector{T} 211 | ∂ℓ::Vector{T} 212 | end 213 | 214 | ArdLinear(logℓ) = ArdLinear(logℓ, zero(logℓ)) 215 | 216 | @inline getvariance(k::ArdLinear; logscale=false) = logscale ? 0.0 : 1.0 217 | @inline getstd(k::ArdLinear) = 1.0 218 | @inline setvariance!(k::ArdLinear, v) = nothing 219 | @inline getlengthscales(k::ArdLinear; logscale=false) = logscale ? k.logℓ : exp.(k.logℓ) 220 | @inline getlengthscales(k::ArdLinear, d::Int) = exp(k.logℓ[d]) 221 | function setlengthscale!(k::ArdLinear{T}, l::T) where {T} 222 | @assert length(k.logℓ) == 1 223 | k.logℓ[1] = l 224 | end 225 | function setlengthscale!(k::ArdLinear{T}, l::AbstractVector{T}) where {T} 226 | k.logℓ[:] = l 227 | end 228 | function kappa(k::ArdLinear, l) 229 | return z -> linearkernel(z, l) 230 | end 231 | 232 | @inline getdistancematrix(k::ArdLinear, x1, x2) = map(d -> x1[:,d] * x2[:,d]', 1:length(k.logℓ)) 233 | 234 | function updategradients!(k::ArdLinear, 235 | precomp::AbstractMatrix{T}, 236 | K::AbstractMatrix{T}, 237 | P::AbstractVector{<:AbstractMatrix{T}}) where {T} 238 | ls = getlengthscales(k) 239 | 240 | @inbounds for d = 1:length(P) 241 | map!(kappa(k, ls[d]), K, P[d]) 242 | k.∂ℓ[d] = 0.5*tr(precomp * -2*K) 243 | end 244 | 245 | return k 246 | end 247 | @inline getgradients(k::ArdLinear) = (zero(eltype(k∂ℓ)), k.∂ℓ) 248 | function setgradients!(k::ArdLinear, grad) 249 | _, ∂ℓ = grad 250 | k.∂ℓ = ∂ℓ 251 | end 252 | -------------------------------------------------------------------------------- /src/common.jl: -------------------------------------------------------------------------------- 1 | export update!, reset_weights!, infer! 2 | export getLogNoise, kernelid 3 | export blockmatrix, blockindecies 4 | export nummixtures 5 | 6 | nummixtures(node::GPNode) = 1 7 | nummixtures(node::GPSplitNode) = mapreduce(nummixtures, *, children(node)) 8 | nummixtures(node::GPSumNode) = mapreduce(nummixtures, +, children(node)) 9 | 10 | 11 | function blockmatrix(node::GPNode) 12 | M = zeros(length(node.observations), length(node.observations)) 13 | idx = findall(node.observations) 14 | M[idx, idx] .+= 1 15 | return M 16 | end 17 | 18 | function blockmatrix(node::Node) 19 | M = blockmatrix(node[1]) 20 | for k = 2:length(node) 21 | M .+= blockmatrix(node[k]) 22 | end 23 | return M 24 | end 25 | 26 | function blockmatrix(node::GPSumNode) 27 | M = weights(node)[1]*blockmatrix(node[1]) 28 | for k = 2:length(node) 29 | M .+= weights(node)[k]*blockmatrix(node[k]) 30 | end 31 | return M 32 | end 33 | 34 | 35 | function blockindecies(node::GPNode, Ix::Vector{Vector{Int}}) 36 | for n in node.obs 37 | append!(Ix[n], node.obs) 38 | end 39 | end 40 | 41 | function blockindecies(node::Node, Ix::Vector{Vector{Int}}) 42 | return map(child -> blockindecies(child, Ix), children(node)) 43 | end 44 | 45 | bestblockmatrix(node::GPNode) = blockmatrix(node) 46 | function bestblockmatrix(node::Node) 47 | mapreduce(bestblockmatrix, +, children(node)) 48 | end 49 | 50 | function bestblockmatrix(node::GPSumNode) 51 | i = argmax(node.logweights) 52 | return bestblockmatrix(node[i]) 53 | end 54 | 55 | @inline kernelid(node::GPNode, x::AbstractMatrix) = repeat([node.kernelid], size(x,1)) 56 | function kernelid(node::GPSplitNode, x::AbstractMatrix) 57 | idx = getchild(node, x) 58 | kernel = zeros(Int, size(x,1)) 59 | for (k, c) in enumerate(children(node)) 60 | j = findall(idx .== k) 61 | if !isempty(j) 62 | kernel_ = kernelid(c, x[j,:]) 63 | kernel[j] = kernel_ 64 | end 65 | end 66 | return kernel 67 | end 68 | 69 | function kernelid(node::GPSumNode{T,V}, x::AbstractMatrix) where {T,V<:SPNNode} 70 | w = weights(node) 71 | k_ = mapreduce(c -> kernelid(c,x), hcat, children(node)) 72 | uk = unique(k_) 73 | c = mapreduce(kk -> sum((k_ .== kk) .* w', dims=2), hcat, uk) 74 | kernel = map(i -> uk[i[2]], argmax(c, dims=2)) 75 | return kernel 76 | end 77 | 78 | function kernelid(node::GPSumNode{T,V}, x::AbstractMatrix) where {T,V<:GPNode} 79 | i = argmax(node.logweights) 80 | kernel = kernelid(node[i], x) 81 | return kernelid(node[i], x) 82 | end 83 | 84 | getLogNoise(node::GPNode, x::AbstractMatrix) = repeat([node.dist.logNoise.value], size(x,1)) 85 | function getLogNoise(node::GPSplitNode, x::AbstractMatrix) 86 | idx = getchild(node, x) 87 | noise = zeros(size(x,1)) 88 | for (k, c) in enumerate(children(node)) 89 | j = findall(idx .== k) 90 | noise_ = getLogNoise(c, x[j,:]) 91 | noise[j] = noise_ 92 | end 93 | return noise 94 | end 95 | 96 | function getLogNoise(node::GPSumNode, x::AbstractMatrix) 97 | return lse(mapreduce(k -> node.logweights[k].+getLogNoise(node[k],x), hcat, 1:length(node))) 98 | end 99 | 100 | 101 | function getchild(node::GPSplitNode, x::AbstractMatrix) 102 | idx = zeros(Int, size(x,1)) 103 | @inbounds for n in 1:size(x,1) 104 | k = 1 105 | while idx[n] == 0 106 | split = node.split[k] 107 | d, s = split 108 | 109 | accept = if k == 1 110 | x[n,d] <= s 111 | else 112 | (x[n,d] <= s) & (x[n,d] > node.split[k-1][2]) 113 | end 114 | 115 | if accept 116 | idx[n] = k 117 | end 118 | k += 1 119 | end 120 | end 121 | return idx 122 | end 123 | 124 | leftGP(node::GPSumNode{T,C}) where {T<:Real,C<:SPNNode} = leftGP(first(children(node))) 125 | leftGP(node::GPSumNode{T,GPNode}) where {T<:Real} = mapreduce(leftGP, vcat, children(node)) 126 | leftGP(node::GPSplitNode) = leftGP(first(children(node))) 127 | leftGP(node::GPNode) = node.dist 128 | 129 | rightGP(node::GPSumNode{T,C}) where {T<:Real,C<:SPNNode} = rightGP(last(children(node))) 130 | rightGP(node::GPSumNode{T,GPNode}) where {T<:Real} = mapreduce(rightGP, vcat, children(node)) 131 | rightGP(node::GPSplitNode) = rightGP(last(children(node))) 132 | rightGP(node::GPNode) = node.dist 133 | 134 | function _predict(node::GPNode, x::AbstractMatrix, μmin) 135 | μ, Σ = prediction(node.dist, x) 136 | σ² = diag(Σ) 137 | σ²[σ² .<= 0] .= ϵ 138 | @assert all(μ .>= μmin) 139 | lm = log.(μ-μmin) 140 | lm2 = log.(μ.^2) 141 | ls = log.(σ²) 142 | return lm, lm2, ls, ones(Int, size(x,1)) 143 | end 144 | 145 | function _predictPoE(node::GPNode, x::AbstractMatrix) 146 | μ, Σ = prediction(node.dist, x) 147 | σ² = diag(Σ) 148 | return μ, inv.(σ²) 149 | end 150 | 151 | function _minpredict(node::GPNode, x::AbstractMatrix) 152 | μ, _ = prediction(node.dist, x) 153 | return μ 154 | end 155 | 156 | function _minpredict(node::GPSplitNode, x::AbstractMatrix) 157 | idx = getchild(node, x) 158 | μ = zeros(size(x,1)) 159 | for (k, c) in enumerate(children(node)) 160 | j = findall(idx .== k) 161 | μ_ = _minpredict(c, x[j,:]) 162 | μ[j] = μ_ 163 | end 164 | return μ 165 | end 166 | 167 | function _minpredict(node::GPSumNode, x::AbstractMatrix) 168 | μ = ones(size(x,1)) * Inf 169 | for (k, c) in enumerate(children(node)) 170 | μ = vec(minimum([μ _minpredict(c, x)], dims=2)) 171 | end 172 | return μ 173 | end 174 | 175 | function predict(node::GPNode, x::AbstractMatrix) 176 | μmin = _minpredict(node, x) 177 | lμ, _, lwσ², _ = _predict(node, x, μmin .- 1) 178 | return exp.(lμ) + μmin .- 1, exp.(lwσ²) 179 | end 180 | 181 | function _predict(node::GPSplitNode, x::AbstractMatrix, μmin) 182 | idx = getchild(node, x) 183 | lμ = zeros(size(x,1)) 184 | lwμ² = zeros(size(x,1)) 185 | lwσ² = zeros(size(x,1)) 186 | n = zeros(Int, size(x,1)) 187 | for (k, c) in enumerate(children(node)) 188 | j = findall(idx .== k) 189 | lμ_, lwμ²_, lwσ²_, n_ = _predict(c, x[j,:], μmin[j]) 190 | lμ[j] = lμ_ 191 | lwμ²[j] = lwμ²_ 192 | lwσ²[j] = lwσ²_ 193 | n[j] = n_ 194 | end 195 | return lμ, lwμ², lwσ², n 196 | end 197 | 198 | function _predictPoE(node::GPSplitNode, x::AbstractMatrix) 199 | μ = zeros(size(x,1)) 200 | t = zeros(size(x,1)) 201 | 202 | for (k,c) in enumerate(children(node)) 203 | μ_, t_ = _predictPoE(c, x) 204 | t[:] += t_ 205 | μ[:] += t_ .* μ_ 206 | end 207 | return μ ./ t, t 208 | end 209 | 210 | # same as for PoE 211 | function _predictgPoE(node::GPSplitNode, x::AbstractMatrix, M::Int) 212 | μ = zeros(size(x,1)) 213 | t = zeros(size(x,1)) 214 | M = length(node) 215 | β = 1/M 216 | for (k,c) in enumerate(children(node)) 217 | μ_, t_ = _predictPoE(c, x) 218 | t[:] += β*t_ 219 | μ[:] += β*t_ .* μ_ 220 | end 221 | return μ ./ t, t 222 | end 223 | 224 | function _predictrBCM(node::GPSplitNode, x::AbstractMatrix) 225 | μ = zeros(size(x,1)) 226 | 227 | gp = leftGP(node) 228 | s = diag(kernelmatrix(gp.kernel, x, x)) .+ getnoise(gp) 229 | 230 | C = deepcopy(1 ./ s) 231 | 232 | for (k,c) in enumerate(children(node)) 233 | μ_, t_ = _predictPoE(c, x) 234 | s_ = 1 ./ t_ 235 | β_ = 0.5 * (log.(s) - log.(s_)) 236 | C += (β_ .* t_) - (β_ ./ s) 237 | μ += μ_ .* (β_ .* t_) 238 | end 239 | 240 | return μ ./ C, C 241 | end 242 | 243 | function predict(node::GPSplitNode, x::AbstractMatrix) 244 | idx = getchild(node, x) 245 | μ = zeros(size(x,1)) 246 | σ² = zeros(size(x,1)) 247 | for (k, c) in enumerate(children(node)) 248 | j = findall(idx .== k) 249 | m, s = predict(c, x[j,:]) 250 | μ[j] = m 251 | σ²[j] = s 252 | end 253 | return μ, σ² 254 | end 255 | 256 | function predictPoE(node::GPSplitNode, x::AbstractMatrix) 257 | μ, t = _predictPoE(node, x) 258 | σ² = inv.(t) 259 | return μ, σ² 260 | end 261 | 262 | # We use β = 1/M as described in Deisenroth et al. (ICML 2015) 263 | function predictgPoE(node::GPSplitNode, x::AbstractMatrix) 264 | μ, t = _predictgPoE(node, x, 1) 265 | σ² = inv.(t) 266 | return μ, σ² 267 | end 268 | 269 | function predictrBCM(node::GPSplitNode, x::AbstractMatrix) 270 | μ, t = _predictrBCM(node, x) 271 | σ² = inv.(t) 272 | return μ, σ² 273 | end 274 | 275 | function _predict(node::GPSumNode, x::AbstractMatrix, μmin) 276 | 277 | lμ = zeros(size(x,1), length(node)) 278 | lwμ² = zeros(size(x,1), length(node)) 279 | lwσ² = zeros(size(x,1), length(node)) 280 | n = zeros(Int, size(x,1)) 281 | 282 | for (k, c) in enumerate(children(node)) 283 | lμ_, lwμ²_, lwσ²_, n_ = _predict(c, x, μmin) 284 | 285 | lμ[:,k] = lμ_ .+ logweights(node)[k] 286 | lwμ²[:,k] = lwμ²_ .+ logweights(node)[k] 287 | lwσ²[:,k] = lwσ²_ .+ logweights(node)[k] 288 | n += n_ 289 | end 290 | 291 | return vec(lse(lμ)), vec(lse(lwμ²)), vec(lse(lwσ²)), n 292 | end 293 | 294 | function predict(node::GPSumNode, x::AbstractMatrix) 295 | 296 | μmin = _minpredict(node, x) 297 | 298 | lμ, lwμ², lwσ², n = _predict(node, x, μmin .- 1) 299 | μ = exp.(lμ) + μmin .- 1 300 | v = exp.(lwσ²) + (exp.(lwμ²) - μ.^2) 301 | return μ, v 302 | end 303 | 304 | @inline predict(model::DSMGP, x::AbstractMatrix) = predict(model.root, x) 305 | @inline predict(model::PoE, x::AbstractMatrix) = predictPoE(model.root, x) 306 | @inline predict(model::gPoE, x::AbstractMatrix) = predictgPoE(model.root, x) 307 | @inline predict(model::rBCM, x::AbstractMatrix) = predictrBCM(model.root, x) 308 | 309 | function lse(x::AbstractMatrix{<:Real}; dims = 2) 310 | m = maximum(x, dims = dims) 311 | v = exp.(x .- m) 312 | return log.(sum(v, dims = dims)) + m 313 | end 314 | 315 | @inline getx(node::GPNode) = node.dist.x 316 | @inline getx(node::GPSplitNode) = mapreduce(c -> getx(c), vcat, children(node)) 317 | @inline getx(node::GPSumNode) = getx(node[1]) 318 | 319 | @inline gety(node::GPNode) = node.dist.y + get(node.dist.mean, node.dist.N) 320 | @inline gety(node::GPSplitNode) = mapreduce(c -> gety(c), vcat, children(node)) 321 | @inline gety(node::GPSumNode) = gety(node[1]) 322 | 323 | @inline update!(node::GPNode) = mll(node.dist) 324 | @inline update!(node::GPSplitNode) = mapreduce(update!, +, children(node)) 325 | 326 | function update!(node::GPSumNode) 327 | K = length(node) 328 | map!(c -> -log(K)+update!(c), node.logweights, children(node)) 329 | z = StatsFuns.logsumexp(node.logweights) 330 | map!(lw -> lw - z, node.logweights, node.logweights) 331 | return z 332 | end 333 | 334 | @inline update!(spn::DSMGP) = update!(spn.root) 335 | 336 | @inline infer!(node::GPNode) = mll(node.dist) 337 | @inline infer!(node::GPSplitNode) = mapreduce(infer!, +, children(node)) 338 | 339 | function infer!(node::GPSumNode{T,V}) where {T<:AbstractFloat,V<:GPNode} 340 | K = length(node) 341 | map!(c -> -log(K)+infer!(c), node.logweights, children(node)) 342 | z = StatsFuns.logsumexp(node.logweights) 343 | map!(lw -> lw - z, node.logweights, node.logweights) 344 | return z 345 | end 346 | 347 | function infer!(node::GPSumNode{T,V}) where {T<:AbstractFloat,V<:SPNNode} 348 | K = length(node) 349 | map!(c -> -log(K)+infer!(c), node.logweights, children(node)) 350 | z = StatsFuns.logsumexp(node.logweights) 351 | fill!(node.logweights, -log(K)) 352 | return z 353 | end 354 | 355 | @inline infer!(spn::DSMGP) = infer!(spn.root) 356 | 357 | function reset_weights!(spn::DSMGP) 358 | snodes = filter(n -> n isa SumNode, SumProductNetworks.getOrderedNodes(spn.root)) 359 | for n in snodes 360 | K = length(n) 361 | fill!(n.logweights, -log(K)) 362 | end 363 | end 364 | 365 | function stats(node::GPNode; dict::Dict{Symbol,Any} = Dict{Symbol,Any}()) 366 | dict[:gps] = get(dict, :gps, 0) + 1 367 | if !haskey(dict, :ndata) 368 | dict[:ndata] = Vector{Int}() 369 | end 370 | push!(dict[:ndata], node.dist.N) 371 | end 372 | 373 | function stats(node::GPSumNode; dict::Dict{Symbol,Any} = Dict{Symbol,Any}()) 374 | for c in children(node) 375 | stats(c, dict = dict) 376 | end 377 | dict[:sumnodes] = get(dict, :slitnodes, 0) + 1 378 | return dict 379 | end 380 | 381 | function stats(node::GPSplitNode; dict::Dict{Symbol,Any} = Dict{Symbol,Any}()) 382 | for c in children(node) 383 | stats(c, dict = dict) 384 | end 385 | dict[:slitnodes] = get(dict, :slitnodes, 0) + 1 386 | if !haskey(dict, :bounds) 387 | dict[:bounds] = Vector{Tuple{Vector{Float64}, Vector{Float64}}}() 388 | end 389 | push!(dict[:bounds], (node.lowerBound, node.upperBound)) 390 | if !haskey(dict, :ids) 391 | dict[:ids] = Vector{Symbol}() 392 | end 393 | push!(dict[:ids], node.id) 394 | return dict 395 | end 396 | -------------------------------------------------------------------------------- /src/fit.jl: -------------------------------------------------------------------------------- 1 | using Base.Threads 2 | using DeepStructuredMixtures.AdvancedCholesky 3 | 4 | export getOverlap, updateK!, fit!, fit_naive! 5 | export updategradients! 6 | export getLeaves 7 | export distancematrix, kernelmatrix 8 | 9 | @inline getLeaves(node::GPNode) = [node] 10 | @inline getLeaves(node::Node) = mapreduce(getLeaves, vcat, children(node)) 11 | 12 | @inline function getOverlap(node::GPNode, D::Matrix{T}, idmap::BiDict) where {T<:Real} 13 | return [node] 14 | end 15 | @inline function getOverlap(node::GPSplitNode, D::Matrix{T}, idmap::BiDict) where {T<:Real} 16 | return mapreduce(c -> getOverlap(c, D, idmap), vcat, children(node)) 17 | end 18 | function getOverlap(node::GPSumNode, D::Matrix{T}, idmap::BiDict) where {T<:Real} 19 | r = map(c -> getOverlap(c, D, idmap), children(node)) 20 | @inbounds begin 21 | for i = 1:length(r) 22 | for j = (i+1):length(r) 23 | for nnode in r[i] 24 | n = idmap.x[nnode.id] 25 | for mnode in r[j] 26 | m = idmap.x[mnode.id] 27 | Δ = xor.(nnode.observations, mnode.observations) 28 | Δn = sum(Δ .& nnode.observations) * (nnode.kernelid == mnode.kernelid) 29 | Δm = sum(Δ .& mnode.observations) * (nnode.kernelid == mnode.kernelid) 30 | D[n,m] = one(T) - T(Δn / sum(nnode.observations)) 31 | D[m,n] = one(T) - T(Δm / sum(mnode.observations)) 32 | end 33 | end 34 | end 35 | end 36 | end 37 | 38 | return reduce(vcat, r) 39 | end 40 | 41 | function getObservationCount!(node::GPNode, P::Matrix{Int}) 42 | for n in node.observations 43 | for m in node.observations 44 | if n != m 45 | P[n,m] += 1 46 | end 47 | end 48 | end 49 | end 50 | @inline function getObservationCount!(node::GPSplitNode, P) 51 | map(c -> getObservationCount!(c, P), children(node)) 52 | end 53 | @inline function getObservationCount!(node::GPSumNode, P) 54 | map(c -> getObservationCount!(c, P), children(node)) 55 | end 56 | 57 | """ 58 | fit!(model::DSMGP; τ = 0.05) 59 | 60 | Update the Cholesky decompositions of the DSMGP. 61 | 62 | # Arguments: 63 | 64 | * `τ`: Minimal relative overlap required to use shared computation. Higher values will lead to higher inaccuracies in the solutions. 65 | 66 | """ 67 | function fit!(spn::DSMGP; τ = 0.05) 68 | return fit!(spn.root, spn.D, spn.gpmap, τ = τ) 69 | end 70 | 71 | function fit!(spn::Union{GPSumNode, GPSplitNode}, D::Matrix, gpmap::BiDict; τ = 0.05) 72 | 73 | leaves = getLeaves(spn) 74 | n = length(leaves) 75 | processed = falses(n) 76 | counts = zeros(Int,n) 77 | S = Vector{GPNode}(undef,n) 78 | for j in 1:n 79 | i = argmax(D[:,j] .* D[j,:]) 80 | counts[i] += 1 81 | nid = gpmap.fx[i] 82 | S[j] = leaves[findfirst(map(n -> n.id == nid, leaves))] 83 | #S[j] = findfirst(n.id == nid for n in leaves) 84 | end 85 | 86 | sort!(leaves, by = (node) -> counts[gpmap.x[node.id]]) 87 | 88 | ttotal = @elapsed for jNode in leaves #for (ii, i) in enumerate(queued) 89 | j = gpmap.x[jNode.id] 90 | if !processed[j] 91 | 92 | mainNode = S[j] 93 | i = gpmap.x[mainNode.id] 94 | mainGP = mainNode.dist 95 | 96 | # solve the main GP 97 | if !processed[i] 98 | update_cholesky!(mainGP) 99 | processed[i] = true 100 | end 101 | 102 | jGP = jNode.dist 103 | processed[j] = true 104 | 105 | update_cholesky!(jGP) 106 | 107 | if mainNode.kernelid != jNode.kernelid 108 | # solve Cholesky 109 | update_cholesky!(jGP) 110 | elseif first(jNode.obs) < first(mainNode.obs) 111 | # solve Cholesky 112 | update_cholesky!(jGP) 113 | else 114 | ione = D[i,j] == one(eltype(D)) 115 | jone = D[j,i] == one(eltype(D)) 116 | fitcontained!(jNode, jGP, mainNode, mainGP, Val(ione), Val(jone), τ) 117 | end 118 | end 119 | end 120 | 121 | ttotal 122 | end 123 | 124 | function fitcontained!(jNode::GPNode, 125 | jGP::GaussianProcess, 126 | mainNode::GPNode, 127 | mainGP::GaussianProcess, 128 | ione, jone, τ::Float64) 129 | update_cholesky!(jGP) 130 | end 131 | 132 | function fitcontained!(jNode::GPNode, 133 | jGP::GaussianProcess, 134 | mainNode::GPNode, 135 | mainGP::GaussianProcess, 136 | ione::Val{true}, 137 | jone::Val{true}, 138 | τ::Float64 139 | ) 140 | # copy Cholesky 141 | jGP.cK.factors[:] = mainGP.cK.factors 142 | jGP.α[:] = mainGP.α 143 | end 144 | 145 | function fitcontained!(jNode::GPNode, 146 | jGP::GaussianProcess, 147 | mainNode::GPNode, 148 | mainGP::GaussianProcess, 149 | ione::Val{false}, 150 | jone::Val{true}, 151 | τ::Float64 152 | ) 153 | 154 | 155 | # j is a sub-region or overlaps 156 | 157 | # solve with low-rank update 158 | minJ = minimum(jNode.obs) 159 | maxJ = maximum(jNode.obs) 160 | 161 | minM = minimum(mainNode.obs) 162 | maxM = maximum(mainNode.obs) 163 | 164 | @assert minJ >= minM 165 | @assert maxJ <= maxM 166 | 167 | s = minJ == minM ? 1 : findfirst(mainNode.obs .== minJ) 168 | e = maxJ == maxM ? mainNode.nobs : findfirst(mainNode.obs .== maxJ) 169 | 170 | idx = collect(s:e) 171 | toupdate = setdiff(mainNode.obs[1:e], jNode.obs) 172 | 173 | # only do low-rank updates if sufficiently stable 174 | if (length(toupdate) / jNode.nobs) < τ 175 | 176 | CC = copy(mainGP.cK.factors) 177 | d = size(CC,1) 178 | 179 | for n in toupdate 180 | @inbounds begin 181 | i = findfirst(mainNode.obs .== n) 182 | AdvancedCholesky.lowrankupdate!(CC, 183 | view(CC,i,(i+1):d), 184 | (i+1), 185 | mainGP.cK.uplo) 186 | end 187 | end 188 | 189 | reverse!(toupdate) 190 | for n in toupdate 191 | i = findfirst(mainNode.obs[idx] .== n) 192 | !isnothing(i) && deleteat!(idx,i) 193 | end 194 | 195 | jGP.cK.factors[:] = CC[idx,idx] 196 | 197 | if all(diag(jGP.cK.factors) .>= 0) 198 | jGP.α[:] = jGP.cK.L' \ (jGP.cK.L \ jGP.y) 199 | else 200 | update_cholesky!(jGP) 201 | end 202 | else 203 | # solve Cholesky 204 | update_cholesky!(jGP) 205 | end 206 | end 207 | 208 | function fitcontained!(jNode::GPNode, 209 | jGP::GaussianProcess, 210 | mainNode::GPNode, 211 | mainGP::GaussianProcess, 212 | ione::Val{true}, 213 | jone::Val{false}, 214 | τ::Float64 215 | ) 216 | 217 | 218 | Knn = kernelmatrix(jGP.kernel, jGP.P) 219 | 220 | # reset factors to kernel matrix 221 | F = jGP.cK.factors 222 | Tchol = eltype(F) 223 | @inbounds F[:] = Tchol.(Knn) 224 | 225 | # compute noise 226 | noise = Tchol(getnoise(jGP) + ϵ) 227 | 228 | # add noise 229 | σ = @view F[diagind(F)] 230 | map!(i -> i+noise, σ, σ) 231 | 232 | # j isa larger than main region 233 | minJ = minimum(jNode.obs) 234 | maxJ = maximum(jNode.obs) 235 | 236 | minM = minimum(mainNode.obs) 237 | maxM = maximum(mainNode.obs) 238 | 239 | @assert minJ >= minM 240 | @assert maxJ >= maxM 241 | 242 | s = minJ == minM ? 1 : findfirst(mainNode.obs .== minJ) 243 | e = mainNode.nobs 244 | 245 | idx = collect(s:e) 246 | 247 | @inbounds s1 = jNode.obs[1:findfirst(jNode.obs .== maxM)] 248 | @inbounds s2 = mainNode.obs[idx] 249 | @inbounds toupdate = setdiff(mainNode.obs[1:e], jNode.obs[1:findfirst(jNode.obs .== maxM)]) 250 | 251 | if (length(s1) != length(s2)) && (minJ == minM) 252 | update_cholesky!(jGP) 253 | else 254 | 255 | # only do low-rank updates if sufficiently stable 256 | if (length(toupdate) / jNode.nobs) < τ 257 | CC = copy(mainGP.cK.factors) 258 | d = size(CC,1) 259 | 260 | for n in toupdate 261 | @inbounds begin 262 | i = findfirst(mainNode.obs .== n) 263 | AdvancedCholesky.lowrankupdate!(CC, 264 | view(CC,i,(i+1):d), 265 | (i+1), 266 | mainGP.cK.uplo) 267 | end 268 | end 269 | 270 | reverse!(toupdate) 271 | for n in toupdate 272 | i = findfirst(mainNode.obs[idx] .== n) 273 | !isnothing(i) && deleteat!(idx,i) 274 | end 275 | 276 | @inbounds F[1:length(s1), 1:length(s1)] = CC[idx,idx] 277 | 278 | _,info = AdvancedCholesky.chol_continue!(F, length(s1)+1) 279 | 280 | check = all(diag(F) .>= 0.0) 281 | 282 | if (info == 0) && check 283 | @inbounds jGP.α[:] = jGP.cK.L' \ (jGP.cK.L \ jGP.y) 284 | else 285 | update_cholesky!(jGP) 286 | end 287 | else 288 | # solve Cholesky 289 | update_cholesky!(jGP) 290 | end 291 | end 292 | end 293 | 294 | function fit_naive!(spn::Union{GPSplitNode,GPSumNode}) 295 | 296 | gpmapping = getLeaves(spn) 297 | gpids = collect(keys(gpmapping)) 298 | K = length(gpids) 299 | ttotal = @elapsed for gpid in gpids 300 | update_cholesky!(gpmapping[gpid].dist) 301 | end 302 | # @info "[fit_naive!] finished with $ttotal sec taken for Cholesky decompositions" 303 | ttotal 304 | end 305 | 306 | function updategradients!(spn::Union{GPSumNode, GPSplitNode}) 307 | leaves = getLeaves(spn) 308 | Threads.@threads for leaf in leaves 309 | updategradients!(leaf.dist) 310 | end 311 | end 312 | 313 | function distancematrix(spn, kernel::IsoKernel, x::AbstractMatrix) 314 | N = length(gety(spn)) 315 | Ix = map(n -> Vector{Int}(), 1:N) 316 | blockindecies(spn, Ix) 317 | V = map( i -> vec(getdistancematrix(kernel, reshape(x[i,:], 1, :), x[Ix[i],:])), 1:N) 318 | return SDiagonal(Ix, V) 319 | end 320 | 321 | function distancematrix(spn, kernel::ArdKernel, x::AbstractMatrix) 322 | N = length(gety(spn)) 323 | Ix = map(n -> Vector{Int}(), 1:N) 324 | blockindecies(spn, Ix) 325 | V = map( i -> dropdims(getdistancematrix(kernel, reshape(x[i,:], 1, :), x[Ix[i],:]), dims=1), 1:N) 326 | return SDiagonal(Ix,V) 327 | end 328 | 329 | function updategradients!(spn::Union{GPSumNode, GPSplitNode}, 330 | K::SDiagonal{Tp,2,<:AbstractVector}, 331 | P::SDiagonal{Tp,2,<:AbstractVector}, 332 | D::AbstractMatrix{T}, 333 | gpmap) where {T,Tp,MTp} 334 | leaves = getLeaves(spn) 335 | n = length(leaves) 336 | isprocessed = falses(n) 337 | 338 | S = Dict{Int, Int}() 339 | for j in 1:n 340 | n = gpmap.x[leaves[j].id] 341 | m = argmax(D[:,n] .* D[n,:]) 342 | if (D[n,m] * D[m,n]) == one(T) 343 | mid = gpmap.fx[m] 344 | S[j] = findfirst(map(n -> n.id == mid, leaves)) 345 | end 346 | end 347 | 348 | for (j, leaf) in enumerate(leaves) 349 | if haskey(S, j) 350 | i = S[j] 351 | mainNode = leaves[i] 352 | if isprocessed[i] 353 | copygradients(leaf.dist, mainNode.dist) 354 | end 355 | else 356 | ix = leaf.obs 357 | updategradients!(leaf.dist, @view(K[ix,ix]), @view(P[ix,ix])) 358 | end 359 | isprocessed[j] = true 360 | end 361 | end 362 | 363 | function updategradients!(spn::Union{GPSumNode, GPSplitNode}, 364 | K::SDiagonal{Tp,2,<:AbstractVector}, 365 | P::SDiagonal{Tp,3,<:AbstractArray}, 366 | D::AbstractMatrix{T}, 367 | gpmap) where {T,Tp,MTp} 368 | leaves = getLeaves(spn) 369 | n = length(leaves) 370 | isprocessed = falses(n) 371 | 372 | S = Dict{Int, Int}() 373 | for j in 1:n 374 | n = gpmap.x[leaves[j].id] 375 | m = argmax(D[:,n] .* D[n,:]) 376 | if (D[n,m] * D[m,n]) == one(T) 377 | mid = gpmap.fx[m] 378 | S[j] = findfirst(map(n -> n.id == mid, leaves)) 379 | end 380 | end 381 | 382 | for (j, leaf) in enumerate(leaves) 383 | if haskey(S, j) 384 | i = S[j] 385 | mainNode = leaves[i] 386 | if isprocessed[i] 387 | copygradients(leaf.dist, mainNode.dist) 388 | end 389 | else 390 | ix = leaf.obs 391 | updategradients!(leaf.dist, @view(K[ix,ix]), @view(P[ix,ix,:])) 392 | end 393 | isprocessed[j] = true 394 | end 395 | end 396 | -------------------------------------------------------------------------------- /src/treeStructure.jl: -------------------------------------------------------------------------------- 1 | export build 2 | export buildDSMGP, buildPoE, buildBCM 3 | 4 | function buildTree(X::AbstractMatrix, y::AbstractVector, config::DSMGPConfig) 5 | 6 | N,D = size(X) 7 | @assert N == length(y) 8 | 9 | lowerBound = ones(D) * -Inf 10 | upperBound = ones(D) * Inf 11 | 12 | observations = collect(1:N) 13 | 14 | @assert all(isfinite.(X)) 15 | 16 | if config.sumRoot 17 | _buildSum(X, y, lowerBound, upperBound, config, 0, observations, N) 18 | else 19 | _buildSplit(X, y, lowerBound, upperBound, config, 0, observations, N) 20 | end 21 | end 22 | 23 | function getSplits(X::AbstractMatrix{T}, 24 | lowerBound::Vector{Float64}, 25 | upperBound::Vector{Float64}, 26 | minData::Int, 27 | ϵ::Float64, 28 | K::Int, 29 | d::Int; 30 | depth = 1) where {T<:Real} 31 | α = β = 2.0 32 | 33 | K_ = depth^2 34 | s = Vector{Float64}() 35 | 36 | l = max(lowerBound[d], minimum(X[:,d])) 37 | u = min(upperBound[d], maximum(X[:,d])) 38 | v = u-l 39 | 40 | idx = findall((X[:,d] .> l) .& (X[:,d] .<= u)) 41 | if length(idx) > minData*2 42 | s_new = Float64(mean(X[:,d])) 43 | 44 | z1 = 0 45 | z2 = 0 46 | 47 | c = 0 48 | m = mean(X[idx,d]) 49 | m = median(X[idx,d]) 50 | 51 | while ((z1 == 0) || (z2 == 0)) 52 | a = rand(Beta(α, β))*v + l 53 | 54 | s_new = Float64(ϵ*a + (1-ϵ)*m) 55 | 56 | z1 = sum(X[idx,d] .<= s_new) 57 | z2 = sum(X[idx,d] .> s_new) 58 | 59 | c += 1 60 | 61 | if c > 100 62 | @warn z1, z2, s_new, m, a 63 | return s 64 | end 65 | end 66 | 67 | zi = rand(1:2) 68 | if zi == 1 69 | if (z1 > minData) && (K_ < K) 70 | ub = copy(upperBound) 71 | ub[d] = s_new 72 | append!(s, getSplits(X, 73 | lowerBound, 74 | ub, 75 | minData, 76 | ϵ, 77 | K, 78 | d, 79 | depth = depth+1 80 | )) 81 | K_ += 1 82 | end 83 | if (z2 > minData) && (K_ < K) 84 | lb = copy(upperBound) 85 | lb[d] = s_new 86 | append!(s, getSplits(X, 87 | lb, 88 | upperBound, 89 | minData, 90 | ϵ, 91 | K, 92 | d, 93 | depth = depth+1 94 | )) 95 | end 96 | else 97 | if (z2 > minData) && (K_ < K) 98 | lb = copy(upperBound) 99 | lb[d] = s_new 100 | append!(s, getSplits(X, 101 | lb, 102 | upperBound, 103 | minData, 104 | ϵ, 105 | K, 106 | d, 107 | depth = depth+1 108 | )) 109 | K_ += 1 110 | end 111 | if (z1 > minData) && (K_ < K) 112 | ub = copy(upperBound) 113 | ub[d] = s_new 114 | append!(s, getSplits(X, 115 | lowerBound, 116 | ub, 117 | minData, 118 | ϵ, 119 | K, 120 | d, 121 | depth = depth+1 122 | )) 123 | end 124 | end 125 | 126 | push!(s, s_new) 127 | end 128 | return s 129 | end 130 | 131 | function _buildSplit( 132 | X::AbstractMatrix, 133 | y::AbstractVector, 134 | lowerBound::Vector{Float64}, 135 | upperBound::Vector{Float64}, 136 | config::DSMGPConfig, 137 | depth::Int, 138 | observations::Vector{Int}, 139 | N::Int; 140 | d = 1 141 | ) 142 | 143 | @assert all(isfinite.(X)) 144 | 145 | l = max(lowerBound[d]) 146 | u = min(upperBound[d]) 147 | 148 | idx = findall((X[:,d] .> l) .& (X[:,d] .<= u)) 149 | 150 | s = getSplits(X, 151 | lowerBound, 152 | upperBound, 153 | config.minData, 154 | config.bnoise, 155 | config.K, 156 | d) 157 | sort!(s) 158 | 159 | split = Vector{Tuple{Int, Float64}}() 160 | if !isempty(s) 161 | for si in s 162 | push!(split, (d, si)) 163 | end 164 | end 165 | push!(split, (d, upperBound[d])) 166 | 167 | node = GPSplitNode(gensym("split"), Vector{Node}(), Vector{SPNNode}(), 168 | lowerBound, upperBound, split) 169 | 170 | 171 | lb = copy(lowerBound) 172 | ub = copy(upperBound) 173 | 174 | if !isempty(s) 175 | for spliti in split 176 | (_, si) = spliti 177 | lb_ = copy(lb) 178 | ub_ = copy(ub) 179 | ub_[d] = si 180 | 181 | idx = findall((X[:,d] .> lb_[d]) .& (X[:,d] .<= ub_[d])) 182 | if (depth < config.depth) && (length(idx) > config.minData) 183 | if config.sumRoot 184 | add!(node, 185 | _buildSum(view(X,idx,:), view(y,idx), lb_, ub_, config, depth, 186 | observations[idx], N) 187 | ) 188 | else 189 | add!(node, 190 | _buildSplit(view(X,idx,:), view(y,idx), lb_, ub_, config, depth, 191 | observations[idx], N) 192 | ) 193 | end 194 | else 195 | add!(node, _buildGP(view(X,idx,:), view(y,idx), lb_, ub_, config, 196 | observations[idx], N)) 197 | end 198 | lb[d] = si 199 | end 200 | return node 201 | else 202 | l = lowerBound[d] 203 | u = upperBound[d] 204 | 205 | idx = findall((X[:,d] .> l) .& (X[:,d] .<= u)) 206 | 207 | return _buildGP(view(X,idx,:), view(y, idx), copy(lowerBound), copy(upperBound), 208 | config, observations[idx], N) 209 | end 210 | end 211 | 212 | function _buildSum( 213 | X::AbstractMatrix, 214 | y::AbstractVector, 215 | lowerBound::Vector{Float64}, 216 | upperBound::Vector{Float64}, 217 | config::DSMGPConfig, 218 | depth::Int, 219 | observations::Vector{Int}, 220 | N::Int; 221 | d = 1 222 | ) 223 | @assert all(isfinite.(X)) 224 | 225 | V = config.V 226 | w = fill(-log(V), V) 227 | node = GPSumNode{Float64,SPNNode}(gensym("sum"), 228 | Vector{Node}(), 229 | Vector{SPNNode}(), 230 | Vector{Float64}()) 231 | 232 | dims = collect(1:size(X,2)) 233 | ϕ = map(d -> maximum(X[:,d])-minimum(X[:,d]), dims) 234 | ϕ ./= sum(ϕ) 235 | for v = 1:V 236 | d = rand(Categorical(ϕ)) 237 | add!(node, 238 | _buildSplit(X, y, lowerBound, upperBound, config, depth+1, 239 | observations, N, d = d), w[v] 240 | ) 241 | end 242 | return node 243 | end 244 | 245 | function _buildGP(X::AbstractMatrix, 246 | y::AbstractVector, 247 | lowerBound::Vector{Float64}, 248 | upperBound::Vector{Float64}, 249 | config::DSMGPConfig, 250 | observations::Vector{Int}, 251 | N::Int 252 | ) 253 | myobs = falses(N) 254 | myobs[observations] .= true 255 | 256 | @assert size(X,1) == sum(myobs) 257 | 258 | if config.kernels isa Vector 259 | 260 | w = rand(Dirichlet(length(config.kernels), 1.0)) 261 | node = GPSumNode{Float64,GPNode}(gensym("sum"), 262 | Vector{Node}(), 263 | Vector{GPNode}(), 264 | Vector{Float64}()) 265 | 266 | for v in 1:length(config.kernels) 267 | kern = deepcopy(config.kernels[v]) 268 | obsNoise = copy(config.observationNoise) 269 | 270 | # create a full GP 271 | mfun = config.meanFun == nothing ? ConstMean(mean(y)) : config.meanFun 272 | gp = GaussianProcess(X, y, kernel = kern, mean = mfun, 273 | logNoise = obsNoise, run_cholesky = false) 274 | 275 | add!(node, GPNode(gensym("GP"), 276 | Vector{Node}(), 277 | gp, 278 | myobs, 279 | observations, 280 | lowerBound, 281 | upperBound, 282 | sum(myobs), 283 | v 284 | ), log(w[v])) 285 | end 286 | return node 287 | else 288 | kern = deepcopy(config.kernels) 289 | obsNoise = copy(config.observationNoise) 290 | 291 | # create a full GP 292 | mfun = config.meanFun == nothing ? ConstMean(mean(y)) : config.meanFun 293 | gp = GaussianProcess(X, y, kernel = kern, mean = mfun, 294 | logNoise = obsNoise, run_cholesky = false) 295 | 296 | return GPNode(gensym("GP"), 297 | Vector{Node}(), 298 | gp, 299 | myobs, 300 | observations, 301 | lowerBound, 302 | upperBound, 303 | sum(myobs), 304 | 1 305 | ) 306 | end 307 | end 308 | 309 | """ 310 | buildDSMGP(x,y,K,V; ϵ=0.5, M=30, D=2, kernel=IsoSE(1.0,1.0), meanFun=nothing, logNoise=1.0, sum=true) 311 | 312 | Build a deep structured mixture of Gaussian processes (DSMGP). 313 | 314 | Arguments: 315 | 316 | * x: Observed input data (Matrix) 317 | * y: Observed output data (Vector) 318 | * K: Number of children under each sum node 319 | * V: Number of splits at each split node 320 | * ϵ: Split position noise parameter (higher means less data-driven splits) 321 | * M: Minimum number of observations per expert 322 | * D: Maximum depth of model 323 | * kernel: Kernel function 324 | * meanFun: Mean function (if nothing use independent ConstMean for each expert) 325 | * logNoise: Log of the likelihood noise variance parameter 326 | * sum: Use sum nodes 327 | """ 328 | function buildDSMGP(x,y,K::Int,V::Int; 329 | ϵ = 0.5, 330 | M = 30, 331 | D = 2, 332 | kernel = IsoSE(1.0, 1.0), 333 | meanFun = nothing, 334 | logNoise = 1.0, 335 | sum = true 336 | ) 337 | m,D,gpmap = build(x,y,K,V,ϵ,M,D,kernel,meanFun,logNoise,sum) 338 | return DSMGP(m, D, gpmap) 339 | end 340 | 341 | """ 342 | buildPoE(x,y,V; ϵ=0.0, M=30, D=2, kernel=IsoSE(1.0,1.0), meanFun=nothing, logNoise=1.0, generalized=false) 343 | 344 | Build (generalized) Product-of-Experts. 345 | 346 | Arguments: 347 | 348 | * x: Observed input data (Matrix) 349 | * y: Observed output data (Vector) 350 | * V: Number of splits at each split node 351 | * ϵ: Split position noise parameter (higher means less data-driven splits) 352 | * M: Minimum number of observations per expert 353 | * D: Maximum depth of model 354 | * kernel: Kernel function 355 | * meanFun: Mean function (if nothing use independent ConstMean for each expert) 356 | * logNoise: Log of the likelihood noise variance parameter 357 | * generalized: Use generalized formulation (Deisenroth et al. 2015) 358 | 359 | """ 360 | function buildPoE(x,y,V::Int; 361 | ϵ = 0.0, 362 | M = 30, 363 | D = 2, 364 | kernel = IsoSE(1.0, 1.0), 365 | meanFun, 366 | logNoise = 1.0, 367 | generalized = false 368 | ) 369 | m,D,gpmap = build(x,y,1,V,ϵ,M,D,kernel,meanFun,logNoise,false) 370 | return generalized ? gPoE(m, D, gpmap) : PoE(m, D, gpmap) 371 | end 372 | 373 | """ 374 | buildBCM(x,y,V; ϵ=0.0, M=30, D=2, kernel=IsoSE(1.0,1.0), meanFun=nothing, logNoise=1.0) 375 | 376 | Build robust Bayesian Committee machine. 377 | 378 | Arguments: 379 | 380 | * x: Observed input data (Matrix) 381 | * y: Observed output data (Vector) 382 | * V: Number of splits at each split node 383 | * ϵ: Split position noise parameter (higher means less data-driven splits) 384 | * M: Minimum number of observations per expert 385 | * D: Maximum depth of model 386 | * kernel: Kernel function 387 | * meanFun: Mean function (if nothing use independent ConstMean for each expert) 388 | * logNoise: Log of the likelihood noise variance parameter 389 | * robust: Use robust formulation (Deisenroth et al. 2015) 390 | 391 | """ 392 | function buildBCM(x,y,V::Int; 393 | ϵ = 0.0, 394 | M = 30, 395 | D = 2, 396 | kernel = IsoSE(1.0, 1.0), 397 | meanFun = nothing, 398 | logNoise = 1.0, 399 | robust = false 400 | ) 401 | m,D,gpmap = build(x,y,1,V,ϵ,M,D,kernel,meanFun,logNoise,false) 402 | return rBCM(m, D, gpmap) 403 | end 404 | 405 | function build(x, y, K::Int, V::Int, ϵ, M, D, kernel, meanFun, logNoise, useSum) 406 | 407 | # DSMGP with a multiple independent GPs 408 | config = DSMGPConfig( 409 | meanFun, 410 | kernel, # kernel function / kernel functions 411 | logNoise, # log σ - Noise 412 | M, # max number of samples per sub-region 413 | V, # K = number of splits per split node (not used) 414 | K, # V = number of children under a sum node 415 | D, # maximum depth of the tree 416 | ϵ, # relative noise used to displace split positions 417 | useSum # use sum root 418 | ) 419 | spn = buildTree(x, y, config); 420 | 421 | gpids = getLeafIds(spn) 422 | 423 | x = Dict(id[2] => id[1] for id in enumerate(gpids)) 424 | fx = Dict(id[1] => id[2] for id in enumerate(gpids)) 425 | 426 | gpmap = BiDict(x,fx) 427 | 428 | D = zeros(Float64, length(gpids), length(gpids)); 429 | 430 | # update D 431 | getOverlap(spn, D, gpmap); 432 | 433 | # fit model 434 | fit!(spn,D,gpmap) 435 | 436 | return spn, D, gpmap 437 | end 438 | --------------------------------------------------------------------------------