├── .gitignore ├── LICENSE ├── README.md ├── Setup.hs ├── examples ├── a02bac.hs ├── c06puce.hs ├── interpolation.hs ├── nelder-mead.hs ├── ode-runge-kutta.hs ├── ode.hs └── one-dim-fft.hs ├── inline-c-nag.cabal └── src └── Language └── C └── Inline ├── Nag.hs └── Nag └── Internal.hsc /.gitignore: -------------------------------------------------------------------------------- 1 | dist/ 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2015 FP Complete Corporation. 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining 4 | a copy of this software and associated documentation files (the 5 | "Software"), to deal in the Software without restriction, including 6 | without limitation the rights to use, copy, modify, merge, publish, 7 | distribute, sublicense, and/or sell copies of the Software, and to 8 | permit persons to whom the Software is furnished to do so, subject to 9 | the following conditions: 10 | 11 | The above copyright notice and this permission notice shall be 12 | included in all copies or substantial portions of the Software. 13 | 14 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 15 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 16 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 17 | NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE 18 | LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 19 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION 20 | WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 21 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Using `inline-c` with NAG 2 | 3 | This package a `C.Context` and various utilities which make it easy to 4 | use the NAG library from Haskell. We present two examples which not 5 | only demonstrate that but also show a nice mix of the features available 6 | in `inline-c`. 7 | 8 | ## One dimensional FFT 9 | 10 | In this first example we will compute the forward discrete Fourier 11 | transform of a sequence of complex numbers, using the 12 | [`nag_sum_fft_complex_1d`](http://www.nag.com/numeric/CL/nagdoc_cl24/html/C06/c06pcc.html) 13 | function in the NAG library. 14 | 15 | While the example is short it showcases various features, including the 16 | already seen ordinary and vector anti-quoting; but also some NAG 17 | specific goodies such as handling of custom types (complex numbers) and 18 | error handling using the `withNagError` function, defined in the 19 | `Language.C.Inline.Nag` module provided by `inline-c-nag`. 20 | 21 | ``` 22 | {-# LANGUAGE QuasiQuotes #-} 23 | {-# LANGUAGE TemplateHaskell #-} 24 | import qualified Language.C.Inline.Nag as C 25 | import qualified Data.Vector.Storable as V 26 | import Foreign.C.Types 27 | 28 | -- Set the 'Context' to the one provided by "Language.C.Inline.Nag". 29 | -- This gives us access to NAG types such as 'Complex' and 'NagError', 30 | -- and also includes the vector and function pointers anti-quoters. 31 | C.context C.nagCtx 32 | 33 | -- Include the headers files we need. 34 | C.include "" 35 | C.include "" 36 | 37 | -- | Computes the discrete Fourier transform for the given sequence of 38 | -- 'Complex' numbers. Returns 'Left' if some error occurred, together 39 | -- with the error message. 40 | forwardFFT :: V.Vector C.Complex -> IO (Either String (V.Vector C.Complex)) 41 | forwardFFT x_orig = do 42 | -- "Thaw" the input vector -- the input is an immutable vector, and by 43 | -- "thawing" it we create a mutable copy of it. 44 | x <- V.thaw x_orig 45 | -- Use 'withNagError' to easily check whether the NAG operation was 46 | -- successful. 47 | C.withNagError $ \fail_ -> do 48 | [C.exp| void { 49 | nag_sum_fft_complex_1d( 50 | // We're computing a forward transform 51 | Nag_ForwardTransform, 52 | // We take the pointer underlying 'x' and it's length, using the 53 | // appropriate anti-quoters 54 | $vec-ptr:(Complex *x), $vec-len:x, 55 | // And pass in the NagError structure given to us by 56 | // 'withNagError'. 57 | $(NagError *fail_)) 58 | } |] 59 | -- Turn the mutable vector back to an immutable one using 'V.freeze' 60 | -- (the inverse of 'V.thaw'). 61 | V.freeze x 62 | 63 | -- Run our function with some sample data and print the results. 64 | main :: IO () 65 | main = do 66 | let vec = V.fromList 67 | [ Complex 0.34907 (-0.37168) 68 | , Complex 0.54890 (-0.35669) 69 | , Complex 0.74776 (-0.31175) 70 | , Complex 0.94459 (-0.23702) 71 | , Complex 1.13850 (-0.13274) 72 | , Complex 1.32850 0.00074 73 | , Complex 1.51370 0.16298 74 | ] 75 | printVec vec 76 | Right vec_f <- forwardFFT vec 77 | printVec vec_f 78 | where 79 | printVec vec = do 80 | V.forM_ vec $ \(Complex re im) -> putStr $ show (re, im) ++ " " 81 | putStrLn "" 82 | ``` 83 | 84 | Note how we're able to use the `nag_sum_fft_complex_1d` function just 85 | for the feature we need, using the `Nag_ForwardTransform` enum directly 86 | in the C code, instead of having to define some Haskell interface for 87 | it. Using facilities provided by `inline-c-nag` we're also able to have 88 | nice error handling, automatically extracting the error returned by NAG 89 | if something goes wrong. 90 | 91 | ### Nelder-Mead optimization 92 | 93 | For a more complex example, we'll write an Haskell function that 94 | performs Nelder-Mead optimization using the 95 | [`nag_opt_simplex_easy`](http://www.nag.com/numeric/CL/nagdoc_cl24/html/E04/e04cbc.html) 96 | function provided by NAG. 97 | 98 | ``` 99 | {-# LANGUAGE TemplateHaskell #-} 100 | {-# LANGUAGE QuasiQuotes #-} 101 | import qualified Data.Vector.Storable as V 102 | import Foreign.ForeignPtr (newForeignPtr_) 103 | import Foreign.Storable (poke) 104 | import qualified Language.C.Inline.Nag as C 105 | import Foreign.C.Types 106 | 107 | C.context C.nagCtx 108 | 109 | C.include "" 110 | C.include "" 111 | C.include "" 112 | C.include "" 113 | 114 | nelderMead 115 | :: V.Vector CDouble 116 | -- ^ Starting point 117 | -> (V.Vector CDouble -> CDouble) 118 | -- ^ Function to minimize 119 | -> C.Nag_Integer 120 | -- ^ Maximum number of iterations (must be >= 1). 121 | -> IO (Either String (CDouble, V.Vector CDouble)) 122 | -- ^ Position of the minimum. 'Left' if something went wrong, with 123 | -- error message. 'Right', together with the minimum cost and its 124 | -- position, if it could be found. 125 | nelderMead xImm pureFunct maxcal = do 126 | -- Create function that the C code will use. 127 | let funct n xc fc _comm = do 128 | xc' <- newForeignPtr_ xc 129 | let f = pureFunct $ V.unsafeFromForeignPtr0 xc' $ fromIntegral n 130 | poke fc f 131 | -- Create mutable input/output vector for C code 132 | x <- V.thaw xImm 133 | -- Call the C code 134 | C.withNagError $ \fail_ -> do 135 | minCost <- [C.block| double { 136 | // The function takes an exit parameter to store the minimum 137 | // cost. 138 | double f; 139 | // We hardcode sensible values (see NAG documentation) for the 140 | // error tolerance, computed using NAG's nag_machine_precision. 141 | double tolf = sqrt(nag_machine_precision); 142 | double tolx = sqrt(tolf); 143 | // Call the function 144 | nag_opt_simplex_easy( 145 | // Get vector length and pointer. 146 | $vec-len:x, $vec-ptr:(double *x), 147 | &f, tolf, tolx, 148 | // Pass function pointer to our Haskell function using the fun 149 | // anti-quotation. 150 | $fun:(void (*funct)(Integer n, const double *xc, double *fc, Nag_Comm *comm)), 151 | // We do not provide a "monitoring" function. 152 | NULL, 153 | // Capture Haskell variable with the max number of iterations. 154 | $(Integer maxcal), 155 | // Do not provide the Nag_Comm parameter, which we don't need. 156 | NULL, 157 | // Pass the NagError parameter provided by withNagError 158 | $(NagError *fail_)); 159 | return f; 160 | } |] 161 | -- Get a new immutable vector by freezing the mutable one. 162 | minCostPos <- V.freeze x 163 | return (minCost, minCostPos) 164 | 165 | -- Optimize a two-dimensional function. Example taken from 166 | -- . 167 | main :: IO () 168 | main = do 169 | let funct = \x -> 170 | let x0 = x V.! 0 171 | x1 = x V.! 1 172 | in exp x0 * (4*x0*(x0+x1)+2*x1*(x1+1.0)+1.0) 173 | start = V.fromList [-1, 1] 174 | Right (minCost, minPos) <- nelderMead start funct 500 175 | putStrLn $ "Minimum cost: " ++ show minCost 176 | putStrLn $ "End positition: " ++ show (minPos V.! 0) ++ ", " ++ show (minPos V.! 1) 177 | ``` 178 | 179 | Again, in this example we use a function with a very complex and 180 | powerful signature, such as the one for the Nelder-Mead optimization in 181 | NAG, in a very specific way -- avoiding the high cost of having to 182 | specify a well-designed Haskell interface for it. 183 | -------------------------------------------------------------------------------- /Setup.hs: -------------------------------------------------------------------------------- 1 | import Distribution.Simple 2 | main = defaultMain 3 | -------------------------------------------------------------------------------- /examples/a02bac.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE TemplateHaskell #-} 2 | {-# LANGUAGE QuasiQuotes #-} 3 | {-# LANGUAGE ForeignFunctionInterface #-} 4 | import qualified Language.C.Inline.Nag as C 5 | import Text.RawString.QQ (r) 6 | 7 | C.include "" 8 | C.include "" 9 | C.include "" 10 | C.include "" 11 | 12 | -- Code dump test 13 | 14 | C.verbatim [r| 15 | int test_emitCode(void) 16 | { 17 | Integer exit_status = 0; 18 | Complex v, w, z; 19 | double r, theta, x, y; 20 | Nag_Boolean equal, not_equal; 21 | 22 | printf("nag_complex (a02bac) Example Program Results\n"); 23 | 24 | x = 2.0; 25 | y = -3.0; 26 | /* nag_complex (a02bac). 27 | * Complex number from real and imaginary parts 28 | */ 29 | z = nag_complex(x, y); 30 | 31 | printf(" %-21s %s %8s = %7.4f, %7.4f\n", "", "", "x, y", 32 | x, y); 33 | printf(" %-21s: %s = %8s = (%7.4f, %7.4f)\n", "nag_complex", 34 | "z", "(x,y)", z.re, z.im); 35 | /* nag_complex_real (a02bbc). 36 | * Real part of a complex number 37 | */ 38 | printf(" %-21s: %s %8s = %7.4f\n", "nag_complex_real", 39 | "", "real(z)", nag_complex_real(z)); 40 | /* nag_complex_imag (a02bcc). 41 | * Imaginary part of a complex number 42 | */ 43 | printf(" %-21s: %s %8s = %7.4f\n", "nag_complex_imag", 44 | "", "imag(z)", nag_complex_imag(z)); 45 | /* nag_complex (a02bac), see above. */ 46 | v = nag_complex(3.0, 1.25); 47 | /* nag_complex (a02bac), see above. */ 48 | w = nag_complex(2.5, -1.75); 49 | printf(" %-21s: %s %8s = (%7.4f, %7.4f)\n", "nag_complex", "", 50 | "v", v.re, v.im); 51 | printf(" %-21s: %s %8s = (%7.4f, %7.4f)\n", "nag_complex", "", 52 | "w", w.re, w.im); 53 | /* nag_complex_add (a02cac). 54 | * Addition of two complex numbers 55 | */ 56 | z = nag_complex_add(v, w); 57 | printf(" %-21s: %s = %8s = (%7.4f, %7.4f)\n", "nag_complex_add", 58 | "z", "v+w", z.re, z.im); 59 | /* nag_complex_subtract (a02cbc). 60 | * Subtraction of two complex numbers 61 | */ 62 | z = nag_complex_subtract(v, w); 63 | printf(" %-21s: %s = %8s = (%7.4f, %7.4f)\n", 64 | "nag_complex_subtract", "z", "v-w", z.re, z.im); 65 | /* nag_complex_multiply (a02ccc). 66 | * Multiplication of two complex numbers 67 | */ 68 | z = nag_complex_multiply(v, w); 69 | printf(" %-21s: %s = %8s = (%7.4f, %7.4f)\n", 70 | "nag_complex_multiply", "z", "v*w", z.re, z.im); 71 | /* nag_complex_divide (a02cdc). 72 | * Quotient of two complex numbers 73 | */ 74 | z = nag_complex_divide(v, w); 75 | printf(" %-21s: %s = %8s = (%7.4f, %7.4f)\n", "nag_complex_divide", 76 | "z", "v/w", z.re, z.im); 77 | /* nag_complex_negate (a02cec). 78 | * Negation of a complex number 79 | */ 80 | z = nag_complex_negate(w); 81 | printf(" %-21s: %s = %8s = (%7.4f, %7.4f)\n", "nag_complex_negate", 82 | "z", "-w", z.re, z.im); 83 | /* nag_complex_conjg (a02cfc). 84 | * Conjugate of a complex number 85 | */ 86 | z = nag_complex_conjg(w); 87 | printf(" %-21s: %s = %8s = (%7.4f, %7.4f)\n", "nag_complex_conjg", 88 | "z", "conjg(w)", z.re, z.im); 89 | /* nag_complex_equal (a02cgc). 90 | * Equality of two complex numbers 91 | */ 92 | equal = nag_complex_equal(v, w); 93 | if (equal) 94 | printf(" %-21s: %s == %s\n", "nag_complex_equal", "v", "w"); 95 | else 96 | printf(" %-21s: %s != %s\n", "nag_complex_equal", "v", "w"); 97 | /* nag_complex_not_equal (a02chc). 98 | * Inequality of two complex numbers 99 | */ 100 | not_equal = nag_complex_not_equal(w, z); 101 | if (not_equal) 102 | printf(" %-21s: %s != %s\n\n", "nag_complex_not_equal", "w", "z"); 103 | else 104 | printf(" %-21s: %s == %s\n\n", "nag_complex_not_equal", "w", "z"); 105 | 106 | /* nag_complex_arg (a02dac). 107 | * Argument of a complex number 108 | */ 109 | theta = nag_complex_arg(z); 110 | printf(" %-21s: %s %8s = %7.4f\n", "nag_complex_arg", "", 111 | "arg(z)", theta); 112 | /* nag_complex_abs (a02dbc). 113 | * Modulus of a complex number 114 | */ 115 | r = nag_complex_abs(z); 116 | printf(" %-21s: %s = %8s = %7.4f\n", "nag_complex_abs", "r", 117 | "abs(z)", r); 118 | /* nag_complex_sqrt (a02dcc). 119 | * Square root of a complex number 120 | */ 121 | v = nag_complex_sqrt(z); 122 | printf(" %-21s: %s = %8s = (%7.4f, %7.4f)\n", "nag_complex_sqrt", 123 | "v", "sqrt(z)", v.re, v.im); 124 | /* nag_complex_i_power (a02ddc). 125 | * Complex number raised to integer power 126 | */ 127 | v = nag_complex_i_power(z, (Integer) 3); 128 | printf(" %-21s: %s = %8s = (%7.4f, %7.4f)\n", "nag_complex_i_power", 129 | "v", "z**3", v.re, v.im); 130 | /* nag_complex_r_power (a02dec). 131 | * Complex number raised to real power 132 | */ 133 | v = nag_complex_r_power(z, 2.5); 134 | printf(" %-21s: %s = %8s = (%7.4f, %7.4f)\n", "nag_complex_r_power", 135 | "v", "z**2.5", v.re, v.im); 136 | /* nag_complex_c_power (a02dfc). 137 | * Complex number raised to complex power 138 | */ 139 | v = nag_complex_c_power(z, w); 140 | printf(" %-21s: %s = %8s = (%7.4f,%8.4f)\n", "nag_complex_c_power", 141 | "v", "z**w", v.re, v.im); 142 | /* nag_complex_log (a02dgc). 143 | * Complex logarithm 144 | */ 145 | v = nag_complex_log(z); 146 | printf(" %-21s: %s = %8s = (%7.4f, %7.4f)\n", "nag_complex_log", 147 | "v", "log(z)", v.re, v.im); 148 | /* nag_complex_exp (a02dhc). 149 | * Complex exponential 150 | */ 151 | z = nag_complex_exp(v); 152 | printf(" %-21s: %s = %8s = (%7.4f, %7.4f)\n", "nag_complex_exp", 153 | "z", "exp(v)", z.re, z.im); 154 | /* nag_complex_sin (a02djc). 155 | * Complex sine 156 | */ 157 | v = nag_complex_sin(z); 158 | printf(" %-21s: %s = %8s = (%7.4f, %7.4f)\n", "nag_complex_sin", 159 | "v", "sin(z)", v.re, v.im); 160 | /* nag_complex_cos (a02dkc). 161 | * Complex cosine 162 | */ 163 | v = nag_complex_cos(z); 164 | printf(" %-21s: %s = %8s = (%7.4f, %7.4f)\n", "nag_complex_cos", 165 | "v", "cos(z)", v.re, v.im); 166 | /* nag_complex_tan (a02dlc). 167 | * Complex tangent 168 | */ 169 | v = nag_complex_tan(z); 170 | printf(" %-21s: %s = %8s = (%7.4f, %7.4f)\n", "nag_complex_tan", 171 | "v", "tan(z)", v.re, v.im); 172 | /* nag_complex_divide (a02cdc), see above. */ 173 | v = nag_complex_divide(nag_complex_sin(z), nag_complex_cos(z)); 174 | printf(" %-21s:%13s = (%7.4f, %7.4f)\n", "nag_complex_divide", 175 | "sin(z)/cos(z)", v.re, v.im); 176 | 177 | return exit_status; 178 | } 179 | |] 180 | 181 | main :: IO () 182 | main = do 183 | [C.exp| void{ test_emitCode() } |] 184 | 185 | -------------------------------------------------------------------------------- /examples/c06puce.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE TemplateHaskell #-} 2 | {-# LANGUAGE QuasiQuotes #-} 3 | {-# LANGUAGE ForeignFunctionInterface #-} 4 | import Control.Applicative ((<*>)) 5 | import Control.Monad (void) 6 | import qualified Data.Array.Storable as A 7 | import Data.Functor ((<$>)) 8 | import Data.Int (Int64) 9 | import Foreign.C.String (withCString) 10 | import Foreign.C.Types 11 | import qualified Language.C.Inline.Nag as C 12 | 13 | C.context C.nagCtx 14 | 15 | C.include "" 16 | C.include "" 17 | C.include "" 18 | C.include "" 19 | 20 | fi :: (Integral a, Num b) => a -> b 21 | fi = fromIntegral 22 | 23 | parseBounds :: IO (Int64, Int64) 24 | parseBounds = do 25 | (m, n) <- C.withPtrs_ $ \(m, n) -> 26 | [C.exp| void{ scanf("%*[^\n] %ld%ld%*[^\n]", $(long *m), $(long *n)) } |] 27 | return (fi m, fi n) 28 | 29 | parseData :: (Int64, Int64) -> IO (A.StorableArray (Int64, Int64) C.Complex) 30 | parseData (m0, n0) = do 31 | x <- A.newArray ((0, 0), (m0, n0)) $ C.Complex 0 0 32 | let (m, n) = (fi m0, fi n0) 33 | A.withStorableArray x $ \xPtr -> [C.block| void(Complex *xPtr) { 34 | int i; 35 | for (i = 0; i < $(Integer m) * $(Integer n); i++) 36 | scanf(" ( %lf , %lf ) ", &xPtr[i].re, &xPtr[i].im); 37 | } |] 38 | return x 39 | 40 | printGenComplxMat 41 | :: String -> A.StorableArray (Int64, Int64) C.Complex -> IO CInt 42 | printGenComplxMat str x = do 43 | ((0, 0), (m0, n0)) <- A.getBounds x 44 | let (m, n) = (fi m0, fi n0) 45 | withCString str $ \str -> A.withStorableArray x $ \xPtr -> 46 | [C.block| int { 47 | NagError fail; INIT_FAIL(fail); 48 | nag_gen_complx_mat_print_comp( 49 | Nag_RowMajor, Nag_GeneralMatrix, Nag_NonUnitDiag, $(Integer n), $(Integer m), 50 | $(Complex *xPtr), $(Integer m), Nag_BracketForm, "%6.3f", $(char *str), 51 | Nag_NoLabels, 0, Nag_NoLabels, 0, 80, 0, NULL, &fail); 52 | return fail.code != NE_NOERROR; 53 | } |] 54 | 55 | sumFftComplex2d 56 | :: CInt -> A.StorableArray (Int64, Int64) C.Complex -> IO CInt 57 | sumFftComplex2d flag x = do 58 | ((0, 0), (m0, n0)) <- A.getBounds x 59 | let (m, n) = (fi m0, fi n0) 60 | A.withStorableArray x $ \xPtr -> [C.block| int { 61 | NagError fail; INIT_FAIL(fail); 62 | nag_sum_fft_complex_2d($(int flag), $(Integer m), $(Integer n), $(Complex *xPtr), &fail); 63 | return fail.code != NE_NOERROR; 64 | } |] 65 | 66 | main :: IO () 67 | main = do 68 | bounds <- parseBounds 69 | x <- parseData bounds 70 | void $ printGenComplxMat "\n Original data values\n" x 71 | void $ sumFftComplex2d <$> [C.exp| int{ Nag_ForwardTransform } |] <*> return x 72 | void $ printGenComplxMat "\n Components of discrete Fourier transform\n" x 73 | void $ sumFftComplex2d <$> [C.exp| int{ Nag_BackwardTransform } |] <*> return x 74 | void $ printGenComplxMat "\n Original sequence as restored by inverse transform\n" x 75 | -------------------------------------------------------------------------------- /examples/interpolation.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE TemplateHaskell #-} 2 | {-# LANGUAGE QuasiQuotes #-} 3 | {-# LANGUAGE ViewPatterns #-} 4 | import Data.Coerce (coerce) 5 | import Data.Functor ((<$>)) 6 | import qualified Data.Vector.Storable as V 7 | import qualified Data.Vector.Storable.Mutable as VM 8 | import Foreign.C.Types 9 | import Graphics.Rendering.Chart.Backend.Cairo (toFile) 10 | import Graphics.Rendering.Chart.Easy (layout_title, (.=), plot, line, points, def) 11 | import qualified Language.C.Inline.Nag as C 12 | import System.Environment (getArgs) 13 | 14 | C.context C.nagCtx 15 | 16 | C.include "" 17 | C.include "" 18 | 19 | data Monotonic = Monotonic 20 | { _monotonicXs :: V.Vector CDouble 21 | , _monotonicYs :: V.Vector CDouble 22 | , _monotonicDs :: V.Vector CDouble 23 | } 24 | 25 | monotonicInterpolate 26 | :: V.Vector Double -> V.Vector Double -> IO (Either String Monotonic) 27 | monotonicInterpolate (coerce -> x) (coerce -> f) = do 28 | let n = V.length x 29 | if V.length f /= n 30 | then error "monotonicInterpolate: vectors of different lenghts" 31 | else do 32 | d <- VM.new n 33 | C.withNagError $ \fail_ -> do 34 | [C.exp| void{ nag_monotonic_interpolant( 35 | $vec-len:x, $vec-ptr:(double *x), $vec-ptr:(double *f), $vec-ptr:(double *d), $(NagError *fail_)) } |] 36 | dImm <- V.unsafeFreeze d 37 | return $ Monotonic x f dImm 38 | 39 | monotonicEvaluate :: Monotonic -> V.Vector Double -> IO (Either String (V.Vector Double)) 40 | monotonicEvaluate (Monotonic x f d) (coerce -> px) = do 41 | let m = V.length px 42 | pf <- VM.new m 43 | C.withNagError $ \fail_ -> do 44 | [C.exp| void{ nag_monotonic_evaluate( 45 | $vec-len:x, $vec-ptr:(double *x), $vec-ptr:(double *f), $vec-ptr:(double *d), 46 | $vec-len:px, $vec-ptr:(double *px), $vec-ptr:(double *pf), 47 | $(NagError *fail_)) } |] 48 | coerce <$> V.unsafeFreeze pf 49 | 50 | {- 51 | monotonicEvaluate_ :: Monotonic -> Double -> IO (Either String Double) 52 | monotonicEvaluate_ mntnc px = fmap (V.! 0) <$> monotonicEvaluate mntnc (V.fromList [px]) 53 | 54 | monotonicDeriv :: Monotonic -> V.Vector Double -> IO (Either String (V.Vector Double, V.Vector Double)) 55 | monotonicDeriv (Monotonic x f d) (coerce -> px) = do 56 | let m = V.length px 57 | pf <- VM.new m 58 | pd <- VM.new m 59 | withNagError $ \fail_ -> do 60 | [cexp| void{ nag_monotonic_deriv( 61 | $vec-len:x, $vec-ptr:(double *x), $vec-ptr:(double *f), $vec-ptr:(double *d), 62 | $vec-len:px, $vec-ptr:(double *px), $vec-ptr:(double *pf), $vec-ptr:(double *pd), 63 | $(NagError *fail_)) } |] 64 | coerce <$> ((,) <$> V.unsafeFreeze pf <*> V.unsafeFreeze pd) 65 | 66 | monotonicDeriv_ :: Monotonic -> Double -> IO (Either String (Double, Double)) 67 | monotonicDeriv_ mntnc px = do 68 | fmap (\(pf, pd) -> (pf V.! 0, pd V.! 0)) <$> monotonicDeriv mntnc (V.fromList [px]) 69 | 70 | monotonicIntg :: Monotonic -> (Double, Double) -> IO (Either String Double) 71 | monotonicIntg (Monotonic x f d) (coerce -> (a, b)) = 72 | fmap coerce $ withNagError $ \fail_ -> withPtr_ $ \integral -> 73 | [cexp| void{ nag_monotonic_intg( 74 | $vec-len:x, $vec-ptr:(double *x), $vec-ptr:(double *f), $vec-ptr:(double *d), 75 | $(double a), $(double b), $(double *integral), $(NagError *fail_)) } |] 76 | -} 77 | 78 | main :: IO () 79 | main = do 80 | args <- getArgs 81 | case args of 82 | [fn] -> do 83 | let (xs, ys) = unzip pts 84 | mntnc <- assertNag $ monotonicInterpolate (V.fromList xs) (V.fromList ys) 85 | let lineXs = filter (< 20) [7.99,8..20] ++ [20] 86 | lineYs <- V.toList <$> assertNag (monotonicEvaluate mntnc (V.fromList lineXs)) 87 | toFile def fn $ do 88 | layout_title .= "Interpolation test" 89 | plot (line "interpolation" [zip lineXs lineYs]) 90 | plot (points "points" pts) 91 | _ -> do 92 | error "usage: interpolation FILE" 93 | where 94 | assertNag m = do 95 | x <- m 96 | case x of 97 | Left err -> error err 98 | Right y -> return y 99 | 100 | pts = 101 | [ ( 7.99, 0.00000E+0) 102 | , ( 8.09, 0.27643E-4) 103 | , ( 8.19, 0.43750E-1) 104 | , ( 8.70, 0.16918E+0) 105 | , ( 9.20, 0.46943E+0) 106 | , (10.00, 0.94374E+0) 107 | , (12.00, 0.99864E+0) 108 | , (15.00, 0.99992E+0) 109 | , (20.00, 0.99999E+0) 110 | ] 111 | 112 | {- 113 | main :: IO () 114 | main = do 115 | Right mntnc <- monotonicInterpolate x f 116 | Right pf <- monotonicEvaluate mntnc px 117 | print pf 118 | where 119 | n = 9 120 | 121 | x = V.fromList 122 | [ 7.99 123 | , 8.09 124 | , 8.19 125 | , 8.70 126 | , 9.20 127 | , 10.00 128 | , 12.00 129 | , 15.00 130 | , 20.00 131 | ] 132 | 133 | f = V.fromList 134 | [ 0.00000E+0 135 | , 0.27643E-4 136 | , 0.43750E-1 137 | , 0.16918E+0 138 | , 0.46943E+0 139 | , 0.94374E+0 140 | , 0.99864E+0 141 | , 0.99992E+0 142 | , 0.99999E+0 143 | ] 144 | 145 | m = 11 146 | 147 | first = x V.! 0 148 | last = x V.! (n - 1) 149 | 150 | step = (last - first) / (m - 1); 151 | 152 | px = V.fromList $ [first,(first+step)..(last-1)] ++ [last] 153 | -} 154 | -------------------------------------------------------------------------------- /examples/nelder-mead.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE TemplateHaskell #-} 2 | {-# LANGUAGE QuasiQuotes #-} 3 | import qualified Data.Vector.Storable as V 4 | import Foreign.C.Types 5 | import Foreign.ForeignPtr (newForeignPtr_) 6 | import Foreign.Storable (poke) 7 | import qualified Language.C.Inline.Nag as C 8 | import System.IO.Unsafe (unsafePerformIO) 9 | 10 | C.context C.nagCtx 11 | 12 | C.include "" 13 | C.include "" 14 | C.include "" 15 | C.include "" 16 | 17 | {-# NOINLINE nelderMead #-} 18 | nelderMead 19 | :: V.Vector CDouble 20 | -- ^ Starting point 21 | -> (V.Vector CDouble -> CDouble) 22 | -- ^ Function to minimize 23 | -> C.Nag_Integer 24 | -- ^ Maximum number of iterations (must be >= 1). 25 | -> Either String (CDouble, V.Vector CDouble) 26 | -- ^ Position of the minimum. 'Left' if something went wrong, with 27 | -- error message. 'Right', together with the minimum cost and its 28 | -- position, if it could be found. 29 | nelderMead xImm pureFunct maxcal = unsafePerformIO $ do 30 | -- Create function that the C code will use. 31 | let funct n xc fc _comm = do 32 | xc' <- newForeignPtr_ xc 33 | let f = pureFunct $ V.unsafeFromForeignPtr0 xc' $ fromIntegral n 34 | poke fc f 35 | -- Create mutable input/output vector for C code 36 | x <- V.thaw xImm 37 | -- Call the C code 38 | C.withNagError $ \fail_ -> do 39 | minCost <- [C.block| double { 40 | // The function takes an exit parameter to store the minimum 41 | // cost. 42 | double f; 43 | // We hardcode sensible values (see NAG documentation) for the 44 | // error tolerance, computed using NAG's nag_machine_precision. 45 | double tolf = sqrt(nag_machine_precision); 46 | double tolx = sqrt(tolf); 47 | // Call the function 48 | nag_opt_simplex_easy( 49 | // Get vector length and pointer. 50 | $vec-len:x, $vec-ptr:(double *x), 51 | &f, tolf, tolx, 52 | // Pass function pointer to our Haskell function using the fun 53 | // anti-quotation. 54 | $fun:(void (*funct)(Integer n, const double *xc, double *fc, Nag_Comm *comm)), 55 | // We do not provide a "monitoring" function. 56 | NULL, 57 | // Capture Haskell variable with the max number of iterations. 58 | $(Integer maxcal), 59 | // Do not provide the Nag_Comm parameter, which we don't need. 60 | NULL, 61 | // Pass the NagError parameter provided by withNagError 62 | $(NagError *fail_)); 63 | return f; 64 | } |] 65 | -- Get a new immutable vector by freezing the mutable one. 66 | minCostPos <- V.freeze x 67 | return (minCost, minCostPos) 68 | 69 | {-# NOINLINE oneVar #-} 70 | oneVar 71 | :: (CDouble, CDouble) 72 | -- ^ Interval containing a minimum 73 | -> (CDouble -> CDouble) 74 | -- ^ Function to minimize 75 | -> C.Nag_Integer 76 | -- ^ Maximum number of iterations. 77 | -> Either String (CDouble, CDouble) 78 | oneVar (a, b) fun max_fun = unsafePerformIO $ do 79 | let funct xc fc _comm = poke fc $ fun xc 80 | C.withNagError $ \fail_ -> C.withPtrs_ $ \(x, f) -> do 81 | [C.block| void { 82 | double a = $(double a), b = $(double b); 83 | nag_opt_one_var_no_deriv( 84 | $fun:(void (*funct)(double, double*, Nag_Comm*)), 85 | 0.0, 0.0, &a, &b, $(Integer max_fun), $(double *x), $(double *f), 86 | NULL, $(NagError *fail_)); 87 | } |] 88 | 89 | -- Optimize a two-dimensional function. Example taken from 90 | -- . 91 | main :: IO () 92 | main = do 93 | let funct1 = \x -> 94 | let x0 = x V.! 0 95 | x1 = x V.! 1 96 | in exp x0 * (4*x0*(x0+x1)+2*x1*(x1+1.0)+1.0) 97 | start = V.fromList [-1, 1] 98 | let Right (minCost1, minPos1) = nelderMead start funct1 500 99 | putStrLn $ "Nelder-Mead" 100 | putStrLn $ "Minimum cost: " ++ show minCost1 101 | putStrLn $ "End positition: " ++ show (minPos1 V.! 0) ++ ", " ++ show (minPos1 V.! 1) 102 | 103 | let funct2 x = sin x / x 104 | let Right (minCost2, minPos2) = oneVar (3.5, 5) funct2 30 105 | putStrLn $ "One variable" 106 | putStrLn $ "Minimum cost: " ++ show minCost2 107 | putStrLn $ "End position: " ++ show minPos2 108 | -------------------------------------------------------------------------------- /examples/ode-runge-kutta.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE TemplateHaskell #-} 2 | {-# LANGUAGE QuasiQuotes #-} 3 | {-# LANGUAGE ScopedTypeVariables #-} 4 | import Control.Monad (forM, when) 5 | import Control.Monad.Trans.Class (lift) 6 | import Control.Monad.Trans.Control (liftBaseOp) 7 | import Control.Monad.Trans.Except (runExceptT, ExceptT(..), throwE) 8 | import Data.Functor ((<$>)) 9 | import qualified Data.Vector.Storable as V 10 | import qualified Data.Vector.Storable.Mutable as VM 11 | import Foreign.C.Types 12 | import Foreign.ForeignPtr (newForeignPtr_) 13 | import Foreign.Ptr (Ptr) 14 | import qualified Language.C.Inline.Nag as C 15 | import System.IO.Unsafe (unsafePerformIO) 16 | 17 | C.context C.nagCtx 18 | 19 | C.include "" 20 | C.include "" 21 | C.include "" 22 | C.include "" 23 | 24 | data Method 25 | = RK_2_3 26 | | RK_4_5 27 | | RK_7_8 28 | deriving (Eq, Show) 29 | 30 | data SolveOptions = SolveOptions 31 | { _soMethod :: Method 32 | , _soTolerance :: CDouble 33 | , _soInitialStepSize :: CDouble 34 | } deriving (Eq, Show) 35 | 36 | {-# NOINLINE solve #-} 37 | solve 38 | :: SolveOptions 39 | -> (CDouble -> V.Vector CDouble -> V.Vector CDouble) 40 | -- ^ ODE to solve 41 | -> [CDouble] 42 | -- ^ @x@ values at which to approximate the solution. 43 | -> V.Vector CDouble 44 | -- ^ The initial values of the solution. 45 | -> Either String [(CDouble, V.Vector CDouble)] 46 | -- ^ Either an error, or the @y@ values corresponding to the @x@ 47 | -- values input. 48 | solve (SolveOptions method tol hstart) f xs y0 = unsafePerformIO $ runExceptT $ do 49 | when (length xs < 2) $ 50 | throwE "You have to provide a minimum of 2 values for @x@" 51 | let tstart = head xs 52 | let tend = last xs 53 | iwsav <- lift $ VM.new liwsav 54 | rwsav <- lift $ VM.new lrwsav 55 | let thresh = V.replicate n 0 56 | methodInt <- lift $ case method of 57 | RK_2_3 -> [C.exp| int{ Nag_RK_2_3 } |] 58 | RK_4_5 -> [C.exp| int{ Nag_RK_4_5 } |] 59 | RK_7_8 -> [C.exp| int{ Nag_RK_7_8 } |] 60 | ExceptT $ C.withNagError $ \fail_ -> 61 | [C.exp| void{ nag_ode_ivp_rkts_setup( 62 | $(Integer n_c), $(double tstart), $(double tend), $vec-ptr:(double *y0), 63 | $(double tol), $vec-ptr:(double *thresh), $(int methodInt), 64 | Nag_ErrorAssess_off, $(double hstart), $vec-ptr:(Integer *iwsav), 65 | $vec-ptr:(double *rwsav), $(NagError *fail_)) 66 | } |] 67 | ygot <- lift $ VM.new n 68 | ypgot <- lift $ VM.new n 69 | ymax <- lift $ VM.new n 70 | let fIO :: CDouble -> C.Nag_Integer -> Ptr CDouble -> Ptr CDouble -> Ptr C.Nag_Comm -> IO () 71 | fIO t n y _yp _comm = do 72 | yFore <- newForeignPtr_ y 73 | let yVec = VM.unsafeFromForeignPtr0 yFore $ fromIntegral n 74 | ypImm <- f t <$> V.unsafeFreeze yVec 75 | V.copy yVec ypImm 76 | liftBaseOp C.initNagError $ \fail_ -> do 77 | -- Tail because the first point is the start 78 | ys <- forM (tail xs) $ \t -> do 79 | ExceptT $ C.checkNagError fail_ $ [C.block| void { 80 | double tgot; 81 | nag_ode_ivp_rkts_range( 82 | $fun:(void (*fIO)(double t, Integer n, const double y[], double yp[], Nag_Comm *comm)), 83 | $(Integer n_c), $(double t), &tgot, $vec-ptr:(double *ygot), 84 | $vec-ptr:(double *ypgot), $vec-ptr:(double *ymax), NULL, 85 | $vec-ptr:(Integer *iwsav), $vec-ptr:(double *rwsav), 86 | $(NagError *fail_)); 87 | } |] 88 | y <- lift $ V.freeze ygot 89 | return (t, y) 90 | return $ (tstart, y0) : ys 91 | where 92 | n = V.length y0 93 | liwsav = 130 94 | lrwsav = 350 + 32 * n 95 | n_c = fromIntegral n 96 | 97 | main :: IO () 98 | main = do 99 | let opts = SolveOptions RK_4_5 1e-8 0 100 | let f _t y = V.fromList [y V.! 1, -(y V.! 0)] 101 | case solve opts f [0,pi/4..pi] (V.fromList [0, 1]) of 102 | Left err -> putStrLn err 103 | Right x -> print x 104 | -------------------------------------------------------------------------------- /examples/ode.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE TemplateHaskell #-} 2 | {-# LANGUAGE QuasiQuotes #-} 3 | {-# LANGUAGE RecordWildCards #-} 4 | {-# LANGUAGE ScopedTypeVariables #-} 5 | import Data.Functor ((<$>)) 6 | import Data.IORef 7 | import qualified Data.Vector.Storable as V 8 | import qualified Data.Vector.Storable.Mutable as VM 9 | import Foreign.C.Types 10 | import Foreign.ForeignPtr (newForeignPtr_) 11 | import Foreign.Ptr (Ptr) 12 | import Foreign.Ptr (nullFunPtr) 13 | import Foreign.Storable (Storable, peek, poke) 14 | import qualified Language.C.Inline.Nag as C 15 | import System.IO.Unsafe (unsafePerformIO) 16 | 17 | C.context C.nagCtx 18 | 19 | C.include "" 20 | C.include "" 21 | C.include "" 22 | 23 | -- Main types 24 | ------------------------------------------------------------------------ 25 | 26 | type Fn 27 | = CDouble 28 | -- ^ The indipendent variable @x@ 29 | -> V.Vector CDouble 30 | -- ^ @y_i@ for @i = 1, 2, ..., neq@ 31 | -> V.Vector CDouble 32 | -- ^ @f_i@ for @i = 1, 2, ..., neq@ 33 | 34 | type Jacobian 35 | = CDouble 36 | -- ^ The indipendent variable @x@ 37 | -> V.Vector CDouble 38 | -- ^ @y_i@ for @i = 1, 2, ..., neq@ 39 | -> V.Vector CDouble 40 | -- ^ Jacobian matrix, @pw[(i - 1) * neq + j - 1@ must contain the 41 | -- value of @∂f_i/∂y_j@, for @i, j = 1, 2, ..., neq@ 42 | 43 | type Interval = (CDouble, CDouble) 44 | 45 | data Failure = Failure 46 | { failureMessage :: String 47 | , _failureAt :: CDouble 48 | } deriving (Eq, Show) 49 | 50 | data ErrorControl 51 | = Relative 52 | | Absolute 53 | | Mixed 54 | deriving (Eq, Show) 55 | 56 | data Options = Options 57 | { optionsTolerance :: !CDouble 58 | , optionsErrorControl :: !ErrorControl 59 | } deriving (Eq, Show) 60 | 61 | -- IO solving 62 | ------------------------------------------------------------------------ 63 | 64 | data OutputIO a = OutputIO 65 | { outputIOStartState :: a 66 | , outputIOStep :: a -> CDouble -> V.Vector CDouble -> IO (a, CDouble) 67 | } 68 | 69 | solveIO 70 | :: Options 71 | -> Fn 72 | -> Maybe Jacobian 73 | -> Interval 74 | -> Maybe (OutputIO a) 75 | -> V.Vector CDouble 76 | -> IO (Either Failure (V.Vector CDouble, Maybe a)) 77 | solveIO Options{..} fcn mbJac (x, xend) mbOutput y = do 78 | -- IO version of the right-hande function 79 | let fcnIO neq x y f _comm = do 80 | fImm <- fcn x <$> vectorFromC neq y 81 | vectorToC fImm neq f 82 | -- Function pointer for the Jacobian function. We use a function 83 | -- pointer directly because we want it to be NULL if the user hasn't 84 | -- provided a function. 85 | jacFunPtr <- case mbJac of 86 | Nothing -> return nullFunPtr 87 | Just jac -> do 88 | let jacIO neq x y pw _comm = do 89 | pwImm <- jac x <$> vectorFromC neq y 90 | vectorToC pwImm (neq*neq) pw 91 | $(C.mkFunPtr [t| C.Nag_Integer -> CDouble -> Ptr CDouble -> Ptr CDouble -> Ptr C.Nag_User -> IO () |]) jacIO 92 | (outputFunPtr, outputGetResult) <- case mbOutput of 93 | Nothing -> return (nullFunPtr, return Nothing) 94 | Just OutputIO{..} -> do 95 | outputStateRef <- newIORef outputIOStartState 96 | let outputIO neq xsol y _comm = do 97 | x <- peek xsol 98 | y' <- vectorFromC neq y 99 | outputState <- readIORef outputStateRef 100 | (outputState', x') <- outputIOStep outputState x y' 101 | writeIORef outputStateRef outputState' 102 | poke xsol x' 103 | outputFunPtr <- $(C.mkFunPtr [t| C.Nag_Integer -> Ptr CDouble -> Ptr CDouble -> Ptr C.Nag_User -> IO ()|]) outputIO 104 | return (outputFunPtr, Just <$> readIORef outputStateRef) 105 | -- Error control 106 | err <- case optionsErrorControl of 107 | Relative -> [C.exp| Nag_ErrorControl{ Nag_Relative } |] 108 | Absolute -> [C.exp| Nag_ErrorControl{ Nag_Absolute } |] 109 | Mixed -> [C.exp| Nag_ErrorControl{ Nag_Mixed } |] 110 | -- Record the last visited x in an 'IORef' to store it in the 111 | -- 'Failure' if there was a problem. 112 | xendRef <- newIORef x 113 | yMut <- V.thaw y 114 | res <- C.withNagError $ \fail_ -> do 115 | xend' <- [C.block| double { 116 | double x = $(double x); 117 | Nag_User comm; 118 | nag_ode_ivp_bdf_gen( 119 | $vec-len:yMut, 120 | $fun:(void (*fcnIO)(Integer neq, double x, const double y[], double f[], Nag_User *comm)), 121 | $(void (*jacFunPtr)(Integer neq, double x, const double y[], double pw[], Nag_User *comm)), 122 | &x, $vec-ptr:(double yMut[]), $(double xend), 123 | $(double optionsTolerance), $(Nag_ErrorControl err), 124 | $(void (*outputFunPtr)(Integer neq, double *xsol, const double y[], Nag_User *comm)), 125 | NULLDFN, &comm, $(NagError *fail_)); 126 | return x; 127 | } |] 128 | writeIORef xendRef xend' 129 | case res of 130 | Left s -> do 131 | xend' <- readIORef xendRef 132 | return $ Left $ Failure s xend' 133 | Right () -> do 134 | y' <- V.freeze yMut 135 | mbOutput <- outputGetResult 136 | return $ Right (y', mbOutput) 137 | 138 | -- Pure solver 139 | ------------------------------------------------------------------------ 140 | 141 | data Output a = Output 142 | { outputStartState :: a 143 | , outputStep :: a -> CDouble -> V.Vector CDouble -> (a, CDouble) 144 | } 145 | 146 | {- 147 | outputInterval :: CDouble -> Output [(CDouble, V.Vector CDouble)] 148 | outputInterval interval = Output 149 | { outputStartState = [] 150 | , outputStep = \xs x y -> (xs ++ [(x, y)], x + interval) 151 | } 152 | 153 | outputFixed :: [CDouble] -> Output ([CDouble], [(CDouble, V.Vector CDouble)]) 154 | outputFixed xs = Output 155 | { outputStartState = (xs, []) 156 | , outputStep = \(steps, ys) x y -> case steps of 157 | [] -> ((steps, ys), x+1) 158 | step:steps -> ((steps, ys ++ [(x, y)]), step) 159 | } 160 | 161 | outputNothing :: CDouble -> Output () 162 | outputNothing x = Output 163 | { outputStartState = () 164 | , outputStep = \() _ _ -> ((), x) 165 | } 166 | -} 167 | 168 | {-# NOINLINE solve #-} 169 | solve 170 | :: Options 171 | -> Fn 172 | -> Maybe Jacobian 173 | -> Interval 174 | -> Maybe (Output a) 175 | -> V.Vector CDouble 176 | -> Either Failure (V.Vector CDouble, Maybe a) 177 | solve opts fcn mbJac int mbOutput y = unsafePerformIO $ 178 | solveIO opts fcn mbJac int mbOutputIO y 179 | where 180 | mbOutputIO = case mbOutput of 181 | Nothing -> Nothing 182 | Just Output{..} -> Just $ OutputIO outputStartState $ \s x y -> return $ outputStep s x y 183 | 184 | -- Oregonator 185 | ------------------------------------------------------------------------ 186 | 187 | oregonator :: IO (V.Vector CDouble) 188 | oregonator = do 189 | let x = 0 190 | let xend = 360 191 | let tol = 1e-5 192 | let res = solve (Options tol Relative) f (Just jac) (x, xend) Nothing y0 193 | case res of 194 | Left err -> error $ "Oregonator failed " ++ failureMessage err 195 | Right (v, _) -> return v 196 | where 197 | f :: Fn 198 | f _ y = 199 | let y1 = y V.! 0 ; y2 = y V.! 1 ; y3 = y V.! 2 200 | in V.fromList 201 | [ s * (y2 - y1 * y2 + y1 - q * (y1 * y1)) 202 | , (-y2 - y1 * y2 + y3) / s 203 | , w * (y1 - y3) 204 | ] 205 | 206 | jac :: Jacobian 207 | jac _ y = 208 | let y1 = y V.! 0 ; y2 = y V.! 1 ; _y3 = y V.! 2 209 | in V.fromList 210 | [ s * (1 - y2 - 2 * q * y1), s * (1 - y1), 0 211 | , -y2 / s, (-1 - y1) / s, 1 / s 212 | , w, 0, -2 213 | ] 214 | 215 | y0 = V.fromList [1, 2, 3] 216 | 217 | s = 77.27 ; q = 8.375E-06 ; w = 0.161 218 | 219 | -- Hires 220 | ------------------------------------------------------------------------ 221 | 222 | hires :: IO (V.Vector CDouble) 223 | hires = do 224 | let x = 0 225 | let xend = 321.8122 226 | let res = solve (Options tol Relative) f (Just jac) (x, xend) Nothing y0 227 | case res of 228 | Left err -> error $ "Hires failed " ++ failureMessage err 229 | Right (v, _) -> return v 230 | where 231 | f :: Fn 232 | f _ y = 233 | let y1 = y V.! 0 ; y2 = y V.! 1 ; y3 = y V.! 2 ; y4 = y V.! 3 234 | y5 = y V.! 4 ; y6 = y V.! 5 ; y7 = y V.! 6 ; y8 = y V.! 7 235 | in V.fromList 236 | [ -1.71 * y1 + 0.43 * y2 + 8.32 * y3 + 0.0007 237 | , 1.71 * y1 - 8.75 * y2 238 | , -10.03 * y3 + 0.43 * y4 + 0.035 * y5 239 | , 8.32 * y2 + 1.71 * y3 - 1.12 * y4 240 | , -1.745 * y5 + 0.43 * y6 + 0.43 * y7 241 | , -280 * y6 * y8 + 0.69 * y4 + 1.71 * y5 - 0.43 * y6 + 0.69 * y7 242 | , 280 * y6 * y8 - 1.81 * y7 243 | , -280 * y6 * y8 + 1.81 * y7 244 | ] 245 | 246 | jac :: Jacobian 247 | jac _ y = 248 | let _y1 = y V.! 0 ; _y2 = y V.! 1 ; _y3 = y V.! 2 ; _y4 = y V.! 3 249 | _y5 = y V.! 4 ; y6 = y V.! 5 ; _y7 = y V.! 6 ; y8 = y V.! 7 250 | in V.fromList 251 | [ -1.71, 0.43, 8.32, 0, 0, 0, 0, 0 252 | , 1.71, -8.75, 0, 0, 0, 0, 0, 0 253 | , 0, 0, -10.03, 0.43, 0.035, 0, 0, 0 254 | , 0, 8.32, 1.71, -1.12, 0, 0, 0, 0 255 | , 0, 0, 0, 0, -1.745, 0.43, 0.43, 0 256 | , 0, 0, 0, 0.69, 1.71, -280 * y8 - 0.43, 0.69, -280 * y6 257 | , 0, 0, 0, 0, 0, 280 * y8, -1.81, 280 * y6 258 | , 0, 0, 0, 0, 0, -280 * y8, 1.81, -280 * y6 259 | ] 260 | 261 | y0 = V.fromList [1, 0, 0, 0, 0, 0, 0, 0.0057] 262 | 263 | tol = 1.0e-6 264 | 265 | -- NAG 266 | ------------------------------------------------------------------------ 267 | 268 | nagTest :: IO (V.Vector CDouble) 269 | nagTest = do 270 | let fcn _x y = 271 | let y1 = y V.! 0 272 | y2 = y V.! 1 273 | y3 = y V.! 2 274 | in V.fromList 275 | [ y1 * (-0.04) + y2 * 1e4 * y3 276 | , y1 * 0.04 - y2 * 1e4 * y3 - y2 * 3e7 * y2 277 | , y2 * 3e7 * y2 278 | ] 279 | let jac _x y = 280 | let _y1 = y V.! 0 281 | y2 = y V.! 1 282 | y3 = y V.! 2 283 | in V.fromList 284 | [ -0.04, y3 * 1e4, y2 * 1e4 285 | , 0.04, y3 * (-1e4) - y2 * 6e7, y2 * (-1e4) 286 | , 0.0, y2 * 6e7, 0.0 287 | ] 288 | let x = 1 289 | let y = V.fromList [1.0, 0.0, 0.0] 290 | let tol = 10**(-3) 291 | let res = solve (Options tol Relative) fcn (Just jac) (x, 10) Nothing y 292 | case res of 293 | Left err -> error $ "NAG test failed " ++ show err 294 | Right (y', _) -> return y' 295 | 296 | -- Main 297 | ------------------------------------------------------------------------ 298 | 299 | main :: IO () 300 | main = do 301 | print =<< nagTest 302 | print =<< oregonator 303 | print =<< hires 304 | 305 | -- Utils 306 | ------------------------------------------------------------------------ 307 | 308 | vectorFromC :: Storable a => C.Nag_Integer -> Ptr a -> IO (V.Vector a) 309 | vectorFromC len ptr = do 310 | ptr' <- newForeignPtr_ ptr 311 | V.freeze $ VM.unsafeFromForeignPtr0 ptr' $ fromIntegral len 312 | 313 | vectorToC :: Storable a => V.Vector a -> C.Nag_Integer -> Ptr a -> IO () 314 | vectorToC vec neq ptr = do 315 | ptr' <- newForeignPtr_ ptr 316 | V.copy (VM.unsafeFromForeignPtr0 ptr' $ fromIntegral neq) vec 317 | -------------------------------------------------------------------------------- /examples/one-dim-fft.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE QuasiQuotes #-} 2 | {-# LANGUAGE TemplateHaskell #-} 3 | import qualified Language.C.Inline.Nag as C 4 | import qualified Data.Vector.Storable as V 5 | import Foreign.C.Types 6 | 7 | -- Set the 'Context' to the one provided by "Language.C.Inline.Nag". 8 | -- This gives us access to NAG types such as 'C.Complex' and 'C.NagError', 9 | -- and also includes the vector and function pointers anti-quoters. 10 | C.context C.nagCtx 11 | 12 | -- Include the headers files we need. 13 | C.include "" 14 | C.include "" 15 | 16 | -- | Computes the discrete Fourier transform for the given sequence of 17 | -- 'Complex' numbers. Returns 'Left' if some error occurred, together 18 | -- with the error message. 19 | forwardFFT :: V.Vector C.Complex -> IO (Either String (V.Vector C.Complex)) 20 | forwardFFT x_orig = do 21 | -- "Thaw" the input vector -- the input is an immutable vector, and by 22 | -- "thawing" it we create a mutable copy of it. 23 | x <- V.thaw x_orig 24 | -- Use 'C.withNagError' to easily check whether the NAG operation was 25 | -- successful. 26 | C.withNagError $ \fail_ -> do 27 | [C.exp| void { 28 | nag_sum_fft_complex_1d( 29 | // We're computing a forward transform 30 | Nag_ForwardTransform, 31 | // We take the pointer underlying 'x' and it's length, using the 32 | // appropriate anti-quoters 33 | $vec-ptr:(Complex *x), $vec-len:x, 34 | // And pass in the NagError structure given to us by 35 | // 'withNagError'. 36 | $(NagError *fail_)) 37 | } |] 38 | -- Turn the mutable vector back to an immutable one using 'V.freeze' 39 | -- (the inverse of 'V.thaw'). 40 | V.freeze x 41 | 42 | -- Run our function with some sample data and print the results. 43 | main :: IO () 44 | main = do 45 | let vec = V.fromList 46 | [ C.Complex 0.34907 (-0.37168) 47 | , C.Complex 0.54890 (-0.35669) 48 | , C.Complex 0.74776 (-0.31175) 49 | , C.Complex 0.94459 (-0.23702) 50 | , C.Complex 1.13850 (-0.13274) 51 | , C.Complex 1.32850 0.00074 52 | , C.Complex 1.51370 0.16298 53 | ] 54 | printVec vec 55 | Right vec_f <- forwardFFT vec 56 | printVec vec_f 57 | where 58 | printVec vec = do 59 | V.forM_ vec $ \(C.Complex re im) -> putStr $ show (re, im) ++ " " 60 | putStrLn "" 61 | -------------------------------------------------------------------------------- /inline-c-nag.cabal: -------------------------------------------------------------------------------- 1 | name: inline-c-nag 2 | version: 0.1.0.0 3 | synopsis: Utilities to use inline-c with NAG. 4 | description: See 5 | license: MIT 6 | license-file: LICENSE 7 | author: Francesco Mazzoli 8 | maintainer: francesco@fpcomplete.com 9 | copyright: (c) 2015 FP Complete Corporation 10 | category: Math 11 | build-type: Simple 12 | cabal-version: >=1.10 13 | 14 | source-repository head 15 | type: git 16 | location: https://github.com/fpco/inline-c-nag 17 | 18 | flag examples 19 | description: Build examples 20 | default: False 21 | 22 | library 23 | exposed-modules: Language.C.Inline.Nag 24 | other-modules: Language.C.Inline.Nag.Internal 25 | c-sources: src/Language/C/Inline/Nag.c 26 | build-depends: base >=4.7 && <5 27 | , inline-c 28 | , template-haskell 29 | , containers 30 | hs-source-dirs: src 31 | default-language: Haskell2010 32 | include-dirs: /opt/NAG/cll6i24dcl/include 33 | ghc-options: -Wall -fPIC 34 | cc-options: -m64 35 | 36 | -- Examples 37 | 38 | executable a02bac 39 | hs-source-dirs: examples 40 | main-is: a02bac.hs 41 | c-sources: examples/a02bac.c 42 | default-language: Haskell2010 43 | cc-options: -m64 44 | extra-libraries: nagc_nag imf svml irng intlc 45 | extra-lib-dirs: /opt/NAG/cll6i24dcl/lib /opt/NAG/cll6i24dcl/rtl/intel64 46 | include-dirs: /opt/NAG/cll6i24dcl/include 47 | ghc-options: -Wall -fPIC 48 | 49 | if flag(examples) 50 | buildable: True 51 | build-depends: base >=4 && <5 52 | , inline-c 53 | , inline-c-nag 54 | , raw-strings-qq 55 | else 56 | buildable: False 57 | 58 | executable c06puce 59 | hs-source-dirs: examples 60 | main-is: c06puce.hs 61 | c-sources: examples/c06puce.c 62 | default-language: Haskell2010 63 | cc-options: -m64 -w 64 | extra-libraries: nagc_nag imf svml irng intlc 65 | extra-lib-dirs: /opt/NAG/cll6i24dcl/lib /opt/NAG/cll6i24dcl/rtl/intel64 66 | include-dirs: /opt/NAG/cll6i24dcl/include 67 | ghc-options: -Wall -fPIC 68 | 69 | if flag(examples) 70 | buildable: True 71 | build-depends: base >=4 && <5 72 | , array 73 | , inline-c 74 | , inline-c-nag 75 | , raw-strings-qq 76 | else 77 | buildable: False 78 | 79 | executable one-dim-fft 80 | hs-source-dirs: examples 81 | main-is: one-dim-fft.hs 82 | c-sources: examples/one-dim-fft.c 83 | default-language: Haskell2010 84 | cc-options: -m64 85 | extra-libraries: nagc_nag imf svml irng intlc 86 | extra-lib-dirs: /opt/NAG/cll6i24dcl/lib /opt/NAG/cll6i24dcl/rtl/intel64 87 | include-dirs: /opt/NAG/cll6i24dcl/include 88 | ghc-options: -Wall -fPIC 89 | 90 | if flag(examples) 91 | buildable: True 92 | build-depends: base >=4 && <5 93 | , array 94 | , inline-c 95 | , inline-c-nag 96 | , vector 97 | else 98 | buildable: False 99 | 100 | executable nelder-mead 101 | hs-source-dirs: examples 102 | main-is: nelder-mead.hs 103 | c-sources: examples/nelder-mead.c 104 | default-language: Haskell2010 105 | cc-options: -m64 106 | extra-libraries: nagc_nag imf svml irng intlc 107 | extra-lib-dirs: /opt/NAG/cll6i24dcl/lib /opt/NAG/cll6i24dcl/rtl/intel64 108 | include-dirs: /opt/NAG/cll6i24dcl/include 109 | ghc-options: -Wall -fPIC 110 | 111 | if flag(examples) 112 | buildable: True 113 | build-depends: base >=4 && <5 114 | , array 115 | , inline-c 116 | , inline-c-nag 117 | , vector 118 | , async 119 | else 120 | buildable: False 121 | 122 | executable ode 123 | hs-source-dirs: examples 124 | main-is: ode.hs 125 | c-sources: examples/ode.c 126 | default-language: Haskell2010 127 | cc-options: -m64 128 | extra-libraries: nagc_nag imf svml irng intlc 129 | extra-lib-dirs: /opt/NAG/cll6i24dcl/lib /opt/NAG/cll6i24dcl/rtl/intel64 130 | include-dirs: /opt/NAG/cll6i24dcl/include 131 | ghc-options: -Wall -fPIC -fno-warn-name-shadowing 132 | 133 | if flag(examples) 134 | buildable: True 135 | build-depends: base >=4 && <5 136 | , inline-c 137 | , inline-c-nag 138 | , vector 139 | else 140 | buildable: False 141 | 142 | executable ode-runge-kutta 143 | hs-source-dirs: examples 144 | main-is: ode-runge-kutta.hs 145 | c-sources: examples/ode-runge-kutta.c 146 | default-language: Haskell2010 147 | cc-options: -m64 148 | extra-libraries: nagc_nag imf svml irng intlc 149 | extra-lib-dirs: /opt/NAG/cll6i24dcl/lib /opt/NAG/cll6i24dcl/rtl/intel64 150 | include-dirs: /opt/NAG/cll6i24dcl/include 151 | ghc-options: -Wall -fPIC -fno-warn-name-shadowing 152 | 153 | if flag(examples) 154 | buildable: True 155 | build-depends: base >=4 && <5 156 | , inline-c-nag 157 | , vector 158 | , transformers 159 | , transformers-compat 160 | , monad-control 161 | else 162 | buildable: False 163 | 164 | executable interpolation 165 | hs-source-dirs: examples 166 | main-is: interpolation.hs 167 | c-sources: examples/interpolation.c 168 | default-language: Haskell2010 169 | cc-options: -m64 170 | extra-libraries: nagc_nag imf svml irng intlc 171 | extra-lib-dirs: /opt/NAG/cll6i24dcl/lib /opt/NAG/cll6i24dcl/rtl/intel64 172 | include-dirs: /opt/NAG/cll6i24dcl/include 173 | ghc-options: -Wall -fPIC -fno-warn-name-shadowing 174 | 175 | if flag(examples) 176 | buildable: True 177 | build-depends: base >=4 && <5 178 | , inline-c-nag 179 | , vector 180 | , Chart 181 | , Chart-cairo 182 | else 183 | buildable: False 184 | -------------------------------------------------------------------------------- /src/Language/C/Inline/Nag.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE QuasiQuotes #-} 2 | {-# LANGUAGE TemplateHaskell #-} 3 | {-# LANGUAGE OverloadedStrings #-} 4 | {-# LANGUAGE RecordWildCards #-} 5 | module Language.C.Inline.Nag 6 | ( module Language.C.Inline 7 | -- * Context 8 | , nagCtx 9 | -- * Utilities 10 | , withNagError 11 | , initNagError 12 | , checkNagError 13 | -- * Types 14 | , Complex(..) 15 | , NagError 16 | , Nag_Boolean 17 | , Nag_Integer 18 | , Nag_Comm 19 | , Nag_User 20 | ) where 21 | 22 | import Prelude hiding (exp) 23 | 24 | import Data.Functor ((<$>)) 25 | import Foreign.C.String (peekCString) 26 | import Foreign.C.Types 27 | import Foreign.Marshal.Alloc (alloca) 28 | import Foreign.Ptr (Ptr) 29 | 30 | import Language.C.Inline.Nag.Internal 31 | import Language.C.Inline 32 | 33 | context nagCtx 34 | 35 | include "" 36 | 37 | -- | Allocates a @'Ptr' 'NagError'@ which can be used with many of the 38 | -- NAG functions. After the action has run, it inspects its contents 39 | -- and reports an error if present. 40 | withNagError :: (Ptr NagError -> IO a) -> IO (Either String a) 41 | withNagError f = initNagError $ \ptr -> checkNagError ptr $ f ptr 42 | 43 | -- | Like 'withNagError', but with no error check. 44 | initNagError :: (Ptr NagError -> IO a) -> IO a 45 | initNagError f = alloca $ \ptr -> do 46 | [exp| void{ INIT_FAIL(*$(NagError *ptr)) } |] 47 | f ptr 48 | 49 | -- | Runs the provided actoin, and checks if the 'NagError' reports 50 | -- an error. 51 | checkNagError :: Ptr NagError -> IO a -> IO (Either String a) 52 | checkNagError ptr f = do 53 | x <- f 54 | errCode <- [exp| int { $(NagError *ptr)->code } |] 55 | if errCode /= _NE_NOERROR 56 | then do 57 | ch <- [exp| char * { $(NagError *ptr)->message } |] 58 | Left <$> peekCString ch 59 | else return $ Right x 60 | -------------------------------------------------------------------------------- /src/Language/C/Inline/Nag/Internal.hsc: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE QuasiQuotes #-} 2 | {-# LANGUAGE TemplateHaskell #-} 3 | {-# LANGUAGE OverloadedStrings #-} 4 | {-# LANGUAGE RecordWildCards #-} 5 | module Language.C.Inline.Nag.Internal 6 | ( -- * Types 7 | Complex(..) 8 | , NagError 9 | , _NE_NOERROR 10 | , Nag_Boolean 11 | , Nag_ErrorControl 12 | , Nag_Integer 13 | , Nag_Comm 14 | , Nag_User 15 | -- * Context 16 | , nagCtx 17 | ) where 18 | 19 | import qualified Data.Map as Map 20 | import Data.Monoid ((<>), mempty) 21 | import Foreign.C.Types 22 | import Foreign.Ptr (Ptr) 23 | import Foreign.Storable (Storable(..)) 24 | import qualified Language.Haskell.TH as TH 25 | 26 | import Language.C.Inline 27 | import Language.C.Inline.Context 28 | import qualified Language.C.Types as C 29 | 30 | #include 31 | 32 | -- * Records 33 | 34 | data Complex = Complex 35 | { complRe :: {-# UNPACK #-} !CDouble 36 | , complIm :: {-# UNPACK #-} !CDouble 37 | } deriving (Show, Read, Eq, Ord) 38 | 39 | instance Storable Complex where 40 | sizeOf _ = (#size Complex) 41 | alignment _ = alignment (undefined :: Ptr CDouble) 42 | peek ptr = do 43 | re <- (#peek Complex, re) ptr 44 | im <- (#peek Complex, im) ptr 45 | return Complex{complRe = re, complIm = im} 46 | poke ptr Complex{..} = do 47 | (#poke Complex, re) ptr complRe 48 | (#poke Complex, im) ptr complIm 49 | 50 | data NagError 51 | 52 | instance Storable NagError where 53 | sizeOf _ = (#size NagError) 54 | alignment _ = alignment (undefined :: Ptr ()) 55 | peek = error "peek not implemented for NagError" 56 | poke _ _ = error "poke not implemented for NagError" 57 | 58 | -- | Code indicating no errors (usually in a 'NagError' structure) 59 | _NE_NOERROR :: CInt 60 | _NE_NOERROR = (#const NE_NOERROR) 61 | 62 | data Nag_Comm 63 | instance Storable Nag_Comm where 64 | sizeOf _ = (#size Nag_Comm) 65 | alignment _ = alignment (undefined :: Ptr ()) 66 | peek _ = error "peek not implemented for Nag_Comm" 67 | poke _ _ = error "poke not implemented for Nag_Comm" 68 | 69 | data Nag_User 70 | instance Storable Nag_User where 71 | sizeOf _ = (#size Nag_User) 72 | alignment _ = alignment (undefined :: Ptr ()) 73 | peek _ = error "peek not implemented for Nag_User" 74 | poke _ _ = error "poke not implemented for Nag_User" 75 | 76 | -- * Enums 77 | 78 | type Nag_Boolean = CInt 79 | type Nag_ErrorControl = CInt 80 | 81 | -- * Utils 82 | 83 | type Nag_Integer = CLong 84 | 85 | -- * Context 86 | 87 | nagCtx :: Context 88 | nagCtx = baseCtx <> funCtx <> vecCtx <> ctx 89 | where 90 | ctx = mempty 91 | { ctxTypesTable = nagTypesTable 92 | } 93 | 94 | nagTypesTable :: Map.Map C.TypeSpecifier TH.TypeQ 95 | nagTypesTable = Map.fromList 96 | [ -- TODO this might not be a long, see nag_types.h 97 | (C.TypeName "Integer", [t| Nag_Integer |]) 98 | , (C.TypeName "Complex", [t| Complex |]) 99 | , (C.TypeName "NagError", [t| NagError |]) 100 | , (C.TypeName "Nag_Boolean", [t| Nag_Boolean |]) 101 | , (C.TypeName "Nag_Comm", [t| Nag_Comm |]) 102 | , (C.TypeName "Nag_User", [t| Nag_User |]) 103 | , (C.TypeName "Nag_ErrorControl", [t| Nag_ErrorControl |]) 104 | ] 105 | --------------------------------------------------------------------------------