├── .git-blame-ignore-revs ├── .gitattributes ├── .github └── workflows │ ├── bench.yaml │ ├── docs.yaml │ ├── haskell-ci.yaml │ ├── julia-ci.yaml │ └── python-ci.yaml ├── .gitignore ├── .hlint.yaml ├── CITATION.cff ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── benchmarks ├── bfgs.py ├── continuous.py ├── conv.dx ├── conv_py.py ├── fused_sum.dx ├── gaussian.dx ├── jvp_matmul.dx ├── matmul_big.dx ├── matmul_small.dx ├── matvec_big.dx ├── matvec_small.dx ├── parboil │ ├── histogram.dx │ ├── mriq.dx │ └── stencil.dx ├── poly.dx ├── prepare-executables.py ├── rodinia │ ├── README.md │ ├── backprop.dx │ ├── backpropad.dx │ ├── hotspot.dx │ ├── kmeans.dx │ └── pathfinder.dx └── vjp_matmul.dx ├── cabal.project ├── default.nix ├── dex.cabal ├── doc ├── conditionals.dx ├── functions.dx └── syntax-philosophy.md ├── examples ├── bfgs.dx ├── brownian_motion.dx ├── ctc.dx ├── data-frames.dx ├── dither.dx ├── export │ ├── array.cpp │ ├── array.dx │ ├── scalar.cpp │ └── scalar.dx ├── fluidsim.dx ├── kernelregression.dx ├── latex.dx ├── levenshtein-distance.dx ├── linear-maps.dx ├── mandelbrot.dx ├── manifold-gradients.dx ├── mcmc.dx ├── mcts.dx ├── md.dx ├── mnist-nearest-neighbors.dx ├── nn.dx ├── ode-integrator.dx ├── particle-filter.dx ├── particle-swarm-optimizer.dx ├── pi.dx ├── psd.dx ├── quaternions.dx ├── raytrace.dx ├── regression.dx ├── rejection-sampler.dx ├── schrodinger.dx ├── sgd.dx ├── sierpinski.dx ├── simplex.dx ├── tutorial-old.dx ├── tutorial.dx └── vega-plotting.dx ├── flake.lock ├── flake.nix ├── julia ├── Project.toml ├── README.md ├── deps │ └── build.jl ├── src │ ├── DexCall.jl │ ├── api.jl │ ├── api_types.jl │ ├── evaluate.jl │ └── native_function.jl └── test │ ├── api.jl │ ├── evaluate.jl │ ├── native_function.jl │ └── runtests.jl ├── lib ├── complex.dx ├── diagram.dx ├── fft.dx ├── linalg.dx ├── netpbm.dx ├── parser.dx ├── plot.dx ├── png.dx ├── prelude.dx ├── set.dx ├── sort.dx └── stats.dx ├── makefile ├── misc ├── build-web-index ├── check-no-diff ├── check-quine ├── dex-completion.bash ├── dex.el └── file-check ├── python ├── dex │ ├── __init__.py │ ├── api.py │ ├── interop │ │ ├── __init__.py │ │ └── jax │ │ │ ├── __init__.py │ │ │ ├── apply.py │ │ │ ├── jax2dex.py │ │ │ └── jaxpr_json.py │ └── native_function.py ├── example.py ├── setup.py └── tests │ ├── api_test.py │ ├── dexjit_test.py │ ├── jax_test.py │ ├── jaxpr_json_test.py │ └── jit_test.py ├── shell.nix ├── src ├── Dex │ └── Foreign │ │ ├── API.hs │ │ ├── Context.hs │ │ ├── JAX.hs │ │ ├── JIT.hs │ │ ├── Serialize.hs │ │ ├── Util.hs │ │ └── rts.c ├── dex.hs ├── lib │ ├── AbstractSyntax.hs │ ├── Actor.hs │ ├── Algebra.hs │ ├── Builder.hs │ ├── CUDA.hs │ ├── CheapReduction.hs │ ├── CheckType.hs │ ├── ConcreteSyntax.hs │ ├── Core.hs │ ├── Err.hs │ ├── Export.hs │ ├── Generalize.hs │ ├── IRVariants.hs │ ├── Imp.hs │ ├── ImpToLLVM.hs │ ├── IncState.hs │ ├── Inference.hs │ ├── Inline.hs │ ├── JAX │ │ ├── Concrete.hs │ │ ├── Rename.hs │ │ └── ToSimp.hs │ ├── LLVM │ │ ├── CUDA.hs │ │ ├── Compile.hs │ │ ├── Link.hs │ │ └── Shims.hs │ ├── Lexing.hs │ ├── Linearize.hs │ ├── Live │ │ ├── Eval.hs │ │ └── Web.hs │ ├── Lower.hs │ ├── MTL1.hs │ ├── MonadUtil.hs │ ├── Name.hs │ ├── OccAnalysis.hs │ ├── Occurrence.hs │ ├── Optimize.hs │ ├── PPrint.hs │ ├── PeepholeOptimize.hs │ ├── QueryType.hs │ ├── QueryTypePure.hs │ ├── RawName.hs │ ├── RenderHtml.hs │ ├── Runtime.hs │ ├── RuntimePrint.hs │ ├── Serialize.hs │ ├── Simplify.hs │ ├── Simplify.hs-boot │ ├── SourceIdTraversal.hs │ ├── SourceRename.hs │ ├── Subst.hs │ ├── TopLevel.hs │ ├── Transpose.hs │ ├── Types │ │ ├── Core.hs │ │ ├── Imp.hs │ │ ├── OpNames.hs │ │ ├── Primitives.hs │ │ ├── Source.hs │ │ └── Top.hs │ ├── Util.hs │ ├── Vectorize.hs │ ├── dexrt.cpp │ └── work-stealing.c └── old │ ├── Imp │ ├── Builder.hs │ └── Optimize.hs │ ├── MLIR │ ├── Eval.hs │ └── Lower.hs │ └── Parallelize.hs ├── stack-llvm-head.yaml ├── stack-macos.yaml ├── stack.yaml ├── static ├── dynamic.html ├── index.ts └── style.css └── tests ├── ad-tests.dx ├── adt-tests.dx ├── algeff-tests.dx ├── cast-tests.dx ├── complex-tests.dx ├── eval-tests.dx ├── exception-tests.dx ├── fft-tests.dx ├── gpu-tests.dx ├── inline-tests.dx ├── instance-interface-syntax-tests.dx ├── instance-methods-tests.dx ├── io-tests.dx ├── linalg-tests.dx ├── linear-tests.dx ├── lower.dx ├── module-tests.dx ├── monad-tests.dx ├── opt-tests.dx ├── parser-combinator-tests.dx ├── parser-tests.dx ├── print-tests.dx ├── read-tests.dx ├── repl-multiline-test-expected-output ├── repl-multiline-test.dx ├── repl-regression-528-test-expected-output ├── repl-regression-528-test.dx ├── serialize-tests.dx ├── set-tests.dx ├── shadow-tests.dx ├── show-tests.dx ├── sort-tests.dx ├── stack-tests.dx ├── standalone-function-tests.dx ├── stats-tests.dx ├── struct-tests.dx ├── test_module_A.dx ├── test_module_B ├── test_module_B.dx ├── test_module_C.dx ├── trig-tests.dx ├── type-tests.dx ├── typeclass-tests.dx ├── uexpr-tests.dx └── unit ├── ConstantCastingSpec.hs ├── JaxADTSpec.hs ├── OccAnalysisSpec.hs ├── OccurrenceSpec.hs ├── RawNameSpec.hs ├── SourceInfoSpec.hs └── Spec.hs /.git-blame-ignore-revs: -------------------------------------------------------------------------------- 1 | # Run this command to always ignore these in local `git blame`: 2 | # git config blame.ignoreRevsFile .git-blame-ignore-revs 3 | 4 | # Formatted TopLevel.hs to 80 character width 5 | ef500cc06f96bddcc355d0e29945a3d48dd19867 6 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | *.dx linguist-language=haskell 2 | -------------------------------------------------------------------------------- /.github/workflows/bench.yaml: -------------------------------------------------------------------------------- 1 | name: Continuous benchmarking 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | 7 | jobs: 8 | build: 9 | runs-on: ${{ matrix.os }} 10 | strategy: 11 | matrix: 12 | os: [ubuntu-20.04] 13 | include: 14 | - os: ubuntu-20.04 15 | install_deps: sudo apt-get install llvm-12-tools llvm-12-dev pkg-config 16 | path_extension: /usr/lib/llvm-12/bin 17 | 18 | steps: 19 | - name: Checkout the repository 20 | uses: actions/checkout@v2 21 | with: 22 | fetch-depth: 0 23 | 24 | - name: Set up Python 3.9 25 | uses: actions/setup-python@v2 26 | with: 27 | python-version: 3.9 28 | 29 | - name: Install system dependencies 30 | run: | 31 | ${{ matrix.install_deps }} 32 | pip install numpy "jax[cpu]" 33 | echo "${{ matrix.path_extension }}" >> $GITHUB_PATH 34 | 35 | - name: Cache 36 | uses: actions/cache@v2 37 | with: 38 | path: | 39 | ~/.stack 40 | 41 | key: ${{ runner.os }}-bench-v1-${{ hashFiles('**/*.cabal', 'stack*.yaml') }} 42 | restore-keys: ${{ runner.os }}-bench-v1 43 | 44 | - name: Benchmark 45 | run: python3 benchmarks/continuous.py /tmp/new-perf-data.csv /tmp/new-commits.csv ${GITHUB_SHA} 46 | 47 | - name: Switch to the data branch 48 | uses: actions/checkout@v2 49 | with: 50 | ref: performance-data 51 | 52 | - name: Append new data points 53 | run: | 54 | cat /tmp/new-perf-data.csv >>performance.csv 55 | cat /tmp/new-commits.csv >>commits.csv 56 | 57 | - name: Commit new data points 58 | if: github.event_name == 'push' 59 | run: | 60 | git config --global user.name 'Dex CI' 61 | git config --global user.email 'apaszke@users.noreply.github.com' 62 | git add performance.csv commits.csv 63 | git commit -m "Add measurements for ${GITHUB_SHA}" 64 | git push 65 | -------------------------------------------------------------------------------- /.github/workflows/docs.yaml: -------------------------------------------------------------------------------- 1 | name: Update HTML docs 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | 8 | jobs: 9 | build: 10 | runs-on: ${{ matrix.os }} 11 | strategy: 12 | matrix: 13 | os: [ubuntu-20.04] 14 | include: 15 | - os: ubuntu-20.04 16 | install_deps: sudo apt-get install llvm-12-tools llvm-12-dev pkg-config wamerican 17 | path_extension: /usr/lib/llvm-12/bin 18 | 19 | steps: 20 | - name: Checkout the repository 21 | uses: actions/checkout@v2 22 | 23 | - name: Install system dependencies 24 | run: | 25 | ${{ matrix.install_deps }} 26 | echo "${{ matrix.path_extension }}" >> $GITHUB_PATH 27 | 28 | - name: Cache 29 | uses: actions/cache@v2 30 | with: 31 | path: | 32 | ~/.stack 33 | $GITHUB_WORKSPACE/.stack-work 34 | key: ${{ runner.os }}-${{ hashFiles('**/*.cabal', 'stack*.yaml') }} 35 | restore-keys: ${{ runner.os }}- 36 | 37 | - name: Build 38 | run: make build 39 | 40 | - name: Generate docs 41 | run: make docs 42 | 43 | - name: Deploy to GitHub Pages 44 | uses: "JamesIves/github-pages-deploy-action@3dbacc7e69578703f91f077118b3475862cb09b8" # 4.1.0 45 | with: 46 | token: ${{ secrets.GITHUB_TOKEN }} 47 | branch: gh-pages # The branch the action should deploy to. 48 | folder: pages/dex-lang # The folder the action should deploy. 49 | clean: false # If true, automatically remove deleted files from the deploy branch. 50 | commit-message: Updating gh-pages from ${{ github.sha }} 51 | -------------------------------------------------------------------------------- /.github/workflows/haskell-ci.yaml: -------------------------------------------------------------------------------- 1 | name: Haskell tests 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | pull_request: 7 | branches: [ main ] 8 | 9 | env: 10 | DEX_CI: 1 11 | 12 | concurrency: 13 | group: haskell-${{ github.head_ref || github.run_id }} 14 | cancel-in-progress: true 15 | 16 | jobs: 17 | build: 18 | runs-on: ${{ matrix.os }} 19 | strategy: 20 | matrix: 21 | os: [ubuntu-20.04, macos-latest] 22 | include: 23 | - os: macos-latest 24 | install_deps: brew install llvm@12 pkg-config wget gzip coreutils 25 | path_extension: $(brew --prefix llvm@12)/bin 26 | - os: ubuntu-20.04 27 | install_deps: sudo apt-get install llvm-12-tools llvm-12-dev pkg-config wget gzip wamerican 28 | path_extension: /usr/lib/llvm-12/bin 29 | 30 | steps: 31 | - name: Checkout the repository 32 | uses: actions/checkout@v2 33 | 34 | - name: Cache 35 | uses: actions/cache@v2 36 | with: 37 | path: | 38 | ~/.stack 39 | ~/.ghcup/ghc/8.10.7 40 | $GITHUB_WORKSPACE/.stack-work 41 | $GITHUB_WORKSPACE/.stack-work-test 42 | $GITHUB_WORKSPACE/examples/t10k-images-idx3-ubyte 43 | $GITHUB_WORKSPACE/examples/t10k-labels-idx1-ubyte 44 | 45 | key: ${{ runner.os }}-v5-${{ hashFiles('**/*.cabal', 'stack*.yaml') }} 46 | restore-keys: ${{ runner.os }}-v5- 47 | 48 | - name: Install system dependencies 49 | run: | 50 | ${{ matrix.install_deps }} 51 | if [[ "$OSTYPE" == "darwin"* ]]; then ghcup install ghc 8.10.7; fi 52 | echo "${{ matrix.path_extension }}" >> $GITHUB_PATH 53 | 54 | # This step is a workaround. 55 | # See issue for context: https://github.com/actions/cache/issues/445 56 | - name: Remove cached Setup executables 57 | run: rm -rf ~/.stack/setup-exe-cache 58 | if: runner.os == 'macOS' 59 | 60 | - name: Build, treating warnings as errors 61 | run: make build-ci 62 | if: runner.os == 'Linux' 63 | 64 | - name: Build 65 | run: make build 66 | 67 | - name: Run tests 68 | run: make tests 69 | -------------------------------------------------------------------------------- /.github/workflows/julia-ci.yaml: -------------------------------------------------------------------------------- 1 | name: Julia tests 2 | 3 | on: 4 | push: 5 | branches: [ ] 6 | pull_request: 7 | branches: [ ] 8 | 9 | env: 10 | DEX_CI: 1 11 | 12 | concurrency: 13 | group: julia-${{ github.head_ref || github.run_id }} 14 | cancel-in-progress: true 15 | 16 | jobs: 17 | build: 18 | if: false # TODO: Fix Julia bindings! 19 | 20 | runs-on: ${{ matrix.os }} 21 | strategy: 22 | matrix: 23 | os: [ubuntu-20.04] 24 | include: 25 | - os: ubuntu-20.04 26 | install_deps: sudo apt-get install llvm-12-tools llvm-12-dev pkg-config wget gzip 27 | path_extension: /usr/lib/llvm-12/bin 28 | 29 | steps: 30 | - name: Checkout the repository 31 | uses: actions/checkout@v2 32 | 33 | - name: Setup Julia 34 | uses: julia-actions/setup-julia@ee66464cb7897ffcc5322800f4b18d449794af30 # v1.6.1 35 | with: 36 | version: '1.6' 37 | arch: x64 38 | 39 | - name: Cache 40 | uses: actions/cache@v2 41 | with: 42 | path: | 43 | ~/.stack 44 | $GITHUB_WORKSPACE/.stack-work 45 | ~/.julia/artifacts 46 | key: ${{ runner.os }}-v2-julia-${{ hashFiles('**/*.cabal', 'stack*.yaml', '**/Project.toml') }} 47 | restore-keys: | 48 | ${{ runner.os }}-v2-julia- 49 | ${{ runner.os }}-v2- 50 | 51 | - name: Build DexCall.jl 52 | uses: julia-actions/julia-buildpkg@f995fa4149fed4a8e9b95ba82f54cc107c1d832a #v1.2.0 53 | with: 54 | project: "julia/" 55 | 56 | - name: Test DexCall.jl 57 | uses: julia-actions/julia-runtest@eda4346d69c0d1653e483c397a83c7f32f4ef2ac # v1.6.0 58 | with: 59 | project: "julia/" 60 | -------------------------------------------------------------------------------- /.github/workflows/python-ci.yaml: -------------------------------------------------------------------------------- 1 | name: Python tests 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | pull_request: 7 | branches: [ main ] 8 | 9 | env: 10 | DEX_CI: 1 11 | 12 | concurrency: 13 | group: python-${{ github.head_ref || github.run_id }} 14 | cancel-in-progress: true 15 | 16 | jobs: 17 | build: 18 | runs-on: ${{ matrix.os }} 19 | strategy: 20 | matrix: 21 | os: [ubuntu-20.04] 22 | include: 23 | - os: ubuntu-20.04 24 | install_deps: sudo apt-get install llvm-12-tools llvm-12-dev pkg-config wget gzip 25 | path_extension: /usr/lib/llvm-12/bin 26 | 27 | steps: 28 | - name: Checkout the repository 29 | uses: actions/checkout@v2 30 | 31 | - name: Set up Python 3.9 32 | uses: actions/setup-python@v2 33 | with: 34 | python-version: 3.9 35 | 36 | - name: Install system dependencies 37 | run: | 38 | pip install pytest jax jaxlib 39 | ${{ matrix.install_deps }} 40 | echo "${{ matrix.path_extension }}" >> $GITHUB_PATH 41 | 42 | - name: Cache 43 | uses: actions/cache@v2 44 | with: 45 | path: | 46 | ~/.stack 47 | ~/.cache/pip 48 | $GITHUB_WORKSPACE/.stack-work 49 | key: ${{ runner.os }}-v2-python-${{ hashFiles('**/*.cabal', 'stack*.yaml') }} 50 | restore-keys: | 51 | ${{ runner.os }}-v2-python- 52 | ${{ runner.os }}-v2- 53 | 54 | - name: Build 55 | run: make build-ffis 56 | 57 | - name: Install Python bindings 58 | run: pip install -e $GITHUB_WORKSPACE/python 59 | 60 | - name: Run tests 61 | run: pytest python/tests 62 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.hi 2 | *.o 3 | test.db 4 | *.prof 5 | *.prof.html 6 | test-script.cd 7 | *.out 8 | *.so 9 | *.egg-info 10 | .stack-work 11 | .stack-work-opt 12 | .stack-work-dbg 13 | .stack-work-prof 14 | .stack-work-test 15 | .stack-work-test-dbg 16 | .stack-work-ffis 17 | garbage.hs 18 | *.lock 19 | *.pyc 20 | pages/ 21 | scratch* 22 | scratch/ 23 | test-scratch/ 24 | dist-newstyle/ 25 | *.cache 26 | *.db 27 | *.bc 28 | .benchmarks 29 | benchmarks/exe/**/* 30 | benchmarks/parboil/data 31 | benchmarks/rodinia/rodinia 32 | examples/export/scalar 33 | examples/export/array 34 | hie.yaml 35 | Manifest.toml 36 | julia/deps/build.log 37 | examples/t10k-images-idx3-ubyte 38 | examples/t10k-labels-idx1-ubyte 39 | examples/camera.ppm 40 | static/index.js 41 | -------------------------------------------------------------------------------- /.hlint.yaml: -------------------------------------------------------------------------------- 1 | - arguments: [--color] 2 | - ignore: {} 3 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | message: "If you use this software, please cite it as below." 3 | authors: 4 | - given-names: "Dougal" 5 | family-names: "Maclaurin" 6 | - given-names: "Adam" 7 | family-names: "Paszke" 8 | title: "Dex: typed and functional array processing" 9 | version: 0.0.0 10 | date-released: 2018-09-29 11 | url: "https://github.com/google-research/dex-lang" 12 | preferred-citation: 13 | type: article 14 | authors: 15 | - given-names: "Adam" 16 | family-names: "Paszke" 17 | - given-names: "Daniel D." 18 | family-names: "Johnson" 19 | - given-names: "David" 20 | family-names: "Duvenaud" 21 | - given-names: "Dimitrios" 22 | family-names: "Vytiniotis" 23 | - given-names: "Alexey" 24 | family-names: "Radul" 25 | - given-names: "Matthew J." 26 | family-names: "Johnson" 27 | - given-names: "Jonathan" 28 | family-names: "Ragan-Kelley" 29 | - given-names: "Dougal" 30 | family-names: "Maclaurin" 31 | doi: "10.1145/3473593" 32 | title: "Getting to the Point: Index Sets and Parallelism-Preserving Autodiff for Pointful Array Programming" 33 | volume: 5 34 | number: "ICFP" 35 | year: 2021 36 | month: 8 37 | issue-date: "August 2021" 38 | journal: "Proceedings of the ACM on Programming Languages" 39 | pages: 29 40 | abstract: > 41 | We present a novel programming language design that attempts to combine the clarity 42 | and safety of high-level functional languages with the efficiency and parallelism 43 | of low-level numerical languages. We treat arrays as eagerly-memoized functions on 44 | typed index sets, allowing abstract function manipulations, such as currying, to work 45 | on arrays. In contrast to composing primitive bulk-array operations, we argue for 46 | an explicit nested indexing style that mirrors application of functions to arguments. 47 | We also introduce a fine-grained typed effects system which affords concise and automatically-parallelized 48 | in-place updates. Specifically, an associative accumulation effect allows reverse-mode 49 | automatic differentiation of in-place updates in a way that preserves parallelism. 50 | Empirically, we benchmark against the Futhark array programming language, and demonstrate 51 | that aggressive inlining and type-driven compilation allows array programs to be written 52 | in an expressive, "pointful" style with little performance penalty. 53 | keywords: 54 | - "array programming" 55 | - "automatic differentiation" 56 | - "parallel computing" 57 | url: "https://doi.org/10.1145/3473593" 58 | publisher: 59 | name: "Association for Computing Machinery" 60 | city: "New York" 61 | region: "NY" 62 | country: "USA" 63 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We'd love to accept your patches and contributions to this project. There are 4 | just a few small guidelines you need to follow. 5 | 6 | ## Contributor License Agreement 7 | 8 | Contributions to this project must be accompanied by a Contributor License 9 | Agreement. You (or your employer) retain the copyright to your contribution; 10 | this simply gives us permission to use and redistribute your contributions as 11 | part of the project. Head over to to see 12 | your current agreements on file or to sign a new one. 13 | 14 | You generally only need to submit a CLA once, so if you've already submitted one 15 | (even if it was for a different project), you probably don't need to do it 16 | again. 17 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2019 Google LLC 2 | 3 | Redistribution and use in source and binary forms, with or without 4 | modification, are permitted provided that the following conditions are 5 | met: 6 | 7 | * Redistributions of source code must retain the above copyright 8 | notice, this list of conditions and the following disclaimer. 9 | * Redistributions in binary form must reproduce the above 10 | copyright notice, this list of conditions and the following disclaimer 11 | in the documentation and/or other materials provided with the 12 | distribution. 13 | * Neither the name of Google LLC nor the names of its 14 | contributors may be used to endorse or promote products derived from 15 | this software without specific prior written permission. 16 | 17 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 18 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 19 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 20 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 21 | OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 22 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 23 | LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 24 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 25 | THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 26 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | -------------------------------------------------------------------------------- /benchmarks/bfgs.py: -------------------------------------------------------------------------------- 1 | 2 | from absl import app 3 | from absl import flags 4 | 5 | from jax import numpy as jnp 6 | import jaxopt 7 | import dex 8 | from dex.interop import jax as djax 9 | from sklearn import datasets 10 | 11 | import time 12 | 13 | FLAGS = flags.FLAGS 14 | 15 | flags.DEFINE_integer("maxiter", default=30, help="Max # of iterations.") 16 | flags.DEFINE_integer("maxls", default=15, help="Max # of linesearch iterations.") 17 | flags.DEFINE_float("tol", default=1e-3, help="Tolerance of the stopping criterion.") 18 | flags.DEFINE_integer("n_samples", default=1000, help="Number of samples.") 19 | flags.DEFINE_integer("n_features", default=20, help="Number of features.") 20 | flags.DEFINE_integer("n_classes", default=5, help="Number of classes.") 21 | flags.DEFINE_string("task", "binary_logreg", "Task to benchmark.") 22 | 23 | 24 | def multiclass_logreg_jaxopt(X, y): 25 | data = (X, y) 26 | fun = jaxopt.objective.multiclass_logreg 27 | init = jnp.zeros((X.shape[1], FLAGS.n_classes)) 28 | bfgs = jaxopt.BFGS( 29 | fun=fun, 30 | linesearch='zoom', 31 | maxiter=FLAGS.maxiter, 32 | maxls=FLAGS.maxls, 33 | tol=FLAGS.tol) 34 | 35 | start_time = time.time() 36 | _ = bfgs.run(init_params=init, data=data) 37 | compile_time = time.time() 38 | 39 | _, state = bfgs.run(init_params=init, data=data) 40 | run_time = time.time() 41 | 42 | return compile_time - start_time, run_time - compile_time, state.error, state.iter_num, state.value 43 | 44 | 45 | def main(argv): 46 | # Compare performance of Jaxopt and Dex BFGS on a multiclass logistic regression problem. 47 | X, y = datasets.make_classification(n_samples=FLAGS.n_samples, 48 | n_features=FLAGS.n_features, 49 | n_classes=FLAGS.n_classes, 50 | n_informative=FLAGS.n_classes, 51 | random_state=0) 52 | time_incl_jit, time_excl_jit, _, _, dex_value = multiclass_logreg_jaxopt(X, y) 53 | print(f"> Jaxopt results:\n Time incl JIT: {time_incl_jit}\n" 54 | f" Time excl JIT: {time_excl_jit}\n Loss function value: {dex_value}") 55 | 56 | with open('examples/bfgs.dx', 'r') as f: 57 | m = dex.Module(f.read()) 58 | dex_bfgs = djax.primitive(m.multiclass_logreg_int) 59 | 60 | start_time = time.time() 61 | dex_value = dex_bfgs( 62 | jnp.array(X), 63 | jnp.array(y), 64 | FLAGS.n_classes, 65 | FLAGS.maxiter, 66 | FLAGS.maxls, 67 | FLAGS.tol) 68 | print(f"> Dex results:\n Total time: {time.time() - start_time}\n" 69 | f" Loss function value: {dex_value}") 70 | 71 | 72 | if __name__ == '__main__': 73 | app.run(main) 74 | -------------------------------------------------------------------------------- /benchmarks/conv.dx: -------------------------------------------------------------------------------- 1 | '# Diagonal convolution 2 | 3 | 'This computes a diagonally-indexed summation: 4 | ``` 5 | result.i.j = input.(i-1).(j-1) + input.i.j + input.(i+1).(j+1) 6 | ``` 7 | This computation is interesting because it occurs in the inner 8 | loop of computing the Neural Tangent Kernel of a convolutional 9 | layer. 10 | 11 | def unsafe_from_integer(i:Int) -> n given (n|Ix) = 12 | unsafe_from_ordinal $ unsafe_i_to_n i 13 | 14 | def conv_1d( 15 | kernel: (Fin d1)=>(Fin d2)=>Float, 16 | size: Nat) 17 | -> (Fin d1)=>(Fin d2)=>Float given (d1, d2) = 18 | half_kernel_size = (f_to_i $ (n_to_f size) / 2.0) 19 | for i j. sum for k: (Fin size). 20 | i' = n_to_i $ ordinal i 21 | j' = n_to_i $ ordinal j 22 | k' = n_to_i $ ordinal k 23 | i'' = i' + k' - half_kernel_size 24 | j'' = j' + k' - half_kernel_size 25 | if i'' < 0 || i'' >= (n_to_i d1) || j'' < 0 || j'' >= (n_to_i d2) 26 | then 0 27 | else kernel[unsafe_from_integer i'', unsafe_from_integer j''] 28 | 29 | def conv( 30 | kernel: (Fin n)=>(Fin c)=>(Fin h)=>(Fin w)=>Float, 31 | size: Int) 32 | -> (Fin n)=>(Fin c)=>(Fin h)=>(Fin w)=>Float given (n, c, h, w) = 33 | for n' c'. conv_1d(kernel[n', c'], unsafe_i_to_n(size)) 34 | 35 | def conv_spec( 36 | kernel: (Fin n)=>(Fin c)=>(Fin h)=>(Fin w)=>Float, 37 | size: Int) 38 | -> (Fin n)=>(Fin c)=>(Fin h)=>(Fin w)=>Float given (n, c, h, w) = 39 | if size == 3 40 | then conv kernel 3 41 | else conv kernel size 42 | 43 | 'We benchmark it on a roughly representative input. 44 | 45 | width = 3 46 | side = 32 47 | n = 100 48 | 49 | x1 = for i:(Fin n) m:(Fin width) j:(Fin side) k:(Fin side). 50 | randn (ixkey (new_key 0) (i, m, j, k)) 51 | 52 | :t x1 53 | 54 | filter_size = +3 55 | 56 | %bench "Diagonal convolution" 57 | res = conv x1 filter_size 58 | 59 | :t res 60 | -------------------------------------------------------------------------------- /benchmarks/conv_py.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import dex 3 | from dex.interop import jax as djax 4 | import numpy as np 5 | 6 | import time 7 | import timeit 8 | 9 | def bench_python(f, loops=None): 10 | """Return average runtime of `f` in seconds and number of iterations used.""" 11 | if loops is None: 12 | f() 13 | s = time.perf_counter() 14 | f() 15 | e = time.perf_counter() 16 | duration = e - s 17 | loops = max(4, int(2 / duration)) # aim for 2s 18 | return (timeit.timeit(f, number=loops, globals=globals()) / loops, loops) 19 | 20 | 21 | def main(): 22 | with open('benchmarks/conv.dx', 'r') as f: 23 | m = dex.Module(f.read()) 24 | dex_conv = djax.primitive(m.conv_spec) 25 | shp = (int(m.n), int(m.width), int(m.side), int(m.side)) 26 | xs = jax.random.normal(jax.random.PRNGKey(1), shp, dtype=jax.numpy.float32) 27 | filter_size = int(m.filter_size) 28 | msg = ("TODO Make dex.interop.primitive return Jax Device Arrays, " 29 | "and change this assert to a block_until_ready() call.") 30 | assert isinstance(dex_conv(xs, filter_size), np.ndarray), msg 31 | time_s, loops = bench_python(lambda : dex_conv(xs, filter_size)) 32 | print(f"> Run time: {time_s} s \t(based on {loops} runs)") 33 | 34 | 35 | if __name__ == '__main__': 36 | main() 37 | -------------------------------------------------------------------------------- /benchmarks/fused_sum.dx: -------------------------------------------------------------------------------- 1 | n = if dex_test_mode() then 1000 else 1000000 2 | 3 | %bench "fused-sum" 4 | sum $ for i:(Fin n). 5 | x = n_to_i64 (ordinal i) 6 | x * x 7 | > 332833500 8 | > 9 | > fused-sum 10 | > Compile time: 22.192 ms 11 | > Run time: 5.186 us (based on 1 run) 12 | -------------------------------------------------------------------------------- /benchmarks/gaussian.dx: -------------------------------------------------------------------------------- 1 | n = if dex_test_mode() then 100 else 1000 * 1000 2 | 3 | %bench "Gaussian" 4 | res = rand_vec n randn (new_key 0) 5 | > 6 | > Gaussian 7 | > Compile time: 40.102 ms 8 | > Run time: 8.517 us (based on 1 run) 9 | -------------------------------------------------------------------------------- /benchmarks/jvp_matmul.dx: -------------------------------------------------------------------------------- 1 | n = if dex_test_mode() then 10 else 500 2 | 3 | m1 = rand_mat(n, n, randn, new_key 0) 4 | m2 = rand_mat(n, n, randn, new_key 1) 5 | 6 | def mmp'(m1:l=>m=>Float, m2:m=>n=>Float) -> l=>n=>Float given (l|Ix, m|Ix, n|Ix) = 7 | jvp (\m. m1 ** m) m2 m2 8 | 9 | %bench "jvp_matmul" 10 | res = mmp'(m1, m2) 11 | > 12 | > jvp_matmul 13 | > Compile time: 82.255 ms 14 | > Run time: 7.298 us (based on 1 run) 15 | -------------------------------------------------------------------------------- /benchmarks/matmul_big.dx: -------------------------------------------------------------------------------- 1 | n = if dex_test_mode() then 10 else 500 2 | 3 | m1 = rand_mat n n randn (new_key 0) 4 | m2 = rand_mat n n randn (new_key 1) 5 | 6 | %bench "matmul_big" 7 | res = m1 ** m2 8 | > 9 | > matmul_big 10 | > Compile time: 27.431 ms 11 | > Run time: 4.406 us (based on 1 run) 12 | -------------------------------------------------------------------------------- /benchmarks/matmul_small.dx: -------------------------------------------------------------------------------- 1 | n = 10 2 | width = 1000 3 | 4 | m1 = for i:(Fin width). rand_mat(n, n, randn, new_key 0) 5 | m2 = for i:(Fin width). rand_mat(n, n, randn, new_key 1) 6 | 7 | %bench "matmul_small" 8 | res = for i. (m1[i] ** m2[i]) 9 | > 10 | > matmul_small 11 | > Compile time: 33.241 ms 12 | > Run time: 358.211 us (based on 1 run) 13 | 14 | -------------------------------------------------------------------------------- /benchmarks/matvec_big.dx: -------------------------------------------------------------------------------- 1 | n = if dex_test_mode() then 10 else 10000 2 | 3 | m = rand_mat n n randn (new_key 0) 4 | v = rand_vec n randn (new_key 1) 5 | 6 | %bench "matvec_big" 7 | res = m **. v 8 | > 9 | > matvec_big 10 | > Compile time: 23.898 ms 11 | > Run time: 4.682 us (based on 1 run) 12 | -------------------------------------------------------------------------------- /benchmarks/matvec_small.dx: -------------------------------------------------------------------------------- 1 | n = 10 2 | width = 10000 3 | 4 | ms = for i:(Fin width). rand_mat(n, n, randn, new_key 0) 5 | vs = for i:(Fin width). rand_vec(n, randn, new_key 1) 6 | 7 | %bench "matvec_small" 8 | res = for i. ms[i] **. vs[i] 9 | > 10 | > matvec_small 11 | > Compile time: 29.506 ms 12 | > Run time: 387.630 us (based on 1 run) 13 | -------------------------------------------------------------------------------- /benchmarks/parboil/histogram.dx: -------------------------------------------------------------------------------- 1 | def (+^) (x: Int) (y: Int) : Int = min (x + y) 255 2 | 3 | def histogram (hist_size: Int) (input : h=>w=>Int) : (Fin hist_size)=>Int = 4 | snd $ with_accum \hist. 5 | for i j. 6 | pos = input.i.j 7 | case 0 <= pos && pos < hist_size of 8 | True -> hist!(unsafe_from_ordinal _ pos) += 1 9 | False -> () 10 | -------------------------------------------------------------------------------- /benchmarks/parboil/mriq.dx: -------------------------------------------------------------------------------- 1 | -- def mriq 2 | -- (kx : ks=>Float) (ky : ks=>Float) (kz : ks=>Float) 3 | -- (x : cs=>Float) (y : cs=>Float) (z : cs=>Float) 4 | -- (phiR : ks=>Float) (phiI : ks=>Float) 5 | -- : (cs=>Float & cs=>Float) = 6 | -- phiMags = for i. phiR.i * phiR.i + phiI.i * phiI.i 7 | -- expArgs = for i:cs j:ks. 2.0 * pi * (kx.j * x.i + ky.j * y.i + kz.j * z.i) 8 | -- qr = for i. sum $ for j. cos $ expArgs.i.j * phiMags.j 9 | -- qi = for i. sum $ for j. sin $ expArgs.i.j * phiMags.j 10 | -- (qr, qi) 11 | 12 | def mriq 13 | (kx : ks=>Float) (ky : ks=>Float) (kz : ks=>Float) 14 | (x : cs=>Float) (y : cs=>Float) (z : cs=>Float) 15 | (phiR : ks=>Float) (phiI : ks=>Float) 16 | : (cs=>Float & cs=>Float) = 17 | unzip $ for i. 18 | run_accum (AddMonoid Float) \qi. 19 | yield_accum (AddMonoid Float) \qr. 20 | for j. 21 | phiMag = phiR.j * phiR.j + phiI.j * phiI.j 22 | expArg = kx.j * x.i + ky.j * y.i + kz.j * z.i 23 | t = 2.0 * pi * expArg * phiMag 24 | qr += cos t 25 | qi += sin t 26 | -------------------------------------------------------------------------------- /benchmarks/parboil/stencil.dx: -------------------------------------------------------------------------------- 1 | -- Assumes that off >= 0 2 | def (+|) (i:n) (off:Int) : n = 3 | newOrd = ordinal i + off 4 | case newOrd < size n of 5 | True -> unsafe_from_ordinal _ newOrd 6 | False -> i 7 | 8 | -- Assumes that off >= 0 9 | def (-|) (i:n) (off:Int) : n = 10 | newOrd = ordinal i - off 11 | case 0 <= newOrd of 12 | True -> unsafe_from_ordinal _ newOrd 13 | False -> i 14 | 15 | def stencil (input : nx=>ny=>nz=>Float) : nx=>ny=>nz=>Float = 16 | c0 = 1.0 / 6.0 17 | c1 = c0 * c0 18 | (xs, ys, zs) = (size nx, size ny, size nz) 19 | for x y z. 20 | (xi, yi, zi) = (ordinal x, ordinal y, ordinal z) 21 | case xi == 0 || xi == (xs-1) || yi == 0 || yi == (ys-1) || zi == 0 || zi == (zs-1) of 22 | True -> input.x.y.z 23 | False -> 24 | neigh = (input.x.y.(z -| 1) + input.x.y.(z +| 1) + 25 | input.x.(y -| 1).z + input.x.(y +| 1).z + 26 | input.(x -| 1).y.z + input.(x +| 1).y.z) 27 | input.x.y.z * c0 + neigh * c1 28 | -------------------------------------------------------------------------------- /benchmarks/poly.dx: -------------------------------------------------------------------------------- 1 | n = if dex_test_mode() then 1000 else 100000 2 | 3 | a = for i:(Fin n). n_to_f $ ordinal i 4 | 5 | %bench "poly" 6 | res = for i. evalpoly [0.0, 1.0, 2.0, 3.0, 4.0] a[i] 7 | > 8 | > poly 9 | > Compile time: 44.950 ms 10 | > Run time: 11.224 us (based on 1 run) 11 | -------------------------------------------------------------------------------- /benchmarks/rodinia/README.md: -------------------------------------------------------------------------------- 1 | # Dex implementation of the Rodinia benchmark suite 2 | 3 | Implementation of each benchmark can be found in the `XYZ.dx` file in this directory. 4 | They don't contain any IO code, so running them will only type-check the definitions. 5 | All runnable scripts are generated in the process of running `python prepare-executables.py`. 6 | Note that the Python script assumes that you have the Rodinia suite downloaded and placed in the `rodinia/` subdirectory of this one. 7 | The original benchmark suite is necessary to retrieve the standard example inputs. 8 | -------------------------------------------------------------------------------- /benchmarks/rodinia/backprop.dx: -------------------------------------------------------------------------------- 1 | ETA = 0.3 2 | MOMENTUM = 0.3 3 | 4 | def squash (x : Float) : Float = 1.0 / (1.0 + exp (-x)) 5 | 6 | def layerForward 7 | (input : in=>Float) 8 | (params : { b: Unit| w: in }=>out=>Float) 9 | : out=>Float = 10 | bias = params.{| b=() |} 11 | total = (sum $ for i:in j:out. params.{| w=i |}.j * input.i) + bias 12 | for i. squash total.i 13 | 14 | def adjustWeights 15 | (delta : out=>Float) 16 | (input : in=>Float) 17 | (weight : { b: Unit | w: in }=>out=>Float) 18 | (oldWeight : { b: Unit | w: in }=>out=>Float) 19 | : ({ b: Unit | w: in }=>out=>Float & { b: Unit | w: in }=>out=>Float) = 20 | weight' = for k:{ b: Unit | w: in } j:out. 21 | i = case k of 22 | {| b=() |} -> 1.0 23 | {| w=k' |} -> input.k' 24 | d = ETA * delta.j * i + (MOMENTUM * oldWeight.k.j) 25 | weight.k.j + d 26 | (weight', weight) 27 | 28 | def outputError 29 | (target : out=>Float) 30 | (output : out=>Float) 31 | : (Float & out=>Float) = 32 | swap $ run_accum (AddMonoid Float) \err. 33 | for i. 34 | o = output.i 35 | d = o * (1.0 - o) * (target.i - o) 36 | err += abs d 37 | d 38 | 39 | def hiddenError 40 | (outputDelta : out=>Float) 41 | (hiddenWeights : { b: Unit | w: hid }=>out=>Float) 42 | (hidden : hid=>Float) 43 | : (Float & hid=>Float) = 44 | swap $ run_accum (AddMonoid Float) \err. 45 | for i:hid. 46 | mult = sum $ for j. outputDelta.j * hiddenWeights.{| w = i |}.j 47 | r = hidden.i * (1.0 - hidden.i) * mult 48 | err += abs r 49 | r 50 | 51 | def backprop 52 | (input : in=>Float) 53 | (target : out=>Float) 54 | (inputWeights : { b: Unit | w: in }=>hid=>Float) 55 | (hiddenWeights : { b: Unit | w: hid }=>out=>Float) 56 | (oldInputWeights : { b: Unit | w: in }=>hid=>Float) 57 | (oldHiddenWeights : { b: Unit | w: hid }=>out=>Float) 58 | : ( Float 59 | & Float 60 | & { b: Unit | w: in }=>hid=>Float 61 | & { b: Unit | w: hid }=>out=>Float) = 62 | hidden = layerForward input inputWeights 63 | output = layerForward hidden hiddenWeights 64 | 65 | (outputErr, outputDelta) = outputError target output 66 | (hiddenErr, hiddenDelta) = hiddenError outputDelta hiddenWeights hidden 67 | 68 | (hiddenWeights', oldHiddenWeights') = adjustWeights outputDelta hidden hiddenWeights oldHiddenWeights 69 | (inputWeights', oldInputWeights') = adjustWeights hiddenDelta input inputWeights oldInputWeights 70 | 71 | (outputErr, hiddenErr, inputWeights', hiddenWeights') 72 | -------------------------------------------------------------------------------- /benchmarks/rodinia/backpropad.dx: -------------------------------------------------------------------------------- 1 | ETA = 0.3 2 | MOMENTUM = 0.3 3 | 4 | def squash (x : Float) : Float = 1.0 / (1.0 + exp (-x)) 5 | 6 | def layerForward 7 | (input : in=>Float) 8 | (params : { b: Unit| w: in }=>out=>Float) 9 | : out=>Float = 10 | bias = params.{| b=() |} 11 | total = (sum $ for i:in j:out. params.{| w=i |}.j * input.i) + bias 12 | for i. squash total.i 13 | 14 | def lossForward (input : n=>Float) (target : n=>Float) : Float = 15 | sum $ input - target 16 | 17 | def adjustWeights 18 | (gradWeight : { b: Unit | w: in }=>out=>Float) 19 | (weight : { b: Unit | w: in }=>out=>Float) 20 | (oldWeight : { b: Unit | w: in }=>out=>Float) 21 | : ({ b: Unit | w: in }=>out=>Float) = 22 | for k j. 23 | d = ETA * gradWeight.k.j + MOMENTUM * oldWeight.k.j 24 | weight.k.j + d 25 | 26 | def backpropad 27 | (input : in=>Float) 28 | (target : out=>Float) 29 | (inputWeights : { b: Unit | w: in }=>hid=>Float) 30 | (hiddenWeights : { b: Unit | w: hid }=>out=>Float) 31 | (oldInputWeights : { b: Unit | w: in }=>hid=>Float) 32 | (oldHiddenWeights : { b: Unit | w: hid }=>out=>Float) 33 | : ( { b: Unit | w: in }=>hid=>Float 34 | & { b: Unit | w: hid }=>out=>Float) = 35 | (gradInputWeights, gradHiddenWeights) = 36 | flip grad (inputWeights, hiddenWeights) \(iw, hw). 37 | hidden = layerForward input inputWeights 38 | output = layerForward hidden hiddenWeights 39 | lossForward output target 40 | 41 | hiddenWeights' = adjustWeights gradHiddenWeights hiddenWeights oldHiddenWeights 42 | inputWeights' = adjustWeights gradInputWeights inputWeights oldInputWeights 43 | (inputWeights', hiddenWeights') 44 | -------------------------------------------------------------------------------- /benchmarks/rodinia/hotspot.dx: -------------------------------------------------------------------------------- 1 | maxPD = 3000000.0 2 | precision = 0.001 3 | specHeatSI = 1750000.0 4 | kSI = 100.0 5 | 6 | factorChip = 0.5 7 | 8 | tChip = 0.0005 9 | chipHeight = 0.016 10 | chipWidth = 0.016 11 | 12 | tAmb = 80.0 13 | 14 | -- Assumes that off >= 0 15 | def (+|) (i:n) (off:Int) : n = 16 | newOrd = ordinal i + off 17 | case newOrd < size n of 18 | True -> unsafe_from_ordinal _ newOrd 19 | False -> i 20 | 21 | -- Assumes that off >= 0 22 | def (-|) (i:n) (off:Int) : n = 23 | newOrd = ordinal i - off 24 | case 0 <= newOrd of 25 | True -> unsafe_from_ordinal _ newOrd 26 | False -> i 27 | 28 | def hotspot 29 | (numIterations: Int) 30 | (tsInit : r=>c=>Float) 31 | (p : r=>c=>Float) 32 | : r=>c=>Float = 33 | gridHeight = chipHeight / (i_to_f $ size r) 34 | gridWidth = chipWidth / (i_to_f $ size c) 35 | cap = factorChip * specHeatSI * tChip * gridWidth * gridHeight 36 | Rx = gridWidth / (2.0 * kSI * tChip * gridHeight) 37 | Ry = gridHeight / (2.0 * kSI * tChip * gridWidth ) 38 | Rz = tChip / (kSI * gridHeight * gridWidth) 39 | maxSlope = maxPD / (factorChip * tChip * specHeatSI) 40 | step = precision / maxSlope 41 | yield_state tsInit $ \tsRef. 42 | for _:(Fin numIterations). 43 | ts = get tsRef 44 | tsRef := for r c. 45 | t = ts.r.c 46 | dc = (ts.r.(c +| 1) + ts.r.(c -| 1) - 2.0 * t) / Rx 47 | dr = (ts.(r +| 1).c + ts.(r -| 1).c - 2.0 * t) / Ry 48 | d = (step / cap) * (p.r.c + dc + dr + (tAmb - t) / Rz) 49 | t + d 50 | -------------------------------------------------------------------------------- /benchmarks/rodinia/kmeans.dx: -------------------------------------------------------------------------------- 1 | def dist (x : d=>Float) (y : d=>Float) : Float = 2 | d = x - y 3 | sum $ for i. d.i * d.i 4 | 5 | def centroidsOf (points : n=>d=>Float) (membership : n=>k) : k=>d=>Float = 6 | clusterSums = yield_accum (AddMonoid Float) \clusterSums. 7 | for i. clusterSums!(membership.i) += points.i 8 | clusterSizes = yield_accum (AddMonoid Float) \clusterSizes. 9 | for i. clusterSizes!(membership.i) += 1.0 10 | for i. clusterSums.i / (max clusterSizes.i 1.0) 11 | 12 | def argminBy (_:Ord o) ?=> (f:a->o) (xs:n=>a) : n = 13 | minimum_by (\i. f xs.i) (for i. i) 14 | 15 | def kmeans 16 | (points : n=>d=>Float) 17 | (k : Int) 18 | (threshold : Int) 19 | (maxIterations : Int) 20 | : (Fin k)=>d=>Float = 21 | initCentroids = for i:(Fin k). points.(ordinal i@_) 22 | initMembership = for c:n. ((ordinal c `mod` k)@_) 23 | final = yield_state (initMembership, initCentroids, 0) \ref. 24 | while do 25 | (membership, centroids, i) = get ref 26 | membership' = for i. argminBy (dist points.i) centroids 27 | centroids' = centroidsOf points membership' 28 | delta = sum $ for i. b_to_i $ membership.i /= membership'.i 29 | ref := (membership', centroids', i + 1) 30 | delta > threshold && i < maxIterations 31 | (_, centroids, _) = final 32 | centroids 33 | -------------------------------------------------------------------------------- /benchmarks/rodinia/pathfinder.dx: -------------------------------------------------------------------------------- 1 | -- Assumes that off >= 0 2 | def (+|) (i:n) (off:Int) : n = 3 | newOrd = ordinal i + off 4 | case newOrd < size n of 5 | True -> unsafe_from_ordinal _ newOrd 6 | False -> i 7 | 8 | -- Assumes that off >= 0 9 | def (-|) (i:n) (off:Int) : n = 10 | newOrd = ordinal i - off 11 | case 0 <= newOrd of 12 | True -> unsafe_from_ordinal _ newOrd 13 | False -> i 14 | 15 | def pathfinder (world : rows=>cols=>Int) : cols=>Int = 16 | yield_state zero $ \costsRef. 17 | for r. 18 | costs = get costsRef 19 | costsRef := for c. world.r.c + (min costs.c $ (min costs.(c -| 1) 20 | costs.(c +| 1))) 21 | -------------------------------------------------------------------------------- /benchmarks/vjp_matmul.dx: -------------------------------------------------------------------------------- 1 | n = if dex_test_mode() then 10 else 500 2 | 3 | m1 = rand_mat n n randn (new_key 0) 4 | m2 = rand_mat n n randn (new_key 1) 5 | 6 | def mmp'(m1:l=>m=>Float, m2:m=>n=>Float) -> l=>n=>Float given (l|Ix, m|Ix, n|Ix) = 7 | snd(vjp(\m. transpose(m1) ** m, for _ _. 0.0))(m2) 8 | 9 | %bench "vjp_matmul" 10 | res = mmp'(m1, m2) 11 | > 12 | > vjp_matmul 13 | > Compile time: 130.231 ms 14 | > Run time: 7.178 us (based on 1 run) 15 | -------------------------------------------------------------------------------- /cabal.project: -------------------------------------------------------------------------------- 1 | packages: dex.cabal 2 | 3 | source-repository-package 4 | type: git 5 | location: https://github.com/llvm-hs/llvm-hs 6 | tag: llvm-9 7 | subdir: llvm-hs 8 | 9 | source-repository-package 10 | type: git 11 | location: https://github.com/llvm-hs/llvm-hs 12 | tag: llvm-9 13 | subdir: llvm-hs-pure 14 | -------------------------------------------------------------------------------- /default.nix: -------------------------------------------------------------------------------- 1 | { pkgs ? import {}, 2 | llvm-hs-src ? pkgs.fetchFromGitHub { 3 | owner = "llvm-hs"; 4 | repo = "llvm-hs"; 5 | rev = "llvm-12"; 6 | sha256 = "IG4Mh89bY+PtBJtzlXKYsPljfHP7OSQk03pV6fSmdRY="; 7 | }, 8 | cudaPackage ? pkgs.cudaPackages.cudatoolkit_11, 9 | cuda ? false, 10 | optimized ? true, 11 | live ? true, 12 | }: 13 | let 14 | llvm-hs-pure = pkgs.haskellPackages.callCabal2nix "llvm-hs-pure" "${llvm-hs-src}/llvm-hs-pure" { 15 | }; 16 | llvm-hs = (pkgs.haskellPackages.callCabal2nix "llvm-hs" "${llvm-hs-src}/llvm-hs" { 17 | inherit llvm-hs-pure; 18 | }).overrideAttrs (oldAttrs: rec { 19 | buildInputs = oldAttrs.buildInputs ++ [ 20 | pkgs.llvm_12 21 | ]; 22 | }); 23 | buildFlags = pkgs.lib.optionals optimized [ 24 | "-foptimized" 25 | ] ++ pkgs.lib.optionals live [ 26 | "-flive" 27 | ] ++ pkgs.lib.optionals cuda [ 28 | "-fcuda" 29 | "--extra-include-dirs=${cudaPackage}/include" 30 | "--extra-lib-dirs=${cudaPackage}/lib64/stubs" 31 | ]; 32 | cxxFlags = [ 33 | "-fPIC" 34 | "-std=c++11" 35 | "-fno-exceptions" 36 | "-fno-rtti" 37 | ] ++ pkgs.lib.optional cuda "-DDEX_CUDA" 38 | ++ pkgs.lib.optional live "-DDEX_LIVE"; 39 | buildRuntimeCommand = '' 40 | ${pkgs.clang_9}/bin/clang++ \ 41 | ${builtins.concatStringsSep " " cxxFlags} \ 42 | -c \ 43 | -emit-llvm \ 44 | -I${pkgs.libpng}/include \ 45 | src/lib/dexrt.cpp \ 46 | -o src/lib/dexrt.bc 47 | ''; 48 | in 49 | # `callCabal2nix` converts `dex.cabal` into a Nix file and builds it. 50 | # Before we do the Haskell build though, we need to first compile the Dex runtime 51 | # so it's properly linked in when compiling Dex. Normally the makefile does this, 52 | # so we instead sneak compiling the runtime in the configuration phase for the Haskell build. 53 | (pkgs.haskellPackages.callCabal2nix "dex" ./. { 54 | inherit llvm-hs; 55 | inherit llvm-hs-pure; 56 | }).overrideAttrs (attrs: { 57 | configurePhase = '' 58 | # Compile the Dex runtime 59 | echo 'Compiling the Dex runtime...' 60 | set -x 61 | ${buildRuntimeCommand} 62 | set +x 63 | echo 'Done compiling the Dex runtime.' 64 | 65 | # Run the Haskell configuration phase 66 | ${attrs.configurePhase} 67 | ''; 68 | configureFlags = builtins.concatStringsSep " " buildFlags; 69 | buildInputs = attrs.buildInputs ++ (pkgs.lib.optional cuda 70 | cudaPackage 71 | ); 72 | }) 73 | -------------------------------------------------------------------------------- /doc/conditionals.dx: -------------------------------------------------------------------------------- 1 | '# Syntax of if expressions 2 | 3 | 'The basic syntax of `if` in Dex is 4 | ``` 5 | if then [else ] 6 | ``` 7 | 8 | 'It can be a bit confusing, though, because of all the tokens it may make sense to indent. 9 | 10 | 'The main rules are: 11 | - The `else` clause is optional (regardless of indentation) 12 | - The `then` and `else` keywords can be inline with the preceding 13 | code, or indented relative to the `if`. 14 | - The code for each arm of the `if` can be either an inline expression 15 | or start a new indentation level (relative to its keyword if that is 16 | indented, or relative to the whole `if` otherwise). 17 | 18 | 'This produces four combinations for one-armed `if`, all of which are legal: 19 | 20 | :p 21 | yield_accum (AddMonoid Float) \ref. 22 | if True then ref += 3. 23 | if True then 24 | ref += 1. 25 | ref += 2. 26 | if True 27 | then ref += 3. 28 | if False 29 | then 30 | ref += 1. 31 | ref += 2. 32 | > 9. 33 | 34 | 'However, not every one of the 16 concievable combinations makes sense for two-armed `if`. 35 | To wit: 36 | - If the consequent is indented, it makes no sense to have the `else` 37 | inline (eliminating 4 combinations). 38 | - If `then` is inline, there can be no indented `else` either, because 39 | there is no readable level at which to indent it (elimintaing 2 more 40 | combinations). 41 | 42 | 'The following contrived code block shows all the acceptable configurations: 43 | 44 | :p 45 | yield_accum (AddMonoid Float) \ref. 46 | -- Two-armed `if` with `then` and the consequent both inline. 47 | x = if False then 1. else 3. 48 | if False then ref += 100. else 49 | ref += 1. 50 | ref += 2. 51 | if False then ref += 200. 52 | else ref += x 53 | if False then ref += 300. 54 | else 55 | ref += 1. 56 | ref += 2. 57 | 58 | -- Two-armed `if` with `then` indented but the consequent inline. 59 | y = if False 60 | then 1. else 3. 61 | if False 62 | then ref += 100. else 63 | ref += 1. 64 | ref += 2. 65 | if False 66 | then ref += 200. 67 | else ref += y 68 | if False 69 | then ref += 300. 70 | else 71 | ref += 1. 72 | ref += 2. 73 | 74 | -- Two-armed `if` with `then` and the consequent both indented. 75 | if False 76 | then 77 | ref += 100. 78 | ref += 200. 79 | else ref += 3. 80 | if False 81 | then 82 | ref += 100. 83 | ref += 200. 84 | else 85 | ref += 2. 86 | ref += 4. 87 | > 27. 88 | 89 | 'And here are expample configurations that don't work, showing the resulting parse errors. 90 | 91 | 'Inline `else` is not allowed after indented consequent, whether the 92 | `then` keyword is indented or not: 93 | 94 | if True 95 | then 96 | x = 6 97 | x else 5 98 | 99 | > Parse error:97:12: 100 | > | 101 | > 97 | x else 5 102 | > | ^ 103 | > Same-line `else` may not follow indented consequent; put the `else` on the next line. 104 | 105 | if True then 106 | x = 6 107 | x else 5 108 | 109 | > Parse error:107:10: 110 | > | 111 | > 107 | x else 5 112 | > | ^ 113 | > No `else` may follow same-line `then` and indented consequent; indent and align both `then` and `else`, or write the whole `if` on one line. 114 | 115 | 'Indented `else` is not allowed after inline `then` and indented 116 | consequent either, because there is no indentation level for it to match. 117 | 118 | :p 119 | if True then 120 | x = 6 121 | x 122 | else 5 123 | 124 | > Parse error:122:8: 125 | > | 126 | > 122 | else 5 127 | > | ^ 128 | > No `else` may follow same-line `then` and indented consequent; indent and align both `then` and `else`, or write the whole `if` on one line. 129 | 130 | -------------------------------------------------------------------------------- /examples/export/array.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | extern "C" { 7 | extern int addVecF32(int32_t n, float* x, float* y, float* result); 8 | } 9 | 10 | template 11 | void addArrays(std::array& x, std::array& y, std::array& z) { 12 | assert(addVecF32(N, x.data(), y.data(), z.data()) == 0); 13 | } 14 | 15 | int main() { 16 | std::array x = {0, 1, 2, 3, 4}; 17 | std::array y = {1.23, 4.12, 6.21, 9.64, 3.61}; 18 | std::array z; 19 | addArrays(x, y, z); 20 | for (size_t i = 0; i < x.size(); ++i) { 21 | assert(z[i] == x[i] + y[i]); 22 | } 23 | return EXIT_SUCCESS; 24 | } 25 | 26 | -------------------------------------------------------------------------------- /examples/export/array.dx: -------------------------------------------------------------------------------- 1 | 2 | def addVecF32 (n : Int) (x : (Fin n)=>Float) (y : (Fin n)=>Float) : (Fin n)=>Float = 3 | for i. x.i + y.i 4 | 5 | :export addVecF32 addVecF32 6 | -------------------------------------------------------------------------------- /examples/export/scalar.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | extern "C" { 7 | extern int addi32(int32_t x, int32_t y, int32_t* result); 8 | extern int addf32(float x, float y, float* result); 9 | } 10 | 11 | int32_t add(int32_t x, int32_t y) { 12 | int32_t result = 0; 13 | assert(addi32(x, y, &result) == 0); 14 | return result; 15 | } 16 | 17 | float add(float x, float y) { 18 | float result = 0; 19 | assert(addf32(x, y, &result) == 0); 20 | return result; 21 | } 22 | 23 | int main() { 24 | std::array i32_examples = {0, 1, 2, 3, 4}; 25 | std::array f32_examples = {1.23, 4.12, 6.21, 9.64, 3.61}; 26 | for (int32_t x : i32_examples) { 27 | for (int32_t y : i32_examples) { 28 | assert(add(x, y) == x + y); 29 | } 30 | } 31 | for (float x : f32_examples) { 32 | for (float y : f32_examples) { 33 | assert(add(x, y) == x + y); 34 | } 35 | } 36 | return EXIT_SUCCESS; 37 | } 38 | -------------------------------------------------------------------------------- /examples/export/scalar.dx: -------------------------------------------------------------------------------- 1 | addi32 : Int32 -> Int32 -> Int32 = (+) 2 | :export addi32 addi32 3 | 4 | addf32 : Float32 -> Float32 -> Float32 = (+) 5 | :export addf32 addf32 6 | -------------------------------------------------------------------------------- /examples/kernelregression.dx: -------------------------------------------------------------------------------- 1 | '# Kernel Regression 2 | 3 | import linalg 4 | import plot 5 | 6 | struct ConjGradState(a|VSpace) = 7 | x : a 8 | r : a 9 | p : a 10 | 11 | # Conjugate gradients solver 12 | def solve'(mat:m=>m=>Float, b:m=>Float) -> m=>Float given (m|Ix) = 13 | x0 = zero 14 | ax = mat **. x0 15 | r0 = b - ax 16 | init = ConjGradState(zero, r0, r0) 17 | final = fold init \_:m state. 18 | x = state.x 19 | r = state.r 20 | p = state.p 21 | ap = mat **. p 22 | alpha = vdot r r / vdot p ap 23 | x' = x + alpha .* p 24 | r' = r - alpha .* ap 25 | beta = vdot r' r' / (vdot r r + 0.000001) 26 | p' = r' + beta .* p 27 | ConjGradState(x', r', p') 28 | final.x 29 | 30 | def chol_solve(l:LowerTriMat m Float, b:m=>Float) -> m=>Float given (m|Ix) = 31 | b' = forward_substitute l b 32 | u = transpose_lower_to_upper l 33 | backward_substitute u b' 34 | 35 | ' # Kernel ridge regression 36 | 37 | ' To learn a function $f_{true}: \mathcal{X} \to \mathbb R$ 38 | from data $(x_1, y_1),\dots,(x_N, y_N)\in \mathcal{X}\times\mathbb R$,\ 39 | in kernel ridge regression the hypothesis takes the form 40 | $f(x)=\sum_{i=1}^N \alpha_i k(x_i, x)$,\ 41 | where $k:\mathcal X \times \mathcal X \to \mathbb R$ is a positive semidefinite kernel function.\ 42 | The optimal coefficients are found by solving a linear system $\alpha=G^{-1}y$,\ 43 | where $G_{ij}:=k(x_i, x_j)+\delta_{ij}\lambda$, $\lambda>0$ and $y = (y_1,\dots,y_N)^\top\in\mathbb R^N$ 44 | 45 | # Synthetic data 46 | Nx = Fin 100 47 | noise = 0.1 48 | [k1, k2] = split_key (new_key 0) 49 | 50 | def trueFun(x:Float) -> Float = 51 | x + sin (20.0 * x) 52 | 53 | xs : Nx=>Float = for i. rand (ixkey k1 i) 54 | ys : Nx=>Float = for i. trueFun xs[i] + noise * randn (ixkey k2 i) 55 | 56 | # Kernel ridge regression 57 | def regress(kernel: (a, a) -> Float, xs: Nx=>a, ys: Nx=>Float) -> (a) -> Float given (a) = 58 | gram = for i j. kernel xs[i] xs[j] + select (i==j) 0.0001 0.0 59 | alpha = solve' gram ys 60 | \x. sum for i. alpha[i] * kernel xs[i] x 61 | 62 | def rbf(lengthscale:Float, x:Float, y:Float) -> Float = 63 | exp (-0.5 * sq ((x - y) / lengthscale)) 64 | 65 | predict = regress (\x y. rbf 0.2 x y) xs ys 66 | 67 | # Evaluation 68 | Nxtest = Fin 1000 69 | xtest : Nxtest=>Float = for i. rand (ixkey k1 i) 70 | preds = map predict xtest 71 | 72 | :html show_plot $ xy_plot xs ys 73 | > 74 | 75 | :html show_plot $ xy_plot xtest preds 76 | > 77 | 78 | ' # Gaussian process regression 79 | 80 | ' GP regression (kriging) works in a similar way. Compared with kernel ridge regression, GP regression assumes Gaussian distributed prior. This, combined 81 | with the Bayes rule, gives the variance of the prediction. 82 | 83 | ' In this implementation, the conjugate gradient solver is replaced with the 84 | cholesky solver from `lib/linalg.dx` for efficiency. 85 | 86 | def gp_regress( 87 | kernel: (a, a) -> Float, 88 | xs: n=>a, 89 | ys: n=>Float 90 | ) -> ((a) -> (Float, Float)) given (n|Ix, a) = 91 | noise_var = 0.0001 92 | gram = for i j. kernel xs[i] xs[j] 93 | c = chol (gram + eye *. noise_var) 94 | alpha = chol_solve c ys 95 | predict = \x. 96 | k' = for i. kernel xs[i] x 97 | mu = sum for i. alpha[i] * k'[i] 98 | alpha' = chol_solve c k' 99 | var = kernel x x + noise_var - sum for i. k'[i] * alpha'[i] 100 | (mu, var) 101 | predict 102 | 103 | gp_predict = gp_regress (\x y. rbf 0.2 x y) xs ys 104 | 105 | (gp_preds, vars) = unzip (map gp_predict xtest) 106 | 107 | :html show_plot $ xyc_plot xtest gp_preds (map sqrt vars) 108 | > 109 | 110 | :html show_plot $ xy_plot xtest vars 111 | > 112 | -------------------------------------------------------------------------------- /examples/latex.dx: -------------------------------------------------------------------------------- 1 | '# $\href{https://katex.org/}{\KaTeX}$ Rendering Examples 2 | 3 | 'This document demonstrates $\KaTeX$ rendering in literate Dex programs. 4 | 5 | '## Random examples 6 | 7 | '$$\text{This is a multiline equation:} \\\\ \textbf{A}\textbf{x} = \textbf{b}$$ 8 | 9 | '$$f(\relax{x}) = \int_{-\infty}^\infty \hat{f}(\xi)\,e^{2 \pi i \xi x} \,d\xi$$ 10 | 11 | '## [Environments](https://katex.org/docs/supported.html#environments) 12 | 13 | '$$\begin{matrix} a & b \\\\ c & d \end{matrix}$$ 14 | 15 | '$$\begin{pmatrix} a & b \\\\ c & d \end{pmatrix}$$ 16 | 17 | '$$\begin{bmatrix} a & b \\\\ c & d \end{bmatrix}$$ 18 | 19 | '$$\def\arraystretch{1.5} \begin{array}{c:c:c} a & b & c \\\\ \hline d & e & f \\\\ \hdashline g & h & i \end{array}$$ 20 | 21 | '$$\begin{aligned} a&=b+c \\\\ d+e&=f \end{aligned}$$ 22 | 23 | '$$\begin{alignedat}{2} 10&x+ &3&y = 2 \\\\ 3&x+&13&y = 4 \end{alignedat}$$ 24 | 25 | '$$\begin{gathered} a=b \\\\ e=b+c \end{gathered}$$ 26 | 27 | '$$x = \begin{cases} a &\text{if } b \\\\ c &\text{if } d \end{cases}$$ 28 | 29 | '$$\begin{rcases} a &\text{if } b \\\\ c &\text{if } d \end{rcases} \Rightarrow \dots$$ 30 | 31 | '## No LaTeX rendering in non-prose-blocks 32 | 33 | 'LaTeX rendering should not occur in code blocks, nor in error output cells. 34 | 35 | def array_sum(x:a=>Int32) -> Int32 given (a|Ix) = 36 | # Note: the following line has `$ ... $`, but it should not trigger KaTeX. 37 | # Note: the incorrect usage of `with_state` below is intentional to verify 38 | # that `$ ... $` is not rendered as LaTeX in error output cells. 39 | snd $ with_state 0 \acc. 40 | for i. 41 | acc := (get acc) + x[i] 42 | > Type error: 43 | > Expected: (a.1, b) 44 | > Actual: (a => ()) 45 | > (Solving for: [a.1, b]) 46 | > 47 | > snd $ with_state 0 \acc. 48 | > ^^^^^^^^^^^^^^^^^^^ 49 | 50 | '## [Layout annotation](https://katex.org/docs/supported.html#annotation) 51 | 52 | '$$\overbrace{a+b+c}^{\text{note}}$$ 53 | 54 | '$$\underbrace{a+b+c}_{\text{note}}$$ 55 | 56 | '$$\xcancel{\text{second-order array combinators}}$$ 57 | 58 | '## [Logic and Set Theory](https://katex.org/docs/supported.html#logic-and-set-theory) 59 | 60 | '$$\begin{aligned} \forall \\; & \texttt{\textbackslash forall} & \complement \\; & \texttt{\textbackslash complement} & \therefore \\; & \texttt{\textbackslash therefore} & \emptyset \\; & \texttt{\textbackslash emptyset} \\\\ \exists \\; & \texttt{\textbackslash exists} & \subset \\; & \texttt{\textbackslash subset} & \because \\; & \texttt{\textbackslash because} & \empty \\; & \texttt{\textbackslash empty} \\\\ \exist \\; & \texttt{\textbackslash exist} & \supset \\; & \texttt{\textbackslash supset} & \mapsto \\; & \texttt{\textbackslash mapsto} & \varnothing \\; & \texttt{\textbackslash varnothing} \\\\ \nexists \\; & \texttt{\textbackslash nexists} & \mid \\; & \texttt{\textbackslash mid} & \to \\; & \texttt{\textbackslash to} & \implies \\; & \texttt{\textbackslash implies} \\\\ \in \\; & \texttt{\textbackslash in} & \land \\; & \texttt{\textbackslash land} & \gets \\; & \texttt{\textbackslash gets} & \impliedby \\; & \texttt{\textbackslash impliedby} \\\\ \isin \\; & \texttt{\textbackslash isin} & \lor \\; & \texttt{\textbackslash lor} & \leftrightarrow \\; & \texttt{\textbackslash leftrightarrow} & \iff \\; & \texttt{\textbackslash iff} \\\\ \notin \\; & \texttt{\textbackslash notin} & \ni \\; & \texttt{\textbackslash ni} & \notni \\; & \texttt{\textbackslash notni} & \neg \\; & \texttt{\textbackslash neg} \\\\ \lnot \\; & \texttt{\textbackslash lnot} \\\\ \end{aligned}$$ 61 | -------------------------------------------------------------------------------- /examples/levenshtein-distance.dx: -------------------------------------------------------------------------------- 1 | '# Levenshtein Distance 2 | 3 | 'May 13, 2022 4 | 5 | 'The Levenshtein distance is the minimum edit distance between two 6 | sequences, counting insertions, deletions, and substitutions each as 1. 7 | 8 | 'Let's see how well Dex handles the conventional dynamic program for 9 | this: For each pair of prefixes of the input strings, the Levenshtein 10 | distance between those prefixes is the minimum obtainable from the 11 | prefixes one shorter by inserting, deleting, or substituting an 12 | element. 13 | 14 | 'Here is the helper that builds the dynamic program table. Dex's 15 | flexible index sets let us encode the fact that the table is 1 larger 16 | in each dimension than the inputs. By capturing the relationship 17 | statically we avoid both programmer off-by-one errors and runtime array 18 | bounds checks. 19 | 20 | def levenshtein_table(xs: n=>a, ys: m=>a) 21 | -> (Post n => Post m => Nat) given (n|Ix, m|Ix, a|Eq) = 22 | yield_state (for _ _. 0) \tab. 23 | for i:(Post n). tab!i!first_ix := ordinal i 24 | for j:(Post m). tab!first_ix!j := ordinal j 25 | for i:n j:m. 26 | subst_cost = if xs[i] == ys[j] then 0 else 1 27 | d_subst = get tab!(left_post i)!(left_post j) + subst_cost 28 | d_delete = get tab!(left_post i)!(right_post j) + 1 29 | d_insert = get tab!(right_post i)!(left_post j) + 1 30 | tab!(right_post i)!(right_post j) := 31 | minimum [d_subst, d_delete, d_insert] 32 | 33 | %time 34 | levenshtein_table ['k', 'i', 't', 't', 'e', 'n'] ['s', 'i', 't', 't', 'i', 'n', 'g'] 35 | > [[0, 1, 2, 3, 4, 5, 6, 7], [1, 1, 2, 3, 4, 5, 6, 7], [2, 2, 1, 2, 3, 4, 5, 6], [3, 3, 2, 1, 2, 3, 4, 5], [4, 4, 3, 2, 1, 2, 3, 4], [5, 5, 4, 3, 2, 2, 3, 4], [6, 6, 5, 4, 3, 3, 2, 3]] 36 | > 37 | > Compile time: 77.001 ms 38 | > Run time: 6.386 us 39 | 40 | 'The actual distance is of course just the last element of the table. 41 | 42 | def levenshtein(xs: n=>a, ys: m=>a) -> Nat given (n|Ix, m|Ix, a|Eq) = 43 | levenshtein_table(xs, ys)[last_ix, last_ix] 44 | 45 | %time 46 | levenshtein (iota $ Fin 100) (iota $ Fin 100) 47 | > 0 48 | > 49 | > Compile time: 55.368 ms 50 | > Run time: 19.665 us 51 | 52 | '## Speed 53 | 54 | 'To check that we don't embarrass ourselves on performance, let's run 55 | the Sountsov benchmark: Compute Levenshtein distances for all pairs of 56 | `arange` arrays of size up to 100. 57 | 58 | %bench "Sountsov Benchmark" 59 | answer = for i:(Fin 100). 60 | for j:(Fin 100). 61 | iint = ordinal i 62 | jint = ordinal j 63 | levenshtein (iota $ Fin iint) (iota $ Fin jint) 64 | > 65 | > Sountsov Benchmark 66 | > Compile time: 81.135 ms 67 | > Run time: 79.545 ms (based on 26 runs) 68 | sum(sum(answer)) 69 | > 333300 70 | 71 | 'The straightforward C++ program for this takes about 35ms on my 72 | workstation, so Dex performance is in the right ballpark. (And we 73 | know several optimizations that should let us close the gap.) As of 74 | this writing, native JAX takes 15 minutes, due to tracing and 75 | compiling the body 10,000 times (once for each pair of input sizes). 76 | 77 | '## Real Data 78 | 79 | 'Just for fun, we can make a crude spelling correcter out of this 80 | distance function on words: 81 | 82 | AsList(ct, words) = lines $ unsafe_io \. read_file "/usr/share/dict/words" 83 | 84 | def closest_word(s:String) -> String = 85 | AsList(_, s') = s 86 | fst $ minimum_by snd for i. 87 | AsList(_, word) = words[i] 88 | (words[i], levenshtein word s') 89 | 90 | %time 91 | closest_word "hello" 92 | > "hello" 93 | > 94 | > Compile time: 126.402 ms 95 | > Run time: 73.713 ms 96 | 97 | closest_word "kitttens" 98 | > "kittens" 99 | closest_word "functor" 100 | > "function" 101 | closest_word "applicative" 102 | > "application" 103 | closest_word "monoids" 104 | > "ovoids" 105 | closest_word "semigroup" 106 | > "subgroup" 107 | closest_word "paralllel" 108 | > "parallel" 109 | -------------------------------------------------------------------------------- /examples/mandelbrot.dx: -------------------------------------------------------------------------------- 1 | '# Mandelbrot Set 2 | 3 | import complex 4 | import plot 5 | 6 | # Escape time algorithm 7 | 8 | def update(c:Complex, z:Complex) -> Complex = c + (z * z) 9 | 10 | tol = 2.0 11 | def inBounds(z:Complex) -> Bool = complex_abs(z) < tol 12 | 13 | def escapeTime(c:Complex) -> Nat = 14 | z <- with_state(zero :: Complex) 15 | bounded_iter(1000, 1000) \i. 16 | case inBounds(get(z)) of 17 | False -> Done(i) 18 | True -> 19 | z := update(c, get(z)) 20 | Continue 21 | 22 | # Evaluate on a grid and plot the results 23 | 24 | xs = linspace(Fin 300, -2.0, 1.0) 25 | ys = linspace(Fin 200, -1.0, 1.0) 26 | 27 | escapeGrid = each(ys) \y. each xs \x. n_to_f(escapeTime(Complex(x, y))) 28 | 29 | :html matshow(-escapeGrid) 30 | > 31 | -------------------------------------------------------------------------------- /examples/mcmc.dx: -------------------------------------------------------------------------------- 1 | '# Markov Chain Monte Carlo 2 | 3 | '## General MCMC utilities 4 | 5 | import plot 6 | 7 | LogProb : Type = Float 8 | 9 | def runChain( 10 | initialize: (Key) -> a, 11 | step: (Key, a) -> a, 12 | numSamples: Nat, 13 | k:Key 14 | ) -> Fin numSamples => a given (a|Data) = 15 | [k1, k2] = split_key(n=2, k) 16 | with_state (initialize k1) \s. 17 | for i:(Fin numSamples). 18 | x = step (ixkey k2 i) (get s) 19 | s := x 20 | x 21 | 22 | def propose( 23 | logDensity: (a) -> LogProb, 24 | cur: a, 25 | proposal: a, 26 | k: Key 27 | ) -> a given (a:Type) = 28 | accept = logDensity proposal > (logDensity cur + log (rand k)) 29 | select accept proposal cur 30 | 31 | def meanAndCovariance(xs:n=>d=>Float) -> (d=>Float, d=>d=>Float) given (n|Ix, d|Ix) = 32 | xsMean : d=>Float = (for i:d. sum for j:n. xs[j,i]) / n_to_f (size n) 33 | xsCov : d=>d=>Float = (for i:d i':d. sum for j:n. 34 | (xs[j,i'] - xsMean[i']) * 35 | (xs[j,i ] - xsMean[i ]) ) / (n_to_f (size n) - 1) 36 | (xsMean, xsCov) 37 | 38 | '## Metropolis-Hastings implementation 39 | 40 | MHParams : Type = Float # step size 41 | 42 | def mhStep( 43 | stepSize: MHParams, 44 | logProb: (d=>Float) -> LogProb, 45 | k:Key, 46 | x:d=>Float 47 | ) -> d=>Float given (d|Ix) = 48 | [k1, k2] = split_key(n=2, k) 49 | proposal = x + stepSize .* randn_vec k1 50 | propose logProb x proposal k2 51 | 52 | '## HMC implementation 53 | 54 | struct HMCParams = 55 | nsteps : Nat 56 | dt : Float 57 | 58 | struct HMCState(a|VSpace) = 59 | x: a 60 | p: a 61 | 62 | def leapfrogIntegrate( 63 | params: HMCParams, 64 | logProb: (a) -> LogProb, 65 | init: HMCState a 66 | ) -> HMCState a given (a|VSpace) = 67 | x = init.x + (0.5 * params.dt) .* init.p 68 | final = apply_n params.nsteps HMCState(x, init.p) \old. 69 | pNew = old.p + params.dt .* grad logProb old.x 70 | xNew = old.x + params.dt .* pNew 71 | HMCState(xNew, pNew) 72 | p = final.p + (0.5 * params.dt) .* grad logProb final.x 73 | HMCState(final.x, p) 74 | 75 | def hmcStep( 76 | params: HMCParams, 77 | logProb: (d=>Float) -> LogProb, 78 | k: Key, 79 | x: d=>Float 80 | ) -> d=>Float given (d|Ix) = 81 | def hamiltonian(s:HMCState (d=>Float)) -> Float = 82 | logProb s.x - 0.5 * vdot s.p s.p 83 | [k1, k2] = split_key(n=2, k) 84 | p = randn_vec k1 :: d => Float 85 | proposal = leapfrogIntegrate params logProb HMCState(x, p) 86 | final = propose hamiltonian HMCState(x, p) proposal k2 87 | final.x 88 | 89 | '## Test it out 90 | 91 | 'Generate samples from a multivariate normal distribution N([1.5, 2.5], [[1., 0.], [0., 0.05]]). 92 | 93 | def myLogProb(x:(Fin 2)=>Float) -> LogProb = 94 | x' = x - [1.5, 2.5] 95 | neg $ 0.5 * inner x' [[1.,0.],[0.,20.]] x' 96 | def myInitializer(k:Key) -> Fin 2 => Float = 97 | randn_vec(k) 98 | 99 | numSamples : Nat = 100 | if dex_test_mode() 101 | then 1000 102 | else 10000 103 | k0 = new_key 1 104 | 105 | mhParams = 0.1 106 | mhSamples = runChain myInitializer (\k x. mhStep mhParams myLogProb k x) numSamples k0 107 | 108 | :p meanAndCovariance mhSamples 109 | > ([0.5455918, 2.522631], [[0.3552593, 0.05022133], [0.05022133, 0.08734216]]) 110 | 111 | :html show_plot $ y_plot $ 112 | slice (each mhSamples head) 0 (Fin 1000) 113 | > 114 | 115 | hmcParams = HMCParams(10, 0.1) 116 | hmcSamples = runChain myInitializer (\k x. hmcStep hmcParams myLogProb k x) numSamples k0 117 | 118 | :p meanAndCovariance hmcSamples 119 | > ([1.472011, 2.483082], [[1.054705, -0.002082013], [-0.002082013, 0.05058844]]) 120 | 121 | :html show_plot $ y_plot $ 122 | slice (each hmcSamples head) 0 (Fin 1000) 123 | > 124 | -------------------------------------------------------------------------------- /examples/mnist-nearest-neighbors.dx: -------------------------------------------------------------------------------- 1 | '# THIS FILE IS STALE 2 | 3 | '(But we plan to update it at some point) 4 | 5 | load dxbo "scratch/mnist.dxbo" as mnist 6 | 7 | :t mnist 8 | 9 | # TODO: these should come from the data set itself 10 | type Img = 28=>28=>Float 11 | type NTrain = 60000 12 | type NTest = 10000 13 | 14 | (xsTrain, ysTrain, xsTest, ysTest) = mnist 15 | 16 | findNearestNeighbor : (a -> a -> Float) -> n=>a -> a -> n 17 | findNearestNeighbor metric xs x = 18 | distances.i = metric xs.i x 19 | argmin distances 20 | 21 | imgDistance : Img -> Img -> Float 22 | imgDistance x y = sum for (i,j). sq (x.i.j - y.i.j) 23 | 24 | fracTrue : n=>Bool -> Float 25 | fracTrue xs = mean for i. float (b2i xs.i) 26 | 27 | example = asidx @NTest 123 28 | :plotmat xsTest.example 29 | 30 | # look at closest match in the training set 31 | closest = findNearestNeighbor imgDistance xsTrain (xsTest.example) 32 | :plotmat xsTrain.closest 33 | 34 | # Make a subset of the test set (evaluating a single test example takes ~80ms) 35 | type NTestSmall = 1000 36 | xsTest' = slice @NTestSmall xsTest 0 37 | ysTest' = slice @NTestSmall ysTest 0 38 | 39 | closestTrainExample : NTestSmall => NTrain 40 | closestTrainExample.i = findNearestNeighbor imgDistance xsTrain xsTest'.i 41 | 42 | :p fracTrue for i. ysTrain.(closestTrainExample.i) == ysTest'.i 43 | -------------------------------------------------------------------------------- /examples/particle-filter.dx: -------------------------------------------------------------------------------- 1 | '# Particle Filter 2 | 3 | struct Distribution(range:Type) = 4 | sample : (Key) -> range 5 | logprob : (range) -> Float 6 | 7 | struct Model(state:Type, observation:Type) = 8 | init : Distribution state 9 | dynamics : (state) -> Distribution state 10 | observe : (state) -> Distribution observation 11 | 12 | def simulate(model: Model s v, t: Nat, key: Key) -> Fin t=>(s, v) given (s|Data, v) = 13 | [key, subkey] = split_key key 14 | s0 = model.init.sample(subkey) 15 | with_state s0 \s_ref . 16 | for i. 17 | [k1, k2] = split_key (ixkey key i) 18 | s = get s_ref 19 | s_next = model.dynamics(s).sample(k1) 20 | v = model.observe(s).sample(k2) 21 | s_ref := s_next 22 | (s, v) 23 | 24 | def particleFilter( 25 | num_particles: Nat, 26 | num_timesteps: Nat, 27 | model: Model s v, 28 | summarize: (Fin num_particles => s) -> a, 29 | obs: Fin num_timesteps=>v, 30 | key: Key 31 | ) -> Fin num_timesteps => a given (s|Data, a, v) = 32 | [key, init_key] = split_key key 33 | init_particles = for i: (Fin num_particles). model.init.sample(ixkey init_key i) 34 | with_state init_particles \p_ref . 35 | for t: (Fin num_timesteps). 36 | p_prev = get p_ref 37 | logLikelihoods = for i. model.observe(p_prev[i]).logprob(obs[t]) 38 | [resample_key, dynamics_key] = split_key (ixkey key t) 39 | resampled_idxs = categorical_batch logLikelihoods resample_key 40 | p_resampled = for i. p_prev[resampled_idxs[i]] 41 | p_next = for i. model.dynamics(p_resampled[i]).sample(ixkey dynamics_key i) 42 | p_ref := p_next 43 | summarize p_resampled 44 | 45 | def normalDistn(mean: Float, var: Float) -> Distribution Float = 46 | Distribution( \k. (randn k) * (sqrt var) + mean 47 | , \v. -0.5 * (sq (v - mean)) / var - 0.5 * log (2.0 * pi * var)) 48 | 49 | gaussModel : Model Float Float = Model( 50 | normalDistn 0.1 0.1, 51 | \s. normalDistn s 1.0, 52 | \s. normalDistn s 1.0) 53 | 54 | timesteps = 10 55 | num_particles = 10000 56 | 57 | truth = for i:(Fin timesteps). 58 | s = n_to_f (ordinal i) 59 | (s, (normalDistn s 1.0).sample(ixkey (new_key 0) i)) 60 | 61 | filtered = particleFilter num_particles _ gaussModel mean (map snd truth) (new_key 0) 62 | 63 | # :p for i. (truth[i], filtered[i]) 64 | -------------------------------------------------------------------------------- /examples/pi.dx: -------------------------------------------------------------------------------- 1 | '# Monte Carlo Estimates of Pi 2 | 3 | 'Consider the unit circle centered at the origin. 4 | 5 | 'Consider the first quadrant: the unit circle quadrant and its $1 \times 1$ bounding unit square. 6 | 7 | '$$\text{Area of unit circle quadrant: } \\\\ A_{quadrant} = \frac{\pi r^2}{4} = \frac{\pi}{4}$$ 8 | 9 | '$$\text{Area of unit square: } \\\\ A_{square} = 1$$ 10 | 11 | '$$\text{Compute } \pi \text{ via ratios: } \\\\ \frac{A_{quadrant}}{A_{square}} = \frac{\pi}{4}, \\; \pi = 4 \thinspace \frac{A_{quadrant}}{A_{square}} $$ 12 | 13 | 'To compute $\pi$, randomly sample points in the first quadrant unit square to estimate the $\frac{A_{quadrant}}{A_{square}}$ ratio. Then, multiply by $4$. 14 | 15 | def estimatePiArea(key:Key) -> Float = 16 | [k1, k2] = split_key(n=2, key) 17 | x = rand k1 18 | y = rand k2 19 | inBounds = (sq x + sq y) < 1.0 20 | 4.0 * b_to_f inBounds 21 | 22 | def estimatePiAvgVal(key:Key) -> Float = 23 | x = rand key 24 | 4.0 * sqrt (1.0 - sq x) 25 | 26 | def meanAndStdDev(n:Nat, f: (Key) -> Float, key:Key) -> (Float, Float) = 27 | samps = for i:(Fin n). many f key i 28 | (mean samps, std samps) 29 | 30 | numSamps = 1000000 31 | 32 | :p meanAndStdDev numSamps estimatePiArea (new_key 0) 33 | > (3.141656, 1.642139) 34 | 35 | :p meanAndStdDev numSamps estimatePiAvgVal (new_key 0) 36 | > (3.145509, 0.8862508) 37 | -------------------------------------------------------------------------------- /examples/psd.dx: -------------------------------------------------------------------------------- 1 | '# PSD Solver Based on Cholesky Decomposition 2 | 3 | import linalg 4 | 5 | def psdsolve(mat:n=>n=>Float, b:n=>Float) -> n=>Float given (n|Ix) = 6 | l = chol mat 7 | b' = forward_substitute l b 8 | u = transpose_lower_to_upper l 9 | backward_substitute u b' 10 | 11 | ' Test 12 | 13 | N = Fin 4 14 | [k1, k2] = split_key $ new_key 0 15 | 16 | psd : N=>N=>Float = 17 | a = for i:N j:N. randn $ ixkey k1 (i, j) 18 | x = a ** transpose a 19 | x + eye 20 | 21 | def padLowerTriMat(mat:LowerTriMat n v) -> n=>n=>v given (n|Ix, v|Add) = 22 | for i j. 23 | if (ordinal j)<=(ordinal i) 24 | then mat[i,unsafe_project j] 25 | else zero 26 | 27 | l = chol psd 28 | l_full = padLowerTriMat l 29 | :p l_full 30 | > [[1.621016, 0., 0., 0.], [0.7793013, 2.965358, 0., 0.], [-0.6449394, 1.054188, 2.194109, 0.], [0.1620137, -1.009056, -1.49802, 1.355752]] 31 | 32 | psdReconstructed = l_full ** transpose l_full 33 | 34 | :p sum for pair. 35 | (i, j) = pair 36 | sq (psd[i,j] - psdReconstructed[i,j]) 37 | > 1.421085e-12 38 | 39 | vec : N=>Float = arb k2 40 | 41 | vec ~~ (psd **. psdsolve psd vec) 42 | > True 43 | -------------------------------------------------------------------------------- /examples/regression.dx: -------------------------------------------------------------------------------- 1 | '# Basis Function Regression 2 | 3 | import plot 4 | 5 | 6 | struct SolverState(n|Ix) = 7 | x : n=>Float 8 | r : n=>Float 9 | p : n=>Float 10 | 11 | # Conjugate gradients solver 12 | def solve(mat:m=>m=>Float, b:m=>Float) -> m=>Float given (m|Ix) = 13 | x0 = zero :: m=>Float 14 | r0 = b - (mat **. x0) 15 | n_iter = size m 16 | result = apply_n(n_iter, SolverState(x0, r0, r0)) \s. 17 | r = s.r 18 | p = s.p 19 | ap = mat **. p 20 | alpha = vdot r r / vdot p ap 21 | x' = s.x + alpha .* p 22 | r' = r - alpha .* ap 23 | beta = vdot r' r' / (vdot r r + 0.000001) 24 | p' = r' + beta .* p 25 | SolverState(x', r', p') 26 | result.x 27 | 28 | 'Make some synthetic data 29 | 30 | Nx = Fin 100 31 | noise = 0.1 32 | [k1, k2] = split_key (new_key 0) 33 | 34 | def trueFun(x:Float) -> Float = 35 | x + sin (5.0 * x) 36 | 37 | xs : Nx=>Float = for i. rand (ixkey k1 i) 38 | ys : Nx=>Float = for i. trueFun xs[i] + noise * randn (ixkey k2 i) 39 | 40 | :html show_plot $ xy_plot xs ys 41 | > 42 | 43 | 'Implement basis function regression 44 | 45 | def regress(featurize: (Float) -> d=>Float, xRaw:n=>Float, y:n=>Float) -> d=>Float given (d|Ix, n|Ix) = 46 | x = map featurize xRaw 47 | xT = transpose x 48 | solve (xT ** x) (xT **. y) 49 | 50 | 'Fit a third-order polynomial 51 | 52 | def poly(x:Float) -> d=>Float given (d|Ix) = 53 | for i. pow x (n_to_f (ordinal i)) 54 | 55 | params : (Fin 4)=>Float = regress poly xs ys 56 | 57 | def predict(x:Float) -> Float = 58 | vdot params (poly x) 59 | 60 | 61 | xsTest = linspace (Fin 200) 0.0 1.0 62 | 63 | :html show_plot $ xy_plot xsTest (map predict xsTest) 64 | > 65 | 66 | 'RMS error 67 | 68 | def rmsErr(truth:n=>Float, pred:n=>Float) -> Float given (n|Ix) = 69 | sqrt $ mean for i. sq (pred[i] - truth[i]) 70 | 71 | :p rmsErr ys (map predict xs) 72 | > 0.2455227 73 | 74 | def tabCat(xs:n=>a, ys:m=>a) -> (Either n m)=>a given (n|Ix, m|Ix, a) = 75 | for i. case i of 76 | Left i' -> xs[i'] 77 | Right i' -> ys[i'] 78 | 79 | xsPlot = tabCat xs xsTest 80 | ysPlot = tabCat ys $ map predict xsTest 81 | 82 | :html show_plot $ xyc_plot xsPlot ysPlot $ 83 | for i. case i of 84 | Left _ -> 0.0 85 | Right _ -> 1.0 86 | > 87 | -------------------------------------------------------------------------------- /examples/rejection-sampler.dx: -------------------------------------------------------------------------------- 1 | '# Rejection Sampler for a Binomial Distribution 2 | 3 | 'We implement rejection sampling from a Binomial distribution using a uniform proposal. 4 | 5 | def rejectionSample(try: (Key) -> Maybe a, k:Key) -> a given (a|Data) = 6 | iter \i. case try $ hash k i of 7 | Nothing -> Continue 8 | Just x -> Done x 9 | 10 | Prob = Float 11 | LogProb = Float 12 | 13 | # log probability density of a Binomial distribution 14 | def logBinomialProb(n':Nat, p:Prob, counts':Nat) -> LogProb = 15 | n = n_to_f n' 16 | counts = n_to_f counts' 17 | pSuccess = log p * counts 18 | pFailure = log1p (-p) * (n - counts) 19 | normConst = (lbeta (1. + counts) (1. + n - counts) + 20 | log1p (n)) 21 | pSuccess + pFailure - normConst 22 | 23 | def trySampleBinomial(n:Nat, p:Prob, k:Key) -> Maybe Nat = 24 | [k1, k2] = split_key k 25 | proposal = f_to_n $ floor $ rand k1 * n_to_f (n + 1) 26 | if proposal > n 27 | then Nothing 28 | else 29 | acceptance = log (rand k2) < logBinomialProb n p proposal 30 | if acceptance 31 | then Just proposal 32 | else Nothing 33 | 34 | '## Example 35 | 36 | 'We test the implementation by sampling from a Binomial distribution with 10 trials and success probability 0.4. 37 | 38 | # parameters 39 | n = 10 40 | p = 0.4 41 | numSamples = 5000 42 | k0 = new_key 0 43 | 44 | # TODO: use currying sugar (or even better, effects) 45 | rejectionSamples = rand_vec numSamples (\k. rejectionSample (\k'. trySampleBinomial n p k') k) k0 46 | 47 | :p slice rejectionSamples 0 $ Fin 10 48 | > [5, 3, 3, 3, 4, 4, 5, 4, 3, 3] 49 | 50 | 'The Binomial distribution has mean 4 and variance 2.4. 51 | 52 | def meanAndVariance(xs:n=>Float) -> (Float, Float) given (n|Ix) = (mean xs, sq $ std xs) 53 | 54 | :p meanAndVariance $ map n_to_f rejectionSamples 55 | > (4.019, 2.434639) 56 | 57 | '## Alternative: Inversion sampling 58 | 59 | 'Alternatively, we can use inversion sampling. 60 | 61 | def binomialSample(n:Nat, p:Prob, k:Key) -> Nat = 62 | m = n + 1 63 | logprobs = for i:(Fin m). logBinomialProb n p $ ordinal i 64 | ordinal $ categorical logprobs k 65 | 66 | inversionSamples = rand_vec numSamples (\k. binomialSample n p k) k0 67 | 68 | :p slice inversionSamples 0 $ Fin 10 69 | > [6, 3, 3, 3, 4, 5, 1, 3, 3, 6] 70 | 71 | :p meanAndVariance $ map n_to_f inversionSamples 72 | > (3.9642, 2.468519) 73 | 74 | 'The following variant is guaranteed to evaluate the CDF only once. 75 | 76 | def binomialBatch(n:Nat, p:Prob, k:Key) -> a => Nat given (a|Ix) = 77 | m = n + 1 78 | logprobs = for i:(Fin m). logBinomialProb n p $ ordinal i 79 | map ordinal $ categorical_batch logprobs k 80 | 81 | inversionBatchSamples = (binomialBatch n p k0) :: Fin numSamples => Nat 82 | 83 | :p slice inversionBatchSamples 0 $ Fin 10 84 | > [6, 3, 3, 3, 4, 5, 1, 3, 3, 6] 85 | 86 | :p meanAndVariance $ map n_to_f inversionBatchSamples 87 | > (3.9642, 2.468519) 88 | -------------------------------------------------------------------------------- /examples/sgd.dx: -------------------------------------------------------------------------------- 1 | '# Stochastic Gradient Descent with Momentum 2 | 3 | def sgd_step( 4 | step_size: Float, 5 | decay: Float, 6 | gradfunc: (a, Nat) -> a, 7 | x: a, 8 | m: a, 9 | iter:Nat 10 | ) ->(a, a) given (a|VSpace) = 11 | g = gradfunc x iter 12 | new_m = decay .* m + g 13 | new_x = x - step_size .* new_m 14 | (new_x, new_m) 15 | 16 | # In-place optimization loop. 17 | def sgd( 18 | step_size:Float, 19 | decay:Float, 20 | num_steps:Nat, 21 | gradient: (a, Nat) -> a, 22 | x0: a 23 | ) -> a given (a|VSpace) = 24 | m0 = zero 25 | (x_final, m_final) = yield_state (x0, m0) \state. 26 | for i:(Fin num_steps). 27 | (x, m) = get state 28 | state := sgd_step step_size decay gradient x m (ordinal i) 29 | x_final 30 | 31 | 32 | '## Example quadratic optimization problem 33 | 34 | D = Fin 4 35 | optimum = for i:D. 1.1 36 | def objective(x:D=>Float) -> Float = 0.5 * sum for i. sq (optimum[i] - x[i]) 37 | def gradfunc(x:D=>Float, iter:Nat) -> D=>Float = grad objective x 38 | 39 | '## Run optimizer 40 | 41 | x_init = for i:D. 0.0 42 | stepsize = 0.01 43 | decay = 0.9 44 | num_iters = 1000 45 | :p sgd stepsize decay num_iters gradfunc x_init 46 | > [1.1, 1.1, 1.1, 1.1] 47 | 48 | :p optimum 49 | > [1.1, 1.1, 1.1, 1.1] 50 | -------------------------------------------------------------------------------- /examples/sierpinski.dx: -------------------------------------------------------------------------------- 1 | '# Sierpinski Triangle ("Chaos Game") 2 | 3 | import diagram 4 | import plot 5 | 6 | def update(points:n=>Point, key:Key, p:Point) -> Point given (n|Ix) = 7 | p' = points[rand_idx key] 8 | Point(0.5 * (p.x + p'.x), 0.5 * (p.y + p'.y)) 9 | 10 | def runChain(n:Nat, key:Key, x0:a, f:(Key, a) -> a) -> Fin n => a given (a|Data) = 11 | ref <- with_state x0 12 | for i:(Fin n). 13 | new = ixkey key i | f(get ref) 14 | ref := new 15 | new 16 | 17 | trianglePoints : (Fin 3)=>Point = [Point(0.0, 0.0), Point(1.0, 0.0), Point(0.5, sqrt 0.75)] 18 | 19 | n = 3000 20 | points = runChain n (new_key 0) (Point 0.0 0.0) \k p. update trianglePoints k p 21 | 22 | (xs, ys) = unzip for i:(Fin n). (points[i].x, points[i].y) 23 | 24 | :html show_plot $ xy_plot xs ys 25 | > 26 | -------------------------------------------------------------------------------- /flake.lock: -------------------------------------------------------------------------------- 1 | { 2 | "nodes": { 3 | "flake-utils": { 4 | "locked": { 5 | "lastModified": 1644229661, 6 | "narHash": "sha256-1YdnJAsNy69bpcjuoKdOYQX0YxZBiCYZo4Twxerqv7k=", 7 | "owner": "numtide", 8 | "repo": "flake-utils", 9 | "rev": "3cecb5b042f7f209c56ffd8371b2711a290ec797", 10 | "type": "github" 11 | }, 12 | "original": { 13 | "owner": "numtide", 14 | "repo": "flake-utils", 15 | "type": "github" 16 | } 17 | }, 18 | "llvm-hs-src": { 19 | "flake": false, 20 | "locked": { 21 | "lastModified": 1644009200, 22 | "narHash": "sha256-IG4Mh89bY+PtBJtzlXKYsPljfHP7OSQk03pV6fSmdRY=", 23 | "owner": "llvm-hs", 24 | "repo": "llvm-hs", 25 | "rev": "eda85a2bbe362a0b89df5adce0cb65e4e755eac5", 26 | "type": "github" 27 | }, 28 | "original": { 29 | "owner": "llvm-hs", 30 | "ref": "llvm-12", 31 | "repo": "llvm-hs", 32 | "type": "github" 33 | } 34 | }, 35 | "nixpkgs": { 36 | "locked": { 37 | "lastModified": 1644151317, 38 | "narHash": "sha256-TpXGBYCFKvEN7Q+To45rn4kqTbLPY4f56rF6ymUGGRE=", 39 | "owner": "NixOS", 40 | "repo": "nixpkgs", 41 | "rev": "942b0817e898262cc6e3f0a5f706ce09d8f749f1", 42 | "type": "github" 43 | }, 44 | "original": { 45 | "id": "nixpkgs", 46 | "type": "indirect" 47 | } 48 | }, 49 | "root": { 50 | "inputs": { 51 | "flake-utils": "flake-utils", 52 | "llvm-hs-src": "llvm-hs-src", 53 | "nixpkgs": "nixpkgs" 54 | } 55 | } 56 | }, 57 | "root": "root", 58 | "version": 7 59 | } 60 | -------------------------------------------------------------------------------- /flake.nix: -------------------------------------------------------------------------------- 1 | { 2 | description = "Dex (named for \"index\") is a research language for typed, functional array processing."; 3 | 4 | inputs = { 5 | flake-utils.url = "github:numtide/flake-utils"; 6 | llvm-hs-src = { 7 | url = "github:llvm-hs/llvm-hs/llvm-12"; 8 | flake = false; 9 | }; 10 | }; 11 | 12 | outputs = { self, nixpkgs, flake-utils, llvm-hs-src }: 13 | flake-utils.lib.eachDefaultSystem (system: 14 | let 15 | pkgs = (import nixpkgs { 16 | inherit system; 17 | config.allowUnfree = true; # Needed for CUDA 18 | }); 19 | in rec { 20 | packages.dex = (pkgs.callPackage ./. { 21 | inherit pkgs; 22 | inherit llvm-hs-src; 23 | }); 24 | packages.dex-cuda = (pkgs.callPackage ./. { 25 | inherit pkgs; 26 | inherit llvm-hs-src; 27 | withCudaSupport = true; 28 | }); 29 | defaultPackage = packages.dex; 30 | 31 | devShell = (import ./shell.nix { 32 | inherit pkgs; 33 | }); 34 | }); 35 | } 36 | -------------------------------------------------------------------------------- /julia/Project.toml: -------------------------------------------------------------------------------- 1 | name = "DexCall" 2 | uuid = "bb22f25d-cb49-471c-b017-930e329a2928" 3 | version = "0.1.0" 4 | 5 | [deps] 6 | CombinedParsers = "5ae71ed2-6f8a-4ed1-b94f-e14e8158f19e" 7 | 8 | [compat] 9 | CombinedParsers = "^0.2" 10 | julia = "^1.6" 11 | 12 | [extras] 13 | Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" 14 | 15 | [targets] 16 | test = ["Test"] 17 | -------------------------------------------------------------------------------- /julia/deps/build.jl: -------------------------------------------------------------------------------- 1 | const dexlang_root = dirname(dirname(@__DIR__)) 2 | cd(dexlang_root) do 3 | run(`make build-ffis`) 4 | end -------------------------------------------------------------------------------- /julia/src/DexCall.jl: -------------------------------------------------------------------------------- 1 | "Calling Dex from Julia" 2 | module DexCall 3 | using CombinedParsers 4 | using CombinedParsers.Regexp 5 | 6 | export evaluate, DexError, DexModule, NativeFunction, @dex_func_str 7 | export Atom, dexize, juliaize, NativeFunction 8 | 9 | include("api_types.jl") 10 | include("api.jl") 11 | include("evaluate.jl") 12 | include("native_function.jl") 13 | 14 | # use this to disable free'ing haskell objects after we have closed the RTS 15 | const NO_FREE = Ref(false) 16 | 17 | function __init__() 18 | init() 19 | @eval const JIT = create_JIT() 20 | atexit() do 21 | destroy_JIT(JIT) 22 | NO_FREE[] = true 23 | fini() 24 | end 25 | 26 | @eval const PRELUDE = create_context() 27 | end 28 | end 29 | -------------------------------------------------------------------------------- /julia/src/api.jl: -------------------------------------------------------------------------------- 1 | const libdex = joinpath(dirname(@__DIR__), "deps", "libDex.so") 2 | isfile(libdex) || error("libDex not found in $libdex. Please run `Pkg.build()`") 3 | 4 | 5 | ########################################################################################## 6 | # Global State 7 | 8 | # These can only be called once in life-time of program, can not re-init after fini etc 9 | init() = @ccall libdex.dexInit()::Nothing 10 | fini() = @ccall libdex.dexFini()::Nothing 11 | 12 | # No reason to call these more than once in life-time of program 13 | create_JIT() = @ccall libdex.dexCreateJIT()::Ptr{HsJIT} 14 | destroy_JIT(jit) = NO_FREE[] || @ccall libdex.dexDestroyJIT(jit::Ptr{HsJIT})::Nothing 15 | 16 | ########################################################################################## 17 | "An error thrown from with-in Dex" 18 | struct DexError <: Exception 19 | msg::String 20 | end 21 | function Base.showerror(io::IO, err::DexError) 22 | if '\n' ∈ err.msg 23 | # If message is multiline then it may dend on exact alignment 24 | println(io, "DexError:\n", err.msg) 25 | else 26 | # If one line then short enough to happen on same line as everuthing else 27 | println(io, "DexError: ", err.msg) 28 | end 29 | end 30 | 31 | get_error_msg() = unsafe_string(@ccall libdex.dexGetError()::Cstring) 32 | throw_from_dex() = throw(DexError(get_error_msg())) 33 | 34 | 35 | create_context() = @ccall libdex.dexCreateContext()::Ptr{HsContext} 36 | destroy_context(ctx) = NO_FREE[] || @ccall libdex.dexDestroyContext(ctx::Ptr{HsContext})::Nothing 37 | 38 | function context(f) 39 | ctx = create_context() 40 | try 41 | f(ctx) 42 | finally 43 | destroy_context(ctx) 44 | end 45 | end 46 | 47 | 48 | dex_eval(ctx, str) = @ccall libdex.dexEval(ctx::Ptr{HsContext}, str::Cstring)::Ptr{HsContext} 49 | 50 | insert(ctx, str, atm) = @ccall libdex.dexInsert(ctx::Ptr{HsContext}, str::Cstring, atm::Ptr{HsAtom})::Ptr{HsContext} 51 | eval_expr(ctx, str) = @ccall libdex.dexEvalExpr(ctx::Ptr{HsContext}, str::Cstring)::Ptr{HsAtom} 52 | lookup(ctx, str) = @ccall libdex.dexLookup(ctx::Ptr{HsContext}, str::Cstring)::Ptr{HsAtom} 53 | 54 | print(atm) = unsafe_string(@ccall libdex.dexPrint(atm::Ptr{HsAtom})::Cstring) 55 | 56 | compile(ctx, atm, jit=JIT) = @ccall libdex.dexCompile(jit::Ptr{HsJIT}, ctx::Ptr{HsContext}, atm::Ptr{HsAtom})::Ptr{NativeFunctionObj} 57 | unload(f, jit=JIT) = NO_FREE[] || @ccall libdex.dexUnload(jit::Ptr{HsJIT}, f::Ptr{NativeFunctionObj})::Nothing 58 | 59 | get_function_signature(f, jit=JIT) = @ccall libdex.dexGetFunctionSignature(jit::Ptr{HsJIT}, f::Ptr{NativeFunctionObj})::Ptr{NativeFunctionSignature} 60 | free_function_signature(s) = NO_FREE[] || @ccall libdex.dexFreeFunctionSignature(s::Ptr{NativeFunctionSignature})::Nothing 61 | 62 | 63 | to_CAtom(src, dest) = @ccall libdex.dexToCAtom(src::Ptr{HsAtom}, dest::Ptr{CAtom})::Int32 64 | from_CAtom(src) = @ccall libdex.dexFromCAtom(src::Ptr{CAtom})::Ptr{HsAtom} -------------------------------------------------------------------------------- /julia/src/api_types.jl: -------------------------------------------------------------------------------- 1 | struct HsAtom end 2 | struct HsContext end 3 | struct HsJIT end 4 | struct NativeFunctionObj end 5 | struct NativeFunctionSignature 6 | arg::Cstring 7 | res::Cstring 8 | _ccall::Cstring # can't name this field `ccall` as that is a built-in in julia 9 | end 10 | 11 | struct CRectArray 12 | data::Ptr{Nothing} 13 | shape_ptr::Ptr{Int64} 14 | strides_ptr::Ptr{Int64} 15 | end 16 | 17 | 18 | 19 | """ 20 | TaggedUnion{Tuple{A, B, ...}} 21 | 22 | Represents a tagged union over types `A`, `B` etc. 23 | Must have a first field `tag::UInt64` 24 | and a second field `payload` which must be some isbits type that can hold the largest 25 | element of the union, which you can (e.g.) declare as a custom `primitive`. 26 | This is required as Julia doens't directly support Unions in the mapping to-from C 27 | so we store the data as arbitary bits then force it to be reinterpretted based on the tag 28 | """ 29 | abstract type TaggedUnion{T<:Tuple} end 30 | function bust_union(x::TaggedUnion{U}) where U 31 | T = U.parameters[Int(x.tag) + 1] 32 | return bust_union(force_reinterpret(T, x.payload)) 33 | end 34 | bust_union(x) = x # not a union, leave it alone 35 | 36 | "Forces reinterpretting `raw` as a `T`. If sizeof(raw) <: size(T) high bits will be filled with junk." 37 | function force_reinterpret(::Type{T}, raw) where T 38 | isbits(raw) || throw(ArgumentError("Can only reinterpret from a isbits type")) 39 | isbitstype(T) || throw(ArgumentError("Can only reinterpret into a isbits type")) 40 | return unsafe_load(Ptr{T}(Base.pointer_from_objref(Ref(raw)))) 41 | end 42 | 43 | "Holds data for CLit, big enough to hold the larges of its union members, such as a Float64" 44 | primitive type CLitPayload 64 end 45 | struct CLit <: TaggedUnion{Tuple{Int64, Int32, UInt8, Float64, Float32, UInt32, UInt64}} 46 | tag::UInt64 47 | payload::CLitPayload 48 | end 49 | CLit(x) = CLit(CLit_tag(x), force_reinterpret(CLitPayload, x)) 50 | 51 | CLit_tag(::Int64) = 0x0000_0000_0000_0000 52 | CLit_tag(::Int32) = 0x0000_0000_0000_0001 53 | CLit_tag(::UInt8) = 0x0000_0000_0000_0002 54 | CLit_tag(::Float64) = 0x0000_0000_0000_0003 55 | CLit_tag(::Float32) = 0x0000_0000_0000_0004 56 | CLit_tag(::UInt32) = 0x0000_0000_0000_0005 57 | CLit_tag(::UInt64) = 0x0000_0000_0000_0006 58 | 59 | "Holds data for CAtom, big enough to hold the larges of its union members, which is the CRectArray" 60 | primitive type CAtomPayload 3*64 end 61 | struct CAtom <: TaggedUnion{Tuple{CLit, CRectArray}} 62 | tag::UInt64 63 | payload::CAtomPayload 64 | end 65 | CAtom(x) = CAtom(0x0000_0000_0000_0000, force_reinterpret(CAtomPayload, CLit(x))) 66 | CAtom(x::AbstractArray) = throw(DomainError(typeof(x), "Arrays not yet supported")) 67 | function CAtom(atm::Ptr{HsAtom}) 68 | result = Ref{CAtom}() 69 | success = to_CAtom(atm, result) 70 | iszero(success) && throw_from_dex() 71 | return result[] 72 | end -------------------------------------------------------------------------------- /julia/src/evaluate.jl: -------------------------------------------------------------------------------- 1 | """ 2 | Atom 3 | A wrapped DexLang value. 4 | 5 | Scalar values can be converted to julia objects using [`juliaize`](@ref), or `convert`. 6 | Functions can be called directly, if you pass in atoms, 7 | or they can be compiled and made usable on julia objects by using [`NativeFunction`](@ref). 8 | """ 9 | struct Atom 10 | ptr::Ptr{HsAtom} 11 | ctx::Ptr{HsContext} 12 | end 13 | 14 | Base.show(io::IO, atom::Atom) = show(io, print(atom.ptr)) 15 | 16 | """ 17 | juliaize(x) 18 | 19 | Get the corresponding Julia object from some Dex object. 20 | The inverse of [`dexize`](@ref). 21 | """ 22 | juliaize(x::CAtom) = bust_union(x) 23 | juliaize(x::Ptr{HsAtom}) = juliaize(CAtom(x)) 24 | juliaize(x::Atom) = juliaize(x.ptr) 25 | Base.convert(::Type{T}, atom::Atom) where {T<:Number} = convert(T, juliaize(atom)) 26 | 27 | """ 28 | dexize(x) 29 | 30 | Get the corresponding Dex object from some Julia object. 31 | The inverse of [`juliaize`](@ref). 32 | """ 33 | dexize(x) = Atom(from_CAtom(Ref(CAtom(x))), PRELUDE) 34 | # ^ Always defined in PRELUDE as it could be defined anywhere that has all the bindings, 35 | # but because it has no bindings anywhere will do 36 | 37 | Base.convert(::Type{Atom}, atom::Number) = dexize(atom) 38 | 39 | 40 | function (self::Atom)(args...) 41 | # TODO: Make those calls more hygenic 42 | env = self.ctx 43 | pieces = (self, args...) 44 | for (i, atom) in enumerate(pieces) 45 | # NB: Atoms can contain arbitrary references 46 | if atom.ctx !== PRELUDE && atom.ctx !== self.ctx 47 | throw(ArgumentError("Mixing atoms coming from different Dex modules is not supported yet!")) 48 | end 49 | old_env, env = env, insert(env, "julia_arg$i", atom.ptr) 50 | destroy_context(old_env) 51 | end 52 | return evaluate(join("julia_arg" .* string.(eachindex(pieces)), " "), self.ctx, env) 53 | end 54 | 55 | mutable struct DexModule 56 | # Needs to be mutable struct so can attach finalizer 57 | ctx::Ptr{HsContext} 58 | end 59 | 60 | """ 61 | DexModule(str) 62 | 63 | For running 1 or more Dex expressions, and keeping the state. 64 | You can get them back out of the module using `getproperty`. 65 | They are returned as [`DexCall.Atom`](@ref)s. 66 | 67 | # Example: 68 | 69 | ```julia 70 | julia> m = DexModule(raw""" 71 | x = 42 72 | y = 2 * x 73 | """) 74 | DexModule(Ptr{DexCall.HsContext} @0x0000000000000031) 75 | 76 | julia> m.x 77 | "42" 78 | 79 | julia> m.y 80 | "84" 81 | ``` 82 | """ 83 | function DexModule(source::AbstractString) 84 | ctx = dex_eval(PRELUDE, source) 85 | ctx == C_NULL && throw_from_dex() 86 | m = DexModule(ctx) 87 | finalizer(m) do _m 88 | destroy_context(getfield(_m, :ctx)) 89 | end 90 | return m 91 | end 92 | 93 | function Base.getproperty(m::DexModule, name::Symbol) 94 | ctx = getfield(m, :ctx) 95 | ret = lookup(ctx, string(name)) 96 | ret == C_NULL && throw_from_dex() 97 | return Atom(ret, ctx) 98 | end 99 | 100 | @doc raw""" 101 | evaluate(str) 102 | 103 | A friendly function for running Dex code. 104 | The string `str` must contain a single Dex expression. 105 | Return a [`DexCall.Atom`](@ref) 106 | # Example: 107 | ```julia 108 | julia> evaluate(raw"sum $ for i. exp [log 2.0, log 4.0].i") 109 | "6." 110 | ``` 111 | """ 112 | function evaluate(str, _module=PRELUDE, env=_module) 113 | result = eval_expr(env, str) 114 | result == C_NULL && throw_from_dex() 115 | return Atom(result, _module) 116 | end -------------------------------------------------------------------------------- /julia/test/api.jl: -------------------------------------------------------------------------------- 1 | @testset "api.jl" begin 2 | @testset "basic demo of eval and check errors" begin 3 | DexCall.context() do ctx 4 | DexCall.dex_eval(ctx, "(1 : Int) + 1.0\n") 5 | error_message = DexCall.get_error_msg() 6 | @test contains(error_message, r"Type error.*Expected: Int32.*Actual: Float32.*"s) 7 | end 8 | end 9 | end 10 | -------------------------------------------------------------------------------- /julia/test/evaluate.jl: -------------------------------------------------------------------------------- 1 | @testset "evaluate.jl" begin 2 | @testset "evaluate erroring" begin 3 | @test_throws DexError evaluate("(1 : Int) + 2.0") 4 | end 5 | 6 | @testset "evaluate show" begin 7 | @test repr(evaluate("1")) == repr("1") 8 | @test repr(evaluate("1.5")) == repr("1.5") 9 | @test repr(evaluate("[1, 2]")) == repr("[1, 2]") 10 | @test repr(evaluate("1+3")) == repr("4") 11 | @test repr(evaluate("for i. [1, 2].i + 1")) == repr("[2, 3]") 12 | 13 | # This seems weird: why is it doubly quoted? 😕 14 | @test repr(evaluate("IToW8 65")) === repr(repr('A')) 15 | end 16 | 17 | @testset "evaluate juliaize" begin 18 | @test juliaize(evaluate("1")) === Int32(1) 19 | @test juliaize(evaluate("1.5")) === 1.5f0 20 | @test juliaize(evaluate("IToW8 65")) === UInt8(65) 21 | end 22 | 23 | @testset "juliaize-dexize round-trip" begin 24 | @test juliaize(dexize(Int64(3))) === Int64(3) 25 | @test juliaize(dexize(Int32(3))) === Int32(3) 26 | @test juliaize(dexize(UInt8(3))) === UInt8(3) 27 | @test juliaize(dexize(Float64(3))) === Float64(3) 28 | @test juliaize(dexize(Float32(3))) === Float32(3) 29 | @test juliaize(dexize(UInt64(3))) === UInt64(3) 30 | @test juliaize(dexize(UInt32(3))) === UInt32(3) 31 | end 32 | 33 | 34 | @testset "Atom function call" begin 35 | m = DexModule(""" 36 | def addOne (x: Float) : Float = x + 1.0 37 | """) 38 | x = evaluate("2.5") 39 | y = evaluate("[2, 3, 4]") 40 | @test repr(m.addOne(x)) == repr("3.5") 41 | 42 | # This is a function that is in `m` from dex's prelude 43 | @test repr(m.sum(y)) == repr("9") 44 | end 45 | 46 | @testset "convert Atom" begin 47 | atom = convert(Atom, 1f0) 48 | @test convert(Number, atom) === 1f0 49 | @test convert(Real, atom) === 1f0 50 | @test convert(Float64, atom) === 1.0 51 | 52 | atom = convert(Atom, Int32(2)) 53 | @test convert(Number, atom) === Int32(2) 54 | @test convert(Real, atom) === Int32(2) 55 | @test convert(Float64, atom) === 2.0 56 | end 57 | 58 | 59 | @testset "DexModule" begin 60 | m = DexModule(""" 61 | x = 2.5 62 | y = [2, 3, 4] 63 | """) 64 | @test repr(m.x) == repr("2.5") 65 | @test repr(m.y) == repr("[2, 3, 4]") 66 | end 67 | end 68 | -------------------------------------------------------------------------------- /julia/test/native_function.jl: -------------------------------------------------------------------------------- 1 | 2 | @testset "native_function.jl" begin 3 | @testset "signature parser" begin 4 | @testset "$example" for example in ( 5 | "arg0:f32", 6 | "arg0:f32,arg1:f32", 7 | "arg0:i64,arg1:i32", 8 | "arg0:f32[10]", 9 | "?arg0:i32,arg1:f32[arg0]", 10 | "arg2:f32[arg0]", 11 | "?arg0:i32,?arg1:i32,arg2:f32[arg0,arg1]", 12 | "arg3:f32[arg1,arg0]", 13 | "arg0:f32,?arg1:i32,arg2:f32[arg1]" 14 | ) 15 | # This is just a quick check to make sure the parser doesn't error. 16 | # later integration tests will show it has the right behavour. 17 | @test DexCall.parse_sig(example) isa Vector{DexCall.Binder} 18 | end 19 | end 20 | 21 | @testset "dex_func anon funcs" begin 22 | @test dex_func"\x:Float. exp x"(0f0) === 1f0 23 | @test dex_func"\x:Float. (2.0*x, x)"(1.5f0) === (3f0, 1.5f0) 24 | @test dex_func"\x:Int64 y:Int. I64ToI x + y"(Int64(1), Int32(2)) === Int32(3) 25 | @test dex_func"\x:((Fin 3)=>Float). sum x"([1f0, 2f0, 3f0]) === 6f0 26 | 27 | @test dex_func"\x:((Fin 3)=>Float). for i. 2.0 * x.i"([1f0, 2f0, 3f0]) isa Vector{Float32} 28 | @test dex_func"\x:((Fin 3)=>Float). for i. 2.0 * x.i"([1f0, 2f0, 3f0]) == [2f0, 4f0, 6f0] 29 | end 30 | 31 | @testset "dex_func named funcs" begin 32 | dex_func""" 33 | def myTranspose (n: Int) ?-> (m: Int) ?-> 34 | (x : (Fin n)=>(Fin m)=>Float) : (Fin m)=>(Fin n)=>Float = 35 | for i j. x.j.i 36 | """ 37 | 38 | myTranspose([1f0 2f0 3f0; 4f0 5f0 6f0]) isa AbstractMatrix{Float32} 39 | @test myTranspose([1f0 2f0 3f0; 4f0 5f0 6f0]) == [1f0 2f0 3f0; 4f0 5f0 6f0]' 40 | 41 | 42 | dex_func"double_it = \x:Float. 2.0 * x" 43 | @test double_it(4f0) === 8f0 44 | end 45 | 46 | @testset "dex_func not all implicits at start" begin 47 | dex_func"def f (a : Float) (n : Int) ?-> (b : (Fin n)=>Float) : Float = a + sum b" 48 | @test f(100f0, [10f0, 2f0, 0.1f0, 0.1f0, 0.1f0]) === 112.3f0 49 | end 50 | 51 | @testset "dex_func named const funcs" begin 52 | @eval dex_func"foo = \x:Int. 1.5"c # use @eval to run at global scope, so can declare const 53 | @test isconst(@__MODULE__, :foo) 54 | @test foo(Int32(4)) === 1.5f0 55 | end 56 | 57 | @testset "dex_func errors" begin 58 | @test_throws ArgumentError dex_func"\x:Float. exp x"(0.0) 59 | end 60 | 61 | @testset "NativeFunction directly" begin 62 | m = DexModule(raw"def addTwo (n: Int) ?-> (x: (Fin n)=>Float) : (Fin n)=>Float = for i. x.i + 2.0") 63 | add_two = NativeFunction(m.addTwo) 64 | @test add_two([1f0, 10f0]) == [3f0, 12f0] 65 | end 66 | end -------------------------------------------------------------------------------- /julia/test/runtests.jl: -------------------------------------------------------------------------------- 1 | using Test 2 | using DexCall 3 | 4 | @testset "DexCall" begin 5 | include("api.jl") 6 | include("evaluate.jl") 7 | include("native_function.jl") 8 | end -------------------------------------------------------------------------------- /lib/netpbm.dx: -------------------------------------------------------------------------------- 1 | '# Netpbm 2 | 3 | 'This is a basic loader for the .ppm P6 image format. 4 | 5 | import parser 6 | 7 | enum Image = 8 | MkImage(rows:Nat, cols:Nat, pixels:(Fin rows => Fin cols => Fin 3 => Word8)) 9 | 10 | parse_p6 : Parser Image = MkParser \h. 11 | # Loads a raw PPM file in P6 format. 12 | # The header will look something like: 13 | # P6 14 | # 220 220 (width, height) 15 | # 255 (max color value) 16 | # followed by a flat block of height x width x 3 chars. 17 | parse h $ p_char 'P' 18 | parse h $ p_char '6' 19 | parse h $ parse_any 20 | cols = i32_to_n $ parse h $ parse_unsigned_int 21 | parse h $ parse_any 22 | rows = i32_to_n $ parse h $ parse_unsigned_int 23 | parse h $ parse_any 24 | colorsize = i32_to_n $ parse h $ parse_unsigned_int 25 | parse h $ parse_any 26 | pixels = for r:(Fin rows). 27 | for c:(Fin cols). 28 | for c:(Fin 3). 29 | parse h parse_any 30 | MkImage rows cols pixels 31 | 32 | def load_image_p6(filename:String) -> Maybe Image = 33 | image_raw = unsafe_io \. read_file filename 34 | run_parser_partial image_raw parse_p6 35 | -------------------------------------------------------------------------------- /lib/set.dx: -------------------------------------------------------------------------------- 1 | '# Sets and Set-Indexed Arrays 2 | 3 | import sort 4 | 5 | '## Monoidal enforcement of uniqueness in sorted lists 6 | 7 | def last(xs:n=>a) -> Maybe a given (n|Ix, a) = 8 | s = size n 9 | case s == 0 of 10 | True -> Nothing 11 | False -> Just xs[unsafe_from_ordinal (unsafe_nat_diff s 1)] 12 | 13 | def first(xs:n=>a) -> Maybe a given (n|Ix, a) = 14 | s = size n 15 | case s == 0 of 16 | True -> Nothing 17 | False -> Just xs[unsafe_from_ordinal 0] 18 | 19 | def all_except_last(xs:n=>a) -> List a given (n|Ix, a) = 20 | shortSize = Fin (size n -| 1) 21 | allButLast = for i:shortSize. xs[unsafe_from_ordinal (ordinal i)] 22 | AsList _ allButLast 23 | 24 | def merge_unique_sorted_lists(xlist:List a, ylist:List a) -> List a given (a|Eq) = 25 | # This function is associative, for use in a monoidal reduction. 26 | # Assumes all xs are <= all ys. 27 | # The element at the end of xs might equal the 28 | # element at the beginning of ys. If so, this 29 | # function removes the duplicate when concatenating the lists. 30 | AsList(nx, xs) = xlist 31 | AsList(_ , ys) = ylist 32 | case last xs of 33 | Nothing -> ylist 34 | Just last_x -> case first ys of 35 | Nothing -> xlist 36 | Just first_y -> case last_x == first_y of 37 | False -> concat [xlist, ylist] 38 | True -> concat [all_except_last xs, ylist] 39 | 40 | def remove_duplicates_from_sorted(xs:n=>a) -> List a given (n|Ix, a|Eq) = 41 | xlists = for i:n. (AsList 1 [xs[i]]) 42 | reduce (AsList 0 []) merge_unique_sorted_lists xlists 43 | 44 | 45 | '## Sets 46 | 47 | enum Set(a|Ord) = 48 | # Guaranteed to be in sorted order with unique elements, 49 | # as long as no one else uses this constructor. 50 | # Instead use the "toSet" function below. 51 | UnsafeAsSet(n:Nat, elements:(Fin n => a)) 52 | 53 | def to_set(xs:n=>a) -> Set a given (n|Ix, a|Ord) = 54 | sorted_xs = sort xs 55 | AsList(n', sorted_unique_xs) = remove_duplicates_from_sorted sorted_xs 56 | UnsafeAsSet n' sorted_unique_xs 57 | 58 | def set_size(p:Set a) -> Nat given (a|Ord) = 59 | UnsafeAsSet(n, _) = p 60 | n 61 | 62 | instance Eq(Set a) given (a|Ord) 63 | def (==)(sx, sy) = 64 | UnsafeAsSet(_, xs) = sx 65 | UnsafeAsSet(_, ys) = sy 66 | (AsList _ xs) == (AsList _ ys) 67 | 68 | def set_union( 69 | sx:Set a, 70 | sy:Set a 71 | ) -> Set a given (a|Ord) = 72 | UnsafeAsSet(nx, xs) = sx 73 | UnsafeAsSet(ny, ys) = sy 74 | combined = merge_sorted_tables xs ys 75 | AsList(n', sorted_unique_xs) = remove_duplicates_from_sorted combined 76 | UnsafeAsSet _ sorted_unique_xs 77 | 78 | def set_intersect( 79 | sx:Set a, 80 | sy:Set a 81 | ) -> Set a given (a|Ord) = 82 | UnsafeAsSet(nx, xs) = sx 83 | UnsafeAsSet(ny, ys) = sy 84 | # This could be done in O(nx + ny) instead of O(nx log ny). 85 | isInYs = \x. case search_sorted_exact ys x of 86 | Just x -> True 87 | Nothing -> False 88 | AsList(n', intersection) = filter xs isInYs 89 | UnsafeAsSet _ intersection 90 | 91 | 92 | '## Sets as a type, whose inhabitants can index arrays 93 | 94 | # TODO Implicit arguments to data definitions 95 | # (Probably `a` should be implicit) 96 | struct Element(set:(Set a)) given (a|Ord) = 97 | val: Nat 98 | 99 | # TODO The set argument could be implicit (inferred from the Element 100 | # type), but maybe it's easier to read if it's explicit. 101 | def member(x:a, set:(Set a)) -> Maybe (Element set) given (a|Ord) = 102 | UnsafeAsSet(_, elts) = set 103 | case search_sorted_exact elts x of 104 | Just n -> Just $ Element(ordinal n) 105 | Nothing -> Nothing 106 | 107 | def value(x:Element set) -> a given (a|Ord, set:Set a) = 108 | UnsafeAsSet(_, elts) = set 109 | elts[unsafe_from_ordinal x.val] 110 | 111 | instance Ix(Element set) given (a|Ord, set:Set a) 112 | def size'() = set_size set 113 | def ordinal(n) = n.val 114 | def unsafe_from_ordinal(n) = Element(n) 115 | 116 | instance Eq(Element set) given (a|Ord, set:Set a) 117 | def (==)(ix1, ix2) = ordinal ix1 == ordinal ix2 118 | 119 | instance Ord(Element set) given (a|Ord, set:Set a) 120 | def (<)(ix1, ix2) = ordinal ix1 < ordinal ix2 121 | def (>)(ix1, ix2) = ordinal ix1 > ordinal ix2 122 | -------------------------------------------------------------------------------- /lib/sort.dx: -------------------------------------------------------------------------------- 1 | '# Monoidal Merge Sort 2 | 3 | 'Warning: Very slow for now!! 4 | 5 | 'Because merging sorted lists is associative, we can expose the 6 | parallelism of merge sort to the Dex compiler by making a monoid 7 | and using `reduce` or `yield_accum`. 8 | However, this approach puts a lot of pressure on the compiler. 9 | As noted on [StackOverflow](https://stackoverflow.com/questions/21877572/can-you-formulate-the-bubble-sort-as-a-monoid-or-semigroup), 10 | if the compiler does the reduction one element at a time, 11 | it's doing bubble / insertion sort with quadratic time cost. 12 | However, if it breaks the list in half recursively, it'll be doing parallel mergesort. 13 | Currently the Dex compiler will do the quadratic-time version. 14 | 15 | def concat_table(leftin: a=>v, rightin: b=>v) -> (Either a b=>v) given (a|Ix, b|Ix, v:Type) = 16 | for idx. case idx of 17 | Left i -> leftin[i] 18 | Right i -> rightin[i] 19 | 20 | def merge_sorted_tables(xs:m=>a, ys:n=>a) -> (Either m n=>a) given (a|Ord, m|Ix, n|Ix) = 21 | # Possible improvements: 22 | # 1) Using a SortedTable type. 23 | # 2) Avoid needlessly initializing the return array. 24 | init = concat_table xs ys # Initialize array of correct size. 25 | yield_state init \buf. 26 | with_state (0, 0) \countrefs. 27 | for i:(Either m n). 28 | (cur_x, cur_y) = get countrefs 29 | if cur_y >= size n # no ys left 30 | then 31 | countrefs := (cur_x + 1, cur_y) 32 | buf!i := xs[unsafe_from_ordinal cur_x] 33 | else 34 | if cur_x < size m # still xs left 35 | then 36 | if xs[unsafe_from_ordinal cur_x] <= ys[unsafe_from_ordinal cur_y] 37 | then 38 | countrefs := (cur_x + 1, cur_y) 39 | buf!i := xs[unsafe_from_ordinal cur_x] 40 | else 41 | countrefs := (cur_x, cur_y + 1) 42 | buf!i := ys[unsafe_from_ordinal cur_y] 43 | 44 | def merge_sorted_lists(lx: List a, ly: List a) -> List a given (a|Ord) = 45 | # Need this wrapper because Dex can't automatically weaken 46 | # (a | b)=>c to ((Fin d)=>c) 47 | AsList(nx, xs) = lx 48 | AsList(ny, ys) = ly 49 | sorted = merge_sorted_tables xs ys 50 | newsize = nx + ny 51 | AsList _ $ unsafe_cast_table(to=Fin newsize, sorted) 52 | 53 | # Warning: Has quadratic runtime cost for now. 54 | def sort(xs: n=>a) -> n=>a given (n|Ix, a|Ord) = 55 | xlists = each xs \x. to_list([x]) 56 | # reduce might someday recursively subdivide the problem. 57 | AsList(_, r) = reduce xlists mempty merge_sorted_lists 58 | unsafe_cast_table(to=n, r) 59 | 60 | def (+|)(i:n, delta:Nat) -> n given (n|Ix) = 61 | i' = ordinal i + delta 62 | from_ordinal $ select (i' >= size n) (size n -| 1) i' 63 | 64 | def is_sorted(xs:n=>a) -> Bool given (a|Ord, n|Ix) = 65 | all for i:n. xs[i] <= xs[i +| 1] 66 | -------------------------------------------------------------------------------- /misc/build-web-index: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # This script constructs a Markdown file (on standard output) that constitutes 4 | # an index of all web-rendered pages of Dex examples, documentation, and 5 | # libraries. 6 | 7 | # The rendered HTML of the index is meant to be placed at the root of the Dex 8 | # page tree, so that relative links work. 9 | 10 | # The script accepts three arguments: 11 | # - One argument containing a space-separated list of documentation files 12 | # - One argument containing a space-separated list of example files 13 | # - One argument containing a space-separated list of library files 14 | 15 | # TODO: Right now, the indexing script requires every file to have a single 16 | # title, identified by having a "'# " line in the Dex source. We should 17 | # detect when a file lacks such a title and fail loudly, instead of just 18 | # omitting that file from the index. 19 | 20 | import re 21 | import sys 22 | 23 | def file_block(files): 24 | for fname in files: 25 | if fname.startswith("doc/"): 26 | link_name = fname[len("doc/"):-len(".dx")] 27 | else: 28 | link_name = fname[:-len(".dx")] 29 | with open(fname, 'r') as f: 30 | line = f.readline() 31 | title = re.match(r"' *# ?(.*)", line) 32 | if title: 33 | print(f"- [{fname}]({link_name}.html) {title.group(1)}") 34 | else: 35 | raise ValueError(f"First line of file {fname} was not a title (top-level Markdown heading)") 36 | 37 | def main(): 38 | docs, examples, libraries = sys.argv[1:4] 39 | 40 | print("# InDex"); print("") 41 | 42 | print("## Documentation"); print("") 43 | 44 | file_block(docs.split()) 45 | 46 | print(""); print("## Examples"); print("") 47 | 48 | file_block(examples.split()) 49 | 50 | print(""); print("## Libraries"); print("") 51 | 52 | print("- [lib/prelude.dx](lib/prelude.html): The Dex Prelude (automatically imported)") 53 | 54 | file_block(libraries.split()) 55 | 56 | 57 | if __name__ == '__main__': 58 | main() 59 | -------------------------------------------------------------------------------- /misc/check-no-diff: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # This script checks whether two input files have a diff. 4 | # 5 | # Usage: 6 | # 7 | # check-no-diff <first file> <second file> 8 | # 9 | # If there is no diff, the script exits with a zero success status. 10 | # If there is a diff, the script prints the diff to stdout and exits with a 11 | # non-zero error status. 12 | 13 | tmpdiff=$(mktemp) 14 | diff --left-column -y $1 $2 > $tmpdiff \ 15 | && echo OK || (cat $tmpdiff; false) 16 | status=$? 17 | 18 | rm $tmpdiff 19 | exit $status 20 | -------------------------------------------------------------------------------- /misc/check-quine: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # This script checks quine programs. 4 | # https://en.wikipedia.org/wiki/Quine_(computing) 5 | # 6 | # Usage: 7 | # 8 | # check-quine <input file> <command> 9 | # check-quine foo.dx dex -- script --allow-errors 10 | # 11 | # Technically, quines are programs that take no input and produce their own 12 | # source as output. This script instead runs a command with an input file and 13 | # checks whether the output has no diff with the input file. 14 | # 15 | # If the output of the command applied to the input file has no diff with the 16 | # input file, then the input file is a quine and the script exits with a zero 17 | # success status. 18 | # 19 | # Otherwise, if the input file is not a quine, the script prints the diff to 20 | # stdout and exits with a non-zero error status. 21 | 22 | tmpout=$(mktemp) 23 | errout=$(mktemp) 24 | 25 | if ${@:2} $1 > $tmpout 2> $errout ; then 26 | # We check for differences up to timing outputs from the %time or %bench 27 | # commands, because those are expected to vary from run to run. 28 | misc/check-no-diff \ 29 | <(grep -vE "> (Compile|Run) time: " $1) \ 30 | <(grep -vE "> (Compile|Run) time: " $tmpout) 31 | status=$? 32 | else 33 | status=$? 34 | cat $tmpout 35 | fi 36 | 37 | cat $errout 38 | 39 | rm $errout 40 | rm $tmpout 41 | 42 | exit $status 43 | -------------------------------------------------------------------------------- /misc/dex-completion.bash: -------------------------------------------------------------------------------- 1 | #/usr/bin/env bash 2 | 3 | _dex_completions() 4 | { 5 | COMPREPLY=(); 6 | local word="${COMP_WORDS[COMP_CWORD]}"; 7 | if [ "$COMP_CWORD" -eq 1 ]; then 8 | COMPREPLY=($(compgen -W "repl script web watch" -- "$word")); 9 | elif [ "${COMP_WORDS[1]}" = "repl" ]; then 10 | COMPREPLY=() 11 | elif [ "${COMP_WORDS[1]}" = "script" ]; then 12 | COMPREPLY=($(compgen -G "*.dx" -- "$word")); 13 | elif [ "${COMP_WORDS[1]}" = "watch" ]; then 14 | COMPREPLY=($(compgen -G "*.dx" -- "$word")); 15 | elif [ "${COMP_WORDS[1]}" = "web" ]; then 16 | COMPREPLY=($(compgen -G "*.dx" -- "$word")); 17 | else 18 | COMPREPLY=() 19 | fi 20 | } 21 | complete -F _dex_completions dex 22 | -------------------------------------------------------------------------------- /misc/dex.el: -------------------------------------------------------------------------------- 1 | ;; Copyright 2019 Google LLC 2 | ;; 3 | ;; Use of this source code is governed by a BSD-style 4 | ;; license that can be found in the LICENSE file or at 5 | ;; https://developers.google.com/open-source/licenses/bsd 6 | 7 | (setq dex-highlights 8 | `(("#.*$" . font-lock-comment-face) 9 | ("^> .*$" . font-lock-comment-face) 10 | ("^'\\(.\\|\n.\\)*\n\n" . font-lock-comment-face) 11 | ("\\w+:" . font-lock-comment-face) 12 | ("^:\\w*" . font-lock-preprocessor-face) 13 | (,(concat 14 | "\\bdef\\b\\|\\bfor\\b\\|\\brof\\b\\|\\bcase\\b\\|" 15 | "\\bstruct\\b\\|\\benum\\b\\|\\bwhere\\b\\|\\bof\\b\\|" 16 | "\\bif\\b\\|\\bthen\\b\\|\\belse\\b\\|\\binterface\\b\\|" 17 | "\\binstance\\b\\|\\bgiven\\b\\|\\bdo\\b\\|\\bview\\b\\|" 18 | "\\bwith\\b\\|\\bself\\b\\|" 19 | "\\bimport\\b\\|\\bforeign\\b\\|\\bsatisfying\\b") . 20 | font-lock-keyword-face) 21 | ("[-.,!;$^&*:~+/=<>|?\\\\]" . font-lock-variable-name-face) 22 | ("\\b[[:upper:]][[:alnum:]]*\\b" . font-lock-type-face) 23 | ("^@[[:alnum:]]*\\b" . font-lock-keyword-face) 24 | ("\\bdef *\\([_[:alnum:]]*\\)\\b" . (1 font-lock-function-name-face)) 25 | )) 26 | 27 | (defun dex-font-lock-extend-region () 28 | (save-excursion 29 | (goto-char font-lock-beg) 30 | (re-search-backward "\n\n" nil t) 31 | (setq font-lock-beg (point)) 32 | (goto-char font-lock-end) 33 | (re-search-forward "\n\n" nil t) 34 | (setq font-lock-end (point)))) 35 | 36 | (define-derived-mode dex-mode fundamental-mode "dex" 37 | (setq font-lock-defaults '(dex-highlights)) 38 | (setq-local comment-start "#") 39 | (setq-local comment-end "") 40 | (setq-local syntax-propertize-function 41 | (syntax-propertize-rules (".>\\( +\\)" (1 ".")))) 42 | (set (make-local-variable 'font-lock-multiline) t) 43 | (add-hook 'font-lock-extend-region-functions 44 | 'dex-font-lock-extend-region)) 45 | 46 | (add-to-list 'auto-mode-alist '("\\.dx\\'" . dex-mode)) 47 | -------------------------------------------------------------------------------- /misc/file-check: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | declare -a possible_filecheck_locations=("FileCheck-12" 4 | "FileCheck") 5 | FILECHECK=$(\ 6 | for fc in "${possible_filecheck_locations[@]}" ; do \ 7 | if [[ $(command -v "$fc" 2>/dev/null) ]]; \ 8 | then echo "$fc" ; break ; \ 9 | fi ; \ 10 | done) 11 | 12 | if [[ -z "$FILECHECK" ]]; then 13 | echo "FileCheck not found" 14 | exit 1 15 | fi 16 | 17 | if ${@:2} $1 --outfmt result-only | $FILECHECK $1 ; then 18 | echo "OK" 19 | exit 0 20 | else 21 | exit $? 22 | fi 23 | 24 | -------------------------------------------------------------------------------- /python/dex/interop/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Use of this source code is governed by a BSD-style 4 | # license that can be found in the LICENSE file or at 5 | # https://developers.google.com/open-source/licenses/bsd 6 | -------------------------------------------------------------------------------- /python/dex/interop/jax/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Use of this source code is governed by a BSD-style 4 | # license that can be found in the LICENSE file or at 5 | # https://developers.google.com/open-source/licenses/bsd 6 | 7 | from .apply import primitive 8 | from .jax2dex import dexjit 9 | 10 | __all__ = [ 11 | 'primitive', 12 | 'dexjit', 13 | ] 14 | -------------------------------------------------------------------------------- /python/example.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Use of this source code is governed by a BSD-style 4 | # license that can be found in the LICENSE file or at 5 | # https://developers.google.com/open-source/licenses/bsd 6 | 7 | import dex 8 | 9 | m = dex.Module(""" 10 | x = 2.5 11 | y = [2, 3, 4] 12 | """) 13 | 14 | print(m.x) 15 | print(m.y) 16 | print(int(m.x)) 17 | -------------------------------------------------------------------------------- /python/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | import os 3 | 4 | # Check dex so file exists in dex directory. 5 | so_file = "libDex.so" 6 | dex_dir = os.path.join(os.path.dirname(__file__), 'dex') 7 | if not os.path.exists(os.path.join(dex_dir, so_file)): 8 | raise FileNotFoundError(f"{so_file} not found in dex/, " 9 | f"please run `make build-ffis`") 10 | 11 | setup( 12 | name='dex', 13 | version='0.0.1', 14 | description='A research language for typed, functional array processing', 15 | license='BSD', 16 | author='Adam Paszke', 17 | author_email='apaszke@google.com', 18 | packages=find_packages(), 19 | package_data={'dex': ['libDex.so']}, 20 | install_requires=['numpy'], 21 | ) 22 | -------------------------------------------------------------------------------- /python/tests/api_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Use of this source code is governed by a BSD-style 4 | # license that can be found in the LICENSE file or at 5 | # https://developers.google.com/open-source/licenses/bsd 6 | 7 | import unittest 8 | from textwrap import dedent 9 | 10 | import dex 11 | 12 | class APITest(unittest.TestCase): 13 | def test_eval(self): 14 | cases = [ 15 | "2.5", 16 | "4", 17 | "[2, 3, 4]", 18 | ] 19 | for expr in cases: 20 | assert str(dex.eval(expr)) == expr 21 | 22 | def test_module_attrs(self): 23 | m = dex.Module(dedent(""" 24 | x = 2.5 25 | y = [2, 3, 4] 26 | """)) 27 | assert str(m.x) == "2.5" 28 | assert str(m.y) == "[2, 3, 4]" 29 | 30 | @unittest.skip 31 | def test_function_call(self): 32 | m = dex.Module(dedent(""" 33 | def addOne (x: Float) : Float = x + 1.0 34 | """)) 35 | x = dex.eval("2.5") 36 | y = dex.eval("[2, 3, 4]") 37 | assert str(m.addOne(x)) == "3.5" 38 | assert str(m.sum(y)) == "9" 39 | 40 | def test_scalar_conversions(self): 41 | assert float(dex.eval("5.0")) == 5.0 42 | assert int(dex.eval("5")) == 5 43 | assert str(dex.Atom(5)) == "5" 44 | assert str(dex.Atom(5.0)) == "5." 45 | 46 | 47 | if __name__ == "__main__": 48 | unittest.main() 49 | -------------------------------------------------------------------------------- /python/tests/jaxpr_json_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google LLC 2 | # 3 | # Use of this source code is governed by a BSD-style 4 | # license that can be found in the LICENSE file or at 5 | # https://developers.google.com/open-source/licenses/bsd 6 | 7 | import json 8 | import math 9 | import unittest 10 | 11 | import jax 12 | import jax.numpy as jnp 13 | from dex import api 14 | from dex import native_function as nf 15 | from dex import prelude 16 | import dex.interop.jax.jaxpr_json as jj 17 | 18 | def check_json_round_trip(jaxpr): 19 | dictified = jj.dump_jaxpr(jaxpr) 20 | dump_str = json.dumps(dictified, indent=2) 21 | reconstituted = json.loads(dump_str) 22 | jaxpr_recon = jj.load_jaxpr(reconstituted) 23 | assert str(jaxpr) == str(jaxpr_recon) 24 | 25 | def check_haskell_round_trip(jaxpr): 26 | dictified = jj.dump_jaxpr(jaxpr) 27 | dump_str = json.dumps(dictified, indent=2) 28 | returned = api.from_cstr(api.roundtripJaxprJson(api.as_cstr(dump_str))) 29 | try: 30 | assert dictified == json.loads(returned) 31 | except json.decoder.JSONDecodeError: 32 | assert False, returned 33 | 34 | class JaxprJsonTest(unittest.TestCase): 35 | 36 | def test_json_one_prim(self): 37 | jaxpr = jax.make_jaxpr(jax.numpy.sin)(3.) 38 | check_json_round_trip(jaxpr) 39 | 40 | def test_json_literal(self): 41 | f = lambda x: jax.numpy.sin(x + 1) + 3 42 | check_json_round_trip(jax.make_jaxpr(f)(3.)) 43 | 44 | def test_json_scan(self): 45 | def f(xs): 46 | return jax.lax.scan(lambda tot, z: (tot + z, tot), 0.25, xs) 47 | check_json_round_trip(jax.make_jaxpr(f)(jnp.array([1., 2., 3.]))) 48 | 49 | def test_haskell_one_prim(self): 50 | jaxpr = jax.make_jaxpr(jax.numpy.sin)(3.) 51 | check_haskell_round_trip(jaxpr) 52 | 53 | def test_haskell_literal(self): 54 | f = lambda x: jax.numpy.sin(x + 1) + 3 55 | check_haskell_round_trip(jax.make_jaxpr(f)(3.)) 56 | 57 | def test_haskell_scan(self): 58 | def f(xs): 59 | return jax.lax.scan(lambda tot, z: (tot + z, tot), 0.25, xs) 60 | check_haskell_round_trip(jax.make_jaxpr(f)(jnp.array([1., 2., 3.]))) 61 | 62 | def test_compute_one_prim(self): 63 | jaxpr = jax.make_jaxpr(jax.numpy.sin)(3.) 64 | jaxpr_dict = jj.dump_jaxpr(jaxpr) 65 | jaxpr_dump = json.dumps(jaxpr_dict, indent=2) 66 | module = prelude 67 | cc = api.FlatCC 68 | compiled = api.compileJaxpr(module, cc, api.as_cstr(jaxpr_dump)) 69 | func = nf.NativeFunction(module, compiled, cc) 70 | self.assertAlmostEqual(func(3.), math.sin(3.)) 71 | 72 | # TODO Test bigger shapes (matrices?) 73 | # TODO Test dependent shapes (that have variables in them) 74 | # - How do I make a jaxpr with such a thing in it? 75 | # TODO Test more funny primitives like scan (scan itself works) 76 | -------------------------------------------------------------------------------- /python/tests/jit_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Use of this source code is governed by a BSD-style 4 | # license that can be found in the LICENSE file or at 5 | # https://developers.google.com/open-source/licenses/bsd 6 | 7 | import unittest 8 | import ctypes 9 | import numpy as np 10 | import itertools as it 11 | from textwrap import dedent 12 | 13 | import dex 14 | 15 | example_floats = list(map(np.float32, (-1.0, -0.5, 0.0, 0.5, 1.0))) 16 | example_ints = [-10, -5, 0, 5, 10] 17 | 18 | def check_atom(dex_atom, reference, args_iter): 19 | compiled = dex_atom.compile() 20 | ran_any_iter = False 21 | for args in args_iter: 22 | ran_any_iter = True 23 | np.testing.assert_allclose(compiled(*args), reference(*args), 24 | rtol=1e-4, atol=1e-6) 25 | assert ran_any_iter, "Empty argument iterator!" 26 | 27 | def expr_test(dex_source, reference, args_iter): 28 | def test(self): 29 | return check_atom(dex.eval(dex_source), reference, args_iter) 30 | return test 31 | 32 | class JITTest(unittest.TestCase): 33 | test_sigmoid = expr_test(r"\x:Float. 1.0 / (1.0 + exp(-x))", 34 | lambda x: np.float32(1.0) / (np.float32(1.0) + np.exp(-x)), 35 | ((x,) for x in example_floats)) 36 | 37 | test_multi_arg = expr_test(r"\x:Float y:Float. atan2(x, y)", 38 | np.arctan2, 39 | ((x + 0.01, y) for x, y in it.product(example_floats, repeat=2) 40 | if (x, y) != (0.0, 0.0))) 41 | 42 | test_int_arg = expr_test(r"\x:Int64 y:Int. i64_to_i(x) + y", 43 | lambda x, y: x + y, 44 | it.product(example_ints, example_ints)) 45 | 46 | test_array_scalar = expr_test(r"\x:((Fin 10)=>Float). sum(x)", 47 | np.sum, 48 | [(np.arange(10, dtype=np.float32),)]) 49 | 50 | test_scalar_array = expr_test(r"\x:Int. for i:(Fin 10). x + n_to_i(ordinal(i))", 51 | lambda x: x + np.arange(10, dtype=np.int32), 52 | [(i,) for i in range(5)]) 53 | 54 | test_array_array = expr_test(r"\x:((Fin 10)=>Float). for i. exp(x[i])", 55 | np.exp, 56 | [(np.arange(10, dtype=np.float32),)]) 57 | 58 | def test_polymorphic_array_1d(self): 59 | m = dex.Module(dedent(""" 60 | def addTwo(x: (Fin n)=>Float) -> (Fin n)=>Float given (n) = for i. x[i] + 2.0 61 | """)) 62 | check_atom(m.addTwo, lambda x: x + 2, 63 | [(np.arange(l, dtype=np.float32),) for l in (2, 5, 10)]) 64 | 65 | def test_polymorphic_array_2d(self): 66 | m = dex.Module(dedent(""" 67 | def myTranspose(x : (Fin n)=>(Fin m)=>Float) -> (Fin m)=>(Fin n)=>Float given (n, m) = 68 | for i j. x[j, i] 69 | """)) 70 | check_atom(m.myTranspose, lambda x: x.T, 71 | [(np.arange(a*b, dtype=np.float32).reshape((a, b)),) 72 | for a, b in it.product((2, 5, 10), repeat=2)]) 73 | 74 | def test_tuple_return(self): 75 | dex_func = dex.eval(r"\x: ((Fin 10) => Float). (x, 2. .* x, 3. .* x)") 76 | reference = lambda x: (x, 2 * x, 3 * x) 77 | 78 | x = np.arange(10, dtype=np.float32) 79 | 80 | dex_output = dex_func.compile()(x) 81 | reference_output = reference(x) 82 | 83 | self.assertEqual(len(dex_output), len(reference_output)) 84 | for dex_array, ref_array in zip(dex_output, reference_output): 85 | np.testing.assert_allclose(dex_array, ref_array) 86 | 87 | def test_arrays_of_nats(self): 88 | dex_func = dex.eval(r"\x: ((Fin 10) => Nat). x + x") 89 | reference = lambda x: x + x 90 | 91 | x = np.arange(10, dtype=np.uint32) 92 | 93 | dex_output = dex_func.compile()(x) 94 | reference_output = reference(x) 95 | np.testing.assert_allclose(dex_output, reference_output) 96 | 97 | if __name__ == "__main__": 98 | unittest.main() 99 | -------------------------------------------------------------------------------- /shell.nix: -------------------------------------------------------------------------------- 1 | { pkgs ? import <nixpkgs> {} }: 2 | pkgs.stdenv.mkDerivation { 3 | name = "dex"; 4 | buildInputs = with pkgs; [ 5 | cabal-install 6 | cacert 7 | clang_12 8 | git 9 | haskell.compiler.ghc884 10 | libpng 11 | llvm_12 12 | pkg-config 13 | stack 14 | zlib 15 | ]; 16 | } 17 | -------------------------------------------------------------------------------- /src/Dex/Foreign/API.hs: -------------------------------------------------------------------------------- 1 | -- Copyright 2020 Google LLC 2 | -- 3 | -- Use of this source code is governed by a BSD-style 4 | -- license that can be found in the LICENSE file or at 5 | -- https://developers.google.com/open-source/licenses/bsd 6 | 7 | module Dex.Foreign.API where 8 | 9 | import Foreign.Ptr 10 | import Foreign.C 11 | 12 | import Dex.Foreign.Context 13 | import Dex.Foreign.Serialize 14 | import Dex.Foreign.JAX 15 | import Dex.Foreign.JIT 16 | 17 | -- Public API (commented out exports are defined in rts.c) 18 | 19 | -- Initialization and basic runtime 20 | -- foreign export ccall "dexInit" _ :: IO () 21 | -- foreign export ccall "dexFini" _ :: IO () 22 | -- foreign export ccall "dexGetError" _ :: CString 23 | 24 | -- Context 25 | foreign export ccall "dexCreateContext" dexCreateContext :: IO (Ptr Context) 26 | foreign export ccall "dexDestroyContext" dexDestroyContext :: Ptr Context -> IO () 27 | foreign export ccall "dexForkContext" dexForkContext :: Ptr Context -> IO (Ptr Context) 28 | foreign export ccall "dexInsert" dexInsert :: Ptr Context -> CString -> Ptr AtomEx -> IO () 29 | foreign export ccall "dexEval" dexEval :: Ptr Context -> CString -> IO CInt 30 | foreign export ccall "dexLookup" dexLookup :: Ptr Context -> CString -> IO (Ptr AtomEx) 31 | foreign export ccall "dexFreshName" dexFreshName :: Ptr Context -> IO CString 32 | 33 | -- Serialization 34 | foreign export ccall "dexPrint" dexPrint :: Ptr Context -> Ptr AtomEx -> IO CString 35 | foreign export ccall "dexToCAtom" dexToCAtom :: Ptr AtomEx -> Ptr CAtom -> IO CInt 36 | foreign export ccall "dexFromCAtom" dexFromCAtom :: Ptr CAtom -> IO (Ptr AtomEx) 37 | 38 | -- JIT 39 | foreign export ccall "dexCompile" dexCompile :: Ptr Context -> CInt -> Ptr AtomEx -> IO ExportNativeFunctionAddr 40 | foreign export ccall "dexUnload" dexUnload :: Ptr Context -> ExportNativeFunctionAddr -> IO () 41 | foreign export ccall "dexGetFunctionSignature" dexGetFunctionSignature :: Ptr Context -> ExportNativeFunctionAddr -> IO (Ptr ClosedExportedSignature) 42 | foreign export ccall "dexFreeFunctionSignature" dexFreeFunctionSignature :: Ptr ClosedExportedSignature -> IO () 43 | 44 | -- JAX serialization 45 | foreign export ccall "dexRoundtripJaxprJson" dexRoundtripJaxprJson :: CString -> IO CString 46 | foreign export ccall "dexCompileJaxpr" dexCompileJaxpr :: Ptr Context -> CInt -> CString -> IO ExportNativeFunctionAddr 47 | -------------------------------------------------------------------------------- /src/Dex/Foreign/JAX.hs: -------------------------------------------------------------------------------- 1 | -- Copyright 2023 Google LLC 2 | -- 3 | -- Use of this source code is governed by a BSD-style 4 | -- license that can be found in the LICENSE file or at 5 | -- https://developers.google.com/open-source/licenses/bsd 6 | 7 | module Dex.Foreign.JAX where 8 | 9 | import Control.Monad.IO.Class 10 | import Data.Aeson (encode, eitherDecode') 11 | import qualified Data.ByteString.Lazy.Char8 as B 12 | import Foreign.C 13 | import Foreign.Ptr 14 | 15 | import Dex.Foreign.Context 16 | import Export 17 | import JAX.Concrete 18 | import JAX.Rename 19 | import JAX.ToSimp 20 | import Name 21 | 22 | -- TODO newCString just mallocs the string; we have to 23 | -- arrange for the caller to free it. 24 | dexRoundtripJaxprJson :: CString -> IO CString 25 | dexRoundtripJaxprJson jsonPtr = do 26 | json <- B.pack <$> peekCString jsonPtr 27 | let maybeJaxpr :: Either String (ClosedJaxpr VoidS) = eitherDecode' json 28 | case maybeJaxpr of 29 | Right jaxpr -> do 30 | let redumped = encode jaxpr 31 | newCString $ B.unpack redumped 32 | Left err -> newCString err 33 | 34 | dexCompileJaxpr :: Ptr Context -> CInt -> CString -> IO ExportNativeFunctionAddr 35 | dexCompileJaxpr ctxPtr ccInt jsonPtr = do 36 | json <- B.pack <$> peekCString jsonPtr 37 | let maybeJaxpr :: Either String (ClosedJaxpr VoidS) = eitherDecode' json 38 | case maybeJaxpr of 39 | Right jaxpr -> runTopperMFromContext ctxPtr do 40 | Distinct <- getDistinct 41 | jRename <- liftRenameM $ renameClosedJaxpr (unsafeCoerceE jaxpr) 42 | sLam <- liftJaxSimpM $ simplifyClosedJaxpr jRename 43 | func <- prepareSLamForExport (intAsCC ccInt) sLam 44 | liftIO $ emitExport ctxPtr func 45 | Left err -> error err 46 | -------------------------------------------------------------------------------- /src/Dex/Foreign/JIT.hs: -------------------------------------------------------------------------------- 1 | -- Copyright 2020 Google LLC 2 | -- 3 | -- Use of this source code is governed by a BSD-style 4 | -- license that can be found in the LICENSE file or at 5 | -- https://developers.google.com/open-source/licenses/bsd 6 | 7 | {-# OPTIONS_GHC -Wno-orphans #-} 8 | 9 | module Dex.Foreign.JIT ( 10 | NativeFunction, ClosedExportedSignature, 11 | ExportNativeFunction (..), ExportNativeFunctionAddr, 12 | dexGetFunctionSignature, dexFreeFunctionSignature, 13 | dexCompile, dexUnload 14 | ) where 15 | 16 | import Control.Concurrent.MVar 17 | import Control.Monad.State.Strict 18 | 19 | import Foreign.Ptr 20 | import Foreign.C.String 21 | import Foreign.C.Types 22 | import Foreign.Storable 23 | import Foreign.Marshal.Alloc 24 | 25 | import Data.Functor 26 | import qualified Data.Map.Strict as M 27 | 28 | import Export 29 | import Name 30 | import Types.Core 31 | 32 | import Dex.Foreign.Util 33 | import Dex.Foreign.Context 34 | 35 | dexCompile :: Ptr Context -> CInt -> Ptr AtomEx -> IO ExportNativeFunctionAddr 36 | dexCompile ctxPtr ccInt funcAtomPtr = catchErrors do 37 | AtomEx funcAtom <- fromStablePtr funcAtomPtr 38 | let cc = intAsCC ccInt 39 | runTopperMFromContext ctxPtr do 40 | -- TODO: Check if atom is compatible with context! Use module name? 41 | func <- prepareFunctionForExport cc (unsafeCoerceE funcAtom) 42 | liftIO $ emitExport ctxPtr func 43 | 44 | dexGetFunctionSignature :: Ptr Context -> ExportNativeFunctionAddr -> IO (Ptr (ExportedSignature 'VoidS)) 45 | dexGetFunctionSignature ctxPtr funcPtr = do 46 | Context _ _ ptrTabMVar <- fromStablePtr ctxPtr 47 | addrTable <- readMVar ptrTabMVar 48 | case M.lookup funcPtr addrTable of 49 | Nothing -> setError "Invalid function address" $> nullPtr 50 | Just ExportNativeFunction{..} -> putOnHeap nativeSignature 51 | 52 | dexFreeFunctionSignature :: Ptr (ExportedSignature 'VoidS) -> IO () 53 | dexFreeFunctionSignature sigPtr = do 54 | let strPtr = castPtr @(ExportedSignature 'VoidS) @CString sigPtr 55 | free =<< peekElemOff strPtr 0 56 | free =<< peekElemOff strPtr 1 57 | free =<< peekElemOff strPtr 2 58 | free sigPtr 59 | 60 | dexUnload :: Ptr Context -> ExportNativeFunctionAddr -> IO () 61 | dexUnload ctxPtr funcPtr = do 62 | f <- popFromNativeFunctionTable ctxPtr funcPtr 63 | nativeFunTeardown $ nativeFunction f 64 | -------------------------------------------------------------------------------- /src/Dex/Foreign/Serialize.hs: -------------------------------------------------------------------------------- 1 | -- Copyright 2020 Google LLC 2 | -- 3 | -- Use of this source code is governed by a BSD-style 4 | -- license that can be found in the LICENSE file or at 5 | -- https://developers.google.com/open-source/licenses/bsd 6 | 7 | module Dex.Foreign.Serialize ( 8 | CAtom, 9 | dexPrint, dexToCAtom, dexFromCAtom 10 | ) where 11 | 12 | import Control.Monad.IO.Class 13 | import Data.Word 14 | import Data.Functor 15 | 16 | import Foreign.C 17 | import Foreign.Ptr 18 | import Foreign.Storable 19 | 20 | import IRVariants 21 | import Name 22 | import TopLevel 23 | import Types.Core hiding (CAtom) 24 | import Types.Primitives 25 | 26 | import Dex.Foreign.Context 27 | import Dex.Foreign.Util 28 | 29 | -- TODO: Free! 30 | dexPrint :: Ptr Context -> Ptr AtomEx -> IO CString 31 | dexPrint contextPtr atomPtr = do 32 | AtomEx atom <- fromStablePtr atomPtr 33 | runTopperMFromContext contextPtr do 34 | -- TODO: Check consistency of atom and context 35 | liftIO . newCString =<< printCodegen (unsafeCoerceE atom) 36 | 37 | data CAtom = CLit LitVal | CRectArray (Ptr ()) [Int] [Int] 38 | 39 | instance Storable CAtom where 40 | sizeOf _ = tag + val + val + val 41 | where tag = 8; val = 8 42 | alignment _ = 8 43 | peek addr = do 44 | tag <- val @Word64 0 45 | case tag of 46 | 0 -> do 47 | litTag <- val @Word64 1 48 | CLit <$> case litTag of 49 | 0 -> Int64Lit <$> val 2 50 | 1 -> Int32Lit <$> val 2 51 | 2 -> Word8Lit <$> val 2 52 | 3 -> Float64Lit <$> val 2 53 | 4 -> Float32Lit <$> val 2 54 | 5 -> Word32Lit <$> val 2 55 | 6 -> Word64Lit <$> val 2 56 | _ -> error "Invalid tag" 57 | _ -> error "Invalid tag" 58 | where 59 | val :: forall a. Storable a => Int -> IO a 60 | val i = peekByteOff (castPtr addr) (i * 8) 61 | poke addr catom = case catom of 62 | CLit lit -> do 63 | val @Word64 0 0 64 | case lit of 65 | Int64Lit v -> val @Word64 1 0 >> val 2 v 66 | Int32Lit v -> val @Word64 1 1 >> val 2 v 67 | Word8Lit v -> val @Word64 1 2 >> val 2 v 68 | Float64Lit v -> val @Word64 1 3 >> val 2 v 69 | Float32Lit v -> val @Word64 1 4 >> val 2 v 70 | Word32Lit v -> val @Word64 1 5 >> val 2 v 71 | Word64Lit v -> val @Word64 1 6 >> val 2 v 72 | PtrLit _ _ -> error "Unsupported" 73 | CRectArray _ _ _ -> error "Unsupported" 74 | where 75 | val :: forall a. Storable a => Int -> a -> IO () 76 | val i v = pokeByteOff (castPtr addr) (i * 8) v 77 | 78 | dexToCAtom :: Ptr AtomEx -> Ptr CAtom -> IO CInt 79 | dexToCAtom atomPtr resultPtr = do 80 | AtomEx atom <- fromStablePtr atomPtr 81 | scalarAtomToCAtom atom 82 | where 83 | scalarAtomToCAtom :: Atom CoreIR n -> IO CInt 84 | scalarAtomToCAtom atom = case atom of 85 | Con con -> case con of 86 | Lit l -> poke resultPtr (CLit l) $> 1 87 | _ -> notSerializable 88 | NewtypeCon NatCon rep -> scalarAtomToCAtom rep 89 | _ -> notSerializable 90 | 91 | notSerializable = setError "Unserializable atom" $> 0 92 | 93 | dexFromCAtom :: Ptr CAtom -> IO (Ptr AtomEx) 94 | dexFromCAtom catomPtr = do 95 | catom <- peek catomPtr 96 | case catom of 97 | CLit lit -> toStablePtr $ AtomEx $ Con $ Lit lit 98 | CRectArray _ _ _ -> unsupported 99 | where 100 | unsupported = setError "Unsupported CAtom" $> nullPtr 101 | -------------------------------------------------------------------------------- /src/Dex/Foreign/Util.hs: -------------------------------------------------------------------------------- 1 | -- Copyright 2020 Google LLC 2 | -- 3 | -- Use of this source code is governed by a BSD-style 4 | -- license that can be found in the LICENSE file or at 5 | -- https://developers.google.com/open-source/licenses/bsd 6 | 7 | module Dex.Foreign.Util (fromStablePtr, toStablePtr, putOnHeap, setError, catchErrors, copyMVar) where 8 | 9 | import Control.Concurrent.MVar 10 | 11 | import Data.Int 12 | import Data.Functor 13 | 14 | import Foreign.Ptr 15 | import Foreign.StablePtr 16 | import Foreign.Storable 17 | import Foreign.C.String 18 | import Foreign.Marshal.Alloc 19 | 20 | import Err 21 | 22 | fromStablePtr :: Ptr a -> IO a 23 | fromStablePtr = deRefStablePtr . castPtrToStablePtr . castPtr 24 | 25 | toStablePtr :: a -> IO (Ptr a) 26 | toStablePtr x = castPtr . castStablePtrToPtr <$> newStablePtr x 27 | 28 | putOnHeap :: Storable a => a -> IO (Ptr a) 29 | putOnHeap x = do 30 | ptr <- malloc 31 | poke ptr x 32 | return ptr 33 | 34 | catchErrors :: IO (FunPtr a) -> IO (FunPtr a) 35 | catchErrors m = catchIOExcept m >>= \case 36 | Success ans -> return ans 37 | Failure err -> setError (pprint err) $> castPtrToFunPtr nullPtr 38 | 39 | foreign import ccall "_internal_dexSetError" internalSetErrorPtr :: CString -> Int64 -> IO () 40 | 41 | setError :: String -> IO () 42 | setError msg = withCStringLen msg $ \(ptr, len) -> 43 | internalSetErrorPtr ptr (fromIntegral len) 44 | 45 | copyMVar :: MVar a -> IO (MVar a) 46 | copyMVar mvar = readMVar mvar >>= newMVar 47 | -------------------------------------------------------------------------------- /src/Dex/Foreign/rts.c: -------------------------------------------------------------------------------- 1 | #include <stdlib.h> 2 | #include <stddef.h> 3 | #include <string.h> 4 | #include "HsFFI.h" 5 | 6 | void dexInit() { 7 | int argc = 4; 8 | char *argv[] = { "+RTS", "-I0", "-A16m", "-RTS", NULL }; 9 | char **pargv = argv; 10 | 11 | hs_init(&argc, &pargv); 12 | } 13 | 14 | void dexFini() { 15 | hs_exit(); 16 | } 17 | 18 | __thread char dex_err_storage[2048]; 19 | 20 | const char* dexGetError() { 21 | return dex_err_storage; 22 | } 23 | 24 | void _internal_dexSetError(char* new_err, int64_t len) { 25 | if (len > 2048) len = 2048; 26 | memcpy(dex_err_storage, new_err, len); 27 | dex_err_storage[2047] = 0; 28 | } 29 | 30 | typedef int64_t (*dex_xla_f)(void*, void**); 31 | void dexXLACPUTrampoline(void* out, void** in) { 32 | dex_xla_f f = *((dex_xla_f*)(*in)); 33 | f(out, in + 1); 34 | } 35 | -------------------------------------------------------------------------------- /src/lib/CUDA.hs: -------------------------------------------------------------------------------- 1 | 2 | module CUDA (hasCUDA, loadCUDAArray, synchronizeCUDA, getCudaArchitecture) where 3 | 4 | import Data.Int 5 | import Foreign.Ptr 6 | #ifdef DEX_CUDA 7 | import Foreign.C 8 | #else 9 | #endif 10 | 11 | hasCUDA :: Bool 12 | 13 | #ifdef DEX_CUDA 14 | hasCUDA = True 15 | 16 | foreign import ccall "dex_cuMemcpyDtoH" cuMemcpyDToH :: Int64 -> Ptr () -> Ptr () -> IO () 17 | foreign import ccall "dex_synchronizeCUDA" synchronizeCUDA :: IO () 18 | foreign import ccall "dex_ensure_has_cuda_context" ensureHasCUDAContext :: IO () 19 | foreign import ccall "dex_get_cuda_architecture" dex_getCudaArchitecture :: Int -> CString -> IO () 20 | 21 | getCudaArchitecture :: Int -> IO String 22 | getCudaArchitecture dev = 23 | withCString "sm_00" $ \cs -> 24 | dex_getCudaArchitecture dev cs >> peekCString cs 25 | #else 26 | hasCUDA = False 27 | 28 | cuMemcpyDToH :: Int64 -> Ptr () -> Ptr () -> IO () 29 | cuMemcpyDToH = error "Dex built without CUDA support" 30 | 31 | synchronizeCUDA :: IO () 32 | synchronizeCUDA = return () 33 | {-# SCC synchronizeCUDA #-} 34 | 35 | ensureHasCUDAContext :: IO () 36 | ensureHasCUDAContext = return () 37 | {-# SCC ensureHasCUDAContext #-} 38 | 39 | getCudaArchitecture :: Int -> IO String 40 | getCudaArchitecture _ = error "Dex built without CUDA support" 41 | #endif 42 | 43 | loadCUDAArray :: Ptr () -> Ptr () -> Int -> IO () 44 | loadCUDAArray hostPtr devicePtr bytes = do 45 | ensureHasCUDAContext 46 | cuMemcpyDToH (fromIntegral bytes) devicePtr hostPtr 47 | {-# SCC loadCUDAArray #-} 48 | -------------------------------------------------------------------------------- /src/lib/IRVariants.hs: -------------------------------------------------------------------------------- 1 | -- Copyright 2022 Google LLC 2 | -- 3 | -- Use of this source code is governed by a BSD-style 4 | -- license that can be found in the LICENSE file or at 5 | -- https://developers.google.com/open-source/licenses/bsd 6 | 7 | {-# LANGUAGE AllowAmbiguousTypes #-} 8 | 9 | module IRVariants 10 | ( IR (..), IRPredicate (..), Sat, Sat' 11 | , CoreToSimpIR, InferenceIR, IRRep (..), IRProxy (..), interpretIR 12 | , IRsEqual (..), eqIRRep, WhenIR (..)) where 13 | 14 | import GHC.Generics (Generic (..)) 15 | import Data.Store 16 | import Data.Hashable 17 | import Data.Store.Internal 18 | import Data.Kind 19 | 20 | import qualified Unsafe.Coerce as TrulyUnsafe 21 | 22 | data IR = 23 | CoreIR -- used after inference and before simplification 24 | | SimpIR -- used after simplification 25 | deriving (Eq, Ord, Generic, Show, Enum) 26 | instance Store IR 27 | 28 | type CoreToSimpIR = CoreIR -- used during the Core-to-Simp translation 29 | data IRFeature = 30 | DAMOps 31 | | CoreOps 32 | | SimpOps 33 | 34 | -- TODO: make this a hard distinctions 35 | type InferenceIR = CoreIR -- used during type inference only 36 | 37 | data IRPredicate = 38 | Is IR 39 | -- TODO: find a way to make this safe and derive it automatically. For now, we 40 | -- assert it manually for the valid cases we know about. 41 | | IsSubsetOf IR 42 | | HasFeature IRFeature 43 | 44 | type Sat (r::IR) (p::IRPredicate) = (Sat' r p ~ True) :: Constraint 45 | type family Sat' (r::IR) (p::IRPredicate) where 46 | Sat' r (Is r) = True 47 | -- subsets 48 | Sat' SimpIR (IsSubsetOf CoreIR) = True 49 | -- DAMOps 50 | Sat' SimpIR (HasFeature DAMOps) = True 51 | -- DAMOps 52 | Sat' SimpIR (HasFeature SimpOps) = True 53 | -- CoreOps 54 | Sat' CoreIR (HasFeature CoreOps) = True 55 | -- otherwise 56 | Sat' _ _ = False 57 | 58 | class IRRep (r::IR) where 59 | getIRRep :: IR 60 | 61 | data IRProxy (r::IR) = IRProxy 62 | 63 | interpretIR :: IR -> (forall r. IRRep r => IRProxy r -> a) -> a 64 | interpretIR ir cont = case ir of 65 | CoreIR -> cont $ IRProxy @CoreIR 66 | SimpIR -> cont $ IRProxy @SimpIR 67 | 68 | instance IRRep CoreIR where getIRRep = CoreIR 69 | instance IRRep SimpIR where getIRRep = SimpIR 70 | 71 | data IRsEqual (r1::IR) (r2::IR) where 72 | IRsEqual :: IRsEqual r r 73 | 74 | eqIRRep :: forall r1 r2. (IRRep r1, IRRep r2) => Maybe (IRsEqual r1 r2) 75 | eqIRRep = if r1Rep == r2Rep 76 | then Just (TrulyUnsafe.unsafeCoerce (IRsEqual :: IRsEqual r1 r1) :: IRsEqual r1 r2) 77 | else Nothing 78 | where r1Rep = getIRRep @r1; r2Rep = getIRRep @r2 79 | {-# INLINE eqIRRep #-} 80 | 81 | data WhenIR (r::IR) (r'::IR) (a::Type) where 82 | WhenIR :: a -> WhenIR r r a 83 | 84 | instance (IRRep r, IRRep r', Store e) => Store (WhenIR r r' e) where 85 | size = VarSize \(WhenIR e) -> getSize e 86 | peek = case eqIRRep @r @r' of 87 | Just IRsEqual -> WhenIR <$> peek 88 | Nothing -> error "impossible" 89 | poke (WhenIR e) = poke e 90 | 91 | instance Hashable a => Hashable (WhenIR r r' a) where 92 | hashWithSalt salt (WhenIR a) = hashWithSalt salt a 93 | 94 | deriving instance Show a => Show (WhenIR r r' a) 95 | deriving instance Eq a => Eq (WhenIR r r' a) 96 | -------------------------------------------------------------------------------- /src/lib/JAX/Rename.hs: -------------------------------------------------------------------------------- 1 | -- Copyright 2023 Google LLC 2 | -- 3 | -- Use of this source code is governed by a BSD-style 4 | -- license that can be found in the LICENSE file or at 5 | -- https://developers.google.com/open-source/licenses/bsd 6 | 7 | module JAX.Rename (liftRenameM, renameClosedJaxpr, renameJaxpr) where 8 | 9 | import Control.Monad.Reader 10 | import Data.Map qualified as M 11 | 12 | import Core 13 | import IRVariants 14 | import JAX.Concrete 15 | import MTL1 16 | import Name 17 | 18 | newtype RenamerM (n::S) (a:: *) = 19 | RenamerM { runRenamerM :: ReaderT1 SourceMap (ScopeReaderM) n a } 20 | deriving ( Functor, Applicative, Monad 21 | , ScopeReader, ScopeExtender) 22 | 23 | newtype SourceMap (n::S) = SourceMap 24 | (M.Map JSourceName (Name (AtomNameC SimpIR) n)) 25 | deriving (Semigroup, Monoid) 26 | 27 | instance SinkableE SourceMap where 28 | sinkingProofE = undefined 29 | 30 | askSourceMap :: RenamerM n (SourceMap n) 31 | askSourceMap = RenamerM ask 32 | 33 | extendSourceMap :: JSourceName -> (Name (AtomNameC SimpIR)) n 34 | -> RenamerM n a -> RenamerM n a 35 | extendSourceMap sname name (RenamerM cont) = RenamerM do 36 | let ext = SourceMap $ M.singleton sname name 37 | local (<> ext) cont 38 | 39 | liftRenameM :: EnvReader m => RenamerM n (e n) -> m n (e n) 40 | liftRenameM act = liftScopeReaderM $ runReaderT1 mempty $ runRenamerM act 41 | 42 | renameClosedJaxpr :: Distinct o => ClosedJaxpr i -> RenamerM o (ClosedJaxpr o) 43 | renameClosedJaxpr ClosedJaxpr{jaxpr, consts} = do 44 | jaxpr' <- renameJaxpr jaxpr 45 | return ClosedJaxpr{jaxpr=jaxpr', consts} 46 | 47 | renameJaxpr :: Distinct o => Jaxpr i -> RenamerM o (Jaxpr o) 48 | renameJaxpr (Jaxpr invars constvars eqns outvars) = 49 | renameJBinders invars \invars' -> 50 | renameJBinders constvars \constvars' -> 51 | renameJEqns eqns \eqns' -> do 52 | outvars' <- mapM renameJAtom outvars 53 | return $ Jaxpr invars' constvars' eqns' outvars' 54 | 55 | renameJBinder :: Distinct o 56 | => JBinder i i' 57 | -> (forall o'. DExt o o' => JBinder o o' -> RenamerM o' a) 58 | -> RenamerM o a 59 | renameJBinder binder cont = case binder of 60 | JBindSource sname ty -> do 61 | withFreshM (getNameHint sname) \freshName -> do 62 | Distinct <- getDistinct 63 | extendSourceMap sname (binderName freshName) $ 64 | cont $ JBind sname ty freshName 65 | JBind _ _ _ -> error "Shouldn't be source-renaming internal names" 66 | 67 | renameJBinders :: Distinct o 68 | => Nest JBinder i i' 69 | -> (forall o'. DExt o o' => Nest JBinder o o' -> RenamerM o' a) 70 | -> RenamerM o a 71 | renameJBinders Empty cont = cont Empty 72 | renameJBinders (Nest b bs) cont = 73 | renameJBinder b \b' -> 74 | renameJBinders bs \bs' -> 75 | cont $ Nest b' bs' 76 | 77 | renameJAtom :: JAtom i -> RenamerM o (JAtom o) 78 | renameJAtom = \case 79 | JVariable jvar -> JVariable <$> renameJVar jvar 80 | JLiteral jlit -> return $ JLiteral jlit 81 | 82 | renameJVar :: JVar i -> RenamerM o (JVar o) 83 | renameJVar JVar{sourceName, ty} = do 84 | sourceName' <- renameJSourceNameOr sourceName 85 | return $ JVar sourceName' ty 86 | 87 | renameJSourceNameOr :: JSourceNameOr (Name (AtomNameC SimpIR)) i 88 | -> RenamerM o (JSourceNameOr (Name (AtomNameC SimpIR)) o) 89 | renameJSourceNameOr = \case 90 | SourceName sname -> do 91 | SourceMap sm <- askSourceMap 92 | case M.lookup sname sm of 93 | (Just name) -> return $ InternalName sname name 94 | Nothing -> error $ "Unbound variable " ++ show sname 95 | InternalName _ _ -> error "Shouldn't be source-renaming internal names" 96 | 97 | renameJEqn :: Distinct o 98 | => JEqn i i' 99 | -> (forall o'. DExt o o' => JEqn o o' -> RenamerM o' a) 100 | -> RenamerM o a 101 | renameJEqn JEqn{outvars, primitive, invars} cont = do 102 | invars' <- mapM renameJAtom invars 103 | renameJBinders outvars \outvars' -> cont $ JEqn outvars' primitive invars' 104 | 105 | renameJEqns :: Distinct o 106 | => Nest JEqn i i' 107 | -> (forall o'. DExt o o' => Nest JEqn o o' -> RenamerM o' a) 108 | -> RenamerM o a 109 | renameJEqns Empty cont = cont Empty 110 | renameJEqns (Nest b bs) cont = 111 | renameJEqn b \b' -> 112 | renameJEqns bs \bs' -> 113 | cont $ Nest b' bs' 114 | 115 | -------------------------------------------------------------------------------- /src/lib/LLVM/Link.hs: -------------------------------------------------------------------------------- 1 | -- Copyright 2022 Google LLC 2 | -- 3 | -- Use of this source code is governed by a BSD-style 4 | -- license that can be found in the LICENSE file or at 5 | -- https://developers.google.com/open-source/licenses/bsd 6 | 7 | module LLVM.Link 8 | ( createLinker, destroyLinker 9 | , addExplicitLinkMap, addObjectFile, getFunctionPointer 10 | , ExplicitLinkMap 11 | ) where 12 | 13 | import Data.String (fromString) 14 | import Foreign.Ptr 15 | import qualified Data.ByteString as BS 16 | 17 | import System.IO 18 | import System.IO.Temp 19 | 20 | import qualified LLVM.OrcJIT as OrcJIT 21 | import qualified LLVM.Internal.OrcJIT as OrcJIT 22 | import qualified LLVM.Internal.Target as Target 23 | import qualified LLVM.Internal.FFI.Target as FFI 24 | 25 | import qualified LLVM.Shims 26 | 27 | data Linker = Linker 28 | { linkerExecutionSession :: OrcJIT.ExecutionSession 29 | #ifdef darwin_HOST_OS 30 | , linkerLinkLayer :: OrcJIT.ObjectLinkingLayer 31 | #else 32 | , linkerLinkLayer :: OrcJIT.RTDyldObjectLinkingLayer 33 | #endif 34 | , _linkerTargetMachine :: Target.TargetMachine 35 | -- We ought to just need the link layer and the mangler but but llvm-hs 36 | -- requires a full `IRCompileLayer` for incidental reasons. TODO: fix. 37 | , linkerIRLayer :: OrcJIT.IRCompileLayer 38 | , linkerDylib :: OrcJIT.JITDylib } 39 | 40 | instance OrcJIT.IRLayer Linker where 41 | -- llvm-hs requires an compile/IR layer but don't actually need it for the 42 | -- linking functions we call. TODO: update llvm-hs to expose more precise 43 | -- requirements for its linking functions. 44 | getIRLayer l = OrcJIT.getIRLayer $ linkerIRLayer l 45 | getDataLayout l = OrcJIT.getDataLayout $ linkerIRLayer l 46 | getMangler l = OrcJIT.getMangler $ linkerIRLayer l 47 | 48 | type CName = String 49 | 50 | type ExplicitLinkMap = [(CName, Ptr ())] 51 | 52 | createLinker :: IO Linker 53 | createLinker = do 54 | -- TODO: should this be a parameter to `createLinker` instead? 55 | tm <- LLVM.Shims.newDefaultHostTargetMachine 56 | s <- OrcJIT.createExecutionSession 57 | #ifdef darwin_HOST_OS 58 | linkLayer <- OrcJIT.createObjectLinkingLayer s 59 | #else 60 | linkLayer <- OrcJIT.createRTDyldObjectLinkingLayer s 61 | #endif 62 | dylib <- OrcJIT.createJITDylib s "main_dylib" 63 | compileLayer <- OrcJIT.createIRCompileLayer s linkLayer tm 64 | OrcJIT.addDynamicLibrarySearchGeneratorForCurrentProcess compileLayer dylib 65 | return $ Linker s linkLayer tm compileLayer dylib 66 | 67 | destroyLinker :: Linker -> IO () 68 | destroyLinker (Linker session _ (Target.TargetMachine tm) _ _) = do 69 | -- dylib, link layer and IRLayer should get cleaned up automatically 70 | OrcJIT.disposeExecutionSession session 71 | FFI.disposeTargetMachine tm 72 | 73 | addExplicitLinkMap :: Linker -> ExplicitLinkMap -> IO () 74 | addExplicitLinkMap l linkMap = do 75 | let (linkedNames, linkedPtrs) = unzip linkMap 76 | let flags = OrcJIT.defaultJITSymbolFlags { OrcJIT.jitSymbolAbsolute = True } 77 | let ptrSymbols = [OrcJIT.JITSymbol (ptrToWordPtr ptr) flags | ptr <- linkedPtrs] 78 | mangledNames <- mapM (OrcJIT.mangleSymbol l . fromString) linkedNames 79 | OrcJIT.defineAbsoluteSymbols (linkerDylib l) $ zip mangledNames ptrSymbols 80 | mapM_ OrcJIT.disposeMangledSymbol mangledNames 81 | 82 | addObjectFile :: Linker -> BS.ByteString -> IO () 83 | addObjectFile l objFileContents = do 84 | withSystemTempFile "objfile.o" \path h -> do 85 | BS.hPut h objFileContents 86 | hFlush h 87 | OrcJIT.addObjectFile (linkerLinkLayer l) (linkerDylib l) path 88 | 89 | getFunctionPointer :: Linker -> CName -> IO (FunPtr a) 90 | getFunctionPointer l name = do 91 | OrcJIT.lookupSymbol (linkerExecutionSession l) (linkerIRLayer l) 92 | (linkerDylib l) (fromString name) >>= \case 93 | Right (OrcJIT.JITSymbol funcAddr _) -> 94 | return $ castPtrToFunPtr $ wordPtrToPtr funcAddr 95 | Left s -> error $ "Couldn't find function: " ++ name ++ "\n" ++ show s 96 | -------------------------------------------------------------------------------- /src/lib/LLVM/Shims.hs: -------------------------------------------------------------------------------- 1 | -- Copyright 2021 Google LLC 2 | -- 3 | -- Use of this source code is governed by a BSD-style 4 | -- license that can be found in the LICENSE file or at 5 | -- https://developers.google.com/open-source/licenses/bsd 6 | 7 | module LLVM.Shims ( 8 | newTargetMachine, newHostTargetMachine, disposeTargetMachine, 9 | newDefaultHostTargetMachine 10 | ) where 11 | 12 | import qualified Data.Map as M 13 | import qualified Data.ByteString.Char8 as BS 14 | import qualified Data.ByteString.Short as SBS 15 | 16 | import qualified LLVM.Relocation as R 17 | import qualified LLVM.CodeModel as CM 18 | import qualified LLVM.CodeGenOpt as CGO 19 | import qualified LLVM.Internal.Target as Target 20 | import qualified LLVM.Internal.FFI.Target as Target.FFI 21 | import LLVM.Prelude (ShortByteString, ByteString) 22 | import LLVM.Internal.Coding (encodeM) 23 | 24 | -- llvm-hs doesn't expose any way to manage target machines in a non-bracketed way 25 | 26 | newTargetMachine :: Target.Target 27 | -> ShortByteString 28 | -> ByteString 29 | -> M.Map Target.CPUFeature Bool 30 | -> Target.TargetOptions 31 | -> R.Model 32 | -> CM.Model 33 | -> CGO.Level 34 | -> IO Target.TargetMachine 35 | newTargetMachine (Target.Target targetFFI) triple cpu features 36 | (Target.TargetOptions targetOptFFI) 37 | relocModel codeModel cgoLevel = do 38 | SBS.useAsCString triple \tripleFFI -> do 39 | BS.useAsCString cpu \cpuFFI -> do 40 | let featuresStr = BS.intercalate "," $ fmap encodeFeature $ M.toList features 41 | BS.useAsCString featuresStr \featuresFFI -> do 42 | relocModelFFI <- encodeM relocModel 43 | codeModelFFI <- encodeM codeModel 44 | cgoLevelFFI <- encodeM cgoLevel 45 | Target.TargetMachine <$> Target.FFI.createTargetMachine 46 | targetFFI tripleFFI cpuFFI featuresFFI 47 | targetOptFFI relocModelFFI codeModelFFI cgoLevelFFI 48 | where encodeFeature (Target.CPUFeature f, on) = (if on then "+" else "-") <> f 49 | 50 | -- XXX: We need to use the large code model for macOS, because the libC functions 51 | -- are loaded very far away from the JITed code. This does not prevent the 52 | -- runtime linker from attempting to shove their offsets into 32-bit values 53 | -- which cannot represent them, leading to segfaults that are very fun to debug. 54 | -- It would be good to find a better solution, because larger code models might 55 | -- hurt performance if we were to end up doing a lot of function calls. 56 | -- TODO: Consider changing the linking layer, as suggested in: 57 | -- http://llvm.1065342.n5.nabble.com/llvm-dev-ORC-JIT-Weekly-5-td135203.html 58 | newDefaultHostTargetMachine :: IO Target.TargetMachine 59 | newDefaultHostTargetMachine = LLVM.Shims.newHostTargetMachine R.PIC cm CGO.Aggressive 60 | where 61 | #if darwin_HOST_OS 62 | cm = CM.Small 63 | #else 64 | cm = CM.Large 65 | #endif 66 | 67 | newHostTargetMachine :: R.Model -> CM.Model -> CGO.Level -> IO Target.TargetMachine 68 | newHostTargetMachine relocModel codeModel cgoLevel = do 69 | Target.initializeAllTargets 70 | triple <- Target.getProcessTargetTriple 71 | (target, _) <- Target.lookupTarget Nothing triple 72 | cpu <- Target.getHostCPUName 73 | features <- Target.getHostCPUFeatures 74 | Target.withTargetOptions \targetOptions -> 75 | newTargetMachine target triple cpu features targetOptions relocModel codeModel cgoLevel 76 | 77 | disposeTargetMachine :: Target.TargetMachine -> IO () 78 | disposeTargetMachine (Target.TargetMachine tmFFI) = Target.FFI.disposeTargetMachine tmFFI 79 | -------------------------------------------------------------------------------- /src/lib/Live/Web.hs: -------------------------------------------------------------------------------- 1 | -- Copyright 2019 Google LLC 2 | -- 3 | -- Use of this source code is governed by a BSD-style 4 | -- license that can be found in the LICENSE file or at 5 | -- https://developers.google.com/open-source/licenses/bsd 6 | 7 | module Live.Web (runWeb, generateHTML) where 8 | 9 | import Control.Concurrent (readChan) 10 | import Control.Monad (forever) 11 | 12 | import Network.Wai (Application, StreamingBody, pathInfo, 13 | responseStream, responseLBS, responseFile) 14 | import Network.Wai.Handler.Warp (run) 15 | import Network.HTTP.Types (status200, status404) 16 | import Data.Aeson (ToJSON, encode) 17 | import Data.Binary.Builder (fromByteString) 18 | import Data.ByteString.Lazy (toStrict) 19 | import qualified Data.ByteString as BS 20 | import System.Directory (withCurrentDirectory) 21 | 22 | -- import Paths_dex (getDataFileName) 23 | import RenderHtml 24 | import Live.Eval 25 | import TopLevel 26 | 27 | runWeb :: FilePath -> EvalConfig -> TopStateEx -> IO () 28 | runWeb fname opts env = do 29 | resultsChan <- watchAndEvalFile fname opts env 30 | putStrLn "Streaming output to http://localhost:8000/" 31 | run 8000 $ serveResults resultsChan 32 | 33 | pagesDir :: FilePath 34 | pagesDir = "pages" 35 | 36 | generateHTML :: FilePath -> FilePath -> EvalConfig -> TopStateEx -> IO () 37 | generateHTML sourcePath destPath cfg env = do 38 | finalState <- evalFileNonInteractive sourcePath cfg env 39 | results <- renderResults finalState 40 | withCurrentDirectory pagesDir do 41 | renderStandaloneHTML destPath results 42 | 43 | serveResults :: EvalServer -> Application 44 | serveResults resultsSubscribe request respond = do 45 | print (pathInfo request) 46 | case pathInfo request of 47 | [] -> respondWith "static/dynamic.html" "text/html" 48 | ["style.css"] -> respondWith "static/style.css" "text/css" 49 | ["index.js"] -> respondWith "static/index.js" "text/javascript" 50 | ["getnext"] -> respond $ responseStream status200 51 | [ ("Content-Type", "text/event-stream") 52 | , ("Cache-Control", "no-cache")] 53 | $ resultStream resultsSubscribe 54 | _ -> respond $ responseLBS status404 55 | [("Content-Type", "text/plain")] "404 - Not Found" 56 | where 57 | respondWith dataFname ctype = do 58 | fname <- return dataFname -- lets us skip rebuilding during development 59 | -- fname <- getDataFileName dataFname 60 | respond $ responseFile status200 [("Content-Type", ctype)] fname Nothing 61 | 62 | resultStream :: EvalServer -> StreamingBody 63 | resultStream resultsServer write flush = do 64 | sendUpdate ("start"::String) 65 | (initResult, resultsChan) <- subscribeIO resultsServer 66 | (renderedInit, renderUpdateFun) <- renderResultsInc initResult 67 | sendUpdate renderedInit 68 | forever $ readChan resultsChan >>= renderUpdateFun >>= sendUpdate 69 | where 70 | sendUpdate :: ToJSON a => a -> IO () 71 | sendUpdate x = write (fromByteString $ encodePacket x) >> flush 72 | 73 | encodePacket :: ToJSON a => a -> BS.ByteString 74 | encodePacket = toStrict . wrap . encode 75 | where wrap s = "data:" <> s <> "\n\n" 76 | -------------------------------------------------------------------------------- /src/lib/Serialize.hs: -------------------------------------------------------------------------------- 1 | -- Copyright 2019 Google LLC 2 | -- 3 | -- Use of this source code is governed by a BSD-style 4 | -- license that can be found in the LICENSE file or at 5 | -- https://developers.google.com/open-source/licenses/bsd 6 | 7 | module Serialize (HasPtrs (..), takePtrSnapshot, restorePtrSnapshot) where 8 | 9 | import Prelude hiding (pi, abs) 10 | import Control.Monad 11 | import qualified Data.ByteString as BS 12 | import Data.ByteString.Internal (memcpy) 13 | import Data.ByteString.Unsafe (unsafeUseAsCString) 14 | import Data.Int 15 | import Data.Store hiding (size) 16 | import Foreign.Ptr 17 | import Foreign.Marshal.Array 18 | import GHC.Generics (Generic) 19 | 20 | import Types.Primitives 21 | 22 | foreign import ccall "malloc_dex" dexMalloc :: Int64 -> IO (Ptr ()) 23 | foreign import ccall "dex_allocation_size" dexAllocSize :: Ptr () -> IO Int64 24 | 25 | data WithSnapshot a = WithSnapshot a [PtrSnapshot] deriving Generic 26 | type RawPtr = Ptr () 27 | 28 | class HasPtrs a where 29 | traversePtrs :: Applicative f => (PtrType -> RawPtr -> f RawPtr) -> a -> f a 30 | 31 | takePtrSnapshot :: PtrType -> PtrLitVal -> IO PtrLitVal 32 | takePtrSnapshot _ NullPtr = return NullPtr 33 | takePtrSnapshot (CPU, ptrTy) (PtrLitVal ptrVal) = case ptrTy of 34 | PtrType eltTy -> do 35 | childPtrs <- loadPtrPtrs ptrVal 36 | PtrSnapshot <$> PtrArray <$> mapM (takePtrSnapshot eltTy) childPtrs 37 | _ -> PtrSnapshot . ByteArray <$> loadPtrBytes ptrVal 38 | takePtrSnapshot (GPU, _) _ = error "Snapshots of GPU memory not implemented" 39 | takePtrSnapshot _ (PtrSnapshot _) = error "Already a snapshot" 40 | {-# SCC takePtrSnapshot #-} 41 | 42 | loadPtrBytes :: RawPtr -> IO BS.ByteString 43 | loadPtrBytes ptr = do 44 | numBytes <- fromIntegral <$> dexAllocSize ptr 45 | liftM BS.pack $ peekArray numBytes $ castPtr ptr 46 | 47 | loadPtrPtrs :: RawPtr -> IO [PtrLitVal] 48 | loadPtrPtrs ptr = do 49 | numBytes <- fromIntegral <$> dexAllocSize ptr 50 | childPtrs <- peekArray (numBytes `div` ptrSize) $ castPtr ptr 51 | forM childPtrs \childPtr -> 52 | if childPtr == nullPtr 53 | then return NullPtr 54 | else return $ PtrLitVal childPtr 55 | 56 | restorePtrSnapshot :: PtrLitVal -> IO PtrLitVal 57 | restorePtrSnapshot NullPtr = return NullPtr 58 | restorePtrSnapshot (PtrSnapshot snapshot) = case snapshot of 59 | PtrArray children -> do 60 | childrenPtrs <- forM children \child -> 61 | restorePtrSnapshot child >>= \case 62 | NullPtr -> return nullPtr 63 | PtrLitVal p -> return p 64 | PtrSnapshot _ -> error "expected a pointer literal" 65 | PtrLitVal <$> storePtrPtrs childrenPtrs 66 | ByteArray bytes -> PtrLitVal <$> storePtrBytes bytes 67 | restorePtrSnapshot (PtrLitVal _) = error "not a snapshot" 68 | {-# SCC restorePtrSnapshot #-} 69 | 70 | storePtrBytes :: BS.ByteString -> IO RawPtr 71 | storePtrBytes xs = do 72 | let numBytes = BS.length xs 73 | destPtr <- dexMalloc $ fromIntegral numBytes 74 | -- this is safe because we don't modify srcPtr's memory or let it escape 75 | unsafeUseAsCString xs \srcPtr -> 76 | memcpy (castPtr destPtr) (castPtr srcPtr) numBytes 77 | return destPtr 78 | 79 | storePtrPtrs :: [RawPtr] -> IO RawPtr 80 | storePtrPtrs ptrs = do 81 | ptr <- dexMalloc $ fromIntegral $ length ptrs * ptrSize 82 | pokeArray (castPtr ptr) ptrs 83 | return ptr 84 | 85 | -- === instances === 86 | 87 | instance Store a => Store (WithSnapshot a) 88 | -------------------------------------------------------------------------------- /src/lib/Simplify.hs-boot: -------------------------------------------------------------------------------- 1 | -- Copyright 2023 Google LLC 2 | -- 3 | -- Use of this source code is governed by a BSD-style 4 | -- license that can be found in the LICENSE file or at 5 | -- https://developers.google.com/open-source/licenses/bsd 6 | 7 | module Simplify (linearizeTopFun) where 8 | 9 | import Name 10 | import Builder 11 | import Types.Core 12 | import Types.Top 13 | 14 | linearizeTopFun :: (Mut n, Fallible1 m, TopBuilder m) => LinearizationSpec n -> m n (TopFunName n, TopFunName n) 15 | -------------------------------------------------------------------------------- /src/lib/Types/OpNames.hs: -------------------------------------------------------------------------------- 1 | -- Copyright 2023 Google LLC 2 | -- 3 | -- Use of this source code is governed by a BSD-style 4 | -- license that can be found in the LICENSE file or at 5 | -- https://developers.google.com/open-source/licenses/bsd 6 | 7 | -- This module contains payload-free versions of the ops defined in Types.Core. 8 | -- It uses the same constructor names so it should be imported qualified. 9 | 10 | module Types.OpNames where 11 | 12 | import IRVariants 13 | import Data.Hashable 14 | import GHC.Generics (Generic (..)) 15 | import Data.Store (Store (..)) 16 | 17 | import PPrint 18 | 19 | data TC = ProdType | SumType | RefType | TypeKind | HeapType 20 | data Con = ProdCon | SumCon Int | HeapVal 21 | 22 | data BinOp = 23 | IAdd | ISub | IMul | IDiv | ICmp CmpOp | FAdd | FSub | FMul 24 | | FDiv | FCmp CmpOp | FPow | BAnd | BOr | BShL | BShR | IRem | BXor 25 | 26 | data UnOp = 27 | Exp | Exp2 | Log | Log2 | Log10 | Log1p | Sin | Cos | Tan | Sqrt | Floor 28 | | Ceil | Round | LGamma | Erf | Erfc | FNeg | BNot 29 | 30 | data CmpOp = Less | Greater | Equal | LessEqual | GreaterEqual 31 | 32 | data MemOp = IOAlloc | IOFree | PtrOffset | PtrLoad | PtrStore 33 | 34 | data MiscOp = 35 | Select | CastOp | BitcastOp | UnsafeCoerce | GarbageVal | Effects 36 | | ThrowError | ThrowException | Tag | SumTag | Create | ToEnum 37 | | OutputStream | ShowAny | ShowScalar 38 | 39 | data VectorOp = VectorBroadcast | VectorIota | VectorIdx | VectorSubref 40 | 41 | data Hof (r::IR) = 42 | While | RunReader | RunWriter | RunState | RunIO | RunInit 43 | | CatchException | Linearize | Transpose 44 | 45 | data DAMOp = Seq | RememberDest | AllocDest | Place | Freeze 46 | 47 | data RefOp = MAsk | MExtend | MGet | MPut | IndexRef | ProjRef Projection 48 | 49 | data Projection = 50 | UnwrapNewtype -- TODO: add `HasCore r` constraint 51 | | ProjectProduct Int 52 | deriving (Show, Eq, Generic) 53 | 54 | data UserEffectOp = Handle | Resume | Perform 55 | 56 | deriving instance Generic BinOp 57 | deriving instance Generic UnOp 58 | deriving instance Generic CmpOp 59 | deriving instance Generic TC 60 | deriving instance Generic Con 61 | deriving instance Generic MemOp 62 | deriving instance Generic MiscOp 63 | deriving instance Generic VectorOp 64 | deriving instance Generic (Hof r) 65 | deriving instance Generic DAMOp 66 | deriving instance Generic RefOp 67 | deriving instance Generic UserEffectOp 68 | 69 | instance Hashable BinOp 70 | instance Hashable UnOp 71 | instance Hashable CmpOp 72 | instance Hashable TC 73 | instance Hashable Con 74 | instance Hashable MemOp 75 | instance Hashable MiscOp 76 | instance Hashable VectorOp 77 | instance Hashable (Hof r) 78 | instance Hashable DAMOp 79 | instance Hashable RefOp 80 | instance Hashable UserEffectOp 81 | instance Hashable Projection 82 | 83 | instance Store BinOp 84 | instance Store UnOp 85 | instance Store CmpOp 86 | instance Store TC 87 | instance Store Con 88 | instance Store MemOp 89 | instance Store MiscOp 90 | instance Store VectorOp 91 | instance IRRep r => Store (Hof r) 92 | instance Store DAMOp 93 | instance Store RefOp 94 | instance Store UserEffectOp 95 | instance Store Projection 96 | 97 | deriving instance Show BinOp 98 | deriving instance Show UnOp 99 | deriving instance Show CmpOp 100 | deriving instance Show TC 101 | deriving instance Show Con 102 | deriving instance Show MemOp 103 | deriving instance Show MiscOp 104 | deriving instance Show VectorOp 105 | deriving instance Show (Hof r) 106 | deriving instance Show DAMOp 107 | deriving instance Show RefOp 108 | deriving instance Show UserEffectOp 109 | 110 | deriving instance Eq BinOp 111 | deriving instance Eq UnOp 112 | deriving instance Eq CmpOp 113 | deriving instance Eq TC 114 | deriving instance Eq Con 115 | deriving instance Eq MemOp 116 | deriving instance Eq MiscOp 117 | deriving instance Eq VectorOp 118 | deriving instance Eq (Hof r) 119 | deriving instance Eq DAMOp 120 | deriving instance Eq RefOp 121 | deriving instance Eq UserEffectOp 122 | 123 | instance Pretty Projection where 124 | pretty = \case 125 | UnwrapNewtype -> "u" 126 | ProjectProduct i -> pretty i 127 | -------------------------------------------------------------------------------- /src/old/Imp/Optimize.hs: -------------------------------------------------------------------------------- 1 | -- Copyright 2020 Google LLC 2 | -- 3 | -- Use of this source code is governed by a BSD-style 4 | -- license that can be found in the LICENSE file or at 5 | -- https://developers.google.com/open-source/licenses/bsd 6 | 7 | module Imp.Optimize (liftCUDAAllocations) where 8 | 9 | import Control.Monad 10 | 11 | import PPrint 12 | import Env 13 | import Cat 14 | import Syntax 15 | import Imp.Builder 16 | 17 | -- TODO: DCE! 18 | 19 | type AllocInfo = (BaseType, Int) 20 | type FuncAllocEnv = [(IBinder, AllocInfo)] 21 | type ModAllocEnv = Env [AllocInfo] 22 | 23 | liftCUDAAllocations :: ImpModule -> ImpModule 24 | liftCUDAAllocations m = 25 | fst $ runCat (traverseImpModule liftFunc m) mempty 26 | where 27 | liftFunc :: Env IFunVar -> ImpFunction -> Cat ModAllocEnv ImpFunction 28 | liftFunc fenv f = case f of 29 | FFIFunction _ -> return f 30 | ImpFunction (fname:>IFunType cc argTys retTys) argBs' body' -> case cc of 31 | CUDAKernelLaunch -> do 32 | let ((argBs, body), fAllocEnv) = 33 | flip runCat mempty $ runISubstBuilderT (ISubstEnv mempty fenv) $ do 34 | ~args@(tid:wid:wsz:_) <- traverse freshIVar argBs' 35 | newBody <- extendValSubst (newEnv argBs' $ fmap IVar args) $ buildScoped $ do 36 | gtid <- iadd (IVar tid) =<< imul (IVar wid) (IVar wsz) 37 | evalImpBlock (liftAlloc gtid) body' 38 | return (fmap Bind args, newBody) 39 | let (allocBs, allocs) = unzip fAllocEnv 40 | extend $ fname @> allocs 41 | let newFunTy = IFunType cc (argTys ++ fmap binderAnn allocBs) retTys 42 | return $ ImpFunction (fname :> newFunTy) (argBs ++ allocBs) body 43 | _ -> traverseImpFunction amendLaunch fenv f 44 | 45 | liftAlloc :: IExpr -> ITraversalDef (Cat FuncAllocEnv) 46 | liftAlloc gtid = (liftAllocDecl, traverseImpInstr rec) 47 | where 48 | rec = liftAlloc gtid 49 | liftAllocDecl decl = case decl of 50 | ImpLet [b] (Alloc addrSpace ty (IIdxRepVal size)) -> 51 | case addrSpace of 52 | Stack -> traverseImpDecl rec decl 53 | Heap CPU -> error "Unexpected CPU allocation in a CUDA kernel" 54 | Heap GPU -> do 55 | bArg <- freshIVar b 56 | liftSE $ extend $ [(Bind bArg, (ty, fromIntegral size))] 57 | ptr <- ptrOffset (IVar bArg) =<< imul gtid (IIdxRepVal size) 58 | return $ b @> ptr 59 | ImpLet _ (Alloc _ _ _) -> 60 | error $ "Failed to lift an allocation out of a CUDA kernel: " ++ pprint decl 61 | ImpLet _ (Free _) -> return mempty 62 | _ -> traverseImpDecl rec decl 63 | 64 | amendLaunch :: ITraversalDef (Cat ModAllocEnv) 65 | amendLaunch = (traverseImpDecl amendLaunch, amendLaunchInstr) 66 | where 67 | amendLaunchInstr :: ImpInstr -> ISubstBuilderT (Cat ModAllocEnv) ImpInstr 68 | amendLaunchInstr instr = case instr of 69 | ILaunch f' s' args' -> do 70 | s <- traverseIExpr s' 71 | args <- traverse traverseIExpr args' 72 | liftedAllocs <- liftSE $ looks (!f') 73 | f <- traverseIFunVar f' 74 | extraArgs <- case null liftedAllocs of 75 | True -> return [] 76 | False -> do 77 | ~[numWorkgroups, workgroupSize] <- emit $ IQueryParallelism f s 78 | nthreads <- imul numWorkgroups workgroupSize 79 | forM liftedAllocs $ \(ty, size) -> do 80 | totalSize <- imul (IIdxRepVal $ fromIntegral size) nthreads 81 | alloc (Heap GPU) ty totalSize 82 | return $ ILaunch f s (args ++ extraArgs) 83 | _ -> traverseImpInstr amendLaunch instr 84 | -------------------------------------------------------------------------------- /src/old/MLIR/Eval.hs: -------------------------------------------------------------------------------- 1 | -- Copyright 2021 Google LLC 2 | -- 3 | -- Use of this source code is governed by a BSD-style 4 | -- license that can be found in the LICENSE file or at 5 | -- https://developers.google.com/open-source/licenses/bsd 6 | 7 | module MLIR.Eval where 8 | 9 | import Data.Function 10 | import qualified Data.ByteString.Char8 as BSC8 11 | import qualified Data.ByteString as BS 12 | import GHC.Stack 13 | 14 | import qualified MLIR.AST as AST 15 | import qualified MLIR.AST.Serialize as AST 16 | import qualified MLIR.Native as Native 17 | import qualified MLIR.Native.Pass as Native 18 | import qualified MLIR.Native.ExecutionEngine as Native 19 | 20 | 21 | import Syntax 22 | -- TODO(apaszke): Separate the LitVal operations from LLVMExec 23 | import LLVMExec 24 | 25 | evalModule :: AST.Operation -> [LitVal] -> [BaseType] -> IO [LitVal] 26 | evalModule ast args resultTypes = 27 | Native.withContext \ctx -> do 28 | Native.registerAllDialects ctx 29 | mOp <- AST.fromAST ctx (mempty, mempty) ast 30 | Just m <- Native.moduleFromOperation mOp 31 | verifyModule m 32 | Native.withPassManager ctx \pm -> do 33 | throwOnFailure "Failed to parse pass pipeline" $ 34 | (Native.addParsedPassPipeline pm $ BS.intercalate "," 35 | [ "func(tensor-bufferize,std-bufferize,finalizing-bufferize)" 36 | , "convert-memref-to-llvm" 37 | , "convert-std-to-llvm" 38 | ]) 39 | Native.runPasses pm m & throwOnFailure "Failed to lower module" 40 | verifyModule m 41 | Native.withExecutionEngine m \(Just eng) -> do 42 | Native.withStringRef "entry" \name -> do 43 | allocaCells (length args) \argsPtr -> 44 | allocaCells (length resultTypes) \resultPtr -> do 45 | storeLitVals argsPtr args 46 | Just () <- Native.executionEngineInvoke @() eng name 47 | [Native.SomeStorable argsPtr, Native.SomeStorable resultPtr] 48 | loadLitVals resultPtr resultTypes 49 | 50 | verifyModule :: HasCallStack => Native.Module -> IO () 51 | verifyModule m = do 52 | correct <- Native.verifyOperation =<< Native.moduleAsOperation m 53 | case correct of 54 | True -> return () 55 | False -> do 56 | modStr <- BSC8.unpack <$> Native.showModule m 57 | error $ "Invalid module:\n" ++ modStr 58 | 59 | throwOnFailure :: String -> IO Native.LogicalResult -> IO () 60 | throwOnFailure msg m = do 61 | result <- m 62 | case result of 63 | Native.Success -> return () 64 | Native.Failure -> error msg 65 | -------------------------------------------------------------------------------- /stack-llvm-head.yaml: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Use of this source code is governed by a BSD-style 4 | # license that can be found in the LICENSE file or at 5 | # https://developers.google.com/open-source/licenses/bsd 6 | 7 | resolver: lts-16.31 8 | 9 | packages: 10 | - . 11 | 12 | extra-deps: 13 | - github: llvm-hs/llvm-hs 14 | commit: aba6986a644916239ad414f0966b40f2faffa5f3 15 | subdirs: 16 | - llvm-hs 17 | - llvm-hs-pure 18 | - github: google/mlir-hs 19 | commit: 7a4f4984c71e8fb0d7730bc541e9f2daf1971073 20 | - megaparsec-8.0.0 21 | - prettyprinter-1.6.2 22 | - store-0.7.8@sha256:0b604101fd5053b6d7d56a4ef4c2addf97f4e08fe8cd06b87ef86f958afef3ae,8001 23 | - store-core-0.4.4.4@sha256:a19098ca8419ea4f6f387790e942a7a5d0acf62fe1beff7662f098cfb611334c,1430 24 | - th-utilities-0.2.4.1@sha256:b37d23c8bdabd678aee5a36dd4373049d4179e9a85f34eb437e9cd3f04f435ca,1869 25 | - floating-bits-0.3.0.0@sha256:742bcfcbc21b8daffc995990ee2399ab49550e8f4dd0dff1732d18f57a064c83,2442 26 | 27 | -------------------------------------------------------------------------------- /stack-macos.yaml: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google LLC 2 | # 3 | # Use of this source code is governed by a BSD-style 4 | # license that can be found in the LICENSE file or at 5 | # https://developers.google.com/open-source/licenses/bsd 6 | 7 | resolver: lts-18.23 8 | 9 | packages: 10 | - . 11 | 12 | extra-deps: 13 | - github: llvm-hs/llvm-hs 14 | commit: 423220bffac4990d019fc088c46c5f25310d5a33 15 | subdirs: 16 | - llvm-hs 17 | - llvm-hs-pure 18 | - megaparsec-8.0.0 19 | - prettyprinter-1.6.2 20 | - store-0.7.8@sha256:0b604101fd5053b6d7d56a4ef4c2addf97f4e08fe8cd06b87ef86f958afef3ae,8001 21 | - store-core-0.4.4.4@sha256:a19098ca8419ea4f6f387790e942a7a5d0acf62fe1beff7662f098cfb611334c,1430 22 | - th-utilities-0.2.4.1@sha256:b37d23c8bdabd678aee5a36dd4373049d4179e9a85f34eb437e9cd3f04f435ca,1869 23 | - floating-bits-0.3.0.0@sha256:742bcfcbc21b8daffc995990ee2399ab49550e8f4dd0dff1732d18f57a064c83,2442 24 | 25 | flags: 26 | llvm-hs: 27 | shared-llvm: false 28 | -------------------------------------------------------------------------------- /stack.yaml: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Use of this source code is governed by a BSD-style 4 | # license that can be found in the LICENSE file or at 5 | # https://developers.google.com/open-source/licenses/bsd 6 | 7 | resolver: lts-18.23 8 | 9 | packages: 10 | - . 11 | 12 | extra-deps: 13 | - github: llvm-hs/llvm-hs 14 | commit: 423220bffac4990d019fc088c46c5f25310d5a33 15 | subdirs: 16 | - llvm-hs 17 | - llvm-hs-pure 18 | - megaparsec-8.0.0 19 | - prettyprinter-1.6.2 20 | - store-0.7.8@sha256:0b604101fd5053b6d7d56a4ef4c2addf97f4e08fe8cd06b87ef86f958afef3ae,8001 21 | - store-core-0.4.4.4@sha256:a19098ca8419ea4f6f387790e942a7a5d0acf62fe1beff7662f098cfb611334c,1430 22 | - th-utilities-0.2.4.1@sha256:b37d23c8bdabd678aee5a36dd4373049d4179e9a85f34eb437e9cd3f04f435ca,1869 23 | - floating-bits-0.3.0.0@sha256:742bcfcbc21b8daffc995990ee2399ab49550e8f4dd0dff1732d18f57a064c83,2442 24 | 25 | nix: 26 | enable: false 27 | packages: [ libpng llvm_12 pkg-config zlib ] 28 | 29 | ghc-options: 30 | containers: -fno-prof-auto -O2 31 | hashable: -fno-prof-auto -O2 32 | llvm-hs-pure: -fno-prof-auto -O2 33 | llvm-hs: -fno-prof-auto -O2 34 | megaparsec: -fno-prof-auto -O2 35 | parser-combinators: -fno-prof-auto -O2 36 | prettyprinter: -fno-prof-auto -O2 37 | store-core: -fno-prof-auto -O2 38 | store: -fno-prof-auto -O2 39 | unordered-containers: -fno-prof-auto -O2 40 | -------------------------------------------------------------------------------- /static/dynamic.html: -------------------------------------------------------------------------------- 1 | <!DOCTYPE html> 2 | <html> 3 | <head> 4 | <meta charset="UTF-8"> 5 | <title>Dex Output 6 | 7 | 8 | 9 |
10 |
(hover over code for more information)
11 |
12 | 13 | 14 | 15 | -------------------------------------------------------------------------------- /static/style.css: -------------------------------------------------------------------------------- 1 | /* Copyright 2019 Google LLC */ 2 | /* */ 3 | /* Use of this source code is governed by a BSD-style */ 4 | /* license that can be found in the LICENSE file or at */ 5 | /* https://developers.google.com/open-source/licenses/bsd */ 6 | 7 | body { 8 | font-family: Helvetica, sans-serif; 9 | font-size: 100%; 10 | color: #333; 11 | overflow-x: hidden; 12 | padding-bottom:50vw; 13 | } 14 | 15 | #main-output { 16 | margin-left: 20px; 17 | } 18 | #minimap { 19 | display: flex; 20 | flex-direction: column; 21 | position: fixed; 22 | top: 0em; 23 | left: 0em; 24 | height: 85vh; 25 | width: 32px; 26 | overflow: hidden; 27 | } 28 | .status { 29 | flex: 1; 30 | width : 30px; 31 | border-top: 1px solid; 32 | border-color: lightgray; 33 | margin-left: 1px; 34 | } 35 | #hover-info { 36 | position: fixed; 37 | height: 15vh; 38 | bottom: 0em; 39 | width: 100vw; 40 | overflow: hidden; 41 | background-color: white; 42 | border-top: 1px solid firebrick; 43 | font-family: monospace; 44 | white-space: pre; 45 | } 46 | /* cell structure */ 47 | .cell { 48 | margin-left: 5px; 49 | display: flex; 50 | } 51 | .line-nums { 52 | flex: 0 0 3em; 53 | height: 100%; 54 | text-align: right; 55 | color: #808080; 56 | font-family: monospace; 57 | white-space: pre; 58 | } 59 | .contents { 60 | margin-left: 1em; 61 | font-family: monospace; 62 | white-space: pre; 63 | } 64 | 65 | /* special results */ 66 | .err-result { 67 | font-weight: bold; 68 | color: #B22222; 69 | } 70 | 71 | /* status colors */ 72 | .status-inert {} 73 | .status-waiting {background-color: gray;} 74 | .status-running {background-color: lightblue;} 75 | .status-err {background-color: red;} 76 | .status-success {background-color: white;} 77 | 78 | /* span highlighting */ 79 | .highlight-error { 80 | text-decoration: red underline; 81 | text-decoration-thickness: 5px; 82 | text-decoration-skip-ink: none;} 83 | .highlight-group { background-color: yellow; } 84 | .highlight-scope { background-color: lightyellow; } 85 | .highlight-binder { background-color: lightblue; } 86 | .highlight-occ { background-color: yellow; } 87 | .highlight-leaf { background-color: lightgray; } 88 | 89 | /* lexeme colors */ 90 | .comment {color: gray;} 91 | .keyword {color: #0000DD;} 92 | .command {color: #A80000;} 93 | .symbol {color: #E07000;} 94 | .type-name {color: #A80000;} 95 | 96 | .status-hover { 97 | background-color: yellow; 98 | } 99 | -------------------------------------------------------------------------------- /tests/algeff-tests.dx: -------------------------------------------------------------------------------- 1 | effect Exn 2 | ctl raise : (a: Type) ?-> Unit -> a 3 | 4 | handler catch_ of Exn r : Maybe r 5 | ctl raise = \_. Nothing 6 | return = \x. Just x 7 | 8 | handler bad_catch_1 of Exn r : Maybe r 9 | ctl raise = \_. Nothing 10 | ctl raise = \_. Nothing -- duplicate! 11 | return = \x. Just x 12 | > Type error:Duplicate operation: raise 13 | 14 | handler bad_catch_2 of Exn r : Maybe r 15 | ctl raise = \_. Nothing 16 | > Type error:missing return 17 | -- return = \x. Just x -- missing! 18 | 19 | handler bad_catch_3 of Exn r : Maybe r 20 | -- ctl raise = \_. Nothing -- missing! 21 | return = \x. Just x 22 | > Type error:Missing operation: raise 23 | 24 | handler bad_catch_4 of Exn r : Maybe r 25 | ctl raise = \_. 42.0 -- definitely not Maybe 26 | return = \x. Just x 27 | > Type error: 28 | > Expected: (Maybe r) 29 | > Actual: Float32 30 | > 31 | > ctl raise = \_. 42.0 -- definitely not Maybe 32 | > ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 33 | 34 | handler bad_catch_5 of Exn r : Maybe r 35 | ctl raise = \_. Nothing 36 | return = \x. 42.0 -- definitely not Maybe 37 | > Type error: 38 | > Expected: (Maybe r) 39 | > Actual: Float32 40 | > 41 | > return = \x. 42.0 -- definitely not Maybe 42 | > ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 43 | 44 | handler bad_catch_6 of Exn r : Maybe r 45 | def raise = \_. Nothing -- wrong policy! 46 | return = \x. Just x 47 | > Type error:operation raise was declared with def but defined with ctl 48 | 49 | def check (b:Bool) : {Exn} Unit = 50 | if not b then raise () 51 | 52 | def checkFloatNonNegative (x:Float) : {Exn} Float = 53 | check $ x >= 0.0 54 | x 55 | 56 | -- catch_ \_. 57 | -- checkFloatNonNegative (3.14) 58 | -- > Compiler bug! 59 | -- > Please report this at github.com/google-research/dex-lang/issues 60 | -- > 61 | -- > Not implemented 62 | -- > CallStack (from HasCallStack): 63 | -- > error, called at src/lib/Simplify.hs:214:19 in dex-0.1.0.0-8hDfthyGTXmzhkTo2ydOn:Simplify 64 | 65 | -- catch_ \_. 66 | -- checkFloatNonNegative (-1.0) 67 | -- > Compiler bug! 68 | -- > Please report this at github.com/google-research/dex-lang/issues 69 | -- > 70 | -- > Not implemented 71 | -- > CallStack (from HasCallStack): 72 | -- > error, called at src/lib/Simplify.hs:214:19 in dex-0.1.0.0-8hDfthyGTXmzhkTo2ydOn:Simplify 73 | 74 | effect Counter 75 | def inc : Unit -> Unit 76 | 77 | handler runCounter of Counter r {h} (ref : Ref h Nat) : {State h} (r & Nat) 78 | def inc = \_. 79 | ref := (1 + get ref) 80 | resume () 81 | return = \x. (x, get ref) 82 | > Error: variable not in scope: resume 83 | > 84 | > resume () 85 | > ^^^^^^^ 86 | -------------------------------------------------------------------------------- /tests/cast-tests.dx: -------------------------------------------------------------------------------- 1 | -- ==== Integral casts ==== 2 | -- 3 | -- Semantics of internal_cast on integral types are based on the bit representation 4 | -- of the values in question. All WordX types have a bit representation equal to their 5 | -- value in standard binary format. All IntX types use two's complement representation. 6 | -- 7 | -- The cast is always performed by taking the source value to its bit representation, 8 | -- resizing that representation (depending on the target signedness), and interpreting 9 | -- the resulting bit pattern in the target type. 10 | -- 11 | -- The rules for resizing the bit pattern are as follows: 12 | -- 1. If the target bitwidth is smaller than source bitwidth, the maximum number of 13 | -- least significant bits are preserved. 14 | -- 2. If the target bitwidth is equal to the source bitwidth, nothing happens. 15 | -- 3. If the target bitwidth is greater than the source bitwidth, and: 16 | -- 3a. the target type is signed, the representation is sign-extended (the 17 | -- MSB is used to pad the value up to the desired width). 18 | -- 3b. the target type is unsigned, the representation is zero-extended. 19 | 20 | -- Casts to Int32 21 | 22 | internal_cast(to=Int32, 2147483647 :: Int64) 23 | > 2147483647 24 | 25 | internal_cast(to=Int32, 2147483648 :: Int64) 26 | > -2147483648 27 | 28 | internal_cast(to=Int32, 8589935826 :: Int64) -- 2^33 + 1234 29 | > 1234 30 | 31 | internal_cast(to=Int32, 123 :: Word8) 32 | > 123 33 | 34 | internal_cast(to=Int32, 1234 :: Word32) 35 | > 1234 36 | 37 | internal_cast(to=Int32, 4294967295 :: Word32) 38 | > -1 39 | 40 | internal_cast(to=Int32, 1234 :: Word64) 41 | > 1234 42 | 43 | internal_cast(to=Int32, 4294967295 :: Word64) 44 | > -1 45 | 46 | internal_cast(to=Int32, 4294967296 :: Word64) 47 | > 0 48 | 49 | internal_cast(to=Int32, 5000000000 :: Word64) 50 | > 705032704 51 | 52 | -- Casts to Int64 53 | 54 | internal_cast(to=Int64, 123 :: Int32) 55 | > 123 56 | 57 | internal_cast(to=Int64, -123 :: Int32) 58 | > -123 59 | 60 | internal_cast(to=Int64, 123 :: Word8) 61 | > 123 62 | 63 | internal_cast(to=Int64, 1234 :: Word32) 64 | > 1234 65 | 66 | internal_cast(to=Int64, 4294967296 :: Word64) -- 2^32 67 | > 4294967296 68 | 69 | -- Casts to Word8 70 | 71 | internal_cast(to=Word8, 1234 :: Int32) 72 | > 0xd2 73 | 74 | internal_cast(to=Word8, 1234 :: Int) 75 | > 0xd2 76 | 77 | internal_cast(to=Word8, 1234 :: Word32) 78 | > 0xd2 79 | 80 | internal_cast(to=Word8, 1234 :: Word64) 81 | > 0xd2 82 | 83 | -- Casts to Word32 84 | 85 | internal_cast(to=Word32, 1234 :: Int32) 86 | > 0x4d2 87 | 88 | internal_cast(to=Word32, -2147483648 :: Int32) 89 | > 0x80000000 90 | 91 | internal_cast(to=Word32, 1234 :: Int64) 92 | > 0x4d2 93 | 94 | internal_cast(to=Word32, 4294968530 :: Int64) -- 2^32 + 1234 95 | > 0x4d2 96 | 97 | internal_cast(to=Word32, -1 :: Int64) 98 | > 0xffffffff 99 | 100 | internal_cast(to=Word32, 123 :: Word8) 101 | > 0x7b 102 | 103 | internal_cast(to=Word32, 1234 :: Word64) 104 | > 0x4d2 105 | 106 | internal_cast(to=Word32, 4294967296 :: Word64) 107 | > 0x0 108 | 109 | -- Casts to Word64 110 | 111 | internal_cast(to=Word64, 1234 :: Int32) 112 | > 0x4d2 113 | 114 | internal_cast(to=Word64, -1 :: Int32) 115 | > 0xffffffff 116 | 117 | internal_cast(to=Word64, 1234 :: Int64) 118 | > 0x4d2 119 | 120 | internal_cast(to=Word64, -1 :: Int64) 121 | > 0xffffffffffffffff 122 | 123 | internal_cast(to=Word64, 123 :: Word8) 124 | > 0x7b 125 | 126 | internal_cast(to=Word64, 1234 :: Word32) 127 | > 0x4d2 128 | 129 | internal_cast(to=Word64, 4294967295 :: Word32) 130 | > 0xffffffff 131 | -------------------------------------------------------------------------------- /tests/complex-tests.dx: -------------------------------------------------------------------------------- 1 | import complex 2 | 3 | :p complex_floor $ Complex 0.3 0.6 4 | > Complex(0., 0.) 5 | :p complex_floor $ Complex 0.6 0.8 6 | > Complex(0., 1.) 7 | :p complex_floor $ Complex 0.8 0.6 8 | > Complex(1., 0.) 9 | :p complex_floor $ Complex 0.6 0.3 10 | > Complex(0., 0.) 11 | 12 | a = Complex 2.1 0.4 13 | b = Complex (-1.1) 1.3 14 | :p (a + b - a) ~~ b 15 | > True 16 | :p (a * b) ~~ (b * a) 17 | > True 18 | :p divide (a * b) a ~~ b 19 | > True 20 | -- This next test can be added once we parameterize the field in the VSpace typeclass. 21 | --:p ((a * b) / a) ~~ b 22 | --> True 23 | :p a == b 24 | > False 25 | :p a == a 26 | > True 27 | :p log (exp a) ~~ a 28 | > True 29 | :p exp (log a) ~~ a 30 | > True 31 | :p log2 (exp2 a) ~~ a 32 | > True 33 | :p exp2 (log2 a) ~~ a 34 | > True 35 | :p sqrt (sq a) ~~ a 36 | > True 37 | :p sqrt (Complex (-1.0) 0.0) ~~ (Complex 0.0 1.0) 38 | > True 39 | :p log ((Complex 1.0 0.0) + a) ~~ log1p a 40 | > True 41 | :p sin (-a) ~~ (-(sin a)) 42 | > True 43 | :p cos (-a) ~~ cos a 44 | > True 45 | :p tan (-a) ~~ (- (tan a)) 46 | > True 47 | :p exp (pi .* (Complex 0.0 1.0)) ~~ (Complex (-1.0) 0.0) -- Euler's identity 48 | > True 49 | :p ((sq (sin a)) + (sq (cos a))) ~~ (Complex 1.0 0.0) 50 | > True 51 | :p complex_abs b > 0.0 52 | > True 53 | 54 | :p sinh (Complex 1.2 3.2) 55 | > Complex(-1.506887, -0.1056956) 56 | :p cosh (Complex 1.2 3.2) 57 | > Complex(-1.807568, 0.08811359) 58 | :p tanh (Complex 1.1 0.1) 59 | > Complex(0.8033752, 0.03580933) 60 | :p tan (Complex 1.2 3.2) 61 | > Complex(0.002250167, 1.002451) 62 | -------------------------------------------------------------------------------- /tests/exception-tests.dx: -------------------------------------------------------------------------------- 1 | 2 | 3 | def checkFloatInUnitInterval(x:Float) -> {Except} Float = 4 | assert $ x >= 0.0 5 | assert $ x <= 1.0 6 | x 7 | 8 | :p catch \. assert False 9 | > Nothing 10 | 11 | :p catch \. assert True 12 | > (Just ()) 13 | 14 | :p catch \. checkFloatInUnitInterval 1.2 15 | > Nothing 16 | 17 | :p catch \. checkFloatInUnitInterval (-1.2) 18 | > Nothing 19 | 20 | :p catch \. checkFloatInUnitInterval 0.2 21 | > (Just 0.2) 22 | 23 | :p yield_state 0 \ref. 24 | catch \. 25 | ref := 1 26 | assert False 27 | ref := 2 28 | > 1 29 | 30 | :p catch \. 31 | for i:(Fin 5). 32 | if ordinal i > 3 33 | then throw() 34 | else 23 35 | > Nothing 36 | 37 | :p catch \. 38 | for i:(Fin 3). 39 | if ordinal i > 3 40 | then throw() 41 | else 23 42 | > (Just [23, 23, 23]) 43 | 44 | -- Is this the result we want? 45 | :p yield_state zero \ref. 46 | catch \. 47 | for i:(Fin 6). 48 | if (ordinal i `rem` 2) == 0 49 | then throw() 50 | else () 51 | ref!i := 1 52 | > [0, 1, 0, 1, 0, 1] 53 | 54 | :p catch \. 55 | run_state 0 \ref. 56 | ref := 1 57 | assert False 58 | ref := 2 59 | > Nothing 60 | 61 | -- https://github.com/google-research/dex-lang/issues/612 62 | def sashabug(h: ()) -> {Except} List Int = 63 | yield_state mempty \results. 64 | results := (get results) <> AsList 1 [2] 65 | 66 | catch \. (catch \. sashabug ()) 67 | > (Just (Just (AsList 1 [2]))) 68 | -------------------------------------------------------------------------------- /tests/fft-tests.dx: -------------------------------------------------------------------------------- 1 | import complex 2 | import fft 3 | 4 | :p map nextpow2 [0, 1, 2, 3, 4, 7, 8, 9, 1023, 1024, 1025] 5 | > [0, 0, 1, 2, 2, 3, 3, 4, 10, 10, 11] 6 | 7 | a : (Fin 4)=>Complex = arb $ new_key 0 8 | :p a ~~ (ifft $ fft a) 9 | > True 10 | :p a ~~ (fft $ ifft a) 11 | > True 12 | 13 | b : (Fin 20)=>(Fin 70)=>Complex = arb $ new_key 0 14 | :p b ~~ (ifft2 $ fft2 b) 15 | > True 16 | :p b ~~ (fft2 $ ifft2 b) 17 | > True 18 | -------------------------------------------------------------------------------- /tests/gpu-tests.dx: -------------------------------------------------------------------------------- 1 | 2 | x = for i:(Fin 5). i_to_f $ ordinal i 3 | x 4 | > [0., 1., 2., 3., 4.] 5 | 6 | x + x 7 | > [0., 2., 4., 6., 8.] 8 | 9 | -- TODO: Make it a FileCheck test 10 | testNestedParallelism = 11 | for i:(Fin 10). 12 | x = ordinal i 13 | q = for j:(Fin 2000). i_to_f $ x * ordinal j 14 | (2.0 .* q, 4.0 .* q) 15 | (fst testNestedParallelism.(2@_)).(5@_) 16 | > 20. 17 | 18 | -- TODO: Make it a FileCheck test 19 | testNestedLoops = 20 | for i:(Fin 10). 21 | for j:(Fin 20). 22 | ordinal i * ordinal j 23 | testNestedLoops.(4@_).(5@_) 24 | > 20 25 | 26 | -- The state is large enough such that it shouldn't fit on the stack of a 27 | -- single GPU thread. It should get lifted to a top-level allocation instead. 28 | -- allocationLiftingTest = 29 | -- for i:(Fin 100). 30 | -- yieldState (for j:(Fin 1000). ordinal i) $ \s. 31 | -- s!(0@_) := get s!(0@_) + 1 32 | -- (allocationLiftingTest.(4@_).(0@_), allocationLiftingTest.(4@_).(1@_)) 33 | -- > (5, 4) 34 | -------------------------------------------------------------------------------- /tests/inline-tests.dx: -------------------------------------------------------------------------------- 1 | -- The "=== inline ===" strings below are a hack around the fact that 2 | -- Dex currently does two passes of inlining and prints the results of 3 | -- both. Surrounding the CHECK block with these commands constrains 4 | -- the body to occur in the output from the first inlining pass. 5 | 6 | @noinline 7 | def id'(x:Nat) -> Nat = x 8 | 9 | -- CHECK-LABEL: Inline for into for 10 | "Inline for into for" 11 | 12 | %passes inline 13 | :pp 14 | xs = for i:(Fin 10). ordinal i 15 | for j. xs[j] + 2 16 | -- CHECK: === inline === 17 | -- CHECK: for 18 | -- CHECK-NOT: for 19 | -- CHECK: === inline === 20 | 21 | -- CHECK-LABEL: Inline for into sum 22 | "Inline for into sum" 23 | 24 | %passes inline 25 | :pp sum for i:(Fin 10). ordinal i 26 | -- CHECK: === inline === 27 | -- CHECK: for 28 | -- CHECK-NOT: for 29 | -- CHECK: === inline === 30 | 31 | -- CHECK-LABEL: Inline nested for into for 32 | "Inline nested for into for" 33 | 34 | %passes inline 35 | :pp 36 | xs = for i:(Fin 10). for j:(Fin 20). ordinal i * ordinal j 37 | for j i. xs[i, j] + 2 38 | -- CHECK: === inline === 39 | -- CHECK: for 40 | -- CHECK: for 41 | -- CHECK-NOT: for 42 | -- CHECK: === inline === 43 | 44 | -- CHECK-LABEL: Inlining does not reorder effects 45 | "Inlining does not reorder effects" 46 | 47 | -- Note that it _would be_ legal to reorder independent effects, but 48 | -- the inliner currently does not do that. But the effect in this 49 | -- example is not legal to reorder in any case. 50 | 51 | %passes inline 52 | :pp run_state 0 \ct. 53 | xs = for i:(Fin 10). 54 | ct := (get ct) + 1 55 | ordinal i 56 | for j. 57 | ct := (get ct) * 2 58 | xs[j] + 2 59 | -- CHECK: === inline === 60 | -- CHECK: for 61 | -- CHECK: for 62 | -- CHECK: === inline === 63 | 64 | -- CHECK-LABEL: Inlining does not duplicate the inlinee through beta reduction 65 | "Inlining does not duplicate the inlinee through beta reduction" 66 | 67 | -- The check is for the error call in the dynamic check that `ix` has 68 | -- type `Fin 100`. 69 | %passes inline 70 | :pp 71 | ix = (id' 20)@(Fin 100) 72 | (for i:(Fin 100). ordinal i + ordinal i)[ix] 73 | -- CHECK: === inline === 74 | -- CHECK: error 75 | -- CHECK-NOT: error 76 | -- CHECK: === inline === 77 | 78 | -- CHECK-LABEL: Inlining does not violate type IR through beta reduction 79 | "Inlining does not violate type IR through beta reduction" 80 | 81 | -- Beta reducing this ix into the `i` index of the `for` should stop 82 | -- before it produces anything a type expression can't handle, and 83 | -- thus execute. 84 | 85 | :p 86 | ix = (1@(Fin 2)) 87 | sum (for i:(Fin 2) j:(..i). ordinal j)[ix] 88 | -- CHECK: 1 89 | -- CHECK-NOT: Compiler bug 90 | 91 | -- CHECK-LABEL: Inlining simplifies case-of-known-constructor 92 | "Inlining simplifies case-of-known-constructor" 93 | 94 | -- Inlining xs exposes a case-of-known-constructor opportunity here; 95 | -- the first inlining pass doesn't take it (yet) because it's 96 | -- conservative about inlining `i` into the body of `xs`, but the 97 | -- second pass does. 98 | %passes inline 99 | :pp 100 | xs = for i:(Either (Fin 3) (Fin 4)). 101 | case i of 102 | Left k -> 1 103 | Right k -> 2 104 | for j:(Fin 3). xs[Left j] 105 | -- CHECK: === inline === 106 | -- CHECK: for 107 | -- CHECK: case 108 | -- CHECK: === inline === 109 | -- CHECK: for 110 | -- CHECK-NOT: case 111 | 112 | -- CHECK-LABEL: Inlining carries out the case-of-case optimization 113 | "Inlining carries out the case-of-case optimization" 114 | 115 | -- Before inlining there are two cases, but attempting to inline `x` 116 | -- reveals a case-of-case opprtunity, which in turn exposes 117 | -- case-of-known-constructor in each branch, leading to just one case 118 | -- in the end. 119 | %passes inline 120 | :pp 121 | x = if id'(3) > 2 122 | then Just 4 123 | else Nothing 124 | case x of 125 | Just a -> a * a 126 | Nothing -> 0 127 | -- CHECK: === inline === 128 | -- CHECK: case 129 | -- CHECK-NOT: case 130 | -- CHECK: === inline === 131 | -------------------------------------------------------------------------------- /tests/instance-interface-syntax-tests.dx: -------------------------------------------------------------------------------- 1 | 2 | interface Empty(a:Type) 3 | pass 4 | -- CHECK-NOT: Parse error 5 | 6 | instance Empty(Int) 7 | pass 8 | -- CHECK-NOT: Parse error 9 | 10 | instance Empty(Float32) 11 | def witness() = 0.0 12 | -- CHECK-NOT: Parse error 13 | -- CHECK: Error: variable not in scope: witness 14 | 15 | interface Inhabited(a) 16 | witness : a 17 | -- CHECK-NOT: Parse error 18 | 19 | instance Inhabited(Int) 20 | witness = 0 21 | -- CHECK-NOT: Parse error 22 | 23 | instance Inhabited(Float64) 24 | witness = f_to_f64(0.0) 25 | pass 26 | -- CHECK: Parse error 27 | -- CHECK: unexpected "pa" 28 | -- CHECK: expecting end of line 29 | 30 | instance Inhabited(Word32) 31 | witness = 0 32 | pass 33 | -- CHECK: Parse error 34 | -------------------------------------------------------------------------------- /tests/instance-methods-tests.dx: -------------------------------------------------------------------------------- 1 | 2 | interface FooBar0(a) 3 | foo0 : (a) -> Int 4 | bar0 : (a) -> Int 5 | 6 | instance FooBar0(Int) 7 | def foo0(x) = x + 1 8 | def bar0(x) = foo0 x + 1 9 | 10 | w : Int = 42 11 | 12 | -- CHECK: 43 13 | foo0 w 14 | > 43 15 | 16 | -- CHECK: 44 17 | bar0 w 18 | > 44 19 | 20 | 21 | interface FooBar1(a) 22 | foo1 : (a) -> Int 23 | bar1 : (a) -> Int 24 | 25 | instance FooBar1(Int) 26 | foo1 = \x. x + 1 27 | -- Fails: Definition of `bar1` uses the class method `bar1` (with index 1); 28 | -- but the instance `FooBar1 Int` is currently still being defined and, at 29 | -- this point, can only grant access to method `foo1` (with index 0). 30 | bar1 = \x. bar1 x + 1 31 | > Type error:Wrong number of positional arguments provided. Expected 1 but got 0 32 | > 33 | > foo1 = \x. x + 1 34 | > ^^^^^^^^^^^^^^^^ 35 | -- CHECK: Type error:Couldn't synthesize a class dictionary for: (FooBar1 Int32) 36 | -- CHECK: bar1 = \x. bar1 x + 1 37 | -- CHECK: ^^^^^ 38 | 39 | 40 | interface FooBar2(a) 41 | foo2 : (a) -> Int 42 | bar2 : (a) -> Int 43 | 44 | def f2(x:a) given (a|FooBar2) = (\y. foo2 y + 1) x 45 | -- The defintion of `f2` is OK because argument `d : FooBar2 a` grants access to 46 | -- all methods of class `FooBar2 a`. (Only one method of `FooBar2` is actually 47 | -- used in the body of `f2`.) 48 | 49 | def g2(x:a) given (a|FooBar2) = (\y z. foo2 y + z) x (bar2 x) 50 | -- The defintion of `g2` is OK because argument `d : FooBar2 a` grants access to 51 | -- all methods of class `FooBar2 a`. 52 | 53 | 54 | instance FooBar2(Int) 55 | def foo2(x) = x + 1 56 | -- Fails: The definition of `bar2` uses `f2`, which requires a dictionary 57 | -- `d : FooBar2 Int` that has access to all methods of `FooBar2 Int`. 58 | def bar2(x) = f2 x + 1 59 | > Type error:Couldn't synthesize a class dictionary for: (FooBar2 Int32) 60 | > 61 | > def bar2(x) = f2 x + 1 62 | > ^^^^^ 63 | -- CHECK: Type error:Couldn't synthesize a class dictionary for: (FooBar2 Int32) 64 | -- CHECK: bar2 = \x. f2 x + 1 65 | -- CHECK: ^^^ 66 | 67 | 68 | interface Shows0(a) 69 | shows0 : (a) -> String 70 | showsList0 : (List a) -> String 71 | 72 | -- The body of method `showsList0` uses method `shows0` from the same instance. 73 | instance Shows0(Nat) 74 | def shows0(x) = show x 75 | def showsList0(xs) = 76 | AsList(n, ys) = xs 77 | strings = map shows0 ys 78 | reduce "" (<>) strings 79 | 80 | showsList0 (AsList 3 [0, 1, 2]) 81 | > "012" 82 | -- CHECK: "012" 83 | 84 | interface Shows1(a) 85 | shows1 : (a) -> String 86 | showsList1 : (List a) -> String 87 | 88 | instance Shows1(Nat) 89 | def shows1(x) = showsList1 (AsList 1 [x]) 90 | -- Methods `shows1` and `showsList1` refer to each other in a mutually recursive 91 | -- fashion: the body of method `showsList1` uses method `shows1` from the same 92 | -- instance, and the body of method `showsList1` uses method `shows1` also from 93 | -- this instance. 94 | def showsList1(xs) = 95 | AsList(n, ys) = xs 96 | strings = map shows1 ys 97 | reduce "" (<>) strings 98 | > Type error:Couldn't synthesize a class dictionary for: (Shows1 Nat) 99 | > 100 | > def shows1(x) = showsList1 (AsList 1 [x]) 101 | > ^^^^^^^^^^^^^^^^^^^^^^^^^ 102 | -- CHECK: Type error:Couldn't synthesize a class dictionary for: (Shows1 Nat) 103 | -- CHECK: shows1 = \x. showsList1 (AsList 1 [x]) 104 | -- CHECK: ^^^^^^^^^^^ 105 | -------------------------------------------------------------------------------- /tests/io-tests.dx: -------------------------------------------------------------------------------- 1 | 2 | :p unsafe_io \. 3 | with_temp_file \fname. 4 | with_file fname WriteMode \stream. 5 | fwrite stream "lorem ipsum\n" 6 | fwrite stream "dolor sit amet\n" 7 | read_file fname 8 | > "lorem ipsum 9 | > dolor sit amet 10 | > " 11 | 12 | :p unsafe_io \. 13 | with_alloc 4 \ptr:(Ptr Nat). 14 | for i:(Fin 4). store (ptr +>> ordinal i) (ordinal i) 15 | table_from_ptr(n=Fin 4, ptr) 16 | > [0, 1, 2, 3] 17 | 18 | unsafe_io \. 19 | print "testing log" 20 | 1.0 -- prevent DCE 21 | > testing log 22 | > 1. 23 | 24 | unsafe_io \. 25 | for i':(Fin 10). 26 | i = ordinal i' 27 | if rem i 2 == 0 28 | then print $ show i <> " is even" 29 | else print $ show i <> " is odd" 30 | 1.0 -- prevent DCE 31 | > 0 is even 32 | > 1 is odd 33 | > 2 is even 34 | > 3 is odd 35 | > 4 is even 36 | > 5 is odd 37 | > 6 is even 38 | > 7 is odd 39 | > 8 is even 40 | > 9 is odd 41 | > 1. 42 | 43 | :p storage_size(a=Int) 44 | > 4 45 | 46 | :p unsafe_io \. 47 | with_alloc 1 \ptr:(Ptr Int). 48 | store ptr 3 49 | load ptr 50 | > 3 51 | 52 | :p with_stack Nat \stack. 53 | stack.extend(for i:(Fin 1000). ordinal i) 54 | stack.extend(for i:(Fin 1000). ordinal i) 55 | AsList(_, xs) = stack.read() 56 | sum xs 57 | > 999000 58 | 59 | :p unsafe_io \. 60 | s = for i:(Fin 10000). i_to_w8 $ f_to_i $ 128.0 * rand (ixkey (new_key 0) i) 61 | with_temp_file \fname. 62 | with_file fname WriteMode \stream. 63 | fwrite stream $ AsList _ s 64 | AsList(_, s') = read_file fname 65 | sum (for i. w8_to_i s[i]) == sum (for i. w8_to_i s'[i]) 66 | > True 67 | 68 | :p unsafe_io \. get_env "NOT_AN_ENV_VAR" 69 | > Nothing 70 | 71 | :p unsafe_io \. get_env "DEX_TEST_MODE" 72 | > (Just "t") 73 | 74 | :p dex_test_mode() 75 | > True 76 | -------------------------------------------------------------------------------- /tests/linalg-tests.dx: -------------------------------------------------------------------------------- 1 | import linalg 2 | 3 | -- Check that the optimized matmul gives the same answers as the naive one 4 | amat = for i:(Fin 100) j:(Fin 100). n_to_f $ ordinal (i, j) 5 | 6 | :p tiled_matmul(amat, amat) ~~ naive_matmul amat amat 7 | > True 8 | 9 | -- Check that the inverse of the inverse is identity. 10 | mat = [[11.,9.,24.,2.],[1.,5.,2.,6.],[3.,17.,18.,1.],[2.,5.,7.,1.]] 11 | :p mat ~~ (invert (invert mat)) 12 | > True 13 | 14 | -- Check that solving gives the inverse. 15 | v = [1., 2., 3., 4.] 16 | :p v ~~ (mat **. (solve mat v)) 17 | > True 18 | 19 | -- Check that det and exp(logdet) are the same. 20 | (s, logdet) = sign_and_log_determinant mat 21 | :p (determinant mat) ~~ (s * (exp logdet)) 22 | > True 23 | 24 | -- Matrix integer powers. 25 | :p matrix_power mat 0 ~~ eye 26 | > True 27 | :p matrix_power mat 1 ~~ mat 28 | > True 29 | :p matrix_power mat 2 ~~ (mat ** mat) 30 | > True 31 | :p matrix_power mat 5 ~~ (mat ** mat ** mat ** mat ** mat) 32 | > True 33 | 34 | :p trace mat == (11. + 5. + 18. + 1.) 35 | > True 36 | 37 | -- Check that we can linearize LU decomposition 38 | -- This is a regression test for Issue #842. 39 | snd(linearize (\x. snd $ sign_and_log_determinant [[x]]) 1.0)(2.0) 40 | > 2. 41 | 42 | -- Check that we can differentiate through LU decomposition 43 | -- This is a regression test for Issue #848. 44 | grad (\x. (pivotize [[x]]).sign) 1.0 45 | > 0. 46 | 47 | grad (\x. snd $ sign_and_log_determinant [[x]]) 2.0 48 | > 0.5 49 | 50 | -- Check forward_substitute solve by comparing 51 | -- against zero-padding and doing the full solve. 52 | def padLowerTriMat(mat:LowerTriMat n v) -> n=>n=>v given (n|Ix, v|Add) = 53 | for i j. 54 | if (ordinal j)<=(ordinal i) 55 | then mat[i,unsafe_project j] 56 | else zero 57 | 58 | lower : LowerTriMat (Fin 4) Float = arb $ new_key 0 59 | lower_padded = padLowerTriMat lower 60 | vec : (Fin 4)=>Float = arb $ new_key 0 61 | 62 | forward_substitute lower vec ~~ solve lower_padded vec 63 | > True 64 | -------------------------------------------------------------------------------- /tests/lower.dx: -------------------------------------------------------------------------------- 1 | for i:(Fin 2) j:(Fin 4). ordinal (i,j) 2 | > [[0, 1, 2, 3], [4, 5, 6, 7]] 3 | -------------------------------------------------------------------------------- /tests/module-tests.dx: -------------------------------------------------------------------------------- 1 | import test_module_A 2 | import test_module_B 3 | 4 | :p 1 + 1 5 | > 2 6 | 7 | :p test_module_A_val + 4 8 | > 7 9 | 10 | :p test_module_amb 11 | > Error: ambiguous variable: test_module_amb is defined: 12 | > in test_module_A 13 | > in test_module_B 14 | > 15 | > 16 | > :p test_module_amb 17 | > ^^^^^^^^^^^^^^^ 18 | 19 | :p test_module_B_val_from_C 20 | > 23 21 | 22 | :p test_module_C_val 23 | > Error: variable not in scope: test_module_C_val 24 | > 25 | > :p test_module_C_val 26 | > ^^^^^^^^^^^^^^^^^ 27 | 28 | :p test_module_A_fun 2 29 | > 4 30 | 31 | :p test_module_A_fun_noinline 3 32 | > 6 33 | 34 | :p fooMethodExportFromB 1 35 | > 2 36 | 37 | :p fooMethodExportFromB 1.0 38 | > 10. 39 | 40 | :p arrayVal 41 | > [1, 2, 3] 42 | 43 | :p arrayVal2 44 | > [2, 4, 6] 45 | -------------------------------------------------------------------------------- /tests/parser-combinator-tests.dx: -------------------------------------------------------------------------------- 1 | 2 | import parser 3 | 4 | parseABC : Parser () = MkParser \h. 5 | parse h $ p_char 'A' 6 | parse h $ p_char 'B' 7 | parse h $ p_char 'C' 8 | 9 | :p run_parser "AAA" parseABC 10 | > Nothing 11 | 12 | :p run_parser "ABCABC" parseABC 13 | > Nothing 14 | 15 | :p run_parser "AB" parseABC 16 | > Nothing 17 | 18 | :p run_parser "ABC" parseABC 19 | > (Just ()) 20 | 21 | def parseT() ->> Parser Bool = MkParser \h. 22 | parse h $ p_char 'T' 23 | True 24 | 25 | def parseF() ->> Parser Bool = MkParser \h. 26 | parse h $ p_char 'F' 27 | False 28 | 29 | def parseTF() ->> Parser Bool = 30 | parseT <|> parseF 31 | 32 | def parserTFTriple() ->> Parser (Fin 3=>Bool) = MkParser \h. 33 | for i. parse h parseTF 34 | 35 | :p run_parser "TTF" parserTFTriple 36 | > (Just [True, True, False]) 37 | 38 | :p run_parser "TTFX" parserTFTriple 39 | > Nothing 40 | 41 | :p run_parser "TTFFTT" $ parse_many parseTF 42 | > (Just (AsList 6 [True, True, False, False, True, True])) 43 | 44 | :p run_parser "1021389" $ parse_many parse_digit 45 | > (Just (AsList 7 [1, 0, 2, 1, 3, 8, 9])) 46 | 47 | :p run_parser "1389" $ parse_int 48 | > (Just 1389) 49 | 50 | :p run_parser "01389" $ parse_int 51 | > (Just 1389) 52 | 53 | :p run_parser "-1389" $ parse_int 54 | > (Just -1389) 55 | 56 | split ' ' " This is a sentence. " 57 | > (AsList 4 ["This", "is", "a", "sentence."]) 58 | -------------------------------------------------------------------------------- /tests/print-tests.dx: -------------------------------------------------------------------------------- 1 | 2 | :pcodegen [(),(),()] 3 | > [(), (), ()] 4 | 5 | -- :pcodegen {x = 1.0, y = 2} 6 | -- > {x = 1., y = 2} 7 | 8 | :pcodegen (the Nat 60, the Int 60, the Float 60, the Int64 60, the Float64 60) 9 | > (60, 60, 60., 60, 60.) 10 | 11 | :pcodegen (the Word8 60, the Word32 60, the Word64 60) 12 | > (0x3c, 0x3c, 0x3c) 13 | 14 | :pcodegen [Just (Just 1.0), Just Nothing, Nothing] 15 | > [(Just (Just 1.)), (Just Nothing), Nothing] 16 | 17 | data MyType = MyValue(Nat) 18 | 19 | :pcodegen MyValue 1 20 | > (MyValue 1) 21 | 22 | :pcodegen "the quick brown fox jumps over the lazy dog" 23 | > "the quick brown fox jumps over the lazy dog" 24 | 25 | :pcodegen ['a', 'b', 'c'] 26 | > [0x61, 0x62, 0x63] 27 | 28 | :pcodegen "abcd" 29 | > "abcd" 30 | -------------------------------------------------------------------------------- /tests/read-tests.dx: -------------------------------------------------------------------------------- 1 | parseString "123" :: Maybe Float 2 | > (Just 123.) 3 | parseString "123.4" :: Maybe Float 4 | > (Just 123.4) 5 | parseString "123x" :: Maybe Float 6 | > Nothing 7 | parseString "x123" :: Maybe Float 8 | > Nothing 9 | -------------------------------------------------------------------------------- /tests/repl-multiline-test-expected-output: -------------------------------------------------------------------------------- 1 | >=> >=> >=> >=> >=> >=> ... ... ... ... 30. 2 | >=> >=> ... ... >=> >=> (1, 1) 3 | >=> >=> >=> >=> 3. 4 | >=> -------------------------------------------------------------------------------- /tests/repl-multiline-test.dx: -------------------------------------------------------------------------------- 1 | 2 | -- comment 3 | 4 | 'Single-line multiline comment 5 | 6 | :p 7 | triple = \x. 8 | y = x + x 9 | x + y 10 | triple 10.0 11 | 12 | f = \x:Int. 13 | (x, 14 | x) 15 | 16 | f 1 17 | 18 | y = 1. * 3. 19 | 20 | :p y 21 | -------------------------------------------------------------------------------- /tests/repl-regression-528-test-expected-output: -------------------------------------------------------------------------------- 1 | >=> ... ... Parse error:1:6: 2 | | 3 | 1 | :help 4 | | ^ 5 | unrecognized command: "help" 6 | 7 | >=> -------------------------------------------------------------------------------- /tests/repl-regression-528-test.dx: -------------------------------------------------------------------------------- 1 | :help 2 | 3 | asdf 4 | -------------------------------------------------------------------------------- /tests/serialize-tests.dx: -------------------------------------------------------------------------------- 1 | :p 1 2 | > 1 3 | 4 | :p 1.0 5 | > 1. 6 | 7 | :p [1, 2, 3] 8 | > [1, 2, 3] 9 | 10 | :p [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]] 11 | > [[1., 2., 3.], [4., 5., 6.]] 12 | 13 | :p from_ordinal(n=Fin 10, 7) 14 | > 7 15 | 16 | :p [True, False] 17 | > [True, False] 18 | 19 | :p () 20 | > () 21 | 22 | x = ['a', 'b'] 23 | :p for p. 24 | (i,j) = p 25 | [x[i], x[j]] 26 | > [[0x61, 0x61], [0x61, 0x62], [0x62, 0x61], [0x62, 0x62]] 27 | 28 | 'Values without a pretty-printer 29 | 30 | :p Int 31 | > Int32 32 | 33 | :p Fin 10 34 | > (Fin 10) 35 | 36 | :p (Fin 10, Fin 20) 37 | > ((Fin 10), (Fin 20)) 38 | -------------------------------------------------------------------------------- /tests/set-tests.dx: -------------------------------------------------------------------------------- 1 | import set 2 | 3 | -- check order invariance. 4 | :p (to_set ["Bob", "Alice", "Charlie"]) == (to_set ["Charlie", "Bob", "Alice"]) 5 | > True 6 | 7 | -- check uniqueness. 8 | :p (to_set ["Bob", "Alice", "Alice", "Charlie"]) == (to_set ["Charlie", "Charlie", "Bob", "Alice"]) 9 | > True 10 | 11 | set1 = to_set ["Xeno", "Alice", "Bob"] 12 | set2 = to_set ["Bob", "Xeno", "Charlie"] 13 | 14 | :p set1 == set2 15 | > False 16 | 17 | :p set_union set1 set2 18 | > (UnsafeAsSet 4 ["Alice", "Bob", "Charlie", "Xeno"]) 19 | 20 | :p set_intersect set1 set2 21 | > (UnsafeAsSet 2 ["Bob", "Xeno"]) 22 | 23 | :p remove_duplicates_from_sorted ["Alice", "Alice", "Alice", "Bob", "Bob", "Charlie", "Charlie", "Charlie"] 24 | > (AsList 3 ["Alice", "Bob", "Charlie"]) 25 | 26 | :p set1 == (set_union set1 set1) 27 | > True 28 | 29 | :p set1 == (set_intersect set1 set1) 30 | > True 31 | 32 | '#### Empty set tests 33 | 34 | emptyset = to_set ([]::(Fin 0)=>String) 35 | 36 | :p emptyset == emptyset 37 | > True 38 | 39 | :p emptyset == (set_union emptyset emptyset) 40 | > True 41 | 42 | :p emptyset == (set_intersect emptyset emptyset) 43 | > True 44 | 45 | :p set1 == (set_union set1 emptyset) 46 | > True 47 | 48 | :p emptyset == (set_intersect set1 emptyset) 49 | > True 50 | 51 | '### Set Index Set tests 52 | 53 | names2 = to_set ["Bob", "Alice", "Charlie", "Alice"] 54 | 55 | Person : Type = Element names2 56 | 57 | :p size Person 58 | > 3 59 | 60 | -- Check that ordinal and unsafeFromOrdinal are inverses. 61 | roundTrip = for i:Person. 62 | i == (unsafe_from_ordinal (ordinal i)) 63 | :p all roundTrip 64 | > True 65 | 66 | -- Check that member and value are inverses. 67 | roundTrip2 = for i:Person. 68 | s = value i 69 | ix = member s names2 70 | i == from_just ix 71 | :p all roundTrip2 72 | > True 73 | 74 | setix : Person = from_just $ member "Bob" names2 75 | :p setix 76 | > Element(1) 77 | 78 | setix2 : Person = from_just $ member "Charlie" names2 79 | :p setix2 80 | > Element(2) 81 | -------------------------------------------------------------------------------- /tests/shadow-tests.dx: -------------------------------------------------------------------------------- 1 | 2 | -- repeated vars in patterns not allowed 3 | :p 4 | (x, x) = (1, 1) 5 | x 6 | > Error: variable already defined within pattern: x 7 | > 8 | > (x, x) = (1, 1) 9 | > ^ 10 | 11 | :p 12 | f = \p. 13 | (x, x) = p 14 | x 15 | f (1, 1) 16 | > Error: variable already defined within pattern: x 17 | > 18 | > (x, x) = p 19 | > ^ 20 | 21 | -- TODO: re-enable if we choose to allow non-peer shadowing 22 | -- -- shouldn't cause error even though it shadows x elsewhere 23 | -- x = 50 24 | 25 | -- :p let x = 100 in (let x = 200 in x) 26 | 27 | -- > [200] 28 | 29 | arr = 10 30 | 31 | -- TODO: enable when we handle this case 32 | -- _ = 10 33 | -- _ = 10 -- underscore shadows allowed 34 | 35 | arr = 20 36 | > Error: variable already defined: arr 37 | > 38 | > arr = 20 39 | > ^^^^ 40 | 41 | :p arr 42 | > Error: ambiguous variable: arr is defined: 43 | > in this file 44 | > in this file 45 | > 46 | > 47 | > :p arr 48 | > ^^^ 49 | 50 | -- testing top-level shadowing 51 | f : (given (a:Type), a) -> a = \x. x 52 | 53 | x = 1 54 | 55 | :p f 1 56 | > 1 57 | 58 | :p y 59 | > Error: variable not in scope: y 60 | > 61 | > :p y 62 | > ^ 63 | 64 | (_, _, z) = (1,2,3) 65 | 66 | :p z 67 | > 3 68 | 69 | :p 70 | (_, _, w) = (1,2,4) 71 | w 72 | > 4 73 | 74 | -- Testing data shadowing 75 | data Shadow = Shadow 76 | > Error: variable already defined: Shadow 77 | 78 | data Shadow2 = 79 | Shadow1(Int) 80 | Shadow1 81 | > Error: variable already defined: Shadow1 82 | 83 | ShadowCon = 1 84 | data Shadow3 = ShadowCon 85 | > Error: variable already defined: ShadowCon 86 | 87 | data Shadow4 = ShadowCon' 88 | ShadowCon' = 1 89 | > Error: variable already defined: ShadowCon' 90 | > 91 | > ShadowCon' = 1 92 | > ^^^^^^^^^^^ 93 | -------------------------------------------------------------------------------- /tests/show-tests.dx: -------------------------------------------------------------------------------- 1 | '# `Show` instances 2 | -- String 3 | 4 | :p show "abc" 5 | > "abc" 6 | 7 | -- Int32 8 | 9 | :p show (1234 :: Int32) 10 | > "1234" 11 | 12 | :p show (-1234 :: Int32) 13 | > "-1234" 14 | 15 | :p show ((f_to_i (-(pow 2. 31.))) :: Int32) 16 | > "-2147483648" 17 | 18 | -- Int64 19 | 20 | :p show (i_to_i64 1234 :: Int64) 21 | > "1234" 22 | 23 | :p show (i_to_i64 (-1234) :: Int64) 24 | > "-1234" 25 | 26 | -- Float32 27 | 28 | :p show (123.456789 :: Float32) 29 | > "123.456787" 30 | 31 | :p show ((pow 2. 16.) :: Float32) 32 | > "65536" 33 | 34 | -- FIXME(https://github.com/google-research/dex-lang/issues/316): 35 | -- Unparenthesized expression with type ascription does not parse. 36 | -- :p show (nan: Float32) 37 | 38 | :p show (nan :: Float32) 39 | > "nan" 40 | 41 | -- Note: `show nan` (Dex runtime dtoa implementation) appears different from 42 | -- `:p nan` (Dex interpreter implementation). 43 | :p nan 44 | > nan 45 | 46 | :p show (infinity :: Float32) 47 | > "inf" 48 | 49 | -- Note: `show infinity` (Dex runtime dtoa implementation) appears different from 50 | -- `:p nan` (Dex interpreter implementation). 51 | :p infinity 52 | > inf 53 | 54 | -- Float64 55 | 56 | :p show (f_to_f64 123.456789:: Float64) 57 | > "123.456787109375" 58 | 59 | :p show (f_to_f64 (pow 2. 16.):: Float64) 60 | > "65536" 61 | 62 | :p show ((f_to_f64 nan):: Float64) 63 | > "nan" 64 | 65 | -- Note: `show nan` (Dex runtime dtoa implementation) appears different from 66 | -- `:p nan` (Dex interpreter implementation). 67 | :p (f_to_f64 nan) 68 | > nan 69 | 70 | :p show ((f_to_f64 infinity):: Float64) 71 | > "inf" 72 | 73 | -- Note: `show infinity` (Dex runtime dtoa implementation) appears different from 74 | -- `:p nan` (Dex interpreter implementation). 75 | :p (f_to_f64 infinity) 76 | > inf 77 | 78 | -- Tuples 79 | 80 | :p show (123, 456) 81 | > "(123, 456)" 82 | 83 | :p show ("abc", 123) 84 | > "(abc, 123)" 85 | 86 | :p show ("abc", 123, ("def", 456)) 87 | > "(abc, 123, (def, 456))" 88 | -------------------------------------------------------------------------------- /tests/sort-tests.dx: -------------------------------------------------------------------------------- 1 | import sort 2 | 3 | :p is_sorted $ sort []::((Fin 0)=>Int) 4 | > True 5 | :p is_sorted $ sort [9, 3, 7, 4, 6, 1, 9, 1, 9, -1, 10, 10, 100, 0] 6 | > True 7 | 8 | :p 9 | xs = [1,2,4] 10 | for i:(Fin 6). 11 | search_sorted_exact(xs, ordinal i) 12 | > [Nothing, (Just 0), (Just 1), Nothing, (Just 2), Nothing] 13 | 14 | '### Lexical Sorting Tests 15 | 16 | :p "aaa" < "bbb" 17 | > True 18 | 19 | :p "aa" < "bbb" 20 | > True 21 | 22 | :p "a" < "aa" 23 | > True 24 | 25 | :p "aaa" > "bbb" 26 | > False 27 | 28 | :p "aa" > "bbb" 29 | > False 30 | 31 | :p "a" > "aa" 32 | > False 33 | 34 | :p "a" < "aa" 35 | > True 36 | 37 | :p ("" :: List Word8) > ("" :: List Word8) 38 | > False 39 | 40 | :p ("" :: List Word8) < ("" :: List Word8) 41 | > False 42 | 43 | :p "a" > "a" 44 | > False 45 | 46 | :p "a" < "a" 47 | > False 48 | 49 | :p "Thomas" < "Thompson" 50 | > True 51 | 52 | :p "Thomas" > "Thompson" 53 | > False 54 | 55 | :p is_sorted $ sort ["Charlie", "Alice", "Bob", "Aaron"] 56 | > True 57 | -------------------------------------------------------------------------------- /tests/stack-tests.dx: -------------------------------------------------------------------------------- 1 | 2 | with_stack Nat \stack. 3 | stack.push 10 4 | stack.push 11 5 | stack.pop() 6 | stack.pop() 7 | > (Just 10) 8 | 9 | with_stack Nat \stack. 10 | stack.push 10 11 | stack.push 11 12 | stack.pop() 13 | stack.pop() 14 | stack.pop() -- Check that popping an empty stack is OK. 15 | stack.push 20 16 | stack.push 21 17 | stack.pop() 18 | > (Just 21) 19 | 20 | with_stack Nat \stack. 21 | stack.pop() 22 | > Nothing 23 | -------------------------------------------------------------------------------- /tests/standalone-function-tests.dx: -------------------------------------------------------------------------------- 1 | 2 | @noinline 3 | def standalone_sum(xs:n=>v) -> v given (n|Ix, v|Add) = 4 | sum xs 5 | 6 | vec3 = [1,2,3] 7 | vec2 = [4,5] 8 | 9 | -- TODO: test that we only get one copy inlined (hard to without dumping IR 10 | -- until we have logging for that sort of thing) 11 | :p standalone_sum vec2 + standalone_sum vec3 12 | > 15 13 | 14 | mat23 = [[1,2,3],[4,5,6]] 15 | mat32 = [[1,2],[3,4],[5,6]] 16 | 17 | @noinline 18 | def standalone_transpose(x:n=>m=>a) -> m=>n=>a given (n|Ix, m|Ix, a) = 19 | for i j. x[j,i] 20 | 21 | :p (standalone_transpose mat23, standalone_transpose mat32) 22 | > ([[1, 4], [2, 5], [3, 6]], [[1, 3, 5], [2, 4, 6]]) 23 | 24 | xs = [1,2,3] 25 | 26 | @noinline 27 | def foo(_:()) -> Nat = sum xs 28 | 29 | foo () 30 | > 6 31 | 32 | 'Regression test for #1152. The standalone function is just here to 33 | make the size of the tables unknown. The actual bug is in Alegbra 34 | handling an expression like `sum_{i=0}^k k * i` where the same 35 | name occurs in the monomial and the limit. 36 | 37 | def LowerTriMat(n|Ix, v:Type) -> Type = (i:n)=>(..i)=>v 38 | def UpperTriMat(n|Ix, v:Type) -> Type = (i:n)=>(i..)=>v 39 | 40 | @noinline 41 | def bar(n: Nat) -> Float = 42 | (for k. for j:(..k). 0.0, for k. for j:(k..). 0.0) :: (LowerTriMat (Fin n) Float, UpperTriMat (Fin n) Float) 43 | 0.0 44 | 45 | bar 2 46 | > 0. 47 | -------------------------------------------------------------------------------- /tests/struct-tests.dx: -------------------------------------------------------------------------------- 1 | 2 | struct MyStruct = 3 | field1 : Int 4 | field2 : Float 5 | field3 : String 6 | 7 | my_struct = MyStruct 1 2 "abc" 8 | 9 | :p my_struct.field3 10 | > "abc" 11 | 12 | :p my_struct.(1 + 1) 13 | > Syntax error: Field must be a name 14 | > 15 | > :p my_struct.(1 + 1) 16 | > ^^^^^^^ 17 | 18 | > Parse error:12:13: 19 | > | 20 | > 12 | :p my_struct.(1 + 1) 21 | > | ^^ 22 | > unexpected ".(" 23 | > expecting "->", "..", "<..", "with", backquoted name, end of input, end of line, infix operator, name, or symbol name 24 | :p my_struct 25 | > MyStruct(1, 2., "abc") 26 | 27 | :t my_struct 28 | > MyStruct 29 | 30 | struct MyParametricStruct(a) = 31 | foo : a 32 | bar : Nat 33 | 34 | :p 35 | foo = MyParametricStruct(1.0, 1) 36 | foo.bar 37 | > 1 38 | 39 | :p 40 | foo = MyParametricStruct(1.0, 1) 41 | foo.baz 42 | > Type error:Can't resolve field baz of type (MyParametricStruct Float32) 43 | > Known fields are: [bar, foo, 0, 1] 44 | > 45 | > foo.baz 46 | > ^^^ 47 | 48 | 49 | x = (1, 2) 50 | 51 | x.0 52 | > 1 53 | 54 | x.1 55 | > 2 56 | 57 | x.2 58 | > Type error:Can't resolve field 2 of type (Nat, Nat) 59 | > Known fields are: [0, 1] 60 | > 61 | > x.2 62 | > ^ 63 | 64 | x.foo 65 | > Type error:Can't resolve field foo of type (Nat, Nat) 66 | > Known fields are: [0, 1] 67 | > 68 | > x.foo 69 | > ^^^ 70 | 71 | struct Thing(a|Add) = 72 | x : a 73 | y : a 74 | 75 | def incby(n:a) -> Thing(a) = 76 | Thing(self.x + n, self.y + n) 77 | 78 | Thing(1,2).incby(10) 79 | > Thing(11, 12) 80 | 81 | struct MissingConstraint(n) = 82 | thing : n=>Float 83 | > Type error:Couldn't synthesize a class dictionary for: (Ix n) 84 | > 85 | > thing : n=>Float 86 | > ^^^^^^^^ 87 | 88 | data AnotherMissingConstraint(n) = 89 | MkAnotherMissingConstraint(n=>Float) 90 | > Type error:Couldn't synthesize a class dictionary for: (Ix n) 91 | > 92 | > MkAnotherMissingConstraint(n=>Float) 93 | > ^^^^^^^^ 94 | -------------------------------------------------------------------------------- /tests/test_module_A.dx: -------------------------------------------------------------------------------- 1 | 2 | import test_module_C 3 | 4 | test_module_amb = 10 5 | 6 | test_module_A_val = 1 + 2 7 | 8 | def test_module_A_fun(x:Int) -> Int = x + x 9 | 10 | @noinline 11 | def test_module_A_fun_noinline(x:Int) -> Int = x + x 12 | 13 | instance FooClass(Float) 14 | def fooMethod(x) = 10.0 * x 15 | -------------------------------------------------------------------------------- /tests/test_module_B: -------------------------------------------------------------------------------- 1 | 2 | import test_module_C 3 | 4 | test_module_B_val = 10 + 2 5 | 6 | -------------------------------------------------------------------------------- /tests/test_module_B.dx: -------------------------------------------------------------------------------- 1 | 2 | import test_module_C 3 | 4 | test_module_amb = 10 5 | 6 | test_module_B_val = 10 7 | 8 | test_module_B_val_from_C = test_module_C_val 9 | 10 | instance FooClass(Nat) 11 | def fooMethod(x) = x + x 12 | 13 | arrayVal = [1,2,3] 14 | 15 | arrayVal2 = for i. arrayVal[i] * 2 16 | 17 | def fooMethodExportFromB(x:a) -> a given (a|FooClass) = fooMethod x 18 | -------------------------------------------------------------------------------- /tests/test_module_C.dx: -------------------------------------------------------------------------------- 1 | 2 | 3 | test_module_C_val = 23 4 | 5 | interface FooClass(a) 6 | fooMethod : (a) -> a 7 | -------------------------------------------------------------------------------- /tests/trig-tests.dx: -------------------------------------------------------------------------------- 1 | :p isnan nan 2 | > True 3 | :p isnan 1.0 4 | > False 5 | :p isinf infinity 6 | > True 7 | :p isinf (-infinity) 8 | > True 9 | :p isinf 1.0 10 | > False 11 | 12 | :p either_is_nan infinity nan 13 | > True 14 | :p either_is_nan nan nan 15 | > True 16 | 17 | :p atan2 (sin 0.44) (cos 0.44) ~~ 0.44 18 | > True 19 | :p atan2 (sin (-0.44)) (cos (-0.44)) ~~ (-0.44) 20 | > True 21 | :p atan2 (-sin (-0.44)) (cos (-0.44)) ~~ (0.44) 22 | > True 23 | :p atan2 (-1.0) (-1.0) ~~ (-3.0/4.0*pi) 24 | > True 25 | 26 | -- Test all the way around the circle. 27 | angles = linspace (Fin 11) (-pi + 0.001) (pi) 28 | :p all for i:(Fin 11). 29 | angles[i] ~~ atan2 (sin angles[i]) (cos angles[i]) 30 | > True 31 | 32 | :p (atan2 infinity 1.0) ~~ ( pi / 2.0) 33 | > True 34 | :p (atan2 (-infinity) 1.0) ~~ (-pi / 2.0) 35 | > True 36 | :p (atan2 1.0 infinity) ~~ 0.0 37 | > True 38 | :p (atan2 (-1.0) infinity) ~~ 0.0 39 | > True 40 | 41 | :p (atan2 infinity infinity) ~~ ( pi / 4.0) 42 | > True 43 | :p (atan2 infinity (-infinity)) ~~ ( 3.0 * pi / 4.0) 44 | > True 45 | :p (atan2 (-infinity) infinity) ~~ (-pi / 4.0) 46 | > True 47 | :p (atan2 (-infinity) (-infinity)) ~~ (-3.0 * pi / 4.0) 48 | > True 49 | 50 | :p isnan $ atan2 nan infinity 51 | > True 52 | :p isnan $ atan2 infinity nan 53 | > True 54 | :p isnan $ atan2 nan nan 55 | > True 56 | 57 | :p sinh 1.2 ~~ 1.5094614 58 | > True 59 | 60 | :p tanh 1.2 ~~ ((sinh 1.2) / (cosh 1.2)) 61 | > True 62 | 63 | :p tanh (f_to_f64 1.2) ~~ divide (sinh (f_to_f64 1.2)) (cosh (f_to_f64 1.2)) 64 | > True 65 | -------------------------------------------------------------------------------- /tests/unit/JaxADTSpec.hs: -------------------------------------------------------------------------------- 1 | -- Copyright 2023 Google LLC 2 | -- 3 | -- Use of this source code is governed by a BSD-style 4 | -- license that can be found in the LICENSE file or at 5 | -- https://developers.google.com/open-source/licenses/bsd 6 | 7 | module JaxADTSpec (spec) where 8 | 9 | import Data.Aeson (encode, decode) 10 | import Test.Hspec 11 | 12 | import Name 13 | import JAX.Concrete 14 | import JAX.Rename 15 | import JAX.ToSimp 16 | import Runtime 17 | import TopLevel 18 | import Types.Imp 19 | import Types.Primitives hiding (Sin) 20 | import Types.Source hiding (SourceName) 21 | import QueryType 22 | 23 | x_nm, y_nm :: JSourceName 24 | x_nm = JSourceName 0 0 "x" 25 | y_nm = JSourceName 1 0 "y" 26 | 27 | float :: JVarType 28 | float = (JArrayName [] F32) 29 | 30 | ten_vec :: JVarType 31 | ten_vec = (JArrayName [DimSize 10] F32) 32 | 33 | a_jaxpr :: JVarType -> Jaxpr VoidS 34 | a_jaxpr ty = Jaxpr 35 | (Nest (JBindSource x_nm ty) Empty) 36 | Empty 37 | (Nest (JEqn 38 | (Nest (JBindSource y_nm ty) Empty) 39 | Sin 40 | [JVariable $ JVar (SourceName x_nm) ty]) Empty) 41 | [JVariable $ JVar (SourceName y_nm) ty] 42 | 43 | compile :: Jaxpr VoidS -> IO LLVMCallable 44 | compile jaxpr = do 45 | let cfg = EvalConfig LLVM [LibBuiltinPath] Nothing Nothing Nothing NoOptimize PrintCodegen 46 | env <- initTopState 47 | fst <$> runTopperM cfg env do 48 | -- TODO Implement GenericE for jaxprs, derive SinkableE, and properly sink 49 | -- the jaxpr instead of just coercing it. 50 | Distinct <- getDistinct 51 | jRename <- liftRenameM $ renameJaxpr (unsafeCoerceE jaxpr) 52 | jSimp <- liftJaxSimpM (simplifyJaxpr jRename) >>= asTopLam 53 | compileTopLevelFun (EntryFunCC CUDANotRequired) jSimp >>= packageLLVMCallable 54 | 55 | spec :: Spec 56 | spec = do 57 | describe "JaxADT" do 58 | it "round-trips to json" do 59 | let first = encode $ a_jaxpr ten_vec 60 | let (Just decoded) = (decode first :: Maybe (Jaxpr VoidS)) 61 | let second = encode decoded 62 | second `shouldBe` first 63 | it "executes" do 64 | jLLVM <- compile $ a_jaxpr float 65 | result <- callEntryFun jLLVM [Float32Lit 3.0] 66 | result `shouldBe` [Float32Lit $ sin 3.0] 67 | -------------------------------------------------------------------------------- /tests/unit/RawNameSpec.hs: -------------------------------------------------------------------------------- 1 | -- Copyright 2022 Google LLC 2 | -- 3 | -- Use of this source code is governed by a BSD-style 4 | -- license that can be found in the LICENSE file or at 5 | -- https://developers.google.com/open-source/licenses/bsd 6 | 7 | {-# OPTIONS_GHC -Wno-orphans #-} 8 | 9 | module RawNameSpec (spec) where 10 | 11 | import Control.Monad 12 | import Data.Char 13 | import Test.Hspec 14 | import Test.QuickCheck 15 | import RawName qualified as R 16 | 17 | newtype RawNameMap = RMap (R.RawNameMap ()) 18 | deriving (Show) 19 | 20 | instance Arbitrary RawNameMap where 21 | arbitrary = do 22 | s <- getSize 23 | RMap . R.fromList <$> (replicateM s $ (,()) <$> arbitrary) 24 | 25 | instance Arbitrary R.NameHint where 26 | arbitrary = do 27 | arbitrary >>= \case 28 | True -> R.getNameHint . fromStringNameHint <$> arbitrary 29 | False -> return R.noHint -- TODO: Generate more interesting non-string names 30 | 31 | instance Arbitrary R.RawName where 32 | arbitrary = R.rawNameFromHint <$> arbitrary 33 | 34 | newtype StringNameHint = StringNameHint { fromStringNameHint :: String } 35 | 36 | instance Show StringNameHint where 37 | show (StringNameHint s) = s 38 | 39 | instance Arbitrary StringNameHint where 40 | arbitrary = StringNameHint <$> do 41 | s <- chooseInt (1, 7) 42 | replicateM s $ arbitrary `suchThat` isNiceAscii 43 | 44 | isNiceAscii :: Char -> Bool 45 | isNiceAscii h = isAsciiLower h || isAsciiUpper h || isDigit h 46 | 47 | spec :: Spec 48 | spec = do 49 | describe "RawName" do 50 | it "generates a fresh name" do 51 | property \hint (RMap m) -> do 52 | let name = R.freshRawName hint m 53 | not $ name `R.member` m 54 | 55 | it "repeatedly generates fresh names from the same hint" do 56 | property \hint (RMap initM) -> do 57 | let n = 512 58 | let step = \(m, ok) () -> 59 | let name = R.freshRawName hint m in 60 | (R.insert name () m, ok && not (name `R.member` m)) 61 | snd $ foldl step (initM, True) (replicate n ()) 62 | 63 | it "string names are in a bijection with short strings" do 64 | property \(StringNameHint s) -> do 65 | let s' = show (R.rawNameFromHint (R.getNameHint s)) 66 | counterexample s' $ s == s' 67 | 68 | it "string names with non-zero counters print correctly" do 69 | property \(StringNameHint s) -> do 70 | let hint = R.getNameHint s 71 | let n = R.rawNameFromHint hint 72 | let scope = R.singleton n () 73 | show (R.freshRawName hint scope) == s ++ ".1" 74 | -------------------------------------------------------------------------------- /tests/unit/Spec.hs: -------------------------------------------------------------------------------- 1 | {-# OPTIONS_GHC -F -pgmF hspec-discover #-} 2 | --------------------------------------------------------------------------------