├── .gitignore ├── .travis.yml ├── LICENSE ├── README.md ├── build.sbt ├── data └── ml-1m.7z ├── project ├── Dependencies.scala ├── Settings.scala └── build.properties └── src ├── main ├── resources │ └── log4j.properties └── scala │ └── com │ └── lendap │ └── spark │ └── lsh │ ├── Hasher.scala │ ├── LSH.scala │ ├── LSHModel.scala │ └── Main.scala └── test └── scala └── com └── lendap └── spark └── lsh ├── LSHTestSuit.scala └── LocalSparkContext.scala /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by .ignore support plugin (hsz.mobi) 2 | ### Scala template 3 | *.class 4 | *.log 5 | 6 | # sbt specific 7 | .cache 8 | .history 9 | .lib/ 10 | data/*.data 11 | dist/* 12 | target/ 13 | lib_managed/ 14 | src_managed/ 15 | project/boot/ 16 | project/plugins/project/ 17 | 18 | # Scala-IDE specific 19 | .scala_dependencies 20 | .worksheet 21 | .idea/* 22 | 23 | 24 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: scala 2 | scala: 3 | - 2.10.4 4 | - 2.11.7 5 | - 2.11.8 6 | - 2.11.10 7 | - 2.11.11 8 | jdk: 9 | - openjdk7 10 | - openjdk8 11 | sudo: false 12 | script: 13 | - sbt -jvm-opts travis/jvmopts.compile compile 14 | - sbt ++$TRAVIS_SCALA_VERSION test -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "{}" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright {yyyy} {name of copyright owner} 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | 203 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Locality Sensitive Hashing for Apache Spark # 2 | 3 | [![Build Status](https://travis-ci.org/marufaytekin/lsh-spark.svg?branch=master)](https://travis-ci.org/marufaytekin/lsh-spark) 4 | 5 | Locality-sensitive hashing (LSH) is an approximate nearest neighbor search and 6 | clustering method for high dimensional data points (http://www.mit.edu/~andoni/LSH/). 7 | Locality-Sensitive functions 8 | take two data points and decide about whether or not they should be a candidate 9 | pair. LSH hashes input data points multiple times in a way that similar data 10 | points map to the same "buckets" with a high probability than dissimilar data 11 | points. The data points map to the same buckets are considered as candidate pair. 12 | 13 | There are different LSH schemes for different distance measures. This implementation 14 | is based on Charikar's LSH schema for cosine distance described in 15 | [Similarity Estimation Techniques from Rounding Algorithms] 16 | (http://www.cs.princeton.edu/courses/archive/spr04/cos598B/bib/CharikarEstim.pdf) 17 | paper. This scheme uses random hyperplane based hash functions for collection of 18 | vectors to produce hash values. The model build (preprocessing) and query answering 19 | algorithms implemented as described in Figures 1 and 2 of http://www.vldb.org/conf/1999/P49.pdf. 20 | 21 | The implementation is inspired from spark-hash project on github. 22 | 23 | ## Build ## 24 | 25 | This is an SBT project. 26 | ``` 27 | sbt clean compile 28 | ``` 29 | 30 | ## Usage ## 31 | 32 | "Main.scala" provided in this package contains sample code for usage of LSH package. 33 | In the following sections we will demonstrate usage of LSH to group similar users 34 | that rate on items. We will use famous movie-lens data set for demonstration. The 35 | data set contains user/item ratings in (user::item::rating::time) format. The 36 | zipped version of data set is provided in "data" directory of this project. 37 | 38 | We would like to group similar users by using LSH method. As preprocessing step 39 | we read the data set and create RDD of Tuple3 version of data in (user, item, rating) 40 | format: 41 | 42 | ```scala 43 | //read data file in as a RDD, partition RDD across cores 44 | val data = sc.textFile(dataFile, numPartitions) 45 | 46 | //parse data and create (user, item, rating) tuple 47 | val ratingsRDD = data 48 | .map(line => line.split("::")) 49 | .map(elems => (elems(0).toInt, elems(1).toInt, elems(2).toDouble)) 50 | ``` 51 | 52 | 53 | We need to represent each user as a vector of ratings to be able to calculate 54 | cosine similarity of users. In order to convert rating of a user to SparseVector, 55 | we need to determine the possible largest vector index in data set. We set it to 56 | maximum index of items since item numbers will be indices of SparseVector. Maximum 57 | index value also will be used for generating random vectors in hashers. 58 | 59 | ```scala 60 | //list of distinct items 61 | val items = ratingsRDD.map(x => x._2).distinct() 62 | val maxIndex = items.max + 1 63 | ``` 64 | 65 | Now we are ready to convert users data to RDD of Tuple2 as (user_id, SparseVector). 66 | SparseVector of a user is created by using a list of (item, rating) pairs as (index, 67 | value) pairs. 68 | 69 | 70 | ```scala 71 | //user ratings grouped by user_id 72 | val userItemRatings = ratingsRDD.map(x => (x._1, (x._2, x._3))).groupByKey().cache() 73 | 74 | //convert each user's rating to tuple of (user_id, SparseVector_of_ratings) 75 | val sparseVectorData = userItemRatings 76 | .map(a=>(a._1.toLong, Vectors.sparse(maxIndex, a._2.toSeq).asInstanceOf[SparseVector])) 77 | ``` 78 | 79 | Finally, we use sparseVectorData to build LSH model as follows: 80 | 81 | 82 | ```scala 83 | //run locality sensitive hashing model with 6 hashTables and 8 hash functions 84 | val lsh = new LSH(sparseVectorData, maxIndex, numHashFunc = 8, numBands = 6) 85 | val model = lsh.run() 86 | ``` 87 | 88 | Number of hash functions (number of rows) for each hashTable and number of hashTables 89 | need to be given to LSH. See implementation details for more information for 90 | selecting number of hashTables and hash functions. 91 | 92 | ```scala 93 | //print sample hashed vectors in ((hashTableId#, hashValue), user_id) format 94 | model.hashTables.take(10) foreach println 95 | ``` 96 | 97 | Sample 10 entries from the model printed out as follows: 98 | 99 | ``` 100 | ((1,10100000),4289) 101 | ((5,01001100),649) 102 | ((3,10011011),5849) 103 | ((0,11000110),5221) 104 | ((1,01010100),3688) 105 | ((1,00001110),354) 106 | ((0,11000110),5118) 107 | ((3,00001011),3698) 108 | ((3,11010011),2941) 109 | ((2,11010010),4488) 110 | ``` 111 | 112 | ### Find Similar Users for User ID ### 113 | Find the similar users for user id: 4587 as follows: 114 | 115 | ```scala 116 | //get the near neighbors of userId: 4587 in the model 117 | val candList = model.getCandidates(4587) 118 | println("Number of Candidates: "+ candList.count()) 119 | println("Candidate List: " + candList.collect().toList) 120 | ``` 121 | 122 | 172 neighbors found for user 4587: 123 | 124 | ``` 125 | Number of Candidates: 172 126 | Candidate List: List(1708, 5297, 1973, 4691, 2864, 903, 30, 501, 2433, 3317, 2268, 4759, 1593, 2617, 3794, 2958, 5918, 3743, 1527, 5030, 1271, 4713, 4095, 2615, 1948, 597, 818, 1084, 5592, 3334, 2342, 3740, 2647, 3476, 2115, 2676, 1385, 2606, 1809, 584, 2341, 5063, 320, 1162, 4899, 5343, 5998, 1423, 1374, 2121, 1846, 3985, 529, 5654, 810, 1028, 5727, 1549, 3126, 2376, 3258, 5573, 5291, 1752, 4727, 187, 1159, 2114, 1028, 4747, 4852, 2390, 3404, 900, 5016, 3576, 5855, 1959, 2964, 2171, 5940, 2521, 171, 5375, 2125, 3357, 2217, 1227, 5949, 2722, 4943, 1575, 1319, 1529, 618, 370, 1280, 5164, 5340, 1166, 4332, 1845, 4158, 5724, 1938, 4953, 2128, 492, 595, 3852, 2915, 4789, 159, 124, 989, 4702, 4259, 2733, 2623, 5431, 1398, 4172, 629, 86, 2726, 5690, 563, 5977, 3538, 2476, 1855, 2904, 3168, 769, 4429, 1470, 1829, 1461, 5335, 5125, 922, 5772, 5109, 643, 131, 4421, 5259, 1960, 738, 383, 5906, 1989, 1902, 469, 500, 15, 939, 1292, 53, 5437, 3721, 3143, 5393, 1789, 1465, 2519, 3001, 4016, 5967, 3203, 3295, 5208) 127 | ``` 128 | 129 | ### Find Similar Users for Vectors ### 130 | 131 | We will find the similar users to a user by using user's rating data on movies. We first 132 | convert this data to a SparseVector as follows: 133 | ```scala 134 | val movies = List(1,6,17,29,32,36,76,137,154,161,172,173,185,223,232,235,260,272,296,300,314,316,318,327,337,338,348) 135 | val ratings = List(5.0,4.0,4.0,5.0,5.0,4.0,5.0,3.0,4.0,4.0,4.0,4.0,4.0,5.0,5.0,4.0,5.0,5.0,4.0,4.0,4.0,5.0,5.0,5.0,4.0,4.0,4.0) 136 | val sampleVector = Vectors.sparse(maxIndex, movies zip ratings).asInstanceOf[SparseVector] 137 | ``` 138 | Then query LSH model for candidate user list for sampleVector: 139 | ```scala 140 | val candidateList = model.getCandidates(sampleVector) 141 | println(candidateList.collect().toList) 142 | ``` 143 | 144 | Following user list is returned as candidate list: 145 | ``` 146 | List(3925, 4607, 3292, 2919, 240, 4182, 5244, 1452, 4526, 3831, 305, 4341, 2939, 2731, 627, 5685, 1656, 3597, 3268, 2908, 1675, 5124, 4588, 5112, 4620, 890, 3655, 5642, 4737, 372, 5916, 3806, 6037, 5384, 1888, 4059, 996, 660, 889, 5020, 2871, 2107, 5080, 1638, 588, 4486, 2945, 335, 2013, 363, 1257, 117, 2848, 417, 1101, 2171, 4526, 147, 411, 3709, 3941, 904, 4442, 1576, 1177, 3844, 5527, 5280, 2998, 287, 3575, 4461, 1548, 5698, 2039, 5283, 5454, 1288, 741, 1496, 11, 3829, 4201, 985, 3862, 2908, 3658, 3594, 5970, 1115, 5690, 5082, 5707, 6030, 555, 4260, 780, 6028, 1353, 5433, 1593, 3933, 5328, 3649, 2700, 3117, 215, 4944, 4266, 3388, 5079, 1483, 1762, 2654) 147 | ``` 148 | 149 | ### Find Similarity of Vectors ### 150 | 151 | Let *a* and *b* two sparse vectors for two users. We can find similarity of these 152 | users based on cosine similarity as follows: 153 | ```scala 154 | val similarity = lsh.cosine(a, b) 155 | ``` 156 | 157 | ### Hash Values for Vectors ### 158 | 159 | We can retrieve hash values for a vector as follows: 160 | ```scala 161 | val hashValues = model.hashValue(sampleVector) 162 | println(hashValues) 163 | ``` 164 | 165 | Generated list of hash values for each hashTable in (hashTable#, hashValue) format: 166 | 167 | ``` 168 | List((0,10101100), (5,01110100), (1,01001110), (2,10000000), (3,10101111), (4,00101100)) 169 | ``` 170 | 171 | ### Hash Values ### 172 | 173 | We can retrieve list of hashValues in each bucket as follows: 174 | 175 | ```scala 176 | val bucketHashValues = hashTables.map(x => x._1).groupByKey() 177 | ``` 178 | 179 | This returns an RDD [(Int, Iterable [String])] 180 | 181 | 182 | ### Add New User ### 183 | 184 | We add new user with ratings vector as follows: 185 | ```scala 186 | val model = model.add(id, v, sc) 187 | ``` 188 | where id, v, and sc are user id, SparseVector of ratings, and SparkContext respectively. 189 | 190 | ### Remove an Existing User ### 191 | 192 | We delete an existing user from the model as follows: 193 | ```scala 194 | val model = model.remove(id, sc) 195 | ``` 196 | where id is user id and sc is SparkContext. 197 | 198 | ### Save/Load The Model ### 199 | 200 | Trained model can be saved and loaded to/from HDFS as follows: 201 | 202 | ```scala 203 | //save model 204 | val temp = "target/" + System.currentTimeMillis().toString 205 | model.save(sc, temp) 206 | 207 | //load model 208 | val modelLoaded = LSHModel.load(sc, temp) 209 | 210 | //print out 10 entries from loaded model 211 | modelLoaded.hashTables.take(10) foreach println 212 | ``` 213 | Sample 10 entries from loaded model printed out as follows: 214 | 215 | ``` 216 | ((1,11101110),4289) 217 | ((5,11100001),649) 218 | ((3,11001111),5849) 219 | ((0,10100101),5221) 220 | ((1,01110001),3688) 221 | ((1,11110010),354) 222 | ((0,10010100),5118) 223 | ((3,10011010),3698) 224 | ((3,10100010),2941) 225 | ((2,11010101),4488) 226 | ``` 227 | 228 | ## Implementation Details ## 229 | 230 | - LSH hashes each vector multiple times (b * r) with hash functions, where *b* is number 231 | of hash tables (bands) and *r* is number of hash functions (rows) in each hash table. 232 | 233 | - If we define *t* as similarity threshold for vectors to be considered as a desired 234 | “similar pair.” The threshold *t* is approximately (1/b)1/r. Select *b* and *r* 235 | to produce a threshold lower than *t* to avoid false negatives, select *b* and *r* to 236 | produce a higher threshold to increase speed and decrease false positives (See section 237 | 3.4.3 of [Mining of Massive Datasets] (http://mmds.org) for details.) 238 | 239 | - Hasher function is defined in com.lendap.lsh.Hasher class and uses random hyperplane 240 | based hash functions which operate on vectors. 241 | 242 | - Hasher functions use randomly generated vectors whose elements are in [-1, 1] 243 | interval. It is sufficiently random if we randomly select vectors whose components 244 | are +1 and -1 (See section 3.7.3 of [Mining of Massive Datasets] 245 | (http://mmds.org).) 246 | 247 | - Hashing function calculates dot product of an input vector with a randomly generated 248 | hash function then produce a hash value (0 or 1) based on the result of dot product. 249 | Each hasher produce a hash value for the vector. Then all hash values are combined 250 | with *AND-construction* to produce a hash signature (e.g. 11110010) for the input vector. 251 | 252 | - Hash signatures for the input vectors are used as bucket ids as described in 253 | http://www.vldb.org/conf/1999/P49.pdf. The model build (preprocessing) and 254 | query answering algorithms implemented as described in Figures 1 and 2 of this paper. 255 | 256 | - Hashed vectors are stored in model.hashTables as *RDD[((Int, String), Long)]* where each entry 257 | is *((hashTable#, hash_value), vector_id)* data. 258 | 259 | - The results can be filtered by passing a filter function to the model. 260 | 261 | - Trained model can be saved to HDFS with *model.save* function. 262 | 263 | - Saved model can be loaded from HDFS with *model.load* function. 264 | 265 | -------------------------------------------------------------------------------- /build.sbt: -------------------------------------------------------------------------------- 1 | 2 | lazy val lsh = (project in file(".")). 3 | settings(Settings.settings: _*). 4 | settings(Settings.lshSettings: _*). 5 | settings(libraryDependencies ++=Dependencies.lshDependencies ) -------------------------------------------------------------------------------- /data/ml-1m.7z: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marufaytekin/lsh-spark/5fd7f155c137dbf5e9a9c34d83570dd636d7d496/data/ml-1m.7z -------------------------------------------------------------------------------- /project/Dependencies.scala: -------------------------------------------------------------------------------- 1 | import sbt._ 2 | 3 | object Dependencies { 4 | 5 | lazy val version = new { 6 | val scalaTest = "2.2.2" 7 | val spark = "1.4.1" 8 | } 9 | 10 | lazy val library = new { 11 | val sparkCore ="org.apache.spark" %% "spark-core" % version.spark 12 | val sparkMLib ="org.apache.spark" %% "spark-mllib" % version.spark 13 | val test = "org.scalatest" %% "scalatest" % version.scalaTest % Test 14 | } 15 | 16 | val lshDependencies: Seq[ModuleID] = Seq( 17 | library.sparkCore, 18 | library.sparkMLib, 19 | library.test 20 | ) 21 | 22 | } 23 | -------------------------------------------------------------------------------- /project/Settings.scala: -------------------------------------------------------------------------------- 1 | import sbt._ 2 | import Keys._ 3 | 4 | object Settings { 5 | lazy val settings = Seq( 6 | organization := "com.lendap", 7 | version := "0.1." + sys.props.getOrElse("buildNumber", default="0-SNAPSHOT"), 8 | scalaVersion := "2.10.4", 9 | publishMavenStyle := true, 10 | publishArtifact in Test := false 11 | ) 12 | 13 | lazy val testSettings = Seq( 14 | fork in Test := false, 15 | parallelExecution in Test := false 16 | ) 17 | 18 | lazy val lshSettings = Seq( 19 | name := "lsh-scala" 20 | ) 21 | } 22 | -------------------------------------------------------------------------------- /project/build.properties: -------------------------------------------------------------------------------- 1 | sbt.version=0.13.15 -------------------------------------------------------------------------------- /src/main/resources/log4j.properties: -------------------------------------------------------------------------------- 1 | # Set everything to be logged to the console 2 | log4j.rootCategory=ERROR, console 3 | log4j.appender.console=org.apache.log4j.ConsoleAppender 4 | log4j.appender.console.target=System.err 5 | log4j.appender.console.layout=org.apache.log4j.PatternLayout 6 | log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n 7 | 8 | # Settings to quiet third party logs that are too verbose 9 | log4j.logger.org.eclipse.jetty=WARN 10 | log4j.logger.org.eclipse.jetty.util.component.AbstractLifeCycle=WARN 11 | log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO 12 | log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO 13 | -------------------------------------------------------------------------------- /src/main/scala/com/lendap/spark/lsh/Hasher.scala: -------------------------------------------------------------------------------- 1 | package com.lendap.spark.lsh 2 | 3 | import org.apache.spark.mllib.linalg.SparseVector 4 | import scala.collection.mutable.ArrayBuffer 5 | import scala.util.Random 6 | 7 | 8 | /** 9 | * Simple hashing function implements random hyperplane based hash functions described in 10 | * http://www.cs.princeton.edu/courses/archive/spring04/cos598B/bib/CharikarEstim.pdf 11 | * r is a random vector. Hash function h_r(u) operates as follows: 12 | * if r.u < 0 //dot product of two vectors 13 | * h_r(u) = 0 14 | * else 15 | * h_r(u) = 1 16 | */ 17 | class Hasher(val r: Array[Double]) extends Serializable { 18 | 19 | /** hash SparseVector v with random vector r */ 20 | def hash(u : SparseVector) : Int = { 21 | val rVec: Array[Double] = u.indices.map(i => r(i)) 22 | val hashVal = (rVec zip u.values).map(_tuple => _tuple._1 * _tuple._2).sum 23 | if (hashVal > 0) 1 else 0 24 | } 25 | 26 | } 27 | 28 | object Hasher { 29 | 30 | /** create a new instance providing size of the random vector Array [Double] */ 31 | def apply (size: Int, seed: Long = System.nanoTime) = new Hasher(r(size, seed)) 32 | 33 | /** create a random vector whose whose components are -1 and +1 */ 34 | def r(size: Int, seed: Long) : Array[Double] = { 35 | val buf = new ArrayBuffer[Double] 36 | val rnd = new Random(seed) 37 | for (_ <- 0 until size) 38 | buf += (if (rnd.nextGaussian() < 0) -1 else 1) 39 | buf.toArray 40 | } 41 | 42 | } 43 | -------------------------------------------------------------------------------- /src/main/scala/com/lendap/spark/lsh/LSH.scala: -------------------------------------------------------------------------------- 1 | package com.lendap.spark.lsh 2 | 3 | /** 4 | * Created by maruf on 09/08/15. 5 | */ 6 | 7 | import org.apache.spark.mllib.linalg.SparseVector 8 | import org.apache.spark.rdd.RDD 9 | 10 | /** Build LSH model with data RDD. Hash each vector number of hashTable times and stores in a bucket. 11 | * 12 | * 13 | * @param data RDD of sparse vectors with vector Ids. RDD(vec_id, SparseVector) 14 | * @param m max number of possible elements in a vector 15 | * @param numHashFunc number of hash functions 16 | * @param numHashTables number of hashTables. 17 | * 18 | * */ 19 | class LSH(data : RDD[(Long, SparseVector)] = null, m: Int = 0, numHashFunc : Int = 4, numHashTables : Int = 4) extends Serializable { 20 | 21 | def run() : LSHModel = { 22 | 23 | //create a new model object 24 | val model = new LSHModel(m, numHashFunc, numHashTables) 25 | 26 | val dataRDD = data.cache() 27 | 28 | //compute hash keys for each vector 29 | // - hash each vector numHashFunc times 30 | // - concat each hash value to create a hash key 31 | // - position hashTable id hash keys and vector id into a new RDD. 32 | // - creates RDD of ((hashTable#, hash_key), vec_id) tuples. 33 | model.hashTables = dataRDD 34 | .map(v => (model.hashFunctions.map(h => (h._1.hash(v._2), h._2 % numHashTables)), v._1)) 35 | .map(x => x._1.map(a => ((a._2, x._2), a._1))) 36 | .flatMap(a => a).groupByKey() 37 | .map(x => ((x._1._1, x._2.mkString("")), x._1._2)).cache() 38 | 39 | model 40 | 41 | } 42 | 43 | def cosine(a: SparseVector, b: SparseVector): Double = { 44 | val intersection = a.indices.intersect(b.indices) 45 | val magnitudeA = intersection.map(x => Math.pow(a.apply(x), 2)).sum 46 | val magnitudeB = intersection.map(x => Math.pow(b.apply(x), 2)).sum 47 | intersection.map(x => a.apply(x) * b.apply(x)).sum / (Math.sqrt(magnitudeA) * Math.sqrt(magnitudeB)) 48 | } 49 | 50 | } 51 | -------------------------------------------------------------------------------- /src/main/scala/com/lendap/spark/lsh/LSHModel.scala: -------------------------------------------------------------------------------- 1 | package com.lendap.spark.lsh 2 | 3 | /** 4 | * Created by maytekin on 06.08.2015. 5 | */ 6 | 7 | import org.apache.hadoop.fs.Path 8 | import org.apache.spark.SparkContext 9 | import org.apache.spark.mllib.linalg.SparseVector 10 | import org.apache.spark.rdd.RDD 11 | import scala.collection.mutable.ListBuffer 12 | import org.apache.spark.mllib.util.Saveable 13 | 14 | import org.json4s._ 15 | import org.json4s.JsonDSL._ 16 | import org.json4s.jackson.JsonMethods._ 17 | 18 | 19 | /** Create LSH model for maximum m number of elements in each vector. 20 | * 21 | * @param m max number of possible elements in a vector 22 | * @param numHashFunc number of hash functions 23 | * @param numHashTables number of hashTables. 24 | * 25 | * */ 26 | class LSHModel(val m: Int, val numHashFunc : Int, val numHashTables: Int) 27 | extends Serializable with Saveable { 28 | 29 | /** generate numHashFunc * numBands randomly generated hash functions and store them in hashFunctions */ 30 | private val _hashFunctions = ListBuffer[Hasher]() 31 | for (_ <- 0 until numHashFunc * numHashTables) 32 | _hashFunctions += Hasher(m) 33 | final var hashFunctions: List[(Hasher, Int)] = _hashFunctions.toList.zipWithIndex 34 | 35 | /** the "hashTables" ((hashTableID, hash key), vector_id) */ 36 | var hashTables: RDD[((Int, String), Long)] = _ 37 | 38 | /** generic filter function for hashTables. */ 39 | def filter(f: (((Int, String), Long)) => Boolean): RDD[((Int, String), Long)] = 40 | hashTables.map(a => a).filter(f) 41 | 42 | /** hash a single vector against an existing model and return the candidate buckets */ 43 | def filter(data: SparseVector, model: LSHModel, itemID: Long): RDD[Long] = { 44 | val hashKey = hashFunctions.map(h => h._1.hash(data)).mkString("") 45 | hashTables.filter(x => x._1._2 == hashKey).map(a => a._2) 46 | } 47 | 48 | /** creates hashValue for each hashTable.*/ 49 | def hashValue(data: SparseVector): List[(Int, String)] = 50 | hashFunctions.map(a => (a._2 % numHashTables, a._1.hash(data))) 51 | .groupBy(_._1) 52 | .map(x => (x._1, x._2.map(_._2).mkString(""))).toList 53 | 54 | /** returns candidate set for given vector id.*/ 55 | def getCandidates(vId: Long): RDD[Long] = { 56 | val buckets = hashTables.filter(x => x._2 == vId).map(x => x._1).distinct().collect() 57 | hashTables.filter(x => buckets contains x._1).map(x => x._2).filter(x => x != vId) 58 | } 59 | 60 | /** returns candidate set for given vector.*/ 61 | def getCandidates(v: SparseVector): RDD[Long] = { 62 | val hashVal = hashValue(v) 63 | hashTables.filter(x => hashVal contains x._1).map(x => x._2) 64 | } 65 | 66 | /** adds a new sparse vector with vector Id: vId to the model. */ 67 | def add (vId: Long, v: SparseVector, sc: SparkContext): LSHModel = { 68 | val newRDD = sc.parallelize(hashValue(v).map(a => (a, vId))) 69 | hashTables ++ newRDD 70 | this 71 | } 72 | 73 | /** remove sparse vector with vector Id: vId from the model. */ 74 | def remove (vId: Long, sc: SparkContext): LSHModel = { 75 | hashTables = hashTables.filter(x => x._2 != vId) 76 | this 77 | } 78 | 79 | override def save(sc: SparkContext, path: String): Unit = 80 | LSHModel.SaveLoadV0_0_1.save(sc, this, path) 81 | 82 | override protected def formatVersion: String = "0.0.1" 83 | 84 | } 85 | 86 | object LSHModel { 87 | 88 | def load(sc: SparkContext, path: String): LSHModel = { 89 | LSHModel.SaveLoadV0_0_1.load(sc, path) 90 | } 91 | 92 | private [lsh] object SaveLoadV0_0_1 { 93 | 94 | private val thisFormatVersion = "0.0.1" 95 | private val thisClassName = this.getClass.getName 96 | 97 | def save(sc: SparkContext, model: LSHModel, path: String): Unit = { 98 | 99 | val metadata = 100 | compact(render(("class" -> thisClassName) ~ ("version" -> thisFormatVersion))) 101 | 102 | //save metadata info 103 | sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) 104 | 105 | //save hash functions as (hashTableId, randomVector) 106 | sc.parallelize(model.hashFunctions 107 | .map(x => (x._2, x._1.r.mkString(","))) 108 | .map(_.productIterator.mkString(","))) 109 | .saveAsTextFile(Loader.hasherPath(path)) 110 | 111 | //save data as (hashTableId#, hashValue, vectorId) 112 | model.hashTables 113 | .map(x => (x._1._1, x._1._2, x._2)) 114 | .map(_.productIterator.mkString(",")) 115 | .saveAsTextFile(Loader.dataPath(path)) 116 | 117 | } 118 | 119 | def load(sc: SparkContext, path: String): LSHModel = { 120 | 121 | implicit val formats: DefaultFormats.type = DefaultFormats 122 | val (className, formatVersion, _) = Loader.loadMetadata(sc, path) 123 | assert(className == thisClassName) 124 | assert(formatVersion == thisFormatVersion) 125 | val hashTables = sc.textFile(Loader.dataPath(path)) 126 | .map(x => x.split(",")) 127 | .map(x => ((x(0).toInt, x(1)), x(2).toLong)) 128 | val hashers = sc.textFile(Loader.hasherPath(path)) 129 | .map(a => a.split(",")) 130 | .map(x => (x.head, x.tail)) 131 | .map(x => (new Hasher(x._2.map(_.toDouble)), x._1.toInt)).collect().toList 132 | val numBands = hashTables.map(x => x._1._1).distinct.count() 133 | val numHashFunc = hashers.size / numBands 134 | 135 | //Validate loaded data 136 | //check size of data 137 | assert(hashTables.count != 0, s"Loaded hashTable data is empty") 138 | //check size of hash functions 139 | assert(hashers.nonEmpty, s"Loaded hasher data is empty") 140 | //check hashValue size. Should be equal to numHashFunc 141 | assert(hashTables.map(x => x._1._2).filter(x => x.length != numHashFunc).collect().length == 0, 142 | s"hashValues in data does not match with hash functions") 143 | 144 | //create model 145 | val model = new LSHModel(0, numHashFunc.toInt, numBands.toInt) 146 | model.hashFunctions = hashers 147 | model.hashTables = hashTables 148 | 149 | model 150 | } 151 | } 152 | } 153 | 154 | 155 | /** Helper functions for save/load data from mllib package. 156 | * TODO: Remove and use Loader functions from mllib. */ 157 | private[lsh] object Loader { 158 | 159 | /** Returns URI for path/data using the Hadoop filesystem */ 160 | def dataPath(path: String): String = new Path(path, "data").toUri.toString 161 | 162 | /** Returns URI for path/metadata using the Hadoop filesystem */ 163 | def metadataPath(path: String): String = new Path(path, "metadata").toUri.toString 164 | 165 | /** Returns URI for path/metadata using the Hadoop filesystem */ 166 | def hasherPath(path: String): String = new Path(path, "hasher").toUri.toString 167 | 168 | /** 169 | * Load metadata from the given path. 170 | * @return (class name, version, metadata) 171 | */ 172 | def loadMetadata(sc: SparkContext, path: String): (String, String, JValue) = { 173 | implicit val formats: DefaultFormats.type = DefaultFormats 174 | val metadata = parse(sc.textFile(metadataPath(path)).first()) 175 | val clazz = (metadata \ "class").extract[String] 176 | val version = (metadata \ "version").extract[String] 177 | (clazz, version, metadata) 178 | } 179 | 180 | } -------------------------------------------------------------------------------- /src/main/scala/com/lendap/spark/lsh/Main.scala: -------------------------------------------------------------------------------- 1 | package com.lendap.spark.lsh 2 | 3 | import org.apache.spark.SparkContext 4 | import org.apache.spark.SparkConf 5 | 6 | import org.apache.spark.mllib.linalg.{Vectors, SparseVector} 7 | 8 | 9 | /** 10 | * Created by maytekin on 05.08.2015. 11 | */ 12 | object Main { 13 | 14 | /** Sample usage of LSH on movie rating data.*/ 15 | def main(args: Array[String]) { 16 | 17 | //init spark context 18 | val numPartitions = 8 19 | val dataFile = "data/ml-1m.data" 20 | val conf = new SparkConf() 21 | .setAppName("LSH") 22 | .setMaster("local[4]") 23 | val sc = new SparkContext(conf) 24 | 25 | //read data file in as a RDD, partition RDD across cores 26 | val data = sc.textFile(dataFile, numPartitions) 27 | 28 | //parse data and create (user, item, rating) tuples 29 | val ratingsRDD = data 30 | .map(line => line.split("::")) 31 | .map(elems => (elems(0).toInt, elems(1).toInt, elems(2).toDouble)) 32 | 33 | //list of distinct items 34 | val items = ratingsRDD.map(x => x._2).distinct() 35 | val maxIndex = items.max + 1 36 | 37 | //user ratings grouped by user_id 38 | val userItemRatings = ratingsRDD.map(x => (x._1, (x._2, x._3))).groupByKey().cache() 39 | 40 | //convert each user's rating to tuple of (user_id, SparseVector_of_ratings) 41 | val sparseVectorData = userItemRatings 42 | .map(a=>(a._1.toLong, Vectors.sparse(maxIndex, a._2.toSeq).asInstanceOf[SparseVector])) 43 | 44 | //run locality sensitive hashing model with 6 hashTables and 8 hash functions 45 | val lsh = new LSH(sparseVectorData, maxIndex, numHashFunc = 8, numHashTables = 6) 46 | val model = lsh.run() 47 | 48 | //print sample hashed vectors in ((hashTableId#, hashValue), vectorId) format 49 | model.hashTables.take(10) foreach println 50 | 51 | //get the near neighbors of userId: 4587 in the model 52 | val candList = model.getCandidates(4587) 53 | println("Number of Candidate Neighbors: ") 54 | println(candList.count()) 55 | println("Candidate List: " + candList.collect().toList) 56 | 57 | //save model 58 | val temp = "target/" + System.currentTimeMillis().toString 59 | model.save(sc, temp) 60 | 61 | //load model 62 | val modelLoaded = LSHModel.load(sc, temp) 63 | 64 | //print out 10 entries from loaded model 65 | modelLoaded.hashTables.take(15) foreach println 66 | 67 | //create a user vector with ratings on movies 68 | val movies = List(1,6,17,29,32,36,76,137,154,161,172,173,185,223,232,235,260,272,296,300,314,316,318,327,337,338,348) 69 | val ratings = List(5.0,4.0,4.0,5.0,5.0,4.0,5.0,3.0,4.0,4.0,4.0,4.0,4.0,5.0,5.0,4.0,5.0,5.0,4.0,4.0,4.0,5.0,5.0,5.0,4.0,4.0,4.0) 70 | val sampleVector = Vectors.sparse(maxIndex, movies zip ratings).asInstanceOf[SparseVector] 71 | println(sampleVector) 72 | 73 | //generate hash values for each bucket 74 | val hashValues = model.hashValue(sampleVector) 75 | println(hashValues) 76 | 77 | //query LSH model for candidate set 78 | val candidateList = model.getCandidates(sampleVector).collect() 79 | println(candidateList.toList) 80 | 81 | //compute similarity of sampleVector with users in candidate set 82 | val candidateVectors = sparseVectorData.filter(x => candidateList.contains(x._1)).cache() 83 | val similarities = candidateVectors.map(x => (x._1, lsh.cosine(x._2, sampleVector))) 84 | similarities foreach println 85 | 86 | } 87 | 88 | } 89 | -------------------------------------------------------------------------------- /src/test/scala/com/lendap/spark/lsh/LSHTestSuit.scala: -------------------------------------------------------------------------------- 1 | package com.lendap.spark.lsh 2 | 3 | import org.apache.spark.mllib.linalg.{SparseVector, Vectors} 4 | import org.scalatest.FunSuite 5 | 6 | /** 7 | * Created by maruf on 09/08/15. 8 | */ 9 | class LSHTestSuit extends FunSuite with LocalSparkContext { 10 | 11 | val simpleDataRDD = List( 12 | List(5.0,3.0,4.0,5.0,5.0,1.0,5.0,3.0,4.0,5.0).zipWithIndex.map(a=>a.swap), 13 | List(1.0,2.0,1.0,5.0,1.0,5.0,1.0,4.0,1.0,3.0).zipWithIndex.map(a=>a.swap), 14 | List(5.0,3.0,4.0,1.0,5.0,4.0,1.0,3.0,4.0,5.0).zipWithIndex.map(a=>a.swap), 15 | List(1.0,3.0,4.0,5.0,5.0,1.0,1.0,3.0,4.0,5.0).zipWithIndex.map(a=>a.swap)) 16 | 17 | test("hasher") { 18 | 19 | val h = Hasher(10, 12345678) 20 | val rdd = sc.parallelize(simpleDataRDD) 21 | 22 | //make sure we have 4 23 | assert(rdd.count() == 4) 24 | 25 | //convert data to RDD of SparseVector 26 | val vectorRDD = rdd.map(a => Vectors.sparse(a.size, a).asInstanceOf[SparseVector]) 27 | 28 | //make sure we still have 4 29 | assert(vectorRDD.count() == 4) 30 | 31 | val hashKey = vectorRDD.map(a => h.hash(a)).collect().mkString("") 32 | 33 | //check if calculated hash key correct 34 | assert(hashKey == "1010") 35 | 36 | } 37 | 38 | test ("lsh") { 39 | 40 | val numBands = 5 41 | val numHashFunc = 4 42 | val m = 50 //number of elements in each vector 43 | val n = 30 //number of data points (vectors) 44 | val rnd = new scala.util.Random 45 | 46 | //generate n random vectors whose elements range 1-5 47 | val dataRDD = List.range(1, n) 48 | .map(a => (a, List.fill(m)(1 + rnd.nextInt(5).toDouble).zipWithIndex.map(x => x.swap))) 49 | val vectorsRDD = sc.parallelize(dataRDD).map(a => (a._1.toLong, Vectors.sparse(a._2.size, a._2).asInstanceOf[SparseVector])) 50 | 51 | val lsh = new LSH(vectorsRDD, m, numHashFunc, numBands) 52 | val model = lsh.run() 53 | 54 | //make sure numBands hashTables created 55 | assert (model.hashTables.map(a => a._1._1).collect().distinct.length == numBands) 56 | 57 | //make sure each key size matches with number of hash functions 58 | assert (model.hashTables.filter(a => a._1._2.length != numHashFunc).count == 0) 59 | 60 | //make sure there is no empty bucket 61 | assert (model.hashTables 62 | .map(a => (a._1._2, a._2)) 63 | .groupByKey().filter(x => x._2.isEmpty) 64 | .count == 0) 65 | 66 | //make sure vectors are not clustered in one bucket 67 | assert (model.hashTables 68 | .map(a => (a._1._1, a._1._2)) 69 | .groupByKey().filter(x => x._2.size == n) 70 | .count == 0) 71 | 72 | //make sure number of buckets for each hashTables is in expected range (2 - 2^numHashFunc) 73 | assert (model.hashTables 74 | .map(a => (a._1._1, a._1._2)) 75 | .groupByKey() 76 | .map(a => (a._1, a._2.toList.distinct)) 77 | .filter(a => a._2.size < 1 || a._2.size > math.pow(2, numHashFunc)) 78 | .count == 0) 79 | 80 | //test save/load operations 81 | val temp = "target/test/" + System.currentTimeMillis().toString 82 | model.save(sc, temp) 83 | val model2 = LSHModel.load(sc, temp) 84 | 85 | //make sure size of saved and loaded models are the same 86 | assert(model.hashTables.count == model2.hashTables.count) 87 | assert(model.hashFunctions.lengthCompare(model2.hashFunctions.size) == 0) 88 | 89 | //make sure loaded model produce the same hashValue with the original 90 | val testRDD = vectorsRDD.take(10) 91 | testRDD.foreach(x => assert(model.hashValue(x._2) == model2.hashValue(x._2))) 92 | 93 | //test cosine similarity 94 | val rdd = sc.parallelize(simpleDataRDD) 95 | 96 | //convert data to RDD of SparseVector 97 | val vectorRDD = rdd.map(a => Vectors.sparse(a.size, a).asInstanceOf[SparseVector]) 98 | val a = vectorRDD.take(4)(0) 99 | val b = vectorRDD.take(4)(3) 100 | assert(lsh.cosine(a, b) === 0.9061030445113443) 101 | 102 | 103 | } 104 | 105 | } 106 | -------------------------------------------------------------------------------- /src/test/scala/com/lendap/spark/lsh/LocalSparkContext.scala: -------------------------------------------------------------------------------- 1 | package com.lendap.spark.lsh 2 | 3 | /** 4 | * Created by maruf on 09/08/15. 5 | */ 6 | import org.scalatest.Suite 7 | import org.scalatest.BeforeAndAfterAll 8 | 9 | import org.apache.spark.{SparkConf, SparkContext} 10 | 11 | trait LocalSparkContext extends BeforeAndAfterAll { self: Suite => 12 | @transient var sc: SparkContext = _ 13 | 14 | override def beforeAll() { 15 | val conf = new SparkConf() 16 | .setMaster("local") 17 | .setAppName("test") 18 | sc = new SparkContext(conf) 19 | super.beforeAll() 20 | } 21 | 22 | override def afterAll() { 23 | if (sc != null) { 24 | sc.stop() 25 | } 26 | super.afterAll() 27 | } 28 | } 29 | --------------------------------------------------------------------------------