├── .gitignore ├── LICENSE ├── README.adoc ├── docs └── examples │ ├── csv.adoc │ └── twitter.adoc ├── pom.xml └── src ├── main ├── java │ └── com │ │ └── lucidworks │ │ └── spark │ │ ├── BatchSizeType.java │ │ ├── ShardIndexPartitioner.java │ │ ├── SparkApp.java │ │ ├── SparkSolrAccumulator.scala │ │ ├── example │ │ ├── hadoop │ │ │ ├── HdfsToSolrRDDProcessor.java │ │ │ └── Logs2SolrRDDProcessor.java │ │ ├── ml │ │ │ ├── MLPipeline.java │ │ │ └── UseML.java │ │ ├── query │ │ │ ├── KMeansAnomaly.java │ │ │ └── ReadTermVectors.java │ │ └── streaming │ │ │ ├── DocumentFilteringStreamProcessor.java │ │ │ └── TwitterToSolrStreamProcessor.java │ │ ├── filter │ │ └── DocFilterContext.java │ │ ├── fusion │ │ └── FusionPipelineClient.java │ │ ├── query │ │ ├── PagedResultsIterator.java │ │ ├── ResultsIterator.java │ │ ├── SolrStreamIterator.java │ │ ├── SparkSolrClientCache.java │ │ ├── StreamingExpressionResultIterator.java │ │ ├── StreamingResultsIterator.java │ │ ├── TupleStreamIterator.java │ │ └── sql │ │ │ └── SolrSQLSupport.java │ │ ├── rdd │ │ ├── SolrJavaRDD.java │ │ └── SolrStreamJavaRDD.java │ │ └── util │ │ ├── EmbeddedSolrServerFactory.java │ │ ├── FusionAuthHttpClient.java │ │ ├── ObjectSizeCalculator.java │ │ ├── SQLQuerySupport.java │ │ ├── ScalaUtil.java │ │ └── Utils.java ├── resources │ └── embedded │ │ ├── solr.xml │ │ └── solrconfig.xml └── scala │ ├── com │ └── lucidworks │ │ └── spark │ │ ├── JsonFacetUtil.scala │ │ ├── Logging.scala │ │ ├── Partitioner.scala │ │ ├── SolrConf.scala │ │ ├── SolrRDDPartition.scala │ │ ├── SolrRelation.scala │ │ ├── SolrStreamWriter.scala │ │ ├── SparkSolrAccumulatorContext.scala │ │ ├── TimePartitioningQuery.scala │ │ ├── analysis │ │ └── LuceneTextAnalyzer.scala │ │ ├── example │ │ ├── NewRDDExample.scala │ │ ├── RDDExample.scala │ │ ├── events │ │ │ └── EventsimIndexer.scala │ │ ├── ml │ │ │ ├── MLPipelineScala.scala │ │ │ └── NewsgroupsIndexer.scala │ │ └── query │ │ │ ├── QueryBenchmark.scala │ │ │ └── WordCount.scala │ │ ├── ml │ │ └── feature │ │ │ └── LuceneTextAnalyzerTransformer.scala │ │ ├── rdd │ │ ├── SelectSolrRDD.scala │ │ ├── SolrRDD.scala │ │ └── StreamingSolrRDD.scala │ │ └── util │ │ ├── ConfigurationConstants.scala │ │ ├── Constants.scala │ │ ├── JavaApiHelper.scala │ │ ├── JsonUtil.scala │ │ ├── QueryConstants.scala │ │ ├── SolrDataFrameImplicits.scala │ │ ├── SolrQuerySupport.scala │ │ ├── SolrRelationUtil.scala │ │ └── SolrSupport.scala │ ├── org │ └── apache │ │ └── spark │ │ ├── ml │ │ └── HasInputColsTransformer.scala │ │ └── solr │ │ └── SparkInternalObjects.scala │ └── solr │ └── DefaultSource.scala └── test ├── java └── com │ └── lucidworks │ └── spark │ ├── RDDProcessorTestBase.java │ ├── SolrRDDTest.java │ ├── SolrRelationTest.java │ ├── SolrSqlTest.java │ ├── StreamProcessorTestBase.java │ ├── TestSolrCloudClusterSupport.java │ ├── analysis │ └── LuceneTextAnalyzerTest.java │ ├── example │ ├── hadoop │ │ ├── HdfsToSolrRDDProcessorTest.java │ │ └── Logs2SolrRDDProcessorTest.java │ ├── query │ │ ├── BuildQueryTest.java │ │ ├── ReadTermVectorsTest.java │ │ └── WordCountTest.java │ └── streaming │ │ ├── BasicIndexingTest.java │ │ └── DocumentFilteringStreamProcessorTest.java │ ├── fusion │ └── FusionPipelineClientTest.java │ ├── ml │ └── feature │ │ └── LuceneTextAnalyzerTransformerTest.java │ ├── query │ ├── StreamingResultsIteratorTest.java │ └── sql │ │ └── SolrSQLSupportTest.java │ ├── solr │ └── TestEmbeddedSolrServer.java │ └── util │ └── EventsimUtil.java ├── resources ├── conf │ ├── lang │ │ └── stopwords_en.txt │ ├── managed-schema │ └── solrconfig.xml ├── custom-solrconfig.xml ├── eventsim │ ├── fields_schema.json │ └── sample_eventsim_1000.json ├── hive-site.xml ├── log4j.properties ├── ml-100k │ ├── README │ ├── movielens_movies.json │ ├── movielens_ratings.json │ ├── movielens_ratings_10k.json │ └── movielens_users.json ├── solr.xml ├── test-data │ ├── child_documents.json │ ├── em_sample.json │ ├── events.json │ ├── nyc_yellow_taxi_sample_1k.csv │ ├── oneusagov │ │ └── oneusagov_sample.json │ └── simple.csv └── wire-mock-props.xml └── scala └── com └── lucidworks └── spark ├── EventsimTestSuite.scala ├── MovieLensTestSuite.scala ├── RDDTestSuite.scala ├── RelationTestSuite.scala ├── SparkSolrFunSuite.scala ├── TestChildDocuments.scala ├── TestFacetQuerying.scala ├── TestFramework.scala ├── TestIndexing.scala ├── TestPartitionByTimeQuerySupport.scala ├── TestQuerying.scala ├── TestShardSplits.scala ├── TestSolrRelation.scala ├── TestSolrStreamWriter.scala ├── analysis └── LuceneTextAnalyzerSuite.scala ├── examples └── TwitterTestSuite.scala ├── ml ├── SparkMLExamples.scala └── feature │ └── LuceneTextAnalyzerTransformerSuite.scala └── util └── SolrCloudUtil.scala /.gitignore: -------------------------------------------------------------------------------- 1 | *.class 2 | 3 | # Mobile Tools for Java (J2ME) 4 | .mtj.tmp/ 5 | 6 | # Package Files # 7 | *.jar 8 | *.war 9 | *.ear 10 | 11 | # virtual machine crash logs, see http://www.java.com/en/download/help/error_hotspot.xml 12 | hs_err_pid* 13 | 14 | dependency-reduced-pom.xml 15 | spark-solr.iml 16 | spark-solr.ipr 17 | spark-solr.iws 18 | 19 | .idea/ 20 | target/ 21 | 22 | -------------------------------------------------------------------------------- /docs/examples/csv.adoc: -------------------------------------------------------------------------------- 1 | == Indexing and Querying NYC yellow taxi csv data 2 | 3 | `localhost:9983` will be used as zkhost in this example. Instead of the main jar file, the shaded artifact should be used for these examples. 4 | 5 | Once the shaded artifact is downloaded or built, it can be imported to the spark-shell by using the `--jars` config 6 | 7 | ./bin/spark-shell --jars spark-solr-3.0.0-alpha-shaded.jar 8 | 9 | === Writing data 10 | 11 | * Create a collection in Solr to index data to. 12 | 13 | Example: The below HTTP call creates a Solr collection with the name 'test-spark-solr' 14 | curl -X GET "http://localhost:8983/solr/admin/collections?action=create&name=test-spark-solr&collection.configName=techproducts&numShards=2&maxShardsPerNode=2" 15 | 16 | * Read the csv file as a Spark DataFrame. The CSV file I have used is located https://github.com/lucidworks/spark-solr/blob/master/src/test/resources/test-data/nyc_yellow_taxi_sample_1k.csv[here] 17 | 18 | [source,scala] 19 | val csvFileLocation = "src/test/resources/test-data/nyc_yellow_taxi_sample_1k.csv" 20 | var csvDF = spark.read.format("com.databricks.spark.csv") 21 | .option("header", "true") 22 | .option("inferSchema", "true") 23 | .load(csvFileLocation) 24 | 25 | * Clean up the data and create `pickup`, `dropoff` fields 26 | 27 | [source,scala] 28 | -------------- 29 | // Filter out invalid lat/lon cols 30 | csvDF = csvDF.filter("pickup_latitude >= -90 AND pickup_latitude <= 90 AND pickup_longitude >= -180 AND pickup_longitude <= 180") 31 | csvDF = csvDF.filter("dropoff_latitude >= -90 AND dropoff_latitude <= 90 AND dropoff_longitude >= -180 AND dropoff_longitude <= 180") 32 | 33 | // concat the lat/lon cols into a single value expected by solr location fields 34 | csvDF = csvDF.withColumn("pickup", concat_ws(",", col("pickup_latitude"),col("pickup_longitude"))).drop("pickup_latitude").drop("pickup_longitude") 35 | csvDF = csvDF.withColumn("dropoff", concat_ws(",", col("dropoff_latitude"),col("dropoff_longitude"))).drop("dropoff_latitude").drop("dropoff_longitude") 36 | -------------- 37 | 38 | * Write data to Solr. Before writing data to Solr, spark-solr tries to create the fields that exist in the csvDF but not in Solr via Schema API. For schema API to be usable in Solr, the https://cwiki.apache.org/confluence/display/solr/Schema+Factory+Definition+in+SolrConfig[ManagedIndexSchemaFactory] should be enabled. If you do not want to enable managed schema, then please manually create all the fields in the csv file in Solr 39 | 40 | [source,scala] 41 | -------------- 42 | val options = Map( 43 | "zkhost" -> "localhost:9983", 44 | "collection" -> "test-spark-solr", 45 | "gen_uniq_key" -> "true" // Generate unique key if the 'id' field does not exist 46 | ) 47 | 48 | // Write to Solr 49 | csvDF.write.format("solr").options(options).mode(org.apache.spark.sql.SaveMode.Overwrite).save 50 | -------------- 51 | 52 | * 999 documents should appear in Solr. If all the docs are not yet visible, then an explicit commit can be done via HTTP call. 53 | 54 | === Reading data 55 | 56 | In this section, we will try to read the csv data that is indexed to the Solr collection `test-spark-solr` 57 | 58 | * Load the solr collection as a DataFrame 59 | 60 | [source,scala] 61 | -------------- 62 | val options = Map( 63 | "zkHost" -> "localhost:9983", 64 | "collection" -> "test-spark-solr" 65 | ) 66 | 67 | val df = spark.read.format("solr").options(options).load 68 | -------------- 69 | 70 | * Every DataFrame has a schema. You can use the `printSchema()` function to get information about the fields available for the tweets DataFrame 71 | 72 | [source,scala] 73 | scala> df.printSchema() 74 | root 75 | |-- improvement_surcharge: double (nullable = true) 76 | |-- vendor_id: long (nullable = true) 77 | |-- trip_distance: double (nullable = true) 78 | |-- tolls_amount: double (nullable = true) 79 | |-- tip_amount: double (nullable = true) 80 | |-- id: string (nullable = false) 81 | |-- pickup: string (nullable = true) 82 | |-- payment_type: long (nullable = true) 83 | |-- fare_amount: double (nullable = true) 84 | |-- passenger_count: long (nullable = true) 85 | |-- dropoff: string (nullable = true) 86 | |-- store_and_fwd_flag: string (nullable = true) 87 | |-- extra: double (nullable = true) 88 | |-- dropoff_datetime: timestamp (nullable = true) 89 | |-- rate_code_id: long (nullable = true) 90 | |-- total_amount: double (nullable = true) 91 | |-- pickup_datetime: timestamp (nullable = true) 92 | |-- mta_tax: double (nullable = true) 93 | 94 | * To be able to query with SQL syntax, we need to register this DataFrame as a table 95 | 96 | [source,scala] 97 | df.registerTempTable("trips") 98 | 99 | * Fire off SQL queries 100 | 101 | [source,scala] 102 | -------------- 103 | // Cache the DataFrame for efficiency. See http://spark.apache.org/docs/latest/sql-programming-guide.html#caching-data-in-memory 104 | scala>df.cache() 105 | scala> sqlContext.sql("SELECT avg(tip_amount), avg(fare_amount) FROM trips").show() 106 | +-----------------+-----------------+ 107 | | _c0| _c1| 108 | +-----------------+-----------------+ 109 | |1.630050050050051|12.27087087087087| 110 | +-----------------+-----------------+ 111 | 112 | scala> sqlContext.sql("SELECT max(tip_amount), max(fare_amount) FROM trips WHERE trip_distance > 10").show() 113 | +-----+----+ 114 | | _c0| _c1| 115 | +-----+----+ 116 | |16.44|83.5| 117 | +-----+----+ 118 | -------------- 119 | -------------------------------------------------------------------------------- /src/main/java/com/lucidworks/spark/BatchSizeType.java: -------------------------------------------------------------------------------- 1 | package com.lucidworks.spark; 2 | 3 | /** 4 | * Specifices what type of Solr batch you are using. 5 | * A) based on the number of documents in the batch. 6 | * Or B) based on the number of bytes in the batch. 7 | */ 8 | public enum BatchSizeType { 9 | NUM_DOCS, 10 | NUM_BYTES 11 | } 12 | -------------------------------------------------------------------------------- /src/main/java/com/lucidworks/spark/ShardIndexPartitioner.java: -------------------------------------------------------------------------------- 1 | package com.lucidworks.spark; 2 | 3 | import com.lucidworks.spark.util.SolrSupport; 4 | import org.apache.solr.client.solrj.impl.CloudSolrClient; 5 | import org.apache.solr.common.SolrInputDocument; 6 | import org.apache.solr.common.cloud.*; 7 | import org.apache.spark.Partitioner; 8 | 9 | import java.io.Serializable; 10 | import java.util.Collection; 11 | import java.util.HashMap; 12 | import java.util.Map; 13 | 14 | /** 15 | * Partition using SolrCloud's sharding scheme. 16 | */ 17 | public class ShardIndexPartitioner extends Partitioner implements Serializable { 18 | 19 | protected String zkHost; 20 | protected String collection; 21 | protected String idField; 22 | 23 | protected transient CloudSolrClient cloudSolrServer = null; 24 | protected transient DocCollection docCollection = null; 25 | protected transient Map shardIndexCache = null; 26 | 27 | public ShardIndexPartitioner(String zkHost, String collection) { 28 | this(zkHost, collection, "id"); 29 | } 30 | 31 | public ShardIndexPartitioner(String zkHost, String collection, String idField) { 32 | this.zkHost = zkHost; 33 | this.collection = collection; 34 | this.idField = idField; 35 | } 36 | 37 | @Override 38 | public int numPartitions() { 39 | return getDocCollection().getActiveSlices().size(); 40 | } 41 | 42 | public String getShardId(SolrInputDocument doc) { 43 | return getShardId((String)doc.getFieldValue(idField)); 44 | } 45 | 46 | public String getShardId(String docId) { 47 | DocCollection dc = getDocCollection(); 48 | Slice slice = dc.getRouter().getTargetSlice(docId, null, null, null, dc); 49 | return slice.getName(); 50 | } 51 | 52 | @Override 53 | public int getPartition(Object o) { 54 | 55 | Object docId = null; 56 | if (o instanceof SolrInputDocument) { 57 | SolrInputDocument doc = (SolrInputDocument)o; 58 | docId = doc.getFieldValue(idField); 59 | if (docId == null) 60 | throw new IllegalArgumentException("SolrInputDocument must contain a non-null value for "+idField); 61 | } else { 62 | docId = o; 63 | } 64 | 65 | if (!(docId instanceof String)) 66 | throw new IllegalArgumentException("Only String document IDs are supported by this Partitioner!"); 67 | 68 | DocCollection dc = getDocCollection(); 69 | Slice slice = dc.getRouter().getTargetSlice((String)docId, null, null, null, dc); 70 | return getShardIndex(slice.getName(), dc); 71 | } 72 | 73 | protected final synchronized int getShardIndex(String shardId, DocCollection dc) { 74 | if (shardIndexCache == null) 75 | shardIndexCache = new HashMap<>(20); 76 | 77 | Integer idx = shardIndexCache.get(shardId); 78 | if (idx != null) 79 | return idx; // meh auto-boxing 80 | 81 | int s = 0; 82 | for (Slice slice : dc.getSlices()) { 83 | if (shardId.equals(slice.getName())) { 84 | shardIndexCache.put(shardId, s); 85 | return s; 86 | } 87 | ++s; 88 | } 89 | throw new IllegalStateException("Cannot find index of shard '"+shardId+"' in collection: "+collection); 90 | } 91 | 92 | protected final synchronized DocCollection getDocCollection() { 93 | if (docCollection == null) { 94 | ZkStateReader zkStateReader = getCloudSolrServer().getZkStateReader(); 95 | docCollection = zkStateReader.getClusterState().getCollection(collection); 96 | 97 | // do basic checks once 98 | DocRouter docRouter = docCollection.getRouter(); 99 | if (docRouter instanceof ImplicitDocRouter) 100 | throw new IllegalStateException("Implicit document routing not supported by this Partitioner!"); 101 | Collection shards = getDocCollection().getSlices(); 102 | if (shards == null || shards.size() == 0) 103 | throw new IllegalStateException("Collection '"+collection+"' does not have any shards!"); 104 | } 105 | return docCollection; 106 | } 107 | 108 | protected final synchronized CloudSolrClient getCloudSolrServer() { 109 | if (cloudSolrServer == null) 110 | cloudSolrServer = SolrSupport.getCachedCloudClient(zkHost); 111 | return cloudSolrServer; 112 | } 113 | } 114 | -------------------------------------------------------------------------------- /src/main/java/com/lucidworks/spark/SparkSolrAccumulator.scala: -------------------------------------------------------------------------------- 1 | package com.lucidworks.spark 2 | 3 | import java.lang.Long 4 | import org.apache.spark.util.AccumulatorV2 5 | 6 | class SparkSolrAccumulator extends AccumulatorV2[java.lang.Long, java.lang.Long] { 7 | private var _count = 0L 8 | 9 | override def isZero: Boolean = _count == 0 10 | 11 | override def copy(): SparkSolrAccumulator = { 12 | val newAcc = new SparkSolrAccumulator 13 | newAcc._count = this._count 14 | newAcc 15 | } 16 | 17 | override def reset(): Unit = { 18 | _count = 0L 19 | } 20 | 21 | def count: Long = _count 22 | 23 | override def value: Long = _count 24 | 25 | def inc(): Unit = _count += 1 26 | 27 | override def add(v: Long): Unit = { 28 | _count += v 29 | } 30 | 31 | override def merge(other: AccumulatorV2[Long, Long]): Unit = other match { 32 | case o: SparkSolrAccumulator => 33 | _count += o.count 34 | case _ => 35 | throw new UnsupportedOperationException( 36 | s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}") 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /src/main/java/com/lucidworks/spark/example/hadoop/HdfsToSolrRDDProcessor.java: -------------------------------------------------------------------------------- 1 | package com.lucidworks.spark.example.hadoop; 2 | 3 | import com.lucidworks.spark.BatchSizeType; 4 | import com.lucidworks.spark.util.SolrSupport; 5 | import com.lucidworks.spark.SparkApp; 6 | import org.apache.commons.cli.CommandLine; 7 | import org.apache.commons.cli.Option; 8 | import org.apache.log4j.Logger; 9 | import org.apache.solr.client.solrj.impl.CloudSolrClient; 10 | import org.apache.solr.common.SolrInputDocument; 11 | import org.apache.spark.SparkConf; 12 | import org.apache.spark.api.java.JavaPairRDD; 13 | import org.apache.spark.api.java.JavaRDD; 14 | import org.apache.spark.api.java.JavaSparkContext; 15 | import org.apache.spark.api.java.function.PairFunction; 16 | import scala.Tuple2; 17 | 18 | public class HdfsToSolrRDDProcessor implements SparkApp.RDDProcessor { 19 | 20 | public static Logger log = Logger.getLogger(HdfsToSolrRDDProcessor.class); 21 | 22 | public String getName() { 23 | return "hdfs-to-solr"; 24 | } 25 | 26 | public Option[] getOptions() { 27 | return new Option[]{ 28 | Option.builder("hdfsPath") 29 | .argName("PATH") 30 | .hasArg() 31 | .required(false) 32 | .desc("HDFS path identifying the directories / files to index") 33 | .build(), 34 | Option.builder("queueSize") 35 | .argName("INT") 36 | .hasArg() 37 | .required(false) 38 | .desc("Queue size for ConcurrentUpdateSolrClient; default is 1000") 39 | .build(), 40 | Option.builder("numRunners") 41 | .argName("INT") 42 | .hasArg() 43 | .required(false) 44 | .desc("Number of runner threads per ConcurrentUpdateSolrClient instance; default is 2") 45 | .build(), 46 | Option.builder("pollQueueTime") 47 | .argName("INT") 48 | .hasArg() 49 | .required(false) 50 | .desc("Number of millis to wait until CUSS sees a doc on the queue before it closes the current request and starts another; default is 20 ms") 51 | .build() 52 | }; 53 | } 54 | 55 | // Benchmarking dataset generated by Solr Scale Toolkit 56 | private static final String[] pigSchema = 57 | ("id,integer1_i,integer2_i,long1_l,long2_l,float1_f,float2_f,double1_d,double2_d,timestamp1_tdt," + 58 | "timestamp2_tdt,string1_s,string2_s,string3_s,boolean1_b,boolean2_b,text1_en,text2_en,text3_en,random_bucket").split(","); 59 | 60 | public int run(SparkConf conf, CommandLine cli) throws Exception { 61 | try (JavaSparkContext jsc = new JavaSparkContext(conf)) { 62 | JavaRDD textFiles = jsc.textFile(cli.getOptionValue("hdfsPath")); 63 | JavaPairRDD pairs = textFiles.mapToPair(new PairFunction() { 64 | public Tuple2 call(String line) throws Exception { 65 | SolrInputDocument doc = new SolrInputDocument(); 66 | String[] row = line.split("\t"); 67 | if (row.length != pigSchema.length) 68 | return null; 69 | 70 | for (int c = 0; c < row.length; c++) 71 | if (row[c] != null && row[c].length() > 0) 72 | doc.setField(pigSchema[c], row[c]); 73 | 74 | return new Tuple2<>((String) doc.getFieldValue("id"), doc); 75 | } 76 | }); 77 | 78 | String zkHost = cli.getOptionValue("zkHost", "localhost:9983"); 79 | String collection = cli.getOptionValue("collection", "collection1"); 80 | int queueSize = Integer.parseInt(cli.getOptionValue("queueSize", "1000")); 81 | int numRunners = Integer.parseInt(cli.getOptionValue("numRunners", "2")); 82 | int pollQueueTime = Integer.parseInt(cli.getOptionValue("pollQueueTime", "20")); 83 | //SolrSupport.streamDocsIntoSolr(zkHost, collection, "id", pairs, queueSize, numRunners, pollQueueTime); 84 | SolrSupport.indexDocs(zkHost, collection, 100, BatchSizeType.NUM_DOCS, pairs.values().rdd()); 85 | 86 | // send a final commit in case soft auto-commits are not enabled 87 | CloudSolrClient cloudSolrClient = SolrSupport.getCachedCloudClient(zkHost); 88 | cloudSolrClient.setDefaultCollection(collection); 89 | cloudSolrClient.commit(true, true); 90 | } 91 | return 0; 92 | } 93 | } 94 | -------------------------------------------------------------------------------- /src/main/java/com/lucidworks/spark/example/hadoop/Logs2SolrRDDProcessor.java: -------------------------------------------------------------------------------- 1 | package com.lucidworks.spark.example.hadoop; 2 | 3 | import com.lucidworks.spark.util.SolrSupport; 4 | import com.lucidworks.spark.SparkApp; 5 | import org.apache.commons.cli.CommandLine; 6 | import org.apache.commons.cli.Option; 7 | import org.apache.commons.compress.compressors.bzip2.BZip2CompressorInputStream; 8 | import org.apache.commons.compress.compressors.gzip.GzipCompressorInputStream; 9 | import org.apache.log4j.Logger; 10 | import org.apache.solr.client.solrj.SolrClient; 11 | import org.apache.solr.common.SolrInputDocument; 12 | import org.apache.spark.SparkConf; 13 | import org.apache.spark.api.java.JavaSparkContext; 14 | import org.apache.spark.api.java.function.VoidFunction; 15 | import org.apache.spark.input.PortableDataStream; 16 | import scala.Tuple2; 17 | 18 | import java.io.BufferedReader; 19 | import java.io.InputStream; 20 | import java.io.InputStreamReader; 21 | import java.util.ArrayList; 22 | import java.util.List; 23 | import java.util.zip.ZipInputStream; 24 | 25 | import scala.collection.JavaConverters; 26 | 27 | public class Logs2SolrRDDProcessor implements SparkApp.RDDProcessor { 28 | 29 | public static Logger log = Logger.getLogger(Logs2SolrRDDProcessor.class); 30 | 31 | public String getName() { return "logs2solr"; } 32 | 33 | public Option[] getOptions() { 34 | return new Option[]{ 35 | Option.builder("hdfsPath") 36 | .argName("PATH") 37 | .hasArg() 38 | .required(false) 39 | .desc("HDFS path identifying the directories / files to index") 40 | .build() 41 | }; 42 | } 43 | 44 | public int run(SparkConf conf, CommandLine cli) throws Exception { 45 | try(JavaSparkContext jsc = new JavaSparkContext(conf)){ 46 | final String zkHost = cli.getOptionValue("zkHost", "localhost:9983"); 47 | final String collection = cli.getOptionValue("collection", "collection1"); 48 | final int batchSize = Integer.parseInt(cli.getOptionValue("batchSize", "1000")); 49 | jsc.binaryFiles(cli.getOptionValue("hdfsPath")).foreach( 50 | new VoidFunction>() { 51 | public void call(Tuple2 t2) throws Exception { 52 | final SolrClient solrServer = SolrSupport.getCachedCloudClient(zkHost); 53 | List batch = new ArrayList(batchSize); 54 | String path = t2._1(); 55 | BufferedReader br = null; 56 | String line = null; 57 | int lineNum = 0; 58 | try { 59 | br = new BufferedReader(new InputStreamReader(openPortableDataStream(t2._2()), "UTF-8")); 60 | while ((line = br.readLine()) != null) { 61 | ++lineNum; 62 | SolrInputDocument doc = new SolrInputDocument(); 63 | doc.setField("id", path + ":" + lineNum); 64 | doc.setField("path_s", path); 65 | doc.setField("line_t", line); 66 | batch.add(doc); 67 | if (batch.size() >= batchSize) 68 | SolrSupport.sendBatchToSolr(solrServer, collection, JavaConverters.collectionAsScalaIterable(batch)); 69 | 70 | if (lineNum % 10000 == 0) 71 | log.info("Sent " + lineNum + " docs to Solr from " + path); 72 | } 73 | if (!batch.isEmpty()) 74 | SolrSupport.sendBatchToSolr(solrServer, collection, JavaConverters.collectionAsScalaIterable(batch)); 75 | } catch (Exception exc) { 76 | log.error("Failed to read '" + path + "' due to: " + exc); 77 | } finally { 78 | if (br != null) { 79 | try { 80 | br.close(); 81 | } catch (Exception ignore) { 82 | } 83 | } 84 | } 85 | } 86 | 87 | InputStream openPortableDataStream(PortableDataStream pds) throws Exception { 88 | InputStream in = null; 89 | String path = pds.getPath(); 90 | log.info("Opening InputStream to " + path); 91 | if (path.endsWith(".zip")) { 92 | try(ZipInputStream zipIn = new ZipInputStream(pds.open())) { 93 | zipIn.getNextEntry(); 94 | in = zipIn; 95 | } 96 | } else if (path.endsWith(".bz2")) { 97 | in = new BZip2CompressorInputStream(pds.open()); 98 | } else if (path.endsWith(".gz")) { 99 | in = new GzipCompressorInputStream(pds.open()); 100 | } 101 | return in; 102 | } 103 | }); 104 | } 105 | return 0; 106 | } 107 | } 108 | -------------------------------------------------------------------------------- /src/main/java/com/lucidworks/spark/example/ml/UseML.java: -------------------------------------------------------------------------------- 1 | package com.lucidworks.spark.example.ml; 2 | 3 | import com.lucidworks.spark.SparkApp; 4 | import org.apache.commons.cli.CommandLine; 5 | import org.apache.commons.cli.Option; 6 | import org.apache.spark.SparkConf; 7 | import org.apache.spark.api.java.JavaSparkContext; 8 | import org.apache.spark.ml.PipelineModel; 9 | import org.apache.spark.ml.tuning.CrossValidatorModel; 10 | import org.apache.spark.sql.*; 11 | import org.apache.spark.sql.types.StructType; 12 | 13 | import java.util.Collections; 14 | import java.util.HashMap; 15 | import java.util.List; 16 | import java.util.Map; 17 | 18 | public class UseML implements SparkApp.RDDProcessor { 19 | 20 | @Override 21 | public String getName() { 22 | return "use-ml"; 23 | } 24 | 25 | @Override 26 | public Option[] getOptions() { 27 | return new Option[0]; 28 | } 29 | 30 | @Override 31 | public int run(SparkConf conf, CommandLine cli) throws Exception { 32 | 33 | long startMs = System.currentTimeMillis(); 34 | 35 | conf.set("spark.ui.enabled", "false"); 36 | 37 | SparkSession sparkSession = SparkSession.builder().config(conf).getOrCreate(); 38 | try (JavaSparkContext jsc = new JavaSparkContext(sparkSession.sparkContext())) { 39 | 40 | long diffMs = (System.currentTimeMillis() - startMs); 41 | System.out.println(">> took " + diffMs + " ms to create SQLContext"); 42 | 43 | Map options = new HashMap<>(); 44 | options.put("zkhost", "localhost:9983"); 45 | options.put("collection", "ml20news"); 46 | options.put("query", "content_txt:[* TO *]"); 47 | options.put("fields", "content_txt"); 48 | 49 | Dataset solrData = sparkSession.read().format("solr").options(options).load(); 50 | Dataset sample = solrData.sample(false, 0.1d, 5150).select("content_txt"); 51 | List rows = sample.collectAsList(); 52 | System.out.println(">> loaded " + rows.size() + " docs to classify"); 53 | 54 | StructType schema = sample.schema(); 55 | 56 | CrossValidatorModel cvModel = CrossValidatorModel.load("ml-pipeline-model"); 57 | PipelineModel bestModel = (PipelineModel) cvModel.bestModel(); 58 | 59 | int r = 0; 60 | startMs = System.currentTimeMillis(); 61 | for (Object o : rows) { 62 | Row next = (Row) o; 63 | Row oneRow = RowFactory.create(next.getString(0)); 64 | Dataset oneRowDF = sparkSession.createDataFrame(Collections.singletonList(oneRow), schema); 65 | Dataset scored = bestModel.transform(oneRowDF); 66 | Object o1 = scored.collectAsList().get(0); 67 | Row scoredRow = (Row) o1; 68 | String predictedLabel = scoredRow.getString(scoredRow.fieldIndex("predictedLabel")); 69 | 70 | // an acutal app would save the predictedLabel 71 | //System.out.println(">> for row["+r+"], model returned "+scoredRows.length+" rows, "+scoredRows[0]); 72 | 73 | r++; 74 | } 75 | diffMs = (System.currentTimeMillis() - startMs); 76 | System.out.println(">> took " + diffMs + " ms to score " + rows.size() + " docs"); 77 | } 78 | return 0; 79 | } 80 | } 81 | -------------------------------------------------------------------------------- /src/main/java/com/lucidworks/spark/example/query/ReadTermVectors.java: -------------------------------------------------------------------------------- 1 | package com.lucidworks.spark.example.query; 2 | 3 | import com.lucidworks.spark.SparkApp; 4 | import com.lucidworks.spark.rdd.SolrJavaRDD; 5 | import org.apache.commons.cli.CommandLine; 6 | import org.apache.commons.cli.Option; 7 | import org.apache.solr.client.solrj.SolrQuery; 8 | import org.apache.spark.SparkConf; 9 | import org.apache.spark.api.java.JavaSparkContext; 10 | 11 | import java.util.ArrayList; 12 | import java.util.List; 13 | 14 | /** 15 | * Generate a {@code JavaRDD} from term vector information in the Solr index. 16 | */ 17 | public class ReadTermVectors implements SparkApp.RDDProcessor { 18 | 19 | public String getName() { 20 | return "term-vectors"; 21 | } 22 | 23 | public Option[] getOptions() { 24 | return new Option[]{ 25 | Option.builder("query") 26 | .argName("QUERY") 27 | .hasArg() 28 | .required(false) 29 | .desc("URL encoded Solr query to send to Solr; default is *:*") 30 | .build(), 31 | Option.builder("field") 32 | .argName("FIELD") 33 | .hasArg() 34 | .required(true) 35 | .desc("Field to generate term vectors from") 36 | .build(), 37 | Option.builder("numFeatures") 38 | .argName("NUM") 39 | .hasArg() 40 | .required(false) 41 | .desc("Number of features; defaults to 500") 42 | .build(), 43 | Option.builder("numIterations") 44 | .argName("NUM") 45 | .hasArg() 46 | .required(false) 47 | .desc("Number of iterations for K-Means clustering; defaults to 20") 48 | .build(), 49 | Option.builder("numClusters") 50 | .argName("NUM") 51 | .hasArg() 52 | .required(false) 53 | .desc("Number of clusters (k) for K-Means clustering; defaults to 5") 54 | .build() 55 | }; 56 | } 57 | 58 | public int run(SparkConf conf, CommandLine cli) throws Exception { 59 | 60 | String zkHost = cli.getOptionValue("zkHost", "localhost:9983"); 61 | String collection = cli.getOptionValue("collection", "collection1"); 62 | String queryStr = cli.getOptionValue("query", "*:*"); 63 | String field = cli.getOptionValue("field"); 64 | int numFeatures = Integer.parseInt(cli.getOptionValue("numFeatures", "500")); 65 | int numClusters = Integer.parseInt(cli.getOptionValue("numClusters", "5")); 66 | int numIterations = Integer.parseInt(cli.getOptionValue("numIterations", "20")); 67 | 68 | JavaSparkContext jsc = new JavaSparkContext(conf); 69 | 70 | final SolrQuery solrQuery = new SolrQuery(queryStr); 71 | solrQuery.setFields("id"); 72 | 73 | // sorts are needed for deep-paging 74 | List sorts = new ArrayList(); 75 | sorts.add(new SolrQuery.SortClause("id", "asc")); 76 | sorts.add(new SolrQuery.SortClause("created_at_tdt", "asc")); 77 | solrQuery.setSorts(sorts); 78 | 79 | SolrJavaRDD solrRDD = SolrJavaRDD.get(zkHost, collection, jsc.sc()); 80 | 81 | //TODO: Commented out until we implement term vectors in Base RDD 82 | // // query Solr for term vectors 83 | // JavaRDD termVectorsFromSolr = 84 | // solrRDD.queryTermVectors(solrQuery, field, numFeatures); 85 | // termVectorsFromSolr.cache(); 86 | // 87 | // // Cluster the data using KMeans 88 | // KMeansModel clusters = KMeans.train(termVectorsFromSolr.rdd(), numClusters, numIterations); 89 | // 90 | // // TODO: do something interesting with the clusters 91 | // 92 | // // Evaluate clustering by computing Within Set Sum of Squared Errors 93 | // double WSSSE = clusters.computeCost(termVectorsFromSolr.rdd()); 94 | // System.out.println("Within Set Sum of Squared Errors = " + WSSSE); 95 | 96 | jsc.stop(); 97 | 98 | return 0; 99 | } 100 | } 101 | -------------------------------------------------------------------------------- /src/main/java/com/lucidworks/spark/example/streaming/DocumentFilteringStreamProcessor.java: -------------------------------------------------------------------------------- 1 | package com.lucidworks.spark.example.streaming; 2 | 3 | import com.lucidworks.spark.util.SolrSupport; 4 | import com.lucidworks.spark.SparkApp; 5 | import com.lucidworks.spark.filter.DocFilterContext; 6 | import org.apache.commons.cli.CommandLine; 7 | import org.apache.commons.cli.Option; 8 | import org.apache.log4j.Logger; 9 | import org.apache.solr.client.solrj.SolrQuery; 10 | import org.apache.solr.common.SolrInputDocument; 11 | import org.apache.spark.api.java.function.Function; 12 | import org.apache.spark.streaming.api.java.JavaDStream; 13 | import org.apache.spark.streaming.api.java.JavaReceiverInputDStream; 14 | import org.apache.spark.streaming.api.java.JavaStreamingContext; 15 | import org.apache.spark.streaming.dstream.DStream; 16 | import org.apache.spark.streaming.twitter.TwitterUtils; 17 | import twitter4j.Status; 18 | 19 | import java.util.ArrayList; 20 | import java.util.List; 21 | 22 | /** 23 | * Example showing how to match documents against a set of known queries; useful 24 | * for doing things like alerts, etc. 25 | */ 26 | public class DocumentFilteringStreamProcessor extends SparkApp.StreamProcessor { 27 | 28 | public static Logger log = Logger.getLogger(DocumentFilteringStreamProcessor.class); 29 | 30 | /** 31 | * A DocFilterContext is responsible for loading queries from some external system and 32 | * then doing something with each doc that is matched to a query. 33 | */ 34 | class ExampleDocFilterContextImpl implements DocFilterContext { 35 | 36 | public void init(JavaStreamingContext jssc, CommandLine cli) { 37 | // nothing to init for this basic impl 38 | } 39 | 40 | public String getDocIdFieldName() { return "id"; } 41 | 42 | public List getQueries() { 43 | List queryList = new ArrayList(); 44 | 45 | // a real impl would pull queries from an external system, such as Solr or a DB or a file 46 | SolrQuery q1 = new SolrQuery("type_s:post"); 47 | q1.setParam("_qid_", "POSTS"); // identify the query when tagging matching docs 48 | queryList.add(q1); 49 | 50 | SolrQuery q2 = new SolrQuery("type_s:echo"); 51 | q2.setParam("_qid_", "ECHOS"); 52 | queryList.add(q2); 53 | 54 | return queryList; 55 | } 56 | 57 | public void onMatch(SolrQuery query, SolrInputDocument inputDoc) { 58 | String[] qids = query.getParams("_qid_"); 59 | if (qids == null || qids.length < 1) return; // not one of ours 60 | 61 | if (log.isDebugEnabled()) 62 | log.debug("document [" + inputDoc.getFieldValue("id") + "] matches query: " + qids[0]); 63 | 64 | // just index the matching query for later analysis 65 | inputDoc.addField("_qid_ss", qids[0]); 66 | } 67 | } 68 | 69 | public String getName() { return "docfilter"; } 70 | 71 | @Override 72 | public void setup(JavaStreamingContext jssc, CommandLine cli) throws Exception { 73 | 74 | // load the DocFilterContext implementation, which knows how to load queries 75 | DocFilterContext docFilterContext = loadDocFilterContext(jssc, cli); 76 | final String idFieldName = docFilterContext.getDocIdFieldName(); 77 | 78 | // start receiving a stream of tweets ... 79 | String filtersArg = cli.getOptionValue("tweetFilters"); 80 | String[] filters = (filtersArg != null) ? filtersArg.split(",") : new String[0]; 81 | JavaReceiverInputDStream tweets = TwitterUtils.createStream(jssc, null, filters); 82 | 83 | // map incoming tweets into SolrInputDocument objects for indexing in Solr 84 | JavaDStream docs = tweets.map( 85 | (Function) status -> { 86 | SolrInputDocument doc = 87 | SolrSupport.autoMapToSolrInputDoc(idFieldName, "tweet-"+status.getId(), status, null); 88 | doc.setField("provider_s", "twitter"); 89 | doc.setField("author_s", status.getUser().getScreenName()); 90 | doc.setField("type_s", status.isRetweet() ? "echo" : "post"); 91 | return doc; 92 | } 93 | ); 94 | 95 | // run each doc through a list of filters pulled from our DocFilterContext 96 | String filterCollection = cli.getOptionValue("filterCollection", collection); 97 | DStream enriched = 98 | SolrSupport.filterDocuments(docFilterContext, zkHost, filterCollection, docs.dstream()); 99 | 100 | // now index the enriched docs into Solr (or do whatever after the matching process runs) 101 | SolrSupport.indexDStreamOfDocs(zkHost, collection, batchSize, batchSizeType, enriched); 102 | } 103 | 104 | protected DocFilterContext loadDocFilterContext(JavaStreamingContext jssc, CommandLine cli) 105 | throws Exception 106 | { 107 | DocFilterContext ctxt = null; 108 | String docFilterContextImplClass = cli.getOptionValue("docFilterContextImplClass"); 109 | if (docFilterContextImplClass != null) { 110 | Class implClass = 111 | (Class)getClass().getClassLoader().loadClass(docFilterContextImplClass); 112 | ctxt = implClass.newInstance(); 113 | } else { 114 | ctxt = new ExampleDocFilterContextImpl(); 115 | } 116 | ctxt.init(jssc, cli); 117 | return ctxt; 118 | } 119 | 120 | public Option[] getOptions() { 121 | return new Option[]{ 122 | Option.builder("tweetFilters") 123 | .argName("LIST") 124 | .hasArg() 125 | .required(false) 126 | .desc("List of Twitter keywords to filter on, separated by commas") 127 | .build(), 128 | Option.builder("filterCollection") 129 | .argName("NAME") 130 | .hasArg() 131 | .required(false) 132 | .desc("Collection to pull configuration files to create an " + 133 | "EmbeddedSolrServer for document matching; defaults to the value of the collection option.") 134 | .build(), 135 | Option.builder("docFilterContextImplClass") 136 | .argName("CLASS") 137 | .hasArg() 138 | .required(false) 139 | .desc("Name of the DocFilterContext implementation class; defaults to an internal example impl: "+ 140 | ExampleDocFilterContextImpl.class.getName()) 141 | .build() 142 | }; 143 | } 144 | } 145 | -------------------------------------------------------------------------------- /src/main/java/com/lucidworks/spark/example/streaming/TwitterToSolrStreamProcessor.java: -------------------------------------------------------------------------------- 1 | package com.lucidworks.spark.example.streaming; 2 | 3 | import com.lucidworks.spark.SparkApp; 4 | import com.lucidworks.spark.util.SolrSupport; 5 | import org.apache.commons.cli.CommandLine; 6 | import org.apache.commons.cli.Option; 7 | import org.apache.log4j.Logger; 8 | import org.apache.solr.common.SolrInputDocument; 9 | import org.apache.spark.api.java.function.Function; 10 | import org.apache.spark.streaming.api.java.JavaDStream; 11 | import org.apache.spark.streaming.api.java.JavaReceiverInputDStream; 12 | import org.apache.spark.streaming.api.java.JavaStreamingContext; 13 | import org.apache.spark.streaming.twitter.TwitterUtils; 14 | import twitter4j.Status; 15 | 16 | /** 17 | * Simple example of indexing tweets into Solr using Spark streaming; be sure to update the 18 | * twitter4j.properties file on the classpath with your Twitter API credentials. 19 | */ 20 | public class TwitterToSolrStreamProcessor extends SparkApp.StreamProcessor { 21 | 22 | public static Logger log = Logger.getLogger(TwitterToSolrStreamProcessor.class); 23 | 24 | public String getName() { 25 | return "twitter-to-solr"; 26 | } 27 | 28 | /** 29 | * Sends a stream of tweets to Solr. 30 | */ 31 | @Override 32 | public void setup(JavaStreamingContext jssc, CommandLine cli) throws Exception { 33 | String filtersArg = cli.getOptionValue("tweetFilters"); 34 | String[] filters = (filtersArg != null) ? filtersArg.split(",") : new String[0]; 35 | 36 | // start receiving a stream of tweets ... 37 | JavaReceiverInputDStream tweets = 38 | TwitterUtils.createStream(jssc, null, filters); 39 | 40 | String fusionUrl = cli.getOptionValue("fusion"); 41 | if (fusionUrl != null) { 42 | // just send JSON directly to Fusion 43 | SolrSupport.sendDStreamOfDocsToFusion(fusionUrl, cli.getOptionValue("fusionCredentials"), tweets.dstream(), batchSize); 44 | } else { 45 | // map incoming tweets into PipelineDocument objects for indexing in Solr 46 | JavaDStream docs = tweets.map( 47 | new Function() { 48 | 49 | /** 50 | * Convert a twitter4j Status object into a SolrJ SolrInputDocument 51 | */ 52 | public SolrInputDocument call(Status status) { 53 | 54 | if (log.isDebugEnabled()) { 55 | log.debug("Received tweet: " + status.getId() + ": " + status.getText().replaceAll("\\s+", " ")); 56 | } 57 | 58 | // simple mapping from primitives to dynamic Solr fields using reflection 59 | SolrInputDocument doc = 60 | SolrSupport.autoMapToSolrInputDoc("tweet-" + status.getId(), status, null); 61 | doc.setField("provider_s", "twitter"); 62 | doc.setField("author_s", status.getUser().getScreenName()); 63 | doc.setField("type_s", status.isRetweet() ? "echo" : "post"); 64 | 65 | if (log.isDebugEnabled()) 66 | log.debug("Transformed document: " + doc.toString()); 67 | return doc; 68 | } 69 | } 70 | ); 71 | 72 | // when ready, send the docs into a SolrCloud cluster 73 | SolrSupport.indexDStreamOfDocs(zkHost, collection, batchSize, batchSizeType, docs.dstream()); 74 | } 75 | } 76 | 77 | public Option[] getOptions() { 78 | return new Option[]{ 79 | Option.builder("tweetFilters") 80 | .argName("LIST") 81 | .hasArg() 82 | .required(false) 83 | .desc("List of Twitter keywords to filter on, separated by commas") 84 | .build(), 85 | Option.builder("fusion") 86 | .argName("URL(s)") 87 | .hasArg() 88 | .required(false) 89 | .desc("Fusion endpoint") 90 | .build(), 91 | Option.builder("fusionCredentials") 92 | .argName("user:password:realm") 93 | .hasArg() 94 | .required(false) 95 | .desc("Fusion credentials user:password:realm") 96 | .build() 97 | }; 98 | } 99 | } 100 | -------------------------------------------------------------------------------- /src/main/java/com/lucidworks/spark/filter/DocFilterContext.java: -------------------------------------------------------------------------------- 1 | package com.lucidworks.spark.filter; 2 | 3 | import org.apache.commons.cli.CommandLine; 4 | import org.apache.solr.client.solrj.SolrQuery; 5 | import org.apache.solr.common.SolrInputDocument; 6 | import org.apache.spark.streaming.api.java.JavaStreamingContext; 7 | 8 | import java.io.Serializable; 9 | import java.util.List; 10 | 11 | /** 12 | * Used by the document filtering framework to delegate the loading and 13 | * of queries used to filter documents and what to do when a document 14 | * matches a query. 15 | */ 16 | public interface DocFilterContext extends Serializable { 17 | void init(JavaStreamingContext jssc, CommandLine cli); 18 | String getDocIdFieldName(); 19 | List getQueries(); 20 | void onMatch(SolrQuery query, SolrInputDocument inputDoc); 21 | } 22 | -------------------------------------------------------------------------------- /src/main/java/com/lucidworks/spark/query/PagedResultsIterator.java: -------------------------------------------------------------------------------- 1 | package com.lucidworks.spark.query; 2 | 3 | import com.lucidworks.spark.util.SolrQuerySupport; 4 | import org.apache.solr.client.solrj.SolrClient; 5 | import org.apache.solr.client.solrj.SolrQuery; 6 | import org.apache.solr.client.solrj.SolrServerException; 7 | import org.apache.solr.client.solrj.impl.CloudSolrClient; 8 | import org.apache.solr.client.solrj.response.QueryResponse; 9 | import org.apache.solr.common.SolrDocumentList; 10 | import scala.Option; 11 | 12 | import java.util.Iterator; 13 | import java.util.List; 14 | import java.util.NoSuchElementException; 15 | 16 | /** 17 | * Base class for iterating over paged results in a Solr QueryResponse, with the 18 | * most obvious example being iterating over SolrDocument objects matching a query. 19 | */ 20 | public abstract class PagedResultsIterator implements Iterator, Iterable { 21 | 22 | protected static final int DEFAULT_PAGE_SIZE = 50; 23 | 24 | protected SolrClient solrServer; 25 | protected SolrQuery solrQuery; 26 | protected int currentPageSize = 0; 27 | protected int iterPos = 0; 28 | protected long totalDocs = 0; 29 | protected long numDocs = 0; 30 | protected String cursorMark = null; 31 | protected boolean closeAfterIterating = false; 32 | 33 | protected List currentPage; 34 | 35 | public PagedResultsIterator(SolrClient solrServer, SolrQuery solrQuery) { 36 | this(solrServer, solrQuery, null); 37 | } 38 | 39 | public PagedResultsIterator(SolrClient solrServer, SolrQuery solrQuery, String cursorMark) { 40 | this.solrServer = solrServer; 41 | this.closeAfterIterating = !(solrServer instanceof CloudSolrClient); 42 | this.solrQuery = solrQuery; 43 | this.cursorMark = cursorMark; 44 | if (solrQuery.getRows() == null) 45 | solrQuery.setRows(DEFAULT_PAGE_SIZE); // default page size 46 | } 47 | 48 | public boolean hasNext() { 49 | if (currentPage == null || iterPos == currentPageSize) { 50 | try { 51 | currentPage = fetchNextPage(); 52 | currentPageSize = currentPage.size(); 53 | iterPos = 0; 54 | } catch (SolrServerException sse) { 55 | throw new RuntimeException(sse); 56 | } 57 | } 58 | boolean hasNext = (iterPos < currentPageSize); 59 | if (!hasNext && closeAfterIterating) { 60 | try { 61 | solrServer.close(); 62 | } catch (Exception exc) { 63 | exc.printStackTrace(); 64 | } 65 | } 66 | return hasNext; 67 | } 68 | 69 | protected int getStartForNextPage() { 70 | Integer currentStart = solrQuery.getStart(); 71 | return (currentStart != null) ? currentStart + solrQuery.getRows() : 0; 72 | } 73 | 74 | protected List fetchNextPage() throws SolrServerException { 75 | int start = (cursorMark != null) ? 0 : getStartForNextPage(); 76 | Option resp = SolrQuerySupport.querySolr(solrServer, solrQuery, start, cursorMark); 77 | if (resp.isDefined()) { 78 | if (cursorMark != null) 79 | cursorMark = resp.get().getNextCursorMark(); 80 | 81 | iterPos = 0; 82 | SolrDocumentList docs = resp.get().getResults(); 83 | totalDocs = docs.getNumFound(); 84 | return processQueryResponse(resp.get()); 85 | } else { 86 | throw new SolrServerException("Found None Query response"); 87 | } 88 | 89 | } 90 | 91 | protected abstract List processQueryResponse(QueryResponse resp); 92 | 93 | public T next() { 94 | if (currentPage == null || iterPos >= currentPageSize) 95 | throw new NoSuchElementException("No more docs available!"); 96 | 97 | ++numDocs; 98 | 99 | return currentPage.get(iterPos++); 100 | } 101 | 102 | public void remove() { 103 | throw new UnsupportedOperationException("remove is not supported"); 104 | } 105 | 106 | public Iterator iterator() { 107 | return this; 108 | } 109 | } -------------------------------------------------------------------------------- /src/main/java/com/lucidworks/spark/query/ResultsIterator.java: -------------------------------------------------------------------------------- 1 | package com.lucidworks.spark.query; 2 | 3 | import com.lucidworks.spark.SparkSolrAccumulator; 4 | 5 | import java.util.Iterator; 6 | 7 | public abstract class ResultsIterator implements Iterator, Iterable { 8 | 9 | protected SparkSolrAccumulator acc; 10 | 11 | public SparkSolrAccumulator getAccumulator() { 12 | return this.acc; 13 | } 14 | 15 | public abstract long getNumDocs(); 16 | 17 | public void setAccumulator(SparkSolrAccumulator acc) { 18 | this.acc = acc; 19 | } 20 | 21 | protected void increment() { 22 | if (acc != null) { 23 | acc.inc(); 24 | } 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /src/main/java/com/lucidworks/spark/query/SolrStreamIterator.java: -------------------------------------------------------------------------------- 1 | package com.lucidworks.spark.query; 2 | 3 | import org.apache.log4j.Logger; 4 | import org.apache.solr.client.solrj.SolrQuery; 5 | import org.apache.solr.client.solrj.impl.CloudSolrClient; 6 | import org.apache.solr.client.solrj.impl.HttpSolrClient; 7 | import org.apache.solr.client.solrj.io.SolrClientCache; 8 | import org.apache.solr.client.solrj.io.stream.SolrStream; 9 | import org.apache.solr.client.solrj.io.stream.StreamContext; 10 | import org.apache.solr.client.solrj.io.stream.TupleStream; 11 | import org.apache.solr.common.params.CommonParams; 12 | 13 | import java.io.IOException; 14 | 15 | /** 16 | * An iterator over a stream of query results from one Solr Core. It is a 17 | * wrapper over the SolrStream to adapt it to an iterator interface. 18 | *

19 | * This iterator is not thread safe. It is intended to be used within the 20 | * context of a single thread. 21 | */ 22 | public class SolrStreamIterator extends TupleStreamIterator { 23 | 24 | private static final Logger log = Logger.getLogger(SolrStreamIterator.class); 25 | 26 | protected SolrQuery solrQuery; 27 | protected String shardUrl; 28 | protected int numWorkers; 29 | protected int workerId; 30 | protected SolrClientCache solrClientCache; 31 | protected HttpSolrClient httpSolrClient; 32 | protected CloudSolrClient cloudSolrClient; 33 | 34 | // Remove the whole code around StreamContext, numWorkers, workerId once SOLR-10490 is fixed. 35 | // It should just work if an 'fq' passed in the params with HashQ filter 36 | public SolrStreamIterator(String shardUrl, CloudSolrClient cloudSolrClient, HttpSolrClient httpSolrClient, SolrQuery solrQuery, int numWorkers, int workerId) { 37 | super(solrQuery); 38 | 39 | this.shardUrl = shardUrl; 40 | this.cloudSolrClient = cloudSolrClient; 41 | this.httpSolrClient = httpSolrClient; 42 | this.solrQuery = solrQuery; 43 | this.numWorkers = numWorkers; 44 | this.workerId = workerId; 45 | 46 | if (solrQuery.getRequestHandler() == null) { 47 | solrQuery = solrQuery.setRequestHandler("/export"); 48 | } 49 | solrQuery.setRows(null); 50 | solrQuery.set(CommonParams.WT, CommonParams.JAVABIN); 51 | //SolrQuerySupport.validateExportHandlerQuery(solrServer, solrQuery); 52 | } 53 | 54 | protected TupleStream openStream() { 55 | SolrStream stream; 56 | try { 57 | stream = new SolrStream(shardUrl, solrQuery); 58 | stream.setStreamContext(getStreamContext()); 59 | stream.open(); 60 | } catch (IOException e1) { 61 | throw new RuntimeException(e1); 62 | } 63 | return stream; 64 | } 65 | 66 | // We have to set the streaming context so that we can pass our own cloud client with authentication 67 | protected StreamContext getStreamContext() { 68 | StreamContext context = new StreamContext(); 69 | solrClientCache = new SparkSolrClientCache(cloudSolrClient, httpSolrClient); 70 | context.setSolrClientCache(solrClientCache); 71 | context.numWorkers = numWorkers; 72 | context.workerID = workerId; 73 | return context; 74 | } 75 | 76 | protected void afterStreamClosed() throws Exception { 77 | // No need to close http or cloudClient because they are re-used from cache 78 | } 79 | } 80 | -------------------------------------------------------------------------------- /src/main/java/com/lucidworks/spark/query/SparkSolrClientCache.java: -------------------------------------------------------------------------------- 1 | package com.lucidworks.spark.query; 2 | 3 | import org.apache.solr.client.solrj.impl.CloudSolrClient; 4 | import org.apache.solr.client.solrj.impl.HttpSolrClient; 5 | import org.apache.solr.client.solrj.io.SolrClientCache; 6 | 7 | /** 8 | * Overriding so that we can pass our own cloud client for StreamingContext 9 | * zkhost param is not used since we have an existing cloud instance for the ZK 10 | */ 11 | public class SparkSolrClientCache extends SolrClientCache { 12 | 13 | private final CloudSolrClient solrClient; 14 | private final HttpSolrClient httpSolrClient; 15 | 16 | public SparkSolrClientCache(CloudSolrClient solrClient, HttpSolrClient httpSolrClient) { 17 | this.solrClient = solrClient; 18 | this.httpSolrClient = httpSolrClient; 19 | } 20 | 21 | public synchronized CloudSolrClient getCloudSolrClient(String zkHost) { 22 | return solrClient; 23 | } 24 | 25 | public synchronized HttpSolrClient getHttpSolrClient(String host) { 26 | if (host != null && host.endsWith("/")) { 27 | host = host.substring(0, host.length() - 1); 28 | } 29 | httpSolrClient.setBaseURL(host); 30 | return httpSolrClient; 31 | } 32 | 33 | } 34 | -------------------------------------------------------------------------------- /src/main/java/com/lucidworks/spark/query/StreamingExpressionResultIterator.java: -------------------------------------------------------------------------------- 1 | package com.lucidworks.spark.query; 2 | 3 | import org.apache.log4j.Logger; 4 | import org.apache.solr.client.solrj.impl.CloudSolrClient; 5 | import org.apache.solr.client.solrj.impl.HttpSolrClient; 6 | import org.apache.solr.client.solrj.io.SolrClientCache; 7 | import org.apache.solr.client.solrj.io.stream.SolrStream; 8 | import org.apache.solr.client.solrj.io.stream.StreamContext; 9 | import org.apache.solr.client.solrj.io.stream.TupleStream; 10 | import org.apache.solr.common.cloud.Replica; 11 | import org.apache.solr.common.cloud.Slice; 12 | import org.apache.solr.common.cloud.ZkCoreNodeProps; 13 | import org.apache.solr.common.cloud.ZkStateReader; 14 | import org.apache.solr.common.params.CommonParams; 15 | import org.apache.solr.common.params.ModifiableSolrParams; 16 | import org.apache.solr.common.params.SolrParams; 17 | 18 | import java.util.ArrayList; 19 | import java.util.Collection; 20 | import java.util.List; 21 | import java.util.Random; 22 | 23 | public class StreamingExpressionResultIterator extends TupleStreamIterator { 24 | 25 | private static final Logger log = Logger.getLogger(StreamingExpressionResultIterator.class); 26 | 27 | protected String zkHost; 28 | protected String collection; 29 | protected String qt; 30 | protected CloudSolrClient cloudSolrClient; 31 | protected HttpSolrClient httpSolrClient; 32 | protected SolrClientCache solrClientCache; 33 | 34 | private final Random random = new Random(5150L); 35 | 36 | public StreamingExpressionResultIterator(CloudSolrClient cloudSolrClient, HttpSolrClient httpSolrClient, String collection, SolrParams solrParams) { 37 | super(solrParams); 38 | this.cloudSolrClient = cloudSolrClient; 39 | this.httpSolrClient = httpSolrClient; 40 | this.collection = collection; 41 | 42 | qt = solrParams.get(CommonParams.QT); 43 | if (qt == null) qt = "/stream"; 44 | } 45 | 46 | protected TupleStream openStream() { 47 | TupleStream stream; 48 | 49 | ModifiableSolrParams params = new ModifiableSolrParams(); 50 | params.set(CommonParams.QT, qt); 51 | 52 | String aggregationMode = solrParams.get("aggregationMode"); 53 | 54 | log.info("aggregationMode=" + aggregationMode + ", solrParams: " + solrParams); 55 | if (aggregationMode != null) { 56 | params.set("aggregationMode", aggregationMode); 57 | } else { 58 | params.set("aggregationMode", "facet"); // use facet by default as it is faster 59 | } 60 | 61 | if ("/sql".equals(qt)) { 62 | String sql = solrParams.get("sql").replaceAll("\\s+", " "); 63 | log.info("Executing SQL statement " + sql + " against collection " + collection); 64 | params.set("stmt", sql); 65 | } else { 66 | String expr = solrParams.get("expr").replaceAll("\\s+", " "); 67 | log.info("Executing streaming expression " + expr + " against collection " + collection); 68 | params.set("expr", expr); 69 | } 70 | 71 | try { 72 | String url = (new ZkCoreNodeProps(getRandomReplica())).getCoreUrl(); 73 | log.info("Sending "+qt+" request to replica "+url+" of "+collection+" with params: "+params); 74 | long startMs = System.currentTimeMillis(); 75 | stream = new SolrStream(url, params); 76 | stream.setStreamContext(getStreamContext()); 77 | stream.open(); 78 | long diffMs = (System.currentTimeMillis() - startMs); 79 | log.debug("Open stream to "+url+" took "+diffMs+" (ms)"); 80 | } catch (Exception e) { 81 | log.error("Failed to execute request ["+solrParams+"] due to: "+e, e); 82 | if (e instanceof RuntimeException) { 83 | throw (RuntimeException)e; 84 | } else { 85 | throw new RuntimeException(e); 86 | } 87 | } 88 | return stream; 89 | } 90 | 91 | // We have to set the streaming context so that we can pass our own cloud client with authentication 92 | protected StreamContext getStreamContext() { 93 | StreamContext context = new StreamContext(); 94 | solrClientCache = new SparkSolrClientCache(cloudSolrClient, httpSolrClient); 95 | context.setSolrClientCache(solrClientCache); 96 | return context; 97 | } 98 | 99 | protected Replica getRandomReplica() { 100 | ZkStateReader zkStateReader = cloudSolrClient.getZkStateReader(); 101 | Collection slices = zkStateReader.getClusterState().getCollection(collection.split(",")[0]).getActiveSlices(); 102 | if (slices == null || slices.size() == 0) 103 | throw new IllegalStateException("No active shards found "+collection); 104 | 105 | List shuffler = new ArrayList<>(); 106 | for (Slice slice : slices) { 107 | shuffler.addAll(slice.getReplicas()); 108 | } 109 | return shuffler.get(random.nextInt(shuffler.size())); 110 | } 111 | } 112 | -------------------------------------------------------------------------------- /src/main/java/com/lucidworks/spark/query/TupleStreamIterator.java: -------------------------------------------------------------------------------- 1 | package com.lucidworks.spark.query; 2 | 3 | import org.apache.solr.client.solrj.io.Tuple; 4 | import org.apache.solr.client.solrj.io.stream.TupleStream; 5 | import org.apache.solr.common.SolrDocument; 6 | import org.apache.solr.common.params.SolrParams; 7 | import org.slf4j.Logger; 8 | import org.slf4j.LoggerFactory; 9 | 10 | import java.io.IOException; 11 | import java.util.Iterator; 12 | import java.util.Map; 13 | import java.util.NoSuchElementException; 14 | 15 | /** 16 | * An iterator over a stream of tuples from Solr. 17 | *

18 | * This iterator is not thread safe. It is intended to be used within the 19 | * context of a single thread. 20 | */ 21 | public abstract class TupleStreamIterator extends ResultsIterator { 22 | 23 | private static final Logger log = LoggerFactory.getLogger(TupleStreamIterator.class); 24 | 25 | protected TupleStream stream; 26 | protected long numDocs = 0; 27 | protected SolrParams solrParams; 28 | private Tuple currentTuple = null; 29 | private long openedAt; 30 | private boolean isClosed = false; 31 | 32 | public TupleStreamIterator(SolrParams solrParams) { 33 | this.solrParams = solrParams; 34 | } 35 | 36 | public synchronized boolean hasNext() { 37 | if (isClosed) { 38 | return false; 39 | } 40 | 41 | if (stream == null) { 42 | stream = openStream(); 43 | openedAt = System.currentTimeMillis(); 44 | } 45 | 46 | try { 47 | if (currentTuple == null) { 48 | currentTuple = fetchNextTuple(); 49 | } 50 | } catch (IOException e) { 51 | log.error("Failed to fetch next Tuple for query: " + solrParams.toQueryString(), e); 52 | throw new RuntimeException(e); 53 | } 54 | 55 | if (currentTuple == null) { 56 | try { 57 | stream.close(); 58 | } catch (IOException e) { 59 | log.error("Failed to close the SolrStream.", e); 60 | throw new RuntimeException(e); 61 | } finally { 62 | this.isClosed = true; 63 | } 64 | 65 | long diffMs = System.currentTimeMillis() - openedAt; 66 | log.debug("Took {} (ms) to read {} from stream", diffMs, numDocs); 67 | 68 | try { 69 | afterStreamClosed(); 70 | } catch (Exception exc) { 71 | log.warn("Exception: {}", exc); 72 | } 73 | } 74 | 75 | return currentTuple != null; 76 | } 77 | 78 | protected void afterStreamClosed() throws Exception { 79 | // no-op - sub-classes can override if needed 80 | } 81 | 82 | protected Tuple fetchNextTuple() throws IOException { 83 | Tuple tuple = stream.read(); 84 | if (tuple.EOF) 85 | return null; 86 | 87 | return tuple; 88 | } 89 | 90 | protected abstract TupleStream openStream(); 91 | 92 | public synchronized Map nextTuple() { 93 | if (isClosed) 94 | throw new NoSuchElementException("already closed"); 95 | 96 | if (currentTuple == null) 97 | throw new NoSuchElementException(); 98 | 99 | final Tuple tempCurrentTuple = currentTuple; 100 | currentTuple = null; 101 | ++numDocs; 102 | increment(); 103 | return tempCurrentTuple.getFields(); 104 | } 105 | 106 | public synchronized Map next() { 107 | return nextTuple(); 108 | } 109 | 110 | public void remove() { 111 | throw new UnsupportedOperationException("remove is not supported"); 112 | } 113 | 114 | public Iterator iterator() { 115 | return this; 116 | } 117 | 118 | public long getNumDocs() { 119 | return numDocs; 120 | } 121 | } 122 | -------------------------------------------------------------------------------- /src/main/java/com/lucidworks/spark/query/sql/SolrSQLSupport.java: -------------------------------------------------------------------------------- 1 | package com.lucidworks.spark.query.sql; 2 | 3 | import java.util.Collections; 4 | import java.util.HashMap; 5 | import java.util.Map; 6 | 7 | /** 8 | * Helper for working with Solr SQL statements, such as parsing out the column list. 9 | */ 10 | public class SolrSQLSupport { 11 | 12 | private static final String SELECT = "select "; 13 | private static final String FROM = " from"; 14 | private static final String AS = "as "; 15 | private static final String DISTINCT = "distinct "; 16 | 17 | /** 18 | * Given a valid Solr SQL statement, parse out the columns and aliases as a map. 19 | */ 20 | public static Map parseColumns(String sqlStmt) throws Exception { 21 | 22 | // NOTE: While I prefer using a SQL parser here, the presto / calcite / Spark parsers were too complex 23 | // for this basic task and pulled in unwanted / incompatible dependencies, e.g. presto requires a different 24 | // version of guava than what Spark supports 25 | 26 | String tmp = sqlStmt.replaceAll("\\s+", " ").trim(); 27 | 28 | String lc = tmp.toLowerCase(); 29 | if (!lc.startsWith(SELECT)) 30 | throw new IllegalArgumentException("Expected SQL to start with '"+SELECT+"' but found ["+sqlStmt+"] instead!"); 31 | 32 | int fromAt = lc.indexOf(FROM, SELECT.length()); 33 | if (fromAt == -1) 34 | throw new IllegalArgumentException("No FROM keyword found in SQL: "+sqlStmt); 35 | 36 | String columnList = tmp.substring(SELECT.length(),fromAt).trim(); 37 | 38 | // SELECT * not supported yet 39 | if ("*".equals(columnList)) 40 | return Collections.emptyMap(); 41 | 42 | Map columns = new HashMap<>(); 43 | for (String pair : columnList.split(",")) { 44 | pair = pair.trim(); 45 | 46 | // trim off distinct indicator 47 | if (pair.toLowerCase().startsWith(DISTINCT)) { 48 | pair = pair.substring(DISTINCT.length()); 49 | } 50 | 51 | String col; 52 | String alias; 53 | int spaceAt = pair.indexOf(" "); 54 | if (spaceAt != -1) { 55 | col = pair.substring(0,spaceAt); 56 | alias = pair.substring(spaceAt+1); 57 | if (alias.toLowerCase().startsWith(AS)) { 58 | alias = alias.substring(AS.length()); 59 | } 60 | } else { 61 | col = pair; 62 | alias = pair; 63 | } 64 | 65 | columns.put(col.replace("`","").replace("'",""), alias.replace("`","").replace("'","")); 66 | } 67 | return columns; 68 | } 69 | } 70 | -------------------------------------------------------------------------------- /src/main/java/com/lucidworks/spark/rdd/SolrJavaRDD.java: -------------------------------------------------------------------------------- 1 | package com.lucidworks.spark.rdd; 2 | 3 | import org.apache.solr.client.solrj.SolrQuery; 4 | import org.apache.solr.common.SolrDocument; 5 | import org.apache.spark.SparkContext; 6 | import org.apache.spark.api.java.JavaRDD; 7 | 8 | public class SolrJavaRDD extends JavaRDD { 9 | 10 | private final SelectSolrRDD solrRDD; 11 | 12 | public SolrJavaRDD(SelectSolrRDD solrRDD) { 13 | super(solrRDD, solrRDD.elementClassTag()); 14 | this.solrRDD = solrRDD; 15 | } 16 | 17 | protected SolrJavaRDD wrap(SelectSolrRDD rdd) { 18 | return new SolrJavaRDD(rdd); 19 | } 20 | 21 | public JavaRDD query(String query) { 22 | return wrap(rdd().query(query)); 23 | } 24 | 25 | public JavaRDD queryShards(String query) { 26 | return wrap(rdd().query(query)); 27 | } 28 | 29 | public JavaRDD queryShards(SolrQuery solrQuery) { 30 | return wrap(rdd().query(solrQuery)); 31 | } 32 | 33 | public JavaRDD queryShards(SolrQuery solrQuery, String splitFieldName, int splitsPerShard) { 34 | return wrap(rdd().query(solrQuery).splitField(splitFieldName).splitsPerShard(splitsPerShard)); 35 | } 36 | 37 | public JavaRDD queryNoSplits(String query) { 38 | return wrap(rdd().query(query).splitsPerShard(1)); 39 | } 40 | 41 | @Override 42 | public SelectSolrRDD rdd() { 43 | return solrRDD; 44 | } 45 | 46 | public static SolrJavaRDD get(String zkHost, String collection, SparkContext sc) { 47 | SelectSolrRDD solrRDD = SelectSolrRDD$.MODULE$.apply(zkHost, collection, sc); 48 | return new SolrJavaRDD(solrRDD); 49 | } 50 | 51 | } 52 | -------------------------------------------------------------------------------- /src/main/java/com/lucidworks/spark/rdd/SolrStreamJavaRDD.java: -------------------------------------------------------------------------------- 1 | package com.lucidworks.spark.rdd; 2 | 3 | import org.apache.solr.client.solrj.SolrQuery; 4 | import org.apache.spark.SparkContext; 5 | import org.apache.spark.api.java.JavaRDD; 6 | 7 | import java.util.Map; 8 | 9 | public class SolrStreamJavaRDD extends JavaRDD> { 10 | 11 | private final StreamingSolrRDD solrRDD; 12 | 13 | public SolrStreamJavaRDD(StreamingSolrRDD solrRDD) { 14 | super(solrRDD, solrRDD.elementClassTag()); 15 | this.solrRDD = solrRDD; 16 | } 17 | 18 | protected SolrStreamJavaRDD wrap(StreamingSolrRDD rdd) { 19 | return new SolrStreamJavaRDD(rdd); 20 | } 21 | 22 | public JavaRDD> query(String query) { 23 | return wrap(rdd().query(query)); 24 | } 25 | 26 | public JavaRDD> queryShards(String query) { 27 | return wrap(rdd().query(query)); 28 | } 29 | 30 | public JavaRDD> queryShards(SolrQuery solrQuery) { 31 | return wrap(rdd().query(solrQuery)); 32 | } 33 | 34 | public JavaRDD> queryShards(SolrQuery solrQuery, String splitFieldName, int splitsPerShard) { 35 | return wrap(rdd().query(solrQuery).splitField(splitFieldName).splitsPerShard(splitsPerShard)); 36 | } 37 | 38 | public JavaRDD> queryNoSplits(String query) { 39 | return wrap(rdd().query(query).splitsPerShard(1)); 40 | } 41 | 42 | @Override 43 | public StreamingSolrRDD rdd() { 44 | return solrRDD; 45 | } 46 | 47 | public static SolrStreamJavaRDD get(String zkHost, String collection, SparkContext sc) { 48 | StreamingSolrRDD solrRDD = StreamingSolrRDD$.MODULE$.apply(zkHost, collection, sc); 49 | return new SolrStreamJavaRDD(solrRDD); 50 | } 51 | 52 | } 53 | -------------------------------------------------------------------------------- /src/main/java/com/lucidworks/spark/util/EmbeddedSolrServerFactory.java: -------------------------------------------------------------------------------- 1 | package com.lucidworks.spark.util; 2 | 3 | import java.io.*; 4 | import java.nio.charset.StandardCharsets; 5 | import java.util.Collections; 6 | import java.util.HashMap; 7 | import java.util.Map; 8 | 9 | import org.apache.solr.client.solrj.embedded.EmbeddedSolrServer; 10 | import org.apache.solr.client.solrj.impl.CloudSolrClient; 11 | import org.apache.solr.common.cloud.ZkConfigManager; 12 | import org.apache.solr.common.cloud.ZkStateReader; 13 | import org.apache.solr.core.CoreContainer; 14 | import org.apache.solr.core.SolrCore; 15 | 16 | import org.apache.log4j.Logger; 17 | 18 | import org.apache.commons.io.FileUtils; 19 | 20 | /** 21 | * Supports one or more embedded Solr servers in the same JVM 22 | */ 23 | public class EmbeddedSolrServerFactory implements Serializable { 24 | 25 | private static final Logger log = Logger.getLogger(EmbeddedSolrServerFactory.class); 26 | 27 | public static final EmbeddedSolrServerFactory singleton = new EmbeddedSolrServerFactory(); 28 | 29 | private transient Map servers = new HashMap(); 30 | 31 | public synchronized EmbeddedSolrServer getEmbeddedSolrServer(String zkHost, String collection) { 32 | return getEmbeddedSolrServer(zkHost, collection, null, null); 33 | } 34 | 35 | public synchronized EmbeddedSolrServer getEmbeddedSolrServer(String zkHost, String collection, String solrConfigXml, String solrXml) { 36 | String key = zkHost+"/"+collection; 37 | 38 | EmbeddedSolrServer solr = servers.get(key); 39 | if (solr == null) { 40 | try { 41 | solr = bootstrapEmbeddedSolrServer(zkHost, collection, solrConfigXml, solrXml); 42 | } catch (Exception exc) { 43 | if (exc instanceof RuntimeException) { 44 | throw (RuntimeException) exc; 45 | } else { 46 | throw new RuntimeException(exc); 47 | } 48 | } 49 | servers.put(key, solr); 50 | } 51 | return solr; 52 | } 53 | 54 | private EmbeddedSolrServer bootstrapEmbeddedSolrServer(String zkHost, String collection, String solrConfigXml, String solrXml) throws Exception { 55 | 56 | CloudSolrClient cloudClient = SolrSupport.getCachedCloudClient(zkHost); 57 | cloudClient.connect(); 58 | 59 | ZkStateReader zkStateReader = cloudClient.getZkStateReader(); 60 | if (!zkStateReader.getClusterState().hasCollection(collection)) 61 | throw new IllegalStateException("Collection '"+collection+"' not found!"); 62 | 63 | String configName = zkStateReader.readConfigName(collection); 64 | if (configName == null) 65 | throw new IllegalStateException("No configName found for Collection: "+collection); 66 | 67 | File tmpDir = FileUtils.getTempDirectory(); 68 | File solrHomeDir = new File(tmpDir, "solr"+System.currentTimeMillis()); 69 | 70 | log.info("Setting up embedded Solr server in local directory: "+solrHomeDir.getAbsolutePath()); 71 | 72 | FileUtils.forceMkdir(solrHomeDir); 73 | 74 | writeSolrXml(solrHomeDir, solrXml); 75 | 76 | String coreName = "embedded"; 77 | 78 | File instanceDir = new File(solrHomeDir, coreName); 79 | FileUtils.forceMkdir(instanceDir); 80 | 81 | File confDir = new File(instanceDir, "conf"); 82 | ZkConfigManager zkConfigManager = 83 | new ZkConfigManager(cloudClient.getZkStateReader().getZkClient()); 84 | zkConfigManager.downloadConfigDir(configName, confDir.toPath()); 85 | if (!confDir.isDirectory()) 86 | throw new IOException("Failed to download /configs/"+configName+" from ZooKeeper!"); 87 | 88 | writeSolrConfigXml(confDir, solrConfigXml); 89 | 90 | log.info(String.format("Attempting to bootstrap EmbeddedSolrServer instance in dir: %s", 91 | instanceDir.getAbsolutePath())); 92 | 93 | CoreContainer coreContainer = new CoreContainer(solrHomeDir.toPath(), null); 94 | coreContainer.load(); 95 | 96 | SolrCore core = coreContainer.create(coreName, instanceDir.toPath(), Collections.emptyMap(), false); 97 | return new EmbeddedSolrServer(coreContainer, coreName); 98 | } 99 | 100 | protected File writeSolrConfigXml(File confDir, String solrConfigXml) throws IOException { 101 | if (solrConfigXml != null && !solrConfigXml.trim().isEmpty()) { 102 | return writeClasspathResourceToLocalFile(solrConfigXml, new File(confDir, "solrconfig.xml")); 103 | } else { 104 | return writeClasspathResourceToLocalFile("embedded/solrconfig.xml", new File(confDir, "solrconfig.xml")); 105 | } 106 | } 107 | 108 | protected File writeSolrXml(File solrHomeDir, String solrXml) throws IOException { 109 | if (solrXml != null && !solrXml.trim().isEmpty()) { 110 | return writeClasspathResourceToLocalFile(solrXml, new File(solrHomeDir, "solr.xml")); 111 | } else { 112 | return writeClasspathResourceToLocalFile("embedded/solr.xml", new File(solrHomeDir, "solr.xml")); 113 | } 114 | } 115 | 116 | protected File writeClasspathResourceToLocalFile(String resourceId, File destFile) throws IOException { 117 | InputStreamReader isr = null; 118 | OutputStreamWriter osw = null; 119 | int r = 0; 120 | char[] ach = new char[1024]; 121 | try { 122 | InputStream in = getClass().getClassLoader().getResourceAsStream(resourceId); 123 | if (in == null) 124 | throw new IOException("Resource "+resourceId+" not found on classpath!"); 125 | 126 | isr = new InputStreamReader(in, StandardCharsets.UTF_8); 127 | osw = new OutputStreamWriter(new FileOutputStream(destFile), StandardCharsets.UTF_8); 128 | while ((r = isr.read(ach)) != -1) osw.write(ach, 0, r); 129 | osw.flush(); 130 | } finally { 131 | if (isr != null) { 132 | try { 133 | isr.close(); 134 | } catch (Exception ignoreMe){ 135 | ignoreMe.printStackTrace(); 136 | } 137 | } 138 | if (osw != null) { 139 | try { 140 | osw.close(); 141 | } catch (Exception ignoreMe){ 142 | ignoreMe.printStackTrace(); 143 | } 144 | } 145 | } 146 | return destFile; 147 | } 148 | } 149 | -------------------------------------------------------------------------------- /src/main/java/com/lucidworks/spark/util/FusionAuthHttpClient.java: -------------------------------------------------------------------------------- 1 | package com.lucidworks.spark.util; 2 | 3 | import org.apache.solr.client.solrj.impl.HttpSolrClient; 4 | 5 | public abstract class FusionAuthHttpClient { 6 | 7 | private final String zkHost; 8 | 9 | public FusionAuthHttpClient(String zkHost) { 10 | this.zkHost = zkHost; 11 | } 12 | 13 | public String getZkHost() { 14 | return zkHost; 15 | } 16 | 17 | public abstract HttpSolrClient.Builder getHttpClientBuilder() throws Exception; 18 | } 19 | -------------------------------------------------------------------------------- /src/main/java/com/lucidworks/spark/util/SQLQuerySupport.java: -------------------------------------------------------------------------------- 1 | package com.lucidworks.spark.util; 2 | 3 | import org.apache.spark.mllib.linalg.MatrixUDT; 4 | import org.apache.spark.mllib.linalg.VectorUDT; 5 | import org.apache.spark.sql.types.DataType; 6 | import org.apache.spark.sql.types.DataTypes; 7 | 8 | 9 | public class SQLQuerySupport { 10 | public static DataType getsqlDataType(String s) { 11 | if (s.equalsIgnoreCase("double")) { 12 | return DataTypes.DoubleType; 13 | } 14 | if (s.equalsIgnoreCase("byte")) { 15 | return DataTypes.ByteType; 16 | } 17 | if (s.equalsIgnoreCase("short")) { 18 | return DataTypes.ShortType; 19 | } 20 | if (((s.equalsIgnoreCase("int")) || (s.equalsIgnoreCase("integer")))) { 21 | return DataTypes.IntegerType; 22 | } 23 | if (s.equalsIgnoreCase("long")) { 24 | return DataTypes.LongType; 25 | } 26 | if (s.equalsIgnoreCase("String")) { 27 | return DataTypes.StringType; 28 | } 29 | if (s.equalsIgnoreCase("boolean")) { 30 | return DataTypes.BooleanType; 31 | } 32 | if (s.equalsIgnoreCase("timestamp")) { 33 | return DataTypes.TimestampType; 34 | } 35 | if (s.equalsIgnoreCase("date")) { 36 | return DataTypes.DateType; 37 | } 38 | if (s.equalsIgnoreCase("vector")) { 39 | return new VectorUDT(); 40 | } 41 | if (s.equalsIgnoreCase("matrix")) { 42 | return new MatrixUDT(); 43 | } 44 | if (s.contains(":") && s.split(":")[0].equalsIgnoreCase("array")) { 45 | return getArrayTypeRecurse(s,0); 46 | } 47 | return DataTypes.StringType; 48 | } 49 | 50 | public static DataType getArrayTypeRecurse(String s, int fromIdx) { 51 | if (s.contains(":") && s.split(":")[1].equalsIgnoreCase("array")) { 52 | fromIdx = s.indexOf(":", fromIdx); 53 | s = s.substring(fromIdx+1, s.length()); 54 | return DataTypes.createArrayType(getArrayTypeRecurse(s,fromIdx)); 55 | } 56 | return DataTypes.createArrayType(getsqlDataType(s.split(":")[1])); 57 | } 58 | } 59 | -------------------------------------------------------------------------------- /src/main/java/com/lucidworks/spark/util/Utils.java: -------------------------------------------------------------------------------- 1 | package com.lucidworks.spark.util; 2 | 3 | 4 | import scala.collection.JavaConverters$; 5 | 6 | public class Utils { 7 | /** 8 | *

9 | * Gets the class for the given name using the Context ClassLoader on this thread or, 10 | * if not present, the ClassLoader that loaded Spark Solr. 11 | *

12 | *

Copied from Spark org.apache.spark.util.Utils.scala

13 | */ 14 | public static Class classForName(String className) throws ClassNotFoundException { 15 | return Class.forName(className, true, getContextOrSparkSolrClassLoader()); 16 | } 17 | 18 | /** 19 | *

20 | * Get the Context ClassLoader on this thread or, if not present, the ClassLoader that 21 | * loaded spark-solr. 22 | *

23 | *

24 | * This should be used whenever passing a ClassLoader to Class.forName or finding the currently 25 | * active loader when setting up ClassLoader delegation chains. 26 | *

27 | *

Copied from Spark org.apache.spark.util.Utils.scala

28 | */ 29 | public static ClassLoader getContextOrSparkSolrClassLoader() { 30 | ClassLoader classLoader = Thread.currentThread().getContextClassLoader(); 31 | return classLoader == null ? getSparkSolrClassLoader() : classLoader; 32 | } 33 | 34 | /** Get the ClassLoader that loaded spark-solr. */ 35 | public static ClassLoader getSparkSolrClassLoader() { 36 | return Utils.class.getClassLoader(); 37 | } 38 | 39 | public static scala.collection.immutable.Map convertJavaMapToScalaImmmutableMap(final java.util.Map m) { 40 | return JavaConverters$.MODULE$.mapAsScalaMapConverter(m).asScala().toMap(scala.Predef$.MODULE$.$conforms()); 41 | } 42 | } -------------------------------------------------------------------------------- /src/main/resources/embedded/solr.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 5 | ${socketTimeout:0} 6 | ${connTimeout:0} 7 | 8 | 9 | 10 | -------------------------------------------------------------------------------- /src/main/resources/embedded/solrconfig.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | LATEST 4 | ${solr.data.dir:} 5 | 6 | 7 | 8 | single 9 | false 10 | 11 | 12 | 13 | ${solr.autoCommit.maxTime:15000} 14 | false 15 | 16 | 17 | ${solr.autoSoftCommit.maxTime:-1} 18 | 19 | 20 | 21 | 1024 22 | 26 | 27 | 31 | 32 | 36 | 37 | 43 | 44 | true 45 | 20 46 | 200 47 | 48 | true 49 | 5 50 | 51 | 52 | 53 | 57 | 58 | 59 | 60 | 61 | 62 | 1 63 | 64 | 65 | 66 | 67 | 68 | true 69 | json 70 | true 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | text/plain; charset=UTF-8 79 | 80 | 81 | 82 | -------------------------------------------------------------------------------- /src/main/scala/com/lucidworks/spark/Logging.scala: -------------------------------------------------------------------------------- 1 | package com.lucidworks.spark 2 | 3 | import org.slf4j.{LoggerFactory, Logger => Underlying} 4 | 5 | 6 | final class Logger private (val underlying: Underlying) extends Serializable { 7 | 8 | def info(msg: String): Unit = if (underlying.isInfoEnabled) underlying.info(msg) 9 | 10 | def info(msg: String, t: Throwable): Unit = if (underlying.isInfoEnabled) underlying.info(msg, t) 11 | 12 | def info(msg: String, args: Any*): Unit = if (underlying.isInfoEnabled) underlying.info(msg, args) 13 | 14 | def debug(msg: String): Unit = if (underlying.isDebugEnabled) underlying.debug(msg) 15 | 16 | def debug(msg: String, t: Throwable): Unit = if (underlying.isDebugEnabled) underlying.debug(msg, t) 17 | 18 | def debug(msg: String, args: Any*): Unit = if (underlying.isDebugEnabled) underlying.debug(msg, args) 19 | 20 | def trace(msg: String): Unit = if (underlying.isTraceEnabled) underlying.trace(msg) 21 | 22 | def trace(msg: String, t: Throwable): Unit = if (underlying.isTraceEnabled) underlying.trace(msg, t) 23 | 24 | def trace(msg: String, args: Any*): Unit = if (underlying.isTraceEnabled) underlying.trace(msg, args) 25 | 26 | def error(msg: String): Unit = if (underlying.isErrorEnabled) underlying.error(msg) 27 | 28 | def error(msg: String, t: Throwable): Unit = if (underlying.isErrorEnabled) underlying.error(msg, t) 29 | 30 | def error(msg: String, args: Any*): Unit = if (underlying.isErrorEnabled) underlying.error(msg, args) 31 | 32 | def warn(msg: String): Unit = if (underlying.isWarnEnabled) underlying.warn(msg) 33 | 34 | def warn(msg: String, t: Throwable): Unit = if (underlying.isWarnEnabled) underlying.warn(msg, t) 35 | 36 | def warn(msg: String, args: Any*): Unit = if (underlying.isWarnEnabled) underlying.warn(msg, args) 37 | 38 | } 39 | 40 | /** 41 | * Companion for [[Logger]], providing a factory for [[Logger]]s. 42 | */ 43 | object Logger { 44 | 45 | /** 46 | * Create a [[Logger]] wrapping the given underlying `org.slf4j.Logger`. 47 | */ 48 | def apply(underlying: Underlying): Logger = 49 | new Logger(underlying) 50 | 51 | /** 52 | * Create a [[Logger]] wrapping the created underlying `org.slf4j.Logger`. 53 | */ 54 | def apply(clazz: Class[_]): Logger = 55 | new Logger(LoggerFactory.getLogger(clazz.getName)) 56 | } 57 | 58 | 59 | trait LazyLogging { 60 | protected lazy val logger: Logger = Logger(LoggerFactory.getLogger(getClass.getName)) 61 | } 62 | 63 | trait StrictLogging { 64 | protected val logger: Logger = Logger(LoggerFactory.getLogger(getClass.getName)) 65 | } -------------------------------------------------------------------------------- /src/main/scala/com/lucidworks/spark/Partitioner.scala: -------------------------------------------------------------------------------- 1 | package com.lucidworks.spark 2 | 3 | import java.net.InetAddress 4 | 5 | import com.lucidworks.spark.rdd.SolrRDD 6 | import com.lucidworks.spark.util.SolrSupport 7 | import org.apache.solr.client.solrj.SolrQuery 8 | import org.apache.spark.Partition 9 | 10 | import scala.collection.mutable.ArrayBuffer 11 | 12 | // Is there a need to override {@code Partitioner.scala} and define our own partition id's 13 | object SolrPartitioner { 14 | 15 | def getShardPartitions(shards: List[SolrShard], query: SolrQuery) : Array[Partition] = { 16 | shards.zipWithIndex.map{ case (shard, i) => 17 | // Chose any of the replicas as the active shard to query 18 | SelectSolrRDDPartition(i, "*", shard, query, SolrRDD.randomReplica(shard))}.toArray 19 | } 20 | 21 | def getSplitPartitions( 22 | shards: List[SolrShard], 23 | query: SolrQuery, 24 | splitFieldName: String, 25 | splitsPerShard: Int): Array[Partition] = { 26 | var splitPartitions = ArrayBuffer.empty[SelectSolrRDDPartition] 27 | var counter = 0 28 | shards.foreach(shard => { 29 | val splits = SolrSupport.getShardSplits(query, shard, splitFieldName, splitsPerShard) 30 | splits.foreach(split => { 31 | splitPartitions += SelectSolrRDDPartition(counter, "*", shard, split.query, split.replica) 32 | counter = counter + 1 33 | }) 34 | }) 35 | splitPartitions.toArray 36 | } 37 | 38 | // Workaround for SOLR-10490. TODO: Remove once fixed 39 | def getExportHandlerPartitions( 40 | shards: List[SolrShard], 41 | query: SolrQuery): Array[Partition] = { 42 | shards.zipWithIndex.map{ case (shard, i) => 43 | // Chose any of the replicas as the active shard to query 44 | ExportHandlerPartition(i, shard, query, SolrRDD.randomReplica(shard), 0, 0)}.toArray 45 | } 46 | 47 | // Workaround for SOLR-10490. TODO: Remove once fixed 48 | def getExportHandlerPartitions( 49 | shards: List[SolrShard], 50 | query: SolrQuery, 51 | splitFieldName: String, 52 | splitsPerShard: Int): Array[Partition] = { 53 | val splitPartitions = ArrayBuffer.empty[ExportHandlerPartition] 54 | var counter = 0 55 | shards.foreach(shard => { 56 | // Form a continuous iterator list so that we can pick different replicas for different partitions in round-robin mode 57 | val splits = SolrSupport.getExportHandlerSplits(query, shard, splitFieldName, splitsPerShard) 58 | splits.foreach(split => { 59 | splitPartitions += ExportHandlerPartition(counter, shard, split.query, split.replica, split.numWorkers, split.workerId) 60 | counter = counter+1 61 | }) 62 | }) 63 | splitPartitions.toArray 64 | } 65 | 66 | } 67 | 68 | case class SolrShard(shardName: String, replicas: List[SolrReplica]) 69 | 70 | case class SolrReplica( 71 | replicaNumber: Int, 72 | replicaName: String, 73 | replicaUrl: String, 74 | replicaHostName: String, 75 | locations: Array[InetAddress]) { 76 | def getHostAndPort: String = {replicaHostName.substring(0, replicaHostName.indexOf('_'))} 77 | override def toString: String = { 78 | s"SolrReplica(${replicaNumber}) ${replicaName}: url=${replicaUrl}, hostName=${replicaHostName}, locations="+locations.mkString(",") 79 | } 80 | } 81 | -------------------------------------------------------------------------------- /src/main/scala/com/lucidworks/spark/SolrRDDPartition.scala: -------------------------------------------------------------------------------- 1 | package com.lucidworks.spark 2 | 3 | import org.apache.solr.client.solrj.SolrQuery 4 | import org.apache.solr.common.params.SolrParams 5 | import org.apache.spark.Partition 6 | 7 | trait SolrRDDPartition extends Partition { 8 | val solrShard: SolrShard 9 | val query: SolrQuery 10 | var preferredReplica: SolrReplica // Preferred replica to query 11 | } 12 | 13 | case class CloudStreamPartition( 14 | index: Int, 15 | zkhost:String, 16 | collection:String, 17 | params: SolrParams) 18 | extends Partition 19 | 20 | case class SelectSolrRDDPartition( 21 | index: Int, 22 | cursorMark: String, 23 | solrShard: SolrShard, 24 | query: SolrQuery, 25 | var preferredReplica: SolrReplica) 26 | extends SolrRDDPartition 27 | 28 | case class ExportHandlerPartition( 29 | index: Int, 30 | solrShard: SolrShard, 31 | query: SolrQuery, 32 | var preferredReplica: SolrReplica, 33 | numWorkers: Int, 34 | workerId: Int) 35 | extends SolrRDDPartition 36 | 37 | case class SolrLimitPartition( 38 | index: Int = 0, 39 | zkhost:String, 40 | collection:String, 41 | maxRows: Int, 42 | query: SolrQuery) 43 | extends Partition 44 | -------------------------------------------------------------------------------- /src/main/scala/com/lucidworks/spark/SolrStreamWriter.scala: -------------------------------------------------------------------------------- 1 | package com.lucidworks.spark 2 | 3 | import com.lucidworks.spark.util.{SolrQuerySupport, SolrSupport} 4 | import org.apache.spark.sql.{DataFrame, SparkSession} 5 | import org.apache.spark.sql.execution.streaming.Sink 6 | import org.apache.spark.sql.streaming.OutputMode 7 | import com.lucidworks.spark.util.ConfigurationConstants._ 8 | import org.apache.spark.sql.types.StructType 9 | 10 | 11 | /** 12 | * Writes a Spark stream to Solr 13 | * @param sparkSession 14 | * @param parameters 15 | * @param partitionColumns 16 | * @param outputMode 17 | * @param solrConf 18 | */ 19 | class SolrStreamWriter( 20 | val sparkSession: SparkSession, 21 | parameters: Map[String, String], 22 | val partitionColumns: Seq[String], 23 | val outputMode: OutputMode)( 24 | implicit val solrConf : SolrConf = new SolrConf(parameters)) 25 | extends Sink with LazyLogging { 26 | 27 | require(solrConf.getZkHost.isDefined, s"Parameter ${SOLR_ZK_HOST_PARAM} not defined") 28 | require(solrConf.getCollection.isDefined, s"Parameter ${SOLR_COLLECTION_PARAM} not defined") 29 | 30 | val collection : String = solrConf.getCollection.get 31 | val zkhost: String = solrConf.getZkHost.get 32 | 33 | lazy val solrVersion : String = SolrSupport.getSolrVersion(solrConf.getZkHost.get) 34 | lazy val uniqueKey: String = SolrQuerySupport.getUniqueKey(zkhost, collection.split(",")(0)) 35 | 36 | lazy val dynamicSuffixes: Set[String] = SolrQuerySupport.getFieldTypes( 37 | Set.empty, 38 | SolrSupport.getSolrBaseUrl(zkhost), 39 | SolrSupport.getCachedCloudClient(zkhost), 40 | collection, 41 | skipDynamicExtensions = false) 42 | .keySet 43 | .filter(f => f.startsWith("*_") || f.endsWith("_*")) 44 | .map(f => if (f.startsWith("*_")) f.substring(1) else f.substring(0, f.length-1)) 45 | 46 | @volatile private var latestBatchId: Long = -1L 47 | val acc: SparkSolrAccumulator = new SparkSolrAccumulator 48 | val accName: String = if (solrConf.getAccumulatorName.isDefined) solrConf.getAccumulatorName.get else "Records Written" 49 | sparkSession.sparkContext.register(acc, accName) 50 | SparkSolrAccumulatorContext.add(accName, acc.id) 51 | 52 | override def addBatch(batchId: Long, df: DataFrame): Unit = { 53 | if (batchId <= latestBatchId) { 54 | logger.info(s"Skipping already processed batch $batchId") 55 | } else { 56 | val rows = df.collect() 57 | if (rows.nonEmpty) { 58 | val schema: StructType = df.schema 59 | val solrClient = SolrSupport.getCachedCloudClient(zkhost) 60 | 61 | // build up a list of updates to send to the Solr Schema API 62 | val fieldsToAddToSolr = SolrRelation.getFieldsToAdd(schema, solrConf, solrVersion, dynamicSuffixes) 63 | 64 | if (fieldsToAddToSolr.nonEmpty) { 65 | SolrRelation.addFieldsForInsert(fieldsToAddToSolr, collection, solrClient) 66 | } 67 | 68 | val solrDocs = rows.toStream.map(row => SolrRelation.convertRowToSolrInputDocument(row, solrConf, uniqueKey)) 69 | acc.add(solrDocs.length.toLong) 70 | SolrSupport.sendBatchToSolrWithRetry(zkhost, solrClient, collection, solrDocs, solrConf.commitWithin) 71 | logger.info(s"Written ${solrDocs.length} documents to Solr collection $collection from batch $batchId") 72 | latestBatchId = batchId 73 | } 74 | } 75 | } 76 | } 77 | -------------------------------------------------------------------------------- /src/main/scala/com/lucidworks/spark/SparkSolrAccumulatorContext.scala: -------------------------------------------------------------------------------- 1 | package com.lucidworks.spark 2 | 3 | import scala.collection.concurrent.TrieMap 4 | 5 | /** 6 | * Spark made it impossible to lookup an accumulator by name. Holding a global singleton here, so that external 7 | * clients that use this library can access the accumulators that are created by spark-solr for reading/writing 8 | * Get rid of this once Spark ties accumulators to the context SPARK-13051 9 | * 10 | * Not really happy about the global singleton but I don't see any other way to do it 11 | */ 12 | object SparkSolrAccumulatorContext { 13 | 14 | private val accMapping = TrieMap.empty[String, Long] 15 | 16 | def remove(name: String): Unit = { 17 | accMapping.remove(name) 18 | } 19 | 20 | def add(name: String, id: Long): Unit = { 21 | accMapping.put(name, id) 22 | } 23 | 24 | def getId(name: String): Option[Long] = { 25 | accMapping.get(name) 26 | } 27 | 28 | override def toString = s"SparkSolrAccumulatorContext($accMapping)" 29 | } 30 | -------------------------------------------------------------------------------- /src/main/scala/com/lucidworks/spark/example/NewRDDExample.scala: -------------------------------------------------------------------------------- 1 | package com.lucidworks.spark.example 2 | 3 | import com.lucidworks.spark.{LazyLogging, SparkApp} 4 | import com.lucidworks.spark.rdd.SelectSolrRDD 5 | import com.lucidworks.spark.util.SolrSupport 6 | import org.apache.commons.cli.{CommandLine, Option} 7 | import org.apache.solr.client.solrj.request.CollectionAdminRequest 8 | import org.apache.spark.{SparkConf, SparkContext} 9 | 10 | class NewRDDExample extends SparkApp.RDDProcessor with LazyLogging { 11 | 12 | override def getName: String = "new-rdd-example" 13 | 14 | override def getOptions: Array[Option] = Array( 15 | Option.builder().longOpt("query").hasArg.required(true).desc("Query to field").build() 16 | ) 17 | 18 | override def run(conf: SparkConf, cli: CommandLine): Int = { 19 | val zkHost = cli.getOptionValue("zkHost", "localhost:9983") 20 | val collection = cli.getOptionValue("collection", "collection1") 21 | val queryStr = cli.getOptionValue("query", "*:*") 22 | 23 | // IMPORTANT: reload the collection to flush caches 24 | println(s"\nReloading collection $collection to flush caches!\n") 25 | val cloudSolrClient = SolrSupport.getCachedCloudClient(zkHost) 26 | val req = CollectionAdminRequest.reloadCollection(collection) 27 | cloudSolrClient.request(req) 28 | 29 | val sc = new SparkContext(conf) 30 | val rdd = new SelectSolrRDD(zkHost, collection, sc).query(queryStr) 31 | val count = rdd.count() 32 | 33 | logger.info("Count is " + count) 34 | sc.stop() 35 | 0 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /src/main/scala/com/lucidworks/spark/example/RDDExample.scala: -------------------------------------------------------------------------------- 1 | package com.lucidworks.spark.example 2 | 3 | import com.lucidworks.spark.{LazyLogging, SparkApp} 4 | import com.lucidworks.spark.rdd.SelectSolrRDD 5 | import com.lucidworks.spark.util.SolrSupport 6 | import org.apache.commons.cli.{CommandLine, Option} 7 | import org.apache.solr.client.solrj.request.CollectionAdminRequest 8 | import org.apache.spark.{SparkConf, SparkContext} 9 | 10 | class RDDExample extends SparkApp.RDDProcessor with LazyLogging { 11 | 12 | override def getName: String = "old-rdd-example" 13 | 14 | override def getOptions: Array[Option] = Array( 15 | Option.builder().longOpt("query").hasArg.required(true).desc("Query to field").build() 16 | ) 17 | 18 | override def run(conf: SparkConf, cli: CommandLine): Int = { 19 | val zkHost = cli.getOptionValue("zkHost", "localhost:9983") 20 | val collection = cli.getOptionValue("collection", "collection1") 21 | val queryStr = cli.getOptionValue("query", "*:*") 22 | 23 | // IMPORTANT: reload the collection to flush caches 24 | println(s"\nReloading collection $collection to flush caches!\n") 25 | val cloudSolrClient = SolrSupport.getCachedCloudClient(zkHost) 26 | val req = CollectionAdminRequest.reloadCollection(collection) 27 | cloudSolrClient.request(req) 28 | 29 | val sc = new SparkContext(conf) 30 | val rdd = new SelectSolrRDD(zkHost, collection, sc) 31 | val count = rdd.query(queryStr).count() 32 | 33 | logger.info("Count is " + count) 34 | sc.stop() 35 | 0 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /src/main/scala/com/lucidworks/spark/example/events/EventsimIndexer.scala: -------------------------------------------------------------------------------- 1 | package com.lucidworks.spark.example.events 2 | 3 | import java.net.URL 4 | import java.util.{Calendar, TimeZone} 5 | import com.lucidworks.spark.SparkApp.RDDProcessor 6 | import com.lucidworks.spark.fusion.FusionPipelineClient 7 | import org.apache.commons.cli.{CommandLine, Option} 8 | import org.apache.spark.SparkConf 9 | import org.apache.spark.sql.{Row, SparkSession} 10 | 11 | import scala.collection.JavaConverters.bufferAsJavaList 12 | import scala.collection.mutable.ListBuffer 13 | 14 | class EventsimIndexer extends RDDProcessor { 15 | val DEFAULT_ENDPOINT = 16 | "http://localhost:8764/api/apollo/index-pipelines/eventsim-default/collections/eventsim/index" 17 | 18 | def getName: String = "eventsim" 19 | 20 | def getOptions: Array[Option] = { 21 | Array( 22 | Option.builder() 23 | .hasArg().required(true) 24 | .desc("Path to an eventsim JSON file") 25 | .longOpt("eventsimJson").build, 26 | Option.builder() 27 | .hasArg() 28 | .desc("Fusion endpoint(s); default is " + DEFAULT_ENDPOINT) 29 | .longOpt("fusion").build, 30 | Option.builder() 31 | .hasArg() 32 | .desc("Fusion username; default is admin") 33 | .longOpt("fusionUser").build, 34 | Option.builder() 35 | .hasArg() 36 | .desc("Fusion password; required if fusionAuthEnbled=true") 37 | .longOpt("fusionPass").build, 38 | Option.builder() 39 | .hasArg() 40 | .desc("Fusion security realm; default is native") 41 | .longOpt("fusionRealm").build, 42 | Option.builder() 43 | .hasArg() 44 | .desc("Fusion authentication enabled; default is true") 45 | .longOpt("fusionAuthEnabled").build, 46 | Option.builder() 47 | .hasArg() 48 | .desc("Fusion indexing batch size; default is 100") 49 | .longOpt("fusionBatchSize").build 50 | ) 51 | } 52 | 53 | def run(conf: SparkConf, cli: CommandLine): Int = { 54 | val fusionEndpoints: String = cli.getOptionValue("fusion", DEFAULT_ENDPOINT) 55 | val fusionAuthEnabled: Boolean = 56 | "true".equalsIgnoreCase(cli.getOptionValue("fusionAuthEnabled", "true")) 57 | val fusionUser: String = cli.getOptionValue("fusionUser", "admin") 58 | 59 | val fusionPass: String = cli.getOptionValue("fusionPass") 60 | if (fusionAuthEnabled && (fusionPass == null || fusionPass.isEmpty)) 61 | throw new IllegalArgumentException("Fusion password is required when authentication is enabled!") 62 | 63 | val fusionRealm: String = cli.getOptionValue("fusionRealm", "native") 64 | val fusionBatchSize: Int = cli.getOptionValue("fusionBatchSize", "100").toInt 65 | 66 | val urls = fusionEndpoints.split(",").distinct 67 | val url = new URL(urls(0)) 68 | val pipelinePath = url.getPath 69 | 70 | val sparkSession: SparkSession = SparkSession.builder().config(conf).getOrCreate() 71 | 72 | sparkSession.read.json(cli.getOptionValue("eventsimJson")).foreachPartition((rows: Iterator[Row]) => { 73 | val fusion: FusionPipelineClient = 74 | if (fusionAuthEnabled) new FusionPipelineClient(fusionEndpoints, fusionUser, fusionPass, fusionRealm) 75 | else new FusionPipelineClient(fusionEndpoints) 76 | 77 | val batch = new ListBuffer[Map[String, _]]() 78 | rows.foreach(next => { 79 | var userId: String = "" 80 | var sessionId: String = "" 81 | var ts: Long = 0 82 | 83 | val fields = new ListBuffer[Map[String, _]]() 84 | for (c <- 0 until next.length) { 85 | val obj = next.get(c) 86 | if (obj != null) { 87 | var colValue = obj 88 | val fieldName = next.schema.fieldNames(c) 89 | if ("ts" == fieldName || "registration" == fieldName) { 90 | ts = obj.asInstanceOf[Long] 91 | val cal = Calendar.getInstance(TimeZone.getTimeZone("UTC")) 92 | cal.setTimeInMillis(ts) 93 | colValue = cal.getTime.toInstant.toString 94 | } else if ("userId" == fieldName) { 95 | userId = obj.toString 96 | } else if ("sessionId" == fieldName) { 97 | sessionId = obj.toString 98 | } 99 | fields += Map("name" -> fieldName, "value" -> colValue) 100 | } 101 | } 102 | 103 | batch += Map("id" -> s"$userId-$sessionId-$ts", "fields" -> fields) 104 | 105 | if (batch.size == fusionBatchSize) { 106 | fusion.postBatchToPipeline(pipelinePath, bufferAsJavaList(batch)) 107 | batch.clear 108 | } 109 | }) 110 | 111 | // post the final batch if any left over 112 | if (batch.nonEmpty) { 113 | fusion.postBatchToPipeline(pipelinePath, bufferAsJavaList(batch)) 114 | batch.clear 115 | } 116 | }) 117 | 118 | sparkSession.stop() 119 | 120 | 0 121 | } 122 | } 123 | -------------------------------------------------------------------------------- /src/main/scala/com/lucidworks/spark/example/query/QueryBenchmark.scala: -------------------------------------------------------------------------------- 1 | package com.lucidworks.spark.example.query 2 | 3 | import com.lucidworks.spark.SparkApp 4 | import com.lucidworks.spark.rdd.SelectSolrRDD 5 | import com.lucidworks.spark.util.SolrSupport 6 | import org.apache.commons.cli.{CommandLine, Option} 7 | import org.apache.solr.client.solrj.SolrQuery 8 | import org.apache.solr.client.solrj.request.CollectionAdminRequest 9 | import org.apache.spark.{SparkConf, SparkContext} 10 | 11 | class QueryBenchmark extends SparkApp.RDDProcessor { 12 | def getName: String = "query-solr-benchmark" 13 | 14 | def getOptions: Array[Option] = { 15 | Array( 16 | Option.builder().longOpt("query").hasArg.required(false).desc("URL encoded Solr query to send to Solr, default is *:* (all docs)").build, 17 | Option.builder().longOpt("rows").hasArg.required(false).desc("Number of rows to fetch at once, default is 1000").build, 18 | Option.builder().longOpt("splitsPerShard").hasArg.required(false).desc("Number of splits per shard, default is 3").build, 19 | Option.builder().longOpt("splitField").hasArg.required(false).desc("Name of an indexed numeric field (preferably long type) used to split a shard, default is _version_").build, 20 | Option.builder().longOpt("fields").hasArg.required(false).desc("Comma-delimited list of fields to be returned from the query; default is all fields").build 21 | ) 22 | } 23 | 24 | def run(conf: SparkConf, cli: CommandLine): Int = { 25 | 26 | val zkHost = cli.getOptionValue("zkHost", "localhost:9983") 27 | val collection = cli.getOptionValue("collection", "collection1") 28 | val queryStr = cli.getOptionValue("query", "*:*") 29 | val rows = cli.getOptionValue("rows", "1000").toInt 30 | val splitsPerShard = cli.getOptionValue("splitsPerShard", "3").toInt 31 | val splitField = cli.getOptionValue("splitField", "_version_") 32 | 33 | val sc = new SparkContext(conf) 34 | 35 | val solrQuery: SolrQuery = new SolrQuery(queryStr) 36 | 37 | val fields = cli.getOptionValue("fields", "") 38 | if (!fields.isEmpty) 39 | fields.split(",").foreach(solrQuery.addField) 40 | 41 | solrQuery.addSort(new SolrQuery.SortClause("id", "asc")) 42 | solrQuery.setRows(rows) 43 | 44 | val solrRDD: SelectSolrRDD = new SelectSolrRDD(zkHost, collection, sc) 45 | 46 | var startMs: Long = System.currentTimeMillis 47 | 48 | var count = solrRDD.query(solrQuery).splitField(splitField).splitsPerShard(splitsPerShard).count() 49 | 50 | var tookMs: Long = System.currentTimeMillis - startMs 51 | println(s"\nTook $tookMs ms read $count docs using queryShards with $splitsPerShard splits") 52 | 53 | // IMPORTANT: reload the collection to flush caches 54 | println(s"\nReloading collection $collection to flush caches!\n") 55 | val cloudSolrClient = SolrSupport.getCachedCloudClient(zkHost) 56 | val req = CollectionAdminRequest.reloadCollection(collection) 57 | cloudSolrClient.request(req) 58 | 59 | startMs = System.currentTimeMillis 60 | 61 | count = solrRDD.query(solrQuery).count() 62 | 63 | tookMs = System.currentTimeMillis - startMs 64 | println(s"\nTook $tookMs ms read $count docs using queryShards") 65 | 66 | sc.stop() 67 | 0 68 | } 69 | } 70 | -------------------------------------------------------------------------------- /src/main/scala/com/lucidworks/spark/example/query/WordCount.scala: -------------------------------------------------------------------------------- 1 | package com.lucidworks.spark.example.query 2 | 3 | import com.lucidworks.spark.SparkApp.RDDProcessor 4 | import com.lucidworks.spark.rdd.{SelectSolrRDD} 5 | import com.lucidworks.spark.util.ConfigurationConstants._ 6 | import org.apache.commons.cli.{CommandLine, Option} 7 | import org.apache.solr.common.SolrDocument 8 | import org.apache.spark.rdd.RDD 9 | import org.apache.spark.sql.{DataFrame, SparkSession} 10 | import org.apache.spark.{SparkConf, SparkContext} 11 | 12 | import scala.collection.immutable.HashMap 13 | 14 | /** 15 | * Example of an wordCount spark app to process tweets from a Solr collection 16 | */ 17 | class WordCount extends RDDProcessor{ 18 | def getName: String = "word-count" 19 | 20 | def getOptions: Array[Option] = { 21 | Array( 22 | Option.builder() 23 | .argName("QUERY") 24 | .longOpt("query") 25 | .hasArg 26 | .required(false) 27 | .desc("URL encoded Solr query to send to Solr") 28 | .build() 29 | ) 30 | } 31 | 32 | def run(conf: SparkConf, cli: CommandLine): Int = { 33 | val zkHost = cli.getOptionValue("zkHost", "localhost:9983") 34 | val collection = cli.getOptionValue("collection", "collection1") 35 | val queryStr = cli.getOptionValue("query", "*:*") 36 | 37 | val sc = SparkContext.getOrCreate(conf) 38 | val solrRDD: SelectSolrRDD = new SelectSolrRDD(zkHost, collection, sc) 39 | val rdd: RDD[SolrDocument] = solrRDD.query(queryStr) 40 | 41 | val words: RDD[String] = rdd.map(doc => if (doc.containsKey("text_t")) doc.get("text_t").toString else "") 42 | val pWords: RDD[String] = words.flatMap(s => s.toLowerCase.replaceAll("[.,!?\n]", " ").trim().split(" ")) 43 | 44 | val wordsCountPairs: RDD[(String, Int)] = pWords.map(s => (s, 1)) 45 | .reduceByKey((a,b) => a+b) 46 | .map(item => item.swap) 47 | .sortByKey(false) 48 | .map(item => item.swap) 49 | 50 | wordsCountPairs.take(20).iterator.foreach(println) 51 | 52 | val sparkSession: SparkSession = SparkSession.builder().config(conf).getOrCreate() 53 | // Now use schema information in Solr to build a queryable SchemaRDD 54 | 55 | // Pro Tip: SolrRDD will figure out the schema if you don't supply a list of field names in your query 56 | val options = HashMap[String, String]( 57 | SOLR_ZK_HOST_PARAM -> zkHost, 58 | SOLR_COLLECTION_PARAM -> collection, 59 | SOLR_QUERY_PARAM -> queryStr 60 | ) 61 | 62 | val df: DataFrame = sparkSession.read.format("solr").options(options).load() 63 | val numEchos = df.filter(df.col("type_s").equalTo("echo")).count() 64 | println("numEchos >> " + numEchos) 65 | 66 | sc.stop() 67 | 0 68 | } 69 | } 70 | 71 | -------------------------------------------------------------------------------- /src/main/scala/com/lucidworks/spark/rdd/SolrRDD.scala: -------------------------------------------------------------------------------- 1 | package com.lucidworks.spark.rdd 2 | 3 | import com.lucidworks.spark._ 4 | import com.lucidworks.spark.util.QueryConstants._ 5 | import com.lucidworks.spark.util.{CacheCloudSolrClient, CacheHttpSolrClient, SolrQuerySupport} 6 | import org.apache.solr.client.solrj.SolrQuery 7 | import org.apache.spark._ 8 | import org.apache.spark.rdd.RDD 9 | import org.apache.spark.scheduler.{SparkListenerApplicationEnd, SparkListenerEvent} 10 | 11 | import scala.reflect.ClassTag 12 | import scala.util.Random 13 | 14 | abstract class SolrRDD[T: ClassTag]( 15 | val zkHost: String, 16 | val collection: String, 17 | @transient private val sc: SparkContext, 18 | requestHandler: Option[String] = None, 19 | query : Option[String] = None, 20 | fields: Option[Array[String]] = None, 21 | rows: Option[Int] = None, 22 | splitField: Option[String] = None, 23 | splitsPerShard: Option[Int] = None, 24 | solrQuery: Option[SolrQuery] = None, 25 | uKey: Option[String] = None) 26 | extends RDD[T](sc, Seq.empty) 27 | with LazyLogging { 28 | 29 | sparkContext.addSparkListener(new SparkFirehoseListener() { 30 | override def onEvent(event: SparkListenerEvent): Unit = event match { 31 | case e: SparkListenerApplicationEnd => 32 | logger.debug(s"Invalidating cloud client and http client caches for event ${e}") 33 | CacheCloudSolrClient.cache.invalidateAll() 34 | CacheHttpSolrClient.cache.invalidateAll() 35 | case _ => 36 | } 37 | }) 38 | 39 | val uniqueKey: String = if (uKey.isDefined) uKey.get else SolrQuerySupport.getUniqueKey(zkHost, collection.split(",")(0)) 40 | 41 | def buildQuery: SolrQuery 42 | 43 | def query(q: String): SolrRDD[T] 44 | 45 | def query(solrQuery: SolrQuery): SolrRDD[T] 46 | 47 | def select(fl: String): SolrRDD[T] 48 | 49 | def select(fl: Array[String]): SolrRDD[T] 50 | 51 | def rows(rows: Int): SolrRDD[T] 52 | 53 | def doSplits(): SolrRDD[T] 54 | 55 | def splitField(field: String): SolrRDD[T] 56 | 57 | def splitsPerShard(splitsPerShard: Int): SolrRDD[T] 58 | 59 | def requestHandler(requestHandler: String): SolrRDD[T] 60 | 61 | def solrCount: BigInt = SolrQuerySupport.getNumDocsFromSolr(collection, zkHost, solrQuery) 62 | 63 | def getReplicaToQuery(partition: SolrRDDPartition, attempt_no: Int): String = { 64 | val preferredReplicaUrl = partition.preferredReplica.replicaUrl 65 | if (attempt_no == 0) 66 | preferredReplicaUrl 67 | else { 68 | logger.info(s"Task attempt no. ${attempt_no}. Checking if replica ${preferredReplicaUrl} is healthy") 69 | // can't do much if there is only one replica for this shard 70 | if (partition.solrShard.replicas.length == 1) 71 | return preferredReplicaUrl 72 | 73 | // Switch to another replica as the task has failed 74 | val newReplicaToQuery = SolrRDD.randomReplica(partition.solrShard, partition.preferredReplica) 75 | logger.info(s"Switching from $preferredReplicaUrl to ${newReplicaToQuery.replicaUrl}") 76 | partition.preferredReplica = newReplicaToQuery 77 | newReplicaToQuery.replicaUrl 78 | } 79 | } 80 | 81 | def calculateSplitsPerShard(solrQuery: SolrQuery, shardSize: Int, replicaSize: Int, docsPerTask: Int = 10000, maxRows: Option[Int] = None): Int = { 82 | val minSplitSize = 2 * replicaSize 83 | if (maxRows.isDefined) return minSplitSize 84 | val noOfDocs = SolrQuerySupport.getNumDocsFromSolr(collection, zkHost, Some(solrQuery)) 85 | val splits = noOfDocs / (docsPerTask * shardSize) 86 | val splitsPerShard = if (splits > minSplitSize) { 87 | splits.toInt 88 | } else { 89 | minSplitSize 90 | } 91 | logger.debug(s"Suggested split size: ${splitsPerShard} for collection size: ${noOfDocs} using query: ${solrQuery}") 92 | splitsPerShard 93 | } 94 | } 95 | 96 | object SolrRDD { 97 | 98 | def randomReplicaLocation(solrShard: SolrShard): String = { 99 | randomReplica(solrShard).replicaUrl 100 | } 101 | 102 | def randomReplica(solrShard: SolrShard): SolrReplica = { 103 | solrShard.replicas(Random.nextInt(solrShard.replicas.size)) 104 | } 105 | 106 | def randomReplica(solrShard: SolrShard, replicaToExclude: SolrReplica): SolrReplica = { 107 | val filteredReplicas = solrShard.replicas.filter(p => p.equals(replicaToExclude)) 108 | solrShard.replicas(Random.nextInt(filteredReplicas.size)) 109 | } 110 | 111 | def apply( 112 | zkHost: String, 113 | collection: String, 114 | @transient sc: SparkContext, 115 | requestHandler: Option[String] = None, 116 | query : Option[String] = Option(DEFAULT_QUERY), 117 | fields: Option[Array[String]] = None, 118 | rows: Option[Int] = Option(DEFAULT_PAGE_SIZE), 119 | splitField: Option[String] = None, 120 | splitsPerShard: Option[Int] = None, 121 | solrQuery: Option[SolrQuery] = None, 122 | uKey: Option[String] = None, 123 | maxRows: Option[Int] = None, 124 | accumulator: Option[SparkSolrAccumulator] = None): SolrRDD[_] = { 125 | if (requestHandler.isDefined) { 126 | if (requiresStreamingRDD(requestHandler.get)) { 127 | // streaming doesn't support maxRows 128 | new StreamingSolrRDD(zkHost, collection, sc, requestHandler, query, fields, rows, splitField, splitsPerShard, solrQuery, uKey, accumulator) 129 | } else { 130 | new SelectSolrRDD(zkHost, collection, sc, requestHandler, query, fields, rows, splitField, splitsPerShard, solrQuery, uKey, maxRows, accumulator) 131 | } 132 | } else { 133 | new SelectSolrRDD(zkHost, collection, sc, Some(DEFAULT_REQUEST_HANDLER), query, fields, rows, splitField, splitsPerShard, solrQuery, uKey, maxRows, accumulator) 134 | } 135 | } 136 | 137 | def requiresStreamingRDD(rq: String): Boolean = { 138 | rq == QT_EXPORT || rq == QT_STREAM || rq == QT_SQL 139 | } 140 | 141 | } 142 | 143 | -------------------------------------------------------------------------------- /src/main/scala/com/lucidworks/spark/util/ConfigurationConstants.scala: -------------------------------------------------------------------------------- 1 | package com.lucidworks.spark.util 2 | 3 | import com.lucidworks.spark.BatchSizeType 4 | 5 | // This should only be used for config options for the sql statements [SolrRelation] 6 | object ConfigurationConstants { 7 | val SOLR_ZK_HOST_PARAM: String = "zkhost" 8 | val SOLR_COLLECTION_PARAM: String = "collection" 9 | 10 | // Query params 11 | val SOLR_QUERY_PARAM: String = "query" 12 | val SOLR_FIELD_PARAM: String = "fields" 13 | val SOLR_FILTERS_PARAM: String = "filters" 14 | val SOLR_ROWS_PARAM: String = "rows" 15 | val SOLR_DO_SPLITS: String = "splits" 16 | val SOLR_SPLIT_FIELD_PARAM: String = "split_field" 17 | val SOLR_SPLITS_PER_SHARD_PARAM: String = "splits_per_shard" 18 | val ESCAPE_FIELDNAMES_PARAM: String = "escape_fieldnames" 19 | val SKIP_NON_DOCVALUE_FIELDS: String = "skip_non_dv" 20 | val SOLR_DOC_VALUES: String = "dv" 21 | val FLATTEN_MULTIVALUED: String = "flatten_multivalued" 22 | val REQUEST_HANDLER: String = "request_handler" 23 | val USE_CURSOR_MARKS: String = "use_cursor_marks" 24 | val SOLR_STREAMING_EXPR: String = "expr" 25 | val SOLR_SQL_STMT: String = "sql" 26 | val SORT_PARAM: String = "sort" 27 | 28 | // Index params 29 | val SOFT_AUTO_COMMIT_SECS: String = "soft_commit_secs" 30 | val BATCH_SIZE: String = "batch_size" 31 | // num_docs or num_bytes 32 | val BATCH_SIZE_TYPE: String = "batch_size_type" 33 | val GENERATE_UNIQUE_KEY: String = "gen_uniq_key" 34 | val GENERATE_UNIQUE_CHILD_KEY: String = "gen_uniq_child_key" 35 | val COMMIT_WITHIN_MILLI_SECS: String = "commit_within" 36 | val CHILD_DOC_FIELDNAME: String = "child_doc_fieldname" 37 | val SOLR_FIELD_TYPES: String = "solr_field_types" 38 | 39 | val SAMPLE_SEED: String = "sample_seed" 40 | val SAMPLE_PCT: String = "sample_pct" 41 | 42 | // Time series partitioning params 43 | 44 | val PARTITION_BY:String="partition_by" 45 | val TIMESTAMP_FIELD_NAME:String="timestamp_field_name" 46 | val TIME_PERIOD:String="time_period" 47 | val DATETIME_PATTERN:String="datetime_pattern" 48 | val TIMEZONE_ID:String="timezone_id" 49 | val MAX_ACTIVE_PARTITIONS:String="max_active_partitions" 50 | val COLLECTION_ALIAS:String="collection_alias" 51 | 52 | val ARBITRARY_PARAMS_STRING: String = "solr.params" 53 | 54 | val SCHEMA: String = "schema" 55 | val MAX_SHARDS_FOR_SCHEMA_SAMPLING = "max_schema_sampling_shards" 56 | val STREAMING_EXPR_SCHEMA: String = "expr_schema" 57 | val SOLR_SQL_SCHEMA: String = "sql_schema" 58 | val EXCLUDE_FIELDS: String = "exclude_fields" 59 | val MAX_ROWS: String = "max_rows" 60 | 61 | val ACCUMULATOR_NAME: String = "acc_name" 62 | } 63 | -------------------------------------------------------------------------------- /src/main/scala/com/lucidworks/spark/util/Constants.scala: -------------------------------------------------------------------------------- 1 | package com.lucidworks.spark.util 2 | 3 | object Constants { 4 | val SOLR_FORMAT = "solr" 5 | val PROMOTE_TO_DOUBLE = "promote_to_double" 6 | } 7 | -------------------------------------------------------------------------------- /src/main/scala/com/lucidworks/spark/util/JavaApiHelper.scala: -------------------------------------------------------------------------------- 1 | package com.lucidworks.spark.util 2 | 3 | import scala.reflect.ClassTag 4 | 5 | object JavaApiHelper { 6 | 7 | // Copied from {@JavaApiHelper} in spark-cassandra-connector project 8 | /** Returns a `ClassTag` of a given runtime class. */ 9 | def getClassTag[T](clazz: Class[T]): ClassTag[T] = ClassTag(clazz) 10 | } 11 | -------------------------------------------------------------------------------- /src/main/scala/com/lucidworks/spark/util/JsonUtil.scala: -------------------------------------------------------------------------------- 1 | package com.lucidworks.spark.util 2 | 3 | import org.json4s._ 4 | 5 | object JsonUtil { 6 | 7 | implicit class JValueExtended(value: JValue) { 8 | def has(childString: String): Boolean = { 9 | (value \ childString) != JNothing 10 | } 11 | } 12 | } 13 | -------------------------------------------------------------------------------- /src/main/scala/com/lucidworks/spark/util/QueryConstants.scala: -------------------------------------------------------------------------------- 1 | package com.lucidworks.spark.util 2 | 3 | object QueryConstants { 4 | // Request handlers 5 | val QT_STREAM = "/stream" 6 | val QT_SQL = "/sql" 7 | val QT_EXPORT = "/export" 8 | val QT_SELECT = "/select" 9 | 10 | val DEFAULT_REQUIRED_FIELD: String = "id" 11 | val DEFAULT_PAGE_SIZE: Int = 5000 12 | val DEFAULT_QUERY: String = "*:*" 13 | val DEFAULT_SPLITS_PER_SHARD: Int = 10 14 | val DEFAULT_SPLIT_FIELD: String = "_version_" 15 | val DEFAULT_REQUEST_HANDLER: String = QT_SELECT 16 | val DEFAULT_TIMESTAMP_FIELD_NAME: String = "timestamp_tdt" 17 | val DEFAULT_TIME_PERIOD: String = "1DAYS" 18 | val DEFAULT_TIMEZONE_ID: String = "UTC" 19 | val DEFAULT_DATETIME_PATTERN: String = "yyyy_MM_dd" 20 | val DEFAULT_CHILD_DOC_FIELD_NAME: String = "_childDocuments_" 21 | } 22 | -------------------------------------------------------------------------------- /src/main/scala/com/lucidworks/spark/util/SolrDataFrameImplicits.scala: -------------------------------------------------------------------------------- 1 | package com.lucidworks.spark.util 2 | 3 | import org.apache.spark.sql.{DataFrameReader, DataFrameWriter, Row, SaveMode} 4 | 5 | /** 6 | * Usage: 7 | * 8 | * import SolrDataFrameImplicits._ 9 | * // then you can: 10 | * val spark: SparkSession 11 | * val collectionName: String 12 | * val df = spark.read.solr(collectionName) 13 | * // do stuff 14 | * df.write.solr(collectionName, overwrite = true) 15 | * // or various other combinations, like setting your own options earlier 16 | * df.write.option("zkhost", "some other solr cluster's zk host").solr(collectionName) 17 | */ 18 | object SolrDataFrameImplicits { 19 | 20 | implicit class SolrReader(reader: DataFrameReader) { 21 | def solr(collection: String, query: String = "*:*") = 22 | reader.format("solr").option("collection", collection).option("query", query).load() 23 | def solr(collection: String, options: Map[String, String]) = 24 | reader.format("solr").option("collection", collection).options(options).load() 25 | } 26 | 27 | implicit class SolrWriter(writer: DataFrameWriter[Row]) { 28 | def solr(collectionName: String, softCommitSecs: Int = 10, overwrite: Boolean = false, format: String = "solr") = { 29 | writer 30 | .format(format) 31 | .option("collection", collectionName) 32 | .option("soft_commit_secs", softCommitSecs.toString) 33 | .mode(if(overwrite) SaveMode.Overwrite else SaveMode.Append) 34 | .save() 35 | } 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/ml/HasInputColsTransformer.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.ml 19 | 20 | import org.apache.spark.ml.param.shared.{HasOutputCol, HasInputCols} 21 | import org.apache.spark.ml.util.{DefaultParamsReader, DefaultParamsWritable} 22 | 23 | /** Shim to allow access to Spark's private[ml] traits HasInputCols and HasOutputCol */ 24 | abstract class HasInputColsTransformer extends Transformer with HasInputCols with HasOutputCol with DefaultParamsWritable { } 25 | 26 | class TransformerParamsReader[T] extends DefaultParamsReader[T] { 27 | // just to expose the private[ml] stuff 28 | } 29 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/solr/SparkInternalObjects.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.solr 2 | 3 | import org.apache.spark.util.{AccumulatorContext, AccumulatorV2} 4 | 5 | object SparkInternalObjects { 6 | 7 | def getAccumulatorById(id: Long): Option[AccumulatorV2[_, _]] = { 8 | AccumulatorContext.get(id) 9 | } 10 | 11 | } 12 | -------------------------------------------------------------------------------- /src/main/scala/solr/DefaultSource.scala: -------------------------------------------------------------------------------- 1 | package solr 2 | 3 | import com.lucidworks.spark.{SolrRelation, SolrStreamWriter} 4 | import com.lucidworks.spark.util.Constants 5 | import org.apache.spark.sql.execution.streaming.Sink 6 | import org.apache.spark.sql.{DataFrame, SQLContext, SaveMode} 7 | import org.apache.spark.sql.sources._ 8 | import org.apache.spark.sql.streaming.OutputMode 9 | 10 | class DefaultSource extends RelationProvider with CreatableRelationProvider with StreamSinkProvider with DataSourceRegister { 11 | 12 | override def createRelation(sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation = { 13 | try { 14 | new SolrRelation(parameters, sqlContext.sparkSession) 15 | } catch { 16 | case re: RuntimeException => throw re 17 | case e: Exception => throw new RuntimeException(e) 18 | } 19 | } 20 | 21 | override def createRelation( 22 | sqlContext: SQLContext, 23 | mode: SaveMode, 24 | parameters: Map[String, String], 25 | df: DataFrame): BaseRelation = { 26 | try { 27 | // TODO: What to do with the saveMode? 28 | val solrRelation: SolrRelation = new SolrRelation(parameters, Some(df), sqlContext.sparkSession) 29 | solrRelation.insert(df, overwrite = true) 30 | solrRelation 31 | } catch { 32 | case re: RuntimeException => throw re 33 | case e: Exception => throw new RuntimeException(e) 34 | } 35 | } 36 | 37 | override def shortName(): String = Constants.SOLR_FORMAT 38 | 39 | override def createSink( 40 | sqlContext: SQLContext, 41 | parameters: Map[String, String], 42 | partitionColumns: Seq[String], 43 | outputMode: OutputMode): Sink = { 44 | new SolrStreamWriter(sqlContext.sparkSession, parameters, partitionColumns, outputMode) 45 | } 46 | } 47 | -------------------------------------------------------------------------------- /src/test/java/com/lucidworks/spark/RDDProcessorTestBase.java: -------------------------------------------------------------------------------- 1 | package com.lucidworks.spark; 2 | 3 | import com.lucidworks.spark.rdd.SolrJavaRDD; 4 | import com.lucidworks.spark.util.SolrSupport; 5 | import org.apache.solr.common.SolrDocument; 6 | import org.apache.solr.common.SolrInputDocument; 7 | import org.apache.solr.util.DateMathParser; 8 | import org.apache.spark.api.java.JavaRDD; 9 | import org.apache.spark.api.java.JavaSparkContext; 10 | import org.apache.spark.api.java.function.Function; 11 | import org.apache.spark.sql.SparkSession; 12 | import org.junit.AfterClass; 13 | import org.junit.BeforeClass; 14 | 15 | import java.io.File; 16 | import java.io.Serializable; 17 | import java.util.Arrays; 18 | 19 | import static org.junit.Assert.assertTrue; 20 | 21 | /** 22 | * Base class for testing RDDProcessor implementations. 23 | */ 24 | public class RDDProcessorTestBase extends TestSolrCloudClusterSupport implements Serializable{ 25 | 26 | protected static transient JavaSparkContext jsc; 27 | protected static transient SparkSession sparkSession; 28 | protected int batchSize = 1000; 29 | protected BatchSizeType batchSizeType = BatchSizeType.NUM_DOCS; 30 | 31 | public JavaSparkContext getJsc() { 32 | return jsc; 33 | } 34 | 35 | @BeforeClass 36 | public static void setupSparkSession() { 37 | sparkSession = SparkSession.builder() 38 | .appName("test") 39 | .master("local") 40 | .config("spark.ui.enabled", "false") 41 | .config("spark.default.parallelism", "1") 42 | .getOrCreate(); 43 | jsc = new JavaSparkContext(sparkSession.sparkContext()); 44 | } 45 | 46 | @AfterClass 47 | public static void stopSparkSession() { 48 | try { 49 | sparkSession.stop(); 50 | } finally { 51 | SparkSession.clearActiveSession(); 52 | SparkSession.clearDefaultSession(); 53 | } 54 | } 55 | 56 | protected void buildCollection(String zkHost, String collection) throws Exception { 57 | String[] inputDocs = new String[] { 58 | collection+"-1,foo,bar,1,[a;b],[1;2]", 59 | collection+"-2,foo,baz,2,[c;d],[3;4]", 60 | collection+"-3,bar,baz,3,[e;f],[5;6]" 61 | }; 62 | buildCollection(zkHost, collection, inputDocs, 2); 63 | } 64 | 65 | protected void buildCollection(String zkHost, String collection, int numDocs) throws Exception { 66 | buildCollection(zkHost, collection, numDocs, 2); 67 | } 68 | 69 | protected void buildCollection(String zkHost, String collection, int numDocs, int numShards) throws Exception { 70 | String[] inputDocs = new String[numDocs]; 71 | for (int n=0; n < numDocs; n++) 72 | inputDocs[n] = collection+"-"+n+",foo"+n+",bar"+n+","+n+",[a;b],[1;2]"; 73 | buildCollection(zkHost, collection, inputDocs, numShards); 74 | } 75 | 76 | protected void buildCollection(String zkHost, String collection, String[] inputDocs, int numShards) throws Exception { 77 | String confName = "testConfig"; 78 | File confDir = new File("src/test/resources/conf"); 79 | int replicationFactor = 1; 80 | createCollection(collection, numShards, replicationFactor, numShards /* maxShardsPerNode */, confName, confDir); 81 | 82 | // index some docs into the new collection 83 | if (inputDocs != null) { 84 | int numDocsIndexed = indexDocs(zkHost, collection, inputDocs); 85 | SolrSupport.getCachedCloudClient(zkHost).commit(collection); 86 | // verify docs got indexed ... relies on soft auto-commits firing frequently 87 | SolrJavaRDD solrRDD = SolrJavaRDD.get(zkHost, collection, jsc.sc()); 88 | JavaRDD resultsRDD = solrRDD.query("*:*"); 89 | long numFound = resultsRDD.count(); 90 | assertTrue("expected " + numDocsIndexed + " docs in query results from " + collection + ", but got " + numFound, 91 | numFound == (long) numDocsIndexed); 92 | } 93 | } 94 | 95 | protected int indexDocs(String zkHost, String collection, String[] inputDocs) { 96 | JavaRDD input = jsc.parallelize(Arrays.asList(inputDocs), 1); 97 | JavaRDD docs = input.map(new Function() { 98 | public SolrInputDocument call(String row) throws Exception { 99 | String[] fields = row.split(","); 100 | if (fields.length < 6) 101 | throw new IllegalArgumentException("Each test input doc should have at least 6 fields! invalid doc: "+row); 102 | 103 | SolrInputDocument doc = new SolrInputDocument(); 104 | doc.setField("id", fields[0]); 105 | doc.setField("field1_s", fields[1]); 106 | doc.setField("field2_s", fields[2]); 107 | doc.setField("field3_i", Integer.parseInt(fields[3])); 108 | 109 | String[] list = fields[4].substring(1,fields[4].length()-1).split(";"); 110 | for (int i=0; i < list.length; i++) 111 | doc.addField("field4_ss", list[i]); 112 | 113 | list = fields[5].substring(1,fields[5].length()-1).split(";"); 114 | for (int i=0; i < list.length; i++) 115 | doc.addField("field5_ii", Integer.parseInt(list[i])); 116 | 117 | if (fields.length > 6) { 118 | list = fields[6].substring(1,fields[6].length()-1).split(";"); 119 | for (int i=0; i < list.length; i++) { 120 | if (list[i].endsWith("Z")) 121 | doc.addField("field6_tdts", DateMathParser.parseMath(null, list[i])); 122 | else 123 | doc.addField("field6_tdts", DateMathParser.parseMath(null, list[i] + "Z")); 124 | } 125 | 126 | } 127 | 128 | return doc; 129 | } 130 | }); 131 | SolrSupport.indexDocs(zkHost, collection, batchSize, batchSizeType, docs.rdd()); 132 | return inputDocs.length; 133 | } 134 | } 135 | -------------------------------------------------------------------------------- /src/test/java/com/lucidworks/spark/SolrRDDTest.java: -------------------------------------------------------------------------------- 1 | package com.lucidworks.spark; 2 | 3 | import com.lucidworks.spark.rdd.SolrJavaRDD; 4 | import org.apache.solr.client.solrj.SolrQuery; 5 | import org.apache.solr.client.solrj.request.QueryRequest; 6 | import org.apache.solr.common.SolrDocument; 7 | import org.apache.solr.common.SolrInputDocument; 8 | import org.apache.solr.common.cloud.Aliases; 9 | import org.apache.solr.common.params.CollectionParams; 10 | import org.apache.solr.common.params.CoreAdminParams; 11 | import org.apache.solr.common.params.ModifiableSolrParams; 12 | import org.apache.spark.api.java.JavaRDD; 13 | import org.junit.Ignore; 14 | import org.junit.Test; 15 | 16 | import java.util.List; 17 | 18 | import static org.junit.Assert.assertEquals; 19 | import static org.junit.Assert.assertTrue; 20 | 21 | /** 22 | * Test basic functionality of the SolrRDD implementations. 23 | */ 24 | public class SolrRDDTest extends RDDProcessorTestBase { 25 | 26 | @Test 27 | public void testCollectionAliasSupport() throws Exception { 28 | String zkHost = cluster.getZkServer().getZkAddress(); 29 | buildCollection(zkHost, "test1"); 30 | buildCollection(zkHost, "test2"); 31 | 32 | // create a collection alias that uses test1 and test2 under the covers 33 | String aliasName = "test"; 34 | String createAliasCollectionsList = "test1,test2"; 35 | ModifiableSolrParams modParams = new ModifiableSolrParams(); 36 | modParams.set(CoreAdminParams.ACTION, CollectionParams.CollectionAction.CREATEALIAS.name()); 37 | modParams.set("name", aliasName); 38 | modParams.set("collections", createAliasCollectionsList); 39 | QueryRequest request = new QueryRequest(modParams); 40 | request.setPath("/admin/collections"); 41 | cloudSolrServer.request(request); 42 | 43 | Aliases aliases = cloudSolrServer.getZkStateReader().getAliases(); 44 | assertEquals(createAliasCollectionsList, aliases.getCollectionAliasMap().get(aliasName)); 45 | 46 | // ok, alias is setup ... now fire a query against it 47 | long expectedNumDocs = 6; 48 | SolrJavaRDD solrRDD = SolrJavaRDD.get(zkHost, aliasName, jsc.sc()); 49 | JavaRDD resultsRDD = solrRDD.query("*:*"); 50 | long numFound = resultsRDD.count(); 51 | assertTrue("expected " + expectedNumDocs + " docs in query results from alias " + aliasName + ", but got " + numFound, 52 | numFound == expectedNumDocs); 53 | } 54 | 55 | @Test 56 | public void testQueryShards() throws Exception { 57 | String zkHost = cluster.getZkServer().getZkAddress(); 58 | String testCollection = "queryShards"; 59 | int numDocs = 2000; 60 | buildCollection(zkHost, testCollection, numDocs, 3); 61 | 62 | SolrJavaRDD solrRDD = SolrJavaRDD.get(zkHost, testCollection, jsc.sc()); 63 | 64 | 65 | SolrQuery testQuery = new SolrQuery(); 66 | testQuery.setQuery("*:*"); 67 | testQuery.setRows(57); 68 | testQuery.addSort(new SolrQuery.SortClause("id", SolrQuery.ORDER.asc)); 69 | JavaRDD docs = solrRDD.queryShards(testQuery); 70 | List docList = docs.collect(); 71 | assertTrue("expected "+numDocs+" from queryShards but only found "+docList.size(), docList.size() == numDocs); 72 | 73 | deleteCollection(testCollection); 74 | } 75 | 76 | @Ignore //Ignore until real-time GET is implemented 77 | @Test 78 | public void testGet() throws Exception { 79 | String zkHost = cluster.getZkServer().getZkAddress(); 80 | String testCollection = "queryGet"; 81 | deleteCollection(testCollection); 82 | buildCollection(zkHost, testCollection, new String[0], 1); 83 | 84 | SolrInputDocument doc = new SolrInputDocument(); 85 | doc.addField("id", "new-dummy-doc"); 86 | doc.addField("field1_s", "value1"); 87 | doc.addField("field2_s", "value2"); 88 | cloudSolrServer.add(testCollection, doc, -1); 89 | 90 | SolrJavaRDD rdd = SolrJavaRDD.get(zkHost, testCollection, jsc.sc()); 91 | // List docs = rdd.get(doc.getField("id").getValue().toString()).collect(); 92 | // assert docs.size() == 1; 93 | // assert docs.get(0).get("id").equals(doc.getField("id").getValue()); 94 | } 95 | 96 | @Test 97 | public void testSolrQuery() throws Exception { 98 | String testCollection = "testSolrQuery"; 99 | 100 | try { 101 | String zkHost = cluster.getZkServer().getZkAddress(); 102 | String[] inputDocs = new String[] { 103 | testCollection+"-1,foo,bar,1,[a;b],[1;2]", 104 | testCollection+"-2,foo,baz,2,[c;d],[3;4]", 105 | testCollection+"-3,bar,baz,3,[e;f],[5;6]" 106 | }; 107 | 108 | buildCollection(zkHost, testCollection, inputDocs, 1); 109 | 110 | { 111 | String queryStr = "q=*:*&sort=id asc&fq=field1_s:foo"; 112 | 113 | SolrJavaRDD solrRDD = SolrJavaRDD.get(zkHost, testCollection, jsc.sc()); 114 | List docs = solrRDD.query(queryStr).collect(); 115 | 116 | assert(docs.size() == 2); 117 | assert docs.get(0).get("id").equals(testCollection + "-1"); 118 | } 119 | 120 | { 121 | String queryStr = "q=*:*&sort=id&fq=field3_i:[2 TO 3]"; 122 | 123 | SolrJavaRDD solrRDD = SolrJavaRDD.get(zkHost, testCollection, jsc.sc()); 124 | 125 | List docs = solrRDD.queryNoSplits(queryStr).collect(); 126 | 127 | assert docs.size() == 2; 128 | assert docs.get(0).get("id").equals(testCollection + "-2"); 129 | } 130 | 131 | } finally { 132 | deleteCollection(testCollection); 133 | } 134 | 135 | } 136 | 137 | 138 | } 139 | -------------------------------------------------------------------------------- /src/test/java/com/lucidworks/spark/SolrSqlTest.java: -------------------------------------------------------------------------------- 1 | package com.lucidworks.spark; 2 | 3 | import com.lucidworks.spark.util.EventsimUtil; 4 | import junit.framework.Assert; 5 | import org.apache.spark.sql.Dataset; 6 | import org.apache.spark.sql.Row; 7 | import org.apache.spark.sql.SQLContext; 8 | import org.apache.spark.sql.types.DataTypes; 9 | import org.apache.spark.sql.types.StructType; 10 | import org.junit.Test; 11 | 12 | import java.util.*; 13 | 14 | import static com.lucidworks.spark.util.ConfigurationConstants.*; 15 | 16 | public class SolrSqlTest extends RDDProcessorTestBase { 17 | 18 | 19 | /** 20 | * 1. Create a collection 21 | * 2. Modify the schema to enable docValues for some fields 22 | * 3. Index sample dataset 23 | * 4. Do a series of SQL queries and make sure they return valid results 24 | * @throws Exception 25 | */ 26 | //@Ignore 27 | @Test 28 | public void testSQLQueries() throws Exception { 29 | String testCollectionName = "testSQLQueries"; 30 | try { 31 | 32 | String zkHost = cluster.getZkServer().getZkAddress(); 33 | 34 | HashMap options = new HashMap<>(); 35 | 36 | deleteCollection(testCollectionName); 37 | buildCollection(zkHost, testCollectionName, null, 2); 38 | EventsimUtil.loadEventSimDataSet(zkHost, testCollectionName, sparkSession); 39 | 40 | options.put(SOLR_ZK_HOST_PARAM(), zkHost); 41 | options.put(SOLR_COLLECTION_PARAM(), testCollectionName); 42 | options.put(SOLR_QUERY_PARAM(), "*:*"); 43 | 44 | { 45 | Dataset eventsim = sparkSession.read().format("solr").options(options).option(SOLR_DOC_VALUES(), "true").load(); 46 | eventsim.createOrReplaceTempView("eventsim"); 47 | 48 | Dataset records = sparkSession.sql("SELECT * FROM eventsim"); 49 | StructType schema = records.schema(); 50 | List rows = records.collectAsList(); 51 | assert records.count() == 1000; 52 | 53 | String[] fieldNames = schema.fieldNames(); 54 | // list of fields that are indexed from {@code EventsimUtil#loadEventSimDataSet} 55 | Assert.assertEquals(21, fieldNames.length); // 18 fields from the file + id + _root_ + artist_txt 56 | //assert fieldNames.length == 20; 57 | 58 | Assert.assertEquals(schema.apply("ts").dataType().typeName(), DataTypes.TimestampType.typeName()); 59 | Assert.assertEquals(schema.apply("sessionId").dataType().typeName(), DataTypes.LongType.typeName()); 60 | Assert.assertEquals(schema.apply("length").dataType().typeName(), DataTypes.DoubleType.typeName()); 61 | Assert.assertEquals(schema.apply("song").dataType().typeName(), DataTypes.StringType.typeName()); 62 | 63 | Assert.assertEquals(21, ((Row)rows.get(0)).length()); 64 | } 65 | 66 | // Filter using SQL syntax and escape field names 67 | { 68 | Dataset eventsim = sparkSession.read().format("solr").options(options).load(); 69 | eventsim.createOrReplaceTempView("eventsim"); 70 | 71 | Dataset records = sparkSession.sql("SELECT `userId`, `ts` from eventsim WHERE `gender` = 'M'"); 72 | assert records.count() == 567; 73 | } 74 | 75 | // Configure the sql query to do splits using an int type field. TODO: Assert the number of partitions based on the field values 76 | { 77 | options.put(SOLR_SPLIT_FIELD_PARAM(), "sessionId"); 78 | options.put(SOLR_SPLITS_PER_SHARD_PARAM(), "10"); 79 | options.put(SOLR_DOC_VALUES(), "false"); 80 | Dataset eventsim = sparkSession.read().format("solr").options(options).load(); 81 | 82 | List rows = eventsim.collectAsList(); 83 | assert rows.size() == 1000; 84 | } 85 | } finally { 86 | deleteCollection(testCollectionName); 87 | } 88 | } 89 | 90 | @Test(expected=IllegalArgumentException.class) 91 | public void testInvalidOptions() { 92 | sparkSession.read().format("solr").load(); 93 | } 94 | 95 | @Test(expected=IllegalArgumentException.class) 96 | public void testInvalidCollectionOption() { 97 | 98 | Map options = Collections.singletonMap("zkHost", cluster.getZkServer().getZkAddress()); 99 | sparkSession.read().format("solr").options(options).load(); 100 | } 101 | 102 | } 103 | -------------------------------------------------------------------------------- /src/test/java/com/lucidworks/spark/StreamProcessorTestBase.java: -------------------------------------------------------------------------------- 1 | package com.lucidworks.spark; 2 | 3 | import java.io.Serializable; 4 | 5 | import org.apache.spark.SparkConf; 6 | import org.apache.spark.streaming.Duration; 7 | import org.apache.spark.streaming.api.java.JavaStreamingContext; 8 | import org.junit.After; 9 | import org.junit.Before; 10 | 11 | /** 12 | * Base class for tests that need a SolrCloud cluster and a JavaStreamingContext. 13 | */ 14 | public abstract class StreamProcessorTestBase extends TestSolrCloudClusterSupport implements Serializable { 15 | 16 | protected transient JavaStreamingContext jssc; 17 | 18 | @Before 19 | public void setupSparkStreamingContext() { 20 | SparkConf conf = new SparkConf() 21 | .setMaster("local") 22 | .setAppName("test") 23 | .set("spark.default.parallelism", "1"); 24 | jssc = new JavaStreamingContext(conf, new Duration(500)); 25 | } 26 | 27 | @After 28 | public void stopSparkStreamingContext() { 29 | jssc.stop(true, true); 30 | } 31 | 32 | } 33 | -------------------------------------------------------------------------------- /src/test/java/com/lucidworks/spark/example/hadoop/HdfsToSolrRDDProcessorTest.java: -------------------------------------------------------------------------------- 1 | package com.lucidworks.spark.example.hadoop; 2 | 3 | import com.lucidworks.spark.SparkApp; 4 | import org.junit.Ignore; 5 | import org.junit.Test; 6 | 7 | import static org.junit.Assert.fail; 8 | 9 | public class HdfsToSolrRDDProcessorTest { 10 | 11 | @Ignore 12 | @Test 13 | public void testRDDProcessor() { 14 | String[] args = new String[] { 15 | "hdfs-to-solr", "-zkHost", "localhost:9983", 16 | "-collection", "gettingstarted", 17 | "-hdfsPath", "hdfs://localhost:9000/user/timpotter/perf", 18 | "-master", "local[2]", "-v" 19 | }; 20 | 21 | try { 22 | SparkApp.main(args); 23 | } catch (Exception exc) { 24 | exc.printStackTrace(); 25 | fail(getClass().getSimpleName()+" failed due to: "+exc); 26 | } 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /src/test/java/com/lucidworks/spark/example/hadoop/Logs2SolrRDDProcessorTest.java: -------------------------------------------------------------------------------- 1 | package com.lucidworks.spark.example.hadoop; 2 | 3 | import com.lucidworks.spark.SparkApp; 4 | import org.junit.Ignore; 5 | import org.junit.Test; 6 | 7 | import static org.junit.Assert.fail; 8 | 9 | public class Logs2SolrRDDProcessorTest { 10 | 11 | @Ignore 12 | @Test 13 | public void testRDDProcessor() { 14 | String[] args = new String[] { 15 | "logs2solr", "-zkHost", "localhost:9983", 16 | "-collection", "gettingstarted", 17 | "-hdfsPath", "hdfs://localhost:9000/user/timpotter/gc_logs", 18 | "-master", "local[2]", "-v" 19 | }; 20 | 21 | try { 22 | SparkApp.main(args); 23 | } catch (Exception exc) { 24 | exc.printStackTrace(); 25 | fail(getClass().getSimpleName()+" failed due to: "+exc); 26 | } 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /src/test/java/com/lucidworks/spark/example/query/BuildQueryTest.java: -------------------------------------------------------------------------------- 1 | package com.lucidworks.spark.example.query; 2 | 3 | import com.lucidworks.spark.util.SolrQuerySupport; 4 | import org.apache.solr.client.solrj.SolrQuery; 5 | import org.junit.Test; 6 | 7 | import java.net.URLEncoder; 8 | import java.nio.charset.StandardCharsets; 9 | import java.util.List; 10 | 11 | import static org.junit.Assert.assertEquals; 12 | import static org.junit.Assert.assertNotNull; 13 | import static org.junit.Assert.assertTrue; 14 | import static com.lucidworks.spark.util.QueryConstants.*; 15 | 16 | public class BuildQueryTest { 17 | 18 | @Test 19 | public void testQueryBuilder() { 20 | SolrQuery q = null; 21 | 22 | q = SolrQuerySupport.toQuery(null); 23 | assertEquals("*:*", q.getQuery()); 24 | assertEquals(new Integer(DEFAULT_PAGE_SIZE()), q.getRows()); 25 | 26 | q = SolrQuerySupport.toQuery("q=*:*") ; 27 | assertEquals("*:*", q.getQuery()); 28 | assertEquals(new Integer(DEFAULT_PAGE_SIZE()), q.getRows()); 29 | 30 | q = SolrQuerySupport.toQuery("q={!geofilt sfield=geo_location pt=44.9609,-93.2642 d=50}") ; 31 | assertEquals("{!geofilt sfield=geo_location pt=44.9609,-93.2642 d=50}", q.getQuery()); 32 | 33 | q = SolrQuerySupport.toQuery("{!geofilt sfield=geo_location pt=44.9609,-93.2642 d=50}") ; 34 | assertEquals("{!geofilt sfield=geo_location pt=44.9609,-93.2642 d=50}", q.getQuery()); 35 | 36 | String qs = "text:hello"; 37 | String fq = "price:[100 TO *]"; 38 | String sort = "id"; 39 | q = SolrQuerySupport.toQuery("q="+encode(qs)+"&fq="+encode(fq)+"&sort="+sort); 40 | assertEquals(qs, q.getQuery()); 41 | assertEquals(new Integer(DEFAULT_PAGE_SIZE()), q.getRows()); 42 | assertTrue(q.getFilterQueries().length == 1); 43 | assertEquals(fq, q.getFilterQueries()[0]); 44 | List sorts = q.getSorts(); 45 | assertNotNull(sorts); 46 | assertTrue(sorts.size() == 1); 47 | SolrQuery.SortClause sortClause = sorts.get(0); 48 | assertEquals(SolrQuery.SortClause.create("id","asc"), sortClause); 49 | } 50 | 51 | private String encode(String val) { 52 | try { 53 | return URLEncoder.encode(val, StandardCharsets.UTF_8.name()); 54 | } catch (java.io.UnsupportedEncodingException uee) { 55 | throw new RuntimeException(uee); 56 | } 57 | } 58 | } 59 | -------------------------------------------------------------------------------- /src/test/java/com/lucidworks/spark/example/query/ReadTermVectorsTest.java: -------------------------------------------------------------------------------- 1 | package com.lucidworks.spark.example.query; 2 | 3 | import com.lucidworks.spark.SparkApp; 4 | import org.junit.Ignore; 5 | import org.junit.Test; 6 | 7 | import static org.junit.Assert.fail; 8 | 9 | public class ReadTermVectorsTest { 10 | 11 | @Ignore 12 | @Test 13 | public void testQueryProcessor() { 14 | String[] args = new String[] { 15 | "term-vectors", "-zkHost", "localhost:9983", 16 | "-collection", "gettingstarted", "-query", "*:*", 17 | "-field", "name", 18 | "-master", "local[2]", "-v" 19 | }; 20 | 21 | try { 22 | SparkApp.main(args); 23 | } catch (Exception exc) { 24 | exc.printStackTrace(); 25 | fail("QueryProcessor failed due to: "+exc); 26 | } 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /src/test/java/com/lucidworks/spark/example/query/WordCountTest.java: -------------------------------------------------------------------------------- 1 | package com.lucidworks.spark.example.query; 2 | 3 | import com.lucidworks.spark.SparkApp; 4 | import org.junit.Ignore; 5 | import org.junit.Test; 6 | 7 | import static org.junit.Assert.fail; 8 | 9 | public class WordCountTest { 10 | 11 | @Ignore 12 | @Test 13 | public void testQueryProcessor() { 14 | String[] args = new String[] { 15 | "com.lucidworks.spark.example.query.WordCount", "-zkHost", "localhost:9983", 16 | "-collection", "gettingstarted", "-query", "*:*", 17 | "-master", "local[2]" 18 | }; 19 | 20 | try { 21 | SparkApp.main(args); 22 | } catch (Exception exc) { 23 | exc.printStackTrace(); 24 | fail("WordCount failed due to: "+exc); 25 | } 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /src/test/java/com/lucidworks/spark/example/streaming/BasicIndexingTest.java: -------------------------------------------------------------------------------- 1 | package com.lucidworks.spark.example.streaming; 2 | 3 | import java.io.File; 4 | import java.util.Arrays; 5 | import java.util.concurrent.LinkedBlockingDeque; 6 | 7 | import com.lucidworks.spark.BatchSizeType; 8 | import com.lucidworks.spark.StreamProcessorTestBase; 9 | import com.lucidworks.spark.rdd.SolrJavaRDD; 10 | import com.lucidworks.spark.util.SolrSupport; 11 | import org.apache.solr.client.solrj.SolrQuery; 12 | import org.apache.solr.common.SolrDocument; 13 | import org.apache.solr.common.SolrInputDocument; 14 | import org.apache.spark.api.java.JavaRDD; 15 | import org.apache.spark.api.java.function.Function; 16 | import org.apache.spark.streaming.api.java.JavaDStream; 17 | import org.junit.Ignore; 18 | import org.junit.Test; 19 | 20 | import static org.junit.Assert.assertTrue; 21 | 22 | /** 23 | * Indexes some docs into Solr and then verifies they were indexed correctly from Spark. 24 | */ 25 | @Ignore 26 | public class BasicIndexingTest extends StreamProcessorTestBase { 27 | 28 | @Test 29 | public void testIndexing() throws Exception { 30 | // create a collection named "test" 31 | String confName = "testConfig"; 32 | File confDir = new File("src/test/resources/conf"); 33 | String testCollection = "test"; 34 | int numShards = 1; 35 | int replicationFactor = 1; 36 | 37 | createCollection(testCollection, numShards, replicationFactor, 1, confName, confDir); 38 | 39 | // Create a stream of input docs to be indexed 40 | String[] inputDocs = new String[] { 41 | "1,foo,bar", 42 | "2,foo,baz", 43 | "3,bar,baz" 44 | }; 45 | 46 | // transform the test RDD into an input stream 47 | JavaRDD input = jssc.sparkContext().parallelize(Arrays.asList(inputDocs),1); 48 | LinkedBlockingDeque> queue = new LinkedBlockingDeque>(); 49 | queue.add(input); 50 | 51 | // map input data to SolrInputDocument objects to be indexed 52 | JavaDStream docs = jssc.queueStream(queue).map( 53 | new Function() { 54 | public SolrInputDocument call(String row) { 55 | String[] fields = row.split(","); 56 | SolrInputDocument doc = new SolrInputDocument(); 57 | doc.setField("id", fields[0]); 58 | doc.setField("field1", fields[1]); 59 | doc.setField("field2", fields[2]); 60 | return doc; 61 | } 62 | } 63 | ); 64 | 65 | // Send to Solr 66 | String zkHost = cluster.getZkServer().getZkAddress(); 67 | SolrSupport.indexDStreamOfDocs(zkHost, testCollection, 1, BatchSizeType.NUM_DOCS, docs.dstream()); 68 | 69 | // Actually start processing the stream here ... 70 | jssc.start(); 71 | 72 | // let the docs flow through the streaming job 73 | Thread.sleep(2000); 74 | 75 | // verify docs got indexed ... relies on soft auto-commits firing frequently 76 | SolrJavaRDD solrRDD = SolrJavaRDD.get(zkHost, testCollection, jssc.sparkContext().sc()); 77 | JavaRDD resultsRDD = 78 | solrRDD.queryShards(new SolrQuery("*:*")); 79 | 80 | long numFound = resultsRDD.count(); 81 | assertTrue("expected "+inputDocs.length+" docs in query results, but got "+numFound, 82 | numFound == inputDocs.length); 83 | 84 | // Commented out until we implement real-time get in BaseRDD 85 | // JavaRDD doc1 = solrRDD.get("1"); 86 | // assertEquals("foo", doc1.collect().get(0).getFirstValue("field1")); 87 | } 88 | } 89 | -------------------------------------------------------------------------------- /src/test/java/com/lucidworks/spark/example/streaming/DocumentFilteringStreamProcessorTest.java: -------------------------------------------------------------------------------- 1 | package com.lucidworks.spark.example.streaming; 2 | 3 | import com.lucidworks.spark.SparkApp; 4 | import org.junit.Ignore; 5 | import org.junit.Test; 6 | 7 | import static org.junit.Assert.fail; 8 | 9 | public class DocumentFilteringStreamProcessorTest { 10 | 11 | @Ignore 12 | @Test 13 | public void testIndexing() throws Exception { 14 | String[] args = new String[] { 15 | "docfilter", "-zkHost", "localhost:9983", 16 | "-collection", "gettingstarted", 17 | "-master", "local[2]", "-v" 18 | }; 19 | 20 | try { 21 | SparkApp.main(args); 22 | } catch (Exception exc) { 23 | exc.printStackTrace(); 24 | fail("QueryProcessor failed due to: "+exc); 25 | } 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /src/test/java/com/lucidworks/spark/query/StreamingResultsIteratorTest.java: -------------------------------------------------------------------------------- 1 | package com.lucidworks.spark.query; 2 | 3 | import com.lucidworks.spark.RDDProcessorTestBase; 4 | import com.lucidworks.spark.util.SolrQuerySupport; 5 | import com.lucidworks.spark.util.SolrRelationUtil; 6 | import org.apache.solr.client.solrj.SolrQuery; 7 | import org.apache.solr.client.solrj.impl.CloudSolrClient; 8 | import org.apache.solr.common.SolrDocument; 9 | import org.apache.solr.common.SolrInputDocument; 10 | import org.apache.spark.sql.types.StructType; 11 | import org.junit.Test; 12 | 13 | import java.util.ArrayList; 14 | import java.util.Collections; 15 | import java.util.List; 16 | import java.util.Random; 17 | 18 | import static org.junit.Assert.*; 19 | import com.lucidworks.spark.util.QueryResultsIterator; 20 | import scala.Some$; 21 | 22 | public class StreamingResultsIteratorTest extends RDDProcessorTestBase { 23 | 24 | //@Ignore 25 | @Test 26 | public void testCursorsWithUnstableIdSort() throws Exception { 27 | final CloudSolrClient cloudSolrClient = cloudSolrServer; // from base 28 | 29 | String zkHost = cluster.getZkServer().getZkAddress(); 30 | String testCollection = "testStreamingResultsIterator"; 31 | buildCollection(zkHost, testCollection, null, 1); 32 | cloudSolrClient.setDefaultCollection(testCollection); 33 | 34 | final Random random = new Random(5150); 35 | final int numDocs = 100; 36 | final List ids = new ArrayList(numDocs); 37 | for (int i=0; i < numDocs; i++) 38 | ids.add(String.valueOf(i)); 39 | Collections.shuffle(ids); 40 | 41 | // need two threads for this: 1) to send docs to Solr with randomized keys 42 | Thread sendDocsThread = new Thread() { 43 | @Override 44 | public void run() { 45 | for (int i=0; i < numDocs; i++) { 46 | SolrInputDocument doc = new SolrInputDocument(); 47 | doc.setField("id", ids.get(i)); 48 | try { 49 | cloudSolrClient.add(doc, 50); 50 | } catch (Exception e) { 51 | throw new RuntimeException(e); 52 | } 53 | 54 | long sleepMs = random.nextInt(10) * 10L; 55 | try { 56 | Thread.sleep(sleepMs); 57 | } catch (Exception exc) { exc.printStackTrace(); } 58 | 59 | if (i % 10 == 0) 60 | System.out.println("sendDocsThread has sent "+(i+1)+" docs so far ..."); 61 | } 62 | 63 | System.out.println("sendDocsThread finished sending "+numDocs+" docs"); 64 | } 65 | }; 66 | 67 | SolrQuery solrQuery = new SolrQuery("*:*"); 68 | solrQuery.setFields("id"); 69 | solrQuery.setRows(5); 70 | solrQuery.setSort(new SolrQuery.SortClause("id", "asc")); 71 | solrQuery.set("collection", testCollection); 72 | 73 | sendDocsThread.start(); 74 | Thread.sleep(2000); 75 | 76 | //StreamingResultsIterator sri = new StreamingResultsIterator(cloudSolrClient, solrQuery, "*"); 77 | QueryResultsIterator sri = new QueryResultsIterator(cloudSolrClient, solrQuery, "*") ; 78 | int numDocsFound = 0; 79 | boolean hasNext = false; 80 | do { 81 | Thread.sleep(500); 82 | hasNext = sri.hasNext(); 83 | } while (hasNext == false); 84 | 85 | while (sri.hasNext()) { 86 | SolrDocument next = sri.next(); 87 | assertNotNull(next); 88 | ++numDocsFound; 89 | 90 | // sleep a little to let the underlying results change 91 | long sleepMs = random.nextInt(10) * 5L; 92 | try { 93 | Thread.sleep(sleepMs); 94 | } catch (Exception exc) { exc.printStackTrace(); } 95 | } 96 | 97 | try { 98 | sendDocsThread.interrupt(); 99 | } catch (Exception ignore) {} 100 | 101 | //assertTrue("Iterator didn't return all docs! Num found: "+numDocsFound, numDocs == numDocsFound); 102 | } 103 | } 104 | -------------------------------------------------------------------------------- /src/test/java/com/lucidworks/spark/query/sql/SolrSQLSupportTest.java: -------------------------------------------------------------------------------- 1 | package com.lucidworks.spark.query.sql; 2 | 3 | import org.junit.Test; 4 | 5 | import java.util.Map; 6 | 7 | import static org.junit.Assert.assertEquals; 8 | import static org.junit.Assert.assertNotNull; 9 | import static org.junit.Assert.assertTrue; 10 | 11 | public class SolrSQLSupportTest { 12 | @Test 13 | public void testSQLParse() throws Exception { 14 | String sqlStmt = "SELECT DISTINCT movie_id, COUNT(*) as agg_count, avg(rating) as avg_rating, sum(rating) as sum_rating, min(rating) as min_rating, max(rating) as max_rating " + 15 | "FROM ratings GROUP BY movie_id ORDER BY movie_id asc"; 16 | Map cols = SolrSQLSupport.parseColumns(sqlStmt); 17 | assertNotNull(cols); 18 | assertTrue(cols.size() == 6); 19 | assertEquals("agg_count", cols.get("COUNT(*)")); 20 | assertEquals("movie_id", cols.get("movie_id")); 21 | assertEquals("avg_rating", cols.get("avg(rating)")); 22 | assertEquals("sum_rating", cols.get("sum(rating)")); 23 | assertEquals("min_rating", cols.get("min(rating)")); 24 | assertEquals("max_rating", cols.get("max(rating)")); 25 | 26 | String selectStar = "SELECT * FROM ratings"; 27 | cols = SolrSQLSupport.parseColumns(selectStar); 28 | assertNotNull(cols); 29 | assertTrue(cols.size() == 0); 30 | } 31 | } 32 | -------------------------------------------------------------------------------- /src/test/java/com/lucidworks/spark/solr/TestEmbeddedSolrServer.java: -------------------------------------------------------------------------------- 1 | package com.lucidworks.spark.solr; 2 | 3 | import com.lucidworks.spark.BatchSizeType; 4 | import com.lucidworks.spark.RDDProcessorTestBase; 5 | import com.lucidworks.spark.util.EmbeddedSolrServerFactory; 6 | import org.apache.solr.client.solrj.SolrQuery; 7 | import org.apache.solr.client.solrj.embedded.EmbeddedSolrServer; 8 | import org.apache.solr.client.solrj.response.QueryResponse; 9 | import org.junit.Test; 10 | 11 | public class TestEmbeddedSolrServer extends RDDProcessorTestBase { 12 | 13 | @Test 14 | public void testEmbeddedSolrServerUseNumDocsAsBatchSize() throws Exception { 15 | this.batchSizeType = BatchSizeType.NUM_DOCS; 16 | this.batchSize = 3; // 3 docs 17 | runEmbeddedServerTest("batchNumDocs", 10); 18 | } 19 | 20 | @Test 21 | public void testEmbeddedSolrServerUseNumBytesAsBatchSize() throws Exception { 22 | this.batchSizeType = BatchSizeType.NUM_BYTES; 23 | this.batchSize = 1000; // 1000 bytes 24 | runEmbeddedServerTest("batchNumBytes", 100); 25 | } 26 | 27 | private void runEmbeddedServerTest(String testCollection, int numDocs) throws Exception { 28 | EmbeddedSolrServer embeddedSolrServer = null; 29 | try { 30 | String zkHost = cluster.getZkServer().getZkAddress(); 31 | buildCollection(zkHost, testCollection, numDocs, 1); 32 | embeddedSolrServer = EmbeddedSolrServerFactory.singleton.getEmbeddedSolrServer(zkHost, testCollection); 33 | QueryResponse queryResponse = embeddedSolrServer.query(new SolrQuery("*:*")); 34 | assert (queryResponse.getStatus() == 0); 35 | } finally { 36 | if (embeddedSolrServer != null) { 37 | embeddedSolrServer.close(); 38 | } 39 | deleteCollection(testCollection); 40 | } 41 | } 42 | 43 | @Test 44 | public void testEmbeddedSolrServerCustomConfig() throws Exception { 45 | String testCollection = "testEmbeddedSolrServerConfig"; 46 | EmbeddedSolrServer embeddedSolrServer = null; 47 | try { 48 | String zkHost = cluster.getZkServer().getZkAddress(); 49 | buildCollection(zkHost, testCollection, 10, 1); 50 | embeddedSolrServer = EmbeddedSolrServerFactory.singleton.getEmbeddedSolrServer(zkHost, testCollection, "custom-solrconfig.xml", null); 51 | QueryResponse queryResponse = embeddedSolrServer.query(new SolrQuery("*:*")); 52 | assert(queryResponse.getStatus() == 0); 53 | } finally { 54 | if (embeddedSolrServer != null) { 55 | embeddedSolrServer.close(); 56 | } 57 | deleteCollection(testCollection); 58 | } 59 | } 60 | } 61 | -------------------------------------------------------------------------------- /src/test/java/com/lucidworks/spark/util/EventsimUtil.java: -------------------------------------------------------------------------------- 1 | package com.lucidworks.spark.util; 2 | 3 | import com.fasterxml.jackson.databind.ObjectMapper; 4 | import org.apache.log4j.Logger; 5 | import org.apache.solr.client.solrj.SolrQuery; 6 | import org.apache.solr.client.solrj.impl.CloudSolrClient; 7 | import org.apache.spark.sql.Dataset; 8 | import org.apache.spark.sql.SparkSession; 9 | import org.apache.solr.client.solrj.request.schema.SchemaRequest; 10 | import org.apache.solr.client.solrj.response.schema.SchemaResponse; 11 | import org.apache.solr.common.SolrException; 12 | import org.apache.solr.common.params.ModifiableSolrParams; 13 | import org.apache.spark.sql.api.java.UDF1; 14 | import org.apache.spark.sql.types.DataTypes; 15 | 16 | import java.sql.Timestamp; 17 | import java.util.HashMap; 18 | import java.util.Collections; 19 | import java.util.Map; 20 | 21 | public class EventsimUtil { 22 | static final Logger log = Logger.getLogger(EventsimUtil.class); 23 | private static ObjectMapper objectMapper = new ObjectMapper(); 24 | 25 | /** 26 | * Load the eventsim json dataset and write it using Solr writer 27 | * @throws Exception 28 | */ 29 | public static void loadEventSimDataSet(String zkHost, String collectionName, SparkSession sparkSession) throws Exception { 30 | String datasetPath = "src/test/resources/eventsim/sample_eventsim_1000.json"; 31 | Dataset df = sparkSession.read().json(datasetPath); 32 | // Modify the unix timestamp to ISO format for Solr 33 | log.info("Indexing eventsim documents from file " + datasetPath); 34 | 35 | df.createOrReplaceTempView("jdbcDF"); 36 | sparkSession.udf().register("ts2iso", new UDF1() { 37 | public Timestamp call(Long ts) { 38 | return asDate(ts); 39 | } 40 | }, DataTypes.TimestampType); 41 | 42 | // Registering an UDF and re-using it via DataFrames is not available through Java right now. 43 | Dataset newDF = sparkSession.sql("SELECT userAgent, userId, artist, auth, firstName, gender, itemInSession, lastName, " + 44 | "length, level, location, method, page, sessionId, song, " + 45 | "ts2iso(registration) AS registration, ts2iso(ts) AS ts, status from jdbcDF"); 46 | 47 | HashMap options = new HashMap(); 48 | options.put("zkhost", zkHost); 49 | options.put("collection", collectionName); 50 | options.put(ConfigurationConstants.GENERATE_UNIQUE_KEY(), "true"); 51 | 52 | newDF = newDF.withColumn("artist_txt", df.col("artist")); 53 | newDF.write().format("solr").options(options).mode(org.apache.spark.sql.SaveMode.Overwrite).save(); 54 | 55 | CloudSolrClient cloudSolrClient = SolrSupport.getCachedCloudClient(zkHost); 56 | cloudSolrClient.commit(collectionName, true, true); 57 | 58 | long docsInSolr = SolrQuerySupport.getNumDocsFromSolr(collectionName, zkHost, scala.Option.apply((SolrQuery) null)); 59 | if (!(docsInSolr == 1000)) { 60 | throw new Exception("All eventsim documents did not get indexed. Expected '1000'. Actual docs in Solr '" + docsInSolr + "'"); 61 | } 62 | } 63 | 64 | public static void defineTextFields(CloudSolrClient solrCloud, String collection) throws Exception { 65 | Map fieldParams = new HashMap<>(); 66 | fieldParams.put("name", "artist_txt"); 67 | fieldParams.put("indexed", "true"); 68 | fieldParams.put("stored", "true"); 69 | fieldParams.put("multiValued", "false"); 70 | fieldParams.put("type", "text_en"); 71 | ModifiableSolrParams solrParams = new ModifiableSolrParams(); 72 | solrParams.add("updateTimeoutSecs", "30"); 73 | SchemaRequest.AddField addField = new SchemaRequest.AddField(fieldParams); 74 | SchemaRequest.MultiUpdate addFieldsMultiUpdate = new SchemaRequest.MultiUpdate(Collections.singletonList(addField), solrParams); 75 | 76 | // Add the fields using SolrClient 77 | SchemaResponse.UpdateResponse response = addFieldsMultiUpdate.process(solrCloud, collection); 78 | if (response.getStatus() > 400) { 79 | throw new SolrException(SolrException.ErrorCode.getErrorCode(response.getStatus()), "Error indexing fields to the Schema"); 80 | } 81 | log.info("Added new field 'artist_txt' to Solr schema for collection" + collection); 82 | } 83 | 84 | private static Timestamp asDate(Object tsObj) { 85 | if (tsObj != null) { 86 | long tsLong = (tsObj instanceof Number) ? ((Number)tsObj).longValue() : Long.parseLong(tsObj.toString()); 87 | return new Timestamp(tsLong); 88 | } 89 | return null; 90 | } 91 | } 92 | -------------------------------------------------------------------------------- /src/test/resources/conf/lang/stopwords_en.txt: -------------------------------------------------------------------------------- 1 | # Licensed to the Apache Software Foundation (ASF) under one or more 2 | # contributor license agreements. See the NOTICE file distributed with 3 | # this work for additional information regarding copyright ownership. 4 | # The ASF licenses this file to You under the Apache License, Version 2.0 5 | # (the "License"); you may not use this file except in compliance with 6 | # the License. You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # a couple of test stopwords to test that the words are really being 17 | # configured from this file: 18 | stopworda 19 | stopwordb 20 | 21 | # Standard english stop words taken from Lucene's StopAnalyzer 22 | a 23 | an 24 | and 25 | are 26 | as 27 | at 28 | be 29 | but 30 | by 31 | for 32 | if 33 | in 34 | into 35 | is 36 | it 37 | no 38 | not 39 | of 40 | on 41 | or 42 | such 43 | that 44 | the 45 | their 46 | then 47 | there 48 | these 49 | they 50 | this 51 | to 52 | was 53 | will 54 | with -------------------------------------------------------------------------------- /src/test/resources/conf/solrconfig.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | LATEST 4 | ${solr.data.dir:} 5 | 6 | 7 | 8 | 9 | true 10 | managed-schema 11 | 12 | 13 | ${solr.lock.type:native} 14 | false 15 | 16 | 17 | 18 | ${solr.ulog.dir:} 19 | 20 | 21 | 1000 22 | false 23 | 24 | 25 | 500 26 | 27 | 28 | 29 | 1024 30 | 34 | 38 | 42 | 48 | 49 | true 50 | 20 51 | 200 52 | false 53 | 2 54 | 55 | 56 | 57 | 61 | 62 | 63 | 64 | 65 | 66 | 10 67 | 68 | 69 | 70 | 71 | 72 | json 73 | 74 | 75 | 76 | 77 | 78 | true 79 | json 80 | 81 | 82 | 83 | 84 | 85 | 88 | 89 | 92 | 93 | 94 | 95 | solrpingquery 96 | 97 | 98 | all 99 | 100 | 101 | 102 | 103 | 104 | explicit 105 | true 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | text/plain; charset=UTF-8 114 | 115 | 116 | -------------------------------------------------------------------------------- /src/test/resources/eventsim/fields_schema.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "name": "userId", "type": "string", "indexed": "true", "stored": "true", "docValues": "true" 4 | }, 5 | { 6 | "name": "sessionId", "type": "tint", "indexed": "true", "stored": "true" 7 | }, 8 | { 9 | "name": "page", "type": "string", "indexed": "true", "stored": "true" 10 | }, 11 | { 12 | "name": "auth", "type": "string", "indexed": "true", "stored": "true" 13 | }, 14 | { 15 | "name": "method", "type": "string", "indexed": "true", "stored": "true" 16 | }, 17 | { 18 | "name": "status", "type": "int", "indexed": "true", "stored": "true", "docValues": "true" 19 | }, 20 | { 21 | "name": "level", "type": "string", "indexed": "true", "stored": "true" 22 | }, 23 | { 24 | "name": "itemInSession", "type": "int", "indexed": "true", "stored": "true" 25 | }, 26 | { 27 | "name": "location", "type": "string", "indexed": "true", "stored": "true" 28 | }, 29 | { 30 | "name": "userAgent", "type": "string", "indexed": "true", "stored": "true" 31 | }, 32 | { 33 | "name": "lastName", "type": "string", "indexed": "true", "stored": "true" 34 | }, 35 | { 36 | "name": "firstName", "type": "string", "indexed": "true", "stored": "true" 37 | }, 38 | { 39 | "name": "gender", "type": "string", "indexed": "true", "stored": "true" 40 | }, 41 | { 42 | "name": "artist", "type": "string", "indexed": "true", "stored": "true", "docValues": "true" 43 | }, 44 | { 45 | "name": "song", "type": "string", "indexed": "true", "stored": "true", "docValues": "true" 46 | }, 47 | { 48 | "name": "length", "type": "double", "indexed": "true", "stored": "true", "docValues": "true" 49 | }, 50 | { 51 | "name": "timestamp", "type": "tdate", "indexed": "false", "stored": "false", "docValues": "true" 52 | }, 53 | { 54 | "name": "registration", "type": "tdate", "indexed": "true", "stored": "true" 55 | } 56 | ] 57 | -------------------------------------------------------------------------------- /src/test/resources/hive-site.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | javax.jdo.option.ConnectionURL 5 | jdbc:derby:memory:databaseName=metastore_db;create=true 6 | 7 | 8 | javax.jdo.option.ConnectionDriverName 9 | org.apache.derby.jdbc.EmbeddedDriver 10 | 11 | 12 | -------------------------------------------------------------------------------- /src/test/resources/log4j.properties: -------------------------------------------------------------------------------- 1 | log4j.rootLogger=INFO, stdout 2 | log4j.appender.file=org.apache.log4j.RollingFileAppender 3 | log4j.appender.file.MaxFileSize=100MB 4 | log4j.appender.file.MaxBackupIndex=20 5 | log4j.appender.file.File=logs/solr.log 6 | log4j.appender.file.layout=org.apache.log4j.PatternLayout 7 | log4j.appender.file.layout.ConversionPattern=%d{ISO8601} [%t] %-5p %c{3} %x - %m%n 8 | log4j.logger.org.apache.zookeeper=ERROR 9 | log4j.logger.org.apache.http=WARN 10 | log4j.logger.org.apache.solr.core.SolrCore=WARN 11 | log4j.logger.org.apache.solr.update.processor.LogUpdateProcessor=WARN 12 | 13 | log4j.appender.stdout=org.apache.log4j.ConsoleAppender 14 | log4j.appender.stdout.layout=org.apache.log4j.PatternLayout 15 | log4j.appender.stdout.layout.ConversionPattern=%d{ISO8601} [%t] %-5p %c{1} %x - %m%n 16 | 17 | #log4j.logger.org.apache.spark.streaming=DEBUG 18 | #log4j.logger.org.apache.spark.streaming.receiver=DEBUG 19 | log4j.logger.org.apache.spark=WARN 20 | log4j.logger.org.apache.spark.sql.execution.streaming=INFO 21 | log4j.logger.org.apache.spark.sql.execution.WholeStageCodegenExec=WARN 22 | log4j.logger.org.apache.spark.storage.BlockManager=ERROR 23 | log4j.logger.org.apache.spark.sql.catalyst.expressions.codegen=WARN 24 | 25 | log4j.logger.org.apache.spark.ml.tuning.CrossValidator=INFO 26 | log4j.logger.com.lucidworks.spark.fusion.FusionPipelineClient=INFO 27 | log4j.logger.org.apache.solr=WARN 28 | log4j.logger.org.eclipse.jetty=ERROR 29 | log4j.logger.org.spark-project.jetty=ERROR 30 | -------------------------------------------------------------------------------- /src/test/resources/solr.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | ${shareSchema:false} 4 | 5 | 6 | 127.0.0.1 7 | ${hostContext:solr} 8 | ${hostPort:8983} 9 | ${solr.zkclienttimeout:30000} 10 | ${genericCoreNodeNames:true} 11 | ${distribUpdateConnTimeout:45000} 12 | ${distribUpdateSoTimeout:340000} 13 | ${autoReplicaFailoverWaitAfterExpiration:10000} 14 | ${autoReplicaFailoverWorkLoopDelay:10000} 15 | ${autoReplicaFailoverBadNodeExpiration:60000} 16 | 17 | 18 | 20 | ${urlScheme:} 21 | ${socketTimeout:90000} 22 | ${connTimeout:15000} 23 | 24 | 25 | 26 | 27 | 28 | 29 | -------------------------------------------------------------------------------- /src/test/resources/test-data/child_documents.json: -------------------------------------------------------------------------------- 1 | { "user":"a", "dates":[ "2017-05-02", "2017-05-03"], "status":"OK", "tags":[{ "foo":123, "bar":"val1", "parent": "a"},{ "foo":456, "bar":"val2", "parent": "a"}]} 2 | { "user":"b", "dates":[ "2017-04-29", "2017-04-30"], "status":"OK", "tags":[{ "foo":789, "bar":"val1", "parent": "b"},{ "foo":1011, "bar":"val2", "parent": "b"}]} -------------------------------------------------------------------------------- /src/test/resources/test-data/oneusagov/oneusagov_sample.json: -------------------------------------------------------------------------------- 1 | { "h": "10OBm3W", "g": "15r91", "l": "pontifier", "hh": "bitly.com", "u": "http:\/\/www.nsa.gov\/", "r": "direct", "a": "Mozilla\/5.0 (Windows NT 6.1) AppleWebKit\/537.36 (KHTML, like Gecko) Chrome\/39.0.2171.95 Safari\/537.36 OPR\/26.0.1656.60", "t": 1422056470, "nk": 0, "hc": 1365701422, "_id": "f59b204e-4768-a343-5f79-d0907e3ea965", "al": "es-ES,es;q=0.8", "c": "AR", "tz": "America\/Argentina\/Catamarca", "gr": "22", "cy": "Santiago Del Estero", "ll": [ -27.750000, -64.583300 ] } 2 | { "h": "YTSU3p", "g": "YTSUjC", "l": "o_33avl0ri1b", "hh": "bit.ly", "u": "http:\/\/fwp.mt.gov\/education\/hunter\/hunterLandowner\/", "r": "http:\/\/fwp.mt.gov\/", "a": "Mozilla\/5.0 (Windows NT 6.1; WOW64) AppleWebKit\/537.36 (KHTML, like Gecko) Chrome\/39.0.2171.99 Safari\/537.36", "t": 1422056470, "nk": 0, "hc": 1412175311, "_id": "3a60cffa-ec48-db15-fc06-1064c46b6260", "al": "en-US,en;q=0.8", "c": "US", "tz": "America\/Denver", "gr": "MT", "cy": "Butte", "mc": 754, "ll": [ 46.003800, -112.534700 ] } 3 | { "h": "1IRWtqY", "g": "1Bm4KEl", "l": "edelmancms", "hh": "1.usa.gov", "u": "http:\/\/www.cuidadodesalud.gov\/es\/subscribe\/?utm_source=resonate&utm_medium=display&utm_campaign=deadline&utm_content=Deadline_C_300x250_011615", "r": "http:\/\/www.oninstagram.com\/static\/pg\/300x250_ATF.html", "a": "Mozilla\/5.0 (Linux; U; Android 4.1.2; es-us; SGH-T599N Build\/JZO54K) AppleWebKit\/534.30 (KHTML, like Gecko) Version\/4.0 Mobile Safari\/534.30", "t": 1422056470, "nk": 0, "hc": 1421449911, "_id": "af53145c-62c4-8a7d-7441-ce29d81f1be0", "al": "es-US, en-US", "c": "US", "tz": "America\/Chicago", "gr": "TX", "ll": [ 32.783100, -96.806700 ] } 4 | { "h": "1Cmco19", "g": "1Cmco1a", "l": "theusnavy", "hh": "1.usa.gov", "u": "http:\/\/www.navy.mil\/submit\/display.asp?story_id=85295", "r": "http:\/\/m.facebook.com", "a": "Mozilla\/5.0 (iPad; CPU OS 8_1_2 like Mac OS X) AppleWebKit\/600.1.4 (KHTML, like Gecko) Mobile\/12B440 [FBAN\/FBIOS;FBAV\/22.0.0.11.27;FBBV\/6183821;FBDV\/iPad2,4;FBMD\/iPad;FBSN\/iPhone OS;FBSV\/8.1.2;FBSS\/1; FBCR\/;FBID\/tablet;FBLC\/en_US;FBOP\/1]", "t": 1422056470, "nk": 0, "hc": 1421956006, "_id": "d45991ba-65a4-a50c-fde0-d318a47e22aa", "al": "en-us", "c": "PH", "tz": "Asia\/Manila", "gr": "D9", "cy": "Manila", "ll": [ 14.604200, 120.982200 ] } 5 | { "h": "6V2MBv", "g": "36VHYC", "l": "vastateparks", "hh": "bit.ly", "u": "http:\/\/www.dcr.virginia.gov\/state_parks\/ycc.shtml", "r": "direct", "a": "Mozilla\/5.0 (Windows NT 6.3; WOW64) AppleWebKit\/537.36 (KHTML, like Gecko) Chrome\/39.0.2171.99 Safari\/537.36", "t": 1422056470, "nk": 0, "hc": 1260624621, "_id": "efa198d0-d9de-08fb-27a7-aee6293c8693", "al": "en-US,en;q=0.8", "kw": "vspycc", "c": "US", "tz": "America\/New_York", "gr": "VA", "cy": "Centreville", "mc": 511, "ll": [ 38.815900, -77.460700 ] } 6 | { "h": "1CwuSfp", "g": "1y9KZsH", "l": "foodsafety", "hh": "1.usa.gov", "u": "http:\/\/www.fda.gov\/Safety\/Recalls\/ucm431432.htm?utm_source=twitterfeed&utm_medium=twitter", "r": "direct", "a": "Mozilla\/5.0 (iPhone; CPU iPhone OS 8_1_2 like Mac OS X) AppleWebKit\/600.1.4 (KHTML, like Gecko) Version\/8.0 Mobile\/12B440 Safari\/600.1.4", "t": 1422056471, "nk": 0, "hc": 1422052296, "_id": "994e1df2-f290-7d27-54b6-05bfbb1dc820", "al": "en-us", "c": "US", "tz": "America\/New_York", "gr": "MI", "cy": "Detroit", "mc": 505, "ll": [ 42.331400, -83.045700 ] } 7 | { "h": "1JuLmVe", "g": "1JuLmVf", "l": "vangheem", "hh": "1.usa.gov", "u": "http:\/\/www.fbi.gov\/cincinnati\/press-releases\/2015\/two-men-charged-with-sex-trafficking", "r": "direct", "a": "ShortLinkTranslate", "t": 1422056471, "nk": 1, "hc": 1422045024, "_id": "89ad071e-7e12-abbd-8d2a-5b234357c157", "c": "JP", "tz": "Asia\/Tokyo", "gr": "40", "cy": "Tokyo", "ll": [ 35.685000, 139.751400 ] } 8 | { "h": "1xD0tqa", "g": "1BKkzFq", "l": "bufferapp", "hh": "buff.ly", "u": "http:\/\/www.ncbi.nlm.nih.gov\/pubmed\/23113567?utm_content=buffer786b3&utm_medium=social&utm_source=facebook.com&utm_campaign=buffer", "r": "https:\/\/www.facebook.com\/", "a": "Mozilla\/5.0 (Macintosh; Intel Mac OS X 10_9_5) AppleWebKit\/600.2.5 (KHTML, like Gecko) Version\/7.1.2 Safari\/537.85.11", "t": 1422056471, "nk": 1, "hc": 1421929534, "_id": "1e0fd542-5806-994f-77b9-4bf0c2983d4e", "al": "pt-pt", "c": "PT", "tz": "Europe\/Lisbon", "gr": "19", "cy": "Barreiro", "ll": [ 38.663100, -9.072400 ] } 9 | { "h": "1xRnyEy", "g": "1ouiTdf", "l": "cdcsocialmedia", "hh": "1.usa.gov", "u": "http:\/\/www.cdc.gov\/hai\/pdfs\/patientsafety\/HAI-Patient-Empowerment.pdf", "r": "http:\/\/emergencyed.net", "a": "Mozilla\/5.0 (Windows NT 5.1) AppleWebKit\/534.25 (KHTML, like Gecko) Chrome\/12.0.704.0 Safari\/534.25", "t": 1422056472, "nk": 0, "hc": 1421770566, "_id": "40b4723f-673d-84db-343b-6b59c453e42f", "al": "en-us,en;q=0.5", "c": "US", "tz": "America\/Los_Angeles", "gr": "CA", "cy": "Newport Beach", "mc": 803, "ll": [ 33.627500, -117.873400 ] } 10 | { "h": "1D0FIbC", "g": "1D0FIbD", "l": "ifttt", "hh": "ift.tt", "u": "http:\/\/www.sec.gov\/Archives\/edgar\/data\/1224437\/000120919115006299\/0001209191-15-006299-index.htm", "r": "direct", "a": "WordPress.com; http:\/\/dontime.net", "t": 1422056472, "nk": 0, "hc": 1422056429, "_id": "ff1cad46-2a0b-80c0-3203-3ddecf61a854", "c": "US", "tz": "America\/Los_Angeles", "gr": "CA", "cy": "San Francisco", "mc": 807, "ll": [ 37.748400, -122.415600 ] } 11 | { "h": "189myqw", "g": "189myqx", "l": "ifttt", "hh": "ift.tt", "u": "http:\/\/www.sec.gov\/Archives\/edgar\/data\/1057706\/000120919115006300\/0001209191-15-006300-index.htm", "r": "direct", "a": "WordPress.com; http:\/\/dontime.net", "t": 1422056473, "nk": 0, "hc": 1422056465, "_id": "45cfc753-7fff-94c7-e2ca-93fe25859c27", "c": "US", "tz": "America\/Los_Angeles", "gr": "CA", "cy": "San Francisco", "mc": 807, "ll": [ 37.748400, -122.415600 ] } 12 | { "h": "189muqF", "g": "189muqG", "l": "ifttt", "hh": "ift.tt", "u": "http:\/\/www.sec.gov\/Archives\/edgar\/data\/1114238\/000110465915004247\/0001104659-15-004247-index.htm", "r": "direct", "a": "WordPress.com; http:\/\/dontime.net", "t": 1422056473, "nk": 0, "hc": 1422056436, "_id": "c9133960-d3aa-31a3-0c1f-9eacc1ee1850", "c": "US", "tz": "America\/Los_Angeles", "gr": "CA", "cy": "San Francisco", "mc": 807, "ll": [ 37.748400, -122.415600 ] } 13 | { "h": "1ySsq3h", "g": "1ySsq3i", "l": "dmolyaiecms", "hh": "go.hc.gov", "u": "http:\/\/www.cuidadodesalud.gov\/es\/subscribe?utm_medium=social&utm_source=facebook&utm_campaign=affordability&utm_content=como_van_tus_012215&linkId=11891582", "r": "direct", "a": "Mozilla\/5.0 (Linux; Android 4.4.2; en-us; SAMSUNG SM-G900T Build\/KOT49H) AppleWebKit\/537.36 (KHTML, like Gecko) Version\/1.6 Chrome\/28.0.1500.94 Mobile Safari\/537.36", "t": 1422056473, "nk": 1, "hc": 1421946019, "_id": "5b73386d-bbf1-a3d5-2a3d-3e594013d6f2", "al": "en-US,en;q=0.8", "c": "US", "tz": "America\/New_York", "gr": "FL", "cy": "Hollywood", "mc": 528, "ll": [ 26.011200, -80.149500 ] } 14 | -------------------------------------------------------------------------------- /src/test/resources/test-data/simple.csv: -------------------------------------------------------------------------------- 1 | id,nrating,ntitle 2 | 1,5,One Piece 3 | 2,6,Hunter x Hunter 4 | 3,4,Attack on Titan 5 | -------------------------------------------------------------------------------- /src/test/resources/wire-mock-props.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | Set useWireMockRule to 'false' if wanting to connect to an actual Fusion instance. A value of 'true' 5 | will use WireMock features instead to fake the connection to Fusion. 6 | 7 | true 8 | http:// 9 | localhost 10 | 8764 11 | agentCollection 12 | agentCollection-default 13 | /api/apollo 14 | /index-pipelines 15 | admin 16 | password123 17 | /api/session?realmName= 18 | native 19 | localhost 20 | 8089 21 | /solr 22 | 23 | -------------------------------------------------------------------------------- /src/test/scala/com/lucidworks/spark/RDDTestSuite.scala: -------------------------------------------------------------------------------- 1 | package com.lucidworks.spark 2 | 3 | import java.util.UUID 4 | 5 | import com.lucidworks.spark.rdd.{SelectSolrRDD, StreamingSolrRDD} 6 | import com.lucidworks.spark.util.ConfigurationConstants._ 7 | import com.lucidworks.spark.util.QueryConstants._ 8 | import com.lucidworks.spark.util.SolrCloudUtil 9 | import org.apache.solr.client.solrj.SolrQuery 10 | 11 | class RDDTestSuite extends TestSuiteBuilder with LazyLogging { 12 | 13 | test("Test Simple Query") { 14 | val collectionName = "testSimpleQuery" + UUID.randomUUID().toString 15 | SolrCloudUtil.buildCollection(zkHost, collectionName, 3, 2, cloudClient, sc) 16 | try { 17 | val newRDD = new SelectSolrRDD(zkHost, collectionName, sc) 18 | val docs = newRDD.collect() 19 | assert(newRDD.count() === 3) 20 | } finally { 21 | SolrCloudUtil.deleteCollection(collectionName, cluster) 22 | } 23 | } 24 | 25 | test("Test RDD Partitions") { 26 | val collectionName = "testRDDPartitions" + UUID.randomUUID().toString 27 | SolrCloudUtil.buildCollection(zkHost, collectionName, 2, 4, cloudClient, sc) 28 | try { 29 | val newRDD = new SelectSolrRDD(zkHost, collectionName, sc) 30 | val partitions = newRDD.partitions 31 | assert(partitions.length == 8) 32 | } finally { 33 | SolrCloudUtil.deleteCollection(collectionName, cluster) 34 | } 35 | } 36 | 37 | ignore("Test Simple Query that uses ExportHandler") { 38 | val collectionName = "testSimpleQuery" + UUID.randomUUID().toString 39 | SolrCloudUtil.buildCollection(zkHost, collectionName, 3999, 2, cloudClient, sc) 40 | try { 41 | val newRDD = new StreamingSolrRDD(zkHost, collectionName, sc, rows=Option(Integer.MAX_VALUE)).requestHandler(QT_EXPORT) 42 | val cnt = newRDD.count() 43 | print("\n********************** RDD COUNT IS = " + cnt + "\n\n") 44 | assert(cnt === 3999) 45 | } finally { 46 | SolrCloudUtil.deleteCollection(collectionName, cluster) 47 | } 48 | } 49 | 50 | ignore("Test RDD Partitions with an RDD that uses query using ExportHandler") { 51 | val collectionName = "testRDDPartitions" + UUID.randomUUID().toString 52 | SolrCloudUtil.buildCollection(zkHost, collectionName, 1002, 14, cloudClient, sc) 53 | try { 54 | val newRDD = new StreamingSolrRDD(zkHost, collectionName, sc, rows=Option(Integer.MAX_VALUE)).requestHandler(QT_EXPORT) 55 | val partitions = newRDD.partitions 56 | assert(partitions.length === 14) 57 | } finally { 58 | SolrCloudUtil.deleteCollection(collectionName, cluster) 59 | } 60 | } 61 | 62 | test("Test Streaming Expression") { 63 | val collectionName = "testStreamingExpr" + UUID.randomUUID().toString 64 | val numDocs = 10 65 | SolrCloudUtil.buildCollection(zkHost, collectionName, numDocs, 1, cloudClient, sc) 66 | 67 | val expr : String = 68 | s""" 69 | |search(${collectionName}, 70 | | q="*:*", 71 | | fl="field1_s", 72 | | sort="field1_s asc", 73 | | qt="/export") 74 | """.stripMargin 75 | try { 76 | val solrQuery = new SolrQuery() 77 | solrQuery.set(SOLR_STREAMING_EXPR, expr) 78 | val streamExprRDD = new StreamingSolrRDD(zkHost, collectionName, sc, Some(QT_STREAM)) 79 | val results = streamExprRDD.query(solrQuery).collect() 80 | assert(results.size == numDocs) 81 | } finally { 82 | SolrCloudUtil.deleteCollection(collectionName, cluster) 83 | } 84 | } 85 | } 86 | -------------------------------------------------------------------------------- /src/test/scala/com/lucidworks/spark/SparkSolrFunSuite.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.lucidworks.spark 19 | 20 | import org.scalatest.Outcome 21 | import org.scalatest.funsuite.AnyFunSuite 22 | /** 23 | * Base abstract class for all Scala unit tests in spark-solr for handling common functionality. 24 | * 25 | * Copied from SparkFunSuite, which is inaccessible from here. 26 | */ 27 | trait SparkSolrFunSuite extends AnyFunSuite with LazyLogging { 28 | 29 | /** 30 | * Log the suite name and the test name before and after each test. 31 | * 32 | * Subclasses should never override this method. If they wish to run 33 | * custom code before and after each test, they should mix in the 34 | * {{org.scalatest.BeforeAndAfter}} trait instead. 35 | */ 36 | final protected override def withFixture(test: NoArgTest): Outcome = { 37 | val testName = test.text 38 | val suiteName = this.getClass.getName 39 | val shortSuiteName = suiteName.replaceAll("com.lucidworks.spark", "c.l.s") 40 | try { 41 | logger.info(s"\n\n===== TEST OUTPUT FOR $shortSuiteName: '$testName' =====\n") 42 | test() 43 | } finally { 44 | logger.info(s"\n\n===== FINISHED $shortSuiteName: '$testName' =====\n") 45 | } 46 | } 47 | } 48 | 49 | -------------------------------------------------------------------------------- /src/test/scala/com/lucidworks/spark/TestChildDocuments.scala: -------------------------------------------------------------------------------- 1 | package com.lucidworks.spark 2 | 3 | import java.util.UUID 4 | 5 | import com.lucidworks.spark.util.{SolrCloudUtil, SolrSupport} 6 | import org.apache.spark.sql.DataFrame 7 | import org.apache.spark.sql.SaveMode._ 8 | 9 | class TestChildDocuments extends TestSuiteBuilder { 10 | test("child document should have root as parent document's id") { 11 | val collectionName = "testChildDocuments-" + UUID.randomUUID().toString 12 | SolrCloudUtil.buildCollection(zkHost, collectionName, null, 1, cloudClient, sc) 13 | 14 | try { 15 | val testDf = buildTestDataFrame() 16 | val solrOpts = Map( 17 | "zkhost" -> zkHost, 18 | "collection" -> collectionName, 19 | "gen_uniq_key" -> "true", 20 | "gen_uniq_child_key" -> "true", 21 | "child_doc_fieldname" -> "tags", 22 | "flatten_multivalued" -> "false" // for correct reading column "date" 23 | ) 24 | testDf.write.format("solr").options(solrOpts).mode(Overwrite).save() 25 | 26 | // Explicit commit to make sure all docs are visible 27 | val solrCloudClient = SolrSupport.getCachedCloudClient(zkHost) 28 | solrCloudClient.commit(collectionName, true, true) 29 | 30 | val solrDf = sparkSession.read.format("solr").options(solrOpts).load() 31 | solrDf.show() 32 | 33 | val userA = solrDf.filter(solrDf("user") === "a") 34 | val userB = solrDf.filter(solrDf("user") === "b") 35 | val childrenFromA = solrDf.filter(solrDf("parent") === "a") 36 | val childrenFromB = solrDf.filter(solrDf("parent") === "b") 37 | 38 | assert(userA.count == 1) 39 | assert(userB.count == 1) 40 | assert(childrenFromA.count == 2) 41 | assert(childrenFromB.count == 2) 42 | 43 | val idOfUserA = userA.select("id").rdd.map(r => r(0).asInstanceOf[String]).collect().head 44 | val idOfUserB = userB.select("id").rdd.map(r => r(0).asInstanceOf[String]).collect().head 45 | 46 | val rootsOfChildrenFromA = childrenFromA.select("_root_").rdd.map(r => r(0).asInstanceOf[String]).collect() 47 | val rootsOfChildrenFromB = childrenFromB.select("_root_").rdd.map(r => r(0).asInstanceOf[String]).collect() 48 | rootsOfChildrenFromA.foreach (root => assert(root == idOfUserA)) 49 | rootsOfChildrenFromB.foreach (root => assert(root == idOfUserB)) 50 | } finally { 51 | SolrCloudUtil.deleteCollection(collectionName, cluster) 52 | } 53 | } 54 | 55 | def buildTestDataFrame(): DataFrame = { 56 | val df = sparkSession.read.json("src/test/resources/test-data/child_documents.json") 57 | df.printSchema() 58 | df.show() 59 | assert(df.count == 2) 60 | return df 61 | } 62 | } 63 | -------------------------------------------------------------------------------- /src/test/scala/com/lucidworks/spark/TestFacetQuerying.scala: -------------------------------------------------------------------------------- 1 | package com.lucidworks.spark 2 | 3 | import com.lucidworks.spark.util.SolrQuerySupport 4 | import org.apache.spark.sql.{DataFrame, Row} 5 | 6 | class TestFacetQuerying extends MovielensBuilder { 7 | 8 | test("JSON Facet terms") { 9 | val facetQuery = 10 | """ 11 | | { 12 | | aggr_genre: { 13 | | type: terms, 14 | | field: genre, 15 | | limit: 30, 16 | | refine: true 17 | | } 18 | | } 19 | """.stripMargin 20 | val queryString = s"q=*:*&json.facet=${facetQuery}" 21 | val dataFrame: DataFrame = SolrQuerySupport.getDataframeFromFacetQuery(SolrQuerySupport.toQuery(queryString), moviesColName, zkHost, sparkSession) 22 | assert(dataFrame.schema.fieldNames.length === 2) 23 | assert(dataFrame.schema.fieldNames.toSet === Set("aggr_genre", "aggr_genre_count")) 24 | 25 | val rows = dataFrame.collect() 26 | assert(rows(0) == Row("drama", 531)) 27 | assert(rows(rows.length - 1) == Row("fantasy", 1)) 28 | 29 | // dataFrame.printSchema() 30 | // dataFrame.show(20) 31 | } 32 | 33 | 34 | test("JSON facet nested top 2") { 35 | val facetQuery = 36 | """ 37 | | { 38 | | aggr_rating: { 39 | | type: terms, 40 | | field: "rating", 41 | | limit: 30, 42 | | refine: true, 43 | | facet: { 44 | | top_movies: { 45 | | type: terms, 46 | | field: movie_id, 47 | | limit: 2 48 | | } 49 | | } 50 | | } 51 | | } 52 | """.stripMargin 53 | val queryString = s"q=*:*&json.facet=${facetQuery}" 54 | val dataFrame: DataFrame = SolrQuerySupport.getDataframeFromFacetQuery(SolrQuerySupport.toQuery(queryString), ratingsColName, zkHost, sparkSession) 55 | assert(dataFrame.schema.fieldNames.length === 4) 56 | assert(dataFrame.schema.fieldNames.toSet === Set("aggr_rating", "aggr_rating_count", "top_movies", "top_movies_count")) 57 | 58 | val rows = dataFrame.collect() 59 | assert(rows(0) == Row(4, 3383, "9", 23)) 60 | assert(rows(1) == Row(4, 3383, "237", 21)) 61 | assert(rows(2) == Row(3, 2742, "294", 20)) 62 | assert(rows(3) == Row(3, 2742, "405", 19)) 63 | assert(rows(4) == Row(5, 2124, "50", 40)) 64 | assert(rows(5) == Row(5, 2124, "56", 25)) 65 | assert(rows(6) == Row(2, 1149, "118", 12)) 66 | assert(rows(7) == Row(2, 1149, "678", 10)) 67 | assert(rows(8) == Row(1, 602, "21", 5)) 68 | assert(rows(9) == Row(1, 602, "225", 4)) 69 | // dataFrame.printSchema() 70 | // dataFrame.show(20) 71 | } 72 | 73 | test("JSON facet aggrs") { 74 | val facetQuery = 75 | """ 76 | | { 77 | | "avg_rating" : "avg(rating)", 78 | | "num_users" : "unique(user_id)", 79 | | "no_of_movies_rated" : "unique(movie_id)", 80 | | "median_rating" : "percentile(rating, 50)" 81 | | } 82 | """.stripMargin 83 | val queryString = s"q=*:*&json.facet=${facetQuery}" 84 | val dataFrame: DataFrame = SolrQuerySupport.getDataframeFromFacetQuery(SolrQuerySupport.toQuery(queryString), ratingsColName, zkHost, sparkSession) 85 | 86 | assert(dataFrame.schema.fieldNames.toSet === Set("count", "num_users", "avg_rating", "no_of_movies_rated", "median_rating")) 87 | val data = dataFrame.collect() 88 | 89 | assert(data(0) == Row(10000, 922, 4.0, 1238, 3.5278)) 90 | // dataFrame.printSchema() 91 | // dataFrame.show() 92 | } 93 | 94 | 95 | test("Test cores") { 96 | val response = SolrQuerySupport.getSolrCores(cloudClient) 97 | assert(response.responseHeader.status == 0) 98 | assert(response.status.nonEmpty) 99 | } 100 | 101 | } 102 | -------------------------------------------------------------------------------- /src/test/scala/com/lucidworks/spark/TestIndexing.scala: -------------------------------------------------------------------------------- 1 | package com.lucidworks.spark 2 | 3 | import java.util.UUID 4 | 5 | import com.lucidworks.spark.util.SolrDataFrameImplicits._ 6 | import com.lucidworks.spark.util.{ConfigurationConstants, SolrCloudUtil, SolrQuerySupport, SolrSupport} 7 | import org.apache.spark.sql.functions.{concat, lit} 8 | import org.apache.spark.sql.types.{DataTypes, StructField, StructType} 9 | 10 | class TestIndexing extends TestSuiteBuilder { 11 | 12 | test("Load csv file and index to Solr") { 13 | val collectionName = "testIndexing-" + UUID.randomUUID().toString 14 | SolrCloudUtil.buildCollection(zkHost, collectionName, null, 2, cloudClient, sc) 15 | try { 16 | val csvFileLocation = "src/test/resources/test-data/nyc_yellow_taxi_sample_1k.csv" 17 | val csvDF = sparkSession.read.format("com.databricks.spark.csv") 18 | .option("header", "true") 19 | .option("inferSchema", "true") 20 | .load(csvFileLocation) 21 | assert(csvDF.count() == 999) 22 | 23 | val solrOpts = Map("zkhost" -> zkHost, "collection" -> collectionName) 24 | val newDF = csvDF 25 | .withColumn("pickup_location", concat(csvDF.col("pickup_latitude"), lit(","), csvDF.col("pickup_longitude"))) 26 | .withColumn("dropoff_location", concat(csvDF.col("dropoff_latitude"), lit(","), csvDF.col("dropoff_longitude"))) 27 | newDF.write.option("zkhost", zkHost).option(ConfigurationConstants.GENERATE_UNIQUE_KEY, "true").solr(collectionName) 28 | 29 | // Explicit commit to make sure all docs are visible 30 | val solrCloudClient = SolrSupport.getCachedCloudClient(zkHost) 31 | solrCloudClient.commit(collectionName, true, true) 32 | 33 | val solrDF = sparkSession.read.format("solr").options(solrOpts).load() 34 | solrDF.printSchema() 35 | assert (solrDF.count() == 999) 36 | solrDF.take(10) 37 | } finally { 38 | SolrCloudUtil.deleteCollection(collectionName, cluster) 39 | } 40 | } 41 | 42 | test("Solr field types config") { 43 | val collectionName = "testIndexing-" + UUID.randomUUID().toString 44 | SolrCloudUtil.buildCollection(zkHost, collectionName, null, 2, cloudClient, sc) 45 | try { 46 | val csvFileLocation = "src/test/resources/test-data/simple.csv" 47 | val csvDF = sparkSession.read.format("com.databricks.spark.csv") 48 | .option("header", "true") 49 | .option("inferSchema", "true") 50 | .load(csvFileLocation) 51 | val solrOpts = Map("zkhost" -> zkHost, "collection" -> collectionName, ConfigurationConstants.SOLR_FIELD_TYPES -> "ntitle:text_en,nrating:string") 52 | csvDF.write.options(solrOpts).solr(collectionName) 53 | 54 | // Explicit commit to make sure all docs are visible 55 | val solrCloudClient = SolrSupport.getCachedCloudClient(zkHost) 56 | solrCloudClient.commit(collectionName, true, true) 57 | 58 | val solrBaseUrl = SolrSupport.getSolrBaseUrl(zkHost) 59 | val solrUrl = solrBaseUrl + collectionName + "/" 60 | 61 | val fieldTypes = SolrQuerySupport.getFieldTypes(Set.empty, solrUrl, cloudClient, collectionName) 62 | assert(fieldTypes("nrating").fieldType === "string") 63 | assert(fieldTypes("ntitle").fieldType === "text_en") 64 | } finally { 65 | SolrCloudUtil.deleteCollection(collectionName, cluster) 66 | } 67 | } 68 | 69 | 70 | test("Field additions") { 71 | val insertSchema = StructType(Array( 72 | StructField("index_only_field", DataTypes.StringType, nullable = true), 73 | StructField("store_only_field", DataTypes.BooleanType, nullable = true), 74 | StructField("a_s", DataTypes.StringType, nullable = true), 75 | StructField("s_b", DataTypes.StringType, nullable = true) 76 | )) 77 | val collection = "testFieldAdditions" + UUID.randomUUID().toString.replace("-", "_") 78 | try { 79 | SolrCloudUtil.buildCollection(zkHost, collection, null, 2, cloudClient, sc) 80 | val opts = Map("zkhost" -> zkHost, "collection" -> collection) 81 | 82 | val solrRelation = new SolrRelation(opts, sparkSession) 83 | val fieldsToAdd = SolrRelation.getFieldsToAdd(insertSchema, solrRelation.conf, solrRelation.solrVersion, solrRelation.dynamicSuffixes) 84 | assert(fieldsToAdd.isEmpty) 85 | } finally { 86 | SolrCloudUtil.deleteCollection(collection, cluster) 87 | } 88 | } 89 | 90 | } 91 | -------------------------------------------------------------------------------- /src/test/scala/com/lucidworks/spark/TestQuerying.scala: -------------------------------------------------------------------------------- 1 | package com.lucidworks.spark 2 | 3 | import java.util.UUID 4 | 5 | import com.lucidworks.spark.util.{SolrCloudUtil, SolrSupport} 6 | import org.apache.spark.sql.SaveMode.Overwrite 7 | import org.apache.spark.sql._ 8 | import org.apache.spark.sql.types._ 9 | 10 | class TestQuerying extends TestSuiteBuilder { 11 | 12 | test("Solr version") { 13 | val solrVersion = SolrSupport.getSolrVersion(zkHost) 14 | assert(solrVersion == "8.11.0") 15 | assert(SolrSupport.isSolrVersionAtleast(solrVersion, 7, 5, 0)) 16 | assert(SolrSupport.isSolrVersionAtleast(solrVersion, 7, 3, 0)) 17 | assert(SolrSupport.isSolrVersionAtleast(solrVersion, 7, 1, 0)) 18 | assert(SolrSupport.isSolrVersionAtleast(solrVersion, 8, 0, 0)) 19 | assert(SolrSupport.isSolrVersionAtleast(solrVersion, 8, 1, 0)) 20 | // TODO: add more here? 21 | assert(!SolrSupport.isSolrVersionAtleast(solrVersion, 9, 0, 0)) 22 | } 23 | 24 | test("vary queried columns") { 25 | val collectionName = "testQuerying-" + UUID.randomUUID().toString 26 | SolrCloudUtil.buildCollection(zkHost, collectionName, null, 1, cloudClient, sc) 27 | try { 28 | val csvDF = buildTestData() 29 | val solrOpts = Map("zkhost" -> zkHost, "collection" -> collectionName) 30 | csvDF.write.format("solr").options(solrOpts).mode(Overwrite).save() 31 | 32 | // Explicit commit to make sure all docs are visible 33 | val solrCloudClient = SolrSupport.getCachedCloudClient(zkHost) 34 | solrCloudClient.commit(collectionName, true, true) 35 | 36 | val solrDF = sparkSession.read.format("solr").options(solrOpts).load() 37 | assert(solrDF.count == 3) 38 | assert(solrDF.schema.fields.length === 5) // _root_ id one_txt two_txt three_s 39 | val oneColFirstRow = solrDF.select("one_txt").head()(0) // query for one column 40 | assert(oneColFirstRow != null) 41 | val firstRow = solrDF.head.toSeq // query for all columns 42 | assert(firstRow.size === 5) 43 | firstRow.foreach(col => assert(col != null)) // no missing values 44 | 45 | } finally { 46 | SolrCloudUtil.deleteCollection(collectionName, cluster) 47 | } 48 | } 49 | 50 | 51 | test("vary queried columns with fields option") { 52 | val collectionName = "testQuerying-" + UUID.randomUUID().toString 53 | SolrCloudUtil.buildCollection(zkHost, collectionName, null, 2, cloudClient, sc) 54 | try { 55 | val csvDF = buildTestData() 56 | val solrOpts = Map("zkhost" -> zkHost, "collection" -> collectionName, "fields" -> "id,one_txt,two_txt") 57 | csvDF.write.format("solr").options(solrOpts).mode(Overwrite).save() 58 | 59 | // Explicit commit to make sure all docs are visible 60 | val solrCloudClient = SolrSupport.getCachedCloudClient(zkHost) 61 | solrCloudClient.commit(collectionName, true, true) 62 | 63 | val solrDF = sparkSession.read.format("solr").options(solrOpts).load() 64 | assert(solrDF.count == 3) 65 | assert(solrDF.schema.fields.length === 3) 66 | 67 | // Query for one column 68 | val oneColFirstRow = solrDF.select("one_txt").head()(0) // query for one column 69 | assert(oneColFirstRow != null) 70 | 71 | // Query for all columns 72 | val firstRow = solrDF.head.toSeq 73 | assert(firstRow.size === 3) 74 | firstRow.foreach(col => assert(col != null)) // no missing values 75 | } finally { 76 | SolrCloudUtil.deleteCollection(collectionName, cluster) 77 | } 78 | } 79 | 80 | test("querying multiple collections") { 81 | val collection1Name = "testQuerying-" + UUID.randomUUID().toString 82 | val collection2Name="testQuerying-" + UUID.randomUUID().toString 83 | SolrCloudUtil.buildCollection(zkHost, collection1Name, null, 2, cloudClient, sc) 84 | SolrCloudUtil.buildCollection(zkHost, collection2Name, null, 2, cloudClient, sc) 85 | try { 86 | val csvDF = buildTestData() 87 | val solrOpts_writing1 = Map("zkhost" -> zkHost, "collection" -> collection1Name) 88 | val solrOpts_writing2 = Map("zkhost" -> zkHost, "collection" -> collection2Name) 89 | val solrOpts = Map("zkhost" -> zkHost, "collection" -> s"$collection1Name,$collection2Name") 90 | 91 | 92 | csvDF.write.format("solr").options(solrOpts_writing1).mode(Overwrite).save() 93 | csvDF.write.format("solr").options(solrOpts_writing2).mode(Overwrite).save() 94 | 95 | // Explicit commit to make sure all docs are visible 96 | val solrCloudClient = SolrSupport.getCachedCloudClient(zkHost) 97 | solrCloudClient.commit(collection1Name, true, true) 98 | solrCloudClient.commit(collection2Name, true, true) 99 | 100 | val solrDF = sparkSession.read.format("solr").options(solrOpts).load() 101 | assert(solrDF.count == 6) 102 | assert(solrDF.schema.fields.length === 5) // _root_ id one_txt two_txt three_s 103 | val oneColFirstRow = solrDF.select("one_txt").head()(0) // query for one column 104 | assert(oneColFirstRow != null) 105 | val firstRow = solrDF.head.toSeq // query for all columns 106 | assert(firstRow.size === 5) 107 | firstRow.foreach(col => assert(col != null)) // no missing values 108 | } finally { 109 | SolrCloudUtil.deleteCollection(collection1Name, cluster) 110 | SolrCloudUtil.deleteCollection(collection2Name, cluster) 111 | } 112 | } 113 | 114 | 115 | def buildTestData() : DataFrame = { 116 | val testDataSchema : StructType = StructType( 117 | StructField("id", IntegerType, true) :: 118 | StructField("one_txt", StringType, false) :: 119 | StructField("two_txt", StringType, false) :: 120 | StructField("three_s", StringType, false) :: Nil) 121 | 122 | val rows = Seq( 123 | Row(1, "A", "B", "C"), 124 | Row(2, "C", "D", "E"), 125 | Row(3, "F", "G", "H") 126 | ) 127 | 128 | val csvDF : DataFrame = sparkSession.createDataFrame(sparkSession.sparkContext.makeRDD(rows, 1), testDataSchema) 129 | assert(csvDF.count == 3) 130 | return csvDF 131 | } 132 | } 133 | -------------------------------------------------------------------------------- /src/test/scala/com/lucidworks/spark/TestShardSplits.scala: -------------------------------------------------------------------------------- 1 | package com.lucidworks.spark 2 | 3 | import com.lucidworks.spark.util.SolrSupport 4 | import com.lucidworks.spark.util.SolrSupport.WorkerShardSplit 5 | import org.apache.solr.client.solrj.SolrQuery 6 | 7 | // Unit tests for testing splits for select and export handler 8 | class TestShardSplits extends SparkSolrFunSuite{ 9 | val splitFieldName = "_version_" 10 | 11 | test("Shard splits with 1 shard and 2 replicas") { 12 | val query = new SolrQuery("*:*") 13 | val solrShard = SolrShard("shard1", List( 14 | SolrReplica(0, "replica1", "http://replica1", "localhost", Array()), 15 | SolrReplica(1, "replica2", "http://replica2", "localhost", Array())) 16 | ) 17 | 18 | val splits: List[WorkerShardSplit] = SolrSupport.getShardSplits(query, solrShard, splitFieldName, 5) 19 | assert(splits.size == 5) 20 | splits.zipWithIndex.foreach { 21 | case (split, i) => 22 | assert(split.query.get("partitionKeys") == splitFieldName) 23 | assert(split.query.getFilterQueries.apply(0) == s"{!hash workers=5 worker=$i}") 24 | if (i%2 == 0) 25 | assert(split.replica.replicaName == "replica1") 26 | else 27 | assert(split.replica.replicaName == "replica2") 28 | } 29 | } 30 | 31 | 32 | test("Shard partitions with 2 shards and 2 replicas") { 33 | val query = new SolrQuery("*:*") 34 | val solrShard1 = SolrShard("shard1", List( 35 | SolrReplica(0, "replica1", "http://replica1", "localhost", Array()), 36 | SolrReplica(1, "replica2", "http://replica2", "localhost", Array())) 37 | ) 38 | val solrShard2 = SolrShard("shard2", List( 39 | SolrReplica(0, "replica1", "http://replica1", "localhost", Array()), 40 | SolrReplica(1, "replica2", "http://replica2", "localhost", Array())) 41 | ) 42 | val solrShards = List(solrShard1, solrShard2) 43 | 44 | val partitions = SolrPartitioner.getSplitPartitions(solrShards, query, splitFieldName, 2) 45 | assert(partitions.length == 4) 46 | partitions.zipWithIndex.foreach { 47 | case (partition, i) => 48 | val spartition = partition.asInstanceOf[SelectSolrRDDPartition] 49 | assert(spartition.cursorMark == "*") 50 | assert(spartition.query.get("partitionKeys") == splitFieldName) 51 | assert(spartition.query.getFilterQueries.apply(0) == s"{!hash workers=2 worker=${i%2}}") 52 | if (i < 2) 53 | assert(spartition.solrShard == solrShard1) 54 | if (i > 2) 55 | assert(spartition.solrShard == solrShard2) 56 | if (i%2 == 0) 57 | assert(spartition.preferredReplica.replicaName == "replica1") 58 | else 59 | assert(spartition.preferredReplica.replicaName == "replica2") 60 | } 61 | } 62 | 63 | test("Export handler splits with 1 shard and 2 replicas") { 64 | val query = new SolrQuery("*:*") 65 | val shard = SolrShard("shard1", List( 66 | SolrReplica(0, "replica1", "http://replica1", "localhost", Array()), 67 | SolrReplica(1, "replica2", "http://replica2", "localhost", Array())) 68 | ) 69 | val splits = SolrSupport.getExportHandlerSplits(query, shard, splitFieldName, 4) 70 | splits.zipWithIndex.foreach{ 71 | case (split, i) => 72 | assert(split.query.equals(query)) 73 | assert(split.numWorkers == 4) 74 | assert(split.workerId == i) 75 | if (i%2 == 0) 76 | assert(split.replica.replicaName == "replica1") 77 | else 78 | assert(split.replica.replicaName == "replica2") 79 | 80 | } 81 | } 82 | 83 | test("Export handler partitions with 2 shards and 2 replicas") { 84 | val query = new SolrQuery("*:*") 85 | val solrShard1 = SolrShard("shard1", List( 86 | SolrReplica(0, "replica1", "http://replica1", "localhost", Array()), 87 | SolrReplica(1, "replica2", "http://replica2", "localhost", Array())) 88 | ) 89 | val solrShard2 = SolrShard("shard2", List( 90 | SolrReplica(0, "replica1", "http://replica1", "localhost", Array()), 91 | SolrReplica(1, "replica2", "http://replica2", "localhost", Array())) 92 | ) 93 | val solrShards = List(solrShard1, solrShard2) 94 | 95 | val partitions = SolrPartitioner.getExportHandlerPartitions(solrShards, query, splitFieldName, 4) 96 | assert(partitions.length == 8) 97 | partitions.zipWithIndex.foreach { 98 | case (partition, i) => 99 | val hpartition = partition.asInstanceOf[ExportHandlerPartition] 100 | assert(hpartition.index == i) 101 | assert(hpartition.numWorkers == 4) 102 | if (i < 4) 103 | assert(hpartition.solrShard == solrShard1) 104 | if (i > 4) 105 | assert(hpartition.solrShard == solrShard2) 106 | if (i%2 == 0) 107 | assert(hpartition.preferredReplica.replicaName == "replica1") 108 | else 109 | assert(hpartition.preferredReplica.replicaName == "replica2") 110 | } 111 | } 112 | 113 | } 114 | -------------------------------------------------------------------------------- /src/test/scala/com/lucidworks/spark/TestSolrRelation.scala: -------------------------------------------------------------------------------- 1 | package com.lucidworks.spark 2 | 3 | import com.lucidworks.spark.util.SolrRelationUtil 4 | import org.apache.commons.lang3.StringEscapeUtils 5 | import org.apache.solr.client.solrj.SolrQuery 6 | import org.apache.spark.sql.sources.{And, EqualTo, Or} 7 | import org.apache.spark.sql.types._ 8 | 9 | class TestSolrRelation extends SparkSolrFunSuite with SparkSolrContextBuilder { 10 | 11 | test("Streaming expr schema") { 12 | val arrayJson = StringEscapeUtils.escapeJson(ArrayType(StringType).json) 13 | val longJson = StringEscapeUtils.escapeJson(LongType.json) 14 | val exprSchema = s"""arrayField:"${arrayJson}",longField:long""" 15 | val structFields = SolrRelation.parseSchemaExprSchemaToStructFields(exprSchema) 16 | assert(structFields.size == 2) 17 | logger.info(s"Parsed fields: ${structFields}") 18 | val arrayStructField = structFields.head 19 | val longStructField = structFields.last 20 | assert(arrayStructField.name === "arrayField") 21 | assert(arrayStructField.dataType === ArrayType(StringType)) 22 | assert(arrayStructField.metadata.getBoolean("multiValued")) 23 | assert(longStructField.name === "longField") 24 | assert(longStructField.dataType === LongType) 25 | } 26 | 27 | test("empty solr relation") { 28 | intercept[IllegalArgumentException] { 29 | new SolrRelation(Map.empty, None, sparkSession) 30 | } 31 | } 32 | 33 | test("Missing collection property") { 34 | intercept[IllegalArgumentException] { 35 | new SolrRelation(Map("zkhost" -> "localhost:121"), None, sparkSession).collection 36 | } 37 | } 38 | 39 | test("relation object creation") { 40 | val options = Map("zkhost" -> "dummy:9983", "collection" -> "test") 41 | val relation = new SolrRelation(options, None, sparkSession) 42 | assert(relation != null) 43 | } 44 | 45 | test("Scala filter expressions") { 46 | val filterExpr = Or(And(EqualTo("gender", "F"), EqualTo("artist", "Bernadette Peters")),And(EqualTo("gender", "M"), EqualTo("artist", "Girl Talk"))) 47 | val solrQuery = new SolrQuery 48 | val schema = StructType(Seq(StructField("gender", DataTypes.StringType), StructField("artist", DataTypes.StringType))) 49 | SolrRelationUtil.applyFilter(filterExpr, solrQuery, schema) 50 | val fq = solrQuery.getFilterQueries 51 | assert(fq.length == 1) 52 | assert(fq(0) === """((gender:"F" AND artist:"Bernadette Peters") OR (gender:"M" AND artist:"Girl Talk"))""") 53 | } 54 | 55 | test("custom field types option") { 56 | val fieldTypeOption = "a:b,c,d:e" 57 | val fieldTypes = SolrRelation.parseUserSuppliedFieldTypes(fieldTypeOption) 58 | assert(fieldTypes.size === 2) 59 | assert(fieldTypes.keySet === Set("a", "d")) 60 | assert(fieldTypes("a") === "b") 61 | assert(fieldTypes("d") === "e") 62 | } 63 | 64 | test("test commas in filter values") { 65 | val fieldValues = """a:"c,d e",f:g,h:"1, 35, 2"""" 66 | val parsedFilters = SolrRelationUtil.parseCommaSeparatedValuesToList(fieldValues) 67 | assert(parsedFilters.head === """a:"c,d e"""") 68 | assert(parsedFilters(1) === "f:g") 69 | assert(parsedFilters(2) === """h:"1, 35, 2"""") 70 | } 71 | } 72 | -------------------------------------------------------------------------------- /src/test/scala/com/lucidworks/spark/TestSolrStreamWriter.scala: -------------------------------------------------------------------------------- 1 | package com.lucidworks.spark 2 | 3 | import java.io.File 4 | import java.util.UUID 5 | 6 | import com.lucidworks.spark.util.{ConfigurationConstants, SolrCloudUtil, SolrQuerySupport, SolrSupport} 7 | import org.apache.commons.io.FileUtils 8 | import org.apache.spark.solr.SparkInternalObjects 9 | 10 | class TestSolrStreamWriter extends TestSuiteBuilder { 11 | 12 | test("Stream data into Solr") { 13 | val collectionName = "testStreaming-" + UUID.randomUUID().toString 14 | SolrCloudUtil.buildCollection(zkHost, collectionName, null, 1, cloudClient, sc) 15 | sparkSession.conf.set("spark.sql.streaming.schemaInference", "true") 16 | sparkSession.sparkContext.setLogLevel("DEBUG") 17 | val offsetsDir = FileUtils.getTempDirectory + "/spark-stream-offsets-" + UUID.randomUUID().toString 18 | try { 19 | val datasetPath = "src/test/resources/test-data/oneusagov" 20 | val streamingJsonDF = sparkSession.readStream.json(datasetPath) 21 | val accName = "acc-" + UUID.randomUUID().toString 22 | assert(streamingJsonDF.isStreaming) 23 | val writeOptions = Map( 24 | "collection" -> collectionName, 25 | "zkhost" -> zkHost, 26 | "checkpointLocation" -> offsetsDir, 27 | ConfigurationConstants.GENERATE_UNIQUE_KEY -> "true", 28 | ConfigurationConstants.ACCUMULATOR_NAME -> accName) 29 | val streamingQuery = streamingJsonDF 30 | .drop("_id") 31 | .writeStream 32 | .outputMode("append") 33 | .format("solr") 34 | .options(writeOptions) 35 | .start() 36 | try { 37 | logger.info(s"Explain ${streamingQuery.explain()}") 38 | streamingQuery.processAllAvailable() 39 | logger.info(s"Status ${streamingQuery.status}") 40 | SolrSupport.getCachedCloudClient(zkHost).commit(collectionName) 41 | assert(SolrQuerySupport.getNumDocsFromSolr(collectionName, zkHost, None) === 13) 42 | val acc = SparkInternalObjects.getAccumulatorById(SparkSolrAccumulatorContext.getId(accName).get) 43 | assert(acc.isDefined) 44 | assert(acc.get.value == 13) 45 | } finally { 46 | streamingQuery.stop() 47 | } 48 | } finally { 49 | SolrCloudUtil.deleteCollection(collectionName, cluster) 50 | FileUtils.deleteDirectory(new File(offsetsDir)) 51 | } 52 | } 53 | } 54 | -------------------------------------------------------------------------------- /src/test/scala/com/lucidworks/spark/examples/TwitterTestSuite.scala: -------------------------------------------------------------------------------- 1 | package com.lucidworks.spark.examples 2 | 3 | import com.lucidworks.spark.SparkSolrFunSuite 4 | import com.lucidworks.spark.util.SolrSupport 5 | import org.apache.solr.common.SolrInputDocument 6 | import twitter4j.TwitterObjectFactory 7 | 8 | class TwitterTestSuite extends SparkSolrFunSuite { 9 | 10 | test("Test twitter field object mapping") { 11 | val tweetJSON = "{ \"scopes\":{ \"place_ids\":[ \"place one\",\"place two\"]}, \"created_at\":\"Tue Mar 05 23:57:32 +0000 2013\", \"id\":309090333021581313, \"id_str\":\"309090333021581313\", \"text\":\"As announced, @anywhere has been retired per https:\\/\\/t.co\\/bWXjhurvwp The js file now logs a message to the console and exits quietly. ^ARK\", \"source\":\"web\", \"truncated\":false, \"in_reply_to_status_id\":null, \"in_reply_to_status_id_str\":null, \"in_reply_to_user_id\":null, \"in_reply_to_user_id_str\":null, \"in_reply_to_screen_name\":null, \"user\":{ \"id\":6253282, \"id_str\":\"6253282\", \"name\":\"Twitter API\", \"screen_name\":\"twitterapi\", \"location\":\"San Francisco, CA\", \"description\":\"The Real Twitter API. I tweet about API changes, service issues and happily answer questions about Twitter and our API. Don't get an answer? It's on my website.\", \"url\":\"http:\\/\\/dev.twitter.com\", \"entities\":{ \"url\":{ \"urls\":[ { \"url\":\"http:\\/\\/dev.twitter.com\", \"expanded_url\":null, \"indices\":[ 0, 22 ] } ] }, \"description\":{ \"urls\":[ ] } }, \"protected\":false, \"followers_count\":1533137, \"friends_count\":33, \"listed_count\":11369, \"created_at\":\"Wed May 23 06:01:13 +0000 2007\", \"favourites_count\":25, \"utc_offset\":-28800, \"time_zone\":\"Pacific Time (US & Canada)\", \"geo_enabled\":true, \"verified\":true, \"statuses_count\":3392, \"lang\":\"en\", \"contributors_enabled\":true, \"is_translator\":false, \"profile_background_color\":\"C0DEED\", \"profile_background_image_url\":\"http:\\/\\/a0.twimg.com\\/profile_background_images\\/656927849\\/miyt9dpjz77sc0w3d4vj.png\", \"profile_background_image_url_https\":\"https:\\/\\/si0.twimg.com\\/profile_background_images\\/656927849\\/miyt9dpjz77sc0w3d4vj.png\", \"profile_background_tile\":true, \"profile_image_url\":\"http:\\/\\/a0.twimg.com\\/profile_images\\/2284174872\\/7df3h38zabcvjylnyfe3_normal.png\", \"profile_image_url_https\":\"https:\\/\\/si0.twimg.com\\/profile_images\\/2284174872\\/7df3h38zabcvjylnyfe3_normal.png\", \"profile_banner_url\":\"https:\\/\\/si0.twimg.com\\/profile_banners\\/6253282\\/1347394302\", \"profile_link_color\":\"0084B4\", \"profile_sidebar_border_color\":\"C0DEED\", \"profile_sidebar_fill_color\":\"DDEEF6\", \"profile_text_color\":\"333333\", \"profile_use_background_image\":true, \"default_profile\":false, \"default_profile_image\":false, \"following\":null, \"follow_request_sent\":false, \"notifications\":null }, \"geo\":null, \"coordinates\":null, \"place\":null, \"contributors\":[ 7588892 ], \"retweet_count\":74, \"entities\":{ \"hashtags\":[ ], \"urls\":[ { \"url\":\"https:\\/\\/t.co\\/bWXjhurvwp\", \"expanded_url\":\"https:\\/\\/dev.twitter.com\\/blog\\/sunsetting-anywhere\", \"display_url\":\"dev.twitter.com\\/blog\\/sunsettin…\", \"indices\":[ 45, 68 ] } ], \"user_mentions\":[ { \"screen_name\":\"anywhere\", \"name\":\"Anywhere\", \"id\":9576402, \"id_str\":\"9576402\", \"indices\":[ 14, 23 ] } ] }, \"favorited\":false, \"retweeted\":false, \"possibly_sensitive\":false, \"lang\":\"en\" }" 12 | val tweetStatusObj = TwitterObjectFactory.createStatus(tweetJSON) 13 | // simple mapping from primitives to dynamic Solr fields using reflection 14 | val doc: SolrInputDocument = SolrSupport.autoMapToSolrInputDoc("tweet-" + tweetStatusObj.getId, tweetStatusObj, null) 15 | logger.info("Mapped to Document: " + doc.toString) 16 | 17 | assert(doc.containsKey("createdAt_tdt")) 18 | assert(doc.containsKey("lang_s")) 19 | assert(doc.containsKey("favoriteCount_i")) 20 | assert(doc.containsKey("source_s")) 21 | assert(doc.containsKey("retweeted_b")) 22 | assert(doc.containsKey("retweetCount_i")) 23 | assert(doc.containsKey("inReplyToStatusId_l")) 24 | } 25 | 26 | } 27 | -------------------------------------------------------------------------------- /src/test/scala/com/lucidworks/spark/ml/SparkMLExamples.scala: -------------------------------------------------------------------------------- 1 | package com.lucidworks.spark.ml 2 | 3 | import com.lucidworks.spark.{SparkSolrContextBuilder, SparkSolrFunSuite} 4 | import org.apache.spark.ml.{Pipeline, PipelineModel} 5 | import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel} 6 | import org.apache.spark.ml.feature.{HashingTF, Tokenizer} 7 | import org.apache.spark.ml.linalg.{Vector, Vectors} 8 | import org.apache.spark.ml.param.ParamMap 9 | import org.apache.spark.sql.Row 10 | 11 | class SparkMLExamples extends SparkSolrFunSuite with SparkSolrContextBuilder { 12 | 13 | test("test ML example with estimators, transformers and params") { 14 | 15 | // Prepare training data from a list of (label, features) tuples. 16 | val training = sparkSession.createDataFrame(Seq( 17 | (1.0, Vectors.dense(0.0, 1.1, 0.1)), 18 | (0.0, Vectors.dense(2.0, 1.0, -1.0)), 19 | (0.0, Vectors.dense(2.0, 1.3, 1.0)), 20 | (1.0, Vectors.dense(0.0, 1.2, -0.5)) 21 | )).toDF("label", "features") 22 | 23 | // Create a LogisticRegression instance. This instance is an Estimator. 24 | val lr = new LogisticRegression() 25 | // Print out the parameters, documentation, and any default values. 26 | println("LogisticRegression parameters:\n" + lr.explainParams() + "\n") 27 | 28 | // We may set parameters using setter methods. 29 | lr.setMaxIter(10) 30 | .setRegParam(0.01) 31 | 32 | // Learn a LogisticRegression model. This uses the parameters stored in lr. 33 | val model1: LogisticRegressionModel = lr.fit(training) 34 | 35 | // Since model1 is a Model (i.e., a Transformer produced by an Estimator), 36 | // we can view the parameters it used during fit(). 37 | // This prints the parameter (name: value) pairs, where names are unique IDs for this 38 | // LogisticRegression instance. 39 | println("Model 1 was fit using parameters: " + model1.parent.extractParamMap) 40 | 41 | // We may alternatively specify parameters using a ParamMap, 42 | // which supports several methods for specifying parameters. 43 | val paramMap = ParamMap(lr.maxIter -> 20) 44 | .put(lr.maxIter, 30) // Specify 1 Param. This overwrites the original maxIter. 45 | .put(lr.regParam -> 0.1, lr.threshold -> 0.55) // Specify multiple Params. 46 | 47 | // One can also combine ParamMaps. 48 | val paramMap2 = ParamMap(lr.probabilityCol -> "myProbability") // Change output column name. 49 | val paramMapCombined = paramMap ++ paramMap2 50 | 51 | // Now learn a new model using the paramMapCombined parameters. 52 | // paramMapCombined overrides all parameters set earlier via lr.set* methods. 53 | val model2 = lr.fit(training, paramMapCombined) 54 | println("Model 2 was fit using parameters: " + model2.parent.extractParamMap) 55 | 56 | // Prepare test data. 57 | val test = sparkSession.createDataFrame(Seq( 58 | (1.0, Vectors.dense(-1.0, 1.5, 1.3)), 59 | (0.0, Vectors.dense(3.0, 2.0, -0.1)), 60 | (1.0, Vectors.dense(0.0, 2.2, -1.5)) 61 | )).toDF("label", "features") 62 | 63 | // Make predictions on test data using the Transformer.transform() method. 64 | // LogisticRegression.transform will only use the 'features' column. 65 | // Note that model2.transform() outputs a 'myProbability' column instead of the usual 66 | // 'probability' column since we renamed the lr.probabilityCol parameter previously. 67 | model2.transform(test) 68 | .select("features", "label", "myProbability", "prediction") 69 | .collect() 70 | .foreach { case Row(features: Vector, label: Double, prob: Vector, prediction: Double) => 71 | println(s"($features, $label) -> prob=$prob, prediction=$prediction") 72 | } 73 | } 74 | 75 | test("Pipeline in Spark ML world") { 76 | // Prepare training documents from a list of (id, text, label) tuples. 77 | val training = sparkSession.createDataFrame(Seq( 78 | (0L, "a b c d e spark", 1.0), 79 | (1L, "b d", 0.0), 80 | (2L, "spark f g h", 1.0), 81 | (3L, "hadoop mapreduce", 0.0) 82 | )).toDF("id", "text", "label") 83 | 84 | // Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. 85 | val tokenizer = new Tokenizer() 86 | .setInputCol("text") 87 | .setOutputCol("words") 88 | val hashingTF = new HashingTF() 89 | .setNumFeatures(1000) 90 | .setInputCol(tokenizer.getOutputCol) 91 | .setOutputCol("features") 92 | val lr = new LogisticRegression() 93 | .setMaxIter(10) 94 | .setRegParam(0.001) 95 | val pipeline = new Pipeline() 96 | .setStages(Array(tokenizer, hashingTF, lr)) 97 | 98 | // Fit the pipeline to training documents. 99 | val model = pipeline.fit(training) 100 | 101 | // Now we can optionally save the fitted pipeline to disk 102 | model.write.overwrite().save("/tmp/spark-logistic-regression-model") 103 | 104 | // We can also save this unfit pipeline to disk 105 | pipeline.write.overwrite().save("/tmp/unfit-lr-model") 106 | 107 | // And load it back in during production 108 | val sameModel = PipelineModel.load("/tmp/spark-logistic-regression-model") 109 | 110 | // Prepare test documents, which are unlabeled (id, text) tuples. 111 | val test = sparkSession.createDataFrame(Seq( 112 | (4L, "spark i j k"), 113 | (5L, "l m n"), 114 | (6L, "spark hadoop spark"), 115 | (7L, "apache hadoop") 116 | )).toDF("id", "text") 117 | 118 | // Make predictions on test documents. 119 | model.transform(test) 120 | .select("id", "text", "probability", "prediction") 121 | .collect() 122 | .foreach { case Row(id: Long, text: String, prob: Vector, prediction: Double) => 123 | println(s"($id, $text) --> prob=$prob, prediction=$prediction") 124 | } 125 | 126 | } 127 | } 128 | --------------------------------------------------------------------------------