├── .ghci
├── .gitignore
├── CHANGELOG.md
├── LICENSE
├── README.md
├── Setup.hs
├── arrayfire.cabal
├── cabal.project
├── cbits
└── wrapper.c
├── default.nix
├── doctests
└── Main.hs
├── exe
└── Main.hs
├── flake.lock
├── flake.nix
├── gen
├── Lex.hs
├── Main.hs
├── Parse.hs
├── Print.hs
└── Types.hs
├── include
├── algorithm.h
├── arith.h
├── array.h
├── backend.h
├── blas.h
├── cuda.h
├── data.h
├── defines.h
├── device.h
├── exception.h
├── features.h
├── graphics.h
├── image.h
├── index.h
├── internal.h
├── lapack.h
├── openCL.h
├── random.h
├── seq.h
├── signal.h
├── sparse.h
├── statistics.h
├── util.h
└── vision.h
├── nix
├── default.nix
└── no-download.patch
├── pkg.nix
├── shell.nix
├── src
├── ArrayFire.hs
└── ArrayFire
│ ├── Algorithm.hs
│ ├── Arith.hs
│ ├── Array.hs
│ ├── BLAS.hs
│ ├── Backend.hs
│ ├── Data.hs
│ ├── Device.hs
│ ├── Exception.hs
│ ├── FFI.hs
│ ├── Features.hs
│ ├── Graphics.hs
│ ├── Image.hs
│ ├── Index.hs
│ ├── Internal
│ ├── Algorithm.hsc
│ ├── Arith.hsc
│ ├── Array.hsc
│ ├── BLAS.hsc
│ ├── Backend.hsc
│ ├── CUDA.hsc
│ ├── Data.hsc
│ ├── Defines.hsc
│ ├── Device.hsc
│ ├── Exception.hsc
│ ├── Features.hsc
│ ├── Graphics.hsc
│ ├── Image.hsc
│ ├── Index.hsc
│ ├── Internal.hsc
│ ├── LAPACK.hsc
│ ├── OpenCL.hsc
│ ├── Random.hsc
│ ├── Seq.hsc
│ ├── Signal.hsc
│ ├── Sparse.hsc
│ ├── Statistics.hsc
│ ├── Types.hsc
│ ├── Util.hsc
│ └── Vision.hsc
│ ├── LAPACK.hs
│ ├── Orphans.hs
│ ├── Random.hs
│ ├── Signal.hs
│ ├── Sparse.hs
│ ├── Statistics.hs
│ ├── Types.hs
│ ├── Util.hs
│ └── Vision.hs
├── stack.yaml
├── stack.yaml.lock
└── test
├── ArrayFire
├── AlgorithmSpec.hs
├── ArithSpec.hs
├── ArraySpec.hs
├── BLASSpec.hs
├── BackendSpec.hs
├── DataSpec.hs
├── DeviceSpec.hs
├── FeaturesSpec.hs
├── GraphicsSpec.hs
├── ImageSpec.hs
├── IndexSpec.hs
├── LAPACKSpec.hs
├── RandomSpec.hs
├── SignalSpec.hs
├── SparseSpec.hs
├── StatisticsSpec.hs
├── UtilSpec.hs
└── VisionSpec.hs
├── Main.hs
├── Spec.hs
└── Test
└── Hspec
└── ApproxExpect.hs
/.ghci:
--------------------------------------------------------------------------------
1 | :set prompt "\x03BB> "
2 | :seti -XTypeApplications
3 | :load ArrayFire
4 | :m - ArrayFire
5 | import qualified ArrayFire as A
6 | :set -laf
7 | :set -isrc:test
8 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | dist*
2 | result/
3 | /TAGS
4 | /result
5 | *~
6 | /ctags
7 | cabal.project.local
8 | tags
9 | /.stack-work/
10 |
--------------------------------------------------------------------------------
/CHANGELOG.md:
--------------------------------------------------------------------------------
1 | # Revision history for fire
2 |
3 | ## 0.1.0.0 -- YYYY-mm-dd
4 |
5 | * First version. Released on an unsuspecting world.
6 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (c) 2018, David Johnson
2 |
3 | All rights reserved.
4 |
5 | Redistribution and use in source and binary forms, with or without
6 | modification, are permitted provided that the following conditions are met:
7 |
8 | * Redistributions of source code must retain the above copyright
9 | notice, this list of conditions and the following disclaimer.
10 |
11 | * Redistributions in binary form must reproduce the above
12 | copyright notice, this list of conditions and the following
13 | disclaimer in the documentation and/or other materials provided
14 | with the distribution.
15 |
16 | * Neither the name of David Johnson nor the names of other
17 | contributors may be used to endorse or promote products derived
18 | from this software without specific prior written permission.
19 |
20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
21 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
22 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
23 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
24 | OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
25 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
26 | LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
27 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
28 | THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
29 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
30 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ##
2 | `ArrayFire` is a general-purpose library that simplifies the process of developing software that targets parallel and massively-parallel architectures including CPUs, GPUs, and other hardware acceleration devices.
3 |
4 | `arrayfire-haskell` is a [Haskell](https://haskell.org) binding to [ArrayFire](https://arrayfire.com).
5 |
6 | ## Table of Contents
7 | - [Installation](#Installation)
8 | - [Haskell Installation](#haskell-installation)
9 | - [Documentation](#Documentation)
10 | - [Hacking](#Hacking)
11 | - [Example](#Example)
12 |
13 |
14 | ## Installation
15 | Install `ArrayFire` via the download page.
16 | - https://arrayfire.com/download/
17 |
18 | `ArrayFire` can also be fetched from [nixpkgs](https://github.com/nixos/nixpkgs) `master`.
19 |
20 | ### Haskell Installation
21 |
22 | `arrayfire` can be installed w/ `cabal`, `stack` or `nix`.
23 |
24 | ```
25 | cabal install arrayfire
26 | ```
27 |
28 | ```
29 | stack install arrayfire
30 | ```
31 |
32 |
33 | Also note, if you plan on using ArrayFire's visualization features, you must install `fontconfig` and `glfw` on OSX or Linux.
34 |
35 | ## Documentation
36 | - [Hackage](http://hackage.haskell.org/package/arrayfire)
37 | - [ArrayFire](http://arrayfire.org/docs/gettingstarted.htm)
38 |
39 | ## Hacking
40 | To hack on this library locally, complete the installation step above. We recommend installing the [nix](https://nixos.org/nix/download.html) package manager to facilitate development.
41 |
42 | After the above tools are installed, clone the source from Github.
43 |
44 | ```bash
45 | git clone git@github.com:arrayfire/arrayfire-haskell.git
46 | cd arrayfire-haskell
47 | ```
48 |
49 | To build and run all tests in response to file changes
50 |
51 | ```bash
52 | nix-shell --run test-runner
53 | ```
54 |
55 | To perform interactive development w/ `ghcid`
56 |
57 | ```bash
58 | nix-shell --run ghcid
59 | ```
60 |
61 | To interactively evaluate code in the `repl`
62 |
63 | ```bash
64 | nix-shell --run repl
65 | ```
66 |
67 | To produce the haddocks and open them in a browser
68 |
69 | ```bash
70 | nix-shell --run docs
71 | ```
72 |
73 |
74 | ## Example
75 | ```haskell
76 | {-# LANGUAGE TypeApplications, ScopedTypeVariables #-}
77 | module Main where
78 |
79 | import qualified ArrayFire as A
80 | import Control.Exception (catch)
81 |
82 | main :: IO ()
83 | main = print newArray `catch` (\(e :: A.AFException) -> print e)
84 | where
85 | newArray = A.matrix @Double (2,2) [ [1..], [1..] ] * A.matrix @Double (2,2) [ [2..], [2..] ]
86 |
87 | {-|
88 |
89 | ArrayFire Array
90 | [2 2 1 1]
91 | 2.0000 6.0000
92 | 2.0000 6.0000
93 |
94 | -}
95 | ```
96 |
--------------------------------------------------------------------------------
/Setup.hs:
--------------------------------------------------------------------------------
1 | module Main where
2 |
3 | import Distribution.Extra.Doctest (defaultMainWithDoctests)
4 |
5 | main :: IO ()
6 | main = defaultMainWithDoctests "doctests"
7 |
--------------------------------------------------------------------------------
/arrayfire.cabal:
--------------------------------------------------------------------------------
1 | cabal-version: 3.0
2 | name: arrayfire
3 | version: 0.7.0.0
4 | synopsis: Haskell bindings to the ArrayFire general-purpose GPU library
5 | homepage: https://github.com/arrayfire/arrayfire-haskell
6 | license: BSD-3-Clause
7 | license-file: LICENSE
8 | author: David Johnson
9 | maintainer: code@dmj.io
10 | copyright: David Johnson (c) 2018-2025
11 | category: Math
12 | build-type: Custom
13 | extra-source-files: CHANGELOG.md
14 | description: High-level Haskell bindings to the ArrayFire General-purpose GPU library
15 | .
16 | <>
17 | .
18 |
19 | flag disable-default-paths
20 | description: When enabled, don't add default hardcoded include/link dirs by default. Needed for hermetic builds like in nix.
21 | default: False
22 | manual: True
23 |
24 | flag disable-build-tool-depends
25 | description: When enabled, don't add build-tool-depends fields to the Cabal file. Needed for working inside @nix develop@.
26 | default: False
27 | manual: True
28 |
29 | custom-setup
30 | setup-depends:
31 | base <5,
32 | Cabal,
33 | cabal-doctest >=1 && <1.1
34 |
35 | library
36 | exposed-modules:
37 | ArrayFire
38 | ArrayFire.Algorithm
39 | ArrayFire.Arith
40 | ArrayFire.Array
41 | ArrayFire.Backend
42 | ArrayFire.BLAS
43 | ArrayFire.Data
44 | ArrayFire.Device
45 | ArrayFire.Features
46 | ArrayFire.Graphics
47 | ArrayFire.Image
48 | ArrayFire.Index
49 | ArrayFire.LAPACK
50 | ArrayFire.Random
51 | ArrayFire.Signal
52 | ArrayFire.Sparse
53 | ArrayFire.Statistics
54 | ArrayFire.Types
55 | ArrayFire.Util
56 | ArrayFire.Vision
57 | other-modules:
58 | ArrayFire.FFI
59 | ArrayFire.Exception
60 | ArrayFire.Orphans
61 | ArrayFire.Internal.Algorithm
62 | ArrayFire.Internal.Arith
63 | ArrayFire.Internal.Array
64 | ArrayFire.Internal.Backend
65 | ArrayFire.Internal.BLAS
66 | ArrayFire.Internal.Data
67 | ArrayFire.Internal.Defines
68 | ArrayFire.Internal.Device
69 | ArrayFire.Internal.Exception
70 | ArrayFire.Internal.Features
71 | ArrayFire.Internal.Graphics
72 | ArrayFire.Internal.Image
73 | ArrayFire.Internal.Index
74 | ArrayFire.Internal.Internal
75 | ArrayFire.Internal.LAPACK
76 | ArrayFire.Internal.Random
77 | ArrayFire.Internal.Signal
78 | ArrayFire.Internal.Sparse
79 | ArrayFire.Internal.Statistics
80 | ArrayFire.Internal.Types
81 | ArrayFire.Internal.Util
82 | ArrayFire.Internal.Vision
83 | if !flag(disable-build-tool-depends)
84 | build-tool-depends:
85 | hsc2hs:hsc2hs
86 | extra-libraries:
87 | af
88 | c-sources:
89 | cbits/wrapper.c
90 | build-depends:
91 | base < 5, filepath, vector
92 | hs-source-dirs:
93 | src
94 | ghc-options:
95 | -Wall -Wno-missing-home-modules
96 | default-language:
97 | Haskell2010
98 |
99 | if os(linux) && !flag(disable-default-paths)
100 | include-dirs:
101 | /opt/arrayfire/include
102 | extra-lib-dirs:
103 | /opt/arrayfire/lib64
104 | ld-options:
105 | -Wl,-rpath /opt/arrayfire/lib64
106 |
107 | if os(OSX) && !flag(disable-default-paths)
108 | include-dirs:
109 | /opt/arrayfire/include
110 | extra-lib-dirs:
111 | /opt/arrayfire/lib
112 | ld-options:
113 | -Wl,-rpath /opt/arrayfire/lib
114 |
115 | executable main
116 | hs-source-dirs:
117 | exe
118 | main-is:
119 | Main.hs
120 | build-depends:
121 | base < 5, arrayfire, vector
122 | c-sources:
123 | cbits/wrapper.c
124 | default-language:
125 | Haskell2010
126 |
127 | executable gen
128 | main-is:
129 | Main.hs
130 | hs-source-dirs:
131 | gen
132 | build-depends:
133 | base < 5, parsec, text, directory
134 | default-language:
135 | Haskell2010
136 | other-modules:
137 | Lex
138 | Parse
139 | Print
140 | Types
141 |
142 | test-suite test
143 | type:
144 | exitcode-stdio-1.0
145 | main-is:
146 | Main.hs
147 | other-modules:
148 | Test.Hspec.ApproxExpect
149 | hs-source-dirs:
150 | test
151 | build-depends:
152 | arrayfire,
153 | base < 5,
154 | directory,
155 | hspec,
156 | HUnit,
157 | QuickCheck,
158 | quickcheck-classes,
159 | vector,
160 | call-stack >=0.4 && <0.5
161 | if !flag(disable-build-tool-depends)
162 | build-tool-depends:
163 | hspec-discover:hspec-discover
164 | default-language:
165 | Haskell2010
166 | other-modules:
167 | Spec
168 | ArrayFire.AlgorithmSpec
169 | ArrayFire.ArithSpec
170 | ArrayFire.ArraySpec
171 | ArrayFire.BLASSpec
172 | ArrayFire.BackendSpec
173 | ArrayFire.DataSpec
174 | ArrayFire.DeviceSpec
175 | ArrayFire.FeaturesSpec
176 | ArrayFire.GraphicsSpec
177 | ArrayFire.ImageSpec
178 | ArrayFire.IndexSpec
179 | ArrayFire.LAPACKSpec
180 | ArrayFire.RandomSpec
181 | ArrayFire.SignalSpec
182 | ArrayFire.SparseSpec
183 | ArrayFire.StatisticsSpec
184 | ArrayFire.UtilSpec
185 | ArrayFire.VisionSpec
186 |
187 | test-suite doctests
188 | type:
189 | exitcode-stdio-1.0
190 | buildable:
191 | False
192 | ghc-options:
193 | -threaded
194 | main-is:
195 | Main.hs
196 | hs-source-dirs:
197 | doctests
198 | build-depends:
199 | arrayfire
200 | , base < 5
201 | , doctest >= 0.8
202 | , split
203 | autogen-modules:
204 | Build_doctests
205 | other-modules:
206 | Build_doctests
207 | default-language:
208 | Haskell2010
209 |
210 | source-repository head
211 | type: git
212 | location: https://github.com/arrayfire/arrayfire-haskell.git
213 |
--------------------------------------------------------------------------------
/cabal.project:
--------------------------------------------------------------------------------
1 | packages: .
2 | ignore-project: False
3 | write-ghc-environment-files: always
4 | tests: True
5 | test-options: "--color"
6 | test-show-details: streaming
7 |
--------------------------------------------------------------------------------
/cbits/wrapper.c:
--------------------------------------------------------------------------------
1 | #include "arrayfire.h"
2 | #include
3 |
4 | af_err af_random_engine_set_type_(af_random_engine engine, const af_random_engine_type rtype) { return af_random_engine_set_type(&engine, rtype); }
5 |
6 | af_err af_random_engine_set_seed_(af_random_engine engine, const unsigned long long seed) {
7 | return af_random_engine_set_seed(&engine, seed);
8 | }
9 |
10 | void test_bool () {
11 | double * data = malloc (sizeof (int) * 5);
12 | data[0] = 2;
13 | data[1] = 2;
14 | data[2] = 2;
15 | data[3] = 2;
16 | data[4] = 2;
17 | data[5] = 2;
18 | dim_t * dims = malloc(sizeof(dim_t) * 4);
19 | dims[0] = 5;
20 | dims[1] = 1;
21 | dims[2] = 1;
22 | dims[3] = 1;
23 | af_array arrin;
24 | af_create_array(&arrin, data, 1, dims, f64);
25 | printf("printing input array\n");
26 | af_print_array(arrin);
27 | af_array arrout;
28 | af_product(&arrout, arrin, 0);
29 | printf("printing output array\n");
30 | af_print_array(arrout);
31 | }
32 |
33 | void test_window () {
34 | af_window window;
35 | af_create_window(&window, 100, 100, "foo");
36 | af_show(window);
37 | }
38 |
39 | void zeroOutArray (af_array * arr) {
40 | (*arr) = 0;
41 | }
42 |
--------------------------------------------------------------------------------
/default.nix:
--------------------------------------------------------------------------------
1 | { pkgs ? import { config.allowUnfree = true; } }:
2 | # Latest arrayfire is not yet procured w/ nix.
3 | let
4 | pkg = pkgs.haskellPackages.callCabal2nix "arrayfire" ./. {
5 | af = null;
6 | };
7 | in
8 | pkg
9 |
--------------------------------------------------------------------------------
/doctests/Main.hs:
--------------------------------------------------------------------------------
1 | module Main (main) where
2 |
3 | import ArrayFire
4 | import Build_doctests (flags, pkgs, module_sources)
5 | import System.Environment
6 | import Test.DocTest (doctest)
7 | import Data.List.Split
8 |
9 | main :: IO ()
10 | main = do
11 | print $ 1 + (1 :: Array Int)
12 | moreFlags <- drop 1 . splitOn " " <$> getEnv "NIX_TARGET_LDFLAGS"
13 | mapM_ print moreFlags
14 | mapM_ print (flags ++ pkgs ++ module_sources ++ moreFlags)
15 | doctest (moreFlags ++ flags ++ pkgs ++ module_sources)
16 |
--------------------------------------------------------------------------------
/exe/Main.hs:
--------------------------------------------------------------------------------
1 | {-# LANGUAGE BangPatterns #-}
2 | {-# LANGUAGE ScopedTypeVariables #-}
3 | {-# LANGUAGE TypeApplications #-}
4 | {-# LANGUAGE DataKinds #-}
5 | module Main where
6 |
7 | import ArrayFire
8 | import Control.Concurrent
9 | import Control.Exception
10 |
11 | import Prelude hiding (sum, product)
12 | -- import GHC.RTS
13 |
14 | foreign import ccall safe "test_bool"
15 | testBool :: IO ()
16 |
17 | foreign import ccall safe "test_window"
18 | testWindow :: IO ()
19 |
20 | main' :: IO ()
21 | main' = print newArray `catch` (\(e :: AFException) -> print e)
22 | where
23 | newArray = matrix @Double (2,2) [ [1..], [1..] ] * matrix @Double (2,2) [ [2..], [2..] ]
24 |
25 | main :: IO ()
26 | main = do
27 | main'
28 | -- testWindow
29 | -- ks <- randn @Double [100,100]
30 | -- saveArray "key" ks "array.txt" False
31 | -- !ks' <- readArrayKey "array.txt" "key"
32 | -- print ks'
33 |
34 | -- info >> putStrLn "ok" >> afInit
35 | -- -- Info things
36 | -- print =<< getSizeOf (Proxy @ Double)
37 | -- print =<< getVersion
38 | -- print =<< getRevision
39 | -- -- getInfo
40 | -- -- print =<< errorToString afErrNoMem
41 | -- putStrLn =<< getInfoString
42 | -- print =<< getDeviceCount
43 | -- print =<< getDevice
44 |
45 | -- -- Create and print an array
46 | -- -- arr1 <- constant 1 1 1 f64
47 | -- -- arr2 <- constant 2 1 1 f64
48 | -- -- r <- addArray arr1 arr2 True
49 | -- -- printArray r
50 |
51 | -- -- print =<< isLAPACKAvailable
52 | -- -- print =<< getAvailableBackends
53 | -- -- print =<< getActiveBackend
54 | -- -- print =<< getAvailableBackends
55 |
56 | -- -- array <- constant @'(10,10) 200
57 | -- -- putStrLn "backend id"
58 | -- -- print (getBackendID array)
59 | -- -- putStrLn "device id"
60 | -- -- print (getDeviceID array)
61 |
62 | -- -- array <- randu @'(9,9,9) @Double
63 | -- -- printArray array -- printArray (mean array 0)
64 |
65 | -- -- printArray (add array 1)
66 |
67 | -- -- putStrLn "got eeem"
68 | -- -- print =<< getDataPtr x
69 |
70 | -- -- x <- constant 10 1 1 f64
71 | -- -- printArray =<< mean x 0
72 |
73 | -- -- print =<< isLAPACKAvailable
74 |
75 | -- a <- randu @'(3,3) @Float
76 | -- b <- randu @'(3,3) @Float
77 | -- printArray ((a `matmul` b) None None)
78 | -- `catch` (\(e :: AFException) -> do
79 | -- putStrLn "got one"
80 | -- print e)
81 |
82 | putStrLn "create window"
83 | window <- createWindow 200 200 "hey"
84 | putStrLn "set visibility"
85 | setVisibility window True
86 | putStrLn "show window"
87 | showWindow window
88 | threadDelay (secs 10)
89 |
90 | -- -- print =<< getActiveBackend
91 | -- -- print =<< getDeviceCount
92 | -- -- print =<< getDevice
93 | -- -- putStrLn "info"
94 | -- -- getInfo
95 | -- -- putStrLn "info string"
96 | -- -- putStrLn =<< getInfoString
97 | -- -- print =<< getVersion
98 |
99 |
100 | secs :: Int -> Int
101 | secs = (*1000000)
102 |
--------------------------------------------------------------------------------
/flake.lock:
--------------------------------------------------------------------------------
1 | {
2 | "nodes": {
3 | "arrayfire-nix": {
4 | "inputs": {
5 | "flake-utils": [
6 | "flake-utils"
7 | ],
8 | "nixpkgs": [
9 | "nixpkgs"
10 | ]
11 | },
12 | "locked": {
13 | "lastModified": 1692793973,
14 | "narHash": "sha256-6dG41ile3T+6dfRazlcPBdKBarGesswsBpb40Lcf35U=",
15 | "owner": "twesterhout",
16 | "repo": "arrayfire-nix",
17 | "rev": "4236770612b80a3f29adbd8d670f6cea2bc098ba",
18 | "type": "github"
19 | },
20 | "original": {
21 | "owner": "twesterhout",
22 | "repo": "arrayfire-nix",
23 | "type": "github"
24 | }
25 | },
26 | "flake-utils": {
27 | "inputs": {
28 | "systems": "systems"
29 | },
30 | "locked": {
31 | "lastModified": 1692792214,
32 | "narHash": "sha256-voZDQOvqHsaReipVd3zTKSBwN7LZcUwi3/ThMxRZToU=",
33 | "owner": "numtide",
34 | "repo": "flake-utils",
35 | "rev": "1721b3e7c882f75f2301b00d48a2884af8c448ae",
36 | "type": "github"
37 | },
38 | "original": {
39 | "owner": "numtide",
40 | "repo": "flake-utils",
41 | "type": "github"
42 | }
43 | },
44 | "nix-filter": {
45 | "locked": {
46 | "lastModified": 1687178632,
47 | "narHash": "sha256-HS7YR5erss0JCaUijPeyg2XrisEb959FIct3n2TMGbE=",
48 | "owner": "numtide",
49 | "repo": "nix-filter",
50 | "rev": "d90c75e8319d0dd9be67d933d8eb9d0894ec9174",
51 | "type": "github"
52 | },
53 | "original": {
54 | "owner": "numtide",
55 | "repo": "nix-filter",
56 | "type": "github"
57 | }
58 | },
59 | "nixpkgs": {
60 | "locked": {
61 | "lastModified": 1692638711,
62 | "narHash": "sha256-J0LgSFgJVGCC1+j5R2QndadWI1oumusg6hCtYAzLID4=",
63 | "owner": "nixos",
64 | "repo": "nixpkgs",
65 | "rev": "91a22f76cd1716f9d0149e8a5c68424bb691de15",
66 | "type": "github"
67 | },
68 | "original": {
69 | "owner": "nixos",
70 | "ref": "nixos-unstable",
71 | "repo": "nixpkgs",
72 | "type": "github"
73 | }
74 | },
75 | "root": {
76 | "inputs": {
77 | "arrayfire-nix": "arrayfire-nix",
78 | "flake-utils": "flake-utils",
79 | "nix-filter": "nix-filter",
80 | "nixpkgs": "nixpkgs"
81 | }
82 | },
83 | "systems": {
84 | "locked": {
85 | "lastModified": 1681028828,
86 | "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
87 | "owner": "nix-systems",
88 | "repo": "default",
89 | "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
90 | "type": "github"
91 | },
92 | "original": {
93 | "owner": "nix-systems",
94 | "repo": "default",
95 | "type": "github"
96 | }
97 | }
98 | },
99 | "root": "root",
100 | "version": 7
101 | }
102 |
--------------------------------------------------------------------------------
/flake.nix:
--------------------------------------------------------------------------------
1 | {
2 | description = "arrayfire/arrayfire-haskell: ArrayFire Haskell bindings";
3 |
4 | inputs = {
5 | nixpkgs.url = "github:nixos/nixpkgs/nixos-unstable";
6 | flake-utils.url = "github:numtide/flake-utils";
7 | nix-filter.url = "github:numtide/nix-filter";
8 | arrayfire-nix = {
9 | url = "github:twesterhout/arrayfire-nix";
10 | inputs.flake-utils.follows = "flake-utils";
11 | inputs.nixpkgs.follows = "nixpkgs";
12 | };
13 | };
14 |
15 | outputs = inputs:
16 | let
17 | src = inputs.nix-filter.lib {
18 | root = ./.;
19 | include = [
20 | "cbits"
21 | "exe"
22 | "gen"
23 | "include"
24 | "src"
25 | "test"
26 | "arrayfire.cabal"
27 | "README.md"
28 | "CHANGELOG.md"
29 | "LICENSE"
30 | ];
31 | };
32 |
33 | # An overlay that lets us test arrayfire-haskell with different GHC versions
34 | arrayfire-haskell-overlay = self: super: {
35 | haskell = super.haskell // {
36 | packageOverrides = inputs.nixpkgs.lib.composeExtensions super.haskell.packageOverrides
37 | (hself: hsuper: {
38 | arrayfire = self.haskell.lib.appendConfigureFlags
39 | (hself.callCabal2nix "arrayfire" src { af = self.arrayfire; })
40 | [ "-f disable-default-paths" ];
41 | });
42 | };
43 | };
44 |
45 | devShell-for = pkgs:
46 | let
47 | ps = pkgs.haskellPackages;
48 | in
49 | ps.shellFor {
50 | packages = ps: with ps; [ arrayfire ];
51 | withHoogle = true;
52 | buildInputs = with pkgs; [ ocl-icd ];
53 | nativeBuildInputs = with pkgs; with ps; [
54 | # Building and testing
55 | cabal-install
56 | doctest
57 | hsc2hs
58 | hspec-discover
59 | # Language servers
60 | haskell-language-server
61 | nil
62 | # Formatters
63 | nixpkgs-fmt
64 | ];
65 | shellHook = ''
66 | '';
67 | };
68 |
69 | pkgs-for = system: import inputs.nixpkgs {
70 | inherit system;
71 | overlays = [
72 | inputs.arrayfire-nix.overlays.default
73 | arrayfire-haskell-overlay
74 | ];
75 | };
76 | in
77 | {
78 | packages = inputs.flake-utils.lib.eachDefaultSystemMap (system:
79 | with (pkgs-for system); {
80 | default = haskellPackages.arrayfire;
81 | haskell = haskell.packages;
82 | });
83 |
84 | devShells = inputs.flake-utils.lib.eachDefaultSystemMap (system: {
85 | default = devShell-for (pkgs-for system);
86 | });
87 |
88 | overlays.default = arrayfire-haskell-overlay;
89 | };
90 | }
91 |
--------------------------------------------------------------------------------
/gen/Lex.hs:
--------------------------------------------------------------------------------
1 | {-# LANGUAGE OverloadedStrings #-}
2 | module Lex where
3 |
4 | import Control.Arrow
5 | import Data.Text (Text)
6 | import qualified Data.Text as T
7 |
8 | import Data.Char
9 | import Types
10 |
11 | symbols :: String
12 | symbols = " *();,"
13 |
14 | lex :: Text -> [Token]
15 | lex = go NameMode
16 | where
17 | tokenize ' ' = []
18 | tokenize '*' = [Star]
19 | tokenize '(' = [LParen]
20 | tokenize ')' = [RParen]
21 | tokenize ';' = [Semi]
22 | tokenize ',' = [Comma]
23 | tokenize _ = []
24 | go TokenMode xs = do
25 | case T.uncons xs of
26 | Nothing -> []
27 | Just (c,cs)
28 | | isAlpha c -> go NameMode (T.cons c cs)
29 | | otherwise -> tokenize c ++ go TokenMode cs
30 | go NameMode xs = do
31 | let (match, rest) = partition xs
32 | if match == "const"
33 | then [] ++ go TokenMode rest
34 | else Id match : go TokenMode rest
35 |
36 | partition :: Text -> (Text,Text)
37 | partition =
38 | T.takeWhile (`notElem` symbols) &&&
39 | T.dropWhile (`notElem` symbols)
40 |
--------------------------------------------------------------------------------
/gen/Main.hs:
--------------------------------------------------------------------------------
1 | {-# LANGUAGE ScopedTypeVariables #-}
2 | {-# LANGUAGE FlexibleContexts #-}
3 | {-# LANGUAGE OverloadedStrings #-}
4 | {-# LANGUAGE GeneralizedNewtypeDeriving #-}
5 | module Main where
6 |
7 | import Control.Monad
8 | import Data.Char
9 | import Data.Either
10 |
11 | import Data.Maybe
12 | import Data.Text (Text)
13 | import qualified Data.Text as T
14 | import qualified Data.Text.IO as T
15 | import System.Directory
16 | import System.Environment
17 | import System.Exit
18 | import Text.Printf
19 |
20 | import Lex
21 | import Parse
22 | import Print
23 | import Types
24 |
25 | main :: IO ()
26 | main = mapM_ writeToDisk =<< getDirectoryFiles
27 |
28 | getDirectoryFiles :: IO [String]
29 | getDirectoryFiles = do
30 | filter (`notElem` exclude) <$> listDirectory "include"
31 | where
32 | exclude = [ "defines.h"
33 | , "complex.h"
34 | , "extra.h"
35 | ]
36 |
37 | writeToDisk :: String -> IO ()
38 | writeToDisk fileName = do
39 | bindings
40 | <- map run
41 | . drop 1
42 | . filter (not . T.null)
43 | . T.lines <$> T.readFile ("include/" <> fileName)
44 | case partitionEithers bindings of
45 | (failures, successes) -> do
46 | if length failures > 0
47 | then do
48 | mapM_ print (listToMaybe failures)
49 | printf "%s failed to generate bindings\n" fileName
50 | else do
51 | let name = makeName (reverse . drop 2 . reverse $ fileName)
52 | T.writeFile (makePath name) $
53 | file name <> T.intercalate "\n" (genBinding <$> successes)
54 | printf "Wrote bindings to %s\n" (makePath name)
55 |
56 | -- | Filename remappings
57 | makeName :: String -> String
58 | makeName n
59 | | n == "lapack" = "LAPACK"
60 | | n == "blas" = "BLAS"
61 | | n == "cuda" = "CUDA"
62 | | otherwise = n
63 |
64 | makePath :: String -> String
65 | makePath s =
66 | printf "src/ArrayFire/Internal/%s.hsc" (capitalName (makeName s))
67 |
68 | file :: String -> Text
69 | file a = T.pack $ printf
70 | "{-# LANGUAGE CPP #-}\n\
71 | \module ArrayFire.Internal.%s where\n\n\
72 | \import ArrayFire.Internal.Defines\n\
73 | \import ArrayFire.Internal.Types\n\
74 | \import Foreign.Ptr\n\
75 | \import Foreign.C.Types\n\n\
76 | \#include \"af/%s.h\"\n\
77 | \" (capitalName a) (if lowerCase a == "exception"
78 | then "defines"
79 | else lowerCase a)
80 |
81 | capitalName, lowerCase :: [Char] -> [Char]
82 | capitalName (x:xs) = toUpper x : xs
83 | lowerCase (x:xs) = map toLower (x:xs)
84 |
--------------------------------------------------------------------------------
/gen/Parse.hs:
--------------------------------------------------------------------------------
1 | {-# LANGUAGE OverloadedStrings #-}
2 | module Parse where
3 |
4 | import Prelude hiding (lex)
5 | import Control.Monad
6 | import Data.Text (Text)
7 | import qualified Data.Text as T
8 | import Text.Parsec
9 |
10 | import Lex
11 | import Types
12 |
13 | parseAST :: Parser AST
14 | parseAST = do
15 | afapi
16 | type' <- getType
17 | funcName <- name
18 | lparen
19 | params <- getParam `sepBy` comma
20 | rparen >> semi
21 | pure (AST type' funcName params)
22 |
23 | getParam :: Parser Type
24 | getParam = do
25 | t <- getType
26 | t <$ name
27 |
28 | getType :: Parser Type
29 | getType = do
30 | typeName <- name
31 | stars <- msum [ try (rep x) | x <- [3,2..0] ]
32 | pure (Type typeName stars)
33 | where
34 | rep n = replicateM_ n star >> pure n
35 |
36 | run :: Text -> Either ParseError AST
37 | run txt = parse parseAST mempty (lex txt)
38 |
39 | afapi,lparen,rparen,semi,star,comma :: Parser ()
40 | lparen = tok' LParen
41 | rparen = tok' RParen
42 | afapi = tok' (Id "AFAPI") <|> pure ()
43 | comma = tok' Comma
44 | semi = tok' Semi
45 | star = tok' Star
46 |
47 | tok :: Token -> Parser Token
48 | tok x = tokenPrim show ignore
49 | (\t -> if x == t then Just x else Nothing)
50 |
51 | tok' :: Token -> Parser ()
52 | tok' x = tokenPrim show ignore
53 | (\t -> if x == t then Just () else Nothing)
54 |
55 | name :: Parser Name
56 | name = tokenPrim show ignore go
57 | where
58 | go (Id x) = Just (Name x)
59 | go _ = Nothing
60 |
61 | ignore x _ _ = x
62 |
--------------------------------------------------------------------------------
/gen/Print.hs:
--------------------------------------------------------------------------------
1 | {-# LANGUAGE OverloadedStrings #-}
2 | module Print where
3 |
4 | import Data.Char
5 | import Data.Text (Text)
6 | import qualified Data.Text as T
7 | import Text.Printf
8 |
9 | import Types
10 |
11 | import Parse
12 |
13 | camelize :: Text -> Text
14 | camelize = T.concat . map go . T.splitOn "_"
15 | where
16 | go "af" = "AF"
17 | go xs = T.cons (toUpper c) cs
18 | where
19 | c = T.head xs
20 | cs = T.tail xs
21 |
22 | isPtr (Type _ x) = x > 0
23 |
24 | genBinding :: AST -> Text
25 | genBinding (AST type' name params) =
26 | header <> dumpBody <> dumpOutput
27 | where
28 | dumpOutput | isPtr type' = "IO (" <> printType type' <> ")"
29 | | otherwise = "IO " <> printType type'
30 | header = T.pack $ printf "foreign import ccall unsafe \"%s\"\n %s :: " name name
31 | dumpBody = printTypes params
32 |
33 | printTypes :: [Type] -> Text
34 | printTypes [] = mempty
35 | printTypes [x] = printType x <> " -> "
36 | printTypes (x:xs) =
37 | mconcat [
38 | printType x
39 | , " -> "
40 | , printTypes xs
41 | ]
42 |
43 | printType (Type (Name x) 0) = showType x
44 | printType (Type (Name x) 1) = "Ptr " <> showType x
45 | printType (Type t n) = "Ptr (" <> printType (Type t (n-1)) <> ")"
46 |
47 | -- | Additional mappings, very important for CodeGen
48 | showType :: Text -> Text
49 | showType "char" = "CChar"
50 | showType "void" = "()"
51 | showType "unsigned" = "CUInt"
52 | showType "dim_t" = "DimT"
53 | showType "af_someenum_t" = "AFSomeEnum"
54 | showType "size_t" = "CSize"
55 | showType "uintl" = "UIntL"
56 | showType "intl" = "IntL"
57 | showType "int" = "CInt"
58 | showType "bool" = "CBool"
59 | showType "af_index_t" = "AFIndex"
60 | showType "af_cspace_t" = "AFCSpace"
61 | showType "afcl_platform" = "AFCLPlatform"
62 | showType x = camelize x
63 |
--------------------------------------------------------------------------------
/gen/Types.hs:
--------------------------------------------------------------------------------
1 | {-# LANGUAGE GeneralizedNewtypeDeriving #-}
2 | module Types where
3 |
4 | import Data.Functor.Identity
5 | import Data.Text (Text)
6 | import qualified Data.Text as T
7 | import Text.Parsec
8 | import Text.Printf
9 |
10 | type Parser = ParsecT [Token] () Identity
11 |
12 | type Params = [Type]
13 |
14 | data AST = AST Type Name [Type]
15 | deriving (Show)
16 |
17 | data Type = Type Name Int
18 | deriving (Show)
19 |
20 | newtype Name = Name Text
21 | deriving (Show, Eq, PrintfArg)
22 |
23 | data Mode
24 | = TokenMode
25 | | NameMode
26 | deriving (Eq, Show)
27 |
28 | data Token
29 | = Id Text
30 | | Star
31 | | LParen
32 | | RParen
33 | | Comma
34 | | Semi
35 | deriving (Show, Eq)
36 |
--------------------------------------------------------------------------------
/include/algorithm.h:
--------------------------------------------------------------------------------
1 | #include "defines.h"
2 |
3 | af_err af_sum(af_array *out, const af_array in, const int dim);
4 | af_err af_sum_nan(af_array *out, const af_array in, const int dim, const double nanval);
5 | af_err af_product(af_array *out, const af_array in, const int dim);
6 | af_err af_product_nan(af_array *out, const af_array in, const int dim, const double nanval);
7 | af_err af_min(af_array *out, const af_array in, const int dim);
8 | af_err af_max(af_array *out, const af_array in, const int dim);
9 | af_err af_all_true(af_array *out, const af_array in, const int dim);
10 | af_err af_any_true(af_array *out, const af_array in, const int dim);
11 | af_err af_count(af_array *out, const af_array in, const int dim);
12 | af_err af_sum_all(double *real, double *imag, const af_array in);
13 | af_err af_sum_nan_all(double *real, double *imag, const af_array in, const double nanval);
14 | af_err af_product_all(double *real, double *imag, const af_array in);
15 | af_err af_product_nan_all(double *real, double *imag, const af_array in, const double nanval);
16 | af_err af_min_all(double *real, double *imag, const af_array in);
17 | af_err af_max_all(double *real, double *imag, const af_array in);
18 | af_err af_all_true_all(double *real, double *imag, const af_array in);
19 | af_err af_any_true_all(double *real, double *imag, const af_array in);
20 | af_err af_count_all(double *real, double *imag, const af_array in);
21 | af_err af_imin(af_array *out, af_array *idx, const af_array in, const int dim);
22 | af_err af_imax(af_array *out, af_array *idx, const af_array in, const int dim);
23 | af_err af_imin_all(double *real, double *imag, unsigned *idx, const af_array in);
24 | af_err af_imax_all(double *real, double *imag, unsigned *idx, const af_array in);
25 | af_err af_accum(af_array *out, const af_array in, const int dim);
26 | af_err af_scan(af_array *out, const af_array in, const int dim, af_binary_op op, bool inclusive_scan);
27 | af_err af_scan_by_key(af_array *out, const af_array key, const af_array in, const int dim, af_binary_op op, bool inclusive_scan);
28 | af_err af_where(af_array *idx, const af_array in);
29 | af_err af_diff1(af_array *out, const af_array in, const int dim);
30 | af_err af_diff2(af_array *out, const af_array in, const int dim);
31 | af_err af_sort(af_array *out, const af_array in, const unsigned dim, const bool isAscending);
32 | af_err af_sort_index(af_array *out, af_array *indices, const af_array in, const unsigned dim, const bool isAscending);
33 | af_err af_sort_by_key(af_array *out_keys, af_array *out_values, const af_array keys, const af_array values, const unsigned dim, const bool isAscending);
34 | af_err af_set_unique(af_array *out, const af_array in, const bool is_sorted);
35 | af_err af_set_union(af_array *out, const af_array first, const af_array second, const bool is_unique);
36 | af_err af_set_intersect(af_array *out, const af_array first, const af_array second, const bool is_unique);
37 |
--------------------------------------------------------------------------------
/include/arith.h:
--------------------------------------------------------------------------------
1 | #include "defines.h"
2 |
3 | af_err af_add (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
4 | af_err af_sub (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
5 | af_err af_mul (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
6 | af_err af_div (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
7 | af_err af_lt (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
8 | af_err af_gt (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
9 | af_err af_le (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
10 | af_err af_ge (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
11 | af_err af_eq (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
12 | af_err af_neq (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
13 | af_err af_and (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
14 | af_err af_or (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
15 | af_err af_not (af_array *out, const af_array in);
16 | af_err af_bitand (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
17 | af_err af_bitor (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
18 | af_err af_bitxor (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
19 | af_err af_bitshiftl(af_array *out, const af_array lhs, const af_array rhs, const bool batch);
20 | af_err af_bitshiftr(af_array *out, const af_array lhs, const af_array rhs, const bool batch);
21 | af_err af_cast (af_array *out, const af_array in, const af_dtype type);
22 | af_err af_minof (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
23 | af_err af_maxof (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
24 | af_err af_clamp(af_array *out, const af_array in, const af_array lo, const af_array hi, const bool batch);
25 | af_err af_rem (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
26 | af_err af_mod (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
27 | af_err af_abs (af_array *out, const af_array in);
28 | af_err af_arg (af_array *out, const af_array in);
29 | af_err af_sign (af_array *out, const af_array in);
30 | af_err af_round (af_array *out, const af_array in);
31 | af_err af_trunc (af_array *out, const af_array in);
32 | af_err af_floor (af_array *out, const af_array in);
33 | af_err af_ceil (af_array *out, const af_array in);
34 | af_err af_hypot (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
35 | af_err af_sin (af_array *out, const af_array in);
36 | af_err af_cos (af_array *out, const af_array in);
37 | af_err af_tan (af_array *out, const af_array in);
38 | af_err af_asin (af_array *out, const af_array in);
39 | af_err af_acos (af_array *out, const af_array in);
40 | af_err af_atan (af_array *out, const af_array in);
41 | af_err af_atan2 (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
42 | af_err af_cplx2 (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
43 | af_err af_cplx (af_array *out, const af_array in);
44 | af_err af_real (af_array *out, const af_array in);
45 | af_err af_imag (af_array *out, const af_array in);
46 | af_err af_conjg (af_array *out, const af_array in);
47 | af_err af_sinh (af_array *out, const af_array in);
48 | af_err af_cosh (af_array *out, const af_array in);
49 | af_err af_tanh (af_array *out, const af_array in);
50 | af_err af_asinh (af_array *out, const af_array in);
51 | af_err af_acosh (af_array *out, const af_array in);
52 | af_err af_atanh (af_array *out, const af_array in);
53 | af_err af_root (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
54 | af_err af_pow (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
55 | af_err af_pow2 (af_array *out, const af_array in);
56 | af_err af_exp (af_array *out, const af_array in);
57 | af_err af_sigmoid (af_array *out, const af_array in);
58 | af_err af_expm1 (af_array *out, const af_array in);
59 | af_err af_erf (af_array *out, const af_array in);
60 | af_err af_erfc (af_array *out, const af_array in);
61 | af_err af_log (af_array *out, const af_array in);
62 | af_err af_log1p (af_array *out, const af_array in);
63 | af_err af_log10 (af_array *out, const af_array in);
64 | af_err af_log2 (af_array *out, const af_array in);
65 | af_err af_sqrt (af_array *out, const af_array in);
66 | af_err af_cbrt (af_array *out, const af_array in);
67 | af_err af_factorial (af_array *out, const af_array in);
68 | af_err af_tgamma (af_array *out, const af_array in);
69 | af_err af_lgamma (af_array *out, const af_array in);
70 | af_err af_iszero (af_array *out, const af_array in);
71 | af_err af_isinf (af_array *out, const af_array in);
72 | af_err af_isnan (af_array *out, const af_array in);
73 |
--------------------------------------------------------------------------------
/include/array.h:
--------------------------------------------------------------------------------
1 | #include "defines.h"
2 |
3 | af_err af_create_array(af_array *arr, const void * const data, const unsigned ndims, const dim_t * const dims, const af_dtype type);
4 | af_err af_create_handle(af_array *arr, const unsigned ndims, const dim_t * const dims, const af_dtype type);
5 | af_err af_copy_array(af_array *arr, const af_array in);
6 | af_err af_write_array(af_array arr, const void *data, const size_t bytes, af_source src);
7 | af_err af_get_data_ptr(void *data, const af_array arr);
8 | af_err af_release_array(af_array arr);
9 | af_err af_retain_array(af_array *out, const af_array in);
10 | af_err af_get_data_ref_count(int *use_count, const af_array in);
11 | af_err af_eval(af_array in);
12 | af_err af_eval_multiple(const int num, af_array *arrays);
13 | af_err af_set_manual_eval_flag(bool flag);
14 | af_err af_get_manual_eval_flag(bool *flag);
15 | af_err af_get_elements(dim_t *elems, const af_array arr);
16 | af_err af_get_type(af_dtype *type, const af_array arr);
17 | af_err af_get_dims(dim_t *d0, dim_t *d1, dim_t *d2, dim_t *d3, const af_array arr);
18 | af_err af_get_numdims(unsigned *result, const af_array arr);
19 | af_err af_is_empty (bool *result, const af_array arr);
20 | af_err af_is_scalar (bool *result, const af_array arr);
21 | af_err af_is_row (bool *result, const af_array arr);
22 | af_err af_is_column (bool *result, const af_array arr);
23 | af_err af_is_vector (bool *result, const af_array arr);
24 | af_err af_is_complex (bool *result, const af_array arr);
25 | af_err af_is_real (bool *result, const af_array arr);
26 | af_err af_is_double (bool *result, const af_array arr);
27 | af_err af_is_single (bool *result, const af_array arr);
28 | af_err af_is_realfloating (bool *result, const af_array arr);
29 | af_err af_is_floating (bool *result, const af_array arr);
30 | af_err af_is_integer (bool *result, const af_array arr);
31 | af_err af_is_bool (bool *result, const af_array arr);
32 | af_err af_is_sparse (bool *result, const af_array arr);
33 | af_err af_get_scalar(void* output_value, const af_array arr);
34 |
--------------------------------------------------------------------------------
/include/backend.h:
--------------------------------------------------------------------------------
1 | #include "defines.h"
2 |
3 | af_err af_set_backend(const af_backend bknd);
4 | af_err af_get_backend_count(unsigned* num_backends);
5 | af_err af_get_available_backends(int* backends);
6 | af_err af_get_backend_id(af_backend *backend, const af_array in);
7 | af_err af_get_active_backend(af_backend *backend);
8 | af_err af_get_device_id(int *device, const af_array in);
9 |
--------------------------------------------------------------------------------
/include/blas.h:
--------------------------------------------------------------------------------
1 | #include "defines.h"
2 |
3 | af_err af_matmul( af_array *out ,const af_array lhs, const af_array rhs, const af_mat_prop optLhs, const af_mat_prop optRhs);
4 | af_err af_dot(af_array *out, const af_array lhs, const af_array rhs, const af_mat_prop optLhs, const af_mat_prop optRhs);
5 | af_err af_dot_all(double *real, double *imag, const af_array lhs, const af_array rhs, const af_mat_prop optLhs, const af_mat_prop optRhs);
6 | af_err af_transpose(af_array *out, af_array in, const bool conjugate);
7 | af_err af_transpose_inplace(af_array in, const bool conjugate);
8 |
--------------------------------------------------------------------------------
/include/cuda.h:
--------------------------------------------------------------------------------
1 | #include "defines.h"
2 |
3 | af_err afcu_get_stream(cudaStream_t* stream, int id);
4 | af_err afcu_get_native_id(int* nativeid, int id);
5 | af_err afcu_set_native_id(int nativeid);
6 |
--------------------------------------------------------------------------------
/include/data.h:
--------------------------------------------------------------------------------
1 | #include "defines.h"
2 |
3 | af_err af_constant(af_array *arr, const double val, const unsigned ndims, const dim_t * const dims, const af_dtype type);
4 | af_err af_constant_complex(af_array *arr, const double real, const double imag, const unsigned ndims, const dim_t * const dims, const af_dtype type);
5 | af_err af_constant_long (af_array *arr, const intl val, const unsigned ndims, const dim_t * const dims);
6 | af_err af_constant_ulong(af_array *arr, const uintl val, const unsigned ndims, const dim_t * const dims);
7 | af_err af_range(af_array *out, const unsigned ndims, const dim_t * const dims, const int seq_dim, const af_dtype type);
8 | af_err af_iota(af_array *out, const unsigned ndims, const dim_t * const dims, const unsigned t_ndims, const dim_t * const tdims, const af_dtype type);
9 | af_err af_identity(af_array *out, const unsigned ndims, const dim_t * const dims, const af_dtype type);
10 | af_err af_diag_create(af_array *out, const af_array in, const int num);
11 | af_err af_diag_extract(af_array *out, const af_array in, const int num);
12 | af_err af_join(af_array *out, const int dim, const af_array first, const af_array second);
13 | af_err af_join_many(af_array *out, const int dim, const unsigned n_arrays, const af_array *inputs);
14 | af_err af_tile(af_array *out, const af_array in, const unsigned x, const unsigned y, const unsigned z, const unsigned w);
15 | af_err af_reorder(af_array *out, const af_array in, const unsigned x, const unsigned y, const unsigned z, const unsigned w);
16 | af_err af_shift(af_array *out, const af_array in, const int x, const int y, const int z, const int w);
17 | af_err af_moddims(af_array *out, const af_array in, const unsigned ndims, const dim_t * const dims);
18 | af_err af_flat(af_array *out, const af_array in);
19 | af_err af_flip(af_array *out, const af_array in, const unsigned dim);
20 | af_err af_lower(af_array *out, const af_array in, bool is_unit_diag);
21 | af_err af_upper(af_array *out, const af_array in, bool is_unit_diag);
22 | af_err af_select(af_array *out, const af_array cond, const af_array a, const af_array b);
23 | af_err af_select_scalar_r(af_array *out, const af_array cond, const af_array a, const double b);
24 | af_err af_select_scalar_l(af_array *out, const af_array cond, const double a, const af_array b);
25 | af_err af_replace(af_array a, const af_array cond, const af_array b);
26 | af_err af_replace_scalar(af_array a, const af_array cond, const double b);
27 |
--------------------------------------------------------------------------------
/include/device.h:
--------------------------------------------------------------------------------
1 | #include "defines.h"
2 |
3 | af_err af_info();
4 | af_err af_init();
5 | af_err af_info_string(char** str, const bool verbose);
6 | af_err af_device_info(char* d_name, char* d_platform, char *d_toolkit, char* d_compute);
7 | af_err af_get_device_count(int *num_of_devices);
8 | af_err af_get_dbl_support(bool* available, const int device);
9 | af_err af_set_device(const int device);
10 | af_err af_get_device(int *device);
11 | af_err af_sync(const int device);
12 | af_err af_alloc_device(void **ptr, const dim_t bytes);
13 | af_err af_free_device(void *ptr);
14 | af_err af_alloc_pinned(void **ptr, const dim_t bytes);
15 | af_err af_free_pinned(void *ptr);
16 | af_err af_alloc_host(void **ptr, const dim_t bytes);
17 | af_err af_free_host(void *ptr);
18 | af_err af_device_array(af_array *arr, const void *data, const unsigned ndims, const dim_t * const dims, const af_dtype type);
19 | af_err af_device_mem_info(size_t *alloc_bytes, size_t *alloc_buffers, size_t *lock_bytes, size_t *lock_buffers);
20 | af_err af_print_mem_info(const char *msg, const int device_id);
21 | af_err af_device_gc();
22 | af_err af_set_mem_step_size(const size_t step_bytes);
23 | af_err af_get_mem_step_size(size_t *step_bytes);
24 | af_err af_lock_device_ptr(const af_array arr);
25 | af_err af_unlock_device_ptr(const af_array arr);
26 | af_err af_lock_array(const af_array arr);
27 | af_err af_is_locked_array(bool *res, const af_array arr);
28 | af_err af_get_device_ptr(void **ptr, const af_array arr);
29 |
--------------------------------------------------------------------------------
/include/exception.h:
--------------------------------------------------------------------------------
1 | #include "defines.h"
2 |
3 | void af_get_last_error(char **msg, dim_t *len);
4 | const char *af_err_to_string(const af_err err);
5 |
6 |
--------------------------------------------------------------------------------
/include/features.h:
--------------------------------------------------------------------------------
1 | #include "defines.h"
2 |
3 | af_err af_create_features(af_features *feat, dim_t num);
4 | af_err af_retain_features(af_features *out, const af_features feat);
5 | af_err af_get_features_num(dim_t *num, const af_features feat);
6 | af_err af_get_features_xpos(af_array *out, const af_features feat);
7 | af_err af_get_features_ypos(af_array *out, const af_features feat);
8 | af_err af_get_features_score(af_array *score, const af_features feat);
9 | af_err af_get_features_orientation(af_array *orientation, const af_features feat);
10 | af_err af_get_features_size(af_array *size, const af_features feat);
11 | af_err af_release_features(af_features feat);
12 |
--------------------------------------------------------------------------------
/include/graphics.h:
--------------------------------------------------------------------------------
1 | #include "defines.h"
2 |
3 | af_err af_create_window(af_window *out, const int width, const int height, const char* const title);
4 | af_err af_set_position(const af_window wind, const unsigned x, const unsigned y);
5 | af_err af_set_title(const af_window wind, const char* const title);
6 | af_err af_set_size(const af_window wind, const unsigned w, const unsigned h);
7 | af_err af_draw_image(const af_window wind, const af_array in, const af_cell* const props);
8 | af_err af_draw_plot(const af_window wind, const af_array X, const af_array Y, const af_cell* const props);
9 | af_err af_draw_plot3(const af_window wind, const af_array P, const af_cell* const props);
10 | af_err af_draw_plot_nd(const af_window wind, const af_array P, const af_cell* const props);
11 | af_err af_draw_plot_2d(const af_window wind, const af_array X, const af_array Y,const af_cell* const props);
12 | af_err af_draw_plot_3d(const af_window wind,const af_array X, const af_array Y, const af_array Z,const af_cell* const props);
13 | af_err af_draw_scatter(const af_window wind, const af_array X, const af_array Y, const af_marker_type marker, const af_cell* const props);
14 | af_err af_draw_scatter3(const af_window wind, const af_array P, const af_marker_type marker, const af_cell* const props);
15 | af_err af_draw_scatter_nd(const af_window wind, const af_array P, const af_marker_type marker, const af_cell* const props);
16 | af_err af_draw_scatter_2d(const af_window wind, const af_array X, const af_array Y,const af_marker_type marker, const af_cell* const props);
17 | af_err af_draw_scatter_3d(const af_window wind, const af_array X, const af_array Y, const af_array Z, const af_marker_type marker, const af_cell* const props);
18 | af_err af_draw_hist(const af_window wind, const af_array X, const double minval, const double maxval, const af_cell* const props);
19 | af_err af_draw_surface(const af_window wind, const af_array xVals, const af_array yVals, const af_array S, const af_cell* const props);
20 | af_err af_draw_vector_field_nd(const af_window wind, const af_array points, const af_array directions, const af_cell* const props);
21 | af_err af_draw_vector_field_3d(const af_window wind, const af_array xPoints, const af_array yPoints, const af_array zPoints, const af_array xDirs, const af_array yDirs, const af_array zDirs, const af_cell* const props);
22 | af_err af_draw_vector_field_2d(const af_window wind,const af_array xPoints, const af_array yPoints,const af_array xDirs, const af_array yDirs,const af_cell* const props);
23 | af_err af_grid(const af_window wind, const int rows, const int cols);
24 | af_err af_set_axes_limits_compute(const af_window wind, const af_array x, const af_array y, const af_array z,const bool exact, const af_cell* const props);
25 | af_err af_set_axes_limits_2d(const af_window wind, const float xmin, const float xmax, const float ymin, const float ymax, const bool exact, const af_cell* const props);
26 | af_err af_set_axes_limits_3d(const af_window wind, const float xmin, const float xmax, const float ymin, const float ymax, const float zmin, const float zmax, const bool exact, const af_cell* const props);
27 | af_err af_set_axes_titles(const af_window wind, const char * const xtitle, const char * const ytitle, const char * const ztitle, const af_cell* const props);
28 | af_err af_show(const af_window wind);
29 | af_err af_is_window_closed(bool *out, const af_window wind);
30 | af_err af_set_visibility(const af_window wind, const bool is_visible);
31 | af_err af_destroy_window(const af_window wind);
32 |
--------------------------------------------------------------------------------
/include/image.h:
--------------------------------------------------------------------------------
1 | #include "defines.h"
2 |
3 | af_err af_gradient(af_array *dx, af_array *dy, const af_array in);
4 | af_err af_load_image(af_array *out, const char* filename, const bool isColor);
5 | af_err af_save_image(const char* filename, const af_array in);
6 | af_err af_load_image_memory(af_array *out, const void* ptr);
7 | af_err af_save_image_memory(void** ptr, const af_array in, const af_image_format format);
8 | af_err af_delete_image_memory(void* ptr);
9 | af_err af_load_image_native(af_array *out, const char* filename);
10 | af_err af_save_image_native(const char* filename, const af_array in);
11 | af_err af_is_image_io_available(bool *out);
12 | af_err af_resize(af_array *out, const af_array in, const dim_t odim0, const dim_t odim1, const af_interp_type method);
13 | af_err af_transform(af_array *out, const af_array in, const af_array transform, const dim_t odim0, const dim_t odim1, const af_interp_type method, const bool inverse);
14 | af_err af_transform_coordinates(af_array *out, const af_array tf, const float d0, const float d1);
15 | af_err af_rotate(af_array *out, const af_array in, const float theta, const bool crop, const af_interp_type method);
16 | af_err af_translate(af_array *out, const af_array in, const float trans0, const float trans1, const dim_t odim0, const dim_t odim1, const af_interp_type method);
17 | af_err af_scale(af_array *out, const af_array in, const float scale0, const float scale1, const dim_t odim0, const dim_t odim1, const af_interp_type method);
18 | af_err af_skew(af_array *out, const af_array in, const float skew0, const float skew1, const dim_t odim0, const dim_t odim1, const af_interp_type method, const bool inverse);
19 | af_err af_histogram(af_array *out, const af_array in, const unsigned nbins, const double minval, const double maxval);
20 | af_err af_dilate(af_array *out, const af_array in, const af_array mask);
21 | af_err af_dilate3(af_array *out, const af_array in, const af_array mask);
22 | af_err af_erode(af_array *out, const af_array in, const af_array mask);
23 | af_err af_erode3(af_array *out, const af_array in, const af_array mask);
24 | af_err af_bilateral(af_array *out, const af_array in, const float spatial_sigma, const float chromatic_sigma, const bool isColor);
25 | af_err af_mean_shift(af_array *out, const af_array in, const float spatial_sigma, const float chromatic_sigma, const unsigned iter, const bool is_color);
26 | af_err af_minfilt(af_array *out, const af_array in, const dim_t wind_length, const dim_t wind_width, const af_border_type edge_pad);
27 | af_err af_maxfilt(af_array *out, const af_array in, const dim_t wind_length, const dim_t wind_width, const af_border_type edge_pad);
28 | af_err af_regions(af_array *out, const af_array in, const af_connectivity connectivity, const af_dtype ty);
29 | af_err af_sobel_operator(af_array *dx, af_array *dy, const af_array img, const unsigned ker_size);
30 | af_err af_rgb2gray(af_array* out, const af_array in, const float rPercent, const float gPercent, const float bPercent);
31 | af_err af_gray2rgb(af_array* out, const af_array in, const float rFactor, const float gFactor, const float bFactor);
32 | af_err af_hist_equal(af_array *out, const af_array in, const af_array hist);
33 | af_err af_gaussian_kernel(af_array *out, const int rows, const int cols, const double sigma_r, const double sigma_c);
34 | af_err af_hsv2rgb(af_array* out, const af_array in);
35 | af_err af_rgb2hsv(af_array* out, const af_array in);
36 | af_err af_color_space(af_array *out, const af_array image, const af_cspace_t to, const af_cspace_t from);
37 | af_err af_unwrap(af_array *out, const af_array in, const dim_t wx, const dim_t wy, const dim_t sx, const dim_t sy, const dim_t px, const dim_t py, const bool is_column);
38 | af_err af_wrap(af_array *out, const af_array in, const dim_t ox, const dim_t oy, const dim_t wx, const dim_t wy, const dim_t sx, const dim_t sy, const dim_t px, const dim_t py, const bool is_column);
39 | af_err af_sat(af_array *out, const af_array in);
40 | af_err af_ycbcr2rgb(af_array* out, const af_array in, const af_ycc_std standard);
41 | af_err af_rgb2ycbcr(af_array* out, const af_array in, const af_ycc_std standard);
42 | af_err af_moments(af_array *out, const af_array in, const af_moment_type moment);
43 | af_err af_moments_all(double* out, const af_array in, const af_moment_type moment);
44 | af_err af_canny(af_array* out, const af_array in, const af_canny_threshold threshold_type, const float low_threshold_ratio, const float high_threshold_ratio, const unsigned sobel_window, const bool is_fast);
45 | af_err af_anisotropic_diffusion(af_array* out, const af_array in, const float timestep, const float conductance, const unsigned iterations, const af_flux_function fftype,const af_diffusion_eq diffusion_kind);
46 |
--------------------------------------------------------------------------------
/include/index.h:
--------------------------------------------------------------------------------
1 | #include "defines.h"
2 |
3 | af_err af_index( af_array *out, const af_array in, const unsigned ndims, const af_seq* const index);
4 | af_err af_lookup( af_array *out, const af_array in, const af_array indices, const unsigned dim);
5 | af_err af_assign_seq( af_array *out, const af_array lhs, const unsigned ndims, const af_seq* const indices, const af_array rhs);
6 | af_err af_index_gen( af_array *out, const af_array in, const dim_t ndims, const af_index_t* indices);
7 | af_err af_assign_gen( af_array *out, const af_array lhs, const dim_t ndims, const af_index_t* indices, const af_array rhs);
8 | af_err af_create_indexers(af_index_t** indexers);
9 | af_err af_set_array_indexer(af_index_t* indexer, const af_array idx, const dim_t dim);
10 | af_err af_set_seq_indexer(af_index_t* indexer, const af_seq* idx, const dim_t dim, const bool is_batch);
11 | af_err af_set_seq_param_indexer(af_index_t* indexer, const double begin, const double end, const double step, const dim_t dim, const bool is_batch);
12 | af_err af_release_indexers(af_index_t* indexers);
13 |
--------------------------------------------------------------------------------
/include/internal.h:
--------------------------------------------------------------------------------
1 | #include "defines.h"
2 |
3 | af_err af_create_strided_array(af_array *arr, const void *data, const dim_t offset, const unsigned ndims, const dim_t *const dims, const dim_t *const strides, const af_dtype ty, const af_source location);
4 | af_err af_get_strides(dim_t *s0, dim_t *s1, dim_t *s2, dim_t *s3, const af_array arr);
5 | af_err af_get_offset(dim_t *offset, const af_array arr);
6 | af_err af_get_raw_ptr(void **ptr, const af_array arr);
7 | af_err af_is_linear(bool *result, const af_array arr);
8 | af_err af_is_owner(bool *result, const af_array arr);
9 | af_err af_get_allocated_bytes(size_t *bytes, const af_array arr);
10 |
--------------------------------------------------------------------------------
/include/lapack.h:
--------------------------------------------------------------------------------
1 | #include "defines.h"
2 |
3 | af_err af_svd(af_array *u, af_array *s, af_array *vt, const af_array in);
4 | af_err af_svd_inplace(af_array *u, af_array *s, af_array *vt, af_array in);
5 | af_err af_lu(af_array *lower, af_array *upper, af_array *pivot, const af_array in);
6 | af_err af_lu_inplace(af_array *pivot, af_array in, const bool is_lapack_piv);
7 | af_err af_qr(af_array *q, af_array *r, af_array *tau, const af_array in);
8 | af_err af_qr_inplace(af_array *tau, af_array in);
9 | af_err af_cholesky(af_array *out, int *info, const af_array in, const bool is_upper);
10 | af_err af_cholesky_inplace(int *info, af_array in, const bool is_upper);
11 | af_err af_solve(af_array *x, const af_array a, const af_array b, const af_mat_prop options);
12 | af_err af_solve_lu(af_array *x, const af_array a, const af_array piv, const af_array b, const af_mat_prop options);
13 | af_err af_inverse(af_array *out, const af_array in, const af_mat_prop options);
14 | af_err af_rank(unsigned *rank, const af_array in, const double tol);
15 | af_err af_det(double *det_real, double *det_imag, const af_array in);
16 | af_err af_norm(double *out, const af_array in, const af_norm_type type, const double p, const double q);
17 | af_err af_is_lapack_available(bool *out);
18 |
--------------------------------------------------------------------------------
/include/openCL.h:
--------------------------------------------------------------------------------
1 | #include "defines.h"
2 |
3 | af_err afcl_get_context(cl_context *ctx, const bool retain);
4 | af_err afcl_get_queue(cl_command_queue *queue, const bool retain);
5 | af_err afcl_get_device_id(cl_device_id *id);
6 | af_err afcl_set_device_id(cl_device_id id);
7 | af_err afcl_add_device_context(cl_device_id dev, cl_context ctx, cl_command_queue que);
8 | af_err afcl_set_device_context(cl_device_id dev, cl_context ctx);
9 | af_err afcl_delete_device_context(cl_device_id dev, cl_context ctx);
10 | af_err afcl_get_device_type(afcl_device_type *res);
11 | af_err afcl_get_platform(afcl_platform *res);
12 |
--------------------------------------------------------------------------------
/include/random.h:
--------------------------------------------------------------------------------
1 | #include "defines.h"
2 |
3 | af_err af_create_random_engine(af_random_engine *engine, af_random_engine_type rtype, uintl seed);
4 | af_err af_retain_random_engine(af_random_engine *out, const af_random_engine engine);
5 | af_err af_random_engine_set_type(af_random_engine *engine, const af_random_engine_type rtype);
6 | af_err af_random_engine_get_type(af_random_engine_type *rtype, const af_random_engine engine);
7 | af_err af_random_uniform(af_array *out, const unsigned ndims, const dim_t * const dims, const af_dtype type, af_random_engine engine);
8 | af_err af_random_normal(af_array *out, const unsigned ndims, const dim_t * const dims, const af_dtype type, af_random_engine engine);
9 | af_err af_random_engine_set_seed(af_random_engine *engine, const uintl seed);
10 | af_err af_get_default_random_engine(af_random_engine *engine);
11 | af_err af_set_default_random_engine_type(const af_random_engine_type rtype);
12 | af_err af_random_engine_get_seed(uintl * const seed, af_random_engine engine);
13 | af_err af_release_random_engine(af_random_engine engine);
14 | af_err af_randu(af_array *out, const unsigned ndims, const dim_t * const dims, const af_dtype type);
15 | af_err af_randn(af_array *out, const unsigned ndims, const dim_t * const dims, const af_dtype type);
16 | af_err af_set_seed(const uintl seed);
17 | af_err af_get_seed(uintl *seed);
18 |
--------------------------------------------------------------------------------
/include/seq.h:
--------------------------------------------------------------------------------
1 | #include "defines.h"
2 |
3 | af_seq af_make_seq(double begin, double end, double step);
4 |
--------------------------------------------------------------------------------
/include/signal.h:
--------------------------------------------------------------------------------
1 | #include "defines.h"
2 |
3 | af_err af_approx1(af_array *out, const af_array in, const af_array pos, const af_interp_type method, const float off_grid);
4 | af_err af_approx2(af_array *out, const af_array in, const af_array pos0, const af_array pos1, const af_interp_type method, const float off_grid);
5 | af_err af_fft(af_array *out, const af_array in, const double norm_factor, const dim_t odim0);
6 | af_err af_fft_inplace(af_array in, const double norm_factor);
7 | af_err af_fft2(af_array *out, const af_array in, const double norm_factor, const dim_t odim0, const dim_t odim1);
8 | af_err af_fft2_inplace(af_array in, const double norm_factor);
9 | af_err af_fft3(af_array *out, const af_array in, const double norm_factor, const dim_t odim0, const dim_t odim1, const dim_t odim2);
10 | af_err af_fft3_inplace(af_array in, const double norm_factor);
11 | af_err af_ifft(af_array *out, const af_array in, const double norm_factor, const dim_t odim0);
12 | af_err af_ifft_inplace(af_array in, const double norm_factor);
13 | af_err af_ifft2(af_array *out, const af_array in, const double norm_factor, const dim_t odim0, const dim_t odim1);
14 | af_err af_ifft2_inplace(af_array in, const double norm_factor);
15 | af_err af_ifft3(af_array *out, const af_array in, const double norm_factor, const dim_t odim0, const dim_t odim1, const dim_t odim2);
16 | af_err af_ifft3_inplace(af_array in, const double norm_factor);
17 | af_err af_fft_r2c (af_array *out, const af_array in, const double norm_factor, const dim_t pad0);
18 | af_err af_fft2_r2c(af_array *out, const af_array in, const double norm_factor, const dim_t pad0, const dim_t pad1);
19 | af_err af_fft3_r2c(af_array *out, const af_array in, const double norm_factor, const dim_t pad0, const dim_t pad1, const dim_t pad2);
20 | af_err af_fft_c2r (af_array *out, const af_array in, const double norm_factor, const bool is_odd);
21 | af_err af_fft2_c2r(af_array *out, const af_array in, const double norm_factor, const bool is_odd);
22 | af_err af_fft3_c2r(af_array *out, const af_array in, const double norm_factor, const bool is_odd);
23 | af_err af_convolve1(af_array *out, const af_array signal, const af_array filter, const af_conv_mode mode, af_conv_domain domain);
24 | af_err af_convolve2(af_array *out, const af_array signal, const af_array filter, const af_conv_mode mode, af_conv_domain domain);
25 | af_err af_convolve3(af_array *out, const af_array signal, const af_array filter, const af_conv_mode mode, af_conv_domain domain);
26 | af_err af_convolve2_sep(af_array *out, const af_array col_filter, const af_array row_filter, const af_array signal, const af_conv_mode mode);
27 | af_err af_fft_convolve1(af_array *out, const af_array signal, const af_array filter, const af_conv_mode mode);
28 | af_err af_fft_convolve2(af_array *out, const af_array signal, const af_array filter, const af_conv_mode mode);
29 | af_err af_fft_convolve3(af_array *out, const af_array signal, const af_array filter, const af_conv_mode mode);
30 | af_err af_fir(af_array *y, const af_array b, const af_array x);
31 | af_err af_iir(af_array *y, const af_array b, const af_array a, const af_array x);
32 | af_err af_medfilt(af_array *out, const af_array in, const dim_t wind_length, const dim_t wind_width, const af_border_type edge_pad);
33 | af_err af_medfilt1(af_array *out, const af_array in, const dim_t wind_width, const af_border_type edge_pad);
34 | af_err af_medfilt2(af_array *out, const af_array in, const dim_t wind_length, const dim_t wind_width, const af_border_type edge_pad);
35 | af_err af_set_fft_plan_cache_size(size_t cache_size);
36 |
--------------------------------------------------------------------------------
/include/sparse.h:
--------------------------------------------------------------------------------
1 | #include "defines.h"
2 |
3 | af_err af_create_sparse_array(af_array *out, const dim_t nRows, const dim_t nCols, const af_array values, const af_array rowIdx, const af_array colIdx, const af_storage stype);
4 | af_err af_create_sparse_array_from_ptr(af_array *out, const dim_t nRows, const dim_t nCols, const dim_t nNZ, const void * const values, const int * const rowIdx, const int * const colIdx, const af_dtype type, const af_storage stype, const af_source src);
5 | af_err af_create_sparse_array_from_dense(af_array *out, const af_array dense, const af_storage stype);
6 | af_err af_sparse_convert_to(af_array *out, const af_array in, const af_storage destStorage);
7 | af_err af_sparse_to_dense(af_array *out, const af_array sparse);
8 | af_err af_sparse_get_info(af_array *values, af_array *rowIdx, af_array *colIdx, af_storage *stype, const af_array in);
9 | af_err af_sparse_get_values(af_array *out, const af_array in);
10 | af_err af_sparse_get_row_idx(af_array *out, const af_array in);
11 | af_err af_sparse_get_col_idx(af_array *out, const af_array in);
12 | af_err af_sparse_get_nnz(dim_t *out, const af_array in);
13 | af_err af_sparse_get_storage(af_storage *out, const af_array in);
14 |
--------------------------------------------------------------------------------
/include/statistics.h:
--------------------------------------------------------------------------------
1 | #include "defines.h"
2 |
3 | af_err af_mean(af_array *out, const af_array in, const dim_t dim);
4 | af_err af_mean_weighted(af_array *out, const af_array in, const af_array weights, const dim_t dim);
5 | af_err af_var(af_array *out, const af_array in, const bool isbiased, const dim_t dim);
6 | af_err af_var_weighted(af_array *out, const af_array in, const af_array weights, const dim_t dim);
7 | af_err af_stdev(af_array *out, const af_array in, const dim_t dim);
8 | af_err af_cov(af_array* out, const af_array X, const af_array Y, const bool isbiased);
9 | af_err af_median(af_array* out, const af_array in, const dim_t dim);
10 | af_err af_mean_all(double *real, double *imag, const af_array in);
11 | af_err af_mean_all_weighted(double *real, double *imag, const af_array in, const af_array weights);
12 | af_err af_var_all(double *realVal, double *imagVal, const af_array in, const bool isbiased);
13 | af_err af_var_all_weighted(double *realVal, double *imagVal, const af_array in, const af_array weights);
14 | af_err af_stdev_all(double *real, double *imag, const af_array in);
15 | af_err af_median_all(double *realVal, double *imagVal, const af_array in);
16 | af_err af_corrcoef(double *realVal, double *imagVal, const af_array X, const af_array Y);
17 | af_err af_topk(af_array *values, af_array *indices, const af_array in, const int k, const int dim, const af_topk_function order);
18 |
--------------------------------------------------------------------------------
/include/util.h:
--------------------------------------------------------------------------------
1 | #include "defines.h"
2 |
3 | af_err af_print_array(af_array arr);
4 | af_err af_print_array_gen(const char *exp, const af_array arr, const int precision);
5 | af_err af_save_array(int *index, const char* key, const af_array arr, const char *filename, const bool append);
6 | af_err af_read_array_index(af_array *out, const char *filename, const unsigned index);
7 | af_err af_read_array_key(af_array *out, const char *filename, const char* key);
8 | af_err af_read_array_key_check(int *index, const char *filename, const char* key);
9 | af_err af_array_to_string(char **output, const char *exp, const af_array arr, const int precision, const bool transpose);
10 | af_err af_example_function(af_array* out, const af_array in, const af_someenum_t param);
11 | af_err af_get_version(int *major, int *minor, int *patch);
12 | const char *af_get_revision();
13 | af_err af_get_size_of(size_t *size, af_dtype type);
14 |
--------------------------------------------------------------------------------
/include/vision.h:
--------------------------------------------------------------------------------
1 | #include "defines.h"
2 |
3 | af_err af_fast(af_features *out, const af_array in, const float thr, const unsigned arc_length, const bool non_max, const float feature_ratio, const unsigned edge);
4 | af_err af_harris(af_features *out, const af_array in, const unsigned max_corners, const float min_response, const float sigma, const unsigned block_size, const float k_thr);
5 | af_err af_orb(af_features *feat, af_array *desc, const af_array in, const float fast_thr, const unsigned max_feat, const float scl_fctr, const unsigned levels, const bool blur_img);
6 | af_err af_sift(af_features *feat, af_array *desc, const af_array in, const unsigned n_layers, const float contrast_thr, const float edge_thr, const float init_sigma, const bool double_input, const float intensity_scale, const float feature_ratio);
7 | af_err af_gloh(af_features *feat, af_array *desc, const af_array in, const unsigned n_layers, const float contrast_thr, const float edge_thr, const float init_sigma, const bool double_input, const float intensity_scale, const float feature_ratio);
8 | af_err af_hamming_matcher(af_array* idx, af_array* dist, const af_array query, const af_array train, const dim_t dist_dim, const unsigned n_dist);
9 | af_err af_nearest_neighbour(af_array* idx, af_array* dist, const af_array query, const af_array train, const dim_t dist_dim, const unsigned n_dist, const af_match_type dist_type);
10 | af_err af_match_template(af_array *out, const af_array search_img, const af_array template_img, const af_match_type m_type);
11 | af_err af_susan(af_features* out, const af_array in, const unsigned radius, const float diff_thr, const float geom_thr, const float feature_ratio, const unsigned edge);
12 | af_err af_dog(af_array *out, const af_array in, const int radius1, const int radius2);
13 | af_err af_homography(af_array *H, int *inliers, const af_array x_src, const af_array y_src, const af_array x_dst, const af_array y_dst, const af_homography_type htype, const float inlier_thr, const unsigned iterations, const af_dtype otype);
14 |
--------------------------------------------------------------------------------
/nix/default.nix:
--------------------------------------------------------------------------------
1 | { stdenv, fetchurl, fetchFromGitHub, cmake, pkgconfig
2 | , cudatoolkit, opencl-clhpp, ocl-icd, fftw, fftwFloat, mkl
3 | , blas, openblas, boost, mesa_noglu, libGLU_combined
4 | , freeimage, python, lib
5 | }:
6 |
7 | let
8 | version = "3.6.4";
9 |
10 | clfftSource = fetchFromGitHub {
11 | owner = "arrayfire";
12 | repo = "clFFT";
13 | rev = "16925fb93338b3cac66490b5cf764953d6a5dac7";
14 | sha256 = "0y35nrdz7w4n1l17myhkni3hwm37z775xn6f76xmf1ph7dbkslsc";
15 | fetchSubmodules = true;
16 | };
17 |
18 | clblasSource = fetchFromGitHub {
19 | owner = "arrayfire";
20 | repo = "clBLAS";
21 | rev = "1f3de2ae5582972f665c685b18ef0df43c1792bb";
22 | sha256 = "154mz52r5hm0jrp5fqrirzzbki14c1jkacj75flplnykbl36ibjs";
23 | fetchSubmodules = true;
24 | };
25 |
26 | cl2hppSource = fetchurl {
27 | url = "https://github.com/KhronosGroup/OpenCL-CLHPP/releases/download/v2.0.10/cl2.hpp";
28 | sha256 = "1v4q0g6b6mwwsi0kn7kbjn749j3qafb9r4ld3zdq1163ln9cwnvw";
29 | };
30 |
31 | in stdenv.mkDerivation {
32 | pname = "arrayfire";
33 | inherit version;
34 |
35 | src = fetchurl {
36 | url = "http://arrayfire.com/arrayfire_source/arrayfire-full-${version}.tar.bz2";
37 | sha256 = "1fin7a9rliyqic3z83agkpb8zlq663q6gdxsnm156cs8s7f7rc9h";
38 | };
39 |
40 | cmakeFlags = [
41 | "-DAF_BUILD_OPENCL=OFF"
42 | "-DAF_BUILD_EXAMPLES=OFF"
43 | "-DBUILD_TESTING=OFF"
44 | ] ++ (lib.optional stdenv.isLinux ["-DCMAKE_LIBRARY_PATH=${cudatoolkit}/lib/stubs"]);
45 |
46 | patches = [ ./no-download.patch ];
47 |
48 | postPatch = ''
49 | mkdir -p ./build/third_party/clFFT/src
50 | cp -R --no-preserve=mode,ownership ${clfftSource}/ ./build/third_party/clFFT/src/clFFT-ext/
51 | mkdir -p ./build/third_party/clBLAS/src
52 | cp -R --no-preserve=mode,ownership ${clblasSource}/ ./build/third_party/clBLAS/src/clBLAS-ext/
53 | mkdir -p ./build/include/CL
54 | cp -R --no-preserve=mode,ownership ${cl2hppSource} ./build/include/CL/cl2.hpp
55 | '';
56 |
57 | preBuild = lib.optionalString stdenv.isLinux ''
58 | export CUDA_PATH="${cudatoolkit}"'
59 | '';
60 |
61 | enableParallelBuilding = true;
62 |
63 | buildInputs = [
64 | cmake pkgconfig
65 | opencl-clhpp fftw fftwFloat
66 | mkl openblas
67 | libGLU_combined
68 | mesa_noglu freeimage
69 | boost.out boost.dev python
70 | ] ++ (lib.optional stdenv.isLinux [ cudatoolkit ocl-icd ]);
71 |
72 | meta = with stdenv.lib; {
73 | description = "A general-purpose library that simplifies the process of developing software that targets parallel and massively-parallel architectures including CPUs, GPUs, and other hardware acceleration devices";
74 | license = licenses.bsd3;
75 | homepage = https://arrayfire.com/ ;
76 | maintainers = with stdenv.lib.maintainers; [ chessai ];
77 | inherit version;
78 | };
79 | }
80 |
--------------------------------------------------------------------------------
/nix/no-download.patch:
--------------------------------------------------------------------------------
1 | diff --git a/CMakeModules/build_clBLAS.cmake b/CMakeModules/build_clBLAS.cmake
2 | index 8de529e8..6361b613 100644
3 | --- a/CMakeModules/build_clBLAS.cmake
4 | +++ b/CMakeModules/build_clBLAS.cmake
5 | @@ -14,8 +14,7 @@ find_package(OpenCL)
6 |
7 | ExternalProject_Add(
8 | clBLAS-ext
9 | - GIT_REPOSITORY https://github.com/arrayfire/clBLAS.git
10 | - GIT_TAG arrayfire-release
11 | + DOWNLOAD_COMMAND true
12 | BUILD_BYPRODUCTS ${clBLAS_location}
13 | PREFIX "${prefix}"
14 | INSTALL_DIR "${prefix}"
15 | diff --git a/CMakeModules/build_clFFT.cmake b/CMakeModules/build_clFFT.cmake
16 | index 28be38a3..85e3915e 100644
17 | --- a/CMakeModules/build_clFFT.cmake
18 | +++ b/CMakeModules/build_clFFT.cmake
19 | @@ -20,8 +20,7 @@ ENDIF()
20 |
21 | ExternalProject_Add(
22 | clFFT-ext
23 | - GIT_REPOSITORY https://github.com/arrayfire/clFFT.git
24 | - GIT_TAG arrayfire-release
25 | + DOWNLOAD_COMMAND true
26 | PREFIX "${prefix}"
27 | INSTALL_DIR "${prefix}"
28 | UPDATE_COMMAND ""
29 |
--------------------------------------------------------------------------------
/pkg.nix:
--------------------------------------------------------------------------------
1 | { mkDerivation, base, directory, parsec, stdenv, text, vector
2 | , hspec, hspec-discover
3 | }:
4 | mkDerivation {
5 | pname = "arrayfire";
6 | version = "0.6.0.0"
7 | src = ./.;
8 | isLibrary = true;
9 | isExecutable = true;
10 | libraryHaskellDepends = [ base vector ];
11 | executableHaskellDepends = [ base directory parsec text ];
12 | testHaskellDepends = [ hspec hspec-discover ];
13 | homepage = "https://github.com/arrayfire/arrayfire-haskell";
14 | description = "Haskell bindings to ArrayFire";
15 | license = stdenv.lib.licenses.bsd3;
16 | }
17 |
--------------------------------------------------------------------------------
/shell.nix:
--------------------------------------------------------------------------------
1 | { pkgs ? import {} }:
2 | let
3 | pkg = (import ./default.nix {}).env;
4 | in
5 | pkgs.lib.overrideDerivation pkg (drv: {
6 | shellHook = ''
7 | export AF_PRINT_ERRORS=1
8 | export PATH=$PATH:${pkgs.haskellPackages.doctest}/bin
9 | export PATH=$PATH:${pkgs.haskellPackages.cabal-install}/bin
10 | function ghcid () {
11 | ${pkgs.haskellPackages.ghcid.bin}/bin/ghcid -c 'cabal v1-repl lib:arrayfire'
12 | };
13 | function test-runner () {
14 | ${pkgs.silver-searcher}/bin/ag -l | \
15 | ${pkgs.entr}/bin/entr sh -c \
16 | 'cabal v1-configure --enable-tests && \
17 | cabal v1-build test && dist/build/test/test'
18 | }
19 | function doctest-runner () {
20 | ${pkgs.silver-searcher}/bin/ag -l | \
21 | ${pkgs.entr}/bin/entr sh -c \
22 | 'cabal v1-configure --enable-tests && \
23 | cabal v1-build doctests && dist/build/doctests/doctests src/ArrayFire/Algorithm.hs'
24 | }
25 | function exe () {
26 | cabal run main
27 | }
28 | function repl () {
29 | cabal v1-repl lib:arrayfire
30 | }
31 | function docs () {
32 | cabal haddock
33 | open ./dist-newstyle/*/*/*/*/doc/html/arrayfire/index.html
34 | }
35 | function upload-docs () {
36 | cabal haddock --haddock-for-hackage
37 | cabal upload -d dist-newstyle/arrayfire-*.*.*.*-docs.tar.gz --publish
38 | }
39 | '';
40 | })
41 |
--------------------------------------------------------------------------------
/src/ArrayFire.hs:
--------------------------------------------------------------------------------
1 | --------------------------------------------------------------------------------
2 | -- |
3 | -- Module : ArrayFire
4 | -- Copyright : David Johnson (c) 2019-2020
5 | -- License : BSD3
6 | -- Maintainer : David Johnson
7 | -- Stability : Experimental
8 | -- Portability : GHC
9 | --
10 | -- <>
11 | --
12 | --------------------------------------------------------------------------------
13 | module ArrayFire
14 | ( -- * Tutorial
15 | -- $tutorial
16 |
17 | -- ** Modules
18 | -- $modules
19 |
20 | -- ** Exceptions
21 | -- $exceptions
22 |
23 | -- ** Construction
24 | -- $construction
25 |
26 | -- ** Laws
27 | -- $laws
28 |
29 | -- ** Conversion
30 | -- $conversion
31 |
32 | -- ** Serialization
33 | -- $serialization
34 |
35 | -- ** Device
36 | -- $device
37 | module ArrayFire.Algorithm
38 | , module ArrayFire.Arith
39 | , module ArrayFire.Array
40 | , module ArrayFire.Backend
41 | , module ArrayFire.BLAS
42 | , module ArrayFire.Data
43 | , module ArrayFire.Device
44 | , module ArrayFire.Features
45 | , module ArrayFire.Graphics
46 | , module ArrayFire.Image
47 | , module ArrayFire.Index
48 | , module ArrayFire.LAPACK
49 | , module ArrayFire.Random
50 | , module ArrayFire.Signal
51 | , module ArrayFire.Sparse
52 | , module ArrayFire.Statistics
53 | , module ArrayFire.Types
54 | , module ArrayFire.Util
55 | , module ArrayFire.Vision
56 | , module Foreign.C.Types
57 | , module Data.Int
58 | , module Data.Word
59 | , module Data.Complex
60 | , module Foreign.Storable
61 | ) where
62 |
63 | import ArrayFire.Algorithm
64 | import ArrayFire.Arith
65 | import ArrayFire.Array
66 | import ArrayFire.Backend
67 | import ArrayFire.BLAS
68 | import ArrayFire.Data
69 | import ArrayFire.Device
70 | import ArrayFire.Features
71 | import ArrayFire.Graphics
72 | import ArrayFire.Image
73 | import ArrayFire.Index
74 | import ArrayFire.LAPACK
75 | import ArrayFire.Random
76 | import ArrayFire.Signal
77 | import ArrayFire.Sparse
78 | import ArrayFire.Statistics
79 | import ArrayFire.Types
80 | import ArrayFire.Util
81 | import ArrayFire.Vision
82 | import ArrayFire.Orphans ()
83 | import Foreign.Storable
84 | import Foreign.C.Types
85 | import Data.Int
86 | import Data.Complex
87 | import Data.Word
88 |
89 | -- $tutorial
90 | --
91 | -- [ArrayFire](http://arrayfire.org/docs/gettingstarted.htm) is a high performance parallel computing library that features modules for statistical and numerical methods.
92 | -- Example usage is depicted below.
93 | --
94 | -- @
95 | -- module Main where
96 | --
97 | -- import qualified ArrayFire as A
98 | --
99 | -- main :: IO ()
100 | -- main = print $ A.matrix @Double (3,2) [[1,2,3],[4,5,6]]
101 | -- @
102 | --
103 | -- Each 'Array' is constructed and displayed in column-major order.
104 | --
105 | -- @
106 | -- ArrayFire Array
107 | -- [3 2 1 1]
108 | -- 1.0000 4.0000
109 | -- 2.0000 5.0000
110 | -- 3.0000 6.0000
111 | -- @
112 |
113 | -- $modules
114 | --
115 | -- All child modules are re-exported top-level in the "ArrayFire" module.
116 | -- We recommend importing "ArrayFire" qualified so as to avoid naming collisions.
117 | --
118 | -- >>> import qualified ArrayFire as A
119 | --
120 |
121 | -- $exceptions
122 | --
123 | -- @
124 | -- {\-\# LANGUAGE TypeApplications \#\-}
125 | -- module Main where
126 | --
127 | -- import qualified ArrayFire as A
128 | -- import Control.Exception ( catch )
129 | --
130 | -- main :: IO ()
131 | -- main = A.printArray action \`catch\` (\\(e :: A.AFException) -> print e)
132 | -- where
133 | -- action =
134 | -- A.matrix \@Double (3,3) [[1..],[1..],[1..]]
135 | -- \`A.mul\` A.matrix \@Double (2,2) [[1..],[1..]]
136 | -- @
137 | --
138 | -- The above operation is invalid since the matrix multiply has improper dimensions. The caught exception produces the following error:
139 | --
140 | -- > AFException {afExceptionType = SizeError, afExceptionCode = 203, afExceptionMsg = "Invalid input size"}
141 | --
142 |
143 | -- $construction
144 | -- An 'Array' can be constructed using the following smart constructors:
145 | --
146 | -- /Note/: All smart constructors (and ArrayFire internally) assume column-major order.
147 | --
148 | -- @
149 | -- >>> scalar \@Double 2.0
150 | -- ArrayFire Array
151 | -- [1 1 1 1]
152 | -- 2.0000
153 | -- @
154 | --
155 | -- @
156 | -- >>> vector \@Double 10 [1..]
157 | -- ArrayFire Array
158 | -- [10 1 1 1]
159 | -- 1.0000
160 | -- 2.0000
161 | -- 3.0000
162 | -- 4.0000
163 | -- 5.0000
164 | -- 6.0000
165 | -- 7.0000
166 | -- 8.0000
167 | -- 9.0000
168 | -- 10.0000
169 | -- @
170 | --
171 | -- @
172 | -- >>> matrix \@Double (2,2) [[1,2],[3,4]]
173 | -- ArrayFire Array
174 | -- [2 2 1 1]
175 | -- 1.0000 3.0000
176 | -- 2.0000 4.0000
177 | -- @
178 | --
179 | -- @
180 | -- >>> cube \@Double (2,2,2) [[[2,2],[2,2]],[[2,2],[2,2]]]
181 | -- ArrayFire Array
182 | -- [2 2 2 1]
183 | -- 2.0000 2.0000
184 | -- 2.0000 2.0000
185 | --
186 | -- 2.0000 2.0000
187 | -- 2.0000 2.0000
188 | -- @
189 | --
190 | -- @
191 | -- >>> tensor \@Double (2,2,2,2) [[[[2,2],[2,2]],[[2,2],[2,2]]], [[[2,2],[2,2]],[[2,2],[2,2]]]]
192 | -- ArrayFire Array
193 | -- [2 2 2 2]
194 | -- 2.0000 2.0000
195 | -- 2.0000 2.0000
196 | --
197 | -- 2.0000 2.0000
198 | -- 2.0000 2.0000
199 | --
200 | --
201 | -- 2.0000 2.0000
202 | -- 2.0000 2.0000
203 | --
204 | -- 2.0000 2.0000
205 | -- 2.0000 2.0000
206 | -- @
207 | --
208 | -- Array construction can use Haskell's lazy lists, since 'take' is called on each dimension before sending to the C API.
209 | --
210 | -- >>> mkArray @Double [5,3] [1..]
211 | -- ArrayFire Array
212 | -- [5 3 1 1]
213 | -- 1.0000 6.0000 11.0000
214 | -- 2.0000 7.0000 12.0000
215 | -- 3.0000 8.0000 13.0000
216 | -- 4.0000 9.0000 14.0000
217 | -- 5.0000 10.0000 15.0000
218 | --
219 | -- Specifying up to 4 dimensions is allowed (anything higher is ignored).
220 |
221 | -- $laws
222 | -- Every 'Array' has an instance of 'Eq', 'Num', 'Fractional', 'Floating' and 'Show'
223 | --
224 | -- 'Num'
225 | --
226 | -- >>> 2.0 :: Array Double
227 | -- ArrayFire Array
228 | -- [1 1 1 1]
229 | -- 2.0000
230 | --
231 | -- >>> scalar @Int 1 + scalar @Int 1
232 | -- ArrayFire Array
233 | -- [1 1 1 1]
234 | -- 2
235 | --
236 | -- >>> scalar @Int 1 - scalar @Int 1
237 | -- ArrayFire Array
238 | -- [1 1 1 1]
239 | -- 0
240 | --
241 | -- >>> scalar @Double 10 / scalar @Double 10
242 | -- ArrayFire Array
243 | -- [1 1 1 1]
244 | -- 1.0000
245 | --
246 | -- >>> abs $ scalar @Double (-10)
247 | -- ArrayFire Array
248 | -- [1 1 1 1]
249 | -- 10.0000
250 | --
251 | -- >>> negate (scalar @Double 10)
252 | -- ArrayFire Array
253 | -- [1 1 1 1]
254 | -- -10.0000
255 | --
256 | -- >>> fromInteger 1.0 :: Array Double
257 | -- ArrayFire Array
258 | -- [1 1 1 1]
259 | -- 1.0000
260 | --
261 | -- 'Eq'
262 | --
263 | -- >>> scalar @Double 1 [10] == scalar @Double 1 [10]
264 | -- True
265 | -- >>> scalar @Double 1 [10] /= scalar @Double 1 [10]
266 | -- False
267 | --
268 | -- 'Floating'
269 | --
270 | -- >>> pi :: Array Double
271 | -- ArrayFire Array
272 | -- [1 1 1 1]
273 | -- 3.1416
274 | --
275 | -- >>> A.sqrt pi :: Array Double
276 | -- ArrayFire Array
277 | -- [1 1 1 1]
278 | -- 1.7725
279 | --
280 | -- 'Fractional'
281 | --
282 | -- >>> (pi :: Array Double) / pi
283 | -- ArrayFire Array
284 | -- [1 1 1 1]
285 | -- 1.000
286 | --
287 | -- >>> recip 0.5 :: Array Double
288 | -- ArrayFire Array
289 | -- [1 1 1 1]
290 | -- 2.000
291 | --
292 | -- 'Show'
293 | --
294 | -- >>> 0.0 :: Array Double
295 | -- ArrayFire Array
296 | -- [1 1 1 1]
297 | -- 0.000
298 | --
299 |
300 | -- $conversion
301 | -- Any 'Array' can be exported into Haskell using `toVector'. This will create a Storable vector suitable for use in other C programs.
302 | --
303 | -- >>> vector :: Vector Double <- toVector <$> randu @Double [10,10]
304 | --
305 |
306 | -- $serialization
307 | -- Each 'Array' can be serialized to disk and deserialized from disk efficiently.
308 | --
309 | -- @
310 | -- import qualified ArrayFire as A
311 | -- import Control.Monad
312 | --
313 | -- main :: IO ()
314 | -- main = do
315 | -- let arr = A.'constant' [1,1,1,1] 10
316 | -- idx <- A.'saveArray' "key" arr "file.array" False
317 | -- foundIndex <- A.'readArrayKeyCheck' "file.array" "key"
318 | -- when (idx == foundIndex) $ do
319 | -- array <- A.'readArrayKey' "file.array" "key"
320 | -- 'print' array
321 | --
322 | -- -- ArrayFire Array
323 | -- -- [ 1 1 1 1 ]
324 | -- -- 10
325 | -- @
326 | --
327 |
328 | -- $device
329 | -- The ArrayFire API is able to see which devices are present, and will by default use the GPU if available.
330 | --
331 | -- >>> afInfo
332 | -- ArrayFire v3.6.4 (OpenCL, 64-bit Mac OSX, build 1b8030c5)
333 | -- [0] APPLE: AMD Radeon Pro 555X Compute Engine, 4096 MB <-- brackets [] signify device being used.
334 | -- -1- APPLE: Intel(R) UHD Graphics 630, 1536 MB
335 | --
336 |
337 | -- $visualization
338 | -- The ArrayFire API is able to display visualizations using the Forge library
339 | -- >>> window <- createWindow 800 600 "Histogram"
340 | --
341 |
--------------------------------------------------------------------------------
/src/ArrayFire/BLAS.hs:
--------------------------------------------------------------------------------
1 | {-# LANGUAGE ViewPatterns #-}
2 | --------------------------------------------------------------------------------
3 | -- |
4 | -- Module : ArrayFire.BLAS
5 | -- Copyright : David Johnson (c) 2019-2020
6 | -- License : BSD3
7 | -- Maintainer : David Johnson
8 | -- Stability : Experimental
9 | -- Portability : GHC
10 | --
11 | -- Basic Linear Algebra Subprograms (BLAS) API
12 | --
13 | -- @
14 | -- main :: IO ()
15 | -- main = print (matmul x y xProp yProp)
16 | -- where
17 | -- x,y :: Array Double
18 | -- x = matrix (2,3) [[1,2],[3,4],[5,6]]
19 | -- y = matrix (3,2) [[1,2,3],[4,5,6]]
20 | --
21 | -- xProp, yProp :: MatProp
22 | -- xProp = None
23 | -- yProp = None
24 | -- @
25 | -- @
26 | -- ArrayFire Array
27 | -- [2 2 1 1]
28 | -- 22.0000 49.0000
29 | -- 28.0000 64.0000
30 | -- @
31 | --------------------------------------------------------------------------------
32 | module ArrayFire.BLAS where
33 |
34 | import Data.Complex
35 |
36 | import ArrayFire.FFI
37 | import ArrayFire.Internal.BLAS
38 | import ArrayFire.Internal.Types
39 |
40 | -- | The following applies for Sparse-Dense matrix multiplication.
41 | --
42 | -- This function can be used with one sparse input. The sparse input must always be the lhs and the dense matrix must be rhs.
43 | --
44 | -- The sparse array can only be of 'CSR' format.
45 | --
46 | -- The returned array is always dense.
47 | --
48 | -- optLhs an only be one of AF_MAT_NONE, AF_MAT_TRANS, AF_MAT_CTRANS.
49 | --
50 | -- optRhs can only be AF_MAT_NONE.
51 | --
52 | -- >>> matmul (matrix @Double (2,2) [[1,2],[3,4]]) (matrix @Double (2,2) [[1,2],[3,4]]) None None
53 | -- ArrayFire Array
54 | -- [2 2 1 1]
55 | -- 7.0000 15.0000
56 | -- 10.0000 22.0000
57 | matmul
58 | :: AFType a
59 | => Array a
60 | -- ^ 2D matrix of Array a, left-hand side
61 | -> Array a
62 | -- ^ 2D matrix of Array a, right-hand side
63 | -> MatProp
64 | -- ^ Left hand side matrix options
65 | -> MatProp
66 | -- ^ Right hand side matrix options
67 | -> Array a
68 | -- ^ Output of 'matmul'
69 | matmul arr1 arr2 prop1 prop2 = do
70 | op2 arr1 arr2 (\p a b -> af_matmul p a b (toMatProp prop1) (toMatProp prop2))
71 |
72 | -- | Scalar dot product between two vectors. Also referred to as the inner product.
73 | --
74 | -- >>> dot (vector @Double 10 [1..]) (vector @Double 10 [1..]) None None
75 | -- ArrayFire Array
76 | -- [1 1 1 1]
77 | -- 385.0000
78 | dot
79 | :: AFType a
80 | => Array a
81 | -- ^ Left-hand side input
82 | -> Array a
83 | -- ^ Right-hand side input
84 | -> MatProp
85 | -- ^ Options for left-hand side. Currently only AF_MAT_NONE and AF_MAT_CONJ are supported.
86 | -> MatProp
87 | -- ^ Options for right-hand side. Currently only AF_MAT_NONE and AF_MAT_CONJ are supported.
88 | -> Array a
89 | -- ^ Output of 'dot'
90 | dot arr1 arr2 prop1 prop2 =
91 | op2 arr1 arr2 (\p a b -> af_dot p a b (toMatProp prop1) (toMatProp prop2))
92 |
93 | -- | Scalar dot product between two vectors. Also referred to as the inner product. Returns the result as a host scalar.
94 | --
95 | -- >>> dotAll (vector @Double 10 [1..]) (vector @Double 10 [1..]) None None
96 | -- 385.0 :+ 0.0
97 | dotAll
98 | :: AFType a
99 | => Array a
100 | -- ^ Left-hand side array
101 | -> Array a
102 | -- ^ Right-hand side array
103 | -> MatProp
104 | -- ^ Options for left-hand side. Currently only AF_MAT_NONE and AF_MAT_CONJ are supported.
105 | -> MatProp
106 | -- ^ Options for right-hand side. Currently only AF_MAT_NONE and AF_MAT_CONJ are supported.
107 | -> Complex Double
108 | -- ^ Real and imaginary component result
109 | dotAll arr1 arr2 prop1 prop2 = do
110 | let (real,imag) =
111 | infoFromArray22 arr1 arr2 $ \a b c d ->
112 | af_dot_all a b c d (toMatProp prop1) (toMatProp prop2)
113 | real :+ imag
114 |
115 | -- | Transposes a matrix.
116 | --
117 | -- >>> array = matrix @Double (2,3) [[2,3],[3,4],[5,6]]
118 | -- >>> array
119 | -- ArrayFire Array
120 | -- [2 3 1 1]
121 | -- 2.0000 3.0000 5.0000
122 | -- 3.0000 4.0000 6.0000
123 | --
124 | -- >>> transpose array True
125 | -- ArrayFire Array
126 | -- [3 2 1 1]
127 | -- 2.0000 3.0000
128 | -- 3.0000 4.0000
129 | -- 5.0000 6.0000
130 | --
131 | transpose
132 | :: AFType a
133 | => Array a
134 | -- ^ Input matrix to be transposed
135 | -> Bool
136 | -- ^ Should perform conjugate transposition
137 | -> Array a
138 | -- ^ The transposed matrix
139 | transpose arr1 (fromIntegral . fromEnum -> b) =
140 | arr1 `op1` (\x y -> af_transpose x y b)
141 |
142 | -- | Transposes a matrix.
143 | --
144 | -- * Warning: This function mutates an array in-place, all subsequent references will be changed. Use carefully.
145 | --
146 | -- >>> array = matrix @Double (2,2) [[1,2],[3,4]]
147 | -- >>> array
148 | -- ArrayFire Array
149 | -- [3 2 1 1]
150 | -- 1.0000 4.0000
151 | -- 2.0000 5.0000
152 | -- 3.0000 6.0000
153 | --
154 | -- >>> transposeInPlace array False
155 | -- >>> array
156 | -- ArrayFire Array
157 | -- [2 2 1 1]
158 | -- 1.0000 2.0000
159 | -- 3.0000 4.0000
160 | --
161 | transposeInPlace
162 | :: AFType a
163 | => Array a
164 | -- ^ Input matrix to be transposed
165 | -> Bool
166 | -- ^ Should perform conjugate transposition
167 | -> IO ()
168 | transposeInPlace arr (fromIntegral . fromEnum -> b) =
169 | arr `inPlace` (`af_transpose_inplace` b)
170 |
--------------------------------------------------------------------------------
/src/ArrayFire/Backend.hs:
--------------------------------------------------------------------------------
1 | --------------------------------------------------------------------------------
2 | -- |
3 | -- Module : ArrayFire.Backend
4 | -- Copyright : David Johnson (c) 2019-2020
5 | -- License : BSD 3
6 | -- Maintainer : David Johnson
7 | -- Stability : Experimental
8 | -- Portability : GHC
9 | --
10 | -- Set and get available ArrayFire 'Backend's.
11 | --
12 | -- @
13 | -- module Main where
14 | --
15 | -- import ArrayFire
16 | --
17 | -- main :: IO ()
18 | -- main = print =<< getAvailableBackends
19 | -- @
20 | --
21 | -- @
22 | -- [CPU,OpenCL]
23 | -- @
24 | --------------------------------------------------------------------------------
25 | module ArrayFire.Backend where
26 |
27 | import ArrayFire.FFI
28 | import ArrayFire.Internal.Backend
29 | import ArrayFire.Internal.Types
30 |
31 | -- | Set specific 'Backend' to use
32 | --
33 | -- >>> setBackend OpenCL
34 | -- ()
35 | setBackend
36 | :: Backend
37 | -- ^ 'Backend' to use for 'Array' construction
38 | -> IO ()
39 | setBackend = afCall . af_set_backend . toAFBackend
40 |
41 | -- | Retrieve count of Backends available
42 | --
43 | -- >>> getBackendCount
44 | -- 2
45 | --
46 | getBackendCount :: IO Int
47 | getBackendCount =
48 | fromIntegral <$>
49 | afCall1 af_get_backend_count
50 |
51 | -- | Retrieve available 'Backend's
52 | --
53 | -- >>> mapM_ print =<< getAvailableBackends
54 | -- CPU
55 | -- OpenCL
56 | getAvailableBackends :: IO [Backend]
57 | getAvailableBackends =
58 | toBackends . fromIntegral <$>
59 | afCall1 af_get_available_backends
60 |
61 | -- | Retrieve 'Backend' that specific 'Array' was created from
62 | --
63 | -- >>> getBackend (scalar @Double 2.0)
64 | -- OpenCL
65 | getBackend :: Array a -> Backend
66 | getBackend = toBackend . flip infoFromArray af_get_backend_id
67 |
68 | -- | Retrieve active 'Backend'
69 | --
70 | -- >>> getActiveBackend
71 | -- OpenCL
72 | getActiveBackend :: IO Backend
73 | getActiveBackend = toBackend <$> afCall1 af_get_active_backend
74 |
75 | -- | Retrieve Device ID that 'Array' was created from
76 | --
77 | -- >>> getDeviceID (scalar \@Double 2.0)
78 | -- 1
79 | getDeviceID :: Array a -> Int
80 | getDeviceID = fromIntegral . flip infoFromArray af_get_device_id
81 |
--------------------------------------------------------------------------------
/src/ArrayFire/Device.hs:
--------------------------------------------------------------------------------
1 | {-# LANGUAGE ViewPatterns #-}
2 | --------------------------------------------------------------------------------
3 | -- |
4 | -- Module : ArrayFire.Device
5 | -- Copyright : David Johnson (c) 2019-2020
6 | -- License : BSD3
7 | -- Maintainer : David Johnson
8 | -- Stability : Experimental
9 | -- Portability : GHC
10 | --
11 | -- Information about ArrayFire API and devices
12 | --
13 | -- >>> info
14 | -- ArrayFire v3.6.4 (OpenCL, 64-bit Mac OSX, build 1b8030c5)
15 | -- [0] APPLE: AMD Radeon Pro 555X Compute Engine, 4096 MB
16 | -- -1- APPLE: Intel(R) UHD Graphics 630, 1536 MB
17 | --
18 | --------------------------------------------------------------------------------
19 | module ArrayFire.Device where
20 |
21 | import Foreign.C.String
22 | import ArrayFire.Internal.Device
23 | import ArrayFire.FFI
24 |
25 | -- | Retrieve info from ArrayFire API
26 | --
27 | -- @
28 | -- ArrayFire v3.6.4 (OpenCL, 64-bit Mac OSX, build 1b8030c5)
29 | -- [0] APPLE: AMD Radeon Pro 555X Compute Engine, 4096 MB
30 | -- -1- APPLE: Intel(R) UHD Graphics 630, 1536 MB
31 | -- @
32 | info :: IO ()
33 | info = afCall af_info
34 |
35 | -- | Calls /af_init/ C function from ArrayFire API
36 | --
37 | -- >>> afInit
38 | -- ()
39 | afInit :: IO ()
40 | afInit = afCall af_init
41 |
42 | -- | Retrieves ArrayFire device information as 'String', same as 'info'.
43 | --
44 | -- >>> getInfoString
45 | -- "ArrayFire v3.6.4 (OpenCL, 64-bit Mac OSX, build 1b8030c5)\n[0] APPLE: AMD Radeon Pro 555X Compute Engine, 4096 MB\n-1- APPLE: Intel(R) UHD Graphics 630, 1536 MB\n"
46 | getInfoString :: IO String
47 | getInfoString = peekCString =<< afCall1 (flip af_info_string 1)
48 |
49 | -- af_err af_device_info(char* d_name, char* d_platform, char *d_toolkit, char* d_compute);
50 |
51 | -- | Retrieves count of devices
52 | --
53 | -- >>> getDeviceCount
54 | -- 2
55 | getDeviceCount :: IO Int
56 | getDeviceCount = fromIntegral <$> afCall1 af_get_device_count
57 |
58 | -- af_err af_get_dbl_support(bool* available, const int device);
59 | -- | Sets a device by 'Int'
60 | --
61 | -- >>> setDevice 0
62 | -- ()
63 | setDevice :: Int -> IO ()
64 | setDevice (fromIntegral -> x) = afCall (af_set_device x)
65 |
66 | -- | Retrieves device identifier
67 | --
68 | -- >>> getDevice
69 | -- 0
70 | getDevice :: IO Int
71 | getDevice = fromIntegral <$> afCall1 af_get_device
72 |
73 | -- af_err af_sync(const int device);
74 | -- af_err af_alloc_device(void **ptr, const dim_t bytes);
75 | -- af_err af_free_device(void *ptr);
76 | -- af_err af_alloc_pinned(void **ptr, const dim_t bytes);
77 | -- af_err af_free_pinned(void *ptr);
78 | -- af_err af_alloc_host(void **ptr, const dim_t bytes);
79 | -- af_err af_free_host(void *ptr);
80 | -- af_err af_device_array(af_array *arr, const void *data, const unsigned ndims, const dim_t * const dims, const af_dtype type);
81 | -- af_err af_device_mem_info(size_t *alloc_bytes, size_t *alloc_buffers, size_t *lock_bytes, size_t *lock_buffers);
82 | -- af_err af_print_mem_info(const char *msg, const int device_id);
83 | -- af_err af_device_gc();
84 | -- af_err af_set_mem_step_size(const size_t step_bytes);
85 | -- af_err af_get_mem_step_size(size_t *step_bytes);
86 | -- af_err af_lock_device_ptr(const af_array arr);
87 | -- af_err af_unlock_device_ptr(const af_array arr);
88 | -- af_err af_lock_array(const af_array arr);
89 | -- af_err af_is_locked_array(bool *res, const af_array arr);
90 | -- af_err af_get_device_ptr(void **ptr, const af_array arr);
91 |
--------------------------------------------------------------------------------
/src/ArrayFire/Exception.hs:
--------------------------------------------------------------------------------
1 | {-# LANGUAGE ViewPatterns #-}
2 | {-# LANGUAGE RecordWildCards #-}
3 | --------------------------------------------------------------------------------
4 | -- |
5 | -- Module : ArrayFire.Exception
6 | -- Copyright : David Johnson (c) 2019-2020
7 | -- License : BSD 3
8 | -- Maintainer : David Johnson
9 | -- Stability : Experimental
10 | -- Portability : GHC
11 | --
12 | -- @
13 | -- module Main where
14 | --
15 | -- import ArrayFire
16 | --
17 | -- main :: IO ()
18 | -- main = print =<< getAvailableBackends
19 | -- @
20 | --
21 | -- @
22 | -- [nix-shell:~\/arrayfire]$ .\/main
23 | -- [CPU,OpenCL]
24 | -- @
25 | --------------------------------------------------------------------------------
26 | module ArrayFire.Exception where
27 |
28 | import Control.Exception hiding (TypeError)
29 | import Data.Typeable
30 | import Control.Monad
31 | import Foreign.C.String
32 | import Foreign.Ptr
33 | import ArrayFire.Internal.Exception
34 | import ArrayFire.Internal.Defines
35 |
36 | -- | String representation of ArrayFire exception
37 | errorToString :: AFErr -> IO String
38 | errorToString = peekCString <=< af_err_to_string
39 |
40 | -- | ArrayFire exception type
41 | data AFExceptionType
42 | = NoMemoryError
43 | | DriverError
44 | | RuntimeError
45 | | InvalidArrayError
46 | | ArgError
47 | | SizeError
48 | | TypeError
49 | | DiffTypeError
50 | | BatchError
51 | | DeviceError
52 | | NotSupportedError
53 | | NotConfiguredError
54 | | NonFreeError
55 | | NoDblError
56 | | NoGfxError
57 | | LoadLibError
58 | | LoadSymError
59 | | BackendMismatchError
60 | | InternalError
61 | | UnknownError
62 | | UnhandledError
63 | deriving (Show, Eq, Typeable)
64 |
65 | -- | Exception type for ArrayFire API
66 | data AFException
67 | = AFException
68 | { afExceptionType :: AFExceptionType
69 | -- ^ The Exception type to throw
70 | , afExceptionCode :: Int
71 | -- ^ Code representing the exception
72 | , afExceptionMsg :: String
73 | -- ^ Exception message
74 | } deriving (Show, Eq, Typeable)
75 |
76 | instance Exception AFException
77 |
78 | -- | Conversion function helper
79 | toAFExceptionType :: AFErr -> AFExceptionType
80 | toAFExceptionType (AFErr 101) = NoMemoryError
81 | toAFExceptionType (AFErr 102) = DriverError
82 | toAFExceptionType (AFErr 103) = RuntimeError
83 | toAFExceptionType (AFErr 201) = InvalidArrayError
84 | toAFExceptionType (AFErr 202) = ArgError
85 | toAFExceptionType (AFErr 203) = SizeError
86 | toAFExceptionType (AFErr 204) = TypeError
87 | toAFExceptionType (AFErr 205) = DiffTypeError
88 | toAFExceptionType (AFErr 207) = BatchError
89 | toAFExceptionType (AFErr 208) = DeviceError
90 | toAFExceptionType (AFErr 301) = NotSupportedError
91 | toAFExceptionType (AFErr 302) = NotConfiguredError
92 | toAFExceptionType (AFErr 303) = NonFreeError
93 | toAFExceptionType (AFErr 401) = NoDblError
94 | toAFExceptionType (AFErr 402) = NoGfxError
95 | toAFExceptionType (AFErr 501) = LoadLibError
96 | toAFExceptionType (AFErr 502) = LoadSymError
97 | toAFExceptionType (AFErr 503) = BackendMismatchError
98 | toAFExceptionType (AFErr 998) = InternalError
99 | toAFExceptionType (AFErr 999) = UnknownError
100 | toAFExceptionType (AFErr _) = UnhandledError
101 |
102 | -- | Throws an ArrayFire Exception
103 | throwAFError :: AFErr -> IO ()
104 | throwAFError exitCode =
105 | unless (exitCode == afSuccess) $ do
106 | let AFErr (fromIntegral -> afExceptionCode) = exitCode
107 | afExceptionType = toAFExceptionType exitCode
108 | afExceptionMsg <- errorToString exitCode
109 | throwIO AFException {..}
110 |
111 | foreign import ccall unsafe "&af_release_random_engine"
112 | af_release_random_engine_finalizer :: FunPtr (AFRandomEngine -> IO ())
113 |
114 | foreign import ccall unsafe "&af_destroy_window"
115 | af_release_window_finalizer :: FunPtr (AFWindow -> IO ())
116 |
117 | foreign import ccall unsafe "&af_release_array"
118 | af_release_array_finalizer :: FunPtr (AFArray -> IO ())
119 |
--------------------------------------------------------------------------------
/src/ArrayFire/Features.hs:
--------------------------------------------------------------------------------
1 | {-# LANGUAGE ViewPatterns #-}
2 | --------------------------------------------------------------------------------
3 | -- |
4 | -- Module : ArrayFire.Features
5 | -- Copyright : David Johnson (c) 2019-2020
6 | -- License : BSD 3
7 | -- Maintainer : David Johnson
8 | -- Stability : Experimental
9 | -- Portability : GHC
10 | --
11 | -- Functions for constructing and querying 'Features'
12 | --
13 | -- @
14 | -- >>> createFeatures 10
15 | -- @
16 | --
17 | --------------------------------------------------------------------------------
18 | module ArrayFire.Features where
19 |
20 | import Foreign.Marshal
21 | import Foreign.Storable
22 | import Foreign.ForeignPtr
23 | import System.IO.Unsafe
24 |
25 | import ArrayFire.Internal.Features
26 | import ArrayFire.Internal.Types
27 | import ArrayFire.FFI
28 | import ArrayFire.Exception
29 |
30 | -- | Construct Features
31 | --
32 | -- >>> features = createFeatures 10
33 | --
34 | createFeatures
35 | :: Int
36 | -> Features
37 | createFeatures (fromIntegral -> n) =
38 | unsafePerformIO $ do
39 | ptr <-
40 | alloca $ \ptrInput -> do
41 | throwAFError =<< ptrInput `af_create_features` n
42 | peek ptrInput
43 | fptr <- newForeignPtr af_release_features ptr
44 | pure (Features fptr)
45 |
46 | -- | Retain Features
47 | --
48 | -- >>> features = retainFeatures (createFeatures 10)
49 | --
50 | retainFeatures
51 | :: Features
52 | -> Features
53 | retainFeatures = (`op1f` af_retain_features)
54 |
55 | -- | Get number of Features
56 | --
57 | -- link
58 | --
59 | -- >>> getFeaturesNum (createFeatures 10)
60 | -- 10
61 | --
62 | getFeaturesNum
63 | :: Features
64 | -> Int
65 | getFeaturesNum = fromIntegral . (`infoFromFeatures` af_get_features_num)
66 |
67 | -- | Get Feature X-position
68 | --
69 | -- >>> getFeaturesXPos (createFeatures 10)
70 | -- ArrayFire Array
71 | -- [10 1 1 1]
72 | -- 0.0000
73 | -- 1.8750
74 | -- 0.0000
75 | -- 2.3750
76 | -- 0.0000
77 | -- 2.5938
78 | -- 0.0000
79 | -- 2.0000
80 | -- 0.0000
81 | -- 2.4375
82 | getFeaturesXPos
83 | :: Features
84 | -> Array a
85 | getFeaturesXPos = (`featuresToArray` af_get_features_xpos)
86 |
87 | -- | Get Feature Y-position
88 | --
89 | -- >>> getFeaturesYPos (createFeatures 10)
90 | -- ArrayFire Array
91 | -- [10 1 1 1]
92 | -- nan
93 | -- nan
94 | -- nan
95 | -- nan
96 | -- nan
97 | -- nan
98 | -- nan
99 | -- nan
100 | -- nan
101 | -- nan
102 | getFeaturesYPos
103 | :: Features
104 | -> Array a
105 | getFeaturesYPos = (`featuresToArray` af_get_features_ypos)
106 |
107 | -- | Get Feature Score
108 | --
109 | -- >>> getFeaturesScore (createFeatures 10)
110 | -- ArrayFire Array
111 | -- [10 1 1 1]
112 | -- nan
113 | -- nan
114 | -- nan
115 | -- nan
116 | -- nan
117 | -- nan
118 | -- nan
119 | -- nan
120 | -- nan
121 | -- nan
122 | getFeaturesScore
123 | :: Features
124 | -> Array a
125 | getFeaturesScore = (`featuresToArray` af_get_features_score)
126 |
127 | -- | Get Feature orientation
128 | --
129 | -- >>> getFeaturesOrientation (createFeatures 10)
130 | -- ArrayFire Array
131 | -- [10 1 1 1]
132 | -- nan
133 | -- nan
134 | -- nan
135 | -- nan
136 | -- nan
137 | -- nan
138 | -- nan
139 | -- nan
140 | -- nan
141 | -- nan
142 | getFeaturesOrientation
143 | :: Features
144 | -> Array a
145 | getFeaturesOrientation = (`featuresToArray` af_get_features_orientation)
146 |
147 | -- | Get Feature size
148 | --
149 | -- >>> getFeaturesSize (createFeatures 10)
150 | -- ArrayFire Array
151 | -- [10 1 1 1]
152 | -- nan
153 | -- nan
154 | -- nan
155 | -- nan
156 | -- nan
157 | -- nan
158 | -- nan
159 | -- nan
160 | -- nan
161 | -- nan
162 | getFeaturesSize
163 | :: Features
164 | -> Array a
165 | getFeaturesSize = (`featuresToArray` af_get_features_size)
166 |
--------------------------------------------------------------------------------
/src/ArrayFire/Index.hs:
--------------------------------------------------------------------------------
1 | --------------------------------------------------------------------------------
2 | -- |
3 | -- Module : ArrayFire.Index
4 | -- Copyright : David Johnson (c) 2019-2020
5 | -- License : BSD 3
6 | -- Maintainer : David Johnson
7 | -- Stability : Experimental
8 | -- Portability : GHC
9 | --
10 | -- Functions for indexing into an 'Array'
11 | --
12 | --------------------------------------------------------------------------------
13 | module ArrayFire.Index where
14 |
15 | import ArrayFire.Internal.Index
16 | import ArrayFire.Internal.Types
17 | import ArrayFire.FFI
18 | import ArrayFire.Exception
19 |
20 | import Foreign
21 |
22 | import System.IO.Unsafe
23 | import Control.Exception
24 |
25 | -- | Index into an 'Array' by 'Seq'
26 | index
27 | :: Array a
28 | -- ^ 'Array' argument
29 | -> [Seq]
30 | -- ^ 'Seq' to use for indexing
31 | -> Array a
32 | index (Array fptr) seqs =
33 | unsafePerformIO . mask_ . withForeignPtr fptr $ \ptr -> do
34 | alloca $ \aptr ->
35 | withArray (toAFSeq <$> seqs) $ \sptr -> do
36 | throwAFError =<< af_index aptr ptr n sptr
37 | Array <$> do
38 | newForeignPtr af_release_array_finalizer
39 | =<< peek aptr
40 | where
41 | n = fromIntegral (length seqs)
42 |
43 | -- | Lookup an Array by keys along a specified dimension
44 | lookup :: Array a -> Array a -> Int -> Array a
45 | lookup a b n = op2 a b $ \p x y -> af_lookup p x y (fromIntegral n)
46 |
47 | -- af_err af_assign_seq( af_array *out, const af_array lhs, const unsigned ndims, const af_seq* const indices, const af_array rhs);
48 | -- | Calculates 'mean' of 'Array' along user-specified dimension.
49 | --
50 | -- @
51 | -- >>> print $ mean 0 ( vector @Int 10 [1..] )
52 | -- @
53 | -- @
54 | -- ArrayFire Array
55 | -- [1 1 1 1]
56 | -- 5.5000
57 | -- @
58 | -- assignSeq :: Array a -> Int -> [Seq] -> Array a -> Array a
59 | -- assignSeq = error "Not implemneted"
60 |
61 | -- af_err af_index_gen( af_array *out, const af_array in, const dim_t ndims, const af_index_t* indices);
62 | -- | Calculates 'mean' of 'Array' along user-specified dimension.
63 | --
64 | -- @
65 | -- >>> print $ mean 0 ( vector @Int 10 [1..] )
66 | -- @
67 | -- @
68 | -- ArrayFire Array
69 | -- [1 1 1 1]
70 | -- 5.5000
71 | -- @
72 | -- indexGen :: Array a -> Int -> [Index a] -> Array a -> Array a
73 | -- indexGen = error "Not implemneted"
74 |
75 | -- af_err af_assingn_gen( af_array *out, const af_array lhs, const dim_t ndims, const af_index_t* indices, const af_array rhs);
76 | -- | Calculates 'mean' of 'Array' along user-specified dimension.
77 | --
78 | -- @
79 | -- >>> print $ mean 0 ( vector @Int 10 [1..] )
80 | -- @
81 | -- @
82 | -- ArrayFire Array
83 | -- [1 1 1 1]
84 | -- 5.5000
85 | -- @
86 | -- assignGen :: Array a -> Int -> [Index a] -> Array a -> Array a
87 | -- assignGen = error "Not implemneted"
88 |
89 | -- af_err af_create_indexers(af_index_t** indexers);
90 | -- af_err af_set_array_indexer(af_index_t* indexer, const af_array idx, const dim_t dim);
91 | -- af_err af_set_seq_indexer(af_index_t* indexer, const af_seq* idx, const dim_t dim, const bool is_batch);
92 | -- af_err af_set_seq_param_indexer(af_index_t* indexer, const double begin, const double end, const double step, const dim_t dim, const bool is_batch);
93 | -- af_err af_release_indexers(af_index_t* indexers);
94 |
--------------------------------------------------------------------------------
/src/ArrayFire/Internal/Algorithm.hsc:
--------------------------------------------------------------------------------
1 | {-# LANGUAGE CPP #-}
2 | module ArrayFire.Internal.Algorithm where
3 |
4 | import ArrayFire.Internal.Defines
5 |
6 | import Foreign.Ptr
7 | import Foreign.C.Types
8 |
9 | #include "af/algorithm.h"
10 | foreign import ccall unsafe "af_sum"
11 | af_sum :: Ptr AFArray -> AFArray -> CInt -> IO AFErr
12 | foreign import ccall unsafe "af_sum_nan"
13 | af_sum_nan :: Ptr AFArray -> AFArray -> CInt -> Double -> IO AFErr
14 | foreign import ccall unsafe "af_product"
15 | af_product :: Ptr AFArray -> AFArray -> CInt -> IO AFErr
16 | foreign import ccall unsafe "af_product_nan"
17 | af_product_nan :: Ptr AFArray -> AFArray -> CInt -> Double -> IO AFErr
18 | foreign import ccall unsafe "af_min"
19 | af_min :: Ptr AFArray -> AFArray -> CInt -> IO AFErr
20 | foreign import ccall unsafe "af_max"
21 | af_max :: Ptr AFArray -> AFArray -> CInt -> IO AFErr
22 | foreign import ccall unsafe "af_all_true"
23 | af_all_true :: Ptr AFArray -> AFArray -> CInt -> IO AFErr
24 | foreign import ccall unsafe "af_any_true"
25 | af_any_true :: Ptr AFArray -> AFArray -> CInt -> IO AFErr
26 | foreign import ccall unsafe "af_count"
27 | af_count :: Ptr AFArray -> AFArray -> CInt -> IO AFErr
28 | foreign import ccall unsafe "af_sum_all"
29 | af_sum_all :: Ptr Double -> Ptr Double -> AFArray -> IO AFErr
30 | foreign import ccall unsafe "af_sum_nan_all"
31 | af_sum_nan_all :: Ptr Double -> Ptr Double -> AFArray -> Double -> IO AFErr
32 | foreign import ccall unsafe "af_product_all"
33 | af_product_all :: Ptr Double -> Ptr Double -> AFArray -> IO AFErr
34 | foreign import ccall unsafe "af_product_nan_all"
35 | af_product_nan_all :: Ptr Double -> Ptr Double -> AFArray -> Double -> IO AFErr
36 | foreign import ccall unsafe "af_min_all"
37 | af_min_all :: Ptr Double -> Ptr Double -> AFArray -> IO AFErr
38 | foreign import ccall unsafe "af_max_all"
39 | af_max_all :: Ptr Double -> Ptr Double -> AFArray -> IO AFErr
40 | foreign import ccall unsafe "af_all_true_all"
41 | af_all_true_all :: Ptr Double -> Ptr Double -> AFArray -> IO AFErr
42 | foreign import ccall unsafe "af_any_true_all"
43 | af_any_true_all :: Ptr Double -> Ptr Double -> AFArray -> IO AFErr
44 | foreign import ccall unsafe "af_count_all"
45 | af_count_all :: Ptr Double -> Ptr Double -> AFArray -> IO AFErr
46 | foreign import ccall unsafe "af_imin"
47 | af_imin :: Ptr AFArray -> Ptr AFArray -> AFArray -> CInt -> IO AFErr
48 | foreign import ccall unsafe "af_imax"
49 | af_imax :: Ptr AFArray -> Ptr AFArray -> AFArray -> CInt -> IO AFErr
50 | foreign import ccall unsafe "af_imin_all"
51 | af_imin_all :: Ptr Double -> Ptr Double -> Ptr CUInt -> AFArray -> IO AFErr
52 | foreign import ccall unsafe "af_imax_all"
53 | af_imax_all :: Ptr Double -> Ptr Double -> Ptr CUInt -> AFArray -> IO AFErr
54 | foreign import ccall unsafe "af_accum"
55 | af_accum :: Ptr AFArray -> AFArray -> CInt -> IO AFErr
56 | foreign import ccall unsafe "af_scan"
57 | af_scan :: Ptr AFArray -> AFArray -> CInt -> AFBinaryOp -> CBool -> IO AFErr
58 | foreign import ccall unsafe "af_scan_by_key"
59 | af_scan_by_key :: Ptr AFArray -> AFArray -> AFArray -> CInt -> AFBinaryOp -> CBool -> IO AFErr
60 | foreign import ccall unsafe "af_where"
61 | af_where :: Ptr AFArray -> AFArray -> IO AFErr
62 | foreign import ccall unsafe "af_diff1"
63 | af_diff1 :: Ptr AFArray -> AFArray -> CInt -> IO AFErr
64 | foreign import ccall unsafe "af_diff2"
65 | af_diff2 :: Ptr AFArray -> AFArray -> CInt -> IO AFErr
66 | foreign import ccall unsafe "af_sort"
67 | af_sort :: Ptr AFArray -> AFArray -> CUInt -> CBool -> IO AFErr
68 | foreign import ccall unsafe "af_sort_index"
69 | af_sort_index :: Ptr AFArray -> Ptr AFArray -> AFArray -> CUInt -> CBool -> IO AFErr
70 | foreign import ccall unsafe "af_sort_by_key"
71 | af_sort_by_key :: Ptr AFArray -> Ptr AFArray -> AFArray -> AFArray -> CUInt -> CBool -> IO AFErr
72 | foreign import ccall unsafe "af_set_unique"
73 | af_set_unique :: Ptr AFArray -> AFArray -> CBool -> IO AFErr
74 | foreign import ccall unsafe "af_set_union"
75 | af_set_union :: Ptr AFArray -> AFArray -> AFArray -> CBool -> IO AFErr
76 | foreign import ccall unsafe "af_set_intersect"
77 | af_set_intersect :: Ptr AFArray -> AFArray -> AFArray -> CBool -> IO AFErr
78 |
--------------------------------------------------------------------------------
/src/ArrayFire/Internal/Arith.hsc:
--------------------------------------------------------------------------------
1 | {-# LANGUAGE CPP #-}
2 | module ArrayFire.Internal.Arith where
3 |
4 | import ArrayFire.Internal.Defines
5 |
6 | import Foreign.Ptr
7 | import Foreign.C.Types
8 |
9 | #include "af/arith.h"
10 | foreign import ccall unsafe "af_add"
11 | af_add :: Ptr AFArray -> AFArray -> AFArray -> CBool -> IO AFErr
12 | foreign import ccall unsafe "af_sub"
13 | af_sub :: Ptr AFArray -> AFArray -> AFArray -> CBool -> IO AFErr
14 | foreign import ccall unsafe "af_mul"
15 | af_mul :: Ptr AFArray -> AFArray -> AFArray -> CBool -> IO AFErr
16 | foreign import ccall unsafe "af_div"
17 | af_div :: Ptr AFArray -> AFArray -> AFArray -> CBool -> IO AFErr
18 | foreign import ccall unsafe "af_lt"
19 | af_lt :: Ptr AFArray -> AFArray -> AFArray -> CBool -> IO AFErr
20 | foreign import ccall unsafe "af_gt"
21 | af_gt :: Ptr AFArray -> AFArray -> AFArray -> CBool -> IO AFErr
22 | foreign import ccall unsafe "af_le"
23 | af_le :: Ptr AFArray -> AFArray -> AFArray -> CBool -> IO AFErr
24 | foreign import ccall unsafe "af_ge"
25 | af_ge :: Ptr AFArray -> AFArray -> AFArray -> CBool -> IO AFErr
26 | foreign import ccall unsafe "af_eq"
27 | af_eq :: Ptr AFArray -> AFArray -> AFArray -> CBool -> IO AFErr
28 | foreign import ccall unsafe "af_neq"
29 | af_neq :: Ptr AFArray -> AFArray -> AFArray -> CBool -> IO AFErr
30 | foreign import ccall unsafe "af_and"
31 | af_and :: Ptr AFArray -> AFArray -> AFArray -> CBool -> IO AFErr
32 | foreign import ccall unsafe "af_or"
33 | af_or :: Ptr AFArray -> AFArray -> AFArray -> CBool -> IO AFErr
34 | foreign import ccall unsafe "af_not"
35 | af_not :: Ptr AFArray -> AFArray -> IO AFErr
36 | foreign import ccall unsafe "af_bitand"
37 | af_bitand :: Ptr AFArray -> AFArray -> AFArray -> CBool -> IO AFErr
38 | foreign import ccall unsafe "af_bitor"
39 | af_bitor :: Ptr AFArray -> AFArray -> AFArray -> CBool -> IO AFErr
40 | foreign import ccall unsafe "af_bitxor"
41 | af_bitxor :: Ptr AFArray -> AFArray -> AFArray -> CBool -> IO AFErr
42 | foreign import ccall unsafe "af_bitshiftl"
43 | af_bitshiftl :: Ptr AFArray -> AFArray -> AFArray -> CBool -> IO AFErr
44 | foreign import ccall unsafe "af_bitshiftr"
45 | af_bitshiftr :: Ptr AFArray -> AFArray -> AFArray -> CBool -> IO AFErr
46 | foreign import ccall unsafe "af_cast"
47 | af_cast :: Ptr AFArray -> AFArray -> AFDtype -> IO AFErr
48 | foreign import ccall unsafe "af_minof"
49 | af_minof :: Ptr AFArray -> AFArray -> AFArray -> CBool -> IO AFErr
50 | foreign import ccall unsafe "af_maxof"
51 | af_maxof :: Ptr AFArray -> AFArray -> AFArray -> CBool -> IO AFErr
52 | foreign import ccall unsafe "af_clamp"
53 | af_clamp :: Ptr AFArray -> AFArray -> AFArray -> AFArray -> CBool -> IO AFErr
54 | foreign import ccall unsafe "af_rem"
55 | af_rem :: Ptr AFArray -> AFArray -> AFArray -> CBool -> IO AFErr
56 | foreign import ccall unsafe "af_mod"
57 | af_mod :: Ptr AFArray -> AFArray -> AFArray -> CBool -> IO AFErr
58 | foreign import ccall unsafe "af_abs"
59 | af_abs :: Ptr AFArray -> AFArray -> IO AFErr
60 | foreign import ccall unsafe "af_arg"
61 | af_arg :: Ptr AFArray -> AFArray -> IO AFErr
62 | foreign import ccall unsafe "af_sign"
63 | af_sign :: Ptr AFArray -> AFArray -> IO AFErr
64 | foreign import ccall unsafe "af_round"
65 | af_round :: Ptr AFArray -> AFArray -> IO AFErr
66 | foreign import ccall unsafe "af_trunc"
67 | af_trunc :: Ptr AFArray -> AFArray -> IO AFErr
68 | foreign import ccall unsafe "af_floor"
69 | af_floor :: Ptr AFArray -> AFArray -> IO AFErr
70 | foreign import ccall unsafe "af_ceil"
71 | af_ceil :: Ptr AFArray -> AFArray -> IO AFErr
72 | foreign import ccall unsafe "af_hypot"
73 | af_hypot :: Ptr AFArray -> AFArray -> AFArray -> CBool -> IO AFErr
74 | foreign import ccall unsafe "af_sin"
75 | af_sin :: Ptr AFArray -> AFArray -> IO AFErr
76 | foreign import ccall unsafe "af_cos"
77 | af_cos :: Ptr AFArray -> AFArray -> IO AFErr
78 | foreign import ccall unsafe "af_tan"
79 | af_tan :: Ptr AFArray -> AFArray -> IO AFErr
80 | foreign import ccall unsafe "af_asin"
81 | af_asin :: Ptr AFArray -> AFArray -> IO AFErr
82 | foreign import ccall unsafe "af_acos"
83 | af_acos :: Ptr AFArray -> AFArray -> IO AFErr
84 | foreign import ccall unsafe "af_atan"
85 | af_atan :: Ptr AFArray -> AFArray -> IO AFErr
86 | foreign import ccall unsafe "af_atan2"
87 | af_atan2 :: Ptr AFArray -> AFArray -> AFArray -> CBool -> IO AFErr
88 | foreign import ccall unsafe "af_cplx2"
89 | af_cplx2 :: Ptr AFArray -> AFArray -> AFArray -> CBool -> IO AFErr
90 | foreign import ccall unsafe "af_cplx"
91 | af_cplx :: Ptr AFArray -> AFArray -> IO AFErr
92 | foreign import ccall unsafe "af_real"
93 | af_real :: Ptr AFArray -> AFArray -> IO AFErr
94 | foreign import ccall unsafe "af_imag"
95 | af_imag :: Ptr AFArray -> AFArray -> IO AFErr
96 | foreign import ccall unsafe "af_conjg"
97 | af_conjg :: Ptr AFArray -> AFArray -> IO AFErr
98 | foreign import ccall unsafe "af_sinh"
99 | af_sinh :: Ptr AFArray -> AFArray -> IO AFErr
100 | foreign import ccall unsafe "af_cosh"
101 | af_cosh :: Ptr AFArray -> AFArray -> IO AFErr
102 | foreign import ccall unsafe "af_tanh"
103 | af_tanh :: Ptr AFArray -> AFArray -> IO AFErr
104 | foreign import ccall unsafe "af_asinh"
105 | af_asinh :: Ptr AFArray -> AFArray -> IO AFErr
106 | foreign import ccall unsafe "af_acosh"
107 | af_acosh :: Ptr AFArray -> AFArray -> IO AFErr
108 | foreign import ccall unsafe "af_atanh"
109 | af_atanh :: Ptr AFArray -> AFArray -> IO AFErr
110 | foreign import ccall unsafe "af_root"
111 | af_root :: Ptr AFArray -> AFArray -> AFArray -> CBool -> IO AFErr
112 | foreign import ccall unsafe "af_pow"
113 | af_pow :: Ptr AFArray -> AFArray -> AFArray -> CBool -> IO AFErr
114 | foreign import ccall unsafe "af_pow2"
115 | af_pow2 :: Ptr AFArray -> AFArray -> IO AFErr
116 | foreign import ccall unsafe "af_exp"
117 | af_exp :: Ptr AFArray -> AFArray -> IO AFErr
118 | foreign import ccall unsafe "af_sigmoid"
119 | af_sigmoid :: Ptr AFArray -> AFArray -> IO AFErr
120 | foreign import ccall unsafe "af_expm1"
121 | af_expm1 :: Ptr AFArray -> AFArray -> IO AFErr
122 | foreign import ccall unsafe "af_erf"
123 | af_erf :: Ptr AFArray -> AFArray -> IO AFErr
124 | foreign import ccall unsafe "af_erfc"
125 | af_erfc :: Ptr AFArray -> AFArray -> IO AFErr
126 | foreign import ccall unsafe "af_log"
127 | af_log :: Ptr AFArray -> AFArray -> IO AFErr
128 | foreign import ccall unsafe "af_log1p"
129 | af_log1p :: Ptr AFArray -> AFArray -> IO AFErr
130 | foreign import ccall unsafe "af_log10"
131 | af_log10 :: Ptr AFArray -> AFArray -> IO AFErr
132 | foreign import ccall unsafe "af_log2"
133 | af_log2 :: Ptr AFArray -> AFArray -> IO AFErr
134 | foreign import ccall unsafe "af_sqrt"
135 | af_sqrt :: Ptr AFArray -> AFArray -> IO AFErr
136 | foreign import ccall unsafe "af_cbrt"
137 | af_cbrt :: Ptr AFArray -> AFArray -> IO AFErr
138 | foreign import ccall unsafe "af_factorial"
139 | af_factorial :: Ptr AFArray -> AFArray -> IO AFErr
140 | foreign import ccall unsafe "af_tgamma"
141 | af_tgamma :: Ptr AFArray -> AFArray -> IO AFErr
142 | foreign import ccall unsafe "af_lgamma"
143 | af_lgamma :: Ptr AFArray -> AFArray -> IO AFErr
144 | foreign import ccall unsafe "af_iszero"
145 | af_iszero :: Ptr AFArray -> AFArray -> IO AFErr
146 | foreign import ccall unsafe "af_isinf"
147 | af_isinf :: Ptr AFArray -> AFArray -> IO AFErr
148 | foreign import ccall unsafe "af_isnan"
149 | af_isnan :: Ptr AFArray -> AFArray -> IO AFErr
150 |
--------------------------------------------------------------------------------
/src/ArrayFire/Internal/Array.hsc:
--------------------------------------------------------------------------------
1 | {-# LANGUAGE CPP #-}
2 | module ArrayFire.Internal.Array where
3 |
4 | import ArrayFire.Internal.Defines
5 |
6 | import Foreign.Ptr
7 | import Foreign.C.Types
8 |
9 | #include "af/array.h"
10 | foreign import ccall unsafe "af_create_array"
11 | af_create_array :: Ptr AFArray -> Ptr () -> CUInt -> Ptr DimT -> AFDtype -> IO AFErr
12 | foreign import ccall unsafe "af_create_handle"
13 | af_create_handle :: Ptr AFArray -> CUInt -> Ptr DimT -> AFDtype -> IO AFErr
14 | foreign import ccall unsafe "af_copy_array"
15 | af_copy_array :: Ptr AFArray -> AFArray -> IO AFErr
16 | foreign import ccall unsafe "af_write_array"
17 | af_write_array :: AFArray -> Ptr () -> CSize -> AFSource -> IO AFErr
18 | foreign import ccall unsafe "af_get_data_ptr"
19 | af_get_data_ptr :: Ptr () -> AFArray -> IO AFErr
20 | foreign import ccall unsafe "af_release_array"
21 | af_release_array :: AFArray -> IO AFErr
22 | foreign import ccall unsafe "af_retain_array"
23 | af_retain_array :: Ptr AFArray -> AFArray -> IO AFErr
24 | foreign import ccall unsafe "af_get_data_ref_count"
25 | af_get_data_ref_count :: Ptr CInt -> AFArray -> IO AFErr
26 | foreign import ccall unsafe "af_eval"
27 | af_eval :: AFArray -> IO AFErr
28 | foreign import ccall unsafe "af_eval_multiple"
29 | af_eval_multiple :: CInt -> Ptr AFArray -> IO AFErr
30 | foreign import ccall unsafe "af_set_manual_eval_flag"
31 | af_set_manual_eval_flag :: CBool -> IO AFErr
32 | foreign import ccall unsafe "af_get_manual_eval_flag"
33 | af_get_manual_eval_flag :: Ptr CBool -> IO AFErr
34 | foreign import ccall unsafe "af_get_elements"
35 | af_get_elements :: Ptr DimT -> AFArray -> IO AFErr
36 | foreign import ccall unsafe "af_get_type"
37 | af_get_type :: Ptr AFDtype -> AFArray -> IO AFErr
38 | foreign import ccall unsafe "af_get_dims"
39 | af_get_dims :: Ptr DimT -> Ptr DimT -> Ptr DimT -> Ptr DimT -> AFArray -> IO AFErr
40 | foreign import ccall unsafe "af_get_numdims"
41 | af_get_numdims :: Ptr CUInt -> AFArray -> IO AFErr
42 | foreign import ccall unsafe "af_is_empty"
43 | af_is_empty :: Ptr CBool -> AFArray -> IO AFErr
44 | foreign import ccall unsafe "af_is_scalar"
45 | af_is_scalar :: Ptr CBool -> AFArray -> IO AFErr
46 | foreign import ccall unsafe "af_is_row"
47 | af_is_row :: Ptr CBool -> AFArray -> IO AFErr
48 | foreign import ccall unsafe "af_is_column"
49 | af_is_column :: Ptr CBool -> AFArray -> IO AFErr
50 | foreign import ccall unsafe "af_is_vector"
51 | af_is_vector :: Ptr CBool -> AFArray -> IO AFErr
52 | foreign import ccall unsafe "af_is_complex"
53 | af_is_complex :: Ptr CBool -> AFArray -> IO AFErr
54 | foreign import ccall unsafe "af_is_real"
55 | af_is_real :: Ptr CBool -> AFArray -> IO AFErr
56 | foreign import ccall unsafe "af_is_double"
57 | af_is_double :: Ptr CBool -> AFArray -> IO AFErr
58 | foreign import ccall unsafe "af_is_single"
59 | af_is_single :: Ptr CBool -> AFArray -> IO AFErr
60 | foreign import ccall unsafe "af_is_realfloating"
61 | af_is_realfloating :: Ptr CBool -> AFArray -> IO AFErr
62 | foreign import ccall unsafe "af_is_floating"
63 | af_is_floating :: Ptr CBool -> AFArray -> IO AFErr
64 | foreign import ccall unsafe "af_is_integer"
65 | af_is_integer :: Ptr CBool -> AFArray -> IO AFErr
66 | foreign import ccall unsafe "af_is_bool"
67 | af_is_bool :: Ptr CBool -> AFArray -> IO AFErr
68 | foreign import ccall unsafe "af_is_sparse"
69 | af_is_sparse :: Ptr CBool -> AFArray -> IO AFErr
70 | foreign import ccall unsafe "af_get_scalar"
71 | af_get_scalar :: Ptr () -> AFArray -> IO AFErr
72 |
--------------------------------------------------------------------------------
/src/ArrayFire/Internal/BLAS.hsc:
--------------------------------------------------------------------------------
1 | {-# LANGUAGE CPP #-}
2 | module ArrayFire.Internal.BLAS where
3 |
4 | import ArrayFire.Internal.Defines
5 |
6 | import Foreign.Ptr
7 | import Foreign.C.Types
8 |
9 | #include "af/blas.h"
10 | foreign import ccall unsafe "af_matmul"
11 | af_matmul :: Ptr AFArray -> AFArray -> AFArray -> AFMatProp -> AFMatProp -> IO AFErr
12 | foreign import ccall unsafe "af_dot"
13 | af_dot :: Ptr AFArray -> AFArray -> AFArray -> AFMatProp -> AFMatProp -> IO AFErr
14 | foreign import ccall unsafe "af_dot_all"
15 | af_dot_all :: Ptr Double -> Ptr Double -> AFArray -> AFArray -> AFMatProp -> AFMatProp -> IO AFErr
16 | foreign import ccall unsafe "af_transpose"
17 | af_transpose :: Ptr AFArray -> AFArray -> CBool -> IO AFErr
18 | foreign import ccall unsafe "af_transpose_inplace"
19 | af_transpose_inplace :: AFArray -> CBool -> IO AFErr
20 |
--------------------------------------------------------------------------------
/src/ArrayFire/Internal/Backend.hsc:
--------------------------------------------------------------------------------
1 | {-# LANGUAGE CPP #-}
2 | module ArrayFire.Internal.Backend where
3 |
4 | import ArrayFire.Internal.Defines
5 |
6 | import Foreign.Ptr
7 | import Foreign.C.Types
8 |
9 | #include "af/backend.h"
10 | foreign import ccall unsafe "af_set_backend"
11 | af_set_backend :: AFBackend -> IO AFErr
12 | foreign import ccall unsafe "af_get_backend_count"
13 | af_get_backend_count :: Ptr CUInt -> IO AFErr
14 | foreign import ccall unsafe "af_get_available_backends"
15 | af_get_available_backends :: Ptr CInt -> IO AFErr
16 | foreign import ccall unsafe "af_get_backend_id"
17 | af_get_backend_id :: Ptr AFBackend -> AFArray -> IO AFErr
18 | foreign import ccall unsafe "af_get_active_backend"
19 | af_get_active_backend :: Ptr AFBackend -> IO AFErr
20 | foreign import ccall unsafe "af_get_device_id"
21 | af_get_device_id :: Ptr CInt -> AFArray -> IO AFErr
22 |
--------------------------------------------------------------------------------
/src/ArrayFire/Internal/CUDA.hsc:
--------------------------------------------------------------------------------
1 | {-# LANGUAGE CPP #-}
2 | module ArrayFire.Internal.CUDA where
3 |
4 | import ArrayFire.Internal.Defines
5 | import ArrayFire.Internal.Types
6 | import Data.Word
7 | import Data.Int
8 | import Foreign.Ptr
9 | import Foreign.C.Types
10 |
11 | #include "af/cuda.h"
12 | foreign import ccall unsafe "afcu_get_stream"
13 | afcu_get_stream :: Ptr CudaStreamT -> CInt -> IO AFErr
14 | foreign import ccall unsafe "afcu_get_native_id"
15 | afcu_get_native_id :: Ptr CInt -> CInt -> IO AFErr
16 | foreign import ccall unsafe "afcu_set_native_id"
17 | afcu_set_native_id :: CInt -> IO AFErr
--------------------------------------------------------------------------------
/src/ArrayFire/Internal/Data.hsc:
--------------------------------------------------------------------------------
1 | {-# LANGUAGE CPP #-}
2 | module ArrayFire.Internal.Data where
3 |
4 | import ArrayFire.Internal.Defines
5 |
6 | import Foreign.Ptr
7 | import Foreign.C.Types
8 |
9 | #include "af/data.h"
10 | foreign import ccall unsafe "af_constant"
11 | af_constant :: Ptr AFArray -> Double -> CUInt -> Ptr DimT -> AFDtype -> IO AFErr
12 | foreign import ccall unsafe "af_constant_complex"
13 | af_constant_complex :: Ptr AFArray -> Double -> Double -> CUInt -> Ptr DimT -> AFDtype -> IO AFErr
14 | foreign import ccall unsafe "af_constant_long"
15 | af_constant_long :: Ptr AFArray -> IntL -> CUInt -> Ptr DimT -> IO AFErr
16 | foreign import ccall unsafe "af_constant_ulong"
17 | af_constant_ulong :: Ptr AFArray -> UIntL -> CUInt -> Ptr DimT -> IO AFErr
18 | foreign import ccall unsafe "af_range"
19 | af_range :: Ptr AFArray -> CUInt -> Ptr DimT -> CInt -> AFDtype -> IO AFErr
20 | foreign import ccall unsafe "af_iota"
21 | af_iota :: Ptr AFArray -> CUInt -> Ptr DimT -> CUInt -> Ptr DimT -> AFDtype -> IO AFErr
22 | foreign import ccall unsafe "af_identity"
23 | af_identity :: Ptr AFArray -> CUInt -> Ptr DimT -> AFDtype -> IO AFErr
24 | foreign import ccall unsafe "af_diag_create"
25 | af_diag_create :: Ptr AFArray -> AFArray -> CInt -> IO AFErr
26 | foreign import ccall unsafe "af_diag_extract"
27 | af_diag_extract :: Ptr AFArray -> AFArray -> CInt -> IO AFErr
28 | foreign import ccall unsafe "af_join"
29 | af_join :: Ptr AFArray -> CInt -> AFArray -> AFArray -> IO AFErr
30 | foreign import ccall unsafe "af_join_many"
31 | af_join_many :: Ptr AFArray -> CInt -> CUInt -> Ptr AFArray -> IO AFErr
32 | foreign import ccall unsafe "af_tile"
33 | af_tile :: Ptr AFArray -> AFArray -> CUInt -> CUInt -> CUInt -> CUInt -> IO AFErr
34 | foreign import ccall unsafe "af_reorder"
35 | af_reorder :: Ptr AFArray -> AFArray -> CUInt -> CUInt -> CUInt -> CUInt -> IO AFErr
36 | foreign import ccall unsafe "af_shift"
37 | af_shift :: Ptr AFArray -> AFArray -> CInt -> CInt -> CInt -> CInt -> IO AFErr
38 | foreign import ccall unsafe "af_moddims"
39 | af_moddims :: Ptr AFArray -> AFArray -> CUInt -> Ptr DimT -> IO AFErr
40 | foreign import ccall unsafe "af_flat"
41 | af_flat :: Ptr AFArray -> AFArray -> IO AFErr
42 | foreign import ccall unsafe "af_flip"
43 | af_flip :: Ptr AFArray -> AFArray -> CUInt -> IO AFErr
44 | foreign import ccall unsafe "af_lower"
45 | af_lower :: Ptr AFArray -> AFArray -> CBool -> IO AFErr
46 | foreign import ccall unsafe "af_upper"
47 | af_upper :: Ptr AFArray -> AFArray -> CBool -> IO AFErr
48 | foreign import ccall unsafe "af_select"
49 | af_select :: Ptr AFArray -> AFArray -> AFArray -> AFArray -> IO AFErr
50 | foreign import ccall unsafe "af_select_scalar_r"
51 | af_select_scalar_r :: Ptr AFArray -> AFArray -> AFArray -> Double -> IO AFErr
52 | foreign import ccall unsafe "af_select_scalar_l"
53 | af_select_scalar_l :: Ptr AFArray -> AFArray -> Double -> AFArray -> IO AFErr
54 | foreign import ccall unsafe "af_replace"
55 | af_replace :: AFArray -> AFArray -> AFArray -> IO AFErr
56 | foreign import ccall unsafe "af_replace_scalar"
57 | af_replace_scalar :: AFArray -> AFArray -> Double -> IO AFErr
58 |
--------------------------------------------------------------------------------
/src/ArrayFire/Internal/Device.hsc:
--------------------------------------------------------------------------------
1 | {-# LANGUAGE CPP #-}
2 | module ArrayFire.Internal.Device where
3 |
4 | import ArrayFire.Internal.Defines
5 |
6 | import Foreign.Ptr
7 | import Foreign.C.Types
8 |
9 | #include "af/device.h"
10 | foreign import ccall unsafe "af_info"
11 | af_info :: IO AFErr
12 | foreign import ccall unsafe "af_init"
13 | af_init :: IO AFErr
14 | foreign import ccall unsafe "af_info_string"
15 | af_info_string :: Ptr (Ptr CChar) -> CBool -> IO AFErr
16 | foreign import ccall unsafe "af_device_info"
17 | af_device_info :: Ptr CChar -> Ptr CChar -> Ptr CChar -> Ptr CChar -> IO AFErr
18 | foreign import ccall unsafe "af_get_device_count"
19 | af_get_device_count :: Ptr CInt -> IO AFErr
20 | foreign import ccall unsafe "af_get_dbl_support"
21 | af_get_dbl_support :: Ptr CBool -> CInt -> IO AFErr
22 | foreign import ccall unsafe "af_set_device"
23 | af_set_device :: CInt -> IO AFErr
24 | foreign import ccall unsafe "af_get_device"
25 | af_get_device :: Ptr CInt -> IO AFErr
26 | foreign import ccall unsafe "af_sync"
27 | af_sync :: CInt -> IO AFErr
28 | foreign import ccall unsafe "af_alloc_device"
29 | af_alloc_device :: Ptr (Ptr ()) -> DimT -> IO AFErr
30 | foreign import ccall unsafe "af_free_device"
31 | af_free_device :: Ptr () -> IO AFErr
32 | foreign import ccall unsafe "af_alloc_pinned"
33 | af_alloc_pinned :: Ptr (Ptr ()) -> DimT -> IO AFErr
34 | foreign import ccall unsafe "af_free_pinned"
35 | af_free_pinned :: Ptr () -> IO AFErr
36 | foreign import ccall unsafe "af_alloc_host"
37 | af_alloc_host :: Ptr (Ptr ()) -> DimT -> IO AFErr
38 | foreign import ccall unsafe "af_free_host"
39 | af_free_host :: Ptr () -> IO AFErr
40 | foreign import ccall unsafe "af_device_array"
41 | af_device_array :: Ptr AFArray -> Ptr () -> CUInt -> Ptr DimT -> AFDtype -> IO AFErr
42 | foreign import ccall unsafe "af_device_mem_info"
43 | af_device_mem_info :: Ptr CSize -> Ptr CSize -> Ptr CSize -> Ptr CSize -> IO AFErr
44 | foreign import ccall unsafe "af_print_mem_info"
45 | af_print_mem_info :: Ptr CChar -> CInt -> IO AFErr
46 | foreign import ccall unsafe "af_device_gc"
47 | af_device_gc :: IO AFErr
48 | foreign import ccall unsafe "af_set_mem_step_size"
49 | af_set_mem_step_size :: CSize -> IO AFErr
50 | foreign import ccall unsafe "af_get_mem_step_size"
51 | af_get_mem_step_size :: Ptr CSize -> IO AFErr
52 | foreign import ccall unsafe "af_lock_device_ptr"
53 | af_lock_device_ptr :: AFArray -> IO AFErr
54 | foreign import ccall unsafe "af_unlock_device_ptr"
55 | af_unlock_device_ptr :: AFArray -> IO AFErr
56 | foreign import ccall unsafe "af_lock_array"
57 | af_lock_array :: AFArray -> IO AFErr
58 | foreign import ccall unsafe "af_is_locked_array"
59 | af_is_locked_array :: Ptr CBool -> AFArray -> IO AFErr
60 | foreign import ccall unsafe "af_get_device_ptr"
61 | af_get_device_ptr :: Ptr (Ptr ()) -> AFArray -> IO AFErr
62 |
--------------------------------------------------------------------------------
/src/ArrayFire/Internal/Exception.hsc:
--------------------------------------------------------------------------------
1 | {-# LANGUAGE CPP #-}
2 | module ArrayFire.Internal.Exception where
3 |
4 | import ArrayFire.Internal.Defines
5 |
6 | import Foreign.Ptr
7 | import Foreign.C.Types
8 |
9 | #include "af/defines.h"
10 | foreign import ccall unsafe "af_get_last_error"
11 | af_get_last_error :: Ptr (Ptr CChar) -> Ptr DimT -> IO ()
12 | foreign import ccall unsafe "af_err_to_string"
13 | af_err_to_string :: AFErr -> IO (Ptr CChar)
14 |
--------------------------------------------------------------------------------
/src/ArrayFire/Internal/Features.hsc:
--------------------------------------------------------------------------------
1 | {-# LANGUAGE CPP #-}
2 | module ArrayFire.Internal.Features where
3 |
4 | import ArrayFire.Internal.Defines
5 |
6 | import Foreign.Ptr
7 | import Foreign.C.Types
8 |
9 | #include "af/features.h"
10 | foreign import ccall unsafe "af_create_features"
11 | af_create_features :: Ptr AFFeatures -> DimT -> IO AFErr
12 | foreign import ccall unsafe "af_retain_features"
13 | af_retain_features :: Ptr AFFeatures -> AFFeatures -> IO AFErr
14 | foreign import ccall unsafe "af_get_features_num"
15 | af_get_features_num :: Ptr DimT -> AFFeatures -> IO AFErr
16 | foreign import ccall unsafe "af_get_features_xpos"
17 | af_get_features_xpos :: Ptr AFArray -> AFFeatures -> IO AFErr
18 | foreign import ccall unsafe "af_get_features_ypos"
19 | af_get_features_ypos :: Ptr AFArray -> AFFeatures -> IO AFErr
20 | foreign import ccall unsafe "af_get_features_score"
21 | af_get_features_score :: Ptr AFArray -> AFFeatures -> IO AFErr
22 | foreign import ccall unsafe "af_get_features_orientation"
23 | af_get_features_orientation :: Ptr AFArray -> AFFeatures -> IO AFErr
24 | foreign import ccall unsafe "af_get_features_size"
25 | af_get_features_size :: Ptr AFArray -> AFFeatures -> IO AFErr
26 | foreign import ccall unsafe "&af_release_features"
27 | af_release_features :: FunPtr (AFFeatures -> IO ())
28 |
--------------------------------------------------------------------------------
/src/ArrayFire/Internal/Graphics.hsc:
--------------------------------------------------------------------------------
1 | {-# LANGUAGE CPP #-}
2 | module ArrayFire.Internal.Graphics where
3 |
4 | import ArrayFire.Internal.Defines
5 | import ArrayFire.Internal.Types
6 |
7 | import Foreign.Ptr
8 | import Foreign.C.Types
9 |
10 | #include "af/graphics.h"
11 | foreign import ccall unsafe "af_create_window"
12 | af_create_window :: Ptr AFWindow -> CInt -> CInt -> Ptr CChar -> IO AFErr
13 | foreign import ccall unsafe "af_set_position"
14 | af_set_position :: AFWindow -> CUInt -> CUInt -> IO AFErr
15 | foreign import ccall unsafe "af_set_title"
16 | af_set_title :: AFWindow -> Ptr CChar -> IO AFErr
17 | foreign import ccall unsafe "af_set_size"
18 | af_set_size :: AFWindow -> CUInt -> CUInt -> IO AFErr
19 | foreign import ccall unsafe "af_draw_image"
20 | af_draw_image :: AFWindow -> AFArray -> Ptr AFCell -> IO AFErr
21 | foreign import ccall unsafe "af_draw_plot"
22 | af_draw_plot :: AFWindow -> AFArray -> AFArray -> Ptr AFCell -> IO AFErr
23 | foreign import ccall unsafe "af_draw_plot3"
24 | af_draw_plot3 :: AFWindow -> AFArray -> Ptr AFCell -> IO AFErr
25 | foreign import ccall unsafe "af_draw_plot_nd"
26 | af_draw_plot_nd :: AFWindow -> AFArray -> Ptr AFCell -> IO AFErr
27 | foreign import ccall unsafe "af_draw_plot_2d"
28 | af_draw_plot_2d :: AFWindow -> AFArray -> AFArray -> Ptr AFCell -> IO AFErr
29 | foreign import ccall unsafe "af_draw_plot_3d"
30 | af_draw_plot_3d :: AFWindow -> AFArray -> AFArray -> AFArray -> Ptr AFCell -> IO AFErr
31 | foreign import ccall unsafe "af_draw_scatter"
32 | af_draw_scatter :: AFWindow -> AFArray -> AFArray -> AFMarkerType -> Ptr AFCell -> IO AFErr
33 | foreign import ccall unsafe "af_draw_scatter3"
34 | af_draw_scatter3 :: AFWindow -> AFArray -> AFMarkerType -> Ptr AFCell -> IO AFErr
35 | foreign import ccall unsafe "af_draw_scatter_nd"
36 | af_draw_scatter_nd :: AFWindow -> AFArray -> AFMarkerType -> Ptr AFCell -> IO AFErr
37 | foreign import ccall unsafe "af_draw_scatter_2d"
38 | af_draw_scatter_2d :: AFWindow -> AFArray -> AFArray -> AFMarkerType -> Ptr AFCell -> IO AFErr
39 | foreign import ccall unsafe "af_draw_scatter_3d"
40 | af_draw_scatter_3d :: AFWindow -> AFArray -> AFArray -> AFArray -> AFMarkerType -> Ptr AFCell -> IO AFErr
41 | foreign import ccall unsafe "af_draw_hist"
42 | af_draw_hist :: AFWindow -> AFArray -> Double -> Double -> Ptr AFCell -> IO AFErr
43 | foreign import ccall unsafe "af_draw_surface"
44 | af_draw_surface :: AFWindow -> AFArray -> AFArray -> AFArray -> Ptr AFCell -> IO AFErr
45 | foreign import ccall unsafe "af_draw_vector_field_nd"
46 | af_draw_vector_field_nd :: AFWindow -> AFArray -> AFArray -> Ptr AFCell -> IO AFErr
47 | foreign import ccall unsafe "af_draw_vector_field_3d"
48 | af_draw_vector_field_3d :: AFWindow -> AFArray -> AFArray -> AFArray -> AFArray -> AFArray -> AFArray -> Ptr AFCell -> IO AFErr
49 | foreign import ccall unsafe "af_draw_vector_field_2d"
50 | af_draw_vector_field_2d :: AFWindow -> AFArray -> AFArray -> AFArray -> AFArray -> Ptr AFCell -> IO AFErr
51 | foreign import ccall unsafe "af_grid"
52 | af_grid :: AFWindow -> CInt -> CInt -> IO AFErr
53 | foreign import ccall unsafe "af_set_axes_limits_compute"
54 | af_set_axes_limits_compute :: AFWindow -> AFArray -> AFArray -> AFArray -> CBool -> Ptr AFCell -> IO AFErr
55 | foreign import ccall unsafe "af_set_axes_limits_2d"
56 | af_set_axes_limits_2d :: AFWindow -> Float -> Float -> Float -> Float -> CBool -> Ptr AFCell -> IO AFErr
57 | foreign import ccall unsafe "af_set_axes_limits_3d"
58 | af_set_axes_limits_3d :: AFWindow -> Float -> Float -> Float -> Float -> Float -> Float -> CBool -> Ptr AFCell -> IO AFErr
59 | foreign import ccall unsafe "af_set_axes_titles"
60 | af_set_axes_titles :: AFWindow -> Ptr CChar -> Ptr CChar -> Ptr CChar -> Ptr AFCell -> IO AFErr
61 | foreign import ccall unsafe "af_show"
62 | af_show :: AFWindow -> IO AFErr
63 | foreign import ccall unsafe "af_is_window_closed"
64 | af_is_window_closed :: Ptr CBool -> AFWindow -> IO AFErr
65 | foreign import ccall unsafe "af_set_visibility"
66 | af_set_visibility :: AFWindow -> CBool -> IO AFErr
67 | foreign import ccall unsafe "af_destroy_window"
68 | af_destroy_window :: AFWindow -> IO AFErr
69 |
--------------------------------------------------------------------------------
/src/ArrayFire/Internal/Image.hsc:
--------------------------------------------------------------------------------
1 | {-# LANGUAGE CPP #-}
2 | module ArrayFire.Internal.Image where
3 |
4 | import ArrayFire.Internal.Defines
5 |
6 | import Foreign.Ptr
7 | import Foreign.C.Types
8 |
9 | #include "af/image.h"
10 | foreign import ccall unsafe "af_gradient"
11 | af_gradient :: Ptr AFArray -> Ptr AFArray -> AFArray -> IO AFErr
12 | foreign import ccall unsafe "af_load_image"
13 | af_load_image :: Ptr AFArray -> Ptr CChar -> CBool -> IO AFErr
14 | foreign import ccall unsafe "af_save_image"
15 | af_save_image :: Ptr CChar -> AFArray -> IO AFErr
16 | foreign import ccall unsafe "af_load_image_memory"
17 | af_load_image_memory :: Ptr AFArray -> Ptr () -> IO AFErr
18 | foreign import ccall unsafe "af_save_image_memory"
19 | af_save_image_memory :: Ptr (Ptr ()) -> AFArray -> AFImageFormat -> IO AFErr
20 | foreign import ccall unsafe "af_delete_image_memory"
21 | af_delete_image_memory :: Ptr () -> IO AFErr
22 | foreign import ccall unsafe "af_load_image_native"
23 | af_load_image_native :: Ptr AFArray -> Ptr CChar -> IO AFErr
24 | foreign import ccall unsafe "af_save_image_native"
25 | af_save_image_native :: Ptr CChar -> AFArray -> IO AFErr
26 | foreign import ccall unsafe "af_is_image_io_available"
27 | af_is_image_io_available :: Ptr CBool -> IO AFErr
28 | foreign import ccall unsafe "af_resize"
29 | af_resize :: Ptr AFArray -> AFArray -> DimT -> DimT -> AFInterpType -> IO AFErr
30 | foreign import ccall unsafe "af_transform"
31 | af_transform :: Ptr AFArray -> AFArray -> AFArray -> DimT -> DimT -> AFInterpType -> CBool -> IO AFErr
32 | foreign import ccall unsafe "af_transform_coordinates"
33 | af_transform_coordinates :: Ptr AFArray -> AFArray -> Float -> Float -> IO AFErr
34 | foreign import ccall unsafe "af_rotate"
35 | af_rotate :: Ptr AFArray -> AFArray -> Float -> CBool -> AFInterpType -> IO AFErr
36 | foreign import ccall unsafe "af_translate"
37 | af_translate :: Ptr AFArray -> AFArray -> Float -> Float -> DimT -> DimT -> AFInterpType -> IO AFErr
38 | foreign import ccall unsafe "af_scale"
39 | af_scale :: Ptr AFArray -> AFArray -> Float -> Float -> DimT -> DimT -> AFInterpType -> IO AFErr
40 | foreign import ccall unsafe "af_skew"
41 | af_skew :: Ptr AFArray -> AFArray -> Float -> Float -> DimT -> DimT -> AFInterpType -> CBool -> IO AFErr
42 | foreign import ccall unsafe "af_histogram"
43 | af_histogram :: Ptr AFArray -> AFArray -> CUInt -> Double -> Double -> IO AFErr
44 | foreign import ccall unsafe "af_dilate"
45 | af_dilate :: Ptr AFArray -> AFArray -> AFArray -> IO AFErr
46 | foreign import ccall unsafe "af_dilate3"
47 | af_dilate3 :: Ptr AFArray -> AFArray -> AFArray -> IO AFErr
48 | foreign import ccall unsafe "af_erode"
49 | af_erode :: Ptr AFArray -> AFArray -> AFArray -> IO AFErr
50 | foreign import ccall unsafe "af_erode3"
51 | af_erode3 :: Ptr AFArray -> AFArray -> AFArray -> IO AFErr
52 | foreign import ccall unsafe "af_bilateral"
53 | af_bilateral :: Ptr AFArray -> AFArray -> Float -> Float -> CBool -> IO AFErr
54 | foreign import ccall unsafe "af_mean_shift"
55 | af_mean_shift :: Ptr AFArray -> AFArray -> Float -> Float -> CUInt -> CBool -> IO AFErr
56 | foreign import ccall unsafe "af_minfilt"
57 | af_minfilt :: Ptr AFArray -> AFArray -> DimT -> DimT -> AFBorderType -> IO AFErr
58 | foreign import ccall unsafe "af_maxfilt"
59 | af_maxfilt :: Ptr AFArray -> AFArray -> DimT -> DimT -> AFBorderType -> IO AFErr
60 | foreign import ccall unsafe "af_regions"
61 | af_regions :: Ptr AFArray -> AFArray -> AFConnectivity -> AFDtype -> IO AFErr
62 | foreign import ccall unsafe "af_sobel_operator"
63 | af_sobel_operator :: Ptr AFArray -> Ptr AFArray -> AFArray -> CUInt -> IO AFErr
64 | foreign import ccall unsafe "af_rgb2gray"
65 | af_rgb2gray :: Ptr AFArray -> AFArray -> Float -> Float -> Float -> IO AFErr
66 | foreign import ccall unsafe "af_gray2rgb"
67 | af_gray2rgb :: Ptr AFArray -> AFArray -> Float -> Float -> Float -> IO AFErr
68 | foreign import ccall unsafe "af_hist_equal"
69 | af_hist_equal :: Ptr AFArray -> AFArray -> AFArray -> IO AFErr
70 | foreign import ccall unsafe "af_gaussian_kernel"
71 | af_gaussian_kernel :: Ptr AFArray -> CInt -> CInt -> Double -> Double -> IO AFErr
72 | foreign import ccall unsafe "af_hsv2rgb"
73 | af_hsv2rgb :: Ptr AFArray -> AFArray -> IO AFErr
74 | foreign import ccall unsafe "af_rgb2hsv"
75 | af_rgb2hsv :: Ptr AFArray -> AFArray -> IO AFErr
76 | foreign import ccall unsafe "af_color_space"
77 | af_color_space :: Ptr AFArray -> AFArray -> AFCSpace -> AFCSpace -> IO AFErr
78 | foreign import ccall unsafe "af_unwrap"
79 | af_unwrap :: Ptr AFArray -> AFArray -> DimT -> DimT -> DimT -> DimT -> DimT -> DimT -> CBool -> IO AFErr
80 | foreign import ccall unsafe "af_wrap"
81 | af_wrap :: Ptr AFArray -> AFArray -> DimT -> DimT -> DimT -> DimT -> DimT -> DimT -> DimT -> DimT -> CBool -> IO AFErr
82 | foreign import ccall unsafe "af_sat"
83 | af_sat :: Ptr AFArray -> AFArray -> IO AFErr
84 | foreign import ccall unsafe "af_ycbcr2rgb"
85 | af_ycbcr2rgb :: Ptr AFArray -> AFArray -> AFYccStd -> IO AFErr
86 | foreign import ccall unsafe "af_rgb2ycbcr"
87 | af_rgb2ycbcr :: Ptr AFArray -> AFArray -> AFYccStd -> IO AFErr
88 | foreign import ccall unsafe "af_moments"
89 | af_moments :: Ptr AFArray -> AFArray -> AFMomentType -> IO AFErr
90 | foreign import ccall unsafe "af_moments_all"
91 | af_moments_all :: Ptr Double -> AFArray -> AFMomentType -> IO AFErr
92 | foreign import ccall unsafe "af_canny"
93 | af_canny :: Ptr AFArray -> AFArray -> AFCannyThreshold -> Float -> Float -> CUInt -> CBool -> IO AFErr
94 | foreign import ccall unsafe "af_anisotropic_diffusion"
95 | af_anisotropic_diffusion :: Ptr AFArray -> AFArray -> Float -> Float -> CUInt -> AFFluxFunction -> AFDiffusionEq -> IO AFErr
96 |
--------------------------------------------------------------------------------
/src/ArrayFire/Internal/Index.hsc:
--------------------------------------------------------------------------------
1 | {-# LANGUAGE CPP #-}
2 | module ArrayFire.Internal.Index where
3 |
4 | import ArrayFire.Internal.Defines
5 | import ArrayFire.Internal.Types
6 |
7 | import Foreign.Ptr
8 | import Foreign.C.Types
9 |
10 | #include "af/index.h"
11 | foreign import ccall unsafe "af_index"
12 | af_index :: Ptr AFArray -> AFArray -> CUInt -> Ptr AFSeq -> IO AFErr
13 | foreign import ccall unsafe "af_lookup"
14 | af_lookup :: Ptr AFArray -> AFArray -> AFArray -> CUInt -> IO AFErr
15 | foreign import ccall unsafe "af_assign_seq"
16 | af_assign_seq :: Ptr AFArray -> AFArray -> CUInt -> Ptr AFSeq -> AFArray -> IO AFErr
17 | foreign import ccall unsafe "af_index_gen"
18 | af_index_gen :: Ptr AFArray -> AFArray -> DimT -> Ptr AFIndex -> IO AFErr
19 | foreign import ccall unsafe "af_assign_gen"
20 | af_assign_gen :: Ptr AFArray -> AFArray -> DimT -> Ptr AFIndex -> AFArray -> IO AFErr
21 | foreign import ccall unsafe "af_create_indexers"
22 | af_create_indexers :: Ptr (Ptr AFIndex) -> IO AFErr
23 | foreign import ccall unsafe "af_set_array_indexer"
24 | af_set_array_indexer :: Ptr AFIndex -> AFArray -> DimT -> IO AFErr
25 | foreign import ccall unsafe "af_set_seq_indexer"
26 | af_set_seq_indexer :: Ptr AFIndex -> Ptr AFSeq -> DimT -> CBool -> IO AFErr
27 | foreign import ccall unsafe "af_set_seq_param_indexer"
28 | af_set_seq_param_indexer :: Ptr AFIndex -> Double -> Double -> Double -> DimT -> CBool -> IO AFErr
29 | foreign import ccall unsafe "af_release_indexers"
30 | af_release_indexers :: Ptr AFIndex -> IO AFErr
31 |
--------------------------------------------------------------------------------
/src/ArrayFire/Internal/Internal.hsc:
--------------------------------------------------------------------------------
1 | {-# LANGUAGE CPP #-}
2 | module ArrayFire.Internal.Internal where
3 |
4 | import ArrayFire.Internal.Defines
5 | import Foreign.Ptr
6 | import Foreign.C.Types
7 |
8 | #include "af/internal.h"
9 | foreign import ccall unsafe "af_create_strided_array"
10 | af_create_strided_array :: Ptr AFArray -> Ptr () -> DimT -> CUInt -> Ptr DimT -> Ptr DimT -> AFDtype -> AFSource -> IO AFErr
11 | foreign import ccall unsafe "af_get_strides"
12 | af_get_strides :: Ptr DimT -> Ptr DimT -> Ptr DimT -> Ptr DimT -> AFArray -> IO AFErr
13 | foreign import ccall unsafe "af_get_offset"
14 | af_get_offset :: Ptr DimT -> AFArray -> IO AFErr
15 | foreign import ccall unsafe "af_get_raw_ptr"
16 | af_get_raw_ptr :: Ptr (Ptr ()) -> AFArray -> IO AFErr
17 | foreign import ccall unsafe "af_is_linear"
18 | af_is_linear :: Ptr CBool -> AFArray -> IO AFErr
19 | foreign import ccall unsafe "af_is_owner"
20 | af_is_owner :: Ptr CBool -> AFArray -> IO AFErr
21 | foreign import ccall unsafe "af_get_allocated_bytes"
22 | af_get_allocated_bytes :: Ptr CSize -> AFArray -> IO AFErr
23 |
--------------------------------------------------------------------------------
/src/ArrayFire/Internal/LAPACK.hsc:
--------------------------------------------------------------------------------
1 | {-# LANGUAGE CPP #-}
2 | module ArrayFire.Internal.LAPACK where
3 |
4 | import ArrayFire.Internal.Defines
5 |
6 | import Foreign.Ptr
7 | import Foreign.C.Types
8 |
9 | #include "af/lapack.h"
10 | foreign import ccall unsafe "af_svd"
11 | af_svd :: Ptr AFArray -> Ptr AFArray -> Ptr AFArray -> AFArray -> IO AFErr
12 | foreign import ccall unsafe "af_svd_inplace"
13 | af_svd_inplace :: Ptr AFArray -> Ptr AFArray -> Ptr AFArray -> AFArray -> IO AFErr
14 | foreign import ccall unsafe "af_lu"
15 | af_lu :: Ptr AFArray -> Ptr AFArray -> Ptr AFArray -> AFArray -> IO AFErr
16 | foreign import ccall unsafe "af_lu_inplace"
17 | af_lu_inplace :: Ptr AFArray -> AFArray -> CBool -> IO AFErr
18 | foreign import ccall unsafe "af_qr"
19 | af_qr :: Ptr AFArray -> Ptr AFArray -> Ptr AFArray -> AFArray -> IO AFErr
20 | foreign import ccall unsafe "af_qr_inplace"
21 | af_qr_inplace :: Ptr AFArray -> AFArray -> IO AFErr
22 | foreign import ccall unsafe "af_cholesky"
23 | af_cholesky :: Ptr AFArray -> Ptr CInt -> AFArray -> CBool -> IO AFErr
24 | foreign import ccall unsafe "af_cholesky_inplace"
25 | af_cholesky_inplace :: Ptr CInt -> AFArray -> CBool -> IO AFErr
26 | foreign import ccall unsafe "af_solve"
27 | af_solve :: Ptr AFArray -> AFArray -> AFArray -> AFMatProp -> IO AFErr
28 | foreign import ccall unsafe "af_solve_lu"
29 | af_solve_lu :: Ptr AFArray -> AFArray -> AFArray -> AFArray -> AFMatProp -> IO AFErr
30 | foreign import ccall unsafe "af_inverse"
31 | af_inverse :: Ptr AFArray -> AFArray -> AFMatProp -> IO AFErr
32 | foreign import ccall unsafe "af_rank"
33 | af_rank :: Ptr CUInt -> AFArray -> Double -> IO AFErr
34 | foreign import ccall unsafe "af_det"
35 | af_det :: Ptr Double -> Ptr Double -> AFArray -> IO AFErr
36 | foreign import ccall unsafe "af_norm"
37 | af_norm :: Ptr Double -> AFArray -> AFNormType -> Double -> Double -> IO AFErr
38 | foreign import ccall unsafe "af_is_lapack_available"
39 | af_is_lapack_available :: Ptr CBool -> IO AFErr
40 |
--------------------------------------------------------------------------------
/src/ArrayFire/Internal/OpenCL.hsc:
--------------------------------------------------------------------------------
1 | {-# LANGUAGE CPP #-}
2 | module ArrayFire.Internal.OpenCL where
3 |
4 | import ArrayFire.Internal.Defines
5 | import ArrayFire.Internal.Types
6 | import Data.Word
7 | import Data.Int
8 | import Foreign.Ptr
9 | import Foreign.C.Types
10 |
11 | #include "af/opencl.h"
12 | foreign import ccall unsafe "afcl_get_context"
13 | afcl_get_context :: Ptr ClContext -> CBool -> IO AFErr
14 | foreign import ccall unsafe "afcl_get_queue"
15 | afcl_get_queue :: Ptr ClCommandQueue -> CBool -> IO AFErr
16 | foreign import ccall unsafe "afcl_get_device_id"
17 | afcl_get_device_id :: Ptr ClDeviceId -> IO AFErr
18 | foreign import ccall unsafe "afcl_set_device_id"
19 | afcl_set_device_id :: ClDeviceId -> IO AFErr
20 | foreign import ccall unsafe "afcl_add_device_context"
21 | afcl_add_device_context :: ClDeviceId -> ClContext -> ClCommandQueue -> IO AFErr
22 | foreign import ccall unsafe "afcl_set_device_context"
23 | afcl_set_device_context :: ClDeviceId -> ClContext -> IO AFErr
24 | foreign import ccall unsafe "afcl_delete_device_context"
25 | afcl_delete_device_context :: ClDeviceId -> ClContext -> IO AFErr
26 | foreign import ccall unsafe "afcl_get_device_type"
27 | afcl_get_device_type :: Ptr AfclDeviceType -> IO AFErr
28 | foreign import ccall unsafe "afcl_get_platform"
29 | afcl_get_platform :: Ptr AFCLPlatform -> IO AFErr
--------------------------------------------------------------------------------
/src/ArrayFire/Internal/Random.hsc:
--------------------------------------------------------------------------------
1 | {-# LANGUAGE CPP #-}
2 | module ArrayFire.Internal.Random where
3 |
4 | import ArrayFire.Internal.Defines
5 |
6 | import Foreign.Ptr
7 | import Foreign.C.Types
8 |
9 | #include "af/random.h"
10 | foreign import ccall unsafe "af_create_random_engine"
11 | af_create_random_engine :: Ptr AFRandomEngine -> AFRandomEngineType -> UIntL -> IO AFErr
12 | foreign import ccall unsafe "af_retain_random_engine"
13 | af_retain_random_engine :: Ptr AFRandomEngine -> AFRandomEngine -> IO AFErr
14 | foreign import ccall unsafe "af_random_engine_set_type"
15 | af_random_engine_set_type :: Ptr AFRandomEngine -> AFRandomEngineType -> IO AFErr
16 | foreign import ccall unsafe "af_random_engine_get_type"
17 | af_random_engine_get_type :: Ptr AFRandomEngineType -> AFRandomEngine -> IO AFErr
18 | foreign import ccall unsafe "af_random_uniform"
19 | af_random_uniform :: Ptr AFArray -> CUInt -> Ptr DimT -> AFDtype -> AFRandomEngine -> IO AFErr
20 | foreign import ccall unsafe "af_random_normal"
21 | af_random_normal :: Ptr AFArray -> CUInt -> Ptr DimT -> AFDtype -> AFRandomEngine -> IO AFErr
22 | foreign import ccall unsafe "af_random_engine_set_seed"
23 | af_random_engine_set_seed :: Ptr AFRandomEngine -> UIntL -> IO AFErr
24 | foreign import ccall unsafe "af_get_default_random_engine"
25 | af_get_default_random_engine :: Ptr AFRandomEngine -> IO AFErr
26 | foreign import ccall unsafe "af_set_default_random_engine_type"
27 | af_set_default_random_engine_type :: AFRandomEngineType -> IO AFErr
28 | foreign import ccall unsafe "af_random_engine_get_seed"
29 | af_random_engine_get_seed :: Ptr UIntL -> AFRandomEngine -> IO AFErr
30 | foreign import ccall unsafe "af_release_random_engine"
31 | af_release_random_engine :: AFRandomEngine -> IO AFErr
32 | foreign import ccall unsafe "af_randu"
33 | af_randu :: Ptr AFArray -> CUInt -> Ptr DimT -> AFDtype -> IO AFErr
34 | foreign import ccall unsafe "af_randn"
35 | af_randn :: Ptr AFArray -> CUInt -> Ptr DimT -> AFDtype -> IO AFErr
36 | foreign import ccall unsafe "af_set_seed"
37 | af_set_seed :: UIntL -> IO AFErr
38 | foreign import ccall unsafe "af_get_seed"
39 | af_get_seed :: Ptr UIntL -> IO AFErr
40 |
--------------------------------------------------------------------------------
/src/ArrayFire/Internal/Seq.hsc:
--------------------------------------------------------------------------------
1 | {-# LANGUAGE CPP #-}
2 | module ArrayFire.Internal.Seq where
3 |
4 | import ArrayFire.Internal.Defines
5 | import ArrayFire.Internal.Types
6 | import Data.Word
7 | import Data.Int
8 | import Foreign.Ptr
9 | import Foreign.C.Types
10 |
11 | #include "af/seq.h"
12 | foreign import ccall unsafe "af_make_seq"
13 | af_make_seq :: Double -> Double -> Double -> IO AFSeq
--------------------------------------------------------------------------------
/src/ArrayFire/Internal/Signal.hsc:
--------------------------------------------------------------------------------
1 | {-# LANGUAGE CPP #-}
2 | module ArrayFire.Internal.Signal where
3 |
4 | import ArrayFire.Internal.Defines
5 | import Foreign.Ptr
6 | import Foreign.C.Types
7 |
8 | #include "af/signal.h"
9 | foreign import ccall unsafe "af_approx1"
10 | af_approx1 :: Ptr AFArray -> AFArray -> AFArray -> AFInterpType -> Float -> IO AFErr
11 | foreign import ccall unsafe "af_approx2"
12 | af_approx2 :: Ptr AFArray -> AFArray -> AFArray -> AFArray -> AFInterpType -> Float -> IO AFErr
13 | foreign import ccall unsafe "af_fft"
14 | af_fft :: Ptr AFArray -> AFArray -> Double -> DimT -> IO AFErr
15 | foreign import ccall unsafe "af_fft_inplace"
16 | af_fft_inplace :: AFArray -> Double -> IO AFErr
17 | foreign import ccall unsafe "af_fft2"
18 | af_fft2 :: Ptr AFArray -> AFArray -> Double -> DimT -> DimT -> IO AFErr
19 | foreign import ccall unsafe "af_fft2_inplace"
20 | af_fft2_inplace :: AFArray -> Double -> IO AFErr
21 | foreign import ccall unsafe "af_fft3"
22 | af_fft3 :: Ptr AFArray -> AFArray -> Double -> DimT -> DimT -> DimT -> IO AFErr
23 | foreign import ccall unsafe "af_fft3_inplace"
24 | af_fft3_inplace :: AFArray -> Double -> IO AFErr
25 | foreign import ccall unsafe "af_ifft"
26 | af_ifft :: Ptr AFArray -> AFArray -> Double -> DimT -> IO AFErr
27 | foreign import ccall unsafe "af_ifft_inplace"
28 | af_ifft_inplace :: AFArray -> Double -> IO AFErr
29 | foreign import ccall unsafe "af_ifft2"
30 | af_ifft2 :: Ptr AFArray -> AFArray -> Double -> DimT -> DimT -> IO AFErr
31 | foreign import ccall unsafe "af_ifft2_inplace"
32 | af_ifft2_inplace :: AFArray -> Double -> IO AFErr
33 | foreign import ccall unsafe "af_ifft3"
34 | af_ifft3 :: Ptr AFArray -> AFArray -> Double -> DimT -> DimT -> DimT -> IO AFErr
35 | foreign import ccall unsafe "af_ifft3_inplace"
36 | af_ifft3_inplace :: AFArray -> Double -> IO AFErr
37 | foreign import ccall unsafe "af_fft_r2c"
38 | af_fft_r2c :: Ptr AFArray -> AFArray -> Double -> DimT -> IO AFErr
39 | foreign import ccall unsafe "af_fft2_r2c"
40 | af_fft2_r2c :: Ptr AFArray -> AFArray -> Double -> DimT -> DimT -> IO AFErr
41 | foreign import ccall unsafe "af_fft3_r2c"
42 | af_fft3_r2c :: Ptr AFArray -> AFArray -> Double -> DimT -> DimT -> DimT -> IO AFErr
43 | foreign import ccall unsafe "af_fft_c2r"
44 | af_fft_c2r :: Ptr AFArray -> AFArray -> Double -> CBool -> IO AFErr
45 | foreign import ccall unsafe "af_fft2_c2r"
46 | af_fft2_c2r :: Ptr AFArray -> AFArray -> Double -> CBool -> IO AFErr
47 | foreign import ccall unsafe "af_fft3_c2r"
48 | af_fft3_c2r :: Ptr AFArray -> AFArray -> Double -> CBool -> IO AFErr
49 | foreign import ccall unsafe "af_convolve1"
50 | af_convolve1 :: Ptr AFArray -> AFArray -> AFArray -> AFConvMode -> AFConvDomain -> IO AFErr
51 | foreign import ccall unsafe "af_convolve2"
52 | af_convolve2 :: Ptr AFArray -> AFArray -> AFArray -> AFConvMode -> AFConvDomain -> IO AFErr
53 | foreign import ccall unsafe "af_convolve3"
54 | af_convolve3 :: Ptr AFArray -> AFArray -> AFArray -> AFConvMode -> AFConvDomain -> IO AFErr
55 | foreign import ccall unsafe "af_convolve2_sep"
56 | af_convolve2_sep :: Ptr AFArray -> AFArray -> AFArray -> AFArray -> AFConvMode -> IO AFErr
57 | foreign import ccall unsafe "af_fft_convolve1"
58 | af_fft_convolve1 :: Ptr AFArray -> AFArray -> AFArray -> AFConvMode -> IO AFErr
59 | foreign import ccall unsafe "af_fft_convolve2"
60 | af_fft_convolve2 :: Ptr AFArray -> AFArray -> AFArray -> AFConvMode -> IO AFErr
61 | foreign import ccall unsafe "af_fft_convolve3"
62 | af_fft_convolve3 :: Ptr AFArray -> AFArray -> AFArray -> AFConvMode -> IO AFErr
63 | foreign import ccall unsafe "af_fir"
64 | af_fir :: Ptr AFArray -> AFArray -> AFArray -> IO AFErr
65 | foreign import ccall unsafe "af_iir"
66 | af_iir :: Ptr AFArray -> AFArray -> AFArray -> AFArray -> IO AFErr
67 | foreign import ccall unsafe "af_medfilt"
68 | af_medfilt :: Ptr AFArray -> AFArray -> DimT -> DimT -> AFBorderType -> IO AFErr
69 | foreign import ccall unsafe "af_medfilt1"
70 | af_medfilt1 :: Ptr AFArray -> AFArray -> DimT -> AFBorderType -> IO AFErr
71 | foreign import ccall unsafe "af_medfilt2"
72 | af_medfilt2 :: Ptr AFArray -> AFArray -> DimT -> DimT -> AFBorderType -> IO AFErr
73 | foreign import ccall unsafe "af_set_fft_plan_cache_size"
74 | af_set_fft_plan_cache_size :: CSize -> IO AFErr
75 |
--------------------------------------------------------------------------------
/src/ArrayFire/Internal/Sparse.hsc:
--------------------------------------------------------------------------------
1 | {-# LANGUAGE CPP #-}
2 | module ArrayFire.Internal.Sparse where
3 |
4 | import ArrayFire.Internal.Defines
5 | import Foreign.Ptr
6 | import Foreign.C.Types
7 |
8 | #include "af/sparse.h"
9 | foreign import ccall unsafe "af_create_sparse_array"
10 | af_create_sparse_array :: Ptr AFArray -> DimT -> DimT -> AFArray -> AFArray -> AFArray -> AFStorage -> IO AFErr
11 | foreign import ccall unsafe "af_create_sparse_array_from_ptr"
12 | af_create_sparse_array_from_ptr :: Ptr AFArray -> DimT -> DimT -> DimT -> Ptr () -> Ptr CInt -> Ptr CInt -> AFDtype -> AFStorage -> AFSource -> IO AFErr
13 | foreign import ccall unsafe "af_create_sparse_array_from_dense"
14 | af_create_sparse_array_from_dense :: Ptr AFArray -> AFArray -> AFStorage -> IO AFErr
15 | foreign import ccall unsafe "af_sparse_convert_to"
16 | af_sparse_convert_to :: Ptr AFArray -> AFArray -> AFStorage -> IO AFErr
17 | foreign import ccall unsafe "af_sparse_to_dense"
18 | af_sparse_to_dense :: Ptr AFArray -> AFArray -> IO AFErr
19 | foreign import ccall unsafe "af_sparse_get_info"
20 | af_sparse_get_info :: Ptr AFArray -> Ptr AFArray -> Ptr AFArray -> Ptr AFStorage -> AFArray -> IO AFErr
21 | foreign import ccall unsafe "af_sparse_get_values"
22 | af_sparse_get_values :: Ptr AFArray -> AFArray -> IO AFErr
23 | foreign import ccall unsafe "af_sparse_get_row_idx"
24 | af_sparse_get_row_idx :: Ptr AFArray -> AFArray -> IO AFErr
25 | foreign import ccall unsafe "af_sparse_get_col_idx"
26 | af_sparse_get_col_idx :: Ptr AFArray -> AFArray -> IO AFErr
27 | foreign import ccall unsafe "af_sparse_get_nnz"
28 | af_sparse_get_nnz :: Ptr DimT -> AFArray -> IO AFErr
29 | foreign import ccall unsafe "af_sparse_get_storage"
30 | af_sparse_get_storage :: Ptr AFStorage -> AFArray -> IO AFErr
31 |
--------------------------------------------------------------------------------
/src/ArrayFire/Internal/Statistics.hsc:
--------------------------------------------------------------------------------
1 | {-# LANGUAGE CPP #-}
2 | module ArrayFire.Internal.Statistics where
3 |
4 | import ArrayFire.Internal.Defines
5 | import Foreign.Ptr
6 | import Foreign.C.Types
7 |
8 | #include "af/statistics.h"
9 | foreign import ccall unsafe "af_mean"
10 | af_mean :: Ptr AFArray -> AFArray -> DimT -> IO AFErr
11 | foreign import ccall unsafe "af_mean_weighted"
12 | af_mean_weighted :: Ptr AFArray -> AFArray -> AFArray -> DimT -> IO AFErr
13 | foreign import ccall unsafe "af_var"
14 | af_var :: Ptr AFArray -> AFArray -> CBool -> DimT -> IO AFErr
15 | foreign import ccall unsafe "af_var_weighted"
16 | af_var_weighted :: Ptr AFArray -> AFArray -> AFArray -> DimT -> IO AFErr
17 | foreign import ccall unsafe "af_stdev"
18 | af_stdev :: Ptr AFArray -> AFArray -> DimT -> IO AFErr
19 | foreign import ccall unsafe "af_cov"
20 | af_cov :: Ptr AFArray -> AFArray -> AFArray -> CBool -> IO AFErr
21 | foreign import ccall unsafe "af_median"
22 | af_median :: Ptr AFArray -> AFArray -> DimT -> IO AFErr
23 | foreign import ccall unsafe "af_mean_all"
24 | af_mean_all :: Ptr Double -> Ptr Double -> AFArray -> IO AFErr
25 | foreign import ccall unsafe "af_mean_all_weighted"
26 | af_mean_all_weighted :: Ptr Double -> Ptr Double -> AFArray -> AFArray -> IO AFErr
27 | foreign import ccall unsafe "af_var_all"
28 | af_var_all :: Ptr Double -> Ptr Double -> AFArray -> CBool -> IO AFErr
29 | foreign import ccall unsafe "af_var_all_weighted"
30 | af_var_all_weighted :: Ptr Double -> Ptr Double -> AFArray -> AFArray -> IO AFErr
31 | foreign import ccall unsafe "af_stdev_all"
32 | af_stdev_all :: Ptr Double -> Ptr Double -> AFArray -> IO AFErr
33 | foreign import ccall unsafe "af_median_all"
34 | af_median_all :: Ptr Double -> Ptr Double -> AFArray -> IO AFErr
35 | foreign import ccall unsafe "af_corrcoef"
36 | af_corrcoef :: Ptr Double -> Ptr Double -> AFArray -> AFArray -> IO AFErr
37 | foreign import ccall unsafe "af_topk"
38 | af_topk :: Ptr AFArray -> Ptr AFArray -> AFArray -> CInt -> CInt -> AFTopkFunction -> IO AFErr
39 |
--------------------------------------------------------------------------------
/src/ArrayFire/Internal/Util.hsc:
--------------------------------------------------------------------------------
1 | {-# LANGUAGE CPP #-}
2 | module ArrayFire.Internal.Util where
3 |
4 | import ArrayFire.Internal.Defines
5 | import Foreign.Ptr
6 | import Foreign.C.Types
7 |
8 | #include "af/util.h"
9 | foreign import ccall unsafe "af_print_array"
10 | af_print_array :: AFArray -> IO AFErr
11 | foreign import ccall unsafe "af_print_array_gen"
12 | af_print_array_gen :: Ptr CChar -> AFArray -> CInt -> IO AFErr
13 | foreign import ccall unsafe "af_save_array"
14 | af_save_array :: Ptr CInt -> Ptr CChar -> AFArray -> Ptr CChar -> CBool -> IO AFErr
15 | foreign import ccall unsafe "af_read_array_index"
16 | af_read_array_index :: Ptr AFArray -> Ptr CChar -> CUInt -> IO AFErr
17 | foreign import ccall unsafe "af_read_array_key"
18 | af_read_array_key :: Ptr AFArray -> Ptr CChar -> Ptr CChar -> IO AFErr
19 | foreign import ccall unsafe "af_read_array_key_check"
20 | af_read_array_key_check :: Ptr CInt -> Ptr CChar -> Ptr CChar -> IO AFErr
21 | foreign import ccall unsafe "af_array_to_string"
22 | af_array_to_string :: Ptr (Ptr CChar) -> Ptr CChar -> AFArray -> CInt -> CBool -> IO AFErr
23 | foreign import ccall unsafe "af_example_function"
24 | af_example_function :: Ptr AFArray -> AFArray -> AFSomeEnum -> IO AFErr
25 | foreign import ccall unsafe "af_get_version"
26 | af_get_version :: Ptr CInt -> Ptr CInt -> Ptr CInt -> IO AFErr
27 | foreign import ccall unsafe "af_get_revision"
28 | af_get_revision :: IO (Ptr CChar)
29 | foreign import ccall unsafe "af_get_size_of"
30 | af_get_size_of :: Ptr CSize -> AFDtype -> IO AFErr
31 |
--------------------------------------------------------------------------------
/src/ArrayFire/Internal/Vision.hsc:
--------------------------------------------------------------------------------
1 | {-# LANGUAGE CPP #-}
2 | module ArrayFire.Internal.Vision where
3 |
4 | import ArrayFire.Internal.Defines
5 | import Foreign.Ptr
6 | import Foreign.C.Types
7 |
8 | #include "af/vision.h"
9 | foreign import ccall unsafe "af_fast"
10 | af_fast :: Ptr AFFeatures -> AFArray -> Float -> CUInt -> CBool -> Float -> CUInt -> IO AFErr
11 | foreign import ccall unsafe "af_harris"
12 | af_harris :: Ptr AFFeatures -> AFArray -> CUInt -> Float -> Float -> CUInt -> Float -> IO AFErr
13 | foreign import ccall unsafe "af_orb"
14 | af_orb :: Ptr AFFeatures -> Ptr AFArray -> AFArray -> Float -> CUInt -> Float -> CUInt -> CBool -> IO AFErr
15 | foreign import ccall unsafe "af_sift"
16 | af_sift :: Ptr AFFeatures -> Ptr AFArray -> AFArray -> CUInt -> Float -> Float -> Float -> CBool -> Float -> Float -> IO AFErr
17 | foreign import ccall unsafe "af_gloh"
18 | af_gloh :: Ptr AFFeatures -> Ptr AFArray -> AFArray -> CUInt -> Float -> Float -> Float -> CBool -> Float -> Float -> IO AFErr
19 | foreign import ccall unsafe "af_hamming_matcher"
20 | af_hamming_matcher :: Ptr AFArray -> Ptr AFArray -> AFArray -> AFArray -> DimT -> CUInt -> IO AFErr
21 | foreign import ccall unsafe "af_nearest_neighbour"
22 | af_nearest_neighbour :: Ptr AFArray -> Ptr AFArray -> AFArray -> AFArray -> DimT -> CUInt -> AFMatchType -> IO AFErr
23 | foreign import ccall unsafe "af_match_template"
24 | af_match_template :: Ptr AFArray -> AFArray -> AFArray -> AFMatchType -> IO AFErr
25 | foreign import ccall unsafe "af_susan"
26 | af_susan :: Ptr AFFeatures -> AFArray -> CUInt -> Float -> Float -> Float -> CUInt -> IO AFErr
27 | foreign import ccall unsafe "af_dog"
28 | af_dog :: Ptr AFArray -> AFArray -> CInt -> CInt -> IO AFErr
29 | foreign import ccall unsafe "af_homography"
30 | af_homography :: Ptr AFArray -> Ptr CInt -> AFArray -> AFArray -> AFArray -> AFArray -> AFHomographyType -> Float -> CUInt -> AFDtype -> IO AFErr
31 |
--------------------------------------------------------------------------------
/src/ArrayFire/Orphans.hs:
--------------------------------------------------------------------------------
1 | {-# LANGUAGE DataKinds #-}
2 | {-# LANGUAGE FlexibleInstances #-}
3 | {-# LANGUAGE TypeApplications #-}
4 | {-# LANGUAGE ScopedTypeVariables #-}
5 | {-# OPTIONS_GHC -fno-warn-orphans #-}
6 | --------------------------------------------------------------------------------
7 | -- |
8 | -- Module : ArrayFire.Orphans
9 | -- Copyright : David Johnson (c) 2019-2020
10 | -- License : BSD 3
11 | -- Maintainer : David Johnson
12 | -- Stability : Experimental
13 | -- Portability : GHC
14 | --
15 | --------------------------------------------------------------------------------
16 | module ArrayFire.Orphans where
17 |
18 | import Prelude
19 |
20 | import qualified ArrayFire.Arith as A
21 | import qualified ArrayFire.Array as A
22 | import qualified ArrayFire.Algorithm as A
23 | import qualified ArrayFire.Data as A
24 | import ArrayFire.Types
25 | import ArrayFire.Util
26 |
27 | instance (AFType a, Eq a) => Eq (Array a) where
28 | x == y = A.allTrueAll (A.eqBatched x y False) == (1.0,0.0)
29 | x /= y = A.allTrueAll (A.neqBatched x y False) == (0.0,0.0)
30 |
31 | instance (Num a, AFType a) => Num (Array a) where
32 | x + y = A.add x y
33 | x * y = A.mul x y
34 | abs = A.abs
35 | signum = A.sign
36 | negate arr = do
37 | let (w,x,y,z) = A.getDims arr
38 | A.cast (A.constant @a [w,x,y,z] 0) `A.sub` arr
39 | x - y = A.sub x y
40 | fromInteger = A.scalar . fromIntegral
41 |
42 | instance Show (Array a) where
43 | show = arrayString
44 |
45 | instance forall a . (Fractional a, AFType a) => Fractional (Array a) where
46 | x / y = A.div x y
47 | fromRational n = A.scalar @a (fromRational n)
48 |
49 | instance forall a . (Ord a, AFType a, Fractional a) => Floating (Array a) where
50 | pi = A.scalar @a 3.14159
51 | exp = A.exp @a
52 | log = A.log @a
53 | sqrt = A.sqrt @a
54 | (**) = A.pow @a
55 | sin = A.sin @a
56 | cos = A.cos @a
57 | tan = A.tan @a
58 | tanh = A.tanh @a
59 | asin = A.asin @a
60 | acos = A.acos @a
61 | atan = A.atan @a
62 | sinh = A.sinh @a
63 | cosh = A.cosh @a
64 | acosh = A.acosh @a
65 | atanh = A.atanh @a
66 | asinh = A.asinh @a
67 |
--------------------------------------------------------------------------------
/src/ArrayFire/Types.hs:
--------------------------------------------------------------------------------
1 | {-# LANGUAGE UndecidableInstances #-}
2 | {-# LANGUAGE TypeOperators #-}
3 | {-# LANGUAGE ScopedTypeVariables #-}
4 | {-# LANGUAGE PolyKinds #-}
5 | {-# LANGUAGE DataKinds #-}
6 | {-# LANGUAGE AllowAmbiguousTypes #-}
7 | {-# LANGUAGE FlexibleInstances #-}
8 | {-# LANGUAGE TypeApplications #-}
9 | {-# LANGUAGE ViewPatterns #-}
10 | {-# LANGUAGE KindSignatures #-}
11 | {-# LANGUAGE RecordWildCards #-}
12 | {-# LANGUAGE GADTs #-}
13 | {-# LANGUAGE TypeFamilies #-}
14 | --------------------------------------------------------------------------------
15 | -- |
16 | -- Module : ArrayFire.Types
17 | -- Copyright : David Johnson (c) 2019-2020
18 | -- License : BSD3
19 | -- Maintainer : David Johnson
20 | -- Stability : Experimental
21 | -- Portability : GHC
22 | --
23 | -- Various Types related to the ArrayFire API
24 | --
25 | --------------------------------------------------------------------------------
26 | module ArrayFire.Types
27 | ( AFException (..)
28 | , AFExceptionType (..)
29 | , Array
30 | , Window
31 | , RandomEngine
32 | , Features
33 | , AFType (..)
34 | , TopK (..)
35 | , Backend (..)
36 | , MatchType (..)
37 | , BinaryOp (..)
38 | , MatProp (..)
39 | , HomographyType (..)
40 | , RandomEngineType (..)
41 | , Cell (..)
42 | , MarkerType (..)
43 | , InterpType (..)
44 | , Connectivity (..)
45 | , CSpace (..)
46 | , YccStd (..)
47 | , MomentType (..)
48 | , CannyThreshold (..)
49 | , FluxFunction (..)
50 | , DiffusionEq (..)
51 | , IterativeDeconvAlgo (..)
52 | , InverseDeconvAlgo (..)
53 | , Seq (..)
54 | , Index (..)
55 | , NormType (..)
56 | , ConvMode (..)
57 | , ConvDomain (..)
58 | , BorderType (..)
59 | , Storage (..)
60 | , AFDType (..)
61 | , AFDtype (..)
62 | , ColorMap (..)
63 | ) where
64 |
65 | import ArrayFire.Exception
66 | import ArrayFire.Internal.Types
67 | import ArrayFire.Internal.Defines
68 |
--------------------------------------------------------------------------------
/src/ArrayFire/Util.hs:
--------------------------------------------------------------------------------
1 | {-# LANGUAGE ScopedTypeVariables #-}
2 | {-# LANGUAGE TypeApplications #-}
3 | {-# LANGUAGE ViewPatterns #-}
4 | --------------------------------------------------------------------------------
5 | -- |
6 | -- Module : ArrayFire.Util
7 | -- Copyright : David Johnson (c) 2019-2020
8 | -- License : BSD 3
9 | -- Maintainer : David Johnson
10 | -- Stability : Experimental
11 | -- Portability : GHC
12 | --
13 | -- Various utilities for working with the ArrayFire C library
14 | --
15 | -- @
16 | -- import qualified ArrayFire as A
17 | -- import Control.Monad
18 | --
19 | -- main :: IO ()
20 | -- main = do
21 | -- let arr = A.constant [1,1,1,1] 10
22 | -- idx <- A.saveArray "key" arr "file.array" False
23 | -- foundIndex <- A.readArrayKeyCheck "file.array" "key"
24 | -- when (idx == foundIndex) $ do
25 | -- array <- A.readArrayKey "file.array" "key"
26 | -- print array
27 | -- @
28 | -- @
29 | -- ArrayFire Array
30 | -- [ 1 1 1 1 ]
31 | -- 10
32 | -- @
33 | --------------------------------------------------------------------------------
34 | module ArrayFire.Util where
35 |
36 | import Control.Exception
37 |
38 | import Data.Proxy
39 | import Foreign.C.String
40 | import Foreign.ForeignPtr
41 | import Foreign.Marshal hiding (void)
42 | import Foreign.Storable
43 | import System.IO.Unsafe
44 |
45 | import ArrayFire.Internal.Types
46 | import ArrayFire.Internal.Util
47 |
48 | import ArrayFire.Exception
49 | import ArrayFire.FFI
50 |
51 | -- | Retrieve version for ArrayFire API
52 | --
53 | -- @
54 | -- >>> 'print' '=<<' 'getVersion'
55 | -- @
56 | -- @
57 | -- (3.6.4)
58 | -- @
59 | getVersion :: IO (Int,Int,Int)
60 | getVersion =
61 | alloca $ \x ->
62 | alloca $ \y ->
63 | alloca $ \z -> do
64 | throwAFError =<< af_get_version x y z
65 | (,,) <$> (fromIntegral <$> peek x)
66 | <*> (fromIntegral <$> peek y)
67 | <*> (fromIntegral <$> peek z)
68 |
69 | -- | Prints array to stdout
70 | --
71 | -- @
72 | -- >>> 'printArray' (constant \@'Double' [1] 1)
73 | -- @
74 | -- @
75 | -- ArrayFire Array
76 | -- [ 1 1 1 1 ]
77 | -- 1.0
78 | -- @
79 | printArray
80 | :: Array a
81 | -- ^ Input 'Array'
82 | -> IO ()
83 | printArray (Array fptr) =
84 | mask_ . withForeignPtr fptr $ \ptr ->
85 | throwAFError =<< af_print_array ptr
86 |
87 | -- | Gets git revision of ArrayFire
88 | --
89 | -- @
90 | -- >>> 'putStrLn' '=<<' 'getRevision'
91 | -- @
92 | -- @
93 | -- 1b8030c5
94 | -- @
95 | getRevision :: IO String
96 | getRevision = peekCString =<< af_get_revision
97 |
98 | -- | Prints 'Array' with error codes
99 | --
100 | -- @
101 | -- >>> printArrayGen "test" (constant \@'Double' [1] 1) 2
102 | -- @
103 | -- @
104 | -- ArrayFire Array
105 | -- [ 1 1 1 1 ]
106 | -- 1.00
107 | -- @
108 | printArrayGen
109 | :: String
110 | -- ^ is the expression or name of the array
111 | -> Array a
112 | -- ^ is the input array
113 | -> Int
114 | -- ^ precision for the display
115 | -> IO ()
116 | printArrayGen s (Array fptr) (fromIntegral -> prec) = do
117 | mask_ . withForeignPtr fptr $ \ptr ->
118 | withCString s $ \cstr ->
119 | throwAFError =<< af_print_array_gen cstr ptr prec
120 |
121 | -- | Saves 'Array' to disk
122 | --
123 | -- Save an array to a binary file.
124 | -- The 'saveArray' and readArray functions are designed to provide store and read access to arrays using files written to disk.
125 | --
126 | --
127 | -- @
128 | -- >>> saveArray "my array" (constant \@'Double' [1] 1) "array.file" 'True'
129 | -- @
130 | -- @
131 | -- 0
132 | -- @
133 | saveArray
134 | :: String
135 | -- ^ An expression used as tag/key for the 'Array' during readArray
136 | -> Array a
137 | -- ^ Input 'Array'
138 | -> FilePath
139 | -- ^ Path that 'Array' will be saved
140 | -> Bool
141 | -- ^ Used to append to an existing file when 'True' and create or overwrite an existing file when 'False'
142 | -> IO Int
143 | -- ^ The index location of the 'Array' in the file
144 | saveArray key (Array fptr) filename (fromIntegral . fromEnum -> append) = do
145 | mask_ . withForeignPtr fptr $ \ptr ->
146 | alloca $ \ptrIdx -> do
147 | withCString key $ \keyCstr ->
148 | withCString filename $ \filenameCstr -> do
149 | throwAFError =<<
150 | af_save_array ptrIdx keyCstr
151 | ptr filenameCstr append
152 | fromIntegral <$> peek ptrIdx
153 |
154 | -- | Reads Array by index
155 | --
156 | -- The 'saveArray' and readArray functions are designed to provide store and read access to arrays using files written to disk.
157 | --
158 | --
159 | -- @
160 | -- >>> readArrayIndex "array.file" 0
161 | -- @
162 | -- @
163 | -- ArrayFire Array
164 | -- [ 1 1 1 1 ]
165 | -- 10.0000
166 | -- @
167 | readArrayIndex
168 | :: FilePath
169 | -- ^ Path to 'Array' location
170 | -> Int
171 | -- ^ Index into 'Array'
172 | -> IO (Array a)
173 | readArrayIndex str (fromIntegral -> idx') =
174 | withCString str $ \cstr ->
175 | createArray' (\p -> af_read_array_index p cstr idx')
176 |
177 | -- | Reads 'Array' by key
178 | --
179 | -- @
180 | -- >>> readArrayKey "array.file" "my array"
181 | -- @
182 | -- @
183 | -- ArrayFire 'Array'
184 | -- [ 1 1 1 1 ]
185 | -- 10.0000
186 | -- @
187 | readArrayKey
188 | :: FilePath
189 | -- ^ Path to 'Array'
190 | -> String
191 | -- ^ Key of 'Array' on disk
192 | -> IO (Array a)
193 | -- ^ Returned 'Array'
194 | readArrayKey fn key =
195 | withCString fn $ \fcstr ->
196 | withCString key $ \kcstr ->
197 | createArray' (\p -> af_read_array_key p fcstr kcstr)
198 |
199 | -- | Reads Array, checks if a key exists in the specified file
200 | --
201 | -- When reading by key, it may be a good idea to run this function first to check for the key and then call the readArray using the index.
202 | --
203 | --
204 | -- @
205 | -- >>> readArrayCheck "array.file" "my array"
206 | -- @
207 | -- @
208 | -- 0
209 | -- @
210 | readArrayKeyCheck
211 | :: FilePath
212 | -- ^ Path to file
213 | -> String
214 | -- ^ Key
215 | -> IO Int
216 | -- ^ is the tag/name of the array to be read. The key needs to have an exact match.
217 | readArrayKeyCheck a b =
218 | withCString a $ \acstr ->
219 | withCString b $ \bcstr ->
220 | fromIntegral <$>
221 | afCall1 (\p -> af_read_array_key_check p acstr bcstr)
222 |
223 | -- | Convert ArrayFire 'Array' to 'String', used for 'Show' instance.
224 | --
225 | -- @
226 | -- >>> 'putStrLn' '$' 'arrayString' (constant \@'Double' 10 [1,1,1,1])
227 | -- @
228 | -- @
229 | -- ArrayFire 'Array'
230 | -- [ 1 1 1 1 ]
231 | -- 10.0000
232 | -- @
233 | arrayString
234 | :: Array a
235 | -- ^ Input 'Array'
236 | -> String
237 | -- ^ 'String' representation of 'Array'
238 | arrayString a = arrayToString "ArrayFire Array" a 4 True
239 |
240 | -- | Convert ArrayFire Array to String
241 | --
242 | -- @
243 | -- >>> print (constant \@'Double' 10 [1,1,1,1]) 4 'False'
244 | -- @
245 | -- @
246 | -- ArrayFire 'Array'
247 | -- [ 1 1 1 1 ]
248 | -- 10.0000
249 | -- @
250 | arrayToString
251 | :: String
252 | -- ^ Name of 'Array'
253 | -> Array a
254 | -- ^ 'Array' input
255 | -> Int
256 | -- ^ Precision of 'Array' values.
257 | -> Bool
258 | -- ^ If 'True', performs takes the transpose before rendering to 'String'
259 | -> String
260 | -- ^ 'Array' rendered to 'String'
261 | arrayToString expr (Array fptr) (fromIntegral -> prec) (fromIntegral . fromEnum -> trans) =
262 | unsafePerformIO . mask_ . withForeignPtr fptr $ \aptr ->
263 | withCString expr $ \expCstr ->
264 | alloca $ \ocstr -> do
265 | throwAFError =<< af_array_to_string ocstr expCstr aptr prec trans
266 | peekCString =<< peek ocstr
267 |
268 | -- | Retrieve size of ArrayFire data type
269 | --
270 | -- @
271 | -- >>> 'getSizeOf' ('Proxy' \@ 'Double')
272 | -- @
273 | -- @
274 | -- 8
275 | -- @
276 | getSizeOf
277 | :: forall a . AFType a
278 | => Proxy a
279 | -- ^ Witness of Haskell type that mirrors ArrayFire type.
280 | -> Int
281 | -- ^ Size of ArrayFire type
282 | getSizeOf proxy =
283 | unsafePerformIO . mask_ . alloca $ \csize -> do
284 | throwAFError =<< af_get_size_of csize (afType proxy)
285 | fromIntegral <$> peek csize
286 |
--------------------------------------------------------------------------------
/stack.yaml:
--------------------------------------------------------------------------------
1 | resolver:
2 | url: https://raw.githubusercontent.com/commercialhaskell/stackage-snapshots/master/lts/19/25.yaml
3 |
4 | packages:
5 | - .
6 |
--------------------------------------------------------------------------------
/stack.yaml.lock:
--------------------------------------------------------------------------------
1 | # This file was autogenerated by Stack.
2 | # You should not edit this file by hand.
3 | # For more information, please see the documentation at:
4 | # https://docs.haskellstack.org/en/stable/lock_files
5 |
6 | packages: []
7 | snapshots:
8 | - completed:
9 | size: 619403
10 | url: https://raw.githubusercontent.com/commercialhaskell/stackage-snapshots/master/lts/19/25.yaml
11 | sha256: 1ecad1f0bd2c27de88dbff6572446cfdf647c615d58a7e2e2085c6b7dfc04176
12 | original:
13 | url: https://raw.githubusercontent.com/commercialhaskell/stackage-snapshots/master/lts/19/25.yaml
14 |
--------------------------------------------------------------------------------
/test/ArrayFire/AlgorithmSpec.hs:
--------------------------------------------------------------------------------
1 | {-# LANGUAGE TypeApplications #-}
2 | module ArrayFire.AlgorithmSpec where
3 |
4 | import qualified ArrayFire as A
5 |
6 | import Test.Hspec
7 |
8 | spec :: Spec
9 | spec =
10 | describe "Algorithm tests" $ do
11 | it "Should sum a scalar" $ do
12 | A.sum (A.scalar @Int 10) 0 `shouldBe` 10
13 | A.sum (A.scalar @A.Int64 10) 0 `shouldBe` 10
14 | A.sum (A.scalar @A.Int32 10) 0 `shouldBe` 10
15 | A.sum (A.scalar @A.Int16 10) 0 `shouldBe` 10
16 | A.sum (A.scalar @Float 10) 0 `shouldBe` 10
17 | A.sum (A.scalar @A.Word32 10) 0 `shouldBe` 10
18 | A.sum (A.scalar @A.Word64 10) 0 `shouldBe` 10
19 | A.sum (A.scalar @Double 10) 0 `shouldBe` 10.0
20 | A.sum (A.scalar @(A.Complex Double) (1 A.:+ 1)) 0 `shouldBe` A.scalar (1 A.:+ 1)
21 | A.sum (A.scalar @(A.Complex Float) (1 A.:+ 1)) 0 `shouldBe` A.scalar (1 A.:+ 1)
22 | A.sum (A.scalar @A.CBool 1) 0 `shouldBe` 1
23 | A.sum (A.scalar @A.CBool 0) 0 `shouldBe` 0
24 | it "Should sum a vector" $ do
25 | A.sum (A.vector @Int 10 [1..]) 0 `shouldBe` 55
26 | A.sum (A.vector @A.Int64 10 [1..]) 0 `shouldBe` 55
27 | A.sum (A.vector @A.Int32 10 [1..]) 0 `shouldBe` 55
28 | A.sum (A.vector @A.Int16 10 [1..]) 0 `shouldBe` 55
29 | A.sum (A.vector @Float 10 [1..]) 0 `shouldBe` 55
30 | A.sum (A.vector @A.Word32 10 [1..]) 0 `shouldBe` 55
31 | A.sum (A.vector @A.Word64 10 [1..]) 0 `shouldBe` 55
32 | A.sum (A.vector @Double 10 [1..]) 0 `shouldBe` 55.0
33 | A.sum (A.vector @(A.Complex Double) 10 (repeat (1 A.:+ 1))) 0 `shouldBe` A.scalar (10.0 A.:+ 10.0)
34 | A.sum (A.vector @(A.Complex Float) 10 (repeat (1 A.:+ 1))) 0 `shouldBe` A.scalar (10.0 A.:+ 10.0)
35 | A.sum (A.vector @A.CBool 10 (repeat 1)) 0 `shouldBe` 10
36 | A.sum (A.vector @A.CBool 10 (repeat 0)) 0 `shouldBe` 0
37 | it "Should sum a default value to replace NaN" $ do
38 | A.sumNaN (A.vector @Float 10 [1..]) 0 1.0 `shouldBe` 55
39 | A.sumNaN (A.vector @Double 2 [acos 2, acos 2]) 0 50 `shouldBe` 100
40 | A.sumNaN (A.vector @(A.Complex Float) 10 (repeat (1 A.:+ 1))) 0 1.0 `shouldBe` A.scalar (10.0 A.:+ 10.0)
41 | A.sumNaN (A.vector @(A.Complex Double) 10 (repeat (1 A.:+ 1))) 0 1.0 `shouldBe` A.scalar (10.0 A.:+ 10.0)
42 | it "Should product a scalar" $ do
43 | A.product (A.scalar @Int 10) 0 `shouldBe` 10
44 | A.product (A.scalar @A.Int64 10) 0 `shouldBe` 10
45 | A.product (A.scalar @A.Int32 10) 0 `shouldBe` 10
46 | A.product (A.scalar @A.Int16 10) 0 `shouldBe` 10
47 | A.product (A.scalar @Float 10) 0 `shouldBe` 10
48 | A.product (A.scalar @A.Word32 10) 0 `shouldBe` 10
49 | A.product (A.scalar @A.Word64 10) 0 `shouldBe` 10
50 | A.product (A.scalar @Double 10) 0 `shouldBe` 10.0
51 | A.product (A.scalar @(A.Complex Double) (1 A.:+ 1)) 0 `shouldBe` A.scalar (1 A.:+ 1)
52 | A.product (A.scalar @(A.Complex Float) (1 A.:+ 1)) 0 `shouldBe` A.scalar (1 A.:+ 1)
53 | A.product (A.scalar @A.CBool 1) 0 `shouldBe` 1
54 | A.product (A.scalar @A.CBool 0) 0 `shouldBe` 0
55 | it "Should product a vector" $ do
56 | A.product (A.vector @Int 10 [1..]) 0 `shouldBe` 3628800
57 | A.product (A.vector @A.Int64 10 [1..]) 0 `shouldBe` 3628800
58 | A.product (A.vector @A.Int32 10 [1..]) 0 `shouldBe` 3628800
59 | A.product (A.vector @A.Int16 5 [1..]) 0 `shouldBe` 120
60 | A.product (A.vector @Float 10 [1..]) 0 `shouldBe` 3628800
61 | A.product (A.vector @A.Word32 10 [1..]) 0 `shouldBe` 3628800
62 | A.product (A.vector @A.Word64 10 [1..]) 0 `shouldBe` 3628800
63 | A.product (A.vector @Double 10 [1..]) 0 `shouldBe` 3628800.0
64 | A.product (A.vector @(A.Complex Double) 10 (repeat (1 A.:+ 1))) 0 `shouldBe` A.scalar (0.0 A.:+ 32.0)
65 | A.product (A.vector @(A.Complex Float) 10 (repeat (1 A.:+ 1))) 0 `shouldBe` A.scalar (0.0 A.:+ 32.0)
66 | A.product (A.vector @A.CBool 10 (repeat 1)) 0 `shouldBe` 1 -- FIXED in 3.8.2, vector product along 0-axis is 1 for vector size 10 of all 1's.
67 | A.product (A.vector @A.CBool 10 (repeat 0)) 0 `shouldBe` 0
68 | it "Should product a default value to replace NaN" $ do
69 | A.productNaN (A.vector @Float 10 [1..]) 0 1.0 `shouldBe` 3628800.0
70 | A.productNaN (A.vector @Double 2 [acos 2, acos 2]) 0 50 `shouldBe` 2500
71 | A.productNaN (A.vector @(A.Complex Float) 10 (repeat (1 A.:+ 1))) 0 1.0 `shouldBe` A.scalar (0.0 A.:+ 32)
72 | A.productNaN (A.vector @(A.Complex Double) 10 (repeat (1 A.:+ 1))) 0 1.0 `shouldBe` A.scalar (0 A.:+ 32)
73 | it "Should take the minimum element of a vector" $ do
74 | A.min (A.vector @Int 10 [1..]) 0 `shouldBe` 1
75 | A.min (A.vector @A.Int64 10 [1..]) 0 `shouldBe` 1
76 | A.min (A.vector @A.Int32 10 [1..]) 0 `shouldBe` 1
77 | A.min (A.vector @A.Int16 10 [1..]) 0 `shouldBe` 1
78 | A.min (A.vector @Float 10 [1..]) 0 `shouldBe` 1
79 | A.min (A.vector @A.Word32 10 [1..]) 0 `shouldBe` 1
80 | A.min (A.vector @A.Word64 10 [1..]) 0 `shouldBe` 1
81 | A.min (A.vector @Double 10 [1..]) 0 `shouldBe` 1
82 | A.min (A.vector @(A.Complex Double) 10 (repeat (1 A.:+ 1))) 0 `shouldBe` A.scalar (1 A.:+ 1)
83 | A.min (A.vector @(A.Complex Float) 10 (repeat (1 A.:+ 1))) 0 `shouldBe` A.scalar (1 A.:+ 1)
84 | A.min (A.vector @A.CBool 10 [1..]) 0 `shouldBe` 1
85 | A.min (A.vector @A.CBool 10 [1..]) 0 `shouldBe` 1
86 | it "Should find if all elements are true along dimension" $ do
87 | A.allTrue (A.vector @Double 5 (repeat 12.0)) 0 `shouldBe` 1
88 | A.allTrue (A.vector @A.CBool 5 (repeat 1)) 0 `shouldBe` 1
89 | A.allTrue (A.vector @A.CBool 5 (repeat 0)) 0 `shouldBe` 0
90 | A.allTrue (A.vector @A.CBool 5 (repeat 0)) 0 `shouldBe` 0
91 | it "Should find if any elements are true along dimension" $ do
92 | A.anyTrue (A.vector @A.CBool 5 (repeat 1)) 0 `shouldBe` 1
93 | A.anyTrue (A.vector @Int 5 (repeat 23)) 0 `shouldBe` 1
94 | A.anyTrue (A.vector @A.CBool 5 (repeat 0)) 0 `shouldBe` 0
95 | it "Should get count of all elements" $ do
96 | A.count (A.vector @Int 5 (repeat 1)) 0 `shouldBe` 5
97 | A.count (A.vector @A.CBool 5 (repeat 1)) 0 `shouldBe` 5
98 | A.count (A.vector @Double 5 (repeat 1)) 0 `shouldBe` 5
99 | A.count (A.vector @Float 5 (repeat 1)) 0 `shouldBe` 5
100 | it "Should get sum all elements" $ do
101 | A.sumAll (A.vector @Int 5 (repeat 2)) `shouldBe` (10,0)
102 | A.sumAll (A.vector @Double 5 (repeat 2)) `shouldBe` (10.0,0)
103 | A.sumAll (A.vector @A.CBool 3800 (repeat 1)) `shouldBe` (3800,0)
104 | A.sumAll (A.vector @(A.Complex Double) 5 (repeat (2 A.:+ 0))) `shouldBe` (10.0,0)
105 | it "Should get sum all elements" $ do
106 | A.sumNaNAll (A.vector @Double 2 [10, acos 2]) 1 `shouldBe` (11.0,0)
107 | it "Should product all elements in an Array" $ do
108 | A.productAll (A.vector @Int 5 (repeat 2)) `shouldBe` (32,0)
109 | it "Should product all elements in an Array" $ do
110 | A.productNaNAll (A.vector @Double 2 [10,acos 2]) 10 `shouldBe` (100,0)
111 | it "Should find minimum value of an Array" $ do
112 | A.minAll (A.vector @Int 5 [0..]) `shouldBe` (0,0)
113 | it "Should find maximum value of an Array" $ do
114 | A.maxAll (A.vector @Int 5 [0..]) `shouldBe` (4,0)
115 | -- it "Should find if all elements are true" $ do
116 | -- A.allTrue (A.vector @A.CBool 5 (repeat 0)) `shouldBe` False
117 |
118 |
--------------------------------------------------------------------------------
/test/ArrayFire/ArithSpec.hs:
--------------------------------------------------------------------------------
1 | {-# LANGUAGE RankNTypes #-}
2 | {-# LANGUAGE ScopedTypeVariables #-}
3 | {-# LANGUAGE TypeApplications #-}
4 |
5 | module ArrayFire.ArithSpec where
6 |
7 | import ArrayFire (AFType, Array, cast, clamp, getType, isInf, isZero, matrix, maxOf, minOf, mkArray, scalar, vector)
8 | import qualified ArrayFire
9 | import Control.Exception (throwIO)
10 | import Control.Monad (unless, when)
11 | import Foreign.C
12 | import GHC.Exts (IsList (..))
13 | import GHC.Stack
14 | import Test.HUnit.Lang (FailureReason (..), HUnitFailure (..))
15 | import Test.Hspec
16 | import Test.Hspec.QuickCheck
17 | import Prelude hiding (div)
18 |
19 | compareWith :: (HasCallStack, Show a) => (a -> a -> Bool) -> a -> a -> Expectation
20 | compareWith comparator result expected =
21 | unless (comparator result expected) $ do
22 | throwIO (HUnitFailure location $ ExpectedButGot Nothing expectedMsg actualMsg)
23 | where
24 | expectedMsg = show expected
25 | actualMsg = show result
26 | location = case reverse (toList callStack) of
27 | (_, loc) : _ -> Just loc
28 | [] -> Nothing
29 |
30 | class (Num a) => HasEpsilon a where
31 | eps :: a
32 |
33 | instance HasEpsilon Float where
34 | eps = 1.1920929e-7
35 |
36 | instance HasEpsilon Double where
37 | eps = 2.220446049250313e-16
38 |
39 | approxWith :: (Ord a, Num a) => a -> a -> a -> a -> Bool
40 | approxWith rtol atol a b = abs (a - b) <= Prelude.max atol (rtol * Prelude.max (abs a) (abs b))
41 |
42 | approx :: (Ord a, HasEpsilon a) => a -> a -> Bool
43 | approx a b = approxWith (2 * eps * Prelude.max (abs a) (abs b)) (4 * eps) a b
44 |
45 | shouldBeApprox :: (Ord a, HasEpsilon a, Show a) => a -> a -> Expectation
46 | shouldBeApprox = compareWith approx
47 |
48 | evalf :: (AFType a) => Array a -> a
49 | evalf = ArrayFire.getScalar
50 |
51 | shouldMatchBuiltin ::
52 | (AFType a, Ord a, RealFloat a, HasEpsilon a, Show a) =>
53 | (Array a -> Array a) ->
54 | (a -> a) ->
55 | a ->
56 | Expectation
57 | shouldMatchBuiltin f f' x
58 | | isInfinite y && isInfinite y' = pure ()
59 | | Prelude.isNaN y && Prelude.isNaN y' = pure ()
60 | | otherwise = y `shouldBeApprox` y'
61 | where
62 | y = evalf (f (scalar x))
63 | y' = f' x
64 |
65 | shouldMatchBuiltin2 ::
66 | (AFType a, Ord a, RealFloat a, HasEpsilon a, Show a) =>
67 | (Array a -> Array a -> Array a) ->
68 | (a -> a -> a) ->
69 | a ->
70 | a ->
71 | Expectation
72 | shouldMatchBuiltin2 f f' a = shouldMatchBuiltin (f (scalar a)) (f' a)
73 |
74 | spec :: Spec
75 | spec =
76 | describe "Arith tests" $ do
77 | it "Should negate scalar value" $ do
78 | negate (scalar @Int 1) `shouldBe` (-1)
79 | it "Should negate a vector" $ do
80 | negate (vector @Int 3 [2, 2, 2]) `shouldBe` vector @Int 3 [-2, -2, -2]
81 | it "Should add two scalar arrays" $ do
82 | scalar @Int 1 + 2 `shouldBe` 3
83 | it "Should add two scalar bool arrays" $ do
84 | scalar @CBool 1 + 0 `shouldBe` 1
85 | it "Should subtract two scalar arrays" $ do
86 | scalar @Int 4 - 2 `shouldBe` 2
87 | it "Should multiply two scalar arrays" $ do
88 | scalar @Double 4 `ArrayFire.mul` 2 `shouldBe` 8
89 | it "Should divide two scalar arrays" $ do
90 | ArrayFire.div @Double 8 2 `shouldBe` 4
91 | it "Should add two matrices" $ do
92 | matrix @Int (2, 2) [[1, 1], [1, 1]] + matrix @Int (2, 2) [[1, 1], [1, 1]]
93 | `shouldBe` matrix @Int (2, 2) [[2, 2], [2, 2]]
94 | prop "Should take cubed root" $ \(x :: Double) ->
95 | evalf (ArrayFire.cbrt (scalar (x * x * x))) `shouldBeApprox` x
96 |
97 | it "Should lte Array" $ do
98 | 2 `ArrayFire.le` (3 :: Array Double) `shouldBe` 1
99 | it "Should gte Array" $ do
100 | 2 `ArrayFire.ge` (3 :: Array Double) `shouldBe` 0
101 | it "Should gt Array" $ do
102 | 2 `ArrayFire.gt` (3 :: Array Double) `shouldBe` 0
103 | it "Should lt Array" $ do
104 | 2 `ArrayFire.le` (3 :: Array Double) `shouldBe` 1
105 | it "Should eq Array" $ do
106 | 3 == (3 :: Array Double) `shouldBe` True
107 | it "Should and Array" $ do
108 | (mkArray @CBool [1] [0] `ArrayFire.and` mkArray [1] [1])
109 | `shouldBe` mkArray [1] [0]
110 | it "Should and Array" $ do
111 | (mkArray @CBool [2] [0, 0] `ArrayFire.and` mkArray [2] [1, 0])
112 | `shouldBe` mkArray [2] [0, 0]
113 | it "Should or Array" $ do
114 | (mkArray @CBool [2] [0, 0] `ArrayFire.or` mkArray [2] [1, 0])
115 | `shouldBe` mkArray [2] [1, 0]
116 | it "Should not Array" $ do
117 | ArrayFire.not (mkArray @CBool [2] [1, 0]) `shouldBe` mkArray [2] [0, 1]
118 | it "Should bitwise and array" $ do
119 | ArrayFire.bitAnd (scalar @Int 1) (scalar @Int 0)
120 | `shouldBe` 0
121 | it "Should bitwise or array" $ do
122 | ArrayFire.bitOr (scalar @Int 1) (scalar @Int 0)
123 | `shouldBe` 1
124 | it "Should bitwise xor array" $ do
125 | ArrayFire.bitXor (scalar @Int 1) (scalar @Int 1)
126 | `shouldBe` 0
127 | it "Should bitwise shift left an array" $ do
128 | ArrayFire.bitShiftL (scalar @Int 1) (scalar @Int 3)
129 | `shouldBe` 8
130 | it "Should cast an array" $ do
131 | getType (cast (scalar @Int 1) :: Array Double)
132 | `shouldBe` ArrayFire.F64
133 | it "Should find the minimum of two arrays" $ do
134 | minOf (scalar @Int 1) (scalar @Int 0)
135 | `shouldBe` 0
136 | it "Should find the max of two arrays" $ do
137 | maxOf (scalar @Int 1) (scalar @Int 0)
138 | `shouldBe` 1
139 | it "Should take the clamp of 3 arrays" $ do
140 | clamp (scalar @Int 2) (scalar @Int 1) (scalar @Int 3)
141 | `shouldBe` 2
142 | it "Should check if an array has positive or negative infinities" $ do
143 | isInf (scalar @Double (1 / 0)) `shouldBe` scalar @Double 1
144 | isInf (scalar @Double 10) `shouldBe` scalar @Double 0
145 | it "Should check if an array has any NaN values" $ do
146 | ArrayFire.isNaN (scalar @Double (acos 2)) `shouldBe` scalar @Double 1
147 | ArrayFire.isNaN (scalar @Double 10) `shouldBe` scalar @Double 0
148 | it "Should check if an array has any Zero values" $ do
149 | isZero (scalar @Double (acos 2)) `shouldBe` scalar @Double 0
150 | isZero (scalar @Double 0) `shouldBe` scalar @Double 1
151 | isZero (scalar @Double 1) `shouldBe` scalar @Double 0
152 |
153 | prop "Floating @Float (exp)" $ \(x :: Float) -> exp `shouldMatchBuiltin` exp $ x
154 | prop "Floating @Float (log)" $ \(x :: Float) -> log `shouldMatchBuiltin` log $ x
155 | prop "Floating @Float (sqrt)" $ \(x :: Float) -> sqrt `shouldMatchBuiltin` sqrt $ x
156 | prop "Floating @Float (**)" $ \(x :: Float) (y :: Float) -> ((**) `shouldMatchBuiltin2` (**)) x y
157 | prop "Floating @Float (sin)" $ \(x :: Float) -> sin `shouldMatchBuiltin` sin $ x
158 | prop "Floating @Float (cos)" $ \(x :: Float) -> cos `shouldMatchBuiltin` cos $ x
159 | prop "Floating @Float (tan)" $ \(x :: Float) -> tan `shouldMatchBuiltin` tan $ x
160 | prop "Floating @Float (asin)" $ \(x :: Float) -> asin `shouldMatchBuiltin` asin $ x
161 | prop "Floating @Float (acos)" $ \(x :: Float) -> acos `shouldMatchBuiltin` acos $ x
162 | prop "Floating @Float (atan)" $ \(x :: Float) -> atan `shouldMatchBuiltin` atan $ x
163 | prop "Floating @Float (sinh)" $ \(x :: Float) -> sinh `shouldMatchBuiltin` sinh $ x
164 | prop "Floating @Float (cosh)" $ \(x :: Float) -> cosh `shouldMatchBuiltin` cosh $ x
165 | prop "Floating @Float (tanh)" $ \(x :: Float) -> tanh `shouldMatchBuiltin` tanh $ x
166 | prop "Floating @Float (asinh)" $ \(x :: Float) -> asinh `shouldMatchBuiltin` asinh $ x
167 | prop "Floating @Float (acosh)" $ \(x :: Float) -> acosh `shouldMatchBuiltin` acosh $ x
168 | prop "Floating @Float (atanh)" $ \(x :: Float) -> atanh `shouldMatchBuiltin` atanh $ x
169 |
--------------------------------------------------------------------------------
/test/ArrayFire/ArraySpec.hs:
--------------------------------------------------------------------------------
1 | {-# LANGUAGE ScopedTypeVariables #-}
2 | {-# LANGUAGE TypeApplications #-}
3 | module ArrayFire.ArraySpec where
4 |
5 | import Control.Exception
6 | import Data.Complex
7 | import Data.Word
8 | import Foreign.C.Types
9 | import GHC.Int
10 | import Test.Hspec
11 |
12 | import ArrayFire
13 |
14 | spec :: Spec
15 | spec =
16 | describe "Array tests" $ do
17 | it "Should perform Array tests" $ do
18 | (1 + 1) `shouldBe` 2
19 | it "Should fail to create 0 dimension arrays" $ do
20 | let arr = mkArray @Int [0,0,0,0] [1..]
21 | evaluate arr `shouldThrow` anyException
22 | it "Should fail to create 0 length arrays" $ do
23 | let arr = mkArray @Int [0,0,0,1] []
24 | evaluate arr `shouldThrow` anyException
25 | it "Should fail to create 0 length arrays w/ 0 dimensions" $ do
26 | let arr = mkArray @Int [0,0,0,0] []
27 | evaluate arr `shouldThrow` anyException
28 | it "Should create a column vector" $ do
29 | let arr = mkArray @Int [9,1,1,1] (repeat 9)
30 | isColumn arr `shouldBe` True
31 | it "Should create a row vector" $ do
32 | let arr = mkArray @Int [1,9,1,1] (repeat 9)
33 | isRow arr `shouldBe` True
34 | it "Should create a vector" $ do
35 | let arr = mkArray @Int [9,1,1,1] (repeat 9)
36 | isVector arr `shouldBe` True
37 | it "Should create a vector" $ do
38 | let arr = mkArray @Int [1,9,1,1] (repeat 9)
39 | isVector arr `shouldBe` True
40 | it "Should copy an array" $ do
41 | let arr = mkArray @Int [9,9,1,1] (repeat 9)
42 | let newArray = copyArray arr
43 | newArray `shouldBe` arr
44 | it "Should modify manual eval flag" $ do
45 | setManualEvalFlag False
46 | (`shouldBe` False) =<< getManualEvalFlag
47 | it "Should return the number of elements" $ do
48 | let arr = mkArray @Int [9,9,1,1] [1..]
49 | getElements arr `shouldBe` 81
50 | -- it "Should give an empty array" $ do
51 | -- let arr = mkArray @Int [-1,1,1,1] []
52 | -- getElements arr `shouldBe` 0
53 | -- isEmpty arr `shouldBe` True
54 | it "Should create a scalar array" $ do
55 | let arr = mkArray @Int [1] [1]
56 | isScalar arr `shouldBe` True
57 | it "Should get number of dims specified" $ do
58 | let arr = mkArray @Int [1,1,1,1] [1]
59 | getNumDims arr `shouldBe` 1
60 | let arr = mkArray @Int [2,3,4,5] [1..]
61 | getNumDims arr `shouldBe` 4
62 | let arr = mkArray @Int [2,3,4] [1..]
63 | getNumDims arr `shouldBe` 3
64 | it "Should get value of dims specified" $ do
65 | let arr = mkArray @Int [2,3,4,5] (repeat 1)
66 | getDims arr `shouldBe` (2,3,4,5)
67 |
68 | it "Should test Sparsity" $ do
69 | let arr = mkArray @Double [2,2,1,1] (repeat 1)
70 | isSparse arr `shouldBe` False
71 |
72 | it "Should make a Bit array" $ do
73 | let arr = mkArray @CBool [2,2] [1,1,1,1]
74 | isBool arr `shouldBe` True
75 |
76 | it "Should make an integer array" $ do
77 | let arr = mkArray @Int [2,2] (repeat 1)
78 | isInteger arr `shouldBe` True
79 |
80 | it "Should make a Floating array" $ do
81 | let arr = mkArray @Double [2,2] (repeat 1)
82 | isFloating arr `shouldBe` True
83 | let arr = mkArray @CBool [2,2] (repeat 1)
84 | isFloating arr `shouldBe` False
85 |
86 | it "Should make a Complex array" $ do
87 | let arr = mkArray @(Complex Double) [2,2] (repeat 1)
88 | isComplex arr `shouldBe` True
89 | isReal arr `shouldBe` False
90 |
91 | it "Should make a Real array" $ do
92 | let arr = mkArray @Double [2,2] (repeat 1)
93 | isReal arr `shouldBe` True
94 | isComplex arr `shouldBe` False
95 |
96 | it "Should make a Double precision array" $ do
97 | let arr = mkArray @Double [2,2] (repeat 1)
98 | isDouble arr `shouldBe` True
99 | isSingle arr `shouldBe` False
100 |
101 | it "Should make a Single precision array" $ do
102 | let arr = mkArray @Float [2,2] (repeat 1)
103 | isDouble arr `shouldBe` False
104 | isSingle arr `shouldBe` True
105 |
106 | it "Should make a Real floating array" $ do
107 | let arr = mkArray @Float [2,2] (repeat 1)
108 | isRealFloating arr `shouldBe` True
109 | let arr = mkArray @Double [2,2] (repeat 1)
110 | isRealFloating arr `shouldBe` True
111 |
112 | it "Should get reference count" $ do
113 | let arr1 = mkArray @Float [2,2] (repeat 1)
114 | arr2 = retainArray arr1
115 | arr3 = retainArray arr2
116 | getDataRefCount arr3 `shouldBe` 3
117 |
118 | it "Should convert an array to a list" $ do
119 | let arr = mkArray @Double [30,30] (repeat 1)
120 | toList arr `shouldBe` Prelude.replicate (30 * 30) 1
121 |
122 | let arr = mkArray @Float [10,10] (repeat (5.5))
123 | toList arr `shouldBe` Prelude.replicate 100 5.5
124 |
125 | let arr = mkArray @CBool [4] [1,1,0,1]
126 | toList arr `shouldBe` [1,1,0,1]
127 |
128 | let arr = mkArray @Int16 [10] [1..]
129 | toList arr `shouldBe` [1..10]
130 |
131 | let arr = mkArray @Int32 [100] [1..100]
132 | toList arr `shouldBe` [1..100]
133 |
134 | let arr = mkArray @Int64 [100] [1..100]
135 | toList arr `shouldBe` [1..100]
136 |
137 | let arr = mkArray @Int [100] [1..100]
138 | toList arr `shouldBe` [1..100]
139 |
140 | let arr = mkArray @(Complex Float) [1] [1 :+ 1]
141 | toList arr `shouldBe` [1 :+ 1]
142 |
143 | let arr = mkArray @(Complex Double) [1] [1 :+ 1]
144 | toList arr `shouldBe` [1 :+ 1]
145 |
146 | let arr = mkArray @Word16 [10] [1..10]
147 | toList arr `shouldBe` [1..10]
148 |
149 | let arr = mkArray @Word32 [10] [1..10]
150 | toList arr `shouldBe` [1..10]
151 |
152 | let arr = mkArray @Word64 [10] [1..10]
153 | toList arr `shouldBe` [1..10]
154 |
155 | let arr = mkArray @Word [10] [1..10]
156 | toList arr `shouldBe` [1..10]
157 |
--------------------------------------------------------------------------------
/test/ArrayFire/BLASSpec.hs:
--------------------------------------------------------------------------------
1 | {-# LANGUAGE TypeApplications #-}
2 | module ArrayFire.BLASSpec where
3 |
4 | import ArrayFire hiding (not)
5 |
6 | import Data.Complex
7 | import Test.Hspec
8 |
9 | spec :: Spec
10 | spec =
11 | describe "BLAS spec" $ do
12 | it "Should matmul two matrices" $ do
13 | (matrix @Double (2,2) [[2,2],[2,2]] `matmul` matrix @Double (2,2) [[2,2],[2,2]]) None None
14 | `shouldBe` matrix @Double (2,2) [[8,8],[8,8]]
15 | it "Should dot product two vectors" $ do
16 | dot (vector @Double 2 (repeat 2)) (vector @Double 2 (repeat 2)) None None
17 | `shouldBe`
18 | scalar @Double 8
19 | it "Should produce scalar dot product between two vectors as a Complex number" $ do
20 | dotAll (vector @Double 2 (repeat 2)) (vector @Double 2 (repeat 2)) None None
21 | `shouldBe`
22 | 8.0 :+ 0.0
23 | it "Should take the transpose of a matrix" $ do
24 | transpose (matrix @Double (2,2) [[1,1],[2,2]]) False
25 | `shouldBe`
26 | matrix @Double (2,2) [[1,2],[1,2]]
27 | it "Should take the transpose of a matrix in place" $ do
28 | let m = matrix @Double (2,2) [[1,1],[2,2]]
29 | transposeInPlace m False
30 | m `shouldBe` matrix @Double (2,2) [[1,2],[1,2]]
31 |
32 |
33 |
34 |
35 |
36 |
--------------------------------------------------------------------------------
/test/ArrayFire/BackendSpec.hs:
--------------------------------------------------------------------------------
1 | {-# LANGUAGE TypeApplications #-}
2 | module ArrayFire.BackendSpec where
3 |
4 | import ArrayFire hiding (not)
5 |
6 | import Test.Hspec
7 |
8 | spec :: Spec
9 | spec =
10 | describe "Backend spec" $ do
11 | it "Should get backend count" $ do
12 | (`shouldSatisfy` (>0)) =<< getBackendCount
13 | it "Should get available backends" $ do
14 | backends <- getAvailableBackends
15 | backends `shouldSatisfy` (CPU `elem`)
16 | it "Should set backend to CPU" $ do
17 | backend <- getActiveBackend
18 | setBackend backend
19 | (`shouldBe` backend) =<< getActiveBackend
20 | let arr = matrix @Int (2,2) [[1,1],[1,1]]
21 | getBackend arr `shouldBe` backend
22 |
--------------------------------------------------------------------------------
/test/ArrayFire/DataSpec.hs:
--------------------------------------------------------------------------------
1 | {-# LANGUAGE ScopedTypeVariables #-}
2 | {-# LANGUAGE TypeApplications #-}
3 | module ArrayFire.DataSpec where
4 |
5 | import Control.Exception
6 | import Data.Complex
7 | import Data.Word
8 | import Foreign.C.Types
9 | import GHC.Int
10 | import Test.Hspec
11 |
12 | import ArrayFire
13 |
14 | spec :: Spec
15 | spec =
16 | describe "Data tests" $ do
17 | it "Should create constant Array" $ do
18 | constant @Float [1] 1 `shouldBe` 1
19 | constant @Double [1] 1 `shouldBe` 1
20 | constant @Int16 [1] 1 `shouldBe` 1
21 | constant @Int32 [1] 1 `shouldBe` 1
22 | constant @Int64 [1] 1 `shouldBe` 1
23 | constant @Int [1] 1 `shouldBe` 1
24 | constant @Word16 [1] 1 `shouldBe` 1
25 | constant @Word32 [1] 1 `shouldBe` 1
26 | constant @Word64 [1] 1 `shouldBe` 1
27 | constant @Word [1] 1 `shouldBe` 1
28 | constant @CBool [1] 1 `shouldBe` 1
29 | constant @(Complex Double) [1] (1.0 :+ 1.0)
30 | `shouldBe`
31 | constant @(Complex Double) [1] (1.0 :+ 1.0)
32 | constant @(Complex Float) [1] (1.0 :+ 1.0)
33 | `shouldBe`
34 | constant @(Complex Float) [1] (1.0 :+ 1.0)
35 | it "Should join Arrays along the specified dimension" $ do
36 | join 0 (constant @Int [1, 3] 1) (constant @Int [1, 3] 2) `shouldBe` mkArray @Int [2, 3] [1, 2, 1, 2, 1, 2]
37 | join 1 (constant @Int [1, 2] 1) (constant @Int [1, 2] 2) `shouldBe` mkArray @Int [1, 4] [1, 1, 2, 2]
38 | joinMany 0 [constant @Int [1, 3] 1, constant @Int [1, 3] 2] `shouldBe` mkArray @Int [2, 3] [1, 2, 1, 2, 1, 2]
39 | joinMany 1 [constant @Int [1, 2] 1, constant @Int [1, 1] 2, constant @Int [1, 3] 3] `shouldBe` mkArray @Int [1, 6] [1, 1, 2, 3, 3, 3]
40 |
--------------------------------------------------------------------------------
/test/ArrayFire/DeviceSpec.hs:
--------------------------------------------------------------------------------
1 | {-# LANGUAGE TypeApplications #-}
2 | module ArrayFire.DeviceSpec where
3 |
4 | import qualified ArrayFire as A
5 | import Foreign.C.Types
6 | import Test.Hspec
7 |
8 | spec :: Spec
9 | spec =
10 | describe "Algorithm tests" $ do
11 | it "Should show device info" $ do
12 | A.info `shouldReturn` ()
13 | it "Should show device init" $ do
14 | A.afInit `shouldReturn` ()
15 | it "Should get info string" $ do
16 | A.getInfoString >>= (`shouldSatisfy` (not . null))
17 | it "Should get device" $ do
18 | A.getDevice >>= (`shouldSatisfy` (>= 0))
19 | it "Should get and set device" $ do
20 | (A.getDevice >>= A.setDevice) `shouldReturn` ()
21 |
22 |
--------------------------------------------------------------------------------
/test/ArrayFire/FeaturesSpec.hs:
--------------------------------------------------------------------------------
1 | {-# LANGUAGE TypeApplications #-}
2 | module ArrayFire.FeaturesSpec where
3 |
4 | import ArrayFire hiding (acos)
5 | import Prelude
6 | import Test.Hspec
7 |
8 | spec :: Spec
9 | spec =
10 | describe "Feautures tests" $ do
11 | it "Should get features number an array" $ do
12 | let feats = createFeatures 10
13 | getFeaturesNum feats `shouldBe` 10
14 |
--------------------------------------------------------------------------------
/test/ArrayFire/GraphicsSpec.hs:
--------------------------------------------------------------------------------
1 | {-# LANGUAGE ScopedTypeVariables #-}
2 | {-# LANGUAGE TypeApplications #-}
3 | module ArrayFire.GraphicsSpec where
4 |
5 | import Control.Exception
6 | import Data.Complex
7 | import Data.Word
8 | import Foreign.C.Types
9 | import GHC.Int
10 | import Test.Hspec
11 |
12 | import ArrayFire
13 |
14 | spec :: Spec
15 | spec =
16 | describe "Graphics tests" $ do
17 | it "Should create window" $ do
18 | (1 + 1) `shouldBe` 2
19 |
--------------------------------------------------------------------------------
/test/ArrayFire/ImageSpec.hs:
--------------------------------------------------------------------------------
1 | {-# LANGUAGE ScopedTypeVariables #-}
2 | {-# LANGUAGE TypeApplications #-}
3 | module ArrayFire.ImageSpec where
4 |
5 | import Control.Exception
6 | import Data.Complex
7 | import Data.Word
8 | import Foreign.C.Types
9 | import GHC.Int
10 | import Test.Hspec
11 |
12 | import ArrayFire
13 |
14 | spec :: Spec
15 | spec =
16 | describe "Image tests" $ do
17 | it "Should test if Image I/O is available" $ do
18 | isImageIOAvailable `shouldReturn` True
19 |
--------------------------------------------------------------------------------
/test/ArrayFire/IndexSpec.hs:
--------------------------------------------------------------------------------
1 | {-# LANGUAGE BangPatterns #-}
2 | {-# LANGUAGE TypeApplications #-}
3 | module ArrayFire.IndexSpec where
4 |
5 | import qualified ArrayFire as A
6 | import Control.Exception
7 | import Data.Complex
8 | import Data.Int
9 | import Data.Proxy
10 | import Data.Word
11 | import Foreign.C.Types
12 | import Test.Hspec
13 |
14 | spec :: Spec
15 | spec =
16 | describe "Index spec" $ do
17 | it "Should index into an array" $ do
18 | let arr = A.vector @Int 10 [1..]
19 | A.index arr [A.Seq 0 4 1]
20 | `shouldBe`
21 | A.vector @Int 5 [1..]
22 |
--------------------------------------------------------------------------------
/test/ArrayFire/LAPACKSpec.hs:
--------------------------------------------------------------------------------
1 | {-# LANGUAGE TypeApplications #-}
2 | module ArrayFire.LAPACKSpec where
3 |
4 | import qualified ArrayFire as A
5 | import Prelude
6 | import Test.Hspec
7 | import Test.Hspec.ApproxExpect
8 |
9 | spec :: Spec
10 | spec =
11 | describe "LAPACK spec" $ do
12 | it "Should have LAPACK available" $ do
13 | A.isLAPACKAvailable `shouldBe` True
14 | it "Should perform svd" $ do
15 | let (s,v,d) = A.svd $ A.matrix @Double (4,2) [ [1,2,3,4], [5,6,7,8] ]
16 | A.getDims s `shouldBe` (4,4,1,1)
17 | A.getDims v `shouldBe` (2,1,1,1)
18 | A.getDims d `shouldBe` (2,2,1,1)
19 | it "Should perform svd in place" $ do
20 | let (s,v,d) = A.svdInPlace $ A.matrix @Double (4,2) [ [1,2,3,4], [5,6,7,8] ]
21 | A.getDims s `shouldBe` (4,4,1,1)
22 | A.getDims v `shouldBe` (2,1,1,1)
23 | A.getDims d `shouldBe` (2,2,1,1)
24 | it "Should perform lu" $ do
25 | let (s,v,d) = A.lu $ A.matrix @Double (2,2) [[3,1],[4,2]]
26 | A.getDims s `shouldBe` (2,2,1,1)
27 | A.getDims v `shouldBe` (2,2,1,1)
28 | A.getDims d `shouldBe` (2,1,1,1)
29 | it "Should perform qr" $ do
30 | let (s,v,d) = A.lu $ A.matrix @Double (3,3) [[12,6,4],[-51,167,24],[4,-68,-41]]
31 | A.getDims s `shouldBe` (3,3,1,1)
32 | A.getDims v `shouldBe` (3,3,1,1)
33 | A.getDims d `shouldBe` (3,1,1,1)
34 | it "Should get determinant of Double" $ do
35 | let eles = [[3 A.:+ 1, 8 A.:+ 1], [4 A.:+ 1, 6 A.:+ 1]]
36 | (x,y) = A.det (A.matrix @(A.Complex Double) (2,2) eles)
37 | x `shouldBeApprox` (-14)
38 | let (x,y) = A.det $ A.matrix @Double (2,2) [[3,8],[4,6]]
39 | x `shouldBeApprox` (-14)
40 | -- it "Should calculate inverse" $ do
41 | -- let x = flip A.inverse A.None $ A.matrix @Double (2,2) [[4.0,7.0],[2.0,6.0]]
42 | -- x `shouldBe` A.matrix (2,2) [[0.6,-0.7],[-0.2,0.4]]
43 | -- it "Should calculate psuedo inverse" $ do
44 | -- let x = A.pinverse (A.matrix @Double (2,2) [[4,7],[2,6]]) 1.0 A.None
45 | -- x `shouldBe` A.matrix @Double (2,2) [[0.6,-0.2],[-0.7,0.4]]
46 |
--------------------------------------------------------------------------------
/test/ArrayFire/RandomSpec.hs:
--------------------------------------------------------------------------------
1 | {-# LANGUAGE TypeApplications #-}
2 | module ArrayFire.RandomSpec where
3 |
4 | import ArrayFire
5 | import Control.Monad
6 |
7 | import Test.Hspec
8 |
9 | spec :: Spec
10 | spec =
11 | describe "Random engine spec" $ do
12 | it "Should create random engine" $ do
13 | (`shouldBe` Philox)
14 | =<< getRandomEngineType
15 | =<< createRandomEngine 5000 Philox
16 | (`shouldBe` Mersenne)
17 | =<< getRandomEngineType
18 | =<< createRandomEngine 5000 Mersenne
19 | (`shouldBe` ThreeFry)
20 | =<< getRandomEngineType
21 | =<< createRandomEngine 5000 ThreeFry
22 | it "Should set random engine" $ do
23 | r <- createRandomEngine 5000 ThreeFry
24 | setRandomEngine r Philox
25 | (`shouldBe` Philox) =<< getRandomEngineType r
26 | it "Should set and get seed" $ do
27 | setSeed 100
28 | (`shouldBe` 100) =<< getSeed
29 |
30 |
31 |
--------------------------------------------------------------------------------
/test/ArrayFire/SignalSpec.hs:
--------------------------------------------------------------------------------
1 | {-# LANGUAGE TypeApplications #-}
2 | module ArrayFire.SignalSpec where
3 |
4 | import qualified ArrayFire as A
5 | import Data.Int
6 | import Data.Word
7 | import Data.Complex
8 | import Data.Proxy
9 | import Foreign.C.Types
10 | import Test.Hspec
11 |
12 | spec :: Spec
13 | spec =
14 | describe "Signal spec" $ do
15 | it "Should do FFT in place" $ do
16 | A.fftInPlace (A.matrix @(Complex Double) (1,1) [[1 :+ 1]]) 10.2
17 | `shouldReturn` ()
18 | it "Should do FFT" $ do
19 | A.fft (A.matrix @(Complex Float) (1,1) [[1 :+ 1]]) 1 1
20 | `shouldBe` A.matrix @(Complex Float) (1,1) [[1 :+ 1]]
21 |
--------------------------------------------------------------------------------
/test/ArrayFire/SparseSpec.hs:
--------------------------------------------------------------------------------
1 | {-# LANGUAGE TypeApplications #-}
2 | module ArrayFire.SparseSpec where
3 |
4 | import qualified ArrayFire as A
5 | import Data.Int
6 | import Data.Word
7 | import Data.Complex
8 | import Data.Proxy
9 | import Foreign.C.Types
10 | import Test.Hspec
11 |
12 | spec :: Spec
13 | spec =
14 | describe "Sparse spec" $ do
15 | it "Should create a sparse array" $ do
16 | (1+1) `shouldBe` 2
17 | -- A.createSparseArrayFromDense (A.matrix @Double (10,10) [1..]) A.CSR
18 | -- `shouldBe`
19 | -- A.vector @Double 10 [0..]
20 |
--------------------------------------------------------------------------------
/test/ArrayFire/StatisticsSpec.hs:
--------------------------------------------------------------------------------
1 | {-# LANGUAGE TypeApplications #-}
2 | module ArrayFire.StatisticsSpec where
3 |
4 | import ArrayFire hiding (not)
5 |
6 | import Data.Complex
7 | import Test.Hspec
8 | import Test.Hspec.ApproxExpect
9 |
10 | spec :: Spec
11 | spec =
12 | describe "Statistics spec" $ do
13 | it "Should find the mean" $ do
14 | mean (vector @Double 10 [1..]) 0
15 | `shouldBe`
16 | 5.5
17 | it "Should find the weighted-mean" $ do
18 | meanWeighted (vector @Double 10 [1..]) (vector @Double 10 [1..]) 0
19 | `shouldBeApprox`
20 | 7.0
21 | it "Should find the variance" $ do
22 | var (vector @Double 8 [1..8]) False 0
23 | `shouldBe`
24 | 5.25
25 | it "Should find the weighted variance" $ do
26 | varWeighted (vector @Double 8 [1..]) (vector @Double 8 (repeat 1)) 0
27 | `shouldBe`
28 | 5.25
29 | it "Should find the standard deviation" $ do
30 | stdev (vector @Double 10 (cycle [1,-1])) 0
31 | `shouldBe`
32 | 1.0
33 | it "Should find the covariance" $ do
34 | cov (vector @Double 10 (repeat 1)) (vector @Double 10 (repeat 1)) False
35 | `shouldBe`
36 | 0.0
37 | it "Should find the median" $ do
38 | median (vector @Double 10 [1..]) 0
39 | `shouldBe`
40 | 5.5
41 | it "Should find the mean of all elements across all dimensions" $ do
42 | fst (meanAll (matrix @Double (2,2) [[10,10],[10,10]]))
43 | `shouldBe`
44 | 10
45 | it "Should find the weighted mean of all elements across all dimensions" $ do
46 | fst (meanAllWeighted (matrix @Double (2,2) [[10,10],[10,10]]) (matrix @Double (2,2) [[10,10],[10,10]]))
47 | `shouldBe`
48 | 10
49 | it "Should find the variance of all elements across all dimensions" $ do
50 | fst (varAll (vector @Double 10 (repeat 10)) False)
51 | `shouldBe`
52 | 0
53 | it "Should find the weighted variance of all elements across all dimensions" $ do
54 | fst (varAllWeighted (vector @Double 10 (repeat 10)) (vector @Double 10 (repeat 10)))
55 | `shouldBe`
56 | 0
57 | it "Should find the stdev of all elements across all dimensions" $ do
58 | fst (stdevAll (vector @Double 10 (repeat 10)))
59 | `shouldBe`
60 | 0
61 | it "Should find the median of all elements across all dimensions" $ do
62 | fst (medianAll (vector @Double 10 [1..]))
63 | `shouldBe`
64 | 5.5
65 | it "Should find the correlation coefficient" $ do
66 | fst (corrCoef (vector @Int 10 [1..] ) ( vector @Int 10 [10,9..] ))
67 | `shouldBe`
68 | (-1.0)
69 | it "Should find the top k elements" $ do
70 | let (vals,indexes) = topk ( vector @Double 10 [1..] ) 3 TopKDefault
71 | vals `shouldBe` vector @Double 3 [10,9,8]
72 | indexes `shouldBe` vector @Double 3 [9,8,7]
73 |
--------------------------------------------------------------------------------
/test/ArrayFire/UtilSpec.hs:
--------------------------------------------------------------------------------
1 | {-# LANGUAGE TypeApplications #-}
2 | module ArrayFire.UtilSpec where
3 |
4 | import qualified ArrayFire as A
5 |
6 | import Data.Complex
7 | import Data.Int
8 | import Data.Proxy
9 | import Data.Word
10 | import Foreign.C.Types
11 | import System.Directory
12 | import Test.Hspec
13 |
14 | spec :: Spec
15 | spec =
16 | describe "Util spec" $ do
17 | it "Should get size of" $ do
18 | A.getSizeOf (Proxy @Int) `shouldBe` 8
19 | A.getSizeOf (Proxy @Int64) `shouldBe` 8
20 | A.getSizeOf (Proxy @Int32) `shouldBe` 4
21 | A.getSizeOf (Proxy @Int16) `shouldBe` 2
22 | A.getSizeOf (Proxy @Word) `shouldBe` 8
23 | A.getSizeOf (Proxy @Word64) `shouldBe` 8
24 | A.getSizeOf (Proxy @Word32) `shouldBe` 4
25 | A.getSizeOf (Proxy @Word16) `shouldBe` 2
26 | A.getSizeOf (Proxy @Word8) `shouldBe` 1
27 | A.getSizeOf (Proxy @CBool) `shouldBe` 1
28 | A.getSizeOf (Proxy @Double) `shouldBe` 8
29 | A.getSizeOf (Proxy @Float) `shouldBe` 4
30 | A.getSizeOf (Proxy @(Complex Float)) `shouldBe` 8
31 | A.getSizeOf (Proxy @(Complex Double)) `shouldBe` 16
32 | it "Should get version" $ do
33 | (major, minor, patch) <- A.getVersion
34 | major `shouldBe` 3
35 | minor `shouldSatisfy` (>= 8)
36 | patch `shouldSatisfy` (>= 0)
37 | it "Should get revision" $ do
38 | x <- A.getRevision
39 | x `shouldSatisfy` (not . null)
40 | it "Should save / read array" $ do
41 | let arr = A.constant @Int [1,1,1,1] 10
42 | idx <- A.saveArray "key" arr "file.array" False
43 | doesFileExist "file.array" `shouldReturn` True
44 | (`shouldBe` idx) =<< A.readArrayKeyCheck "file.array" "key"
45 | (`shouldBe` arr) =<< A.readArrayIndex "file.array" idx
46 | (`shouldBe` arr) =<< A.readArrayKey "file.array" "key"
47 | removeFile "file.array"
48 |
49 |
--------------------------------------------------------------------------------
/test/ArrayFire/VisionSpec.hs:
--------------------------------------------------------------------------------
1 | {-# LANGUAGE TypeApplications #-}
2 | module ArrayFire.VisionSpec where
3 |
4 | import qualified ArrayFire as A
5 | import Test.Hspec
6 |
7 | spec :: Spec
8 | spec =
9 | describe "Vision spec" $ do
10 | it "Should construct Features for fast feature detection" $ do
11 | let arr = A.vector @Int 30000 [1..]
12 | let feats = A.fast arr 1.0 9 False 1.0 3
13 | (1 + 1) `shouldBe` 2
14 |
15 |
--------------------------------------------------------------------------------
/test/Main.hs:
--------------------------------------------------------------------------------
1 | {-# LANGUAGE TypeApplications #-}
2 | {-# LANGUAGE ScopedTypeVariables #-}
3 | module Main where
4 |
5 | import Control.Monad
6 |
7 | import Data.Proxy
8 | import Spec (spec)
9 | import Test.Hspec (hspec)
10 | import Test.QuickCheck
11 | import Test.QuickCheck.Classes
12 |
13 | import qualified ArrayFire as A
14 | import ArrayFire (Array)
15 |
16 | import System.IO.Unsafe
17 |
18 | instance (A.AFType a, Arbitrary a) => Arbitrary (Array a) where
19 | arbitrary = pure $ unsafePerformIO (A.randu [2,2])
20 |
21 | main :: IO ()
22 | main = do
23 | -- checks (Proxy :: Proxy (A.Array (A.Complex Float)))
24 | -- checks (Proxy :: Proxy (A.Array (A.Complex Double)))
25 | -- checks (Proxy :: Proxy (A.Array Double))
26 | -- checks (Proxy :: Proxy (A.Array Float))
27 | -- checks (Proxy :: Proxy (A.Array Double))
28 | -- checks (Proxy :: Proxy (A.Array A.Int16))
29 | -- checks (Proxy :: Proxy (A.Array A.Int32))
30 | -- checks (Proxy :: Proxy (A.Array A.CBool))
31 | -- checks (Proxy :: Proxy (A.Array Word))
32 | -- checks (Proxy :: Proxy (A.Array A.Word8))
33 | -- checks (Proxy :: Proxy (A.Array A.Word16))
34 | -- checks (Proxy :: Proxy (A.Array A.Word32))
35 | -- lawsCheck $ semigroupLaws (Proxy :: Proxy (A.Array Double))
36 | -- lawsCheck $ semigroupLaws (Proxy :: Proxy (A.Array Float))
37 | hspec spec
38 |
39 | checks proxy = do
40 | lawsCheck (numLaws proxy)
41 | lawsCheck (eqLaws proxy)
42 | lawsCheck (ordLaws proxy)
43 | -- lawsCheck (semigroupLaws proxy)
44 |
--------------------------------------------------------------------------------
/test/Spec.hs:
--------------------------------------------------------------------------------
1 | {-# OPTIONS_GHC -F -pgmF hspec-discover -optF --module-name=Spec #-}
2 |
--------------------------------------------------------------------------------
/test/Test/Hspec/ApproxExpect.hs:
--------------------------------------------------------------------------------
1 | {-# LANGUAGE TypeApplications #-}
2 | {-# LANGUAGE ScopedTypeVariables #-}
3 | module Test.Hspec.ApproxExpect where
4 |
5 | import Data.CallStack (HasCallStack)
6 |
7 | import Test.Hspec (shouldSatisfy, Expectation)
8 |
9 | infix 1 `shouldBeApprox`
10 |
11 | shouldBeApprox :: (HasCallStack, Show a, Fractional a, Eq a)
12 | => a -> a -> Expectation
13 | shouldBeApprox actual tgt
14 | -- This is a hackish way of checking, without requiring a specific
15 | -- type or an 'Ord' instance, whether two floating-point values
16 | -- are only some epsilons apart: when the difference is small enough
17 | -- so scaling it down some more makes it a no-op for addition.
18 | = actual `shouldSatisfy` \x -> (x-tgt) * 1e-4 + tgt == tgt
19 |
20 |
--------------------------------------------------------------------------------