├── .gitignore ├── .travis.yml ├── README.md ├── lgpl-3.0.txt ├── pom.xml ├── project.clj └── src ├── main ├── clojure │ ├── data_readers.clj │ └── mikera │ │ └── vectorz │ │ ├── core.clj │ │ ├── matrix.clj │ │ ├── matrix_api.clj │ │ └── readers.clj └── java │ ├── .gitignore │ └── mikera │ └── vectorz │ ├── FnOp.java │ ├── FnOp2.java │ ├── PrimitiveFnOp.java │ └── PrimitiveFnOp2.java └── test ├── clojure ├── mikera │ └── vectorz │ │ ├── benchmark_matrix.clj │ │ ├── benchmark_stats.clj │ │ ├── blank.clj │ │ ├── examples.clj │ │ ├── generators.clj │ │ ├── implementation_check.clj │ │ ├── large_matrix_benchmark.clj │ │ ├── matrix_benchmarks.clj │ │ ├── test_core.clj │ │ ├── test_linear.clj │ │ ├── test_matrix.clj │ │ ├── test_matrix_api.clj │ │ ├── test_ops.clj │ │ ├── test_properties.clj │ │ ├── test_readers.clj │ │ ├── test_sparse.clj │ │ └── test_stats.clj └── test │ └── misc │ └── loading.clj └── java └── mikera └── vectorz └── ClojureTests.java /.gitignore: -------------------------------------------------------------------------------- 1 | /.project 2 | /.classpath 3 | /target 4 | /.settings 5 | /classes 6 | /.lein-failures 7 | /.nrepl-port 8 | /bin/ 9 | /.lein-repl-history 10 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: java 2 | jdk: 3 | - oraclejdk8 4 | - openjdk7 5 | sudo: false -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | vectorz-clj 2 | =========== 3 | 4 | [![Join the chat at https://gitter.im/mikera/vectorz-clj](https://badges.gitter.im/Join%20Chat.svg)](https://gitter.im/mikera/vectorz-clj?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) 5 | 6 | [![Clojars Project](http://clojars.org/net.mikera/vectorz-clj/latest-version.svg)](http://clojars.org/net.mikera/vectorz-clj) 7 | 8 | [![Build Status](https://travis-ci.org/mikera/vectorz-clj.png?branch=develop)](https://travis-ci.org/mikera/vectorz-clj) [![Dependency Status](https://www.versioneye.com/user/projects/54deed26271c93696000004a/badge.svg?style=flat)](https://www.versioneye.com/user/projects/54deed26271c93696000004a) 9 | 10 | Fast vector and matrix library for Clojure, building on the [Vectorz](https://github.com/mikera/vectorz) library and designed to work with the [core.matrix](https://github.com/mikera/core.matrix) array programming API. 11 | 12 | `vectorz-clj` is designed so that you don't have to compromise, offering both: 13 | 14 | - An idiomatic high-level Clojure API using **core.matrix** 15 | - General purpose **multi-dimensional** arrays 16 | - High **performance** (about as fast as you can get on the JVM). vectorz-clj is currently the fastest pure-JVM vector/matrix library available for Clojure 17 | 18 | The library was originally designed for games, simulations and machine learning applications, 19 | but should be applicable for any situations where you need numerical `double` arrays. 20 | 21 | Important features: 22 | 23 | - **"Pure"** functions for an idiomatic functional programming style are provided. These return new vectors without mutating their arguments. 24 | - **Primitive-backed** special purpose vectors and matrices for performance, e.g. `Vector3` for fast 3D maths. 25 | - **Flexible DSL-style** functions for manipulating vectors and matrices, e.g. the ability to create a "view" into a subspace of a large vector. 26 | - **core.matrix** fully supported - see: https://github.com/mikera/core.matrix 27 | - **Pure cross-platform JVM code** - no native dependencies 28 | - **"Impure"** functions that mutate vectors are available for performance when you need it: i.e. you can use a nice functional style most of the time, but switch to mutation when you hit a bottleneck. 29 | 30 | ## Documentation 31 | 32 | *vectorz-clj* is intended to be used primarily as a `core.matrix` implementation. As such, the main API to understand is `core.matrix` itself. See the `core.matrix` wiki for more information: 33 | 34 | - https://github.com/mikera/core.matrix/wiki 35 | 36 | For more information about the specific details of vectorz-clj itself, see the [vectorz-clj Wiki](https://github.com/mikera/vectorz-clj/wiki). 37 | 38 | ### Status 39 | 40 | `vectorz-clj` requires Clojure 1.4 or above, Java 1.7 or above, and an up to date version of *core.matrix* 41 | 42 | `vectorz-clj` is reasonably stable, and implements all of the *core.matrix* API feature set. 43 | 44 | 45 | ### License 46 | 47 | Like `Vectorz`, `vectorz-clj` is licensed under the LGPL license: 48 | 49 | - http://www.gnu.org/licenses/lgpl.html 50 | 51 | ### Usage 52 | 53 | Follow the instructions to install with Leiningen / Maven from Clojars: 54 | 55 | - https://clojars.org/net.mikera/vectorz-clj 56 | 57 | You can then use `Vectorz` as a standard `core.matrix` implementation. Example: 58 | 59 | ```clojure 60 | (use 'clojure.core.matrix) 61 | (use 'clojure.core.matrix.operators) ;; overrides *, + etc. for matrices 62 | 63 | (set-current-implementation :vectorz) ;; use Vectorz as default matrix implementation 64 | 65 | ;; define a 2x2 Matrix 66 | (def M (matrix [[1 2] [3 4]])) 67 | M 68 | => # 69 | 70 | ;; define a length 2 vector (a 1D matrix is considered equivalent to a vector in core.matrix) 71 | (def v (matrix [1 2])) 72 | v 73 | => # 74 | 75 | ;; Matrix x Vector elementwise multiply 76 | (mul M v) 77 | => # 78 | 79 | ;; Matrix x Vector matrix multiply (inner product) 80 | (inner-product M v) 81 | => # 82 | ``` 83 | 84 | For more examples see [Wiki Examples](https://github.com/mikera/vectorz-clj/wiki/Examples) 85 | -------------------------------------------------------------------------------- /lgpl-3.0.txt: -------------------------------------------------------------------------------- 1 | GNU LESSER GENERAL PUBLIC LICENSE 2 | Version 3, 29 June 2007 3 | 4 | Copyright (C) 2007 Free Software Foundation, Inc. 5 | Everyone is permitted to copy and distribute verbatim copies 6 | of this license document, but changing it is not allowed. 7 | 8 | 9 | This version of the GNU Lesser General Public License incorporates 10 | the terms and conditions of version 3 of the GNU General Public 11 | License, supplemented by the additional permissions listed below. 12 | 13 | 0. Additional Definitions. 14 | 15 | As used herein, "this License" refers to version 3 of the GNU Lesser 16 | General Public License, and the "GNU GPL" refers to version 3 of the GNU 17 | General Public License. 18 | 19 | "The Library" refers to a covered work governed by this License, 20 | other than an Application or a Combined Work as defined below. 21 | 22 | An "Application" is any work that makes use of an interface provided 23 | by the Library, but which is not otherwise based on the Library. 24 | Defining a subclass of a class defined by the Library is deemed a mode 25 | of using an interface provided by the Library. 26 | 27 | A "Combined Work" is a work produced by combining or linking an 28 | Application with the Library. The particular version of the Library 29 | with which the Combined Work was made is also called the "Linked 30 | Version". 31 | 32 | The "Minimal Corresponding Source" for a Combined Work means the 33 | Corresponding Source for the Combined Work, excluding any source code 34 | for portions of the Combined Work that, considered in isolation, are 35 | based on the Application, and not on the Linked Version. 36 | 37 | The "Corresponding Application Code" for a Combined Work means the 38 | object code and/or source code for the Application, including any data 39 | and utility programs needed for reproducing the Combined Work from the 40 | Application, but excluding the System Libraries of the Combined Work. 41 | 42 | 1. Exception to Section 3 of the GNU GPL. 43 | 44 | You may convey a covered work under sections 3 and 4 of this License 45 | without being bound by section 3 of the GNU GPL. 46 | 47 | 2. Conveying Modified Versions. 48 | 49 | If you modify a copy of the Library, and, in your modifications, a 50 | facility refers to a function or data to be supplied by an Application 51 | that uses the facility (other than as an argument passed when the 52 | facility is invoked), then you may convey a copy of the modified 53 | version: 54 | 55 | a) under this License, provided that you make a good faith effort to 56 | ensure that, in the event an Application does not supply the 57 | function or data, the facility still operates, and performs 58 | whatever part of its purpose remains meaningful, or 59 | 60 | b) under the GNU GPL, with none of the additional permissions of 61 | this License applicable to that copy. 62 | 63 | 3. Object Code Incorporating Material from Library Header Files. 64 | 65 | The object code form of an Application may incorporate material from 66 | a header file that is part of the Library. You may convey such object 67 | code under terms of your choice, provided that, if the incorporated 68 | material is not limited to numerical parameters, data structure 69 | layouts and accessors, or small macros, inline functions and templates 70 | (ten or fewer lines in length), you do both of the following: 71 | 72 | a) Give prominent notice with each copy of the object code that the 73 | Library is used in it and that the Library and its use are 74 | covered by this License. 75 | 76 | b) Accompany the object code with a copy of the GNU GPL and this license 77 | document. 78 | 79 | 4. Combined Works. 80 | 81 | You may convey a Combined Work under terms of your choice that, 82 | taken together, effectively do not restrict modification of the 83 | portions of the Library contained in the Combined Work and reverse 84 | engineering for debugging such modifications, if you also do each of 85 | the following: 86 | 87 | a) Give prominent notice with each copy of the Combined Work that 88 | the Library is used in it and that the Library and its use are 89 | covered by this License. 90 | 91 | b) Accompany the Combined Work with a copy of the GNU GPL and this license 92 | document. 93 | 94 | c) For a Combined Work that displays copyright notices during 95 | execution, include the copyright notice for the Library among 96 | these notices, as well as a reference directing the user to the 97 | copies of the GNU GPL and this license document. 98 | 99 | d) Do one of the following: 100 | 101 | 0) Convey the Minimal Corresponding Source under the terms of this 102 | License, and the Corresponding Application Code in a form 103 | suitable for, and under terms that permit, the user to 104 | recombine or relink the Application with a modified version of 105 | the Linked Version to produce a modified Combined Work, in the 106 | manner specified by section 6 of the GNU GPL for conveying 107 | Corresponding Source. 108 | 109 | 1) Use a suitable shared library mechanism for linking with the 110 | Library. A suitable mechanism is one that (a) uses at run time 111 | a copy of the Library already present on the user's computer 112 | system, and (b) will operate properly with a modified version 113 | of the Library that is interface-compatible with the Linked 114 | Version. 115 | 116 | e) Provide Installation Information, but only if you would otherwise 117 | be required to provide such information under section 6 of the 118 | GNU GPL, and only to the extent that such information is 119 | necessary to install and execute a modified version of the 120 | Combined Work produced by recombining or relinking the 121 | Application with a modified version of the Linked Version. (If 122 | you use option 4d0, the Installation Information must accompany 123 | the Minimal Corresponding Source and Corresponding Application 124 | Code. If you use option 4d1, you must provide the Installation 125 | Information in the manner specified by section 6 of the GNU GPL 126 | for conveying Corresponding Source.) 127 | 128 | 5. Combined Libraries. 129 | 130 | You may place library facilities that are a work based on the 131 | Library side by side in a single library together with other library 132 | facilities that are not Applications and are not covered by this 133 | License, and convey such a combined library under terms of your 134 | choice, if you do both of the following: 135 | 136 | a) Accompany the combined library with a copy of the same work based 137 | on the Library, uncombined with any other library facilities, 138 | conveyed under the terms of this License. 139 | 140 | b) Give prominent notice with the combined library that part of it 141 | is a work based on the Library, and explaining where to find the 142 | accompanying uncombined form of the same work. 143 | 144 | 6. Revised Versions of the GNU Lesser General Public License. 145 | 146 | The Free Software Foundation may publish revised and/or new versions 147 | of the GNU Lesser General Public License from time to time. Such new 148 | versions will be similar in spirit to the present version, but may 149 | differ in detail to address new problems or concerns. 150 | 151 | Each version is given a distinguishing version number. If the 152 | Library as you received it specifies that a certain numbered version 153 | of the GNU Lesser General Public License "or any later version" 154 | applies to it, you have the option of following the terms and 155 | conditions either of that published version or of any later version 156 | published by the Free Software Foundation. If the Library as you 157 | received it does not specify a version number of the GNU Lesser 158 | General Public License, you may choose any version of the GNU Lesser 159 | General Public License ever published by the Free Software Foundation. 160 | 161 | If the Library as you received it specifies that a proxy can decide 162 | whether future versions of the GNU Lesser General Public License shall 163 | apply, that proxy's public statement of acceptance of any version is 164 | permanent authorization for you to choose that version for the 165 | Library. 166 | -------------------------------------------------------------------------------- /pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 4.0.0 3 | vectorz-clj 4 | 0.48.1-SNAPSHOT 5 | Fast matrix and vector maths library for Clojure - as a pure JVM core.matrix implementation 6 | 7 | 8 | 9 | GNU Lesser General Public License (LGPL) 10 | http://www.gnu.org/licenses/lgpl.html 11 | 12 | 13 | 14 | 15 | 16 | clojars.org 17 | Clojars repository 18 | https://clojars.org/repo 19 | 20 | 21 | 22 | 23 | scm:git:git@github.com:mikera/${project.artifactId}.git 24 | scm:git:git@github.com:mikera/${project.artifactId}.git 25 | scm:git:git@github.com:mikera/${project.artifactId}.git 26 | HEAD 27 | 28 | 29 | 30 | net.mikera 31 | clojure-pom 32 | 0.6.0 33 | 34 | 35 | 36 | 37 | net.mikera 38 | vectorz 39 | 0.66.0 40 | 41 | 42 | net.mikera 43 | core.matrix 44 | 0.62.0 45 | 46 | 47 | net.mikera 48 | core.matrix 49 | 0.62.0 50 | tests 51 | test 52 | 53 | 54 | criterium 55 | criterium 56 | 0.4.4 57 | test 58 | 59 | 60 | net.mikera 61 | cljunit 62 | 0.6.0 63 | test 64 | 65 | 66 | org.clojure 67 | clojure 68 | 1.9.0 69 | 70 | 71 | org.clojure 72 | tools.analyzer 73 | 0.6.9 74 | test 75 | 76 | 77 | net.mikera 78 | clojure-utils 79 | 0.8.0 80 | 81 | 82 | org.clojure 83 | test.check 84 | 0.9.0 85 | test 86 | 87 | 88 | 89 | 90 | -------------------------------------------------------------------------------- /project.clj: -------------------------------------------------------------------------------- 1 | ;; This file is proveded as a convenience for Leiningen users 2 | ;; 3 | ;; The pom.xml is used for official builds, and should be considered the 4 | ;; definitive source for build configuration. 5 | ;; 6 | ;; If you are having trouble building, please check the pom.xml for latest dependency versions 7 | 8 | (defproject net.mikera/vectorz-clj "0.30.2-SNAPSHOT" 9 | :description "Fast vector library for Clojure, building on Vectorz and using core.matrix" 10 | :url "https://github.com/mikera/vectorz-clj" 11 | :license {:name "GNU Lesser General Public License (LGPL)" 12 | :url "http://www.gnu.org/licenses/lgpl.html"} 13 | :source-paths ["src/main/clojure"] 14 | :test-paths ["src/test/clojure"] 15 | :dependencies [[org.clojure/clojure "1.9.0"] 16 | [criterium/criterium "0.4.4"] 17 | [org.clojure/tools.analyzer "0.6.7"] 18 | [org.clojure/test.check "0.9.0"] 19 | [net.mikera/clojure-utils "0.7.0"] 20 | [net.mikera/core.matrix "0.61.0"] 21 | ;; [net.mikera/core.matrix "0.47.1" :type "test-jar"] ;; bug in Lein!!!! see: https://github.com/technomancy/leiningen/issues/1975 22 | [net.mikera/vectorz "0.62.0"]] 23 | 24 | :profiles {:dev {:java-source-paths ["src/test/java"] 25 | :dependencies [[net.mikera/cljunit "0.4.1"] 26 | [net.mikera/core.matrix "0.61.0" :classifier "tests"]]}} 27 | 28 | :repositories [["clojars.org" {:url "https://clojars.org/repo"}]]) 29 | -------------------------------------------------------------------------------- /src/main/clojure/data_readers.clj: -------------------------------------------------------------------------------- 1 | {vectorz/vector mikera.vectorz.readers/vector 2 | vectorz/matrix mikera.vectorz.readers/matrix 3 | vectorz/scalar mikera.vectorz.readers/scalar 4 | vectorz/array mikera.vectorz.readers/array} -------------------------------------------------------------------------------- /src/main/clojure/mikera/vectorz/core.clj: -------------------------------------------------------------------------------- 1 | (ns mikera.vectorz.core 2 | "Clojure API for directly accessing Vectorz vector functions. 3 | 4 | In most cases these are relatively lightweight wrappers over equivalent functions in Vectorz, 5 | but specialised with type hints for handling Vectorz vectors for performance purposes. 6 | 7 | These are generally equivalent to similar functions in clojure.core.matrix API. If performance is 8 | less of a concern, consider using the clojure.core.matrix API directly, which offer more functionality 9 | and work with a much broader range of array shapes and argument types." 10 | (:import [mikera.vectorz AVector Vectorz Vector Vector1 Vector2 Vector3 Vector4]) 11 | (:import [mikera.arrayz INDArray]) 12 | (:import [mikera.transformz Transformz]) 13 | (:require [mikera.cljutils.error :refer [error]]) 14 | (:refer-clojure :exclude [+ - * vec vec? vector subvec get set to-array empty])) 15 | 16 | (set! *warn-on-reflection* true) 17 | (set! *unchecked-math* :warn-on-boxed) 18 | 19 | ;; ================================================== 20 | ;; Protocols 21 | 22 | (defprotocol PVectorisable 23 | (to-vector [a])) 24 | 25 | (extend-protocol PVectorisable 26 | (Class/forName "[D") 27 | (to-vector [coll] 28 | (Vectorz/create ^doubles coll)) 29 | java.util.List 30 | (to-vector [coll] 31 | (Vectorz/create ^java.util.List coll)) 32 | java.lang.Iterable 33 | (to-vector [coll] 34 | (Vectorz/create ^java.lang.Iterable coll)) 35 | mikera.vectorz.AVector 36 | (to-vector [coll] 37 | (.clone coll))) 38 | 39 | ;; ================================================== 40 | ;; basic functions 41 | 42 | (defn clone 43 | "Creates a (mutable) clone of a vectorz array. May not be exactly the same class as the original array." 44 | (^INDArray [^INDArray v] 45 | (.clone v))) 46 | 47 | (defn ecount 48 | "Returns the number of elements in a vectorz array" 49 | (^long [^INDArray v] 50 | (.elementCount v))) 51 | 52 | (defn vec? 53 | "Returns true if v is a vectorz vector 54 | (i.e. an instance of mikera.vectorz.AVector or a 1-dimensional INDArray)" 55 | ([v] 56 | (or 57 | (instance? AVector v) 58 | (and (instance? INDArray v) (== 1 (.dimensionality ^INDArray v)))))) 59 | 60 | (defn vectorz? 61 | "Returns true if v is a vectorz array class (i.e. any instance of mikera.arrayz.INDArray)" 62 | ([a] 63 | (instance? INDArray a))) 64 | 65 | (defn get 66 | "DEPRECATED: use mget instead for consistency with core.matrix 67 | 68 | Returns the component of a vector at a specific index position" 69 | (^double [^INDArray v ^long index] 70 | (.get v (int index)))) 71 | 72 | (defn mget 73 | "Returns the component of a vector at a specific index position" 74 | (^double [^INDArray v] 75 | (.get v)) 76 | (^double [^INDArray v ^long i] 77 | (.get v (int i))) 78 | (^double [^INDArray v ^long i ^long j] 79 | (.get v (int i) (int j)))) 80 | 81 | (defn set 82 | "DEPRECATED: use mset! instead for consistency with core.matrix 83 | 84 | Sets the component of a vector at position i (mutates in place)" 85 | ([^AVector v ^long index ^double value] 86 | (.set v (int index) value) 87 | v)) 88 | 89 | (defn mset! 90 | "Sets the component of a vector at position i (mutates in place)" 91 | ([^INDArray v ^double value] 92 | (.set v value) 93 | v) 94 | ([^INDArray v ^long i ^double value] 95 | (.set v (int i) value) 96 | v) 97 | ([^INDArray v ^long i ^long j ^double value] 98 | (.set v (int i) (int j) value) 99 | v)) 100 | 101 | ;; ===================================================== 102 | ;; vector predicates 103 | 104 | (defn normalised? 105 | "Returns true if a vector is normalised (has unit length)" 106 | [^AVector v] 107 | (.isUnitLengthVector v)) 108 | 109 | ;; ==================================================== 110 | ;; vector constructors 111 | 112 | (defn of 113 | "Creates a vector from its numerical components" 114 | (^AVector [& xs] 115 | (let [len (int (count xs)) 116 | ss (seq xs) 117 | ^AVector v (Vectorz/newVector len)] 118 | (loop [i (int 0) ss ss] 119 | (if ss 120 | (do 121 | (.set v i (double (first ss))) 122 | (recur (inc i) (next ss))) 123 | v))))) 124 | 125 | (defn vec 126 | "Creates a vector from a collection, a sequence or anything else that implements the PVectorisable protocol" 127 | (^AVector [coll] 128 | (cond 129 | (vec? coll) (clone coll) 130 | (satisfies? PVectorisable coll) (to-vector coll) 131 | (sequential? coll) (apply of coll) 132 | :else (error "Can't create vector from: " (class coll))))) 133 | 134 | (defn vec1 135 | "Creates a Vector1 instance" 136 | (^Vector1 [] 137 | (Vector1.)) 138 | (^Vector1 [coll] 139 | (if (number? coll) 140 | (Vector1/of (double coll)) 141 | (let [v (vec coll)] 142 | (if (instance? Vector1 v) 143 | v 144 | (error "Can't create Vector1 from: " (str coll))))))) 145 | 146 | (defn vec2 147 | "Creates a Vector2 instance" 148 | (^Vector2 [] 149 | (Vector2.)) 150 | (^Vector2 [coll] 151 | (let [v (vec coll)] 152 | (if (instance? Vector2 v) 153 | v 154 | (error "Can't create Vector2 from: " (str coll))))) 155 | (^Vector2 [^double x ^double y] 156 | (Vector2/of (double x) (double y)))) 157 | 158 | (defn vec3 159 | "Creates a Vector3 instance" 160 | (^Vector3 [] 161 | (Vector3.)) 162 | (^Vector3 [coll] 163 | (let [v (vec coll)] 164 | (if (instance? Vector3 v) 165 | v 166 | (error "Can't create Vector3 from: " (str coll))))) 167 | (^Vector3 [^double x ^double y ^double z] 168 | (Vector3/of (double x) (double y) (double z)))) 169 | 170 | (defn vec4 171 | "Creates a Vector4 instance" 172 | (^Vector4 [] 173 | (Vector4.)) 174 | (^Vector4 [coll] 175 | (let [v (vec coll)] 176 | (if (instance? Vector4 v) 177 | v 178 | (error "Can't create Vector4 from: " (str coll))))) 179 | (^Vector4 [^double x ^double y ^double z ^double t] 180 | (Vector4/of (double x) (double y) (double z) (double t)))) 181 | 182 | (defn vector 183 | "Creates a vector from zero or more numerical components." 184 | (^AVector [& xs] 185 | (vec xs))) 186 | 187 | (defn create-length 188 | "Creates a vector of a specified length. Will use optimised primitive vectors for small lengths" 189 | (^AVector [len] 190 | (Vectorz/newVector (int len)))) 191 | 192 | (defn empty 193 | "Creates an empty vector of a specified length. Will use optimised primitive vectors for small lengths" 194 | (^AVector [^long len] 195 | (Vectorz/newVector (int len)))) 196 | 197 | (defn subvec 198 | "Returns a subvector of a vector. The subvector is a reference (i.e can be sed to modify the original vector)" 199 | (^AVector [^AVector v ^long start ^long end] 200 | (.subVector v (int start) (int end)))) 201 | 202 | (defn join 203 | "Joins two vectors together. The returned vector is a new reference vector that refers to the originals." 204 | (^AVector [^AVector a ^AVector b] 205 | (.join a b))) 206 | 207 | ;; ====================================== 208 | ;; Conversions 209 | 210 | 211 | (defn to-array 212 | "Converts a vector to a double array" 213 | (^doubles [^AVector a] 214 | (.toDoubleArray a))) 215 | 216 | (defn to-list 217 | "Converts a vector to a list of doubles" 218 | (^java.util.List [^AVector a] 219 | (.toList a))) 220 | 221 | ;; ===================================== 222 | ;; In-place operations 223 | 224 | (defn assign! 225 | "Fills a vector in place with the value of another vector" 226 | (^AVector [^AVector a ^AVector new-value] 227 | (.set a new-value) 228 | a)) 229 | 230 | (defn add! 231 | "Add a vector to another (in-place)" 232 | (^AVector [^AVector dest ^AVector source] 233 | (.add dest source) 234 | dest)) 235 | 236 | (defn add-multiple! 237 | "Add a vector to another (in-place)" 238 | (^AVector [^AVector dest ^AVector source ^double factor] 239 | (.addMultiple dest source factor) 240 | dest)) 241 | 242 | (defn sub! 243 | "Subtract a vector from another (in-place)" 244 | (^AVector [^AVector dest ^AVector source] 245 | (.sub dest source) 246 | dest)) 247 | 248 | (defn mul! 249 | "Multiply a vector with another vector or scalar (in-place)" 250 | (^AVector [^AVector dest source] 251 | (if (number? source) 252 | (.multiply dest (double source)) 253 | (.multiply dest ^AVector source)) 254 | dest)) 255 | 256 | (defn div! 257 | "Divide a vector by another vector or scalar (in-place)" 258 | (^AVector [^AVector dest source] 259 | (if (number? source) 260 | (.divide dest (double source)) 261 | (.divide dest ^AVector source)) 262 | dest)) 263 | 264 | (defn normalise! 265 | "Normalises a vector in place to unit length and returns it" 266 | (^AVector [^AVector a] 267 | (.normalise a) 268 | a)) 269 | 270 | (defn normalise-get-magnitude! 271 | "Normalises a vector in place to unit length and returns its magnitude" 272 | (^double [^AVector a] 273 | (.normalise a))) 274 | 275 | (defn negate! 276 | "Negates a vector in place and returns it" 277 | (^AVector [^AVector a] 278 | (.negate a) 279 | a)) 280 | 281 | (defn abs! 282 | "Computes the absolute value of a vector in place and returns it" 283 | (^AVector [^AVector a] 284 | (.abs a) 285 | a)) 286 | 287 | (defn scale! 288 | "Scales a vector in place by a scalar numerical factor" 289 | (^AVector [^AVector a ^double factor] 290 | (.scale a factor) 291 | a)) 292 | 293 | (defn scale-add! 294 | "Scales a fector in place by a scalar numerical factor and adds a second vector" 295 | (^AVector [^AVector a ^double factor ^AVector b] 296 | (.scaleAdd a factor b) 297 | a)) 298 | 299 | (defn add-weighted! 300 | "Create a weighted average of a vector with another in place. Numerical weight specifies the proportion of the second vector to use" 301 | (^AVector [^AVector dest ^AVector source weight] 302 | (.addWeighted dest source (double weight)) 303 | dest)) 304 | 305 | (defn fill! 306 | "Fills a vector in place with a specific numerical value" 307 | (^AVector [^AVector a ^double value] 308 | (.fill a value) 309 | a)) 310 | 311 | 312 | 313 | ;; ===================================== 314 | ;; Special 3D functions 315 | 316 | (defn cross-product! 317 | "Calculates the cross product of a 3D vector in place " 318 | (^Vector3 [^Vector3 a ^AVector b] 319 | (.crossProduct a b) 320 | a)) 321 | 322 | ;; ===================================== 323 | ;; Pure functional operations 324 | 325 | (defn add 326 | "Add a vector to another" 327 | (^AVector [^AVector dest ^AVector source] 328 | (.addCopy dest source))) 329 | 330 | (defn add-multiple 331 | "Add a vector to another" 332 | (^AVector [^AVector dest ^AVector source ^double factor] 333 | (.addMultipleCopy dest source factor))) 334 | 335 | (defn sub 336 | "Subtract a vector from another" 337 | (^AVector [^AVector dest ^AVector source] 338 | (.subCopy dest source))) 339 | 340 | (defn mul 341 | "Multiply a vector with another vector or scalar" 342 | (^AVector [^AVector dest source] 343 | (if (number? source) 344 | (.scaleCopy dest (double source)) 345 | (.multiplyCopy dest ^AVector source)))) 346 | 347 | (defn div 348 | "Divide a vector by another vector or scalar" 349 | (^AVector [^AVector dest source] 350 | (if (number? source) 351 | (.scaleCopy dest (/ 1.0 (double source))) 352 | (.divideCopy dest ^AVector source)))) 353 | 354 | (defn interpolate 355 | (^AVector [^AVector a ^AVector b position] 356 | (let [^AVector result (clone a)] 357 | (.interpolate result b (double position)) 358 | result))) 359 | 360 | (defn normalise 361 | "Normalises a vector to unit length and returns it" 362 | (^AVector [^AVector a] 363 | (.normaliseCopy a))) 364 | 365 | (defn negate 366 | "Negates a vector and returns it" 367 | (^AVector [^AVector a] 368 | (.negateCopy a))) 369 | 370 | (defn abs 371 | "Computes the absolute value of a vector and returns it" 372 | (^AVector [^AVector a] 373 | (.absCopy a))) 374 | 375 | (defn scale 376 | "Scales a vector by a scalar numerical factor" 377 | ([^AVector a ^double factor] 378 | (.scaleCopy a factor))) 379 | 380 | (defn scale-add 381 | "Scales a vector by a scalar numerical factor and adds a second vector" 382 | ([^AVector a factor ^AVector b] 383 | (scale-add! (clone a) factor b))) 384 | 385 | (defn fill 386 | "Fills a vector with a specific numerical value" 387 | ([^AVector a ^double value] 388 | (fill! (clone a) value))) 389 | 390 | 391 | (defn add-weighted 392 | "Create a weighted average of a vector with another. Numerical weight specifies the proportion of the second vector to use" 393 | (^AVector [^AVector dest ^AVector source weight] 394 | (add-weighted! (clone dest) source (double weight)))) 395 | 396 | 397 | ;; ===================================== 398 | ;; Arithmetic functions and operators 399 | 400 | (defn approx= 401 | "Returns a boolean indicating whether the two vectors are approximately equal, +/- an optional tolerance" 402 | ([^AVector a ^AVector b] 403 | (.epsilonEquals a b)) 404 | ([^AVector a ^AVector b epsilon] 405 | (.epsilonEquals a b (double epsilon)))) 406 | 407 | (defn dot 408 | "Compute the dot product of two vectors" 409 | (^double [^AVector a ^AVector b] 410 | (.dotProduct a b))) 411 | 412 | (defn magnitude 413 | "Return the magnitude of a vector (geometric length)" 414 | (^double [^AVector a] 415 | (.magnitude a))) 416 | 417 | (defn magnitude-squared 418 | "Return the squared magnitude of a vector. Slightly more efficient than getting the magnitude directly." 419 | (^double [^AVector a] 420 | (.magnitudeSquared a))) 421 | 422 | (defn distance 423 | "Return the euclidean distance between two vectors" 424 | (^double [^AVector a ^AVector b] 425 | (.distance a b))) 426 | 427 | (defn distance-squared 428 | "Return the squared euclidean distance between two vectors" 429 | (^double [^AVector a ^AVector b] 430 | (.distanceSquared a b))) 431 | 432 | (defn angle 433 | "Return the angle between two vectors" 434 | (^double [^AVector a ^AVector b] 435 | (.angle a b))) 436 | 437 | (defn + 438 | "Add one or more vectors, returning a new vector as the result" 439 | (^AVector [^AVector a] (clone a)) 440 | (^AVector [^AVector a ^AVector b] 441 | (let [r (clone a)] 442 | (.add r b) 443 | r)) 444 | (^AVector [^AVector a ^AVector b & vs] 445 | (let [r (clone a)] 446 | (.add r b) 447 | (doseq [^AVector v vs] 448 | (.add r v)) 449 | r))) 450 | 451 | (defn - 452 | "Substract one or more vectors" 453 | (^AVector [^AVector a] (clone a)) 454 | (^AVector [^AVector a ^AVector b] 455 | (let [r (clone a)] 456 | (.sub r b) 457 | r)) 458 | (^AVector [^AVector a ^AVector b & vs] 459 | (let [r (- a b)] 460 | (doseq [^AVector v vs] 461 | (.sub r v)) 462 | r))) 463 | 464 | (defn * 465 | "Multiply one or more vectors, element-wise" 466 | (^AVector [^AVector a] (clone a)) 467 | (^AVector [^AVector a ^AVector b] 468 | (let [r (clone a)] 469 | (.multiply r b) 470 | r)) 471 | (^AVector [^AVector a ^AVector b & vs] 472 | (let [r (* a b)] 473 | (doseq [^AVector v vs] 474 | (.multiply r v)) 475 | r))) 476 | 477 | (defn divide 478 | "Divide one or more vectors, element-wise" 479 | (^AVector [^AVector a] (clone a)) 480 | (^AVector [^AVector a ^AVector b] 481 | (let [r (clone a)] 482 | (.divide ^AVector r ^AVector b) 483 | r)) 484 | (^AVector [^AVector a ^AVector b & vs] 485 | (let [^AVector r (divide a b)] 486 | (doseq [^AVector v vs] 487 | (.divide r v)) 488 | r))) -------------------------------------------------------------------------------- /src/main/clojure/mikera/vectorz/matrix.clj: -------------------------------------------------------------------------------- 1 | (ns mikera.vectorz.matrix 2 | "Clojure API for directly accessing Vectorz matrix functions. 3 | 4 | In most cases these are relatively lightweight wrappers over equivalent functions in Vectorz, 5 | but specialised with type hints for handling Vectorz matrices for performance purposes. 6 | 7 | These are generally equivalent to similar functions in clojure.core.matrix API. If performance is 8 | less of a concern, consider using the clojure.core.matrix API directly, which offer more functionality 9 | and work with a much broader range of array shapes and argument types." 10 | (:import [mikera.vectorz AVector Vectorz Vector Vector3]) 11 | (:import [mikera.matrixx AMatrix Matrixx Matrix]) 12 | (:import [mikera.transformz Transformz ATransform AAffineTransform MatrixTransform]) 13 | (:import [mikera.arrayz INDArray]) 14 | (:require [mikera.vectorz.core :as v]) 15 | (:refer-clojure :exclude [* get set zero?])) 16 | 17 | (set! *warn-on-reflection* true) 18 | (set! *unchecked-math* :warn-on-boxed) 19 | 20 | ;; ============================================ 21 | ;; Core functions 22 | 23 | (defn clone 24 | "Creates a (mutable) deep clone of a matrix. May not be exactly the same class as the original matrix." 25 | (^AMatrix [^AMatrix v] 26 | (.clone v))) 27 | 28 | (defn to-transform 29 | "Coerces a matrix or transform to an ATransform instance" 30 | (^ATransform [a] 31 | (if (instance? AMatrix a) 32 | (MatrixTransform. a) 33 | a))) 34 | 35 | (defn transform? 36 | "Returns true if m is a transform (i.e. an instance of mikera.transformz.ATransform)" 37 | ([m] 38 | (instance? ATransform m))) 39 | 40 | (defn affine-transform? 41 | "Returns true if m is a transform (i.e. an instance of mikera.transformz.AAffineTransform)" 42 | ([m] 43 | (instance? AAffineTransform (to-transform m)))) 44 | 45 | (defn matrix? 46 | "Returns true if m is a Vectorz matrix (i.e. an instance of mikera.matrixx.AMatrix)" 47 | ([m] 48 | (instance? AMatrix m))) 49 | 50 | (defn get 51 | "Returns the component of a matrix at a specific (row,column) position" 52 | (^double [^AMatrix m ^long row ^long column] 53 | (.get m (int row) (int column)))) 54 | 55 | (defn set 56 | "Sets the component of a matrix at a (row,column) position (mutates in place)" 57 | ([^AMatrix m ^long row ^long column ^double value] 58 | (.set m (int row) (int column) value) 59 | m)) 60 | 61 | (defn get-row 62 | "Gets a row of the matrix as a vector" 63 | (^AVector [^AMatrix m ^long row] 64 | (.getRow m (int row)))) 65 | 66 | (defn get-column 67 | "Gets a column of the matrix as a vector" 68 | (^AVector [^AMatrix m ^long row] 69 | (.getColumn m (int row)))) 70 | 71 | 72 | ;; ============================================ 73 | ;; Matrix predicates 74 | 75 | (defn fully-mutable? 76 | "Returns true if the matrix is fully mutable" 77 | ([^AMatrix m] 78 | (.isFullyMutable m))) 79 | 80 | (defn zero? 81 | "Returns true if the matrix is a zero-filled matrix (i.e. maps every vector to zero)" 82 | ([^AMatrix m] 83 | (.isZero m))) 84 | 85 | (defn square? 86 | "Returns true if the matrix is a square matrix" 87 | ([^AMatrix m ] 88 | (.isSquare m))) 89 | 90 | (defn identity? 91 | "Returns true if the matrix is an identity matrix" 92 | ([^AMatrix m] 93 | (.isIdentity m))) 94 | 95 | 96 | ;; ============================================ 97 | ;; General transform constructors 98 | 99 | (defn constant-transform 100 | "Converts a vector to a constant transform" 101 | (^ATransform [^AVector v 102 | & {:keys [input-dimensions]}] 103 | (mikera.transformz.Transformz/constantTransform (int (or input-dimensions (.length v))) v))) 104 | 105 | ;; ============================================ 106 | ;; Matrix constructors 107 | 108 | (defn new-matrix 109 | "Creates a new, mutable, zero-filled matrix with the given number of rows and columns" 110 | (^AMatrix [row-count column-count] 111 | (Matrixx/newMatrix (int row-count) (int column-count)))) 112 | 113 | (defn matrix 114 | "Creates a new, mutable matrix using the specified data, which should be a sequence of row vectors" 115 | (^AMatrix [rows] 116 | (let [vecs (vec (map v/vec rows)) 117 | cc (apply max (map v/ecount vecs)) 118 | rc (count rows) 119 | mat (new-matrix rc cc)] 120 | (dotimes [i rc] 121 | (let [^AVector v (vecs i) 122 | ^AVector row (.getRowView mat i)] 123 | (.copyTo v row (int 0)))) 124 | mat))) 125 | 126 | (defn identity-matrix 127 | "Returns an immutable identity matrix for the given number of dimensions." 128 | (^AMatrix [dimensions] 129 | (Matrixx/createIdentityMatrix (int dimensions)))) 130 | 131 | (defn diagonal-matrix 132 | "Creates a diagonal matrix, using the sequence of diagonal values provided" 133 | (^AMatrix [diagonal-values] 134 | (mikera.matrixx.impl.DiagonalMatrix/create (double-array diagonal-values)))) 135 | 136 | (defn scale-matrix 137 | "Creates a diagonal scaling matrix" 138 | (^AMatrix [scale-factors] 139 | (Matrixx/createScaleMatrix (double-array (seq scale-factors)))) 140 | (^AMatrix [dimensions factor] 141 | (Matrixx/createScaleMatrix (int dimensions) (double factor)))) 142 | 143 | (defn scalar-matrix 144 | "Creates a diagonal scalar matrix (multiplies all components by same factor)" 145 | (^AMatrix [dimensions factor] 146 | (Matrixx/createScalarMatrix (int dimensions) (double factor)))) 147 | 148 | (defn x-axis-rotation-matrix 149 | "Creates a rotation matrix with the given number of radians around the x axis" 150 | (^AMatrix [angle] 151 | (Matrixx/createXAxisRotationMatrix (double angle)))) 152 | 153 | (defn y-axis-rotation-matrix 154 | "Creates a rotation matrix with the given number of radians around the y axis" 155 | (^AMatrix [angle] 156 | (Matrixx/createYAxisRotationMatrix (double angle)))) 157 | 158 | (defn z-axis-rotation-matrix 159 | "Creates a rotation matrix with the given number of radians around the z axis" 160 | (^AMatrix [angle] 161 | (Matrixx/createZAxisRotationMatrix (double angle)))) 162 | 163 | ;; ============================================ 164 | ;; matrix operations 165 | 166 | (defn scale 167 | "Scales a matrix by a scalar factor" 168 | (^AMatrix [^AMatrix m factor] 169 | (let [^AMatrix m (clone m)] 170 | (.addMultiple m m (- (double factor) 1.0)) 171 | m))) 172 | 173 | (defn input-dimensions 174 | "Gets the number of input dimensions (columns) of a matrix or other transform" 175 | (^long [m] 176 | (if (instance? AMatrix m) 177 | (.columnCount ^AMatrix m) 178 | (.inputDimensions ^ATransform m)))) 179 | 180 | (defn output-dimensions 181 | "Gets the number of output dimensions (rows) of a matrix or other transform" 182 | (^long [m] 183 | (if (instance? AMatrix m) 184 | (.rowCount ^AMatrix m) 185 | (.outputDimensions ^ATransform m)))) 186 | 187 | (defn transpose! 188 | "Transposes a matrix in place, if possible" 189 | (^AMatrix [^AMatrix m] 190 | (.transposeInPlace m) 191 | m)) 192 | 193 | (defn transpose 194 | "Gets the transpose of a matrix as a transposed reference to the original matrix" 195 | (^AMatrix [^AMatrix m] 196 | (.getTranspose m))) 197 | 198 | (defn as-vector 199 | "Returns a vector view over all elements of a matrix (in row major order)" 200 | (^AVector [^AMatrix m] 201 | (.asVector m))) 202 | 203 | (defn inverse 204 | "Gets the inverse of a square matrix as a new matrix." 205 | (^AMatrix [^AMatrix m] 206 | (.inverse m))) 207 | 208 | (defn compose! 209 | "Composes a transform with another transform (in-place). Second transform should be square." 210 | (^ATransform [a b] 211 | (.composeWith (to-transform a) (to-transform b)) 212 | a)) 213 | 214 | (defn compose 215 | "Composes a transform with another transform" 216 | (^ATransform [a b] 217 | (.compose (to-transform a) (to-transform b)))) 218 | 219 | (defn determinant 220 | "Gets the determinant of a (square) matrix" 221 | (^double [^AMatrix m] 222 | (.determinant m))) 223 | 224 | 225 | 226 | ;; ============================================ 227 | ;; Matrix application 228 | 229 | (defn transform! 230 | "Applies a matrix transform to a vector, modifying the vector in place" 231 | (^AVector [^ATransform m ^AVector a] 232 | (.transformInPlace m a) 233 | a)) 234 | 235 | (defn transform 236 | "Applies a matrix transform to a vector, returning a new vector" 237 | (^AVector [m ^AVector a] 238 | (if (instance? ATransform m) 239 | (let [^ATransform m m] 240 | (.transform m a)) 241 | (.innerProduct ^AMatrix m a)))) 242 | 243 | (defn transform-normal 244 | "Applies a an affine transform to a normal vector, storing the result in dest" 245 | (^AVector [^AAffineTransform m ^AVector src ^AVector dest] 246 | (.transformNormal m src dest))) 247 | 248 | (defn * 249 | "Applies a matrix to a vector or matrix, returning a new vector or matrix. If applied to a vector, the vector is transformed. If applied to a matrix, the two matrices are composed" 250 | ([^AMatrix m a] 251 | (cond 252 | (instance? AVector a) (.innerProduct m ^AVector a) 253 | :else (.innerProduct m ^INDArray a)))) -------------------------------------------------------------------------------- /src/main/clojure/mikera/vectorz/matrix_api.clj: -------------------------------------------------------------------------------- 1 | (ns mikera.vectorz.matrix-api 2 | "Namespace for vectorz-clj core.matrix implementation. Loading this namespace either 3 | directly or indirectly is required to enable the :vectorz implementation for core.matrix." 4 | (:refer-clojure :exclude [abs vector?]) 5 | (:use clojure.core.matrix) 6 | (:use clojure.core.matrix.utils) 7 | (:require [clojure.core.matrix.implementations :as imp]) 8 | (:require [clojure.core.matrix.protocols :as mp]) 9 | (:require [mikera.vectorz.readers]) 10 | (:import [mikera.matrixx AMatrix Matrixx Matrix]) 11 | (:import [mikera.matrixx.impl SparseRowMatrix SparseColumnMatrix]) 12 | (:import [mikera.vectorz AVector Vectorz Vector AScalar Vector3 Ops Op Op2]) 13 | (:import [mikera.vectorz Scalar]) 14 | (:import [mikera.vectorz FnOp FnOp2]) 15 | (:import [mikera.vectorz.impl ASparseIndexedVector SparseHashedVector ZeroVector SparseIndexedVector]) 16 | (:import [mikera.arrayz Arrayz INDArray Array]) 17 | (:import [mikera.arrayz.impl SliceArray]) 18 | (:import [mikera.indexz AIndex Index]) 19 | (:import [java.util List Arrays]) 20 | (:import [java.io Writer]) 21 | (:import [mikera.transformz ATransform]) 22 | (:import [mikera.matrixx.decompose QR IQRResult Cholesky ICholeskyResult ICholeskyLDUResult]) 23 | (:import [mikera.matrixx.decompose SVD ISVDResult LUP ILUPResult Eigen IEigenResult]) 24 | (:import [mikera.matrixx.solve Linear])) 25 | 26 | ;; ====================================================================== 27 | ;; General implementation notes 28 | ;; 29 | ;; Vectorz supports double element types only. All internal calculation is done with 30 | ;; unboxed double primitives for speed, however boxing may be required for return values, 31 | ;; passing to other core.matrix protocols etc. 32 | ;; 33 | ;; Arguments other then the first argument are *not* guaranteed to be Vectorz types 34 | ;; so we need to coerce to appropriate forms before use. This ensures that we can work with 35 | ;; types from all other working numerical implementations. 36 | ;; Utility functions to do this include: vectorz-coerce, double-coerce, avector-coerce 37 | 38 | (set! *warn-on-reflection* true) 39 | (set! *unchecked-math* :warn-on-boxed) 40 | (declare vectorz-coerce* avector-coerce* double-coerce) 41 | 42 | ;; ======================================================================= 43 | ;; Macros and helper functions 44 | ;; 45 | ;; Intended to be internal to vectorz-clj implementation 46 | 47 | (defmacro tag-symbol [tag form] 48 | (let [tagged-sym (vary-meta (gensym "res") assoc :tag tag)] 49 | `(let [~tagged-sym ~form] ~tagged-sym))) 50 | 51 | (defn vectorz-type? [tag] 52 | (let [^String stag (if (class? tag) (.getName ^Class tag) (str tag))] 53 | (or (.startsWith stag "mikera.vectorz.") 54 | (.startsWith stag "mikera.matrixx.") 55 | (.startsWith stag "mikera.indexz.") 56 | (.startsWith stag "mikera.arrayz.")))) 57 | 58 | (defmacro vectorz? 59 | "Returns true if v is a vectorz class (i.e. an instance of mikera.arrayz.INDArray)" 60 | ([a] 61 | `(instance? INDArray ~a))) 62 | 63 | (defmacro vectorz-coerce 64 | "Coerces the argument to a vectorz INDArray. Broadcasts to the shape of an optional target if provided." 65 | ([x] 66 | (if (and (symbol? x) (vectorz-type? (:tag (meta x)))) 67 | x ;; return tagged symbol unchanged 68 | `(tag-symbol mikera.arrayz.INDArray 69 | (let [x# ~x] 70 | (if (instance? INDArray x#) x# (vectorz-coerce* x#)))))) 71 | ([target x] 72 | `(let [m# ~target 73 | x# (vectorz-coerce ~x)] 74 | (tag-symbol mikera.arrayz.INDArray 75 | (if (< (.dimensionality x#) (.dimensionality m#)) 76 | (.broadcastLike x# m#) 77 | x#))))) 78 | 79 | (defmacro vectorz-clone 80 | "Coerces the argument to a new (cloned) vectorz INDArray" 81 | ([x] 82 | `(tag-symbol mikera.arrayz.INDArray 83 | (let [x# ~x] 84 | (if (instance? INDArray x#) (.clone ^INDArray x#) (vectorz-coerce* x#)))))) 85 | 86 | (defmacro avector-coerce 87 | "Coerces an argument x to an AVector instance, of the same size as m" 88 | ([m x] 89 | `(tag-symbol mikera.vectorz.AVector 90 | (let [x# ~x] 91 | (if (instance? AVector x#) x# (avector-coerce* ~m x#))))) 92 | ([x] 93 | `(tag-symbol mikera.vectorz.AVector 94 | (let [x# ~x] 95 | (if (instance? AVector x#) x# (avector-coerce* x#)))))) 96 | 97 | (defmacro amatrix-coerce 98 | "Coerces an argument x to an AMatrix instance" 99 | ([m x] 100 | `(tag-symbol mikera.matrixx.AMatrix 101 | (let [x# ~x] 102 | (if (instance? AMatrix x#) x# (amatrix-coerce* ~m x#))))) 103 | ([x] 104 | `(tag-symbol mikera.matrixx.AMatrix 105 | (let [x# ~x] 106 | (if (instance? AMatrix x#) x# (amatrix-coerce* x#)))))) 107 | 108 | (defmacro with-clone 109 | "Executes the body with a cloned version of the specfied symbol/expression binding. Returns the cloned object." 110 | ([[sym exp] & body] 111 | (let [] 112 | (when-not (symbol? sym) (error "Symbol required for with-clone binding")) 113 | `(let [~sym (.clone ~(if exp exp sym))] 114 | ~@body 115 | ~sym)))) 116 | 117 | (defmacro with-vectorz-clone 118 | "Executes the body with a cloned version of the specfied symbol/expression binding. Returns the cloned object." 119 | ([[sym exp] & body] 120 | (let [] 121 | (when-not (symbol? sym) (error "Symbol required for with-clone binding")) 122 | `(let [~sym (vectorz-clone ~(if exp exp sym))] 123 | ~@body 124 | ~sym)))) 125 | 126 | (defmacro with-broadcast-clone 127 | "Executes body with a broadcasted clone of a and a broadcasted INDArray version of b. 128 | Returns the broadcasted clone of a." 129 | ([[a b] & body] 130 | (when-not (and (symbol? a) (symbol? b)) (error "Symbols required for with-broadcast-clone binding")) 131 | (let [] 132 | `(let [~b (vectorz-coerce ~a ~b) 133 | ~a (.broadcastCloneLike ~a ~b)] 134 | ~@body 135 | ~a)))) 136 | 137 | (defmacro with-broadcast-coerce 138 | "Executes body with a and a coerced INDArray version of b. Returns result of body." 139 | ([[a b] & body] 140 | (when-not (and (symbol? a) (symbol? b)) (error "Symbols required for with-broadcast-clone binding")) 141 | (let [] 142 | `(let [~b (vectorz-coerce ~a ~b) 143 | ~a (.broadcastLike ~a ~b)] 144 | ~@body)))) 145 | 146 | (def ^{:tag Class :const true} INT-ARRAY-CLASS (Class/forName "[I")) 147 | 148 | (defmacro int-array-coerce 149 | "Coerces an arbitrary object to an int array" 150 | ([m] 151 | `(tag-symbol ~'ints 152 | (let [m# ~m] 153 | (cond 154 | (instance? INT-ARRAY-CLASS m#) m# 155 | (sequential? m#) (int-array m#) 156 | :else (int-array (mp/element-seq m#))))))) 157 | 158 | 159 | (defmacro with-indexes 160 | "Executes body after binding int indexes from the given indexing object" 161 | ([[syms ixs] & body] 162 | (let [n (count syms) 163 | isym (gensym)] 164 | `(let [~isym ~ixs] 165 | (cond 166 | (instance? INT-ARRAY-CLASS ~isym) 167 | (let [~isym ~(vary-meta isym assoc :tag "[I") 168 | ~@(interleave 169 | syms 170 | (map (fn [i] `(int (aget ~isym ~i)) ) (range n)))] ~@body) 171 | (instance? clojure.lang.IPersistentVector ~isym) 172 | (let [~isym ~(vary-meta isym assoc :tag 'clojure.lang.IPersistentVector) 173 | ~@(interleave 174 | syms 175 | (map (fn [i] `(int (.nth ~isym ~i)) ) (range n)))] ~@body) 176 | :else 177 | (let [[~@syms] (seq ~isym)] ~@body)))))) 178 | 179 | (defn avector-coerce* 180 | "Coerces any numerical array to an AVector instance. 181 | May broadcast to the shape of an optional target if necessary. 182 | Does *not* guarantee a new copy - may return same data." 183 | (^AVector [^AVector target m] 184 | (cond 185 | (instance? INDArray m) 186 | (.broadcastLike ^INDArray m target) 187 | (number? m) 188 | (Vectorz/createRepeatedElement (.length target) (double m)) 189 | (== 0 (long (mp/dimensionality m))) 190 | (Vectorz/createRepeatedElement (.length target) (double-coerce m)) 191 | :else (.broadcastLike (Vector/wrap ^doubles (mp/to-double-array m)) target))) 192 | (^AVector [m] 193 | (cond 194 | (instance? AVector m) m 195 | (== (dimensionality m) 1) 196 | (Vector/wrap ^doubles (mp/to-double-array m)) 197 | :else (error "Can't coerce to AVector with shape: " (mp/get-shape m))))) 198 | 199 | (defn amatrix-coerce* 200 | "Coerces any numerical array to an AMatrix instance. 201 | May broadcast to the shape of an optional target if necessary. 202 | Does *not* guarantee a new copy - may return same data." 203 | (^AMatrix [^AMatrix target m] 204 | (.broadcastLike ^INDArray (vectorz-coerce* m) target)) 205 | (^AMatrix [m] 206 | (if (instance? AMatrix m) m 207 | (let [rows (long (mp/dimension-count m 0)) 208 | cols (long (mp/dimension-count m 1)) 209 | elems (* rows cols)] 210 | (if (< elems 10000) 211 | ;; dense default for small matrices 212 | (Matrix/wrap rows cols (mp/to-double-array m)) 213 | ;; for larger matrices - TODO think aboutr sparse case? 214 | (Matrixx/create ^java.util.List (mapv vectorz-coerce* (slices m)))))))) 215 | 216 | (defn vectorz-coerce* 217 | "Function to attempt conversion to a Vectorz INDArray object. Should work on any core.matrix 218 | numerical array or scalar. Does *not* guarantee a new copy - may return same data." 219 | (^INDArray [p] 220 | (if (number? p) 221 | (Scalar. (double p)) 222 | (let [dims (long (mp/dimensionality p))] 223 | (cond 224 | (== 0 dims) 225 | (cond 226 | (instance? AScalar p) p 227 | (nil? p) (error "Can't convert nil to vectorz format") 228 | :else (do 229 | ;; (println (str "Coercing " p)) 230 | (Scalar. (double (mp/get-0d p))))) 231 | (== 1 dims) 232 | (avector-coerce* p) 233 | (== 2 dims) 234 | (amatrix-coerce* p) 235 | :else 236 | (let [^List sv (mapv (fn [sl] (vectorz-coerce sl)) (slices p))] 237 | (and (seq sv) (sv 0) (Arrayz/create sv)))))))) 238 | 239 | (defmacro double-coerce 240 | "Macro to coerce to a primitive double value. Works on numbers and 0d arrays." 241 | ([x] 242 | `(let [x# ~x] 243 | (double (if (number? x#) x# (mp/get-0d x#)))))) 244 | 245 | (eval 246 | `(extend-protocol mp/PImplementation 247 | ~@(mapcat 248 | (fn [sym] 249 | (cons sym 250 | '( 251 | (implementation-key [m] :vectorz) 252 | (supports-dimensionality? [m dims] true) 253 | (new-vector [m length] (Vectorz/newVector (int length))) 254 | (new-matrix [m rows columns] (Matrixx/newMatrix (int rows) (int columns))) 255 | (new-matrix-nd [m shape] 256 | (case (count shape) 257 | 0 (Scalar/create 0.0) 258 | 1 (Vector/createLength (int (first shape))) 259 | 2 (Matrix/create (int (first shape)) (int (second shape))) 260 | (Array/newArray (int-array shape)))) 261 | (construct-matrix [m data] 262 | (cond 263 | (instance? INDArray data) 264 | (.clone ^INDArray data) 265 | (mp/is-scalar? data) 266 | (double-coerce data) 267 | (array? data) 268 | (if (== 0 (dimensionality data)) 269 | (double-coerce data) 270 | (vectorz-coerce data)) 271 | :default 272 | (let [vm (mp/construct-matrix [] data)] 273 | ;; (println m vm (shape vm)) 274 | (assign! (mp/new-matrix-nd m (shape vm)) vm))))))) 275 | ['mikera.vectorz.AVector 'mikera.matrixx.AMatrix 'mikera.vectorz.AScalar 'mikera.arrayz.INDArray 'mikera.indexz.AIndex]) )) 276 | 277 | (defmacro with-keys 278 | [available required] 279 | (let [result-sym (gensym)] 280 | `(if ~required 281 | (let 282 | [~result-sym {} 283 | ~@(mapcat (fn [[k v]] `(~result-sym (if (some #{~k} ~required) 284 | (assoc ~result-sym ~k ~v) 285 | ~result-sym))) 286 | available)] 287 | ~result-sym) 288 | ~available))) 289 | 290 | (extend-protocol mp/PQRDecomposition 291 | INDArray 292 | (qr [m options] 293 | (let [dims (dimensionality m)] 294 | (if (== 2 dims) 295 | (mp/qr (Matrixx/toMatrix m) options) 296 | (error "Can't compute QR on an array of dimensionality " dims)))) 297 | AMatrix 298 | (qr [m options] 299 | (let 300 | [result (cond 301 | (:compact options) (QR/decompose m true) 302 | :else (QR/decompose m))] 303 | (with-keys {:Q (.getQ result) :R (.getR result)} (:return options))))) 304 | 305 | (extend-protocol mp/PLUDecomposition 306 | INDArray 307 | (lu [m options] 308 | (let [dims (dimensionality m)] 309 | (if (== 2 dims) 310 | (mp/lu (Matrixx/toMatrix m) options) 311 | (error "Can't compute LU on an array of dimensionality " dims)))) AMatrix 312 | AMatrix 313 | (lu [m options] 314 | (let 315 | [result (LUP/decompose m)] 316 | (with-keys {:L (.getL result) :U (.getU result) :P (.getP result)} (:return options))))) 317 | 318 | (extend-protocol mp/PCholeskyDecomposition 319 | INDArray 320 | (cholesky [m options] 321 | (let [dims (dimensionality m)] 322 | (if (== 2 dims) 323 | (mp/cholesky (Matrixx/toMatrix m) options) 324 | (error "Can't compute cholesky on an array of dimensionality " dims)))) AMatrix 325 | AMatrix 326 | (cholesky [m options] 327 | (when-let [result (Cholesky/decompose m)] 328 | (with-keys {:L (.getL result) :L* (.getU result)} (:return options))))) 329 | 330 | (extend-protocol mp/PSVDDecomposition 331 | INDArray 332 | (svd [m options] 333 | (let [dims (dimensionality m)] 334 | (if (== 2 dims) 335 | (mp/svd (Matrixx/toMatrix m) options) 336 | (error "Can't compute SVD on an array of dimensionality " dims)))) AMatrix 337 | AMatrix 338 | (svd [m options] 339 | (when-let [result (SVD/decompose m)] 340 | (with-keys {:U (.getU result) :S (diagonal (.getS result)) :V* (.getTranspose (.getV result))} (:return options))))) 341 | 342 | ;; TODO: complete Eigendecomposition 343 | ;; Need to handle complex numbers! 344 | ; 345 | ;(extend-protocol mp/PEigenDecomposition 346 | ; INDArray 347 | ; (eigen [m options] 348 | ; (let [dims (dimensionality m)] 349 | ; (if (== 2 dims) 350 | ; (mp/eigen (Matrixx/toMatrix m) options) 351 | ; (error "Can't compute Eigendecomposition on an array of dimensionality " dims)))) AMatrix 352 | ; AMatrix 353 | ; (eigen [m options] 354 | ; (when-let [result (Eigen/decompose m)] 355 | ; (let [eigenvalues (.getEigenvalues result) 356 | ; eigenvectors (.getEigenVectors result) 357 | ; Qt (VectorMatrixMN/wrap eigenvectors)] 358 | ; (with-keys {:Q (.transpose Qt) 359 | ; :Qt Qt 360 | ; :A (DiagonalMatrix/create eigenvectors) 361 | ; :eigenvalues eigenvalues 362 | ; :eigenvectorss eigenvectors} 363 | ; (:return options)))))) 364 | 365 | (extend-protocol mp/PNorm 366 | INDArray 367 | (norm [m p] 368 | (cond 369 | (= java.lang.Double/POSITIVE_INFINITY p) (.elementMaxAbs m) 370 | (number? p) (Math/pow (.elementAbsPowSum m p) (/ 1.0 (double p))) 371 | :else (error "p must be a number")))) 372 | 373 | (extend-protocol mp/PMatrixRank 374 | INDArray 375 | (rank [m options] 376 | (let [dims (dimensionality m)] 377 | (if (== 2 dims) 378 | (mp/rank (Matrixx/toMatrix m)) 379 | (error "Can't compute matrix rank on an array of dimensionality " dims)))) AMatrix 380 | AMatrix 381 | (rank [m] 382 | (let [{:keys [S]} (mp/svd m {:return [:S]}) 383 | eps 1e-10] 384 | (reduce (fn [^long n x] (if (< (java.lang.Math/abs (double x)) eps) n (inc n))) 0 S)))) 385 | 386 | (extend-protocol mp/PSolveLinear 387 | AMatrix 388 | (solve [a b] 389 | (Linear/solve a (avector-coerce b)))) 390 | 391 | (extend-protocol mp/PLeastSquares 392 | AMatrix 393 | (least-squares [a b] 394 | (Linear/solveLeastSquares a (avector-coerce b)))) 395 | 396 | (extend-protocol mp/PTypeInfo 397 | INDArray 398 | (element-type [m] (Double/TYPE)) 399 | AIndex 400 | (element-type [m] (Integer/TYPE))) 401 | 402 | (extend-protocol mp/PGenericValues 403 | INDArray 404 | (generic-zero [m] 405 | 0.0) 406 | (generic-one [m] 407 | 1.0) 408 | (generic-value [m] 409 | 0.0)) 410 | 411 | (extend-protocol mp/PMutableMatrixConstruction 412 | INDArray 413 | (mutable-matrix [m] (.clone m)) 414 | AIndex 415 | (mutable-matrix [m] (.clone m))) 416 | 417 | (extend-protocol mp/PMatrixMutableScaling 418 | INDArray 419 | (scale! [m a] 420 | (.scale m (double-coerce a))) 421 | (pre-scale! [m a] 422 | (.scale m (double-coerce a))) 423 | AVector 424 | (scale! [m a] 425 | (.scale m (double-coerce a))) 426 | (pre-scale! [m a] 427 | (.scale m (double-coerce a)))) 428 | 429 | (extend-protocol mp/PNumerical 430 | INDArray 431 | (numerical? [m] 432 | true) 433 | AIndex 434 | (numerical? [m] 435 | true)) 436 | 437 | (extend-protocol mp/PSameShape 438 | INDArray 439 | (same-shape? [a b] 440 | (if (instance? INDArray b) 441 | (.isSameShape a ^INDArray b) 442 | (clojure.core.matrix.utils/same-shape-object? (mp/get-shape a) (mp/get-shape b))))) 443 | 444 | (extend-protocol mp/PDoubleArrayOutput 445 | INDArray 446 | (to-double-array [m] (.toDoubleArray m)) 447 | (as-double-array [m] nil) 448 | Array 449 | (to-double-array [m] (.toDoubleArray m)) 450 | (as-double-array [m] (.getArray m)) 451 | AScalar 452 | (to-double-array [m] (let [arr (double-array 1)] (aset arr (int 0) (.get m)) arr)) 453 | (as-double-array [m] nil) 454 | Vector 455 | (to-double-array [m] (.toDoubleArray m)) 456 | (as-double-array [m] (.getArray m)) 457 | AVector 458 | (to-double-array [m] (.toDoubleArray m)) 459 | (as-double-array [m] nil) 460 | AMatrix 461 | (to-double-array [m] (.toDoubleArray (.asVector m))) 462 | (as-double-array [m] nil) 463 | Matrix 464 | (to-double-array [m] (.toDoubleArray (.asVector m))) 465 | (as-double-array [m] (.data m))) 466 | 467 | (extend-protocol mp/PObjectArrayOutput 468 | INDArray 469 | (to-object-array [m] 470 | (let [ec (.elementCount m) 471 | ^objects obs (object-array ec)] 472 | (.getElements m obs (int 0)) 473 | obs)) 474 | (as-object-array [m] 475 | nil)) 476 | 477 | (extend-protocol mp/PVectorisable 478 | INDArray 479 | (to-vector [m] 480 | (.toVector m)) 481 | AVector 482 | (to-vector [m] 483 | (.clone m))) 484 | 485 | (extend-protocol mp/PMutableFill 486 | INDArray 487 | (fill! 488 | [m value] 489 | (.fill m (double-coerce value)))) 490 | 491 | (extend-protocol mp/PDimensionInfo 492 | INDArray 493 | (dimensionality [m] 494 | (.dimensionality m)) 495 | (is-vector? [m] 496 | (== 1 (.dimensionality m))) 497 | (is-scalar? [m] 498 | false) 499 | (get-shape [m] 500 | (.getShape m)) 501 | (dimension-count [m x] 502 | (.getShape m (int x))) 503 | AScalar 504 | (dimensionality [m] 505 | 0) 506 | (is-vector? [m] 507 | false) 508 | (is-scalar? [m] 509 | false) ;; this isn't an immutable scalar value in the core.matrix sense 510 | (get-shape [m] 511 | []) 512 | (dimension-count [m x] 513 | (error "Scalar does not have dimension: " x)) 514 | AVector 515 | (dimensionality [m] 516 | 1) 517 | (is-vector? [m] 518 | true) 519 | (is-scalar? [m] 520 | false) 521 | (get-shape [m] 522 | [(long (.length m))]) 523 | (dimension-count [m x] 524 | (if (== 0 (long x)) 525 | (.length m) 526 | (error "Vector does not have dimension: " x))) 527 | AMatrix 528 | (dimensionality [m] 529 | 2) 530 | (is-vector? [m] 531 | false) 532 | (is-scalar? [m] 533 | false) 534 | (get-shape [m] 535 | [(long (.rowCount m)) (long (.columnCount m))]) 536 | (dimension-count [m x] 537 | (let [x (int x)] 538 | (cond 539 | (== x 0) (.rowCount m) 540 | (== x 1) (.columnCount m) 541 | :else (error "Matrix does not have dimension: " x)))) 542 | AIndex 543 | (dimensionality [m] 544 | 1) 545 | (is-vector? [m] 546 | true) 547 | (is-scalar? [m] 548 | false) 549 | (get-shape [m] 550 | [(long (.length m))]) 551 | (dimension-count [m x] 552 | (let [x (int x)] 553 | (cond 554 | (== x 0) (.length m) 555 | :else (error "Index does not have dimension: " x))))) 556 | 557 | (extend-protocol mp/PIndexedAccess 558 | INDArray 559 | (get-1d [m x] 560 | (.get m (int x))) 561 | (get-2d [m x y] 562 | (.get m (int x) (int y))) 563 | (get-nd [m indexes] 564 | (.get m (int-array indexes))) 565 | AScalar 566 | (get-1d [m x] 567 | (error "Can't access 1-dimensional index of a scalar")) 568 | (get-2d [m x y] 569 | (error "Can't access 2-dimensional index of a scalar")) 570 | (get-nd [m indexes] 571 | (if-let [ni (seq indexes)] 572 | (error "Can't access multi-dimensional index of a scalar") 573 | (.get m))) 574 | AVector 575 | (get-1d [m x] 576 | (.get m (int x))) 577 | (get-2d [m x y] 578 | (error "Can't access 2-dimensional index of a vector")) 579 | (get-nd [m indexes] 580 | (with-indexes [[x] indexes] 581 | (.get m (int x)))) 582 | AMatrix 583 | (get-1d [m x] 584 | (error "Can't access 1-dimensional index of a matrix")) 585 | (get-2d [m x y] 586 | (.get m (int x) (int y))) 587 | (get-nd [m indexes] 588 | (with-indexes [[x y] indexes] 589 | (.get m (int x) (int y)))) 590 | AIndex 591 | (get-1d [m x] 592 | (.get m (int x))) 593 | (get-2d [m x y] 594 | (error "Can't access 2-dimensional index of an Index")) 595 | (get-nd [m indexes] 596 | (with-indexes [[x] indexes] 597 | (.get m (int x))))) 598 | 599 | (extend-protocol mp/PZeroDimensionConstruction 600 | INDArray 601 | (new-scalar-array 602 | ([m] 603 | (Scalar/create 0.0)) 604 | ([m value] 605 | (Scalar/create (double-coerce value))))) 606 | 607 | (extend-protocol mp/PZeroDimensionAccess 608 | INDArray 609 | (get-0d [m] 610 | (.get m)) 611 | (set-0d! [m value] 612 | (.set m (double-coerce value))) 613 | AScalar 614 | (get-0d [m] 615 | (.get m)) 616 | (set-0d! [m value] 617 | (.set m (double-coerce value)))) 618 | 619 | (extend-protocol mp/PZeroDimensionSet 620 | INDArray 621 | (set-0d [m value] 622 | (if (== 0 (.dimensionality m)) 623 | (Scalar/create (double-coerce value)) 624 | (error "Can't do 0-d set on " (class m)))) 625 | AScalar 626 | (set-0d [m value] 627 | (Scalar/create (double-coerce value)))) 628 | 629 | (extend-protocol mp/PImmutableMatrixConstruction 630 | INDArray 631 | (immutable-matrix [m] 632 | (.immutable m))) 633 | 634 | ;; TODO semantics are tricky re: cloning or not? 635 | ;(extend-protocol mp/PImmutableAssignment 636 | ; INDArray 637 | ; (assign 638 | ; [m source] 639 | ; (broadcast-coerce m source))) 640 | 641 | (extend-protocol mp/PSpecialisedConstructors 642 | INDArray 643 | (identity-matrix [m dims] 644 | (Matrixx/createIdentityMatrix (int dims))) 645 | (diagonal-matrix [m diagonal-values] 646 | (Matrixx/createDiagonalMatrix (Vectorz/toVector diagonal-values)))) 647 | 648 | (extend-protocol mp/PPermutationMatrix 649 | INDArray 650 | (permutation-matrix [m permutation] 651 | (let [v (int-array-coerce permutation)] 652 | (mikera.matrixx.impl.PermutationMatrix/create v)))) 653 | 654 | (extend-protocol mp/PBroadcast 655 | INDArray 656 | (broadcast [m target-shape] 657 | (.broadcast m (int-array-coerce target-shape)))) 658 | 659 | (extend-protocol mp/PBroadcastLike 660 | INDArray 661 | (broadcast-like [m a] 662 | (vectorz-coerce m a))) 663 | 664 | (extend-protocol mp/PBroadcastCoerce 665 | INDArray 666 | (broadcast-coerce [m a] 667 | (vectorz-coerce m a))) 668 | 669 | (extend-protocol mp/PReshaping 670 | INDArray 671 | (reshape [m target-shape] 672 | (.reshape m (int-array target-shape)))) 673 | 674 | (extend-protocol mp/PZeroCount 675 | INDArray 676 | (zero-count [m] (- (.elementCount m) (.nonZeroCount m)))) 677 | 678 | 679 | (extend-protocol mp/PArrayMetrics 680 | INDArray 681 | (nonzero-count [m] (.nonZeroCount m))) 682 | 683 | (extend-protocol mp/PMatrixTypes 684 | AMatrix 685 | (diagonal? [m] (.isDiagonal m)) 686 | (upper-triangular? [m] (.isUpperTriangular m)) 687 | (lower-triangular? [m] (.isLowerTriangular m)) 688 | (positive-definite? [m] (mikera.matrixx.algo.Definite/isPositiveDefinite m)) 689 | (positive-semidefinite? [m] (mikera.matrixx.algo.Definite/isPositiveSemiDefinite m)) 690 | (orthogonal? [m eps] (.isOrthogonal m (double-coerce eps)))) 691 | 692 | (extend-protocol mp/PIndexedSetting 693 | INDArray 694 | (set-1d [m row v] 695 | (with-clone [m] (.set m (int row) (double v)))) 696 | (set-2d [m row column v] 697 | (with-clone [m] (.set m (int row) (int column) (double v)))) 698 | (set-nd [m indexes v] 699 | (with-clone [m] (.set m (int-array indexes) (double v)))) 700 | (is-mutable? [m] (.isFullyMutable m)) 701 | 702 | AScalar 703 | (set-1d [m row v] (error "Can't do 1-dimensional set on a 0-d array!")) 704 | (set-2d [m row column v] (error "Can't do 2-dimensional set on a 0-d array!")) 705 | (set-nd [m indexes v] 706 | (if (== 0 (count indexes)) 707 | (Scalar/create (double v)) 708 | (error "Can't do " (count indexes) "-dimensional set on a 0-d array!"))) 709 | (is-mutable? [m] (.isFullyMutable m)) 710 | AVector 711 | (set-1d [m row v] 712 | (let [m (.clone m)] (.set m (int row) (double v)) m)) 713 | (set-2d [m row column v] (error "Can't do 2-dimensional set on a 1D vector!")) 714 | (set-nd [m indexes v] 715 | (if (== 1 (count indexes)) 716 | (with-clone [m] (.set m (int (first indexes)) (double v))) 717 | (error "Can't do " (count indexes) "-dimensional set on a 1D vector!"))) 718 | (is-mutable? [m] (.isFullyMutable m)) 719 | AMatrix 720 | (set-1d [m row v] (error "Can't do 1-dimensional set on a 2D matrix!")) 721 | (set-2d [m row column v] 722 | (with-clone [m] (.set m (int row) (int column) (double v)))) 723 | (set-nd [m indexes v] 724 | (with-clone [m] (.set m (int-array indexes) (double v)))) 725 | (is-mutable? [m] (.isFullyMutable m))) 726 | 727 | (extend-protocol mp/PIndexedSettingMutable 728 | INDArray 729 | (set-1d! [m row v] 730 | (.set m (int row) (double v))) ;; double is OK: v should only be a java.lang.Number instance 731 | (set-2d! [m row column v] 732 | (.set m (int row) (int column) (double v))) 733 | (set-nd! [m indexes v] 734 | (.set m (int-array indexes) (double v))) 735 | AScalar 736 | (set-1d! [m row v] (error "Can't do 1-dimensional set on a 0D array!")) 737 | (set-2d! [m row column v] (error "Can't do 1-dimensional set on a 0D array!")) 738 | (set-nd! [m indexes v] 739 | (if (== 0 (count indexes)) 740 | (.set m (double v)) 741 | (error "Can't do " (count indexes) "-dimensional set on a 0D array!"))) 742 | AVector 743 | (set-1d! [m row v] (.set m (int row) (double v))) 744 | (set-2d! [m row column v] (error "Can't do 2-dimensional set on a 1D vector!")) 745 | (set-nd! [m indexes v] 746 | (if (== 1 (count indexes)) 747 | (.set m (int (first indexes)) (double v)) 748 | (error "Can't do " (count indexes) "-dimensional set on a 1D vector!"))) 749 | AIndex 750 | (set-1d! [m row v] (.set m (int row) (int v))) 751 | (set-2d! [m row column v] (error "Can't do 2-dimensional set on a 1D index!")) 752 | (set-nd! [m indexes v] 753 | (if (== 1 (count indexes)) 754 | (.set m (int (first indexes)) (double v)) 755 | (error "Can't do " (count indexes) "-dimensional set on a 1D index!"))) 756 | AMatrix 757 | (set-1d! [m row v] (error "Can't do 1-dimensional set on a 2D matrix!")) 758 | (set-2d! [m row column v] (.set m (int row) (int column) (double v))) 759 | (set-nd! [m indexes v] 760 | (if (== 2 (count indexes)) 761 | (.set m (int (first indexes)) (int (second indexes)) (double v)) 762 | (error "Can't do " (count indexes) "-dimensional set on a 2D matrix!")))) 763 | 764 | (extend-protocol mp/PSparseArray 765 | INDArray 766 | (is-sparse? [m] 767 | (.isSparse m))) 768 | 769 | (extend-protocol mp/PNewSparseArray 770 | INDArray 771 | (new-sparse-array [m shape] 772 | (Arrayz/createSparseArray (int-array-coerce shape)))) 773 | 774 | (extend-protocol mp/PMatrixEquality 775 | INDArray 776 | (matrix-equals [a b] 777 | (.equals a (vectorz-coerce b))) 778 | AMatrix 779 | (matrix-equals [a b] 780 | (.equals a (vectorz-coerce b))) 781 | AVector 782 | (matrix-equals [a b] 783 | (.equals a (vectorz-coerce b)))) 784 | 785 | (extend-protocol mp/PValueEquality 786 | INDArray 787 | (value-equals [a b] 788 | (let [b (vectorz-coerce b)] 789 | (and 790 | (.isSameShape a b) 791 | (.equals a b))))) 792 | 793 | (extend-protocol mp/PMatrixEqualityEpsilon 794 | INDArray 795 | (matrix-equals-epsilon [a b eps] 796 | (.epsilonEquals a (vectorz-coerce b) (double eps)))) 797 | 798 | (extend-protocol mp/PMatrixSlices 799 | INDArray 800 | (get-row [m i] 801 | (if (== 2 (dimensionality m)) 802 | (.slice m (int i)) 803 | (error "Can't get row of array with dimensionality: " (dimensionality m)))) 804 | (get-column [m i] 805 | (if (== 2 (dimensionality m)) 806 | (.slice m (int 1) (int i)) 807 | (error "Can't get column of array with dimensionality: " (dimensionality m)))) 808 | (get-major-slice [m i] 809 | (.sliceValue m (int i))) 810 | (get-slice [m dimension i] 811 | (let [dimension (int dimension)] 812 | (.slice m dimension (int i)))) 813 | AVector 814 | (get-row [m i] 815 | (error "Can't access row of a 1D vector!")) 816 | (get-column [m i] 817 | (error "Can't access column of a 1D vector!")) 818 | (get-major-slice [m i] 819 | (.sliceValue m (int i))) 820 | (get-slice [m dimension i] 821 | (if (== 0 (long dimension)) 822 | (.sliceValue m (int i)) 823 | (error "Can't get slice from vector with dimension: " dimension))) 824 | AMatrix 825 | (get-row [m i] 826 | (.getRow m (int i))) 827 | (get-column [m i] 828 | (.getColumn m (int i))) 829 | (get-major-slice [m i] 830 | (.slice m (int i))) 831 | (get-slice [m dimension i] 832 | (.slice m (int dimension) (int i)))) 833 | 834 | (extend-protocol mp/PRotate 835 | INDArray 836 | (rotate [m dim places] 837 | (let [dim (int dim)] 838 | (if (<= 0 dim (dec (.dimensionality m))) 839 | (.rotateView m dim (int places)) 840 | m)))) 841 | 842 | (extend-protocol mp/PShift 843 | AVector 844 | (shift [m dim shift] 845 | (if (== (long dim) 0) 846 | (.shiftCopy m (int shift)) 847 | (error "Can't shift vector along dimension: " dim))) 848 | (shift-all [m shifts] 849 | (let [n (count shifts)] 850 | (cond 851 | (== n 0) m 852 | (== n 1) (.shiftCopy m (int (first shifts))) 853 | :else (error "Can't shift vector along more than one dimension"))))) 854 | 855 | (extend-protocol mp/POrder 856 | INDArray 857 | (order 858 | ([m indices] 859 | (.reorder m (int-array-coerce indices))) 860 | ([m dimension indices] 861 | (.reorder m (int dimension) (int-array-coerce indices))))) 862 | 863 | (extend-protocol mp/PMatrixRows 864 | AMatrix 865 | (get-rows [m] 866 | (.getRows m))) 867 | 868 | (extend-protocol mp/PMatrixColumns 869 | AMatrix 870 | (get-columns [m] 871 | (.getColumns m))) 872 | 873 | (extend-protocol mp/PRowSetting 874 | AMatrix 875 | ;; note: use avector-coerce on the argument to ensure correct broadcasting 876 | (set-row [m i row] 877 | (with-clone [m] 878 | (.setRow m (int i) (avector-coerce (.getRow m 0) row)))) 879 | (set-row! [m i row] 880 | (.setRow m (int i) (avector-coerce (.getRow m 0) row)))) 881 | 882 | (extend-protocol mp/PColumnSetting 883 | AMatrix 884 | ;; note: use avector-coerce on the argument to ensure correct broadcasting 885 | (set-column [m i v] 886 | (with-clone [m] 887 | (.setColumn m (int i) (avector-coerce (.getColumn m 0) v)))) 888 | (set-column! [m i v] 889 | (.setColumn m (int i) (avector-coerce (.getColumn m 0) v)))) 890 | 891 | (extend-protocol mp/PSliceView 892 | INDArray 893 | (get-major-slice-view [m i] 894 | (.slice m (int i)))) 895 | 896 | (extend-protocol mp/PSliceView2 897 | INDArray 898 | (get-slice-view [m dim i] 899 | (.slice m (int dim) (int i)))) 900 | 901 | (extend-protocol mp/PSliceSeq 902 | INDArray 903 | (get-major-slice-seq [m] 904 | (seq (.getSlices m))) 905 | AVector 906 | (get-major-slice-seq [m] 907 | ;; we want Clojure to produce an efficient ArraySeq, so we convert to double array first 908 | (seq (or (.asDoubleArray m) (.toDoubleArray m)))) 909 | Index 910 | (get-major-slice-seq [m] 911 | (seq (.getData m)))) 912 | 913 | (extend-protocol mp/PSliceSeq2 914 | INDArray 915 | (get-slice-seq [m dimension] 916 | (let [ldimension (long dimension)] 917 | (cond 918 | (== ldimension 0) (mp/get-major-slice-seq m) 919 | (< ldimension 0) (error "Can't get slices of a negative dimension: " dimension) 920 | :else (map #(mp/get-slice m dimension %) (range (mp/dimension-count m dimension)))))) 921 | AVector 922 | (get-slice-seq [m dimension] 923 | (if (== 0 (long dimension)) 924 | m 925 | (error "Can't access dimension " dimension " of a vector")))) 926 | 927 | (extend-protocol mp/PSliceViewSeq 928 | INDArray 929 | (get-major-slice-view-seq [m] 930 | (seq (.getSliceViews m)))) 931 | 932 | (extend-protocol mp/PMatrixSubComponents 933 | AMatrix 934 | (main-diagonal [m] 935 | (.getLeadingDiagonal m))) 936 | 937 | (extend-protocol mp/PAssignment 938 | AScalar 939 | (assign! 940 | [m source] (.set m (double-coerce source))) 941 | (assign-array! 942 | ([m arr] (.set m (double (nth arr 0)))) 943 | ([m arr start length] (.set m (double (nth arr 0))))) 944 | AVector 945 | (assign! [m source] 946 | (cond 947 | (number? source) 948 | (.fill m (double source)) 949 | (instance? INDArray source) 950 | (.set m ^INDArray source) 951 | (== 0 (dimensionality source)) 952 | (.fill m (double-coerce source)) 953 | :else 954 | (.set m (vectorz-coerce source)))) 955 | (assign-array! 956 | ([m arr] (dotimes [i (count arr)] (.set m (int i) (double (nth arr i))))) 957 | ([m arr start length] 958 | (let [length (long length) start (long start)] 959 | (dotimes [i length] (.set m (int i) (double (nth arr (+ i start)))))))) 960 | 961 | INDArray 962 | (assign! [m source] 963 | (.set m (vectorz-coerce m source))) 964 | (assign-array! 965 | ([m arr] 966 | (let [alen (long (count arr))] 967 | (if (mp/is-vector? m) 968 | (dotimes [i alen] 969 | (mp/set-1d! m i (nth arr i))) 970 | (mp/assign-array! m arr 0 alen)))) 971 | ([m arr start length] 972 | (let [length (long length) 973 | start (long start)] 974 | (if (mp/is-vector? m) 975 | (dotimes [i length] 976 | (mp/set-1d! m i (nth arr (+ start i)))) 977 | (let [ss (seq (mp/get-major-slice-seq m)) 978 | skip (long (if ss (mp/element-count (first (mp/get-major-slice-seq m))) 0))] 979 | (doseq-indexed [s ss i] 980 | (mp/assign-array! s arr (+ start (* skip i)) skip)))))))) 981 | 982 | (extend-protocol mp/PSubVector 983 | INDArray 984 | (subvector [m start length] 985 | (let [dims (.dimensionality m)] 986 | (if (== 1 dims) 987 | (.subVector (.asVector m) (int start) (int length)) 988 | (error "Can't take subvector of " dims "-D array")))) 989 | 990 | AVector 991 | (subvector [m start length] 992 | (.subVector m (int start) (int length)))) 993 | 994 | (extend-protocol mp/PSubMatrix 995 | AMatrix 996 | (submatrix [m index-ranges] 997 | (let [[rr cr] index-ranges 998 | s1 (int (if rr (first rr) 0)) 999 | s2 (int (if cr (first cr) 0)) 1000 | l1 (int (if rr (second rr) (.rowCount m))) 1001 | l2 (int (if cr (second cr) (.columnCount m)))] 1002 | (.subMatrix m s1 l1 s2 l2))) 1003 | AVector 1004 | (submatrix [m index-ranges] 1005 | (let [[rr] index-ranges 1006 | s1 (int (if rr (first rr) 0)) 1007 | l1 (int (if rr (second rr) (.length m)))] 1008 | (.subVector m s1 l1)))) 1009 | 1010 | ;; protocols for indexed access 1011 | 1012 | (extend-protocol mp/PSelect 1013 | ;; note that select needs to create a view 1014 | INDArray 1015 | (select [a args] 1016 | (if (empty? args) 1017 | a 1018 | (let [args (mapv #(int-array-coerce %) args) 1019 | dims (.dimensionality a) 1020 | next-args (next args) 1021 | ^ints ixs (first args) 1022 | n (alength ixs) 1023 | oa (object-array n)] 1024 | (cond 1025 | (> dims 1) 1026 | (do 1027 | (dotimes [i n] 1028 | (aset oa i (mp/select (.slice a (aget ixs i)) next-args))) 1029 | (SliceArray/create ^List (vec oa))) 1030 | :else 1031 | (.select (.asVector a) ixs))))) 1032 | ;; TODO AMatrix override 1033 | AVector 1034 | (select [a args] 1035 | (if (empty? args) 1036 | a 1037 | (let [ixs (int-array-coerce (first args))] 1038 | (.select a ixs))))) 1039 | 1040 | (extend-protocol mp/PSetSelection 1041 | AVector 1042 | (set-selection [a args values] 1043 | (let [ixs (int-array-coerce (first args)) 1044 | sv (.select a ixs) 1045 | vs (avector-coerce sv values)] 1046 | (.set sv vs)))) 1047 | 1048 | (extend-protocol mp/PIndicesAccess 1049 | INDArray 1050 | (get-indices [a indices] 1051 | (let [c (int (count indices)) 1052 | r (Vectorz/newVector c)] 1053 | (doseq-indexed [ix indices i] 1054 | (.unsafeSet r (int i) (.get a (int-array-coerce ix)))) 1055 | r))) 1056 | 1057 | (extend-protocol mp/PIndicesSetting 1058 | INDArray 1059 | (set-indices [a indices values] 1060 | (let [result (.clone a)] 1061 | (mp/set-indices! result indices values) 1062 | result)) 1063 | (set-indices! [a indices values] 1064 | (let [c (int (count indices)) 1065 | vs (avector-coerce values)] 1066 | (doseq-indexed [ix indices i] 1067 | (.set a (int-array-coerce ix) (.get vs (int i)))) 1068 | a))) 1069 | 1070 | (extend-protocol mp/PNonZeroIndices 1071 | AVector 1072 | (non-zero-indices 1073 | [m] 1074 | (.nonZeroIndices m)) 1075 | AMatrix 1076 | (non-zero-indices 1077 | [m] 1078 | (vec (for [i (range (mp/dimension-count m 0))] 1079 | (mp/non-zero-indices (mp/get-major-slice m i)))))) 1080 | 1081 | ;; protocols for elementwise ops 1082 | 1083 | (extend-protocol mp/PSummable 1084 | INDArray 1085 | (element-sum [m] 1086 | (.elementSum m)) 1087 | AVector 1088 | (element-sum [m] 1089 | (.elementSum m)) 1090 | AMatrix 1091 | (element-sum [m] 1092 | (.elementSum m)) 1093 | AScalar 1094 | (element-sum [m] 1095 | (.get m))) 1096 | 1097 | (extend-protocol mp/PMatrixAdd 1098 | mikera.vectorz.AScalar 1099 | (matrix-add [m a] 1100 | (with-broadcast-coerce [m a] (.addCopy m a))) 1101 | (matrix-sub [m a] 1102 | (with-broadcast-coerce [m a] (.subCopy m a))) 1103 | mikera.vectorz.AVector 1104 | (matrix-add [m a] 1105 | (with-broadcast-coerce [m a] (.addCopy m a))) 1106 | (matrix-sub [m a] 1107 | (with-broadcast-coerce [m a] (.subCopy m a))) 1108 | SparseIndexedVector 1109 | (matrix-add [m a] 1110 | (with-broadcast-clone [m a] (.add m a))) 1111 | (matrix-sub [m a] 1112 | (with-broadcast-clone [m a] (.sub m a))) 1113 | mikera.matrixx.AMatrix 1114 | (matrix-add [m a] 1115 | (with-broadcast-coerce [m a] (.addCopy m a))) 1116 | (matrix-sub [m a] 1117 | (with-broadcast-coerce [m a] (.subCopy m a))) 1118 | INDArray 1119 | (matrix-add [m a] 1120 | (with-broadcast-coerce [m a] (.addCopy m a))) 1121 | (matrix-sub [m a] 1122 | (with-broadcast-coerce [m a] (.subCopy m a)))) 1123 | 1124 | (extend-protocol mp/PMatrixAddMutable 1125 | INDArray 1126 | (matrix-add! [m a] 1127 | (.add m (vectorz-coerce a))) 1128 | (matrix-sub! [m a] 1129 | (.sub m (vectorz-coerce a))) 1130 | AScalar 1131 | (matrix-add! [m a] 1132 | (.add m (double-coerce a))) 1133 | (matrix-sub! [m a] 1134 | (.sub m (double-coerce a))) 1135 | AVector 1136 | (matrix-add! [m a] 1137 | (.add m (avector-coerce m a))) 1138 | (matrix-sub! [m a] 1139 | (.sub m (avector-coerce m a))) 1140 | AMatrix 1141 | (matrix-add! [m a] 1142 | (.add m (amatrix-coerce m a))) 1143 | (matrix-sub! [m a] 1144 | (.sub m (amatrix-coerce m a)))) 1145 | 1146 | (extend-protocol mp/PScaleAdd 1147 | INDArray 1148 | (scale-add! [m a m2 b c] 1149 | (.scaleAdd m (double-coerce a) (vectorz-coerce m2) (double-coerce b) (double-coerce c)) 1150 | m)) 1151 | 1152 | (extend-protocol mp/PVectorOps 1153 | INDArray 1154 | (vector-dot [a b] 1155 | (.dotProduct (avector-coerce a) (avector-coerce a b))) 1156 | (length [a] 1157 | (.magnitude (avector-coerce a))) 1158 | (length-squared [a] 1159 | (.magnitudeSquared (avector-coerce a))) 1160 | (normalise [a] 1161 | (with-clone [a] (.toNormal (avector-coerce a)))) 1162 | 1163 | AVector 1164 | (vector-dot [a b] 1165 | (.dotProduct a (avector-coerce a b))) 1166 | (length [a] 1167 | (.magnitude a)) 1168 | (length-squared [a] 1169 | (.magnitudeSquared a)) 1170 | (normalise [a] 1171 | (.toNormal a))) 1172 | 1173 | (extend-protocol mp/PNegation 1174 | AScalar (negate [m] (with-clone [m] (.negate m))) 1175 | AVector (negate [m] (with-clone [m] (.negate m))) 1176 | AMatrix (negate [m] (with-clone [m] (.negate m))) 1177 | INDArray (negate [m] (with-clone [m] (.negate m)))) 1178 | 1179 | (extend-protocol mp/PTranspose 1180 | INDArray (transpose [m] (.getTranspose m)) 1181 | AScalar (transpose [m] m) 1182 | AVector (transpose [m] m) 1183 | AMatrix (transpose [m] (.getTranspose m))) 1184 | 1185 | (extend-protocol mp/PTransposeInPlace 1186 | AMatrix (transpose! [m] (.transposeInPlace m)) 1187 | AVector (transpose! [m] m) 1188 | AScalar (transpose! [m] m)) 1189 | 1190 | (extend-protocol mp/PVectorCross 1191 | INDArray 1192 | (cross-product [a b] 1193 | (let [v (Vector3. (avector-coerce a))] 1194 | (.crossProduct v (avector-coerce a b)) 1195 | v)) 1196 | (cross-product! [a b] 1197 | (assign! a (mp/cross-product a b))) 1198 | AVector 1199 | (cross-product [a b] 1200 | (let [v (Vector3. a)] 1201 | (.crossProduct v (avector-coerce a b)) 1202 | v)) 1203 | (cross-product! [a b] 1204 | (.crossProduct a (avector-coerce a b)))) 1205 | 1206 | (extend-protocol mp/PMatrixCloning 1207 | INDArray (clone [m] (.clone m)) 1208 | AScalar (clone [m] (.clone m)) 1209 | AVector (clone [m] (.clone m)) 1210 | AMatrix (clone [m] (.clone m)) 1211 | AIndex (clone [m] (.clone m))) 1212 | 1213 | (extend-protocol mp/PCoercion 1214 | INDArray 1215 | (coerce-param [m param] 1216 | (if (number? param) 1217 | param 1218 | (vectorz-coerce param))) 1219 | AIndex 1220 | (coerce-param [m param] 1221 | (if (== 1 (dimensionality param)) 1222 | (Index/of (int-array (mp/element-seq param))) 1223 | (error "Cannot coerce to Index with shape: " (vec (mp/get-shape param)))))) 1224 | 1225 | (extend-protocol mp/PRowColMatrix 1226 | INDArray 1227 | (column-matrix [m data] 1228 | (mikera.matrixx.impl.ColumnMatrix. (avector-coerce data))) 1229 | (row-matrix [m data] 1230 | (mikera.matrixx.impl.RowMatrix. (avector-coerce data)))) 1231 | 1232 | (extend-protocol mp/PValidateShape 1233 | INDArray 1234 | (validate-shape 1235 | ([m] 1236 | (.validate m) 1237 | (vec (.getShape m))) 1238 | ([m expected-shape] 1239 | (.validate m) 1240 | (if (nil? expected-shape) 1241 | (error "Shape validation failed, expected a scalar but was a Vectorz array") 1242 | (let [shape (vec (.getShape m))] 1243 | (if (= shape (vec expected-shape)) 1244 | shape 1245 | (error "Shape validation failed, expected " expected-shape "but was " shape))))))) 1246 | 1247 | (extend-protocol mp/PConversion 1248 | AScalar 1249 | (convert-to-nested-vectors [m] 1250 | (.get m)) 1251 | AVector 1252 | (convert-to-nested-vectors [m] 1253 | (into [] m)) 1254 | AMatrix 1255 | (convert-to-nested-vectors [m] 1256 | (mapv mp/convert-to-nested-vectors (.getSlices m))) 1257 | INDArray 1258 | (convert-to-nested-vectors [m] 1259 | (if (== 0 (.dimensionality m)) 1260 | (mp/get-0d m) 1261 | (mapv mp/convert-to-nested-vectors (.getSlices m)))) 1262 | AIndex 1263 | (convert-to-nested-vectors [m] 1264 | (vec m)) 1265 | Index 1266 | (convert-to-nested-vectors [m] 1267 | (vec (.getData m)))) 1268 | 1269 | (extend-protocol mp/PMatrixDivide 1270 | INDArray 1271 | (element-divide 1272 | ([m] 1273 | (with-clone [m] (.reciprocal m))) 1274 | ([m a] 1275 | (with-broadcast-clone [m a] (.divide m a)))) 1276 | AVector 1277 | (element-divide 1278 | ([m] 1279 | (with-clone [m] (.reciprocal m))) 1280 | ([m a] 1281 | (with-clone [m] (.divide m (vectorz-coerce a))))) 1282 | Index ;; we need this special override, since division doesn't work with integer indexes! 1283 | (element-divide 1284 | ([m] 1285 | (let [v (Vectorz/create m)] (.reciprocal v) v)) 1286 | ([m a] 1287 | (let [v (Vectorz/create m)] (.divide v (vectorz-coerce a)) v)))) 1288 | 1289 | (extend-protocol mp/PMatrixDivideMutable 1290 | INDArray 1291 | (element-divide! 1292 | ([m] 1293 | (.reciprocal m)) 1294 | ([m a] 1295 | (.divide m (vectorz-coerce a)))) 1296 | AVector 1297 | (element-divide! 1298 | ([m] 1299 | (.reciprocal m)) 1300 | ([m a] 1301 | (.divide m (vectorz-coerce a))))) 1302 | 1303 | (extend-protocol mp/PMatrixMultiply 1304 | AScalar 1305 | (matrix-multiply [m a] 1306 | (with-vectorz-clone [a] (.multiply a (.get m)))) 1307 | (element-multiply [m a] 1308 | (with-vectorz-clone [a] (.multiply a (.get m)))) 1309 | AVector 1310 | (matrix-multiply [m a] 1311 | (.innerProduct m (vectorz-coerce a))) 1312 | (element-multiply [m a] 1313 | (with-broadcast-coerce [m a] (.multiplyCopy m a))) 1314 | AMatrix 1315 | (matrix-multiply [m a] 1316 | (cond 1317 | (instance? AVector a) 1318 | (let [^AVector r (Vectorz/newVector (.rowCount m))] 1319 | (.transform m ^AVector a r) 1320 | r) 1321 | (instance? AMatrix a) (.innerProduct m ^AMatrix a) 1322 | (number? a) (.multiplyCopy m (double a)) 1323 | :else (.innerProduct m (vectorz-coerce a)))) 1324 | (element-multiply [m a] 1325 | (with-broadcast-clone [m a] (.multiply m a))) 1326 | INDArray 1327 | (matrix-multiply [m a] 1328 | (if-let [^INDArray a (vectorz-coerce a)] 1329 | (.innerProduct m a) 1330 | (error "Can't convert to vectorz representation: " a))) 1331 | (element-multiply [m a] 1332 | (with-broadcast-clone [m a] (.multiply m a)))) 1333 | 1334 | (extend-protocol mp/PMatrixMultiplyMutable 1335 | AVector 1336 | (matrix-multiply! [m a] 1337 | (mp/assign! m (mp/inner-product m (vectorz-coerce a)))) 1338 | (element-multiply! [m a] 1339 | (.multiply m (vectorz-coerce a))) 1340 | INDArray 1341 | (matrix-multiply! [m a] 1342 | (mp/assign! m (mp/inner-product m (vectorz-coerce a)))) 1343 | (element-multiply! [m a] 1344 | (.multiply m (vectorz-coerce a)))) 1345 | 1346 | (extend-protocol mp/PMatrixDivideMutable 1347 | INDArray 1348 | (element-divide! 1349 | ([m] (.reciprocal m)) 1350 | ([m a] (.divide m (vectorz-coerce a))))) 1351 | 1352 | (extend-protocol mp/PMatrixProducts 1353 | INDArray 1354 | (inner-product [m a] (.innerProduct m (vectorz-coerce a))) 1355 | (outer-product [m a] (.outerProduct m (vectorz-coerce a)))) 1356 | 1357 | (defn vectorz-scale 1358 | "Scales a vectorz array, return a new scaled array" 1359 | ([^INDArray m ^double a] 1360 | (with-clone [m] (.scale m (double a))))) 1361 | 1362 | (extend-protocol mp/PAddProduct 1363 | AVector 1364 | (add-product [m a b] 1365 | (with-clone [m] 1366 | (.addProduct m (avector-coerce m a) (avector-coerce m b))))) 1367 | 1368 | (extend-protocol mp/PAddProductMutable 1369 | INDArray 1370 | (add-product! [m a b] 1371 | (.add m (vectorz-coerce (mp/element-multiply a b)))) 1372 | AVector 1373 | (add-product! [m a b] 1374 | (.addProduct m (avector-coerce m a) (avector-coerce m b)))) 1375 | 1376 | (extend-protocol mp/PAddInnerProductMutable 1377 | INDArray 1378 | (add-inner-product! 1379 | ([m a b] 1380 | (let [a (vectorz-coerce a) 1381 | b (vectorz-coerce b)] 1382 | (.addInnerProduct m a b))) 1383 | ([m a b factor] 1384 | (let [factor (double-coerce factor)] 1385 | (cond 1386 | (== 0.0 factor) m 1387 | (== 1.0 factor) (mp/add-inner-product! m a b) 1388 | :else (mp/add-inner-product! m a (mp/scale b factor))))))) 1389 | 1390 | (extend-protocol mp/PAddOuterProductMutable 1391 | INDArray 1392 | (add-outer-product! 1393 | ([m a b] 1394 | (let [a (vectorz-coerce a) 1395 | b (vectorz-coerce b)] 1396 | (.addOuterProduct m a b))) 1397 | ([m a b factor] 1398 | (let [factor (double-coerce factor)] 1399 | (cond 1400 | (== 0.0 factor) m 1401 | (== 1.0 factor) (mp/add-outer-product! m a b) 1402 | :else (mp/add-outer-product! m a (mp/scale b factor))))))) 1403 | 1404 | (extend-protocol mp/PSetInnerProductMutable 1405 | INDArray 1406 | (set-inner-product! 1407 | ([m a b] 1408 | (let [a (vectorz-coerce a) 1409 | b (vectorz-coerce b)] 1410 | (.setInnerProduct m a b))) 1411 | ([m a b factor] 1412 | (let [factor (double-coerce factor)] 1413 | (cond 1414 | (== 0.0 factor) m 1415 | (== 1.0 factor) (mp/set-inner-product! m a b) 1416 | :else (mp/set-inner-product! m a (mp/scale b factor))))))) 1417 | 1418 | (extend-protocol mp/PAddScaled 1419 | INDArray 1420 | (add-scaled [m a factor] 1421 | (with-clone [m] 1422 | (.scaleAdd m 1.0 (vectorz-coerce a) (double-coerce factor) 0.0))) 1423 | AVector 1424 | (add-scaled [m a factor] 1425 | (with-clone [m] 1426 | (.addMultiple m (avector-coerce m a) (double-coerce factor))))) 1427 | 1428 | (extend-protocol mp/PAddScaledMutable 1429 | INDArray 1430 | (add-scaled! [m a factor] 1431 | (.addMultiple m (vectorz-coerce a) (double-coerce factor))) 1432 | AMatrix 1433 | (add-scaled! [m a factor] 1434 | (.addMultiple m (vectorz-coerce a) (double-coerce factor))) 1435 | AVector 1436 | (add-scaled! [m a factor] 1437 | (.addMultiple m (avector-coerce m a) (double-coerce factor)))) 1438 | 1439 | (extend-protocol mp/PAddScaledProduct 1440 | INDArray 1441 | (add-scaled-product [m a b factor] 1442 | (with-clone [m] 1443 | (.addMultiple m (vectorz-coerce (mul a b)) (double factor)))) 1444 | AVector 1445 | (add-scaled-product [m a b factor] 1446 | (with-clone [m] 1447 | (.addProduct m (avector-coerce m a) (avector-coerce m b) (double factor))))) 1448 | 1449 | (extend-protocol mp/PAddScaledProductMutable 1450 | INDArray 1451 | (add-scaled-product! [m a b factor] 1452 | (.addMultiple m (vectorz-coerce (mul a b)) (double factor))) 1453 | AVector 1454 | (add-scaled-product! [m a b factor] 1455 | (.addProduct m (avector-coerce m a) (avector-coerce m b) (double factor)))) 1456 | 1457 | (extend-protocol mp/PLerp 1458 | INDArray 1459 | (lerp [a b factor] 1460 | (let [factor (double-coerce factor)] 1461 | (with-clone [a] 1462 | (.scaleAdd a (- 1.0 factor) (vectorz-coerce b) factor 0.0)))) 1463 | (lerp! [a b factor] 1464 | (let [factor (double-coerce factor)] 1465 | (.scaleAdd a (- 1.0 factor) (vectorz-coerce b) factor 0.0))) 1466 | AVector 1467 | (lerp [a b factor] 1468 | (let [factor (double-coerce factor)] 1469 | (with-clone [a] 1470 | (.scaleAdd a (- 1.0 factor) (avector-coerce a b) factor 0.0)))) 1471 | (lerp! [a b factor] 1472 | (let [factor (double-coerce factor)] 1473 | (.scaleAdd a (- 1.0 factor) (avector-coerce a b) factor 0.0)))) 1474 | 1475 | (extend-protocol mp/PMatrixScaling 1476 | AScalar 1477 | (scale [m a] (vectorz-scale m (double-coerce a))) 1478 | (pre-scale [m a] (vectorz-scale m (double-coerce a))) 1479 | AVector 1480 | (scale [m a] (vectorz-scale m (double-coerce a))) 1481 | (pre-scale [m a] (vectorz-scale m (double-coerce a))) 1482 | AMatrix 1483 | (scale [m a] (vectorz-scale m (double-coerce a))) 1484 | (pre-scale [m a] (vectorz-scale m (double-coerce a))) 1485 | INDArray 1486 | (scale [m a] (vectorz-scale m (double-coerce a))) 1487 | (pre-scale [m a] (vectorz-scale m (double-coerce a)))) 1488 | 1489 | (extend-protocol mp/PVectorTransform 1490 | ATransform 1491 | (vector-transform [m v] 1492 | (if (instance? AVector v) 1493 | (.transform m ^AVector v) 1494 | (.transform m (avector-coerce v)))) 1495 | (vector-transform! [m v] 1496 | (if (instance? AVector v) 1497 | (.transformInPlace m ^AVector v) 1498 | (assign! v (transform m v))))) 1499 | 1500 | (extend-protocol mp/PMutableVectorOps 1501 | INDArray 1502 | (normalise! [a] 1503 | (if (== 1 (.dimensionality a)) 1504 | (.normalise (.asVector a)) 1505 | (error "Can't normalise something that isn't a 1D vector!"))) 1506 | AVector 1507 | (normalise! [a] 1508 | (.normalise a))) 1509 | 1510 | (extend-protocol mp/PMatrixOps 1511 | INDArray 1512 | (trace [m] 1513 | (.trace (amatrix-coerce m))) 1514 | (determinant [m] 1515 | (.determinant (amatrix-coerce m))) 1516 | (inverse [m] 1517 | (.inverse (amatrix-coerce m))) 1518 | AMatrix 1519 | (trace [m] 1520 | (.trace m)) 1521 | (determinant [m] 1522 | (.determinant m)) 1523 | (inverse [m] 1524 | (.inverse m))) 1525 | 1526 | (extend-protocol mp/PMatrixPredicates 1527 | INDArray 1528 | (identity-matrix? 1529 | [m] 1530 | (and 1531 | (== 2 (.dimensionality m)) 1532 | (.isIdentity (Matrixx/toMatrix m)))) 1533 | (zero-matrix? 1534 | [m] 1535 | (.isZero m)) 1536 | (symmetric? 1537 | [m] 1538 | (case (.dimensionality m) ; should be 1, 3, 4, ...; never 2 1539 | 1 true 1540 | (equals m (transpose m)))) 1541 | AMatrix 1542 | (identity-matrix? 1543 | [m] 1544 | (.isIdentity m)) 1545 | (zero-matrix? 1546 | [m] 1547 | (.isZero m)) 1548 | (symmetric? 1549 | [m] 1550 | (.isSymmetric m))) 1551 | 1552 | (extend-protocol mp/PSquare 1553 | INDArray 1554 | (square [m] 1555 | (with-clone [m] (.square m))) 1556 | AVector 1557 | (square [m] 1558 | (with-clone [m] (.square m)))) 1559 | 1560 | (extend-protocol mp/PExponent 1561 | INDArray 1562 | (element-pow [m exponent] 1563 | (if (number? m) 1564 | (with-clone [m] (.pow m (double-coerce exponent))) 1565 | (mp/element-map m (fn ^double [^double x ^double y] (Math/pow x y)) (vectorz-coerce m exponent))))) 1566 | 1567 | (extend-protocol mp/PLogistic 1568 | INDArray 1569 | (logistic [m] 1570 | (.applyOpCopy m Ops/LOGISTIC))) 1571 | 1572 | (extend-protocol mp/PLogisticMutable 1573 | INDArray 1574 | (logistic! [m] 1575 | (.applyOp m Ops/LOGISTIC))) 1576 | 1577 | (extend-protocol mp/PSoftplus 1578 | INDArray 1579 | (softplus [m] 1580 | (.applyOpCopy m Ops/SOFTPLUS))) 1581 | 1582 | (extend-protocol mp/PSoftplusMutable 1583 | INDArray 1584 | (softplus! [m] 1585 | (.applyOp m Ops/SOFTPLUS))) 1586 | 1587 | (extend-protocol mp/PSoftmax 1588 | AVector 1589 | (softmax [m] 1590 | (.softmaxCopy m))) 1591 | 1592 | (extend-protocol mp/PSoftmaxMutable 1593 | AVector 1594 | (softmax! [m] 1595 | (.softmax m))) 1596 | 1597 | (extend-protocol mp/PReLU 1598 | INDArray 1599 | (relu [m] 1600 | (.applyOpCopy m Ops/RECTIFIER))) 1601 | 1602 | (extend-protocol mp/PReLUMutable 1603 | INDArray 1604 | (relu! [m] 1605 | (.applyOp m Ops/RECTIFIER))) 1606 | 1607 | 1608 | (extend-protocol mp/PElementCount 1609 | INDArray 1610 | (element-count [m] 1611 | (.elementCount m)) 1612 | AMatrix 1613 | (element-count [m] 1614 | (.elementCount m)) 1615 | AVector 1616 | (element-count [m] 1617 | (.elementCount m)) 1618 | AScalar 1619 | (element-count [m] 1620 | 1) 1621 | AIndex 1622 | (element-count [m] 1623 | (.length m))) 1624 | 1625 | (extend-protocol mp/PSparse 1626 | INDArray 1627 | (sparse-coerce [m data] 1628 | (cond 1629 | (== 0 (dimensionality data)) (Scalar. (double-coerce data)) 1630 | (vectorz? data) (.sparse ^INDArray data) 1631 | :else (let [ss (map (fn [s] (.sparse (vectorz-coerce s))) (mp/get-major-slice-seq data))] 1632 | (.sparse (Arrayz/create (object-array ss)))))) 1633 | (sparse [m] 1634 | (.sparse m))) 1635 | 1636 | (extend-protocol mp/PSliceJoin 1637 | INDArray 1638 | (join [m a] 1639 | (.join m (vectorz-coerce a) (int 0))) 1640 | AVector 1641 | (join [m a] 1642 | (.join m (avector-coerce a)))) 1643 | 1644 | (extend-protocol mp/PSliceJoinAlong 1645 | INDArray 1646 | (join-along [m a dim] 1647 | (.join m (vectorz-coerce a) (int dim)))) 1648 | 1649 | (extend-protocol mp/PVectorView 1650 | INDArray 1651 | (as-vector [m] 1652 | (.asVector m)) 1653 | AVector 1654 | (as-vector [m] 1655 | m)) 1656 | 1657 | (extend-protocol mp/PVectorDistance 1658 | AVector 1659 | (distance [a b] 1660 | (.distance a (avector-coerce b)))) 1661 | 1662 | (extend-protocol mp/PElementMinMax 1663 | INDArray 1664 | (element-min [m] 1665 | (.elementMin m)) 1666 | (element-max [m] 1667 | (.elementMax m))) 1668 | 1669 | (extend-protocol mp/PComputeMatrix 1670 | INDArray 1671 | (compute-matrix [m shape f] 1672 | (let [dims (long (count shape))] 1673 | (cond 1674 | (== 0 dims) (double (f)) 1675 | (== 1 dims) 1676 | (let [n (int (first shape)) 1677 | v (Vector/createLength n)] 1678 | (dotimes [i n] (.set v (int i) (double (f i)))) 1679 | v) 1680 | (== 2 dims) 1681 | (let [n (int (first shape)) 1682 | m (int (second shape)) 1683 | v (Matrix/create n m)] 1684 | (dotimes [i n] 1685 | (dotimes [j m] 1686 | (.set v (int i) (int j) (double (f i j))))) 1687 | v) 1688 | :else 1689 | (Arrayz/create 1690 | (let [ns (next shape)] 1691 | (mapv #(mp/compute-matrix m ns (fn [& ixs] (apply f % ixs))) (range (first shape))))))))) 1692 | 1693 | (extend-protocol mp/PFunctionalOperations 1694 | INDArray 1695 | (element-seq 1696 | [m] 1697 | (let [ec (.elementCount m) 1698 | ^doubles data (or (.asDoubleArray m) (.toDoubleArray m))] 1699 | (seq data))) 1700 | (element-map 1701 | ([m f] 1702 | (.applyOpCopy m (FnOp/wrap f))) 1703 | ([m f a] 1704 | (with-clone [m] 1705 | (.applyOp m ^Op2 (FnOp2/wrap f) ^INDArray (vectorz-coerce a)))) 1706 | ([m f a more] 1707 | (mp/coerce-param m (mp/element-map (mp/convert-to-nested-vectors m) f a more)))) 1708 | (element-map! 1709 | ([m f] 1710 | (.applyOp m ^Op (FnOp/wrap f))) 1711 | ([m f a] 1712 | (.applyOp m ^Op2 (FnOp2/wrap f) ^INDArray (vectorz-coerce a))) 1713 | ([m f a more] 1714 | (mp/assign! m (mp/element-map m f a more)))) 1715 | (element-reduce 1716 | ([m f] 1717 | (.reduce m (FnOp2/wrap f))) 1718 | ([m f init] 1719 | (.reduce m (FnOp2/wrap f) (double init)))) 1720 | 1721 | AVector 1722 | (element-seq 1723 | [m] 1724 | (let [ec (.length m) 1725 | ^doubles data (or (.asDoubleArray m) (.toDoubleArray m))] 1726 | (seq data))) 1727 | (element-map 1728 | ([m f] 1729 | (.applyOpCopy m (FnOp/wrap f))) 1730 | ([m f a] 1731 | (with-clone [m] 1732 | (.applyOp m ^Op2 (FnOp2/wrap f) ^INDArray (vectorz-coerce a)))) 1733 | ([m f a more] 1734 | (mp/coerce-param m (mp/element-map (mp/convert-to-nested-vectors m) f a more)))) 1735 | (element-map! 1736 | ([m f] 1737 | (.applyOp m ^Op (FnOp/wrap f))) 1738 | ([m f a] 1739 | (.applyOp m ^Op2 (FnOp2/wrap f) ^INDArray (vectorz-coerce a))) 1740 | ([m f a more] 1741 | (mp/assign! m (mp/element-map m f a more)))) 1742 | (element-reduce 1743 | ([m f] 1744 | (.reduce m (FnOp2/wrap f))) 1745 | ([m f init] 1746 | (.reduce m (FnOp2/wrap f) (double init)))) 1747 | 1748 | AMatrix 1749 | (element-seq 1750 | [m] 1751 | (let [ec (.elementCount m) 1752 | ^doubles data (or (.asDoubleArray m) (.toDoubleArray m))] 1753 | (seq data))) 1754 | (element-map 1755 | ([m f] 1756 | (.applyOpCopy m (FnOp/wrap f))) 1757 | ([m f a] 1758 | (with-clone [m] 1759 | (.applyOp m ^Op2 (FnOp2/wrap f) ^INDArray (vectorz-coerce a)))) 1760 | ([m f a more] 1761 | (mp/coerce-param m (mp/element-map (mp/convert-to-nested-vectors m) f a more)))) 1762 | (element-map! 1763 | ([m f] 1764 | (.applyOp m ^Op (FnOp/wrap f))) 1765 | ([m f a] 1766 | (.applyOp m ^Op2 (FnOp2/wrap f) ^INDArray (vectorz-coerce a))) 1767 | ([m f a more] 1768 | (mp/assign! m (mp/element-map m f a more)))) 1769 | (element-reduce 1770 | ([m f] 1771 | (.reduce m (FnOp2/wrap f))) 1772 | ([m f init] 1773 | (.reduce m (FnOp2/wrap f) (double init)))) 1774 | 1775 | AIndex 1776 | (element-seq 1777 | [m] 1778 | (seq m)) 1779 | (element-map 1780 | ([m f] 1781 | (let [ec (.length m) 1782 | ^ints data (int-array ec)] 1783 | (dotimes [i ec] (aset data i (int (f (.get m i))))) 1784 | (Index/of data))) 1785 | ([m f a] 1786 | (let [ec (.length m) 1787 | ^ints data (int-array ec)] 1788 | (dotimes [i ec] (aset data i (int (f (.get m i) (mp/get-1d a i))))) 1789 | (Index/of data))) 1790 | ([m f a more] 1791 | (mp/element-map (mp/convert-to-nested-vectors m) f a more))) 1792 | (element-map! 1793 | ([m f] 1794 | (let [ec (.length m)] 1795 | (dotimes [i ec] (.set m i (int (f (.get m i))))) )) 1796 | ([m f a] 1797 | (let [ec (.length m)] 1798 | (dotimes [i ec] (.set m i (int (f (.get m i) (mp/get-1d a i))))) )) 1799 | ([m f a more] 1800 | (mp/assign! m (mp/element-map m f a more)))) 1801 | (element-reduce 1802 | ([m f] 1803 | (let [n (.length m)] 1804 | (cond 1805 | (== 0 n) (f) 1806 | (== 1 n) (.get m 0) 1807 | :else (loop [v ^Object (.get m 0) i 1] 1808 | (if (< i n) 1809 | (recur (f v (.get m i)) (inc i)) 1810 | v))))) 1811 | ([m f init] 1812 | (let [n (.length m)] 1813 | (cond 1814 | (== 0 n) init 1815 | :else (loop [v init i 0] 1816 | (if (< i n) 1817 | (recur (f v (.get m i)) (inc i)) 1818 | v))))))) 1819 | 1820 | (extend-protocol mp/PCompare 1821 | AVector 1822 | (element-compare [a b] 1823 | (let [b (avector-coerce a b)] 1824 | (.signum (.mutable (.subCopy a b))))) 1825 | (element-if [m a b] 1826 | (let [n (.length m) 1827 | a (avector-coerce m a) 1828 | b (avector-coerce m b) 1829 | r (Vectorz/newVector n)] 1830 | (dotimes [i (long n)] 1831 | (let [i (int i) 1832 | test (.unsafeGet m i)] 1833 | (.set r i (if (> test 0) (.unsafeGet a i) (.unsafeGet b i))))) 1834 | r)) 1835 | (element-lt [m a] 1836 | (let [n (.length m) 1837 | a (avector-coerce m a) 1838 | r (Vectorz/newVector n)] 1839 | (dotimes [i (long n)] 1840 | (let [i (int i) 1841 | test (- (.unsafeGet m i) (.unsafeGet a i))] 1842 | (.set r i (if (< test 0) 1.0 0.0)))) 1843 | r)) 1844 | (element-le [m a] 1845 | (let [n (.length m) 1846 | a (avector-coerce m a) 1847 | r (Vectorz/newVector n)] 1848 | (dotimes [i (long n)] 1849 | (let [i (int i) 1850 | test (- (.unsafeGet m i) (.unsafeGet a i))] 1851 | (.set r i (if (<= test 0) 1.0 0.0)))) 1852 | r)) 1853 | (element-gt [m a] 1854 | (let [n (.length m) 1855 | a (avector-coerce m a) 1856 | r (Vectorz/newVector n)] 1857 | (dotimes [i (long n)] 1858 | (let [i (int i) 1859 | test (- (.unsafeGet m i) (.unsafeGet a i))] 1860 | (.set r i (if (> test 0) 1.0 0.0)))) 1861 | r)) 1862 | (element-ge [m a] 1863 | (let [n (.length m) 1864 | a (avector-coerce m a) 1865 | r (Vectorz/newVector n)] 1866 | (dotimes [i (long n)] 1867 | (let [i (int i) 1868 | test (- (.unsafeGet m i) (.unsafeGet a i))] 1869 | (.set r i (if (>= test 0) 1.0 0.0)))) 1870 | r)) 1871 | (element-ne [m a] 1872 | (let [n (.length m) 1873 | a (avector-coerce m a) 1874 | r (Vectorz/newVector n)] 1875 | (dotimes [i (long n)] 1876 | (let [i (int i) 1877 | test (- (.unsafeGet m i) (.unsafeGet a i))] 1878 | (.set r i (if-not (== test 0) 1.0 0.0)))) 1879 | r)) 1880 | (element-eq [m a] 1881 | (let [n (.length m) 1882 | a (avector-coerce m a) 1883 | r (Vectorz/newVector n)] 1884 | (dotimes [i (long n)] 1885 | (let [i (int i) 1886 | test (- (.unsafeGet m i) (.unsafeGet a i))] 1887 | (.set r i (if (== test 0) 1.0 0.0)))) 1888 | r))) 1889 | 1890 | ;; ============================================================== 1891 | ;; Generator for mathematical functions 1892 | 1893 | (def math-op-mapping 1894 | '[(abs Ops/ABS) 1895 | (acos Ops/ACOS) 1896 | (asin Ops/ASIN) 1897 | (atan Ops/ATAN) 1898 | (cbrt Ops/CBRT) 1899 | (ceil Ops/CEIL) 1900 | (cos Ops/COS) 1901 | (cosh Ops/COSH) 1902 | (exp Ops/EXP) 1903 | (floor Ops/FLOOR) 1904 | (log Ops/LOG) 1905 | (log10 Ops/LOG10) 1906 | (round Ops/RINT) 1907 | (signum Ops/SIGNUM) 1908 | (sin Ops/SIN) 1909 | (sinh Ops/SINH) 1910 | (sqrt Ops/SQRT) 1911 | (tan Ops/TAN) 1912 | (tanh Ops/TANH) 1913 | (to-degrees Ops/TO_DEGREES) 1914 | (to-radians Ops/TO_RADIANS)]) 1915 | 1916 | (eval 1917 | `(extend-protocol mp/PMathsFunctions 1918 | INDArray 1919 | ~@(map 1920 | (fn [[fname op]] 1921 | `(~fname [~'m] (with-clone [~'m] (.applyTo ~op ~'m)))) 1922 | math-op-mapping) 1923 | AMatrix 1924 | ~@(map 1925 | (fn [[fname op]] 1926 | `(~fname [~'m] 1927 | (with-clone [~'m] (.applyTo ~op ~'m)))) 1928 | math-op-mapping) 1929 | AVector 1930 | ~@(map 1931 | (fn [[fname op]] 1932 | `(~fname [~'m] 1933 | (with-clone [~'m] (.applyTo ~op ~'m)))) 1934 | math-op-mapping) 1935 | AScalar 1936 | ~@(map 1937 | (fn [[fname op]] 1938 | `(~fname [~'m] 1939 | (.apply ~op (.get ~'m)))) 1940 | math-op-mapping))) 1941 | 1942 | (eval 1943 | `(extend-protocol mp/PMathsFunctionsMutable 1944 | INDArray 1945 | ~@(map 1946 | (fn [[fname op]] 1947 | (let [fname (symbol (str fname "!"))] 1948 | `(~fname [~'m] (.applyTo ~op ~'m)))) 1949 | math-op-mapping) 1950 | AMatrix 1951 | ~@(map 1952 | (fn [[fname op]] 1953 | (let [fname (symbol (str fname "!"))] 1954 | `(~fname [~'m] (.applyTo ~op ~'m)))) 1955 | math-op-mapping) 1956 | AVector 1957 | ~@(map 1958 | (fn [[fname op]] 1959 | (let [fname (symbol (str fname "!"))] 1960 | `(~fname [~'m] (.applyTo ~op ~'m)))) 1961 | math-op-mapping))) 1962 | 1963 | ;; Printing methods 1964 | (defmethod print-dup AVector [^AVector x ^Writer writer] 1965 | (.write writer (str "#vectorz/vector " x))) 1966 | 1967 | (defmethod print-dup AScalar [^AScalar x ^Writer writer] 1968 | (.write writer (str "#vectorz/scalar " x))) 1969 | 1970 | (defmethod print-dup AMatrix [^AMatrix x ^Writer writer] 1971 | (.write writer (str "#vectorz/matrix " x))) 1972 | 1973 | (defmethod print-dup INDArray [^INDArray x ^Writer writer] 1974 | (.write writer (str "#vectorz/array "x))) 1975 | 1976 | (defmethod print-method AVector [^AVector x ^Writer writer] 1977 | (.write writer (str "#vectorz/vector " x))) 1978 | 1979 | (defmethod print-method AScalar [^AScalar x ^Writer writer] 1980 | (.write writer (str "#vectorz/scalar " x))) 1981 | 1982 | (defmethod print-method AMatrix [^AMatrix x ^Writer writer] 1983 | (.write writer (str "#vectorz/matrix " x))) 1984 | 1985 | (defmethod print-method INDArray [^INDArray x ^Writer writer] 1986 | (.write writer (str "#vectorz/array "x))) 1987 | 1988 | ;; registration 1989 | 1990 | (imp/register-implementation (vectorz-coerce [[1 2] [3 4]])) 1991 | -------------------------------------------------------------------------------- /src/main/clojure/mikera/vectorz/readers.clj: -------------------------------------------------------------------------------- 1 | (ns mikera.vectorz.readers 2 | "Namespace for vectorz-clj data literal readers." 3 | (:require [clojure.core.matrix :as m]) 4 | (:import [mikera.vectorz AVector Vectorz Vector Vector1 Vector2 Vector3 Vector4 AScalar Scalar]) 5 | (:import [mikera.arrayz INDArray Arrayz]) 6 | (:import [mikera.matrixx AMatrix Matrixx]) 7 | (:import [java.util List]) 8 | (:require [mikera.cljutils.error :refer [error]]) 9 | (:refer-clojure :exclude [vector])) 10 | 11 | (set! *warn-on-reflection* true) 12 | (set! *unchecked-math* true) 13 | 14 | (defn array 15 | "Reads a data structure into a Vectorz array" 16 | (^INDArray [a] 17 | (Arrayz/create a))) 18 | 19 | (defn vector 20 | "Reads a data structure into a Vectorz vector" 21 | (^AVector [a] 22 | (cond 23 | (instance? List a) (Vectorz/create ^List a) 24 | (sequential? a) (Vectorz/create ^List (vec a)) 25 | :else (error "Vector must be read from a vector literal")))) 26 | 27 | (defn matrix 28 | "Reads a data structure into a Vectorz vector" 29 | (^AMatrix [a] 30 | (cond 31 | (instance? List a) (Matrixx/create ^List a) 32 | (sequential? a) (Matrixx/create ^List (vec a)) 33 | :else (error "Matrix must be read as a vector of vectors")))) 34 | 35 | (defn scalar 36 | "Reads a data structure into a Vectorz vector" 37 | (^AScalar [a] 38 | (if (number? a) 39 | (Scalar/create (double a)) 40 | (error "Scalar must be read as a numerical value")))) 41 | 42 | -------------------------------------------------------------------------------- /src/main/java/.gitignore: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mikera/vectorz-clj/bf575e5b9089f03b482071c83615af443b105525/src/main/java/.gitignore -------------------------------------------------------------------------------- /src/main/java/mikera/vectorz/FnOp.java: -------------------------------------------------------------------------------- 1 | package mikera.vectorz; 2 | 3 | import clojure.lang.IFn; 4 | import clojure.lang.IFn.DD; 5 | 6 | 7 | /** 8 | * Wrapper class for a Clojure arity-1 function 9 | * @author Mike 10 | * 11 | */ 12 | public final class FnOp extends mikera.vectorz.ops.AFunctionOp { 13 | protected IFn fn; 14 | 15 | public static Op wrap(Object f) { 16 | if (f instanceof Op) { 17 | return (Op) f; 18 | } else if (f instanceof DD) { 19 | return new PrimitiveFnOp((DD)f); 20 | } else { 21 | return new FnOp(f); 22 | } 23 | } 24 | 25 | public FnOp(Object f) { 26 | fn=(IFn)f; 27 | } 28 | 29 | public FnOp(IFn f) { 30 | fn=f; 31 | } 32 | 33 | @Override 34 | public double apply(double x) { 35 | return ((Number)fn.invoke(x)).doubleValue(); 36 | } 37 | 38 | @Override 39 | public void applyTo(double[] xs, int offset, int length) { 40 | for (int i=0; i 1.5 ns per addition 24 | 25 | ;; core.matrix mutable add 26 | (let [a (v/vec [1 2 3]) 27 | b (v/vec [1 2 3])] 28 | (c/quick-bench (dotimes [i 1000] (add! a b)))) 29 | ;; => 7.5 ns per addition 30 | 31 | ;; mikera.vectorz.core mutable add 32 | (let [^Vector3 a (v/vec [1 2 3]) 33 | ^Vector3 b (v/vec [1 2 3])] 34 | (c/quick-bench (dotimes [i 1000] (v/add! a b)))) 35 | ;; => 3.9 ns per addition 36 | 37 | 38 | ;; core.matrix add 39 | (let [a (v/vec [1 2 3]) 40 | b (v/vec [1 2 3])] 41 | (c/quick-bench (dotimes [i 1000] (add a b)))) 42 | ;; => 15.8 ns per addition 43 | 44 | ;; direct persistent vector add 45 | (let [a [1 2 3] 46 | b [1 2 3]] 47 | (c/quick-bench (dotimes [i 1000] (mapv + a b)))) 48 | 49 | ;; persistent vector core.matrix add 50 | (let [a [1 2 3] 51 | b [1 2 3]] 52 | (c/quick-bench (dotimes [i 1000] (add a b)))) 53 | 54 | ;; Adding two regular Clojure vectors with clojure.core/+ 55 | (let [a [1 2 3 4 5 6 7 8 9 10] 56 | b [1 2 3 4 5 6 7 8 9 10]] 57 | (c/quick-bench (dotimes [i 1000] (mapv clojure.core/+ a b)))) 58 | ;; => Execution time mean per addition : 285 ns 59 | 60 | ;; Adding two regular Clojure vectors with + 61 | (let [a [1 2 3 4 5 6 7 8 9 10] 62 | b [1 2 3 4 5 6 7 8 9 10]] 63 | (c/quick-bench (dotimes [i 1000] (+ a b)))) 64 | ;; => Execution time mean per addition : 285 ns 65 | 66 | ;; Adding two core.matrix vectors (pure functions, i.e. creating a new vector) 67 | (let [a (array :vectorz [1 2 3 4 5 6 7 8 9 10]) 68 | b (array :vectorz [1 2 3 4 5 6 7 8 9 10])] 69 | (c/quick-bench (dotimes [i 1000] (add a b)))) 70 | ;; => Execution time mean per addition: 120 ns 71 | 72 | ;; Adding two core.matrix vectors (mutable operation, i.e. adding to the first vector) 73 | (let [a (array :vectorz [1 2 3 4 5 6 7 8 9 10]) 74 | b (array :vectorz [1 2 3 4 5 6 7 8 9 10])] 75 | (c/quick-bench (dotimes [i 1000] (add! a b)))) 76 | ;; => Execution time mean per addition: 28 ns 77 | 78 | ;; Adding two core.matrix vectors using low level Java interop 79 | (let [a (Vectorz/create [1 2 3 4 5 6 7 8 9 10]) 80 | b (Vectorz/create [1 2 3 4 5 6 7 8 9 10])] 81 | (c/quick-bench (dotimes [i 1000] (.add a b)))) 82 | ;; => Execution time mean per addition: 11 ns 83 | 84 | ;; Indexed lookup with Clojure vector 85 | (let [a (Vectorz/create [1 2 3 4 5 6 7 8 9 10])] 86 | (c/quick-bench (dotimes [i 1000] (mp/get-nd a [1])))) 87 | ;; => Execution time mean per lookup: 14 ns 88 | 89 | 90 | ) -------------------------------------------------------------------------------- /src/test/clojure/mikera/vectorz/benchmark_stats.clj: -------------------------------------------------------------------------------- 1 | (ns mikera.vectorz.benchmark-stats 2 | (:use clojure.core.matrix) 3 | (:use clojure.core.matrix.stats) 4 | (:require [clojure.core.matrix.operators :refer [+ - *]]) 5 | (:refer-clojure :exclude [+ - * ]) 6 | (:require [criterium.core :as c]) 7 | (:require [mikera.vectorz.matrix-api]) 8 | (:require [mikera.vectorz.core :as v]) 9 | (:require [mikera.vectorz.matrix :as m]) 10 | (:import [mikera.vectorz Vector3 Vectorz])) 11 | 12 | (set! *warn-on-reflection* true) 13 | (set! *unchecked-math* :warn-on-boxed) 14 | 15 | (set-current-implementation :vectorz) 16 | 17 | (defn benchmarks [] 18 | (let [vs (vec (for [i (range 100)] (let [m (Vectorz/newVector 100)] (.set m (double i)) m)))] 19 | (c/quick-bench (mean vs))) 20 | ;; => ~2ns per element 21 | 22 | (let [vs (vec (for [i (range 100)] (let [m (Vectorz/newVector 100)] (.set m (double i)) m)))] 23 | (c/quick-bench (variance vs))) 24 | ;; => ~4ns per element 25 | 26 | ) -------------------------------------------------------------------------------- /src/test/clojure/mikera/vectorz/blank.clj: -------------------------------------------------------------------------------- 1 | (ns mikera.vectorz.blank 2 | (:use clojure.core.matrix)) -------------------------------------------------------------------------------- /src/test/clojure/mikera/vectorz/examples.clj: -------------------------------------------------------------------------------- 1 | (ns mikera.vectorz.examples 2 | (:refer-clojure :exclude [+ - *]) 3 | (:use clojure.core.matrix) 4 | (:require [clojure.core.matrix.operators :refer [+ - *]]) 5 | (:require mikera.vectorz.matrix-api)) 6 | 7 | ;; in which we run a sequence of examples to demonstrate vectorz-clj features 8 | 9 | (defn example [] 10 | 11 | ;; first up, tell core.matrix that we want to use vectorz as our default implementation 12 | (set-current-implementation :vectorz) 13 | 14 | ;; create a new 3D vector 15 | (def a (new-vector 3)) 16 | a 17 | ;; => # 18 | 19 | ;; check the class of our vector. it should be a Java class from the Vectorz packages 20 | (class a) 21 | ;; => mikera.vectorz.Vector3 22 | 23 | 24 | ;; convert a vectorz vector into a regular Clojure persistent vector 25 | ;; our vector should be empty at the moment (all zeros) 26 | (coerce [] a) 27 | ;; => [0.0 0.0 0.0] 28 | 29 | 30 | ;; assign to a vector using core.matrix functions 31 | ;; 32 | ;; Note 1: you can use clojure vectors quite happily as arguments to core.matrix functions: 33 | ;; they are considered as valid matrices by core.matrix 34 | ;; Note 2: functions with a ! cause mutation 35 | (assign! a [1 2 3]) 36 | ;; => # 37 | 38 | 39 | ;; create a normalised version of our vector 40 | (def n (normalise a)) 41 | n 42 | ;; => # 43 | 44 | ;; normalised vector should have a length of 1.0, or very close (subject to numerical error) 45 | (length n) 46 | ;;=> 1.0 47 | 48 | 49 | 50 | ) -------------------------------------------------------------------------------- /src/test/clojure/mikera/vectorz/generators.clj: -------------------------------------------------------------------------------- 1 | (ns mikera.vectorz.generators 2 | (:require [clojure.core.matrix :as m] 3 | [clojure.test.check.generators :as gen] 4 | [clojure.core.matrix.generators :as gm] 5 | [clojure.test.check.properties :as prop] 6 | [mikera.cljutils.error :refer [error]])) 7 | 8 | (set! *warn-on-reflection* true) 9 | (set! *unchecked-math* :warn-on-boxed) 10 | 11 | (defn subvector 12 | "Creates a subvector generator from a vector generator" 13 | ([g-vector] 14 | (gen/bind g-vector 15 | (fn [v] 16 | (let [n (m/ecount v)] 17 | ()))))) 18 | 19 | (defn mutable-array 20 | "Create a generator for fully mutable arrays" 21 | ([] 22 | (gen/one-of 23 | [(gm/gen-array (gm/gen-shape) gm/gen-double :vectorz)]))) 24 | 25 | (defn mutable-vector 26 | "Create a generator for fully mutable vectors" 27 | ([] 28 | (gen/one-of 29 | [(gm/gen-vector gm/gen-double :vectorz) 30 | (gen/fmap m/as-vector (mutable-array))]))) 31 | 32 | -------------------------------------------------------------------------------- /src/test/clojure/mikera/vectorz/implementation_check.clj: -------------------------------------------------------------------------------- 1 | (ns mikera.vectorz.implementation-check 2 | (:use clojure.core.matrix) 3 | (:require [clojure.test :as test]) 4 | (:require [clojure.core.matrix.utils :as utils])) 5 | 6 | (set! *warn-on-reflection* true) 7 | (set! *unchecked-math* :warn-on-boxed) 8 | 9 | (set-current-implementation :vectorz) 10 | 11 | (def protos (utils/extract-protocols)) 12 | 13 | (defn test-impls 14 | "Gets a map of vectorz types to unimplemented interfaces. 15 | Intended to allow us to check which protocols still need to be implemented." 16 | ([] 17 | (array :vectorz [1]) 18 | (into {} (mapv #(do [% (utils/unimplemented %)]) 19 | [mikera.arrayz.INDArray 20 | mikera.matrixx.AMatrix 21 | mikera.vectorz.AVector 22 | mikera.vectorz.AScalar])))) 23 | 24 | -------------------------------------------------------------------------------- /src/test/clojure/mikera/vectorz/large_matrix_benchmark.clj: -------------------------------------------------------------------------------- 1 | (ns mikera.vectorz.large-matrix-benchmark 2 | (:use clojure.core.matrix) 3 | (:require [clojure.core.matrix.operators :refer [+ - *]]) 4 | (:refer-clojure :exclude [+ - *]) 5 | (:require [criterium.core :as c]) 6 | (:require [mikera.vectorz.core :as v]) 7 | (:require [mikera.vectorz.matrix :as m]) 8 | (:import [mikera.vectorz Vector3 Vectorz])) 9 | 10 | (set! *warn-on-reflection* true) 11 | (set! *unchecked-math* :warn-on-boxed) 12 | 13 | (set-current-implementation :vectorz) 14 | 15 | (defn benchmarks [] 16 | ;; direct vectorz add 17 | (let [a (matrix (range 1000)) 18 | b (matrix (range 1000))] 19 | (c/quick-bench (dotimes [i 1000] (+ a b)))) 20 | ;; 2437 ns per add 21 | 22 | ;; 100x100 matrix construction 23 | (let [] 24 | (c/quick-bench (matrix (map (fn [r] (range 100)) (range 100))))) 25 | ;; 45 ns per element 26 | 27 | (let [m (matrix (map (fn [r] (range 100)) (range 100)))] 28 | (c/quick-bench (+ m m))) 29 | ;; 10 ns per element - OK-ish 30 | 31 | (let [m (matrix (map (fn [r] (range 100)) (range 100)))] 32 | (c/quick-bench (* m m))) 33 | ;; 1.3ns per multiply?? 34 | 35 | (let [m (matrix (map (fn [r] (range 100)) (range 100)))] 36 | (c/quick-bench (mul m m))) 37 | ;; 1.3ns per multiply?? 38 | 39 | (let [m (matrix (map (fn [r] (range 100)) (range 100)))] 40 | (c/quick-bench (abs! m))) 41 | ;; ~1.0ns per element 42 | 43 | (let [m (matrix (map (fn [r] (range 100)) (range 100)))] 44 | (c/quick-bench (sqrt! m))) 45 | ;; ~3.4ns per element 46 | 47 | (let [m (matrix (map (fn [r] (range 100)) (range 100)))] 48 | (c/quick-bench (cbrt! m))) 49 | ;; ~19ns per element 50 | 51 | 52 | ) -------------------------------------------------------------------------------- /src/test/clojure/mikera/vectorz/matrix_benchmarks.clj: -------------------------------------------------------------------------------- 1 | (ns mikera.vectorz.matrix-benchmarks 2 | (:use clojure.core.matrix) 3 | (:use clojure.core.matrix.stats) 4 | (:require [criterium.core :as c]) 5 | (:require [mikera.vectorz.matrix-api]) 6 | (:require [mikera.vectorz.core :as v]) 7 | (:require [mikera.vectorz.matrix :as m]) 8 | (:import [mikera.vectorz Vector3 Vectorz]) 9 | (:import [mikera.matrixx Matrixx])) 10 | 11 | (set! *warn-on-reflection* true) 12 | (set! *unchecked-math* true) 13 | 14 | (set-current-implementation :vectorz) 15 | 16 | (defn benchmarks [] 17 | ;; elementwise mutation of 10x10 matrix, followed by computation of the sum 18 | ;; => about 3,000 ns 19 | (defn buildsum [n] 20 | (let [bv (zero-matrix n n)] 21 | (dotimes [i n] 22 | (dotimes [j n] 23 | (mset! bv i j (* i j)))) 24 | (esum bv))) 25 | (c/quick-bench (buildsum 10)) 26 | 27 | 28 | ;; multiplication of two 100x100 matrices 29 | ;; => about 1.3 ms 30 | (defn multiply [n] 31 | (let [ma (zero-matrix n n) 32 | mb (zero-matrix n n)] 33 | (mmul ma mb))) 34 | (c/quick-bench (multiply 100)) 35 | 36 | 37 | ;; multiplication of two 10x10 matrices 38 | ;; => about 1,800 ns 39 | (c/quick-bench (multiply 10)) 40 | 41 | ) -------------------------------------------------------------------------------- /src/test/clojure/mikera/vectorz/test_core.clj: -------------------------------------------------------------------------------- 1 | (ns mikera.vectorz.test-core 2 | (:use clojure.test) 3 | (:require mikera.vectorz.examples) 4 | (:require [mikera.vectorz.core :as v]) 5 | (:import [mikera.vectorz AVector Vectorz Vector])) 6 | 7 | (set! *warn-on-reflection* true) 8 | (set! *unchecked-math* :warn-on-boxed) 9 | 10 | (deftest test-arithmetic 11 | (testing "addition" 12 | (is (= (v/of 1 2) 13 | (v/+ (v/of 0 2) (v/of 1 0))))) 14 | (testing "subtraction" 15 | (is (= (v/of 1 2) 16 | (v/- (v/of 3 3) (v/of 2 1))))) 17 | (testing "division" 18 | (is (= (v/of 1 2) 19 | (v/divide (v/of 10 10) (v/of 10 5))))) 20 | (testing "multiplication" 21 | (is (= (v/of 2 6) 22 | (v/* (v/of 1 2) (v/of 2 3)))))) 23 | 24 | (deftest test-vector-ops 25 | (testing "dot product" 26 | (is (= 10.0 (v/dot (v/of 2 3 1) (v/of 1 2 2))))) 27 | (testing "distance" 28 | (is (= 5.0 (v/distance (v/of 1 0 0) (v/of 1 3 4)))))) 29 | 30 | (deftest test-more-vector-ops 31 | (testing "add weighted" 32 | (is (= (v/of 2 6) 33 | (v/add-weighted (v/of 1 5) (v/of 5 9) 1/4)))) 34 | (testing "interpolate" 35 | (is (= (v/of 2 6) 36 | (v/interpolate (v/of 5 9) (v/of 1 5) 3/4))))) 37 | 38 | (deftest test-get-set 39 | (testing "get" 40 | (is (= 10.0 (v/get (v/of 5 10 15) 1)))) 41 | (testing "set" 42 | (is (= (v/of 5 10 15) 43 | (v/set (v/of 5 0 15) (long 1) 10.0))) 44 | (is (= (v/of 3 2) (v/to-vector (v/set (v/of 1 2) 0 3)))))) 45 | 46 | (deftest test-seq 47 | (testing "to seq" 48 | (is (= [2.0 3.0] (seq (v/of 2 3)))))) 49 | 50 | (deftest test-refs 51 | (testing "join" 52 | (let [v1 (v/of 1 2) 53 | v2 (v/of 3 4) 54 | jv (v/join v1 v2)] 55 | (is (= (v/of 1 2 3 4) jv)) 56 | (v/fill! (v/subvec jv 1 2) 10) 57 | (is (== 10.0 (v/get v1 1))) 58 | (is (== 10.0 (v/get v2 0))) 59 | (v/fill! (v/clone jv) 20) 60 | (is (not= 20.0 (v/get v2 0)))))) 61 | 62 | (deftest test-assign 63 | (testing "assign" 64 | (is (= (v/of 1 2) (v/assign! (v/of 2 3) (v/of 1 2)))))) 65 | 66 | (deftest test-construction 67 | (testing "from double arrays" 68 | (is (= (v/of 1 2 3) (v/vec (double-array [1 2 3])))))) 69 | 70 | (deftest test-primitive-vector-constructors 71 | (testing "vec1" 72 | (is (= (v/of 1) (v/vec1 1) )) 73 | (is (= (v/of 1) (v/vec1 [1]))) 74 | (is (= (v/of 0) (v/vec1))) 75 | (is (thrown? Throwable (v/vec1 [2 3])))) 76 | (testing "vec2" 77 | (is (= (v/of 1 2) (v/vec2 1 2) )) 78 | (is (= (v/of 1 2) (v/vec2 [1 2]))) 79 | (is (= (v/of 0 0) (v/vec2))) 80 | (is (thrown? Throwable (v/vec2 [2])))) 81 | (testing "vec3" 82 | (is (= (v/of 1 2 3) (v/vec3 1 2 3) )) 83 | (is (= (v/of 1 2 3) (v/vec3 [1 2 3]))) 84 | (is (= (v/of 0 0 0) (v/vec3))) 85 | (is (thrown? Throwable (v/vec3 [2 3])))) 86 | (testing "vec4" 87 | (is (= (v/of 1 2 3 4) (v/vec4 1 2 3 4) )) 88 | (is (= (v/of 1 2 3 4) (v/vec4 [1 2 3 4]))) 89 | (is (= (v/of 0 0 0 0) (v/vec4))) 90 | (is (thrown? Throwable (v/vec4 [2 3]))))) -------------------------------------------------------------------------------- /src/test/clojure/mikera/vectorz/test_linear.clj: -------------------------------------------------------------------------------- 1 | (ns mikera.vectorz.test-linear 2 | (:use [clojure.test] 3 | [clojure.core.matrix :exclude [rank]] 4 | [clojure.core.matrix.linear]) 5 | (:require [mikera.vectorz.core :as v])) 6 | 7 | (set! *warn-on-reflection* true) 8 | (set! *unchecked-math* :warn-on-boxed) 9 | 10 | (set-current-implementation :vectorz) 11 | 12 | (deftest test-svd 13 | (let [result (svd [[2 0] [0 1]])] 14 | (is (every? v/vectorz? (vals result))) 15 | (is (equals [2 1] (:S result))) 16 | (is (every? orthogonal? ((juxt :V* :U) result))))) 17 | 18 | (deftest test-qr 19 | (let [A (matrix [[2 0] [0 1]]) 20 | result (qr A) 21 | R (:R result) 22 | Q (:Q result)] 23 | (is (orthogonal? Q)) 24 | (is (upper-triangular? R)) 25 | (is (equals A (mmul Q R))))) 26 | 27 | ;; TODO: Reinstate once linear algebra implementation complete 28 | (deftest test-solve 29 | (let [A [[1 2] 30 | [2 1]] 31 | b [22 32 | 26]] 33 | (is (equals [10 6] (solve A b))))) 34 | 35 | 36 | (deftest test-QR-decomposition 37 | (let [epsilon 0.00001] 38 | (testing "test0" 39 | (let [M (matrix [[1 2 3][4 5 6][7 8 9]]) 40 | {:keys [Q R]} (qr M {:return [:Q]})] 41 | (is (orthogonal? Q)) 42 | (is (and Q (not R))))) 43 | (testing "test1" 44 | (let [M (matrix [[1 2 3][4 5 6][7 8 9]]) 45 | {:keys [Q R]} (qr M {:return [:Q :R]})] 46 | (is (orthogonal? Q)) 47 | (is (upper-triangular? R)) 48 | (is (equals M (mmul Q R) epsilon)))) 49 | (testing "test2" 50 | (let [M (matrix [[111 222 333][444 555 666][777 888 999]]) 51 | {:keys [Q R]} (qr M nil)] 52 | (is (orthogonal? Q)) 53 | (is (upper-triangular? R)) 54 | (is (equals M (mmul Q R) epsilon)))) 55 | (testing "test3" 56 | (let [M (matrix [[-1 2 0][14 51 6.23][7.1242 -8.4 119]]) 57 | {:keys [Q R]} (qr M)] 58 | (is (orthogonal? Q)) 59 | (is (upper-triangular? R)) 60 | (is (equals M (mmul Q R) epsilon)))))) 61 | 62 | (deftest test-QR-decomposition-rectangular 63 | (let [epsilon 0.00001] 64 | (testing "should decompose wide matrices" 65 | (let [M (matrix [[1 2 3 4 5][6 7 8 9 10][11 12 13 14 15]]) 66 | {:keys [Q R]} (qr M)] 67 | (is (= [3 3](shape Q))) 68 | (is (orthogonal? Q)) 69 | (is (= [3 5](shape R))) 70 | (is (equals M (mmul Q R) epsilon)))) 71 | (testing "should decompose tall matrices" 72 | (let [M (matrix [[1 2 3][4 5 6][7 8 9][10 11 12][13 14 15]]) 73 | {:keys [Q R]} (qr M)] 74 | (is (= [5 5](shape Q))) 75 | (is (orthogonal? Q)) 76 | (is (= [5 3](shape R))) 77 | (is (equals M (mmul Q R) epsilon)))))) 78 | 79 | (deftest test-LUP-decomposition 80 | (let [epsilon 0.00001] 81 | (testing "test0" 82 | (let [M (matrix [[1 2 3][4 5 6][7 8 9]]) 83 | {:keys [L U P]} (lu M {:return [:U]})] 84 | (is (upper-triangular? U)) 85 | (is (and U (not L) (not P))))) 86 | (testing "test1" 87 | (let [M (matrix [[1 2 3][4 5 6][7 8 9]]) 88 | {:keys [L U P]} (lu M {:return [:L :U :P]}) 89 | p (matrix [[0.0 1.0 0.0][0.0 0.0 1.0][1.0 0.0 0.0]])] 90 | (is (lower-triangular? L)) 91 | (is (upper-triangular? U)) 92 | (is (equals P p epsilon)) 93 | (is (equals M (mmul P L U) epsilon)))) 94 | (testing "test2" 95 | (let [M (matrix [[76 87 98][11 21 32][43 54 65]]) 96 | {:keys [L U P]} (lu M) 97 | p (matrix [[1.0 0.0 0.0][0.0 1.0 0.0][0.0 0.0 1.0]])] 98 | (is (lower-triangular? L)) 99 | (is (upper-triangular? U)) 100 | (is (equals P p epsilon)) 101 | (is (equals M (mmul P L U) epsilon)))))) 102 | 103 | (deftest test-SVD-decomposition 104 | (let [epsilon 0.00001] 105 | (testing "test0" 106 | (let [M (matrix [[1 2 3][4 5 6][7 8 9]]) 107 | {:keys [U S V*]} (svd M {:return [:S]})] 108 | (is (and S (not U) (not V*))))) 109 | (testing "test1" 110 | (let [M (matrix [[1 2 3][4 5 6][7 8 9]]) 111 | {:keys [U S V*]} (svd M {:return [:U :S :V*]}) 112 | S_matrix (diagonal-matrix :vectorz S)] 113 | (is (orthogonal? U)) 114 | (is (orthogonal? V*)) 115 | (is (equals M (mmul U S_matrix V*) epsilon)))) 116 | (testing "test2" 117 | (let [M (matrix [[12 234 3.23][-2344 -235 61][-7 18.34 9]]) 118 | {:keys [U S V*]} (svd M) 119 | S_matrix (diagonal-matrix :vectorz S)] 120 | (is (orthogonal? U)) 121 | (is (orthogonal? V*)) 122 | (is (equals M (mmul U S_matrix V*) epsilon)))) 123 | (testing "test3" 124 | (let [M (matrix [[76 87 98][11 21 32][43 54 65]]) 125 | {:keys [U S V*]} (svd M nil) 126 | S_matrix (diagonal-matrix :vectorz S)] 127 | (is (orthogonal? U)) 128 | (is (orthogonal? V*)) 129 | (is (equals M (mmul U S_matrix V*) epsilon)))))) 130 | 131 | (deftest test-Cholesky-decomposition 132 | (let [epsilon 0.00001] 133 | (testing "test0" 134 | (let [M (matrix [[1 2 3][4 5 6][7 8 9]]) 135 | result (cholesky M)] 136 | (is (nil? result)))) 137 | (testing "test1" 138 | (let [M (matrix [[4 12 -16][12 37 -43][-16 -43 98]]) 139 | {:keys [L L*]} (cholesky M {:return [:L]})] 140 | (is (lower-triangular? L)) 141 | (is (and L (not L*))) 142 | (is (equals M (mmul L (transpose L)) epsilon)) 143 | (is (reduce (fn [a b] (and a (> (double b) 0))) true (diagonal L))))) 144 | (testing "test1" 145 | (let [M (matrix [[2 -1 0][-1 2 -1][0 -1 2]]) 146 | {:keys [L L*]} (cholesky M {:return [:L :L*]})] 147 | (is (lower-triangular? L)) 148 | (is (upper-triangular? L*)) 149 | (is (equals L (transpose L*) epsilon)) 150 | (is (equals M (mmul L L*) epsilon)) 151 | (is (reduce (fn [a b] (and a (> (double b) 0))) true (diagonal L))) 152 | (is (reduce (fn [a b] (and a (> (double b) 0))) true (diagonal L*))))))) 153 | 154 | (deftest test-norm 155 | (let [M (matrix [[1 2 3][4 5 6][7 8 9]])] 156 | (is (equals 45.0 (norm M 1) 1e-10)) 157 | (is (equals 16.88194301613 (norm M 2) 1e-10)) 158 | (is (equals 9 (norm M java.lang.Double/POSITIVE_INFINITY) 1e-10)) 159 | (is (equals 16.88194301613 (norm M) 1e-10)) 160 | (is (equals 12.65148997952 (norm M 3) 1e-10)) 161 | (let [V (as-vector M)] 162 | (is (equals 45.0 (norm V 1) 1e-10)) 163 | (is (equals 16.88194301613 (norm V 2) 1e-10)) 164 | (is (equals 9 (norm V java.lang.Double/POSITIVE_INFINITY) 1e-10)) 165 | (is (equals 16.88194301613 (norm V) 1e-10)) 166 | (is (equals 0.941944314533 (norm V -3) 1e-10)) 167 | (is (equals 12.65148997952 (norm V 3) 1e-10))))) 168 | 169 | (deftest test-rank 170 | (let [M1 (matrix [[1 2 3][4 5 6][7 8 9]]) 171 | M2 (identity-matrix 3) 172 | M3 (matrix [[1 1 1][1 1 1][1 1 1]])] 173 | (is (equals 2 (rank M1))) 174 | (is (equals 3 (rank M2))) 175 | (is (equals 1 (rank M3))))) 176 | 177 | (deftest test-solve 178 | (let [M1 (matrix [[1 -2 1][0 1 6][0 0 1]]) 179 | M2 (matrix [[1 2 3][4 5 6][7 8 9]]) ;Singular matrix 180 | V1 (array [4 -1 2]) 181 | V2 (vec [4 -1 2]) 182 | A1 (array [-24 -13 2])] 183 | (is (equals A1 (solve M1 V1) 1e-8)) 184 | (is (equals A1 (solve M1 V2) 1e-8)) 185 | (is (nil? (solve M2 V1))))) 186 | 187 | (deftest test-least-squares 188 | (let [M1 (matrix [[1 2][3 4][5 6]]) 189 | V1 (array [1 2 3]) 190 | V2 (vec [1 2 3]) 191 | A1 (array [0 0.5])] 192 | (is (equals A1 (least-squares M1 V1) 1e-8)) 193 | (is (equals A1 (least-squares M1 V2) 1e-8)))) 194 | -------------------------------------------------------------------------------- /src/test/clojure/mikera/vectorz/test_matrix.clj: -------------------------------------------------------------------------------- 1 | (ns mikera.vectorz.test-matrix 2 | (:use [clojure test]) 3 | (:require [mikera.vectorz.core :as v]) 4 | (:require [mikera.vectorz.matrix :as m]) 5 | (:import [mikera.matrixx AMatrix Matrixx Matrix]) 6 | (:import [mikera.vectorz AVector Vectorz Vector])) 7 | 8 | (set! *warn-on-reflection* true) 9 | (set! *unchecked-math* :warn-on-boxed) 10 | 11 | (deftest test-constructors 12 | (testing "identity" 13 | (is (= (m/matrix [[1 0] [0 1]]) (m/identity-matrix 2))) 14 | (is (= (m/identity-matrix 3) (m/scale-matrix [1 1 1])))) 15 | (testing "scale matrix" 16 | (is (= (m/matrix [[1 0] [0 1]]) (m/scale-matrix 2 1)))) 17 | (testing "diagonal matrix" 18 | (is (= (m/matrix [[2 0] [0 3]]) (m/diagonal-matrix [2 3])))) 19 | (testing "rotation matrix" 20 | (is (v/approx= (v/of 1 2 3) 21 | (m/* (m/x-axis-rotation-matrix (* 2 Math/PI)) 22 | (v/of 1 2 3)))))) 23 | 24 | (deftest test-compose 25 | (testing "composing scales" 26 | (is (= (m/scale-matrix [3 6]) (m/* (m/scale-matrix [1 2]) (m/scale-matrix 2 3)))))) 27 | 28 | (deftest test-ops 29 | (testing "as-vector" 30 | (is (= (v/of 1 0 0 1) (m/as-vector (m/identity-matrix 2)))) 31 | (is (= (v/of 1 0) (m/get-row (m/identity-matrix 2) 0))))) 32 | 33 | (deftest test-get-set 34 | (testing "setting" 35 | (let [m (m/clone (m/identity-matrix 2))] 36 | (is (= 1.0 (m/get m 0 0))) 37 | (m/set m 0 0 2.0) 38 | (is (= 2.0 (m/get m 0 0)))))) 39 | 40 | (deftest test-arithmetic 41 | (testing "identity" 42 | (let [a (v/of 2 3) 43 | m (m/identity-matrix 2) 44 | r (m/* m a)] 45 | (is (= a r))))) 46 | 47 | (deftest test-predicates 48 | (testing "fully mutable" 49 | (is (m/fully-mutable? (m/new-matrix 3 3))) 50 | ) 51 | (testing "square" 52 | (is (m/square? (m/new-matrix 3 3))) 53 | (is (m/square? (m/identity-matrix 10))) 54 | (is (not (m/square? (m/new-matrix 4 3)))) 55 | ) 56 | (testing "identity" 57 | (is (m/identity? (m/identity-matrix 3))) 58 | (is (not (m/identity? (m/new-matrix 2 2 )))) 59 | (is (m/identity? (m/scale-matrix [1 1 1]))) 60 | (is (not (m/identity? (m/scale-matrix [1 2 3])))) 61 | ) 62 | (testing "zero" 63 | (is (m/zero? (m/new-matrix 2 3))) 64 | (is (m/zero? (m/scale-matrix [0 0 0 0 0]))) 65 | ) 66 | (testing "affine" 67 | (is (m/affine-transform? (m/new-matrix 2 3))))) 68 | 69 | (deftest test-dimensions 70 | (testing "inputs" 71 | (is (= 3 (m/input-dimensions (m/new-matrix 2 3))))) 72 | (testing "outputs" 73 | (is (= 2 (m/output-dimensions (m/new-matrix 2 3))))) 74 | ) -------------------------------------------------------------------------------- /src/test/clojure/mikera/vectorz/test_matrix_api.clj: -------------------------------------------------------------------------------- 1 | (ns mikera.vectorz.test-matrix-api 2 | (:refer-clojure :exclude [vector? * - +]) 3 | (:use [clojure test]) 4 | (:use clojure.core.matrix) 5 | (:require [clojure.core.matrix.operators :refer [+ - *]]) 6 | (:require clojure.core.matrix.compliance-tester) 7 | (:require [clojure.core.matrix.protocols :as mp]) 8 | (:require [clojure.core.matrix.linear :as li]) 9 | (:require [mikera.vectorz.core :as v]) 10 | (:require [mikera.vectorz.matrix :as m]) 11 | (:require [mikera.vectorz.matrix-api]) 12 | (:require clojure.core.matrix.impl.persistent-vector) 13 | (:require [clojure.core.matrix.impl.wrappers :as wrap]) 14 | (:import [mikera.matrixx AMatrix Matrixx Matrix]) 15 | (:import [mikera.vectorz Scalar]) 16 | (:import [mikera.indexz AIndex Index]) 17 | (:import [mikera.vectorz AVector Vectorz Vector]) 18 | (:import [mikera.arrayz INDArray Array NDArray])) 19 | 20 | (set! *warn-on-reflection* true) 21 | (set! *unchecked-math* :warn-on-boxed) 22 | 23 | ;; note - all the operators are core.matrix operators 24 | 25 | (set-current-implementation :vectorz) 26 | 27 | (deftest test-misc-regressions 28 | (let [v1 (v/vec1 1.0)] 29 | (is (array? v1)) 30 | (is (== 1 (dimensionality v1))) 31 | (is (== 1 (ecount v1))) 32 | (is (not (matrix? v1)))) 33 | (let [m (coerce (matrix [[1 2]]) [[1 2] [3 4]])] 34 | (is (every? true? (map == (range 1 (inc (long (ecount m)))) (eseq m))))) 35 | (let [m (matrix [[1 2] [3 4]])] 36 | (is (== 2 (ecount (first (slices m))))) 37 | (scale! (first (slices m)) 2.0) 38 | (is (equals m [[2 4] [3 4]]))) 39 | (let [m (matrix [[0 0] [0 0]])] 40 | (assign! m [[1 2] [3 4]]) 41 | (is (equals m [[1 2] [3 4]])) 42 | (assign! m [[0 0] [0 0]]) 43 | (is (equals m [[0 0] [0 0]])) 44 | (mp/assign-array! m (double-array [2 4 6 8])) 45 | (is (equals m [[2 4] [6 8]])) 46 | (mp/assign-array! m (double-array 4)) 47 | (is (equals m [[0 0] [0 0]]))) 48 | (let [v (v/vec [1 2 3])] 49 | (is (equals [2 4 6] (add v v)))) 50 | (let [v (Vector/of (double-array 0))] 51 | (is (== 10 (reduce (fn [acc _] (inc (long acc))) 10 (eseq v)))) 52 | (is (== 10 (ereduce (fn [acc _] (inc (long acc))) 10 v)))) 53 | (let [m (reshape (array (double-array (range 9))) [3 3])] 54 | (is (equals [[0 1 2]] (submatrix m 0 [0 1]))) 55 | (is (equals [[0 1 2]] (submatrix m [[0 1] nil]))) 56 | (is (equals [[0] [3] [6]] (submatrix m 1 [0 1])))) 57 | (let [v (array (range 9))] 58 | (is (equals v (submatrix v [nil]))) 59 | (is (equals v (submatrix v [[0 9]]))) 60 | (is (equals [2 3 4] (submatrix v 0 [2 3]))) 61 | (is (equals [2 3 4] (submatrix v [[2 3]])))) 62 | (is (instance? AVector (array [1 2]))) 63 | (is (equals [1 1 1] (div (array [2 2 2]) 2))) 64 | (is (equals [[1 2] [3 4] [5 6]] (join (array [[1 2] [3 4]]) (array [[5 6]])))) 65 | (is (equals [[1 3] [2 4] [5 6]] (join (transpose (array [[1 2] [3 4]])) (array [[5 6]])))) 66 | (is (= 1.0 (slice (array [0 1 2]) 1))) 67 | (is (mp/set-nd (matrix :vectorz [[1 2][3 4]]) [0 1] 3)) 68 | (testing "Regression with matrix applyOp arity 2" 69 | (let [t (array :vectorz [[10] [20]])] 70 | (is (equals [[1] [2]] (emap / t 10)))))) 71 | 72 | (deftest test-set-indices-62 ;; fix for #62 issue 73 | (is (equals [[1 1] [2 9]] (set-indices (matrix :vectorz [[1 1] [2 2]]) [[1 1]] [9])))) 74 | 75 | (deftest test-infinity-norm ;; fix for #67 issue 76 | (is (equals 2 (mp/norm (array :vectorz [-2 1]) Double/POSITIVE_INFINITY))) 77 | (is (equals 3 (mp/norm (array :vectorz [-2 3]) Double/POSITIVE_INFINITY)))) 78 | 79 | (deftest test-set-column ;; fix for #63 issue 80 | (let [m (matrix :vectorz [[1 2 3] [3 4 5]])] 81 | (is (equals [[1 2 10] [3 4 11]] (set-column m 2 [10 11]))))) 82 | 83 | (deftest test-row-column-matrix 84 | (let [m (matrix :vectorz [1 2 3]) 85 | rm (row-matrix m) 86 | cm (column-matrix m)] 87 | (is (not (equals rm cm))) 88 | (is (equals rm (transpose cm))) 89 | (is (equals rm [[1 2 3]])))) 90 | 91 | (deftest test-mget-regressions 92 | (is (== 3 (mget (mset (zero-array [4 4]) 0 2 3) 0 2))) 93 | (is (== 3 (mget (mset (zero-array [4]) 2 3) 2))) 94 | (is (== 3 (mget (mset (zero-array []) 3))))) 95 | 96 | (deftest test-scalar-arrays 97 | (is (equals 0 (new-scalar-array :vectorz))) 98 | (is (equals 3 (scalar-array 3))) 99 | (is (equals 2 (add 1 (array 1)))) 100 | (is (equals [2 3] (add 1 (array [1 2])))) 101 | (is (equals [2 3] (add (scalar-array 1) (array [1 2]))))) 102 | 103 | (deftest test-symmetric? 104 | (is (symmetric? (array [[1 2] [2 3]]))) 105 | (is (not (symmetric? (array [[1 2] [3 4]])))) 106 | (is (symmetric? (array [1 2 3]))) 107 | (is (symmetric? (array [[[1 2] [0 0]] [[2 0] [0 1]]]))) 108 | (is (symmetric? (array [[[1 2] [3 0]] [[2 0] [0 1]]]))) 109 | (is (not (symmetric? (array [[[1 2] [0 0]] [[3 0] [0 1]]]))))) 110 | 111 | (deftest test-broadcasting-cases 112 | (is (equals [[2 3] [4 5]] (add (array [[1 2] [3 4]]) (array [1 1])))) 113 | (is (equals [[2 3] [4 5]] (add (array [1 1]) (array [[1 2] [3 4]])))) 114 | (is (equals [[2 4] [6 8]] (mul (array [[1 2] [3 4]]) (scalar-array 2)))) 115 | (is (equals [[2 6] [6 12]] (mul (array [[1 2] [3 4]]) [2 3]))) 116 | (is (equals [[1 4] [3 16]] (pow (array [[1 2] [3 4]]) [1 2])))) 117 | 118 | (deftest test-broadcasts 119 | (is (equals [[2 2] [2 2]] (broadcast 2 [2 2]))) 120 | (is (not (equals [[2 2] [2 2]] (broadcast 2 [2]))))) 121 | 122 | (deftest test-scalar-add 123 | (is (equals [2 3 4] (add 1 (array [1 2 3])))) 124 | (is (equals [2 3 4] (add (array [1 2 3]) 1 0)))) 125 | 126 | (deftest test-ecount 127 | (is (== 1 (ecount (Scalar. 10)))) 128 | (is (== 2 (ecount (v/of 1 2)))) 129 | (is (== 0 (ecount (Vector/of (double-array 0))))) 130 | (is (== 0 (count (eseq (Vector/of (double-array 0)))))) 131 | (is (== 0 (ecount (coerce :vectorz [])))) 132 | (is (== 4 (ecount (coerce :vectorz [[1 2] [3 4]])))) 133 | (is (== 8 (ecount (coerce :vectorz [[[1 2] [3 4]] [[1 2] [3 4]]]))))) 134 | 135 | (deftest test-mutability 136 | (let [v (v/of 1 2)] 137 | (is (mutable? v)) 138 | (is (mutable? (first (slice-views v))))) 139 | (let [v (new-array [3 4 5 6])] 140 | (is (v/vectorz? v)) 141 | (is (mutable? v)) 142 | (is (mutable? (first (slice-views v)))))) 143 | 144 | (deftest test-new-array 145 | (is (instance? AVector (new-array [10]))) 146 | (is (instance? AMatrix (new-array [10 10]))) 147 | (is (instance? INDArray (new-array [3 4 5 6])))) 148 | 149 | (deftest test-sub 150 | (let [a (v/vec [1 2 3 0 0]) 151 | b (v/vec [1 1 4 0 0])] 152 | (is (equals [0 1 -1 0 0] (sub a b))))) 153 | 154 | (deftest test-add-product 155 | (let [a (v/vec [1 2]) 156 | b (v/vec [1 1])] 157 | (is (equals [2 5] (add-product b a a))) 158 | (is (equals [3 9] (add-scaled-product b a [1 2] 2))) 159 | (is (equals [11 21] (add-product b a 10))) 160 | (is (equals [11 21] (add-product b 10 a))))) 161 | 162 | (deftest test-add-product! 163 | (let [a (v/vec [1 2]) 164 | b (v/vec [1 1])] 165 | (add-product! b a a) 166 | (is (equals [2 5] b)) 167 | (add-scaled! b a -1) 168 | (is (equals [1 3] b)) 169 | (add-scaled-product! b [0 1] [3 4] 2) 170 | (is (equals [1 11] b)))) 171 | 172 | (deftest test-coerce 173 | (is (equals (array [1 2]) (coerce :vectorz [1 2]))) 174 | (is (equals (array [[1 2] [3 4]]) (coerce :vectorz [[1 2] [3 4]]))) 175 | (let [a (v/vec [1 2 3 0 0]) 176 | b (v/vec [1 1 4 0 0]) 177 | r (sub a b)] 178 | (is (equals [0 1 -1 0 0] (coerce [] r))) 179 | (is (instance? clojure.lang.IPersistentVector (coerce [] r))) 180 | ;; (is (instance? INDArray (coerce :vectorz 10.0))) ;; TODO: what should this be?? 181 | )) 182 | 183 | (deftest test-ndarray 184 | (is (equals [[[1]]] (matrix :vectorz [[[1]]]))) 185 | (is (equals [[[[1]]]] (matrix :vectorz [[[[1]]]]))) 186 | (is (equals [[[1]]] (slice (matrix :vectorz [[[[1]]]]) 0))) 187 | (is (== 4 (dimensionality (matrix :vectorz [[[[1]]]])))) 188 | (is (equals [[[1]]] (wrap/wrap-slice (matrix :vectorz [[[[1]]]]) 0))) 189 | (is (equals [[[[1]]]] (wrap/wrap-nd (matrix :vectorz [[[[1]]]]))))) 190 | 191 | (deftest test-element-equality 192 | (is (e= (matrix :vectorz [[0.5 0] [0 2]]) 193 | [[0.5 0.0] [0.0 2.0]])) 194 | ;; TODO: enable this test once fixed version of core.matrix is released 195 | ;; (is (not (e= (matrix :vectorz [[1 2] [3 4]]) 196 | ;; [[5 6] [7 8]]))) 197 | ) 198 | 199 | (deftest test-inverse 200 | (let [m (matrix :vectorz [[0.5 0] [0 2]])] 201 | (is (equals [[2 0] [0 0.5]] (inverse m))))) 202 | 203 | (deftest test-det 204 | (is (== -1.0 (det (matrix :vectorz [[0 1] [1 0]]))))) 205 | 206 | (defn test-round-trip [m] 207 | (is (equals m (read-string (str m)))) 208 | ;; TODO edn round-tripping? 209 | ) 210 | 211 | (deftest test-round-trips 212 | (test-round-trip (v/of 1 2)) 213 | (test-round-trip (v/of 1 2 3 4 5)) 214 | (test-round-trip (matrix :vectorz [[1 2 3] [4 5 6]])) 215 | (test-round-trip (matrix :vectorz [[1 2] [3 4]])) 216 | (test-round-trip (first (slices (v/of 1 2 3)))) 217 | ) 218 | 219 | (deftest test-equals 220 | (is (equals (v/of 1 2) [1 2]))) 221 | 222 | (deftest test-vector-ops 223 | (testing "addition" 224 | (is (= (v/of 1 2) (+ (v/of 1 1) [0 1]))) 225 | (is (= (v/of 3 4) (+ (v/of 1 1) (v/of 2 3)))) 226 | (is (= [1.0 2.0] (+ [0 2] (v/of 1 0))))) 227 | 228 | (testing "scaling" 229 | (is (= (v/of 2 4) (* (v/of 1 2) 2))) 230 | (is (= (v/of 2 4) (scale (v/of 1 2) 2))) 231 | (is (= (v/of 2 4) (scale (v/of 1 2) 2N))) 232 | (is (= (v/of 2 4) (scale (v/of 1 2) 2.0)))) 233 | 234 | (testing "subtraction" 235 | (is (= (v/of 2 4) (- (v/of 3 5) [1 1]))) 236 | (is (= (v/of 1 2) (- (v/of 2 3) (v/of 1 0) (v/of 0 1)))))) 237 | 238 | (deftest test-matrix-ops 239 | (testing "addition" 240 | (is (= (m/matrix [[2 2] [2 2]]) (+ (m/matrix [[1 1] [2 0]]) 241 | (m/matrix [[1 1] [0 2]])))) 242 | (is (= (m/matrix [[2 2] [2 2]]) (+ (m/matrix [[1 1] [2 0]]) 243 | [[1 1] [0 2]]))) 244 | (is (= [[2.0 2.0] [2.0 2.0]] (+ [[1 1] [0 2]] 245 | (m/matrix [[1 1] [2 0]]))))) 246 | (testing "scaling" 247 | (is (= (m/matrix [[2 2] [2 2]]) (scale (m/matrix [[1 1] [1 1]]) 2)))) 248 | 249 | (testing "multiplication" 250 | (is (= (m/matrix [[8]]) (mmul (m/matrix [[2 2]]) (m/matrix [[2] [2]])))) 251 | (is (= (m/matrix [[8]]) (mmul (m/matrix [[2 2]]) [[2] [2]]))) 252 | ;; (is (= [[8.0]] (* [[2 2]] (m/matrix [[2] [2]])))) 253 | )) 254 | 255 | (deftest test-join 256 | (is (= (array [[[1]] [[2]]]) (join (array [[[1]]]) (array [[[2]]]))))) 257 | 258 | (deftest test-pm 259 | (is (string? (clojure.core.matrix.impl.pprint/pm (array :vectorz [1 2]))))) 260 | 261 | (deftest test-matrix-transform 262 | (testing "vector multiple" 263 | (is (= (v/of 2 4) (mmul (m/matrix [[2 0] [0 2]]) (v/of 1 2)))) 264 | (is (= (v/of 2 4) (mmul (m/scalar-matrix 2 2.0) (v/of 1 2)))) 265 | (is (= (v/of 2 4) (mmul (m/scalar-matrix 2 2.0) [1 2])))) 266 | (testing "persistent vector transform" 267 | (is (= (v/of 1 2) (transform (m/identity-matrix 2) [1 2])))) 268 | (testing "transform in place" 269 | (let [v (matrix [1 2]) 270 | m (matrix [[2 0] [0 2]])] 271 | (transform! m v) 272 | (is (= (v/of 2 4) v))))) 273 | 274 | (deftest test-slices 275 | (testing "slice row and column from matrix" 276 | (is (equals [1 2] (first (slices (matrix [[1 2] [3 4]]))))) 277 | (is (equals [3 4] (second (slices (matrix [[1 2] [3 4]]))))) 278 | (is (equals [3 4] (slice (matrix [[1 2] [3 4]]) 0 1))) 279 | (is (equals [2 4] (slice (matrix [[1 2] [3 4]]) 1 1)))) 280 | (testing "slices of vector" 281 | (is (equals '(1.0 2.0 3.0) (slices (matrix [1 2 3])))))) 282 | 283 | ;; verify scalar operators should still work on numbers! 284 | (deftest test-scalar-operators 285 | (testing "addition" 286 | (is (== 2.0 (+ 1.0 1.0))) 287 | (is (== 3 (+ 1 2)))) 288 | (testing "multiplication" 289 | (is (== 2.0 (* 4 0.5))) 290 | (is (== 6 (* 1 2 3)))) 291 | (testing "subtraction" 292 | (is (== 2.0 (- 4 2.0))) 293 | (is (== 6 (- 10 2 2))))) 294 | 295 | (deftest test-compare 296 | (testing "eif" 297 | (is (= (eif (array [1 0 0]) (array [1 2 3]) (array [4 5 6])) (array [1 5 6])))) 298 | (testing "lt" 299 | (is (= (lt (array [0 2 -1 2]) 0) (array [0 0 1 0]))) 300 | (is (= (lt (array [0 2 -1 2]) (array [1 2 3 4])) (array [1 0 1 1])))) 301 | (testing "le" 302 | (is (= (le (array [0 2 -1 2]) 0) (array [1 0 1 0]))) 303 | (is (= (le (array [0 2 -1 2]) (array [1 2 3 4])) (array [1 1 1 1])))) 304 | (testing "gt" 305 | (is (= (gt (array [-1 2 0 4]) 0) (array [0 1 0 1]))) 306 | (is (= (gt (array [-1 2 0 4]) (array [1 2 3 4])) (array [0 0 0 0])))) 307 | (testing "ge" 308 | (is (= (ge (array [-1 2 0 4]) 0) (array [0 1 1 1]))) 309 | (is (= (ge (array [-1 2 0 4]) (array [1 2 3 4])) (array [0 1 0 1])))) 310 | (testing "ne" 311 | (is (= (ne (array [-1 2 0 4]) 0) (array [1 1 0 1]))) 312 | (is (= (ne (array [-1 2 0 4]) (array [1 2 3 4])) (array [1 0 1 0])))) 313 | (testing "eq" 314 | (is (= (eq (array [-1 2 0 4]) 0) (array [0 0 1 0]))) 315 | (is (= (eq (array [-1 2 0 4]) (array [1 2 3 4])) (array [0 1 0 1]))))) 316 | 317 | (deftest test-construction 318 | (testing "1D" 319 | (is (= (v/of 1.0) (matrix [1]))) 320 | (is (instance? AVector (matrix [1])))) 321 | (testing "2D" 322 | (is (= (m/matrix [[1 2] [3 4]]) (matrix [[1 2] [3 4]]))) 323 | (is (instance? AMatrix (matrix [[1]]))))) 324 | 325 | (deftest test-conversion 326 | (testing "vector" 327 | (is (= [1.0] (to-nested-vectors (v/of 1.0)))) 328 | (is (= [1.0] (coerce [] (v/of 1.0))))) 329 | (testing "matrix" 330 | (is (= [[1.0]] (to-nested-vectors (m/matrix [[1.0]]))))) 331 | (testing "coercion" 332 | (is (equals [[1 2] [3 4]] (coerce (m/matrix [[1.0]]) [[1 2] [3 4]]))) 333 | (is (number? (coerce :vectorz 10))) 334 | (is (instance? AVector (coerce :vectorz [1 2 3]))) 335 | (is (instance? AMatrix (coerce :vectorz [[1 2] [3 4]]))))) 336 | 337 | (deftest test-functional-ops 338 | (testing "eseq" 339 | (is (= [1.0 2.0 3.0 4.0] (eseq (matrix [[1 2] [3 4]])))) 340 | (is (empty? (eseq (coerce :vectorz [])))) 341 | (is (= [10.0] (eseq (array :vectorz 10)))) 342 | (is (= [10.0] (eseq (array :vectorz [[[10]]])))) 343 | (is (== 1 (first (eseq (v/of 1 2)))))) 344 | (testing "emap" 345 | (is (equals [1 2] (emap inc (v/of 0 1)))) 346 | (is (equals [1 3] (emap + (v/of 0 1) [1 2]))) 347 | ;; (is (equals [2 3] (emap + (v/of 0 1) 2))) shouldn't work - no broadcast support in emap? 348 | (is (equals [3 6] (emap + (v/of 0 1) [1 2] (v/of 2 3))))) 349 | (testing "long args" 350 | ;; TODO: fix in core.matrix 0.15.0 351 | ; (is (equals [10] (emap + 352 | ; (v/of 1) 353 | ; [2] 354 | ; (array :vectorz [3]) 355 | ; (broadcast 4 [1])))) 356 | (is (equals [10] (emap + (array :vectorz [1]) [2] [3] [4]))) 357 | (is (equals 10 (ereduce + (array :vectorz [[1 2] [3 4]])))))) 358 | 359 | (deftest test-compute-array 360 | (is (equals [[0 1] [1 2]] (compute-matrix :vectorz [2 2] +))) 361 | (is (equals [[[0 1] [1 2]][[1 2][2 3]]] (compute-matrix [2 2 2] +)))) 362 | 363 | (deftest test-maths-functions 364 | (testing "abs" 365 | (is (equals [1 2 3] (abs [-1 2 -3]))) 366 | (is (equals [1 2 3] (abs (v/of -1 2 -3)))))) 367 | 368 | (deftest test-assign 369 | (is (e== [2 2] (assign (v/of 1 2) 2))) 370 | (let [m (array :vectorz [1 2 3 4 5 6])] 371 | (is (e== [1 2 3] (subvector m 0 3))) 372 | (is (e== [4 5 6] (subvector m 3 3))) 373 | (assign! (subvector m 0 3) (subvector m 3 3)) 374 | (is (e== [4 5 6 4 5 6] m))) 375 | (testing "mutable assign" 376 | (let [a (array [[1 2] [3 4]])] 377 | (assign! a [0 1]) 378 | (is (equals [[0 1] [0 1]] a))))) 379 | 380 | ;; vectorz operations hould return a vectorz datatype 381 | (deftest test-vectorz-results 382 | (is (v/vectorz? (+ (v/of 1 2) [1 2]))) 383 | (is (v/vectorz? (+ (v/of 1 2) 1))) 384 | (is (v/vectorz? (- 2 (v/of 1 2)))) 385 | (is (v/vectorz? (* (v/of 1 2) 2.0))) 386 | (is (v/vectorz? (emap inc (v/of 1 2)))) 387 | (is (v/vectorz? (array [[[1]]]))) 388 | (is (v/vectorz? (to-vector (array [[[1]]])))) 389 | (is (v/vectorz? (identity-matrix 3))) 390 | (is (v/vectorz? (reshape (identity-matrix 3) [5 1]))) 391 | (is (v/vectorz? (slice (identity-matrix 3) 1))) 392 | (is (v/vectorz? (* (identity-matrix 3) [1 2 3]))) 393 | (is (v/vectorz? (inner-product (v/of 1 2) [1 2]))) 394 | (is (v/vectorz? (outer-product (v/of 1 2) [1 2]))) 395 | (is (v/vectorz? (add! (Scalar. 1.0) 10)))) 396 | 397 | (deftest test-shift 398 | (is (v/vectorz? (shift (v/of 1 2) [1]))) 399 | (is (equals [2 0] (shift (v/of 1 2) [1])))) 400 | 401 | (deftest test-defensive-copy-on-double-array 402 | (let [a (double-array [1 2 3 4 5]) 403 | v (array a)] 404 | (aset-double a 4 9999) 405 | (is (equals v [1 2 3 4 5])))) 406 | 407 | (deftest test-validate-shape 408 | (is (equals [2] (mp/validate-shape (v/of 1 2))))) 409 | 410 | (deftest test-add-inner-product! 411 | (let [m (array :vectorz [1 2]) 412 | a (array :vectorz [[0 2] [1 0]]) 413 | b (array :vectorz [10 100])] 414 | (add-inner-product! m a b) 415 | (is (equals [201 12] m)) 416 | (add-inner-product! m a b -1) 417 | (is (equals [1 2] m))) 418 | (is (equals [101 102] (add-inner-product! (array :vectorz [1 2]) 10 10))) 419 | (is (equals [101 102] (add-inner-product! (array :vectorz [1 2]) [1 2 3] [1 3 1] 10)))) 420 | 421 | (deftest test-add-outer-product! 422 | (let [m (array :vectorz [[1 2] [3 4]]) 423 | a (array :vectorz [10 100]) 424 | b (array :vectorz [7 9])] 425 | (add-outer-product! m a b) 426 | (is (equals [[71 92] [703 904]] m)) 427 | (add-outer-product! m a b -1) 428 | (is (equals [[1 2] [3 4]] m))) 429 | (is (equals [11 32] (add-outer-product! (array :vectorz [1 2]) [1 3] 10))) 430 | (is (equals [11 32] (add-outer-product! (array :vectorz [1 2]) 10 [1 3]))) 431 | (is (equals [101 302] (add-outer-product! (array :vectorz [1 2]) 10 [1 3] 10)))) 432 | 433 | (deftest test-logistic 434 | (is (equals [0 0.5 1] (logistic (array :vectorz [-1000 0 1000]))))) 435 | 436 | (deftest test-select-regression 437 | (let [m (new-matrix 3 4) 438 | col (select m :all 1)] 439 | (assign! col [3 4 5]) 440 | (is (equals [[0 3 0 0] [0 4 0 0] [0 5 0 0]] m)))) 441 | 442 | (deftest test-array-add-product-regression 443 | (let [a (array :vectorz [[[1]]])] 444 | (is (equals [[[61]]] (add-scaled-product a [2] [3] 10))) 445 | (is (equals [[[1]]] a))) 446 | (let [a (array :vectorz [[[1]]])] 447 | (is (equals [[[61]]] (add-scaled-product! a [2] [3] 10))) 448 | (is (equals [[[61]]] a)))) 449 | 450 | ;; regression test for #54 451 | (deftest test-diagonal-inverse-regression 452 | (is (equals (inverse [[1 0] [0 2]]) (inverse (diagonal-matrix [1 2]))))) 453 | 454 | ;; run compliance tests 455 | 456 | (deftest instance-tests 457 | (clojure.core.matrix.compliance-tester/instance-test (Scalar. 2.0)) 458 | (clojure.core.matrix.compliance-tester/instance-test (v/of 1 2)) 459 | (clojure.core.matrix.compliance-tester/instance-test (v/of 1 2 3)) 460 | (clojure.core.matrix.compliance-tester/instance-test (v/of 1 2 3 4 5 6 7)) 461 | (clojure.core.matrix.compliance-tester/instance-test (subvector (v/of 1 2 3 4 5 6 7) 2 3)) 462 | (clojure.core.matrix.compliance-tester/instance-test (matrix :vectorz [[[1 2] [3 4]] [[5 6] [7 8]]])) 463 | (clojure.core.matrix.compliance-tester/instance-test (clone (first (slices (v/of 1 2 3))))) 464 | (clojure.core.matrix.compliance-tester/instance-test (first (slices (v/of 1 2 3)))) 465 | ;; (clojure.core.matrix.compliance-tester/instance-test (Vector/of (double-array 0))) ;; TODO: needs fixed compliance tests 466 | (clojure.core.matrix.compliance-tester/instance-test (first (slices (v/of 1 2 3 4 5 6)))) 467 | (clojure.core.matrix.compliance-tester/instance-test (array :vectorz [[1 2] [3 4]])) 468 | (clojure.core.matrix.compliance-tester/instance-test (array :vectorz [[[[4]]]])) 469 | (clojure.core.matrix.compliance-tester/instance-test (Array/create (array :vectorz [[[[4 3]]]]))) 470 | (clojure.core.matrix.compliance-tester/instance-test (Index/of (int-array [1 2 3])))) 471 | 472 | (deftest compliance-test 473 | (clojure.core.matrix.compliance-tester/compliance-test (v/of 1 2))) 474 | -------------------------------------------------------------------------------- /src/test/clojure/mikera/vectorz/test_ops.clj: -------------------------------------------------------------------------------- 1 | (ns mikera.vectorz.test-ops 2 | (:use clojure.test) 3 | (:use clojure.core.matrix) 4 | (:import [mikera.vectorz AVector Vectorz Vector] 5 | [mikera.vectorz Ops])) 6 | 7 | (set! *warn-on-reflection* true) 8 | (set! *unchecked-math* :warn-on-boxed) 9 | 10 | (deftest test-vector-ops 11 | (testing "Vectorz Op" 12 | (is (equals [1 2] (emap Ops/ABS (array :vectorz [1 -2])))))) 13 | 14 | (deftest test-add-emap 15 | (testing "Vectorz Op" 16 | (let [dest (array :vectorz [10 100])] 17 | (is (equals [14 106] (add-emap! dest Ops/ADD (array :vectorz [1 2]) [3 4])))))) 18 | 19 | (deftest test-set-emap 20 | (testing "Vectorz Op" 21 | (let [dest (array :vectorz [10 100])] 22 | (is (equals [4 6] (set-emap! dest Ops/ADD (array :vectorz [1 2]) [3 4])))))) 23 | -------------------------------------------------------------------------------- /src/test/clojure/mikera/vectorz/test_properties.clj: -------------------------------------------------------------------------------- 1 | (ns mikera.vectorz.test-properties 2 | (:use clojure.core.matrix) 3 | (:require [clojure.test.check :as sc] 4 | [clojure.test.check.generators :as gen :refer (sample)] 5 | [clojure.core.matrix.generators :as genm] 6 | [clojure.core.matrix.compliance-tester :as ctest] 7 | [clojure.test.check.properties :as prop] 8 | [clojure.test.check.clojure-test :as ct :refer (defspec)]) 9 | (:import [mikera.vectorz AVector Vectorz Vector])) 10 | 11 | (set! *warn-on-reflection* true) 12 | (set! *unchecked-math* :warn-on-boxed) 13 | 14 | (set-current-implementation :vectorz) 15 | 16 | ;; ==================================================== 17 | ;; vectrorz-specific generator functions 18 | 19 | (def gen-vectorz-arrays (genm/gen-array (genm/gen-shape) genm/gen-double (gen/return :vectorz))) 20 | 21 | 22 | ;; ==================================================== 23 | ;; property based tests 24 | 25 | (defspec generative-instance-tests 20 26 | (prop/for-all [v gen-vectorz-arrays] 27 | (ctest/instance-test v))) 28 | 29 | (defspec add-test 20 30 | (prop/for-all [v gen-vectorz-arrays] 31 | (equals (mul v 2) (add v v)))) 32 | 33 | (defspec clone-test 20 34 | (prop/for-all [v gen-vectorz-arrays] 35 | (equals v (clone v)))) -------------------------------------------------------------------------------- /src/test/clojure/mikera/vectorz/test_readers.clj: -------------------------------------------------------------------------------- 1 | (ns mikera.vectorz.test-readers 2 | (:use clojure.core.matrix) 3 | (:use clojure.test) 4 | (:require [clojure.test.check :as sc] 5 | [clojure.test.check.generators :as gen :refer (sample)] 6 | [clojure.core.matrix.generators :as genm] 7 | [clojure.core.matrix.compliance-tester :as ctest] 8 | [clojure.test.check.properties :as prop] 9 | [clojure.test.check.clojure-test :as ct :refer (defspec)])) 10 | 11 | (set! *warn-on-reflection* true) 12 | (set! *unchecked-math* :warn-on-boxed) 13 | 14 | (set-current-implementation :vectorz) 15 | 16 | (def gen-vectorz-arrays (genm/gen-array (genm/gen-shape) genm/gen-double (gen/return :vectorz))) 17 | 18 | (defspec reader-round-trip 100 19 | (prop/for-all [v gen-vectorz-arrays] 20 | (is (equals v (read-string (pr-str v)))))) 21 | 22 | -------------------------------------------------------------------------------- /src/test/clojure/mikera/vectorz/test_sparse.clj: -------------------------------------------------------------------------------- 1 | (ns mikera.vectorz.test-sparse 2 | (:use [clojure test]) 3 | (:use clojure.core.matrix) 4 | (:require clojure.core.matrix.compliance-tester) 5 | (:require [clojure.core.matrix.protocols :as mp]) 6 | (:require [clojure.core.matrix.linear :as li]) 7 | (:require [mikera.vectorz.matrix-api]) 8 | (:import [mikera.matrixx AMatrix Matrixx Matrix]) 9 | (:import [mikera.vectorz Scalar]) 10 | (:import [mikera.indexz AIndex Index]) 11 | (:import [mikera.vectorz AVector Vectorz Vector]) 12 | (:import [mikera.arrayz INDArray Array NDArray])) 13 | 14 | (set! *warn-on-reflection* true) 15 | (set! *unchecked-math* :warn-on-boxed) 16 | 17 | (deftest test-new-sparse-array 18 | (let [s [1000 1000] 19 | a (new-sparse-array :vectorz s)] 20 | (is (instance? INDArray a)) 21 | (is (= s (shape a))))) 22 | 23 | (deftest test-sparse-assign 24 | (let [pm [[6 7] [8 9]] 25 | sm (sparse :vectorz (matrix :vectorz [[1 2] [3 4]]))] 26 | (is (instance? INDArray sm)) 27 | (is (sparse? sm)) 28 | (assign! sm pm) 29 | (is (== 30 (esum sm))) 30 | (assign! sm 2) 31 | (is (== 8 (esum sm))))) 32 | 33 | (deftest test-sparse-assign-double 34 | (let [pm [[6 7] [8 9]] 35 | sm (sparse :vectorz pm)] 36 | (is (instance? INDArray sm)) 37 | (is (sparse? sm)) 38 | (assign! sm 1) 39 | (is (== 4 (esum sm))))) 40 | 41 | (deftest test-non-zero-indices 42 | (let [pm [[2 0] [0 1]] 43 | sm (sparse (matrix :vectorz pm))] 44 | (is (equals [0] (first (non-zero-indices sm)))) 45 | (is (equals [1] (second (non-zero-indices sm)))))) 46 | 47 | (deftest test-ops 48 | (let [pm [[2 0] [0 1]] 49 | sm (sparse (matrix :vectorz pm))] 50 | (is (equals [[4 0] [ 0 2]] (div! sm 0.5))) 51 | (is (equals pm (mul! sm 0.5))))) -------------------------------------------------------------------------------- /src/test/clojure/mikera/vectorz/test_stats.clj: -------------------------------------------------------------------------------- 1 | (ns mikera.vectorz.test-stats 2 | (:use clojure.core.matrix) 3 | (:use clojure.core.matrix.stats) 4 | (:use clojure.test) 5 | (:require [mikera.vectorz.core :as v])) 6 | 7 | (set! *warn-on-reflection* true) 8 | (set! *unchecked-math* :warn-on-boxed) 9 | 10 | (deftest test-mean 11 | (let [vs (map v/vec [[1 2] [3 4] [5 6] [7 8]])] 12 | (is (e== [4 5] (mean vs))))) 13 | 14 | -------------------------------------------------------------------------------- /src/test/clojure/test/misc/loading.clj: -------------------------------------------------------------------------------- 1 | (ns test.misc.loading 2 | (:use [clojure.core.matrix]) 3 | (:use [clojure.core.matrix.utils])) 4 | 5 | (defn foo [] 6 | (doall (map deref (for [i (range 10)] (future (matrix :vectorz [0 1])))))) -------------------------------------------------------------------------------- /src/test/java/mikera/vectorz/ClojureTests.java: -------------------------------------------------------------------------------- 1 | package mikera.vectorz; 2 | 3 | import mikera.cljunit.ClojureTest; 4 | 5 | public class ClojureTests extends ClojureTest { 6 | 7 | @Override 8 | public String filter() { 9 | return "mikera.vectorz"; 10 | } 11 | 12 | } 13 | --------------------------------------------------------------------------------