├── .gitignore ├── .gitmodules ├── .travis.yml ├── HerbiePlugin.cabal ├── LICENSE ├── README.md ├── Setup.hs ├── data └── Herbie.db ├── src ├── Herbie.hs ├── Herbie │ ├── CoreManip.hs │ ├── ForeignInterface.hs │ ├── MathExpr.hs │ ├── MathInfo.hs │ └── Options.hs └── Show.hs ├── stack.yaml └── test ├── SpecialFunctions.hs ├── Tests.hs └── ValidRewrite.hs /.gitignore: -------------------------------------------------------------------------------- 1 | *.Rout 2 | *.csv 3 | *.class 4 | *.o 5 | 6 | results 7 | cover_tree 8 | 9 | hlearn-allknn 10 | hlearn-linear 11 | *.swp 12 | *.swo 13 | dist/ 14 | gitignore/ 15 | examples/old/ 16 | scripts/ 17 | 18 | .cabal-sandbox/ 19 | cabal.sandbox.config 20 | 21 | .stack-work/ 22 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "herbie"] 2 | path = herbie 3 | url = https://github.com/uwplse/herbie 4 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | # NB: don't set `language: haskell` here 2 | language: c 3 | sudo: false 4 | 5 | cache: 6 | directories: 7 | - $HOME/.cabsnap 8 | - $HOME/.cabal/packages 9 | - $HOME/racket 10 | 11 | before_cache: 12 | - rm -fv $HOME/.cabal/packages/hackage.haskell.org/build-reports.log 13 | - rm -fv $HOME/.cabal/packages/hackage.haskell.org/00-index.tar 14 | 15 | matrix: 16 | include: 17 | - env: CABALVER=1.22 GHCVER=7.10.1 RACKET_VERSION=6.1 18 | compiler: ": #GHC 7.10.1" 19 | addons: {apt: {packages: [cabal-install-1.22,ghc-7.10.1,alex-3.1.4,happy-1.19.5,libblas-dev,liblapack-dev], sources: [hvr-ghc]}} 20 | - env: CABALVER=1.22 GHCVER=7.10.2 RACKET_VERSION=6.1 21 | compiler: ": #GHC 7.10.2" 22 | addons: {apt: {packages: [cabal-install-1.22,ghc-7.10.2,alex-3.1.4,happy-1.19.5,libblas-dev,liblapack-dev], sources: [hvr-ghc]}} 23 | 24 | # Note: the distinction between `before_install` and `install` is not important. 25 | before_install: 26 | - unset CC 27 | - export HAPPYVER=1.19.5 28 | - export ALEXVER=3.1.4 29 | - export RACKET_DIR=$HOME/racket/$RACKET_VERSION 30 | - export PATH=$HOME/bin:$HOME/.cabal/bin:/opt/ghc/$GHCVER/bin:/opt/cabal/$CABALVER/bin:/opt/happy/$HAPPYVER/bin:/opt/alex/$ALEXVER/bin:$RACKET_DIR/bin:$PATH 31 | - mkdir -p $HOME/bin 32 | # install Racket if it isn't already installed 33 | - if [ ! -d $RACKET_DIR ]; then 34 | git clone https://github.com/greghendershott/travis-racket.git; 35 | cat travis-racket/install-racket.sh | bash; 36 | fi 37 | 38 | install: 39 | # install herbie-exec 40 | - raco exe -o $HOME/bin/herbie-exec herbie/herbie/interface/inout.rkt 41 | 42 | # manually download the expression database 43 | - wget https://github.com/mikeizbicki/HerbiePlugin/raw/master/data/Herbie.db 44 | - mkdir -p $HOME/.cabal/share/x86_64-linux-ghc-$GHCVER/HerbiePlugin-0.2.0.0/ 45 | - cp Herbie.db $HOME/.cabal/share/x86_64-linux-ghc-$GHCVER/HerbiePlugin-0.2.0.0/ 46 | 47 | # display versions 48 | - cabal --version 49 | - echo "$(ghc --version) [$(ghc --print-project-git-commit-id 2> /dev/null || echo '?')]" 50 | 51 | # install HerbiePlugin 52 | - if [ -f $HOME/.cabal/packages/hackage.haskell.org/00-index.tar.gz ]; then 53 | zcat $HOME/.cabal/packages/hackage.haskell.org/00-index.tar.gz > 54 | $HOME/.cabal/packages/hackage.haskell.org/00-index.tar; 55 | fi 56 | - travis_retry cabal update 57 | - "sed -i 's/^jobs:.*$/jobs: 2/' $HOME/.cabal/config" 58 | - cabal install --only-dependencies --enable-tests --enable-benchmarks --dry -v > installplan.txt 59 | - sed -i -e '1,/^Resolving /d' installplan.txt; cat installplan.txt 60 | 61 | # check whether current requested install-plan matches cached package-db snapshot 62 | - if diff -u installplan.txt $HOME/.cabsnap/installplan.txt; then 63 | echo "cabal build-cache HIT"; 64 | rm -rfv .ghc; 65 | cp -a $HOME/.cabsnap/ghc $HOME/.ghc; 66 | cp -a $HOME/.cabsnap/lib $HOME/.cabsnap/share $HOME/.cabsnap/bin $HOME/.cabal/; 67 | else 68 | echo "cabal build-cache MISS"; 69 | rm -rf $HOME/.cabsnap; 70 | mkdir -p $HOME/.ghc $HOME/.cabal/lib $HOME/.cabal/share $HOME/.cabal/bin; 71 | cabal install --only-dependencies --enable-tests --enable-benchmarks; 72 | fi 73 | 74 | # snapshot package-db on cache miss 75 | - if [ ! -d $HOME/.cabsnap ]; then 76 | echo "snapshotting package-db to build-cache"; 77 | mkdir $HOME/.cabsnap; 78 | cp -a $HOME/.ghc $HOME/.cabsnap/ghc; 79 | cp -a $HOME/.cabal/lib $HOME/.cabal/share $HOME/.cabal/bin installplan.txt $HOME/.cabsnap/; 80 | fi 81 | 82 | # Here starts the actual work to be performed for the package under test; any command which exits with a non-zero exit code causes the build to fail. 83 | script: 84 | - cabal configure --enable-tests --enable-benchmarks -v2 # -v2 provides useful information for debugging 85 | - cabal build # this builds all libraries and executables (including tests/benchmarks) 86 | - cabal test 87 | - cabal check 88 | - cabal sdist # tests that a source-distribution can be generated 89 | 90 | # Check that the resulting source distribution can be built & installed. 91 | # If there are no other `.tar.gz` files in `dist`, this can be even simpler: 92 | # `cabal install --force-reinstalls dist/*-*.tar.gz` 93 | - SRC_TGZ=$(cabal info . | awk '{print $2;exit}').tar.gz && 94 | (cd dist && cabal install --force-reinstalls "$SRC_TGZ") 95 | -------------------------------------------------------------------------------- /HerbiePlugin.cabal: -------------------------------------------------------------------------------- 1 | -- Initial herbie-haskell.cabal generated by cabal init. For further 2 | -- documentation, see http://haskell.org/cabal/users-guide/ 3 | 4 | -- The name of the package. 5 | name: HerbiePlugin 6 | 7 | -- The package version. See the Haskell package versioning policy (PVP) 8 | -- for standards guiding when and how versions should be incremented. 9 | -- http://www.haskell.org/haskellwiki/Package_versioning_policy 10 | -- PVP summary: +-+------- breaking API changes 11 | -- | | +----- non-breaking API additions 12 | -- | | | +--- code changes with no API change 13 | version: 0.2.0.0 14 | 15 | -- A short (one-line) description of the package. 16 | synopsis: automatically improve your code's numeric stability 17 | 18 | -- A longer description of the package. 19 | description: 20 | This package contains a GHC plugin that automatically improves the numerical stability of your Haskell code. 21 | See for details on how it works and how to use it. 22 | 23 | -- URL for the project homepage or repository. 24 | homepage: github.com/mikeizbicki/herbie-haskell 25 | 26 | -- The license under which the package is released. 27 | license: BSD3 28 | 29 | -- The file containing the license text. 30 | license-file: LICENSE 31 | 32 | -- The package author(s). 33 | author: Mike Izbicki 34 | 35 | -- An email address to which users can send suggestions, bug reports, and 36 | -- patches. 37 | maintainer: mike@izbicki.me 38 | 39 | -- A copyright notice. 40 | -- copyright: 41 | 42 | category: Math 43 | 44 | build-type: Simple 45 | 46 | -- Extra files to be distributed with the package, such as examples or a 47 | -- README. 48 | -- extra-source-files: 49 | 50 | data-files: 51 | Herbie.db 52 | 53 | data-dir: 54 | data 55 | 56 | -- Constraint on the version of Cabal needed to build this package. 57 | cabal-version: >=1.10 58 | 59 | source-repository head 60 | type: git 61 | location: http://github.com/mikeizbicki/HerbiePlugin 62 | 63 | -------------------------------------------------------------------------------- 64 | 65 | library 66 | -- Modules exported by the library. 67 | exposed-modules: 68 | Herbie 69 | 70 | -- Modules included in this library but not exported. 71 | other-modules: 72 | Herbie.CoreManip 73 | Herbie.ForeignInterface 74 | Herbie.MathExpr 75 | Herbie.MathInfo 76 | Herbie.Options 77 | Show 78 | Paths_HerbiePlugin 79 | 80 | -- LANGUAGE extensions used by modules in this package. 81 | default-extensions: 82 | MultiWayIf 83 | ScopedTypeVariables 84 | DeriveGeneric 85 | DeriveAnyClass 86 | DeriveDataTypeable 87 | StandaloneDeriving 88 | 89 | -- Other library packages from which modules are imported. 90 | build-depends: base >=4.8 && <4.9 91 | , ghc 92 | , template-haskell 93 | , process >= 1.1.0.0 94 | , sqlite-simple 95 | , text 96 | , directory 97 | , deepseq 98 | , mtl 99 | , split 100 | 101 | -- Directories containing source files. 102 | hs-source-dirs: src 103 | 104 | -- Base language which the package is written in. 105 | default-language: Haskell2010 106 | 107 | -------------------------------------------------------------------------------- 108 | 109 | Test-Suite Tests 110 | default-language: Haskell2010 111 | type: exitcode-stdio-1.0 112 | hs-source-dirs: test 113 | main-is: Tests.hs 114 | 115 | ghc-options: 116 | -fplugin=Herbie 117 | 118 | build-depends: 119 | base, 120 | subhask, 121 | HerbiePlugin 122 | 123 | Test-Suite SpecialFunctions 124 | default-language: Haskell2010 125 | type: exitcode-stdio-1.0 126 | hs-source-dirs: test 127 | main-is: SpecialFunctions.hs 128 | 129 | ghc-options: 130 | -fplugin=Herbie 131 | 132 | build-depends: 133 | base, 134 | subhask, 135 | HerbiePlugin 136 | -- linear 137 | 138 | Test-Suite ValidRewrite 139 | default-language: Haskell2010 140 | type: exitcode-stdio-1.0 141 | hs-source-dirs: test 142 | main-is: ValidRewrite.hs 143 | 144 | ghc-options: 145 | -fplugin=Herbie 146 | -fplugin-opts=Herbie:tol=0 147 | 148 | build-depends: 149 | base, 150 | subhask, 151 | HerbiePlugin 152 | -- linear 153 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2015, Mike Izbicki 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | * Redistributions of source code must retain the above copyright 8 | notice, this list of conditions and the following disclaimer. 9 | 10 | * Redistributions in binary form must reproduce the above 11 | copyright notice, this list of conditions and the following 12 | disclaimer in the documentation and/or other materials provided 13 | with the distribution. 14 | 15 | * Neither the name of Mike Izbicki nor the names of other 16 | contributors may be used to endorse or promote products derived 17 | from this software without specific prior written permission. 18 | 19 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 20 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 21 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 22 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 23 | OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 24 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 25 | LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 26 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 27 | THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 28 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Herbie GHC Plugin ![](https://travis-ci.org/mikeizbicki/HerbiePlugin.svg) 2 | 3 | The Herbie [GHC Plugin](https://downloads.haskell.org/~ghc/latest/docs/html/users_guide/compiler-plugins.html) automatically improves the [numerical stability](https://en.wikipedia.org/wiki/Numerical_stability) of your Haskell code. 4 | The Herbie plugin fully supports the [SubHask](http://github.com/mikeizbicki/subhask) numeric prelude, 5 | and partially supports the standard prelude (see the [known bugs section](#bugs) below). 6 | 7 | This README is organized into the following sections: 8 | 9 | * [Example: linear algebra and numerical stability](#example-linear-algebra-and-numerical-instability) 10 | * [How the Herbie plugin works](#how-it-works) 11 | * [Installing and using the Herbie plugin](#installing-and-using-the-herbie-plugin) 12 | * [Compiling all of stackage with the Herbie plugin](#compiling-all-of-stackage-with-the-herbie-plugin) 13 | * [Known bugs](#known-bugs) 14 | 15 | ## Example: linear algebra and numerical instability 16 | 17 | The popular [linear](https://hackage.haskell.org/package/linear) library contains [the following calculation](https://github.com/ekmett/linear/blob/35dcce4152c1a26e0d82e9ac75c3b77607b2aa3c/src/Linear/Projection.hs#L73): 18 | 19 | ``` 20 | w :: Double -> Double -> Double 21 | w far near = -(2 * far * near) / (far - near) 22 | ``` 23 | 24 | This code looks correct, but it can give the wrong answer. 25 | When the values of `far` and `near` are both very small (or very large), the product `far * near` will underflow to 0 (or overflow to infinity). 26 | In the worst case scenario using `Double`s, this calculation can lose up to 14 bits of information. 27 | 28 | The Herbie plugin automatically prevents this class of bugs (with some technical caveats, see [how it works](#how-it-works) below). 29 | If you compile the linear package with Herbie, the code above gets rewritten to: 30 | 31 | ``` 32 | w :: Double -> Double -> Double 33 | w far near = if far < -1.7210442634149447e81 34 | then ((-2 * far) / (far - near)) * near 35 | else if far < 8.364504563556443e16 36 | then -2 * far * (near / (far - near)) 37 | else ((-2 * far) / (far - near)) * near 38 | ``` 39 | 40 | This modified code is numerically stable. 41 | The if statements check to see which regime we are in (very small or very large) 42 | and select the calculation that is most appropriate for this regime. 43 | 44 | The [test suite](/test/Tests.hs) contains MANY more examples of the types of expressions the Herbie plugin can analyze. 45 | 46 | ## How it works 47 | 48 | [GHC Plugins](https://downloads.haskell.org/~ghc/latest/docs/html/users_guide/compiler-plugins.html) 49 | let library authors add new features to the Haskell compiler. 50 | The Herbie plugin gets run after type checking, but before any optimizations. 51 | Because GHC is so good at optimizing, 52 | the code generated by the Herbie plugin is just as fast as hand-written code. 53 | The plugin has two key components: 54 | first it finds the floating point computations in your code; 55 | then it replaces them with numerically stable versions. 56 | 57 | ### Finding the computations 58 | 59 | When the Herbie plugin is run, it traverses your code's [abstract syntax tree](https://en.wikipedia.org/wiki/Abstract_syntax_tree) looking for mathematical expressions. 60 | These expressions may: 61 | 62 | * consist of the following operators: 63 | `/` 64 | , `*` 65 | , `-` 66 | , `*` 67 | , `**` 68 | , `^` 69 | , `^^` 70 | , `$` 71 | , `cos` 72 | , `sin` 73 | , `tan` 74 | , `acos` 75 | , `asin` 76 | , `atan` 77 | , `cosh` 78 | , `sinh` 79 | , `tanh` 80 | , `exp` 81 | , `log` 82 | , `sqrt` 83 | , `abs` 84 | 85 | * contain an arbitrary number of free variables 86 | 87 | * contain arbitrary non-mathematical subexpressions 88 | 89 | For example, in the function below: 90 | ``` 91 | test :: String -> String 92 | test str = show ( sqrt (1+fromIntegral (length str)) 93 | - sqrt (fromIntegral (length str)) 94 | :: Float 95 | ) 96 | ``` 97 | the Herbie plugin extracts the expression `sqrt (x+1) - sqrt x`; 98 | calculates the numerically stable version `1 / (sqrt (x+1) + sqrt x)`; 99 | and then substitutes back in. 100 | 101 | The Herbie plugin performs this procedure on both concrete types (`Float` and `Double`) and polymorphic types. 102 | For example, the plugin will rewrite 103 | ``` 104 | f :: Field a => a -> a -> a 105 | f far near = (far+near)/(far*near*2) 106 | ``` 107 | to 108 | ``` 109 | f :: Field a => a -> a -> a 110 | f far near = 0.5/far + 0.5/near 111 | ``` 112 | (The `Field` constraint comes from [SubHask](http://github.com/mikeizbicki/subhask).) 113 | 114 | These polymorphic rewrites are always guaranteed to preserve the semantics when the expression is evaluated on an exact numeric type. 115 | So both versions of `f` above are guaranteed to give the same result when called on `Rational` values, 116 | but the rewritten version will be more accurate when called on `Float` or `Double`. 117 | 118 | The main limitation of the Herbie plugin is that any recursive part of an expression is ignored. 119 | This is because analyzing the numeric stability of a Turing complete language is undecidable 120 | (and no practical heuristics are known). 121 | 122 | Fortunately, the Herbie plugin can still analyze the non-recursive subexpressions of a recursive expression. 123 | So in the code: 124 | ``` 125 | go :: Float -> Float -> Float 126 | go 0 b = b 127 | go a b = go (a-1) (sqrt $ (a+b) * (a+b)) 128 | ``` 129 | the expression `sqrt $ (a+b) * (a+b)` gets rewritten by Herbie into `abs (a+b)`. 130 | 131 | #### Disabling Herbie 132 | 133 | You can prevent Herbie from analyzing a function using an [annotation pragma](https://downloads.haskell.org/~ghc/latest/docs/html/users_guide/extending-ghc.html#annotation-pragmas). 134 | In the following example: 135 | ``` 136 | {-# ANN foo "NoHerbie" #-} 137 | foo x = (x + 1) - x 138 | ``` 139 | Herbie will NOT rewrite the code into 140 | ``` 141 | foo x = 1 142 | ``` 143 | These GHC annotations can only be applied to top level bindings. 144 | They prevent Herbie from searching anywhere inside the binding. 145 | 146 | ### Improving the stability of expressions 147 | 148 | The Herbie plugin uses two sources of information to find numerically stable replacements for expressions. 149 | The simplest source is a sqlite3 database. 150 | This database contains about 400 expressions that are known to be used by Haskell libraries. 151 | 152 | The more important source is the [Herbie program](http://herbie.uwplse.org/). 153 | Herbie is a recent research project on using probabilistic searches to find numerically stable expressions. 154 | Because Herbie is probabilistic, it provides weak theoretic guarantees on the numerical stability of the resulting expressions; 155 | but in practice, the improved expressions are significantly better. 156 | To ensure reproducible builds, the same random seed is used on all calls to the Herbie program. 157 | For more details on how Herbie works, check out the [PLDI15 paper](http://herbie.uwplse.org/pldi15.html). 158 | 159 | The Herbie program can take a long time to run. 160 | If the program doesn't return a solution within two minutes, 161 | then the Herbie plugin assumes that no better solution is possible and continues processing. 162 | To improve compile times, every time the Herbie program returns a new solution, 163 | the solution is added to the `Herbie.db` database. 164 | When compiling a file, if all the math expressions are already in the database, 165 | then the Herbie plugin imposes essentially no overhead on compile times. 166 | 167 | ## Installing and Using the Herbie Plugin 168 | 169 | The Herbie plugin requires GHC 7.10.1 or 7.10.2. 170 | It is installable via cabal using the command: 171 | ``` 172 | cabal update && cabal install HerbiePlugin 173 | ``` 174 | 175 | It is recommended (but not required) that you also install the Herbie program. 176 | (Without installing the program, the Herbie plugin can only replace expressions in the standard database.) 177 | The Herbie program is written in racket, 178 | so you must first install Racket. 179 | Go to the [download page](http://download.racket-lang.org/racket-v6.1.html) for Racket 6.1.1 and install the version for your platform. 180 | Then run the commands: 181 | ``` 182 | git clone http://github.com/mikeizbicki/HerbiePlugin --recursive 183 | cd HerbiePlugin/herbie 184 | raco exe -o herbie-exec herbie/interface/inout.rkt 185 | ``` 186 | The last line creates an executable `herbie-exec` that the Herbie plugin will try to call. 187 | You must move the program somewhere into your `PATH` for these calls to succeed. 188 | One way to do this is with the command: 189 | ``` 190 | sudo mv herbie-exec /usr/local/bin 191 | ``` 192 | 193 | To compile a file with the Herbie plugin, you need to pass the flag `-fplugin=Herbie` to when calling GHC. 194 | You can have `cabal install` automatically apply the Herbie plugin by passing the following flags 195 | ``` 196 | --ghc-option=-fplugin=Herbie 197 | --ghc-option=-package-id herbie-haskell-0.1.0.0-50ba55c8f248a3301dce2d3339977982 198 | ``` 199 | 200 | ## Running Herbie on all of stackage 201 | 202 | [Stackage LTS-3.5](https://www.stackage.org/lts-3.5) is a collection of 1351 of the most popular Haskell libraries. 203 | The script [install.sh]() compiles all of stackage using the Herbie plugin. 204 | 205 | 48 of the 1351 packages (3.5%) use floating point computations internally. 206 | Of these, 40 packages (83%) contain expressions whose numerical stability is improved with the Herbie plugin. 207 | In total, there are 303 distinct numerical expressions used in all stackage packages, 208 | and 164 of these expressions (54%) are more stable using the Herbie plugin. 209 | 210 | The table below shows a detailed breakdown of which packages contain the unstable expressions. 211 | Notice that the unstable and stable expression columns may not add up to the total expressions column. 212 | The difference is the expressions could not be analyzed because the Herbie program timed out. 213 | 214 | | package | total math expressions | unstable expressions | stable expressions | 215 | | ------- | ---------------------- | -------------------- | ------------------ | 216 | | math-functions-0.1.5.2|92|50|34 | 217 | | colour-2.3.3|28|8|18 | 218 | | linear-1.19.1.3|28|19|8 | 219 | | diagrams-lib-1.3.0.3|25|15|10 | 220 | | diagrams-solve-0.1|25|14|11 | 221 | | statistics-0.13.2.3|15|5|7 | 222 | | plot-0.2.3.4|12|9|2 | 223 | | random-fu-0.2.6.2|11|4|5 | 224 | | Chart-1.5.3|10|8|2 | 225 | | circle-packing-0.1.0.4|9|3|6 | 226 | | mwc-random-0.13.3.2|9|1|7 | 227 | | Rasterific-0.6.1|8|3|5 | 228 | | diagrams-contrib-1.3.0.5|6|6|0 | 229 | | log-domain-0.10.2.1|5|0|5 | 230 | | repa-algorithms-3.4.0.1|5|3|2 | 231 | | rasterific-svg-0.2.3.1|4|2|2 | 232 | | sbv-4.4|4|2|2 | 233 | | clustering-0.2.1|3|1|2 | 234 | | erf-2.0.0.0|3|2|1 | 235 | | hsignal-0.2.7.1|3|2|0 | 236 | | hyperloglog-0.3.4|3|1|2 | 237 | | integration-0.2.1|3|1|2 | 238 | | intervals-0.7.1|3|2|1 | 239 | | shake-0.15.5|3|2|1 | 240 | | Chart-diagrams-1.5.1|2|2|0 | 241 | | JuicyPixels-3.2.6.1|2|2|0 | 242 | | Yampa-0.10.2|2|2|0 | 243 | | YampaSynth-0.2|2|1|1 | 244 | | diagrams-rasterific-1.3.1.3|2|1|1 | 245 | | fay-base-0.20.0.1|2|1|1 | 246 | | histogram-fill-0.8.4.1|2|0|2 | 247 | | parsec-3.1.9|2|0|2 | 248 | | smoothie-0.4.0.2|2|2|0 | 249 | | Octree-0.5.4.2|1|1|0 | 250 | | ad-4.2.4|1|1|0 | 251 | | approximate-0.2.2.1|1|1|0 | 252 | | crypto-random-0.0.9|1|1|0 | 253 | | diagrams-cairo-1.3.0.3|1|1|0 | 254 | | diagrams-svg-1.3.1.4|1|1|0 | 255 | | dice-0.1|1|0|1 | 256 | | force-layout-0.4.0.2|1|1|0 | 257 | | gipeda-0.1.2.1|1|0|1 | 258 | | hakyll-4.7.2.3|1|0|1 | 259 | | hashtables-1.2.0.2|1|1|0 | 260 | | hmatrix-0.16.1.5|1|0|1 | 261 | | hscolour-1.23|1|0|1 | 262 | | metrics-0.3.0.2|1|1|0 | 263 | 264 | 299 | 300 | The easiest way to find the offending code in each package is to compile the package using the Herbie plugin. 301 | 302 | ## Known Bugs 303 | 304 | There are no known bugs when compiling programs that use the [SubHask](https://github.com/mikeizbicki/subhask) numeric prelude. 305 | 306 | The standard Prelude is only partially supported. 307 | In particular, the Herbie plugin is able to extract mathematical expressions correctly and will print the stabilized version to stdout. 308 | But the plugin can only substitute the stabilized version on polymorphic expressions, and does not perform the substitution on non-polymorphic ones. 309 | The problem is that the `solveWantedsTcM` function (called within the `getDictionary` function inside of [Herbie/CoreManip.hs](https://github.com/mikeizbicki/HerbiePlugin/blob/master/src/Herbie/CoreManip.hs)) is unable to find the `Num` dictionary for `Float` and `Double` types. 310 | I have no idea why this is happening, and I'd be very happy to accept pull requests that fix the issue :) 311 | -------------------------------------------------------------------------------- /Setup.hs: -------------------------------------------------------------------------------- 1 | import Distribution.Simple 2 | main = defaultMain 3 | -------------------------------------------------------------------------------- /data/Herbie.db: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mikeizbicki/HerbiePlugin/0d495e40961742d578746120cd728c1ec9237cb4/data/Herbie.db -------------------------------------------------------------------------------- /src/Herbie.hs: -------------------------------------------------------------------------------- 1 | module Herbie 2 | ( plugin 3 | , pass 4 | ) 5 | where 6 | 7 | import Class 8 | import DsBinds 9 | import DsMonad 10 | import ErrUtils 11 | import GhcPlugins 12 | import Id 13 | import Unique 14 | import MkId 15 | import PrelNames 16 | import TcRnMonad 17 | import TcSimplify 18 | 19 | import Control.Monad 20 | import Control.Monad.Except 21 | import Data.Data 22 | import Data.List 23 | import Data.Maybe 24 | import Data.Typeable 25 | 26 | import Herbie.CoreManip 27 | import Herbie.ForeignInterface 28 | import Herbie.MathExpr 29 | import Herbie.MathInfo 30 | import Herbie.Options 31 | 32 | import Debug.Trace 33 | 34 | import Prelude 35 | import Show 36 | import Data.IORef 37 | 38 | plugin :: Plugin 39 | plugin = defaultPlugin 40 | { installCoreToDos = install 41 | } 42 | 43 | install :: [CommandLineOption] -> [CoreToDo] -> CoreM [CoreToDo] 44 | install opts todo = do 45 | putMsgS "Compiling with Herbie floating point stabilization" 46 | reinitializeGlobals 47 | return (CoreDoPluginPass "MathInfo" (pass opts) : todo) 48 | 49 | pass :: [CommandLineOption] -> ModGuts -> CoreM ModGuts 50 | pass opts guts = do 51 | dflags <- getDynFlags 52 | liftIO $ writeIORef dynFlags_ref dflags 53 | bindsOnlyPass (mapM (modBind opts guts)) guts 54 | 55 | -- | This function gets run on each binding on the Haskell source file. 56 | modBind :: [CommandLineOption] -> ModGuts -> CoreBind -> CoreM CoreBind 57 | modBind opts guts bndr@(Rec _) = return bndr 58 | modBind opts guts bndr@(NonRec b e) = do 59 | -- dflags <- getDynFlags 60 | -- putMsgS "" 61 | -- putMsgS $ showSDoc dflags (ppr b) 62 | -- ++ "::" 63 | -- ++ showSDoc dflags (ppr $ varType b) 64 | -- putMsgS $ myshow dflags e 65 | -- return bndr 66 | anns <- annotationsOn guts b :: CoreM [String] 67 | e' <- if "NoHerbie" `elem` anns 68 | then return e 69 | else go [] e 70 | return $ NonRec b e' 71 | where 72 | pluginOpts = parsePluginOpts opts 73 | 74 | -- Recursively descend into the expression e. 75 | -- For each math expression we find, run Herbie on it. 76 | -- We need to save each dictionary we find because 77 | -- it might be needed to create the replacement expressions. 78 | go dicts e = do 79 | dflags <- getDynFlags 80 | case mkMathInfo dflags dicts (varType b) e of 81 | 82 | -- not a math expression, so recurse into subexpressions 83 | Nothing -> case e of 84 | 85 | -- Lambda expression: 86 | -- If the variable is a dictionary, add it to the list; 87 | -- Always recurse into the subexpression 88 | -- 89 | -- FIXME: 90 | -- Currently, we're removing deadness annotations from any dead variables. 91 | -- This is so that we can use all the dictionaries that the type signatures allow. 92 | -- Core lint complains about using dead variables if we don't. 93 | -- This causes us to remove ALL deadness annotations in the entire program. 94 | -- I'm not sure the drawback of this. 95 | -- This could be fixed by having a second pass through the code 96 | -- to remove only the appropriate deadness annotations. 97 | Lam a b -> do 98 | let a' = undeadenId a 99 | b' <- go (extractDicts a'++dicts) b 100 | return $ Lam a' b' 101 | 102 | -- Let binding: 103 | -- If the variable is a dictionary, add it to the list; 104 | -- Always recurse into the subexpression 105 | Let (NonRec a e) b -> do 106 | let a' = undeadenId a 107 | e' <- go dicts e 108 | b' <- go (extractDicts a'++dicts) b 109 | return $ Let (NonRec a' e') b' 110 | 111 | Let (Rec bndrs) expr -> do 112 | bndrs' <- forM bndrs $ \(a,e) -> do 113 | let a' = undeadenId a 114 | e' <- go dicts e 115 | return (a',e') 116 | expr' <- go dicts expr 117 | return $ Let (Rec bndrs') expr' 118 | 119 | -- Function application: 120 | -- Math expressions may appear on either side, so recurse on both 121 | App a b -> do 122 | a' <- go dicts a 123 | b' <- go dicts b 124 | return $ App a' b' 125 | 126 | -- Case statement: 127 | -- Math expressions may appear in the condition or in any of the branches 128 | Case cond w t es -> do 129 | cond' <- go dicts cond 130 | es' <- forM es $ \ (altcon, xs, expr) -> do 131 | expr' <- go dicts expr 132 | return (altcon, xs, expr') 133 | return $ Case cond' w t es' 134 | 135 | -- Ticks and Casts are just annotating extra information on an expression. 136 | -- We ignore the extra information and recurse into the expression. 137 | Tick a b -> do 138 | b' <- go dicts b 139 | return $ Tick a b' 140 | 141 | Cast a b -> do 142 | a' <- go dicts a 143 | return $ Cast a' b 144 | 145 | -- There's nothing to do for these statements. 146 | -- They form the recursion's base case. 147 | Var v -> return $ Var v 148 | Lit l -> return $ Lit l 149 | Type t -> return $ Type t 150 | Coercion c -> return $ Coercion c 151 | 152 | -- We found a math expression, so process it 153 | Just mathInfo -> do 154 | putMsgS $ "Found math expression within binding " 155 | ++ showSDoc dflags (ppr b) 156 | ++ " :: " 157 | ++ showSDoc dflags (ppr $ varType b) 158 | putMsgS $ " original expression = "++pprMathInfo mathInfo 159 | 160 | simplifyExpression $ HerbieOptions 161 | [ [ "-r", "#(1461197085 2376054483 1553562171 1611329376 2497620867 2308122621)" ] 162 | , [ "-o", "rules:numerics" ] 163 | ] 164 | 165 | where 166 | simplifyExpression :: HerbieOptions -> CoreM CoreExpr 167 | simplifyExpression herbieOpts = do 168 | let dbgInfo = DbgInfo 169 | { dbgComments = concat opts 170 | , modName = showSDoc dflags (ppr $ moduleName $ mg_module guts) 171 | , functionName = showSDoc dflags (ppr b) 172 | , functionType = showSDoc dflags (ppr $ varType b) 173 | } 174 | res <- liftIO $ stabilizeMathExpr dbgInfo herbieOpts $ getMathExpr mathInfo 175 | let mathInfo' = mathInfo { getMathExpr = cmdout res } 176 | 177 | -- Display the improved expression if found 178 | let canRewrite = errin res-errout res > optsTol pluginOpts 179 | if canRewrite 180 | then do 181 | putMsgS $ " improved expression = "++pprMathInfo mathInfo' 182 | putMsgS $ " original error = "++show (errin res)++" bits" 183 | putMsgS $ " improved error = "++show (errout res)++" bits" 184 | else do 185 | putMsgS $ " Herbie could not improve the stability of the original expression" 186 | 187 | -- Rewrite the expression 188 | if not (optsRewrite pluginOpts) || not canRewrite 189 | then return e 190 | else do 191 | ret <- runExceptT $ mathInfo2expr guts mathInfo' 192 | case ret of 193 | Left (NotInScope var) -> do 194 | putMsgS $ " WARNING: Variable \""++var++"\" not in scope" 195 | if var `elem` fancyOps 196 | then do 197 | putMsgS " WARNING: Disabling fancy numerical operations and retrying" 198 | simplifyExpression $ toggleNumerics herbieOpts 199 | else do 200 | putMsgS " WARNING: Not substituting the improved expression into your code" 201 | return e 202 | 203 | Left (BadCxt cxt) -> do 204 | putMsgS $ " WARNING: Cannot satisfy the constraint "++cxt 205 | if "Ord" `isInfixOf` cxt 206 | then do 207 | putMsgS " WARNING: Disabling if statements and retrying" 208 | simplifyExpression $ toggleRegimes herbieOpts 209 | else do 210 | putMsgS " WARNING: Not substituting the improved expression into your code" 211 | return e 212 | 213 | Left (OtherException str) -> do 214 | putMsgS " WARNING: Not substituting the improved expression into your code" 215 | putMsgS str 216 | return e 217 | 218 | Right e' -> return e' 219 | 220 | -- | Return a list with the given variable if the variable is a dictionary or tuple of dictionaries, 221 | -- otherwise return []. 222 | extractDicts :: Var -> [Var] 223 | extractDicts v = case classifyPredType (varType v) of 224 | ClassPred _ _ -> [v] 225 | EqPred _ _ _ -> [v] 226 | TuplePred _ -> [v] 227 | IrredPred _ -> [] 228 | 229 | -- | If a variable is marked as dead, remove the marking 230 | undeadenId :: Var -> Var 231 | undeadenId a = if isDeadBinder a 232 | then setIdOccInfo a NoOccInfo 233 | else a 234 | 235 | -- | Function taken from the docs: 236 | -- https://downloads.haskell.org/~ghc/latest/docs/html/users_guide/compiler-plugins.html 237 | annotationsOn :: Data a => ModGuts -> CoreBndr -> CoreM [a] 238 | annotationsOn guts bndr = do 239 | anns <- getAnnotations deserializeWithData guts 240 | return $ lookupWithDefaultUFM anns [] (varUnique bndr) 241 | -------------------------------------------------------------------------------- /src/Herbie/CoreManip.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE CPP #-} 2 | module Herbie.CoreManip 3 | where 4 | 5 | import Class 6 | import DsBinds 7 | import DsMonad 8 | import ErrUtils 9 | import GhcPlugins hiding (trace) 10 | import Unique 11 | import MkId 12 | import PrelNames 13 | import UniqSupply 14 | import TcRnMonad 15 | import TcSimplify 16 | import Type 17 | 18 | import Control.Monad 19 | import Control.Monad.Except 20 | import Control.Monad.Trans 21 | import Data.Char 22 | import Data.List 23 | import Data.Maybe 24 | import Data.Ratio 25 | 26 | import Herbie.MathExpr 27 | 28 | import Prelude 29 | import Show 30 | 31 | -- import Debug.Trace hiding (traceM) 32 | trace a b = b 33 | traceM a = return () 34 | 35 | -------------------------------------------------------------------------------- 36 | 37 | instance MonadUnique m => MonadUnique (ExceptT e m) where 38 | getUniqueSupplyM = lift getUniqueSupplyM 39 | 40 | instance (Monad m, HasDynFlags m) => HasDynFlags (ExceptT e m) where 41 | getDynFlags = lift getDynFlags 42 | 43 | instance MonadThings m => MonadThings (ExceptT e m) where 44 | lookupThing name = lift $ lookupThing name 45 | 46 | data ExceptionType 47 | = NotInScope String 48 | | BadCxt String 49 | | OtherException String 50 | 51 | 52 | ---------------------------------------- 53 | -- core manipulation 54 | 55 | -- | Converts a string into a Core variable 56 | getVar :: ModGuts -> String -> ExceptT ExceptionType CoreM Var 57 | getVar guts opstr = do 58 | opname <- getName guts opstr 59 | hscenv <- lift getHscEnv 60 | dflags <- getDynFlags 61 | eps <- liftIO $ hscEPS hscenv 62 | optype <- case lookupNameEnv (eps_PTE eps) opname of 63 | Just (AnId i) -> return $ varType i 64 | _ -> throwError (NotInScope opstr) 65 | return $ mkGlobalVar VanillaId opname optype vanillaIdInfo 66 | 67 | where 68 | getName :: ModGuts -> String -> ExceptT ExceptionType CoreM Name 69 | getName guts str = case filter isCorrectVar (concat $ occEnvElts (mg_rdr_env guts)) of 70 | xs -> if not (null xs) 71 | then return $ gre_name $ head xs 72 | else throwError (NotInScope str) 73 | where 74 | isCorrectVar x = getString (gre_name x) == str 75 | && (str == "abs" || case gre_par x of NoParent -> False; _ -> True) 76 | 77 | -- | Like "decorateFunction", but first finds the function variable given a string. 78 | getDecoratedFunction :: ModGuts -> String -> Type -> [CoreExpr] -> ExceptT ExceptionType CoreM CoreExpr 79 | getDecoratedFunction guts str t preds = do 80 | f <- getVar guts str 81 | decorateFunction guts f t preds 82 | 83 | -- | Given a variable that contains a function, 84 | -- the type the function is being applied to, 85 | -- and all in scope predicates, 86 | -- apply the type and any needed dictionaries to the function. 87 | decorateFunction :: ModGuts -> Var -> Type -> [CoreExpr] -> ExceptT ExceptionType CoreM CoreExpr 88 | decorateFunction guts f t preds = do 89 | let ([v],unquantified) = extractQuantifiers $ varType f 90 | (cxt,_) = extractContext unquantified 91 | cxt' = substTysWith [v] [t] cxt 92 | 93 | cxt'' <- mapM getDict cxt' 94 | 95 | return $ mkApps (App (Var f) (Type t)) cxt'' 96 | where 97 | getDict :: PredType -> ExceptT ExceptionType CoreM CoreExpr 98 | getDict pred = do 99 | catchError 100 | (getDictionary guts pred) 101 | (\_ -> getPredEvidence guts pred preds) 102 | 103 | -- | Given a non-polymorphic PredType (e.g. `Num Float`), 104 | -- return the corresponding dictionary. 105 | getDictionary :: ModGuts -> Type -> ExceptT ExceptionType CoreM CoreExpr 106 | getDictionary guts dictTy = do 107 | let dictVar = mkGlobalVar 108 | VanillaId 109 | (mkSystemName (mkUnique 'z' 1337) (mkVarOcc "magicDictionaryName")) 110 | dictTy 111 | vanillaIdInfo 112 | 113 | bnds <- lift $ runTcM guts $ do 114 | loc <- getCtLoc $ GivenOrigin UnkSkol 115 | let nonC = mkNonCanonical CtWanted 116 | { ctev_pred = varType dictVar 117 | , ctev_evar = dictVar 118 | , ctev_loc = loc 119 | } 120 | wCs = mkSimpleWC [nonC] 121 | (x, evBinds) <- solveWantedsTcM wCs 122 | bnds <- initDsTc $ dsEvBinds evBinds 123 | 124 | -- liftIO $ do 125 | -- putStrLn $ "dictType="++showSDoc dflags (ppr dictType) 126 | -- putStrLn $ "dictVar="++showSDoc dflags (ppr dictVar) 127 | -- 128 | -- putStrLn $ "nonC="++showSDoc dflags (ppr nonC) 129 | -- putStrLn $ "wCs="++showSDoc dflags (ppr wCs) 130 | -- putStrLn $ "bnds="++showSDoc dflags (ppr bnds) 131 | -- putStrLn $ "x="++showSDoc dflags (ppr x) 132 | 133 | return bnds 134 | 135 | case bnds of 136 | [NonRec _ dict] -> return dict 137 | otherwise -> throwError $ BadCxt $ dbg dictTy 138 | 139 | -- | Given a predicate for which we don't have evidence 140 | -- and a list of expressions that contain evidence for predicates, 141 | -- construct an expression that contains evidence for the given predicate. 142 | getPredEvidence :: ModGuts -> PredType -> [CoreExpr] -> ExceptT ExceptionType CoreM CoreExpr 143 | getPredEvidence guts pred evidenceExprs = go $ prepEvidence evidenceExprs 144 | where 145 | 146 | go :: [(CoreExpr,Type)] -> ExceptT ExceptionType CoreM CoreExpr 147 | 148 | -- We've looked at all the evidence, but didn't find anything 149 | go [] = throwError $ BadCxt $ dbg pred 150 | 151 | -- Recursively descend into all the available predicates. 152 | -- The list tracks both the evidence expression (this will change in recursive descent), 153 | -- and the baseTy that gave rise to the expression (this stays constant). 154 | go ((expr,baseTy):exprs) = if exprType expr == pred 155 | 156 | -- The expression we've found matches the predicate. 157 | -- We're done! 158 | then return expr 159 | 160 | -- The expression doesn't match the predicate, 161 | -- so we recurse by searching for sub-predicates within expr 162 | -- and adding them to the list. 163 | else case classifyPredType (exprType expr) of 164 | 165 | -- What we've found contains no more predicates to recurse into, 166 | -- so we don't add anything to the list of exprs to search. 167 | IrredPred _ -> go exprs 168 | 169 | EqPred _ t1 t2 -> trace ("getPredEvidence.go.EP: pred="++dbg pred 170 | ++"; origType="++dbg baseTy 171 | ++"; exprType="++dbg (exprType expr) 172 | ) $ case splitAppTy_maybe pred of 173 | Nothing -> trace " A" $ go exprs 174 | -- Just (tyCon,tyApp) -> if baseTy/=tyApp 175 | Just (tyCon,tyApp) -> trace " A'" $ if t1/=tyApp && t2 /=tyApp 176 | then trace (" B: baseTy="++dbg baseTy++"; tyApp="++dbg tyApp) 177 | $ go exprs 178 | else do 179 | let pred' = mkAppTy tyCon $ if t1==tyApp 180 | then t2 181 | else t1 182 | getDictionary guts pred' >>= castToType evidenceExprs pred 183 | 184 | -- We've found a class dictionary. 185 | -- Recurse into each field (selId) of the dictionary. 186 | -- Some (but not all) of these may be more dictionaries. 187 | -- 188 | -- FIXME: Multiparamter classes broken 189 | ClassPred c' [ct] -> trace ("getPredEvidence.go.CP: pred="++dbg pred 190 | ++"; origType="++dbg baseTy 191 | ++"; exprType="++dbg (exprType expr) 192 | ) $ 193 | go $ 194 | exprs++ 195 | [ ( App (App (Var selId) (Type baseTy)) expr 196 | , baseTy 197 | ) 198 | | selId <- classAllSelIds c' 199 | ] 200 | 201 | ClassPred _ _ -> go exprs 202 | 203 | -- We've found a tuple of evidence. 204 | -- For each field of the tuple we extract it with a case statement, then recurse. 205 | TuplePred preds -> do 206 | trace ("getPredEvidence.go.TP: pred="++dbg pred 207 | ++"; origType="++dbg baseTy 208 | ++"; exprType="++dbg (exprType expr) 209 | ) $ return () 210 | 211 | uniqs <- getUniquesM 212 | 213 | traceM $ " tupelems: baseTy="++dbg baseTy++"; preds="++dbg preds 214 | let tupelems = 215 | [ mkLocalVar 216 | VanillaId 217 | (mkSystemName uniq (mkVarOcc $ "a"++show j)) 218 | t' 219 | -- (mkAppTy (fst $ splitAppTys t') baseTy) 220 | vanillaIdInfo 221 | | (j,t',uniq) <- zip3 [0..] preds uniqs 222 | ] 223 | 224 | uniq <- getUniqueM 225 | let wildName = mkSystemName uniq (mkVarOcc "wild") 226 | wildVar = mkLocalVar VanillaId wildName (exprType expr) vanillaIdInfo 227 | 228 | let ret = 229 | [ ( Case expr wildVar (varType $ tupelems!!i) 230 | [ ( DataAlt $ tupleCon ConstraintTuple $ length preds 231 | , tupelems 232 | , Var $ tupelems!!i 233 | ) 234 | ] 235 | , baseTy 236 | ) 237 | | (i,t) <- zip [0..] preds 238 | ] 239 | 240 | sequence_ [ traceM $ " ret!!"++show i++"="++myshow dynFlags (fst $ ret!!i) | i<-[0..length ret-1]] 241 | 242 | go $ ret++exprs 243 | 244 | -- | Given some evidence, an expression, and a type: 245 | -- try to prove that the expression can be cast to the type. 246 | -- If it can, return the cast expression. 247 | castToType :: [CoreExpr] -> Type -> CoreExpr -> ExceptT ExceptionType CoreM CoreExpr 248 | castToType xs castTy inputExpr = if exprType inputExpr == castTy 249 | then return inputExpr 250 | else go $ prepEvidence xs 251 | -- else go $ catMaybes [ (x, extractBaseTy $ exprType x) | x <- xs ] 252 | where 253 | 254 | 255 | go :: [(CoreExpr,Type)] -> ExceptT ExceptionType CoreM CoreExpr 256 | 257 | -- base case: we've searched through all the evidence, but couldn't create a cast 258 | go [] = throwError $ OtherException $ 259 | " WARNING: Could not cast expression of type "++dbg (exprType inputExpr)++" to "++dbg castTy 260 | 261 | -- recursively try each evidence expression looking for a cast 262 | go ((expr,baseTy):exprs) = case classifyPredType $ exprType expr of 263 | 264 | IrredPred _ -> go exprs 265 | 266 | EqPred _ t1 t2 -> trace ("castToType.go.EP: castTy="++dbg castTy 267 | ++"; origType="++dbg baseTy 268 | ++"; exprType="++dbg (exprType expr) 269 | ) $ goEqPred [] castTy (exprType inputExpr) 270 | where 271 | -- Check if a cast is possible. 272 | -- We need to recursively peel off all the type constructors 273 | -- on the inputTyRHS and castTyRHS types. 274 | -- As long as the type constructors match, 275 | -- we might be able to do a cast at any level of the peeling 276 | goEqPred :: [TyCon] -> Type -> Type -> ExceptT ExceptionType CoreM CoreExpr 277 | goEqPred tyCons castTyRHS inputTyRHS = if 278 | | t1==castTyRHS && t2==inputTyRHS -> mkCast True 279 | | t2==castTyRHS && t1==inputTyRHS -> mkCast False 280 | | otherwise -> case ( splitTyConApp_maybe castTyRHS 281 | , splitTyConApp_maybe inputTyRHS 282 | ) of 283 | (Just (castTyCon, [castTyRHS']), Just (inputTyCon,[inputTyRHS'])) -> 284 | if castTyCon == inputTyCon 285 | then goEqPred (castTyCon:tyCons) castTyRHS' inputTyRHS' 286 | else go exprs 287 | _ -> go exprs 288 | where 289 | 290 | -- Constructs the actual cast from one variable type to another. 291 | -- 292 | -- There's some subtle voodoo in here involving GHC's Roles. 293 | -- Basically, everything gets created as a Nominal role, 294 | -- but the final Coercion needs to be Representational. 295 | -- mkSubCo converts from Nominal into Representational. 296 | -- See https://ghc.haskell.org/trac/ghc/wiki/RolesImplementation 297 | mkCast :: Bool -> ExceptT ExceptionType CoreM CoreExpr 298 | mkCast isFlipped = do 299 | coboxUniq <- getUniqueM 300 | let coboxName = mkSystemName coboxUniq (mkVarOcc "cobox") 301 | coboxType = if isFlipped 302 | then mkCoercionType Nominal castTyRHS inputTyRHS 303 | else mkCoercionType Nominal inputTyRHS castTyRHS 304 | coboxVar = mkLocalVar VanillaId coboxName coboxType vanillaIdInfo 305 | 306 | -- Reapplies the list of tyCons that we peeled off during the recursion. 307 | let mkCoercion [] = if isFlipped 308 | then mkSymCo $ mkCoVarCo coboxVar 309 | else mkCoVarCo coboxVar 310 | mkCoercion (x:xs) = mkTyConAppCo Nominal x [mkCoercion xs] 311 | 312 | wildUniq <- getUniqueM 313 | let wildName = mkSystemName wildUniq (mkVarOcc "wild") 314 | wildType = exprType expr 315 | wildVar = mkLocalVar VanillaId wildName wildType vanillaIdInfo 316 | 317 | return $ Case 318 | expr 319 | wildVar 320 | castTy 321 | [ ( DataAlt eqBoxDataCon 322 | , [coboxVar] 323 | , Cast inputExpr $ mkSubCo $ mkCoercion tyCons 324 | ) ] 325 | 326 | -- FIXME: ClassPred and TuplePred are both handled the same 327 | -- within castToPred and getPredEvidence. 328 | -- They should be factored out? 329 | ClassPred c' [ct] -> go $ 330 | exprs++ 331 | [ ( App (App (Var selId) (Type baseTy)) expr 332 | , baseTy 333 | ) 334 | | selId <- classAllSelIds c' 335 | ] 336 | 337 | ClassPred _ _ -> go exprs 338 | 339 | TuplePred preds -> do 340 | uniqs <- getUniquesM 341 | let tupelems = 342 | [ mkLocalVar 343 | VanillaId 344 | (mkSystemName uniq (mkVarOcc $ "a"++show j)) 345 | -- (mkAppTy (fst $ splitAppTys t') baseTy) 346 | t' 347 | vanillaIdInfo 348 | | (j,t',uniq) <- zip3 [0..] preds uniqs 349 | ] 350 | 351 | uniq <- getUniqueM 352 | let wildName = mkSystemName uniq (mkVarOcc "wild") 353 | wildVar = mkLocalVar VanillaId wildName (exprType expr) vanillaIdInfo 354 | 355 | let ret = 356 | [ ( Case expr wildVar (varType $ tupelems!!i) 357 | [ ( DataAlt $ tupleCon ConstraintTuple $ length preds 358 | , tupelems 359 | , Var $ tupelems!!i 360 | ) 361 | ] 362 | , baseTy 363 | ) 364 | | (i,t) <- zip [0..] preds 365 | ] 366 | 367 | go $ ret++exprs 368 | 369 | -- | Each element in the input list must contain evidence of a predicate. 370 | -- The output list contains evidence of a predicate along with a type that will be used for casting. 371 | prepEvidence :: [CoreExpr] -> [(CoreExpr,Type)] 372 | prepEvidence exprs = catMaybes 373 | [ case extractBaseTy $ exprType x of 374 | Just t -> Just (x,t) 375 | Nothing -> Nothing --(x, extractBaseTy $ exprType x) 376 | | x <- exprs 377 | ] 378 | 379 | where 380 | -- Extracts the type that each of our pieces of evidence is applied to 381 | extractBaseTy :: Type -> Maybe Type 382 | extractBaseTy t = case classifyPredType t of 383 | 384 | ClassPred _ [x] -> Just x 385 | 386 | EqPred rel t1 t2 -> if 387 | | t1 == boolTy -> Just t2 388 | | t2 == boolTy -> Just t1 389 | | otherwise -> Nothing 390 | 391 | _ -> Nothing 392 | 393 | -- | Return all the TyVars that occur anywhere in the Type 394 | extractTyVars :: Type -> [TyVar] 395 | extractTyVars t = case getTyVar_maybe t of 396 | Just x -> [x] 397 | Nothing -> case tyConAppArgs_maybe t of 398 | Just xs -> concatMap extractTyVars xs 399 | Nothing -> concatMap extractTyVars $ snd $ splitAppTys t 400 | 401 | -- | Given a quantified type of the form: 402 | -- 403 | -- > forall a. (Num a, Ord a) => a -> a 404 | -- 405 | -- The first element of the returned tuple is the list of quantified variables, 406 | -- and the seecond element is the unquantified type. 407 | extractQuantifiers :: Type -> ([Var],Type) 408 | extractQuantifiers t = case splitForAllTy_maybe t of 409 | Nothing -> ([],t) 410 | Just (a,b) -> (a:as,b') 411 | where 412 | (as,b') = extractQuantifiers b 413 | 414 | -- | Given unquantified types of the form: 415 | -- 416 | -- > (Num a, Ord a) => a -> a 417 | -- 418 | -- The first element of the returned tuple contains everything to the left of "=>"; 419 | -- and the second element contains everything to the right. 420 | extractContext :: Type -> ([Type],Type) 421 | extractContext t = case splitTyConApp_maybe t of 422 | Nothing -> ([],t) 423 | Just (tycon,xs) -> if occNameString (nameOccName $ tyConName tycon) /= "(->)" 424 | || not hasCxt 425 | then ([],t) 426 | else (head xs:cxt',t') 427 | where 428 | (cxt',t') = extractContext $ head $ tail xs 429 | 430 | hasCxt = case classifyPredType $ head xs of 431 | IrredPred _ -> False 432 | _ -> True 433 | 434 | -- | given a function, get the type of the parameters 435 | -- 436 | -- FIXME: this should be deleted 437 | extractParam :: Type -> Maybe Type 438 | extractParam t = case splitTyConApp_maybe t of 439 | Nothing -> Nothing 440 | Just (tycon,xs) -> if occNameString (nameOccName $ tyConName tycon) /= "(->)" 441 | then Just t -- Nothing 442 | else Just (head xs) 443 | 444 | 445 | -- | Given a type of the form 446 | -- 447 | -- > A -> ... -> C 448 | -- 449 | -- returns C 450 | getReturnType :: Type -> Type 451 | getReturnType t = case splitForAllTys t of 452 | (_,t') -> go t' 453 | where 454 | go t = case splitTyConApp_maybe t of 455 | Just (tycon,[_,t']) -> if getString tycon=="(->)" 456 | then go t' 457 | else t 458 | _ -> t 459 | 460 | 461 | -------------------------------------------------------------------------------- 462 | -- 463 | 464 | runTcM :: ModGuts -> TcM a -> CoreM a 465 | runTcM guts tcm = do 466 | env <- getHscEnv 467 | dflags <- getDynFlags 468 | #if __GLASGOW_HASKELL__ < 710 || (__GLASGOW_HASKELL__ == 710 && __GLASGOW_HASKELL_PATCHLEVEL1__ < 2) 469 | (msgs, mr) <- liftIO $ initTc env HsSrcFile False (mg_module guts) tcm 470 | #else 471 | let realSrcSpan = mkRealSrcSpan 472 | (mkRealSrcLoc (mkFastString "a") 0 1) 473 | (mkRealSrcLoc (mkFastString "b") 2 3) 474 | (msgs, mr) <- liftIO $ initTc env HsSrcFile False (mg_module guts) realSrcSpan tcm 475 | #endif 476 | let showMsgs (warns, errs) = showSDoc dflags $ vcat 477 | $ text "Errors:" : pprErrMsgBag errs 478 | ++ text "Warnings:" : pprErrMsgBag warns 479 | maybe (fail $ showMsgs msgs) return mr 480 | where 481 | pprErrMsgBag = pprErrMsgBagWithLoc 482 | 483 | -------------------------------------------------------------------------------- 484 | -- utils 485 | 486 | getString :: NamedThing a => a -> String 487 | getString = occNameString . getOccName 488 | 489 | expr2str :: DynFlags -> Expr Var -> String 490 | expr2str dflags (Var v) = {-"var_" ++-} var2str v++"_"++showSDoc dflags (ppr $ getUnique v) 491 | expr2str dflags e = "expr_" ++ (decorate $ showSDoc dflags (ppr e)) 492 | where 493 | decorate :: String -> String 494 | decorate = map go 495 | where 496 | go x = if not (isAlphaNum x) 497 | then '_' 498 | else x 499 | 500 | lit2rational :: Literal -> Rational 501 | lit2rational l = case l of 502 | MachInt i -> toRational i 503 | MachInt64 i -> toRational i 504 | MachWord i -> toRational i 505 | MachWord64 i -> toRational i 506 | MachFloat r -> r 507 | MachDouble r -> r 508 | LitInteger i _ -> toRational i 509 | 510 | var2str :: Var -> String 511 | var2str = occNameString . occName . varName 512 | 513 | maybeHead :: [a] -> Maybe a 514 | maybeHead (a:_) = Just a 515 | maybeHead _ = Nothing 516 | 517 | myshow :: DynFlags -> Expr Var -> String 518 | myshow dflags = go 1 519 | where 520 | go i (Var v) = "Var "++showSDoc dflags (ppr v) 521 | ++"_"++showSDoc dflags (ppr $ getUnique v) 522 | ++"::"++showSDoc dflags (ppr $ varType v) 523 | go i (Lit (MachFloat l )) = "FloatLiteral " ++show (fromRational l :: Double) 524 | go i (Lit (MachDouble l )) = "DoubleLiteral " ++show (fromRational l :: Double) 525 | go i (Lit (MachInt l )) = "IntLiteral " ++show (fromIntegral l :: Double) 526 | go i (Lit (MachInt64 l )) = "Int64Literal " ++show (fromIntegral l :: Double) 527 | go i (Lit (MachWord l )) = "WordLiteral " ++show (fromIntegral l :: Double) 528 | go i (Lit (MachWord64 l )) = "Word64Literal " ++show (fromIntegral l :: Double) 529 | go i (Lit (LitInteger l t)) = "IntegerLiteral "++show (fromIntegral l :: Double)++ 530 | "::"++showSDoc dflags (ppr t) 531 | go i (Lit l) = "Lit" 532 | go i (Type t) = "Type "++showSDoc dflags (ppr t) 533 | go i (Tick a b) = "Tick (" ++ show a ++ ") ("++go (i+1) b++")" 534 | go i (Coercion l) = "Coercion "++myCoercionShow dflags l 535 | go i (Cast a b) 536 | = "Cast \n" 537 | ++white++"(" ++ go (i+1) a ++ ")\n" 538 | ++white++"("++myshow dflags (Coercion b)++")\n" 539 | ++drop 4 white 540 | where 541 | white=replicate (4*i) ' ' 542 | go i (Let (NonRec a e) b) 543 | = "Let "++getString a++"_"++showSDoc dflags (ppr $ getUnique a) 544 | ++"::"++showSDoc dflags (ppr $ varType a)++"\n" 545 | ++white++"("++go (i+1) e++")\n" 546 | ++white++"("++go (i+1) b++")\n" 547 | ++drop 4 white 548 | where 549 | white=replicate (4*i) ' ' 550 | go i (Let _ _) = error "myshow: recursive let" 551 | go i (Lam a b) 552 | = "Lam "++getString a++"_"++showSDoc dflags (ppr $ getUnique a) 553 | ++"::"++showSDoc dflags (ppr $ varType a) 554 | ++"; coercion="++show (isCoVar a)++"\n" 555 | ++white++"("++go (i+1) b++")\n" 556 | ++drop 4 white 557 | where 558 | white=replicate (4*i) ' ' 559 | go i (App a b) 560 | = "App\n" 561 | ++white++"(" ++ go (i+1) a ++ ")\n" 562 | ++white++"("++go (i+1) b++")\n" 563 | ++drop 4 white 564 | where 565 | white=replicate (4*i) ' ' 566 | go i (Case a b c d) 567 | = "Case\n" 568 | ++white++"("++go (i+1) a++")\n" 569 | ++white++"("++getString b++"_"++showSDoc dflags (ppr $ getUnique b) 570 | ++"::"++showSDoc dflags (ppr $ varType b)++")\n" 571 | ++white++"("++showSDoc dflags (ppr c)++"; "++show (fmap (myshow dflags . Var) $ getTyVar_maybe c)++")\n" 572 | ++white++"["++concatMap altShow d++"]\n" 573 | ++drop 4 white 574 | where 575 | white=replicate (4*i) ' ' 576 | 577 | altShow :: Alt Var -> String 578 | altShow (con,xs,expr) = "("++con'++", "++xs'++", "++go (i+1) expr++")\n"++white 579 | where 580 | con' = case con of 581 | DataAlt x -> showSDoc dflags (ppr x) 582 | LitAlt x -> showSDoc dflags (ppr x) 583 | DEFAULT -> "DEFAULT" 584 | 585 | xs' = show $ map (myshow dflags . Var) xs 586 | 587 | myCoercionShow :: DynFlags -> Coercion -> String 588 | myCoercionShow dflags c = go c 589 | where 590 | go (Refl _ _ ) = "Refl" 591 | go (TyConAppCo a b c ) = "TyConAppCo "++showSDoc dflags (ppr a)++" " 592 | ++showSDoc dflags (ppr b)++" " 593 | ++showSDoc dflags (ppr c) 594 | go (AppCo _ _ ) = "AppCo" 595 | go (ForAllCo _ _ ) = "ForAllCo" 596 | go (CoVarCo v ) = "CoVarCo ("++myshow dflags (Var v)++")" 597 | go (AxiomInstCo _ _ _ ) = "AxiomInstCo" 598 | go (UnivCo _ _ _ _ ) = "UnivCo" 599 | go (SymCo c' ) = "SymCo ("++myCoercionShow dflags c'++")" 600 | go (TransCo _ _ ) = "TransCo" 601 | go (AxiomRuleCo _ _ _ ) = "AxiomRuleCo" 602 | go (NthCo _ _ ) = "NthCo" 603 | go (LRCo _ _ ) = "LRCo" 604 | go (InstCo _ _ ) = "InstCo" 605 | go (SubCo c' ) = "SubCo ("++myCoercionShow dflags c'++")" 606 | 607 | 608 | -- instance Show (Coercion) where 609 | -- show _ = "Coercion" 610 | -- 611 | -- instance Show b => Show (Bind b) where 612 | -- show _ = "Bind" 613 | -- 614 | -- instance Show (Tickish Id) where 615 | -- show _ = "(Tickish Id)" 616 | -- 617 | -- instance Show Type where 618 | -- show _ = "Type" 619 | -- 620 | -- instance Show AltCon where 621 | -- show _ = "AltCon" 622 | -- 623 | -- instance Show Var where 624 | -- show v = getString v 625 | 626 | 627 | -------------------------------------------------------------------------------- /src/Herbie/ForeignInterface.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE OverloadedStrings, FlexibleContexts #-} 2 | 3 | module Herbie.ForeignInterface 4 | where 5 | 6 | import Control.Applicative 7 | import Control.Exception 8 | import Control.DeepSeq 9 | import Data.List 10 | import Data.String 11 | import qualified Data.Text as T 12 | import Database.SQLite.Simple 13 | import Database.SQLite.Simple.FromRow 14 | import Database.SQLite.Simple.FromField 15 | import Database.SQLite.Simple.ToField 16 | import GHC.Generics hiding (modName) 17 | import System.Directory 18 | import System.Process 19 | import System.Timeout 20 | 21 | import Paths_HerbiePlugin 22 | import Herbie.MathInfo 23 | import Herbie.MathExpr 24 | 25 | import Prelude 26 | 27 | -- | Stores the flags that will get passed to the Herbie executable 28 | newtype HerbieOptions = HerbieOptions [[String]] 29 | deriving (Show,Generic,NFData) 30 | 31 | opts2string :: HerbieOptions -> String 32 | opts2string (HerbieOptions opts) = concat $ intersperse " " $ concat opts 33 | 34 | string2opts :: String -> HerbieOptions 35 | string2opts xs = HerbieOptions [words xs] 36 | 37 | toggleNumerics :: HerbieOptions -> HerbieOptions 38 | toggleNumerics (HerbieOptions xs) = HerbieOptions $ go xs [] 39 | where 40 | go :: [[String]] -> [[String]] -> [[String]] 41 | go [] ys = ys ++ [["-o", "rules:numerics" ]] 42 | go (["-o","rules:numerics"]:xs) ys = ys ++ xs 43 | go (x:xs) ys = go xs (x:ys) 44 | 45 | toggleRegimes :: HerbieOptions -> HerbieOptions 46 | toggleRegimes (HerbieOptions xs) = HerbieOptions $ go xs [] 47 | where 48 | go :: [[String]] -> [[String]] -> [[String]] 49 | go [] ys = ys ++ [["-o", "reduce:regimes" ]] 50 | go (["-o","reduce:regimes"]:xs) ys = ys ++ xs 51 | go (x:xs) ys = go xs (x:ys) 52 | 53 | -- | This information gets stored in a separate db table for debugging purposes 54 | data DbgInfo = DbgInfo 55 | { dbgComments :: String 56 | , modName :: String 57 | , functionName :: String 58 | , functionType :: String 59 | } 60 | 61 | -- | The result of running Herbie 62 | data HerbieResult a = HerbieResult 63 | { cmdin :: !a 64 | , cmdout :: !a 65 | , opts :: !HerbieOptions 66 | , errin :: !Double 67 | , errout :: !Double 68 | } 69 | deriving (Show,Generic,NFData) 70 | 71 | instance FromField a => FromRow (HerbieResult a) where 72 | fromRow = HerbieResult <$> field <*> field <*> (fmap string2opts field) <*> field <*> field 73 | 74 | instance ToField a => ToRow (HerbieResult a) where 75 | toRow (HerbieResult cmdin cmdout opts errin errout) = toRow 76 | ( cmdin 77 | , cmdout 78 | , opts2string opts 79 | , errin 80 | , errout 81 | ) 82 | 83 | -- | Given a MathExpr, return a numerically stable version. 84 | stabilizeMathExpr :: DbgInfo -> HerbieOptions -> MathExpr -> IO (HerbieResult MathExpr) 85 | stabilizeMathExpr dbgInfo opts cmdin = do 86 | let (cmdinLisp,varmap) = getCanonicalLispCmd $ haskellOpsToHerbieOps cmdin 87 | res <- stabilizeLisp dbgInfo opts cmdinLisp 88 | cmdout' <- do 89 | -- FIXME: 90 | -- Due to a bug in Herbie, fromCanonicalLispCmd sometimes throws an exception. 91 | ret <- try $ do 92 | let ret = herbieOpsToHaskellOps $ fromCanonicalLispCmd (cmdout res,varmap) 93 | deepseq ret $ return ret 94 | case ret of 95 | Left (SomeException e) -> do 96 | putStrLn $ "WARNING in stabilizeMathExpr: "++show e 97 | return cmdin 98 | Right x -> return x 99 | let res' = res 100 | { cmdin = cmdin 101 | , cmdout = cmdout' 102 | } 103 | -- putStrLn $ " cmdin: "++cmdinLisp 104 | -- putStrLn $ " cmdout: "++cmdout res 105 | -- putStrLn $ " stabilizeLisp': "++mathExpr2lisp (fromCanonicalLispCmd (cmdout res,varmap)) 106 | return res' 107 | 108 | -- | Given a Lisp command, return a numerically stable version. 109 | -- It first checks if the command is in the global database; 110 | -- if it's not, then it runs "execHerbie". 111 | stabilizeLisp :: DbgInfo -> HerbieOptions -> String -> IO (HerbieResult String) 112 | stabilizeLisp dbgInfo opts cmd = do 113 | dbResult <- lookupDatabase opts cmd 114 | ret <- case dbResult of 115 | Just x -> do 116 | return x 117 | Nothing -> do 118 | putStrLn " Not found in database. Running Herbie..." 119 | res <- execHerbie opts cmd 120 | insertDatabase res 121 | return res 122 | insertDatabaseDbgInfo dbgInfo ret 123 | 124 | -- FIXME: 125 | -- Herbie has a bug where it sometimes outputs a less numerically stable version. 126 | -- So we need to check to make sure we return the more stable output. 127 | return $ if errin ret > errout ret 128 | then ret 129 | else ret { errout = errin ret, cmdout = cmdin ret } 130 | 131 | -- | Run the `herbie` command and return the result 132 | execHerbie :: HerbieOptions -> String -> IO (HerbieResult String) 133 | execHerbie (HerbieOptions opts) lisp = do 134 | 135 | -- build the command string we will pass to Herbie 136 | let varstr = "("++unwords (lisp2vars lisp)++")" 137 | stdin = "(herbie-test "++varstr++" \"cmd\" "++lisp++") \n" 138 | 139 | -- Herbie can take a long time to run. 140 | -- Here we limit it to 2 minutes. 141 | -- 142 | -- FIXME: 143 | -- This should be a parameter the user can pass to the plugin 144 | ret <- timeout 120000000 $ do 145 | 146 | -- launch Herbie with a fixed seed to ensure reproducible builds 147 | (_,stdout,stderr) <- readProcessWithExitCode 148 | "herbie-exec" 149 | (concat opts) 150 | stdin 151 | 152 | -- try to parse Herbie's output; 153 | -- if we can't parse it, that means Herbie had an error and we should abort gracefully 154 | ret <- try $ do 155 | let (line1:line2:line3:_) = lines stdout 156 | let ret = HerbieResult 157 | { errin 158 | = read 159 | $ drop 1 160 | $ dropWhile (/=':') line1 161 | , errout 162 | = read 163 | $ drop 1 164 | $ dropWhile (/=':') line2 165 | , opts 166 | = HerbieOptions opts 167 | , cmdin 168 | = lisp 169 | , cmdout 170 | = (!!2) 171 | $ groupByParens 172 | $ init 173 | $ tail line3 174 | } 175 | deepseq ret $ return ret 176 | 177 | case ret of 178 | Left (SomeException e) -> do 179 | putStrLn $ " WARNING: error when calling `herbie-exec`" 180 | putStrLn $ " stdin="++stdin 181 | putStrLn $ " stdout="++stdout 182 | return HerbieResult 183 | { errin = 0/0 184 | , errout = 0/0 185 | , opts = HerbieOptions opts 186 | , cmdin = lisp 187 | , cmdout = lisp 188 | } 189 | Right x -> return x 190 | 191 | case ret of 192 | Just x -> return x 193 | Nothing -> do 194 | putStrLn $ "WARNING: Call to Herbie timed out after 2 minutes." 195 | return $ HerbieResult 196 | { errin = 0/0 197 | , errout = 0/0 198 | , opts = HerbieOptions opts 199 | , cmdin = lisp 200 | , cmdout = lisp 201 | } 202 | 203 | 204 | -- | Returns a connection to the sqlite3 database 205 | mkConn = do 206 | path <- getDataFileName "Herbie.db" 207 | open path 208 | 209 | -- | Check the database to see if we already know the answer for running Herbie 210 | -- 211 | -- FIXME: 212 | -- When Herbie times out, NULL gets inserted into the database for errin and errout. 213 | -- The Sqlite3 bindings don't support putting NULL into Double's as NaNs, 214 | -- so the query below raises an exception. 215 | -- This isn't so bad, except a nasty error message gets printed, 216 | -- and the plugin attempts to run Herbie again (wasting a lot of time). 217 | lookupDatabase :: HerbieOptions -> String -> IO (Maybe (HerbieResult String)) 218 | lookupDatabase opts cmdin = do 219 | ret <- try $ do 220 | conn <- mkConn 221 | res <- queryNamed 222 | conn 223 | "SELECT cmdin,cmdout,opts,errin,errout from HerbieResults where cmdin = :cmdin and opts = :opts" 224 | [":cmdin" := cmdin, ":opts" := opts2string opts] 225 | :: IO [HerbieResult String] 226 | close conn 227 | return $ case res of 228 | [x] -> Just x 229 | [] -> Nothing 230 | case ret of 231 | Left (SomeException e) -> do 232 | putStrLn $ " WARNING in lookupDatabase: "++show e 233 | return Nothing 234 | Right x -> return x 235 | 236 | -- | Inserts a "HerbieResult" into the global database of commands 237 | insertDatabase :: HerbieResult String -> IO () 238 | insertDatabase res = do 239 | ret <- try $ do 240 | conn <- mkConn 241 | execute_ conn $ fromString $ 242 | "CREATE TABLE IF NOT EXISTS HerbieResults " 243 | ++"( id INTEGER PRIMARY KEY" 244 | ++", cmdin TEXT NOT NULL" 245 | ++", cmdout TEXT NOT NULL" 246 | ++", opts TEXT NOT NULL" 247 | ++", errin DOUBLE " 248 | ++", errout DOUBLE " 249 | ++", UNIQUE (cmdin, opts)" 250 | ++")" 251 | execute_ conn "CREATE INDEX IF NOT EXISTS HerbieResultsIndex ON HerbieResults(cmdin)" 252 | execute conn "INSERT INTO HerbieResults (cmdin,cmdout,opts,errin,errout) VALUES (?,?,?,?,?)" res 253 | close conn 254 | case ret of 255 | Left (SomeException e) -> putStrLn $ " WARNING in insertDatabase: "++show e 256 | Right _ -> return () 257 | return () 258 | 259 | insertDatabaseDbgInfo :: DbgInfo -> HerbieResult String -> IO () 260 | insertDatabaseDbgInfo dbgInfo res = do 261 | ret <- try $ do 262 | conn <- mkConn 263 | execute_ conn $ fromString $ 264 | "CREATE TABLE IF NOT EXISTS DbgInfo " 265 | ++"( id INTEGER PRIMARY KEY" 266 | ++", resid INTEGER NOT NULL" 267 | ++", dbgComments TEXT" 268 | ++", modName TEXT" 269 | ++", functionName TEXT" 270 | ++", functionType TEXT" 271 | ++")" 272 | res <- queryNamed 273 | conn 274 | "SELECT id,cmdout from HerbieResults where cmdin = :cmdin" 275 | [":cmdin" := cmdin res] 276 | :: IO [(Int,String)] 277 | execute conn "INSERT INTO DbgInfo (resid,dbgComments,modName,functionName,functionType) VALUES (?,?,?,?,?)" (fst $ head res,dbgComments dbgInfo,modName dbgInfo,functionName dbgInfo,functionType dbgInfo) 278 | close conn 279 | case ret of 280 | Left (SomeException e) -> putStrLn $ "WARNING in insertDatabaseDbgInfo: "++show e 281 | Right _ -> return () 282 | return () 283 | -------------------------------------------------------------------------------- /src/Herbie/MathExpr.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE DeriveAnyClass,DeriveGeneric #-} 2 | module Herbie.MathExpr 3 | where 4 | 5 | import Control.DeepSeq 6 | import Data.List 7 | import Data.List.Split 8 | import Data.Maybe 9 | import GHC.Generics 10 | 11 | import Debug.Trace 12 | import Prelude 13 | ifThenElse True t f = t 14 | ifThenElse False t f = f 15 | 16 | ------------------------------------------------------------------------------- 17 | -- constants that define valid math expressions 18 | 19 | monOpList = 20 | [ "cos" 21 | , "sin" 22 | , "tan" 23 | , "acos" 24 | , "asin" 25 | , "atan" 26 | , "cosh" 27 | , "sinh" 28 | , "tanh" 29 | , "exp" 30 | , "log" 31 | , "sqrt" 32 | , "abs" 33 | , "size" 34 | , "negate" 35 | ] 36 | 37 | binOpList = [ "^", "**", "^^", "/", "-", "expt" ] ++ commutativeOpList 38 | commutativeOpList = [ "*", "+"] -- , "max", "min" ] 39 | 40 | fancyOps = [ "hypot", "log1p", "expm1" ] 41 | 42 | -------------------------------------------------------------------------------- 43 | 44 | -- | Stores the AST for a math expression in a generic form that requires no knowledge of Core syntax. 45 | data MathExpr 46 | = EBinOp String MathExpr MathExpr 47 | | EMonOp String MathExpr 48 | | EIf MathExpr MathExpr MathExpr 49 | | ELit Rational 50 | | ELeaf String 51 | deriving (Show,Eq,Generic,NFData) 52 | 53 | instance Ord MathExpr where 54 | compare (ELeaf _) (ELeaf _) = EQ 55 | compare (ELeaf _) _ = LT 56 | 57 | compare (ELit r1) (ELit r2) = compare r1 r2 58 | compare (ELit _ ) (ELeaf _) = GT 59 | compare (ELit _ ) _ = LT 60 | 61 | compare (EMonOp op1 e1) (EMonOp op2 e2) = case compare op1 op2 of 62 | EQ -> compare e1 e2 63 | x -> x 64 | compare (EMonOp _ _) (ELeaf _) = GT 65 | compare (EMonOp _ _) (ELit _) = GT 66 | compare (EMonOp _ _) _ = LT 67 | 68 | compare (EBinOp op1 e1a e1b) (EBinOp op2 e2a e2b) = case compare op1 op2 of 69 | EQ -> case compare e1a e2a of 70 | EQ -> compare e1b e2b 71 | _ -> EQ 72 | _ -> EQ 73 | compare (EBinOp _ _ _) _ = LT 74 | 75 | -- | Converts all Haskell operators in the MathExpr into Herbie operators 76 | haskellOpsToHerbieOps :: MathExpr -> MathExpr 77 | haskellOpsToHerbieOps = go 78 | where 79 | go (EBinOp op e1 e2) = EBinOp op' (go e1) (go e2) 80 | where 81 | op' = case op of 82 | "**" -> "expt" 83 | "^^" -> "expt" 84 | "^" -> "expt" 85 | x -> x 86 | 87 | go (EMonOp op e1) = EMonOp op' (go e1) 88 | where 89 | op' = case op of 90 | "size" -> "abs" 91 | x -> x 92 | 93 | go (EIf cond e1 e2) = EIf (go cond) (go e1) (go e2) 94 | go x = x 95 | 96 | -- | Converts all Herbie operators in the MathExpr into Haskell operators 97 | herbieOpsToHaskellOps :: MathExpr -> MathExpr 98 | herbieOpsToHaskellOps = go 99 | where 100 | go (EBinOp op e1 e2) = EBinOp op' (go e1) (go e2) 101 | where 102 | op' = case op of 103 | "^" -> "**" 104 | "expt" -> "**" 105 | x -> x 106 | 107 | go (EMonOp "sqr" e1) = EBinOp "*" (go e1) (go e1) 108 | go (EMonOp op e1) = EMonOp op' (go e1) 109 | where 110 | op' = case op of 111 | "-" -> "negate" 112 | "abs" -> "size" 113 | x -> x 114 | 115 | go (EIf cond e1 e2) = EIf (go cond) (go e1) (go e2) 116 | go x = x 117 | 118 | -- | Replace all the variables in the MathExpr with canonical names (x0,x1,x2...) 119 | -- and reorder commutative binary operations. 120 | -- This lets us more easily compare MathExpr's based on their structure. 121 | -- The returned map lets us convert the canoncial MathExpr back into the original. 122 | toCanonicalMathExpr :: MathExpr -> (MathExpr,[(String,String)]) 123 | toCanonicalMathExpr e = go [] e 124 | where 125 | go :: [(String,String)] -> MathExpr -> (MathExpr,[(String,String)]) 126 | go acc (EBinOp op e1 e2) = (EBinOp op e1' e2',acc2') 127 | where 128 | (e1_,e2_) = if op `elem` commutativeOpList 129 | then (min e1 e2,max e1 e2) 130 | else (e1,e2) 131 | 132 | (e1',acc1') = go acc e1_ 133 | (e2',acc2') = go acc1' e2_ 134 | 135 | go acc (EMonOp op e1) = (EMonOp op e1', acc1') 136 | where 137 | (e1',acc1') = go acc e1 138 | go acc (ELit r) = (ELit r,acc) 139 | go acc (ELeaf str) = (ELeaf str',acc') 140 | where 141 | (acc',str') = case lookup str acc of 142 | Nothing -> ((str,"herbie"++show (length acc)):acc, "herbie"++show (length acc)) 143 | Just x -> (acc,x) 144 | 145 | -- | Convert a canonical MathExpr into its original form. 146 | -- 147 | -- FIXME: 148 | -- A bug in Herbie causes it to sometimes output infinities, 149 | -- which break this function and cause it to error. 150 | fromCanonicalMathExpr :: (MathExpr,[(String,String)]) -> MathExpr 151 | fromCanonicalMathExpr (e,xs) = go e 152 | where 153 | xs' = map (\(a,b) -> (b,a)) xs 154 | 155 | go (EMonOp op e1) = EMonOp op (go e1) 156 | go (EBinOp op e1 e2) = EBinOp op (go e1) (go e2) 157 | go (EIf (EBinOp "<" _ (ELeaf "-inf.0")) e1 e2) = go e2 -- FIXME: added due to bug above 158 | go (EIf cond e1 e2) = EIf (go cond) (go e1) (go e2) 159 | go (ELit r) = ELit r 160 | go (ELeaf str) = case lookup str xs' of 161 | Just x -> ELeaf x 162 | Nothing -> error $ "fromCanonicalMathExpr: str="++str++"; xs="++show xs' 163 | 164 | -- | Calculates the maximum depth of the AST. 165 | mathExprDepth :: MathExpr -> Int 166 | mathExprDepth (EBinOp _ e1 e2) = 1+max (mathExprDepth e1) (mathExprDepth e2) 167 | mathExprDepth (EMonOp _ e1 ) = 1+mathExprDepth e1 168 | mathExprDepth _ = 0 169 | 170 | -------------------------------------------------------------------------------- 171 | -- functions for manipulating math expressions in lisp form 172 | 173 | getCanonicalLispCmd :: MathExpr -> (String,[(String,String)]) 174 | getCanonicalLispCmd me = (mathExpr2lisp me',varmap) 175 | where 176 | (me',varmap) = toCanonicalMathExpr me 177 | 178 | fromCanonicalLispCmd :: (String,[(String,String)]) -> MathExpr 179 | fromCanonicalLispCmd (lisp,varmap) = fromCanonicalMathExpr (lisp2mathExpr lisp,varmap) 180 | 181 | -- | Converts MathExpr into a lisp command suitable for passing to Herbie 182 | mathExpr2lisp :: MathExpr -> String 183 | mathExpr2lisp = go 184 | where 185 | go (EBinOp op a1 a2) = "("++op++" "++go a1++" "++go a2++")" 186 | go (EMonOp "negate" a) = "(- "++go a++")" 187 | go (EMonOp op a) = "("++op++" "++go a++")" 188 | go (EIf cond e1 e2) = "(if "++go cond++" "++go e1++" "++go e2++")" 189 | go (ELeaf e) = e 190 | go (ELit r) = if (toRational (floor r::Integer) == r) 191 | then show (floor r :: Integer) 192 | else show (fromRational r :: Double) 193 | 194 | -- | Converts a lisp command into a MathExpr 195 | lisp2mathExpr :: String -> MathExpr 196 | lisp2mathExpr ('-':xs) = EMonOp "negate" (lisp2mathExpr xs) 197 | lisp2mathExpr ('(':xs) = if length xs > 1 && last xs==')' 198 | then case groupByParens $ init xs of 199 | [op,e1] -> EMonOp op (lisp2mathExpr e1) 200 | [op,e1,e2] -> EBinOp op (lisp2mathExpr e1) (lisp2mathExpr e2) 201 | ["if",cond,e1,e2] -> EIf (lisp2mathExpr cond) (lisp2mathExpr e1) (lisp2mathExpr e2) 202 | _ -> error $ "lisp2mathExpr: "++xs 203 | else error $ "lisp2mathExpr: malformed input '("++xs++"'" 204 | lisp2mathExpr xs = case splitOn "/" xs of 205 | [num,den] -> lisp2mathExpr $ "(/ "++num++" "++den++")" 206 | _ -> case readMaybe xs :: Maybe Double of 207 | Just x -> ELit $ toRational x 208 | Nothing -> ELeaf xs 209 | 210 | -- | Extracts all the variables from the lisp commands with no duplicates. 211 | lisp2vars :: String -> [String] 212 | lisp2vars = nub . lisp2varsNoNub 213 | 214 | -- | Extracts all the variables from the lisp commands. 215 | -- Each variable occurs once in the output for each time it occurs in the input. 216 | lisp2varsNoNub :: String -> [String] 217 | lisp2varsNoNub lisp 218 | = sort 219 | $ filter (\x -> x/="(" 220 | && x/=")" 221 | && (x `notElem` binOpList) 222 | && (x `notElem` monOpList) 223 | && (head x `notElem` ("1234567890"::String)) 224 | ) 225 | $ tokenize lisp :: [String] 226 | where 227 | -- We just need to add spaces around the parens before calling "words" 228 | tokenize :: String -> [String] 229 | tokenize = words . concat . map go 230 | where 231 | go '(' = " ( " 232 | go ')' = " ) " 233 | go x = [x] 234 | 235 | lispHasRepeatVars :: String -> Bool 236 | lispHasRepeatVars lisp = length (lisp2vars lisp) /= length (lisp2varsNoNub lisp) 237 | 238 | ------------------------------------------------------------------------------- 239 | -- utilities 240 | 241 | readMaybe :: Read a => String -> Maybe a 242 | readMaybe = fmap fst . listToMaybe . reads 243 | 244 | -- | Given an expression, break it into tokens only outside parentheses 245 | groupByParens :: String -> [String] 246 | groupByParens str = go 0 str [] [] 247 | where 248 | go 0 (' ':xs) [] ret = go 0 xs [] ret 249 | go 0 (' ':xs) acc ret = go 0 xs [] (ret++[acc]) 250 | go 0 (')':xs) acc ret = go 0 xs [] (ret++[acc]) 251 | go i (')':xs) acc ret = go (i-1) xs (acc++")") ret 252 | go i ('(':xs) acc ret = go (i+1) xs (acc++"(") ret 253 | go i (x :xs) acc ret = go i xs (acc++[x]) ret 254 | go _ [] acc ret = ret++[acc] 255 | 256 | 257 | -------------------------------------------------------------------------------- /src/Herbie/MathInfo.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE FlexibleInstances,FlexibleContexts,MultiWayIf,CPP #-} 2 | module Herbie.MathInfo 3 | where 4 | 5 | import Class 6 | import DsBinds 7 | import DsMonad 8 | import ErrUtils 9 | import GhcPlugins hiding (trace) 10 | import Unique 11 | import MkId 12 | import PrelNames 13 | import UniqSupply 14 | import TcRnMonad 15 | import TcSimplify 16 | import Type 17 | 18 | import Control.Monad 19 | import Control.Monad.Except 20 | import Control.Monad.Trans 21 | import Data.Char 22 | import Data.List 23 | import Data.List.Split 24 | import Data.Maybe 25 | import Data.Ratio 26 | 27 | import Herbie.CoreManip 28 | import Herbie.MathExpr 29 | 30 | import Prelude 31 | import Show 32 | 33 | -- import Debug.Trace hiding (traceM) 34 | trace a b = b 35 | traceM a = return () 36 | 37 | -------------------------------------------------------------------------------- 38 | 39 | -- | The fields of this type correspond to the sections of a function type. 40 | -- 41 | -- Must satisfy the invariant that every class in "getCxt" has an associated dictionary in "getDicts". 42 | data ParamType = ParamType 43 | { getQuantifier :: [Var] 44 | , getCxt :: [Type] 45 | , getDicts :: [CoreExpr] 46 | , getParam :: Type 47 | } 48 | 49 | -- | This type is a simplified version of the CoreExpr type. 50 | -- It only supports math expressions. 51 | -- We first convert a CoreExpr into a MathInfo, 52 | -- perform all the manipulation on the MathExpr within the MathInfo, 53 | -- then use the information in MathInfo to convert the MathExpr back into a CoreExpr. 54 | data MathInfo = MathInfo 55 | { getMathExpr :: MathExpr 56 | , getParamType :: ParamType 57 | , getExprs :: [(String,Expr Var)] 58 | -- ^ the fst value is the unique name assigned to non-mathematical expressions 59 | -- the snd value is the expression 60 | } 61 | 62 | -- | Pretty print a math expression 63 | pprMathInfo :: MathInfo -> String 64 | pprMathInfo mathInfo = go 1 False $ getMathExpr mathInfo 65 | where 66 | isLitOrLeaf :: MathExpr -> Bool 67 | isLitOrLeaf (ELit _ ) = True 68 | isLitOrLeaf (ELeaf _) = True 69 | isLitOrLeaf _ = False 70 | 71 | go :: Int -> Bool -> MathExpr -> String 72 | go i b e = if b && not (isLitOrLeaf e) 73 | then "("++str++")" 74 | else str 75 | where 76 | str = case e of 77 | -- EMonOp "negate" l@(ELit _) -> "-"++go i False l 78 | EMonOp "negate" e1 -> "-"++go i False e1 79 | EMonOp op e1 -> op++" "++go i True e1 80 | 81 | EBinOp op e1 e2 -> if op `elem` fancyOps 82 | then op++" "++go i True e1++" "++go i True e2 83 | else go i parens1 e1++" "++op++" "++go i parens2 e2 84 | where 85 | parens1 = case e1 of 86 | -- (EBinOp op' _ _) -> op/=op' 87 | (EMonOp _ _) -> False 88 | _ -> True 89 | 90 | parens2 = case e2 of 91 | -- (EBinOp op' _ _) -> op/=op' || not (op `elem` commutativeOpList) 92 | (EMonOp _ _) -> False 93 | _ -> True 94 | 95 | ELit l -> if toRational (floor l) == l 96 | then if length (show (floor l :: Integer)) < 10 97 | then show (floor l :: Integer) 98 | else show (fromRational l :: Double) 99 | else show (fromRational l :: Double) 100 | 101 | ELeaf l -> case lookup l $ getExprs mathInfo of 102 | Just (Var _) -> pprVariable l mathInfo 103 | Just _ -> pprExpr l mathInfo 104 | -- Just _ -> "???" 105 | 106 | EIf cond e1 e2 -> "if "++go i False cond++"\n" 107 | ++white++"then "++go (i+1) False e1++"\n" 108 | ++white++"else "++go (i+1) False e2 109 | where 110 | white = replicate (4*i) ' ' 111 | 112 | -- | If there is no ambiguity, the variable is displayed without the unique. 113 | -- Otherwise, it is returned with the unique 114 | pprVariable :: String -> MathInfo -> String 115 | pprVariable var mathInfo = if length (filter (==pprvar) pprvars) 116 | > length (filter (==var) $ map fst $ getExprs mathInfo) 117 | then var 118 | else pprvar 119 | where 120 | pprvar = ppr var 121 | pprvars = map ppr $ map fst $ getExprs mathInfo 122 | 123 | ppr = concat . intersperse "_" . init . splitOn "_" 124 | 125 | -- | The names of expressions are long and awkward. 126 | -- This gives us a display-friendly version. 127 | pprExpr :: String -> MathInfo -> String 128 | pprExpr var mathInfo = "?"++show index 129 | where 130 | index = case findIndex (==var) notvars of 131 | Just x -> x 132 | 133 | notvars 134 | = map fst 135 | $ filter (\(v,e) -> case e of (Var _) -> False; otherwise -> True) 136 | $ getExprs mathInfo 137 | 138 | -- | If the given expression is a math expression, 139 | -- returns the type of the variable that the math expression operates on. 140 | varTypeIfValidExpr :: CoreExpr -> Maybe Type 141 | varTypeIfValidExpr e = case e of 142 | 143 | -- might be a binary math operation 144 | (App (App (App (App (Var v) (Type t)) _) _) _) -> if var2str v `elem` binOpList 145 | then if isValidType t 146 | then Just t 147 | else Nothing 148 | else Nothing 149 | 150 | -- might be a unary math operation 151 | (App (App (App (Var v) (Type t)) _) _) -> if var2str v `elem` monOpList 152 | then if isValidType t 153 | then Just t 154 | else Nothing 155 | else Nothing 156 | 157 | -- first function is anything else means that we're not a math expression 158 | _ -> Nothing 159 | 160 | where 161 | isValidType :: Type -> Bool 162 | isValidType t = isTyVarTy t || case splitTyConApp_maybe t of 163 | Nothing -> True 164 | Just (tyCon,_) -> tyCon == floatTyCon || tyCon == doubleTyCon 165 | 166 | -- | Converts a CoreExpr into a MathInfo 167 | mkMathInfo :: DynFlags -> [Var] -> Type -> Expr Var -> Maybe MathInfo 168 | mkMathInfo dflags dicts bndType e = case varTypeIfValidExpr e of 169 | Nothing -> Nothing 170 | Just t -> if mathExprDepth getMathExpr>1 || lispHasRepeatVars (mathExpr2lisp getMathExpr) 171 | then Just $ MathInfo 172 | getMathExpr 173 | ParamType 174 | { getQuantifier = quantifier 175 | , getCxt = cxt 176 | , getDicts = map Var dicts 177 | , getParam = t 178 | } 179 | exprs 180 | else Nothing 181 | 182 | where 183 | (getMathExpr,exprs) = go e [] 184 | 185 | -- this should never return Nothing if validExpr is not Nothing 186 | (quantifier,unquantified) = extractQuantifiers bndType 187 | (cxt,uncxt) = extractContext unquantified 188 | 189 | -- recursively converts the `Expr Var` into a MathExpr and a dictionary 190 | go :: Expr Var 191 | -> [(String,Expr Var)] 192 | -> (MathExpr 193 | ,[(String,Expr Var)] 194 | ) 195 | 196 | -- we need to special case the $ operator for when MathExpr is run before any rewrite rules 197 | go e@(App (App (App (App (Var v) (Type _)) (Type _)) a1) a2) exprs 198 | = if var2str v == "$" 199 | then go (App a1 a2) exprs 200 | else (ELeaf $ expr2str dflags e,[(expr2str dflags e,e)]) 201 | 202 | -- polymorphic literals created via fromInteger 203 | go e@(App (App (App (Var v) (Type _)) dict) (Lit l)) exprs 204 | = (ELit $ lit2rational l, exprs) 205 | 206 | -- polymorphic literals created via fromRational 207 | go e@(App (App (App (Var v) (Type _)) dict) 208 | (App (App (App (Var _) (Type _)) (Lit l1)) (Lit l2))) exprs 209 | = (ELit $ lit2rational l1 / lit2rational l2, exprs) 210 | 211 | -- non-polymorphic literals 212 | go e@(App (Var _) (Lit l)) exprs 213 | = (ELit $ lit2rational l, exprs) 214 | 215 | -- binary operators 216 | go e@(App (App (App (App (Var v) (Type _)) dict) a1) a2) exprs 217 | = if var2str v `elem` binOpList 218 | then let (a1',exprs1) = go a1 [] 219 | (a2',exprs2) = go a2 [] 220 | in ( EBinOp (var2str v) a1' a2' 221 | , exprs++exprs1++exprs2 222 | ) 223 | else (ELeaf $ expr2str dflags e,[(expr2str dflags e,e)]) 224 | 225 | -- unary operators 226 | go e@(App (App (App (Var v) (Type _)) dict) a) exprs 227 | = if var2str v `elem` monOpList 228 | then let (a',exprs') = go a [] 229 | in ( EMonOp (var2str v) a' 230 | , exprs++exprs' 231 | ) 232 | else (ELeaf $ expr2str dflags e,(expr2str dflags e,e):exprs) 233 | 234 | -- everything else 235 | go e exprs = (ELeaf $ expr2str dflags e,[(expr2str dflags e,e)]) 236 | 237 | -- | Converts a MathInfo back into a CoreExpr 238 | mathInfo2expr :: ModGuts -> MathInfo -> ExceptT ExceptionType CoreM CoreExpr 239 | mathInfo2expr guts herbie = go (getMathExpr herbie) 240 | where 241 | pt = getParamType herbie 242 | 243 | -- binary operators 244 | go (EBinOp opstr a1 a2) = do 245 | a1' <- go a1 246 | a2' <- go a2 247 | f <- getDecoratedFunction guts opstr (getParam pt) (getDicts pt) 248 | return $ App (App f a1') a2' 249 | 250 | -- unary operators 251 | go (EMonOp opstr a) = do 252 | a' <- go a 253 | f <- getDecoratedFunction guts opstr (getParam pt) (getDicts pt) 254 | castToType 255 | (getDicts pt) 256 | (getParam pt) 257 | $ App f a' 258 | 259 | -- if statements 260 | go (EIf cond a1 a2) = do 261 | cond' <- go cond >>= castToType (getDicts pt) boolTy 262 | a1' <- go a1 263 | a2' <- go a2 264 | 265 | wildUniq <- getUniqueM 266 | let wildName = mkSystemName wildUniq (mkVarOcc "wild") 267 | wildVar = mkLocalVar VanillaId wildName boolTy vanillaIdInfo 268 | 269 | return $ Case 270 | cond' 271 | wildVar 272 | (getParam pt) 273 | [ (DataAlt falseDataCon, [], a2') 274 | , (DataAlt trueDataCon, [], a1') 275 | ] 276 | 277 | -- leaf is a numeric literal 278 | go (ELit r) = do 279 | fromRationalExpr <- getDecoratedFunction guts "fromRational" (getParam pt) (getDicts pt) 280 | 281 | integerTyCon <- lookupTyCon integerTyConName 282 | let integerTy = mkTyConTy integerTyCon 283 | 284 | ratioTyCon <- lookupTyCon ratioTyConName 285 | tmpUniq <- getUniqueM 286 | let tmpName = mkSystemName tmpUniq (mkVarOcc "a") 287 | tmpVar = mkTyVar tmpName liftedTypeKind 288 | tmpVarT = mkTyVarTy tmpVar 289 | ratioConTy = mkForAllTy tmpVar $ mkFunTys [tmpVarT,tmpVarT] $ mkAppTy (mkTyConTy ratioTyCon) tmpVarT 290 | ratioConVar = mkGlobalVar VanillaId ratioDataConName ratioConTy vanillaIdInfo 291 | 292 | return $ App 293 | fromRationalExpr 294 | (App 295 | (App 296 | (App 297 | (Var ratioConVar ) 298 | (Type integerTy) 299 | ) 300 | (Lit $ LitInteger (numerator r) integerTy) 301 | ) 302 | (Lit $ LitInteger (denominator r) integerTy) 303 | ) 304 | 305 | -- leaf is any other expression 306 | go (ELeaf str) = do 307 | dflags <- getDynFlags 308 | return $ case lookup str (getExprs herbie) of 309 | Just x -> x 310 | Nothing -> error $ "mathInfo2expr: var " ++ str ++ " not in scope" 311 | ++"; in scope vars="++show (nub $ map fst $ getExprs herbie) 312 | 313 | -------------------------------------------------------------------------------- /src/Herbie/Options.hs: -------------------------------------------------------------------------------- 1 | -- | This module handles parsing the options that can get passed into the HerbiePlugin 2 | module Herbie.Options 3 | where 4 | 5 | import GhcPlugins 6 | import Prelude 7 | 8 | data PluginOpts = PluginOpts 9 | { 10 | -- | This comment will be stored in the Herbie database for each expression that is found 11 | optsComments :: String 12 | 13 | -- | Controls whether rewriting is enabled or not 14 | , optsRewrite :: Bool 15 | 16 | -- | Perform the rewrite only if the improved expression reduces instability 17 | -- by this number of bits 18 | , optsTol :: Double 19 | } 20 | 21 | defPluginOpts :: PluginOpts 22 | defPluginOpts = PluginOpts 23 | { optsComments = "" 24 | , optsRewrite = True 25 | , optsTol = 0.5 26 | } 27 | 28 | parsePluginOpts :: [CommandLineOption] -> PluginOpts 29 | parsePluginOpts xs = go xs defPluginOpts 30 | where 31 | go [] opts = opts 32 | go (x:xs) opts 33 | | take 9 x == "noRewrite" = go xs $ opts { optsRewrite = False } 34 | | take 4 x == "tol=" = go xs $ opts { optsTol = read (drop 4 x) } 35 | | take 8 x == "comment=" = go xs $ opts { optsComments = drop 8 x } 36 | | otherwise = go xs opts 37 | -------------------------------------------------------------------------------- /src/Show.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE FlexibleInstances, MultiWayIf, StandaloneDeriving, 2 | TypeSynonymInstances #-} 3 | 4 | {-# OPTIONS_GHC -fno-warn-orphans #-} 5 | 6 | -- | We define lots of orphan Show instances here, for debugging and learning 7 | -- purposes. 8 | -- 9 | -- Most of the time while trying to figure out when a constructor is used or how 10 | -- is a term compiled, it's easiest to just create an example and run the plugin 11 | -- on it. 12 | -- 13 | -- Without Show instances though, we can't easily inspect compiled outputs. 14 | -- Outputable generated strings hide lots of details(especially constructors), 15 | -- but we still export a `showOutputable` here, for similar reasons. 16 | -- 17 | module Show where 18 | 19 | import Data.IORef 20 | import Data.List (intercalate) 21 | import System.IO.Unsafe (unsafePerformIO) 22 | 23 | import Class 24 | import CostCentre 25 | import ForeignCall 26 | import Demand 27 | import GhcPlugins 28 | import IdInfo 29 | import PrimOp 30 | import TypeRep 31 | 32 | import Prelude 33 | 34 | -------------------------------------------------------------------------------- 35 | 36 | dbg :: Outputable a => a -> String 37 | dbg a = showSDoc dynFlags (ppr a) 38 | 39 | {-# NOINLINE dynFlags_ref #-} 40 | dynFlags_ref :: IORef DynFlags 41 | dynFlags_ref = unsafePerformIO (newIORef undefined) 42 | 43 | {-# NOINLINE dynFlags #-} 44 | dynFlags :: DynFlags 45 | dynFlags = unsafePerformIO (readIORef dynFlags_ref) 46 | 47 | showOutputable :: Outputable a => a -> String 48 | showOutputable = showSDoc dynFlags . ppr 49 | 50 | -------------------------------------------------------------------------------- 51 | -- Orphan Show instances 52 | 53 | deriving instance Show a => Show (Expr a) 54 | deriving instance Show Type 55 | deriving instance Show Literal 56 | deriving instance Show a => Show (Tickish a) 57 | deriving instance Show a => Show (Bind a) 58 | deriving instance Show AltCon 59 | deriving instance Show TyLit 60 | deriving instance Show FunctionOrData 61 | deriving instance Show Module 62 | deriving instance Show CostCentre 63 | deriving instance Show Role 64 | deriving instance Show LeftOrRight 65 | deriving instance Show IsCafCC 66 | 67 | instance Show Class where 68 | show _ = "" 69 | 70 | deriving instance Show IdDetails 71 | deriving instance Show PrimOp 72 | deriving instance Show ForeignCall 73 | deriving instance Show TickBoxOp 74 | deriving instance Show PrimOpVecCat 75 | deriving instance Show CCallSpec 76 | deriving instance Show CCallTarget 77 | deriving instance Show CCallConv 78 | deriving instance Show SpecInfo 79 | deriving instance Show OccInfo 80 | deriving instance Show InlinePragma 81 | deriving instance Show OneShotInfo 82 | deriving instance Show CafInfo 83 | deriving instance Show Unfolding 84 | deriving instance Show UnfoldingSource 85 | deriving instance Show UnfoldingGuidance 86 | deriving instance Show Activation 87 | deriving instance Show CoreRule 88 | -- deriving instance Show IsOrphan 89 | deriving instance Show StrictSig 90 | deriving instance Show DmdType 91 | 92 | instance Show RuleFun where 93 | show _ = "" 94 | 95 | instance Show (UniqFM a) where 96 | show _ = "" 97 | 98 | instance Show IdInfo where 99 | show info = 100 | "Info{" ++ intercalate "," [show arityInfo_, show specInfo_, show unfoldingInfo_, 101 | show cafInfo_, show oneShotInfo_, show inlinePragInfo_, 102 | show occInfo_, show strictnessInfo_, show demandInfo_, 103 | show callArityInfo_] ++ "}" 104 | where 105 | arityInfo_ = arityInfo info 106 | specInfo_ = specInfo info 107 | unfoldingInfo_ = unfoldingInfo info 108 | cafInfo_ = cafInfo info 109 | oneShotInfo_ = oneShotInfo info 110 | inlinePragInfo_ = inlinePragInfo info 111 | occInfo_ = occInfo info 112 | strictnessInfo_ = strictnessInfo info 113 | demandInfo_ = demandInfo info 114 | callArityInfo_ = callArityInfo info 115 | 116 | instance Show Var where 117 | show v = 118 | if | isId v -> 119 | let details = idDetails v 120 | info = idInfo v 121 | in "Id{" ++ intercalate "," [show name, show uniq, show ty, show details, show info] ++ "}" 122 | | isTyVar v -> "TyVar{" ++ show name ++ "}" 123 | | otherwise -> "TcTyVar{" ++ show name ++ "}" 124 | where 125 | name = varName v 126 | uniq = varUnique v 127 | ty = varType v 128 | 129 | instance Show DataCon where 130 | show = show . dataConName 131 | 132 | instance Show TyCon where 133 | show = show . tyConName 134 | 135 | instance Show ModuleName where 136 | show = show . moduleNameString 137 | 138 | instance Show PackageKey where 139 | show = show . packageKeyString 140 | 141 | instance Show Name where 142 | show = showOutputable . nameOccName 143 | 144 | -- deriving instance Show Name 145 | instance Show OccName where 146 | show = showOutputable 147 | 148 | instance Show Coercion where 149 | show _ = "" 150 | 151 | 152 | -- Instance for non-terms related stuff. 153 | 154 | deriving instance Show CoreToDo 155 | deriving instance Show SimplifierMode 156 | deriving instance Show CompilerPhase 157 | deriving instance Show FloatOutSwitches 158 | 159 | instance Show PluginPass where 160 | show _ = "PluginPass" 161 | -------------------------------------------------------------------------------- /stack.yaml: -------------------------------------------------------------------------------- 1 | flags: {} 2 | packages: 3 | - '.' 4 | extra-deps: 5 | [] 6 | resolver: lts-3.3 7 | -------------------------------------------------------------------------------- /test/SpecialFunctions.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE GADTs,RebindableSyntax,CPP,FlexibleContexts,FlexibleInstances,ConstraintKinds #-} 2 | {-# OPTIONS_GHC -dcore-lint #-} 3 | 4 | {- 5 | - This test suite demonstrates that the special functions `log1p`, `expm1`, and `hypot` all work. 6 | -} 7 | module Main 8 | where 9 | 10 | import SubHask 11 | -- import Prelude as P 12 | -- 13 | -- fromRational = P.fromRational 14 | -- 15 | -- (<) :: Ord a => a -> a -> Bool 16 | -- (<) = (P.<) 17 | 18 | -------------------------------------------------------------------------------- 19 | 20 | -- test1 :: Floating a => a -> a -> a 21 | test1 :: Double -> Double -> Double 22 | test1 a b = sqrt (a*a + b*b) 23 | 24 | test2 :: Double -> Double 25 | test2 a = log (1 + a) 26 | 27 | test3 :: Double -> Double 28 | test3 a = exp a - 1 29 | 30 | -------------------------------------------------------------------------------- 31 | 32 | main = return () 33 | -------------------------------------------------------------------------------- /test/Tests.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE GADTs,RebindableSyntax,CPP,FlexibleContexts,FlexibleInstances,ConstraintKinds #-} 2 | 3 | {-# OPTIONS_GHC -dcore-lint #-} 4 | 5 | {- 6 | - The idea of this test suite is that it should be compiled 7 | - with the -fplugin=Herbie and -dcore-lint flags. 8 | - Then we check to make sure GHC didn't throw any errors during 9 | - the core type checking process. 10 | -} 11 | module Main 12 | where 13 | 14 | import SubHask 15 | 16 | -------------------------------------------------------------------------------- 17 | 18 | -- This section tests that Herbie obeys the code annotations 19 | 20 | {-# ANN ann "NoHerbie" #-} 21 | ann :: Float -> Float 22 | ann x = x+x+x*x*x*x 23 | 24 | noann :: Float -> Float 25 | noann x = x+x+x*x*x*x 26 | 27 | -------------------------------------------------------------------------------- 28 | 29 | -- This section tests that Herbie gets run on the correct types. 30 | -- Herbie should be run on all the functions below. 31 | 32 | #define f1(x) (sqrt ((x)+1) - sqrt (x)) 33 | 34 | herbie1 :: Real a => a -> a 35 | herbie1 x = f1(x) 36 | 37 | herbie2 :: Real a => a -> a -> a -> a -> a 38 | herbie2 a b c d = f1(a)+f1(b)+f1(c)+f1(d) 39 | 40 | herbie3 :: Float -> Float 41 | herbie3 x = f1(x) 42 | 43 | herbie4 :: String -> String 44 | herbie4 str = show $ f1(x1) 45 | where 46 | x1 = fromIntegral (length str) :: Float 47 | 48 | herbie5 :: (Show a, Real a) => String -> a -> String 49 | herbie5 str x1 = show $ f1(x1) 50 | 51 | herbie6 :: (Show a, Real a) => a -> String -> String 52 | herbie6 x1 str = show $ f1(x1) 53 | 54 | herbie7 :: Semigroup a => a -> a 55 | herbie7 x1 = x1+x1+x1+x1+x1 56 | 57 | herbie8 :: Float -> Float 58 | herbie8 x1 = case x1 of 59 | 1.0 -> f1(x1) 60 | 2.0 -> x1 61 | 62 | herbie9 :: Float -> Float 63 | herbie9 x1 = go 4 x1 64 | where 65 | go :: Float -> Float -> Float 66 | go 0 b = b 67 | go a b = go (a-1) (sqrt (b-1)) 68 | 69 | herbie10 :: String -> String 70 | herbie10 str = show ( sqrt (1+fromIntegral (length str)) 71 | - sqrt (fromIntegral (length str)) 72 | :: Float 73 | ) 74 | 75 | -- Herbie should not get run on any of the functions in this section. 76 | 77 | #define f2(a,b) a+b*(a+b*a)+a*b 78 | 79 | noherbie1 :: String -> String 80 | noherbie1 x = x++"hello world" 81 | 82 | noherbie2 :: Rational -> Rational -> Rational 83 | noherbie2 a b = f2(a,b) 84 | 85 | noherbie3 :: Int -> Int -> Int 86 | noherbie3 a b = f2(a,b) 87 | 88 | noherbie4 :: x -> Int -> Int -> Int 89 | noherbie4 x a b = f2(a,b) 90 | 91 | -------------------------------------------------------------------------------- 92 | 93 | -- Herbie shouldn't process these because the expression size is too small. 94 | -- We're unlikely to get any benefit, and it might take a long time. 95 | 96 | toosmall1 :: Float -> Float 97 | toosmall1 a = a+a 98 | 99 | toosmall2 :: Float -> Float -> Float -> Float 100 | toosmall2 a b c = a+b*c 101 | 102 | -- These are big enough and should get processed 103 | 104 | bigenough1 :: Float -> Float 105 | bigenough1 a = a+a*a 106 | 107 | bigenough2 :: Float -> Float -> Float -> Float 108 | bigenough2 a b c = a+b*(c+a) 109 | 110 | bigenough3 :: Float -> Float -> Float -> Float 111 | bigenough3 a b c = f1(c) 112 | 113 | -------------------------------------------------------------------------------- 114 | 115 | -- This section contains lots of examples of expressions that the Herbie plugin can parse 116 | -- and find improved versions. 117 | 118 | example1 x1 x2 = sqrt (x1*x1 + x2*x2) 119 | 120 | example2 x = exp(log(x)+8) 121 | 122 | example3 x = sqrt(x*x +1) -1 123 | 124 | example4 x = exp(x)-1 125 | 126 | example5 x = log(1+x) 127 | 128 | example6 x y = sqrt(x+ y) - sqrt(y) 129 | 130 | example7 k r a = k*(r-a)^3 131 | 132 | example8 k r a = k*(r-a)^2 133 | 134 | example9 x y = sin(x - y) 135 | 136 | example10 p1x p2x p1y p2y = sqrt((p1x - p2x) * (p1x - p2x) + (p1y - p2y) * (p1y - p2y)) 137 | 138 | example11 x = sin(x)-x 139 | 140 | example12 x = 1-cos(x) 141 | 142 | example13 x1 x2 = sqrt((x1 - x2) * (x1 - x2)) 143 | 144 | example14 x y z = sqrt(x*x + y*y + z*z) 145 | 146 | example15 x y z c = sqrt(x*x + y*y + z*z)/c 147 | 148 | example16 tdx dx tdy dy = (tdx * dx + tdy * dy) / (dx * dx + dy * dy) 149 | 150 | example17 tdx dx tdy dy sl2 = (tdx * dx + tdy * dy) / sl2 151 | 152 | example18 x = (x + 0.1)-x 153 | 154 | example19 x = log(x) - sin(x+1) 155 | 156 | example20 a b = exp(1+log(a) + log(b)) 157 | 158 | example21 x = (1+sqrt(x-1))/(x-1)^2 159 | 160 | example22 x = (1+sqrt(x))/(x-1)^2 161 | 162 | example23 a b c d e f = a+b+(((d-c)*(d-c))*e*f/(e+f)) 163 | 164 | example24 q = sqrt(q*(q-1)) 165 | 166 | example25 a = sqrt(a^2-1) 167 | 168 | example26 a b c d = ((a*b)+(c*d))/(a+c) 169 | 170 | example27 x = sqrt(x^2) 171 | 172 | example28 x y = sqrt(x) * y * y 173 | 174 | example29 x y z = sqrt(x*x+y*y+z*z) 175 | 176 | example30 x y = 1.75 * x * y*y + sqrt(x/y) 177 | 178 | example31 x = exp(3*log(x)+2) 179 | 180 | example32 x = exp(2*log(x)) 181 | 182 | example33 x = sqrt(1/x + 1) - sqrt(1/x) 183 | 184 | example34 left i right count = left + i * ((left - right) / count) 185 | 186 | example35 left right count = left + count * ((left - right) / count) 187 | 188 | example36 x y = sqrt(x*x) - sqrt(y*y) 189 | 190 | example37 x = log(x+1)-log(x) 191 | 192 | example38 x = log(x+1)^x 193 | 194 | example39 minval minstep val = (minval/minstep + val) * minstep 195 | 196 | example40 x = x*x*cos(x/2 - sqrt(x)) 197 | 198 | example41 x = sqrt(4+x^2+x) 199 | 200 | example42 x y z = x / sqrt(x*x + y*y + z*z) 201 | 202 | example43 x = sin(sqrt(x+1)) 203 | 204 | example44 x = sqrt(x-2)-sqrt(x*x-3) 205 | 206 | example45 x = (sin(x) - tan(x)) / x 207 | 208 | example46 x y = 1 / sqrt(x^2 - y^2) 209 | 210 | example47 x1 x2 = sqrt((x1 - x2)^2) 211 | 212 | example48 x = x - sin(x) 213 | 214 | example49 x = sqrt(x + 1) - 1 + x 215 | 216 | example50 a b c = (a*a - c*c)/b 217 | 218 | example51 x y = sin(x+y)-cos(x+y) 219 | 220 | example52 x = (x + 1)^2 - 1 221 | 222 | example53 x = sqrt(1+x) - sqrt(x) 223 | 224 | example54 x = sqrt(x + 1) / (x*x) 225 | 226 | example55 x = sqrt(x^2 / 3) 227 | 228 | example56 a b = 100*(a-b)/a 229 | 230 | example57 x = abs(x^3)-x^3 231 | 232 | example58 x = log(x) - log(x+1) 233 | 234 | example59 x = 1/x - 1/(x+1) 235 | 236 | example60 a b c = -b + sqrt(b*b-4*a*c)/(2*a) 237 | 238 | example61 a c an cn = log(exp(a)*an + exp(c)*cn) - log(an+cn) 239 | 240 | example62 x = sqrt(sin(x)) - sqrt(x) 241 | 242 | example63 x = log(1+x) 243 | 244 | example64 a b = a * b / (1 - b + a * b) 245 | 246 | example65 a b = b*sqrt(a * a + 1.0) 247 | 248 | example65' a = sqrt(a * a + 1.0) 249 | 250 | example66 x y = x * y * x*pi/y 251 | 252 | example67 x = sqrt(x + 1) - sqrt(x - 1) 253 | 254 | example68 x = cos(x + 1) * x^2 255 | 256 | example69 a b = b*(a/b - log(1 + a/b)) 257 | 258 | example70 a b = b*(a/b - 1 - log(a/b)) 259 | 260 | example71 x = (6/(x^99))*(x^101) 261 | 262 | example72 x = (1/(x^99))*(x^101) 263 | 264 | example73 x = (1/(x^100))*(x^100) 265 | 266 | example74 x y z = cos(sqrt(x*x+y*y+z*z)) 267 | 268 | example75 x = sqrt(sqrt(x*x+1)+1) 269 | 270 | example76 a k = a + sqrt(a*a-k) 271 | 272 | example77 a k = -a - sqrt(a*a-k) 273 | 274 | example78 a b x = x*x*a+x*(a+b) +x*b 275 | 276 | example79 x = (x + x) ^ 3 / x 277 | 278 | example80 x = sqrt(x+1)-sqrt 1 279 | 280 | example81 x = (x+1)-x 281 | 282 | example82 x = sqrt(x+100)-sqrt(x) 283 | 284 | example83 x = 1-cos(x) 285 | 286 | example84 u v = sqrt(sqrt(u^2 + v^2) - u) 287 | 288 | example85 x = exp(log(x)) 289 | 290 | example86 x = sqrt(x + 1) - sqrt x + sin(x - 1) 291 | 292 | example87 x = exp x / sqrt(exp x - 1) * sqrt x 293 | 294 | example88 x = (exp(x) - 1) / x 295 | 296 | example89 x = sqrt(x + 2) - sqrt(x) 297 | 298 | -------------------------------------------------------------------------------- 299 | 300 | -- The main function does nothing 301 | main = return () 302 | -------------------------------------------------------------------------------- /test/ValidRewrite.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE GADTs,RebindableSyntax,CPP,FlexibleContexts,FlexibleInstances,ConstraintKinds #-} 2 | {-# LANGUAGE StandaloneDeriving,DeriveDataTypeable #-} 3 | {-# OPTIONS_GHC -dcore-lint #-} 4 | {- 5 | - This test suite ensures that the rewrites that HerbiePlugin performs 6 | - give the correct results. 7 | -} 8 | 9 | module Main 10 | where 11 | 12 | import SubHask 13 | 14 | import System.IO 15 | -- import Data.Complex 16 | -- import Linear.Quaternion 17 | -- import Linear.V3 18 | -- import Linear.Vector 19 | 20 | -------------------------------------------------------------------------------- 21 | 22 | test1a :: Double -> Double -> Double 23 | test1a far near = -(2 * far * near) / (far - near) 24 | 25 | {-# ANN test1b "NoHerbie" #-} 26 | test1b :: Double -> Double -> Double 27 | test1b far near = -(2 * far * near) / (far - near) 28 | 29 | {-# ANN test1c "NoHerbie" #-} 30 | test1c :: Double -> Double -> Double 31 | test1c far near = if far < -1.7210442634149447e81 32 | then ((-2 * far) / (far - near)) * near 33 | else if far < 8.364504563556443e16 34 | then -2 * far * (near / (far - near)) 35 | else ((-2 * far) / (far - near)) * near 36 | 37 | {- 38 | -------------------- 39 | 40 | test2a :: Double -> Double -> Double 41 | test2a a b = a + ((b - a) / 2) 42 | 43 | {-# ANN test2b "NoHerbie" #-} 44 | test2b :: Double -> Double -> Double 45 | test2b a b = a + ((b - a) / 2) 46 | 47 | -------------------- 48 | 49 | -- test3a :: Quaternion Double -> Quaternion Double -> Quaternion Double 50 | -- test3a (Quaternion q0 (V3 q1 q2 q3)) (Quaternion r0 (V3 r1 r2 r3)) = 51 | -- Quaternion (r0*q0+r1*q1+r2*q2+r3*q3) 52 | -- (V3 (r0*q1-r1*q0-r2*q3+r3*q2) 53 | -- (r0*q2+r1*q3-r2*q0-r3*q1) 54 | -- (r0*q3-r1*q2+r2*q1-r3*q0)) 55 | -- ^/ (r0*r0 + r1*r1 + r2*r2 + r3*r3) 56 | -- 57 | -- {-# ANN test3b "NoHerbie" #-} 58 | -- test3b :: Quaternion Double -> Quaternion Double -> Quaternion Double 59 | -- test3b (Quaternion q0 (V3 q1 q2 q3)) (Quaternion r0 (V3 r1 r2 r3)) = 60 | -- Quaternion (r0*q0+r1*q1+r2*q2+r3*q3) 61 | -- (V3 (r0*q1-r1*q0-r2*q3+r3*q2) 62 | -- (r0*q2+r1*q3-r2*q0-r3*q1) 63 | -- (r0*q3-r1*q2+r2*q1-r3*q0)) 64 | -- ^/ (r0*r0 + r1*r1 + r2*r2 + r3*r3) 65 | 66 | -------------------- 67 | 68 | data Yo a = Yo 69 | { yo_x2y :: a 70 | , yo_y2x :: a 71 | } 72 | 73 | test4 :: Real a => a -> a -> Yo a 74 | test4 x y = Yo 75 | { yo_x2y = x * x * y 76 | , yo_y2x = y * y * x 77 | } 78 | 79 | test5 :: Float -> Float -> Float 80 | test5 x y = (x * x) + (2 * x * y) + (y * y) 81 | 82 | -------------------------------------------------------------------------------- 83 | 84 | -- asinh_ :: Complex Double -> Complex Double 85 | -- asinh_ x = log (x + sqrt (1.0+x*x)) 86 | -- 87 | -- acosh_ :: Complex Double -> Complex Double 88 | -- acosh_ x = log (x + (x+1.0) * sqrt ((x-1.0)/(x+1.0))) 89 | 90 | atanh_ :: Double -> Double 91 | atanh_ x = 0.5 * log ((1.0+x) / (1.0-x)) 92 | 93 | -------------------------------------------------------------------------------- 94 | -} 95 | 96 | #define mkTest(f1,f2,a,b) \ 97 | putStrLn $ "mkTest: " ++ show (f1 (a) (b)); \ 98 | putStrLn $ "mkTest: " ++ show (f2 (a) (b)); \ 99 | putStrLn "" 100 | 101 | #define mkTestB(f1,f2,a) \ 102 | putStrLn $ "mkTest: " ++ show (f1 (a)); \ 103 | putStrLn $ "mkTest: " ++ show (f2 (a)); \ 104 | putStrLn "" 105 | 106 | main = do 107 | mkTest(test1a,test1b,-2e90,6) 108 | mkTest(test1a,test1b,3,4) 109 | mkTest(test1a,test1b,2e90,6) 110 | 111 | mkTest(test1a,test1c,-2e90,6) 112 | mkTest(test1a,test1c,3,4) 113 | mkTest(test1a,test1c,2e90,6) 114 | 115 | {- 116 | mkTest(test2a,test2b,1,2) 117 | 118 | -- mkTest(test3a,test3b,(Quaternion 1 (V3 1 2 3)),(Quaternion 2 (V3 2 3 4))) 119 | 120 | -- mkTestB(asinh,asinh_,5e-17::Complex Double) 121 | -- mkTestB(acosh,acosh_,5e-17::Complex Double) 122 | mkTestB(atanh,atanh_,5e-17::Double) 123 | -} 124 | 125 | putStrLn "done" 126 | 127 | 128 | 129 | 130 | 131 | --------------------------------------------------------------------------------