├── .gitignore
├── COPYING
├── README.md
├── TODO
├── build.sbt
├── perf
├── build.sbt
├── lib
│ └── optimized-numeric-plugin_2.9.1-0.1.jar
└── src
│ └── main
│ └── scala
│ └── Main.scala
├── plugin
├── Makefile
├── build.sbt
└── src
│ ├── main
│ ├── resources
│ │ └── scalac-plugin.xml
│ └── scala
│ │ └── com
│ │ └── azavea
│ │ └── math
│ │ └── plugin
│ │ └── OptimizedNumeric.scala
│ └── test
│ └── scala
│ └── Example.scala
├── project
└── Build.scala
└── src
├── main
└── scala
│ └── com
│ └── azavea
│ └── math
│ ├── Convertable.scala
│ ├── EasyNumericOps.scala
│ ├── FastNumericOps.scala
│ ├── LiteralOps.scala
│ └── Numeric.scala
└── test
└── scala
└── Matrix.scala
/.gitignore:
--------------------------------------------------------------------------------
1 | target/
2 | lib_managed/
3 | src_managed/
4 | project/boot/
5 | .ensime
6 | benchmark.html
7 |
--------------------------------------------------------------------------------
/COPYING:
--------------------------------------------------------------------------------
1 | Copyright (c) 2011 Erik Osheim, Azavea Inc.
2 |
3 | Permission is hereby granted, free of charge, to any person obtaining a copy of
4 | this software and associated documentation files (the "Software"), to deal in
5 | the Software without restriction, including without limitation the rights to
6 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
7 | of the Software, and to permit persons to whom the Software is furnished to do
8 | so, subject to the following conditions:
9 |
10 | The above copyright notice and this permission notice shall be included in all
11 | copies or substantial portions of the Software.
12 |
13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
19 | SOFTWARE.
20 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | NUMERIC
2 | =======
3 |
4 | This package (com.azavea.math.Numeric) is based on scala.math.Numeric, which is
5 | a type trait designed to allow abstraction across numeric types (e.g. Int,
6 | Double, etc). I (Erik Osheim) began this work as part of the
7 | specialized-numeric Scala Incubator project.
8 |
9 | This package is a much-improved version. For one, it is way faster, thanks to
10 | Scala's specialization system, as well as some restructuring and an optional
11 | compiler plugin. It is also more flexible, allowing the user to operate on
12 | generic numeric types, concrete numeric types, and literals without requiring
13 | ugly casting.
14 |
15 | Ultimately my hope is to push these changes back into Scala's library.
16 |
17 | REQUIREMENTS
18 | ============
19 |
20 | Build using SBT 10. Will port to SBT 11 soon.
21 |
22 |
23 | EXAMPLES
24 | ========
25 |
26 | Basic example of addition:
27 |
28 | import com.azavea.math.Numeric
29 | import com.azavea.math.FastImplicits._
30 |
31 | def adder[A:Numeric](a:A, b:A) = a + b
32 |
33 | Creating a Point3 object with coordinates of any numeric type:
34 |
35 | import com.azavea.math.Numeric
36 | import com.azavea.math.FastImplicits._
37 |
38 | case class Point3[T:Numeric](x:T, y:T, z:T) {
39 | def +(rhs:Point3[T]) = Point(x + rhs.x, y + rhs.y, z + rhs.z)
40 |
41 | def distance(rhs:Point3[T]) = {
42 | val dx = x - rhs.x
43 | val dy = y - rhs.y
44 | val dz = z - rhs.z
45 | val d = dx * dx + dy * dy + dz * dz
46 | math.sqrt(d.toDouble)
47 | }
48 | }
49 |
50 | Mixing literals and variables:
51 |
52 | import com.azavea.math.Numeric
53 | import com.azavea.math.EasyImplicits._
54 | import Predef.{any2stringadd => _, _}
55 |
56 | def foo[T:Numeric](a:T, b:T):T = a * 100 + b
57 |
58 | Currently there are two different ways to use Numeric: EasyImplicits allows you
59 | to operate on mixed numeric types (e.g. T + U + Int). FastImplicits sacrifices
60 | this ability in exchange for speed. Both can be made equally fast (as fast as
61 | operating on direct types) with the optimized-numeric compiler plugin.
62 |
63 | If you plan to use the compiler plugin, or aren't worried about speed, you will
64 | probably want to use EasyImplicits.
65 |
66 |
67 | ANNOYING PREDEF
68 | ===============
69 |
70 | You may have noticed this ugly-looking import:
71 |
72 | import Predef.{any2stringadd => _, _}
73 |
74 | This is to work around a design problem in Scala. You may not always need to
75 | use this, but if you notice problems with + not working correctly you should
76 | add this top-level import. Sorry! :(
77 |
78 |
79 | PROJECT STRUCTURE
80 | =================
81 |
82 | Here is a list of the SBT projects:
83 |
84 | * root: contains the library code itself ("package" builds the library jar)
85 | * plugin: contains the compiler plugin code ("package" builds the plugin jar)
86 | * perf: contains the performance test
87 |
88 | Use "projects" to view them, and "project XYZ" to switch to XYZ.
89 |
90 | USING THE PLUGIN
91 | ================
92 |
93 | The optimized-numeric plugin is able to speed things up by rewring certain
94 | constructions into other, faster ones. Here's an example:
95 |
96 | // written
97 | def foo[T:Numeric](a:T, b:T) = a + b
98 |
99 | // compiled
100 | def foo[T](a:T, b:T)(implicit ev:Numeric[T]) = new FastNumericOps(a).+(b)
101 |
102 | // compiled with plugin
103 | def foo[T](a:T, b:T)(implicit ev:Numeric[T]) = ev.add(a, b)
104 |
105 | In the future scalac might be able to do this for us (or hotspot might be able
106 | to optimized it away). But in the absence of these things the plugin helps make
107 | Numeric much faster (especially when using EasyImplicits).
108 |
109 | At the most basic level, you can add "-Xplugin:path/to/optimized-numeric.jar"
110 | to your scalac invocation to compile things with the plugin.
111 |
112 | When running the perf project, here are the steps you can take to build and
113 | enable the plugin:
114 |
115 | 1. in sbt: "project plugin", then "package"
116 | 2. cp plugin/target/scala-2.9.1.final/optimized-numeric-plugin_2.9.1-0.1.jar perf/lib/
117 | 3. uncomment perf/build.sbt line involving -Xplugin
118 | 4. in sbt: "project perf", then "run"
119 |
120 | The plugin is enabled by default. To disable the plugin just revert step #3.
121 |
122 |
123 | BENCHMARKS
124 | ==========
125 |
126 | To run the benchmarks, do the following:
127 |
128 | 1. optionally build and install the plugin
129 | 2. in sbt: "project perf", then "run"
130 |
131 | The output shows the speed (in milliseconds) of a direct implementation
132 | (without generics), the new implementation (com.azavea.math.Numeric) and the
133 | old implementation (the built-in scala.math.Numeric). In some cases there is no
134 | old implementation--in those cases the test tries to hem as closely as possible
135 | to the direct implementation.
136 |
137 | * n:d is how new compares to direct
138 | * o:d is how old compares to direct
139 | * o:n is how old compares to new
140 |
141 | It also creates a benchmark.html file which colors the output.
142 |
143 | RESULTS
144 | =======
145 |
146 | There are some interesting results:
147 |
148 | 1. Both scala.math.Numeric and com.azavea.math.Numeric seem to perform worse
149 | on integral types than fractional ones. this is very pronounced for
150 | scala.math.Numeric and only slight for com.azavea.math.Numeric.
151 |
152 | 2. com.azavea.math.Numeric mostly* performs as well as direct implementations
153 | except when using infix operators without the compiler plugin. The current
154 | Numeric is clearly inappropriate for any application where performance is
155 | important.
156 |
157 | 3. The asterisk in the previous item has to do with Quicksort. Basically,
158 | scala.util.Sorting uses Ordering\[A\] which is not specialized and which
159 | implements all its own (non-specialized) comparison operators in terms of
160 | compare().
161 |
162 | This ends up being really slow, so my Numeric trait doesn't extend it, but
163 | instead provides a getOrdering() method (which builds a separate Ordering
164 | instance wrapping the Numeric instance). As a result, it doesn't perform any
165 | better than scala.math.Numeric on this test (and in fact does a bit worse).
166 |
167 | I don't know how likely it is that Ordering will be specialized, but huge
168 | performance gains seem possible.
169 |
170 | 4. scala.util.Sorting.quickSort lacks a direct Long implementation, so using it
171 | with Longs is ~5x slower than Int, Float or Double.
172 |
173 | 5. It seems like scala.util.Sorting could use some love. My naive direct
174 | implementation of merge sort seems to beat Sorting.quickSort for Long
175 | (obviously), Float and Double. That said, optimizing sort algorithms can be
176 | tricky. But gains seem possible.
177 |
178 |
179 | DIFFERENCES
180 | ===========
181 |
182 | While very similar to scala.math.Numeric, com.azavea.math.Numeric has
183 | some differences. The most significant ones are:
184 |
185 | 1. It does not inherit from the Ordering type class, but rather directly
186 | implements the comparison methods. I will try to do some cleanup on this and
187 | make it more compatible with Ordering, but it was important to me that the
188 | comparison methods are also specialized.
189 |
190 | 2. It does not implement Integral/Fractional. I think that leaving
191 | division/modulo off of Numeric is a mistake, and don't think that forcing users
192 | to use Integral/Fractional is a good idea. Given that Scala uses the same
193 | symbol (/) to mean both "integer division" and "true division" it seems clear
194 | that Numeric can too.
195 |
196 | 3. It's in a different package.
197 |
198 | 4. It adds some operators that I thought would be nice to have:
199 |
200 | - <=> as an alias for compare
201 | - === as an alias for equiv
202 | - !== as an alias for !equiv
203 | - ** as an alias for math.pow
204 |
205 | 5. It adds a full-suite of conversions. Unlike the existing Numeric, you can
206 | convert directly to/from any numeric type. This is useful since you might be
207 | going from a generic type to a known type (e.g. A -> Int) or a known type to a
208 | geneirc one (Int -> A). In both cases it is important not to do any unnecessary
209 | work (when A is an Int, you should not have to do any copying/casting).
210 |
211 | These conversions can be used from the instance of Numeric[T] or directly
212 | on the values of T themselves:
213 |
214 | def foo[T:Numeric](t:T):Double = {
215 | val i:Int = numeric.toInt(t)
216 | val d:Double = t.toDouble
217 | d - i
218 | }
219 |
220 | CAVEATS
221 | =======
222 |
223 | This section is just for "known problems" with the Numeric type class approach
224 | in Scala.
225 |
226 | Precision
227 | ---------
228 |
229 | Given the signatures of Numeric's functions, it's possible to mix literals and
230 | generic types in a way which will lose precision. Consider:
231 |
232 | def foo[T:Numeric](t:T) = t + 9.2
233 |
234 | The author might expect that foo\[Int\](5) will return a Double (14.2), but in
235 | fact foo\[T\] returns T and so Foo\[Int\] will return an Int (14). The solution
236 | in cases like this (if you know you want a double) is to convert T to a Double
237 | first:
238 |
239 | def foo[T:Numeric](t:T) = t.toDouble + 9.2
240 |
241 | You could imagine that + could "figure out" whether T has more precision than
242 | Double (e.g. BigDecimal) or less precisoin (e.g. Int) and "do the right thing".
243 | Sadly, this is not possible using the current strategy.
244 |
245 | If you are interested in implementing a numeric tower in Scala (presumably by
246 | writing some pretty intense compiler plugins if not modifying Scala's type
247 | system) please get in touch with the author. :)
248 |
249 | Clunky syntax
250 | -------------
251 |
252 | Unfortunately in order to benefit from the speed of this library you must
253 | annotate all your numeric with @specialization. Also, you often need a
254 | manifest (for instance when you want to allocate Arrays of that type), so you
255 | have to provide those too. Your code will end up looking similar to this:
256 |
257 | def handleInt(a:Int) = ...
258 |
259 | def handleA[@specialized A:Numeric:Manifest](a:A) = ...
260 |
261 | When you compare these visually the second is obviously terrible. In many of
262 | the examples I have omitted the @specialized annotation and the Manifest type
263 | bound for clarity. Hopefully this is not too deceptive.
264 |
265 | It would be great if there was some way to create a type bound that "included"
266 | specialization. For instance:
267 |
268 | // written
269 | def bar[T:SNumeric](a:T, b:T, c:T) = a * b + c
270 |
271 | // compiled
272 | def bar[@specialized(Int,Long,Float,Double) T:Numeric](a:T, b:T, c:T) = a * b + c
273 |
274 | The current design of specialization works well when a library author expects
275 | users to call her generic functions with concrete types. But if her users are
276 | themselves defining generic functions, the library author is powerless to help.
277 |
278 | Obviously, this could pose real problem in terms of a bytecode explosion, which
279 | is why it should not be a default behavior. In fact, Numeric may be one of the
280 | few places where this behavior would be desirable.
281 |
282 | Non-extensibility
283 | -----------------
284 |
285 | While you should be able to implement your own instances of Numeric for
286 | user-generated types (e.g. Complex or BigRational) you will not be able to use
287 | the full power of this library due to the need to add methods to the
288 | ConvertableFrom and ConvertableTo traits.
289 |
290 | This is unfortunate, and I'm is exploring a pluggable numeric conversion
291 | system. However, in the interested of getting a fast, working implementation
292 | out there, I have shelved that for now.
293 |
294 | Lack of range support
295 | ---------------------
296 |
297 | I am still working on implementing a corresponding NumericRange. For now
298 | you'll need to convert to a known type (e.g. Long), or use while loops. This
299 | is definitely possible and will hopefuly be done soon.
300 |
--------------------------------------------------------------------------------
/TODO:
--------------------------------------------------------------------------------
1 | This is a random list of things that either would be good to do, or would be
2 | good to consider doing. Some of these are mutually-incompatible (or at least,
3 | mutually-undesirable).
4 |
5 | Some of these may also be impossible! :)
6 |
7 | 1. Implement a NumericRange class
8 |
9 | Paul Phillips has encouraged me to do this. I managed to create something that
10 | worked but wasn't performing as well as the direct integer ranges. The problem
11 | seemed to be the difference between the way foo(Int => Unit) and foo[T](T =>
12 | Unit) are treated.
13 |
14 |
15 | 2. Fix support for user-created types
16 |
17 | Right now the strategy we're using with ConvertableTo/ConvertableFrom doesn't
18 | scale well when users are adding their own types. This may not actually be a
19 | big deal, but it'd be nice if the system was more pluggable. Right now if we
20 | want to provide suport converting to/from "Foo" objects we have to change the
21 | ConvertableFrom/ConvertableTo traits to include methods "toFoo" and "fromFoo".
22 |
23 | I can imagine using something like ConvertableBetween[A, B] to solve this, but
24 | I don't want to sacrifice performance to do so. Experimentation necessary.
25 | See https://github.com/nuttycom/salt/commits/master/src/main/scala/com/nommit/salt/Bijection.scala
26 | for ideas.
27 |
28 |
29 | 3. Specialize Ordering
30 |
31 | This is probably too big for this project, but it would be great if Ordering
32 | were specialized so that Numeric could extend it without taking a huge speed
33 | hit on things like Numeric.lt().
34 |
35 |
36 | 4. Rationalize working with two different numeric types
37 |
38 | Currently if you have "def complicated[T:Numeric, U:Numeric](t:T, u:U)" you
39 | will have to do manual conversions to one of those types without being sure
40 | that you aren't losing precision. I'm not sure there's a way to get around
41 | this, but it might be nice to have a way to compare two Numeric objects in
42 | terms of which one should "win" in terms of precision. That way you could at
43 | least say something like "if (n1.morePreciseThan(n2)) ... else ..."
44 |
45 | I'm not sure how often this would be used, but I could imagine it being nice.
46 |
47 |
48 | 5. Test on more hardware/JVM configurations
49 |
50 |
51 | 6. Port tests to one of the newer performance testing frameworks
52 |
53 |
54 | 7. Write more tests
55 |
56 |
57 | 8. Figure out how to deal with the "strict" infix operators versus the "fuzzy"
58 | infix operators. There is a compiler plugin to rewrite both into the faster
59 | form "n.add(lhs, rhs)" but without that the strict infix operators perform
60 | better, and also don't fall afoul of StringOps' + method. On the other hand,
61 | the fuzzy infix operators work better with literals (e.g. 3) and concrete types
62 | (e.g i:Int).
63 |
64 |
65 | 9. Add "auto-specialization" to the compiler plugin
66 |
67 | I can imagine creating a type alias for Numeric ("SNumeric") and then
68 | making the compiler plugin automatically specialize any uses of it. This way a
69 | user's code could automatically get the gains of specialization without having
70 | to annotate every single function manually.
71 |
72 | I'm not sure if this would actually work, but it'd be worth looking into.
73 |
74 |
75 | 10. Think about how to support other math functions
76 |
77 | Right now there are still some functions (sqrt and others from scala.math)
78 | which aren't defined directly on the types. This is bad, because when you need
79 | to use these functions you end up either calling toDouble (and thus breaking
80 | support for BigDecimal and friends) or you call toBigDecimal (and slow things
81 | down for your fast primitive types). We could define these functions to all
82 | return T, but then you'd have no way of getting a fractional sqrt() when using
83 | Ints.
84 |
85 | This gets back to the issues with the lack of a numeric tower. When taking the
86 | sqrt of an Int we probably want to get back a Double, but when taking the sqrt
87 | of a BigInt we probably want a BigDecimal. I don't see any easy way of dealing
88 | with it, but I may not be creative enough. At various points I've had quixotic
89 | visions of a compiler plugin that could "figure it out"...
90 |
--------------------------------------------------------------------------------
/build.sbt:
--------------------------------------------------------------------------------
1 | // project name
2 | name := "Numeric"
3 |
4 | // shrug?
5 | version := "0.1"
6 |
7 | // test
8 | libraryDependencies += "org.scalatest" % "scalatest_2.9.0" % "1.6.1"
9 |
10 | // hide backup files
11 | defaultExcludes ~= (filter => filter || "*~")
12 |
13 | scalacOptions += "-optimise"
14 |
15 | // any of these work, although 2.9.1 performs the best
16 | //scalaVersion := "2.8.1
17 | //scalaVersion := "2.9.0-1"
18 | scalaVersion := "2.9.1"
19 |
--------------------------------------------------------------------------------
/perf/build.sbt:
--------------------------------------------------------------------------------
1 | // project name
2 | name := "Numeric Performance Test"
3 |
4 | // shrug?
5 | version := "0.1"
6 |
7 | // hide backup files
8 | defaultExcludes ~= (filter => filter || "*~")
9 |
10 | scalacOptions += "-optimise"
11 |
12 | //autoCompilerPlugins := true
13 |
14 | scalacOptions += "-Xplugin:perf/lib/optimized-numeric-plugin_2.9.1-0.1.jar"
15 | //scalacOptions += "-Xplugin:lib/optimized-numeric.jar"
16 |
17 | // any of these work, although 2.9.1 performs the best
18 | //scalaVersion := "2.8.1
19 | //scalaVersion := "2.9.0-1"
20 | scalaVersion := "2.9.1"
21 |
--------------------------------------------------------------------------------
/perf/lib/optimized-numeric-plugin_2.9.1-0.1.jar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/azavea/numeric/15a1e0381224826b418a9d9a6b2c9773ea917137/perf/lib/optimized-numeric-plugin_2.9.1-0.1.jar
--------------------------------------------------------------------------------
/perf/src/main/scala/Main.scala:
--------------------------------------------------------------------------------
1 | import scala.math.max
2 | import scala.math.{Numeric => OldNumeric, Integral, Fractional, min, max}
3 | import scala.util.Random
4 | import scala.testing.Benchmark
5 |
6 | import java.io.{FileWriter, PrintWriter}
7 |
8 | import Console.printf
9 |
10 | import com.azavea.math.Numeric
11 | import com.azavea.math.EasyImplicits._
12 | //import com.azavea.math.FastImplicits._
13 | import Predef.{any2stringadd => _, _}
14 |
15 | // define some constant sizes and random arrays that we can use for our various
16 | // performance tests. if things run way too slow or way too fast you can try
17 | // changing the factor or divisor :)
18 | object Constant {
19 | val factor = 1
20 | val divisor = 1
21 |
22 | val smallSize:Int = (10 * 1000 * factor) / divisor
23 | val smallIntArray = Array.ofDim[Int](smallSize).map(i => Random.nextInt())
24 | val smallLongArray = Array.ofDim[Long](smallSize).map(i => Random.nextLong())
25 | val smallFloatArray = Array.ofDim[Float](smallSize).map(i => Random.nextFloat())
26 | val smallDoubleArray = Array.ofDim[Double](smallSize).map(i => Random.nextDouble())
27 |
28 | val mediumSize:Int = (1 * 1000 * 1000 * factor) / divisor
29 | val mediumIntArray = Array.ofDim[Int](mediumSize).map(i => Random.nextInt())
30 | val mediumLongArray = Array.ofDim[Long](mediumSize).map(i => Random.nextLong())
31 | val mediumFloatArray = Array.ofDim[Float](mediumSize).map(i => Random.nextFloat())
32 | val mediumDoubleArray = Array.ofDim[Double](mediumSize).map(i => Random.nextDouble())
33 |
34 | val largeSize:Int = (10 * 1000 * 1000 * factor) / divisor
35 | val largeIntArray = Array.ofDim[Int](largeSize).map(i => Random.nextInt())
36 | val largeLongArray = Array.ofDim[Long](largeSize).map(i => Random.nextLong())
37 | val largeFloatArray = Array.ofDim[Float](largeSize).map(i => Random.nextFloat())
38 | val largeDoubleArray = Array.ofDim[Double](largeSize).map(i => Random.nextDouble())
39 |
40 | val createHTML = true
41 | }
42 | import Constant._
43 |
44 | case class TestResult(tavg:Double, tmax:Long, tmin:Long)
45 |
46 | // represents a particular performance test we want to run
47 | trait TestCase {
48 | def name: String
49 |
50 | // direct implementation using primitives
51 | def direct(): Option[Any]
52 |
53 | // implemented using the new Numeric trait
54 | def newGeneric(): Option[Any]
55 |
56 | // implemented using the built-in Numeric trait
57 | def oldGeneric(): Option[Any]
58 |
59 | object CaseDirect extends Benchmark { def run = direct }
60 | object CaseGeneric extends Benchmark { def run = newGeneric }
61 | object CaseOld extends Benchmark { def run = oldGeneric }
62 |
63 | val cases = List(CaseDirect, CaseGeneric, CaseOld)
64 |
65 | def runCase(c:Benchmark, warmupRuns:Int, liveRuns:Int) = {
66 | c.runBenchmark(warmupRuns)
67 | val results = c.runBenchmark(liveRuns)
68 |
69 | val total = results.foldLeft(0L)(_ + _)
70 |
71 | val tmax = results.foldLeft(0L)(max(_, _))
72 | val tmin = results.foldLeft(Long.MaxValue)(min(_, _))
73 | val tavg = total.toDouble / liveRuns
74 |
75 | TestResult(tavg, tmax, tmin)
76 | }
77 |
78 | def run(n:Int, m:Int):List[TestResult] = cases.map(runCase(_, n, m))
79 |
80 | def percStatus(p:Double) = if (p < 0.9) "great"
81 | else if (p < 1.1) "good"
82 | else if (p < 2.2) "ok"
83 | else if (p < 4.4) "poor"
84 | else if (p < 8.8) "bad"
85 | else "awful"
86 |
87 | def test(p:Option[PrintWriter]) {
88 | val results = run(2, 8)
89 |
90 | val times = results.map(_.tavg)
91 | val tstrs = times.map {
92 | t => if (t < 0.1) {
93 | " n/a"
94 | } else {
95 | "%6.1fms".format(t)
96 | }
97 | }
98 |
99 | val List(t1, t2, t3) = times
100 |
101 | def mkp(a:Double, b:Double) = if (a == 0.0 || b == 0.0) 0.0 else a / b
102 |
103 | val percs = List(mkp(t2, t1), mkp(t3, t1), mkp(t3, t2))
104 | val pstrs = percs.map(p => if(p < 0.01) " n/a" else "%5.2fx".format(p))
105 |
106 | p match {
107 | case Some(pw) => {
108 | val a = "
%s | %.1f | ".format(name, times(0))
109 | val s1 = percStatus(percs(0))
110 | val b = if (times(1) < 0.1) {
111 | " | | "
112 | } else {
113 | "%.1f | %s | ".format(s1, times(1), s1, pstrs(0))
114 | }
115 | val s2 = percStatus(percs(1))
116 | val c = if (times(2) < 0.1) {
117 | " | | "
118 | } else {
119 | "%.1f | %s |
".format(s2, times(2), s2, pstrs(1))
120 | }
121 | pw.println(a + b + c)
122 | }
123 | case _ => {}
124 | }
125 |
126 | val fields = ("%-24s".format(name) :: tstrs) ++ ("/" :: pstrs)
127 | println(fields.reduceLeft(_ + " " + _))
128 | }
129 | }
130 |
131 |
132 | // ===============================================================
133 | trait FromIntToX extends TestCase {
134 | def directToInt(a:Array[Int]) = {
135 | val b = Array.ofDim[Int](a.length)
136 | var i = 0
137 | while (i < a.length) {
138 | b(i) = a(i).toInt
139 | i += 1
140 | }
141 | b
142 | }
143 | def directToLong(a:Array[Int]) = {
144 | val b = Array.ofDim[Long](a.length)
145 | var i = 0
146 | while (i < a.length) {
147 | b(i) = a(i).toLong
148 | i += 1
149 | }
150 | b
151 | }
152 |
153 | def directToFloat(a:Array[Int]) = {
154 | val b = Array.ofDim[Float](a.length)
155 | var i = 0
156 | while (i < a.length) {
157 | b(i) = a(i).toFloat
158 | i += 1
159 | }
160 | b
161 | }
162 |
163 | def directToDouble(a:Array[Int]) = {
164 | val b = Array.ofDim[Double](a.length)
165 | var i = 0
166 | while (i < a.length) {
167 | b(i) = a(i).toDouble
168 | i += 1
169 | }
170 | b
171 | }
172 |
173 | def newFromInts[@specialized A:Numeric:Manifest](a:Array[Int]): Array[A] = {
174 | val b = Array.ofDim[A](a.length)
175 | var i = 0
176 | while (i < a.length) {
177 | b(i) = numeric.fromInt(a(i))
178 | i += 1
179 | }
180 | b
181 | }
182 |
183 | def oldFromInts[A:OldNumeric:Manifest](a:Array[Int]): Array[A] = {
184 | val m = implicitly[OldNumeric[A]]
185 | val b = Array.ofDim[A](a.length)
186 | var i = 0
187 | while (i < a.length) {
188 | b(i) = m.fromInt(a(i))
189 | i += 1
190 | }
191 | b
192 | }
193 | }
194 |
195 | final class FromIntToInt extends FromIntToX {
196 | def name = "from-int-to-int"
197 | def direct() = Option(directToInt(largeIntArray))
198 | def newGeneric() = Option(newFromInts[Int](largeIntArray))
199 | def oldGeneric() = Option(oldFromInts[Int](largeIntArray))
200 | }
201 |
202 | final class FromIntToLong extends FromIntToX {
203 | def name = "from-int-to-long"
204 | def direct() = Option(directToLong(largeIntArray))
205 | def newGeneric() = Option(newFromInts[Long](largeIntArray))
206 | def oldGeneric() = Option(oldFromInts[Long](largeIntArray))
207 | }
208 |
209 | final class FromIntToFloat extends FromIntToX {
210 | def name = "from-int-to-float"
211 | def direct() = Option(directToFloat(largeIntArray))
212 | def newGeneric() = Option(newFromInts[Float](largeIntArray))
213 | def oldGeneric() = Option(oldFromInts[Float](largeIntArray))
214 | }
215 |
216 | final class FromIntToDouble extends FromIntToX {
217 | def name = "from-int-to-double"
218 | def direct() = Option(directToDouble(largeIntArray))
219 | def newGeneric() = Option(newFromInts[Double](largeIntArray))
220 | def oldGeneric() = Option(oldFromInts[Double](largeIntArray))
221 | }
222 |
223 |
224 |
225 | // =================================================================
226 | trait BaseAdder extends TestCase {
227 | def newAdder[@specialized A](a:A, b:A)(implicit m:Numeric[A]): A
228 | def oldAdder[A](a:A, b:A)(implicit m:OldNumeric[A]): A
229 |
230 | def directIntAdder(a:Int, b:Int):Int = a + b
231 | def directLongAdder(a:Long, b:Long):Long = a + b
232 | def directFloatAdder(a:Float, b:Float):Float = a + b
233 | def directDoubleAdder(a:Double, b:Double):Double = a + b
234 | }
235 |
236 | trait BaseAdderInt extends BaseAdder {
237 | def direct() = {
238 | var s = 0
239 | var i = 0
240 | while (i < largeSize) { s = directIntAdder(s, i); i += 1 }
241 | Option(s)
242 | }
243 | def newGeneric() = {
244 | var s = 0
245 | var i = 0
246 | while (i < largeSize) { s = newAdder(s, i); i += 1 }
247 | Option(s)
248 | }
249 | def oldGeneric() = {
250 | var s = 0
251 | var i = 0
252 | while (i < largeSize) { s = oldAdder(s, i); i += 1 }
253 | Option(s)
254 | }
255 | }
256 |
257 | trait BaseAdderLong extends BaseAdder {
258 | def direct() = {
259 | var s = 0L
260 | var i = 0
261 | while (i < largeSize) { s = directLongAdder(s, i); i += 1 }
262 | Option(s)
263 | }
264 | def newGeneric() = {
265 | var s = 0L
266 | var i = 0
267 | while (i < largeSize) { s = newAdder(s, i); i += 1 }
268 | Option(s)
269 | }
270 | def oldGeneric() = {
271 | var s = 0L
272 | var i = 0
273 | while (i < largeSize) { s = oldAdder(s, i); i += 1 }
274 | Option(s)
275 | }
276 | }
277 |
278 | trait BaseAdderFloat extends BaseAdder {
279 | def direct() = {
280 | var s = 0.0F
281 | var i = 0
282 | while (i < largeSize) { s = directFloatAdder(s, i); i += 1 }
283 | Option(s)
284 | }
285 | def newGeneric() = {
286 | var s = 0.0F
287 | var i = 0
288 | while (i < largeSize) { s = newAdder(s, i); i += 1 }
289 | Option(s)
290 | }
291 | def oldGeneric() = {
292 | var s = 0.0F
293 | var i = 0
294 | while (i < largeSize) { s = oldAdder(s, i); i += 1 }
295 | Option(s)
296 | }
297 | }
298 |
299 | trait BaseAdderDouble extends BaseAdder {
300 | def direct() = {
301 | var s = 0.0
302 | var i = 0
303 | while (i < largeSize) { s = directDoubleAdder(s, i); i += 1 }
304 | Option(s)
305 | }
306 | def newGeneric() = {
307 | var s = 0.0
308 | var i = 0
309 | while (i < largeSize) { s = newAdder(s, i); i += 1 }
310 | Option(s)
311 | }
312 | def oldGeneric() = {
313 | var s = 0.0
314 | var i = 0
315 | while (i < largeSize) { s = oldAdder(s, i); i += 1 }
316 | Option(s)
317 | }
318 | }
319 |
320 |
321 | // =========================================================
322 | trait Adder extends BaseAdder {
323 | def newAdder[@specialized A](a:A, b:A)(implicit m:Numeric[A]): A = m.plus(a, b)
324 | def oldAdder[A](a:A, b:A)(implicit m:OldNumeric[A]): A = m.plus(a, b)
325 | }
326 |
327 | final class AdderInt extends Adder with BaseAdderInt { def name = "adder-int" }
328 | final class AdderLong extends Adder with BaseAdderLong { def name = "adder-long" }
329 | final class AdderFloat extends Adder with BaseAdderFloat { def name = "adder-float" }
330 | final class AdderDouble extends Adder with BaseAdderDouble{ def name = "adder-double" }
331 |
332 |
333 | // =====================================================
334 | trait BaseArrayOps extends TestCase {
335 | def directIntArrayOp(a:Int, b:Int): Int
336 | def directIntArrayOps(a:Array[Int]) = {
337 | var total = 0
338 | var i = 0
339 | while (i < a.length) {
340 | total = directIntArrayOp(total, a(i))
341 | i += 1
342 | }
343 | total
344 | }
345 |
346 | def directLongArrayOp(a:Long, b:Long): Long
347 | def directLongArrayOps(a:Array[Long]) = {
348 | var total = 0L
349 | var i = 0
350 | while (i < a.length) {
351 | total = directLongArrayOp(total, a(i))
352 | i += 1
353 | }
354 | total
355 | }
356 |
357 | def directFloatArrayOp(a:Float, b:Float): Float
358 | def directFloatArrayOps(a:Array[Float]) = {
359 | var total = 0.0F
360 | var i = 0
361 | while (i < a.length) {
362 | total = directFloatArrayOp(total, a(i))
363 | i += 1
364 | }
365 | total
366 | }
367 |
368 | def directDoubleArrayOp(a:Double, b:Double): Double
369 | def directDoubleArrayOps(a:Array[Double]) = {
370 | var total = 0.0
371 | var i = 0
372 | while (i < a.length) {
373 | total = directDoubleArrayOp(total, a(i))
374 | i += 1
375 | }
376 | total
377 | }
378 |
379 | def newArrayOp[@specialized A:Numeric](a:A, b:A): A
380 | def newArrayOps[@specialized A:Numeric:Manifest](a:Array[A]) = {
381 | var total = numeric.zero
382 | var i = 0
383 | while (i < a.length) {
384 | total = newArrayOp(total, a(i))
385 | i += 1
386 | }
387 | total
388 | }
389 |
390 | def oldArrayOp[A](a:A, b:A)(implicit m:OldNumeric[A]): A
391 | def oldArrayOps[A:OldNumeric:Manifest](a:Array[A]) = {
392 | val m = implicitly[OldNumeric[A]]
393 | var total = m.zero
394 | var i = 0
395 | while (i < a.length) {
396 | total = oldArrayOp(total, a(i))
397 | i += 1
398 | }
399 | total
400 | }
401 | }
402 |
403 |
404 | trait BaseArrayMapOps extends TestCase {
405 | def directIntArrayOp(a:Int): Int
406 | def directIntArrayOps(a:Array[Int]) = {
407 | var i = 0
408 | while (i < a.length) {
409 | a(i) = directIntArrayOp(a(i))
410 | i += 1
411 | }
412 | a
413 | }
414 |
415 | def directLongArrayOp(a:Long): Long
416 | def directLongArrayOps(a:Array[Long]) = {
417 | var i = 0
418 | while (i < a.length) {
419 | a(i) = directLongArrayOp(a(i))
420 | i += 1
421 | }
422 | a
423 | }
424 |
425 | def directFloatArrayOp(a:Float): Float
426 | def directFloatArrayOps(a:Array[Float]) = {
427 | var i = 0
428 | while (i < a.length) {
429 | a(i) = directFloatArrayOp(a(i))
430 | i += 1
431 | }
432 | a
433 | }
434 |
435 | def directDoubleArrayOp(a:Double): Double
436 | def directDoubleArrayOps(a:Array[Double]) = {
437 | var i = 0
438 | while (i < a.length) {
439 | a(i) = directDoubleArrayOp(a(i))
440 | i += 1
441 | }
442 | a
443 | }
444 |
445 | def newArrayOp[@specialized A:Numeric](a:A): A
446 | def newArrayOps[@specialized A:Numeric:Manifest](a:Array[A]) = {
447 | var i = 0
448 | while (i < a.length) {
449 | a(i) = newArrayOp(a(i))
450 | i += 1
451 | }
452 | a
453 | }
454 |
455 | def oldArrayOp[A](a:A)(implicit m:OldNumeric[A]): A
456 | def oldArrayOps[A:OldNumeric:Manifest](a:Array[A]) = {
457 | val m = implicitly[OldNumeric[A]]
458 | var i = 0
459 | while (i < a.length) {
460 | a(i) = oldArrayOp(a(i))
461 | i += 1
462 | }
463 | a
464 | }
465 | }
466 |
467 |
468 | // ======================================================
469 | trait ArrayAdder extends BaseArrayOps {
470 | def directIntArrayOp(a:Int, b:Int) = a + b
471 | def directLongArrayOp(a:Long, b:Long) = a + b
472 | def directFloatArrayOp(a:Float, b:Float) = a + b
473 | def directDoubleArrayOp(a:Double, b:Double) = a + b
474 |
475 | def newArrayOp[@specialized A:Numeric](a:A, b:A) = numeric.plus(a, b)
476 | def oldArrayOp[A](a:A, b:A)(implicit m:OldNumeric[A]) = m.plus(a, b)
477 | }
478 |
479 | final class IntArrayAdder extends ArrayAdder {
480 | def name = "array-total-int"
481 | def direct() = Option(directIntArrayOps(largeIntArray))
482 | def newGeneric() = Option(newArrayOps(largeIntArray))
483 | def oldGeneric() = Option(oldArrayOps(largeIntArray))
484 | }
485 |
486 | final class LongArrayAdder extends ArrayAdder {
487 | def name = "array-total-long"
488 | def direct() = Option(directLongArrayOps(largeLongArray))
489 | def newGeneric() = Option(newArrayOps(largeLongArray))
490 | def oldGeneric() = Option(oldArrayOps(largeLongArray))
491 | }
492 |
493 | final class FloatArrayAdder extends ArrayAdder {
494 | def name = "array-total-float"
495 | def direct() = Option(directFloatArrayOps(largeFloatArray))
496 | def newGeneric() = Option(newArrayOps(largeFloatArray))
497 | def oldGeneric() = Option(oldArrayOps(largeFloatArray))
498 | }
499 |
500 | final class DoubleArrayAdder extends ArrayAdder {
501 | def name = "array-total-double"
502 | def direct() = Option(directDoubleArrayOps(largeDoubleArray))
503 | def newGeneric() = Option(newArrayOps(largeDoubleArray))
504 | def oldGeneric() = Option(oldArrayOps(largeDoubleArray))
505 | }
506 |
507 |
508 |
509 |
510 | // ==========================================================================
511 | trait ArrayRescale extends BaseArrayMapOps {
512 | def directIntArrayOp(b:Int) = (b * 5) / 3
513 | def directLongArrayOp(b:Long) = (b * 5L) / 3L
514 | def directFloatArrayOp(b:Float) = (b * 5.0F) / 3.0F
515 | def directDoubleArrayOp(b:Double) = (b * 5.0) / 3.0
516 |
517 | def newArrayOp[@specialized A](b:A)(implicit m:Numeric[A]) = m.div(m.times(b, m.fromDouble(5.0)), m.fromDouble(3.0))
518 | def oldArrayOp[A](b:A)(implicit m:OldNumeric[A]) = m.fromInt((m.toDouble(b) * 5.0 / 3.0).toInt)
519 | }
520 |
521 | final class IntArrayRescale extends ArrayRescale {
522 | def name = "array-rescale-int"
523 | def direct() = Option(directIntArrayOps(largeIntArray))
524 | def newGeneric() = Option(newArrayOps(largeIntArray))
525 | def oldGeneric() = Option(oldArrayOps(largeIntArray))
526 | }
527 |
528 | final class LongArrayRescale extends ArrayRescale {
529 | def name = "array-rescale-long"
530 | def direct() = Option(directLongArrayOps(largeLongArray))
531 | def newGeneric() = Option(newArrayOps(largeLongArray))
532 | def oldGeneric() = Option(oldArrayOps(largeLongArray))
533 | }
534 |
535 | final class FloatArrayRescale extends ArrayRescale {
536 | def name = "array-rescale-float"
537 | def direct() = Option(directFloatArrayOps(largeFloatArray))
538 | def newGeneric() = Option(newArrayOps(largeFloatArray))
539 | def oldGeneric() = Option(oldArrayOps(largeFloatArray))
540 | }
541 |
542 | final class DoubleArrayRescale extends ArrayRescale {
543 | def name = "array-rescale-double"
544 | def direct() = Option(directDoubleArrayOps(largeDoubleArray))
545 | def newGeneric() = Option(newArrayOps(largeDoubleArray))
546 | def oldGeneric() = Option(oldArrayOps(largeDoubleArray))
547 | }
548 |
549 |
550 |
551 | // ==========================================================================
552 | trait InfixAdder extends BaseAdder {
553 | def newAdder[@specialized A:Numeric](a:A, b:A): A = a + b
554 | def oldAdder[A](a:A, b:A)(implicit m:OldNumeric[A]): A = {
555 | import m._
556 | a + b
557 | }
558 | }
559 |
560 | final class InfixAdderInt extends InfixAdder with BaseAdderInt { def name = "infix-adder-int" }
561 | final class InfixAdderLong extends InfixAdder with BaseAdderLong { def name = "infix-adder-long" }
562 | final class InfixAdderFloat extends InfixAdder with BaseAdderFloat { def name = "infix-adder-float" }
563 | final class InfixAdderDouble extends InfixAdder with BaseAdderDouble{ def name = "infix-adder-double" }
564 |
565 | // ==========================================================
566 | trait FindMax extends TestCase {
567 | def directMaxInt(a:Array[Int]) = {
568 | var curr = a(0)
569 | var i = 1
570 | while (i < a.length) { curr = scala.math.max(curr, a(i)); i += 1 }
571 | curr
572 | }
573 |
574 | def directMaxLong(a:Array[Long]) = {
575 | var curr = a(0)
576 | var i = 1
577 | while (i < a.length) { curr = scala.math.max(curr, a(i)); i += 1 }
578 | curr
579 | }
580 |
581 | def directMaxFloat(a:Array[Float]) = {
582 | var curr = a(0)
583 | var i = 1
584 | while (i < a.length) { curr = scala.math.max(curr, a(i)); i += 1 }
585 | curr
586 | }
587 |
588 | def directMaxDouble(a:Array[Double]) = {
589 | var curr = a(0)
590 | var i = 1
591 | while (i < a.length) { curr = scala.math.max(curr, a(i)); i += 1 }
592 | curr
593 | }
594 |
595 | def newGenericMax[@specialized A:Numeric](a:Array[A]) = {
596 | var curr = a(0)
597 | var i = 1
598 | while (i < a.length) { curr = numeric.max(curr, a(i)); i += 1 }
599 | curr
600 | }
601 |
602 | def oldGenericMax[A:OldNumeric](a:Array[A]) = {
603 | val n = implicitly[OldNumeric[A]]
604 | var curr = a(0)
605 | var i = 1
606 | while (i < a.length) { curr = n.max(curr, a(i)); i += 1 }
607 | curr
608 | }
609 | }
610 |
611 | final class FindMaxInt extends FindMax {
612 | def name = "find-max-int"
613 | def direct() = Some(directMaxInt(largeIntArray))
614 | def newGeneric() = Some(newGenericMax(largeIntArray))
615 | def oldGeneric() = Some(oldGenericMax(largeIntArray))
616 | }
617 |
618 | final class FindMaxLong extends FindMax {
619 | def name = "find-max-long"
620 | def direct() = Some(directMaxLong(largeLongArray))
621 | def newGeneric() = Some(newGenericMax(largeLongArray))
622 | def oldGeneric() = Some(oldGenericMax(largeLongArray))
623 | }
624 |
625 | final class FindMaxFloat extends FindMax {
626 | def name = "find-max-float"
627 | def direct() = Some(directMaxFloat(largeFloatArray))
628 | def newGeneric() = Some(newGenericMax(largeFloatArray))
629 | def oldGeneric() = Some(oldGenericMax(largeFloatArray))
630 | }
631 |
632 | final class FindMaxDouble extends FindMax {
633 | def name = "find-max-double"
634 | def direct() = Some(directMaxDouble(largeDoubleArray))
635 | def newGeneric() = Some(newGenericMax(largeDoubleArray))
636 | def oldGeneric() = Some(oldGenericMax(largeDoubleArray))
637 | }
638 |
639 | // ================================================================
640 | trait BaseSort extends TestCase {
641 | def directIntSorter(a:Array[Int]): Array[Int]
642 | def directLongSorter(a:Array[Long]): Array[Long]
643 | def directFloatSorter(a:Array[Float]): Array[Float]
644 | def directDoubleSorter(a:Array[Double]): Array[Double]
645 |
646 | def newGenericSorter[@specialized A:Numeric:Manifest](a:Array[A]): Array[A]
647 | def oldGenericSorter[A:OldNumeric:Manifest](a:Array[A]): Array[A]
648 |
649 | def directInt(a:Array[Int]) = Option(directIntSorter(a))
650 | def directLong(a:Array[Long]) = Option(directLongSorter(a))
651 | def directFloat(a:Array[Float]) = Option(directFloatSorter(a))
652 | def directDouble(a:Array[Double]) = Option(directDoubleSorter(a))
653 |
654 | def newGenericSort[@specialized A:Numeric:Manifest](a:Array[A]) = Option(newGenericSorter(a))
655 | def oldGenericSort[A:OldNumeric:Manifest](a:Array[A]) = Option(oldGenericSorter(a))
656 | }
657 |
658 |
659 | // =======================================================================
660 | trait Quicksort extends BaseSort {
661 | def directIntSorter(a:Array[Int]) = { val d = a.clone; scala.util.Sorting.quickSort(d); d }
662 | def directLongSorter(a:Array[Long]) = { val d = a.clone; scala.util.Sorting.quickSort(d); d }
663 | def directFloatSorter(a:Array[Float]) = { val d = a.clone; scala.util.Sorting.quickSort(d); d }
664 | def directDoubleSorter(a:Array[Double]) = { val d = a.clone; scala.util.Sorting.quickSort(d); d }
665 |
666 | // NOTE: this will perform slowly just because Ordering is not specialized!
667 | def newGenericSorter[@specialized A:Numeric:Manifest](a:Array[A]) = {
668 | val d = a.clone;
669 | implicit val ord = implicitly[Numeric[A]].getOrdering()
670 | scala.util.Sorting.quickSort(d);
671 | d
672 | }
673 | def oldGenericSorter[A:OldNumeric:Manifest](a:Array[A]) = { val d = a.clone; scala.util.Sorting.quickSort(d); d }
674 | }
675 |
676 | final class QuicksortInt extends Quicksort {
677 | def name = "quicksort-int"
678 | def direct() = directInt(mediumIntArray)
679 | def newGeneric() = newGenericSort(mediumIntArray)
680 | def oldGeneric() = oldGenericSort(mediumIntArray)
681 | }
682 |
683 | final class QuicksortLong extends Quicksort {
684 | def name = "quicksort-long"
685 | def direct() = directLong(mediumLongArray)
686 | def newGeneric() = newGenericSort(mediumLongArray)
687 | def oldGeneric() = oldGenericSort(mediumLongArray)
688 | }
689 |
690 | final class QuicksortFloat extends Quicksort {
691 | def name = "quicksort-float"
692 | def direct() = directFloat(mediumFloatArray)
693 | def newGeneric() = newGenericSort(mediumFloatArray)
694 | def oldGeneric() = oldGenericSort(mediumFloatArray)
695 | }
696 |
697 | final class QuicksortDouble extends Quicksort {
698 | def name = "quicksort-double"
699 | def direct() = directDouble(mediumDoubleArray)
700 | def newGeneric() = newGenericSort(mediumDoubleArray)
701 | def oldGeneric() = oldGenericSort(mediumDoubleArray)
702 | }
703 |
704 | // ==========================================
705 | trait InsertionSort extends BaseSort {
706 | def directIntSorter(b:Array[Int]) = {
707 | val a = b.clone
708 | var i = 0
709 | while (i < a.length - 1) {
710 | var j = i + 1
711 | var k = i
712 | while (j < a.length) {
713 | if (a(j) < a(i)) k = j
714 | j += 1
715 | }
716 | val temp = a(i)
717 | a(i) = a(k)
718 | a(k) = temp
719 | i += 1
720 | }
721 | a
722 | }
723 | def directLongSorter(b:Array[Long]) = {
724 | val a = b.clone
725 | var i = 0
726 | while (i < a.length - 1) {
727 | var j = i + 1
728 | var k = i
729 | while (j < a.length) {
730 | if (a(j) < a(i)) k = j
731 | j += 1
732 | }
733 | val temp = a(i)
734 | a(i) = a(k)
735 | a(k) = temp
736 | i += 1
737 | }
738 | a
739 | }
740 | def directFloatSorter(b:Array[Float]) = {
741 | val a = b.clone
742 | var i = 0
743 | while (i < a.length - 1) {
744 | var j = i + 1
745 | var k = i
746 | while (j < a.length) {
747 | if (a(j) < a(i)) k = j
748 | j += 1
749 | }
750 | val temp = a(i)
751 | a(i) = a(k)
752 | a(k) = temp
753 | i += 1
754 | }
755 | a
756 | }
757 | def directDoubleSorter(b:Array[Double]) = {
758 | val a = b.clone
759 | var i = 0
760 | while (i < a.length - 1) {
761 | var j = i + 1
762 | var k = i
763 | while (j < a.length) {
764 | if (a(j) < a(i)) k = j
765 | j += 1
766 | }
767 | val temp = a(i)
768 | a(i) = a(k)
769 | a(k) = temp
770 | i += 1
771 | }
772 | a
773 | }
774 |
775 | def newGenericSorter[@specialized A:Numeric:Manifest](b:Array[A]) = {
776 | val a = b.clone
777 | var i = 0
778 | while (i < a.length - 1) {
779 | var j = i + 1
780 | var k = i
781 | while (j < a.length) {
782 | if (numeric.lt(a(j), a(i))) k = j
783 | j += 1
784 | }
785 | val temp = a(i)
786 | a(i) = a(k)
787 | a(k) = temp
788 | i += 1
789 | }
790 | a
791 | }
792 |
793 | def oldGenericSorter[A:OldNumeric:Manifest](b:Array[A]) = {
794 | val n = implicitly[OldNumeric[A]]
795 | val a = b.clone
796 | var i = 0
797 | while (i < a.length - 1) {
798 | var j = i + 1
799 | var k = i
800 | while (j < a.length) {
801 | if (n.lt(a(j), a(i))) k = j
802 | j += 1
803 | }
804 | val temp = a(i)
805 | a(i) = a(k)
806 | a(k) = temp
807 | i += 1
808 | }
809 | a
810 | }
811 | }
812 |
813 | final class InsertionSortInt extends InsertionSort {
814 | def name = "insertion-sort-int"
815 | def direct() = directInt(smallIntArray)
816 | def newGeneric() = newGenericSort(smallIntArray)
817 | def oldGeneric() = oldGenericSort(smallIntArray)
818 | }
819 |
820 | final class InsertionSortLong extends InsertionSort {
821 | def name = "insertion-sort-long"
822 | def direct() = directLong(smallLongArray)
823 | def newGeneric() = newGenericSort(smallLongArray)
824 | def oldGeneric() = oldGenericSort(smallLongArray)
825 | }
826 |
827 | final class InsertionSortFloat extends InsertionSort {
828 | def name = "insertion-sort-float"
829 | def direct() = directFloat(smallFloatArray)
830 | def newGeneric() = newGenericSort(smallFloatArray)
831 | def oldGeneric() = oldGenericSort(smallFloatArray)
832 | }
833 |
834 | final class InsertionSortDouble extends InsertionSort {
835 | def name = "insertion-sort-double"
836 | def direct() = directDouble(smallDoubleArray)
837 | def newGeneric() = newGenericSort(smallDoubleArray)
838 | def oldGeneric() = oldGenericSort(smallDoubleArray)
839 | }
840 |
841 |
842 | // ========================================================
843 | trait ArrayAllocator extends TestCase {
844 | def directIntAllocator(num:Int, dim:Int, const:Int) = {
845 | val outer = Array.ofDim[Array[Int]](num)
846 | var i = 0
847 | while (i < num) { outer(i) = Array.fill(dim)(const); i += 1 }
848 | outer
849 | }
850 |
851 | def directLongAllocator(num:Int, dim:Int, const:Long) = {
852 | val outer = Array.ofDim[Array[Long]](num)
853 | var i = 0
854 | while (i < num) { outer(i) = Array.fill(dim)(const); i += 1 }
855 | outer
856 | }
857 |
858 | def directFloatAllocator(num:Int, dim:Int, const:Float) = {
859 | val outer = Array.ofDim[Array[Float]](num)
860 | var i = 0
861 | while (i < num) { outer(i) = Array.fill(dim)(const); i += 1 }
862 | outer
863 | }
864 |
865 | def directDoubleAllocator(num:Int, dim:Int, const:Double) = {
866 | val outer = Array.ofDim[Array[Double]](num)
867 | var i = 0
868 | while (i < num) { outer(i) = Array.fill(dim)(const); i += 1 }
869 | outer
870 | }
871 |
872 | def newAllocator[@specialized A:Numeric:Manifest](num:Int, dim:Int, const:A) = {
873 | val outer = Array.ofDim[Array[A]](num)
874 | var i = 0
875 | while (i < num) { outer(i) = Array.fill(dim)(const); i += 1 }
876 | outer
877 | }
878 |
879 | def oldAllocator[A:OldNumeric:Manifest](num:Int, dim:Int, const:A) = {
880 | val outer = Array.ofDim[Array[A]](num)
881 | var i = 0
882 | while (i < num) { outer(i) = Array.fill(dim)(const); i += 1 }
883 | outer
884 | }
885 | }
886 |
887 | final class ArrayAllocatorInt extends ArrayAllocator {
888 | def name = "array-allocator-int"
889 | def direct = Option(directIntAllocator(mediumSize, 5, 13))
890 | def newGeneric = Option(newAllocator(mediumSize, 5, 13))
891 | def oldGeneric = Option(oldAllocator(mediumSize, 5, 13))
892 | }
893 |
894 | final class ArrayAllocatorLong extends ArrayAllocator {
895 | def name = "array-allocator-long"
896 | def direct = Option(directLongAllocator(mediumSize, 5, 13L))
897 | def newGeneric = Option(newAllocator(mediumSize, 5, 13L))
898 | def oldGeneric = Option(oldAllocator(mediumSize, 5, 13L))
899 | }
900 |
901 | final class ArrayAllocatorFloat extends ArrayAllocator {
902 | def name = "array-allocator-float"
903 | def direct = Option(directFloatAllocator(mediumSize, 5, 13.0F))
904 | def newGeneric = Option(newAllocator(mediumSize, 5, 13.0F))
905 | def oldGeneric = Option(oldAllocator(mediumSize, 5, 13.0F))
906 | }
907 |
908 | final class ArrayAllocatorDouble extends ArrayAllocator {
909 | def name = "array-allocator-double"
910 | def direct = Option(directDoubleAllocator(mediumSize, 5, 13.0))
911 | def newGeneric = Option(newAllocator(mediumSize, 5, 13.0))
912 | def oldGeneric = Option(oldAllocator(mediumSize, 5, 13.0))
913 | }
914 |
915 | // =================================================================
916 | trait MergeSort extends BaseSort {
917 | def directIntSorter(a:Array[Int]) = {
918 | if (a.length > 1) {
919 | val llen = a.length / 2
920 | val rlen = a.length - llen
921 |
922 | val left = Array.ofDim[Int](llen)
923 | Array.copy(a, 0, left, 0, llen)
924 | directIntSorter(left)
925 |
926 | val right = Array.ofDim[Int](rlen)
927 | Array.copy(a, llen, right, 0, rlen)
928 | directIntSorter(right)
929 |
930 | var (i, j, k) = (0, 0, 0)
931 | while (i < llen || j < rlen) {
932 | if (j == rlen) {
933 | a(k) = left(i); i += 1
934 | } else if (i == llen) {
935 | a(k) = right(j); j += 1
936 | } else if (left(i) < right(j)) {
937 | a(k) = left(i); i += 1
938 | } else {
939 | a(k) = right(j); j += 1
940 | }
941 | k += 1
942 | }
943 | }
944 | a
945 | }
946 |
947 | def directLongSorter(a:Array[Long]) = {
948 | if (a.length > 1) {
949 | val llen = a.length / 2
950 | val rlen = a.length - llen
951 |
952 | val left = Array.ofDim[Long](llen)
953 | Array.copy(a, 0, left, 0, llen)
954 | directLongSorter(left)
955 |
956 | val right = Array.ofDim[Long](rlen)
957 | Array.copy(a, llen, right, 0, rlen)
958 | directLongSorter(right)
959 |
960 | var (i, j, k) = (0, 0, 0)
961 | while (i < llen || j < rlen) {
962 | if (j == rlen) {
963 | a(k) = left(i); i += 1
964 | } else if (i == llen) {
965 | a(k) = right(j); j += 1
966 | } else if (left(i) < right(j)) {
967 | a(k) = left(i); i += 1
968 | } else {
969 | a(k) = right(j); j += 1
970 | }
971 | k += 1
972 | }
973 | }
974 | a
975 | }
976 |
977 | def directFloatSorter(a:Array[Float]) = {
978 | if (a.length > 1) {
979 | val llen = a.length / 2
980 | val rlen = a.length - llen
981 |
982 | val left = Array.ofDim[Float](llen)
983 | Array.copy(a, 0, left, 0, llen)
984 | directFloatSorter(left)
985 |
986 | val right = Array.ofDim[Float](rlen)
987 | Array.copy(a, llen, right, 0, rlen)
988 | directFloatSorter(right)
989 |
990 | var (i, j, k) = (0, 0, 0)
991 | while (i < llen || j < rlen) {
992 | if (j == rlen) {
993 | a(k) = left(i); i += 1
994 | } else if (i == llen) {
995 | a(k) = right(j); j += 1
996 | } else if (left(i) < right(j)) {
997 | a(k) = left(i); i += 1
998 | } else {
999 | a(k) = right(j); j += 1
1000 | }
1001 | k += 1
1002 | }
1003 | }
1004 | a
1005 | }
1006 |
1007 | def directDoubleSorter(a:Array[Double]) = {
1008 | if (a.length > 1) {
1009 | val llen = a.length / 2
1010 | val rlen = a.length - llen
1011 |
1012 | val left = Array.ofDim[Double](llen)
1013 | Array.copy(a, 0, left, 0, llen)
1014 | directDoubleSorter(left)
1015 |
1016 | val right = Array.ofDim[Double](rlen)
1017 | Array.copy(a, llen, right, 0, rlen)
1018 | directDoubleSorter(right)
1019 |
1020 | var (i, j, k) = (0, 0, 0)
1021 | while (i < llen || j < rlen) {
1022 | if (j == rlen) {
1023 | a(k) = left(i); i += 1
1024 | } else if (i == llen) {
1025 | a(k) = right(j); j += 1
1026 | } else if (left(i) < right(j)) {
1027 | a(k) = left(i); i += 1
1028 | } else {
1029 | a(k) = right(j); j += 1
1030 | }
1031 | k += 1
1032 | }
1033 | }
1034 | a
1035 | }
1036 |
1037 | def newGenericSorter[@specialized A:Numeric:Manifest](a:Array[A]):Array[A] = {
1038 | if (a.length > 1) {
1039 | val llen = a.length / 2
1040 | val rlen = a.length - llen
1041 |
1042 | val left = Array.ofDim[A](llen)
1043 | Array.copy(a, 0, left, 0, llen)
1044 | newGenericSorter(left)
1045 |
1046 | val right = Array.ofDim[A](rlen)
1047 | Array.copy(a, llen, right, 0, rlen)
1048 | newGenericSorter(right)
1049 |
1050 | var (i, j, k) = (0, 0, 0)
1051 | while (i < llen || j < rlen) {
1052 | if (j == rlen) {
1053 | a(k) = left(i); i += 1
1054 | } else if (i == llen) {
1055 | a(k) = right(j); j += 1
1056 | } else if (numeric.lt(left(i), right(j))) {
1057 | a(k) = left(i); i += 1
1058 | } else {
1059 | a(k) = right(j); j += 1
1060 | }
1061 | k += 1
1062 | }
1063 | }
1064 | a
1065 | }
1066 |
1067 | def oldGenericSorter[A:OldNumeric:Manifest](a:Array[A]):Array[A] = {
1068 | if (a.length > 1) {
1069 | val n = implicitly[OldNumeric[A]]
1070 |
1071 | val llen = a.length / 2
1072 | val rlen = a.length - llen
1073 |
1074 | val left = Array.ofDim[A](llen)
1075 | Array.copy(a, 0, left, 0, llen)
1076 | oldGenericSorter(left)
1077 |
1078 | val right = Array.ofDim[A](rlen)
1079 | Array.copy(a, llen, right, 0, rlen)
1080 | oldGenericSorter(right)
1081 |
1082 | var (i, j, k) = (0, 0, 0)
1083 | while (i < llen || j < rlen) {
1084 | if (j == rlen) {
1085 | a(k) = left(i); i += 1
1086 | } else if (i == llen) {
1087 | a(k) = right(j); j += 1
1088 | } else if (n.lt(left(i), right(j))) {
1089 | a(k) = left(i); i += 1
1090 | } else {
1091 | a(k) = right(j); j += 1
1092 | }
1093 | k += 1
1094 | }
1095 | }
1096 | a
1097 | }
1098 | }
1099 |
1100 | final class MergeSortInt extends MergeSort {
1101 | def name = "merge-sort-int"
1102 | def direct() = { val a = mediumIntArray.clone; directInt(a); Some(a) }
1103 | def newGeneric() = { val a = mediumIntArray.clone; newGenericSort(a); Some(a) }
1104 | def oldGeneric() = { val a = mediumIntArray.clone; oldGenericSort(a); Some(a) }
1105 | }
1106 |
1107 | final class MergeSortLong extends MergeSort {
1108 | def name = "merge-sort-long"
1109 | def direct() = { val a = mediumLongArray.clone; directLong(a); Some(a) }
1110 | def newGeneric() = { val a = mediumLongArray; newGenericSort(a); Some(a) }
1111 | def oldGeneric() = { val a = mediumLongArray; oldGenericSort(a); Some(a) }
1112 | }
1113 |
1114 | final class MergeSortFloat extends MergeSort {
1115 | def name = "merge-sort-float"
1116 | def direct() = { val a = mediumFloatArray; directFloat(a); Some(a) }
1117 | def newGeneric() = { val a = mediumFloatArray; newGenericSort(a); Some(a) }
1118 | def oldGeneric() = { val a = mediumFloatArray; oldGenericSort(a); Some(a) }
1119 | }
1120 |
1121 | final class MergeSortDouble extends MergeSort {
1122 | def name = "merge-sort-double"
1123 | def direct() = { val a = mediumDoubleArray; directDouble(a); Some(a) }
1124 | def newGeneric() = { val a = mediumDoubleArray; newGenericSort(a); Some(a) }
1125 | def oldGeneric() = { val a = mediumDoubleArray; oldGenericSort(a); Some(a) }
1126 | }
1127 |
1128 |
1129 | // =================================================================
1130 | final class IncrementInt1 extends TestCase {
1131 | def name = "increment-int1"
1132 |
1133 | def directIncrement(x:Int) = x + 100
1134 | def direct() = {
1135 | var i = 0
1136 | var total = 0
1137 | while (i < largeSize) {
1138 | total = directIncrement(total)
1139 | i += 1
1140 | }
1141 | Some(total)
1142 | }
1143 |
1144 | def newIncrement[@specialized A:Numeric](a:A):A = numeric.plus(a, numeric.fromInt(100))
1145 | def newGeneric() = {
1146 | var i = 0
1147 | var total = 0
1148 | while (i < largeSize) {
1149 | total = newIncrement(total)
1150 | i += 1
1151 | }
1152 | Some(total)
1153 | }
1154 |
1155 | def oldIncrement[A](a:A)(implicit n:OldNumeric[A]) = n.plus(a, n.fromInt(100))
1156 | def oldGeneric() = {
1157 | var i = 0
1158 | var total = 0
1159 | while (i < largeSize) {
1160 | total = oldIncrement(total)
1161 | i += 1
1162 | }
1163 | Some(total)
1164 | }
1165 | }
1166 |
1167 | final class IncrementInt2 extends TestCase {
1168 | def name = "increment-int2"
1169 |
1170 | def directIncrement(x:Int) = x + 100
1171 | def direct() = {
1172 | var i = 0
1173 | var total = 0
1174 | while (i < largeSize) {
1175 | total = directIncrement(total)
1176 | i += 1
1177 | }
1178 | Some(total)
1179 | }
1180 |
1181 | def newIncrement[@specialized A:Numeric](a:A):A = a + numeric.fromInt(100)
1182 | def newGeneric() = {
1183 | var i = 0
1184 | var total = 0
1185 | while (i < largeSize) {
1186 | total = newIncrement(total)
1187 | i += 1
1188 | }
1189 | Some(total)
1190 | }
1191 |
1192 | def oldIncrement[A](a:A)(implicit n:OldNumeric[A]) = n.plus(a, n.fromInt(100))
1193 | def oldGeneric() = {
1194 | var i = 0
1195 | var total = 0
1196 | while (i < largeSize) {
1197 | total = oldIncrement(total)
1198 | i += 1
1199 | }
1200 | Some(total)
1201 | }
1202 | }
1203 |
1204 | final class IncrementInt3 extends TestCase {
1205 | def name = "increment-int3"
1206 |
1207 | def directIncrement(x:Int) = x + 100
1208 | def direct() = {
1209 | var i = 0
1210 | var total = 0
1211 | while (i < largeSize) {
1212 | total = directIncrement(total)
1213 | i += 1
1214 | }
1215 | Some(total)
1216 | }
1217 |
1218 | def newIncrement[@specialized A:Numeric](a:A):A = a + 100
1219 | def newGeneric() = {
1220 | var i = 0
1221 | var total = 0
1222 | while (i < largeSize) {
1223 | total = newIncrement(total)
1224 | i += 1
1225 | }
1226 | Some(total)
1227 | }
1228 |
1229 | def oldIncrement[A](a:A)(implicit n:OldNumeric[A]) = n.plus(a, n.fromInt(100))
1230 | def oldGeneric() = {
1231 | var i = 0
1232 | var total = 0
1233 | while (i < largeSize) {
1234 | total = oldIncrement(total)
1235 | i += 1
1236 | }
1237 | Some(total)
1238 | }
1239 | }
1240 |
1241 | final class IncrementInt4 extends TestCase {
1242 | def name = "increment-int4"
1243 |
1244 | def directIncrement(x:Int) = x + 100
1245 | def direct() = {
1246 | var i = 0
1247 | var total = 0
1248 | while (i < largeSize) {
1249 | total = directIncrement(total)
1250 | i += 1
1251 | }
1252 | Some(total)
1253 | }
1254 |
1255 | def newIncrement[@specialized A:Numeric](a:A):A = 100 + a
1256 | def newGeneric() = {
1257 | var i = 0
1258 | var total = 0
1259 | while (i < largeSize) {
1260 | total = newIncrement(total)
1261 | i += 1
1262 | }
1263 | Some(total)
1264 | }
1265 |
1266 | def oldIncrement[A](a:A)(implicit n:OldNumeric[A]) = n.plus(a, n.fromInt(100))
1267 | def oldGeneric() = {
1268 | var i = 0
1269 | var total = 0
1270 | while (i < largeSize) {
1271 | total = oldIncrement(total)
1272 | i += 1
1273 | }
1274 | Some(total)
1275 | }
1276 | }
1277 |
1278 |
1279 | object Main {
1280 | val tests = List(List(new FromIntToInt,
1281 | new FromIntToLong,
1282 | new FromIntToFloat,
1283 | new FromIntToDouble),
1284 |
1285 | List(new InfixAdderInt,
1286 | new InfixAdderLong,
1287 | new InfixAdderFloat,
1288 | new InfixAdderDouble),
1289 |
1290 | List(new AdderInt,
1291 | new AdderLong,
1292 | new AdderFloat,
1293 | new AdderDouble),
1294 |
1295 | List(new IntArrayAdder,
1296 | new LongArrayAdder,
1297 | new FloatArrayAdder,
1298 | new DoubleArrayAdder),
1299 |
1300 | List(new IntArrayRescale,
1301 | new LongArrayRescale,
1302 | new FloatArrayRescale,
1303 | new DoubleArrayRescale),
1304 |
1305 | List(new ArrayAllocatorInt,
1306 | new ArrayAllocatorLong,
1307 | new ArrayAllocatorFloat,
1308 | new ArrayAllocatorDouble),
1309 |
1310 | List(new FindMaxInt,
1311 | new FindMaxLong,
1312 | new FindMaxFloat,
1313 | new FindMaxDouble),
1314 |
1315 | List(new QuicksortInt,
1316 | new QuicksortLong,
1317 | new QuicksortFloat,
1318 | new QuicksortDouble),
1319 |
1320 | List(new InsertionSortInt,
1321 | new InsertionSortLong,
1322 | new InsertionSortFloat,
1323 | new InsertionSortDouble),
1324 |
1325 | List(new MergeSortInt,
1326 | new MergeSortLong,
1327 | new MergeSortFloat,
1328 | new MergeSortDouble),
1329 |
1330 | List(new IncrementInt1,
1331 | new IncrementInt2,
1332 | new IncrementInt3,
1333 | new IncrementInt4))
1334 |
1335 |
1336 | def getHTMLHeader() = """
1337 |
1338 |
1339 |
1354 |
1355 |
1356 |
1357 |
1358 |
1359 | test | direct (ms) | new (ms) | old (ms) |
1360 |
1361 | """
1362 |
1363 | def getHTMLFooter() = "
\n \n\n"
1364 |
1365 | def main(args:Array[String]): Unit = {
1366 | if (Constant.createHTML) {
1367 | println("creating benchmark.html...")
1368 | printf("%-24s %8s %8s %8s / %6s %6s %6s\n", "test", "direct", "new", "old", "n:d", "o:d", "o:n")
1369 |
1370 | val p = new PrintWriter(new FileWriter("benchmark.html"))
1371 |
1372 | p.println(getHTMLHeader())
1373 | tests.foreach {
1374 | group => {
1375 | p.println(" |
\n")
1376 | group.foreach(_.test(Some(p)))
1377 | }
1378 | }
1379 | p.println(getHTMLFooter())
1380 | p.close()
1381 |
1382 | } else {
1383 | printf("%-24s %8s %8s %8s / %6s %6s %6s\n", "test", "direct", "new", "old", "n:d", "o:d", "o:n")
1384 | tests.foreach {
1385 | group => group.foreach(_.test(None))
1386 | }
1387 | }
1388 | }
1389 | }
1390 |
--------------------------------------------------------------------------------
/plugin/Makefile:
--------------------------------------------------------------------------------
1 | .PHONY: default clean plugin post tree browse base test run runbase runtest all
2 |
3 | # The plugin's source file and JAR to create
4 | # PLUGIN=OptimizedNumeric.scala
5 | PLUGIN=src/main/scala/com/azavea/math/plugin/OptimizedNumeric.scala
6 | PLUGJAR=optimized-numeric.jar
7 | PLUGXML=src/main/resources/scalac-plugin.xml
8 |
9 | # The phase to run before inspecting the AST
10 | PHASE=typer
11 | PHASE2=optimized-numeric
12 |
13 | # The test's source file, and class name
14 | # TEST=Test.scala
15 | # TESTBASE=Test
16 | TEST=src/test/scala/Example.scala
17 | TESTBASE=Example
18 |
19 | # The classpath to use when building/running the test
20 | ## CP=lib/numeric_2.9.1-0.1.jar
21 | CP=../target/scala-2.9.1.final/classes
22 |
23 | # Some help output
24 | default:
25 | @echo "targets: clean | plugin | post | tree | base | test | run"
26 |
27 | # Remove classfiles and the plugin's jar file
28 | clean:
29 | rm -rf classes
30 | find . -name '*.class' -exec rm -f {} \;
31 | rm -f $(PLUGJAR)
32 |
33 | # Build the plugin
34 | plugin:
35 | mkdir -p classes
36 | scalac -cp $(CP) -d classes $(PLUGIN)
37 | cp $(PLUGXML) classes
38 | jar -C classes -c -f $(PLUGJAR) .
39 |
40 | # Various targets for inspecting the AST:
41 | # post shows the AST as source code
42 | # tree prints the actual AST
43 | # browse runs a graphical AST browser
44 | post:
45 | scalac -cp $(CP) -Xprint:$(PHASE) -Ystop-after:$(PHASE) $(TEST) | tee POST
46 |
47 | tree:
48 | scalac -cp $(CP) -Xprint:$(PHASE) -Ystop-after:$(PHASE) -Yshow-trees $(TEST) | tee TREE
49 |
50 | browse:
51 | scalac -cp $(CP) -Ybrowse:$(PHASE) -Ystop-after:$(PHASE) -Yshow-trees $(TEST)
52 |
53 | postplug:
54 | scalac -cp $(CP) -Xplugin:$(PLUGJAR) -Xprint:$(PHASE2) -Ystop-after:$(PHASE2) $(TEST) | tee POST
55 |
56 | # Compile the test without the plugin (the base case)
57 | base:
58 | scalac -cp $(CP) $(TEST)
59 |
60 | # Compile the test with the plugin (the test case)
61 | test:
62 | scalac -cp $(CP) -Xplugin:$(PLUGJAR) $(TEST)
63 |
64 | # Run the test
65 | run:
66 | scala -cp $(CP) $(TESTBASE)
67 |
68 | # Compile and run the base case
69 | runbase: base run
70 |
71 | # Compile and run the test case
72 | runtest: test run
73 |
74 | # Do a clean build of the plugin and test case, then run the test
75 | all: clean plugin test run
76 |
--------------------------------------------------------------------------------
/plugin/build.sbt:
--------------------------------------------------------------------------------
1 | // project name
2 | name := "Optimized Numeric Plugin"
3 |
4 | //sbtPlugin := true
5 |
6 | libraryDependencies <+= scalaVersion("org.scala-lang" % "scala-compiler" % _)
7 |
8 | resolvers += "Scala-Tools Maven2 Snapshots Repository" at "http://scala-tools.org/repo-snapshots"
9 |
10 | // shrug?
11 | version := "0.1"
12 |
13 | // hide backup files
14 | defaultExcludes ~= (filter => filter || "*~")
15 |
16 | scalacOptions += "-optimise"
17 |
18 | // any of these work, although 2.9.1 performs the best
19 | //scalaVersion := "2.8.1
20 | //scalaVersion := "2.9.0-1"
21 | scalaVersion := "2.9.1"
22 |
23 | //crossScalaVersions := List("2.8.1", "2.9.0-1", "2.9.1")
24 |
--------------------------------------------------------------------------------
/plugin/src/main/resources/scalac-plugin.xml:
--------------------------------------------------------------------------------
1 |
2 | optimized-numeric
3 | com.azavea.math.plugin.OptimizedNumeric
4 |
5 |
--------------------------------------------------------------------------------
/plugin/src/main/scala/com/azavea/math/plugin/OptimizedNumeric.scala:
--------------------------------------------------------------------------------
1 | package com.azavea.math.plugin
2 |
3 | import scala.tools.nsc
4 | import nsc.Global
5 | import nsc.Phase
6 | import nsc.plugins.Plugin
7 | import nsc.plugins.PluginComponent
8 | import nsc.transform.Transform
9 | import nsc.transform.InfoTransform
10 | import nsc.transform.TypingTransformers
11 | import nsc.symtab.Flags._
12 | import nsc.ast.TreeDSL
13 | import nsc.typechecker
14 |
15 | /**
16 | * Our shiny compiler plugin.
17 | */
18 | class OptimizedNumeric(val global: Global) extends Plugin {
19 | val name = "optimized-numeric"
20 | val description = "Optimizes com.azavea.math.Numeric usage."
21 | val components = List[PluginComponent](new RewriteInfixOps(this, global))
22 | }
23 |
24 | /**
25 | * This component turns things like:
26 | * 1. new FastNumericOps[T](m)(implicit ev).+(n)
27 | * 2. com.azavea.math.FastImplicits.infixOps[T](m)(implicit ev).*(n)
28 | *
29 | * Into:
30 | * 1. ev.plus(m, n)
31 | * 2. ev.times(m, n)
32 | */
33 | class RewriteInfixOps(plugin:Plugin, val global:Global) extends PluginComponent
34 | with Transform with TypingTransformers with TreeDSL {
35 | import global._
36 | import typer.typed
37 |
38 | // set to true to print a warning for each transform
39 | val debugging = false
40 |
41 | // TODO: maybe look up the definition of op and automatically figure mapping
42 | val unops = Map(
43 | newTermName("abs") -> "abs",
44 | newTermName("unary_$minus") -> "negate",
45 | newTermName("signum") -> "signum"
46 | )
47 |
48 | val binops = Map(
49 | newTermName("compare") -> "compare",
50 | newTermName("equiv") -> "equiv",
51 | newTermName("max") -> "max",
52 | newTermName("min") -> "min",
53 |
54 | newTermName("$less$eq$greater") -> "compare",
55 | newTermName("$div") -> "div",
56 | newTermName("$eq$eq$eq") -> "equiv",
57 | newTermName("$bang$eq$eq") -> "nequiv",
58 | newTermName("$greater") -> "gt",
59 | newTermName("$greater$eq") -> "gteq",
60 | newTermName("$less") -> "lt",
61 | newTermName("$less$eq") -> "lteq",
62 | newTermName("$minus") -> "minus",
63 | newTermName("$percent") -> "mod",
64 | newTermName("$plus") -> "plus",
65 | newTermName("$times") -> "times",
66 | newTermName("$times$times") -> "pow"
67 | )
68 |
69 | val runsAfter = List("typer");
70 | val phaseName = "optimized-numeric"
71 | def newTransformer(unit:CompilationUnit) = new MyTransformer(unit)
72 |
73 | // Determine if two type are equivalent
74 | def equivalentTypes(t1:Type, t2:Type) = {
75 | t1.dealias.deconst.widen =:= t2.dealias.deconst.widen
76 | }
77 |
78 | // TODO: figure out better type matching for Numeric, e.g. a.tpe <:< b.tpe
79 | val numericClass = definitions.getClass("com.azavea.math.Numeric")
80 | def isNumeric(t:Type) = t.typeSymbol == numericClass.tpe.typeSymbol
81 |
82 | // For built-in types, figure out whether or not we have a "fast" conversion method
83 | val BigIntClass = definitions.getClass("scala.math.BigInt")
84 | val BigDecimalClass = definitions.getClass("scala.math.BigDecimal")
85 | def getConverter(t:Type) = if (t <:< definitions.ByteClass.tpe) {
86 | Some("fromByte")
87 | } else if (t <:< definitions.ShortClass.tpe) {
88 | Some("fromShort")
89 | } else if (t <:< definitions.IntClass.tpe) {
90 | Some("fromInt")
91 | } else if (t <:< definitions.LongClass.tpe) {
92 | Some("fromLong")
93 | } else if (t <:< definitions.FloatClass.tpe) {
94 | Some("fromFloat")
95 | } else if (t <:< definitions.DoubleClass.tpe) {
96 | Some("fromDouble")
97 | } else if (t <:< BigIntClass.tpe) {
98 | Some("fromBigInt")
99 | } else if (t <:< BigDecimalClass.tpe) {
100 | Some("fromBigDecimal")
101 | } else {
102 | None
103 | }
104 |
105 | // TODO: maybe match further out on the implicit Numeric[T]?
106 | class MyTransformer(unit:CompilationUnit) extends TypingTransformer(unit) {
107 |
108 | override def transform(tree: Tree): Tree = {
109 | //def mylog(s:String) = if (debugging) unit.warning(tree.pos, s)
110 | def mylog(s:String) = Unit
111 |
112 | val tree2 = tree match {
113 |
114 | // match fuzzy binary operators
115 | case Apply(Apply(TypeApply(Select(Apply(Apply(_, List(m)), List(ev)), op), List(tt)), List(n)), List(ev2)) => {
116 | if (!isNumeric(ev.tpe)) {
117 | //mylog("fuzzy alarm #1")
118 | tree
119 |
120 | } else if (binops.contains(op)) {
121 | val op2 = binops(op)
122 | val conv = getConverter(n.tpe)
123 | conv match {
124 | case Some(meth) => {
125 | //mylog("fuzzy transformed %s (with %s)".format(op, meth))
126 | typed { Apply(Select(ev, op2), List(m, Apply(Select(ev, meth), List(n)))) }
127 | }
128 | case None => if (equivalentTypes(m.tpe, n.tpe)) {
129 | //mylog("fuzzy transformed %s (removed conversion)".format(op))
130 | typed { Apply(Select(ev, op2), List(m, n)) }
131 | } else {
132 | //mylog("fuzzy transformed %s".format(op))
133 | typed { Apply(Select(ev, op2), List(m, Apply(TypeApply(Select(ev, "fromType"), List(tt)), List(n)))) }
134 | }
135 | }
136 |
137 | } else {
138 | //mylog("fuzzy alarm #2")
139 | tree
140 | }
141 | }
142 |
143 | // match IntOps (and friends Float, Long, etc.)
144 | case Apply(Apply(TypeApply(Select(Apply(_, List(m)), op), List(tt)), List(n)), List(ev)) => {
145 | if (!isNumeric(ev.tpe)) {
146 | //mylog("literal ops alarm #1")
147 | tree
148 |
149 | } else if (binops.contains(op)) {
150 | val op2 = binops(op)
151 | val conv = getConverter(m.tpe)
152 | conv match {
153 | case Some(meth) => {
154 | //mylog("zzz literal ops transformed %s (with %s)".format(op, meth))
155 | typed { Apply(Select(ev, op2), List(Apply(Select(ev, meth), List(m)), n)) }
156 | }
157 | case None => {
158 | //mylog("zzz literal ops transformed %s".format(op))
159 | typed { Apply(Select(ev, op2), List(Apply(TypeApply(Select(ev, "fromType"), List(tt)), List(m)), n)) }
160 | }
161 | }
162 |
163 | } else {
164 | //mylog("literal ops alarm #2")
165 | tree
166 | }
167 | }
168 |
169 | // match binary operators
170 | case Apply(Select(Apply(Apply(_, List(m)), List(ev)), op), List(n)) => {
171 | if (!isNumeric(ev.tpe)) {
172 | unit.warning(tree.pos, "binop false alarm #1")
173 | tree
174 | } else if (binops.contains(op)) {
175 | val op2 = binops(op)
176 | //mylog("binop rewrote %s %s %s to n.%s(%s, %s)".format(m, op, n, op2, m, n))
177 | typed { Apply(Select(ev, op2), List(m, n)) }
178 | } else {
179 | unit.warning(tree.pos, "binop false alarm #2")
180 | tree
181 | }
182 | }
183 |
184 | // match unary operators
185 | case Select(Apply(Apply(_, List(m)), List(ev)), op) => {
186 | if (!isNumeric(ev.tpe)) {
187 | unit.warning(tree.pos, "unop false alarm #1")
188 | tree
189 | } else if (unops.contains(op)) {
190 | val op2 = unops(op)
191 | //mylog("unop rewrote %s to n.%s".format(op, op2))
192 | typed { Apply(Select(ev, op2), List(m)) }
193 | } else {
194 | unit.warning(tree.pos, "unop false alarm #2")
195 | tree
196 | }
197 | }
198 |
199 | case _ => tree
200 | }
201 |
202 | super.transform(tree2)
203 | }
204 | }
205 | }
206 |
--------------------------------------------------------------------------------
/plugin/src/test/scala/Example.scala:
--------------------------------------------------------------------------------
1 | import com.azavea.math.Numeric
2 |
3 | //import com.azavea.math.FastImplicits._
4 | import com.azavea.math.EasyImplicits._
5 | import Predef.{any2stringadd => _, _}
6 |
7 | object Example {
8 | def foo1[T:Numeric](m:T, n:T) = m + n
9 | def foo2[T](m:T, n:T)(implicit ev:Numeric[T]) = ev.plus(m, n)
10 |
11 | def bar1[T:Numeric](m:T) = -m
12 | def bar2[T](m:T)(implicit ev:Numeric[T]) = ev.negate(m)
13 |
14 | def duh1[T:Numeric](m:T) = m + 13
15 | def duh2[T](m:T)(implicit ev:Numeric[T]) = ev.plus(m, ev.fromInt(13))
16 |
17 | def yak1[T:Numeric](m:T) = m + BigInt(9)
18 | def yak2[T](m:T)(implicit ev:Numeric[T]) = ev.plus(m, ev.fromBigInt(9))
19 |
20 | def zug1[T:Numeric](n:T) = 1 + n
21 | def zug2[T](n:T)(implicit ev:Numeric[T]) = ev.plus(ev.fromInt(1), n)
22 |
23 | def main(args: Array[String]) {
24 | println(bar1(1))
25 | }
26 | }
27 |
--------------------------------------------------------------------------------
/project/Build.scala:
--------------------------------------------------------------------------------
1 | import sbt._
2 |
3 | object NumericBuild extends Build {
4 | // library
5 | lazy val root = Project("root", file("."))
6 |
7 | // plugin
8 | lazy val plugin = Project("plugin", file("plugin")) dependsOn(root)
9 |
10 | // performance testing
11 | lazy val perf = Project("perf", file("perf")) dependsOn(root, plugin)
12 | }
13 |
--------------------------------------------------------------------------------
/src/main/scala/com/azavea/math/Convertable.scala:
--------------------------------------------------------------------------------
1 | package com.azavea.math
2 |
3 | /**
4 | * @author Erik Osheim
5 | */
6 |
7 | /**
8 | * This package is used to provide concrete implementations of the conversions
9 | * between numeric primitives. The idea here is that the Numeric trait can
10 | * extend these traits to inherit the conversions.
11 | *
12 | * We can also use these implementations to provide a way to convert from
13 | * A -> B, where both A and B are generic Numeric types. Without a separate
14 | * trait, we'd have circular type definitions when compiling Numeric.
15 | */
16 |
17 | import scala.{specialized => spec}
18 |
19 | /**
20 | * Conversions to type.
21 | *
22 | * An object implementing ConvertableTo[A] provides methods to go
23 | * from number types to A.
24 | */
25 | trait ConvertableTo[@spec A] {
26 | implicit def fromByte(a:Byte): A
27 | implicit def fromShort(a:Short): A
28 | implicit def fromInt(a:Int): A
29 | implicit def fromLong(a:Long): A
30 | implicit def fromFloat(a:Float): A
31 | implicit def fromDouble(a:Double): A
32 | implicit def fromBigInt(a:BigInt): A
33 | implicit def fromBigDecimal(a:BigDecimal): A
34 | }
35 |
36 | trait ConvertableToByte extends ConvertableTo[Byte] {
37 | implicit def fromByte(a:Byte): Byte = a
38 | implicit def fromShort(a:Short): Byte = a.toByte
39 | implicit def fromInt(a:Int): Byte = a.toByte
40 | implicit def fromLong(a:Long): Byte = a.toByte
41 | implicit def fromFloat(a:Float): Byte = a.toByte
42 | implicit def fromDouble(a:Double): Byte = a.toByte
43 | implicit def fromBigInt(a:BigInt): Byte = a.toByte
44 | implicit def fromBigDecimal(a:BigDecimal): Byte = a.toByte
45 | }
46 |
47 | trait ConvertableToShort extends ConvertableTo[Short] {
48 | implicit def fromByte(a:Byte): Short = a.toShort
49 | implicit def fromShort(a:Short): Short = a
50 | implicit def fromInt(a:Int): Short = a.toShort
51 | implicit def fromLong(a:Long): Short = a.toShort
52 | implicit def fromFloat(a:Float): Short = a.toShort
53 | implicit def fromDouble(a:Double): Short = a.toShort
54 | implicit def fromBigInt(a:BigInt): Short = a.toShort
55 | implicit def fromBigDecimal(a:BigDecimal): Short = a.toShort
56 | }
57 |
58 | trait ConvertableToInt extends ConvertableTo[Int] {
59 | implicit def fromByte(a:Byte): Int = a.toInt
60 | implicit def fromShort(a:Short): Int = a.toInt
61 | implicit def fromInt(a:Int): Int = a
62 | implicit def fromLong(a:Long): Int = a.toInt
63 | implicit def fromFloat(a:Float): Int = a.toInt
64 | implicit def fromDouble(a:Double): Int = a.toInt
65 | implicit def fromBigInt(a:BigInt): Int = a.toInt
66 | implicit def fromBigDecimal(a:BigDecimal): Int = a.toInt
67 | }
68 |
69 | trait ConvertableToLong extends ConvertableTo[Long] {
70 | implicit def fromByte(a:Byte): Long = a.toLong
71 | implicit def fromShort(a:Short): Long = a.toLong
72 | implicit def fromInt(a:Int): Long = a.toLong
73 | implicit def fromLong(a:Long): Long = a
74 | implicit def fromFloat(a:Float): Long = a.toLong
75 | implicit def fromDouble(a:Double): Long = a.toLong
76 | implicit def fromBigInt(a:BigInt): Long = a.toLong
77 | implicit def fromBigDecimal(a:BigDecimal): Long = a.toLong
78 | }
79 |
80 | trait ConvertableToFloat extends ConvertableTo[Float] {
81 | implicit def fromByte(a:Byte): Float = a.toFloat
82 | implicit def fromShort(a:Short): Float = a.toFloat
83 | implicit def fromInt(a:Int): Float = a.toFloat
84 | implicit def fromLong(a:Long): Float = a.toFloat
85 | implicit def fromFloat(a:Float): Float = a
86 | implicit def fromDouble(a:Double): Float = a.toFloat
87 | implicit def fromBigInt(a:BigInt): Float = a.toFloat
88 | implicit def fromBigDecimal(a:BigDecimal): Float = a.toFloat
89 | }
90 |
91 | trait ConvertableToDouble extends ConvertableTo[Double] {
92 | implicit def fromByte(a:Byte): Double = a.toDouble
93 | implicit def fromShort(a:Short): Double = a.toDouble
94 | implicit def fromInt(a:Int): Double = a.toDouble
95 | implicit def fromLong(a:Long): Double = a.toDouble
96 | implicit def fromFloat(a:Float): Double = a.toDouble
97 | implicit def fromDouble(a:Double): Double = a
98 | implicit def fromBigInt(a:BigInt): Double = a.toDouble
99 | implicit def fromBigDecimal(a:BigDecimal): Double = a.toDouble
100 | }
101 |
102 | trait ConvertableToBigInt extends ConvertableTo[BigInt] {
103 | implicit def fromByte(a:Byte): BigInt = BigInt(a)
104 | implicit def fromShort(a:Short): BigInt = BigInt(a)
105 | implicit def fromInt(a:Int): BigInt = BigInt(a)
106 | implicit def fromLong(a:Long): BigInt = BigInt(a)
107 | implicit def fromFloat(a:Float): BigInt = BigInt(a.toLong)
108 | implicit def fromDouble(a:Double): BigInt = BigInt(a.toLong)
109 | implicit def fromBigInt(a:BigInt): BigInt = a
110 | implicit def fromBigDecimal(a:BigDecimal): BigInt = a.toBigInt
111 | }
112 |
113 | trait ConvertableToBigDecimal extends ConvertableTo[BigDecimal] {
114 | implicit def fromByte(a:Byte): BigDecimal = BigDecimal(a)
115 | implicit def fromShort(a:Short): BigDecimal = BigDecimal(a)
116 | implicit def fromInt(a:Int): BigDecimal = BigDecimal(a)
117 | implicit def fromLong(a:Long): BigDecimal = BigDecimal(a)
118 | implicit def fromFloat(a:Float): BigDecimal = BigDecimal(a)
119 | implicit def fromDouble(a:Double): BigDecimal = BigDecimal(a)
120 | implicit def fromBigInt(a:BigInt): BigDecimal = BigDecimal(a)
121 | implicit def fromBigDecimal(a:BigDecimal): BigDecimal = a
122 | }
123 |
124 | object ConvertableTo {
125 | implicit object ConvertableToByte extends ConvertableToByte
126 | implicit object ConvertableToShort extends ConvertableToShort
127 | implicit object ConvertableToInt extends ConvertableToInt
128 | implicit object ConvertableToLong extends ConvertableToLong
129 | implicit object ConvertableToFloat extends ConvertableToFloat
130 | implicit object ConvertableToDouble extends ConvertableToDouble
131 | implicit object ConvertableToBigInt extends ConvertableToBigInt
132 | implicit object ConvertableToBigDecimal extends ConvertableToBigDecimal
133 | }
134 |
135 |
136 | /**
137 | * Conversions from type.
138 | *
139 | * An object implementing ConvertableFrom[A] provides methods to go
140 | * from A to number types (and String).
141 | */
142 | trait ConvertableFrom[@spec A] {
143 | implicit def toByte(a:A): Byte
144 | implicit def toShort(a:A): Short
145 | implicit def toInt(a:A): Int
146 | implicit def toLong(a:A): Long
147 | implicit def toFloat(a:A): Float
148 | implicit def toDouble(a:A): Double
149 | implicit def toBigInt(a:A): BigInt
150 | implicit def toBigDecimal(a:A): BigDecimal
151 |
152 | implicit def toString(a:A): String
153 | }
154 |
155 | trait ConvertableFromByte extends ConvertableFrom[Byte] {
156 | implicit def toByte(a:Byte): Byte = a
157 | implicit def toShort(a:Byte): Short = a.toShort
158 | implicit def toInt(a:Byte): Int = a.toInt
159 | implicit def toLong(a:Byte): Long = a.toLong
160 | implicit def toFloat(a:Byte): Float = a.toFloat
161 | implicit def toDouble(a:Byte): Double = a.toDouble
162 | implicit def toBigInt(a:Byte): BigInt = BigInt(a)
163 | implicit def toBigDecimal(a:Byte): BigDecimal = BigDecimal(a)
164 |
165 | implicit def toString(a:Byte): String = a.toString
166 | }
167 |
168 | trait ConvertableFromShort extends ConvertableFrom[Short] {
169 | implicit def toByte(a:Short): Byte = a.toByte
170 | implicit def toShort(a:Short): Short = a
171 | implicit def toInt(a:Short): Int = a.toInt
172 | implicit def toLong(a:Short): Long = a.toLong
173 | implicit def toFloat(a:Short): Float = a.toFloat
174 | implicit def toDouble(a:Short): Double = a.toDouble
175 | implicit def toBigInt(a:Short): BigInt = BigInt(a)
176 | implicit def toBigDecimal(a:Short): BigDecimal = BigDecimal(a)
177 |
178 | implicit def toString(a:Short): String = a.toString
179 | }
180 |
181 | trait ConvertableFromInt extends ConvertableFrom[Int] {
182 | implicit def toByte(a:Int): Byte = a.toByte
183 | implicit def toShort(a:Int): Short = a.toShort
184 | implicit def toInt(a:Int): Int = a
185 | implicit def toLong(a:Int): Long = a.toLong
186 | implicit def toFloat(a:Int): Float = a.toFloat
187 | implicit def toDouble(a:Int): Double = a.toDouble
188 | implicit def toBigInt(a:Int): BigInt = BigInt(a)
189 | implicit def toBigDecimal(a:Int): BigDecimal = BigDecimal(a)
190 |
191 | implicit def toString(a:Int): String = a.toString
192 | }
193 |
194 | trait ConvertableFromLong extends ConvertableFrom[Long] {
195 | implicit def toByte(a:Long): Byte = a.toByte
196 | implicit def toShort(a:Long): Short = a.toShort
197 | implicit def toInt(a:Long): Int = a.toInt
198 | implicit def toLong(a:Long): Long = a
199 | implicit def toFloat(a:Long): Float = a.toFloat
200 | implicit def toDouble(a:Long): Double = a.toDouble
201 | implicit def toBigInt(a:Long): BigInt = BigInt(a)
202 | implicit def toBigDecimal(a:Long): BigDecimal = BigDecimal(a)
203 |
204 | implicit def toString(a:Long): String = a.toString
205 | }
206 |
207 | trait ConvertableFromFloat extends ConvertableFrom[Float] {
208 | implicit def toByte(a:Float): Byte = a.toByte
209 | implicit def toShort(a:Float): Short = a.toShort
210 | implicit def toInt(a:Float): Int = a.toInt
211 | implicit def toLong(a:Float): Long = a.toLong
212 | implicit def toFloat(a:Float): Float = a
213 | implicit def toDouble(a:Float): Double = a.toDouble
214 | implicit def toBigInt(a:Float): BigInt = BigInt(a.toLong)
215 | implicit def toBigDecimal(a:Float): BigDecimal = BigDecimal(a)
216 |
217 | implicit def toString(a:Float): String = a.toString
218 | }
219 |
220 | trait ConvertableFromDouble extends ConvertableFrom[Double] {
221 | implicit def toByte(a:Double): Byte = a.toByte
222 | implicit def toShort(a:Double): Short = a.toShort
223 | implicit def toInt(a:Double): Int = a.toInt
224 | implicit def toLong(a:Double): Long = a.toLong
225 | implicit def toFloat(a:Double): Float = a.toFloat
226 | implicit def toDouble(a:Double): Double = a
227 | implicit def toBigInt(a:Double): BigInt = BigInt(a.toLong)
228 | implicit def toBigDecimal(a:Double): BigDecimal = BigDecimal(a)
229 |
230 | implicit def toString(a:Double): String = a.toString
231 | }
232 |
233 | trait ConvertableFromBigInt extends ConvertableFrom[BigInt] {
234 | implicit def toByte(a:BigInt): Byte = a.toByte
235 | implicit def toShort(a:BigInt): Short = a.toShort
236 | implicit def toInt(a:BigInt): Int = a.toInt
237 | implicit def toLong(a:BigInt): Long = a.toLong
238 | implicit def toFloat(a:BigInt): Float = a.toFloat
239 | implicit def toDouble(a:BigInt): Double = a.toDouble
240 | implicit def toBigInt(a:BigInt): BigInt = a
241 | implicit def toBigDecimal(a:BigInt): BigDecimal = BigDecimal(a)
242 |
243 | implicit def toString(a:BigInt): String = a.toString
244 | }
245 |
246 | trait ConvertableFromBigDecimal extends ConvertableFrom[BigDecimal] {
247 | implicit def toByte(a:BigDecimal): Byte = a.toByte
248 | implicit def toShort(a:BigDecimal): Short = a.toShort
249 | implicit def toInt(a:BigDecimal): Int = a.toInt
250 | implicit def toLong(a:BigDecimal): Long = a.toLong
251 | implicit def toFloat(a:BigDecimal): Float = a.toFloat
252 | implicit def toDouble(a:BigDecimal): Double = a.toDouble
253 | implicit def toBigInt(a:BigDecimal): BigInt = a.toBigInt
254 | implicit def toBigDecimal(a:BigDecimal): BigDecimal = a
255 |
256 | implicit def toString(a:BigDecimal): String = a.toString
257 | }
258 |
259 | object ConvertableFrom {
260 | implicit object ConvertableFromByte extends ConvertableFromByte
261 | implicit object ConvertableFromShort extends ConvertableFromShort
262 | implicit object ConvertableFromInt extends ConvertableFromInt
263 | implicit object ConvertableFromLong extends ConvertableFromLong
264 | implicit object ConvertableFromFloat extends ConvertableFromFloat
265 | implicit object ConvertableFromDouble extends ConvertableFromDouble
266 | implicit object ConvertableFromBigInt extends ConvertableFromBigInt
267 | implicit object ConvertableFromBigDecimal extends ConvertableFromBigDecimal
268 | }
269 |
--------------------------------------------------------------------------------
/src/main/scala/com/azavea/math/EasyNumericOps.scala:
--------------------------------------------------------------------------------
1 | package com.azavea.math
2 |
3 | import scala.{specialized => spec}
4 |
5 | /**
6 | * @author Erik Osheim
7 | *
8 | * NumericOps adds operators to A. It's intended to be used as an implicit
9 | * decorator like so:
10 | *
11 | * def foo[A:Numeric](a:A, b:A) = a + b
12 | *
13 | * (compiled into) = new NumericOps(a).+(b)
14 | * (w/plugin into) = numeric.add(a, b)
15 | */
16 | final class EasyNumericOps[@spec(Int,Long,Float,Double) A:Numeric](val lhs:A) {
17 | val n = implicitly[Numeric[A]]
18 |
19 | def abs = n.abs(lhs)
20 | def unary_- = n.negate(lhs)
21 | def signum = n.signum(lhs)
22 |
23 | def compare[B:ConvertableFrom](rhs:B) = n.compare(lhs, n.fromType(rhs))
24 | def equiv[B:ConvertableFrom](rhs:B) = n.equiv(lhs, n.fromType(rhs))
25 | def max[B:ConvertableFrom](rhs:B) = n.max(lhs, n.fromType(rhs))
26 | def min[B:ConvertableFrom](rhs:B) = n.min(lhs, n.fromType(rhs))
27 |
28 | def <=>[B:ConvertableFrom](rhs:B) = n.compare(lhs, n.fromType(rhs))
29 | def ===[B:ConvertableFrom](rhs:B) = n.equiv(lhs, n.fromType(rhs))
30 | def !==[B:ConvertableFrom](rhs:B) = n.nequiv(lhs, n.fromType(rhs))
31 |
32 | def /[B:ConvertableFrom](rhs:B) = n.div(lhs, n.fromType(rhs))
33 | def >[B:ConvertableFrom](rhs:B) = n.gt(lhs, n.fromType(rhs))
34 | def >=[B:ConvertableFrom](rhs:B) = n.gteq(lhs, n.fromType(rhs))
35 | def <[B:ConvertableFrom](rhs:B) = n.lt(lhs, n.fromType(rhs))
36 | def <=[B:ConvertableFrom](rhs:B) = n.lteq(lhs, n.fromType(rhs))
37 | def -[B:ConvertableFrom](rhs:B) = n.minus(lhs, n.fromType(rhs))
38 | def %[B:ConvertableFrom](rhs:B) = n.mod(lhs, n.fromType(rhs))
39 | def +[B:ConvertableFrom](rhs:B) = n.plus(lhs, n.fromType(rhs))
40 | def *[B:ConvertableFrom](rhs:B) = n.times(lhs, n.fromType(rhs))
41 | def **[B:ConvertableFrom](rhs:B) = n.pow(lhs, n.fromType(rhs))
42 |
43 | def toByte = n.toByte(lhs)
44 | def toShort = n.toShort(lhs)
45 | def toInt = n.toInt(lhs)
46 | def toLong = n.toLong(lhs)
47 | def toFloat = n.toFloat(lhs)
48 | def toDouble = n.toDouble(lhs)
49 | def toBigInt = n.toBigInt(lhs)
50 | def toBigDecimal = n.toBigDecimal(lhs)
51 | }
52 |
--------------------------------------------------------------------------------
/src/main/scala/com/azavea/math/FastNumericOps.scala:
--------------------------------------------------------------------------------
1 | package com.azavea.math
2 |
3 | /**
4 | * @author Erik Osheim
5 | */
6 |
7 | import scala.{specialized => spec}
8 |
9 | /**
10 | * NumericOps adds things like inline operators to A. It's intended to
11 | * be used as an implicit decorator like so:
12 | *
13 | * def foo[A:Numeric](a:A, b:A) = a + b
14 | * (this is translated into) = new NumericOps(a).+(b)
15 | */
16 | final class FastNumericOps[@spec(Int,Long,Float,Double) A:Numeric](val lhs:A) {
17 | val n = implicitly[Numeric[A]]
18 |
19 | def abs = n.abs(lhs)
20 | def unary_- = n.negate(lhs)
21 | def signum = n.signum(lhs)
22 |
23 | def compare(rhs:A) = n.compare(lhs, rhs)
24 | def equiv(rhs:A) = n.equiv(lhs, rhs)
25 | def max(rhs:A) = n.max(lhs, rhs)
26 | def min(rhs:A) = n.min(lhs, rhs)
27 |
28 | def <=>(rhs:A) = n.compare(lhs, rhs)
29 | def ===(rhs:A) = n.equiv(lhs, rhs)
30 | def !==(rhs:A) = n.nequiv(lhs, rhs)
31 | def >(rhs:A) = n.gt(lhs, rhs)
32 | def >=(rhs:A) = n.gteq(lhs, rhs)
33 | def <(rhs:A) = n.lt(lhs, rhs)
34 | def <=(rhs:A) = n.lteq(lhs, rhs)
35 | def /(rhs:A) = n.div(lhs, rhs)
36 | def -(rhs:A) = n.minus(lhs, rhs)
37 | def %(rhs:A) = n.mod(lhs, rhs)
38 | def +(rhs:A) = n.plus(lhs, rhs)
39 | def *(rhs:A) = n.times(lhs, rhs)
40 | def **(rhs:A) = n.pow(lhs, rhs)
41 |
42 | def toByte = n.toByte(lhs)
43 | def toShort = n.toShort(lhs)
44 | def toInt = n.toInt(lhs)
45 | def toLong = n.toLong(lhs)
46 | def toFloat = n.toFloat(lhs)
47 | def toDouble = n.toDouble(lhs)
48 | def toBigInt = n.toBigInt(lhs)
49 | def toBigDecimal = n.toBigDecimal(lhs)
50 | }
51 |
--------------------------------------------------------------------------------
/src/main/scala/com/azavea/math/LiteralOps.scala:
--------------------------------------------------------------------------------
1 | package com.azavea.math
2 |
3 | import scala.{specialized => spec}
4 |
5 | /**
6 | * IntOps, LongOps and friends provide the same tilde operators as NumericOps
7 | * (such as +~, -~, *~, etc) for the number types we're interested in.
8 | *
9 | * Using these, we use these operators with literals, number types, and
10 | * generic types. For instance:
11 | *
12 | * def foo[A:Numeric](a:A, b:Int) = (a *~ b) +~ 1
13 | */
14 |
15 | final class LiteralIntOps(val lhs:Int) {
16 | def compare[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.compare(n.fromInt(lhs), rhs)
17 | def equiv[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.equiv(n.fromInt(lhs), rhs)
18 | def max[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.max(n.fromInt(lhs), rhs)
19 | def min[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.min(n.fromInt(lhs), rhs)
20 |
21 | def <=>[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.compare(n.fromInt(lhs), rhs)
22 | def ===[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.equiv(n.fromInt(lhs), rhs)
23 | def !==[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.nequiv(n.fromInt(lhs), rhs)
24 |
25 | def /[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.div(n.fromInt(lhs), rhs)
26 | def >[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.gt(n.fromInt(lhs), rhs)
27 | def >=[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.gteq(n.fromInt(lhs), rhs)
28 | def <[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.lt(n.fromInt(lhs), rhs)
29 | def <=[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.lteq(n.fromInt(lhs), rhs)
30 | def -[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.minus(n.fromInt(lhs), rhs)
31 | def %[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.mod(n.fromInt(lhs), rhs)
32 | def +[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.plus(n.fromInt(lhs), rhs)
33 | def *[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.times(n.fromInt(lhs), rhs)
34 | def **[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.pow(n.fromInt(lhs), rhs)
35 | def toNumeric[@spec(Int,Long,Float,Double) A](implicit n:Numeric[A]) = n.fromInt(lhs)
36 |
37 | def toBigInt() = BigInt(lhs)
38 | def toBigDecimal() = BigDecimal(lhs)
39 | }
40 |
41 |
42 | final class LiteralLongOps(val lhs:Long) {
43 | def compare[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.compare(n.fromLong(lhs), rhs)
44 | def equiv[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.equiv(n.fromLong(lhs), rhs)
45 | def max[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.max(n.fromLong(lhs), rhs)
46 | def min[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.min(n.fromLong(lhs), rhs)
47 |
48 | def <=>[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.compare(n.fromLong(lhs), rhs)
49 | def ===[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.equiv(n.fromLong(lhs), rhs)
50 | def !==[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.nequiv(n.fromLong(lhs), rhs)
51 |
52 | def /[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.div(n.fromLong(lhs), rhs)
53 | def >[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.gt(n.fromLong(lhs), rhs)
54 | def >=[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.gteq(n.fromLong(lhs), rhs)
55 | def <[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.lt(n.fromLong(lhs), rhs)
56 | def <=[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.lteq(n.fromLong(lhs), rhs)
57 | def -[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.minus(n.fromLong(lhs), rhs)
58 | def %[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.mod(n.fromLong(lhs), rhs)
59 | def +[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.plus(n.fromLong(lhs), rhs)
60 | def *[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.times(n.fromLong(lhs), rhs)
61 | def **[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.pow(n.fromLong(lhs), rhs)
62 | def toNumeric[@spec(Int,Long,Float,Double) A](implicit n:Numeric[A]) = n.fromLong(lhs)
63 |
64 | def toBigInt() = BigInt(lhs)
65 | def toBigDecimal() = BigDecimal(lhs)
66 | }
67 |
68 |
69 | final class LiteralFloatOps(val lhs:Float) {
70 | def compare[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.compare(n.fromFloat(lhs), rhs)
71 | def equiv[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.equiv(n.fromFloat(lhs), rhs)
72 | def max[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.max(n.fromFloat(lhs), rhs)
73 | def min[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.min(n.fromFloat(lhs), rhs)
74 |
75 | def <=>[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.compare(n.fromFloat(lhs), rhs)
76 | def ===[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.equiv(n.fromFloat(lhs), rhs)
77 | def !==[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.nequiv(n.fromFloat(lhs), rhs)
78 |
79 | def /[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.div(n.fromFloat(lhs), rhs)
80 | def >[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.gt(n.fromFloat(lhs), rhs)
81 | def >=[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.gteq(n.fromFloat(lhs), rhs)
82 | def <[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.lt(n.fromFloat(lhs), rhs)
83 | def <=[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.lteq(n.fromFloat(lhs), rhs)
84 | def -[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.minus(n.fromFloat(lhs), rhs)
85 | def %[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.mod(n.fromFloat(lhs), rhs)
86 | def +[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.plus(n.fromFloat(lhs), rhs)
87 | def *[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.times(n.fromFloat(lhs), rhs)
88 | def **[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.pow(n.fromFloat(lhs), rhs)
89 | def toNumeric[@spec(Int,Long,Float,Double) A](implicit n:Numeric[A]) = n.fromFloat(lhs)
90 |
91 | def toBigInt() = BigDecimal(lhs).toBigInt
92 | def toBigDecimal() = BigDecimal(lhs)
93 | }
94 |
95 |
96 | final class LiteralDoubleOps(val lhs:Double) {
97 | def compare[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.compare(n.fromDouble(lhs), rhs)
98 | def equiv[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.equiv(n.fromDouble(lhs), rhs)
99 | def max[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.max(n.fromDouble(lhs), rhs)
100 | def min[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.min(n.fromDouble(lhs), rhs)
101 |
102 | def <=>[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.compare(n.fromDouble(lhs), rhs)
103 | def ===[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.equiv(n.fromDouble(lhs), rhs)
104 | def !==[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.nequiv(n.fromDouble(lhs), rhs)
105 |
106 | def /[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.div(n.fromDouble(lhs), rhs)
107 | def >[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.gt(n.fromDouble(lhs), rhs)
108 | def >=[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.gteq(n.fromDouble(lhs), rhs)
109 | def <[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.lt(n.fromDouble(lhs), rhs)
110 | def <=[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.lteq(n.fromDouble(lhs), rhs)
111 | def -[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.minus(n.fromDouble(lhs), rhs)
112 | def %[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.mod(n.fromDouble(lhs), rhs)
113 | def +[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.plus(n.fromDouble(lhs), rhs)
114 | def *[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.times(n.fromDouble(lhs), rhs)
115 | def **[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.pow(n.fromDouble(lhs), rhs)
116 | def toNumeric[@spec(Int,Long,Float,Double) A](implicit n:Numeric[A]) = n.fromDouble(lhs)
117 |
118 | def toBigInt() = BigDecimal(lhs).toBigInt
119 | def toBigDecimal() = BigDecimal(lhs)
120 | }
121 |
122 |
123 | final class LiteralBigIntOps(val lhs:BigInt) {
124 | def compare[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.compare(n.fromBigInt(lhs), rhs)
125 | def equiv[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.equiv(n.fromBigInt(lhs), rhs)
126 | def max[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.max(n.fromBigInt(lhs), rhs)
127 | def min[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.min(n.fromBigInt(lhs), rhs)
128 |
129 | def <=>[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.compare(n.fromBigInt(lhs), rhs)
130 | def ===[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.equiv(n.fromBigInt(lhs), rhs)
131 | def !==[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.nequiv(n.fromBigInt(lhs), rhs)
132 |
133 | def /[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.div(n.fromBigInt(lhs), rhs)
134 | def >[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.gt(n.fromBigInt(lhs), rhs)
135 | def >=[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.gteq(n.fromBigInt(lhs), rhs)
136 | def <[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.lt(n.fromBigInt(lhs), rhs)
137 | def <=[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.lteq(n.fromBigInt(lhs), rhs)
138 | def -[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.minus(n.fromBigInt(lhs), rhs)
139 | def %[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.mod(n.fromBigInt(lhs), rhs)
140 | def +[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.plus(n.fromBigInt(lhs), rhs)
141 | def *[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.times(n.fromBigInt(lhs), rhs)
142 | def **[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.pow(n.fromBigInt(lhs), rhs)
143 | def toNumeric[@spec(Int,Long,Float,Double) A](implicit n:Numeric[A]) = n.fromBigInt(lhs)
144 |
145 | def toBigInt() = lhs
146 | def toBigDecimal() = BigDecimal(lhs)
147 | }
148 |
149 |
150 | final class LiteralBigDecimalOps(val lhs:BigDecimal) {
151 | def compare[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.compare(n.fromBigDecimal(lhs), rhs)
152 | def equiv[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.equiv(n.fromBigDecimal(lhs), rhs)
153 | def max[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.max(n.fromBigDecimal(lhs), rhs)
154 | def min[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.min(n.fromBigDecimal(lhs), rhs)
155 |
156 | def <=>[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.compare(n.fromBigDecimal(lhs), rhs)
157 | def ===[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.equiv(n.fromBigDecimal(lhs), rhs)
158 | def !==[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.nequiv(n.fromBigDecimal(lhs), rhs)
159 |
160 | def /[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.div(n.fromBigDecimal(lhs), rhs)
161 | def >[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.gt(n.fromBigDecimal(lhs), rhs)
162 | def >=[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.gteq(n.fromBigDecimal(lhs), rhs)
163 | def <[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.lt(n.fromBigDecimal(lhs), rhs)
164 | def <=[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.lteq(n.fromBigDecimal(lhs), rhs)
165 | def -[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.minus(n.fromBigDecimal(lhs), rhs)
166 | def %[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.mod(n.fromBigDecimal(lhs), rhs)
167 | def +[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.plus(n.fromBigDecimal(lhs), rhs)
168 | def *[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.times(n.fromBigDecimal(lhs), rhs)
169 | def **[@spec(Int,Long,Float,Double) A](rhs:A)(implicit n:Numeric[A]) = n.pow(n.fromBigDecimal(lhs), rhs)
170 | def toNumeric[@spec(Int,Long,Float,Double) A](implicit n:Numeric[A]) = n.fromBigDecimal(lhs)
171 |
172 | def toBigInt() = lhs.toBigInt
173 | def toBigDecimal() = lhs
174 | }
175 |
--------------------------------------------------------------------------------
/src/main/scala/com/azavea/math/Numeric.scala:
--------------------------------------------------------------------------------
1 | package com.azavea.math
2 |
3 | import scala.math.{abs, min, max, pow}
4 | import annotation.implicitNotFound
5 |
6 | /**
7 | * @author Erik Osheim
8 | */
9 |
10 |
11 | /**
12 | * Numeric typeclass for doing operations on generic types.
13 | *
14 | * Importantly, this package does not deliver classes for you to instantiate.
15 | * Rather, it gives you a trait to associated with your generic types, which
16 | * allows actual uses of your generic code with concrete types (e.g. Int) to
17 | * link up with concrete implementations (e.g. IntIsNumeric) of Numeric's
18 | * method for that type.
19 | *
20 | * @example {{{
21 | * import demo.Numeric
22 | * import demo.Numeric.FastImplicits._
23 | *
24 | * def pythagoreanTheorem[T:Numeric](a:T, b:T): Double = {
25 | * val c = (a * a) + (b * b)
26 | * math.sqrt(c.toDouble)
27 | * }
28 | *
29 | * def
30 | * }}}
31 | *
32 | */
33 | //@implicitNotFound(msg = "Cannot find Numeric type class for ${A}")
34 | trait Numeric[@specialized(Int,Long,Float,Double) A]
35 | extends ConvertableFrom[A] with ConvertableTo[A] {
36 |
37 | /**
38 | * Computes the absolute value of `a`.
39 | *
40 | * @return the absolute value of `a`
41 | */
42 | def abs(a:A):A
43 |
44 | /**
45 | * Returns an integer whose sign denotes the relationship between
46 | * `a` and `b`. If `a` < `b` it returns -1, if `a` == `b` it returns
47 | * 0 and if `a` > `b` it returns 1.
48 | *
49 | * @return -1, 0 or 1
50 | *
51 | * @see math.abs
52 | */
53 | def compare(a:A, b:A):Int = if (lt(a, b)) -1 else if (gt(a, b)) 1 else 0
54 |
55 | /**
56 | * Divides `a` by `b`.
57 | *
58 | * This method maintains the type of the arguments (`A`). If this
59 | * method is used with `Int` or `Long` then the quotient (as in
60 | * integer division). Otherwise (with `Float` and `Double`) a
61 | * fractional result is returned.
62 | *
63 | * @return `a` / `b`
64 | */
65 | def div(a:A, b:A):A
66 |
67 | /**
68 | * Tests if `a` and `b` are equivalent.
69 | *
70 | * @return `a` == `b`
71 | */
72 | def equiv(a:A, b:A):Boolean
73 |
74 | /**
75 | * Tests if `a` and `b` are not equivalent.
76 | *
77 | * @return `a` != `b`
78 | */
79 | def nequiv(a:A, b:A):Boolean
80 |
81 | /**
82 | * Tests if `a` is greater than `b`.
83 | *
84 | * @return `a` > `b`
85 | */
86 | def gt(a:A, b:A):Boolean
87 |
88 | /**
89 | * Tests if `a` is greater than or equal to `b`.
90 | *
91 | * @return `a` >= `b`
92 | */
93 | def gteq(a:A, b:A):Boolean
94 |
95 | /**
96 | * Tests if `a` is less than `b`.
97 | *
98 | * @return `a` <= `b`
99 | */
100 | def lt(a:A, b:A):Boolean
101 |
102 | /**
103 | * Tests if `a` is less than or equal to `b`.
104 | *
105 | * @return `a` <= `b`
106 | */
107 | def lteq(a:A, b:A):Boolean
108 |
109 | /**
110 | * Returns the larger of `a` and `b`.
111 | *
112 | * @return max(`a`, `b`)
113 | *
114 | * @see math.max
115 | */
116 | def max(a:A, b:A):A
117 |
118 | /**
119 | * Returns the smaller of `a` and `b`.
120 | *
121 | * @return min(`a`, `b`)
122 | *
123 | * @see math.min
124 | */
125 | def min(a:A, b:A):A
126 |
127 | /**
128 | * Returns `a` minus `b`.
129 | *
130 | * @return `a` - `b`
131 | */
132 | def minus(a:A, b:A):A
133 |
134 | /**
135 | * Returns `a` modulo `b`.
136 | *
137 | * @return `a` % `b`
138 | */
139 | def mod(a:A, b:A):A
140 |
141 | /**
142 | * Returns the additive inverse `a`.
143 | *
144 | * @return -`a`
145 | */
146 | def negate(a:A):A
147 |
148 | /**
149 | * Returns one.
150 | *
151 | * @return 1
152 | */
153 | def one:A
154 |
155 | /**
156 | * Returns `a` plus `b`.
157 | *
158 | * @return `a` + `b`
159 | */
160 | def plus(a:A, b:A):A
161 |
162 | /**
163 | * Returns `a` to the `b`th power.
164 | *
165 | * Note that with large numbers this method will overflow and
166 | * return Infinity, which becomes MaxValue for whatever type
167 | * is being used. This behavior is inherited from `math.pow`.
168 | *
169 | * @returns pow(`a`, `b`)
170 | *
171 | * @see math.pow
172 | */
173 | def pow(a:A, b:A):A
174 |
175 | /**
176 | * Returns an integer whose sign denotes the sign of `a`.
177 | * If `a` is negative it returns -1, if `a` is zero it
178 | * returns 0 and if `a` is positive it returns 1.
179 | *
180 | * @return -1, 0 or 1
181 | */
182 | def signum(a:A):Int = compare(a, zero)
183 |
184 | /**
185 | * Returns `a` times `b`.
186 | *
187 | * @return `a` * `b`
188 | */
189 | def times(a:A, b:A):A
190 |
191 | /**
192 | * Returns zero.
193 | *
194 | * @return 0
195 | */
196 | def zero:A
197 |
198 | /**
199 | * Convert a value `b` of type `B` to type `A`.
200 | *
201 | * This method can be used to coerce one generic numeric type to
202 | * another, to allow operations on them jointly.
203 | *
204 | * @example {{{
205 | * def foo[A:Numeric,B:Numeric](a:A, b:B) = {
206 | * val n = implicitly[Numeric[A]]
207 | * n.add(a, n.fromType(b))
208 | * }
209 | * }}}
210 | *
211 | * Note that `b` may lose precision when represented as an `A`
212 | * (e.g. if B is Long and A is Int).
213 | *
214 | * @return the value of `b` encoded in type `A`
215 | */
216 | def fromType[@specialized(Int, Long, Float, Double) B](b:B)(implicit c:ConvertableFrom[B]): A
217 |
218 | def toType[@specialized(Int, Long, Float, Double) B](a:A)(implicit c:ConvertableTo[B]): B
219 |
220 | /**
221 | * Used to get an Ordering[A] instance.
222 | */
223 | def getOrdering():Ordering[A] = new NumericOrdering(this)
224 | }
225 |
226 | /**
227 | * This is a little helper class that allows us to support the Ordering trait.
228 | *
229 | * If Numeric extended Ordering directly then we'd have to override all of
230 | * the comparison operators, losing specialization and other performance
231 | * benefits.
232 | */
233 | class NumericOrdering[A](n:Numeric[A]) extends Ordering[A] {
234 | def compare(a:A, b:A) = n.compare(a, b)
235 | }
236 |
237 | trait IntIsNumeric
238 | extends Numeric[Int] with ConvertableFromInt with ConvertableToInt {
239 | def abs(a:Int): Int = scala.math.abs(a)
240 | def div(a:Int, b:Int): Int = a / b
241 | def equiv(a:Int, b:Int): Boolean = a == b
242 | def gt(a:Int, b:Int): Boolean = a > b
243 | def gteq(a:Int, b:Int): Boolean = a >= b
244 | def lt(a:Int, b:Int): Boolean = a < b
245 | def lteq(a:Int, b:Int): Boolean = a <= b
246 | def max(a:Int, b:Int): Int = scala.math.max(a, b)
247 | def min(a:Int, b:Int): Int = scala.math.min(a, b)
248 | def minus(a:Int, b:Int): Int = a - b
249 | def mod(a:Int, b:Int): Int = a % b
250 | def negate(a:Int): Int = -a
251 | def nequiv(a:Int, b:Int): Boolean = a != b
252 | def one: Int = 1
253 | def plus(a:Int, b:Int): Int = a + b
254 | def pow(a:Int, b:Int): Int = scala.math.pow(a, b).toInt
255 | def times(a:Int, b:Int): Int = a * b
256 | def zero: Int = 0
257 |
258 | def fromType[@specialized(Int, Long, Float, Double) B](b:B)(implicit c:ConvertableFrom[B]) = c.toInt(b)
259 | def toType[@specialized(Int, Long, Float, Double) B](a:Int)(implicit c:ConvertableTo[B]) = c.fromInt(a)
260 | }
261 |
262 | trait LongIsNumeric
263 | extends Numeric[Long] with ConvertableFromLong with ConvertableToLong {
264 | def abs(a:Long): Long = scala.math.abs(a)
265 | def div(a:Long, b:Long): Long = a / b
266 | def equiv(a:Long, b:Long): Boolean = a == b
267 | def gt(a:Long, b:Long): Boolean = a > b
268 | def gteq(a:Long, b:Long): Boolean = a >= b
269 | def lt(a:Long, b:Long): Boolean = a < b
270 | def lteq(a:Long, b:Long): Boolean = a <= b
271 | def max(a:Long, b:Long): Long = scala.math.max(a, b)
272 | def min(a:Long, b:Long): Long = scala.math.min(a, b)
273 | def minus(a:Long, b:Long): Long = a - b
274 | def mod(a:Long, b:Long): Long = a % b
275 | def negate(a:Long): Long = -a
276 | def nequiv(a:Long, b:Long): Boolean = a != b
277 | def one: Long = 1L
278 | def plus(a:Long, b:Long): Long = a + b
279 | def pow(a:Long, b:Long): Long = scala.math.pow(a, b).toLong
280 | def times(a:Long, b:Long): Long = a * b
281 | def zero: Long = 0L
282 |
283 | def fromType[@specialized(Int, Long, Float, Double) B](b:B)(implicit c:ConvertableFrom[B]) = c.toLong(b)
284 | def toType[@specialized(Int, Long, Float, Double) B](a:Long)(implicit c:ConvertableTo[B]) = c.fromLong(a)
285 | }
286 |
287 | trait FloatIsNumeric
288 | extends Numeric[Float] with ConvertableFromFloat with ConvertableToFloat {
289 | def abs(a:Float): Float = scala.math.abs(a)
290 | def div(a:Float, b:Float): Float = a / b
291 | def equiv(a:Float, b:Float): Boolean = a == b
292 | def gt(a:Float, b:Float): Boolean = a > b
293 | def gteq(a:Float, b:Float): Boolean = a >= b
294 | def lt(a:Float, b:Float): Boolean = a < b
295 | def lteq(a:Float, b:Float): Boolean = a <= b
296 | def max(a:Float, b:Float): Float = scala.math.max(a, b)
297 | def min(a:Float, b:Float): Float = scala.math.min(a, b)
298 | def minus(a:Float, b:Float): Float = a - b
299 | def mod(a:Float, b:Float): Float = a % b
300 | def negate(a:Float): Float = -a
301 | def nequiv(a:Float, b:Float): Boolean = a != b
302 | def one: Float = 1.0F
303 | def plus(a:Float, b:Float): Float = a + b
304 | def pow(a:Float, b:Float): Float = scala.math.pow(a, b).toFloat
305 | def times(a:Float, b:Float): Float = a * b
306 | def zero: Float = 0.0F
307 |
308 | def fromType[@specialized(Int, Long, Float, Double) B](b:B)(implicit c:ConvertableFrom[B]) = c.toFloat(b)
309 | def toType[@specialized(Int, Long, Float, Double) B](a:Float)(implicit c:ConvertableTo[B]) = c.fromFloat(a)
310 | }
311 |
312 | trait DoubleIsNumeric
313 | extends Numeric[Double] with ConvertableFromDouble with ConvertableToDouble {
314 | def abs(a:Double): Double = scala.math.abs(a)
315 | def div(a:Double, b:Double): Double = a / b
316 | def equiv(a:Double, b:Double): Boolean = a == b
317 | def gt(a:Double, b:Double): Boolean = a > b
318 | def gteq(a:Double, b:Double): Boolean = a >= b
319 | def lt(a:Double, b:Double): Boolean = a < b
320 | def lteq(a:Double, b:Double): Boolean = a <= b
321 | def max(a:Double, b:Double): Double = scala.math.max(a, b)
322 | def min(a:Double, b:Double): Double = scala.math.min(a, b)
323 | def minus(a:Double, b:Double): Double = a - b
324 | def mod(a:Double, b:Double): Double = a % b
325 | def negate(a:Double): Double = -a
326 | def nequiv(a:Double, b:Double): Boolean = a != b
327 | def one: Double = 1.0
328 | def plus(a:Double, b:Double): Double = a + b
329 | def pow(a:Double, b:Double): Double = scala.math.pow(a, b)
330 | def times(a:Double, b:Double): Double = a * b
331 | def zero: Double = 0.0
332 |
333 | def fromType[@specialized(Int, Long, Float, Double) B](b:B)(implicit c:ConvertableFrom[B]) = c.toDouble(b)
334 | def toType[@specialized(Int, Long, Float, Double) B](a:Double)(implicit c:ConvertableTo[B]) = c.fromDouble(a)
335 | }
336 |
337 | trait BigIntIsNumeric
338 | extends Numeric[BigInt] with ConvertableFromBigInt with ConvertableToBigInt {
339 | def abs(a:BigInt): BigInt = a.abs
340 | def div(a:BigInt, b:BigInt): BigInt = a / b
341 | def equiv(a:BigInt, b:BigInt): Boolean = a == b
342 | def gt(a:BigInt, b:BigInt): Boolean = a > b
343 | def gteq(a:BigInt, b:BigInt): Boolean = a >= b
344 | def lt(a:BigInt, b:BigInt): Boolean = a < b
345 | def lteq(a:BigInt, b:BigInt): Boolean = a <= b
346 | def max(a:BigInt, b:BigInt): BigInt = a.max(b)
347 | def min(a:BigInt, b:BigInt): BigInt = a.min(b)
348 | def minus(a:BigInt, b:BigInt): BigInt = a - b
349 | def mod(a:BigInt, b:BigInt): BigInt = a % b
350 | def negate(a:BigInt): BigInt = -a
351 | def nequiv(a:BigInt, b:BigInt): Boolean = a != b
352 | def one: BigInt = BigInt(1)
353 | def plus(a:BigInt, b:BigInt): BigInt = a + b
354 | def pow(a:BigInt, b:BigInt): BigInt = a.pow(b)
355 | def times(a:BigInt, b:BigInt): BigInt = a * b
356 | def zero: BigInt = BigInt(0)
357 |
358 | def fromType[@specialized(Int, Long, Float, Double) B](b:B)(implicit c:ConvertableFrom[B]) = c.toBigInt(b)
359 | def toType[@specialized(Int, Long, Float, Double) B](a:BigInt)(implicit c:ConvertableTo[B]) = c.fromBigInt(a)
360 | }
361 |
362 | trait BigDecimalIsNumeric
363 | extends Numeric[BigDecimal] with ConvertableFromBigDecimal with ConvertableToBigDecimal {
364 | def abs(a:BigDecimal): BigDecimal = a.abs
365 | def div(a:BigDecimal, b:BigDecimal): BigDecimal = a / b
366 | def equiv(a:BigDecimal, b:BigDecimal): Boolean = a == b
367 | def gt(a:BigDecimal, b:BigDecimal): Boolean = a > b
368 | def gteq(a:BigDecimal, b:BigDecimal): Boolean = a >= b
369 | def lt(a:BigDecimal, b:BigDecimal): Boolean = a < b
370 | def lteq(a:BigDecimal, b:BigDecimal): Boolean = a <= b
371 | def max(a:BigDecimal, b:BigDecimal): BigDecimal = a.max(b)
372 | def min(a:BigDecimal, b:BigDecimal): BigDecimal = a.min(b)
373 | def minus(a:BigDecimal, b:BigDecimal): BigDecimal = a - b
374 | def mod(a:BigDecimal, b:BigDecimal): BigDecimal = a % b
375 | def negate(a:BigDecimal): BigDecimal = -a
376 | def nequiv(a:BigDecimal, b:BigDecimal): Boolean = a != b
377 | def one: BigDecimal = BigDecimal(1.0)
378 | def plus(a:BigDecimal, b:BigDecimal): BigDecimal = a + b
379 | def pow(a:BigDecimal, b:BigDecimal): BigDecimal = a.pow(b)
380 | def times(a:BigDecimal, b:BigDecimal): BigDecimal = a * b
381 | def zero: BigDecimal = BigDecimal(0.0)
382 |
383 | def fromType[@specialized(Int, Long, Float, Double) B](b:B)(implicit c:ConvertableFrom[B]) = c.toBigDecimal(b)
384 | def toType[@specialized(Int, Long, Float, Double) B](a:BigDecimal)(implicit c:ConvertableTo[B]) = c.fromBigDecimal(a)
385 | }
386 |
387 |
388 | /**
389 | * This companion object provides the instances (e.g. IntIsNumeric)
390 | * associating the type class (Numeric) with its member type (Int).
391 | */
392 | object Numeric {
393 | implicit object IntIsNumeric extends IntIsNumeric
394 | implicit object LongIsNumeric extends LongIsNumeric
395 | implicit object FloatIsNumeric extends FloatIsNumeric
396 | implicit object DoubleIsNumeric extends DoubleIsNumeric
397 | implicit object BigIntIsNumeric extends BigIntIsNumeric
398 | implicit object BigDecimalIsNumeric extends BigDecimalIsNumeric
399 |
400 | def numeric[@specialized(Int, Long, Float, Double) A:Numeric]:Numeric[A] = implicitly[Numeric[A]]
401 | }
402 |
403 | object FastImplicits {
404 | implicit def infixOps[@specialized(Int, Long, Float, Double) A:Numeric](a:A) = new FastNumericOps(a)
405 |
406 | implicit def infixIntOps(i:Int) = new LiteralIntOps(i)
407 | implicit def infixLongOps(l:Long) = new LiteralLongOps(l)
408 | implicit def infixFloatOps(f:Float) = new LiteralFloatOps(f)
409 | implicit def infixDoubleOps(d:Double) = new LiteralDoubleOps(d)
410 | implicit def infixBigIntOps(f:BigInt) = new LiteralBigIntOps(f)
411 | implicit def infixBigDecimalOps(d:BigDecimal) = new LiteralBigDecimalOps(d)
412 |
413 | def numeric[@specialized(Int, Long, Float, Double) A:Numeric]:Numeric[A] = implicitly[Numeric[A]]
414 | }
415 |
416 | object EasyImplicits {
417 | implicit def infixOps[@specialized(Int, Long, Float, Double) A:Numeric](a:A) = new EasyNumericOps(a)
418 |
419 | implicit def infixIntOps(i:Int) = new LiteralIntOps(i)
420 | implicit def infixLongOps(l:Long) = new LiteralLongOps(l)
421 | implicit def infixFloatOps(f:Float) = new LiteralFloatOps(f)
422 | implicit def infixDoubleOps(d:Double) = new LiteralDoubleOps(d)
423 | implicit def infixBigIntOps(f:BigInt) = new LiteralBigIntOps(f)
424 | implicit def infixBigDecimalOps(d:BigDecimal) = new LiteralBigDecimalOps(d)
425 |
426 | def numeric[@specialized(Int, Long, Float, Double) A:Numeric]:Numeric[A] = implicitly[Numeric[A]]
427 | }
428 |
--------------------------------------------------------------------------------
/src/test/scala/Matrix.scala:
--------------------------------------------------------------------------------
1 | import com.azavea.math.Numeric
2 | import com.azavea.math.EasyImplicits._
3 | import Predef.{any2stringadd => _, _}
4 |
5 | class MatrixException(s:String) extends Exception
6 |
7 |
8 | /**
9 | * Matrix implementation in terms of a generic numeric type.
10 | */
11 | class NMatrix[A:Numeric:Manifest](val data:Array[Array[A]],
12 | val rows:Int, val cols:Int) {
13 | if (rows < 1) throw new MatrixException("illegal height")
14 | if (cols < 1) throw new MatrixException("illegal widht")
15 |
16 | /* build an empty matrix with the same dimensions */
17 | def createEmpty() = NMatrix.empty[A](rows, cols)
18 |
19 | /* access the element at (y, x) */
20 | def apply(y:Int, x:Int) = data(y)(x)
21 |
22 | /* update the element at (y, x) to value */
23 | def update(y:Int, x:Int, value:A) = data(y)(x) = value
24 |
25 | /* create a new Matrix by mapping each element through f */
26 | def map[B:Numeric:Manifest](f:A => B) = {
27 | new NMatrix[B](data.map(_.map(f).toArray).toArray, rows, cols)
28 | }
29 |
30 | /* combine two matrices element-by-element */
31 | def combine(rhs:NMatrix[A], f:(A, A) => A) = {
32 | val result = createEmpty
33 | for (y <- 0 until rows; x <- 0 until cols) {
34 | result(y, x) = f(this(y, x), rhs(y, x))
35 | }
36 | result
37 | }
38 |
39 | /* add a scalar value to each element */
40 | def +(a:A):NMatrix[A] = map(_ + a)
41 |
42 | /* add two matrices */
43 | def +(rhs:NMatrix[A]):NMatrix[A] = combine(rhs, _ + _)
44 |
45 | /* multiply each element by a scalar value */
46 | def *(a:A):NMatrix[A] = map(_ * a)
47 |
48 | /* multiply two matrices */
49 | def *(rhs:NMatrix[A]):NMatrix[A] = {
50 |
51 | /* make sure this and rhs are compatible */
52 | if (this.rows != rhs.cols || this.cols != rhs.rows) {
53 | throw new MatrixException("dimensions do not match")
54 | }
55 |
56 | /* figure out the dimensions of the result matrix */
57 | val (rrows, rcols, n) = (this.rows, rhs.cols, this.cols)
58 |
59 | /* allocate the result matrix */
60 | val result = NMatrix.empty[A](rows, rcols)
61 |
62 | /* loop over the cells in the result matrix */
63 | for(y <- 0 until rrows; x <- 0 until rcols) {
64 | /* for each pair of values in this-row/rhs-column, multiply them
65 | * and then sum to get the result value for this cell. */
66 | result(y, x) = (0 until n).foldLeft(numeric.zero) {
67 | case (sum, i) => sum + (this(y, i) * rhs(i, x))
68 | }
69 | }
70 | result
71 | }
72 |
73 | def toAscii() = "[" + data.map {
74 | _.foldLeft("")(_ + " " + _.toString)
75 | }.reduceLeft(_ + "\n" + _) + "]"
76 | }
77 |
78 | object NMatrix {
79 | def empty[A:Numeric:Manifest](rows:Int, cols:Int) = {
80 | new NMatrix(Array.ofDim[A](rows, cols), rows, cols)
81 | }
82 |
83 | def apply[A:Numeric:Manifest](data:Array[Array[A]], rows:Int, cols:Int) = {
84 | new NMatrix(data, rows, cols)
85 | }
86 | }
87 |
88 | /**
89 | * Matrix implementation in terms of Double.
90 | */
91 | class DMatrix(val data:Array[Array[Double]], val rows:Int, val cols:Int) {
92 | if (rows < 1) throw new MatrixException("illegal height")
93 | if (cols < 1) throw new MatrixException("illegal widht")
94 |
95 | /* build an empty matrix with the same dimensions */
96 | def createEmpty() = DMatrix.empty(rows, cols)
97 |
98 | /* access the element at (y, x) */
99 | def apply(y:Int, x:Int) = data(y)(x)
100 |
101 | /* update the element at (y, x) to value */
102 | def update(y:Int, x:Int, value:Double) = data(y)(x) = value
103 |
104 | /* create a new Matrix by mapping each element through f */
105 | def map(f:Double => Double) = {
106 | new DMatrix(data.map(_.map(f).toArray).toArray, rows, cols)
107 | }
108 |
109 | /* combine two matrices element-by-element */
110 | def combine(rhs:DMatrix, f:(Double, Double) => Double) = {
111 | val result = createEmpty
112 | for (y <- 0 until rows; x <- 0 until cols) {
113 | result(y, x) = f(this(y, x), rhs(y, x))
114 | }
115 | result
116 | }
117 |
118 | /* add a scalar value to each element */
119 | def +(a:Double) = map(_ + a)
120 |
121 | /* add two matrices */
122 | def +(rhs:DMatrix) = combine(rhs, _ + _)
123 |
124 | /* multiply each element by a scalar value */
125 | def *(a:Double) = map(_ * a)
126 |
127 | /* multiply two matrices */
128 | def *(rhs:DMatrix) = {
129 |
130 | /* make sure this and rhs are compatible */
131 | if (this.rows != rhs.cols || this.cols != rhs.rows) {
132 | throw new MatrixException("dimensions do not match")
133 | }
134 |
135 | val (rrows, rcols, n) = (this.rows, rhs.cols, this.cols)
136 |
137 | val result = DMatrix.empty(rrows, rcols)
138 |
139 | for(y <- 0 until rrows; x <- 0 until rcols) {
140 | result(y, x) = (0 until n).foldLeft(0.0) {
141 | case (sum, i) => sum + (this(y, i) * rhs(i, x))
142 | }
143 | }
144 |
145 | result
146 | }
147 |
148 | def toAscii() = "[" + data.map {
149 | _.foldLeft("")(_ + " " + _.toString)
150 | }.reduceLeft(_ + "\n" + _) + "]\n"
151 | }
152 |
153 | object DMatrix {
154 | def empty(rows:Int, cols:Int) = {
155 | new DMatrix(Array.ofDim[Double](rows, cols), rows, cols)
156 | }
157 |
158 | def apply(data:Array[Array[Double]], rows:Int, cols:Int) = {
159 | new DMatrix(data, rows, cols)
160 | }
161 | }
162 |
163 |
164 |
165 | /**
166 | * Testing...
167 | */
168 | import org.scalatest.FunSuite
169 | import org.scalatest.matchers.ShouldMatchers
170 |
171 | class MatrixSpec extends FunSuite with ShouldMatchers {
172 | val (h, w) = (2, 2)
173 | val data1 = Array(Array(1.0, 2.0), Array(3.0, 4.0))
174 |
175 | def compare(m1:NMatrix[Double], m2:DMatrix):Boolean = {
176 | println(m1.toAscii)
177 | println(m2.toAscii)
178 | for (y <- 0 until h; x <- 0 until w) {
179 | if (m1(y, x) !== m2(y, x)) return false
180 | }
181 | true
182 | }
183 |
184 | val m1 = NMatrix(data1.clone, h, w)
185 | val m2 = DMatrix(data1.clone, h, w)
186 |
187 | test("test matrix representation") { assert(compare(m1, m2)) }
188 |
189 | test("scalar addition") { assert(compare(m1 + 13.0, m2 + 13.0)) }
190 | test("scalar multiplication") { assert(compare(m1 * 9.0, m2 * 9.0)) }
191 |
192 | test("matrix addition") { assert(compare(m1 + m1, m2 + m2)) }
193 | test("matrix multiplication") { assert(compare(m1 * m1, m2 * m2)) }
194 | }
195 |
--------------------------------------------------------------------------------