├── .github └── workflows │ └── maven.yml ├── .travis.yml ├── README.md ├── pom.xml ├── src ├── main │ ├── java │ │ └── cc │ │ │ └── mallet │ │ │ ├── classify │ │ │ ├── KLDivergenceClassifier.java │ │ │ ├── KLDivergenceClassifierMultiCorpus.java │ │ │ └── evaluate │ │ │ │ └── EnhancedConfusionMatrix.java │ │ │ ├── configuration │ │ │ ├── ConfigFactory.java │ │ │ ├── Configuration.java │ │ │ ├── LDACommandLineParser.java │ │ │ ├── LDAConfiguration.java │ │ │ ├── LDARemoteConfiguration.java │ │ │ ├── LDATrainTestCommandLineParser.java │ │ │ ├── LDATrainTestConfiguration.java │ │ │ ├── ModelFactory.java │ │ │ ├── ParsedLDAConfiguration.java │ │ │ ├── ParsedLDATrainTestConfiguration.java │ │ │ ├── ParsedRemoteLDAConfiguration.java │ │ │ ├── SimpleLDAConfiguration.java │ │ │ └── SubConfig.java │ │ │ ├── pipe │ │ │ ├── KeepConnectorPunctuationNumericAlsoTokenizer.java │ │ │ ├── KeepConnectorPunctuationTokenizerLarge.java │ │ │ ├── NumericAlsoTokenizer.java │ │ │ ├── RawTokenizer.java │ │ │ ├── SimpleTokenizerLarge.java │ │ │ ├── TfIdfPipe.java │ │ │ └── TokenSequencePredicateMatcher.java │ │ │ ├── similarity │ │ │ ├── BM25Distance.java │ │ │ ├── BhattacharyyaDistance.java │ │ │ ├── CanberraDistance.java │ │ │ ├── ChebychevDistance.java │ │ │ ├── CoOccurrenceLDALikelihoodDistance.java │ │ │ ├── ContinousJaccardDistance.java │ │ │ ├── CorpusStatistics.java │ │ │ ├── CosineDistance.java │ │ │ ├── Distance.java │ │ │ ├── DocumentDistancer.java │ │ │ ├── EuclidianDistance.java │ │ │ ├── HellingerDistance.java │ │ │ ├── InstanceDistance.java │ │ │ ├── InstanceDistanceWrapper.java │ │ │ ├── JaccardDistance.java │ │ │ ├── JensenShannonDistance.java │ │ │ ├── KLDistance.java │ │ │ ├── KolmogorovSmirnovDistance.java │ │ │ ├── LDADistancer.java │ │ │ ├── LDALikelihoodDistance.java │ │ │ ├── LengthLDALikelihoodDistance.java │ │ │ ├── LikelihoodDistance.java │ │ │ ├── LongQueryLDALikelihoodDistance.java │ │ │ ├── ManhattanDistance.java │ │ │ ├── MultiInstanceDistanceWrapper.java │ │ │ ├── StatObj.java │ │ │ ├── StatisticalDistance.java │ │ │ ├── StreamCorpusStatistics.java │ │ │ ├── SymmetricKLDistance.java │ │ │ ├── TDistance.java │ │ │ ├── TfIdfVectorizer.java │ │ │ ├── TokenFrequencyVectorizer.java │ │ │ ├── TokenIndexVectorizer.java │ │ │ ├── TokenOccurenceVectorizer.java │ │ │ ├── TrainedDistance.java │ │ │ ├── UberDistance.java │ │ │ └── Vectorizer.java │ │ │ ├── topics │ │ │ ├── ADLDA.java │ │ │ ├── AbortableSampler.java │ │ │ ├── CollapsedLightLDA.java │ │ │ ├── DocTopicTokenFreqTable.java │ │ │ ├── EfficientUncollapsedParallelLDA.java │ │ │ ├── HDPSamplerWithPhi.java │ │ │ ├── HashLDA.java │ │ │ ├── LDADocSamplingContext.java │ │ │ ├── LDADocSamplingResult.java │ │ │ ├── LDADocSamplingResultDense.java │ │ │ ├── LDADocSamplingResultSparse.java │ │ │ ├── LDADocSamplingResultSparseSimple.java │ │ │ ├── LDAGibbsSampler.java │ │ │ ├── LDASamplerContinuable.java │ │ │ ├── LDASamplerInitiable.java │ │ │ ├── LDASamplerWithCallback.java │ │ │ ├── LDASamplerWithDocumentPriors.java │ │ │ ├── LDASamplerWithPhi.java │ │ │ ├── LDASamplerWithTopicPriors.java │ │ │ ├── LightPCLDA.java │ │ │ ├── LightPCLDAtypeTopicProposal.java │ │ │ ├── LogState.java │ │ │ ├── MarginalProbEstimatorPlain.java │ │ │ ├── ModifiedSimpleLDA.java │ │ │ ├── MyWorkerRunnable.java │ │ │ ├── NZVSSpaliasUncollapsedParallelLDA.java │ │ │ ├── PoissonPolyaUrnHDPLDA.java │ │ │ ├── PoissonPolyaUrnHDPLDAInfiniteTopics.java │ │ │ ├── PoissonPolyaUrnHLDA.java │ │ │ ├── PolyaUrnSpaliasLDA.java │ │ │ ├── PolyaUrnSpaliasLDAWithPriors.java │ │ │ ├── SerialCollapsedLDA.java │ │ │ ├── SpaliasUncollapsedParallelLDA.java │ │ │ ├── SpaliasUncollapsedParallelWithPriors.java │ │ │ ├── SparseHDPSampler.java │ │ │ ├── SparseUncollapsedSampler.java │ │ │ ├── TableBuilderFactory.java │ │ │ ├── TopicModelDiagnosticsPlain.java │ │ │ ├── TypeTopicParallelTableBuilder.java │ │ │ ├── UncollapsedLDADocSamplingContext.java │ │ │ ├── UncollapsedParallelLDA.java │ │ │ ├── WalkerAliasTableBuildResult.java │ │ │ ├── randomscan │ │ │ │ ├── document │ │ │ │ │ ├── AdaptiveBatchBuilder.java │ │ │ │ │ ├── BatchBuilderFactory.java │ │ │ │ │ ├── DocumentBatchBuilder.java │ │ │ │ │ ├── EvenSplitBatchBuilder.java │ │ │ │ │ ├── FixedSplitBatchBuilder.java │ │ │ │ │ └── PercentageBatchBuilder.java │ │ │ │ └── topic │ │ │ │ │ ├── AllWordsTopicIndexBuilder.java │ │ │ │ │ ├── DeltaNTopicIndexBuilder.java │ │ │ │ │ ├── EvenSplitTopicBatchBuilder.java │ │ │ │ │ ├── MandelbrotTopicIndexBuilder.java │ │ │ │ │ ├── MetaTopicIndexBuilder.java │ │ │ │ │ ├── MixedMandelbrotDeltaNTopicIndexBuilder.java │ │ │ │ │ ├── PercentageTopicBatchBuilder.java │ │ │ │ │ ├── ProportionalTopicIndexBuilder.java │ │ │ │ │ ├── TopWordsRandomFractionTopicIndexBuilder.java │ │ │ │ │ ├── TopicBatchBuilder.java │ │ │ │ │ ├── TopicBatchBuilderFactory.java │ │ │ │ │ ├── TopicIndexBuilder.java │ │ │ │ │ └── TopicIndexBuilderFactory.java │ │ │ └── tui │ │ │ │ ├── BM25Search.java │ │ │ │ ├── DocumentSimilarity.java │ │ │ │ ├── IterationListener.java │ │ │ │ ├── KLClassifier.java │ │ │ │ ├── LDASimilarity.java │ │ │ │ ├── LoglikelihoodCalculator.java │ │ │ │ ├── ParallelLDA.java │ │ │ │ ├── ParallelLDAInference.java │ │ │ │ ├── ParallelLDATrainTest.java │ │ │ │ ├── SvmLightExporter.java │ │ │ │ ├── TopicMassExperiment.java │ │ │ │ └── XValidationCreator.java │ │ │ ├── types │ │ │ ├── BinomialSampler.java │ │ │ ├── ConditionalDirichlet.java │ │ │ ├── DefaultSparseDirichletSamplerBuilder.java │ │ │ ├── MarsagliaSparseDirichlet.java │ │ │ ├── ParallelDirichlet.java │ │ │ ├── PoissonFixedCoeffSampler.java │ │ │ ├── PolyaUrnDirichlet.java │ │ │ ├── PolyaUrnDirichletFixedCoeffPoisson.java │ │ │ ├── PolyaUrnDirichletSamplerBuilder.java │ │ │ ├── PolyaUrnFixedCoeffPoissonDirichletSamplerBuilder.java │ │ │ ├── SimpleMultinomial.java │ │ │ ├── SparseDirichlet.java │ │ │ ├── SparseDirichletSamplerBuilder.java │ │ │ ├── StandardArgsDirichletBuilder.java │ │ │ ├── VSDirichlet.java │ │ │ ├── VSResult.java │ │ │ ├── VariableSelectionDirichlet.java │ │ │ └── VariableSelectionResult.java │ │ │ └── util │ │ │ ├── ArrayStringUtils.java │ │ │ ├── EclipseDetector.java │ │ │ ├── FileLoggingUtils.java │ │ │ ├── GentleAliasMethod.java │ │ │ ├── IndexSampler.java │ │ │ ├── IndexSorter.java │ │ │ ├── IntArraySortUtils.java │ │ │ ├── LDADatasetDirectoryLoadingUtils.java │ │ │ ├── LDADatasetFileLoadingUtils.java │ │ │ ├── LDADatasetStreamLoadingUtils.java │ │ │ ├── LDADatasetStringLoadingUtils.java │ │ │ ├── LDALoggingUtils.java │ │ │ ├── LDANullLogger.java │ │ │ ├── LDAThreadFactory.java │ │ │ ├── LDAUtils.java │ │ │ ├── LoggingUtils.java │ │ │ ├── MalletTopicIndicatorLogger.java │ │ │ ├── MoreFileUtils.java │ │ │ ├── NullOutputStream.java │ │ │ ├── NullPrintWriter.java │ │ │ ├── OptimizedGentleAliasMethod.java │ │ │ ├── OptimizedGentleAliasMethodDynamicSize.java │ │ │ ├── ParallelRandoms.java │ │ │ ├── PerplexityDatasetBuilder.java │ │ │ ├── ReMappedAliasTable.java │ │ │ ├── SparsityTools.java │ │ │ ├── StandardTopicIndicatorLogger.java │ │ │ ├── Stats.java │ │ │ ├── StringClassArrayIterator.java │ │ │ ├── SystematicSampling.java │ │ │ ├── TeeStream.java │ │ │ ├── Timer.java │ │ │ ├── Timing.java │ │ │ ├── TopicIndicatorLogger.java │ │ │ ├── WalkerAliasTable.java │ │ │ ├── WithoutReplacementSampler.java │ │ │ ├── XORShiftRandom.java │ │ │ └── resources │ │ │ └── logging.properties │ └── resources │ │ ├── configuration │ │ ├── Clinton.cfg │ │ ├── Configuration-README.md │ │ ├── Configuration-README.txt │ │ ├── GlobalPLDAConfig.cfg │ │ ├── KLClassification.cfg │ │ ├── LDASimilarity.cfg │ │ ├── PLDAConfig.cfg │ │ ├── PLDAConfigDeltaN.cfg │ │ ├── PoissonPolyaUrnHLDA.cfg │ │ ├── PolyaUrnConfig.cfg │ │ ├── RSConfig.cfg │ │ ├── SmokeTestConfig.cfg │ │ ├── SpaliasMainRemote.conf │ │ ├── SpaliasWorkerRemote.conf │ │ ├── TestConfig.cfg │ │ ├── TestPriorsConfig.cfg │ │ ├── TopicMassConfig.cfg │ │ ├── UCBaseLineConfig.cfg │ │ ├── UnitTestConfig.cfg │ │ ├── UnitTestConfigWithCommaDesc.cfg │ │ ├── VSConfig.cfg │ │ └── minimal.cfg │ │ ├── datasets │ │ ├── 20newsgroups.txt │ │ ├── README.txt │ │ ├── SmallTexts.txt │ │ ├── ap.txt │ │ ├── cgcbib.txt │ │ ├── enron.txt │ │ ├── nips.txt │ │ ├── small.txt │ │ ├── small.txt.gz │ │ ├── small.txt.zip │ │ ├── special_chars.txt │ │ └── tfidf-samples.txt │ │ └── topic_priors.txt └── test │ ├── java │ └── cc │ │ └── mallet │ │ ├── configuration │ │ ├── ConfigTest.java │ │ └── SmokeTest.java │ │ ├── misc │ │ ├── RandomTesting.java │ │ └── SystematicSamplingTest.java │ │ ├── pipe │ │ └── TfIdfPipeTest.java │ │ ├── similarity │ │ ├── CorpusStatisticsTest.java │ │ ├── CosineDistanceTest.java │ │ ├── LDALikelihoodTest.java │ │ ├── LikelihoodDistanceTest.java │ │ └── SimilarityTest.java │ │ ├── topics │ │ ├── DeltaWritingTest.java │ │ ├── DocTopicTokenFreqTableTest.java │ │ ├── DocumentProposalTest.java │ │ ├── LightXLDATest.java │ │ ├── LogLikelihoodTest.java │ │ ├── MarginalProbEstimatorPlainTest.java │ │ ├── ModifiedSimpleLDATest.java │ │ ├── ParanoidCollapsedLightLDA.java │ │ ├── ParanoidLightPCLDAtypeTopicProposal.java │ │ ├── ParanoidPoissonPolyaUrnHDP.java │ │ ├── ParanoidSpaliasUncollapsedLDA.java │ │ ├── ParanoidTest.java │ │ ├── ParanoidUncollapsedParallelLDA.java │ │ ├── ParanoidVSSpaliasUncollapsedLDA.java │ │ ├── PoissonPolyaUrnHDPLDATest.java │ │ ├── PoissonPolyaUrnTest.java │ │ ├── PolyaUrnSpaliasTest.java │ │ ├── PriorsTest.java │ │ ├── ReadWriteTest.java │ │ ├── SpaliasUncollapsedTest.java │ │ ├── SpaliasUncollapsedTestPhiPriors.java │ │ ├── SparsityToolsTest.java │ │ ├── TestBetweenProcessInitialization.java │ │ ├── TestInitialization.java │ │ └── tui │ │ │ └── LoglikelihoodCalculatorTest.java │ │ ├── types │ │ ├── BinomialSamplerTest.java │ │ ├── CondDirichletDraw.java │ │ ├── PoissonFixedCoeffSamplerTest.java │ │ ├── RawTokenizerTest.java │ │ ├── SamplerTest.java │ │ ├── SimpleTokenizerLargeTest.java │ │ ├── SparseDirichletDrawParameterizedTest.java │ │ ├── SparseDirichletDrawTest.java │ │ ├── TestSimpleMultinomial.java │ │ ├── TopicCalculationTests.java │ │ └── VSDirichletTest.java │ │ ├── util │ │ ├── LDAUtilsTest.java │ │ └── LogginUtilsTest.java │ │ └── utils │ │ ├── BatchBuilderTest.java │ │ ├── IndexSorterTest.java │ │ ├── InsertionSortTest.java │ │ ├── MultinomialSampler.java │ │ ├── TestPerplexityDatasetBuilder.java │ │ ├── TestUtils.java │ │ └── WalkerAliasTableTest.java │ └── resources │ ├── document_priors.txt │ ├── max_doc_buf-2.cfg │ ├── max_doc_buf-small.cfg │ ├── max_doc_buf.cfg │ ├── nips_document_priors.txt │ ├── nips_topic_priors.txt │ ├── special_chars.cfg │ ├── topic_priors.txt │ └── topic_priors_SmallTexts.txt ├── stoplist-20ng-large.txt ├── stoplist-empty.txt └── stoplist.txt /.github/workflows/maven.yml: -------------------------------------------------------------------------------- 1 | # This workflow will build a Java project with Maven 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/building-and-testing-java-with-maven 3 | 4 | name: Java CI with Maven 5 | 6 | on: 7 | push: 8 | branches: [ master ] 9 | pull_request: 10 | branches: [ master ] 11 | 12 | jobs: 13 | build: 14 | 15 | runs-on: ubuntu-latest 16 | 17 | steps: 18 | - uses: actions/checkout@v2 19 | - name: Set up JDK 8 20 | uses: actions/setup-java@v2 21 | with: 22 | java-version: '8' 23 | distribution: 'adopt' 24 | - name: Build with Maven 25 | run: mvn -B package --file pom.xml 26 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | # Travis CI file for java projects 2 | # See: https://docs.travis-ci.com/user/languages/java 3 | 4 | language: java 5 | -------------------------------------------------------------------------------- /src/main/java/cc/mallet/configuration/ConfigFactory.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.configuration; 2 | 3 | import org.apache.commons.configuration.ConfigurationException; 4 | 5 | public class ConfigFactory { 6 | protected static Configuration mainConfig = null; 7 | public static Configuration getMainConfiguration(LDACommandLineParser cp) throws ConfigurationException { 8 | if( mainConfig == null ) { 9 | mainConfig = new ParsedLDAConfiguration(cp); 10 | } 11 | 12 | return mainConfig; 13 | } 14 | 15 | public static Configuration getMainRemoteConfiguration(LDACommandLineParser cp) throws ConfigurationException { 16 | if( mainConfig == null ) { 17 | mainConfig = new ParsedRemoteLDAConfiguration(cp); 18 | } 19 | 20 | return mainConfig; 21 | } 22 | 23 | public static Configuration getTrainTestConfiguration(LDACommandLineParser cp) throws ConfigurationException { 24 | if( mainConfig == null ) { 25 | mainConfig = new ParsedLDATrainTestConfiguration(cp); 26 | } 27 | 28 | return mainConfig; 29 | } 30 | 31 | public static Configuration getMainConfiguration() { 32 | return mainConfig; 33 | } 34 | 35 | public static Configuration setMainConfiguration(Configuration conf) { 36 | return mainConfig = conf; 37 | } 38 | 39 | } 40 | -------------------------------------------------------------------------------- /src/main/java/cc/mallet/configuration/Configuration.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.configuration; 2 | 3 | public interface Configuration extends org.apache.commons.configuration.Configuration{ 4 | public String [] getSubConfigs(); 5 | public String getActiveSubConfig(); 6 | public String whereAmI(); 7 | public String [] getStringArrayProperty(String key); 8 | public int [] getIntArrayProperty(String key,int [] defaultValues); 9 | public String getStringProperty(String key); 10 | Object getConfProperty(String key); 11 | void activateSubconfig(String subConfName); 12 | } 13 | -------------------------------------------------------------------------------- /src/main/java/cc/mallet/configuration/LDARemoteConfiguration.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.configuration; 2 | 3 | 4 | public interface LDARemoteConfiguration extends LDAConfiguration { 5 | 6 | public static final int REMOTE_PORT_DEFAULT = 5150; 7 | 8 | public String[] getRemoteWorkerMachines(); 9 | 10 | public String getRemoteMaster(); 11 | 12 | public int[] getRemoteWorkerCores(); 13 | 14 | public int[] getRemoteWorkerPorts(int defaultValue); 15 | 16 | public int getRemoteWorkerPort(int defaultValue); 17 | 18 | public String getAkkaMasterConfig(); 19 | 20 | public String getAkkaWorkerConfig(); 21 | 22 | public Boolean getSendPartials(); 23 | 24 | } -------------------------------------------------------------------------------- /src/main/java/cc/mallet/configuration/LDATrainTestCommandLineParser.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.configuration; 2 | 3 | import org.apache.commons.cli.ParseException; 4 | 5 | public class LDATrainTestCommandLineParser extends LDACommandLineParser { 6 | 7 | public LDATrainTestCommandLineParser(String [] args) throws ParseException { 8 | super.addOptions(); 9 | addOptions(); 10 | parsedCommandLine = parseCommandLine(args); 11 | 12 | if( parsedCommandLine.hasOption( "cm" ) ) { 13 | comment = parsedCommandLine.getOptionValue( "comment" ); 14 | } 15 | if( parsedCommandLine.hasOption( "cf" ) ) { 16 | configFn = parsedCommandLine.getOptionValue( "run_cfg" ); 17 | } 18 | } 19 | 20 | protected void addOptions() { 21 | options.addOption( "ts", "testset", true, "a filename to a file containing which ids in the dataset to act as test ids " ); 22 | } 23 | 24 | } 25 | -------------------------------------------------------------------------------- /src/main/java/cc/mallet/configuration/LDATrainTestConfiguration.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.configuration; 2 | 3 | public interface LDATrainTestConfiguration extends LDAConfiguration { 4 | 5 | String getTextDatasetTestIdsFilename(); 6 | 7 | } -------------------------------------------------------------------------------- /src/main/java/cc/mallet/configuration/ParsedLDATrainTestConfiguration.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.configuration; 2 | 3 | import org.apache.commons.configuration.ConfigurationException; 4 | 5 | 6 | public class ParsedLDATrainTestConfiguration extends ParsedLDAConfiguration implements Configuration, LDATrainTestConfiguration { 7 | 8 | private static final long serialVersionUID = 1L; 9 | 10 | public ParsedLDATrainTestConfiguration(LDACommandLineParser cp) throws ConfigurationException { 11 | super(cp); 12 | } 13 | 14 | public ParsedLDATrainTestConfiguration(String path) throws ConfigurationException { 15 | super(path); 16 | } 17 | 18 | /* (non-Javadoc) 19 | * @see cc.mallet.configuration.LDATrainTestConfiguration#getTextDatasetTestFilename() 20 | */ 21 | @Override 22 | public String getTextDatasetTestIdsFilename() { 23 | return getStringProperty("textdataset_testids"); 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /src/main/java/cc/mallet/pipe/RawTokenizer.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.pipe; 2 | 3 | import java.io.File; 4 | import java.util.ArrayList; 5 | import java.util.HashSet; 6 | 7 | import cc.mallet.types.Instance; 8 | 9 | /** 10 | * This tokenizer tries to do as little as possible with the input text 11 | * to facilitate doing text-preprocessing outside of LDA 12 | * 13 | * @author Leif Jonsson 14 | * 15 | */ 16 | public class RawTokenizer extends SimpleTokenizer { 17 | 18 | private static final long serialVersionUID = 1L; 19 | protected int tokenBufferSize = 10000; 20 | 21 | public RawTokenizer(File stopfile) { 22 | super(stopfile); 23 | } 24 | 25 | public RawTokenizer(int languageFlag) { 26 | super(languageFlag); 27 | } 28 | 29 | public RawTokenizer(HashSet stoplist) { 30 | super(stoplist); 31 | } 32 | 33 | public RawTokenizer(File stopfile, int bufferSize) { 34 | super(stopfile); 35 | tokenBufferSize = bufferSize; 36 | } 37 | 38 | public RawTokenizer(int languageFlag, int bufferSize) { 39 | super(languageFlag); 40 | tokenBufferSize = bufferSize; 41 | } 42 | 43 | public RawTokenizer(HashSet stoplist, int bufferSize) { 44 | super(stoplist); 45 | tokenBufferSize = bufferSize; 46 | } 47 | 48 | @SuppressWarnings("unchecked") 49 | @Override 50 | public RawTokenizer deepClone() { 51 | return new RawTokenizer((HashSet) stoplist.clone(), tokenBufferSize); 52 | } 53 | 54 | public Instance pipe(Instance instance) { 55 | 56 | if (instance.getData() instanceof CharSequence) { 57 | 58 | CharSequence characters = (CharSequence) instance.getData(); 59 | 60 | ArrayList tokens = new ArrayList(); 61 | 62 | int[] tokenBuffer = new int[tokenBufferSize]; 63 | int length = -1; 64 | 65 | int totalCodePoints = Character.codePointCount(characters, 0, characters.length()); 66 | 67 | for (int i=0; i < totalCodePoints; i++) { 68 | 69 | int codePoint = Character.codePointAt(characters, i); 70 | int codePointType = Character.getType(codePoint); 71 | 72 | if (codePointType == Character.SPACE_SEPARATOR || 73 | codePointType == Character.LINE_SEPARATOR) { 74 | 75 | if (length != -1) { 76 | String token = new String(tokenBuffer, 0, length + 1); 77 | if (! stoplist.contains(token)) { 78 | tokens.add(token); 79 | } 80 | length = -1; 81 | } 82 | } 83 | else { 84 | length++; 85 | tokenBuffer[length] = codePoint; 86 | } 87 | } 88 | 89 | if (length != -1) { 90 | String token = new String(tokenBuffer, 0, length + 1); 91 | if (! stoplist.contains(token)) { 92 | tokens.add(token); 93 | } 94 | } 95 | 96 | instance.setData(tokens); 97 | } 98 | else { 99 | throw new IllegalArgumentException("Looking for a CharSequence, found a " + 100 | instance.getData().getClass()); 101 | } 102 | 103 | return instance; 104 | } 105 | 106 | public int getTokenBufferSize() { 107 | return tokenBufferSize; 108 | } 109 | 110 | public void setTokenBufferSize(int tokenBufferSize) { 111 | this.tokenBufferSize = tokenBufferSize; 112 | } 113 | 114 | } 115 | -------------------------------------------------------------------------------- /src/main/java/cc/mallet/pipe/TokenSequencePredicateMatcher.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.pipe; 2 | 3 | import cc.mallet.types.Instance; 4 | import cc.mallet.types.Token; 5 | import cc.mallet.types.TokenSequence; 6 | 7 | public class TokenSequencePredicateMatcher extends Pipe 8 | { 9 | 10 | public interface Predicate { 11 | boolean test(T query); 12 | } 13 | 14 | private static final long serialVersionUID = 1L; 15 | Predicate predicate; 16 | 17 | public TokenSequencePredicateMatcher (Predicate p) 18 | { 19 | this.predicate = p; 20 | } 21 | 22 | public Instance pipe (Instance carrier) 23 | { 24 | TokenSequence ts = (TokenSequence) carrier.getData(); 25 | TokenSequence newts = new TokenSequence(); 26 | for (int i = 0; i < ts.size(); i++) { 27 | Token t = ts.get(i); 28 | if(predicate.test(t.getText())) { 29 | newts.add(t.getText()); 30 | } 31 | } 32 | carrier.setData(newts); 33 | return carrier; 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /src/main/java/cc/mallet/similarity/BhattacharyyaDistance.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.similarity; 2 | 3 | public class BhattacharyyaDistance implements Distance { 4 | 5 | @Override 6 | public double calculate(double[] v1, double[] v2) { 7 | double of = 1.0 / 4.0; 8 | 9 | double var1 = StatisticalDistance.variance(v1); 10 | double var2 = StatisticalDistance.variance(v2); 11 | double mean1 = StatisticalDistance.mean(v1); 12 | double mean2 = StatisticalDistance.mean(v2); 13 | 14 | double t1 = Math.log(of * (var1 / var2 + var2 / var2 + 2)); 15 | double t2 = Math.pow(mean1-mean2,2) / (var1 + var2); 16 | 17 | return of * t1 + of * t2; 18 | } 19 | } 20 | -------------------------------------------------------------------------------- /src/main/java/cc/mallet/similarity/CanberraDistance.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.similarity; 2 | 3 | public class CanberraDistance implements Distance { 4 | 5 | @Override 6 | public double calculate(double[] v1, double[] v2) { 7 | double sum = 0; 8 | for (int i = 0; i < v1.length; i++) { 9 | final double num = Math.abs(v1[i] - v2[i]); 10 | final double denom = Math.abs(v1[i]) + Math.abs(v2[i]); 11 | sum += denom == 0.0 ? 0.0 : num / denom; 12 | } 13 | return sum; 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /src/main/java/cc/mallet/similarity/ChebychevDistance.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.similarity; 2 | 3 | 4 | public class ChebychevDistance implements Distance { 5 | 6 | @Override 7 | public double calculate(double[] v1, double[] v2) { 8 | double max = 0; 9 | for (int i = 0; i < v1.length; i++) { 10 | max = Math.max(max, Math.abs(v1[i] - v2[i])); 11 | } 12 | return max; 13 | } 14 | 15 | } 16 | -------------------------------------------------------------------------------- /src/main/java/cc/mallet/similarity/CoOccurrenceLDALikelihoodDistance.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.similarity; 2 | 3 | import cc.mallet.configuration.LDAConfiguration; 4 | import cc.mallet.topics.LDASamplerWithPhi; 5 | 6 | public class CoOccurrenceLDALikelihoodDistance extends LDALikelihoodDistance { 7 | 8 | StreamCorpusStatistics cs; 9 | 10 | public CoOccurrenceLDALikelihoodDistance(double alpha, StreamCorpusStatistics cs) { 11 | super(alpha); 12 | this.cs = cs; 13 | } 14 | 15 | public CoOccurrenceLDALikelihoodDistance(LDASamplerWithPhi trainedSampler,StreamCorpusStatistics cs) { 16 | super(trainedSampler); 17 | this.cs = cs; 18 | } 19 | 20 | public CoOccurrenceLDALikelihoodDistance(LDAConfiguration config,StreamCorpusStatistics cs) { 21 | super(config); 22 | this.cs = cs; 23 | } 24 | 25 | /** 26 | * Calculate p(query|document) 27 | * @param query Frequency encoded query (query.length == vocabulary.length) 28 | * @param document Frequency encoded document (document.length == vocabulary.length) 29 | * @param theta 30 | * @return logLikelihood of document generating query 31 | */ 32 | public double ldaLoglikelihood(int[] query, int[] document, double[] theta) { 33 | //Map p_w_d = calcProbWordGivenDocMLFrequencyEncoding(document); 34 | double [] p_w_d = calcProbWordGivenDocMLFrequencyEncoding(document); 35 | 36 | double querylength = getDocLength(query); 37 | double doclength = getDocLength(document); 38 | 39 | // Some sanity check first 40 | if(querylength == 0 && doclength == 0) return 0; 41 | if(querylength == 0 && doclength != 0) return Double.POSITIVE_INFINITY; 42 | if(querylength != 0 && doclength == 0) return Double.POSITIVE_INFINITY; 43 | 44 | if(mixtureRatio<0) { 45 | mixtureRatio = (doclength / (doclength + mu)); 46 | } 47 | 48 | double p_q_d = 0.0; 49 | for (int i = 0; i < query.length; i++) { 50 | double wordProb = 0.0; 51 | int wordFreq = (int)query[i]; 52 | if(wordFreq > 0) { 53 | int word = i; 54 | double wordTopicProb; 55 | // No need to calculate this if it will have no effect 56 | if(lambda < 1) { 57 | wordTopicProb = calcProbWordGivenTheta(theta, word, phi); 58 | } else { 59 | wordTopicProb = 0.0; 60 | } 61 | double wordCorpusProb = calcProbWordGivenCorpus(word); 62 | 63 | if(p_w_d[word] != 0.0) { 64 | wordProb = p_w_d[word]; 65 | } else { 66 | wordProb = coOccurrenceScore(word,document,cs); 67 | } 68 | p_q_d += Math.log(lambda * (mixtureRatio * wordProb + 69 | (1-mixtureRatio) * wordCorpusProb) + 70 | (1-lambda) * wordTopicProb); 71 | } 72 | } 73 | 74 | return p_q_d; 75 | } 76 | 77 | static double coOccurrenceScore(int word, int[] document, StreamCorpusStatistics cs) { 78 | double score = 0.0; 79 | double documentLength = 0.0; 80 | for (int coOccurringWord = 0; coOccurringWord < document.length; coOccurringWord++) { 81 | int coOccurringWordFreq = document[coOccurringWord]; 82 | double probInc = cs.getCoOccurrence(word, coOccurringWord) / (double) cs.getDocFreqs()[word]; 83 | score += (coOccurringWordFreq * probInc); 84 | documentLength += coOccurringWordFreq; 85 | } 86 | return score / documentLength; 87 | } 88 | } 89 | -------------------------------------------------------------------------------- /src/main/java/cc/mallet/similarity/ContinousJaccardDistance.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.similarity; 2 | 3 | public class ContinousJaccardDistance implements Distance { 4 | 5 | @Override 6 | public double calculate(double[] v1, double[] v2) { 7 | double intersection = 0.0; 8 | double union = 0.0; 9 | for (int i = 0; i < v1.length; i++) { 10 | intersection += Math.min(v1[i], v2[i]); 11 | union += Math.max(v1[i], v2[i]); 12 | } 13 | if (intersection > 0.0D) { 14 | return 1-(intersection / union); 15 | } else { 16 | return 0.0; 17 | } 18 | } 19 | } 20 | -------------------------------------------------------------------------------- /src/main/java/cc/mallet/similarity/CosineDistance.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.similarity; 2 | 3 | public class CosineDistance implements Distance { 4 | 5 | @Override 6 | public double calculate(double[] v1, double[] v2) { 7 | if(v1.length != v2.length) throw new ArrayIndexOutOfBoundsException("Vectors have to be of equal length for cosine distance!"); 8 | double dotProduct = 0.0; 9 | double normA = 0.0; 10 | double normB = 0.0; 11 | for (int i = 0; i < v1.length; i++) { 12 | dotProduct += v1[i] * v2[i]; 13 | normA += Math.pow(v1[i], 2); 14 | normB += Math.pow(v2[i], 2); 15 | } 16 | return 1 - (dotProduct / (Math.sqrt(normA) * Math.sqrt(normB))); 17 | } 18 | 19 | } 20 | -------------------------------------------------------------------------------- /src/main/java/cc/mallet/similarity/Distance.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.similarity; 2 | 3 | public interface Distance { 4 | double calculate(double [] v1, double [] v2); 5 | } 6 | -------------------------------------------------------------------------------- /src/main/java/cc/mallet/similarity/EuclidianDistance.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.similarity; 2 | 3 | 4 | public class EuclidianDistance implements Distance { 5 | 6 | @Override 7 | public double calculate(double[] v1, double[] v2) { 8 | double sum = 0; 9 | for (int i = 0; i < v1.length; i++) { 10 | final double dp = v1[i] - v2[i]; 11 | sum += dp * dp; 12 | } 13 | return Math.sqrt(sum); 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /src/main/java/cc/mallet/similarity/HellingerDistance.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.similarity; 2 | 3 | import static java.lang.Math.*; 4 | 5 | public class HellingerDistance implements Distance { 6 | 7 | @Override 8 | public double calculate(double[] v1, double[] v2) { 9 | double sum = 0; 10 | for (int i = 0; i < v1.length; i++) { 11 | final double dp = sqrt(v1[i]) - sqrt(v2[i]); 12 | sum += (dp * dp); 13 | } 14 | return sum; 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /src/main/java/cc/mallet/similarity/InstanceDistance.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.similarity; 2 | 3 | import cc.mallet.types.Instance; 4 | 5 | public interface InstanceDistance { 6 | double distance(Instance v1, Instance v2); 7 | } 8 | -------------------------------------------------------------------------------- /src/main/java/cc/mallet/similarity/InstanceDistanceWrapper.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.similarity; 2 | 3 | import cc.mallet.types.Instance; 4 | 5 | public class InstanceDistanceWrapper implements InstanceDistance, Distance { 6 | 7 | Distance dist; 8 | Vectorizer vectorizer; 9 | 10 | public InstanceDistanceWrapper(Distance dist, Vectorizer vectorizer) { 11 | super(); 12 | this.dist = dist; 13 | this.vectorizer = vectorizer; 14 | } 15 | 16 | @Override 17 | public double distance(Instance i1, Instance i2) { 18 | double [] v1 = vectorizer.instanceToVector(i1); 19 | double [] v2 = vectorizer.instanceToVector(i2); 20 | return dist.calculate(v1, v2); 21 | } 22 | 23 | @Override 24 | public double calculate(double[] v1, double[] v2) { 25 | return dist.calculate(v1, v2); 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /src/main/java/cc/mallet/similarity/JaccardDistance.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.similarity; 2 | 3 | import java.util.HashSet; 4 | import java.util.Set; 5 | 6 | public class JaccardDistance implements Distance { 7 | 8 | @Override 9 | public double calculate(double[] v1, double[] v2) { 10 | Set union = new HashSet<>(); 11 | Set intersection = new HashSet<>(); 12 | Set v1s = new HashSet<>(); 13 | 14 | for (int i = 0; i < v1.length; i++) { 15 | if(v1[i]>0) { 16 | union.add(i); 17 | v1s.add(i); 18 | } 19 | } 20 | for (int i = 0; i < v2.length; i++) { 21 | if(v2[i]>0) { 22 | intersection.add(i); 23 | } 24 | } 25 | 26 | union.addAll(intersection); 27 | intersection.retainAll(v1s); 28 | 29 | if (intersection.size() > 0) { 30 | return 1-(intersection.size() / (double) union.size()); 31 | } else { 32 | return Double.MAX_VALUE; 33 | } 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /src/main/java/cc/mallet/similarity/JensenShannonDistance.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.similarity; 2 | 3 | public class JensenShannonDistance implements Distance { 4 | 5 | KLDistance kldist = new KLDistance(); 6 | 7 | @Override 8 | public double calculate(double[] v1, double[] v2) { 9 | double[] avg = new double[v1.length]; 10 | for (int i = 0; i < v1.length; ++i) { 11 | avg[i] += (v1[i] + v2[i])/2.0; 12 | } 13 | return (kldist.calculate(v1, avg) + kldist.calculate(v2, avg))/2; 14 | } 15 | 16 | } 17 | -------------------------------------------------------------------------------- /src/main/java/cc/mallet/similarity/KLDistance.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.similarity; 2 | 3 | public class KLDistance implements Distance { 4 | 5 | @Override 6 | public double calculate(double[] v1, double[] v2) { 7 | if(v1.length != v2.length) throw new IllegalArgumentException("Vectors have to be of equal length for KLDistance distance! v1.length=" 8 | + v1.length + " v2.legth=" + v2.length); 9 | return cc.mallet.util.Maths.klDivergence(v1, v2); 10 | } 11 | 12 | } 13 | -------------------------------------------------------------------------------- /src/main/java/cc/mallet/similarity/KolmogorovSmirnovDistance.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.similarity; 2 | 3 | import org.apache.commons.math3.stat.inference.KolmogorovSmirnovTest; 4 | 5 | public class KolmogorovSmirnovDistance implements Distance { 6 | 7 | KolmogorovSmirnovTest kstest = new KolmogorovSmirnovTest(); 8 | 9 | @Override 10 | public double calculate(double[] v1, double[] v2) { 11 | return kstest.kolmogorovSmirnovStatistic(v1, v2); 12 | } 13 | 14 | } 15 | -------------------------------------------------------------------------------- /src/main/java/cc/mallet/similarity/LengthLDALikelihoodDistance.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.similarity; 2 | 3 | import org.apache.commons.math3.distribution.TDistribution; 4 | 5 | import cc.mallet.configuration.LDAConfiguration; 6 | import cc.mallet.topics.LDASamplerWithPhi; 7 | 8 | public class LengthLDALikelihoodDistance extends LDALikelihoodDistance { 9 | 10 | public LengthLDALikelihoodDistance(int K, double alpha) { 11 | super(alpha); 12 | } 13 | 14 | public LengthLDALikelihoodDistance(LDASamplerWithPhi trainedSampler) { 15 | super(trainedSampler); 16 | } 17 | 18 | public LengthLDALikelihoodDistance(LDAConfiguration config) { 19 | super(config); 20 | } 21 | 22 | /** 23 | * Calculate p(query|document) 24 | * @param query Frequency encoded query (query.length == vocabulary.length) 25 | * @param document Frequency encoded document (document.length == vocabulary.length) 26 | * @param theta 27 | * @return logLikelihood of document generating query 28 | */ 29 | public double ldaLoglikelihood(int[] query, int[] document, double[] theta) { 30 | double p_q_d = super.ldaLoglikelihood(query, document, theta); 31 | 32 | double querylength = getDocLength(query); 33 | double doclength = getDocLength(document); 34 | 35 | double df = 150; 36 | TDistribution tdist = new TDistribution(df); 37 | double diff = Math.log(Math.abs(querylength-doclength)); 38 | double length_prob; 39 | if((querylength-doclength)==0) { 40 | length_prob = 0; 41 | } else { 42 | length_prob = Math.log(tdist.density(diff)); 43 | } 44 | 45 | p_q_d = p_q_d + length_prob; 46 | 47 | return p_q_d; 48 | } 49 | } 50 | -------------------------------------------------------------------------------- /src/main/java/cc/mallet/similarity/ManhattanDistance.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.similarity; 2 | 3 | public class ManhattanDistance implements Distance { 4 | 5 | @Override 6 | public double calculate(double[] v1, double[] v2) { 7 | double sum = 0; 8 | for (int i = 0; i < v1.length; i++) { 9 | sum += Math.abs(v1[i] - v2[i]); 10 | } 11 | return sum; 12 | } 13 | } 14 | -------------------------------------------------------------------------------- /src/main/java/cc/mallet/similarity/MultiInstanceDistanceWrapper.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.similarity; 2 | 3 | import cc.mallet.types.Instance; 4 | import cc.mallet.types.InstanceList; 5 | 6 | public class MultiInstanceDistanceWrapper extends InstanceDistanceWrapper implements TrainedDistance { 7 | 8 | InstanceList trainingset; 9 | public MultiInstanceDistanceWrapper(Distance dist, Vectorizer vectorizer, InstanceList trainingset) { 10 | super(dist, vectorizer); 11 | this.trainingset = trainingset; 12 | } 13 | 14 | @Override 15 | public void init(InstanceList trainingset) { 16 | this.trainingset = trainingset; 17 | if(dist instanceof TrainedDistance) { 18 | ((TrainedDistance) dist).init(trainingset); 19 | } 20 | } 21 | 22 | @Override 23 | public double distanceToTrainingSample(double[] query, int sampleId) { 24 | double [] v2 = vectorizer.instanceToVector(trainingset.get(sampleId)); 25 | return calculate(query, v2); 26 | } 27 | 28 | @Override 29 | public double[] distanceToAll(Instance testInstance) { 30 | double [] query = vectorizer.instanceToVector(testInstance); 31 | double [] results = new double [trainingset.size()]; 32 | for (int i = 0; i < results.length; i++) { 33 | results[i] = distanceToTrainingSample(query, i); 34 | } 35 | return results; 36 | } 37 | 38 | } 39 | -------------------------------------------------------------------------------- /src/main/java/cc/mallet/similarity/StatisticalDistance.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.similarity; 2 | 3 | public class StatisticalDistance implements Distance { 4 | 5 | @Override 6 | public double calculate(double[] v1, double[] v2) { 7 | return -(correlation(v1, v2)-1); 8 | } 9 | 10 | public static double mean(double [] vector) { 11 | double sum = 0.0; 12 | for (int i = 0; i < vector.length; i++) { 13 | sum +=vector[i]; 14 | } 15 | return sum/vector.length; 16 | } 17 | 18 | public static double variance(double [] v) 19 | { 20 | double mean = mean(v); 21 | double var = 0; 22 | for(double a : v) 23 | var += (mean-a)*(mean-a); 24 | return var/v.length; 25 | } 26 | 27 | public static double sd(double [] v) 28 | { 29 | return Math.sqrt(variance(v)); 30 | } 31 | 32 | public static double correlation(double [] v1, double [] v2) { 33 | return covariance(v1, v2) / Math.sqrt(variance(v1) * variance(v2)); 34 | } 35 | 36 | public static double covariance(double [] v1, double [] v2) { 37 | double result = 0.0; 38 | double m1 = mean(v1); 39 | double m2 = mean(v2); 40 | for (int i = 0; i < v1.length; i++) { 41 | double v1Deviation = v1[i] - m1; 42 | double v2Deviation = v2[i] - m2; 43 | result += (v1Deviation * v2Deviation - result) / (i + 1); 44 | } 45 | return result; 46 | } 47 | 48 | } 49 | -------------------------------------------------------------------------------- /src/main/java/cc/mallet/similarity/SymmetricKLDistance.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.similarity; 2 | 3 | public class SymmetricKLDistance implements Distance { 4 | 5 | @Override 6 | public double calculate(double[] v1, double[] v2) { 7 | // Symmetrisized KL divergence 8 | double u1 = cc.mallet.util.Maths.klDivergence(v1, v2); 9 | double u2 = cc.mallet.util.Maths.klDivergence(v2, v1); 10 | return (u1 + u2) / 2; 11 | } 12 | 13 | } 14 | -------------------------------------------------------------------------------- /src/main/java/cc/mallet/similarity/TDistance.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.similarity; 2 | 3 | import org.apache.commons.math3.stat.inference.TTest; 4 | 5 | public class TDistance implements Distance { 6 | 7 | TTest t = new TTest(); 8 | 9 | @Override 10 | public double calculate(double[] v1, double[] v2) { 11 | return t.t(v1, v2); 12 | } 13 | 14 | } 15 | -------------------------------------------------------------------------------- /src/main/java/cc/mallet/similarity/TfIdfVectorizer.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.similarity; 2 | 3 | import cc.mallet.pipe.TfIdfPipe; 4 | import cc.mallet.types.Alphabet; 5 | import cc.mallet.types.FeatureSequence; 6 | import cc.mallet.types.Instance; 7 | import cc.mallet.util.ArrayStringUtils; 8 | 9 | public class TfIdfVectorizer implements Vectorizer { 10 | 11 | TfIdfPipe tp; 12 | 13 | public TfIdfVectorizer(TfIdfPipe tp) { 14 | this.tp = tp; 15 | } 16 | 17 | @Override 18 | public double[] instanceToVector(Instance instance) { 19 | FeatureSequence trainTokenSeq = (FeatureSequence) instance.getData(); 20 | int [] tokenSequence = trainTokenSeq.getFeatures(); 21 | double [] coordinates = new double[instance.getAlphabet().size()]; 22 | for (int i = 0; i < tokenSequence.length; i++) { 23 | coordinates[tokenSequence[i]] = tp.getTfIdf().get(tokenSequence[i]); 24 | } 25 | return coordinates; 26 | } 27 | 28 | @Override 29 | public int[] instanceToIntVector(Instance instance) { 30 | FeatureSequence trainTokenSeq = (FeatureSequence) instance.getData(); 31 | int [] tokenSequence = trainTokenSeq.getFeatures(); 32 | int [] coordinates = new int[instance.getAlphabet().size()]; 33 | for (int i = 0; i < tokenSequence.length; i++) { 34 | coordinates[tokenSequence[i]] = (int) Math.round(tp.getTfIdf().get(tokenSequence[i])); 35 | } 36 | return coordinates; 37 | } 38 | 39 | public String toAnnotatedString(Instance instance) { 40 | double [] arr = instanceToVector(instance); 41 | Alphabet alphabet = instance.getAlphabet(); 42 | 43 | String res = ""; 44 | res += "[" + arr.length + "]:"; 45 | for (int j = 0; j < arr.length; j++) { 46 | if(arr[j]<0.00005) { 47 | res += ""; 48 | } else { 49 | String word = (String) alphabet.lookupObject(j); 50 | res += "(" + word + "):" + j + ":" + ArrayStringUtils.formatDouble(arr[j]) + ", "; 51 | } 52 | } 53 | return res; 54 | } 55 | 56 | 57 | } 58 | -------------------------------------------------------------------------------- /src/main/java/cc/mallet/similarity/TokenFrequencyVectorizer.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.similarity; 2 | 3 | import cc.mallet.types.Alphabet; 4 | import cc.mallet.types.FeatureSequence; 5 | import cc.mallet.types.Instance; 6 | import cc.mallet.util.ArrayStringUtils; 7 | 8 | public class TokenFrequencyVectorizer implements Vectorizer { 9 | 10 | @Override 11 | public double[] instanceToVector(Instance instance) { 12 | FeatureSequence features = (FeatureSequence) instance.getData(); 13 | double [] coordinates = new double[instance.getAlphabet().size()]; 14 | for (int i = 0; i < features.size(); i++) { 15 | coordinates[features.getIndexAtPosition(i)]++; 16 | } 17 | return coordinates; 18 | } 19 | 20 | @Override 21 | public int[] instanceToIntVector(Instance instance) { 22 | FeatureSequence features = (FeatureSequence) instance.getData(); 23 | int [] coordinates = new int[instance.getAlphabet().size()]; 24 | for (int i = 0; i < features.size(); i++) { 25 | coordinates[features.getIndexAtPosition(i)]++; 26 | } 27 | return coordinates; 28 | } 29 | 30 | 31 | public String toAnnotatedString(Instance instance) { 32 | double [] arr = instanceToVector(instance); 33 | Alphabet alphabet = instance.getAlphabet(); 34 | 35 | String res = ""; 36 | int nonZero = 0; 37 | for (int j = 0; j < arr.length; j++) { 38 | if(arr[j]<0.00005) { 39 | res += ""; 40 | } else { 41 | String word = (String) alphabet.lookupObject(j); 42 | res += "(" + word + "):" + j + ":" + ArrayStringUtils.formatDouble(arr[j]) + ", "; 43 | nonZero++; 44 | } 45 | } 46 | return "[" + nonZero + "]:" + res; 47 | } 48 | 49 | 50 | } 51 | -------------------------------------------------------------------------------- /src/main/java/cc/mallet/similarity/TokenIndexVectorizer.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.similarity; 2 | 3 | import cc.mallet.types.Alphabet; 4 | import cc.mallet.types.FeatureSequence; 5 | import cc.mallet.types.Instance; 6 | import cc.mallet.util.ArrayStringUtils; 7 | 8 | public class TokenIndexVectorizer implements Vectorizer { 9 | 10 | @Override 11 | public double[] instanceToVector(Instance instance) { 12 | FeatureSequence features = (FeatureSequence) instance.getData(); 13 | double [] coordinates = new double[features.size()]; 14 | 15 | for(int i = 0; i < features.size(); i++) { 16 | coordinates[i] = features.getIndexAtPosition(i); 17 | } 18 | 19 | return coordinates; 20 | } 21 | 22 | @Override 23 | public int[] instanceToIntVector(Instance instance) { 24 | FeatureSequence features = (FeatureSequence) instance.getData(); 25 | int [] coordinates = new int[features.size()]; 26 | 27 | for(int i = 0; i < features.size(); i++) { 28 | coordinates[i] = features.getIndexAtPosition(i); 29 | } 30 | 31 | return coordinates; 32 | } 33 | 34 | public String toAnnotatedString(Instance instance) { 35 | double [] arr = instanceToVector(instance); 36 | Alphabet alphabet = instance.getAlphabet(); 37 | 38 | String res = ""; 39 | res += "[" + arr.length + "]:"; 40 | for (int j = 0; j < arr.length; j++) { 41 | String word = (String) alphabet.lookupObject((int)arr[j]); 42 | res += "(" + word + "):" + j + ":" + ArrayStringUtils.formatDouble(arr[j]) + ", "; 43 | } 44 | return res; 45 | } 46 | 47 | 48 | } 49 | -------------------------------------------------------------------------------- /src/main/java/cc/mallet/similarity/TokenOccurenceVectorizer.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.similarity; 2 | 3 | import cc.mallet.types.Alphabet; 4 | import cc.mallet.types.FeatureSequence; 5 | import cc.mallet.types.Instance; 6 | 7 | public class TokenOccurenceVectorizer implements Vectorizer { 8 | 9 | @Override 10 | public double[] instanceToVector(Instance instance) { 11 | FeatureSequence features = (FeatureSequence) instance.getData(); 12 | double [] coordinates = new double[instance.getAlphabet().size()]; 13 | for (int i = 0; i < features.size(); i++) { 14 | coordinates[features.getIndexAtPosition(i)] = 1; 15 | } 16 | return coordinates; 17 | } 18 | 19 | @Override 20 | public int[] instanceToIntVector(Instance instance) { 21 | FeatureSequence features = (FeatureSequence) instance.getData(); 22 | int [] coordinates = new int[instance.getAlphabet().size()]; 23 | for (int i = 0; i < features.size(); i++) { 24 | coordinates[features.getIndexAtPosition(i)] = 1; 25 | } 26 | return coordinates; 27 | } 28 | 29 | public String toAnnotatedString(Instance instance) { 30 | FeatureSequence features = (FeatureSequence) instance.getData(); 31 | int nrWords = features.size(); 32 | Alphabet alphabet = instance.getAlphabet(); 33 | String res = ""; 34 | res += "[" + nrWords + "]:"; 35 | for (int j = 0; j < nrWords; j++) { 36 | String word = (String) alphabet.lookupObject(features.getIndexAtPosition(j)); 37 | res += "(" + word + "):" + j + ":1"; 38 | } 39 | return res; 40 | } 41 | 42 | } 43 | -------------------------------------------------------------------------------- /src/main/java/cc/mallet/similarity/TrainedDistance.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.similarity; 2 | 3 | import cc.mallet.types.Instance; 4 | import cc.mallet.types.InstanceList; 5 | 6 | public interface TrainedDistance extends Distance { 7 | void init(InstanceList trainingset); 8 | double distanceToTrainingSample(double [] query, int sampleId); 9 | double [] distanceToAll(Instance testInstance); 10 | } 11 | -------------------------------------------------------------------------------- /src/main/java/cc/mallet/similarity/UberDistance.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.similarity; 2 | 3 | public class UberDistance implements Distance { 4 | 5 | Distance [] measures = { 6 | new CanberraDistance(), 7 | new ChebychevDistance(), 8 | new CosineDistance(), 9 | new EuclidianDistance(), 10 | new ContinousJaccardDistance(), 11 | new KLDistance(), 12 | new ManhattanDistance() 13 | }; 14 | 15 | @Override 16 | public double calculate(double[] v1, double[] v2) { 17 | double sum = 0; 18 | for (int i = 0; i < measures.length; i++) { 19 | sum += measures[i].calculate(v1, v2); 20 | } 21 | return sum / measures.length; 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /src/main/java/cc/mallet/similarity/Vectorizer.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.similarity; 2 | 3 | import cc.mallet.types.Instance; 4 | 5 | public interface Vectorizer { 6 | double[] instanceToVector(Instance instance); 7 | int[] instanceToIntVector(Instance instance); 8 | String toAnnotatedString(Instance instance); 9 | } 10 | -------------------------------------------------------------------------------- /src/main/java/cc/mallet/topics/AbortableSampler.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.topics; 2 | 3 | public interface AbortableSampler { 4 | void abort(); 5 | boolean getAbort(); 6 | } 7 | -------------------------------------------------------------------------------- /src/main/java/cc/mallet/topics/HDPSamplerWithPhi.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.topics; 2 | 3 | import java.util.List; 4 | 5 | public interface HDPSamplerWithPhi extends LDASamplerWithPhi { 6 | int [] getTopicOcurrenceCount(); 7 | List getActiveTopicHistory(); 8 | List getActiveTopicInDataHistory(); 9 | 10 | } 11 | -------------------------------------------------------------------------------- /src/main/java/cc/mallet/topics/HashLDA.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.topics; 2 | 3 | import java.util.ArrayList; 4 | import java.util.HashMap; 5 | 6 | public interface HashLDA { 7 | int [] getMapType(); 8 | public ArrayList> getHashTopicTypeCounts(); 9 | } 10 | -------------------------------------------------------------------------------- /src/main/java/cc/mallet/topics/LDADocSamplingContext.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.topics; 2 | 3 | import cc.mallet.types.FeatureSequence; 4 | import cc.mallet.types.LabelSequence; 5 | 6 | public interface LDADocSamplingContext { 7 | 8 | FeatureSequence getTokens(); 9 | 10 | void setTokens(FeatureSequence tokens); 11 | 12 | LabelSequence getTopics(); 13 | 14 | void setTopics(LabelSequence topics); 15 | 16 | int getMyBatch(); 17 | 18 | void setMyBatch(int myBatch); 19 | 20 | int getDocIdx(); 21 | 22 | void setDocIdx(int docId); 23 | 24 | } -------------------------------------------------------------------------------- /src/main/java/cc/mallet/topics/LDADocSamplingResult.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.topics; 2 | 3 | public interface LDADocSamplingResult { 4 | 5 | int [] getLocalTopicCounts(); 6 | 7 | } -------------------------------------------------------------------------------- /src/main/java/cc/mallet/topics/LDADocSamplingResultDense.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.topics; 2 | 3 | public class LDADocSamplingResultDense implements LDADocSamplingResult { 4 | 5 | int[] localTopicCounts; 6 | 7 | public LDADocSamplingResultDense(int[] localTopicCounts) { 8 | super(); 9 | this.localTopicCounts = localTopicCounts; 10 | } 11 | 12 | @Override 13 | public int[] getLocalTopicCounts() { 14 | return localTopicCounts; 15 | } 16 | 17 | } 18 | -------------------------------------------------------------------------------- /src/main/java/cc/mallet/topics/LDADocSamplingResultSparse.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.topics; 2 | 3 | public interface LDADocSamplingResultSparse extends LDADocSamplingResult { 4 | 5 | int getNonZeroTopicCounts(); 6 | int [] getNonZeroIndices(); 7 | 8 | } -------------------------------------------------------------------------------- /src/main/java/cc/mallet/topics/LDADocSamplingResultSparseSimple.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.topics; 2 | 3 | public class LDADocSamplingResultSparseSimple extends LDADocSamplingResultDense implements LDADocSamplingResultSparse { 4 | 5 | int nonZeroTopicCnt; 6 | int [] nonZeroIndices; 7 | 8 | public LDADocSamplingResultSparseSimple(int[] localTopicCounts, int nonZeroTopicCnt, int[] nonZeroIndices) { 9 | super(localTopicCounts); 10 | this.nonZeroTopicCnt = nonZeroTopicCnt; 11 | this.nonZeroIndices = nonZeroIndices; 12 | } 13 | 14 | @Override 15 | public int getNonZeroTopicCounts() { 16 | return nonZeroTopicCnt; 17 | } 18 | 19 | @Override 20 | public int[] getNonZeroIndices() { 21 | return nonZeroIndices; 22 | } 23 | 24 | } 25 | -------------------------------------------------------------------------------- /src/main/java/cc/mallet/topics/LDAGibbsSampler.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.topics; 2 | 3 | import java.io.IOException; 4 | import java.util.ArrayList; 5 | 6 | import cc.mallet.configuration.LDAConfiguration; 7 | import cc.mallet.types.Alphabet; 8 | import cc.mallet.types.InstanceList; 9 | 10 | public interface LDAGibbsSampler extends AbortableSampler { 11 | void setConfiguration(LDAConfiguration config); 12 | LDAConfiguration getConfiguration(); 13 | void addInstances (InstanceList training); 14 | void addTestInstances (InstanceList testSet); 15 | void sample (int iterations) throws IOException; 16 | void setRandomSeed(int seed); 17 | int getNoTopics(); 18 | int getNumTopics(); 19 | int getNoTypes(); 20 | int getCurrentIteration(); 21 | int [][] getZIndicators(); 22 | double [][] getZbar(); 23 | double[][] getThetaEstimate(); 24 | void setZIndicators(int[][] zIndicators); 25 | InstanceList getDataset(); 26 | InstanceList getTestSet(); 27 | ArrayList getData(); 28 | int[][] getDeltaStatistics(); 29 | int[] getTopTypeFrequencyIndices(); 30 | int[] getTypeFrequencies(); 31 | long getCorpusSize(); 32 | Alphabet getAlphabet(); 33 | int getStartSeed(); 34 | double[] getTypeMassCumSum(); 35 | int [][] getDocumentTopicMatrix(); 36 | int [][] getTypeTopicMatrix(); 37 | int [] getTopicTotals(); 38 | double getBeta(); 39 | double[] getAlpha(); 40 | void preIteration(); 41 | void postIteration(); 42 | void preSample(); 43 | void postSample(); 44 | void postZ(); 45 | void preZ(); 46 | double[] getLogLikelihood(); 47 | double[] getHeldOutLogLikelihood(); 48 | } 49 | -------------------------------------------------------------------------------- /src/main/java/cc/mallet/topics/LDASamplerContinuable.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.topics; 2 | 3 | import java.io.IOException; 4 | 5 | public interface LDASamplerContinuable extends LDAGibbsSampler { 6 | void continueSampling(int iterations) throws IOException; 7 | void preContinuedSampling(); 8 | void postContinuedSampling(); 9 | } 10 | -------------------------------------------------------------------------------- /src/main/java/cc/mallet/topics/LDASamplerInitiable.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.topics; 2 | 3 | public interface LDASamplerInitiable extends LDAGibbsSampler { 4 | void initFrom(LDAGibbsSampler source); 5 | } 6 | -------------------------------------------------------------------------------- /src/main/java/cc/mallet/topics/LDASamplerWithCallback.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.topics; 2 | 3 | import cc.mallet.topics.tui.IterationListener; 4 | 5 | public interface LDASamplerWithCallback extends LDAGibbsSampler { 6 | void setIterationCallback(IterationListener iterListener); 7 | } 8 | -------------------------------------------------------------------------------- /src/main/java/cc/mallet/topics/LDASamplerWithDocumentPriors.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.topics; 2 | 3 | import java.util.Map; 4 | 5 | public interface LDASamplerWithDocumentPriors extends LDAGibbsSampler { 6 | Map getDocumentPriors(); 7 | } 8 | -------------------------------------------------------------------------------- /src/main/java/cc/mallet/topics/LDASamplerWithPhi.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.topics; 2 | 3 | import cc.mallet.types.Alphabet; 4 | 5 | public interface LDASamplerWithPhi extends LDAGibbsSampler { 6 | double [][] getPhi(); 7 | void setPhi(double [][] phi, Alphabet dataAlphabet, Alphabet targetAlphabet); 8 | double [][] getPhiMeans(); 9 | public void prePhi(); 10 | public void postPhi(); 11 | void sampleZGivenPhi(int iterations); 12 | } 13 | -------------------------------------------------------------------------------- /src/main/java/cc/mallet/topics/LDASamplerWithTopicPriors.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.topics; 2 | 3 | public interface LDASamplerWithTopicPriors extends LDAGibbsSampler { 4 | double [][] getTopicPriors(); 5 | } 6 | -------------------------------------------------------------------------------- /src/main/java/cc/mallet/topics/LogState.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.topics; 2 | 3 | import java.util.logging.Logger; 4 | 5 | public class LogState { 6 | public double logLik; 7 | public int iteration; 8 | public String wordsPerTopic; 9 | public String loggingPath; 10 | public Logger logger; 11 | 12 | public LogState(double logLik, int iteration, String wordsPerTopic, String loggingPath, Logger logger, 13 | long absoluteTime, long zSamplingTokenUpdateTime, long phiSamplingTime, double density) { 14 | this( logLik, iteration, wordsPerTopic, loggingPath, logger); 15 | } 16 | 17 | public LogState(double logLik, int iteration, String wordsPerTopic, String loggingPath, Logger logger) { 18 | this.logLik = logLik; 19 | this.iteration = iteration; 20 | this.wordsPerTopic = wordsPerTopic; 21 | this.loggingPath = loggingPath; 22 | this.logger = logger; 23 | } 24 | } -------------------------------------------------------------------------------- /src/main/java/cc/mallet/topics/SparseHDPSampler.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.topics; 2 | 3 | import java.util.ArrayList; 4 | import java.util.Arrays; 5 | import java.util.List; 6 | 7 | import cc.mallet.configuration.LDAConfiguration; 8 | 9 | public abstract class SparseHDPSampler extends PolyaUrnSpaliasLDA implements HDPSamplerWithPhi { 10 | 11 | private static final long serialVersionUID = 1L; 12 | 13 | List activeTopicHistory = new ArrayList(); 14 | List activeTopicInDataHistory = new ArrayList(); 15 | // Used to track the number of times a topic occurs in the dataset 16 | int [] topicOcurrenceCount; 17 | 18 | public SparseHDPSampler(LDAConfiguration config) { 19 | super(config); 20 | } 21 | 22 | public int[] getTopicOcurrenceCount() { 23 | return topicOcurrenceCount; 24 | } 25 | 26 | public void setTopicOcurrenceCount(int[] topicOcurrenceCount) { 27 | this.topicOcurrenceCount = topicOcurrenceCount; 28 | } 29 | 30 | public List getActiveTopicHistory() { 31 | return activeTopicHistory; 32 | } 33 | 34 | public void setActiveTopicHistory(List activeTopicHistory) { 35 | this.activeTopicHistory = activeTopicHistory; 36 | } 37 | 38 | public List getActiveTopicInDataHistory() { 39 | return activeTopicInDataHistory; 40 | } 41 | 42 | public void setActiveTopicInDataHistory(List activeInDataTopicHistory) { 43 | this.activeTopicInDataHistory = activeInDataTopicHistory; 44 | } 45 | 46 | public static int calcK(double percentile, int [] tokensPerTopic) { 47 | int [] sortedAllocation = Arrays.copyOf(tokensPerTopic, tokensPerTopic.length); 48 | Arrays.sort(sortedAllocation); 49 | int [] ecdf = calcEcdf(sortedAllocation); 50 | int k95 = findPercentile(ecdf,percentile); 51 | return k95; 52 | } 53 | 54 | public static int findPercentile(int[] ecdf, double percentile) { 55 | double total = ecdf[ecdf.length-1]; 56 | for (int j = 0; j < ecdf.length; j++) { 57 | if(ecdf[j]/total > percentile) { 58 | return j; 59 | } 60 | } 61 | return ecdf.length; 62 | } 63 | 64 | public static int[] calcEcdf(int[] sortedAllocation) { 65 | int [] ecdf = new int[sortedAllocation.length]; 66 | ecdf[0] = sortedAllocation[sortedAllocation.length-1]; 67 | for(int i = 1; i < sortedAllocation.length; i++) { 68 | ecdf[i] = sortedAllocation[sortedAllocation.length - i - 1] + ecdf[i-1]; 69 | } 70 | 71 | return ecdf; 72 | } 73 | 74 | } 75 | -------------------------------------------------------------------------------- /src/main/java/cc/mallet/topics/TableBuilderFactory.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.topics; 2 | 3 | import java.util.concurrent.Callable; 4 | 5 | interface TableBuilderFactory { 6 | Callable instance(int type); 7 | } -------------------------------------------------------------------------------- /src/main/java/cc/mallet/topics/UncollapsedLDADocSamplingContext.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.topics; 2 | 3 | import cc.mallet.types.FeatureSequence; 4 | import cc.mallet.types.LabelSequence; 5 | 6 | public class UncollapsedLDADocSamplingContext implements LDADocSamplingContext { 7 | FeatureSequence tokens; 8 | LabelSequence topics; 9 | int myBatch; 10 | int docIdx = -1; 11 | 12 | public UncollapsedLDADocSamplingContext(FeatureSequence tokens, LabelSequence topics, int myBatch, int docIdx) { 13 | super(); 14 | this.tokens = tokens; 15 | this.topics = topics; 16 | this.myBatch = myBatch; 17 | this.docIdx = docIdx; 18 | } 19 | /* (non-Javadoc) 20 | * @see cc.mallet.topics.LDADocSamplingContext#getTokens() 21 | */ 22 | @Override 23 | public FeatureSequence getTokens() { 24 | return tokens; 25 | } 26 | /* (non-Javadoc) 27 | * @see cc.mallet.topics.LDADocSamplingContext#setTokens(cc.mallet.types.FeatureSequence) 28 | */ 29 | @Override 30 | public void setTokens(FeatureSequence tokens) { 31 | this.tokens = tokens; 32 | } 33 | /* (non-Javadoc) 34 | * @see cc.mallet.topics.LDADocSamplingContext#getTopics() 35 | */ 36 | @Override 37 | public LabelSequence getTopics() { 38 | return topics; 39 | } 40 | /* (non-Javadoc) 41 | * @see cc.mallet.topics.LDADocSamplingContext#setTopics(cc.mallet.types.LabelSequence) 42 | */ 43 | @Override 44 | public void setTopics(LabelSequence topics) { 45 | this.topics = topics; 46 | } 47 | /* (non-Javadoc) 48 | * @see cc.mallet.topics.LDADocSamplingContext#getMyBatch() 49 | */ 50 | @Override 51 | public int getMyBatch() { 52 | return myBatch; 53 | } 54 | /* (non-Javadoc) 55 | * @see cc.mallet.topics.LDADocSamplingContext#setMyBatch(int) 56 | */ 57 | @Override 58 | public void setMyBatch(int myBatch) { 59 | this.myBatch = myBatch; 60 | } 61 | /* (non-Javadoc) 62 | * @see cc.mallet.topics.LDADocSamplingContext#getDocId() 63 | */ 64 | @Override 65 | public int getDocIdx() { 66 | return docIdx; 67 | } 68 | /* (non-Javadoc) 69 | * @see cc.mallet.topics.LDADocSamplingContext#setDocId(int) 70 | */ 71 | @Override 72 | public void setDocIdx(int docId) { 73 | this.docIdx = docId; 74 | } 75 | } -------------------------------------------------------------------------------- /src/main/java/cc/mallet/topics/WalkerAliasTableBuildResult.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.topics; 2 | 3 | import cc.mallet.util.WalkerAliasTable; 4 | 5 | public class WalkerAliasTableBuildResult { 6 | public int type; 7 | public WalkerAliasTable table; 8 | public double typeNorm; 9 | 10 | public WalkerAliasTableBuildResult(int type, WalkerAliasTable table, double typeNorm) { 11 | this.type = type; 12 | this.table = table; 13 | this.typeNorm = typeNorm; 14 | } 15 | } -------------------------------------------------------------------------------- /src/main/java/cc/mallet/topics/randomscan/document/AdaptiveBatchBuilder.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.topics.randomscan.document; 2 | 3 | import cc.mallet.configuration.LDAConfiguration; 4 | import cc.mallet.topics.LDAGibbsSampler; 5 | 6 | public class AdaptiveBatchBuilder extends PercentageBatchBuilder { 7 | 8 | // Set default period to be 0 iterations 9 | // For the NIPS data 125 iterations seems to be a good value 10 | int deltaInstabilityPeriod = 0; 11 | 12 | public AdaptiveBatchBuilder(LDAConfiguration config, LDAGibbsSampler sampler) { 13 | super(config, sampler); 14 | deltaInstabilityPeriod = config.getInstabilityPeriod(0); 15 | } 16 | 17 | public int getDeltaInstabilityPeriod() { 18 | return deltaInstabilityPeriod; 19 | } 20 | 21 | public void setDeltaInstabilityPeriod(int instabilityPeriod) { 22 | this.deltaInstabilityPeriod = instabilityPeriod; 23 | } 24 | 25 | boolean inInstabilityPeriod() { 26 | return sampler.getCurrentIteration() i ? 1 : 0); 46 | int [] indices = new int[docsInBatch]; 47 | for (int j = 0; j < docsInBatch; j++) { 48 | indices[j] = docIdx++; 49 | } 50 | docBatches[i] = indices; 51 | } 52 | } else { 53 | calculateDocumentsToSample(config, docsPercentage); 54 | super.calculateBatch(); 55 | } 56 | } 57 | 58 | @Override 59 | int calcDocsToSample() { 60 | if(inInstabilityPeriod()) { 61 | return totalDocsAvailable; 62 | } else { 63 | return super.calcDocsToSample(); 64 | } 65 | } 66 | } 67 | -------------------------------------------------------------------------------- /src/main/java/cc/mallet/topics/randomscan/document/BatchBuilderFactory.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.topics.randomscan.document; 2 | 3 | import java.lang.reflect.InvocationTargetException; 4 | 5 | import cc.mallet.configuration.LDAConfiguration; 6 | import cc.mallet.topics.LDAGibbsSampler; 7 | 8 | public class BatchBuilderFactory { 9 | 10 | public static final String EVEN_SPLIT = "cc.mallet.topics.randomscan.document.EvenSplitBatchBuilder"; 11 | public static final String PERCENTAGE_SPLIT = "cc.mallet.topics.randomscan.document.PercentageBatchBuilder"; 12 | public static final String ADAPTIVE_SPLIT = "cc.mallet.topics.randomscan.document.AdaptiveBatchBuilder"; 13 | public static final String FIXED_SPLIT = "cc.mallet.topics.randomscan.document.FixedSplitBatchBuilder"; 14 | 15 | public BatchBuilderFactory() { 16 | } 17 | 18 | @SuppressWarnings("unchecked") 19 | public static synchronized DocumentBatchBuilder get(LDAConfiguration config, LDAGibbsSampler sampler) { 20 | String topic_building_scheme = config.getDocumentBatchBuildingScheme(LDAConfiguration.BATCH_BUILD_SCHEME_DEFAULT); 21 | 22 | @SuppressWarnings("rawtypes") 23 | Class batchBuilderClass = null; 24 | try { 25 | batchBuilderClass = Class.forName(topic_building_scheme); 26 | } catch (ClassNotFoundException e) { 27 | e.printStackTrace(); 28 | throw new IllegalArgumentException(e); 29 | } 30 | 31 | @SuppressWarnings("rawtypes") 32 | Class[] argumentTypes = new Class[2]; 33 | argumentTypes[0] = LDAConfiguration.class; 34 | argumentTypes[1] = LDAGibbsSampler.class; 35 | 36 | try { 37 | return (DocumentBatchBuilder) batchBuilderClass.getDeclaredConstructor(argumentTypes) 38 | .newInstance(config,sampler); 39 | } catch (InstantiationException | IllegalAccessException 40 | | InvocationTargetException 41 | | NoSuchMethodException | SecurityException e) { 42 | e.printStackTrace(); 43 | throw new IllegalArgumentException(e); 44 | } 45 | } 46 | } 47 | -------------------------------------------------------------------------------- /src/main/java/cc/mallet/topics/randomscan/document/DocumentBatchBuilder.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.topics.randomscan.document; 2 | 3 | import cc.mallet.topics.LDAGibbsSampler; 4 | 5 | 6 | /** 7 | * In general batch builders are NOT thread safe, they are intended to be called only from 8 | * the coordinator thread! 9 | * 10 | * 11 | */ 12 | public interface DocumentBatchBuilder { 13 | 14 | /** 15 | * Do the calculation of the batch size, this Algorithm can vary depending on scheme 16 | */ 17 | void calculateBatch(); 18 | 19 | /** 20 | * The result is a matrix that contains the document indices to sample for each worker 21 | * @return a matrix A indexed by A[batch][documentIdx] 22 | */ 23 | public int[][] documentBatches(); 24 | 25 | /** 26 | * @return how many documents the workers should buffer before sending the resulting 27 | * samples 28 | */ 29 | int getDocResultsSize(); 30 | 31 | /** 32 | * @param currentIteration 33 | * @return how many documenst should be sampled during this iteration 34 | */ 35 | int getDocumentsInIteration(int currentIteration); 36 | 37 | 38 | /** 39 | * Sets the sampler that wants to do random scan, we might need various statistics from 40 | * the sampler to decide which documents, types and topics to sample 41 | * @param sampler 42 | */ 43 | void setSampler(LDAGibbsSampler sampler); 44 | } 45 | -------------------------------------------------------------------------------- /src/main/java/cc/mallet/topics/randomscan/document/EvenSplitBatchBuilder.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.topics.randomscan.document; 2 | 3 | import cc.mallet.configuration.LDAConfiguration; 4 | import cc.mallet.topics.LDAGibbsSampler; 5 | import cc.mallet.types.InstanceList; 6 | 7 | public class EvenSplitBatchBuilder implements DocumentBatchBuilder { 8 | 9 | LDAGibbsSampler sampler; 10 | LDAConfiguration config; 11 | InstanceList data; 12 | 13 | int[] batchSizeArray; 14 | int[] batchStartArray; 15 | int [][] documentBatches; 16 | int documentsPerIter; 17 | 18 | 19 | public EvenSplitBatchBuilder(LDAConfiguration config, LDAGibbsSampler sampler) { 20 | this.config = config; 21 | this.data = sampler.getDataset(); 22 | } 23 | 24 | @Override 25 | public void setSampler(LDAGibbsSampler sampler) { 26 | this.sampler = sampler; 27 | } 28 | 29 | @Override 30 | public synchronized void calculateBatch() { 31 | documentsPerIter = 0; 32 | int corpusSize = data.size(); 33 | // Initializes the batch sizes to be as even as possible 34 | int numBatches = config.getNoBatches(LDAConfiguration.NO_BATCHES_DEFAULT); 35 | int remainder = (corpusSize % numBatches); 36 | int batchSize = (corpusSize / numBatches); 37 | batchSizeArray = new int[numBatches]; 38 | batchStartArray = new int[numBatches]; 39 | for (int b = 0; b < numBatches; b++) { 40 | batchSizeArray[b] = batchSize + (remainder > b ? 1 : 0); 41 | documentsPerIter += batchSizeArray[b]; 42 | if(b > 0) batchStartArray[b] = batchStartArray[b-1] + batchSizeArray[b-1]; 43 | } 44 | } 45 | 46 | @Override 47 | public synchronized int[][] documentBatches() { 48 | if(documentBatches==null) { 49 | documentBatches = new int[batchSizeArray.length][]; 50 | for (int i = 0; i < documentBatches.length; i++) { 51 | documentBatches[i] = new int[batchSizeArray[i]]; 52 | int idx = batchStartArray[i]; 53 | for (int j = 0; j < documentBatches[i].length; j++) { 54 | documentBatches[i][j] = idx++; 55 | } 56 | } 57 | return documentBatches; 58 | } else { 59 | return documentBatches; 60 | } 61 | } 62 | 63 | @Override 64 | public int getDocResultsSize() { 65 | return config.getResultSize(LDAConfiguration.RESULTS_SIZE_DEFAULT); 66 | } 67 | 68 | /* This implementation have the same number of documents in each iteration 69 | * @see utils.BatchBuilder#getDocumentsInIteration(int) 70 | */ 71 | @Override 72 | public int getDocumentsInIteration(int currentIteration) { 73 | return documentsPerIter; 74 | } 75 | } 76 | -------------------------------------------------------------------------------- /src/main/java/cc/mallet/topics/randomscan/document/FixedSplitBatchBuilder.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.topics.randomscan.document; 2 | 3 | import cc.mallet.configuration.LDAConfiguration; 4 | import cc.mallet.topics.LDAGibbsSampler; 5 | import cc.mallet.types.InstanceList; 6 | 7 | /** 8 | * 9 | * This document batch builder builds batches per iteration that are a fixed %'age of the 10 | * total corpus. The percentage is read from the fixed_split_size_doc config 11 | * parameter. This can be an array of percentages (0.01, in which case it will it will take 12 | * these in order, the first percentage the first iteration, the next the next iteration 13 | * and so on. If it is only one value this will be used for every iteration. The documents 14 | * are NOT randomly drawn! It takes X% from the beginning of the corpus and then the next 15 | * X% of the corpus and so on. 16 | * 17 | */ 18 | public class FixedSplitBatchBuilder implements DocumentBatchBuilder { 19 | 20 | LDAGibbsSampler sampler; 21 | LDAConfiguration config; 22 | InstanceList data; 23 | 24 | int[] batchSizeArray; 25 | int [][] documentBatches; 26 | int numBatches; 27 | 28 | int documentsInIter; 29 | double [] percentages; 30 | // We are looping over the different percentages 31 | // percentagePointer index is the current position 32 | int percentagePointer = 0; 33 | // Points to the next document to be added to a batch 34 | int globalDocPointer = 1; 35 | 36 | public FixedSplitBatchBuilder(LDAConfiguration config, LDAGibbsSampler sampler) { 37 | this.config = config; 38 | this.data = sampler.getDataset(); 39 | percentages = config.getFixedSplitSizeDoc(); 40 | numBatches = config.getNoBatches(LDAConfiguration.NO_BATCHES_DEFAULT); 41 | if(percentages==null||percentages.length==0) { 42 | throw new IllegalArgumentException("Using 'fixed_split_size_doc' but did not find valid config for it."); 43 | } 44 | } 45 | 46 | @Override 47 | public void setSampler(LDAGibbsSampler sampler) { 48 | this.sampler = sampler; 49 | } 50 | 51 | @Override 52 | public synchronized void calculateBatch() { 53 | int corpusSize = data.size(); 54 | documentsInIter = (int) Math.ceil(percentages[percentagePointer]*corpusSize); 55 | percentagePointer = (percentagePointer+1) % percentages.length; 56 | 57 | // Distribute the documents evenly over the processors 58 | int remainder = (documentsInIter % numBatches); 59 | int batchSize = (documentsInIter / numBatches); 60 | batchSizeArray = new int[numBatches]; 61 | for (int b = 0; b < numBatches; b++) { 62 | batchSizeArray[b] = batchSize + (remainder > b ? 1 : 0); 63 | } 64 | 65 | documentBatches = new int[numBatches][]; 66 | for (int i = 0; i < documentBatches.length; i++) { 67 | documentBatches[i] = new int[batchSizeArray[i]]; 68 | for (int j = 0; j < documentBatches[i].length; j++) { 69 | documentBatches[i][j] = globalDocPointer++; 70 | // If we have reached the end of the corpus, start over 71 | if(globalDocPointer>=corpusSize) { 72 | globalDocPointer = 0; 73 | } 74 | } 75 | } 76 | } 77 | 78 | @Override 79 | public synchronized int[][] documentBatches() { 80 | return documentBatches; 81 | } 82 | 83 | @Override 84 | public int getDocResultsSize() { 85 | return config.getResultSize(LDAConfiguration.RESULTS_SIZE_DEFAULT); 86 | } 87 | 88 | @Override 89 | public int getDocumentsInIteration(int currentIteration) { 90 | return documentsInIter; 91 | } 92 | } 93 | -------------------------------------------------------------------------------- /src/main/java/cc/mallet/topics/randomscan/document/PercentageBatchBuilder.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.topics.randomscan.document; 2 | 3 | import cc.mallet.configuration.LDAConfiguration; 4 | import cc.mallet.topics.LDAGibbsSampler; 5 | import cc.mallet.util.IndexSampler; 6 | import cc.mallet.util.WithoutReplacementSampler; 7 | 8 | public class PercentageBatchBuilder implements DocumentBatchBuilder { 9 | 10 | LDAConfiguration config; 11 | int numBatches; 12 | int docsToSamplePerIteration; 13 | int docsPerBatch; 14 | int remainder; 15 | int totalDocsAvailable; 16 | double docsPercentage = 1.0; 17 | 18 | // The document batches 19 | int [][] docBatches; 20 | LDAGibbsSampler sampler; 21 | 22 | public PercentageBatchBuilder(LDAConfiguration config, LDAGibbsSampler sampler) { 23 | this.config = config; 24 | this.sampler = sampler; 25 | 26 | // Document part 27 | this.totalDocsAvailable = sampler.getDataset().size(); 28 | 29 | this.docsPercentage = config.getDocPercentageSplitSize(); 30 | // Calculate the number of docs to sample per batch 31 | calculateDocumentsToSample(config, docsPercentage); 32 | } 33 | 34 | protected void calculateDocumentsToSample(LDAConfiguration config, double docsPercentage) { 35 | this.docsToSamplePerIteration = (int) Math.ceil(docsPercentage * totalDocsAvailable); 36 | this.numBatches = config.getNoBatches(LDAConfiguration.NO_BATCHES_DEFAULT); 37 | this.docBatches = new int[numBatches][]; 38 | // The number of docs should be evenly split over the batches 39 | this.docsPerBatch = docsToSamplePerIteration / numBatches; 40 | // We split the remainder evenly over the batches 41 | this.remainder = docsToSamplePerIteration % numBatches; 42 | } 43 | 44 | @Override 45 | public void setSampler(LDAGibbsSampler sampler) { 46 | this.sampler = sampler; 47 | } 48 | 49 | @Override 50 | public void calculateBatch() { 51 | IndexSampler is = new WithoutReplacementSampler(0, totalDocsAvailable); 52 | 53 | for(int i = 0; i < numBatches; i++) { 54 | // This is needed since we split the remainder evenly over the batches 55 | int docsInBatch = docsPerBatch + (remainder > i ? 1 : 0); 56 | int [] indices = new int[docsInBatch]; 57 | 58 | // Sample w/o replacement 59 | for (int j = 0; j < docsInBatch; j++) { 60 | indices[j] = is.nextSample(); 61 | } 62 | docBatches[i] = indices; 63 | } 64 | } 65 | 66 | int calcDocsToSample() { 67 | return docsToSamplePerIteration; 68 | } 69 | 70 | @Override 71 | public int[][] documentBatches() { 72 | return docBatches; 73 | } 74 | 75 | @Override 76 | public int getDocResultsSize() { 77 | return config.getResultSize(LDAConfiguration.RESULTS_SIZE_DEFAULT); 78 | } 79 | 80 | @Override 81 | public int getDocumentsInIteration(int currentIteration) { 82 | return calcDocsToSample(); 83 | } 84 | } 85 | -------------------------------------------------------------------------------- /src/main/java/cc/mallet/topics/randomscan/topic/AllWordsTopicIndexBuilder.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.topics.randomscan.topic; 2 | 3 | import cc.mallet.configuration.LDAConfiguration; 4 | import cc.mallet.topics.LDAGibbsSampler; 5 | 6 | public class AllWordsTopicIndexBuilder implements TopicIndexBuilder { 7 | 8 | LDAConfiguration config; 9 | LDAGibbsSampler sampler; 10 | int[][] topicTypeIndices; 11 | 12 | public AllWordsTopicIndexBuilder(LDAConfiguration config, LDAGibbsSampler sampler) { 13 | this.config = config; 14 | this.sampler = sampler; 15 | //topicTypeIndices = AllWordsTopicIndexBuilder.getAllIndicesMatrix(sampler.getAlphabet().size(), sampler.getNoTopics()); 16 | } 17 | 18 | /** 19 | * Sample all words 20 | */ 21 | @Override 22 | public synchronized int[][] getTopicTypeIndices() { 23 | //return topicTypeIndices; 24 | // null means sample all 25 | return null; 26 | } 27 | 28 | public static int[][] getAllIndicesMatrix(int vocabSize, int noTopics) { 29 | int [] indicesToSample = new int[vocabSize]; 30 | for (int i = 0; i < indicesToSample.length; i++) { 31 | indicesToSample[i] = i; 32 | } 33 | int [][] topicTypeIndices = new int [noTopics][]; 34 | // In the basic version we sample the same tokens (words) in all the topics 35 | for (int i = 0; i < topicTypeIndices.length; i++) { 36 | topicTypeIndices[i] = indicesToSample; 37 | } 38 | return topicTypeIndices; 39 | } 40 | 41 | } 42 | -------------------------------------------------------------------------------- /src/main/java/cc/mallet/topics/randomscan/topic/DeltaNTopicIndexBuilder.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.topics.randomscan.topic; 2 | 3 | import cc.mallet.configuration.LDAConfiguration; 4 | import cc.mallet.topics.LDAGibbsSampler; 5 | 6 | public class DeltaNTopicIndexBuilder implements TopicIndexBuilder { 7 | 8 | LDAConfiguration config; 9 | LDAGibbsSampler sampler; 10 | int instabilityPeriod = 0; 11 | int fullPhiPeriod = -1; 12 | AllWordsTopicIndexBuilder allWords; 13 | 14 | public DeltaNTopicIndexBuilder(LDAConfiguration config, LDAGibbsSampler sampler) { 15 | this.config = config; 16 | this.sampler = sampler; 17 | instabilityPeriod = config.getInstabilityPeriod(0); 18 | fullPhiPeriod = config.getFullPhiPeriod(-1); 19 | allWords = new AllWordsTopicIndexBuilder(config,sampler); 20 | } 21 | 22 | /** 23 | * Decide which types (words) to sample in Phi proportional how much they changed 24 | * in the last update round. 25 | */ 26 | @Override 27 | public int[][] getTopicTypeIndices() { 28 | // If we are in the instable period, sample everything (null means everything) 29 | int currentIteration = sampler.getCurrentIteration(); 30 | if(currentIteration0 && ((currentIteration % fullPhiPeriod) == 0)) { 34 | return allWords.getTopicTypeIndices(); 35 | } else { 36 | return sampler.getDeltaStatistics(); 37 | } 38 | } 39 | 40 | } 41 | -------------------------------------------------------------------------------- /src/main/java/cc/mallet/topics/randomscan/topic/EvenSplitTopicBatchBuilder.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.topics.randomscan.topic; 2 | 3 | import cc.mallet.configuration.LDAConfiguration; 4 | import cc.mallet.topics.LDAGibbsSampler; 5 | import cc.mallet.types.InstanceList; 6 | 7 | public class EvenSplitTopicBatchBuilder implements TopicBatchBuilder { 8 | 9 | int[] topicBatchSizeArray; 10 | int[] topicBatchStartArray; 11 | int [][] topicBatches; 12 | LDAGibbsSampler sampler; 13 | LDAConfiguration config; 14 | InstanceList data; 15 | int numTopics = -1; // Make sure things crash if the default is used, fail fast! 16 | 17 | public EvenSplitTopicBatchBuilder(LDAConfiguration config, LDAGibbsSampler sampler) { 18 | this.config = config; 19 | this.data = sampler.getDataset(); 20 | } 21 | 22 | @Override 23 | public void setSampler(LDAGibbsSampler sampler) { 24 | this.sampler = sampler; 25 | } 26 | 27 | @Override 28 | public void calculateBatch() { 29 | int numBatches = config.getNoTopicBatches(LDAConfiguration.NO_TOPIC_BATCHES_DEFAULT); 30 | numTopics = config.getNoTopics(LDAConfiguration.NO_TOPICS_DEFAULT); 31 | topicBatchSizeArray = new int[numBatches]; 32 | topicBatchStartArray = new int[numBatches]; 33 | int topicRemainder = numTopics % numBatches; 34 | int topicBatchSize = numTopics / numBatches; 35 | for (int b = 0; b < numBatches; b++) { 36 | topicBatchSizeArray[b] = topicBatchSize + (topicRemainder > b ? 1 : 0); 37 | if(b > 0) topicBatchStartArray[b] = topicBatchStartArray[b-1] + topicBatchSizeArray[b-1]; 38 | } 39 | } 40 | 41 | @Override 42 | public synchronized int[][] topicBatches() { 43 | if(topicBatches==null) { 44 | topicBatches = new int[topicBatchSizeArray.length][]; 45 | for (int i = 0; i < topicBatches.length; i++) { 46 | topicBatches[i] = new int[topicBatchSizeArray[i]]; 47 | int idx = topicBatchStartArray[i]; 48 | for (int j = 0; j < topicBatches[i].length; j++) { 49 | topicBatches[i][j] = idx++; 50 | } 51 | } 52 | return topicBatches; 53 | } else { 54 | return topicBatches; 55 | } 56 | } 57 | 58 | /** 59 | * This builder samples all topics every iteration, this can be changed in subclasses 60 | * @see cc.mallet.topics.randomscan.topic.TopicBatchBuilder#getTopicsInIteration(int) 61 | */ 62 | @Override 63 | public int getTopicsInIteration(int currentIteration) { 64 | return numTopics; 65 | } 66 | } 67 | -------------------------------------------------------------------------------- /src/main/java/cc/mallet/topics/randomscan/topic/MandelbrotTopicIndexBuilder.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.topics.randomscan.topic; 2 | 3 | import cc.mallet.configuration.LDAConfiguration; 4 | import cc.mallet.topics.LDAGibbsSampler; 5 | 6 | public class MandelbrotTopicIndexBuilder implements TopicIndexBuilder { 7 | 8 | LDAConfiguration config; 9 | LDAGibbsSampler sampler; 10 | int instabilityPeriod = 0; 11 | int fullPhiPeriod; 12 | AllWordsTopicIndexBuilder allWords; 13 | double percentToSample; 14 | 15 | public MandelbrotTopicIndexBuilder(LDAConfiguration config, LDAGibbsSampler sampler) { 16 | this.config = config; 17 | this.sampler = sampler; 18 | instabilityPeriod = config.getInstabilityPeriod(0); 19 | fullPhiPeriod = config.getFullPhiPeriod(-1); 20 | percentToSample = config.topTokensToSample(0.2); 21 | allWords = new AllWordsTopicIndexBuilder(config,sampler); 22 | } 23 | 24 | /** 25 | * Samples the top X% ('percent_top_tokens' from config) of the most frequent tokens in the corpus, 26 | * respects the full_phi_period variable 27 | */ 28 | @Override 29 | public int[][] getTopicTypeIndices() { 30 | // If we are in the instable period, sample everything (null means everything) 31 | int currentIteration = sampler.getCurrentIteration(); 32 | if(currentIteration0 && ((currentIteration % fullPhiPeriod) == 0)) { 36 | return allWords.getTopicTypeIndices(); 37 | } else { 38 | int [] topIndices = sampler.getTopTypeFrequencyIndices(); 39 | 40 | int noToSample = (int) Math.ceil(percentToSample*topIndices.length); 41 | int [] indicesToSample = new int[noToSample]; 42 | for (int i = 0; i < noToSample; i++) { 43 | indicesToSample[i] = topIndices[i]; 44 | } 45 | int [][] topicTypeIndices = new int [sampler.getNoTopics()][]; 46 | // In the basic version we sample the same tokens (words) in all the topics 47 | for (int i = 0; i < topicTypeIndices.length; i++) { 48 | topicTypeIndices[i] = indicesToSample; 49 | } 50 | return topicTypeIndices; 51 | } 52 | } 53 | } 54 | -------------------------------------------------------------------------------- /src/main/java/cc/mallet/topics/randomscan/topic/MetaTopicIndexBuilder.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.topics.randomscan.topic; 2 | 3 | import java.lang.reflect.InvocationTargetException; 4 | 5 | import cc.mallet.configuration.LDAConfiguration; 6 | import cc.mallet.topics.LDAGibbsSampler; 7 | 8 | public class MetaTopicIndexBuilder implements TopicIndexBuilder { 9 | 10 | LDAConfiguration config; 11 | LDAGibbsSampler sampler; 12 | int instabilityPeriod = 0; 13 | int fullPhiPeriod; 14 | int builderIdx = 0; 15 | TopicIndexBuilder [] builders; 16 | 17 | public MetaTopicIndexBuilder(LDAConfiguration config, LDAGibbsSampler sampler) { 18 | this.config = config; 19 | this.sampler = sampler; 20 | instabilityPeriod = config.getInstabilityPeriod(0); 21 | fullPhiPeriod = config.getFullPhiPeriod(-1); 22 | String [] subbuilders = config.getSubTopicIndexBuilders(-1); 23 | initBuilders(subbuilders); 24 | } 25 | 26 | @SuppressWarnings({ "rawtypes", "unchecked" }) 27 | private void initBuilders(String[] subbuilders) { 28 | builders = new TopicIndexBuilder[subbuilders.length]; 29 | int cnt = 0; 30 | for (String builderName : subbuilders) { 31 | 32 | Class topicIndexBuilderClass = null; 33 | try { 34 | topicIndexBuilderClass = Class.forName(builderName); 35 | } catch (ClassNotFoundException e) { 36 | e.printStackTrace(); 37 | throw new IllegalArgumentException(e); 38 | } 39 | 40 | Class[] argumentTypes = new Class[2]; 41 | argumentTypes[0] = LDAConfiguration.class; 42 | argumentTypes[1] = LDAGibbsSampler.class; 43 | 44 | try { 45 | builders[cnt++] = (TopicIndexBuilder) topicIndexBuilderClass.getDeclaredConstructor(argumentTypes) 46 | .newInstance(config,sampler); 47 | } catch (InstantiationException | IllegalAccessException 48 | | InvocationTargetException 49 | | NoSuchMethodException | SecurityException e) { 50 | e.printStackTrace(); 51 | throw new IllegalArgumentException(e); 52 | } 53 | } 54 | } 55 | 56 | /** 57 | * Loop over the different builders and call them one by one 58 | */ 59 | @Override 60 | public int[][] getTopicTypeIndices() { 61 | // If we are in the instable period, sample everything (null means everything) 62 | int currentIteration = sampler.getCurrentIteration(); 63 | if(currentIterationfull_phi_period 28 | * variable 29 | */ 30 | @Override 31 | public int[][] getTopicTypeIndices() { 32 | // If we are in the instable period, sample everything (null means everything) 33 | int currentIteration = sampler.getCurrentIteration(); 34 | if(currentIterationpercentage_split_size_topic 11 | * config parameter. 12 | * 13 | * Config Example: 14 | * percentage_split_size_topic = 0.01 # Samples 1 % of (the topics) rows of Phi 15 | * 16 | */ 17 | public class PercentageTopicBatchBuilder implements TopicBatchBuilder { 18 | 19 | LDAConfiguration config;; 20 | int numTopicBatches; 21 | int numTopics; 22 | int topicsToSamplePerIteration; 23 | int remainder; 24 | int topicRemainder; 25 | int topicsPerBatch; 26 | double phiPercentage = 1.0; 27 | 28 | // The topic batches 29 | int [][] topicBatches; 30 | LDAGibbsSampler sampler; 31 | 32 | public PercentageTopicBatchBuilder(LDAConfiguration config, LDAGibbsSampler sampler) { 33 | this.config = config; 34 | this.sampler = sampler; 35 | this.phiPercentage = config.getTopicPercentageSplitSize(); 36 | this.numTopicBatches = config.getNoTopicBatches(LDAConfiguration.NO_TOPIC_BATCHES_DEFAULT); 37 | this.numTopics = config.getNoTopics(LDAConfiguration.NO_TOPICS_DEFAULT); 38 | 39 | // Calculate the number of topics to sample per batch 40 | this.topicsToSamplePerIteration = (int) Math.ceil(phiPercentage * numTopics); 41 | this.topicsPerBatch = topicsToSamplePerIteration / numTopicBatches; 42 | this.topicRemainder = topicsToSamplePerIteration % numTopicBatches; 43 | this.topicBatches = new int[numTopicBatches][]; 44 | } 45 | 46 | @Override 47 | public void setSampler(LDAGibbsSampler sampler) { 48 | this.sampler = sampler; 49 | } 50 | 51 | @Override 52 | public void calculateBatch() { 53 | IndexSampler is = new WithoutReplacementSampler(0, numTopics); 54 | 55 | for (int b = 0; b < numTopicBatches; b++) { 56 | int topicsInBatch = topicsPerBatch + (topicRemainder > b ? 1 : 0); 57 | int [] topicIndices = new int[topicsInBatch]; 58 | for (int j = 0; j < topicsInBatch; j++) { 59 | topicIndices[j] = is.nextSample(); 60 | } 61 | topicBatches[b] = topicIndices; 62 | } 63 | } 64 | 65 | @Override 66 | public int[][] topicBatches() { 67 | return topicBatches; 68 | } 69 | 70 | @Override 71 | public int getTopicsInIteration(int currentIteration) { 72 | return topicsToSamplePerIteration; 73 | } 74 | } 75 | -------------------------------------------------------------------------------- /src/main/java/cc/mallet/topics/randomscan/topic/ProportionalTopicIndexBuilder.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.topics.randomscan.topic; 2 | 3 | import cc.mallet.configuration.LDAConfiguration; 4 | import cc.mallet.topics.LDAGibbsSampler; 5 | import cc.mallet.util.SystematicSampling; 6 | 7 | public class ProportionalTopicIndexBuilder implements TopicIndexBuilder { 8 | 9 | LDAConfiguration config; 10 | LDAGibbsSampler sampler; 11 | int instabilityPeriod = 0; 12 | int fullPhiPeriod; 13 | AllWordsTopicIndexBuilder allWords; 14 | int skipStep=1; 15 | 16 | public ProportionalTopicIndexBuilder(LDAConfiguration config, LDAGibbsSampler sampler) { 17 | this.config = config; 18 | this.sampler = sampler; 19 | instabilityPeriod = config.getInstabilityPeriod(0); 20 | fullPhiPeriod = config.getFullPhiPeriod(-1); 21 | allWords = new AllWordsTopicIndexBuilder(config,sampler); 22 | skipStep = config.getProportionalTopicIndexBuilderSkipStep(); 23 | } 24 | 25 | /** 26 | * Samples the types in the corpus proportional to their frequency in the corpus 27 | * using systematic sampling. 28 | * Respects the full_phi_period 29 | * Respects the instabilityPeriod 30 | */ 31 | @Override 32 | public int[][] getTopicTypeIndices() { 33 | // If we are in the instable period, sample everything (null means everything) 34 | int currentIteration = sampler.getCurrentIteration(); 35 | if(currentIteration0 && ((currentIteration % fullPhiPeriod) == 0)) { 39 | return allWords.getTopicTypeIndices(); 40 | } else { 41 | //typeCounts 42 | int [] typeFreqs = sampler.getTypeFrequencies(); 43 | int [] indicesToSample = SystematicSampling.sample(typeFreqs, skipStep); 44 | 45 | int [][] topicTypeIndices = new int [sampler.getNoTopics()][]; 46 | for (int i = 0; i < topicTypeIndices.length; i++) { 47 | topicTypeIndices[i] = indicesToSample; 48 | } 49 | return topicTypeIndices; 50 | } 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /src/main/java/cc/mallet/topics/randomscan/topic/TopWordsRandomFractionTopicIndexBuilder.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.topics.randomscan.topic; 2 | 3 | import org.apache.commons.math3.distribution.BetaDistribution; 4 | 5 | import cc.mallet.configuration.LDAConfiguration; 6 | import cc.mallet.topics.LDAGibbsSampler; 7 | 8 | public class TopWordsRandomFractionTopicIndexBuilder implements TopicIndexBuilder { 9 | 10 | LDAConfiguration config; 11 | LDAGibbsSampler sampler; 12 | int instabilityPeriod = 0; 13 | double a = 2.0; 14 | double b = 5.0; 15 | double finalA = 5.0; 16 | double finalB = 0.05; 17 | double ainc; 18 | double binc; 19 | int fullPhiPeriod; 20 | AllWordsTopicIndexBuilder allWords; 21 | 22 | public TopWordsRandomFractionTopicIndexBuilder(LDAConfiguration config, LDAGibbsSampler sampler) { 23 | this.config = config; 24 | this.sampler = sampler; 25 | instabilityPeriod = config.getInstabilityPeriod(0); 26 | fullPhiPeriod = config.getFullPhiPeriod(-1); 27 | allWords = new AllWordsTopicIndexBuilder(config,sampler); 28 | int noIter = 200; 29 | ainc = (finalA-a)/noIter; 30 | binc = (b-finalB)/noIter; 31 | } 32 | 33 | /** 34 | * Decide which types (words) to sample in Phi proportional to the corpus frequency the type. 35 | * Types that have high frequency should be sampled more often but according 36 | * to random scan contract ALL types MUST have a small probability to be 37 | * sampled. We draw the proportion of types to sample from a Beta distribution 38 | * with a mode centered on 20 % from the start which tends towards 100% as the number 39 | * of iterations tend towards Inf. In the beginning we will sample the 20% most probable 40 | * words, but sometimes we will sample them all. After a while we always sample full Phi 41 | * A Beta(2.0,5.0) will have the mode (a-1) / (a+b-2) = 0.2 = 20% 42 | */ 43 | @Override 44 | public int[][] getTopicTypeIndices() { 45 | // If we are in the instable period, sample everything (null means everything) 46 | if(sampler.getCurrentIteration()finalA && b 0 && ((sampler.getCurrentIteration() % fullPhiPeriod) == 0)) 52 | return allWords.getTopicTypeIndices(); 53 | 54 | int [] topIndices = sampler.getTopTypeFrequencyIndices(); 55 | 56 | BetaDistribution beta = new BetaDistribution(a,b); 57 | double percentToSample = beta.sample(); 58 | if(afinalB) 61 | b -= binc; 62 | 63 | int noToSample = (int) Math.ceil(percentToSample*topIndices.length); 64 | int [] indicesToSample = new int[noToSample]; 65 | for (int i = 0; i < noToSample; i++) { 66 | indicesToSample[i] = topIndices[i]; 67 | } 68 | int [][] topicTypeIndices = new int [sampler.getNoTopics()][]; 69 | // In the basic version we sample the same tokens (words) in all the topics 70 | for (int i = 0; i < topicTypeIndices.length; i++) { 71 | topicTypeIndices[i] = indicesToSample; 72 | } 73 | return topicTypeIndices; 74 | } 75 | 76 | } 77 | -------------------------------------------------------------------------------- /src/main/java/cc/mallet/topics/randomscan/topic/TopicBatchBuilder.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.topics.randomscan.topic; 2 | 3 | import cc.mallet.topics.LDAGibbsSampler; 4 | 5 | 6 | /** 7 | * In general batch builders are NOT thread safe, they are intended to be called only from 8 | * the coordinator thread! 9 | * 10 | * 11 | */ 12 | public interface TopicBatchBuilder { 13 | 14 | /** 15 | * Do the calculation of the batch size, this Algorithm can vary depending on scheme 16 | */ 17 | void calculateBatch(); 18 | 19 | /** 20 | * The result is a matrix that contains the topic indices to sample for each worker 21 | * @return a matrix A indexed by A[batch][documentIdx] 22 | */ 23 | int [][] topicBatches(); 24 | 25 | 26 | /** 27 | * @param currentIteration 28 | * @return how many topics should be sampled during this iteration 29 | */ 30 | int getTopicsInIteration(int currentIteration); 31 | 32 | 33 | /** 34 | * Sets the sampler that wants to do random scan, we might need various statistics from 35 | * the sampler to decide which documents, types and topics to sample 36 | * @param sampler 37 | */ 38 | void setSampler(LDAGibbsSampler sampler); 39 | } 40 | -------------------------------------------------------------------------------- /src/main/java/cc/mallet/topics/randomscan/topic/TopicBatchBuilderFactory.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.topics.randomscan.topic; 2 | 3 | import java.lang.reflect.InvocationTargetException; 4 | 5 | import cc.mallet.configuration.LDAConfiguration; 6 | import cc.mallet.topics.LDAGibbsSampler; 7 | 8 | public class TopicBatchBuilderFactory { 9 | 10 | public static final String EVEN_SPLIT = "cc.mallet.topics.randomscan.topic.EvenSplitTopicBatchBuilder"; 11 | 12 | public TopicBatchBuilderFactory() { 13 | } 14 | 15 | @SuppressWarnings("unchecked") 16 | public static synchronized TopicBatchBuilder get(LDAConfiguration config, LDAGibbsSampler sampler) { 17 | String building_scheme = config.getTopicBatchBuildingScheme(LDAConfiguration.TOPIC_BATCH_BUILD_SCHEME_DEFAULT); 18 | 19 | @SuppressWarnings("rawtypes") 20 | Class batchBuilderClass = null; 21 | try { 22 | batchBuilderClass = Class.forName(building_scheme); 23 | } catch (ClassNotFoundException e) { 24 | e.printStackTrace(); 25 | throw new IllegalArgumentException(e); 26 | } 27 | 28 | @SuppressWarnings("rawtypes") 29 | Class[] argumentTypes = new Class[2]; 30 | argumentTypes[0] = LDAConfiguration.class; 31 | argumentTypes[1] = LDAGibbsSampler.class; 32 | 33 | try { 34 | return (TopicBatchBuilder) batchBuilderClass.getDeclaredConstructor(argumentTypes) 35 | .newInstance(config,sampler); 36 | } catch (InstantiationException | IllegalAccessException 37 | | InvocationTargetException 38 | | NoSuchMethodException | SecurityException e) { 39 | e.printStackTrace(); 40 | throw new IllegalArgumentException(e); 41 | } 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /src/main/java/cc/mallet/topics/randomscan/topic/TopicIndexBuilder.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.topics.randomscan.topic; 2 | 3 | public interface TopicIndexBuilder { 4 | 5 | /** 6 | * A matrix that contains the types to sample for each topic 7 | * @return an matrix A indexed by A[topic][typeIdx], null means sample ALL types for all topics 8 | */ 9 | int[][] getTopicTypeIndices(); 10 | 11 | } 12 | -------------------------------------------------------------------------------- /src/main/java/cc/mallet/topics/randomscan/topic/TopicIndexBuilderFactory.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.topics.randomscan.topic; 2 | 3 | import java.lang.reflect.InvocationTargetException; 4 | 5 | import cc.mallet.configuration.LDAConfiguration; 6 | import cc.mallet.topics.LDAGibbsSampler; 7 | 8 | public class TopicIndexBuilderFactory { 9 | 10 | public static final String ALL = "cc.mallet.topics.randomscan.topic.AllWordsTopicIndexBuilder"; 11 | public static final String ADAPTIVE_BETA_MIX = "cc.mallet.topics.randomscan.topic.TopWordsRandomFractionTopicIndexBuilder"; 12 | 13 | public TopicIndexBuilderFactory() { 14 | } 15 | 16 | @SuppressWarnings("unchecked") 17 | public static synchronized TopicIndexBuilder get(LDAConfiguration config, LDAGibbsSampler sampler) { 18 | String building_scheme = config.getTopicIndexBuildingScheme(LDAConfiguration.TOPIC_INDEX_BUILD_SCHEME_DEFAULT); 19 | 20 | @SuppressWarnings("rawtypes") 21 | Class topicIndexBuilderClass = null; 22 | try { 23 | topicIndexBuilderClass = Class.forName(building_scheme); 24 | } catch (ClassNotFoundException e) { 25 | e.printStackTrace(); 26 | throw new IllegalArgumentException(e); 27 | } 28 | 29 | @SuppressWarnings("rawtypes") 30 | Class[] argumentTypes = new Class[2]; 31 | argumentTypes[0] = LDAConfiguration.class; 32 | argumentTypes[1] = LDAGibbsSampler.class; 33 | 34 | try { 35 | return (TopicIndexBuilder) topicIndexBuilderClass.getDeclaredConstructor(argumentTypes) 36 | .newInstance(config,sampler); 37 | } catch (InstantiationException | IllegalAccessException 38 | | InvocationTargetException 39 | | NoSuchMethodException | SecurityException e) { 40 | e.printStackTrace(); 41 | throw new IllegalArgumentException(e); 42 | } 43 | } 44 | } 45 | -------------------------------------------------------------------------------- /src/main/java/cc/mallet/topics/tui/IterationListener.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.topics.tui; 2 | 3 | import cc.mallet.topics.LDAGibbsSampler; 4 | 5 | public interface IterationListener { 6 | void iterationCallback(LDAGibbsSampler model); 7 | } 8 | -------------------------------------------------------------------------------- /src/main/java/cc/mallet/types/ConditionalDirichlet.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.types; 2 | 3 | import java.util.Arrays; 4 | 5 | import cc.mallet.util.ParallelRandoms; 6 | 7 | public class ConditionalDirichlet extends ParallelDirichlet{ 8 | 9 | public ConditionalDirichlet(Alphabet dict, double alpha) { 10 | super(dict, alpha); 11 | // TODO Auto-generated constructor stub 12 | } 13 | 14 | public ConditionalDirichlet(Alphabet dict) { 15 | super(dict); 16 | // TODO Auto-generated constructor stub 17 | } 18 | 19 | public ConditionalDirichlet(double m, double[] p) { 20 | super(m, p); 21 | // TODO Auto-generated constructor stub 22 | } 23 | 24 | public ConditionalDirichlet(double[] alphas, Alphabet dict) { 25 | super(alphas, dict); 26 | // TODO Auto-generated constructor stub 27 | } 28 | 29 | public ConditionalDirichlet(double[] p) { 30 | super(p); 31 | // TODO Auto-generated constructor stub 32 | } 33 | 34 | public ConditionalDirichlet(int size, double alpha) { 35 | super(size, alpha); 36 | // TODO Auto-generated constructor stub 37 | } 38 | 39 | public ConditionalDirichlet(int size) { 40 | super(size); 41 | // TODO Auto-generated constructor stub 42 | } 43 | 44 | 45 | /** 46 | * Draw a conditional Dirichlet distribution. This version MODIFIES the input 47 | * argument Phi and updates the indices in phi_index with new draws 48 | * 49 | * @param phi Previous Dirichlet draw 50 | * @param phi_index Part of phi to produce new draw conditional draw for. 51 | * 52 | */ 53 | public void setNextConditionalDistribution(double[] phi, int[] phi_index) { 54 | // For each dimension in phi_index, draw a sample from Gamma(mp_i, 1) 55 | double sum_gamma = 0; 56 | double sum_phi = 0; 57 | for (int i = 0; i < phi_index.length; i++) { 58 | sum_phi += phi[phi_index[i]]; 59 | // Now phi in phi_index contain gammas. 60 | phi[phi_index[i]] = ParallelRandoms.rgamma(partition[phi_index[i]] * magnitude, 1, 0); 61 | if (phi[phi_index[i]] <= 0) { 62 | phi[phi_index[i]] = 0.0001; 63 | } 64 | sum_gamma += phi[phi_index[i]]; 65 | } 66 | 67 | // Normalize part of dirichlet 68 | for (int i = 0; i < phi_index.length; i++) { 69 | phi[phi_index[i]] = (phi[phi_index[i]] / sum_gamma) * sum_phi; 70 | } 71 | } 72 | 73 | /** 74 | * Draw a conditional Dirichlet distribution. 75 | * 76 | * @param phi Previous Dirichlet draw 77 | * @param phi_index Part of phi to produce new draw conditional draw for. 78 | * 79 | */ 80 | public double [] nextConditionalDistribution(double[] phiArg, int[] phi_index) { 81 | // Create the resulting Phi 82 | double [] phi = Arrays.copyOf(phiArg, phiArg.length); 83 | // For each dimension in phi_index, draw a sample from Gamma(mp_i, 1) 84 | double sum_gamma = 0; 85 | double sum_phi = 0; 86 | for (int i = 0; i < phi_index.length; i++) { 87 | sum_phi += phi[phi_index[i]]; 88 | // Now phi in phi_index contain gammas. 89 | phi[phi_index[i]] = ParallelRandoms.rgamma(partition[phi_index[i]] * magnitude, 1, 0); 90 | if (phi[phi_index[i]] <= 0) { 91 | phi[phi_index[i]] = 0.0001; 92 | } 93 | sum_gamma += phi[phi_index[i]]; 94 | } 95 | 96 | // Normalize part of dirichlet 97 | for (int i = 0; i < phi_index.length; i++) { 98 | phi[phi_index[i]] = (phi[phi_index[i]] / sum_gamma) * sum_phi; 99 | } 100 | return phi; 101 | } 102 | } 103 | -------------------------------------------------------------------------------- /src/main/java/cc/mallet/types/DefaultSparseDirichletSamplerBuilder.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.types; 2 | 3 | import cc.mallet.configuration.LDAConfiguration; 4 | 5 | public class DefaultSparseDirichletSamplerBuilder extends StandardArgsDirichletBuilder { 6 | 7 | @Override 8 | protected String getSparseDirichletSamplerClassName() { 9 | return LDAConfiguration.SPARSE_DIRICHLET_SAMPLER_DEFAULT; 10 | } 11 | 12 | } 13 | -------------------------------------------------------------------------------- /src/main/java/cc/mallet/types/MarsagliaSparseDirichlet.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.types; 2 | 3 | import cc.mallet.util.ParallelRandoms; 4 | 5 | public class MarsagliaSparseDirichlet extends ParallelDirichlet implements SparseDirichlet { 6 | double [] cs; 7 | double [] ds; 8 | 9 | public MarsagliaSparseDirichlet(double[] prior) { 10 | super(prior); 11 | cs = new double[prior.length]; 12 | ds = new double[prior.length]; 13 | for (int idx = 0; idx < prior.length; idx++) { 14 | double [] params = ParallelRandoms.preCalcParams(partition[idx]*magnitude); 15 | cs[idx] = params[0]; 16 | ds[idx] = params[1]; 17 | } 18 | } 19 | 20 | public MarsagliaSparseDirichlet(int size, double prior) { 21 | super(size, prior); 22 | cs = new double[size]; 23 | ds = new double[size]; 24 | double [] params = ParallelRandoms.preCalcParams(partition[0]*magnitude); 25 | for (int idx = 0; idx < size; idx++) { 26 | cs[idx] = params[0]; 27 | ds[idx] = params[1]; 28 | } 29 | } 30 | 31 | public double[] nextDistribution(int [] counts) { 32 | double distribution[] = new double[partition.length]; 33 | 34 | double sum = 0; 35 | for (int i=0; i0) { 34 | for (int i=0; i 0) { 22 | u-=probs[category]; 23 | category++; 24 | } 25 | res[category-1]++; 26 | } 27 | return res; 28 | } 29 | 30 | } 31 | -------------------------------------------------------------------------------- /src/main/java/cc/mallet/types/SparseDirichlet.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.types; 2 | 3 | public interface SparseDirichlet { 4 | 5 | public double[] nextDistribution(); 6 | public double[] nextDistribution(int [] counts); 7 | public VSResult nextDistributionWithSparseness(); 8 | public VSResult nextDistributionWithSparseness(int [] counts); 9 | public VSResult nextDistributionWithSparseness(double prior); 10 | public int[] updateDistributionWithSparseness(double [] target, double prior); 11 | 12 | } 13 | -------------------------------------------------------------------------------- /src/main/java/cc/mallet/types/SparseDirichletSamplerBuilder.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.types; 2 | 3 | import cc.mallet.topics.LDAGibbsSampler; 4 | 5 | public interface SparseDirichletSamplerBuilder { 6 | SparseDirichlet build(LDAGibbsSampler sampler); 7 | } 8 | -------------------------------------------------------------------------------- /src/main/java/cc/mallet/types/StandardArgsDirichletBuilder.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.types; 2 | 3 | import java.lang.reflect.InvocationTargetException; 4 | 5 | import cc.mallet.configuration.LDAConfiguration; 6 | import cc.mallet.topics.LDAGibbsSampler; 7 | 8 | public abstract class StandardArgsDirichletBuilder implements SparseDirichletSamplerBuilder { 9 | 10 | @Override 11 | public SparseDirichlet build(LDAGibbsSampler sampler) { 12 | return instantiateSparseDirichletSampler(getSparseDirichletSamplerClassName(), 13 | sampler.getNoTypes(), 14 | sampler.getConfiguration().getBeta(LDAConfiguration.BETA_DEFAULT)); 15 | } 16 | 17 | @SuppressWarnings("unchecked") 18 | protected synchronized SparseDirichlet instantiateSparseDirichletSampler(String samplerClassName, int numTypes, double beta) { 19 | 20 | @SuppressWarnings("rawtypes") 21 | Class modelClass = null; 22 | try { 23 | modelClass = Class.forName(samplerClassName); 24 | } catch (ClassNotFoundException e) { 25 | e.printStackTrace(); 26 | throw new IllegalArgumentException(e); 27 | } 28 | 29 | @SuppressWarnings("rawtypes") 30 | Class[] argumentTypes = new Class[2]; 31 | argumentTypes[0] = int.class; 32 | argumentTypes[1] = double.class; 33 | 34 | try { 35 | return (SparseDirichlet) modelClass.getDeclaredConstructor(argumentTypes).newInstance(numTypes,beta); 36 | } catch (InstantiationException | IllegalAccessException 37 | | InvocationTargetException 38 | | NoSuchMethodException | SecurityException e) { 39 | System.err.println("Could not create sampler: " + samplerClassName); 40 | e.printStackTrace(); 41 | throw new IllegalArgumentException(e); 42 | } 43 | } 44 | 45 | protected abstract String getSparseDirichletSamplerClassName(); 46 | 47 | } 48 | -------------------------------------------------------------------------------- /src/main/java/cc/mallet/types/VSResult.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.types; 2 | 3 | public class VSResult implements VariableSelectionResult { 4 | public double [] phiRow; 5 | public int [] nonZeroIdxs; 6 | public VSResult(double[] phiRow, int [] nonZeroIdxs) { 7 | this.phiRow = phiRow; 8 | this.nonZeroIdxs = nonZeroIdxs; 9 | } 10 | @Override 11 | public double[] getPhi() { 12 | return phiRow; 13 | } 14 | @Override 15 | public int[] getNonZeroIdxs() { 16 | int [] res = new int[nonZeroIdxs.length]; 17 | for (int i = 0; i < nonZeroIdxs.length; i++) { 18 | res[i] = nonZeroIdxs[i]; 19 | } 20 | return res; 21 | } 22 | } -------------------------------------------------------------------------------- /src/main/java/cc/mallet/types/VariableSelectionDirichlet.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.types; 2 | 3 | public interface VariableSelectionDirichlet { 4 | public VariableSelectionResult nextDistribution(int[] counts, double [] previousPhi); 5 | } 6 | -------------------------------------------------------------------------------- /src/main/java/cc/mallet/types/VariableSelectionResult.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.types; 2 | 3 | public interface VariableSelectionResult { 4 | double[] getPhi(); 5 | int[] getNonZeroIdxs(); 6 | } -------------------------------------------------------------------------------- /src/main/java/cc/mallet/util/EclipseDetector.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.util; 2 | 3 | import java.util.ArrayList; 4 | import java.util.List; 5 | 6 | public class EclipseDetector { 7 | /** 8 | * Checks incoming arguments for '-runningInEclipse || --runningInEclipse' if 9 | * this is detected returns a new String [] with this removed, else returns null 10 | * 11 | * @param args 12 | * @return 13 | */ 14 | public static String [] runningInEclipse(String [] args) { 15 | boolean inEclipse = false; 16 | List resultList = new ArrayList(); 17 | for (int i = 0; i < args.length; i++) { 18 | if(args[i].equals("-runningInEclipse") || args[i].equals("--runningInEclipse")) { 19 | inEclipse = true; 20 | } else { 21 | resultList.add(args[i]); 22 | } 23 | } 24 | if(inEclipse) { 25 | String [] result = new String[resultList.size()]; 26 | return resultList.toArray(result); 27 | } else { 28 | return null; 29 | } 30 | } 31 | } 32 | -------------------------------------------------------------------------------- /src/main/java/cc/mallet/util/GentleAliasMethod.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.util; 2 | 3 | import java.util.ArrayList; 4 | import java.util.List; 5 | import java.util.Random; 6 | import java.util.concurrent.ThreadLocalRandom; 7 | 8 | public class GentleAliasMethod implements WalkerAliasTable { 9 | Random random = new Random(); 10 | 11 | int k; 12 | double [] ps; 13 | int [] a; 14 | 15 | public GentleAliasMethod() { 16 | 17 | } 18 | public GentleAliasMethod(double [] pis, double normalizer) { 19 | generateAliasTable(pis,normalizer); 20 | } 21 | 22 | public GentleAliasMethod(double [] pis) { 23 | generateAliasTable(pis); 24 | } 25 | 26 | @Override 27 | public void initTableNormalizedProbabilities(double[] probabilities) { 28 | generateAliasTable(probabilities); 29 | } 30 | 31 | @Override 32 | public void initTable(double[] probabilities, double normalizer) { 33 | generateAliasTable(probabilities,normalizer); 34 | 35 | } 36 | 37 | public void generateAliasTable(double [] pi, double normalizer) { 38 | k = pi.length; 39 | ps = new double[k]; 40 | double [] b = new double[k]; 41 | List low = new ArrayList<>(); 42 | List high = new ArrayList<>(); 43 | double k1 = 1.0/k; 44 | a = new int [k]; 45 | for (int i = 0; i < k; i++) { 46 | a[i] = i; 47 | b[i] = (pi[i]/normalizer) - k1; 48 | if(b[i]<0.0) { 49 | low.add(i); 50 | } else { 51 | high.add(i); 52 | } 53 | } 54 | int steps = 0; 55 | while(steps<=k&&low.size()>0&&high.size()>0) { 56 | int l = low.remove(0); 57 | int h = high.get(0); 58 | double c=b[l]; 59 | double d=b[h]; 60 | b[l] = 0; 61 | b[h] = c + d; 62 | if(b[h]<=0) {high.remove(0);} 63 | if(b[h]<0) {low.add(h);} 64 | a[l] = h; 65 | ps[l] = 1.0 + ((double) k) * c; 66 | } 67 | } 68 | 69 | @Override 70 | public void reGenerateAliasTable(double[] pi, double normalizer) { 71 | generateAliasTable(pi,normalizer); 72 | } 73 | 74 | public void generateAliasTable(double [] pi) { 75 | generateAliasTable(pi,1.0); 76 | } 77 | 78 | @Override 79 | public int generateSample() { 80 | int i=random.nextInt(k); if (ThreadLocalRandom.current().nextDouble()>ps[i]) i=a[i]; return i; 81 | } 82 | 83 | @Override 84 | public int [] generateSamples(int nrSamples) { 85 | int [] samples = new int[nrSamples]; 86 | for (int i = 0; i < nrSamples; i++) { 87 | double u = ThreadLocalRandom.current().nextDouble(); 88 | samples[i] = generateSample(u); 89 | } 90 | return samples; 91 | } 92 | 93 | @Override 94 | public int generateSample(double u) { 95 | int i=random.nextInt(k); if (u>ps[i]) i=a[i]; return i; 96 | } 97 | 98 | public static void main(String [] args) { 99 | GentleAliasMethod ga = new GentleAliasMethod(); 100 | double [] pi = {0.3, 0.05, 0.2, 0.4, 0.05}; 101 | ga.generateAliasTable(pi); 102 | int [] counts = new int[pi.length]; 103 | int noSamples = 10_000_000; 104 | for(int i = 0; i memusage, 29 | String runner, String logfilename, String heading, String iterationType, int iterations, 30 | List metadata) throws Exception; 31 | 32 | void logTimings(String logfilename, String logdir) throws FileNotFoundException; 33 | 34 | PrintWriter checkCreateAndCreateLogPrinter(String dir,String filename); 35 | PrintWriter getLogPrinter(String filename); 36 | PrintWriter getAppendingLogPrinter(String filename); 37 | PrintStream getLogPrintStream(String filename,boolean append); 38 | boolean isFileLogger(); 39 | } -------------------------------------------------------------------------------- /src/main/java/cc/mallet/util/LDAThreadFactory.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.util; 2 | 3 | import java.util.concurrent.ThreadFactory; 4 | import java.util.concurrent.atomic.AtomicInteger; 5 | 6 | public class LDAThreadFactory implements ThreadFactory { 7 | private static final AtomicInteger poolNumber = new AtomicInteger(1); 8 | private final ThreadGroup group; 9 | private final AtomicInteger threadNumber = new AtomicInteger(1); 10 | private final String namePrefix; 11 | 12 | public LDAThreadFactory(String namePrePrefix) { 13 | SecurityManager s = System.getSecurityManager(); 14 | group = (s != null) ? s.getThreadGroup() : 15 | Thread.currentThread().getThreadGroup(); 16 | namePrefix = namePrePrefix + "-" + 17 | poolNumber.getAndIncrement() + 18 | "-thread-"; 19 | } 20 | 21 | 22 | @Override 23 | public Thread newThread(Runnable r) { 24 | Thread t = new Thread(group, r, 25 | namePrefix + threadNumber.getAndIncrement(), 26 | 0); 27 | 28 | if (t.isDaemon()) 29 | t.setDaemon(false); 30 | // I can't really see that we want non-daemon threads... 31 | // Seems we might get problem with daemon-threads?? Does this mess up join after invoke?? 32 | // t.setDaemon(true); 33 | if (t.getPriority() != Thread.NORM_PRIORITY) 34 | t.setPriority(Thread.NORM_PRIORITY); 35 | return t; 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /src/main/java/cc/mallet/util/MalletTopicIndicatorLogger.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.util; 2 | 3 | import java.io.BufferedWriter; 4 | import java.io.File; 5 | import java.io.FileWriter; 6 | import java.io.PrintWriter; 7 | import java.util.ArrayList; 8 | 9 | import cc.mallet.configuration.LDAConfiguration; 10 | import cc.mallet.topics.TopicAssignment; 11 | import cc.mallet.types.Alphabet; 12 | import cc.mallet.types.FeatureSequence; 13 | import cc.mallet.types.LabelSequence; 14 | 15 | public class MalletTopicIndicatorLogger implements TopicIndicatorLogger { 16 | 17 | public void log(ArrayList data, LDAConfiguration config, int iteration) { 18 | Alphabet a = data.get(0).instance.getDataAlphabet(); 19 | File ld = config.getLoggingUtil().getLogDir(); 20 | File z_file = new File(ld.getAbsolutePath() + "/z_" + iteration + ".csv"); 21 | try (FileWriter fw = new FileWriter(z_file, false); 22 | BufferedWriter bw = new BufferedWriter(fw); 23 | PrintWriter pw = new PrintWriter(bw)) { 24 | pw.println ("#doc source pos typeindex type topic"); 25 | for (int di = 0; di < data.size(); di++) { 26 | FeatureSequence fs = (FeatureSequence) data.get(di).instance.getData(); 27 | LabelSequence topicSequence = 28 | (LabelSequence) data.get(di).topicSequence; 29 | 30 | String source = "NA"; 31 | if (data.get(di).instance.getSource() != null) { 32 | source = data.get(di).instance.getSource().toString(); 33 | } 34 | 35 | int [] oneDocTopics = topicSequence.getFeatures(); 36 | for (int si = 0; si < fs.size(); si++) { 37 | int type = fs.getIndexAtPosition(si); 38 | pw.print(di); pw.print(' '); 39 | pw.print(source); pw.print(' '); 40 | pw.print(si); pw.print(' '); 41 | pw.print(type); pw.print(' '); 42 | pw.print(a.lookupObject(type)); pw.print(' '); 43 | pw.print(oneDocTopics[si]); pw.println(); 44 | } 45 | } 46 | } catch (Exception e) { 47 | e.printStackTrace(); 48 | } 49 | } 50 | } 51 | -------------------------------------------------------------------------------- /src/main/java/cc/mallet/util/NullOutputStream.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.util; 2 | 3 | import java.io.IOException; 4 | import java.io.OutputStream; 5 | 6 | public class NullOutputStream extends OutputStream { 7 | @Override 8 | public void write(int b) throws IOException { 9 | } 10 | } -------------------------------------------------------------------------------- /src/main/java/cc/mallet/util/NullPrintWriter.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.util; 2 | 3 | import java.io.File; 4 | import java.io.FileNotFoundException; 5 | import java.io.PrintWriter; 6 | import java.util.Locale; 7 | 8 | public class NullPrintWriter extends PrintWriter { 9 | 10 | public NullPrintWriter() throws FileNotFoundException { 11 | super(new NullOutputStream()); 12 | } 13 | 14 | public NullPrintWriter(File file) throws FileNotFoundException { 15 | super(file); 16 | } 17 | 18 | @Override 19 | public void flush() {} 20 | 21 | @Override 22 | public void close() {} 23 | 24 | @Override 25 | public boolean checkError() { 26 | return false; 27 | } 28 | 29 | @Override 30 | protected void setError() {} 31 | 32 | @Override 33 | protected void clearError() {} 34 | 35 | @Override 36 | public void write(int c) {} 37 | 38 | @Override 39 | public void write(char[] buf, int off, int len) {} 40 | 41 | @Override 42 | public void write(char[] buf) {} 43 | 44 | @Override 45 | public void write(String s, int off, int len) {} 46 | 47 | @Override 48 | public void write(String s) {} 49 | 50 | @Override 51 | public void print(boolean b) {} 52 | 53 | @Override 54 | public void print(char c) {} 55 | 56 | @Override 57 | public void print(int i) {} 58 | 59 | @Override 60 | public void print(long l) {} 61 | 62 | @Override 63 | public void print(float f) {} 64 | 65 | @Override 66 | public void print(double d) {} 67 | 68 | @Override 69 | public void print(char[] s) {} 70 | 71 | @Override 72 | public void print(String s) {} 73 | 74 | @Override 75 | public void print(Object obj) {} 76 | 77 | @Override 78 | public void println() {} 79 | 80 | @Override 81 | public void println(boolean x) {} 82 | 83 | @Override 84 | public void println(char x) {} 85 | 86 | @Override 87 | public void println(int x) {} 88 | 89 | @Override 90 | public void println(long x) {} 91 | 92 | @Override 93 | public void println(float x) {} 94 | 95 | @Override 96 | public void println(double x) {} 97 | 98 | @Override 99 | public void println(char[] x) {} 100 | 101 | @Override 102 | public void println(String x) {} 103 | 104 | @Override 105 | public void println(Object x) {} 106 | 107 | @Override 108 | public PrintWriter printf(String format, Object... args) { 109 | try { 110 | return new NullPrintWriter(); 111 | } catch (FileNotFoundException e) { 112 | e.printStackTrace(); 113 | } 114 | return null; 115 | } 116 | 117 | @Override 118 | public PrintWriter printf(Locale l, String format, Object... args) { 119 | return printf(format,args); 120 | } 121 | 122 | @Override 123 | public PrintWriter format(String format, Object... args) { 124 | return printf(format,args); 125 | } 126 | 127 | @Override 128 | public PrintWriter format(Locale l, String format, Object... args) { 129 | return printf(format,args); 130 | } 131 | 132 | @Override 133 | public PrintWriter append(CharSequence csq) { 134 | try { 135 | return new NullPrintWriter(); 136 | } catch (FileNotFoundException e) { 137 | e.printStackTrace(); 138 | } 139 | return null; 140 | } 141 | 142 | @Override 143 | public PrintWriter append(CharSequence csq, int start, int end) { 144 | return append(csq); 145 | } 146 | 147 | @Override 148 | public PrintWriter append(char c) { 149 | try { 150 | return new NullPrintWriter(); 151 | } catch (FileNotFoundException e) { 152 | e.printStackTrace(); 153 | } 154 | return null; 155 | } 156 | } 157 | -------------------------------------------------------------------------------- /src/main/java/cc/mallet/util/PerplexityDatasetBuilder.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.util; 2 | 3 | import gnu.trove.TIntHashSet; 4 | import cc.mallet.types.FeatureSequence; 5 | import cc.mallet.types.Instance; 6 | import cc.mallet.types.InstanceList; 7 | 8 | public class PerplexityDatasetBuilder { 9 | 10 | public static InstanceList [] buildPerplexityDataset(InstanceList instances, int noFolds) { 11 | InstanceList.CrossValidationIterator cviter = instances.crossValidationIterator(noFolds); 12 | InstanceList [] tSets = cviter.nextSplit(); 13 | InstanceList trainingSet = tSets[0]; 14 | InstanceList fullTestSet = tSets[1]; 15 | InstanceList testDocs = new InstanceList(trainingSet.getDataAlphabet(),trainingSet.getTargetAlphabet()); 16 | 17 | // For the perplexity calculations we want to split each "test document" into two parts where half the document 18 | // goes into the training set and the other into the test set (!) 19 | 20 | for(Instance testDoc : fullTestSet) { 21 | FeatureSequence tokens = (FeatureSequence) testDoc.getData(); 22 | 23 | // Select half of the words in the document to use in the test set 24 | int noWordsToSample = tokens.size() / 2; 25 | IndexSampler is = new WithoutReplacementSampler(0,tokens.size()); 26 | TIntHashSet inds = new TIntHashSet(); 27 | int [] testfeatures = new int[noWordsToSample]; 28 | for (int i = 0; i < noWordsToSample; i++) { 29 | int idx = is.nextSample(); 30 | testfeatures[i] = tokens.getIndexAtPosition(idx); 31 | inds.add(idx); 32 | } 33 | FeatureSequence testFs = new FeatureSequence(testDoc.getAlphabet(),testfeatures); 34 | Instance testPart = new Instance(testFs,testDoc.getTarget(), testDoc.getName(), testDoc.getSource()); 35 | testDocs.add(testPart); 36 | 37 | // Select the other half of the test document to put back in the training set 38 | int noWordsLeft = tokens.size()-noWordsToSample; 39 | int [] trainfeatures = new int[noWordsLeft]; 40 | int added = 0; 41 | for (int i = 0; i < tokens.size(); i++) { 42 | if(!inds.contains(i)) { 43 | trainfeatures[added++] = tokens.getIndexAtPosition(i); 44 | } 45 | } 46 | FeatureSequence trainFs = new FeatureSequence(testDoc.getAlphabet(),trainfeatures); 47 | Instance trainPart = new Instance(trainFs,testDoc.getTarget(), testDoc.getName(), testDoc.getSource()); 48 | // Now add this half to the training set 49 | trainingSet.add(trainPart); 50 | } 51 | 52 | InstanceList [] result = new InstanceList[2]; 53 | result[0] = trainingSet; 54 | result[1] = testDocs; 55 | 56 | return result; 57 | } 58 | 59 | } 60 | -------------------------------------------------------------------------------- /src/main/java/cc/mallet/util/ReMappedAliasTable.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.util; 2 | 3 | import java.io.Serializable; 4 | import java.util.concurrent.ThreadLocalRandom; 5 | 6 | public class ReMappedAliasTable extends OptimizedGentleAliasMethod implements WalkerAliasTable, Serializable { 7 | private static final long serialVersionUID = 1L; 8 | 9 | int [] mapping; 10 | 11 | public ReMappedAliasTable(int [] mapping) { 12 | this.mapping = mapping; 13 | } 14 | 15 | public ReMappedAliasTable(double [] pis, double normalizer, int [] mapping) { 16 | this.mapping = mapping; 17 | generateAliasTable(pis,normalizer); 18 | } 19 | 20 | public ReMappedAliasTable(double [] pis, int [] mapping) { 21 | this.mapping = mapping; 22 | generateAliasTable(pis); 23 | } 24 | 25 | @Override 26 | public int generateSample() { 27 | double u = ThreadLocalRandom.current().nextDouble(); 28 | return mapping[generateSample(u)]; 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /src/main/java/cc/mallet/util/SparsityTools.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.util; 2 | 3 | public class SparsityTools { 4 | public static int remove(int oldTopic, int[] nonZeroTopics, int[] nonZeroTopicsBackMapping, int nonZeroTopicCnt) { 5 | if (nonZeroTopicCnt<1) { 6 | throw new IllegalArgumentException ("Cannot remove, count is less than 1 => " + nonZeroTopicCnt); 7 | } 8 | // We have one less non-zero topic, move the last to its place, and decrease the non-zero count 9 | int nonZeroIdx = nonZeroTopicsBackMapping[oldTopic]; 10 | nonZeroTopics[nonZeroIdx] = nonZeroTopics[--nonZeroTopicCnt]; 11 | nonZeroTopicsBackMapping[nonZeroTopics[nonZeroIdx]] = nonZeroIdx; 12 | return nonZeroTopicCnt; 13 | } 14 | 15 | public static int insert(int newTopic, int[] nonZeroTopics, int[] nonZeroTopicsBackMapping, int nonZeroTopicCnt) { 16 | //// We have a new non-zero topic put it in the last empty slot and increase the count 17 | nonZeroTopics[nonZeroTopicCnt] = newTopic; 18 | nonZeroTopicsBackMapping[newTopic] = nonZeroTopicCnt; 19 | return ++nonZeroTopicCnt; 20 | } 21 | 22 | public static int removeSorted(int oldTopic, int[] nonZeroTopics, int[] nonZeroTopicsBackMapping, int nonZeroTopicCnt) { 23 | if (nonZeroTopicCnt<1) { 24 | throw new IllegalArgumentException ("Cannot remove, count is less than 1"); 25 | } 26 | //System.out.println("New empty topic. Cnt = " + nonZeroTopicCnt); 27 | int nonZeroIdx = nonZeroTopicsBackMapping[oldTopic]; 28 | nonZeroTopicCnt--; 29 | // Shift the ones above one step to the left 30 | for(int i=nonZeroIdx; i nonZeroTopics[slot] && slot < nonZeroTopicCnt) slot++; 43 | 44 | for(int i=nonZeroTopicCnt; i>slot;i--) { 45 | // Move the last non-zero topic to this new empty slot 46 | nonZeroTopics[i] = nonZeroTopics[i-1]; 47 | // Do the corresponding for the back mapping 48 | nonZeroTopicsBackMapping[nonZeroTopics[i]] = i; 49 | } 50 | nonZeroTopics[slot] = newTopic; 51 | nonZeroTopicsBackMapping[newTopic] = slot; 52 | nonZeroTopicCnt++; 53 | return nonZeroTopicCnt; 54 | } 55 | 56 | public static int findIdx(double[] cumsum, double u, int maxIdx) { 57 | if(cumsum.length<2000) { 58 | return findIdxLinSentinel(cumsum,u,maxIdx); 59 | } else { 60 | return findIdxBin(cumsum,u,maxIdx); 61 | } 62 | } 63 | 64 | public static int findIdxBin(double[] cumsum, double u, int maxIdx) { 65 | int slot = java.util.Arrays.binarySearch(cumsum,0,maxIdx,u); 66 | 67 | return slot >= 0 ? slot : -(slot+1); 68 | } 69 | 70 | public static int findIdxLinSentinel(double[] cumsum, double u, int maxIdx) { 71 | cumsum[cumsum.length-1] = Double.MAX_VALUE; 72 | int i = 0; 73 | while(true) { 74 | if(u<=cumsum[i]) return i; 75 | i++; 76 | } 77 | } 78 | 79 | public static int findIdxLin(double[] cumsum, double u, int maxIdx) { 80 | for (int i = 0; i < maxIdx; i++) { 81 | if(u<=cumsum[i]) return i; 82 | } 83 | return cumsum.length-1; 84 | } 85 | } 86 | -------------------------------------------------------------------------------- /src/main/java/cc/mallet/util/StandardTopicIndicatorLogger.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.util; 2 | 3 | import java.io.BufferedWriter; 4 | import java.io.File; 5 | import java.io.FileWriter; 6 | import java.io.PrintWriter; 7 | import java.util.ArrayList; 8 | 9 | import cc.mallet.configuration.LDAConfiguration; 10 | import cc.mallet.topics.TopicAssignment; 11 | import cc.mallet.types.LabelSequence; 12 | 13 | public class StandardTopicIndicatorLogger implements TopicIndicatorLogger { 14 | public void log(ArrayList data, LDAConfiguration config, int iteration) { 15 | File ld = config.getLoggingUtil().getLogDir(); 16 | File z_file = new File(ld.getAbsolutePath() + "/z_" + iteration + ".csv"); 17 | try (FileWriter fw = new FileWriter(z_file, false); 18 | BufferedWriter bw = new BufferedWriter(fw); 19 | PrintWriter pw = new PrintWriter(bw)) { 20 | for (int docIdx = 0; docIdx < data.size(); docIdx++) { 21 | String szs = ""; 22 | LabelSequence topicSequence = 23 | (LabelSequence) data.get(docIdx).topicSequence; 24 | int [] oneDocTopics = topicSequence.getFeatures(); 25 | for (int i = 0; i < topicSequence.size(); i++) { 26 | szs += oneDocTopics[i] + ","; 27 | } 28 | if(szs.length()>0) { 29 | szs = szs.substring(0, szs.length()-1); 30 | } 31 | pw.println(szs); 32 | } 33 | } catch (Exception e) { 34 | e.printStackTrace(); 35 | } 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /src/main/java/cc/mallet/util/Stats.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.util; 2 | 3 | 4 | public class Stats { 5 | public int iteration; 6 | public String loggingPath; 7 | public long absoluteTime; 8 | public long zSamplingTokenUpdateTime; 9 | public long phiSamplingTime; 10 | public double density; 11 | public double docDensity; 12 | public double phiDensity; 13 | public long [] zTimings; 14 | public long [] countTimings; 15 | public Double heldOutLL; 16 | 17 | public Stats(int iteration, String loggingPath, long absoluteTime, 18 | long zSamplingTokenUpdateTime, long phiSamplingTime, double density, 19 | double docDensity, long [] zTimings, long [] countTimings, double phiDensity) { 20 | this.iteration = iteration; 21 | this.loggingPath = loggingPath; 22 | this.absoluteTime = absoluteTime; 23 | this.zSamplingTokenUpdateTime = zSamplingTokenUpdateTime; 24 | this.phiSamplingTime = phiSamplingTime; 25 | this.density = density; 26 | this.docDensity = docDensity; 27 | this.phiDensity = phiDensity; 28 | this.zTimings = zTimings; 29 | this.countTimings = countTimings; 30 | } 31 | 32 | public Stats(int iteration, String loggingPath, long elapsedMillis, long zSamplingTokenUpdateTime, 33 | long phiSamplingTime, double density, double docDensity, long[] zTimings, long[] countTimings, 34 | double phiDensity, Double heldOutLL) { 35 | this(iteration, loggingPath,elapsedMillis, 36 | zSamplingTokenUpdateTime, phiSamplingTime, density, 37 | docDensity, zTimings, countTimings, phiDensity); 38 | this.heldOutLL = heldOutLL; 39 | } 40 | 41 | } 42 | -------------------------------------------------------------------------------- /src/main/java/cc/mallet/util/StringClassArrayIterator.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.util; 2 | 3 | import java.net.URI; 4 | import java.util.Iterator; 5 | 6 | import cc.mallet.types.Instance; 7 | 8 | /** 9 | * Simple Iterator for string data where you can supply the class it 10 | * belongs to 11 | * 12 | * @author Leif Jonsson 13 | * 14 | */ 15 | public class StringClassArrayIterator implements Iterator { 16 | 17 | String[] id; 18 | String[] data; 19 | int index = 0; 20 | String [] classNames; 21 | 22 | public StringClassArrayIterator (String[] data) 23 | { 24 | this.data = data; 25 | } 26 | 27 | public StringClassArrayIterator (String[] data, String className) 28 | { 29 | this.data = data; 30 | this.classNames = new String[]{className}; 31 | } 32 | 33 | public StringClassArrayIterator (String[] data, String [] classNames) 34 | { 35 | if(classNames != null && classNames.length != 1 && data.length != classNames.length) { 36 | throw new IllegalArgumentException("data.length != classNames.length when classNames.length != 1"); 37 | } 38 | this.data = data; 39 | this.classNames = classNames; 40 | } 41 | 42 | public StringClassArrayIterator (String[] data, String [] classNames, String [] ids) 43 | { 44 | if(classNames != null && classNames.length != 1 && data.length != classNames.length) { 45 | throw new IllegalArgumentException("data.length != classNames.length when classNames.length != 1"); 46 | } 47 | if(ids != null && data.length != ids.length) { 48 | throw new IllegalArgumentException("data.length != ids.length"); 49 | } 50 | this.data = data; 51 | this.classNames = classNames; 52 | this.id = ids; 53 | } 54 | 55 | public Instance next () 56 | { 57 | URI uri = null; 58 | try { 59 | if(id==null) 60 | uri = new URI ("" + index); 61 | else 62 | uri = new URI (id[index]); 63 | } 64 | catch (Exception e) { e.printStackTrace(); throw new IllegalStateException(); } 65 | String className = ""; 66 | if(classNames == null) className = ""; 67 | else if(classNames.length == 1) className = classNames[0]; 68 | else if(classNames.length > 1) className = classNames[index]; 69 | 70 | Instance i = new Instance (data[index], className, uri, null); 71 | index++; 72 | return i; 73 | } 74 | 75 | public boolean hasNext () { return index < data.length; } 76 | 77 | public void remove () { 78 | throw new IllegalStateException ("This Iterator does not support remove()."); 79 | } 80 | 81 | } -------------------------------------------------------------------------------- /src/main/java/cc/mallet/util/SystematicSampling.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.util; 2 | 3 | import java.util.Arrays; 4 | import java.util.concurrent.ThreadLocalRandom; 5 | 6 | public class SystematicSampling { 7 | 8 | public SystematicSampling() { 9 | } 10 | 11 | public static int [] Tmpsample(int [] counts, int n) { 12 | if(n<1) throw new IllegalArgumentException("Step must be bigger than 1, given was: " + n); 13 | //<- function(size=c(4, 140, 14, 20, 13, 110, 29, 90, 34, 29, 230), n=100){ 14 | double l = ThreadLocalRandom.current().nextDouble() * (double)n; 15 | System.out.println("l: " + l); 16 | double cum_sum = 0.0; 17 | int j = 0; 18 | int [] res = new int[counts.length]; 19 | for(int i = 0; i < counts.length; i++) { 20 | double cum_sum_tmp = cum_sum + (double)counts[i]; 21 | System.out.println("sum_sum_tmp: " + cum_sum_tmp); 22 | if((cum_sum < l && cum_sum_tmp >= l) || cum_sum_tmp >= l + n) { 23 | res[j] = i; 24 | j++; 25 | } 26 | if(cum_sum > n){ 27 | cum_sum = cum_sum_tmp % n; 28 | } else { 29 | cum_sum = cum_sum_tmp; 30 | } 31 | } 32 | return Arrays.copyOf(res, j); 33 | } 34 | 35 | public static int [] origsample(int [] counts, int n) { 36 | if(n<1) throw new IllegalArgumentException("Step must be bigger than 1, given was: " + n); 37 | int l = (int) (ThreadLocalRandom.current().nextDouble() * (double)n); 38 | int countsum = l; 39 | int [] res = new int[counts.length]; 40 | int i = 0; 41 | int j = 0; 42 | while( i < counts.length ) { 43 | if(counts[i] max ? counts[i] : max; 89 | } 90 | return max; 91 | } 92 | 93 | } 94 | -------------------------------------------------------------------------------- /src/main/java/cc/mallet/util/TeeStream.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.util; 2 | import java.io.PrintStream; 3 | public class TeeStream extends PrintStream { 4 | PrintStream out; 5 | public TeeStream(PrintStream out1, PrintStream out2) { 6 | super(out1); 7 | this.out = out2; 8 | } 9 | public void write(byte buf[], int off, int len) { 10 | try { 11 | super.write(buf, off, len); 12 | out.write(buf, off, len); 13 | } catch (Exception e) { 14 | } 15 | } 16 | public void flush() { 17 | super.flush(); 18 | out.flush(); 19 | } 20 | } -------------------------------------------------------------------------------- /src/main/java/cc/mallet/util/Timer.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.util; 2 | 3 | import java.util.Date; 4 | 5 | public class Timer { 6 | 7 | protected Date startdate = new Date(); 8 | protected Date enddate; 9 | protected String start = startdate.toString(); 10 | protected long startTime; 11 | protected long estimatedTime; 12 | 13 | public Timer() { 14 | } 15 | 16 | public void start() { 17 | startdate = new Date(); 18 | start = startdate.toString(); 19 | startTime = System.nanoTime(); 20 | } 21 | 22 | public void stop() { 23 | enddate = new Date(); 24 | estimatedTime = System.nanoTime() - startTime; 25 | } 26 | 27 | public Date getStartdate() { 28 | return startdate; 29 | } 30 | 31 | public void setStartdate(Date startdate) { 32 | this.startdate = startdate; 33 | } 34 | 35 | public Date getEnddate() { 36 | return enddate; 37 | } 38 | 39 | public void setEnddate(Date enddate) { 40 | this.enddate = enddate; 41 | } 42 | 43 | public long getEllapsedTime() { 44 | return (enddate.getTime() - startdate.getTime()); 45 | } 46 | 47 | public void report(String prefix) { 48 | System.err.println(prefix + estimatedTime/1000000000 + " seconds"); 49 | System.out.println("Started: " + start); 50 | System.out.println("Ended : " + enddate.toString()); 51 | } 52 | 53 | } 54 | -------------------------------------------------------------------------------- /src/main/java/cc/mallet/util/Timing.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.util; 2 | 3 | public class Timing { 4 | long start; 5 | long stop; 6 | String label; 7 | 8 | public Timing() { 9 | 10 | } 11 | 12 | public Timing(long start, long stop, String label) { 13 | super(); 14 | this.start = start; 15 | this.stop = stop; 16 | this.label = label; 17 | } 18 | 19 | /** 20 | * @return the start 21 | */ 22 | public long getStart() { 23 | return start; 24 | } 25 | 26 | /** 27 | * @param start the start to set 28 | */ 29 | public void setStart(int start) { 30 | this.start = start; 31 | } 32 | 33 | /** 34 | * @return the stop 35 | */ 36 | public long getStop() { 37 | return stop; 38 | } 39 | 40 | /** 41 | * @param stop the stop to set 42 | */ 43 | public void setStop(int stop) { 44 | this.stop = stop; 45 | } 46 | 47 | /** 48 | * @return the label 49 | */ 50 | public String getLabel() { 51 | return label; 52 | } 53 | 54 | /** 55 | * @param label the label to set 56 | */ 57 | public void setLabel(String label) { 58 | this.label = label; 59 | } 60 | 61 | } 62 | -------------------------------------------------------------------------------- /src/main/java/cc/mallet/util/TopicIndicatorLogger.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.util; 2 | 3 | import java.util.ArrayList; 4 | 5 | import cc.mallet.configuration.LDAConfiguration; 6 | import cc.mallet.topics.TopicAssignment; 7 | 8 | public interface TopicIndicatorLogger { 9 | void log(ArrayList data, LDAConfiguration config, int iteration); 10 | } 11 | -------------------------------------------------------------------------------- /src/main/java/cc/mallet/util/WalkerAliasTable.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.util; 2 | 3 | 4 | public interface WalkerAliasTable { 5 | public int generateSample(); 6 | public int [] generateSamples(int nrSamples); 7 | public int generateSample(double u); 8 | public void initTable(double [] probabilities, double normalizer); 9 | public void reGenerateAliasTable(double[] pi, double normalizer); 10 | public void initTableNormalizedProbabilities(double [] probabilities); 11 | } -------------------------------------------------------------------------------- /src/main/java/cc/mallet/util/WithoutReplacementSampler.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.util; 2 | 3 | import gnu.trove.TIntArrayList; 4 | 5 | import java.util.concurrent.ThreadLocalRandom; 6 | 7 | public class WithoutReplacementSampler implements IndexSampler { 8 | 9 | TIntArrayList available; 10 | public WithoutReplacementSampler(int startRange, int endRange) { 11 | available = new TIntArrayList(); 12 | for (int i = startRange; i < endRange; i++) { 13 | available.add(i); 14 | } 15 | } 16 | 17 | @Override 18 | public int nextSample() { 19 | if(available.size()==0) { 20 | throw new IllegalStateException("Sampler is exausted, there are no more to sample"); 21 | } 22 | int idx = (int) (ThreadLocalRandom.current().nextDouble() * available.size()); 23 | int val = available.get(idx); 24 | available.remove(idx); 25 | return val; 26 | } 27 | 28 | } 29 | -------------------------------------------------------------------------------- /src/main/java/cc/mallet/util/XORShiftRandom.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.util; 2 | 3 | import java.util.Random; 4 | 5 | public class XORShiftRandom extends Random { 6 | private static final long serialVersionUID = 1L; 7 | private long seed = System.nanoTime(); 8 | 9 | public XORShiftRandom() { 10 | } 11 | 12 | protected int next(int nbits) { 13 | // N.B. Not thread-safe! 14 | long x = this.seed; 15 | x ^= (x << 21); 16 | x ^= (x >>> 35); 17 | x ^= (x << 4); 18 | this.seed = x; 19 | x &= ((1L << nbits) -1); 20 | return (int) x; 21 | } 22 | } -------------------------------------------------------------------------------- /src/main/java/cc/mallet/util/resources/logging.properties: -------------------------------------------------------------------------------- 1 | ############################################################ 2 | # Default Logging Configuration File 3 | # 4 | # You can use a different file by specifying a filename 5 | # with the java.util.logging.config.file system property. 6 | # For example java -Djava.util.logging.config.file=myfile 7 | ############################################################ 8 | 9 | ############################################################ 10 | # Global properties 11 | ############################################################ 12 | 13 | # "handlers" specifies a comma separated list of log Handler 14 | # classes. These handlers will be installed during VM startup. 15 | # Note that these classes must be on the system classpath. 16 | # By default we only configure a ConsoleHandler, which will only 17 | # show messages at the INFO and above levels. 18 | handlers= java.util.logging.ConsoleHandler 19 | 20 | # To also add the FileHandler, use the following line instead. 21 | #handlers= java.util.logging.FileHandler, java.util.logging.ConsoleHandler 22 | 23 | # Default global logging level. 24 | # This specifies which kinds of events are logged across 25 | # all loggers. For any given facility this global level 26 | # can be overriden by a facility specific level 27 | # Note that the ConsoleHandler also has a separate level 28 | # setting to limit messages printed to the console. 29 | .level= SEVERE 30 | 31 | ############################################################ 32 | # Handler specific properties. 33 | # Describes specific configuration info for Handlers. 34 | ############################################################ 35 | 36 | # default file output is in user's home directory. 37 | java.util.logging.FileHandler.pattern = %h/java%u.log 38 | java.util.logging.FileHandler.limit = 50000 39 | java.util.logging.FileHandler.count = 1 40 | java.util.logging.FileHandler.formatter = cc.mallet.util.PlainLogFormatter 41 | #java.util.logging.FileHandler.formatter = java.util.logging.XMLFormatter 42 | 43 | # Limit the message that are printed on the console. ALL means all messages are reported. Off means no messages are reported. 44 | java.util.logging.ConsoleHandler.level = INFO 45 | java.util.logging.ConsoleHandler.formatter = cc.mallet.util.PlainLogFormatter 46 | 47 | 48 | ############################################################ 49 | # Facility specific properties. 50 | # Provides extra control for each logger. 51 | ############################################################ 52 | 53 | # For example, set the com.xyz.foo logger to only log SEVERE 54 | # messages: 55 | 56 | #Put the level of specific loggers here. If not included, default is INFO 57 | 58 | #cc.mallet.fst.MaxLatticeDefault.level = FINE 59 | -------------------------------------------------------------------------------- /src/main/resources/configuration/Clinton.cfg: -------------------------------------------------------------------------------- 1 | configs = Spalias 2 | #configs = VSSpalias, Spalias 3 | #configs = VSSpalias, Spalias, ADLDA 4 | no_runs = 1 5 | #experiment_out_dir = test_exp 6 | 7 | [Spalias] 8 | title = PCPLDA 9 | description = PCP LDA on Clinton emails 10 | dataset = /Users/eralljn/Downloads/clinton-emails/Emails.lda 11 | #scheme = lightpclda 12 | scheme = spalias 13 | seed=4711 14 | topics = 40 15 | alpha = 0.1 16 | beta = 0.01 17 | iterations = 10000 18 | rare_threshold = 10 19 | #diagnostic_interval = 50, 150 20 | #dn_diagnostic_interval = 1,5 21 | batches = 4 22 | topic_batches = 4 23 | topic_interval = 10 24 | start_diagnostic = 500 25 | results_size = 200 26 | debug = 0 27 | log_type_topic_density = true 28 | log_document_density = true 29 | log_phi_density = true 30 | phi_mean_filename = phi-mean.csv 31 | phi_mean_burnin = 20 32 | phi_mean_thin = 5 33 | stoplist = stoplist_clinton.txt 34 | -------------------------------------------------------------------------------- /src/main/resources/configuration/GlobalPLDAConfig.cfg: -------------------------------------------------------------------------------- 1 | #configs = LightCollapsed 2 | #configs = ADLDA 3 | configs = Spalias-Polya 4 | #configs = Spalias-Outlook 5 | #configs = LightPCLDAW2, Light, LightCollapsed, Spalias 6 | no_runs = 1 7 | seed=4711 8 | topics = 25 9 | alpha = 0.01 10 | beta = 0.01 11 | iterations = 3000 12 | topic_interval = 10 13 | start_diagnostic = 500 14 | debug = 0 15 | rare_threshold = 5 16 | log_type_topic_density = true 17 | log_document_density = true 18 | log_phi_density = true 19 | save_doc_topic_means = true 20 | doc_topic_mean_filename = doc_topic_means.csv 21 | phi_mean_filename = phi_means.csv 22 | phi_mean_burnin = 20 23 | phi_mean_thin = 5 24 | save_doc_lengths = true 25 | doc_lengths_filename = doc_lengths.txt 26 | save_term_frequencies = true 27 | term_frequencies_filename = term_frequencies.txt 28 | save_vocabulary = true 29 | vocabulary_filename = lda_vocab.txt 30 | #dataset = src/main/resources/datasets/nips.txt 31 | #dataset = src/main/resources/datasets/enron.txt 32 | dataset = /Users/eralljn/Research/Datasets/20newsgroups.lda 33 | stoplist = stoplist-20ng.txt 34 | #sparse_dirichlet_sampler_name = cc.mallet.types.PolyaUrnDirichlet 35 | hyperparam_optim_interval = 100 36 | #symmetric_alpha = true 37 | 38 | [Spalias-EMR] 39 | title = PCPLDA 40 | description = PCP LDA on selected dataset 41 | scheme = spalias 42 | dataset = /Users/eralljn/Downloads/EMR/output/messages.dat 43 | iterations = 2000 44 | alpha = 0.01 45 | beta = 0.01 46 | stoplist = stoplist-emr.txt 47 | 48 | [Spalias-Outlook] 49 | title = PCPLDA 50 | description = PCP LDA on selected dataset 51 | dataset = /Users/eralljn/workspace/OutlookMail/mails.lda 52 | scheme = spalias 53 | topics = 40 54 | iterations = 5000 55 | alpha = 0.01 56 | beta = 0.01 57 | rare_threshold = 10 58 | stoplist = stoplist-mail.txt 59 | 60 | [Spalias] 61 | title = PCPLDA 62 | description = PCP LDA on selected dataset 63 | scheme = spalias 64 | 65 | [Spalias-Polya] 66 | title = PCPLDA 67 | description = PCP LDA on selected dataset 68 | scheme = polyaurn 69 | #sparse_dirichlet_sampler_name = cc.mallet.types.PolyaUrnDirichlet 70 | 71 | [Spalias-nips] 72 | title = PCPLDA 73 | description = PCP LDA on NIPS 74 | dataset = src/main/resources/datasets/nips.txt 75 | scheme = spalias 76 | iterations = 100 77 | 78 | [Light] 79 | title = LightPCLDA 80 | description = PCP LDA on selected dataset 81 | scheme = lightpclda 82 | 83 | [LightCollapsed] 84 | title = LightCollapsed 85 | description = Light Collapsed LDA on selected dataset 86 | scheme = lightcollapsed 87 | 88 | [LightPCLDAW2] 89 | title = LightPCLDAW2 90 | description = Light PCP LDA with type-topic proposal on selected dataset 91 | scheme = lightpcldaw2 92 | 93 | [ADLDA] 94 | title = ADLDA 95 | description = AD LDA on selected dataset 96 | scheme = adlda 97 | tfidf_vocab_size = 50 98 | 99 | -------------------------------------------------------------------------------- /src/main/resources/configuration/KLClassification.cfg: -------------------------------------------------------------------------------- 1 | configs = films-imdb-1000 2 | no_runs = 1 3 | test_iterations = 20 4 | x_folds = 3 5 | burn_in = 70 6 | lag = 10 7 | seed=4711 8 | 9 | 10 | [films-imdb-small] 11 | title = DOLDA 12 | description = DOLDA on selected dataset 13 | dataset = ../DOLDA/src/main/resources/datasets/films-imdb-141.lda 14 | topics = 40 15 | rare_threshold = 3 16 | alpha = 0.01 17 | beta = 0.01 18 | iterations = 30 19 | batches = 4 20 | topic_interval = 50 21 | debug = 0 22 | log_type_topic_density = true 23 | log_document_density = true 24 | log_phi_density = true 25 | save_betas = true 26 | save_doc_topic_means = true 27 | doc_topic_mean_filename = doc_topic_means.csv 28 | 29 | [films-imdb-1000] 30 | title = DOLDA 31 | description = DOLDA on selected dataset 32 | dataset = ../DOLDA/src/main/resources/datasets/films-imdb-1000.lda 33 | topics = 40 34 | rare_threshold = 3 35 | alpha = 0.01 36 | beta = 0.01 37 | iterations = 1000 38 | batches = 4 39 | topic_interval = 500 40 | debug = 0 41 | log_type_topic_density = true 42 | log_document_density = true 43 | log_phi_density = true 44 | save_betas = true 45 | save_doc_topic_means = true 46 | doc_topic_mean_filename = doc_topic_means.csv 47 | 48 | [films-imdb] 49 | title = DOLDA 50 | description = DOLDA on full IMDB dataset 51 | dataset = ../DOLDA/src/main/resources/datasets/films-imdb.lda 52 | topics = 40 53 | rare_threshold = 3 54 | alpha = 0.01 55 | beta = 0.01 56 | iterations = 1000 57 | batches = 4 58 | topic_interval = 100 59 | debug = 0 60 | log_type_topic_density = true 61 | log_document_density = true 62 | log_phi_density = true 63 | save_betas = true 64 | save_doc_topic_means = true 65 | doc_topic_mean_filename = doc_topic_means.csv 66 | 67 | -------------------------------------------------------------------------------- /src/main/resources/configuration/PLDAConfigDeltaN.cfg: -------------------------------------------------------------------------------- 1 | 2 | configs = PLDA 3 | 4 | [PLDA] 5 | title = DeltaN_exp1 6 | description = Analyze the delta N:s during different parts of the chain. 7 | dataset = src/main/resources/datasets/nips.txt 8 | scheme = uncollapsed 9 | seed=121212 10 | topics = 20 11 | alpha = 0.1 12 | beta = 0.01 13 | iterations = 10500 14 | diagnostic_interval = 0 15 | dn_diagnostic_interval = 1, 1000, 10001, 10500 16 | batches = 4 17 | rare_threshold = 3 18 | topic_interval = 100 19 | start_diagnostic = 500 20 | debug = 0 21 | -------------------------------------------------------------------------------- /src/main/resources/configuration/RSConfig.cfg: -------------------------------------------------------------------------------- 1 | 2 | configs = spalias_adaptive 3 | #configs = spalias, spalias_adaptive 4 | no_runs = 1 5 | 6 | [spalias] 7 | title = PCPLDA 8 | description = Plain Partially Collapsed Parallel LDA 9 | dataset = src/main/resources/datasets/enron.txt 10 | #dataset = src/main/resources/datasets/smallnips.txt 11 | #dataset = src/main/resources/datasets/nips.txt 12 | #dataset = src/main/resources/datasets/pubmed.txt 13 | scheme = spalias 14 | #scheme = paranoid 15 | seed = 4711 16 | topics = 100 17 | alpha = 0.1 18 | beta = 0.01 19 | iterations = 1000 20 | diagnostic_interval = -1 21 | batches = 4 22 | topic_batches = 4 23 | rare_threshold = 10 24 | topic_interval = 10 25 | diagnostic_interval = -1 26 | dn_diagnostic_interval = -1 27 | start_diagnostic = 500 28 | measure_timing = false 29 | results_size = 5 30 | batch_building_scheme = utils.randomscan.document.EvenSplitBatchBuilder 31 | debug = 0 32 | 33 | [spalias_adaptive] 34 | title = PCPLDA 35 | description = Spalias Collapsed Parallel LDA with adaptive subsampling 36 | dataset = src/main/resources/datasets/enron.txt 37 | #dataset = src/main/resources/datasets/nips.txt 38 | #dataset = src/main/resources/datasets/pubmed.txt 39 | scheme = spalias 40 | #scheme = uncollapsed 41 | #scheme = paranoid 42 | seed = 4711 43 | topics = 100 44 | alpha = 1.0 45 | beta = 0.01 46 | iterations = 1000 47 | batches = 4 48 | topic_batches = 4 49 | rare_threshold = 10 50 | topic_interval = 10 51 | #batch_building_scheme = cc.mallet.topics.randomscan.document.EvenSplitBatchBuilder 52 | batch_building_scheme = cc.mallet.topics.randomscan.document.PercentageBatchBuilder 53 | percentage_split_size_doc = 0.2 54 | #batch_building_scheme = cc.mallet.topics.randomscan.document.AdaptiveBatchBuilder 55 | #batch_building_scheme = cc.mallet.topics.randomscan.document.FixedSplitBatchBuilder 56 | #fixed_split_size_doc = 0.1 57 | #topic_index_building_scheme = cc.mallet.topics.randomscan.topic.DeltaNTopicIndexBuilder 58 | topic_index_building_scheme = cc.mallet.topics.randomscan.topic.ProportionalTopicIndexBuilder 59 | proportional_ib_skip_step = 100 60 | #percent_top_tokens = 0.50 61 | #topic_index_building_scheme = cc.mallet.topics.randomscan.topic.TopWordsRandomFractionTopicIndexBuilder 62 | full_phi_period = 5 # Sample full Phi ever X:th interation 63 | #percentage_split_size_topic = 1.0 64 | #instability_period = 125 # Iterations 65 | #instability_period = 0 66 | debug = 0 67 | 68 | -------------------------------------------------------------------------------- /src/main/resources/configuration/SmokeTestConfig.cfg: -------------------------------------------------------------------------------- 1 | 2 | configs = demo 3 | 4 | [demo] 5 | title = Default test 6 | description = Standard LDA on AP dataset 7 | dataset = src/main/resources/datasets/ap.txt 8 | scheme = collapsed 9 | topics = 10 10 | alpha = 1.0 11 | beta = 0.01 12 | batches = 2 13 | rare_threshold = 0 14 | topic_interval = 10 15 | start_diagnostic = 500 16 | debug = 0 17 | stopwords = error, describe, ericsson, tr, trouble, fault, problem, a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p, q, r, s, t, u, v, w, x, y, z 18 | -------------------------------------------------------------------------------- /src/main/resources/configuration/SpaliasMainRemote.conf: -------------------------------------------------------------------------------- 1 | akka { 2 | loglevel = "INFO" 3 | actor { 4 | provider = "akka.remote.RemoteActorRefProvider" 5 | } 6 | remote { 7 | enabled-transports = ["akka.remote.netty.tcp"] 8 | netty.tcp { 9 | hostname = "127.0.0.1" 10 | port = 0 11 | maximum-frame-size = 128000m 12 | } 13 | log-sent-messages = on 14 | log-received-messages = on 15 | } 16 | } -------------------------------------------------------------------------------- /src/main/resources/configuration/SpaliasWorkerRemote.conf: -------------------------------------------------------------------------------- 1 | akka { 2 | loglevel = "INFO" 3 | actor { 4 | provider = "akka.remote.RemoteActorRefProvider" 5 | } 6 | remote { 7 | enabled-transports = ["akka.remote.netty.tcp"] 8 | netty.tcp { 9 | hostname = "127.0.0.1" 10 | port = 5150 11 | maximum-frame-size = 128000m 12 | } 13 | log-sent-messages = on 14 | log-received-messages = on 15 | } 16 | } -------------------------------------------------------------------------------- /src/main/resources/configuration/TestConfig.cfg: -------------------------------------------------------------------------------- 1 | 2 | configs = demo-sl 3 | #configs = demo-st, demo-sl, demo-pl 4 | 5 | [demo-nips] 6 | title = Default test 7 | description = Standard LDA on AP dataset 8 | #dataset = src/main/resources/datasets/ap.txt 9 | dataset = src/main/resources/datasets/nips.txt 10 | #dataset = src/main/resources/datasets/Corpus.txt 11 | #dataset = src/main/resources/datasets/CorpusSmall.txt 12 | #scheme = ush_serial 13 | scheme = ush_parallel 14 | seed=4711 15 | topics = 20 16 | alpha = 1.0 17 | beta = 0.01 18 | iterations = 1000 19 | batches = 2 20 | rare_threshold = 0 21 | topic_interval = 10 22 | start_diagnostic = 500 23 | debug = 0 24 | 25 | 26 | [demo-st] 27 | title = Default test 28 | description = Standard LDA on AP dataset 29 | dataset = src/main/resources/datasets/Corpus.txt 30 | #dataset = src/main/resources/datasets/ap.txt 31 | scheme = standard 32 | topics = 10 33 | alpha = 1.0 34 | beta = 0.01 35 | iterations = 5000 36 | batches = 2 37 | rare_threshold = -1 38 | topic_interval = 10 39 | start_diagnostic = 500 40 | debug = 0 41 | 42 | [demo-sl] 43 | title = Default test 44 | description = Standard LDA on AP dataset 45 | #dataset = src/main/resources/datasets/Corpus.txt 46 | dataset = src/main/resources/datasets/nips.txt 47 | scheme = ush_serial 48 | topics = 10 49 | alpha = 1.0 50 | beta = 0.01 51 | iterations = 1500 52 | diagnostic_interval = 3500, 4000 53 | dn_diagnostic_interval = -1 54 | batches = 2 55 | rare_threshold = -1 56 | topic_interval = 10 57 | start_diagnostic = 500 58 | debug = 0 59 | 60 | 61 | [demo-pl] 62 | title = Default test 63 | description = Standard LDA on AP dataset 64 | dataset = src/main/resources/datasets/Corpus.txt 65 | #dataset = src/main/resources/datasets/ap.txt 66 | scheme = ush_parallel 67 | topics = 10 68 | alpha = 1.0 69 | beta = 0.01 70 | iterations = 1500 71 | batches = 2 72 | rare_threshold = -1 73 | topic_interval = 10 74 | start_diagnostic = 500 75 | debug = 0 76 | 77 | -------------------------------------------------------------------------------- /src/main/resources/configuration/TestPriorsConfig.cfg: -------------------------------------------------------------------------------- 1 | 2 | configs = demo 3 | 4 | [demo] 5 | title = Default test 6 | description = Standard LDA on AP dataset 7 | dataset = src/main/resources/datasets/SmallTexts.txt 8 | scheme = spalias_priors 9 | topics = 4 10 | alpha = 1.0 11 | beta = 0.01 12 | batches = 2 13 | rare_threshold = 0 14 | topic_interval = 10 15 | start_diagnostic = 500 16 | debug = 0 17 | stopwords = error, describe, ericsson, tr, trouble, fault, problem, a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p, q, r, s, t, u, v, w, x, y, z 18 | topic_prior_filename = src/test/resources/topic_priors.txt 19 | 20 | -------------------------------------------------------------------------------- /src/main/resources/configuration/TopicMassConfig.cfg: -------------------------------------------------------------------------------- 1 | 2 | #configs = NIPS, ENRON, PUBMED 3 | configs = PUBMED 4 | 5 | [NIPS] 6 | dataset = src/main/resources/datasets/nips.txt 7 | dn_diagnostic_interval = -1 8 | 9 | [ENRON] 10 | dataset = src/main/resources/datasets/enron.txt 11 | dn_diagnostic_interval = -1 12 | 13 | [PUBMED] 14 | dataset = src/main/resources/datasets/pubmed.txt 15 | dn_diagnostic_interval = -1 16 | -------------------------------------------------------------------------------- /src/main/resources/configuration/UnitTestConfig.cfg: -------------------------------------------------------------------------------- 1 | 2 | configs = demo 3 | 4 | [demo] 5 | title = Default test 6 | description = Standard LDA on AP dataset 7 | dataset = src/main/resources/datasets/ap.txt 8 | scheme = ush_parallel 9 | topics = 100 10 | alpha = 1.0 11 | beta = 0.01 12 | iterations = 1500 13 | batches = 2 14 | rare_threshold = 3 15 | topic_interval = 10 16 | start_diagnostic = 500 17 | debug = 0 18 | stopwords = a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p, q, r, s, t, u, v, w, x, y, z 19 | dataset_sizes = 100, 200, 400, 800, 1600, 3200, 6400, 12800, -1 20 | -------------------------------------------------------------------------------- /src/main/resources/configuration/UnitTestConfigWithCommaDesc.cfg: -------------------------------------------------------------------------------- 1 | 2 | configs = demo 3 | 4 | [demo] 5 | title = Default test 6 | description = Standard LDA on AP, dataset 7 | dataset = src/main/resources/datasets/ap.txt 8 | scheme = ush_parallel 9 | topics = 100 10 | alpha = 1.0 11 | beta = 0.01 12 | iterations = 1500 13 | batches = 2 14 | rare_threshold = 3 15 | topic_interval = 10 16 | start_diagnostic = 500 17 | debug = 0 18 | stopwords = error, describe, ericsson, tr, trouble, fault, problem, a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p, q, r, s, t, u, v, w, x, y, z 19 | dataset_sizes = 100, 200, 400, 800, 1600, 3200, 6400, 12800, -1 20 | -------------------------------------------------------------------------------- /src/main/resources/configuration/VSConfig.cfg: -------------------------------------------------------------------------------- 1 | configs = VSSpalias, ADLDA, Spalias 2 | #configs = VSSpalias 3 | no_runs = 3 4 | 5 | [VSSpalias] 6 | title = VSSpalias 7 | scheme = nzvsspalias 8 | description = PCP LDA on selected dataset 9 | #dataset = src/main/resources/datasets/20newsgroups.txt 10 | #dataset = src/main/resources/datasets/nips.txt 11 | dataset = src/main/resources/datasets/enron.txt 12 | seed=4711 13 | topics = 400 14 | alpha = 0.1 15 | beta = 0.01 16 | iterations = 2000 17 | #diagnostic_interval = 50, 150 18 | #dn_diagnostic_interval = 1,5 19 | batches = 4 20 | topic_batches = 4 21 | topic_interval = 10 22 | start_diagnostic = 500 23 | results_size = 200 24 | debug = 0 25 | log_type_topic_density = true 26 | log_document_density = true 27 | 28 | 29 | [Spalias] 30 | title = Spalias 31 | scheme = spalias 32 | description = PCP LDA on selected dataset 33 | #dataset = src/main/resources/datasets/20newsgroups.txt 34 | #dataset = src/main/resources/datasets/nips.txt 35 | dataset = src/main/resources/datasets/enron.txt 36 | seed=4711 37 | topics = 400 38 | alpha = 0.1 39 | beta = 0.01 40 | iterations = 2000 41 | #diagnostic_interval = 50, 150 42 | #dn_diagnostic_interval = 1,5 43 | batches = 4 44 | topic_batches = 4 45 | topic_interval = 10 46 | start_diagnostic = 500 47 | results_size = 200 48 | debug = 0 49 | log_type_topic_density = true 50 | log_document_density = true 51 | 52 | 53 | [ADLDA] 54 | title = ADLDA 55 | scheme = adlda 56 | description = PCP LDA on selected dataset 57 | #dataset = src/main/resources/datasets/20newsgroups.txt 58 | #dataset = src/main/resources/datasets/nips.txt 59 | dataset = src/main/resources/datasets/enron.txt 60 | seed=4711 61 | topics = 400 62 | alpha = 0.1 63 | beta = 0.01 64 | iterations = 2000 65 | #diagnostic_interval = 50, 150 66 | #dn_diagnostic_interval = 1,5 67 | batches = 4 68 | topic_batches = 4 69 | topic_interval = 10 70 | start_diagnostic = 500 71 | results_size = 200 72 | debug = 0 73 | log_type_topic_density = true 74 | log_document_density = true -------------------------------------------------------------------------------- /src/main/resources/configuration/minimal.cfg: -------------------------------------------------------------------------------- 1 | topics = 20 2 | dataset = src/main/resources/datasets/nips.txt 3 | iterations = 400 4 | -------------------------------------------------------------------------------- /src/main/resources/datasets/README.txt: -------------------------------------------------------------------------------- 1 | Orig data used in the article is available at source: https://archive.ics.uci.edu/ml/datasets/Bag+of+Words 2 | 3 | Enron Emails: 4 | orig source: www.cs.cmu.edu/~enron 5 | D=39861 6 | W=28102 7 | N=6,400,000 (approx) 8 | 9 | NIPS full papers: 10 | orig source: books.nips.cc 11 | D=1500 12 | W=12419 13 | N=1,900,000 (approx) 14 | 15 | KOS blog entries: 16 | orig source: dailykos.com 17 | D=3430 18 | W=6906 19 | N=467,714 20 | 21 | NYTimes news articles: 22 | orig source: ldc.upenn.edu 23 | D=300000 24 | W=102660 25 | N=100,000,000 (approx) 26 | 27 | PubMed abstracts: 28 | orig source: www.pubmed.gov 29 | D=8200000 30 | W=141043 31 | N=730,000,000 (approx) 32 | 33 | D is the number of documents, W is the number of words in the vocabulary, and N is the total number of words in the collection. 34 | -------------------------------------------------------------------------------- /src/main/resources/datasets/SmallTexts.txt: -------------------------------------------------------------------------------- 1 | docno:1 X 'INSERT DISK THREE' ? But I can only get two in the drive ! 2 | docno:2 X 'Intel Inside' is a Government Warning required by Law. 3 | docno:3 X 'Intel Inside': The world's most widely used warning label. 4 | docno:4 X A Freudian slip is when you say one thing but mean your mother 5 | docno:5 X A backward poet writes inverse. 6 | -------------------------------------------------------------------------------- /src/main/resources/datasets/cgcbib.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lejon/PartiallyCollapsedLDA/0c1e588c2555811a5f5616f75a17497081560ba8/src/main/resources/datasets/cgcbib.txt -------------------------------------------------------------------------------- /src/main/resources/datasets/small.txt: -------------------------------------------------------------------------------- 1 | docno:1 X 1 2 3 4 5 6 7 2 | docno:2 X 8 9 10 11 12 13 14 15 16 17 3 | docno:3 X 18 19 20 21 22 23 4 | docno:4 X 24 25 26 27 28 29 30 31 32 33 34 5 | docno:5 X 35 36 37 38 39 40 41 42 43 44 45 46 47 48 6 | docno:6 X 49 50 51 52 53 54 55 56 57 7 | docno:7 X 58 59 60 61 62 63 8 | docno:8 X 64 65 66 67 68 69 70 9 | docno:9 X 71 72 73 74 75 76 77 78 79 80 10 | docno:10 X 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 11 | -------------------------------------------------------------------------------- /src/main/resources/datasets/small.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lejon/PartiallyCollapsedLDA/0c1e588c2555811a5f5616f75a17497081560ba8/src/main/resources/datasets/small.txt.gz -------------------------------------------------------------------------------- /src/main/resources/datasets/small.txt.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lejon/PartiallyCollapsedLDA/0c1e588c2555811a5f5616f75a17497081560ba8/src/main/resources/datasets/small.txt.zip -------------------------------------------------------------------------------- /src/main/resources/datasets/special_chars.txt: -------------------------------------------------------------------------------- 1 | docno:1 X 'INSERT DISK THREE' ? But_I_can only get two in the drive ! 2 | docno:2 X 'Intel Inside' is a Government Warning required by Law. 3 | docno:3 X 'Intel Inside': The world's most widely used warning label. 4 | docno:4 X A Freudian slip is when you say one thing but mean your mother 5 | docno:5 X A backward poet writes inverse. 6 | -------------------------------------------------------------------------------- /src/main/resources/datasets/tfidf-samples.txt: -------------------------------------------------------------------------------- 1 | docno:1 X this is a sample 2 | docno:2 X this is a another another example example example -------------------------------------------------------------------------------- /src/main/resources/topic_priors.txt: -------------------------------------------------------------------------------- 1 | 0, cell, stimulus, visual, cortex, response, spatial 2 | 19, image, images, pixel -------------------------------------------------------------------------------- /src/test/java/cc/mallet/misc/RandomTesting.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.misc; 2 | 3 | import static org.junit.Assert.assertTrue; 4 | 5 | import org.apache.commons.math3.distribution.BetaDistribution; 6 | import org.apache.commons.math3.distribution.GammaDistribution; 7 | import org.apache.commons.math3.stat.inference.KolmogorovSmirnovTest; 8 | import org.junit.Test; 9 | 10 | import cc.mallet.util.ParallelRandoms; 11 | import cc.mallet.util.Randoms; 12 | 13 | public class RandomTesting { 14 | 15 | @Test 16 | public void testBetaSampler() { 17 | int noDraws = 700; 18 | double [] samples = new double[noDraws]; 19 | double [] alphas = {0.5001, 1.0001, 2.0001, 4.0001, 8.0001, 16.0001, 32.0001, 1024.0001}; 20 | double [] betas = {0.5, 1.0, 2.0}; 21 | 22 | int loops = 10; 23 | for (int l = 0; l < loops; l++) { 24 | for (double alpha : alphas) { 25 | for (double beta : betas) { 26 | for (int i = 0; i < noDraws; i++) { 27 | samples[i] = ParallelRandoms.rbeta(alpha, beta); 28 | } 29 | BetaDistribution betaCdf = new BetaDistribution(alpha, beta); 30 | 31 | KolmogorovSmirnovTest ks = new KolmogorovSmirnovTest(); 32 | double test2 = ks.kolmogorovSmirnovTest(betaCdf, samples); 33 | assertTrue(test2 > 0.00001); 34 | } 35 | } 36 | } 37 | } 38 | 39 | @Test 40 | public void testMarsagliaKS() { 41 | ParallelRandoms pr = new ParallelRandoms(); 42 | Randoms malletRnd = new Randoms(); 43 | int noDraws = 500_000; 44 | double [] samplesM = new double[noDraws]; 45 | double [] samplesB = new double[noDraws]; 46 | double [] alphas = {0.5001, 1.0001, 2.0001, 4.0001, 8.0001, 16.0001, 32.0001, 1024.0001}; 47 | double [] betas = {0.5, 1.0, 2.0}; 48 | 49 | for (double alpha : alphas) { 50 | for (double beta : betas) { 51 | double lambda = 0; 52 | for (int i = 0; i < noDraws; i++) { 53 | samplesM[i] = pr.nextGamma(alpha, beta, lambda); // Marsaglia (2000) 54 | samplesB[i] = malletRnd.nextGamma(alpha, beta, lambda); // Best 55 | } 56 | KolmogorovSmirnovTest ks = new KolmogorovSmirnovTest(); 57 | double test1 = ks.kolmogorovSmirnovTest(samplesB, samplesM); 58 | // System.out.println(test1); 59 | assertTrue(test1 > 0.00001); 60 | } 61 | } 62 | } 63 | 64 | 65 | @Test 66 | public void testMarsagliaVsTrue() { 67 | int noDraws = 700; 68 | double [] samples = new double[noDraws]; 69 | double [] alphas = {0.5001, 1.0001, 2.0001, 4.0001, 8.0001, 16.0001, 32.0001, 1024.0001}; 70 | double [] betas = {0.5, 1.0, 2.0}; 71 | 72 | int loops = 10; 73 | for (int l = 0; l < loops; l++) { 74 | for (double alpha : alphas) { 75 | for (double beta : betas) { 76 | double lambda = 0; 77 | for (int i = 0; i < noDraws; i++) { 78 | samples[i] = ParallelRandoms.rgamma(alpha, beta, lambda); // Marsaglia (2000) 79 | } 80 | GammaDistribution gammaCdf = new GammaDistribution(alpha, beta); 81 | 82 | KolmogorovSmirnovTest ks = new KolmogorovSmirnovTest(); 83 | double test2 = ks.kolmogorovSmirnovTest(gammaCdf, samples); 84 | assertTrue(test2 > 0.00001); 85 | } 86 | } 87 | } 88 | 89 | } 90 | 91 | } 92 | -------------------------------------------------------------------------------- /src/test/java/cc/mallet/similarity/CosineDistanceTest.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.similarity; 2 | 3 | import static org.junit.Assert.*; 4 | 5 | import org.junit.Test; 6 | 7 | public class CosineDistanceTest { 8 | 9 | @Test 10 | public void test() { 11 | double [] v1 = {3.0,8.0,7.0,5.0,2.0,9.0}; 12 | double [] v2 = {10.0,8.0,6.0,6.0,4.0,5.0}; 13 | 14 | CosineDistance cd = new CosineDistance(); 15 | double result = cd.calculate(v1, v2); 16 | 17 | assertEquals(1-0.8638935626791596, result, 0.000000000000001); 18 | } 19 | 20 | } 21 | -------------------------------------------------------------------------------- /src/test/java/cc/mallet/similarity/SimilarityTest.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.similarity; 2 | 3 | import static org.junit.Assert.assertTrue; 4 | 5 | import java.util.Arrays; 6 | import java.util.List; 7 | 8 | import org.junit.Before; 9 | import org.junit.Ignore; 10 | import org.junit.Test; 11 | import org.junit.runner.RunWith; 12 | import org.junit.runners.Parameterized; 13 | import org.junit.runners.Parameterized.Parameters; 14 | 15 | @RunWith(value = Parameterized.class) 16 | @Ignore("To be finalized") public class SimilarityTest { 17 | 18 | @SuppressWarnings("rawtypes") 19 | Class distanceClass; 20 | Distance dist; 21 | double epsilon = 0.001; 22 | 23 | public SimilarityTest(@SuppressWarnings("rawtypes") Class testClass) { 24 | distanceClass = testClass; 25 | } 26 | 27 | @Parameters 28 | public static List data() { 29 | Object[][] impls = new Object[][] { 30 | //{ BhattacharyyaDistance.class }, 31 | //{ BM25Distance.class }, 32 | { CanberraDistance.class }, 33 | { ChebychevDistance.class }, 34 | { CosineDistance.class }, 35 | { EuclidianDistance.class }, 36 | { HellingerDistance.class }, 37 | { JensenShannonDistance.class }, 38 | { JaccardDistance.class }, 39 | { KLDistance.class }, 40 | { KolmogorovSmirnovDistance.class }, 41 | { ManhattanDistance.class }, 42 | { StatisticalDistance.class } 43 | }; 44 | return Arrays.asList(impls); 45 | } 46 | 47 | @Before 48 | public void noSetup() throws InstantiationException, IllegalAccessException { 49 | dist = (Distance) distanceClass.newInstance(); 50 | } 51 | 52 | @Test 53 | public void testSame() { 54 | double [] v1 = {0.2, 0.3, 0.5, 0.7}; 55 | double calcDist = dist.calculate(v1, v1); 56 | assertTrue(dist.getClass().getSimpleName() + "Distance was: " + calcDist, calcDist < 0.00001); 57 | } 58 | 59 | @Test 60 | public void testNotSame() { 61 | double [] v1 = {0.2, 0.3, 0.5, 0.7}; 62 | double [] v2 = {0.5, 0.8, 0.1, 0.7}; 63 | double calcDistSame = dist.calculate(v1, v1); 64 | double calcDist = dist.calculate(v1, v2); 65 | assertTrue(dist.getClass().getSimpleName() + "Distance was: " + calcDist, calcDist > calcDistSame); 66 | } 67 | 68 | @Test 69 | public void testVsOne() { 70 | double [] v1 = {0.2, 0.3, 0.5, 0.7}; 71 | double [] v2 = {0.0, 0.1, 0.0, 0.0}; 72 | double calcDistSame = dist.calculate(v1, v1); 73 | double calcDist = dist.calculate(v1, v2); 74 | assertTrue(dist.getClass().getSimpleName() + "Distance was: " + calcDist, calcDist > calcDistSame); 75 | } 76 | 77 | } 78 | -------------------------------------------------------------------------------- /src/test/java/cc/mallet/topics/DocumentProposalTest.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.topics; 2 | 3 | import static org.junit.Assert.assertTrue; 4 | 5 | import java.util.Random; 6 | import java.util.concurrent.ThreadLocalRandom; 7 | 8 | import org.apache.commons.math3.stat.inference.KolmogorovSmirnovTest; 9 | import org.junit.Test; 10 | 11 | 12 | public class DocumentProposalTest { 13 | 14 | Random rnd = new Random(); 15 | int numTopics = 12; 16 | int docLength = 11; 17 | double alphaSum = 0.1; 18 | int [] oneDocTopics = {11,0,1,3,10,0,6,0,8,9,10,11,1,9,8,7,0,4,11,11}; 19 | 20 | int loops = 1_000_000; 21 | int [] stats = new int[2]; 22 | double [] drawStats = new double[numTopics]; 23 | double [] priorDrawStats = new double[numTopics]; 24 | KolmogorovSmirnovTest ksTest = new KolmogorovSmirnovTest(); 25 | 26 | @Test 27 | public void testDocProposal() { 28 | for (int i = 0; i < loops; i++) { 29 | double u_i = ThreadLocalRandom.current().nextDouble() * (oneDocTopics.length + alphaSum); // (n_d + K*alpha) * u where u ~ U(0,1) 30 | int docTopicIndicatorProposal = -1; 31 | if(u_i < oneDocTopics.length) { 32 | stats[0]++; 33 | docTopicIndicatorProposal = oneDocTopics[(int) u_i]; 34 | } else { 35 | stats[1]++; 36 | docTopicIndicatorProposal = (int) (((u_i - oneDocTopics.length) / alphaSum) * numTopics); // assume symmetric alpha, just draws one alpha 37 | priorDrawStats[docTopicIndicatorProposal]++; 38 | } 39 | drawStats[docTopicIndicatorProposal]++; 40 | } 41 | 42 | for (int i = 0; i < drawStats.length; i++) { 43 | drawStats[i] /= loops; 44 | } 45 | for (int i = 0; i < priorDrawStats.length; i++) { 46 | priorDrawStats[i] /= loops; 47 | } 48 | 49 | //System.out.println("Less or more: "+ Arrays.toString(stats)); 50 | //System.out.println("Distr: "+ Arrays.toString(drawStats)); 51 | //System.out.println("Prior Distr: "+ Arrays.toString(priorDrawStats)); 52 | double [] expectedProportions = new double[priorDrawStats.length]; 53 | for (int i = 0; i < expectedProportions.length; i++) { 54 | expectedProportions[i] = stats[1] / (double) (priorDrawStats.length*loops); 55 | } 56 | //System.out.println("Expected Distr: "+ Arrays.toString(expectedProportions)); 57 | 58 | double pval = ksTest.kolmogorovSmirnovTest(priorDrawStats, expectedProportions); 59 | 60 | //System.out.println("Pval is: " + pval); 61 | 62 | assertTrue("Pval is: " + pval, pval > 0.00001); 63 | 64 | } 65 | 66 | 67 | } 68 | -------------------------------------------------------------------------------- /src/test/java/cc/mallet/topics/ParanoidCollapsedLightLDA.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.topics; 2 | 3 | import java.util.List; 4 | import java.util.concurrent.Future; 5 | 6 | import cc.mallet.configuration.LDAConfiguration; 7 | import cc.mallet.types.InstanceList; 8 | 9 | public class ParanoidCollapsedLightLDA extends CollapsedLightLDA { 10 | 11 | private static final long serialVersionUID = 6948198361119397002L; 12 | 13 | public ParanoidCollapsedLightLDA(LDAConfiguration config) { 14 | super(config); 15 | } 16 | 17 | @Override 18 | public void addInstances(InstanceList training) { 19 | super.addInstances(training); 20 | ensureConsistentTopicTypeCounts(typeTopicCounts); 21 | debugPrintMMatrix(); 22 | } 23 | 24 | @Override 25 | protected void updateCounts(List> futureResults) throws InterruptedException { 26 | super.updateCounts(futureResults); 27 | ensureConsistentTopicTypeCounts(typeTopicCounts); 28 | debugPrintMMatrix(); 29 | } 30 | 31 | @Override 32 | public void postIteration() { 33 | super.postIteration(); 34 | ensureTTEquals(); 35 | } 36 | 37 | @Override 38 | public void postSample() { 39 | super.postSample(); 40 | int updateCountSum = 0; 41 | for (int batch = 0; batch < batchLocalTopicTypeUpdates.length; batch++) { 42 | for (int topic = 0; topic < numTopics; topic++) { 43 | for (int type = 0; type < numTypes; type++) { 44 | //updateCountSum += batchLocalTopicTypeUpdates[batch][topic][type]; 45 | updateCountSum += batchLocalTopicTypeUpdates[topic][type].get(); 46 | } 47 | } 48 | if(updateCountSum!=0) throw new IllegalStateException("Update count does not sum to zero: " + updateCountSum); 49 | updateCountSum = 0; 50 | } 51 | } 52 | 53 | @Override 54 | protected void sampleTopicAssignmentsParallel(LDADocSamplingContext ctxIn) { 55 | //SamplingResult res = super.sampleTopicAssignmentsParallel(tokenSequence, oneDocTopics, myBatch); 56 | LightLDADocSamplingContext ctx = (LightLDADocSamplingContext) ctxIn; 57 | int [][] globalTypeTopicCounts = ctx.getMyTypeTopicCounts(); 58 | super.sampleTopicAssignmentsParallel(ctx); 59 | ensureConsistentTopicTypeCounts(globalTypeTopicCounts); 60 | //System.out.println("Glboal type topic count is consistent!"); 61 | //ensureConsistentTopicTypeCountDelta(batchLocalTopicTypeUpdates, ctx.getMyBatch()); 62 | } 63 | 64 | 65 | } 66 | 67 | -------------------------------------------------------------------------------- /src/test/java/cc/mallet/topics/ParanoidLightPCLDAtypeTopicProposal.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.topics; 2 | 3 | import cc.mallet.configuration.LDAConfiguration; 4 | import cc.mallet.types.InstanceList; 5 | 6 | public class ParanoidLightPCLDAtypeTopicProposal extends LightPCLDAtypeTopicProposal { 7 | 8 | private static final long serialVersionUID = 6948198361119397002L; 9 | 10 | public ParanoidLightPCLDAtypeTopicProposal(LDAConfiguration config) { 11 | super(config); 12 | } 13 | 14 | @Override 15 | protected void samplePhi() { 16 | super.samplePhi(); 17 | ensureConsistentPhi(phi); 18 | ensureConsistentTopicTypeCounts(topicTypeCountMapping, typeTopicCounts, tokensPerTopic); 19 | debugPrintMMatrix(); 20 | } 21 | 22 | @Override 23 | public void addInstances(InstanceList training) { 24 | super.addInstances(training); 25 | //ensureConsistentTopicTypeCounts(topicTypeCounts); 26 | ensureConsistentPhi(phi); 27 | ensureConsistentTopicTypeCounts(topicTypeCountMapping, typeTopicCounts, tokensPerTopic); 28 | debugPrintMMatrix(); 29 | } 30 | 31 | @Override 32 | protected void updateCounts() throws InterruptedException { 33 | super.updateCounts(); 34 | ensureConsistentPhi(phi); 35 | ensureConsistentTopicTypeCounts(topicTypeCountMapping, typeTopicCounts, tokensPerTopic); 36 | debugPrintMMatrix(); 37 | } 38 | 39 | 40 | 41 | @Override 42 | public void postSample() { 43 | super.postSample(); 44 | int updateCountSum = 0; 45 | for (int batch = 0; batch < batchLocalTopicTypeUpdates.length; batch++) { 46 | for (int topic = 0; topic < numTopics; topic++) { 47 | for (int type = 0; type < numTypes; type++) { 48 | //updateCountSum += batchLocalTopicTypeUpdates[batch][topic][type]; 49 | updateCountSum += batchLocalTopicTypeUpdates[topic][type].get(); 50 | } 51 | } 52 | if(updateCountSum!=0) throw new IllegalStateException("Update count does not sum to zero: " + updateCountSum); 53 | updateCountSum = 0; 54 | } 55 | } 56 | 57 | @Override 58 | protected LDADocSamplingResult sampleTopicAssignmentsParallel(LDADocSamplingContext ctx) { 59 | //SamplingResult res = super.sampleTopicAssignmentsParallel(tokenSequence, oneDocTopics, myBatch); 60 | return super.sampleTopicAssignmentsParallel(ctx); 61 | // THIS CANNOT BE ENSURED with a job stealing implementation 62 | //ensureConsistentTopicTypeCountDelta(batchLocalTopicTypeUpdates, myBatch); 63 | //return res; 64 | } 65 | 66 | } 67 | 68 | -------------------------------------------------------------------------------- /src/test/java/cc/mallet/topics/ParanoidPoissonPolyaUrnHDP.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.topics; 2 | 3 | import cc.mallet.configuration.LDAConfiguration; 4 | import cc.mallet.types.InstanceList; 5 | 6 | public class ParanoidPoissonPolyaUrnHDP extends PoissonPolyaUrnHLDA { 7 | 8 | private static final long serialVersionUID = 1L; 9 | 10 | public ParanoidPoissonPolyaUrnHDP(LDAConfiguration config) { 11 | super(config); 12 | } 13 | 14 | @Override 15 | protected void samplePhi() { 16 | super.samplePhi(); 17 | ensureConsistentPhi(phi); 18 | ensureConsistentTopicTypeCounts(topicTypeCountMapping, typeTopicCounts, tokensPerTopic); 19 | debugPrintMMatrix(); 20 | } 21 | 22 | @Override 23 | public void addInstances(InstanceList training) { 24 | super.addInstances(training); 25 | //ensureConsistentTopicTypeCounts(topicTypeCounts); 26 | ensureConsistentPhi(phi); 27 | ensureConsistentTopicTypeCounts(topicTypeCountMapping, typeTopicCounts, tokensPerTopic); 28 | debugPrintMMatrix(); 29 | } 30 | 31 | @Override 32 | protected void updateCounts() throws InterruptedException { 33 | super.updateCounts(); 34 | ensureConsistentPhi(phi); 35 | ensureConsistentTopicTypeCounts(topicTypeCountMapping, typeTopicCounts, tokensPerTopic); 36 | debugPrintMMatrix(); 37 | } 38 | 39 | @Override 40 | public void postSample() { 41 | super.postSample(); 42 | int updateCountSum = 0; 43 | for (int batch = 0; batch < batchLocalTopicTypeUpdates.length; batch++) { 44 | for (int topic = 0; topic < numTopics; topic++) { 45 | for (int type = 0; type < numTypes; type++) { 46 | //updateCountSum += batchLocalTopicTypeUpdates[batch][topic][type]; 47 | updateCountSum += batchLocalTopicTypeUpdates[topic][type].get(); 48 | } 49 | } 50 | if(updateCountSum!=0) throw new IllegalStateException("Update count does not sum to zero: " + updateCountSum); 51 | updateCountSum = 0; 52 | } 53 | } 54 | 55 | 56 | @Override 57 | public void postIteration() { 58 | super.postIteration(); 59 | for (int topic = numTopics; topic < maxTopics; topic++) { 60 | if(tokensPerTopic[topic]>0) { 61 | throw new IllegalArgumentException("Topic count: " + topic + " has value > 0 for " + tokensPerTopic[topic] + ". numTopics:" + numTopics); 62 | } 63 | for(int type = 0; type < numTypes; type++) { 64 | if(topicTypeCountMapping[topic][type]>0) { 65 | throw new IllegalArgumentException("Topic: " + topic + " has value > 0 for " + topicTypeCountMapping[topic][type] + ". numTopics:" + numTopics); 66 | } 67 | } 68 | } 69 | 70 | ensureTTEquals(); 71 | } 72 | 73 | } 74 | -------------------------------------------------------------------------------- /src/test/java/cc/mallet/topics/ParanoidSpaliasUncollapsedLDA.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.topics; 2 | 3 | import cc.mallet.configuration.LDAConfiguration; 4 | import cc.mallet.types.InstanceList; 5 | 6 | public class ParanoidSpaliasUncollapsedLDA extends SpaliasUncollapsedParallelLDA { 7 | 8 | private static final long serialVersionUID = 6948198361119397002L; 9 | 10 | public ParanoidSpaliasUncollapsedLDA(LDAConfiguration config) { 11 | super(config); 12 | } 13 | 14 | @Override 15 | protected void samplePhi() { 16 | super.samplePhi(); 17 | ensureConsistentPhi(phi); 18 | ensureConsistentTopicTypeCounts(topicTypeCountMapping, typeTopicCounts, tokensPerTopic); 19 | debugPrintMMatrix(); 20 | } 21 | 22 | @Override 23 | public void addInstances(InstanceList training) { 24 | super.addInstances(training); 25 | //ensureConsistentTopicTypeCounts(topicTypeCounts); 26 | ensureConsistentPhi(phi); 27 | ensureConsistentTopicTypeCounts(topicTypeCountMapping, typeTopicCounts, tokensPerTopic); 28 | debugPrintMMatrix(); 29 | } 30 | 31 | @Override 32 | protected void updateCounts() throws InterruptedException { 33 | super.updateCounts(); 34 | ensureConsistentPhi(phi); 35 | ensureConsistentTopicTypeCounts(topicTypeCountMapping, typeTopicCounts, tokensPerTopic); 36 | debugPrintMMatrix(); 37 | } 38 | 39 | @Override 40 | public void postIteration() { 41 | super.postIteration(); 42 | ensureTTEquals(); 43 | } 44 | 45 | @Override 46 | public void postSample() { 47 | super.postSample(); 48 | int updateCountSum = 0; 49 | for (int batch = 0; batch < batchLocalTopicTypeUpdates.length; batch++) { 50 | for (int topic = 0; topic < numTopics; topic++) { 51 | for (int type = 0; type < numTypes; type++) { 52 | //updateCountSum += batchLocalTopicTypeUpdates[batch][topic][type]; 53 | updateCountSum += batchLocalTopicTypeUpdates[topic][type].get(); 54 | } 55 | } 56 | if(updateCountSum!=0) throw new IllegalStateException("Update count does not sum to zero: " + updateCountSum); 57 | updateCountSum = 0; 58 | } 59 | } 60 | 61 | @Override 62 | protected LDADocSamplingResult sampleTopicAssignmentsParallel(LDADocSamplingContext ctx) { 63 | //SamplingResult res = super.sampleTopicAssignmentsParallel(tokenSequence, oneDocTopics, myBatch); 64 | return super.sampleTopicAssignmentsParallel(ctx); 65 | // THIS CANNOT BE ENSURED with a job stealing implementation 66 | //ensureConsistentTopicTypeCountDelta(batchLocalTopicTypeUpdates, myBatch); 67 | //return res; 68 | } 69 | 70 | 71 | } 72 | 73 | -------------------------------------------------------------------------------- /src/test/java/cc/mallet/topics/ParanoidUncollapsedParallelLDA.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.topics; 2 | 3 | import cc.mallet.configuration.LDAConfiguration; 4 | import cc.mallet.types.InstanceList; 5 | 6 | public class ParanoidUncollapsedParallelLDA extends EfficientUncollapsedParallelLDA { 7 | 8 | private static final long serialVersionUID = 6948198361119397002L; 9 | 10 | public ParanoidUncollapsedParallelLDA(LDAConfiguration config) { 11 | super(config); 12 | } 13 | 14 | @Override 15 | protected void samplePhi() { 16 | super.samplePhi(); 17 | ensureConsistentPhi(phi); 18 | ensureConsistentTopicTypeCounts(topicTypeCountMapping, typeTopicCounts, tokensPerTopic); 19 | debugPrintMMatrix(); 20 | } 21 | 22 | @Override 23 | public void addInstances(InstanceList training) { 24 | super.addInstances(training); 25 | //ensureConsistentTopicTypeCounts(topicTypeCounts); 26 | ensureConsistentPhi(phi); 27 | ensureConsistentTopicTypeCounts(topicTypeCountMapping, typeTopicCounts, tokensPerTopic); 28 | debugPrintMMatrix(); 29 | } 30 | 31 | @Override 32 | protected void updateCounts() throws InterruptedException { 33 | super.updateCounts(); 34 | ensureConsistentPhi(phi); 35 | ensureConsistentTopicTypeCounts(topicTypeCountMapping, typeTopicCounts, tokensPerTopic); 36 | debugPrintMMatrix(); 37 | } 38 | 39 | 40 | 41 | @Override 42 | public void postSample() { 43 | super.postSample(); 44 | int updateCountSum = 0; 45 | for (int batch = 0; batch < batchLocalTopicTypeUpdates.length; batch++) { 46 | for (int topic = 0; topic < numTopics; topic++) { 47 | for (int type = 0; type < numTypes; type++) { 48 | //updateCountSum += batchLocalTopicTypeUpdates[batch][topic][type]; 49 | updateCountSum += batchLocalTopicTypeUpdates[topic][type].get(); 50 | } 51 | } 52 | if(updateCountSum!=0) throw new IllegalStateException("Update count does not sum to zero: " + updateCountSum); 53 | updateCountSum = 0; 54 | } 55 | } 56 | 57 | @Override 58 | protected LDADocSamplingResult sampleTopicAssignmentsParallel(LDADocSamplingContext ctx) { 59 | //SamplingResult res = super.sampleTopicAssignmentsParallel(tokenSequence, oneDocTopics, myBatch); 60 | return super.sampleTopicAssignmentsParallel(ctx); 61 | // THIS CANNOT BE ENSURED with a job stealing implementation 62 | //ensureConsistentTopicTypeCountDelta(batchLocalTopicTypeUpdates, myBatch); 63 | //return res; 64 | } 65 | 66 | } 67 | 68 | -------------------------------------------------------------------------------- /src/test/java/cc/mallet/topics/ParanoidVSSpaliasUncollapsedLDA.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.topics; 2 | 3 | import cc.mallet.configuration.LDAConfiguration; 4 | import cc.mallet.types.InstanceList; 5 | 6 | public class ParanoidVSSpaliasUncollapsedLDA extends NZVSSpaliasUncollapsedParallelLDA { 7 | 8 | private static final long serialVersionUID = 6948198361119397002L; 9 | 10 | public ParanoidVSSpaliasUncollapsedLDA(LDAConfiguration config) { 11 | super(config); 12 | } 13 | 14 | @Override 15 | protected void samplePhi() { 16 | super.samplePhi(); 17 | ensureConsistentPhi(phi); 18 | ensureConsistentTopicTypeCounts(topicTypeCountMapping, typeTopicCounts, tokensPerTopic); 19 | debugPrintMMatrix(); 20 | } 21 | 22 | @Override 23 | public void addInstances(InstanceList training) { 24 | super.addInstances(training); 25 | //ensureConsistentTopicTypeCounts(topicTypeCounts); 26 | ensureConsistentPhi(phi); 27 | ensureConsistentTopicTypeCounts(topicTypeCountMapping, typeTopicCounts, tokensPerTopic); 28 | debugPrintMMatrix(); 29 | } 30 | 31 | @Override 32 | protected void updateCounts() throws InterruptedException { 33 | super.updateCounts(); 34 | ensureConsistentPhi(phi); 35 | ensureConsistentTopicTypeCounts(topicTypeCountMapping, typeTopicCounts, tokensPerTopic); 36 | debugPrintMMatrix(); 37 | } 38 | 39 | 40 | 41 | @Override 42 | public void postSample() { 43 | super.postSample(); 44 | int updateCountSum = 0; 45 | for (int batch = 0; batch < batchLocalTopicTypeUpdates.length; batch++) { 46 | for (int topic = 0; topic < numTopics; topic++) { 47 | for (int type = 0; type < numTypes; type++) { 48 | //updateCountSum += batchLocalTopicTypeUpdates[batch][topic][type]; 49 | updateCountSum += batchLocalTopicTypeUpdates[topic][type].get(); 50 | } 51 | } 52 | if(updateCountSum!=0) throw new IllegalStateException("Update count does not sum to zero: " + updateCountSum); 53 | updateCountSum = 0; 54 | } 55 | } 56 | 57 | @Override 58 | protected LDADocSamplingResultSparseSimple sampleTopicAssignmentsParallel(LDADocSamplingContext ctx) { //SamplingResult res = super.sampleTopicAssignmentsParallel(tokenSequence, oneDocTopics, myBatch); 59 | return super.sampleTopicAssignmentsParallel(ctx); 60 | // THIS CANNOT BE ENSURED with a job stealing implementation 61 | //ensureConsistentTopicTypeCountDelta(batchLocalTopicTypeUpdates, myBatch); 62 | //return res; 63 | } 64 | 65 | 66 | } 67 | 68 | -------------------------------------------------------------------------------- /src/test/java/cc/mallet/topics/PoissonPolyaUrnHDPLDATest.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.topics; 2 | 3 | import static org.junit.Assert.assertEquals; 4 | 5 | import java.util.ArrayList; 6 | import java.util.Arrays; 7 | import java.util.List; 8 | 9 | import org.junit.Test; 10 | 11 | import cc.mallet.configuration.SimpleLDAConfiguration; 12 | 13 | public class PoissonPolyaUrnHDPLDATest { 14 | 15 | @Test 16 | public void testUpdateNrActiveTopics() { 17 | PoissonPolyaUrnHDPLDA s = new PoissonPolyaUrnHDPLDA(new SimpleLDAConfiguration()); 18 | List at = new ArrayList(); 19 | at.add(1); 20 | at.add(2); 21 | at.add(3); 22 | int [] et = new int [] {1,2}; 23 | int nt = s.updateNrActiveTopics(et, at); 24 | assertEquals(1, nt); 25 | assertEquals(1, at.size()); 26 | } 27 | 28 | @Test 29 | public void testUpdateNrActiveTopicsNoChange() { 30 | PoissonPolyaUrnHDPLDA s = new PoissonPolyaUrnHDPLDA(new SimpleLDAConfiguration()); 31 | List at = new ArrayList(); 32 | at.add(1); 33 | at.add(2); 34 | at.add(3); 35 | int [] et = new int [] {}; 36 | int nt = s.updateNrActiveTopics(et, at); 37 | assertEquals(3, nt); 38 | assertEquals(3, at.size()); 39 | } 40 | 41 | @Test 42 | public void testCalcNewTopicsEmpty() { 43 | PoissonPolyaUrnHDPLDA s = new PoissonPolyaUrnHDPLDA(new SimpleLDAConfiguration()); 44 | int [] nt = s.calcNewTopics(new ArrayList(), new int [] {}); 45 | assertEquals(0, nt.length); 46 | } 47 | 48 | @Test 49 | public void testCalcNewTopicsNoNew() { 50 | PoissonPolyaUrnHDPLDA s = new PoissonPolyaUrnHDPLDA(new SimpleLDAConfiguration()); 51 | List at = Arrays.asList(new Integer[]{1, 2, 3}); 52 | int [] nt = s.calcNewTopics(at, new int [] {1,2,3}); 53 | assertEquals(0, nt.length); 54 | } 55 | 56 | @Test 57 | public void testCalcNewTopics() { 58 | PoissonPolyaUrnHDPLDA s = new PoissonPolyaUrnHDPLDA(new SimpleLDAConfiguration()); 59 | List at = Arrays.asList(new Integer[]{1, 2, 3}); 60 | int [] nt = s.calcNewTopics(at, new int [] {2,3,4}); 61 | assertEquals(1, nt.length); 62 | assertEquals(4, nt[0]); 63 | } 64 | 65 | @Test 66 | public void testCalcNewTopicsDuplicateSampled() { 67 | PoissonPolyaUrnHDPLDA s = new PoissonPolyaUrnHDPLDA(new SimpleLDAConfiguration()); 68 | List at = Arrays.asList(new Integer[]{1, 2, 3}); 69 | int [] nt = s.calcNewTopics(at, new int [] {2,3,4,4,4,4}); 70 | assertEquals(1, nt.length); 71 | assertEquals(4, nt[0]); 72 | } 73 | 74 | @Test 75 | public void testCalcNewTopicsDisjoint() { 76 | PoissonPolyaUrnHDPLDA s = new PoissonPolyaUrnHDPLDA(new SimpleLDAConfiguration()); 77 | List at = Arrays.asList(new Integer[]{1, 2, 3}); 78 | int [] nt = s.calcNewTopics(at, new int [] {4,4,4,4,5,6}); 79 | assertEquals(3, nt.length); 80 | assertEquals(4, nt[0]); 81 | assertEquals(5, nt[1]); 82 | assertEquals(6, nt[2]); 83 | } 84 | } 85 | -------------------------------------------------------------------------------- /src/test/java/cc/mallet/topics/PolyaUrnSpaliasTest.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.topics; 2 | 3 | import static org.junit.Assert.assertEquals; 4 | 5 | import java.io.IOException; 6 | 7 | import org.junit.Test; 8 | 9 | import cc.mallet.configuration.LDAConfiguration; 10 | import cc.mallet.configuration.SimpleLDAConfiguration; 11 | import cc.mallet.types.InstanceList; 12 | import cc.mallet.util.LDALoggingUtils; 13 | import cc.mallet.util.LDAUtils; 14 | import cc.mallet.util.LoggingUtils; 15 | 16 | public class PolyaUrnSpaliasTest { 17 | 18 | SimpleLDAConfiguration getStdCfg(String whichModel, Integer numIter, Integer numBatches) { 19 | Integer numTopics = 20; 20 | Double alpha = 0.1; 21 | Double beta = 0.01; 22 | Integer rareWordThreshold = 0; 23 | Integer showTopicsInterval = 10; 24 | Integer startDiagnosticOutput = 0; 25 | 26 | SimpleLDAConfiguration config = new SimpleLDAConfiguration(new LoggingUtils(), whichModel, 27 | numTopics, alpha, beta, numIter, 28 | numBatches, rareWordThreshold, showTopicsInterval, 29 | startDiagnosticOutput,4711,"src/main/resources/datasets/nips.txt"); 30 | 31 | return config; 32 | } 33 | 34 | @Test 35 | public void testGetPhiMeans() throws IOException { 36 | String whichModel = "polyaurn"; 37 | Integer numBatches = 6; 38 | 39 | Integer numIter = 10; 40 | SimpleLDAConfiguration config = getStdCfg(whichModel, numIter, numBatches); 41 | config.setSavePhi(true); 42 | config.setPhiBurnIn(20); 43 | 44 | String dataset_fn = config.getDatasetFilename(); 45 | System.out.println("Using dataset: " + dataset_fn); 46 | System.out.println("Scheme: " + whichModel); 47 | LDALoggingUtils lu = new LoggingUtils(); 48 | lu.checkAndCreateCurrentLogDir("TestRuns"); 49 | config.setLoggingUtil(lu); 50 | 51 | InstanceList instances = LDAUtils.loadInstances(dataset_fn, 52 | "stoplist.txt", config.getRareThreshold(LDAConfiguration.RARE_WORD_THRESHOLD)); 53 | 54 | LDAGibbsSampler model = new PolyaUrnSpaliasLDA(config); 55 | System.out.println( 56 | String.format("Spalias Uncollapsed Parallell LDA (%d batches).", 57 | config.getNoBatches(LDAConfiguration.NO_BATCHES_DEFAULT))); 58 | 59 | System.out.println("Vocabulary size: " + instances.getDataAlphabet().size() + "\n"); 60 | System.out.println("Instance list is: " + instances.size()); 61 | System.out.println("Loading data instances..."); 62 | 63 | model.setRandomSeed(config.getSeed(LDAConfiguration.SEED_DEFAULT)); 64 | model.addInstances(instances); 65 | 66 | Integer noIterations = config.getNoIterations(LDAConfiguration.NO_ITER_DEFAULT); 67 | System.out.println("Starting iterations (" + noIterations + " total)."); 68 | 69 | // Runs the model 70 | model.sample(noIterations); 71 | 72 | LDASamplerWithPhi modelWithPhi = (LDASamplerWithPhi) model; 73 | double [][] means = modelWithPhi.getPhiMeans(); 74 | 75 | int burnInIter = (int)(((double)config.getPhiBurnInPercent(LDAConfiguration.PHI_BURN_IN_DEFAULT) / 100) * noIterations); 76 | assertEquals(numIter - burnInIter, ((PolyaUrnSpaliasLDA)model).getNoSampledPhi()); 77 | assertEquals(means.length,config.getNoTopics(LDAConfiguration.NO_TOPICS_DEFAULT).intValue()); 78 | assertEquals(means[0].length,instances.getDataAlphabet().size()); 79 | 80 | int noNonZero = 0; 81 | for (int i = 0; i < means.length; i++) { 82 | for (int j = 0; j < means[i].length; j++) { 83 | if(means[i][j]!=0) noNonZero++; 84 | } 85 | } 86 | 87 | assertEquals(means.length * means[0].length, noNonZero); 88 | } 89 | 90 | 91 | } 92 | -------------------------------------------------------------------------------- /src/test/java/cc/mallet/topics/tui/LoglikelihoodCalculatorTest.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.topics.tui; 2 | 3 | import static org.junit.Assert.assertArrayEquals; 4 | import static org.junit.Assert.assertEquals; 5 | 6 | import java.util.HashMap; 7 | import java.util.Map; 8 | 9 | import org.junit.Test; 10 | 11 | public class LoglikelihoodCalculatorTest { 12 | 13 | @Test 14 | public void test() { 15 | int M = 2; 16 | int V = 3; 17 | int numTopics; 18 | 19 | int[][] zs = {{0,2,2,3}, {3,1,1,1}}; 20 | int[][] w = {{1,1,0,2}, {2,2,0,0}}; 21 | Map vocab = new HashMap<>(); 22 | vocab.put(0, "a"); 23 | vocab.put(1, "b"); 24 | vocab.put(2, "c"); 25 | numTopics = LoglikelihoodCalculator.findNumTopics(zs); 26 | 27 | assertEquals(4,numTopics); 28 | 29 | LoglikelihoodCalculator llc = new LoglikelihoodCalculator(numTopics, w, vocab, zs); 30 | 31 | int [][] nmk = new int[M][numTopics]; 32 | int [][] nkt = new int[numTopics][V]; 33 | int [] nk = new int[numTopics]; 34 | 35 | for( int i = 0; i < zs.length; i++) { 36 | int [] row = zs[i]; 37 | llc.updateLocalTopicCounts(nmk[i],row); 38 | llc.updateTypeTopicMatrix(nkt, w[i], row); 39 | llc.updateTopicCounts(nk,row); 40 | } 41 | 42 | int [] nke = {1,3,2,2}; 43 | assertArrayEquals(nk, nke); 44 | 45 | int [] nmk0 = {1,0,2,1}; 46 | assertArrayEquals(nmk0, nmk[0]); 47 | int [] nmk1 = {0,3,0,1}; 48 | assertArrayEquals(nmk1, nmk[1]); 49 | 50 | int [] nk0t = {0,1,0}; 51 | assertArrayEquals(nk0t, nkt[0]); 52 | int [] nk1t = {2,0,1}; 53 | assertArrayEquals(nk1t, nkt[1]); 54 | 55 | System.out.println(llc.calcLL(0.1, 0.01)); 56 | } 57 | 58 | } 59 | -------------------------------------------------------------------------------- /src/test/java/cc/mallet/types/CondDirichletDraw.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.types; 2 | 3 | import static org.junit.Assert.assertTrue; 4 | 5 | import java.util.Arrays; 6 | 7 | import org.junit.Test; 8 | 9 | public class CondDirichletDraw { 10 | 11 | 12 | @Test 13 | public void test() { 14 | double phi0[] = {0.1, 0.2, 0.5, 0.2}; 15 | double phi1[] = Arrays.copyOf(phi0, phi0.length); 16 | double alpha[] = {1.0, 1.0, 1.0, 1.0}; 17 | int phi_index[] = {0, 2}; 18 | ConditionalDirichlet testDirichlet1 = new ConditionalDirichlet(1.0, alpha); 19 | phi1 = testDirichlet1.nextConditionalDistribution(phi0, phi_index); 20 | 21 | // Test that drawn phi are the same 22 | int[] phi_index_test0 = {1, 3}; 23 | for (int i = 0; i < phi_index_test0.length; i++){ 24 | double diff = phi0[phi_index_test0[i]] - phi1[phi_index_test0[i]]; 25 | boolean test = Math.abs(diff) < 0.000000001; 26 | assertTrue(test); 27 | } 28 | 29 | // Test that drawn phi are not the same 30 | int[] phi_index_test1 = phi_index; 31 | for (int i = 0; i < phi_index_test1.length; i++){ 32 | double diff = phi0[phi_index_test1[i]] - phi1[phi_index_test1[i]]; 33 | boolean test = Math.abs(diff) > 0.000000001; 34 | assertTrue(test); 35 | } 36 | // Assert that the sums of the new draw are correct 37 | double sum_full_phi = 0; 38 | for (int i = 0; i < phi1.length; i++){ 39 | sum_full_phi += phi1[i]; 40 | } 41 | double diff = sum_full_phi - 1.0; 42 | System.out.println(diff); 43 | boolean test = Math.abs(diff) < 0.000000001; 44 | assertTrue(test); 45 | 46 | double sum_part_phi0 = 0; 47 | double sum_part_phi1 = 0; 48 | for (int i = 0; i < phi_index_test0.length; i++){ 49 | sum_part_phi0 += phi0[phi_index_test0[i]]; 50 | sum_part_phi1 += phi1[phi_index_test0[i]]; 51 | } 52 | double diff1 = sum_part_phi0 - sum_part_phi1; 53 | boolean test1 = Math.abs(diff1) < 0.000000001; 54 | assertTrue(test1); 55 | } 56 | } 57 | -------------------------------------------------------------------------------- /src/test/java/cc/mallet/types/SamplerTest.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.types; 2 | 3 | import static org.junit.Assert.assertFalse; 4 | 5 | import org.apache.commons.math3.distribution.PoissonDistribution; 6 | import org.apache.commons.math3.stat.inference.ChiSquareTest; 7 | import org.junit.Test; 8 | 9 | public class SamplerTest { 10 | 11 | ChiSquareTest cs = new ChiSquareTest(); 12 | 13 | @Test 14 | public void testPoissonNormalApprox() { 15 | double beta = 0.01; 16 | int betaAdd = 30; 17 | double lambda = beta + betaAdd; 18 | int nrDraws = 10_000; 19 | 20 | int ll = 20; 21 | int ul = 40; 22 | long [] fepDraws = new long[ul-ll]; 23 | long [] stdDraws = new long[ul-ll]; 24 | PoissonDistribution stdPois = new PoissonDistribution(lambda); 25 | for (int i = 0; i < nrDraws; i++) { 26 | { 27 | long nd = PolyaUrnDirichlet.nextPoissonNormalApproximation(lambda); 28 | if(nd
    =ll) fepDraws[(int)(nd-ll)]++; 29 | } 30 | 31 | { 32 | long stdnd = stdPois.sample(); 33 | if(stdnd
      =ll) stdDraws[(int)(stdnd-ll)]++; 34 | } 35 | 36 | } 37 | 38 | //System.out.println(Arrays.toString(fepDraws)); 39 | //System.out.println(Arrays.toString(stdDraws)); 40 | 41 | assertFalse(cs.chiSquareTestDataSetsComparison(fepDraws, stdDraws, 0.01)); 42 | } 43 | 44 | // TODO: Fix test 45 | // @Test 46 | // public void testBinomialNormalApprox() { 47 | // double p = 0.01; 48 | // int trials = 30; 49 | // int nrDraws = 10_000; 50 | // 51 | // int ll = 20; 52 | // int ul = 40; 53 | // long [] fepDraws = new long[ul-ll]; 54 | // long [] stdDraws = new long[ul-ll]; 55 | // 56 | // for (int i = 0; i < nrDraws; i++) { 57 | // { 58 | // double meanNormal = trials * p; 59 | // double variance = trials * p * (1-p); 60 | // long nd = (int) Math.round(Math.sqrt(variance) * ThreadLocalRandom.current().nextGaussian() + meanNormal); 61 | // if(nd
        =ll) fepDraws[(int)(nd-ll)]++; 62 | // } 63 | // 64 | // { 65 | // BinomialDistribution c_j_k = new BinomialDistribution(trials, p); 66 | // long stdnd = c_j_k.sample(); 67 | // if(stdnd
          =ll) stdDraws[(int)(stdnd-ll)]++; 68 | // } 69 | // 70 | // } 71 | // 72 | // //System.out.println(Arrays.toString(fepDraws)); 73 | // //System.out.println(Arrays.toString(stdDraws)); 74 | // 75 | // assertFalse(cs.chiSquareTestDataSetsComparison(fepDraws, stdDraws, 0.01)); 76 | // } 77 | 78 | } 79 | -------------------------------------------------------------------------------- /src/test/java/cc/mallet/types/TestSimpleMultinomial.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.types; 2 | 3 | import static org.junit.Assert.*; 4 | 5 | import org.apache.commons.math3.stat.inference.ChiSquareTest; 6 | import org.junit.Test; 7 | 8 | public class TestSimpleMultinomial { 9 | 10 | @Test 11 | public void test() { 12 | double[] dirichletParams = {1, 1, 1}; 13 | ChiSquareTest cs = new ChiSquareTest(); 14 | Dirichlet dir = new ParallelDirichlet(dirichletParams); 15 | for (int loop = 0; loop < 10; loop++) { 16 | double [] expected = dir.nextDistribution(); 17 | SimpleMultinomial sm = new SimpleMultinomial(expected); 18 | int [] draw = sm.draw(10000); 19 | long [] observed = new long[expected.length]; 20 | for (int i = 0; i < observed.length; i++) { 21 | observed[i] = draw[i]; 22 | } 23 | assertFalse(cs.chiSquareTest(expected, observed, 0.001)); 24 | } 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /src/test/java/cc/mallet/types/VSDirichletTest.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.types; 2 | 3 | import static org.junit.Assert.assertEquals; 4 | import static org.junit.Assert.assertTrue; 5 | 6 | import org.junit.Test; 7 | 8 | public class VSDirichletTest { 9 | 10 | 11 | @Test 12 | public void testIndicatorSampling() { 13 | double beta = 0.1; 14 | int numTypes = 10; 15 | int [] zeroCount = new int[numTypes]; 16 | 17 | double vsPrior = 0.1; 18 | 19 | double[] dirichletParams = {1, 100, 1, 100, 100, 1200, 100, 100, 1, 1000}; 20 | Dirichlet dir = new ParallelDirichlet(dirichletParams); 21 | 22 | double[] unifDirichletParams = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; 23 | Dirichlet unifDir = new ParallelDirichlet(unifDirichletParams); 24 | 25 | double [] phiRow = unifDir.nextDistribution(); 26 | 27 | SimpleMultinomial mn = new SimpleMultinomial(dir.nextDistribution()); 28 | VariableSelectionDirichlet dist = new VSDirichlet(beta, vsPrior); 29 | 30 | int noLoops = 100; 31 | for (int loop = 0; loop < noLoops; loop++) { 32 | int [] relevantTypeTopicCounts = mn.draw(100); 33 | 34 | VariableSelectionResult res = dist.nextDistribution(relevantTypeTopicCounts, phiRow); 35 | phiRow = res.getPhi(); 36 | 37 | int [] zeroIdxs = res.getNonZeroIdxs(); 38 | 39 | assertTrue(res.getPhi().length==numTypes); 40 | assertTrue(zeroIdxs.length<=numTypes); 41 | double sum = 0.0; 42 | for (int i = 0; i < res.getPhi().length; i++) { 43 | assertTrue(res.getPhi()[i]>=0 && res.getPhi()[i]<=1); 44 | sum += res.getPhi()[i]; 45 | if(res.getPhi()[i]==0.0) { 46 | zeroCount[i]++; 47 | } 48 | } 49 | assertEquals(1.0,sum,0.0001); 50 | //System.out.println("Non zero Idxs:" + zeroIdxs.length); 51 | for (int i = 0; i < zeroIdxs.length; i++) { 52 | assertTrue(zeroIdxs[i]>=0 && zeroIdxs[i]<=numTypes); 53 | } 54 | System.out.println(arrToStr(phiRow, "Phi")); 55 | } 56 | System.out.println(arrToStr(zeroCount, "ZeroCount")); 57 | } 58 | 59 | String arrToStr(double [] arr, String title) { 60 | String res = ""; 61 | res += title + "[" + arr.length + "]:"; 62 | for (int j = 0; j < arr.length; j++) { 63 | res += String.format("%.4f",arr[j]) + ", "; 64 | } 65 | return res; 66 | } 67 | 68 | String arrToStr(int [] arr, String title) { 69 | String res = ""; 70 | res += title + "[" + arr.length + "]:"; 71 | for (int j = 0; j < arr.length; j++) { 72 | res += arr[j] + ", "; 73 | } 74 | return res; 75 | } 76 | } 77 | -------------------------------------------------------------------------------- /src/test/java/cc/mallet/util/LogginUtilsTest.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.util; 2 | 3 | import java.io.FileNotFoundException; 4 | import java.io.PrintWriter; 5 | 6 | import org.junit.Test; 7 | 8 | public class LogginUtilsTest { 9 | 10 | @SuppressWarnings("resource") 11 | @Test 12 | public void testCreation() throws FileNotFoundException { 13 | new NullPrintWriter(); 14 | } 15 | 16 | @Test 17 | public void testWriting() throws FileNotFoundException { 18 | LDALoggingUtils lu = new LDANullLogger(); 19 | PrintWriter pw = lu.checkCreateAndCreateLogPrinter( 20 | lu.getLogDir() + "/timing_data", 21 | "thr_Phi_sampling.txt"); 22 | pw.println("before" + "," + 20000); 23 | pw.flush(); 24 | pw.close(); 25 | } 26 | 27 | } 28 | -------------------------------------------------------------------------------- /src/test/java/cc/mallet/utils/IndexSorterTest.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.utils; 2 | 3 | import static org.junit.Assert.*; 4 | 5 | import org.junit.Test; 6 | 7 | import cc.mallet.util.IndexSorter; 8 | 9 | public class IndexSorterTest { 10 | 11 | @Test 12 | public void testSortInts() { 13 | int [] values = {4,2,7,3,8,0}; 14 | int [] si = IndexSorter.getSortedIndices(values); 15 | int [] expected = {4,2,0,3,1,5}; 16 | for (int i = 0; i < expected.length; i++) { 17 | assertEquals(expected[i], si[i]); 18 | } 19 | 20 | values = new int[0]; 21 | si = IndexSorter.getSortedIndices(values); 22 | assertEquals(0, si.length); 23 | 24 | int [] values2 = {4,2,7,-3,3,8,0}; 25 | si = IndexSorter.getSortedIndices(values2); 26 | int [] expected2 = {5,2,0,4,1,6,3}; 27 | for (int i = 0; i < expected.length; i++) { 28 | assertEquals(expected2[i], si[i]); 29 | } 30 | } 31 | 32 | @Test 33 | public void testSortDoubles() { 34 | double [] values = {4.0,2.0,7.0,3.0,8.0,0.0}; 35 | int [] si = IndexSorter.getSortedIndices(values); 36 | int [] expected = {4,2,0,3,1,5}; 37 | for (int i = 0; i < expected.length; i++) { 38 | assertEquals(expected[i], si[i]); 39 | } 40 | 41 | values = new double[0]; 42 | si = IndexSorter.getSortedIndices(values); 43 | assertEquals(0, si.length); 44 | 45 | int [] values2 = {4,2,7,-3,3,8,0}; 46 | si = IndexSorter.getSortedIndices(values2); 47 | int [] expected2 = {5,2,0,4,1,6,3}; 48 | for (int i = 0; i < expected.length; i++) { 49 | assertEquals(expected2[i], si[i]); 50 | } 51 | } 52 | 53 | @Test 54 | public void testSortMatrix() { 55 | double[][] matrix= { 56 | {1, 5}, 57 | {13, 1.55}, 58 | {12, 100.6}, 59 | {12.1, .85} }; 60 | int [] si = IndexSorter.getSortedIndices(matrix,1); 61 | int [] expected = {2,0,1,3}; 62 | for (int i = 0; i < expected.length; i++) { 63 | assertEquals(expected[i], si[i]); 64 | } 65 | } 66 | } 67 | -------------------------------------------------------------------------------- /src/test/java/cc/mallet/utils/MultinomialSampler.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.utils; 2 | 3 | import java.util.Arrays; 4 | import java.util.concurrent.ThreadLocalRandom; 5 | 6 | import org.apache.commons.math3.stat.inference.ChiSquareTest; 7 | 8 | public class MultinomialSampler { 9 | 10 | double [] probs; 11 | 12 | public MultinomialSampler(double[] probs) { 13 | super(); 14 | this.probs = probs; 15 | } 16 | 17 | public int generateSample() { 18 | return generateSample(probs); 19 | } 20 | 21 | public static int generateSample(double [] probs) { 22 | double U = ThreadLocalRandom.current().nextDouble(); 23 | int theSample = -1; 24 | while (U > 0.0) { 25 | theSample++; 26 | U -= probs[theSample]; 27 | } 28 | return theSample; 29 | } 30 | 31 | public static int [] multinomialSampler(double[] probs, int noSamples) { 32 | int[] multinomialSamples = new int[noSamples]; 33 | for (int i = 0; i < noSamples; i++) { 34 | multinomialSamples[i] = generateSample(probs); 35 | } 36 | return multinomialSamples; 37 | } 38 | 39 | public static void main(String [] args) { 40 | double [] pi = {2.0/15.0,7.0/15.0,6.0/15.0}; 41 | MultinomialSampler ga = new MultinomialSampler(pi); 42 | 43 | int noSamples = 20; 44 | int [] samples = new int[noSamples]; 45 | for (int i = 0; i < samples.length; i++) { 46 | samples[i] = ga.generateSample(); 47 | } 48 | long [] cnts = new long[pi.length]; 49 | for (int i = 0; i < samples.length; i++) { 50 | cnts[samples[i]]++; 51 | } 52 | 53 | double [] obsFreq = new double[cnts.length]; 54 | for (int i = 0; i < obsFreq.length; i++) { 55 | obsFreq[i] = cnts[i] / (double) noSamples; 56 | } 57 | 58 | ChiSquareTest cs = new ChiSquareTest(); 59 | if(cs.chiSquareTest(pi, cnts, 0.01)) { 60 | System.out.println("Probs: " + Arrays.toString(pi) + " are NOT equal to " + Arrays.toString(obsFreq)); 61 | } else { 62 | System.out.println("Probs: " + Arrays.toString(pi) + " ARE equal to " + Arrays.toString(obsFreq)); 63 | } 64 | } 65 | } 66 | -------------------------------------------------------------------------------- /src/test/java/cc/mallet/utils/TestUtils.java: -------------------------------------------------------------------------------- 1 | package cc.mallet.utils; 2 | 3 | import static org.junit.Assert.assertEquals; 4 | 5 | public class TestUtils { 6 | 7 | public TestUtils() { 8 | } 9 | 10 | public static void assertEqualArrays(int[][] arr1, int[][] arr2) { 11 | assertEquals("Dimensions are not the same: " 12 | + arr2.length + "!=" + arr1.length, 13 | arr2.length, arr1.length); 14 | for (int i = 0; i < arr2.length; i++) { 15 | for (int j = 0; j < arr2[i].length; j++) { 16 | assertEquals("Collapsed and Uncollapsed token counts are not the same: " 17 | + arr2[i][j] + "!=" + arr1[i][j], 18 | arr2[i][j], arr1[i][j]); 19 | } 20 | } 21 | } 22 | 23 | public static void assertEqualArrays(double[][] arr1, double[][] arr2, double precision) { 24 | assertEquals("Dimensions are not the same: " 25 | + arr2.length + "!=" + arr1.length, 26 | arr2.length, arr1.length); 27 | for (int i = 0; i < arr2.length; i++) { 28 | for (int j = 0; j < arr2[i].length; j++) { 29 | assertEquals("Arrays are not the same: " 30 | + arr2[i][j] + "!=" + arr1[i][j], 31 | arr2[i][j], arr1[i][j], precision); 32 | } 33 | } 34 | } 35 | 36 | } 37 | -------------------------------------------------------------------------------- /src/test/resources/document_priors.txt: -------------------------------------------------------------------------------- 1 | pos, 0, 0,2,4 2 | pos, 19, 1,3,5 -------------------------------------------------------------------------------- /src/test/resources/max_doc_buf-2.cfg: -------------------------------------------------------------------------------- 1 | configs = large_wiki_random_100_spalias_cores_16_seed_4711 2 | no_runs = 1 3 | experiment_out_dir = stochvb_experiment 4 | 5 | [large_wiki_random_100_spalias_cores_16_seed_4711] 6 | title = large_wiki_random_100_spalias_cores_16_seed_4711 7 | description = Pubmed experiment (large_wiki_random_100_spalias_cores_16_seed_4711) 8 | dataset = src/main/resources/datasets/nips.txt 9 | scheme = spalias 10 | seed = 4711 11 | topics = 100 12 | alpha = 0.1 13 | beta = 0.01 14 | iterations = 1000 15 | batches = 16 16 | topic_batches = 16 17 | tfidf_vocab_size = 7700 18 | topic_interval = 10 19 | diagnostic_interval = -1 20 | dn_diagnostic_interval = -1 21 | results_size = 5 22 | debug = 0 23 | log_type_topic_density = true 24 | log_document_density = true 25 | max_doc_buf_size = 100000 -------------------------------------------------------------------------------- /src/test/resources/max_doc_buf-small.cfg: -------------------------------------------------------------------------------- 1 | configs = tf_idf_prune 2 | 3 | [tf_idf_prune] 4 | title = large_wiki_random_100_spalias_cores_16_seed_4711 5 | description = Pubmed experiment (large_wiki_random_100_spalias_cores_16_seed_4711) 6 | dataset = src/main/resources/datasets/nips.txt 7 | scheme = spalias 8 | seed = 4711 9 | topics = 100 10 | alpha = 0.1 11 | beta = 0.01 12 | iterations = 1000 13 | batches = 16 14 | topic_batches = 16 15 | tfidf_vocab_size = 7700 16 | topic_interval = 10 17 | diagnostic_interval = -1 18 | dn_diagnostic_interval = -1 19 | results_size = 5 20 | debug = 0 21 | log_type_topic_density = true 22 | log_document_density = true 23 | max_doc_buf_size = 10 24 | 25 | [rare_word_prune] 26 | title = large_wiki_random_100_spalias_cores_16_seed_4711 27 | description = Pubmed experiment (large_wiki_random_100_spalias_cores_16_seed_4711) 28 | dataset = src/main/resources/datasets/nips.txt 29 | scheme = spalias 30 | seed = 4711 31 | topics = 100 32 | alpha = 0.1 33 | beta = 0.01 34 | iterations = 1000 35 | batches = 16 36 | topic_batches = 16 37 | topic_interval = 10 38 | diagnostic_interval = -1 39 | dn_diagnostic_interval = -1 40 | results_size = 5 41 | debug = 0 42 | log_type_topic_density = true 43 | log_document_density = true 44 | max_doc_buf_size = 10 45 | rare_threshold = 3 46 | 47 | [default] 48 | title = large_wiki_random_100_spalias_cores_16_seed_4711 49 | description = Pubmed experiment (large_wiki_random_100_spalias_cores_16_seed_4711) 50 | dataset = src/main/resources/datasets/nips.txt 51 | scheme = spalias 52 | seed = 4711 53 | topics = 100 54 | alpha = 0.1 55 | beta = 0.01 56 | iterations = 1000 57 | batches = 16 58 | topic_batches = 16 59 | tfidf_vocab_size = 7700 60 | topic_interval = 10 61 | diagnostic_interval = -1 62 | dn_diagnostic_interval = -1 63 | results_size = 5 64 | debug = 0 65 | log_type_topic_density = true 66 | log_document_density = true 67 | max_doc_buf_size = 10 -------------------------------------------------------------------------------- /src/test/resources/max_doc_buf.cfg: -------------------------------------------------------------------------------- 1 | configs = large_wiki_random_100_spalias_cores_16_seed_4711 2 | 3 | [large_wiki_random_100_spalias_cores_16_seed_4711] 4 | title = large_wiki_random_100_spalias_cores_16_seed_4711 5 | description = Pubmed experiment (large_wiki_random_100_spalias_cores_16_seed_4711) 6 | dataset = ../datasets_pclda/wiki_random.txt 7 | scheme = spalias 8 | seed = 4711 9 | topics = 100 10 | alpha = 0.1 11 | beta = 0.01 12 | iterations = 1000 13 | batches = 16 14 | topic_batches = 16 15 | tfidf_vocab_size = 7700 16 | topic_interval = 10 17 | diagnostic_interval = -1 18 | dn_diagnostic_interval = -1 19 | results_size = 5 20 | debug = 0 21 | log_type_topic_density = true 22 | log_document_density = true 23 | max_doc_buf_size = 471100 24 | -------------------------------------------------------------------------------- /src/test/resources/nips_document_priors.txt: -------------------------------------------------------------------------------- 1 | 0, 0,2,4,6 2 | 19, 1,3,5,7 -------------------------------------------------------------------------------- /src/test/resources/nips_topic_priors.txt: -------------------------------------------------------------------------------- 1 | 0, cell, stimulus, visual, cortex, response, spatial 2 | 19, image, images, pixel -------------------------------------------------------------------------------- /src/test/resources/special_chars.cfg: -------------------------------------------------------------------------------- 1 | configs = special 2 | 3 | [special] 4 | title = special_chars 5 | description = Special characters 6 | dataset = src/main/resources/datasets/special_chars.txt 7 | scheme = spalias 8 | seed = 4711 9 | topics = 100 10 | alpha = 0.1 11 | beta = 0.01 12 | iterations = 1000 13 | batches = 16 14 | topic_batches = 16 15 | tfidf_vocab_size = 7700 16 | topic_interval = 10 17 | diagnostic_interval = -1 18 | dn_diagnostic_interval = -1 19 | results_size = 5 20 | debug = 0 21 | log_type_topic_density = true 22 | log_document_density = true 23 | -------------------------------------------------------------------------------- /src/test/resources/topic_priors.txt: -------------------------------------------------------------------------------- 1 | 0, java, jvm, jre, NullPointerException 2 | 2, cell, control, cabinet -------------------------------------------------------------------------------- /src/test/resources/topic_priors_SmallTexts.txt: -------------------------------------------------------------------------------- 1 | 0, mother, slip 2 | 3, disk, drive -------------------------------------------------------------------------------- /stoplist-empty.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lejon/PartiallyCollapsedLDA/0c1e588c2555811a5f5616f75a17497081560ba8/stoplist-empty.txt --------------------------------------------------------------------------------