├── .github └── workflows │ └── run-tests.yaml ├── .gitignore ├── .project ├── .settings ├── ca.ubc.stat.blang.BlangDsl.prefs ├── org.eclipse.emf.ecore.xcore.Xcore.prefs ├── org.eclipse.jdt.core.prefs ├── org.eclipse.xtend.core.Xtend.prefs └── org.eclipse.xtext.java.Java.prefs ├── .travis.yml ├── LICENSE.txt ├── README.md ├── build.gradle ├── doc ├── blang.js ├── build.sh ├── deploy.sh ├── download-deps.sh ├── eclipse-release-assembly │ ├── assemble.sh │ └── plain-eclipse │ │ └── download-eclipse-xtext.sh ├── www │ ├── GitHub-logo.png │ ├── ide.jpg │ ├── jumbotron-narrow.css │ └── jupiter.jpg └── xtend.js ├── gradle └── wrapper │ ├── gradle-wrapper.jar │ └── gradle-wrapper.properties ├── gradlew ├── gradlew.bat ├── jitpack.yml ├── settings.gradle ├── setup-cli.sh ├── setup-eclipse.sh └── src ├── main └── java │ └── blang │ ├── distributions │ ├── Bernoulli.bl │ ├── Beta.bl │ ├── BetaBinomial.bl │ ├── BetaNegativeBinomial.bl │ ├── Binomial.bl │ ├── Categorical.bl │ ├── ChiSquared.bl │ ├── ContinuousUniform.bl │ ├── Dirichlet.bl │ ├── DiscreteUniform.bl │ ├── Exponential.bl │ ├── F.bl │ ├── Gamma.bl │ ├── GammaMeanParam.bl │ ├── Generators.java │ ├── Geometric.bl │ ├── Gompertz.bl │ ├── Gumbel.bl │ ├── HalfStudentT.bl │ ├── HyperGeometric.bl │ ├── Laplace.bl │ ├── LnUniform.bl │ ├── LogLogistic.bl │ ├── LogPotential.bl │ ├── LogUniform.bl │ ├── Logistic.bl │ ├── MultivariateNormal.bl │ ├── NegativeBinomial.bl │ ├── NegativeBinomialMeanParam.bl │ ├── Normal.bl │ ├── NormalField.bl │ ├── Poisson.bl │ ├── SimplexUniform.bl │ ├── StudentT.bl │ ├── SymmetricDirichlet.bl │ ├── Weibull.bl │ ├── YuleSimon.bl │ └── internals │ │ └── Helpers.java │ ├── engines │ ├── AdaptiveJarzynski.java │ ├── ParallelTempering.java │ └── internals │ │ ├── CovarAccumulator.java │ │ ├── EngineStaticUtils.java │ │ ├── LogSumAccumulator.java │ │ ├── PosteriorInferenceEngine.java │ │ ├── Spline.java │ │ ├── SplineDerivatives.xtend │ │ ├── factories │ │ ├── AIS.java │ │ ├── Exact.java │ │ ├── Forward.java │ │ ├── IAIS.xtend │ │ ├── ISCM.xtend │ │ ├── MCMC.java │ │ ├── None.java │ │ ├── PT.java │ │ ├── Pigeons.java │ │ └── SCM.java │ │ ├── ladders │ │ ├── EquallySpaced.java │ │ ├── FromAnotherExec.java │ │ ├── Geometric.java │ │ ├── Polynomial.java │ │ ├── TemperatureLadder.java │ │ └── UserSpecified.java │ │ ├── ptanalysis │ │ ├── PathViz.xtend │ │ └── Paths.xtend │ │ └── schedules │ │ ├── AdaptiveTemperatureSchedule.java │ │ ├── FixedTemperatureSchedule.java │ │ ├── TemperatureSchedule.java │ │ └── UserSpecified.java │ ├── io │ ├── BlangTidySerializer.xtend │ ├── DataSource.xtend │ ├── GlobalDataSource.xtend │ ├── NA.xtend │ ├── Parsers.xtend │ └── internals │ │ ├── CSV.xtend │ │ ├── DataSourceReader.xtend │ │ └── GlobalDataSourceStore.xtend │ ├── mcmc │ ├── CategoricalSampler.xtend │ ├── ConnectedFactor.java │ ├── EllipticalSliceSampler.xtend │ ├── IntSliceSampler.java │ ├── MHSampler.java │ ├── RealSliceSampler.java │ ├── SampledVariable.java │ ├── Sampler.java │ ├── Samplers.java │ ├── SimplexSampler.xtend │ ├── UniformSampler.xtend │ └── internals │ │ ├── BuiltSamplers.java │ │ ├── Callback.java │ │ ├── ExponentiatedFactor.java │ │ ├── SamplerBuilder.java │ │ ├── SamplerBuilderContext.java │ │ ├── SamplerBuilderOptions.java │ │ ├── SamplerMatch.java │ │ ├── SamplerMatchingUtils.java │ │ ├── SamplerSet.java │ │ ├── SimplexWritableVariable.xtend │ │ └── bps │ │ ├── Likelihood2EnergyAdaptor.java │ │ └── RealVar2MutableDouble.java │ ├── runtime │ ├── Observations.xtend │ ├── PostProcessor.xtend │ ├── Runner.xtend │ ├── SampledModel.java │ └── internals │ │ ├── ComputeESS.java │ │ ├── CreateBlangGradleProject.java │ │ ├── DefaultPostProcessor.xtend │ │ ├── Main.xtend │ │ ├── RecursiveAnnotationProducer.java │ │ ├── StandaloneCompiler.java │ │ ├── doc │ │ ├── Categories.java │ │ ├── MakeHTMLDoc.xtend │ │ └── contents │ │ │ ├── BlangCLI.xtend │ │ │ ├── BlangIDE.xtend │ │ │ ├── BlangWeb.xtend │ │ │ ├── BuiltInDistributions.xtend │ │ │ ├── BuiltInFunctions.xtend │ │ │ ├── BuiltInRandomVariables.xtend │ │ │ ├── CreatingTypes.xtend │ │ │ ├── Empty.xtend │ │ │ ├── Examples.xtend │ │ │ ├── GettingStarted.xtend │ │ │ ├── Home.xtend │ │ │ ├── InferenceAndRuntime.xtend │ │ │ ├── InputOutput.xtend │ │ │ ├── Javadoc.xtend │ │ │ ├── Syntax.xtend │ │ │ └── Testing.xtend │ │ └── objectgraph │ │ ├── AccessibilityGraph.java │ │ ├── AnnealingStructure.java │ │ ├── ArrayConstituentNode.java │ │ ├── ArrayView.java │ │ ├── ConstituentNode.java │ │ ├── DeepCloner.java │ │ ├── DoubleArrayView.java │ │ ├── ExplorationRule.java │ │ ├── ExplorationRules.java │ │ ├── FieldConstituentNode.java │ │ ├── GraphAnalysis.java │ │ ├── IntArrayView.java │ │ ├── MapConstituentNode.java │ │ ├── MatrixConstituentNode.java │ │ ├── Node.java │ │ ├── ObjectArrayView.java │ │ ├── ObjectNode.java │ │ ├── SkipDependency.java │ │ ├── SkippedFieldConstituentNode.java │ │ ├── StaticUtils.java │ │ ├── VariableUtils.java │ │ └── ViewedArray.java │ ├── types │ ├── AnnealingParameter.xtend │ ├── DenseSimplex.xtend │ ├── DenseTransitionMatrix.xtend │ ├── ExtensionUtils.xtend │ ├── Index.xtend │ ├── Plate.xtend │ ├── Plated.xtend │ ├── PlatedMatrix.xtend │ ├── Precision.java │ ├── Simplex.java │ ├── SpikedRealVar.xtend │ ├── StaticUtils.xtend │ ├── TransitionMatrix.java │ └── internals │ │ ├── ColumnName.xtend │ │ ├── Delegator.java │ │ ├── HashPlate.xtend │ │ ├── HashPlated.xtend │ │ ├── IndexedDataSource.xtend │ │ ├── IntScalar.xtend │ │ ├── InvalidParameter.java │ │ ├── LatentFactoryAsParser.xtend │ │ ├── Parser.xtend │ │ ├── PlatedSlice.xtend │ │ ├── Query.xtend │ │ ├── RealScalar.xtend │ │ ├── SimpleParser.xtend │ │ └── SimplePlate.xtend │ └── validation │ ├── DeterminismTest.java │ ├── DiscreteMCTest.java │ ├── ExactInvarianceTest.java │ ├── Instance.xtend │ ├── NormalizationTest.java │ ├── UnbiasnessTest.xtend │ └── internals │ ├── Helpers.xtend │ └── fixtures │ ├── AutoBoxDeboxTests.bl │ ├── BadNormal.bl │ ├── BadPlate.bl │ ├── BadRealSliceSampler.java │ ├── CustomAnnealRef.bl │ ├── CustomAnnealTest.bl │ ├── Cyclic.bl │ ├── Diffusion.bl │ ├── Doomsday.bl │ ├── DynamicNormalMixture.bl │ ├── Empty.bl │ ├── ExactHMMCalculations.java │ ├── Examples.xtend │ ├── FixedMatrix.bl │ ├── Functions.xtend │ ├── GenerateTwice.bl │ ├── Growth.bl │ ├── HierarchicalModel.bl │ ├── IfElse.bl │ ├── IntNaiveMHSampler.java │ ├── IntRealizationSquared.java │ ├── Ising.bl │ ├── LinRegression.bl │ ├── ListHash.java │ ├── MarkovChain.bl │ ├── MixtureModel.bl │ ├── Multimodal.bl │ ├── NoGen.bl │ ├── NormalFieldExamples.bl │ ├── NotNormalForm.bl │ ├── Operations.bl │ ├── PCR.bl │ ├── PlatedMatrixTests.bl │ ├── PoissonAllInOne.bl │ ├── PoissonNormalField.bl │ ├── RealNaiveMHSampler.java │ ├── RealRealizationSquared.java │ ├── Scalability.bl │ ├── Simple.bl │ ├── SimpleHierarchicalModel.bl │ ├── SmallHMM.bl │ ├── SometimesNaN.bl │ ├── SpikeAndSlab.bl │ ├── SpikedGLM.bl │ ├── Unid.bl │ ├── UnspecifiedParam.bl │ └── VectorHash.java └── test ├── java └── blang │ ├── TestCloning.xtend │ ├── TestDiscreteModels.xtend │ ├── TestDocumentation.xtend │ ├── TestESS.xtend │ ├── TestEndToEnd.xtend │ ├── TestExactTest.xtend │ ├── TestFixedMatrix.xtend │ ├── TestLadders.java │ ├── TestMoments.java │ ├── TestRunner.java │ ├── TestSDKDistributions.xtend │ ├── TestSDKNormalizations.java │ ├── TestSMCUnbiasness.xtend │ ├── TestSparseDirichletAndBetaWarnings.xtend │ ├── TestStandaloneCompiler.xtend │ ├── TestSyntax.java │ └── runtime │ └── TestSampledModel.java └── resource └── data.csv /.github/workflows/run-tests.yaml: -------------------------------------------------------------------------------- 1 | name: Java CI 2 | 3 | on: [push] 4 | 5 | jobs: 6 | build: 7 | runs-on: ubuntu-latest 8 | 9 | steps: 10 | - uses: actions/checkout@v3 11 | - name: Set up JDK 11 12 | uses: actions/setup-java@v3 13 | with: 14 | java-version: '11' 15 | distribution: 'adopt' 16 | - name: Validate Gradle wrapper 17 | uses: gradle/wrapper-validation-action@v1 18 | - name: Assemble 19 | uses: gradle/gradle-build-action@v2 20 | with: 21 | arguments: installDist 22 | - name: Build with Gradle 23 | uses: gradle/gradle-build-action@v2 24 | with: 25 | arguments: test -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .gradle 2 | .classpath 3 | bin/ 4 | .DS_Store 5 | build 6 | results 7 | java.hprof.txt 8 | src/main/xtend-gen/blang/examples/.gitignore 9 | src/main/xtend-gen/blang/mcmc/.gitignore 10 | src/main/xtend-gen/blang/runtime/.gitignore 11 | src/main/xtend-gen/blang/types/.gitignore 12 | src/main/xtend-gen/blang/utils/.gitignore 13 | samples 14 | failed-test-info-* 15 | doc/www/ace 16 | doc/ace-master 17 | doc/bootstrap 18 | doc/www/dist 19 | Home.html 20 | Quick_Start.html 21 | data.csv 22 | Reference.html 23 | doc/eclipse-release-assembly/plain-eclipse/Eclipse.app 24 | doc/eclipse-release-assembly/blang 25 | doc/www/downloads 26 | Getting_started.html 27 | doc/www/Blang_IDE.html 28 | doc/www/Blang_in_browser.html 29 | doc/www/Blang_via_web.html 30 | logNormEstimate.txt 31 | runningTimeSummary.tsv 32 | index.html 33 | Useful_types.html 34 | Built-in_random_variables.html 35 | Random_variables.html 36 | Reading_from_command_line.html 37 | Syntax_reference.html 38 | doc/www/Inference_and_runtime.html 39 | doc/www/Template.html 40 | Creating_random_types.html 41 | Input_and_output.html 42 | Testing_Blang_models.html 43 | Distributions.html 44 | Functions.html 45 | Examples.html 46 | CLI.html 47 | doc/www/javadoc-dsl 48 | doc/www/javadoc-inits 49 | doc/www/javadoc-sdk 50 | doc/www/javadoc-xlinear 51 | Javadoc.html 52 | data/ 53 | *.swp 54 | lin-reg-data.csv 55 | -------------------------------------------------------------------------------- /.project: -------------------------------------------------------------------------------- 1 | 2 | 3 | blang 4 | 5 | 6 | 7 | org.eclipse.jdt.core.javanature 8 | org.eclipse.xtext.ui.shared.xtextNature 9 | 10 | 11 | 12 | org.eclipse.jdt.core.javabuilder 13 | 14 | 15 | 16 | org.eclipse.xtext.ui.shared.xtextBuilder 17 | 18 | 19 | 20 | 21 | 22 | 23 | -------------------------------------------------------------------------------- /.settings/ca.ubc.stat.blang.BlangDsl.prefs: -------------------------------------------------------------------------------- 1 | //outlet.DEFAULT_OUTPUT.sourceFolder.src/main/java.directory=build/blang/main 2 | //outlet.DEFAULT_OUTPUT.sourceFolder.src/main/resources.directory=build/blang/main 3 | //outlet.DEFAULT_OUTPUT.sourceFolder.src/test/java.directory=build/blang/test 4 | //outlet.DEFAULT_OUTPUT.sourceFolder.src/test/resources.directory=build/blang/test 5 | BuilderConfiguration.is_project_specific=true 6 | ValidatorConfiguration.is_project_specific=true 7 | eclipse.preferences.version=1 8 | generateGeneratedAnnotation=false 9 | generateSuppressWarnings=true 10 | includeDateInGenerated=false 11 | outlet.DEFAULT_OUTPUT.hideLocalSyntheticVariables=true 12 | outlet.DEFAULT_OUTPUT.installDslAsPrimarySource=false 13 | outlet.DEFAULT_OUTPUT.userOutputPerSourceFolder=true 14 | targetJavaVersion=Java8 15 | useJavaCompilerCompliance=true 16 | -------------------------------------------------------------------------------- /.settings/org.eclipse.emf.ecore.xcore.Xcore.prefs: -------------------------------------------------------------------------------- 1 | //outlet.DEFAULT_OUTPUT.sourceFolder.build/blang/main.directory= 2 | //outlet.DEFAULT_OUTPUT.sourceFolder.build/blang/main.ignore= 3 | //outlet.DEFAULT_OUTPUT.sourceFolder.build/xtend/main.directory= 4 | //outlet.DEFAULT_OUTPUT.sourceFolder.build/xtend/main.ignore= 5 | //outlet.DEFAULT_OUTPUT.sourceFolder.src/main/java.directory= 6 | //outlet.DEFAULT_OUTPUT.sourceFolder.src/main/java.ignore= 7 | //outlet.DEFAULT_OUTPUT.sourceFolder.src/main/xtend-gen.directory= 8 | //outlet.DEFAULT_OUTPUT.sourceFolder.src/main/xtend-gen.ignore= 9 | //outlet.DEFAULT_OUTPUT.sourceFolder.src/test/java.directory= 10 | //outlet.DEFAULT_OUTPUT.sourceFolder.src/test/java.ignore= 11 | BuilderConfiguration.is_project_specific=true 12 | autobuilding=true 13 | eclipse.preferences.version=1 14 | generateGeneratedAnnotation=false 15 | generateSuppressWarnings=true 16 | generatedAnnotationComment= 17 | includeDateInGenerated=false 18 | outlet.DEFAULT_OUTPUT.cleanDirectory=false 19 | outlet.DEFAULT_OUTPUT.cleanupDerived=true 20 | outlet.DEFAULT_OUTPUT.createDirectory=true 21 | outlet.DEFAULT_OUTPUT.derived=true 22 | outlet.DEFAULT_OUTPUT.directory=./src-gen 23 | outlet.DEFAULT_OUTPUT.hideLocalSyntheticVariables=true 24 | outlet.DEFAULT_OUTPUT.installDslAsPrimarySource=false 25 | outlet.DEFAULT_OUTPUT.keepLocalHistory=true 26 | outlet.DEFAULT_OUTPUT.override=true 27 | outlet.DEFAULT_OUTPUT.sourceFolder.xtend-gen.directory= 28 | outlet.DEFAULT_OUTPUT.sourceFolder.xtend-gen.ignore= 29 | outlet.DEFAULT_OUTPUT.userOutputPerSourceFolder= 30 | targetJavaVersion=JAVA5 31 | useJavaCompilerCompliance=false 32 | -------------------------------------------------------------------------------- /.settings/org.eclipse.jdt.core.prefs: -------------------------------------------------------------------------------- 1 | eclipse.preferences.version=1 2 | org.eclipse.jdt.core.compiler.codegen.inlineJsrBytecode=enabled 3 | org.eclipse.jdt.core.compiler.codegen.methodParameters=do not generate 4 | org.eclipse.jdt.core.compiler.codegen.targetPlatform=1.8 5 | org.eclipse.jdt.core.compiler.codegen.unusedLocal=preserve 6 | org.eclipse.jdt.core.compiler.compliance=1.8 7 | org.eclipse.jdt.core.compiler.debug.lineNumber=generate 8 | org.eclipse.jdt.core.compiler.debug.localVariable=generate 9 | org.eclipse.jdt.core.compiler.debug.sourceFile=generate 10 | org.eclipse.jdt.core.compiler.problem.assertIdentifier=error 11 | org.eclipse.jdt.core.compiler.problem.enumIdentifier=error 12 | org.eclipse.jdt.core.compiler.source=1.8 13 | -------------------------------------------------------------------------------- /.settings/org.eclipse.xtend.core.Xtend.prefs: -------------------------------------------------------------------------------- 1 | //outlet.DEFAULT_OUTPUT.sourceFolder.build/xtend/main.directory= 2 | //outlet.DEFAULT_OUTPUT.sourceFolder.build/xtend/main.ignore= 3 | //outlet.DEFAULT_OUTPUT.sourceFolder.build/xtend/test.directory= 4 | //outlet.DEFAULT_OUTPUT.sourceFolder.build/xtend/test.ignore= 5 | //outlet.DEFAULT_OUTPUT.sourceFolder.src/main/java.directory=build/xtend/main 6 | //outlet.DEFAULT_OUTPUT.sourceFolder.src/main/java.ignore= 7 | //outlet.DEFAULT_OUTPUT.sourceFolder.src/main/resources.directory=build/xtend/main 8 | //outlet.DEFAULT_OUTPUT.sourceFolder.src/test/java.directory=build/xtend/test 9 | //outlet.DEFAULT_OUTPUT.sourceFolder.src/test/java.ignore= 10 | //outlet.DEFAULT_OUTPUT.sourceFolder.src/test/resources.directory=build/xtend/test 11 | BuilderConfiguration.is_project_specific=true 12 | ValidatorConfiguration.is_project_specific=true 13 | autobuilding=true 14 | eclipse.preferences.version=1 15 | generateGeneratedAnnotation=false 16 | generateSuppressWarnings=true 17 | generatedAnnotationComment= 18 | includeDateInGenerated=false 19 | org.eclipse.xtend.core.Xtend.useProjectSettings=true 20 | outlet.DEFAULT_OUTPUT.cleanDirectory=false 21 | outlet.DEFAULT_OUTPUT.cleanupDerived=true 22 | outlet.DEFAULT_OUTPUT.createDirectory=true 23 | outlet.DEFAULT_OUTPUT.derived=true 24 | outlet.DEFAULT_OUTPUT.directory=xtend-gen 25 | outlet.DEFAULT_OUTPUT.hideLocalSyntheticVariables=true 26 | outlet.DEFAULT_OUTPUT.installDslAsPrimarySource=false 27 | outlet.DEFAULT_OUTPUT.keepLocalHistory=false 28 | outlet.DEFAULT_OUTPUT.override=true 29 | outlet.DEFAULT_OUTPUT.sourceFolder.xtend-gen.directory= 30 | outlet.DEFAULT_OUTPUT.sourceFolder.xtend-gen.ignore= 31 | outlet.DEFAULT_OUTPUT.userOutputPerSourceFolder=true 32 | targetJavaVersion=JAVA8 33 | useJavaCompilerCompliance=false 34 | -------------------------------------------------------------------------------- /.settings/org.eclipse.xtext.java.Java.prefs: -------------------------------------------------------------------------------- 1 | #Tue Mar 21 11:54:13 GMT 2017 2 | includeDateInGenerated=false 3 | BuilderConfiguration.is_project_specific=true 4 | eclipse.preferences.version=1 5 | generateGeneratedAnnotation=false 6 | useJavaCompilerCompliance=false 7 | generateSuppressWarnings=true 8 | ValidatorConfiguration.is_project_specific=true 9 | targetJavaVersion=Java8 10 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: java 2 | script: 3 | - ./gradlew check -i --stacktrace 4 | jdk: 5 | - openjdk15 6 | - openjdk13 7 | - openjdk11 8 | - openjdk8 9 | install: travis_wait 30 ./gradlew installDist -i --stacktrace 10 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) 2016, The Blang Development Team 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 5 | 6 | 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 7 | 8 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 9 | 10 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Summary 2 | ------- 3 | 4 | **Prospective/current users**: please vist the [project web page](https://www.stat.ubc.ca/~bouchard/blang/index.html) for more information as well as our [JSS paper](https://www.jstatsoft.org/article/view/v103i11). 5 | 6 | **Docker image**: [link to documentation](https://github.com/UBC-Stat-ML/containers/tree/main/blang) 7 | 8 | **New feature:** Blang can now run on 1000s of machines using MPI via [the Pigeons-Blang bridge](https://julia-tempering.github.io/Pigeons.jl/dev/reference/#Pigeons.BlangTarget) 9 | 10 | **Blang developers**: in addition to the above resources, see also the [documentation repository](https://github.com/UBC-Stat-ML/blangDoc) for more information. 11 | 12 | This is one of the repositories hosting Blang's code. This one contains the Blang's SDK (Software Development Kit), including: 13 | 14 | - Basic datatypes suitable for sampling. 15 | - Infrastructure to create new data types and distributions. 16 | - Inference algorithms for such datatypes, such as [Adaptive Non-Reversible Parallel Tempering](https://www.stat.ubc.ca/~bouchard/pub/Syed2019NRPT.pdf) and Sequential Change of Measure. 17 | - Standard probability distributions. 18 | - MCMC testing infrastructure. 19 | - Runtime to perform static analysis to infer the factor graph and its sparsity patterns. 20 | - Automated post-processing facilities (MCMC diagnostic, trace/density/pmf/summaries generation, etc). 21 | 22 | See [this readme](https://github.com/UBC-Stat-ML/blangDoc/blob/master/README.md) for a roadmap of the other key repositories (language infrastructure, examples, supporting libraries, etc) 23 | 24 | **Citing Blang**: if you find Blang useful for your work, consider citing our [JSS paper](https://www.jstatsoft.org/article/view/v103i11): 25 | 26 | ``` 27 | Alexandre Bouchard-Côté, Kevin Chern, Davor Cubranic, Sahand Hosseini, Justin Hume, Matteo Lepur, Zihui Ouyang, Giorgio Sgarbi (2022) 28 | Journal of Statistical Software 103:1–98 29 | ``` 30 | -------------------------------------------------------------------------------- /doc/blang.js: -------------------------------------------------------------------------------- 1 | define(function(require, exports, module) { 2 | "use strict"; 3 | 4 | var oop = require("../lib/oop"); 5 | var mText = require("./text"); 6 | var mTextHighlightRules = require("./text_highlight_rules"); 7 | 8 | var HighlightRules = function() { 9 | var keywords = "as|case|catch|default|do|else|extends|extension|false|finally|for|generate|if|import|indicator|instanceof|is|laws|logf|model|new|null|package|param|random|return|static|super|switch|synchronized|throw|true|try|typeof|val|var|while"; 10 | this.$rules = { 11 | "start": [ 12 | {token: "comment", regex: "\\/\\/.*$"}, 13 | {token: "comment", regex: "\\/\\*", next : "comment"}, 14 | {token: "string", regex: '["](?:(?:\\\\.)|(?:[^"\\\\]))*?["]'}, 15 | {token: "string", regex: "['](?:(?:\\\\.)|(?:[^'\\\\]))*?[']"}, 16 | {token: "constant.numeric", regex: "[+-]?\\d+(?:(?:\\.\\d*)?(?:[eE][+-]?\\d+)?)?\\b"}, 17 | {token: "constant.numeric", regex: "0[xX][0-9a-fA-F]+\\b"}, 18 | {token: "lparen", regex: "[\\[({]"}, 19 | {token: "rparen", regex: "[\\])}]"}, 20 | {token: "keyword", regex: "\\b(?:" + keywords + ")\\b"} 21 | ], 22 | "comment": [ 23 | {token: "comment", regex: ".*?\\*\\/", next : "start"}, 24 | {token: "comment", regex: ".+"} 25 | ] 26 | }; 27 | }; 28 | oop.inherits(HighlightRules, mTextHighlightRules.TextHighlightRules); 29 | 30 | var Mode = function() { 31 | this.HighlightRules = HighlightRules; 32 | }; 33 | oop.inherits(Mode, mText.Mode); 34 | Mode.prototype.$id = "xtext/bl"; 35 | Mode.prototype.getCompletions = function(state, session, pos, prefix) { 36 | return []; 37 | } 38 | 39 | return { 40 | Mode: Mode 41 | }; 42 | 43 | 44 | }); 45 | -------------------------------------------------------------------------------- /doc/build.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | ## First, setup the ace hack 4 | 5 | # copy blang file 6 | cp blang.js ace-master/lib/ace/mode/ 7 | cp xtend.js ace-master/lib/ace/mode/ 8 | 9 | cd ace-master 10 | npm clean 11 | npm install 12 | node Makefile.dryice.js 13 | cd .. 14 | 15 | rm -rf www/ace 16 | cp -r ace-master/build/src/ www/ace 17 | 18 | 19 | ## Then, generate the actual documentation 20 | 21 | # Rebuild source 22 | cd .. 23 | ./setup-cli.sh 24 | cd - 25 | 26 | # Run the document generator 27 | cd www 28 | java -cp ../../build/install/blang/lib/\* blang.runtime.internals.doc.MakeHTMLDoc 29 | cd - 30 | 31 | 32 | ##### Javadocs 33 | 34 | ## DSL 35 | 36 | cd ../../blangDSL/ca.ubc.stat.blang.parent 37 | ./gradlew assemble 38 | cd - 39 | 40 | rm -rf www/javadoc-dsl 41 | mv ../../blangDSL/ca.ubc.stat.blang.parent/ca.ubc.stat.blang/build/docs/javadoc www/javadoc-dsl 42 | 43 | 44 | 45 | ## xlinear 46 | 47 | cd ../../xlinear 48 | ./gradlew assemble 49 | cd - 50 | 51 | rm -rf www/javadoc-xlinear 52 | mv ../../xlinear/build/docs/javadoc www/javadoc-xlinear 53 | 54 | 55 | ## inits 56 | 57 | cd ../../inits 58 | ./gradlew assemble 59 | cd - 60 | 61 | rm -rf www/javadoc-inits 62 | mv ../../inits/build/docs/javadoc www/javadoc-inits 63 | 64 | 65 | ## SDK 66 | 67 | cd .. 68 | ./gradlew assemble 69 | cd - 70 | 71 | rm -rf www/javadoc-sdk 72 | mv ../build/docs/javadoc www/javadoc-sdk 73 | 74 | -------------------------------------------------------------------------------- /doc/deploy.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | remote='s2:~/public_html/blang/' 4 | chmod -R 755 www 5 | rsync -t --rsh=/usr/bin/ssh --recursive --perms --group www/ $remote; echo "Finished pushing blang documentation site" & -------------------------------------------------------------------------------- /doc/download-deps.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # ACE 4 | rm -rf ace-master 5 | git clone https://github.com/ajaxorg/ace.git 6 | mv ace ace-master 7 | cd ace-master 8 | git reset --hard c3403f1fbdf22cfff2cb1dda584b8e04467cd372 9 | cd - 10 | 11 | 12 | # Bootstrap 13 | rm -rf bootstrap 14 | rm -rf www/dist 15 | git clone https://github.com/twbs/bootstrap.git 16 | cd bootstrap 17 | git reset --hard v3.3.7 18 | cp -r dist ../www/ 19 | cd - -------------------------------------------------------------------------------- /doc/eclipse-release-assembly/assemble.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | # nb: eclipse executable in Eclipse.app/Contents/MacOS 5 | # list features: ./eclipse -clean -purgeHistory -application org.eclipse.equinox.p2.director -noSplash -repository https://www.stat.ubc.ca/~bouchard/maven/blang-eclipse-plugin-latest/ -list 6 | 7 | CUR=`pwd` 8 | 9 | cd ../.. 10 | ./setup-cli.sh 11 | cd - 12 | 13 | 14 | blang_folder=blang 15 | rm -rf $blang_folder 16 | mkdir $blang_folder 17 | 18 | ### Setup eclipse 19 | 20 | cp -r plain-eclipse/Eclipse.app $blang_folder/BlangIDE.app 21 | 22 | $blang_folder/BlangIDE.app/Contents/MacOS/eclipse \ 23 | -clean -purgeHistory \ 24 | -application org.eclipse.equinox.p2.director \ 25 | -noSplash \ 26 | -repository https://www.stat.ubc.ca/~bouchard/maven/blang-eclipse-plugin-latest/ \ 27 | -installIUs ca.ubc.stat.blang.feature.feature.group 28 | 29 | sudo codesign --force --sign - $blang_folder/BlangIDE.app 30 | 31 | 32 | ### Setup blang-related projects in workspace 33 | 34 | cd $blang_folder 35 | mkdir workspace 36 | cd workspace 37 | 38 | create-blang-gradle-project --name blangExample --githubOrganization UBC-Stat-ML 39 | 40 | git clone https://github.com/UBC-Stat-ML/blangSDK.git 41 | 42 | 43 | ### Package things up into a zip 44 | 45 | cd $CUR 46 | 47 | zip -r $blang_folder $blang_folder 48 | mkdir ../www/downloads 49 | mv ${blang_folder}.zip ../www/downloads/blang-mac-latest.zip 50 | 51 | rm -rf $blang_folder 52 | -------------------------------------------------------------------------------- /doc/eclipse-release-assembly/plain-eclipse/download-eclipse-xtext.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | wget https://ftp.osuosl.org/pub/eclipse/technology/epp/downloads/release/2020-12/R/eclipse-dsl-2020-12-R-macosx-cocoa-x86_64.dmg 4 | 5 | echo "Unpack and put Eclipse.app in here" -------------------------------------------------------------------------------- /doc/www/GitHub-logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UBC-Stat-ML/blangSDK/b8642c9c2a0adab8a5b6da96f2a7889f1b81b6cc/doc/www/GitHub-logo.png -------------------------------------------------------------------------------- /doc/www/ide.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UBC-Stat-ML/blangSDK/b8642c9c2a0adab8a5b6da96f2a7889f1b81b6cc/doc/www/ide.jpg -------------------------------------------------------------------------------- /doc/www/jumbotron-narrow.css: -------------------------------------------------------------------------------- 1 | /* Space out content a bit */ 2 | body { 3 | padding-top: 20px; 4 | padding-bottom: 20px; 5 | } 6 | 7 | /* Everything but the jumbotron gets side spacing for mobile first views */ 8 | .header, 9 | .marketing, 10 | .footer { 11 | padding-right: 15px; 12 | padding-left: 15px; 13 | } 14 | 15 | /* Custom page header */ 16 | .header { 17 | padding-bottom: 20px; 18 | border-bottom: 1px solid #e5e5e5; 19 | } 20 | /* Make the masthead heading the same height as the navigation */ 21 | .header h3 { 22 | margin-top: 0; 23 | margin-bottom: 0; 24 | line-height: 40px; 25 | } 26 | 27 | /* Custom page footer */ 28 | .footer { 29 | padding-top: 19px; 30 | color: #777; 31 | border-top: 1px solid #e5e5e5; 32 | } 33 | 34 | /* Customize container */ 35 | @media (min-width: 768px) { 36 | .container { 37 | max-width: 730px; 38 | } 39 | } 40 | .container-narrow > hr { 41 | margin: 30px 0; 42 | } 43 | 44 | /* Main marketing message and sign up button */ 45 | .jumbotron { 46 | text-align: center; 47 | border-bottom: 1px solid #e5e5e5; 48 | } 49 | 50 | .jumbotron-bg { 51 | colour = white; 52 | background-image:url('jupiter.jpg'); 53 | background-repeat: no-repeat; 54 | background-size: cover; 55 | } 56 | 57 | .jumbotron-bg h1 { 58 | color: white; 59 | text-shadow: 2px 2px 4px #000000; 60 | } 61 | 62 | .jumbotron span{ 63 | background: black; 64 | } 65 | 66 | .jumbotron-bg p{ 67 | color: white; 68 | text-shadow: 2px 2px 4px #000000; 69 | } 70 | 71 | .jumbotron .btn { 72 | padding: 14px 24px; 73 | font-size: 21px; 74 | } 75 | 76 | .marketing div { 77 | font-size: 18px; 78 | } 79 | 80 | 81 | /* Supporting marketing content */ 82 | .marketing { 83 | margin: 40px 0; 84 | } 85 | .marketing p + h4 { 86 | margin-top: 28px; 87 | } 88 | 89 | /* Responsive: Portrait tablets and up */ 90 | @media screen and (min-width: 768px) { 91 | /* Remove the padding we set earlier */ 92 | .header, 93 | .marketing, 94 | .footer { 95 | padding-right: 0; 96 | padding-left: 0; 97 | } 98 | /* Space out the masthead */ 99 | .header { 100 | margin-bottom: 30px; 101 | } 102 | /* Remove the bottom border on the jumbotron for visual effect */ 103 | .jumbotron { 104 | border-bottom: 0; 105 | } 106 | } 107 | -------------------------------------------------------------------------------- /doc/www/jupiter.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UBC-Stat-ML/blangSDK/b8642c9c2a0adab8a5b6da96f2a7889f1b81b6cc/doc/www/jupiter.jpg -------------------------------------------------------------------------------- /doc/xtend.js: -------------------------------------------------------------------------------- 1 | define(function(require, exports, module) { 2 | "use strict"; 3 | 4 | var oop = require("../lib/oop"); 5 | var mText = require("./text"); 6 | var mTextHighlightRules = require("./text_highlight_rules"); 7 | 8 | var HighlightRules = function() { 9 | var keywords = "this|it|null|abstract|annotation|boolean|case|catch|char|class|create|def|default|do|double|enum|else|extends|extension|final|finally|float|for|if|implements|import|int|interface|long|new|override|package|private|protected|return|short|static|super|switch|throw|throws|try|typeof|val|var|void|while|FOR|ENDFOR|IF|ENDIF|ELSEIF|BEFORE|AFTER|SEPARATOR"; 10 | this.$rules = { 11 | "start": [ 12 | {token: "comment", regex: "\\/\\/.*$"}, 13 | {token: "comment", regex: "\\/\\*", next : "comment"}, 14 | {token: "string", regex: '["](?:(?:\\\\.)|(?:[^"\\\\]))*?["]'}, 15 | {token: "string", regex: "['](?:(?:\\\\.)|(?:[^'\\\\]))*?[']"}, 16 | {token: "constant.numeric", regex: "[+-]?\\d+(?:(?:\\.\\d*)?(?:[eE][+-]?\\d+)?)?\\b"}, 17 | {token: "constant.numeric", regex: "0[xX][0-9a-fA-F]+\\b"}, 18 | {token: "lparen", regex: "[\\[({]"}, 19 | {token: "rparen", regex: "[\\])}]"}, 20 | {token: "keyword", regex: "\\b(?:" + keywords + ")\\b"} 21 | ], 22 | "comment": [ 23 | {token: "comment", regex: ".*?\\*\\/", next : "start"}, 24 | {token: "comment", regex: ".+"} 25 | ] 26 | }; 27 | }; 28 | oop.inherits(HighlightRules, mTextHighlightRules.TextHighlightRules); 29 | 30 | var Mode = function() { 31 | this.HighlightRules = HighlightRules; 32 | }; 33 | oop.inherits(Mode, mText.Mode); 34 | Mode.prototype.$id = "xtend"; 35 | Mode.prototype.getCompletions = function(state, session, pos, prefix) { 36 | return []; 37 | } 38 | 39 | return { 40 | Mode: Mode 41 | }; 42 | 43 | 44 | }); 45 | -------------------------------------------------------------------------------- /gradle/wrapper/gradle-wrapper.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UBC-Stat-ML/blangSDK/b8642c9c2a0adab8a5b6da96f2a7889f1b81b6cc/gradle/wrapper/gradle-wrapper.jar -------------------------------------------------------------------------------- /gradle/wrapper/gradle-wrapper.properties: -------------------------------------------------------------------------------- 1 | distributionBase=GRADLE_USER_HOME 2 | distributionPath=wrapper/dists 3 | distributionUrl=https\://services.gradle.org/distributions/gradle-6.8.3-bin.zip 4 | zipStoreBase=GRADLE_USER_HOME 5 | zipStorePath=wrapper/dists 6 | -------------------------------------------------------------------------------- /jitpack.yml: -------------------------------------------------------------------------------- 1 | jdk: 2 | - openjdk11 -------------------------------------------------------------------------------- /settings.gradle: -------------------------------------------------------------------------------- 1 | rootProject.name = "blang" // Do not change to blangSDK: this value is hardcoded in the blang command-line infrastructure 2 | -------------------------------------------------------------------------------- /setup-cli.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | echo 4 | echo " INSTALLING BLANG (COMMAND LINE INTERFACE)" 5 | echo " This may take some time as dependencies" 6 | echo " are being downloaded" 7 | echo 8 | 9 | # some weird gradle-xtext-blang problem may be caused by deamon trying to 10 | # handle 2 Blang versions, try to avoid this restarting the daemon after a Blang update 11 | ./gradlew --stop || exit 1 12 | 13 | ./gradlew clean || exit 1 14 | ./gradlew installDist || exit 1 15 | 16 | # Fix problem arising if eclipse is used jointly 17 | mkdir build/xtend/test 18 | mkdir build/blang/test 19 | 20 | echo 21 | echo " INSTALLATION WAS SUCCESSFUL" 22 | echo " Type 'blang' to try it" 23 | echo 24 | 25 | if hash blang 2>/dev/null; then 26 | echo 27 | else 28 | echo "NOTE: We are adding a line into ~/.bash_profile or ~/.zshenv to make the blang CLI command" 29 | echo " accessible from any directory (as blang is not found in PATH right now)." 30 | echo 31 | to_add="$(pwd)/build/install/blang/bin/" 32 | existing='$PATH' 33 | line="export PATH=${existing}:${to_add}" 34 | export PATH=$PATH:${to_add} 35 | if [[ $(basename $SHELL) == "zsh" ]]; then 36 | echo $line >>~/.zshenv 37 | elif [[ $(basename $SHELL) == "bash" ]]; then 38 | echo $line >>~/.bash_profile 39 | else 40 | echo "Default shell is not Bash nor Zsh." 41 | echo "Please add $(to_add) to PATH manually." 42 | fi 43 | fi 44 | -------------------------------------------------------------------------------- /setup-eclipse.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | ./gradlew assemble eclipse 4 | 5 | # Fix some stuff that get broken everytime 6 | git -c diff.mnemonicprefix=false -c core.quotepath=false -c credential.helper=sourcetree checkout -- .settings/ca.ubc.stat.blang.BlangDsl.prefs .settings/org.eclipse.jdt.core.prefs .settings/org.eclipse.xtend.core.Xtend.prefs .settings/org.eclipse.xtext.java.Java.prefs 7 | mkdir -p build 8 | mkdir -p build/blang 9 | mkdir -p build/blang/test 10 | mkdir -p build/blang/main 11 | mkdir -p build/xtend/main 12 | mkdir -p build/xtend/test 13 | 14 | echo Done 15 | -------------------------------------------------------------------------------- /src/main/java/blang/distributions/Bernoulli.bl: -------------------------------------------------------------------------------- 1 | package blang.distributions 2 | 3 | /** Any random variable taking values in \(\{0, 1\}\). */ 4 | model Bernoulli { 5 | random IntVar realization 6 | 7 | /** Probability \(p \in [0, 1]\) that the realization is one. */ 8 | param RealVar probability 9 | 10 | laws { 11 | realization | probability ~ Categorical({ 12 | if (probability < 0.0 || probability > 1.0) invalidParameter 13 | return fixedSimplex(1.0 - probability, probability) 14 | }) 15 | } 16 | 17 | generate (rand) { 18 | if (rand.bernoulli(probability)) 1 else 0 19 | } 20 | } -------------------------------------------------------------------------------- /src/main/java/blang/distributions/Beta.bl: -------------------------------------------------------------------------------- 1 | package blang.distributions 2 | 3 | import blang.distributions.internals.Helpers 4 | 5 | /** Beta random variable on the open interval \((0, 1)\). */ 6 | model Beta { 7 | random RealVar realization 8 | 9 | /** Higher values brings mean closer to one. \(\alpha > 0 \) */ 10 | param RealVar alpha 11 | 12 | /** Higher values brings mean closer to zero. \(\beta > 0 \) */ 13 | param RealVar beta 14 | 15 | laws { 16 | logf(alpha, realization) { 17 | if (realization <= 0.0 || realization >= 1.0) return NEGATIVE_INFINITY 18 | if (alpha <= 0.0) return NEGATIVE_INFINITY 19 | Helpers::checkDirichletOrBetaParam(alpha) 20 | return (alpha - 1.0) * log(realization) 21 | } 22 | logf(beta, realization) { 23 | if (realization <= 0.0 || realization >= 1.0) return NEGATIVE_INFINITY 24 | if (beta <= 0.0) return NEGATIVE_INFINITY 25 | Helpers::checkDirichletOrBetaParam(beta) 26 | return (beta - 1.0) * log1p(-realization) 27 | } 28 | logf(alpha, beta) { 29 | if (alpha <= 0.0) return NEGATIVE_INFINITY 30 | if (beta <= 0.0) return NEGATIVE_INFINITY 31 | return lnGamma(alpha + beta) 32 | } 33 | logf(alpha) { 34 | if (alpha <= 0.0) return NEGATIVE_INFINITY 35 | return - lnGamma(alpha) 36 | } 37 | logf(beta) { 38 | if (beta <= 0.0) return NEGATIVE_INFINITY 39 | return - lnGamma(beta) 40 | } 41 | } 42 | 43 | generate(rand) { 44 | rand.beta(alpha, beta) 45 | } 46 | } -------------------------------------------------------------------------------- /src/main/java/blang/distributions/BetaBinomial.bl: -------------------------------------------------------------------------------- 1 | package blang.distributions 2 | 3 | /** A sum of \(n\) iid Bernoulli variables, with a marginalized Beta prior on the success probability. Values in \(\{0, 1, 2, \dots, n\}\). */ 4 | model BetaBinomial{ 5 | random IntVar realization 6 | 7 | /** The number \(n\) of Bernoulli variables being summed. \(n > 0\) */ 8 | param IntVar numberOfTrials 9 | 10 | /** Higher values brings mean closer to one. \(\alpha > 0 \) */ 11 | param RealVar alpha 12 | 13 | /** Higher values brings mean closer to zero. \(\beta > 0 \) */ 14 | param RealVar beta 15 | 16 | laws{ 17 | logf(realization,numberOfTrials,alpha,beta) { 18 | if (alpha <= 0.0 || beta <= 0.0) return NEGATIVE_INFINITY 19 | if (realization < 0.0) return NEGATIVE_INFINITY 20 | if (numberOfTrials <= 0.0 || realization > numberOfTrials) return NEGATIVE_INFINITY 21 | return lnGamma(numberOfTrials+1) 22 | +lnGamma(realization+alpha) 23 | +lnGamma(numberOfTrials-realization+beta) 24 | +lnGamma(alpha+beta) 25 | -lnGamma(realization+1) 26 | -lnGamma(numberOfTrials-realization+1) 27 | -lnGamma(numberOfTrials+alpha+beta) 28 | -lnGamma(alpha) 29 | -lnGamma(beta) 30 | } 31 | } 32 | generate (rand){ rand.betaBinomial(alpha,beta,numberOfTrials)} 33 | } -------------------------------------------------------------------------------- /src/main/java/blang/distributions/BetaNegativeBinomial.bl: -------------------------------------------------------------------------------- 1 | package blang.distributions 2 | 3 | /** Negative Binomial Distribution with a marginalized Beta prior. Values in \(\{0, 1, 2, \dots\}\). */ 4 | model BetaNegativeBinomial { 5 | random IntVar k 6 | 7 | /** Number of failures until experiment is stopped (generalized to the reals). \(r > 0\) */ 8 | param RealVar r 9 | 10 | /** Higher values brings mean accept probability closer to one. \(\alpha > 0 \) */ 11 | param RealVar alpha 12 | 13 | /** Higher values brings mean accept probability closer to zero. \(\beta > 0 \) */ 14 | param RealVar beta 15 | 16 | laws { 17 | logf(r, k) { 18 | if (k < 0) NEGATIVE_INFINITY 19 | else if (r <= 0.0) NEGATIVE_INFINITY 20 | else logGamma(r + k) 21 | } 22 | logf(alpha, beta, r, k) { 23 | if (alpha <= 0.0) NEGATIVE_INFINITY 24 | else if (beta <= 0.0) NEGATIVE_INFINITY 25 | else if (k < 0) NEGATIVE_INFINITY 26 | else if (r <= 0.0) NEGATIVE_INFINITY 27 | else logBeta(alpha + k, beta + r) // Fixes an error in wikipedia 28 | } 29 | logf(k) { 30 | if (k < 0) NEGATIVE_INFINITY 31 | else -logFactorial(k) 32 | } 33 | logf(r) { 34 | if (r <= 0.0) NEGATIVE_INFINITY 35 | else -logGamma(r) 36 | } 37 | logf(alpha, beta) { 38 | if (alpha <= 0.0) NEGATIVE_INFINITY 39 | else if (beta <= 0.0) NEGATIVE_INFINITY 40 | else -logBeta(alpha, beta) 41 | } 42 | } 43 | 44 | generate (rand) { 45 | val p = rand.beta(alpha, beta) 46 | return rand.negativeBinomial(r, p) 47 | } 48 | } -------------------------------------------------------------------------------- /src/main/java/blang/distributions/Binomial.bl: -------------------------------------------------------------------------------- 1 | package blang.distributions 2 | 3 | /** A sum of \(n\) iid Bernoulli variables. Values in \(\{0, 1, 2, \dots, n\}\). */ 4 | model Binomial { 5 | random IntVar numberOfSuccesses 6 | 7 | /** The number \(n\) of Bernoulli variables being summed. \(n > 0\) */ 8 | param IntVar numberOfTrials 9 | 10 | /** The parameter \(p \in [0, 1]\) shared by all the Bernoulli variables (probability that they be equal to 1). */ 11 | param RealVar probabilityOfSuccess 12 | 13 | laws { 14 | logf(numberOfSuccesses, numberOfTrials, probabilityOfSuccess) { 15 | if (probabilityOfSuccess < 0.0 || probabilityOfSuccess > 1.0) return NEGATIVE_INFINITY 16 | if (numberOfSuccesses < 0) return NEGATIVE_INFINITY 17 | if (numberOfTrials <= 0 || numberOfSuccesses > numberOfTrials) return NEGATIVE_INFINITY 18 | return numberOfSuccesses * log(probabilityOfSuccess) + (numberOfTrials - numberOfSuccesses) * log(1.0 - probabilityOfSuccess) 19 | } 20 | logf(numberOfTrials, numberOfSuccesses) { 21 | if (numberOfSuccesses < 0) return NEGATIVE_INFINITY 22 | if (numberOfTrials <= 0 || numberOfSuccesses > numberOfTrials) return NEGATIVE_INFINITY 23 | return logBinomial(numberOfTrials, numberOfSuccesses) 24 | } 25 | } 26 | 27 | generate (rand) { 28 | rand.binomial(numberOfTrials, probabilityOfSuccess) 29 | } 30 | } -------------------------------------------------------------------------------- /src/main/java/blang/distributions/Categorical.bl: -------------------------------------------------------------------------------- 1 | package blang.distributions 2 | 3 | /** Any random variable over a finite set \(\{0, 1, 2, \dots, n-1\}\). */ 4 | @Samplers(CategoricalSampler) 5 | model Categorical { 6 | random IntVar realization 7 | 8 | /** Vector of probabilities \((p_0, p_1, \dots, p_{n-1})\) for each of the \(n\) integers. */ 9 | param Simplex probabilities 10 | 11 | laws { 12 | logf(probabilities, realization) { 13 | log(probabilities.get(realization)) 14 | } 15 | realization is Constrained 16 | } 17 | 18 | generate(rand) { 19 | rand.categorical(probabilities.vectorToArray) 20 | } 21 | } -------------------------------------------------------------------------------- /src/main/java/blang/distributions/ChiSquared.bl: -------------------------------------------------------------------------------- 1 | package blang.distributions 2 | 3 | /** Chi Squared random variable. Values in \((0, \infty)\). */ 4 | model ChiSquared { 5 | random RealVar realization 6 | 7 | /** The degrees of freedom \(\nu\). \( \nu > 0 \) */ 8 | param IntVar nu 9 | 10 | laws { 11 | logf(nu){ 12 | if (nu <= 0) return NEGATIVE_INFINITY 13 | return - (nu / 2.0) * log(2) - lnGamma(nu / 2.0) 14 | } 15 | logf(realization, nu){ 16 | if (nu <= 0) return NEGATIVE_INFINITY 17 | if (realization <= 0) return NEGATIVE_INFINITY 18 | return (nu / 2.0 - 1) * log(realization) 19 | } 20 | logf(realization){ 21 | if (realization <= 0) return NEGATIVE_INFINITY 22 | return - realization / 2.0 23 | } 24 | } 25 | 26 | generate(rand) { 27 | rand.chisquared(nu) 28 | } 29 | } -------------------------------------------------------------------------------- /src/main/java/blang/distributions/ContinuousUniform.bl: -------------------------------------------------------------------------------- 1 | package blang.distributions 2 | 3 | /** Uniform random variable over a close interval \([m, M]\). */ 4 | model ContinuousUniform { 5 | random RealVar realization 6 | 7 | /** The left end point \(m\) of the interval. \(m \in (-\infty, M)\) */ 8 | param RealVar min 9 | 10 | /** The right end point \(M\) of the interval. \(M \in (m, \infty)\) */ 11 | param RealVar max 12 | 13 | laws { 14 | logf(min, max) { 15 | if (max - min <= 0.0) return NEGATIVE_INFINITY 16 | return - log(max - min) 17 | } 18 | logf(realization, min, max) { 19 | if (min <= realization && realization <= max) return 0.0 20 | else return NEGATIVE_INFINITY 21 | } 22 | } 23 | 24 | generate(rand) { 25 | rand.uniform(min, max) 26 | } 27 | } -------------------------------------------------------------------------------- /src/main/java/blang/distributions/Dirichlet.bl: -------------------------------------------------------------------------------- 1 | package blang.distributions 2 | 3 | import blang.distributions.internals.Helpers 4 | 5 | /** The Dirichlet distribution over vectors of probabilities \((p_0, p_1, \dots, p_{n-1})\). \(p_i \in (0, 1), \sum_i p_i = 1.\) */ 6 | model Dirichlet { 7 | random Simplex realization 8 | 9 | /** Vector \((\alpha_0, \alpha_1, \dots, \alpha_{n-1})\) such that increasing the \(i\)th component increases the mean of entry \(p_i\). */ 10 | param Matrix concentrations 11 | 12 | laws { 13 | logf(concentrations, realization) { 14 | var sum = 0.0 15 | for (int dim : 0 ..< concentrations.nEntries) { 16 | val concentration = concentrations.get(dim) 17 | if (concentration < 0.0) return NEGATIVE_INFINITY 18 | Helpers::checkDirichletOrBetaParam(concentration) 19 | sum += (concentration - 1.0) * log(realization.get(dim)) 20 | } 21 | return sum 22 | } 23 | logf(concentrations) { 24 | var sum = 0.0 25 | for (int dim : 0 ..< concentrations.nEntries) { 26 | val concentration = concentrations.get(dim) 27 | if (concentration < 0.0) return NEGATIVE_INFINITY 28 | sum += - lnGamma(concentration) 29 | } 30 | return sum + lnGamma(concentrations.sum) 31 | } 32 | realization is Constrained 33 | } 34 | 35 | generate(rand) { 36 | rand.dirichletInPlace(concentrations, realization) 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /src/main/java/blang/distributions/DiscreteUniform.bl: -------------------------------------------------------------------------------- 1 | package blang.distributions 2 | 3 | /** Uniform random variable over the contiguous set of integers \(\{m, m+1, \dots, M-1\}\). */ 4 | @Samplers(UniformSampler) 5 | model DiscreteUniform { 6 | random IntVar realization 7 | 8 | /** The left point of the set (inclusive). \(m \in (-\infty, M)\) */ 9 | param IntVar minInclusive 10 | 11 | /** The right point of the set (exclusive). \(M \in (m, \infty)\) */ 12 | param IntVar maxExclusive 13 | 14 | laws { 15 | logf(minInclusive, maxExclusive) { 16 | if (maxExclusive - minInclusive <= 0.0) return NEGATIVE_INFINITY 17 | return -log(maxExclusive - minInclusive) 18 | } 19 | logf(realization, minInclusive, maxExclusive) { 20 | if (minInclusive <= realization && 21 | realization < maxExclusive) return 0.0 22 | else return NEGATIVE_INFINITY 23 | } 24 | } 25 | 26 | generate(rand) { 27 | rand.discreteUniform(minInclusive, maxExclusive) 28 | } 29 | } -------------------------------------------------------------------------------- /src/main/java/blang/distributions/Exponential.bl: -------------------------------------------------------------------------------- 1 | package blang.distributions 2 | 3 | /** Exponential random variable. Values in \((0, \infty)\) */ 4 | model Exponential { 5 | random RealVar realization 6 | 7 | /** The rate \(\lambda\), inversely proportional to the mean. \( \lambda > 0 \) */ 8 | param RealVar rate 9 | 10 | laws { 11 | realization | rate ~ Gamma(1.0, rate) 12 | } 13 | 14 | generate (rand) { 15 | rand.exponential(rate) 16 | } 17 | } -------------------------------------------------------------------------------- /src/main/java/blang/distributions/F.bl: -------------------------------------------------------------------------------- 1 | package blang.distributions 2 | 3 | /** The F-distribution. Also known as Fisher-Snedecor distribution. Values in \((0, +\infty) \)*/ 4 | model F { 5 | random RealVar realization 6 | 7 | /** The degrees of freedom \( d_1 \) and \( d_2 \) . \( d_1, d_2 > 0 \) */ 8 | param RealVar d1, d2 9 | 10 | laws { 11 | logf(d1, d2) { 12 | if (d1 <= 0.0) return NEGATIVE_INFINITY 13 | if (d2 <= 0.0) return NEGATIVE_INFINITY 14 | return 0.5 * (d1*log(d1) + d2*log(d2)) + lnGamma(d1/2) + lnGamma(d2/2) - lnGamma((d1 + d2)/2) 15 | } 16 | 17 | logf(d1, realization) { 18 | if (d1 <= 0.0) return NEGATIVE_INFINITY 19 | if ((d1 == 1) && (realization <= 0.0)) return NEGATIVE_INFINITY 20 | if (realization < 0.0) return NEGATIVE_INFINITY 21 | return ( (d1 / 2) - 1.0) * log(realization) 22 | } 23 | 24 | logf(d1, d2, realization) { 25 | if (d1 <= 0.0) return NEGATIVE_INFINITY 26 | if (d2 <= 0.0) return NEGATIVE_INFINITY 27 | if ((d1 == 1) && (realization <= 0.0)) return NEGATIVE_INFINITY 28 | if (realization < 0.0) return NEGATIVE_INFINITY 29 | return - 0.5 * (d1 + d2) * log( (d1 * realization) + d2) 30 | } 31 | } 32 | 33 | generate(rand) { 34 | rand.fDist(d1, d2) 35 | } 36 | } -------------------------------------------------------------------------------- /src/main/java/blang/distributions/Gamma.bl: -------------------------------------------------------------------------------- 1 | package blang.distributions 2 | 3 | /** Gamma random variable. Values in \((0, \infty)\). */ 4 | model Gamma { 5 | random RealVar realization 6 | 7 | /** The shape \(\alpha\) is proportional to the mean and variance. \( \alpha > 0 \) */ 8 | param RealVar shape 9 | 10 | /** The rate \(\beta\) is inverse proportional to the mean and quadratically inverse proportional to the variance. \( \beta > 0 \) */ 11 | param RealVar rate 12 | 13 | laws { 14 | logf(shape, rate, realization) { 15 | if (shape <= 0.0 || rate <= 0) return NEGATIVE_INFINITY 16 | if (realization <= 0.0) return NEGATIVE_INFINITY 17 | return (shape - 1.0) * log(realization * rate) 18 | } 19 | logf(realization, rate) { 20 | if (rate <= 0) return NEGATIVE_INFINITY 21 | if (realization <= 0.0) return NEGATIVE_INFINITY 22 | return - realization * rate 23 | } 24 | logf(shape) { 25 | if (shape <= 0.0) return NEGATIVE_INFINITY 26 | return - lnGamma(shape) 27 | } 28 | logf(rate) { 29 | if (rate <= 0.0) return NEGATIVE_INFINITY 30 | return log(rate) 31 | } 32 | } 33 | 34 | generate(rand) { 35 | rand.gamma(shape, rate) 36 | } 37 | } -------------------------------------------------------------------------------- /src/main/java/blang/distributions/GammaMeanParam.bl: -------------------------------------------------------------------------------- 1 | package blang.distributions 2 | 3 | model GammaMeanParam { 4 | random RealVar realization 5 | param RealVar mean 6 | param RealVar variance 7 | 8 | laws { 9 | realization | mean, variance ~ Gamma(mean * mean / variance, mean / variance) 10 | } 11 | 12 | generate (rand) { 13 | rand.gamma(mean * mean / variance, mean / variance) 14 | } 15 | } -------------------------------------------------------------------------------- /src/main/java/blang/distributions/Geometric.bl: -------------------------------------------------------------------------------- 1 | package blang.distributions 2 | 3 | /** The number of unsuccessful Bernoulli trials until a success. Values in \(\{0, 1, 2, \dots\}\) */ 4 | model Geometric{ 5 | random IntVar realization 6 | 7 | /** The probability of success for each Bernoulli trial. */ 8 | param RealVar p 9 | 10 | laws { 11 | logf(p, realization) { 12 | if (p <= 0 || p >= 1) return NEGATIVE_INFINITY 13 | if (realization < 0) return NEGATIVE_INFINITY 14 | return realization*log(1-p) + log(p) 15 | } 16 | } 17 | generate(rand){ rand.geometric(p) } 18 | } -------------------------------------------------------------------------------- /src/main/java/blang/distributions/Gompertz.bl: -------------------------------------------------------------------------------- 1 | package blang.distributions 2 | 3 | /** The Gompertz distribution. Values in \([0, \infty) \). */ 4 | model Gompertz { 5 | random RealVar realization 6 | 7 | /** The shape parameter \(\nu \). \(\nu > 0 \) */ 8 | param RealVar shape 9 | 10 | /** The scale parameter \(b\). \(b > 0 \) */ 11 | param RealVar scale 12 | 13 | laws { 14 | logf(shape, scale) { 15 | if (shape <= 0.0) return NEGATIVE_INFINITY 16 | if (scale <= 0.0) return NEGATIVE_INFINITY 17 | return log(shape / scale) 18 | } 19 | logf(realization, scale, shape) { 20 | if (realization < 0.0) return NEGATIVE_INFINITY 21 | if (scale <= 0.0) return NEGATIVE_INFINITY 22 | if (shape <= 0.0) return NEGATIVE_INFINITY 23 | return (realization / scale) - (shape * (exp(realization / scale) - 1)) 24 | } 25 | } 26 | 27 | generate(rand) { 28 | rand.gompertz(shape, scale) 29 | } 30 | } -------------------------------------------------------------------------------- /src/main/java/blang/distributions/Gumbel.bl: -------------------------------------------------------------------------------- 1 | package blang.distributions 2 | 3 | /** The Gumbel Distribution. Values in \( \mathbb{R} \)*/ 4 | model Gumbel { 5 | random RealVar realization 6 | 7 | /** The location parameter \(\mu \). \( \mu \in \mathbb{R} \) */ 8 | param RealVar location 9 | 10 | /** The scale parameter \(\beta \). \( \beta > 0 \)*/ 11 | param RealVar scale 12 | 13 | laws { 14 | logf(location, scale) { 15 | if (scale <= 0.0) return NEGATIVE_INFINITY 16 | return - log(scale) 17 | } 18 | logf(location, scale, realization) { 19 | if (scale <= 0.0) return NEGATIVE_INFINITY 20 | return - exp((location - realization) / scale) + ((location - realization) / scale) 21 | } 22 | } 23 | 24 | generate(rand) { 25 | rand.gumbel(location, scale) 26 | } 27 | } -------------------------------------------------------------------------------- /src/main/java/blang/distributions/HalfStudentT.bl: -------------------------------------------------------------------------------- 1 | package blang.distributions 2 | 3 | /** HalfStudentT random variable. Values in \((0, \infty)\) */ 4 | model HalfStudentT{ 5 | random RealVar realization 6 | 7 | /** A degree of freedom parameter \(\nu\). \( \nu > 0 \) */ 8 | param RealVar nu 9 | 10 | /** A scale parameter \(\sigma\). \( \sigma > 0 \). */ 11 | param RealVar sigma 12 | 13 | laws{ 14 | logf(nu){ 15 | if (nu <= 0.0) return NEGATIVE_INFINITY 16 | return log(2.0) + lnGamma((nu + 1)/ 2.0) - lnGamma(nu / 2.0) - 0.5 * log(nu * PI) 17 | } 18 | logf(sigma){ 19 | if (sigma <= 0.0) return NEGATIVE_INFINITY 20 | return - log(sigma) 21 | } 22 | logf(nu, sigma, realization){ 23 | if (realization < 0.0) return NEGATIVE_INFINITY 24 | if (sigma <= 0.0) return NEGATIVE_INFINITY 25 | if (nu <= 0.0) return NEGATIVE_INFINITY 26 | return -((nu + 1.0) / 2.0) * log(1.0 + 1.0 / nu * pow(realization / sigma, 2)) 27 | } 28 | } 29 | generate(rand){ 30 | rand.halfstudentt(nu, sigma) 31 | } 32 | } -------------------------------------------------------------------------------- /src/main/java/blang/distributions/HyperGeometric.bl: -------------------------------------------------------------------------------- 1 | package blang.distributions 2 | 3 | /** A population of size \(N\), \(K\) of which are marked, and drawing without replacement \(n\) samples from the population; the HyperGeometric models the number in the sample that are marked. */ 4 | model HyperGeometric{ 5 | 6 | random IntVar numberOfSuccess 7 | 8 | /** Number sampled. \(n\) */ 9 | param IntVar numberOfDraws 10 | 11 | /** Population size. \(N\) */ 12 | param IntVar population 13 | 14 | /** Number marked in the population. \(K\) */ 15 | param IntVar populationConditioned 16 | 17 | laws{ 18 | logf(populationConditioned, numberOfSuccess){ 19 | if(numberOfSuccess<0) return NEGATIVE_INFINITY 20 | if(populationConditioned<=0 || numberOfSuccess > populationConditioned) return NEGATIVE_INFINITY 21 | return logBinomial(populationConditioned,numberOfSuccess) 22 | } 23 | logf(population, numberOfDraws){ 24 | if(numberOfDraws<0) return NEGATIVE_INFINITY 25 | if(population<=0 || numberOfDraws > population) return NEGATIVE_INFINITY 26 | return -logBinomial(population, numberOfDraws) 27 | } 28 | logf(populationConditioned,numberOfSuccess,population,numberOfDraws){ 29 | if(numberOfDraws-numberOfSuccess<0) return NEGATIVE_INFINITY 30 | if(population-populationConditioned<=0 || numberOfDraws-numberOfSuccess > population-populationConditioned) 31 | return NEGATIVE_INFINITY 32 | return logBinomial(population-populationConditioned,numberOfDraws-numberOfSuccess) 33 | } 34 | } 35 | generate (rand){ 36 | rand.hyperGeometric(numberOfDraws,population,populationConditioned) 37 | } 38 | } -------------------------------------------------------------------------------- /src/main/java/blang/distributions/Laplace.bl: -------------------------------------------------------------------------------- 1 | package blang.distributions 2 | 3 | /** The Laplace Distribution over \(\mathbb{R}\) */ 4 | model Laplace { 5 | random RealVar realization 6 | 7 | /** The mean parameter. */ 8 | param RealVar location 9 | 10 | /** The scale parameter \( b \), equal to the square root of half of the variance. \( b > 0 \) */ 11 | param RealVar scale 12 | 13 | laws { 14 | logf(realization, location, scale) { 15 | if (scale <= 0) return NEGATIVE_INFINITY 16 | return -log(2 * scale) - abs(realization - location) / scale 17 | } 18 | } 19 | 20 | generate(rand) { rand.laplace(location, scale) } 21 | } -------------------------------------------------------------------------------- /src/main/java/blang/distributions/LnUniform.bl: -------------------------------------------------------------------------------- 1 | package blang.distributions 2 | 3 | /** The random variable \(X = e^Y\) where \(Y \sim \text{ContinuousUniform}[m, M]\). */ 4 | model LnUniform { 5 | random RealVar realization 6 | 7 | /** The left end point \(m\) of the interval. \(m \in (-\infty, M)\) */ 8 | param RealVar min 9 | 10 | /** The right end point \(M\) of the interval. \(M \in (m, \infty)\) */ 11 | param RealVar max 12 | 13 | 14 | laws { 15 | realization | min, max ~ LogUniform(min, max, Math::E) 16 | } 17 | 18 | generate (rand) { 19 | val u = rand.uniform(min, max) 20 | return pow(Math::E, u) 21 | } 22 | } -------------------------------------------------------------------------------- /src/main/java/blang/distributions/LogLogistic.bl: -------------------------------------------------------------------------------- 1 | package blang.distributions 2 | 3 | /** A log-logistic distribution is the probability distribution of a random variable whose logarithm has a logistic distribution. Values in \([0, +\infty) \)*/ 4 | model LogLogistic { 5 | random RealVar realization 6 | 7 | /** The scale parameter \(\alpha\) and also the median. \(\alpha > 0 \) */ 8 | param RealVar scale 9 | 10 | /** The shape parameter \(\beta\). \(\beta > 0 \) */ 11 | param RealVar shape 12 | 13 | laws { 14 | logf(scale, shape) { 15 | if (scale <= 0.0) return NEGATIVE_INFINITY 16 | if (shape <= 0.0) return NEGATIVE_INFINITY 17 | return log(shape) - (shape*log(scale)) 18 | } 19 | logf(scale, shape, realization) { 20 | if (realization < 0.0) return NEGATIVE_INFINITY 21 | if (scale <= 0.0) return NEGATIVE_INFINITY 22 | if (shape <= 0.0) return NEGATIVE_INFINITY 23 | return shape*log(realization) - log(realization) 24 | } 25 | logf(scale, shape, realization) { 26 | if (realization < 0.0) return NEGATIVE_INFINITY 27 | if (scale <= 0.0) return NEGATIVE_INFINITY 28 | if (shape <= 0.0) return NEGATIVE_INFINITY 29 | return -2 * log(1 + pow((realization / scale), shape)) 30 | } 31 | } 32 | 33 | generate(rand){ 34 | rand.logLogistic(scale, shape) 35 | } 36 | } -------------------------------------------------------------------------------- /src/main/java/blang/distributions/LogPotential.bl: -------------------------------------------------------------------------------- 1 | package blang.distributions 2 | 3 | /** Not really a distribution, but rather a way to handle undirected model (AKA random fields). See Ising under the Examples page. */ 4 | model LogPotential { 5 | /** The log of the current value of this potential. */ 6 | param RealVar logPotential 7 | 8 | laws { 9 | logf(logPotential) { 10 | return logPotential 11 | } 12 | } 13 | 14 | generate (rand) { throw new RuntimeException } 15 | } -------------------------------------------------------------------------------- /src/main/java/blang/distributions/LogUniform.bl: -------------------------------------------------------------------------------- 1 | package blang.distributions 2 | 3 | /** The random variable \(X = b^Y\) where \(Y \sim \text{ContinuousUniform}[m, M]\). */ 4 | model LogUniform { 5 | random RealVar realization 6 | 7 | /** The left end point \(m\) of the interval. \(m \in (-\infty, M)\) */ 8 | param RealVar min 9 | 10 | /** The right end point \(M\) of the interval. \(M \in (m, \infty)\) */ 11 | param RealVar max 12 | 13 | /** The base \(b\). \(b > 0\) */ 14 | param RealVar base 15 | 16 | laws { 17 | logf(min, max) { 18 | if (max - min <= 0.0) return NEGATIVE_INFINITY 19 | return - log(max - min) 20 | } 21 | logf(realization, min, max, base) { 22 | if (base <= 0) return NEGATIVE_INFINITY 23 | if (realization <= 0) return NEGATIVE_INFINITY 24 | val logr = log(realization) 25 | val logb = log(base) 26 | val transformed = logr / logb 27 | if (min <= transformed && transformed <= max) return -logr - log(logb) 28 | else return NEGATIVE_INFINITY 29 | } 30 | } 31 | 32 | generate (rand) { 33 | val u = rand.uniform(min, max) 34 | return pow(base, u) 35 | } 36 | } -------------------------------------------------------------------------------- /src/main/java/blang/distributions/Logistic.bl: -------------------------------------------------------------------------------- 1 | package blang.distributions 2 | 3 | /** A random variable with a cumulative distribution given by the logistic function. Values in \( \mathbb{R} \) */ 4 | model Logistic { 5 | random RealVar realization 6 | 7 | /** The mean. \(\mu \in \mathbb{R}\) */ 8 | param RealVar location 9 | 10 | /** The scale parameter. \(s > 0\) */ 11 | param RealVar scale 12 | 13 | laws { 14 | logf(scale) { 15 | if (scale <= 0.0) return NEGATIVE_INFINITY 16 | return - log(scale) 17 | } 18 | logf(scale, location, realization) { 19 | if (scale <= 0.0) return NEGATIVE_INFINITY 20 | return (location - realization) / scale 21 | } 22 | logf(scale, location, realization) { 23 | if (scale <= 0.0) return NEGATIVE_INFINITY 24 | return - 2 * log(1.0 + exp( (location - realization) / scale)) 25 | } 26 | } 27 | 28 | generate(rand) { 29 | rand.logisticDist(location, scale) 30 | } 31 | } -------------------------------------------------------------------------------- /src/main/java/blang/distributions/MultivariateNormal.bl: -------------------------------------------------------------------------------- 1 | package blang.distributions 2 | 3 | /** Arbitrary linear transformations of \(n\) iid standard normal random variables. */ 4 | model MultivariateNormal { 5 | random Matrix realization 6 | 7 | /** An \(n \times 1\) vector \(\mu\). \(\mu \in \mathbb{R}^n\) */ 8 | param Matrix mean 9 | // Note: no need to mark as constrained since CholeskyDecomposition is read-only 10 | // so won't attempt to do naive sampling by default 11 | /** Inverse covariance matrix \(\Lambda\), a positive definite \(n \times n\) matrix. */ 12 | param CholeskyDecomposition precision 13 | 14 | laws { 15 | logf(double dim = realization.nEntries) { 16 | - 0.5 * dim * log(2.0*PI) 17 | } 18 | logf(precision) { 19 | 0.5 * precision.logDet 20 | } 21 | logf(mean, precision, realization) { 22 | val Matrix centered = mean - realization 23 | val Matrix L = precision.L 24 | // by doing left to right this is quadratic not cubic: 25 | return - 0.5 * (centered.transpose * L * L.transpose * centered).doubleValue 26 | } 27 | } 28 | 29 | generate(rand) { 30 | realization.setTo(rand.multivariateNormal(mean, precision)) 31 | } 32 | } -------------------------------------------------------------------------------- /src/main/java/blang/distributions/NegativeBinomial.bl: -------------------------------------------------------------------------------- 1 | package blang.distributions 2 | 3 | /** Number of successes in a sequence of iid Bernoulli until \(r\) failures occur. Values in \(\{0, 1, 2, \dots\}\). */ 4 | model NegativeBinomial { 5 | random IntVar k 6 | 7 | /** Number of failures until experiment is stopped (generalized to the reals). \(r > 0\) */ 8 | param RealVar r 9 | 10 | /** Probability of success of each experiment. \(p \in (0, 1)\) */ 11 | param RealVar p 12 | 13 | laws { 14 | logf(k, r) { 15 | if (r <= 0 || k < 0) return NEGATIVE_INFINITY 16 | val result = logBinomial(k+r-1.0, k) 17 | if (result.isNaN) 18 | return NEGATIVE_INFINITY // E.g: if k = 0 and r = 1.4761528506003524E-63, logBinomial gives -INF 19 | return result 20 | } 21 | logf(r, k, p) { 22 | if (p <= 0.0 || p >= 1.0) return NEGATIVE_INFINITY 23 | if (r <= 0 || k < 0) return NEGATIVE_INFINITY 24 | return k * log(p) + r * log(1.0 - p) 25 | } 26 | } 27 | 28 | generate(rand) { 29 | rand.negativeBinomial(r, p) 30 | } 31 | } -------------------------------------------------------------------------------- /src/main/java/blang/distributions/NegativeBinomialMeanParam.bl: -------------------------------------------------------------------------------- 1 | package blang.distributions 2 | 3 | model NegativeBinomialMeanParam { 4 | random IntVar k 5 | param RealVar mean, overdispersion 6 | 7 | laws { 8 | k | mean, overdispersion ~ NegativeBinomial( 9 | mean * mean / overdispersion, 10 | 1.0 - mean/(mean + overdispersion) 11 | ) 12 | } 13 | 14 | generate (rand) { 15 | rand.negativeBinomial(mean * mean / overdispersion, 1.0 - mean/(mean + overdispersion)) 16 | } 17 | } -------------------------------------------------------------------------------- /src/main/java/blang/distributions/Normal.bl: -------------------------------------------------------------------------------- 1 | package blang.distributions 2 | 3 | /** Normal random variables. Values in \(\mathbb{R}\) */ 4 | model Normal { 5 | random RealVar realization 6 | 7 | /** Mean \(\mu\). \(\mu \in \mathbb{R}\) */ 8 | param RealVar mean 9 | 10 | /** Variance \(\sigma^2\). \(\sigma^2 > 0\) */ 11 | param RealVar variance 12 | 13 | laws { 14 | logf() { 15 | - 0.5 * log(2.0*PI) 16 | } 17 | logf(variance) { 18 | if (variance <= 0.0) return NEGATIVE_INFINITY 19 | return - 0.5 * log(variance) 20 | } 21 | logf(mean, variance, realization) { 22 | if (variance <= 0.0) return NEGATIVE_INFINITY 23 | return - 0.5 * pow(mean - realization, 2) / variance 24 | } 25 | } 26 | 27 | generate(rand) { 28 | rand.normal(mean, variance) 29 | } 30 | } -------------------------------------------------------------------------------- /src/main/java/blang/distributions/NormalField.bl: -------------------------------------------------------------------------------- 1 | package blang.distributions 2 | 3 | import briefj.collections.UnorderedPair 4 | import blang.types.Precision 5 | import briefj.Indexer 6 | 7 | /** A mean-zero normal, sparse-precision Markov random field. 8 | * For small problem, use MultivariateNormal instead, 9 | * but for problems with a large, sparse precision matrix, this implementation 10 | * allows the user to specify a 'support' 11 | * for the precision, outside of which the precision is guaranteed to be zero. This 12 | * can speed up sampling considerably. 13 | */ 14 | @Samplers(EllipticalSliceSampler) 15 | model NormalField { 16 | /** Precision matrix structure. 17 | * precision.support is assumed to be constant. 18 | * TODO: add some construct that test this exponentially less and less frequently 19 | */ 20 | param Precision precision 21 | random Plated realization 22 | 23 | laws { 24 | for (UnorderedPair pair : precision.support) { 25 | logf( 26 | precision, 27 | pair, 28 | RealVar x0 = realization.get(precision.plate.index(pair.first)), 29 | RealVar x1 = realization.get(precision.plate.index(pair.second)) 30 | ) { 31 | if (pair.first == pair.second) { 32 | return - 0.5 * precision.get(pair) * x0 * x0 33 | } else { 34 | // 0.5 * 2 = 1 (because we iterate over set of unordered pairs) 35 | return - precision.get(pair) * x0 * x1 36 | } 37 | } 38 | } 39 | logf(int dim = precision.plate.indices.size) { 40 | - dim * log(2*PI) / 2.0 41 | } 42 | logf(precision) { 43 | 0.5 * precision.logDet 44 | } 45 | } 46 | 47 | generate (rand) { 48 | val Precision p = precision 49 | val Indexer indexer = Precision::indexer(p.plate) 50 | val Matrix precisionMatrix = Precision::asMatrix(p, indexer) 51 | val Matrix result = sampleNormalByPrecision(rand, precisionMatrix) 52 | for (Index index : p.plate.indices) { 53 | (realization.get(index) as WritableRealVar).set(result.get(indexer.o2i(index.key))) 54 | } 55 | } 56 | } -------------------------------------------------------------------------------- /src/main/java/blang/distributions/Poisson.bl: -------------------------------------------------------------------------------- 1 | package blang.distributions 2 | 3 | /** Poisson random variable. Values in \(0, 1, 2, \dots\) */ 4 | model Poisson { 5 | random IntVar realization 6 | 7 | /** Mean parameter \(\lambda\). \(\lambda > 0\) */ 8 | param RealVar mean 9 | 10 | laws { 11 | logf(realization, mean) { 12 | if (mean <= 0) return NEGATIVE_INFINITY 13 | if (realization < 0) return NEGATIVE_INFINITY 14 | return realization * log(mean) 15 | } 16 | logf(mean) { 17 | if (mean <= 0) return NEGATIVE_INFINITY 18 | return - mean; 19 | } 20 | logf(realization) { 21 | if (realization < 0) return NEGATIVE_INFINITY 22 | return - logFactorial(realization) 23 | } 24 | } 25 | 26 | generate(rand) { 27 | rand.poisson(mean) 28 | } 29 | } -------------------------------------------------------------------------------- /src/main/java/blang/distributions/SimplexUniform.bl: -------------------------------------------------------------------------------- 1 | package blang.distributions 2 | 3 | /** \(n\) dimensional Dirichlet with all concentrations equal to one. */ 4 | model SimplexUniform { 5 | random Simplex realization 6 | 7 | /** The dimensionality \(n\). \( n > 0 \) */ 8 | param Integer dim 9 | 10 | laws { 11 | realization | dim ~ Dirichlet(ones(dim)) 12 | } 13 | 14 | generate (rand) { 15 | rand.dirichletInPlace(ones(dim), realization) 16 | } 17 | } -------------------------------------------------------------------------------- /src/main/java/blang/distributions/StudentT.bl: -------------------------------------------------------------------------------- 1 | package blang.distributions 2 | 3 | /** Student T random variable. Values in \(\mathbb{R}\) */ 4 | model StudentT { 5 | random RealVar realization 6 | 7 | /** The degrees of freedom \(\nu\). \( \nu > 0 \) */ 8 | param RealVar nu 9 | 10 | /** Location parameter \(\mu\). \(\mu \in \mathbb{R}\) */ 11 | param RealVar mu 12 | 13 | /** Scale parameter \(\sigma\). \(\sigma > 0\) */ 14 | param RealVar sigma 15 | 16 | laws{ 17 | logf(){ 18 | return - 0.5 * log(PI) 19 | } 20 | 21 | logf(sigma){ 22 | if (sigma <= 0.0) return NEGATIVE_INFINITY 23 | return -log(sigma) 24 | } 25 | 26 | logf(nu){ 27 | if (nu <= 0.0) return NEGATIVE_INFINITY 28 | return lnGamma((nu + 1.0) / 2.0) - 0.5 * log(nu) - lnGamma(nu / 2.0) 29 | } 30 | 31 | logf(mu, nu, sigma, realization){ 32 | if (nu <= 0.0 || sigma <= 0.0) return NEGATIVE_INFINITY 33 | return - ((nu + 1.0) / 2.0) * log(1.0 + (1.0 / nu) * ( 1.0 / (pow(sigma, 2)) ) * pow((realization - mu), 2)) 34 | } 35 | 36 | } 37 | 38 | generate(rand){ 39 | rand.studentt(nu, mu, sigma) 40 | } 41 | } -------------------------------------------------------------------------------- /src/main/java/blang/distributions/SymmetricDirichlet.bl: -------------------------------------------------------------------------------- 1 | package blang.distributions 2 | 3 | /** \(n\) dimensional Dirichlet with all concentrations equal to \(\alpha / n\). */ 4 | model SymmetricDirichlet { 5 | random Simplex realization 6 | 7 | /** The dimensionality \(n\). \( n > 0 \) */ 8 | param Integer dim 9 | 10 | /** The shared concentration parameter \(\alpha\) before normalization by the dimensionality. \(\alpha > 0\) */ 11 | param RealVar concentration 12 | 13 | laws { 14 | realization | dim, concentration ~ Dirichlet(ones(dim) * (concentration.doubleValue / dim)) 15 | } 16 | 17 | generate(rand) { 18 | rand.dirichletInPlace(ones(dim) * (concentration.doubleValue / dim), realization) 19 | } 20 | } -------------------------------------------------------------------------------- /src/main/java/blang/distributions/Weibull.bl: -------------------------------------------------------------------------------- 1 | package blang.distributions 2 | 3 | /** The Weibull Distribution. Values in \((0, \infty)\).*/ 4 | model Weibull { 5 | random RealVar realization 6 | 7 | /** The scale parameter \(\lambda\). \( \lambda \in (0, +\infty) \) */ 8 | param RealVar scale 9 | 10 | /** The shape parameter \(k\). \( k \in (0, +\infty) \) */ 11 | param RealVar shape 12 | 13 | laws { 14 | logf(scale, shape) { 15 | if (scale <= 0.0) return NEGATIVE_INFINITY 16 | if (shape <= 0.0) return NEGATIVE_INFINITY 17 | return log(shape) - (shape*log(scale)) 18 | } 19 | logf(scale, shape, realization) { 20 | if (scale <= 0.0) return NEGATIVE_INFINITY 21 | if (shape <= 0.0) return NEGATIVE_INFINITY 22 | if (realization <= 0.0) return NEGATIVE_INFINITY 23 | return (shape - 1) * log(realization) 24 | } 25 | logf(scale, shape, realization) { 26 | if (scale <= 0.0) return NEGATIVE_INFINITY 27 | if (shape <= 0.0) return NEGATIVE_INFINITY 28 | if (realization <= 0.0) return NEGATIVE_INFINITY 29 | return - pow((realization / scale), shape) 30 | } 31 | } 32 | 33 | generate(rand) { 34 | rand.weibull(scale, shape) 35 | } 36 | } -------------------------------------------------------------------------------- /src/main/java/blang/distributions/YuleSimon.bl: -------------------------------------------------------------------------------- 1 | package blang.distributions 2 | 3 | import static org.apache.commons.math3.special.Beta.logBeta 4 | 5 | /** An exponential-geometric mixture. */ 6 | model YuleSimon { 7 | random IntVar realization 8 | 9 | /** The rate of the mixing exponential distribution. */ 10 | param RealVar rho 11 | 12 | laws { 13 | logf(rho, realization) { 14 | if (rho <= 0) return NEGATIVE_INFINITY 15 | if (realization < 0) return NEGATIVE_INFINITY 16 | return log(rho) + logBeta(1.0 + realization, rho + 1) 17 | } 18 | } 19 | 20 | generate(rand) { 21 | val w = rand.exponential(rho) 22 | return rand.negativeBinomial(1.0, 1.0 - exp(-w)) // This is the correct formula. Sometimes reported in the literature as "-exp(w)" (e.g. wikipedia, which I have edited now to fix, due to confusion b/w success pr/failure pr. of geometric) 23 | } 24 | } -------------------------------------------------------------------------------- /src/main/java/blang/distributions/internals/Helpers.java: -------------------------------------------------------------------------------- 1 | package blang.distributions.internals; 2 | 3 | import bayonet.math.NumericalUtils; 4 | 5 | public class Helpers 6 | { 7 | public static boolean warnedUnstableConcentration = false; 8 | public static double concentrationWarningThreshold = 0.5; 9 | public static void checkDirichletOrBetaParam(double concentration) 10 | { 11 | if (!warnedUnstableConcentration && concentration < concentrationWarningThreshold - NumericalUtils.THRESHOLD) 12 | { 13 | warnedUnstableConcentration = true; 14 | System.err.println("Warning: small concentrations may cause numeric instability to Dirichlet and Beta distributions. " 15 | + "Consider enforcing a lower bound of say " + concentrationWarningThreshold + " " 16 | + "This message may also occur when slice samling outside of such constraint, you can then ignore this message. "); 17 | } 18 | } 19 | 20 | private Helpers() {} 21 | } 22 | -------------------------------------------------------------------------------- /src/main/java/blang/engines/internals/CovarAccumulator.java: -------------------------------------------------------------------------------- 1 | package blang.engines.internals; 2 | 3 | public class CovarAccumulator 4 | { 5 | double meanx = 0, meany = 0, C = 0; 6 | int n = 0; 7 | public void add(double x, double y) { 8 | n += 1; 9 | double dx = x - meanx; 10 | meanx += dx / n; 11 | meany += (y - meany) / n; 12 | C += dx * (y - meany); 13 | } 14 | public double sampleCovariance() { 15 | return C / n; 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /src/main/java/blang/engines/internals/LogSumAccumulator.java: -------------------------------------------------------------------------------- 1 | package blang.engines.internals; 2 | 3 | import bayonet.math.NumericalUtils; 4 | 5 | /** 6 | * Stores a sum in log scale and allow adding 7 | * one term stored in log scale 8 | */ 9 | public class LogSumAccumulator { 10 | double logSum = Double.NEGATIVE_INFINITY; 11 | long n = 0; 12 | 13 | public double logSum() { 14 | return logSum; 15 | } 16 | 17 | public long numberOfTerms() { 18 | return n; 19 | } 20 | 21 | /** 22 | * Conceptually, performs logSum <- log ( exp(logSum) + exp(logTerm) ) 23 | * but in a numerically stable and efficient fashion. 24 | * 25 | * @param logTerm 26 | */ 27 | public void add(double logTerm) { 28 | logSum = NumericalUtils.logAdd(logSum, logTerm); 29 | n++; 30 | } 31 | } 32 | -------------------------------------------------------------------------------- /src/main/java/blang/engines/internals/PosteriorInferenceEngine.java: -------------------------------------------------------------------------------- 1 | package blang.engines.internals; 2 | 3 | import blang.engines.internals.factories.AIS; 4 | import blang.engines.internals.factories.Exact; 5 | import blang.engines.internals.factories.Forward; 6 | import blang.engines.internals.factories.ISCM; 7 | import blang.engines.internals.factories.IAIS; 8 | import blang.engines.internals.factories.MCMC; 9 | import blang.engines.internals.factories.None; 10 | import blang.engines.internals.factories.PT; 11 | import blang.engines.internals.factories.SCM; 12 | import blang.inits.Implementations; 13 | import blang.runtime.SampledModel; 14 | import blang.runtime.internals.objectgraph.GraphAnalysis; 15 | 16 | @Implementations({SCM.class, PT.class, MCMC.class, AIS.class, Forward.class, Exact.class, None.class, ISCM.class, IAIS.class}) 17 | public interface PosteriorInferenceEngine 18 | { 19 | public void setSampledModel(SampledModel model); 20 | public void performInference(); 21 | public void check(GraphAnalysis analysis); 22 | } 23 | -------------------------------------------------------------------------------- /src/main/java/blang/engines/internals/SplineDerivatives.xtend: -------------------------------------------------------------------------------- 1 | package blang.engines.internals 2 | 3 | import blang.engines.internals.Spline.MonotoneCubicSpline 4 | import org.apache.commons.math3.analysis.differentiation.DerivativeStructure 5 | 6 | import static extension xlinear.AutoDiff.* 7 | 8 | class SplineDerivatives { 9 | 10 | def static double derivative(MonotoneCubicSpline it, double _x) { 11 | val n = mX.length 12 | if (Double.isNaN(_x)) { 13 | return _x; 14 | } 15 | if (_x <= mX.get(0)) { 16 | return 0.0 17 | } 18 | if (_x >= mX.get(n - 1)) { 19 | return 0.0 20 | } 21 | // Find the index 'i' of the last point with smaller X. 22 | // We know this will be within the spline due to the boundary tests. 23 | var i = 0; 24 | while (_x >= mX.get(i + 1)) { 25 | i += 1 26 | } 27 | val h = mX.get(i + 1) - mX.get(i) 28 | val x = new DerivativeStructure(1, 1, 0, _x) 29 | val t = (x - mX.get(i)) / h 30 | val result = (mY.get(i) * (1 + 2 * t) + h * mM.get(i) * t) * (1 - t) * (1 - t) + 31 | (mY.get(i + 1) * (3 - 2 * t) + h * mM.get(i + 1) * (t - 1)) * t * t 32 | return result.getPartialDerivative(1) 33 | } 34 | } -------------------------------------------------------------------------------- /src/main/java/blang/engines/internals/factories/AIS.java: -------------------------------------------------------------------------------- 1 | package blang.engines.internals.factories; 2 | 3 | import blang.runtime.SampledModel; 4 | 5 | public class AIS extends SCM { 6 | @Override 7 | public void setSampledModel(SampledModel model) 8 | { 9 | resamplingESSThreshold = 0.0; 10 | super.setSampledModel(model); 11 | } 12 | } 13 | -------------------------------------------------------------------------------- /src/main/java/blang/engines/internals/factories/Forward.java: -------------------------------------------------------------------------------- 1 | package blang.engines.internals.factories; 2 | 3 | 4 | import bayonet.distributions.Random; 5 | import blang.engines.internals.PosteriorInferenceEngine; 6 | import blang.inits.Arg; 7 | import blang.inits.DefaultValue; 8 | import blang.inits.GlobalArg; 9 | import blang.inits.experiments.ExperimentResults; 10 | import blang.io.BlangTidySerializer; 11 | import blang.runtime.Runner; 12 | import blang.runtime.SampledModel; 13 | import blang.runtime.internals.objectgraph.GraphAnalysis; 14 | 15 | public class Forward implements PosteriorInferenceEngine 16 | { 17 | @Arg @DefaultValue("1") 18 | public Random random = new Random(1); 19 | 20 | @GlobalArg ExperimentResults results; 21 | 22 | SampledModel model; 23 | 24 | @Override 25 | public void setSampledModel(SampledModel model) 26 | { 27 | this.model = model; 28 | } 29 | 30 | @SuppressWarnings("unchecked") 31 | @Override 32 | public void performInference() 33 | { 34 | BlangTidySerializer tidySerializer = new BlangTidySerializer(results.child(Runner.SAMPLES_FOLDER)); 35 | model.forwardSample(random, true); 36 | model.getSampleWriter(tidySerializer).write(); 37 | } 38 | 39 | @Override 40 | public void check(GraphAnalysis analysis) 41 | { 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /src/main/java/blang/engines/internals/factories/IAIS.xtend: -------------------------------------------------------------------------------- 1 | package blang.engines.internals.factories 2 | 3 | import blang.runtime.SampledModel 4 | 5 | class IAIS extends ISCM { 6 | override void setSampledModel(SampledModel model) { 7 | resamplingESSThreshold = 0.0 8 | super.setSampledModel(model) 9 | } 10 | } -------------------------------------------------------------------------------- /src/main/java/blang/engines/internals/factories/MCMC.java: -------------------------------------------------------------------------------- 1 | package blang.engines.internals.factories; 2 | 3 | import blang.runtime.SampledModel; 4 | 5 | public class MCMC extends PT { 6 | @Override 7 | public void setSampledModel(SampledModel m) 8 | { 9 | nPassesPerScan = 1; 10 | nChains = 1; 11 | usePriorSamples = false; 12 | initialization = InitType.COPIES; 13 | super.setSampledModel(m); 14 | } 15 | 16 | } 17 | -------------------------------------------------------------------------------- /src/main/java/blang/engines/internals/factories/None.java: -------------------------------------------------------------------------------- 1 | package blang.engines.internals.factories; 2 | 3 | import blang.engines.internals.PosteriorInferenceEngine; 4 | import blang.runtime.SampledModel; 5 | import blang.runtime.internals.objectgraph.GraphAnalysis; 6 | 7 | public class None implements PosteriorInferenceEngine 8 | { 9 | 10 | @Override 11 | public void setSampledModel(SampledModel model) 12 | { 13 | } 14 | 15 | @Override 16 | public void performInference() 17 | { 18 | } 19 | 20 | @Override 21 | public void check(GraphAnalysis analysis) 22 | { 23 | } 24 | 25 | } 26 | -------------------------------------------------------------------------------- /src/main/java/blang/engines/internals/ladders/EquallySpaced.java: -------------------------------------------------------------------------------- 1 | package blang.engines.internals.ladders; 2 | 3 | import java.util.ArrayList; 4 | import java.util.List; 5 | 6 | public class EquallySpaced implements TemperatureLadder 7 | { 8 | @Override 9 | public List temperingParameters(int nChains) 10 | { 11 | List temperingParameters = new ArrayList<>(); 12 | if (nChains == 1) 13 | temperingParameters.add(1.0); 14 | else 15 | for (int i = nChains - 1; i >= 0; i--) 16 | temperingParameters.add(((double) i) / ((double) nChains - 1.0)); 17 | return temperingParameters; 18 | } 19 | } 20 | -------------------------------------------------------------------------------- /src/main/java/blang/engines/internals/ladders/FromAnotherExec.java: -------------------------------------------------------------------------------- 1 | package blang.engines.internals.ladders; 2 | 3 | import static blang.inits.experiments.tabwriters.TidySerializer.VALUE; 4 | 5 | import java.io.File; 6 | import java.util.ArrayList; 7 | import java.util.List; 8 | import java.util.Map; 9 | 10 | import blang.inits.Arg; 11 | import blang.inits.DefaultValue; 12 | import briefj.BriefIO; 13 | 14 | public class FromAnotherExec implements TemperatureLadder 15 | { 16 | @Arg @DefaultValue("An 'annealingParameters.csv[.gz] file from the monitoring folder " 17 | + "of an ealier execution. The schedule from the final round will be used as the initialiation " 18 | + "of this one. ") 19 | public File annealingParameters; 20 | 21 | @Arg(description = "If the command line argument 'nChains' is different, than the number of " 22 | + "provided grid points, allow the use of spline interpolation/extrapolation.") 23 | @DefaultValue("false") 24 | public boolean allowSplineGeneralization = false; 25 | 26 | @Override 27 | public List temperingParameters(int nChains) 28 | { 29 | List parsed = new ArrayList(); 30 | parsed.add(0.0); 31 | for (Map line : BriefIO.readLines(annealingParameters).indexCSV()) 32 | if (line.get("isAdapt").equals("false")) 33 | { 34 | parsed.add(Double.parseDouble(line.get(VALUE))); 35 | } 36 | UserSpecified userSpecified = new UserSpecified(); 37 | userSpecified.annealingParameters = parsed; 38 | userSpecified.allowSplineGeneralization = this.allowSplineGeneralization; 39 | return userSpecified.temperingParameters(nChains); 40 | } 41 | 42 | } 43 | -------------------------------------------------------------------------------- /src/main/java/blang/engines/internals/ladders/Geometric.java: -------------------------------------------------------------------------------- 1 | package blang.engines.internals.ladders; 2 | 3 | import java.util.ArrayList; 4 | import java.util.List; 5 | 6 | import blang.inits.Arg; 7 | import blang.inits.DefaultValue; 8 | 9 | public class Geometric implements TemperatureLadder 10 | { 11 | @Arg @DefaultValue("0.8") 12 | public double annealingScaling = 0.8; 13 | 14 | @Override 15 | public List temperingParameters(int nChains) 16 | { 17 | List temperingParameters = new ArrayList<>(); 18 | if (annealingScaling < 0.0 || annealingScaling >= 1.0) 19 | throw new RuntimeException("Annealing scaling must be between 0 and 1 exclusively."); 20 | if (nChains == 1) 21 | temperingParameters.add(1.0); 22 | else 23 | { 24 | for (int i = 0; i < nChains - 1; i++) 25 | temperingParameters.add(Math.pow(annealingScaling, i)); 26 | temperingParameters.add(0.0); 27 | } 28 | return temperingParameters; 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /src/main/java/blang/engines/internals/ladders/Polynomial.java: -------------------------------------------------------------------------------- 1 | package blang.engines.internals.ladders; 2 | 3 | import java.util.ArrayList; 4 | import java.util.List; 5 | 6 | import blang.inits.Arg; 7 | import blang.inits.DefaultValue; 8 | 9 | public class Polynomial implements TemperatureLadder 10 | { 11 | @Arg @DefaultValue("3") 12 | public double power = 3; 13 | 14 | @Override 15 | public List temperingParameters(int nChains) 16 | { 17 | List temperingParameters = new ArrayList<>(); 18 | if (power < 1.0) 19 | throw new RuntimeException("Annealing scaling must be between 0 and 1 exclusively."); 20 | if (nChains == 1) 21 | temperingParameters.add(1.0); 22 | else 23 | for (int i = 0; i < nChains; i++) { 24 | double fraction = (double) i / ((double) nChains - 1.0); 25 | temperingParameters.add(Math.pow((1.0 - fraction), power)); 26 | } 27 | return temperingParameters; 28 | } 29 | } 30 | -------------------------------------------------------------------------------- /src/main/java/blang/engines/internals/ladders/TemperatureLadder.java: -------------------------------------------------------------------------------- 1 | package blang.engines.internals.ladders; 2 | 3 | import java.util.List; 4 | 5 | import blang.inits.Implementations; 6 | 7 | /** 8 | * Provides a temperature ladder for parallel tempering-type algorithms. 9 | * Difference with a TemperatureSchedule is that the whole chain has to 10 | * be provided at once. 11 | */ 12 | @Implementations({Geometric.class, EquallySpaced.class, Polynomial.class, UserSpecified.class, FromAnotherExec.class}) 13 | public interface TemperatureLadder 14 | { 15 | /** 16 | * Fill the provided temperingParameters with annealing parameters (index 0, i.e. first one, should be 1 - i.e. room temperature) 17 | */ 18 | List temperingParameters(int nTemperatures); 19 | } -------------------------------------------------------------------------------- /src/main/java/blang/engines/internals/ptanalysis/PathViz.xtend: -------------------------------------------------------------------------------- 1 | package blang.engines.internals.ptanalysis 2 | 3 | import viz.core.Viz 4 | import viz.core.PublicSize 5 | import viz.core.Viz.PrivateSize 6 | import blang.inits.ConstructorArg 7 | import blang.inits.experiments.Experiment 8 | import blang.inits.DesignatedConstructor 9 | import blang.inits.DefaultValue 10 | import blang.inits.Arg 11 | import java.util.Optional 12 | 13 | class PathViz extends Viz { 14 | val Paths paths 15 | 16 | @Arg public Optional boldTrajectory = Optional.empty 17 | 18 | @DefaultValue("true") 19 | @Arg public boolean useAcceptRejectColours = true 20 | 21 | public float ratio = 0.5f 22 | 23 | @DesignatedConstructor 24 | new( 25 | @ConstructorArg("swapIndicators") Paths paths, 26 | @ConstructorArg("size") @DefaultValue("height", "300") PublicSize publicSize 27 | ) { 28 | super(publicSize) 29 | this.paths = paths 30 | } 31 | 32 | val baseWeight = 0.05f 33 | override protected draw() { 34 | translate(0.5f, 0.5f) 35 | val boldStroke = 6 * baseWeight 36 | val minY = 0f 37 | val maxY = paths.nChains - 1 38 | val minX = 0f 39 | val maxX = ratio * (paths.nIterations - 1) 40 | for (c : 0 ..< paths.nChains) { 41 | if (boldTrajectory.orElse(-1) == c) 42 | strokeWeight(boldStroke) 43 | else 44 | strokeWeight(baseWeight) 45 | if (!useAcceptRejectColours) 46 | setColour(c) 47 | val path = paths.get(c) 48 | for (i : 1 ..< paths.nIterations) { 49 | if (useAcceptRejectColours) 50 | setColour(path.get(i-1) != path.get(i)) 51 | val y0 = path.get(i-1) 52 | val y1 =path.get(i) 53 | line(ratio*(i-1), y0, ratio*i, y1) 54 | if (useAcceptRejectColours) 55 | stroke(0, 0, 0) 56 | ellipse(ratio*(i - 1), y0, 0.1f, 0.1f) 57 | } 58 | ellipse(maxX, path.get(paths.nIterations - 1), 0.1f, 0.1f) 59 | } 60 | // black boundaries (masks corner case red/green color off there) 61 | stroke(0, 0, 0) 62 | strokeWeight(boldStroke) 63 | line(minX, minY, maxX, minY) 64 | line(minX, maxY, maxX, maxY) 65 | } 66 | 67 | def void setColour(boolean accepted) { 68 | if (accepted) stroke(0, 204, 0) 69 | else stroke(204, 0, 0) 70 | } 71 | 72 | def void setColour(int chainIndex) { 73 | val from = color(204, 102, 0) 74 | val to = color(0, 102, 153) 75 | val interpolated = lerpColor(from, to, 1.0f * chainIndex / paths.nChains) 76 | stroke(interpolated) 77 | } 78 | 79 | override protected privateSize() { new PrivateSize(paths.nIterations * ratio, paths.nChains) } 80 | 81 | static def void main(String [] args) { Experiment::start(args) } 82 | } -------------------------------------------------------------------------------- /src/main/java/blang/engines/internals/schedules/FixedTemperatureSchedule.java: -------------------------------------------------------------------------------- 1 | package blang.engines.internals.schedules; 2 | 3 | import bayonet.smc.ParticlePopulation; 4 | import blang.inits.Arg; 5 | import blang.inits.DefaultValue; 6 | import blang.runtime.SampledModel; 7 | 8 | public class FixedTemperatureSchedule implements TemperatureSchedule 9 | { 10 | @Arg @DefaultValue("100") 11 | public int nTemperatures = 100; 12 | 13 | @Override 14 | public double nextTemperature(ParticlePopulation population, double temperature, double maxAnnealingParameter) 15 | { 16 | if (nTemperatures < 1) 17 | throw new RuntimeException("Number of temperatures should be positive: " + nTemperatures); 18 | return Math.min(maxAnnealingParameter, temperature + maxAnnealingParameter / ((double) nTemperatures)); 19 | } 20 | } -------------------------------------------------------------------------------- /src/main/java/blang/engines/internals/schedules/TemperatureSchedule.java: -------------------------------------------------------------------------------- 1 | package blang.engines.internals.schedules; 2 | 3 | import bayonet.smc.ParticlePopulation; 4 | import blang.inits.Implementations; 5 | import blang.runtime.SampledModel; 6 | 7 | @Implementations({AdaptiveTemperatureSchedule.class, FixedTemperatureSchedule.class, UserSpecified.class}) 8 | public interface TemperatureSchedule 9 | { 10 | double nextTemperature(ParticlePopulation population, double annealingParam, double maxAnnealParam); 11 | } -------------------------------------------------------------------------------- /src/main/java/blang/engines/internals/schedules/UserSpecified.java: -------------------------------------------------------------------------------- 1 | package blang.engines.internals.schedules; 2 | 3 | import java.util.ArrayList; 4 | import java.util.Collections; 5 | import java.util.HashMap; 6 | import java.util.List; 7 | import java.util.Map; 8 | 9 | import bayonet.smc.ParticlePopulation; 10 | import blang.inits.ConstructorArg; 11 | import blang.inits.DesignatedConstructor; 12 | import blang.runtime.SampledModel; 13 | 14 | public class UserSpecified implements TemperatureSchedule 15 | { 16 | Map next; 17 | 18 | @DesignatedConstructor 19 | public UserSpecified(@ConstructorArg("annealingParameters") List annealingParameters) 20 | { 21 | set(annealingParameters); 22 | } 23 | 24 | public void set(List annealingParameters) 25 | { 26 | annealingParameters = new ArrayList<>(annealingParameters); 27 | Collections.sort(annealingParameters); 28 | if (annealingParameters.get(0) != 0.0 || annealingParameters.get(annealingParameters.size()-1) != 1.0) 29 | throw new RuntimeException(); 30 | next = new HashMap<>(); 31 | for (int i = 0; i < annealingParameters.size() - 1; i++) 32 | next.put(annealingParameters.get(i), annealingParameters.get(i + 1)); 33 | if (next.size() != annealingParameters.size() - 1) 34 | throw new RuntimeException(); 35 | } 36 | 37 | @Override 38 | public double nextTemperature(ParticlePopulation population, double temperature, double maxAnnealingParameter) 39 | { 40 | return next.get(temperature); 41 | } 42 | } -------------------------------------------------------------------------------- /src/main/java/blang/io/BlangTidySerializer.xtend: -------------------------------------------------------------------------------- 1 | package blang.io 2 | 3 | import blang.inits.experiments.tabwriters.TidySerializer 4 | import blang.inits.DesignatedConstructor 5 | import blang.inits.GlobalArg 6 | import blang.inits.experiments.ExperimentResults 7 | import xlinear.Matrix 8 | import blang.inits.experiments.tabwriters.TabularWriter 9 | import xlinear.StaticUtils 10 | import blang.types.Plated 11 | import java.util.Map.Entry 12 | import blang.types.internals.Query 13 | import blang.types.Index 14 | import blang.types.PlatedMatrix 15 | 16 | class BlangTidySerializer extends TidySerializer { 17 | 18 | @DesignatedConstructor 19 | new(@GlobalArg ExperimentResults result) { 20 | super(result) 21 | } 22 | 23 | def dispatch protected void serializeImplementation(Matrix m, TabularWriter writer) { 24 | if (m.isVector) { 25 | for (var int i = 0; i < m.nEntries; i++) { 26 | writer.write("entry" -> i, TidySerializer::VALUE -> m.get(i)) 27 | } 28 | } else { 29 | StaticUtils::visitSkippingSomeZeros(m) [ int row, int col, double value | 30 | writer.write("row" -> row, "col" -> col, TidySerializer::VALUE -> value) 31 | ] 32 | } 33 | } 34 | 35 | def dispatch protected void serializeImplementation(Plated p, TabularWriter writer) { 36 | for (_entry : p.entries) { 37 | val Entry entry = _entry as Entry // work around type inference bug 38 | var TabularWriter childWriter = writer 39 | for (Index index : entry.key.indices) { 40 | childWriter = childWriter.child(index.plate.name.string, index.key) 41 | } 42 | serializeImplementation(entry.value, childWriter) 43 | } 44 | } 45 | 46 | def dispatch protected void serializeImplementation(PlatedMatrix p, TabularWriter writer) { 47 | for (_entry : p.entries) { 48 | val Entry entry = _entry as Entry // work around type inference bug 49 | var TabularWriter childWriter = writer 50 | for (Index index : entry.key.indices) { 51 | childWriter = childWriter.child(index.plate.name.string, index.key) 52 | } 53 | val currentMatrix = entry.value as Matrix 54 | val rowIndexer = p.rowIndexer(entry.key) 55 | val colIndexer = p.colIndexer(entry.key) 56 | for (r : 0 ..< currentMatrix.nRows) { 57 | val rowWriter = childWriter.child(p.rowPlate.name.string, rowIndexer.i2o(r).key) 58 | for (c : 0 ..< currentMatrix.nCols) { 59 | val actualWriter = // cases for vector vs matrix 60 | if (colIndexer === null) 61 | rowWriter 62 | else 63 | rowWriter.child(p.colPlate.name.string, colIndexer.i2o(c).key) 64 | serializeImplementation(currentMatrix.get(r, c), actualWriter) 65 | } 66 | } 67 | } 68 | } 69 | } -------------------------------------------------------------------------------- /src/main/java/blang/io/DataSource.xtend: -------------------------------------------------------------------------------- 1 | package blang.io 2 | 3 | import blang.inits.Input 4 | import java.util.Optional 5 | import blang.inits.DesignatedConstructor 6 | import blang.inits.Arg 7 | import blang.inits.DefaultValue 8 | import java.util.Map 9 | import org.eclipse.xtend.lib.annotations.Accessors 10 | import java.util.Set 11 | import java.util.Collections 12 | import blang.types.internals.ColumnName 13 | import blang.io.internals.CSV 14 | import blang.io.internals.GlobalDataSourceStore 15 | import blang.io.internals.DataSourceReader 16 | 17 | /** 18 | * Description of an optional data source. 19 | */ 20 | class DataSource { 21 | 22 | /* 23 | * The optional path could be file system path, database connection path, etc 24 | * Not specifying the path is useful in situations such as generation from prior. 25 | */ 26 | public val Optional path 27 | 28 | @Arg @DefaultValue("CSV") 29 | @Accessors(PUBLIC_SETTER) 30 | DataSourceReader reader = new CSV 31 | 32 | @DesignatedConstructor 33 | new(@Input(formatDescription = "Path to the DataSource.") Optional path) { 34 | this.path = path 35 | } 36 | 37 | def Iterable> read() { 38 | return reader.read(path.get) 39 | } 40 | 41 | def Set columnNames() { 42 | val Map head = read().head 43 | if (head === null) { 44 | return Collections::emptySet 45 | } else { 46 | return head.keySet 47 | } 48 | } 49 | 50 | def boolean isPresent() { 51 | return path.present 52 | } 53 | 54 | def static DataSource empty() { 55 | return new DataSource(Optional.empty) 56 | } 57 | 58 | def static DataSource scopedDataSource(DataSource local, GlobalDataSourceStore global) { 59 | if (local.present) { 60 | return local 61 | } else if (global.dataSource.present) { 62 | return global.dataSource 63 | } else { 64 | return empty 65 | } 66 | } 67 | 68 | } -------------------------------------------------------------------------------- /src/main/java/blang/io/GlobalDataSource.xtend: -------------------------------------------------------------------------------- 1 | package blang.io 2 | 3 | import java.util.Optional 4 | import blang.inits.Input 5 | import blang.inits.DesignatedConstructor 6 | import blang.inits.GlobalArg 7 | import blang.io.internals.GlobalDataSourceStore 8 | 9 | /** 10 | * A DataSource made available as a default to all Plate and Plated declared afterwards. 11 | */ 12 | class GlobalDataSource extends DataSource { 13 | @DesignatedConstructor 14 | new( 15 | @Input(formatDescription = "Path to the DataSource.") Optional path, 16 | @GlobalArg GlobalDataSourceStore store 17 | ) { 18 | super(path) 19 | store.set(this) 20 | } 21 | 22 | private new(Optional path) { 23 | super(path) 24 | } 25 | 26 | def static GlobalDataSource empty() { 27 | return new GlobalDataSource(Optional.empty) 28 | } 29 | } -------------------------------------------------------------------------------- /src/main/java/blang/io/NA.xtend: -------------------------------------------------------------------------------- 1 | package blang.io 2 | 3 | class NA { 4 | 5 | /** 6 | * A symbol used when parsing variables to indicate the 7 | * value is not observed (and hence needs to be 8 | * imputed as part of the posterior simulation). 9 | */ 10 | val public static SYMBOL = "NA" 11 | 12 | def static boolean isNA(String string) { 13 | return string.trim.toUpperCase == SYMBOL 14 | } 15 | 16 | private new () {} 17 | } -------------------------------------------------------------------------------- /src/main/java/blang/io/internals/CSV.xtend: -------------------------------------------------------------------------------- 1 | package blang.io.internals 2 | 3 | import java.util.Map 4 | import briefj.BriefIO 5 | import au.com.bytecode.opencsv.CSVParser 6 | import blang.inits.Arg 7 | import blang.inits.DefaultValue 8 | import java.util.Optional 9 | import java.util.List 10 | import com.google.common.collect.Iterables 11 | import com.google.common.collect.FluentIterable 12 | import java.util.Collections 13 | import blang.types.internals.ColumnName 14 | import com.google.common.collect.Maps 15 | 16 | class CSV implements DataSourceReader { 17 | 18 | @Arg @DefaultValue(",") 19 | char separator = ',' 20 | 21 | /* 22 | * Avoid making the following two fields command line 23 | * options as they cause problem down the pipeline when 24 | * command line option values are reported 25 | */ 26 | 27 | // @Arg @DefaultValue("\"") 28 | char quotechar = "\"" 29 | 30 | // @Arg @DefaultValue("\\") 31 | char escape = "\\" 32 | 33 | @Arg @DefaultValue("false") 34 | boolean strictQuotes = false 35 | 36 | @Arg @DefaultValue("true") 37 | boolean ignoreLeadingWhiteSpace = true 38 | 39 | @Arg 40 | Optional commentCharacter 41 | 42 | override Iterable> read(String path) { 43 | val fileIterator = BriefIO::readLines(path) 44 | val CSVParser parser = new CSVParser( 45 | separator, 46 | quotechar, 47 | escape, 48 | strictQuotes, 49 | ignoreLeadingWhiteSpace 50 | ) 51 | val commentChar = commentCharacter.orElse(null) 52 | val List keys = Iterables.getFirst(fileIterator.splitCSV(parser, commentChar), Collections.EMPTY_LIST).map[String key | new ColumnName(key)] 53 | val FluentIterable> bodyIterable = fileIterator.splitCSV(parser, commentChar).skip(1) 54 | return bodyIterable.transform[List values | 55 | val int size = keys.size() 56 | if (size != values.size()) 57 | throw new RuntimeException("The number of keys should have the same length as the number of values:" + size + " vs " + values.size()); 58 | val Map result = Maps.newLinkedHashMap(); 59 | for (var int i = 0; i < size; i++) { 60 | result.put(keys.get(i), values.get(i)) 61 | } 62 | return result 63 | ] 64 | } 65 | } -------------------------------------------------------------------------------- /src/main/java/blang/io/internals/DataSourceReader.xtend: -------------------------------------------------------------------------------- 1 | package blang.io.internals 2 | 3 | import java.util.Map 4 | import blang.inits.Implementations 5 | import blang.types.internals.ColumnName 6 | import blang.io.internals.CSV 7 | 8 | @Implementations(CSV) 9 | interface DataSourceReader { 10 | /** 11 | * Iterate over all the entries in the provided path. 12 | */ 13 | def Iterable> read(String path) 14 | } 15 | -------------------------------------------------------------------------------- /src/main/java/blang/io/internals/GlobalDataSourceStore.xtend: -------------------------------------------------------------------------------- 1 | package blang.io.internals 2 | 3 | import blang.io.DataSource 4 | import blang.io.GlobalDataSource 5 | 6 | /** 7 | * Maintains the GlobalDataSource if any. 8 | */ 9 | class GlobalDataSourceStore { 10 | public DataSource dataSource = DataSource.empty 11 | def void set(GlobalDataSource dataSource) { 12 | if (this.dataSource.present) { 13 | if (dataSource.path.present) 14 | throw new RuntimeException("There can be only one global data source.") 15 | else 16 | // no harm: just ignore 17 | return 18 | } 19 | this.dataSource = dataSource 20 | } 21 | } -------------------------------------------------------------------------------- /src/main/java/blang/mcmc/CategoricalSampler.xtend: -------------------------------------------------------------------------------- 1 | package blang.mcmc 2 | 3 | import bayonet.distributions.Random 4 | import blang.core.Factor 5 | import blang.mcmc.internals.SamplerBuilderContext 6 | import java.util.List 7 | import blang.distributions.Categorical 8 | import blang.runtime.internals.objectgraph.StaticUtils 9 | import blang.core.Constrained 10 | import blang.core.LogScaleFactor 11 | import java.util.ArrayList 12 | import blang.core.WritableIntVar 13 | 14 | class CategoricalSampler implements Sampler { 15 | 16 | @SampledVariable 17 | Categorical categorical 18 | 19 | @ConnectedFactor 20 | List _factors 21 | 22 | List logScaleFactors = null 23 | 24 | override void execute(Random rand) { 25 | val int max = categorical.probabilities.nEntries 26 | val IntSliceSampler sampler = IntSliceSampler.build(categorical.getRealization as WritableIntVar, logScaleFactors, 0, max) 27 | sampler.execute(rand) 28 | } 29 | 30 | @SuppressWarnings("unchecked") 31 | override boolean setup(SamplerBuilderContext context) { 32 | if (!context.isLatent(categorical.getRealization) || 33 | !(categorical.getRealization instanceof WritableIntVar) 34 | ) { 35 | return false 36 | } 37 | /* 38 | * More complex init needed to avoid pulling too many 39 | * dependencies (i.e. those coming from categorical.probabilities 40 | */ 41 | _factors = null 42 | logScaleFactors = extractFactorsFor(categorical.getRealization, context) 43 | if (logScaleFactors === null) 44 | return false 45 | return true 46 | } 47 | 48 | static def List extractFactorsFor(Object object, SamplerBuilderContext context) { 49 | val result = new ArrayList 50 | var boolean constrainedFound = false 51 | for (Factor f : context.connectedFactors(StaticUtils.node(object))) { 52 | if (f instanceof Constrained) { 53 | if (constrainedFound) { 54 | return null 55 | } 56 | constrainedFound = true 57 | } 58 | else if (f instanceof LogScaleFactor) { 59 | result.add(f as LogScaleFactor) 60 | } 61 | else 62 | return null 63 | } 64 | return result 65 | } 66 | 67 | } 68 | -------------------------------------------------------------------------------- /src/main/java/blang/mcmc/ConnectedFactor.java: -------------------------------------------------------------------------------- 1 | package blang.mcmc; 2 | 3 | import java.lang.annotation.ElementType; 4 | import java.lang.annotation.Retention; 5 | import java.lang.annotation.RetentionPolicy; 6 | import java.lang.annotation.Target; 7 | 8 | 9 | /** 10 | * An annotation used to specify where and how MCMC samplers 11 | * can be applied in a probability models. 12 | * 13 | * When writing a class implementing an MCMC move, two main 14 | * annotations should be used. First, use SampledVariable to 15 | * specify the field that will hold a reference to the variable 16 | * to be resampled. Second, use ConnectedFactor to specify which 17 | * factors are expected to be connected to the variable. 18 | * 19 | * The rules used to match up the fields of a sampler to the 20 | * factors in a ProbabilityModel are implemented in 21 | * blang.mcmc.internals.SamplerMatchingUtils. 22 | * 23 | * @author Alexandre Bouchard (alexandre.bouchard@gmail.com) 24 | */ 25 | @Retention(RetentionPolicy.RUNTIME) 26 | @Target({ElementType.FIELD}) 27 | public @interface ConnectedFactor 28 | { 29 | } 30 | -------------------------------------------------------------------------------- /src/main/java/blang/mcmc/MHSampler.java: -------------------------------------------------------------------------------- 1 | package blang.mcmc; 2 | 3 | import java.util.List; 4 | import bayonet.distributions.Random; 5 | 6 | import blang.core.LogScaleFactor; 7 | import blang.distributions.Generators; 8 | import blang.mcmc.internals.Callback; 9 | 10 | public abstract class MHSampler implements Sampler 11 | { 12 | @ConnectedFactor 13 | protected List numericFactors; 14 | 15 | @Override 16 | public void execute(Random random) 17 | { 18 | // record likelihood before 19 | final double logBefore = logDensity(); 20 | Callback callback = new Callback() 21 | { 22 | private Double proposalLogRatio = null; 23 | @Override 24 | public void setProposalLogRatio(double logRatio) 25 | { 26 | this.proposalLogRatio = logRatio; 27 | } 28 | @Override 29 | public boolean sampleAcceptance() 30 | { 31 | if (proposalLogRatio == null) 32 | throw new RuntimeException("Use setProposalLogRatio(..) before calling sampleAcceptance()"); 33 | final double logAfter = logDensity(); 34 | final double ratio = Math.exp(proposalLogRatio + logAfter - logBefore); 35 | if (Double.isNaN(ratio)) { 36 | System.err.println("NaN MH ratio: " + proposalLogRatio + " " + logAfter + " " + logBefore); 37 | return false; 38 | } 39 | return Generators.bernoulli(random, Math.min(1.0, ratio)); 40 | } 41 | }; 42 | propose(random, callback); 43 | } 44 | 45 | private double logDensity() { 46 | double sum = 0.0; 47 | for (LogScaleFactor f : numericFactors) 48 | sum += f.logDensity(); 49 | return sum; 50 | } 51 | 52 | public abstract void propose(Random random, Callback callback); 53 | 54 | } -------------------------------------------------------------------------------- /src/main/java/blang/mcmc/SampledVariable.java: -------------------------------------------------------------------------------- 1 | package blang.mcmc; 2 | 3 | import java.lang.annotation.ElementType; 4 | import java.lang.annotation.Retention; 5 | import java.lang.annotation.RetentionPolicy; 6 | import java.lang.annotation.Target; 7 | 8 | 9 | @Retention(RetentionPolicy.RUNTIME) 10 | @Target({ElementType.FIELD}) 11 | public @interface SampledVariable 12 | { 13 | boolean skipFactorsFromSampledModel() default false; 14 | } 15 | -------------------------------------------------------------------------------- /src/main/java/blang/mcmc/Sampler.java: -------------------------------------------------------------------------------- 1 | package blang.mcmc; 2 | 3 | import bayonet.distributions.Random; 4 | import blang.mcmc.internals.SamplerBuilderContext; 5 | 6 | 7 | 8 | 9 | public interface Sampler 10 | { 11 | 12 | public void execute(Random rand); 13 | 14 | /** 15 | * @return Is the sampler compatible? 16 | */ 17 | default public boolean setup(SamplerBuilderContext context) 18 | { 19 | return true; 20 | } 21 | 22 | } 23 | -------------------------------------------------------------------------------- /src/main/java/blang/mcmc/Samplers.java: -------------------------------------------------------------------------------- 1 | package blang.mcmc; 2 | 3 | import java.lang.annotation.ElementType; 4 | import java.lang.annotation.Retention; 5 | import java.lang.annotation.RetentionPolicy; 6 | import java.lang.annotation.Target; 7 | 8 | 9 | 10 | @Retention(RetentionPolicy.RUNTIME) 11 | @Target({ElementType.TYPE}) 12 | public @interface Samplers 13 | { 14 | public Class[] value(); 15 | } 16 | -------------------------------------------------------------------------------- /src/main/java/blang/mcmc/SimplexSampler.xtend: -------------------------------------------------------------------------------- 1 | package blang.mcmc 2 | 3 | import bayonet.distributions.Random 4 | import blang.core.Constrained 5 | import blang.types.DenseSimplex 6 | import blang.core.LogScaleFactor 7 | import java.util.List 8 | import blang.mcmc.internals.SamplerBuilderContext 9 | import blang.mcmc.internals.SimplexWritableVariable 10 | 11 | class SimplexSampler implements Sampler { 12 | @SampledVariable DenseSimplex simplex 13 | @ConnectedFactor List numericFactors 14 | @ConnectedFactor Constrained constrained 15 | 16 | override void execute(Random rand) { 17 | val int sampledDim = rand.nextInt(simplex.nEntries) 18 | val SimplexWritableVariable sampled 19 | = new SimplexWritableVariable(sampledDim, simplex) 20 | val RealSliceSampler slicer 21 | = RealSliceSampler::build(sampled, numericFactors, 0.0, sampled.sum) 22 | slicer.execute(rand) 23 | } 24 | 25 | override boolean setup(SamplerBuilderContext context) { 26 | return 27 | simplex.nEntries >= 2 && 28 | constrained !== null && 29 | constrained.object instanceof DenseSimplex 30 | } 31 | } -------------------------------------------------------------------------------- /src/main/java/blang/mcmc/UniformSampler.xtend: -------------------------------------------------------------------------------- 1 | package blang.mcmc 2 | 3 | import bayonet.distributions.Random 4 | import blang.core.Factor 5 | import blang.mcmc.internals.SamplerBuilderContext 6 | import java.util.List 7 | import blang.core.LogScaleFactor 8 | import blang.core.WritableIntVar 9 | import blang.distributions.DiscreteUniform 10 | 11 | class UniformSampler implements Sampler { 12 | 13 | @SampledVariable 14 | DiscreteUniform uniform 15 | 16 | @ConnectedFactor 17 | List _factors 18 | 19 | List logScaleFactors = null 20 | 21 | override void execute(Random rand) { 22 | val int min = uniform.minInclusive.intValue 23 | val int max = uniform.maxExclusive.intValue 24 | val IntSliceSampler sampler = IntSliceSampler.build(uniform.getRealization as WritableIntVar, logScaleFactors, min, max) 25 | sampler.execute(rand) 26 | } 27 | 28 | @SuppressWarnings("unchecked") 29 | override boolean setup(SamplerBuilderContext context) { 30 | if (!context.isLatent(uniform.getRealization) || 31 | !(uniform.getRealization instanceof WritableIntVar) 32 | ) { 33 | return false 34 | } 35 | /* 36 | * More complex init needed to avoid pulling too many 37 | * dependencies (i.e. those coming from categorical.probabilities 38 | */ 39 | _factors = null 40 | logScaleFactors = CategoricalSampler::extractFactorsFor(uniform.getRealization, context) 41 | if (logScaleFactors === null) 42 | return false 43 | return true 44 | } 45 | } 46 | -------------------------------------------------------------------------------- /src/main/java/blang/mcmc/internals/BuiltSamplers.java: -------------------------------------------------------------------------------- 1 | package blang.mcmc.internals; 2 | 3 | import java.util.ArrayList; 4 | import java.util.LinkedHashSet; 5 | import java.util.List; 6 | import java.util.Set; 7 | import java.util.stream.Collectors; 8 | 9 | import blang.mcmc.Sampler; 10 | import blang.runtime.internals.objectgraph.Node; 11 | 12 | public class BuiltSamplers 13 | { 14 | public final List list = new ArrayList(); 15 | public final List correspondingVariables = new ArrayList<>(); 16 | public final Set matchingReport = new LinkedHashSet<>(); 17 | 18 | @Override 19 | public String toString() 20 | { 21 | return "" + list.size() + " samplers constructed with following prototypes:\n" + 22 | matchingReport.stream().map(line -> blang.System.out.indentString + line).collect(Collectors.joining("\n")); 23 | } 24 | } -------------------------------------------------------------------------------- /src/main/java/blang/mcmc/internals/Callback.java: -------------------------------------------------------------------------------- 1 | package blang.mcmc.internals; 2 | 3 | public interface Callback 4 | { 5 | public void setProposalLogRatio(double logRatio); 6 | public boolean sampleAcceptance(); 7 | } -------------------------------------------------------------------------------- /src/main/java/blang/mcmc/internals/SamplerBuilderContext.java: -------------------------------------------------------------------------------- 1 | package blang.mcmc.internals; 2 | 3 | import java.util.ArrayList; 4 | import java.util.LinkedHashSet; 5 | import java.util.List; 6 | import java.util.Set; 7 | import java.util.stream.Collectors; 8 | 9 | import blang.core.Factor; 10 | import blang.core.RealVar; 11 | import blang.inits.experiments.ExperimentResults; 12 | import blang.runtime.internals.objectgraph.GraphAnalysis; 13 | import blang.runtime.internals.objectgraph.Node; 14 | import blang.runtime.internals.objectgraph.ObjectNode; 15 | import blang.runtime.internals.objectgraph.StaticUtils; 16 | 17 | public class SamplerBuilderContext 18 | { 19 | private GraphAnalysis graphAnalysis; 20 | private Node sampledVariable; 21 | private Set _sampledNodes = null; 22 | public final ExperimentResults monitoringStatistics; 23 | 24 | SamplerBuilderContext(GraphAnalysis graphAnalysis, Node sampledVariable, ExperimentResults monitoringStatistics) 25 | { 26 | this.graphAnalysis = graphAnalysis; 27 | this.sampledVariable = sampledVariable; 28 | this.monitoringStatistics = monitoringStatistics; 29 | } 30 | 31 | private Set getSampledNodes() 32 | { 33 | if (_sampledNodes == null) 34 | _sampledNodes = graphAnalysis.accessibilityGraph 35 | .getAccessibleNodes(sampledVariable) 36 | .collect(Collectors.toSet()); 37 | return _sampledNodes; 38 | } 39 | 40 | public Set sampledObjectsAccessibleFrom(Factor factor) 41 | { 42 | return graphAnalysis.accessibilityGraph 43 | .getAccessibleNodes(factor) 44 | .filter(n -> getSampledNodes().contains(n)) 45 | .collect(Collectors.toSet()); 46 | } 47 | 48 | public List connectedFactors(Node node) 49 | { 50 | List result = new ArrayList<>(); 51 | for (ObjectNode n : graphAnalysis.getConnectedFactor(node)) 52 | result.add(n.object); 53 | return result; 54 | } 55 | 56 | public boolean isLatent(Object object) 57 | { 58 | return contain(graphAnalysis.getLatentVariables(), object); 59 | } 60 | 61 | public RealVar getAnnealingParameter() 62 | { 63 | return graphAnalysis.annealingParameter; 64 | } 65 | 66 | public static boolean contain(Set nodes, Object object) 67 | { 68 | return nodes.contains(StaticUtils.node(object)); 69 | } 70 | 71 | // Make sure the graph analysis does not get cloned later on 72 | // if that instance is saved somehow 73 | void tearDown() 74 | { 75 | this.graphAnalysis = null; 76 | this.sampledVariable = null; 77 | } 78 | } 79 | -------------------------------------------------------------------------------- /src/main/java/blang/mcmc/internals/SamplerBuilderOptions.java: -------------------------------------------------------------------------------- 1 | package blang.mcmc.internals; 2 | 3 | import blang.inits.Arg; 4 | import blang.inits.DefaultValue; 5 | import blang.inits.experiments.ExperimentResults; 6 | import blang.mcmc.Sampler; 7 | 8 | public class SamplerBuilderOptions 9 | { 10 | @Arg(description = "If the arguments of the annotations @Samplers should be used to " 11 | + "determine a starting set of sampler types.") 12 | @DefaultValue("true") 13 | public boolean useAnnotation = true; 14 | 15 | @Arg(description = "Samplers to be added.") 16 | public SamplerSet additional = new SamplerSet(); 17 | 18 | @Arg(description = "Samplers to be excluded (only useful if useAnnotation = true).") 19 | public SamplerSet excluded = new SamplerSet(); 20 | 21 | public ExperimentResults monitoringStatistics = new ExperimentResults(); 22 | 23 | public static SamplerBuilderOptions startWithOnly(Class thisTypeOfSampler) 24 | { 25 | SamplerBuilderOptions result = new SamplerBuilderOptions(); 26 | result.useAnnotation = false; 27 | result.additional.add(thisTypeOfSampler); 28 | return result; 29 | } 30 | } -------------------------------------------------------------------------------- /src/main/java/blang/mcmc/internals/SamplerMatch.java: -------------------------------------------------------------------------------- 1 | package blang.mcmc.internals; 2 | 3 | import java.util.LinkedHashSet; 4 | import java.util.Set; 5 | import java.util.stream.Collectors; 6 | 7 | import blang.mcmc.Sampler; 8 | 9 | public class SamplerMatch 10 | { 11 | public final Class latentClass; 12 | public final Set> matchedSamplers; 13 | public SamplerMatch(Object latentObject) { 14 | this.latentClass = latentObject.getClass(); 15 | this.matchedSamplers = new LinkedHashSet<>(); 16 | } 17 | 18 | @Override 19 | public String toString() { 20 | return latentClass.getSimpleName() + " sampled via: " + matchedSamplers.stream().map(c -> c.getSimpleName()).collect(Collectors.toList()); 21 | } 22 | 23 | @Override 24 | public int hashCode() { 25 | final int prime = 31; 26 | int result = 1; 27 | result = prime * result + ((latentClass == null) ? 0 : latentClass.hashCode()); 28 | result = prime * result + ((matchedSamplers == null) ? 0 : matchedSamplers.hashCode()); 29 | return result; 30 | } 31 | @Override 32 | public boolean equals(Object obj) { 33 | if (this == obj) 34 | return true; 35 | if (obj == null) 36 | return false; 37 | if (getClass() != obj.getClass()) 38 | return false; 39 | SamplerMatch other = (SamplerMatch) obj; 40 | if (latentClass == null) { 41 | if (other.latentClass != null) 42 | return false; 43 | } else if (!latentClass.equals(other.latentClass)) 44 | return false; 45 | if (matchedSamplers == null) { 46 | if (other.matchedSamplers != null) 47 | return false; 48 | } else if (!matchedSamplers.equals(other.matchedSamplers)) 49 | return false; 50 | return true; 51 | } 52 | } -------------------------------------------------------------------------------- /src/main/java/blang/mcmc/internals/SamplerSet.java: -------------------------------------------------------------------------------- 1 | package blang.mcmc.internals; 2 | 3 | import java.util.Collections; 4 | import java.util.LinkedHashSet; 5 | import java.util.List; 6 | import java.util.Optional; 7 | import java.util.Set; 8 | 9 | import blang.inits.DesignatedConstructor; 10 | import blang.inits.Input; 11 | import blang.mcmc.Sampler; 12 | 13 | public class SamplerSet 14 | { 15 | public Set> samplers = new LinkedHashSet<>(); 16 | 17 | @SuppressWarnings("unchecked") 18 | public void add(Class additional) 19 | { 20 | samplers.add((Class) additional); 21 | } 22 | 23 | @SuppressWarnings("unchecked") 24 | @DesignatedConstructor 25 | public static SamplerSet parse( 26 | @Input(formatDescription = "Fully qualified instances of blang.mcmc.Sampler") 27 | Optional> qualifiedNames) throws ClassNotFoundException 28 | { 29 | SamplerSet result = new SamplerSet(); 30 | for (String qualifiedName : qualifiedNames.orElse(Collections.emptyList())) 31 | result.samplers.add((Class) Class.forName(qualifiedName)); 32 | return result; 33 | } 34 | } -------------------------------------------------------------------------------- /src/main/java/blang/mcmc/internals/SimplexWritableVariable.xtend: -------------------------------------------------------------------------------- 1 | package blang.mcmc.internals 2 | 3 | import org.eclipse.xtend.lib.annotations.Data 4 | import blang.core.WritableRealVar 5 | import blang.types.DenseSimplex 6 | 7 | @Data 8 | class SimplexWritableVariable implements WritableRealVar { 9 | 10 | val int index 11 | val DenseSimplex simplex 12 | 13 | def double sum() 14 | { 15 | return simplex.get(index) + simplex.get(nextIndex); 16 | } 17 | 18 | def int nextIndex() { 19 | if (index === simplex.nEntries - 1) { 20 | return 0 21 | } else { 22 | return index + 1 23 | } 24 | } 25 | 26 | override set(double value) { 27 | val sum = sum() 28 | val complement = Math.max(0.0, sum - value) // avoid rounding errors creating negative values 29 | simplex.setPair(index, value, nextIndex, complement) 30 | } 31 | 32 | override doubleValue() { 33 | return simplex.get(index) 34 | } 35 | } -------------------------------------------------------------------------------- /src/main/java/blang/mcmc/internals/bps/Likelihood2EnergyAdaptor.java: -------------------------------------------------------------------------------- 1 | package blang.mcmc.internals.bps; 2 | 3 | import java.util.List; 4 | 5 | import org.apache.commons.math3.analysis.differentiation.FiniteDifferencesDifferentiator; 6 | 7 | import blang.core.LogScaleFactor; 8 | //import ca.ubc.bps.energies.Energy; 9 | //import ca.ubc.bps.state.MutableDouble; 10 | 11 | public class Likelihood2EnergyAdaptor //implements Energy 12 | { 13 | // private final LogScaleFactor logScaleFactor; 14 | // 15 | // private final List variables; 16 | // 17 | // @Override 18 | // public double[] gradient(double[] point) 19 | // { 20 | // use http://commons.apache.org/proper/commons-math/javadocs/api-3.4/org/apache/commons/math3/analysis/differentiation/FiniteDifferencesDifferentiator.html 21 | // dont forget to negate 22 | // 23 | // FiniteDifferencesDifferentiator xxx 24 | // 25 | // return null; 26 | // } 27 | // 28 | // @Override 29 | // public double valueAt(double[] point) 30 | // { 31 | // dont forget to negate 32 | // return 0; 33 | // } 34 | } 35 | -------------------------------------------------------------------------------- /src/main/java/blang/mcmc/internals/bps/RealVar2MutableDouble.java: -------------------------------------------------------------------------------- 1 | package blang.mcmc.internals.bps; 2 | 3 | import blang.core.WritableRealVar; 4 | //import ca.ubc.bps.state.MutableDouble; 5 | 6 | public class RealVar2MutableDouble //implements MutableDouble 7 | { 8 | // private final WritableRealVar realVar; 9 | // 10 | // public RealVar2MutableDouble(WritableRealVar realVar) 11 | // { 12 | // this.realVar = realVar; 13 | // } 14 | // 15 | // @Override 16 | // public void set(double value) 17 | // { 18 | // realVar.set(value); 19 | // } 20 | // 21 | // @Override 22 | // public double get() 23 | // { 24 | // return realVar.doubleValue(); 25 | // } 26 | // 27 | // @Override 28 | // public String toString() { 29 | // return realVar.toString(); 30 | // } 31 | } 32 | -------------------------------------------------------------------------------- /src/main/java/blang/runtime/Observations.xtend: -------------------------------------------------------------------------------- 1 | package blang.runtime 2 | 3 | import org.eclipse.xtend.lib.annotations.Data 4 | import java.util.LinkedHashSet 5 | import blang.runtime.internals.objectgraph.StaticUtils 6 | import blang.runtime.internals.objectgraph.Node 7 | 8 | /** 9 | * One of several mechanisms to set nodes as observed. Another mechanism is to make the 10 | * observed immutable. The present mechanism is need for cases such as matrices or arrays where 11 | * some but not all the entries are required to be observed. 12 | */ 13 | @Data 14 | class Observations { 15 | /** 16 | * All nodes accessible from these roots will be marked as observed in the accessibility graph analysis. 17 | */ 18 | val LinkedHashSet observationRoots = new LinkedHashSet() 19 | 20 | def T markAsObserved(T object) { 21 | val Node node = StaticUtils::node(object) 22 | observationRoots.add(node) 23 | return object 24 | } 25 | 26 | def void markAsObserved(Node node) { 27 | observationRoots.add(node) 28 | } 29 | 30 | // Catch some possible mistakes: 31 | 32 | def void markAsObserved(double object) { 33 | throw new RuntimeException 34 | } 35 | 36 | def void markAsObserved(int object) { 37 | throw new RuntimeException 38 | } 39 | } -------------------------------------------------------------------------------- /src/main/java/blang/runtime/PostProcessor.xtend: -------------------------------------------------------------------------------- 1 | package blang.runtime 2 | 3 | import blang.inits.Implementations 4 | import blang.runtime.internals.DefaultPostProcessor 5 | import blang.inits.experiments.Experiment 6 | import blang.inits.Arg 7 | import java.io.File 8 | import java.util.Optional 9 | 10 | @Implementations(DefaultPostProcessor, NoPostProcessor) 11 | abstract class PostProcessor extends Experiment { 12 | 13 | @Arg(description = "When called from Blang, this will be the latest run, otherwise point to the .exec folder created by Blang") 14 | public Optional blangExecutionDirectory 15 | 16 | static class NoPostProcessor extends PostProcessor { 17 | override run() { 18 | System.out.println("No post-processing requested. Use '--postProcessor DefaultPostProcessor' or run after the fact using 'postprocess --help'") 19 | } 20 | } 21 | 22 | } -------------------------------------------------------------------------------- /src/main/java/blang/runtime/internals/doc/Categories.java: -------------------------------------------------------------------------------- 1 | package blang.runtime.internals.doc; 2 | 3 | public class Categories 4 | { 5 | public static final String 6 | reference = "Reference", 7 | tools = "Tools"; 8 | 9 | private Categories() {} 10 | } 11 | -------------------------------------------------------------------------------- /src/main/java/blang/runtime/internals/doc/contents/BlangCLI.xtend: -------------------------------------------------------------------------------- 1 | package blang.runtime.internals.doc.contents 2 | 3 | import blang.xdoc.components.Document 4 | import blang.runtime.internals.doc.Categories 5 | import blang.xdoc.components.Code.Language 6 | 7 | class BlangCLI { 8 | 9 | public val static Document page = new Document("CLI") [ 10 | 11 | category = Categories::tools 12 | 13 | section("Installing Blang Command Line Interface (CLI)") [ 14 | 15 | it += '''The prerequisites for the CLI installation process are:''' 16 | 17 | orderedList[ 18 | 19 | it += '''A UNIX-compatible environment running «SYMB»bash«ENDSYMB». This includes, in particular, Mac OS 20 | X, where bash is the default terminal interpreter when launching Terminal.app.''' 21 | 22 | it += '''The «SYMB»git«ENDSYMB» command''' 23 | 24 | it += '''The Java Software Development Kit (SDK), version 11. The Java runtime environment is not sufficient, as 25 | compilation of models requires compilation into the Java Virtual Machine. Type «SYMB»javac -version«ENDSYMB» to 26 | test if the Java SDK is installed. If not, the Java SDK is freely available at 27 | «LINK("https://openjdk.java.net/")»https://openjdk.java.net/«ENDLINK».''' 28 | ] 29 | 30 | it += '''The following installation process is most thoroughly tested on Mac OS X, which is the primary 31 | supported platform at the moment, however users have reported installing it suc- cessfully on certain Linux 32 | and Windows configurations and we plan to expand the set of officially supported platforms to both in the near future. 33 | To install the CLI tools, input the following commands in a bash terminal interpreter: 34 | ''' 35 | 36 | code(Language::sh, ''' 37 | git clone https://github.com/UBC-Stat-ML/blangSDK.git 38 | cd blangSDK 39 | source setup-cli.sh 40 | cd .. 41 | ''') 42 | 43 | ] 44 | ] 45 | } -------------------------------------------------------------------------------- /src/main/java/blang/runtime/internals/doc/contents/BlangWeb.xtend: -------------------------------------------------------------------------------- 1 | package blang.runtime.internals.doc.contents 2 | 3 | import blang.xdoc.components.Document 4 | import blang.runtime.internals.doc.Categories 5 | 6 | class BlangWeb { 7 | 8 | public val static Document page = new Document("Blang via web") [ 9 | 10 | category = Categories::tools 11 | 12 | section("Blang on the cloud via the browser") [ 13 | 14 | it += '''Link to web app «LINK("https://silico.io")»is available here«ENDLINK».''' 15 | 16 | section("System requirements") [ 17 | it += ''' 18 | A modern browser (tested on Chrome and Firefox).''' 19 | ] 20 | 21 | section("Creating a Blang project") [ 22 | orderedList[ 23 | it += '''After signing up, create a «SYMB»Model«ENDSYMB».''' 24 | it += '''Using the gear icon, create a new file, name it to end in «SYMB».bl«ENDSYMB».''' 25 | it += '''With the gear icon again, set the file as «SYMB»entry«ENDSYMB»''' 26 | ] 27 | ] 28 | 29 | section("Using a Blang model") [ 30 | orderedList[ 31 | it += '''Click on «SYMB»Run«ENDSYMB».''' 32 | it += '''Command line options can be provided with a file name «SYMB»configuration.txt«ENDSYMB».''' 33 | ] 34 | ] 35 | 36 | ] 37 | 38 | ] 39 | 40 | } -------------------------------------------------------------------------------- /src/main/java/blang/runtime/internals/doc/contents/CreatingTypes.xtend: -------------------------------------------------------------------------------- 1 | package blang.runtime.internals.doc.contents 2 | 3 | import blang.xdoc.components.Document 4 | import blang.runtime.internals.doc.Categories 5 | 6 | class CreatingTypes { 7 | 8 | public val static Document page = new Document("Creating random types") [ 9 | 10 | category = Categories::reference 11 | 12 | section("Creating new types: overview") [ 13 | it += ''' 14 | The basic steps involved to create custom types are described in 15 | the «LINK(GettingStarted::page)»Getting Started page«ENDLINK». 16 | 17 | To handle more complicated cases, read the following: 18 | ''' 19 | unorderedList[ 20 | it += ''' 21 | In the «LINK(InferenceAndRuntime::page)»Inference and 22 | Runtime page«ENDLINK», you can find how custom samplers are 23 | automatically matched to target types. 24 | ''' 25 | it += ''' 26 | In the «LINK(InputOutput::page)»Input and Output page«ENDLINK», 27 | you can find how to load observations for the custom types, and 28 | how to output samples. 29 | ''' 30 | it += ''' 31 | In the «LINK(Testing::page)»Testing page«ENDLINK», you can find 32 | information on setting automated tests to check correctness of your 33 | implementation. 34 | ''' 35 | ] 36 | it += ''' 37 | Also consider transforming the problem of sampling your new type into 38 | a problem that can be handled using built-in sampler, which include: 39 | ''' 40 | orderedList[ 41 | it += ''' 42 | «SYMB»RealSliceSampler«ENDSYMB»: implementation of the Slice Sampler 43 | «LINK("https://projecteuclid.org/download/pdf_1/euclid.aos/1056562461")»(Neal, 2003)«ENDLINK» 44 | with doubling and shrinking. A fixed starting interval can also be provided 45 | if only the shrinking procedure is required (for example this second 46 | variant is used internally for simplex sampling in «SYMB»SimplexSampler«ENDSYMB»). 47 | ''' 48 | it += ''' 49 | «SYMB»IntSliceSampler«ENDSYMB»: which provides the same facilities as 50 | above but for integers. The fixed starting variant is used internally in 51 | categorical realization sampling, «SYMB»CategoricalSampler«ENDSYMB». 52 | ''' 53 | it += ''' 54 | «SYMB»MHSampler«ENDSYMB»: an abstract class providing a basis for custom 55 | Metropolis-Hastings samplers. See 56 | «SYMB»blang.validation.internals.fixtures.IntNaiveSampler«ENDSYMB» for an example. 57 | ''' 58 | ] 59 | ] 60 | ] 61 | 62 | } -------------------------------------------------------------------------------- /src/main/java/blang/runtime/internals/doc/contents/Empty.xtend: -------------------------------------------------------------------------------- 1 | package blang.runtime.internals.doc.contents 2 | 3 | import blang.xdoc.components.Document 4 | import blang.runtime.internals.doc.Categories 5 | 6 | 7 | class Empty { 8 | 9 | public val static Document page = new Document("Template") [ 10 | 11 | category = Categories::reference 12 | 13 | // TODO 14 | ] 15 | 16 | } 17 | 18 | -------------------------------------------------------------------------------- /src/main/java/blang/runtime/internals/doc/contents/Javadoc.xtend: -------------------------------------------------------------------------------- 1 | package blang.runtime.internals.doc.contents 2 | 3 | import blang.xdoc.components.Document 4 | import blang.runtime.internals.doc.Categories 5 | 6 | 7 | class Javadoc { 8 | 9 | public val static Document page = new Document("Javadoc") [ 10 | 11 | category = Categories::reference 12 | 13 | unorderedList[ 14 | it += '''«LINK("javadoc-inits/index.html")»Inits JavaDoc«ENDLINK»''' 15 | it += '''«LINK("javadoc-dsl/index.html")»Blang DSL JavaDoc«ENDLINK»''' 16 | it += '''«LINK("javadoc-sdk/index.html")»Blang SDK JavaDoc«ENDLINK»''' 17 | it += '''«LINK("javadoc-xlinear/index.html")»xlinear JavaDoc«ENDLINK»''' 18 | ] 19 | ] 20 | 21 | } -------------------------------------------------------------------------------- /src/main/java/blang/runtime/internals/objectgraph/AnnealingStructure.java: -------------------------------------------------------------------------------- 1 | package blang.runtime.internals.objectgraph; 2 | 3 | import java.util.ArrayList; 4 | import java.util.List; 5 | 6 | import blang.core.AnnealedFactor; 7 | import blang.core.Factor; 8 | import blang.core.LogScaleFactor; 9 | import blang.mcmc.internals.ExponentiatedFactor; 10 | import blang.types.internals.RealScalar; 11 | 12 | public class AnnealingStructure 13 | { 14 | public final RealScalar annealingParameter; 15 | 16 | // Those that are not annealed (e.g. priors, Constrained, etc) 17 | public final List fixedLogScaleFactors = new ArrayList<>(); 18 | public final List otherFactors = new ArrayList<>(); 19 | 20 | // Those that are annealed 21 | public final List exponentiatedFactors = new ArrayList<>(); 22 | public final List otherAnnealedFactors = new ArrayList<>(); // custom (not yet used) 23 | 24 | public AnnealingStructure(RealScalar annealingParameter) 25 | { 26 | this.annealingParameter = annealingParameter; 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /src/main/java/blang/runtime/internals/objectgraph/ArrayConstituentNode.java: -------------------------------------------------------------------------------- 1 | package blang.runtime.internals.objectgraph; 2 | 3 | 4 | 5 | public class ArrayConstituentNode extends ConstituentNode 6 | { 7 | public ArrayConstituentNode(Object container, Integer key) 8 | { 9 | super(container, key); 10 | } 11 | 12 | @Override 13 | public Object resolve() 14 | { 15 | if (container.getClass().getComponentType().isPrimitive()) 16 | return null; 17 | Object [] array = (Object[]) container; 18 | return array[key]; 19 | } 20 | 21 | @Override 22 | public String toStringSummary() 23 | { 24 | return "" + key; 25 | } 26 | 27 | @Override 28 | public boolean isMutable() 29 | { 30 | return true; 31 | } 32 | } -------------------------------------------------------------------------------- /src/main/java/blang/runtime/internals/objectgraph/ArrayView.java: -------------------------------------------------------------------------------- 1 | package blang.runtime.internals.objectgraph; 2 | 3 | import com.google.common.collect.ImmutableList; 4 | 5 | 6 | 7 | abstract class ArrayView 8 | { 9 | public final ImmutableList viewedIndices; 10 | 11 | public ArrayView(ImmutableList viewedIndices) 12 | { 13 | this.viewedIndices = viewedIndices; 14 | } 15 | } -------------------------------------------------------------------------------- /src/main/java/blang/runtime/internals/objectgraph/ConstituentNode.java: -------------------------------------------------------------------------------- 1 | package blang.runtime.internals.objectgraph; 2 | 3 | public abstract class ConstituentNode implements Node 4 | { 5 | /** 6 | * 7 | * @return null if a primitive, the object referred to otherwise 8 | */ 9 | public abstract Object resolve(); 10 | 11 | public boolean resolvesToObject() 12 | { 13 | return resolve() != null; 14 | } 15 | 16 | // These should stay protected and without getter/setter 17 | // Making them accessible is probably symptom of a bug, see e.g. commit b8c2f2f6df416c2d64527c373356965b1daec583 18 | protected final Object container; 19 | protected final K key; 20 | 21 | public ConstituentNode(Object container, K key) 22 | { 23 | if (container == null) 24 | throw new RuntimeException(); 25 | this.container = container; 26 | this.key = key; 27 | } 28 | 29 | @Override 30 | public String toString() 31 | { 32 | return "ConstituentNode[containerClass=" + container.getClass() + ",containerObjectId=" + System.identityHashCode(container) + ",key=" + key + "]"; 33 | } 34 | 35 | @Override 36 | public int hashCode() 37 | { 38 | final int prime = 31; 39 | int result = 1; 40 | result = prime * result 41 | + ((container == null) ? 0 : 42 | // IMPORTANT distinction from automatically generated hashCode(): 43 | // use identity hash code for the container (but not the key), 44 | // as e.g. large integer keys will not point to the same address 45 | System.identityHashCode(container)); 46 | result = prime * result + ((key == null) ? 0 : key.hashCode()); 47 | return result; 48 | } 49 | @Override 50 | public boolean equals(Object obj) 51 | { 52 | if (this == obj) 53 | return true; 54 | if (obj == null) 55 | return false; 56 | if (getClass() != obj.getClass()) 57 | return false; 58 | @SuppressWarnings("rawtypes") 59 | ConstituentNode other = (ConstituentNode) obj; 60 | if (container == null) 61 | { 62 | if (other.container != null) 63 | return false; 64 | } else if ( 65 | // IMPORTANT: see similar comment in hashCode() 66 | container != other.container) 67 | //!container.equals(other.container)) 68 | return false; 69 | if (key == null) 70 | { 71 | if (other.key != null) 72 | return false; 73 | } else if (!key.equals(other.key)) 74 | return false; 75 | return true; 76 | } 77 | } -------------------------------------------------------------------------------- /src/main/java/blang/runtime/internals/objectgraph/DeepCloner.java: -------------------------------------------------------------------------------- 1 | package blang.runtime.internals.objectgraph; 2 | 3 | import com.rits.cloning.Cloner; 4 | 5 | public class DeepCloner 6 | { 7 | public static final Cloner cloner = new Cloner(); // thread safe 8 | // cloner.nullTransient = true // not a good idea, Java SDK uses transient liberally e.g the data array in ArrayList! 9 | // override registerFastCloners() {} // not needed after all, but may want to add LinkedHashSet at some point 10 | 11 | public static T deepClone(T object) 12 | { 13 | return cloner.deepClone(object); 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /src/main/java/blang/runtime/internals/objectgraph/DoubleArrayView.java: -------------------------------------------------------------------------------- 1 | package blang.runtime.internals.objectgraph; 2 | 3 | import com.google.common.collect.ImmutableList; 4 | 5 | 6 | 7 | public final class DoubleArrayView extends ArrayView 8 | { 9 | @ViewedArray 10 | private final double[] viewedArray; 11 | 12 | public DoubleArrayView(ImmutableList viewedIndices, double[] viewedArray) 13 | { 14 | super(viewedIndices); 15 | this.viewedArray = viewedArray; 16 | } 17 | 18 | public double get(int indexIndex) 19 | { 20 | return viewedArray[viewedIndices.get(indexIndex)]; 21 | } 22 | 23 | public void set(int indexIndex, double object) 24 | { 25 | viewedArray[viewedIndices.get(indexIndex)] = object; 26 | } 27 | } -------------------------------------------------------------------------------- /src/main/java/blang/runtime/internals/objectgraph/ExplorationRule.java: -------------------------------------------------------------------------------- 1 | package blang.runtime.internals.objectgraph; 2 | 3 | import java.util.List; 4 | 5 | 6 | 7 | public interface ExplorationRule 8 | { 9 | /** 10 | * return null if the rule does not apply to this object, else, a list of constituents to recurse to 11 | */ 12 | public List> explore(Object object); 13 | } -------------------------------------------------------------------------------- /src/main/java/blang/runtime/internals/objectgraph/FieldConstituentNode.java: -------------------------------------------------------------------------------- 1 | package blang.runtime.internals.objectgraph; 2 | 3 | import java.lang.reflect.Field; 4 | import java.lang.reflect.Modifier; 5 | 6 | import briefj.ReflexionUtils; 7 | 8 | 9 | 10 | public class FieldConstituentNode extends ConstituentNode 11 | { 12 | public FieldConstituentNode(Object container, Field key) 13 | { 14 | super(container, key); 15 | } 16 | 17 | @Override 18 | public Object resolve() 19 | { 20 | if (key.getType().isPrimitive()) 21 | return null; 22 | return ReflexionUtils.getFieldValue(key, container); 23 | } 24 | 25 | @Override 26 | public boolean isMutable() 27 | { 28 | return !Modifier.isFinal(key.getModifiers()); 29 | } 30 | 31 | @Override 32 | public String toStringSummary() 33 | { 34 | return key.getName(); 35 | } 36 | } -------------------------------------------------------------------------------- /src/main/java/blang/runtime/internals/objectgraph/IntArrayView.java: -------------------------------------------------------------------------------- 1 | package blang.runtime.internals.objectgraph; 2 | 3 | import com.google.common.collect.ImmutableList; 4 | 5 | 6 | 7 | public final class IntArrayView extends ArrayView 8 | { 9 | @ViewedArray 10 | private final int[] viewedArray; 11 | 12 | public IntArrayView(ImmutableList viewedIndices, int[] viewedArray) 13 | { 14 | super(viewedIndices); 15 | this.viewedArray = viewedArray; 16 | } 17 | 18 | public int get(int indexIndex) 19 | { 20 | return viewedArray[viewedIndices.get(indexIndex)]; 21 | } 22 | 23 | public void set(int indexIndex, int object) 24 | { 25 | viewedArray[viewedIndices.get(indexIndex)] = object; 26 | } 27 | } -------------------------------------------------------------------------------- /src/main/java/blang/runtime/internals/objectgraph/MapConstituentNode.java: -------------------------------------------------------------------------------- 1 | package blang.runtime.internals.objectgraph; 2 | 3 | import java.util.Map; 4 | 5 | public class MapConstituentNode extends ConstituentNode 6 | { 7 | public MapConstituentNode(Object container, Object key) 8 | { 9 | super(container, key); 10 | } 11 | 12 | @Override 13 | public Object resolve() 14 | { 15 | Map map = (Map) container; 16 | return map.get(key); 17 | } 18 | 19 | @Override 20 | public String toStringSummary() 21 | { 22 | return "" + key; 23 | } 24 | 25 | @Override 26 | public boolean isMutable() 27 | { 28 | return true; 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /src/main/java/blang/runtime/internals/objectgraph/Node.java: -------------------------------------------------------------------------------- 1 | package blang.runtime.internals.objectgraph; 2 | 3 | /** 4 | * A node (factor, variable, or component) in the accessibility graph. 5 | * 6 | * Each such node is associated to a unique address in memory (i.e. an instance is essentially a pointer). 7 | * To be more precise, this association is established in one of two ways: 8 | * 9 | * 1. a reference to an object o (with hashCode and equals based on o's identity instead of o's potentially overloaded hashCode and equal) 10 | * 2. a reference to a container c (e.g., an array, or List) as well as a key k (in this case, hashCode and equal are based on a 11 | * combination of the identity of c, and the standard hashCode and equal of k) 12 | * 13 | * Case (1) is called an object node, and case (2) is called a constituent node. 14 | * 15 | * An important special case of a constituent node: the container c being a regular object, and the key k being a Field of c's class 16 | * 17 | * Constituent nodes are needed for example to obtain slices of 18 | * a matrix, partially observed arrays, etc. 19 | * 20 | * We assume all implementation provide appropriate hashCode and equal, in particular, by-passing custom hashCode and 21 | * equals of enclosed objects. 22 | * 23 | * @author Alexandre Bouchard (alexandre.bouchard@gmail.com) 24 | * 25 | */ 26 | public interface Node 27 | { 28 | public default String toStringSummary() 29 | { 30 | return toString(); 31 | } 32 | public boolean isMutable(); 33 | } -------------------------------------------------------------------------------- /src/main/java/blang/runtime/internals/objectgraph/ObjectArrayView.java: -------------------------------------------------------------------------------- 1 | package blang.runtime.internals.objectgraph; 2 | 3 | import com.google.common.collect.ImmutableList; 4 | 5 | 6 | 7 | public final class ObjectArrayView extends ArrayView 8 | { 9 | @ViewedArray 10 | private final T[] viewedArray; 11 | 12 | public ObjectArrayView(ImmutableList viewedIndices, T[] viewedArray) 13 | { 14 | super(viewedIndices); 15 | this.viewedArray = viewedArray; 16 | } 17 | 18 | public T get(int indexIndex) 19 | { 20 | return viewedArray[viewedIndices.get(indexIndex)]; 21 | } 22 | 23 | public void set(int indexIndex, T object) 24 | { 25 | viewedArray[viewedIndices.get(indexIndex)] = object; 26 | } 27 | } -------------------------------------------------------------------------------- /src/main/java/blang/runtime/internals/objectgraph/ObjectNode.java: -------------------------------------------------------------------------------- 1 | package blang.runtime.internals.objectgraph; 2 | 3 | 4 | public class ObjectNode implements Node 5 | { 6 | public final T object; 7 | 8 | public ObjectNode(T object) 9 | { 10 | if (object == null) 11 | throw new RuntimeException(); 12 | this.object = object; 13 | } 14 | 15 | @Override 16 | public int hashCode() 17 | { 18 | return System.identityHashCode(object); 19 | } 20 | 21 | @Override 22 | public String toString() 23 | { 24 | return "ObjectNode[class=" + object.getClass().getName() + ",objectId=" + System.identityHashCode(object) + "]"; 25 | } 26 | 27 | @Override 28 | public String toStringSummary() 29 | { 30 | return "" + object.getClass().getName() + "@" + System.identityHashCode(object); 31 | } 32 | 33 | @Override 34 | public boolean equals(Object obj) 35 | { 36 | if (this == obj) 37 | return true; 38 | if (!(obj instanceof ObjectNode)) 39 | return false; 40 | return ((ObjectNode) obj).object == this.object; 41 | } 42 | 43 | @Override 44 | public boolean isMutable() 45 | { 46 | return false; // fields or array entries only are deemed mutable 47 | } 48 | } -------------------------------------------------------------------------------- /src/main/java/blang/runtime/internals/objectgraph/SkipDependency.java: -------------------------------------------------------------------------------- 1 | package blang.runtime.internals.objectgraph; 2 | 3 | import java.lang.annotation.ElementType; 4 | import java.lang.annotation.Retention; 5 | import java.lang.annotation.RetentionPolicy; 6 | import java.lang.annotation.Target; 7 | 8 | @Retention(RetentionPolicy.RUNTIME) 9 | @Target(ElementType.FIELD) 10 | public @interface SkipDependency { 11 | public boolean isMutable(); 12 | } 13 | -------------------------------------------------------------------------------- /src/main/java/blang/runtime/internals/objectgraph/SkippedFieldConstituentNode.java: -------------------------------------------------------------------------------- 1 | package blang.runtime.internals.objectgraph; 2 | 3 | import java.lang.reflect.Field; 4 | 5 | 6 | 7 | public class SkippedFieldConstituentNode extends ConstituentNode 8 | { 9 | public SkippedFieldConstituentNode(Object container, Field key) 10 | { 11 | super(container, key); 12 | } 13 | 14 | @Override 15 | public Object resolve() 16 | { 17 | return null; 18 | } 19 | 20 | @Override 21 | public boolean isMutable() 22 | { 23 | /* 24 | * If the field is skipped, it is to hide mutable stuff, so stating that 25 | * the field is modifiable will have the correct behavior when doing recursive 26 | * analysis of mutability. 27 | * See Issue #62 in blang DSL project. 28 | */ 29 | return true; 30 | } 31 | 32 | @Override 33 | public String toStringSummary() 34 | { 35 | return key.getName(); 36 | } 37 | } -------------------------------------------------------------------------------- /src/main/java/blang/runtime/internals/objectgraph/StaticUtils.java: -------------------------------------------------------------------------------- 1 | package blang.runtime.internals.objectgraph; 2 | 3 | import java.lang.reflect.Field; 4 | import java.util.Iterator; 5 | import java.util.List; 6 | 7 | import briefj.ReflexionUtils; 8 | 9 | public class StaticUtils 10 | { 11 | public static Node node(T object) 12 | { 13 | if (object instanceof Node) 14 | return (Node) object; 15 | else 16 | return new ObjectNode(object); 17 | } 18 | 19 | @SuppressWarnings("unchecked") 20 | public static T tryCasting(Node node, Class type) 21 | { 22 | // node instanceof type 23 | if (type.isAssignableFrom(node.getClass())) 24 | return (T) node; 25 | else if (node instanceof ObjectNode) 26 | { 27 | ObjectNode objectNode = (ObjectNode) node; 28 | if (type.isAssignableFrom(objectNode.object.getClass())) 29 | return (T) objectNode.object; 30 | } 31 | return null; 32 | } 33 | 34 | public static List getDeclaredFields(Class aClass) 35 | { 36 | List result = ReflexionUtils.getDeclaredFields(aClass, true); 37 | Iterator resultsIter = result.iterator(); 38 | while (resultsIter.hasNext()) 39 | if (resultsIter.next().getName().equals("$jacocoData")) // work around required for checking test-case coverage 40 | resultsIter.remove(); 41 | return result; 42 | } 43 | 44 | private StaticUtils() {} 45 | } 46 | -------------------------------------------------------------------------------- /src/main/java/blang/runtime/internals/objectgraph/VariableUtils.java: -------------------------------------------------------------------------------- 1 | package blang.runtime.internals.objectgraph; 2 | 3 | import java.lang.annotation.Annotation; 4 | import java.util.LinkedHashSet; 5 | import java.util.Set; 6 | 7 | import blang.mcmc.Sampler; 8 | import blang.mcmc.Samplers; 9 | import blang.runtime.internals.RecursiveAnnotationProducer; 10 | 11 | public class VariableUtils 12 | { 13 | public static boolean isVariable(Class c) 14 | { 15 | for (Annotation a : c.getAnnotations()) 16 | if (a instanceof Samplers) 17 | return true; 18 | return false; 19 | } 20 | 21 | public static Set> annotatedSamplers(Class latentNode) 22 | { 23 | Set> result = new LinkedHashSet<>(); 24 | RecursiveAnnotationProducer> annotationsProducer = RecursiveAnnotationProducer.ofClasses(Samplers.class, true); 25 | result.addAll(annotationsProducer.getProducts(latentNode)); 26 | return result; 27 | } 28 | 29 | private VariableUtils() {} 30 | } 31 | -------------------------------------------------------------------------------- /src/main/java/blang/runtime/internals/objectgraph/ViewedArray.java: -------------------------------------------------------------------------------- 1 | package blang.runtime.internals.objectgraph; 2 | 3 | import java.lang.annotation.ElementType; 4 | import java.lang.annotation.Retention; 5 | import java.lang.annotation.RetentionPolicy; 6 | import java.lang.annotation.Target; 7 | 8 | 9 | 10 | @Retention(RetentionPolicy.RUNTIME) 11 | @Target({ElementType.FIELD}) 12 | public @interface ViewedArray 13 | { 14 | } -------------------------------------------------------------------------------- /src/main/java/blang/types/AnnealingParameter.xtend: -------------------------------------------------------------------------------- 1 | package blang.types 2 | 3 | import blang.core.RealVar 4 | 5 | class AnnealingParameter implements RealVar { 6 | var RealVar param = null 7 | 8 | def void _set(RealVar _param) { 9 | this.param = _param 10 | } 11 | 12 | override doubleValue() { 13 | return param.doubleValue 14 | } 15 | 16 | } -------------------------------------------------------------------------------- /src/main/java/blang/types/DenseSimplex.xtend: -------------------------------------------------------------------------------- 1 | package blang.types 2 | 3 | import org.eclipse.xtend.lib.annotations.Delegate 4 | import blang.mcmc.Samplers 5 | import blang.mcmc.SimplexSampler 6 | import org.eclipse.xtend.lib.annotations.Accessors 7 | import blang.types.internals.Delegator 8 | import xlinear.DenseMatrix 9 | import bayonet.math.NumericalUtils 10 | 11 | import static extension xlinear.MatrixExtensions.* 12 | import xlinear.internals.MatrixVisitorEditInPlace 13 | 14 | /** Vector of entries summing to one. 15 | * 16 | * We do not enforce positive constraints here to facilitate undoing sampling moves. 17 | */ 18 | @Samplers(SimplexSampler) 19 | class DenseSimplex implements Simplex, DenseMatrix, Delegator { 20 | @Accessors(PUBLIC_GETTER) 21 | @Delegate DenseMatrix delegate 22 | 23 | new (DenseMatrix matrix) { 24 | NumericalUtils::checkIsClose(matrix.sum, 1.0) 25 | this.delegate = matrix 26 | } 27 | 28 | override void editInPlace(MatrixVisitorEditInPlace visitor) { 29 | delegate.editInPlace(visitor) 30 | NumericalUtils::checkIsClose(this.sum, 1.0) 31 | } 32 | 33 | /** Set a pair of entries, checking their sum is the same before and after */ 34 | def void setPair(int index1, double value1, int index2, double value2) { 35 | val double old = get(index1) + get(index2) 36 | NumericalUtils::checkIsClose(old, value1 + value2) 37 | delegate.set(index1, value1) 38 | delegate.set(index2, value2) 39 | } 40 | 41 | override void set(int i, int j, double value) { 42 | throw new RuntimeException("Use setPair instead"); 43 | } 44 | 45 | override void set(int i, double value) { 46 | throw new RuntimeException("Use setPair instead"); 47 | } 48 | 49 | } -------------------------------------------------------------------------------- /src/main/java/blang/types/DenseTransitionMatrix.xtend: -------------------------------------------------------------------------------- 1 | package blang.types 2 | 3 | import org.eclipse.xtend.lib.annotations.Data 4 | import org.eclipse.xtend.lib.annotations.Delegate 5 | import org.eclipse.xtend.lib.annotations.Accessors 6 | import blang.types.internals.Delegator 7 | import xlinear.DenseMatrix 8 | import xlinear.internals.MatrixVisitorEditInPlace 9 | 10 | import static extension xlinear.MatrixExtensions.* 11 | import bayonet.math.NumericalUtils 12 | 13 | /** Matrix where each row is a DenseSimplex. */ 14 | @Data 15 | class DenseTransitionMatrix implements TransitionMatrix, DenseMatrix, Delegator { 16 | @Accessors(PUBLIC_GETTER) 17 | @Delegate 18 | val DenseMatrix delegate 19 | 20 | /** Get a view into a row. */ 21 | override DenseSimplex row(int i) { 22 | return new DenseSimplex(delegate.row(i)) 23 | } 24 | 25 | override void editInPlace(MatrixVisitorEditInPlace visitor) { 26 | delegate.editInPlace(visitor) 27 | for (var int rowIndex = 0; rowIndex < nRows; rowIndex++) { 28 | NumericalUtils::checkIsClose(row(rowIndex).sum, 1.0) 29 | } 30 | } 31 | } -------------------------------------------------------------------------------- /src/main/java/blang/types/Index.xtend: -------------------------------------------------------------------------------- 1 | package blang.types 2 | 3 | import org.eclipse.xtend.lib.annotations.Data 4 | 5 | /** 6 | * An Index of type K in a specified Plate. 7 | * 8 | * K: the type of key, such as Integer, String, date or space coordinate 9 | * It is assumed that K is Immutable (and not a random variable). 10 | */ 11 | @Data // important! this is used in hash tables 12 | class Index { 13 | public val Plate plate 14 | public val K key 15 | } -------------------------------------------------------------------------------- /src/main/java/blang/types/Simplex.java: -------------------------------------------------------------------------------- 1 | package blang.types; 2 | 3 | import xlinear.Matrix; 4 | 5 | public interface Simplex extends Matrix 6 | { 7 | 8 | } 9 | -------------------------------------------------------------------------------- /src/main/java/blang/types/SpikedRealVar.xtend: -------------------------------------------------------------------------------- 1 | package blang.types 2 | 3 | import blang.core.RealVar 4 | import blang.core.IntVar 5 | import blang.types.StaticUtils 6 | 7 | class SpikedRealVar implements RealVar { 8 | public val IntVar selected = StaticUtils::latentInt 9 | public val RealVar continuousPart = StaticUtils::latentReal 10 | 11 | override doubleValue() { 12 | if (selected.intValue < 0 || selected.intValue > 1) 13 | StaticUtils::invalidParameter() 14 | if (selected.intValue == 0) return 0.0 15 | else return continuousPart.doubleValue 16 | } 17 | 18 | override toString() { "" + doubleValue } 19 | } -------------------------------------------------------------------------------- /src/main/java/blang/types/TransitionMatrix.java: -------------------------------------------------------------------------------- 1 | package blang.types; 2 | 3 | import xlinear.Matrix; 4 | 5 | public interface TransitionMatrix extends Matrix 6 | { 7 | public Simplex row(int i); 8 | 9 | @Override 10 | default public void set(int i, int j, double value) { 11 | throw new RuntimeException("Get row(..).setPair instead"); 12 | } 13 | } 14 | -------------------------------------------------------------------------------- /src/main/java/blang/types/internals/ColumnName.xtend: -------------------------------------------------------------------------------- 1 | package blang.types.internals 2 | 3 | import org.eclipse.xtend.lib.annotations.Data import blang.inits.DesignatedConstructor 4 | import blang.inits.Input 5 | import com.rits.cloning.Immutable 6 | 7 | /** 8 | * Column names for data set read from files or databases. 9 | * Case insensitive and spaces are dropped. 10 | */ 11 | @Data 12 | @Immutable 13 | class ColumnName { 14 | public val String string 15 | @DesignatedConstructor 16 | new(@Input String string) { 17 | this.string = string.replaceAll("\\s+", "").toLowerCase 18 | } 19 | override String toString() { string } 20 | } -------------------------------------------------------------------------------- /src/main/java/blang/types/internals/Delegator.java: -------------------------------------------------------------------------------- 1 | package blang.types.internals; 2 | 3 | public interface Delegator 4 | { 5 | T getDelegate(); 6 | } 7 | -------------------------------------------------------------------------------- /src/main/java/blang/types/internals/HashPlate.xtend: -------------------------------------------------------------------------------- 1 | package blang.types.internals 2 | 3 | import blang.types.Plate 4 | import java.util.Map 5 | import java.util.LinkedHashMap 6 | import java.util.Set 7 | import blang.types.Index 8 | import org.eclipse.xtend.lib.annotations.Accessors 9 | import java.util.LinkedHashSet 10 | import blang.io.DataSource 11 | import java.util.Optional 12 | import com.rits.cloning.Immutable 13 | import java.util.Collection 14 | 15 | /** 16 | * A Plate using a DataSource to load and store indices in a hash table. 17 | */ 18 | @Immutable 19 | class HashPlate implements Plate { 20 | 21 | @Accessors(PUBLIC_GETTER) 22 | val ColumnName name 23 | 24 | /** 25 | * Maximum number of items to load. 26 | */ 27 | val int maxSize 28 | 29 | val Map>> indices = new LinkedHashMap 30 | 31 | val IndexedDataSource index 32 | 33 | val Parser parser 34 | 35 | override Collection> indices(Query query) { 36 | if (indices.containsKey(query)) { 37 | return indices.get(query) 38 | } 39 | val Set keys = index.getStrings(query) 40 | val Set> result = new LinkedHashSet 41 | var int i = 0 42 | for (String key : keys) { 43 | if (i++ < maxSize) { 44 | val K parsed = parser.parse(key) 45 | result.add(new Index(this, parsed)) 46 | } 47 | } 48 | indices.put(query, result) 49 | return result 50 | } 51 | 52 | override K parse(String string) { 53 | return parser.parse(string) 54 | } 55 | 56 | /** 57 | * If optionalMaxSize missing, maxSize is set to Integer.MAX_VALUE. 58 | */ 59 | new(ColumnName name, DataSource dataSource, Parser parser, Optional optionalMaxSize) { 60 | this.name = name 61 | this.index = new IndexedDataSource(name, dataSource) 62 | this.parser = parser 63 | this.maxSize = optionalMaxSize.orElse(Integer.MAX_VALUE) 64 | } 65 | 66 | override String toString() { 67 | return name.string 68 | } 69 | } -------------------------------------------------------------------------------- /src/main/java/blang/types/internals/HashPlated.xtend: -------------------------------------------------------------------------------- 1 | package blang.types.internals 2 | 3 | import blang.types.Plated 4 | import blang.io.DataSource 5 | import java.util.Map 6 | import blang.io.NA 7 | import java.util.LinkedHashMap 8 | import blang.types.internals.Query.QueryType 9 | 10 | /** 11 | * A Plated using a DataSource to load and store random variables or parameters in a hash table. 12 | */ 13 | class HashPlated implements Plated { 14 | 15 | val ColumnName columnName 16 | 17 | val Map variables = new LinkedHashMap 18 | 19 | val IndexedDataSource index 20 | 21 | val Parser parser 22 | 23 | override T get(Query query) { 24 | if (variables.containsKey(query)) { 25 | return variables.get(query) 26 | } 27 | val T result = parser.parse(getString(query)) 28 | variables.put(query, result) 29 | return result 30 | } 31 | 32 | private def String getString(Query query) { 33 | if (!index.dataSource.present) { 34 | return NA::SYMBOL 35 | } 36 | return index.getString(query) 37 | } 38 | 39 | new(ColumnName columnName, DataSource dataSource, Parser parser) { 40 | this.columnName = columnName 41 | this.index = new IndexedDataSource(columnName, dataSource) 42 | this.parser = parser 43 | } 44 | 45 | override entries() { 46 | return variables.entrySet 47 | } 48 | 49 | override String toString() { 50 | return toString(this) 51 | } 52 | 53 | def static String toString(Plated plated) { 54 | val StringBuilder result = new StringBuilder 55 | var boolean first = true 56 | for (entry : plated.entries) { 57 | if (first) { 58 | result.append(entry.key.indices.map[plate.name].join("\t") + "\tvalue" + "\n") 59 | first = false 60 | } 61 | result.append(entry.key.indices.map[key].join("\t") + "\t" + entry.value + "\n") 62 | } 63 | return result.toString 64 | } 65 | } -------------------------------------------------------------------------------- /src/main/java/blang/types/internals/IndexedDataSource.xtend: -------------------------------------------------------------------------------- 1 | package blang.types.internals 2 | 3 | import blang.io.DataSource 4 | import java.util.Map 5 | import java.util.Set 6 | import blang.types.internals.Query.QueryType 7 | import java.util.LinkedHashMap 8 | import blang.types.Plate 9 | import blang.types.Index 10 | import briefj.BriefMaps 11 | import org.eclipse.xtend.lib.annotations.Accessors 12 | import com.rits.cloning.Immutable 13 | 14 | /** 15 | * Utility to quickly access entries in DataSource. 16 | */ 17 | @Immutable 18 | class IndexedDataSource { 19 | 20 | val ColumnName columnName 21 | 22 | @Accessors(PUBLIC_GETTER) 23 | val DataSource dataSource 24 | 25 | val Map>> cache = new LinkedHashMap 26 | 27 | new(ColumnName columnName, DataSource dataSource) { 28 | this.columnName = columnName 29 | this.dataSource = dataSource 30 | } 31 | 32 | def Set getStrings(Query query) { 33 | val QueryType queryType = query.type 34 | if (!cache.containsKey(queryType)) { 35 | computeCache(queryType) 36 | } 37 | return cache.get(queryType).get(query) 38 | } 39 | 40 | def String getString(Query query) { 41 | val Set strings = getStrings(query) 42 | if (strings === null) { 43 | return null 44 | } 45 | if (strings.size > 1) { 46 | throw new RuntimeException("More than one match for " + query) 47 | } 48 | return strings.iterator.next 49 | } 50 | 51 | def void computeCache(QueryType queryType) { 52 | val Map> currentCache = new LinkedHashMap 53 | for (Map line : dataSource.read) { 54 | val Query curQuery = Query.build 55 | for (Plate curPlate : queryType.plates) { 56 | val Object parsed = curPlate.parse(line.get(curPlate.name)) 57 | curQuery.indices.add(new Index(curPlate, parsed)) 58 | } 59 | BriefMaps.getOrPutSet(currentCache, curQuery).add(line.get(columnName)) 60 | } 61 | cache.put(queryType, currentCache) 62 | } 63 | } -------------------------------------------------------------------------------- /src/main/java/blang/types/internals/IntScalar.xtend: -------------------------------------------------------------------------------- 1 | package blang.types.internals 2 | 3 | import blang.mcmc.Samplers 4 | import blang.core.WritableIntVar 5 | import blang.mcmc.IntSliceSampler 6 | 7 | /** A latent integer random variable. */ 8 | @Samplers(IntSliceSampler) 9 | class IntScalar implements WritableIntVar { 10 | 11 | var int value 12 | 13 | new(int value) { this.value = value } 14 | 15 | override int intValue() { 16 | return value 17 | } 18 | 19 | override void set(int newValue) { 20 | this.value = newValue 21 | } 22 | 23 | override String toString() { 24 | return Integer.toString(value) 25 | } 26 | } -------------------------------------------------------------------------------- /src/main/java/blang/types/internals/InvalidParameter.java: -------------------------------------------------------------------------------- 1 | package blang.types.internals; 2 | 3 | public class InvalidParameter extends RuntimeException 4 | { 5 | private static final long serialVersionUID = 1L; 6 | 7 | public static final InvalidParameter instance = new InvalidParameter(); 8 | 9 | private InvalidParameter() 10 | { 11 | super("Invalid parameter. Assigned Double.NEGATIVE_INFINITY to that factor."); 12 | } 13 | 14 | @Override 15 | public synchronized Throwable fillInStackTrace() 16 | { 17 | return this; 18 | } 19 | } 20 | -------------------------------------------------------------------------------- /src/main/java/blang/types/internals/LatentFactoryAsParser.xtend: -------------------------------------------------------------------------------- 1 | package blang.types.internals 2 | 3 | import org.eclipse.xtend.lib.annotations.Data 4 | import java.util.function.Supplier 5 | import static extension blang.io.NA.isNA 6 | import com.rits.cloning.Immutable 7 | 8 | @Data 9 | @Immutable 10 | class LatentFactoryAsParser implements Parser { 11 | 12 | val Supplier supplier 13 | 14 | override parse(String string) { 15 | if (string.isNA) { 16 | return supplier.get 17 | } else { 18 | throw new RuntimeException 19 | } 20 | } 21 | 22 | } -------------------------------------------------------------------------------- /src/main/java/blang/types/internals/Parser.xtend: -------------------------------------------------------------------------------- 1 | package blang.types.internals 2 | 3 | 4 | /** 5 | * Utility for HashPlate and HashPlated. 6 | */ 7 | @FunctionalInterface 8 | interface Parser { 9 | def T parse(String string) 10 | } -------------------------------------------------------------------------------- /src/main/java/blang/types/internals/PlatedSlice.xtend: -------------------------------------------------------------------------------- 1 | package blang.types.internals 2 | 3 | import blang.types.Plated 4 | import blang.runtime.internals.objectgraph.SkipDependency 5 | import java.util.Map 6 | import java.util.LinkedHashMap 7 | 8 | /** 9 | * Variables loaded through this class will be inserted into 10 | * the parent(s) as well. Not vice versa. 11 | * 12 | * Dependency analysis for instance of PlatedSlice 13 | * will only point to the entries in the 14 | * slice, as expected. 15 | */ 16 | class PlatedSlice implements Plated { 17 | 18 | val Map variables = new LinkedHashMap 19 | 20 | val Query sliceIndices 21 | 22 | @SkipDependency(isMutable = false) // Otherwise dependency would be too large 23 | val Plated parent 24 | 25 | override T get(Query query) { 26 | query.indices.addAll(sliceIndices.indices) 27 | if (variables.containsKey(query)) { 28 | return variables.get(query) 29 | } 30 | val T result = parent.get(query.indices) 31 | variables.put(query, result) 32 | return result 33 | } 34 | 35 | override entries() { 36 | return variables.entrySet 37 | } 38 | 39 | new(Plated parent, Query sliceIndices) { 40 | this.parent = parent 41 | this.sliceIndices = sliceIndices 42 | } 43 | 44 | override String toString() { 45 | HashPlated::toString(this) 46 | } 47 | } -------------------------------------------------------------------------------- /src/main/java/blang/types/internals/Query.xtend: -------------------------------------------------------------------------------- 1 | package blang.types.internals 2 | 3 | import org.eclipse.xtend.lib.annotations.Data 4 | import java.util.LinkedHashSet 5 | import blang.types.Index 6 | import blang.types.Plate 7 | import java.util.Set 8 | import com.rits.cloning.Immutable 9 | 10 | /** 11 | * Utility for HashPlate and HashPlated. 12 | */ 13 | @Data // important! this is used in hash tables 14 | @Immutable 15 | class Query { 16 | 17 | val Set> indices 18 | 19 | // While queries can use arbitrary order, for storage we expect a deterministic order 20 | def static Query build(Index ... indices) { 21 | return new Query(new LinkedHashSet(indices)) 22 | } 23 | def QueryType type() { 24 | return new QueryType(indices.map[index | index.plate].toSet) 25 | } 26 | @Data 27 | @Immutable 28 | static class QueryType { 29 | val Set> plates 30 | } 31 | } -------------------------------------------------------------------------------- /src/main/java/blang/types/internals/RealScalar.xtend: -------------------------------------------------------------------------------- 1 | package blang.types.internals 2 | 3 | import blang.mcmc.Samplers 4 | import blang.core.WritableRealVar 5 | import blang.mcmc.RealSliceSampler 6 | 7 | /** A latent random real variable. */ 8 | @Samplers(RealSliceSampler) 9 | class RealScalar implements WritableRealVar { 10 | 11 | var double value = 0.0 12 | 13 | new (double value) { this.value = value } 14 | 15 | override double doubleValue() { 16 | return value 17 | } 18 | 19 | override void set(double newValue) { 20 | this.value = newValue 21 | } 22 | 23 | override String toString() { 24 | return Double.toString(value) 25 | } 26 | } -------------------------------------------------------------------------------- /src/main/java/blang/types/internals/SimpleParser.xtend: -------------------------------------------------------------------------------- 1 | package blang.types.internals 2 | 3 | import org.eclipse.xtend.lib.annotations.Data 4 | import blang.inits.Creator 5 | import com.google.inject.TypeLiteral 6 | import com.rits.cloning.Immutable 7 | 8 | @Data 9 | @Immutable 10 | class SimpleParser implements Parser { 11 | val Creator creator 12 | val TypeLiteral typeArgument 13 | override parse(String string) { 14 | try { 15 | return creator.init(typeArgument, blang.inits.parsing.SimpleParser.parse(string)) 16 | } catch (Exception e) { 17 | throw new RuntimeException("Failed to parse " + string + " as " + typeArgument) // removed details as creator.errorReport crashes if not yet initialized (i.e. crashed in SimpleParser.parse(..)) + ", details:\n" + creator.errorReport) 18 | } 19 | } 20 | } -------------------------------------------------------------------------------- /src/main/java/blang/types/internals/SimplePlate.xtend: -------------------------------------------------------------------------------- 1 | package blang.types.internals 2 | 3 | import blang.types.Plate 4 | import blang.types.Index 5 | import org.eclipse.xtend.lib.annotations.Accessors 6 | import java.util.Set 7 | import java.util.LinkedHashSet 8 | import com.rits.cloning.Immutable 9 | import java.util.Collection 10 | 11 | /** 12 | * Plate implementation based on an explicit list of indices. 13 | */ 14 | @Immutable 15 | class SimplePlate implements Plate { 16 | 17 | @Accessors(PUBLIC_GETTER) 18 | val ColumnName name 19 | 20 | val Set> indices 21 | 22 | /** 23 | * Assume a non-jagged array so that parentIndices are ignored. 24 | */ 25 | override Collection> indices(Query parentIndices) { 26 | return indices 27 | } 28 | 29 | /** 30 | * This is not needed for SimplePlates. 31 | */ 32 | override parse(String string) { 33 | throw new UnsupportedOperationException 34 | } 35 | 36 | new(String name, Set keys) { 37 | this(new ColumnName(name), keys) 38 | } 39 | 40 | new(ColumnName name, Set keys) { 41 | this.name = name 42 | this.indices = new LinkedHashSet 43 | for (T key : keys) { 44 | indices.add(new Index(this, key)) 45 | } 46 | } 47 | 48 | override String toString() { 49 | return name.string 50 | } 51 | } -------------------------------------------------------------------------------- /src/main/java/blang/validation/DeterminismTest.java: -------------------------------------------------------------------------------- 1 | package blang.validation; 2 | 3 | import java.util.function.Function; 4 | 5 | import org.apache.commons.math3.stat.descriptive.SummaryStatistics; 6 | import org.junit.Assert; 7 | 8 | import com.google.common.collect.Table; 9 | 10 | import bayonet.distributions.Random; 11 | import blang.inits.Arg; 12 | import blang.inits.DefaultValue; 13 | import blang.mcmc.Sampler; 14 | import blang.runtime.SampledModel; 15 | 16 | public class DeterminismTest 17 | { 18 | @Arg @DefaultValue("10") 19 | public int nIndependentSamples = 10; 20 | 21 | @Arg @DefaultValue("10") 22 | public int nPosteriorSamplesPerIndep = 10; 23 | 24 | public void check(Instance instance) 25 | { 26 | System.out.print("Running DeterminismTest on model " + instance.model.getClass().getSimpleName()); 27 | for (Class currentSamplerType : instance.samplerTypes()) 28 | { 29 | System.out.print(" [" + currentSamplerType.getSimpleName() + "]"); 30 | SampledModel sampledModel = instance.restrictedSampledModel(currentSamplerType); 31 | checkDeterministic(sampledModel, instance.testFunctions, false); 32 | checkDeterministic(sampledModel, instance.testFunctions, true); 33 | } 34 | System.out.println(); 35 | } 36 | 37 | private void checkDeterministic(SampledModel sampledModel, Function[] testFunctions, boolean usePosterior) 38 | { 39 | 40 | Table, Double> 41 | list1 = ExactInvarianceTest.sample(new Random(1), sampledModel, testFunctions, usePosterior, nIndependentSamples, nPosteriorSamplesPerIndep, new SummaryStatistics()), 42 | list2 = ExactInvarianceTest.sample(new Random(1), sampledModel, testFunctions, usePosterior, nIndependentSamples, nPosteriorSamplesPerIndep, new SummaryStatistics()); 43 | Assert.assertTrue( 44 | "Problem with model " + sampledModel.model.getClass().getSimpleName() + ": " + 45 | (usePosterior ? 46 | "Posterior simulation should be deterministic given a random seed. " 47 | + "Problematic kernel: " + sampledModel.getPosteriorInvariantSamplers().get(0).getClass().getSimpleName() : 48 | "Forward simulation should be deterministic given a random seed.") + "\n" + 49 | list1.toString() + "\nvs\n" + list2.toString(), 50 | list1.equals(list2)); 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /src/main/java/blang/validation/Instance.xtend: -------------------------------------------------------------------------------- 1 | package blang.validation 2 | 3 | import blang.core.Model 4 | import blang.mcmc.Sampler 5 | import blang.mcmc.internals.BuiltSamplers 6 | import blang.mcmc.internals.SamplerBuilder 7 | import blang.mcmc.internals.SamplerBuilderOptions 8 | import blang.runtime.Observations 9 | import blang.runtime.SampledModel 10 | import blang.runtime.internals.objectgraph.GraphAnalysis 11 | import java.util.LinkedHashSet 12 | import java.util.Set 13 | import java.util.function.Function 14 | 15 | class Instance { 16 | public val M model 17 | public val Function [] testFunctions 18 | public val GraphAnalysis graphAnalysis 19 | public val BuiltSamplers allKernels 20 | public val SampledModel sampledModel 21 | new (M model, Function ... testFunctions) { 22 | this(model, new SamplerBuilderOptions(), testFunctions) 23 | } 24 | new (M model, SamplerBuilderOptions samplerOptions, Function ... testFunctions) { 25 | this.model = model 26 | this.testFunctions = testFunctions 27 | this.graphAnalysis = new GraphAnalysis(model, new Observations()) 28 | this.allKernels = SamplerBuilder.build(graphAnalysis, samplerOptions) 29 | if (allKernels.list.isEmpty()) 30 | throw new RuntimeException("No kernels produced by model to be tested") 31 | this.sampledModel = new SampledModel(graphAnalysis, allKernels); 32 | } 33 | def Set> samplerTypes() { 34 | val Set> samplerTypes = new LinkedHashSet() 35 | for (Sampler sampler : allKernels.list) { 36 | samplerTypes.add(sampler.getClass()) 37 | } 38 | return samplerTypes 39 | } 40 | def SampledModel restrictedSampledModel(Class currentSamplerType) { 41 | val SamplerBuilderOptions options = SamplerBuilderOptions.startWithOnly(currentSamplerType) 42 | val BuiltSamplers currentKernel = SamplerBuilder.build(graphAnalysis, options) 43 | return new SampledModel(graphAnalysis, currentKernel) 44 | } 45 | } -------------------------------------------------------------------------------- /src/main/java/blang/validation/NormalizationTest.java: -------------------------------------------------------------------------------- 1 | package blang.validation; 2 | 3 | import org.apache.commons.math3.analysis.UnivariateFunction; 4 | import org.apache.commons.math3.analysis.integration.SimpsonIntegrator; 5 | import org.apache.commons.math3.analysis.integration.UnivariateIntegrator; 6 | import org.junit.Assert; 7 | 8 | import bayonet.math.NumericalUtils; 9 | import blang.core.DistributionAdaptor; 10 | import blang.core.RealDistribution; 11 | import blang.core.RealDistributionAdaptor; 12 | import blang.core.RealVar; 13 | import blang.core.UnivariateModel; 14 | 15 | public class NormalizationTest 16 | { 17 | protected UnivariateIntegrator integrator = new SimpsonIntegrator(); 18 | protected int maxEval = 100_000_000; 19 | protected int initialAutoRadius = 8; 20 | protected int maxNExpansions = 10; 21 | 22 | protected void checkNormalization(UnivariateModel distribution) 23 | { 24 | checkNormalization(distribution, Double.NaN, Double.NaN); 25 | } 26 | 27 | protected void checkNormalization(UnivariateModel distribution, double left, double right) 28 | { 29 | System.out.println("Checking normalization for " + distribution.getClass().getSimpleName()); 30 | 31 | DistributionAdaptor adaptor = new DistributionAdaptor<>(distribution); 32 | RealDistribution realDist = new RealDistributionAdaptor(adaptor); 33 | UnivariateFunction function = x -> Math.exp(realDist.logDensity(x)); 34 | 35 | boolean expand = Double.isNaN(left); 36 | if (expand) 37 | { 38 | left = - initialAutoRadius; 39 | right = initialAutoRadius; 40 | } 41 | 42 | for (int i = 0; i < maxNExpansions; i++) 43 | { 44 | 45 | double integral = integrator.integrate(maxEval, function, left, right); 46 | System.out.println("\tIntegrating from " + left + " -- " + right + " -> " + integral); 47 | if (integral > 1.0 + NumericalUtils.THRESHOLD) 48 | Assert.fail("Normalization greater than one."); 49 | if (NumericalUtils.isClose(1.0, integral, NumericalUtils.THRESHOLD)) 50 | return; 51 | 52 | if (!expand) 53 | break; 54 | 55 | left *= 2.0; 56 | right *= 2.0; 57 | } 58 | Assert.fail("Seems to normalize to less than one."); 59 | 60 | } 61 | } 62 | -------------------------------------------------------------------------------- /src/main/java/blang/validation/UnbiasnessTest.xtend: -------------------------------------------------------------------------------- 1 | package blang.validation 2 | 3 | import java.util.function.Supplier 4 | import bayonet.distributions.ExhaustiveDebugRandom 5 | 6 | class UnbiasnessTest { 7 | def static double expectedZEstimate(Supplier logZEstimator, ExhaustiveDebugRandom exhausiveRand) { 8 | var expectation = 0.0 9 | var nProgramTraces = 0 10 | while (exhausiveRand.hasNext) { 11 | val logZ = logZEstimator.get 12 | expectation += Math.exp(logZ) * exhausiveRand.lastProbability 13 | nProgramTraces++ 14 | } 15 | println("nProgramTraces = " + nProgramTraces) 16 | return expectation 17 | } 18 | } -------------------------------------------------------------------------------- /src/main/java/blang/validation/internals/fixtures/AutoBoxDeboxTests.bl: -------------------------------------------------------------------------------- 1 | package blang.validation.internals.fixtures 2 | 3 | import java.util.function.Function 4 | import blang.types.internals.IntScalar 5 | 6 | model AutoBoxDeboxTests { 7 | laws { 8 | logf() { 9 | 10 | // Boxing/unboxing 11 | 12 | val Function d2d = Function.identity 13 | val Function i2i = Function.identity 14 | val RealVar rv = latentReal 15 | val IntVar iv = new IntScalar(1) 16 | val List list = #[1,2] 17 | 18 | d2d.apply(rv) // test 1: realvar -> Double 19 | // d2d.apply(iv) // test 2: intvar -> Double // does not work, apparently java unboxing int -> Double not supported 20 | // val Double test = iv // same as above 21 | i2i.apply(iv) // test 3: intvar -> Integer 22 | 23 | Math.log(rv) // test 4: realvar -> double 24 | Math.log(iv) // test 5: intvar -> double 25 | 26 | list.get(iv) // test 6: intvar -> int 27 | 28 | val RealVar v0 = 0.0 // test 7: double -> realvar 29 | 30 | val RealVar v1 = new Double(0.0) // test 8: Double -> realvar 31 | 32 | val IntVar v2 = 0 // test 9: int -> intvar 33 | val RealVar v3 = 0 // test 10: Integer -> realvar 34 | 35 | val IntVar v4 = new Integer(0) // test 11: Integer -> intvar 36 | val RealVar v5 = new Integer(0) // test 12: Integer -> realvar 37 | 38 | return v0 + v1 + v2 + v3 + v4 + v5 39 | } 40 | } 41 | } -------------------------------------------------------------------------------- /src/main/java/blang/validation/internals/fixtures/BadNormal.bl: -------------------------------------------------------------------------------- 1 | package blang.validation.internals.fixtures 2 | 3 | import blang.validation.internals.Helpers 4 | 5 | model BadNormal { 6 | 7 | random RealVar realization 8 | 9 | param RealVar mean, 10 | variance 11 | 12 | laws { 13 | 14 | logf() { 15 | - log(Math.sqrt(2*Math.PI)) 16 | } 17 | 18 | logf(variance) { 19 | - 0.5 * log(variance) 20 | } 21 | 22 | logf(mean, variance, realization) { 23 | Helpers::checkOkToUseDefectiveImplementation 24 | return - (pow((mean - realization), 2)) / variance // intentionally missing 1/2 to make sure our tests catch it 25 | } 26 | 27 | logf(variance) { 28 | if (variance > 0) return 0.0 29 | else return NEGATIVE_INFINITY 30 | } 31 | } 32 | 33 | generate(rand) { 34 | rand.nextGaussian * sqrt(variance) + mean 35 | } 36 | } -------------------------------------------------------------------------------- /src/main/java/blang/validation/internals/fixtures/BadPlate.bl: -------------------------------------------------------------------------------- 1 | package blang.validation.internals.fixtures 2 | 3 | model BadPlate { 4 | param Plate bad 5 | 6 | laws { 7 | 8 | } 9 | } -------------------------------------------------------------------------------- /src/main/java/blang/validation/internals/fixtures/CustomAnnealRef.bl: -------------------------------------------------------------------------------- 1 | package blang.validation.internals.fixtures 2 | 3 | model CustomAnnealRef { 4 | 5 | random RealVar mu ?: latentReal 6 | random RealVar x ?: fixedReal(10.0) 7 | 8 | laws { 9 | 10 | mu ~ Normal(0.0, 1.0) 11 | 12 | x | mu ~ Normal(mu, 1.0) 13 | } 14 | } -------------------------------------------------------------------------------- /src/main/java/blang/validation/internals/fixtures/CustomAnnealTest.bl: -------------------------------------------------------------------------------- 1 | package blang.validation.internals.fixtures 2 | 3 | model CustomAnnealTest { 4 | 5 | random RealVar mu ?: latentReal 6 | random RealVar x ?: fixedReal(10.0) 7 | 8 | laws { 9 | 10 | mu ~ Normal(0.0, 1.0) 11 | 12 | | x, mu, RealVar beta = new AnnealingParameter ~ LogPotential({ 13 | val dist = Normal::distribution(mu, 1.0) 14 | return dist.logDensity(x) * beta 15 | }) 16 | } 17 | } -------------------------------------------------------------------------------- /src/main/java/blang/validation/internals/fixtures/Cyclic.bl: -------------------------------------------------------------------------------- 1 | package blang.testmodel 2 | 3 | //import blang.distributions.Normal 4 | 5 | model Cyclic { 6 | 7 | random RealVar 8 | x ?: latentReal, 9 | y ?: latentReal 10 | 11 | laws { 12 | x | y ~ Normal(y, 1.0) 13 | y | x ~ Normal(x, 1.0) 14 | } 15 | } -------------------------------------------------------------------------------- /src/main/java/blang/validation/internals/fixtures/Diffusion.bl: -------------------------------------------------------------------------------- 1 | package blang.validation.internals.fixtures 2 | 3 | model Diffusion { 4 | param Double startPoint ?: 0.5 5 | param Double variance ?: 0.01 6 | random List process ?: latentRealList(10) 7 | param int length ?: process.size 8 | 9 | laws { 10 | process.get(0) | startPoint, variance ~ Normal(startPoint, variance * sqrt(startPoint * (1.0 - startPoint))) 11 | 12 | for (int i : 1 ..< length) { 13 | process.get(i) | variance, RealVar prev = process.get(i - 1) 14 | ~ Normal( 15 | prev, 16 | { 17 | if (prev <= 0.0 || prev >= 1.0) return 1e-5 18 | else variance * sqrt(prev * (1.0 - prev)) 19 | } 20 | ) 21 | } 22 | } 23 | } -------------------------------------------------------------------------------- /src/main/java/blang/validation/internals/fixtures/Doomsday.bl: -------------------------------------------------------------------------------- 1 | package blang.validation.internals.fixtures 2 | 3 | model Doomsday { 4 | random RealVar z 5 | random RealVar y 6 | param RealVar rate 7 | laws { 8 | z | rate ~ Exponential(rate) 9 | y | z ~ ContinuousUniform(0.0, z) 10 | } 11 | } -------------------------------------------------------------------------------- /src/main/java/blang/validation/internals/fixtures/DynamicNormalMixture.bl: -------------------------------------------------------------------------------- 1 | package blang.validation.internals.fixtures 2 | 3 | model DynamicNormalMixture { 4 | 5 | param int nLatentStates 6 | 7 | random List observations 8 | random List states ?: latentIntList(observations.size) 9 | 10 | random DenseSimplex initialDistribution ?: latentSimplex(nLatentStates) 11 | random DenseTransitionMatrix transitionProbabilities ?: latentTransitionMatrix(nLatentStates) 12 | random List means ?: latentRealList(nLatentStates), variances ?: latentRealList(nLatentStates) 13 | 14 | param Matrix concentrations ?: ones(nLatentStates).readOnlyView 15 | 16 | laws { 17 | // Priors on initial and transition probabilities 18 | initialDistribution | concentrations ~ Dirichlet(concentrations) 19 | for (int latentStateIdx : 0 ..< means.size) { 20 | transitionProbabilities.row(latentStateIdx) | concentrations ~ Dirichlet(concentrations) 21 | } 22 | 23 | // Priors on means and variances 24 | for (int latentStateIdx : 0 ..< means.size) { 25 | means.get(latentStateIdx) ~ Normal(0.0, 1.0) 26 | variances.get(latentStateIdx) ~ Gamma(1.0, 1.0) 27 | } 28 | 29 | states | initialDistribution, transitionProbabilities 30 | ~ MarkovChain(initialDistribution, transitionProbabilities) 31 | 32 | // Gaussian emissions 33 | for (int obsIdx : 0 ..< observations.size) { 34 | observations.get(obsIdx) | 35 | means, 36 | variances, 37 | IntVar curIndic = states.get(obsIdx) 38 | ~ Normal(means.get(curIndic), variances.get(curIndic)) 39 | } 40 | } 41 | } -------------------------------------------------------------------------------- /src/main/java/blang/validation/internals/fixtures/Empty.bl: -------------------------------------------------------------------------------- 1 | package blang.validation.internals.fixtures 2 | 3 | model Empty { 4 | 5 | laws {} 6 | } -------------------------------------------------------------------------------- /src/main/java/blang/validation/internals/fixtures/FixedMatrix.bl: -------------------------------------------------------------------------------- 1 | package blang.validation.internals.fixtures 2 | 3 | model FixedMatrix { 4 | 5 | random Matrix m ?: fixedVector(#[2.1, 4.2]) 6 | 7 | laws { 8 | m.getRealVar(0,0) ~ Normal(0, 1) 9 | } 10 | 11 | } -------------------------------------------------------------------------------- /src/main/java/blang/validation/internals/fixtures/Functions.xtend: -------------------------------------------------------------------------------- 1 | package blang.validation.internals.fixtures 2 | 3 | import briefj.collections.UnorderedPair 4 | import java.util.List 5 | import java.util.ArrayList 6 | 7 | class Functions { 8 | def static List> squareIsingEdges(int N){ 9 | val result = new ArrayList 10 | for (int i : 0 ..< N){ 11 | for (int j : 0 ..< N-1){ 12 | result.add(new UnorderedPair(N*i+j, N*i+j+1)) 13 | } 14 | result.add(new UnorderedPair(N*i,N*i+N-1)) 15 | } 16 | for (int j : 0 ..< N){ 17 | for (int i : 0 ..< N-1){ 18 | result.add(new UnorderedPair(N*i+j, N*(i+1)+j)) 19 | } 20 | result.add(new UnorderedPair(j,N*(N-1)+j)) 21 | } 22 | return result 23 | } 24 | } -------------------------------------------------------------------------------- /src/main/java/blang/validation/internals/fixtures/GenerateTwice.bl: -------------------------------------------------------------------------------- 1 | package blang.testmodels 2 | 3 | model GenerateTwice { 4 | 5 | random RealVar x ?: latentReal 6 | 7 | laws { 8 | x ~ Normal(0.0, 1.0) 9 | x ~ Normal(0.0, 1.0) 10 | } 11 | } -------------------------------------------------------------------------------- /src/main/java/blang/validation/internals/fixtures/Growth.bl: -------------------------------------------------------------------------------- 1 | package blang.validation.internals.fixtures 2 | 3 | import blang.validation.internals.Helpers 4 | 5 | model Growth { 6 | param IntVar current ?: latentInt 7 | random IntVar next ?: latentInt 8 | 9 | laws { 10 | logf(current, next) { 11 | if (next == current - 1) return log(1.0/10.0) 12 | else if (next == current) return log(7.0/10.0) 13 | else if (next == current + 1) return log(2.0/10.0) 14 | else return NEGATIVE_INFINITY 15 | } 16 | } 17 | generate (rand) { 18 | val int delta = Generators::categorical(rand, #[1.0/10.0, 7.0/10.0, 2.0/10.0]) - 1 19 | Helpers.checkOkToUseDefectiveImplementation 20 | return /*current + */ delta 21 | } 22 | } -------------------------------------------------------------------------------- /src/main/java/blang/validation/internals/fixtures/HierarchicalModel.bl: -------------------------------------------------------------------------------- 1 | package blang.validation.internals.fixtures 2 | 3 | model HierarchicalModel { 4 | param GlobalDataSource data 5 | param Plate rocketTypes 6 | param Plated numberOfLaunches 7 | random Plated failureProbabilities 8 | random Plated numberOfFailures 9 | random RealVar a ?: latentReal, b ?: latentReal 10 | 11 | laws { 12 | a ~ Exponential(1) 13 | b ~ Exponential(1) 14 | for (Index rocketType : rocketTypes.indices) { 15 | failureProbabilities.get(rocketType) | a, b ~ Beta(a, b) 16 | numberOfFailures.get(rocketType) 17 | | RealVar failureProbability = failureProbabilities.get(rocketType), 18 | IntVar numberOfLaunch = numberOfLaunches.get(rocketType) 19 | ~ Binomial(numberOfLaunch, failureProbability) 20 | } 21 | } 22 | } -------------------------------------------------------------------------------- /src/main/java/blang/validation/internals/fixtures/IfElse.bl: -------------------------------------------------------------------------------- 1 | package blang.validation.internals.fixtures 2 | 3 | model Loop { 4 | 5 | random RealVar x ?: latentReal 6 | 7 | random RealVar y ?: latentReal 8 | 9 | laws { 10 | 11 | // if (true) { 12 | // x ~ Normal(0, 1) 13 | // } 14 | // 15 | // 16 | // if (false) { 17 | // y ~ Normal(1, 2) 18 | // } else { 19 | // y ~ Exponential(1.0) 20 | // } 21 | 22 | 23 | } 24 | 25 | } -------------------------------------------------------------------------------- /src/main/java/blang/validation/internals/fixtures/IntNaiveMHSampler.java: -------------------------------------------------------------------------------- 1 | package blang.validation.internals.fixtures; 2 | 3 | import java.util.List; 4 | 5 | import bayonet.distributions.Random; 6 | import blang.core.Constrained; 7 | import blang.core.LogScaleFactor; 8 | import blang.core.WritableIntVar; 9 | import blang.mcmc.ConnectedFactor; 10 | import blang.mcmc.MHSampler; 11 | import blang.mcmc.SampledVariable; 12 | import blang.mcmc.internals.Callback; 13 | 14 | 15 | 16 | /** 17 | * Warning: not a general purpose move - specialized to SmallHMM test or similar simple binary cases 18 | */ 19 | public class IntNaiveMHSampler extends MHSampler 20 | { 21 | @SampledVariable 22 | WritableIntVar variable; 23 | 24 | @ConnectedFactor 25 | List constrained; 26 | 27 | public static IntNaiveMHSampler build(WritableIntVar variable, List numericFactors) 28 | { 29 | IntNaiveMHSampler result = new IntNaiveMHSampler(); 30 | result.variable = variable; 31 | result.numericFactors = numericFactors; 32 | return result; 33 | } 34 | 35 | @Override 36 | public void propose(Random random, Callback callback) 37 | { 38 | final int oldValue = variable.intValue(); 39 | callback.setProposalLogRatio(0.0); 40 | variable.set(1 - oldValue); 41 | if (!callback.sampleAcceptance()) 42 | variable.set(oldValue); 43 | } 44 | } 45 | -------------------------------------------------------------------------------- /src/main/java/blang/validation/internals/fixtures/IntRealizationSquared.java: -------------------------------------------------------------------------------- 1 | package blang.validation.internals.fixtures; 2 | 3 | import java.util.function.Function; 4 | 5 | import blang.core.IntVar; 6 | import blang.core.UnivariateModel; 7 | 8 | public class IntRealizationSquared implements Function, Double> 9 | { 10 | 11 | @Override 12 | public Double apply(UnivariateModel t) 13 | { 14 | return Math.pow(t.realization().intValue(), 2.0); 15 | } 16 | 17 | } 18 | -------------------------------------------------------------------------------- /src/main/java/blang/validation/internals/fixtures/Ising.bl: -------------------------------------------------------------------------------- 1 | package blang.validation.internals.fixtures 2 | 3 | import briefj.collections.UnorderedPair 4 | import static blang.validation.internals.fixtures.Functions.squareIsingEdges 5 | 6 | model Ising { 7 | param Double moment ?: 0.0 8 | param Double beta ?: log(1 + sqrt(2.0)) / 2.0 // critical point 9 | param Integer N ?: 5 10 | random List vertices ?: latentIntList(N*N) 11 | 12 | laws { 13 | 14 | // Pairwise potentials 15 | for (UnorderedPair pair : squareIsingEdges(N)) { 16 | | IntVar first = vertices.get(pair.getFirst), 17 | IntVar second = vertices.get(pair.getSecond), 18 | beta 19 | ~ LogPotential( 20 | if ((first < 0 || first > 1 || second < 0 || second > 1)) 21 | return NEGATIVE_INFINITY 22 | else 23 | return beta*(2*first-1)*(2*second-1)) 24 | } 25 | 26 | // Node potentials 27 | for (IntVar vertex : vertices) { 28 | vertex | moment ~ Bernoulli(logistic(-2.0*moment)) 29 | } 30 | } 31 | } -------------------------------------------------------------------------------- /src/main/java/blang/validation/internals/fixtures/LinRegression.bl: -------------------------------------------------------------------------------- 1 | package blang.validation.internals.fixtures 2 | 3 | model LinRegression { 4 | param GlobalDataSource data 5 | param Plate observationPlate 6 | param Plated x 7 | 8 | random RealVar alpha, beta, sigma 9 | random Plated y 10 | 11 | laws { 12 | alpha ~ Normal(0, 25) 13 | beta ~ Normal(0, 25) 14 | sigma ~ ContinuousUniform(0, 10) 15 | for (Index i : observationPlate.indices) { 16 | y.get(i) | beta, alpha, sigma, RealVar x_i = x.get(i) 17 | ~ Normal(beta * x_i + alpha, sigma * sigma) 18 | } 19 | } 20 | } -------------------------------------------------------------------------------- /src/main/java/blang/validation/internals/fixtures/ListHash.java: -------------------------------------------------------------------------------- 1 | package blang.validation.internals.fixtures; 2 | 3 | import java.util.List; 4 | import java.util.function.Function; 5 | 6 | import blang.core.IntVar; 7 | import blang.core.UnivariateModel; 8 | 9 | public class ListHash implements Function>, Double> 10 | { 11 | 12 | @Override 13 | public Double apply(UnivariateModel> t) 14 | { 15 | return hash(t.realization()); 16 | } 17 | 18 | public static double hash(List list) 19 | { 20 | double sum = 0.0; 21 | for (int i = 0; i < list.size(); i++) 22 | sum += (i+1) * Math.pow(list.get(i).intValue(), 2.0); 23 | return sum; 24 | } 25 | 26 | } 27 | -------------------------------------------------------------------------------- /src/main/java/blang/validation/internals/fixtures/MarkovChain.bl: -------------------------------------------------------------------------------- 1 | package blang.validation.internals.fixtures 2 | 3 | model MarkovChain { 4 | 5 | param Simplex initialDistribution 6 | param TransitionMatrix transitionProbabilities 7 | 8 | random List chain 9 | 10 | laws { 11 | 12 | // Initial distribution: 13 | chain.get(0) | initialDistribution ~ Categorical(initialDistribution) 14 | // Transitions: 15 | for (int step : 1 ..< chain.size) { 16 | chain.get(step) | 17 | IntVar previous = chain.get(step - 1), 18 | transitionProbabilities 19 | ~ Categorical( 20 | if (previous >= 0 && previous < transitionProbabilities.nRows) 21 | transitionProbabilities.row(previous) 22 | else 23 | transitionProbabilities.row(0) 24 | ) 25 | } 26 | } 27 | } -------------------------------------------------------------------------------- /src/main/java/blang/validation/internals/fixtures/MixtureModel.bl: -------------------------------------------------------------------------------- 1 | package blang.examples 2 | 3 | model MixtureModel { 4 | 5 | random List observations 6 | random List clusterIndicators ?: latentIntList(observations.size) 7 | random Simplex pi ?: latentSimplex(2) 8 | random List means ?: latentRealList(2), 9 | variances ?: latentRealList(2) 10 | param Matrix concentration ?: fixedVector(1.0, 1.0) 11 | 12 | laws { 13 | 14 | pi | concentration ~ Dirichlet(concentration) 15 | 16 | // priors on each mixture component mean and variance 17 | for (int mixIdx : 0 ..< means.size) { 18 | means.get(mixIdx) ~ Normal(0.0, 1.0) 19 | variances.get(mixIdx) ~ Gamma(1.0, 1.0) 20 | } 21 | 22 | for (int obsIdx : 0 ..< observations.size) { 23 | // prior over mixture indicators 24 | clusterIndicators.get(obsIdx) | pi ~ Categorical(pi) 25 | // likelihood: 26 | observations.get(obsIdx) | 27 | means, variances, 28 | IntVar curIndic = clusterIndicators.get(obsIdx) 29 | ~ Normal( 30 | means.get(curIndic), 31 | variances.get(curIndic) 32 | ) 33 | } 34 | } 35 | } -------------------------------------------------------------------------------- /src/main/java/blang/validation/internals/fixtures/Multimodal.bl: -------------------------------------------------------------------------------- 1 | package blang.validation.internals.fixtures 2 | 3 | 4 | model Multimodal { 5 | 6 | param RealDistribution dist1 ?: Normal::distribution(0.0, 1.0) 7 | param RealDistribution dist2 ?: Normal::distribution(3.0, 1.0) 8 | 9 | random RealVar x ?: latentReal 10 | 11 | laws { 12 | logf(x, dist1, dist2) { 13 | log(0.5 * exp(dist1.logDensity(x)) + 0.5 * exp(dist2.logDensity(x))) 14 | } 15 | } 16 | 17 | generate (rand) { 18 | if (Generators.bernoulli(rand, 0.5)) 19 | dist1.sample(rand) 20 | else 21 | dist2.sample(rand) 22 | } 23 | 24 | } -------------------------------------------------------------------------------- /src/main/java/blang/validation/internals/fixtures/NoGen.bl: -------------------------------------------------------------------------------- 1 | package blang.validation.internals.fixtures 2 | 3 | model NoGen { 4 | 5 | random RealVar 6 | x1 ?: latentReal, 7 | x2 ?: latentReal 8 | 9 | laws { 10 | 11 | x1 ~ Normal(0,1) 12 | logf(x2) { 13 | 0.0 14 | } 15 | 16 | } 17 | } -------------------------------------------------------------------------------- /src/main/java/blang/validation/internals/fixtures/NormalFieldExamples.bl: -------------------------------------------------------------------------------- 1 | package blang.validation.internals.fixtures 2 | 3 | import blang.types.Precision.Diagonal 4 | 5 | model NormalFieldExamples { 6 | 7 | param Boolean diagonal ?: true 8 | 9 | param Plate plate 10 | 11 | random Plated latents 12 | 13 | laws { 14 | 15 | // Prior: 16 | //hyperParam ~ Exponential(1.0) 17 | latents 18 | | Precision precision = 19 | if (diagonal) diagonalPrecision(1.0, plate) 20 | else simpleBrownian(1.0, plate) 21 | ~ NormalField(precision) 22 | 23 | } 24 | } -------------------------------------------------------------------------------- /src/main/java/blang/validation/internals/fixtures/NotNormalForm.bl: -------------------------------------------------------------------------------- 1 | package blang.validation.internals.fixtures 2 | 3 | model NotNormalForm { 4 | 5 | random RealVar x ?: latentReal 6 | random RealVar y ?: latentReal 7 | 8 | laws{ 9 | logf(y) {-y*y } 10 | x | y ~ Normal(y, 1.0) 11 | } 12 | } -------------------------------------------------------------------------------- /src/main/java/blang/validation/internals/fixtures/PCR.bl: -------------------------------------------------------------------------------- 1 | package blang.validation.internals.fixtures 2 | 3 | model PCR { 4 | random List states ?: latentIntList(5) 5 | 6 | laws { 7 | for (int i : 1 ..< states.size) { 8 | states.get(i) | IntVar previous = states.get(i - 1) ~ Growth(previous) 9 | } 10 | } 11 | } -------------------------------------------------------------------------------- /src/main/java/blang/validation/internals/fixtures/PlatedMatrixTests.bl: -------------------------------------------------------------------------------- 1 | package blang.validation.internals.fixtures 2 | 3 | model PlatedMatrixTests { 4 | 5 | param Plate dims 6 | param Plate replicates 7 | 8 | random PlatedMatrix xs 9 | 10 | 11 | 12 | laws { 13 | for (Index n : replicates.indices) { 14 | 15 | xs.getDenseVector(dims, n) ~ MultivariateNormal(ones(3), identity(3).cholesky) 16 | 17 | } 18 | } 19 | } -------------------------------------------------------------------------------- /src/main/java/blang/validation/internals/fixtures/PoissonAllInOne.bl: -------------------------------------------------------------------------------- 1 | package blang.distributions 2 | 3 | model PoissonAllInOne { 4 | random IntVar realization 5 | param RealVar mean 6 | 7 | laws { 8 | logf(realization, mean) { 9 | if (mean <= 0) return NEGATIVE_INFINITY 10 | if (realization < 0) return NEGATIVE_INFINITY 11 | return realization * log(mean) - mean - factorialLog(realization) 12 | } 13 | } 14 | 15 | generate(rand) { 16 | Generators::poisson(rand, mean) 17 | } 18 | } -------------------------------------------------------------------------------- /src/main/java/blang/validation/internals/fixtures/PoissonNormalField.bl: -------------------------------------------------------------------------------- 1 | package blang.validation.internals.fixtures 2 | 3 | import blang.types.Precision.Diagonal 4 | 5 | model PoissonNormalField { 6 | 7 | param Boolean diagonal ?: true 8 | 9 | param Plate plate 10 | 11 | random Plated latents 12 | random Plated observations 13 | 14 | laws { 15 | 16 | // Prior: 17 | //hyperParam ~ Exponential(1.0) 18 | latents 19 | | Precision precision = 20 | if (diagonal) diagonalPrecision(1.0, plate) 21 | else simpleBrownian(1.0, plate) 22 | ~ NormalField(precision) 23 | 24 | // Likelihood 25 | for (Index index : plate.indices) { 26 | observations.get(index) | RealVar latent = latents.get(index) ~ Poisson({ 27 | val double result = exp(latent) 28 | if (result == 0) 29 | Generators::ZERO_PLUS_EPS 30 | else if (result == Double::POSITIVE_INFINITY) 31 | Double::MAX_VALUE 32 | else 33 | result 34 | }) 35 | } 36 | 37 | } 38 | } -------------------------------------------------------------------------------- /src/main/java/blang/validation/internals/fixtures/RealNaiveMHSampler.java: -------------------------------------------------------------------------------- 1 | package blang.validation.internals.fixtures; 2 | 3 | import java.util.List; 4 | 5 | import bayonet.distributions.Random; 6 | import blang.core.Constrained; 7 | import blang.core.LogScaleFactor; 8 | import blang.core.WritableIntVar; 9 | import blang.core.WritableRealVar; 10 | import blang.mcmc.ConnectedFactor; 11 | import blang.mcmc.MHSampler; 12 | import blang.mcmc.SampledVariable; 13 | import blang.mcmc.internals.Callback; 14 | 15 | 16 | 17 | /** 18 | * Warning: not a general purpose move - specialized to SmallHMM test or similar simple binary cases 19 | */ 20 | public class RealNaiveMHSampler extends MHSampler 21 | { 22 | @SampledVariable 23 | WritableRealVar variable; 24 | 25 | @ConnectedFactor 26 | List constrained; 27 | 28 | public static RealNaiveMHSampler build(WritableRealVar variable, List numericFactors) 29 | { 30 | RealNaiveMHSampler result = new RealNaiveMHSampler(); 31 | result.variable = variable; 32 | result.numericFactors = numericFactors; 33 | return result; 34 | } 35 | 36 | @Override 37 | public void propose(Random random, Callback callback) 38 | { 39 | final double oldValue = variable.doubleValue(); 40 | callback.setProposalLogRatio(0.0); 41 | variable.set(oldValue + random.nextGaussian()); 42 | if (!callback.sampleAcceptance()) 43 | variable.set(oldValue); 44 | } 45 | } 46 | -------------------------------------------------------------------------------- /src/main/java/blang/validation/internals/fixtures/RealRealizationSquared.java: -------------------------------------------------------------------------------- 1 | package blang.validation.internals.fixtures; 2 | 3 | import java.util.function.Function; 4 | 5 | import blang.core.RealVar; 6 | import blang.core.UnivariateModel; 7 | 8 | public class RealRealizationSquared implements Function, Double> 9 | { 10 | 11 | @Override 12 | public Double apply(UnivariateModel t) 13 | { 14 | return Math.pow(t.realization().doubleValue(), 2.0); 15 | } 16 | 17 | } 18 | -------------------------------------------------------------------------------- /src/main/java/blang/validation/internals/fixtures/Scalability.bl: -------------------------------------------------------------------------------- 1 | package blang.validation.internals.fixtures 2 | 3 | model Scalability { 4 | 5 | param Plate plate 6 | 7 | param Plated variables 8 | 9 | laws { 10 | 11 | for (Index index : plate.indices) { 12 | variables.get(index) ~ Normal(0.0, 1.0) 13 | } 14 | 15 | } 16 | 17 | } -------------------------------------------------------------------------------- /src/main/java/blang/validation/internals/fixtures/Simple.bl: -------------------------------------------------------------------------------- 1 | package blang.validation.internals.fixtures 2 | 3 | model Simple { 4 | 5 | random RealVar x ?: latentReal 6 | 7 | laws { 8 | 9 | x ~ Normal(0, 1) 10 | 11 | } 12 | 13 | } -------------------------------------------------------------------------------- /src/main/java/blang/validation/internals/fixtures/SimpleHierarchicalModel.bl: -------------------------------------------------------------------------------- 1 | package blang.validation.internals.fixtures 2 | 3 | model SimpleHierarchicalModel { 4 | param GlobalDataSource data 5 | param Plate rocketTypes 6 | random Plated numberOfLaunches 7 | random Plated failureProbabilities 8 | random Plated numberOfFailures 9 | random RealVar a ?: latentReal, b ?: latentReal 10 | 11 | laws { 12 | a ~ Exponential(1) 13 | b ~ Exponential(1) 14 | for (Index rocketType : rocketTypes.indices) { 15 | failureProbabilities.get(rocketType) | a, b ~ Beta(a + 0.5, b + 0.5) 16 | numberOfLaunches.get(rocketType) ~ Poisson(2.0) 17 | numberOfFailures.get(rocketType) 18 | | RealVar failureProbability = failureProbabilities.get(rocketType), 19 | IntVar numberOfLaunch = numberOfLaunches.get(rocketType) 20 | ~ Binomial(1+numberOfLaunch, failureProbability) 21 | // we add one here since for testing here we need to generate 22 | // numberOfLaunches, and we use a Poisson here so adding one so that 23 | // we avoid getting zero as argument 24 | } 25 | } 26 | } -------------------------------------------------------------------------------- /src/main/java/blang/validation/internals/fixtures/SmallHMM.bl: -------------------------------------------------------------------------------- 1 | package blang.validation.internals.fixtures 2 | 3 | model SmallHMM { 4 | 5 | param TransitionMatrix trMatrix ?: fixedTransitionMatrix(new ExactHMMCalculations.SimpleTwoStates().transitionPrs) 6 | 7 | random List observations 8 | random List latents ?: latentIntList(observations.size) 9 | 10 | laws { 11 | 12 | latents | trMatrix ~ MarkovChain(fixedSimplex(new ExactHMMCalculations.SimpleTwoStates().initialPrs), trMatrix) 13 | 14 | for (int i : 0 ..< latents.size) { 15 | observations.get(i) | trMatrix, IntVar latent = latents.get(i) ~ Categorical({ 16 | if (latent >= 0 && latent < trMatrix.nRows) { 17 | trMatrix.row(latent) 18 | } else { 19 | trMatrix.row(0) 20 | } 21 | }) 22 | } 23 | 24 | } 25 | } -------------------------------------------------------------------------------- /src/main/java/blang/validation/internals/fixtures/SometimesNaN.bl: -------------------------------------------------------------------------------- 1 | package blang.validation.internals.fixtures 2 | 3 | model SometimesNaN { 4 | random RealVar test ?: latentReal 5 | 6 | laws { 7 | logf(test) { 8 | if (test < 0.0 || test > 1.0) return Double.NaN 9 | return 0.0 10 | } 11 | } 12 | 13 | generate (rand) { 14 | rand.uniform(0.0, 1.0) 15 | } 16 | } -------------------------------------------------------------------------------- /src/main/java/blang/validation/internals/fixtures/SpikeAndSlab.bl: -------------------------------------------------------------------------------- 1 | package blang.validation.internals.fixtures 2 | 3 | 4 | model SpikeAndSlab { 5 | 6 | random List variables 7 | 8 | param RealVar zeroProbability 9 | param RealDistribution nonZeroLogDensity 10 | 11 | laws { 12 | for (int index : 0 ..< variables.size) { 13 | logf(zeroProbability, nonZeroLogDensity, RealVar variable = variables.get(index)) { 14 | if (zeroProbability < 0.0 || zeroProbability > 1.0) return NEGATIVE_INFINITY 15 | if (variable == 0.0) { 16 | log(zeroProbability) 17 | } else { 18 | log(1.0 - zeroProbability) + nonZeroLogDensity.logDensity(variable) 19 | } 20 | } 21 | logf(SpikedRealVar variable = variables.get(index)) { 22 | if (variable.selected.isBool) return 0.0 23 | else return NEGATIVE_INFINITY 24 | } 25 | variables.get(index) is Constrained 26 | } 27 | 28 | } 29 | 30 | generate(rand) { 31 | for (SpikedRealVar variable : variables) { 32 | (variable.selected as WritableIntVar).set(Generators::bernoulli(rand, zeroProbability).asInt) 33 | (variable.continuousPart as WritableRealVar).set(nonZeroLogDensity.sample(rand)) 34 | } 35 | } 36 | } -------------------------------------------------------------------------------- /src/main/java/blang/validation/internals/fixtures/SpikedGLM.bl: -------------------------------------------------------------------------------- 1 | package blang.validation.internals.fixtures 2 | 3 | 4 | 5 | model SpikedGLM { 6 | 7 | param Matrix designMatrix // n by p 8 | random List output // size n 9 | random List coefficients ?: { 10 | val p = designMatrix.nCols 11 | return new ArrayList(p) => [ 12 | for (int i : 0 ..< p) 13 | add(new SpikedRealVar) 14 | ] 15 | } 16 | 17 | random RealVar zeroProbability ?: latentReal 18 | 19 | laws { 20 | 21 | zeroProbability ~ Beta(1,1) 22 | 23 | coefficients | zeroProbability ~ 24 | SpikeAndSlab(zeroProbability, Normal::distribution(0, 1)) 25 | 26 | for (int index : 0 ..< output.size) { 27 | output.get(index) | coefficients, 28 | Matrix predictors = designMatrix.row(index) 29 | ~ Bernoulli(logistic({ 30 | var sum = 0.0 31 | for (i : 0 ..< coefficients.size) 32 | sum += coefficients.get(i).doubleValue * predictors.get(i) 33 | sum 34 | })) 35 | } 36 | 37 | } 38 | 39 | } -------------------------------------------------------------------------------- /src/main/java/blang/validation/internals/fixtures/Unid.bl: -------------------------------------------------------------------------------- 1 | package blang.validation.internals.fixtures 2 | 3 | model Unid { 4 | random RealVar p1 ?: latentReal 5 | random RealVar p2 ?: latentReal 6 | param IntVar nTrials ?: 100000 7 | random IntVar nFails ?: nTrials/2 8 | laws { 9 | p1 ~ ContinuousUniform(0, 1) 10 | p2 ~ ContinuousUniform(0, 1) 11 | nFails | nTrials, p1, p2 ~ Binomial(nTrials, p1 * p2) 12 | } 13 | } -------------------------------------------------------------------------------- /src/main/java/blang/validation/internals/fixtures/UnspecifiedParam.bl: -------------------------------------------------------------------------------- 1 | package blang.validation.internals.fixtures 2 | 3 | model UnspecifiedParam { 4 | 5 | random RealVar myRandom ?: latentReal 6 | 7 | param RealVar myParam 8 | 9 | laws { 10 | 11 | } 12 | } -------------------------------------------------------------------------------- /src/main/java/blang/validation/internals/fixtures/VectorHash.java: -------------------------------------------------------------------------------- 1 | package blang.validation.internals.fixtures; 2 | 3 | import java.util.function.Function; 4 | 5 | import blang.core.UnivariateModel; 6 | import xlinear.Matrix; 7 | 8 | public class VectorHash implements Function, Double> 9 | { 10 | 11 | @Override 12 | public Double apply(UnivariateModel t) 13 | { 14 | return hash(t.realization()); 15 | } 16 | 17 | public static double hash(Matrix m) 18 | { 19 | double sum = 0.0; 20 | for (int i = 0; i < m.nEntries(); i++) 21 | sum += (i+1) * Math.pow(m.get(i), 2.0); 22 | return sum; 23 | } 24 | 25 | } 26 | -------------------------------------------------------------------------------- /src/test/java/blang/TestDiscreteModels.xtend: -------------------------------------------------------------------------------- 1 | package blang 2 | 3 | import org.junit.Test 4 | import blang.validation.internals.fixtures.Ising 5 | import blang.validation.DiscreteMCTest 6 | import blang.runtime.SampledModel 7 | import blang.validation.internals.fixtures.IntNaiveMHSampler 8 | import blang.runtime.internals.objectgraph.GraphAnalysis 9 | import blang.mcmc.internals.SamplerBuilder 10 | import blang.mcmc.internals.SamplerBuilderOptions 11 | import java.util.List 12 | import java.util.ArrayList 13 | 14 | class TestDiscreteModels { 15 | 16 | 17 | @Test 18 | def void isingTests() { 19 | val n = 2 20 | val options = SamplerBuilderOptions::startWithOnly(IntNaiveMHSampler) 21 | val ising = new Ising.Builder().setN(n).build 22 | val graphAnalysis = new GraphAnalysis(ising) 23 | val kernels = SamplerBuilder.build(graphAnalysis, options) 24 | val model = new SampledModel(graphAnalysis, kernels) 25 | val rep = [isingState(it)] 26 | 27 | val test = new DiscreteMCTest(model, rep) 28 | test.checkStateSpaceSize((2 ** (n*n)) as int) 29 | test.checkInvariance 30 | test.checkIrreducibility 31 | } 32 | 33 | def static List isingState(SampledModel m) { 34 | return new ArrayList((m.model as Ising).vertices.map[intValue].toList) 35 | } 36 | } -------------------------------------------------------------------------------- /src/test/java/blang/TestDocumentation.xtend: -------------------------------------------------------------------------------- 1 | package blang 2 | 3 | import org.junit.Test 4 | import blang.runtime.internals.doc.contents.BuiltInDistributions 5 | import blang.xdoc.DocElementExtensions 6 | 7 | class TestDocumentation { 8 | @Test 9 | def void checkComplete() { 10 | DocElementExtensions::checkCommentsComplete = true 11 | new BuiltInDistributions 12 | DocElementExtensions::checkCommentsComplete = false 13 | } 14 | } -------------------------------------------------------------------------------- /src/test/java/blang/TestExactTest.xtend: -------------------------------------------------------------------------------- 1 | package blang 2 | 3 | import blang.mcmc.internals.SamplerBuilderOptions 4 | import blang.types.StaticUtils 5 | import blang.validation.ExactInvarianceTest 6 | import blang.validation.internals.Helpers 7 | import blang.validation.internals.fixtures.BadNormal 8 | import blang.validation.internals.fixtures.BadRealSliceSampler 9 | import blang.validation.internals.fixtures.Multimodal 10 | import blang.validation.internals.fixtures.RealRealizationSquared 11 | import org.junit.After 12 | import org.junit.Assert 13 | import org.junit.Before 14 | import org.junit.Test 15 | import blang.validation.Instance 16 | import blang.types.internals.RealScalar 17 | 18 | /** 19 | * A test for the exact test, to make sure it catches some common types of errors. 20 | */ 21 | class TestExactTest { 22 | 23 | @SuppressWarnings("unchecked") @Test def void checkBadNormalDetected() { 24 | var ExactInvarianceTest test = new ExactInvarianceTest() 25 | test.add(new Instance( 26 | new BadNormal.Builder().setMean(StaticUtils.fixedReal(0.2)).setVariance(StaticUtils.fixedReal(0.1)).setRealization(new RealScalar(1.0)).build(), 27 | new RealRealizationSquared())) 28 | ensureTestFails(test) 29 | } 30 | 31 | @Test 32 | def void checkBadSliceSamplerDetected() { 33 | var ExactInvarianceTest test = new ExactInvarianceTest() 34 | var SamplerBuilderOptions samplers = SamplerBuilderOptions.startWithOnly(BadRealSliceSampler) 35 | 36 | test.add(new Instance( 37 | new Multimodal.Builder().build, 38 | samplers, 39 | new RealRealizationSquared() 40 | )) 41 | ensureTestFails(test) 42 | } 43 | 44 | def void ensureTestFails(ExactInvarianceTest test) { 45 | val double referenceFamilyWiseErrorThreshold = getMainTestPValue() 46 | Assert.assertTrue(test.nTests() > 0) 47 | println("Threshold derived from TestSDKDistributions:" + referenceFamilyWiseErrorThreshold) 48 | println("Expecting " + test.nTests() + " failed test: \n" + ExactInvarianceTest::format(test.results)) 49 | Assert.assertEquals(test.failedTests(referenceFamilyWiseErrorThreshold).size(), test.nTests()) 50 | println 51 | } 52 | 53 | @Before 54 | def void before() { 55 | Helpers.setDefectiveImplementationStatus(true) 56 | } 57 | 58 | @After 59 | def void after() { 60 | Helpers.setDefectiveImplementationStatus(false) 61 | } 62 | 63 | def private double getMainTestPValue() { 64 | var ExactInvarianceTest lazyTest = new ExactInvarianceTest(true) 65 | TestSDKDistributions.setup(lazyTest) 66 | return lazyTest.correctedPValue 67 | } 68 | } 69 | -------------------------------------------------------------------------------- /src/test/java/blang/TestFixedMatrix.xtend: -------------------------------------------------------------------------------- 1 | package blang 2 | 3 | import org.junit.Test 4 | import blang.validation.internals.fixtures.FixedMatrix 5 | import blang.mcmc.internals.SamplerBuilder 6 | import blang.runtime.internals.objectgraph.GraphAnalysis 7 | import org.junit.Assert 8 | import xlinear.MatrixOperations 9 | import blang.types.StaticUtils 10 | 11 | class TestFixedMatrix { 12 | 13 | 14 | @Test 15 | def void testFixed() { 16 | val model = new FixedMatrix.Builder().build 17 | val built = SamplerBuilder::build(new GraphAnalysis(model)) 18 | Assert::assertTrue(built.list.empty) 19 | } 20 | 21 | @Test 22 | def void testMutable() { 23 | val model = new FixedMatrix.Builder().setM(MatrixOperations::dense(2)).build 24 | val built = SamplerBuilder::build(new GraphAnalysis(model)) 25 | Assert::assertTrue(!built.list.empty) 26 | } 27 | 28 | @Test 29 | def void testRecurse() { 30 | val simplex = StaticUtils::fixedSimplex(0.5, 0.5) 31 | val model = new FixedMatrix.Builder().setM(simplex.row(0)).build 32 | val built = SamplerBuilder::build(new GraphAnalysis(model)) 33 | Assert::assertTrue(built.list.empty) 34 | } 35 | } -------------------------------------------------------------------------------- /src/test/java/blang/TestRunner.java: -------------------------------------------------------------------------------- 1 | package blang; 2 | 3 | import org.junit.Assert; 4 | import org.junit.Test; 5 | 6 | import blang.core.ModelBuilder; 7 | import blang.inits.experiments.Experiment; 8 | import blang.runtime.Runner; 9 | import blang.testmodel.Cyclic; 10 | import blang.testmodels.GenerateTwice; 11 | import blang.validation.internals.Helpers; 12 | import blang.validation.internals.fixtures.Simple; 13 | import blang.validation.internals.fixtures.UnspecifiedParam; 14 | 15 | public class TestRunner 16 | { 17 | @Test 18 | public void checkCyclesDetected() 19 | { 20 | checkDAGViolation(new Cyclic.Builder()); 21 | } 22 | 23 | @Test 24 | public void checkGeneratedTwiceDetected() 25 | { 26 | checkDAGViolation(new GenerateTwice.Builder()); 27 | } 28 | 29 | @Test 30 | public void checkSimpleOK() 31 | { 32 | new Runner(new Simple.Builder()).run(); 33 | } 34 | 35 | @Test 36 | public void testMissingParam() 37 | { 38 | // A RealVar as parameter without default ?: provided should prompt a CLI parsing error 39 | // Keep this check to ensure parsing behaviour of RealVar, etc does not get too liberal 40 | Assert.assertEquals(Runner.start(UnspecifiedParam.class.getCanonicalName()), Experiment.CLI_PARSING_ERROR_CODE); 41 | } 42 | 43 | public void checkDAGViolation(ModelBuilder builder) 44 | { 45 | Runner runner = new Runner(builder); 46 | Helpers.assertTypeOfThrownExceptionMatches(() -> runner.run(), new Runner.NotDAG("")); 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /src/test/java/blang/TestSDKDistributions.xtend: -------------------------------------------------------------------------------- 1 | package blang 2 | 3 | import blang.validation.ExactInvarianceTest 4 | import org.junit.Test 5 | import blang.validation.DeterminismTest 6 | import blang.distributions.Generators 7 | 8 | import blang.validation.internals.fixtures.Examples 9 | 10 | class TestSDKDistributions { 11 | 12 | @Test 13 | def void exactInvarianceTest() { 14 | val oldThreshold = Generators._poissonSwitchToNormalThreshold 15 | Generators._poissonSwitchToNormalThreshold = Examples.largeLambda - 10 16 | test(new ExactInvarianceTest) 17 | Generators._poissonSwitchToNormalThreshold = oldThreshold 18 | } 19 | 20 | def static void test(ExactInvarianceTest test) { 21 | setup(test) 22 | println("Corrected pValue = " + test.correctedPValue) 23 | test.check() 24 | } 25 | 26 | def static void setup(ExactInvarianceTest test) { 27 | test => [ 28 | nPosteriorSamplesPerIndep = 500 // 1000 creates a travis time out 29 | for (instance : new Examples().all) { 30 | test.add(instance) 31 | } 32 | ] 33 | } 34 | 35 | @Test 36 | def void determinismTest() { 37 | new DeterminismTest => [ 38 | for (instance : new Examples().all) { 39 | check(instance) 40 | } 41 | ] 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /src/test/java/blang/TestSDKNormalizations.java: -------------------------------------------------------------------------------- 1 | package blang; 2 | 3 | import org.junit.Assert; 4 | import org.junit.Test; 5 | 6 | import blang.core.IntDistribution; 7 | import blang.distributions.Generators; 8 | import blang.distributions.YuleSimon; 9 | import blang.types.StaticUtils; 10 | import blang.validation.NormalizationTest; 11 | import blang.distributions.BetaNegativeBinomial; 12 | 13 | import blang.validation.internals.fixtures.Examples; 14 | 15 | public class TestSDKNormalizations extends NormalizationTest 16 | { 17 | private Examples examples = new Examples(); 18 | 19 | @Test 20 | public void normal() 21 | { 22 | // check norm from -infty to +infty (by doubling domain of integration) 23 | checkNormalization(examples.normal.model); 24 | } 25 | 26 | @Test 27 | public void beta() 28 | { 29 | // check norm on a close interval 30 | checkNormalization(examples.beta.model, Generators.ZERO_PLUS_EPS, Generators.ONE_MINUS_EPS); 31 | } 32 | 33 | @Test 34 | public void testExponential() 35 | { 36 | // approximate 0, infty interval 37 | checkNormalization(examples.exp.model, 0.0, 10.0); 38 | } 39 | 40 | @Test 41 | public void testGamma() 42 | { 43 | checkNormalization(examples.gamma.model, 0.0, 15.0); 44 | } 45 | 46 | @Test 47 | public void testYuleSimon() 48 | { 49 | IntDistribution distribution = YuleSimon.distribution(StaticUtils.fixedReal(3.5)); 50 | double sum = 0.0; 51 | for (int i = 0; i < 100; i++) 52 | sum += Math.exp(distribution.logDensity(i)); 53 | Assert.assertEquals(1.0, sum, 0.01); 54 | } 55 | 56 | @Test 57 | public void testBNB() 58 | { 59 | IntDistribution distribution = BetaNegativeBinomial.distribution(StaticUtils.fixedReal(3.5), StaticUtils.fixedReal(1.2), StaticUtils.fixedReal(3.0)); 60 | double sum = 0.0; 61 | for (int i = 0; i < 1000; i++) 62 | sum += Math.exp(distribution.logDensity(i)); 63 | Assert.assertEquals(1.0, sum, 0.01); 64 | } 65 | 66 | } 67 | -------------------------------------------------------------------------------- /src/test/java/blang/TestSparseDirichletAndBetaWarnings.xtend: -------------------------------------------------------------------------------- 1 | package blang 2 | 3 | import blang.validation.ExactInvarianceTest 4 | import org.junit.Test 5 | import bayonet.distributions.Random 6 | import blang.distributions.Dirichlet 7 | import blang.validation.Instance 8 | import xlinear.MatrixOperations 9 | import blang.types.StaticUtils 10 | import org.junit.Assert 11 | import blang.distributions.internals.Helpers 12 | import blang.distributions.Beta 13 | 14 | class TestSparseDirichletAndBetaWarnings { 15 | 16 | @Test 17 | def void testSimpleDiri() 18 | { 19 | Helpers.warnedUnstableConcentration = false 20 | new ExactInvarianceTest => [ 21 | random = new Random(14) 22 | nPosteriorSamplesPerIndep = 1 //500 23 | val instance = new Instance( 24 | new Dirichlet.Builder() 25 | .setConcentrations(MatrixOperations::denseCopy(#[0.1, 0.1])) 26 | .setRealization(StaticUtils::latentSimplex(2)).build, 27 | [getRealization.get(0)]) 28 | add(instance) 29 | ] //.check(0.05) After changing 1->500 above this would crash (p value is 0.036631052707119305 on commit of Nov 10 4pm). See Issue #62 30 | Assert.assertTrue(Helpers.warnedUnstableConcentration) 31 | } 32 | 33 | @Test 34 | def void testBeta() 35 | { 36 | Helpers.warnedUnstableConcentration = false 37 | new ExactInvarianceTest => [ 38 | nPosteriorSamplesPerIndep = 1 //500 39 | val instance = new Instance( 40 | new Beta.Builder() 41 | .setAlpha(StaticUtils::fixedReal(0.1)) 42 | .setBeta(StaticUtils::fixedReal(0.1)) 43 | .setRealization(StaticUtils::latentReal).build, 44 | [getRealization.doubleValue] 45 | ) 46 | add(instance) 47 | ] //.check(0.05) After changing 1->500 above this would crash (p value is 0.02330809853328797 on commit of Nov 10 4pm). See Issue #62 48 | Assert.assertTrue(Helpers.warnedUnstableConcentration) 49 | } 50 | } -------------------------------------------------------------------------------- /src/test/java/blang/TestSyntax.java: -------------------------------------------------------------------------------- 1 | package blang; 2 | 3 | import org.junit.Test; 4 | 5 | import blang.core.LogScaleFactor; 6 | import blang.core.ModelBuilder; 7 | import blang.validation.internals.fixtures.AutoBoxDeboxTests; 8 | import blang.validation.internals.fixtures.Operations; 9 | 10 | public class TestSyntax 11 | { 12 | @Test 13 | public void boxDebox() 14 | { 15 | test(new AutoBoxDeboxTests.Builder()); 16 | } 17 | 18 | @Test 19 | public void operations() 20 | { 21 | test(new Operations.Builder()); 22 | } 23 | 24 | private static void test(ModelBuilder builder) 25 | { 26 | ((LogScaleFactor) builder.build().components().iterator().next()).logDensity(); 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /src/test/java/blang/runtime/TestSampledModel.java: -------------------------------------------------------------------------------- 1 | package blang.runtime; 2 | 3 | import org.junit.Assert; 4 | import org.junit.Test; 5 | 6 | import blang.validation.internals.Helpers; 7 | import blang.validation.internals.fixtures.PCR; 8 | 9 | public class TestSampledModel 10 | { 11 | 12 | @Test 13 | public void test() 14 | { 15 | Helpers.setDefectiveImplementationStatus(true); 16 | Runner runner = new Runner(new PCR.Builder()); 17 | try 18 | { 19 | runner.run(); 20 | } 21 | catch (RuntimeException re) 22 | { 23 | Assert.assertEquals(re.getMessage(), SampledModel.INVALID_LOG_RATIO); 24 | return; 25 | } 26 | Assert.fail(); 27 | } 28 | 29 | } 30 | -------------------------------------------------------------------------------- /src/test/resource/data.csv: -------------------------------------------------------------------------------- 1 | sample,observations 2 | 0,2.4343 3 | 1,-23.45 --------------------------------------------------------------------------------