├── .gitignore ├── COPYING ├── README.md ├── TODO ├── build.sbt ├── perf ├── build.sbt ├── lib │ └── optimized-numeric-plugin_2.9.1-0.1.jar └── src │ └── main │ └── scala │ └── Main.scala ├── plugin ├── Makefile ├── build.sbt └── src │ ├── main │ ├── resources │ │ └── scalac-plugin.xml │ └── scala │ │ └── com │ │ └── azavea │ │ └── math │ │ └── plugin │ │ └── OptimizedNumeric.scala │ └── test │ └── scala │ └── Example.scala ├── project └── Build.scala └── src ├── main └── scala │ └── com │ └── azavea │ └── math │ ├── Convertable.scala │ ├── EasyNumericOps.scala │ ├── FastNumericOps.scala │ ├── LiteralOps.scala │ └── Numeric.scala └── test └── scala └── Matrix.scala /.gitignore: -------------------------------------------------------------------------------- 1 | target/ 2 | lib_managed/ 3 | src_managed/ 4 | project/boot/ 5 | .ensime 6 | benchmark.html 7 | -------------------------------------------------------------------------------- /COPYING: -------------------------------------------------------------------------------- 1 | Copyright (c) 2011 Erik Osheim, Azavea Inc. 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | this software and associated documentation files (the "Software"), to deal in 5 | the Software without restriction, including without limitation the rights to 6 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies 7 | of the Software, and to permit persons to whom the Software is furnished to do 8 | so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in all 11 | copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | SOFTWARE. 20 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | NUMERIC 2 | ======= 3 | 4 | This package (com.azavea.math.Numeric) is based on scala.math.Numeric, which is 5 | a type trait designed to allow abstraction across numeric types (e.g. Int, 6 | Double, etc). I (Erik Osheim) began this work as part of the 7 | specialized-numeric Scala Incubator project. 8 | 9 | This package is a much-improved version. For one, it is way faster, thanks to 10 | Scala's specialization system, as well as some restructuring and an optional 11 | compiler plugin. It is also more flexible, allowing the user to operate on 12 | generic numeric types, concrete numeric types, and literals without requiring 13 | ugly casting. 14 | 15 | Ultimately my hope is to push these changes back into Scala's library. 16 | 17 | REQUIREMENTS 18 | ============ 19 | 20 | Build using SBT 10. Will port to SBT 11 soon. 21 | 22 | 23 | EXAMPLES 24 | ======== 25 | 26 | Basic example of addition: 27 | 28 | import com.azavea.math.Numeric 29 | import com.azavea.math.FastImplicits._ 30 | 31 | def adder[A:Numeric](a:A, b:A) = a + b 32 | 33 | Creating a Point3 object with coordinates of any numeric type: 34 | 35 | import com.azavea.math.Numeric 36 | import com.azavea.math.FastImplicits._ 37 | 38 | case class Point3[T:Numeric](x:T, y:T, z:T) { 39 | def +(rhs:Point3[T]) = Point(x + rhs.x, y + rhs.y, z + rhs.z) 40 | 41 | def distance(rhs:Point3[T]) = { 42 | val dx = x - rhs.x 43 | val dy = y - rhs.y 44 | val dz = z - rhs.z 45 | val d = dx * dx + dy * dy + dz * dz 46 | math.sqrt(d.toDouble) 47 | } 48 | } 49 | 50 | Mixing literals and variables: 51 | 52 | import com.azavea.math.Numeric 53 | import com.azavea.math.EasyImplicits._ 54 | import Predef.{any2stringadd => _, _} 55 | 56 | def foo[T:Numeric](a:T, b:T):T = a * 100 + b 57 | 58 | Currently there are two different ways to use Numeric: EasyImplicits allows you 59 | to operate on mixed numeric types (e.g. T + U + Int). FastImplicits sacrifices 60 | this ability in exchange for speed. Both can be made equally fast (as fast as 61 | operating on direct types) with the optimized-numeric compiler plugin. 62 | 63 | If you plan to use the compiler plugin, or aren't worried about speed, you will 64 | probably want to use EasyImplicits. 65 | 66 | 67 | ANNOYING PREDEF 68 | =============== 69 | 70 | You may have noticed this ugly-looking import: 71 | 72 | import Predef.{any2stringadd => _, _} 73 | 74 | This is to work around a design problem in Scala. You may not always need to 75 | use this, but if you notice problems with + not working correctly you should 76 | add this top-level import. Sorry! :( 77 | 78 | 79 | PROJECT STRUCTURE 80 | ================= 81 | 82 | Here is a list of the SBT projects: 83 | 84 | * root: contains the library code itself ("package" builds the library jar) 85 | * plugin: contains the compiler plugin code ("package" builds the plugin jar) 86 | * perf: contains the performance test 87 | 88 | Use "projects" to view them, and "project XYZ" to switch to XYZ. 89 | 90 | USING THE PLUGIN 91 | ================ 92 | 93 | The optimized-numeric plugin is able to speed things up by rewring certain 94 | constructions into other, faster ones. Here's an example: 95 | 96 | // written 97 | def foo[T:Numeric](a:T, b:T) = a + b 98 | 99 | // compiled 100 | def foo[T](a:T, b:T)(implicit ev:Numeric[T]) = new FastNumericOps(a).+(b) 101 | 102 | // compiled with plugin 103 | def foo[T](a:T, b:T)(implicit ev:Numeric[T]) = ev.add(a, b) 104 | 105 | In the future scalac might be able to do this for us (or hotspot might be able 106 | to optimized it away). But in the absence of these things the plugin helps make 107 | Numeric much faster (especially when using EasyImplicits). 108 | 109 | At the most basic level, you can add "-Xplugin:path/to/optimized-numeric.jar" 110 | to your scalac invocation to compile things with the plugin. 111 | 112 | When running the perf project, here are the steps you can take to build and 113 | enable the plugin: 114 | 115 | 1. in sbt: "project plugin", then "package" 116 | 2. cp plugin/target/scala-2.9.1.final/optimized-numeric-plugin_2.9.1-0.1.jar perf/lib/ 117 | 3. uncomment perf/build.sbt line involving -Xplugin 118 | 4. in sbt: "project perf", then "run" 119 | 120 | The plugin is enabled by default. To disable the plugin just revert step #3. 121 | 122 | 123 | BENCHMARKS 124 | ========== 125 | 126 | To run the benchmarks, do the following: 127 | 128 | 1. optionally build and install the plugin 129 | 2. in sbt: "project perf", then "run" 130 | 131 | The output shows the speed (in milliseconds) of a direct implementation 132 | (without generics), the new implementation (com.azavea.math.Numeric) and the 133 | old implementation (the built-in scala.math.Numeric). In some cases there is no 134 | old implementation--in those cases the test tries to hem as closely as possible 135 | to the direct implementation. 136 | 137 | * n:d is how new compares to direct 138 | * o:d is how old compares to direct 139 | * o:n is how old compares to new 140 | 141 | It also creates a benchmark.html file which colors the output. 142 | 143 | RESULTS 144 | ======= 145 | 146 | There are some interesting results: 147 | 148 | 1. Both scala.math.Numeric and com.azavea.math.Numeric seem to perform worse 149 | on integral types than fractional ones. this is very pronounced for 150 | scala.math.Numeric and only slight for com.azavea.math.Numeric. 151 | 152 | 2. com.azavea.math.Numeric mostly* performs as well as direct implementations 153 | except when using infix operators without the compiler plugin. The current 154 | Numeric is clearly inappropriate for any application where performance is 155 | important. 156 | 157 | 3. The asterisk in the previous item has to do with Quicksort. Basically, 158 | scala.util.Sorting uses Ordering\[A\] which is not specialized and which 159 | implements all its own (non-specialized) comparison operators in terms of 160 | compare(). 161 | 162 | This ends up being really slow, so my Numeric trait doesn't extend it, but 163 | instead provides a getOrdering() method (which builds a separate Ordering 164 | instance wrapping the Numeric instance). As a result, it doesn't perform any 165 | better than scala.math.Numeric on this test (and in fact does a bit worse). 166 | 167 | I don't know how likely it is that Ordering will be specialized, but huge 168 | performance gains seem possible. 169 | 170 | 4. scala.util.Sorting.quickSort lacks a direct Long implementation, so using it 171 | with Longs is ~5x slower than Int, Float or Double. 172 | 173 | 5. It seems like scala.util.Sorting could use some love. My naive direct 174 | implementation of merge sort seems to beat Sorting.quickSort for Long 175 | (obviously), Float and Double. That said, optimizing sort algorithms can be 176 | tricky. But gains seem possible. 177 | 178 | 179 | DIFFERENCES 180 | =========== 181 | 182 | While very similar to scala.math.Numeric, com.azavea.math.Numeric has 183 | some differences. The most significant ones are: 184 | 185 | 1. It does not inherit from the Ordering type class, but rather directly 186 | implements the comparison methods. I will try to do some cleanup on this and 187 | make it more compatible with Ordering, but it was important to me that the 188 | comparison methods are also specialized. 189 | 190 | 2. It does not implement Integral/Fractional. I think that leaving 191 | division/modulo off of Numeric is a mistake, and don't think that forcing users 192 | to use Integral/Fractional is a good idea. Given that Scala uses the same 193 | symbol (/) to mean both "integer division" and "true division" it seems clear 194 | that Numeric can too. 195 | 196 | 3. It's in a different package. 197 | 198 | 4. It adds some operators that I thought would be nice to have: 199 | 200 | - <=> as an alias for compare 201 | - === as an alias for equiv 202 | - !== as an alias for !equiv 203 | - ** as an alias for math.pow 204 | 205 | 5. It adds a full-suite of conversions. Unlike the existing Numeric, you can 206 | convert directly to/from any numeric type. This is useful since you might be 207 | going from a generic type to a known type (e.g. A -> Int) or a known type to a 208 | geneirc one (Int -> A). In both cases it is important not to do any unnecessary 209 | work (when A is an Int, you should not have to do any copying/casting). 210 | 211 | These conversions can be used from the instance of Numeric[T] or directly 212 | on the values of T themselves: 213 | 214 | def foo[T:Numeric](t:T):Double = { 215 | val i:Int = numeric.toInt(t) 216 | val d:Double = t.toDouble 217 | d - i 218 | } 219 | 220 | CAVEATS 221 | ======= 222 | 223 | This section is just for "known problems" with the Numeric type class approach 224 | in Scala. 225 | 226 | Precision 227 | --------- 228 | 229 | Given the signatures of Numeric's functions, it's possible to mix literals and 230 | generic types in a way which will lose precision. Consider: 231 | 232 | def foo[T:Numeric](t:T) = t + 9.2 233 | 234 | The author might expect that foo\[Int\](5) will return a Double (14.2), but in 235 | fact foo\[T\] returns T and so Foo\[Int\] will return an Int (14). The solution 236 | in cases like this (if you know you want a double) is to convert T to a Double 237 | first: 238 | 239 | def foo[T:Numeric](t:T) = t.toDouble + 9.2 240 | 241 | You could imagine that + could "figure out" whether T has more precision than 242 | Double (e.g. BigDecimal) or less precisoin (e.g. Int) and "do the right thing". 243 | Sadly, this is not possible using the current strategy. 244 | 245 | If you are interested in implementing a numeric tower in Scala (presumably by 246 | writing some pretty intense compiler plugins if not modifying Scala's type 247 | system) please get in touch with the author. :) 248 | 249 | Clunky syntax 250 | ------------- 251 | 252 | Unfortunately in order to benefit from the speed of this library you must 253 | annotate all your numeric with @specialization. Also, you often need a 254 | manifest (for instance when you want to allocate Arrays of that type), so you 255 | have to provide those too. Your code will end up looking similar to this: 256 | 257 | def handleInt(a:Int) = ... 258 | 259 | def handleA[@specialized A:Numeric:Manifest](a:A) = ... 260 | 261 | When you compare these visually the second is obviously terrible. In many of 262 | the examples I have omitted the @specialized annotation and the Manifest type 263 | bound for clarity. Hopefully this is not too deceptive. 264 | 265 | It would be great if there was some way to create a type bound that "included" 266 | specialization. For instance: 267 | 268 | // written 269 | def bar[T:SNumeric](a:T, b:T, c:T) = a * b + c 270 | 271 | // compiled 272 | def bar[@specialized(Int,Long,Float,Double) T:Numeric](a:T, b:T, c:T) = a * b + c 273 | 274 | The current design of specialization works well when a library author expects 275 | users to call her generic functions with concrete types. But if her users are 276 | themselves defining generic functions, the library author is powerless to help. 277 | 278 | Obviously, this could pose real problem in terms of a bytecode explosion, which 279 | is why it should not be a default behavior. In fact, Numeric may be one of the 280 | few places where this behavior would be desirable. 281 | 282 | Non-extensibility 283 | ----------------- 284 | 285 | While you should be able to implement your own instances of Numeric for 286 | user-generated types (e.g. Complex or BigRational) you will not be able to use 287 | the full power of this library due to the need to add methods to the 288 | ConvertableFrom and ConvertableTo traits. 289 | 290 | This is unfortunate, and I'm is exploring a pluggable numeric conversion 291 | system. However, in the interested of getting a fast, working implementation 292 | out there, I have shelved that for now. 293 | 294 | Lack of range support 295 | --------------------- 296 | 297 | I am still working on implementing a corresponding NumericRange. For now 298 | you'll need to convert to a known type (e.g. Long), or use while loops. This 299 | is definitely possible and will hopefuly be done soon. 300 | -------------------------------------------------------------------------------- /TODO: -------------------------------------------------------------------------------- 1 | This is a random list of things that either would be good to do, or would be 2 | good to consider doing. Some of these are mutually-incompatible (or at least, 3 | mutually-undesirable). 4 | 5 | Some of these may also be impossible! :) 6 | 7 | 1. Implement a NumericRange class 8 | 9 | Paul Phillips has encouraged me to do this. I managed to create something that 10 | worked but wasn't performing as well as the direct integer ranges. The problem 11 | seemed to be the difference between the way foo(Int => Unit) and foo[T](T => 12 | Unit) are treated. 13 | 14 | 15 | 2. Fix support for user-created types 16 | 17 | Right now the strategy we're using with ConvertableTo/ConvertableFrom doesn't 18 | scale well when users are adding their own types. This may not actually be a 19 | big deal, but it'd be nice if the system was more pluggable. Right now if we 20 | want to provide suport converting to/from "Foo" objects we have to change the 21 | ConvertableFrom/ConvertableTo traits to include methods "toFoo" and "fromFoo". 22 | 23 | I can imagine using something like ConvertableBetween[A, B] to solve this, but 24 | I don't want to sacrifice performance to do so. Experimentation necessary. 25 | See https://github.com/nuttycom/salt/commits/master/src/main/scala/com/nommit/salt/Bijection.scala 26 | for ideas. 27 | 28 | 29 | 3. Specialize Ordering 30 | 31 | This is probably too big for this project, but it would be great if Ordering 32 | were specialized so that Numeric could extend it without taking a huge speed 33 | hit on things like Numeric.lt(). 34 | 35 | 36 | 4. Rationalize working with two different numeric types 37 | 38 | Currently if you have "def complicated[T:Numeric, U:Numeric](t:T, u:U)" you 39 | will have to do manual conversions to one of those types without being sure 40 | that you aren't losing precision. I'm not sure there's a way to get around 41 | this, but it might be nice to have a way to compare two Numeric objects in 42 | terms of which one should "win" in terms of precision. That way you could at 43 | least say something like "if (n1.morePreciseThan(n2)) ... else ..." 44 | 45 | I'm not sure how often this would be used, but I could imagine it being nice. 46 | 47 | 48 | 5. Test on more hardware/JVM configurations 49 | 50 | 51 | 6. Port tests to one of the newer performance testing frameworks 52 | 53 | 54 | 7. Write more tests 55 | 56 | 57 | 8. Figure out how to deal with the "strict" infix operators versus the "fuzzy" 58 | infix operators. There is a compiler plugin to rewrite both into the faster 59 | form "n.add(lhs, rhs)" but without that the strict infix operators perform 60 | better, and also don't fall afoul of StringOps' + method. On the other hand, 61 | the fuzzy infix operators work better with literals (e.g. 3) and concrete types 62 | (e.g i:Int). 63 | 64 | 65 | 9. Add "auto-specialization" to the compiler plugin 66 | 67 | I can imagine creating a type alias for Numeric ("SNumeric") and then 68 | making the compiler plugin automatically specialize any uses of it. This way a 69 | user's code could automatically get the gains of specialization without having 70 | to annotate every single function manually. 71 | 72 | I'm not sure if this would actually work, but it'd be worth looking into. 73 | 74 | 75 | 10. Think about how to support other math functions 76 | 77 | Right now there are still some functions (sqrt and others from scala.math) 78 | which aren't defined directly on the types. This is bad, because when you need 79 | to use these functions you end up either calling toDouble (and thus breaking 80 | support for BigDecimal and friends) or you call toBigDecimal (and slow things 81 | down for your fast primitive types). We could define these functions to all 82 | return T, but then you'd have no way of getting a fractional sqrt() when using 83 | Ints. 84 | 85 | This gets back to the issues with the lack of a numeric tower. When taking the 86 | sqrt of an Int we probably want to get back a Double, but when taking the sqrt 87 | of a BigInt we probably want a BigDecimal. I don't see any easy way of dealing 88 | with it, but I may not be creative enough. At various points I've had quixotic 89 | visions of a compiler plugin that could "figure it out"... 90 | -------------------------------------------------------------------------------- /build.sbt: -------------------------------------------------------------------------------- 1 | // project name 2 | name := "Numeric" 3 | 4 | // shrug? 5 | version := "0.1" 6 | 7 | // test 8 | libraryDependencies += "org.scalatest" % "scalatest_2.9.0" % "1.6.1" 9 | 10 | // hide backup files 11 | defaultExcludes ~= (filter => filter || "*~") 12 | 13 | scalacOptions += "-optimise" 14 | 15 | // any of these work, although 2.9.1 performs the best 16 | //scalaVersion := "2.8.1 17 | //scalaVersion := "2.9.0-1" 18 | scalaVersion := "2.9.1" 19 | -------------------------------------------------------------------------------- /perf/build.sbt: -------------------------------------------------------------------------------- 1 | // project name 2 | name := "Numeric Performance Test" 3 | 4 | // shrug? 5 | version := "0.1" 6 | 7 | // hide backup files 8 | defaultExcludes ~= (filter => filter || "*~") 9 | 10 | scalacOptions += "-optimise" 11 | 12 | //autoCompilerPlugins := true 13 | 14 | scalacOptions += "-Xplugin:perf/lib/optimized-numeric-plugin_2.9.1-0.1.jar" 15 | //scalacOptions += "-Xplugin:lib/optimized-numeric.jar" 16 | 17 | // any of these work, although 2.9.1 performs the best 18 | //scalaVersion := "2.8.1 19 | //scalaVersion := "2.9.0-1" 20 | scalaVersion := "2.9.1" 21 | -------------------------------------------------------------------------------- /perf/lib/optimized-numeric-plugin_2.9.1-0.1.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/azavea/numeric/15a1e0381224826b418a9d9a6b2c9773ea917137/perf/lib/optimized-numeric-plugin_2.9.1-0.1.jar -------------------------------------------------------------------------------- /perf/src/main/scala/Main.scala: -------------------------------------------------------------------------------- 1 | import scala.math.max 2 | import scala.math.{Numeric => OldNumeric, Integral, Fractional, min, max} 3 | import scala.util.Random 4 | import scala.testing.Benchmark 5 | 6 | import java.io.{FileWriter, PrintWriter} 7 | 8 | import Console.printf 9 | 10 | import com.azavea.math.Numeric 11 | import com.azavea.math.EasyImplicits._ 12 | //import com.azavea.math.FastImplicits._ 13 | import Predef.{any2stringadd => _, _} 14 | 15 | // define some constant sizes and random arrays that we can use for our various 16 | // performance tests. if things run way too slow or way too fast you can try 17 | // changing the factor or divisor :) 18 | object Constant { 19 | val factor = 1 20 | val divisor = 1 21 | 22 | val smallSize:Int = (10 * 1000 * factor) / divisor 23 | val smallIntArray = Array.ofDim[Int](smallSize).map(i => Random.nextInt()) 24 | val smallLongArray = Array.ofDim[Long](smallSize).map(i => Random.nextLong()) 25 | val smallFloatArray = Array.ofDim[Float](smallSize).map(i => Random.nextFloat()) 26 | val smallDoubleArray = Array.ofDim[Double](smallSize).map(i => Random.nextDouble()) 27 | 28 | val mediumSize:Int = (1 * 1000 * 1000 * factor) / divisor 29 | val mediumIntArray = Array.ofDim[Int](mediumSize).map(i => Random.nextInt()) 30 | val mediumLongArray = Array.ofDim[Long](mediumSize).map(i => Random.nextLong()) 31 | val mediumFloatArray = Array.ofDim[Float](mediumSize).map(i => Random.nextFloat()) 32 | val mediumDoubleArray = Array.ofDim[Double](mediumSize).map(i => Random.nextDouble()) 33 | 34 | val largeSize:Int = (10 * 1000 * 1000 * factor) / divisor 35 | val largeIntArray = Array.ofDim[Int](largeSize).map(i => Random.nextInt()) 36 | val largeLongArray = Array.ofDim[Long](largeSize).map(i => Random.nextLong()) 37 | val largeFloatArray = Array.ofDim[Float](largeSize).map(i => Random.nextFloat()) 38 | val largeDoubleArray = Array.ofDim[Double](largeSize).map(i => Random.nextDouble()) 39 | 40 | val createHTML = true 41 | } 42 | import Constant._ 43 | 44 | case class TestResult(tavg:Double, tmax:Long, tmin:Long) 45 | 46 | // represents a particular performance test we want to run 47 | trait TestCase { 48 | def name: String 49 | 50 | // direct implementation using primitives 51 | def direct(): Option[Any] 52 | 53 | // implemented using the new Numeric trait 54 | def newGeneric(): Option[Any] 55 | 56 | // implemented using the built-in Numeric trait 57 | def oldGeneric(): Option[Any] 58 | 59 | object CaseDirect extends Benchmark { def run = direct } 60 | object CaseGeneric extends Benchmark { def run = newGeneric } 61 | object CaseOld extends Benchmark { def run = oldGeneric } 62 | 63 | val cases = List(CaseDirect, CaseGeneric, CaseOld) 64 | 65 | def runCase(c:Benchmark, warmupRuns:Int, liveRuns:Int) = { 66 | c.runBenchmark(warmupRuns) 67 | val results = c.runBenchmark(liveRuns) 68 | 69 | val total = results.foldLeft(0L)(_ + _) 70 | 71 | val tmax = results.foldLeft(0L)(max(_, _)) 72 | val tmin = results.foldLeft(Long.MaxValue)(min(_, _)) 73 | val tavg = total.toDouble / liveRuns 74 | 75 | TestResult(tavg, tmax, tmin) 76 | } 77 | 78 | def run(n:Int, m:Int):List[TestResult] = cases.map(runCase(_, n, m)) 79 | 80 | def percStatus(p:Double) = if (p < 0.9) "great" 81 | else if (p < 1.1) "good" 82 | else if (p < 2.2) "ok" 83 | else if (p < 4.4) "poor" 84 | else if (p < 8.8) "bad" 85 | else "awful" 86 | 87 | def test(p:Option[PrintWriter]) { 88 | val results = run(2, 8) 89 | 90 | val times = results.map(_.tavg) 91 | val tstrs = times.map { 92 | t => if (t < 0.1) { 93 | " n/a" 94 | } else { 95 | "%6.1fms".format(t) 96 | } 97 | } 98 | 99 | val List(t1, t2, t3) = times 100 | 101 | def mkp(a:Double, b:Double) = if (a == 0.0 || b == 0.0) 0.0 else a / b 102 | 103 | val percs = List(mkp(t2, t1), mkp(t3, t1), mkp(t3, t2)) 104 | val pstrs = percs.map(p => if(p < 0.01) " n/a" else "%5.2fx".format(p)) 105 | 106 | p match { 107 | case Some(pw) => { 108 | val a = " %s%.1f".format(name, times(0)) 109 | val s1 = percStatus(percs(0)) 110 | val b = if (times(1) < 0.1) { 111 | "" 112 | } else { 113 | "%.1f%s".format(s1, times(1), s1, pstrs(0)) 114 | } 115 | val s2 = percStatus(percs(1)) 116 | val c = if (times(2) < 0.1) { 117 | "" 118 | } else { 119 | "%.1f%s".format(s2, times(2), s2, pstrs(1)) 120 | } 121 | pw.println(a + b + c) 122 | } 123 | case _ => {} 124 | } 125 | 126 | val fields = ("%-24s".format(name) :: tstrs) ++ ("/" :: pstrs) 127 | println(fields.reduceLeft(_ + " " + _)) 128 | } 129 | } 130 | 131 | 132 | // =============================================================== 133 | trait FromIntToX extends TestCase { 134 | def directToInt(a:Array[Int]) = { 135 | val b = Array.ofDim[Int](a.length) 136 | var i = 0 137 | while (i < a.length) { 138 | b(i) = a(i).toInt 139 | i += 1 140 | } 141 | b 142 | } 143 | def directToLong(a:Array[Int]) = { 144 | val b = Array.ofDim[Long](a.length) 145 | var i = 0 146 | while (i < a.length) { 147 | b(i) = a(i).toLong 148 | i += 1 149 | } 150 | b 151 | } 152 | 153 | def directToFloat(a:Array[Int]) = { 154 | val b = Array.ofDim[Float](a.length) 155 | var i = 0 156 | while (i < a.length) { 157 | b(i) = a(i).toFloat 158 | i += 1 159 | } 160 | b 161 | } 162 | 163 | def directToDouble(a:Array[Int]) = { 164 | val b = Array.ofDim[Double](a.length) 165 | var i = 0 166 | while (i < a.length) { 167 | b(i) = a(i).toDouble 168 | i += 1 169 | } 170 | b 171 | } 172 | 173 | def newFromInts[@specialized A:Numeric:Manifest](a:Array[Int]): Array[A] = { 174 | val b = Array.ofDim[A](a.length) 175 | var i = 0 176 | while (i < a.length) { 177 | b(i) = numeric.fromInt(a(i)) 178 | i += 1 179 | } 180 | b 181 | } 182 | 183 | def oldFromInts[A:OldNumeric:Manifest](a:Array[Int]): Array[A] = { 184 | val m = implicitly[OldNumeric[A]] 185 | val b = Array.ofDim[A](a.length) 186 | var i = 0 187 | while (i < a.length) { 188 | b(i) = m.fromInt(a(i)) 189 | i += 1 190 | } 191 | b 192 | } 193 | } 194 | 195 | final class FromIntToInt extends FromIntToX { 196 | def name = "from-int-to-int" 197 | def direct() = Option(directToInt(largeIntArray)) 198 | def newGeneric() = Option(newFromInts[Int](largeIntArray)) 199 | def oldGeneric() = Option(oldFromInts[Int](largeIntArray)) 200 | } 201 | 202 | final class FromIntToLong extends FromIntToX { 203 | def name = "from-int-to-long" 204 | def direct() = Option(directToLong(largeIntArray)) 205 | def newGeneric() = Option(newFromInts[Long](largeIntArray)) 206 | def oldGeneric() = Option(oldFromInts[Long](largeIntArray)) 207 | } 208 | 209 | final class FromIntToFloat extends FromIntToX { 210 | def name = "from-int-to-float" 211 | def direct() = Option(directToFloat(largeIntArray)) 212 | def newGeneric() = Option(newFromInts[Float](largeIntArray)) 213 | def oldGeneric() = Option(oldFromInts[Float](largeIntArray)) 214 | } 215 | 216 | final class FromIntToDouble extends FromIntToX { 217 | def name = "from-int-to-double" 218 | def direct() = Option(directToDouble(largeIntArray)) 219 | def newGeneric() = Option(newFromInts[Double](largeIntArray)) 220 | def oldGeneric() = Option(oldFromInts[Double](largeIntArray)) 221 | } 222 | 223 | 224 | 225 | // ================================================================= 226 | trait BaseAdder extends TestCase { 227 | def newAdder[@specialized A](a:A, b:A)(implicit m:Numeric[A]): A 228 | def oldAdder[A](a:A, b:A)(implicit m:OldNumeric[A]): A 229 | 230 | def directIntAdder(a:Int, b:Int):Int = a + b 231 | def directLongAdder(a:Long, b:Long):Long = a + b 232 | def directFloatAdder(a:Float, b:Float):Float = a + b 233 | def directDoubleAdder(a:Double, b:Double):Double = a + b 234 | } 235 | 236 | trait BaseAdderInt extends BaseAdder { 237 | def direct() = { 238 | var s = 0 239 | var i = 0 240 | while (i < largeSize) { s = directIntAdder(s, i); i += 1 } 241 | Option(s) 242 | } 243 | def newGeneric() = { 244 | var s = 0 245 | var i = 0 246 | while (i < largeSize) { s = newAdder(s, i); i += 1 } 247 | Option(s) 248 | } 249 | def oldGeneric() = { 250 | var s = 0 251 | var i = 0 252 | while (i < largeSize) { s = oldAdder(s, i); i += 1 } 253 | Option(s) 254 | } 255 | } 256 | 257 | trait BaseAdderLong extends BaseAdder { 258 | def direct() = { 259 | var s = 0L 260 | var i = 0 261 | while (i < largeSize) { s = directLongAdder(s, i); i += 1 } 262 | Option(s) 263 | } 264 | def newGeneric() = { 265 | var s = 0L 266 | var i = 0 267 | while (i < largeSize) { s = newAdder(s, i); i += 1 } 268 | Option(s) 269 | } 270 | def oldGeneric() = { 271 | var s = 0L 272 | var i = 0 273 | while (i < largeSize) { s = oldAdder(s, i); i += 1 } 274 | Option(s) 275 | } 276 | } 277 | 278 | trait BaseAdderFloat extends BaseAdder { 279 | def direct() = { 280 | var s = 0.0F 281 | var i = 0 282 | while (i < largeSize) { s = directFloatAdder(s, i); i += 1 } 283 | Option(s) 284 | } 285 | def newGeneric() = { 286 | var s = 0.0F 287 | var i = 0 288 | while (i < largeSize) { s = newAdder(s, i); i += 1 } 289 | Option(s) 290 | } 291 | def oldGeneric() = { 292 | var s = 0.0F 293 | var i = 0 294 | while (i < largeSize) { s = oldAdder(s, i); i += 1 } 295 | Option(s) 296 | } 297 | } 298 | 299 | trait BaseAdderDouble extends BaseAdder { 300 | def direct() = { 301 | var s = 0.0 302 | var i = 0 303 | while (i < largeSize) { s = directDoubleAdder(s, i); i += 1 } 304 | Option(s) 305 | } 306 | def newGeneric() = { 307 | var s = 0.0 308 | var i = 0 309 | while (i < largeSize) { s = newAdder(s, i); i += 1 } 310 | Option(s) 311 | } 312 | def oldGeneric() = { 313 | var s = 0.0 314 | var i = 0 315 | while (i < largeSize) { s = oldAdder(s, i); i += 1 } 316 | Option(s) 317 | } 318 | } 319 | 320 | 321 | // ========================================================= 322 | trait Adder extends BaseAdder { 323 | def newAdder[@specialized A](a:A, b:A)(implicit m:Numeric[A]): A = m.plus(a, b) 324 | def oldAdder[A](a:A, b:A)(implicit m:OldNumeric[A]): A = m.plus(a, b) 325 | } 326 | 327 | final class AdderInt extends Adder with BaseAdderInt { def name = "adder-int" } 328 | final class AdderLong extends Adder with BaseAdderLong { def name = "adder-long" } 329 | final class AdderFloat extends Adder with BaseAdderFloat { def name = "adder-float" } 330 | final class AdderDouble extends Adder with BaseAdderDouble{ def name = "adder-double" } 331 | 332 | 333 | // ===================================================== 334 | trait BaseArrayOps extends TestCase { 335 | def directIntArrayOp(a:Int, b:Int): Int 336 | def directIntArrayOps(a:Array[Int]) = { 337 | var total = 0 338 | var i = 0 339 | while (i < a.length) { 340 | total = directIntArrayOp(total, a(i)) 341 | i += 1 342 | } 343 | total 344 | } 345 | 346 | def directLongArrayOp(a:Long, b:Long): Long 347 | def directLongArrayOps(a:Array[Long]) = { 348 | var total = 0L 349 | var i = 0 350 | while (i < a.length) { 351 | total = directLongArrayOp(total, a(i)) 352 | i += 1 353 | } 354 | total 355 | } 356 | 357 | def directFloatArrayOp(a:Float, b:Float): Float 358 | def directFloatArrayOps(a:Array[Float]) = { 359 | var total = 0.0F 360 | var i = 0 361 | while (i < a.length) { 362 | total = directFloatArrayOp(total, a(i)) 363 | i += 1 364 | } 365 | total 366 | } 367 | 368 | def directDoubleArrayOp(a:Double, b:Double): Double 369 | def directDoubleArrayOps(a:Array[Double]) = { 370 | var total = 0.0 371 | var i = 0 372 | while (i < a.length) { 373 | total = directDoubleArrayOp(total, a(i)) 374 | i += 1 375 | } 376 | total 377 | } 378 | 379 | def newArrayOp[@specialized A:Numeric](a:A, b:A): A 380 | def newArrayOps[@specialized A:Numeric:Manifest](a:Array[A]) = { 381 | var total = numeric.zero 382 | var i = 0 383 | while (i < a.length) { 384 | total = newArrayOp(total, a(i)) 385 | i += 1 386 | } 387 | total 388 | } 389 | 390 | def oldArrayOp[A](a:A, b:A)(implicit m:OldNumeric[A]): A 391 | def oldArrayOps[A:OldNumeric:Manifest](a:Array[A]) = { 392 | val m = implicitly[OldNumeric[A]] 393 | var total = m.zero 394 | var i = 0 395 | while (i < a.length) { 396 | total = oldArrayOp(total, a(i)) 397 | i += 1 398 | } 399 | total 400 | } 401 | } 402 | 403 | 404 | trait BaseArrayMapOps extends TestCase { 405 | def directIntArrayOp(a:Int): Int 406 | def directIntArrayOps(a:Array[Int]) = { 407 | var i = 0 408 | while (i < a.length) { 409 | a(i) = directIntArrayOp(a(i)) 410 | i += 1 411 | } 412 | a 413 | } 414 | 415 | def directLongArrayOp(a:Long): Long 416 | def directLongArrayOps(a:Array[Long]) = { 417 | var i = 0 418 | while (i < a.length) { 419 | a(i) = directLongArrayOp(a(i)) 420 | i += 1 421 | } 422 | a 423 | } 424 | 425 | def directFloatArrayOp(a:Float): Float 426 | def directFloatArrayOps(a:Array[Float]) = { 427 | var i = 0 428 | while (i < a.length) { 429 | a(i) = directFloatArrayOp(a(i)) 430 | i += 1 431 | } 432 | a 433 | } 434 | 435 | def directDoubleArrayOp(a:Double): Double 436 | def directDoubleArrayOps(a:Array[Double]) = { 437 | var i = 0 438 | while (i < a.length) { 439 | a(i) = directDoubleArrayOp(a(i)) 440 | i += 1 441 | } 442 | a 443 | } 444 | 445 | def newArrayOp[@specialized A:Numeric](a:A): A 446 | def newArrayOps[@specialized A:Numeric:Manifest](a:Array[A]) = { 447 | var i = 0 448 | while (i < a.length) { 449 | a(i) = newArrayOp(a(i)) 450 | i += 1 451 | } 452 | a 453 | } 454 | 455 | def oldArrayOp[A](a:A)(implicit m:OldNumeric[A]): A 456 | def oldArrayOps[A:OldNumeric:Manifest](a:Array[A]) = { 457 | val m = implicitly[OldNumeric[A]] 458 | var i = 0 459 | while (i < a.length) { 460 | a(i) = oldArrayOp(a(i)) 461 | i += 1 462 | } 463 | a 464 | } 465 | } 466 | 467 | 468 | // ====================================================== 469 | trait ArrayAdder extends BaseArrayOps { 470 | def directIntArrayOp(a:Int, b:Int) = a + b 471 | def directLongArrayOp(a:Long, b:Long) = a + b 472 | def directFloatArrayOp(a:Float, b:Float) = a + b 473 | def directDoubleArrayOp(a:Double, b:Double) = a + b 474 | 475 | def newArrayOp[@specialized A:Numeric](a:A, b:A) = numeric.plus(a, b) 476 | def oldArrayOp[A](a:A, b:A)(implicit m:OldNumeric[A]) = m.plus(a, b) 477 | } 478 | 479 | final class IntArrayAdder extends ArrayAdder { 480 | def name = "array-total-int" 481 | def direct() = Option(directIntArrayOps(largeIntArray)) 482 | def newGeneric() = Option(newArrayOps(largeIntArray)) 483 | def oldGeneric() = Option(oldArrayOps(largeIntArray)) 484 | } 485 | 486 | final class LongArrayAdder extends ArrayAdder { 487 | def name = "array-total-long" 488 | def direct() = Option(directLongArrayOps(largeLongArray)) 489 | def newGeneric() = Option(newArrayOps(largeLongArray)) 490 | def oldGeneric() = Option(oldArrayOps(largeLongArray)) 491 | } 492 | 493 | final class FloatArrayAdder extends ArrayAdder { 494 | def name = "array-total-float" 495 | def direct() = Option(directFloatArrayOps(largeFloatArray)) 496 | def newGeneric() = Option(newArrayOps(largeFloatArray)) 497 | def oldGeneric() = Option(oldArrayOps(largeFloatArray)) 498 | } 499 | 500 | final class DoubleArrayAdder extends ArrayAdder { 501 | def name = "array-total-double" 502 | def direct() = Option(directDoubleArrayOps(largeDoubleArray)) 503 | def newGeneric() = Option(newArrayOps(largeDoubleArray)) 504 | def oldGeneric() = Option(oldArrayOps(largeDoubleArray)) 505 | } 506 | 507 | 508 | 509 | 510 | // ========================================================================== 511 | trait ArrayRescale extends BaseArrayMapOps { 512 | def directIntArrayOp(b:Int) = (b * 5) / 3 513 | def directLongArrayOp(b:Long) = (b * 5L) / 3L 514 | def directFloatArrayOp(b:Float) = (b * 5.0F) / 3.0F 515 | def directDoubleArrayOp(b:Double) = (b * 5.0) / 3.0 516 | 517 | def newArrayOp[@specialized A](b:A)(implicit m:Numeric[A]) = m.div(m.times(b, m.fromDouble(5.0)), m.fromDouble(3.0)) 518 | def oldArrayOp[A](b:A)(implicit m:OldNumeric[A]) = m.fromInt((m.toDouble(b) * 5.0 / 3.0).toInt) 519 | } 520 | 521 | final class IntArrayRescale extends ArrayRescale { 522 | def name = "array-rescale-int" 523 | def direct() = Option(directIntArrayOps(largeIntArray)) 524 | def newGeneric() = Option(newArrayOps(largeIntArray)) 525 | def oldGeneric() = Option(oldArrayOps(largeIntArray)) 526 | } 527 | 528 | final class LongArrayRescale extends ArrayRescale { 529 | def name = "array-rescale-long" 530 | def direct() = Option(directLongArrayOps(largeLongArray)) 531 | def newGeneric() = Option(newArrayOps(largeLongArray)) 532 | def oldGeneric() = Option(oldArrayOps(largeLongArray)) 533 | } 534 | 535 | final class FloatArrayRescale extends ArrayRescale { 536 | def name = "array-rescale-float" 537 | def direct() = Option(directFloatArrayOps(largeFloatArray)) 538 | def newGeneric() = Option(newArrayOps(largeFloatArray)) 539 | def oldGeneric() = Option(oldArrayOps(largeFloatArray)) 540 | } 541 | 542 | final class DoubleArrayRescale extends ArrayRescale { 543 | def name = "array-rescale-double" 544 | def direct() = Option(directDoubleArrayOps(largeDoubleArray)) 545 | def newGeneric() = Option(newArrayOps(largeDoubleArray)) 546 | def oldGeneric() = Option(oldArrayOps(largeDoubleArray)) 547 | } 548 | 549 | 550 | 551 | // ========================================================================== 552 | trait InfixAdder extends BaseAdder { 553 | def newAdder[@specialized A:Numeric](a:A, b:A): A = a + b 554 | def oldAdder[A](a:A, b:A)(implicit m:OldNumeric[A]): A = { 555 | import m._ 556 | a + b 557 | } 558 | } 559 | 560 | final class InfixAdderInt extends InfixAdder with BaseAdderInt { def name = "infix-adder-int" } 561 | final class InfixAdderLong extends InfixAdder with BaseAdderLong { def name = "infix-adder-long" } 562 | final class InfixAdderFloat extends InfixAdder with BaseAdderFloat { def name = "infix-adder-float" } 563 | final class InfixAdderDouble extends InfixAdder with BaseAdderDouble{ def name = "infix-adder-double" } 564 | 565 | // ========================================================== 566 | trait FindMax extends TestCase { 567 | def directMaxInt(a:Array[Int]) = { 568 | var curr = a(0) 569 | var i = 1 570 | while (i < a.length) { curr = scala.math.max(curr, a(i)); i += 1 } 571 | curr 572 | } 573 | 574 | def directMaxLong(a:Array[Long]) = { 575 | var curr = a(0) 576 | var i = 1 577 | while (i < a.length) { curr = scala.math.max(curr, a(i)); i += 1 } 578 | curr 579 | } 580 | 581 | def directMaxFloat(a:Array[Float]) = { 582 | var curr = a(0) 583 | var i = 1 584 | while (i < a.length) { curr = scala.math.max(curr, a(i)); i += 1 } 585 | curr 586 | } 587 | 588 | def directMaxDouble(a:Array[Double]) = { 589 | var curr = a(0) 590 | var i = 1 591 | while (i < a.length) { curr = scala.math.max(curr, a(i)); i += 1 } 592 | curr 593 | } 594 | 595 | def newGenericMax[@specialized A:Numeric](a:Array[A]) = { 596 | var curr = a(0) 597 | var i = 1 598 | while (i < a.length) { curr = numeric.max(curr, a(i)); i += 1 } 599 | curr 600 | } 601 | 602 | def oldGenericMax[A:OldNumeric](a:Array[A]) = { 603 | val n = implicitly[OldNumeric[A]] 604 | var curr = a(0) 605 | var i = 1 606 | while (i < a.length) { curr = n.max(curr, a(i)); i += 1 } 607 | curr 608 | } 609 | } 610 | 611 | final class FindMaxInt extends FindMax { 612 | def name = "find-max-int" 613 | def direct() = Some(directMaxInt(largeIntArray)) 614 | def newGeneric() = Some(newGenericMax(largeIntArray)) 615 | def oldGeneric() = Some(oldGenericMax(largeIntArray)) 616 | } 617 | 618 | final class FindMaxLong extends FindMax { 619 | def name = "find-max-long" 620 | def direct() = Some(directMaxLong(largeLongArray)) 621 | def newGeneric() = Some(newGenericMax(largeLongArray)) 622 | def oldGeneric() = Some(oldGenericMax(largeLongArray)) 623 | } 624 | 625 | final class FindMaxFloat extends FindMax { 626 | def name = "find-max-float" 627 | def direct() = Some(directMaxFloat(largeFloatArray)) 628 | def newGeneric() = Some(newGenericMax(largeFloatArray)) 629 | def oldGeneric() = Some(oldGenericMax(largeFloatArray)) 630 | } 631 | 632 | final class FindMaxDouble extends FindMax { 633 | def name = "find-max-double" 634 | def direct() = Some(directMaxDouble(largeDoubleArray)) 635 | def newGeneric() = Some(newGenericMax(largeDoubleArray)) 636 | def oldGeneric() = Some(oldGenericMax(largeDoubleArray)) 637 | } 638 | 639 | // ================================================================ 640 | trait BaseSort extends TestCase { 641 | def directIntSorter(a:Array[Int]): Array[Int] 642 | def directLongSorter(a:Array[Long]): Array[Long] 643 | def directFloatSorter(a:Array[Float]): Array[Float] 644 | def directDoubleSorter(a:Array[Double]): Array[Double] 645 | 646 | def newGenericSorter[@specialized A:Numeric:Manifest](a:Array[A]): Array[A] 647 | def oldGenericSorter[A:OldNumeric:Manifest](a:Array[A]): Array[A] 648 | 649 | def directInt(a:Array[Int]) = Option(directIntSorter(a)) 650 | def directLong(a:Array[Long]) = Option(directLongSorter(a)) 651 | def directFloat(a:Array[Float]) = Option(directFloatSorter(a)) 652 | def directDouble(a:Array[Double]) = Option(directDoubleSorter(a)) 653 | 654 | def newGenericSort[@specialized A:Numeric:Manifest](a:Array[A]) = Option(newGenericSorter(a)) 655 | def oldGenericSort[A:OldNumeric:Manifest](a:Array[A]) = Option(oldGenericSorter(a)) 656 | } 657 | 658 | 659 | // ======================================================================= 660 | trait Quicksort extends BaseSort { 661 | def directIntSorter(a:Array[Int]) = { val d = a.clone; scala.util.Sorting.quickSort(d); d } 662 | def directLongSorter(a:Array[Long]) = { val d = a.clone; scala.util.Sorting.quickSort(d); d } 663 | def directFloatSorter(a:Array[Float]) = { val d = a.clone; scala.util.Sorting.quickSort(d); d } 664 | def directDoubleSorter(a:Array[Double]) = { val d = a.clone; scala.util.Sorting.quickSort(d); d } 665 | 666 | // NOTE: this will perform slowly just because Ordering is not specialized! 667 | def newGenericSorter[@specialized A:Numeric:Manifest](a:Array[A]) = { 668 | val d = a.clone; 669 | implicit val ord = implicitly[Numeric[A]].getOrdering() 670 | scala.util.Sorting.quickSort(d); 671 | d 672 | } 673 | def oldGenericSorter[A:OldNumeric:Manifest](a:Array[A]) = { val d = a.clone; scala.util.Sorting.quickSort(d); d } 674 | } 675 | 676 | final class QuicksortInt extends Quicksort { 677 | def name = "quicksort-int" 678 | def direct() = directInt(mediumIntArray) 679 | def newGeneric() = newGenericSort(mediumIntArray) 680 | def oldGeneric() = oldGenericSort(mediumIntArray) 681 | } 682 | 683 | final class QuicksortLong extends Quicksort { 684 | def name = "quicksort-long" 685 | def direct() = directLong(mediumLongArray) 686 | def newGeneric() = newGenericSort(mediumLongArray) 687 | def oldGeneric() = oldGenericSort(mediumLongArray) 688 | } 689 | 690 | final class QuicksortFloat extends Quicksort { 691 | def name = "quicksort-float" 692 | def direct() = directFloat(mediumFloatArray) 693 | def newGeneric() = newGenericSort(mediumFloatArray) 694 | def oldGeneric() = oldGenericSort(mediumFloatArray) 695 | } 696 | 697 | final class QuicksortDouble extends Quicksort { 698 | def name = "quicksort-double" 699 | def direct() = directDouble(mediumDoubleArray) 700 | def newGeneric() = newGenericSort(mediumDoubleArray) 701 | def oldGeneric() = oldGenericSort(mediumDoubleArray) 702 | } 703 | 704 | // ========================================== 705 | trait InsertionSort extends BaseSort { 706 | def directIntSorter(b:Array[Int]) = { 707 | val a = b.clone 708 | var i = 0 709 | while (i < a.length - 1) { 710 | var j = i + 1 711 | var k = i 712 | while (j < a.length) { 713 | if (a(j) < a(i)) k = j 714 | j += 1 715 | } 716 | val temp = a(i) 717 | a(i) = a(k) 718 | a(k) = temp 719 | i += 1 720 | } 721 | a 722 | } 723 | def directLongSorter(b:Array[Long]) = { 724 | val a = b.clone 725 | var i = 0 726 | while (i < a.length - 1) { 727 | var j = i + 1 728 | var k = i 729 | while (j < a.length) { 730 | if (a(j) < a(i)) k = j 731 | j += 1 732 | } 733 | val temp = a(i) 734 | a(i) = a(k) 735 | a(k) = temp 736 | i += 1 737 | } 738 | a 739 | } 740 | def directFloatSorter(b:Array[Float]) = { 741 | val a = b.clone 742 | var i = 0 743 | while (i < a.length - 1) { 744 | var j = i + 1 745 | var k = i 746 | while (j < a.length) { 747 | if (a(j) < a(i)) k = j 748 | j += 1 749 | } 750 | val temp = a(i) 751 | a(i) = a(k) 752 | a(k) = temp 753 | i += 1 754 | } 755 | a 756 | } 757 | def directDoubleSorter(b:Array[Double]) = { 758 | val a = b.clone 759 | var i = 0 760 | while (i < a.length - 1) { 761 | var j = i + 1 762 | var k = i 763 | while (j < a.length) { 764 | if (a(j) < a(i)) k = j 765 | j += 1 766 | } 767 | val temp = a(i) 768 | a(i) = a(k) 769 | a(k) = temp 770 | i += 1 771 | } 772 | a 773 | } 774 | 775 | def newGenericSorter[@specialized A:Numeric:Manifest](b:Array[A]) = { 776 | val a = b.clone 777 | var i = 0 778 | while (i < a.length - 1) { 779 | var j = i + 1 780 | var k = i 781 | while (j < a.length) { 782 | if (numeric.lt(a(j), a(i))) k = j 783 | j += 1 784 | } 785 | val temp = a(i) 786 | a(i) = a(k) 787 | a(k) = temp 788 | i += 1 789 | } 790 | a 791 | } 792 | 793 | def oldGenericSorter[A:OldNumeric:Manifest](b:Array[A]) = { 794 | val n = implicitly[OldNumeric[A]] 795 | val a = b.clone 796 | var i = 0 797 | while (i < a.length - 1) { 798 | var j = i + 1 799 | var k = i 800 | while (j < a.length) { 801 | if (n.lt(a(j), a(i))) k = j 802 | j += 1 803 | } 804 | val temp = a(i) 805 | a(i) = a(k) 806 | a(k) = temp 807 | i += 1 808 | } 809 | a 810 | } 811 | } 812 | 813 | final class InsertionSortInt extends InsertionSort { 814 | def name = "insertion-sort-int" 815 | def direct() = directInt(smallIntArray) 816 | def newGeneric() = newGenericSort(smallIntArray) 817 | def oldGeneric() = oldGenericSort(smallIntArray) 818 | } 819 | 820 | final class InsertionSortLong extends InsertionSort { 821 | def name = "insertion-sort-long" 822 | def direct() = directLong(smallLongArray) 823 | def newGeneric() = newGenericSort(smallLongArray) 824 | def oldGeneric() = oldGenericSort(smallLongArray) 825 | } 826 | 827 | final class InsertionSortFloat extends InsertionSort { 828 | def name = "insertion-sort-float" 829 | def direct() = directFloat(smallFloatArray) 830 | def newGeneric() = newGenericSort(smallFloatArray) 831 | def oldGeneric() = oldGenericSort(smallFloatArray) 832 | } 833 | 834 | final class InsertionSortDouble extends InsertionSort { 835 | def name = "insertion-sort-double" 836 | def direct() = directDouble(smallDoubleArray) 837 | def newGeneric() = newGenericSort(smallDoubleArray) 838 | def oldGeneric() = oldGenericSort(smallDoubleArray) 839 | } 840 | 841 | 842 | // ======================================================== 843 | trait ArrayAllocator extends TestCase { 844 | def directIntAllocator(num:Int, dim:Int, const:Int) = { 845 | val outer = Array.ofDim[Array[Int]](num) 846 | var i = 0 847 | while (i < num) { outer(i) = Array.fill(dim)(const); i += 1 } 848 | outer 849 | } 850 | 851 | def directLongAllocator(num:Int, dim:Int, const:Long) = { 852 | val outer = Array.ofDim[Array[Long]](num) 853 | var i = 0 854 | while (i < num) { outer(i) = Array.fill(dim)(const); i += 1 } 855 | outer 856 | } 857 | 858 | def directFloatAllocator(num:Int, dim:Int, const:Float) = { 859 | val outer = Array.ofDim[Array[Float]](num) 860 | var i = 0 861 | while (i < num) { outer(i) = Array.fill(dim)(const); i += 1 } 862 | outer 863 | } 864 | 865 | def directDoubleAllocator(num:Int, dim:Int, const:Double) = { 866 | val outer = Array.ofDim[Array[Double]](num) 867 | var i = 0 868 | while (i < num) { outer(i) = Array.fill(dim)(const); i += 1 } 869 | outer 870 | } 871 | 872 | def newAllocator[@specialized A:Numeric:Manifest](num:Int, dim:Int, const:A) = { 873 | val outer = Array.ofDim[Array[A]](num) 874 | var i = 0 875 | while (i < num) { outer(i) = Array.fill(dim)(const); i += 1 } 876 | outer 877 | } 878 | 879 | def oldAllocator[A:OldNumeric:Manifest](num:Int, dim:Int, const:A) = { 880 | val outer = Array.ofDim[Array[A]](num) 881 | var i = 0 882 | while (i < num) { outer(i) = Array.fill(dim)(const); i += 1 } 883 | outer 884 | } 885 | } 886 | 887 | final class ArrayAllocatorInt extends ArrayAllocator { 888 | def name = "array-allocator-int" 889 | def direct = Option(directIntAllocator(mediumSize, 5, 13)) 890 | def newGeneric = Option(newAllocator(mediumSize, 5, 13)) 891 | def oldGeneric = Option(oldAllocator(mediumSize, 5, 13)) 892 | } 893 | 894 | final class ArrayAllocatorLong extends ArrayAllocator { 895 | def name = "array-allocator-long" 896 | def direct = Option(directLongAllocator(mediumSize, 5, 13L)) 897 | def newGeneric = Option(newAllocator(mediumSize, 5, 13L)) 898 | def oldGeneric = Option(oldAllocator(mediumSize, 5, 13L)) 899 | } 900 | 901 | final class ArrayAllocatorFloat extends ArrayAllocator { 902 | def name = "array-allocator-float" 903 | def direct = Option(directFloatAllocator(mediumSize, 5, 13.0F)) 904 | def newGeneric = Option(newAllocator(mediumSize, 5, 13.0F)) 905 | def oldGeneric = Option(oldAllocator(mediumSize, 5, 13.0F)) 906 | } 907 | 908 | final class ArrayAllocatorDouble extends ArrayAllocator { 909 | def name = "array-allocator-double" 910 | def direct = Option(directDoubleAllocator(mediumSize, 5, 13.0)) 911 | def newGeneric = Option(newAllocator(mediumSize, 5, 13.0)) 912 | def oldGeneric = Option(oldAllocator(mediumSize, 5, 13.0)) 913 | } 914 | 915 | // ================================================================= 916 | trait MergeSort extends BaseSort { 917 | def directIntSorter(a:Array[Int]) = { 918 | if (a.length > 1) { 919 | val llen = a.length / 2 920 | val rlen = a.length - llen 921 | 922 | val left = Array.ofDim[Int](llen) 923 | Array.copy(a, 0, left, 0, llen) 924 | directIntSorter(left) 925 | 926 | val right = Array.ofDim[Int](rlen) 927 | Array.copy(a, llen, right, 0, rlen) 928 | directIntSorter(right) 929 | 930 | var (i, j, k) = (0, 0, 0) 931 | while (i < llen || j < rlen) { 932 | if (j == rlen) { 933 | a(k) = left(i); i += 1 934 | } else if (i == llen) { 935 | a(k) = right(j); j += 1 936 | } else if (left(i) < right(j)) { 937 | a(k) = left(i); i += 1 938 | } else { 939 | a(k) = right(j); j += 1 940 | } 941 | k += 1 942 | } 943 | } 944 | a 945 | } 946 | 947 | def directLongSorter(a:Array[Long]) = { 948 | if (a.length > 1) { 949 | val llen = a.length / 2 950 | val rlen = a.length - llen 951 | 952 | val left = Array.ofDim[Long](llen) 953 | Array.copy(a, 0, left, 0, llen) 954 | directLongSorter(left) 955 | 956 | val right = Array.ofDim[Long](rlen) 957 | Array.copy(a, llen, right, 0, rlen) 958 | directLongSorter(right) 959 | 960 | var (i, j, k) = (0, 0, 0) 961 | while (i < llen || j < rlen) { 962 | if (j == rlen) { 963 | a(k) = left(i); i += 1 964 | } else if (i == llen) { 965 | a(k) = right(j); j += 1 966 | } else if (left(i) < right(j)) { 967 | a(k) = left(i); i += 1 968 | } else { 969 | a(k) = right(j); j += 1 970 | } 971 | k += 1 972 | } 973 | } 974 | a 975 | } 976 | 977 | def directFloatSorter(a:Array[Float]) = { 978 | if (a.length > 1) { 979 | val llen = a.length / 2 980 | val rlen = a.length - llen 981 | 982 | val left = Array.ofDim[Float](llen) 983 | Array.copy(a, 0, left, 0, llen) 984 | directFloatSorter(left) 985 | 986 | val right = Array.ofDim[Float](rlen) 987 | Array.copy(a, llen, right, 0, rlen) 988 | directFloatSorter(right) 989 | 990 | var (i, j, k) = (0, 0, 0) 991 | while (i < llen || j < rlen) { 992 | if (j == rlen) { 993 | a(k) = left(i); i += 1 994 | } else if (i == llen) { 995 | a(k) = right(j); j += 1 996 | } else if (left(i) < right(j)) { 997 | a(k) = left(i); i += 1 998 | } else { 999 | a(k) = right(j); j += 1 1000 | } 1001 | k += 1 1002 | } 1003 | } 1004 | a 1005 | } 1006 | 1007 | def directDoubleSorter(a:Array[Double]) = { 1008 | if (a.length > 1) { 1009 | val llen = a.length / 2 1010 | val rlen = a.length - llen 1011 | 1012 | val left = Array.ofDim[Double](llen) 1013 | Array.copy(a, 0, left, 0, llen) 1014 | directDoubleSorter(left) 1015 | 1016 | val right = Array.ofDim[Double](rlen) 1017 | Array.copy(a, llen, right, 0, rlen) 1018 | directDoubleSorter(right) 1019 | 1020 | var (i, j, k) = (0, 0, 0) 1021 | while (i < llen || j < rlen) { 1022 | if (j == rlen) { 1023 | a(k) = left(i); i += 1 1024 | } else if (i == llen) { 1025 | a(k) = right(j); j += 1 1026 | } else if (left(i) < right(j)) { 1027 | a(k) = left(i); i += 1 1028 | } else { 1029 | a(k) = right(j); j += 1 1030 | } 1031 | k += 1 1032 | } 1033 | } 1034 | a 1035 | } 1036 | 1037 | def newGenericSorter[@specialized A:Numeric:Manifest](a:Array[A]):Array[A] = { 1038 | if (a.length > 1) { 1039 | val llen = a.length / 2 1040 | val rlen = a.length - llen 1041 | 1042 | val left = Array.ofDim[A](llen) 1043 | Array.copy(a, 0, left, 0, llen) 1044 | newGenericSorter(left) 1045 | 1046 | val right = Array.ofDim[A](rlen) 1047 | Array.copy(a, llen, right, 0, rlen) 1048 | newGenericSorter(right) 1049 | 1050 | var (i, j, k) = (0, 0, 0) 1051 | while (i < llen || j < rlen) { 1052 | if (j == rlen) { 1053 | a(k) = left(i); i += 1 1054 | } else if (i == llen) { 1055 | a(k) = right(j); j += 1 1056 | } else if (numeric.lt(left(i), right(j))) { 1057 | a(k) = left(i); i += 1 1058 | } else { 1059 | a(k) = right(j); j += 1 1060 | } 1061 | k += 1 1062 | } 1063 | } 1064 | a 1065 | } 1066 | 1067 | def oldGenericSorter[A:OldNumeric:Manifest](a:Array[A]):Array[A] = { 1068 | if (a.length > 1) { 1069 | val n = implicitly[OldNumeric[A]] 1070 | 1071 | val llen = a.length / 2 1072 | val rlen = a.length - llen 1073 | 1074 | val left = Array.ofDim[A](llen) 1075 | Array.copy(a, 0, left, 0, llen) 1076 | oldGenericSorter(left) 1077 | 1078 | val right = Array.ofDim[A](rlen) 1079 | Array.copy(a, llen, right, 0, rlen) 1080 | oldGenericSorter(right) 1081 | 1082 | var (i, j, k) = (0, 0, 0) 1083 | while (i < llen || j < rlen) { 1084 | if (j == rlen) { 1085 | a(k) = left(i); i += 1 1086 | } else if (i == llen) { 1087 | a(k) = right(j); j += 1 1088 | } else if (n.lt(left(i), right(j))) { 1089 | a(k) = left(i); i += 1 1090 | } else { 1091 | a(k) = right(j); j += 1 1092 | } 1093 | k += 1 1094 | } 1095 | } 1096 | a 1097 | } 1098 | } 1099 | 1100 | final class MergeSortInt extends MergeSort { 1101 | def name = "merge-sort-int" 1102 | def direct() = { val a = mediumIntArray.clone; directInt(a); Some(a) } 1103 | def newGeneric() = { val a = mediumIntArray.clone; newGenericSort(a); Some(a) } 1104 | def oldGeneric() = { val a = mediumIntArray.clone; oldGenericSort(a); Some(a) } 1105 | } 1106 | 1107 | final class MergeSortLong extends MergeSort { 1108 | def name = "merge-sort-long" 1109 | def direct() = { val a = mediumLongArray.clone; directLong(a); Some(a) } 1110 | def newGeneric() = { val a = mediumLongArray; newGenericSort(a); Some(a) } 1111 | def oldGeneric() = { val a = mediumLongArray; oldGenericSort(a); Some(a) } 1112 | } 1113 | 1114 | final class MergeSortFloat extends MergeSort { 1115 | def name = "merge-sort-float" 1116 | def direct() = { val a = mediumFloatArray; directFloat(a); Some(a) } 1117 | def newGeneric() = { val a = mediumFloatArray; newGenericSort(a); Some(a) } 1118 | def oldGeneric() = { val a = mediumFloatArray; oldGenericSort(a); Some(a) } 1119 | } 1120 | 1121 | final class MergeSortDouble extends MergeSort { 1122 | def name = "merge-sort-double" 1123 | def direct() = { val a = mediumDoubleArray; directDouble(a); Some(a) } 1124 | def newGeneric() = { val a = mediumDoubleArray; newGenericSort(a); Some(a) } 1125 | def oldGeneric() = { val a = mediumDoubleArray; oldGenericSort(a); Some(a) } 1126 | } 1127 | 1128 | 1129 | // ================================================================= 1130 | final class IncrementInt1 extends TestCase { 1131 | def name = "increment-int1" 1132 | 1133 | def directIncrement(x:Int) = x + 100 1134 | def direct() = { 1135 | var i = 0 1136 | var total = 0 1137 | while (i < largeSize) { 1138 | total = directIncrement(total) 1139 | i += 1 1140 | } 1141 | Some(total) 1142 | } 1143 | 1144 | def newIncrement[@specialized A:Numeric](a:A):A = numeric.plus(a, numeric.fromInt(100)) 1145 | def newGeneric() = { 1146 | var i = 0 1147 | var total = 0 1148 | while (i < largeSize) { 1149 | total = newIncrement(total) 1150 | i += 1 1151 | } 1152 | Some(total) 1153 | } 1154 | 1155 | def oldIncrement[A](a:A)(implicit n:OldNumeric[A]) = n.plus(a, n.fromInt(100)) 1156 | def oldGeneric() = { 1157 | var i = 0 1158 | var total = 0 1159 | while (i < largeSize) { 1160 | total = oldIncrement(total) 1161 | i += 1 1162 | } 1163 | Some(total) 1164 | } 1165 | } 1166 | 1167 | final class IncrementInt2 extends TestCase { 1168 | def name = "increment-int2" 1169 | 1170 | def directIncrement(x:Int) = x + 100 1171 | def direct() = { 1172 | var i = 0 1173 | var total = 0 1174 | while (i < largeSize) { 1175 | total = directIncrement(total) 1176 | i += 1 1177 | } 1178 | Some(total) 1179 | } 1180 | 1181 | def newIncrement[@specialized A:Numeric](a:A):A = a + numeric.fromInt(100) 1182 | def newGeneric() = { 1183 | var i = 0 1184 | var total = 0 1185 | while (i < largeSize) { 1186 | total = newIncrement(total) 1187 | i += 1 1188 | } 1189 | Some(total) 1190 | } 1191 | 1192 | def oldIncrement[A](a:A)(implicit n:OldNumeric[A]) = n.plus(a, n.fromInt(100)) 1193 | def oldGeneric() = { 1194 | var i = 0 1195 | var total = 0 1196 | while (i < largeSize) { 1197 | total = oldIncrement(total) 1198 | i += 1 1199 | } 1200 | Some(total) 1201 | } 1202 | } 1203 | 1204 | final class IncrementInt3 extends TestCase { 1205 | def name = "increment-int3" 1206 | 1207 | def directIncrement(x:Int) = x + 100 1208 | def direct() = { 1209 | var i = 0 1210 | var total = 0 1211 | while (i < largeSize) { 1212 | total = directIncrement(total) 1213 | i += 1 1214 | } 1215 | Some(total) 1216 | } 1217 | 1218 | def newIncrement[@specialized A:Numeric](a:A):A = a + 100 1219 | def newGeneric() = { 1220 | var i = 0 1221 | var total = 0 1222 | while (i < largeSize) { 1223 | total = newIncrement(total) 1224 | i += 1 1225 | } 1226 | Some(total) 1227 | } 1228 | 1229 | def oldIncrement[A](a:A)(implicit n:OldNumeric[A]) = n.plus(a, n.fromInt(100)) 1230 | def oldGeneric() = { 1231 | var i = 0 1232 | var total = 0 1233 | while (i < largeSize) { 1234 | total = oldIncrement(total) 1235 | i += 1 1236 | } 1237 | Some(total) 1238 | } 1239 | } 1240 | 1241 | final class IncrementInt4 extends TestCase { 1242 | def name = "increment-int4" 1243 | 1244 | def directIncrement(x:Int) = x + 100 1245 | def direct() = { 1246 | var i = 0 1247 | var total = 0 1248 | while (i < largeSize) { 1249 | total = directIncrement(total) 1250 | i += 1 1251 | } 1252 | Some(total) 1253 | } 1254 | 1255 | def newIncrement[@specialized A:Numeric](a:A):A = 100 + a 1256 | def newGeneric() = { 1257 | var i = 0 1258 | var total = 0 1259 | while (i < largeSize) { 1260 | total = newIncrement(total) 1261 | i += 1 1262 | } 1263 | Some(total) 1264 | } 1265 | 1266 | def oldIncrement[A](a:A)(implicit n:OldNumeric[A]) = n.plus(a, n.fromInt(100)) 1267 | def oldGeneric() = { 1268 | var i = 0 1269 | var total = 0 1270 | while (i < largeSize) { 1271 | total = oldIncrement(total) 1272 | i += 1 1273 | } 1274 | Some(total) 1275 | } 1276 | } 1277 | 1278 | 1279 | object Main { 1280 | val tests = List(List(new FromIntToInt, 1281 | new FromIntToLong, 1282 | new FromIntToFloat, 1283 | new FromIntToDouble), 1284 | 1285 | List(new InfixAdderInt, 1286 | new InfixAdderLong, 1287 | new InfixAdderFloat, 1288 | new InfixAdderDouble), 1289 | 1290 | List(new AdderInt, 1291 | new AdderLong, 1292 | new AdderFloat, 1293 | new AdderDouble), 1294 | 1295 | List(new IntArrayAdder, 1296 | new LongArrayAdder, 1297 | new FloatArrayAdder, 1298 | new DoubleArrayAdder), 1299 | 1300 | List(new IntArrayRescale, 1301 | new LongArrayRescale, 1302 | new FloatArrayRescale, 1303 | new DoubleArrayRescale), 1304 | 1305 | List(new ArrayAllocatorInt, 1306 | new ArrayAllocatorLong, 1307 | new ArrayAllocatorFloat, 1308 | new ArrayAllocatorDouble), 1309 | 1310 | List(new FindMaxInt, 1311 | new FindMaxLong, 1312 | new FindMaxFloat, 1313 | new FindMaxDouble), 1314 | 1315 | List(new QuicksortInt, 1316 | new QuicksortLong, 1317 | new QuicksortFloat, 1318 | new QuicksortDouble), 1319 | 1320 | List(new InsertionSortInt, 1321 | new InsertionSortLong, 1322 | new InsertionSortFloat, 1323 | new InsertionSortDouble), 1324 | 1325 | List(new MergeSortInt, 1326 | new MergeSortLong, 1327 | new MergeSortFloat, 1328 | new MergeSortDouble), 1329 | 1330 | List(new IncrementInt1, 1331 | new IncrementInt2, 1332 | new IncrementInt3, 1333 | new IncrementInt4)) 1334 | 1335 | 1336 | def getHTMLHeader() = """ 1337 | 1338 | 1339 | 1354 | 1355 | 1356 | 1357 | 1358 | 1359 | 1360 | 1361 | """ 1362 | 1363 | def getHTMLFooter() = "
testdirect (ms)new (ms)old (ms)
\n \n\n" 1364 | 1365 | def main(args:Array[String]): Unit = { 1366 | if (Constant.createHTML) { 1367 | println("creating benchmark.html...") 1368 | printf("%-24s %8s %8s %8s / %6s %6s %6s\n", "test", "direct", "new", "old", "n:d", "o:d", "o:n") 1369 | 1370 | val p = new PrintWriter(new FileWriter("benchmark.html")) 1371 | 1372 | p.println(getHTMLHeader()) 1373 | tests.foreach { 1374 | group => { 1375 | p.println("\n") 1376 | group.foreach(_.test(Some(p))) 1377 | } 1378 | } 1379 | p.println(getHTMLFooter()) 1380 | p.close() 1381 | 1382 | } else { 1383 | printf("%-24s %8s %8s %8s / %6s %6s %6s\n", "test", "direct", "new", "old", "n:d", "o:d", "o:n") 1384 | tests.foreach { 1385 | group => group.foreach(_.test(None)) 1386 | } 1387 | } 1388 | } 1389 | } 1390 | -------------------------------------------------------------------------------- /plugin/Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: default clean plugin post tree browse base test run runbase runtest all 2 | 3 | # The plugin's source file and JAR to create 4 | # PLUGIN=OptimizedNumeric.scala 5 | PLUGIN=src/main/scala/com/azavea/math/plugin/OptimizedNumeric.scala 6 | PLUGJAR=optimized-numeric.jar 7 | PLUGXML=src/main/resources/scalac-plugin.xml 8 | 9 | # The phase to run before inspecting the AST 10 | PHASE=typer 11 | PHASE2=optimized-numeric 12 | 13 | # The test's source file, and class name 14 | # TEST=Test.scala 15 | # TESTBASE=Test 16 | TEST=src/test/scala/Example.scala 17 | TESTBASE=Example 18 | 19 | # The classpath to use when building/running the test 20 | ## CP=lib/numeric_2.9.1-0.1.jar 21 | CP=../target/scala-2.9.1.final/classes 22 | 23 | # Some help output 24 | default: 25 | @echo "targets: clean | plugin | post | tree | base | test | run" 26 | 27 | # Remove classfiles and the plugin's jar file 28 | clean: 29 | rm -rf classes 30 | find . -name '*.class' -exec rm -f {} \; 31 | rm -f $(PLUGJAR) 32 | 33 | # Build the plugin 34 | plugin: 35 | mkdir -p classes 36 | scalac -cp $(CP) -d classes $(PLUGIN) 37 | cp $(PLUGXML) classes 38 | jar -C classes -c -f $(PLUGJAR) . 39 | 40 | # Various targets for inspecting the AST: 41 | # post shows the AST as source code 42 | # tree prints the actual AST 43 | # browse runs a graphical AST browser 44 | post: 45 | scalac -cp $(CP) -Xprint:$(PHASE) -Ystop-after:$(PHASE) $(TEST) | tee POST 46 | 47 | tree: 48 | scalac -cp $(CP) -Xprint:$(PHASE) -Ystop-after:$(PHASE) -Yshow-trees $(TEST) | tee TREE 49 | 50 | browse: 51 | scalac -cp $(CP) -Ybrowse:$(PHASE) -Ystop-after:$(PHASE) -Yshow-trees $(TEST) 52 | 53 | postplug: 54 | scalac -cp $(CP) -Xplugin:$(PLUGJAR) -Xprint:$(PHASE2) -Ystop-after:$(PHASE2) $(TEST) | tee POST 55 | 56 | # Compile the test without the plugin (the base case) 57 | base: 58 | scalac -cp $(CP) $(TEST) 59 | 60 | # Compile the test with the plugin (the test case) 61 | test: 62 | scalac -cp $(CP) -Xplugin:$(PLUGJAR) $(TEST) 63 | 64 | # Run the test 65 | run: 66 | scala -cp $(CP) $(TESTBASE) 67 | 68 | # Compile and run the base case 69 | runbase: base run 70 | 71 | # Compile and run the test case 72 | runtest: test run 73 | 74 | # Do a clean build of the plugin and test case, then run the test 75 | all: clean plugin test run 76 | -------------------------------------------------------------------------------- /plugin/build.sbt: -------------------------------------------------------------------------------- 1 | // project name 2 | name := "Optimized Numeric Plugin" 3 | 4 | //sbtPlugin := true 5 | 6 | libraryDependencies <+= scalaVersion("org.scala-lang" % "scala-compiler" % _) 7 | 8 | resolvers += "Scala-Tools Maven2 Snapshots Repository" at "http://scala-tools.org/repo-snapshots" 9 | 10 | // shrug? 11 | version := "0.1" 12 | 13 | // hide backup files 14 | defaultExcludes ~= (filter => filter || "*~") 15 | 16 | scalacOptions += "-optimise" 17 | 18 | // any of these work, although 2.9.1 performs the best 19 | //scalaVersion := "2.8.1 20 | //scalaVersion := "2.9.0-1" 21 | scalaVersion := "2.9.1" 22 | 23 | //crossScalaVersions := List("2.8.1", "2.9.0-1", "2.9.1") 24 | -------------------------------------------------------------------------------- /plugin/src/main/resources/scalac-plugin.xml: -------------------------------------------------------------------------------- 1 | 2 | optimized-numeric 3 | com.azavea.math.plugin.OptimizedNumeric 4 | 5 | -------------------------------------------------------------------------------- /plugin/src/main/scala/com/azavea/math/plugin/OptimizedNumeric.scala: -------------------------------------------------------------------------------- 1 | package com.azavea.math.plugin 2 | 3 | import scala.tools.nsc 4 | import nsc.Global 5 | import nsc.Phase 6 | import nsc.plugins.Plugin 7 | import nsc.plugins.PluginComponent 8 | import nsc.transform.Transform 9 | import nsc.transform.InfoTransform 10 | import nsc.transform.TypingTransformers 11 | import nsc.symtab.Flags._ 12 | import nsc.ast.TreeDSL 13 | import nsc.typechecker 14 | 15 | /** 16 | * Our shiny compiler plugin. 17 | */ 18 | class OptimizedNumeric(val global: Global) extends Plugin { 19 | val name = "optimized-numeric" 20 | val description = "Optimizes com.azavea.math.Numeric usage." 21 | val components = List[PluginComponent](new RewriteInfixOps(this, global)) 22 | } 23 | 24 | /** 25 | * This component turns things like: 26 | * 1. new FastNumericOps[T](m)(implicit ev).+(n) 27 | * 2. com.azavea.math.FastImplicits.infixOps[T](m)(implicit ev).*(n) 28 | * 29 | * Into: 30 | * 1. ev.plus(m, n) 31 | * 2. ev.times(m, n) 32 | */ 33 | class RewriteInfixOps(plugin:Plugin, val global:Global) extends PluginComponent 34 | with Transform with TypingTransformers with TreeDSL { 35 | import global._ 36 | import typer.typed 37 | 38 | // set to true to print a warning for each transform 39 | val debugging = false 40 | 41 | // TODO: maybe look up the definition of op and automatically figure mapping 42 | val unops = Map( 43 | newTermName("abs") -> "abs", 44 | newTermName("unary_$minus") -> "negate", 45 | newTermName("signum") -> "signum" 46 | ) 47 | 48 | val binops = Map( 49 | newTermName("compare") -> "compare", 50 | newTermName("equiv") -> "equiv", 51 | newTermName("max") -> "max", 52 | newTermName("min") -> "min", 53 | 54 | newTermName("$less$eq$greater") -> "compare", 55 | newTermName("$div") -> "div", 56 | newTermName("$eq$eq$eq") -> "equiv", 57 | newTermName("$bang$eq$eq") -> "nequiv", 58 | newTermName("$greater") -> "gt", 59 | newTermName("$greater$eq") -> "gteq", 60 | newTermName("$less") -> "lt", 61 | newTermName("$less$eq") -> "lteq", 62 | newTermName("$minus") -> "minus", 63 | newTermName("$percent") -> "mod", 64 | newTermName("$plus") -> "plus", 65 | newTermName("$times") -> "times", 66 | newTermName("$times$times") -> "pow" 67 | ) 68 | 69 | val runsAfter = List("typer"); 70 | val phaseName = "optimized-numeric" 71 | def newTransformer(unit:CompilationUnit) = new MyTransformer(unit) 72 | 73 | // Determine if two type are equivalent 74 | def equivalentTypes(t1:Type, t2:Type) = { 75 | t1.dealias.deconst.widen =:= t2.dealias.deconst.widen 76 | } 77 | 78 | // TODO: figure out better type matching for Numeric, e.g. a.tpe <:< b.tpe 79 | val numericClass = definitions.getClass("com.azavea.math.Numeric") 80 | def isNumeric(t:Type) = t.typeSymbol == numericClass.tpe.typeSymbol 81 | 82 | // For built-in types, figure out whether or not we have a "fast" conversion method 83 | val BigIntClass = definitions.getClass("scala.math.BigInt") 84 | val BigDecimalClass = definitions.getClass("scala.math.BigDecimal") 85 | def getConverter(t:Type) = if (t <:< definitions.ByteClass.tpe) { 86 | Some("fromByte") 87 | } else if (t <:< definitions.ShortClass.tpe) { 88 | Some("fromShort") 89 | } else if (t <:< definitions.IntClass.tpe) { 90 | Some("fromInt") 91 | } else if (t <:< definitions.LongClass.tpe) { 92 | Some("fromLong") 93 | } else if (t <:< definitions.FloatClass.tpe) { 94 | Some("fromFloat") 95 | } else if (t <:< definitions.DoubleClass.tpe) { 96 | Some("fromDouble") 97 | } else if (t <:< BigIntClass.tpe) { 98 | Some("fromBigInt") 99 | } else if (t <:< BigDecimalClass.tpe) { 100 | Some("fromBigDecimal") 101 | } else { 102 | None 103 | } 104 | 105 | // TODO: maybe match further out on the implicit Numeric[T]? 106 | class MyTransformer(unit:CompilationUnit) extends TypingTransformer(unit) { 107 | 108 | override def transform(tree: Tree): Tree = { 109 | //def mylog(s:String) = if (debugging) unit.warning(tree.pos, s) 110 | def mylog(s:String) = Unit 111 | 112 | val tree2 = tree match { 113 | 114 | // match fuzzy binary operators 115 | case Apply(Apply(TypeApply(Select(Apply(Apply(_, List(m)), List(ev)), op), List(tt)), List(n)), List(ev2)) => { 116 | if (!isNumeric(ev.tpe)) { 117 | //mylog("fuzzy alarm #1") 118 | tree 119 | 120 | } else if (binops.contains(op)) { 121 | val op2 = binops(op) 122 | val conv = getConverter(n.tpe) 123 | conv match { 124 | case Some(meth) => { 125 | //mylog("fuzzy transformed %s (with %s)".format(op, meth)) 126 | typed { Apply(Select(ev, op2), List(m, Apply(Select(ev, meth), List(n)))) } 127 | } 128 | case None => if (equivalentTypes(m.tpe, n.tpe)) { 129 | //mylog("fuzzy transformed %s (removed conversion)".format(op)) 130 | typed { Apply(Select(ev, op2), List(m, n)) } 131 | } else { 132 | //mylog("fuzzy transformed %s".format(op)) 133 | typed { Apply(Select(ev, op2), List(m, Apply(TypeApply(Select(ev, "fromType"), List(tt)), List(n)))) } 134 | } 135 | } 136 | 137 | } else { 138 | //mylog("fuzzy alarm #2") 139 | tree 140 | } 141 | } 142 | 143 | // match IntOps (and friends Float, Long, etc.) 144 | case Apply(Apply(TypeApply(Select(Apply(_, List(m)), op), List(tt)), List(n)), List(ev)) => { 145 | if (!isNumeric(ev.tpe)) { 146 | //mylog("literal ops alarm #1") 147 | tree 148 | 149 | } else if (binops.contains(op)) { 150 | val op2 = binops(op) 151 | val conv = getConverter(m.tpe) 152 | conv match { 153 | case Some(meth) => { 154 | //mylog("zzz literal ops transformed %s (with %s)".format(op, meth)) 155 | typed { Apply(Select(ev, op2), List(Apply(Select(ev, meth), List(m)), n)) } 156 | } 157 | case None => { 158 | //mylog("zzz literal ops transformed %s".format(op)) 159 | typed { Apply(Select(ev, op2), List(Apply(TypeApply(Select(ev, "fromType"), List(tt)), List(m)), n)) } 160 | } 161 | } 162 | 163 | } else { 164 | //mylog("literal ops alarm #2") 165 | tree 166 | } 167 | } 168 | 169 | // match binary operators 170 | case Apply(Select(Apply(Apply(_, List(m)), List(ev)), op), List(n)) => { 171 | if (!isNumeric(ev.tpe)) { 172 | unit.warning(tree.pos, "binop false alarm #1") 173 | tree 174 | } else if (binops.contains(op)) { 175 | val op2 = binops(op) 176 | //mylog("binop rewrote %s %s %s to n.%s(%s, %s)".format(m, op, n, op2, m, n)) 177 | typed { Apply(Select(ev, op2), List(m, n)) } 178 | } else { 179 | unit.warning(tree.pos, "binop false alarm #2") 180 | tree 181 | } 182 | } 183 | 184 | // match unary operators 185 | case Select(Apply(Apply(_, List(m)), List(ev)), op) => { 186 | if (!isNumeric(ev.tpe)) { 187 | unit.warning(tree.pos, "unop false alarm #1") 188 | tree 189 | } else if (unops.contains(op)) { 190 | val op2 = unops(op) 191 | //mylog("unop rewrote %s to n.%s".format(op, op2)) 192 | typed { Apply(Select(ev, op2), List(m)) } 193 | } else { 194 | unit.warning(tree.pos, "unop false alarm #2") 195 | tree 196 | } 197 | } 198 | 199 | case _ => tree 200 | } 201 | 202 | super.transform(tree2) 203 | } 204 | } 205 | } 206 | -------------------------------------------------------------------------------- /plugin/src/test/scala/Example.scala: -------------------------------------------------------------------------------- 1 | import com.azavea.math.Numeric 2 | 3 | //import com.azavea.math.FastImplicits._ 4 | import com.azavea.math.EasyImplicits._ 5 | import Predef.{any2stringadd => _, _} 6 | 7 | object Example { 8 | def foo1[T:Numeric](m:T, n:T) = m + n 9 | def foo2[T](m:T, n:T)(implicit ev:Numeric[T]) = ev.plus(m, n) 10 | 11 | def bar1[T:Numeric](m:T) = -m 12 | def bar2[T](m:T)(implicit ev:Numeric[T]) = ev.negate(m) 13 | 14 | def duh1[T:Numeric](m:T) = m + 13 15 | def duh2[T](m:T)(implicit ev:Numeric[T]) = ev.plus(m, ev.fromInt(13)) 16 | 17 | def yak1[T:Numeric](m:T) = m + BigInt(9) 18 | def yak2[T](m:T)(implicit ev:Numeric[T]) = ev.plus(m, ev.fromBigInt(9)) 19 | 20 | def zug1[T:Numeric](n:T) = 1 + n 21 | def zug2[T](n:T)(implicit ev:Numeric[T]) = ev.plus(ev.fromInt(1), n) 22 | 23 | def main(args: Array[String]) { 24 | println(bar1(1)) 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /project/Build.scala: -------------------------------------------------------------------------------- 1 | import sbt._ 2 | 3 | object NumericBuild extends Build { 4 | // library 5 | lazy val root = Project("root", file(".")) 6 | 7 | // plugin 8 | lazy val plugin = Project("plugin", file("plugin")) dependsOn(root) 9 | 10 | // performance testing 11 | lazy val perf = Project("perf", file("perf")) dependsOn(root, plugin) 12 | } 13 | -------------------------------------------------------------------------------- /src/main/scala/com/azavea/math/Convertable.scala: -------------------------------------------------------------------------------- 1 | package com.azavea.math 2 | 3 | /** 4 | * @author Erik Osheim 5 | */ 6 | 7 | /** 8 | * This package is used to provide concrete implementations of the conversions 9 | * between numeric primitives. The idea here is that the Numeric trait can 10 | * extend these traits to inherit the conversions. 11 | * 12 | * We can also use these implementations to provide a way to convert from 13 | * A -> B, where both A and B are generic Numeric types. Without a separate 14 | * trait, we'd have circular type definitions when compiling Numeric. 15 | */ 16 | 17 | import scala.{specialized => spec} 18 | 19 | /** 20 | * Conversions to type. 21 | * 22 | * An object implementing ConvertableTo[A] provides methods to go 23 | * from number types to A. 24 | */ 25 | trait ConvertableTo[@spec A] { 26 | implicit def fromByte(a:Byte): A 27 | implicit def fromShort(a:Short): A 28 | implicit def fromInt(a:Int): A 29 | implicit def fromLong(a:Long): A 30 | implicit def fromFloat(a:Float): A 31 | implicit def fromDouble(a:Double): A 32 | implicit def fromBigInt(a:BigInt): A 33 | implicit def fromBigDecimal(a:BigDecimal): A 34 | } 35 | 36 | trait ConvertableToByte extends ConvertableTo[Byte] { 37 | implicit def fromByte(a:Byte): Byte = a 38 | implicit def fromShort(a:Short): Byte = a.toByte 39 | implicit def fromInt(a:Int): Byte = a.toByte 40 | implicit def fromLong(a:Long): Byte = a.toByte 41 | implicit def fromFloat(a:Float): Byte = a.toByte 42 | implicit def fromDouble(a:Double): Byte = a.toByte 43 | implicit def fromBigInt(a:BigInt): Byte = a.toByte 44 | implicit def fromBigDecimal(a:BigDecimal): Byte = a.toByte 45 | } 46 | 47 | trait ConvertableToShort extends ConvertableTo[Short] { 48 | implicit def fromByte(a:Byte): Short = a.toShort 49 | implicit def fromShort(a:Short): Short = a 50 | implicit def fromInt(a:Int): Short = a.toShort 51 | implicit def fromLong(a:Long): Short = a.toShort 52 | implicit def fromFloat(a:Float): Short = a.toShort 53 | implicit def fromDouble(a:Double): Short = a.toShort 54 | implicit def fromBigInt(a:BigInt): Short = a.toShort 55 | implicit def fromBigDecimal(a:BigDecimal): Short = a.toShort 56 | } 57 | 58 | trait ConvertableToInt extends ConvertableTo[Int] { 59 | implicit def fromByte(a:Byte): Int = a.toInt 60 | implicit def fromShort(a:Short): Int = a.toInt 61 | implicit def fromInt(a:Int): Int = a 62 | implicit def fromLong(a:Long): Int = a.toInt 63 | implicit def fromFloat(a:Float): Int = a.toInt 64 | implicit def fromDouble(a:Double): Int = a.toInt 65 | implicit def fromBigInt(a:BigInt): Int = a.toInt 66 | implicit def fromBigDecimal(a:BigDecimal): Int = a.toInt 67 | } 68 | 69 | trait ConvertableToLong extends ConvertableTo[Long] { 70 | implicit def fromByte(a:Byte): Long = a.toLong 71 | implicit def fromShort(a:Short): Long = a.toLong 72 | implicit def fromInt(a:Int): Long = a.toLong 73 | implicit def fromLong(a:Long): Long = a 74 | implicit def fromFloat(a:Float): Long = a.toLong 75 | implicit def fromDouble(a:Double): Long = a.toLong 76 | implicit def fromBigInt(a:BigInt): Long = a.toLong 77 | implicit def fromBigDecimal(a:BigDecimal): Long = a.toLong 78 | } 79 | 80 | trait ConvertableToFloat extends ConvertableTo[Float] { 81 | implicit def fromByte(a:Byte): Float = a.toFloat 82 | implicit def fromShort(a:Short): Float = a.toFloat 83 | implicit def fromInt(a:Int): Float = a.toFloat 84 | implicit def fromLong(a:Long): Float = a.toFloat 85 | implicit def fromFloat(a:Float): Float = a 86 | implicit def fromDouble(a:Double): Float = a.toFloat 87 | implicit def fromBigInt(a:BigInt): Float = a.toFloat 88 | implicit def fromBigDecimal(a:BigDecimal): Float = a.toFloat 89 | } 90 | 91 | trait ConvertableToDouble extends ConvertableTo[Double] { 92 | implicit def fromByte(a:Byte): Double = a.toDouble 93 | implicit def fromShort(a:Short): Double = a.toDouble 94 | implicit def fromInt(a:Int): Double = a.toDouble 95 | implicit def fromLong(a:Long): Double = a.toDouble 96 | implicit def fromFloat(a:Float): Double = a.toDouble 97 | implicit def fromDouble(a:Double): Double = a 98 | implicit def fromBigInt(a:BigInt): Double = a.toDouble 99 | implicit def fromBigDecimal(a:BigDecimal): Double = a.toDouble 100 | } 101 | 102 | trait ConvertableToBigInt extends ConvertableTo[BigInt] { 103 | implicit def fromByte(a:Byte): BigInt = BigInt(a) 104 | implicit def fromShort(a:Short): BigInt = BigInt(a) 105 | implicit def fromInt(a:Int): BigInt = BigInt(a) 106 | implicit def fromLong(a:Long): BigInt = BigInt(a) 107 | implicit def fromFloat(a:Float): BigInt = BigInt(a.toLong) 108 | implicit def fromDouble(a:Double): BigInt = BigInt(a.toLong) 109 | implicit def fromBigInt(a:BigInt): BigInt = a 110 | implicit def fromBigDecimal(a:BigDecimal): BigInt = a.toBigInt 111 | } 112 | 113 | trait ConvertableToBigDecimal extends ConvertableTo[BigDecimal] { 114 | implicit def fromByte(a:Byte): BigDecimal = BigDecimal(a) 115 | implicit def fromShort(a:Short): BigDecimal = BigDecimal(a) 116 | implicit def fromInt(a:Int): BigDecimal = BigDecimal(a) 117 | implicit def fromLong(a:Long): BigDecimal = BigDecimal(a) 118 | implicit def fromFloat(a:Float): BigDecimal = BigDecimal(a) 119 | implicit def fromDouble(a:Double): BigDecimal = BigDecimal(a) 120 | implicit def fromBigInt(a:BigInt): BigDecimal = BigDecimal(a) 121 | implicit def fromBigDecimal(a:BigDecimal): BigDecimal = a 122 | } 123 | 124 | object ConvertableTo { 125 | implicit object ConvertableToByte extends ConvertableToByte 126 | implicit object ConvertableToShort extends ConvertableToShort 127 | implicit object ConvertableToInt extends ConvertableToInt 128 | implicit object ConvertableToLong extends ConvertableToLong 129 | implicit object ConvertableToFloat extends ConvertableToFloat 130 | implicit object ConvertableToDouble extends ConvertableToDouble 131 | implicit object ConvertableToBigInt extends ConvertableToBigInt 132 | implicit object ConvertableToBigDecimal extends ConvertableToBigDecimal 133 | } 134 | 135 | 136 | /** 137 | * Conversions from type. 138 | * 139 | * An object implementing ConvertableFrom[A] provides methods to go 140 | * from A to number types (and String). 141 | */ 142 | trait ConvertableFrom[@spec A] { 143 | implicit def toByte(a:A): Byte 144 | implicit def toShort(a:A): Short 145 | implicit def toInt(a:A): Int 146 | implicit def toLong(a:A): Long 147 | implicit def toFloat(a:A): Float 148 | implicit def toDouble(a:A): Double 149 | implicit def toBigInt(a:A): BigInt 150 | implicit def toBigDecimal(a:A): BigDecimal 151 | 152 | implicit def toString(a:A): String 153 | } 154 | 155 | trait ConvertableFromByte extends ConvertableFrom[Byte] { 156 | implicit def toByte(a:Byte): Byte = a 157 | implicit def toShort(a:Byte): Short = a.toShort 158 | implicit def toInt(a:Byte): Int = a.toInt 159 | implicit def toLong(a:Byte): Long = a.toLong 160 | implicit def toFloat(a:Byte): Float = a.toFloat 161 | implicit def toDouble(a:Byte): Double = a.toDouble 162 | implicit def toBigInt(a:Byte): BigInt = BigInt(a) 163 | implicit def toBigDecimal(a:Byte): BigDecimal = BigDecimal(a) 164 | 165 | implicit def toString(a:Byte): String = a.toString 166 | } 167 | 168 | trait ConvertableFromShort extends ConvertableFrom[Short] { 169 | implicit def toByte(a:Short): Byte = a.toByte 170 | implicit def toShort(a:Short): Short = a 171 | implicit def toInt(a:Short): Int = a.toInt 172 | implicit def toLong(a:Short): Long = a.toLong 173 | implicit def toFloat(a:Short): Float = a.toFloat 174 | implicit def toDouble(a:Short): Double = a.toDouble 175 | implicit def toBigInt(a:Short): BigInt = BigInt(a) 176 | implicit def toBigDecimal(a:Short): BigDecimal = BigDecimal(a) 177 | 178 | implicit def toString(a:Short): String = a.toString 179 | } 180 | 181 | trait ConvertableFromInt extends ConvertableFrom[Int] { 182 | implicit def toByte(a:Int): Byte = a.toByte 183 | implicit def toShort(a:Int): Short = a.toShort 184 | implicit def toInt(a:Int): Int = a 185 | implicit def toLong(a:Int): Long = a.toLong 186 | implicit def toFloat(a:Int): Float = a.toFloat 187 | implicit def toDouble(a:Int): Double = a.toDouble 188 | implicit def toBigInt(a:Int): BigInt = BigInt(a) 189 | implicit def toBigDecimal(a:Int): BigDecimal = BigDecimal(a) 190 | 191 | implicit def toString(a:Int): String = a.toString 192 | } 193 | 194 | trait ConvertableFromLong extends ConvertableFrom[Long] { 195 | implicit def toByte(a:Long): Byte = a.toByte 196 | implicit def toShort(a:Long): Short = a.toShort 197 | implicit def toInt(a:Long): Int = a.toInt 198 | implicit def toLong(a:Long): Long = a 199 | implicit def toFloat(a:Long): Float = a.toFloat 200 | implicit def toDouble(a:Long): Double = a.toDouble 201 | implicit def toBigInt(a:Long): BigInt = BigInt(a) 202 | implicit def toBigDecimal(a:Long): BigDecimal = BigDecimal(a) 203 | 204 | implicit def toString(a:Long): String = a.toString 205 | } 206 | 207 | trait ConvertableFromFloat extends ConvertableFrom[Float] { 208 | implicit def toByte(a:Float): Byte = a.toByte 209 | implicit def toShort(a:Float): Short = a.toShort 210 | implicit def toInt(a:Float): Int = a.toInt 211 | implicit def toLong(a:Float): Long = a.toLong 212 | implicit def toFloat(a:Float): Float = a 213 | implicit def toDouble(a:Float): Double = a.toDouble 214 | implicit def toBigInt(a:Float): BigInt = BigInt(a.toLong) 215 | implicit def toBigDecimal(a:Float): BigDecimal = BigDecimal(a) 216 | 217 | implicit def toString(a:Float): String = a.toString 218 | } 219 | 220 | trait ConvertableFromDouble extends ConvertableFrom[Double] { 221 | implicit def toByte(a:Double): Byte = a.toByte 222 | implicit def toShort(a:Double): Short = a.toShort 223 | implicit def toInt(a:Double): Int = a.toInt 224 | implicit def toLong(a:Double): Long = a.toLong 225 | implicit def toFloat(a:Double): Float = a.toFloat 226 | implicit def toDouble(a:Double): Double = a 227 | implicit def toBigInt(a:Double): BigInt = BigInt(a.toLong) 228 | implicit def toBigDecimal(a:Double): BigDecimal = BigDecimal(a) 229 | 230 | implicit def toString(a:Double): String = a.toString 231 | } 232 | 233 | trait ConvertableFromBigInt extends ConvertableFrom[BigInt] { 234 | implicit def toByte(a:BigInt): Byte = a.toByte 235 | implicit def toShort(a:BigInt): Short = a.toShort 236 | implicit def toInt(a:BigInt): Int = a.toInt 237 | implicit def toLong(a:BigInt): Long = a.toLong 238 | implicit def toFloat(a:BigInt): Float = a.toFloat 239 | implicit def toDouble(a:BigInt): Double = a.toDouble 240 | implicit def toBigInt(a:BigInt): BigInt = a 241 | implicit def toBigDecimal(a:BigInt): BigDecimal = BigDecimal(a) 242 | 243 | implicit def toString(a:BigInt): String = a.toString 244 | } 245 | 246 | trait ConvertableFromBigDecimal extends ConvertableFrom[BigDecimal] { 247 | implicit def toByte(a:BigDecimal): Byte = a.toByte 248 | implicit def toShort(a:BigDecimal): Short = a.toShort 249 | implicit def toInt(a:BigDecimal): Int = a.toInt 250 | implicit def toLong(a:BigDecimal): Long = a.toLong 251 | implicit def toFloat(a:BigDecimal): Float = a.toFloat 252 | implicit def toDouble(a:BigDecimal): Double = a.toDouble 253 | implicit def toBigInt(a:BigDecimal): BigInt = a.toBigInt 254 | implicit def toBigDecimal(a:BigDecimal): BigDecimal = a 255 | 256 | implicit def toString(a:BigDecimal): String = a.toString 257 | } 258 | 259 | object ConvertableFrom { 260 | implicit object ConvertableFromByte extends ConvertableFromByte 261 | implicit object ConvertableFromShort extends ConvertableFromShort 262 | implicit object ConvertableFromInt extends ConvertableFromInt 263 | implicit object ConvertableFromLong extends ConvertableFromLong 264 | implicit object ConvertableFromFloat extends ConvertableFromFloat 265 | implicit object ConvertableFromDouble extends ConvertableFromDouble 266 | implicit object ConvertableFromBigInt extends ConvertableFromBigInt 267 | implicit object ConvertableFromBigDecimal extends ConvertableFromBigDecimal 268 | } 269 | -------------------------------------------------------------------------------- /src/main/scala/com/azavea/math/EasyNumericOps.scala: -------------------------------------------------------------------------------- 1 | package com.azavea.math 2 | 3 | import scala.{specialized => spec} 4 | 5 | /** 6 | * @author Erik Osheim 7 | * 8 | * NumericOps adds operators to A. It's intended to be used as an implicit 9 | * decorator like so: 10 | * 11 | * def foo[A:Numeric](a:A, b:A) = a + b 12 | * 13 | * (compiled into) = new NumericOps(a).+(b) 14 | * (w/plugin into) = numeric.add(a, b) 15 | */ 16 | final class EasyNumericOps[@spec(Int,Long,Float,Double) A:Numeric](val lhs:A) { 17 | val n = implicitly[Numeric[A]] 18 | 19 | def abs = n.abs(lhs) 20 | def unary_- = n.negate(lhs) 21 | def signum = n.signum(lhs) 22 | 23 | def compare[B:ConvertableFrom](rhs:B) = n.compare(lhs, n.fromType(rhs)) 24 | def equiv[B:ConvertableFrom](rhs:B) = n.equiv(lhs, n.fromType(rhs)) 25 | def max[B:ConvertableFrom](rhs:B) = n.max(lhs, n.fromType(rhs)) 26 | def min[B:ConvertableFrom](rhs:B) = n.min(lhs, n.fromType(rhs)) 27 | 28 | def <=>[B:ConvertableFrom](rhs:B) = n.compare(lhs, n.fromType(rhs)) 29 | def ===[B:ConvertableFrom](rhs:B) = n.equiv(lhs, n.fromType(rhs)) 30 | def !==[B:ConvertableFrom](rhs:B) = n.nequiv(lhs, n.fromType(rhs)) 31 | 32 | def /[B:ConvertableFrom](rhs:B) = n.div(lhs, n.fromType(rhs)) 33 | def >[B:ConvertableFrom](rhs:B) = n.gt(lhs, n.fromType(rhs)) 34 | def >=[B:ConvertableFrom](rhs:B) = n.gteq(lhs, n.fromType(rhs)) 35 | def <[B:ConvertableFrom](rhs:B) = n.lt(lhs, n.fromType(rhs)) 36 | def <=[B:ConvertableFrom](rhs:B) = n.lteq(lhs, n.fromType(rhs)) 37 | def -[B:ConvertableFrom](rhs:B) = n.minus(lhs, n.fromType(rhs)) 38 | def %[B:ConvertableFrom](rhs:B) = n.mod(lhs, n.fromType(rhs)) 39 | def +[B:ConvertableFrom](rhs:B) = n.plus(lhs, n.fromType(rhs)) 40 | def *[B:ConvertableFrom](rhs:B) = n.times(lhs, n.fromType(rhs)) 41 | def **[B:ConvertableFrom](rhs:B) = n.pow(lhs, n.fromType(rhs)) 42 | 43 | def toByte = n.toByte(lhs) 44 | def toShort = n.toShort(lhs) 45 | def toInt = n.toInt(lhs) 46 | def toLong = n.toLong(lhs) 47 | def toFloat = n.toFloat(lhs) 48 | def toDouble = n.toDouble(lhs) 49 | def toBigInt = n.toBigInt(lhs) 50 | def toBigDecimal = n.toBigDecimal(lhs) 51 | } 52 | -------------------------------------------------------------------------------- /src/main/scala/com/azavea/math/FastNumericOps.scala: -------------------------------------------------------------------------------- 1 | package com.azavea.math 2 | 3 | /** 4 | * @author Erik Osheim 5 | */ 6 | 7 | import scala.{specialized => spec} 8 | 9 | /** 10 | * NumericOps adds things like inline operators to A. It's intended to 11 | * be used as an implicit decorator like so: 12 | * 13 | * def foo[A:Numeric](a:A, b:A) = a + b 14 | * (this is translated into) = new NumericOps(a).+(b) 15 | */ 16 | final class FastNumericOps[@spec(Int,Long,Float,Double) A:Numeric](val lhs:A) { 17 | val n = implicitly[Numeric[A]] 18 | 19 | def abs = n.abs(lhs) 20 | def unary_- = n.negate(lhs) 21 | def signum = n.signum(lhs) 22 | 23 | def compare(rhs:A) = n.compare(lhs, rhs) 24 | def equiv(rhs:A) = n.equiv(lhs, rhs) 25 | def max(rhs:A) = n.max(lhs, rhs) 26 | def min(rhs:A) = n.min(lhs, rhs) 27 | 28 | def <=>(rhs:A) = n.compare(lhs, rhs) 29 | def ===(rhs:A) = n.equiv(lhs, rhs) 30 | def !==(rhs:A) = n.nequiv(lhs, rhs) 31 | def >(rhs:A) = n.gt(lhs, rhs) 32 | def >=(rhs:A) = n.gteq(lhs, rhs) 33 | def <(rhs:A) = n.lt(lhs, rhs) 34 | def <=(rhs:A) = n.lteq(lhs, rhs) 35 | def /(rhs:A) = n.div(lhs, rhs) 36 | def -(rhs:A) = n.minus(lhs, rhs) 37 | def %(rhs:A) = n.mod(lhs, rhs) 38 | def +(rhs:A) = n.plus(lhs, rhs) 39 | def *(rhs:A) = n.times(lhs, rhs) 40 | def **(rhs:A) = n.pow(lhs, rhs) 41 | 42 | def toByte = n.toByte(lhs) 43 | def toShort = n.toShort(lhs) 44 | def toInt = n.toInt(lhs) 45 | def toLong = n.toLong(lhs) 46 | def toFloat = n.toFloat(lhs) 47 | def toDouble = n.toDouble(lhs) 48 | def toBigInt = n.toBigInt(lhs) 49 | def toBigDecimal = n.toBigDecimal(lhs) 50 | } 51 | -------------------------------------------------------------------------------- /src/main/scala/com/azavea/math/LiteralOps.scala: -------------------------------------------------------------------------------- 1 | package com.azavea.math 2 | 3 | import scala.{specialized => spec} 4 | 5 | /** 6 | * IntOps, LongOps and friends provide the same tilde operators as NumericOps 7 | * (such as +~, -~, *~, etc) for the number types we're interested in. 8 | * 9 | * Using these, we use these operators with literals, number types, and 10 | * generic types. For instance: 11 | * 12 | * def foo[A:Numeric](a:A, b:Int) = (a *~ b) +~ 1 13 | */ 14 | 15 | final class LiteralIntOps(val lhs:Int) { 16 | def compare[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.compare(n.fromInt(lhs), rhs) 17 | def equiv[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.equiv(n.fromInt(lhs), rhs) 18 | def max[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.max(n.fromInt(lhs), rhs) 19 | def min[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.min(n.fromInt(lhs), rhs) 20 | 21 | def <=>[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.compare(n.fromInt(lhs), rhs) 22 | def ===[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.equiv(n.fromInt(lhs), rhs) 23 | def !==[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.nequiv(n.fromInt(lhs), rhs) 24 | 25 | def /[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.div(n.fromInt(lhs), rhs) 26 | def >[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.gt(n.fromInt(lhs), rhs) 27 | def >=[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.gteq(n.fromInt(lhs), rhs) 28 | def <[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.lt(n.fromInt(lhs), rhs) 29 | def <=[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.lteq(n.fromInt(lhs), rhs) 30 | def -[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.minus(n.fromInt(lhs), rhs) 31 | def %[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.mod(n.fromInt(lhs), rhs) 32 | def +[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.plus(n.fromInt(lhs), rhs) 33 | def *[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.times(n.fromInt(lhs), rhs) 34 | def **[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.pow(n.fromInt(lhs), rhs) 35 | def toNumeric[@spec(Int,Long,Float,Double) A](implicit n:Numeric[A]) = n.fromInt(lhs) 36 | 37 | def toBigInt() = BigInt(lhs) 38 | def toBigDecimal() = BigDecimal(lhs) 39 | } 40 | 41 | 42 | final class LiteralLongOps(val lhs:Long) { 43 | def compare[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.compare(n.fromLong(lhs), rhs) 44 | def equiv[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.equiv(n.fromLong(lhs), rhs) 45 | def max[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.max(n.fromLong(lhs), rhs) 46 | def min[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.min(n.fromLong(lhs), rhs) 47 | 48 | def <=>[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.compare(n.fromLong(lhs), rhs) 49 | def ===[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.equiv(n.fromLong(lhs), rhs) 50 | def !==[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.nequiv(n.fromLong(lhs), rhs) 51 | 52 | def /[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.div(n.fromLong(lhs), rhs) 53 | def >[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.gt(n.fromLong(lhs), rhs) 54 | def >=[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.gteq(n.fromLong(lhs), rhs) 55 | def <[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.lt(n.fromLong(lhs), rhs) 56 | def <=[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.lteq(n.fromLong(lhs), rhs) 57 | def -[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.minus(n.fromLong(lhs), rhs) 58 | def %[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.mod(n.fromLong(lhs), rhs) 59 | def +[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.plus(n.fromLong(lhs), rhs) 60 | def *[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.times(n.fromLong(lhs), rhs) 61 | def **[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.pow(n.fromLong(lhs), rhs) 62 | def toNumeric[@spec(Int,Long,Float,Double) A](implicit n:Numeric[A]) = n.fromLong(lhs) 63 | 64 | def toBigInt() = BigInt(lhs) 65 | def toBigDecimal() = BigDecimal(lhs) 66 | } 67 | 68 | 69 | final class LiteralFloatOps(val lhs:Float) { 70 | def compare[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.compare(n.fromFloat(lhs), rhs) 71 | def equiv[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.equiv(n.fromFloat(lhs), rhs) 72 | def max[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.max(n.fromFloat(lhs), rhs) 73 | def min[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.min(n.fromFloat(lhs), rhs) 74 | 75 | def <=>[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.compare(n.fromFloat(lhs), rhs) 76 | def ===[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.equiv(n.fromFloat(lhs), rhs) 77 | def !==[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.nequiv(n.fromFloat(lhs), rhs) 78 | 79 | def /[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.div(n.fromFloat(lhs), rhs) 80 | def >[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.gt(n.fromFloat(lhs), rhs) 81 | def >=[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.gteq(n.fromFloat(lhs), rhs) 82 | def <[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.lt(n.fromFloat(lhs), rhs) 83 | def <=[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.lteq(n.fromFloat(lhs), rhs) 84 | def -[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.minus(n.fromFloat(lhs), rhs) 85 | def %[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.mod(n.fromFloat(lhs), rhs) 86 | def +[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.plus(n.fromFloat(lhs), rhs) 87 | def *[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.times(n.fromFloat(lhs), rhs) 88 | def **[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.pow(n.fromFloat(lhs), rhs) 89 | def toNumeric[@spec(Int,Long,Float,Double) A](implicit n:Numeric[A]) = n.fromFloat(lhs) 90 | 91 | def toBigInt() = BigDecimal(lhs).toBigInt 92 | def toBigDecimal() = BigDecimal(lhs) 93 | } 94 | 95 | 96 | final class LiteralDoubleOps(val lhs:Double) { 97 | def compare[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.compare(n.fromDouble(lhs), rhs) 98 | def equiv[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.equiv(n.fromDouble(lhs), rhs) 99 | def max[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.max(n.fromDouble(lhs), rhs) 100 | def min[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.min(n.fromDouble(lhs), rhs) 101 | 102 | def <=>[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.compare(n.fromDouble(lhs), rhs) 103 | def ===[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.equiv(n.fromDouble(lhs), rhs) 104 | def !==[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.nequiv(n.fromDouble(lhs), rhs) 105 | 106 | def /[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.div(n.fromDouble(lhs), rhs) 107 | def >[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.gt(n.fromDouble(lhs), rhs) 108 | def >=[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.gteq(n.fromDouble(lhs), rhs) 109 | def <[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.lt(n.fromDouble(lhs), rhs) 110 | def <=[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.lteq(n.fromDouble(lhs), rhs) 111 | def -[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.minus(n.fromDouble(lhs), rhs) 112 | def %[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.mod(n.fromDouble(lhs), rhs) 113 | def +[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.plus(n.fromDouble(lhs), rhs) 114 | def *[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.times(n.fromDouble(lhs), rhs) 115 | def **[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.pow(n.fromDouble(lhs), rhs) 116 | def toNumeric[@spec(Int,Long,Float,Double) A](implicit n:Numeric[A]) = n.fromDouble(lhs) 117 | 118 | def toBigInt() = BigDecimal(lhs).toBigInt 119 | def toBigDecimal() = BigDecimal(lhs) 120 | } 121 | 122 | 123 | final class LiteralBigIntOps(val lhs:BigInt) { 124 | def compare[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.compare(n.fromBigInt(lhs), rhs) 125 | def equiv[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.equiv(n.fromBigInt(lhs), rhs) 126 | def max[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.max(n.fromBigInt(lhs), rhs) 127 | def min[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.min(n.fromBigInt(lhs), rhs) 128 | 129 | def <=>[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.compare(n.fromBigInt(lhs), rhs) 130 | def ===[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.equiv(n.fromBigInt(lhs), rhs) 131 | def !==[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.nequiv(n.fromBigInt(lhs), rhs) 132 | 133 | def /[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.div(n.fromBigInt(lhs), rhs) 134 | def >[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.gt(n.fromBigInt(lhs), rhs) 135 | def >=[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.gteq(n.fromBigInt(lhs), rhs) 136 | def <[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.lt(n.fromBigInt(lhs), rhs) 137 | def <=[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.lteq(n.fromBigInt(lhs), rhs) 138 | def -[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.minus(n.fromBigInt(lhs), rhs) 139 | def %[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.mod(n.fromBigInt(lhs), rhs) 140 | def +[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.plus(n.fromBigInt(lhs), rhs) 141 | def *[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.times(n.fromBigInt(lhs), rhs) 142 | def **[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.pow(n.fromBigInt(lhs), rhs) 143 | def toNumeric[@spec(Int,Long,Float,Double) A](implicit n:Numeric[A]) = n.fromBigInt(lhs) 144 | 145 | def toBigInt() = lhs 146 | def toBigDecimal() = BigDecimal(lhs) 147 | } 148 | 149 | 150 | final class LiteralBigDecimalOps(val lhs:BigDecimal) { 151 | def compare[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.compare(n.fromBigDecimal(lhs), rhs) 152 | def equiv[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.equiv(n.fromBigDecimal(lhs), rhs) 153 | def max[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.max(n.fromBigDecimal(lhs), rhs) 154 | def min[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.min(n.fromBigDecimal(lhs), rhs) 155 | 156 | def <=>[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.compare(n.fromBigDecimal(lhs), rhs) 157 | def ===[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.equiv(n.fromBigDecimal(lhs), rhs) 158 | def !==[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.nequiv(n.fromBigDecimal(lhs), rhs) 159 | 160 | def /[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.div(n.fromBigDecimal(lhs), rhs) 161 | def >[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.gt(n.fromBigDecimal(lhs), rhs) 162 | def >=[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.gteq(n.fromBigDecimal(lhs), rhs) 163 | def <[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.lt(n.fromBigDecimal(lhs), rhs) 164 | def <=[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.lteq(n.fromBigDecimal(lhs), rhs) 165 | def -[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.minus(n.fromBigDecimal(lhs), rhs) 166 | def %[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.mod(n.fromBigDecimal(lhs), rhs) 167 | def +[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.plus(n.fromBigDecimal(lhs), rhs) 168 | def *[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.times(n.fromBigDecimal(lhs), rhs) 169 | def **[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.pow(n.fromBigDecimal(lhs), rhs) 170 | def toNumeric[@spec(Int,Long,Float,Double) A](implicit n:Numeric[A]) = n.fromBigDecimal(lhs) 171 | 172 | def toBigInt() = lhs.toBigInt 173 | def toBigDecimal() = lhs 174 | } 175 | -------------------------------------------------------------------------------- /src/main/scala/com/azavea/math/Numeric.scala: -------------------------------------------------------------------------------- 1 | package com.azavea.math 2 | 3 | import scala.math.{abs, min, max, pow} 4 | import annotation.implicitNotFound 5 | 6 | /** 7 | * @author Erik Osheim 8 | */ 9 | 10 | 11 | /** 12 | * Numeric typeclass for doing operations on generic types. 13 | * 14 | * Importantly, this package does not deliver classes for you to instantiate. 15 | * Rather, it gives you a trait to associated with your generic types, which 16 | * allows actual uses of your generic code with concrete types (e.g. Int) to 17 | * link up with concrete implementations (e.g. IntIsNumeric) of Numeric's 18 | * method for that type. 19 | * 20 | * @example {{{ 21 | * import demo.Numeric 22 | * import demo.Numeric.FastImplicits._ 23 | * 24 | * def pythagoreanTheorem[T:Numeric](a:T, b:T): Double = { 25 | * val c = (a * a) + (b * b) 26 | * math.sqrt(c.toDouble) 27 | * } 28 | * 29 | * def 30 | * }}} 31 | * 32 | */ 33 | //@implicitNotFound(msg = "Cannot find Numeric type class for ${A}") 34 | trait Numeric[@specialized(Int,Long,Float,Double) A] 35 | extends ConvertableFrom[A] with ConvertableTo[A] { 36 | 37 | /** 38 | * Computes the absolute value of `a`. 39 | * 40 | * @return the absolute value of `a` 41 | */ 42 | def abs(a:A):A 43 | 44 | /** 45 | * Returns an integer whose sign denotes the relationship between 46 | * `a` and `b`. If `a` < `b` it returns -1, if `a` == `b` it returns 47 | * 0 and if `a` > `b` it returns 1. 48 | * 49 | * @return -1, 0 or 1 50 | * 51 | * @see math.abs 52 | */ 53 | def compare(a:A, b:A):Int = if (lt(a, b)) -1 else if (gt(a, b)) 1 else 0 54 | 55 | /** 56 | * Divides `a` by `b`. 57 | * 58 | * This method maintains the type of the arguments (`A`). If this 59 | * method is used with `Int` or `Long` then the quotient (as in 60 | * integer division). Otherwise (with `Float` and `Double`) a 61 | * fractional result is returned. 62 | * 63 | * @return `a` / `b` 64 | */ 65 | def div(a:A, b:A):A 66 | 67 | /** 68 | * Tests if `a` and `b` are equivalent. 69 | * 70 | * @return `a` == `b` 71 | */ 72 | def equiv(a:A, b:A):Boolean 73 | 74 | /** 75 | * Tests if `a` and `b` are not equivalent. 76 | * 77 | * @return `a` != `b` 78 | */ 79 | def nequiv(a:A, b:A):Boolean 80 | 81 | /** 82 | * Tests if `a` is greater than `b`. 83 | * 84 | * @return `a` > `b` 85 | */ 86 | def gt(a:A, b:A):Boolean 87 | 88 | /** 89 | * Tests if `a` is greater than or equal to `b`. 90 | * 91 | * @return `a` >= `b` 92 | */ 93 | def gteq(a:A, b:A):Boolean 94 | 95 | /** 96 | * Tests if `a` is less than `b`. 97 | * 98 | * @return `a` <= `b` 99 | */ 100 | def lt(a:A, b:A):Boolean 101 | 102 | /** 103 | * Tests if `a` is less than or equal to `b`. 104 | * 105 | * @return `a` <= `b` 106 | */ 107 | def lteq(a:A, b:A):Boolean 108 | 109 | /** 110 | * Returns the larger of `a` and `b`. 111 | * 112 | * @return max(`a`, `b`) 113 | * 114 | * @see math.max 115 | */ 116 | def max(a:A, b:A):A 117 | 118 | /** 119 | * Returns the smaller of `a` and `b`. 120 | * 121 | * @return min(`a`, `b`) 122 | * 123 | * @see math.min 124 | */ 125 | def min(a:A, b:A):A 126 | 127 | /** 128 | * Returns `a` minus `b`. 129 | * 130 | * @return `a` - `b` 131 | */ 132 | def minus(a:A, b:A):A 133 | 134 | /** 135 | * Returns `a` modulo `b`. 136 | * 137 | * @return `a` % `b` 138 | */ 139 | def mod(a:A, b:A):A 140 | 141 | /** 142 | * Returns the additive inverse `a`. 143 | * 144 | * @return -`a` 145 | */ 146 | def negate(a:A):A 147 | 148 | /** 149 | * Returns one. 150 | * 151 | * @return 1 152 | */ 153 | def one:A 154 | 155 | /** 156 | * Returns `a` plus `b`. 157 | * 158 | * @return `a` + `b` 159 | */ 160 | def plus(a:A, b:A):A 161 | 162 | /** 163 | * Returns `a` to the `b`th power. 164 | * 165 | * Note that with large numbers this method will overflow and 166 | * return Infinity, which becomes MaxValue for whatever type 167 | * is being used. This behavior is inherited from `math.pow`. 168 | * 169 | * @returns pow(`a`, `b`) 170 | * 171 | * @see math.pow 172 | */ 173 | def pow(a:A, b:A):A 174 | 175 | /** 176 | * Returns an integer whose sign denotes the sign of `a`. 177 | * If `a` is negative it returns -1, if `a` is zero it 178 | * returns 0 and if `a` is positive it returns 1. 179 | * 180 | * @return -1, 0 or 1 181 | */ 182 | def signum(a:A):Int = compare(a, zero) 183 | 184 | /** 185 | * Returns `a` times `b`. 186 | * 187 | * @return `a` * `b` 188 | */ 189 | def times(a:A, b:A):A 190 | 191 | /** 192 | * Returns zero. 193 | * 194 | * @return 0 195 | */ 196 | def zero:A 197 | 198 | /** 199 | * Convert a value `b` of type `B` to type `A`. 200 | * 201 | * This method can be used to coerce one generic numeric type to 202 | * another, to allow operations on them jointly. 203 | * 204 | * @example {{{ 205 | * def foo[A:Numeric,B:Numeric](a:A, b:B) = { 206 | * val n = implicitly[Numeric[A]] 207 | * n.add(a, n.fromType(b)) 208 | * } 209 | * }}} 210 | * 211 | * Note that `b` may lose precision when represented as an `A` 212 | * (e.g. if B is Long and A is Int). 213 | * 214 | * @return the value of `b` encoded in type `A` 215 | */ 216 | def fromType[@specialized(Int, Long, Float, Double) B](b:B)(implicit c:ConvertableFrom[B]): A 217 | 218 | def toType[@specialized(Int, Long, Float, Double) B](a:A)(implicit c:ConvertableTo[B]): B 219 | 220 | /** 221 | * Used to get an Ordering[A] instance. 222 | */ 223 | def getOrdering():Ordering[A] = new NumericOrdering(this) 224 | } 225 | 226 | /** 227 | * This is a little helper class that allows us to support the Ordering trait. 228 | * 229 | * If Numeric extended Ordering directly then we'd have to override all of 230 | * the comparison operators, losing specialization and other performance 231 | * benefits. 232 | */ 233 | class NumericOrdering[A](n:Numeric[A]) extends Ordering[A] { 234 | def compare(a:A, b:A) = n.compare(a, b) 235 | } 236 | 237 | trait IntIsNumeric 238 | extends Numeric[Int] with ConvertableFromInt with ConvertableToInt { 239 | def abs(a:Int): Int = scala.math.abs(a) 240 | def div(a:Int, b:Int): Int = a / b 241 | def equiv(a:Int, b:Int): Boolean = a == b 242 | def gt(a:Int, b:Int): Boolean = a > b 243 | def gteq(a:Int, b:Int): Boolean = a >= b 244 | def lt(a:Int, b:Int): Boolean = a < b 245 | def lteq(a:Int, b:Int): Boolean = a <= b 246 | def max(a:Int, b:Int): Int = scala.math.max(a, b) 247 | def min(a:Int, b:Int): Int = scala.math.min(a, b) 248 | def minus(a:Int, b:Int): Int = a - b 249 | def mod(a:Int, b:Int): Int = a % b 250 | def negate(a:Int): Int = -a 251 | def nequiv(a:Int, b:Int): Boolean = a != b 252 | def one: Int = 1 253 | def plus(a:Int, b:Int): Int = a + b 254 | def pow(a:Int, b:Int): Int = scala.math.pow(a, b).toInt 255 | def times(a:Int, b:Int): Int = a * b 256 | def zero: Int = 0 257 | 258 | def fromType[@specialized(Int, Long, Float, Double) B](b:B)(implicit c:ConvertableFrom[B]) = c.toInt(b) 259 | def toType[@specialized(Int, Long, Float, Double) B](a:Int)(implicit c:ConvertableTo[B]) = c.fromInt(a) 260 | } 261 | 262 | trait LongIsNumeric 263 | extends Numeric[Long] with ConvertableFromLong with ConvertableToLong { 264 | def abs(a:Long): Long = scala.math.abs(a) 265 | def div(a:Long, b:Long): Long = a / b 266 | def equiv(a:Long, b:Long): Boolean = a == b 267 | def gt(a:Long, b:Long): Boolean = a > b 268 | def gteq(a:Long, b:Long): Boolean = a >= b 269 | def lt(a:Long, b:Long): Boolean = a < b 270 | def lteq(a:Long, b:Long): Boolean = a <= b 271 | def max(a:Long, b:Long): Long = scala.math.max(a, b) 272 | def min(a:Long, b:Long): Long = scala.math.min(a, b) 273 | def minus(a:Long, b:Long): Long = a - b 274 | def mod(a:Long, b:Long): Long = a % b 275 | def negate(a:Long): Long = -a 276 | def nequiv(a:Long, b:Long): Boolean = a != b 277 | def one: Long = 1L 278 | def plus(a:Long, b:Long): Long = a + b 279 | def pow(a:Long, b:Long): Long = scala.math.pow(a, b).toLong 280 | def times(a:Long, b:Long): Long = a * b 281 | def zero: Long = 0L 282 | 283 | def fromType[@specialized(Int, Long, Float, Double) B](b:B)(implicit c:ConvertableFrom[B]) = c.toLong(b) 284 | def toType[@specialized(Int, Long, Float, Double) B](a:Long)(implicit c:ConvertableTo[B]) = c.fromLong(a) 285 | } 286 | 287 | trait FloatIsNumeric 288 | extends Numeric[Float] with ConvertableFromFloat with ConvertableToFloat { 289 | def abs(a:Float): Float = scala.math.abs(a) 290 | def div(a:Float, b:Float): Float = a / b 291 | def equiv(a:Float, b:Float): Boolean = a == b 292 | def gt(a:Float, b:Float): Boolean = a > b 293 | def gteq(a:Float, b:Float): Boolean = a >= b 294 | def lt(a:Float, b:Float): Boolean = a < b 295 | def lteq(a:Float, b:Float): Boolean = a <= b 296 | def max(a:Float, b:Float): Float = scala.math.max(a, b) 297 | def min(a:Float, b:Float): Float = scala.math.min(a, b) 298 | def minus(a:Float, b:Float): Float = a - b 299 | def mod(a:Float, b:Float): Float = a % b 300 | def negate(a:Float): Float = -a 301 | def nequiv(a:Float, b:Float): Boolean = a != b 302 | def one: Float = 1.0F 303 | def plus(a:Float, b:Float): Float = a + b 304 | def pow(a:Float, b:Float): Float = scala.math.pow(a, b).toFloat 305 | def times(a:Float, b:Float): Float = a * b 306 | def zero: Float = 0.0F 307 | 308 | def fromType[@specialized(Int, Long, Float, Double) B](b:B)(implicit c:ConvertableFrom[B]) = c.toFloat(b) 309 | def toType[@specialized(Int, Long, Float, Double) B](a:Float)(implicit c:ConvertableTo[B]) = c.fromFloat(a) 310 | } 311 | 312 | trait DoubleIsNumeric 313 | extends Numeric[Double] with ConvertableFromDouble with ConvertableToDouble { 314 | def abs(a:Double): Double = scala.math.abs(a) 315 | def div(a:Double, b:Double): Double = a / b 316 | def equiv(a:Double, b:Double): Boolean = a == b 317 | def gt(a:Double, b:Double): Boolean = a > b 318 | def gteq(a:Double, b:Double): Boolean = a >= b 319 | def lt(a:Double, b:Double): Boolean = a < b 320 | def lteq(a:Double, b:Double): Boolean = a <= b 321 | def max(a:Double, b:Double): Double = scala.math.max(a, b) 322 | def min(a:Double, b:Double): Double = scala.math.min(a, b) 323 | def minus(a:Double, b:Double): Double = a - b 324 | def mod(a:Double, b:Double): Double = a % b 325 | def negate(a:Double): Double = -a 326 | def nequiv(a:Double, b:Double): Boolean = a != b 327 | def one: Double = 1.0 328 | def plus(a:Double, b:Double): Double = a + b 329 | def pow(a:Double, b:Double): Double = scala.math.pow(a, b) 330 | def times(a:Double, b:Double): Double = a * b 331 | def zero: Double = 0.0 332 | 333 | def fromType[@specialized(Int, Long, Float, Double) B](b:B)(implicit c:ConvertableFrom[B]) = c.toDouble(b) 334 | def toType[@specialized(Int, Long, Float, Double) B](a:Double)(implicit c:ConvertableTo[B]) = c.fromDouble(a) 335 | } 336 | 337 | trait BigIntIsNumeric 338 | extends Numeric[BigInt] with ConvertableFromBigInt with ConvertableToBigInt { 339 | def abs(a:BigInt): BigInt = a.abs 340 | def div(a:BigInt, b:BigInt): BigInt = a / b 341 | def equiv(a:BigInt, b:BigInt): Boolean = a == b 342 | def gt(a:BigInt, b:BigInt): Boolean = a > b 343 | def gteq(a:BigInt, b:BigInt): Boolean = a >= b 344 | def lt(a:BigInt, b:BigInt): Boolean = a < b 345 | def lteq(a:BigInt, b:BigInt): Boolean = a <= b 346 | def max(a:BigInt, b:BigInt): BigInt = a.max(b) 347 | def min(a:BigInt, b:BigInt): BigInt = a.min(b) 348 | def minus(a:BigInt, b:BigInt): BigInt = a - b 349 | def mod(a:BigInt, b:BigInt): BigInt = a % b 350 | def negate(a:BigInt): BigInt = -a 351 | def nequiv(a:BigInt, b:BigInt): Boolean = a != b 352 | def one: BigInt = BigInt(1) 353 | def plus(a:BigInt, b:BigInt): BigInt = a + b 354 | def pow(a:BigInt, b:BigInt): BigInt = a.pow(b) 355 | def times(a:BigInt, b:BigInt): BigInt = a * b 356 | def zero: BigInt = BigInt(0) 357 | 358 | def fromType[@specialized(Int, Long, Float, Double) B](b:B)(implicit c:ConvertableFrom[B]) = c.toBigInt(b) 359 | def toType[@specialized(Int, Long, Float, Double) B](a:BigInt)(implicit c:ConvertableTo[B]) = c.fromBigInt(a) 360 | } 361 | 362 | trait BigDecimalIsNumeric 363 | extends Numeric[BigDecimal] with ConvertableFromBigDecimal with ConvertableToBigDecimal { 364 | def abs(a:BigDecimal): BigDecimal = a.abs 365 | def div(a:BigDecimal, b:BigDecimal): BigDecimal = a / b 366 | def equiv(a:BigDecimal, b:BigDecimal): Boolean = a == b 367 | def gt(a:BigDecimal, b:BigDecimal): Boolean = a > b 368 | def gteq(a:BigDecimal, b:BigDecimal): Boolean = a >= b 369 | def lt(a:BigDecimal, b:BigDecimal): Boolean = a < b 370 | def lteq(a:BigDecimal, b:BigDecimal): Boolean = a <= b 371 | def max(a:BigDecimal, b:BigDecimal): BigDecimal = a.max(b) 372 | def min(a:BigDecimal, b:BigDecimal): BigDecimal = a.min(b) 373 | def minus(a:BigDecimal, b:BigDecimal): BigDecimal = a - b 374 | def mod(a:BigDecimal, b:BigDecimal): BigDecimal = a % b 375 | def negate(a:BigDecimal): BigDecimal = -a 376 | def nequiv(a:BigDecimal, b:BigDecimal): Boolean = a != b 377 | def one: BigDecimal = BigDecimal(1.0) 378 | def plus(a:BigDecimal, b:BigDecimal): BigDecimal = a + b 379 | def pow(a:BigDecimal, b:BigDecimal): BigDecimal = a.pow(b) 380 | def times(a:BigDecimal, b:BigDecimal): BigDecimal = a * b 381 | def zero: BigDecimal = BigDecimal(0.0) 382 | 383 | def fromType[@specialized(Int, Long, Float, Double) B](b:B)(implicit c:ConvertableFrom[B]) = c.toBigDecimal(b) 384 | def toType[@specialized(Int, Long, Float, Double) B](a:BigDecimal)(implicit c:ConvertableTo[B]) = c.fromBigDecimal(a) 385 | } 386 | 387 | 388 | /** 389 | * This companion object provides the instances (e.g. IntIsNumeric) 390 | * associating the type class (Numeric) with its member type (Int). 391 | */ 392 | object Numeric { 393 | implicit object IntIsNumeric extends IntIsNumeric 394 | implicit object LongIsNumeric extends LongIsNumeric 395 | implicit object FloatIsNumeric extends FloatIsNumeric 396 | implicit object DoubleIsNumeric extends DoubleIsNumeric 397 | implicit object BigIntIsNumeric extends BigIntIsNumeric 398 | implicit object BigDecimalIsNumeric extends BigDecimalIsNumeric 399 | 400 | def numeric[@specialized(Int, Long, Float, Double) A:Numeric]:Numeric[A] = implicitly[Numeric[A]] 401 | } 402 | 403 | object FastImplicits { 404 | implicit def infixOps[@specialized(Int, Long, Float, Double) A:Numeric](a:A) = new FastNumericOps(a) 405 | 406 | implicit def infixIntOps(i:Int) = new LiteralIntOps(i) 407 | implicit def infixLongOps(l:Long) = new LiteralLongOps(l) 408 | implicit def infixFloatOps(f:Float) = new LiteralFloatOps(f) 409 | implicit def infixDoubleOps(d:Double) = new LiteralDoubleOps(d) 410 | implicit def infixBigIntOps(f:BigInt) = new LiteralBigIntOps(f) 411 | implicit def infixBigDecimalOps(d:BigDecimal) = new LiteralBigDecimalOps(d) 412 | 413 | def numeric[@specialized(Int, Long, Float, Double) A:Numeric]:Numeric[A] = implicitly[Numeric[A]] 414 | } 415 | 416 | object EasyImplicits { 417 | implicit def infixOps[@specialized(Int, Long, Float, Double) A:Numeric](a:A) = new EasyNumericOps(a) 418 | 419 | implicit def infixIntOps(i:Int) = new LiteralIntOps(i) 420 | implicit def infixLongOps(l:Long) = new LiteralLongOps(l) 421 | implicit def infixFloatOps(f:Float) = new LiteralFloatOps(f) 422 | implicit def infixDoubleOps(d:Double) = new LiteralDoubleOps(d) 423 | implicit def infixBigIntOps(f:BigInt) = new LiteralBigIntOps(f) 424 | implicit def infixBigDecimalOps(d:BigDecimal) = new LiteralBigDecimalOps(d) 425 | 426 | def numeric[@specialized(Int, Long, Float, Double) A:Numeric]:Numeric[A] = implicitly[Numeric[A]] 427 | } 428 | -------------------------------------------------------------------------------- /src/test/scala/Matrix.scala: -------------------------------------------------------------------------------- 1 | import com.azavea.math.Numeric 2 | import com.azavea.math.EasyImplicits._ 3 | import Predef.{any2stringadd => _, _} 4 | 5 | class MatrixException(s:String) extends Exception 6 | 7 | 8 | /** 9 | * Matrix implementation in terms of a generic numeric type. 10 | */ 11 | class NMatrix[A:Numeric:Manifest](val data:Array[Array[A]], 12 | val rows:Int, val cols:Int) { 13 | if (rows < 1) throw new MatrixException("illegal height") 14 | if (cols < 1) throw new MatrixException("illegal widht") 15 | 16 | /* build an empty matrix with the same dimensions */ 17 | def createEmpty() = NMatrix.empty[A](rows, cols) 18 | 19 | /* access the element at (y, x) */ 20 | def apply(y:Int, x:Int) = data(y)(x) 21 | 22 | /* update the element at (y, x) to value */ 23 | def update(y:Int, x:Int, value:A) = data(y)(x) = value 24 | 25 | /* create a new Matrix by mapping each element through f */ 26 | def map[B:Numeric:Manifest](f:A => B) = { 27 | new NMatrix[B](data.map(_.map(f).toArray).toArray, rows, cols) 28 | } 29 | 30 | /* combine two matrices element-by-element */ 31 | def combine(rhs:NMatrix[A], f:(A, A) => A) = { 32 | val result = createEmpty 33 | for (y <- 0 until rows; x <- 0 until cols) { 34 | result(y, x) = f(this(y, x), rhs(y, x)) 35 | } 36 | result 37 | } 38 | 39 | /* add a scalar value to each element */ 40 | def +(a:A):NMatrix[A] = map(_ + a) 41 | 42 | /* add two matrices */ 43 | def +(rhs:NMatrix[A]):NMatrix[A] = combine(rhs, _ + _) 44 | 45 | /* multiply each element by a scalar value */ 46 | def *(a:A):NMatrix[A] = map(_ * a) 47 | 48 | /* multiply two matrices */ 49 | def *(rhs:NMatrix[A]):NMatrix[A] = { 50 | 51 | /* make sure this and rhs are compatible */ 52 | if (this.rows != rhs.cols || this.cols != rhs.rows) { 53 | throw new MatrixException("dimensions do not match") 54 | } 55 | 56 | /* figure out the dimensions of the result matrix */ 57 | val (rrows, rcols, n) = (this.rows, rhs.cols, this.cols) 58 | 59 | /* allocate the result matrix */ 60 | val result = NMatrix.empty[A](rows, rcols) 61 | 62 | /* loop over the cells in the result matrix */ 63 | for(y <- 0 until rrows; x <- 0 until rcols) { 64 | /* for each pair of values in this-row/rhs-column, multiply them 65 | * and then sum to get the result value for this cell. */ 66 | result(y, x) = (0 until n).foldLeft(numeric.zero) { 67 | case (sum, i) => sum + (this(y, i) * rhs(i, x)) 68 | } 69 | } 70 | result 71 | } 72 | 73 | def toAscii() = "[" + data.map { 74 | _.foldLeft("")(_ + " " + _.toString) 75 | }.reduceLeft(_ + "\n" + _) + "]" 76 | } 77 | 78 | object NMatrix { 79 | def empty[A:Numeric:Manifest](rows:Int, cols:Int) = { 80 | new NMatrix(Array.ofDim[A](rows, cols), rows, cols) 81 | } 82 | 83 | def apply[A:Numeric:Manifest](data:Array[Array[A]], rows:Int, cols:Int) = { 84 | new NMatrix(data, rows, cols) 85 | } 86 | } 87 | 88 | /** 89 | * Matrix implementation in terms of Double. 90 | */ 91 | class DMatrix(val data:Array[Array[Double]], val rows:Int, val cols:Int) { 92 | if (rows < 1) throw new MatrixException("illegal height") 93 | if (cols < 1) throw new MatrixException("illegal widht") 94 | 95 | /* build an empty matrix with the same dimensions */ 96 | def createEmpty() = DMatrix.empty(rows, cols) 97 | 98 | /* access the element at (y, x) */ 99 | def apply(y:Int, x:Int) = data(y)(x) 100 | 101 | /* update the element at (y, x) to value */ 102 | def update(y:Int, x:Int, value:Double) = data(y)(x) = value 103 | 104 | /* create a new Matrix by mapping each element through f */ 105 | def map(f:Double => Double) = { 106 | new DMatrix(data.map(_.map(f).toArray).toArray, rows, cols) 107 | } 108 | 109 | /* combine two matrices element-by-element */ 110 | def combine(rhs:DMatrix, f:(Double, Double) => Double) = { 111 | val result = createEmpty 112 | for (y <- 0 until rows; x <- 0 until cols) { 113 | result(y, x) = f(this(y, x), rhs(y, x)) 114 | } 115 | result 116 | } 117 | 118 | /* add a scalar value to each element */ 119 | def +(a:Double) = map(_ + a) 120 | 121 | /* add two matrices */ 122 | def +(rhs:DMatrix) = combine(rhs, _ + _) 123 | 124 | /* multiply each element by a scalar value */ 125 | def *(a:Double) = map(_ * a) 126 | 127 | /* multiply two matrices */ 128 | def *(rhs:DMatrix) = { 129 | 130 | /* make sure this and rhs are compatible */ 131 | if (this.rows != rhs.cols || this.cols != rhs.rows) { 132 | throw new MatrixException("dimensions do not match") 133 | } 134 | 135 | val (rrows, rcols, n) = (this.rows, rhs.cols, this.cols) 136 | 137 | val result = DMatrix.empty(rrows, rcols) 138 | 139 | for(y <- 0 until rrows; x <- 0 until rcols) { 140 | result(y, x) = (0 until n).foldLeft(0.0) { 141 | case (sum, i) => sum + (this(y, i) * rhs(i, x)) 142 | } 143 | } 144 | 145 | result 146 | } 147 | 148 | def toAscii() = "[" + data.map { 149 | _.foldLeft("")(_ + " " + _.toString) 150 | }.reduceLeft(_ + "\n" + _) + "]\n" 151 | } 152 | 153 | object DMatrix { 154 | def empty(rows:Int, cols:Int) = { 155 | new DMatrix(Array.ofDim[Double](rows, cols), rows, cols) 156 | } 157 | 158 | def apply(data:Array[Array[Double]], rows:Int, cols:Int) = { 159 | new DMatrix(data, rows, cols) 160 | } 161 | } 162 | 163 | 164 | 165 | /** 166 | * Testing... 167 | */ 168 | import org.scalatest.FunSuite 169 | import org.scalatest.matchers.ShouldMatchers 170 | 171 | class MatrixSpec extends FunSuite with ShouldMatchers { 172 | val (h, w) = (2, 2) 173 | val data1 = Array(Array(1.0, 2.0), Array(3.0, 4.0)) 174 | 175 | def compare(m1:NMatrix[Double], m2:DMatrix):Boolean = { 176 | println(m1.toAscii) 177 | println(m2.toAscii) 178 | for (y <- 0 until h; x <- 0 until w) { 179 | if (m1(y, x) !== m2(y, x)) return false 180 | } 181 | true 182 | } 183 | 184 | val m1 = NMatrix(data1.clone, h, w) 185 | val m2 = DMatrix(data1.clone, h, w) 186 | 187 | test("test matrix representation") { assert(compare(m1, m2)) } 188 | 189 | test("scalar addition") { assert(compare(m1 + 13.0, m2 + 13.0)) } 190 | test("scalar multiplication") { assert(compare(m1 * 9.0, m2 * 9.0)) } 191 | 192 | test("matrix addition") { assert(compare(m1 + m1, m2 + m2)) } 193 | test("matrix multiplication") { assert(compare(m1 * m1, m2 * m2)) } 194 | } 195 | --------------------------------------------------------------------------------