├── .gitignore ├── LICENSE ├── README.md ├── codegen-test.rkt ├── codegen.rkt ├── cuda-synth.rkt ├── cuda.rkt ├── ex1-transpose.rkt ├── ex2-conv1d.rkt ├── ex2-stencil.rkt ├── ex2-stencil2d.rkt ├── ex3-poly-mult-load-only.rkt ├── ex3-poly-mult-noacc.rkt ├── ex3-poly-mult.rkt ├── ex4-aos-sum-noacc.rkt ├── ex4-aos-sum.rkt ├── ex5-aos-pure-load-sol.rkt ├── ex5-aos-pure-load.rkt └── util.rkt /.gitignore: -------------------------------------------------------------------------------- 1 | *~ 2 | *.o 3 | *# 4 | *.out 5 | *.sass 6 | *.lib 7 | *.ptx 8 | *.ii 9 | *.stub.c 10 | *.fatbin 11 | *.fatbin.c 12 | *.module_id 13 | *.cubin 14 | *.reg.c 15 | *.gpu -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2018-2019, University of California, Berkeley. 2 | 3 | Authored by Phitchaya Mangpo Phothilimthana. 4 | 5 | Copyright (c) 2018-2019, NVIDIA CORPORATION. All rights reserved. 6 | 7 | Redistribution and use in source and binary forms, with or without 8 | modification, are permitted provided that the following conditions are met: 9 | 10 | 1. Redistributions of source code must retain the above copyright notice, this 11 | list of conditions and the following disclaimer. 12 | 13 | 2. Redistributions in binary form must reproduce the above copyright notice, 14 | this list of conditions and the following disclaimer in the documentation 15 | and/or other materials provided with the distribution. 16 | 17 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 18 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 19 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 20 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 21 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 22 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 23 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 24 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 25 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 26 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 27 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## References 2 | [Swizzle Inventor: Data Movement Synthesis for GPU Kernels, ASPLOS 2019](https://mangpo.net/papers/swizzle-inventor-asplos19.pdf) 3 | 4 | ## License 5 | Refer to [LICENSE](LICENSE) for the license and copyright information for this project. 6 | 7 | ## Software Prerequisites 8 | * [Racket](https://racket-lang.org/download/) 9 | * [Rosette 2.x](https://github.com/emina/rosette/releases/tag/2.2). Note: Swizzle Inventor has not been tested with Rosette 3.x. 10 | 11 | ## Running Synthesizer 12 | 13 | #### 1D stencil 14 | ex2-stencil.rkt 15 | 16 | #### 1D conv 17 | ex2-conv1d.rkt 18 | 19 | #### 2D conv 20 | ex2-stencil2d.rkt 21 | 22 | #### Finite field multiplication 23 | ex3-poly-mult.rkt 24 | 25 | #### AoS-load-sum 26 | ex4-aos-sum.rkt 27 | 28 | #### AoS-load-store 29 | ex5-aos-pure-load.rkt 30 | 31 | ## Racket to CUDA Code Generator 32 | codegen-test.rkt 33 | -------------------------------------------------------------------------------- /codegen-test.rkt: -------------------------------------------------------------------------------- 1 | #| 2 | | Copyright (c) 2018-2019, NVIDIA CORPORATION. All rights reserved. 3 | | 4 | | Redistribution and use in source and binary forms, with or without 5 | | modification, are permitted provided that the following conditions are met: 6 | | 7 | | 1. Redistributions of source code must retain the above copyright notice, 8 | | this list of conditions and the following disclaimer. 9 | | 10 | | 2. Redistributions in binary form must reproduce the above copyright notice, 11 | | this list of conditions and the following disclaimer in the documentation 12 | | and/or other materials provided with the distribution. 13 | | 14 | | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 15 | | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 16 | | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 17 | | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 18 | | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 19 | | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 20 | | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 21 | | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 22 | | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 23 | | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 24 | | POSSIBILITY OF SUCH DAMAGE. 25 | |# 26 | 27 | #lang racket 28 | 29 | (require "codegen.rkt") 30 | 31 | (define struct-size 2) 32 | 33 | (define func 34 | '(define (AOS-loadsh-sketch-fan threadId blockID blockDim I O a b c) 35 | (define I-cached (create-matrix-local (x-y-z struct-size))) 36 | (define localId (modulo (get-x threadId) 32)) 37 | (define offset 38 | (* struct-size (- (+ (* blockID blockDim) (get-x threadId)) localId))) 39 | (global-to-local 40 | I 41 | I-cached 42 | (x-y-z 1) 43 | offset 44 | (x-y-z (* warpSize struct-size)) 45 | #f 46 | #:round 47 | struct-size 48 | #:shfl 49 | (lambda (localId i) 50 | (sw-xform localId warpSize 2 16 32 -1 i struct-size 0 1 17))) 51 | (define O-cached 52 | (permute-vector 53 | I-cached 54 | struct-size 55 | (lambda (i) (sw-xform i struct-size 1 2 2 1 localId warpSize 0 16 1)))) 56 | (local-to-global 57 | O-cached 58 | O 59 | (x-y-z 1) 60 | offset 61 | (x-y-z (* warpSize struct-size)) 62 | #f 63 | #:round 64 | struct-size 65 | #:shfl 66 | (lambda (localId i) 67 | (sw-xform localId warpSize 0 1 32 -1 i struct-size 15 1 16))))) 68 | 69 | 70 | (print-cuda (racket2cuda func 1 #:const-map (hash 'struct-size struct-size 'warpSize 32 'n 64))) 71 | -------------------------------------------------------------------------------- /codegen.rkt: -------------------------------------------------------------------------------- 1 | #| 2 | | Copyright (c) 2018-2019, NVIDIA CORPORATION. All rights reserved. 3 | | 4 | | Redistribution and use in source and binary forms, with or without 5 | | modification, are permitted provided that the following conditions are met: 6 | | 7 | | 1. Redistributions of source code must retain the above copyright notice, 8 | | this list of conditions and the following disclaimer. 9 | | 10 | | 2. Redistributions in binary form must reproduce the above copyright notice, 11 | | this list of conditions and the following disclaimer in the documentation 12 | | and/or other materials provided with the distribution. 13 | | 14 | | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 15 | | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 16 | | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 17 | | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 18 | | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 19 | | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 20 | | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 21 | | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 22 | | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 23 | | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 24 | | POSSIBILITY OF SUCH DAMAGE. 25 | |# 26 | 27 | #lang racket 28 | 29 | (provide racket2cuda print-cuda load-store?) 30 | 31 | (define data-type "int") 32 | (define index-type "int") 33 | 34 | (define dims 1) 35 | (define thread-id #f) 36 | (define block-id #f) 37 | (define block-size #f) 38 | 39 | (define env (make-hash)) 40 | (define matrix-size (make-hash)) 41 | (define cuda-vars (make-hash)) 42 | (define env-consts (hash 'struct-size 3 'warpSize 32 'n 64)) 43 | (define accumulators (make-hash)) 44 | (define acc-replace (make-hash)) 45 | (define temps (make-hash)) 46 | 47 | (define (cuda-var? v) 48 | (hash-has-key? cuda-vars v)) 49 | (define (load-store? f) 50 | (member f '(global-to-local local-to-global global-to-shared shared-to-global))) 51 | (define (load? f) 52 | (member f '(global-to-local global-to-shared))) 53 | (define (store? f) 54 | (member f '(local-to-global shared-to-global))) 55 | (define (let? x) 56 | (member x '(let let*))) 57 | 58 | (define (eval-const v) 59 | (if (hash-has-key? env-consts v) 60 | (hash-ref env-consts v) 61 | v)) 62 | 63 | (define (sanitize x) 64 | (cond 65 | [(number? x) x] 66 | [(equal? x #t) 1] 67 | [(equal? x #f) 0] 68 | [(string? x) (string-replace x "-" "_")] 69 | [else (string-replace (symbol->string x) "-" "_")])) 70 | 71 | (define (print-cuda l) 72 | (for ([s (flatten l)]) 73 | (pretty-display s))) 74 | 75 | ;; Convert racket function to CUDA. 76 | ;; expr -- s-expression 77 | ;; d -- kernel dimension 78 | ;; const -- a hash map that maps symbols to values 79 | (define (racket2cuda expr d #:const-map [const-map #f]) 80 | (set! dims d) 81 | (match expr 82 | [(list 'define (list func-name tid bid bsize args ...) body ...) 83 | (set! thread-id tid) 84 | (set! block-id bid) 85 | (set! block-size bsize) 86 | (hash-set! env tid d) 87 | (hash-set! env block-id d) 88 | (hash-set! env block-size d) 89 | (hash-set! cuda-vars tid "threadIdx") 90 | (hash-set! cuda-vars block-id "blockIdx") 91 | (hash-set! cuda-vars block-size "blockDim") 92 | (when const-map (set! env-consts const-map)) 93 | 94 | (define args-str 95 | (for/list ([arg args]) 96 | (format "const ~a *~a" data-type (sanitize arg)))) 97 | (define declare 98 | (format-indent "__global__ void ~a(~a) {" (sanitize func-name) (string-join args-str ", "))) 99 | (inc-indent) 100 | (define body-ret (for/list ([st body]) (convert-statement st))) 101 | (dec-indent) 102 | 103 | (list declare 104 | body-ret 105 | (format-indent "}")) 106 | ])) 107 | 108 | 109 | (define indent-str "") 110 | (define (inc-indent) (set! indent-str (string-append indent-str " ")) (list)) 111 | (define (dec-indent) (set! indent-str (substring indent-str 2)) (list)) 112 | (define-syntax-rule (format-indent s args ...) 113 | (string-append indent-str (format s args ...))) 114 | 115 | (define (convert-statement st) 116 | (match st 117 | [(list 'define matrix (list 'create-matrix-local (list 'x-y-z sizes ...) '#:type type)) 118 | (hash-set! matrix-size matrix sizes) 119 | (format-indent "~a ~a~a;" type (sanitize matrix) (dims-str sizes))] 120 | 121 | [(list 'define matrix (list 'create-matrix-local (list 'x-y-z sizes ...))) 122 | (hash-set! matrix-size matrix sizes) 123 | (format-indent "~a ~a~a;" data-type (sanitize matrix) (dims-str sizes))] 124 | 125 | [(list 'define-shared matrix (list 'create-matrix (list 'x-y-z sizes ...) '#:type type)) 126 | (hash-set! matrix-size matrix sizes) 127 | (format-indent "__shared__ ~a ~a~a;" type (sanitize matrix) (dims-str sizes))] 128 | 129 | [(list 'define-shared matrix (list 'create-matrix (list 'x-y-z sizes ...))) 130 | (hash-set! matrix-size matrix sizes) 131 | (format-indent "__shared__ ~a ~a~a;" data-type (sanitize matrix) (dims-str sizes))] 132 | 133 | ;; log M rotations 134 | [(list 'define y (list 'permute-vector x size 135 | (list 'lambda (list i) (list 'sw-xform sw-args ...)))) 136 | (define statements (list)) 137 | (define (add-st x) (set! statements (cons x statements))) 138 | 139 | (define n (eval-const size)) 140 | 141 | (add-st (convert-statement `(define ,y (make-vector ,size)))) 142 | (hash-set! env i 1) 143 | (define-values (i-expr j-expr) (apply convert-sw-xform2 sw-args)) 144 | 145 | (cond 146 | [(string->number j-expr) 147 | (add-st (roate-one-step (sanitize x) (sanitize y) n i i-expr j-expr)) 148 | ] 149 | 150 | [else 151 | (add-st (format-indent "{")) 152 | (inc-indent) 153 | 154 | (define (def-loop skip) 155 | (when (< skip n) 156 | (add-st (convert-statement 157 | `(define ,(format "_~a~a" x (/ skip 2)) 158 | (make-vector ,size)))) 159 | (def-loop (* 2 skip)))) 160 | (def-loop 2) 161 | (add-st (format-indent "int rot = (~a) % ~a;" j-expr n)) 162 | 163 | (define (loop skip) 164 | (when (< skip n) 165 | (add-st (rotate-log-step (sanitize x) (sanitize y) n skip 'rot i i-expr (>= (* 2 skip) n))) 166 | (loop (* 2 skip)))) 167 | 168 | (loop 1) 169 | (dec-indent) 170 | (add-st (format-indent "}"))] 171 | ) 172 | (reverse statements)] 173 | 174 | [(list 'define acc (list 'create-accumulator (list 'list op-list ...) final-op block-dim)) 175 | (hash-set! accumulators acc (cons (map convert-op op-list) final-op)) 176 | (format-indent "~a ~a = 0;" data-type (sanitize acc))] ;; TODO: where to insert final-op? 177 | 178 | 179 | [(list 'accumulate acc l) 180 | (convert-statement (list 'accumulate acc l '#:pred #t)) 181 | ] 182 | 183 | [(list 'accumulate acc l* '#:pred pred*) 184 | (define pred (simplify pred*)) 185 | (cond 186 | [(equal? pred #f) 187 | (list)] 188 | 189 | [else 190 | (define l 191 | (match l* 192 | [(list 'list x ...) x] 193 | [_ l*])) 194 | (define op-list (car (hash-ref accumulators acc))) 195 | (define res (accumulate l (cdr (reverse op-list)))) 196 | (define st (format "~a ~a= ~a;" acc (last op-list) res)) 197 | 198 | (cond 199 | [(equal? pred #t) 200 | (format-indent st)] 201 | 202 | [else 203 | (define-values (pred-n pred-f) (convert-expr pred)) 204 | (define pred-ret (pred-f 0)) 205 | (format-indent "if(~a) ~a" pred-ret st) 206 | ]) 207 | ]) 208 | ] 209 | 210 | [(list 'define matrix (list 'make-vector size)) 211 | (hash-set! matrix-size matrix (list size)) 212 | (format-indent "~a ~a~a;" data-type (sanitize matrix) (dims-str (list size)))] 213 | 214 | [(list 'define var e) 215 | (define-values (n f) (convert-expr e)) 216 | (hash-set! env var n) 217 | (if (= n 1) 218 | (format-indent "~a ~a = ~a;" index-type (sanitize var) (f 0)) 219 | (for/list ([i n]) 220 | (format-indent "~a ~a~a = ~a;" index-type (sanitize var) i (f i))))] 221 | 222 | [(list 'global-to-reg global reg idx) 223 | (format-indent "~a = ~a~a;" (sanitize reg) (sanitize global) (dims-str (list idx)))] 224 | 225 | [(list 'global-to-reg global reg idx '#:size size) 226 | (define-values (idx-n idx-f) (convert-expr idx)) 227 | (define-values (size-n size-f) (convert-expr size)) 228 | (define ans (idx-f (sub1 idx-n))) 229 | (for ([t (sub1 idx-n)]) 230 | (let ([i (- idx-n t 2)]) 231 | (set! ans (format "(~a * ~a) + ~a" ans (size-f i) (idx-f i))))) 232 | (format-indent "~a = ~a[~a];" (sanitize reg) (sanitize global) ans)] 233 | 234 | [(list 'reg-to-global reg global idx) 235 | (define-values (reg-n reg-f) (convert-expr reg)) 236 | (format-indent "~a~a = ~a;" (sanitize global) (dims-str (list idx)) (reg-f 0))] 237 | 238 | [(list 'reg-to-global reg global idx '#:size size) 239 | (define-values (reg-n reg-f) (convert-expr reg)) 240 | (define-values (idx-n idx-f) (convert-expr idx)) 241 | (define-values (size-n size-f) (convert-expr size)) 242 | (define ans (idx-f (sub1 idx-n))) 243 | (for ([t (sub1 idx-n)]) 244 | (let ([i (- idx-n t 2)]) 245 | (set! ans (format "(~a * ~a) + ~a" ans (size-f i) (idx-f i))))) 246 | (format-indent "~a[~a] = ~a;" (sanitize global) ans (reg-f 0))] 247 | 248 | [(list (? load-store? f) A B stride offset size transpose) 249 | (define warp-shape 250 | (cond [(= dims 1) 32] 251 | [(= dims 2) '(x-y-z 32 1)] 252 | [(= dims 3) '(x-y-z 32 1 1)])) 253 | (define cuda-f (string-replace (symbol->string f) "-" "_")) 254 | (define global (if (load? f) A B)) 255 | (define local (if (load? f) B A)) 256 | (convert-global-to-local cuda-f 257 | A B 1 stride offset size transpose warp-shape 258 | (hash-ref matrix-size local)) 259 | ] 260 | 261 | 262 | 263 | [(list (? load-store? f) A B stride offset size transpose '#:warp-shape warp-shape) 264 | (define cuda-f (string-replace (symbol->string f) "-" "_")) 265 | (define global (if (load? f) A B)) 266 | (define local (if (load? f) B A)) 267 | (convert-global-to-local cuda-f 268 | A B 1 stride offset size transpose warp-shape 269 | (hash-ref matrix-size local)) 270 | ] 271 | 272 | [(list (? load-store? f) A B stride offset size transpose '#:round round) 273 | (convert-statement (list f A B stride offset size transpose '#:round round '#:size 1)) 274 | ] 275 | 276 | [(list (? load-store? f) A B stride offset size transpose '#:size gsize) 277 | (convert-statement (list f A B stride offset size transpose '#:round 1 '#:size gsize)) 278 | ] 279 | 280 | [(list (? load-store? f) A B stride offset size transpose '#:round round '#:size gsize) 281 | (define warp-shape 282 | (cond [(= dims 1) 32] 283 | [(= dims 2) '(x-y-z 32 1)] 284 | [(= dims 3) '(x-y-z 32 1 1)])) 285 | (convert-statement (list f A B stride offset size transpose '#:warp-shape warp-shape '#:round round '#:size gsize)) 286 | ] 287 | 288 | [(list (? load-store? f) A B stride offset size transpose '#:warp-shape warp-shape '#:round round) 289 | (convert-statement (list f A B stride offset size transpose '#:warp-shape warp-shape '#:round round '#:size 1)) 290 | ] 291 | 292 | [(list (? load-store? f) A B stride offset size transpose '#:warp-shape warp-shape '#:size gsize) 293 | (convert-statement (list f A B stride offset size transpose '#:warp-shape warp-shape '#:round 1 '#:size gsize)) 294 | ] 295 | 296 | [(list (? load-store? f) A B stride offset size transpose '#:warp-shape warp-shape '#:round round '#:size gsize) 297 | (define cuda-f (string-replace (symbol->string f) "-" "_")) 298 | (define global (if (load? f) A B)) 299 | (define local (if (load? f) B A)) 300 | (convert-global-to-local cuda-f 301 | A B round stride offset size transpose warp-shape 302 | (hash-ref matrix-size local) #:size gsize) 303 | ] 304 | 305 | [(list (? load-store? f) A B stride offset size transpose '#:round round 306 | '#:shfl (list 'lambda (list tid i) (list 'sw-xform sw-args ...))) 307 | (define-values (sw-xform-n sw-xform-f) (apply convert-sw-xform sw-args)) 308 | 309 | (define warp-shape 310 | (cond [(= dims 1) 32] 311 | [(= dims 2) '(x-y-z 32 1)] 312 | [(= dims 3) '(x-y-z 32 1 1)])) 313 | (define cuda-f (string-replace (symbol->string f) "-" "_")) 314 | (define global (if (load? f) A B)) 315 | (define local (if (load? f) B A)) 316 | (list 317 | (format-indent "auto perm_~a = [=] (int ~a, int ~a) -> int{ return ~a; };" (sanitize A) tid i (sw-xform-f 0)) 318 | (convert-global-to-local cuda-f 319 | A B round stride offset size transpose warp-shape 320 | (hash-ref matrix-size local) #:shfl (format "perm_~a" (sanitize A)))) 321 | ] 322 | 323 | [(list 'for (list (list vs ls) ...) body) 324 | (for ([v vs]) 325 | (hash-set! env v 1)) 326 | (define inits (list)) 327 | (define conds (list)) 328 | (define incs (list)) 329 | 330 | (for ([v vs] [l ls]) 331 | (let* ([x (sanitize v)] 332 | [b (sanitize l)]) 333 | (set! inits (cons (format "~a = 0" x) inits)) 334 | (set! conds (cons (format "(~a < ~a)" x b) conds)) 335 | (set! incs (cons (format "~a++" x) incs)))) 336 | (define start 337 | (format-indent "for(int ~a; ~a; ~a) {" (string-join inits ",") (string-join conds "&&") (string-join incs ","))) 338 | (inc-indent) 339 | (define body-ret (convert-statement body)) 340 | (dec-indent) 341 | (define end (format-indent "}")) 342 | (list start body-ret end) 343 | ] 344 | 345 | [(list 'for* (list (list vs ls) ...) body) 346 | (for ([v vs]) 347 | (hash-set! env v 1)) 348 | (append 349 | (for/list ([v vs] [l ls]) 350 | (let* ([x (sanitize v)] 351 | [b (sanitize l)] 352 | [temp (format-indent "for(int ~a = 0; ~a < ~a; ~a++) {" x x b x)]) 353 | (inc-indent) 354 | temp)) 355 | (list (convert-statement body)) 356 | (for/list ([v vs]) 357 | (dec-indent) 358 | (format-indent "}"))) 359 | ] 360 | 361 | [(list let? (list (list vs es) ...) body ...) 362 | (define ret 363 | (append 364 | (for/list ([e es] [v vs]) 365 | (let ([e (simplify e)]) 366 | (let-values ([(n f) (convert-expr e)]) 367 | (hash-set! env v n) 368 | (hash-set! temps v e) 369 | (cond 370 | [(= n 1) 371 | (format-indent "int ~a = ~a;" (sanitize v) (f 0))] 372 | [else 373 | (for/list ([i n]) (format-indent "int ~a~a = ~a;" (sanitize v) i (f i)))]) 374 | ))) 375 | (for/list ([st body]) 376 | (convert-statement st)))) 377 | (for ([v vs]) (hash-remove! temps v)) 378 | ret 379 | ] 380 | 381 | [(list 'set matrix idxs ... v) 382 | (define v-str 383 | (let-values ([(n v-f) (convert-expr v)]) 384 | (v-f 0))) 385 | 386 | (format-indent "~a~a = ~a;" (sanitize matrix) (dims-str idxs) v-str) 387 | ] 388 | 389 | )) 390 | 391 | (define (convert-global-to-local name A B round stride offset load-size transpose warp-shape local-size 392 | #:shfl [shfl #f] #:size [gsize 1]) 393 | (define-values (stride-n stride-f) (convert-expr stride)) 394 | (define-values (offset-n offset-f) (convert-expr offset)) 395 | (define-values (shape-n shape-f) (convert-expr warp-shape)) 396 | (define-values (round-n round-f) (convert-expr round)) 397 | (define d (max stride-n offset-n)) 398 | (define size-str "") 399 | (when (> d 1) 400 | (define-values (gsize-n gsize-f) (convert-expr gsize)) 401 | (define l 402 | (append 403 | (for/list ([i (sub1 d)]) (gsize-f i)) 404 | (for/list ([i (sub1 d)]) (format "~a" (list-ref local-size i))))) 405 | (set! size-str (format ",~a" (string-join l ","))) 406 | ) 407 | 408 | (define str-list 409 | (list (format "~a~a~a<~a>((~a*) ~a, (~a*) ~a" name (if shfl "_shfl" "") d data-type 410 | data-type (sanitize A) data-type (sanitize B)) 411 | "," (string-join (for/list ([i d]) (round-f i)) ", ") 412 | "," (string-join (for/list ([i d]) (offset-f i)) ", ") 413 | "," (string-join (for/list ([i d]) (stride-f i)) ", ") 414 | "," (string-join (for/list ([i d]) (shape-f i)) ", ") 415 | (if shfl (format ",~a" shfl) "") 416 | size-str ");")) 417 | (format-indent "~a" (string-join (flatten str-list) "")) 418 | ) 419 | 420 | (define (convert-expr expr) 421 | (match expr 422 | [(list 'get-warpId tid) 423 | (values 424 | 1 425 | (lambda (i) 426 | (cond 427 | [(= dims 1) "(threadIdx.x/32)"] 428 | [(= dims 2) "((threadIdx.y*blockDim.x + threadIdx.x)/32)"] 429 | [(= dims 3) "((threadIdx.z*blockDim.y*blockDim.x + threadIdx.y*blockDim.x + threadIdx.x)/32)"] 430 | )))] 431 | 432 | [(list 'get-idInWarp tid) 433 | (values 1 (lambda (i) "(threadIdx.x&31)"))] 434 | 435 | [(list 'get-global-threadId tid bid) 436 | (values 437 | dims 438 | (lambda (i) 439 | (cond 440 | [(= i 0) "(blockIdx.x * blockDim.x + threadIdx.x)"] 441 | [(= i 1) "(blockIdx.y * blockDim.y + threadIdx.y)"] 442 | [(= i 2) "(blockIdx.z * blockDim.z + threadIdx.z)"])))] 443 | 444 | [(list 'shfl e lane) 445 | (define-values (e-n e-f) (convert-expr e)) 446 | (define-values (lane-n lane-f) (convert-expr lane)) 447 | (values 1 (lambda (i) (format "__shfl_sync(FULL_MASK, ~a, ~a)" (e-f 0) (lane-f 0)))) 448 | ] 449 | 450 | [(list 'get matrix idxs1 ... (list 'ite c a b) idxs2 ...) 451 | (define-values (c-n c-f) (convert-expr c)) 452 | (define-values (geta-n geta-f) (convert-expr (append `(get ,matrix) idxs1 `(,a) idxs2))) 453 | (define-values (getb-n getb-f) (convert-expr (append `(get ,matrix) idxs1 `(,b) idxs2))) 454 | (values 1 (lambda (i) (format "~a? ~a: ~a" (c-f 0) (geta-f 0) (getb-f 0)))) 455 | ] 456 | 457 | [(list 'get matrix idxs ...) 458 | (define ites (map ite-const? idxs)) 459 | (define ite-n (count identity ites)) 460 | 461 | (cond 462 | [(> ite-n 0) 463 | (define new-idxs 464 | (for/list ([idx idxs]) 465 | (if (ite-const? idx) 466 | (hash-ref temps idx) 467 | idx))) 468 | (convert-expr (append `(get ,matrix) new-idxs)) 469 | ] 470 | 471 | [else 472 | (define str-list 473 | (for/list ([idx idxs]) 474 | (let-values ([(n idx-f) (convert-expr idx)]) 475 | (idx-f 0)))) 476 | 477 | (values 1 (lambda (i) (format "~a~a" (sanitize matrix) (dims-str str-list))))] 478 | ) 479 | ] 480 | 481 | [(list 'sw-xform j n* cj* dj* group* conf-fw 482 | k m* ck* dk*) 483 | (convert-sw-xform j n* cj* dj* group* conf-fw 484 | k m* ck* dk* 0)] 485 | 486 | [(list 'sw-xform j n* cj* dj* group* conf-fw 487 | k m* ck* dk* offset) 488 | (convert-sw-xform j n* cj* dj* group* conf-fw 489 | k m* ck* dk* offset)] 490 | 491 | [(list 'accumulate-final acc) 492 | (define final-op (cdr (hash-ref accumulators acc))) 493 | 494 | (match final-op 495 | ['identity (convert-expr acc)] 496 | [(list 'lambda (list arg) body) 497 | (hash-set! acc-replace arg acc) 498 | (define-values (n f) (convert-expr body)) 499 | (hash-remove! acc-replace arg) 500 | (values n f) 501 | ]) 502 | ] 503 | 504 | [(list 'get-x v) 505 | (define-values (n f) (convert-expr v)) 506 | (values 1 (lambda (i) (f 0))) 507 | ] 508 | 509 | [(list 'get-y v) 510 | (define-values (n f) (convert-expr v)) 511 | (values 1 (lambda (i) (f 1))) 512 | ] 513 | 514 | [(list 'get-z v) 515 | (define-values (n f) (convert-expr v)) 516 | (values 1 (lambda (i) (f 2))) 517 | ] 518 | 519 | [(list '@dup x) 520 | (convert-expr x) 521 | ] 522 | 523 | [(list 'x-y-z xs ...) 524 | (values (length xs) (lambda (i) 525 | (define-values (n f) (convert-expr (list-ref xs i))) 526 | (f 0))) 527 | ] 528 | 529 | [(list 'ite c a b) 530 | (define new-ite (simplify expr)) 531 | 532 | (match new-ite 533 | [(list 'ite _ _ _) 534 | (define-values (c-n c-f) (convert-expr c)) 535 | (define-values (a-n a-f) (convert-expr a)) 536 | (define-values (b-n b-f) (convert-expr b)) 537 | (define max-d (max a-n b-n c-n)) 538 | 539 | (values max-d 540 | (lambda (i) 541 | (format "~a? ~a: ~a" (c-f i) (a-f i) (b-f i))))] 542 | [else (convert-expr new-ite)]) 543 | ] 544 | 545 | [(list op args ...) 546 | (define new-expr (simplify expr)) 547 | 548 | (match new-expr 549 | [(list op args ...) 550 | (define max-d 1) 551 | (define fs 552 | (for/list ([arg args]) 553 | (let-values ([(n f) (convert-expr arg)]) 554 | (when (> n max-d) (set! max-d n)) 555 | f))) 556 | 557 | (values max-d 558 | (lambda (i) 559 | (format "(~a)" 560 | (string-join 561 | (for/list ([f fs]) 562 | (f i)) 563 | (convert-op op)))))] 564 | 565 | [v (convert-expr v)]) 566 | ] 567 | 568 | 569 | [(? cuda-var? v) 570 | (define name (hash-ref cuda-vars v)) 571 | (values dims 572 | (lambda (i) 573 | (cond [(= i 0) (format "~a.x" name)] 574 | [(= i 1) (format "~a.y" name)] 575 | [(= i 2) (format "~a.z" name)]))) 576 | ] 577 | 578 | [v 579 | (define d (if (hash-has-key? env v) (hash-ref env v) 1)) 580 | (cond 581 | [(hash-has-key? acc-replace v) 582 | (define name (hash-ref acc-replace v)) 583 | (values d (lambda (i) (format "~a" (sanitize name))))] 584 | [(= d 1) (values d (lambda (i) (format "~a" (sanitize v))))] 585 | [else (values d (lambda (i) (format "~a~a" (sanitize v) i)))]) 586 | ] 587 | )) 588 | 589 | (define (ite-const? e) 590 | (if (hash-has-key? temps e) 591 | (let ([v (hash-ref temps e)]) 592 | (match v 593 | [(list 'ite _ _ _) #t] 594 | ;[(list 'ite c (? number?) (? number?)) #t] 595 | [_ #f])) 596 | #f)) 597 | 598 | (define (accumulate vals ops) 599 | (cond 600 | [(= (length ops) 0) 601 | (define-values (v-n v-f) (convert-expr vals)) 602 | (v-f 0) 603 | ] 604 | 605 | [else 606 | (define vals-ret (for/list ([v vals]) (accumulate v (cdr ops)))) 607 | (string-join vals-ret (car ops)) 608 | ])) 609 | 610 | (define (convert-sw-xform j n* cj* dj* group* conf-fw 611 | k m* ck* dk* [offset 0]) 612 | (define n (eval-const n*)) 613 | (define cj (eval-const cj*)) 614 | (define dj (eval-const dj*)) 615 | (define group (eval-const group*)) 616 | (define m (eval-const m*)) 617 | (define ck (eval-const ck*)) 618 | (define dk (eval-const dk*)) 619 | 620 | (define offset1-a 621 | (cond 622 | [(equal? dj n) 0] 623 | [(equal? group n) (@quotient j dj)] 624 | [else (@quotient (@modulo j group) dj)])) 625 | 626 | (define offset1-b (@* k ck)) 627 | 628 | (define offset1-c 629 | (cond 630 | [(equal? dk m) 0] [else (@quotient k dk)])) 631 | 632 | (define offset1 (@+ offset1-a offset1-b offset1-c offset)) 633 | 634 | (define common 635 | (if (and (number? group) (number? dj)) 636 | (quotient group dj) (@quotient group dj))) 637 | (define offset2 638 | (cond 639 | [(or (= conf-fw 1) (equal? common group)) offset1] 640 | [else (@modulo offset1 common)])) 641 | 642 | (define group-offset 643 | (cond 644 | [(equal? group n) 0] [else (@* (@quotient j group) group)])) 645 | 646 | (define all (@+ group-offset 647 | (@modulo (@+ (@* j cj) offset2) group))) 648 | 649 | (convert-expr all)) 650 | 651 | (define (convert-sw-xform2 j n* cj* dj* group* conf-fw 652 | k m* ck* dk* [offset 0]) 653 | (define n (eval-const n*)) 654 | (define cj (eval-const cj*)) 655 | (define dj (eval-const dj*)) 656 | (define group (eval-const group*)) 657 | (define m (eval-const m*)) 658 | (define ck (eval-const ck*)) 659 | (define dk (eval-const dk*)) 660 | 661 | (unless (equal? group n) 662 | (raise (format "sw-xform function for permute-vector must have (~a) n = (~a) group." n group))) 663 | 664 | (define offset-j ;j 665 | (cond 666 | [(equal? dj n) 0] 667 | [(equal? group n) (@quotient j dj)] 668 | [else (@quotient (@modulo j group) dj)])) 669 | 670 | (define offset1-b (@* k ck)) ;k 671 | 672 | (define offset1-c 673 | (cond 674 | [(equal? dk m) 0] [else (@quotient k dk)])) ;k 675 | 676 | (define offset-k (@+ offset1-b offset1-c offset)) 677 | 678 | (define common 679 | (if (and (number? group) (number? dj)) 680 | (quotient group dj) (@quotient group dj))) 681 | 682 | (unless (or (= conf-fw 1) (equal? common group)) 683 | (unless (equal? common 1) 684 | (raise (exn "sw-xform function for permute-vector: invalid conf-fw, group, dj.")) 685 | (set! offset-j 0) 686 | (set! offset-k 0))) 687 | 688 | (define-values (j-n j-f) (convert-expr (@+ (@* j cj) offset-j))) 689 | (define-values (k-n k-f) (convert-expr offset-k)) 690 | (values (j-f 0) (k-f 0))) 691 | 692 | (define (roate-one-step x y n i i-expr j-expr) 693 | (list 694 | (format-indent "for(int ~a=0; ~a<~a; ~a++) {" i i n i) 695 | (inc-indent) 696 | (format-indent "~a[~a] = ~a[(~a+~a)%~a];" y i x i-expr j-expr n) 697 | (dec-indent) 698 | (format-indent "}") 699 | )) 700 | 701 | (define (rotate-log-step x* y* n skip rot i i-expr last-iter) 702 | (define x (if (= skip 1) x* (format "_~a~a" x* (/ skip 2)))) 703 | (define y (if last-iter y* (format "_~a~a" x* skip))) 704 | (list 705 | (format-indent "for(int ~a=0; ~a<~a; ~a++) {" i i n i) 706 | (inc-indent) 707 | (format-indent "if((~a & ~a)==0) ~a[~a] = ~a[(~a)%~a];" rot skip y i x (if last-iter i-expr i) n) 708 | (format-indent "else ~a[~a] = ~a[(~a+~a)%~a];" y i x (if last-iter i-expr i) skip n) 709 | (dec-indent) 710 | (format-indent "}") 711 | )) 712 | 713 | (define (convert-op op) 714 | (match op 715 | ['quotient "/"] 716 | ['modulo "%"] 717 | ['bvand "&"] 718 | ['bvxor "^"] 719 | [(? string?) op] 720 | [x (symbol->string x)])) 721 | 722 | (define (dims-str idxs) 723 | (define str-list 724 | (for/list ([idx (reverse idxs)]) 725 | (let-values ([(n idx-f) (convert-expr idx)]) 726 | (idx-f 0)))) 727 | 728 | (define dims (map (lambda (s) (format "[~a]" s)) str-list)) 729 | (string-join dims "")) 730 | 731 | (define (@++ xs) 732 | (cond 733 | [(= (length xs) 1) (car xs)] 734 | [else 735 | (define y (@++ (cdr xs))) 736 | (define x (car xs)) 737 | (cond 738 | [(equal? x 0) y] 739 | [(equal? y 0) x] 740 | [else `(+ ,x ,y)]) 741 | ])) 742 | 743 | (define (@-- xs) 744 | (cond 745 | [(= (length xs) 1) (car xs)] 746 | [else 747 | (define y (@++ (cdr xs))) 748 | (define x (car xs)) 749 | (cond 750 | [(and (equal? x 0) (number? y)) (- 0 y)] 751 | [(equal? y 0) x] 752 | [else `(- ,x ,y)]) 753 | ])) 754 | 755 | (define (@** xs) 756 | (define ret 757 | (cond 758 | [(= (length xs) 1) (car xs)] 759 | [else 760 | (define y (@** (cdr xs))) 761 | (define x (car xs)) 762 | (cond 763 | [(equal? x 0) 0] 764 | [(equal? y 0) 0] 765 | [(equal? x 1) y] 766 | [(equal? y 1) x] 767 | [else `(* ,x ,y)]) 768 | ])) 769 | ret 770 | ) 771 | 772 | (define-syntax-rule (@+ x ...) (@++ (list x ...))) 773 | (define-syntax-rule (@- x ...) (@-- (list x ...))) 774 | (define-syntax-rule (@* x ...) (@** (list x ...))) 775 | 776 | (define (@quotient x y) 777 | (cond 778 | [(equal? y 1) x] 779 | [(equal? x 0) 0] 780 | [else `(quotient ,x ,y)])) 781 | 782 | (define (@modulo x y) 783 | (cond 784 | [(equal? y 1) 0] 785 | [else `(modulo ,x ,y)])) 786 | 787 | (define (simplify e) 788 | (match e 789 | [(list 'ite c a b) 790 | (define new-c (simplify c)) 791 | (cond 792 | [(equal? new-c #t) (simplify a)] 793 | [(equal? new-c #f) (simplify b)] 794 | [else `(ite ,new-c ,(simplify a) ,(simplify b))])] 795 | [(list '+ args ...) (@++ (for/list ([x args]) (simplify x)))] 796 | [(list '- args ...) (@-- (for/list ([x args]) (simplify x)))] 797 | [(list '* args ...) (@** (for/list ([x args]) (simplify x)))] 798 | [(list 'quotient x y) (@quotient (simplify x) (simplify y))] 799 | [(list 'modulo x y) (@modulo (simplify x) (simplify y))] 800 | [(list op a b) 801 | (match `(,op ,(simplify a) ,(simplify b)) 802 | [(list '= x x) #t] 803 | [(list '>= x x) #t] 804 | [(list '<= x x) #t] 805 | [(list '> x x) #f] 806 | [(list '< x x) #f] 807 | [(list '= x (list '+ (? number?) x)) #f] 808 | [(list '<= x (list '+ (? positive?) x)) #t] 809 | [(list '< x (list '+ (? positive?) x)) #t] 810 | [(list '<= x (list '+ (? negative?) x)) #f] 811 | [(list '< x (list '+ (? negative?) x)) #f] 812 | [new-expr new-expr])] 813 | [(list '@dup x) (simplify x)] 814 | [_ 815 | (if (hash-has-key? env-consts e) 816 | (hash-ref env-consts e) 817 | e) 818 | ])) 819 | 820 | -------------------------------------------------------------------------------- /cuda-synth.rkt: -------------------------------------------------------------------------------- 1 | #| 2 | | Copyright (c) 2018-2019, University of California, Berkeley. 3 | | Copyright (c) 2018-2019, NVIDIA CORPORATION. All rights reserved. 4 | | 5 | | Redistribution and use in source and binary forms, with or without 6 | | modification, are permitted provided that the following conditions are met: 7 | | 8 | | 1. Redistributions of source code must retain the above copyright notice, 9 | | this list of conditions and the following disclaimer. 10 | | 11 | | 2. Redistributions in binary form must reproduce the above copyright notice, 12 | | this list of conditions and the following disclaimer in the documentation 13 | | and/or other materials provided with the distribution. 14 | | 15 | | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 16 | | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 17 | | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 18 | | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 19 | | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 20 | | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 21 | | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 22 | | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 23 | | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 24 | | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 25 | | POSSIBILITY OF SUCH DAMAGE. 26 | |# 27 | 28 | #lang rosette 29 | 30 | (require rosette/lib/synthax) 31 | (require "util.rkt" "cuda.rkt") 32 | (provide ?? ?lane ?lane-mod 33 | ?sw-xform ?sw-xform-easy ?sw-xform-extra 34 | ?cond ?cond-easy 35 | ?warp-size ?warp-offset 36 | print-forms choose 37 | ID get-grid-storage collect-inputs check-warp-input num-regs vector-list-append 38 | unique unique-warp unique-list) 39 | 40 | ;; Condition swizzle (easy template, smaller search space) 41 | (define-synthax ?cond-easy 42 | ([(?cond-easy x ...) 43 | (choose #t #f 44 | ((choose < <= > >= =) (choose x ...) (choose x ...)))] 45 | 46 | )) 47 | 48 | ;; Condition swizzle (full template) 49 | (define-synthax ?cond 50 | ([(?cond x ... #:mod mod) 51 | (choose #t #f 52 | ((choose < <= > >= =) 53 | (choose x ...) 54 | (modulo (+ (* (??) (choose x ...)) (??)) mod) 55 | ))] 56 | 57 | ;[(?cond x ... [c ...]) 58 | ; (choose #t #f 59 | ; ((choose < <= > >= =) (choose x ...) 60 | ; ((choose + -) (?const warpSize c ...) (choose x ...))))] 61 | 62 | [(?cond x ...) 63 | (choose #t #f 64 | ((choose < <= > >= =) (choose x ...) 65 | ((choose + -) (?const warpSize) (choose x ...))))] 66 | 67 | )) 68 | 69 | (define-synthax (?lane-c x ... [c ...] depth) 70 | #:base (choose x ... (@dup (?const c ...))) 71 | #:else (choose 72 | x ... (@dup (?const c ...)) 73 | ((choose quotient modulo *) (?lane x ... [c ...] (- depth 1)) (?const c ...)) 74 | ((choose + -) 75 | (?lane-c x ... [c ...] (- depth 1)) 76 | (?lane-c x ... [c ...] (- depth 1))))) 77 | 78 | (define-synthax (?lane x ... depth) 79 | #:base (choose x ... (@dup (??))) 80 | #:else (choose 81 | x ... (@dup (??)) 82 | ((choose quotient modulo *) (?lane x ... (- depth 1)) (??)) 83 | ((choose + -) 84 | (?lane x ... (- depth 1)) 85 | (?lane x ... (- depth 1))))) 86 | 87 | ;; Naive template for transformation index swizzle 88 | (define-synthax ?lane-mod 89 | ([(?lane-mod x ... depth n [c ...]) 90 | (modulo (?lane-c x ... [c ...] depth) n)] 91 | 92 | [(?lane-mod x ... depth n) 93 | (modulo (?lane x ... depth) n)] 94 | 95 | )) 96 | 97 | ;; Proposed template for transformation index swizzle (easy template, smaller search space) 98 | (define-synthax ?sw-xform-easy 99 | ([(?sw-xform-easy eid n k m) 100 | (sw-xform eid n (??) n n 1 ;(choose 1 -1) 101 | k m (??) m (??))] 102 | 103 | [(?sw-xform-easy eid n k m #:fw conf-fw) 104 | (sw-xform eid n (??) n n conf-fw 105 | k m (??) m (??))] 106 | 107 | [(?sw-xform-easy eid n k m [c ...]) 108 | (sw-xform eid n (?const c ...) n n 1 ;(choose 1 -1) 109 | k m (?const c ...) m (?const m c ...))] 110 | 111 | [(?sw-xform-easy eid n k m [c ...] #:fw conf-fw) 112 | (sw-xform eid n (?const c ...) n n conf-fw 113 | k m (?const c ...) m (?const m c ...))] 114 | ) 115 | ) 116 | 117 | ;; Proposed template for transformation index swizzle (full template) 118 | (define-synthax ?sw-xform 119 | ([(?sw-xform eid n k m) 120 | (sw-xform eid n (??) (??) (??) (choose 1 -1) 121 | k m (??) (??) (??))] 122 | 123 | [(?sw-xform eid n k m #:fw conf-fw) 124 | (sw-xform eid n (??) (??) (??) conf-fw 125 | k m (??) (??) (??))] 126 | 127 | [(?sw-xform eid n k m [c ...]) 128 | (sw-xform eid n (?const c ...) (?const n c ...) (?const n c ...) (choose 1 -1) 129 | k m (?const c ...) (?const m c ...) (?const m c ...))] 130 | 131 | [(?sw-xform eid n k m [c ...] #:fw conf-fw) 132 | (sw-xform eid n (?const c ...) (?const n c ...) (?const n c ...) conf-fw 133 | k m (?const c ...) (?const m c ...) (?const m c ...))] 134 | ) 135 | ) 136 | 137 | ;; Proposed template for transformation index swizzle (advanced template, bigger search space) 138 | (define-synthax ?sw-xform-extra 139 | ([(?sw-xform-extra eid n k m) 140 | (sw-xform eid n (??) (??) (??) (choose 1 -1) 141 | k m (??) (??) (??) 142 | #:gcd (??) #:ecr (??) #:ec (??) 143 | )])) 144 | 145 | (define-synthax ?const 146 | ([(?const c ...) 147 | (choose 0 1 -1 c ...)]) 148 | ) 149 | 150 | (define-synthax ?const- 151 | ([(?const c ...) 152 | (choose 0 1 c ... -1 (- 0 c) ...)]) 153 | ) 154 | 155 | (define-synthax (?warp-size-const x ... depth) 156 | #:base (choose x ... (??)) 157 | #:else (choose 158 | x ... (??) 159 | ((choose + -) 160 | (?warp-size-const x ... (- depth 1)) 161 | (?warp-size-const x ... (- depth 1))))) 162 | 163 | (define-synthax (?warp-size x ... depth) 164 | #:base (choose x ...) 165 | #:else (choose 166 | x ... 167 | ((choose + -) 168 | (?warp-size x ... (- depth 1)) 169 | (?warp-size-const x ... (- depth 1))) 170 | (- 171 | (?warp-size-const x ... (- depth 1)) 172 | (?warp-size x ... (- depth 1))) 173 | (* (??) (?warp-size x ... (- depth 1))) 174 | )) 175 | 176 | (define-synthax ?warp-offset 177 | ([(?warp-offset [id size] ...) 178 | (+ (??) (* (??) id size) ...)]) 179 | ) 180 | 181 | 182 | ;;;;;;;;;;;;;;;;; for data loading synthesis ;;;;;;;;;;;;;;;;;;;; 183 | (struct ID (thread warp block)) 184 | 185 | (define (get-grid-storage) 186 | (define blocks (create-matrix (get-gridDim) list)) 187 | (define warps (create-matrix (cons (/ blockSize warpSize) (get-gridDim)) list)) 188 | (define threads (create-matrix (append (get-blockDim) (get-gridDim)) list)) 189 | (values threads warps blocks)) 190 | 191 | (define (update-val M i v) 192 | (define current (get* M i)) 193 | (define update (append v current)) 194 | (set* M i update)) 195 | 196 | (define (collect-inputs O IDs threads warps blocks) 197 | (define (f o id) 198 | (cond 199 | [(vector? id) 200 | (for ([oi o] [idi id]) (f oi idi))] 201 | 202 | [(and (accumulator? o) (ID? id)) 203 | (define vals (flatten (accumulator-val o))) 204 | (update-val blocks (ID-block id) vals) 205 | (update-val warps (cons (ID-warp id) (ID-block id)) vals) 206 | (update-val threads (append (ID-thread id) (ID-block id)) vals)])) 207 | (f O IDs)) 208 | 209 | (define (num-regs warps I) 210 | (define all-inputs (list->set (to-list I))) 211 | (define max-num 0) 212 | (define (f x) 213 | (cond 214 | [(vector? x) (for ([xi x]) (f xi))] 215 | [else 216 | (define n (set-count (set-intersect (list->set x) all-inputs))) 217 | (when (> n max-num) (set! max-num n))])) 218 | (f warps) 219 | (+ (quotient (- max-num 1) warpSize) 1)) 220 | 221 | (define (to-list x) 222 | (cond 223 | [(or (vector? x) (list? x)) (for/list ([xi x]) (to-list xi))] 224 | [else x])) 225 | 226 | (define (vector-list-append x y) 227 | (cond 228 | [(and (vector? x) (vector? y)) (for/vector ([xi x] [yi y]) (vector-list-append xi yi))] 229 | [else (append x y)])) 230 | 231 | (define (check-warp-input warp-input-spec I I-cached warpId blockId) 232 | (define all-inputs (to-list I)) 233 | (define n (/ (vector-length warpId) warpSize)) 234 | (define warp-input (create-matrix (list n) list)) 235 | ;(pretty-display `(I-cached ,I-cached)) 236 | (for ([my-input I-cached] 237 | [wid warpId]) 238 | (let* ([current (get warp-input wid)] 239 | [update (append (to-list my-input) current)]) 240 | (set warp-input wid update))) 241 | ;(pretty-display `(warp-input ,warp-input)) 242 | 243 | (for ([i n] 244 | [my-input warp-input]) 245 | (let ([spec (list->set (get* warp-input-spec (cons i blockId)))]) 246 | (for ([x spec]) 247 | (when (member x all-inputs) 248 | ;(pretty-display `(check ,i ,n ,x ,(list? (member x my-input)))) 249 | (assert (member x my-input)))) 250 | ))) 251 | 252 | 253 | ;;;;;;;;;;;;;;;;;;;;;;;;;;; constraint ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; 254 | (define (unique-warp lane) 255 | ;(pretty-display `(unique-warp ,lane)) 256 | (define len (vector-length lane)) 257 | (for ([o (quotient len warpSize)]) 258 | (let ([offset (* o warpSize)]) 259 | (let ([l (for/list ([i warpSize]) 260 | (vector-ref lane (+ offset i)))]) 261 | (apply distinct? l))))) 262 | 263 | (define (unique lane) 264 | ;(pretty-display `(unique ,lane)) 265 | (define len (vector-length lane)) 266 | (for ([i len]) 267 | (let ([x (vector-ref lane i)]) 268 | (for ([j (range (add1 i) len)]) 269 | (let ([y (vector-ref lane j)]) 270 | ;(pretty-display `(xy ,(+ offset i) ,(+ offset j) ,x ,y)) 271 | (assert (not (= x y)))))))) 272 | 273 | (define (unique-list l) 274 | (apply distinct? l)) 275 | -------------------------------------------------------------------------------- /cuda.rkt: -------------------------------------------------------------------------------- 1 | #| 2 | | Copyright (c) 2018-2019, University of California, Berkeley. 3 | | Copyright (c) 2018-2019, NVIDIA CORPORATION. All rights reserved. 4 | | 5 | | Redistribution and use in source and binary forms, with or without 6 | | modification, are permitted provided that the following conditions are met: 7 | | 8 | | 1. Redistributions of source code must retain the above copyright notice, 9 | | this list of conditions and the following disclaimer. 10 | | 11 | | 2. Redistributions in binary form must reproduce the above copyright notice, 12 | | this list of conditions and the following disclaimer in the documentation 13 | | and/or other materials provided with the distribution. 14 | | 15 | | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 16 | | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 17 | | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 18 | | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 19 | | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 20 | | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 21 | | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 22 | | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 23 | | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 24 | | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 25 | | POSSIBILITY OF SUCH DAMAGE. 26 | |# 27 | 28 | #lang rosette 29 | 30 | (require "util.rkt") 31 | 32 | (define (drop@ name) 33 | (if (regexp-match? #rx"^@.+$" name) 34 | (regexp-replace #rx"@" name "") 35 | name)) 36 | 37 | (require (only-in racket [sort %sort] [< %<])) 38 | (provide (rename-out [@+ +] [@- -] [@* *] [@modulo modulo] [@quotient quotient] [@< <] [@<= <=] [@> >] [@>= >=] [@= =] [@ite ite] 39 | [@bvadd bvadd] [@bvsub bvsub] [@bvand bvand] [@bvxor bvxor] [@bvshl bvshl] [@bvlshr bvlshr] [@extract extract] [@bvlog bvlog]) 40 | @int @bv @dup gen-uid gen-sym gen-bv for/bounded 41 | define-shared 42 | create-matrix-local 43 | global-to-shared shared-to-global 44 | global-to-local local-to-global 45 | global-to-reg reg-to-global reg-to-global-update 46 | warpSize set-warpSize blockSize set-blockSize 47 | get-warpId get-idInWarp get-blockDim get-gridDim get-global-threadId 48 | shfl shfl-send sw-xform sw-xform-prime rotate-nogroup permute-vector 49 | accumulator accumulator? accumulator-val create-accumulator accumulate accumulate-merge accumulate-final 50 | get-accumulator-val acc-equal? acc-print 51 | run-kernel get-cost reset-cost) 52 | 53 | 54 | (define warpSize 4) 55 | (define blockSize warpSize) 56 | (define blockDim (list blockSize)) 57 | (define gridDim (list 1)) 58 | 59 | ;; Return a vector of size blockSize with value x. 60 | (define (@dup x) (for/vector ([i blockSize]) x)) 61 | (define (get-blockDim) blockDim) 62 | (define (get-gridDim) gridDim) 63 | 64 | (define (set-warpSize s) 65 | (set! warpSize s)) 66 | (define (set-blockSize s) 67 | (set! blockSize s)) 68 | 69 | (define uid 0) 70 | 71 | ;; Generate a unique id. 72 | (define (gen-uid) 73 | (set! uid (add1 uid)) 74 | uid) 75 | 76 | ;; Generate a symbolic integer variable. 77 | (define (gen-sym) 78 | (define-symbolic* x integer?) 79 | x) 80 | 81 | ;; Generate a symbolic bitvector variable. 82 | (define (gen-bv) 83 | (define-symbolic* x (bitvector 4)) 84 | x) 85 | 86 | ;;;;;;;;;;;;;;;;;;;;;;;;;;; lifted operations ;;;;;;;;;;;;;;;;;;;;;;;;;;; 87 | 88 | ;; Create a variable in shared memory. 89 | (define-syntax-rule (define-shared x exp) (define x exp)) 90 | 91 | ;; Apply op on every element of x and y. 92 | (define (iterate x y op) 93 | (define (f x y) 94 | (cond 95 | [(and (vector? x) (vector? y)) (for/vector ([i (vector-length x)]) (f (get x i) (get y i)))] 96 | [(vector? x) (for/vector ([i (vector-length x)]) (f (get x i) y))] 97 | [(vector? y) (for/vector ([i (vector-length y)]) (f x (get y i)))] 98 | [(and (list? x) (list? y)) (map f x y)] 99 | [(list? x) (map (lambda (xi) (f xi y)) x)] 100 | [(list? y) (map (lambda (yi) (f x yi)) y)] 101 | [else (op x y)])) 102 | (f x y)) 103 | 104 | ;; Vectorized ite 105 | (define (@ite c x y) ;; TODO: not quite correct 106 | (define (f c x y) 107 | (cond 108 | [(and (vector? c) (vector? x) (vector? y)) (for/vector ([i (vector-length c)]) (f (get c i) (get x i) (get y i)))] 109 | [(and (vector? c) (vector? x)) (for/vector ([i (vector-length c)]) (f (get c i) (get x i) y))] 110 | [(and (vector? c) (vector? y)) (for/vector ([i (vector-length c)]) (f (get c i) x (get y i)))] 111 | [(and (vector? x) (vector? y)) (for/vector ([i (vector-length x)]) (f c (get x i) (get y i)))] 112 | [(and (vector? c)) (for/vector ([i (vector-length c)]) (f (get c i) x y))] 113 | [(and (vector? x)) (for/vector ([i (vector-length x)]) (f c (get x i) y))] 114 | [(and (vector? y)) (for/vector ([i (vector-length y)]) (f c x (get y i)))] 115 | [else (if c x y)]) 116 | ) 117 | (f c x y)) 118 | 119 | (define-syntax-rule (define-operator my-op @op op) 120 | (begin 121 | (define (@op l) 122 | (define ret 123 | (cond 124 | [(= (length l) 1) (car l)] 125 | [(= (length l) 2) 126 | (iterate (first l) (second l) op)] 127 | [else (iterate (first l) (@op (cdr l)) op)])) 128 | (inc-cost my-op ret l) 129 | ret) 130 | (define my-op (lambda l (@op l))) 131 | )) 132 | 133 | ;; Vector operations with cost 0. 134 | (define-operator @++ $++ +) 135 | (define-operator @** $** *) 136 | 137 | ;; Vector operations. 138 | (define-operator @+ $+ +) 139 | (define-operator @- $- -) 140 | (define-operator @* $* *) 141 | (define-operator @> $> >) 142 | (define-operator @>= $>= >=) 143 | (define-operator @< $< <) 144 | (define-operator @<= $<= <=) 145 | (define-operator @= $= =) 146 | (define-operator @modulo $modulo modulo) 147 | (define-operator @quotient $quotient quotient) 148 | 149 | (define-operator @bvadd $bvadd bvadd) 150 | (define-operator @bvsub $bvsub bvsub) 151 | (define-operator @bvand $bvand bvand) 152 | (define-operator @bvxor $bvxor bvxor) 153 | (define-operator @bvshl $bvshl bvshl) 154 | (define-operator @bvlshr $bvlshr bvlshr) 155 | 156 | (define (@bv x) 157 | (if (vector? x) 158 | (for/vector ([i (vector-length x)]) 159 | (@bv (vector-ref x i))) 160 | (integer->bitvector x (bitvector BW)))) 161 | 162 | (define (@int x) 163 | (if (vector? x) 164 | (for/vector ([i (vector-length x)]) 165 | (@int (vector-ref x i))) 166 | (bitvector->integer x))) 167 | 168 | (define (@bvlog x) 169 | (define y (log x 2)) 170 | (assert (integer? y)) 171 | (integer->bitvector (exact->inexact y) (bitvector BW))) 172 | 173 | (define (@extract x b) 174 | (if (vector? x) 175 | (for/vector ([i (vector-length x)]) 176 | (@extract (vector-ref x i) b)) 177 | (let ([s (bvsub (bv BW (bitvector BW)) b)]) 178 | (bvlshr (bvshl x s) s)))) 179 | 180 | ;; Compute GCD of x and y with recursive bound = 8. 181 | (define (gcd/bound x y [depth 8]) 182 | (assert (> depth 0)) 183 | (if (= y 0) 184 | x 185 | (gcd/bound y (modulo x y) (sub1 depth)))) 186 | 187 | ;; Produce a permutation of 1D vector x of size n according to 188 | ;; the shuffle function f. 189 | (define (permute-vector x n f) 190 | (pretty-display `(permute-vector ,n)) 191 | (define y (create-matrix-local (x-y-z n))) 192 | (for ([i n]) 193 | (pretty-display `(i ,i ,n)) 194 | (set y (@dup i) (get x (f i)))) 195 | y) 196 | 197 | ;; Transformation index swizzle. 198 | ;; Refer to Section 5.3 of https://mangpo.net/papers/swizzle-inventor-asplos19.pdf 199 | ;; The arguments' names should be consistent with the paper. 200 | (define (sw-xform i n cf df group wrap 201 | k m cr dr [c 0] 202 | #:gcd [gcd (quotient group df)] 203 | #:ecr [ecr 0] #:ec [ec 0] ; extra rot 204 | ;;#:cz [cz 1] #:nz [nz group] ; extra fan 205 | ) 206 | (assert (and (>= group 1) (<= group n))) 207 | (assert (and (>= cf -1) (< cf group))) 208 | (assert (and (>= cr -1) (< cr group))) 209 | (assert (and (>= c 0) (< c group))) 210 | 211 | (define rem (modulo n group)) 212 | (assert (= rem 0)) 213 | 214 | ;; df should be group/gcd(group, cf) 215 | ;; gcd = group/df 216 | (assert (= (modulo group df) 0)) 217 | (assert (= (modulo cf (quotient group df)) 0)) 218 | 219 | (assert (= (modulo m dr) 0)) 220 | ;; If we don't impose gcd to be actual gcd(group, cf), then our equation contains Eq(24) from Trove. 221 | ;; (assert (= gcd (quotient group df))) 222 | 223 | (define ii (@modulo (@+ i (@* ecr k) ec) group)) ; extra rot (before fan) 224 | ;; (define ii (@modulo (@+ (@* j cz) (@quotient j nz)) group)) ; extra fan 225 | 226 | (define offset1 (@+ (@quotient ii df) ; fan conflict 227 | (@* k cr) (@quotient k dr) c)) ; rot 228 | (define offset2 ; rotation (after fan) 229 | (if (= wrap 1) 230 | offset1 ; rot 231 | (@modulo offset1 gcd))) ; grouped rot 232 | 233 | (@+ (@* (@quotient ii group) group) ; top-level group 234 | (@modulo (@+ (@* ii cf) ; fan without fan conflict 235 | offset2) ; fan conflict + rotation 236 | group)) 237 | ) 238 | 239 | ;; sw-xform when cf and n are co-prime. 240 | (define-syntax-rule (sw-xform-prime i n cf 241 | k m cr dr) 242 | (sw-xform i n cf n n 1 243 | k m cr dr)) 244 | 245 | ;; rotation 246 | (define-syntax-rule (rotate-nogroup i n 247 | k m cr dr) 248 | (sw-xform i n 1 n n 1 249 | k m cr dr)) 250 | 251 | ;;;;;;;;;;;;;;;;;;;;;;;;;;; performance cost ;;;;;;;;;;;;;;;;;;;;;;;;;;; 252 | 253 | (define cost 0) 254 | (define (reset-cost) (set! cost 0)) 255 | (define (get-cost) cost) 256 | 257 | (define (cost-of op) 258 | (cond 259 | [(member op (list @+ @- @> @>= @< @<= @= @bvadd @bvsub @bvand @bvxor @bvshl @bvlshr)) 1] 260 | [(member op (list @* @modulo @quotient)) 2] 261 | [(member op (list @++ @**)) 0] 262 | [else (assert `(cost-of ,op unimplemented))])) 263 | 264 | (define (zero? x) (= x 0)) 265 | (define (one? x) (= x 1)) 266 | (define (minus-one? x) (= x -1)) 267 | (define (zero-bv? x) (= x (bv 0 BW))) 268 | (define (one-bv? x) (= x (bv 1 BW))) 269 | (define (minus-one-bv? x) (= x (bv -1 BW))) 270 | (define (true? x) (and (boolean? x) x)) 271 | (define (false? x) (and (boolean? x) (not x))) 272 | 273 | (define (all? x f) 274 | (cond 275 | [(vector? x) 276 | (define ret #t) 277 | (for ([i (vector-length x)]) 278 | (set! ret (and ret (all? (vector-ref x i) f)))) 279 | ret] 280 | 281 | [(list? x) 282 | (andmap (lambda (xi) (all? xi f)) x)] 283 | 284 | [else (f x)])) 285 | 286 | (define (size-of x) 287 | (cond 288 | [(vector? x) 289 | (define len (vector-length x)) 290 | (if (> len 0) 291 | (* len (size-of (vector-ref x 0))) 292 | 0)] 293 | 294 | [(list? x) 295 | (define len (length x)) 296 | (if (> len 0) 297 | (* len (size-of (car x))) 298 | 0)] 299 | 300 | [else 1])) 301 | 302 | (define (inc-cost op ret args) 303 | (define op-cost (cost-of op)) 304 | 305 | (define inc 306 | (cond 307 | [(member op (list @+ @-)) 308 | (cond 309 | [(all? (first args) zero?) 0] 310 | [(all? (second args) zero?) 0] 311 | [(all? ret zero?) 0] 312 | [else op-cost])] 313 | 314 | [(member op (list @modulo)) 315 | (cond 316 | [(all? (second args) one?) 0] 317 | [else op-cost])] 318 | 319 | [(member op (list @*)) 320 | (cond 321 | [(all? (first args) zero?) 0] 322 | [(all? (first args) one?) 0] 323 | [(all? (first args) minus-one?) 0] 324 | [(all? (second args) zero?) 0] 325 | [(all? (second args) one?) 0] 326 | [(all? (second args) minus-one?) 0] 327 | [else op-cost])] 328 | 329 | [(member op (list @quotient)) 330 | (cond 331 | [(all? (second args) one?) 0] 332 | [else op-cost])] 333 | 334 | [(member op (list @bvadd @bvsub)) 335 | (cond 336 | [(all? ret zero-bv?) 0] 337 | [else op-cost])] 338 | 339 | [(member op (list @bvshl @bvlshr)) 340 | (cond 341 | [(all? (second args) zero-bv?) 0] 342 | [else op-cost])] 343 | 344 | [else op-cost] 345 | )) 346 | ;;(set! cost (+ cost inc)) 347 | (void) 348 | ) 349 | 350 | (define (accumulate-cost ops vals) 351 | (define (f ops vals) 352 | (cond 353 | [(vector? vals) 354 | (* (vector-length vals) 355 | (+ (cost-of (car ops)) (f (cdr ops) (vector-ref vals 0))))] 356 | 357 | [(list? vals) 358 | (* (length vals) 359 | (+ (cost-of (car ops)) (f (cdr ops) (car vals))))] 360 | 361 | [(empty? ops) 0] 362 | [else (cost-of (car ops))] 363 | )) 364 | 365 | (define inc 366 | (cond 367 | [(vector? vals) 368 | (+ (cost-of (last ops)) 369 | (f (cdr (reverse ops)) (vector-ref vals 0)))] 370 | 371 | [else 372 | (f (reverse ops) vals)])) 373 | ;;(set! cost (+ cost inc)) 374 | (set! cost (+ cost 1)) 375 | ) 376 | 377 | (define (global-cost pattern sizes) 378 | (define pattern-x (get-x pattern)) 379 | (define my-cost 380 | (if (= pattern-x 1) 381 | (+ 1 (quotient (apply * sizes) blockSize)) 382 | (* 4 (+ 1 (quotient (apply * sizes) blockSize))))) 383 | ;;(set! cost (+ cost my-cost)) 384 | (void) 385 | ) 386 | 387 | ;;;;;;;;;;;;;;;;;;;;;;;;;;; memory operations ;;;;;;;;;;;;;;;;;;;;;;;;;;; 388 | (define-syntax-rule 389 | (for/bounded ([i I]) body ...) 390 | (letrec ([f (lambda (i bound) 391 | (when (< i I) 392 | (if (> bound 0) 393 | (begin 394 | body ... 395 | (f (+ i 1) (- bound 1))) 396 | (assert #f))))]) 397 | (f 0 8))) 398 | 399 | ;; Create a local matrix. 400 | (define (create-matrix-local dims [init (lambda () 0)]) 401 | (create-matrix (append dims (list blockSize)))) 402 | 403 | ;; Load I in global memory to I-shared in shared memory 404 | ;; pattern -- (x-y-z stride-x ...) 405 | ;; >> each thread load stride-x * stride-y * ... consecutive block in round-robin fasion 406 | ;; offset -- the starting x-y-z coordinate of global memory that the thread block loads. 407 | ;; sizes -- (x-y-z size-x ...) 408 | ;; >> each thread block loads size-x * size-y * ... values 409 | ;; transpose -- #t for load with transpose 410 | ;; round -- (x-y-z round-x ...) or just round-x for 1D. Round of the round robin to fully load 'sizes'. 411 | ;; gsize -- (x-y-z gsize-x ...) size of global memory, must be specified for 2D and 3D 412 | (define (global-to-shared I I-shared pattern offset sizes [transpose #f] 413 | #:round [round 1] #:size [gsize #f]) 414 | (global-cost pattern sizes) 415 | (define bounds (get-dims I)) 416 | (pretty-display `(sizes ,sizes)) 417 | (pretty-display `(bounds ,(@* blockDim pattern round))) 418 | (assert (all? (@<= sizes (@* blockDim pattern round)) true?) "size 1") 419 | (assert (all? (@> sizes (@* blockDim pattern (@- round 1))) true?) "size 2") 420 | (when (> (length pattern) 1) (assert gsize "#:size must be specified for dimenion > 1")) 421 | 422 | (cond 423 | [(= (length offset) 1) 424 | (let ([size-x (get-x sizes)] 425 | [bound-x (get-x bounds)] 426 | [offset-x (get-x offset)]) 427 | (when (vector? offset-x) (set! offset-x (vector-ref offset-x 0))) 428 | (for ([i size-x]) 429 | (when (< (+ offset-x i) bound-x) 430 | (set I-shared i (get I (+ offset-x i))))))] 431 | 432 | [(= (length offset) 2) 433 | (let ([size-x (get-x sizes)] 434 | [size-y (get-y sizes)] 435 | [bound-x (get-x bounds)] 436 | [bound-y (get-y bounds)] 437 | [offset-x (get-x offset)] 438 | [offset-y (get-y offset)]) 439 | (when (vector? offset-x) 440 | (set! offset-x (vector-ref offset-x 0)) 441 | (set! offset-y (vector-ref offset-y 0)) 442 | ) 443 | (for* ([y size-y] [x size-x]) 444 | (when (and (< (+ offset-x x) bound-x) (< (+ offset-y y) bound-y)) 445 | (if transpose 446 | (set I-shared y x (get I (+ offset-x x) (+ offset-y y))) 447 | (set I-shared x y (get I (+ offset-x x) (+ offset-y y)))))))] 448 | 449 | [(= (length offset) 3) 450 | (let ([size-x (get-x sizes)] 451 | [size-y (get-y sizes)] 452 | [size-z (get-z sizes)] 453 | [bound-x (get-x bounds)] 454 | [bound-y (get-y bounds)] 455 | [bound-z (get-z bounds)] 456 | [offset-x (get-x offset)] 457 | [offset-y (get-y offset)] 458 | [offset-z (get-z offset)]) 459 | (when (vector? offset-x) 460 | (set! offset-x (vector-ref offset-x 0)) 461 | (set! offset-y (vector-ref offset-y 0)) 462 | (set! offset-z (vector-ref offset-z 0)) 463 | ) 464 | (for* ([z size-z] [y size-y] [x size-x]) 465 | (when (and (< (+ offset-x x) bound-x) (< (+ offset-y y) bound-y) (< (+ offset-z z) bound-z)) 466 | (if transpose 467 | (set I-shared z y x (get I (+ offset-x x) (+ offset-y y) (+ offset-z z))) 468 | (set I-shared x y z (get I (+ offset-x x) (+ offset-y y) (+ offset-z z)))))))] 469 | )) 470 | 471 | ;; Similar to global-to-shared but 472 | ;; for storing I-shared in shared memory to I in global memory 473 | (define (shared-to-global I-shared I pattern offset sizes [transpose #f] #:round [round 1] #:size [s #f]) 474 | (if transpose 475 | (global-cost (reverse pattern) (reverse sizes)) 476 | (global-cost pattern sizes)) 477 | (define bounds (get-dims I)) 478 | (assert (all? (@<= sizes (@* blockDim pattern round)) true?)) 479 | (assert (all? (@> sizes (@* blockDim pattern (@- round 1))) true?)) 480 | 481 | (cond 482 | [(= (length offset) 1) 483 | (let ([size-x (get-x sizes)] 484 | [bound-x (get-x bounds)] 485 | [offset-x (get-x offset)]) 486 | (for ([i size-x]) 487 | (when (< (+ offset-x i) bound-x) 488 | (set I (+ offset-x i) (get I-shared i)))))] 489 | 490 | [(= (length offset) 2) 491 | (let ([size-x (get-x sizes)] 492 | [size-y (get-y sizes)] 493 | [bound-x (get-x bounds)] 494 | [bound-y (get-y bounds)] 495 | [offset-x (get-x offset)] 496 | [offset-y (get-y offset)]) 497 | (for* ([y size-y] [x size-x]) 498 | (when (and (< (+ offset-x x) bound-x) (< (+ offset-y y) bound-y)) 499 | (if transpose 500 | (set I (+ offset-x y) (+ offset-y x) (get I-shared x y)) 501 | (set I (+ offset-x x) (+ offset-y y) (get I-shared x y))))))] 502 | 503 | [(= (length offset) 3) 504 | (let ([size-x (get-x sizes)] 505 | [size-y (get-y sizes)] 506 | [size-z (get-z sizes)] 507 | [bound-x (get-x bounds)] 508 | [bound-y (get-y bounds)] 509 | [bound-z (get-z bounds)] 510 | [offset-x (get-x offset)] 511 | [offset-y (get-y offset)] 512 | [offset-z (get-z offset)]) 513 | (for* ([z size-z] [y size-y] [x size-x]) 514 | (when (and (< (+ offset-x x) bound-x) (< (+ offset-y y) bound-y) (< (+ offset-z z) bound-z)) 515 | (if transpose 516 | (set I (+ offset-x z) (+ offset-y y) (+ offset-z x) (get I-shared x y z)) 517 | (set I (+ offset-x x) (+ offset-y y) (+ offset-z z) (get I-shared x y z))))))] 518 | )) 519 | 520 | ;; Load I in global memory at offset to register I-reg. 521 | ;; gsize -- (x-y-z gsize-x ...) size of global memory, must be specified for 2D and 3D 522 | (define-syntax global-to-reg 523 | (syntax-rules () 524 | ((global-to-reg I I-reg offset) 525 | (let* ([bounds (get-dims I)] 526 | [blockSize (vector-length offset)] 527 | [new-I-reg (make-vector blockSize #f)]) 528 | (global-cost (list 1) (list (size-of I-reg))) 529 | (for ([t blockSize]) 530 | (set new-I-reg t (clone I-reg))) 531 | (set! I-reg new-I-reg) 532 | (for ([i blockSize] 533 | [global-i offset]) 534 | (when (for/and ([b bounds] [i global-i]) (< i b)) 535 | (set I-reg i (get* I global-i)))))) 536 | 537 | ((global-to-reg I I-reg offset #:size gsize) 538 | (global-to-reg I I-reg offset)))) 539 | 540 | 541 | ;; Store register I-reg to I in global memory at offset. 542 | ;; gsize -- (x-y-z gsize-x ...) size of global memory, must be specified for 2D and 3D 543 | (define (reg-to-global I-reg I offset #:size [gsize #f]) 544 | (let* ([bounds (get-dims I)] 545 | [blockSize (vector-length offset)]) 546 | (global-cost (list 1) (list (size-of I-reg))) 547 | (for ([i blockSize] 548 | [global-i offset]) 549 | (when (for/and ([b bounds] [i global-i]) (< i b)) 550 | (set* I global-i (get I-reg i)))))) 551 | 552 | ;; Update I in global memory at offset to f(old_value, I-reg) 553 | ;; gsize -- (x-y-z gsize-x ...) size of global memory, must be specified for 2D and 3D 554 | (define (reg-to-global-update f I-reg I offset #:size [gsize #f] #:pred [pred (make-vector blockSize)]) 555 | (let* ([bounds (get-dims I)] 556 | [blockSize (vector-length offset)]) 557 | (global-cost (list 1) (list (size-of I-reg))) 558 | (for ([i blockSize] 559 | [global-i offset]) 560 | (when (and (vector-ref pred i) 561 | (for/and ([b bounds] [i global-i]) 562 | (< i b))) 563 | (set* I global-i (f (get* I global-i) (get I-reg i))))))) 564 | 565 | ;; Load I in global memory to I-reg in local memory/registers 566 | ;; pattern -- (x-y-z stride-x ...) 567 | ;; >> each thread load stride-x * stride-y * ... consecutive block in round-robin fasion 568 | ;; offset -- the starting x-y-z coordinate of global memory that the warp loads. 569 | ;; sizes -- (x-y-z size-x ...) 570 | ;; >> each warp loads size-x * size-y * ... values 571 | ;; transpose -- #t for load with transpose 572 | ;; warp-shape -- (x-y-z shape-x shape-y ...) must be specified for 2D and 3D 573 | ;; round -- (x-y-z round-x ...) or just round-x for 1D. Round of the round robin to fully load 'sizes'. 574 | ;; shfl -- shuffle function for load with shuffle. 'k' is the iteration of the round robin. 575 | ;; gsize -- (x-y-z gsize-x ...) size of global memory, must be specified for 2D and 3D 576 | (define (global-to-local I I-reg pattern offset sizes transpose 577 | #:warp-shape [warp-shape warpSize] 578 | #:round [round 1] 579 | #:shfl [shfl (lambda (tid k) tid)] 580 | #:size [gsize #f]) 581 | (global-cost pattern sizes) 582 | (assert (all? (@<= sizes (@* warp-shape pattern round)) true?)) 583 | (assert (all? (@> sizes (@* warp-shape pattern (@- round 1))) true?)) 584 | (cond 585 | [(= (length blockDim) 1) 586 | (let* ([size-x (get-x sizes)] 587 | [stride-x (get-x pattern)] 588 | [blockSize (apply * blockDim)] 589 | [iter-x (add1 (quotient (sub1 size-x) (* warpSize stride-x)))] 590 | [I-len (vector-length I)] 591 | [I-reg-len (vector-length (vector-ref I-reg 0))]) 592 | (for ([warp (quotient blockSize warpSize)]) 593 | (let ([offset-x (if (vector? offset) 594 | (get-x (vector-ref offset (* warp warpSize))) 595 | (vector-ref (get-x offset) (* warp warpSize)))]) 596 | ;(pretty-display `(offset-x ,offset-x)) 597 | (for/bounded ([it iter-x]) 598 | (for ([t warpSize]) 599 | (let ([t-from (shfl t it)]) 600 | (for/bounded ([my-i stride-x]) 601 | ;(pretty-display `(loop ,warp ,it ,t ,my-i)) ;; (loop 1 1 0 0) 602 | (let ([global-x (+ offset-x (* it stride-x warpSize) (* stride-x t-from) my-i)] 603 | [local-x (+ my-i (* it stride-x))]) 604 | (when (and (< global-x I-len) 605 | (< local-x I-reg-len) 606 | ) 607 | (vector-set! (vector-ref I-reg (+ t (* warp warpSize))) ;; thread in a block 608 | local-x ;; local index 609 | (vector-ref I global-x)) 610 | ;(pretty-display `(loop-true)) 611 | )))))) 612 | ))) 613 | ] 614 | 615 | [(= (length blockDim) 2) 616 | (let* ([size-x (get-x sizes)] 617 | [size-y (get-y sizes)] 618 | [stride-x (get-x pattern)] 619 | [stride-y (get-y pattern)] 620 | [warp-shape-x (get-x warp-shape)] 621 | [warp-shape-y (get-y warp-shape)] 622 | [blockSize (apply * blockDim)] 623 | [iter-x (add1 (quotient (sub1 size-x) (* warp-shape-x stride-x)))] 624 | [iter-y (add1 (quotient (sub1 size-y) (* warp-shape-y stride-y)))] 625 | [I-len-x (vector-length (vector-ref I 0))] 626 | [I-len-y (vector-length I)] 627 | [I-reg-len-y (vector-length (vector-ref I-reg 0))] 628 | [I-reg-len-x (vector-length (vector-ref (vector-ref I-reg 0) 0))]) 629 | (for ([warp (quotient blockSize warpSize)]) 630 | ;(pretty-display `(>>> warp ,warp ,offset)) 631 | (let ([offset-x (if (vector? offset) 632 | (get-x (vector-ref offset (* warp warpSize))) 633 | (vector-ref (get-x offset) (* warp warpSize)))] 634 | [offset-y (if (vector? offset) 635 | (get-y (vector-ref offset (* warp warpSize))) 636 | (vector-ref (get-y offset) (* warp warpSize)))]) 637 | ;(pretty-display `(offset-x ,offset-x)) 638 | (for/bounded ([it-y iter-y]) 639 | (for/bounded ([it-x iter-x]) 640 | ;(pretty-display `(iter ,warp ,it-y ,it-x)) 641 | (for ([t warpSize]) 642 | (let ([t-from (shfl t (+ (* it-y iter-x) it-x))]) 643 | (for/bounded ([my-y stride-y]) 644 | (for/bounded ([my-x stride-x]) 645 | ;(pretty-display `(loop ,warp ,it-x ,t ,my-x)) 646 | (let ([global-y (+ offset-y 647 | (* it-y warp-shape-y stride-y) ;; TODO (* size-y warp) 648 | (* (quotient t-from warp-shape-x) stride-y) my-y)] 649 | [global-x (+ offset-x 650 | (* it-x warp-shape-x stride-x) ;; TODO (* size-x warp) 651 | (* (modulo t-from warp-shape-x) stride-x) my-x)] 652 | [local-y (+ my-y (* it-y stride-y))] 653 | [local-x (+ my-x (* it-x stride-x))] 654 | ) 655 | ;(pretty-display `(info ,warp ,t ,global-y ,global-x ,local-y ,local-x)) 656 | (when (and (< global-y I-len-y) (< global-x I-len-x) 657 | (< local-x I-reg-len-x) (< local-y I-reg-len-y) 658 | ) 659 | (set I-reg local-x local-y 660 | (+ t (* warp warpSize)) ;; thread in a block 661 | (get I global-x global-y 662 | )))))))))) 663 | ))) 664 | ] 665 | 666 | ;; TODO 667 | [else (raise "unimplemented")] 668 | )) 669 | 670 | ;; Similar to global-to-local but 671 | ;; for storing I-reg in local memory/registers to I in global memory 672 | (define (local-to-global I-reg I pattern offset sizes transpose 673 | #:warp-shape [warp-shape warpSize] 674 | #:round [round 1] 675 | #:shfl [shfl (lambda (tid k) tid)] 676 | #:size [gsize #f]) 677 | (begin 678 | (if transpose 679 | (global-cost (reverse pattern) (reverse sizes)) 680 | (global-cost pattern sizes)) 681 | (assert (all? (@<= sizes (@* warp-shape pattern round)) true?)) 682 | (assert (all? (@> sizes (@* warp-shape pattern (@- round 1))) true?)) 683 | (cond 684 | [(= (length blockDim) 1) 685 | (let* ([size-x (get-x sizes)] 686 | [stride-x (get-x pattern)] 687 | [blockSize (apply * blockDim)] 688 | [iter-x (add1 (quotient (sub1 size-x) (* warpSize stride-x)))] 689 | [I-len (vector-length I)] 690 | [I-reg-len (vector-length I-reg)] 691 | [new-I-reg (make-vector blockSize #f)]) 692 | ;(pretty-display `(iterate ,(quotient blockSize warpSize) ,iter-x ,stride-x)) 693 | (for ([warp (quotient blockSize warpSize)]) 694 | (let ([offset-x (if (vector? offset) 695 | (get-x (vector-ref offset (* warp warpSize))) 696 | (vector-ref (get-x offset) (* warp warpSize)))] 697 | [inc-x 0]) 698 | ;(pretty-display `(offset-x ,offset-x)) 699 | #;(for/bounded ([it iter-x]) 700 | (for ([t warpSize]) 701 | (for/bounded ([my-i stride-x]) 702 | (when (and (< inc-x size-x) 703 | (< (+ offset-x inc-x) I-len) 704 | (< (+ my-i (* it stride-x)) I-reg-len) 705 | ) 706 | (vector-set! I (+ offset-x inc-x) 707 | (vector-ref 708 | (vector-ref I-reg (+ t (* warp warpSize))) ;; thread in a block 709 | (+ my-i (* it stride-x)))) ;; local index 710 | ) 711 | (set! inc-x (+ inc-x 1))))) 712 | (for/bounded ([it iter-x]) 713 | (for ([t warpSize]) 714 | (let ([t-from (shfl t it)]) 715 | (for/bounded ([my-i stride-x]) 716 | ;(pretty-display `(loop ,warp ,it ,t ,my-i)) 717 | (when (and (< inc-x size-x) 718 | (< (+ offset-x inc-x) I-len) 719 | (< (+ my-i (* it stride-x)) I-reg-len) 720 | ) 721 | (vector-set! I (+ offset-x (* it stride-x warpSize) (* stride-x t-from) my-i) 722 | (vector-ref 723 | (vector-ref I-reg (+ t (* warp warpSize))) ;; thread in a block 724 | (+ my-i (* it stride-x)))) ;; local index 725 | ))))) 726 | ))) 727 | ] 728 | 729 | ;; TODO 730 | [else (raise "unimplemented")] 731 | ))) 732 | 733 | ;;;;;;;;;;;;;;;;;;;;;;;;;;; intra-warp shuffle operations ;;;;;;;;;;;;;;;;;;;;;;;;;;; 734 | 735 | (define (shfl val lane) 736 | (define len (vector-length val)) 737 | (define res (make-vector len #f)) 738 | 739 | (define lane-vec 740 | (if (vector? lane) 741 | (for/vector ([i (vector-length lane)]) (modulo (get lane i) warpSize)) 742 | (for/vector ([i len]) (modulo lane warpSize)))) 743 | 744 | (for ([iter (quotient len warpSize)]) 745 | (let ([offset (* iter warpSize)]) 746 | (for ([i warpSize]) 747 | (let ([i-dest (+ offset i)] 748 | [i-src (+ offset (get lane-vec (+ offset i)))]) 749 | (set res i-dest (get val i-src)))))) 750 | 751 | ;(set! cost (+ cost 2)) 752 | res) 753 | 754 | ;; Scatter version of shuffle instruction. This instruction doesn't exist in GPU. 755 | ;; This function is for convenient uses. 756 | (define (shfl-send val lane) 757 | (define len (vector-length val)) 758 | (define res (make-vector len #f)) 759 | 760 | (define lane-vec 761 | (if (vector? lane) 762 | (for/vector ([i (vector-length lane)]) (modulo (get lane i) warpSize)) 763 | (for/vector ([i len]) (modulo lane warpSize)))) 764 | 765 | (for ([iter (quotient len warpSize)]) 766 | (let ([offset (* iter warpSize)]) 767 | (for ([i warpSize]) 768 | (let ([i-src (+ offset i)] 769 | [i-dest (+ offset (get lane-vec (+ offset i)))]) 770 | (set res i-dest (get val i-src)))))) 771 | 772 | ;(set! cost (+ cost 2)) 773 | res) 774 | 775 | ;;;;;;;;;;;;;;;;;;;;;;;;;;; accumulators ;;;;;;;;;;;;;;;;;;;;;;;;;;; 776 | 777 | (struct accumulator (val oplist opfinal veclen) #:mutable) 778 | 779 | ;; Multiset equal 780 | (define (multiset= x y) 781 | (cond 782 | [(and (list? x) (list? y)) 783 | (define ret (= (length x) (length y))) 784 | (for ([xi x]) 785 | (let ([f (lambda (yi) (multiset= xi yi))]) 786 | (set! ret (and ret (= (count f x) (count f y)))))) 787 | ret] 788 | 789 | [else (equal? x y)])) 790 | 791 | ;; Accumulator equal 792 | (define (acc=? x y recursive-equal?) 793 | (and (multiset= (accumulator-val x) (accumulator-val y)) 794 | (equal? (accumulator-oplist x) (accumulator-oplist y)) 795 | (equal? (accumulator-opfinal x) (accumulator-opfinal y)))) 796 | 797 | ;; Create an accumulator or a vector of accumulators. 798 | (define-syntax create-accumulator 799 | (syntax-rules () 800 | ((create-accumulator op-list final-op) 801 | (accumulator (list) op-list final-op #f)) 802 | ((create-accumulator op-list final-op blockDim) 803 | (build-vector (apply * blockDim) 804 | (lambda (i) (accumulator (list) op-list final-op (apply * blockDim))))))) 805 | 806 | (define-syntax-rule (get-accumulator-val x) 807 | (if (vector? x) 808 | (for/vector ([xi x]) (accumulator-val xi)) 809 | (accumulator-val x))) 810 | 811 | ;; Convert to a vector of sorted lists. 812 | (define (vector-of-list l veclen) 813 | (for/vector ([i veclen]) 814 | (let ([each (map (lambda (x) (if (vector? x) (get x i) x)) l)]) 815 | (%sort each (lambda (x y) (stringvector (reverse ret))) 942 | 943 | (define (run-grid kernel my-gridDim my-blockDim threadIds args) 944 | (set! gridDim my-gridDim) 945 | (set! blockDim my-blockDim) 946 | (set! blockSize (apply * my-blockDim)) 947 | (reset-cost) 948 | 949 | (define (f blockID sizes) 950 | (if (empty? sizes) 951 | (begin 952 | (pretty-display `(blockID ,blockID ,blockDim ,threadIds)) 953 | (apply kernel (append (list threadIds blockID blockDim) args))) 954 | (for ([i (car sizes)]) 955 | (f (cons i blockID) (cdr sizes))))) 956 | (f (list) (reverse gridDim)) 957 | ;;(pretty-display `(cost ,cost)) 958 | ) 959 | 960 | ;; Run a kernel. 961 | (define-syntax-rule (run-kernel kernel my-blockDim my-gridDim x ...) 962 | (let ([Ids (get-threadId my-blockDim)]) 963 | (run-grid kernel my-gridDim my-blockDim Ids (list x ...)))) 964 | -------------------------------------------------------------------------------- /ex1-transpose.rkt: -------------------------------------------------------------------------------- 1 | #| 2 | | Copyright (c) 2018-2019, University of California, Berkeley. 3 | | Copyright (c) 2018-2019, NVIDIA CORPORATION. All rights reserved. 4 | | 5 | | Redistribution and use in source and binary forms, with or without 6 | | modification, are permitted provided that the following conditions are met: 7 | | 8 | | 1. Redistributions of source code must retain the above copyright notice, 9 | | this list of conditions and the following disclaimer. 10 | | 11 | | 2. Redistributions in binary form must reproduce the above copyright notice, 12 | | this list of conditions and the following disclaimer in the documentation 13 | | and/or other materials provided with the distribution. 14 | | 15 | | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 16 | | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 17 | | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 18 | | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 19 | | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 20 | | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 21 | | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 22 | | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 23 | | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 24 | | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 25 | | POSSIBILITY OF SUCH DAMAGE. 26 | |# 27 | 28 | #lang rosette 29 | 30 | (require "util.rkt" "cuda.rkt") 31 | 32 | (define (transpose-spec I O sizes) 33 | (for* ([y (get-y sizes)] 34 | [x (get-x sizes)]) 35 | (set O y x (get I x y)))) 36 | 37 | (define sizes (x-y-z 5 5)) 38 | (define I (create-matrix sizes 39 | (lambda () (define-symbolic* x integer?) x))) 40 | (define O (create-matrix (reverse sizes))) 41 | (define O* (create-matrix (reverse sizes))) 42 | 43 | (transpose-spec I O sizes) 44 | 45 | (define (transpose1 threadId blockID blockDim I O) 46 | (define-shared I-shared (create-matrix (reverse blockDim))) 47 | (define offset (* blockID blockDim)) 48 | (global-to-shared I I-shared 49 | (x-y-z 1 1) 50 | offset blockDim 51 | #:transpose #t) 52 | (shared-to-global I-shared O 53 | (x-y-z 1 1) 54 | (reverse offset) (reverse blockDim)) 55 | ) 56 | 57 | (define (transpose2 threadId blockID blockDim I O) 58 | (define tileDim (x-y-z 4 4)) 59 | (define-shared I-shared (create-matrix (reverse tileDim))) 60 | (define offset (* blockID tileDim)) 61 | (global-to-shared I I-shared 62 | (x-y-z 1 1) 63 | offset tileDim #t 64 | #:round (x-y-z 1 4) #:size sizes) 65 | (shared-to-global I-shared O 66 | (x-y-z 1 1) 67 | (reverse offset) (reverse tileDim) 68 | #:round (x-y-z 1 4) #:size sizes) 69 | ) 70 | 71 | ;;(run-kernel transpose1 (x-y-z 2 2) (x-y-z 3 3) I O*) 72 | (run-kernel transpose2 (x-y-z 4 1) (x-y-z 2 2) I O*) 73 | (pretty-display `(O* ,O*)) 74 | (verify #:guarantee (assert (equal? O O*))) 75 | -------------------------------------------------------------------------------- /ex2-conv1d.rkt: -------------------------------------------------------------------------------- 1 | #| 2 | | Copyright (c) 2018-2019, NVIDIA CORPORATION. All rights reserved. 3 | | 4 | | Redistribution and use in source and binary forms, with or without 5 | | modification, are permitted provided that the following conditions are met: 6 | | 7 | | 1. Redistributions of source code must retain the above copyright notice, 8 | | this list of conditions and the following disclaimer. 9 | | 10 | | 2. Redistributions in binary form must reproduce the above copyright notice, 11 | | this list of conditions and the following disclaimer in the documentation 12 | | and/or other materials provided with the distribution. 13 | | 14 | | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 15 | | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 16 | | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 17 | | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 18 | | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 19 | | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 20 | | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 21 | | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 22 | | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 23 | | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 24 | | POSSIBILITY OF SUCH DAMAGE. 25 | |# 26 | 27 | #lang rosette 28 | 29 | (require "util.rkt" "cuda.rkt" "cuda-synth.rkt") 30 | 31 | (define n-block 2) 32 | (define /3 (lambda (x) (/ x 3))) 33 | 34 | ;; Create input and output matrices. 35 | (define (create-IO warpSize) 36 | (set-warpSize warpSize) 37 | (define block-size (* 2 warpSize)) 38 | (define I-sizes (x-y-z (* 2 block-size))) 39 | (define O-sizes (- I-sizes 2)) 40 | (define W (create-matrix (list 3) gen-uid)) 41 | (define I (create-matrix I-sizes gen-uid)) 42 | (define O (create-matrix O-sizes)) 43 | (define O* (create-matrix O-sizes)) 44 | (values block-size I-sizes O-sizes I O O* W)) 45 | 46 | ;; Run sequential program spec and GPU kernel kernel, and compare their outputs. 47 | (define (run-with-warp-size spec kernel w) 48 | (define-values (block-size I-sizes O-sizes I O O* W) 49 | (create-IO w)) 50 | 51 | (spec I O W O-sizes) 52 | (run-kernel kernel (x-y-z block-size) (x-y-z n-block) I O* W I-sizes O-sizes) 53 | ;(acc-print O*) 54 | (acc-equal? O O*) 55 | ) 56 | 57 | ;; Sequential program spec 58 | (define (conv1d-spec I O W o-sizes) 59 | (for ([i (get-x o-sizes)]) 60 | (let ([o (create-accumulator (list * +) identity)]) 61 | (for ([j 3]) 62 | (accumulate o (list (get W j) (get I (+ i j))))) 63 | (set O i o)))) 64 | 65 | ;; Complete kernel 66 | (define (conv1d threadId blockID blockDim I O W I-sizes O-sizes) 67 | (define I-cached (create-matrix-local (x-y-z 2))) 68 | (define warpID (get-warpId threadId)) 69 | (define offset (+ (* blockID blockDim) (* warpID warpSize))) ;; warpID = (threadIdy * blockDimx + threadIdx)/warpSize 70 | (define gid (get-global-threadId threadId blockID)) 71 | (global-to-local I I-cached 72 | (x-y-z 1) 73 | (+ (* blockID blockDim) (* warpID warpSize)) 74 | (x-y-z (+ warpSize 2)) #f #:round 2) 75 | 76 | (define localId (get-idInWarp threadId)) 77 | (define o (create-accumulator (list * +) identity blockDim)) 78 | (for ([i 3]) 79 | (let* ([index (ite (< localId i) 1 0)] 80 | [lane (+ i localId)] 81 | [x (shfl (get I-cached index) lane)] 82 | [w (@dup (get W i))] 83 | ) 84 | ;(pretty-display `(lane ,i ,localId ,lane)) 85 | (accumulate o (list w x)) 86 | )) 87 | (reg-to-global o O gid) 88 | ) 89 | 90 | ;; Kernel sketch 91 | (define (conv1d-sketch threadId blockID blockDim I O W I-sizes O-sizes) 92 | (define I-cached (create-matrix-local (x-y-z 2))) 93 | (define gid (+ (* blockID blockDim) threadId)) 94 | (define localId (get-idInWarp threadId)) 95 | (define offset (- gid localId)) 96 | (global-to-local I I-cached 97 | (x-y-z 1) 98 | offset 99 | (x-y-z (+ warpSize 2)) #f #:round 2) 100 | 101 | (define o (create-accumulator (list * +) identity blockDim)) 102 | 103 | (for/bounded ([i 3]) 104 | (let* ([index (ite (?cond (@dup i) localId) (@dup 0) (@dup 1))] 105 | [lane (?sw-xform localId warpSize 106 | i warpSize [])] 107 | [x (shfl (get I-cached index) lane)] 108 | [w (@dup (get W i))]) 109 | (accumulate o (list w x) #:pred (?cond localId (@dup i))) ; (?cond localId (@dup i)) 110 | )) 111 | 112 | (reg-to-global o O gid) 113 | ) 114 | 115 | ;; Check correctness of a complete kernel against a spec. 116 | (define (test) 117 | (for ([w (list 32)]) 118 | (let ([ret (run-with-warp-size conv1d-spec conv1d w)]) 119 | (pretty-display `(test ,w ,ret)))) 120 | ) 121 | ;(test) 122 | 123 | ;; Synthesize a kernel sketch given a spec. 124 | (define (synthesis) 125 | (pretty-display "solving...") 126 | (define sol 127 | (time (solve 128 | (assert (andmap 129 | (lambda (w) (run-with-warp-size conv1d-spec conv1d-sketch w)) 130 | (list 32)))))) 131 | (print-forms sol) 132 | ) 133 | (define t0 (current-seconds)) 134 | (synthesis) 135 | (define t1 (current-seconds)) 136 | (- t1 t0) 137 | 138 | ;; Synthesize data loading from global memory. 139 | (define (load-synth) 140 | (define-values (block-size I-sizes O-sizes I O O*) 141 | (create-IO 4)) 142 | 143 | ;; Store 144 | (define (conv1d-store threadId blockId blockDim O) 145 | (define warpID (get-warpId threadId)) 146 | (define o 147 | (for/vector ([w warpID] 148 | [t threadId]) 149 | (ID t w blockId))) 150 | (reg-to-global o O (get-global-threadId threadId blockId)) 151 | ) 152 | 153 | ;; Run spec 154 | (conv1d-spec I O O-sizes) 155 | 156 | ;; Collect IDs 157 | (define IDs (create-matrix O-sizes)) 158 | (run-kernel conv1d-store (x-y-z block-size) (x-y-z n-block) IDs) 159 | (define-values (threads warps blocks) (get-grid-storage)) 160 | (collect-inputs O IDs threads warps blocks) 161 | (define n-regs (num-regs warps I)) 162 | (pretty-display `(n-regs ,n-regs)) 163 | 164 | ;; Load 165 | (define (conv1d-load threadId blockId blockDim I warp-input-spec) 166 | (define warpId (get-warpId threadId)) 167 | ;; sketch starts 168 | (define I-cached (create-matrix-local (x-y-z n-regs))) 169 | (global-to-local I I-cached 170 | (x-y-z (??)) ;; stride 171 | (x-y-z (?warp-offset [(get-x blockId) (get-x blockDim)] [warpId warpSize])) ;; offset 172 | (x-y-z (?warp-size warpSize 1)) ;; load size 173 | #f) 174 | 175 | ;; sketch ends 176 | (check-warp-input warp-input-spec I I-cached warpId blockId) 177 | ) 178 | 179 | (run-kernel conv1d-load (x-y-z block-size) (x-y-z n-block) I warps) 180 | (define sol (time (solve (assert #t)))) 181 | (when (sat? sol) 182 | (print-forms sol) 183 | #;(define sol-hash (match sol [(model m) m])) 184 | #;(for ([key-val (hash->list sol-hash)]) 185 | (let ([key (car key-val)] 186 | [val (cdr key-val)]) 187 | (when (string-contains? (format "~a" key) "stencil:115") ;; stride 188 | (assert (not (equal? key val))) 189 | (pretty-display `(v ,key ,val ,(string-contains? (format "~a" key) "stencil:113"))))) 190 | )) 191 | ) 192 | ;(load-synth) 193 | -------------------------------------------------------------------------------- /ex2-stencil.rkt: -------------------------------------------------------------------------------- 1 | #| 2 | | Copyright (c) 2018-2019, University of California, Berkeley. 3 | | Copyright (c) 2018-2019, NVIDIA CORPORATION. All rights reserved. 4 | | 5 | | Redistribution and use in source and binary forms, with or without 6 | | modification, are permitted provided that the following conditions are met: 7 | | 8 | | 1. Redistributions of source code must retain the above copyright notice, 9 | | this list of conditions and the following disclaimer. 10 | | 11 | | 2. Redistributions in binary form must reproduce the above copyright notice, 12 | | this list of conditions and the following disclaimer in the documentation 13 | | and/or other materials provided with the distribution. 14 | | 15 | | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 16 | | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 17 | | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 18 | | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 19 | | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 20 | | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 21 | | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 22 | | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 23 | | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 24 | | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 25 | | POSSIBILITY OF SUCH DAMAGE. 26 | |# 27 | 28 | #lang rosette 29 | 30 | (require "util.rkt" "cuda.rkt" "cuda-synth.rkt") 31 | 32 | (define n-block 2) 33 | (define /3 (lambda (x) (/ x 3))) 34 | 35 | ;; Create input and output matrices. 36 | (define (create-IO warpSize) 37 | (set-warpSize warpSize) 38 | (define block-size (* 2 warpSize)) 39 | (define I-sizes (x-y-z (* 2 block-size))) 40 | (define O-sizes (- I-sizes 2)) 41 | (define I (create-matrix I-sizes gen-uid)) 42 | (define O (create-matrix O-sizes)) 43 | (define O* (create-matrix O-sizes)) 44 | (values block-size I-sizes O-sizes I O O*)) 45 | 46 | ;; Run sequential program spec and GPU kernel kernel, and compare their outputs. 47 | (define (run-with-warp-size spec kernel w) 48 | (define-values (block-size I-sizes O-sizes I O O*) 49 | (create-IO w)) 50 | 51 | (spec I O O-sizes) 52 | (run-kernel kernel (x-y-z block-size) (x-y-z n-block) I O* I-sizes O-sizes) 53 | ;(acc-print O*) 54 | (acc-equal? O O*) 55 | ) 56 | 57 | ;; Sequential program spec 58 | (define (stencil-1d-spec I O o-sizes) 59 | (for ([i (get-x o-sizes)]) 60 | (let ([o (create-accumulator (list +) /3)]) 61 | (for ([j 3]) 62 | (accumulate o (get I (+ i j)))) 63 | (set O i o)))) 64 | 65 | ;; Complete kernel 66 | (define (stencil-1d threadId blockID blockDim I O I-sizes O-sizes) 67 | (define I-cached (create-matrix-local (x-y-z 2))) 68 | (define warpID (get-warpId threadId)) 69 | (define offset (+ (* blockID blockDim) (* warpID warpSize))) ;; warpID = (threadIdy * blockDimx + threadIdx)/warpSize 70 | (define gid (get-global-threadId threadId blockID)) 71 | (global-to-local I I-cached 72 | (x-y-z 1) 73 | (+ (* blockID blockDim) (* warpID warpSize)) 74 | (x-y-z (+ warpSize 2)) #f #:round 2) 75 | 76 | (define localId (get-idInWarp threadId)) 77 | (define o (create-accumulator (list +) /3 blockDim)) 78 | (for ([i 3]) 79 | (let* ([index (ite (< localId i) 1 0)] 80 | [lane (+ i localId)] 81 | [x (shfl (get I-cached index) lane)]) 82 | ;(pretty-display `(lane ,i ,localId ,lane)) 83 | (accumulate o x) 84 | )) 85 | (reg-to-global o O gid) 86 | ) 87 | 88 | ;; Kernel sketch 89 | (define (stencil-1d-sketch threadId blockID blockDim I O I-sizes O-sizes) 90 | (define I-cached (create-matrix-local (x-y-z 2))) 91 | (define gid (+ (* blockID blockDim) threadId)) 92 | (define localId (get-idInWarp threadId)) 93 | (define offset (- gid localId)) 94 | (global-to-local I I-cached 95 | (x-y-z 1) 96 | offset 97 | (x-y-z (+ warpSize 2)) #f #:round 2) 98 | 99 | (define o (create-accumulator (list +) /3 blockDim)) 100 | 101 | (for/bounded ([i 3]) 102 | (let* ([index (ite (?cond (@dup i) localId) (@dup 0) (@dup 1))] 103 | [lane (?sw-xform localId warpSize 104 | i warpSize [])] 105 | [x (shfl (get I-cached index) lane)]) 106 | (accumulate o x #:pred (?cond localId (@dup i))) ; (?cond localId (@dup i)) 107 | )) 108 | 109 | (reg-to-global o O gid) 110 | ) 111 | 112 | ;; Check correctness of a complete kernel against a spec. 113 | (define (test) 114 | (for ([w (list 32)]) 115 | (let ([ret (run-with-warp-size stencil-1d-spec stencil-1d-sketch w)]) 116 | (pretty-display `(test ,w ,ret)))) 117 | ) 118 | ;(test) 119 | 120 | ;; Synthesize a kernel sketch given a spec. 121 | (define (synthesis) 122 | (pretty-display "solving...") 123 | (define sol 124 | (time (solve 125 | (assert (andmap 126 | (lambda (w) (run-with-warp-size stencil-1d-spec stencil-1d-sketch w)) 127 | (list 32)))))) 128 | (print-forms sol) 129 | ) 130 | (define t0 (current-seconds)) 131 | (synthesis) 132 | (define t1 (current-seconds)) 133 | (- t1 t0) 134 | 135 | ;; Synthesize data loading from global memory. 136 | (define (load-synth) 137 | (define-values (block-size I-sizes O-sizes I O O*) 138 | (create-IO 4)) 139 | 140 | ;; Store 141 | (define (stencil-1d-store threadId blockId blockDim O) 142 | (define warpID (get-warpId threadId)) 143 | (define o 144 | (for/vector ([w warpID] 145 | [t threadId]) 146 | (ID t w blockId))) 147 | (reg-to-global o O (get-global-threadId threadId blockId)) 148 | ) 149 | 150 | ;; Run spec 151 | (stencil-1d-spec I O O-sizes) 152 | 153 | ;; Collect IDs 154 | (define IDs (create-matrix O-sizes)) 155 | (run-kernel stencil-1d-store (x-y-z block-size) (x-y-z n-block) IDs) 156 | (define-values (threads warps blocks) (get-grid-storage)) 157 | (collect-inputs O IDs threads warps blocks) 158 | (define n-regs (num-regs warps I)) 159 | (pretty-display `(n-regs ,n-regs)) 160 | 161 | ;; Load 162 | (define (conv1d-load threadId blockId blockDim I warp-input-spec) 163 | (define warpId (get-warpId threadId)) 164 | ;; sketch starts 165 | (define I-cached (create-matrix-local (x-y-z n-regs))) 166 | (global-to-local I I-cached 167 | (x-y-z (??)) ;; stride 168 | (x-y-z (?warp-offset [(get-x blockId) (get-x blockDim)] [warpId warpSize])) ;; offset 169 | (x-y-z (?warp-size warpSize 1)) ;; load size 170 | #f) 171 | 172 | ;; sketch ends 173 | (check-warp-input warp-input-spec I I-cached warpId blockId) 174 | ) 175 | 176 | (run-kernel conv1d-load (x-y-z block-size) (x-y-z n-block) I warps) 177 | (define sol (time (solve (assert #t)))) 178 | (when (sat? sol) 179 | (print-forms sol) 180 | #;(define sol-hash (match sol [(model m) m])) 181 | #;(for ([key-val (hash->list sol-hash)]) 182 | (let ([key (car key-val)] 183 | [val (cdr key-val)]) 184 | (when (string-contains? (format "~a" key) "stencil:115") ;; stride 185 | (assert (not (equal? key val))) 186 | (pretty-display `(v ,key ,val ,(string-contains? (format "~a" key) "stencil:113"))))) 187 | )) 188 | ) 189 | ;(load-synth) 190 | -------------------------------------------------------------------------------- /ex2-stencil2d.rkt: -------------------------------------------------------------------------------- 1 | #| 2 | | Copyright (c) 2018-2019, NVIDIA CORPORATION. All rights reserved. 3 | | 4 | | Redistribution and use in source and binary forms, with or without 5 | | modification, are permitted provided that the following conditions are met: 6 | | 7 | | 1. Redistributions of source code must retain the above copyright notice, 8 | | this list of conditions and the following disclaimer. 9 | | 10 | | 2. Redistributions in binary form must reproduce the above copyright notice, 11 | | this list of conditions and the following disclaimer in the documentation 12 | | and/or other materials provided with the distribution. 13 | | 14 | | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 15 | | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 16 | | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 17 | | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 18 | | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 19 | | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 20 | | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 21 | | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 22 | | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 23 | | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 24 | | POSSIBILITY OF SUCH DAMAGE. 25 | |# 26 | 27 | #lang rosette 28 | 29 | (require "util.rkt" "cuda.rkt" "cuda-synth.rkt") 30 | 31 | (define WARP_SIZE 16) 32 | (define n-block (x-y-z 1 1)) 33 | (define /9 (lambda (x) (/ x 9))) 34 | (define W 4) 35 | (define H 4) 36 | (define warp-shape (x-y-z W H)) 37 | 38 | (define (create-IO warpSize) 39 | (pretty-display `(warpSize ,warpSize)) 40 | (set-warpSize warpSize) 41 | (define block-size (x-y-z (* 1 warpSize) 1)) 42 | (define I-sizes (+ warp-shape 6)) 43 | (define O-sizes (- I-sizes 6)) 44 | (define I (create-matrix I-sizes gen-uid)) 45 | (define O (create-matrix O-sizes)) 46 | (define O* (create-matrix O-sizes)) 47 | (values block-size I-sizes O-sizes I O O*)) 48 | 49 | (define (run-with-warp-size spec kernel w) 50 | (define-values (block-size I-sizes O-sizes I O O*) 51 | (create-IO w)) 52 | 53 | (spec I O O-sizes) 54 | (run-kernel kernel block-size n-block I O* I-sizes O-sizes) 55 | ;(pretty-display ">>> O") 56 | ;(acc-print O) 57 | ;(pretty-display ">>> O*") 58 | ;(acc-print O*) 59 | (acc-equal? O O*) 60 | ) 61 | 62 | (define (stencil-2d-spec I O o-sizes) 63 | (for* ([j (get-y o-sizes)] 64 | [i (get-x o-sizes)]) 65 | (let ([o (create-accumulator (list +) /9)]) 66 | (for* ([jj 7] [ii 7]) 67 | (accumulate o (get I (+ i ii) (+ j jj))) 68 | (set O i j o))))) 69 | 70 | (define (stencil-2d threadId blockID blockDim I O I-sizes O-sizes) 71 | (define gid (+ (* blockID blockDim) threadId)) 72 | (define gx (get-x gid)) 73 | (define gy (get-y gid)) 74 | (define id (modulo (get-x threadId) warpSize)) 75 | (define warp-col (modulo id W)) 76 | (define warp-row (quotient id W)) 77 | 78 | (define offset-x (* (quotient gx warpSize) W)) 79 | (define offset-y (* gy H)) 80 | 81 | (define I-cached (create-matrix-local (x-y-z 2 2))) 82 | (global-to-local I I-cached 83 | (x-y-z 1 1) 84 | (lov2vol (x-y-z offset-x offset-y)) 85 | (+ warp-shape 2) #f 86 | #:warp-shape warp-shape #:round (x-y-z 2 2)) 87 | 88 | (define o (create-accumulator (list +) /9 blockDim)) 89 | 90 | (for* ([ky 3] [kx 3]) 91 | (let* ([index-j (ite (< warp-row ky) 1 0)] 92 | [index-i (ite (< warp-col kx) 1 0)] 93 | [lane-x (sw-xform warp-col W 1 W W 1 94 | kx 3 1 3)] 95 | [lane-y (sw-xform warp-row H 1 H H 1 96 | ky 3 1 3)] 97 | [lane (+ (* lane-y W) lane-x)] 98 | [x (shfl (get I-cached index-i index-j) lane)]) 99 | (accumulate o x) 100 | )) 101 | ;(acc-print o) 102 | (reg-to-global (accumulate-final o) O 103 | (lov2vol (x-y-z (+ offset-x warp-col) (+ offset-y warp-row)))) 104 | ) 105 | 106 | (define (stencil-2d-sketch threadId blockID blockDim I O I-sizes O-sizes) 107 | (define gid (+ (* blockID blockDim) threadId)) 108 | (define gx (get-x gid)) 109 | (define gy (get-y gid)) 110 | (define id (modulo (get-x threadId) warpSize)) 111 | (define warp-col (modulo id W)) 112 | (define warp-row (quotient id W)) 113 | 114 | (define offset-x (* (quotient gx warpSize) W)) 115 | (define offset-y (* gy H)) 116 | 117 | (define I-cached (create-matrix-local (x-y-z 2 2))) 118 | (global-to-local I I-cached 119 | (x-y-z 1 1) 120 | (lov2vol (x-y-z offset-x offset-y)) 121 | (+ warp-shape 2) #f 122 | #:warp-shape warp-shape #:round (x-y-z 2 2)) 123 | 124 | (define o (create-accumulator (list +) /9 blockDim)) 125 | 126 | (for* ([ky 3] [kx 3]) 127 | (let* ([index-j (ite (?cond warp-row ky) (@dup 0) (@dup 1))] 128 | [index-i (ite (?cond warp-col kx) (@dup 0) (@dup 1))] 129 | [lane-x (?sw-xform warp-col W 130 | kx 3 [])] 131 | [lane-y (?sw-xform warp-row H 132 | ky 3 [])] 133 | [lane (+ (* lane-y W) lane-x)] 134 | [x (shfl (get I-cached index-i index-j) lane)]) 135 | (accumulate o x) 136 | )) 137 | ;(acc-print o) 138 | (reg-to-global (accumulate-final o) O 139 | (lov2vol (x-y-z (+ offset-x warp-col) (+ offset-y warp-row)))) 140 | ) 141 | 142 | (define (stencil-2d-sketch2 threadId blockID blockDim I O I-sizes O-sizes) 143 | (define gid (+ (* blockID blockDim) threadId)) 144 | (define gx (get-x gid)) 145 | (define gy (get-y gid)) 146 | (define id (modulo (get-x threadId) warpSize)) 147 | (define warp-col (modulo id W)) 148 | (define warp-row (quotient id W)) 149 | 150 | (define offset-x (* (quotient gx warpSize) W)) 151 | (define offset-y (* gy H)) 152 | 153 | (define I-cached (create-matrix-local (x-y-z 3 3))) 154 | (global-to-local I I-cached 155 | (x-y-z 1 1) 156 | (lov2vol (x-y-z offset-x offset-y)) 157 | (+ warp-shape 6) #f 158 | #:warp-shape warp-shape #:round (x-y-z 3 3)) 159 | 160 | (define o (create-accumulator (list +) /9 blockDim)) 161 | 162 | (for* ([ky 7] [kx 7]) 163 | (let* ([index-j (ite (?cond warp-row ky [H]) (@dup 0) (ite (?cond warp-row ky [H]) (@dup 1) (@dup 2)))] 164 | [index-i (ite (<= kx warp-col) (@dup 0) (ite (<= kx (+ W warp-col)) (@dup 1) (@dup 2)))] 165 | [lane-x (?sw-xform-easy warp-col W 166 | kx 7 [])] 167 | [lane-y (?sw-xform-easy warp-row H 168 | ky 7 [])] 169 | [lane (+ (* lane-y W) lane-x)] 170 | [x (shfl (get I-cached index-i index-j) lane)]) 171 | (accumulate o x) 172 | )) 173 | ;(acc-print o) 174 | (reg-to-global (accumulate-final o) O 175 | (lov2vol (x-y-z (+ offset-x warp-col) (+ offset-y warp-row)))) 176 | ) 177 | 178 | (define (stencil-2d-sketch2-sol threadId blockID blockDim I O I-sizes O-sizes) 179 | (define gid (+ (* blockID blockDim) threadId)) 180 | (define gx (get-x gid)) 181 | (define gy (get-y gid)) 182 | (define id (modulo (get-x threadId) warpSize)) 183 | (define warp-col (modulo id W)) 184 | (define warp-row (quotient id W)) 185 | 186 | (define offset-x (* (quotient gx warpSize) W)) 187 | (define offset-y (* gy H)) 188 | 189 | (define I-cached (create-matrix-local (x-y-z 3 3))) 190 | (global-to-local I I-cached 191 | (x-y-z 1 1) 192 | (lov2vol (x-y-z offset-x offset-y)) 193 | (+ warp-shape 6) #f 194 | #:warp-shape warp-shape #:round (x-y-z 3 3)) 195 | 196 | (define o (create-accumulator (list +) /9 blockDim)) 197 | 198 | (for* ([ky 7] [kx 7]) 199 | (let* ([index-j (ite (<= ky warp-row) (@dup 0) (ite (<= ky (+ H warp-row)) (@dup 1) (@dup 2)))] 200 | [index-i (ite (<= kx warp-col) (@dup 0) (ite (<= kx (+ W warp-col)) (@dup 1) (@dup 2)))] 201 | [lane-x (sw-xform warp-col W 1 W W 1 ;(choose 1 -1) 202 | kx 7 1 7 0)] 203 | [lane-y (sw-xform warp-row H 1 H H 1 ;(choose 1 -1) 204 | ky 7 1 7 0)] 205 | [lane (+ (* lane-y W) lane-x)] 206 | [x (shfl (get I-cached index-i index-j) lane)]) 207 | (accumulate o x) 208 | )) 209 | ;(acc-print o) 210 | (reg-to-global (accumulate-final o) O 211 | (lov2vol (x-y-z (+ offset-x warp-col) (+ offset-y warp-row)))) 212 | ) 213 | 214 | (define (test) 215 | (for ([w (list WARP_SIZE)]) 216 | (let ([ret (run-with-warp-size stencil-2d-spec stencil-2d-sketch2-sol w)]) 217 | (pretty-display `(test ,w ,ret)))) 218 | ) 219 | ;(test) 220 | 221 | (define (synthesis) 222 | (pretty-display "solving...") 223 | (define sol 224 | (time (solve 225 | (assert (andmap 226 | (lambda (w) (run-with-warp-size stencil-2d-spec stencil-2d-sketch2 w)) 227 | (list WARP_SIZE)))))) 228 | (print-forms sol) 229 | ;(print-lane 'lane (evaluate my-lane sol) '#(localId i) '#()) 230 | ) 231 | (define t0 (current-seconds)) 232 | (synthesis) 233 | (define t1 (current-seconds)) 234 | (- t1 t0) 235 | -------------------------------------------------------------------------------- /ex3-poly-mult-load-only.rkt: -------------------------------------------------------------------------------- 1 | #| 2 | | Copyright (c) 2018-2019, NVIDIA CORPORATION. All rights reserved. 3 | | 4 | | Redistribution and use in source and binary forms, with or without 5 | | modification, are permitted provided that the following conditions are met: 6 | | 7 | | 1. Redistributions of source code must retain the above copyright notice, 8 | | this list of conditions and the following disclaimer. 9 | | 10 | | 2. Redistributions in binary form must reproduce the above copyright notice, 11 | | this list of conditions and the following disclaimer in the documentation 12 | | and/or other materials provided with the distribution. 13 | | 14 | | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 15 | | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 16 | | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 17 | | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 18 | | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 19 | | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 20 | | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 21 | | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 22 | | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 23 | | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 24 | | POSSIBILITY OF SUCH DAMAGE. 25 | |# 26 | 27 | #lang rosette 28 | 29 | (require rosette/lib/synthax) 30 | (require "util.rkt" "cuda.rkt" "cuda-synth.rkt") 31 | 32 | (define n-block 1) 33 | 34 | (define (create-IO warpSize n) 35 | (set-warpSize warpSize) 36 | (define block-size warpSize) 37 | (define sizes (x-y-z n)) 38 | (define A (create-matrix sizes gen-uid)) 39 | (define B (create-matrix sizes gen-uid)) 40 | (define C (create-matrix (* 2 sizes))) 41 | (define C* (create-matrix (* 2 sizes))) 42 | (values block-size sizes A B C C*)) 43 | 44 | (define (run-with-warp-size spec kernel w n) 45 | (define-values (block-size sizes A B C C*) 46 | (create-IO w n)) 47 | 48 | (spec A B C sizes) 49 | (run-kernel kernel (x-y-z block-size) (x-y-z n-block) A B C* sizes) 50 | ;(acc-print O*) 51 | (acc-equal? C C*)) 52 | 53 | (define (mult-spec A B C sizes) 54 | (for ([index (get-x sizes)]) 55 | (let ([c (create-accumulator (list bvand bvxor) identity)]) 56 | (for ([i (add1 index)]) 57 | (let ([a (get A i)] 58 | [b (get B (- index i))]) 59 | (accumulate c (list a b)))) 60 | (set C index c)) 61 | (let ([d (create-accumulator (list bvand bvxor) identity)]) 62 | (for ([i (range (add1 index) (get-x sizes))]) 63 | (let ([a (get A i)] 64 | [b (get B (- (+ index (get-x sizes)) i))]) 65 | (accumulate d (list a b)))) 66 | (set C (+ (get-x sizes) index) d)))) 67 | 68 | ;; Synthesize data loading from global memory. 69 | (define (load-synth) 70 | (define-values (block-size sizes A B C D C* D*) 71 | (create-IO 4)) 72 | 73 | ;; Store 74 | (define (mult-store threadId blockId blockDim C D) 75 | (define warpID (get-warpId threadId)) 76 | (define o 77 | (for/vector ([w warpID] 78 | [t threadId]) 79 | (ID t w blockId))) 80 | (reg-to-global o C threadId) 81 | (reg-to-global o D threadId) 82 | ) 83 | 84 | ;; Run spec -- already ran 85 | 86 | ;; Collect IDs 87 | (define C-IDs (create-matrix sizes)) 88 | (define D-IDs (create-matrix sizes)) 89 | (run-kernel mult-store sizes (x-y-z n-block) C-IDs D-IDs) 90 | 91 | (define-values (C-threads C-warps C-blocks) (get-grid-storage)) 92 | (collect-inputs C C-IDs C-threads C-warps C-blocks) 93 | (define-values (D-threads D-warps D-blocks) (get-grid-storage)) 94 | (collect-inputs D D-IDs D-threads D-warps D-blocks) 95 | 96 | (define warps (vector-list-append C-warps D-warps)) 97 | (define a-regs (num-regs warps A)) 98 | (pretty-display `(a-regs ,a-regs)) 99 | (define b-regs (num-regs warps B)) 100 | (pretty-display `(b-regs ,b-regs)) 101 | 102 | ;; Load 103 | (define (mult-load threadId blockId blockDim A B C-warp-spec D-warp-spec) 104 | (define warpId (get-warpId threadId)) 105 | ;; sketch starts 106 | (define A-cached (create-matrix-local (x-y-z a-regs))) 107 | (define B-cached (create-matrix-local (x-y-z b-regs))) 108 | (global-to-local A A-cached 109 | (x-y-z 1) ;; stride 110 | (x-y-z (?warp-offset [(get-x blockId) (get-x blockDim)] [warpId warpSize])) ;; offset 111 | (x-y-z (?warp-size warpSize 1)) ;; load size --> TODO: minimize load size 112 | #f) 113 | (global-to-local B B-cached 114 | (x-y-z 1) ;; stride 115 | (x-y-z (?warp-offset [(get-x blockId) (get-x blockDim)] [warpId warpSize])) ;; offset 116 | (x-y-z (?warp-size warpSize 1)) ;; load size 117 | #f) 118 | ;; sketch ends 119 | (check-warp-input C-warp-spec A A-cached warpId blockId) 120 | (check-warp-input D-warp-spec A A-cached warpId blockId) 121 | (check-warp-input C-warp-spec B B-cached warpId blockId) 122 | (check-warp-input D-warp-spec B B-cached warpId blockId) 123 | ) 124 | 125 | (run-kernel mult-load sizes (x-y-z n-block) A B C-warps D-warps) 126 | (define sol 127 | (time 128 | (synthesize 129 | #:forall (append (symbolics A) (symbolics B)) 130 | #:guarantee (assert #t)))) 131 | (print-forms sol) 132 | ) 133 | ;(load-synth) 134 | -------------------------------------------------------------------------------- /ex3-poly-mult-noacc.rkt: -------------------------------------------------------------------------------- 1 | #| 2 | | Copyright (c) 2018-2019, NVIDIA CORPORATION. All rights reserved. 3 | | 4 | | Redistribution and use in source and binary forms, with or without 5 | | modification, are permitted provided that the following conditions are met: 6 | | 7 | | 1. Redistributions of source code must retain the above copyright notice, 8 | | this list of conditions and the following disclaimer. 9 | | 10 | | 2. Redistributions in binary form must reproduce the above copyright notice, 11 | | this list of conditions and the following disclaimer in the documentation 12 | | and/or other materials provided with the distribution. 13 | | 14 | | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 15 | | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 16 | | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 17 | | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 18 | | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 19 | | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 20 | | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 21 | | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 22 | | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 23 | | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 24 | | POSSIBILITY OF SUCH DAMAGE. 25 | |# 26 | 27 | #lang rosette 28 | 29 | (require rosette/lib/synthax) 30 | (require "util.rkt" "cuda.rkt" "cuda-synth.rkt") 31 | 32 | (define n-block 1) 33 | (define block-dim-y 1) 34 | 35 | (define syms #f) 36 | (define (create-IO warpSize n) 37 | (set-warpSize warpSize) 38 | (define A (create-matrix (x-y-z n n-block) gen-bv)) 39 | (define B (create-matrix (x-y-z n n-block) gen-bv)) 40 | (set! syms (append (symbolics A) (symbolics B))) 41 | (define C (create-matrix (x-y-z (* 2 n) n-block))) 42 | (define C* (create-matrix (x-y-z (* 2 n) n-block))) 43 | (values A B C C*)) 44 | 45 | (define (run-with-warp-size spec kernel w n) 46 | (define-values (A B C C*) 47 | (create-IO w n)) 48 | 49 | (spec A B C n n-block) 50 | (run-kernel kernel (x-y-z w block-dim-y) (x-y-z 1 n-block) A B C* n) 51 | ;(pretty-display ">>> C") 52 | ;(acc-print C) 53 | ;(pretty-display ">>> C*") 54 | ;(acc-print C*) 55 | ;(acc-equal? C C*) 56 | (for/and ([row C] [row* C*]) 57 | (for/and ([e row] [e* row*]) 58 | (bveq e e*))) 59 | ) 60 | 61 | (define (update c a b) 62 | (bvxor 63 | c 64 | (bvand a b))) 65 | 66 | (define (mult-spec A B C n rows) 67 | (for ([row rows]) 68 | (for ([index n]) 69 | (let ([c (bv 0 4)]) 70 | (for ([i (add1 index)]) 71 | (let ([a (get A i row)] 72 | [b (get B (- index i) row)]) 73 | (set! c (update c a b)))) 74 | (set C index row c)) 75 | (let ([d (bv 0 4)]) 76 | (for ([i (range (add1 index) n)]) 77 | (let ([a (get A i row)] 78 | [b (get B (- (+ index n) i) row)]) 79 | (set! d (update d a b)))) 80 | (set C (+ n index) row d))))) 81 | 82 | 83 | (define (mult32 threadId blockID blockDim A B C n) 84 | (define globalID (+ threadId (* blockID blockDim))) 85 | (define a-cached 0) 86 | (define b-cached 0) 87 | (global-to-reg A a-cached globalID #:size (x-y-z n)) 88 | (global-to-reg B b-cached globalID #:size (x-y-z n)) 89 | 90 | (define tidx (modulo (get-x threadId) 32)) 91 | (define acc1 (bv 0 4)) 92 | (define acc2 (bv 0 4)) 93 | 94 | (for ([i n]) 95 | (let* ([lane-a (sw-xform tidx warpSize 0 warpSize warpSize 1 96 | i warpSize 1 warpSize)] 97 | [lane-b (sw-xform tidx warpSize 1 warpSize warpSize 1 98 | i warpSize (- warpSize 1) warpSize)] 99 | [a (shfl a-cached lane-a)] 100 | [b (shfl b-cached lane-b)] 101 | ) 102 | (set! acc1 (ite (<= (@dup i) tidx) (update acc1 a b) acc1)) 103 | (set! acc2 (ite (> (@dup i) tidx) (update acc2 a b) acc2)) 104 | )) 105 | 106 | (reg-to-global acc1 C globalID #:size (x-y-z (* 2 n))) 107 | (reg-to-global acc2 C (+ globalID (@dup (x-y-z n 0))) #:size (x-y-z (* 2 n))) 108 | ) 109 | 110 | ;; Sketch that doesn't use accumulators. 111 | (define (mult32-sketch threadId blockID blockDim A B C n) 112 | ;; For 2D kernel like this, threadId, blockID, and blockDim contain two values: .x and .y. 113 | ;; (* blockID blockDim) = (x-y-z (* blockID.x blockDim.x) (* blockID.y blockDim.y)) 114 | ;; x-y-z is for creating a tuple of values 115 | (define globalID (+ threadId (* blockID blockDim))) 116 | (define a-cached 0) 117 | (define b-cached 0) 118 | (global-to-reg A a-cached globalID #:size (x-y-z n)) 119 | (global-to-reg B b-cached globalID #:size (x-y-z n)) 120 | 121 | (define tidx (modulo (get-x threadId) 32)) ;; threadId.x % 32 122 | (define acc1 (bv 0 4)) ;; not accumulator 123 | (define acc2 (bv 0 4)) ;; not accumulator 124 | 125 | (for ([i n]) 126 | (let* (;[lane-a (?sw-xform-easy tidx warpSize i warpSize [])] 127 | ;[lane-b (?sw-xform-easy tidx warpSize i warpSize [])] 128 | [lane-a (?lane-mod tidx (@dup i) 2 n [warpSize])] 129 | [lane-b (?lane-mod tidx (@dup i) 2 n [warpSize])] 130 | [a (shfl a-cached lane-a)] 131 | [b (shfl b-cached lane-b)] 132 | ) 133 | (set! acc1 (ite (<= (@dup i) tidx) (update acc1 a b) acc1)) 134 | (set! acc2 (ite (> (@dup i) tidx) (update acc2 a b) acc2)) 135 | )) 136 | 137 | (reg-to-global acc1 C globalID #:size (x-y-z (* 2 n))) 138 | (reg-to-global acc2 C (+ globalID (@dup (x-y-z n 0))) #:size (x-y-z (* 2 n))) 139 | ) 140 | 141 | ;; Sketch that doesn't use accumulators. 142 | (define (mult32-shared-sketch threadId blockID blockDim A B C n) 143 | ;; For 2D kernel like this, threadId, blockID, and blockDim contain two values: .x and .y. 144 | ;; (* blockID blockDim) = (x-y-z (* blockID.x blockDim.x) (* blockID.y blockDim.y)) 145 | ;; x-y-z is for creating a tuple of values 146 | (define globalID (+ threadId (* blockID blockDim))) 147 | (define-shared a-cached (create-matrix blockDim)) 148 | (define-shared b-cached (create-matrix blockDim)) 149 | (global-to-shared A a-cached 150 | (x-y-z 1 1) ;; stride 151 | (* blockDim blockID) 152 | blockDim 153 | #f #:round (x-y-z 1 1) #:size (x-y-z n 1)) 154 | (global-to-shared B b-cached 155 | (x-y-z 1 1) ;; stride 156 | (* blockDim blockID) 157 | blockDim 158 | #f #:round (x-y-z 1 1) #:size (x-y-z n 1)) 159 | 160 | (define tidx (modulo (get-x threadId) 32)) ;; threadId.x % 32 161 | (define tidy (get-y threadId)) 162 | (define acc1 (bv 0 4)) 163 | (define acc2 (bv 0 4)) 164 | 165 | (for ([i n]) 166 | (let* (;[lane-a (?sw-xform-easy tidx warpSize i warpSize [])] 167 | ;[lane-b (?sw-xform-easy tidx warpSize i warpSize [])] 168 | [lane-a (?lane-mod tidx (@dup i) 2 n [warpSize])] 169 | [lane-b (?lane-mod tidx (@dup i) 2 n [warpSize])] 170 | [a (get a-cached lane-a tidy)] 171 | [b (get b-cached lane-b tidy)] 172 | ) 173 | (set! acc1 (ite (<= (@dup i) tidx) (update acc1 a b) acc1)) 174 | (set! acc2 (ite (> (@dup i) tidx) (update acc2 a b) acc2)) 175 | )) 176 | 177 | (reg-to-global acc1 C globalID #:size (x-y-z (* 2 n))) 178 | (reg-to-global acc2 C (+ globalID (@dup (x-y-z n 0))) #:size (x-y-z (* 2 n))) 179 | ) 180 | 181 | 182 | (define (test) 183 | (for ([w (list 32)]) 184 | (let ([ret (run-with-warp-size mult-spec mult32 w (* 1 w))]) 185 | (pretty-display `(test ,w ,ret)) 186 | (pretty-display `(cost ,(get-cost))) 187 | )) 188 | ) 189 | ;(test) 190 | 191 | (define (synthesis) 192 | (pretty-display "solving...") 193 | (define t (andmap 194 | (lambda (w) (run-with-warp-size mult-spec mult32-shared-sketch w (* 1 w))) 195 | (list 4))) 196 | ;(define cost (get-cost)) 197 | (define sol (time (synthesize #:forall syms 198 | #:guarantee (assert t)))) 199 | 200 | ;(define this-cost (evaluate cost sol)) 201 | (print-forms sol) 202 | ;(pretty-display `(cost ,this-cost)) 203 | ) 204 | (define t0 (current-seconds)) 205 | (synthesis) 206 | (define t1 (current-seconds)) 207 | (- t1 t0) 208 | -------------------------------------------------------------------------------- /ex3-poly-mult.rkt: -------------------------------------------------------------------------------- 1 | #| 2 | | Copyright (c) 2018-2019, University of California, Berkeley. 3 | | Copyright (c) 2018-2019, NVIDIA CORPORATION. All rights reserved. 4 | | 5 | | Redistribution and use in source and binary forms, with or without 6 | | modification, are permitted provided that the following conditions are met: 7 | | 8 | | 1. Redistributions of source code must retain the above copyright notice, 9 | | this list of conditions and the following disclaimer. 10 | | 11 | | 2. Redistributions in binary form must reproduce the above copyright notice, 12 | | this list of conditions and the following disclaimer in the documentation 13 | | and/or other materials provided with the distribution. 14 | | 15 | | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 16 | | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 17 | | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 18 | | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 19 | | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 20 | | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 21 | | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 22 | | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 23 | | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 24 | | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 25 | | POSSIBILITY OF SUCH DAMAGE. 26 | |# 27 | 28 | #lang rosette 29 | 30 | (require rosette/lib/synthax) 31 | (require "util.rkt" "cuda.rkt" "cuda-synth.rkt") 32 | 33 | (define n-block 1) 34 | (define block-dim-y 1) 35 | 36 | ;; Create input and output matrices. 37 | (define (create-IO warpSize n) 38 | (set-warpSize warpSize) 39 | (define A (create-matrix (x-y-z n n-block) gen-uid)) 40 | (define B (create-matrix (x-y-z n n-block) gen-uid)) 41 | (define C (create-matrix (x-y-z (* 2 n) n-block))) 42 | (define C* (create-matrix (x-y-z (* 2 n) n-block))) 43 | (values A B C C*)) 44 | 45 | ;; Run sequential program spec and GPU kernel kernel, and compare their outputs. 46 | ;; n is the degree of polynomial multiplication. 47 | (define (run-with-warp-size spec kernel w n) 48 | (define-values (A B C C*) 49 | (create-IO w n)) 50 | 51 | (spec A B C n n-block) 52 | (run-kernel kernel (x-y-z w block-dim-y) (x-y-z 1 n-block) A B C* n) 53 | ;(pretty-display ">>> C") 54 | ;(acc-print C) 55 | ;(pretty-display ">>> C*") 56 | ;(acc-print C*) 57 | (acc-equal? C C*)) 58 | 59 | ;; Sequential program spec 60 | (define (mult-spec A B C n rows) 61 | (for ([row rows]) 62 | (for ([index n]) 63 | (let ([c (create-accumulator (list bvand bvxor) identity)]) 64 | (for ([i (add1 index)]) 65 | (let ([a (get A i row)] 66 | [b (get B (- index i) row)]) 67 | (accumulate c (list a b)))) 68 | (set C index row c)) 69 | (let ([d (create-accumulator (list bvand bvxor) identity)]) 70 | (for ([i (range (add1 index) n)]) 71 | (let ([a (get A i row)] 72 | [b (get B (- (+ index n) i) row)]) 73 | (accumulate d (list a b)))) 74 | (set C (+ n index) row d))))) 75 | 76 | ;; Complete kernel for poly-mult degree 32 77 | (define (mult threadId blockID blockDim A B C n) 78 | (define block-offset (* blockID blockDim)) 79 | (define globalID (+ threadId block-offset)) 80 | (define a-cached 0) 81 | (define b-cached 0) 82 | (global-to-reg A a-cached globalID) 83 | (global-to-reg B b-cached globalID) 84 | 85 | (define tidx (get-x threadId)) 86 | (define acc1 (create-accumulator (list bvand bvxor) identity blockDim)) 87 | (define acc2 (create-accumulator (list bvand bvxor) identity blockDim)) 88 | 89 | (for ([i n]) 90 | (let ([a (shfl a-cached i)] 91 | [b (shfl b-cached (- tidx i))]) 92 | (accumulate acc1 (list a b) #:pred (<= i tidx)))) 93 | 94 | (for ([i n]) 95 | (let ([a (shfl a-cached i)] 96 | [b (shfl b-cached (- tidx i))]) 97 | (accumulate acc2 (list a b) #:pred (> i tidx)))) 98 | 99 | (reg-to-global acc1 C (+ block-offset threadId)) 100 | (reg-to-global acc2 C (+ block-offset (@dup (x-y-z n 0)) threadId)) 101 | ) 102 | 103 | ;; Complete kernel for poly-mult degree 32 using sw-xform function 104 | (define (mult32 threadId blockID blockDim A B C n) 105 | (define globalID (+ threadId (* blockID blockDim))) 106 | (define a-cached 0) 107 | (define b-cached 0) 108 | (global-to-reg A a-cached globalID #:size (x-y-z n)) 109 | (global-to-reg B b-cached globalID #:size (x-y-z n)) 110 | 111 | (define tidx (modulo (get-x threadId) 32)) 112 | (define acc1 (create-accumulator (list bvand bvxor) identity blockDim)) 113 | (define acc2 (create-accumulator (list bvand bvxor) identity blockDim)) 114 | 115 | (for ([i n]) 116 | (let* ([lane-a (sw-xform tidx warpSize 0 warpSize warpSize 1 117 | i warpSize 1 warpSize)] 118 | [lane-b (sw-xform tidx warpSize 1 warpSize warpSize 1 119 | i warpSize (- warpSize 1) warpSize)] 120 | [a (shfl a-cached lane-a)] 121 | [b (shfl b-cached lane-b)] 122 | ) 123 | (accumulate acc1 (list a b) #:pred (<= (@dup i) tidx)) 124 | (accumulate acc2 (list a b) #:pred (> (@dup i) tidx)) 125 | )) 126 | 127 | (reg-to-global acc1 C globalID #:size (x-y-z (* 2 n))) 128 | (reg-to-global acc2 C (+ globalID (@dup (x-y-z n 0))) #:size (x-y-z (* 2 n))) 129 | ) 130 | 131 | ;; Kernel sketch for degree 32 using registers 132 | (define (mult32-sketch threadId blockID blockDim A B C n) 133 | ;; For 2D kernel like this, threadId, blockID, and blockDim contain two values: .x and .y. 134 | ;; (* blockID blockDim) = (x-y-z (* blockID.x blockDim.x) (* blockID.y blockDim.y)) 135 | ;; x-y-z is for creating a tuple of values 136 | (define globalID (+ threadId (* blockID blockDim))) 137 | (define a-cached 0) 138 | (define b-cached 0) 139 | (global-to-reg A a-cached globalID #:size (x-y-z n)) 140 | (global-to-reg B b-cached globalID #:size (x-y-z n)) 141 | 142 | (define tidx (modulo (get-x threadId) 32)) ;; threadId.x % 32 143 | (define acc1 (create-accumulator (list bvand bvxor) identity blockDim)) 144 | (define acc2 (create-accumulator (list bvand bvxor) identity blockDim)) 145 | 146 | (for ([i n]) 147 | (let* ([lane-a (?sw-xform-easy tidx warpSize i warpSize [])] 148 | [lane-b (?sw-xform-easy tidx warpSize i warpSize [])] 149 | ;[lane-a (?lane-mod tidx (@dup i) 2 n [warpSize])] 150 | ;[lane-b (?lane-mod tidx (@dup i) 2 n [warpSize])] 151 | [a (shfl a-cached lane-a)] 152 | [b (shfl b-cached lane-b)] 153 | ) 154 | (accumulate acc1 (list a b) #:pred (?cond-easy tidx (@dup i))) 155 | (accumulate acc2 (list a b) #:pred (?cond-easy tidx (@dup i))) 156 | )) 157 | 158 | (reg-to-global acc1 C globalID #:size (x-y-z (* 2 n))) 159 | (reg-to-global acc2 C (+ globalID (@dup (x-y-z n 0))) #:size (x-y-z (* 2 n))) 160 | ) 161 | 162 | ;; Kernel sketch for degree 32 using shared memory 163 | (define (mult32-shared-sketch threadId blockID blockDim A B C n) 164 | ;; For 2D kernel like this, threadId, blockID, and blockDim contain two values: .x and .y. 165 | ;; (* blockID blockDim) = (x-y-z (* blockID.x blockDim.x) (* blockID.y blockDim.y)) 166 | ;; x-y-z is for creating a tuple of values 167 | (define globalID (+ threadId (* blockID blockDim))) 168 | (define-shared a-cached (create-matrix blockDim)) 169 | (define-shared b-cached (create-matrix blockDim)) 170 | (global-to-shared A a-cached 171 | (x-y-z 1 1) ;; stride 172 | (* blockDim blockID) 173 | blockDim 174 | #f #:round (x-y-z 1 1) #:size (x-y-z n 1)) 175 | (global-to-shared B b-cached 176 | (x-y-z 1 1) ;; stride 177 | (* blockDim blockID) 178 | blockDim 179 | #f #:round (x-y-z 1 1) #:size (x-y-z n 1)) 180 | 181 | (define tidx (modulo (get-x threadId) 32)) ;; threadId.x % 32 182 | (define tidy (get-y threadId)) 183 | (define acc1 (create-accumulator (list bvand bvxor) identity blockDim)) 184 | (define acc2 (create-accumulator (list bvand bvxor) identity blockDim)) 185 | 186 | (for ([i n]) 187 | (let* ([lane-a (?sw-xform-easy tidx warpSize i warpSize [])] 188 | [lane-b (?sw-xform-easy tidx warpSize i warpSize [])] 189 | ;[lane-a (?lane-mod tidx (@dup i) 2 n [warpSize])] 190 | ;[lane-b (?lane-mod tidx (@dup i) 2 n [warpSize])] 191 | [a (get a-cached lane-a tidy)] 192 | [b (get b-cached lane-b tidy)] 193 | ) 194 | (accumulate acc1 (list a b) #:pred (?cond-easy tidx (@dup i))) 195 | (accumulate acc2 (list a b) #:pred (?cond-easy tidx (@dup i))) 196 | )) 197 | 198 | (reg-to-global acc1 C globalID #:size (x-y-z (* 2 n))) 199 | (reg-to-global acc2 C (+ globalID (@dup (x-y-z n 0))) #:size (x-y-z (* 2 n))) 200 | ) 201 | 202 | ;; Complete kernel for degree 64 using registers 203 | (define (mult64 threadId blockID blockDim A B C n) 204 | (define globalID (+ threadId (* blockID blockDim))) 205 | (define a-cached (create-matrix-local (x-y-z 2 1))) 206 | (define b-cached (create-matrix-local (x-y-z 2 1))) 207 | (global-to-local A a-cached 208 | (x-y-z 1 1) ;; stride 209 | (* (quotient globalID (x-y-z warpSize 1)) 210 | (x-y-z n 1)) 211 | (x-y-z n 1) 212 | #f #:warp-shape (x-y-z warpSize 1) #:round (x-y-z 2 1)) 213 | (global-to-local B b-cached 214 | (x-y-z 1 1) ;; stride 215 | (* (quotient globalID (x-y-z warpSize 1)) 216 | (x-y-z n 1)) 217 | (x-y-z n 1) 218 | #f #:warp-shape (x-y-z warpSize 1) #:round (x-y-z 2 1)) 219 | 220 | (define tidx (get-idInWarp threadId)) 221 | (define acc1 (create-accumulator (list bvand bvxor) identity blockDim)) 222 | (define acc2 (create-accumulator (list bvand bvxor) identity blockDim)) 223 | (define acc3 (create-accumulator (list bvand bvxor) identity blockDim)) 224 | (define acc4 (create-accumulator (list bvand bvxor) identity blockDim)) 225 | 226 | (for ([i warpSize]) 227 | (let* ([lane-a1 (sw-xform tidx warpSize 0 warpSize warpSize 1 228 | i warpSize 1 warpSize)] 229 | [lane-a2 (sw-xform tidx warpSize 0 warpSize warpSize 1 230 | i warpSize 1 warpSize)] 231 | [lane-b1 (sw-xform tidx warpSize 1 warpSize warpSize 1 232 | i warpSize (- warpSize 1) warpSize)] 233 | [lane-b2 (sw-xform tidx warpSize 1 warpSize warpSize 1 234 | i warpSize (- warpSize 1) warpSize)] 235 | [a1 (shfl (get a-cached (@dup 0) (@dup 0)) lane-a1)] 236 | [a2 (shfl (get a-cached (@dup 1) (@dup 0)) lane-a2)] 237 | [b1 (shfl (get b-cached (@dup 0) (@dup 0)) lane-b1)] 238 | [b2 (shfl (get b-cached (@dup 1) (@dup 0)) lane-b2)] 239 | ) 240 | (accumulate acc1 (list a1 b1) #:pred (<= i tidx)) 241 | 242 | (accumulate acc2 (list a1 b1) #:pred (> i tidx)) 243 | (accumulate acc2 (list a1 b2) #:pred (<= i tidx)) 244 | (accumulate acc2 (list a2 b1) #:pred (<= i tidx)) 245 | 246 | (accumulate acc3 (list a1 b2) #:pred (> i tidx)) 247 | (accumulate acc3 (list a2 b1) #:pred (> i tidx)) 248 | (accumulate acc3 (list a2 b2) #:pred (<= i tidx)) 249 | 250 | (accumulate acc4 (list a2 b2) #:pred (> i tidx)) 251 | )) 252 | 253 | (reg-to-global acc1 C globalID) 254 | (reg-to-global acc2 C (+ globalID (@dup (x-y-z warpSize 0)))) 255 | (reg-to-global acc3 C (+ globalID (@dup (x-y-z (* 2 warpSize) 0)))) 256 | (reg-to-global acc4 C (+ globalID (@dup (x-y-z (* 3 warpSize) 0)))) 257 | ) 258 | 259 | ;; Kernel sketch for degree 64 using registers 260 | (define (mult64-sketch threadId blockID blockDim A B C n) 261 | (define globalID (+ threadId (* blockID blockDim))) 262 | (define a-cached (create-matrix-local (x-y-z 2 1))) 263 | (define b-cached (create-matrix-local (x-y-z 2 1))) 264 | (global-to-local A a-cached 265 | (x-y-z 1 1) ;; stride 266 | (* (quotient globalID (x-y-z warpSize 1)) 267 | (x-y-z n 1)) 268 | (x-y-z n 1) 269 | #f #:warp-shape (x-y-z warpSize 1) #:round (x-y-z 2 1)) 270 | (global-to-local B b-cached 271 | (x-y-z 1 1) ;; stride 272 | (* (quotient globalID (x-y-z warpSize 1)) 273 | (x-y-z n 1)) 274 | (x-y-z n 1) 275 | #f #:warp-shape (x-y-z warpSize 1) #:round (x-y-z 2 1)) 276 | 277 | (define tidx (get-idInWarp threadId)) 278 | (define acc1 (create-accumulator (list bvand bvxor) identity blockDim)) 279 | (define acc2 (create-accumulator (list bvand bvxor) identity blockDim)) 280 | (define acc3 (create-accumulator (list bvand bvxor) identity blockDim)) 281 | (define acc4 (create-accumulator (list bvand bvxor) identity blockDim)) 282 | 283 | (for ([i warpSize]) 284 | (let* ([lane-a1 (?sw-xform-easy tidx warpSize 285 | i warpSize [])] 286 | [lane-a2 (?sw-xform-easy tidx warpSize 287 | i warpSize [])] 288 | [lane-b1 (?sw-xform-easy tidx warpSize 289 | i warpSize [])] 290 | [lane-b2 (?sw-xform-easy tidx warpSize 291 | i warpSize [])] 292 | [idx-a1 (ite (?cond tidx (@dup i)) (@dup 0) (@dup 1))] 293 | [idx-a2 (ite (?cond tidx (@dup i)) (@dup 0) (@dup 1))] 294 | [idx-b1 (ite (?cond tidx (@dup i)) (@dup 0) (@dup 1))] 295 | [idx-b2 (ite (?cond tidx (@dup i)) (@dup 0) (@dup 1))] 296 | [a1 (shfl (get a-cached idx-a1 (@dup 0)) lane-a1)] 297 | [a2 (shfl (get a-cached idx-a2 (@dup 0)) lane-a2)] 298 | [b1 (shfl (get b-cached idx-b1 (@dup 0)) lane-b1)] 299 | [b2 (shfl (get b-cached idx-b2 (@dup 0)) lane-b2)] 300 | ) 301 | (accumulate acc1 (list a1 b1) #:pred (?cond tidx (@dup i))) 302 | (accumulate acc2 (list a1 b1) #:pred (?cond tidx (@dup i))) 303 | (accumulate acc3 (list a1 b1) #:pred (?cond tidx (@dup i))) 304 | (accumulate acc4 (list a1 b1) #:pred (?cond tidx (@dup i))) 305 | 306 | (accumulate acc1 (list a1 b2) #:pred (?cond tidx (@dup i))) 307 | (accumulate acc2 (list a1 b2) #:pred (?cond tidx (@dup i))) 308 | (accumulate acc3 (list a1 b2) #:pred (?cond tidx (@dup i))) 309 | (accumulate acc4 (list a1 b2) #:pred (?cond tidx (@dup i))) 310 | 311 | (accumulate acc1 (list a2 b1) #:pred (?cond tidx (@dup i))) 312 | (accumulate acc2 (list a2 b1) #:pred (?cond tidx (@dup i))) 313 | (accumulate acc3 (list a2 b1) #:pred (?cond tidx (@dup i))) 314 | (accumulate acc4 (list a2 b1) #:pred (?cond tidx (@dup i))) 315 | 316 | (accumulate acc1 (list a2 b2) #:pred (?cond tidx (@dup i))) 317 | (accumulate acc2 (list a2 b2) #:pred (?cond tidx (@dup i))) 318 | (accumulate acc3 (list a2 b2) #:pred (?cond tidx (@dup i))) 319 | (accumulate acc4 (list a2 b2) #:pred (?cond tidx (@dup i))) 320 | )) 321 | 322 | (reg-to-global acc1 C globalID) 323 | (reg-to-global acc2 C (+ globalID (@dup (x-y-z warpSize 0)))) 324 | (reg-to-global acc3 C (+ globalID (@dup (x-y-z (* 2 warpSize) 0)))) 325 | (reg-to-global acc4 C (+ globalID (@dup (x-y-z (* 3 warpSize) 0)))) 326 | ) 327 | 328 | ;; Complete sketch for degree 32 using shared memory 329 | (define (mult32-shared threadId blockID blockDim A B C n) 330 | (define warpId (get-warpId threadId)) 331 | (define-shared a-cached (create-matrix (x-y-z warpSize block-dim-y))) 332 | (define-shared b-cached (create-matrix (x-y-z warpSize block-dim-y))) 333 | (define block-offset (* blockID blockDim)) 334 | (global-to-shared A a-cached 335 | (x-y-z 1 1) ;; stride 336 | block-offset 337 | blockDim #:size warpSize) 338 | (global-to-shared B b-cached 339 | (x-y-z 1 1) ;; stride 340 | block-offset 341 | blockDim #:size warpSize) 342 | 343 | (define tidx (get-x threadId)) 344 | (define tidy (get-y threadId)) 345 | (define acc1 (create-accumulator (list bvand bvxor) identity blockDim)) 346 | (define acc2 (create-accumulator (list bvand bvxor) identity blockDim)) 347 | 348 | (for ([i n]) 349 | (let* ([lane-a (sw-xform tidx warpSize 0 warpSize warpSize 1 350 | i warpSize 1 warpSize)] 351 | [lane-b (sw-xform tidx warpSize 1 warpSize warpSize 1 352 | i warpSize -1 warpSize)] 353 | #;[lane-a (?sw-xform tidx warpSize 354 | i warpSize #:fw 1)] 355 | #;[lane-b (?sw-xform tidx warpSize 356 | i warpSize #:fw 1)] 357 | [a (get a-cached lane-a tidy)] 358 | [b (get b-cached lane-b tidy)] 359 | ) 360 | (accumulate acc1 (list a b) #:pred (<= i tidx) #;(?cond tidx (@dup i))) 361 | (accumulate acc2 (list a b) #:pred (> i tidx) #;(?cond tidx (@dup i))) 362 | )) 363 | 364 | (reg-to-global acc1 C (+ block-offset threadId)) 365 | (reg-to-global acc2 C (+ block-offset (@dup (x-y-z n 0)) threadId)) 366 | ) 367 | 368 | ;; Complete sketch for degree 64 using shared memory 369 | (define (mult64-shared threadId blockID blockDim A B C n) 370 | (define warpId (get-warpId threadId)) 371 | (define-shared a-cached (create-matrix (* (x-y-z 2 1) blockDim))) 372 | (define-shared b-cached (create-matrix (* (x-y-z 2 1) blockDim))) 373 | (define block-offset (* (x-y-z 2 1) blockID blockDim)) 374 | (global-to-shared A a-cached 375 | (x-y-z 1 1) ;; stride 376 | block-offset 377 | (* (x-y-z 2 1) blockDim)) 378 | (global-to-shared B b-cached 379 | (x-y-z 1 1) ;; stride 380 | block-offset 381 | (* (x-y-z 2 1) blockDim)) 382 | 383 | (pretty-display `(a ,a-cached)) 384 | 385 | (define tidx (modulo (get-x threadId) 32)) 386 | (define tidy (get-y threadId)) 387 | (define acc1 (create-accumulator (list bvand bvxor) identity blockDim)) 388 | (define acc2 (create-accumulator (list bvand bvxor) identity blockDim)) 389 | (define acc3 (create-accumulator (list bvand bvxor) identity blockDim)) 390 | (define acc4 (create-accumulator (list bvand bvxor) identity blockDim)) 391 | 392 | (for ([i warpSize]) 393 | (let* ([lane-a1 (sw-xform tidx n 0 n n 1 394 | i warpSize 1 warpSize 0)] 395 | [lane-a2 (sw-xform tidx n 0 n n 1 396 | i warpSize 1 warpSize warpSize)] 397 | [lane-b1 (sw-xform tidx n 1 n n 1 398 | i warpSize -1 warpSize 0)] 399 | [lane-b2 (sw-xform tidx n 1 n n 1 400 | i warpSize -1 warpSize warpSize)] 401 | [a1 (get a-cached lane-a1 tidy)] 402 | [a2 (get a-cached lane-a2 tidy)] 403 | [b1 (get b-cached lane-b1 tidy)] 404 | [b2 (get b-cached lane-b2 tidy)] 405 | ) 406 | (accumulate acc1 (list a1 b1) #:pred (<= i tidx)) 407 | (accumulate acc3 (list a1 b1) #:pred (> i tidx)) 408 | 409 | (accumulate acc2 (list a1 b2) #:pred #t) 410 | 411 | (accumulate acc2 (list a2 b1) #:pred (<= i tidx)) 412 | (accumulate acc4 (list a2 b1) #:pred (> i tidx)) 413 | 414 | (accumulate acc3 (list a2 b2) #:pred #t) 415 | )) 416 | 417 | (reg-to-global acc1 C (+ block-offset threadId)) 418 | (reg-to-global acc2 C (+ block-offset (x-y-z warpSize 0) threadId)) 419 | (reg-to-global acc3 C (+ block-offset (x-y-z (* 2 warpSize) 0) threadId)) 420 | (reg-to-global acc4 C (+ block-offset (x-y-z (* 3 warpSize) 0) threadId)) 421 | ) 422 | 423 | ;; Kernel sketch for degree 64 using shared memory 424 | (define (mult64-shared-sketch threadId blockID blockDim A B C n) 425 | (define warpId (get-warpId threadId)) 426 | (define-shared a-cached (create-matrix (* (x-y-z 2 1) blockDim))) 427 | (define-shared b-cached (create-matrix (* (x-y-z 2 1) blockDim))) 428 | (define block-offset (* (x-y-z 2 1) blockID blockDim)) 429 | (global-to-shared A a-cached 430 | (x-y-z 1 1) ;; stride 431 | block-offset 432 | (* (x-y-z 2 1) blockDim) #:round (x-y-z 2 1) #:size (x-y-z n)) 433 | (global-to-shared B b-cached 434 | (x-y-z 1 1) ;; stride 435 | block-offset 436 | (* (x-y-z 2 1) blockDim) #:round (x-y-z 2 1) #:size (x-y-z n)) 437 | 438 | (pretty-display `(a ,a-cached)) 439 | 440 | (define tidx (get-idInWarp threadId)) 441 | (define tidy (get-y threadId)) 442 | (define acc1 (create-accumulator (list bvand bvxor) identity blockDim)) 443 | (define acc2 (create-accumulator (list bvand bvxor) identity blockDim)) 444 | (define acc3 (create-accumulator (list bvand bvxor) identity blockDim)) 445 | (define acc4 (create-accumulator (list bvand bvxor) identity blockDim)) 446 | 447 | (for ([i warpSize]) 448 | (let* ([lane-a1 (?sw-xform-easy tidx n 449 | i warpSize [])] 450 | [lane-a2 (?sw-xform-easy tidx n 451 | i warpSize [])] 452 | [lane-b1 (?sw-xform-easy tidx n 453 | i warpSize [])] 454 | [lane-b2 (?sw-xform-easy tidx n 455 | i warpSize [])] 456 | [a1 (get a-cached lane-a1 tidy)] 457 | [a2 (get a-cached lane-a2 tidy)] 458 | [b1 (get b-cached lane-b1 tidy)] 459 | [b2 (get b-cached lane-b2 tidy)] 460 | ) 461 | (accumulate acc1 (list a1 b1) #:pred (?cond tidx (@dup i))) 462 | (accumulate acc2 (list a1 b1) #:pred (?cond tidx (@dup i))) 463 | (accumulate acc3 (list a1 b1) #:pred (?cond tidx (@dup i))) 464 | (accumulate acc4 (list a1 b1) #:pred (?cond tidx (@dup i))) 465 | 466 | (accumulate acc1 (list a1 b2) #:pred (?cond tidx (@dup i))) 467 | (accumulate acc2 (list a1 b2) #:pred (?cond tidx (@dup i))) 468 | (accumulate acc3 (list a1 b2) #:pred (?cond tidx (@dup i))) 469 | (accumulate acc4 (list a1 b2) #:pred (?cond tidx (@dup i))) 470 | 471 | (accumulate acc1 (list a2 b1) #:pred (?cond tidx (@dup i))) 472 | (accumulate acc2 (list a2 b1) #:pred (?cond tidx (@dup i))) 473 | (accumulate acc3 (list a2 b1) #:pred (?cond tidx (@dup i))) 474 | (accumulate acc4 (list a2 b1) #:pred (?cond tidx (@dup i))) 475 | 476 | (accumulate acc1 (list a2 b2) #:pred (?cond tidx (@dup i))) 477 | (accumulate acc2 (list a2 b2) #:pred (?cond tidx (@dup i))) 478 | (accumulate acc3 (list a2 b2) #:pred (?cond tidx (@dup i))) 479 | (accumulate acc4 (list a2 b2) #:pred (?cond tidx (@dup i))) 480 | )) 481 | 482 | (reg-to-global acc1 C (+ block-offset threadId)) 483 | (reg-to-global acc2 C (+ block-offset (x-y-z warpSize 0) threadId)) 484 | (reg-to-global acc3 C (+ block-offset (x-y-z (* 2 warpSize) 0) threadId)) 485 | (reg-to-global acc4 C (+ block-offset (x-y-z (* 3 warpSize) 0) threadId)) 486 | ) 487 | 488 | ;; Check correctness of a complete kernel against a spec. 489 | (define (test) 490 | (for ([w (list 32)]) 491 | (let ([ret (run-with-warp-size mult-spec mult32-shared-sketch w (* 1 w))]) 492 | (pretty-display `(test ,w ,ret)) 493 | (pretty-display `(cost ,(get-cost))) 494 | )) 495 | ) 496 | ;(test) 497 | 498 | ;; Synthesize a kernel sketch given a spec. 499 | (define (synthesis) 500 | (pretty-display "solving...") 501 | (assert (andmap 502 | ;(lambda (w) (run-with-warp-size mult-spec mult32-sketch w (* 1 w))) 503 | (lambda (w) (run-with-warp-size mult-spec mult64-sketch w (* 2 w))) 504 | (list 4))) 505 | (define cost (get-cost)) 506 | (define sol (time (optimize #:minimize (list cost) #:guarantee (assert #t)))) 507 | 508 | (define this-cost (evaluate cost sol)) 509 | (print-forms sol) 510 | (pretty-display `(cost ,this-cost)) 511 | ) 512 | (define t0 (current-seconds)) 513 | (synthesis) 514 | (define t1 (current-seconds)) 515 | (- t1 t0) 516 | -------------------------------------------------------------------------------- /ex4-aos-sum-noacc.rkt: -------------------------------------------------------------------------------- 1 | #| 2 | | Copyright (c) 2018-2019, NVIDIA CORPORATION. All rights reserved. 3 | | 4 | | Redistribution and use in source and binary forms, with or without 5 | | modification, are permitted provided that the following conditions are met: 6 | | 7 | | 1. Redistributions of source code must retain the above copyright notice, 8 | | this list of conditions and the following disclaimer. 9 | | 10 | | 2. Redistributions in binary form must reproduce the above copyright notice, 11 | | this list of conditions and the following disclaimer in the documentation 12 | | and/or other materials provided with the distribution. 13 | | 14 | | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 15 | | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 16 | | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 17 | | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 18 | | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 19 | | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 20 | | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 21 | | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 22 | | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 23 | | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 24 | | POSSIBILITY OF SUCH DAMAGE. 25 | |# 26 | 27 | #lang rosette 28 | 29 | (require "util.rkt" "cuda.rkt" "cuda-synth.rkt") 30 | 31 | (define struct-size 3) 32 | (define n-block 1) 33 | 34 | (define syms #f) 35 | 36 | (define (create-IO warpSize) 37 | (set-warpSize warpSize) 38 | (define block-size (* 1 warpSize)) 39 | (define array-size (* n-block block-size)) 40 | (define I-sizes (x-y-z (* array-size struct-size))) 41 | (define O-sizes (x-y-z array-size)) 42 | (define I (create-matrix I-sizes gen-sym)) 43 | (set! syms (symbolics I)) 44 | (define O (create-matrix O-sizes)) 45 | (define O* (create-matrix O-sizes)) 46 | (values block-size I-sizes O-sizes I O O*)) 47 | 48 | (define (run-with-warp-size spec kernel w) 49 | (define-values (block-size I-sizes O-sizes I O O*) 50 | (create-IO w)) 51 | 52 | (define c (gcd struct-size warpSize)) 53 | (define a (/ struct-size c)) 54 | (define b (/ warpSize c)) 55 | 56 | (reset-cost) 57 | (spec I O O-sizes) 58 | (pretty-display `(spec-cost ,(get-cost))) 59 | (run-kernel kernel (x-y-z block-size) (x-y-z n-block) I O* I-sizes O-sizes a b c) 60 | ;(pretty-display O) 61 | ;(pretty-display O*) 62 | (equal? O O*) 63 | ) 64 | 65 | (define (AOS-sum-spec I O O-sizes) 66 | (for ([i (get-x O-sizes)]) 67 | (let ([o 0]) 68 | (for ([j struct-size]) 69 | (set! o (+ o (get I (+ (* i struct-size) j))))) 70 | (set O i o))) 71 | ) 72 | 73 | (define (AOS-sum3 threadId blockID blockDim I O I-sizes O-sizes a b c) 74 | (define I-cached (create-matrix-local (x-y-z struct-size))) 75 | (define gid (+ (* blockID blockDim) threadId)) 76 | (define localId (get-idInWarp threadId)) 77 | (global-to-local 78 | I 79 | I-cached 80 | (x-y-z 1) 81 | (* struct-size (- gid localId)) 82 | (x-y-z (* warpSize struct-size)) 83 | #f 84 | #:round 85 | struct-size) 86 | (define o 0) 87 | (define I-cached2 88 | (permute-vector 89 | I-cached 90 | struct-size 91 | (lambda (i) (sw-xform i struct-size 0 1 1 -1 localId warpSize -1 8 0)))) 92 | (for 93 | ((i struct-size)) 94 | (let* ((lane (sw-xform localId warpSize 2 -32 32 -1 i struct-size 16 3 10)) 95 | (x (shfl (get I-cached2 (@dup i)) lane))) 96 | (set! o (+ o x)))) 97 | (reg-to-global o O gid)) 98 | 99 | ;; Sketch that doesn't use accumulators. 100 | (define (AOS-sum-sketch threadId blockID blockDim I O I-sizes O-sizes a b c) 101 | 102 | (define I-cached (create-matrix-local (x-y-z struct-size))) 103 | ;(define warpID (get-warpId threadId)) 104 | ;(define gid (get-global-threadId threadId blockID)) 105 | (define gid (+ (* blockID blockDim) threadId)) 106 | (define localId (get-idInWarp threadId)) 107 | (global-to-local I I-cached 108 | (x-y-z 1) ;; stride 109 | ;(+ (* struct-size blockID blockDim) (* struct-size warpID warpSize)) 110 | (* struct-size (- gid localId)) 111 | (x-y-z (* warpSize struct-size)) 112 | #f #:round struct-size) 113 | 114 | (define o 0) ;; not accumulator 115 | 116 | ;; column shuffle 117 | (define I-cached2 (permute-vector I-cached struct-size 118 | (lambda (i) (?sw-xform-easy i struct-size localId warpSize)))) 119 | 120 | ;; row shuffle 121 | (for ([i struct-size]) 122 | (let* ([lane (?sw-xform-easy localId warpSize i struct-size)] 123 | [x (shfl (get I-cached2 (@dup i)) lane)] 124 | ) 125 | (set! o (+ o x)) 126 | )) 127 | 128 | (reg-to-global o O gid) 129 | ) 130 | 131 | (define (test) 132 | (for ([w (list 32)]) 133 | (let ([ret (run-with-warp-size AOS-sum-spec AOS-sum3 w)]) 134 | (pretty-display `(test ,w ,ret ,(get-cost))))) 135 | ) 136 | ;(test) 137 | 138 | (define (synthesis) 139 | (pretty-display "solving...") 140 | (define t 141 | (andmap (lambda (w) (run-with-warp-size AOS-sum-spec AOS-sum-sketch w)) 142 | (list 32))) 143 | (define sol (time (synthesize #:forall syms 144 | #:guarantee (assert t)))) 145 | (print-forms sol) 146 | ) 147 | (define t0 (current-seconds)) 148 | (synthesis) 149 | (define t1 (current-seconds)) 150 | (- t1 t0) 151 | -------------------------------------------------------------------------------- /ex4-aos-sum.rkt: -------------------------------------------------------------------------------- 1 | #| 2 | | Copyright (c) 2018-2019, University of California, Berkeley. 3 | | Copyright (c) 2018-2019, NVIDIA CORPORATION. All rights reserved. 4 | | 5 | | Redistribution and use in source and binary forms, with or without 6 | | modification, are permitted provided that the following conditions are met: 7 | | 8 | | 1. Redistributions of source code must retain the above copyright notice, 9 | | this list of conditions and the following disclaimer. 10 | | 11 | | 2. Redistributions in binary form must reproduce the above copyright notice, 12 | | this list of conditions and the following disclaimer in the documentation 13 | | and/or other materials provided with the distribution. 14 | | 15 | | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 16 | | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 17 | | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 18 | | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 19 | | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 20 | | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 21 | | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 22 | | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 23 | | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 24 | | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 25 | | POSSIBILITY OF SUCH DAMAGE. 26 | |# 27 | 28 | #lang rosette 29 | 30 | (require "util.rkt" "cuda.rkt" "cuda-synth.rkt") 31 | 32 | (define struct-size 2) 33 | (define n-block 1) 34 | 35 | (define (create-IO warpSize) 36 | (set-warpSize warpSize) 37 | (define block-size (* 1 warpSize)) 38 | (define array-size (* n-block block-size)) 39 | (define I-sizes (x-y-z (* array-size struct-size))) 40 | (define O-sizes (x-y-z array-size)) 41 | (define I (create-matrix I-sizes gen-uid)) 42 | (define O (create-matrix O-sizes)) 43 | (define O* (create-matrix O-sizes)) 44 | (values block-size I-sizes O-sizes I O O*)) 45 | 46 | (define (run-with-warp-size spec kernel w) 47 | (define-values (block-size I-sizes O-sizes I O O*) 48 | (create-IO w)) 49 | 50 | (define c (gcd struct-size warpSize)) 51 | (define a (/ struct-size c)) 52 | (define b (/ warpSize c)) 53 | 54 | (reset-cost) 55 | (spec I O O-sizes) 56 | (pretty-display `(spec-cost ,(get-cost))) 57 | (run-kernel kernel (x-y-z block-size) (x-y-z n-block) I O* I-sizes O-sizes a b c) 58 | ;(acc-print O) 59 | ;(acc-print O*) 60 | (acc-equal? O O*)) 61 | 62 | (define (AOS-sum-spec I O O-sizes) 63 | (for ([i (get-x O-sizes)]) 64 | (let ([o (create-accumulator (list +) identity)]) 65 | (for ([j struct-size]) 66 | (accumulate o (get I (+ (* i struct-size) j)))) 67 | (set O i o))) 68 | ) 69 | 70 | (define (AOS-sum-slow threadId blockID blockDim I O I-sizes O-sizes a b c) 71 | (define I-cached (create-matrix-local (x-y-z struct-size))) 72 | (define warpID (get-warpId threadId)) 73 | (define offset (+ (* struct-size blockID blockDim) (* struct-size warpID warpSize))) ;; warpID = (threadIdy * blockDimx + threadIdx)/warpSize 74 | (define gid (get-global-threadId threadId blockID)) 75 | (global-to-local I I-cached 76 | (x-y-z struct-size) 77 | offset (x-y-z (* warpSize struct-size)) #f) 78 | 79 | (define localId (get-idInWarp threadId)) 80 | (define o (create-accumulator (list +) identity blockDim)) 81 | (for ([i struct-size]) 82 | (let* ([index (@dup i)] 83 | [lane localId] 84 | [x (shfl (get I-cached index) lane)]) 85 | (accumulate o x) 86 | )) 87 | (reg-to-global o O gid) 88 | ) 89 | 90 | (define (AOS-sum-sketch threadId blockID blockDim I O I-sizes O-sizes a b c) 91 | 92 | (define I-cached (create-matrix-local (x-y-z struct-size))) 93 | (define gid (+ (* blockID blockDim) threadId)) 94 | (define localId (get-idInWarp threadId)) 95 | (global-to-local I I-cached 96 | (x-y-z 1) ;; stride 97 | (* struct-size (- gid localId)) 98 | (x-y-z (* warpSize struct-size)) 99 | #f #:round struct-size) 100 | 101 | (define o (create-accumulator (list +) identity blockDim)) 102 | 103 | ;; column shuffle 104 | (define I-cached2 (permute-vector I-cached struct-size 105 | (lambda (i) (?lane-mod localId (@dup i) 2 struct-size) 106 | #;(?sw-xform i struct-size localId warpSize)))) 107 | (pretty-display "finish permute-vector") 108 | 109 | ;; row shuffle 110 | (for ([i struct-size]) 111 | (let* ([lane (?sw-xform localId warpSize i struct-size)] 112 | ;[lane (?lane-mod localId (@dup i) 2 warpSize)] 113 | [x (shfl (get I-cached2 (@dup i)) lane)] 114 | ) 115 | (accumulate o x #:pred #t) 116 | )) 117 | 118 | (reg-to-global o O gid) 119 | (pretty-display "finish kernel") 120 | ) 121 | 122 | (define (test) 123 | (for ([w (list 32)]) 124 | (let ([ret (run-with-warp-size AOS-sum-spec AOS-sum-sketch w)]) 125 | (pretty-display `(test ,w ,ret ,(get-cost))))) 126 | ) 127 | ;(test) 128 | 129 | (define (synthesis) 130 | (pretty-display "solving...") 131 | (assert 132 | (andmap (lambda (w) (run-with-warp-size AOS-sum-spec AOS-sum-sketch w)) 133 | (list 32))) 134 | (define cost (get-cost)) 135 | (define sol (time (optimize #:minimize (list cost) #:guarantee (assert #t)))) 136 | 137 | (define this-cost (evaluate cost sol)) 138 | (print-forms sol) 139 | (pretty-display `(cost ,this-cost)) 140 | 141 | ;(define sol2 (solve (assert (< cost this-cost)))) 142 | ;(pretty-display `(cost2 ,(evaluate cost sol2))) 143 | ) 144 | (define t0 (current-seconds)) 145 | (synthesis) 146 | (define t1 (current-seconds)) 147 | (- t1 t0) 148 | -------------------------------------------------------------------------------- /ex5-aos-pure-load-sol.rkt: -------------------------------------------------------------------------------- 1 | #| 2 | | Copyright (c) 2018-2019, NVIDIA CORPORATION. All rights reserved. 3 | | 4 | | Redistribution and use in source and binary forms, with or without 5 | | modification, are permitted provided that the following conditions are met: 6 | | 7 | | 1. Redistributions of source code must retain the above copyright notice, 8 | | this list of conditions and the following disclaimer. 9 | | 10 | | 2. Redistributions in binary form must reproduce the above copyright notice, 11 | | this list of conditions and the following disclaimer in the documentation 12 | | and/or other materials provided with the distribution. 13 | | 14 | | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 15 | | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 16 | | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 17 | | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 18 | | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 19 | | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 20 | | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 21 | | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 22 | | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 23 | | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 24 | | POSSIBILITY OF SUCH DAMAGE. 25 | |# 26 | 27 | #lang rosette 28 | 29 | (require "util.rkt" "cuda.rkt" "cuda-synth.rkt") 30 | 31 | (define struct-size 6) 32 | (define n-block 1) 33 | 34 | (define (create-IO warpSize) 35 | (set-warpSize warpSize) 36 | (define block-size (* 2 warpSize)) 37 | (define array-size (* n-block block-size)) 38 | (define I-sizes (x-y-z (* array-size struct-size))) 39 | (define I (create-matrix I-sizes gen-uid)) 40 | (define O (create-matrix I-sizes)) 41 | (define O* (create-matrix I-sizes)) 42 | (values block-size I-sizes I O O*)) 43 | 44 | (define (run-with-warp-size spec kernel w) 45 | (define-values (block-size I-sizes I O O*) 46 | (create-IO w)) 47 | 48 | (define c (gcd struct-size warpSize)) 49 | (define a (/ struct-size c)) 50 | (define b (/ warpSize c)) 51 | 52 | (run-kernel spec (x-y-z block-size) (x-y-z n-block) I O a b c) 53 | (run-kernel kernel (x-y-z block-size) (x-y-z n-block) I O* a b c) 54 | (define ret (equal? O O*)) 55 | (pretty-display `(O ,(print-vec O))) 56 | (pretty-display `(O* ,(print-vec O*))) 57 | ret) 58 | 59 | (define (AOS-load-spec threadId blockID blockDim I O a b c) 60 | (define I-cached (create-matrix-local (x-y-z struct-size))) 61 | (define warpID (get-warpId threadId)) 62 | (define offset (+ (* struct-size blockID blockDim) (* struct-size warpID warpSize))) ;; warpID = (threadIdy * blockDimx + threadIdx)/warpSize 63 | (define gid (get-global-threadId threadId blockID)) 64 | (global-to-local I I-cached 65 | (x-y-z struct-size) 66 | offset (x-y-z (* warpSize struct-size)) #f) 67 | (local-to-global I-cached O 68 | (x-y-z 1) offset (x-y-z (* warpSize struct-size)) #f #:round struct-size) 69 | ) 70 | 71 | (define (print-vec x) 72 | (format "#(~a)" (string-join (for/list ([xi x]) (format "~a" xi))))) 73 | 74 | (define (AOS-load2 threadId blockID blockDim I O a b c) 75 | (define I-cached (create-matrix-local (x-y-z struct-size))) 76 | (define O-cached (create-matrix-local (x-y-z struct-size))) 77 | (define warpID (get-warpId threadId)) 78 | (define offset 79 | (+ (* struct-size blockID blockDim) (* struct-size warpID warpSize))) 80 | (define gid (get-global-threadId threadId blockID)) 81 | (global-to-local 82 | I 83 | I-cached 84 | (x-y-z 1) 85 | offset 86 | (x-y-z (* warpSize struct-size) #:round struct-size) 87 | #f) 88 | (define indices (make-vector struct-size)) 89 | (define indices-o (make-vector struct-size)) 90 | (define localId (get-idInWarp threadId)) 91 | (for 92 | ((i struct-size)) 93 | (let* ((index (sw-xform i struct-size 0 1 2 1 localId warpSize 0 1)) 94 | (lane (sw-xform localId warpSize 2 16 32 -1 i struct-size 0 1)) 95 | (x (shfl (get I-cached index) lane)) 96 | (index-o (sw-xform i struct-size 0 1 2 1 localId warpSize 0 16))) 97 | (unique-warp (modulo lane warpSize)) 98 | (vector-set! indices i index) 99 | (vector-set! indices-o i index-o) 100 | (set O-cached index-o x))) 101 | (for 102 | ((t blockSize)) 103 | (let ((l 104 | (for/list ((i struct-size)) (vector-ref (vector-ref indices i) t))) 105 | (lo 106 | (for/list 107 | ((i struct-size)) 108 | (vector-ref (vector-ref indices-o i) t)))) 109 | (unique-list l) 110 | (unique-list lo))) 111 | (local-to-global 112 | O-cached 113 | O 114 | (x-y-z 1) 115 | offset 116 | (x-y-z (* warpSize struct-size)) 117 | #f #:round struct-size)) 118 | 119 | (define (AOS-load-rcr-2 threadId blockID blockDim I O a b c) 120 | (define I-cached (create-matrix-local (x-y-z struct-size))) 121 | (define temp (create-matrix-local (x-y-z struct-size))) 122 | (define O-cached (create-matrix-local (x-y-z struct-size))) 123 | (define warpID (get-warpId threadId)) 124 | (define offset 125 | (+ (* struct-size blockID blockDim) (* struct-size warpID warpSize))) 126 | (define gid (get-global-threadId threadId blockID)) 127 | (global-to-local 128 | I 129 | I-cached 130 | (x-y-z 1) 131 | offset 132 | (x-y-z (* warpSize struct-size)) 133 | #f #:round struct-size) 134 | (define indices (make-vector struct-size)) 135 | (define indices-o (make-vector struct-size)) 136 | (define localId (get-idInWarp threadId)) 137 | (for 138 | ((i struct-size)) 139 | (let* ((lane1 (sw-xform localId warpSize 0 1 2 1 i struct-size 0 1)) 140 | (x (shfl (get I-cached (@dup i)) lane1))) 141 | (set temp (@dup i) x))) 142 | (for 143 | ((i struct-size)) 144 | (let* ((index (sw-xform i struct-size 0 1 2 1 localId warpSize 0 1)) 145 | (lane2 (sw-xform localId warpSize 16 2 32 1 i struct-size 15 1)) 146 | (x (shfl-send (get temp index) lane2))) 147 | (set O-cached (@dup i) x))) 148 | (local-to-global 149 | O-cached 150 | O 151 | (x-y-z 1) 152 | offset 153 | (x-y-z (* warpSize struct-size)) 154 | #f #:round struct-size)) 155 | 156 | (define (AOS-load3 threadId blockID blockDim I O a b c) 157 | (define I-cached (create-matrix-local (x-y-z struct-size))) 158 | (define O-cached (create-matrix-local (x-y-z struct-size))) 159 | (define localId (modulo (get-x threadId) 32)) 160 | (define offset (* struct-size (- (+ (* blockID blockDim) (get-x threadId)) localId))) 161 | (global-to-local 162 | I 163 | I-cached 164 | (x-y-z 1) 165 | offset 166 | (x-y-z (* warpSize struct-size)) 167 | #f #:round struct-size) 168 | (for 169 | ((i struct-size)) 170 | (let* ((index (sw-xform i struct-size 2 3 3 1 localId warpSize 0 1)) 171 | (lane (sw-xform localId warpSize 3 32 32 1 i struct-size 0 1)) 172 | (x (shfl (get I-cached index) lane)) 173 | (index-o (sw-xform i struct-size 1 3 3 1 localId warpSize 0 warpSize))) 174 | (unique-warp (modulo lane warpSize)) 175 | (set O-cached index-o x))) 176 | (local-to-global 177 | O-cached 178 | O 179 | (x-y-z 1) 180 | offset 181 | (x-y-z (* warpSize struct-size)) 182 | #f #:round struct-size)) 183 | 184 | (define (AOS-loadhsh3 threadId blockID blockDim I O a b c) 185 | (define I-cached (create-matrix-local (x-y-z struct-size))) 186 | (define temp (create-matrix-local (x-y-z struct-size))) 187 | (define O-cached (create-matrix-local (x-y-z struct-size))) 188 | (define warpID (get-warpId threadId)) 189 | (define offset 190 | (+ (* struct-size blockID blockDim) (* struct-size warpID warpSize))) 191 | (define gid (get-global-threadId threadId blockID)) 192 | (global-to-local 193 | I 194 | I-cached 195 | (x-y-z 1) 196 | offset 197 | (x-y-z (* warpSize struct-size)) 198 | #f #:round struct-size) 199 | (define indices (make-vector struct-size)) 200 | (define indices-o (make-vector struct-size)) 201 | (define localId (get-idInWarp threadId)) 202 | (for 203 | ((i struct-size)) 204 | (let* ((lane1 (sw-xform localId warpSize 0 1 32 1 i struct-size 31 1)) 205 | (x (shfl (get I-cached (@dup i)) lane1))) 206 | (set temp (@dup i) x))) 207 | (for 208 | ((i struct-size)) 209 | (let* ((index (sw-xform i struct-size 2 3 3 1 localId warpSize 0 1)) 210 | (lane2 (sw-xform localId warpSize 11 32 32 1 i struct-size 20 1)) 211 | (x (shfl-send (get temp index) lane2))) 212 | (set O-cached (@dup i) x))) 213 | (local-to-global 214 | O-cached 215 | O 216 | (x-y-z 1) 217 | offset 218 | (x-y-z (* warpSize struct-size)) 219 | #f #:round struct-size)) 220 | 221 | (define (AOS-loadhsh3* threadId blockID blockDim I O a b c) 222 | (define I-cached (create-matrix-local (x-y-z struct-size))) 223 | (define warpID (get-warpId threadId)) 224 | (define offset 225 | (+ (* struct-size blockID blockDim) (* struct-size warpID warpSize))) 226 | (define gid (get-global-threadId threadId blockID)) 227 | (global-to-local 228 | I 229 | I-cached 230 | (x-y-z 1) 231 | offset 232 | (x-y-z (* warpSize struct-size)) 233 | #f #:round struct-size 234 | #:shfl (lambda (localId i) (sw-xform localId warpSize 0 1 32 1 i struct-size 31 1))) 235 | (define localId (get-idInWarp threadId)) 236 | (define O-cached (permute-vector I-cached struct-size 237 | (lambda (i) (sw-xform i struct-size 2 3 3 1 localId warpSize 0 1)))) 238 | (local-to-global 239 | O-cached 240 | O 241 | (x-y-z 1) 242 | offset 243 | (x-y-z (* warpSize struct-size)) 244 | #f #:round struct-size 245 | #:shfl (lambda (localId i) 246 | (sw-xform localId warpSize 11 32 32 1 i struct-size 20 1))) 247 | ) 248 | 249 | (define (AOS-load4 threadId blockID blockDim I O a b c) 250 | (define I-cached (create-matrix-local (x-y-z struct-size))) 251 | (define O-cached (create-matrix-local (x-y-z struct-size))) 252 | (define warpID (get-warpId threadId)) 253 | (define offset 254 | (+ (* struct-size blockID blockDim) (* struct-size warpID warpSize))) 255 | (define gid (get-global-threadId threadId blockID)) 256 | (global-to-local 257 | I 258 | I-cached 259 | (x-y-z 1) 260 | offset 261 | (x-y-z (* warpSize struct-size)) 262 | #f #:round struct-size) 263 | (define indices (make-vector struct-size)) 264 | (define indices-o (make-vector struct-size)) 265 | (define localId (get-idInWarp threadId)) 266 | (for 267 | ((i struct-size)) 268 | (let* ((index (sw-xform i struct-size 3 4 4 1 localId warpSize 0 1)) 269 | (lane (sw-xform localId warpSize 4 8 32 -1 270 | i struct-size 0 1)) 271 | (x (shfl (get I-cached index) lane)) 272 | (index-o (sw-xform i struct-size 0 1 4 1 localId warpSize 0 8))) 273 | (pretty-display `(lane ,lane)) 274 | (unique-warp (modulo lane warpSize)) 275 | (vector-set! indices i index) 276 | (vector-set! indices-o i index-o) 277 | (set O-cached index-o x))) 278 | (for 279 | ((t blockSize)) 280 | (let ((l 281 | (for/list ((i struct-size)) (vector-ref (vector-ref indices i) t))) 282 | (lo 283 | (for/list 284 | ((i struct-size)) 285 | (vector-ref (vector-ref indices-o i) t)))) 286 | (unique-list l) 287 | (unique-list lo))) 288 | (local-to-global 289 | O-cached 290 | O 291 | (x-y-z 1) 292 | offset 293 | (x-y-z (* warpSize struct-size)) 294 | #f #:round struct-size)) 295 | 296 | (define (AOS-load-rcr-4 threadId blockID blockDim I O a b c) 297 | (define I-cached (create-matrix-local (x-y-z struct-size))) 298 | (define temp (create-matrix-local (x-y-z struct-size))) 299 | (define O-cached (create-matrix-local (x-y-z struct-size))) 300 | (define warpID (get-warpId threadId)) 301 | (define offset 302 | (+ (* struct-size blockID blockDim) (* struct-size warpID warpSize))) 303 | (define gid (get-global-threadId threadId blockID)) 304 | (global-to-local 305 | I 306 | I-cached 307 | (x-y-z 1) 308 | offset 309 | (x-y-z (* warpSize struct-size)) 310 | #f #:round struct-size) 311 | (define indices (make-vector struct-size)) 312 | (define indices-o (make-vector struct-size)) 313 | (define localId (get-idInWarp threadId)) 314 | (for 315 | ((i struct-size)) 316 | (let* ((lane1 (sw-xform localId warpSize 0 1 4 1 i struct-size 0 1)) 317 | (x (shfl (get I-cached (@dup i)) lane1))) 318 | (set temp (@dup i) x))) 319 | (for 320 | ((i struct-size)) 321 | (let* ((index (sw-xform i struct-size 0 1 4 1 localId warpSize 0 -1)) 322 | (lane2 (sw-xform localId warpSize 24 4 32 1 i struct-size 7 1)) 323 | (x (shfl-send (get temp index) lane2))) 324 | (set O-cached (@dup i) x))) 325 | (local-to-global 326 | O-cached 327 | O 328 | (x-y-z 1) 329 | offset 330 | (x-y-z (* warpSize struct-size)) 331 | #f #:round struct-size)) 332 | 333 | (define (AOS-load5 threadId blockID blockDim I O a b c) 334 | (define I-cached (create-matrix-local (x-y-z struct-size))) 335 | (define O-cached (create-matrix-local (x-y-z struct-size))) 336 | (define warpID (get-warpId threadId)) 337 | (define offset 338 | (+ (* struct-size blockID blockDim) (* struct-size warpID warpSize))) 339 | (define gid (get-global-threadId threadId blockID)) 340 | (global-to-local 341 | I 342 | I-cached 343 | (x-y-z 1) 344 | offset 345 | (x-y-z (* warpSize struct-size)) 346 | #f #:round struct-size) 347 | (define indices (make-vector struct-size)) 348 | (define indices-o (make-vector struct-size)) 349 | (define localId (get-idInWarp threadId)) 350 | (for 351 | ((i struct-size)) 352 | (let* ((index (sw-xform i struct-size 3 5 5 1 localId warpSize 1 1)) 353 | (lane (sw-xform localId warpSize 5 32 32 1 i struct-size 0 1)) 354 | (x (shfl (get I-cached index) lane)) 355 | (index-o (sw-xform i struct-size 1 5 5 -1 localId warpSize 0 1))) 356 | (unique-warp (modulo lane warpSize)) 357 | (vector-set! indices i index) 358 | (vector-set! indices-o i index-o) 359 | (set O-cached index-o x))) 360 | (for 361 | ((t blockSize)) 362 | (let ((l 363 | (for/list ((i struct-size)) (vector-ref (vector-ref indices i) t))) 364 | (lo 365 | (for/list 366 | ((i struct-size)) 367 | (vector-ref (vector-ref indices-o i) t)))) 368 | (unique-list l) 369 | (unique-list lo))) 370 | (local-to-global 371 | O-cached 372 | O 373 | (x-y-z 1) 374 | offset 375 | (x-y-z (* warpSize struct-size)) 376 | #f #:round struct-size)) 377 | 378 | (define (AOS-load-rcr-5 threadId blockID blockDim I O a b c) 379 | (define I-cached (create-matrix-local (x-y-z struct-size))) 380 | (define temp (create-matrix-local (x-y-z struct-size))) 381 | (define O-cached (create-matrix-local (x-y-z struct-size))) 382 | (define warpID (get-warpId threadId)) 383 | (define offset 384 | (+ (* struct-size blockID blockDim) (* struct-size warpID warpSize))) 385 | (define gid (get-global-threadId threadId blockID)) 386 | (global-to-local 387 | I 388 | I-cached 389 | (x-y-z 1) 390 | offset 391 | (x-y-z (* warpSize struct-size)) 392 | #f #:round struct-size) 393 | (define indices (make-vector struct-size)) 394 | (define indices-o (make-vector struct-size)) 395 | (define localId (get-idInWarp threadId)) 396 | (for 397 | ((i struct-size)) 398 | (let* ((x (get I-cached (@dup i)))) 399 | (set temp (@dup i) x))) 400 | (for 401 | ((i struct-size)) 402 | (let* ((index (sw-xform i 5 3 5 5 1 403 | localId warpSize 2 warpSize)) 404 | #;(index 405 | (modulo (+ (* 3 i) (* localId 2)) 5)) 406 | (lane2 (sw-xform localId 32 13 32 32 1 407 | i 5 19 5)) 408 | #;(lane2 409 | (modulo 410 | (- (* 13 localId) (* 13 i)) 411 | 32)) 412 | (x (shfl-send (get temp index) lane2))) 413 | ;(pretty-display `(lane ,(print-vec (modulo lane2 32)))) 414 | (set O-cached (@dup i) x))) 415 | (local-to-global 416 | O-cached 417 | O 418 | (x-y-z 1) 419 | offset 420 | (x-y-z (* warpSize struct-size)) 421 | #f #:round struct-size)) 422 | 423 | (define (AOS-load6 threadId blockID blockDim I O a b c) 424 | (define I-cached (create-matrix-local (x-y-z struct-size))) 425 | (define O-cached (create-matrix-local (x-y-z struct-size))) 426 | 427 | (define localId (modulo (get-x threadId) 32)) 428 | (define offset (* struct-size (- (+ (* blockID blockDim) (get-x threadId)) localId))) 429 | 430 | (global-to-local I I-cached 431 | (x-y-z 1) 432 | offset 433 | (x-y-z (* warpSize struct-size)) #f #:round struct-size) 434 | 435 | ;; column shuffle 436 | (define I-cached2 (permute-vector I-cached struct-size 437 | (lambda (i) 438 | #;(+ (modulo (quotient (+ (modulo (- i localId) struct-size) 1) 2) 3) 439 | (* 3 (modulo (- i localId) 2))) 440 | (sw-xform i struct-size 3 2 struct-size 1 441 | localId warpSize 0 warpSize #;offset 3 442 | #:ecr 5 #:ec 1) 443 | ))) 444 | 445 | ;; row shuffle 446 | (for ([i struct-size]) 447 | (let* ([lane 448 | #;(modulo 449 | (+ (* 6 localId) (modulo (+ i (quotient localId 16)) 6)) 450 | warpSize) 451 | (sw-xform localId warpSize 6 16 warpSize -1 452 | i struct-size 1 struct-size #;offset 0 453 | #:gcd 6)] 454 | [x (shfl (get I-cached2 (@dup i)) lane)] 455 | ) 456 | (set O-cached (@dup i) x)) 457 | ) 458 | 459 | ;; column shuffle 460 | (define O-cached2 (permute-vector O-cached struct-size 461 | (lambda (i) 462 | #;(modulo (- i (quotient localId 16)) struct-size) 463 | (sw-xform i struct-size 1 struct-size struct-size 1 464 | localId warpSize 0 -16 #;offset 0)))) 465 | 466 | (local-to-global O-cached2 O 467 | (x-y-z 1) 468 | offset 469 | (x-y-z (* warpSize struct-size)) #f #:round struct-size) 470 | ) 471 | 472 | (define (test) 473 | (for ([w (list 32)]) 474 | (let ([ret (run-with-warp-size AOS-load-spec AOS-load6 w)]) 475 | (pretty-display `(test ,w ,ret)))) 476 | ) 477 | (test) 478 | 479 | -------------------------------------------------------------------------------- /ex5-aos-pure-load.rkt: -------------------------------------------------------------------------------- 1 | #| 2 | | Copyright (c) 2018-2019, University of California, Berkeley. 3 | | Copyright (c) 2018-2019, NVIDIA CORPORATION. All rights reserved. 4 | | 5 | | Redistribution and use in source and binary forms, with or without 6 | | modification, are permitted provided that the following conditions are met: 7 | | 8 | | 1. Redistributions of source code must retain the above copyright notice, 9 | | this list of conditions and the following disclaimer. 10 | | 11 | | 2. Redistributions in binary form must reproduce the above copyright notice, 12 | | this list of conditions and the following disclaimer in the documentation 13 | | and/or other materials provided with the distribution. 14 | | 15 | | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 16 | | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 17 | | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 18 | | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 19 | | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 20 | | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 21 | | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 22 | | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 23 | | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 24 | | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 25 | | POSSIBILITY OF SUCH DAMAGE. 26 | |# 27 | 28 | #lang rosette 29 | 30 | (require "util.rkt" "cuda.rkt" "cuda-synth.rkt") 31 | 32 | (define struct-size 3) 33 | (define n-block 1) 34 | 35 | (define (create-IO warpSize) 36 | (set-warpSize warpSize) 37 | (define block-size (* 1 warpSize)) 38 | (define array-size (* n-block block-size)) 39 | (define I-sizes (x-y-z (* array-size struct-size))) 40 | (define I (create-matrix I-sizes gen-uid)) 41 | (define O (create-matrix I-sizes)) 42 | (define O* (create-matrix I-sizes)) 43 | (values block-size I-sizes I O O*)) 44 | 45 | (define (run-with-warp-size spec kernel w) 46 | (define-values (block-size I-sizes I O O*) 47 | (create-IO w)) 48 | 49 | (define c (gcd struct-size warpSize)) 50 | (define a (/ struct-size c)) 51 | (define b (/ warpSize c)) 52 | 53 | (run-kernel spec (x-y-z block-size) (x-y-z n-block) I O a b c) 54 | (run-kernel kernel (x-y-z block-size) (x-y-z n-block) I O* a b c) 55 | (define ret (equal? O O*)) 56 | ;(pretty-display `(O ,O)) 57 | ;(pretty-display `(O* ,O*)) 58 | ret) 59 | 60 | (define (AOS-load-spec threadId blockID blockDim I O a b c) 61 | (define I-cached (create-matrix-local (x-y-z struct-size))) 62 | (define warpID (get-warpId threadId)) 63 | (define offset (+ (* struct-size blockID blockDim) (* struct-size warpID warpSize))) ;; warpID = (threadIdy * blockDimx + threadIdx)/warpSize 64 | (define gid (get-global-threadId threadId blockID)) 65 | (global-to-local I I-cached 66 | (x-y-z struct-size) 67 | offset (x-y-z (* warpSize struct-size)) #f) 68 | (local-to-global I-cached O 69 | (x-y-z 1) offset (x-y-z (* warpSize struct-size)) #f #:round struct-size) 70 | ) 71 | 72 | (define (print-vec x) 73 | (format "#(~a)" (string-join (for/list ([xi x]) (format "~a" xi))))) 74 | 75 | ;; Sketch that uses column-row-column shuffles. 76 | (define (AOS-load-sketch threadId blockID blockDim I O a b c) 77 | (define I-cached (create-matrix-local (x-y-z struct-size))) 78 | (define O-cached (create-matrix-local (x-y-z struct-size))) 79 | 80 | (define localId (modulo (get-x threadId) 32)) 81 | (define offset (* struct-size (- (+ (* blockID blockDim) (get-x threadId)) localId))) 82 | 83 | (global-to-local I I-cached 84 | (x-y-z 1) 85 | offset 86 | (x-y-z (* warpSize struct-size)) #f #:round struct-size) 87 | 88 | ;; column shuffle 89 | (define I-cached2 (permute-vector I-cached struct-size 90 | (lambda (i) (?sw-xform i struct-size localId warpSize)))) 91 | 92 | ;; row shuffle 93 | (for ([i struct-size]) 94 | (let* ([lane (?sw-xform localId warpSize i struct-size)] 95 | [x (shfl (get I-cached2 (@dup i)) lane)] 96 | ) 97 | (set O-cached (@dup i) x)) 98 | ) 99 | 100 | ;; column shuffle 101 | (define O-cached2 (permute-vector O-cached struct-size 102 | (lambda (i) (?sw-xform i struct-size localId warpSize)))) 103 | 104 | (local-to-global O-cached2 O 105 | (x-y-z 1) 106 | offset 107 | (x-y-z (* warpSize struct-size)) #f #:round struct-size) 108 | ) 109 | 110 | ;; Sketch that uses row-column-row shuffles. 111 | (define (AOS-load-rcr-sketch threadId blockID blockDim I O a b c) 112 | (define I-cached (create-matrix-local (x-y-z struct-size))) 113 | 114 | (define localId (modulo (get-x threadId) 32)) 115 | (define offset (* struct-size (- (+ (* blockID blockDim) (get-x threadId)) localId))) 116 | 117 | ;; load with (row) shuffle 118 | (global-to-local 119 | I 120 | I-cached 121 | (x-y-z 1) 122 | offset 123 | (x-y-z (* warpSize struct-size)) 124 | #f #:round struct-size 125 | #:shfl (lambda (localId i) (?sw-xform localId warpSize i struct-size))) 126 | 127 | ;; column shuffle 128 | (define O-cached (permute-vector I-cached struct-size 129 | (lambda (i) (?sw-xform i struct-size localId warpSize)))) 130 | 131 | ;; store with (row) shuffle 132 | (local-to-global 133 | O-cached 134 | O 135 | (x-y-z 1) 136 | offset 137 | (x-y-z (* warpSize struct-size)) 138 | #f #:round struct-size 139 | #:shfl (lambda (localId i) 140 | (?sw-xform localId warpSize i struct-size))) 141 | ) 142 | 143 | (define (test) 144 | (for ([w (list 32)]) 145 | (let ([ret (run-with-warp-size AOS-load-spec AOS-load-sketch w)]) 146 | (pretty-display `(test ,w ,ret)))) 147 | ) 148 | ;(test) 149 | 150 | (define (synthesis) 151 | (pretty-display "solving...") 152 | (assert (andmap (lambda (w) (run-with-warp-size AOS-load-spec AOS-load-sketch w)) 153 | (list 32))) 154 | (define cost (get-cost)) 155 | 156 | (define sol (time (optimize #:minimize (list cost) #:guarantee (assert #t)))) 157 | 158 | (define this-cost (evaluate cost sol)) 159 | (print-forms sol) 160 | (pretty-display `(cost ,this-cost)) 161 | ) 162 | (define t0 (current-seconds)) 163 | (synthesis) 164 | (define t1 (current-seconds)) 165 | (- t1 t0) 166 | -------------------------------------------------------------------------------- /util.rkt: -------------------------------------------------------------------------------- 1 | #| 2 | | Copyright (c) 2018-2019, University of California, Berkeley. 3 | | Copyright (c) 2018-2019, NVIDIA CORPORATION. All rights reserved. 4 | | 5 | | Redistribution and use in source and binary forms, with or without 6 | | modification, are permitted provided that the following conditions are met: 7 | | 8 | | 1. Redistributions of source code must retain the above copyright notice, 9 | | this list of conditions and the following disclaimer. 10 | | 11 | | 2. Redistributions in binary form must reproduce the above copyright notice, 12 | | this list of conditions and the following disclaimer in the documentation 13 | | and/or other materials provided with the distribution. 14 | | 15 | | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 16 | | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 17 | | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 18 | | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 19 | | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 20 | | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 21 | | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 22 | | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 23 | | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 24 | | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 25 | | POSSIBILITY OF SUCH DAMAGE. 26 | |# 27 | 28 | #lang rosette 29 | 30 | (provide (all-defined-out)) 31 | 32 | (define BW 10) 33 | (current-bitwidth BW) 34 | 35 | ;; dims: x y z 36 | (define (create-matrix dims [init (lambda () 0)]) 37 | (define (f dims) 38 | (if (empty? dims) 39 | (init) 40 | (let ([vec (make-vector (car dims))]) 41 | (for ([i (car dims)]) 42 | (vector-set! vec i (f (cdr dims)))) 43 | vec))) 44 | (f (reverse dims))) 45 | 46 | (define (get-dims M) 47 | (define (f M) 48 | (if (vector? M) 49 | (cons (vector-length M) (f (vector-ref M 0))) 50 | (list))) 51 | (reverse (f M))) 52 | 53 | (define (my-vector-ref vec index) 54 | (cond 55 | [(and (vector? index) 56 | (= (vector-length vec) (vector-length index)) 57 | (vector? (vector-ref vec 0))) 58 | (for*/all ([my-vec vec] [my-index index]) 59 | (for/vector ([vec-i my-vec] [index-i my-index]) (vector-ref vec-i index-i)))] 60 | 61 | [(vector? index) 62 | (for*/all ([my-vec vec] [my-index index]) 63 | (for/vector ([index-i my-index]) (vector-ref my-vec index-i)))] 64 | 65 | [else 66 | (for*/all ([my-vec vec] [my-index index]) 67 | (vector-ref my-vec my-index))])) 68 | 69 | (define (my-vector-set! vec index val) 70 | (when (and (vector? index) (vector? val)) 71 | (assert (= (vector-length vec) (vector-length val)) `(= (vector-length vec) (vector-length val)))) 72 | (cond 73 | [(and (vector? index) 74 | (= (vector-length vec) (vector-length index)) 75 | (vector? (vector-ref vec 0))) 76 | (if (vector? val) 77 | (for/vector ([vec-i vec] [index-i index] [val-i val]) (vector-set! vec-i index-i val-i)) 78 | (for/vector ([vec-i vec] [index-i index]) (vector-set! vec-i index-i val))) 79 | ] 80 | 81 | [(vector? index) 82 | (if (vector? val) 83 | (for/vector ([vec-i vec] [index-i index] [val-i val]) (vector-set! vec-i index-i val-i)) 84 | (for/vector ([vec-i vec] [index-i index]) (vector-set! vec-i index-i val)))] 85 | 86 | [else 87 | (vector-set! vec index val)])) 88 | 89 | (define (get* M l) 90 | (define (f M l) 91 | (if (empty? l) 92 | M 93 | (f (vector-ref M (car l)) (cdr l)))) 94 | (f M (reverse l))) 95 | 96 | (define (set* M l v) 97 | (define (f M l) 98 | (if (= (length l) 1) 99 | (vector-set! M (car l) v) 100 | (f (vector-ref M (car l)) (cdr l)))) 101 | (f M (reverse l))) 102 | 103 | (define-syntax get 104 | (syntax-rules () 105 | ((get M i) 106 | (my-vector-ref M i)) 107 | ((get M i ... j) 108 | (get (my-vector-ref M j) i ...)))) 109 | 110 | (define-syntax set 111 | (syntax-rules () 112 | ((set M i v) (my-vector-set! M i v)) 113 | ((set M i ... j v) (set (my-vector-ref M j) i ... v)))) 114 | 115 | (define (my-list-ref l index) 116 | (if (vector? l) 117 | (for/vector ([x l]) (my-list-ref x index)) 118 | (list-ref l index))) 119 | 120 | (define-syntax-rule (get-x l) (my-list-ref l 0)) 121 | (define-syntax-rule (get-y l) (my-list-ref l 1)) 122 | (define-syntax-rule (get-z l) (my-list-ref l 2)) 123 | (define-syntax x-y-z 124 | (syntax-rules () 125 | ((x-y-z x) (list x)) 126 | ((x-y-z x y) (list x y)) 127 | ((x-y-z x y z) (list x y z)))) 128 | 129 | (define (global-threadID threadID blockID blockDIM) 130 | (map (lambda (tid bid dim) (+ tid (* bid dim))) threadID blockID blockDIM)) 131 | 132 | 133 | (define (clone x) 134 | (cond 135 | [(vector? x) (for/vector ([xi x]) (clone xi))] 136 | [else x])) 137 | 138 | (define (lov2vol x) 139 | (define vec-len (vector-length (car x))) 140 | (define list-len (length x)) 141 | 142 | (for/vector ([vi vec-len]) 143 | (for/list ([li list-len]) 144 | (vector-ref (list-ref x li) vi)))) 145 | 146 | (define (print-vec x) 147 | (format "#(~a)" (string-join (for/list ([xi x]) (format "~a" xi))))) 148 | --------------------------------------------------------------------------------