├── .gitignore ├── .travis.yml ├── .travis ├── install_menoh.sh └── install_mkldnn.sh ├── LICENSE ├── README.md ├── Setup.hs ├── app ├── mnist_example.hs └── vgg16_example.hs ├── appveyor.yml ├── data ├── 0.png ├── 1.png ├── 2.png ├── 3.png ├── 4.png ├── 5.png ├── 6.png ├── 7.png ├── 8.png ├── 9.png └── mnist.onnx ├── menoh.cabal ├── mnist_example.ipynb ├── retrieve_data.hs ├── retrieve_data.sh ├── src ├── Menoh.hs └── Menoh │ └── Base.hsc ├── stack-ghc-7.10.yaml ├── stack-ghc-7.8.yaml ├── stack-ghc-8.0.yaml ├── stack-ghc-8.2.yaml ├── stack.yaml ├── test └── test.hs └── vgg16_example.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | *.o 2 | *.hi 3 | .stack-work 4 | dist 5 | 6 | 7 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | sudo: required 2 | dist: trusty 3 | 4 | # Do not choose a language; we provide our own build tools. 5 | language: generic 6 | 7 | addons: 8 | apt: 9 | sources: 10 | - ubuntu-toolchain-r-test 11 | - sourceline: 'ppa:maarten-fonville/protobuf' 12 | packages: 13 | - libgmp-dev 14 | - gcc-7 15 | - g++-7 16 | - cmake-data 17 | - cmake 18 | - libopencv-dev 19 | - libprotobuf-dev 20 | - protobuf-compiler 21 | 22 | matrix: 23 | include: 24 | - env: STACK_YAML=stack.yaml 25 | compiler: ": #GHC 8.4.3" 26 | - env: STACK_YAML=stack.yaml 27 | compiler: ": #GHC 8.4.3" 28 | os: osx 29 | - env: STACK_YAML=stack-ghc-8.2.yaml 30 | compiler: ": #GHC 8.2.2" 31 | - env: STACK_YAML=stack-ghc-8.0.yaml 32 | compiler: ": #GHC 8.0.2" 33 | - env: STACK_YAML=stack-ghc-7.10.yaml 34 | compiler: ": #GHC 7.10.3" 35 | - env: STACK_YAML=stack-ghc-7.8.yaml 36 | compiler: ": #GHC 7.8.4" 37 | 38 | env: 39 | global: 40 | - MKL_DNN_REV: v0.16 41 | - MKL_DNN_INSTALL_SUFFIX: -0.16 42 | - MENOH_REV: 423af528c9627ec41f0dfbc81890e27d66121bde 43 | - MENOH_INSTALL_SUFFIX: -423af528c9627ec41f0dfbc81890e27d66121bde 44 | 45 | # Caching so the next build will be fast too. 46 | cache: 47 | directories: 48 | - $HOME/.stack 49 | - $HOME/mkl-dnn${MKL_DNN_INSTALL_SUFFIX} 50 | - $HOME/menoh${MENOH_INSTALL_SUFFIX} 51 | 52 | before_install: 53 | # Download and unpack the stack executable 54 | - mkdir -p $HOME/.local/bin 55 | - export PATH=$HOME/.local/bin:$PATH 56 | - | 57 | if [ "$TRAVIS_OS_NAME" = "osx" ]; then 58 | travis_retry curl -L https://www.stackage.org/stack/osx-x86_64 | tar xz --strip-components=1 -C ~/.local/bin '*/stack' 59 | else 60 | travis_retry curl -L https://www.stackage.org/stack/linux-x86_64 | tar xz --wildcards --strip-components=1 -C ~/.local/bin '*/stack' 61 | fi 62 | 63 | install: 64 | # Build dependencies 65 | - | 66 | if [ "$TRAVIS_OS_NAME" = "osx" ]; then 67 | brew update 68 | brew upgrade python 69 | export PATH=/usr/local/opt/python/libexec/bin:$PATH 70 | brew install mkl-dnn protobuf 71 | else 72 | bash .travis/install_mkldnn.sh 73 | fi 74 | - bash .travis/install_menoh.sh 75 | - export PKG_CONFIG_PATH=$HOME/menoh${MENOH_INSTALL_SUFFIX}/share/pkgconfig:$PKG_CONFIG_PATH 76 | - export LD_LIBRARY_PATH=$HOME/menoh${MENOH_INSTALL_SUFFIX}/lib:$HOME/mkl-dnn${MKL_DNN_INSTALL_SUFFIX}/lib:$LD_LIBRARY_PATH 77 | - stack --jobs 2 --no-terminal --install-ghc build --test --bench --only-dependencies $FLAGS 78 | 79 | script: 80 | # Build the package, its tests, and its docs and run the tests 81 | - stack --jobs 2 --no-terminal test --bench --no-run-benchmarks $FLAGS 82 | - stack exec mnist_example 83 | -------------------------------------------------------------------------------- /.travis/install_menoh.sh: -------------------------------------------------------------------------------- 1 | if [ ! -d "$HOME/menoh${MENOH_INSTALL_SUFFIX}/lib" ]; then 2 | if [ "$TRAVIS_OS_NAME" = "linux" ]; then export CXX="g++-7" CC="gcc-7"; fi 3 | git clone https://github.com/pfnet-research/menoh.git --recurse-submodules 4 | cd menoh 5 | git checkout $MENOH_REV 6 | mkdir -p build 7 | cd build 8 | if [ "$TRAVIS_OS_NAME" = "linux" ]; then 9 | cmake \ 10 | -DENABLE_TEST=OFF -DENABLE_BENCHMARK=OFF -DENABLE_EXAMPLE=OFF \ 11 | -DMKLDNN_INCLUDE_DIR="$HOME/mkl-dnn${MKL_DNN_INSTALL_SUFFIX}/include" \ 12 | -DMKLDNN_LIBRARY="$HOME/mkl-dnn${MKL_DNN_INSTALL_SUFFIX}/lib/libmkldnn.so" \ 13 | -DCMAKE_INSTALL_PREFIX=$HOME/menoh${MENOH_INSTALL_SUFFIX} \ 14 | .. 15 | else 16 | cmake \ 17 | -DENABLE_TEST=OFF -DENABLE_BENCHMARK=OFF -DENABLE_EXAMPLE=OFF \ 18 | -DCMAKE_INSTALL_PREFIX=$HOME/menoh${MENOH_INSTALL_SUFFIX} \ 19 | -DCMAKE_INSTALL_NAME_DIR=$HOME/menoh${MENOH_INSTALL_SUFFIX}/lib \ 20 | .. 21 | fi 22 | make 23 | make install 24 | cd ../.. 25 | else 26 | echo "Using cached directory." 27 | fi 28 | -------------------------------------------------------------------------------- /.travis/install_mkldnn.sh: -------------------------------------------------------------------------------- 1 | if [ ! -d "$HOME/mkl-dnn${MKL_DNN_INSTALL_SUFFIX}/lib" ]; then 2 | git clone https://github.com/intel/mkl-dnn.git 3 | cd mkl-dnn 4 | git checkout $MKL_DNN_REV 5 | cd scripts && bash ./prepare_mkl.sh && cd .. 6 | sed -i 's/add_subdirectory(examples)//g' CMakeLists.txt 7 | sed -i 's/add_subdirectory(tests)//g' CMakeLists.txt 8 | mkdir -p build && cd build && cmake -DCMAKE_INSTALL_PREFIX=$HOME/mkl-dnn${MKL_DNN_INSTALL_SUFFIX} .. && make 9 | make install 10 | cd .. 11 | else 12 | echo "Using cached directory." 13 | fi 14 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | MIT License 3 | 4 | Copyright (c) 2018 Preferred Networks, Inc. 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # menoh-haskell 2 | 3 | [![Hackage](https://img.shields.io/hackage/v/menoh.svg)](https://hackage.haskell.org/package/menoh) 4 | [![Hackage Deps](https://img.shields.io/hackage-deps/v/menoh.svg)](https://packdeps.haskellers.com/feed?needle=menoh) 5 | [![Build Status (Travis CI)](https://travis-ci.org/pfnet-research/menoh-haskell.svg?branch=master)](https://travis-ci.org/pfnet-research/menoh-haskell) 6 | [![Build Status (AppVeyor)](https://ci.appveyor.com/api/projects/status/x4yicemyr55cj6na/branch/master?svg=true)](https://ci.appveyor.com/project/pfnet-research/menoh-haskell/branch/master) 7 | [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) 8 | 9 | Haskell binding for [Menoh](https://github.com/pfnet-research/menoh/) DNN inference library. 10 | 11 | # Requirements 12 | 13 | - [Menoh](https://github.com/pfnet-research/menoh/) 14 | - [The Haskell Tool Stack](https://www.haskellstack.org/) 15 | 16 | # Build 17 | 18 | Execute below commands in root directory. 19 | 20 | ``` 21 | sh retrieve_data.sh 22 | stack build 23 | ``` 24 | 25 | # Running VGG16 example 26 | 27 | Execute below command in root directory. 28 | 29 | ``` 30 | cd menoh 31 | stack exec vgg16_example 32 | ``` 33 | 34 | Result is below 35 | 36 | ``` 37 | vgg16 example 38 | fc6_out: -21.936756 -27.385506 -18.64326 16.917625 5.599732 ... 39 | top 5 categories are: 40 | 8 0.93506306 n01514859 hen 41 | 7 0.05933844 n01514668 cock 42 | 86 0.0033869066 n01807496 partridge 43 | 82 0.0008452002 n01797886 ruffed grouse, partridge, Bonasa umbellus 44 | 97 0.0003278699 n01847000 drake 45 | ``` 46 | 47 | Please give `--help` option for details 48 | 49 | ``` 50 | stack exec vgg16_example --help 51 | ``` 52 | 53 | # Installation 54 | 55 | ``` 56 | stack install 57 | ``` 58 | 59 | # Documents 60 | 61 | * [API reference manual](http://hackage.haskell.org/package/menoh) is available on Hackage. 62 | 63 | * see [mnist_example.hs](app/mnist_example.hs) and [vgg16_example.hs](app/vgg16_example.hs) for example usage. 64 | 65 | # License 66 | 67 | Note: `retrieve_data.sh` downloads `data/VGG16.onnx`. `data/VGG16.onnx` is generated by onnx-chainer from pre-trained model which is uploaded 68 | at http://www.robots.ox.ac.uk/%7Evgg/software/very_deep/caffe/VGG_ILSVRC_16_layers.caffemodel 69 | 70 | That pre-trained model is released under Creative Commons Attribution License. 71 | -------------------------------------------------------------------------------- /Setup.hs: -------------------------------------------------------------------------------- 1 | import Distribution.Simple 2 | main = defaultMain 3 | -------------------------------------------------------------------------------- /app/mnist_example.hs: -------------------------------------------------------------------------------- 1 | {-# OPTIONS_GHC -Wall #-} 2 | {-# LANGUAGE ScopedTypeVariables #-} 3 | {-# LANGUAGE FlexibleContexts #-} 4 | module Main (main) where 5 | 6 | import qualified Codec.Picture as Picture 7 | import qualified Codec.Picture.Types as Picture 8 | import Control.Applicative 9 | import Control.Monad 10 | import Data.Monoid 11 | import qualified Data.Vector.Generic as VG 12 | import qualified Data.Vector.Storable as VS 13 | import Data.Version 14 | import Options.Applicative 15 | import Menoh 16 | import System.FilePath 17 | import Text.Printf 18 | 19 | import Paths_menoh (getDataDir) 20 | 21 | main :: IO () 22 | main = do 23 | putStrLn "mnist example" 24 | dataDir <- getDataDir 25 | opt <- execParser (parserInfo (dataDir "data")) 26 | 27 | let input_dir = optInputPath opt 28 | 29 | images <- forM [(0::Int)..9] $ \i -> do 30 | let fname :: String 31 | fname = printf "%d.png" i 32 | ret <- Picture.readImage $ input_dir fname 33 | case ret of 34 | Left e -> error e 35 | Right img -> return (Picture.extractLumaPlane $ Picture.convertRGB8 img, i, fname) 36 | 37 | let batch_size = length images 38 | channel_num = 1 39 | height = 28 40 | width = 28 41 | category_num = 10 42 | 43 | input_dims, output_dims :: Dims 44 | input_dims = [batch_size, channel_num, height, width] 45 | output_dims = [batch_size, category_num] 46 | 47 | -- Aliases to onnx's node input and output tensor name 48 | mnist_in_name = "139900320569040" 49 | mnist_out_name = "139898462888656" 50 | 51 | -- Load ONNX model data 52 | model_data <- makeModelDataFromONNXFile (optModelPath opt) 53 | 54 | -- Specify inputs and outputs 55 | vpt <- makeVariableProfileTable 56 | [(mnist_in_name, DTypeFloat, input_dims)] 57 | [mnist_out_name] 58 | model_data 59 | optimizeModelData model_data vpt 60 | 61 | -- Construct computation primitive list and memories 62 | model <- makeModel vpt model_data "mkldnn" 63 | 64 | -- Copy input image data to model's input array 65 | writeBuffer model mnist_in_name [VG.map fromIntegral (Picture.imageData img) :: VS.Vector Float | (img,_,_) <- images] 66 | 67 | -- Run inference 68 | run model 69 | 70 | -- Get output 71 | (vs :: [VS.Vector Float]) <- readBuffer model mnist_out_name 72 | 73 | -- Examine the results 74 | forM_ (zip images vs) $ \((_img,expected,fname), scores) -> do 75 | let guessed = VG.maxIndex scores 76 | putStrLn fname 77 | printf "Expected: %d Guessed: %d\n" expected guessed 78 | putStrLn $ "Scores: " ++ show (zip [(0::Int)..] (VG.toList scores)) 79 | putStrLn $ "Probabilities: " ++ show (zip [(0::Int)..] (VG.toList (softmax scores))) 80 | putStrLn "" 81 | 82 | -- ------------------------------------------------------------------------- 83 | 84 | data Options 85 | = Options 86 | { optInputPath :: FilePath 87 | , optModelPath :: FilePath 88 | } 89 | 90 | optionsParser :: FilePath -> Parser Options 91 | optionsParser dataDir = Options 92 | <$> inputPathOption 93 | <*> modelPathOption 94 | where 95 | inputPathOption = strOption 96 | $ long "input" 97 | <> short 'i' 98 | <> metavar "DIR" 99 | <> help "input image path" 100 | <> value dataDir 101 | <> showDefault 102 | modelPathOption = strOption 103 | $ long "model" 104 | <> short 'm' 105 | <> metavar "PATH" 106 | <> help "onnx model path" 107 | <> value (dataDir "mnist.onnx") 108 | <> showDefault 109 | 110 | parserInfo :: FilePath -> ParserInfo Options 111 | parserInfo dir = info (helper <*> versionOption <*> optionsParser dir) 112 | $ fullDesc 113 | <> header "mnist_example - an example program of Menoh haskell binding" 114 | where 115 | versionOption :: Parser (a -> a) 116 | versionOption = infoOption (showVersion version) 117 | $ hidden 118 | <> long "version" 119 | <> help "Show version" 120 | 121 | -- ------------------------------------------------------------------------- 122 | 123 | softmax :: (Real a, Floating a, VG.Vector v a) => v a -> v a 124 | softmax v | VG.null v = VG.empty 125 | softmax v = VG.map (/ s) v' 126 | where 127 | m = VG.maximum v 128 | v' = VG.map (\x -> exp (x - m)) v 129 | s = VG.sum v' 130 | 131 | -- ------------------------------------------------------------------------- 132 | -------------------------------------------------------------------------------- /app/vgg16_example.hs: -------------------------------------------------------------------------------- 1 | {-# OPTIONS_GHC -Wall #-} 2 | {-# LANGUAGE ScopedTypeVariables #-} 3 | {-# LANGUAGE FlexibleContexts #-} 4 | module Main (main) where 5 | 6 | import qualified Codec.Picture as Picture 7 | import Control.Applicative 8 | import Control.Monad 9 | import Data.List 10 | import Data.Monoid 11 | import Data.Ord 12 | import qualified Data.Vector as V 13 | import qualified Data.Vector.Storable as VS 14 | import Data.Version 15 | import Options.Applicative 16 | import Menoh 17 | import Text.Printf 18 | 19 | main :: IO () 20 | main = do 21 | putStrLn "vgg16 example" 22 | opt <- execParser parserInfo 23 | 24 | let batch_size = 1 25 | channel_num = 3 26 | height = 224 27 | width = 224 28 | category_num = 1000 29 | input_dims, output_dims :: Dims 30 | input_dims = [batch_size, channel_num, height, width] 31 | output_dims = [batch_size, category_num] 32 | 33 | ret <- Picture.readImage (optInputImagePath opt) 34 | let image_data = 35 | case ret of 36 | Left e -> error e 37 | Right img -> convert width height img 38 | 39 | -- Aliases to onnx's node input and output tensor name 40 | let conv1_1_in_name = "Input_0" 41 | fc6_out_name = "Gemm_0" 42 | softmax_out_name = "Softmax_0" 43 | 44 | -- Load ONNX model data 45 | model_data <- makeModelDataFromONNX (optModelPath opt) 46 | 47 | -- Specify inputs and outputs 48 | vpt <- makeVariableProfileTable 49 | [(conv1_1_in_name, DTypeFloat, input_dims)] 50 | [fc6_out_name, softmax_out_name] 51 | model_data 52 | optimizeModelData model_data vpt 53 | 54 | -- Construct computation primitive list and memories 55 | model <- makeModel vpt model_data "mkldnn" 56 | 57 | -- Copy input image data to model's input array 58 | writeBuffer model conv1_1_in_name [image_data] 59 | 60 | -- Run inference 61 | run model 62 | 63 | -- Get output 64 | ([fc6_out] :: [V.Vector Float]) <- readBuffer model fc6_out_name 65 | putStr "fc6_out: " 66 | forM_ [0..4] $ \i -> do 67 | putStr $ show $ fc6_out V.! i 68 | putStr " " 69 | putStrLn "..." 70 | 71 | ([softmax_out] :: [V.Vector Float]) <- readBuffer model softmax_out_name 72 | 73 | categories <- liftM lines $ readFile (optSynsetWordsPath opt) 74 | let k = 5 75 | scores <- forM [0 .. V.length softmax_out - 1] $ \i -> do 76 | return (i, softmax_out V.! i) 77 | printf "top %d categories are:\n" k 78 | forM_ (take k $ sortBy (flip (comparing snd)) scores) $ \(i,p) -> do 79 | printf "%d %f %s\n" i p (categories !! i) 80 | 81 | -- ------------------------------------------------------------------------- 82 | 83 | data Options 84 | = Options 85 | { optInputImagePath :: FilePath 86 | , optModelPath :: FilePath 87 | , optSynsetWordsPath :: FilePath 88 | } 89 | 90 | optionsParser :: Parser Options 91 | optionsParser = Options 92 | <$> inputImageOption 93 | <*> modelPathOption 94 | <*> synsetWordsPathOption 95 | where 96 | inputImageOption = strOption 97 | $ long "input-image" 98 | <> short 'i' 99 | <> metavar "PATH" 100 | <> help "input image path" 101 | <> value "data/Light_sussex_hen.jpg" 102 | <> showDefault 103 | modelPathOption = strOption 104 | $ long "model" 105 | <> short 'm' 106 | <> metavar "PATH" 107 | <> help "onnx model path" 108 | <> value "data/vgg16.onnx" 109 | <> showDefault 110 | synsetWordsPathOption = strOption 111 | $ long "synset-words" 112 | <> short 's' 113 | <> metavar "PATH" 114 | <> help "synset words path" 115 | <> value "data/synset_words.txt" 116 | <> showDefault 117 | 118 | parserInfo :: ParserInfo Options 119 | parserInfo = info (helper <*> versionOption <*> optionsParser) 120 | $ fullDesc 121 | <> header "vgg16_example - an example program of Menoh haskell binding" 122 | where 123 | versionOption :: Parser (a -> a) 124 | versionOption = infoOption (showVersion version) 125 | $ hidden 126 | <> long "version" 127 | <> help "Show version" 128 | 129 | -- ------------------------------------------------------------------------- 130 | 131 | convert :: Int -> Int -> Picture.DynamicImage -> VS.Vector Float 132 | convert w h = reorderToNCHW . resize (w,h) . crop . Picture.convertRGB8 133 | 134 | crop :: Picture.Pixel a => Picture.Image a -> Picture.Image a 135 | crop img = Picture.generateImage (\x y -> Picture.pixelAt img (base_x + x) (base_y + y)) shortEdge shortEdge 136 | where 137 | shortEdge = min (Picture.imageWidth img) (Picture.imageHeight img) 138 | base_x = (Picture.imageWidth img - shortEdge) `div` 2 139 | base_y = (Picture.imageHeight img - shortEdge) `div` 2 140 | 141 | -- TODO: Should we do some kind of interpolation? 142 | resize :: Picture.Pixel a => (Int,Int) -> Picture.Image a -> Picture.Image a 143 | resize (w,h) img = Picture.generateImage (\x y -> Picture.pixelAt img (x * orig_w `div` w) (y * orig_h `div` h)) w h 144 | where 145 | orig_w = Picture.imageWidth img 146 | orig_h = Picture.imageHeight img 147 | 148 | reorderToNCHW :: Picture.Image Picture.PixelRGB8 -> VS.Vector Float 149 | reorderToNCHW img = VS.generate (3 * Picture.imageHeight img * Picture.imageWidth img) f 150 | where 151 | f i = 152 | case Picture.pixelAt img x y of 153 | Picture.PixelRGB8 r g b -> 154 | case ch of 155 | 0 -> fromIntegral r - 123.68 156 | 1 -> fromIntegral g - 116.779 157 | 2 -> fromIntegral b - 103.939 158 | _ -> undefined 159 | where 160 | (ch,m) = i `divMod` (Picture.imageWidth img * Picture.imageHeight img) 161 | (y,x) = m `divMod` Picture.imageWidth img 162 | 163 | -- ------------------------------------------------------------------------- 164 | -------------------------------------------------------------------------------- /appveyor.yml: -------------------------------------------------------------------------------- 1 | platform: 2 | - x64 3 | 4 | environment: 5 | global: 6 | STACK_ROOT: "c:\\sr" 7 | matrix: 8 | - TARGET: mingw 9 | - TARGET: msvc 10 | 11 | cache: 12 | - "c:\\sr" # stack root, short paths == less problems 13 | 14 | install: 15 | - set SSL_CERT_FILE=C:\msys64\mingw64\ssl\cert.pem 16 | - set SSL_CERT_DIR=C:\msys64\mingw64\ssl\certs 17 | # Some conditional statements are splited to avoid "\Microsoft was unexpected at this time.” error. 18 | # https://support.microsoft.com/ja-jp/help/2524009/error-running-command-shell-scripts-that-include-parentheses 19 | - if [%TARGET%]==[mingw] set PATH=C:\msys64\mingw64\bin;C:\msys64\usr\bin;%PATH% 20 | - if [%TARGET%]==[msvc] set PATH=C:\msys64\mingw64\bin;%PATH% 21 | - if [%TARGET%]==[mingw] ( 22 | curl -omingw-w64-x86_64-mkl-dnn-0.16-1-x86_64.pkg.tar.xz -L https://github.com/msakai/mkl-dnn/releases/download/v0.16/mingw-w64-x86_64-mkl-dnn-0.16-1-x86_64.pkg.tar.xz && 23 | pacman -U --noconfirm mingw-w64-x86_64-mkl-dnn-0.16-1-x86_64.pkg.tar.xz && 24 | curl -omingw-w64-x86_64-menoh-1.1.1-1-x86_64.pkg.tar.xz -L https://github.com/pfnet-research/menoh/releases/download/v1.1.1/mingw-w64-x86_64-menoh-1.1.1-1-x86_64.pkg.tar.xz && 25 | pacman -U --noconfirm mingw-w64-x86_64-menoh-1.1.1-1-x86_64.pkg.tar.xz 26 | ) else ( 27 | curl -omkl-dnn-0.16-win64.7z -L --insecure https://github.com/msakai/mkl-dnn/releases/download/v0.16/mkl-dnn-0.16-win64.7z && 28 | 7z x mkl-dnn-0.16-win64.7z && 29 | curl -omenoh_prebuild_win_v1.1.1.zip -L --insecure https://github.com/pfnet-research/menoh/releases/download/v1.1.1/menoh_prebuild_win_v1.1.1.zip && 30 | 7z x menoh_prebuild_win_v1.1.1.zip 31 | ) 32 | - if [%TARGET%]==[msvc] set PKG_CONFIG_PATH=c:\projects\menoh-haskell\menoh_prebuild_win_v1.1.1\share\pkgconfig;%PKG_CONFIG_PATH% 33 | - if [%TARGET%]==[msvc] set PATH=c:\projects\menoh-haskell\menoh_prebuild_win_v1.1.1\bin;c:\projects\menoh-haskell\mkl-dnn-0.16-win64\bin;c:\projects\menoh-haskell\mkl-dnn-0.16-win64\lib;%PATH% 34 | 35 | - curl -ostack.zip -L --insecure http://www.stackage.org/stack/windows-x86_64 36 | - 7z x stack.zip stack.exe 37 | - stack setup > nul 38 | 39 | build_script: 40 | - echo "" | stack --no-terminal build --test --no-run-tests 41 | 42 | test_script: 43 | - echo "" | stack --no-terminal test 44 | - stack exec mnist_example 45 | -------------------------------------------------------------------------------- /data/0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pfnet-research/menoh-haskell/d8a3dd48296461499d054625ef97169350812c23/data/0.png -------------------------------------------------------------------------------- /data/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pfnet-research/menoh-haskell/d8a3dd48296461499d054625ef97169350812c23/data/1.png -------------------------------------------------------------------------------- /data/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pfnet-research/menoh-haskell/d8a3dd48296461499d054625ef97169350812c23/data/2.png -------------------------------------------------------------------------------- /data/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pfnet-research/menoh-haskell/d8a3dd48296461499d054625ef97169350812c23/data/3.png -------------------------------------------------------------------------------- /data/4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pfnet-research/menoh-haskell/d8a3dd48296461499d054625ef97169350812c23/data/4.png -------------------------------------------------------------------------------- /data/5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pfnet-research/menoh-haskell/d8a3dd48296461499d054625ef97169350812c23/data/5.png -------------------------------------------------------------------------------- /data/6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pfnet-research/menoh-haskell/d8a3dd48296461499d054625ef97169350812c23/data/6.png -------------------------------------------------------------------------------- /data/7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pfnet-research/menoh-haskell/d8a3dd48296461499d054625ef97169350812c23/data/7.png -------------------------------------------------------------------------------- /data/8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pfnet-research/menoh-haskell/d8a3dd48296461499d054625ef97169350812c23/data/8.png -------------------------------------------------------------------------------- /data/9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pfnet-research/menoh-haskell/d8a3dd48296461499d054625ef97169350812c23/data/9.png -------------------------------------------------------------------------------- /data/mnist.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pfnet-research/menoh-haskell/d8a3dd48296461499d054625ef97169350812c23/data/mnist.onnx -------------------------------------------------------------------------------- /menoh.cabal: -------------------------------------------------------------------------------- 1 | name: menoh 2 | version: 0.4.0 3 | license: MIT 4 | license-file: LICENSE 5 | author: Masahiro Sakai 6 | maintainer: Masahiro Sakai 7 | copyright: Copyright 2018 Preferred Networks, Inc. 8 | category: Machine Learning, Deep Learning 9 | synopsis: Haskell binding for Menoh DNN inference library 10 | description: Menoh is a MKL-DNN based DNN inference library for ONNX models. See https://github.com/pfnet-research/menoh/ for details. 11 | build-type: Simple 12 | cabal-version: >=1.10 13 | extra-source-files: 14 | README.md 15 | retrieve_data.hs 16 | retrieve_data.sh 17 | data-files: 18 | data/*.png 19 | data/mnist.onnx 20 | tested-with: 21 | GHC ==7.8.4 22 | GHC ==7.10.3 23 | GHC ==8.0.2 24 | GHC ==8.2.2 25 | GHC ==8.4.3 26 | 27 | source-repository head 28 | type: git 29 | location: https://github.com/pfnet-research/menoh-haskell/ 30 | 31 | flag buildexamples 32 | description: build example programs 33 | default: True 34 | manual: True 35 | 36 | library 37 | hs-source-dirs: src 38 | exposed-modules: 39 | Menoh 40 | Menoh.Base 41 | other-modules: 42 | Paths_menoh 43 | other-extensions: 44 | CPP 45 | , DeriveDataTypeable 46 | , FlexibleContexts 47 | , ForeignFunctionInterface 48 | , ScopedTypeVariables 49 | build-depends: 50 | base >=4.7 && <5 51 | , aeson >=0.8 && <1.5 52 | , bytestring >=0.10 && <0.11 53 | , containers >=0.5 && <0.7 54 | , monad-control >=1.0 && <1.1 55 | , transformers >=0.3 && <0.6 56 | , vector >=0.10 && <0.13 57 | pkgconfig-depends: 58 | menoh >=1.1.1 && <2.0.0 59 | default-language: Haskell2010 60 | 61 | executable vgg16_example 62 | hs-source-dirs: app 63 | main-is: vgg16_example.hs 64 | build-depends: 65 | base 66 | -- convertRGB8 requires JuicyPixels >=3.2.7 67 | , JuicyPixels >=3.2.7 && <3.4 68 | , optparse-applicative >=0.11 && <0.15 69 | , menoh 70 | , vector 71 | default-language: Haskell2010 72 | If !flag(buildexamples) 73 | buildable: False 74 | 75 | executable mnist_example 76 | hs-source-dirs: app 77 | other-modules: Paths_menoh 78 | main-is: mnist_example.hs 79 | build-depends: 80 | base 81 | , filepath >=1.3 && <1.5 82 | , JuicyPixels 83 | , optparse-applicative 84 | , menoh 85 | , vector 86 | default-language: Haskell2010 87 | If !flag(buildexamples) 88 | buildable: False 89 | 90 | Test-suite Test 91 | type: exitcode-stdio-1.0 92 | hs-source-dirs: test 93 | main-is: test.hs 94 | other-modules: 95 | Paths_menoh 96 | build-depends: 97 | base >=4 && <5 98 | , bytestring 99 | , async >=2.0.2 100 | , filepath >=1.3 && <1.5 101 | , JuicyPixels 102 | , menoh 103 | , vector 104 | , tasty >=0.10.1 105 | , tasty-hunit >=0.9 && <0.11 106 | , tasty-th 107 | default-language: Haskell2010 108 | ghc-options: -threaded 109 | -------------------------------------------------------------------------------- /mnist_example.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# MNIST example for Menoh Haskell " 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "{-# LANGUAGE ScopedTypeVariables #-}\n", 17 | "import Control.Applicative\n", 18 | "import Control.Monad\n", 19 | "import System.FilePath\n", 20 | "import Text.Printf\n", 21 | "import Menoh" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": 2, 27 | "metadata": {}, 28 | "outputs": [ 29 | { 30 | "data": { 31 | "text/plain": [ 32 | "Version {versionBranch = [1,0,2], versionTags = []}" 33 | ] 34 | }, 35 | "metadata": {}, 36 | "output_type": "display_data" 37 | } 38 | ], 39 | "source": [ 40 | "Menoh.version" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": 3, 46 | "metadata": {}, 47 | "outputs": [ 48 | { 49 | "data": { 50 | "text/plain": [ 51 | "Version {versionBranch = [0,2,0], versionTags = []}" 52 | ] 53 | }, 54 | "metadata": {}, 55 | "output_type": "display_data" 56 | } 57 | ], 58 | "source": [ 59 | "Menoh.bindingVersion" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": 4, 65 | "metadata": {}, 66 | "outputs": [], 67 | "source": [ 68 | "dataDir = \"data\"" 69 | ] 70 | }, 71 | { 72 | "cell_type": "markdown", 73 | "metadata": {}, 74 | "source": [ 75 | "## Step 1: Prepare input for inference" 76 | ] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "execution_count": 5, 81 | "metadata": {}, 82 | "outputs": [], 83 | "source": [ 84 | "import qualified Codec.Picture as Picture\n", 85 | "import qualified Codec.Picture.Types as Picture" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": 6, 91 | "metadata": {}, 92 | "outputs": [ 93 | { 94 | "data": { 95 | "text/plain": [] 96 | }, 97 | "metadata": {}, 98 | "output_type": "display_data" 99 | } 100 | ], 101 | "source": [ 102 | "images <- forM [0..9] $ \\i -> do\n", 103 | " let fname :: String\n", 104 | " fname = printf \"%d.png\" i\n", 105 | " ret <- Picture.readImage $ dataDir fname\n", 106 | " case ret of\n", 107 | " Left e -> error e\n", 108 | " Right img -> return (Picture.extractLumaPlane $ Picture.convertRGB8 img, i, fname)" 109 | ] 110 | }, 111 | { 112 | "cell_type": "code", 113 | "execution_count": 7, 114 | "metadata": {}, 115 | "outputs": [ 116 | { 117 | "data": { 118 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAABLklEQVR4nGNgoBPg4eHBJWVfMHNlgTE2PZptC9f////uVW+BPaau5S/+f5k0qbf3//8Ce1SzeewLPv7+vtEWZPTsVz2aKJKePS+3fD4eA2annDzphixnXP6/vHD7dm0wp/769WAkOfm0V7M85e3t5cE2TPp/xQZJ0vfYyQReGEdzy7vZ8kgaK1/NgjuBJ/HkUV8cGnkKD3yqR9LIU/lqMVyj/ebf90KRNNosfpUD1ShvX/Du+hRtJBtzX82yhbIjth9L9kUylCH47AWYRp6m/z3IUgwMXa+vu0KZnitfpaHIMSz8dR7qdPvy/wUmqJKTvz8OA8WCfETTf4zYSj55fWlrbW1rz0lMOZDzn70Cgv8nscQyA4NJ2uKVK1c2RchjkWNg4JVUVZXkxSpFdQAAJYxzwP/IyZQAAAAASUVORK5CYII=" 119 | }, 120 | "metadata": {}, 121 | "output_type": "display_data" 122 | }, 123 | { 124 | "data": { 125 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAAApElEQVR4nGNgoAfgsbc3xikp3vM4Bqek6cpXhTglE06e9MUlJ9/0v0kel2TylSs5uOR4Jn2/YoNL0mbzp0k4TU15vdYfp6ltrxZL4JLUXPyqkheXZOjJk6G45OQrX83SxCVpc+xkDk5Tk0+edMUlx1P5ajFOU0O3vMPt1vrv13G6lWH2p/04TTWe9L8Wp6kxr1YG4jQ19nEPznBlMLG3x2kq6QAAM4AxxP3dxw0AAAAASUVORK5CYII=" 126 | }, 127 | "metadata": {}, 128 | "output_type": "display_data" 129 | }, 130 | { 131 | "data": { 132 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAABLklEQVR4nGNgGHDAo+kaHBxsI49FSt63cvHJkyfPnCwwxpCzLzj56hVQ8uy9VzFoUjYRTf9PLq5MDg5O2fGmjQdZis130rPTOYnavEC2RM+ruSiSPKWf75ZD3eG6610XL6rklkptqPnTP7xKRrGRTV5eCqos5/W7STY4fOq5/F0vDjl5z0nP0kxwyEXsPhCBQ84m58SBCGxhxyAf3HP084t5MTyYUjbJXTe//P73/c3+ck90afkprz89/f375tnX7/6vnJSD4lz5iOvfL57/sjrFP6er99X/1+tndyXD7OaJOHB15Z7/BWB3mqStev3///eT9lBJzeUfjx591wTj2uRM2vHsegSUp33yw6efx3yR7HHbsQNmr3zB9Xer0pA9KG9vD+fKAxMO9nChAQAAB+d8B+JVXPsAAAAASUVORK5CYII=" 133 | }, 134 | "metadata": {}, 135 | "output_type": "display_data" 136 | }, 137 | { 138 | "data": { 139 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAABIUlEQVR4nGNgGEggbxMMBK6aPFik7AsOnASBiZ6YsvLzt0MkX/XYY9EZATI1OKXyZQGmLAzY1v8vkMcqw6PpmrPj+klNrO7NmXXy5OtXPRiS8vIRBUDnnDzWlmbPiyZnP+nuif+PH26ZHIzFQvv/v/99/vTx05VKYyySHceP71q55d7r1zk2mKHAK6mqqmqf0bnp3coIrD4BqUm8970JTau8PMwd2j3v1nuiyBkXbIcLmEz5X44iGfPqSiCMzTPlfy+KZO2bLXYwtvjKV60okoWvVprCNAaufBmLIul74HQGKEB5eNxylr9MM0F1bMTJmwunTOmZtHjmu1dockDZglf/geD1u9eLC9DlgB5I650yZcry3nxb9CiBAV5eXDJYAABkB3YahSR/xQAAAABJRU5ErkJggg==" 140 | }, 141 | "metadata": {}, 142 | "output_type": "display_data" 143 | }, 144 | { 145 | "data": { 146 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAABEElEQVR4nGNgGGAgLy+PW9J+vj2QlDI2lsIiV7AdJJm8f78rhhxPwf9KoLE8Xd8fYErarHyVA6I2fJhugKEx99ViWyCd+2qlLYZGzcWvKnkZGIx7/5fzYkiWXL+eDNSf9mqxJ4ac/KT/F2zg+tFAxIXrTUC3hp48GYrpycr/UzQZGMTbXi/XxJTsebWSh4HB8/inHkxTGdL+zw6Vly9/9yYNUw4Ydq8O5GTcfLCjPDg4uTLNHk225+SFxx/+PzgJBK/+o0nyaiec/PznO1hyZQGaJNCrwUe+PGsKDg5208biKIamN5tNsAiDAU/bq8USuCRDt7zDEnZQOzu+X/DFpVH+wJUmnClM3stLG5ccJgAAgTNlADVeUsoAAAAASUVORK5CYII=" 147 | }, 148 | "metadata": {}, 149 | "output_type": "display_data" 150 | }, 151 | { 152 | "data": { 153 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAAA+0lEQVR4nGNgoC1gE1NSV1YCA3EedEn50vVb1q4Egx57dEmb3e8+vX4FBo8xJO3P3lrXUgsGsSbokq4rXyXjdFDoyZOhuOTkK1/N0sQlqXnsZCUvDjme+JUvfYGUOBZPMvB0vVk3ubVnysqVi9OM0SUlet49fv3q/3+gL19iyPK0vLqyduWU3rbatsUYsrz2EWFmkiAX8XpNwtSLUBe89lUaLkkGibXvenBK8s78PwlFQF5eHsaySZ79vwBZzn7KlcWg+CisrOw5evJVAUqc2f///w4clSf/vzp5Ei2y7ZsWbgEngpNTZuW4aaMHMq+kKhCYaWvz4gp+OgEA7pds7vwUkw0AAAAASUVORK5CYII=" 154 | }, 155 | "metadata": {}, 156 | "output_type": "display_data" 157 | }, 158 | { 159 | "data": { 160 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAABFElEQVR4nGNgoAOQT24qMMYhZ19w/fsre6xSPDbl/2/eW4ld0r7l1axqfzNebHLGBS/L7bHKAN2S9nKyLS6H+h47mYBDH4N85atZmuRo5Kl8tRinRs3FH7ZMri0sjXHlwZQ0vf7y97tXr66/3hwhjyFZ+eD/x/0rV15//e7CwiYbNMn8/3/O2quaxZZ3nf3/uglNc/6t36sgXspZ++FehyuKpPKWN71Qpk3fp4/1KJISy98thlmV8fBXH4okb87rlzlQdvKVK1moltrM/D8JkgawhYfnrFdpIFdqJqx/mYYekLwRB47lBMfkdFx8lWbCgA7kI7YfOHDmyuujBZhyQFn7ggNnTzb5YoYfRNom2AaHFCYAAAwlbZ2jv1/DAAAAAElFTkSuQmCC" 161 | }, 162 | "metadata": {}, 163 | "output_type": "display_data" 164 | }, 165 | { 166 | "data": { 167 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAAA5klEQVR4nGNgoC2Qkpdnwynpu6pWE6dk0+uXG6cAQW+lrzyG5LRf/yHg3ZWFBb5oklO/3VoF0rnw3usvP1fao8jx9LxeLgFiaJdue/37P4okj+fil2m8YIf13vj2ZT6KpOakV5UmIIZx0/fXC/NtUUzV3r4d7Aj5tDcnO2zQnCNvbw/2QM6Vm6WSODxrPOn/GXR9cP1pb7aUYwYBRC7i2IEIXhwaIw6cjMChD+id/7245Bhy7r3LwSVnP+n7Slsccja9/2fjco3nzC/v0nDo4yn//6XXBJfkpNeTcFnIwGtvj0sf9QEABIxZTHLDDfMAAAAASUVORK5CYII=" 168 | }, 169 | "metadata": {}, 170 | "output_type": "display_data" 171 | }, 172 | { 173 | "data": { 174 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAABT0lEQVR4nGNgGEAg71tQCwS9hb5YJCNOvnvz6tWrd9+a5BlCy8s1UeQWfni8ZeXKlfvvXUwzrv9+BUWy8v+9JntVVVX7jANn+5a+m6KNLDnl3X4TqN0zVv1PM0GxsvfdfjswQzO+41WhPap7Kr7fiwXRvh03X8Wi6mNgCF35rt6eQd6+99u9cnQ5Bl77gv+T5m65evN0kza6HNgz/3/++X99UgQWOfuC6/9//v728X+BMboUD9DUk5NnTenr2fGuxwZN0rPnVWWEPMju+tcvg1HljMv/w9yYcf16BqpkzKuVgVBm8MmTaDqb3+33xSmZ9fp1FpSZceECmqTr/o99PGCW/Mx3111RJXkL/h+alCwvH1ze9/DXLk1USQb78k/fr/T0XHn98eGbSl40SQbb3itX/r+6/np7UwRGuAMts8lpmlUf64ahDWYzLy4ZbAAA1QuPQFv7dcYAAAAASUVORK5CYII=" 175 | }, 176 | "metadata": {}, 177 | "output_type": "display_data" 178 | }, 179 | { 180 | "data": { 181 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAABM0lEQVR4nGNgoCWQtwl2NVECAnEeDDn7gpMnT64EgcW5Njzocv+Bkq8gYFa8JrI0T8H/2TnByYW1QNC2+NXJWRHyCEnNla8qYWxe2/xjr44W2MM1u508mYzkNt/KQy97PGGywSdPBiM7QbN226f5mlCOzcmTOUg6I5oOf/p80Q3GLf8/yVeThwdq6NH/n35/ugE3y7Pn1dEpiyuTgzPqJ518dbJ3y7MvV+CSvPYFx/6/enXy5PXX/4Ee0az8+PpiMLJFlb2LQZILE7UlPBe/u7LbBiWYeLXdgjOCgWKak151pdnLM2ADPAknj/pilQEC1+VvKrFrA2rseH0Ft8Zjv5bj0sgQ++ZNBy45nvrrV4JxSbpufteF09Ss16+zcMnJd73b4opTcvP2CF6ckvY4wg0TAADdg4Q1GF8YEAAAAABJRU5ErkJggg==" 182 | }, 183 | "metadata": {}, 184 | "output_type": "display_data" 185 | } 186 | ], 187 | "source": [ 188 | "[img | (img, _, _) <- images]" 189 | ] 190 | }, 191 | { 192 | "cell_type": "markdown", 193 | "metadata": {}, 194 | "source": [ 195 | "## Step 2: Loading ONNX model" 196 | ] 197 | }, 198 | { 199 | "cell_type": "code", 200 | "execution_count": 8, 201 | "metadata": {}, 202 | "outputs": [], 203 | "source": [ 204 | "batch_size = length images\n", 205 | "channel_num = 1\n", 206 | "height = 28\n", 207 | "width = 28\n", 208 | "category_num = 10\n", 209 | "\n", 210 | "input_dims, output_dims :: Dims\n", 211 | "input_dims = [batch_size, channel_num, height, width]\n", 212 | "output_dims = [batch_size, category_num]\n", 213 | "\n", 214 | "-- Aliases to onnx's node input and output tensor name\n", 215 | "mnist_in_name = \"139900320569040\"\n", 216 | "mnist_out_name = \"139898462888656\"" 217 | ] 218 | }, 219 | { 220 | "cell_type": "code", 221 | "execution_count": 9, 222 | "metadata": {}, 223 | "outputs": [], 224 | "source": [ 225 | "-- Load ONNX model data\n", 226 | "model_data <- makeModelDataFromONNXFile (dataDir \"mnist.onnx\")" 227 | ] 228 | }, 229 | { 230 | "cell_type": "code", 231 | "execution_count": 10, 232 | "metadata": {}, 233 | "outputs": [], 234 | "source": [ 235 | "-- Specify inputs and outputs\n", 236 | "vpt <- makeVariableProfileTable\n", 237 | " [(mnist_in_name, DTypeFloat, input_dims)]\n", 238 | " [mnist_out_name]\n", 239 | " model_data\n", 240 | "optimizeModelData model_data vpt" 241 | ] 242 | }, 243 | { 244 | "cell_type": "code", 245 | "execution_count": 11, 246 | "metadata": {}, 247 | "outputs": [], 248 | "source": [ 249 | "-- Construct computation primitive list and memories\n", 250 | "model <- makeModel vpt model_data \"mkldnn\"" 251 | ] 252 | }, 253 | { 254 | "cell_type": "markdown", 255 | "metadata": {}, 256 | "source": [ 257 | "## Step 3: Run inference" 258 | ] 259 | }, 260 | { 261 | "cell_type": "code", 262 | "execution_count": 12, 263 | "metadata": {}, 264 | "outputs": [], 265 | "source": [ 266 | "import Data.Word\n", 267 | "import qualified Data.Vector.Generic as VG\n", 268 | "import qualified Data.Vector.Storable as VS\n", 269 | "\n", 270 | "-- Copy input image data to model's input array\n", 271 | "writeBuffer model mnist_in_name [VG.map fromIntegral (Picture.imageData img) :: VS.Vector Float | (img,_,_) <- images]\n", 272 | "\n", 273 | "-- Run inference\n", 274 | "run model\n", 275 | "\n", 276 | "-- Get output\n", 277 | "(vs :: [VS.Vector Float]) <- readBuffer model mnist_out_name" 278 | ] 279 | }, 280 | { 281 | "cell_type": "markdown", 282 | "metadata": {}, 283 | "source": [ 284 | "## Step 4: Examine the results" 285 | ] 286 | }, 287 | { 288 | "cell_type": "code", 289 | "execution_count": 13, 290 | "metadata": {}, 291 | "outputs": [], 292 | "source": [ 293 | "softmax :: (Real a, Floating a, VG.Vector v a) => v a -> v a\n", 294 | "softmax v | VG.null v = VG.empty\n", 295 | "softmax v = VG.map (/ s) v'\n", 296 | " where\n", 297 | " m = VG.maximum v\n", 298 | " v' = VG.map (\\x -> exp (x - m)) v\n", 299 | " s = VG.sum v'" 300 | ] 301 | }, 302 | { 303 | "cell_type": "code", 304 | "execution_count": 14, 305 | "metadata": {}, 306 | "outputs": [ 307 | { 308 | "data": { 309 | "text/plain": [ 310 | "0.png\n", 311 | "Expected: 0 Guessed: 0\n", 312 | "Scores: [(0,9777.892),(1,-7585.3193),(2,-2358.066),(3,-6089.8984),(4,-6311.9727),(5,-6329.25),(6,-2478.122),(7,-4269.2573),(8,-5181.277),(9,-2872.3535)]\n", 313 | "Probabilities: [(0,1.0),(1,0.0),(2,0.0),(3,0.0),(4,0.0),(5,0.0),(6,0.0),(7,0.0),(8,0.0),(9,0.0)]\n", 314 | "\n", 315 | "1.png\n", 316 | "Expected: 1 Guessed: 1\n", 317 | "Scores: [(0,-2165.809),(1,4201.37),(2,-1893.1384),(3,-7467.487),(4,-1335.353),(5,-4800.9795),(6,-416.61957),(7,-2426.936),(8,-503.85828),(9,-5068.9653)]\n", 318 | "Probabilities: [(0,0.0),(1,1.0),(2,0.0),(3,0.0),(4,0.0),(5,0.0),(6,0.0),(7,0.0),(8,0.0),(9,0.0)]\n", 319 | "\n", 320 | "2.png\n", 321 | "Expected: 2 Guessed: 2\n", 322 | "Scores: [(0,-6348.131),(1,-337.71472),(2,7276.1235),(3,-1548.1252),(4,-4145.504),(5,-6463.499),(6,-3721.0723),(7,-1077.4471),(8,-3267.8303),(9,-6660.838)]\n", 323 | "Probabilities: [(0,0.0),(1,0.0),(2,1.0),(3,0.0),(4,0.0),(5,0.0),(6,0.0),(7,0.0),(8,0.0),(9,0.0)]\n", 324 | "\n", 325 | "3.png\n", 326 | "Expected: 3 Guessed: 3\n", 327 | "Scores: [(0,-6894.149),(1,-2907.7124),(2,-1893.7008),(3,7371.175),(4,-8884.441),(5,-4217.8105),(6,-8523.427),(7,-4278.061),(8,-1147.6962),(9,-2374.1099)]\n", 328 | "Probabilities: [(0,0.0),(1,0.0),(2,0.0),(3,1.0),(4,0.0),(5,0.0),(6,0.0),(7,0.0),(8,0.0),(9,0.0)]\n", 329 | "\n", 330 | "4.png\n", 331 | "Expected: 4 Guessed: 4\n", 332 | "Scores: [(0,-4422.2188),(1,-1456.4163),(2,-4136.614),(3,-2221.086),(4,3840.482),(5,-3867.4666),(6,-71.5983),(7,-2693.855),(8,-2997.8684),(9,-4552.251)]\n", 333 | "Probabilities: [(0,0.0),(1,0.0),(2,0.0),(3,0.0),(4,1.0),(5,0.0),(6,0.0),(7,0.0),(8,0.0),(9,0.0)]\n", 334 | "\n", 335 | "5.png\n", 336 | "Expected: 5 Guessed: 5\n", 337 | "Scores: [(0,-7696.7534),(1,-1373.3462),(2,-6049.2095),(3,-311.19458),(4,-5299.1133),(5,5255.896),(6,-1519.8872),(7,-2771.51),(8,-2157.697),(9,-1522.519)]\n", 338 | "Probabilities: [(0,0.0),(1,0.0),(2,0.0),(3,0.0),(4,0.0),(5,1.0),(6,0.0),(7,0.0),(8,0.0),(9,0.0)]\n", 339 | "\n", 340 | "6.png\n", 341 | "Expected: 6 Guessed: 6\n", 342 | "Scores: [(0,-1738.4187),(1,-2735.6375),(2,-3410.437),(3,-3023.2148),(4,-2893.8752),(5,-354.15915),(6,3736.5344),(7,-4364.4053),(8,-287.73703),(9,-3295.0278)]\n", 343 | "Probabilities: [(0,0.0),(1,0.0),(2,0.0),(3,0.0),(4,0.0),(5,0.0),(6,1.0),(7,0.0),(8,0.0),(9,0.0)]\n", 344 | "\n", 345 | "7.png\n", 346 | "Expected: 7 Guessed: 7\n", 347 | "Scores: [(0,-2490.2102),(1,-631.0392),(2,1605.6711),(3,-832.4991),(4,-4712.144),(5,-6243.2437),(6,-5990.4595),(7,4322.048),(8,-4437.4653),(9,-2763.0308)]\n", 348 | "Probabilities: [(0,0.0),(1,0.0),(2,0.0),(3,0.0),(4,0.0),(5,0.0),(6,0.0),(7,1.0),(8,0.0),(9,0.0)]\n", 349 | "\n", 350 | "8.png\n", 351 | "Expected: 8 Guessed: 8\n", 352 | "Scores: [(0,-1720.1462),(1,-1791.0916),(2,-1960.5614),(3,17.458065),(4,-4540.445),(5,-2249.1724),(6,-2721.0393),(7,-4319.779),(8,3329.3308),(9,-2291.6396)]\n", 353 | "Probabilities: [(0,0.0),(1,0.0),(2,0.0),(3,0.0),(4,0.0),(5,0.0),(6,0.0),(7,0.0),(8,1.0),(9,0.0)]\n", 354 | "\n", 355 | "9.png\n", 356 | "Expected: 9 Guessed: 9\n", 357 | "Scores: [(0,-3192.5137),(1,-2531.5234),(2,-2891.6475),(3,-388.6603),(4,-1259.5007),(5,-921.98096),(6,-3483.3315),(7,1143.936),(8,-2964.6362),(9,1423.8708)]\n", 358 | "Probabilities: [(0,0.0),(1,0.0),(2,0.0),(3,0.0),(4,0.0),(5,0.0),(6,0.0),(7,0.0),(8,0.0),(9,1.0)]" 359 | ] 360 | }, 361 | "metadata": {}, 362 | "output_type": "display_data" 363 | } 364 | ], 365 | "source": [ 366 | "forM_ (zip images vs) $ \\((img,expected,fname), scores) -> do\n", 367 | " let guessed = VG.maxIndex scores\n", 368 | " putStrLn fname\n", 369 | " printf \"Expected: %d Guessed: %d\\n\" expected guessed\n", 370 | " putStrLn $ \"Scores: \" ++ show (zip [0..] (VG.toList scores))\n", 371 | " putStrLn $ \"Probabilities: \" ++ show (zip [0..] (VG.toList (softmax scores)))\n", 372 | " putStrLn \"\"" 373 | ] 374 | } 375 | ], 376 | "metadata": { 377 | "kernelspec": { 378 | "display_name": "Haskell", 379 | "language": "haskell", 380 | "name": "haskell" 381 | }, 382 | "language_info": { 383 | "codemirror_mode": "ihaskell", 384 | "file_extension": ".hs", 385 | "name": "haskell", 386 | "pygments_lexer": "Haskell", 387 | "version": "8.2.2" 388 | } 389 | }, 390 | "nbformat": 4, 391 | "nbformat_minor": 2 392 | } 393 | -------------------------------------------------------------------------------- /retrieve_data.hs: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env stack 2 | -- stack runghc --package conduit --package conduit-extra --package http-conduit 3 | {-# OPTIONS_GHC -Wall #-} 4 | import Control.Monad.Trans.Resource 5 | import Data.Conduit.Binary (sinkFile) 6 | import Network.HTTP.Simple 7 | 8 | main :: IO () 9 | main = do 10 | let downloadTo :: String -> FilePath -> IO () 11 | downloadTo req fname = do 12 | putStrLn $ req ++ " -> " ++ fname 13 | request <- parseRequest req 14 | runResourceT $ httpSink request $ \_ -> sinkFile fname 15 | downloadTo "https://preferredjp.box.com/shared/static/o2xip23e3f0knwc5ve78oderuglkf2wt.onnx" "./data/vgg16.onnx" 16 | downloadTo "https://raw.githubusercontent.com/HoldenCaulfieldRye/caffe/master/data/ilsvrc12/synset_words.txt" "./data/synset_words.txt" 17 | downloadTo "https://upload.wikimedia.org/wikipedia/commons/5/54/Light_sussex_hen.jpg" "./data/Light_sussex_hen.jpg" 18 | -------------------------------------------------------------------------------- /retrieve_data.sh: -------------------------------------------------------------------------------- 1 | wget https://www.dropbox.com/s/bjfn9kehukpbmcm/VGG16.onnx?dl=1 -O ./data/VGG16.onnx 2 | wget https://raw.githubusercontent.com/HoldenCaulfieldRye/caffe/master/data/ilsvrc12/synset_words.txt -O ./data/synset_words.txt 3 | wget https://upload.wikimedia.org/wikipedia/commons/5/54/Light_sussex_hen.jpg -O ./data/Light_sussex_hen.jpg 4 | -------------------------------------------------------------------------------- /src/Menoh.hs: -------------------------------------------------------------------------------- 1 | {-# OPTIONS_GHC -Wall #-} 2 | {-# LANGUAGE CPP #-} 3 | {-# LANGUAGE DeriveDataTypeable #-} 4 | {-# LANGUAGE FlexibleContexts #-} 5 | {-# LANGUAGE FlexibleInstances #-} 6 | {-# LANGUAGE ScopedTypeVariables #-} 7 | ----------------------------------------------------------------------------- 8 | -- | 9 | -- Module : Menoh 10 | -- Copyright : Copyright (c) 2018 Preferred Networks, Inc. 11 | -- License : MIT (see the file LICENSE) 12 | -- 13 | -- Maintainer : Masahiro Sakai 14 | -- Stability : experimental 15 | -- Portability : non-portable 16 | -- 17 | -- Haskell binding for /Menoh/ DNN inference library. 18 | -- 19 | -- = Basic usage 20 | -- 21 | -- 1. Load computation graph from ONNX file using 'makeModelDataFromONNXFile'. 22 | -- 2. Specify input variable type/dimentions (in particular batch size) and 23 | -- which output variables you want to retrieve. This can be done by 24 | -- constructing 'VariableProfileTable' using 'makeVariableProfileTable'. 25 | -- 3. Optimize 'ModelData' with respect to your 'VariableProfileTable' by using 26 | -- 'optimizeModelData'. 27 | -- 4. Construct a 'Model' using 'makeModel' or 'makeModelWithConfig'. 28 | -- If you want to use custom buffers instead of internally allocated ones, 29 | -- You need to use more low level 'ModelBuilder'. 30 | -- 5. Load input data. This can be done conveniently using 'writeBuffer'. 31 | -- There are also more low-level API such as 'unsafeGetBuffer' and 'withBuffer'. 32 | -- 6. Run inference using 'run'. 33 | -- 7. Retrieve the result data. This can be done conveniently using 'readBuffer'. 34 | -- 35 | -- = Note on thread safety 36 | -- 37 | -- TL;DR: If you want to use Menoh from multiple haskell threads, you need to 38 | -- use /threaded/ RTS by supplying @-threaded@ option to GHC. 39 | -- 40 | -- Menoh uses thread local storage (TLS) for storing error information, and 41 | -- the only way to use TLS safely is to use in /bound/ threads 42 | -- (see "Control.Concurrent#boundthreads"). 43 | -- 44 | -- * In /threaded RTS/ (i.e. 'rtsSupportsBoundThreads' is True), this module 45 | -- runs computation in bound threads by using 'runInBoundThread'. (If the 46 | -- calling thread is not bound, 'runInBoundThread' create a bound thread 47 | -- temporarily and run the computation inside it). 48 | -- 49 | -- * In /non-threaded RTS/, this module /does not/ use 'runInBoundThread' and 50 | -- is therefore unsafe to use from multiple haskell threads. Using non-threaded 51 | -- RTS is allowed for the sake of convenience (e.g. running in GHCi) despite 52 | -- its unsafety. 53 | -- 54 | ----------------------------------------------------------------------------- 55 | 56 | #include "MachDeps.h" 57 | #include 58 | 59 | #define MIN_VERSION_libmenoh(major,minor,patch) (\ 60 | (major) < MENOH_MAJOR_VERSION || \ 61 | (major) == MENOH_MAJOR_VERSION && (minor) < MENOH_MINOR_VERSION || \ 62 | (major) == MENOH_MAJOR_VERSION && (minor) == MENOH_MINOR_VERSION && (patch) <= MENOH_PATCH_VERSION) 63 | 64 | module Menoh 65 | ( 66 | -- * Basic data types 67 | Dims 68 | , DType (..) 69 | , Error (..) 70 | 71 | -- * ModelData type 72 | , ModelData (..) 73 | , makeModelDataFromONNXFile 74 | , makeModelDataFromONNX 75 | , makeModelDataFromONNXByteString 76 | , optimizeModelData 77 | -- ** Manual model data construction API 78 | , makeModelData 79 | , addParameterFromPtr 80 | , addNewNode 81 | , addInputNameToCurrentNode 82 | , addOutputNameToCurrentNode 83 | , AttributeType (..) 84 | , addAttribute 85 | 86 | -- * VariableProfileTable 87 | , VariableProfileTable (..) 88 | , makeVariableProfileTable 89 | , vptGetDType 90 | , vptGetDims 91 | 92 | -- * Model type 93 | , Model (..) 94 | , makeModel 95 | , makeModelWithConfig 96 | , run 97 | , getDType 98 | , getDims 99 | -- ** Accessors for buffers 100 | , ToBuffer (..) 101 | , FromBuffer (..) 102 | , writeBuffer 103 | , readBuffer 104 | -- ** Low-level accessors for buffers 105 | , unsafeGetBuffer 106 | , withBuffer 107 | -- ** Deprecated accessors for buffers 108 | , HasDType (..) 109 | , writeBufferFromVector 110 | , writeBufferFromStorableVector 111 | , readBufferToVector 112 | , readBufferToStorableVector 113 | 114 | -- * Misc 115 | , version 116 | , bindingVersion 117 | 118 | -- * Low-level API 119 | 120 | -- ** Builder for 'VariableProfileTable' 121 | , VariableProfileTableBuilder (..) 122 | , makeVariableProfileTableBuilder 123 | , addInputProfileDims2 124 | , addInputProfileDims4 125 | , addOutputName 126 | , addOutputProfile 127 | , AddOutput (..) 128 | , buildVariableProfileTable 129 | 130 | -- ** Builder for 'Model' 131 | , ModelBuilder (..) 132 | , makeModelBuilder 133 | , attachExternalBuffer 134 | , buildModel 135 | , buildModelWithConfig 136 | ) where 137 | 138 | import Control.Applicative 139 | import Control.Concurrent 140 | import Control.Monad 141 | import Control.Monad.Trans.Control (MonadBaseControl, liftBaseOp) 142 | import Control.Monad.IO.Class 143 | import Control.Exception 144 | import qualified Data.Aeson as J 145 | import qualified Data.ByteString as BS 146 | import qualified Data.ByteString.Lazy as BL 147 | import Data.Proxy 148 | import Data.Typeable 149 | import qualified Data.Vector as V 150 | import qualified Data.Vector.Generic as VG 151 | import qualified Data.Vector.Storable as VS 152 | import qualified Data.Vector.Storable.Mutable as VSM 153 | import qualified Data.Vector.Unboxed as VU 154 | import Data.IntMap (IntMap) 155 | import qualified Data.IntMap as IntMap 156 | import Data.Version 157 | import Foreign 158 | import Foreign.C 159 | 160 | import qualified Menoh.Base as Base 161 | import qualified Paths_menoh 162 | 163 | -- ------------------------------------------------------------------------ 164 | 165 | -- | Functions in this module can throw this exception type. 166 | data Error 167 | = ErrorStdError String 168 | | ErrorUnknownError String 169 | | ErrorInvalidFilename String 170 | | ErrorONNXParseError String 171 | | ErrorInvalidDType String 172 | | ErrorInvalidAttributeType String 173 | | ErrorUnsupportedOperatorAttribute String 174 | | ErrorDimensionMismatch String 175 | | ErrorVariableNotFound String 176 | | ErrorIndexOutOfRange String 177 | | ErrorJSONParseError String 178 | | ErrorInvalidBackendName String 179 | | ErrorUnsupportedOperator String 180 | | ErrorFailedToConfigureOperator String 181 | | ErrorBackendError String 182 | | ErrorSameNamedVariableAlreadyExist String 183 | | UnsupportedInputDims String 184 | | SameNamedParameterAlreadyExist String 185 | | SameNamedAttributeAlreadyExist String 186 | | InvalidBackendConfigError String 187 | | InputNotFoundError String 188 | | OutputNotFoundError String 189 | deriving (Eq, Ord, Show, Read, Typeable) 190 | 191 | instance Exception Error 192 | 193 | runMenoh :: IO Base.MenohErrorCode -> IO () 194 | runMenoh m = runInBoundThread' $ do 195 | e <- m 196 | if e == Base.menohErrorCodeSuccess then 197 | return () 198 | else do 199 | s <- peekCString =<< Base.menoh_get_last_error_message 200 | case IntMap.lookup (fromIntegral e) table of 201 | Just ex -> throwIO $ ex s 202 | Nothing -> throwIO $ ErrorUnknownError $ s ++ "(error code: " ++ show (fromIntegral e :: Int) ++ ")" 203 | where 204 | table :: IntMap (String -> Error) 205 | table = IntMap.fromList $ map (\(k,v) -> (fromIntegral k, v)) $ 206 | [ (Base.menohErrorCodeStdError , ErrorStdError) 207 | , (Base.menohErrorCodeUnknownError , ErrorUnknownError) 208 | , (Base.menohErrorCodeInvalidFilename , ErrorInvalidFilename) 209 | , (Base.menohErrorCodeOnnxParseError , ErrorONNXParseError) 210 | , (Base.menohErrorCodeInvalidDtype , ErrorInvalidDType) 211 | , (Base.menohErrorCodeInvalidAttributeType , ErrorInvalidAttributeType) 212 | , (Base.menohErrorCodeUnsupportedOperatorAttribute , ErrorUnsupportedOperatorAttribute) 213 | , (Base.menohErrorCodeDimensionMismatch , ErrorDimensionMismatch) 214 | , (Base.menohErrorCodeVariableNotFound , ErrorVariableNotFound) 215 | , (Base.menohErrorCodeIndexOutOfRange , ErrorIndexOutOfRange) 216 | , (Base.menohErrorCodeJsonParseError , ErrorJSONParseError) 217 | , (Base.menohErrorCodeInvalidBackendName , ErrorInvalidBackendName) 218 | , (Base.menohErrorCodeUnsupportedOperator , ErrorUnsupportedOperator) 219 | , (Base.menohErrorCodeFailedToConfigureOperator , ErrorFailedToConfigureOperator) 220 | , (Base.menohErrorCodeBackendError , ErrorBackendError) 221 | , (Base.menohErrorCodeSameNamedVariableAlreadyExist , ErrorSameNamedVariableAlreadyExist) 222 | , (Base.menohErrorCodeUnsupportedInputDims , UnsupportedInputDims) 223 | , (Base.menohErrorCodeSameNamedParameterAlreadyExist, SameNamedParameterAlreadyExist) 224 | , (Base.menohErrorCodeSameNamedAttributeAlreadyExist, SameNamedAttributeAlreadyExist) 225 | , (Base.menohErrorCodeInvalidBackendConfigError , InvalidBackendConfigError) 226 | , (Base.menohErrorCodeInputNotFoundError , InputNotFoundError) 227 | , (Base.menohErrorCodeOutputNotFoundError , OutputNotFoundError) 228 | ] 229 | 230 | runInBoundThread' :: IO a -> IO a 231 | runInBoundThread' action 232 | | rtsSupportsBoundThreads = runInBoundThread action 233 | | otherwise = action 234 | 235 | -- ------------------------------------------------------------------------ 236 | 237 | -- | Data type of array elements 238 | data DType 239 | = DTypeFloat -- ^ single precision floating point number 240 | | DTypeUnknown !Base.MenohDType -- ^ types that this binding is unware of 241 | deriving (Eq, Ord, Show, Read) 242 | 243 | instance Enum DType where 244 | toEnum x 245 | | x == fromIntegral Base.menohDtypeFloat = DTypeFloat 246 | | otherwise = DTypeUnknown (fromIntegral x) 247 | 248 | fromEnum DTypeFloat = fromIntegral Base.menohDtypeFloat 249 | fromEnum (DTypeUnknown i) = fromIntegral i 250 | 251 | dtypeSize :: DType -> Int 252 | dtypeSize DTypeFloat = sizeOf (undefined :: CFloat) 253 | dtypeSize (DTypeUnknown _) = error "Menoh.dtypeSize: unknown DType" 254 | 255 | {-# DEPRECATED HasDType "use FromBuffer/ToBuffer instead" #-} 256 | -- | Haskell types that have associated 'DType' type code. 257 | class Storable a => HasDType a where 258 | dtypeOf :: Proxy a -> DType 259 | 260 | instance HasDType CFloat where 261 | dtypeOf _ = DTypeFloat 262 | 263 | #if SIZEOF_HSFLOAT == SIZEOF_FLOAT 264 | 265 | instance HasDType Float where 266 | dtypeOf _ = DTypeFloat 267 | 268 | #endif 269 | 270 | -- ------------------------------------------------------------------------ 271 | 272 | -- | Dimensions of array 273 | type Dims = [Int] 274 | 275 | -- ------------------------------------------------------------------------ 276 | 277 | -- | @ModelData@ contains model parameters and computation graph structure. 278 | newtype ModelData = ModelData (ForeignPtr Base.MenohModelData) 279 | 280 | {-# DEPRECATED makeModelDataFromONNX "use makeModelDataFromONNXFile instead" #-} 281 | -- | Load onnx file and make 'ModelData'. 282 | makeModelDataFromONNX :: MonadIO m => FilePath -> m ModelData 283 | makeModelDataFromONNX = makeModelDataFromONNXFile 284 | 285 | -- | Load onnx file and make 'ModelData'. 286 | makeModelDataFromONNXFile :: MonadIO m => FilePath -> m ModelData 287 | makeModelDataFromONNXFile fpath = liftIO $ withCString fpath $ \fpath' -> alloca $ \ret -> do 288 | runMenoh $ Base.menoh_make_model_data_from_onnx fpath' ret 289 | liftM ModelData $ newForeignPtr Base.menoh_delete_model_data_funptr =<< peek ret 290 | 291 | -- | make 'ModelData' from on-memory 'BS.ByteString'. 292 | makeModelDataFromONNXByteString :: MonadIO m => BS.ByteString -> m ModelData 293 | makeModelDataFromONNXByteString b = liftIO $ BS.useAsCStringLen b $ \(p,len) -> alloca $ \ret -> do 294 | runMenoh $ Base.menoh_make_model_data_from_onnx_data_on_memory p (fromIntegral len) ret 295 | liftM ModelData $ newForeignPtr Base.menoh_delete_model_data_funptr =<< peek ret 296 | 297 | -- | Optimize function for 'ModelData'. 298 | -- 299 | -- This function modify given 'ModelData'. 300 | optimizeModelData :: MonadIO m => ModelData -> VariableProfileTable -> m () 301 | optimizeModelData (ModelData m) (VariableProfileTable vpt) = liftIO $ 302 | withForeignPtr m $ \m' -> withForeignPtr vpt $ \vpt' -> 303 | runMenoh $ Base.menoh_model_data_optimize m' vpt' 304 | 305 | -- | Make empty 'ModelData' 306 | makeModelData :: MonadIO m => m ModelData 307 | makeModelData = liftIO $ alloca $ \ret -> do 308 | runMenoh $ Base.menoh_make_model_data ret 309 | liftM ModelData $ newForeignPtr Base.menoh_delete_model_data_funptr =<< peek ret 310 | 311 | -- | Add a new parameter in 'ModelData' 312 | -- 313 | -- This API is tentative and will be changed in the future. 314 | -- 315 | -- Duplication of parameter_name is not allowed and it throws error. 316 | addParameterFromPtr :: MonadIO m => ModelData -> String -> DType -> Dims -> Ptr a -> m () 317 | addParameterFromPtr (ModelData m) name dtype dims p = liftIO $ 318 | withForeignPtr m $ \m' -> withCString name $ \name' -> withArrayLen (map fromIntegral dims) $ \n dims' -> 319 | runMenoh $ Base.menoh_model_data_add_parameter m' name' (fromIntegral (fromEnum dtype)) (fromIntegral n) dims' p 320 | 321 | -- | Add a new node to 'ModelData' 322 | addNewNode :: MonadIO m => ModelData -> String -> m () 323 | addNewNode (ModelData m) name = liftIO $ 324 | withForeignPtr m $ \m' -> withCString name $ \name' -> 325 | runMenoh $ Base.menoh_model_data_add_new_node m' name' 326 | 327 | -- | Add a new input name to latest added node in 'ModelData' 328 | addInputNameToCurrentNode :: MonadIO m => ModelData -> String -> m () 329 | addInputNameToCurrentNode (ModelData m) name = liftIO $ 330 | withForeignPtr m $ \m' -> withCString name $ \name' -> 331 | runMenoh $ Base.menoh_model_data_add_input_name_to_current_node m' name' 332 | 333 | -- | Add a new output name to latest added node in 'ModelData' 334 | addOutputNameToCurrentNode :: MonadIO m => ModelData -> String -> m () 335 | addOutputNameToCurrentNode (ModelData m) name = liftIO $ 336 | withForeignPtr m $ \m' -> withCString name $ \name' -> 337 | runMenoh $ Base.menoh_model_data_add_output_name_to_current_node m' name' 338 | 339 | -- | A class of types that can be added to nodes using 'addAttribute'. 340 | class AttributeType value where 341 | basicAddAttribute :: Ptr Base.MenohModelData -> CString -> value -> IO () 342 | 343 | instance AttributeType Int where 344 | basicAddAttribute m' name' value = 345 | runMenoh $ Base.menoh_model_data_add_attribute_int_to_current_node m' name' (fromIntegral value) 346 | 347 | instance AttributeType Float where 348 | basicAddAttribute m' name' value = 349 | runMenoh $ Base.menoh_model_data_add_attribute_float_to_current_node m' name' (realToFrac value) 350 | 351 | instance AttributeType [Int] where 352 | basicAddAttribute m' name' values = 353 | withArrayLen (map fromIntegral values) $ \n values' -> 354 | runMenoh $ Base.menoh_model_data_add_attribute_ints_to_current_node m' name' (fromIntegral n) values' 355 | 356 | instance AttributeType [Float] where 357 | basicAddAttribute m' name' values = 358 | withArrayLen (map realToFrac values) $ \n values' -> 359 | runMenoh $ Base.menoh_model_data_add_attribute_floats_to_current_node m' name' (fromIntegral n) values' 360 | 361 | -- | Add a new attribute to latest added node in model_data 362 | addAttribute :: (AttributeType value, MonadIO m) => ModelData -> String -> value -> m () 363 | addAttribute (ModelData m) name value = liftIO $ 364 | withForeignPtr m $ \m' -> withCString name $ \name' -> 365 | basicAddAttribute m' name' value 366 | 367 | -- ------------------------------------------------------------------------ 368 | 369 | -- | Builder for creation of 'VariableProfileTable'. 370 | newtype VariableProfileTableBuilder 371 | = VariableProfileTableBuilder (ForeignPtr Base.MenohVariableProfileTableBuilder) 372 | 373 | -- | Factory function for 'VariableProfileTableBuilder'. 374 | makeVariableProfileTableBuilder :: MonadIO m => m VariableProfileTableBuilder 375 | makeVariableProfileTableBuilder = liftIO $ alloca $ \p -> do 376 | runMenoh $ Base.menoh_make_variable_profile_table_builder p 377 | liftM VariableProfileTableBuilder $ newForeignPtr Base.menoh_delete_variable_profile_table_builder_funptr =<< peek p 378 | 379 | addInputProfileDims :: MonadIO m => VariableProfileTableBuilder -> String -> DType -> Dims -> m () 380 | addInputProfileDims (VariableProfileTableBuilder vpt) name dtype dims = 381 | liftIO $ 382 | withForeignPtr vpt $ \vpt' -> withCString name $ \name' -> withArrayLen (map fromIntegral dims) $ \n dims' -> 383 | runMenoh $ Base.menoh_variable_profile_table_builder_add_input_profile 384 | vpt' name' (fromIntegral (fromEnum dtype)) (fromIntegral n) dims' 385 | 386 | -- | Add 2D input profile. 387 | -- 388 | -- Input profile contains name, dtype and dims @(num, size)@. 389 | -- This 2D input is conventional batched 1D inputs. 390 | {-# DEPRECATED addInputProfileDims2 "use addInputProfileDims instead" #-} 391 | addInputProfileDims2 392 | :: MonadIO m 393 | => VariableProfileTableBuilder 394 | -> String 395 | -> DType 396 | -> (Int, Int) -- ^ (num, size) 397 | -> m () 398 | addInputProfileDims2 (VariableProfileTableBuilder vpt) name dtype (num, size) = liftIO $ 399 | withForeignPtr vpt $ \vpt' -> withCString name $ \name' -> 400 | runMenoh $ Base.menoh_variable_profile_table_builder_add_input_profile_dims_2 401 | vpt' name' (fromIntegral (fromEnum dtype)) 402 | (fromIntegral num) (fromIntegral size) 403 | 404 | -- | Add 4D input profile 405 | -- 406 | -- Input profile contains name, dtype and dims @(num, channel, height, width)@. 407 | -- This 4D input is conventional batched image inputs. Image input is 408 | -- 3D (channel, height, width). 409 | {-# DEPRECATED addInputProfileDims4 "use addInputProfileDims instead" #-} 410 | addInputProfileDims4 411 | :: MonadIO m 412 | => VariableProfileTableBuilder 413 | -> String 414 | -> DType 415 | -> (Int, Int, Int, Int) -- ^ (num, channel, height, width) 416 | -> m () 417 | addInputProfileDims4 (VariableProfileTableBuilder vpt) name dtype (num, channel, height, width) = liftIO $ 418 | withForeignPtr vpt $ \vpt' -> withCString name $ \name' -> 419 | runMenoh $ Base.menoh_variable_profile_table_builder_add_input_profile_dims_4 420 | vpt' name' (fromIntegral (fromEnum dtype)) 421 | (fromIntegral num) (fromIntegral channel) (fromIntegral height) (fromIntegral width) 422 | 423 | -- | Add output name 424 | -- 425 | -- Output profile contains name and dtype. Its 'Dims' and 'DType' are calculated 426 | -- automatically, so that you don't need to specify explicitly. 427 | addOutputName :: MonadIO m => VariableProfileTableBuilder -> String -> m () 428 | addOutputName (VariableProfileTableBuilder vpt) name = liftIO $ 429 | withForeignPtr vpt $ \vpt' -> withCString name $ \name' -> 430 | runMenoh $ Base.menoh_variable_profile_table_builder_add_output_name 431 | vpt' name' 432 | 433 | {-# DEPRECATED addOutputProfile "use addOutputName instead" #-} 434 | -- | Add output profile 435 | -- 436 | -- Output profile contains name and dtype. Its 'Dims' are calculated automatically, 437 | -- so that you don't need to specify explicitly. 438 | addOutputProfile :: MonadIO m => VariableProfileTableBuilder -> String -> DType -> m () 439 | addOutputProfile (VariableProfileTableBuilder vpt) name dtype = liftIO $ 440 | withForeignPtr vpt $ \vpt' -> withCString name $ \name' -> 441 | runMenoh $ Base.menoh_variable_profile_table_builder_add_output_profile 442 | vpt' name' (fromIntegral (fromEnum dtype)) 443 | 444 | -- | Type class for abstracting 'addOutputProfile' and 'addOutputName'. 445 | class AddOutput a where 446 | addOutput :: VariableProfileTableBuilder -> a -> IO () 447 | 448 | instance AddOutput String where 449 | addOutput = addOutputName 450 | 451 | instance AddOutput (String, DType) where 452 | addOutput b (name,_dtype) = addOutputName b name 453 | 454 | -- | Factory function for 'VariableProfileTable' 455 | buildVariableProfileTable 456 | :: MonadIO m 457 | => VariableProfileTableBuilder 458 | -> ModelData 459 | -> m VariableProfileTable 460 | buildVariableProfileTable (VariableProfileTableBuilder b) (ModelData m) = liftIO $ 461 | withForeignPtr b $ \b' -> withForeignPtr m $ \m' -> alloca $ \ret -> do 462 | runMenoh $ Base.menoh_build_variable_profile_table b' m' ret 463 | liftM VariableProfileTable $ newForeignPtr Base.menoh_delete_variable_profile_table_funptr =<< peek ret 464 | 465 | -- ------------------------------------------------------------------------ 466 | 467 | -- | @VariableProfileTable@ contains information of dtype and dims of variables. 468 | -- 469 | -- Users can access to dtype and dims via 'vptGetDType' and 'vptGetDims'. 470 | newtype VariableProfileTable 471 | = VariableProfileTable (ForeignPtr Base.MenohVariableProfileTable) 472 | 473 | -- | Convenient function for constructing 'VariableProfileTable'. 474 | -- 475 | -- If you need finer control, you can use 'VariableProfileTableBuidler'. 476 | makeVariableProfileTable 477 | :: (AddOutput a, MonadIO m) 478 | => [(String, DType, Dims)] -- ^ input names with dtypes and dims 479 | -> [a] -- ^ required output informations (@`String`@ or @('String', 'DType')@) 480 | -> ModelData -- ^ model data 481 | -> m VariableProfileTable 482 | makeVariableProfileTable input_name_and_dims_pair_list required_output_name_list model_data = liftIO $ runInBoundThread' $ do 483 | b <- makeVariableProfileTableBuilder 484 | forM_ input_name_and_dims_pair_list $ \(name,dtype,dims) -> do 485 | addInputProfileDims b name dtype dims 486 | mapM_ (addOutput b) required_output_name_list 487 | buildVariableProfileTable b model_data 488 | 489 | -- | Accessor function for 'VariableProfileTable' 490 | -- 491 | -- Select variable name and get its 'DType'. 492 | vptGetDType :: MonadIO m => VariableProfileTable -> String -> m DType 493 | vptGetDType (VariableProfileTable vpt) name = liftIO $ 494 | withForeignPtr vpt $ \vpt' -> withCString name $ \name' -> alloca $ \ret -> do 495 | runMenoh $ Base.menoh_variable_profile_table_get_dims_size vpt' name' ret 496 | (toEnum . fromIntegral) <$> peek ret 497 | 498 | -- | Accessor function for 'VariableProfileTable' 499 | -- 500 | -- Select variable name and get its 'Dims'. 501 | vptGetDims :: MonadIO m => VariableProfileTable -> String -> m Dims 502 | vptGetDims (VariableProfileTable vpt) name = liftIO $ runInBoundThread' $ 503 | withForeignPtr vpt $ \vpt' -> withCString name $ \name' -> alloca $ \ret -> do 504 | runMenoh $ Base.menoh_variable_profile_table_get_dims_size vpt' name' ret 505 | size <- peek ret 506 | forM [0..size-1] $ \i -> do 507 | runMenoh $ Base.menoh_variable_profile_table_get_dims_at vpt' name' (fromIntegral i) ret 508 | fromIntegral <$> peek ret 509 | 510 | -- ------------------------------------------------------------------------ 511 | 512 | -- | Helper for creating of 'Model'. 513 | newtype ModelBuilder = ModelBuilder (ForeignPtr Base.MenohModelBuilder) 514 | 515 | -- | Factory function for 'ModelBuilder' 516 | makeModelBuilder :: MonadIO m => VariableProfileTable -> m ModelBuilder 517 | makeModelBuilder (VariableProfileTable vpt) = liftIO $ 518 | withForeignPtr vpt $ \vpt' -> alloca $ \ret -> do 519 | runMenoh $ Base.menoh_make_model_builder vpt' ret 520 | liftM ModelBuilder $ newForeignPtr Base.menoh_delete_model_builder_funptr =<< peek ret 521 | 522 | -- | Attach a buffer which allocated by users. 523 | -- 524 | -- Users can attach a external buffer which they allocated to target variable. 525 | -- 526 | -- Variables attached no external buffer are attached internal buffers allocated 527 | -- automatically. 528 | -- 529 | -- Users can get that internal buffer handle by calling 'unsafeGetBuffer' etc. later. 530 | attachExternalBuffer :: MonadIO m => ModelBuilder -> String -> Ptr a -> m () 531 | attachExternalBuffer (ModelBuilder m) name buf = liftIO $ 532 | withForeignPtr m $ \m' -> withCString name $ \name' -> 533 | runMenoh $ Base.menoh_model_builder_attach_external_buffer m' name' buf 534 | 535 | -- | Factory function for 'Model'. 536 | buildModel 537 | :: MonadIO m 538 | => ModelBuilder 539 | -> ModelData 540 | -> String -- ^ backend name 541 | -> m Model 542 | buildModel builder m backend = liftIO $ 543 | withCString "" $ 544 | buildModelWithConfigString builder m backend 545 | 546 | -- | Similar to 'buildModel', but backend specific configuration can be supplied as JSON. 547 | buildModelWithConfig 548 | :: (MonadIO m, J.ToJSON a) 549 | => ModelBuilder 550 | -> ModelData 551 | -> String -- ^ backend name 552 | -> a -- ^ backend config 553 | -> m Model 554 | buildModelWithConfig builder m backend backend_config = liftIO $ 555 | BS.useAsCString (BL.toStrict (J.encode backend_config)) $ 556 | buildModelWithConfigString builder m backend 557 | 558 | buildModelWithConfigString 559 | :: MonadIO m 560 | => ModelBuilder 561 | -> ModelData 562 | -> String -- ^ backend name 563 | -> CString -- ^ backend config 564 | -> m Model 565 | buildModelWithConfigString (ModelBuilder builder) (ModelData m) backend backend_config = liftIO $ 566 | withForeignPtr builder $ \builder' -> withForeignPtr m $ \m' -> withCString backend $ \backend' -> alloca $ \ret -> do 567 | runMenoh $ Base.menoh_build_model builder' m' backend' backend_config ret 568 | liftM Model $ newForeignPtr Base.menoh_delete_model_funptr =<< peek ret 569 | 570 | -- ------------------------------------------------------------------------ 571 | 572 | -- | ONNX model with input/output buffers 573 | newtype Model = Model (ForeignPtr Base.MenohModel) 574 | 575 | -- | Run model inference. 576 | -- 577 | -- This function can't be called asynchronously. 578 | run :: MonadIO m => Model -> m () 579 | run (Model model) = liftIO $ withForeignPtr model $ \model' -> do 580 | runMenoh $ Base.menoh_model_run model' 581 | 582 | -- | Get 'DType' of target variable. 583 | getDType :: MonadIO m => Model -> String -> m DType 584 | getDType (Model m) name = liftIO $ do 585 | withForeignPtr m $ \m' -> withCString name $ \name' -> alloca $ \ret -> do 586 | runMenoh $ Base.menoh_model_get_variable_dtype m' name' ret 587 | liftM (toEnum . fromIntegral) $ peek ret 588 | 589 | -- | Get 'Dims' of target variable. 590 | getDims :: MonadIO m => Model -> String -> m Dims 591 | getDims (Model m) name = liftIO $ runInBoundThread' $ do 592 | withForeignPtr m $ \m' -> withCString name $ \name' -> alloca $ \ret -> do 593 | runMenoh $ Base.menoh_model_get_variable_dims_size m' name' ret 594 | size <- peek ret 595 | forM [0..size-1] $ \i -> do 596 | runMenoh $ Base.menoh_model_get_variable_dims_at m' name' (fromIntegral i) ret 597 | fromIntegral <$> peek ret 598 | 599 | -- ------------------------------------------------------------------------ 600 | -- Accessing buffers 601 | 602 | -- | Get a buffer handle attached to target variable. 603 | -- 604 | -- Users can get a buffer handle attached to target variable. 605 | -- If that buffer is allocated by users and attached to the variable by calling 606 | -- 'attachExternalBuffer', returned buffer handle is same to it. 607 | -- 608 | -- This function is unsafe because it does not prevent the model to be GC'ed and 609 | -- the returned pointer become dangling pointer. 610 | -- 611 | -- See also 'withBuffer'. 612 | unsafeGetBuffer :: MonadIO m => Model -> String -> m (Ptr a) 613 | unsafeGetBuffer (Model m) name = liftIO $ do 614 | withForeignPtr m $ \m' -> withCString name $ \name' -> alloca $ \ret -> do 615 | runMenoh $ Base.menoh_model_get_variable_buffer_handle m' name' ret 616 | peek ret 617 | 618 | -- | This function takes a function which is applied to the buffer associated to specified variable. 619 | -- The resulting action is then executed. The buffer is kept alive at least during the whole action, 620 | -- even if it is not used directly inside. 621 | -- Note that it is not safe to return the pointer from the action and use it after the action completes. 622 | -- 623 | -- See also 'unsafeGetBuffer'. 624 | withBuffer :: forall m r a. (MonadIO m, MonadBaseControl IO m) => Model -> String -> (Ptr a -> m r) -> m r 625 | withBuffer (Model m) name f = 626 | liftBaseOp (withForeignPtr m) $ \m' -> 627 | (liftBaseOp (withCString name) :: (CString -> m r) -> m r) $ \name' -> 628 | liftBaseOp alloca $ \ret -> do 629 | p <- liftIO $ do 630 | runMenoh $ Base.menoh_model_get_variable_buffer_handle m' name' ret 631 | peek ret 632 | f p 633 | 634 | -- | Type that can be written to menoh's buffer. 635 | class ToBuffer a where 636 | -- Basic method for implementing @ToBuffer@ class. 637 | -- Normal user should use 'writeBuffer' instead. 638 | basicWriteBuffer :: DType -> Dims -> Ptr () -> a -> IO () 639 | 640 | -- | Type that can be read from menoh's buffer. 641 | class FromBuffer a where 642 | -- Basic method for implementing @FromBuffer@ class. 643 | -- Normal user should use 'readBuffer' instead. 644 | basicReadBuffer :: DType -> Dims -> Ptr () -> IO a 645 | 646 | -- | Read values from the given model's buffer 647 | readBuffer :: (FromBuffer a, MonadIO m) => Model -> String -> m a 648 | readBuffer model name = liftIO $ withBuffer model name $ \p -> do 649 | dtype <- getDType model name 650 | dims <- getDims model name 651 | basicReadBuffer dtype dims p 652 | 653 | -- | Write values to the given model's buffer 654 | writeBuffer :: (ToBuffer a, MonadIO m) => Model -> String -> a -> m () 655 | writeBuffer model name a = liftIO $ withBuffer model name $ \p -> do 656 | dtype <- getDType model name 657 | dims <- getDims model name 658 | basicWriteBuffer dtype dims p a 659 | 660 | 661 | -- | Default implementation of 'basicWriteBuffer' for 'VG.Vector' class 662 | -- for the cases whete the 'Storable' is compatible for representation in buffers. 663 | basicWriteBufferGenericVectorStorable 664 | :: forall v a. (VG.Vector v a, Storable a) 665 | => DType -> DType -> Dims -> Ptr () -> v a -> IO () 666 | basicWriteBufferGenericVectorStorable dtype0 dtype dims p vec = do 667 | let n = product dims 668 | p' = castPtr p 669 | checkDTypeAndSize "Menoh.basicWriteBufferGenericVectorStorable" (dtype, n) (dtype0, VG.length vec) 670 | forM_ [0..n-1] $ \i -> do 671 | pokeElemOff p' i (vec VG.! i) 672 | 673 | -- | Default implementation of 'basicReadToBuffer' for 'VG.Vector' class 674 | -- for the cases whete the 'Storable' is compatible for representation in buffers. 675 | basicReadBufferGenericVectorStorable 676 | :: forall v a. (VG.Vector v a, Storable a) 677 | => DType -> DType -> Dims -> Ptr () -> IO (v a) 678 | basicReadBufferGenericVectorStorable dtype0 dtype dims p = do 679 | checkDType "Menoh.basicReadBufferGenericVectorStorable" dtype dtype0 680 | let n = product dims 681 | p' = castPtr p 682 | VG.generateM n $ peekElemOff p' 683 | 684 | 685 | -- | Default implementation of 'basicWriteBuffer' for 'VS.Vector' class 686 | -- for the cases whete the 'Storable' is compatible for representation in buffers. 687 | basicWriteBufferStorableVector 688 | :: forall a. (Storable a) 689 | => DType -> DType -> Dims -> Ptr () -> VS.Vector a -> IO () 690 | basicWriteBufferStorableVector dtype0 dtype dims p vec = do 691 | let n = product dims 692 | checkDTypeAndSize "Menoh.basicWriteBufferStorableVector" (dtype, n) (dtype0, VG.length vec) 693 | VS.unsafeWith vec $ \src -> do 694 | copyArray (castPtr p) src n 695 | 696 | -- | Default implementation of 'basicReadToBuffer' for 'VS.Vector' class 697 | -- for the cases whete the 'Storable' is compatible for representation in buffers. 698 | basicReadBufferStorableVector 699 | :: forall a. (Storable a) 700 | => DType -> DType -> Dims -> Ptr () -> IO (VS.Vector a) 701 | basicReadBufferStorableVector dtype0 dtype dims p = do 702 | checkDType "Menoh.basicReadBufferStorableVector" dtype dtype0 703 | let n = product dims 704 | vec <- VSM.new n 705 | VSM.unsafeWith vec $ \dst -> copyArray dst (castPtr p) n 706 | VS.unsafeFreeze vec 707 | 708 | 709 | instance ToBuffer (V.Vector Float) where 710 | basicWriteBuffer = basicWriteBufferGenericVectorStorable DTypeFloat 711 | instance FromBuffer (V.Vector Float) where 712 | basicReadBuffer = basicReadBufferGenericVectorStorable DTypeFloat 713 | 714 | 715 | instance ToBuffer (VU.Vector Float) where 716 | basicWriteBuffer = basicWriteBufferGenericVectorStorable DTypeFloat 717 | instance FromBuffer (VU.Vector Float) where 718 | basicReadBuffer = basicReadBufferGenericVectorStorable DTypeFloat 719 | 720 | 721 | instance ToBuffer (VS.Vector Float) where 722 | basicWriteBuffer = basicWriteBufferStorableVector DTypeFloat 723 | instance FromBuffer (VS.Vector Float) where 724 | basicReadBuffer = basicReadBufferStorableVector DTypeFloat 725 | 726 | 727 | instance ToBuffer a => ToBuffer [a] where 728 | basicWriteBuffer _dtype [] _p _xs = 729 | throwIO $ ErrorDimensionMismatch $ "ToBuffer{[a]}.basicWriteBuffer: empty dims" 730 | basicWriteBuffer dtype (dim : dims) p xs = do 731 | unless (dim == length xs) $ do 732 | throwIO $ ErrorDimensionMismatch $ "ToBuffer{[a]}.basicWriteBuffer: dimension mismatch" 733 | let s = product dims * dtypeSize dtype 734 | forM_ (zip [0,s..] xs) $ \(offset,x) -> do 735 | basicWriteBuffer dtype dims (p `plusPtr` offset) x 736 | 737 | instance FromBuffer a => FromBuffer [a] where 738 | basicReadBuffer _dtype [] _p = 739 | throwIO $ ErrorDimensionMismatch $ "FromBuffer{[a]}.basicReadBuffer: empty dims" 740 | basicReadBuffer dtype (dim : dims) p = do 741 | let s = product dims * dtypeSize dtype 742 | forM [0..dim-1] $ \i -> do 743 | basicReadBuffer dtype dims (p `plusPtr` (i*s)) 744 | 745 | 746 | checkDType :: String -> DType -> DType -> IO () 747 | checkDType name dtype1 dtype2 748 | | dtype1 /= dtype2 = throwIO $ ErrorInvalidDType $ name ++ ": dtype mismatch" 749 | | otherwise = return () 750 | 751 | checkDTypeAndSize :: String -> (DType,Int) -> (DType,Int) -> IO () 752 | checkDTypeAndSize name (dtype1,n1) (dtype2,n2) 753 | | dtype1 /= dtype2 = throwIO $ ErrorInvalidDType $ name ++ ": dtype mismatch" 754 | | n1 /= n2 = throwIO $ ErrorDimensionMismatch $ name ++ ": dimension mismatch" 755 | | otherwise = return () 756 | 757 | 758 | {-# DEPRECATED writeBufferFromVector, writeBufferFromStorableVector "Use ToBuffer class and writeBuffer instead" #-} 759 | 760 | -- | Copy whole elements of 'VG.Vector' into a model's buffer 761 | writeBufferFromVector :: forall v a m. (VG.Vector v a, HasDType a, MonadIO m) => Model -> String -> v a -> m () 762 | writeBufferFromVector model name vec = liftIO $ withBuffer model name $ \p -> do 763 | dtype <- getDType model name 764 | dims <- getDims model name 765 | basicWriteBufferGenericVectorStorable (dtypeOf (Proxy :: Proxy a)) dtype dims p vec 766 | 767 | -- | Copy whole elements of @'VS.Vector' a@ into a model's buffer 768 | writeBufferFromStorableVector :: forall a m. (HasDType a, MonadIO m) => Model -> String -> VS.Vector a -> m () 769 | writeBufferFromStorableVector model name vec = liftIO $ withBuffer model name $ \p -> do 770 | dtype <- getDType model name 771 | dims <- getDims model name 772 | basicWriteBufferStorableVector (dtypeOf (Proxy :: Proxy a)) dtype dims p vec 773 | 774 | {-# DEPRECATED readBufferToVector, readBufferToStorableVector "Use FromBuffer class and readBuffer instead" #-} 775 | 776 | -- | Read whole elements of 'Array' and return as a 'VG.Vector'. 777 | readBufferToVector :: forall v a m. (VG.Vector v a, HasDType a, MonadIO m) => Model -> String -> m (v a) 778 | readBufferToVector model name = liftIO $ withBuffer model name $ \p -> do 779 | dtype <- getDType model name 780 | dims <- getDims model name 781 | basicReadBufferGenericVectorStorable (dtypeOf (Proxy :: Proxy a)) dtype dims p 782 | 783 | -- | Read whole eleemnts of 'Array' and return as a 'VS.Vector'. 784 | readBufferToStorableVector :: forall a m. (HasDType a, MonadIO m) => Model -> String -> m (VS.Vector a) 785 | readBufferToStorableVector model name = liftIO $ withBuffer model name $ \p -> do 786 | dtype <- getDType model name 787 | dims <- getDims model name 788 | basicReadBufferStorableVector (dtypeOf (Proxy :: Proxy a)) dtype dims p 789 | 790 | -- ------------------------------------------------------------------------ 791 | 792 | -- | Convenient methods for constructing a 'Model'. 793 | makeModel 794 | :: MonadIO m 795 | => VariableProfileTable -- ^ variable profile table 796 | -> ModelData -- ^ model data 797 | -> String -- ^ backend name 798 | -> m Model 799 | makeModel vpt model_data backend_name = liftIO $ do 800 | b <- makeModelBuilder vpt 801 | buildModel b model_data backend_name 802 | 803 | -- | Similar to 'makeModel' but backend-specific configuration can be supplied. 804 | makeModelWithConfig 805 | :: (MonadIO m, J.ToJSON a) 806 | => VariableProfileTable -- ^ variable profile table 807 | -> ModelData -- ^ model data 808 | -> String -- ^ backend name 809 | -> a -- ^ backend config 810 | -> m Model 811 | makeModelWithConfig vpt model_data backend_name backend_config = liftIO $ do 812 | b <- makeModelBuilder vpt 813 | buildModelWithConfig b model_data backend_name backend_config 814 | 815 | -- ------------------------------------------------------------------------ 816 | 817 | -- | Menoh version which was supplied on compilation time via CPP macro. 818 | version :: Version 819 | #if MIN_VERSION_base(4,8,0) 820 | version = makeVersion [Base.menoh_major_version, Base.menoh_minor_version, Base.menoh_patch_version] 821 | #else 822 | version = Version [Base.menoh_major_version, Base.menoh_minor_version, Base.menoh_patch_version] [] 823 | #endif 824 | 825 | -- | Version of this Haskell binding. (Not the version of /Menoh/ itself) 826 | bindingVersion :: Version 827 | bindingVersion = Paths_menoh.version 828 | -------------------------------------------------------------------------------- /src/Menoh/Base.hsc: -------------------------------------------------------------------------------- 1 | {-# OPTIONS_GHC -Wall #-} 2 | {-# LANGUAGE ForeignFunctionInterface #-} 3 | ----------------------------------------------------------------------------- 4 | -- | 5 | -- Module : Menoh.Base 6 | -- Copyright : Copyright (c) 2018 Preferred Networks, Inc. 7 | -- License : MIT (see the file LICENSE) 8 | -- 9 | -- Maintainer : Masahiro Sakai 10 | -- Stability : experimental 11 | -- Portability : non-portable 12 | -- 13 | -- FFI imports of Menoh library. 14 | -- 15 | -- See https://pfnet-research.github.io/menoh/ for details of those API. 16 | -- 17 | ----------------------------------------------------------------------------- 18 | module Menoh.Base where 19 | 20 | import Data.Int 21 | import Foreign 22 | import Foreign.C 23 | 24 | #include 25 | #include 26 | 27 | #define MIN_VERSION_libmenoh(major,minor,patch) (\ 28 | (major) < MENOH_MAJOR_VERSION || \ 29 | (major) == MENOH_MAJOR_VERSION && (minor) < MENOH_MINOR_VERSION || \ 30 | (major) == MENOH_MAJOR_VERSION && (minor) == MENOH_MINOR_VERSION && (patch) <= MENOH_PATCH_VERSION) 31 | 32 | type MenohDType = #type menoh_dtype 33 | 34 | type MenohErrorCode = #type menoh_error_code 35 | 36 | #enum MenohDType,,menoh_dtype_float 37 | 38 | #enum MenohErrorCode,, \ 39 | menoh_error_code_success, \ 40 | menoh_error_code_std_error, \ 41 | menoh_error_code_unknown_error, \ 42 | menoh_error_code_invalid_filename, \ 43 | menoh_error_code_onnx_parse_error, \ 44 | menoh_error_code_invalid_dtype, \ 45 | menoh_error_code_invalid_attribute_type, \ 46 | menoh_error_code_unsupported_operator_attribute, \ 47 | menoh_error_code_dimension_mismatch, \ 48 | menoh_error_code_variable_not_found, \ 49 | menoh_error_code_index_out_of_range, \ 50 | menoh_error_code_json_parse_error, \ 51 | menoh_error_code_invalid_backend_name, \ 52 | menoh_error_code_unsupported_operator, \ 53 | menoh_error_code_failed_to_configure_operator, \ 54 | menoh_error_code_backend_error, \ 55 | menoh_error_code_same_named_variable_already_exist, \ 56 | menoh_error_code_unsupported_input_dims, \ 57 | menoh_error_code_same_named_parameter_already_exist, \ 58 | menoh_error_code_same_named_attribute_already_exist, \ 59 | menoh_error_code_invalid_backend_config_error, \ 60 | menoh_error_code_input_not_found_error, \ 61 | menoh_error_code_output_not_found_error 62 | 63 | foreign import ccall unsafe menoh_get_last_error_message 64 | :: IO CString 65 | 66 | data MenohModelData 67 | type MenohModelDataHandle = Ptr MenohModelData 68 | 69 | foreign import ccall safe menoh_make_model_data_from_onnx 70 | :: CString -> Ptr MenohModelDataHandle -> IO MenohErrorCode 71 | 72 | foreign import ccall safe menoh_make_model_data_from_onnx_data_on_memory 73 | :: Ptr a -> Int32 -> Ptr MenohModelDataHandle -> IO MenohErrorCode 74 | 75 | foreign import ccall "&menoh_delete_model_data" menoh_delete_model_data_funptr 76 | :: FunPtr (MenohModelDataHandle -> IO ()) 77 | 78 | data MenohVariableProfileTableBuilder 79 | type MenohVariableProfileTableBuilderHandle = Ptr MenohVariableProfileTableBuilder 80 | 81 | foreign import ccall unsafe menoh_make_variable_profile_table_builder 82 | :: Ptr MenohVariableProfileTableBuilderHandle -> IO MenohErrorCode 83 | 84 | foreign import ccall "&menoh_delete_variable_profile_table_builder" 85 | menoh_delete_variable_profile_table_builder_funptr 86 | :: FunPtr (MenohVariableProfileTableBuilderHandle -> IO ()) 87 | 88 | foreign import ccall unsafe menoh_variable_profile_table_builder_add_input_profile 89 | :: MenohVariableProfileTableBuilderHandle -> CString -> MenohDType -> Int32 -> Ptr Int32 -> IO MenohErrorCode 90 | 91 | foreign import ccall unsafe menoh_variable_profile_table_builder_add_input_profile_dims_2 92 | :: MenohVariableProfileTableBuilderHandle -> CString -> MenohDType -> Int32 -> Int32 -> IO MenohErrorCode 93 | 94 | foreign import ccall unsafe menoh_variable_profile_table_builder_add_input_profile_dims_4 95 | :: MenohVariableProfileTableBuilderHandle -> CString -> MenohDType -> Int32 -> Int32 -> Int32 -> Int32 -> IO MenohErrorCode 96 | 97 | foreign import ccall unsafe menoh_variable_profile_table_builder_add_output_profile 98 | :: MenohVariableProfileTableBuilderHandle -> CString -> MenohDType -> IO MenohErrorCode 99 | 100 | foreign import ccall unsafe menoh_variable_profile_table_builder_add_output_name 101 | :: MenohVariableProfileTableBuilderHandle -> CString -> IO MenohErrorCode 102 | 103 | data MenohVariableProfileTable 104 | type MenohVariableProfileTableHandle = Ptr MenohVariableProfileTable 105 | 106 | foreign import ccall safe menoh_build_variable_profile_table 107 | :: MenohVariableProfileTableBuilderHandle -> MenohModelDataHandle 108 | -> Ptr MenohVariableProfileTableHandle -> IO MenohErrorCode 109 | 110 | foreign import ccall "&menoh_delete_variable_profile_table" 111 | menoh_delete_variable_profile_table_funptr 112 | :: FunPtr (MenohVariableProfileTableHandle -> IO ()) 113 | 114 | foreign import ccall unsafe menoh_variable_profile_table_get_dtype 115 | :: MenohVariableProfileTableHandle -> CString -> Ptr MenohDType -> IO MenohErrorCode 116 | 117 | foreign import ccall unsafe menoh_variable_profile_table_get_dims_size 118 | :: MenohVariableProfileTableHandle -> CString -> Ptr Int32 -> IO MenohErrorCode 119 | 120 | foreign import ccall unsafe menoh_variable_profile_table_get_dims_at 121 | :: MenohVariableProfileTableHandle -> CString -> Int32 -> Ptr Int32 -> IO MenohErrorCode 122 | 123 | foreign import ccall safe menoh_model_data_optimize 124 | :: MenohModelDataHandle -> MenohVariableProfileTableHandle -> IO MenohErrorCode 125 | 126 | data MenohModelBuilder 127 | type MenohModelBuilderHandle = Ptr MenohModelBuilder 128 | 129 | foreign import ccall unsafe menoh_make_model_builder 130 | :: MenohVariableProfileTableHandle -> Ptr MenohModelBuilderHandle -> IO MenohErrorCode 131 | 132 | foreign import ccall "&menoh_delete_model_builder" menoh_delete_model_builder_funptr 133 | :: FunPtr (MenohModelBuilderHandle -> IO ()) 134 | 135 | foreign import ccall unsafe menoh_model_builder_attach_external_buffer 136 | :: MenohModelBuilderHandle -> CString -> Ptr a -> IO MenohErrorCode 137 | 138 | data MenohModel 139 | type MenohModelHandle = Ptr MenohModel 140 | 141 | foreign import ccall safe menoh_build_model 142 | :: MenohModelBuilderHandle -> MenohModelDataHandle -> CString -> CString 143 | -> Ptr MenohModelHandle -> IO MenohErrorCode 144 | 145 | foreign import ccall "&menoh_delete_model" menoh_delete_model_funptr 146 | :: FunPtr (MenohModelHandle -> IO ()) 147 | 148 | foreign import ccall unsafe menoh_model_get_variable_buffer_handle 149 | :: MenohModelHandle -> CString -> Ptr (Ptr a) -> IO MenohErrorCode 150 | 151 | foreign import ccall unsafe menoh_model_get_variable_dtype 152 | :: MenohModelHandle -> CString -> Ptr MenohDType -> IO MenohErrorCode 153 | 154 | foreign import ccall unsafe menoh_model_get_variable_dims_size 155 | :: MenohModelHandle -> CString -> Ptr Int32 -> IO MenohErrorCode 156 | 157 | foreign import ccall unsafe menoh_model_get_variable_dims_at 158 | :: MenohModelHandle -> CString -> Int32 -> Ptr Int32 -> IO MenohErrorCode 159 | 160 | foreign import ccall safe menoh_model_run 161 | :: MenohModelHandle -> IO MenohErrorCode 162 | 163 | foreign import ccall unsafe menoh_make_model_data 164 | :: Ptr MenohModelDataHandle -> IO MenohErrorCode 165 | 166 | foreign import ccall safe menoh_model_data_add_parameter 167 | :: MenohModelDataHandle -> CString -> MenohDType -> Int32 -> Ptr Int32 -> Ptr a -> IO MenohErrorCode 168 | 169 | foreign import ccall unsafe menoh_model_data_add_new_node 170 | :: MenohModelDataHandle -> CString -> IO MenohErrorCode 171 | 172 | foreign import ccall unsafe menoh_model_data_add_input_name_to_current_node 173 | :: MenohModelDataHandle -> CString -> IO MenohErrorCode 174 | 175 | foreign import ccall unsafe menoh_model_data_add_output_name_to_current_node 176 | :: MenohModelDataHandle -> CString -> IO MenohErrorCode 177 | 178 | foreign import ccall unsafe menoh_model_data_add_attribute_int_to_current_node 179 | :: MenohModelDataHandle -> CString -> Int32 -> IO MenohErrorCode 180 | 181 | foreign import ccall unsafe menoh_model_data_add_attribute_float_to_current_node 182 | :: MenohModelDataHandle -> CString -> CFloat -> IO MenohErrorCode 183 | 184 | foreign import ccall unsafe menoh_model_data_add_attribute_ints_to_current_node 185 | :: MenohModelDataHandle -> CString -> Int32 -> Ptr CInt -> IO MenohErrorCode 186 | 187 | foreign import ccall unsafe menoh_model_data_add_attribute_floats_to_current_node 188 | :: MenohModelDataHandle -> CString -> Int32 -> Ptr CFloat -> IO MenohErrorCode 189 | 190 | menoh_major_version :: Int 191 | menoh_major_version = #const MENOH_MAJOR_VERSION 192 | 193 | menoh_minor_version :: Int 194 | menoh_minor_version = #const MENOH_MINOR_VERSION 195 | 196 | menoh_patch_version :: Int 197 | menoh_patch_version = #const MENOH_PATCH_VERSION 198 | 199 | menoh_version_string :: String 200 | menoh_version_string = #const_str MENOH_VERSION_STRING 201 | -------------------------------------------------------------------------------- /stack-ghc-7.10.yaml: -------------------------------------------------------------------------------- 1 | # This file was automatically generated by 'stack init' 2 | # 3 | # Some commonly used options have been documented as comments in this file. 4 | # For advanced use and comprehensive documentation of the format, please see: 5 | # https://docs.haskellstack.org/en/stable/yaml_configuration/ 6 | 7 | # Resolver to choose a 'specific' stackage snapshot or a compiler version. 8 | # A snapshot resolver dictates the compiler version and the set of packages 9 | # to be used for project dependencies. For example: 10 | # 11 | # resolver: lts-3.5 12 | # resolver: nightly-2015-09-21 13 | # resolver: ghc-7.10.2 14 | # resolver: ghcjs-0.1.0_ghc-7.10.2 15 | # resolver: 16 | # name: custom-snapshot 17 | # location: "./custom-snapshot.yaml" 18 | resolver: lts-6.35 19 | 20 | # User packages to be built. 21 | # Various formats can be used as shown in the example below. 22 | # 23 | # packages: 24 | # - some-directory 25 | # - https://example.com/foo/bar/baz-0.0.2.tar.gz 26 | # - location: 27 | # git: https://github.com/commercialhaskell/stack.git 28 | # commit: e7b331f14bcffb8367cd58fbfc8b40ec7642100a 29 | # - location: https://github.com/commercialhaskell/stack/commit/e7b331f14bcffb8367cd58fbfc8b40ec7642100a 30 | # extra-dep: true 31 | # subdirs: 32 | # - auto-update 33 | # - wai 34 | # 35 | # A package marked 'extra-dep: true' will only be built if demanded by a 36 | # non-dependency (i.e. a user package), and its test suites and benchmarks 37 | # will not be run. This is useful for tweaking upstream packages. 38 | packages: 39 | - . 40 | 41 | # Dependency packages to be pulled from upstream that are not in the resolver 42 | # (e.g., acme-missiles-0.3) 43 | extra-deps: [] 44 | 45 | # Override default flag values for local packages and extra-deps 46 | # flags: {} 47 | 48 | # Extra package databases containing global packages 49 | # extra-package-dbs: [] 50 | 51 | # Control whether we use the GHC we find on the path 52 | # system-ghc: true 53 | # 54 | # Require a specific version of stack, using version ranges 55 | # require-stack-version: -any # Default 56 | # require-stack-version: ">=1.6" 57 | # 58 | # Override the architecture used by stack, especially useful on Windows 59 | # arch: i386 60 | # arch: x86_64 61 | # 62 | # Extra directories used by stack for building 63 | # extra-include-dirs: [/path/to/dir] 64 | # extra-lib-dirs: [/path/to/dir] 65 | # 66 | # Allow a newer minor version of GHC than the snapshot specifies 67 | # compiler-check: newer-minor 68 | -------------------------------------------------------------------------------- /stack-ghc-7.8.yaml: -------------------------------------------------------------------------------- 1 | # This file was automatically generated by 'stack init' 2 | # 3 | # Some commonly used options have been documented as comments in this file. 4 | # For advanced use and comprehensive documentation of the format, please see: 5 | # https://docs.haskellstack.org/en/stable/yaml_configuration/ 6 | 7 | # Resolver to choose a 'specific' stackage snapshot or a compiler version. 8 | # A snapshot resolver dictates the compiler version and the set of packages 9 | # to be used for project dependencies. For example: 10 | # 11 | # resolver: lts-3.5 12 | # resolver: nightly-2015-09-21 13 | # resolver: ghc-7.10.2 14 | # resolver: ghcjs-0.1.0_ghc-7.10.2 15 | # resolver: 16 | # name: custom-snapshot 17 | # location: "./custom-snapshot.yaml" 18 | resolver: lts-2.22 19 | 20 | # User packages to be built. 21 | # Various formats can be used as shown in the example below. 22 | # 23 | # packages: 24 | # - some-directory 25 | # - https://example.com/foo/bar/baz-0.0.2.tar.gz 26 | # - location: 27 | # git: https://github.com/commercialhaskell/stack.git 28 | # commit: e7b331f14bcffb8367cd58fbfc8b40ec7642100a 29 | # - location: https://github.com/commercialhaskell/stack/commit/e7b331f14bcffb8367cd58fbfc8b40ec7642100a 30 | # extra-dep: true 31 | # subdirs: 32 | # - auto-update 33 | # - wai 34 | # 35 | # A package marked 'extra-dep: true' will only be built if demanded by a 36 | # non-dependency (i.e. a user package), and its test suites and benchmarks 37 | # will not be run. This is useful for tweaking upstream packages. 38 | packages: 39 | - . 40 | 41 | # Dependency packages to be pulled from upstream that are not in the resolver 42 | # (e.g., acme-missiles-0.3) 43 | extra-deps: 44 | - JuicyPixels-3.2.7.2 45 | 46 | # Override default flag values for local packages and extra-deps 47 | # flags: {} 48 | 49 | # Extra package databases containing global packages 50 | # extra-package-dbs: [] 51 | 52 | # Control whether we use the GHC we find on the path 53 | # system-ghc: true 54 | # 55 | # Require a specific version of stack, using version ranges 56 | # require-stack-version: -any # Default 57 | # require-stack-version: ">=1.6" 58 | # 59 | # Override the architecture used by stack, especially useful on Windows 60 | # arch: i386 61 | # arch: x86_64 62 | # 63 | # Extra directories used by stack for building 64 | # extra-include-dirs: [/path/to/dir] 65 | # extra-lib-dirs: [/path/to/dir] 66 | # 67 | # Allow a newer minor version of GHC than the snapshot specifies 68 | # compiler-check: newer-minor 69 | -------------------------------------------------------------------------------- /stack-ghc-8.0.yaml: -------------------------------------------------------------------------------- 1 | # This file was automatically generated by 'stack init' 2 | # 3 | # Some commonly used options have been documented as comments in this file. 4 | # For advanced use and comprehensive documentation of the format, please see: 5 | # https://docs.haskellstack.org/en/stable/yaml_configuration/ 6 | 7 | # Resolver to choose a 'specific' stackage snapshot or a compiler version. 8 | # A snapshot resolver dictates the compiler version and the set of packages 9 | # to be used for project dependencies. For example: 10 | # 11 | # resolver: lts-3.5 12 | # resolver: nightly-2015-09-21 13 | # resolver: ghc-7.10.2 14 | # resolver: ghcjs-0.1.0_ghc-7.10.2 15 | # resolver: 16 | # name: custom-snapshot 17 | # location: "./custom-snapshot.yaml" 18 | resolver: lts-9.21 19 | 20 | # User packages to be built. 21 | # Various formats can be used as shown in the example below. 22 | # 23 | # packages: 24 | # - some-directory 25 | # - https://example.com/foo/bar/baz-0.0.2.tar.gz 26 | # - location: 27 | # git: https://github.com/commercialhaskell/stack.git 28 | # commit: e7b331f14bcffb8367cd58fbfc8b40ec7642100a 29 | # - location: https://github.com/commercialhaskell/stack/commit/e7b331f14bcffb8367cd58fbfc8b40ec7642100a 30 | # extra-dep: true 31 | # subdirs: 32 | # - auto-update 33 | # - wai 34 | # 35 | # A package marked 'extra-dep: true' will only be built if demanded by a 36 | # non-dependency (i.e. a user package), and its test suites and benchmarks 37 | # will not be run. This is useful for tweaking upstream packages. 38 | packages: 39 | - . 40 | 41 | # Dependency packages to be pulled from upstream that are not in the resolver 42 | # (e.g., acme-missiles-0.3) 43 | extra-deps: [] 44 | 45 | # Override default flag values for local packages and extra-deps 46 | # flags: {} 47 | 48 | # Extra package databases containing global packages 49 | # extra-package-dbs: [] 50 | 51 | # Control whether we use the GHC we find on the path 52 | # system-ghc: true 53 | # 54 | # Require a specific version of stack, using version ranges 55 | # require-stack-version: -any # Default 56 | # require-stack-version: ">=1.6" 57 | # 58 | # Override the architecture used by stack, especially useful on Windows 59 | # arch: i386 60 | # arch: x86_64 61 | # 62 | # Extra directories used by stack for building 63 | # extra-include-dirs: [/path/to/dir] 64 | # extra-lib-dirs: [/path/to/dir] 65 | # 66 | # Allow a newer minor version of GHC than the snapshot specifies 67 | # compiler-check: newer-minor 68 | -------------------------------------------------------------------------------- /stack-ghc-8.2.yaml: -------------------------------------------------------------------------------- 1 | # This file was automatically generated by 'stack init' 2 | # 3 | # Some commonly used options have been documented as comments in this file. 4 | # For advanced use and comprehensive documentation of the format, please see: 5 | # https://docs.haskellstack.org/en/stable/yaml_configuration/ 6 | 7 | # Resolver to choose a 'specific' stackage snapshot or a compiler version. 8 | # A snapshot resolver dictates the compiler version and the set of packages 9 | # to be used for project dependencies. For example: 10 | # 11 | # resolver: lts-3.5 12 | # resolver: nightly-2015-09-21 13 | # resolver: ghc-7.10.2 14 | # resolver: ghcjs-0.1.0_ghc-7.10.2 15 | # resolver: 16 | # name: custom-snapshot 17 | # location: "./custom-snapshot.yaml" 18 | resolver: lts-11.17 19 | 20 | # User packages to be built. 21 | # Various formats can be used as shown in the example below. 22 | # 23 | # packages: 24 | # - some-directory 25 | # - https://example.com/foo/bar/baz-0.0.2.tar.gz 26 | # - location: 27 | # git: https://github.com/commercialhaskell/stack.git 28 | # commit: e7b331f14bcffb8367cd58fbfc8b40ec7642100a 29 | # - location: https://github.com/commercialhaskell/stack/commit/e7b331f14bcffb8367cd58fbfc8b40ec7642100a 30 | # extra-dep: true 31 | # subdirs: 32 | # - auto-update 33 | # - wai 34 | # 35 | # A package marked 'extra-dep: true' will only be built if demanded by a 36 | # non-dependency (i.e. a user package), and its test suites and benchmarks 37 | # will not be run. This is useful for tweaking upstream packages. 38 | packages: 39 | - . 40 | 41 | # Dependency packages to be pulled from upstream that are not in the resolver 42 | # (e.g., acme-missiles-0.3) 43 | extra-deps: [] 44 | 45 | # Override default flag values for local packages and extra-deps 46 | # flags: {} 47 | 48 | # Extra package databases containing global packages 49 | # extra-package-dbs: [] 50 | 51 | # Control whether we use the GHC we find on the path 52 | # system-ghc: true 53 | # 54 | # Require a specific version of stack, using version ranges 55 | # require-stack-version: -any # Default 56 | # require-stack-version: ">=1.6" 57 | # 58 | # Override the architecture used by stack, especially useful on Windows 59 | # arch: i386 60 | # arch: x86_64 61 | # 62 | # Extra directories used by stack for building 63 | # extra-include-dirs: [/path/to/dir] 64 | # extra-lib-dirs: [/path/to/dir] 65 | # 66 | # Allow a newer minor version of GHC than the snapshot specifies 67 | # compiler-check: newer-minor 68 | -------------------------------------------------------------------------------- /stack.yaml: -------------------------------------------------------------------------------- 1 | # This file was automatically generated by 'stack init' 2 | # 3 | # Some commonly used options have been documented as comments in this file. 4 | # For advanced use and comprehensive documentation of the format, please see: 5 | # https://docs.haskellstack.org/en/stable/yaml_configuration/ 6 | 7 | # Resolver to choose a 'specific' stackage snapshot or a compiler version. 8 | # A snapshot resolver dictates the compiler version and the set of packages 9 | # to be used for project dependencies. For example: 10 | # 11 | # resolver: lts-3.5 12 | # resolver: nightly-2015-09-21 13 | # resolver: ghc-7.10.2 14 | # resolver: ghcjs-0.1.0_ghc-7.10.2 15 | # resolver: 16 | # name: custom-snapshot 17 | # location: "./custom-snapshot.yaml" 18 | resolver: lts-12.19 19 | 20 | # User packages to be built. 21 | # Various formats can be used as shown in the example below. 22 | # 23 | # packages: 24 | # - some-directory 25 | # - https://example.com/foo/bar/baz-0.0.2.tar.gz 26 | # - location: 27 | # git: https://github.com/commercialhaskell/stack.git 28 | # commit: e7b331f14bcffb8367cd58fbfc8b40ec7642100a 29 | # - location: https://github.com/commercialhaskell/stack/commit/e7b331f14bcffb8367cd58fbfc8b40ec7642100a 30 | # extra-dep: true 31 | # subdirs: 32 | # - auto-update 33 | # - wai 34 | # 35 | # A package marked 'extra-dep: true' will only be built if demanded by a 36 | # non-dependency (i.e. a user package), and its test suites and benchmarks 37 | # will not be run. This is useful for tweaking upstream packages. 38 | packages: 39 | - . 40 | 41 | # Dependency packages to be pulled from upstream that are not in the resolver 42 | # (e.g., acme-missiles-0.3) 43 | extra-deps: [] 44 | 45 | # Override default flag values for local packages and extra-deps 46 | # flags: {} 47 | 48 | # Extra package databases containing global packages 49 | # extra-package-dbs: [] 50 | 51 | # Control whether we use the GHC we find on the path 52 | # system-ghc: true 53 | # 54 | # Require a specific version of stack, using version ranges 55 | # require-stack-version: -any # Default 56 | # require-stack-version: ">=1.6" 57 | # 58 | # Override the architecture used by stack, especially useful on Windows 59 | # arch: i386 60 | # arch: x86_64 61 | # 62 | # Extra directories used by stack for building 63 | # extra-include-dirs: [/path/to/dir] 64 | # extra-lib-dirs: [/path/to/dir] 65 | # 66 | # Allow a newer minor version of GHC than the snapshot specifies 67 | # compiler-check: newer-minor 68 | -------------------------------------------------------------------------------- /test/test.hs: -------------------------------------------------------------------------------- 1 | {-# OPTIONS_GHC -Wall #-} 2 | {-# LANGUAGE CPP #-} 3 | {-# LANGUAGE ScopedTypeVariables #-} 4 | {-# LANGUAGE TemplateHaskell #-} 5 | 6 | import qualified Codec.Picture as Picture 7 | import qualified Codec.Picture.Types as Picture 8 | import Control.Concurrent.Async 9 | import Control.Exception 10 | import Control.Monad 11 | import qualified Data.ByteString as BS 12 | import qualified Data.Vector as V 13 | import qualified Data.Vector.Generic as VG 14 | import qualified Data.Vector.Storable as VS 15 | import qualified Data.Vector.Unboxed as VU 16 | import Foreign 17 | import System.FilePath 18 | 19 | import Test.Tasty.HUnit 20 | import Test.Tasty.TH 21 | 22 | import Menoh 23 | import Paths_menoh (getDataDir) 24 | 25 | #include 26 | 27 | #define MIN_VERSION_libmenoh(major,minor,patch) (\ 28 | (major) < MENOH_MAJOR_VERSION || \ 29 | (major) == MENOH_MAJOR_VERSION && (minor) < MENOH_MINOR_VERSION || \ 30 | (major) == MENOH_MAJOR_VERSION && (minor) == MENOH_MINOR_VERSION && (patch) <= MENOH_PATCH_VERSION) 31 | 32 | ------------------------------------------------------------------------ 33 | 34 | case_basicWriteBuffer_vector :: Assertion 35 | case_basicWriteBuffer_vector = do 36 | allocaArray 9 $ \(p :: Ptr Float) -> do 37 | basicWriteBuffer DTypeFloat [3,3] (castPtr p) (VG.tail (V.fromList xs)) 38 | ys <- peekArray 9 p 39 | ys @?= (tail xs) 40 | where 41 | xs = [0..9] 42 | 43 | case_basicWriteBuffer_vector_storable :: Assertion 44 | case_basicWriteBuffer_vector_storable = do 45 | allocaArray 9 $ \(p :: Ptr Float) -> do 46 | basicWriteBuffer DTypeFloat [3,3] (castPtr p) (VG.tail (VS.fromList xs)) 47 | ys <- peekArray 9 p 48 | ys @?= tail xs 49 | where 50 | xs = [0..9] 51 | 52 | case_basicWriteBuffer_vector_unboxed :: Assertion 53 | case_basicWriteBuffer_vector_unboxed = do 54 | allocaArray 9 $ \(p :: Ptr Float) -> do 55 | basicWriteBuffer DTypeFloat [3,3] (castPtr p) (VG.tail (VU.fromList xs)) 56 | ys <- peekArray 9 p 57 | ys @?= tail xs 58 | where 59 | xs = [0..9] 60 | 61 | case_basicWriteBuffer_list :: Assertion 62 | case_basicWriteBuffer_list = do 63 | allocaArray 9 $ \(p :: Ptr Float) -> do 64 | basicWriteBuffer DTypeFloat [3,3] (castPtr p) (map V.fromList xss) 65 | ys <- peekArray 9 p 66 | ys @?= concat xss 67 | where 68 | xss = [[1,2,3], [4,5,6], [7,8,9]] 69 | 70 | ------------------------------------------------------------------------ 71 | 72 | case_loading_nonexistent_model_file :: Assertion 73 | case_loading_nonexistent_model_file = do 74 | dataDir <- getDataDir 75 | ret <- try $ makeModelDataFromONNXFile $ dataDir "data" "nonexistent_model.onnx" 76 | case ret of 77 | Left (ErrorInvalidFilename _msg) -> return () 78 | _ -> assertFailure "should throw ErrorInvalidFilename" 79 | 80 | 81 | case_empty_output :: Assertion 82 | case_empty_output = do 83 | images <- loadMNISTImages 84 | let batch_size = length images 85 | 86 | dataDir <- getDataDir 87 | model_data <- makeModelDataFromONNXFile $ dataDir "data" "mnist.onnx" 88 | vpt <- makeVariableProfileTable 89 | [(mnist_in_name, DTypeFloat, [batch_size, mnist_channel_num, mnist_height, mnist_width])] 90 | ([] :: [String]) 91 | model_data 92 | optimizeModelData model_data vpt 93 | model <- makeModel vpt model_data "mkldnn" 94 | 95 | -- Run the model 96 | writeBuffer model mnist_in_name images 97 | run model 98 | 99 | -- but we cannot retrieve results 100 | return () 101 | 102 | 103 | case_insufficient_input :: Assertion 104 | case_insufficient_input = do 105 | dataDir <- getDataDir 106 | model_data <- makeModelDataFromONNXFile $ dataDir "data" "mnist.onnx" 107 | ret <- try $ makeVariableProfileTable 108 | [] 109 | [mnist_out_name] 110 | model_data 111 | case ret of 112 | Left (ErrorVariableNotFound _msg) -> return () 113 | _ -> assertFailure "should throw ErrorVariableNotFound" 114 | 115 | 116 | case_bad_input :: Assertion 117 | case_bad_input = do 118 | images <- loadMNISTImages 119 | 120 | dataDir <- getDataDir 121 | model_data <- makeModelDataFromONNXFile $ dataDir "data" "mnist.onnx" 122 | ret <- try $ makeVariableProfileTable 123 | [ (mnist_in_name, DTypeFloat, [length images, mnist_channel_num, mnist_height, mnist_width]) 124 | , ("bad input name", DTypeFloat, [1,8]) 125 | ] 126 | [mnist_out_name] 127 | model_data 128 | case ret of 129 | Left (InputNotFoundError _msg) -> return () 130 | _ -> assertFailure "should throw InputNotFoundError" 131 | 132 | case_bad_output :: Assertion 133 | case_bad_output = do 134 | images <- loadMNISTImages 135 | 136 | dataDir <- getDataDir 137 | model_data <- makeModelDataFromONNXFile $ dataDir "data" "mnist.onnx" 138 | ret <- try $ makeVariableProfileTable 139 | [(mnist_in_name, DTypeFloat, [length images, mnist_channel_num, mnist_height, mnist_width])] 140 | [mnist_out_name, "bad output name"] 141 | model_data 142 | case ret of 143 | Left (OutputNotFoundError _msg) -> return () 144 | _ -> assertFailure "should throw OutputNotFoundError" 145 | 146 | ------------------------------------------------------------------------ 147 | 148 | -- Aliases to onnx's node input and output tensor name 149 | mnist_in_name, mnist_out_name :: String 150 | mnist_in_name = "139900320569040" 151 | mnist_out_name = "139898462888656" 152 | 153 | mnist_channel_num, mnist_height, mnist_width :: Int 154 | mnist_channel_num = 1 155 | mnist_height = 28 156 | mnist_width = 28 157 | 158 | loadMNISTImages :: IO [VS.Vector Float] 159 | loadMNISTImages = do 160 | dataDir <- getDataDir 161 | forM [(0::Int)..9] $ \i -> do 162 | ret <- Picture.readImage $ dataDir "data" (show i ++ ".png") 163 | case ret of 164 | Left e -> error e 165 | Right img -> return 166 | $ VG.map fromIntegral 167 | $ Picture.imageData 168 | $ Picture.extractLumaPlane 169 | $ Picture.convertRGB8 170 | $ img 171 | 172 | loadMNISTModel :: Int -> IO Model 173 | loadMNISTModel batch_size = do 174 | dataDir <- getDataDir 175 | model_data <- makeModelDataFromONNXFile $ dataDir "data" "mnist.onnx" 176 | vpt <- makeVariableProfileTable 177 | [(mnist_in_name, DTypeFloat, [batch_size, mnist_channel_num, mnist_height, mnist_width])] 178 | [mnist_out_name] 179 | model_data 180 | optimizeModelData model_data vpt 181 | makeModel vpt model_data "mkldnn" 182 | 183 | loadMNISTModelFromByteString :: Int -> IO Model 184 | loadMNISTModelFromByteString batch_size = do 185 | dataDir <- getDataDir 186 | b <- BS.readFile $ dataDir "data" "mnist.onnx" 187 | model_data <- makeModelDataFromONNXByteString b 188 | vpt <- makeVariableProfileTable 189 | [(mnist_in_name, DTypeFloat, [batch_size, mnist_channel_num, mnist_height, mnist_width])] 190 | [mnist_out_name] 191 | model_data 192 | optimizeModelData model_data vpt 193 | makeModel vpt model_data "mkldnn" 194 | 195 | case_MNIST :: Assertion 196 | case_MNIST = do 197 | images <- loadMNISTImages 198 | model <- loadMNISTModel (length images) 199 | 200 | -- Run the model 201 | writeBuffer model mnist_in_name images 202 | run model 203 | (vs :: [V.Vector Float]) <- readBuffer model mnist_out_name 204 | forM_ (zip [0..9] vs) $ \(i, scores) -> do 205 | V.maxIndex scores @?= i 206 | 207 | -- Run the same model more than once, but with the different order 208 | writeBuffer model mnist_in_name (reverse images) 209 | run model 210 | (vs' :: [V.Vector Float]) <- readBuffer model mnist_out_name 211 | forM_ (zip [9,8..0] vs') $ \(i, scores) -> do 212 | V.maxIndex scores @?= i 213 | 214 | case_MNIST_concurrently :: Assertion 215 | case_MNIST_concurrently = do 216 | images <- loadMNISTImages 217 | let batch_size = length images 218 | 219 | dataDir <- getDataDir 220 | model_data <- makeModelDataFromONNXFile $ dataDir "data" "mnist.onnx" 221 | vpt <- makeVariableProfileTable 222 | [(mnist_in_name, DTypeFloat, [batch_size, mnist_channel_num, mnist_height, mnist_width])] 223 | [mnist_out_name] 224 | model_data 225 | optimizeModelData model_data vpt 226 | models <- replicateM 10 $ makeModel vpt model_data "mkldnn" 227 | 228 | _ <- flip mapConcurrently models $ \model -> do 229 | replicateM_ 10 $ do 230 | writeBuffer model mnist_in_name images 231 | run model 232 | (vs :: [V.Vector Float]) <- readBuffer model mnist_out_name 233 | forM_ (zip [0..9] vs) $ \(i, scores) -> do 234 | V.maxIndex scores @?= i 235 | return () 236 | 237 | case_makeModelDataFromONNXByteString :: Assertion 238 | case_makeModelDataFromONNXByteString = do 239 | images <- loadMNISTImages 240 | model1 <- loadMNISTModel (length images) 241 | model2 <- loadMNISTModelFromByteString (length images) 242 | 243 | -- Run the model (1) 244 | writeBuffer model1 mnist_in_name images 245 | run model1 246 | (vs1 :: [V.Vector Float]) <- readBuffer model1 mnist_out_name 247 | 248 | -- Run the model (2) 249 | writeBuffer model2 mnist_in_name images 250 | run model2 251 | (vs2 :: [V.Vector Float]) <- readBuffer model2 mnist_out_name 252 | 253 | vs2 @?= vs1 254 | 255 | case_makeModelData :: Assertion 256 | case_makeModelData = do 257 | md <- makeModelData 258 | withArray [1,2,3,4,5,6] $ \(p :: Ptr Float) -> 259 | addParameterFromPtr md "W" DTypeFloat [2,3] p 260 | withArray [7,8] $ \(p :: Ptr Float) -> 261 | addParameterFromPtr md "b" DTypeFloat [2] p 262 | addNewNode md "FC" 263 | addInputNameToCurrentNode md "input" 264 | addInputNameToCurrentNode md "W" 265 | addInputNameToCurrentNode md "b" 266 | addOutputNameToCurrentNode md "output" 267 | 268 | vpt <- makeVariableProfileTable 269 | [("input", DTypeFloat, [1, 3])] 270 | ["output"] 271 | md 272 | 273 | optimizeModelData md vpt 274 | m <- makeModel vpt md "mkldnn" 275 | 276 | writeBuffer m "input" $ [VS.fromList [1::Float,2,3]] 277 | run m 278 | [r] <- readBuffer m "output" 279 | 280 | r @?= VS.fromList [21::Float,40] 281 | 282 | ------------------------------------------------------------------------ 283 | -- Test harness 284 | 285 | main :: IO () 286 | main = $(defaultMainGenerator) 287 | --------------------------------------------------------------------------------