├── CHANGELOG.md ├── app ├── exampleCount.txt ├── exampleSort.txt ├── Sort.hs └── Count.hs ├── Setup.hs ├── docs ├── meta.json ├── synopsis.png ├── src │ ├── highlight.js │ └── style.css ├── index.html ├── quick-jump.css ├── doc-index.html ├── raskell.txt ├── Lib.html ├── doc-index.json ├── linuwial.css ├── RaskellLib.html ├── Core.html ├── quick-jump.min.js └── RaskellCore.html ├── .gitignore ├── stack.yaml ├── bench └── Bench.hs ├── README.md ├── test ├── testCore.hs └── testLib.hs ├── .vscode └── tasks.json ├── LICENSE ├── package.yaml ├── raskell.cabal └── src ├── RaskellLib.hs └── RaskellCore.hs /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /app/exampleCount.txt: -------------------------------------------------------------------------------- 1 | 1 3 2 | 101 117 3 | 42 64 4 | -------------------------------------------------------------------------------- /Setup.hs: -------------------------------------------------------------------------------- 1 | import Distribution.Simple 2 | main = defaultMain 3 | -------------------------------------------------------------------------------- /docs/meta.json: -------------------------------------------------------------------------------- 1 | {"haddock_version":"2.27.0","quickjump_version":1} -------------------------------------------------------------------------------- /app/exampleSort.txt: -------------------------------------------------------------------------------- 1 | 1 0 -1 2 | 10 9 8 7 6 5 4 3 2 1 3 | 111 45 -13 33 91 -117 28 4 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .stack-work/ 2 | bench/*.html 3 | *~ 4 | stack.yaml.lock 5 | dist-newstyle/ 6 | -------------------------------------------------------------------------------- /docs/synopsis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/charlesfrye/raskell/HEAD/docs/synopsis.png -------------------------------------------------------------------------------- /stack.yaml: -------------------------------------------------------------------------------- 1 | resolver: 2 | url: https://raw.githubusercontent.com/commercialhaskell/stackage-snapshots/master/lts/21/21.yaml 3 | 4 | packages: 5 | - . 6 | -------------------------------------------------------------------------------- /docs/src/highlight.js: -------------------------------------------------------------------------------- 1 | 2 | var highlight = function (on) { 3 | return function () { 4 | var links = document.getElementsByTagName('a'); 5 | for (var i = 0; i < links.length; i++) { 6 | var that = links[i]; 7 | 8 | if (this.href != that.href) { 9 | continue; 10 | } 11 | 12 | if (on) { 13 | that.classList.add("hover-highlight"); 14 | } else { 15 | that.classList.remove("hover-highlight"); 16 | } 17 | } 18 | } 19 | }; 20 | 21 | window.onload = function () { 22 | var links = document.getElementsByTagName('a'); 23 | for (var i = 0; i < links.length; i++) { 24 | links[i].onmouseover = highlight(true); 25 | links[i].onmouseout = highlight(false); 26 | } 27 | }; 28 | -------------------------------------------------------------------------------- /bench/Bench.hs: -------------------------------------------------------------------------------- 1 | import Criterion.Main 2 | import RaskellCore 3 | import RaskellLib (gt, sample) 4 | 5 | sort :: Sequence -> Sequence 6 | sort xs = sample (maximum xs) raspSort (xs ++ [minBound]) (fromIntegral (length xs)) 7 | where 8 | raspSort :: Sequence -> Sequence 9 | raspSort s = minKQV s s gt s 10 | 11 | main :: IO () 12 | main = 13 | defaultMain 14 | [ bgroup 15 | "sort" 16 | [ bench "2" $ whnf sort (reverse [0 .. 2]), 17 | bench "4" $ whnf sort (reverse [0 .. 4]), 18 | bench "8" $ whnf sort (reverse [0 .. 8]), 19 | bench "16" $ whnf sort (reverse [1 .. 16]), 20 | bench "32" $ whnf sort (reverse [1 .. 32]) 21 | ] 22 | ] 23 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # `raskell`: RASP-L in Haskell 2 | 3 | Implements the RASP-L language from 4 | [What Algorithms Can Transformers Learn](https://arxiv.org/abs/2310.16028) 5 | by Zhou et al. in Haskell. 6 | 7 | Includes the entire RASP-L Core from Listing 2 (`src/Core.hs`), 8 | some of the RASP-L Standard Library from Listing 3 (`src/Lib.hs`), 9 | and some sample programs (in `app/`). 10 | 11 | RASP-L is a domain-specific language that models the behavior of Transformers 12 | on algorithmic tasks. 13 | 14 | Here it is running just about the worst sorting algorithm imaginable: 15 | 16 | ```haskell 17 | -- | Θ(n^2) time sorting algorithm 18 | sort :: [Token] -> [Token] 19 | sort xs = sample endOfSequence raspSort (xs ++ [startOfSequence]) numTokens 20 | where 21 | startOfSequence = minBound 22 | endOfSequence = maximum xs 23 | raspSort :: Sequence -> Sequence 24 | -- Sort by looking for the smallest element greater than the current token 25 | raspSort s = minKQV s s gt s 26 | numTokens = fromIntegral (length xs) 27 | ``` 28 | -------------------------------------------------------------------------------- /test/testCore.hs: -------------------------------------------------------------------------------- 1 | import Control.Monad (when) 2 | import RaskellCore 3 | import Test.QuickCheck (Property, Result (..), quickCheckResult, (===), (==>)) 4 | 5 | any' :: Predicate 6 | any' _ _ = True 7 | 8 | prop_maxKQV_is_maximum :: Sequence -> Property 9 | prop_maxKQV_is_maximum xs = (length xs > 1) ==> scanl1 max xs === maxKQV xs xs any' xs 10 | 11 | prop_minKQV_is_minimum :: Sequence -> Property 12 | prop_minKQV_is_minimum xs = (length xs > 1) ==> scanl1 min xs === minKQV xs xs any' xs 13 | 14 | prop_selWidth_is_num_true :: Selector -> Property 15 | prop_selWidth_is_num_true s = (length s > 1) ==> map numTrue s === selWidth s 16 | where 17 | numTrue = fromIntegral . length . filter id 18 | 19 | main :: IO () 20 | main = do 21 | results <- 22 | sequence 23 | [ quickCheckResult prop_maxKQV_is_maximum, 24 | quickCheckResult prop_minKQV_is_minimum, 25 | quickCheckResult prop_selWidth_is_num_true 26 | ] 27 | let failed = not (all isSuccess results) 28 | when failed $ error "Some tests failed" 29 | 30 | isSuccess :: Result -> Bool 31 | isSuccess Success {} = True 32 | isSuccess _ = False 33 | -------------------------------------------------------------------------------- /app/Sort.hs: -------------------------------------------------------------------------------- 1 | module Sort (main) where 2 | 3 | import RaskellCore 4 | import RaskellLib (gt, sample) 5 | 6 | -- | Θ(n^2) time, Θ(n) space sorting algorithm 7 | sort :: Sequence -> Sequence 8 | sort xs = sample endOfSequence raspSort (prep xs) seqLength 9 | where 10 | endOfSequence = maxBound 11 | seqLength = fromIntegral (length xs) 12 | raspSort :: Sequence -> Sequence 13 | -- Sort by looking for the smallest token greater than the current token 14 | raspSort s = minKQV s s gt s 15 | prep :: Sequence -> Sequence 16 | -- Prepare a sequence for sorting by padding with max/min values 17 | prep s = [maxBound] ++ s ++ [minBound] 18 | 19 | -- | Sorts a newline-separated collection of space-separated int8 sequences 20 | -- 21 | -- Try it with 22 | -- cat exampleSort.txt | raskell-sort 23 | main :: IO () 24 | main = do 25 | contents <- getContents 26 | let sorts = map sort lists 27 | lists = map parse $ lines contents 28 | mapM_ printClean sorts 29 | 30 | parse :: String -> [Token] 31 | parse = map read . words 32 | 33 | printClean :: Sequence -> IO () 34 | printClean = putStrLn . unwords . map show . dropPrefix 35 | where 36 | dropPrefix = drop 1 . dropWhile (> minBound) 37 | -------------------------------------------------------------------------------- /.vscode/tasks.json: -------------------------------------------------------------------------------- 1 | { 2 | // See https://go.microsoft.com/fwlink/?LinkId=733558 3 | // for the documentation about the tasks.json format 4 | "version": "2.0.0", 5 | "tasks": [ 6 | { 7 | "label": "test", 8 | "type": "shell", 9 | "command": "stack", 10 | "args": [ 11 | "build", 12 | "--test", 13 | "--fast", 14 | "--file-watch", 15 | "--haddock" 16 | ], 17 | "runOptions": { 18 | "runOn": "folderOpen" 19 | }, 20 | "group": { 21 | "kind": "test", 22 | "isDefault": true 23 | } 24 | }, 25 | { 26 | "label": "bench", 27 | "type": "shell", 28 | "command": "stack build --bench", 29 | "group": { 30 | "kind": "build", 31 | "isDefault": true 32 | } 33 | }, 34 | { 35 | "label": "docs", 36 | "type": "shell", 37 | "command": "stack build --haddock --haddock-arguments='-o docs/'", 38 | "group": { 39 | "kind": "build", 40 | "isDefault": true 41 | } 42 | } 43 | ] 44 | } 45 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright Charles Frye (c) 2023 2 | 3 | All rights reserved. 4 | 5 | Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions are met: 7 | 8 | * Redistributions of source code must retain the above copyright 9 | notice, this list of conditions and the following disclaimer. 10 | 11 | * Redistributions in binary form must reproduce the above 12 | copyright notice, this list of conditions and the following 13 | disclaimer in the documentation and/or other materials provided 14 | with the distribution. 15 | 16 | * Neither the name of Charles Frye nor the names of other 17 | contributors may be used to endorse or promote products derived 18 | from this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 21 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 22 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 23 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 24 | OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 25 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 26 | LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 27 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 28 | THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 29 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 30 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 | -------------------------------------------------------------------------------- /docs/src/style.css: -------------------------------------------------------------------------------- 1 | body { 2 | background-color: #fdf6e3; 3 | } 4 | 5 | .hs-identifier { 6 | color: #073642; 7 | } 8 | 9 | .hs-identifier.hs-var { 10 | } 11 | 12 | .hs-identifier.hs-type { 13 | color: #5f5faf; 14 | } 15 | 16 | .hs-keyword { 17 | color: #af005f; 18 | } 19 | 20 | .hs-string, .hs-char { 21 | color: #cb4b16; 22 | } 23 | 24 | .hs-number { 25 | color: #268bd2; 26 | } 27 | 28 | .hs-operator { 29 | color: #d33682; 30 | } 31 | 32 | .hs-glyph, .hs-special { 33 | color: #dc322f; 34 | } 35 | 36 | .hs-comment { 37 | color: #8a8a8a; 38 | } 39 | 40 | .hs-pragma { 41 | color: #2aa198; 42 | } 43 | 44 | .hs-cpp { 45 | color: #859900; 46 | } 47 | 48 | a:link, a:visited { 49 | text-decoration: none; 50 | border-bottom: 1px solid #eee8d5; 51 | } 52 | 53 | a:hover, a.hover-highlight { 54 | background-color: #eee8d5; 55 | } 56 | 57 | span.annot{ 58 | position:relative; 59 | color:#000; 60 | text-decoration:none 61 | } 62 | 63 | span.annot:hover{z-index:25; background-color:#ff0} 64 | 65 | span.annot span.annottext{ 66 | display: none; 67 | border-radius: 5px 5px; 68 | 69 | -moz-border-radius: 5px; 70 | -webkit-border-radius: 5px; 71 | 72 | box-shadow: 5px 5px 5px rgba(0, 0, 0, 0.1); 73 | -webkit-box-shadow: 5px 5px rgba(0, 0, 0, 0.1); 74 | -moz-box-shadow: 5px 5px rgba(0, 0, 0, 0.1); 75 | 76 | position: absolute; 77 | left: 1em; top: 2em; 78 | z-index: 99; 79 | margin-left: 5; 80 | background: #FFFFAA; 81 | border: 2px solid #FFAD33; 82 | padding: 0.8em 1em; 83 | } 84 | 85 | span.annot:hover span.annottext{ 86 | display:block; 87 | } 88 | 89 | /* This bridges the gap so you can mouse into the tooltip without it disappearing */ 90 | span.annot span.annottext:before{ 91 | content: ""; 92 | position: absolute; 93 | left: -1em; top: -1em; 94 | background: #FFFFFF00; 95 | z-index:-1; 96 | padding: 2em 2em; 97 | } 98 | -------------------------------------------------------------------------------- /docs/index.html: -------------------------------------------------------------------------------- 1 | raskell-0.0.0.2: RASP-L in Haskell
raskell-0.0.0.2: RASP-L in Haskell

raskell-0.0.0.2: RASP-L in Haskell

Implements the RASP-L language from "What Algorithms Can Transformers Learn" by Zhou et al. For more information, see the README at https://github.com/charlesfrye/raskell#readme.

Modules

raskell-0.0.0.2

-------------------------------------------------------------------------------- /package.yaml: -------------------------------------------------------------------------------- 1 | name: raskell 2 | version: 0.0.0.2 3 | github: "charlesfrye/raskell" 4 | license: BSD-3-Clause 5 | author: "Charles Frye" 6 | maintainer: "cfrye59@gmail.com" 7 | copyright: "Copyright (c) 2023 Charles Frye" 8 | 9 | extra-source-files: 10 | - README.md 11 | - CHANGELOG.md 12 | 13 | synopsis: RASP-L in Haskell 14 | category: nlp 15 | 16 | description: Implements the RASP-L language from "What Algorithms Can Transformers Learn" by Zhou et al. For more information, see the README at . 17 | 18 | dependencies: 19 | - base >= 4.7 && < 5 20 | 21 | ghc-options: 22 | - -Wall 23 | - -Wcompat 24 | - -Widentities 25 | - -Wincomplete-record-updates 26 | - -Wincomplete-uni-patterns 27 | - -Wmissing-export-lists 28 | - -Wmissing-home-modules 29 | - -Wpartial-fields 30 | - -Wredundant-constraints 31 | 32 | library: 33 | source-dirs: src 34 | 35 | executables: 36 | 37 | raskell-sort: 38 | main: Sort 39 | source-dirs: app 40 | ghc-options: 41 | - -threaded 42 | - -rtsopts 43 | - -with-rtsopts=-N 44 | dependencies: 45 | - raskell 46 | 47 | raskell-count: 48 | main: Count 49 | source-dirs: app 50 | ghc-options: 51 | - -threaded 52 | - -rtsopts 53 | - -with-rtsopts=-N 54 | dependencies: 55 | - raskell 56 | 57 | tests: 58 | raskell-lib-test: 59 | main: testLib.hs 60 | source-dirs: test 61 | ghc-options: 62 | - -threaded 63 | - -rtsopts 64 | - -with-rtsopts=-N 65 | dependencies: 66 | - raskell 67 | - QuickCheck 68 | raskell-core-test: 69 | main: testCore.hs 70 | source-dirs: test 71 | ghc-options: 72 | - -threaded 73 | - -rtsopts 74 | - -with-rtsopts=-N 75 | dependencies: 76 | - raskell 77 | - QuickCheck 78 | 79 | benchmarks: 80 | raskell-bench: 81 | main: Bench.hs 82 | source-dirs: bench 83 | ghc-options: 84 | - -threaded 85 | - -rtsopts 86 | - -with-rtsopts=-N 87 | dependencies: 88 | - raskell 89 | - criterion 90 | -------------------------------------------------------------------------------- /app/Count.hs: -------------------------------------------------------------------------------- 1 | module Count (main) where 2 | 3 | import RaskellCore 4 | import RaskellLib 5 | 6 | sos :: Token 7 | sos = -1 8 | 9 | eos :: Token 10 | eos = -2 11 | 12 | equals :: Token -> Token -> Token 13 | equals x y 14 | | x == y = 1 15 | | otherwise = 0 16 | 17 | -- TODO: rewrite with startCounting before withEos? 18 | raspCount :: Sequence -> Sequence 19 | raspCount inputs = finalCounts 20 | where 21 | finalCounts = 22 | -- to get the final next token prediction 23 | map (> 0) startCounting -- figure out where we start counting 24 | ? (countFroms, withEOS) -- and merge the starting number into our output there 25 | withEOS = 26 | -- add the EOS tokens in the right spots 27 | map (> 0) transitions ? (inputs `filledWith` eos, succs) 28 | succs = 29 | -- increment all of the tokens 30 | -- that's our basic prediction, the rest of this is just patching edge cases 31 | map (+ 1) inputs 32 | transitions = 33 | -- determine where we switch from counting to reading task tokens or vice versa 34 | zipWith equals inputs countTos -- it's actually just whenever we hit a number we're counting to 35 | countFroms = 36 | -- determine what number we're in the middle of counting from for each token 37 | maxKQV idxs (map (+ 1) lastSOS) (==) inputs 38 | countTos = 39 | -- determine what number we're in the middle of counting to for each token 40 | maxKQV idxs (map (+ 2) lastSOS) (==) inputs 41 | startCounting = zipWith equals idxs (map (+ 2) lastSOS) 42 | lastSOS = 43 | -- sequence of indices of most recent SoS token 44 | maxKQV inputs (inputs `filledWith` sos) (==) idxs 45 | idxs = indicesOf inputs 46 | 47 | -- showSequence :: Sequence -> String 48 | -- showSequence tokens = "[" ++ intercalate ", " (map showToken tokens) ++ "]" 49 | 50 | showToken :: Token -> String 51 | showToken token 52 | | token == sos = "SOS" 53 | | token == eos = "EOS" 54 | | otherwise = show token 55 | 56 | -- | Counts between numbers provided in a newline-separated collection of space-separated non-negative int8 pairs. 57 | -- 58 | -- Try it with 59 | -- cat exampleCount.txt | raskell-count 60 | main :: IO () 61 | main = do 62 | contents <- getContents 63 | let counts = map count lists 64 | lists = map parse $ lines contents 65 | mapM_ printClean counts 66 | 67 | count :: Sequence -> Sequence 68 | count xs = sample eos raspCount (prep xs) seqLength 69 | where 70 | seqLength = 24 71 | prep :: Sequence -> Sequence 72 | -- Prepare a sequence for counting by prepending an SOS token 73 | prep s = sos : s 74 | 75 | parse :: String -> [Token] 76 | parse = map read . words 77 | 78 | printClean :: Sequence -> IO () 79 | printClean = putStrLn . unwords . map showToken . dropSuffix . dropPrefix 80 | where 81 | dropPrefix = drop 3 -- drop 1 . dropWhile (>= -1) 82 | dropSuffix = reverse . drop 1 . reverse 83 | -------------------------------------------------------------------------------- /raskell.cabal: -------------------------------------------------------------------------------- 1 | cabal-version: 2.2 2 | 3 | -- This file has been generated from package.yaml by hpack version 0.35.2. 4 | -- 5 | -- see: https://github.com/sol/hpack 6 | 7 | name: raskell 8 | version: 0.0.0.2 9 | synopsis: RASP-L in Haskell 10 | description: Implements the RASP-L language from "What Algorithms Can Transformers Learn" by Zhou et al. For more information, see the README at . 11 | category: nlp 12 | homepage: https://github.com/charlesfrye/raskell#readme 13 | bug-reports: https://github.com/charlesfrye/raskell/issues 14 | author: Charles Frye 15 | maintainer: cfrye59@gmail.com 16 | copyright: Copyright (c) 2023 Charles Frye 17 | license: BSD-3-Clause 18 | license-file: LICENSE 19 | build-type: Simple 20 | extra-source-files: 21 | README.md 22 | CHANGELOG.md 23 | 24 | source-repository head 25 | type: git 26 | location: https://github.com/charlesfrye/raskell 27 | 28 | library 29 | exposed-modules: 30 | RaskellCore 31 | RaskellLib 32 | other-modules: 33 | Paths_raskell 34 | autogen-modules: 35 | Paths_raskell 36 | hs-source-dirs: 37 | src 38 | ghc-options: -Wall -Wcompat -Widentities -Wincomplete-record-updates -Wincomplete-uni-patterns -Wmissing-export-lists -Wmissing-home-modules -Wpartial-fields -Wredundant-constraints 39 | build-depends: 40 | base >=4.7 && <5 41 | default-language: Haskell2010 42 | 43 | executable raskell-count 44 | main-is: Count.hs 45 | other-modules: 46 | Sort 47 | Paths_raskell 48 | autogen-modules: 49 | Paths_raskell 50 | hs-source-dirs: 51 | app 52 | ghc-options: -Wall -Wcompat -Widentities -Wincomplete-record-updates -Wincomplete-uni-patterns -Wmissing-export-lists -Wmissing-home-modules -Wpartial-fields -Wredundant-constraints -threaded -rtsopts -with-rtsopts=-N -main-is Count 53 | build-depends: 54 | base >=4.7 && <5 55 | , raskell 56 | default-language: Haskell2010 57 | 58 | executable raskell-sort 59 | main-is: Sort.hs 60 | other-modules: 61 | Count 62 | Paths_raskell 63 | autogen-modules: 64 | Paths_raskell 65 | hs-source-dirs: 66 | app 67 | ghc-options: -Wall -Wcompat -Widentities -Wincomplete-record-updates -Wincomplete-uni-patterns -Wmissing-export-lists -Wmissing-home-modules -Wpartial-fields -Wredundant-constraints -threaded -rtsopts -with-rtsopts=-N -main-is Sort 68 | build-depends: 69 | base >=4.7 && <5 70 | , raskell 71 | default-language: Haskell2010 72 | 73 | test-suite raskell-core-test 74 | type: exitcode-stdio-1.0 75 | main-is: testCore.hs 76 | other-modules: 77 | Paths_raskell 78 | autogen-modules: 79 | Paths_raskell 80 | hs-source-dirs: 81 | test 82 | ghc-options: -Wall -Wcompat -Widentities -Wincomplete-record-updates -Wincomplete-uni-patterns -Wmissing-export-lists -Wmissing-home-modules -Wpartial-fields -Wredundant-constraints -threaded -rtsopts -with-rtsopts=-N 83 | build-depends: 84 | QuickCheck 85 | , base >=4.7 && <5 86 | , raskell 87 | default-language: Haskell2010 88 | 89 | test-suite raskell-lib-test 90 | type: exitcode-stdio-1.0 91 | main-is: testLib.hs 92 | other-modules: 93 | Paths_raskell 94 | autogen-modules: 95 | Paths_raskell 96 | hs-source-dirs: 97 | test 98 | ghc-options: -Wall -Wcompat -Widentities -Wincomplete-record-updates -Wincomplete-uni-patterns -Wmissing-export-lists -Wmissing-home-modules -Wpartial-fields -Wredundant-constraints -threaded -rtsopts -with-rtsopts=-N 99 | build-depends: 100 | QuickCheck 101 | , base >=4.7 && <5 102 | , raskell 103 | default-language: Haskell2010 104 | 105 | benchmark raskell-bench 106 | type: exitcode-stdio-1.0 107 | main-is: Bench.hs 108 | other-modules: 109 | Paths_raskell 110 | autogen-modules: 111 | Paths_raskell 112 | hs-source-dirs: 113 | bench 114 | ghc-options: -Wall -Wcompat -Widentities -Wincomplete-record-updates -Wincomplete-uni-patterns -Wmissing-export-lists -Wmissing-home-modules -Wpartial-fields -Wredundant-constraints -threaded -rtsopts -with-rtsopts=-N 115 | build-depends: 116 | base >=4.7 && <5 117 | , criterion 118 | , raskell 119 | default-language: Haskell2010 120 | -------------------------------------------------------------------------------- /docs/quick-jump.css: -------------------------------------------------------------------------------- 1 | /* @group Fundamentals */ 2 | 3 | .hidden { 4 | display: none; 5 | } 6 | 7 | /* @end */ 8 | 9 | /* @group Search box layout */ 10 | 11 | #search { 12 | position: fixed; 13 | top: 3.2em; 14 | bottom: 0; 15 | left: calc(50% - 22em); 16 | width: 44em; 17 | z-index: 1000; 18 | overflow-y: auto; 19 | } 20 | 21 | @media only screen and (max-width: 999px) { 22 | #search { 23 | top: 5.7em; 24 | } 25 | } 26 | 27 | #search-form, #search-results { 28 | box-shadow: 2px 2px 6px rgb(199, 204, 208); 29 | pointer-events: all; 30 | } 31 | 32 | #search-form input { 33 | font-size: 1.25em; line-height: 2.3em; height: 2.4em; 34 | display: block; 35 | box-sizing: border-box; 36 | width: 100%; 37 | margin: 0; 38 | padding: 0 0.75em; 39 | border: 0.05em solid rgb(151, 179, 202); 40 | } 41 | 42 | #search input:focus { 43 | outline: none; 44 | } 45 | 46 | #search p.error { 47 | color: rgb(107, 24, 24); 48 | font-weight: bold; 49 | } 50 | 51 | #search-results { 52 | box-sizing: border-box; 53 | border: 0.05em solid #b2d5fb; 54 | background: #e8f3ff; 55 | max-height: 80%; 56 | overflow: scroll; 57 | } 58 | 59 | #search-form input + #search-results { 60 | border-top: none; 61 | top: 3em; 62 | max-height: calc(100% - 3em); 63 | } 64 | 65 | /* @end */ 66 | 67 | /* @group search results */ 68 | 69 | #search-results > ul { 70 | margin: 0; 71 | list-style: none; 72 | } 73 | 74 | #search-results > ul > li, 75 | #search-results > p, 76 | #search-results > table { 77 | padding: 0.5em 1em; 78 | margin: 0; 79 | } 80 | 81 | #search-results > ul > li { 82 | border-bottom: 1px solid #b2d5fb; 83 | } 84 | 85 | #search-results > ul > li > ul { 86 | list-style: none; 87 | } 88 | 89 | .search-module h4 { 90 | margin: 0; 91 | } 92 | 93 | .search-module > ul { 94 | margin: 0.5em 0 0.5em 2em; 95 | } 96 | 97 | .search-module > ul > li > a[href] { 98 | display: block; 99 | color: inherit; 100 | padding: 0.25em 0.5em; 101 | } 102 | 103 | .search-module > ul > li > a[href].active-link { 104 | background: #faf9dc; 105 | } 106 | 107 | .search-module a[href]:hover { 108 | text-decoration: none; 109 | } 110 | 111 | .search-result a a { 112 | pointer-events: none; 113 | } 114 | 115 | .search-result ul.subs { 116 | display: inline-block; 117 | margin: 0; padding: 0; 118 | } 119 | 120 | .search-result ul.subs li { 121 | display: none; 122 | } 123 | 124 | .search-result ul.subs::after { 125 | display: inline-block; 126 | content: "..."; 127 | color: rgb(78,98,114); 128 | margin: 0 0.25em; 129 | } 130 | 131 | .more-results { 132 | color: rgb(99, 141, 173); 133 | position: relative; 134 | } 135 | 136 | .more-results::before { 137 | content: "+"; 138 | display: inline-block; 139 | color: #b2d5fb; 140 | font-weight: bold; 141 | font-size: 1.25em; line-height: inherit; 142 | position: absolute; 143 | left: -1em; 144 | } 145 | 146 | /* @end */ 147 | 148 | /* @group Keyboard shortcuts table */ 149 | 150 | .keyboard-shortcuts { 151 | line-height: 1.6em; 152 | } 153 | 154 | .keyboard-shortcuts th { 155 | color: rgb(78,98,114); 156 | } 157 | 158 | .keyboard-shortcuts td:first-child, 159 | .keyboard-shortcuts th:first-child { 160 | text-align: right; 161 | padding-right: 0.6em; 162 | } 163 | 164 | .key { 165 | display: inline-block; 166 | font-size: 0.9em; 167 | min-width: 0.8em; line-height: 1.2em; 168 | text-align: center; 169 | background: #b2d5fb; 170 | border: 1px solid #74a3d6; 171 | padding: 0 0.2em; 172 | margin: 0 0.1em; 173 | } 174 | 175 | /* @end */ 176 | 177 | /* @group Dropdown menus */ 178 | 179 | /* Based on #search styling above. */ 180 | 181 | .dropdown-menu { 182 | position: fixed; 183 | /* Not robust to window size changes. */ 184 | top: 3.2em; 185 | right: 0; 186 | /* To display on top of synopsis menu on right side. */ 187 | z-index: 1000; 188 | border: 0.05em solid #b2d5fb; 189 | background: #e8f3ff; 190 | } 191 | 192 | @media only screen and (max-width: 999px) { 193 | .dropdown-menu { 194 | top: 5.7em; 195 | } 196 | } 197 | 198 | .dropdown-menu * { 199 | margin: 0.1em; 200 | } 201 | 202 | .dropdown-menu button { 203 | border: 1px #5E5184 solid; 204 | border-radius: 3px; 205 | background: #5E5184; 206 | padding: 3px; 207 | color: #f4f4f4; 208 | min-width: 6em; 209 | } 210 | 211 | .dropdown-menu button:hover { 212 | color: #5E5184; 213 | background: #f4f4f4; 214 | } 215 | 216 | .dropdown-menu button:active { 217 | color: #f4f4f4; 218 | background: #5E5184; 219 | } 220 | 221 | /* @end */ 222 | -------------------------------------------------------------------------------- /src/RaskellLib.hs: -------------------------------------------------------------------------------- 1 | -- | This module provides convenience functions built from the core of the RASP-L language. 2 | -- 3 | -- It is based on Listing 3 of 4 | -- "What Algorithms Can Transformers Learn", https://arxiv.org/abs/2310.16028, 5 | -- by Zhou et al. 6 | module RaskellLib 7 | ( -- * Logical Operations 8 | (?), 9 | shiftRight, 10 | toBool, 11 | mask, 12 | 13 | -- * Running Aggregations 14 | cumSum, 15 | maximum', 16 | minimum', 17 | argmax, 18 | argmin, 19 | 20 | -- * Aggregations with `Queries` 21 | numPrev, 22 | hasSeen, 23 | 24 | -- * Indexing with `Queries` 25 | firsts, 26 | lasts, 27 | indexSelect, 28 | 29 | -- * Token Comparisons 30 | leq, 31 | geq, 32 | lt, 33 | gt, 34 | 35 | -- * Sampling 36 | sample, 37 | 38 | -- * Compatibility 39 | where', 40 | sampleAutoregressive, 41 | ) 42 | where 43 | 44 | import Data.Int (Int8) 45 | import Data.Word (Word8) 46 | import RaskellCore 47 | 48 | -- | Use a boolean sequence to select between two sequences. 49 | -- Also known in Python RASP-L as "where", see `where'`. 50 | (?) :: [Bool] -> (Sequence, Sequence) -> Sequence 51 | bs ? (xs, ys) = seqMap (\xm ym -> if xm == 0 then ym else xm) xms yms 52 | where 53 | xms = seqMap (\bt x -> if bt == 1 then x else 0) bts xs 54 | yms = seqMap (\bt y -> if bt == 0 then y else 0) bts ys 55 | bts = fromBoolSeq bs 56 | 57 | -- | Use a boolean sequence to select between two sequences. 58 | -- Provided for compatibility with Listing 3, but with 59 | -- an apostrophe to avoid a name clash with the "where" keyword. 60 | where' :: [Bool] -> Sequence -> Sequence -> Sequence 61 | where' bs xs ys = bs ? (xs, ys) 62 | 63 | -- | Shift a sequence to the right by a given number of elements, 64 | -- filling the vacated positions with the provided `Token`. 65 | shiftRight :: 66 | -- | Filler `Token` 67 | Token -> 68 | -- | Number of positions to shift 69 | Int8 -> 70 | -- | Input `Sequence` 71 | Sequence -> 72 | Sequence 73 | shiftRight filler n xs = kqv filler Mean shiftedIdxs idxs (==) xs 74 | where 75 | shiftedIdxs = map (+ n) idxs 76 | idxs = indices xs 77 | 78 | -- | Maps tokens onto bools using Python's "truthiness" rules. 79 | toBool :: Token -> Bool 80 | toBool x 81 | | x == 0 = False 82 | | otherwise = True 83 | 84 | -- | Converts a list of bools to a sequence of tokens. 85 | fromBoolSeq :: [Bool] -> Sequence 86 | fromBoolSeq = map fromBool 87 | 88 | -- | Computes the cumulative sum of a boolean sequence. 89 | cumSum :: [Bool] -> Sequence 90 | cumSum bs = selWidth (selectCausal bTokens bTokens first) 91 | where 92 | bTokens = fromBoolSeq bs 93 | first x _ = toBool x 94 | 95 | -- | Masks a `Sequence` with a boolean sequence, using the provided `Token` as the mask. 96 | mask :: Token -> [Bool] -> Sequence -> Sequence 97 | mask maskT bs xs = bs ? (xs, xs `filledWith` maskT) 98 | 99 | -- | Computes the running maximum of a `Sequence`. 100 | maximum' :: Sequence -> Sequence 101 | maximum' xs = maxKQV xs xs always xs 102 | where 103 | always _ _ = True 104 | 105 | -- | Computes the running minimum of a `Sequence`. 106 | minimum' :: Sequence -> Sequence 107 | minimum' xs = minKQV xs xs always xs 108 | where 109 | always _ _ = True 110 | 111 | -- | Computes the indices of the running maximum values in a `Sequence`. 112 | argmax :: Sequence -> Sequence 113 | argmax xs = maxKQV xs maxs (==) (indicesOf xs) 114 | where 115 | maxs = maximum' xs 116 | 117 | -- | Computes the indices of the running minimum values in a `Sequence`. 118 | argmin :: Sequence -> Sequence 119 | argmin xs = maxKQV xs mins (==) (indicesOf xs) 120 | where 121 | mins = minimum' xs 122 | 123 | -- | Computes the number of previous tokens in a `Sequence` that are equal to each `Token` from `Queries`. 124 | numPrev :: Sequence -> Queries -> Sequence 125 | numPrev xs queries = selWidth (selectCausal xs queries (==)) 126 | 127 | -- | Returns 1s where the `Token` from the `Queries` has been seen before in the `Sequence`. 128 | hasSeen :: Sequence -> Queries -> Sequence 129 | hasSeen xs queries = kqv 0 Max xs queries (==) (queries `filledWith` 1) 130 | 131 | -- | Finds the first occurrence of each query token in a `Sequence`. 132 | firsts :: Token -> Sequence -> Queries -> Sequence 133 | firsts filler xs queries = kqv filler Min xs queries (==) (indicesOf xs) 134 | 135 | -- | Finds the last occurrence of each query token in a `Sequence`. 136 | lasts :: Token -> Sequence -> Queries -> Sequence 137 | lasts filler xs queries = kqv filler Max xs queries (==) (indicesOf xs) 138 | 139 | -- | Selects the tokens from a `Sequence` at the indices provided by another sequence. 140 | indexSelect :: Token -> Sequence -> Sequence -> Sequence 141 | indexSelect filler xs idxs = kqv filler Max (indicesOf xs) idxs (==) xs 142 | 143 | leq :: Token -> Token -> Bool 144 | leq = (<=) 145 | 146 | geq :: Token -> Token -> Bool 147 | geq = (>=) 148 | 149 | lt :: Token -> Token -> Bool 150 | lt = (<) 151 | 152 | gt :: Token -> Token -> Bool 153 | gt = (>) 154 | 155 | -- | Greedily and autoregressively sample the output of a RASP-L program on a sequence. 156 | sample :: 157 | -- | End of sequence token 158 | Token -> 159 | -- | RASP-L program to extend the sequence 160 | (Sequence -> Sequence) -> 161 | -- | Initial/prompt sequence 162 | Sequence -> 163 | -- | Number of steps to decode 164 | Word8 -> 165 | -- | Output (including prompt) 166 | Sequence 167 | sample _ _ xs 0 = xs 168 | sample endOfSequence prog xs n 169 | | last xs == endOfSequence = xs 170 | | otherwise = sample endOfSequence prog (xs ++ [last $ prog xs]) (n - 1) 171 | 172 | -- | Greedily and autoregressively sample the output of a RASP-L program on a sequence. 173 | -- 174 | -- Provided for compatibility with Listing 3. 175 | sampleAutoregressive :: Token -> (Sequence -> Sequence) -> Sequence -> Word8 -> Sequence 176 | sampleAutoregressive = sample 177 | -------------------------------------------------------------------------------- /test/testLib.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE FlexibleInstances #-} 2 | {-# OPTIONS_GHC -Wno-name-shadowing #-} 3 | 4 | import Control.Monad (when) 5 | import Data.Int (Int8) 6 | import Data.List (inits) 7 | import RaskellCore 8 | import RaskellLib 9 | import Test.QuickCheck (Arbitrary (arbitrary, shrink), Property, Result (..), choose, quickCheckResult, vectorOf, (===), (==>)) 10 | 11 | prop_where_allTrue_is_idLeft :: Sequence -> Property 12 | prop_where_allTrue_is_idLeft xs = xs === allTrue ? (xs, xs `filledWith` undefined) 13 | where 14 | allTrue = replicate (length xs) True 15 | 16 | prop_where_allFalse_is_idRight :: Sequence -> Property 17 | prop_where_allFalse_is_idRight xs = xs === allFalse ? (xs `filledWith` undefined, xs) 18 | where 19 | allFalse = replicate (length xs) False 20 | 21 | prop_where_alternating_alternates :: Sequence -> Property 22 | prop_where_alternating_alternates xs = take l (cycle [1, -1]) === alternating ? (xs `filledWith` 1, xs `filledWith` (-1)) 23 | where 24 | alternating = cycle [True, False] 25 | l = length xs 26 | 27 | prop_shiftRight_zero_is_id :: Sequence -> Property 28 | prop_shiftRight_zero_is_id xs = xs === shiftRight 0 0 xs 29 | 30 | prop_shiftRight_length_matches_replicate :: Sequence -> Property 31 | prop_shiftRight_length_matches_replicate xs = replicate (fromIntegral l) 1 === shiftRight 1 l xs 32 | where 33 | l = fromIntegral . length $ xs 34 | 35 | prop_shiftRight_matches_rotateFill :: Token -> Int8 -> Sequence -> Property 36 | prop_shiftRight_matches_rotateFill t n xs = n >= 0 && l > 0 ==> rotateFill xs === shiftRight t n xs 37 | where 38 | -- Uses normal list operations to shift the sequence. 39 | rotateFill :: Sequence -> Sequence 40 | rotateFill s = take l $ replicate n' t ++ take (l - n') s 41 | 42 | n' = fromIntegral n 43 | l = length xs 44 | 45 | prop_cumSum_matches_scanl :: [Bool] -> Property 46 | prop_cumSum_matches_scanl bs = scanl1 (+) (map fromBool bs) === cumSum bs 47 | 48 | prop_mask_matches_zipWith :: Token -> [Bool] -> Sequence -> Property 49 | prop_mask_matches_zipWith t bs xs = zipWith (\b x -> if b then x else t) bs xs === mask t bs xs 50 | 51 | prop_maximum'_matches_scanl :: Sequence -> Property 52 | prop_maximum'_matches_scanl xs = scanl1 max xs === maximum' xs 53 | 54 | prop_minimum'_matches_scanl :: Sequence -> Property 55 | prop_minimum'_matches_scanl xs = scanl1 min xs === minimum' xs 56 | 57 | prop_argmax_matches_scanl :: Sequence -> Property 58 | prop_argmax_matches_scanl xs = map fst (scanl1 argmax' (enumerate xs)) === argmax xs 59 | where 60 | argmax' :: (Token, Token) -> (Token, Token) -> (Token, Token) 61 | argmax' (accIdx, accVal) (idx, val) 62 | | val >= accVal = (idx, val) 63 | | otherwise = (accIdx, accVal) 64 | 65 | enumerate = zip [0 ..] 66 | 67 | prop_argmin_matches_scanl :: Sequence -> Property 68 | prop_argmin_matches_scanl xs = map fst (scanl1 argmin' (enumerate xs)) === argmin xs 69 | where 70 | argmin' :: (Token, Token) -> (Token, Token) -> (Token, Token) 71 | argmin' (accIdx, accVal) (idx, val) 72 | | val <= accVal = (idx, val) 73 | | otherwise = (accIdx, accVal) 74 | 75 | enumerate = zip [0 ..] 76 | 77 | -- Define a newtype for a pair of lists of the same length 78 | newtype EqualLengthSequences = EqualLengthSequences ([Token], [Token]) 79 | deriving (Show) 80 | 81 | -- Generate a pair of lists of the same length 82 | instance Arbitrary EqualLengthSequences where 83 | arbitrary = do 84 | len <- choose (0, 100) -- Choose a length between 0 and 100 85 | list1 <- vectorOf len arbitrary 86 | list2 <- vectorOf len arbitrary 87 | return $ EqualLengthSequences (list1, list2) 88 | 89 | shrink (EqualLengthSequences (l1, l2)) = 90 | [EqualLengthSequences (l1', l2') | (l1', l2') <- zip (shrink l1) (shrink l2)] 91 | 92 | prop_numPrev_matches_zipWith :: EqualLengthSequences -> Property 93 | prop_numPrev_matches_zipWith (EqualLengthSequences (xs, qs)) = 94 | l > 0 ==> zipWith numPrev' (tail (inits xs)) qs === numPrev xs qs 95 | where 96 | numPrev' :: [Token] -> Token -> Int8 97 | numPrev' (x : xs) q = numPrev' xs q + fromBool (x == q) 98 | numPrev' [] _ = 0 99 | 100 | l = length qs 101 | 102 | prop_hasSeen_matches_zipWith :: EqualLengthSequences -> Property 103 | prop_hasSeen_matches_zipWith (EqualLengthSequences (xs, qs)) = 104 | l > 0 ==> zipWith hasSeen' (tail (inits xs)) qs === hasSeen xs qs 105 | where 106 | hasSeen' :: [Token] -> Token -> Int8 107 | hasSeen' (x : xs) q = max (hasSeen' xs q) (fromBool (x == q)) 108 | hasSeen' [] _ = 0 109 | 110 | l = length qs 111 | 112 | prop_firsts_matches_zipWith :: Token -> EqualLengthSequences -> Property 113 | prop_firsts_matches_zipWith filler (EqualLengthSequences (xs, qs)) = 114 | l > 0 ==> zipWith firsts' (tail (inits (enumerate xs))) qs === firsts filler xs qs 115 | where 116 | firsts' :: [(Int8, Token)] -> Token -> Int8 117 | firsts' ((idx, x) : xs) q = 118 | if x == q then idx else firsts' xs q 119 | firsts' [] _ = filler 120 | 121 | enumerate = zip [0 ..] 122 | 123 | l = length qs 124 | 125 | prop_lasts_matches_zipWith :: Token -> EqualLengthSequences -> Property 126 | prop_lasts_matches_zipWith filler (EqualLengthSequences (xs, qs)) = 127 | l > 0 ==> zipWith lasts' (tail (inits (enumerate xs))) qs === lasts filler xs qs 128 | where 129 | lasts' :: [(Int8, Token)] -> Token -> Int8 130 | lasts' xs q = case filter (\(_, x) -> x == q) xs of 131 | [] -> filler 132 | xs' -> fst $ last xs' 133 | 134 | enumerate = zip [0 ..] 135 | 136 | l = length qs 137 | 138 | prop_indexSelect_matches_zipWith :: Token -> EqualLengthSequences -> Property 139 | prop_indexSelect_matches_zipWith filler (EqualLengthSequences (xs, idxs)) = 140 | l > 0 ==> zipWith indexSelect' (tail (inits (enumerate xs))) idxs === indexSelect filler xs idxs 141 | where 142 | indexSelect' :: [(Int8, Token)] -> Token -> Token 143 | indexSelect' xs q = case filter (\(idx, _) -> idx == q) xs of 144 | [] -> filler 145 | xs' -> snd $ last xs' 146 | enumerate = zip [0 ..] 147 | 148 | l = length idxs 149 | 150 | main :: IO () 151 | main = do 152 | results <- 153 | sequence 154 | [ quickCheckResult prop_where_allTrue_is_idLeft, 155 | quickCheckResult prop_where_allFalse_is_idRight, 156 | quickCheckResult prop_where_alternating_alternates, 157 | quickCheckResult prop_shiftRight_zero_is_id, 158 | quickCheckResult prop_shiftRight_length_matches_replicate, 159 | quickCheckResult prop_shiftRight_matches_rotateFill, 160 | quickCheckResult prop_cumSum_matches_scanl, 161 | quickCheckResult prop_mask_matches_zipWith, 162 | quickCheckResult prop_maximum'_matches_scanl, 163 | quickCheckResult prop_minimum'_matches_scanl, 164 | quickCheckResult prop_argmax_matches_scanl, 165 | quickCheckResult prop_argmin_matches_scanl, 166 | quickCheckResult prop_numPrev_matches_zipWith, 167 | quickCheckResult prop_hasSeen_matches_zipWith, 168 | quickCheckResult prop_firsts_matches_zipWith, 169 | quickCheckResult prop_lasts_matches_zipWith, 170 | quickCheckResult prop_indexSelect_matches_zipWith 171 | ] 172 | let failed = not (all isSuccess results) 173 | when failed $ error "Some tests failed" 174 | 175 | isSuccess :: Result -> Bool 176 | isSuccess Success {} = True 177 | isSuccess _ = False 178 | -------------------------------------------------------------------------------- /docs/doc-index.html: -------------------------------------------------------------------------------- 1 | raskell-0.0.0.2: RASP-L in Haskell (Index)
raskell-0.0.0.2: RASP-L in Haskell
-------------------------------------------------------------------------------- /docs/raskell.txt: -------------------------------------------------------------------------------- 1 | -- Hoogle documentation, generated by Haddock 2 | -- See Hoogle, http://www.haskell.org/hoogle/ 3 | 4 | 5 | -- | RASP-L in Haskell 6 | -- 7 | -- Implements the RASP-L language from "What Algorithms Can Transformers 8 | -- Learn" by Zhou et al. For more information, see the README at 9 | -- https://github.com/charlesfrye/raskell#readme. 10 | @package raskell 11 | @version 0.0.0.2 12 | 13 | 14 | -- | This module provides the core of the RASP-L language. 15 | -- 16 | -- It is based on Listing 2 of "What Algorithms Can Transformers Learn", 17 | -- https://arxiv.org/abs/2310.16028, by Zhou et al. 18 | module RaskellCore 19 | 20 | -- | A Token in a Sequence is a small integer. RASP-L uses 21 | -- Int8 to ensure all maps of type Token -> Token 22 | -- are learnable. 23 | type Token = Int8 24 | 25 | -- | A Sequence is a list of Tokens. 26 | type Sequence = [Token] 27 | 28 | -- | A collection of keys is a list of Tokens. 29 | type Keys = Sequence 30 | 31 | -- | A collection of queries is a list of Tokens. 32 | type Queries = Sequence 33 | 34 | -- | A collection of values is a list of Tokens. 35 | type Values = Sequence 36 | 37 | -- | We can compare Keys and Queries to determine if they 38 | -- match. 39 | type Predicate = Token -> Token -> Bool 40 | 41 | -- | The equivalents of "attention maps" are collections of Boolean 42 | -- sequences. 43 | type Selector = [BoolSequence] 44 | 45 | -- | Internally, we sometimes need to operate on collections of 46 | -- Bools. 47 | type BoolSequence = [Bool] 48 | 49 | -- | Type alias for "fully-specified" aggregators that are ready to 50 | -- aggregate a sequence of values with a selector. 51 | type Aggregator = Selector -> Values -> Sequence 52 | 53 | -- | Enum for the three methods for aggregating selected values 54 | data AggregationType 55 | Min :: AggregationType 56 | Mean :: AggregationType 57 | Max :: AggregationType 58 | 59 | -- | Performs a key-query-value lookup operation and aggregates over 60 | -- values. 61 | -- 62 | -- Given a filler token, an aggregation type, two sequences (keys and 63 | -- queries), and a predicate, it returns a processed sequence. It first 64 | -- selects elements based on the predicate and then aggregates them. 65 | -- 66 | -- Roughly matches the attention layer of a Transformer. 67 | kqv :: Token -> AggregationType -> Keys -> Queries -> Predicate -> Values -> Sequence 68 | 69 | -- | Performs Key-Query-Value lookup with maximum aggregation of values. 70 | maxKQV :: Keys -> Queries -> Predicate -> Values -> Sequence 71 | 72 | -- | Performs Key-Query-Value lookup with minimum aggregation of values. 73 | minKQV :: Keys -> Queries -> Predicate -> Values -> Sequence 74 | 75 | -- | Compares pairs of elements from sequences with a predicate subject to 76 | -- a causal constraint. 77 | selectCausal :: Keys -> Queries -> Predicate -> Selector 78 | 79 | -- | Creates a matched-length constant sequence with the provided token. 80 | filledWith :: Sequence -> Token -> Sequence 81 | 82 | -- | Extracts the indices of the elements in a sequence. 83 | indicesOf :: Sequence -> Sequence 84 | 85 | -- | Aggregates values with some aggregation, filling in with a default 86 | -- token. 87 | aggregate :: AggregationType -> Token -> Aggregator 88 | 89 | -- | Aggregates values by selecting the largest value. 90 | aggrMax :: Token -> Aggregator 91 | 92 | -- | Aggregates values by taking the mean. 93 | aggrMean :: Token -> Aggregator 94 | 95 | -- | Aggregates values by selecting the smallest value. 96 | aggrMin :: Token -> Aggregator 97 | 98 | -- | Computes the "width", or number of nonzero entries, of the rows of a 99 | -- Selector. 100 | selWidth :: Selector -> Sequence 101 | fromBool :: Bool -> Token 102 | 103 | -- | Applies an elementwise operation to a sequence of tokens. 104 | -- 105 | -- Roughly matches the MLP layer in a Transformer. Alias for map. 106 | tokMap :: (Token -> Token) -> Sequence -> Sequence 107 | 108 | -- | Applies an elementwise operation for pairs of tokens on a pair of 109 | -- sequences. Alias for zipWith. 110 | seqMap :: (Token -> Token -> Token) -> Sequence -> Sequence -> Sequence 111 | 112 | -- | Creates a sequence of the same length as the provided sequence filled 113 | -- with the provided token. Alias for filledWith. 114 | full :: Sequence -> Token -> Sequence 115 | 116 | -- | Extracts the indices of the elements in a sequence. Alias for 117 | -- indicesOf. 118 | indices :: Sequence -> Sequence 119 | 120 | -- | Creates an aggregator with a given aggregation type. Alias for 121 | -- aggregate. 122 | aggr :: AggregationType -> Token -> Aggregator 123 | 124 | -- | Produces a selector indicating which pairs of Keys and 125 | -- Queries match. 126 | select :: Bool -> Keys -> Queries -> Predicate -> Selector 127 | 128 | 129 | -- | This module provides convenience functions built from the core of the 130 | -- RASP-L language. 131 | -- 132 | -- It is based on Listing 3 of "What Algorithms Can Transformers Learn", 133 | -- https://arxiv.org/abs/2310.16028, by Zhou et al. 134 | module RaskellLib 135 | 136 | -- | Use a boolean sequence to select between two sequences. Also known in 137 | -- Python RASP-L as "where", see where'. 138 | (?) :: [Bool] -> (Sequence, Sequence) -> Sequence 139 | 140 | -- | Shift a sequence to the right by a given number of elements, filling 141 | -- the vacated positions with the provided Token. 142 | shiftRight :: Token -> Int8 -> Sequence -> Sequence 143 | 144 | -- | Maps tokens onto bools using Python's "truthiness" rules. 145 | toBool :: Token -> Bool 146 | 147 | -- | Masks a Sequence with a boolean sequence, using the provided 148 | -- Token as the mask. 149 | mask :: Token -> [Bool] -> Sequence -> Sequence 150 | 151 | -- | Computes the cumulative sum of a boolean sequence. 152 | cumSum :: [Bool] -> Sequence 153 | 154 | -- | Computes the running maximum of a Sequence. 155 | maximum' :: Sequence -> Sequence 156 | 157 | -- | Computes the running minimum of a Sequence. 158 | minimum' :: Sequence -> Sequence 159 | 160 | -- | Computes the indices of the running maximum values in a 161 | -- Sequence. 162 | argmax :: Sequence -> Sequence 163 | 164 | -- | Computes the indices of the running minimum values in a 165 | -- Sequence. 166 | argmin :: Sequence -> Sequence 167 | 168 | -- | Computes the number of previous tokens in a Sequence that are 169 | -- equal to each Token from Queries. 170 | numPrev :: Sequence -> Queries -> Sequence 171 | 172 | -- | Returns 1s where the Token from the Queries has been 173 | -- seen before in the Sequence. 174 | hasSeen :: Sequence -> Queries -> Sequence 175 | 176 | -- | Finds the first occurrence of each query token in a Sequence. 177 | firsts :: Token -> Sequence -> Queries -> Sequence 178 | 179 | -- | Finds the last occurrence of each query token in a Sequence. 180 | lasts :: Token -> Sequence -> Queries -> Sequence 181 | 182 | -- | Selects the tokens from a Sequence at the indices provided by 183 | -- another sequence. 184 | indexSelect :: Token -> Sequence -> Sequence -> Sequence 185 | leq :: Token -> Token -> Bool 186 | geq :: Token -> Token -> Bool 187 | lt :: Token -> Token -> Bool 188 | gt :: Token -> Token -> Bool 189 | 190 | -- | Greedily and autoregressively sample the output of a RASP-L program on 191 | -- a sequence. 192 | sample :: Token -> (Sequence -> Sequence) -> Sequence -> Word8 -> Sequence 193 | 194 | -- | Use a boolean sequence to select between two sequences. Provided for 195 | -- compatibility with Listing 3, but with an apostrophe to avoid a name 196 | -- clash with the "where" keyword. 197 | where' :: [Bool] -> Sequence -> Sequence -> Sequence 198 | 199 | -- | Greedily and autoregressively sample the output of a RASP-L program on 200 | -- a sequence. 201 | -- 202 | -- Provided for compatibility with Listing 3. 203 | sampleAutoregressive :: Token -> (Sequence -> Sequence) -> Sequence -> Word8 -> Sequence 204 | -------------------------------------------------------------------------------- /src/RaskellCore.hs: -------------------------------------------------------------------------------- 1 | -- | This module provides the core of the RASP-L language. 2 | -- 3 | -- It is based on Listing 2 of 4 | -- "What Algorithms Can Transformers Learn", https://arxiv.org/abs/2310.16028, 5 | -- by Zhou et al. 6 | module RaskellCore 7 | ( -- * Types 8 | 9 | -- | Most of these types are merely aliases. 10 | 11 | -- ** Tokens and Sequences 12 | Token, 13 | Sequence, 14 | Keys, 15 | Queries, 16 | Values, 17 | 18 | -- ** Predicates and Selectors 19 | Predicate, 20 | Selector, 21 | BoolSequence, 22 | 23 | -- ** Aggregation 24 | Aggregator, 25 | AggregationType (..), 26 | 27 | -- * Functions 28 | 29 | -- ** Key-Query-Value lookup 30 | kqv, 31 | maxKQV, 32 | minKQV, 33 | selectCausal, 34 | filledWith, 35 | indicesOf, 36 | 37 | -- ** Aggregation 38 | aggregate, 39 | aggrMax, 40 | aggrMean, 41 | aggrMin, 42 | 43 | -- ** Selection 44 | selWidth, 45 | fromBool, 46 | 47 | -- * Compatibility 48 | 49 | -- | These functions are provided for closer compatibility with the original Python implementation of RASP-L. 50 | tokMap, 51 | seqMap, 52 | full, 53 | indices, 54 | aggr, 55 | select, 56 | ) 57 | where 58 | 59 | import Data.Int (Int8) 60 | import Data.Maybe (fromMaybe) 61 | 62 | -- | A `Token` in a `Sequence` is a small integer. 63 | -- RASP-L uses `Int8` to ensure all maps of type `Token` -> `Token` are learnable. 64 | type Token = Int8 65 | 66 | -- | A `Sequence` is a list of `Token`s. 67 | type Sequence = [Token] 68 | 69 | -- | A collection of keys is a list of `Token`s. 70 | type Keys = Sequence 71 | 72 | -- | A collection of queries is a list of `Token`s. 73 | type Queries = Sequence 74 | 75 | -- | A collection of values is a list of `Token`s. 76 | type Values = Sequence 77 | 78 | -- | We can compare `Keys` and `Queries` to determine if they match. 79 | type Predicate = Token -> Token -> Bool 80 | 81 | -- | The equivalents of "attention maps" are collections of `Bool`ean sequences. 82 | type Selector = [BoolSequence] 83 | 84 | -- | Internally, we sometimes need to operate on collections of `Bool`s. 85 | type BoolSequence = [Bool] 86 | 87 | -- | Enum for the three methods for aggregating selected values 88 | data AggregationType 89 | = Min 90 | | Mean 91 | | Max 92 | 93 | -- | Performs a key-query-value lookup operation and aggregates over values. 94 | -- 95 | -- Given a filler token, an aggregation type, two sequences (keys and queries), 96 | -- and a predicate, it returns a processed sequence. It first selects elements 97 | -- based on the predicate and then aggregates them. 98 | -- 99 | -- Roughly matches the attention layer of a Transformer. 100 | kqv :: 101 | -- | Filler token used in aggregation 102 | Token -> 103 | -- | Type of aggregation (Min, Mean, Max) 104 | AggregationType -> 105 | -- | Sequence of keys 106 | Keys -> 107 | -- | Sequence of queries 108 | Queries -> 109 | -- | A boolean predicate that determines whether a key and query match 110 | Predicate -> 111 | -- | Sequence of values 112 | Values -> 113 | -- | The output sequence 114 | Sequence 115 | kqv filler agg keys queries predicate = aggregate agg filler $ selectCausal keys queries predicate 116 | 117 | -- | Performs Key-Query-Value lookup with maximum aggregation of values. 118 | maxKQV :: Keys -> Queries -> Predicate -> Values -> Sequence 119 | maxKQV = kqv minInt8 Max 120 | where 121 | minInt8 = minBound :: Int8 122 | 123 | -- | Performs Key-Query-Value lookup with minimum aggregation of values. 124 | minKQV :: Keys -> Queries -> Predicate -> Values -> Sequence 125 | minKQV = kqv maxInt8 Min 126 | where 127 | maxInt8 = maxBound :: Int8 128 | 129 | -- | Compares pairs of elements from sequences with a predicate subject to a causal constraint. 130 | selectCausal :: Keys -> Queries -> Predicate -> Selector 131 | selectCausal keys queries predicate = 132 | [ [ (keyIndex <= queryIndex) && predicate (keys !! keyIndex) (queries !! queryIndex) 133 | | keyIndex <- [0 .. length keys - 1] 134 | ] 135 | | queryIndex <- [0 .. length queries - 1] 136 | ] 137 | 138 | -- | Creates a matched-length constant sequence with the provided token. 139 | filledWith :: Sequence -> Token -> Sequence 140 | filledWith = replicate . length 141 | 142 | -- | Extracts the indices of the elements in a sequence. 143 | indicesOf :: Sequence -> Sequence 144 | indicesOf x = [0 .. (fromIntegral (length x) - 1)] 145 | 146 | -- | Type alias for "fully-specified" aggregators that are ready to aggregate a sequence of values with a selector. 147 | type Aggregator = Selector -> Values -> Sequence 148 | 149 | -- | Aggregates values with some aggregation, filling in with a default token. 150 | aggregate :: AggregationType -> Token -> Aggregator 151 | aggregate Max = aggrMax 152 | aggregate Mean = aggrMean 153 | aggregate Min = aggrMin 154 | 155 | -- | Aggregates values by selecting the largest value. 156 | aggrMax :: Token -> Aggregator 157 | aggrMax filler a v = map (aggrMaxByRow filler v) a 158 | 159 | -- | Aggregates values by taking the mean. 160 | aggrMean :: Token -> Aggregator 161 | aggrMean filler a v = map (aggrMeanByRow filler v) a 162 | 163 | -- | Aggregates values by selecting the smallest value. 164 | aggrMin :: Token -> Aggregator 165 | aggrMin filler a v = map (aggrMinByRow filler v) a 166 | 167 | aggrMaxByRow :: Token -> Sequence -> BoolSequence -> Token 168 | aggrMaxByRow filler v a = fromMaybe filler maybeMax 169 | where 170 | maybeMax = safeMaximum (filterByList a v) 171 | 172 | aggrMeanByRow :: Token -> Sequence -> BoolSequence -> Token 173 | aggrMeanByRow filler v a = fromMaybe filler maybeMean 174 | where 175 | maybeMean = safeInt8Mean (filterByList a v) 176 | 177 | aggrMinByRow :: Token -> Sequence -> BoolSequence -> Token 178 | aggrMinByRow filler v a = fromMaybe filler maybeMin 179 | where 180 | maybeMin = safeMinimum (filterByList a v) 181 | 182 | filterByList :: [Bool] -> [a] -> [a] 183 | filterByList (True : bs) (x : xs) = x : filterByList bs xs 184 | filterByList (False : bs) (_ : xs) = filterByList bs xs 185 | filterByList _ _ = [] 186 | 187 | safeMaximum :: (Ord a) => [a] -> Maybe a 188 | safeMaximum [] = Nothing 189 | safeMaximum xs = Just (maximum xs) 190 | 191 | safeMinimum :: (Ord a) => [a] -> Maybe a 192 | safeMinimum [] = Nothing 193 | safeMinimum xs = Just (minimum xs) 194 | 195 | safeInt8Mean :: Sequence -> Maybe Token 196 | safeInt8Mean [] = Nothing 197 | safeInt8Mean xs = Just (sum xs `div` fromIntegral (length xs)) 198 | 199 | -- | Computes the "width", or number of nonzero entries, of the rows of a `Selector`. 200 | selWidth :: Selector -> Sequence 201 | selWidth = map (sum . map fromBool) 202 | 203 | fromBool :: Bool -> Token 204 | fromBool True = 1 205 | fromBool _ = 0 206 | 207 | -- | Applies an elementwise operation to a sequence of tokens. 208 | -- 209 | -- Roughly matches the MLP layer in a Transformer. Alias for `map`. 210 | tokMap :: (Token -> Token) -> Sequence -> Sequence 211 | tokMap = map 212 | 213 | -- | Applies an elementwise operation for pairs of tokens on a pair of sequences. 214 | -- Alias for `zipWith`. 215 | seqMap :: (Token -> Token -> Token) -> Sequence -> Sequence -> Sequence 216 | seqMap = zipWith 217 | 218 | -- | Creates a sequence of the same length as the provided sequence filled with the provided token. 219 | -- Alias for `filledWith`. 220 | full :: Sequence -> Token -> Sequence 221 | full = filledWith 222 | 223 | -- | Extracts the indices of the elements in a sequence. 224 | -- Alias for `indicesOf`. 225 | indices :: Sequence -> Sequence 226 | indices = indicesOf 227 | 228 | -- | Creates an aggregator with a given aggregation type. 229 | -- Alias for `aggregate`. 230 | aggr :: AggregationType -> Token -> Aggregator 231 | aggr = aggregate 232 | 233 | -- | Produces a selector indicating which pairs of `Keys` and `Queries` match. 234 | select :: 235 | -- | Whether to use causal selection 236 | Bool -> 237 | -- | A collection of `Keys` to check against `Queries` 238 | Keys -> 239 | -- | A collection of `Queries` to check against `Keys` 240 | Queries -> 241 | -- | A boolean predicate that determines whether a key and query match 242 | Predicate -> 243 | -- | A collection of boolean sequences indicating which pairs of `Keys` and `Queries` match 244 | Selector 245 | select True = selectCausal 246 | select False = selectAcausal 247 | 248 | -- | Non-causal selection is included for some reason. 249 | selectAcausal :: Keys -> Queries -> Predicate -> Selector 250 | selectAcausal keys queries predicate = [[predicate keyIndex queryIndex | keyIndex <- keys] | queryIndex <- queries] 251 | -------------------------------------------------------------------------------- /docs/Lib.html: -------------------------------------------------------------------------------- 1 | Lib
raskell-0.0.0.1: RASP-L in Haskell
Safe HaskellSafe-Inferred
LanguageHaskell2010

Lib

Description

This module provides convenience functions built from the core of the RASP-L language.

It is based on Listing 3 of 2 | "What Algorithms Can Transformers Learn", https://arxiv.org/abs/2310.16028, 3 | by Zhou et al.

Logical Operations

(?) :: [Bool] -> (Sequence, Sequence) -> Sequence Source #

Use a boolean sequence to select between two sequences. 4 | Also known in Python RASP-L as "where", see where'.

where' :: [Bool] -> Sequence -> Sequence -> Sequence Source #

Use a boolean sequence to select between two sequences. 5 | Provided for compatibility with Listing 3, but with 6 | an apostrophe to avoid a name clash with the "where" keyword.

shiftRight Source #

Arguments

:: Token

Filler Token

-> Int8

Number of positions to shift

-> Sequence

Input Sequence

-> Sequence 

Shift a sequence to the right by a given number of elements, 7 | filling the vacated positions with the provided Token.

Running Aggregation

Token Comparison

Sampling

sample Source #

Arguments

:: Token

End of sequence token

-> (Sequence -> Sequence)

RASP-L program to extend the sequence

-> Sequence

Initial sequence

-> Word8

Number of steps to decode

-> Sequence 

Greedily and autoregressively sample the output of a RASP-L program on a sequence.

-------------------------------------------------------------------------------- /docs/doc-index.json: -------------------------------------------------------------------------------- 1 | [{"display_html":"type Token = Int8","name":"Token","module":"RaskellCore","link":"RaskellCore.html#t:Token"},{"display_html":"type Sequence = [Token]","name":"Sequence","module":"RaskellCore","link":"RaskellCore.html#t:Sequence"},{"display_html":"type Keys = Sequence","name":"Keys","module":"RaskellCore","link":"RaskellCore.html#t:Keys"},{"display_html":"type Queries = Sequence","name":"Queries","module":"RaskellCore","link":"RaskellCore.html#t:Queries"},{"display_html":"type Values = Sequence","name":"Values","module":"RaskellCore","link":"RaskellCore.html#t:Values"},{"display_html":"type Predicate = Token -> Token -> Bool","name":"Predicate","module":"RaskellCore","link":"RaskellCore.html#t:Predicate"},{"display_html":"type Selector = [BoolSequence]","name":"Selector","module":"RaskellCore","link":"RaskellCore.html#t:Selector"},{"display_html":"type BoolSequence = [Bool]","name":"BoolSequence","module":"RaskellCore","link":"RaskellCore.html#t:BoolSequence"},{"display_html":"type Aggregator = Selector -> Values -> Sequence","name":"Aggregator","module":"RaskellCore","link":"RaskellCore.html#t:Aggregator"},{"display_html":"data AggregationType","name":"AggregationType Min Max Mean","module":"RaskellCore","link":"RaskellCore.html#t:AggregationType"},{"display_html":"kqv :: Token -> AggregationType -> Keys -> Queries -> Predicate -> Values -> Sequence","name":"kqv","module":"RaskellCore","link":"RaskellCore.html#v:kqv"},{"display_html":"maxKQV :: Keys -> Queries -> Predicate -> Values -> Sequence","name":"maxKQV","module":"RaskellCore","link":"RaskellCore.html#v:maxKQV"},{"display_html":"minKQV :: Keys -> Queries -> Predicate -> Values -> Sequence","name":"minKQV","module":"RaskellCore","link":"RaskellCore.html#v:minKQV"},{"display_html":"selectCausal :: Keys -> Queries -> Predicate -> Selector","name":"selectCausal","module":"RaskellCore","link":"RaskellCore.html#v:selectCausal"},{"display_html":"filledWith :: Sequence -> Token -> Sequence","name":"filledWith","module":"RaskellCore","link":"RaskellCore.html#v:filledWith"},{"display_html":"indicesOf :: Sequence -> Sequence","name":"indicesOf","module":"RaskellCore","link":"RaskellCore.html#v:indicesOf"},{"display_html":"aggregate :: AggregationType -> Token -> Aggregator","name":"aggregate","module":"RaskellCore","link":"RaskellCore.html#v:aggregate"},{"display_html":"aggrMax :: Token -> Aggregator","name":"aggrMax","module":"RaskellCore","link":"RaskellCore.html#v:aggrMax"},{"display_html":"aggrMean :: Token -> Aggregator","name":"aggrMean","module":"RaskellCore","link":"RaskellCore.html#v:aggrMean"},{"display_html":"aggrMin :: Token -> Aggregator","name":"aggrMin","module":"RaskellCore","link":"RaskellCore.html#v:aggrMin"},{"display_html":"selWidth :: Selector -> Sequence","name":"selWidth","module":"RaskellCore","link":"RaskellCore.html#v:selWidth"},{"display_html":"fromBool :: Bool -> Token","name":"fromBool","module":"RaskellCore","link":"RaskellCore.html#v:fromBool"},{"display_html":"tokMap :: (Token -> Token) -> Sequence -> Sequence","name":"tokMap","module":"RaskellCore","link":"RaskellCore.html#v:tokMap"},{"display_html":"seqMap :: (Token -> Token -> Token) -> Sequence -> Sequence -> Sequence","name":"seqMap","module":"RaskellCore","link":"RaskellCore.html#v:seqMap"},{"display_html":"full :: Sequence -> Token -> Sequence","name":"full","module":"RaskellCore","link":"RaskellCore.html#v:full"},{"display_html":"indices :: Sequence -> Sequence","name":"indices","module":"RaskellCore","link":"RaskellCore.html#v:indices"},{"display_html":"aggr :: AggregationType -> Token -> Aggregator","name":"aggr","module":"RaskellCore","link":"RaskellCore.html#v:aggr"},{"display_html":"select :: Bool -> Keys -> Queries -> Predicate -> Selector","name":"select","module":"RaskellCore","link":"RaskellCore.html#v:select"},{"display_html":"(?) :: [Bool] -> (Sequence, Sequence) -> Sequence","name":"?","module":"RaskellLib","link":"RaskellLib.html#v:-63-"},{"display_html":"shiftRight :: Token -> Int8 -> Sequence -> Sequence","name":"shiftRight","module":"RaskellLib","link":"RaskellLib.html#v:shiftRight"},{"display_html":"toBool :: Token -> Bool","name":"toBool","module":"RaskellLib","link":"RaskellLib.html#v:toBool"},{"display_html":"mask :: Token -> [Bool] -> Sequence -> Sequence","name":"mask","module":"RaskellLib","link":"RaskellLib.html#v:mask"},{"display_html":"cumSum :: [Bool] -> Sequence","name":"cumSum","module":"RaskellLib","link":"RaskellLib.html#v:cumSum"},{"display_html":"maximum' :: Sequence -> Sequence","name":"maximum'","module":"RaskellLib","link":"RaskellLib.html#v:maximum-39-"},{"display_html":"minimum' :: Sequence -> Sequence","name":"minimum'","module":"RaskellLib","link":"RaskellLib.html#v:minimum-39-"},{"display_html":"argmax :: Sequence -> Sequence","name":"argmax","module":"RaskellLib","link":"RaskellLib.html#v:argmax"},{"display_html":"argmin :: Sequence -> Sequence","name":"argmin","module":"RaskellLib","link":"RaskellLib.html#v:argmin"},{"display_html":"numPrev :: Sequence -> Queries -> Sequence","name":"numPrev","module":"RaskellLib","link":"RaskellLib.html#v:numPrev"},{"display_html":"hasSeen :: Sequence -> Queries -> Sequence","name":"hasSeen","module":"RaskellLib","link":"RaskellLib.html#v:hasSeen"},{"display_html":"firsts :: Token -> Sequence -> Queries -> Sequence","name":"firsts","module":"RaskellLib","link":"RaskellLib.html#v:firsts"},{"display_html":"lasts :: Token -> Sequence -> Queries -> Sequence","name":"lasts","module":"RaskellLib","link":"RaskellLib.html#v:lasts"},{"display_html":"indexSelect :: Token -> Sequence -> Sequence -> Sequence","name":"indexSelect","module":"RaskellLib","link":"RaskellLib.html#v:indexSelect"},{"display_html":"leq :: Token -> Token -> Bool","name":"leq","module":"RaskellLib","link":"RaskellLib.html#v:leq"},{"display_html":"geq :: Token -> Token -> Bool","name":"geq","module":"RaskellLib","link":"RaskellLib.html#v:geq"},{"display_html":"lt :: Token -> Token -> Bool","name":"lt","module":"RaskellLib","link":"RaskellLib.html#v:lt"},{"display_html":"gt :: Token -> Token -> Bool","name":"gt","module":"RaskellLib","link":"RaskellLib.html#v:gt"},{"display_html":"sample :: Token -> (Sequence -> Sequence) -> Sequence -> Word8 -> Sequence","name":"sample","module":"RaskellLib","link":"RaskellLib.html#v:sample"},{"display_html":"where' :: [Bool] -> Sequence -> Sequence -> Sequence","name":"where'","module":"RaskellLib","link":"RaskellLib.html#v:where-39-"},{"display_html":"sampleAutoregressive :: Token -> (Sequence -> Sequence) -> Sequence -> Word8 -> Sequence","name":"sampleAutoregressive","module":"RaskellLib","link":"RaskellLib.html#v:sampleAutoregressive"}] -------------------------------------------------------------------------------- /docs/linuwial.css: -------------------------------------------------------------------------------- 1 | /* @group Fundamentals */ 2 | 3 | * { margin: 0; padding: 0 } 4 | 5 | /* Is this portable? */ 6 | html { 7 | background-color: white; 8 | width: 100%; 9 | height: 100%; 10 | } 11 | 12 | body { 13 | background: #fefefe; 14 | color: #111; 15 | text-align: left; 16 | min-height: 100vh; 17 | position: relative; 18 | -webkit-text-size-adjust: 100%; 19 | -webkit-font-feature-settings: "kern" 1, "liga" 0; 20 | -moz-font-feature-settings: "kern" 1, "liga" 0; 21 | -o-font-feature-settings: "kern" 1, "liga" 0; 22 | font-feature-settings: "kern" 1, "liga" 0; 23 | letter-spacing: 0.0015rem; 24 | } 25 | 26 | #content a { 27 | overflow-wrap: break-word; 28 | } 29 | 30 | p { 31 | margin: 0.8em 0; 32 | } 33 | 34 | ul, ol { 35 | margin: 0.8em 0 0.8em 2em; 36 | } 37 | 38 | dl { 39 | margin: 0.8em 0; 40 | } 41 | 42 | dt { 43 | font-weight: bold; 44 | } 45 | dd { 46 | margin-left: 2em; 47 | } 48 | 49 | a { text-decoration: none; } 50 | a[href]:link { color: #9E358F; } 51 | a[href]:visited {color: #6F5F9C; } 52 | a[href]:hover { text-decoration:underline; } 53 | 54 | a[href].def:link, a[href].def:visited { color: rgba(69, 59, 97, 0.8); } 55 | a[href].def:hover { color: rgb(78, 98, 114); } 56 | 57 | /* @end */ 58 | 59 | /* @group Show and hide with JS */ 60 | 61 | body.js-enabled .hide-when-js-enabled { 62 | display: none; 63 | } 64 | 65 | /* @end */ 66 | 67 | 68 | /* @group responsive */ 69 | 70 | #package-header .caption { 71 | margin: 0px 1em 0 2em; 72 | } 73 | 74 | @media only screen and (min-width: 1280px) { 75 | #content { 76 | width: 63vw; 77 | max-width: 1450px; 78 | } 79 | 80 | #table-of-contents { 81 | position: fixed; 82 | max-width: 10vw; 83 | top: 10.2em; 84 | left: 2em; 85 | bottom: 1em; 86 | overflow-y: auto; 87 | } 88 | 89 | #synopsis { 90 | display: block; 91 | position: fixed; 92 | float: left; 93 | top: 5em; 94 | bottom: 1em; 95 | right: 0; 96 | max-width: 65vw; 97 | overflow-y: auto; 98 | /* Ensure that synopsis covers everything (including MathJAX markup) */ 99 | z-index: 1; 100 | } 101 | 102 | #synopsis .show { 103 | border: 1px solid #5E5184; 104 | padding: 0.7em; 105 | max-height: 65vh; 106 | } 107 | 108 | } 109 | 110 | @media only screen and (max-width: 1279px) { 111 | #content { 112 | width: 80vw; 113 | } 114 | 115 | #synopsis { 116 | display: block; 117 | padding: 0; 118 | position: relative; 119 | margin: 0; 120 | width: 100%; 121 | } 122 | } 123 | 124 | @media only screen and (max-width: 999px) { 125 | #content { 126 | width: 93vw; 127 | } 128 | } 129 | 130 | 131 | /* menu for wider screens 132 | 133 | Display the package name at the left and the menu links at the right, 134 | inline with each other: 135 | The package name Source . Contents . Index 136 | */ 137 | @media only screen and (min-width: 1000px) { 138 | #package-header { 139 | text-align: left; 140 | white-space: nowrap; 141 | height: 40px; 142 | padding: 4px 1.5em 0px 1.5em; 143 | overflow: visible; 144 | 145 | display: flex; 146 | justify-content: space-between; 147 | align-items: center; 148 | } 149 | 150 | #package-header .caption { 151 | display: inline-block; 152 | margin: 0; 153 | } 154 | 155 | #package-header ul.links { 156 | margin: 0; 157 | display: inline-table; 158 | } 159 | 160 | #package-header .caption + ul.links { 161 | margin-left: 1em; 162 | } 163 | } 164 | 165 | /* menu for smaller screens 166 | 167 | Display the package name on top of the menu links and center both elements: 168 | The package name 169 | Source . Contents . Index 170 | */ 171 | @media only screen and (max-width: 999px) { 172 | #package-header { 173 | text-align: center; 174 | padding: 6px 0 4px 0; 175 | overflow: hidden; 176 | } 177 | 178 | #package-header ul.links { 179 | display: block; 180 | text-align: center; 181 | margin: 0; 182 | 183 | /* Hide scrollbar but allow scrolling menu links horizontally */ 184 | white-space: nowrap; 185 | overflow-x: auto; 186 | overflow-y: hidden; 187 | margin-bottom: -17px; 188 | height: 50px; 189 | } 190 | 191 | #package-header .caption { 192 | display: block; 193 | margin: 4px 0; 194 | text-align: center; 195 | } 196 | 197 | #package-header ul.links::-webkit-scrollbar { 198 | display: none; 199 | } 200 | 201 | #package-header ul.links li:first-of-type { 202 | padding-left: 1em; 203 | } 204 | 205 | #package-header ul.links li:last-of-type { 206 | /* 207 | The last link of the menu should offer the same distance to the right 208 | as the #package-header enforces at the left. 209 | */ 210 | padding-right: 1em; 211 | } 212 | 213 | #package-header .caption + ul.links { 214 | padding-top: 9px; 215 | } 216 | 217 | #module-header table.info { 218 | float: none; 219 | top: 0; 220 | margin: 0 auto; 221 | overflow: hidden; 222 | max-width: 80vw; 223 | } 224 | } 225 | 226 | /* @end */ 227 | 228 | 229 | /* @group Fonts & Sizes */ 230 | 231 | /* Basic technique & IE workarounds from YUI 3 232 | For reasons, see: 233 | http://yui.yahooapis.com/3.1.1/build/cssfonts/fonts.css 234 | */ 235 | 236 | body, button { 237 | font: 400 14px/1.4 'PT Sans', 238 | /* Fallback Font Stack */ 239 | -apple-system, 240 | BlinkMacSystemFont, 241 | 'Segoe UI', 242 | Roboto, 243 | Oxygen-Sans, 244 | Cantarell, 245 | 'Helvetica Neue', 246 | sans-serif; 247 | *font-size: medium; /* for IE */ 248 | *font:x-small; /* for IE in quirks mode */ 249 | } 250 | 251 | h1 { font-size: 146.5%; /* 19pt */ } 252 | h2 { font-size: 131%; /* 17pt */ } 253 | h3 { font-size: 116%; /* 15pt */ } 254 | h4 { font-size: 100%; /* 13pt */ } 255 | h5 { font-size: 100%; /* 13pt */ } 256 | 257 | table { 258 | font-size:inherit; 259 | font:100%; 260 | } 261 | 262 | pre, code, kbd, samp, tt, .src { 263 | font-family:monospace; 264 | } 265 | 266 | .links, .link { 267 | font-size: 85%; /* 11pt */ 268 | } 269 | 270 | #module-header .caption { 271 | font-size: 182%; /* 24pt */ 272 | } 273 | 274 | #module-header .caption sup { 275 | font-size: 80%; 276 | font-weight: normal; 277 | } 278 | 279 | #package-header #page-menu a:link, #package-header #page-menu a:visited { color: white; } 280 | 281 | 282 | .info { 283 | font-size: 90%; 284 | } 285 | 286 | 287 | /* @end */ 288 | 289 | /* @group Common */ 290 | 291 | .caption, h1, h2, h3, h4, h5, h6, summary { 292 | font-weight: bold; 293 | color: #5E5184; 294 | margin: 1.5em 0 1em 0; 295 | } 296 | 297 | 298 | * + h1, * + h2, * + h3, * + h4, * + h5, * + h6 { 299 | margin-top: 2em; 300 | } 301 | 302 | h1 + h2, h2 + h3, h3 + h4, h4 + h5, h5 + h6 { 303 | margin-top: inherit; 304 | } 305 | 306 | ul li + li { 307 | margin-top: 0.2rem; 308 | } 309 | 310 | ul + p { 311 | margin-top: 0.93em; 312 | } 313 | 314 | p + ul { 315 | margin-top: 0.5em; 316 | } 317 | 318 | p { 319 | margin-top: 0.7rem; 320 | } 321 | 322 | ul, ol { 323 | margin: 0.8em 0 0.8em 2em; 324 | } 325 | 326 | ul.links { 327 | list-style: none; 328 | text-align: left; 329 | font-size: 0.95em; 330 | } 331 | 332 | #package-header ul.links, #package-header ul.links button { 333 | font-size: 1rem; 334 | } 335 | 336 | ul.links li { 337 | display: inline; 338 | white-space: nowrap; 339 | padding: 0; 340 | } 341 | 342 | ul.links > li + li:before { 343 | content: '\00B7'; 344 | } 345 | 346 | ul.links li a { 347 | padding: 0.2em 0.5em; 348 | } 349 | 350 | .hide { display: none; } 351 | .show { display: inherit; } 352 | .clear { clear: both; } 353 | 354 | .collapser:before, .expander:before, .noexpander:before { 355 | font-size: 1.2em; 356 | color: #9C5791; 357 | display: inline-block; 358 | padding-right: 7px; 359 | } 360 | 361 | .collapser:before { 362 | content: '▿'; 363 | } 364 | .expander:before { 365 | content: '▹'; 366 | } 367 | .noexpander:before { 368 | content: '▿'; 369 | visibility: hidden; 370 | } 371 | 372 | .collapser, .expander { 373 | cursor: pointer; 374 | } 375 | 376 | .instance.collapser, .instance.expander { 377 | margin-left: 0px; 378 | background-position: left center; 379 | min-width: 9px; 380 | min-height: 9px; 381 | } 382 | 383 | summary { 384 | cursor: pointer; 385 | outline: none; 386 | } 387 | 388 | pre { 389 | padding: 0.5rem 1rem; 390 | margin: 1em 0 0 0; 391 | background-color: #f7f7f7; 392 | overflow: auto; 393 | border: 1px solid #ddd; 394 | border-radius: 0.3em; 395 | } 396 | 397 | pre + p { 398 | margin-top: 1em; 399 | } 400 | 401 | pre + pre { 402 | margin-top: 0.5em; 403 | } 404 | 405 | blockquote { 406 | border-left: 3px solid #c7a5d3; 407 | background-color: #eee4f1; 408 | margin: 0.5em; 409 | padding: 0.0005em 0.3em 0.5em 0.5em; 410 | } 411 | 412 | .src { 413 | background: #f2f2f2; 414 | padding: 0.2em 0.5em; 415 | } 416 | 417 | .keyword { font-weight: normal; } 418 | .def { font-weight: bold; } 419 | 420 | @media print { 421 | #footer { display: none; } 422 | } 423 | 424 | /* @end */ 425 | 426 | /* @group Page Structure */ 427 | 428 | #content { 429 | margin: 3em auto 6em auto; 430 | padding: 0; 431 | } 432 | 433 | #package-header { 434 | background: #5E5184; 435 | border-bottom: 5px solid rgba(69, 59, 97, 0.5); 436 | color: #ddd; 437 | position: relative; 438 | font-size: 1.2em; 439 | text-align: left; 440 | margin: 0 auto; 441 | } 442 | 443 | #package-header .caption { 444 | color: white; 445 | font-style: normal; 446 | font-size: 1rem; 447 | font-weight: bold; 448 | } 449 | 450 | #module-header .caption { 451 | font-weight: bold; 452 | border-bottom: 1px solid #ddd; 453 | } 454 | 455 | table.info { 456 | float: right; 457 | padding: 0.5em 1em; 458 | border: 1px solid #ddd; 459 | color: rgb(78,98,114); 460 | background-color: #fff; 461 | max-width: 60%; 462 | border-spacing: 0; 463 | position: relative; 464 | top: -0.78em; 465 | margin: 0 0 0 2em; 466 | } 467 | 468 | .info th { 469 | padding: 0 1em 0 0; 470 | text-align: right; 471 | } 472 | 473 | #style-menu li { 474 | display: block; 475 | border-style: none; 476 | list-style-type: none; 477 | } 478 | 479 | #footer { 480 | background: #ededed; 481 | border-top: 1px solid #aaa; 482 | padding: 0.5em 0; 483 | color: #222; 484 | text-align: center; 485 | width: 100%; 486 | height: 3em; 487 | margin-top: 3em; 488 | position: relative; 489 | clear: both; 490 | } 491 | 492 | /* @end */ 493 | 494 | /* @group Front Matter */ 495 | 496 | #synopsis .caption, 497 | #contents-list .caption { 498 | font-size: 1rem; 499 | } 500 | 501 | #synopsis, #table-of-contents { 502 | font-size: 16px; 503 | } 504 | 505 | #contents-list { 506 | background: #f4f4f4; 507 | padding: 1em; 508 | margin: 0; 509 | } 510 | 511 | #contents-list .caption { 512 | text-align: left; 513 | margin: 0; 514 | } 515 | 516 | #contents-list ul { 517 | list-style: none; 518 | margin: 0; 519 | margin-top: 10px; 520 | font-size: 14px; 521 | } 522 | 523 | #contents-list ul ul { 524 | margin-left: 1.5em; 525 | } 526 | 527 | #description .caption { 528 | display: none; 529 | } 530 | 531 | #synopsis summary { 532 | display: block; 533 | float: right; 534 | width: 29px; 535 | color: rgba(255,255,255,0); 536 | height: 110px; 537 | margin: 0; 538 | font-size: 1px; 539 | padding: 0; 540 | background: url(synopsis.png) no-repeat 0px -8px; 541 | } 542 | 543 | #synopsis details[open] > summary { 544 | background: url(synopsis.png) no-repeat -75px -8px; 545 | } 546 | 547 | #synopsis details:not([open]) > ul { 548 | visibility: hidden; 549 | } 550 | 551 | #synopsis ul { 552 | height: 100%; 553 | overflow: auto; 554 | padding: 0.5em; 555 | margin: 0; 556 | } 557 | 558 | #synopsis ul ul { 559 | overflow: hidden; 560 | } 561 | 562 | #synopsis ul, 563 | #synopsis ul li.src { 564 | background-color: rgb(250,247,224); 565 | white-space: nowrap; 566 | list-style: none; 567 | margin-left: 0; 568 | } 569 | 570 | #interface td.src { 571 | white-space: nowrap; 572 | } 573 | 574 | /* @end */ 575 | 576 | /* @group Main Content */ 577 | 578 | #interface div.top + div.top { 579 | margin-top: 1.5em; 580 | } 581 | 582 | #interface p + div.top, 583 | #interface h1 + div.top, 584 | #interface h2 + div.top, 585 | #interface h3 + div.top, 586 | #interface h4 + div.top, 587 | #interface h5 + div.top { 588 | margin-top: 1em; 589 | } 590 | #interface .src .selflink, 591 | #interface .src .link { 592 | float: right; 593 | color: #888; 594 | padding: 0 7px; 595 | -moz-user-select: none; 596 | font-weight: bold; 597 | line-height: 30px; 598 | } 599 | #interface .src .selflink { 600 | margin: 0 -0.5em 0 0.5em; 601 | } 602 | 603 | #interface span.fixity { 604 | color: #919191; 605 | border-left: 1px solid #919191; 606 | padding: 0.2em 0.5em 0.2em 0.5em; 607 | margin: 0 -1em 0 1em; 608 | } 609 | 610 | #interface span.rightedge { 611 | border-left: 1px solid #919191; 612 | padding: 0.2em 0 0.2em 0; 613 | margin: 0 0 0 1em; 614 | } 615 | 616 | #interface table { border-spacing: 2px; } 617 | #interface td { 618 | vertical-align: top; 619 | padding-left: 0.5em; 620 | } 621 | 622 | #interface td.doc p { 623 | margin: 0; 624 | } 625 | #interface td.doc p + p { 626 | margin-top: 0.8em; 627 | } 628 | 629 | .doc table { 630 | border-collapse: collapse; 631 | border-spacing: 0px; 632 | } 633 | 634 | .doc th, 635 | .doc td { 636 | padding: 5px; 637 | border: 1px solid #ddd; 638 | } 639 | 640 | .doc th { 641 | background-color: #f0f0f0; 642 | } 643 | 644 | .clearfix:after { 645 | clear: both; 646 | content: " "; 647 | display: block; 648 | height: 0; 649 | visibility: hidden; 650 | } 651 | 652 | .subs, .top > .doc, .subs > .doc { 653 | padding-left: 1em; 654 | border-left: 1px solid gainsboro; 655 | margin-bottom: 1em; 656 | } 657 | 658 | .top .subs { 659 | margin-bottom: 0.6em; 660 | } 661 | 662 | .subs.fields ul { 663 | list-style: none; 664 | display: table; 665 | margin: 0; 666 | } 667 | 668 | .subs.fields ul li { 669 | display: table-row; 670 | } 671 | 672 | .subs ul li dfn { 673 | display: table-cell; 674 | font-style: normal; 675 | font-weight: bold; 676 | margin: 1px 0; 677 | white-space: nowrap; 678 | } 679 | 680 | .subs ul li > .doc { 681 | display: table-cell; 682 | padding-left: 0.5em; 683 | margin-bottom: 0.5em; 684 | } 685 | 686 | .subs ul li > .doc p { 687 | margin: 0; 688 | } 689 | 690 | .subs .subs p.src { 691 | border: none; 692 | background-color: #f8f8f8; 693 | } 694 | 695 | .subs .subs .caption { 696 | margin-top: 1em ; 697 | margin-bottom: 0px; 698 | } 699 | 700 | .subs p.caption { 701 | margin-top: 0; 702 | } 703 | 704 | .subs .subs .caption + .src { 705 | margin: 0px; 706 | margin-top: 8px; 707 | } 708 | 709 | .subs .subs .src + .src { 710 | margin: 7px 0 0 0; 711 | } 712 | 713 | /* Render short-style data instances */ 714 | .inst ul { 715 | height: 100%; 716 | padding: 0.5em; 717 | margin: 0; 718 | } 719 | 720 | .inst, .inst li { 721 | list-style: none; 722 | margin-left: 1em; 723 | } 724 | 725 | /* Workaround for bug in Firefox (issue #384) */ 726 | .inst-left { 727 | float: left; 728 | } 729 | 730 | .top p.src { 731 | border-bottom: 3px solid #e5e5e5; 732 | line-height: 2rem; 733 | margin-bottom: 1em; 734 | } 735 | 736 | .warning { 737 | color: red; 738 | } 739 | 740 | .arguments { 741 | margin-top: -0.4em; 742 | } 743 | .arguments .caption { 744 | display: none; 745 | } 746 | 747 | .fields { padding-left: 1em; } 748 | 749 | .fields .caption { display: none; } 750 | 751 | .fields p { margin: 0 0; } 752 | 753 | /* this seems bulky to me 754 | .methods, .constructors { 755 | background: #f8f8f8; 756 | border: 1px solid #eee; 757 | } 758 | */ 759 | 760 | /* @end */ 761 | 762 | /* @group Auxillary Pages */ 763 | 764 | 765 | .extension-list { 766 | list-style-type: none; 767 | margin-left: 0; 768 | } 769 | 770 | #mini { 771 | margin: 0 auto; 772 | padding: 0 1em 1em; 773 | } 774 | 775 | #mini > * { 776 | font-size: 93%; /* 12pt */ 777 | } 778 | 779 | #mini #module-list .caption, 780 | #mini #module-header .caption { 781 | font-size: 125%; /* 15pt */ 782 | } 783 | 784 | #mini #interface h1, 785 | #mini #interface h2, 786 | #mini #interface h3, 787 | #mini #interface h4 { 788 | font-size: 109%; /* 13pt */ 789 | margin: 1em 0 0; 790 | } 791 | 792 | #mini #interface .top, 793 | #mini #interface .src { 794 | margin: 0; 795 | } 796 | 797 | #mini #module-list ul { 798 | list-style: none; 799 | margin: 0; 800 | } 801 | 802 | #alphabet ul { 803 | list-style: none; 804 | padding: 0; 805 | margin: 0.5em 0 0; 806 | text-align: center; 807 | } 808 | 809 | #alphabet li { 810 | display: inline; 811 | margin: 0 0.25em; 812 | } 813 | 814 | #alphabet a { 815 | font-weight: bold; 816 | } 817 | 818 | #index .caption, 819 | #module-list .caption { font-size: 131%; /* 17pt */ } 820 | 821 | #index table { 822 | margin-left: 2em; 823 | } 824 | 825 | #index .src { 826 | font-weight: bold; 827 | } 828 | #index .alt { 829 | font-size: 77%; /* 10pt */ 830 | font-style: italic; 831 | padding-left: 2em; 832 | } 833 | 834 | #index td + td { 835 | padding-left: 1em; 836 | } 837 | 838 | #module-list ul { 839 | list-style: none; 840 | margin: 0 0 0 2em; 841 | } 842 | 843 | #module-list li { 844 | clear: right; 845 | } 846 | 847 | #module-list span.collapser, 848 | #module-list span.expander { 849 | background-position: 0 0.3em; 850 | } 851 | 852 | #module-list .package { 853 | float: right; 854 | } 855 | 856 | :target { 857 | background: -webkit-linear-gradient(top, transparent 0%, transparent 65%, #fbf36d 60%, #fbf36d 100%); 858 | background: -moz-linear-gradient(top, transparent 0%, transparent 65%, #fbf36d 60%, #fbf36d 100%); 859 | background: -o-linear-gradient(top, transparent 0%, transparent 65%, #fbf36d 60%, #fbf36d 100%); 860 | background: -ms-linear-gradient(top, transparent 0%, transparent 65%, #fbf36d 60%, #fbf36d 100%); 861 | background: linear-gradient(to bottom, transparent 0%, transparent 65%, #fbf36d 60%, #fbf36d 100%); 862 | } 863 | 864 | :target:hover { 865 | background: -webkit-linear-gradient(top, transparent 0%, transparent 0%, #fbf36d 0%, #fbf36d 100%); 866 | background: -moz-linear-gradient(top, transparent 0%, transparent 0%, #fbf36d 0%, #fbf36d 100%); 867 | background: -o-linear-gradient(top, transparent 0%, transparent 0%, #fbf36d 0%, #fbf36d 100%); 868 | background: -ms-linear-gradient(top, transparent 0%, transparent 0%, #fbf36d 0%, #fbf36d 100%); 869 | background: linear-gradient(to bottom, transparent 0%, transparent 0%, #fbf36d 0%, #fbf36d 100%); 870 | } 871 | 872 | /* @end */ 873 | 874 | /* @group Dropdown menus */ 875 | 876 | #preferences-menu, #style-menu { 877 | width: 25em; 878 | overflow-y: auto; 879 | } 880 | 881 | /* @end */ 882 | -------------------------------------------------------------------------------- /docs/RaskellLib.html: -------------------------------------------------------------------------------- 1 | RaskellLib
raskell-0.0.0.2: RASP-L in Haskell
Safe HaskellSafe-Inferred
LanguageHaskell2010

RaskellLib

Description

This module provides convenience functions built from the core of the RASP-L language.

It is based on Listing 3 of 2 | "What Algorithms Can Transformers Learn", https://arxiv.org/abs/2310.16028, 3 | by Zhou et al.

Logical Operations

(?) :: [Bool] -> (Sequence, Sequence) -> Sequence Source #

Use a boolean sequence to select between two sequences. 4 | Also known in Python RASP-L as "where", see where'.

shiftRight Source #

Arguments

:: Token

Filler Token

-> Int8

Number of positions to shift

-> Sequence

Input Sequence

-> Sequence 

Shift a sequence to the right by a given number of elements, 5 | filling the vacated positions with the provided Token.

toBool :: Token -> Bool Source #

Maps tokens onto bools using Python's "truthiness" rules.

mask :: Token -> [Bool] -> Sequence -> Sequence Source #

Masks a Sequence with a boolean sequence, using the provided Token as the mask.

Running Aggregations

cumSum :: [Bool] -> Sequence Source #

Computes the cumulative sum of a boolean sequence.

maximum' :: Sequence -> Sequence Source #

Computes the running maximum of a Sequence.

minimum' :: Sequence -> Sequence Source #

Computes the running minimum of a Sequence.

argmax :: Sequence -> Sequence Source #

Computes the indices of the running maximum values in a Sequence.

argmin :: Sequence -> Sequence Source #

Computes the indices of the running minimum values in a Sequence.

Aggregations with Queries

numPrev :: Sequence -> Queries -> Sequence Source #

Computes the number of previous tokens in a Sequence that are equal to each Token from Queries.

hasSeen :: Sequence -> Queries -> Sequence Source #

Returns 1s where the Token from the Queries has been seen before in the Sequence.

Indexing with Queries

firsts :: Token -> Sequence -> Queries -> Sequence Source #

Finds the first occurrence of each query token in a Sequence.

lasts :: Token -> Sequence -> Queries -> Sequence Source #

Finds the last occurrence of each query token in a Sequence.

indexSelect :: Token -> Sequence -> Sequence -> Sequence Source #

Selects the tokens from a Sequence at the indices provided by another sequence.

Token Comparisons

Sampling

sample Source #

Arguments

:: Token

End of sequence token

-> (Sequence -> Sequence)

RASP-L program to extend the sequence

-> Sequence

Initial/prompt sequence

-> Word8

Number of steps to decode

-> Sequence

Output (including prompt)

Greedily and autoregressively sample the output of a RASP-L program on a sequence.

Compatibility

where' :: [Bool] -> Sequence -> Sequence -> Sequence Source #

Use a boolean sequence to select between two sequences. 6 | Provided for compatibility with Listing 3, but with 7 | an apostrophe to avoid a name clash with the "where" keyword.

sampleAutoregressive :: Token -> (Sequence -> Sequence) -> Sequence -> Word8 -> Sequence Source #

Greedily and autoregressively sample the output of a RASP-L program on a sequence.

Provided for compatibility with Listing 3.

-------------------------------------------------------------------------------- /docs/Core.html: -------------------------------------------------------------------------------- 1 | Core
raskell-0.0.0.1: RASP-L in Haskell
Safe HaskellSafe-Inferred
LanguageHaskell2010

Core

Description

This module provides the core of the RASP-L language.

It is based on Listing 2 of 2 | "What Algorithms Can Transformers Learn", https://arxiv.org/abs/2310.16028, 3 | by Zhou et al.

Types

Most of these types are merely aliases.

Tokens and Sequences

type Token = Int8 Source #

A Token in a Sequence is a small integer. 4 | RASP-L uses Int8 to ensure all maps of type Token -> Token are learnable.

type Sequence = [Token] Source #

A Sequence is a list of Tokens.

type Keys = Sequence Source #

A collection of keys is a list of Tokens.

type Queries = Sequence Source #

A collection of queries is a list of Tokens.

type Values = Sequence Source #

A collection of values is a list of Tokens.

Predicates and Selectors

type Predicate = Token -> Token -> Bool Source #

We can compare Keys and Queries to determine if they match.

type Selector = [BoolSequence] Source #

The equivalents of "attention maps" are collections of Boolean sequences.

type BoolSequence = [Bool] Source #

Internally, we sometimes need to operate on collections of Bools.

Aggregation

type Aggregator = Selector -> Values -> Sequence Source #

Type alias for "fully-specified" aggregators that are ready to aggregate a sequence of values with a selector.

data AggregationType Source #

Enum for the three methods for aggregating selected values

Constructors

Min 
Mean 
Max 

Functions

Key-Query-Value lookup

kqv Source #

Arguments

:: Token

Filler token used in aggregation

-> AggregationType

Type of aggregation (Min, Mean, Max)

-> Keys

Sequence of keys

-> Queries

Sequence of queries

-> Predicate

A boolean predicate that determines whether a key and query match

-> Values

Sequence of values

-> Sequence

The output sequence

Performs a key-query-value lookup operation and aggregates over values.

Given a filler token, an aggregation type, two sequences (keys and queries), 5 | and a predicate, it returns a processed sequence. It first selects elements 6 | based on the predicate and then aggregates them.

Roughly matches the attention layer of a Transformer.

maxKQV :: Keys -> Queries -> Predicate -> Values -> Sequence Source #

Performs Key-Query-Value lookup with maximum aggregation of values.

minKQV :: Keys -> Queries -> Predicate -> Values -> Sequence Source #

Performs Key-Query-Value lookup with minimum aggregation of values.

selectCausal :: Keys -> Queries -> Predicate -> Selector Source #

Compareis pairs of elements from sequences with a predicate subject to a causal constraint.

filledWith :: Sequence -> Token -> Sequence Source #

Creates a matched-length constant sequence with the provided token.

indicesOf :: Sequence -> Sequence Source #

Extracts the indices of the elements in a sequence.

Aggregation

aggregate :: AggregationType -> Token -> Aggregator Source #

Aggregates values with some aggregation, filling in with a default token.

aggrMax :: Token -> Aggregator Source #

Aggregates values by selecting the largest value.

aggrMean :: Token -> Aggregator Source #

Aggregates values by taking the mean.

aggrMin :: Token -> Aggregator Source #

Aggregates values by selecting the smallest value.

Selection

selWidth :: Selector -> Sequence Source #

Computes the "width", or number of nonzero entries, of the rows of a Selector.

Compatibility

These functions are provided for closer compatibility with the original Python implementation of RASP-L.

tokMap :: (Token -> Token) -> Sequence -> Sequence Source #

Applies an elementwise operation to a sequence of tokens.

Roughly matches the MLP layer in a Transformer. Alias for map.

seqMap :: (Token -> Token -> Token) -> Sequence -> Sequence -> Sequence Source #

Applies an elementwise operation for pairs of tokens on a pair of sequences. 7 | Alias for zipWith.

full :: Sequence -> Token -> Sequence Source #

Creates a sequence of the same length as the provided sequence filled with the provided token. 8 | Alias for filledWith.

indices :: Sequence -> Sequence Source #

Extracts the indices of the elements in a sequence. 9 | Alias for indicesOf.

aggr :: AggregationType -> Token -> Aggregator Source #

Creates an aggregator with a given aggregation type. 10 | Alias for aggregate.

select Source #

Arguments

:: Bool

Whether to use causal selection

-> Keys

A collection of Keys to check against Queries

-> Queries

A collection of Queries to check against Keys

-> Predicate

A boolean predicate that determines whether a key and query match

-> Selector

A collection of boolean sequences indicating which pairs of Keys and Queries match

Produces a selector indicating which pairs of Keys and Queries match.

-------------------------------------------------------------------------------- /docs/quick-jump.min.js: -------------------------------------------------------------------------------- 1 | !function i(s,a,l){function c(t,e){if(!a[t]){if(!s[t]){var n="function"==typeof require&&require;if(!e&&n)return n(t,!0);if(u)return u(t,!0);var o=new Error("Cannot find module '"+t+"'");throw o.code="MODULE_NOT_FOUND",o}var r=a[t]={exports:{}};s[t][0].call(r.exports,function(e){return c(s[t][1][e]||e)},r,r.exports,i,s,a,l)}return a[t].exports}for(var u="function"==typeof require&&require,e=0;ewindow.innerHeight?this.searchResults.scrollTop+=e.bottom-window.innerHeight+80:e.topn)return u(e,this.pattern,o);var r=this.options,i=r.location,s=r.distance,a=r.threshold,l=r.findAllMatches,c=r.minMatchCharLength;return h(e,this.pattern,this.patternAlphabet,{location:i,distance:s,threshold:a,findAllMatches:l,minMatchCharLength:c})}}]),m}();e.exports=r},function(e,t,n){"use strict";var u=n(0);e.exports=function(e,t){return function e(t,n,o){if(n){var r=n.indexOf("."),i=n,s=null;-1!==r&&(i=n.slice(0,r),s=n.slice(r+1));var a=t[i];if(null!=a)if(s||"string"!=typeof a&&"number"!=typeof a)if(u(a))for(var l=0,c=a.length;l 0 and <= 1");d=d.name}else a[d]={weight:1};this._analyze({key:d,value:this.options.getFn(u,d),record:u,index:l},{resultMap:o,results:r,tokenSearchers:e,fullSearcher:t})}return{weights:a,results:r}}},{key:"_analyze",value:function(e,t){var n=e.key,o=e.arrayIndex,r=void 0===o?-1:o,i=e.value,s=e.record,a=e.index,l=t.tokenSearchers,c=void 0===l?[]:l,u=t.fullSearcher,h=void 0===u?[]:u,p=t.resultMap,d=void 0===p?{}:p,f=t.results,v=void 0===f?[]:f;if(null!=i){var g=!1,_=-1,m=0;if("string"==typeof i){this._log("\nKey: "+(""===n?"-":n));var y=h.search(i);if(this._log('Full text: "'+i+'", score: '+y.score),this.options.tokenize){for(var k=i.split(this.options.tokenSeparator),b=[],x=0;x=c.length;if(this._log("\nCheck Matches: "+O),(g||y.isMatch)&&O){var P=d[a];P?P.output.push({key:n,arrayIndex:r,value:i,score:T,matchedIndices:y.matchedIndices}):(d[a]={item:s,output:[{key:n,arrayIndex:r,value:i,score:T,matchedIndices:y.matchedIndices}]},v.push(d[a]))}}else if(U(i))for(var j=0,E=i.length;jRaskellCore
raskell-0.0.0.2: RASP-L in Haskell
Safe HaskellSafe-Inferred
LanguageHaskell2010

RaskellCore

Description

This module provides the core of the RASP-L language.

It is based on Listing 2 of 2 | "What Algorithms Can Transformers Learn", https://arxiv.org/abs/2310.16028, 3 | by Zhou et al.

Types

Most of these types are merely aliases.

Tokens and Sequences

type Token = Int8 Source #

A Token in a Sequence is a small integer. 4 | RASP-L uses Int8 to ensure all maps of type Token -> Token are learnable.

type Sequence = [Token] Source #

A Sequence is a list of Tokens.

type Keys = Sequence Source #

A collection of keys is a list of Tokens.

type Queries = Sequence Source #

A collection of queries is a list of Tokens.

type Values = Sequence Source #

A collection of values is a list of Tokens.

Predicates and Selectors

type Predicate = Token -> Token -> Bool Source #

We can compare Keys and Queries to determine if they match.

type Selector = [BoolSequence] Source #

The equivalents of "attention maps" are collections of Boolean sequences.

type BoolSequence = [Bool] Source #

Internally, we sometimes need to operate on collections of Bools.

Aggregation

type Aggregator = Selector -> Values -> Sequence Source #

Type alias for "fully-specified" aggregators that are ready to aggregate a sequence of values with a selector.

data AggregationType Source #

Enum for the three methods for aggregating selected values

Constructors

Min 
Mean 
Max 

Functions

Key-Query-Value lookup

kqv Source #

Arguments

:: Token

Filler token used in aggregation

-> AggregationType

Type of aggregation (Min, Mean, Max)

-> Keys

Sequence of keys

-> Queries

Sequence of queries

-> Predicate

A boolean predicate that determines whether a key and query match

-> Values

Sequence of values

-> Sequence

The output sequence

Performs a key-query-value lookup operation and aggregates over values.

Given a filler token, an aggregation type, two sequences (keys and queries), 5 | and a predicate, it returns a processed sequence. It first selects elements 6 | based on the predicate and then aggregates them.

Roughly matches the attention layer of a Transformer.

maxKQV :: Keys -> Queries -> Predicate -> Values -> Sequence Source #

Performs Key-Query-Value lookup with maximum aggregation of values.

minKQV :: Keys -> Queries -> Predicate -> Values -> Sequence Source #

Performs Key-Query-Value lookup with minimum aggregation of values.

selectCausal :: Keys -> Queries -> Predicate -> Selector Source #

Compares pairs of elements from sequences with a predicate subject to a causal constraint.

filledWith :: Sequence -> Token -> Sequence Source #

Creates a matched-length constant sequence with the provided token.

indicesOf :: Sequence -> Sequence Source #

Extracts the indices of the elements in a sequence.

Aggregation

aggregate :: AggregationType -> Token -> Aggregator Source #

Aggregates values with some aggregation, filling in with a default token.

aggrMax :: Token -> Aggregator Source #

Aggregates values by selecting the largest value.

aggrMean :: Token -> Aggregator Source #

Aggregates values by taking the mean.

aggrMin :: Token -> Aggregator Source #

Aggregates values by selecting the smallest value.

Selection

selWidth :: Selector -> Sequence Source #

Computes the "width", or number of nonzero entries, of the rows of a Selector.

Compatibility

These functions are provided for closer compatibility with the original Python implementation of RASP-L.

tokMap :: (Token -> Token) -> Sequence -> Sequence Source #

Applies an elementwise operation to a sequence of tokens.

Roughly matches the MLP layer in a Transformer. Alias for map.

seqMap :: (Token -> Token -> Token) -> Sequence -> Sequence -> Sequence Source #

Applies an elementwise operation for pairs of tokens on a pair of sequences. 7 | Alias for zipWith.

full :: Sequence -> Token -> Sequence Source #

Creates a sequence of the same length as the provided sequence filled with the provided token. 8 | Alias for filledWith.

indices :: Sequence -> Sequence Source #

Extracts the indices of the elements in a sequence. 9 | Alias for indicesOf.

aggr :: AggregationType -> Token -> Aggregator Source #

Creates an aggregator with a given aggregation type. 10 | Alias for aggregate.

select Source #

Arguments

:: Bool

Whether to use causal selection

-> Keys

A collection of Keys to check against Queries

-> Queries

A collection of Queries to check against Keys

-> Predicate

A boolean predicate that determines whether a key and query match

-> Selector

A collection of boolean sequences indicating which pairs of Keys and Queries match

Produces a selector indicating which pairs of Keys and Queries match.

--------------------------------------------------------------------------------