├── .gitignore ├── .travis.yml ├── LICENSE ├── README.md ├── build.sbt ├── build ├── sbt ├── sbt-launch-0.13.6.jar └── sbt-launch-lib.bash ├── project ├── build.properties └── plugins.sbt └── src ├── main └── scala │ └── com │ └── brkyvz │ └── spark │ └── linalg │ ├── BLASUtils.scala │ ├── MatrixLike.scala │ ├── VectorLike.scala │ ├── funcs.scala │ └── package.scala └── test └── scala └── com └── brkyvz └── spark ├── linalg ├── BLASUtilsSuite.scala ├── MatricesSuite.scala └── VectorsSuite.scala └── util └── TestingUtils.scala /.gitignore: -------------------------------------------------------------------------------- 1 | *.class 2 | *.log 3 | *.pyc 4 | sbt/*.jar 5 | .idea/ 6 | 7 | # sbt specific 8 | .cache/ 9 | .history/ 10 | .lib/ 11 | dist/* 12 | target/ 13 | lib_managed/ 14 | src_managed/ 15 | project/boot/ 16 | project/plugins/project/ 17 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: scala 2 | scala: 3 | - 2.10.4 4 | jdk: 5 | - openjdk7 6 | sudo: false 7 | before_install: 8 | - pip install --user codecov 9 | script: 10 | - sbt -jvm-opts travis/jvmopts.compile compile 11 | - sbt -jvm-opts travis/jvmopts.test test 12 | after_success: 13 | - codecov -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License, Version 2.0 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. 13 | 14 | "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. 15 | 16 | "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. 17 | 18 | "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. 19 | 20 | "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. 21 | 22 | "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). 23 | 24 | "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. 25 | 26 | "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." 27 | 28 | "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 29 | 30 | 2. Grant of Copyright License. 31 | 32 | Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 33 | 34 | 3. Grant of Patent License. 35 | 36 | Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 37 | 38 | 4. Redistribution. 39 | 40 | You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: 41 | 42 | You must give any other recipients of the Work or Derivative Works a copy of this License; and 43 | You must cause any modified files to carry prominent notices stating that You changed the files; and 44 | You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and 45 | If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. 46 | You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 47 | 48 | 5. Submission of Contributions. 49 | 50 | Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 51 | 52 | 6. Trademarks. 53 | 54 | This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 55 | 56 | 7. Disclaimer of Warranty. 57 | 58 | Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 59 | 60 | 8. Limitation of Liability. 61 | 62 | In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 63 | 64 | 9. Accepting Warranty or Additional Liability. 65 | 66 | While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. 67 | 68 | END OF TERMS AND CONDITIONS 69 | 70 | APPENDIX: How to apply the Apache License to your work 71 | 72 | To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. 73 | 74 | Copyright [yyyy] [name of copyright owner] 75 | 76 | Licensed under the Apache License, Version 2.0 (the "License"); 77 | you may not use this file except in compliance with the License. 78 | You may obtain a copy of the License at 79 | 80 | http://www.apache.org/licenses/LICENSE-2.0 81 | 82 | Unless required by applicable law or agreed to in writing, software 83 | distributed under the License is distributed on an "AS IS" BASIS, 84 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 85 | See the License for the specific language governing permissions and 86 | limitations under the License. 87 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | lazy-linalg 2 | ----------- 3 | 4 | Linear algebra operators for Apache Spark MLlib's linalg package. Works best with Spark 1.5. 5 | 6 | Goal 7 | ==== 8 | 9 | Most of the code in this repository was written as a part of 10 | [SPARK-6442](https://issues.apache.org/jira/browse/SPARK-6442). The aim was to support the most 11 | common local linear algebra operations on top of Spark without having to depend on an external 12 | library. 13 | 14 | It is somewhat cumbersome to write code where you have to convert the MLlib representation of a 15 | vector or matrix to Breeze perform the simplest arithmetic operations like addition, subtraction, etc. 16 | This package aims to lift that burden, and provide efficient implementations for some of these methods. 17 | 18 | By keeping operations lazy, this package provides some of the optimizations that you would see 19 | in C++ libraries like Armadillo, Eigen, etc. 20 | 21 | Installation 22 | ============ 23 | 24 | Include this package in your Spark Applications using: 25 | 26 | ### spark-shell, pyspark, or spark-submit 27 | 28 | ``` 29 | > $SPARK_HOME/bin/spark-shell --packages brkyvz:lazy-linalg:0.1.0 30 | ``` 31 | 32 | ### sbt 33 | 34 | If you use the [sbt-spark-package plugin](https://github.com/databricks/sbt-spark-package), 35 | in your sbt build file, add: 36 | 37 | ```scala 38 | spDependencies += "brkyvz/lazy-linalg:0.1.0" 39 | ``` 40 | 41 | Otherwise, 42 | 43 | ```scala 44 | resolvers += "Spark Packages Repo" at "http://dl.bintray.com/spark-packages/maven" 45 | 46 | libraryDependencies += "brkyvz" % "lazy-linalg" % "0.1.0" 47 | ``` 48 | 49 | ### Maven 50 | 51 | In your pom.xml, add: 52 | ```xml 53 | 54 | 55 | 56 | brkyvz 57 | lazy-linalg 58 | 0.1.0 59 | 60 | 61 | 62 | 63 | 64 | SparkPackagesRepo 65 | http://dl.bintray.com/spark-packages/maven 66 | 67 | 68 | ``` 69 | 70 | Examples 71 | ======== 72 | 73 | Import `com.brkyvz.spark.linalg._` and all the implicits will kick in for Scala users. 74 | 75 | ```scala 76 | scala> import com.brkyvz.spark.linalg._ 77 | scala> import org.apache.spark.mllib.linalg._ 78 | 79 | scala> val rnd = new java.util.Random 80 | 81 | scala> val A = DenseMatrix.eye(3) 82 | A: org.apache.spark.mllib.linalg.DenseMatrix = 83 | 1.0 0.0 0.0 84 | 0.0 1.0 0.0 85 | 0.0 0.0 1.0 86 | 87 | scala> val B = DenseMatrix.rand(3, 3, rnd) 88 | B: org.apache.spark.mllib.linalg.DenseMatrix = 89 | 0.6133402813080373 0.7162729054788076 0.15011768207263143 90 | 0.3078993912354502 0.23923486751376188 0.05973497171994935 91 | 0.49892408305838276 0.9534484503645188 0.48047741591983717 92 | 93 | scala> val C = DenseMatrix.zeros(3, 3) 94 | C: org.apache.spark.mllib.linalg.DenseMatrix = 95 | 0.0 0.0 0.0 96 | 0.0 0.0 0.0 97 | 0.0 0.0 0.0 98 | 99 | scala> C := A + B + A * B 100 | res0: com.brkyvz.spark.linalg.MatrixLike = 101 | 2.2266805626160746 1.4325458109576152 0.30023536414526286 102 | 0.6157987824709004 1.4784697350275238 0.1194699434398987 103 | 0.9978481661167655 1.9068969007290375 1.9609548318396746 104 | 105 | scala> C += -1 106 | res1: com.brkyvz.spark.linalg.DenseMatrixWrapper = 107 | 1.2266805626160746 0.43254581095761524 -0.6997646358547371 108 | -0.38420121752909964 0.47846973502752377 -0.8805300565601013 109 | -0.0021518338832344774 0.9068969007290375 0.9609548318396746 110 | 111 | scala> val D = A * 2 - 1 112 | scala> D.compute() 113 | res2: com.brkyvz.spark.linalg.MatrixLike = 114 | 1.0 -1.0 -1.0 115 | -1.0 1.0 -1.0 116 | -1.0 -1.0 1.0 117 | ``` 118 | 119 | Matrix multiplication vs. element-wise multiplication: 120 | 121 | ```scala 122 | scala> C := A * B 123 | res4: com.brkyvz.spark.linalg.MatrixLike = 124 | 0.6133402813080373 0.7162729054788076 0.15011768207263143 125 | 0.3078993912354502 0.23923486751376188 0.05973497171994935 126 | 0.49892408305838276 0.9534484503645188 0.48047741591983717 127 | 128 | scala> C := A :* B 129 | res5: com.brkyvz.spark.linalg.MatrixLike = 130 | 0.6133402813080373 0.0 0.0 131 | 0.0 0.23923486751376188 0.0 132 | 0.0 0.0 0.48047741591983717 133 | ``` 134 | 135 | Support for element-wise basic math functions: 136 | 137 | ```scala 138 | import com.brkyvz.spark.linalg.funcs._ 139 | 140 | scala> C := pow(B, A) 141 | res7: com.brkyvz.spark.linalg.MatrixLike = 142 | 0.6133402813080373 1.0 1.0 143 | 1.0 0.23923486751376188 1.0 144 | 1.0 1.0 0.48047741591983717 145 | 146 | scala> C := exp(A) 147 | res8: com.brkyvz.spark.linalg.MatrixLike = 148 | 2.718281828459045 1.0 1.0 149 | 1.0 2.718281828459045 1.0 150 | 1.0 1.0 2.718281828459045 151 | 152 | scala> C := asin(A) - math.Pi / 2 153 | res12: com.brkyvz.spark.linalg.MatrixLike = 154 | 0.0 -1.5707963267948966 -1.5707963267948966 155 | -1.5707963267948966 0.0 -1.5707963267948966 156 | -1.5707963267948966 -1.5707963267948966 0.0 157 | ``` 158 | 159 | All operations work similarly for vectors. 160 | 161 | ```scala 162 | scala> import com.brkyvz.spark.linalg._ 163 | scala> import org.apache.spark.mllib.linalg._ 164 | 165 | scala> val x = Vectors.dense(1, 0, 2, 1) 166 | x: org.apache.spark.mllib.linalg.Vector = [1.0,0.0,2.0,1.0] 167 | 168 | scala> val y = Vectors.dense(0, 0, 0, 0) 169 | y: org.apache.spark.mllib.linalg.Vector = [0.0,0.0,0.0,0.0] 170 | 171 | scala> y := x / 2 + x * 0.5 172 | res15: com.brkyvz.spark.linalg.DenseVectorWrapper = [1.0,0.0,2.0,1.0] 173 | 174 | scala> val z = Vectors.dense(1, -2, 3) 175 | z: org.apache.spark.mllib.linalg.Vector = [1.0,-2.0,3.0] 176 | 177 | scala> val m = A * z + 1 178 | scala> m.compute() 179 | res16: com.brkyvz.spark.linalg.VectorLike = [2.0,-1.0,4.0] 180 | ``` 181 | 182 | Caveats 183 | ======= 184 | 185 | 1- Implicits may not work perfectly for vectors and matrices generated through `Matrices.` and 186 | `Vectors.` consutructors. 187 | 188 | 2- Scalars need to be **after** the matrix or the vector during operations, e.g. use `x * 2` instead 189 | of `2 * x`. 190 | 191 | What's to Come 192 | ============== 193 | 194 | 1- Support for more linear algebra operations: determinant, matrix inverse, trace, slicing, 195 | reshaping... 196 | 197 | 2- Supporting such methods with BlockMatrix. 198 | 199 | 3- Better, smarter lazy evaluation with codegen. This should bring performance closer to C++ 200 | libraries. 201 | 202 | 203 | Contributing 204 | ============ 205 | 206 | If you run across any bugs, please file issues, and please feel free to submit pull requests! 207 | 208 | 209 | -------------------------------------------------------------------------------- /build.sbt: -------------------------------------------------------------------------------- 1 | scalaVersion := "2.10.4" 2 | 3 | sparkVersion := "1.5.0" 4 | 5 | spName := "brkyvz/lazy-linalg" 6 | 7 | version := "0.1.0" 8 | 9 | licenses := Seq("Apache-2.0" -> url("http://opensource.org/licenses/Apache-2.0")) 10 | 11 | sparkComponents += "mllib" 12 | 13 | libraryDependencies += "org.scalatest" %% "scalatest" % "1.9.1" % "test" 14 | 15 | libraryDependencies += "holdenk" % "spark-testing-base" % "1.4.1_0.1.1" % "test" 16 | 17 | parallelExecution in Test := false 18 | 19 | ScoverageSbtPlugin.ScoverageKeys.coverageHighlighting := { 20 | if (scalaBinaryVersion.value == "2.10") false 21 | else true 22 | } 23 | 24 | spShortDescription := "Linear algebra operators for Apache Spark MLlib's linalg package" 25 | 26 | spDescription := 27 | """It is somewhat cumbersome to write code where you have to convert the MLlib representation of a 28 | |vector or matrix to Breeze perform the simplest arithmetic operations like addition, subtraction, etc. 29 | |This package aims to lift that burden, and provide efficient implementations for some of these methods. 30 | | 31 | |By keeping operations lazy, this package provides some of the optimizations that you would see 32 | |in C++ libraries like Armadillo, Eigen, etc. 33 | """.stripMargin 34 | 35 | credentials += Credentials(Path.userHome / ".ivy2" / ".sbtcredentials") 36 | -------------------------------------------------------------------------------- /build/sbt: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # When creating new tests for Spark SQL Hive, the HADOOP_CLASSPATH must contain the hive jars so 4 | # that we can run Hive to generate the golden answer. This is not required for normal development 5 | # or testing. 6 | for i in $HIVE_HOME/lib/* 7 | do HADOOP_CLASSPATH=$HADOOP_CLASSPATH:$i 8 | done 9 | export HADOOP_CLASSPATH 10 | 11 | realpath () { 12 | ( 13 | TARGET_FILE=$1 14 | 15 | cd $(dirname $TARGET_FILE) 16 | TARGET_FILE=$(basename $TARGET_FILE) 17 | 18 | COUNT=0 19 | while [ -L "$TARGET_FILE" -a $COUNT -lt 100 ] 20 | do 21 | TARGET_FILE=$(readlink $TARGET_FILE) 22 | cd $(dirname $TARGET_FILE) 23 | TARGET_FILE=$(basename $TARGET_FILE) 24 | COUNT=$(($COUNT + 1)) 25 | done 26 | 27 | echo $(pwd -P)/$TARGET_FILE 28 | ) 29 | } 30 | 31 | . $(dirname $(realpath $0))/sbt-launch-lib.bash 32 | 33 | 34 | declare -r noshare_opts="-Dsbt.global.base=project/.sbtboot -Dsbt.boot.directory=project/.boot -Dsbt.ivy.home=project/.ivy" 35 | declare -r sbt_opts_file=".sbtopts" 36 | declare -r etc_sbt_opts_file="/etc/sbt/sbtopts" 37 | 38 | usage() { 39 | cat < path to global settings/plugins directory (default: ~/.sbt) 47 | -sbt-boot path to shared boot directory (default: ~/.sbt/boot in 0.11 series) 48 | -ivy path to local Ivy repository (default: ~/.ivy2) 49 | -mem set memory options (default: $sbt_mem, which is $(get_mem_opts $sbt_mem)) 50 | -no-share use all local caches; no sharing 51 | -no-global uses global caches, but does not use global ~/.sbt directory. 52 | -jvm-debug Turn on JVM debugging, open at the given port. 53 | -batch Disable interactive mode 54 | # sbt version (default: from project/build.properties if present, else latest release) 55 | -sbt-version use the specified version of sbt 56 | -sbt-jar use the specified jar as the sbt launcher 57 | -sbt-rc use an RC version of sbt 58 | -sbt-snapshot use a snapshot version of sbt 59 | # java version (default: java from PATH, currently $(java -version 2>&1 | grep version)) 60 | -java-home alternate JAVA_HOME 61 | # jvm options and output control 62 | JAVA_OPTS environment variable, if unset uses "$java_opts" 63 | SBT_OPTS environment variable, if unset uses "$default_sbt_opts" 64 | .sbtopts if this file exists in the current directory, it is 65 | prepended to the runner args 66 | /etc/sbt/sbtopts if this file exists, it is prepended to the runner args 67 | -Dkey=val pass -Dkey=val directly to the java runtime 68 | -J-X pass option -X directly to the java runtime 69 | (-J is stripped) 70 | -S-X add -X to sbt's scalacOptions (-J is stripped) 71 | -PmavenProfiles Enable a maven profile for the build. 72 | In the case of duplicated or conflicting options, the order above 73 | shows precedence: JAVA_OPTS lowest, command line options highest. 74 | EOM 75 | } 76 | 77 | process_my_args () { 78 | while [[ $# -gt 0 ]]; do 79 | case "$1" in 80 | -no-colors) addJava "-Dsbt.log.noformat=true" && shift ;; 81 | -no-share) addJava "$noshare_opts" && shift ;; 82 | -no-global) addJava "-Dsbt.global.base=$(pwd)/project/.sbtboot" && shift ;; 83 | -sbt-boot) require_arg path "$1" "$2" && addJava "-Dsbt.boot.directory=$2" && shift 2 ;; 84 | -sbt-dir) require_arg path "$1" "$2" && addJava "-Dsbt.global.base=$2" && shift 2 ;; 85 | -debug-inc) addJava "-Dxsbt.inc.debug=true" && shift ;; 86 | -batch) exec &2 "$@" 31 | } 32 | vlog () { 33 | [[ $verbose || $debug ]] && echoerr "$@" 34 | } 35 | dlog () { 36 | [[ $debug ]] && echoerr "$@" 37 | } 38 | 39 | acquire_sbt_jar () { 40 | SBT_VERSION=`awk -F "=" '/sbt\\.version/ {print $2}' ./project/build.properties` 41 | URL1=https://dl.bintray.com/typesafe/ivy-releases/org.scala-sbt/sbt-launch/${SBT_VERSION}/sbt-launch.jar 42 | JAR=build/sbt-launch-${SBT_VERSION}.jar 43 | 44 | sbt_jar=$JAR 45 | 46 | if [[ ! -f "$sbt_jar" ]]; then 47 | # Download sbt launch jar if it hasn't been downloaded yet 48 | if [ ! -f ${JAR} ]; then 49 | # Download 50 | printf "Attempting to fetch sbt\n" 51 | JAR_DL=${JAR}.part 52 | if hash curl 2>/dev/null; then 53 | curl --fail --location --silent ${URL1} > "${JAR_DL}" &&\ 54 | mv "${JAR_DL}" "${JAR}" 55 | elif hash wget 2>/dev/null; then 56 | wget --quiet ${URL1} -O "${JAR_DL}" &&\ 57 | mv "${JAR_DL}" "${JAR}" 58 | else 59 | printf "You do not have curl or wget installed, please install sbt manually from http://www.scala-sbt.org/\n" 60 | exit -1 61 | fi 62 | fi 63 | if [ ! -f ${JAR} ]; then 64 | # We failed to download 65 | printf "Our attempt to download sbt locally to ${JAR} failed. Please install sbt manually from http://www.scala-sbt.org/\n" 66 | exit -1 67 | fi 68 | printf "Launching sbt from ${JAR}\n" 69 | fi 70 | } 71 | 72 | execRunner () { 73 | # print the arguments one to a line, quoting any containing spaces 74 | [[ $verbose || $debug ]] && echo "# Executing command line:" && { 75 | for arg; do 76 | if printf "%s\n" "$arg" | grep -q ' '; then 77 | printf "\"%s\"\n" "$arg" 78 | else 79 | printf "%s\n" "$arg" 80 | fi 81 | done 82 | echo "" 83 | } 84 | 85 | exec "$@" 86 | } 87 | 88 | addJava () { 89 | dlog "[addJava] arg = '$1'" 90 | java_args=( "${java_args[@]}" "$1" ) 91 | } 92 | 93 | enableProfile () { 94 | dlog "[enableProfile] arg = '$1'" 95 | maven_profiles=( "${maven_profiles[@]}" "$1" ) 96 | export SBT_MAVEN_PROFILES="${maven_profiles[@]}" 97 | } 98 | 99 | addSbt () { 100 | dlog "[addSbt] arg = '$1'" 101 | sbt_commands=( "${sbt_commands[@]}" "$1" ) 102 | } 103 | addResidual () { 104 | dlog "[residual] arg = '$1'" 105 | residual_args=( "${residual_args[@]}" "$1" ) 106 | } 107 | addDebugger () { 108 | addJava "-Xdebug -Xrunjdwp:transport=dt_socket,server=y,suspend=n,address=$1" 109 | } 110 | 111 | # a ham-fisted attempt to move some memory settings in concert 112 | # so they need not be dicked around with individually. 113 | get_mem_opts () { 114 | local mem=${1:-2048} 115 | local perm=$(( $mem / 4 )) 116 | (( $perm > 256 )) || perm=256 117 | (( $perm < 4096 )) || perm=4096 118 | local codecache=$(( $perm / 2 )) 119 | 120 | echo "-Xms${mem}m -Xmx${mem}m -XX:MaxPermSize=${perm}m -XX:ReservedCodeCacheSize=${codecache}m" 121 | } 122 | 123 | require_arg () { 124 | local type="$1" 125 | local opt="$2" 126 | local arg="$3" 127 | if [[ -z "$arg" ]] || [[ "${arg:0:1}" == "-" ]]; then 128 | die "$opt requires <$type> argument" 129 | fi 130 | } 131 | 132 | is_function_defined() { 133 | declare -f "$1" > /dev/null 134 | } 135 | 136 | process_args () { 137 | while [[ $# -gt 0 ]]; do 138 | case "$1" in 139 | -h|-help) usage; exit 1 ;; 140 | -v|-verbose) verbose=1 && shift ;; 141 | -d|-debug) debug=1 && shift ;; 142 | 143 | -ivy) require_arg path "$1" "$2" && addJava "-Dsbt.ivy.home=$2" && shift 2 ;; 144 | -mem) require_arg integer "$1" "$2" && sbt_mem="$2" && shift 2 ;; 145 | -jvm-debug) require_arg port "$1" "$2" && addDebugger $2 && shift 2 ;; 146 | -batch) exec JavaDouble} 5 | 6 | import org.apache.spark.mllib.linalg._ 7 | 8 | /** Util methods that use reflection to call into MLlib's private BLAS methods. */ 9 | object BLASUtils { 10 | 11 | @transient private lazy val clazz: Class[_] = Class.forName("org.apache.spark.mllib.linalg.BLAS$") 12 | 13 | @transient private lazy val _blas: Any = { 14 | val constructor = clazz.getDeclaredConstructors.head 15 | constructor.setAccessible(true) 16 | constructor.newInstance() 17 | } 18 | 19 | private def castMatrix(mat: MatrixLike, toDense: Boolean = false): Matrix = mat match { 20 | case dn: DenseMatrixWrapper => dn.asInstanceOf[DenseMatrix] 21 | case sp: SparseMatrixWrapper => 22 | if (toDense) sp.toDense else sp.asInstanceOf[SparseMatrix] 23 | case lzy: LazyMatrix => lzy.compute().asInstanceOf[DenseMatrix] 24 | case _ => throw new UnsupportedOperationException(s"${mat.getClass} can't be cast to Matrix.") 25 | } 26 | 27 | private def castVector(mat: VectorLike, toDense: Boolean = false): Vector = mat match { 28 | case dn: DenseVectorWrapper => dn.asInstanceOf[DenseVector] 29 | case sp: SparseVectorWrapper => 30 | if (toDense) sp.toDense else sp.asInstanceOf[SparseVector] 31 | case lzy: LazyVector => lzy.compute().asInstanceOf[DenseVector] 32 | case _ => throw new UnsupportedOperationException(s"${mat.getClass} can't be cast to Vector.") 33 | } 34 | 35 | private def invokeMethod(methodName: String, args: (Class[_], AnyRef)*): Any = { 36 | val (types, values) = args.unzip 37 | val method = clazz.getDeclaredMethod(methodName, types: _*) 38 | method.setAccessible(true) 39 | try { 40 | method.invoke(_blas, values.toSeq: _*) 41 | } catch { 42 | case ex: InvocationTargetException => 43 | throw new IllegalArgumentException(s"$methodName is not supported for arguments: $values") 44 | } 45 | } 46 | 47 | /** 48 | * y += a * x 49 | */ 50 | def axpy(a: Double, x: VectorLike, y: VectorLike): Unit = { 51 | val args: Seq[(Class[_], AnyRef)] = Seq((classOf[Double], new JavaDouble(a)), 52 | (classOf[Vector], castVector(x)), (classOf[Vector], castVector(y, toDense = true))) 53 | invokeMethod("axpy", args: _*) 54 | } 55 | 56 | /** 57 | * x^T^y 58 | */ 59 | def dot(x: VectorLike, y: VectorLike): Double = { 60 | val args: Seq[(Class[_], AnyRef)] = Seq( 61 | (classOf[Vector], castVector(x)), (classOf[Vector], castVector(y))) 62 | invokeMethod("dot", args: _*).asInstanceOf[Double] 63 | } 64 | 65 | /** 66 | * x = a * x 67 | */ 68 | def scal(a: Double, x: VectorLike): Unit = { 69 | val cx = castVector(x) 70 | val args: Seq[(Class[_], AnyRef)] = Seq( 71 | (classOf[Double], new JavaDouble(a)), (classOf[Vector], cx)) 72 | invokeMethod("scal", args: _*) 73 | } 74 | 75 | /** 76 | * A := alpha * x * x^T^ + A 77 | * @param alpha a real scalar that will be multiplied to x * x^T^. 78 | * @param x the vector x that contains the n elements. 79 | * @param A the symmetric matrix A. Size of n x n. 80 | */ 81 | def syr(alpha: Double, x: Vector, A: MatrixLike): Unit = { 82 | val args: Seq[(Class[_], AnyRef)] = Seq((classOf[Double], new JavaDouble(alpha)), 83 | (classOf[Vector], castVector(x)), (classOf[DenseMatrix], castMatrix(A, toDense = true))) 84 | invokeMethod("syr", args: _*) 85 | } 86 | 87 | /** 88 | * C := alpha * A * B + beta * C 89 | * @param alpha a scalar to scale the multiplication A * B. 90 | * @param A the matrix A that will be left multiplied to B. Size of m x k. 91 | * @param B the matrix B that will be left multiplied by A. Size of k x n. 92 | * @param beta a scalar that can be used to scale matrix C. 93 | * @param C the resulting matrix C. Size of m x n. C.isTransposed must be false. 94 | */ 95 | def gemm(alpha: Double, A: MatrixLike, B: MatrixLike, beta: Double, C: DenseMatrix): Unit = { 96 | B match { 97 | case dnB: DenseMatrixWrapper => mllibGemm(alpha, castMatrix(A), dnB, beta, C) 98 | case spB: SparseMatrixWrapper => 99 | A match { 100 | case dnA: DenseMatrixWrapper => dsgemm(alpha, dnA, spB, beta, C) 101 | case spA: SparseMatrixWrapper => mllibGemm(alpha, spA, spB.toDense, beta, C) 102 | case lzy: LazyMatrix => 103 | dsgemm(alpha, lzy.compute().asInstanceOf[DenseMatrixWrapper], spB, beta, C) 104 | } 105 | case lzy: LazyMatrix => 106 | mllibGemm(alpha, castMatrix(A), lzy.compute().asInstanceOf[DenseMatrix], beta, C) 107 | } 108 | } 109 | 110 | private def mllibGemm( 111 | alpha: Double, 112 | A: Matrix, 113 | B: DenseMatrix, 114 | beta: Double, 115 | C: DenseMatrix): Unit = { 116 | val args: Seq[(Class[_], AnyRef)] = Seq( 117 | (classOf[Double], new JavaDouble(alpha)), (classOf[Matrix], A), (classOf[DenseMatrix], B), 118 | (classOf[Double], new JavaDouble(beta)), (classOf[DenseMatrix], C)) 119 | invokeMethod("gemm", args: _*) 120 | } 121 | 122 | private def dsgemm( 123 | alpha: Double, 124 | A: DenseMatrixWrapper, 125 | B: SparseMatrixWrapper, 126 | beta: Double, 127 | C: DenseMatrix): Unit = { 128 | val mA: Int = A.numRows 129 | val nB: Int = B.numCols 130 | val kA: Int = A.numCols 131 | val kB: Int = B.numRows 132 | 133 | require(kA == kB, s"The columns of A don't match the rows of B. A: $kA, B: $kB") 134 | require(mA == C.numRows, s"The rows of C don't match the rows of A. C: ${C.numRows}, A: $mA") 135 | require(nB == C.numCols, 136 | s"The columns of C don't match the columns of B. C: ${C.numCols}, A: $nB") 137 | 138 | val Avals = A.values 139 | val Bvals = B.values 140 | val Cvals = C.values 141 | val BrowIndices = B.rowIndices 142 | val BcolPtrs = B.colPtrs 143 | 144 | // Slicing is easy in this case. This is the optimal multiplication setting for sparse matrices 145 | if (!B.isTransposed){ 146 | var colCounterForB = 0 147 | if (A.isTransposed) { // Expensive to put the check inside the loop 148 | while (colCounterForB < nB) { 149 | var rowCounterForA = 0 150 | val Cstart = colCounterForB * mA 151 | val Bstart = BcolPtrs(colCounterForB) 152 | while (rowCounterForA < mA) { 153 | var i = Bstart 154 | val indEnd = BcolPtrs(colCounterForB + 1) 155 | val Astart = rowCounterForA * kA 156 | var sum = 0.0 157 | while (i < indEnd) { 158 | sum += Avals(Astart + BrowIndices(i)) * Bvals(i) 159 | i += 1 160 | } 161 | val Cindex = Cstart + rowCounterForA 162 | Cvals(Cindex) = beta * Cvals(Cindex) + sum * alpha 163 | rowCounterForA += 1 164 | } 165 | colCounterForB += 1 166 | } 167 | } else { 168 | while (colCounterForB < nB) { 169 | var rowCounterForA = 0 170 | val Cstart = colCounterForB * mA 171 | while (rowCounterForA < mA) { 172 | var i = BcolPtrs(colCounterForB) 173 | val indEnd = BcolPtrs(colCounterForB + 1) 174 | var sum = 0.0 175 | while (i < indEnd) { 176 | sum += A(rowCounterForA, BrowIndices(i)) * Bvals(i) 177 | i += 1 178 | } 179 | val Cindex = Cstart + rowCounterForA 180 | Cvals(Cindex) = beta * Cvals(Cindex) + sum * alpha 181 | rowCounterForA += 1 182 | } 183 | colCounterForB += 1 184 | } 185 | } 186 | } else { 187 | // Scale matrix first if `beta` is not equal to 0.0 188 | if (beta != 1.0) { 189 | scal(beta, new DenseVectorWrapper(C.values)) 190 | } 191 | // Perform matrix multiplication and add to C. The rows of A are multiplied by the columns of 192 | // B, and added to C. 193 | var rowCounterForB = 0 // the column to be updated in C 194 | if (!A.isTransposed) { // Expensive to put the check inside the loop 195 | while (rowCounterForB < kB) { 196 | var i = BcolPtrs(rowCounterForB) 197 | val indEnd = BcolPtrs(rowCounterForB + 1) 198 | while (i < indEnd) { 199 | var rowCounterForA = 0 // The column of A to multiply with the row of B 200 | val Bval = Bvals(i) * alpha 201 | val Cstart = BrowIndices(i) * mA 202 | val Astart = rowCounterForB * mA 203 | while (rowCounterForA < mA) { 204 | Cvals(Cstart + rowCounterForA) += Avals(Astart + rowCounterForA) * Bval 205 | rowCounterForA += 1 206 | } 207 | i += 1 208 | } 209 | rowCounterForB += 1 210 | } 211 | } else { 212 | while (rowCounterForB < kB) { 213 | var i = BcolPtrs(rowCounterForB) 214 | val indEnd = BcolPtrs(rowCounterForB + 1) 215 | while (i < indEnd) { 216 | var rowCounterForA = 0 // The column of A to multiply with the row of B 217 | val Bval = Bvals(i) * alpha 218 | val Bcol = BrowIndices(i) 219 | val Cstart = Bcol * mA 220 | while (rowCounterForA < mA) { 221 | Cvals(Cstart + rowCounterForA) += A(rowCounterForA, rowCounterForB) * Bval 222 | rowCounterForA += 1 223 | } 224 | i += 1 225 | } 226 | rowCounterForB += 1 227 | } 228 | } 229 | } 230 | } 231 | 232 | /** 233 | * y := alpha * A * x + beta * y 234 | * @param alpha a scalar to scale the multiplication A * x. 235 | * @param A the matrix A that will be left multiplied to x. Size of m x n. 236 | * @param x the vector x that will be left multiplied by A. Size of n x 1. 237 | * @param beta a scalar that can be used to scale vector y. 238 | * @param y the resulting vector y. Size of m x 1. 239 | */ 240 | def gemv( 241 | alpha: Double, 242 | A: MatrixLike, 243 | x: VectorLike, 244 | beta: Double, 245 | y: VectorLike): Unit = { 246 | val a: Matrix = castMatrix(A) 247 | val _x: Vector = castVector(x) 248 | val _y: Vector = castVector(y) 249 | val args: Seq[(Class[_], AnyRef)] = Seq((classOf[Double], new JavaDouble(alpha)), 250 | (classOf[Matrix], a), (classOf[Vector], x), 251 | (classOf[Double], new JavaDouble(beta)), (classOf[DenseVector], y)) 252 | invokeMethod("gemv", args: _*) 253 | } 254 | } 255 | -------------------------------------------------------------------------------- /src/main/scala/com/brkyvz/spark/linalg/MatrixLike.scala: -------------------------------------------------------------------------------- 1 | package com.brkyvz.spark.linalg 2 | 3 | import org.apache.spark.mllib.linalg.{Matrix, DenseMatrix, SparseMatrix} 4 | 5 | trait MatrixLike extends Serializable { 6 | 7 | /** Number of rows. */ 8 | def numRows: Int 9 | 10 | /** Number of columns. */ 11 | def numCols: Int 12 | 13 | def size: Int = numRows * numCols 14 | 15 | def apply(i: Int): Double 16 | 17 | import funcs._ 18 | def +(y: MatrixLike): LazyMatrix = add(this, y) 19 | def -(y: MatrixLike): LazyMatrix = sub(this, y) 20 | def :*(y: MatrixLike): LazyMatrix = emul(this, y) 21 | def *(y: MatrixLike): LazyMatrix 22 | def /(y: MatrixLike): LazyMatrix = div(this, y) 23 | } 24 | 25 | /** Dense and Sparse Matrices can be mutated. Lazy matrices are immutable. */ 26 | sealed trait MutableMatrix extends MatrixLike { 27 | override def *(y: MatrixLike): LazyMatrix = { 28 | require(this.numCols == y.numRows || y.isInstanceOf[Scalar], 29 | s"numCols of left side doesn't match numRows of right. ${this.numCols} vs. ${y.numRows}") 30 | y match { 31 | case mm: MutableMatrix => new LazyMM_MMultOp(this, mm) 32 | case lzy: LazyMatrix => new LazyML_MMultOp(this, lzy) 33 | case scalar: Scalar => funcs.emul(this, scalar) 34 | } 35 | } 36 | def *(y: VectorLike): LazyVector = { 37 | require(this.numCols == y.size, 38 | s"numCols of left side doesn't match numRows of right. ${this.numCols} vs. ${y.size}") 39 | y match { 40 | case dn: DenseVectorWrapper => new LazyMM_MV_MultOp(this, dn) 41 | case sp: SparseVectorWrapper => new LazyMM_MV_MultOp(this, sp) 42 | case lzy: LazyVector => new LazyMM_LV_MultOp(this, lzy) 43 | } 44 | } 45 | } 46 | 47 | class DenseMatrixWrapper( 48 | override val numRows: Int, 49 | override val numCols: Int, 50 | override val values: Array[Double], 51 | override val isTransposed: Boolean) 52 | extends DenseMatrix(numRows, numCols, values, isTransposed) with MutableMatrix { 53 | 54 | def this(numRows: Int, numCols: Int, values: Array[Double]) = 55 | this(numRows, numCols, values, isTransposed = false) 56 | 57 | override def apply(i: Int): Double = values(i) 58 | 59 | def +=(y: MatrixLike): this.type = { 60 | require(y.numRows == this.numRows || y.isInstanceOf[Scalar], 61 | s"Rows don't match for in-place addition. ${this.numRows} vs. ${y.numRows}") 62 | require(y.numCols == this.numCols || y.isInstanceOf[Scalar], 63 | s"Cols don't match for in-place addition. ${this.numCols} vs. ${y.numCols}") 64 | y match { 65 | case dd: LazyMM_MMultOp => 66 | new LazyMM_MMultOp(dd.left, dd.right, Option(this.values), 1.0).compute() 67 | case dl: LazyML_MMultOp => 68 | new LazyML_MMultOp(dl.left, dl.right, Option(this.values), 1.0).compute() 69 | case ld: LazyLM_MMultOp => 70 | new LazyLM_MMultOp(ld.left, ld.right, Option(this.values), 1.0).compute() 71 | case ll: LazyLL_MMultOp => 72 | new LazyLL_MMultOp(ll.left, ll.right, Option(this.values), 1.0).compute() 73 | case _ => new LazyImDenseMMOp(this, y, _ + _).compute(Option(this.values)) 74 | } 75 | this 76 | } 77 | 78 | def :=(y: LazyMatrix): MatrixLike = { 79 | require(y.numRows == this.numRows, 80 | s"Rows don't match for in-place evaluation. ${this.numRows} vs. ${y.numRows}") 81 | require(y.numCols == this.numCols, 82 | s"Cols don't match for in-place evaluation. ${this.numCols} vs. ${y.numCols}") 83 | y match { 84 | case dd: LazyMM_MMultOp => 85 | new LazyMM_MMultOp(dd.left, dd.right, Option(this.values), 0.0).compute() 86 | case dl: LazyML_MMultOp => 87 | new LazyML_MMultOp(dl.left, dl.right, Option(this.values), 0.0).compute() 88 | case ld: LazyLM_MMultOp => 89 | new LazyLM_MMultOp(ld.left, ld.right, Option(this.values), 0.0).compute() 90 | case ll: LazyLL_MMultOp => 91 | new LazyLL_MMultOp(ll.left, ll.right, Option(this.values), 0.0).compute() 92 | case _ => y.compute(Option(this.values)) 93 | } 94 | this 95 | } 96 | } 97 | 98 | object DenseMatrixWrapper { 99 | def apply(mat: DenseMatrix): DenseMatrixWrapper = 100 | new DenseMatrixWrapper(mat.numRows, mat.numCols, mat.values, mat.isTransposed) 101 | } 102 | 103 | class SparseMatrixWrapper( 104 | override val numRows: Int, 105 | override val numCols: Int, 106 | override val colPtrs: Array[Int], 107 | override val rowIndices: Array[Int], 108 | override val values: Array[Double], 109 | override val isTransposed: Boolean) 110 | extends SparseMatrix(numRows, numCols, colPtrs, rowIndices, values, isTransposed) 111 | with MutableMatrix { 112 | 113 | def this( 114 | numRows: Int, 115 | numCols: Int, 116 | colPtrs: Array[Int], 117 | rowIndices: Array[Int], 118 | values: Array[Double]) = 119 | this(numRows, numCols, colPtrs, rowIndices, values, isTransposed = false) 120 | 121 | override def apply(i: Int): Double = this(i % numRows, i / numRows) 122 | } 123 | 124 | object SparseMatrixWrapper { 125 | def apply(mat: SparseMatrix): SparseMatrixWrapper = new SparseMatrixWrapper(mat.numRows, 126 | mat.numCols, mat.colPtrs, mat.rowIndices, mat.values, mat.isTransposed) 127 | } 128 | 129 | sealed trait LazyMatrix extends MatrixLike { 130 | def compute(into: Option[Array[Double]] = None): MatrixLike = { 131 | val values = into.getOrElse(new Array[Double](size)) 132 | require(values.length == size, 133 | s"Size of buffer (${values.length}) not equal to size of matrix ($size).") 134 | var i = 0 135 | while (i < size) { 136 | values(i) = this(i) 137 | i += 1 138 | } 139 | new DenseMatrixWrapper(numRows, numCols, values) 140 | } 141 | override def *(y: MatrixLike): LazyMatrix = { 142 | require(this.numCols == y.numRows || y.isInstanceOf[Scalar], 143 | s"numCols of left side doesn't match numRows of right. ${this.numCols} vs. ${y.numRows}") 144 | y match { 145 | case mm: MutableMatrix => new LazyLM_MMultOp(this, mm) 146 | case lzy: LazyMatrix => new LazyLL_MMultOp(this, lzy) 147 | case scalar: Scalar => funcs.emul(this, scalar) 148 | } 149 | } 150 | def *(y: VectorLike): LazyVector = { 151 | require(this.numCols == y.size, 152 | s"numCols of left side doesn't match numRows of right. ${this.numCols} vs. ${y.size}") 153 | y match { 154 | case dn: DenseVectorWrapper => new LazyLM_MV_MultOp(this, dn) 155 | case sp: SparseVectorWrapper => new LazyLM_MV_MultOp(this, sp) 156 | case lzy: LazyVector => new LazyLM_LV_MultOp(this, lzy) 157 | } 158 | } 159 | } 160 | 161 | private[linalg] abstract class LazyMMOp( 162 | left: MatrixLike, 163 | right: MatrixLike, 164 | operation: (Double, Double) => Double) extends LazyMatrix { 165 | require(left.numRows == right.numRows || left.isInstanceOf[Scalar] || right.isInstanceOf[Scalar], 166 | s"Rows don't match for in-place addition. ${left.numRows} vs. ${right.numRows}") 167 | require(left.numCols == right.numCols || left.isInstanceOf[Scalar] || right.isInstanceOf[Scalar], 168 | s"Cols don't match for in-place addition. ${left.numCols} vs. ${right.numCols}") 169 | override def numRows = math.max(left.numRows, right.numRows) 170 | override def numCols = math.max(left.numCols, right.numCols) 171 | } 172 | 173 | private[linalg] class LazyImDenseMMOp( 174 | left: MatrixLike, 175 | right: MatrixLike, 176 | operation: (Double, Double) => Double) extends LazyMMOp(left, right, operation) { 177 | override def apply(i: Int): Double = operation(left(i), right(i)) 178 | } 179 | 180 | private[linalg] case class LazyImDenseScaleOp( 181 | left: Scalar, 182 | right: MatrixLike) extends LazyImDenseMMOp(left, right, _ * _) 183 | 184 | private[linalg] class LazyMatrixMapOp( 185 | parent: MatrixLike, 186 | operation: Double => Double) extends LazyMatrix { 187 | override def numRows = parent.numRows 188 | override def numCols = parent.numCols 189 | override def apply(i: Int): Double = operation(parent(i)) 190 | } 191 | 192 | private[linalg] abstract class LazyMMultOp( 193 | left: MatrixLike, 194 | right: MatrixLike, 195 | into: Option[Array[Double]] = None, 196 | beta: Double = 1.0) extends LazyMatrix { 197 | override def numRows = left.numRows 198 | override def numCols = right.numCols 199 | } 200 | 201 | private[linalg] class LazyLL_MMultOp( 202 | val left: LazyMatrix, 203 | val right: LazyMatrix, 204 | into: Option[Array[Double]] = None, 205 | beta: Double = 1.0) extends LazyMMultOp(left, right, into, beta) { 206 | override def apply(i: Int): Double = result(i) 207 | 208 | private var buffer: Option[Array[Double]] = into 209 | 210 | lazy val result: DenseMatrixWrapper = { 211 | var leftScale = 1.0 212 | val (effLeft: DenseMatrixWrapper, leftRes) = left match { 213 | case scaled: LazyImDenseScaleOp => 214 | leftScale = scaled.left.value 215 | (scaled.right, None) 216 | case ll: LazyLL_MMultOp => 217 | if (ll.size < ll.right.size) { 218 | (ll.compute(), None) 219 | } else { 220 | (ll.right.compute(), Option(ll.left)) 221 | } 222 | case ld: LazyLM_MMultOp => 223 | if (ld.size < ld.right.size) { 224 | (ld.compute(), None) 225 | } else { 226 | (ld.right, Option(ld.left)) 227 | } 228 | case dl: LazyML_MMultOp => 229 | if (dl.size < dl.right.size) { 230 | (dl.compute(), None) 231 | } else { 232 | (dl.right.compute(), Option(dl.left)) 233 | } 234 | case dd: LazyMM_MMultOp => 235 | if (dd.size < dd.right.size) { 236 | (dd.compute(), None) 237 | } else { 238 | (dd.right, Option(dd.left)) 239 | } 240 | case _ => (left.compute(), None) 241 | } 242 | var rightScale = 1.0 243 | val (effRight: DenseMatrixWrapper, rightRes) = right match { 244 | case scaled: LazyImDenseScaleOp => 245 | rightScale = scaled.left.value 246 | (scaled.right, None) 247 | case ll: LazyLL_MMultOp => 248 | if (ll.size < ll.right.size) { 249 | (ll.compute(), None) 250 | } else { 251 | (ll.right.compute(), Option(ll.left)) 252 | } 253 | case ld: LazyLM_MMultOp => 254 | if (ld.size < ld.right.size) { 255 | (ld.compute(), None) 256 | } else { 257 | (ld.right, Option(ld.left)) 258 | } 259 | case dl: LazyML_MMultOp => 260 | if (dl.size < dl.right.size) { 261 | (dl.compute(), None) 262 | } else { 263 | (dl.right.compute(), Option(dl.left)) 264 | } 265 | case dd: LazyMM_MMultOp => 266 | if (dd.size < dd.right.size) { 267 | (dd.compute(), None) 268 | } else { 269 | (dd.right, Option(dd.left)) 270 | } 271 | case _ => (right.compute(), None) 272 | } 273 | val middle = 274 | if (leftRes.isEmpty && rightRes.isEmpty) { 275 | val inside = new DenseMatrixWrapper(effLeft.numRows, effRight.numCols, 276 | buffer.getOrElse(new Array[Double](effLeft.numRows * effRight.numCols))) 277 | BLASUtils.gemm(leftScale * rightScale, effLeft, effRight, beta, inside) 278 | inside 279 | } else { 280 | val inside = DenseMatrix.zeros(effLeft.numRows, effRight.numCols) 281 | BLASUtils.gemm(leftScale * rightScale, effLeft, effRight, 1.0, inside) 282 | inside 283 | } 284 | 285 | val rebuildRight = rightRes.getOrElse(None) match { 286 | case l: LazyMatrix => new LazyML_MMultOp(middle, l) 287 | case d: DenseMatrixWrapper => new LazyMM_MMultOp(middle, d) 288 | case None => middle 289 | } 290 | leftRes.getOrElse(None) match { 291 | case l: LazyMatrix => 292 | rebuildRight match { 293 | case r: LazyMatrix => new LazyLL_MMultOp(l, r, buffer, beta).compute() 294 | case d: DenseMatrixWrapper => new LazyLM_MMultOp(l, d, buffer, beta).compute() 295 | } 296 | case ld: DenseMatrixWrapper => 297 | rebuildRight match { 298 | case r: LazyMatrix => new LazyML_MMultOp(ld, r, buffer, beta).compute() 299 | case d: DenseMatrixWrapper => new LazyMM_MMultOp(ld, d, buffer, beta).compute() 300 | } 301 | case None => 302 | rebuildRight match { 303 | case r: LazyMM_MMultOp => new LazyMM_MMultOp(r.left, r.right, buffer, beta).compute() 304 | case l: LazyML_MMultOp => new LazyML_MMultOp(l.left, l.right, buffer, beta).compute() 305 | case d: DenseMatrixWrapper => d 306 | } 307 | } 308 | } 309 | override def compute(into: Option[Array[Double]] = None): DenseMatrixWrapper = { 310 | into.foreach(b => buffer = Option(b)) 311 | result 312 | } 313 | } 314 | 315 | private[linalg] class LazyLM_MMultOp( 316 | val left: LazyMatrix, 317 | val right: MutableMatrix, 318 | into: Option[Array[Double]] = None, 319 | beta: Double = 1.0) extends LazyMMultOp(left, right, into, beta) { 320 | override def apply(i: Int): Double = result(i) 321 | 322 | private var buffer: Option[Array[Double]] = into 323 | 324 | lazy val result: DenseMatrixWrapper = { 325 | var leftScale = 1.0 326 | val (effLeft: DenseMatrixWrapper, leftRes) = left match { 327 | case scaled: LazyImDenseScaleOp => 328 | leftScale = scaled.left.value 329 | (scaled.right, None) 330 | case ll: LazyLL_MMultOp => 331 | if (ll.size < ll.right.size) { 332 | (ll.compute(), None) 333 | } else { 334 | (ll.right.compute(), Option(ll.left)) 335 | } 336 | case ld: LazyLM_MMultOp => 337 | if (ld.size < ld.right.size) { 338 | (ld.compute(), None) 339 | } else { 340 | (ld.right, Option(ld.left)) 341 | } 342 | case dl: LazyML_MMultOp => 343 | if (dl.size < dl.right.size) { 344 | (dl.compute(), None) 345 | } else { 346 | (dl.right.compute(), Option(dl.left)) 347 | } 348 | case dd: LazyMM_MMultOp => 349 | if (dd.size < dd.right.size) { 350 | (dd.compute(), None) 351 | } else { 352 | (dd.right, Option(dd.left)) 353 | } 354 | case _ => (left.compute(), None) 355 | } 356 | 357 | val middle = 358 | if (leftRes.isEmpty) { 359 | val inside = new DenseMatrixWrapper(effLeft.numRows, right.numCols, 360 | buffer.getOrElse(new Array[Double](effLeft.numRows * right.numCols))) 361 | BLASUtils.gemm(leftScale, effLeft, right, beta, inside) 362 | inside 363 | } else { 364 | val inside = DenseMatrix.zeros(effLeft.numRows, right.numCols) 365 | BLASUtils.gemm(leftScale, effLeft, right, 1.0, inside) 366 | inside 367 | } 368 | 369 | leftRes.getOrElse(None) match { 370 | case l: LazyMatrix => new LazyLM_MMultOp(l, middle, buffer, beta).compute() 371 | case ld: DenseMatrixWrapper => new LazyMM_MMultOp(ld, middle, buffer, beta).compute() 372 | case None => middle 373 | } 374 | } 375 | 376 | override def compute(into: Option[Array[Double]] = None): DenseMatrixWrapper = { 377 | into.foreach(b => buffer = Option(b)) 378 | result 379 | } 380 | } 381 | 382 | private[linalg] class LazyML_MMultOp( 383 | val left: MutableMatrix, 384 | val right: LazyMatrix, 385 | into: Option[Array[Double]] = None, 386 | beta: Double = 1.0) extends LazyMMultOp(left, right, into, beta) { 387 | override def apply(i: Int): Double = result(i) 388 | 389 | private var buffer: Option[Array[Double]] = into 390 | 391 | lazy val result: DenseMatrixWrapper = { 392 | var rightScale = 1.0 393 | val (effRight: DenseMatrixWrapper, rightRes) = right match { 394 | case scaled: LazyImDenseScaleOp => 395 | rightScale = scaled.left.value 396 | (scaled.right, None) 397 | case ll: LazyLL_MMultOp => 398 | if (ll.size < ll.right.size) { 399 | (ll.compute(), None) 400 | } else { 401 | (ll.right.compute(), Option(ll.left)) 402 | } 403 | case ld: LazyLM_MMultOp => 404 | if (ld.size < ld.right.size) { 405 | (ld.compute(), None) 406 | } else { 407 | (ld.right, Option(ld.left)) 408 | } 409 | case dl: LazyML_MMultOp => 410 | if (dl.size < dl.right.size) { 411 | (dl.compute(), None) 412 | } else { 413 | (dl.right.compute(), Option(dl.left)) 414 | } 415 | case dd: LazyMM_MMultOp => 416 | if (dd.size < dd.right.size) { 417 | (dd.compute(), None) 418 | } else { 419 | (dd.right, Option(dd.left)) 420 | } 421 | case _ => (right.compute(), None) 422 | } 423 | val middle = 424 | if (rightRes.isEmpty) { 425 | val inside = new DenseMatrixWrapper(left.numRows, effRight.numCols, 426 | buffer.getOrElse(new Array[Double](left.numRows * effRight.numCols))) 427 | BLASUtils.gemm(rightScale, left, effRight, beta, inside) 428 | inside 429 | } else { 430 | val inside = DenseMatrix.zeros(left.numRows, effRight.numCols) 431 | BLASUtils.gemm(rightScale, left, effRight, 0.0, inside) 432 | inside 433 | } 434 | 435 | rightRes.getOrElse(None) match { 436 | case l: LazyMatrix => new LazyML_MMultOp(middle, l, buffer, beta).compute() 437 | case d: DenseMatrixWrapper => new LazyMM_MMultOp(middle, d, buffer, beta).compute() 438 | case None => middle 439 | } 440 | } 441 | 442 | override def compute(into: Option[Array[Double]] = None): DenseMatrixWrapper = { 443 | into.foreach(b => buffer = Option(b)) 444 | result 445 | } 446 | } 447 | 448 | private[linalg] class LazyMM_MMultOp( 449 | val left: MutableMatrix, 450 | val right: MutableMatrix, 451 | into: Option[Array[Double]] = None, 452 | beta: Double = 1.0) extends LazyMMultOp(left, right, into, beta) { 453 | override def apply(i: Int): Double = result(i) 454 | 455 | private var buffer: Option[Array[Double]] = into 456 | 457 | lazy val result: DenseMatrixWrapper = { 458 | val inside = new DenseMatrixWrapper(left.numRows, right.numCols, 459 | buffer.getOrElse(new Array[Double](left.numRows * right.numCols))) 460 | BLASUtils.gemm(1.0, left, right, beta, inside) 461 | inside 462 | } 463 | 464 | override def compute(into: Option[Array[Double]] = None): DenseMatrixWrapper = { 465 | into.foreach(b => buffer = Option(b)) 466 | result 467 | } 468 | } 469 | -------------------------------------------------------------------------------- /src/main/scala/com/brkyvz/spark/linalg/VectorLike.scala: -------------------------------------------------------------------------------- 1 | package com.brkyvz.spark.linalg 2 | 3 | import java.{util => ju} 4 | 5 | import org.apache.spark.mllib.linalg.{SparseVector, DenseVector} 6 | 7 | sealed trait VectorLike extends Serializable { 8 | 9 | /** 10 | * Size of the vector. 11 | */ 12 | def size: Int 13 | 14 | def apply(i: Int): Double 15 | 16 | import funcs._ 17 | def +(y: VectorLike): LazyVector = add(this, y) 18 | def -(y: VectorLike): LazyVector = sub(this, y) 19 | def *(y: VectorLike): LazyVector = emul(this, y) 20 | def /(y: VectorLike): LazyVector = div(this, y) 21 | 22 | def +(y: Scalar): LazyVector = new LazyDenseVSOp(this, y, _ + _) 23 | def -(y: Scalar): LazyVector = new LazyDenseVSOp(this, y, _ - _) 24 | def *(y: Scalar): LazyVector = LazyVectorScaleOp(y, this) 25 | def /(y: Scalar): LazyVector = new LazyDenseVSOp(this, y, _ / _) 26 | } 27 | 28 | /** Dense and Sparse Vectors can be mutated. Lazy vectors are immutable. */ 29 | sealed trait MutableVector extends VectorLike 30 | 31 | /** 32 | * A dense vector represented by a value array. 33 | */ 34 | class DenseVectorWrapper(override val values: Array[Double]) 35 | extends DenseVector(values) with MutableVector { 36 | 37 | override def foreachActive(f: (Int, Double) => Unit) = { 38 | var i = 0 39 | val localValuesSize = values.length 40 | val localValues = values 41 | 42 | while (i < localValuesSize) { 43 | f(i, localValues(i)) 44 | i += 1 45 | } 46 | } 47 | 48 | def :=(x: LazyVector): this.type = x.compute(Option(this.values)).asInstanceOf[this.type] 49 | 50 | def +=(y: VectorLike): this.type = { 51 | y match { 52 | case dd: LazyMM_MV_MultOp => 53 | new LazyMM_MV_MultOp(dd.left, dd.right, Option(this.values)).compute() 54 | case dl: LazyMM_LV_MultOp => 55 | new LazyMM_LV_MultOp(dl.left, dl.right, Option(this.values)).compute() 56 | case ld: LazyLM_MV_MultOp => 57 | new LazyLM_MV_MultOp(ld.left, ld.right, Option(this.values)).compute() 58 | case ll: LazyLM_LV_MultOp => 59 | new LazyLM_LV_MultOp(ll.left, ll.right, Option(this.values)).compute() 60 | case axpy: LazyVectorScaleOp => 61 | new LazyVectorAxpyOp(axpy.left, axpy.right, Option(this.values)).compute() 62 | case mv: MutableVector => new LazyDenseVVOp(this, mv, _ + _).compute(Option(this.values)) 63 | case lzy: LazyVector => new LazyDenseVVOp(this, lzy, _ + _).compute(Option(this.values)) 64 | case _ => throw new UnsupportedOperationException 65 | } 66 | this 67 | } 68 | } 69 | 70 | /** 71 | * A sparse vector represented by an index array and an value array. 72 | * 73 | * @param size size of the vector. 74 | * @param indices index array, assume to be strictly increasing. 75 | * @param values value array, must have the same length as the index array. 76 | */ 77 | class SparseVectorWrapper( 78 | override val size: Int, 79 | override val indices: Array[Int], 80 | override val values: Array[Double]) 81 | extends SparseVector(size, indices, values) with MutableVector { 82 | 83 | override def foreachActive(f: (Int, Double) => Unit) = { 84 | var i = 0 85 | val localValuesSize = values.length 86 | val localIndices = indices 87 | val localValues = values 88 | 89 | while (i < localValuesSize) { 90 | f(localIndices(i), localValues(i)) 91 | i += 1 92 | } 93 | } 94 | 95 | override def apply(i: Int): Double = { 96 | val index = ju.Arrays.binarySearch(indices, i) 97 | if (index < 0) 0.0 else values(index) 98 | } 99 | } 100 | 101 | 102 | sealed trait LazyVector extends VectorLike { 103 | def compute(into: Option[Array[Double]] = None): VectorLike = { 104 | val values = into.getOrElse(new Array[Double](size)) 105 | require(values.length == size, 106 | s"Size of buffer not equal to size of vector. Buffer size: ${values.length} vs. $size") 107 | var i = 0 108 | while (i < size) { 109 | values(i) = this(i) 110 | i += 1 111 | } 112 | new DenseVectorWrapper(values) 113 | } 114 | } 115 | 116 | private[linalg] class LazyDenseVVOp( 117 | left: VectorLike, 118 | right: VectorLike, 119 | operation: (Double, Double) => Double) extends LazyVector { 120 | require(left.size == right.size, 121 | s"Sizes of vectors don't match. left: ${left.size} vs. right: ${right.size}") 122 | override def size = left.size 123 | override def apply(i: Int): Double = operation(left(i), right(i)) 124 | } 125 | 126 | private[linalg] class LazyDenseVSOp( 127 | left: VectorLike, 128 | right: Scalar, 129 | operation: (Double, Double) => Double) extends LazyVector { 130 | override def size = left.size 131 | override def apply(i: Int): Double = operation(left(i), right.value) 132 | } 133 | 134 | private[linalg] class LazyDenseSVOp( 135 | left: Scalar, 136 | right: VectorLike, 137 | operation: (Double, Double) => Double) extends LazyVector { 138 | override def size = right.size 139 | override def apply(i: Int): Double = operation(left.value, right(i)) 140 | } 141 | 142 | private[linalg] class LazySparseVVOp( 143 | left: SparseVectorWrapper, 144 | right: SparseVectorWrapper, 145 | operation: (Double, Double) => Double) extends LazyVector { 146 | require(left.size == right.size, 147 | s"Sizes of vectors don't match. left: ${left.size} vs. right: ${right.size}") 148 | override def size = left.size 149 | override def apply(i: Int): Double = operation(left(i), right(i)) 150 | 151 | private case class IndexMatcher(index: Int, fromLeft: Boolean) 152 | 153 | override def compute(into: Option[Array[Double]] = None): VectorLike = { 154 | val leftIndices = left.indices 155 | val rightIndices = right.indices 156 | val nonZeroIndices = (leftIndices.toSet ++ rightIndices).toArray.sorted 157 | val numNonZeros = nonZeroIndices.length 158 | val values = into.getOrElse(new Array[Double](numNonZeros)) 159 | require(values.length == numNonZeros, "Size of buffer not equal to number of non-zeros. " + 160 | s"Buffer size: ${values.length} vs. $numNonZeros") 161 | var x = 0 162 | var y = 0 163 | var z = 0 164 | val leftValues = left.values 165 | val rightValues = right.values 166 | while (z < numNonZeros) { 167 | val effLeftIndex = if (x == leftIndices.length) size else leftIndices(x) 168 | val effRightIndex = if (x == rightIndices.length) size else rightIndices(x) 169 | if (effLeftIndex == effRightIndex) { 170 | values(z) = operation(leftValues(x), rightValues(y)) 171 | x += 1 172 | y += 1 173 | } else if (effLeftIndex < effRightIndex) { 174 | values(z) = operation(leftValues(x), 0.0) 175 | x += 1 176 | } else { 177 | values(z) = operation(0.0, rightValues(y)) 178 | y += 1 179 | } 180 | z += 1 181 | } 182 | new SparseVectorWrapper(size, nonZeroIndices, values) 183 | } 184 | } 185 | 186 | private[linalg] class LazyVectorMapOp( 187 | parent: VectorLike, 188 | operation: Double => Double) extends LazyVector { 189 | override def size = parent.size 190 | override def apply(i: Int): Double = operation(parent(i)) 191 | 192 | override def compute(into: Option[Array[Double]] = None): VectorLike = { 193 | parent match { 194 | case sp: SparseVectorWrapper => 195 | val indices = sp.indices 196 | val nnz = indices.length 197 | val values = into.getOrElse(new Array[Double](nnz)) 198 | var i = 0 199 | if (values.length == nnz) { 200 | while (i < nnz) { 201 | values(i) = this(indices(i)) 202 | i += 1 203 | } 204 | new SparseVectorWrapper(size, sp.indices, values) 205 | } else if (values.length == size) { 206 | var i = 0 207 | while (i < size) { 208 | values(i) = this(i) 209 | i += 1 210 | } 211 | new DenseVectorWrapper(values) 212 | } else { 213 | throw new IllegalArgumentException("Size of buffer not equal to size of vector. " + 214 | s"Buffer size: ${values.length} vs. $nnz") 215 | } 216 | case _ => 217 | val values = into.getOrElse(new Array[Double](size)) 218 | require(values.length == size, 219 | s"Size of buffer not equal to size of vector. Buffer size: ${values.length} vs. $size") 220 | var i = 0 221 | while (i < size) { 222 | values(i) = this(i) 223 | i += 1 224 | } 225 | new DenseVectorWrapper(values) 226 | } 227 | } 228 | } 229 | 230 | private[linalg] abstract class LazyMVMultOp( 231 | left: MatrixLike, 232 | right: VectorLike, 233 | into: Option[Array[Double]] = None) extends LazyVector 234 | 235 | private[linalg] case class LazyVectorAxpyOp( 236 | left: Scalar, 237 | right: VectorLike, 238 | into: Option[Array[Double]]) extends LazyMVMultOp(left, right, into) { 239 | override def size: Int = right.size 240 | 241 | private var buffer = into 242 | 243 | override def apply(i: Int): Double = result(i) 244 | 245 | lazy val result: VectorLike = { 246 | val scale = left.value 247 | if (scale == 1.0) { 248 | right match { 249 | case lzy: LazyVector => lzy.compute(buffer) 250 | case _ => right 251 | } 252 | } else { 253 | val inside = new DenseVectorWrapper(buffer.getOrElse(new Array[Double](size))) 254 | BLASUtils.axpy(scale, right, inside) 255 | inside 256 | } 257 | } 258 | override def compute(into: Option[Array[Double]] = None): VectorLike = { 259 | into.foreach(b => buffer = Option(b)) 260 | result 261 | } 262 | } 263 | 264 | private[linalg] case class LazyVectorScaleOp( 265 | left: Scalar, 266 | right: VectorLike, 267 | into: Option[Array[Double]] = None) extends LazyMVMultOp(left, right, into) { 268 | override def size: Int = right.size 269 | 270 | private var buffer = into 271 | 272 | override def apply(i: Int): Double = result(i) 273 | 274 | lazy val result: VectorLike = { 275 | val scale = left.value 276 | if (scale == 1.0) { 277 | right match { 278 | case lzy: LazyVector => lzy.compute(buffer) 279 | case _ => right 280 | } 281 | } else { 282 | right match { 283 | case dn: DenseVectorWrapper => 284 | buffer match { 285 | case Some(values) => 286 | require(values.length == size, 287 | "Size of buffer not equal to size of vector. " + 288 | s"Buffer size: ${values.length} vs. $size") 289 | dn.foreachActive { case (i, v) => 290 | values(i) = scale * v 291 | } 292 | new DenseVectorWrapper(values) 293 | case None => 294 | val inside = new DenseVectorWrapper(new Array[Double](size)) 295 | BLASUtils.axpy(scale, dn, inside) 296 | inside 297 | } 298 | case sp: SparseVectorWrapper => 299 | buffer match { 300 | case Some(values) => 301 | if (values.length == size) { 302 | sp.foreachActive { case (i, v) => 303 | values(i) = scale * v 304 | } 305 | new DenseVectorWrapper(values) 306 | } else if (values.length == sp.values.length) { 307 | var i = 0 308 | val length = sp.values.length 309 | val vals = sp.values 310 | while (i < length) { 311 | values(i) = scale * vals(i) 312 | i += 1 313 | } 314 | new SparseVectorWrapper(size, sp.indices, values) 315 | } else { 316 | throw new IllegalArgumentException("The sizes of the vectors don't match for " + 317 | s"scaling into. ${values.length} vs nnz: ${sp.values.length} and size: $size") 318 | } 319 | case None => 320 | val inside = new DenseVectorWrapper(new Array[Double](size)) 321 | BLASUtils.axpy(scale, sp, inside) 322 | inside 323 | } 324 | case lzy: LazyVector => 325 | val inside = lzy.compute(buffer) 326 | BLASUtils.scal(scale, inside) 327 | inside 328 | } 329 | } 330 | } 331 | override def compute(into: Option[Array[Double]] = None): VectorLike = { 332 | into.foreach(b => buffer = Option(b)) 333 | result 334 | } 335 | } 336 | 337 | private[linalg] case class LazyLM_MV_MultOp( 338 | left: LazyMatrix, 339 | right: MutableVector, 340 | into: Option[Array[Double]] = None) extends LazyMVMultOp(left, right, into) { 341 | override def size: Int = left.numRows 342 | 343 | override def apply(i: Int): Double = result(i) 344 | 345 | private var buffer = into 346 | 347 | lazy val result: VectorLike = { 348 | val inside = new DenseVectorWrapper(buffer.getOrElse(new Array[Double](size))) 349 | require(inside.size == size, 350 | s"Size of buffer not equal to size of vector. Buffer size: ${inside.size} vs. $size") 351 | BLASUtils.gemv(1.0, left.compute(), right, 1.0, inside) 352 | inside 353 | } 354 | override def compute(into: Option[Array[Double]] = None): VectorLike = { 355 | into.foreach(b => buffer = Option(b)) 356 | result 357 | } 358 | } 359 | 360 | private[linalg] case class LazyLM_LV_MultOp( 361 | left: LazyMatrix, 362 | right: LazyVector, 363 | into: Option[Array[Double]] = None) extends LazyMVMultOp(left, right, into) { 364 | override def size: Int = left.numRows 365 | 366 | override def apply(i: Int): Double = result(i) 367 | 368 | private var buffer = into 369 | 370 | lazy val result: VectorLike = { 371 | var rightScale = 1.0 372 | val effRight: VectorLike = right match { 373 | case scaled: LazyVectorScaleOp => 374 | rightScale = scaled.left.value 375 | scaled.right 376 | case _ => right.compute() 377 | } 378 | val inside = new DenseVectorWrapper(buffer.getOrElse(new Array[Double](size))) 379 | require(inside.size == size, 380 | s"Size of buffer not equal to size of vector. Buffer size: ${inside.size} vs. $size") 381 | BLASUtils.gemv(rightScale, left.compute(), effRight, 1.0, inside) 382 | inside 383 | } 384 | override def compute(into: Option[Array[Double]] = None): VectorLike = { 385 | into.foreach(b => buffer = Option(b)) 386 | result 387 | } 388 | } 389 | 390 | private[linalg] case class LazyMM_LV_MultOp( 391 | left: MutableMatrix, 392 | right: LazyVector, 393 | into: Option[Array[Double]] = None) extends LazyMVMultOp(left, right, into) { 394 | override def size: Int = left.numRows 395 | 396 | override def apply(i: Int): Double = result(i) 397 | 398 | private var buffer = into 399 | 400 | lazy val result: VectorLike = { 401 | var rightScale = 1.0 402 | val effRight: VectorLike = right match { 403 | case scaled: LazyVectorScaleOp => 404 | rightScale = scaled.left.value 405 | scaled.right 406 | case _ => right.compute() 407 | } 408 | val inside = new DenseVectorWrapper(buffer.getOrElse(new Array[Double](size))) 409 | require(inside.size == size, 410 | s"Size of buffer not equal to size of vector. Buffer size: ${inside.size} vs. $size") 411 | BLASUtils.gemv(rightScale, left, effRight, 1.0, inside) 412 | inside 413 | } 414 | override def compute(into: Option[Array[Double]] = None): VectorLike = { 415 | into.foreach(b => buffer = Option(b)) 416 | result 417 | } 418 | } 419 | 420 | private[linalg] case class LazyMM_MV_MultOp( 421 | left: MutableMatrix, 422 | right: MutableVector, 423 | into: Option[Array[Double]] = None) extends LazyMVMultOp(left, right, into) { 424 | override def size: Int = left.numRows 425 | 426 | override def apply(i: Int): Double = result(i) 427 | 428 | private var buffer = into 429 | 430 | lazy val result: VectorLike = { 431 | val inside = new DenseVector(buffer.getOrElse(new Array[Double](size))) 432 | require(inside.size == size, 433 | s"Size of buffer not equal to size of vector. Buffer size: ${inside.size} vs. $size") 434 | BLASUtils.gemv(1.0, left, right, 1.0, inside) 435 | inside 436 | } 437 | override def compute(into: Option[Array[Double]] = None): VectorLike = { 438 | into.foreach(b => buffer = Option(b)) 439 | result 440 | } 441 | } 442 | -------------------------------------------------------------------------------- /src/main/scala/com/brkyvz/spark/linalg/funcs.scala: -------------------------------------------------------------------------------- 1 | package com.brkyvz.spark.linalg 2 | 3 | object funcs { 4 | 5 | ////////////////////////////////////////////////// 6 | // Matrix Functions 7 | ////////////////////////////////////////////////// 8 | 9 | def add(x: MatrixLike, y: MatrixLike): LazyMatrix = new LazyImDenseMMOp(x, y, _ + _) 10 | def sub(x: MatrixLike, y: MatrixLike): LazyMatrix = new LazyImDenseMMOp(x, y, _ - _) 11 | def emul(x: MatrixLike, y: MatrixLike): LazyMatrix = new LazyImDenseMMOp(x, y, _ * _) 12 | def div(x: MatrixLike, y: MatrixLike): LazyMatrix = new LazyImDenseMMOp(x, y, _ / _) 13 | 14 | def apply(x: MatrixLike, y: MatrixLike, f: (Double, Double) => Double): LazyMatrix = 15 | new LazyImDenseMMOp(x, y, f) 16 | 17 | def apply(x: MatrixLike, f: (Double) => Double): LazyMatrix = new LazyMatrixMapOp(x, f) 18 | 19 | def sin(x: MatrixLike): LazyMatrix = new LazyMatrixMapOp(x, java.lang.Math.sin) 20 | def cos(x: MatrixLike): LazyMatrix = new LazyMatrixMapOp(x, java.lang.Math.cos) 21 | def tan(x: MatrixLike): LazyMatrix = new LazyMatrixMapOp(x, java.lang.Math.tan) 22 | def asin(x: MatrixLike): LazyMatrix = new LazyMatrixMapOp(x, java.lang.Math.asin) 23 | def acos(x: MatrixLike): LazyMatrix = new LazyMatrixMapOp(x, java.lang.Math.acos) 24 | def atan(x: MatrixLike): LazyMatrix = new LazyMatrixMapOp(x, java.lang.Math.atan) 25 | 26 | /** Converts an angle measured in degrees to an approximately equivalent 27 | * angle measured in radians. 28 | * 29 | * @param x an angle, in degrees 30 | * @return the measurement of the angle `x` in radians. 31 | */ 32 | def toRadians(x: MatrixLike): LazyMatrix = new LazyMatrixMapOp(x, java.lang.Math.toRadians) 33 | 34 | /** Converts an angle measured in radians to an approximately equivalent 35 | * angle measured in degrees. 36 | * 37 | * @param x angle, in radians 38 | * @return the measurement of the angle `x` in degrees. 39 | */ 40 | def toDegrees(x: MatrixLike): LazyMatrix = new LazyMatrixMapOp(x, java.lang.Math.toDegrees) 41 | 42 | /** Returns Euler's number `e` raised to the power of a `double` value. 43 | * 44 | * @param x the exponent to raise `e` to. 45 | * @return the value `e^a^`, where `e` is the base of the natural 46 | * logarithms. 47 | */ 48 | def exp(x: MatrixLike): LazyMatrix = new LazyMatrixMapOp(x, java.lang.Math.exp) 49 | def log(x: MatrixLike): LazyMatrix = new LazyMatrixMapOp(x, java.lang.Math.log) 50 | def sqrt(x: MatrixLike): LazyMatrix = new LazyMatrixMapOp(x, java.lang.Math.sqrt) 51 | 52 | def ceil(x: MatrixLike): LazyMatrix = new LazyMatrixMapOp(x, java.lang.Math.ceil) 53 | def floor(x: MatrixLike): LazyMatrix = new LazyMatrixMapOp(x, java.lang.Math.floor) 54 | 55 | /** Returns the `double` value that is closest in value to the 56 | * argument and is equal to a mathematical integer. 57 | * 58 | * @param x a `double` value 59 | * @return the closest floating-point value to a that is equal to a 60 | * mathematical integer. 61 | */ 62 | def rint(x: MatrixLike): LazyMatrix = new LazyMatrixMapOp(x, java.lang.Math.rint) 63 | 64 | /** Converts rectangular coordinates `(x, y)` to polar `(r, theta)`. 65 | * 66 | * @param x the ordinate coordinate 67 | * @param y the abscissa coordinate 68 | * @return the ''theta'' component of the point `(r, theta)` in polar 69 | * coordinates that corresponds to the point `(x, y)` in 70 | * Cartesian coordinates. 71 | */ 72 | def atan2(y: MatrixLike, x: MatrixLike): MatrixLike = 73 | new LazyImDenseMMOp(y, x, java.lang.Math.atan2) 74 | 75 | /** Returns the value of the first argument raised to the power of the 76 | * second argument. 77 | * 78 | * @param x the base. 79 | * @param y the exponent. 80 | * @return the value `x^y^`. 81 | */ 82 | def pow(x: MatrixLike, y: MatrixLike): LazyMatrix = new LazyImDenseMMOp(x, y, java.lang.Math.pow) 83 | 84 | def abs(x: MatrixLike): LazyMatrix = new LazyMatrixMapOp(x, java.lang.Math.abs) 85 | 86 | def max(x: MatrixLike, y: MatrixLike): LazyMatrix = new LazyImDenseMMOp(x, y, java.lang.Math.max) 87 | 88 | def min(x: MatrixLike, y: MatrixLike): LazyMatrix = new LazyImDenseMMOp(x, y, java.lang.Math.min) 89 | 90 | def signum(x: MatrixLike): LazyMatrix = new LazyMatrixMapOp(x, java.lang.Math.signum) 91 | 92 | // ----------------------------------------------------------------------- 93 | // root functions 94 | // ----------------------------------------------------------------------- 95 | 96 | /** Returns the cube root of the given `MatrixLike` value. */ 97 | def cbrt(x: MatrixLike): LazyMatrix = new LazyMatrixMapOp(x, java.lang.Math.cbrt) 98 | 99 | // ----------------------------------------------------------------------- 100 | // exponential functions 101 | // ----------------------------------------------------------------------- 102 | 103 | /** Returns `exp(x) - 1`. */ 104 | def expm1(x: MatrixLike): LazyMatrix = new LazyMatrixMapOp(x, java.lang.Math.expm1) 105 | 106 | // ----------------------------------------------------------------------- 107 | // logarithmic functions 108 | // ----------------------------------------------------------------------- 109 | 110 | /** Returns the natural logarithm of the sum of the given `MatrixLike` value and 1. */ 111 | def log1p(x: MatrixLike): LazyMatrix = new LazyMatrixMapOp(x, java.lang.Math.log1p) 112 | 113 | /** Returns the base 10 logarithm of the given `MatrixLike` value. */ 114 | def log10(x: MatrixLike): LazyMatrix = new LazyMatrixMapOp(x, java.lang.Math.log10) 115 | 116 | // ----------------------------------------------------------------------- 117 | // trigonometric functions 118 | // ----------------------------------------------------------------------- 119 | 120 | /** Returns the hyperbolic sine of the given `MatrixLike` value. */ 121 | def sinh(x: MatrixLike): LazyMatrix = new LazyMatrixMapOp(x, java.lang.Math.sinh) 122 | 123 | /** Returns the hyperbolic cosine of the given `MatrixLike` value. */ 124 | def cosh(x: MatrixLike): LazyMatrix = new LazyMatrixMapOp(x, java.lang.Math.cosh) 125 | 126 | /** Returns the hyperbolic tangent of the given `MatrixLike` value. */ 127 | def tanh(x: MatrixLike): LazyMatrix = new LazyMatrixMapOp(x, java.lang.Math.tanh) 128 | 129 | // ----------------------------------------------------------------------- 130 | // miscellaneous functions 131 | // ----------------------------------------------------------------------- 132 | 133 | /** Returns the square root of the sum of the squares of both given `MatrixLike` 134 | * values without intermediate underflow or overflow. 135 | */ 136 | def hypot(x: MatrixLike, y: MatrixLike): LazyMatrix = 137 | new LazyImDenseMMOp(x, y, java.lang.Math.hypot) 138 | 139 | 140 | ////////////////////////////////////////////////// 141 | // Vector Functions 142 | ////////////////////////////////////////////////// 143 | 144 | def add(x: VectorLike, y: VectorLike): LazyVector = { 145 | (x, y) match { 146 | case (a: SparseVectorWrapper, b: SparseVectorWrapper) => new LazySparseVVOp(a, b, _ + _) 147 | case _ => new LazyDenseVVOp(x, y, _ + _) 148 | } 149 | } 150 | def sub(x: VectorLike, y: VectorLike): LazyVector = { 151 | (x, y) match { 152 | case (a: SparseVectorWrapper, b: SparseVectorWrapper) => new LazySparseVVOp(a, b, _ - _) 153 | case _ => new LazyDenseVVOp(x, y, _ - _) 154 | } 155 | } 156 | def emul(x: VectorLike, y: VectorLike): LazyVector = { 157 | (x, y) match { 158 | case (a: SparseVectorWrapper, b: SparseVectorWrapper) => new LazySparseVVOp(a, b, _ * _) 159 | case _ => new LazyDenseVVOp(x, y, _ * _) 160 | } 161 | } 162 | def div(x: VectorLike, y: VectorLike): LazyVector = { 163 | (x, y) match { 164 | case (a: SparseVectorWrapper, b: SparseVectorWrapper) => new LazySparseVVOp(a, b, _ / _) 165 | case _ => new LazyDenseVVOp(x, y, _ / _) 166 | } 167 | } 168 | 169 | def apply(x: VectorLike, y: VectorLike, f: (Double, Double) => Double): LazyVector = 170 | new LazyDenseVVOp(x, y, f) 171 | 172 | def apply(x: VectorLike, y: Scalar, f: (Double, Double) => Double): LazyVector = 173 | new LazyDenseVSOp(x, y, f) 174 | 175 | def apply(x: Scalar, y: VectorLike, f: (Double, Double) => Double): LazyVector = 176 | new LazyDenseSVOp(x, y, f) 177 | 178 | def apply(x: VectorLike, f: (Double) => Double): LazyVector = new LazyVectorMapOp(x, f) 179 | 180 | def sin(x: VectorLike): LazyVector = new LazyVectorMapOp(x, java.lang.Math.sin) 181 | def cos(x: VectorLike): LazyVector = new LazyVectorMapOp(x, java.lang.Math.cos) 182 | def tan(x: VectorLike): LazyVector = new LazyVectorMapOp(x, java.lang.Math.tan) 183 | def asin(x: VectorLike): LazyVector = new LazyVectorMapOp(x, java.lang.Math.asin) 184 | def acos(x: VectorLike): LazyVector = new LazyVectorMapOp(x, java.lang.Math.acos) 185 | def atan(x: VectorLike): LazyVector = new LazyVectorMapOp(x, java.lang.Math.atan) 186 | 187 | /** Converts an angle measured in degrees to an approximately equivalent 188 | * angle measured in radians. 189 | * 190 | * @param x an angle, in degrees 191 | * @return the measurement of the angle `x` in radians. 192 | */ 193 | def toRadians(x: VectorLike): LazyVector = new LazyVectorMapOp(x, java.lang.Math.toRadians) 194 | 195 | /** Converts an angle measured in radians to an approximately equivalent 196 | * angle measured in degrees. 197 | * 198 | * @param x angle, in radians 199 | * @return the measurement of the angle `x` in degrees. 200 | */ 201 | def toDegrees(x: VectorLike): LazyVector = new LazyVectorMapOp(x, java.lang.Math.toDegrees) 202 | 203 | /** Returns Euler's number `e` raised to the power of a `double` value. 204 | * 205 | * @param x the exponent to raise `e` to. 206 | * @return the value `e^a^`, where `e` is the base of the natural 207 | * logarithms. 208 | */ 209 | def exp(x: VectorLike): LazyVector = new LazyVectorMapOp(x, java.lang.Math.exp) 210 | def log(x: VectorLike): LazyVector = new LazyVectorMapOp(x, java.lang.Math.log) 211 | def sqrt(x: VectorLike): LazyVector = new LazyVectorMapOp(x, java.lang.Math.sqrt) 212 | 213 | def ceil(x: VectorLike): LazyVector = new LazyVectorMapOp(x, java.lang.Math.ceil) 214 | def floor(x: VectorLike): LazyVector = new LazyVectorMapOp(x, java.lang.Math.floor) 215 | 216 | /** Returns the `double` value that is closest in value to the 217 | * argument and is equal to a mathematical integer. 218 | * 219 | * @param x a `double` value 220 | * @return the closest floating-point value to a that is equal to a 221 | * mathematical integer. 222 | */ 223 | def rint(x: VectorLike): LazyVector = new LazyVectorMapOp(x, java.lang.Math.rint) 224 | 225 | /** Converts rectangular coordinates `(x, y)` to polar `(r, theta)`. 226 | * 227 | * @param x the ordinate coordinate 228 | * @param y the abscissa coordinate 229 | * @return the ''theta'' component of the point `(r, theta)` in polar 230 | * coordinates that corresponds to the point `(x, y)` in 231 | * Cartesian coordinates. 232 | */ 233 | def atan2(y: VectorLike, x: VectorLike): LazyVector = 234 | new LazyDenseVVOp(y, x, java.lang.Math.atan2) 235 | 236 | /** Returns the value of the first argument raised to the power of the 237 | * second argument. 238 | * 239 | * @param x the base. 240 | * @param y the exponent. 241 | * @return the value `x^y^`. 242 | */ 243 | def pow(x: VectorLike, y: VectorLike): LazyVector = new LazyDenseVVOp(x, y, java.lang.Math.pow) 244 | 245 | def abs(x: VectorLike): LazyVector = new LazyVectorMapOp(x, java.lang.Math.abs) 246 | 247 | def max(x: VectorLike, y: VectorLike): LazyVector = new LazyDenseVVOp(x, y, java.lang.Math.max) 248 | 249 | def min(x: VectorLike, y: VectorLike): LazyVector = new LazyDenseVVOp(x, y, java.lang.Math.min) 250 | 251 | def signum(x: VectorLike): LazyVector = new LazyVectorMapOp(x, java.lang.Math.signum) 252 | 253 | // ----------------------------------------------------------------------- 254 | // root functions 255 | // ----------------------------------------------------------------------- 256 | 257 | /** Returns the cube root of the given `VectorLike` value. */ 258 | def cbrt(x: VectorLike): LazyVector = new LazyVectorMapOp(x, java.lang.Math.cbrt) 259 | 260 | // ----------------------------------------------------------------------- 261 | // exponential functions 262 | // ----------------------------------------------------------------------- 263 | 264 | /** Returns `exp(x) - 1`. */ 265 | def expm1(x: VectorLike): LazyVector = new LazyVectorMapOp(x, java.lang.Math.expm1) 266 | 267 | // ----------------------------------------------------------------------- 268 | // logarithmic functions 269 | // ----------------------------------------------------------------------- 270 | 271 | /** Returns the natural logarithm of the sum of the given `VectorLike` value and 1. */ 272 | def log1p(x: VectorLike): LazyVector = new LazyVectorMapOp(x, java.lang.Math.log1p) 273 | 274 | /** Returns the base 10 logarithm of the given `VectorLike` value. */ 275 | def log10(x: VectorLike): LazyVector = new LazyVectorMapOp(x, java.lang.Math.log10) 276 | 277 | // ----------------------------------------------------------------------- 278 | // trigonometric functions 279 | // ----------------------------------------------------------------------- 280 | 281 | /** Returns the hyperbolic sine of the given `VectorLike` value. */ 282 | def sinh(x: VectorLike): LazyVector = new LazyVectorMapOp(x, java.lang.Math.sinh) 283 | 284 | /** Returns the hyperbolic cosine of the given `VectorLike` value. */ 285 | def cosh(x: VectorLike): LazyVector = new LazyVectorMapOp(x, java.lang.Math.cosh) 286 | 287 | /** Returns the hyperbolic tangent of the given `VectorLike` value. */ 288 | def tanh(x: VectorLike): LazyVector = new LazyVectorMapOp(x, java.lang.Math.tanh) 289 | 290 | // ----------------------------------------------------------------------- 291 | // miscellaneous functions 292 | // ----------------------------------------------------------------------- 293 | 294 | /** Returns the square root of the sum of the squares of both given `VectorLike` 295 | * values without intermediate underflow or overflow. 296 | */ 297 | def hypot(x: VectorLike, y: VectorLike): LazyVector = 298 | new LazyDenseVVOp(x, y, java.lang.Math.hypot) 299 | 300 | } 301 | -------------------------------------------------------------------------------- /src/main/scala/com/brkyvz/spark/linalg/package.scala: -------------------------------------------------------------------------------- 1 | package com.brkyvz.spark 2 | 3 | import scala.language.implicitConversions 4 | 5 | import org.apache.spark.mllib.linalg._ 6 | 7 | package object linalg { 8 | 9 | implicit def wrapDenseMatrix(x: DenseMatrix): DenseMatrixWrapper = DenseMatrixWrapper(x) 10 | implicit def wrapSparseMatrix(x: SparseMatrix): SparseMatrixWrapper = SparseMatrixWrapper(x) 11 | 12 | implicit def wrapMatrix(x: Matrix): MatrixLike = x match { 13 | case dn: DenseMatrix => DenseMatrixWrapper(dn) 14 | case sp: SparseMatrix => SparseMatrixWrapper(sp) 15 | } 16 | 17 | implicit def wrapDenseVector(x: DenseVector): DenseVectorWrapper = 18 | new DenseVectorWrapper(x.values) 19 | 20 | implicit def wrapSparseVector(x: SparseVector): SparseVectorWrapper = 21 | new SparseVectorWrapper(x.size, x.indices, x.values) 22 | 23 | implicit def wrapVector(x: Vector): VectorLike = x match { 24 | case dn: DenseVector => new DenseVectorWrapper(dn.values) 25 | case sp: SparseVector => new SparseVectorWrapper(sp.size, sp.indices, sp.values) 26 | } 27 | 28 | trait Scalar extends MatrixLike { 29 | val value: Double 30 | 31 | def *(y: MatrixLike): LazyMatrix = LazyImDenseScaleOp(this, y) 32 | def *(y: VectorLike): LazyVector = LazyVectorScaleOp(this, y) 33 | } 34 | 35 | implicit def double2Scalar(x: Double): Scalar = new Scalar { 36 | override val value: Double = x 37 | def apply(i: Int): Double = value 38 | override def numRows = 1 39 | override def numCols = 1 40 | } 41 | 42 | implicit def int2Scalar(x: Int): Scalar = new Scalar { 43 | override val value: Double = x.toDouble 44 | def apply(i: Int): Double = value 45 | override def numRows = 1 46 | override def numCols = 1 47 | } 48 | 49 | implicit def float2Scalar(x: Float): Scalar = new Scalar { 50 | override val value: Double = x.toDouble 51 | def apply(i: Int): Double = value 52 | override def numRows = 1 53 | override def numCols = 1 54 | } 55 | 56 | implicit def long2Scalar(x: Long): Scalar = new Scalar { 57 | override val value: Double = x.toDouble 58 | def apply(i: Int): Double = value 59 | override def numRows = 1 60 | override def numCols = 1 61 | } 62 | } 63 | -------------------------------------------------------------------------------- /src/test/scala/com/brkyvz/spark/linalg/BLASUtilsSuite.scala: -------------------------------------------------------------------------------- 1 | package com.brkyvz.spark.linalg 2 | 3 | import BLASUtils._ 4 | import org.scalatest.FunSuite 5 | 6 | import org.apache.spark.mllib.linalg._ 7 | 8 | import com.brkyvz.spark.util.TestingUtils._ 9 | 10 | class BLASUtilsSuite extends FunSuite { 11 | 12 | test("scal") { 13 | val a = 0.1 14 | val sx = Vectors.sparse(3, Array(0, 2), Array(1.0, -2.0)) 15 | val dx = Vectors.dense(1.0, 0.0, -2.0) 16 | 17 | scal(a, sx) 18 | assert(sx ~== Vectors.sparse(3, Array(0, 2), Array(0.1, -0.2)) absTol 1e-15) 19 | 20 | scal(a, dx) 21 | assert(dx ~== Vectors.dense(0.1, 0.0, -0.2) absTol 1e-15) 22 | } 23 | 24 | test("axpy") { 25 | val alpha = 0.1 26 | val sx = Vectors.sparse(3, Array(0, 2), Array(1.0, -2.0)) 27 | val dx = Vectors.dense(1.0, 0.0, -2.0) 28 | val dy = Array(2.0, 1.0, 0.0) 29 | val expected = Vectors.dense(2.1, 1.0, -0.2) 30 | 31 | val dy1 = Vectors.dense(dy.clone()) 32 | axpy(alpha, sx, dy1) 33 | assert(dy1 ~== expected absTol 1e-15) 34 | 35 | val dy2 = Vectors.dense(dy.clone()) 36 | axpy(alpha, dx, dy2) 37 | assert(dy2 ~== expected absTol 1e-15) 38 | 39 | val sy = Vectors.sparse(4, Array(0, 1), Array(2.0, 1.0)) 40 | 41 | intercept[IllegalArgumentException] { 42 | axpy(alpha, sx, sy) 43 | } 44 | 45 | intercept[IllegalArgumentException] { 46 | axpy(alpha, dx, sy) 47 | } 48 | 49 | withClue("vector sizes must match") { 50 | intercept[Exception] { 51 | axpy(alpha, sx, Vectors.dense(1.0, 2.0)) 52 | } 53 | } 54 | } 55 | 56 | test("dot") { 57 | val sx = Vectors.sparse(3, Array(0, 2), Array(1.0, -2.0)) 58 | val dx = Vectors.dense(1.0, 0.0, -2.0) 59 | val sy = Vectors.sparse(3, Array(0, 1), Array(2.0, 1.0)) 60 | val dy = Vectors.dense(2.0, 1.0, 0.0) 61 | 62 | assert(dot(sx, sy) ~== 2.0 absTol 1e-15) 63 | assert(dot(sy, sx) ~== 2.0 absTol 1e-15) 64 | assert(dot(sx, dy) ~== 2.0 absTol 1e-15) 65 | assert(dot(dy, sx) ~== 2.0 absTol 1e-15) 66 | assert(dot(dx, dy) ~== 2.0 absTol 1e-15) 67 | assert(dot(dy, dx) ~== 2.0 absTol 1e-15) 68 | 69 | assert(dot(sx, sx) ~== 5.0 absTol 1e-15) 70 | assert(dot(dx, dx) ~== 5.0 absTol 1e-15) 71 | assert(dot(sx, dx) ~== 5.0 absTol 1e-15) 72 | assert(dot(dx, sx) ~== 5.0 absTol 1e-15) 73 | 74 | val sx1 = Vectors.sparse(10, Array(0, 3, 5, 7, 8), Array(1.0, 2.0, 3.0, 4.0, 5.0)) 75 | val sx2 = Vectors.sparse(10, Array(1, 3, 6, 7, 9), Array(1.0, 2.0, 3.0, 4.0, 5.0)) 76 | assert(dot(sx1, sx2) ~== 20.0 absTol 1e-15) 77 | assert(dot(sx2, sx1) ~== 20.0 absTol 1e-15) 78 | 79 | withClue("vector sizes must match") { 80 | intercept[Exception] { 81 | dot(sx, Vectors.dense(2.0, 1.0)) 82 | } 83 | } 84 | } 85 | 86 | test("syr") { 87 | val dA = new DenseMatrix(4, 4, 88 | Array(0.0, 1.2, 2.2, 3.1, 1.2, 3.2, 5.3, 4.6, 2.2, 5.3, 1.8, 3.0, 3.1, 4.6, 3.0, 0.8)) 89 | val x = new DenseVector(Array(0.0, 2.7, 3.5, 2.1)) 90 | val alpha = 0.15 91 | 92 | val expected = new DenseMatrix(4, 4, 93 | Array(0.0, 1.2, 2.2, 3.1, 1.2, 4.2935, 6.7175, 5.4505, 2.2, 6.7175, 3.6375, 4.1025, 3.1, 94 | 5.4505, 4.1025, 1.4615)) 95 | 96 | syr(alpha, x, dA) 97 | 98 | assert(dA ~== expected absTol 1e-15) 99 | 100 | val dB = 101 | new DenseMatrix(3, 4, Array(0.0, 1.2, 2.2, 3.1, 1.2, 3.2, 5.3, 4.6, 2.2, 5.3, 1.8, 3.0)) 102 | 103 | withClue("Matrix A must be a symmetric Matrix") { 104 | intercept[Exception] { 105 | syr(alpha, x, dB) 106 | } 107 | } 108 | 109 | val dC = 110 | new DenseMatrix(3, 3, Array(0.0, 1.2, 2.2, 1.2, 3.2, 5.3, 2.2, 5.3, 1.8)) 111 | 112 | withClue("Size of vector must match the rank of matrix") { 113 | intercept[Exception] { 114 | syr(alpha, x, dC) 115 | } 116 | } 117 | 118 | val y = new DenseVector(Array(0.0, 2.7, 3.5, 2.1, 1.5)) 119 | 120 | withClue("Size of vector must match the rank of matrix") { 121 | intercept[Exception] { 122 | syr(alpha, y, dA) 123 | } 124 | } 125 | 126 | val xSparse = new SparseVector(4, Array(0, 2, 3), Array(1.0, 3.0, 4.0)) 127 | val dD = new DenseMatrix(4, 4, 128 | Array(0.0, 1.2, 2.2, 3.1, 1.2, 3.2, 5.3, 4.6, 2.2, 5.3, 1.8, 3.0, 3.1, 4.6, 3.0, 0.8)) 129 | syr(0.1, xSparse, dD) 130 | val expectedSparse = new DenseMatrix(4, 4, 131 | Array(0.1, 1.2, 2.5, 3.5, 1.2, 3.2, 5.3, 4.6, 2.5, 5.3, 2.7, 4.2, 3.5, 4.6, 4.2, 2.4)) 132 | assert(dD ~== expectedSparse absTol 1e-15) 133 | } 134 | 135 | test("gemm") { 136 | val dA = 137 | new DenseMatrix(4, 3, Array(0.0, 1.0, 0.0, 0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 3.0)) 138 | val sA = new SparseMatrix(4, 3, Array(0, 1, 3, 4), Array(1, 0, 2, 3), Array(1.0, 2.0, 1.0, 3.0)) 139 | 140 | val dB = new DenseMatrix(3, 2, Array(1.0, 0.0, 0.0, 0.0, 2.0, 1.0)) 141 | val sB = new SparseMatrix(3, 2, Array(0, 1, 3), Array(0, 1, 2), Array(1.0, 2.0, 1.0)) 142 | val expected = new DenseMatrix(4, 2, Array(0.0, 1.0, 0.0, 0.0, 4.0, 0.0, 2.0, 3.0)) 143 | val dBTman = new DenseMatrix(2, 3, Array(1.0, 0.0, 0.0, 2.0, 0.0, 1.0)) 144 | val sBTman = new SparseMatrix(2, 3, Array(0, 1, 2, 3), Array(0, 1, 1), Array(1.0, 2.0, 1.0)) 145 | 146 | assert(dA.multiply(dB) ~== expected absTol 1e-15) 147 | assert(sA.multiply(dB) ~== expected absTol 1e-15) 148 | 149 | val C1 = new DenseMatrix(4, 2, Array(1.0, 0.0, 2.0, 1.0, 0.0, 0.0, 1.0, 0.0)) 150 | val C2 = C1.copy 151 | val expected2 = new DenseMatrix(4, 2, Array(2.0, 1.0, 4.0, 2.0, 4.0, 0.0, 4.0, 3.0)) 152 | val expected3 = new DenseMatrix(4, 2, Array(2.0, 2.0, 4.0, 2.0, 8.0, 0.0, 6.0, 6.0)) 153 | val expected4 = new DenseMatrix(4, 2, Array(5.0, 0.0, 10.0, 5.0, 0.0, 0.0, 5.0, 0.0)) 154 | val expected5 = new DenseMatrix(4, 2, Array(1.0, 0.0, 2.0, 1.0, 0.0, 0.0, 1.0, 0.0)) 155 | 156 | gemm(1.0, dA, dB, 0.0, C2) 157 | assert(C2 ~== expected absTol 1e-15) 158 | gemm(1.0, sA, dB, 0.0, C2) 159 | assert(C2 ~== expected absTol 1e-15) 160 | 161 | withClue("columns of A don't match the rows of B") { 162 | intercept[Exception] { 163 | gemm(1.0, dA.transpose, dB, 2.0, C1) 164 | } 165 | } 166 | 167 | val dATman = 168 | new DenseMatrix(3, 4, Array(0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 3.0)) 169 | val sATman = 170 | new SparseMatrix(3, 4, Array(0, 1, 2, 3, 4), Array(1, 0, 1, 2), Array(2.0, 1.0, 1.0, 3.0)) 171 | 172 | val dATT = dATman.transpose 173 | val sATT = sATman.transpose 174 | val BTT = dBTman.transpose 175 | val sBTT = dBTman.toSparse.transpose 176 | 177 | val combinations = Seq((1.0, 0.0, expected), (1.0, 2.0, expected2), (2.0, 2.0, expected3), 178 | (0.0, 5.0, expected4), (0.0, 1.0, expected5)) 179 | 180 | combinations.foreach { case (alpha, beta, expectation) => 181 | def checkResult(a: MatrixLike, b: MatrixLike): Unit = { 182 | val Cres = C1.copy 183 | gemm(alpha, a, b, beta, Cres) 184 | assert(Cres ~== expectation absTol 1e-15) 185 | } 186 | checkResult(dA, dB) 187 | checkResult(dA, sB) 188 | checkResult(dA, BTT) 189 | checkResult(dA, sBTT) 190 | checkResult(sA, dB) 191 | checkResult(sA, sB) 192 | checkResult(sA, BTT) 193 | checkResult(sA, sBTT) 194 | checkResult(dATT, dB) 195 | checkResult(dATT, BTT) 196 | checkResult(dATT, sB) 197 | checkResult(dATT, sBTT) 198 | checkResult(sATT, dB) 199 | checkResult(sATT, BTT) 200 | checkResult(sATT, sB) 201 | checkResult(sATT, sBTT) 202 | } 203 | } 204 | 205 | test("gemv") { 206 | val dA = 207 | new DenseMatrix(4, 3, Array(0.0, 1.0, 0.0, 0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 3.0)) 208 | val sA = new SparseMatrix(4, 3, Array(0, 1, 3, 4), Array(1, 0, 2, 3), Array(1.0, 2.0, 1.0, 3.0)) 209 | 210 | val dx = new DenseVector(Array(1.0, 2.0, 3.0)) 211 | val sx = dx.toSparse 212 | val expected = new DenseVector(Array(4.0, 1.0, 2.0, 9.0)) 213 | 214 | assert(dA.multiply(dx) ~== expected absTol 1e-15) 215 | assert(sA.multiply(dx) ~== expected absTol 1e-15) 216 | assert(dA.multiply(sx) ~== expected absTol 1e-15) 217 | assert(sA.multiply(sx) ~== expected absTol 1e-15) 218 | 219 | val y1 = new DenseVector(Array(1.0, 3.0, 1.0, 0.0)) 220 | 221 | val expected2 = new DenseVector(Array(6.0, 7.0, 4.0, 9.0)) 222 | val expected3 = new DenseVector(Array(10.0, 8.0, 6.0, 18.0)) 223 | 224 | withClue("columns of A don't match the rows of B") { 225 | intercept[Exception] { 226 | gemv(1.0, dA.transpose, dx, 2.0, y1) 227 | } 228 | intercept[Exception] { 229 | gemv(1.0, sA.transpose, dx, 2.0, y1) 230 | } 231 | intercept[Exception] { 232 | gemv(1.0, dA.transpose, sx, 2.0, y1) 233 | } 234 | intercept[Exception] { 235 | gemv(1.0, sA.transpose, sx, 2.0, y1) 236 | } 237 | } 238 | 239 | val dAT = 240 | new DenseMatrix(3, 4, Array(0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 3.0)) 241 | val sAT = 242 | new SparseMatrix(3, 4, Array(0, 1, 2, 3, 4), Array(1, 0, 1, 2), Array(2.0, 1.0, 1.0, 3.0)) 243 | 244 | val dATT = dAT.transpose 245 | val sATT = sAT.transpose 246 | 247 | val combinations = Seq((1.0, 0.0, expected), (1.0, 2.0, expected2), (2.0, 2.0, expected3)) 248 | 249 | combinations.foreach { case (alpha, beta, expectation) => 250 | def checkResult(a: MatrixLike, b: VectorLike): Unit = { 251 | val Yres = y1.copy 252 | gemv(alpha, a, b, beta, Yres) 253 | assert(Yres ~== expectation absTol 1e-15) 254 | } 255 | checkResult(dA, dx) 256 | checkResult(dA, sx) 257 | checkResult(sA, dx) 258 | checkResult(sA, sx) 259 | checkResult(dATT, dx) 260 | checkResult(dATT, sx) 261 | checkResult(sATT, dx) 262 | checkResult(sATT, sx) 263 | } 264 | } 265 | } 266 | -------------------------------------------------------------------------------- /src/test/scala/com/brkyvz/spark/linalg/MatricesSuite.scala: -------------------------------------------------------------------------------- 1 | package com.brkyvz.spark.linalg 2 | 3 | import com.holdenkarau.spark.testing.PerTestSparkContext 4 | import org.scalatest.FunSuite 5 | 6 | import org.apache.spark.mllib.linalg.{DenseMatrix, Matrices} 7 | 8 | class MatricesSuite extends FunSuite with PerTestSparkContext { 9 | 10 | private val a = Matrices.dense(2, 2, Array(1, 2, 3, 4)) 11 | private val b = new DenseMatrix(2, 2, Array(0, -2, 0, -2)) 12 | private val c = Matrices.sparse(2, 2, Array(0, 1, 1), Array(0), Array(1.0)) 13 | private val x = Matrices.sparse(3, 2, Array(0, 1, 2), Array(0, 2), Array(0.5, 2.0)) 14 | 15 | test("basic arithmetic") { 16 | val buffer = new Array[Double](4) 17 | val wrapper = new DenseMatrix(2, 2, buffer) 18 | 19 | wrapper := a + b 20 | assert(wrapper.values.toSeq === Seq(1.0, 0.0, 3.0, 2.0)) 21 | assert(buffer.toSeq === Seq(1.0, 0.0, 3.0, 2.0)) 22 | 23 | val buffer2 = new Array[Double](4) 24 | (a + b).compute(Option(buffer2)) 25 | assert(buffer2.toSeq === Seq(1.0, 0.0, 3.0, 2.0)) 26 | 27 | wrapper := a * 2 28 | assert(wrapper.values.toSeq === Seq(2.0, 4.0, 6.0, 8.0)) 29 | 30 | wrapper := a - c 31 | assert(wrapper.values.toSeq === Seq(0.0, 2.0, 3.0, 4.0)) 32 | 33 | val d = b.copy 34 | 35 | d += -2 36 | assert(d.values.toSeq === Seq(-2.0, -4.0, -2.0, -4.0)) 37 | } 38 | 39 | test("requires right buffer size") { 40 | val wrongSizedBuffer = new Array[Double](5) 41 | intercept[IllegalArgumentException]((a + b).compute(Option(wrongSizedBuffer))) 42 | } 43 | 44 | test("size mismatch throws error") { 45 | intercept[IllegalArgumentException]((a + x).compute()) 46 | } 47 | 48 | test("scalar op") { 49 | val buffer = new Array[Double](4) 50 | (a + 2).compute(Option(buffer)) 51 | assert(buffer.toSeq === Seq(3.0, 4.0, 5.0, 6.0)) 52 | (c + 2).compute(Option(buffer)) 53 | assert(buffer.toSeq === Seq(3.0, 2.0, 2.0, 2.0)) 54 | val sparseBuffer = new Array[Double](6) 55 | (x * 3).compute(Option(sparseBuffer)) 56 | assert(sparseBuffer.toSeq === Seq(1.5, 0.0, 0.0, 0.0, 0.0, 6.0)) 57 | } 58 | 59 | test("funcs") { 60 | import com.brkyvz.spark.linalg.funcs._ 61 | val buffer = new Array[Double](4) 62 | val buffer2 = new Array[Double](6) 63 | pow(a, c).compute(Option(buffer)) 64 | assert(buffer.toSeq === Seq(1.0, 1.0, 1.0, 1.0)) 65 | val sparseBuffer = new Array[Double](6) 66 | exp(x).compute(Option(sparseBuffer)) 67 | assert(sparseBuffer.toSeq === 68 | Seq(java.lang.Math.exp(0.5), 1.0, 1.0, 1.0, 1.0, java.lang.Math.exp(2.0))) 69 | apply(a, c, (m: Double, n: Double) => m + n).compute(Option(buffer)) 70 | assert(buffer.toSeq === Seq(2.0, 2.0, 3.0, 4.0)) 71 | } 72 | 73 | test("blas methods") { 74 | var d = new DenseMatrixWrapper(2, 2, a.copy.toArray) 75 | d += a * 3 76 | val e = (a * 4).compute() 77 | assert(d.asInstanceOf[DenseMatrix].values.toSeq === e.asInstanceOf[DenseMatrix].values.toSeq) 78 | 79 | val A = DenseMatrix.eye(2) 80 | A += c * a 81 | assert(A.values.toSeq === Seq(2.0, 0.0, 3.0, 1.0)) 82 | 83 | val B = DenseMatrix.zeros(2, 2) 84 | B := a * b 85 | val firstVals = B.values.clone().toSeq 86 | B := a * b 87 | assert(B.values.toSeq === firstVals) 88 | } 89 | 90 | test("rdd methods") { 91 | val rdd = sc.parallelize(Seq(a, b, c)) 92 | val Array(res1, res2, res3) = 93 | rdd.map(v => (v + 2).compute().asInstanceOf[DenseMatrix]).collect() 94 | assert(res1.values.toSeq === Seq(3.0, 4.0, 5.0, 6.0)) 95 | assert(res2.values.toSeq === Seq(2.0, 0.0, 2.0, 0.0)) 96 | assert(res3.values.toSeq === Seq(3.0, 2.0, 2.0, 2.0)) 97 | val Array(res4, res5, res6) = rdd.map(v => v + 2).map(_ - 1).collect() 98 | assert(res4.compute().asInstanceOf[DenseMatrix].values.toSeq === Seq(2.0, 3.0, 4.0, 5.0)) 99 | assert(res5.compute().asInstanceOf[DenseMatrix].values.toSeq === Seq(1.0, -1.0, 1.0, -1.0)) 100 | assert(res6.compute().asInstanceOf[DenseMatrix].values.toSeq === Seq(2.0, 1.0, 1.0, 1.0)) 101 | 102 | val sum = rdd.aggregate(DenseMatrix.zeros(2, 2))( 103 | seqOp = (base, element) => base += element, 104 | combOp = (base1, base2) => base1 += base2 105 | ) 106 | assert(sum.values.toSeq === Seq(2.0, 0.0, 3.0, 2.0)) 107 | val sum2 = rdd.aggregate(DenseMatrix.zeros(2, 2))( 108 | seqOp = (base, element) => base += element * 2 - 1, 109 | combOp = (base1, base2) => base1 += base2 110 | ) 111 | assert(sum2.values.toSeq === Seq(1.0, -3.0, 3.0, 1.0)) 112 | } 113 | } 114 | -------------------------------------------------------------------------------- /src/test/scala/com/brkyvz/spark/linalg/VectorsSuite.scala: -------------------------------------------------------------------------------- 1 | package com.brkyvz.spark.linalg 2 | 3 | import java.util.Random 4 | 5 | import com.holdenkarau.spark.testing.SharedSparkContext 6 | import org.scalatest.FunSuite 7 | 8 | import org.apache.spark.mllib.linalg.{DenseMatrix, DenseVector, Vectors} 9 | 10 | class VectorsSuite extends FunSuite with SharedSparkContext { 11 | 12 | private val a = Vectors.dense(1, 2, 3, 4) 13 | private val b = new DenseVector(Array(0, -2, 0, -2)) 14 | private val c = Vectors.sparse(4, Seq((0, 1.0))) 15 | private val x = Vectors.sparse(5, Seq((3, 0.5))) 16 | 17 | test("basic arithmetic") { 18 | val buffer = new Array[Double](4) 19 | val wrapper = new DenseVector(buffer) 20 | 21 | wrapper := a + b 22 | assert(wrapper.values.toSeq === Seq(1.0, 0.0, 3.0, 2.0)) 23 | assert(buffer.toSeq === Seq(1.0, 0.0, 3.0, 2.0)) 24 | 25 | val buffer2 = new Array[Double](4) 26 | (a + b).compute(Option(buffer2)) 27 | assert(buffer2.toSeq === Seq(1.0, 0.0, 3.0, 2.0)) 28 | 29 | wrapper := a * 2 30 | assert(wrapper.values.toSeq === Seq(2.0, 4.0, 6.0, 8.0)) 31 | 32 | wrapper := a - c 33 | assert(wrapper.values.toSeq === Seq(0.0, 2.0, 3.0, 4.0)) 34 | } 35 | 36 | test("requires right buffer size") { 37 | val wrongSizedBuffer = new Array[Double](5) 38 | intercept[IllegalArgumentException]((a + b).compute(Option(wrongSizedBuffer))) 39 | } 40 | 41 | test("size mismatch throws error") { 42 | intercept[IllegalArgumentException]((a + x).compute()) 43 | } 44 | 45 | test("scalar op") { 46 | val buffer = new Array[Double](4) 47 | (a + 2).compute(Option(buffer)) 48 | assert(buffer.toSeq === Seq(3.0, 4.0, 5.0, 6.0)) 49 | (c + 2).compute(Option(buffer)) 50 | assert(buffer.toSeq === Seq(3.0, 2.0, 2.0, 2.0)) 51 | val sparseBuffer = new Array[Double](1) 52 | (x * 3).compute(Option(sparseBuffer)) 53 | assert(sparseBuffer.toSeq === Seq(1.5)) 54 | } 55 | 56 | test("sparse ops remain sparse") { 57 | val d = Vectors.sparse(4, Seq((1, 1.0), (3, 2.0))) 58 | val res = (c + d).compute() 59 | assert(res(0) === 1.0) 60 | assert(res(1) === 1.0) 61 | assert(res(2) === 0.0) 62 | assert(res(3) === 2.0) 63 | val sparse = res.asInstanceOf[SparseVectorWrapper] 64 | assert(sparse.values.length === 3) 65 | assert(sparse.indices.length === 3) 66 | assert(sparse.indices.toSeq === Seq(0, 1, 3)) 67 | assert(sparse.size === 4) 68 | } 69 | 70 | test("funcs") { 71 | import funcs._ 72 | val buffer = new Array[Double](4) 73 | val buffer2 = new Array[Double](5) 74 | pow(a, c).compute(Option(buffer)) 75 | assert(buffer.toSeq === Seq(1.0, 1.0, 1.0, 1.0)) 76 | val sparseBuffer = new Array[Double](1) 77 | exp(x).compute(Option(sparseBuffer)) 78 | assert(sparseBuffer.toSeq === Seq(java.lang.Math.exp(0.5))) 79 | exp(x).compute(Option(buffer2)) 80 | assert(buffer2.toSeq === Seq(1.0, 1.0, 1.0, java.lang.Math.exp(0.5), 1.0)) 81 | apply(a, c, (m: Double, n: Double) => m + n).compute(Option(buffer)) 82 | assert(buffer.toSeq === Seq(2.0, 2.0, 3.0, 4.0)) 83 | } 84 | 85 | test("blas methods") { 86 | var d = new DenseVectorWrapper(a.copy.toArray) 87 | d += a * 3 88 | val e = (a * 4).compute() 89 | assert(d.asInstanceOf[DenseVector].values.toSeq === e.asInstanceOf[DenseVector].values.toSeq) 90 | 91 | val A = DenseMatrix.rand(5, 4, new Random()) 92 | val resSpark = A.multiply(a) 93 | val buffer = new Array[Double](5) 94 | val res = (A * a).compute(Option(buffer)) 95 | assert(resSpark.values.toSeq === buffer.toSeq) 96 | } 97 | 98 | test("rdd methods") { 99 | val rdd = sc.parallelize(Seq(a, b, c)) 100 | val Array(res1, res2, res3) = 101 | rdd.map(v => (v + 2).compute().asInstanceOf[DenseVector]).collect() 102 | assert(res1.values.toSeq === Seq(3.0, 4.0, 5.0, 6.0)) 103 | assert(res2.values.toSeq === Seq(2.0, 0.0, 2.0, 0.0)) 104 | assert(res3.values.toSeq === Seq(3.0, 2.0, 2.0, 2.0)) 105 | val Array(res4, res5, res6) = rdd.map(v => v + 2).map(_ - 1).collect() 106 | assert(res4.compute().asInstanceOf[DenseVector].values.toSeq === Seq(2.0, 3.0, 4.0, 5.0)) 107 | assert(res5.compute().asInstanceOf[DenseVector].values.toSeq === Seq(1.0, -1.0, 1.0, -1.0)) 108 | assert(res6.compute().asInstanceOf[DenseVector].values.toSeq === Seq(2.0, 1.0, 1.0, 1.0)) 109 | 110 | val sum = rdd.aggregate(new DenseVector(Array(0, 0, 0, 0)))( 111 | seqOp = (base, element) => base += element, 112 | combOp = (base1, base2) => base1 += base2 113 | ) 114 | assert(sum.values.toSeq === Seq(2.0, 0.0, 3.0, 2.0)) 115 | val sum2 = rdd.aggregate(new DenseVector(Array(0, 0, 0, 0)))( 116 | seqOp = (base, element) => base += element * 2 - 1, 117 | combOp = (base1, base2) => base1 += base2 118 | ) 119 | assert(sum2.values.toSeq === Seq(1.0, -3.0, 3.0, 1.0)) 120 | } 121 | } 122 | -------------------------------------------------------------------------------- /src/test/scala/com/brkyvz/spark/util/TestingUtils.scala: -------------------------------------------------------------------------------- 1 | package com.brkyvz.spark.util 2 | 3 | import org.apache.spark.mllib.linalg.{Vector, Matrix} 4 | import org.scalatest.exceptions.TestFailedException 5 | 6 | object TestingUtils { 7 | 8 | val ABS_TOL_MSG = " using absolute tolerance" 9 | val REL_TOL_MSG = " using relative tolerance" 10 | 11 | /** 12 | * Private helper function for comparing two values using relative tolerance. 13 | * Note that if x or y is extremely close to zero, i.e., smaller than Double.MinPositiveValue, 14 | * the relative tolerance is meaningless, so the exception will be raised to warn users. 15 | */ 16 | private def RelativeErrorComparison(x: Double, y: Double, eps: Double): Boolean = { 17 | val absX = math.abs(x) 18 | val absY = math.abs(y) 19 | val diff = math.abs(x - y) 20 | if (x == y) { 21 | true 22 | } else if (absX < Double.MinPositiveValue || absY < Double.MinPositiveValue) { 23 | throw new TestFailedException( 24 | s"$x or $y is extremely close to zero, so the relative tolerance is meaningless.", 0) 25 | } else { 26 | diff < eps * math.min(absX, absY) 27 | } 28 | } 29 | 30 | /** 31 | * Private helper function for comparing two values using absolute tolerance. 32 | */ 33 | private def AbsoluteErrorComparison(x: Double, y: Double, eps: Double): Boolean = { 34 | math.abs(x - y) < eps 35 | } 36 | 37 | case class CompareDoubleRightSide( 38 | fun: (Double, Double, Double) => Boolean, y: Double, eps: Double, method: String) 39 | 40 | /** 41 | * Implicit class for comparing two double values using relative tolerance or absolute tolerance. 42 | */ 43 | implicit class DoubleWithAlmostEquals(val x: Double) { 44 | 45 | /** 46 | * When the difference of two values are within eps, returns true; otherwise, returns false. 47 | */ 48 | def ~=(r: CompareDoubleRightSide): Boolean = r.fun(x, r.y, r.eps) 49 | 50 | /** 51 | * When the difference of two values are within eps, returns false; otherwise, returns true. 52 | */ 53 | def !~=(r: CompareDoubleRightSide): Boolean = !r.fun(x, r.y, r.eps) 54 | 55 | /** 56 | * Throws exception when the difference of two values are NOT within eps; 57 | * otherwise, returns true. 58 | */ 59 | def ~==(r: CompareDoubleRightSide): Boolean = { 60 | if (!r.fun(x, r.y, r.eps)) { 61 | throw new TestFailedException( 62 | s"Expected $x and ${r.y} to be within ${r.eps}${r.method}.", 0) 63 | } 64 | true 65 | } 66 | 67 | /** 68 | * Throws exception when the difference of two values are within eps; otherwise, returns true. 69 | */ 70 | def !~==(r: CompareDoubleRightSide): Boolean = { 71 | if (r.fun(x, r.y, r.eps)) { 72 | throw new TestFailedException( 73 | s"Did not expect $x and ${r.y} to be within ${r.eps}${r.method}.", 0) 74 | } 75 | true 76 | } 77 | 78 | /** 79 | * Comparison using absolute tolerance. 80 | */ 81 | def absTol(eps: Double): CompareDoubleRightSide = 82 | CompareDoubleRightSide(AbsoluteErrorComparison, x, eps, ABS_TOL_MSG) 83 | 84 | /** 85 | * Comparison using relative tolerance. 86 | */ 87 | def relTol(eps: Double): CompareDoubleRightSide = 88 | CompareDoubleRightSide(RelativeErrorComparison, x, eps, REL_TOL_MSG) 89 | 90 | override def toString: String = x.toString 91 | } 92 | 93 | case class CompareVectorRightSide( 94 | fun: (Vector, Vector, Double) => Boolean, y: Vector, eps: Double, method: String) 95 | 96 | /** 97 | * Implicit class for comparing two vectors using relative tolerance or absolute tolerance. 98 | */ 99 | implicit class VectorWithAlmostEquals(val x: Vector) { 100 | 101 | /** 102 | * When the difference of two vectors are within eps, returns true; otherwise, returns false. 103 | */ 104 | def ~=(r: CompareVectorRightSide): Boolean = r.fun(x, r.y, r.eps) 105 | 106 | /** 107 | * When the difference of two vectors are within eps, returns false; otherwise, returns true. 108 | */ 109 | def !~=(r: CompareVectorRightSide): Boolean = !r.fun(x, r.y, r.eps) 110 | 111 | /** 112 | * Throws exception when the difference of two vectors are NOT within eps; 113 | * otherwise, returns true. 114 | */ 115 | def ~==(r: CompareVectorRightSide): Boolean = { 116 | if (!r.fun(x, r.y, r.eps)) { 117 | throw new TestFailedException( 118 | s"Expected $x and ${r.y} to be within ${r.eps}${r.method} for all elements.", 0) 119 | } 120 | true 121 | } 122 | 123 | /** 124 | * Throws exception when the difference of two vectors are within eps; otherwise, returns true. 125 | */ 126 | def !~==(r: CompareVectorRightSide): Boolean = { 127 | if (r.fun(x, r.y, r.eps)) { 128 | throw new TestFailedException( 129 | s"Did not expect $x and ${r.y} to be within ${r.eps}${r.method} for all elements.", 0) 130 | } 131 | true 132 | } 133 | 134 | /** 135 | * Comparison using absolute tolerance. 136 | */ 137 | def absTol(eps: Double): CompareVectorRightSide = CompareVectorRightSide( 138 | (x: Vector, y: Vector, eps: Double) => { 139 | x.toArray.zip(y.toArray).forall(x => x._1 ~= x._2 absTol eps) 140 | }, x, eps, ABS_TOL_MSG) 141 | 142 | /** 143 | * Comparison using relative tolerance. Note that comparing against sparse vector 144 | * with elements having value of zero will raise exception because it involves with 145 | * comparing against zero. 146 | */ 147 | def relTol(eps: Double): CompareVectorRightSide = CompareVectorRightSide( 148 | (x: Vector, y: Vector, eps: Double) => { 149 | x.toArray.zip(y.toArray).forall(x => x._1 ~= x._2 relTol eps) 150 | }, x, eps, REL_TOL_MSG) 151 | 152 | override def toString: String = x.toString 153 | } 154 | 155 | case class CompareMatrixRightSide( 156 | fun: (Matrix, Matrix, Double) => Boolean, y: Matrix, eps: Double, method: String) 157 | 158 | /** 159 | * Implicit class for comparing two matrices using relative tolerance or absolute tolerance. 160 | */ 161 | implicit class MatrixWithAlmostEquals(val x: Matrix) { 162 | 163 | /** 164 | * When the difference of two matrices are within eps, returns true; otherwise, returns false. 165 | */ 166 | def ~=(r: CompareMatrixRightSide): Boolean = r.fun(x, r.y, r.eps) 167 | 168 | /** 169 | * When the difference of two matrices are within eps, returns false; otherwise, returns true. 170 | */ 171 | def !~=(r: CompareMatrixRightSide): Boolean = !r.fun(x, r.y, r.eps) 172 | 173 | /** 174 | * Throws exception when the difference of two matrices are NOT within eps; 175 | * otherwise, returns true. 176 | */ 177 | def ~==(r: CompareMatrixRightSide): Boolean = { 178 | if (!r.fun(x, r.y, r.eps)) { 179 | throw new TestFailedException( 180 | s"Expected \n$x\n and \n${r.y}\n to be within ${r.eps}${r.method} for all elements.", 0) 181 | } 182 | true 183 | } 184 | 185 | /** 186 | * Throws exception when the difference of two matrices are within eps; otherwise, returns true. 187 | */ 188 | def !~==(r: CompareMatrixRightSide): Boolean = { 189 | if (r.fun(x, r.y, r.eps)) { 190 | throw new TestFailedException( 191 | s"Did not expect \n$x\n and \n${r.y}\n to be within " + 192 | "${r.eps}${r.method} for all elements.", 0) 193 | } 194 | true 195 | } 196 | 197 | /** 198 | * Comparison using absolute tolerance. 199 | */ 200 | def absTol(eps: Double): CompareMatrixRightSide = CompareMatrixRightSide( 201 | (x: Matrix, y: Matrix, eps: Double) => { 202 | x.toArray.zip(y.toArray).forall(x => x._1 ~= x._2 absTol eps) 203 | }, x, eps, ABS_TOL_MSG) 204 | 205 | /** 206 | * Comparison using relative tolerance. Note that comparing against sparse vector 207 | * with elements having value of zero will raise exception because it involves with 208 | * comparing against zero. 209 | */ 210 | def relTol(eps: Double): CompareMatrixRightSide = CompareMatrixRightSide( 211 | (x: Matrix, y: Matrix, eps: Double) => { 212 | x.toArray.zip(y.toArray).forall(x => x._1 ~= x._2 relTol eps) 213 | }, x, eps, REL_TOL_MSG) 214 | 215 | override def toString: String = x.toString 216 | } 217 | } 218 | --------------------------------------------------------------------------------