├── .ghci ├── .gitignore ├── CHANGELOG.md ├── LICENSE ├── README.md ├── Setup.hs ├── arrayfire.cabal ├── cabal.project ├── cbits └── wrapper.c ├── default.nix ├── doctests └── Main.hs ├── exe └── Main.hs ├── flake.lock ├── flake.nix ├── gen ├── Lex.hs ├── Main.hs ├── Parse.hs ├── Print.hs └── Types.hs ├── include ├── algorithm.h ├── arith.h ├── array.h ├── backend.h ├── blas.h ├── cuda.h ├── data.h ├── defines.h ├── device.h ├── exception.h ├── features.h ├── graphics.h ├── image.h ├── index.h ├── internal.h ├── lapack.h ├── openCL.h ├── random.h ├── seq.h ├── signal.h ├── sparse.h ├── statistics.h ├── util.h └── vision.h ├── nix ├── default.nix └── no-download.patch ├── pkg.nix ├── shell.nix ├── src ├── ArrayFire.hs └── ArrayFire │ ├── Algorithm.hs │ ├── Arith.hs │ ├── Array.hs │ ├── BLAS.hs │ ├── Backend.hs │ ├── Data.hs │ ├── Device.hs │ ├── Exception.hs │ ├── FFI.hs │ ├── Features.hs │ ├── Graphics.hs │ ├── Image.hs │ ├── Index.hs │ ├── Internal │ ├── Algorithm.hsc │ ├── Arith.hsc │ ├── Array.hsc │ ├── BLAS.hsc │ ├── Backend.hsc │ ├── CUDA.hsc │ ├── Data.hsc │ ├── Defines.hsc │ ├── Device.hsc │ ├── Exception.hsc │ ├── Features.hsc │ ├── Graphics.hsc │ ├── Image.hsc │ ├── Index.hsc │ ├── Internal.hsc │ ├── LAPACK.hsc │ ├── OpenCL.hsc │ ├── Random.hsc │ ├── Seq.hsc │ ├── Signal.hsc │ ├── Sparse.hsc │ ├── Statistics.hsc │ ├── Types.hsc │ ├── Util.hsc │ └── Vision.hsc │ ├── LAPACK.hs │ ├── Orphans.hs │ ├── Random.hs │ ├── Signal.hs │ ├── Sparse.hs │ ├── Statistics.hs │ ├── Types.hs │ ├── Util.hs │ └── Vision.hs ├── stack.yaml ├── stack.yaml.lock └── test ├── ArrayFire ├── AlgorithmSpec.hs ├── ArithSpec.hs ├── ArraySpec.hs ├── BLASSpec.hs ├── BackendSpec.hs ├── DataSpec.hs ├── DeviceSpec.hs ├── FeaturesSpec.hs ├── GraphicsSpec.hs ├── ImageSpec.hs ├── IndexSpec.hs ├── LAPACKSpec.hs ├── RandomSpec.hs ├── SignalSpec.hs ├── SparseSpec.hs ├── StatisticsSpec.hs ├── UtilSpec.hs └── VisionSpec.hs ├── Main.hs ├── Spec.hs └── Test └── Hspec └── ApproxExpect.hs /.ghci: -------------------------------------------------------------------------------- 1 | :set prompt "\x03BB> " 2 | :seti -XTypeApplications 3 | :load ArrayFire 4 | :m - ArrayFire 5 | import qualified ArrayFire as A 6 | :set -laf 7 | :set -isrc:test 8 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | dist* 2 | result/ 3 | /TAGS 4 | /result 5 | *~ 6 | /ctags 7 | cabal.project.local 8 | tags 9 | /.stack-work/ 10 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Revision history for fire 2 | 3 | ## 0.1.0.0 -- YYYY-mm-dd 4 | 5 | * First version. Released on an unsuspecting world. 6 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2018, David Johnson 2 | 3 | All rights reserved. 4 | 5 | Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions are met: 7 | 8 | * Redistributions of source code must retain the above copyright 9 | notice, this list of conditions and the following disclaimer. 10 | 11 | * Redistributions in binary form must reproduce the above 12 | copyright notice, this list of conditions and the following 13 | disclaimer in the documentation and/or other materials provided 14 | with the distribution. 15 | 16 | * Neither the name of David Johnson nor the names of other 17 | contributors may be used to endorse or promote products derived 18 | from this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 21 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 22 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 23 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 24 | OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 25 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 26 | LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 27 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 28 | THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 29 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 30 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## 2 | `ArrayFire` is a general-purpose library that simplifies the process of developing software that targets parallel and massively-parallel architectures including CPUs, GPUs, and other hardware acceleration devices. 3 | 4 | `arrayfire-haskell` is a [Haskell](https://haskell.org) binding to [ArrayFire](https://arrayfire.com). 5 | 6 | ## Table of Contents 7 | - [Installation](#Installation) 8 | - [Haskell Installation](#haskell-installation) 9 | - [Documentation](#Documentation) 10 | - [Hacking](#Hacking) 11 | - [Example](#Example) 12 | 13 | 14 | ## Installation 15 | Install `ArrayFire` via the download page. 16 | - https://arrayfire.com/download/ 17 | 18 | `ArrayFire` can also be fetched from [nixpkgs](https://github.com/nixos/nixpkgs) `master`. 19 | 20 | ### Haskell Installation 21 | 22 | `arrayfire` can be installed w/ `cabal`, `stack` or `nix`. 23 | 24 | ``` 25 | cabal install arrayfire 26 | ``` 27 | 28 | ``` 29 | stack install arrayfire 30 | ``` 31 | 32 | 33 | Also note, if you plan on using ArrayFire's visualization features, you must install `fontconfig` and `glfw` on OSX or Linux. 34 | 35 | ## Documentation 36 | - [Hackage](http://hackage.haskell.org/package/arrayfire) 37 | - [ArrayFire](http://arrayfire.org/docs/gettingstarted.htm) 38 | 39 | ## Hacking 40 | To hack on this library locally, complete the installation step above. We recommend installing the [nix](https://nixos.org/nix/download.html) package manager to facilitate development. 41 | 42 | After the above tools are installed, clone the source from Github. 43 | 44 | ```bash 45 | git clone git@github.com:arrayfire/arrayfire-haskell.git 46 | cd arrayfire-haskell 47 | ``` 48 | 49 | To build and run all tests in response to file changes 50 | 51 | ```bash 52 | nix-shell --run test-runner 53 | ``` 54 | 55 | To perform interactive development w/ `ghcid` 56 | 57 | ```bash 58 | nix-shell --run ghcid 59 | ``` 60 | 61 | To interactively evaluate code in the `repl` 62 | 63 | ```bash 64 | nix-shell --run repl 65 | ``` 66 | 67 | To produce the haddocks and open them in a browser 68 | 69 | ```bash 70 | nix-shell --run docs 71 | ``` 72 | 73 | 74 | ## Example 75 | ```haskell 76 | {-# LANGUAGE TypeApplications, ScopedTypeVariables #-} 77 | module Main where 78 | 79 | import qualified ArrayFire as A 80 | import Control.Exception (catch) 81 | 82 | main :: IO () 83 | main = print newArray `catch` (\(e :: A.AFException) -> print e) 84 | where 85 | newArray = A.matrix @Double (2,2) [ [1..], [1..] ] * A.matrix @Double (2,2) [ [2..], [2..] ] 86 | 87 | {-| 88 | 89 | ArrayFire Array 90 | [2 2 1 1] 91 | 2.0000 6.0000 92 | 2.0000 6.0000 93 | 94 | -} 95 | ``` 96 | -------------------------------------------------------------------------------- /Setup.hs: -------------------------------------------------------------------------------- 1 | module Main where 2 | 3 | import Distribution.Extra.Doctest (defaultMainWithDoctests) 4 | 5 | main :: IO () 6 | main = defaultMainWithDoctests "doctests" 7 | -------------------------------------------------------------------------------- /arrayfire.cabal: -------------------------------------------------------------------------------- 1 | cabal-version: 3.0 2 | name: arrayfire 3 | version: 0.7.0.0 4 | synopsis: Haskell bindings to the ArrayFire general-purpose GPU library 5 | homepage: https://github.com/arrayfire/arrayfire-haskell 6 | license: BSD-3-Clause 7 | license-file: LICENSE 8 | author: David Johnson 9 | maintainer: code@dmj.io 10 | copyright: David Johnson (c) 2018-2025 11 | category: Math 12 | build-type: Custom 13 | extra-source-files: CHANGELOG.md 14 | description: High-level Haskell bindings to the ArrayFire General-purpose GPU library 15 | . 16 | <> 17 | . 18 | 19 | flag disable-default-paths 20 | description: When enabled, don't add default hardcoded include/link dirs by default. Needed for hermetic builds like in nix. 21 | default: False 22 | manual: True 23 | 24 | flag disable-build-tool-depends 25 | description: When enabled, don't add build-tool-depends fields to the Cabal file. Needed for working inside @nix develop@. 26 | default: False 27 | manual: True 28 | 29 | custom-setup 30 | setup-depends: 31 | base <5, 32 | Cabal, 33 | cabal-doctest >=1 && <1.1 34 | 35 | library 36 | exposed-modules: 37 | ArrayFire 38 | ArrayFire.Algorithm 39 | ArrayFire.Arith 40 | ArrayFire.Array 41 | ArrayFire.Backend 42 | ArrayFire.BLAS 43 | ArrayFire.Data 44 | ArrayFire.Device 45 | ArrayFire.Features 46 | ArrayFire.Graphics 47 | ArrayFire.Image 48 | ArrayFire.Index 49 | ArrayFire.LAPACK 50 | ArrayFire.Random 51 | ArrayFire.Signal 52 | ArrayFire.Sparse 53 | ArrayFire.Statistics 54 | ArrayFire.Types 55 | ArrayFire.Util 56 | ArrayFire.Vision 57 | other-modules: 58 | ArrayFire.FFI 59 | ArrayFire.Exception 60 | ArrayFire.Orphans 61 | ArrayFire.Internal.Algorithm 62 | ArrayFire.Internal.Arith 63 | ArrayFire.Internal.Array 64 | ArrayFire.Internal.Backend 65 | ArrayFire.Internal.BLAS 66 | ArrayFire.Internal.Data 67 | ArrayFire.Internal.Defines 68 | ArrayFire.Internal.Device 69 | ArrayFire.Internal.Exception 70 | ArrayFire.Internal.Features 71 | ArrayFire.Internal.Graphics 72 | ArrayFire.Internal.Image 73 | ArrayFire.Internal.Index 74 | ArrayFire.Internal.Internal 75 | ArrayFire.Internal.LAPACK 76 | ArrayFire.Internal.Random 77 | ArrayFire.Internal.Signal 78 | ArrayFire.Internal.Sparse 79 | ArrayFire.Internal.Statistics 80 | ArrayFire.Internal.Types 81 | ArrayFire.Internal.Util 82 | ArrayFire.Internal.Vision 83 | if !flag(disable-build-tool-depends) 84 | build-tool-depends: 85 | hsc2hs:hsc2hs 86 | extra-libraries: 87 | af 88 | c-sources: 89 | cbits/wrapper.c 90 | build-depends: 91 | base < 5, filepath, vector 92 | hs-source-dirs: 93 | src 94 | ghc-options: 95 | -Wall -Wno-missing-home-modules 96 | default-language: 97 | Haskell2010 98 | 99 | if os(linux) && !flag(disable-default-paths) 100 | include-dirs: 101 | /opt/arrayfire/include 102 | extra-lib-dirs: 103 | /opt/arrayfire/lib64 104 | ld-options: 105 | -Wl,-rpath /opt/arrayfire/lib64 106 | 107 | if os(OSX) && !flag(disable-default-paths) 108 | include-dirs: 109 | /opt/arrayfire/include 110 | extra-lib-dirs: 111 | /opt/arrayfire/lib 112 | ld-options: 113 | -Wl,-rpath /opt/arrayfire/lib 114 | 115 | executable main 116 | hs-source-dirs: 117 | exe 118 | main-is: 119 | Main.hs 120 | build-depends: 121 | base < 5, arrayfire, vector 122 | c-sources: 123 | cbits/wrapper.c 124 | default-language: 125 | Haskell2010 126 | 127 | executable gen 128 | main-is: 129 | Main.hs 130 | hs-source-dirs: 131 | gen 132 | build-depends: 133 | base < 5, parsec, text, directory 134 | default-language: 135 | Haskell2010 136 | other-modules: 137 | Lex 138 | Parse 139 | Print 140 | Types 141 | 142 | test-suite test 143 | type: 144 | exitcode-stdio-1.0 145 | main-is: 146 | Main.hs 147 | other-modules: 148 | Test.Hspec.ApproxExpect 149 | hs-source-dirs: 150 | test 151 | build-depends: 152 | arrayfire, 153 | base < 5, 154 | directory, 155 | hspec, 156 | HUnit, 157 | QuickCheck, 158 | quickcheck-classes, 159 | vector, 160 | call-stack >=0.4 && <0.5 161 | if !flag(disable-build-tool-depends) 162 | build-tool-depends: 163 | hspec-discover:hspec-discover 164 | default-language: 165 | Haskell2010 166 | other-modules: 167 | Spec 168 | ArrayFire.AlgorithmSpec 169 | ArrayFire.ArithSpec 170 | ArrayFire.ArraySpec 171 | ArrayFire.BLASSpec 172 | ArrayFire.BackendSpec 173 | ArrayFire.DataSpec 174 | ArrayFire.DeviceSpec 175 | ArrayFire.FeaturesSpec 176 | ArrayFire.GraphicsSpec 177 | ArrayFire.ImageSpec 178 | ArrayFire.IndexSpec 179 | ArrayFire.LAPACKSpec 180 | ArrayFire.RandomSpec 181 | ArrayFire.SignalSpec 182 | ArrayFire.SparseSpec 183 | ArrayFire.StatisticsSpec 184 | ArrayFire.UtilSpec 185 | ArrayFire.VisionSpec 186 | 187 | test-suite doctests 188 | type: 189 | exitcode-stdio-1.0 190 | buildable: 191 | False 192 | ghc-options: 193 | -threaded 194 | main-is: 195 | Main.hs 196 | hs-source-dirs: 197 | doctests 198 | build-depends: 199 | arrayfire 200 | , base < 5 201 | , doctest >= 0.8 202 | , split 203 | autogen-modules: 204 | Build_doctests 205 | other-modules: 206 | Build_doctests 207 | default-language: 208 | Haskell2010 209 | 210 | source-repository head 211 | type: git 212 | location: https://github.com/arrayfire/arrayfire-haskell.git 213 | -------------------------------------------------------------------------------- /cabal.project: -------------------------------------------------------------------------------- 1 | packages: . 2 | ignore-project: False 3 | write-ghc-environment-files: always 4 | tests: True 5 | test-options: "--color" 6 | test-show-details: streaming 7 | -------------------------------------------------------------------------------- /cbits/wrapper.c: -------------------------------------------------------------------------------- 1 | #include "arrayfire.h" 2 | #include 3 | 4 | af_err af_random_engine_set_type_(af_random_engine engine, const af_random_engine_type rtype) { return af_random_engine_set_type(&engine, rtype); } 5 | 6 | af_err af_random_engine_set_seed_(af_random_engine engine, const unsigned long long seed) { 7 | return af_random_engine_set_seed(&engine, seed); 8 | } 9 | 10 | void test_bool () { 11 | double * data = malloc (sizeof (int) * 5); 12 | data[0] = 2; 13 | data[1] = 2; 14 | data[2] = 2; 15 | data[3] = 2; 16 | data[4] = 2; 17 | data[5] = 2; 18 | dim_t * dims = malloc(sizeof(dim_t) * 4); 19 | dims[0] = 5; 20 | dims[1] = 1; 21 | dims[2] = 1; 22 | dims[3] = 1; 23 | af_array arrin; 24 | af_create_array(&arrin, data, 1, dims, f64); 25 | printf("printing input array\n"); 26 | af_print_array(arrin); 27 | af_array arrout; 28 | af_product(&arrout, arrin, 0); 29 | printf("printing output array\n"); 30 | af_print_array(arrout); 31 | } 32 | 33 | void test_window () { 34 | af_window window; 35 | af_create_window(&window, 100, 100, "foo"); 36 | af_show(window); 37 | } 38 | 39 | void zeroOutArray (af_array * arr) { 40 | (*arr) = 0; 41 | } 42 | -------------------------------------------------------------------------------- /default.nix: -------------------------------------------------------------------------------- 1 | { pkgs ? import { config.allowUnfree = true; } }: 2 | # Latest arrayfire is not yet procured w/ nix. 3 | let 4 | pkg = pkgs.haskellPackages.callCabal2nix "arrayfire" ./. { 5 | af = null; 6 | }; 7 | in 8 | pkg 9 | -------------------------------------------------------------------------------- /doctests/Main.hs: -------------------------------------------------------------------------------- 1 | module Main (main) where 2 | 3 | import ArrayFire 4 | import Build_doctests (flags, pkgs, module_sources) 5 | import System.Environment 6 | import Test.DocTest (doctest) 7 | import Data.List.Split 8 | 9 | main :: IO () 10 | main = do 11 | print $ 1 + (1 :: Array Int) 12 | moreFlags <- drop 1 . splitOn " " <$> getEnv "NIX_TARGET_LDFLAGS" 13 | mapM_ print moreFlags 14 | mapM_ print (flags ++ pkgs ++ module_sources ++ moreFlags) 15 | doctest (moreFlags ++ flags ++ pkgs ++ module_sources) 16 | -------------------------------------------------------------------------------- /exe/Main.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE BangPatterns #-} 2 | {-# LANGUAGE ScopedTypeVariables #-} 3 | {-# LANGUAGE TypeApplications #-} 4 | {-# LANGUAGE DataKinds #-} 5 | module Main where 6 | 7 | import ArrayFire 8 | import Control.Concurrent 9 | import Control.Exception 10 | 11 | import Prelude hiding (sum, product) 12 | -- import GHC.RTS 13 | 14 | foreign import ccall safe "test_bool" 15 | testBool :: IO () 16 | 17 | foreign import ccall safe "test_window" 18 | testWindow :: IO () 19 | 20 | main' :: IO () 21 | main' = print newArray `catch` (\(e :: AFException) -> print e) 22 | where 23 | newArray = matrix @Double (2,2) [ [1..], [1..] ] * matrix @Double (2,2) [ [2..], [2..] ] 24 | 25 | main :: IO () 26 | main = do 27 | main' 28 | -- testWindow 29 | -- ks <- randn @Double [100,100] 30 | -- saveArray "key" ks "array.txt" False 31 | -- !ks' <- readArrayKey "array.txt" "key" 32 | -- print ks' 33 | 34 | -- info >> putStrLn "ok" >> afInit 35 | -- -- Info things 36 | -- print =<< getSizeOf (Proxy @ Double) 37 | -- print =<< getVersion 38 | -- print =<< getRevision 39 | -- -- getInfo 40 | -- -- print =<< errorToString afErrNoMem 41 | -- putStrLn =<< getInfoString 42 | -- print =<< getDeviceCount 43 | -- print =<< getDevice 44 | 45 | -- -- Create and print an array 46 | -- -- arr1 <- constant 1 1 1 f64 47 | -- -- arr2 <- constant 2 1 1 f64 48 | -- -- r <- addArray arr1 arr2 True 49 | -- -- printArray r 50 | 51 | -- -- print =<< isLAPACKAvailable 52 | -- -- print =<< getAvailableBackends 53 | -- -- print =<< getActiveBackend 54 | -- -- print =<< getAvailableBackends 55 | 56 | -- -- array <- constant @'(10,10) 200 57 | -- -- putStrLn "backend id" 58 | -- -- print (getBackendID array) 59 | -- -- putStrLn "device id" 60 | -- -- print (getDeviceID array) 61 | 62 | -- -- array <- randu @'(9,9,9) @Double 63 | -- -- printArray array -- printArray (mean array 0) 64 | 65 | -- -- printArray (add array 1) 66 | 67 | -- -- putStrLn "got eeem" 68 | -- -- print =<< getDataPtr x 69 | 70 | -- -- x <- constant 10 1 1 f64 71 | -- -- printArray =<< mean x 0 72 | 73 | -- -- print =<< isLAPACKAvailable 74 | 75 | -- a <- randu @'(3,3) @Float 76 | -- b <- randu @'(3,3) @Float 77 | -- printArray ((a `matmul` b) None None) 78 | -- `catch` (\(e :: AFException) -> do 79 | -- putStrLn "got one" 80 | -- print e) 81 | 82 | putStrLn "create window" 83 | window <- createWindow 200 200 "hey" 84 | putStrLn "set visibility" 85 | setVisibility window True 86 | putStrLn "show window" 87 | showWindow window 88 | threadDelay (secs 10) 89 | 90 | -- -- print =<< getActiveBackend 91 | -- -- print =<< getDeviceCount 92 | -- -- print =<< getDevice 93 | -- -- putStrLn "info" 94 | -- -- getInfo 95 | -- -- putStrLn "info string" 96 | -- -- putStrLn =<< getInfoString 97 | -- -- print =<< getVersion 98 | 99 | 100 | secs :: Int -> Int 101 | secs = (*1000000) 102 | -------------------------------------------------------------------------------- /flake.lock: -------------------------------------------------------------------------------- 1 | { 2 | "nodes": { 3 | "arrayfire-nix": { 4 | "inputs": { 5 | "flake-utils": [ 6 | "flake-utils" 7 | ], 8 | "nixpkgs": [ 9 | "nixpkgs" 10 | ] 11 | }, 12 | "locked": { 13 | "lastModified": 1692793973, 14 | "narHash": "sha256-6dG41ile3T+6dfRazlcPBdKBarGesswsBpb40Lcf35U=", 15 | "owner": "twesterhout", 16 | "repo": "arrayfire-nix", 17 | "rev": "4236770612b80a3f29adbd8d670f6cea2bc098ba", 18 | "type": "github" 19 | }, 20 | "original": { 21 | "owner": "twesterhout", 22 | "repo": "arrayfire-nix", 23 | "type": "github" 24 | } 25 | }, 26 | "flake-utils": { 27 | "inputs": { 28 | "systems": "systems" 29 | }, 30 | "locked": { 31 | "lastModified": 1692792214, 32 | "narHash": "sha256-voZDQOvqHsaReipVd3zTKSBwN7LZcUwi3/ThMxRZToU=", 33 | "owner": "numtide", 34 | "repo": "flake-utils", 35 | "rev": "1721b3e7c882f75f2301b00d48a2884af8c448ae", 36 | "type": "github" 37 | }, 38 | "original": { 39 | "owner": "numtide", 40 | "repo": "flake-utils", 41 | "type": "github" 42 | } 43 | }, 44 | "nix-filter": { 45 | "locked": { 46 | "lastModified": 1687178632, 47 | "narHash": "sha256-HS7YR5erss0JCaUijPeyg2XrisEb959FIct3n2TMGbE=", 48 | "owner": "numtide", 49 | "repo": "nix-filter", 50 | "rev": "d90c75e8319d0dd9be67d933d8eb9d0894ec9174", 51 | "type": "github" 52 | }, 53 | "original": { 54 | "owner": "numtide", 55 | "repo": "nix-filter", 56 | "type": "github" 57 | } 58 | }, 59 | "nixpkgs": { 60 | "locked": { 61 | "lastModified": 1692638711, 62 | "narHash": "sha256-J0LgSFgJVGCC1+j5R2QndadWI1oumusg6hCtYAzLID4=", 63 | "owner": "nixos", 64 | "repo": "nixpkgs", 65 | "rev": "91a22f76cd1716f9d0149e8a5c68424bb691de15", 66 | "type": "github" 67 | }, 68 | "original": { 69 | "owner": "nixos", 70 | "ref": "nixos-unstable", 71 | "repo": "nixpkgs", 72 | "type": "github" 73 | } 74 | }, 75 | "root": { 76 | "inputs": { 77 | "arrayfire-nix": "arrayfire-nix", 78 | "flake-utils": "flake-utils", 79 | "nix-filter": "nix-filter", 80 | "nixpkgs": "nixpkgs" 81 | } 82 | }, 83 | "systems": { 84 | "locked": { 85 | "lastModified": 1681028828, 86 | "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", 87 | "owner": "nix-systems", 88 | "repo": "default", 89 | "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", 90 | "type": "github" 91 | }, 92 | "original": { 93 | "owner": "nix-systems", 94 | "repo": "default", 95 | "type": "github" 96 | } 97 | } 98 | }, 99 | "root": "root", 100 | "version": 7 101 | } 102 | -------------------------------------------------------------------------------- /flake.nix: -------------------------------------------------------------------------------- 1 | { 2 | description = "arrayfire/arrayfire-haskell: ArrayFire Haskell bindings"; 3 | 4 | inputs = { 5 | nixpkgs.url = "github:nixos/nixpkgs/nixos-unstable"; 6 | flake-utils.url = "github:numtide/flake-utils"; 7 | nix-filter.url = "github:numtide/nix-filter"; 8 | arrayfire-nix = { 9 | url = "github:twesterhout/arrayfire-nix"; 10 | inputs.flake-utils.follows = "flake-utils"; 11 | inputs.nixpkgs.follows = "nixpkgs"; 12 | }; 13 | }; 14 | 15 | outputs = inputs: 16 | let 17 | src = inputs.nix-filter.lib { 18 | root = ./.; 19 | include = [ 20 | "cbits" 21 | "exe" 22 | "gen" 23 | "include" 24 | "src" 25 | "test" 26 | "arrayfire.cabal" 27 | "README.md" 28 | "CHANGELOG.md" 29 | "LICENSE" 30 | ]; 31 | }; 32 | 33 | # An overlay that lets us test arrayfire-haskell with different GHC versions 34 | arrayfire-haskell-overlay = self: super: { 35 | haskell = super.haskell // { 36 | packageOverrides = inputs.nixpkgs.lib.composeExtensions super.haskell.packageOverrides 37 | (hself: hsuper: { 38 | arrayfire = self.haskell.lib.appendConfigureFlags 39 | (hself.callCabal2nix "arrayfire" src { af = self.arrayfire; }) 40 | [ "-f disable-default-paths" ]; 41 | }); 42 | }; 43 | }; 44 | 45 | devShell-for = pkgs: 46 | let 47 | ps = pkgs.haskellPackages; 48 | in 49 | ps.shellFor { 50 | packages = ps: with ps; [ arrayfire ]; 51 | withHoogle = true; 52 | buildInputs = with pkgs; [ ocl-icd ]; 53 | nativeBuildInputs = with pkgs; with ps; [ 54 | # Building and testing 55 | cabal-install 56 | doctest 57 | hsc2hs 58 | hspec-discover 59 | # Language servers 60 | haskell-language-server 61 | nil 62 | # Formatters 63 | nixpkgs-fmt 64 | ]; 65 | shellHook = '' 66 | ''; 67 | }; 68 | 69 | pkgs-for = system: import inputs.nixpkgs { 70 | inherit system; 71 | overlays = [ 72 | inputs.arrayfire-nix.overlays.default 73 | arrayfire-haskell-overlay 74 | ]; 75 | }; 76 | in 77 | { 78 | packages = inputs.flake-utils.lib.eachDefaultSystemMap (system: 79 | with (pkgs-for system); { 80 | default = haskellPackages.arrayfire; 81 | haskell = haskell.packages; 82 | }); 83 | 84 | devShells = inputs.flake-utils.lib.eachDefaultSystemMap (system: { 85 | default = devShell-for (pkgs-for system); 86 | }); 87 | 88 | overlays.default = arrayfire-haskell-overlay; 89 | }; 90 | } 91 | -------------------------------------------------------------------------------- /gen/Lex.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE OverloadedStrings #-} 2 | module Lex where 3 | 4 | import Control.Arrow 5 | import Data.Text (Text) 6 | import qualified Data.Text as T 7 | 8 | import Data.Char 9 | import Types 10 | 11 | symbols :: String 12 | symbols = " *();," 13 | 14 | lex :: Text -> [Token] 15 | lex = go NameMode 16 | where 17 | tokenize ' ' = [] 18 | tokenize '*' = [Star] 19 | tokenize '(' = [LParen] 20 | tokenize ')' = [RParen] 21 | tokenize ';' = [Semi] 22 | tokenize ',' = [Comma] 23 | tokenize _ = [] 24 | go TokenMode xs = do 25 | case T.uncons xs of 26 | Nothing -> [] 27 | Just (c,cs) 28 | | isAlpha c -> go NameMode (T.cons c cs) 29 | | otherwise -> tokenize c ++ go TokenMode cs 30 | go NameMode xs = do 31 | let (match, rest) = partition xs 32 | if match == "const" 33 | then [] ++ go TokenMode rest 34 | else Id match : go TokenMode rest 35 | 36 | partition :: Text -> (Text,Text) 37 | partition = 38 | T.takeWhile (`notElem` symbols) &&& 39 | T.dropWhile (`notElem` symbols) 40 | -------------------------------------------------------------------------------- /gen/Main.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE ScopedTypeVariables #-} 2 | {-# LANGUAGE FlexibleContexts #-} 3 | {-# LANGUAGE OverloadedStrings #-} 4 | {-# LANGUAGE GeneralizedNewtypeDeriving #-} 5 | module Main where 6 | 7 | import Control.Monad 8 | import Data.Char 9 | import Data.Either 10 | 11 | import Data.Maybe 12 | import Data.Text (Text) 13 | import qualified Data.Text as T 14 | import qualified Data.Text.IO as T 15 | import System.Directory 16 | import System.Environment 17 | import System.Exit 18 | import Text.Printf 19 | 20 | import Lex 21 | import Parse 22 | import Print 23 | import Types 24 | 25 | main :: IO () 26 | main = mapM_ writeToDisk =<< getDirectoryFiles 27 | 28 | getDirectoryFiles :: IO [String] 29 | getDirectoryFiles = do 30 | filter (`notElem` exclude) <$> listDirectory "include" 31 | where 32 | exclude = [ "defines.h" 33 | , "complex.h" 34 | , "extra.h" 35 | ] 36 | 37 | writeToDisk :: String -> IO () 38 | writeToDisk fileName = do 39 | bindings 40 | <- map run 41 | . drop 1 42 | . filter (not . T.null) 43 | . T.lines <$> T.readFile ("include/" <> fileName) 44 | case partitionEithers bindings of 45 | (failures, successes) -> do 46 | if length failures > 0 47 | then do 48 | mapM_ print (listToMaybe failures) 49 | printf "%s failed to generate bindings\n" fileName 50 | else do 51 | let name = makeName (reverse . drop 2 . reverse $ fileName) 52 | T.writeFile (makePath name) $ 53 | file name <> T.intercalate "\n" (genBinding <$> successes) 54 | printf "Wrote bindings to %s\n" (makePath name) 55 | 56 | -- | Filename remappings 57 | makeName :: String -> String 58 | makeName n 59 | | n == "lapack" = "LAPACK" 60 | | n == "blas" = "BLAS" 61 | | n == "cuda" = "CUDA" 62 | | otherwise = n 63 | 64 | makePath :: String -> String 65 | makePath s = 66 | printf "src/ArrayFire/Internal/%s.hsc" (capitalName (makeName s)) 67 | 68 | file :: String -> Text 69 | file a = T.pack $ printf 70 | "{-# LANGUAGE CPP #-}\n\ 71 | \module ArrayFire.Internal.%s where\n\n\ 72 | \import ArrayFire.Internal.Defines\n\ 73 | \import ArrayFire.Internal.Types\n\ 74 | \import Foreign.Ptr\n\ 75 | \import Foreign.C.Types\n\n\ 76 | \#include \"af/%s.h\"\n\ 77 | \" (capitalName a) (if lowerCase a == "exception" 78 | then "defines" 79 | else lowerCase a) 80 | 81 | capitalName, lowerCase :: [Char] -> [Char] 82 | capitalName (x:xs) = toUpper x : xs 83 | lowerCase (x:xs) = map toLower (x:xs) 84 | -------------------------------------------------------------------------------- /gen/Parse.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE OverloadedStrings #-} 2 | module Parse where 3 | 4 | import Prelude hiding (lex) 5 | import Control.Monad 6 | import Data.Text (Text) 7 | import qualified Data.Text as T 8 | import Text.Parsec 9 | 10 | import Lex 11 | import Types 12 | 13 | parseAST :: Parser AST 14 | parseAST = do 15 | afapi 16 | type' <- getType 17 | funcName <- name 18 | lparen 19 | params <- getParam `sepBy` comma 20 | rparen >> semi 21 | pure (AST type' funcName params) 22 | 23 | getParam :: Parser Type 24 | getParam = do 25 | t <- getType 26 | t <$ name 27 | 28 | getType :: Parser Type 29 | getType = do 30 | typeName <- name 31 | stars <- msum [ try (rep x) | x <- [3,2..0] ] 32 | pure (Type typeName stars) 33 | where 34 | rep n = replicateM_ n star >> pure n 35 | 36 | run :: Text -> Either ParseError AST 37 | run txt = parse parseAST mempty (lex txt) 38 | 39 | afapi,lparen,rparen,semi,star,comma :: Parser () 40 | lparen = tok' LParen 41 | rparen = tok' RParen 42 | afapi = tok' (Id "AFAPI") <|> pure () 43 | comma = tok' Comma 44 | semi = tok' Semi 45 | star = tok' Star 46 | 47 | tok :: Token -> Parser Token 48 | tok x = tokenPrim show ignore 49 | (\t -> if x == t then Just x else Nothing) 50 | 51 | tok' :: Token -> Parser () 52 | tok' x = tokenPrim show ignore 53 | (\t -> if x == t then Just () else Nothing) 54 | 55 | name :: Parser Name 56 | name = tokenPrim show ignore go 57 | where 58 | go (Id x) = Just (Name x) 59 | go _ = Nothing 60 | 61 | ignore x _ _ = x 62 | -------------------------------------------------------------------------------- /gen/Print.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE OverloadedStrings #-} 2 | module Print where 3 | 4 | import Data.Char 5 | import Data.Text (Text) 6 | import qualified Data.Text as T 7 | import Text.Printf 8 | 9 | import Types 10 | 11 | import Parse 12 | 13 | camelize :: Text -> Text 14 | camelize = T.concat . map go . T.splitOn "_" 15 | where 16 | go "af" = "AF" 17 | go xs = T.cons (toUpper c) cs 18 | where 19 | c = T.head xs 20 | cs = T.tail xs 21 | 22 | isPtr (Type _ x) = x > 0 23 | 24 | genBinding :: AST -> Text 25 | genBinding (AST type' name params) = 26 | header <> dumpBody <> dumpOutput 27 | where 28 | dumpOutput | isPtr type' = "IO (" <> printType type' <> ")" 29 | | otherwise = "IO " <> printType type' 30 | header = T.pack $ printf "foreign import ccall unsafe \"%s\"\n %s :: " name name 31 | dumpBody = printTypes params 32 | 33 | printTypes :: [Type] -> Text 34 | printTypes [] = mempty 35 | printTypes [x] = printType x <> " -> " 36 | printTypes (x:xs) = 37 | mconcat [ 38 | printType x 39 | , " -> " 40 | , printTypes xs 41 | ] 42 | 43 | printType (Type (Name x) 0) = showType x 44 | printType (Type (Name x) 1) = "Ptr " <> showType x 45 | printType (Type t n) = "Ptr (" <> printType (Type t (n-1)) <> ")" 46 | 47 | -- | Additional mappings, very important for CodeGen 48 | showType :: Text -> Text 49 | showType "char" = "CChar" 50 | showType "void" = "()" 51 | showType "unsigned" = "CUInt" 52 | showType "dim_t" = "DimT" 53 | showType "af_someenum_t" = "AFSomeEnum" 54 | showType "size_t" = "CSize" 55 | showType "uintl" = "UIntL" 56 | showType "intl" = "IntL" 57 | showType "int" = "CInt" 58 | showType "bool" = "CBool" 59 | showType "af_index_t" = "AFIndex" 60 | showType "af_cspace_t" = "AFCSpace" 61 | showType "afcl_platform" = "AFCLPlatform" 62 | showType x = camelize x 63 | -------------------------------------------------------------------------------- /gen/Types.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE GeneralizedNewtypeDeriving #-} 2 | module Types where 3 | 4 | import Data.Functor.Identity 5 | import Data.Text (Text) 6 | import qualified Data.Text as T 7 | import Text.Parsec 8 | import Text.Printf 9 | 10 | type Parser = ParsecT [Token] () Identity 11 | 12 | type Params = [Type] 13 | 14 | data AST = AST Type Name [Type] 15 | deriving (Show) 16 | 17 | data Type = Type Name Int 18 | deriving (Show) 19 | 20 | newtype Name = Name Text 21 | deriving (Show, Eq, PrintfArg) 22 | 23 | data Mode 24 | = TokenMode 25 | | NameMode 26 | deriving (Eq, Show) 27 | 28 | data Token 29 | = Id Text 30 | | Star 31 | | LParen 32 | | RParen 33 | | Comma 34 | | Semi 35 | deriving (Show, Eq) 36 | -------------------------------------------------------------------------------- /include/algorithm.h: -------------------------------------------------------------------------------- 1 | #include "defines.h" 2 | 3 | af_err af_sum(af_array *out, const af_array in, const int dim); 4 | af_err af_sum_nan(af_array *out, const af_array in, const int dim, const double nanval); 5 | af_err af_product(af_array *out, const af_array in, const int dim); 6 | af_err af_product_nan(af_array *out, const af_array in, const int dim, const double nanval); 7 | af_err af_min(af_array *out, const af_array in, const int dim); 8 | af_err af_max(af_array *out, const af_array in, const int dim); 9 | af_err af_all_true(af_array *out, const af_array in, const int dim); 10 | af_err af_any_true(af_array *out, const af_array in, const int dim); 11 | af_err af_count(af_array *out, const af_array in, const int dim); 12 | af_err af_sum_all(double *real, double *imag, const af_array in); 13 | af_err af_sum_nan_all(double *real, double *imag, const af_array in, const double nanval); 14 | af_err af_product_all(double *real, double *imag, const af_array in); 15 | af_err af_product_nan_all(double *real, double *imag, const af_array in, const double nanval); 16 | af_err af_min_all(double *real, double *imag, const af_array in); 17 | af_err af_max_all(double *real, double *imag, const af_array in); 18 | af_err af_all_true_all(double *real, double *imag, const af_array in); 19 | af_err af_any_true_all(double *real, double *imag, const af_array in); 20 | af_err af_count_all(double *real, double *imag, const af_array in); 21 | af_err af_imin(af_array *out, af_array *idx, const af_array in, const int dim); 22 | af_err af_imax(af_array *out, af_array *idx, const af_array in, const int dim); 23 | af_err af_imin_all(double *real, double *imag, unsigned *idx, const af_array in); 24 | af_err af_imax_all(double *real, double *imag, unsigned *idx, const af_array in); 25 | af_err af_accum(af_array *out, const af_array in, const int dim); 26 | af_err af_scan(af_array *out, const af_array in, const int dim, af_binary_op op, bool inclusive_scan); 27 | af_err af_scan_by_key(af_array *out, const af_array key, const af_array in, const int dim, af_binary_op op, bool inclusive_scan); 28 | af_err af_where(af_array *idx, const af_array in); 29 | af_err af_diff1(af_array *out, const af_array in, const int dim); 30 | af_err af_diff2(af_array *out, const af_array in, const int dim); 31 | af_err af_sort(af_array *out, const af_array in, const unsigned dim, const bool isAscending); 32 | af_err af_sort_index(af_array *out, af_array *indices, const af_array in, const unsigned dim, const bool isAscending); 33 | af_err af_sort_by_key(af_array *out_keys, af_array *out_values, const af_array keys, const af_array values, const unsigned dim, const bool isAscending); 34 | af_err af_set_unique(af_array *out, const af_array in, const bool is_sorted); 35 | af_err af_set_union(af_array *out, const af_array first, const af_array second, const bool is_unique); 36 | af_err af_set_intersect(af_array *out, const af_array first, const af_array second, const bool is_unique); 37 | -------------------------------------------------------------------------------- /include/arith.h: -------------------------------------------------------------------------------- 1 | #include "defines.h" 2 | 3 | af_err af_add (af_array *out, const af_array lhs, const af_array rhs, const bool batch); 4 | af_err af_sub (af_array *out, const af_array lhs, const af_array rhs, const bool batch); 5 | af_err af_mul (af_array *out, const af_array lhs, const af_array rhs, const bool batch); 6 | af_err af_div (af_array *out, const af_array lhs, const af_array rhs, const bool batch); 7 | af_err af_lt (af_array *out, const af_array lhs, const af_array rhs, const bool batch); 8 | af_err af_gt (af_array *out, const af_array lhs, const af_array rhs, const bool batch); 9 | af_err af_le (af_array *out, const af_array lhs, const af_array rhs, const bool batch); 10 | af_err af_ge (af_array *out, const af_array lhs, const af_array rhs, const bool batch); 11 | af_err af_eq (af_array *out, const af_array lhs, const af_array rhs, const bool batch); 12 | af_err af_neq (af_array *out, const af_array lhs, const af_array rhs, const bool batch); 13 | af_err af_and (af_array *out, const af_array lhs, const af_array rhs, const bool batch); 14 | af_err af_or (af_array *out, const af_array lhs, const af_array rhs, const bool batch); 15 | af_err af_not (af_array *out, const af_array in); 16 | af_err af_bitand (af_array *out, const af_array lhs, const af_array rhs, const bool batch); 17 | af_err af_bitor (af_array *out, const af_array lhs, const af_array rhs, const bool batch); 18 | af_err af_bitxor (af_array *out, const af_array lhs, const af_array rhs, const bool batch); 19 | af_err af_bitshiftl(af_array *out, const af_array lhs, const af_array rhs, const bool batch); 20 | af_err af_bitshiftr(af_array *out, const af_array lhs, const af_array rhs, const bool batch); 21 | af_err af_cast (af_array *out, const af_array in, const af_dtype type); 22 | af_err af_minof (af_array *out, const af_array lhs, const af_array rhs, const bool batch); 23 | af_err af_maxof (af_array *out, const af_array lhs, const af_array rhs, const bool batch); 24 | af_err af_clamp(af_array *out, const af_array in, const af_array lo, const af_array hi, const bool batch); 25 | af_err af_rem (af_array *out, const af_array lhs, const af_array rhs, const bool batch); 26 | af_err af_mod (af_array *out, const af_array lhs, const af_array rhs, const bool batch); 27 | af_err af_abs (af_array *out, const af_array in); 28 | af_err af_arg (af_array *out, const af_array in); 29 | af_err af_sign (af_array *out, const af_array in); 30 | af_err af_round (af_array *out, const af_array in); 31 | af_err af_trunc (af_array *out, const af_array in); 32 | af_err af_floor (af_array *out, const af_array in); 33 | af_err af_ceil (af_array *out, const af_array in); 34 | af_err af_hypot (af_array *out, const af_array lhs, const af_array rhs, const bool batch); 35 | af_err af_sin (af_array *out, const af_array in); 36 | af_err af_cos (af_array *out, const af_array in); 37 | af_err af_tan (af_array *out, const af_array in); 38 | af_err af_asin (af_array *out, const af_array in); 39 | af_err af_acos (af_array *out, const af_array in); 40 | af_err af_atan (af_array *out, const af_array in); 41 | af_err af_atan2 (af_array *out, const af_array lhs, const af_array rhs, const bool batch); 42 | af_err af_cplx2 (af_array *out, const af_array lhs, const af_array rhs, const bool batch); 43 | af_err af_cplx (af_array *out, const af_array in); 44 | af_err af_real (af_array *out, const af_array in); 45 | af_err af_imag (af_array *out, const af_array in); 46 | af_err af_conjg (af_array *out, const af_array in); 47 | af_err af_sinh (af_array *out, const af_array in); 48 | af_err af_cosh (af_array *out, const af_array in); 49 | af_err af_tanh (af_array *out, const af_array in); 50 | af_err af_asinh (af_array *out, const af_array in); 51 | af_err af_acosh (af_array *out, const af_array in); 52 | af_err af_atanh (af_array *out, const af_array in); 53 | af_err af_root (af_array *out, const af_array lhs, const af_array rhs, const bool batch); 54 | af_err af_pow (af_array *out, const af_array lhs, const af_array rhs, const bool batch); 55 | af_err af_pow2 (af_array *out, const af_array in); 56 | af_err af_exp (af_array *out, const af_array in); 57 | af_err af_sigmoid (af_array *out, const af_array in); 58 | af_err af_expm1 (af_array *out, const af_array in); 59 | af_err af_erf (af_array *out, const af_array in); 60 | af_err af_erfc (af_array *out, const af_array in); 61 | af_err af_log (af_array *out, const af_array in); 62 | af_err af_log1p (af_array *out, const af_array in); 63 | af_err af_log10 (af_array *out, const af_array in); 64 | af_err af_log2 (af_array *out, const af_array in); 65 | af_err af_sqrt (af_array *out, const af_array in); 66 | af_err af_cbrt (af_array *out, const af_array in); 67 | af_err af_factorial (af_array *out, const af_array in); 68 | af_err af_tgamma (af_array *out, const af_array in); 69 | af_err af_lgamma (af_array *out, const af_array in); 70 | af_err af_iszero (af_array *out, const af_array in); 71 | af_err af_isinf (af_array *out, const af_array in); 72 | af_err af_isnan (af_array *out, const af_array in); 73 | -------------------------------------------------------------------------------- /include/array.h: -------------------------------------------------------------------------------- 1 | #include "defines.h" 2 | 3 | af_err af_create_array(af_array *arr, const void * const data, const unsigned ndims, const dim_t * const dims, const af_dtype type); 4 | af_err af_create_handle(af_array *arr, const unsigned ndims, const dim_t * const dims, const af_dtype type); 5 | af_err af_copy_array(af_array *arr, const af_array in); 6 | af_err af_write_array(af_array arr, const void *data, const size_t bytes, af_source src); 7 | af_err af_get_data_ptr(void *data, const af_array arr); 8 | af_err af_release_array(af_array arr); 9 | af_err af_retain_array(af_array *out, const af_array in); 10 | af_err af_get_data_ref_count(int *use_count, const af_array in); 11 | af_err af_eval(af_array in); 12 | af_err af_eval_multiple(const int num, af_array *arrays); 13 | af_err af_set_manual_eval_flag(bool flag); 14 | af_err af_get_manual_eval_flag(bool *flag); 15 | af_err af_get_elements(dim_t *elems, const af_array arr); 16 | af_err af_get_type(af_dtype *type, const af_array arr); 17 | af_err af_get_dims(dim_t *d0, dim_t *d1, dim_t *d2, dim_t *d3, const af_array arr); 18 | af_err af_get_numdims(unsigned *result, const af_array arr); 19 | af_err af_is_empty (bool *result, const af_array arr); 20 | af_err af_is_scalar (bool *result, const af_array arr); 21 | af_err af_is_row (bool *result, const af_array arr); 22 | af_err af_is_column (bool *result, const af_array arr); 23 | af_err af_is_vector (bool *result, const af_array arr); 24 | af_err af_is_complex (bool *result, const af_array arr); 25 | af_err af_is_real (bool *result, const af_array arr); 26 | af_err af_is_double (bool *result, const af_array arr); 27 | af_err af_is_single (bool *result, const af_array arr); 28 | af_err af_is_realfloating (bool *result, const af_array arr); 29 | af_err af_is_floating (bool *result, const af_array arr); 30 | af_err af_is_integer (bool *result, const af_array arr); 31 | af_err af_is_bool (bool *result, const af_array arr); 32 | af_err af_is_sparse (bool *result, const af_array arr); 33 | af_err af_get_scalar(void* output_value, const af_array arr); 34 | -------------------------------------------------------------------------------- /include/backend.h: -------------------------------------------------------------------------------- 1 | #include "defines.h" 2 | 3 | af_err af_set_backend(const af_backend bknd); 4 | af_err af_get_backend_count(unsigned* num_backends); 5 | af_err af_get_available_backends(int* backends); 6 | af_err af_get_backend_id(af_backend *backend, const af_array in); 7 | af_err af_get_active_backend(af_backend *backend); 8 | af_err af_get_device_id(int *device, const af_array in); 9 | -------------------------------------------------------------------------------- /include/blas.h: -------------------------------------------------------------------------------- 1 | #include "defines.h" 2 | 3 | af_err af_matmul( af_array *out ,const af_array lhs, const af_array rhs, const af_mat_prop optLhs, const af_mat_prop optRhs); 4 | af_err af_dot(af_array *out, const af_array lhs, const af_array rhs, const af_mat_prop optLhs, const af_mat_prop optRhs); 5 | af_err af_dot_all(double *real, double *imag, const af_array lhs, const af_array rhs, const af_mat_prop optLhs, const af_mat_prop optRhs); 6 | af_err af_transpose(af_array *out, af_array in, const bool conjugate); 7 | af_err af_transpose_inplace(af_array in, const bool conjugate); 8 | -------------------------------------------------------------------------------- /include/cuda.h: -------------------------------------------------------------------------------- 1 | #include "defines.h" 2 | 3 | af_err afcu_get_stream(cudaStream_t* stream, int id); 4 | af_err afcu_get_native_id(int* nativeid, int id); 5 | af_err afcu_set_native_id(int nativeid); 6 | -------------------------------------------------------------------------------- /include/data.h: -------------------------------------------------------------------------------- 1 | #include "defines.h" 2 | 3 | af_err af_constant(af_array *arr, const double val, const unsigned ndims, const dim_t * const dims, const af_dtype type); 4 | af_err af_constant_complex(af_array *arr, const double real, const double imag, const unsigned ndims, const dim_t * const dims, const af_dtype type); 5 | af_err af_constant_long (af_array *arr, const intl val, const unsigned ndims, const dim_t * const dims); 6 | af_err af_constant_ulong(af_array *arr, const uintl val, const unsigned ndims, const dim_t * const dims); 7 | af_err af_range(af_array *out, const unsigned ndims, const dim_t * const dims, const int seq_dim, const af_dtype type); 8 | af_err af_iota(af_array *out, const unsigned ndims, const dim_t * const dims, const unsigned t_ndims, const dim_t * const tdims, const af_dtype type); 9 | af_err af_identity(af_array *out, const unsigned ndims, const dim_t * const dims, const af_dtype type); 10 | af_err af_diag_create(af_array *out, const af_array in, const int num); 11 | af_err af_diag_extract(af_array *out, const af_array in, const int num); 12 | af_err af_join(af_array *out, const int dim, const af_array first, const af_array second); 13 | af_err af_join_many(af_array *out, const int dim, const unsigned n_arrays, const af_array *inputs); 14 | af_err af_tile(af_array *out, const af_array in, const unsigned x, const unsigned y, const unsigned z, const unsigned w); 15 | af_err af_reorder(af_array *out, const af_array in, const unsigned x, const unsigned y, const unsigned z, const unsigned w); 16 | af_err af_shift(af_array *out, const af_array in, const int x, const int y, const int z, const int w); 17 | af_err af_moddims(af_array *out, const af_array in, const unsigned ndims, const dim_t * const dims); 18 | af_err af_flat(af_array *out, const af_array in); 19 | af_err af_flip(af_array *out, const af_array in, const unsigned dim); 20 | af_err af_lower(af_array *out, const af_array in, bool is_unit_diag); 21 | af_err af_upper(af_array *out, const af_array in, bool is_unit_diag); 22 | af_err af_select(af_array *out, const af_array cond, const af_array a, const af_array b); 23 | af_err af_select_scalar_r(af_array *out, const af_array cond, const af_array a, const double b); 24 | af_err af_select_scalar_l(af_array *out, const af_array cond, const double a, const af_array b); 25 | af_err af_replace(af_array a, const af_array cond, const af_array b); 26 | af_err af_replace_scalar(af_array a, const af_array cond, const double b); 27 | -------------------------------------------------------------------------------- /include/device.h: -------------------------------------------------------------------------------- 1 | #include "defines.h" 2 | 3 | af_err af_info(); 4 | af_err af_init(); 5 | af_err af_info_string(char** str, const bool verbose); 6 | af_err af_device_info(char* d_name, char* d_platform, char *d_toolkit, char* d_compute); 7 | af_err af_get_device_count(int *num_of_devices); 8 | af_err af_get_dbl_support(bool* available, const int device); 9 | af_err af_set_device(const int device); 10 | af_err af_get_device(int *device); 11 | af_err af_sync(const int device); 12 | af_err af_alloc_device(void **ptr, const dim_t bytes); 13 | af_err af_free_device(void *ptr); 14 | af_err af_alloc_pinned(void **ptr, const dim_t bytes); 15 | af_err af_free_pinned(void *ptr); 16 | af_err af_alloc_host(void **ptr, const dim_t bytes); 17 | af_err af_free_host(void *ptr); 18 | af_err af_device_array(af_array *arr, const void *data, const unsigned ndims, const dim_t * const dims, const af_dtype type); 19 | af_err af_device_mem_info(size_t *alloc_bytes, size_t *alloc_buffers, size_t *lock_bytes, size_t *lock_buffers); 20 | af_err af_print_mem_info(const char *msg, const int device_id); 21 | af_err af_device_gc(); 22 | af_err af_set_mem_step_size(const size_t step_bytes); 23 | af_err af_get_mem_step_size(size_t *step_bytes); 24 | af_err af_lock_device_ptr(const af_array arr); 25 | af_err af_unlock_device_ptr(const af_array arr); 26 | af_err af_lock_array(const af_array arr); 27 | af_err af_is_locked_array(bool *res, const af_array arr); 28 | af_err af_get_device_ptr(void **ptr, const af_array arr); 29 | -------------------------------------------------------------------------------- /include/exception.h: -------------------------------------------------------------------------------- 1 | #include "defines.h" 2 | 3 | void af_get_last_error(char **msg, dim_t *len); 4 | const char *af_err_to_string(const af_err err); 5 | 6 | -------------------------------------------------------------------------------- /include/features.h: -------------------------------------------------------------------------------- 1 | #include "defines.h" 2 | 3 | af_err af_create_features(af_features *feat, dim_t num); 4 | af_err af_retain_features(af_features *out, const af_features feat); 5 | af_err af_get_features_num(dim_t *num, const af_features feat); 6 | af_err af_get_features_xpos(af_array *out, const af_features feat); 7 | af_err af_get_features_ypos(af_array *out, const af_features feat); 8 | af_err af_get_features_score(af_array *score, const af_features feat); 9 | af_err af_get_features_orientation(af_array *orientation, const af_features feat); 10 | af_err af_get_features_size(af_array *size, const af_features feat); 11 | af_err af_release_features(af_features feat); 12 | -------------------------------------------------------------------------------- /include/graphics.h: -------------------------------------------------------------------------------- 1 | #include "defines.h" 2 | 3 | af_err af_create_window(af_window *out, const int width, const int height, const char* const title); 4 | af_err af_set_position(const af_window wind, const unsigned x, const unsigned y); 5 | af_err af_set_title(const af_window wind, const char* const title); 6 | af_err af_set_size(const af_window wind, const unsigned w, const unsigned h); 7 | af_err af_draw_image(const af_window wind, const af_array in, const af_cell* const props); 8 | af_err af_draw_plot(const af_window wind, const af_array X, const af_array Y, const af_cell* const props); 9 | af_err af_draw_plot3(const af_window wind, const af_array P, const af_cell* const props); 10 | af_err af_draw_plot_nd(const af_window wind, const af_array P, const af_cell* const props); 11 | af_err af_draw_plot_2d(const af_window wind, const af_array X, const af_array Y,const af_cell* const props); 12 | af_err af_draw_plot_3d(const af_window wind,const af_array X, const af_array Y, const af_array Z,const af_cell* const props); 13 | af_err af_draw_scatter(const af_window wind, const af_array X, const af_array Y, const af_marker_type marker, const af_cell* const props); 14 | af_err af_draw_scatter3(const af_window wind, const af_array P, const af_marker_type marker, const af_cell* const props); 15 | af_err af_draw_scatter_nd(const af_window wind, const af_array P, const af_marker_type marker, const af_cell* const props); 16 | af_err af_draw_scatter_2d(const af_window wind, const af_array X, const af_array Y,const af_marker_type marker, const af_cell* const props); 17 | af_err af_draw_scatter_3d(const af_window wind, const af_array X, const af_array Y, const af_array Z, const af_marker_type marker, const af_cell* const props); 18 | af_err af_draw_hist(const af_window wind, const af_array X, const double minval, const double maxval, const af_cell* const props); 19 | af_err af_draw_surface(const af_window wind, const af_array xVals, const af_array yVals, const af_array S, const af_cell* const props); 20 | af_err af_draw_vector_field_nd(const af_window wind, const af_array points, const af_array directions, const af_cell* const props); 21 | af_err af_draw_vector_field_3d(const af_window wind, const af_array xPoints, const af_array yPoints, const af_array zPoints, const af_array xDirs, const af_array yDirs, const af_array zDirs, const af_cell* const props); 22 | af_err af_draw_vector_field_2d(const af_window wind,const af_array xPoints, const af_array yPoints,const af_array xDirs, const af_array yDirs,const af_cell* const props); 23 | af_err af_grid(const af_window wind, const int rows, const int cols); 24 | af_err af_set_axes_limits_compute(const af_window wind, const af_array x, const af_array y, const af_array z,const bool exact, const af_cell* const props); 25 | af_err af_set_axes_limits_2d(const af_window wind, const float xmin, const float xmax, const float ymin, const float ymax, const bool exact, const af_cell* const props); 26 | af_err af_set_axes_limits_3d(const af_window wind, const float xmin, const float xmax, const float ymin, const float ymax, const float zmin, const float zmax, const bool exact, const af_cell* const props); 27 | af_err af_set_axes_titles(const af_window wind, const char * const xtitle, const char * const ytitle, const char * const ztitle, const af_cell* const props); 28 | af_err af_show(const af_window wind); 29 | af_err af_is_window_closed(bool *out, const af_window wind); 30 | af_err af_set_visibility(const af_window wind, const bool is_visible); 31 | af_err af_destroy_window(const af_window wind); 32 | -------------------------------------------------------------------------------- /include/image.h: -------------------------------------------------------------------------------- 1 | #include "defines.h" 2 | 3 | af_err af_gradient(af_array *dx, af_array *dy, const af_array in); 4 | af_err af_load_image(af_array *out, const char* filename, const bool isColor); 5 | af_err af_save_image(const char* filename, const af_array in); 6 | af_err af_load_image_memory(af_array *out, const void* ptr); 7 | af_err af_save_image_memory(void** ptr, const af_array in, const af_image_format format); 8 | af_err af_delete_image_memory(void* ptr); 9 | af_err af_load_image_native(af_array *out, const char* filename); 10 | af_err af_save_image_native(const char* filename, const af_array in); 11 | af_err af_is_image_io_available(bool *out); 12 | af_err af_resize(af_array *out, const af_array in, const dim_t odim0, const dim_t odim1, const af_interp_type method); 13 | af_err af_transform(af_array *out, const af_array in, const af_array transform, const dim_t odim0, const dim_t odim1, const af_interp_type method, const bool inverse); 14 | af_err af_transform_coordinates(af_array *out, const af_array tf, const float d0, const float d1); 15 | af_err af_rotate(af_array *out, const af_array in, const float theta, const bool crop, const af_interp_type method); 16 | af_err af_translate(af_array *out, const af_array in, const float trans0, const float trans1, const dim_t odim0, const dim_t odim1, const af_interp_type method); 17 | af_err af_scale(af_array *out, const af_array in, const float scale0, const float scale1, const dim_t odim0, const dim_t odim1, const af_interp_type method); 18 | af_err af_skew(af_array *out, const af_array in, const float skew0, const float skew1, const dim_t odim0, const dim_t odim1, const af_interp_type method, const bool inverse); 19 | af_err af_histogram(af_array *out, const af_array in, const unsigned nbins, const double minval, const double maxval); 20 | af_err af_dilate(af_array *out, const af_array in, const af_array mask); 21 | af_err af_dilate3(af_array *out, const af_array in, const af_array mask); 22 | af_err af_erode(af_array *out, const af_array in, const af_array mask); 23 | af_err af_erode3(af_array *out, const af_array in, const af_array mask); 24 | af_err af_bilateral(af_array *out, const af_array in, const float spatial_sigma, const float chromatic_sigma, const bool isColor); 25 | af_err af_mean_shift(af_array *out, const af_array in, const float spatial_sigma, const float chromatic_sigma, const unsigned iter, const bool is_color); 26 | af_err af_minfilt(af_array *out, const af_array in, const dim_t wind_length, const dim_t wind_width, const af_border_type edge_pad); 27 | af_err af_maxfilt(af_array *out, const af_array in, const dim_t wind_length, const dim_t wind_width, const af_border_type edge_pad); 28 | af_err af_regions(af_array *out, const af_array in, const af_connectivity connectivity, const af_dtype ty); 29 | af_err af_sobel_operator(af_array *dx, af_array *dy, const af_array img, const unsigned ker_size); 30 | af_err af_rgb2gray(af_array* out, const af_array in, const float rPercent, const float gPercent, const float bPercent); 31 | af_err af_gray2rgb(af_array* out, const af_array in, const float rFactor, const float gFactor, const float bFactor); 32 | af_err af_hist_equal(af_array *out, const af_array in, const af_array hist); 33 | af_err af_gaussian_kernel(af_array *out, const int rows, const int cols, const double sigma_r, const double sigma_c); 34 | af_err af_hsv2rgb(af_array* out, const af_array in); 35 | af_err af_rgb2hsv(af_array* out, const af_array in); 36 | af_err af_color_space(af_array *out, const af_array image, const af_cspace_t to, const af_cspace_t from); 37 | af_err af_unwrap(af_array *out, const af_array in, const dim_t wx, const dim_t wy, const dim_t sx, const dim_t sy, const dim_t px, const dim_t py, const bool is_column); 38 | af_err af_wrap(af_array *out, const af_array in, const dim_t ox, const dim_t oy, const dim_t wx, const dim_t wy, const dim_t sx, const dim_t sy, const dim_t px, const dim_t py, const bool is_column); 39 | af_err af_sat(af_array *out, const af_array in); 40 | af_err af_ycbcr2rgb(af_array* out, const af_array in, const af_ycc_std standard); 41 | af_err af_rgb2ycbcr(af_array* out, const af_array in, const af_ycc_std standard); 42 | af_err af_moments(af_array *out, const af_array in, const af_moment_type moment); 43 | af_err af_moments_all(double* out, const af_array in, const af_moment_type moment); 44 | af_err af_canny(af_array* out, const af_array in, const af_canny_threshold threshold_type, const float low_threshold_ratio, const float high_threshold_ratio, const unsigned sobel_window, const bool is_fast); 45 | af_err af_anisotropic_diffusion(af_array* out, const af_array in, const float timestep, const float conductance, const unsigned iterations, const af_flux_function fftype,const af_diffusion_eq diffusion_kind); 46 | -------------------------------------------------------------------------------- /include/index.h: -------------------------------------------------------------------------------- 1 | #include "defines.h" 2 | 3 | af_err af_index( af_array *out, const af_array in, const unsigned ndims, const af_seq* const index); 4 | af_err af_lookup( af_array *out, const af_array in, const af_array indices, const unsigned dim); 5 | af_err af_assign_seq( af_array *out, const af_array lhs, const unsigned ndims, const af_seq* const indices, const af_array rhs); 6 | af_err af_index_gen( af_array *out, const af_array in, const dim_t ndims, const af_index_t* indices); 7 | af_err af_assign_gen( af_array *out, const af_array lhs, const dim_t ndims, const af_index_t* indices, const af_array rhs); 8 | af_err af_create_indexers(af_index_t** indexers); 9 | af_err af_set_array_indexer(af_index_t* indexer, const af_array idx, const dim_t dim); 10 | af_err af_set_seq_indexer(af_index_t* indexer, const af_seq* idx, const dim_t dim, const bool is_batch); 11 | af_err af_set_seq_param_indexer(af_index_t* indexer, const double begin, const double end, const double step, const dim_t dim, const bool is_batch); 12 | af_err af_release_indexers(af_index_t* indexers); 13 | -------------------------------------------------------------------------------- /include/internal.h: -------------------------------------------------------------------------------- 1 | #include "defines.h" 2 | 3 | af_err af_create_strided_array(af_array *arr, const void *data, const dim_t offset, const unsigned ndims, const dim_t *const dims, const dim_t *const strides, const af_dtype ty, const af_source location); 4 | af_err af_get_strides(dim_t *s0, dim_t *s1, dim_t *s2, dim_t *s3, const af_array arr); 5 | af_err af_get_offset(dim_t *offset, const af_array arr); 6 | af_err af_get_raw_ptr(void **ptr, const af_array arr); 7 | af_err af_is_linear(bool *result, const af_array arr); 8 | af_err af_is_owner(bool *result, const af_array arr); 9 | af_err af_get_allocated_bytes(size_t *bytes, const af_array arr); 10 | -------------------------------------------------------------------------------- /include/lapack.h: -------------------------------------------------------------------------------- 1 | #include "defines.h" 2 | 3 | af_err af_svd(af_array *u, af_array *s, af_array *vt, const af_array in); 4 | af_err af_svd_inplace(af_array *u, af_array *s, af_array *vt, af_array in); 5 | af_err af_lu(af_array *lower, af_array *upper, af_array *pivot, const af_array in); 6 | af_err af_lu_inplace(af_array *pivot, af_array in, const bool is_lapack_piv); 7 | af_err af_qr(af_array *q, af_array *r, af_array *tau, const af_array in); 8 | af_err af_qr_inplace(af_array *tau, af_array in); 9 | af_err af_cholesky(af_array *out, int *info, const af_array in, const bool is_upper); 10 | af_err af_cholesky_inplace(int *info, af_array in, const bool is_upper); 11 | af_err af_solve(af_array *x, const af_array a, const af_array b, const af_mat_prop options); 12 | af_err af_solve_lu(af_array *x, const af_array a, const af_array piv, const af_array b, const af_mat_prop options); 13 | af_err af_inverse(af_array *out, const af_array in, const af_mat_prop options); 14 | af_err af_rank(unsigned *rank, const af_array in, const double tol); 15 | af_err af_det(double *det_real, double *det_imag, const af_array in); 16 | af_err af_norm(double *out, const af_array in, const af_norm_type type, const double p, const double q); 17 | af_err af_is_lapack_available(bool *out); 18 | -------------------------------------------------------------------------------- /include/openCL.h: -------------------------------------------------------------------------------- 1 | #include "defines.h" 2 | 3 | af_err afcl_get_context(cl_context *ctx, const bool retain); 4 | af_err afcl_get_queue(cl_command_queue *queue, const bool retain); 5 | af_err afcl_get_device_id(cl_device_id *id); 6 | af_err afcl_set_device_id(cl_device_id id); 7 | af_err afcl_add_device_context(cl_device_id dev, cl_context ctx, cl_command_queue que); 8 | af_err afcl_set_device_context(cl_device_id dev, cl_context ctx); 9 | af_err afcl_delete_device_context(cl_device_id dev, cl_context ctx); 10 | af_err afcl_get_device_type(afcl_device_type *res); 11 | af_err afcl_get_platform(afcl_platform *res); 12 | -------------------------------------------------------------------------------- /include/random.h: -------------------------------------------------------------------------------- 1 | #include "defines.h" 2 | 3 | af_err af_create_random_engine(af_random_engine *engine, af_random_engine_type rtype, uintl seed); 4 | af_err af_retain_random_engine(af_random_engine *out, const af_random_engine engine); 5 | af_err af_random_engine_set_type(af_random_engine *engine, const af_random_engine_type rtype); 6 | af_err af_random_engine_get_type(af_random_engine_type *rtype, const af_random_engine engine); 7 | af_err af_random_uniform(af_array *out, const unsigned ndims, const dim_t * const dims, const af_dtype type, af_random_engine engine); 8 | af_err af_random_normal(af_array *out, const unsigned ndims, const dim_t * const dims, const af_dtype type, af_random_engine engine); 9 | af_err af_random_engine_set_seed(af_random_engine *engine, const uintl seed); 10 | af_err af_get_default_random_engine(af_random_engine *engine); 11 | af_err af_set_default_random_engine_type(const af_random_engine_type rtype); 12 | af_err af_random_engine_get_seed(uintl * const seed, af_random_engine engine); 13 | af_err af_release_random_engine(af_random_engine engine); 14 | af_err af_randu(af_array *out, const unsigned ndims, const dim_t * const dims, const af_dtype type); 15 | af_err af_randn(af_array *out, const unsigned ndims, const dim_t * const dims, const af_dtype type); 16 | af_err af_set_seed(const uintl seed); 17 | af_err af_get_seed(uintl *seed); 18 | -------------------------------------------------------------------------------- /include/seq.h: -------------------------------------------------------------------------------- 1 | #include "defines.h" 2 | 3 | af_seq af_make_seq(double begin, double end, double step); 4 | -------------------------------------------------------------------------------- /include/signal.h: -------------------------------------------------------------------------------- 1 | #include "defines.h" 2 | 3 | af_err af_approx1(af_array *out, const af_array in, const af_array pos, const af_interp_type method, const float off_grid); 4 | af_err af_approx2(af_array *out, const af_array in, const af_array pos0, const af_array pos1, const af_interp_type method, const float off_grid); 5 | af_err af_fft(af_array *out, const af_array in, const double norm_factor, const dim_t odim0); 6 | af_err af_fft_inplace(af_array in, const double norm_factor); 7 | af_err af_fft2(af_array *out, const af_array in, const double norm_factor, const dim_t odim0, const dim_t odim1); 8 | af_err af_fft2_inplace(af_array in, const double norm_factor); 9 | af_err af_fft3(af_array *out, const af_array in, const double norm_factor, const dim_t odim0, const dim_t odim1, const dim_t odim2); 10 | af_err af_fft3_inplace(af_array in, const double norm_factor); 11 | af_err af_ifft(af_array *out, const af_array in, const double norm_factor, const dim_t odim0); 12 | af_err af_ifft_inplace(af_array in, const double norm_factor); 13 | af_err af_ifft2(af_array *out, const af_array in, const double norm_factor, const dim_t odim0, const dim_t odim1); 14 | af_err af_ifft2_inplace(af_array in, const double norm_factor); 15 | af_err af_ifft3(af_array *out, const af_array in, const double norm_factor, const dim_t odim0, const dim_t odim1, const dim_t odim2); 16 | af_err af_ifft3_inplace(af_array in, const double norm_factor); 17 | af_err af_fft_r2c (af_array *out, const af_array in, const double norm_factor, const dim_t pad0); 18 | af_err af_fft2_r2c(af_array *out, const af_array in, const double norm_factor, const dim_t pad0, const dim_t pad1); 19 | af_err af_fft3_r2c(af_array *out, const af_array in, const double norm_factor, const dim_t pad0, const dim_t pad1, const dim_t pad2); 20 | af_err af_fft_c2r (af_array *out, const af_array in, const double norm_factor, const bool is_odd); 21 | af_err af_fft2_c2r(af_array *out, const af_array in, const double norm_factor, const bool is_odd); 22 | af_err af_fft3_c2r(af_array *out, const af_array in, const double norm_factor, const bool is_odd); 23 | af_err af_convolve1(af_array *out, const af_array signal, const af_array filter, const af_conv_mode mode, af_conv_domain domain); 24 | af_err af_convolve2(af_array *out, const af_array signal, const af_array filter, const af_conv_mode mode, af_conv_domain domain); 25 | af_err af_convolve3(af_array *out, const af_array signal, const af_array filter, const af_conv_mode mode, af_conv_domain domain); 26 | af_err af_convolve2_sep(af_array *out, const af_array col_filter, const af_array row_filter, const af_array signal, const af_conv_mode mode); 27 | af_err af_fft_convolve1(af_array *out, const af_array signal, const af_array filter, const af_conv_mode mode); 28 | af_err af_fft_convolve2(af_array *out, const af_array signal, const af_array filter, const af_conv_mode mode); 29 | af_err af_fft_convolve3(af_array *out, const af_array signal, const af_array filter, const af_conv_mode mode); 30 | af_err af_fir(af_array *y, const af_array b, const af_array x); 31 | af_err af_iir(af_array *y, const af_array b, const af_array a, const af_array x); 32 | af_err af_medfilt(af_array *out, const af_array in, const dim_t wind_length, const dim_t wind_width, const af_border_type edge_pad); 33 | af_err af_medfilt1(af_array *out, const af_array in, const dim_t wind_width, const af_border_type edge_pad); 34 | af_err af_medfilt2(af_array *out, const af_array in, const dim_t wind_length, const dim_t wind_width, const af_border_type edge_pad); 35 | af_err af_set_fft_plan_cache_size(size_t cache_size); 36 | -------------------------------------------------------------------------------- /include/sparse.h: -------------------------------------------------------------------------------- 1 | #include "defines.h" 2 | 3 | af_err af_create_sparse_array(af_array *out, const dim_t nRows, const dim_t nCols, const af_array values, const af_array rowIdx, const af_array colIdx, const af_storage stype); 4 | af_err af_create_sparse_array_from_ptr(af_array *out, const dim_t nRows, const dim_t nCols, const dim_t nNZ, const void * const values, const int * const rowIdx, const int * const colIdx, const af_dtype type, const af_storage stype, const af_source src); 5 | af_err af_create_sparse_array_from_dense(af_array *out, const af_array dense, const af_storage stype); 6 | af_err af_sparse_convert_to(af_array *out, const af_array in, const af_storage destStorage); 7 | af_err af_sparse_to_dense(af_array *out, const af_array sparse); 8 | af_err af_sparse_get_info(af_array *values, af_array *rowIdx, af_array *colIdx, af_storage *stype, const af_array in); 9 | af_err af_sparse_get_values(af_array *out, const af_array in); 10 | af_err af_sparse_get_row_idx(af_array *out, const af_array in); 11 | af_err af_sparse_get_col_idx(af_array *out, const af_array in); 12 | af_err af_sparse_get_nnz(dim_t *out, const af_array in); 13 | af_err af_sparse_get_storage(af_storage *out, const af_array in); 14 | -------------------------------------------------------------------------------- /include/statistics.h: -------------------------------------------------------------------------------- 1 | #include "defines.h" 2 | 3 | af_err af_mean(af_array *out, const af_array in, const dim_t dim); 4 | af_err af_mean_weighted(af_array *out, const af_array in, const af_array weights, const dim_t dim); 5 | af_err af_var(af_array *out, const af_array in, const bool isbiased, const dim_t dim); 6 | af_err af_var_weighted(af_array *out, const af_array in, const af_array weights, const dim_t dim); 7 | af_err af_stdev(af_array *out, const af_array in, const dim_t dim); 8 | af_err af_cov(af_array* out, const af_array X, const af_array Y, const bool isbiased); 9 | af_err af_median(af_array* out, const af_array in, const dim_t dim); 10 | af_err af_mean_all(double *real, double *imag, const af_array in); 11 | af_err af_mean_all_weighted(double *real, double *imag, const af_array in, const af_array weights); 12 | af_err af_var_all(double *realVal, double *imagVal, const af_array in, const bool isbiased); 13 | af_err af_var_all_weighted(double *realVal, double *imagVal, const af_array in, const af_array weights); 14 | af_err af_stdev_all(double *real, double *imag, const af_array in); 15 | af_err af_median_all(double *realVal, double *imagVal, const af_array in); 16 | af_err af_corrcoef(double *realVal, double *imagVal, const af_array X, const af_array Y); 17 | af_err af_topk(af_array *values, af_array *indices, const af_array in, const int k, const int dim, const af_topk_function order); 18 | -------------------------------------------------------------------------------- /include/util.h: -------------------------------------------------------------------------------- 1 | #include "defines.h" 2 | 3 | af_err af_print_array(af_array arr); 4 | af_err af_print_array_gen(const char *exp, const af_array arr, const int precision); 5 | af_err af_save_array(int *index, const char* key, const af_array arr, const char *filename, const bool append); 6 | af_err af_read_array_index(af_array *out, const char *filename, const unsigned index); 7 | af_err af_read_array_key(af_array *out, const char *filename, const char* key); 8 | af_err af_read_array_key_check(int *index, const char *filename, const char* key); 9 | af_err af_array_to_string(char **output, const char *exp, const af_array arr, const int precision, const bool transpose); 10 | af_err af_example_function(af_array* out, const af_array in, const af_someenum_t param); 11 | af_err af_get_version(int *major, int *minor, int *patch); 12 | const char *af_get_revision(); 13 | af_err af_get_size_of(size_t *size, af_dtype type); 14 | -------------------------------------------------------------------------------- /include/vision.h: -------------------------------------------------------------------------------- 1 | #include "defines.h" 2 | 3 | af_err af_fast(af_features *out, const af_array in, const float thr, const unsigned arc_length, const bool non_max, const float feature_ratio, const unsigned edge); 4 | af_err af_harris(af_features *out, const af_array in, const unsigned max_corners, const float min_response, const float sigma, const unsigned block_size, const float k_thr); 5 | af_err af_orb(af_features *feat, af_array *desc, const af_array in, const float fast_thr, const unsigned max_feat, const float scl_fctr, const unsigned levels, const bool blur_img); 6 | af_err af_sift(af_features *feat, af_array *desc, const af_array in, const unsigned n_layers, const float contrast_thr, const float edge_thr, const float init_sigma, const bool double_input, const float intensity_scale, const float feature_ratio); 7 | af_err af_gloh(af_features *feat, af_array *desc, const af_array in, const unsigned n_layers, const float contrast_thr, const float edge_thr, const float init_sigma, const bool double_input, const float intensity_scale, const float feature_ratio); 8 | af_err af_hamming_matcher(af_array* idx, af_array* dist, const af_array query, const af_array train, const dim_t dist_dim, const unsigned n_dist); 9 | af_err af_nearest_neighbour(af_array* idx, af_array* dist, const af_array query, const af_array train, const dim_t dist_dim, const unsigned n_dist, const af_match_type dist_type); 10 | af_err af_match_template(af_array *out, const af_array search_img, const af_array template_img, const af_match_type m_type); 11 | af_err af_susan(af_features* out, const af_array in, const unsigned radius, const float diff_thr, const float geom_thr, const float feature_ratio, const unsigned edge); 12 | af_err af_dog(af_array *out, const af_array in, const int radius1, const int radius2); 13 | af_err af_homography(af_array *H, int *inliers, const af_array x_src, const af_array y_src, const af_array x_dst, const af_array y_dst, const af_homography_type htype, const float inlier_thr, const unsigned iterations, const af_dtype otype); 14 | -------------------------------------------------------------------------------- /nix/default.nix: -------------------------------------------------------------------------------- 1 | { stdenv, fetchurl, fetchFromGitHub, cmake, pkgconfig 2 | , cudatoolkit, opencl-clhpp, ocl-icd, fftw, fftwFloat, mkl 3 | , blas, openblas, boost, mesa_noglu, libGLU_combined 4 | , freeimage, python, lib 5 | }: 6 | 7 | let 8 | version = "3.6.4"; 9 | 10 | clfftSource = fetchFromGitHub { 11 | owner = "arrayfire"; 12 | repo = "clFFT"; 13 | rev = "16925fb93338b3cac66490b5cf764953d6a5dac7"; 14 | sha256 = "0y35nrdz7w4n1l17myhkni3hwm37z775xn6f76xmf1ph7dbkslsc"; 15 | fetchSubmodules = true; 16 | }; 17 | 18 | clblasSource = fetchFromGitHub { 19 | owner = "arrayfire"; 20 | repo = "clBLAS"; 21 | rev = "1f3de2ae5582972f665c685b18ef0df43c1792bb"; 22 | sha256 = "154mz52r5hm0jrp5fqrirzzbki14c1jkacj75flplnykbl36ibjs"; 23 | fetchSubmodules = true; 24 | }; 25 | 26 | cl2hppSource = fetchurl { 27 | url = "https://github.com/KhronosGroup/OpenCL-CLHPP/releases/download/v2.0.10/cl2.hpp"; 28 | sha256 = "1v4q0g6b6mwwsi0kn7kbjn749j3qafb9r4ld3zdq1163ln9cwnvw"; 29 | }; 30 | 31 | in stdenv.mkDerivation { 32 | pname = "arrayfire"; 33 | inherit version; 34 | 35 | src = fetchurl { 36 | url = "http://arrayfire.com/arrayfire_source/arrayfire-full-${version}.tar.bz2"; 37 | sha256 = "1fin7a9rliyqic3z83agkpb8zlq663q6gdxsnm156cs8s7f7rc9h"; 38 | }; 39 | 40 | cmakeFlags = [ 41 | "-DAF_BUILD_OPENCL=OFF" 42 | "-DAF_BUILD_EXAMPLES=OFF" 43 | "-DBUILD_TESTING=OFF" 44 | ] ++ (lib.optional stdenv.isLinux ["-DCMAKE_LIBRARY_PATH=${cudatoolkit}/lib/stubs"]); 45 | 46 | patches = [ ./no-download.patch ]; 47 | 48 | postPatch = '' 49 | mkdir -p ./build/third_party/clFFT/src 50 | cp -R --no-preserve=mode,ownership ${clfftSource}/ ./build/third_party/clFFT/src/clFFT-ext/ 51 | mkdir -p ./build/third_party/clBLAS/src 52 | cp -R --no-preserve=mode,ownership ${clblasSource}/ ./build/third_party/clBLAS/src/clBLAS-ext/ 53 | mkdir -p ./build/include/CL 54 | cp -R --no-preserve=mode,ownership ${cl2hppSource} ./build/include/CL/cl2.hpp 55 | ''; 56 | 57 | preBuild = lib.optionalString stdenv.isLinux '' 58 | export CUDA_PATH="${cudatoolkit}"' 59 | ''; 60 | 61 | enableParallelBuilding = true; 62 | 63 | buildInputs = [ 64 | cmake pkgconfig 65 | opencl-clhpp fftw fftwFloat 66 | mkl openblas 67 | libGLU_combined 68 | mesa_noglu freeimage 69 | boost.out boost.dev python 70 | ] ++ (lib.optional stdenv.isLinux [ cudatoolkit ocl-icd ]); 71 | 72 | meta = with stdenv.lib; { 73 | description = "A general-purpose library that simplifies the process of developing software that targets parallel and massively-parallel architectures including CPUs, GPUs, and other hardware acceleration devices"; 74 | license = licenses.bsd3; 75 | homepage = https://arrayfire.com/ ; 76 | maintainers = with stdenv.lib.maintainers; [ chessai ]; 77 | inherit version; 78 | }; 79 | } 80 | -------------------------------------------------------------------------------- /nix/no-download.patch: -------------------------------------------------------------------------------- 1 | diff --git a/CMakeModules/build_clBLAS.cmake b/CMakeModules/build_clBLAS.cmake 2 | index 8de529e8..6361b613 100644 3 | --- a/CMakeModules/build_clBLAS.cmake 4 | +++ b/CMakeModules/build_clBLAS.cmake 5 | @@ -14,8 +14,7 @@ find_package(OpenCL) 6 | 7 | ExternalProject_Add( 8 | clBLAS-ext 9 | - GIT_REPOSITORY https://github.com/arrayfire/clBLAS.git 10 | - GIT_TAG arrayfire-release 11 | + DOWNLOAD_COMMAND true 12 | BUILD_BYPRODUCTS ${clBLAS_location} 13 | PREFIX "${prefix}" 14 | INSTALL_DIR "${prefix}" 15 | diff --git a/CMakeModules/build_clFFT.cmake b/CMakeModules/build_clFFT.cmake 16 | index 28be38a3..85e3915e 100644 17 | --- a/CMakeModules/build_clFFT.cmake 18 | +++ b/CMakeModules/build_clFFT.cmake 19 | @@ -20,8 +20,7 @@ ENDIF() 20 | 21 | ExternalProject_Add( 22 | clFFT-ext 23 | - GIT_REPOSITORY https://github.com/arrayfire/clFFT.git 24 | - GIT_TAG arrayfire-release 25 | + DOWNLOAD_COMMAND true 26 | PREFIX "${prefix}" 27 | INSTALL_DIR "${prefix}" 28 | UPDATE_COMMAND "" 29 | -------------------------------------------------------------------------------- /pkg.nix: -------------------------------------------------------------------------------- 1 | { mkDerivation, base, directory, parsec, stdenv, text, vector 2 | , hspec, hspec-discover 3 | }: 4 | mkDerivation { 5 | pname = "arrayfire"; 6 | version = "0.6.0.0" 7 | src = ./.; 8 | isLibrary = true; 9 | isExecutable = true; 10 | libraryHaskellDepends = [ base vector ]; 11 | executableHaskellDepends = [ base directory parsec text ]; 12 | testHaskellDepends = [ hspec hspec-discover ]; 13 | homepage = "https://github.com/arrayfire/arrayfire-haskell"; 14 | description = "Haskell bindings to ArrayFire"; 15 | license = stdenv.lib.licenses.bsd3; 16 | } 17 | -------------------------------------------------------------------------------- /shell.nix: -------------------------------------------------------------------------------- 1 | { pkgs ? import {} }: 2 | let 3 | pkg = (import ./default.nix {}).env; 4 | in 5 | pkgs.lib.overrideDerivation pkg (drv: { 6 | shellHook = '' 7 | export AF_PRINT_ERRORS=1 8 | export PATH=$PATH:${pkgs.haskellPackages.doctest}/bin 9 | export PATH=$PATH:${pkgs.haskellPackages.cabal-install}/bin 10 | function ghcid () { 11 | ${pkgs.haskellPackages.ghcid.bin}/bin/ghcid -c 'cabal v1-repl lib:arrayfire' 12 | }; 13 | function test-runner () { 14 | ${pkgs.silver-searcher}/bin/ag -l | \ 15 | ${pkgs.entr}/bin/entr sh -c \ 16 | 'cabal v1-configure --enable-tests && \ 17 | cabal v1-build test && dist/build/test/test' 18 | } 19 | function doctest-runner () { 20 | ${pkgs.silver-searcher}/bin/ag -l | \ 21 | ${pkgs.entr}/bin/entr sh -c \ 22 | 'cabal v1-configure --enable-tests && \ 23 | cabal v1-build doctests && dist/build/doctests/doctests src/ArrayFire/Algorithm.hs' 24 | } 25 | function exe () { 26 | cabal run main 27 | } 28 | function repl () { 29 | cabal v1-repl lib:arrayfire 30 | } 31 | function docs () { 32 | cabal haddock 33 | open ./dist-newstyle/*/*/*/*/doc/html/arrayfire/index.html 34 | } 35 | function upload-docs () { 36 | cabal haddock --haddock-for-hackage 37 | cabal upload -d dist-newstyle/arrayfire-*.*.*.*-docs.tar.gz --publish 38 | } 39 | ''; 40 | }) 41 | -------------------------------------------------------------------------------- /src/ArrayFire.hs: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- 2 | -- | 3 | -- Module : ArrayFire 4 | -- Copyright : David Johnson (c) 2019-2020 5 | -- License : BSD3 6 | -- Maintainer : David Johnson 7 | -- Stability : Experimental 8 | -- Portability : GHC 9 | -- 10 | -- <> 11 | -- 12 | -------------------------------------------------------------------------------- 13 | module ArrayFire 14 | ( -- * Tutorial 15 | -- $tutorial 16 | 17 | -- ** Modules 18 | -- $modules 19 | 20 | -- ** Exceptions 21 | -- $exceptions 22 | 23 | -- ** Construction 24 | -- $construction 25 | 26 | -- ** Laws 27 | -- $laws 28 | 29 | -- ** Conversion 30 | -- $conversion 31 | 32 | -- ** Serialization 33 | -- $serialization 34 | 35 | -- ** Device 36 | -- $device 37 | module ArrayFire.Algorithm 38 | , module ArrayFire.Arith 39 | , module ArrayFire.Array 40 | , module ArrayFire.Backend 41 | , module ArrayFire.BLAS 42 | , module ArrayFire.Data 43 | , module ArrayFire.Device 44 | , module ArrayFire.Features 45 | , module ArrayFire.Graphics 46 | , module ArrayFire.Image 47 | , module ArrayFire.Index 48 | , module ArrayFire.LAPACK 49 | , module ArrayFire.Random 50 | , module ArrayFire.Signal 51 | , module ArrayFire.Sparse 52 | , module ArrayFire.Statistics 53 | , module ArrayFire.Types 54 | , module ArrayFire.Util 55 | , module ArrayFire.Vision 56 | , module Foreign.C.Types 57 | , module Data.Int 58 | , module Data.Word 59 | , module Data.Complex 60 | , module Foreign.Storable 61 | ) where 62 | 63 | import ArrayFire.Algorithm 64 | import ArrayFire.Arith 65 | import ArrayFire.Array 66 | import ArrayFire.Backend 67 | import ArrayFire.BLAS 68 | import ArrayFire.Data 69 | import ArrayFire.Device 70 | import ArrayFire.Features 71 | import ArrayFire.Graphics 72 | import ArrayFire.Image 73 | import ArrayFire.Index 74 | import ArrayFire.LAPACK 75 | import ArrayFire.Random 76 | import ArrayFire.Signal 77 | import ArrayFire.Sparse 78 | import ArrayFire.Statistics 79 | import ArrayFire.Types 80 | import ArrayFire.Util 81 | import ArrayFire.Vision 82 | import ArrayFire.Orphans () 83 | import Foreign.Storable 84 | import Foreign.C.Types 85 | import Data.Int 86 | import Data.Complex 87 | import Data.Word 88 | 89 | -- $tutorial 90 | -- 91 | -- [ArrayFire](http://arrayfire.org/docs/gettingstarted.htm) is a high performance parallel computing library that features modules for statistical and numerical methods. 92 | -- Example usage is depicted below. 93 | -- 94 | -- @ 95 | -- module Main where 96 | -- 97 | -- import qualified ArrayFire as A 98 | -- 99 | -- main :: IO () 100 | -- main = print $ A.matrix @Double (3,2) [[1,2,3],[4,5,6]] 101 | -- @ 102 | -- 103 | -- Each 'Array' is constructed and displayed in column-major order. 104 | -- 105 | -- @ 106 | -- ArrayFire Array 107 | -- [3 2 1 1] 108 | -- 1.0000 4.0000 109 | -- 2.0000 5.0000 110 | -- 3.0000 6.0000 111 | -- @ 112 | 113 | -- $modules 114 | -- 115 | -- All child modules are re-exported top-level in the "ArrayFire" module. 116 | -- We recommend importing "ArrayFire" qualified so as to avoid naming collisions. 117 | -- 118 | -- >>> import qualified ArrayFire as A 119 | -- 120 | 121 | -- $exceptions 122 | -- 123 | -- @ 124 | -- {\-\# LANGUAGE TypeApplications \#\-} 125 | -- module Main where 126 | -- 127 | -- import qualified ArrayFire as A 128 | -- import Control.Exception ( catch ) 129 | -- 130 | -- main :: IO () 131 | -- main = A.printArray action \`catch\` (\\(e :: A.AFException) -> print e) 132 | -- where 133 | -- action = 134 | -- A.matrix \@Double (3,3) [[1..],[1..],[1..]] 135 | -- \`A.mul\` A.matrix \@Double (2,2) [[1..],[1..]] 136 | -- @ 137 | -- 138 | -- The above operation is invalid since the matrix multiply has improper dimensions. The caught exception produces the following error: 139 | -- 140 | -- > AFException {afExceptionType = SizeError, afExceptionCode = 203, afExceptionMsg = "Invalid input size"} 141 | -- 142 | 143 | -- $construction 144 | -- An 'Array' can be constructed using the following smart constructors: 145 | -- 146 | -- /Note/: All smart constructors (and ArrayFire internally) assume column-major order. 147 | -- 148 | -- @ 149 | -- >>> scalar \@Double 2.0 150 | -- ArrayFire Array 151 | -- [1 1 1 1] 152 | -- 2.0000 153 | -- @ 154 | -- 155 | -- @ 156 | -- >>> vector \@Double 10 [1..] 157 | -- ArrayFire Array 158 | -- [10 1 1 1] 159 | -- 1.0000 160 | -- 2.0000 161 | -- 3.0000 162 | -- 4.0000 163 | -- 5.0000 164 | -- 6.0000 165 | -- 7.0000 166 | -- 8.0000 167 | -- 9.0000 168 | -- 10.0000 169 | -- @ 170 | -- 171 | -- @ 172 | -- >>> matrix \@Double (2,2) [[1,2],[3,4]] 173 | -- ArrayFire Array 174 | -- [2 2 1 1] 175 | -- 1.0000 3.0000 176 | -- 2.0000 4.0000 177 | -- @ 178 | -- 179 | -- @ 180 | -- >>> cube \@Double (2,2,2) [[[2,2],[2,2]],[[2,2],[2,2]]] 181 | -- ArrayFire Array 182 | -- [2 2 2 1] 183 | -- 2.0000 2.0000 184 | -- 2.0000 2.0000 185 | -- 186 | -- 2.0000 2.0000 187 | -- 2.0000 2.0000 188 | -- @ 189 | -- 190 | -- @ 191 | -- >>> tensor \@Double (2,2,2,2) [[[[2,2],[2,2]],[[2,2],[2,2]]], [[[2,2],[2,2]],[[2,2],[2,2]]]] 192 | -- ArrayFire Array 193 | -- [2 2 2 2] 194 | -- 2.0000 2.0000 195 | -- 2.0000 2.0000 196 | -- 197 | -- 2.0000 2.0000 198 | -- 2.0000 2.0000 199 | -- 200 | -- 201 | -- 2.0000 2.0000 202 | -- 2.0000 2.0000 203 | -- 204 | -- 2.0000 2.0000 205 | -- 2.0000 2.0000 206 | -- @ 207 | -- 208 | -- Array construction can use Haskell's lazy lists, since 'take' is called on each dimension before sending to the C API. 209 | -- 210 | -- >>> mkArray @Double [5,3] [1..] 211 | -- ArrayFire Array 212 | -- [5 3 1 1] 213 | -- 1.0000 6.0000 11.0000 214 | -- 2.0000 7.0000 12.0000 215 | -- 3.0000 8.0000 13.0000 216 | -- 4.0000 9.0000 14.0000 217 | -- 5.0000 10.0000 15.0000 218 | -- 219 | -- Specifying up to 4 dimensions is allowed (anything higher is ignored). 220 | 221 | -- $laws 222 | -- Every 'Array' has an instance of 'Eq', 'Num', 'Fractional', 'Floating' and 'Show' 223 | -- 224 | -- 'Num' 225 | -- 226 | -- >>> 2.0 :: Array Double 227 | -- ArrayFire Array 228 | -- [1 1 1 1] 229 | -- 2.0000 230 | -- 231 | -- >>> scalar @Int 1 + scalar @Int 1 232 | -- ArrayFire Array 233 | -- [1 1 1 1] 234 | -- 2 235 | -- 236 | -- >>> scalar @Int 1 - scalar @Int 1 237 | -- ArrayFire Array 238 | -- [1 1 1 1] 239 | -- 0 240 | -- 241 | -- >>> scalar @Double 10 / scalar @Double 10 242 | -- ArrayFire Array 243 | -- [1 1 1 1] 244 | -- 1.0000 245 | -- 246 | -- >>> abs $ scalar @Double (-10) 247 | -- ArrayFire Array 248 | -- [1 1 1 1] 249 | -- 10.0000 250 | -- 251 | -- >>> negate (scalar @Double 10) 252 | -- ArrayFire Array 253 | -- [1 1 1 1] 254 | -- -10.0000 255 | -- 256 | -- >>> fromInteger 1.0 :: Array Double 257 | -- ArrayFire Array 258 | -- [1 1 1 1] 259 | -- 1.0000 260 | -- 261 | -- 'Eq' 262 | -- 263 | -- >>> scalar @Double 1 [10] == scalar @Double 1 [10] 264 | -- True 265 | -- >>> scalar @Double 1 [10] /= scalar @Double 1 [10] 266 | -- False 267 | -- 268 | -- 'Floating' 269 | -- 270 | -- >>> pi :: Array Double 271 | -- ArrayFire Array 272 | -- [1 1 1 1] 273 | -- 3.1416 274 | -- 275 | -- >>> A.sqrt pi :: Array Double 276 | -- ArrayFire Array 277 | -- [1 1 1 1] 278 | -- 1.7725 279 | -- 280 | -- 'Fractional' 281 | -- 282 | -- >>> (pi :: Array Double) / pi 283 | -- ArrayFire Array 284 | -- [1 1 1 1] 285 | -- 1.000 286 | -- 287 | -- >>> recip 0.5 :: Array Double 288 | -- ArrayFire Array 289 | -- [1 1 1 1] 290 | -- 2.000 291 | -- 292 | -- 'Show' 293 | -- 294 | -- >>> 0.0 :: Array Double 295 | -- ArrayFire Array 296 | -- [1 1 1 1] 297 | -- 0.000 298 | -- 299 | 300 | -- $conversion 301 | -- Any 'Array' can be exported into Haskell using `toVector'. This will create a Storable vector suitable for use in other C programs. 302 | -- 303 | -- >>> vector :: Vector Double <- toVector <$> randu @Double [10,10] 304 | -- 305 | 306 | -- $serialization 307 | -- Each 'Array' can be serialized to disk and deserialized from disk efficiently. 308 | -- 309 | -- @ 310 | -- import qualified ArrayFire as A 311 | -- import Control.Monad 312 | -- 313 | -- main :: IO () 314 | -- main = do 315 | -- let arr = A.'constant' [1,1,1,1] 10 316 | -- idx <- A.'saveArray' "key" arr "file.array" False 317 | -- foundIndex <- A.'readArrayKeyCheck' "file.array" "key" 318 | -- when (idx == foundIndex) $ do 319 | -- array <- A.'readArrayKey' "file.array" "key" 320 | -- 'print' array 321 | -- 322 | -- -- ArrayFire Array 323 | -- -- [ 1 1 1 1 ] 324 | -- -- 10 325 | -- @ 326 | -- 327 | 328 | -- $device 329 | -- The ArrayFire API is able to see which devices are present, and will by default use the GPU if available. 330 | -- 331 | -- >>> afInfo 332 | -- ArrayFire v3.6.4 (OpenCL, 64-bit Mac OSX, build 1b8030c5) 333 | -- [0] APPLE: AMD Radeon Pro 555X Compute Engine, 4096 MB <-- brackets [] signify device being used. 334 | -- -1- APPLE: Intel(R) UHD Graphics 630, 1536 MB 335 | -- 336 | 337 | -- $visualization 338 | -- The ArrayFire API is able to display visualizations using the Forge library 339 | -- >>> window <- createWindow 800 600 "Histogram" 340 | -- 341 | -------------------------------------------------------------------------------- /src/ArrayFire/BLAS.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE ViewPatterns #-} 2 | -------------------------------------------------------------------------------- 3 | -- | 4 | -- Module : ArrayFire.BLAS 5 | -- Copyright : David Johnson (c) 2019-2020 6 | -- License : BSD3 7 | -- Maintainer : David Johnson 8 | -- Stability : Experimental 9 | -- Portability : GHC 10 | -- 11 | -- Basic Linear Algebra Subprograms (BLAS) API 12 | -- 13 | -- @ 14 | -- main :: IO () 15 | -- main = print (matmul x y xProp yProp) 16 | -- where 17 | -- x,y :: Array Double 18 | -- x = matrix (2,3) [[1,2],[3,4],[5,6]] 19 | -- y = matrix (3,2) [[1,2,3],[4,5,6]] 20 | -- 21 | -- xProp, yProp :: MatProp 22 | -- xProp = None 23 | -- yProp = None 24 | -- @ 25 | -- @ 26 | -- ArrayFire Array 27 | -- [2 2 1 1] 28 | -- 22.0000 49.0000 29 | -- 28.0000 64.0000 30 | -- @ 31 | -------------------------------------------------------------------------------- 32 | module ArrayFire.BLAS where 33 | 34 | import Data.Complex 35 | 36 | import ArrayFire.FFI 37 | import ArrayFire.Internal.BLAS 38 | import ArrayFire.Internal.Types 39 | 40 | -- | The following applies for Sparse-Dense matrix multiplication. 41 | -- 42 | -- This function can be used with one sparse input. The sparse input must always be the lhs and the dense matrix must be rhs. 43 | -- 44 | -- The sparse array can only be of 'CSR' format. 45 | -- 46 | -- The returned array is always dense. 47 | -- 48 | -- optLhs an only be one of AF_MAT_NONE, AF_MAT_TRANS, AF_MAT_CTRANS. 49 | -- 50 | -- optRhs can only be AF_MAT_NONE. 51 | -- 52 | -- >>> matmul (matrix @Double (2,2) [[1,2],[3,4]]) (matrix @Double (2,2) [[1,2],[3,4]]) None None 53 | -- ArrayFire Array 54 | -- [2 2 1 1] 55 | -- 7.0000 15.0000 56 | -- 10.0000 22.0000 57 | matmul 58 | :: AFType a 59 | => Array a 60 | -- ^ 2D matrix of Array a, left-hand side 61 | -> Array a 62 | -- ^ 2D matrix of Array a, right-hand side 63 | -> MatProp 64 | -- ^ Left hand side matrix options 65 | -> MatProp 66 | -- ^ Right hand side matrix options 67 | -> Array a 68 | -- ^ Output of 'matmul' 69 | matmul arr1 arr2 prop1 prop2 = do 70 | op2 arr1 arr2 (\p a b -> af_matmul p a b (toMatProp prop1) (toMatProp prop2)) 71 | 72 | -- | Scalar dot product between two vectors. Also referred to as the inner product. 73 | -- 74 | -- >>> dot (vector @Double 10 [1..]) (vector @Double 10 [1..]) None None 75 | -- ArrayFire Array 76 | -- [1 1 1 1] 77 | -- 385.0000 78 | dot 79 | :: AFType a 80 | => Array a 81 | -- ^ Left-hand side input 82 | -> Array a 83 | -- ^ Right-hand side input 84 | -> MatProp 85 | -- ^ Options for left-hand side. Currently only AF_MAT_NONE and AF_MAT_CONJ are supported. 86 | -> MatProp 87 | -- ^ Options for right-hand side. Currently only AF_MAT_NONE and AF_MAT_CONJ are supported. 88 | -> Array a 89 | -- ^ Output of 'dot' 90 | dot arr1 arr2 prop1 prop2 = 91 | op2 arr1 arr2 (\p a b -> af_dot p a b (toMatProp prop1) (toMatProp prop2)) 92 | 93 | -- | Scalar dot product between two vectors. Also referred to as the inner product. Returns the result as a host scalar. 94 | -- 95 | -- >>> dotAll (vector @Double 10 [1..]) (vector @Double 10 [1..]) None None 96 | -- 385.0 :+ 0.0 97 | dotAll 98 | :: AFType a 99 | => Array a 100 | -- ^ Left-hand side array 101 | -> Array a 102 | -- ^ Right-hand side array 103 | -> MatProp 104 | -- ^ Options for left-hand side. Currently only AF_MAT_NONE and AF_MAT_CONJ are supported. 105 | -> MatProp 106 | -- ^ Options for right-hand side. Currently only AF_MAT_NONE and AF_MAT_CONJ are supported. 107 | -> Complex Double 108 | -- ^ Real and imaginary component result 109 | dotAll arr1 arr2 prop1 prop2 = do 110 | let (real,imag) = 111 | infoFromArray22 arr1 arr2 $ \a b c d -> 112 | af_dot_all a b c d (toMatProp prop1) (toMatProp prop2) 113 | real :+ imag 114 | 115 | -- | Transposes a matrix. 116 | -- 117 | -- >>> array = matrix @Double (2,3) [[2,3],[3,4],[5,6]] 118 | -- >>> array 119 | -- ArrayFire Array 120 | -- [2 3 1 1] 121 | -- 2.0000 3.0000 5.0000 122 | -- 3.0000 4.0000 6.0000 123 | -- 124 | -- >>> transpose array True 125 | -- ArrayFire Array 126 | -- [3 2 1 1] 127 | -- 2.0000 3.0000 128 | -- 3.0000 4.0000 129 | -- 5.0000 6.0000 130 | -- 131 | transpose 132 | :: AFType a 133 | => Array a 134 | -- ^ Input matrix to be transposed 135 | -> Bool 136 | -- ^ Should perform conjugate transposition 137 | -> Array a 138 | -- ^ The transposed matrix 139 | transpose arr1 (fromIntegral . fromEnum -> b) = 140 | arr1 `op1` (\x y -> af_transpose x y b) 141 | 142 | -- | Transposes a matrix. 143 | -- 144 | -- * Warning: This function mutates an array in-place, all subsequent references will be changed. Use carefully. 145 | -- 146 | -- >>> array = matrix @Double (2,2) [[1,2],[3,4]] 147 | -- >>> array 148 | -- ArrayFire Array 149 | -- [3 2 1 1] 150 | -- 1.0000 4.0000 151 | -- 2.0000 5.0000 152 | -- 3.0000 6.0000 153 | -- 154 | -- >>> transposeInPlace array False 155 | -- >>> array 156 | -- ArrayFire Array 157 | -- [2 2 1 1] 158 | -- 1.0000 2.0000 159 | -- 3.0000 4.0000 160 | -- 161 | transposeInPlace 162 | :: AFType a 163 | => Array a 164 | -- ^ Input matrix to be transposed 165 | -> Bool 166 | -- ^ Should perform conjugate transposition 167 | -> IO () 168 | transposeInPlace arr (fromIntegral . fromEnum -> b) = 169 | arr `inPlace` (`af_transpose_inplace` b) 170 | -------------------------------------------------------------------------------- /src/ArrayFire/Backend.hs: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- 2 | -- | 3 | -- Module : ArrayFire.Backend 4 | -- Copyright : David Johnson (c) 2019-2020 5 | -- License : BSD 3 6 | -- Maintainer : David Johnson 7 | -- Stability : Experimental 8 | -- Portability : GHC 9 | -- 10 | -- Set and get available ArrayFire 'Backend's. 11 | -- 12 | -- @ 13 | -- module Main where 14 | -- 15 | -- import ArrayFire 16 | -- 17 | -- main :: IO () 18 | -- main = print =<< getAvailableBackends 19 | -- @ 20 | -- 21 | -- @ 22 | -- [CPU,OpenCL] 23 | -- @ 24 | -------------------------------------------------------------------------------- 25 | module ArrayFire.Backend where 26 | 27 | import ArrayFire.FFI 28 | import ArrayFire.Internal.Backend 29 | import ArrayFire.Internal.Types 30 | 31 | -- | Set specific 'Backend' to use 32 | -- 33 | -- >>> setBackend OpenCL 34 | -- () 35 | setBackend 36 | :: Backend 37 | -- ^ 'Backend' to use for 'Array' construction 38 | -> IO () 39 | setBackend = afCall . af_set_backend . toAFBackend 40 | 41 | -- | Retrieve count of Backends available 42 | -- 43 | -- >>> getBackendCount 44 | -- 2 45 | -- 46 | getBackendCount :: IO Int 47 | getBackendCount = 48 | fromIntegral <$> 49 | afCall1 af_get_backend_count 50 | 51 | -- | Retrieve available 'Backend's 52 | -- 53 | -- >>> mapM_ print =<< getAvailableBackends 54 | -- CPU 55 | -- OpenCL 56 | getAvailableBackends :: IO [Backend] 57 | getAvailableBackends = 58 | toBackends . fromIntegral <$> 59 | afCall1 af_get_available_backends 60 | 61 | -- | Retrieve 'Backend' that specific 'Array' was created from 62 | -- 63 | -- >>> getBackend (scalar @Double 2.0) 64 | -- OpenCL 65 | getBackend :: Array a -> Backend 66 | getBackend = toBackend . flip infoFromArray af_get_backend_id 67 | 68 | -- | Retrieve active 'Backend' 69 | -- 70 | -- >>> getActiveBackend 71 | -- OpenCL 72 | getActiveBackend :: IO Backend 73 | getActiveBackend = toBackend <$> afCall1 af_get_active_backend 74 | 75 | -- | Retrieve Device ID that 'Array' was created from 76 | -- 77 | -- >>> getDeviceID (scalar \@Double 2.0) 78 | -- 1 79 | getDeviceID :: Array a -> Int 80 | getDeviceID = fromIntegral . flip infoFromArray af_get_device_id 81 | -------------------------------------------------------------------------------- /src/ArrayFire/Device.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE ViewPatterns #-} 2 | -------------------------------------------------------------------------------- 3 | -- | 4 | -- Module : ArrayFire.Device 5 | -- Copyright : David Johnson (c) 2019-2020 6 | -- License : BSD3 7 | -- Maintainer : David Johnson 8 | -- Stability : Experimental 9 | -- Portability : GHC 10 | -- 11 | -- Information about ArrayFire API and devices 12 | -- 13 | -- >>> info 14 | -- ArrayFire v3.6.4 (OpenCL, 64-bit Mac OSX, build 1b8030c5) 15 | -- [0] APPLE: AMD Radeon Pro 555X Compute Engine, 4096 MB 16 | -- -1- APPLE: Intel(R) UHD Graphics 630, 1536 MB 17 | -- 18 | -------------------------------------------------------------------------------- 19 | module ArrayFire.Device where 20 | 21 | import Foreign.C.String 22 | import ArrayFire.Internal.Device 23 | import ArrayFire.FFI 24 | 25 | -- | Retrieve info from ArrayFire API 26 | -- 27 | -- @ 28 | -- ArrayFire v3.6.4 (OpenCL, 64-bit Mac OSX, build 1b8030c5) 29 | -- [0] APPLE: AMD Radeon Pro 555X Compute Engine, 4096 MB 30 | -- -1- APPLE: Intel(R) UHD Graphics 630, 1536 MB 31 | -- @ 32 | info :: IO () 33 | info = afCall af_info 34 | 35 | -- | Calls /af_init/ C function from ArrayFire API 36 | -- 37 | -- >>> afInit 38 | -- () 39 | afInit :: IO () 40 | afInit = afCall af_init 41 | 42 | -- | Retrieves ArrayFire device information as 'String', same as 'info'. 43 | -- 44 | -- >>> getInfoString 45 | -- "ArrayFire v3.6.4 (OpenCL, 64-bit Mac OSX, build 1b8030c5)\n[0] APPLE: AMD Radeon Pro 555X Compute Engine, 4096 MB\n-1- APPLE: Intel(R) UHD Graphics 630, 1536 MB\n" 46 | getInfoString :: IO String 47 | getInfoString = peekCString =<< afCall1 (flip af_info_string 1) 48 | 49 | -- af_err af_device_info(char* d_name, char* d_platform, char *d_toolkit, char* d_compute); 50 | 51 | -- | Retrieves count of devices 52 | -- 53 | -- >>> getDeviceCount 54 | -- 2 55 | getDeviceCount :: IO Int 56 | getDeviceCount = fromIntegral <$> afCall1 af_get_device_count 57 | 58 | -- af_err af_get_dbl_support(bool* available, const int device); 59 | -- | Sets a device by 'Int' 60 | -- 61 | -- >>> setDevice 0 62 | -- () 63 | setDevice :: Int -> IO () 64 | setDevice (fromIntegral -> x) = afCall (af_set_device x) 65 | 66 | -- | Retrieves device identifier 67 | -- 68 | -- >>> getDevice 69 | -- 0 70 | getDevice :: IO Int 71 | getDevice = fromIntegral <$> afCall1 af_get_device 72 | 73 | -- af_err af_sync(const int device); 74 | -- af_err af_alloc_device(void **ptr, const dim_t bytes); 75 | -- af_err af_free_device(void *ptr); 76 | -- af_err af_alloc_pinned(void **ptr, const dim_t bytes); 77 | -- af_err af_free_pinned(void *ptr); 78 | -- af_err af_alloc_host(void **ptr, const dim_t bytes); 79 | -- af_err af_free_host(void *ptr); 80 | -- af_err af_device_array(af_array *arr, const void *data, const unsigned ndims, const dim_t * const dims, const af_dtype type); 81 | -- af_err af_device_mem_info(size_t *alloc_bytes, size_t *alloc_buffers, size_t *lock_bytes, size_t *lock_buffers); 82 | -- af_err af_print_mem_info(const char *msg, const int device_id); 83 | -- af_err af_device_gc(); 84 | -- af_err af_set_mem_step_size(const size_t step_bytes); 85 | -- af_err af_get_mem_step_size(size_t *step_bytes); 86 | -- af_err af_lock_device_ptr(const af_array arr); 87 | -- af_err af_unlock_device_ptr(const af_array arr); 88 | -- af_err af_lock_array(const af_array arr); 89 | -- af_err af_is_locked_array(bool *res, const af_array arr); 90 | -- af_err af_get_device_ptr(void **ptr, const af_array arr); 91 | -------------------------------------------------------------------------------- /src/ArrayFire/Exception.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE ViewPatterns #-} 2 | {-# LANGUAGE RecordWildCards #-} 3 | -------------------------------------------------------------------------------- 4 | -- | 5 | -- Module : ArrayFire.Exception 6 | -- Copyright : David Johnson (c) 2019-2020 7 | -- License : BSD 3 8 | -- Maintainer : David Johnson 9 | -- Stability : Experimental 10 | -- Portability : GHC 11 | -- 12 | -- @ 13 | -- module Main where 14 | -- 15 | -- import ArrayFire 16 | -- 17 | -- main :: IO () 18 | -- main = print =<< getAvailableBackends 19 | -- @ 20 | -- 21 | -- @ 22 | -- [nix-shell:~\/arrayfire]$ .\/main 23 | -- [CPU,OpenCL] 24 | -- @ 25 | -------------------------------------------------------------------------------- 26 | module ArrayFire.Exception where 27 | 28 | import Control.Exception hiding (TypeError) 29 | import Data.Typeable 30 | import Control.Monad 31 | import Foreign.C.String 32 | import Foreign.Ptr 33 | import ArrayFire.Internal.Exception 34 | import ArrayFire.Internal.Defines 35 | 36 | -- | String representation of ArrayFire exception 37 | errorToString :: AFErr -> IO String 38 | errorToString = peekCString <=< af_err_to_string 39 | 40 | -- | ArrayFire exception type 41 | data AFExceptionType 42 | = NoMemoryError 43 | | DriverError 44 | | RuntimeError 45 | | InvalidArrayError 46 | | ArgError 47 | | SizeError 48 | | TypeError 49 | | DiffTypeError 50 | | BatchError 51 | | DeviceError 52 | | NotSupportedError 53 | | NotConfiguredError 54 | | NonFreeError 55 | | NoDblError 56 | | NoGfxError 57 | | LoadLibError 58 | | LoadSymError 59 | | BackendMismatchError 60 | | InternalError 61 | | UnknownError 62 | | UnhandledError 63 | deriving (Show, Eq, Typeable) 64 | 65 | -- | Exception type for ArrayFire API 66 | data AFException 67 | = AFException 68 | { afExceptionType :: AFExceptionType 69 | -- ^ The Exception type to throw 70 | , afExceptionCode :: Int 71 | -- ^ Code representing the exception 72 | , afExceptionMsg :: String 73 | -- ^ Exception message 74 | } deriving (Show, Eq, Typeable) 75 | 76 | instance Exception AFException 77 | 78 | -- | Conversion function helper 79 | toAFExceptionType :: AFErr -> AFExceptionType 80 | toAFExceptionType (AFErr 101) = NoMemoryError 81 | toAFExceptionType (AFErr 102) = DriverError 82 | toAFExceptionType (AFErr 103) = RuntimeError 83 | toAFExceptionType (AFErr 201) = InvalidArrayError 84 | toAFExceptionType (AFErr 202) = ArgError 85 | toAFExceptionType (AFErr 203) = SizeError 86 | toAFExceptionType (AFErr 204) = TypeError 87 | toAFExceptionType (AFErr 205) = DiffTypeError 88 | toAFExceptionType (AFErr 207) = BatchError 89 | toAFExceptionType (AFErr 208) = DeviceError 90 | toAFExceptionType (AFErr 301) = NotSupportedError 91 | toAFExceptionType (AFErr 302) = NotConfiguredError 92 | toAFExceptionType (AFErr 303) = NonFreeError 93 | toAFExceptionType (AFErr 401) = NoDblError 94 | toAFExceptionType (AFErr 402) = NoGfxError 95 | toAFExceptionType (AFErr 501) = LoadLibError 96 | toAFExceptionType (AFErr 502) = LoadSymError 97 | toAFExceptionType (AFErr 503) = BackendMismatchError 98 | toAFExceptionType (AFErr 998) = InternalError 99 | toAFExceptionType (AFErr 999) = UnknownError 100 | toAFExceptionType (AFErr _) = UnhandledError 101 | 102 | -- | Throws an ArrayFire Exception 103 | throwAFError :: AFErr -> IO () 104 | throwAFError exitCode = 105 | unless (exitCode == afSuccess) $ do 106 | let AFErr (fromIntegral -> afExceptionCode) = exitCode 107 | afExceptionType = toAFExceptionType exitCode 108 | afExceptionMsg <- errorToString exitCode 109 | throwIO AFException {..} 110 | 111 | foreign import ccall unsafe "&af_release_random_engine" 112 | af_release_random_engine_finalizer :: FunPtr (AFRandomEngine -> IO ()) 113 | 114 | foreign import ccall unsafe "&af_destroy_window" 115 | af_release_window_finalizer :: FunPtr (AFWindow -> IO ()) 116 | 117 | foreign import ccall unsafe "&af_release_array" 118 | af_release_array_finalizer :: FunPtr (AFArray -> IO ()) 119 | -------------------------------------------------------------------------------- /src/ArrayFire/Features.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE ViewPatterns #-} 2 | -------------------------------------------------------------------------------- 3 | -- | 4 | -- Module : ArrayFire.Features 5 | -- Copyright : David Johnson (c) 2019-2020 6 | -- License : BSD 3 7 | -- Maintainer : David Johnson 8 | -- Stability : Experimental 9 | -- Portability : GHC 10 | -- 11 | -- Functions for constructing and querying 'Features' 12 | -- 13 | -- @ 14 | -- >>> createFeatures 10 15 | -- @ 16 | -- 17 | -------------------------------------------------------------------------------- 18 | module ArrayFire.Features where 19 | 20 | import Foreign.Marshal 21 | import Foreign.Storable 22 | import Foreign.ForeignPtr 23 | import System.IO.Unsafe 24 | 25 | import ArrayFire.Internal.Features 26 | import ArrayFire.Internal.Types 27 | import ArrayFire.FFI 28 | import ArrayFire.Exception 29 | 30 | -- | Construct Features 31 | -- 32 | -- >>> features = createFeatures 10 33 | -- 34 | createFeatures 35 | :: Int 36 | -> Features 37 | createFeatures (fromIntegral -> n) = 38 | unsafePerformIO $ do 39 | ptr <- 40 | alloca $ \ptrInput -> do 41 | throwAFError =<< ptrInput `af_create_features` n 42 | peek ptrInput 43 | fptr <- newForeignPtr af_release_features ptr 44 | pure (Features fptr) 45 | 46 | -- | Retain Features 47 | -- 48 | -- >>> features = retainFeatures (createFeatures 10) 49 | -- 50 | retainFeatures 51 | :: Features 52 | -> Features 53 | retainFeatures = (`op1f` af_retain_features) 54 | 55 | -- | Get number of Features 56 | -- 57 | -- link 58 | -- 59 | -- >>> getFeaturesNum (createFeatures 10) 60 | -- 10 61 | -- 62 | getFeaturesNum 63 | :: Features 64 | -> Int 65 | getFeaturesNum = fromIntegral . (`infoFromFeatures` af_get_features_num) 66 | 67 | -- | Get Feature X-position 68 | -- 69 | -- >>> getFeaturesXPos (createFeatures 10) 70 | -- ArrayFire Array 71 | -- [10 1 1 1] 72 | -- 0.0000 73 | -- 1.8750 74 | -- 0.0000 75 | -- 2.3750 76 | -- 0.0000 77 | -- 2.5938 78 | -- 0.0000 79 | -- 2.0000 80 | -- 0.0000 81 | -- 2.4375 82 | getFeaturesXPos 83 | :: Features 84 | -> Array a 85 | getFeaturesXPos = (`featuresToArray` af_get_features_xpos) 86 | 87 | -- | Get Feature Y-position 88 | -- 89 | -- >>> getFeaturesYPos (createFeatures 10) 90 | -- ArrayFire Array 91 | -- [10 1 1 1] 92 | -- nan 93 | -- nan 94 | -- nan 95 | -- nan 96 | -- nan 97 | -- nan 98 | -- nan 99 | -- nan 100 | -- nan 101 | -- nan 102 | getFeaturesYPos 103 | :: Features 104 | -> Array a 105 | getFeaturesYPos = (`featuresToArray` af_get_features_ypos) 106 | 107 | -- | Get Feature Score 108 | -- 109 | -- >>> getFeaturesScore (createFeatures 10) 110 | -- ArrayFire Array 111 | -- [10 1 1 1] 112 | -- nan 113 | -- nan 114 | -- nan 115 | -- nan 116 | -- nan 117 | -- nan 118 | -- nan 119 | -- nan 120 | -- nan 121 | -- nan 122 | getFeaturesScore 123 | :: Features 124 | -> Array a 125 | getFeaturesScore = (`featuresToArray` af_get_features_score) 126 | 127 | -- | Get Feature orientation 128 | -- 129 | -- >>> getFeaturesOrientation (createFeatures 10) 130 | -- ArrayFire Array 131 | -- [10 1 1 1] 132 | -- nan 133 | -- nan 134 | -- nan 135 | -- nan 136 | -- nan 137 | -- nan 138 | -- nan 139 | -- nan 140 | -- nan 141 | -- nan 142 | getFeaturesOrientation 143 | :: Features 144 | -> Array a 145 | getFeaturesOrientation = (`featuresToArray` af_get_features_orientation) 146 | 147 | -- | Get Feature size 148 | -- 149 | -- >>> getFeaturesSize (createFeatures 10) 150 | -- ArrayFire Array 151 | -- [10 1 1 1] 152 | -- nan 153 | -- nan 154 | -- nan 155 | -- nan 156 | -- nan 157 | -- nan 158 | -- nan 159 | -- nan 160 | -- nan 161 | -- nan 162 | getFeaturesSize 163 | :: Features 164 | -> Array a 165 | getFeaturesSize = (`featuresToArray` af_get_features_size) 166 | -------------------------------------------------------------------------------- /src/ArrayFire/Index.hs: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- 2 | -- | 3 | -- Module : ArrayFire.Index 4 | -- Copyright : David Johnson (c) 2019-2020 5 | -- License : BSD 3 6 | -- Maintainer : David Johnson 7 | -- Stability : Experimental 8 | -- Portability : GHC 9 | -- 10 | -- Functions for indexing into an 'Array' 11 | -- 12 | -------------------------------------------------------------------------------- 13 | module ArrayFire.Index where 14 | 15 | import ArrayFire.Internal.Index 16 | import ArrayFire.Internal.Types 17 | import ArrayFire.FFI 18 | import ArrayFire.Exception 19 | 20 | import Foreign 21 | 22 | import System.IO.Unsafe 23 | import Control.Exception 24 | 25 | -- | Index into an 'Array' by 'Seq' 26 | index 27 | :: Array a 28 | -- ^ 'Array' argument 29 | -> [Seq] 30 | -- ^ 'Seq' to use for indexing 31 | -> Array a 32 | index (Array fptr) seqs = 33 | unsafePerformIO . mask_ . withForeignPtr fptr $ \ptr -> do 34 | alloca $ \aptr -> 35 | withArray (toAFSeq <$> seqs) $ \sptr -> do 36 | throwAFError =<< af_index aptr ptr n sptr 37 | Array <$> do 38 | newForeignPtr af_release_array_finalizer 39 | =<< peek aptr 40 | where 41 | n = fromIntegral (length seqs) 42 | 43 | -- | Lookup an Array by keys along a specified dimension 44 | lookup :: Array a -> Array a -> Int -> Array a 45 | lookup a b n = op2 a b $ \p x y -> af_lookup p x y (fromIntegral n) 46 | 47 | -- af_err af_assign_seq( af_array *out, const af_array lhs, const unsigned ndims, const af_seq* const indices, const af_array rhs); 48 | -- | Calculates 'mean' of 'Array' along user-specified dimension. 49 | -- 50 | -- @ 51 | -- >>> print $ mean 0 ( vector @Int 10 [1..] ) 52 | -- @ 53 | -- @ 54 | -- ArrayFire Array 55 | -- [1 1 1 1] 56 | -- 5.5000 57 | -- @ 58 | -- assignSeq :: Array a -> Int -> [Seq] -> Array a -> Array a 59 | -- assignSeq = error "Not implemneted" 60 | 61 | -- af_err af_index_gen( af_array *out, const af_array in, const dim_t ndims, const af_index_t* indices); 62 | -- | Calculates 'mean' of 'Array' along user-specified dimension. 63 | -- 64 | -- @ 65 | -- >>> print $ mean 0 ( vector @Int 10 [1..] ) 66 | -- @ 67 | -- @ 68 | -- ArrayFire Array 69 | -- [1 1 1 1] 70 | -- 5.5000 71 | -- @ 72 | -- indexGen :: Array a -> Int -> [Index a] -> Array a -> Array a 73 | -- indexGen = error "Not implemneted" 74 | 75 | -- af_err af_assingn_gen( af_array *out, const af_array lhs, const dim_t ndims, const af_index_t* indices, const af_array rhs); 76 | -- | Calculates 'mean' of 'Array' along user-specified dimension. 77 | -- 78 | -- @ 79 | -- >>> print $ mean 0 ( vector @Int 10 [1..] ) 80 | -- @ 81 | -- @ 82 | -- ArrayFire Array 83 | -- [1 1 1 1] 84 | -- 5.5000 85 | -- @ 86 | -- assignGen :: Array a -> Int -> [Index a] -> Array a -> Array a 87 | -- assignGen = error "Not implemneted" 88 | 89 | -- af_err af_create_indexers(af_index_t** indexers); 90 | -- af_err af_set_array_indexer(af_index_t* indexer, const af_array idx, const dim_t dim); 91 | -- af_err af_set_seq_indexer(af_index_t* indexer, const af_seq* idx, const dim_t dim, const bool is_batch); 92 | -- af_err af_set_seq_param_indexer(af_index_t* indexer, const double begin, const double end, const double step, const dim_t dim, const bool is_batch); 93 | -- af_err af_release_indexers(af_index_t* indexers); 94 | -------------------------------------------------------------------------------- /src/ArrayFire/Internal/Algorithm.hsc: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE CPP #-} 2 | module ArrayFire.Internal.Algorithm where 3 | 4 | import ArrayFire.Internal.Defines 5 | 6 | import Foreign.Ptr 7 | import Foreign.C.Types 8 | 9 | #include "af/algorithm.h" 10 | foreign import ccall unsafe "af_sum" 11 | af_sum :: Ptr AFArray -> AFArray -> CInt -> IO AFErr 12 | foreign import ccall unsafe "af_sum_nan" 13 | af_sum_nan :: Ptr AFArray -> AFArray -> CInt -> Double -> IO AFErr 14 | foreign import ccall unsafe "af_product" 15 | af_product :: Ptr AFArray -> AFArray -> CInt -> IO AFErr 16 | foreign import ccall unsafe "af_product_nan" 17 | af_product_nan :: Ptr AFArray -> AFArray -> CInt -> Double -> IO AFErr 18 | foreign import ccall unsafe "af_min" 19 | af_min :: Ptr AFArray -> AFArray -> CInt -> IO AFErr 20 | foreign import ccall unsafe "af_max" 21 | af_max :: Ptr AFArray -> AFArray -> CInt -> IO AFErr 22 | foreign import ccall unsafe "af_all_true" 23 | af_all_true :: Ptr AFArray -> AFArray -> CInt -> IO AFErr 24 | foreign import ccall unsafe "af_any_true" 25 | af_any_true :: Ptr AFArray -> AFArray -> CInt -> IO AFErr 26 | foreign import ccall unsafe "af_count" 27 | af_count :: Ptr AFArray -> AFArray -> CInt -> IO AFErr 28 | foreign import ccall unsafe "af_sum_all" 29 | af_sum_all :: Ptr Double -> Ptr Double -> AFArray -> IO AFErr 30 | foreign import ccall unsafe "af_sum_nan_all" 31 | af_sum_nan_all :: Ptr Double -> Ptr Double -> AFArray -> Double -> IO AFErr 32 | foreign import ccall unsafe "af_product_all" 33 | af_product_all :: Ptr Double -> Ptr Double -> AFArray -> IO AFErr 34 | foreign import ccall unsafe "af_product_nan_all" 35 | af_product_nan_all :: Ptr Double -> Ptr Double -> AFArray -> Double -> IO AFErr 36 | foreign import ccall unsafe "af_min_all" 37 | af_min_all :: Ptr Double -> Ptr Double -> AFArray -> IO AFErr 38 | foreign import ccall unsafe "af_max_all" 39 | af_max_all :: Ptr Double -> Ptr Double -> AFArray -> IO AFErr 40 | foreign import ccall unsafe "af_all_true_all" 41 | af_all_true_all :: Ptr Double -> Ptr Double -> AFArray -> IO AFErr 42 | foreign import ccall unsafe "af_any_true_all" 43 | af_any_true_all :: Ptr Double -> Ptr Double -> AFArray -> IO AFErr 44 | foreign import ccall unsafe "af_count_all" 45 | af_count_all :: Ptr Double -> Ptr Double -> AFArray -> IO AFErr 46 | foreign import ccall unsafe "af_imin" 47 | af_imin :: Ptr AFArray -> Ptr AFArray -> AFArray -> CInt -> IO AFErr 48 | foreign import ccall unsafe "af_imax" 49 | af_imax :: Ptr AFArray -> Ptr AFArray -> AFArray -> CInt -> IO AFErr 50 | foreign import ccall unsafe "af_imin_all" 51 | af_imin_all :: Ptr Double -> Ptr Double -> Ptr CUInt -> AFArray -> IO AFErr 52 | foreign import ccall unsafe "af_imax_all" 53 | af_imax_all :: Ptr Double -> Ptr Double -> Ptr CUInt -> AFArray -> IO AFErr 54 | foreign import ccall unsafe "af_accum" 55 | af_accum :: Ptr AFArray -> AFArray -> CInt -> IO AFErr 56 | foreign import ccall unsafe "af_scan" 57 | af_scan :: Ptr AFArray -> AFArray -> CInt -> AFBinaryOp -> CBool -> IO AFErr 58 | foreign import ccall unsafe "af_scan_by_key" 59 | af_scan_by_key :: Ptr AFArray -> AFArray -> AFArray -> CInt -> AFBinaryOp -> CBool -> IO AFErr 60 | foreign import ccall unsafe "af_where" 61 | af_where :: Ptr AFArray -> AFArray -> IO AFErr 62 | foreign import ccall unsafe "af_diff1" 63 | af_diff1 :: Ptr AFArray -> AFArray -> CInt -> IO AFErr 64 | foreign import ccall unsafe "af_diff2" 65 | af_diff2 :: Ptr AFArray -> AFArray -> CInt -> IO AFErr 66 | foreign import ccall unsafe "af_sort" 67 | af_sort :: Ptr AFArray -> AFArray -> CUInt -> CBool -> IO AFErr 68 | foreign import ccall unsafe "af_sort_index" 69 | af_sort_index :: Ptr AFArray -> Ptr AFArray -> AFArray -> CUInt -> CBool -> IO AFErr 70 | foreign import ccall unsafe "af_sort_by_key" 71 | af_sort_by_key :: Ptr AFArray -> Ptr AFArray -> AFArray -> AFArray -> CUInt -> CBool -> IO AFErr 72 | foreign import ccall unsafe "af_set_unique" 73 | af_set_unique :: Ptr AFArray -> AFArray -> CBool -> IO AFErr 74 | foreign import ccall unsafe "af_set_union" 75 | af_set_union :: Ptr AFArray -> AFArray -> AFArray -> CBool -> IO AFErr 76 | foreign import ccall unsafe "af_set_intersect" 77 | af_set_intersect :: Ptr AFArray -> AFArray -> AFArray -> CBool -> IO AFErr 78 | -------------------------------------------------------------------------------- /src/ArrayFire/Internal/Arith.hsc: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE CPP #-} 2 | module ArrayFire.Internal.Arith where 3 | 4 | import ArrayFire.Internal.Defines 5 | 6 | import Foreign.Ptr 7 | import Foreign.C.Types 8 | 9 | #include "af/arith.h" 10 | foreign import ccall unsafe "af_add" 11 | af_add :: Ptr AFArray -> AFArray -> AFArray -> CBool -> IO AFErr 12 | foreign import ccall unsafe "af_sub" 13 | af_sub :: Ptr AFArray -> AFArray -> AFArray -> CBool -> IO AFErr 14 | foreign import ccall unsafe "af_mul" 15 | af_mul :: Ptr AFArray -> AFArray -> AFArray -> CBool -> IO AFErr 16 | foreign import ccall unsafe "af_div" 17 | af_div :: Ptr AFArray -> AFArray -> AFArray -> CBool -> IO AFErr 18 | foreign import ccall unsafe "af_lt" 19 | af_lt :: Ptr AFArray -> AFArray -> AFArray -> CBool -> IO AFErr 20 | foreign import ccall unsafe "af_gt" 21 | af_gt :: Ptr AFArray -> AFArray -> AFArray -> CBool -> IO AFErr 22 | foreign import ccall unsafe "af_le" 23 | af_le :: Ptr AFArray -> AFArray -> AFArray -> CBool -> IO AFErr 24 | foreign import ccall unsafe "af_ge" 25 | af_ge :: Ptr AFArray -> AFArray -> AFArray -> CBool -> IO AFErr 26 | foreign import ccall unsafe "af_eq" 27 | af_eq :: Ptr AFArray -> AFArray -> AFArray -> CBool -> IO AFErr 28 | foreign import ccall unsafe "af_neq" 29 | af_neq :: Ptr AFArray -> AFArray -> AFArray -> CBool -> IO AFErr 30 | foreign import ccall unsafe "af_and" 31 | af_and :: Ptr AFArray -> AFArray -> AFArray -> CBool -> IO AFErr 32 | foreign import ccall unsafe "af_or" 33 | af_or :: Ptr AFArray -> AFArray -> AFArray -> CBool -> IO AFErr 34 | foreign import ccall unsafe "af_not" 35 | af_not :: Ptr AFArray -> AFArray -> IO AFErr 36 | foreign import ccall unsafe "af_bitand" 37 | af_bitand :: Ptr AFArray -> AFArray -> AFArray -> CBool -> IO AFErr 38 | foreign import ccall unsafe "af_bitor" 39 | af_bitor :: Ptr AFArray -> AFArray -> AFArray -> CBool -> IO AFErr 40 | foreign import ccall unsafe "af_bitxor" 41 | af_bitxor :: Ptr AFArray -> AFArray -> AFArray -> CBool -> IO AFErr 42 | foreign import ccall unsafe "af_bitshiftl" 43 | af_bitshiftl :: Ptr AFArray -> AFArray -> AFArray -> CBool -> IO AFErr 44 | foreign import ccall unsafe "af_bitshiftr" 45 | af_bitshiftr :: Ptr AFArray -> AFArray -> AFArray -> CBool -> IO AFErr 46 | foreign import ccall unsafe "af_cast" 47 | af_cast :: Ptr AFArray -> AFArray -> AFDtype -> IO AFErr 48 | foreign import ccall unsafe "af_minof" 49 | af_minof :: Ptr AFArray -> AFArray -> AFArray -> CBool -> IO AFErr 50 | foreign import ccall unsafe "af_maxof" 51 | af_maxof :: Ptr AFArray -> AFArray -> AFArray -> CBool -> IO AFErr 52 | foreign import ccall unsafe "af_clamp" 53 | af_clamp :: Ptr AFArray -> AFArray -> AFArray -> AFArray -> CBool -> IO AFErr 54 | foreign import ccall unsafe "af_rem" 55 | af_rem :: Ptr AFArray -> AFArray -> AFArray -> CBool -> IO AFErr 56 | foreign import ccall unsafe "af_mod" 57 | af_mod :: Ptr AFArray -> AFArray -> AFArray -> CBool -> IO AFErr 58 | foreign import ccall unsafe "af_abs" 59 | af_abs :: Ptr AFArray -> AFArray -> IO AFErr 60 | foreign import ccall unsafe "af_arg" 61 | af_arg :: Ptr AFArray -> AFArray -> IO AFErr 62 | foreign import ccall unsafe "af_sign" 63 | af_sign :: Ptr AFArray -> AFArray -> IO AFErr 64 | foreign import ccall unsafe "af_round" 65 | af_round :: Ptr AFArray -> AFArray -> IO AFErr 66 | foreign import ccall unsafe "af_trunc" 67 | af_trunc :: Ptr AFArray -> AFArray -> IO AFErr 68 | foreign import ccall unsafe "af_floor" 69 | af_floor :: Ptr AFArray -> AFArray -> IO AFErr 70 | foreign import ccall unsafe "af_ceil" 71 | af_ceil :: Ptr AFArray -> AFArray -> IO AFErr 72 | foreign import ccall unsafe "af_hypot" 73 | af_hypot :: Ptr AFArray -> AFArray -> AFArray -> CBool -> IO AFErr 74 | foreign import ccall unsafe "af_sin" 75 | af_sin :: Ptr AFArray -> AFArray -> IO AFErr 76 | foreign import ccall unsafe "af_cos" 77 | af_cos :: Ptr AFArray -> AFArray -> IO AFErr 78 | foreign import ccall unsafe "af_tan" 79 | af_tan :: Ptr AFArray -> AFArray -> IO AFErr 80 | foreign import ccall unsafe "af_asin" 81 | af_asin :: Ptr AFArray -> AFArray -> IO AFErr 82 | foreign import ccall unsafe "af_acos" 83 | af_acos :: Ptr AFArray -> AFArray -> IO AFErr 84 | foreign import ccall unsafe "af_atan" 85 | af_atan :: Ptr AFArray -> AFArray -> IO AFErr 86 | foreign import ccall unsafe "af_atan2" 87 | af_atan2 :: Ptr AFArray -> AFArray -> AFArray -> CBool -> IO AFErr 88 | foreign import ccall unsafe "af_cplx2" 89 | af_cplx2 :: Ptr AFArray -> AFArray -> AFArray -> CBool -> IO AFErr 90 | foreign import ccall unsafe "af_cplx" 91 | af_cplx :: Ptr AFArray -> AFArray -> IO AFErr 92 | foreign import ccall unsafe "af_real" 93 | af_real :: Ptr AFArray -> AFArray -> IO AFErr 94 | foreign import ccall unsafe "af_imag" 95 | af_imag :: Ptr AFArray -> AFArray -> IO AFErr 96 | foreign import ccall unsafe "af_conjg" 97 | af_conjg :: Ptr AFArray -> AFArray -> IO AFErr 98 | foreign import ccall unsafe "af_sinh" 99 | af_sinh :: Ptr AFArray -> AFArray -> IO AFErr 100 | foreign import ccall unsafe "af_cosh" 101 | af_cosh :: Ptr AFArray -> AFArray -> IO AFErr 102 | foreign import ccall unsafe "af_tanh" 103 | af_tanh :: Ptr AFArray -> AFArray -> IO AFErr 104 | foreign import ccall unsafe "af_asinh" 105 | af_asinh :: Ptr AFArray -> AFArray -> IO AFErr 106 | foreign import ccall unsafe "af_acosh" 107 | af_acosh :: Ptr AFArray -> AFArray -> IO AFErr 108 | foreign import ccall unsafe "af_atanh" 109 | af_atanh :: Ptr AFArray -> AFArray -> IO AFErr 110 | foreign import ccall unsafe "af_root" 111 | af_root :: Ptr AFArray -> AFArray -> AFArray -> CBool -> IO AFErr 112 | foreign import ccall unsafe "af_pow" 113 | af_pow :: Ptr AFArray -> AFArray -> AFArray -> CBool -> IO AFErr 114 | foreign import ccall unsafe "af_pow2" 115 | af_pow2 :: Ptr AFArray -> AFArray -> IO AFErr 116 | foreign import ccall unsafe "af_exp" 117 | af_exp :: Ptr AFArray -> AFArray -> IO AFErr 118 | foreign import ccall unsafe "af_sigmoid" 119 | af_sigmoid :: Ptr AFArray -> AFArray -> IO AFErr 120 | foreign import ccall unsafe "af_expm1" 121 | af_expm1 :: Ptr AFArray -> AFArray -> IO AFErr 122 | foreign import ccall unsafe "af_erf" 123 | af_erf :: Ptr AFArray -> AFArray -> IO AFErr 124 | foreign import ccall unsafe "af_erfc" 125 | af_erfc :: Ptr AFArray -> AFArray -> IO AFErr 126 | foreign import ccall unsafe "af_log" 127 | af_log :: Ptr AFArray -> AFArray -> IO AFErr 128 | foreign import ccall unsafe "af_log1p" 129 | af_log1p :: Ptr AFArray -> AFArray -> IO AFErr 130 | foreign import ccall unsafe "af_log10" 131 | af_log10 :: Ptr AFArray -> AFArray -> IO AFErr 132 | foreign import ccall unsafe "af_log2" 133 | af_log2 :: Ptr AFArray -> AFArray -> IO AFErr 134 | foreign import ccall unsafe "af_sqrt" 135 | af_sqrt :: Ptr AFArray -> AFArray -> IO AFErr 136 | foreign import ccall unsafe "af_cbrt" 137 | af_cbrt :: Ptr AFArray -> AFArray -> IO AFErr 138 | foreign import ccall unsafe "af_factorial" 139 | af_factorial :: Ptr AFArray -> AFArray -> IO AFErr 140 | foreign import ccall unsafe "af_tgamma" 141 | af_tgamma :: Ptr AFArray -> AFArray -> IO AFErr 142 | foreign import ccall unsafe "af_lgamma" 143 | af_lgamma :: Ptr AFArray -> AFArray -> IO AFErr 144 | foreign import ccall unsafe "af_iszero" 145 | af_iszero :: Ptr AFArray -> AFArray -> IO AFErr 146 | foreign import ccall unsafe "af_isinf" 147 | af_isinf :: Ptr AFArray -> AFArray -> IO AFErr 148 | foreign import ccall unsafe "af_isnan" 149 | af_isnan :: Ptr AFArray -> AFArray -> IO AFErr 150 | -------------------------------------------------------------------------------- /src/ArrayFire/Internal/Array.hsc: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE CPP #-} 2 | module ArrayFire.Internal.Array where 3 | 4 | import ArrayFire.Internal.Defines 5 | 6 | import Foreign.Ptr 7 | import Foreign.C.Types 8 | 9 | #include "af/array.h" 10 | foreign import ccall unsafe "af_create_array" 11 | af_create_array :: Ptr AFArray -> Ptr () -> CUInt -> Ptr DimT -> AFDtype -> IO AFErr 12 | foreign import ccall unsafe "af_create_handle" 13 | af_create_handle :: Ptr AFArray -> CUInt -> Ptr DimT -> AFDtype -> IO AFErr 14 | foreign import ccall unsafe "af_copy_array" 15 | af_copy_array :: Ptr AFArray -> AFArray -> IO AFErr 16 | foreign import ccall unsafe "af_write_array" 17 | af_write_array :: AFArray -> Ptr () -> CSize -> AFSource -> IO AFErr 18 | foreign import ccall unsafe "af_get_data_ptr" 19 | af_get_data_ptr :: Ptr () -> AFArray -> IO AFErr 20 | foreign import ccall unsafe "af_release_array" 21 | af_release_array :: AFArray -> IO AFErr 22 | foreign import ccall unsafe "af_retain_array" 23 | af_retain_array :: Ptr AFArray -> AFArray -> IO AFErr 24 | foreign import ccall unsafe "af_get_data_ref_count" 25 | af_get_data_ref_count :: Ptr CInt -> AFArray -> IO AFErr 26 | foreign import ccall unsafe "af_eval" 27 | af_eval :: AFArray -> IO AFErr 28 | foreign import ccall unsafe "af_eval_multiple" 29 | af_eval_multiple :: CInt -> Ptr AFArray -> IO AFErr 30 | foreign import ccall unsafe "af_set_manual_eval_flag" 31 | af_set_manual_eval_flag :: CBool -> IO AFErr 32 | foreign import ccall unsafe "af_get_manual_eval_flag" 33 | af_get_manual_eval_flag :: Ptr CBool -> IO AFErr 34 | foreign import ccall unsafe "af_get_elements" 35 | af_get_elements :: Ptr DimT -> AFArray -> IO AFErr 36 | foreign import ccall unsafe "af_get_type" 37 | af_get_type :: Ptr AFDtype -> AFArray -> IO AFErr 38 | foreign import ccall unsafe "af_get_dims" 39 | af_get_dims :: Ptr DimT -> Ptr DimT -> Ptr DimT -> Ptr DimT -> AFArray -> IO AFErr 40 | foreign import ccall unsafe "af_get_numdims" 41 | af_get_numdims :: Ptr CUInt -> AFArray -> IO AFErr 42 | foreign import ccall unsafe "af_is_empty" 43 | af_is_empty :: Ptr CBool -> AFArray -> IO AFErr 44 | foreign import ccall unsafe "af_is_scalar" 45 | af_is_scalar :: Ptr CBool -> AFArray -> IO AFErr 46 | foreign import ccall unsafe "af_is_row" 47 | af_is_row :: Ptr CBool -> AFArray -> IO AFErr 48 | foreign import ccall unsafe "af_is_column" 49 | af_is_column :: Ptr CBool -> AFArray -> IO AFErr 50 | foreign import ccall unsafe "af_is_vector" 51 | af_is_vector :: Ptr CBool -> AFArray -> IO AFErr 52 | foreign import ccall unsafe "af_is_complex" 53 | af_is_complex :: Ptr CBool -> AFArray -> IO AFErr 54 | foreign import ccall unsafe "af_is_real" 55 | af_is_real :: Ptr CBool -> AFArray -> IO AFErr 56 | foreign import ccall unsafe "af_is_double" 57 | af_is_double :: Ptr CBool -> AFArray -> IO AFErr 58 | foreign import ccall unsafe "af_is_single" 59 | af_is_single :: Ptr CBool -> AFArray -> IO AFErr 60 | foreign import ccall unsafe "af_is_realfloating" 61 | af_is_realfloating :: Ptr CBool -> AFArray -> IO AFErr 62 | foreign import ccall unsafe "af_is_floating" 63 | af_is_floating :: Ptr CBool -> AFArray -> IO AFErr 64 | foreign import ccall unsafe "af_is_integer" 65 | af_is_integer :: Ptr CBool -> AFArray -> IO AFErr 66 | foreign import ccall unsafe "af_is_bool" 67 | af_is_bool :: Ptr CBool -> AFArray -> IO AFErr 68 | foreign import ccall unsafe "af_is_sparse" 69 | af_is_sparse :: Ptr CBool -> AFArray -> IO AFErr 70 | foreign import ccall unsafe "af_get_scalar" 71 | af_get_scalar :: Ptr () -> AFArray -> IO AFErr 72 | -------------------------------------------------------------------------------- /src/ArrayFire/Internal/BLAS.hsc: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE CPP #-} 2 | module ArrayFire.Internal.BLAS where 3 | 4 | import ArrayFire.Internal.Defines 5 | 6 | import Foreign.Ptr 7 | import Foreign.C.Types 8 | 9 | #include "af/blas.h" 10 | foreign import ccall unsafe "af_matmul" 11 | af_matmul :: Ptr AFArray -> AFArray -> AFArray -> AFMatProp -> AFMatProp -> IO AFErr 12 | foreign import ccall unsafe "af_dot" 13 | af_dot :: Ptr AFArray -> AFArray -> AFArray -> AFMatProp -> AFMatProp -> IO AFErr 14 | foreign import ccall unsafe "af_dot_all" 15 | af_dot_all :: Ptr Double -> Ptr Double -> AFArray -> AFArray -> AFMatProp -> AFMatProp -> IO AFErr 16 | foreign import ccall unsafe "af_transpose" 17 | af_transpose :: Ptr AFArray -> AFArray -> CBool -> IO AFErr 18 | foreign import ccall unsafe "af_transpose_inplace" 19 | af_transpose_inplace :: AFArray -> CBool -> IO AFErr 20 | -------------------------------------------------------------------------------- /src/ArrayFire/Internal/Backend.hsc: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE CPP #-} 2 | module ArrayFire.Internal.Backend where 3 | 4 | import ArrayFire.Internal.Defines 5 | 6 | import Foreign.Ptr 7 | import Foreign.C.Types 8 | 9 | #include "af/backend.h" 10 | foreign import ccall unsafe "af_set_backend" 11 | af_set_backend :: AFBackend -> IO AFErr 12 | foreign import ccall unsafe "af_get_backend_count" 13 | af_get_backend_count :: Ptr CUInt -> IO AFErr 14 | foreign import ccall unsafe "af_get_available_backends" 15 | af_get_available_backends :: Ptr CInt -> IO AFErr 16 | foreign import ccall unsafe "af_get_backend_id" 17 | af_get_backend_id :: Ptr AFBackend -> AFArray -> IO AFErr 18 | foreign import ccall unsafe "af_get_active_backend" 19 | af_get_active_backend :: Ptr AFBackend -> IO AFErr 20 | foreign import ccall unsafe "af_get_device_id" 21 | af_get_device_id :: Ptr CInt -> AFArray -> IO AFErr 22 | -------------------------------------------------------------------------------- /src/ArrayFire/Internal/CUDA.hsc: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE CPP #-} 2 | module ArrayFire.Internal.CUDA where 3 | 4 | import ArrayFire.Internal.Defines 5 | import ArrayFire.Internal.Types 6 | import Data.Word 7 | import Data.Int 8 | import Foreign.Ptr 9 | import Foreign.C.Types 10 | 11 | #include "af/cuda.h" 12 | foreign import ccall unsafe "afcu_get_stream" 13 | afcu_get_stream :: Ptr CudaStreamT -> CInt -> IO AFErr 14 | foreign import ccall unsafe "afcu_get_native_id" 15 | afcu_get_native_id :: Ptr CInt -> CInt -> IO AFErr 16 | foreign import ccall unsafe "afcu_set_native_id" 17 | afcu_set_native_id :: CInt -> IO AFErr -------------------------------------------------------------------------------- /src/ArrayFire/Internal/Data.hsc: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE CPP #-} 2 | module ArrayFire.Internal.Data where 3 | 4 | import ArrayFire.Internal.Defines 5 | 6 | import Foreign.Ptr 7 | import Foreign.C.Types 8 | 9 | #include "af/data.h" 10 | foreign import ccall unsafe "af_constant" 11 | af_constant :: Ptr AFArray -> Double -> CUInt -> Ptr DimT -> AFDtype -> IO AFErr 12 | foreign import ccall unsafe "af_constant_complex" 13 | af_constant_complex :: Ptr AFArray -> Double -> Double -> CUInt -> Ptr DimT -> AFDtype -> IO AFErr 14 | foreign import ccall unsafe "af_constant_long" 15 | af_constant_long :: Ptr AFArray -> IntL -> CUInt -> Ptr DimT -> IO AFErr 16 | foreign import ccall unsafe "af_constant_ulong" 17 | af_constant_ulong :: Ptr AFArray -> UIntL -> CUInt -> Ptr DimT -> IO AFErr 18 | foreign import ccall unsafe "af_range" 19 | af_range :: Ptr AFArray -> CUInt -> Ptr DimT -> CInt -> AFDtype -> IO AFErr 20 | foreign import ccall unsafe "af_iota" 21 | af_iota :: Ptr AFArray -> CUInt -> Ptr DimT -> CUInt -> Ptr DimT -> AFDtype -> IO AFErr 22 | foreign import ccall unsafe "af_identity" 23 | af_identity :: Ptr AFArray -> CUInt -> Ptr DimT -> AFDtype -> IO AFErr 24 | foreign import ccall unsafe "af_diag_create" 25 | af_diag_create :: Ptr AFArray -> AFArray -> CInt -> IO AFErr 26 | foreign import ccall unsafe "af_diag_extract" 27 | af_diag_extract :: Ptr AFArray -> AFArray -> CInt -> IO AFErr 28 | foreign import ccall unsafe "af_join" 29 | af_join :: Ptr AFArray -> CInt -> AFArray -> AFArray -> IO AFErr 30 | foreign import ccall unsafe "af_join_many" 31 | af_join_many :: Ptr AFArray -> CInt -> CUInt -> Ptr AFArray -> IO AFErr 32 | foreign import ccall unsafe "af_tile" 33 | af_tile :: Ptr AFArray -> AFArray -> CUInt -> CUInt -> CUInt -> CUInt -> IO AFErr 34 | foreign import ccall unsafe "af_reorder" 35 | af_reorder :: Ptr AFArray -> AFArray -> CUInt -> CUInt -> CUInt -> CUInt -> IO AFErr 36 | foreign import ccall unsafe "af_shift" 37 | af_shift :: Ptr AFArray -> AFArray -> CInt -> CInt -> CInt -> CInt -> IO AFErr 38 | foreign import ccall unsafe "af_moddims" 39 | af_moddims :: Ptr AFArray -> AFArray -> CUInt -> Ptr DimT -> IO AFErr 40 | foreign import ccall unsafe "af_flat" 41 | af_flat :: Ptr AFArray -> AFArray -> IO AFErr 42 | foreign import ccall unsafe "af_flip" 43 | af_flip :: Ptr AFArray -> AFArray -> CUInt -> IO AFErr 44 | foreign import ccall unsafe "af_lower" 45 | af_lower :: Ptr AFArray -> AFArray -> CBool -> IO AFErr 46 | foreign import ccall unsafe "af_upper" 47 | af_upper :: Ptr AFArray -> AFArray -> CBool -> IO AFErr 48 | foreign import ccall unsafe "af_select" 49 | af_select :: Ptr AFArray -> AFArray -> AFArray -> AFArray -> IO AFErr 50 | foreign import ccall unsafe "af_select_scalar_r" 51 | af_select_scalar_r :: Ptr AFArray -> AFArray -> AFArray -> Double -> IO AFErr 52 | foreign import ccall unsafe "af_select_scalar_l" 53 | af_select_scalar_l :: Ptr AFArray -> AFArray -> Double -> AFArray -> IO AFErr 54 | foreign import ccall unsafe "af_replace" 55 | af_replace :: AFArray -> AFArray -> AFArray -> IO AFErr 56 | foreign import ccall unsafe "af_replace_scalar" 57 | af_replace_scalar :: AFArray -> AFArray -> Double -> IO AFErr 58 | -------------------------------------------------------------------------------- /src/ArrayFire/Internal/Device.hsc: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE CPP #-} 2 | module ArrayFire.Internal.Device where 3 | 4 | import ArrayFire.Internal.Defines 5 | 6 | import Foreign.Ptr 7 | import Foreign.C.Types 8 | 9 | #include "af/device.h" 10 | foreign import ccall unsafe "af_info" 11 | af_info :: IO AFErr 12 | foreign import ccall unsafe "af_init" 13 | af_init :: IO AFErr 14 | foreign import ccall unsafe "af_info_string" 15 | af_info_string :: Ptr (Ptr CChar) -> CBool -> IO AFErr 16 | foreign import ccall unsafe "af_device_info" 17 | af_device_info :: Ptr CChar -> Ptr CChar -> Ptr CChar -> Ptr CChar -> IO AFErr 18 | foreign import ccall unsafe "af_get_device_count" 19 | af_get_device_count :: Ptr CInt -> IO AFErr 20 | foreign import ccall unsafe "af_get_dbl_support" 21 | af_get_dbl_support :: Ptr CBool -> CInt -> IO AFErr 22 | foreign import ccall unsafe "af_set_device" 23 | af_set_device :: CInt -> IO AFErr 24 | foreign import ccall unsafe "af_get_device" 25 | af_get_device :: Ptr CInt -> IO AFErr 26 | foreign import ccall unsafe "af_sync" 27 | af_sync :: CInt -> IO AFErr 28 | foreign import ccall unsafe "af_alloc_device" 29 | af_alloc_device :: Ptr (Ptr ()) -> DimT -> IO AFErr 30 | foreign import ccall unsafe "af_free_device" 31 | af_free_device :: Ptr () -> IO AFErr 32 | foreign import ccall unsafe "af_alloc_pinned" 33 | af_alloc_pinned :: Ptr (Ptr ()) -> DimT -> IO AFErr 34 | foreign import ccall unsafe "af_free_pinned" 35 | af_free_pinned :: Ptr () -> IO AFErr 36 | foreign import ccall unsafe "af_alloc_host" 37 | af_alloc_host :: Ptr (Ptr ()) -> DimT -> IO AFErr 38 | foreign import ccall unsafe "af_free_host" 39 | af_free_host :: Ptr () -> IO AFErr 40 | foreign import ccall unsafe "af_device_array" 41 | af_device_array :: Ptr AFArray -> Ptr () -> CUInt -> Ptr DimT -> AFDtype -> IO AFErr 42 | foreign import ccall unsafe "af_device_mem_info" 43 | af_device_mem_info :: Ptr CSize -> Ptr CSize -> Ptr CSize -> Ptr CSize -> IO AFErr 44 | foreign import ccall unsafe "af_print_mem_info" 45 | af_print_mem_info :: Ptr CChar -> CInt -> IO AFErr 46 | foreign import ccall unsafe "af_device_gc" 47 | af_device_gc :: IO AFErr 48 | foreign import ccall unsafe "af_set_mem_step_size" 49 | af_set_mem_step_size :: CSize -> IO AFErr 50 | foreign import ccall unsafe "af_get_mem_step_size" 51 | af_get_mem_step_size :: Ptr CSize -> IO AFErr 52 | foreign import ccall unsafe "af_lock_device_ptr" 53 | af_lock_device_ptr :: AFArray -> IO AFErr 54 | foreign import ccall unsafe "af_unlock_device_ptr" 55 | af_unlock_device_ptr :: AFArray -> IO AFErr 56 | foreign import ccall unsafe "af_lock_array" 57 | af_lock_array :: AFArray -> IO AFErr 58 | foreign import ccall unsafe "af_is_locked_array" 59 | af_is_locked_array :: Ptr CBool -> AFArray -> IO AFErr 60 | foreign import ccall unsafe "af_get_device_ptr" 61 | af_get_device_ptr :: Ptr (Ptr ()) -> AFArray -> IO AFErr 62 | -------------------------------------------------------------------------------- /src/ArrayFire/Internal/Exception.hsc: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE CPP #-} 2 | module ArrayFire.Internal.Exception where 3 | 4 | import ArrayFire.Internal.Defines 5 | 6 | import Foreign.Ptr 7 | import Foreign.C.Types 8 | 9 | #include "af/defines.h" 10 | foreign import ccall unsafe "af_get_last_error" 11 | af_get_last_error :: Ptr (Ptr CChar) -> Ptr DimT -> IO () 12 | foreign import ccall unsafe "af_err_to_string" 13 | af_err_to_string :: AFErr -> IO (Ptr CChar) 14 | -------------------------------------------------------------------------------- /src/ArrayFire/Internal/Features.hsc: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE CPP #-} 2 | module ArrayFire.Internal.Features where 3 | 4 | import ArrayFire.Internal.Defines 5 | 6 | import Foreign.Ptr 7 | import Foreign.C.Types 8 | 9 | #include "af/features.h" 10 | foreign import ccall unsafe "af_create_features" 11 | af_create_features :: Ptr AFFeatures -> DimT -> IO AFErr 12 | foreign import ccall unsafe "af_retain_features" 13 | af_retain_features :: Ptr AFFeatures -> AFFeatures -> IO AFErr 14 | foreign import ccall unsafe "af_get_features_num" 15 | af_get_features_num :: Ptr DimT -> AFFeatures -> IO AFErr 16 | foreign import ccall unsafe "af_get_features_xpos" 17 | af_get_features_xpos :: Ptr AFArray -> AFFeatures -> IO AFErr 18 | foreign import ccall unsafe "af_get_features_ypos" 19 | af_get_features_ypos :: Ptr AFArray -> AFFeatures -> IO AFErr 20 | foreign import ccall unsafe "af_get_features_score" 21 | af_get_features_score :: Ptr AFArray -> AFFeatures -> IO AFErr 22 | foreign import ccall unsafe "af_get_features_orientation" 23 | af_get_features_orientation :: Ptr AFArray -> AFFeatures -> IO AFErr 24 | foreign import ccall unsafe "af_get_features_size" 25 | af_get_features_size :: Ptr AFArray -> AFFeatures -> IO AFErr 26 | foreign import ccall unsafe "&af_release_features" 27 | af_release_features :: FunPtr (AFFeatures -> IO ()) 28 | -------------------------------------------------------------------------------- /src/ArrayFire/Internal/Graphics.hsc: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE CPP #-} 2 | module ArrayFire.Internal.Graphics where 3 | 4 | import ArrayFire.Internal.Defines 5 | import ArrayFire.Internal.Types 6 | 7 | import Foreign.Ptr 8 | import Foreign.C.Types 9 | 10 | #include "af/graphics.h" 11 | foreign import ccall unsafe "af_create_window" 12 | af_create_window :: Ptr AFWindow -> CInt -> CInt -> Ptr CChar -> IO AFErr 13 | foreign import ccall unsafe "af_set_position" 14 | af_set_position :: AFWindow -> CUInt -> CUInt -> IO AFErr 15 | foreign import ccall unsafe "af_set_title" 16 | af_set_title :: AFWindow -> Ptr CChar -> IO AFErr 17 | foreign import ccall unsafe "af_set_size" 18 | af_set_size :: AFWindow -> CUInt -> CUInt -> IO AFErr 19 | foreign import ccall unsafe "af_draw_image" 20 | af_draw_image :: AFWindow -> AFArray -> Ptr AFCell -> IO AFErr 21 | foreign import ccall unsafe "af_draw_plot" 22 | af_draw_plot :: AFWindow -> AFArray -> AFArray -> Ptr AFCell -> IO AFErr 23 | foreign import ccall unsafe "af_draw_plot3" 24 | af_draw_plot3 :: AFWindow -> AFArray -> Ptr AFCell -> IO AFErr 25 | foreign import ccall unsafe "af_draw_plot_nd" 26 | af_draw_plot_nd :: AFWindow -> AFArray -> Ptr AFCell -> IO AFErr 27 | foreign import ccall unsafe "af_draw_plot_2d" 28 | af_draw_plot_2d :: AFWindow -> AFArray -> AFArray -> Ptr AFCell -> IO AFErr 29 | foreign import ccall unsafe "af_draw_plot_3d" 30 | af_draw_plot_3d :: AFWindow -> AFArray -> AFArray -> AFArray -> Ptr AFCell -> IO AFErr 31 | foreign import ccall unsafe "af_draw_scatter" 32 | af_draw_scatter :: AFWindow -> AFArray -> AFArray -> AFMarkerType -> Ptr AFCell -> IO AFErr 33 | foreign import ccall unsafe "af_draw_scatter3" 34 | af_draw_scatter3 :: AFWindow -> AFArray -> AFMarkerType -> Ptr AFCell -> IO AFErr 35 | foreign import ccall unsafe "af_draw_scatter_nd" 36 | af_draw_scatter_nd :: AFWindow -> AFArray -> AFMarkerType -> Ptr AFCell -> IO AFErr 37 | foreign import ccall unsafe "af_draw_scatter_2d" 38 | af_draw_scatter_2d :: AFWindow -> AFArray -> AFArray -> AFMarkerType -> Ptr AFCell -> IO AFErr 39 | foreign import ccall unsafe "af_draw_scatter_3d" 40 | af_draw_scatter_3d :: AFWindow -> AFArray -> AFArray -> AFArray -> AFMarkerType -> Ptr AFCell -> IO AFErr 41 | foreign import ccall unsafe "af_draw_hist" 42 | af_draw_hist :: AFWindow -> AFArray -> Double -> Double -> Ptr AFCell -> IO AFErr 43 | foreign import ccall unsafe "af_draw_surface" 44 | af_draw_surface :: AFWindow -> AFArray -> AFArray -> AFArray -> Ptr AFCell -> IO AFErr 45 | foreign import ccall unsafe "af_draw_vector_field_nd" 46 | af_draw_vector_field_nd :: AFWindow -> AFArray -> AFArray -> Ptr AFCell -> IO AFErr 47 | foreign import ccall unsafe "af_draw_vector_field_3d" 48 | af_draw_vector_field_3d :: AFWindow -> AFArray -> AFArray -> AFArray -> AFArray -> AFArray -> AFArray -> Ptr AFCell -> IO AFErr 49 | foreign import ccall unsafe "af_draw_vector_field_2d" 50 | af_draw_vector_field_2d :: AFWindow -> AFArray -> AFArray -> AFArray -> AFArray -> Ptr AFCell -> IO AFErr 51 | foreign import ccall unsafe "af_grid" 52 | af_grid :: AFWindow -> CInt -> CInt -> IO AFErr 53 | foreign import ccall unsafe "af_set_axes_limits_compute" 54 | af_set_axes_limits_compute :: AFWindow -> AFArray -> AFArray -> AFArray -> CBool -> Ptr AFCell -> IO AFErr 55 | foreign import ccall unsafe "af_set_axes_limits_2d" 56 | af_set_axes_limits_2d :: AFWindow -> Float -> Float -> Float -> Float -> CBool -> Ptr AFCell -> IO AFErr 57 | foreign import ccall unsafe "af_set_axes_limits_3d" 58 | af_set_axes_limits_3d :: AFWindow -> Float -> Float -> Float -> Float -> Float -> Float -> CBool -> Ptr AFCell -> IO AFErr 59 | foreign import ccall unsafe "af_set_axes_titles" 60 | af_set_axes_titles :: AFWindow -> Ptr CChar -> Ptr CChar -> Ptr CChar -> Ptr AFCell -> IO AFErr 61 | foreign import ccall unsafe "af_show" 62 | af_show :: AFWindow -> IO AFErr 63 | foreign import ccall unsafe "af_is_window_closed" 64 | af_is_window_closed :: Ptr CBool -> AFWindow -> IO AFErr 65 | foreign import ccall unsafe "af_set_visibility" 66 | af_set_visibility :: AFWindow -> CBool -> IO AFErr 67 | foreign import ccall unsafe "af_destroy_window" 68 | af_destroy_window :: AFWindow -> IO AFErr 69 | -------------------------------------------------------------------------------- /src/ArrayFire/Internal/Image.hsc: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE CPP #-} 2 | module ArrayFire.Internal.Image where 3 | 4 | import ArrayFire.Internal.Defines 5 | 6 | import Foreign.Ptr 7 | import Foreign.C.Types 8 | 9 | #include "af/image.h" 10 | foreign import ccall unsafe "af_gradient" 11 | af_gradient :: Ptr AFArray -> Ptr AFArray -> AFArray -> IO AFErr 12 | foreign import ccall unsafe "af_load_image" 13 | af_load_image :: Ptr AFArray -> Ptr CChar -> CBool -> IO AFErr 14 | foreign import ccall unsafe "af_save_image" 15 | af_save_image :: Ptr CChar -> AFArray -> IO AFErr 16 | foreign import ccall unsafe "af_load_image_memory" 17 | af_load_image_memory :: Ptr AFArray -> Ptr () -> IO AFErr 18 | foreign import ccall unsafe "af_save_image_memory" 19 | af_save_image_memory :: Ptr (Ptr ()) -> AFArray -> AFImageFormat -> IO AFErr 20 | foreign import ccall unsafe "af_delete_image_memory" 21 | af_delete_image_memory :: Ptr () -> IO AFErr 22 | foreign import ccall unsafe "af_load_image_native" 23 | af_load_image_native :: Ptr AFArray -> Ptr CChar -> IO AFErr 24 | foreign import ccall unsafe "af_save_image_native" 25 | af_save_image_native :: Ptr CChar -> AFArray -> IO AFErr 26 | foreign import ccall unsafe "af_is_image_io_available" 27 | af_is_image_io_available :: Ptr CBool -> IO AFErr 28 | foreign import ccall unsafe "af_resize" 29 | af_resize :: Ptr AFArray -> AFArray -> DimT -> DimT -> AFInterpType -> IO AFErr 30 | foreign import ccall unsafe "af_transform" 31 | af_transform :: Ptr AFArray -> AFArray -> AFArray -> DimT -> DimT -> AFInterpType -> CBool -> IO AFErr 32 | foreign import ccall unsafe "af_transform_coordinates" 33 | af_transform_coordinates :: Ptr AFArray -> AFArray -> Float -> Float -> IO AFErr 34 | foreign import ccall unsafe "af_rotate" 35 | af_rotate :: Ptr AFArray -> AFArray -> Float -> CBool -> AFInterpType -> IO AFErr 36 | foreign import ccall unsafe "af_translate" 37 | af_translate :: Ptr AFArray -> AFArray -> Float -> Float -> DimT -> DimT -> AFInterpType -> IO AFErr 38 | foreign import ccall unsafe "af_scale" 39 | af_scale :: Ptr AFArray -> AFArray -> Float -> Float -> DimT -> DimT -> AFInterpType -> IO AFErr 40 | foreign import ccall unsafe "af_skew" 41 | af_skew :: Ptr AFArray -> AFArray -> Float -> Float -> DimT -> DimT -> AFInterpType -> CBool -> IO AFErr 42 | foreign import ccall unsafe "af_histogram" 43 | af_histogram :: Ptr AFArray -> AFArray -> CUInt -> Double -> Double -> IO AFErr 44 | foreign import ccall unsafe "af_dilate" 45 | af_dilate :: Ptr AFArray -> AFArray -> AFArray -> IO AFErr 46 | foreign import ccall unsafe "af_dilate3" 47 | af_dilate3 :: Ptr AFArray -> AFArray -> AFArray -> IO AFErr 48 | foreign import ccall unsafe "af_erode" 49 | af_erode :: Ptr AFArray -> AFArray -> AFArray -> IO AFErr 50 | foreign import ccall unsafe "af_erode3" 51 | af_erode3 :: Ptr AFArray -> AFArray -> AFArray -> IO AFErr 52 | foreign import ccall unsafe "af_bilateral" 53 | af_bilateral :: Ptr AFArray -> AFArray -> Float -> Float -> CBool -> IO AFErr 54 | foreign import ccall unsafe "af_mean_shift" 55 | af_mean_shift :: Ptr AFArray -> AFArray -> Float -> Float -> CUInt -> CBool -> IO AFErr 56 | foreign import ccall unsafe "af_minfilt" 57 | af_minfilt :: Ptr AFArray -> AFArray -> DimT -> DimT -> AFBorderType -> IO AFErr 58 | foreign import ccall unsafe "af_maxfilt" 59 | af_maxfilt :: Ptr AFArray -> AFArray -> DimT -> DimT -> AFBorderType -> IO AFErr 60 | foreign import ccall unsafe "af_regions" 61 | af_regions :: Ptr AFArray -> AFArray -> AFConnectivity -> AFDtype -> IO AFErr 62 | foreign import ccall unsafe "af_sobel_operator" 63 | af_sobel_operator :: Ptr AFArray -> Ptr AFArray -> AFArray -> CUInt -> IO AFErr 64 | foreign import ccall unsafe "af_rgb2gray" 65 | af_rgb2gray :: Ptr AFArray -> AFArray -> Float -> Float -> Float -> IO AFErr 66 | foreign import ccall unsafe "af_gray2rgb" 67 | af_gray2rgb :: Ptr AFArray -> AFArray -> Float -> Float -> Float -> IO AFErr 68 | foreign import ccall unsafe "af_hist_equal" 69 | af_hist_equal :: Ptr AFArray -> AFArray -> AFArray -> IO AFErr 70 | foreign import ccall unsafe "af_gaussian_kernel" 71 | af_gaussian_kernel :: Ptr AFArray -> CInt -> CInt -> Double -> Double -> IO AFErr 72 | foreign import ccall unsafe "af_hsv2rgb" 73 | af_hsv2rgb :: Ptr AFArray -> AFArray -> IO AFErr 74 | foreign import ccall unsafe "af_rgb2hsv" 75 | af_rgb2hsv :: Ptr AFArray -> AFArray -> IO AFErr 76 | foreign import ccall unsafe "af_color_space" 77 | af_color_space :: Ptr AFArray -> AFArray -> AFCSpace -> AFCSpace -> IO AFErr 78 | foreign import ccall unsafe "af_unwrap" 79 | af_unwrap :: Ptr AFArray -> AFArray -> DimT -> DimT -> DimT -> DimT -> DimT -> DimT -> CBool -> IO AFErr 80 | foreign import ccall unsafe "af_wrap" 81 | af_wrap :: Ptr AFArray -> AFArray -> DimT -> DimT -> DimT -> DimT -> DimT -> DimT -> DimT -> DimT -> CBool -> IO AFErr 82 | foreign import ccall unsafe "af_sat" 83 | af_sat :: Ptr AFArray -> AFArray -> IO AFErr 84 | foreign import ccall unsafe "af_ycbcr2rgb" 85 | af_ycbcr2rgb :: Ptr AFArray -> AFArray -> AFYccStd -> IO AFErr 86 | foreign import ccall unsafe "af_rgb2ycbcr" 87 | af_rgb2ycbcr :: Ptr AFArray -> AFArray -> AFYccStd -> IO AFErr 88 | foreign import ccall unsafe "af_moments" 89 | af_moments :: Ptr AFArray -> AFArray -> AFMomentType -> IO AFErr 90 | foreign import ccall unsafe "af_moments_all" 91 | af_moments_all :: Ptr Double -> AFArray -> AFMomentType -> IO AFErr 92 | foreign import ccall unsafe "af_canny" 93 | af_canny :: Ptr AFArray -> AFArray -> AFCannyThreshold -> Float -> Float -> CUInt -> CBool -> IO AFErr 94 | foreign import ccall unsafe "af_anisotropic_diffusion" 95 | af_anisotropic_diffusion :: Ptr AFArray -> AFArray -> Float -> Float -> CUInt -> AFFluxFunction -> AFDiffusionEq -> IO AFErr 96 | -------------------------------------------------------------------------------- /src/ArrayFire/Internal/Index.hsc: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE CPP #-} 2 | module ArrayFire.Internal.Index where 3 | 4 | import ArrayFire.Internal.Defines 5 | import ArrayFire.Internal.Types 6 | 7 | import Foreign.Ptr 8 | import Foreign.C.Types 9 | 10 | #include "af/index.h" 11 | foreign import ccall unsafe "af_index" 12 | af_index :: Ptr AFArray -> AFArray -> CUInt -> Ptr AFSeq -> IO AFErr 13 | foreign import ccall unsafe "af_lookup" 14 | af_lookup :: Ptr AFArray -> AFArray -> AFArray -> CUInt -> IO AFErr 15 | foreign import ccall unsafe "af_assign_seq" 16 | af_assign_seq :: Ptr AFArray -> AFArray -> CUInt -> Ptr AFSeq -> AFArray -> IO AFErr 17 | foreign import ccall unsafe "af_index_gen" 18 | af_index_gen :: Ptr AFArray -> AFArray -> DimT -> Ptr AFIndex -> IO AFErr 19 | foreign import ccall unsafe "af_assign_gen" 20 | af_assign_gen :: Ptr AFArray -> AFArray -> DimT -> Ptr AFIndex -> AFArray -> IO AFErr 21 | foreign import ccall unsafe "af_create_indexers" 22 | af_create_indexers :: Ptr (Ptr AFIndex) -> IO AFErr 23 | foreign import ccall unsafe "af_set_array_indexer" 24 | af_set_array_indexer :: Ptr AFIndex -> AFArray -> DimT -> IO AFErr 25 | foreign import ccall unsafe "af_set_seq_indexer" 26 | af_set_seq_indexer :: Ptr AFIndex -> Ptr AFSeq -> DimT -> CBool -> IO AFErr 27 | foreign import ccall unsafe "af_set_seq_param_indexer" 28 | af_set_seq_param_indexer :: Ptr AFIndex -> Double -> Double -> Double -> DimT -> CBool -> IO AFErr 29 | foreign import ccall unsafe "af_release_indexers" 30 | af_release_indexers :: Ptr AFIndex -> IO AFErr 31 | -------------------------------------------------------------------------------- /src/ArrayFire/Internal/Internal.hsc: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE CPP #-} 2 | module ArrayFire.Internal.Internal where 3 | 4 | import ArrayFire.Internal.Defines 5 | import Foreign.Ptr 6 | import Foreign.C.Types 7 | 8 | #include "af/internal.h" 9 | foreign import ccall unsafe "af_create_strided_array" 10 | af_create_strided_array :: Ptr AFArray -> Ptr () -> DimT -> CUInt -> Ptr DimT -> Ptr DimT -> AFDtype -> AFSource -> IO AFErr 11 | foreign import ccall unsafe "af_get_strides" 12 | af_get_strides :: Ptr DimT -> Ptr DimT -> Ptr DimT -> Ptr DimT -> AFArray -> IO AFErr 13 | foreign import ccall unsafe "af_get_offset" 14 | af_get_offset :: Ptr DimT -> AFArray -> IO AFErr 15 | foreign import ccall unsafe "af_get_raw_ptr" 16 | af_get_raw_ptr :: Ptr (Ptr ()) -> AFArray -> IO AFErr 17 | foreign import ccall unsafe "af_is_linear" 18 | af_is_linear :: Ptr CBool -> AFArray -> IO AFErr 19 | foreign import ccall unsafe "af_is_owner" 20 | af_is_owner :: Ptr CBool -> AFArray -> IO AFErr 21 | foreign import ccall unsafe "af_get_allocated_bytes" 22 | af_get_allocated_bytes :: Ptr CSize -> AFArray -> IO AFErr 23 | -------------------------------------------------------------------------------- /src/ArrayFire/Internal/LAPACK.hsc: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE CPP #-} 2 | module ArrayFire.Internal.LAPACK where 3 | 4 | import ArrayFire.Internal.Defines 5 | 6 | import Foreign.Ptr 7 | import Foreign.C.Types 8 | 9 | #include "af/lapack.h" 10 | foreign import ccall unsafe "af_svd" 11 | af_svd :: Ptr AFArray -> Ptr AFArray -> Ptr AFArray -> AFArray -> IO AFErr 12 | foreign import ccall unsafe "af_svd_inplace" 13 | af_svd_inplace :: Ptr AFArray -> Ptr AFArray -> Ptr AFArray -> AFArray -> IO AFErr 14 | foreign import ccall unsafe "af_lu" 15 | af_lu :: Ptr AFArray -> Ptr AFArray -> Ptr AFArray -> AFArray -> IO AFErr 16 | foreign import ccall unsafe "af_lu_inplace" 17 | af_lu_inplace :: Ptr AFArray -> AFArray -> CBool -> IO AFErr 18 | foreign import ccall unsafe "af_qr" 19 | af_qr :: Ptr AFArray -> Ptr AFArray -> Ptr AFArray -> AFArray -> IO AFErr 20 | foreign import ccall unsafe "af_qr_inplace" 21 | af_qr_inplace :: Ptr AFArray -> AFArray -> IO AFErr 22 | foreign import ccall unsafe "af_cholesky" 23 | af_cholesky :: Ptr AFArray -> Ptr CInt -> AFArray -> CBool -> IO AFErr 24 | foreign import ccall unsafe "af_cholesky_inplace" 25 | af_cholesky_inplace :: Ptr CInt -> AFArray -> CBool -> IO AFErr 26 | foreign import ccall unsafe "af_solve" 27 | af_solve :: Ptr AFArray -> AFArray -> AFArray -> AFMatProp -> IO AFErr 28 | foreign import ccall unsafe "af_solve_lu" 29 | af_solve_lu :: Ptr AFArray -> AFArray -> AFArray -> AFArray -> AFMatProp -> IO AFErr 30 | foreign import ccall unsafe "af_inverse" 31 | af_inverse :: Ptr AFArray -> AFArray -> AFMatProp -> IO AFErr 32 | foreign import ccall unsafe "af_rank" 33 | af_rank :: Ptr CUInt -> AFArray -> Double -> IO AFErr 34 | foreign import ccall unsafe "af_det" 35 | af_det :: Ptr Double -> Ptr Double -> AFArray -> IO AFErr 36 | foreign import ccall unsafe "af_norm" 37 | af_norm :: Ptr Double -> AFArray -> AFNormType -> Double -> Double -> IO AFErr 38 | foreign import ccall unsafe "af_is_lapack_available" 39 | af_is_lapack_available :: Ptr CBool -> IO AFErr 40 | -------------------------------------------------------------------------------- /src/ArrayFire/Internal/OpenCL.hsc: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE CPP #-} 2 | module ArrayFire.Internal.OpenCL where 3 | 4 | import ArrayFire.Internal.Defines 5 | import ArrayFire.Internal.Types 6 | import Data.Word 7 | import Data.Int 8 | import Foreign.Ptr 9 | import Foreign.C.Types 10 | 11 | #include "af/opencl.h" 12 | foreign import ccall unsafe "afcl_get_context" 13 | afcl_get_context :: Ptr ClContext -> CBool -> IO AFErr 14 | foreign import ccall unsafe "afcl_get_queue" 15 | afcl_get_queue :: Ptr ClCommandQueue -> CBool -> IO AFErr 16 | foreign import ccall unsafe "afcl_get_device_id" 17 | afcl_get_device_id :: Ptr ClDeviceId -> IO AFErr 18 | foreign import ccall unsafe "afcl_set_device_id" 19 | afcl_set_device_id :: ClDeviceId -> IO AFErr 20 | foreign import ccall unsafe "afcl_add_device_context" 21 | afcl_add_device_context :: ClDeviceId -> ClContext -> ClCommandQueue -> IO AFErr 22 | foreign import ccall unsafe "afcl_set_device_context" 23 | afcl_set_device_context :: ClDeviceId -> ClContext -> IO AFErr 24 | foreign import ccall unsafe "afcl_delete_device_context" 25 | afcl_delete_device_context :: ClDeviceId -> ClContext -> IO AFErr 26 | foreign import ccall unsafe "afcl_get_device_type" 27 | afcl_get_device_type :: Ptr AfclDeviceType -> IO AFErr 28 | foreign import ccall unsafe "afcl_get_platform" 29 | afcl_get_platform :: Ptr AFCLPlatform -> IO AFErr -------------------------------------------------------------------------------- /src/ArrayFire/Internal/Random.hsc: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE CPP #-} 2 | module ArrayFire.Internal.Random where 3 | 4 | import ArrayFire.Internal.Defines 5 | 6 | import Foreign.Ptr 7 | import Foreign.C.Types 8 | 9 | #include "af/random.h" 10 | foreign import ccall unsafe "af_create_random_engine" 11 | af_create_random_engine :: Ptr AFRandomEngine -> AFRandomEngineType -> UIntL -> IO AFErr 12 | foreign import ccall unsafe "af_retain_random_engine" 13 | af_retain_random_engine :: Ptr AFRandomEngine -> AFRandomEngine -> IO AFErr 14 | foreign import ccall unsafe "af_random_engine_set_type" 15 | af_random_engine_set_type :: Ptr AFRandomEngine -> AFRandomEngineType -> IO AFErr 16 | foreign import ccall unsafe "af_random_engine_get_type" 17 | af_random_engine_get_type :: Ptr AFRandomEngineType -> AFRandomEngine -> IO AFErr 18 | foreign import ccall unsafe "af_random_uniform" 19 | af_random_uniform :: Ptr AFArray -> CUInt -> Ptr DimT -> AFDtype -> AFRandomEngine -> IO AFErr 20 | foreign import ccall unsafe "af_random_normal" 21 | af_random_normal :: Ptr AFArray -> CUInt -> Ptr DimT -> AFDtype -> AFRandomEngine -> IO AFErr 22 | foreign import ccall unsafe "af_random_engine_set_seed" 23 | af_random_engine_set_seed :: Ptr AFRandomEngine -> UIntL -> IO AFErr 24 | foreign import ccall unsafe "af_get_default_random_engine" 25 | af_get_default_random_engine :: Ptr AFRandomEngine -> IO AFErr 26 | foreign import ccall unsafe "af_set_default_random_engine_type" 27 | af_set_default_random_engine_type :: AFRandomEngineType -> IO AFErr 28 | foreign import ccall unsafe "af_random_engine_get_seed" 29 | af_random_engine_get_seed :: Ptr UIntL -> AFRandomEngine -> IO AFErr 30 | foreign import ccall unsafe "af_release_random_engine" 31 | af_release_random_engine :: AFRandomEngine -> IO AFErr 32 | foreign import ccall unsafe "af_randu" 33 | af_randu :: Ptr AFArray -> CUInt -> Ptr DimT -> AFDtype -> IO AFErr 34 | foreign import ccall unsafe "af_randn" 35 | af_randn :: Ptr AFArray -> CUInt -> Ptr DimT -> AFDtype -> IO AFErr 36 | foreign import ccall unsafe "af_set_seed" 37 | af_set_seed :: UIntL -> IO AFErr 38 | foreign import ccall unsafe "af_get_seed" 39 | af_get_seed :: Ptr UIntL -> IO AFErr 40 | -------------------------------------------------------------------------------- /src/ArrayFire/Internal/Seq.hsc: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE CPP #-} 2 | module ArrayFire.Internal.Seq where 3 | 4 | import ArrayFire.Internal.Defines 5 | import ArrayFire.Internal.Types 6 | import Data.Word 7 | import Data.Int 8 | import Foreign.Ptr 9 | import Foreign.C.Types 10 | 11 | #include "af/seq.h" 12 | foreign import ccall unsafe "af_make_seq" 13 | af_make_seq :: Double -> Double -> Double -> IO AFSeq -------------------------------------------------------------------------------- /src/ArrayFire/Internal/Signal.hsc: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE CPP #-} 2 | module ArrayFire.Internal.Signal where 3 | 4 | import ArrayFire.Internal.Defines 5 | import Foreign.Ptr 6 | import Foreign.C.Types 7 | 8 | #include "af/signal.h" 9 | foreign import ccall unsafe "af_approx1" 10 | af_approx1 :: Ptr AFArray -> AFArray -> AFArray -> AFInterpType -> Float -> IO AFErr 11 | foreign import ccall unsafe "af_approx2" 12 | af_approx2 :: Ptr AFArray -> AFArray -> AFArray -> AFArray -> AFInterpType -> Float -> IO AFErr 13 | foreign import ccall unsafe "af_fft" 14 | af_fft :: Ptr AFArray -> AFArray -> Double -> DimT -> IO AFErr 15 | foreign import ccall unsafe "af_fft_inplace" 16 | af_fft_inplace :: AFArray -> Double -> IO AFErr 17 | foreign import ccall unsafe "af_fft2" 18 | af_fft2 :: Ptr AFArray -> AFArray -> Double -> DimT -> DimT -> IO AFErr 19 | foreign import ccall unsafe "af_fft2_inplace" 20 | af_fft2_inplace :: AFArray -> Double -> IO AFErr 21 | foreign import ccall unsafe "af_fft3" 22 | af_fft3 :: Ptr AFArray -> AFArray -> Double -> DimT -> DimT -> DimT -> IO AFErr 23 | foreign import ccall unsafe "af_fft3_inplace" 24 | af_fft3_inplace :: AFArray -> Double -> IO AFErr 25 | foreign import ccall unsafe "af_ifft" 26 | af_ifft :: Ptr AFArray -> AFArray -> Double -> DimT -> IO AFErr 27 | foreign import ccall unsafe "af_ifft_inplace" 28 | af_ifft_inplace :: AFArray -> Double -> IO AFErr 29 | foreign import ccall unsafe "af_ifft2" 30 | af_ifft2 :: Ptr AFArray -> AFArray -> Double -> DimT -> DimT -> IO AFErr 31 | foreign import ccall unsafe "af_ifft2_inplace" 32 | af_ifft2_inplace :: AFArray -> Double -> IO AFErr 33 | foreign import ccall unsafe "af_ifft3" 34 | af_ifft3 :: Ptr AFArray -> AFArray -> Double -> DimT -> DimT -> DimT -> IO AFErr 35 | foreign import ccall unsafe "af_ifft3_inplace" 36 | af_ifft3_inplace :: AFArray -> Double -> IO AFErr 37 | foreign import ccall unsafe "af_fft_r2c" 38 | af_fft_r2c :: Ptr AFArray -> AFArray -> Double -> DimT -> IO AFErr 39 | foreign import ccall unsafe "af_fft2_r2c" 40 | af_fft2_r2c :: Ptr AFArray -> AFArray -> Double -> DimT -> DimT -> IO AFErr 41 | foreign import ccall unsafe "af_fft3_r2c" 42 | af_fft3_r2c :: Ptr AFArray -> AFArray -> Double -> DimT -> DimT -> DimT -> IO AFErr 43 | foreign import ccall unsafe "af_fft_c2r" 44 | af_fft_c2r :: Ptr AFArray -> AFArray -> Double -> CBool -> IO AFErr 45 | foreign import ccall unsafe "af_fft2_c2r" 46 | af_fft2_c2r :: Ptr AFArray -> AFArray -> Double -> CBool -> IO AFErr 47 | foreign import ccall unsafe "af_fft3_c2r" 48 | af_fft3_c2r :: Ptr AFArray -> AFArray -> Double -> CBool -> IO AFErr 49 | foreign import ccall unsafe "af_convolve1" 50 | af_convolve1 :: Ptr AFArray -> AFArray -> AFArray -> AFConvMode -> AFConvDomain -> IO AFErr 51 | foreign import ccall unsafe "af_convolve2" 52 | af_convolve2 :: Ptr AFArray -> AFArray -> AFArray -> AFConvMode -> AFConvDomain -> IO AFErr 53 | foreign import ccall unsafe "af_convolve3" 54 | af_convolve3 :: Ptr AFArray -> AFArray -> AFArray -> AFConvMode -> AFConvDomain -> IO AFErr 55 | foreign import ccall unsafe "af_convolve2_sep" 56 | af_convolve2_sep :: Ptr AFArray -> AFArray -> AFArray -> AFArray -> AFConvMode -> IO AFErr 57 | foreign import ccall unsafe "af_fft_convolve1" 58 | af_fft_convolve1 :: Ptr AFArray -> AFArray -> AFArray -> AFConvMode -> IO AFErr 59 | foreign import ccall unsafe "af_fft_convolve2" 60 | af_fft_convolve2 :: Ptr AFArray -> AFArray -> AFArray -> AFConvMode -> IO AFErr 61 | foreign import ccall unsafe "af_fft_convolve3" 62 | af_fft_convolve3 :: Ptr AFArray -> AFArray -> AFArray -> AFConvMode -> IO AFErr 63 | foreign import ccall unsafe "af_fir" 64 | af_fir :: Ptr AFArray -> AFArray -> AFArray -> IO AFErr 65 | foreign import ccall unsafe "af_iir" 66 | af_iir :: Ptr AFArray -> AFArray -> AFArray -> AFArray -> IO AFErr 67 | foreign import ccall unsafe "af_medfilt" 68 | af_medfilt :: Ptr AFArray -> AFArray -> DimT -> DimT -> AFBorderType -> IO AFErr 69 | foreign import ccall unsafe "af_medfilt1" 70 | af_medfilt1 :: Ptr AFArray -> AFArray -> DimT -> AFBorderType -> IO AFErr 71 | foreign import ccall unsafe "af_medfilt2" 72 | af_medfilt2 :: Ptr AFArray -> AFArray -> DimT -> DimT -> AFBorderType -> IO AFErr 73 | foreign import ccall unsafe "af_set_fft_plan_cache_size" 74 | af_set_fft_plan_cache_size :: CSize -> IO AFErr 75 | -------------------------------------------------------------------------------- /src/ArrayFire/Internal/Sparse.hsc: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE CPP #-} 2 | module ArrayFire.Internal.Sparse where 3 | 4 | import ArrayFire.Internal.Defines 5 | import Foreign.Ptr 6 | import Foreign.C.Types 7 | 8 | #include "af/sparse.h" 9 | foreign import ccall unsafe "af_create_sparse_array" 10 | af_create_sparse_array :: Ptr AFArray -> DimT -> DimT -> AFArray -> AFArray -> AFArray -> AFStorage -> IO AFErr 11 | foreign import ccall unsafe "af_create_sparse_array_from_ptr" 12 | af_create_sparse_array_from_ptr :: Ptr AFArray -> DimT -> DimT -> DimT -> Ptr () -> Ptr CInt -> Ptr CInt -> AFDtype -> AFStorage -> AFSource -> IO AFErr 13 | foreign import ccall unsafe "af_create_sparse_array_from_dense" 14 | af_create_sparse_array_from_dense :: Ptr AFArray -> AFArray -> AFStorage -> IO AFErr 15 | foreign import ccall unsafe "af_sparse_convert_to" 16 | af_sparse_convert_to :: Ptr AFArray -> AFArray -> AFStorage -> IO AFErr 17 | foreign import ccall unsafe "af_sparse_to_dense" 18 | af_sparse_to_dense :: Ptr AFArray -> AFArray -> IO AFErr 19 | foreign import ccall unsafe "af_sparse_get_info" 20 | af_sparse_get_info :: Ptr AFArray -> Ptr AFArray -> Ptr AFArray -> Ptr AFStorage -> AFArray -> IO AFErr 21 | foreign import ccall unsafe "af_sparse_get_values" 22 | af_sparse_get_values :: Ptr AFArray -> AFArray -> IO AFErr 23 | foreign import ccall unsafe "af_sparse_get_row_idx" 24 | af_sparse_get_row_idx :: Ptr AFArray -> AFArray -> IO AFErr 25 | foreign import ccall unsafe "af_sparse_get_col_idx" 26 | af_sparse_get_col_idx :: Ptr AFArray -> AFArray -> IO AFErr 27 | foreign import ccall unsafe "af_sparse_get_nnz" 28 | af_sparse_get_nnz :: Ptr DimT -> AFArray -> IO AFErr 29 | foreign import ccall unsafe "af_sparse_get_storage" 30 | af_sparse_get_storage :: Ptr AFStorage -> AFArray -> IO AFErr 31 | -------------------------------------------------------------------------------- /src/ArrayFire/Internal/Statistics.hsc: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE CPP #-} 2 | module ArrayFire.Internal.Statistics where 3 | 4 | import ArrayFire.Internal.Defines 5 | import Foreign.Ptr 6 | import Foreign.C.Types 7 | 8 | #include "af/statistics.h" 9 | foreign import ccall unsafe "af_mean" 10 | af_mean :: Ptr AFArray -> AFArray -> DimT -> IO AFErr 11 | foreign import ccall unsafe "af_mean_weighted" 12 | af_mean_weighted :: Ptr AFArray -> AFArray -> AFArray -> DimT -> IO AFErr 13 | foreign import ccall unsafe "af_var" 14 | af_var :: Ptr AFArray -> AFArray -> CBool -> DimT -> IO AFErr 15 | foreign import ccall unsafe "af_var_weighted" 16 | af_var_weighted :: Ptr AFArray -> AFArray -> AFArray -> DimT -> IO AFErr 17 | foreign import ccall unsafe "af_stdev" 18 | af_stdev :: Ptr AFArray -> AFArray -> DimT -> IO AFErr 19 | foreign import ccall unsafe "af_cov" 20 | af_cov :: Ptr AFArray -> AFArray -> AFArray -> CBool -> IO AFErr 21 | foreign import ccall unsafe "af_median" 22 | af_median :: Ptr AFArray -> AFArray -> DimT -> IO AFErr 23 | foreign import ccall unsafe "af_mean_all" 24 | af_mean_all :: Ptr Double -> Ptr Double -> AFArray -> IO AFErr 25 | foreign import ccall unsafe "af_mean_all_weighted" 26 | af_mean_all_weighted :: Ptr Double -> Ptr Double -> AFArray -> AFArray -> IO AFErr 27 | foreign import ccall unsafe "af_var_all" 28 | af_var_all :: Ptr Double -> Ptr Double -> AFArray -> CBool -> IO AFErr 29 | foreign import ccall unsafe "af_var_all_weighted" 30 | af_var_all_weighted :: Ptr Double -> Ptr Double -> AFArray -> AFArray -> IO AFErr 31 | foreign import ccall unsafe "af_stdev_all" 32 | af_stdev_all :: Ptr Double -> Ptr Double -> AFArray -> IO AFErr 33 | foreign import ccall unsafe "af_median_all" 34 | af_median_all :: Ptr Double -> Ptr Double -> AFArray -> IO AFErr 35 | foreign import ccall unsafe "af_corrcoef" 36 | af_corrcoef :: Ptr Double -> Ptr Double -> AFArray -> AFArray -> IO AFErr 37 | foreign import ccall unsafe "af_topk" 38 | af_topk :: Ptr AFArray -> Ptr AFArray -> AFArray -> CInt -> CInt -> AFTopkFunction -> IO AFErr 39 | -------------------------------------------------------------------------------- /src/ArrayFire/Internal/Util.hsc: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE CPP #-} 2 | module ArrayFire.Internal.Util where 3 | 4 | import ArrayFire.Internal.Defines 5 | import Foreign.Ptr 6 | import Foreign.C.Types 7 | 8 | #include "af/util.h" 9 | foreign import ccall unsafe "af_print_array" 10 | af_print_array :: AFArray -> IO AFErr 11 | foreign import ccall unsafe "af_print_array_gen" 12 | af_print_array_gen :: Ptr CChar -> AFArray -> CInt -> IO AFErr 13 | foreign import ccall unsafe "af_save_array" 14 | af_save_array :: Ptr CInt -> Ptr CChar -> AFArray -> Ptr CChar -> CBool -> IO AFErr 15 | foreign import ccall unsafe "af_read_array_index" 16 | af_read_array_index :: Ptr AFArray -> Ptr CChar -> CUInt -> IO AFErr 17 | foreign import ccall unsafe "af_read_array_key" 18 | af_read_array_key :: Ptr AFArray -> Ptr CChar -> Ptr CChar -> IO AFErr 19 | foreign import ccall unsafe "af_read_array_key_check" 20 | af_read_array_key_check :: Ptr CInt -> Ptr CChar -> Ptr CChar -> IO AFErr 21 | foreign import ccall unsafe "af_array_to_string" 22 | af_array_to_string :: Ptr (Ptr CChar) -> Ptr CChar -> AFArray -> CInt -> CBool -> IO AFErr 23 | foreign import ccall unsafe "af_example_function" 24 | af_example_function :: Ptr AFArray -> AFArray -> AFSomeEnum -> IO AFErr 25 | foreign import ccall unsafe "af_get_version" 26 | af_get_version :: Ptr CInt -> Ptr CInt -> Ptr CInt -> IO AFErr 27 | foreign import ccall unsafe "af_get_revision" 28 | af_get_revision :: IO (Ptr CChar) 29 | foreign import ccall unsafe "af_get_size_of" 30 | af_get_size_of :: Ptr CSize -> AFDtype -> IO AFErr 31 | -------------------------------------------------------------------------------- /src/ArrayFire/Internal/Vision.hsc: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE CPP #-} 2 | module ArrayFire.Internal.Vision where 3 | 4 | import ArrayFire.Internal.Defines 5 | import Foreign.Ptr 6 | import Foreign.C.Types 7 | 8 | #include "af/vision.h" 9 | foreign import ccall unsafe "af_fast" 10 | af_fast :: Ptr AFFeatures -> AFArray -> Float -> CUInt -> CBool -> Float -> CUInt -> IO AFErr 11 | foreign import ccall unsafe "af_harris" 12 | af_harris :: Ptr AFFeatures -> AFArray -> CUInt -> Float -> Float -> CUInt -> Float -> IO AFErr 13 | foreign import ccall unsafe "af_orb" 14 | af_orb :: Ptr AFFeatures -> Ptr AFArray -> AFArray -> Float -> CUInt -> Float -> CUInt -> CBool -> IO AFErr 15 | foreign import ccall unsafe "af_sift" 16 | af_sift :: Ptr AFFeatures -> Ptr AFArray -> AFArray -> CUInt -> Float -> Float -> Float -> CBool -> Float -> Float -> IO AFErr 17 | foreign import ccall unsafe "af_gloh" 18 | af_gloh :: Ptr AFFeatures -> Ptr AFArray -> AFArray -> CUInt -> Float -> Float -> Float -> CBool -> Float -> Float -> IO AFErr 19 | foreign import ccall unsafe "af_hamming_matcher" 20 | af_hamming_matcher :: Ptr AFArray -> Ptr AFArray -> AFArray -> AFArray -> DimT -> CUInt -> IO AFErr 21 | foreign import ccall unsafe "af_nearest_neighbour" 22 | af_nearest_neighbour :: Ptr AFArray -> Ptr AFArray -> AFArray -> AFArray -> DimT -> CUInt -> AFMatchType -> IO AFErr 23 | foreign import ccall unsafe "af_match_template" 24 | af_match_template :: Ptr AFArray -> AFArray -> AFArray -> AFMatchType -> IO AFErr 25 | foreign import ccall unsafe "af_susan" 26 | af_susan :: Ptr AFFeatures -> AFArray -> CUInt -> Float -> Float -> Float -> CUInt -> IO AFErr 27 | foreign import ccall unsafe "af_dog" 28 | af_dog :: Ptr AFArray -> AFArray -> CInt -> CInt -> IO AFErr 29 | foreign import ccall unsafe "af_homography" 30 | af_homography :: Ptr AFArray -> Ptr CInt -> AFArray -> AFArray -> AFArray -> AFArray -> AFHomographyType -> Float -> CUInt -> AFDtype -> IO AFErr 31 | -------------------------------------------------------------------------------- /src/ArrayFire/Orphans.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE DataKinds #-} 2 | {-# LANGUAGE FlexibleInstances #-} 3 | {-# LANGUAGE TypeApplications #-} 4 | {-# LANGUAGE ScopedTypeVariables #-} 5 | {-# OPTIONS_GHC -fno-warn-orphans #-} 6 | -------------------------------------------------------------------------------- 7 | -- | 8 | -- Module : ArrayFire.Orphans 9 | -- Copyright : David Johnson (c) 2019-2020 10 | -- License : BSD 3 11 | -- Maintainer : David Johnson 12 | -- Stability : Experimental 13 | -- Portability : GHC 14 | -- 15 | -------------------------------------------------------------------------------- 16 | module ArrayFire.Orphans where 17 | 18 | import Prelude 19 | 20 | import qualified ArrayFire.Arith as A 21 | import qualified ArrayFire.Array as A 22 | import qualified ArrayFire.Algorithm as A 23 | import qualified ArrayFire.Data as A 24 | import ArrayFire.Types 25 | import ArrayFire.Util 26 | 27 | instance (AFType a, Eq a) => Eq (Array a) where 28 | x == y = A.allTrueAll (A.eqBatched x y False) == (1.0,0.0) 29 | x /= y = A.allTrueAll (A.neqBatched x y False) == (0.0,0.0) 30 | 31 | instance (Num a, AFType a) => Num (Array a) where 32 | x + y = A.add x y 33 | x * y = A.mul x y 34 | abs = A.abs 35 | signum = A.sign 36 | negate arr = do 37 | let (w,x,y,z) = A.getDims arr 38 | A.cast (A.constant @a [w,x,y,z] 0) `A.sub` arr 39 | x - y = A.sub x y 40 | fromInteger = A.scalar . fromIntegral 41 | 42 | instance Show (Array a) where 43 | show = arrayString 44 | 45 | instance forall a . (Fractional a, AFType a) => Fractional (Array a) where 46 | x / y = A.div x y 47 | fromRational n = A.scalar @a (fromRational n) 48 | 49 | instance forall a . (Ord a, AFType a, Fractional a) => Floating (Array a) where 50 | pi = A.scalar @a 3.14159 51 | exp = A.exp @a 52 | log = A.log @a 53 | sqrt = A.sqrt @a 54 | (**) = A.pow @a 55 | sin = A.sin @a 56 | cos = A.cos @a 57 | tan = A.tan @a 58 | tanh = A.tanh @a 59 | asin = A.asin @a 60 | acos = A.acos @a 61 | atan = A.atan @a 62 | sinh = A.sinh @a 63 | cosh = A.cosh @a 64 | acosh = A.acosh @a 65 | atanh = A.atanh @a 66 | asinh = A.asinh @a 67 | -------------------------------------------------------------------------------- /src/ArrayFire/Types.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE UndecidableInstances #-} 2 | {-# LANGUAGE TypeOperators #-} 3 | {-# LANGUAGE ScopedTypeVariables #-} 4 | {-# LANGUAGE PolyKinds #-} 5 | {-# LANGUAGE DataKinds #-} 6 | {-# LANGUAGE AllowAmbiguousTypes #-} 7 | {-# LANGUAGE FlexibleInstances #-} 8 | {-# LANGUAGE TypeApplications #-} 9 | {-# LANGUAGE ViewPatterns #-} 10 | {-# LANGUAGE KindSignatures #-} 11 | {-# LANGUAGE RecordWildCards #-} 12 | {-# LANGUAGE GADTs #-} 13 | {-# LANGUAGE TypeFamilies #-} 14 | -------------------------------------------------------------------------------- 15 | -- | 16 | -- Module : ArrayFire.Types 17 | -- Copyright : David Johnson (c) 2019-2020 18 | -- License : BSD3 19 | -- Maintainer : David Johnson 20 | -- Stability : Experimental 21 | -- Portability : GHC 22 | -- 23 | -- Various Types related to the ArrayFire API 24 | -- 25 | -------------------------------------------------------------------------------- 26 | module ArrayFire.Types 27 | ( AFException (..) 28 | , AFExceptionType (..) 29 | , Array 30 | , Window 31 | , RandomEngine 32 | , Features 33 | , AFType (..) 34 | , TopK (..) 35 | , Backend (..) 36 | , MatchType (..) 37 | , BinaryOp (..) 38 | , MatProp (..) 39 | , HomographyType (..) 40 | , RandomEngineType (..) 41 | , Cell (..) 42 | , MarkerType (..) 43 | , InterpType (..) 44 | , Connectivity (..) 45 | , CSpace (..) 46 | , YccStd (..) 47 | , MomentType (..) 48 | , CannyThreshold (..) 49 | , FluxFunction (..) 50 | , DiffusionEq (..) 51 | , IterativeDeconvAlgo (..) 52 | , InverseDeconvAlgo (..) 53 | , Seq (..) 54 | , Index (..) 55 | , NormType (..) 56 | , ConvMode (..) 57 | , ConvDomain (..) 58 | , BorderType (..) 59 | , Storage (..) 60 | , AFDType (..) 61 | , AFDtype (..) 62 | , ColorMap (..) 63 | ) where 64 | 65 | import ArrayFire.Exception 66 | import ArrayFire.Internal.Types 67 | import ArrayFire.Internal.Defines 68 | -------------------------------------------------------------------------------- /src/ArrayFire/Util.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE ScopedTypeVariables #-} 2 | {-# LANGUAGE TypeApplications #-} 3 | {-# LANGUAGE ViewPatterns #-} 4 | -------------------------------------------------------------------------------- 5 | -- | 6 | -- Module : ArrayFire.Util 7 | -- Copyright : David Johnson (c) 2019-2020 8 | -- License : BSD 3 9 | -- Maintainer : David Johnson 10 | -- Stability : Experimental 11 | -- Portability : GHC 12 | -- 13 | -- Various utilities for working with the ArrayFire C library 14 | -- 15 | -- @ 16 | -- import qualified ArrayFire as A 17 | -- import Control.Monad 18 | -- 19 | -- main :: IO () 20 | -- main = do 21 | -- let arr = A.constant [1,1,1,1] 10 22 | -- idx <- A.saveArray "key" arr "file.array" False 23 | -- foundIndex <- A.readArrayKeyCheck "file.array" "key" 24 | -- when (idx == foundIndex) $ do 25 | -- array <- A.readArrayKey "file.array" "key" 26 | -- print array 27 | -- @ 28 | -- @ 29 | -- ArrayFire Array 30 | -- [ 1 1 1 1 ] 31 | -- 10 32 | -- @ 33 | -------------------------------------------------------------------------------- 34 | module ArrayFire.Util where 35 | 36 | import Control.Exception 37 | 38 | import Data.Proxy 39 | import Foreign.C.String 40 | import Foreign.ForeignPtr 41 | import Foreign.Marshal hiding (void) 42 | import Foreign.Storable 43 | import System.IO.Unsafe 44 | 45 | import ArrayFire.Internal.Types 46 | import ArrayFire.Internal.Util 47 | 48 | import ArrayFire.Exception 49 | import ArrayFire.FFI 50 | 51 | -- | Retrieve version for ArrayFire API 52 | -- 53 | -- @ 54 | -- >>> 'print' '=<<' 'getVersion' 55 | -- @ 56 | -- @ 57 | -- (3.6.4) 58 | -- @ 59 | getVersion :: IO (Int,Int,Int) 60 | getVersion = 61 | alloca $ \x -> 62 | alloca $ \y -> 63 | alloca $ \z -> do 64 | throwAFError =<< af_get_version x y z 65 | (,,) <$> (fromIntegral <$> peek x) 66 | <*> (fromIntegral <$> peek y) 67 | <*> (fromIntegral <$> peek z) 68 | 69 | -- | Prints array to stdout 70 | -- 71 | -- @ 72 | -- >>> 'printArray' (constant \@'Double' [1] 1) 73 | -- @ 74 | -- @ 75 | -- ArrayFire Array 76 | -- [ 1 1 1 1 ] 77 | -- 1.0 78 | -- @ 79 | printArray 80 | :: Array a 81 | -- ^ Input 'Array' 82 | -> IO () 83 | printArray (Array fptr) = 84 | mask_ . withForeignPtr fptr $ \ptr -> 85 | throwAFError =<< af_print_array ptr 86 | 87 | -- | Gets git revision of ArrayFire 88 | -- 89 | -- @ 90 | -- >>> 'putStrLn' '=<<' 'getRevision' 91 | -- @ 92 | -- @ 93 | -- 1b8030c5 94 | -- @ 95 | getRevision :: IO String 96 | getRevision = peekCString =<< af_get_revision 97 | 98 | -- | Prints 'Array' with error codes 99 | -- 100 | -- @ 101 | -- >>> printArrayGen "test" (constant \@'Double' [1] 1) 2 102 | -- @ 103 | -- @ 104 | -- ArrayFire Array 105 | -- [ 1 1 1 1 ] 106 | -- 1.00 107 | -- @ 108 | printArrayGen 109 | :: String 110 | -- ^ is the expression or name of the array 111 | -> Array a 112 | -- ^ is the input array 113 | -> Int 114 | -- ^ precision for the display 115 | -> IO () 116 | printArrayGen s (Array fptr) (fromIntegral -> prec) = do 117 | mask_ . withForeignPtr fptr $ \ptr -> 118 | withCString s $ \cstr -> 119 | throwAFError =<< af_print_array_gen cstr ptr prec 120 | 121 | -- | Saves 'Array' to disk 122 | -- 123 | -- Save an array to a binary file. 124 | -- The 'saveArray' and readArray functions are designed to provide store and read access to arrays using files written to disk. 125 | -- 126 | -- 127 | -- @ 128 | -- >>> saveArray "my array" (constant \@'Double' [1] 1) "array.file" 'True' 129 | -- @ 130 | -- @ 131 | -- 0 132 | -- @ 133 | saveArray 134 | :: String 135 | -- ^ An expression used as tag/key for the 'Array' during readArray 136 | -> Array a 137 | -- ^ Input 'Array' 138 | -> FilePath 139 | -- ^ Path that 'Array' will be saved 140 | -> Bool 141 | -- ^ Used to append to an existing file when 'True' and create or overwrite an existing file when 'False' 142 | -> IO Int 143 | -- ^ The index location of the 'Array' in the file 144 | saveArray key (Array fptr) filename (fromIntegral . fromEnum -> append) = do 145 | mask_ . withForeignPtr fptr $ \ptr -> 146 | alloca $ \ptrIdx -> do 147 | withCString key $ \keyCstr -> 148 | withCString filename $ \filenameCstr -> do 149 | throwAFError =<< 150 | af_save_array ptrIdx keyCstr 151 | ptr filenameCstr append 152 | fromIntegral <$> peek ptrIdx 153 | 154 | -- | Reads Array by index 155 | -- 156 | -- The 'saveArray' and readArray functions are designed to provide store and read access to arrays using files written to disk. 157 | -- 158 | -- 159 | -- @ 160 | -- >>> readArrayIndex "array.file" 0 161 | -- @ 162 | -- @ 163 | -- ArrayFire Array 164 | -- [ 1 1 1 1 ] 165 | -- 10.0000 166 | -- @ 167 | readArrayIndex 168 | :: FilePath 169 | -- ^ Path to 'Array' location 170 | -> Int 171 | -- ^ Index into 'Array' 172 | -> IO (Array a) 173 | readArrayIndex str (fromIntegral -> idx') = 174 | withCString str $ \cstr -> 175 | createArray' (\p -> af_read_array_index p cstr idx') 176 | 177 | -- | Reads 'Array' by key 178 | -- 179 | -- @ 180 | -- >>> readArrayKey "array.file" "my array" 181 | -- @ 182 | -- @ 183 | -- ArrayFire 'Array' 184 | -- [ 1 1 1 1 ] 185 | -- 10.0000 186 | -- @ 187 | readArrayKey 188 | :: FilePath 189 | -- ^ Path to 'Array' 190 | -> String 191 | -- ^ Key of 'Array' on disk 192 | -> IO (Array a) 193 | -- ^ Returned 'Array' 194 | readArrayKey fn key = 195 | withCString fn $ \fcstr -> 196 | withCString key $ \kcstr -> 197 | createArray' (\p -> af_read_array_key p fcstr kcstr) 198 | 199 | -- | Reads Array, checks if a key exists in the specified file 200 | -- 201 | -- When reading by key, it may be a good idea to run this function first to check for the key and then call the readArray using the index. 202 | -- 203 | -- 204 | -- @ 205 | -- >>> readArrayCheck "array.file" "my array" 206 | -- @ 207 | -- @ 208 | -- 0 209 | -- @ 210 | readArrayKeyCheck 211 | :: FilePath 212 | -- ^ Path to file 213 | -> String 214 | -- ^ Key 215 | -> IO Int 216 | -- ^ is the tag/name of the array to be read. The key needs to have an exact match. 217 | readArrayKeyCheck a b = 218 | withCString a $ \acstr -> 219 | withCString b $ \bcstr -> 220 | fromIntegral <$> 221 | afCall1 (\p -> af_read_array_key_check p acstr bcstr) 222 | 223 | -- | Convert ArrayFire 'Array' to 'String', used for 'Show' instance. 224 | -- 225 | -- @ 226 | -- >>> 'putStrLn' '$' 'arrayString' (constant \@'Double' 10 [1,1,1,1]) 227 | -- @ 228 | -- @ 229 | -- ArrayFire 'Array' 230 | -- [ 1 1 1 1 ] 231 | -- 10.0000 232 | -- @ 233 | arrayString 234 | :: Array a 235 | -- ^ Input 'Array' 236 | -> String 237 | -- ^ 'String' representation of 'Array' 238 | arrayString a = arrayToString "ArrayFire Array" a 4 True 239 | 240 | -- | Convert ArrayFire Array to String 241 | -- 242 | -- @ 243 | -- >>> print (constant \@'Double' 10 [1,1,1,1]) 4 'False' 244 | -- @ 245 | -- @ 246 | -- ArrayFire 'Array' 247 | -- [ 1 1 1 1 ] 248 | -- 10.0000 249 | -- @ 250 | arrayToString 251 | :: String 252 | -- ^ Name of 'Array' 253 | -> Array a 254 | -- ^ 'Array' input 255 | -> Int 256 | -- ^ Precision of 'Array' values. 257 | -> Bool 258 | -- ^ If 'True', performs takes the transpose before rendering to 'String' 259 | -> String 260 | -- ^ 'Array' rendered to 'String' 261 | arrayToString expr (Array fptr) (fromIntegral -> prec) (fromIntegral . fromEnum -> trans) = 262 | unsafePerformIO . mask_ . withForeignPtr fptr $ \aptr -> 263 | withCString expr $ \expCstr -> 264 | alloca $ \ocstr -> do 265 | throwAFError =<< af_array_to_string ocstr expCstr aptr prec trans 266 | peekCString =<< peek ocstr 267 | 268 | -- | Retrieve size of ArrayFire data type 269 | -- 270 | -- @ 271 | -- >>> 'getSizeOf' ('Proxy' \@ 'Double') 272 | -- @ 273 | -- @ 274 | -- 8 275 | -- @ 276 | getSizeOf 277 | :: forall a . AFType a 278 | => Proxy a 279 | -- ^ Witness of Haskell type that mirrors ArrayFire type. 280 | -> Int 281 | -- ^ Size of ArrayFire type 282 | getSizeOf proxy = 283 | unsafePerformIO . mask_ . alloca $ \csize -> do 284 | throwAFError =<< af_get_size_of csize (afType proxy) 285 | fromIntegral <$> peek csize 286 | -------------------------------------------------------------------------------- /stack.yaml: -------------------------------------------------------------------------------- 1 | resolver: 2 | url: https://raw.githubusercontent.com/commercialhaskell/stackage-snapshots/master/lts/19/25.yaml 3 | 4 | packages: 5 | - . 6 | -------------------------------------------------------------------------------- /stack.yaml.lock: -------------------------------------------------------------------------------- 1 | # This file was autogenerated by Stack. 2 | # You should not edit this file by hand. 3 | # For more information, please see the documentation at: 4 | # https://docs.haskellstack.org/en/stable/lock_files 5 | 6 | packages: [] 7 | snapshots: 8 | - completed: 9 | size: 619403 10 | url: https://raw.githubusercontent.com/commercialhaskell/stackage-snapshots/master/lts/19/25.yaml 11 | sha256: 1ecad1f0bd2c27de88dbff6572446cfdf647c615d58a7e2e2085c6b7dfc04176 12 | original: 13 | url: https://raw.githubusercontent.com/commercialhaskell/stackage-snapshots/master/lts/19/25.yaml 14 | -------------------------------------------------------------------------------- /test/ArrayFire/AlgorithmSpec.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE TypeApplications #-} 2 | module ArrayFire.AlgorithmSpec where 3 | 4 | import qualified ArrayFire as A 5 | 6 | import Test.Hspec 7 | 8 | spec :: Spec 9 | spec = 10 | describe "Algorithm tests" $ do 11 | it "Should sum a scalar" $ do 12 | A.sum (A.scalar @Int 10) 0 `shouldBe` 10 13 | A.sum (A.scalar @A.Int64 10) 0 `shouldBe` 10 14 | A.sum (A.scalar @A.Int32 10) 0 `shouldBe` 10 15 | A.sum (A.scalar @A.Int16 10) 0 `shouldBe` 10 16 | A.sum (A.scalar @Float 10) 0 `shouldBe` 10 17 | A.sum (A.scalar @A.Word32 10) 0 `shouldBe` 10 18 | A.sum (A.scalar @A.Word64 10) 0 `shouldBe` 10 19 | A.sum (A.scalar @Double 10) 0 `shouldBe` 10.0 20 | A.sum (A.scalar @(A.Complex Double) (1 A.:+ 1)) 0 `shouldBe` A.scalar (1 A.:+ 1) 21 | A.sum (A.scalar @(A.Complex Float) (1 A.:+ 1)) 0 `shouldBe` A.scalar (1 A.:+ 1) 22 | A.sum (A.scalar @A.CBool 1) 0 `shouldBe` 1 23 | A.sum (A.scalar @A.CBool 0) 0 `shouldBe` 0 24 | it "Should sum a vector" $ do 25 | A.sum (A.vector @Int 10 [1..]) 0 `shouldBe` 55 26 | A.sum (A.vector @A.Int64 10 [1..]) 0 `shouldBe` 55 27 | A.sum (A.vector @A.Int32 10 [1..]) 0 `shouldBe` 55 28 | A.sum (A.vector @A.Int16 10 [1..]) 0 `shouldBe` 55 29 | A.sum (A.vector @Float 10 [1..]) 0 `shouldBe` 55 30 | A.sum (A.vector @A.Word32 10 [1..]) 0 `shouldBe` 55 31 | A.sum (A.vector @A.Word64 10 [1..]) 0 `shouldBe` 55 32 | A.sum (A.vector @Double 10 [1..]) 0 `shouldBe` 55.0 33 | A.sum (A.vector @(A.Complex Double) 10 (repeat (1 A.:+ 1))) 0 `shouldBe` A.scalar (10.0 A.:+ 10.0) 34 | A.sum (A.vector @(A.Complex Float) 10 (repeat (1 A.:+ 1))) 0 `shouldBe` A.scalar (10.0 A.:+ 10.0) 35 | A.sum (A.vector @A.CBool 10 (repeat 1)) 0 `shouldBe` 10 36 | A.sum (A.vector @A.CBool 10 (repeat 0)) 0 `shouldBe` 0 37 | it "Should sum a default value to replace NaN" $ do 38 | A.sumNaN (A.vector @Float 10 [1..]) 0 1.0 `shouldBe` 55 39 | A.sumNaN (A.vector @Double 2 [acos 2, acos 2]) 0 50 `shouldBe` 100 40 | A.sumNaN (A.vector @(A.Complex Float) 10 (repeat (1 A.:+ 1))) 0 1.0 `shouldBe` A.scalar (10.0 A.:+ 10.0) 41 | A.sumNaN (A.vector @(A.Complex Double) 10 (repeat (1 A.:+ 1))) 0 1.0 `shouldBe` A.scalar (10.0 A.:+ 10.0) 42 | it "Should product a scalar" $ do 43 | A.product (A.scalar @Int 10) 0 `shouldBe` 10 44 | A.product (A.scalar @A.Int64 10) 0 `shouldBe` 10 45 | A.product (A.scalar @A.Int32 10) 0 `shouldBe` 10 46 | A.product (A.scalar @A.Int16 10) 0 `shouldBe` 10 47 | A.product (A.scalar @Float 10) 0 `shouldBe` 10 48 | A.product (A.scalar @A.Word32 10) 0 `shouldBe` 10 49 | A.product (A.scalar @A.Word64 10) 0 `shouldBe` 10 50 | A.product (A.scalar @Double 10) 0 `shouldBe` 10.0 51 | A.product (A.scalar @(A.Complex Double) (1 A.:+ 1)) 0 `shouldBe` A.scalar (1 A.:+ 1) 52 | A.product (A.scalar @(A.Complex Float) (1 A.:+ 1)) 0 `shouldBe` A.scalar (1 A.:+ 1) 53 | A.product (A.scalar @A.CBool 1) 0 `shouldBe` 1 54 | A.product (A.scalar @A.CBool 0) 0 `shouldBe` 0 55 | it "Should product a vector" $ do 56 | A.product (A.vector @Int 10 [1..]) 0 `shouldBe` 3628800 57 | A.product (A.vector @A.Int64 10 [1..]) 0 `shouldBe` 3628800 58 | A.product (A.vector @A.Int32 10 [1..]) 0 `shouldBe` 3628800 59 | A.product (A.vector @A.Int16 5 [1..]) 0 `shouldBe` 120 60 | A.product (A.vector @Float 10 [1..]) 0 `shouldBe` 3628800 61 | A.product (A.vector @A.Word32 10 [1..]) 0 `shouldBe` 3628800 62 | A.product (A.vector @A.Word64 10 [1..]) 0 `shouldBe` 3628800 63 | A.product (A.vector @Double 10 [1..]) 0 `shouldBe` 3628800.0 64 | A.product (A.vector @(A.Complex Double) 10 (repeat (1 A.:+ 1))) 0 `shouldBe` A.scalar (0.0 A.:+ 32.0) 65 | A.product (A.vector @(A.Complex Float) 10 (repeat (1 A.:+ 1))) 0 `shouldBe` A.scalar (0.0 A.:+ 32.0) 66 | A.product (A.vector @A.CBool 10 (repeat 1)) 0 `shouldBe` 1 -- FIXED in 3.8.2, vector product along 0-axis is 1 for vector size 10 of all 1's. 67 | A.product (A.vector @A.CBool 10 (repeat 0)) 0 `shouldBe` 0 68 | it "Should product a default value to replace NaN" $ do 69 | A.productNaN (A.vector @Float 10 [1..]) 0 1.0 `shouldBe` 3628800.0 70 | A.productNaN (A.vector @Double 2 [acos 2, acos 2]) 0 50 `shouldBe` 2500 71 | A.productNaN (A.vector @(A.Complex Float) 10 (repeat (1 A.:+ 1))) 0 1.0 `shouldBe` A.scalar (0.0 A.:+ 32) 72 | A.productNaN (A.vector @(A.Complex Double) 10 (repeat (1 A.:+ 1))) 0 1.0 `shouldBe` A.scalar (0 A.:+ 32) 73 | it "Should take the minimum element of a vector" $ do 74 | A.min (A.vector @Int 10 [1..]) 0 `shouldBe` 1 75 | A.min (A.vector @A.Int64 10 [1..]) 0 `shouldBe` 1 76 | A.min (A.vector @A.Int32 10 [1..]) 0 `shouldBe` 1 77 | A.min (A.vector @A.Int16 10 [1..]) 0 `shouldBe` 1 78 | A.min (A.vector @Float 10 [1..]) 0 `shouldBe` 1 79 | A.min (A.vector @A.Word32 10 [1..]) 0 `shouldBe` 1 80 | A.min (A.vector @A.Word64 10 [1..]) 0 `shouldBe` 1 81 | A.min (A.vector @Double 10 [1..]) 0 `shouldBe` 1 82 | A.min (A.vector @(A.Complex Double) 10 (repeat (1 A.:+ 1))) 0 `shouldBe` A.scalar (1 A.:+ 1) 83 | A.min (A.vector @(A.Complex Float) 10 (repeat (1 A.:+ 1))) 0 `shouldBe` A.scalar (1 A.:+ 1) 84 | A.min (A.vector @A.CBool 10 [1..]) 0 `shouldBe` 1 85 | A.min (A.vector @A.CBool 10 [1..]) 0 `shouldBe` 1 86 | it "Should find if all elements are true along dimension" $ do 87 | A.allTrue (A.vector @Double 5 (repeat 12.0)) 0 `shouldBe` 1 88 | A.allTrue (A.vector @A.CBool 5 (repeat 1)) 0 `shouldBe` 1 89 | A.allTrue (A.vector @A.CBool 5 (repeat 0)) 0 `shouldBe` 0 90 | A.allTrue (A.vector @A.CBool 5 (repeat 0)) 0 `shouldBe` 0 91 | it "Should find if any elements are true along dimension" $ do 92 | A.anyTrue (A.vector @A.CBool 5 (repeat 1)) 0 `shouldBe` 1 93 | A.anyTrue (A.vector @Int 5 (repeat 23)) 0 `shouldBe` 1 94 | A.anyTrue (A.vector @A.CBool 5 (repeat 0)) 0 `shouldBe` 0 95 | it "Should get count of all elements" $ do 96 | A.count (A.vector @Int 5 (repeat 1)) 0 `shouldBe` 5 97 | A.count (A.vector @A.CBool 5 (repeat 1)) 0 `shouldBe` 5 98 | A.count (A.vector @Double 5 (repeat 1)) 0 `shouldBe` 5 99 | A.count (A.vector @Float 5 (repeat 1)) 0 `shouldBe` 5 100 | it "Should get sum all elements" $ do 101 | A.sumAll (A.vector @Int 5 (repeat 2)) `shouldBe` (10,0) 102 | A.sumAll (A.vector @Double 5 (repeat 2)) `shouldBe` (10.0,0) 103 | A.sumAll (A.vector @A.CBool 3800 (repeat 1)) `shouldBe` (3800,0) 104 | A.sumAll (A.vector @(A.Complex Double) 5 (repeat (2 A.:+ 0))) `shouldBe` (10.0,0) 105 | it "Should get sum all elements" $ do 106 | A.sumNaNAll (A.vector @Double 2 [10, acos 2]) 1 `shouldBe` (11.0,0) 107 | it "Should product all elements in an Array" $ do 108 | A.productAll (A.vector @Int 5 (repeat 2)) `shouldBe` (32,0) 109 | it "Should product all elements in an Array" $ do 110 | A.productNaNAll (A.vector @Double 2 [10,acos 2]) 10 `shouldBe` (100,0) 111 | it "Should find minimum value of an Array" $ do 112 | A.minAll (A.vector @Int 5 [0..]) `shouldBe` (0,0) 113 | it "Should find maximum value of an Array" $ do 114 | A.maxAll (A.vector @Int 5 [0..]) `shouldBe` (4,0) 115 | -- it "Should find if all elements are true" $ do 116 | -- A.allTrue (A.vector @A.CBool 5 (repeat 0)) `shouldBe` False 117 | 118 | -------------------------------------------------------------------------------- /test/ArrayFire/ArithSpec.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE RankNTypes #-} 2 | {-# LANGUAGE ScopedTypeVariables #-} 3 | {-# LANGUAGE TypeApplications #-} 4 | 5 | module ArrayFire.ArithSpec where 6 | 7 | import ArrayFire (AFType, Array, cast, clamp, getType, isInf, isZero, matrix, maxOf, minOf, mkArray, scalar, vector) 8 | import qualified ArrayFire 9 | import Control.Exception (throwIO) 10 | import Control.Monad (unless, when) 11 | import Foreign.C 12 | import GHC.Exts (IsList (..)) 13 | import GHC.Stack 14 | import Test.HUnit.Lang (FailureReason (..), HUnitFailure (..)) 15 | import Test.Hspec 16 | import Test.Hspec.QuickCheck 17 | import Prelude hiding (div) 18 | 19 | compareWith :: (HasCallStack, Show a) => (a -> a -> Bool) -> a -> a -> Expectation 20 | compareWith comparator result expected = 21 | unless (comparator result expected) $ do 22 | throwIO (HUnitFailure location $ ExpectedButGot Nothing expectedMsg actualMsg) 23 | where 24 | expectedMsg = show expected 25 | actualMsg = show result 26 | location = case reverse (toList callStack) of 27 | (_, loc) : _ -> Just loc 28 | [] -> Nothing 29 | 30 | class (Num a) => HasEpsilon a where 31 | eps :: a 32 | 33 | instance HasEpsilon Float where 34 | eps = 1.1920929e-7 35 | 36 | instance HasEpsilon Double where 37 | eps = 2.220446049250313e-16 38 | 39 | approxWith :: (Ord a, Num a) => a -> a -> a -> a -> Bool 40 | approxWith rtol atol a b = abs (a - b) <= Prelude.max atol (rtol * Prelude.max (abs a) (abs b)) 41 | 42 | approx :: (Ord a, HasEpsilon a) => a -> a -> Bool 43 | approx a b = approxWith (2 * eps * Prelude.max (abs a) (abs b)) (4 * eps) a b 44 | 45 | shouldBeApprox :: (Ord a, HasEpsilon a, Show a) => a -> a -> Expectation 46 | shouldBeApprox = compareWith approx 47 | 48 | evalf :: (AFType a) => Array a -> a 49 | evalf = ArrayFire.getScalar 50 | 51 | shouldMatchBuiltin :: 52 | (AFType a, Ord a, RealFloat a, HasEpsilon a, Show a) => 53 | (Array a -> Array a) -> 54 | (a -> a) -> 55 | a -> 56 | Expectation 57 | shouldMatchBuiltin f f' x 58 | | isInfinite y && isInfinite y' = pure () 59 | | Prelude.isNaN y && Prelude.isNaN y' = pure () 60 | | otherwise = y `shouldBeApprox` y' 61 | where 62 | y = evalf (f (scalar x)) 63 | y' = f' x 64 | 65 | shouldMatchBuiltin2 :: 66 | (AFType a, Ord a, RealFloat a, HasEpsilon a, Show a) => 67 | (Array a -> Array a -> Array a) -> 68 | (a -> a -> a) -> 69 | a -> 70 | a -> 71 | Expectation 72 | shouldMatchBuiltin2 f f' a = shouldMatchBuiltin (f (scalar a)) (f' a) 73 | 74 | spec :: Spec 75 | spec = 76 | describe "Arith tests" $ do 77 | it "Should negate scalar value" $ do 78 | negate (scalar @Int 1) `shouldBe` (-1) 79 | it "Should negate a vector" $ do 80 | negate (vector @Int 3 [2, 2, 2]) `shouldBe` vector @Int 3 [-2, -2, -2] 81 | it "Should add two scalar arrays" $ do 82 | scalar @Int 1 + 2 `shouldBe` 3 83 | it "Should add two scalar bool arrays" $ do 84 | scalar @CBool 1 + 0 `shouldBe` 1 85 | it "Should subtract two scalar arrays" $ do 86 | scalar @Int 4 - 2 `shouldBe` 2 87 | it "Should multiply two scalar arrays" $ do 88 | scalar @Double 4 `ArrayFire.mul` 2 `shouldBe` 8 89 | it "Should divide two scalar arrays" $ do 90 | ArrayFire.div @Double 8 2 `shouldBe` 4 91 | it "Should add two matrices" $ do 92 | matrix @Int (2, 2) [[1, 1], [1, 1]] + matrix @Int (2, 2) [[1, 1], [1, 1]] 93 | `shouldBe` matrix @Int (2, 2) [[2, 2], [2, 2]] 94 | prop "Should take cubed root" $ \(x :: Double) -> 95 | evalf (ArrayFire.cbrt (scalar (x * x * x))) `shouldBeApprox` x 96 | 97 | it "Should lte Array" $ do 98 | 2 `ArrayFire.le` (3 :: Array Double) `shouldBe` 1 99 | it "Should gte Array" $ do 100 | 2 `ArrayFire.ge` (3 :: Array Double) `shouldBe` 0 101 | it "Should gt Array" $ do 102 | 2 `ArrayFire.gt` (3 :: Array Double) `shouldBe` 0 103 | it "Should lt Array" $ do 104 | 2 `ArrayFire.le` (3 :: Array Double) `shouldBe` 1 105 | it "Should eq Array" $ do 106 | 3 == (3 :: Array Double) `shouldBe` True 107 | it "Should and Array" $ do 108 | (mkArray @CBool [1] [0] `ArrayFire.and` mkArray [1] [1]) 109 | `shouldBe` mkArray [1] [0] 110 | it "Should and Array" $ do 111 | (mkArray @CBool [2] [0, 0] `ArrayFire.and` mkArray [2] [1, 0]) 112 | `shouldBe` mkArray [2] [0, 0] 113 | it "Should or Array" $ do 114 | (mkArray @CBool [2] [0, 0] `ArrayFire.or` mkArray [2] [1, 0]) 115 | `shouldBe` mkArray [2] [1, 0] 116 | it "Should not Array" $ do 117 | ArrayFire.not (mkArray @CBool [2] [1, 0]) `shouldBe` mkArray [2] [0, 1] 118 | it "Should bitwise and array" $ do 119 | ArrayFire.bitAnd (scalar @Int 1) (scalar @Int 0) 120 | `shouldBe` 0 121 | it "Should bitwise or array" $ do 122 | ArrayFire.bitOr (scalar @Int 1) (scalar @Int 0) 123 | `shouldBe` 1 124 | it "Should bitwise xor array" $ do 125 | ArrayFire.bitXor (scalar @Int 1) (scalar @Int 1) 126 | `shouldBe` 0 127 | it "Should bitwise shift left an array" $ do 128 | ArrayFire.bitShiftL (scalar @Int 1) (scalar @Int 3) 129 | `shouldBe` 8 130 | it "Should cast an array" $ do 131 | getType (cast (scalar @Int 1) :: Array Double) 132 | `shouldBe` ArrayFire.F64 133 | it "Should find the minimum of two arrays" $ do 134 | minOf (scalar @Int 1) (scalar @Int 0) 135 | `shouldBe` 0 136 | it "Should find the max of two arrays" $ do 137 | maxOf (scalar @Int 1) (scalar @Int 0) 138 | `shouldBe` 1 139 | it "Should take the clamp of 3 arrays" $ do 140 | clamp (scalar @Int 2) (scalar @Int 1) (scalar @Int 3) 141 | `shouldBe` 2 142 | it "Should check if an array has positive or negative infinities" $ do 143 | isInf (scalar @Double (1 / 0)) `shouldBe` scalar @Double 1 144 | isInf (scalar @Double 10) `shouldBe` scalar @Double 0 145 | it "Should check if an array has any NaN values" $ do 146 | ArrayFire.isNaN (scalar @Double (acos 2)) `shouldBe` scalar @Double 1 147 | ArrayFire.isNaN (scalar @Double 10) `shouldBe` scalar @Double 0 148 | it "Should check if an array has any Zero values" $ do 149 | isZero (scalar @Double (acos 2)) `shouldBe` scalar @Double 0 150 | isZero (scalar @Double 0) `shouldBe` scalar @Double 1 151 | isZero (scalar @Double 1) `shouldBe` scalar @Double 0 152 | 153 | prop "Floating @Float (exp)" $ \(x :: Float) -> exp `shouldMatchBuiltin` exp $ x 154 | prop "Floating @Float (log)" $ \(x :: Float) -> log `shouldMatchBuiltin` log $ x 155 | prop "Floating @Float (sqrt)" $ \(x :: Float) -> sqrt `shouldMatchBuiltin` sqrt $ x 156 | prop "Floating @Float (**)" $ \(x :: Float) (y :: Float) -> ((**) `shouldMatchBuiltin2` (**)) x y 157 | prop "Floating @Float (sin)" $ \(x :: Float) -> sin `shouldMatchBuiltin` sin $ x 158 | prop "Floating @Float (cos)" $ \(x :: Float) -> cos `shouldMatchBuiltin` cos $ x 159 | prop "Floating @Float (tan)" $ \(x :: Float) -> tan `shouldMatchBuiltin` tan $ x 160 | prop "Floating @Float (asin)" $ \(x :: Float) -> asin `shouldMatchBuiltin` asin $ x 161 | prop "Floating @Float (acos)" $ \(x :: Float) -> acos `shouldMatchBuiltin` acos $ x 162 | prop "Floating @Float (atan)" $ \(x :: Float) -> atan `shouldMatchBuiltin` atan $ x 163 | prop "Floating @Float (sinh)" $ \(x :: Float) -> sinh `shouldMatchBuiltin` sinh $ x 164 | prop "Floating @Float (cosh)" $ \(x :: Float) -> cosh `shouldMatchBuiltin` cosh $ x 165 | prop "Floating @Float (tanh)" $ \(x :: Float) -> tanh `shouldMatchBuiltin` tanh $ x 166 | prop "Floating @Float (asinh)" $ \(x :: Float) -> asinh `shouldMatchBuiltin` asinh $ x 167 | prop "Floating @Float (acosh)" $ \(x :: Float) -> acosh `shouldMatchBuiltin` acosh $ x 168 | prop "Floating @Float (atanh)" $ \(x :: Float) -> atanh `shouldMatchBuiltin` atanh $ x 169 | -------------------------------------------------------------------------------- /test/ArrayFire/ArraySpec.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE ScopedTypeVariables #-} 2 | {-# LANGUAGE TypeApplications #-} 3 | module ArrayFire.ArraySpec where 4 | 5 | import Control.Exception 6 | import Data.Complex 7 | import Data.Word 8 | import Foreign.C.Types 9 | import GHC.Int 10 | import Test.Hspec 11 | 12 | import ArrayFire 13 | 14 | spec :: Spec 15 | spec = 16 | describe "Array tests" $ do 17 | it "Should perform Array tests" $ do 18 | (1 + 1) `shouldBe` 2 19 | it "Should fail to create 0 dimension arrays" $ do 20 | let arr = mkArray @Int [0,0,0,0] [1..] 21 | evaluate arr `shouldThrow` anyException 22 | it "Should fail to create 0 length arrays" $ do 23 | let arr = mkArray @Int [0,0,0,1] [] 24 | evaluate arr `shouldThrow` anyException 25 | it "Should fail to create 0 length arrays w/ 0 dimensions" $ do 26 | let arr = mkArray @Int [0,0,0,0] [] 27 | evaluate arr `shouldThrow` anyException 28 | it "Should create a column vector" $ do 29 | let arr = mkArray @Int [9,1,1,1] (repeat 9) 30 | isColumn arr `shouldBe` True 31 | it "Should create a row vector" $ do 32 | let arr = mkArray @Int [1,9,1,1] (repeat 9) 33 | isRow arr `shouldBe` True 34 | it "Should create a vector" $ do 35 | let arr = mkArray @Int [9,1,1,1] (repeat 9) 36 | isVector arr `shouldBe` True 37 | it "Should create a vector" $ do 38 | let arr = mkArray @Int [1,9,1,1] (repeat 9) 39 | isVector arr `shouldBe` True 40 | it "Should copy an array" $ do 41 | let arr = mkArray @Int [9,9,1,1] (repeat 9) 42 | let newArray = copyArray arr 43 | newArray `shouldBe` arr 44 | it "Should modify manual eval flag" $ do 45 | setManualEvalFlag False 46 | (`shouldBe` False) =<< getManualEvalFlag 47 | it "Should return the number of elements" $ do 48 | let arr = mkArray @Int [9,9,1,1] [1..] 49 | getElements arr `shouldBe` 81 50 | -- it "Should give an empty array" $ do 51 | -- let arr = mkArray @Int [-1,1,1,1] [] 52 | -- getElements arr `shouldBe` 0 53 | -- isEmpty arr `shouldBe` True 54 | it "Should create a scalar array" $ do 55 | let arr = mkArray @Int [1] [1] 56 | isScalar arr `shouldBe` True 57 | it "Should get number of dims specified" $ do 58 | let arr = mkArray @Int [1,1,1,1] [1] 59 | getNumDims arr `shouldBe` 1 60 | let arr = mkArray @Int [2,3,4,5] [1..] 61 | getNumDims arr `shouldBe` 4 62 | let arr = mkArray @Int [2,3,4] [1..] 63 | getNumDims arr `shouldBe` 3 64 | it "Should get value of dims specified" $ do 65 | let arr = mkArray @Int [2,3,4,5] (repeat 1) 66 | getDims arr `shouldBe` (2,3,4,5) 67 | 68 | it "Should test Sparsity" $ do 69 | let arr = mkArray @Double [2,2,1,1] (repeat 1) 70 | isSparse arr `shouldBe` False 71 | 72 | it "Should make a Bit array" $ do 73 | let arr = mkArray @CBool [2,2] [1,1,1,1] 74 | isBool arr `shouldBe` True 75 | 76 | it "Should make an integer array" $ do 77 | let arr = mkArray @Int [2,2] (repeat 1) 78 | isInteger arr `shouldBe` True 79 | 80 | it "Should make a Floating array" $ do 81 | let arr = mkArray @Double [2,2] (repeat 1) 82 | isFloating arr `shouldBe` True 83 | let arr = mkArray @CBool [2,2] (repeat 1) 84 | isFloating arr `shouldBe` False 85 | 86 | it "Should make a Complex array" $ do 87 | let arr = mkArray @(Complex Double) [2,2] (repeat 1) 88 | isComplex arr `shouldBe` True 89 | isReal arr `shouldBe` False 90 | 91 | it "Should make a Real array" $ do 92 | let arr = mkArray @Double [2,2] (repeat 1) 93 | isReal arr `shouldBe` True 94 | isComplex arr `shouldBe` False 95 | 96 | it "Should make a Double precision array" $ do 97 | let arr = mkArray @Double [2,2] (repeat 1) 98 | isDouble arr `shouldBe` True 99 | isSingle arr `shouldBe` False 100 | 101 | it "Should make a Single precision array" $ do 102 | let arr = mkArray @Float [2,2] (repeat 1) 103 | isDouble arr `shouldBe` False 104 | isSingle arr `shouldBe` True 105 | 106 | it "Should make a Real floating array" $ do 107 | let arr = mkArray @Float [2,2] (repeat 1) 108 | isRealFloating arr `shouldBe` True 109 | let arr = mkArray @Double [2,2] (repeat 1) 110 | isRealFloating arr `shouldBe` True 111 | 112 | it "Should get reference count" $ do 113 | let arr1 = mkArray @Float [2,2] (repeat 1) 114 | arr2 = retainArray arr1 115 | arr3 = retainArray arr2 116 | getDataRefCount arr3 `shouldBe` 3 117 | 118 | it "Should convert an array to a list" $ do 119 | let arr = mkArray @Double [30,30] (repeat 1) 120 | toList arr `shouldBe` Prelude.replicate (30 * 30) 1 121 | 122 | let arr = mkArray @Float [10,10] (repeat (5.5)) 123 | toList arr `shouldBe` Prelude.replicate 100 5.5 124 | 125 | let arr = mkArray @CBool [4] [1,1,0,1] 126 | toList arr `shouldBe` [1,1,0,1] 127 | 128 | let arr = mkArray @Int16 [10] [1..] 129 | toList arr `shouldBe` [1..10] 130 | 131 | let arr = mkArray @Int32 [100] [1..100] 132 | toList arr `shouldBe` [1..100] 133 | 134 | let arr = mkArray @Int64 [100] [1..100] 135 | toList arr `shouldBe` [1..100] 136 | 137 | let arr = mkArray @Int [100] [1..100] 138 | toList arr `shouldBe` [1..100] 139 | 140 | let arr = mkArray @(Complex Float) [1] [1 :+ 1] 141 | toList arr `shouldBe` [1 :+ 1] 142 | 143 | let arr = mkArray @(Complex Double) [1] [1 :+ 1] 144 | toList arr `shouldBe` [1 :+ 1] 145 | 146 | let arr = mkArray @Word16 [10] [1..10] 147 | toList arr `shouldBe` [1..10] 148 | 149 | let arr = mkArray @Word32 [10] [1..10] 150 | toList arr `shouldBe` [1..10] 151 | 152 | let arr = mkArray @Word64 [10] [1..10] 153 | toList arr `shouldBe` [1..10] 154 | 155 | let arr = mkArray @Word [10] [1..10] 156 | toList arr `shouldBe` [1..10] 157 | -------------------------------------------------------------------------------- /test/ArrayFire/BLASSpec.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE TypeApplications #-} 2 | module ArrayFire.BLASSpec where 3 | 4 | import ArrayFire hiding (not) 5 | 6 | import Data.Complex 7 | import Test.Hspec 8 | 9 | spec :: Spec 10 | spec = 11 | describe "BLAS spec" $ do 12 | it "Should matmul two matrices" $ do 13 | (matrix @Double (2,2) [[2,2],[2,2]] `matmul` matrix @Double (2,2) [[2,2],[2,2]]) None None 14 | `shouldBe` matrix @Double (2,2) [[8,8],[8,8]] 15 | it "Should dot product two vectors" $ do 16 | dot (vector @Double 2 (repeat 2)) (vector @Double 2 (repeat 2)) None None 17 | `shouldBe` 18 | scalar @Double 8 19 | it "Should produce scalar dot product between two vectors as a Complex number" $ do 20 | dotAll (vector @Double 2 (repeat 2)) (vector @Double 2 (repeat 2)) None None 21 | `shouldBe` 22 | 8.0 :+ 0.0 23 | it "Should take the transpose of a matrix" $ do 24 | transpose (matrix @Double (2,2) [[1,1],[2,2]]) False 25 | `shouldBe` 26 | matrix @Double (2,2) [[1,2],[1,2]] 27 | it "Should take the transpose of a matrix in place" $ do 28 | let m = matrix @Double (2,2) [[1,1],[2,2]] 29 | transposeInPlace m False 30 | m `shouldBe` matrix @Double (2,2) [[1,2],[1,2]] 31 | 32 | 33 | 34 | 35 | 36 | -------------------------------------------------------------------------------- /test/ArrayFire/BackendSpec.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE TypeApplications #-} 2 | module ArrayFire.BackendSpec where 3 | 4 | import ArrayFire hiding (not) 5 | 6 | import Test.Hspec 7 | 8 | spec :: Spec 9 | spec = 10 | describe "Backend spec" $ do 11 | it "Should get backend count" $ do 12 | (`shouldSatisfy` (>0)) =<< getBackendCount 13 | it "Should get available backends" $ do 14 | backends <- getAvailableBackends 15 | backends `shouldSatisfy` (CPU `elem`) 16 | it "Should set backend to CPU" $ do 17 | backend <- getActiveBackend 18 | setBackend backend 19 | (`shouldBe` backend) =<< getActiveBackend 20 | let arr = matrix @Int (2,2) [[1,1],[1,1]] 21 | getBackend arr `shouldBe` backend 22 | -------------------------------------------------------------------------------- /test/ArrayFire/DataSpec.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE ScopedTypeVariables #-} 2 | {-# LANGUAGE TypeApplications #-} 3 | module ArrayFire.DataSpec where 4 | 5 | import Control.Exception 6 | import Data.Complex 7 | import Data.Word 8 | import Foreign.C.Types 9 | import GHC.Int 10 | import Test.Hspec 11 | 12 | import ArrayFire 13 | 14 | spec :: Spec 15 | spec = 16 | describe "Data tests" $ do 17 | it "Should create constant Array" $ do 18 | constant @Float [1] 1 `shouldBe` 1 19 | constant @Double [1] 1 `shouldBe` 1 20 | constant @Int16 [1] 1 `shouldBe` 1 21 | constant @Int32 [1] 1 `shouldBe` 1 22 | constant @Int64 [1] 1 `shouldBe` 1 23 | constant @Int [1] 1 `shouldBe` 1 24 | constant @Word16 [1] 1 `shouldBe` 1 25 | constant @Word32 [1] 1 `shouldBe` 1 26 | constant @Word64 [1] 1 `shouldBe` 1 27 | constant @Word [1] 1 `shouldBe` 1 28 | constant @CBool [1] 1 `shouldBe` 1 29 | constant @(Complex Double) [1] (1.0 :+ 1.0) 30 | `shouldBe` 31 | constant @(Complex Double) [1] (1.0 :+ 1.0) 32 | constant @(Complex Float) [1] (1.0 :+ 1.0) 33 | `shouldBe` 34 | constant @(Complex Float) [1] (1.0 :+ 1.0) 35 | it "Should join Arrays along the specified dimension" $ do 36 | join 0 (constant @Int [1, 3] 1) (constant @Int [1, 3] 2) `shouldBe` mkArray @Int [2, 3] [1, 2, 1, 2, 1, 2] 37 | join 1 (constant @Int [1, 2] 1) (constant @Int [1, 2] 2) `shouldBe` mkArray @Int [1, 4] [1, 1, 2, 2] 38 | joinMany 0 [constant @Int [1, 3] 1, constant @Int [1, 3] 2] `shouldBe` mkArray @Int [2, 3] [1, 2, 1, 2, 1, 2] 39 | joinMany 1 [constant @Int [1, 2] 1, constant @Int [1, 1] 2, constant @Int [1, 3] 3] `shouldBe` mkArray @Int [1, 6] [1, 1, 2, 3, 3, 3] 40 | -------------------------------------------------------------------------------- /test/ArrayFire/DeviceSpec.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE TypeApplications #-} 2 | module ArrayFire.DeviceSpec where 3 | 4 | import qualified ArrayFire as A 5 | import Foreign.C.Types 6 | import Test.Hspec 7 | 8 | spec :: Spec 9 | spec = 10 | describe "Algorithm tests" $ do 11 | it "Should show device info" $ do 12 | A.info `shouldReturn` () 13 | it "Should show device init" $ do 14 | A.afInit `shouldReturn` () 15 | it "Should get info string" $ do 16 | A.getInfoString >>= (`shouldSatisfy` (not . null)) 17 | it "Should get device" $ do 18 | A.getDevice >>= (`shouldSatisfy` (>= 0)) 19 | it "Should get and set device" $ do 20 | (A.getDevice >>= A.setDevice) `shouldReturn` () 21 | 22 | -------------------------------------------------------------------------------- /test/ArrayFire/FeaturesSpec.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE TypeApplications #-} 2 | module ArrayFire.FeaturesSpec where 3 | 4 | import ArrayFire hiding (acos) 5 | import Prelude 6 | import Test.Hspec 7 | 8 | spec :: Spec 9 | spec = 10 | describe "Feautures tests" $ do 11 | it "Should get features number an array" $ do 12 | let feats = createFeatures 10 13 | getFeaturesNum feats `shouldBe` 10 14 | -------------------------------------------------------------------------------- /test/ArrayFire/GraphicsSpec.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE ScopedTypeVariables #-} 2 | {-# LANGUAGE TypeApplications #-} 3 | module ArrayFire.GraphicsSpec where 4 | 5 | import Control.Exception 6 | import Data.Complex 7 | import Data.Word 8 | import Foreign.C.Types 9 | import GHC.Int 10 | import Test.Hspec 11 | 12 | import ArrayFire 13 | 14 | spec :: Spec 15 | spec = 16 | describe "Graphics tests" $ do 17 | it "Should create window" $ do 18 | (1 + 1) `shouldBe` 2 19 | -------------------------------------------------------------------------------- /test/ArrayFire/ImageSpec.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE ScopedTypeVariables #-} 2 | {-# LANGUAGE TypeApplications #-} 3 | module ArrayFire.ImageSpec where 4 | 5 | import Control.Exception 6 | import Data.Complex 7 | import Data.Word 8 | import Foreign.C.Types 9 | import GHC.Int 10 | import Test.Hspec 11 | 12 | import ArrayFire 13 | 14 | spec :: Spec 15 | spec = 16 | describe "Image tests" $ do 17 | it "Should test if Image I/O is available" $ do 18 | isImageIOAvailable `shouldReturn` True 19 | -------------------------------------------------------------------------------- /test/ArrayFire/IndexSpec.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE BangPatterns #-} 2 | {-# LANGUAGE TypeApplications #-} 3 | module ArrayFire.IndexSpec where 4 | 5 | import qualified ArrayFire as A 6 | import Control.Exception 7 | import Data.Complex 8 | import Data.Int 9 | import Data.Proxy 10 | import Data.Word 11 | import Foreign.C.Types 12 | import Test.Hspec 13 | 14 | spec :: Spec 15 | spec = 16 | describe "Index spec" $ do 17 | it "Should index into an array" $ do 18 | let arr = A.vector @Int 10 [1..] 19 | A.index arr [A.Seq 0 4 1] 20 | `shouldBe` 21 | A.vector @Int 5 [1..] 22 | -------------------------------------------------------------------------------- /test/ArrayFire/LAPACKSpec.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE TypeApplications #-} 2 | module ArrayFire.LAPACKSpec where 3 | 4 | import qualified ArrayFire as A 5 | import Prelude 6 | import Test.Hspec 7 | import Test.Hspec.ApproxExpect 8 | 9 | spec :: Spec 10 | spec = 11 | describe "LAPACK spec" $ do 12 | it "Should have LAPACK available" $ do 13 | A.isLAPACKAvailable `shouldBe` True 14 | it "Should perform svd" $ do 15 | let (s,v,d) = A.svd $ A.matrix @Double (4,2) [ [1,2,3,4], [5,6,7,8] ] 16 | A.getDims s `shouldBe` (4,4,1,1) 17 | A.getDims v `shouldBe` (2,1,1,1) 18 | A.getDims d `shouldBe` (2,2,1,1) 19 | it "Should perform svd in place" $ do 20 | let (s,v,d) = A.svdInPlace $ A.matrix @Double (4,2) [ [1,2,3,4], [5,6,7,8] ] 21 | A.getDims s `shouldBe` (4,4,1,1) 22 | A.getDims v `shouldBe` (2,1,1,1) 23 | A.getDims d `shouldBe` (2,2,1,1) 24 | it "Should perform lu" $ do 25 | let (s,v,d) = A.lu $ A.matrix @Double (2,2) [[3,1],[4,2]] 26 | A.getDims s `shouldBe` (2,2,1,1) 27 | A.getDims v `shouldBe` (2,2,1,1) 28 | A.getDims d `shouldBe` (2,1,1,1) 29 | it "Should perform qr" $ do 30 | let (s,v,d) = A.lu $ A.matrix @Double (3,3) [[12,6,4],[-51,167,24],[4,-68,-41]] 31 | A.getDims s `shouldBe` (3,3,1,1) 32 | A.getDims v `shouldBe` (3,3,1,1) 33 | A.getDims d `shouldBe` (3,1,1,1) 34 | it "Should get determinant of Double" $ do 35 | let eles = [[3 A.:+ 1, 8 A.:+ 1], [4 A.:+ 1, 6 A.:+ 1]] 36 | (x,y) = A.det (A.matrix @(A.Complex Double) (2,2) eles) 37 | x `shouldBeApprox` (-14) 38 | let (x,y) = A.det $ A.matrix @Double (2,2) [[3,8],[4,6]] 39 | x `shouldBeApprox` (-14) 40 | -- it "Should calculate inverse" $ do 41 | -- let x = flip A.inverse A.None $ A.matrix @Double (2,2) [[4.0,7.0],[2.0,6.0]] 42 | -- x `shouldBe` A.matrix (2,2) [[0.6,-0.7],[-0.2,0.4]] 43 | -- it "Should calculate psuedo inverse" $ do 44 | -- let x = A.pinverse (A.matrix @Double (2,2) [[4,7],[2,6]]) 1.0 A.None 45 | -- x `shouldBe` A.matrix @Double (2,2) [[0.6,-0.2],[-0.7,0.4]] 46 | -------------------------------------------------------------------------------- /test/ArrayFire/RandomSpec.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE TypeApplications #-} 2 | module ArrayFire.RandomSpec where 3 | 4 | import ArrayFire 5 | import Control.Monad 6 | 7 | import Test.Hspec 8 | 9 | spec :: Spec 10 | spec = 11 | describe "Random engine spec" $ do 12 | it "Should create random engine" $ do 13 | (`shouldBe` Philox) 14 | =<< getRandomEngineType 15 | =<< createRandomEngine 5000 Philox 16 | (`shouldBe` Mersenne) 17 | =<< getRandomEngineType 18 | =<< createRandomEngine 5000 Mersenne 19 | (`shouldBe` ThreeFry) 20 | =<< getRandomEngineType 21 | =<< createRandomEngine 5000 ThreeFry 22 | it "Should set random engine" $ do 23 | r <- createRandomEngine 5000 ThreeFry 24 | setRandomEngine r Philox 25 | (`shouldBe` Philox) =<< getRandomEngineType r 26 | it "Should set and get seed" $ do 27 | setSeed 100 28 | (`shouldBe` 100) =<< getSeed 29 | 30 | 31 | -------------------------------------------------------------------------------- /test/ArrayFire/SignalSpec.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE TypeApplications #-} 2 | module ArrayFire.SignalSpec where 3 | 4 | import qualified ArrayFire as A 5 | import Data.Int 6 | import Data.Word 7 | import Data.Complex 8 | import Data.Proxy 9 | import Foreign.C.Types 10 | import Test.Hspec 11 | 12 | spec :: Spec 13 | spec = 14 | describe "Signal spec" $ do 15 | it "Should do FFT in place" $ do 16 | A.fftInPlace (A.matrix @(Complex Double) (1,1) [[1 :+ 1]]) 10.2 17 | `shouldReturn` () 18 | it "Should do FFT" $ do 19 | A.fft (A.matrix @(Complex Float) (1,1) [[1 :+ 1]]) 1 1 20 | `shouldBe` A.matrix @(Complex Float) (1,1) [[1 :+ 1]] 21 | -------------------------------------------------------------------------------- /test/ArrayFire/SparseSpec.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE TypeApplications #-} 2 | module ArrayFire.SparseSpec where 3 | 4 | import qualified ArrayFire as A 5 | import Data.Int 6 | import Data.Word 7 | import Data.Complex 8 | import Data.Proxy 9 | import Foreign.C.Types 10 | import Test.Hspec 11 | 12 | spec :: Spec 13 | spec = 14 | describe "Sparse spec" $ do 15 | it "Should create a sparse array" $ do 16 | (1+1) `shouldBe` 2 17 | -- A.createSparseArrayFromDense (A.matrix @Double (10,10) [1..]) A.CSR 18 | -- `shouldBe` 19 | -- A.vector @Double 10 [0..] 20 | -------------------------------------------------------------------------------- /test/ArrayFire/StatisticsSpec.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE TypeApplications #-} 2 | module ArrayFire.StatisticsSpec where 3 | 4 | import ArrayFire hiding (not) 5 | 6 | import Data.Complex 7 | import Test.Hspec 8 | import Test.Hspec.ApproxExpect 9 | 10 | spec :: Spec 11 | spec = 12 | describe "Statistics spec" $ do 13 | it "Should find the mean" $ do 14 | mean (vector @Double 10 [1..]) 0 15 | `shouldBe` 16 | 5.5 17 | it "Should find the weighted-mean" $ do 18 | meanWeighted (vector @Double 10 [1..]) (vector @Double 10 [1..]) 0 19 | `shouldBeApprox` 20 | 7.0 21 | it "Should find the variance" $ do 22 | var (vector @Double 8 [1..8]) False 0 23 | `shouldBe` 24 | 5.25 25 | it "Should find the weighted variance" $ do 26 | varWeighted (vector @Double 8 [1..]) (vector @Double 8 (repeat 1)) 0 27 | `shouldBe` 28 | 5.25 29 | it "Should find the standard deviation" $ do 30 | stdev (vector @Double 10 (cycle [1,-1])) 0 31 | `shouldBe` 32 | 1.0 33 | it "Should find the covariance" $ do 34 | cov (vector @Double 10 (repeat 1)) (vector @Double 10 (repeat 1)) False 35 | `shouldBe` 36 | 0.0 37 | it "Should find the median" $ do 38 | median (vector @Double 10 [1..]) 0 39 | `shouldBe` 40 | 5.5 41 | it "Should find the mean of all elements across all dimensions" $ do 42 | fst (meanAll (matrix @Double (2,2) [[10,10],[10,10]])) 43 | `shouldBe` 44 | 10 45 | it "Should find the weighted mean of all elements across all dimensions" $ do 46 | fst (meanAllWeighted (matrix @Double (2,2) [[10,10],[10,10]]) (matrix @Double (2,2) [[10,10],[10,10]])) 47 | `shouldBe` 48 | 10 49 | it "Should find the variance of all elements across all dimensions" $ do 50 | fst (varAll (vector @Double 10 (repeat 10)) False) 51 | `shouldBe` 52 | 0 53 | it "Should find the weighted variance of all elements across all dimensions" $ do 54 | fst (varAllWeighted (vector @Double 10 (repeat 10)) (vector @Double 10 (repeat 10))) 55 | `shouldBe` 56 | 0 57 | it "Should find the stdev of all elements across all dimensions" $ do 58 | fst (stdevAll (vector @Double 10 (repeat 10))) 59 | `shouldBe` 60 | 0 61 | it "Should find the median of all elements across all dimensions" $ do 62 | fst (medianAll (vector @Double 10 [1..])) 63 | `shouldBe` 64 | 5.5 65 | it "Should find the correlation coefficient" $ do 66 | fst (corrCoef (vector @Int 10 [1..] ) ( vector @Int 10 [10,9..] )) 67 | `shouldBe` 68 | (-1.0) 69 | it "Should find the top k elements" $ do 70 | let (vals,indexes) = topk ( vector @Double 10 [1..] ) 3 TopKDefault 71 | vals `shouldBe` vector @Double 3 [10,9,8] 72 | indexes `shouldBe` vector @Double 3 [9,8,7] 73 | -------------------------------------------------------------------------------- /test/ArrayFire/UtilSpec.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE TypeApplications #-} 2 | module ArrayFire.UtilSpec where 3 | 4 | import qualified ArrayFire as A 5 | 6 | import Data.Complex 7 | import Data.Int 8 | import Data.Proxy 9 | import Data.Word 10 | import Foreign.C.Types 11 | import System.Directory 12 | import Test.Hspec 13 | 14 | spec :: Spec 15 | spec = 16 | describe "Util spec" $ do 17 | it "Should get size of" $ do 18 | A.getSizeOf (Proxy @Int) `shouldBe` 8 19 | A.getSizeOf (Proxy @Int64) `shouldBe` 8 20 | A.getSizeOf (Proxy @Int32) `shouldBe` 4 21 | A.getSizeOf (Proxy @Int16) `shouldBe` 2 22 | A.getSizeOf (Proxy @Word) `shouldBe` 8 23 | A.getSizeOf (Proxy @Word64) `shouldBe` 8 24 | A.getSizeOf (Proxy @Word32) `shouldBe` 4 25 | A.getSizeOf (Proxy @Word16) `shouldBe` 2 26 | A.getSizeOf (Proxy @Word8) `shouldBe` 1 27 | A.getSizeOf (Proxy @CBool) `shouldBe` 1 28 | A.getSizeOf (Proxy @Double) `shouldBe` 8 29 | A.getSizeOf (Proxy @Float) `shouldBe` 4 30 | A.getSizeOf (Proxy @(Complex Float)) `shouldBe` 8 31 | A.getSizeOf (Proxy @(Complex Double)) `shouldBe` 16 32 | it "Should get version" $ do 33 | (major, minor, patch) <- A.getVersion 34 | major `shouldBe` 3 35 | minor `shouldSatisfy` (>= 8) 36 | patch `shouldSatisfy` (>= 0) 37 | it "Should get revision" $ do 38 | x <- A.getRevision 39 | x `shouldSatisfy` (not . null) 40 | it "Should save / read array" $ do 41 | let arr = A.constant @Int [1,1,1,1] 10 42 | idx <- A.saveArray "key" arr "file.array" False 43 | doesFileExist "file.array" `shouldReturn` True 44 | (`shouldBe` idx) =<< A.readArrayKeyCheck "file.array" "key" 45 | (`shouldBe` arr) =<< A.readArrayIndex "file.array" idx 46 | (`shouldBe` arr) =<< A.readArrayKey "file.array" "key" 47 | removeFile "file.array" 48 | 49 | -------------------------------------------------------------------------------- /test/ArrayFire/VisionSpec.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE TypeApplications #-} 2 | module ArrayFire.VisionSpec where 3 | 4 | import qualified ArrayFire as A 5 | import Test.Hspec 6 | 7 | spec :: Spec 8 | spec = 9 | describe "Vision spec" $ do 10 | it "Should construct Features for fast feature detection" $ do 11 | let arr = A.vector @Int 30000 [1..] 12 | let feats = A.fast arr 1.0 9 False 1.0 3 13 | (1 + 1) `shouldBe` 2 14 | 15 | -------------------------------------------------------------------------------- /test/Main.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE TypeApplications #-} 2 | {-# LANGUAGE ScopedTypeVariables #-} 3 | module Main where 4 | 5 | import Control.Monad 6 | 7 | import Data.Proxy 8 | import Spec (spec) 9 | import Test.Hspec (hspec) 10 | import Test.QuickCheck 11 | import Test.QuickCheck.Classes 12 | 13 | import qualified ArrayFire as A 14 | import ArrayFire (Array) 15 | 16 | import System.IO.Unsafe 17 | 18 | instance (A.AFType a, Arbitrary a) => Arbitrary (Array a) where 19 | arbitrary = pure $ unsafePerformIO (A.randu [2,2]) 20 | 21 | main :: IO () 22 | main = do 23 | -- checks (Proxy :: Proxy (A.Array (A.Complex Float))) 24 | -- checks (Proxy :: Proxy (A.Array (A.Complex Double))) 25 | -- checks (Proxy :: Proxy (A.Array Double)) 26 | -- checks (Proxy :: Proxy (A.Array Float)) 27 | -- checks (Proxy :: Proxy (A.Array Double)) 28 | -- checks (Proxy :: Proxy (A.Array A.Int16)) 29 | -- checks (Proxy :: Proxy (A.Array A.Int32)) 30 | -- checks (Proxy :: Proxy (A.Array A.CBool)) 31 | -- checks (Proxy :: Proxy (A.Array Word)) 32 | -- checks (Proxy :: Proxy (A.Array A.Word8)) 33 | -- checks (Proxy :: Proxy (A.Array A.Word16)) 34 | -- checks (Proxy :: Proxy (A.Array A.Word32)) 35 | -- lawsCheck $ semigroupLaws (Proxy :: Proxy (A.Array Double)) 36 | -- lawsCheck $ semigroupLaws (Proxy :: Proxy (A.Array Float)) 37 | hspec spec 38 | 39 | checks proxy = do 40 | lawsCheck (numLaws proxy) 41 | lawsCheck (eqLaws proxy) 42 | lawsCheck (ordLaws proxy) 43 | -- lawsCheck (semigroupLaws proxy) 44 | -------------------------------------------------------------------------------- /test/Spec.hs: -------------------------------------------------------------------------------- 1 | {-# OPTIONS_GHC -F -pgmF hspec-discover -optF --module-name=Spec #-} 2 | -------------------------------------------------------------------------------- /test/Test/Hspec/ApproxExpect.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE TypeApplications #-} 2 | {-# LANGUAGE ScopedTypeVariables #-} 3 | module Test.Hspec.ApproxExpect where 4 | 5 | import Data.CallStack (HasCallStack) 6 | 7 | import Test.Hspec (shouldSatisfy, Expectation) 8 | 9 | infix 1 `shouldBeApprox` 10 | 11 | shouldBeApprox :: (HasCallStack, Show a, Fractional a, Eq a) 12 | => a -> a -> Expectation 13 | shouldBeApprox actual tgt 14 | -- This is a hackish way of checking, without requiring a specific 15 | -- type or an 'Ord' instance, whether two floating-point values 16 | -- are only some epsilons apart: when the difference is small enough 17 | -- so scaling it down some more makes it a no-op for addition. 18 | = actual `shouldSatisfy` \x -> (x-tgt) * 1e-4 + tgt == tgt 19 | 20 | --------------------------------------------------------------------------------