└── ThreadGroupTilingX.hlsl /ThreadGroupTilingX.hlsl: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2020-2024, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Permission is hereby granted, free of charge, to any person obtaining a 5 | * copy of this software and associated documentation files (the "Software"), 6 | * to deal in the Software without restriction, including without limitation 7 | * the rights to use, copy, modify, merge, publish, distribute, sublicense, 8 | * and/or sell copies of the Software, and to permit persons to whom the 9 | * Software is furnished to do so, subject to the following conditions: 10 | * 11 | * The above copyright notice and this permission notice shall be included in 12 | * all copies or substantial portions of the Software. 13 | * 14 | * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 | * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 16 | * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 17 | * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 18 | * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 19 | * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 20 | * DEALINGS IN THE SOFTWARE. 21 | */ 22 | 23 | // Normally argument "dipatchGridDim" is parsed through a constant buffer. However, if for some reason it is a 24 | // static value, some DXC compiler versions will be unable to compile the code. 25 | // If that's the case for you, flip DXC_STATIC_DISPATCH_GRID_DIM definition from 0 to 1. 26 | #define DXC_STATIC_DISPATCH_GRID_DIM 0 27 | 28 | // Divide the 2D-Dispatch_Grid into tiles of dimension [N, DipatchGridDim.y] 29 | // “CTA” (Cooperative Thread Array) == Thread Group in DirectX terminology 30 | uint2 ThreadGroupTilingX( 31 | const uint2 dipatchGridDim, // Arguments of the Dispatch call (typically from a ConstantBuffer) 32 | const uint2 ctaDim, // Already known in HLSL, eg:[numthreads(8, 8, 1)] -> uint2(8, 8) 33 | const uint maxTileWidth, // User parameter (N). Recommended values: 8, 16 or 32. 34 | const uint2 groupThreadID, // SV_GroupThreadID 35 | const uint2 groupId // SV_GroupID 36 | ) 37 | { 38 | // A perfect tile is one with dimensions = [maxTileWidth, dipatchGridDim.y] 39 | const uint Number_of_CTAs_in_a_perfect_tile = maxTileWidth * dipatchGridDim.y; 40 | 41 | // Possible number of perfect tiles 42 | const uint Number_of_perfect_tiles = dipatchGridDim.x / maxTileWidth; 43 | 44 | // Total number of CTAs present in the perfect tiles 45 | const uint Total_CTAs_in_all_perfect_tiles = Number_of_perfect_tiles * maxTileWidth * dipatchGridDim.y; 46 | const uint vThreadGroupIDFlattened = dipatchGridDim.x * groupId.y + groupId.x; 47 | 48 | // Tile_ID_of_current_CTA : current CTA to TILE-ID mapping. 49 | const uint Tile_ID_of_current_CTA = vThreadGroupIDFlattened / Number_of_CTAs_in_a_perfect_tile; 50 | const uint Local_CTA_ID_within_current_tile = vThreadGroupIDFlattened % Number_of_CTAs_in_a_perfect_tile; 51 | uint Local_CTA_ID_y_within_current_tile; 52 | uint Local_CTA_ID_x_within_current_tile; 53 | 54 | if (Total_CTAs_in_all_perfect_tiles <= vThreadGroupIDFlattened) 55 | { 56 | // Path taken only if the last tile has imperfect dimensions and CTAs from the last tile are launched. 57 | uint X_dimension_of_last_tile = dipatchGridDim.x % maxTileWidth; 58 | #ifdef DXC_STATIC_DISPATCH_GRID_DIM 59 | X_dimension_of_last_tile = max(1, X_dimension_of_last_tile); 60 | #endif 61 | Local_CTA_ID_y_within_current_tile = Local_CTA_ID_within_current_tile / X_dimension_of_last_tile; 62 | Local_CTA_ID_x_within_current_tile = Local_CTA_ID_within_current_tile % X_dimension_of_last_tile; 63 | } 64 | else 65 | { 66 | Local_CTA_ID_y_within_current_tile = Local_CTA_ID_within_current_tile / maxTileWidth; 67 | Local_CTA_ID_x_within_current_tile = Local_CTA_ID_within_current_tile % maxTileWidth; 68 | } 69 | 70 | const uint Swizzled_vThreadGroupIDFlattened = 71 | Tile_ID_of_current_CTA * maxTileWidth + 72 | Local_CTA_ID_y_within_current_tile * dipatchGridDim.x + 73 | Local_CTA_ID_x_within_current_tile; 74 | 75 | uint2 SwizzledvThreadGroupID; 76 | SwizzledvThreadGroupID.y = Swizzled_vThreadGroupIDFlattened / dipatchGridDim.x; 77 | SwizzledvThreadGroupID.x = Swizzled_vThreadGroupIDFlattened % dipatchGridDim.x; 78 | 79 | uint2 SwizzledvThreadID; 80 | SwizzledvThreadID.x = ctaDim.x * SwizzledvThreadGroupID.x + groupThreadID.x; 81 | SwizzledvThreadID.y = ctaDim.y * SwizzledvThreadGroupID.y + groupThreadID.y; 82 | 83 | return SwizzledvThreadID.xy; 84 | } 85 | --------------------------------------------------------------------------------