├── LICENSE ├── README.md ├── bmm.py ├── colab ├── test_bmm.ipynb ├── test_mbmm.ipynb ├── test_minbmm.ipynb └── test_topkbmm.ipynb ├── custom_kernel.py ├── imgs ├── bmm_1.png ├── bmm_2.png ├── bmm_A[1,N,N] B[1,N,N].png ├── bmm_A[1,N,N] B[1,N,N]_inverted.png ├── bmm_A[128,N,128] B[128,128,N].png ├── bmm_A[128,N,128] B[128,128,N]_inverted.png ├── bmm_A[128,N,N] B[128,N,N].png ├── mask1.png ├── mask2.png ├── mbmm_1.png ├── mbmm_2.png ├── mbmm_3.png ├── mbmm_4.png ├── mbmm_A[128,1024,X] B[128,X,1024].png ├── mbmm_A[128,1024,X] B[128,X,1024]_inverted.png ├── mbmm_A[128,X,512] B[128,512,X].png ├── mbmm_A[128,X,64] B[128,64,X].png ├── mbmm_A[128,X,64] B[128,64,X]_inverted.png ├── mbmm_A[X,1024,128] B[X,128,1024].png ├── mbmm_A[X,1024,128] B[X,128,1024]_inverted.png ├── mbmm_A[X,1024,64] B[X,64,1024].png ├── mbmm_A[X,512,64] B[X,64,512].png ├── min_bmm_A[1,N,128] B[1,128,N].png ├── min_bmm_A[1,N,16] B[1,16,N].png ├── min_bmm_A[1,N,16] B[1,16,N]_inverted.png ├── min_bmm_A[1,N,16] B[1,16,N]_memory.png ├── min_bmm_A[1,N,16] B[1,16,N]_memory_inverted.png ├── min_bmm_A[1,N,256] B[1,256,N].png ├── min_bmm_A[1,N,256] B[1,256,N]_inverted.png ├── min_bmm_A[1,N,256] B[1,256,N]_memory.png ├── min_bmm_A[1,N,256] B[1,256,N]_memory_inverted.png ├── min_bmm_A[1,N,32] B[1,32,N].png ├── min_bmm_A[1,N,64] B[1,64,N].png ├── min_bmm_A[1,N,64] B[1,64,N]_inverted.png ├── min_bmm_A[1,N,64] B[1,64,N]_memory.png ├── min_bmm_A[1,N,64] B[1,64,N]_memory_inverted.png ├── min_bmm_A[64,N,256] B[64,256,N].png ├── min_bmm_A[64,N,64] B[64,64,N].png ├── min_bmm_A[64,N,64] B[64,64,N]_inverted.png ├── min_bmm_A[64,N,64] B[64,64,N]_memory.png ├── min_bmm_A[64,N,64] B[64,64,N]_memory_inverted.png ├── topk_bmm_A[1,N,1024] B[1,1024,1024]_loglog.png ├── topk_bmm_A[1,N,1024] B[1,1024,1024]_memory_semilogx.png ├── topk_bmm_A[1,N,1024] B[1,1024,1024]_semilogx.png ├── topk_bmm_A[1,N,1024] B[1,1024,1024]_semilogx_inverted.png ├── topk_bmm_A[1,N,256] B[1,256,1024]_loglog.png ├── topk_bmm_A[1,N,256] B[1,256,1024]_memory_semilogx.png ├── topk_bmm_A[1,N,256] B[1,256,1024]_memory_semilogx_inverted.png ├── topk_bmm_A[1,N,256] B[1,256,1024]_semilogx.png ├── topk_bmm_A[1,N,256] B[1,256,1024]_semilogx_inverted.png ├── topk_bmm_A[1,N,64] B[1,64,1024]_loglog.png ├── topk_bmm_A[1,N,64] B[1,64,1024]_memory_semilogx.png ├── topk_bmm_A[1,N,64] B[1,64,1024]_memory_semilogx_inverted.png ├── topk_bmm_A[1,N,64] B[1,64,1024]_semilogx.png └── topk_bmm_A[1,N,64] B[1,64,1024]_semilogx_inverted.png ├── kernels ├── bitonic_sort.cu ├── bmm.cu ├── bmm_helpers.cu ├── mbmm.cu ├── minbmm.cu └── topkbmm.cu ├── mbmm.py ├── minbmm.py └── topkbmm.py /LICENSE: -------------------------------------------------------------------------------- 1 | GNU GENERAL PUBLIC LICENSE 2 | Version 3, 29 June 2007 3 | 4 | Copyright (C) 2007 Free Software Foundation, Inc. 5 | Everyone is permitted to copy and distribute verbatim copies 6 | of this license document, but changing it is not allowed. 7 | 8 | Preamble 9 | 10 | The GNU General Public License is a free, copyleft license for 11 | software and other kinds of works. 12 | 13 | The licenses for most software and other practical works are designed 14 | to take away your freedom to share and change the works. By contrast, 15 | the GNU General Public License is intended to guarantee your freedom to 16 | share and change all versions of a program--to make sure it remains free 17 | software for all its users. We, the Free Software Foundation, use the 18 | GNU General Public License for most of our software; it applies also to 19 | any other work released this way by its authors. You can apply it to 20 | your programs, too. 21 | 22 | When we speak of free software, we are referring to freedom, not 23 | price. Our General Public Licenses are designed to make sure that you 24 | have the freedom to distribute copies of free software (and charge for 25 | them if you wish), that you receive source code or can get it if you 26 | want it, that you can change the software or use pieces of it in new 27 | free programs, and that you know you can do these things. 28 | 29 | To protect your rights, we need to prevent others from denying you 30 | these rights or asking you to surrender the rights. Therefore, you have 31 | certain responsibilities if you distribute copies of the software, or if 32 | you modify it: responsibilities to respect the freedom of others. 33 | 34 | For example, if you distribute copies of such a program, whether 35 | gratis or for a fee, you must pass on to the recipients the same 36 | freedoms that you received. You must make sure that they, too, receive 37 | or can get the source code. And you must show them these terms so they 38 | know their rights. 39 | 40 | Developers that use the GNU GPL protect your rights with two steps: 41 | (1) assert copyright on the software, and (2) offer you this License 42 | giving you legal permission to copy, distribute and/or modify it. 43 | 44 | For the developers' and authors' protection, the GPL clearly explains 45 | that there is no warranty for this free software. For both users' and 46 | authors' sake, the GPL requires that modified versions be marked as 47 | changed, so that their problems will not be attributed erroneously to 48 | authors of previous versions. 49 | 50 | Some devices are designed to deny users access to install or run 51 | modified versions of the software inside them, although the manufacturer 52 | can do so. This is fundamentally incompatible with the aim of 53 | protecting users' freedom to change the software. The systematic 54 | pattern of such abuse occurs in the area of products for individuals to 55 | use, which is precisely where it is most unacceptable. Therefore, we 56 | have designed this version of the GPL to prohibit the practice for those 57 | products. If such problems arise substantially in other domains, we 58 | stand ready to extend this provision to those domains in future versions 59 | of the GPL, as needed to protect the freedom of users. 60 | 61 | Finally, every program is threatened constantly by software patents. 62 | States should not allow patents to restrict development and use of 63 | software on general-purpose computers, but in those that do, we wish to 64 | avoid the special danger that patents applied to a free program could 65 | make it effectively proprietary. To prevent this, the GPL assures that 66 | patents cannot be used to render the program non-free. 67 | 68 | The precise terms and conditions for copying, distribution and 69 | modification follow. 70 | 71 | TERMS AND CONDITIONS 72 | 73 | 0. Definitions. 74 | 75 | "This License" refers to version 3 of the GNU General Public License. 76 | 77 | "Copyright" also means copyright-like laws that apply to other kinds of 78 | works, such as semiconductor masks. 79 | 80 | "The Program" refers to any copyrightable work licensed under this 81 | License. Each licensee is addressed as "you". "Licensees" and 82 | "recipients" may be individuals or organizations. 83 | 84 | To "modify" a work means to copy from or adapt all or part of the work 85 | in a fashion requiring copyright permission, other than the making of an 86 | exact copy. The resulting work is called a "modified version" of the 87 | earlier work or a work "based on" the earlier work. 88 | 89 | A "covered work" means either the unmodified Program or a work based 90 | on the Program. 91 | 92 | To "propagate" a work means to do anything with it that, without 93 | permission, would make you directly or secondarily liable for 94 | infringement under applicable copyright law, except executing it on a 95 | computer or modifying a private copy. Propagation includes copying, 96 | distribution (with or without modification), making available to the 97 | public, and in some countries other activities as well. 98 | 99 | To "convey" a work means any kind of propagation that enables other 100 | parties to make or receive copies. Mere interaction with a user through 101 | a computer network, with no transfer of a copy, is not conveying. 102 | 103 | An interactive user interface displays "Appropriate Legal Notices" 104 | to the extent that it includes a convenient and prominently visible 105 | feature that (1) displays an appropriate copyright notice, and (2) 106 | tells the user that there is no warranty for the work (except to the 107 | extent that warranties are provided), that licensees may convey the 108 | work under this License, and how to view a copy of this License. If 109 | the interface presents a list of user commands or options, such as a 110 | menu, a prominent item in the list meets this criterion. 111 | 112 | 1. Source Code. 113 | 114 | The "source code" for a work means the preferred form of the work 115 | for making modifications to it. "Object code" means any non-source 116 | form of a work. 117 | 118 | A "Standard Interface" means an interface that either is an official 119 | standard defined by a recognized standards body, or, in the case of 120 | interfaces specified for a particular programming language, one that 121 | is widely used among developers working in that language. 122 | 123 | The "System Libraries" of an executable work include anything, other 124 | than the work as a whole, that (a) is included in the normal form of 125 | packaging a Major Component, but which is not part of that Major 126 | Component, and (b) serves only to enable use of the work with that 127 | Major Component, or to implement a Standard Interface for which an 128 | implementation is available to the public in source code form. A 129 | "Major Component", in this context, means a major essential component 130 | (kernel, window system, and so on) of the specific operating system 131 | (if any) on which the executable work runs, or a compiler used to 132 | produce the work, or an object code interpreter used to run it. 133 | 134 | The "Corresponding Source" for a work in object code form means all 135 | the source code needed to generate, install, and (for an executable 136 | work) run the object code and to modify the work, including scripts to 137 | control those activities. However, it does not include the work's 138 | System Libraries, or general-purpose tools or generally available free 139 | programs which are used unmodified in performing those activities but 140 | which are not part of the work. For example, Corresponding Source 141 | includes interface definition files associated with source files for 142 | the work, and the source code for shared libraries and dynamically 143 | linked subprograms that the work is specifically designed to require, 144 | such as by intimate data communication or control flow between those 145 | subprograms and other parts of the work. 146 | 147 | The Corresponding Source need not include anything that users 148 | can regenerate automatically from other parts of the Corresponding 149 | Source. 150 | 151 | The Corresponding Source for a work in source code form is that 152 | same work. 153 | 154 | 2. Basic Permissions. 155 | 156 | All rights granted under this License are granted for the term of 157 | copyright on the Program, and are irrevocable provided the stated 158 | conditions are met. This License explicitly affirms your unlimited 159 | permission to run the unmodified Program. The output from running a 160 | covered work is covered by this License only if the output, given its 161 | content, constitutes a covered work. This License acknowledges your 162 | rights of fair use or other equivalent, as provided by copyright law. 163 | 164 | You may make, run and propagate covered works that you do not 165 | convey, without conditions so long as your license otherwise remains 166 | in force. You may convey covered works to others for the sole purpose 167 | of having them make modifications exclusively for you, or provide you 168 | with facilities for running those works, provided that you comply with 169 | the terms of this License in conveying all material for which you do 170 | not control copyright. Those thus making or running the covered works 171 | for you must do so exclusively on your behalf, under your direction 172 | and control, on terms that prohibit them from making any copies of 173 | your copyrighted material outside their relationship with you. 174 | 175 | Conveying under any other circumstances is permitted solely under 176 | the conditions stated below. Sublicensing is not allowed; section 10 177 | makes it unnecessary. 178 | 179 | 3. Protecting Users' Legal Rights From Anti-Circumvention Law. 180 | 181 | No covered work shall be deemed part of an effective technological 182 | measure under any applicable law fulfilling obligations under article 183 | 11 of the WIPO copyright treaty adopted on 20 December 1996, or 184 | similar laws prohibiting or restricting circumvention of such 185 | measures. 186 | 187 | When you convey a covered work, you waive any legal power to forbid 188 | circumvention of technological measures to the extent such circumvention 189 | is effected by exercising rights under this License with respect to 190 | the covered work, and you disclaim any intention to limit operation or 191 | modification of the work as a means of enforcing, against the work's 192 | users, your or third parties' legal rights to forbid circumvention of 193 | technological measures. 194 | 195 | 4. Conveying Verbatim Copies. 196 | 197 | You may convey verbatim copies of the Program's source code as you 198 | receive it, in any medium, provided that you conspicuously and 199 | appropriately publish on each copy an appropriate copyright notice; 200 | keep intact all notices stating that this License and any 201 | non-permissive terms added in accord with section 7 apply to the code; 202 | keep intact all notices of the absence of any warranty; and give all 203 | recipients a copy of this License along with the Program. 204 | 205 | You may charge any price or no price for each copy that you convey, 206 | and you may offer support or warranty protection for a fee. 207 | 208 | 5. Conveying Modified Source Versions. 209 | 210 | You may convey a work based on the Program, or the modifications to 211 | produce it from the Program, in the form of source code under the 212 | terms of section 4, provided that you also meet all of these conditions: 213 | 214 | a) The work must carry prominent notices stating that you modified 215 | it, and giving a relevant date. 216 | 217 | b) The work must carry prominent notices stating that it is 218 | released under this License and any conditions added under section 219 | 7. This requirement modifies the requirement in section 4 to 220 | "keep intact all notices". 221 | 222 | c) You must license the entire work, as a whole, under this 223 | License to anyone who comes into possession of a copy. This 224 | License will therefore apply, along with any applicable section 7 225 | additional terms, to the whole of the work, and all its parts, 226 | regardless of how they are packaged. This License gives no 227 | permission to license the work in any other way, but it does not 228 | invalidate such permission if you have separately received it. 229 | 230 | d) If the work has interactive user interfaces, each must display 231 | Appropriate Legal Notices; however, if the Program has interactive 232 | interfaces that do not display Appropriate Legal Notices, your 233 | work need not make them do so. 234 | 235 | A compilation of a covered work with other separate and independent 236 | works, which are not by their nature extensions of the covered work, 237 | and which are not combined with it such as to form a larger program, 238 | in or on a volume of a storage or distribution medium, is called an 239 | "aggregate" if the compilation and its resulting copyright are not 240 | used to limit the access or legal rights of the compilation's users 241 | beyond what the individual works permit. Inclusion of a covered work 242 | in an aggregate does not cause this License to apply to the other 243 | parts of the aggregate. 244 | 245 | 6. Conveying Non-Source Forms. 246 | 247 | You may convey a covered work in object code form under the terms 248 | of sections 4 and 5, provided that you also convey the 249 | machine-readable Corresponding Source under the terms of this License, 250 | in one of these ways: 251 | 252 | a) Convey the object code in, or embodied in, a physical product 253 | (including a physical distribution medium), accompanied by the 254 | Corresponding Source fixed on a durable physical medium 255 | customarily used for software interchange. 256 | 257 | b) Convey the object code in, or embodied in, a physical product 258 | (including a physical distribution medium), accompanied by a 259 | written offer, valid for at least three years and valid for as 260 | long as you offer spare parts or customer support for that product 261 | model, to give anyone who possesses the object code either (1) a 262 | copy of the Corresponding Source for all the software in the 263 | product that is covered by this License, on a durable physical 264 | medium customarily used for software interchange, for a price no 265 | more than your reasonable cost of physically performing this 266 | conveying of source, or (2) access to copy the 267 | Corresponding Source from a network server at no charge. 268 | 269 | c) Convey individual copies of the object code with a copy of the 270 | written offer to provide the Corresponding Source. This 271 | alternative is allowed only occasionally and noncommercially, and 272 | only if you received the object code with such an offer, in accord 273 | with subsection 6b. 274 | 275 | d) Convey the object code by offering access from a designated 276 | place (gratis or for a charge), and offer equivalent access to the 277 | Corresponding Source in the same way through the same place at no 278 | further charge. You need not require recipients to copy the 279 | Corresponding Source along with the object code. If the place to 280 | copy the object code is a network server, the Corresponding Source 281 | may be on a different server (operated by you or a third party) 282 | that supports equivalent copying facilities, provided you maintain 283 | clear directions next to the object code saying where to find the 284 | Corresponding Source. Regardless of what server hosts the 285 | Corresponding Source, you remain obligated to ensure that it is 286 | available for as long as needed to satisfy these requirements. 287 | 288 | e) Convey the object code using peer-to-peer transmission, provided 289 | you inform other peers where the object code and Corresponding 290 | Source of the work are being offered to the general public at no 291 | charge under subsection 6d. 292 | 293 | A separable portion of the object code, whose source code is excluded 294 | from the Corresponding Source as a System Library, need not be 295 | included in conveying the object code work. 296 | 297 | A "User Product" is either (1) a "consumer product", which means any 298 | tangible personal property which is normally used for personal, family, 299 | or household purposes, or (2) anything designed or sold for incorporation 300 | into a dwelling. In determining whether a product is a consumer product, 301 | doubtful cases shall be resolved in favor of coverage. For a particular 302 | product received by a particular user, "normally used" refers to a 303 | typical or common use of that class of product, regardless of the status 304 | of the particular user or of the way in which the particular user 305 | actually uses, or expects or is expected to use, the product. A product 306 | is a consumer product regardless of whether the product has substantial 307 | commercial, industrial or non-consumer uses, unless such uses represent 308 | the only significant mode of use of the product. 309 | 310 | "Installation Information" for a User Product means any methods, 311 | procedures, authorization keys, or other information required to install 312 | and execute modified versions of a covered work in that User Product from 313 | a modified version of its Corresponding Source. The information must 314 | suffice to ensure that the continued functioning of the modified object 315 | code is in no case prevented or interfered with solely because 316 | modification has been made. 317 | 318 | If you convey an object code work under this section in, or with, or 319 | specifically for use in, a User Product, and the conveying occurs as 320 | part of a transaction in which the right of possession and use of the 321 | User Product is transferred to the recipient in perpetuity or for a 322 | fixed term (regardless of how the transaction is characterized), the 323 | Corresponding Source conveyed under this section must be accompanied 324 | by the Installation Information. But this requirement does not apply 325 | if neither you nor any third party retains the ability to install 326 | modified object code on the User Product (for example, the work has 327 | been installed in ROM). 328 | 329 | The requirement to provide Installation Information does not include a 330 | requirement to continue to provide support service, warranty, or updates 331 | for a work that has been modified or installed by the recipient, or for 332 | the User Product in which it has been modified or installed. Access to a 333 | network may be denied when the modification itself materially and 334 | adversely affects the operation of the network or violates the rules and 335 | protocols for communication across the network. 336 | 337 | Corresponding Source conveyed, and Installation Information provided, 338 | in accord with this section must be in a format that is publicly 339 | documented (and with an implementation available to the public in 340 | source code form), and must require no special password or key for 341 | unpacking, reading or copying. 342 | 343 | 7. Additional Terms. 344 | 345 | "Additional permissions" are terms that supplement the terms of this 346 | License by making exceptions from one or more of its conditions. 347 | Additional permissions that are applicable to the entire Program shall 348 | be treated as though they were included in this License, to the extent 349 | that they are valid under applicable law. If additional permissions 350 | apply only to part of the Program, that part may be used separately 351 | under those permissions, but the entire Program remains governed by 352 | this License without regard to the additional permissions. 353 | 354 | When you convey a copy of a covered work, you may at your option 355 | remove any additional permissions from that copy, or from any part of 356 | it. (Additional permissions may be written to require their own 357 | removal in certain cases when you modify the work.) You may place 358 | additional permissions on material, added by you to a covered work, 359 | for which you have or can give appropriate copyright permission. 360 | 361 | Notwithstanding any other provision of this License, for material you 362 | add to a covered work, you may (if authorized by the copyright holders of 363 | that material) supplement the terms of this License with terms: 364 | 365 | a) Disclaiming warranty or limiting liability differently from the 366 | terms of sections 15 and 16 of this License; or 367 | 368 | b) Requiring preservation of specified reasonable legal notices or 369 | author attributions in that material or in the Appropriate Legal 370 | Notices displayed by works containing it; or 371 | 372 | c) Prohibiting misrepresentation of the origin of that material, or 373 | requiring that modified versions of such material be marked in 374 | reasonable ways as different from the original version; or 375 | 376 | d) Limiting the use for publicity purposes of names of licensors or 377 | authors of the material; or 378 | 379 | e) Declining to grant rights under trademark law for use of some 380 | trade names, trademarks, or service marks; or 381 | 382 | f) Requiring indemnification of licensors and authors of that 383 | material by anyone who conveys the material (or modified versions of 384 | it) with contractual assumptions of liability to the recipient, for 385 | any liability that these contractual assumptions directly impose on 386 | those licensors and authors. 387 | 388 | All other non-permissive additional terms are considered "further 389 | restrictions" within the meaning of section 10. If the Program as you 390 | received it, or any part of it, contains a notice stating that it is 391 | governed by this License along with a term that is a further 392 | restriction, you may remove that term. If a license document contains 393 | a further restriction but permits relicensing or conveying under this 394 | License, you may add to a covered work material governed by the terms 395 | of that license document, provided that the further restriction does 396 | not survive such relicensing or conveying. 397 | 398 | If you add terms to a covered work in accord with this section, you 399 | must place, in the relevant source files, a statement of the 400 | additional terms that apply to those files, or a notice indicating 401 | where to find the applicable terms. 402 | 403 | Additional terms, permissive or non-permissive, may be stated in the 404 | form of a separately written license, or stated as exceptions; 405 | the above requirements apply either way. 406 | 407 | 8. Termination. 408 | 409 | You may not propagate or modify a covered work except as expressly 410 | provided under this License. Any attempt otherwise to propagate or 411 | modify it is void, and will automatically terminate your rights under 412 | this License (including any patent licenses granted under the third 413 | paragraph of section 11). 414 | 415 | However, if you cease all violation of this License, then your 416 | license from a particular copyright holder is reinstated (a) 417 | provisionally, unless and until the copyright holder explicitly and 418 | finally terminates your license, and (b) permanently, if the copyright 419 | holder fails to notify you of the violation by some reasonable means 420 | prior to 60 days after the cessation. 421 | 422 | Moreover, your license from a particular copyright holder is 423 | reinstated permanently if the copyright holder notifies you of the 424 | violation by some reasonable means, this is the first time you have 425 | received notice of violation of this License (for any work) from that 426 | copyright holder, and you cure the violation prior to 30 days after 427 | your receipt of the notice. 428 | 429 | Termination of your rights under this section does not terminate the 430 | licenses of parties who have received copies or rights from you under 431 | this License. If your rights have been terminated and not permanently 432 | reinstated, you do not qualify to receive new licenses for the same 433 | material under section 10. 434 | 435 | 9. Acceptance Not Required for Having Copies. 436 | 437 | You are not required to accept this License in order to receive or 438 | run a copy of the Program. Ancillary propagation of a covered work 439 | occurring solely as a consequence of using peer-to-peer transmission 440 | to receive a copy likewise does not require acceptance. However, 441 | nothing other than this License grants you permission to propagate or 442 | modify any covered work. These actions infringe copyright if you do 443 | not accept this License. Therefore, by modifying or propagating a 444 | covered work, you indicate your acceptance of this License to do so. 445 | 446 | 10. Automatic Licensing of Downstream Recipients. 447 | 448 | Each time you convey a covered work, the recipient automatically 449 | receives a license from the original licensors, to run, modify and 450 | propagate that work, subject to this License. You are not responsible 451 | for enforcing compliance by third parties with this License. 452 | 453 | An "entity transaction" is a transaction transferring control of an 454 | organization, or substantially all assets of one, or subdividing an 455 | organization, or merging organizations. If propagation of a covered 456 | work results from an entity transaction, each party to that 457 | transaction who receives a copy of the work also receives whatever 458 | licenses to the work the party's predecessor in interest had or could 459 | give under the previous paragraph, plus a right to possession of the 460 | Corresponding Source of the work from the predecessor in interest, if 461 | the predecessor has it or can get it with reasonable efforts. 462 | 463 | You may not impose any further restrictions on the exercise of the 464 | rights granted or affirmed under this License. For example, you may 465 | not impose a license fee, royalty, or other charge for exercise of 466 | rights granted under this License, and you may not initiate litigation 467 | (including a cross-claim or counterclaim in a lawsuit) alleging that 468 | any patent claim is infringed by making, using, selling, offering for 469 | sale, or importing the Program or any portion of it. 470 | 471 | 11. Patents. 472 | 473 | A "contributor" is a copyright holder who authorizes use under this 474 | License of the Program or a work on which the Program is based. The 475 | work thus licensed is called the contributor's "contributor version". 476 | 477 | A contributor's "essential patent claims" are all patent claims 478 | owned or controlled by the contributor, whether already acquired or 479 | hereafter acquired, that would be infringed by some manner, permitted 480 | by this License, of making, using, or selling its contributor version, 481 | but do not include claims that would be infringed only as a 482 | consequence of further modification of the contributor version. For 483 | purposes of this definition, "control" includes the right to grant 484 | patent sublicenses in a manner consistent with the requirements of 485 | this License. 486 | 487 | Each contributor grants you a non-exclusive, worldwide, royalty-free 488 | patent license under the contributor's essential patent claims, to 489 | make, use, sell, offer for sale, import and otherwise run, modify and 490 | propagate the contents of its contributor version. 491 | 492 | In the following three paragraphs, a "patent license" is any express 493 | agreement or commitment, however denominated, not to enforce a patent 494 | (such as an express permission to practice a patent or covenant not to 495 | sue for patent infringement). To "grant" such a patent license to a 496 | party means to make such an agreement or commitment not to enforce a 497 | patent against the party. 498 | 499 | If you convey a covered work, knowingly relying on a patent license, 500 | and the Corresponding Source of the work is not available for anyone 501 | to copy, free of charge and under the terms of this License, through a 502 | publicly available network server or other readily accessible means, 503 | then you must either (1) cause the Corresponding Source to be so 504 | available, or (2) arrange to deprive yourself of the benefit of the 505 | patent license for this particular work, or (3) arrange, in a manner 506 | consistent with the requirements of this License, to extend the patent 507 | license to downstream recipients. "Knowingly relying" means you have 508 | actual knowledge that, but for the patent license, your conveying the 509 | covered work in a country, or your recipient's use of the covered work 510 | in a country, would infringe one or more identifiable patents in that 511 | country that you have reason to believe are valid. 512 | 513 | If, pursuant to or in connection with a single transaction or 514 | arrangement, you convey, or propagate by procuring conveyance of, a 515 | covered work, and grant a patent license to some of the parties 516 | receiving the covered work authorizing them to use, propagate, modify 517 | or convey a specific copy of the covered work, then the patent license 518 | you grant is automatically extended to all recipients of the covered 519 | work and works based on it. 520 | 521 | A patent license is "discriminatory" if it does not include within 522 | the scope of its coverage, prohibits the exercise of, or is 523 | conditioned on the non-exercise of one or more of the rights that are 524 | specifically granted under this License. You may not convey a covered 525 | work if you are a party to an arrangement with a third party that is 526 | in the business of distributing software, under which you make payment 527 | to the third party based on the extent of your activity of conveying 528 | the work, and under which the third party grants, to any of the 529 | parties who would receive the covered work from you, a discriminatory 530 | patent license (a) in connection with copies of the covered work 531 | conveyed by you (or copies made from those copies), or (b) primarily 532 | for and in connection with specific products or compilations that 533 | contain the covered work, unless you entered into that arrangement, 534 | or that patent license was granted, prior to 28 March 2007. 535 | 536 | Nothing in this License shall be construed as excluding or limiting 537 | any implied license or other defenses to infringement that may 538 | otherwise be available to you under applicable patent law. 539 | 540 | 12. No Surrender of Others' Freedom. 541 | 542 | If conditions are imposed on you (whether by court order, agreement or 543 | otherwise) that contradict the conditions of this License, they do not 544 | excuse you from the conditions of this License. If you cannot convey a 545 | covered work so as to satisfy simultaneously your obligations under this 546 | License and any other pertinent obligations, then as a consequence you may 547 | not convey it at all. For example, if you agree to terms that obligate you 548 | to collect a royalty for further conveying from those to whom you convey 549 | the Program, the only way you could satisfy both those terms and this 550 | License would be to refrain entirely from conveying the Program. 551 | 552 | 13. Use with the GNU Affero General Public License. 553 | 554 | Notwithstanding any other provision of this License, you have 555 | permission to link or combine any covered work with a work licensed 556 | under version 3 of the GNU Affero General Public License into a single 557 | combined work, and to convey the resulting work. The terms of this 558 | License will continue to apply to the part which is the covered work, 559 | but the special requirements of the GNU Affero General Public License, 560 | section 13, concerning interaction through a network will apply to the 561 | combination as such. 562 | 563 | 14. Revised Versions of this License. 564 | 565 | The Free Software Foundation may publish revised and/or new versions of 566 | the GNU General Public License from time to time. Such new versions will 567 | be similar in spirit to the present version, but may differ in detail to 568 | address new problems or concerns. 569 | 570 | Each version is given a distinguishing version number. If the 571 | Program specifies that a certain numbered version of the GNU General 572 | Public License "or any later version" applies to it, you have the 573 | option of following the terms and conditions either of that numbered 574 | version or of any later version published by the Free Software 575 | Foundation. If the Program does not specify a version number of the 576 | GNU General Public License, you may choose any version ever published 577 | by the Free Software Foundation. 578 | 579 | If the Program specifies that a proxy can decide which future 580 | versions of the GNU General Public License can be used, that proxy's 581 | public statement of acceptance of a version permanently authorizes you 582 | to choose that version for the Program. 583 | 584 | Later license versions may give you additional or different 585 | permissions. However, no additional obligations are imposed on any 586 | author or copyright holder as a result of your choosing to follow a 587 | later version. 588 | 589 | 15. Disclaimer of Warranty. 590 | 591 | THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY 592 | APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT 593 | HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY 594 | OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, 595 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 596 | PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM 597 | IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF 598 | ALL NECESSARY SERVICING, REPAIR OR CORRECTION. 599 | 600 | 16. Limitation of Liability. 601 | 602 | IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING 603 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS 604 | THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY 605 | GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE 606 | USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF 607 | DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD 608 | PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), 609 | EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF 610 | SUCH DAMAGES. 611 | 612 | 17. Interpretation of Sections 15 and 16. 613 | 614 | If the disclaimer of warranty and limitation of liability provided 615 | above cannot be given local legal effect according to their terms, 616 | reviewing courts shall apply local law that most closely approximates 617 | an absolute waiver of all civil liability in connection with the 618 | Program, unless a warranty or assumption of liability accompanies a 619 | copy of the Program in return for a fee. 620 | 621 | END OF TERMS AND CONDITIONS 622 | 623 | How to Apply These Terms to Your New Programs 624 | 625 | If you develop a new program, and you want it to be of the greatest 626 | possible use to the public, the best way to achieve this is to make it 627 | free software which everyone can redistribute and change under these terms. 628 | 629 | To do so, attach the following notices to the program. It is safest 630 | to attach them to the start of each source file to most effectively 631 | state the exclusion of warranty; and each file should have at least 632 | the "copyright" line and a pointer to where the full notice is found. 633 | 634 | 635 | Copyright (C) 636 | 637 | This program is free software: you can redistribute it and/or modify 638 | it under the terms of the GNU General Public License as published by 639 | the Free Software Foundation, either version 3 of the License, or 640 | (at your option) any later version. 641 | 642 | This program is distributed in the hope that it will be useful, 643 | but WITHOUT ANY WARRANTY; without even the implied warranty of 644 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 645 | GNU General Public License for more details. 646 | 647 | You should have received a copy of the GNU General Public License 648 | along with this program. If not, see . 649 | 650 | Also add information on how to contact you by electronic and paper mail. 651 | 652 | If the program does terminal interaction, make it output a short 653 | notice like this when it starts in an interactive mode: 654 | 655 | Copyright (C) 656 | This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'. 657 | This is free software, and you are welcome to redistribute it 658 | under certain conditions; type `show c' for details. 659 | 660 | The hypothetical commands `show w' and `show c' should show the appropriate 661 | parts of the General Public License. Of course, your program's commands 662 | might be different; for a GUI interface, you would use an "about box". 663 | 664 | You should also get your employer (if you work as a programmer) or school, 665 | if any, to sign a "copyright disclaimer" for the program, if necessary. 666 | For more information on this, and how to apply and follow the GNU GPL, see 667 | . 668 | 669 | The GNU General Public License does not permit incorporating your program 670 | into proprietary programs. If your program is a subroutine library, you 671 | may consider it more useful to permit linking proprietary applications with 672 | the library. If this is what you want to do, use the GNU Lesser General 673 | Public License instead of this License. But first, please read 674 | . 675 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Custom Matmul Kernels 2 | This repository contains source code for [this blog post](https://demoriarty.github.io/BMM-1/). 3 | 4 | ## Dependency 5 | - Python 3.7.10 or higher 6 | - CuPy 7.4.0 or higher 7 | - Pytorch 1.8.1 or higher 8 | - Only tested with CUDA 11.2 9 | -------------------------------------------------------------------------------- /bmm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import cupy as cp 3 | import numpy as np 4 | import math 5 | from custom_kernel import CustomKernel 6 | 7 | class BMMCUDA(CustomKernel): 8 | def __init__(self, patch_m=4, patch_n=4): 9 | super(BMMCUDA, self).__init__() 10 | self.patch_m = patch_m 11 | self.patch_n = patch_n 12 | 13 | with open("kernels/bmm_helpers.cu", "r") as f: 14 | helpers = f.read() 15 | 16 | with open("kernels/bmm.cu",'r') as f: ### 17 | self.kernel = helpers + f.read() 18 | 19 | self.kernel = (self.kernel 20 | .replace("_PM_", str(self.patch_m)) 21 | .replace("_PN_", str(self.patch_n)) 22 | .replace("__DISTANCE_FN__", "madd") 23 | ) 24 | 25 | self._fn_tt = cp.RawKernel( 26 | code=self.kernel, 27 | name="bmm_tt", 28 | backend='nvcc', 29 | options=('--maxrregcount=128', '--use_fast_math') 30 | ) 31 | self._fn_nn = cp.RawKernel( 32 | code=self.kernel, 33 | name="bmm_nn", 34 | backend='nvcc', 35 | options=( 36 | '--maxrregcount=128', 37 | '--use_fast_math', 38 | #'-Xptxas', 39 | #'-dlcm=cg', 40 | ) 41 | ) 42 | # print(self._fn_nn.attributes) 43 | self._fn_tn = cp.RawKernel( 44 | code=self.kernel, 45 | name="bmm_tn", 46 | backend='nvcc', 47 | options=('--maxrregcount=128', '--use_fast_math') 48 | ) 49 | self._fn_nt = cp.RawKernel( 50 | code=self.kernel, 51 | name="bmm_nt", 52 | backend='nvcc', 53 | options=('--maxrregcount=128', '--use_fast_math') 54 | ) 55 | 56 | def get_mode(self, A, B): 57 | mode = [None, None] 58 | if A.stride()[-1] == 1: 59 | mode[0] = "n" 60 | elif A.stride()[-2] == 1: 61 | mode[0] = "t" 62 | if B.stride()[-1] == 1: 63 | mode[1] = "n" 64 | elif B.stride()[-2] == 1: 65 | mode[1] = "t" 66 | return "".join(mode) 67 | 68 | def __call__(self, A, B): 69 | """ 70 | Performs C = f(A) @ g(B) 71 | A: torch.Tensor, shape : [l, m, k] or [l, k, m] 72 | B: torch.Tensor, shape : [l, n, k] or [l, k, n] 73 | returns C: torch.Tensor, shape : [l, m, n] 74 | """ 75 | assert len(A.shape) == len(B.shape) 76 | # A = A.contiguous() 77 | # B = B.contiguous() 78 | if len(A.shape) == 2 and len(B.shape) == 2: 79 | A = A[None] 80 | B = B[None] 81 | two_dimentional = True 82 | elif len(A.shape) == 3 and len(B.shape) == 3: 83 | two_dimentional = False 84 | else: 85 | raise ValueError("A and B need to be 2d or 3d") 86 | assert A.shape[0] == B.shape[0] 87 | assert A.shape[2] == B.shape[1] 88 | assert A.dtype == B.dtype 89 | assert A.dtype in [torch.float, torch.half] 90 | assert A.device.type == B.device.type == "cuda" 91 | 92 | mode = self.get_mode(A, B) 93 | 94 | if mode == "nn": 95 | kernel_fn = self._fn_nn 96 | elif mode == "tt": 97 | kernel_fn = self._fn_tt 98 | elif mode == "tn": 99 | kernel_fn = self._fn_tn 100 | elif mode == "nt": 101 | kernel_fn = self._fn_nt 102 | 103 | l, m, k = A.shape 104 | l, k, n = B.shape 105 | 106 | C = torch.zeros([l, m, n], device="cuda:0", dtype=A.dtype) 107 | 108 | threads_per_block = (256,) 109 | #blocks_per_grid = (math.ceil(n/128), math.ceil(m/128), l) 110 | 111 | n_ = math.ceil(n/(128*self.patch_n)) 112 | m_ = math.ceil(m/(128*self.patch_m)) 113 | blocks_per_grid = (self.patch_n*self.patch_m, n_ * m_, l) 114 | 115 | self._fn_nn( 116 | grid=blocks_per_grid, 117 | block=threads_per_block, 118 | args=[ 119 | A.data_ptr(), 120 | B.data_ptr(), 121 | C.data_ptr(), 122 | m, n, k, 123 | ], 124 | stream=self.stream 125 | ) 126 | 127 | if two_dimentional: 128 | C = C[0] 129 | return C 130 | -------------------------------------------------------------------------------- /colab/test_mbmm.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "test_mbmm.ipynb", 7 | "provenance": [], 8 | "collapsed_sections": [] 9 | }, 10 | "kernelspec": { 11 | "name": "python3", 12 | "display_name": "Python 3" 13 | }, 14 | "language_info": { 15 | "name": "python" 16 | }, 17 | "accelerator": "GPU" 18 | }, 19 | "cells": [ 20 | { 21 | "cell_type": "code", 22 | "metadata": { 23 | "id": "GU4BcFUE5E0J" 24 | }, 25 | "source": [ 26 | "#!pip install --upgrade cupy-cuda112==8.5.0" 27 | ], 28 | "execution_count": 1, 29 | "outputs": [] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "metadata": { 34 | "id": "75DR78cKpM3m" 35 | }, 36 | "source": [ 37 | "import numpy as np\n", 38 | "import torch\n", 39 | "import torch.nn as nn\n", 40 | "import torch.nn.functional as F\n", 41 | "import matplotlib.pyplot as plt\n", 42 | "import cupy as cp\n", 43 | "import math\n", 44 | "from time import time" 45 | ], 46 | "execution_count": 2, 47 | "outputs": [] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "metadata": { 52 | "id": "12phuKJIFY9A", 53 | "cellView": "form" 54 | }, 55 | "source": [ 56 | "#@title CustomKernel\n", 57 | "import cupy as cp\n", 58 | "import torch\n", 59 | "\n", 60 | "@cp.util.memoize(for_each_device=True)\n", 61 | "def cunnex(func_name, func_body):\n", 62 | " return cp.cuda.compile_with_cache(func_body).get_function(func_name)\n", 63 | "\n", 64 | "class Stream:\n", 65 | " def __init__(self, ptr):\n", 66 | " self.ptr = ptr\n", 67 | " \n", 68 | "class CustomKernel:\n", 69 | " def __init__(self):\n", 70 | " self._use_torch_in_cupy_malloc()\n", 71 | " self.stream = Stream(torch.cuda.current_stream().cuda_stream)\n", 72 | "\n", 73 | " @staticmethod\n", 74 | " def _torch_alloc(size):\n", 75 | " device = cp.cuda.Device().id\n", 76 | " tensor = torch.empty(size, dtype=torch.uint8, device=device)\n", 77 | " return cp.cuda.MemoryPointer(\n", 78 | " cp.cuda.UnownedMemory(tensor.data_ptr(), size, tensor), 0)\n", 79 | "\n", 80 | " def _use_torch_in_cupy_malloc(self):\n", 81 | " cp.cuda.set_allocator(self._torch_alloc)\n", 82 | "\n", 83 | " def _compile_kernel_str(\n", 84 | " self,\n", 85 | " kernel,\n", 86 | " name,\n", 87 | " options=(),\n", 88 | " backend=\"nvrtc\",\n", 89 | " max_dynamic_smem=None\n", 90 | " ):\n", 91 | " fn = cp.RawKernel(\n", 92 | " kernel,\n", 93 | " name,\n", 94 | " options=options,\n", 95 | " backend=backend,\n", 96 | " )\n", 97 | " if max_dynamic_smem:\n", 98 | " fn.max_dynamic_shared_size_bytes = max_dynamic_smem\n", 99 | " return fn" 100 | ], 101 | "execution_count": 3, 102 | "outputs": [] 103 | }, 104 | { 105 | "cell_type": "code", 106 | "metadata": { 107 | "id": "IDuqX_JrO-50", 108 | "cellView": "form" 109 | }, 110 | "source": [ 111 | "#@title BMMv2.5 Kernel\n", 112 | "kernel = \"\"\"\n", 113 | "#define _VOLATILE_ \n", 114 | "\n", 115 | "#define likely(x) __builtin_expect(!!(x), 1)\n", 116 | "#define unlikely(x) __builtin_expect(!!(x), 0)\n", 117 | "#define load(x) __ldcg(x)\n", 118 | "#define store(x, value) __stcs(x, value)\n", 119 | "\n", 120 | "typedef long long ll_t;\n", 121 | "typedef unsigned long long ull_t;\n", 122 | "\n", 123 | "typedef struct __builtin_align__(32) {\n", 124 | " float s0, s1, s2, s3, s4, s5, s6, s7;\n", 125 | "} _float8;\n", 126 | "\n", 127 | "typedef union {\n", 128 | " _float8 f8;\n", 129 | " float val[8];\n", 130 | "} float8;\n", 131 | "\n", 132 | "__device__ void init_cCache(\n", 133 | " float8 cCache[8]\n", 134 | ") {\n", 135 | " #pragma unroll\n", 136 | " for (int i=0; i<8; i++){\n", 137 | " #pragma unroll\n", 138 | " for (int j=0; j<8; j++){\n", 139 | " cCache[i].val[j] = 0.f;\n", 140 | " }\n", 141 | " }\n", 142 | "}\n", 143 | "\n", 144 | "__device__ void thread_matmul_v4(\n", 145 | " _VOLATILE_ float aSM[8][128+4],\n", 146 | " _VOLATILE_ float bSM[8][128+4],\n", 147 | " float8 cCache[8],\n", 148 | " int vx, int vy\n", 149 | ") {\n", 150 | " float aCache1[8];\n", 151 | " float aCache2[8];\n", 152 | " #pragma unroll\n", 153 | " for (int mi=0; mi<8; mi++){\n", 154 | " aCache1[mi] = aSM[0][8*vy + mi];\n", 155 | " }\n", 156 | "\n", 157 | " #pragma unroll\n", 158 | " for (int ki=0; ki<8; ki++){\n", 159 | " int is_odd = ki & 1;\n", 160 | " if (is_odd == 0){\n", 161 | " if (likely(ki < 7)){\n", 162 | " #pragma unroll\n", 163 | " for (int mi=0; mi<8; mi++){\n", 164 | " aCache2[mi] = aSM[ki+1][8*vy + mi];\n", 165 | " }\n", 166 | " }\n", 167 | " #pragma unroll\n", 168 | " for (int ni=0; ni<8; ni++){\n", 169 | " float b = bSM[ki][vx/4 + 8*vx + ni];\n", 170 | " #pragma unroll\n", 171 | " for (int mi=0; mi<8; mi++){\n", 172 | " float a = aCache1[mi];\n", 173 | " cCache[mi].val[ni] = fmaf(a, b, cCache[mi].val[ni]);\n", 174 | " }\n", 175 | " }\n", 176 | " } else {\n", 177 | " if (likely(ki < 7)){\n", 178 | " #pragma unroll\n", 179 | " for (int mi=0; mi<8; mi++){\n", 180 | " aCache1[mi] = aSM[ki+1][8*vy + mi];\n", 181 | " }\n", 182 | " }\n", 183 | " #pragma unroll\n", 184 | " for (int ni=0; ni<8; ni++){\n", 185 | " float b = bSM[ki][vx/4 + 8*vx + ni];\n", 186 | " #pragma unroll\n", 187 | " for (int mi=0; mi<8; mi++){\n", 188 | " float a = aCache2[mi];\n", 189 | " cCache[mi].val[ni] = fmaf(a, b, cCache[mi].val[ni]);\n", 190 | " }\n", 191 | " }\n", 192 | " }\n", 193 | " }\n", 194 | "}\n", 195 | "\n", 196 | "__device__ void thread_matmul_v3(\n", 197 | " _VOLATILE_ float aSM[8][128+4],\n", 198 | " _VOLATILE_ float bSM[8][128+4],\n", 199 | " float8 cCache[8],\n", 200 | " int vx, int vy\n", 201 | ") {\n", 202 | " float aCache[8];\n", 203 | "\n", 204 | " #pragma unroll\n", 205 | " for (int ki=0; ki<8; ki++){\n", 206 | " #pragma unroll\n", 207 | " for (int mi=0; mi<8; mi++){\n", 208 | " aCache[mi] = aSM[ki][8*vy + mi];\n", 209 | " }\n", 210 | " #pragma unroll\n", 211 | " for (int ni=0; ni<8; ni++){\n", 212 | " float b = bSM[ki][vx/4 + 8*vx + ni];\n", 213 | " #pragma unroll\n", 214 | " for (int mi=0; mi<8; mi++){\n", 215 | " float a = aCache[mi];\n", 216 | " cCache[mi].val[ni] = fmaf(a, b, cCache[mi].val[ni]);\n", 217 | " }\n", 218 | " }\n", 219 | " }\n", 220 | "}\n", 221 | "\n", 222 | "// Unsafe\n", 223 | "__device__ void write_c(\n", 224 | " float8 cCache[8],\n", 225 | " float* C,\n", 226 | " int gStartx, int gStarty,\n", 227 | " int vx, int vy, int bid,\n", 228 | " int M, int N\n", 229 | ") {\n", 230 | " #pragma unroll\n", 231 | " for (int i=0; i<8; i++){\n", 232 | " int iM = gStarty + vy*8 + i;\n", 233 | " if (likely(iM < M)){\n", 234 | " int iN_start = gStartx + vx*8;\n", 235 | " reinterpret_cast(C + (bid)*M*N + (iM)*N + (iN_start))[0] = cCache[i];\n", 236 | " /*\n", 237 | " if (likely(iN_start + 7 < N)){\n", 238 | " reinterpret_cast(C + (bid)*M*N + (iM)*N + (iN_start))[0] = cCache[i];\n", 239 | " } else {\n", 240 | " #pragma unroll\n", 241 | " for (int j=0; j<8; j++){\n", 242 | " int iN = iN_start + j;\n", 243 | " if (iN < N){\n", 244 | " C[(bid)*M*N + (iM)*N + (iN)] = cCache[i].val[j];\n", 245 | " }\n", 246 | " }\n", 247 | " }\n", 248 | " */\n", 249 | " }\n", 250 | " }\n", 251 | "}\n", 252 | "\n", 253 | "__device__ void write_c_v3(\n", 254 | " float8 cCache[8],\n", 255 | " float* C,\n", 256 | " int gStartx, int gStarty,\n", 257 | " int vx, int vy, int bid,\n", 258 | " int M, int N\n", 259 | ") {\n", 260 | " __shared__ volatile float cSM[16][128];\n", 261 | " #pragma unroll\n", 262 | " for (int mi=0; mi<8; mi++){\n", 263 | " int iM = gStarty + vy*8 + mi;\n", 264 | " // Store 1 row from cCache to cSM\n", 265 | " if (iM < M){\n", 266 | " #pragma unroll\n", 267 | " for (int ni=0; ni<8; ni++){\n", 268 | " cSM[vy][vx*8 + ni] = cCache[mi].val[ni];\n", 269 | " }\n", 270 | " // Store to C\n", 271 | " #pragma unroll\n", 272 | " for (int ni=0; ni<8; ni++){\n", 273 | " int iN = gStartx + 16*ni + vx;\n", 274 | " if (iN < N){\n", 275 | " float cVal = cSM[vy][16*ni + vx];\n", 276 | " store(C+(bid)*M*N + (iM)*N + (iN), cVal);\n", 277 | " }\n", 278 | " }\n", 279 | " }\n", 280 | " } \n", 281 | "}\n", 282 | "\n", 283 | "extern \"C\"\n", 284 | "__global__ void bmm_tn(\n", 285 | " const float* __restrict__ A,\n", 286 | " const float* __restrict__ B,\n", 287 | " float* __restrict__ C,\n", 288 | " int M, int N, int K\n", 289 | "){\n", 290 | "}\n", 291 | "\n", 292 | "extern \"C\"\n", 293 | "__global__ void bmm_nt(\n", 294 | " const float* __restrict__ A,\n", 295 | " const float* __restrict__ B,\n", 296 | " float* __restrict__ C,\n", 297 | " int M, int N, int K\n", 298 | "){\n", 299 | "}\n", 300 | "\n", 301 | "extern \"C\"\n", 302 | "__global__ void bmm_nn(\n", 303 | " const float* __restrict__ A,\n", 304 | " const float* __restrict__ B,\n", 305 | " float* __restrict__ C,\n", 306 | " int M, int N, int K\n", 307 | "){\n", 308 | " int tid = threadIdx.x; // thread idx\n", 309 | " int bid = blockIdx.z; // batch idx\n", 310 | "\n", 311 | " // Neighboring blocks are grouped into PN x PM block groups in order to increase\n", 312 | " // L1 cache hit rate\n", 313 | " // There are ceil(M/PM) x ceil(N/PN) block groups in total.\n", 314 | " // Blocks within block groups are indexed with blockIdx.x % PN and blockIdx.x / PN\n", 315 | " int px = blockIdx.x % _PN_;\n", 316 | " int py = blockIdx.x / _PN_;\n", 317 | " int bDimX = (N + (128*_PN_) - 1) / (128*_PN_); \n", 318 | " int bDimY = (M + (128*_PM_) - 1) / (128*_PM_); \n", 319 | " int bIdxX = (blockIdx.y % bDimX) * _PN_ + px;\n", 320 | " int bIdxY = (blockIdx.y / bDimX) * _PM_ + py;\n", 321 | " int gStartx = bIdxX * 128; // starting index of block on N axis\n", 322 | " int gStarty = bIdxY * 128; // starting index of block on M axis\n", 323 | " if (gStartx > N || gStarty > M){\n", 324 | " return;\n", 325 | " }\n", 326 | " // These are used to re-arrange threads into different shapes\n", 327 | " // for example: (256) -> (16, 16) -> (8, 32) -> (32, 8)\n", 328 | " int vx = tid % 16;\n", 329 | " int vy = tid / 16;\n", 330 | " int wx = tid % 32; // thread idx in warp\n", 331 | " int wy = tid / 32; // warp id\n", 332 | " int dx = tid % 8;\n", 333 | " int dy = tid / 8;\n", 334 | "\n", 335 | " __shared__ _VOLATILE_ float aSM1[8][128+4];\n", 336 | " __shared__ _VOLATILE_ float bSM1[8][128+4];\n", 337 | " __shared__ _VOLATILE_ float aSM2[8][128+4];\n", 338 | " __shared__ _VOLATILE_ float bSM2[8][128+4];\n", 339 | " float aBuffer1[4];\n", 340 | " float bBuffer1[4];\n", 341 | " float aBuffer2[4];\n", 342 | " float bBuffer2[4];\n", 343 | "\n", 344 | " float8 cCache[8];\n", 345 | " init_cCache(cCache);\n", 346 | "\n", 347 | " // Load initial 16 x 128 tile of A and B to buffer1 and buffer2\n", 348 | " #pragma unroll\n", 349 | " for (int i=0; i<4; i++){\n", 350 | " int iM = gStarty + dy + i*32;\n", 351 | " int iN = gStartx + wx + i*32;\n", 352 | " if (likely(iM < _M_)){\n", 353 | " if (likely(dx < _K_)){\n", 354 | " aBuffer1[i] = load(A + (bid)*_M_*_K_ + (iM)*_K_ + (dx));\n", 355 | " } else {\n", 356 | " aBuffer1[i] = 0.f;\n", 357 | " }\n", 358 | " if (likely(dx+8 < _K_)){\n", 359 | " aBuffer2[i] = load(A + (bid)*_M_*_K_ + (iM)*_K_ + (dx+8));\n", 360 | " } else {\n", 361 | " aBuffer2[i] = 0.f;\n", 362 | " }\n", 363 | " }\n", 364 | " if (likely(iN < N)){\n", 365 | " if (likely(wy < _K_)){\n", 366 | " bBuffer1[i] = load(B + (bid)*_N_*_K_ + (wy)*_N_ + (iN));\n", 367 | " } else {\n", 368 | " bBuffer1[i] = 0.f;\n", 369 | " }\n", 370 | " if (likely(wy+8 < _K_)){\n", 371 | " bBuffer2[i] = load(B + (bid)*_N_*_K_ + (wy+8)*_N_ + (iN));\n", 372 | " } else {\n", 373 | " bBuffer2[i] = 0.f;\n", 374 | " }\n", 375 | " }\n", 376 | " }\n", 377 | "\n", 378 | " // Number of main loop iterations is ceil(k/16)\n", 379 | " int nIt = (_K_ + 16 - 1) / 16;\n", 380 | " #pragma unroll\n", 381 | " for (int itr=0; itr A @ B\n", 593 | " \"tt\" --> A.T @ B.T\n", 594 | " \"nt\" --> A @ B.T\n", 595 | " \"tn\" --> A.T @ B\n", 596 | " \"\"\"\n", 597 | " assert len(A.shape) == len(B.shape)\n", 598 | " A = A.contiguous()\n", 599 | " B = B.contiguous()\n", 600 | " if len(A.shape) == 2 and len(B.shape) == 2:\n", 601 | " A2 = A[None]\n", 602 | " B2 = B[None]\n", 603 | " elif len(A.shape) == 3 and len(B.shape) == 3:\n", 604 | " A2 = A\n", 605 | " B2 = B\n", 606 | " else:\n", 607 | " raise ValueError(\"shape of A and B need to be 2d or 3d\")\n", 608 | "\n", 609 | " if mode == \"nn\":\n", 610 | " C = self._call_nn(A2, B2)\n", 611 | " elif mode == \"tt\":\n", 612 | " C = self._call_tt(A2, B2)\n", 613 | " elif mode == \"tn\":\n", 614 | " C = self._call_tn(A2, B2)\n", 615 | " elif mode == \"nt\":\n", 616 | " C = self._call_nt(A2, B2)\n", 617 | "\n", 618 | " if len(A.shape) == 2 and len(B.shape) == 2:\n", 619 | " C = C[0]\n", 620 | " return C" 621 | ], 622 | "execution_count": 5, 623 | "outputs": [] 624 | }, 625 | { 626 | "cell_type": "code", 627 | "metadata": { 628 | "id": "zjdgI4ruqEve", 629 | "cellView": "form" 630 | }, 631 | "source": [ 632 | "#@title MBMMv2 Kernel\n", 633 | "kernel = \"\"\"\n", 634 | "#define _VOLATILE_ \n", 635 | "\n", 636 | "#define likely(x) __builtin_expect(!!(x), 1)\n", 637 | "#define unlikely(x) __builtin_expect(!!(x), 0)\n", 638 | "#define load(x) __ldcg(x)\n", 639 | "#define store(x, value) __stcs(x, value)\n", 640 | "\n", 641 | "typedef long long ll_t;\n", 642 | "typedef unsigned long long ull_t;\n", 643 | "typedef unsigned char uint8_t;\n", 644 | "\n", 645 | "typedef struct __builtin_align__(32) {\n", 646 | " float s0, s1, s2, s3, s4, s5, s6, s7;\n", 647 | "} _float8;\n", 648 | "\n", 649 | "typedef union {\n", 650 | " _float8 f8;\n", 651 | " float val[8];\n", 652 | "} float8;\n", 653 | "\n", 654 | "__device__ void init_cCache(\n", 655 | " float8 cCache[8]\n", 656 | ") {\n", 657 | " #pragma unroll\n", 658 | " for (int i=0; i<8; i++){\n", 659 | " #pragma unroll\n", 660 | " for (int j=0; j<8; j++){\n", 661 | " cCache[i].val[j] = 0.f;\n", 662 | " }\n", 663 | " }\n", 664 | "}\n", 665 | "\n", 666 | "__device__ void thread_matmul_v4(\n", 667 | " _VOLATILE_ float aSM[8][128+4],\n", 668 | " _VOLATILE_ float bSM[8][128+4],\n", 669 | " float8 cCache[8],\n", 670 | " int vx, int vy\n", 671 | ") {\n", 672 | " float aCache1[8];\n", 673 | " float aCache2[8];\n", 674 | " #pragma unroll\n", 675 | " for (int mi=0; mi<8; mi++){\n", 676 | " aCache1[mi] = aSM[0][8*vy + mi];\n", 677 | " }\n", 678 | "\n", 679 | " #pragma unroll\n", 680 | " for (int ki=0; ki<8; ki++){\n", 681 | " int is_odd = ki & 1;\n", 682 | " if (is_odd == 0){\n", 683 | " if (likely(ki < 7)){\n", 684 | " #pragma unroll\n", 685 | " for (int mi=0; mi<8; mi++){\n", 686 | " aCache2[mi] = aSM[ki+1][8*vy + mi];\n", 687 | " }\n", 688 | " }\n", 689 | " #pragma unroll\n", 690 | " for (int ni=0; ni<8; ni++){\n", 691 | " float b = bSM[ki][vx/4 + 8*vx + ni];\n", 692 | " #pragma unroll\n", 693 | " for (int mi=0; mi<8; mi++){\n", 694 | " float a = aCache1[mi];\n", 695 | " cCache[mi].val[ni] = fmaf(a, b, cCache[mi].val[ni]);\n", 696 | " }\n", 697 | " }\n", 698 | " } else {\n", 699 | " if (likely(ki < 7)){\n", 700 | " #pragma unroll\n", 701 | " for (int mi=0; mi<8; mi++){\n", 702 | " aCache1[mi] = aSM[ki+1][8*vy + mi];\n", 703 | " }\n", 704 | " }\n", 705 | " #pragma unroll\n", 706 | " for (int ni=0; ni<8; ni++){\n", 707 | " float b = bSM[ki][vx/4 + 8*vx + ni];\n", 708 | " #pragma unroll\n", 709 | " for (int mi=0; mi<8; mi++){\n", 710 | " float a = aCache2[mi];\n", 711 | " cCache[mi].val[ni] = fmaf(a, b, cCache[mi].val[ni]);\n", 712 | " }\n", 713 | " }\n", 714 | " }\n", 715 | " }\n", 716 | "}\n", 717 | "\n", 718 | "__device__ void thread_matmul_v3(\n", 719 | " _VOLATILE_ float aSM[8][128+4],\n", 720 | " _VOLATILE_ float bSM[8][128+4],\n", 721 | " float8 cCache[8],\n", 722 | " int vx, int vy\n", 723 | ") {\n", 724 | " float aCache[8];\n", 725 | "\n", 726 | " #pragma unroll\n", 727 | " for (int ki=0; ki<8; ki++){\n", 728 | " #pragma unroll\n", 729 | " for (int mi=0; mi<8; mi++){\n", 730 | " aCache[mi] = aSM[ki][8*vy + mi];\n", 731 | " }\n", 732 | " #pragma unroll\n", 733 | " for (int ni=0; ni<8; ni++){\n", 734 | " float b = bSM[ki][vx/4 + 8*vx + ni];\n", 735 | " #pragma unroll\n", 736 | " for (int mi=0; mi<8; mi++){\n", 737 | " float a = aCache[mi];\n", 738 | " cCache[mi].val[ni] = fmaf(a, b, cCache[mi].val[ni]);\n", 739 | " }\n", 740 | " }\n", 741 | " }\n", 742 | "}\n", 743 | "\n", 744 | "__device__ void mask_cCache(\n", 745 | " float8 cCache[8],\n", 746 | " const uint8_t* ElementMask,\n", 747 | " int gStartx,\n", 748 | " int gStarty,\n", 749 | " int vx, int vy, int bid,\n", 750 | " int M, int N\n", 751 | ") {\n", 752 | " #pragma unroll\n", 753 | " for (int i=0; i<8; i++){\n", 754 | " int iM = gStarty + vy*8 + i;\n", 755 | " if (likely(iM < M)){\n", 756 | " #pragma unroll\n", 757 | " for (int j=0; j<8; j++){\n", 758 | " int iN = gStartx + vx*8 + j;\n", 759 | " if (likely(iN < N)){\n", 760 | " uint8_t element_mask = ElementMask[(__MASK_BID__)*M*N + (iM)*N + (iN)];\n", 761 | " cCache[i].val[j] *= element_mask;\n", 762 | " }\n", 763 | " }\n", 764 | " }\n", 765 | " }\n", 766 | "}\n", 767 | "\n", 768 | "// Unsafe\n", 769 | "__device__ void write_c(\n", 770 | " float8 cCache[8],\n", 771 | " float* C,\n", 772 | " int gStartx, int gStarty,\n", 773 | " int vx, int vy, int bid,\n", 774 | " int M, int N\n", 775 | ") {\n", 776 | " #pragma unroll\n", 777 | " for (int i=0; i<8; i++){\n", 778 | " int iM = gStarty + vy*8 + i;\n", 779 | " if (likely(iM < M)){\n", 780 | " int iN_start = gStartx + vx*8;\n", 781 | " reinterpret_cast(C + (bid)*M*N + (iM)*N + (iN_start))[0] = cCache[i];\n", 782 | " }\n", 783 | " }\n", 784 | "}\n", 785 | "\n", 786 | "__device__ void write_c_v3(\n", 787 | " float8 cCache[8],\n", 788 | " float* C,\n", 789 | " int gStartx, int gStarty,\n", 790 | " int vx, int vy, int bid,\n", 791 | " int M, int N\n", 792 | ") {\n", 793 | " __shared__ volatile float cSM[16][128];\n", 794 | " #pragma unroll\n", 795 | " for (int mi=0; mi<8; mi++){\n", 796 | " int iM = gStarty + vy*8 + mi;\n", 797 | " // Store 1 row from cCache to cSM\n", 798 | " if (iM < M){\n", 799 | " #pragma unroll\n", 800 | " for (int ni=0; ni<8; ni++){\n", 801 | " cSM[vy][vx*8 + ni] = cCache[mi].val[ni];\n", 802 | " }\n", 803 | " // Store to C\n", 804 | " #pragma unroll\n", 805 | " for (int ni=0; ni<8; ni++){\n", 806 | " int iN = gStartx + 16*ni + vx;\n", 807 | " if (iN < N){\n", 808 | " float cVal = cSM[vy][16*ni + vx];\n", 809 | " //store(C+(bid)*M*N + (iM)*N + (iN), cVal);\n", 810 | " C[(bid)*M*N + (iM)*N + (iN)] = cVal;\n", 811 | " }\n", 812 | " }\n", 813 | " }\n", 814 | " } \n", 815 | "}\n", 816 | "\n", 817 | "extern \"C\"\n", 818 | "__global__ void mbmm_tn(\n", 819 | " const float* __restrict__ A,\n", 820 | " const float* __restrict__ B,\n", 821 | " float* __restrict__ C,\n", 822 | " const uint8_t* __restrict__ BlockMask,\n", 823 | " const uint8_t* __restrict__ ThreadMask,\n", 824 | " const uint8_t* __restrict__ ElementMask,\n", 825 | " int M, int N, int K\n", 826 | "){\n", 827 | "}\n", 828 | "\n", 829 | "extern \"C\"\n", 830 | "__global__ void mbmm_nt(\n", 831 | " const float* __restrict__ A,\n", 832 | " const float* __restrict__ B,\n", 833 | " float* __restrict__ C,\n", 834 | " const uint8_t* __restrict__ BlockMask,\n", 835 | " const uint8_t* __restrict__ ThreadMask,\n", 836 | " const uint8_t* __restrict__ ElementMask,\n", 837 | " int M, int N, int K\n", 838 | "){\n", 839 | "}\n", 840 | "\n", 841 | "extern \"C\"\n", 842 | "__global__ void mbmm_nn(\n", 843 | " const float* __restrict__ A,\n", 844 | " const float* __restrict__ B,\n", 845 | " float* __restrict__ C,\n", 846 | " const uint8_t* __restrict__ BlockMask,\n", 847 | " const uint8_t* __restrict__ ThreadMask,\n", 848 | " const uint8_t* __restrict__ ElementMask,\n", 849 | " int M, int N, int K\n", 850 | "){\n", 851 | " int tid = threadIdx.x; // thread idx\n", 852 | " int bid = blockIdx.z; // batch idx\n", 853 | "\n", 854 | " // Neighboring blocks are grouped into PN x PM block groups in order to increase\n", 855 | " // L1 cache hit rate\n", 856 | " // There are ceil(M/PM) x ceil(N/PN) block groups in total.\n", 857 | " // Blocks within block groups are indexed with blockIdx.x % PN and blockIdx.x / PN\n", 858 | " \n", 859 | " int px = blockIdx.x % _PN_;\n", 860 | " int py = blockIdx.x / _PN_;\n", 861 | " int bDimX = (N + (128*_PN_) - 1) / (128*_PN_); \n", 862 | " int bDimY = (M + (128*_PM_) - 1) / (128*_PM_); \n", 863 | " int bIdxX = (blockIdx.y % bDimX) * _PN_ + px;\n", 864 | " int bIdxY = (blockIdx.y / bDimX) * _PM_ + py;\n", 865 | " int gStartx = bIdxX * 128; // starting index of block on N axis\n", 866 | " int gStarty = bIdxY * 128; // starting index of block on M axis\n", 867 | " if (gStartx > N || gStarty > M){\n", 868 | " return;\n", 869 | " }\n", 870 | "\n", 871 | " // These are used to re-arrange threads into different shapes\n", 872 | " // for example: (256) -> (16, 16) -> (8, 32) -> (32, 8)\n", 873 | " int vx = tid % 16;\n", 874 | " int vy = tid / 16;\n", 875 | " int wx = tid % 32; // thread idx in warp\n", 876 | " int wy = tid / 32; // warp id\n", 877 | " int dx = tid % 8;\n", 878 | " int dy = tid / 8;\n", 879 | "\n", 880 | " int bM = (M + 128 - 1) / 128;\n", 881 | " int bN = (N + 128 - 1) / 128;\n", 882 | " int tM = (M + 8 - 1) / 8;\n", 883 | " int tN = (N + 8 - 1) / 8;\n", 884 | " uint8_t block_mask = BlockMask[__MASK_BID__*bM*bN + (bIdxY)*bN + (bIdxX)];\n", 885 | " uint8_t thread_mask = ThreadMask[__MASK_BID__*tM*tN + (bIdxY*16 + vy)*tN + (bIdxX*16 + vx) ];\n", 886 | " if (block_mask == 0){\n", 887 | " return;\n", 888 | " }\n", 889 | "\n", 890 | " __shared__ _VOLATILE_ float aSM1[8][128+4];\n", 891 | " __shared__ _VOLATILE_ float bSM1[8][128+4];\n", 892 | " __shared__ _VOLATILE_ float aSM2[8][128+4];\n", 893 | " __shared__ _VOLATILE_ float bSM2[8][128+4];\n", 894 | " float aBuffer1[4];\n", 895 | " float bBuffer1[4];\n", 896 | " float aBuffer2[4];\n", 897 | " float bBuffer2[4];\n", 898 | "\n", 899 | " float8 cCache[8];\n", 900 | " init_cCache(cCache);\n", 901 | "\n", 902 | " // Load initial 16 x 128 tile of A and B to buffer1 and buffer2\n", 903 | " #pragma unroll\n", 904 | " for (int i=0; i<4; i++){\n", 905 | " int iM = gStarty + dy + i*32;\n", 906 | " int iN = gStartx + wx + i*32;\n", 907 | " if (likely(iM < _M_)){\n", 908 | " if (likely(dx < _K_)){\n", 909 | " aBuffer1[i] = load(A + (bid)*_M_*_K_ + (iM)*_K_ + (dx));\n", 910 | " } else {\n", 911 | " aBuffer1[i] = 0.f;\n", 912 | " }\n", 913 | " if (likely(dx+8 < _K_)){\n", 914 | " aBuffer2[i] = load(A + (bid)*_M_*_K_ + (iM)*_K_ + (dx+8));\n", 915 | " } else {\n", 916 | " aBuffer2[i] = 0.f;\n", 917 | " }\n", 918 | " }\n", 919 | " if (likely(iN < N)){\n", 920 | " if (likely(wy < _K_)){\n", 921 | " bBuffer1[i] = load(B + (bid)*_N_*_K_ + (wy)*_N_ + (iN));\n", 922 | " } else {\n", 923 | " bBuffer1[i] = 0.f;\n", 924 | " }\n", 925 | " if (likely(wy+8 < _K_)){\n", 926 | " bBuffer2[i] = load(B + (bid)*_N_*_K_ + (wy+8)*_N_ + (iN));\n", 927 | " } else {\n", 928 | " bBuffer2[i] = 0.f;\n", 929 | " }\n", 930 | " }\n", 931 | " }\n", 932 | "\n", 933 | " // Number of main loop iterations is ceil(k/16)\n", 934 | " int nIt = (_K_ + 16 - 1) / 16;\n", 935 | " #pragma unroll\n", 936 | " for (int itr=0; itr A @ B\n", 1218 | " \"tt\" --> A.T @ B.T\n", 1219 | " \"nt\" --> A @ B.T\n", 1220 | " \"tn\" --> A.T @ B\n", 1221 | " \"\"\"\n", 1222 | " assert len(A.shape) == len(B.shape)\n", 1223 | " A = A.contiguous()\n", 1224 | " B = B.contiguous()\n", 1225 | " if len(A.shape) == 2 and len(B.shape) == 2:\n", 1226 | " A2 = A[None]\n", 1227 | " B2 = B[None]\n", 1228 | " if not self.share_mask:\n", 1229 | " block_mask = block_mask[None]\n", 1230 | " thread_mask = thread_mask[None]\n", 1231 | " element_mask = element_mask[None]\n", 1232 | " elif len(A.shape) == 3 and len(B.shape) == 3:\n", 1233 | " A2 = A\n", 1234 | " B2 = B\n", 1235 | " else:\n", 1236 | " raise ValueError(\"shape of A and B need to be 2d or 3d\")\n", 1237 | "\n", 1238 | " if mode == \"nn\":\n", 1239 | " C = self._call_nn(A2, B2, block_mask, thread_mask, element_mask)\n", 1240 | " elif mode == \"tt\":\n", 1241 | " C = self._call_tt(A2, B2, block_mask, thread_mask, element_mask)\n", 1242 | " elif mode == \"tn\":\n", 1243 | " C = self._call_tn(A2, B2, block_mask, thread_mask, element_mask)\n", 1244 | " elif mode == \"nt\":\n", 1245 | " C = self._call_nt(A2, B2, block_mask, thread_mask, element_mask)\n", 1246 | "\n", 1247 | " if len(A.shape) == 2 and len(B.shape) == 2:\n", 1248 | " C = C[0]\n", 1249 | " return C" 1250 | ], 1251 | "execution_count": 7, 1252 | "outputs": [] 1253 | }, 1254 | { 1255 | "cell_type": "code", 1256 | "metadata": { 1257 | "id": "47aR82_Ygpp0", 1258 | "cellView": "form" 1259 | }, 1260 | "source": [ 1261 | "#@title test MBMMv2\n", 1262 | "def test_mbmm_v2(l, m, n, k, mode=\"nn\", n_iter=1, share_mask=True, verbose=0):\n", 1263 | " print(f\"l={l} m={m} n={n} k={k}\")\n", 1264 | " if mode[0] == \"n\":\n", 1265 | " A = torch.randn(l, m, k, device=\"cuda:0\")\n", 1266 | " elif mode[0] == \"t\":\n", 1267 | " A = torch.randn(l, k, m, device=\"cuda:0\")\n", 1268 | " \n", 1269 | " if mode[1] == \"n\":\n", 1270 | " B = torch.randn(l, k, n, device=\"cuda:0\")\n", 1271 | " elif mode[1] == \"t\":\n", 1272 | " B = torch.randn(l, n, k, device=\"cuda:0\")\n", 1273 | " custom_mbmm = MBMMCUDAv2(share_mask=share_mask)\n", 1274 | " custom_bmm = BMMCUDAv2_5()\n", 1275 | "\n", 1276 | " ### Create mask\n", 1277 | " # block_mask = np.random.choice(2, size=[\n", 1278 | " # l,\n", 1279 | " # math.ceil(m/128),\n", 1280 | " # math.ceil(n/128)\n", 1281 | " # ], p=[0.5, 0.5])\n", 1282 | " # thread_mask = np.random.choice(2, size=[\n", 1283 | " # l,\n", 1284 | " # math.ceil(m/8),\n", 1285 | " # math.ceil(n/8)\n", 1286 | " # ], p=[0.0, 1.0])\n", 1287 | " # block_mask = torch.tensor(block_mask).to(\"cuda\")\n", 1288 | " # thread_mask = torch.tensor(thread_mask).to(\"cuda\")\n", 1289 | " # big_block_mask = torch.repeat_interleave(block_mask, 128, dim=1)\n", 1290 | " # big_block_mask = torch.repeat_interleave(big_block_mask, 128, dim=2)\n", 1291 | " # big_thread_mask = torch.repeat_interleave(thread_mask, 8, dim=1)\n", 1292 | " # big_thread_mask = torch.repeat_interleave(big_thread_mask, 8, dim=2)\n", 1293 | " # final_mask = big_block_mask * big_thread_mask\n", 1294 | " ### mask type 2\n", 1295 | " if share_mask:\n", 1296 | " final_mask = torch.ones(m, n, device=\"cuda\")\n", 1297 | " final_mask = torch.tril(final_mask).to(\"cuda\").bool()\n", 1298 | " thread_mask = final_mask.view(math.ceil(m/8), 8, math.ceil(n/8), 8)\n", 1299 | " thread_mask = thread_mask.sum(dim=1).sum(dim=-1)\n", 1300 | " block_mask = final_mask.view(math.ceil(m/128), 128, math.ceil(n/128), 128)\n", 1301 | " block_mask = block_mask.sum(dim=1).sum(dim=-1).bool()\n", 1302 | " else:\n", 1303 | " final_mask = torch.ones(l, m, n, device=\"cuda\")\n", 1304 | " final_mask = torch.tril(final_mask).to(\"cuda\").bool()\n", 1305 | " thread_mask = final_mask.view(l, math.ceil(m/8), 8, math.ceil(n/8), 8)\n", 1306 | " thread_mask = thread_mask.sum(dim=2).sum(dim=-1)\n", 1307 | " block_mask = final_mask.view(l, math.ceil(m/128), 128, math.ceil(n/128), 128)\n", 1308 | " block_mask = block_mask.sum(dim=2).sum(dim=-1).bool()\n", 1309 | " \n", 1310 | " ### end mask\n", 1311 | " thread_mask = thread_mask.to(torch.uint8)\n", 1312 | " block_mask = block_mask.to(torch.uint8)\n", 1313 | " element_mask = final_mask.to(torch.uint8)\n", 1314 | " ###\n", 1315 | " batch_index = np.random.randint(l)\n", 1316 | " if verbose > 1:\n", 1317 | " print(\"Block Mask\")\n", 1318 | " plt.imshow(block_mask.cpu())\n", 1319 | " plt.show()\n", 1320 | " print(\"Thread Mask\")\n", 1321 | " plt.imshow(thread_mask.cpu())\n", 1322 | " plt.show()\n", 1323 | " print(\"Element Mask\")\n", 1324 | " plt.imshow(element_mask.cpu())\n", 1325 | " plt.show()\n", 1326 | "\n", 1327 | " mask = ~final_mask.bool()\n", 1328 | " del final_mask\n", 1329 | " flop = l*m*n*k*2 + l*m*n\n", 1330 | "\n", 1331 | " ### cuBLAS BMM\n", 1332 | " if mode[0] == \"t\":\n", 1333 | " A = A.transpose(1, 2)\n", 1334 | " if mode[1] == \"t\":\n", 1335 | " B = B.transpose(1, 2)\n", 1336 | " #warmup\n", 1337 | " for i in range(n_iter):\n", 1338 | " torch.bmm(A, B)\n", 1339 | " torch.cuda.synchronize()\n", 1340 | " tm = time()\n", 1341 | " for i in range(n_iter):\n", 1342 | " C = torch.bmm(A, B)\n", 1343 | " C.masked_fill_(mask = mask, value=0)\n", 1344 | " torch.cuda.synchronize()\n", 1345 | " time_cost_0 = (time() - tm) / n_iter\n", 1346 | " flops0 = (flop / time_cost_0) / 1000**4\n", 1347 | " if verbose > 0:\n", 1348 | " print(\"tflops:\", flops0)\n", 1349 | " print(\"time spent for torch.bmm + masked_fill:\", time_cost_0)\n", 1350 | " else:\n", 1351 | " del C\n", 1352 | " \n", 1353 | " \n", 1354 | " ### MBMMv2\n", 1355 | " for i in range(n_iter):\n", 1356 | " custom_mbmm(A, B, block_mask, thread_mask, element_mask, mode=mode)\n", 1357 | " torch.cuda.synchronize()\n", 1358 | " tm = time()\n", 1359 | " for i in range(n_iter):\n", 1360 | " C1 = custom_mbmm(A, B, block_mask, thread_mask, element_mask, mode=mode)\n", 1361 | " # C1.masked_fill_(mask = mask, value=0)\n", 1362 | " torch.cuda.synchronize()\n", 1363 | " time_cost_1 = (time() - tm) / n_iter\n", 1364 | " flops1 = (flop / time_cost_1) / 1000**4\n", 1365 | " if verbose > 0:\n", 1366 | " print(\"tflops:\", flops1)\n", 1367 | " print(\"time spent for custom_mbmmv2:\", time_cost_1)\n", 1368 | " else:\n", 1369 | " del C1\n", 1370 | "\n", 1371 | " for i in range(n_iter):\n", 1372 | " custom_bmm(A, B, mode=mode)\n", 1373 | " torch.cuda.synchronize()\n", 1374 | " tm = time()\n", 1375 | " for i in range(n_iter):\n", 1376 | " C2 = custom_bmm(A, B, mode=mode)\n", 1377 | " C2.masked_fill_(mask = mask, value=0)\n", 1378 | " torch.cuda.synchronize()\n", 1379 | " time_cost_2 = (time() - tm) / n_iter\n", 1380 | " flops2 = (flop / time_cost_2) / 1000**4\n", 1381 | " if verbose > 0:\n", 1382 | " print(\"tflops:\", flops2)\n", 1383 | " print(\"time spent for custom_bmmv2_5 + masked_fill:\", time_cost_2)\n", 1384 | " # del C2\n", 1385 | " \n", 1386 | " if verbose > 0:\n", 1387 | " dif = (C1 - C).abs()\n", 1388 | " print(\"Error:\", dif.mean())\n", 1389 | " print(\"ratio:\", time_cost_1 / time_cost_0)\n", 1390 | " \n", 1391 | " if verbose > 1:\n", 1392 | " print(\"mbmmv2\")\n", 1393 | " plt.imshow(C1[batch_index].cpu().abs())\n", 1394 | " plt.show()\n", 1395 | " print(\"torch.bmm + masked_fill\")\n", 1396 | " plt.imshow(C[batch_index].cpu().abs())\n", 1397 | " plt.show()\n", 1398 | " print(\"error map\")\n", 1399 | " plt.imshow(( dif < 1e-4)[batch_index].cpu())\n", 1400 | " plt.show()\n", 1401 | " \n", 1402 | " # return time_cost_0, time_cost_1, time_cost_2\n", 1403 | " return flops0, flops1, time_cost_2\n", 1404 | " \n", 1405 | "_ = test_mbmm_v2(1, 4096*2, 4096*2, 256,\n", 1406 | " mode=\"nn\", n_iter=10, share_mask=True,\n", 1407 | " verbose=1\n", 1408 | ")\n" 1409 | ], 1410 | "execution_count": null, 1411 | "outputs": [] 1412 | }, 1413 | { 1414 | "cell_type": "code", 1415 | "metadata": { 1416 | "id": "vXkMqJGQyPtj" 1417 | }, 1418 | "source": [ 1419 | "import os\n", 1420 | "if not os.path.exists(\"imgs\"):\n", 1421 | " os.mkdir(\"imgs\")" 1422 | ], 1423 | "execution_count": 9, 1424 | "outputs": [] 1425 | }, 1426 | { 1427 | "cell_type": "code", 1428 | "metadata": { 1429 | "id": "UsZoADEyOH29", 1430 | "cellView": "form" 1431 | }, 1432 | "source": [ 1433 | "#@title Grid test MBMMv2\n", 1434 | "ls = [i*128 for i in range(1, 3)]\n", 1435 | "ms = [512]\n", 1436 | "ns = ms\n", 1437 | "ks = [64]\n", 1438 | "mode=\"nn\"\n", 1439 | "\n", 1440 | "custom_res = dict()\n", 1441 | "cublass_res = dict()\n", 1442 | "custommf_res = dict()\n", 1443 | "for l in ls:\n", 1444 | " for m in ms:\n", 1445 | " for k in ks:\n", 1446 | " res = test_mbmm_v2(l, m, m, k, mode=mode, n_iter=50, share_mask=True)\n", 1447 | " custom_res[l] = res[1] *1e3\n", 1448 | " cublass_res[l] = res[0] *1e3\n", 1449 | " custommf_res[l] = res[2] *1e3\n", 1450 | "\n", 1451 | "\n", 1452 | "plt.figure(figsize=(15, 10) )\n", 1453 | "plt.tight_layout()\n", 1454 | "plt.xlabel(\"X\", fontsize=17)\n", 1455 | "plt.ylabel(\"milliseconds\", fontsize=17)\n", 1456 | "title = f\"A[X,{m},{k}] B[X,{k},{m}]\"\n", 1457 | "plt.title(title)\n", 1458 | "plt.rcParams[\"font.size\"] = \"17\"\n", 1459 | "plt.grid()\n", 1460 | "colors = [\"red\", \"blue\", \"green\"]\n", 1461 | "labels = [\"custom_mbmm\", \"torch.bmm + masked_fill\", \"custom_bmm + masked_fill\"]\n", 1462 | "for i, res in enumerate([custom_res, cublass_res, custommf_res]):\n", 1463 | " res_x = list(res.keys())\n", 1464 | " res_y = list(res.values())\n", 1465 | " plt.plot(\n", 1466 | " res_x,\n", 1467 | " res_y,\n", 1468 | " color=colors[i],\n", 1469 | " label=labels[i],\n", 1470 | " )\n", 1471 | "\n", 1472 | "plt.legend()\n", 1473 | "plt.savefig(\"imgs/mbmm_\" + title)\n", 1474 | "plt.show()\n" 1475 | ], 1476 | "execution_count": null, 1477 | "outputs": [] 1478 | } 1479 | ] 1480 | } -------------------------------------------------------------------------------- /custom_kernel.py: -------------------------------------------------------------------------------- 1 | import cupy as cp 2 | import torch 3 | 4 | @cp.memoize(for_each_device=True) 5 | def cunnex(func_name, func_body): 6 | return cp.cuda.compile_with_cache(func_body).get_function(func_name) 7 | 8 | class Stream: 9 | def __init__(self, ptr): 10 | self.ptr = ptr 11 | 12 | class CustomKernel: 13 | def __init__(self): 14 | self._use_torch_in_cupy_malloc() 15 | self.stream = Stream(torch.cuda.current_stream().cuda_stream) 16 | 17 | @staticmethod 18 | def _torch_alloc(size): 19 | device = cp.cuda.Device().id 20 | tensor = torch.empty(size, dtype=torch.uint8, device=device) 21 | return cp.cuda.MemoryPointer( 22 | cp.cuda.UnownedMemory(tensor.data_ptr(), size, tensor), 0) 23 | 24 | def _use_torch_in_cupy_malloc(self): 25 | cp.cuda.set_allocator(self._torch_alloc) 26 | 27 | def _compile_kernel_str( 28 | self, 29 | kernel, 30 | name, 31 | options=(), 32 | backend="nvrtc", 33 | max_dynamic_smem=None 34 | ): 35 | fn = cp.RawKernel( 36 | kernel, 37 | name, 38 | options=options, 39 | backend=backend, 40 | ) 41 | if max_dynamic_smem: 42 | fn.max_dynamic_shared_size_bytes = max_dynamic_smem 43 | return fn 44 | -------------------------------------------------------------------------------- /imgs/bmm_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeMoriarty/custom_matmul_kernels/79493921165444a1baebe45196412e3e75d8ae2d/imgs/bmm_1.png -------------------------------------------------------------------------------- /imgs/bmm_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeMoriarty/custom_matmul_kernels/79493921165444a1baebe45196412e3e75d8ae2d/imgs/bmm_2.png -------------------------------------------------------------------------------- /imgs/bmm_A[1,N,N] B[1,N,N].png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeMoriarty/custom_matmul_kernels/79493921165444a1baebe45196412e3e75d8ae2d/imgs/bmm_A[1,N,N] B[1,N,N].png -------------------------------------------------------------------------------- /imgs/bmm_A[1,N,N] B[1,N,N]_inverted.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeMoriarty/custom_matmul_kernels/79493921165444a1baebe45196412e3e75d8ae2d/imgs/bmm_A[1,N,N] B[1,N,N]_inverted.png -------------------------------------------------------------------------------- /imgs/bmm_A[128,N,128] B[128,128,N].png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeMoriarty/custom_matmul_kernels/79493921165444a1baebe45196412e3e75d8ae2d/imgs/bmm_A[128,N,128] B[128,128,N].png -------------------------------------------------------------------------------- /imgs/bmm_A[128,N,128] B[128,128,N]_inverted.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeMoriarty/custom_matmul_kernels/79493921165444a1baebe45196412e3e75d8ae2d/imgs/bmm_A[128,N,128] B[128,128,N]_inverted.png -------------------------------------------------------------------------------- /imgs/bmm_A[128,N,N] B[128,N,N].png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeMoriarty/custom_matmul_kernels/79493921165444a1baebe45196412e3e75d8ae2d/imgs/bmm_A[128,N,N] B[128,N,N].png -------------------------------------------------------------------------------- /imgs/mask1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeMoriarty/custom_matmul_kernels/79493921165444a1baebe45196412e3e75d8ae2d/imgs/mask1.png -------------------------------------------------------------------------------- /imgs/mask2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeMoriarty/custom_matmul_kernels/79493921165444a1baebe45196412e3e75d8ae2d/imgs/mask2.png -------------------------------------------------------------------------------- /imgs/mbmm_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeMoriarty/custom_matmul_kernels/79493921165444a1baebe45196412e3e75d8ae2d/imgs/mbmm_1.png -------------------------------------------------------------------------------- /imgs/mbmm_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeMoriarty/custom_matmul_kernels/79493921165444a1baebe45196412e3e75d8ae2d/imgs/mbmm_2.png -------------------------------------------------------------------------------- /imgs/mbmm_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeMoriarty/custom_matmul_kernels/79493921165444a1baebe45196412e3e75d8ae2d/imgs/mbmm_3.png -------------------------------------------------------------------------------- /imgs/mbmm_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeMoriarty/custom_matmul_kernels/79493921165444a1baebe45196412e3e75d8ae2d/imgs/mbmm_4.png -------------------------------------------------------------------------------- /imgs/mbmm_A[128,1024,X] B[128,X,1024].png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeMoriarty/custom_matmul_kernels/79493921165444a1baebe45196412e3e75d8ae2d/imgs/mbmm_A[128,1024,X] B[128,X,1024].png -------------------------------------------------------------------------------- /imgs/mbmm_A[128,1024,X] B[128,X,1024]_inverted.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeMoriarty/custom_matmul_kernels/79493921165444a1baebe45196412e3e75d8ae2d/imgs/mbmm_A[128,1024,X] B[128,X,1024]_inverted.png -------------------------------------------------------------------------------- /imgs/mbmm_A[128,X,512] B[128,512,X].png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeMoriarty/custom_matmul_kernels/79493921165444a1baebe45196412e3e75d8ae2d/imgs/mbmm_A[128,X,512] B[128,512,X].png -------------------------------------------------------------------------------- /imgs/mbmm_A[128,X,64] B[128,64,X].png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeMoriarty/custom_matmul_kernels/79493921165444a1baebe45196412e3e75d8ae2d/imgs/mbmm_A[128,X,64] B[128,64,X].png -------------------------------------------------------------------------------- /imgs/mbmm_A[128,X,64] B[128,64,X]_inverted.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeMoriarty/custom_matmul_kernels/79493921165444a1baebe45196412e3e75d8ae2d/imgs/mbmm_A[128,X,64] B[128,64,X]_inverted.png -------------------------------------------------------------------------------- /imgs/mbmm_A[X,1024,128] B[X,128,1024].png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeMoriarty/custom_matmul_kernels/79493921165444a1baebe45196412e3e75d8ae2d/imgs/mbmm_A[X,1024,128] B[X,128,1024].png -------------------------------------------------------------------------------- /imgs/mbmm_A[X,1024,128] B[X,128,1024]_inverted.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeMoriarty/custom_matmul_kernels/79493921165444a1baebe45196412e3e75d8ae2d/imgs/mbmm_A[X,1024,128] B[X,128,1024]_inverted.png -------------------------------------------------------------------------------- /imgs/mbmm_A[X,1024,64] B[X,64,1024].png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeMoriarty/custom_matmul_kernels/79493921165444a1baebe45196412e3e75d8ae2d/imgs/mbmm_A[X,1024,64] B[X,64,1024].png -------------------------------------------------------------------------------- /imgs/mbmm_A[X,512,64] B[X,64,512].png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeMoriarty/custom_matmul_kernels/79493921165444a1baebe45196412e3e75d8ae2d/imgs/mbmm_A[X,512,64] B[X,64,512].png -------------------------------------------------------------------------------- /imgs/min_bmm_A[1,N,128] B[1,128,N].png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeMoriarty/custom_matmul_kernels/79493921165444a1baebe45196412e3e75d8ae2d/imgs/min_bmm_A[1,N,128] B[1,128,N].png -------------------------------------------------------------------------------- /imgs/min_bmm_A[1,N,16] B[1,16,N].png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeMoriarty/custom_matmul_kernels/79493921165444a1baebe45196412e3e75d8ae2d/imgs/min_bmm_A[1,N,16] B[1,16,N].png -------------------------------------------------------------------------------- /imgs/min_bmm_A[1,N,16] B[1,16,N]_inverted.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeMoriarty/custom_matmul_kernels/79493921165444a1baebe45196412e3e75d8ae2d/imgs/min_bmm_A[1,N,16] B[1,16,N]_inverted.png -------------------------------------------------------------------------------- /imgs/min_bmm_A[1,N,16] B[1,16,N]_memory.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeMoriarty/custom_matmul_kernels/79493921165444a1baebe45196412e3e75d8ae2d/imgs/min_bmm_A[1,N,16] B[1,16,N]_memory.png -------------------------------------------------------------------------------- /imgs/min_bmm_A[1,N,16] B[1,16,N]_memory_inverted.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeMoriarty/custom_matmul_kernels/79493921165444a1baebe45196412e3e75d8ae2d/imgs/min_bmm_A[1,N,16] B[1,16,N]_memory_inverted.png -------------------------------------------------------------------------------- /imgs/min_bmm_A[1,N,256] B[1,256,N].png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeMoriarty/custom_matmul_kernels/79493921165444a1baebe45196412e3e75d8ae2d/imgs/min_bmm_A[1,N,256] B[1,256,N].png -------------------------------------------------------------------------------- /imgs/min_bmm_A[1,N,256] B[1,256,N]_inverted.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeMoriarty/custom_matmul_kernels/79493921165444a1baebe45196412e3e75d8ae2d/imgs/min_bmm_A[1,N,256] B[1,256,N]_inverted.png -------------------------------------------------------------------------------- /imgs/min_bmm_A[1,N,256] B[1,256,N]_memory.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeMoriarty/custom_matmul_kernels/79493921165444a1baebe45196412e3e75d8ae2d/imgs/min_bmm_A[1,N,256] B[1,256,N]_memory.png -------------------------------------------------------------------------------- /imgs/min_bmm_A[1,N,256] B[1,256,N]_memory_inverted.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeMoriarty/custom_matmul_kernels/79493921165444a1baebe45196412e3e75d8ae2d/imgs/min_bmm_A[1,N,256] B[1,256,N]_memory_inverted.png -------------------------------------------------------------------------------- /imgs/min_bmm_A[1,N,32] B[1,32,N].png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeMoriarty/custom_matmul_kernels/79493921165444a1baebe45196412e3e75d8ae2d/imgs/min_bmm_A[1,N,32] B[1,32,N].png -------------------------------------------------------------------------------- /imgs/min_bmm_A[1,N,64] B[1,64,N].png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeMoriarty/custom_matmul_kernels/79493921165444a1baebe45196412e3e75d8ae2d/imgs/min_bmm_A[1,N,64] B[1,64,N].png -------------------------------------------------------------------------------- /imgs/min_bmm_A[1,N,64] B[1,64,N]_inverted.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeMoriarty/custom_matmul_kernels/79493921165444a1baebe45196412e3e75d8ae2d/imgs/min_bmm_A[1,N,64] B[1,64,N]_inverted.png -------------------------------------------------------------------------------- /imgs/min_bmm_A[1,N,64] B[1,64,N]_memory.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeMoriarty/custom_matmul_kernels/79493921165444a1baebe45196412e3e75d8ae2d/imgs/min_bmm_A[1,N,64] B[1,64,N]_memory.png -------------------------------------------------------------------------------- /imgs/min_bmm_A[1,N,64] B[1,64,N]_memory_inverted.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeMoriarty/custom_matmul_kernels/79493921165444a1baebe45196412e3e75d8ae2d/imgs/min_bmm_A[1,N,64] B[1,64,N]_memory_inverted.png -------------------------------------------------------------------------------- /imgs/min_bmm_A[64,N,256] B[64,256,N].png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeMoriarty/custom_matmul_kernels/79493921165444a1baebe45196412e3e75d8ae2d/imgs/min_bmm_A[64,N,256] B[64,256,N].png -------------------------------------------------------------------------------- /imgs/min_bmm_A[64,N,64] B[64,64,N].png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeMoriarty/custom_matmul_kernels/79493921165444a1baebe45196412e3e75d8ae2d/imgs/min_bmm_A[64,N,64] B[64,64,N].png -------------------------------------------------------------------------------- /imgs/min_bmm_A[64,N,64] B[64,64,N]_inverted.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeMoriarty/custom_matmul_kernels/79493921165444a1baebe45196412e3e75d8ae2d/imgs/min_bmm_A[64,N,64] B[64,64,N]_inverted.png -------------------------------------------------------------------------------- /imgs/min_bmm_A[64,N,64] B[64,64,N]_memory.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeMoriarty/custom_matmul_kernels/79493921165444a1baebe45196412e3e75d8ae2d/imgs/min_bmm_A[64,N,64] B[64,64,N]_memory.png -------------------------------------------------------------------------------- /imgs/min_bmm_A[64,N,64] B[64,64,N]_memory_inverted.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeMoriarty/custom_matmul_kernels/79493921165444a1baebe45196412e3e75d8ae2d/imgs/min_bmm_A[64,N,64] B[64,64,N]_memory_inverted.png -------------------------------------------------------------------------------- /imgs/topk_bmm_A[1,N,1024] B[1,1024,1024]_loglog.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeMoriarty/custom_matmul_kernels/79493921165444a1baebe45196412e3e75d8ae2d/imgs/topk_bmm_A[1,N,1024] B[1,1024,1024]_loglog.png -------------------------------------------------------------------------------- /imgs/topk_bmm_A[1,N,1024] B[1,1024,1024]_memory_semilogx.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeMoriarty/custom_matmul_kernels/79493921165444a1baebe45196412e3e75d8ae2d/imgs/topk_bmm_A[1,N,1024] B[1,1024,1024]_memory_semilogx.png -------------------------------------------------------------------------------- /imgs/topk_bmm_A[1,N,1024] B[1,1024,1024]_semilogx.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeMoriarty/custom_matmul_kernels/79493921165444a1baebe45196412e3e75d8ae2d/imgs/topk_bmm_A[1,N,1024] B[1,1024,1024]_semilogx.png -------------------------------------------------------------------------------- /imgs/topk_bmm_A[1,N,1024] B[1,1024,1024]_semilogx_inverted.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeMoriarty/custom_matmul_kernels/79493921165444a1baebe45196412e3e75d8ae2d/imgs/topk_bmm_A[1,N,1024] B[1,1024,1024]_semilogx_inverted.png -------------------------------------------------------------------------------- /imgs/topk_bmm_A[1,N,256] B[1,256,1024]_loglog.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeMoriarty/custom_matmul_kernels/79493921165444a1baebe45196412e3e75d8ae2d/imgs/topk_bmm_A[1,N,256] B[1,256,1024]_loglog.png -------------------------------------------------------------------------------- /imgs/topk_bmm_A[1,N,256] B[1,256,1024]_memory_semilogx.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeMoriarty/custom_matmul_kernels/79493921165444a1baebe45196412e3e75d8ae2d/imgs/topk_bmm_A[1,N,256] B[1,256,1024]_memory_semilogx.png -------------------------------------------------------------------------------- /imgs/topk_bmm_A[1,N,256] B[1,256,1024]_memory_semilogx_inverted.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeMoriarty/custom_matmul_kernels/79493921165444a1baebe45196412e3e75d8ae2d/imgs/topk_bmm_A[1,N,256] B[1,256,1024]_memory_semilogx_inverted.png -------------------------------------------------------------------------------- /imgs/topk_bmm_A[1,N,256] B[1,256,1024]_semilogx.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeMoriarty/custom_matmul_kernels/79493921165444a1baebe45196412e3e75d8ae2d/imgs/topk_bmm_A[1,N,256] B[1,256,1024]_semilogx.png -------------------------------------------------------------------------------- /imgs/topk_bmm_A[1,N,256] B[1,256,1024]_semilogx_inverted.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeMoriarty/custom_matmul_kernels/79493921165444a1baebe45196412e3e75d8ae2d/imgs/topk_bmm_A[1,N,256] B[1,256,1024]_semilogx_inverted.png -------------------------------------------------------------------------------- /imgs/topk_bmm_A[1,N,64] B[1,64,1024]_loglog.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeMoriarty/custom_matmul_kernels/79493921165444a1baebe45196412e3e75d8ae2d/imgs/topk_bmm_A[1,N,64] B[1,64,1024]_loglog.png -------------------------------------------------------------------------------- /imgs/topk_bmm_A[1,N,64] B[1,64,1024]_memory_semilogx.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeMoriarty/custom_matmul_kernels/79493921165444a1baebe45196412e3e75d8ae2d/imgs/topk_bmm_A[1,N,64] B[1,64,1024]_memory_semilogx.png -------------------------------------------------------------------------------- /imgs/topk_bmm_A[1,N,64] B[1,64,1024]_memory_semilogx_inverted.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeMoriarty/custom_matmul_kernels/79493921165444a1baebe45196412e3e75d8ae2d/imgs/topk_bmm_A[1,N,64] B[1,64,1024]_memory_semilogx_inverted.png -------------------------------------------------------------------------------- /imgs/topk_bmm_A[1,N,64] B[1,64,1024]_semilogx.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeMoriarty/custom_matmul_kernels/79493921165444a1baebe45196412e3e75d8ae2d/imgs/topk_bmm_A[1,N,64] B[1,64,1024]_semilogx.png -------------------------------------------------------------------------------- /imgs/topk_bmm_A[1,N,64] B[1,64,1024]_semilogx_inverted.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeMoriarty/custom_matmul_kernels/79493921165444a1baebe45196412e3e75d8ae2d/imgs/topk_bmm_A[1,N,64] B[1,64,1024]_semilogx_inverted.png -------------------------------------------------------------------------------- /kernels/bitonic_sort.cu: -------------------------------------------------------------------------------- 1 | typedef long long ll_t; 2 | #define isnan(x) ( x != x ) 3 | 4 | #if (__CUDA_ARCH__ < 700) 5 | __device__ void __nanosleep(unsigned int ns){ 6 | clock_t start_clock = clock(); 7 | clock_t clock_offset = 0; 8 | while (clock_offset < ns) 9 | { 10 | clock_offset = clock() - start_clock; 11 | } 12 | } 13 | #endif 14 | 15 | /* 16 | mutex lock code from: 17 | https://stackoverflow.com/questions/18963293/cuda-atomics-change-flag/18968893#18968893 18 | */ 19 | 20 | __device__ void mutex_lock_v2( 21 | unsigned int *mutex 22 | ) { 23 | unsigned int ns = 8; 24 | __syncthreads(); 25 | if (threadIdx.x == 0){ 26 | while (atomicCAS(mutex, 0, 1) == 1) { 27 | __nanosleep(ns); 28 | if (ns < 256) { 29 | ns *= 2; 30 | } 31 | } 32 | } 33 | __syncthreads(); 34 | } 35 | 36 | __device__ void mutex_lock( 37 | unsigned int *mutex, 38 | unsigned int blockMutex[1] 39 | ) { 40 | unsigned int ns = 8; 41 | float old_value; 42 | if (threadIdx.x == 0){ 43 | old_value = atomicCAS(mutex, 0, 1); 44 | blockMutex[0] = old_value; 45 | } 46 | __syncthreads(); 47 | old_value = blockMutex[0]; 48 | while (old_value == 1) { 49 | __nanosleep(ns); 50 | if (ns < 256) { 51 | ns *= 2; 52 | } 53 | 54 | if (threadIdx.x == 0){ 55 | old_value = atomicCAS(mutex, 0, 1); 56 | blockMutex[0] = old_value; 57 | } 58 | __syncthreads(); 59 | old_value = blockMutex[0]; 60 | __syncthreads(); 61 | } 62 | } 63 | 64 | __device__ void mutex_unlock_v2(unsigned int *mutex) { 65 | __threadfence(); 66 | __syncthreads(); 67 | if (threadIdx.x == 0){ 68 | atomicExch(mutex, 0); 69 | __threadfence(); 70 | } 71 | __syncthreads(); 72 | } 73 | 74 | __device__ void mutex_unlock(unsigned int *mutex) { 75 | atomicExch(mutex, 0); 76 | } 77 | 78 | __device__ __forceinline__ unsigned int bfe( 79 | unsigned int source, 80 | unsigned int bitIndex 81 | ) { 82 | unsigned int bit; 83 | asm volatile("bfe.u32 %0, %1, %2, %3;" : "=r"(bit) : "r"((unsigned int) source), "r"(bitIndex), "r"(1)); 84 | return bit; 85 | } 86 | 87 | __device__ __forceinline__ void warpComparator( 88 | float &value, 89 | float &index, 90 | const int stride, 91 | const int direction 92 | ){ 93 | const float other_value = __shfl_xor_sync(0xFFFFFFFF, value, stride); 94 | const float other_index = __shfl_xor_sync(0xFFFFFFFF, index, stride); 95 | bool condition = value < other_value == direction; 96 | index = condition ? other_index : index; 97 | value = condition ? other_value : value; 98 | } 99 | 100 | __device__ __forceinline__ void blockComparator( 101 | float &value, 102 | float &index, 103 | const int stride, 104 | const int direction, 105 | const int laneID, 106 | float valSM[128], 107 | float idxSM[128] 108 | ){ 109 | valSM[laneID] = value; 110 | idxSM[laneID] = index; 111 | __syncthreads(); 112 | 113 | float other_value = valSM[laneID ^ stride]; 114 | float other_index = idxSM[laneID ^ stride]; 115 | __syncthreads(); 116 | 117 | bool condition = value < other_value == direction; 118 | index = condition ? other_index : index; 119 | value = condition ? other_value : value; 120 | } 121 | 122 | __device__ void bitonicSort256( 123 | float &value, 124 | float &index, 125 | float* values, 126 | ll_t* indices, 127 | float valSM[128], 128 | float idxSM[128], 129 | int gStartx, int Q 130 | ){ 131 | float other_value = values[threadIdx.x]; 132 | float other_index = indices[threadIdx.x] - gStartx; 133 | 134 | bool condition = value > other_value == 0; 135 | if (condition){ 136 | float temp_value = value; 137 | float temp_index = index; 138 | value = other_value; 139 | index = other_index; 140 | other_value = temp_value; 141 | other_index = temp_index; 142 | } 143 | 144 | int laneID = threadIdx.x % 128; 145 | int i = 7; 146 | for (int j = 6; j >= 0; j--){ 147 | unsigned int direction = bfe(laneID, 8) ^ bfe(laneID, j); 148 | int stride = pow(2, j); 149 | if (stride < 32){ 150 | warpComparator(value, index, stride, !direction); 151 | } else { 152 | blockComparator(value, index, stride, !direction, laneID, valSM, idxSM); 153 | } 154 | } 155 | 156 | if (threadIdx.x < Q){ 157 | values[threadIdx.x] = value; 158 | indices[threadIdx.x] = index + gStartx; 159 | } 160 | } 161 | 162 | __device__ void bitonicSort( 163 | float &value, 164 | float &index, 165 | float valSM[128], 166 | float idxSM[128] 167 | ) { 168 | unsigned int laneID = threadIdx.x % 128; 169 | for (int i=0; i < 7; i++){ 170 | for (int j=i; j >= 0; j--){ 171 | unsigned int direction = bfe(laneID, i + 1) ^ bfe(laneID, j); 172 | int stride = pow(2, j); 173 | if (stride < 32){ 174 | warpComparator(value, index, stride, direction); 175 | } else { 176 | blockComparator(value, index, stride, direction, laneID, valSM, idxSM); 177 | } 178 | } 179 | } 180 | } 181 | 182 | extern "C" 183 | __global__ void bitonic_sort( 184 | const float* __restrict__ arr, 185 | float* values, 186 | ll_t* indices, 187 | unsigned int* mutex, 188 | int L, int Q 189 | ){ 190 | int gStartx = blockIdx.x * 128; 191 | int tid = threadIdx.x; 192 | __shared__ float valSM[128]; 193 | __shared__ float idxSM[128]; 194 | 195 | float value; 196 | float index; 197 | int iL = gStartx + tid; 198 | if (iL < L){ 199 | value = arr[iL]; 200 | index = tid; 201 | } else { 202 | value = -INFINITY; 203 | } 204 | 205 | bitonicSort(value, index, valSM, idxSM); 206 | 207 | __shared__ unsigned int blockMutex[1]; 208 | mutex_lock_v2(mutex); 209 | 210 | bitonicSort256( 211 | value, index, values, indices, 212 | valSM, idxSM, gStartx, Q 213 | ); 214 | 215 | mutex_unlock_v2(mutex); 216 | } -------------------------------------------------------------------------------- /kernels/bmm.cu: -------------------------------------------------------------------------------- 1 | extern "C" 2 | __global__ void bmm_tn( 3 | const float* __restrict__ A, 4 | const float* __restrict__ B, 5 | float* __restrict__ C, 6 | int M, int N, int K 7 | ){ 8 | int tid = threadIdx.x; // thread idx 9 | int bid = blockIdx.z; // batch idx 10 | 11 | // Neighboring blocks are grouped into PN x PM block groups in order to increase 12 | // L2 cache hit rate 13 | // There are ceil(M/PM) x ceil(N/PN) block groups in total. 14 | // Blocks within block groups are indexed with blockIdx.x % PN and blockIdx.x / PN 15 | int px = blockIdx.x % _PN_; 16 | int py = blockIdx.x / _PN_; 17 | int bDimX = (N + (128*_PN_) - 1) / (128*_PN_); 18 | int bDimY = (M + (128*_PM_) - 1) / (128*_PM_); 19 | int bIdxX = (blockIdx.y % bDimX) * _PN_ + px; 20 | int bIdxY = (blockIdx.y / bDimX) * _PM_ + py; 21 | int gStartx = bIdxX * 128; // starting index of block on N axis 22 | int gStarty = bIdxY * 128; // starting index of block on M axis 23 | if (gStartx > N || gStarty > M){ 24 | return; 25 | } 26 | // These are used to re-arrange threads into different shapes 27 | // for example: (256) -> (16, 16) -> (8, 32) -> (32, 8) 28 | int vx = tid % 16; 29 | int vy = tid / 16; 30 | int wx = tid % 32; // thread idx in warp 31 | int wy = tid / 32; // warp id 32 | int dx = tid % 8; 33 | int dy = tid / 8; 34 | 35 | __shared__ _VOLATILE_ float aSmem1[8][128+4]; 36 | __shared__ _VOLATILE_ float bSmem1[8][128+4]; 37 | __shared__ _VOLATILE_ float aSmem2[8][128+4]; 38 | __shared__ _VOLATILE_ float bSmem2[8][128+4]; 39 | float aBuffer1[4]; 40 | float bBuffer1[4]; 41 | float aBuffer2[4]; 42 | float bBuffer2[4]; 43 | 44 | float8 cCache[8]; 45 | init_cCache(cCache); 46 | 47 | // Load initial 16 x 128 tile of A and B to buffer1 and buffer2 48 | load_ab_tn( 49 | A, B, 50 | aBuffer1, aBuffer2, bBuffer1, bBuffer2, 51 | bid, gStartx, gStarty, 0, 52 | M, N, K 53 | ); 54 | 55 | // Number of main loop iterations is ceil(k/16) 56 | int nIt = (K + 16 - 1) / 16; 57 | #pragma unroll 58 | for (int itr=0; itr N || gStarty > M){ 116 | return; 117 | } 118 | // These are used to re-arrange threads into different shapes 119 | // for example: (256) -> (16, 16) -> (8, 32) -> (32, 8) 120 | int vx = tid % 16; 121 | int vy = tid / 16; 122 | int wx = tid % 32; // thread idx in warp 123 | int wy = tid / 32; // warp id 124 | int dx = tid % 8; 125 | int dy = tid / 8; 126 | 127 | __shared__ _VOLATILE_ float aSmem1[8][128+4]; 128 | __shared__ _VOLATILE_ float bSmem1[8][128+4]; 129 | __shared__ _VOLATILE_ float aSmem2[8][128+4]; 130 | __shared__ _VOLATILE_ float bSmem2[8][128+4]; 131 | float aBuffer1[4]; 132 | float bBuffer1[4]; 133 | float aBuffer2[4]; 134 | float bBuffer2[4]; 135 | 136 | float8 cCache[8]; 137 | init_cCache(cCache); 138 | 139 | // Load initial 16 x 128 tile of A and B to buffer1 and buffer2 140 | load_ab_nt( 141 | A, B, 142 | aBuffer1, aBuffer2, bBuffer1, bBuffer2, 143 | bid, gStartx, gStarty, 0, 144 | M, N, K 145 | ); 146 | 147 | // Number of main loop iterations is ceil(k/16) 148 | int nIt = (K + 16 - 1) / 16; 149 | #pragma unroll 150 | for (int itr=0; itr N || gStarty > M){ 207 | return; 208 | } 209 | // These are used to re-arrange threads into different shapes 210 | // for example: (256) -> (16, 16) -> (8, 32) -> (32, 8) 211 | int vx = tid % 16; 212 | int vy = tid / 16; 213 | int wx = tid % 32; // thread idx in warp 214 | int wy = tid / 32; // warp id 215 | int dx = tid % 8; 216 | int dy = tid / 8; 217 | 218 | __shared__ _VOLATILE_ float aSmem1[8][128+4]; 219 | __shared__ _VOLATILE_ float bSmem1[8][128+4]; 220 | __shared__ _VOLATILE_ float aSmem2[8][128+4]; 221 | __shared__ _VOLATILE_ float bSmem2[8][128+4]; 222 | float aBuffer1[4]; 223 | float bBuffer1[4]; 224 | float aBuffer2[4]; 225 | float bBuffer2[4]; 226 | 227 | float8 cCache[8]; 228 | init_cCache(cCache); 229 | 230 | // Load initial 16 x 128 tile of A and B to buffer1 and buffer2 231 | load_ab_nn( 232 | A, B, 233 | aBuffer1, aBuffer2, bBuffer1, bBuffer2, 234 | bid, gStartx, gStarty, 0, 235 | M, N, K 236 | ); 237 | 238 | // Number of main loop iterations is ceil(k/16) 239 | int nIt = (K + 16 - 1) / 16; 240 | #pragma unroll 241 | for (int itr=0; itr N || gStarty > M){ 299 | return; 300 | } 301 | // These are used to re-arrange threads into different shapes 302 | // for example: (256) -> (16, 16) -> (8, 32) -> (32, 8) 303 | int vx = tid % 16; 304 | int vy = tid / 16; 305 | int wx = tid % 32; // thread idx in warp 306 | int wy = tid / 32; // warp id 307 | int dx = tid % 8; 308 | int dy = tid / 8; 309 | 310 | __shared__ _VOLATILE_ float aSmem1[8][128+4]; 311 | __shared__ _VOLATILE_ float bSmem1[8][128+4]; 312 | __shared__ _VOLATILE_ float aSmem2[8][128+4]; 313 | __shared__ _VOLATILE_ float bSmem2[8][128+4]; 314 | float aBuffer1[4]; 315 | float bBuffer1[4]; 316 | float aBuffer2[4]; 317 | float bBuffer2[4]; 318 | 319 | float8 cCache[8]; 320 | init_cCache(cCache); 321 | 322 | // Load initial 16 x 128 tile of A and B to buffer1 and buffer2 323 | load_ab_tt( 324 | A, B, 325 | aBuffer1, aBuffer2, bBuffer1, bBuffer2, 326 | bid, gStartx, gStarty, 0, 327 | M, N, K 328 | ); 329 | 330 | // Number of main loop iterations is ceil(k/16) 331 | int nIt = (K + 16 - 1) / 16; 332 | #pragma unroll 333 | for (int itr=0; itr(C + (bid)*M*N + (iM)*N + (iN_start))[0] = cCache[i]; 193 | /* 194 | if (likely(iN_start + 7 < N)){ 195 | reinterpret_cast(C + (bid)*M*N + (iM)*N + (iN_start))[0] = cCache[i]; 196 | } else { 197 | #pragma unroll 198 | for (int j=0; j<8; j++){ 199 | int iN = iN_start + j; 200 | if (iN < N){ 201 | C[(bid)*M*N + (iM)*N + (iN)] = cCache[i].val[j]; 202 | } 203 | } 204 | } 205 | */ 206 | } 207 | } 208 | } 209 | 210 | __device__ void write_c_v3( 211 | float8 cCache[8], 212 | float* C, 213 | int gStartx, int gStarty, 214 | int vx, int vy, int bid, 215 | int M, int N 216 | ) { 217 | __shared__ volatile float cSM[16][128]; 218 | #pragma unroll 219 | for (int mi=0; mi<8; mi++){ 220 | int iM = gStarty + vy*8 + mi; 221 | // Store 1 row from cCache to cSM 222 | if (iM < M){ 223 | #pragma unroll 224 | for (int ni=0; ni<8; ni++){ 225 | cSM[vy][vx*8 + ni] = cCache[mi].val[ni]; 226 | } 227 | // Store to C 228 | #pragma unroll 229 | for (int ni=0; ni<8; ni++){ 230 | int iN = gStartx + 16*ni + vx; 231 | if (iN < N){ 232 | float cVal = cSM[vy][16*ni + vx]; 233 | store(C+(bid)*M*N + (iM)*N + (iN), cVal); 234 | } 235 | } 236 | } 237 | } 238 | } 239 | 240 | __device__ void load_ab_nn( 241 | const float* A, 242 | const float* B, 243 | float aBuffer1[4], 244 | float aBuffer2[4], 245 | float bBuffer1[4], 246 | float bBuffer2[4], 247 | int bid, int gStartx, int gStarty, int gStartk, 248 | int M, int N, int K 249 | ){ 250 | int tid = threadIdx.x; 251 | int wx = tid % 32; 252 | int wy = tid / 32; 253 | int dx = tid % 8; 254 | int dy = tid / 8; 255 | int iKA = gStartk + dx; 256 | int iKB = gStartk + wy; 257 | #pragma unroll 258 | for (int i=0; i<4; i++){ 259 | int iM = gStarty + dy + i*32; 260 | int iN = gStartx + wx + i*32; 261 | if (likely(iM < M)){ 262 | if (likely(iKA < K)){ 263 | aBuffer1[i] = load(A + (bid)*M*K + (iM)*K + (iKA)); 264 | } else { 265 | aBuffer1[i] = 0.f; 266 | } 267 | if (likely(iKA+8 < K)){ 268 | aBuffer2[i] = load(A + (bid)*M*K + (iM)*K + (iKA+8)); 269 | } else { 270 | aBuffer2[i] = 0.f; 271 | } 272 | } 273 | if (likely(iN < N)){ 274 | if (likely(iKB < K)){ 275 | bBuffer1[i] = load(B + (bid)*N*K + (iKB)*N + (iN)); 276 | } else { 277 | bBuffer1[i] = 0.f; 278 | } 279 | if (likely(iKB+8 < K)){ 280 | bBuffer2[i] = load(B + (bid)*N*K + (iKB+8)*N + (iN)); 281 | } else { 282 | bBuffer2[i] = 0.f; 283 | } 284 | } 285 | } 286 | } 287 | 288 | __device__ void load_ab_tt( 289 | const float* A, 290 | const float* B, 291 | float aBuffer1[4], 292 | float aBuffer2[4], 293 | float bBuffer1[4], 294 | float bBuffer2[4], 295 | int bid, int gStartx, int gStarty, int gStartk, 296 | int M, int N, int K 297 | ){ 298 | int tid = threadIdx.x; 299 | int wx = tid % 32; 300 | int wy = tid / 32; 301 | int dx = tid % 8; 302 | int dy = tid / 8; 303 | int iKA = gStartk + wy; 304 | int iKB = gStartk + dx; 305 | #pragma unroll 306 | for (int i=0; i<4; i++){ 307 | int iM = gStarty + wx + i*32; 308 | int iN = gStartx + dy + i*32; 309 | if (likely(iM < M)){ 310 | if (likely(iKA < K)){ 311 | aBuffer1[i] = load(A + (bid)*M*K + (iKA)*M + (iM)); 312 | } else { 313 | aBuffer1[i] = 0.f; 314 | } 315 | if (likely(iKA+8 < K)){ 316 | aBuffer2[i] = load(A + (bid)*M*K + (iKA+8)*M + (iM)); 317 | } else { 318 | aBuffer2[i] = 0.f; 319 | } 320 | } 321 | if (likely(iN < N)){ 322 | if (likely(iKB < K)){ 323 | bBuffer1[i] = load(B + (bid)*N*K + (iN)*K + (iKB)); 324 | } else { 325 | bBuffer1[i] = 0.f; 326 | } 327 | if (likely(iKB+8 < K)){ 328 | bBuffer2[i] = load(B + (bid)*N*K + (iN)*K + (iKB+8)); 329 | } else { 330 | bBuffer2[i] = 0.f; 331 | } 332 | } 333 | } 334 | } 335 | 336 | __device__ void load_ab_nt( 337 | const float* A, 338 | const float* B, 339 | float aBuffer1[4], 340 | float aBuffer2[4], 341 | float bBuffer1[4], 342 | float bBuffer2[4], 343 | int bid, int gStartx, int gStarty, int gStartk, 344 | int M, int N, int K 345 | ){ 346 | int tid = threadIdx.x; 347 | int wx = tid % 32; 348 | int wy = tid / 32; 349 | int dx = tid % 8; 350 | int dy = tid / 8; 351 | int iKA = gStartk + dx; 352 | int iKB = gStartk + dx; 353 | #pragma unroll 354 | for (int i=0; i<4; i++){ 355 | int iM = gStarty + dy + i*32; 356 | int iN = gStartx + dy + i*32; 357 | if (likely(iM < M)){ 358 | if (likely(iKA < K)){ 359 | aBuffer1[i] = load(A + (bid)*M*K + (iM)*K + (iKA)); 360 | } else { 361 | aBuffer1[i] = 0.f; 362 | } 363 | if (likely(iKA+8 < K)){ 364 | aBuffer2[i] = load(A + (bid)*M*K + (iM)*K + (iKA+8)); 365 | } else { 366 | aBuffer2[i] = 0.f; 367 | } 368 | } 369 | if (likely(iN < N)){ 370 | if (likely(iKB < K)){ 371 | bBuffer1[i] = load(B + (bid)*N*K + (iN)*K + (iKB)); 372 | } else { 373 | bBuffer1[i] = 0.f; 374 | } 375 | if (likely(iKB+8 < K)){ 376 | bBuffer2[i] = load(B + (bid)*N*K + (iN)*K + (iKB+8)); 377 | } else { 378 | bBuffer2[i] = 0.f; 379 | } 380 | } 381 | } 382 | } 383 | 384 | __device__ void load_ab_tn( 385 | const float* A, 386 | const float* B, 387 | float aBuffer1[4], 388 | float aBuffer2[4], 389 | float bBuffer1[4], 390 | float bBuffer2[4], 391 | int bid, int gStartx, int gStarty, int gStartk, 392 | int M, int N, int K 393 | ){ 394 | int tid = threadIdx.x; 395 | int wx = tid % 32; 396 | int wy = tid / 32; 397 | int dx = tid % 8; 398 | int dy = tid / 8; 399 | int iKA = gStartk + wy; 400 | int iKB = gStartk + wy; 401 | #pragma unroll 402 | for (int i=0; i<4; i++){ 403 | int iM = gStarty + wx + i*32; 404 | int iN = gStartx + wx + i*32; 405 | if (likely(iM < M)){ 406 | if (likely(iKA < K)){ 407 | aBuffer1[i] = load(A + (bid)*M*K + (iKA)*M + (iM)); 408 | } else { 409 | aBuffer1[i] = 0.f; 410 | } 411 | if (likely(iKA+8 < K)){ 412 | aBuffer2[i] = load(A + (bid)*M*K + (iKA+8)*M + (iM)); 413 | } else { 414 | aBuffer2[i] = 0.f; 415 | } 416 | } 417 | if (likely(iN < N)){ 418 | if (likely(iKB < K)){ 419 | bBuffer1[i] = load(B + (bid)*N*K + (iKB)*N + (iN)); 420 | } else { 421 | bBuffer1[i] = 0.f; 422 | } 423 | if (likely(iKB+8 < K)){ 424 | bBuffer2[i] = load(B + (bid)*N*K + (iKB+8)*N + (iN)); 425 | } else { 426 | bBuffer2[i] = 0.f; 427 | } 428 | } 429 | } 430 | } 431 | 432 | __device__ void buffer2smem_nn( 433 | _VOLATILE_ float aSM1[8][128+4], 434 | _VOLATILE_ float aSM2[8][128+4], 435 | _VOLATILE_ float bSM1[8][128+4], 436 | _VOLATILE_ float bSM2[8][128+4], 437 | float aBuffer1[4], 438 | float aBuffer2[4], 439 | float bBuffer1[4], 440 | float bBuffer2[4] 441 | ){ 442 | int tid = threadIdx.x; 443 | int wx = tid % 32; 444 | int wy = tid / 32; 445 | int dx = tid % 8; 446 | int dy = tid / 8; 447 | #pragma unroll 448 | for (int i=0; i<4; i++){ 449 | // Store buffered tiles into shared memory 450 | aSM1[dx][dy+i*32] = aBuffer1[i]; 451 | bSM1[wy][wx+i*32+i] = bBuffer1[i]; 452 | aSM2[dx][dy+i*32] = aBuffer2[i]; 453 | bSM2[wy][wx+i*32+i] = bBuffer2[i]; 454 | } 455 | } 456 | 457 | __device__ void buffer2smem_tt( 458 | _VOLATILE_ float aSM1[8][128+4], 459 | _VOLATILE_ float aSM2[8][128+4], 460 | _VOLATILE_ float bSM1[8][128+4], 461 | _VOLATILE_ float bSM2[8][128+4], 462 | float aBuffer1[4], 463 | float aBuffer2[4], 464 | float bBuffer1[4], 465 | float bBuffer2[4] 466 | ){ 467 | int tid = threadIdx.x; 468 | int wx = tid % 32; 469 | int wy = tid / 32; 470 | int dx = tid % 8; 471 | int dy = tid / 8; 472 | #pragma unroll 473 | for (int i=0; i<4; i++){ 474 | // Store buffered tiles into shared memory 475 | aSM1[wy][wx+i*32] = aBuffer1[i]; 476 | aSM2[wy][wx+i*32] = aBuffer2[i]; 477 | bSM1[dx][dy+i*32+i] = bBuffer1[i]; 478 | bSM2[dx][dy+i*32+i] = bBuffer2[i]; 479 | } 480 | } 481 | 482 | __device__ void buffer2smem_nt( 483 | _VOLATILE_ float aSM1[8][128+4], 484 | _VOLATILE_ float aSM2[8][128+4], 485 | _VOLATILE_ float bSM1[8][128+4], 486 | _VOLATILE_ float bSM2[8][128+4], 487 | float aBuffer1[4], 488 | float aBuffer2[4], 489 | float bBuffer1[4], 490 | float bBuffer2[4] 491 | ){ 492 | int tid = threadIdx.x; 493 | int wx = tid % 32; 494 | int wy = tid / 32; 495 | int dx = tid % 8; 496 | int dy = tid / 8; 497 | #pragma unroll 498 | for (int i=0; i<4; i++){ 499 | // Store buffered tiles into shared memory 500 | aSM1[dx][dy+i*32] = aBuffer1[i]; 501 | aSM2[dx][dy+i*32] = aBuffer2[i]; 502 | bSM1[dx][dy+i*32+i] = bBuffer1[i]; 503 | bSM2[dx][dy+i*32+i] = bBuffer2[i]; 504 | } 505 | } 506 | 507 | __device__ void buffer2smem_tn( 508 | _VOLATILE_ float aSM1[8][128+4], 509 | _VOLATILE_ float aSM2[8][128+4], 510 | _VOLATILE_ float bSM1[8][128+4], 511 | _VOLATILE_ float bSM2[8][128+4], 512 | float aBuffer1[4], 513 | float aBuffer2[4], 514 | float bBuffer1[4], 515 | float bBuffer2[4] 516 | ){ 517 | int tid = threadIdx.x; 518 | int wx = tid % 32; 519 | int wy = tid / 32; 520 | int dx = tid % 8; 521 | int dy = tid / 8; 522 | #pragma unroll 523 | for (int i=0; i<4; i++){ 524 | // Store buffered tiles into shared memory 525 | aSM1[wy][wx+i*32] = aBuffer1[i]; 526 | aSM2[wy][wx+i*32] = aBuffer2[i]; 527 | bSM1[wy][wx+i*32+i] = bBuffer1[i]; 528 | bSM2[wy][wx+i*32+i] = bBuffer2[i]; 529 | } 530 | } 531 | 532 | __device__ void buffer2smem_16_nn( 533 | _VOLATILE_ float aSM[16][128+4], 534 | _VOLATILE_ float bSM[16][128+4], 535 | float aBuffer1[4], 536 | float aBuffer2[4], 537 | float bBuffer1[4], 538 | float bBuffer2[4] 539 | ){ 540 | int tid = threadIdx.x; 541 | int wx = tid % 32; 542 | int wy = tid / 32; 543 | int dx = tid % 8; 544 | int dy = tid / 8; 545 | #pragma unroll 546 | for (int i=0; i<4; i++){ 547 | // Store buffered tiles into shared memory 548 | aSM[dx][dy+i*32] = aBuffer1[i]; 549 | aSM[dx+8][dy+i*32] = aBuffer2[i]; 550 | bSM[wy][wx+i*32+i] = bBuffer1[i]; 551 | bSM[wy+8][wx+i*32+i] = bBuffer2[i]; 552 | } 553 | } 554 | 555 | __device__ void buffer2smem_16_tt( 556 | _VOLATILE_ float aSM[16][128+4], 557 | _VOLATILE_ float bSM[16][128+4], 558 | float aBuffer1[4], 559 | float aBuffer2[4], 560 | float bBuffer1[4], 561 | float bBuffer2[4] 562 | ){ 563 | int tid = threadIdx.x; 564 | int wx = tid % 32; 565 | int wy = tid / 32; 566 | int dx = tid % 8; 567 | int dy = tid / 8; 568 | #pragma unroll 569 | for (int i=0; i<4; i++){ 570 | // Store buffered tiles into shared memory 571 | aSM[wy][wx+i*32] = aBuffer1[i]; 572 | aSM[wy+8][wx+i*32] = aBuffer2[i]; 573 | bSM[dx][dy+i*32+i] = bBuffer1[i]; 574 | bSM[dx+8][dy+i*32+i] = bBuffer2[i]; 575 | } 576 | } 577 | 578 | __device__ void buffer2smem_16_nt( 579 | _VOLATILE_ float aSM[16][128+4], 580 | _VOLATILE_ float bSM[16][128+4], 581 | float aBuffer1[4], 582 | float aBuffer2[4], 583 | float bBuffer1[4], 584 | float bBuffer2[4] 585 | ){ 586 | int tid = threadIdx.x; 587 | int wx = tid % 32; 588 | int wy = tid / 32; 589 | int dx = tid % 8; 590 | int dy = tid / 8; 591 | #pragma unroll 592 | for (int i=0; i<4; i++){ 593 | // Store buffered tiles into shared memory 594 | aSM[dx][dy+i*32] = aBuffer1[i]; 595 | aSM[dx+8][dy+i*32] = aBuffer2[i]; 596 | bSM[dx][dy+i*32+i] = bBuffer1[i]; 597 | bSM[dx+8][dy+i*32+i] = bBuffer2[i]; 598 | } 599 | } 600 | 601 | __device__ void buffer2smem_16_tn( 602 | _VOLATILE_ float aSM[16][128+4], 603 | _VOLATILE_ float bSM[16][128+4], 604 | float aBuffer1[4], 605 | float aBuffer2[4], 606 | float bBuffer1[4], 607 | float bBuffer2[4] 608 | ){ 609 | int tid = threadIdx.x; 610 | int wx = tid % 32; 611 | int wy = tid / 32; 612 | int dx = tid % 8; 613 | int dy = tid / 8; 614 | #pragma unroll 615 | for (int i=0; i<4; i++){ 616 | // Store buffered tiles into shared memory 617 | aSM[wy][wx+i*32] = aBuffer1[i]; 618 | aSM[wy+8][wx+i*32] = aBuffer2[i]; 619 | bSM[wy][wx+i*32+i] = bBuffer1[i]; 620 | bSM[wy+8][wx+i*32+i] = bBuffer2[i]; 621 | } 622 | } 623 | -------------------------------------------------------------------------------- /kernels/mbmm.cu: -------------------------------------------------------------------------------- 1 | #define _VOLATILE_ 2 | 3 | #define likely(x) __builtin_expect(!!(x), 1) 4 | #define unlikely(x) __builtin_expect(!!(x), 0) 5 | #define load(x) __ldcg(x) 6 | #define store(x, value) __stcs(x, value) 7 | 8 | typedef long long ll_t; 9 | typedef unsigned long long ull_t; 10 | typedef unsigned char uint8_t; 11 | 12 | typedef struct __builtin_align__(32) { 13 | float s0, s1, s2, s3, s4, s5, s6, s7; 14 | } _float8; 15 | 16 | typedef union { 17 | _float8 f8; 18 | float val[8]; 19 | } float8; 20 | 21 | __device__ void init_cCache( 22 | float8 cCache[8] 23 | ) { 24 | #pragma unroll 25 | for (int i=0; i<8; i++){ 26 | #pragma unroll 27 | for (int j=0; j<8; j++){ 28 | cCache[i].val[j] = 0.f; 29 | } 30 | } 31 | } 32 | 33 | __device__ void thread_matmul_v4( 34 | _VOLATILE_ float aSM[8][128+4], 35 | _VOLATILE_ float bSM[8][128+4], 36 | float8 cCache[8], 37 | int vx, int vy 38 | ) { 39 | float aCache1[8]; 40 | float aCache2[8]; 41 | #pragma unroll 42 | for (int mi=0; mi<8; mi++){ 43 | aCache1[mi] = aSM[0][8*vy + mi]; 44 | } 45 | 46 | #pragma unroll 47 | for (int ki=0; ki<8; ki++){ 48 | int is_odd = ki & 1; 49 | if (is_odd == 0){ 50 | if (likely(ki < 7)){ 51 | #pragma unroll 52 | for (int mi=0; mi<8; mi++){ 53 | aCache2[mi] = aSM[ki+1][8*vy + mi]; 54 | } 55 | } 56 | #pragma unroll 57 | for (int ni=0; ni<8; ni++){ 58 | float b = bSM[ki][vx/4 + 8*vx + ni]; 59 | #pragma unroll 60 | for (int mi=0; mi<8; mi++){ 61 | float a = aCache1[mi]; 62 | cCache[mi].val[ni] = fmaf(a, b, cCache[mi].val[ni]); 63 | } 64 | } 65 | } else { 66 | if (likely(ki < 7)){ 67 | #pragma unroll 68 | for (int mi=0; mi<8; mi++){ 69 | aCache1[mi] = aSM[ki+1][8*vy + mi]; 70 | } 71 | } 72 | #pragma unroll 73 | for (int ni=0; ni<8; ni++){ 74 | float b = bSM[ki][vx/4 + 8*vx + ni]; 75 | #pragma unroll 76 | for (int mi=0; mi<8; mi++){ 77 | float a = aCache2[mi]; 78 | cCache[mi].val[ni] = fmaf(a, b, cCache[mi].val[ni]); 79 | } 80 | } 81 | } 82 | } 83 | } 84 | 85 | __device__ void thread_matmul_v3( 86 | _VOLATILE_ float aSM[8][128+4], 87 | _VOLATILE_ float bSM[8][128+4], 88 | float8 cCache[8], 89 | int vx, int vy 90 | ) { 91 | float aCache[8]; 92 | 93 | #pragma unroll 94 | for (int ki=0; ki<8; ki++){ 95 | #pragma unroll 96 | for (int mi=0; mi<8; mi++){ 97 | aCache[mi] = aSM[ki][8*vy + mi]; 98 | } 99 | #pragma unroll 100 | for (int ni=0; ni<8; ni++){ 101 | float b = bSM[ki][vx/4 + 8*vx + ni]; 102 | #pragma unroll 103 | for (int mi=0; mi<8; mi++){ 104 | float a = aCache[mi]; 105 | cCache[mi].val[ni] = fmaf(a, b, cCache[mi].val[ni]); 106 | } 107 | } 108 | } 109 | } 110 | 111 | __device__ void mask_cCache( 112 | float8 cCache[8], 113 | const uint8_t* ElementMask, 114 | int gStartx, 115 | int gStarty, 116 | int vx, int vy, int bid, 117 | int M, int N 118 | ) { 119 | #pragma unroll 120 | for (int i=0; i<8; i++){ 121 | int iM = gStarty + vy*8 + i; 122 | if (likely(iM < M)){ 123 | #pragma unroll 124 | for (int j=0; j<8; j++){ 125 | int iN = gStartx + vx*8 + j; 126 | if (likely(iN < N)){ 127 | uint8_t element_mask = ElementMask[(__MASK_BID__)*M*N + (iM)*N + (iN)]; 128 | cCache[i].val[j] *= element_mask; 129 | } 130 | } 131 | } 132 | } 133 | } 134 | 135 | // Unsafe 136 | __device__ void write_c( 137 | float8 cCache[8], 138 | float* C, 139 | int gStartx, int gStarty, 140 | int vx, int vy, int bid, 141 | int M, int N 142 | ) { 143 | #pragma unroll 144 | for (int i=0; i<8; i++){ 145 | int iM = gStarty + vy*8 + i; 146 | if (likely(iM < M)){ 147 | int iN_start = gStartx + vx*8; 148 | reinterpret_cast(C + (bid)*M*N + (iM)*N + (iN_start))[0] = cCache[i]; 149 | } 150 | } 151 | } 152 | 153 | __device__ void write_c_v3( 154 | float8 cCache[8], 155 | float* C, 156 | int gStartx, int gStarty, 157 | int vx, int vy, int bid, 158 | int M, int N 159 | ) { 160 | __shared__ volatile float cSM[16][128]; 161 | #pragma unroll 162 | for (int mi=0; mi<8; mi++){ 163 | int iM = gStarty + vy*8 + mi; 164 | // Store 1 row from cCache to cSM 165 | if (iM < M){ 166 | #pragma unroll 167 | for (int ni=0; ni<8; ni++){ 168 | cSM[vy][vx*8 + ni] = cCache[mi].val[ni]; 169 | } 170 | // Store to C 171 | #pragma unroll 172 | for (int ni=0; ni<8; ni++){ 173 | int iN = gStartx + 16*ni + vx; 174 | if (iN < N){ 175 | float cVal = cSM[vy][16*ni + vx]; 176 | //store(C+(bid)*M*N + (iM)*N + (iN), cVal); 177 | C[(bid)*M*N + (iM)*N + (iN)] = cVal; 178 | } 179 | } 180 | } 181 | } 182 | } 183 | 184 | extern "C" 185 | __global__ void mbmm_tn( 186 | const float* __restrict__ A, 187 | const float* __restrict__ B, 188 | float* __restrict__ C, 189 | const uint8_t* __restrict__ BlockMask, 190 | const uint8_t* __restrict__ ThreadMask, 191 | const uint8_t* __restrict__ ElementMask, 192 | int M, int N, int K 193 | ){ 194 | } 195 | 196 | extern "C" 197 | __global__ void mbmm_nt( 198 | const float* __restrict__ A, 199 | const float* __restrict__ B, 200 | float* __restrict__ C, 201 | const uint8_t* __restrict__ BlockMask, 202 | const uint8_t* __restrict__ ThreadMask, 203 | const uint8_t* __restrict__ ElementMask, 204 | int M, int N, int K 205 | ){ 206 | } 207 | 208 | extern "C" 209 | __global__ void mbmm_nn( 210 | const float* __restrict__ A, 211 | const float* __restrict__ B, 212 | float* __restrict__ C, 213 | const uint8_t* __restrict__ BlockMask, 214 | const uint8_t* __restrict__ ThreadMask, 215 | const uint8_t* __restrict__ ElementMask, 216 | int M, int N, int K 217 | ){ 218 | int tid = threadIdx.x; // thread idx 219 | int bid = blockIdx.z; // batch idx 220 | 221 | // Neighboring blocks are grouped into PN x PM block groups in order to increase 222 | // L1 cache hit rate. 223 | // There are ceil(M/PM) x ceil(N/PN) block groups in total. 224 | // Blocks within block groups are indexed with blockIdx.x % PN and blockIdx.x / PN 225 | 226 | int px = blockIdx.x % _PN_; 227 | int py = blockIdx.x / _PN_; 228 | int bDimX = (N + (128*_PN_) - 1) / (128*_PN_); 229 | int bDimY = (M + (128*_PM_) - 1) / (128*_PM_); 230 | int bIdxX = (blockIdx.y % bDimX) * _PN_ + px; 231 | int bIdxY = (blockIdx.y / bDimX) * _PM_ + py; 232 | int gStartx = bIdxX * 128; // starting index of block on N axis 233 | int gStarty = bIdxY * 128; // starting index of block on M axis 234 | if (gStartx > N || gStarty > M){ 235 | return; 236 | } 237 | 238 | // These are used to re-arrange threads into different shapes 239 | // for example: (256) -> (16, 16) -> (8, 32) -> (32, 8) 240 | int vx = tid % 16; 241 | int vy = tid / 16; 242 | int wx = tid % 32; // thread idx in warp 243 | int wy = tid / 32; // warp id 244 | int dx = tid % 8; 245 | int dy = tid / 8; 246 | 247 | int bM = (M + 128 - 1) / 128; 248 | int bN = (N + 128 - 1) / 128; 249 | int tM = (M + 8 - 1) / 8; 250 | int tN = (N + 8 - 1) / 8; 251 | uint8_t block_mask = BlockMask[__MASK_BID__*bM*bN + (bIdxY)*bN + (bIdxX)]; 252 | uint8_t thread_mask = ThreadMask[__MASK_BID__*tM*tN + (bIdxY*16 + vy)*tN + (bIdxX*16 + vx) ]; 253 | if (block_mask == 0){ 254 | return; 255 | } 256 | 257 | __shared__ _VOLATILE_ float aSM1[8][128+4]; 258 | __shared__ _VOLATILE_ float bSM1[8][128+4]; 259 | __shared__ _VOLATILE_ float aSM2[8][128+4]; 260 | __shared__ _VOLATILE_ float bSM2[8][128+4]; 261 | float aBuffer1[4]; 262 | float bBuffer1[4]; 263 | float aBuffer2[4]; 264 | float bBuffer2[4]; 265 | 266 | float8 cCache[8]; 267 | init_cCache(cCache); 268 | 269 | // Load initial 16 x 128 tile of A and B to buffer1 and buffer2 270 | #pragma unroll 271 | for (int i=0; i<4; i++){ 272 | int iM = gStarty + dy + i*32; 273 | int iN = gStartx + wx + i*32; 274 | if (likely(iM < M)){ 275 | if (likely(dx < K)){ 276 | aBuffer1[i] = load(A + (bid)*M*K + (iM)*K + (dx)); 277 | } else { 278 | aBuffer1[i] = 0.f; 279 | } 280 | if (likely(dx+8 < K)){ 281 | aBuffer2[i] = load(A + (bid)*M*K + (iM)*K + (dx+8)); 282 | } else { 283 | aBuffer2[i] = 0.f; 284 | } 285 | } 286 | if (likely(iN < N)){ 287 | if (likely(wy < K)){ 288 | bBuffer1[i] = load(B + (bid)*N*K + (wy)*N + (iN)); 289 | } else { 290 | bBuffer1[i] = 0.f; 291 | } 292 | if (likely(wy+8 < K)){ 293 | bBuffer2[i] = load(B + (bid)*N*K + (wy+8)*N + (iN)); 294 | } else { 295 | bBuffer2[i] = 0.f; 296 | } 297 | } 298 | } 299 | 300 | // Number of main loop iterations is ceil(k/16) 301 | int nIt = (K + 16 - 1) / 16; 302 | #pragma unroll 303 | for (int itr=0; itr N || gStarty > M){ 179 | return; 180 | } 181 | // These are used to re-arrange threads into different shapes 182 | // for example: (256) -> (16, 16) -> (8, 32) -> (32, 8) 183 | int vx = tid % 16; 184 | int vy = tid / 16; 185 | int wx = tid % 32; // thread idx in warp 186 | int wy = tid / 32; // warp id 187 | int dx = tid % 8; 188 | int dy = tid / 8; 189 | 190 | __shared__ _VOLATILE_ float aSmem[16][128+4]; 191 | __shared__ _VOLATILE_ float bSmem[16][128+4]; 192 | 193 | float aBuffer1[4]; 194 | float bBuffer1[4]; 195 | float aBuffer2[4]; 196 | float bBuffer2[4]; 197 | 198 | float8 cCache[8]; 199 | init_cCache(cCache); 200 | 201 | // Load initial 16 x 128 tile of A and B to buffer1 and buffer2 202 | #pragma unroll 203 | load_ab_tn( 204 | A, B, 205 | aBuffer1, aBuffer2, bBuffer1, bBuffer2, 206 | bid, gStartx, gStarty, 0, 207 | M, N, K 208 | ); 209 | 210 | // Number of main loop iterations is ceil(k/16) 211 | int nIt = (K + 16 - 1) / 16; 212 | #pragma unroll 213 | for (int itr=0; itr N || gStarty > M){ 275 | return; 276 | } 277 | // These are used to re-arrange threads into different shapes 278 | // for example: (256) -> (16, 16) -> (8, 32) -> (32, 8) 279 | int vx = tid % 16; 280 | int vy = tid / 16; 281 | int wx = tid % 32; // thread idx in warp 282 | int wy = tid / 32; // warp id 283 | int dx = tid % 8; 284 | int dy = tid / 8; 285 | 286 | __shared__ _VOLATILE_ float aSmem[16][128+4]; 287 | __shared__ _VOLATILE_ float bSmem[16][128+4]; 288 | 289 | float aBuffer1[4]; 290 | float bBuffer1[4]; 291 | float aBuffer2[4]; 292 | float bBuffer2[4]; 293 | 294 | float8 cCache[8]; 295 | init_cCache(cCache); 296 | 297 | // Load initial 16 x 128 tile of A and B to buffer1 and buffer2 298 | #pragma unroll 299 | load_ab_nt( 300 | A, B, 301 | aBuffer1, aBuffer2, bBuffer1, bBuffer2, 302 | bid, gStartx, gStarty, 0, 303 | M, N, K 304 | ); 305 | 306 | // Number of main loop iterations is ceil(k/16) 307 | int nIt = (K + 16 - 1) / 16; 308 | #pragma unroll 309 | for (int itr=0; itr N || gStarty > M){ 371 | return; 372 | } 373 | // These are used to re-arrange threads into different shapes 374 | // for example: (256) -> (16, 16) -> (8, 32) -> (32, 8) 375 | int vx = tid % 16; 376 | int vy = tid / 16; 377 | int wx = tid % 32; // thread idx in warp 378 | int wy = tid / 32; // warp id 379 | int dx = tid % 8; 380 | int dy = tid / 8; 381 | 382 | __shared__ _VOLATILE_ float aSmem[16][128+4]; 383 | __shared__ _VOLATILE_ float bSmem[16][128+4]; 384 | 385 | float aBuffer1[4]; 386 | float bBuffer1[4]; 387 | float aBuffer2[4]; 388 | float bBuffer2[4]; 389 | 390 | float8 cCache[8]; 391 | init_cCache(cCache); 392 | 393 | // Load initial 16 x 128 tile of A and B to buffer1 and buffer2 394 | load_ab_nn( 395 | A, B, 396 | aBuffer1, aBuffer2, bBuffer1, bBuffer2, 397 | bid, gStartx, gStarty, 0, 398 | M, N, K 399 | ); 400 | 401 | // Number of main loop iterations is ceil(k/16) 402 | int nIt = (K + 16 - 1) / 16; 403 | #pragma unroll 404 | for (int itr=0; itr N || gStarty > M){ 466 | return; 467 | } 468 | // These are used to re-arrange threads into different shapes 469 | // for example: (256) -> (16, 16) -> (8, 32) -> (32, 8) 470 | int vx = tid % 16; 471 | int vy = tid / 16; 472 | int wx = tid % 32; // thread idx in warp 473 | int wy = tid / 32; // warp id 474 | int dx = tid % 8; 475 | int dy = tid / 8; 476 | 477 | __shared__ _VOLATILE_ float aSmem[16][128+4]; 478 | __shared__ _VOLATILE_ float bSmem[16][128+4]; 479 | 480 | float aBuffer1[4]; 481 | float bBuffer1[4]; 482 | float aBuffer2[4]; 483 | float bBuffer2[4]; 484 | 485 | float8 cCache[8]; 486 | init_cCache(cCache); 487 | 488 | // Load initial 16 x 128 tile of A and B to buffer1 and buffer2 489 | #pragma unroll 490 | load_ab_tt( 491 | A, B, 492 | aBuffer1, aBuffer2, bBuffer1, bBuffer2, 493 | bid, gStartx, gStarty, 0, 494 | M, N, K 495 | ); 496 | 497 | // Number of main loop iterations is ceil(k/16) 498 | int nIt = (K + 16 - 1) / 16; 499 | #pragma unroll 500 | for (int itr=0; itr 100000) break; 24 | if (ns < 256) { 25 | ns *= 2; 26 | } 27 | } 28 | } 29 | __syncthreads(); 30 | } 31 | 32 | __device__ void mutex_lock_noop( 33 | ) { 34 | __syncthreads(); 35 | } 36 | 37 | __device__ void mutex_unlock( 38 | unsigned int *mutex 39 | ) { 40 | __threadfence(); 41 | __syncthreads(); 42 | if (threadIdx.x == 0){ 43 | atomicExch(mutex, 0); 44 | __threadfence(); 45 | } 46 | __syncthreads(); 47 | } 48 | 49 | __device__ void mutex_unlock_noop(){ 50 | __syncthreads(); 51 | __syncthreads(); 52 | } 53 | 54 | __device__ __forceinline__ unsigned int bfe( 55 | unsigned int source, 56 | unsigned int bitIndex 57 | ) { 58 | unsigned int bit; 59 | asm volatile("bfe.u32 %0, %1, %2, %3;" : "=r"(bit) : "r"((unsigned int) source), "r"(bitIndex), "r"(1)); 60 | return bit; 61 | } 62 | 63 | __device__ __forceinline__ void warp_comparator( 64 | float &value, 65 | float &index, 66 | const int stride, 67 | const int direction 68 | ){ 69 | const float other_value = __shfl_xor_sync(0xFFFFFFFF, value, stride); 70 | const float other_index = __shfl_xor_sync(0xFFFFFFFF, index, stride); 71 | bool condition = value < other_value == direction; 72 | index = condition ? other_index : index; 73 | value = condition ? other_value : value; 74 | } 75 | 76 | __device__ __forceinline__ void block_comparator( 77 | float &value, 78 | float &index, 79 | const int stride, 80 | const int direction, 81 | const int laneID, 82 | _VOLATILE_ float valSmem[128+4], 83 | _VOLATILE_ float idxSmem[128+4] 84 | ){ 85 | valSmem[laneID] = value; 86 | idxSmem[laneID] = index; 87 | __syncthreads(); 88 | 89 | float other_value = valSmem[laneID ^ stride]; 90 | float other_index = idxSmem[laneID ^ stride]; 91 | __syncthreads(); 92 | 93 | bool condition = value < other_value == direction; 94 | index = condition ? other_index : index; 95 | value = condition ? other_value : value; 96 | } 97 | 98 | __device__ void bitonic_sort_128( 99 | float &value, 100 | float &index, 101 | _VOLATILE_ float valSmem[128+4], 102 | _VOLATILE_ float idxSmem[128+4] 103 | ) { 104 | unsigned int laneID = threadIdx.x % 128; 105 | warp_comparator(value, index, 1, bfe(laneID, 1) ^ bfe(laneID, 0)); 106 | 107 | warp_comparator(value, index, 2, bfe(laneID, 2) ^ bfe(laneID, 1)); 108 | warp_comparator(value, index, 1, bfe(laneID, 2) ^ bfe(laneID, 0)); 109 | 110 | warp_comparator(value, index, 4, bfe(laneID, 3) ^ bfe(laneID, 2)); 111 | warp_comparator(value, index, 2, bfe(laneID, 3) ^ bfe(laneID, 1)); 112 | warp_comparator(value, index, 1, bfe(laneID, 3) ^ bfe(laneID, 0)); 113 | 114 | warp_comparator(value, index, 8, bfe(laneID, 4) ^ bfe(laneID, 3)); 115 | warp_comparator(value, index, 4, bfe(laneID, 4) ^ bfe(laneID, 2)); 116 | warp_comparator(value, index, 2, bfe(laneID, 4) ^ bfe(laneID, 1)); 117 | warp_comparator(value, index, 1, bfe(laneID, 4) ^ bfe(laneID, 0)); 118 | 119 | warp_comparator(value, index, 16, bfe(laneID, 5) ^ bfe(laneID, 4)); 120 | warp_comparator(value, index, 8, bfe(laneID, 5) ^ bfe(laneID, 3)); 121 | warp_comparator(value, index, 4, bfe(laneID, 5) ^ bfe(laneID, 2)); 122 | warp_comparator(value, index, 2, bfe(laneID, 5) ^ bfe(laneID, 1)); 123 | warp_comparator(value, index, 1, bfe(laneID, 5) ^ bfe(laneID, 0)); 124 | 125 | block_comparator(value, index, 32, bfe(laneID, 6) ^ bfe(laneID, 5), laneID, valSmem, idxSmem); 126 | warp_comparator(value, index, 16, bfe(laneID, 6) ^ bfe(laneID, 4)); 127 | warp_comparator(value, index, 8, bfe(laneID, 6) ^ bfe(laneID, 3)); 128 | warp_comparator(value, index, 4, bfe(laneID, 6) ^ bfe(laneID, 2)); 129 | warp_comparator(value, index, 2, bfe(laneID, 6) ^ bfe(laneID, 1)); 130 | warp_comparator(value, index, 1, bfe(laneID, 6) ^ bfe(laneID, 0)); 131 | 132 | block_comparator(value, index, 64, bfe(laneID, 6), laneID, valSmem, idxSmem); 133 | block_comparator(value, index, 32, bfe(laneID, 5), laneID, valSmem, idxSmem); 134 | warp_comparator(value, index, 16, bfe(laneID, 4)); 135 | warp_comparator(value, index, 8, bfe(laneID, 3)); 136 | warp_comparator(value, index, 4, bfe(laneID, 2)); 137 | warp_comparator(value, index, 2, bfe(laneID, 1)); 138 | warp_comparator(value, index, 1, bfe(laneID, 0)); 139 | } 140 | 141 | __device__ void bitonic_sort_256( 142 | float &value, 143 | float &index, 144 | float* g_values, 145 | ll_t* g_indices, 146 | float valSmem[128+4], 147 | float idxSmem[128+4], 148 | int Q, int adr, bool ok 149 | ){ 150 | int laneID = threadIdx.x % 128; 151 | float other_index; 152 | float other_value; 153 | if (ok){ 154 | other_value = g_values[adr]; 155 | other_index = g_indices[adr]; 156 | } else { 157 | other_value = -INFINITY; 158 | other_index = 0; 159 | } 160 | bool condition = value > other_value == 0; 161 | if (condition){ 162 | value = value + other_value; 163 | index = index + other_index; 164 | other_value = value - other_value; 165 | other_index = index - other_index; 166 | value = value - other_value; 167 | index = index - other_index; 168 | } 169 | 170 | block_comparator(value, index, 64, !bfe(laneID, 6), laneID, valSmem, idxSmem); 171 | block_comparator(value, index, 32, !bfe(laneID, 5), laneID, valSmem, idxSmem); 172 | warp_comparator(value, index, 16, !bfe(laneID, 4)); 173 | warp_comparator(value, index, 8, !bfe(laneID, 3)); 174 | warp_comparator(value, index, 4, !bfe(laneID, 2)); 175 | warp_comparator(value, index, 2, !bfe(laneID, 1)); 176 | warp_comparator(value, index, 1, !bfe(laneID, 0)); 177 | /* 178 | */ 179 | if (ok){ 180 | g_values[adr] = value; 181 | g_indices[adr] = index; 182 | } 183 | } 184 | 185 | __device__ void bitonicSort256_noop() 186 | { 187 | __syncthreads(); 188 | __syncthreads(); 189 | __syncthreads(); 190 | __syncthreads(); 191 | } 192 | 193 | __device__ void topk_dim_1( 194 | float8 cCache[8], 195 | _VOLATILE_ float valSmem[16][128+4], 196 | _VOLATILE_ float idxSmem[16][128+4], 197 | float* values, 198 | ll_t* indices, 199 | unsigned int* mutex, 200 | int gStartx, int gStarty, int bid, 201 | int M, int N, int Q 202 | ){ 203 | int tid = threadIdx.x; 204 | int vx = tid % 16; 205 | int vy = tid / 16; 206 | int hx = tid % 128; 207 | int hy = tid / 128; 208 | #pragma unroll 209 | for (int ni=0; ni<8; ni++){ 210 | int iN = gStartx + vx*8 + ni; 211 | //if (iN < N) break; 212 | 213 | // Store cCache to cSM 214 | #pragma unroll 215 | for (int mi=0; mi<8; mi++){ 216 | int iM = gStarty + vy*8 + mi; 217 | if (likely(iM < M && iN < N)){ 218 | valSmem[vx][vy*8 + mi] = cCache[mi].val[ni]; 219 | idxSmem[vx][vy*8 + mi] = iM; 220 | } else { 221 | valSmem[vx][vy*8 + mi] = -INFINITY; 222 | idxSmem[vx][vy*8 + mi] = -1; 223 | } 224 | } 225 | __syncthreads(); 226 | // Load from cSM to cCache 227 | #pragma unroll 228 | for (int i=0; i<8; i++){ 229 | float value = valSmem[hy*8 + i][hx]; 230 | float index = idxSmem[hy*8 + i][hx]; 231 | bitonic_sort_128( 232 | value, index, 233 | valSmem[hy*8 + i], idxSmem[hy*8 + i] 234 | ); 235 | int iN = gStartx + (hy*8 + i)*8 + ni; 236 | int adr = (bid)*N*Q + iN*Q + hx; 237 | mutex_lock( &mutex[(bid)*N + iN] ); 238 | bitonic_sort_256( 239 | value, index, 240 | values, indices, 241 | valSmem[hy*8+i], idxSmem[hy*8+i], 242 | Q, adr, iN < N 243 | ); 244 | mutex_unlock( &mutex[(bid)*N + iN] ); 245 | } 246 | } 247 | } 248 | 249 | __device__ void topk_dim_2( 250 | float8 cCache[8], 251 | _VOLATILE_ float valSmem[16][128+4], 252 | _VOLATILE_ float idxSmem[16][128+4], 253 | float* values, 254 | ll_t* indices, 255 | unsigned int* mutex, 256 | int gStartx, int gStarty, int bid, 257 | int M, int N, int Q 258 | ){ 259 | int tid = threadIdx.x; 260 | int vx = tid % 16; 261 | int vy = tid / 16; 262 | int hx = tid % 128; 263 | int hy = tid / 128; 264 | #pragma unroll 265 | for (int mi=0; mi<8; mi++){ 266 | int iM = gStarty + vy*8 + mi; 267 | //if (iM >= M) break; 268 | 269 | // Store cCache to cSM 270 | #pragma unroll 271 | for (int ni=0; ni<8; ni++){ 272 | int iN = gStartx + vx*8 + ni; 273 | if (likely(iN < N && iM < M)){ 274 | valSmem[vy][vx*8 + ni] = cCache[mi].val[ni]; 275 | idxSmem[vy][vx*8 + ni] = iN; 276 | } else { 277 | valSmem[vy][vx*8 + ni] = -INFINITY; 278 | idxSmem[vy][vx*8 + ni] = -1; 279 | } 280 | } 281 | __syncthreads(); 282 | // Load from cSM to cCache 283 | #pragma unroll 284 | for (int i=0; i<8; i++){ 285 | float value = valSmem[hy*8 + i][hx]; 286 | float index = idxSmem[hy*8 + i][hx]; 287 | bitonic_sort_128( 288 | value, index, 289 | valSmem[hy*8 + i], idxSmem[hy*8 + i] 290 | ); 291 | int iM = gStarty + (hy*8 + i)*8 + mi; 292 | int adr = (bid)*M*Q + iM*Q + hx; 293 | mutex_lock( &mutex[(bid)*M + iM] ); 294 | bitonic_sort_256( 295 | value, index, 296 | values, indices, 297 | valSmem[hy*8+i], idxSmem[hy*8+i], 298 | Q, adr, iM < M 299 | ); 300 | mutex_unlock( &mutex[(bid)*M + iM] ); 301 | } 302 | } 303 | } 304 | 305 | extern "C" 306 | __global__ void topk_bmm_tn( 307 | const float* __restrict__ A, 308 | const float* __restrict__ B, 309 | float* values, 310 | ll_t* indices, 311 | unsigned int* mutex, 312 | int M, int N, int K, int DIM, int Q 313 | ){ 314 | int tid = threadIdx.x; // thread idx 315 | int bid = blockIdx.z; // batch idx 316 | 317 | // Neighboring blocks are grouped into PN x PM block groups in order to increase 318 | // L1 cache hit rate 319 | // There are ceil(M/PM) x ceil(N/PN) block groups in total. 320 | // Blocks within block groups are indexed with blockIdx.x % PN and blockIdx.x / PN 321 | int px = blockIdx.x % _PN_; 322 | int py = blockIdx.x / _PN_; 323 | int bDimX = (N + (128*_PN_) - 1) / (128*_PN_); 324 | int bDimY = (M + (128*_PM_) - 1) / (128*_PM_); 325 | int bIdxX = (blockIdx.y % bDimX) * _PN_ + px; 326 | int bIdxY = (blockIdx.y / bDimX) * _PM_ + py; 327 | int gStartx = bIdxX * 128; // starting index of block on N axis 328 | int gStarty = bIdxY * 128; // starting index of block on M axis 329 | if (gStartx > N || gStarty > M){ 330 | return; 331 | } 332 | // These are used to re-arrange threads into different shapes 333 | // for example: (256) -> (16, 16) -> (8, 32) -> (32, 8) 334 | int vx = tid % 16; 335 | int vy = tid / 16; 336 | int wx = tid % 32; // thread idx in warp 337 | int wy = tid / 32; // warp id 338 | int dx = tid % 8; 339 | int dy = tid / 8; 340 | 341 | __shared__ _VOLATILE_ float aSmem[16][128+4]; 342 | __shared__ _VOLATILE_ float bSmem[16][128+4]; 343 | 344 | float aBuffer1[4]; 345 | float bBuffer1[4]; 346 | float aBuffer2[4]; 347 | float bBuffer2[4]; 348 | 349 | float8 cCache[8]; 350 | init_cCache(cCache); 351 | 352 | // Load initial 16 x 128 tile of A and B to buffer1 and buffer2 353 | load_ab_tn( 354 | A, B, 355 | aBuffer1, aBuffer2, bBuffer1, bBuffer2, 356 | bid, gStartx, gStarty, 0, 357 | M, N, K 358 | ); 359 | 360 | // Number of main loop iterations is ceil(k/16) 361 | int nIt = (K + 16 - 1) / 16; 362 | #pragma unroll 363 | for (int itr=0; itr N || gStarty > M){ 427 | return; 428 | } 429 | // These are used to re-arrange threads into different shapes 430 | // for example: (256) -> (16, 16) -> (8, 32) -> (32, 8) 431 | int vx = tid % 16; 432 | int vy = tid / 16; 433 | int wx = tid % 32; // thread idx in warp 434 | int wy = tid / 32; // warp id 435 | int dx = tid % 8; 436 | int dy = tid / 8; 437 | 438 | __shared__ _VOLATILE_ float aSmem[16][128+4]; 439 | __shared__ _VOLATILE_ float bSmem[16][128+4]; 440 | 441 | float aBuffer1[4]; 442 | float bBuffer1[4]; 443 | float aBuffer2[4]; 444 | float bBuffer2[4]; 445 | 446 | float8 cCache[8]; 447 | init_cCache(cCache); 448 | 449 | // Load initial 16 x 128 tile of A and B to buffer1 and buffer2 450 | load_ab_nt( 451 | A, B, 452 | aBuffer1, aBuffer2, bBuffer1, bBuffer2, 453 | bid, gStartx, gStarty, 0, 454 | M, N, K 455 | ); 456 | 457 | // Number of main loop iterations is ceil(k/16) 458 | int nIt = (K + 16 - 1) / 16; 459 | #pragma unroll 460 | for (int itr=0; itr N || gStarty > M){ 524 | return; 525 | } 526 | // These are used to re-arrange threads into different shapes 527 | // for example: (256) -> (16, 16) -> (8, 32) -> (32, 8) 528 | int vx = tid % 16; 529 | int vy = tid / 16; 530 | int wx = tid % 32; // thread idx in warp 531 | int wy = tid / 32; // warp id 532 | int dx = tid % 8; 533 | int dy = tid / 8; 534 | 535 | __shared__ _VOLATILE_ float aSmem[16][128+4]; 536 | __shared__ _VOLATILE_ float bSmem[16][128+4]; 537 | 538 | float aBuffer1[4]; 539 | float bBuffer1[4]; 540 | float aBuffer2[4]; 541 | float bBuffer2[4]; 542 | 543 | float8 cCache[8]; 544 | init_cCache(cCache); 545 | 546 | // Load initial 16 x 128 tile of A and B to buffer1 and buffer2 547 | load_ab_nn( 548 | A, B, 549 | aBuffer1, aBuffer2, bBuffer1, bBuffer2, 550 | bid, gStartx, gStarty, 0, 551 | M, N, K 552 | ); 553 | 554 | // Number of main loop iterations is ceil(k/16) 555 | int nIt = (K + 16 - 1) / 16; 556 | #pragma unroll 557 | for (int itr=0; itr N || gStarty > M){ 621 | return; 622 | } 623 | // These are used to re-arrange threads into different shapes 624 | // for example: (256) -> (16, 16) -> (8, 32) -> (32, 8) 625 | int vx = tid % 16; 626 | int vy = tid / 16; 627 | int wx = tid % 32; // thread idx in warp 628 | int wy = tid / 32; // warp id 629 | int dx = tid % 8; 630 | int dy = tid / 8; 631 | 632 | __shared__ _VOLATILE_ float aSmem[16][128+4]; 633 | __shared__ _VOLATILE_ float bSmem[16][128+4]; 634 | 635 | float aBuffer1[4]; 636 | float bBuffer1[4]; 637 | float aBuffer2[4]; 638 | float bBuffer2[4]; 639 | 640 | float8 cCache[8]; 641 | init_cCache(cCache); 642 | 643 | // Load initial 16 x 128 tile of A and B to buffer1 and buffer2 644 | load_ab_tt( 645 | A, B, 646 | aBuffer1, aBuffer2, bBuffer1, bBuffer2, 647 | bid, gStartx, gStarty, 0, 648 | M, N, K 649 | ); 650 | 651 | // Number of main loop iterations is ceil(k/16) 652 | int nIt = (K + 16 - 1) / 16; 653 | #pragma unroll 654 | for (int itr=0; itr A @ B 166 | "tt" --> A.T @ B.T 167 | "nt" --> A @ B.T 168 | "tn" --> A.T @ B 169 | """ 170 | assert len(A.shape) == len(B.shape) 171 | A = A.contiguous() 172 | B = B.contiguous() 173 | if len(A.shape) == 2 and len(B.shape) == 2: 174 | A2 = A[None] 175 | B2 = B[None] 176 | if not self.share_mask: 177 | block_mask = block_mask[None] 178 | thread_mask = thread_mask[None] 179 | element_mask = element_mask[None] 180 | elif len(A.shape) == 3 and len(B.shape) == 3: 181 | A2 = A 182 | B2 = B 183 | else: 184 | raise ValueError("shape of A and B need to be 2d or 3d") 185 | 186 | if mode == "nn": 187 | C = self._call_nn(A2, B2, block_mask, thread_mask, element_mask) 188 | elif mode == "tt": 189 | C = self._call_tt(A2, B2, block_mask, thread_mask, element_mask) 190 | elif mode == "tn": 191 | C = self._call_tn(A2, B2, block_mask, thread_mask, element_mask) 192 | elif mode == "nt": 193 | C = self._call_nt(A2, B2, block_mask, thread_mask, element_mask) 194 | 195 | if len(A.shape) == 2 and len(B.shape) == 2: 196 | C = C[0] 197 | return C -------------------------------------------------------------------------------- /minbmm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import cupy as cp 3 | import numpy as np 4 | import math 5 | from custom_kernel import CustomKernel 6 | 7 | class MinBMMCUDA(CustomKernel): 8 | def __init__(self, patch_m=4, patch_n=4, distance="inner"): 9 | super(MinBMMCUDA, self).__init__() 10 | self.patch_m = patch_m 11 | self.patch_n = patch_n 12 | if distance == "inner": 13 | dist_fn = "madd" 14 | elif distance in ["l2", "euclidean"]: 15 | dist_fn = "squared_l2" 16 | elif distance in ["l1", "manhattan"]: 17 | dist_fn = "l1" 18 | else: 19 | ValueError("Unrecognized distance type") 20 | 21 | self.distance = distance 22 | 23 | with open("kernels/bmm_helpers.cu", "r") as f: 24 | helpers = f.read() 25 | 26 | with open("kernels/minbmm.cu",'r') as f: ### 27 | self.kernel = helpers + f.read() 28 | 29 | self.kernel = (self.kernel 30 | .replace("_PM_", str(self.patch_m)) 31 | .replace("_PN_", str(self.patch_n)) 32 | .replace("__DISTANCE_FN__", dist_fn) 33 | ) 34 | 35 | self._fn_tt = cp.RawKernel( 36 | code=self.kernel, 37 | name="min_bmm_tt", 38 | backend='nvcc', 39 | options=('--maxrregcount=128', '--use_fast_math') 40 | ) 41 | self._fn_nn = cp.RawKernel( 42 | code=self.kernel, 43 | name="min_bmm_nn", 44 | backend='nvcc', 45 | options=( 46 | '--maxrregcount=128', 47 | '--use_fast_math', 48 | #'-Xptxas', 49 | #'-dlcm=cg', 50 | ) 51 | ) 52 | # print(self._fn_nn.attributes) 53 | self._fn_tn = cp.RawKernel( 54 | code=self.kernel, 55 | name="min_bmm_tn", 56 | backend='nvcc', 57 | options=('--maxrregcount=128', '--use_fast_math') 58 | ) 59 | self._fn_nt = cp.RawKernel( 60 | code=self.kernel, 61 | name="min_bmm_nt", 62 | backend='nvcc', 63 | options=('--maxrregcount=128', '--use_fast_math') 64 | ) 65 | 66 | def get_mode(self, A, B): 67 | mode = [None, None] 68 | if A.stride()[-1] == 1: 69 | mode[0] = "n" 70 | elif A.stride()[-2] == 1: 71 | mode[0] = "t" 72 | if B.stride()[-1] == 1: 73 | mode[1] = "n" 74 | elif B.stride()[-2] == 1: 75 | mode[1] = "t" 76 | return "".join(mode) 77 | 78 | def __call__(self, A, B, dim=1): 79 | """ 80 | Performs C = min(f(A) @ g(B)), argmin(f(A) @ g(B)) 81 | A: torch.Tensor, shape : [l, m, k] 82 | B: torch.Tensor, shape : [l, k, n] 83 | returns C: torch.Tensor, shape : [l, m, n] 84 | """ 85 | assert len(A.shape) == len(B.shape) 86 | if len(A.shape) == 2 and len(B.shape) == 2: 87 | A = A[None] 88 | B = B[None] 89 | dim += 1 90 | two_dimentional = True 91 | elif len(A.shape) == 3 and len(B.shape) == 3: 92 | two_dimentional = False 93 | else: 94 | raise ValueError("A and B need to be 2d or 3d") 95 | assert A.shape[0] == B.shape[0] 96 | assert A.shape[2] == B.shape[1] 97 | assert A.dtype == B.dtype 98 | assert A.dtype in [torch.float, torch.half] 99 | assert A.device.type == B.device.type == "cuda" 100 | assert dim in [1, 2] 101 | 102 | mode = self.get_mode(A, B) 103 | 104 | if mode == "nn": 105 | kernel_fn = self._fn_nn 106 | elif mode == "nt": 107 | kernel_fn = self._fn_nt 108 | elif mode == "tn": 109 | kernel_fn = self._fn_tn 110 | elif mode == "tt": 111 | kernel_fn = self._fn_tt 112 | 113 | l, m, k = A.shape 114 | l, k, n = B.shape 115 | 116 | if dim == 1: 117 | values = torch.empty([l, n], device="cuda:0", dtype=A.dtype) 118 | indices = torch.empty([l, n], device="cuda:0", dtype=torch.int64) 119 | elif dim == 2: 120 | values = torch.empty([l, m], device="cuda:0", dtype=A.dtype) 121 | indices = torch.empty([l, m], device="cuda:0", dtype=torch.int64) 122 | values.fill_(float("inf")) 123 | 124 | threads_per_block = (256,) 125 | #blocks_per_grid = (math.ceil(n/128), math.ceil(m/128), l) 126 | 127 | n_ = math.ceil(n/(128*self.patch_n)) 128 | m_ = math.ceil(m/(128*self.patch_m)) 129 | blocks_per_grid = (self.patch_n*self.patch_m, n_ * m_, l) 130 | # print(blocks_per_grid, m_, n_) 131 | 132 | kernel_fn( 133 | grid=blocks_per_grid, 134 | block=threads_per_block, 135 | args=[ 136 | A.data_ptr(), 137 | B.data_ptr(), 138 | values.data_ptr(), 139 | indices.data_ptr(), 140 | m, n, k, dim, 141 | ], 142 | stream=self.stream 143 | ) 144 | 145 | if two_dimentional: 146 | indices = indices[0] 147 | values = values[0] 148 | 149 | return values, indices -------------------------------------------------------------------------------- /topkbmm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import cupy as cp 3 | import numpy as np 4 | import math 5 | from custom_kernel import CustomKernel 6 | 7 | class TopkBMMCUDA(CustomKernel): 8 | def __init__( 9 | self, patch_m=4, patch_n=4, 10 | distance="inner" 11 | ): 12 | super(TopkBMMCUDA, self).__init__() 13 | self.patch_m = patch_m 14 | self.patch_n = patch_n 15 | if distance == "inner": 16 | dist_fn = "madd" 17 | elif distance in ["l2", "euclidean"]: 18 | dist_fn = "squared_l2" 19 | elif distance in ["l1", "manhattan"]: 20 | dist_fn = "l1" 21 | else: 22 | ValueError("Unrecognized distance type") 23 | 24 | self.distance = distance 25 | 26 | with open("kernels/bmm_helpers.cu",'r') as f: ### 27 | helpers = f.read() 28 | 29 | with open("kernels/topkbmm.cu",'r') as f: ### 30 | self.kernel = helpers + f.read() 31 | 32 | self.kernel = (self.kernel 33 | .replace("_PM_", str(self.patch_m)) 34 | .replace("_PN_", str(self.patch_n)) 35 | .replace("__DISTANCE_FN__", dist_fn) 36 | ) 37 | 38 | self._fn_tt = cp.RawKernel( 39 | code=self.kernel, 40 | name="topk_bmm_tt", 41 | backend='nvcc', 42 | options=('--maxrregcount=128', '--use_fast_math') 43 | ) 44 | self._fn_nn = cp.RawKernel( 45 | code=self.kernel, 46 | name="topk_bmm_nn", 47 | backend='nvcc', 48 | options=( 49 | '--maxrregcount=128', 50 | '--use_fast_math', 51 | #'-Xptxas', 52 | #'-dlcm=cg', 53 | ) 54 | ) 55 | # print(self._fn_nn.attributes) 56 | self._fn_tn = cp.RawKernel( 57 | code=self.kernel, 58 | name="topk_bmm_tn", 59 | backend='nvcc', 60 | options=('--maxrregcount=128', '--use_fast_math') 61 | ) 62 | self._fn_nt = cp.RawKernel( 63 | code=self.kernel, 64 | name="topk_bmm_nt", 65 | backend='nvcc', 66 | options=('--maxrregcount=128', '--use_fast_math') 67 | ) 68 | 69 | def get_mode(self, A, B): 70 | mode = [None, None] 71 | if A.stride()[-1] == 1: 72 | mode[0] = "n" 73 | elif A.stride()[-2] == 1: 74 | mode[0] = "t" 75 | if B.stride()[-1] == 1: 76 | mode[1] = "n" 77 | elif B.stride()[-2] == 1: 78 | mode[1] = "t" 79 | return "".join(mode) 80 | 81 | def __call__(self, A, B, k=128, dim=1): 82 | """ 83 | Performs C = min(f(A) @ g(B)), argmin(f(A) @ g(B)) 84 | A: torch.Tensor, shape : [l, m, k] 85 | B: torch.Tensor, shape : [l, k, n] 86 | returns C: torch.Tensor, shape : [l, m, n] 87 | """ 88 | assert len(A.shape) == len(B.shape) 89 | if len(A.shape) == 2 and len(B.shape) == 2: 90 | A = A[None] 91 | B = B[None] 92 | two_dimentional = True 93 | dim += 1 94 | elif len(A.shape) == 3 and len(B.shape) == 3: 95 | two_dimentional = False 96 | else: 97 | raise ValueError("shape of A and B need to be 2d or 3d") 98 | assert A.shape[0] == B.shape[0] 99 | assert A.shape[2] == B.shape[1] 100 | assert A.dtype == B.dtype 101 | assert A.dtype in [torch.float, torch.half] 102 | assert A.device.type == B.device.type == "cuda" 103 | assert dim in [1, 2] 104 | assert 0 < k <= 128 105 | 106 | mode = self.get_mode(A, B) 107 | if mode == "nn": 108 | kernel_fn = self._fn_nn 109 | elif mode == "nt": 110 | kernel_fn = self._fn_nt 111 | elif mode == "tn": 112 | kernel_fn = self._fn_tn 113 | elif mode == "tt": 114 | kernel_fn = self._fn_tt 115 | 116 | l, m, d = A.shape 117 | l, d, n = B.shape 118 | 119 | if dim == 1: 120 | values = torch.empty([l, n, 128], device="cuda:0", dtype=A.dtype) 121 | indices = torch.empty([l, n, 128], device="cuda:0", dtype=torch.int64) 122 | mutex = torch.zeros([l, n], device="cuda:0", dtype=torch.int32) 123 | elif dim == 2: 124 | values = torch.empty([l, m, 128], device="cuda:0", dtype=A.dtype) 125 | indices = torch.empty([l, m, 128], device="cuda:0", dtype=torch.int64) 126 | mutex = torch.zeros([l, m], device="cuda:0", dtype=torch.int32) 127 | values.fill_(float("-inf")) 128 | 129 | threads_per_block = (256,) 130 | #blocks_per_grid = (math.ceil(n/128), math.ceil(m/128), l) 131 | 132 | n_ = math.ceil(n/(128*self.patch_n)) 133 | m_ = math.ceil(m/(128*self.patch_m)) 134 | blocks_per_grid = (self.patch_n*self.patch_m, n_ * m_, l) 135 | # print(blocks_per_grid, m_, n_) 136 | 137 | kernel_fn( 138 | grid=blocks_per_grid, 139 | block=threads_per_block, 140 | args=[ 141 | A.data_ptr(), 142 | B.data_ptr(), 143 | values.data_ptr(), 144 | indices.data_ptr(), 145 | mutex.data_ptr(), 146 | m, n, d, dim, 128 147 | ], 148 | stream=self.stream 149 | ) 150 | indices = indices[:, :, :k] 151 | values = values[:, :, :k] 152 | 153 | if two_dimentional: 154 | indices = indices[0] 155 | values = values[0] 156 | 157 | return values, indices --------------------------------------------------------------------------------