├── LICENSE ├── README.md ├── setup.py └── src ├── ctlib.cpp ├── fan_ea_cuda.h ├── fan_ea_kernel.cu ├── fan_ed_cuda.h ├── fan_ed_kernel.cu ├── laplacian_cuda.h ├── laplacian_cuda_kernel.cu ├── para_cuda.h └── para_kernel.cu /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 xwj01 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CTLIB 2 | A lib of CT projector and back-projector based on PyTorch 3 | 4 | Coded with distance driven method [1, 2] 5 | 6 | If you use the code, please cite our work 7 | ``` 8 | @article{xia2021magic, 9 | title={MAGIC: Manifold and Graph Integrative Convolutional Network for Low-Dose CT Reconstruction}, 10 | author={Xia, Wenjun and Lu, Zexin and Huang, Yongqiang and Shi, Zuoqiang and Liu, Yan and Chen, Hu and Chen, Yang and Zhou, Jiliu and Zhang, Yi}, 11 | journal={IEEE Transactions on Medical Imaging}, 12 | year={2021}, 13 | publisher={IEEE} 14 | } 15 | ``` 16 | ## Installation 17 | The following is the step-by-step instruction to install this lib using conda 18 | 19 | Create Conda environment 20 | ``` 21 | conda create -n ctlib 22 | ``` 23 | Enter the enviroment and install pytorch 24 | ``` 25 | conda activate ctlib 26 | conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia 27 | ``` 28 | Notice: Usually the cudatoolkit installed from nvidia channel will provide the complier. But if you install the previous pytorch version using earlier cudatoolkit which is not from nvidia channel, the cudatoolkit may not include the complier. The later installation will report the error that can't find nvcc. If so, you need install the lib of cudatoolkit-dev as follows: 29 | ``` 30 | conda install cudatoolkit-dev -c conda-forge 31 | ``` 32 | Move to the the directory of this lib, and then compile and install 33 | ``` 34 | python setup.py install 35 | ``` 36 | 37 | ## API 38 | ``projection(image, options)``: Projector of CT 39 | 40 | ``projection_t(projection, options)``: Transpose of projector 41 | 42 | ``backprojection_t(image, options)``: Transpose of backprojector 43 | 44 | ``backprojection(projection, options)``: Backprojector of CT 45 | 46 | ``fbp(projection, options)``: FBP with RL filter 47 | 48 | ``laplacian(input, k)``: Computation of adjancency matrix in [3] 49 | 50 | ``image``: 4D torch tensor, B x 1 x H x W, 51 | 52 | ``projection``: 4D torch tensor, B x 1 x V x D, V is the total number of scanning views, D is the total number of detector bins 53 | 54 | ``options``: 12D torch vector for fan beam and 9D torch vector for parallel beam, scanning geometry parameters, including 55 | 56 | ``views``: Number of scanning views 57 | 58 | ``dets``: Number of detector bins 59 | 60 | ``width`` and ``height``: Spatial resolution of images 61 | 62 | ``dImg``: Physical length of a pixel 63 | 64 | ``dDet``: Interval between two adjacent detector bins, especially, interval is ``rad`` for equal angle fan beam 65 | 66 | ``Ang0``: Starting angle 67 | 68 | ``dAng``: Interval between two adjacent scanning views: ``rad`` 69 | 70 | ``s2r``: The distance between x-ray source and rotation center, not needed in parallel beam 71 | 72 | ``d2r``: The distance between detector and roration center, not needed in parallel beam 73 | 74 | ``binshift``: The shift of the detector 75 | 76 | ``scan_type``: ``0`` is equal distance fan beam, ``1`` is euql angle fan beam and ``2`` is parallel beam 77 | 78 | [1] B. De Man and S. Basu, “Distance-driven projection and backprojection,” 79 | in IEEE Nucl. Sci. Symp. Conf. Record, vol. 3, 2002, pp. 1477–80. 80 | 81 | [2] B. De Man and S. Basu, “Distance-driven projection and backprojection in three dimensions,” 82 | Phys. Med. Biol., vol. 49, no. 11, p. 2463, 2004. 83 | 84 | [3] Xia, W., Lu, Z., Huang, Y., Shi, Z., Liu, Y., Chen, H., ... & Zhang, Y. (2021). MAGIC: Manifold and Graph Integrative Convolutional Network for Low-Dose CT Reconstruction. IEEE Transactions on Medical Imaging. 85 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 3 | 4 | setup( 5 | name='ctlib', 6 | version='0.2.0', 7 | author='Wenjun Xia', 8 | ext_modules=[ 9 | CUDAExtension('ctlib', [ 10 | 'src/ctlib.cpp', 11 | 'src/fan_ed_kernel.cu', 12 | 'src/fan_ea_kernel.cu', 13 | 'src/para_kernel.cu', 14 | 'src/laplacian_cuda_kernel.cu', 15 | ], 16 | ), 17 | ], 18 | cmdclass={ 19 | 'build_ext': BuildExtension 20 | }) 21 | -------------------------------------------------------------------------------- /src/ctlib.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include "fan_ed_cuda.h" 3 | #include "fan_ea_cuda.h" 4 | #include "para_cuda.h" 5 | #include "laplacian_cuda.h" 6 | 7 | #define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") 8 | #define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") 9 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 10 | 11 | #define Fan_Equal_Distance 0 12 | #define Fan_Equal_Angle 1 13 | #define Para 2 14 | 15 | torch::Tensor backprojection_t(torch::Tensor image, torch::Tensor options) { 16 | CHECK_INPUT(image); 17 | CHECK_INPUT(options); 18 | int scan_type = options[static_cast(options.size(0))-1].item(); 19 | if (scan_type == Fan_Equal_Distance){ 20 | return bprj_t_fan_ed_cuda(image, options); 21 | } else if (scan_type == Fan_Equal_Angle) { 22 | return bprj_t_fan_ea_cuda(image, options); 23 | } else if (scan_type == Para) { 24 | return bprj_t_para_cuda(image, options); 25 | } else { 26 | exit(0); 27 | } 28 | } 29 | 30 | torch::Tensor backprojection(torch::Tensor projection, torch::Tensor options) { 31 | CHECK_INPUT(projection); 32 | CHECK_INPUT(options); 33 | int scan_type = options[static_cast(options.size(0))-1].item(); 34 | if (scan_type == Fan_Equal_Distance){ 35 | return bprj_fan_ed_cuda(projection, options); 36 | } else if (scan_type == Fan_Equal_Angle) { 37 | return bprj_fan_ea_cuda(projection, options); 38 | } else if (scan_type == Para) { 39 | return bprj_para_cuda(projection, options); 40 | } else { 41 | exit(0); 42 | } 43 | } 44 | 45 | torch::Tensor backprojection_sv(torch::Tensor projection, torch::Tensor options) { 46 | CHECK_INPUT(projection); 47 | CHECK_INPUT(options); 48 | int scan_type = options[static_cast(options.size(0))-1].item(); 49 | if (scan_type == Fan_Equal_Distance){ 50 | return bprj_sv_fan_ed_cuda(projection, options); 51 | } else { 52 | exit(0); 53 | } 54 | } 55 | 56 | torch::Tensor projection(torch::Tensor image, torch::Tensor options) { 57 | CHECK_INPUT(image); 58 | CHECK_INPUT(options); 59 | int scan_type = options[static_cast(options.size(0))-1].item(); 60 | if (scan_type == Fan_Equal_Distance){ 61 | return prj_fan_ed_cuda(image, options); 62 | } else if (scan_type == Fan_Equal_Angle) { 63 | return prj_fan_ea_cuda(image, options); 64 | } else if (scan_type == Para) { 65 | return prj_para_cuda(image, options); 66 | } else { 67 | exit(0); 68 | } 69 | } 70 | 71 | torch::Tensor projection_t(torch::Tensor projection, torch::Tensor options) { 72 | CHECK_INPUT(projection); 73 | CHECK_INPUT(options); 74 | int scan_type = options[static_cast(options.size(0))-1].item(); 75 | if (scan_type == Fan_Equal_Distance){ 76 | return prj_t_fan_ed_cuda(projection, options); 77 | } else if (scan_type == Fan_Equal_Angle) { 78 | return prj_t_fan_ea_cuda(projection, options); 79 | } else if (scan_type == Para) { 80 | return prj_t_para_cuda(projection, options); 81 | } else { 82 | exit(0); 83 | } 84 | } 85 | 86 | torch::Tensor fbp(torch::Tensor projection, torch::Tensor options) { 87 | CHECK_INPUT(projection); 88 | CHECK_INPUT(options); 89 | int scan_type = options[static_cast(options.size(0))-1].item(); 90 | if (scan_type == Fan_Equal_Distance){ 91 | return fbp_fan_ed_cuda(projection, options); 92 | } else if (scan_type == Fan_Equal_Angle) { 93 | return fbp_fan_ea_cuda(projection, options); 94 | } else if (scan_type == Para) { 95 | return fbp_para_cuda(projection, options); 96 | } else { 97 | exit(0); 98 | } 99 | } 100 | 101 | torch::Tensor laplacian(torch::Tensor input, int k) { 102 | CHECK_INPUT(input); 103 | return laplacian_cuda_forward(input, k); 104 | } 105 | 106 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 107 | m.def("projection", &projection, "CT projection (CUDA)"); 108 | m.def("projection_t", &projection_t, "Transpose of CT projection (CUDA)"); 109 | m.def("backprojection_t", &backprojection_t, "Transpose of backprojection (CUDA)"); 110 | m.def("backprojection", &backprojection, "CT backprojection (CUDA)"); 111 | m.def("backprojection_sv", &backprojection_sv, "CT backprojection single view (CUDA)"); 112 | m.def("fbp", &fbp, "CT filtered backprojection with RL filter (CUDA)"); 113 | m.def("laplacian", &laplacian, "Graph Laplacian computation (CUDA)"); 114 | } 115 | -------------------------------------------------------------------------------- /src/fan_ea_cuda.h: -------------------------------------------------------------------------------- 1 | #ifndef FAN_EA_CUDA_H 2 | #define FAN_EA_CUDA_H 3 | 4 | #include 5 | 6 | torch::Tensor prj_fan_ea_cuda(torch::Tensor image, torch::Tensor options); 7 | torch::Tensor prj_t_fan_ea_cuda(torch::Tensor projection, torch::Tensor options); 8 | torch::Tensor bprj_t_fan_ea_cuda(torch::Tensor image, torch::Tensor options); 9 | torch::Tensor bprj_fan_ea_cuda(torch::Tensor image, torch::Tensor options); 10 | torch::Tensor fbp_fan_ea_cuda(torch::Tensor projection, torch::Tensor options); 11 | 12 | #endif -------------------------------------------------------------------------------- /src/fan_ea_kernel.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #define BLOCK_DIM 256 5 | #define GRID_DIM 512 6 | 7 | template 8 | __device__ scalar_t map_x(scalar_t sourcex, scalar_t sourcey, scalar_t detx, scalar_t dety) { 9 | return (sourcex * dety - sourcey * detx) / (dety - sourcey); 10 | } 11 | 12 | template 13 | __device__ scalar_t map_y(scalar_t sourcex, scalar_t sourcey, scalar_t detx, scalar_t dety) { 14 | return (sourcey * detx - sourcex * dety) / (detx - sourcex); 15 | } 16 | 17 | template 18 | __device__ scalar_t cweight(scalar_t sourcex, scalar_t sourcey, scalar_t detx, scalar_t dety) { 19 | return (sourcex - detx) * (sourcex - detx) + (sourcey - dety) * (sourcey - dety); 20 | } 21 | 22 | template 23 | __global__ void prj_fan_ea( 24 | const torch::PackedTensorAccessor image, 25 | torch::PackedTensorAccessor projection, 26 | const scalar_t* __restrict__ views, const scalar_t* __restrict__ dets, 27 | const scalar_t* __restrict__ width, const scalar_t* __restrict__ height, 28 | const scalar_t* __restrict__ dImg, const scalar_t* __restrict__ dDet, 29 | const scalar_t* __restrict__ Ang0, const scalar_t* __restrict__ dAng, 30 | const scalar_t* __restrict__ s2r, const scalar_t* __restrict__ d2r, 31 | const scalar_t* __restrict__ binshift) { 32 | 33 | __shared__ unsigned int nblocks; 34 | __shared__ unsigned int idxchannel; 35 | __shared__ unsigned int idxview; 36 | nblocks = ceil(*views / gridDim.y); 37 | idxchannel = blockIdx.x % nblocks; 38 | idxview = idxchannel * gridDim.y + blockIdx.y; 39 | if (idxview >= *views) return; 40 | idxchannel = blockIdx.x / nblocks; 41 | __shared__ scalar_t prj[BLOCK_DIM]; 42 | __shared__ scalar_t dPoint[BLOCK_DIM]; 43 | __shared__ scalar_t coef[BLOCK_DIM]; 44 | __shared__ scalar_t dImage; 45 | __shared__ scalar_t sourcex; 46 | __shared__ scalar_t sourcey; 47 | __shared__ scalar_t dPoint0; 48 | __shared__ double ang; 49 | __shared__ double beta0; 50 | __shared__ double PI; 51 | __shared__ double ang_error; 52 | __shared__ double cosval; 53 | __shared__ double sinval; 54 | __shared__ unsigned int dIndex0; 55 | 56 | PI = acos(-1.0); 57 | ang = idxview * *dAng + *Ang0; 58 | dImage = *dImg; 59 | ang_error = abs(ang - round(ang / PI) * PI) * 4 / PI; 60 | cosval = cos(ang); 61 | sinval = sin(ang); 62 | sourcex = - sinval * *s2r; 63 | sourcey = cosval * *s2r; 64 | dIndex0 = blockIdx.z * blockDim.x; 65 | unsigned int tx = threadIdx.x; 66 | unsigned int dIndex = dIndex0 + tx; 67 | prj[tx] = 0; 68 | __syncthreads(); 69 | if (ang_error <= 1) { 70 | ang_error = (ang - floor(ang / 2 / PI) * 2 * PI) * 4 / PI; 71 | if (ang_error >= 3 && ang_error < 7) { 72 | beta0 = (*dets / 2 - dIndex0) * *dDet + *binshift + ang; 73 | dPoint0 = sourcex + sin(beta0) / cos(beta0) * sourcey; 74 | if (dIndex < *dets) { 75 | double beta = (*dets / 2 - dIndex - 1) * *dDet + *binshift + ang; 76 | dPoint[tx] = sourcex + sin(beta) / cos(beta) * sourcey; 77 | } 78 | } else { 79 | beta0 = (dIndex0 - *dets / 2) * *dDet + *binshift + ang; 80 | dPoint0 = sourcex + sin(beta0) / cos(beta0) * sourcey; 81 | if (dIndex < *dets) { 82 | double beta = (dIndex + 1 - *dets / 2) * *dDet + *binshift + ang; 83 | dPoint[tx] = sourcex + sin(beta) / cos(beta) * sourcey; 84 | } 85 | } 86 | __syncthreads(); 87 | if (tx == 0){ 88 | coef[tx] = dPoint[tx] - dPoint0; 89 | } else { 90 | coef[tx] = dPoint[tx] - dPoint[tx-1]; 91 | } 92 | __syncthreads(); 93 | for (int i = 0; i < ceil(*height / blockDim.x); i++){ 94 | int idxrow = i * blockDim.x + tx; 95 | if (idxrow < *height) { 96 | scalar_t i0y = (*height / 2 - idxrow - 0.5) * dImage; 97 | scalar_t i0x = - *width / 2 * dImage; 98 | int idx0col = floor(((i0y - sourcey) / (- sourcey) * 99 | (dPoint0 - sourcex) + sourcex - i0x) / dImage); 100 | idx0col = max(idx0col, 0); 101 | i0x += idx0col * dImage; 102 | scalar_t threadprj = 0; 103 | scalar_t prebound = map_x(sourcex, sourcey, i0x, i0y); 104 | prebound = max(prebound, dPoint0); 105 | i0x += dImage; 106 | scalar_t pixbound = map_x(sourcex, sourcey, i0x, i0y); 107 | scalar_t detbound = dPoint[0]; 108 | int idxd = 0, idxi = idx0col; 109 | while (idxi < *width && (idxd + dIndex0) < *dets && idxd < blockDim.x) { 110 | if (detbound <= prebound) { 111 | idxd ++; 112 | if (idxd < blockDim.x) detbound = dPoint[idxd]; 113 | }else if (pixbound <= prebound){ 114 | idxi ++; 115 | i0x += dImage; 116 | pixbound = map_x(sourcex, sourcey, i0x, i0y); 117 | }else if (pixbound < detbound) { 118 | threadprj += (pixbound - prebound) * image[idxchannel][0][idxrow][idxi] / coef[idxd]; 119 | prebound = pixbound; 120 | idxi ++; 121 | i0x += dImage; 122 | pixbound = map_x(sourcex, sourcey, i0x, i0y); 123 | } else { 124 | threadprj += (detbound - prebound) * image[idxchannel][0][idxrow][idxi] / coef[idxd]; 125 | prebound = detbound; 126 | atomicAdd(prj+idxd, threadprj); 127 | threadprj = 0; 128 | idxd ++; 129 | if (idxd < blockDim.x) detbound = dPoint[idxd]; 130 | } 131 | } 132 | if (threadprj != 0) atomicAdd(prj+idxd, threadprj); 133 | } 134 | } 135 | __syncthreads(); 136 | dPoint0 = abs(sourcey) / sqrt((dPoint0 - sourcex) * (dPoint0 - sourcex) + sourcey * sourcey); 137 | if (dIndex < *dets) { 138 | dPoint[tx] = abs(sourcey) / sqrt((dPoint[tx] - sourcex) * (dPoint[tx] - sourcex) + sourcey * sourcey); 139 | __syncthreads(); 140 | if (tx == 0){ 141 | coef[tx] = (dPoint[tx] + dPoint0) / 2; 142 | } else { 143 | coef[tx] = (dPoint[tx] + dPoint[tx-1]) / 2; 144 | } 145 | __syncthreads(); 146 | prj[tx] *= dImage; 147 | prj[tx] /= coef[tx]; 148 | if (ang_error >= 3 && ang_error < 7) { 149 | projection[idxchannel][0][idxview][static_cast(*dets)-1-dIndex] = prj[tx]; 150 | } else { 151 | projection[idxchannel][0][idxview][dIndex] = prj[tx]; 152 | } 153 | } 154 | } else { 155 | ang_error = (ang - floor(ang / 2 / PI) * 2 * PI) * 4 / PI; 156 | if (ang_error >= 3 && ang_error < 7) { 157 | beta0 = (*dets / 2 - dIndex0) * *dDet + *binshift + ang; 158 | dPoint0 = sourcey + cos(beta0) / sin(beta0) * sourcex; 159 | if (dIndex < *dets) { 160 | double beta = (*dets / 2 - dIndex - 1) * *dDet + *binshift + ang; 161 | dPoint[tx] = sourcey + cos(beta) / sin(beta) * sourcex; 162 | } 163 | } else { 164 | beta0 = (dIndex0 - *dets / 2) * *dDet + *binshift + ang; 165 | dPoint0 = sourcey + cos(beta0) / sin(beta0) * sourcex; 166 | if (dIndex < *dets) { 167 | double beta = (dIndex + 1 - *dets / 2) * *dDet + *binshift + ang; 168 | dPoint[tx] = sourcey + cos(beta) / sin(beta) * sourcex; 169 | } 170 | } 171 | __syncthreads(); 172 | if (tx == 0){ 173 | coef[tx] = dPoint[tx] - dPoint0; 174 | } else { 175 | coef[tx] = dPoint[tx] - dPoint[tx-1]; 176 | } 177 | __syncthreads(); 178 | for (int i = 0; i < ceil(*width / blockDim.x); i++){ 179 | int idxcol = i * blockDim.x + tx; 180 | if (idxcol < *width) { 181 | scalar_t i0x = (idxcol - *width / 2 + 0.5) * dImage; 182 | scalar_t i0y = - *height / 2 * dImage; 183 | int idx0row = floor(((i0x - sourcex) / (- sourcex) * 184 | (dPoint0 - sourcey) + sourcey - i0y) / dImage); 185 | idx0row = max(idx0row, 0); 186 | i0y += idx0row * dImage; 187 | scalar_t threadprj = 0; 188 | scalar_t prebound = map_y(sourcex, sourcey, i0x, i0y); 189 | prebound = max(prebound, dPoint0); 190 | i0y += dImage; 191 | scalar_t pixbound = map_y(sourcex, sourcey, i0x, i0y); 192 | scalar_t detbound = dPoint[0]; 193 | int idxd = 0, idxi = idx0row; 194 | while (idxi < *height && (idxd + dIndex0) < *dets && idxd < blockDim.x) { 195 | if (detbound <= prebound) { 196 | idxd ++; 197 | if (idxd < blockDim.x) detbound = dPoint[idxd]; 198 | }else if (pixbound <= prebound) { 199 | idxi ++; 200 | i0y += dImage; 201 | pixbound = map_y(sourcex, sourcey, i0x, i0y); 202 | }else if (pixbound < detbound) { 203 | threadprj += (pixbound - prebound) * image[idxchannel][0][static_cast(*height)-1-idxi][idxcol] / coef[idxd]; 204 | prebound = pixbound; 205 | idxi ++; 206 | i0y += dImage; 207 | pixbound = map_y(sourcex, sourcey, i0x, i0y); 208 | } else { 209 | threadprj += (detbound - prebound) * image[idxchannel][0][static_cast(*height)-1-idxi][idxcol] / coef[idxd]; 210 | prebound = detbound; 211 | atomicAdd(prj+idxd, threadprj); 212 | threadprj = 0; 213 | idxd ++; 214 | if (idxd < blockDim.x) detbound = dPoint[idxd]; 215 | } 216 | } 217 | if (threadprj != 0) atomicAdd(prj+idxd, threadprj); 218 | } 219 | } 220 | __syncthreads(); 221 | dPoint0 = abs(sourcex) / sqrt((dPoint0 - sourcey) * (dPoint0 - sourcey) + sourcex * sourcex); 222 | if (dIndex < *dets) { 223 | dPoint[tx] = abs(sourcex) / sqrt((dPoint[tx] - sourcey) * (dPoint[tx] - sourcey) + sourcex * sourcex); 224 | __syncthreads(); 225 | if (tx == 0){ 226 | coef[tx] = (dPoint[tx] + dPoint0) / 2; 227 | } else { 228 | coef[tx] = (dPoint[tx] + dPoint[tx-1]) / 2; 229 | } 230 | __syncthreads(); 231 | prj[tx] *= dImage; 232 | prj[tx] /= coef[tx]; 233 | if (ang_error >= 3 && ang_error < 7) { 234 | projection[idxchannel][0][idxview][static_cast(*dets)-1-dIndex] = prj[tx]; 235 | } else { 236 | projection[idxchannel][0][idxview][dIndex] = prj[tx]; 237 | } 238 | } 239 | } 240 | } 241 | 242 | template 243 | __global__ void prj_t_fan_ea( 244 | torch::PackedTensorAccessor image, 245 | const torch::PackedTensorAccessor projection, 246 | const scalar_t* __restrict__ views, const scalar_t* __restrict__ dets, 247 | const scalar_t* __restrict__ width, const scalar_t* __restrict__ height, 248 | const scalar_t* __restrict__ dImg, const scalar_t* __restrict__ dDet, 249 | const scalar_t* __restrict__ Ang0, const scalar_t* __restrict__ dAng, 250 | const scalar_t* __restrict__ s2r, const scalar_t* __restrict__ d2r, 251 | const scalar_t* __restrict__ binshift) { 252 | 253 | __shared__ unsigned int nblocks; 254 | __shared__ unsigned int idxchannel; 255 | __shared__ unsigned int idxview; 256 | nblocks = ceil(*views / gridDim.y); 257 | idxchannel = blockIdx.x % nblocks; 258 | idxview = idxchannel * gridDim.y + blockIdx.y; 259 | if (idxview >= *views) return; 260 | idxchannel = blockIdx.x / nblocks; 261 | __shared__ scalar_t prj[BLOCK_DIM]; 262 | __shared__ scalar_t dPoint[BLOCK_DIM]; 263 | __shared__ scalar_t coef[BLOCK_DIM]; 264 | __shared__ scalar_t dImage; 265 | __shared__ scalar_t sourcex; 266 | __shared__ scalar_t sourcey; 267 | __shared__ scalar_t dPoint0; 268 | __shared__ double ang; 269 | __shared__ double beta0; 270 | __shared__ double PI; 271 | __shared__ double ang_error; 272 | __shared__ double cosval; 273 | __shared__ double sinval; 274 | __shared__ unsigned int dIndex0; 275 | 276 | PI = acos(-1.0); 277 | ang = idxview * *dAng + *Ang0; 278 | ang_error = abs(ang - round(ang / PI) * PI) * 4 / PI; 279 | cosval = cos(ang); 280 | sinval = sin(ang); 281 | sourcex = - sinval * *s2r; 282 | sourcey = cosval * *s2r; 283 | dIndex0 = blockIdx.z * blockDim.x; 284 | unsigned int tx = threadIdx.x; 285 | unsigned int dIndex = dIndex0 + tx; 286 | __syncthreads(); 287 | if (ang_error <= 1) { 288 | ang_error = (ang - floor(ang / 2 / PI) * 2 * PI) * 4 / PI; 289 | if (ang_error >= 3 && ang_error < 7) { 290 | beta0 = (*dets / 2 - dIndex0) * *dDet + *binshift + ang; 291 | dPoint0 = sourcex + sin(beta0) / cos(beta0) * sourcey; 292 | if (dIndex < *dets) { 293 | double beta = (*dets / 2 - dIndex - 1) * *dDet + *binshift + ang; 294 | dPoint[tx] = sourcex + sin(beta) / cos(beta) * sourcey; 295 | } 296 | } else { 297 | beta0 = (dIndex0 - *dets / 2) * *dDet + *binshift + ang; 298 | dPoint0 = sourcex + sin(beta0) / cos(beta0) * sourcey; 299 | if (dIndex < *dets) { 300 | double beta = (dIndex + 1 - *dets / 2) * *dDet + *binshift + ang; 301 | dPoint[tx] = sourcex + sin(beta) / cos(beta) * sourcey; 302 | } 303 | } 304 | __syncthreads(); 305 | dImage = abs(sourcey) / sqrt((dPoint0 - sourcex) * (dPoint0 - sourcex) + sourcey * sourcey); 306 | if (dIndex < *dets) { 307 | prj[tx] = abs(sourcey) / sqrt((dPoint[tx] - sourcex) * (dPoint[tx] - sourcex) + sourcey * sourcey); 308 | } else { 309 | prj[tx] = 0; 310 | } 311 | __syncthreads(); 312 | if (dIndex < *dets) { 313 | if (tx == 0){ 314 | coef[tx] = (prj[tx] + dImage) / 2; 315 | } else { 316 | coef[tx] = (prj[tx] + prj[tx-1]) / 2; 317 | } 318 | } 319 | __syncthreads(); 320 | if (dIndex < *dets) { 321 | if (ang_error >= 3 && ang_error < 7) { 322 | prj[tx] = projection[idxchannel][0][idxview][static_cast(*dets)-1-dIndex]; 323 | } else { 324 | prj[tx] = projection[idxchannel][0][idxview][dIndex]; 325 | } 326 | prj[tx] *= *dImg; 327 | prj[tx] /= coef[tx]; 328 | if (tx == 0){ 329 | coef[tx] = dPoint[tx] - dPoint0; 330 | } else { 331 | coef[tx] = dPoint[tx] - dPoint[tx-1]; 332 | } 333 | prj[tx] /= coef[tx]; 334 | } 335 | __syncthreads(); 336 | dImage = *dImg; 337 | for (int i = 0; i < ceil(*height / blockDim.x); i++){ 338 | int idxrow = i * blockDim.x + tx; 339 | if (idxrow < *height) { 340 | scalar_t i0y = (*height / 2 - idxrow - 0.5) * dImage; 341 | scalar_t i0x = - *width / 2 * dImage; 342 | int idx0col = floor(((i0y - sourcey) / (- sourcey) * 343 | (dPoint0- sourcex) + sourcex - i0x) / dImage); 344 | idx0col = max(idx0col, 0); 345 | i0x += idx0col * dImage; 346 | scalar_t threadprj = 0; 347 | scalar_t prebound = map_x(sourcex, sourcey, i0x, i0y); 348 | prebound = max(prebound, dPoint0); 349 | i0x += dImage; 350 | scalar_t pixbound = map_x(sourcex, sourcey, i0x, i0y); 351 | scalar_t detbound = dPoint[0]; 352 | int idxd = 0, idxi = idx0col; 353 | while (idxi < *width && (idxd + dIndex0) < *dets && idxd < blockDim.x) { 354 | if (detbound <= prebound) { 355 | idxd ++; 356 | if (idxd < blockDim.x) detbound = dPoint[idxd]; 357 | }else if (pixbound <= prebound){ 358 | idxi ++; 359 | i0x += dImage; 360 | pixbound = map_x(sourcex, sourcey, i0x, i0y); 361 | }else if (pixbound <= detbound) { 362 | threadprj += (pixbound - prebound) * prj[idxd]; 363 | prebound = pixbound; 364 | atomicAdd(&(image[idxchannel][0][idxrow][idxi]), threadprj); 365 | threadprj = 0; 366 | idxi ++; 367 | i0x += dImage; 368 | pixbound = map_x(sourcex, sourcey, i0x, i0y); 369 | } else { 370 | threadprj += (detbound - prebound) * prj[idxd]; 371 | prebound = detbound; 372 | idxd ++; 373 | if (idxd < blockDim.x) detbound = dPoint[idxd]; 374 | } 375 | } 376 | if (threadprj !=0 ) atomicAdd(&(image[idxchannel][0][idxrow][idxi]), threadprj); 377 | } 378 | } 379 | } else { 380 | ang_error = (ang - floor(ang / 2 / PI) * 2 * PI) * 4 / PI; 381 | if (ang_error >= 3 && ang_error < 7) { 382 | beta0 = (*dets / 2 - dIndex0) * *dDet + *binshift + ang; 383 | dPoint0 = sourcey + cos(beta0) / sin(beta0) * sourcex; 384 | if (dIndex < *dets) { 385 | double beta = (*dets / 2 - dIndex - 1) * *dDet + *binshift + ang; 386 | dPoint[tx] = sourcey + cos(beta) / sin(beta) * sourcex; 387 | } 388 | } else { 389 | beta0 = (dIndex0 - *dets / 2) * *dDet + *binshift + ang; 390 | dPoint0 = sourcey + cos(beta0) / sin(beta0) * sourcex; 391 | if (dIndex < *dets) { 392 | double beta = (dIndex + 1 - *dets / 2) * *dDet + *binshift + ang; 393 | dPoint[tx] = sourcey + cos(beta) / sin(beta) * sourcex; 394 | } 395 | } 396 | __syncthreads(); 397 | dImage = abs(sourcex) / sqrt((dPoint0 - sourcey) * (dPoint0 - sourcey) + sourcex * sourcex); 398 | if (dIndex < *dets) { 399 | prj[tx] = abs(sourcex) / sqrt((dPoint[tx] - sourcey) * (dPoint[tx] - sourcey) + sourcex * sourcex); 400 | } else { 401 | prj[tx] = 0; 402 | } 403 | __syncthreads(); 404 | if (dIndex < *dets) { 405 | if (tx == 0){ 406 | coef[tx] = (prj[tx] + dImage) / 2; 407 | } else { 408 | coef[tx] = (prj[tx] + prj[tx-1]) / 2; 409 | } 410 | } 411 | __syncthreads(); 412 | if (dIndex < *dets) { 413 | if (ang_error >= 3 && ang_error < 7) { 414 | prj[tx] = projection[idxchannel][0][idxview][static_cast(*dets)-1-dIndex]; 415 | } else { 416 | prj[tx] = projection[idxchannel][0][idxview][dIndex]; 417 | } 418 | prj[tx] *= *dImg; 419 | prj[tx] /= coef[tx]; 420 | if (tx == 0){ 421 | coef[tx] = dPoint[tx] - dPoint0; 422 | } else { 423 | coef[tx] = dPoint[tx] - dPoint[tx-1]; 424 | } 425 | prj[tx] /= coef[tx]; 426 | } 427 | __syncthreads(); 428 | dImage = *dImg; 429 | for (int i = 0; i < ceil(*width / blockDim.x); i++){ 430 | int idxcol = i * blockDim.x + tx; 431 | if (idxcol < *width) { 432 | scalar_t i0x = (idxcol - *width / 2 + 0.5) * dImage; 433 | scalar_t i0y = - *height / 2 * dImage; 434 | int idx0row = floor(((i0x - sourcex) / (- sourcex) * 435 | (dPoint0- sourcey) + sourcey - i0y) / dImage); 436 | idx0row = max(idx0row, 0); 437 | i0y += idx0row * dImage; 438 | scalar_t threadprj = 0; 439 | scalar_t prebound = map_y(sourcex, sourcey, i0x, i0y); 440 | prebound = max(prebound, dPoint0); 441 | i0y += dImage; 442 | scalar_t pixbound = map_y(sourcex, sourcey, i0x, i0y); 443 | scalar_t detbound = dPoint[0]; 444 | int idxd = 0, idxi = idx0row; 445 | while (idxi < *height && (idxd + dIndex0) < *dets && idxd < blockDim.x) { 446 | if (detbound <= prebound) { 447 | idxd ++; 448 | if (idxd < blockDim.x) detbound = dPoint[idxd]; 449 | }else if (pixbound <= prebound) { 450 | idxi ++; 451 | i0y += dImage; 452 | pixbound = map_y(sourcex, sourcey, i0x, i0y); 453 | }else if (pixbound <= detbound) { 454 | threadprj += (pixbound - prebound) * prj[idxd]; 455 | prebound = pixbound; 456 | atomicAdd(&(image[idxchannel][0][static_cast(*height)-1-idxi][idxcol]), threadprj); 457 | threadprj = 0; 458 | idxi ++; 459 | i0y += dImage; 460 | pixbound = map_y(sourcex, sourcey, i0x, i0y); 461 | } else { 462 | threadprj += (detbound - prebound) * prj[idxd]; 463 | prebound = detbound; 464 | idxd ++; 465 | if (idxd < blockDim.x) detbound = dPoint[idxd]; 466 | } 467 | } 468 | if (threadprj !=0 ) atomicAdd(&(image[idxchannel][0][static_cast(*height)-1-idxi][idxcol]), threadprj); 469 | } 470 | } 471 | } 472 | } 473 | 474 | template 475 | __global__ void bprj_t_fan_ea( 476 | const torch::PackedTensorAccessor image, 477 | torch::PackedTensorAccessor projection, 478 | const scalar_t* __restrict__ views, const scalar_t* __restrict__ dets, 479 | const scalar_t* __restrict__ width, const scalar_t* __restrict__ height, 480 | const scalar_t* __restrict__ dImg, const scalar_t* __restrict__ dDet, 481 | const scalar_t* __restrict__ Ang0, const scalar_t* __restrict__ dAng, 482 | const scalar_t* __restrict__ s2r, const scalar_t* __restrict__ d2r, 483 | const scalar_t* __restrict__ binshift) { 484 | 485 | __shared__ unsigned int nblocks; 486 | __shared__ unsigned int idxchannel; 487 | __shared__ unsigned int idxview; 488 | nblocks = ceil(*views / gridDim.y); 489 | idxchannel = blockIdx.x % nblocks; 490 | idxview = idxchannel * gridDim.y + blockIdx.y; 491 | if (idxview >= *views) return; 492 | idxchannel = blockIdx.x / nblocks; 493 | __shared__ scalar_t prj[BLOCK_DIM]; 494 | __shared__ scalar_t dPoint[BLOCK_DIM]; 495 | __shared__ scalar_t coef[BLOCK_DIM]; 496 | __shared__ scalar_t dImage; 497 | __shared__ scalar_t sourcex; 498 | __shared__ scalar_t sourcey; 499 | __shared__ scalar_t dPoint0; 500 | __shared__ double ang; 501 | __shared__ double beta0; 502 | __shared__ double PI; 503 | __shared__ double ang_error; 504 | __shared__ double cosval; 505 | __shared__ double sinval; 506 | __shared__ unsigned int dIndex0; 507 | 508 | PI = acos(-1.0); 509 | ang = idxview * *dAng + *Ang0; 510 | dImage = *dImg; 511 | ang_error = abs(ang - round(ang / PI) * PI) * 4 / PI; 512 | cosval = cos(ang); 513 | sinval = sin(ang); 514 | sourcex = - sinval * *s2r; 515 | sourcey = cosval * *s2r; 516 | dIndex0 = blockIdx.z * blockDim.x; 517 | unsigned int tx = threadIdx.x; 518 | unsigned int dIndex = dIndex0 + tx; 519 | prj[tx] = 0; 520 | __syncthreads(); 521 | if (ang_error <= 1) { 522 | ang_error = (ang - floor(ang / 2 / PI) * 2 * PI) * 4 / PI; 523 | if (ang_error >= 3 && ang_error < 7) { 524 | beta0 = (*dets / 2 - dIndex0) * *dDet + *binshift + ang; 525 | dPoint0 = sourcex + sin(beta0) / cos(beta0) * sourcey; 526 | if (dIndex < *dets) { 527 | double beta = (*dets / 2 - dIndex - 1) * *dDet + *binshift + ang; 528 | dPoint[tx] = sourcex + sin(beta) / cos(beta) * sourcey; 529 | } 530 | } else { 531 | beta0 = (dIndex0 - *dets / 2) * *dDet + *binshift + ang; 532 | dPoint0 = sourcex + sin(beta0) / cos(beta0) * sourcey; 533 | if (dIndex < *dets) { 534 | double beta = (dIndex + 1 - *dets / 2) * *dDet + *binshift + ang; 535 | dPoint[tx] = sourcex + sin(beta) / cos(beta) * sourcey; 536 | } 537 | } 538 | __syncthreads(); 539 | if (tx == 0){ 540 | coef[tx] = dPoint[tx] - dPoint0; 541 | } else { 542 | coef[tx] = dPoint[tx] - dPoint[tx-1]; 543 | } 544 | __syncthreads(); 545 | for (int i = 0; i < ceil(*height / blockDim.x); i++){ 546 | int idxrow = i * blockDim.x + tx; 547 | if (idxrow < *height) { 548 | scalar_t i0y = (*height / 2 - idxrow - 0.5) * dImage; 549 | scalar_t i0x = - *width / 2 * dImage; 550 | int idx0col = floor(((i0y - sourcey) / (- sourcey) * 551 | (dPoint0 - sourcex) + sourcex - i0x) / dImage); 552 | idx0col = max(idx0col, 0); 553 | i0x += idx0col * dImage; 554 | scalar_t threadprj = 0; 555 | scalar_t prebound = map_x(sourcex, sourcey, i0x, i0y); 556 | scalar_t prepixbound = prebound; 557 | prebound = max(prebound, dPoint0); 558 | i0x += dImage; 559 | scalar_t pixbound = map_x(sourcex, sourcey, i0x, i0y); 560 | scalar_t detbound = dPoint[0]; 561 | int idxd = 0, idxi = idx0col; 562 | while (idxi < *width && (idxd + dIndex0) < *dets && idxd < blockDim.x) { 563 | if (detbound <= prebound) { 564 | idxd ++; 565 | if (idxd < blockDim.x) detbound = dPoint[idxd]; 566 | }else if (pixbound <= prebound){ 567 | idxi ++; 568 | i0x += dImage; 569 | prepixbound = pixbound; 570 | pixbound = map_x(sourcex, sourcey, i0x, i0y); 571 | }else if (pixbound < detbound) { 572 | threadprj += (pixbound - prebound) * image[idxchannel][0][idxrow][idxi] / (pixbound - prepixbound) / cweight(sourcex, sourcey, i0x - dImage / 2, i0y); 573 | prebound = pixbound; 574 | idxi ++; 575 | i0x += dImage; 576 | prepixbound = pixbound; 577 | pixbound = map_x(sourcex, sourcey, i0x, i0y); 578 | } else { 579 | threadprj += (detbound - prebound) * image[idxchannel][0][idxrow][idxi] / (pixbound - prepixbound) / cweight(sourcex, sourcey, i0x - dImage / 2, i0y); 580 | prebound = detbound; 581 | atomicAdd(prj+idxd, threadprj); 582 | threadprj = 0; 583 | idxd ++; 584 | if (idxd < blockDim.x) detbound = dPoint[idxd]; 585 | } 586 | } 587 | if (threadprj != 0) atomicAdd(prj+idxd, threadprj); 588 | } 589 | } 590 | __syncthreads(); 591 | dPoint0 = abs(sourcey) / sqrt((dPoint0 - sourcex) * (dPoint0 - sourcex) + sourcey * sourcey); 592 | if (dIndex < *dets) { 593 | dPoint[tx] = abs(sourcey) / sqrt((dPoint[tx] - sourcex) * (dPoint[tx] - sourcex) + sourcey * sourcey); 594 | __syncthreads(); 595 | if (tx == 0){ 596 | coef[tx] = (dPoint[tx] + dPoint0) / 2; 597 | } else { 598 | coef[tx] = (dPoint[tx] + dPoint[tx-1]) / 2; 599 | } 600 | __syncthreads(); 601 | prj[tx] *= dImage; 602 | prj[tx] /= coef[tx]; 603 | if (ang_error >= 3 && ang_error < 7) { 604 | projection[idxchannel][0][idxview][static_cast(*dets)-1-dIndex] = prj[tx]; 605 | } else { 606 | projection[idxchannel][0][idxview][dIndex] = prj[tx]; 607 | } 608 | } 609 | } else { 610 | ang_error = (ang - floor(ang / 2 / PI) * 2 * PI) * 4 / PI; 611 | if (ang_error >= 3 && ang_error < 7) { 612 | beta0 = (*dets / 2 - dIndex0) * *dDet + *binshift + ang; 613 | dPoint0 = sourcey + cos(beta0) / sin(beta0) * sourcex; 614 | if (dIndex < *dets) { 615 | double beta = (*dets / 2 - dIndex - 1) * *dDet + *binshift + ang; 616 | dPoint[tx] = sourcey + cos(beta) / sin(beta) * sourcex; 617 | } 618 | } else { 619 | beta0 = (dIndex0 - *dets / 2) * *dDet + *binshift + ang; 620 | dPoint0 = sourcey + cos(beta0) / sin(beta0) * sourcex; 621 | if (dIndex < *dets) { 622 | double beta = (dIndex + 1 - *dets / 2) * *dDet + *binshift + ang; 623 | dPoint[tx] = sourcey + cos(beta) / sin(beta) * sourcex; 624 | } 625 | } 626 | __syncthreads(); 627 | if (tx == 0){ 628 | coef[tx] = dPoint[tx] - dPoint0; 629 | } else { 630 | coef[tx] = dPoint[tx] - dPoint[tx-1]; 631 | } 632 | __syncthreads(); 633 | for (int i = 0; i < ceil(*width / blockDim.x); i++){ 634 | int idxcol = i * blockDim.x + tx; 635 | if (idxcol < *width) { 636 | scalar_t i0x = (idxcol - *width / 2 + 0.5) * dImage; 637 | scalar_t i0y = - *height / 2 * dImage; 638 | int idx0row = floor(((i0x - sourcex) / (- sourcex) * 639 | (dPoint0 - sourcey) + sourcey - i0y) / dImage); 640 | idx0row = max(idx0row, 0); 641 | i0y += idx0row * dImage; 642 | scalar_t threadprj = 0; 643 | scalar_t prebound = map_y(sourcex, sourcey, i0x, i0y); 644 | scalar_t prepixbound = prebound; 645 | prebound = max(prebound, dPoint0); 646 | i0y += dImage; 647 | scalar_t pixbound = map_y(sourcex, sourcey, i0x, i0y); 648 | scalar_t detbound = dPoint[0]; 649 | int idxd = 0, idxi = idx0row; 650 | while (idxi < *height && (idxd + dIndex0) < *dets && idxd < blockDim.x) { 651 | if (detbound <= prebound) { 652 | idxd ++; 653 | if (idxd < blockDim.x) detbound = dPoint[idxd]; 654 | }else if (pixbound <= prebound) { 655 | idxi ++; 656 | i0y += dImage; 657 | prepixbound = pixbound; 658 | pixbound = map_y(sourcex, sourcey, i0x, i0y); 659 | }else if (pixbound < detbound) { 660 | threadprj += (pixbound - prebound) * image[idxchannel][0][static_cast(*height)-1-idxi][idxcol] / (pixbound - prepixbound) / cweight(sourcex, sourcey, i0x, i0y - dImage / 2); 661 | prebound = pixbound; 662 | idxi ++; 663 | i0y += dImage; 664 | prepixbound = pixbound; 665 | pixbound = map_y(sourcex, sourcey, i0x, i0y); 666 | } else { 667 | threadprj += (detbound - prebound) * image[idxchannel][0][static_cast(*height)-1-idxi][idxcol] / (pixbound - prepixbound) / cweight(sourcex, sourcey, i0x, i0y - dImage / 2); 668 | prebound = detbound; 669 | atomicAdd(prj+idxd, threadprj); 670 | threadprj = 0; 671 | idxd ++; 672 | if (idxd < blockDim.x) detbound = dPoint[idxd]; 673 | } 674 | } 675 | if (threadprj != 0) atomicAdd(prj+idxd, threadprj); 676 | } 677 | } 678 | __syncthreads(); 679 | dPoint0 = abs(sourcex) / sqrt((dPoint0 - sourcey) * (dPoint0 - sourcey) + sourcex * sourcex); 680 | if (dIndex < *dets) { 681 | dPoint[tx] = abs(sourcex) / sqrt((dPoint[tx] - sourcey) * (dPoint[tx] - sourcey) + sourcex * sourcex); 682 | __syncthreads(); 683 | if (tx == 0){ 684 | coef[tx] = (dPoint[tx] + dPoint0) / 2; 685 | } else { 686 | coef[tx] = (dPoint[tx] + dPoint[tx-1]) / 2; 687 | } 688 | __syncthreads(); 689 | prj[tx] *= dImage; 690 | prj[tx] /= coef[tx]; 691 | if (ang_error >= 3 && ang_error < 7) { 692 | projection[idxchannel][0][idxview][static_cast(*dets)-1-dIndex] = prj[tx]; 693 | } else { 694 | projection[idxchannel][0][idxview][dIndex] = prj[tx]; 695 | } 696 | } 697 | } 698 | } 699 | 700 | template 701 | __global__ void bprj_fan_ea( 702 | torch::PackedTensorAccessor image, 703 | const torch::PackedTensorAccessor projection, 704 | const scalar_t* __restrict__ views, const scalar_t* __restrict__ dets, 705 | const scalar_t* __restrict__ width, const scalar_t* __restrict__ height, 706 | const scalar_t* __restrict__ dImg, const scalar_t* __restrict__ dDet, 707 | const scalar_t* __restrict__ Ang0, const scalar_t* __restrict__ dAng, 708 | const scalar_t* __restrict__ s2r, const scalar_t* __restrict__ d2r, 709 | const scalar_t* __restrict__ binshift) { 710 | 711 | __shared__ unsigned int nblocks; 712 | __shared__ unsigned int idxchannel; 713 | __shared__ unsigned int idxview; 714 | nblocks = ceil(*views / gridDim.y); 715 | idxchannel = blockIdx.x % nblocks; 716 | idxview = idxchannel * gridDim.y + blockIdx.y; 717 | if (idxview >= *views) return; 718 | idxchannel = blockIdx.x / nblocks; 719 | __shared__ scalar_t prj[BLOCK_DIM]; 720 | __shared__ scalar_t dPoint[BLOCK_DIM]; 721 | __shared__ scalar_t dImage; 722 | __shared__ scalar_t sourcex; 723 | __shared__ scalar_t sourcey; 724 | __shared__ scalar_t dPoint0; 725 | __shared__ double ang; 726 | __shared__ double beta0; 727 | __shared__ double PI; 728 | __shared__ double ang_error; 729 | __shared__ double cosval; 730 | __shared__ double sinval; 731 | __shared__ unsigned int dIndex0; 732 | 733 | PI = acos(-1.0); 734 | ang = idxview * *dAng + *Ang0; 735 | ang_error = abs(ang - round(ang / PI) * PI) * 4 / PI; 736 | cosval = cos(ang); 737 | sinval = sin(ang); 738 | sourcex = - sinval * *s2r; 739 | sourcey = cosval * *s2r; 740 | dImage = *dImg; 741 | dIndex0 = blockIdx.z * blockDim.x; 742 | unsigned int tx = threadIdx.x; 743 | unsigned int dIndex = dIndex0 + tx; 744 | __syncthreads(); 745 | if (ang_error <= 1) { 746 | ang_error = (ang - floor(ang / 2 / PI) * 2 * PI) * 4 / PI; 747 | if (ang_error >= 3 && ang_error < 7) { 748 | beta0 = (*dets / 2 - dIndex0) * *dDet + *binshift + ang; 749 | dPoint0 = sourcex + sin(beta0) / cos(beta0) * sourcey; 750 | if (dIndex < *dets) { 751 | double beta = (*dets / 2 - dIndex - 1) * *dDet + *binshift + ang; 752 | dPoint[tx] = sourcex + sin(beta) / cos(beta) * sourcey; 753 | } 754 | } else { 755 | beta0 = (dIndex0 - *dets / 2) * *dDet + *binshift + ang; 756 | dPoint0 = sourcex + sin(beta0) / cos(beta0) * sourcey; 757 | if (dIndex < *dets) { 758 | double beta = (dIndex + 1 - *dets / 2) * *dDet + *binshift + ang; 759 | dPoint[tx] = sourcex + sin(beta) / cos(beta) * sourcey; 760 | } 761 | } 762 | __syncthreads(); 763 | if (dIndex < *dets) { 764 | if (ang_error >= 3 && ang_error < 7) { 765 | prj[tx] = projection[idxchannel][0][idxview][static_cast(*dets)-1-dIndex]; 766 | } else { 767 | prj[tx] = projection[idxchannel][0][idxview][dIndex]; 768 | } 769 | } else { 770 | prj[tx] = 0; 771 | } 772 | __syncthreads(); 773 | for (int i = 0; i < ceil(*height / blockDim.x); i++){ 774 | int idxrow = i * blockDim.x + tx; 775 | if (idxrow < *height) { 776 | scalar_t i0y = (*height / 2 - idxrow - 0.5) * dImage; 777 | scalar_t i0x = - *width / 2 * dImage; 778 | int idx0col = floor(((i0y - sourcey) / (- sourcey) * 779 | (dPoint0- sourcex) + sourcex - i0x) / dImage); 780 | idx0col = max(idx0col, 0); 781 | i0x += idx0col * dImage; 782 | scalar_t threadprj = 0; 783 | scalar_t prebound = map_x(sourcex, sourcey, i0x, i0y); 784 | scalar_t prepixbound = prebound; 785 | prebound = max(prebound, dPoint0); 786 | i0x += dImage; 787 | scalar_t pixbound = map_x(sourcex, sourcey, i0x, i0y); 788 | scalar_t detbound = dPoint[0]; 789 | int idxd = 0, idxi = idx0col; 790 | while (idxi < *width && (idxd + dIndex0) < *dets && idxd < blockDim.x) { 791 | if (detbound <= prebound) { 792 | idxd ++; 793 | if (idxd < blockDim.x) detbound = dPoint[idxd]; 794 | }else if (pixbound <= prebound){ 795 | idxi ++; 796 | i0x += dImage; 797 | prepixbound = pixbound; 798 | pixbound = map_x(sourcex, sourcey, i0x, i0y); 799 | }else if (pixbound <= detbound) { 800 | threadprj += (pixbound - prebound) * prj[idxd] / (pixbound - prepixbound); 801 | prebound = pixbound; 802 | threadprj /= cweight(sourcex, sourcey, i0x - dImage / 2, i0y); 803 | atomicAdd(&(image[idxchannel][0][idxrow][idxi]), threadprj); 804 | threadprj = 0; 805 | idxi ++; 806 | i0x += dImage; 807 | prepixbound = pixbound; 808 | pixbound = map_x(sourcex, sourcey, i0x, i0y); 809 | } else { 810 | threadprj += (detbound - prebound) * prj[idxd] / (pixbound - prepixbound); 811 | prebound = detbound; 812 | idxd ++; 813 | if (idxd < blockDim.x) detbound = dPoint[idxd]; 814 | } 815 | } 816 | if (threadprj !=0 ) { 817 | threadprj /= cweight(sourcex, sourcey, i0x - dImage / 2, i0y); 818 | atomicAdd(&(image[idxchannel][0][idxrow][idxi]), threadprj); 819 | } 820 | } 821 | } 822 | } else { 823 | ang_error = (ang - floor(ang / 2 / PI) * 2 * PI) * 4 / PI; 824 | if (ang_error >= 3 && ang_error < 7) { 825 | beta0 = (*dets / 2 - dIndex0) * *dDet + *binshift + ang; 826 | dPoint0 = sourcey + cos(beta0) / sin(beta0) * sourcex; 827 | if (dIndex < *dets) { 828 | double beta = (*dets / 2 - dIndex - 1) * *dDet + *binshift + ang; 829 | dPoint[tx] = sourcey + cos(beta) / sin(beta) * sourcex; 830 | } 831 | } else { 832 | beta0 = (dIndex0 - *dets / 2) * *dDet + *binshift + ang; 833 | dPoint0 = sourcey + cos(beta0) / sin(beta0) * sourcex; 834 | if (dIndex < *dets) { 835 | double beta = (dIndex + 1 - *dets / 2) * *dDet + *binshift + ang; 836 | dPoint[tx] = sourcey + cos(beta) / sin(beta) * sourcex; 837 | } 838 | } 839 | __syncthreads(); 840 | if (dIndex < *dets) { 841 | if (ang_error >= 3 && ang_error < 7) { 842 | prj[tx] = projection[idxchannel][0][idxview][static_cast(*dets)-1-dIndex]; 843 | } else { 844 | prj[tx] = projection[idxchannel][0][idxview][dIndex]; 845 | } 846 | } else { 847 | prj[tx] = 0; 848 | } 849 | __syncthreads(); 850 | 851 | for (int i = 0; i < ceil(*width / blockDim.x); i++){ 852 | int idxcol = i * blockDim.x + tx; 853 | if (idxcol < *width) { 854 | scalar_t i0x = (idxcol - *width / 2 + 0.5) * dImage; 855 | scalar_t i0y = - *height / 2 * dImage; 856 | int idx0row = floor(((i0x - sourcex) / (- sourcex) * 857 | (dPoint0- sourcey) + sourcey - i0y) / dImage); 858 | idx0row = max(idx0row, 0); 859 | i0y += idx0row * dImage; 860 | scalar_t threadprj = 0; 861 | scalar_t prebound = map_y(sourcex, sourcey, i0x, i0y); 862 | scalar_t prepixbound = prebound; 863 | prebound = max(prebound, dPoint0); 864 | i0y += dImage; 865 | scalar_t pixbound = map_y(sourcex, sourcey, i0x, i0y); 866 | scalar_t detbound = dPoint[0]; 867 | int idxd = 0, idxi = idx0row; 868 | while (idxi < *height && (idxd + dIndex0) < *dets && idxd < blockDim.x) { 869 | if (detbound <= prebound) { 870 | idxd ++; 871 | if (idxd < blockDim.x) detbound = dPoint[idxd]; 872 | }else if (pixbound <= prebound) { 873 | idxi ++; 874 | i0y += dImage; 875 | prepixbound = pixbound; 876 | pixbound = map_y(sourcex, sourcey, i0x, i0y); 877 | }else if (pixbound <= detbound) { 878 | threadprj += (pixbound - prebound) * prj[idxd] / (pixbound - prepixbound); 879 | prebound = pixbound; 880 | threadprj /= cweight(sourcex, sourcey, i0x, i0y - dImage / 2); 881 | atomicAdd(&(image[idxchannel][0][static_cast(*height)-1-idxi][idxcol]), threadprj); 882 | threadprj = 0; 883 | idxi ++; 884 | i0y += dImage; 885 | prepixbound = pixbound; 886 | pixbound = map_y(sourcex, sourcey, i0x, i0y); 887 | } else { 888 | threadprj += (detbound - prebound) * prj[idxd] / (pixbound - prepixbound); 889 | prebound = detbound; 890 | idxd ++; 891 | if (idxd < blockDim.x) detbound = dPoint[idxd]; 892 | } 893 | } 894 | if (threadprj !=0 ) { 895 | threadprj /= cweight(sourcex, sourcey, i0x, i0y - dImage / 2); 896 | atomicAdd(&(image[idxchannel][0][static_cast(*height)-1-idxi][idxcol]), threadprj); 897 | } 898 | } 899 | } 900 | } 901 | } 902 | 903 | template 904 | __global__ void rlfilter(scalar_t* __restrict__ filter, 905 | const scalar_t* __restrict__ dets, const scalar_t* __restrict__ dDet) { 906 | unsigned xIndex = blockIdx.x * blockDim.x + threadIdx.x; 907 | __shared__ double PI; 908 | __shared__ scalar_t d; 909 | PI = acos(-1.0); 910 | d = *dDet; 911 | if (xIndex < (*dets * 2 - 1)) { 912 | int x = xIndex - *dets + 1; 913 | if ((abs(x) % 2) == 1) { 914 | filter[xIndex] = -1 / (PI * PI * x * x * d * d); 915 | } else if (x == 0) { 916 | filter[xIndex] = 1 / (4 * d * d); 917 | } else { 918 | filter[xIndex] = 0; 919 | } 920 | } 921 | } 922 | 923 | torch::Tensor prj_fan_ea_cuda(torch::Tensor image, torch::Tensor options) { 924 | cudaSetDevice(image.device().index()); 925 | auto views = options[0]; 926 | auto dets = options[1]; 927 | auto width = options[2]; 928 | auto height = options[3]; 929 | auto dImg = options[4]; 930 | auto dDet = options[5]; 931 | auto Ang0 = options[6]; 932 | auto dAng = options[7]; 933 | auto s2r = options[8]; 934 | auto d2r = options[9]; 935 | auto binshift = options[10]; 936 | const int channels = static_cast(image.size(0)); 937 | auto projection = torch::empty({channels, 1, views.item(), dets.item()}, image.options()); 938 | 939 | int nblocksx = ceil(views.item() / GRID_DIM) * channels; 940 | int nblocksy = min(views.item(), GRID_DIM); 941 | int nblocksz = ceil(dets.item() / BLOCK_DIM); 942 | const dim3 blocks(nblocksx, nblocksy, nblocksz); 943 | 944 | AT_DISPATCH_FLOATING_TYPES(image.type(), "fan_beam_equal_angle_projection", ([&] { 945 | prj_fan_ea<<>>( 946 | image.packed_accessor(), 947 | projection.packed_accessor(), 948 | views.data(), dets.data(), width.data(), 949 | height.data(), dImg.data(), dDet.data(), 950 | Ang0.data(), dAng.data(), s2r.data(), 951 | d2r.data(), binshift.data() 952 | ); 953 | })); 954 | return projection; 955 | } 956 | 957 | torch::Tensor prj_t_fan_ea_cuda(torch::Tensor projection, torch::Tensor options) { 958 | cudaSetDevice(projection.device().index()); 959 | auto views = options[0]; 960 | auto dets = options[1]; 961 | auto width = options[2]; 962 | auto height = options[3]; 963 | auto dImg = options[4]; 964 | auto dDet = options[5]; 965 | auto Ang0 = options[6]; 966 | auto dAng = options[7]; 967 | auto s2r = options[8]; 968 | auto d2r = options[9]; 969 | auto binshift = options[10]; 970 | const int channels = static_cast(projection.size(0)); 971 | auto image = torch::zeros({channels, 1, height.item(), width.item()}, projection.options()); 972 | 973 | int nblocksx = ceil(views.item() / GRID_DIM) * channels; 974 | int nblocksy = min(views.item(), GRID_DIM); 975 | int nblocksz = ceil(dets.item() / BLOCK_DIM); 976 | const dim3 blocks(nblocksx, nblocksy, nblocksz); 977 | 978 | AT_DISPATCH_FLOATING_TYPES(projection.type(), "fan_beam_equal_angle_backprojection", ([&] { 979 | prj_t_fan_ea<<>>( 980 | image.packed_accessor(), 981 | projection.packed_accessor(), 982 | views.data(), dets.data(), width.data(), 983 | height.data(), dImg.data(), dDet.data(), 984 | Ang0.data(), dAng.data(), s2r.data(), 985 | d2r.data(), binshift.data() 986 | ); 987 | })); 988 | return image; 989 | } 990 | 991 | torch::Tensor bprj_t_fan_ea_cuda(torch::Tensor image, torch::Tensor options) { 992 | cudaSetDevice(image.device().index()); 993 | auto views = options[0]; 994 | auto dets = options[1]; 995 | auto width = options[2]; 996 | auto height = options[3]; 997 | auto dImg = options[4]; 998 | auto dDet = options[5]; 999 | auto Ang0 = options[6]; 1000 | auto dAng = options[7]; 1001 | auto s2r = options[8]; 1002 | auto d2r = options[9]; 1003 | auto binshift = options[10]; 1004 | const int channels = static_cast(image.size(0)); 1005 | auto projection = torch::empty({channels, 1, views.item(), dets.item()}, image.options()); 1006 | 1007 | int nblocksx = ceil(views.item() / GRID_DIM) * channels; 1008 | int nblocksy = min(views.item(), GRID_DIM); 1009 | int nblocksz = ceil(dets.item() / BLOCK_DIM); 1010 | const dim3 blocks(nblocksx, nblocksy, nblocksz); 1011 | 1012 | AT_DISPATCH_FLOATING_TYPES(image.type(), "fan_beam_equal_angle_fbp_projection", ([&] { 1013 | bprj_t_fan_ea<<>>( 1014 | image.packed_accessor(), 1015 | projection.packed_accessor(), 1016 | views.data(), dets.data(), width.data(), 1017 | height.data(), dImg.data(), dDet.data(), 1018 | Ang0.data(), dAng.data(), s2r.data(), 1019 | d2r.data(), binshift.data() 1020 | ); 1021 | })); 1022 | return projection; 1023 | } 1024 | 1025 | torch::Tensor bprj_fan_ea_cuda(torch::Tensor projection, torch::Tensor options) { 1026 | cudaSetDevice(projection.device().index()); 1027 | auto views = options[0]; 1028 | auto dets = options[1]; 1029 | auto width = options[2]; 1030 | auto height = options[3]; 1031 | auto dImg = options[4]; 1032 | auto dDet = options[5]; 1033 | auto Ang0 = options[6]; 1034 | auto dAng = options[7]; 1035 | auto s2r = options[8]; 1036 | auto d2r = options[9]; 1037 | auto binshift = options[10]; 1038 | const int channels = static_cast(projection.size(0)); 1039 | auto image = torch::zeros({channels, 1, height.item(), width.item()}, projection.options()); 1040 | 1041 | int nblocksx = ceil(views.item() / GRID_DIM) * channels; 1042 | int nblocksy = min(views.item(), GRID_DIM); 1043 | int nblocksz = ceil(dets.item() / BLOCK_DIM); 1044 | const dim3 blocks(nblocksx, nblocksy, nblocksz); 1045 | 1046 | AT_DISPATCH_FLOATING_TYPES(projection.type(), "fan_beam_equal_angle_fbp_backprojection", ([&] { 1047 | bprj_fan_ea<<>>( 1048 | image.packed_accessor(), 1049 | projection.packed_accessor(), 1050 | views.data(), dets.data(), width.data(), 1051 | height.data(), dImg.data(), dDet.data(), 1052 | Ang0.data(), dAng.data(), s2r.data(), 1053 | d2r.data(), binshift.data() 1054 | ); 1055 | })); 1056 | return image; 1057 | } 1058 | 1059 | torch::Tensor fbp_fan_ea_cuda(torch::Tensor projection, torch::Tensor options) { 1060 | cudaSetDevice(projection.device().index()); 1061 | auto views = options[0]; 1062 | auto dets = options[1]; 1063 | auto width = options[2]; 1064 | auto height = options[3]; 1065 | auto dImg = options[4]; 1066 | auto dDet = options[5]; 1067 | auto Ang0 = options[6]; 1068 | auto dAng = options[7]; 1069 | auto s2r = options[8]; 1070 | auto d2r = options[9]; 1071 | auto binshift = options[10]; 1072 | const int channels = static_cast(projection.size(0)); 1073 | auto image = torch::zeros({channels, 1, height.item(), width.item()}, projection.options()); 1074 | auto filter = torch::empty({1,1,1,dets.item()*2-1}, projection.options()); 1075 | auto rectweight = torch::arange((-dets.item()/2+0.5), dets.item()/2, 1, projection.options()); 1076 | rectweight = rectweight * dDet; 1077 | rectweight = torch::cos(rectweight); 1078 | rectweight = rectweight * s2r * dDet; 1079 | rectweight = rectweight.view({1, 1, 1, dets.item()}); 1080 | rectweight = projection * rectweight; 1081 | 1082 | int filterdim = ceil((dets.item()*2-1) / BLOCK_DIM); 1083 | int nblocksx = ceil(views.item() / GRID_DIM) * channels; 1084 | int nblocksy = min(views.item(), GRID_DIM); 1085 | int nblocksz = ceil(dets.item() / BLOCK_DIM); 1086 | const dim3 blocks(nblocksx, nblocksy, nblocksz); 1087 | 1088 | AT_DISPATCH_FLOATING_TYPES(projection.type(), "ramp_filter", ([&] { 1089 | rlfilter<<>>( 1090 | filter.data(), dets.data(), dDet.data()); 1091 | })); 1092 | 1093 | auto filtered_projection = torch::conv2d(rectweight, filter, {}, 1, torch::IntArrayRef({0, dets.item()-1})); 1094 | 1095 | AT_DISPATCH_FLOATING_TYPES(projection.type(), "fan_beam_equal_angle_fbp_backprojection", ([&] { 1096 | bprj_fan_ea<<>>( 1097 | image.packed_accessor(), 1098 | filtered_projection.packed_accessor(), 1099 | views.data(), dets.data(), width.data(), 1100 | height.data(), dImg.data(), dDet.data(), 1101 | Ang0.data(), dAng.data(), s2r.data(), 1102 | d2r.data(), binshift.data() 1103 | ); 1104 | })); 1105 | image = image * dAng / 2; 1106 | return image; 1107 | } -------------------------------------------------------------------------------- /src/fan_ed_cuda.h: -------------------------------------------------------------------------------- 1 | #ifndef FAN_ED_CUDA_H 2 | #define FAN_ED_CUDA_H 3 | 4 | #include 5 | 6 | torch::Tensor prj_fan_ed_cuda(torch::Tensor image, torch::Tensor options); 7 | torch::Tensor prj_t_fan_ed_cuda(torch::Tensor projection, torch::Tensor options); 8 | torch::Tensor bprj_t_fan_ed_cuda(torch::Tensor image, torch::Tensor options); 9 | torch::Tensor bprj_fan_ed_cuda(torch::Tensor image, torch::Tensor options); 10 | torch::Tensor bprj_sv_fan_ed_cuda(torch::Tensor image, torch::Tensor options); 11 | torch::Tensor fbp_fan_ed_cuda(torch::Tensor projection, torch::Tensor options); 12 | 13 | #endif 14 | -------------------------------------------------------------------------------- /src/fan_ed_kernel.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #define BLOCK_DIM 256 5 | #define GRID_DIM 512 6 | 7 | template 8 | __device__ scalar_t map_x(scalar_t sourcex, scalar_t sourcey, scalar_t detx, scalar_t dety) { 9 | return (sourcex * dety - sourcey * detx) / (dety - sourcey); 10 | } 11 | 12 | template 13 | __device__ scalar_t map_y(scalar_t sourcex, scalar_t sourcey, scalar_t detx, scalar_t dety) { 14 | return (sourcey * detx - sourcex * dety) / (detx - sourcex); 15 | } 16 | 17 | template 18 | __device__ scalar_t cweight(scalar_t sourcex, scalar_t sourcey, scalar_t detx, scalar_t dety, scalar_t r) { 19 | scalar_t d = (sourcex * detx + sourcey * dety) / r; 20 | return r * r / ((r - d) * (r - d)); 21 | } 22 | 23 | template 24 | __global__ void prj_fan_ed( 25 | const torch::PackedTensorAccessor image, 26 | torch::PackedTensorAccessor projection, 27 | const scalar_t* __restrict__ views, const scalar_t* __restrict__ dets, 28 | const scalar_t* __restrict__ width, const scalar_t* __restrict__ height, 29 | const scalar_t* __restrict__ dImg, const scalar_t* __restrict__ dDet, 30 | const scalar_t* __restrict__ Ang0, const scalar_t* __restrict__ dAng, 31 | const scalar_t* __restrict__ s2r, const scalar_t* __restrict__ d2r, 32 | const scalar_t* __restrict__ binshift) { 33 | 34 | __shared__ unsigned int nblocks; 35 | __shared__ unsigned int idxchannel; 36 | __shared__ unsigned int idxview; 37 | nblocks = ceil(*views / gridDim.y); 38 | idxchannel = blockIdx.x % nblocks; 39 | idxview = idxchannel * gridDim.y + blockIdx.y; 40 | if (idxview >= *views) return; 41 | idxchannel = blockIdx.x / nblocks; 42 | __shared__ scalar_t prj[BLOCK_DIM]; 43 | __shared__ scalar_t dPoint[BLOCK_DIM]; 44 | __shared__ scalar_t coef[BLOCK_DIM]; 45 | __shared__ scalar_t dImage; 46 | __shared__ scalar_t sourcex; 47 | __shared__ scalar_t sourcey; 48 | __shared__ scalar_t d0x; 49 | __shared__ scalar_t d0y; 50 | __shared__ scalar_t dPoint0; 51 | __shared__ double ang; 52 | __shared__ double PI; 53 | __shared__ double ang_error; 54 | __shared__ double cosval; 55 | __shared__ double sinval; 56 | __shared__ double virdDet; 57 | __shared__ double virshift; 58 | __shared__ unsigned int dIndex0; 59 | 60 | PI = acos(-1.0); 61 | ang = idxview * *dAng + *Ang0; 62 | dImage = *dImg; 63 | ang_error = abs(ang - round(ang / PI) * PI) * 4 / PI; 64 | cosval = cos(ang); 65 | sinval = sin(ang); 66 | sourcex = - sinval * *s2r; 67 | sourcey = cosval * *s2r; 68 | virdDet = *s2r / (*s2r + *d2r) * *dDet; 69 | virshift = *s2r / (*s2r + *d2r) * *binshift; 70 | dIndex0 = blockIdx.z * blockDim.x; 71 | unsigned int tx = threadIdx.x; 72 | unsigned int dIndex = dIndex0 + tx; 73 | prj[tx] = 0; 74 | __syncthreads(); 75 | if (ang_error <= 1) { 76 | ang_error = (ang - floor(ang / 2 / PI) * 2 * PI) * 4 / PI; 77 | if (ang_error >= 3 && ang_error < 7) { 78 | d0x = ((*dets / 2 - dIndex0) * virdDet + virshift) * cosval; 79 | d0y = ((*dets / 2 - dIndex0) * virdDet + virshift) * sinval; 80 | dPoint0 = map_x(sourcex, sourcey, d0x, d0y); 81 | if (dIndex < *dets) { 82 | scalar_t detx = ((*dets / 2 - dIndex - 1) * virdDet + virshift) * cosval; 83 | scalar_t dety = ((*dets / 2 - dIndex - 1) * virdDet + virshift) * sinval; 84 | dPoint[tx] = map_x(sourcex, sourcey, detx, dety); 85 | } 86 | } else { 87 | d0x = ((dIndex0 - *dets / 2) * virdDet + virshift) * cosval; 88 | d0y = ((dIndex0 - *dets / 2) * virdDet + virshift) * sinval; 89 | dPoint0 = map_x(sourcex, sourcey, d0x, d0y); 90 | if (dIndex < *dets) { 91 | scalar_t detx = ((dIndex + 1 - *dets / 2) * virdDet + virshift) * cosval; 92 | scalar_t dety = ((dIndex + 1 - *dets / 2) * virdDet + virshift) * sinval; 93 | dPoint[tx] = map_x(sourcex, sourcey, detx, dety); 94 | } 95 | } 96 | __syncthreads(); 97 | if (tx == 0){ 98 | coef[tx] = dPoint[tx] - dPoint0; 99 | } else { 100 | coef[tx] = dPoint[tx] - dPoint[tx-1]; 101 | } 102 | __syncthreads(); 103 | for (int i = 0; i < ceil(*height / blockDim.x); i++){ 104 | int idxrow = i * blockDim.x + tx; 105 | if (idxrow < *height) { 106 | scalar_t i0y = (*height / 2 - idxrow - 0.5) * dImage; 107 | scalar_t i0x = - *width / 2 * dImage; 108 | int idx0col = floor(((i0y - sourcey) / (d0y - sourcey) * 109 | (d0x - sourcex) + sourcex - i0x) / dImage); 110 | idx0col = max(idx0col, 0); 111 | i0x += idx0col * dImage; 112 | scalar_t threadprj = 0; 113 | scalar_t prebound = map_x(sourcex, sourcey, i0x, i0y); 114 | prebound = max(prebound, dPoint0); 115 | i0x += dImage; 116 | scalar_t pixbound = map_x(sourcex, sourcey, i0x, i0y); 117 | scalar_t detbound = dPoint[0]; 118 | int idxd = 0, idxi = idx0col; 119 | while (idxi < *width && (idxd + dIndex0) < *dets && idxd < blockDim.x) { 120 | if (detbound <= prebound) { 121 | idxd ++; 122 | if (idxd < blockDim.x) detbound = dPoint[idxd]; 123 | }else if (pixbound <= prebound){ 124 | idxi ++; 125 | i0x += dImage; 126 | pixbound = map_x(sourcex, sourcey, i0x, i0y); 127 | }else if (pixbound < detbound) { 128 | threadprj += (pixbound - prebound) * image[idxchannel][0][idxrow][idxi] / coef[idxd]; 129 | prebound = pixbound; 130 | idxi ++; 131 | i0x += dImage; 132 | pixbound = map_x(sourcex, sourcey, i0x, i0y); 133 | } else { 134 | threadprj += (detbound - prebound) * image[idxchannel][0][idxrow][idxi] / coef[idxd]; 135 | prebound = detbound; 136 | atomicAdd(prj+idxd, threadprj); 137 | threadprj = 0; 138 | idxd ++; 139 | if (idxd < blockDim.x) detbound = dPoint[idxd]; 140 | } 141 | } 142 | if (threadprj != 0) atomicAdd(prj+idxd, threadprj); 143 | } 144 | } 145 | __syncthreads(); 146 | dPoint0 = abs(sourcey) / sqrt((dPoint0 - sourcex) * (dPoint0 - sourcex) + sourcey * sourcey); 147 | if (dIndex < *dets) { 148 | dPoint[tx] = abs(sourcey) / sqrt((dPoint[tx] - sourcex) * (dPoint[tx] - sourcex) + sourcey * sourcey); 149 | __syncthreads(); 150 | if (tx == 0){ 151 | coef[tx] = (dPoint[tx] + dPoint0) / 2; 152 | } else { 153 | coef[tx] = (dPoint[tx] + dPoint[tx-1]) / 2; 154 | } 155 | __syncthreads(); 156 | prj[tx] *= dImage; 157 | prj[tx] /= coef[tx]; 158 | if (ang_error >= 3 && ang_error < 7) { 159 | projection[idxchannel][0][idxview][static_cast(*dets)-1-dIndex] = prj[tx]; 160 | } else { 161 | projection[idxchannel][0][idxview][dIndex] = prj[tx]; 162 | } 163 | } 164 | } else { 165 | ang_error = (ang - floor(ang / 2 / PI) * 2 * PI) * 4 / PI; 166 | if (ang_error >= 3 && ang_error < 7) { 167 | d0x = ((*dets / 2 - dIndex0) * virdDet + virshift) * cosval; 168 | d0y = ((*dets / 2 - dIndex0) * virdDet + virshift) * sinval; 169 | dPoint0 = map_y(sourcex, sourcey, d0x, d0y); 170 | if (dIndex < *dets) { 171 | scalar_t detx = ((*dets / 2 - dIndex - 1) * virdDet + virshift) * cosval; 172 | scalar_t dety = ((*dets / 2 - dIndex - 1) * virdDet + virshift) * sinval; 173 | dPoint[tx] = map_y(sourcex, sourcey, detx, dety); 174 | } 175 | } else { 176 | d0x = ((dIndex0 - *dets / 2) * virdDet + virshift) * cosval; 177 | d0y = ((dIndex0 - *dets / 2) * virdDet + virshift) * sinval; 178 | dPoint0 = map_y(sourcex, sourcey, d0x, d0y); 179 | if (dIndex < *dets) { 180 | scalar_t detx = ((dIndex + 1 - *dets / 2) * virdDet + virshift) * cosval; 181 | scalar_t dety = ((dIndex + 1 - *dets / 2) * virdDet + virshift) * sinval; 182 | dPoint[tx] = map_y(sourcex, sourcey, detx, dety); 183 | } 184 | } 185 | __syncthreads(); 186 | if (tx == 0){ 187 | coef[tx] = dPoint[tx] - dPoint0; 188 | } else { 189 | coef[tx] = dPoint[tx] - dPoint[tx-1]; 190 | } 191 | __syncthreads(); 192 | for (int i = 0; i < ceil(*width / blockDim.x); i++){ 193 | int idxcol = i * blockDim.x + tx; 194 | if (idxcol < *width) { 195 | scalar_t i0x = (idxcol - *width / 2 + 0.5) * dImage; 196 | scalar_t i0y = - *height / 2 * dImage; 197 | int idx0row = floor(((i0x - sourcex) / (d0x - sourcex) * 198 | (d0y - sourcey) + sourcey - i0y) / dImage); 199 | idx0row = max(idx0row, 0); 200 | i0y += idx0row * dImage; 201 | scalar_t threadprj = 0; 202 | scalar_t prebound = map_y(sourcex, sourcey, i0x, i0y); 203 | prebound = max(prebound, dPoint0); 204 | i0y += dImage; 205 | scalar_t pixbound = map_y(sourcex, sourcey, i0x, i0y); 206 | scalar_t detbound = dPoint[0]; 207 | int idxd = 0, idxi = idx0row; 208 | while (idxi < *height && (idxd + dIndex0) < *dets && idxd < blockDim.x) { 209 | if (detbound <= prebound) { 210 | idxd ++; 211 | if (idxd < blockDim.x) detbound = dPoint[idxd]; 212 | }else if (pixbound <= prebound) { 213 | idxi ++; 214 | i0y += dImage; 215 | pixbound = map_y(sourcex, sourcey, i0x, i0y); 216 | }else if (pixbound < detbound) { 217 | threadprj += (pixbound - prebound) * image[idxchannel][0][static_cast(*height)-1-idxi][idxcol] / coef[idxd]; 218 | prebound = pixbound; 219 | idxi ++; 220 | i0y += dImage; 221 | pixbound = map_y(sourcex, sourcey, i0x, i0y); 222 | } else { 223 | threadprj += (detbound - prebound) * image[idxchannel][0][static_cast(*height)-1-idxi][idxcol] / coef[idxd]; 224 | prebound = detbound; 225 | atomicAdd(prj+idxd, threadprj); 226 | threadprj = 0; 227 | idxd ++; 228 | if (idxd < blockDim.x) detbound = dPoint[idxd]; 229 | } 230 | } 231 | if (threadprj != 0) atomicAdd(prj+idxd, threadprj); 232 | } 233 | } 234 | __syncthreads(); 235 | dPoint0 = abs(sourcex) / sqrt((dPoint0 - sourcey) * (dPoint0 - sourcey) + sourcex * sourcex); 236 | if (dIndex < *dets) { 237 | dPoint[tx] = abs(sourcex) / sqrt((dPoint[tx] - sourcey) * (dPoint[tx] - sourcey) + sourcex * sourcex); 238 | __syncthreads(); 239 | if (tx == 0){ 240 | coef[tx] = (dPoint[tx] + dPoint0) / 2; 241 | } else { 242 | coef[tx] = (dPoint[tx] + dPoint[tx-1]) / 2; 243 | } 244 | __syncthreads(); 245 | prj[tx] *= dImage; 246 | prj[tx] /= coef[tx]; 247 | if (ang_error >= 3 && ang_error < 7) { 248 | projection[idxchannel][0][idxview][static_cast(*dets)-1-dIndex] = prj[tx]; 249 | } else { 250 | projection[idxchannel][0][idxview][dIndex] = prj[tx]; 251 | } 252 | } 253 | } 254 | } 255 | 256 | template 257 | __global__ void prj_t_fan_ed( 258 | torch::PackedTensorAccessor image, 259 | const torch::PackedTensorAccessor projection, 260 | const scalar_t* __restrict__ views, const scalar_t* __restrict__ dets, 261 | const scalar_t* __restrict__ width, const scalar_t* __restrict__ height, 262 | const scalar_t* __restrict__ dImg, const scalar_t* __restrict__ dDet, 263 | const scalar_t* __restrict__ Ang0, const scalar_t* __restrict__ dAng, 264 | const scalar_t* __restrict__ s2r, const scalar_t* __restrict__ d2r, 265 | const scalar_t* __restrict__ binshift) { 266 | 267 | __shared__ unsigned int nblocks; 268 | __shared__ unsigned int idxchannel; 269 | __shared__ unsigned int idxview; 270 | nblocks = ceil(*views / gridDim.y); 271 | idxchannel = blockIdx.x % nblocks; 272 | idxview = idxchannel * gridDim.y + blockIdx.y; 273 | if (idxview >= *views) return; 274 | idxchannel = blockIdx.x / nblocks; 275 | __shared__ scalar_t prj[BLOCK_DIM]; 276 | __shared__ scalar_t dPoint[BLOCK_DIM]; 277 | __shared__ scalar_t coef[BLOCK_DIM]; 278 | __shared__ scalar_t dImage; 279 | __shared__ scalar_t sourcex; 280 | __shared__ scalar_t sourcey; 281 | __shared__ scalar_t d0x; 282 | __shared__ scalar_t d0y; 283 | __shared__ scalar_t dPoint0; 284 | __shared__ double ang; 285 | __shared__ double PI; 286 | __shared__ double ang_error; 287 | __shared__ double cosval; 288 | __shared__ double sinval; 289 | __shared__ double virdDet; 290 | __shared__ double virshift; 291 | __shared__ unsigned int dIndex0; 292 | 293 | PI = acos(-1.0); 294 | ang = idxview * *dAng + *Ang0; 295 | ang_error = abs(ang - round(ang / PI) * PI) * 4 / PI; 296 | cosval = cos(ang); 297 | sinval = sin(ang); 298 | sourcex = - sinval * *s2r; 299 | sourcey = cosval * *s2r; 300 | virdDet = *s2r / (*s2r + *d2r) * *dDet; 301 | virshift = *s2r / (*s2r + *d2r) * *binshift; 302 | dIndex0 = blockIdx.z * blockDim.x; 303 | unsigned int tx = threadIdx.x; 304 | unsigned int dIndex = dIndex0 + tx; 305 | __syncthreads(); 306 | if (ang_error <= 1) { 307 | ang_error = (ang - floor(ang / 2 / PI) * 2 * PI) * 4 / PI; 308 | if (ang_error >= 3 && ang_error < 7) { 309 | d0x = ((*dets / 2 - dIndex0) * virdDet + virshift) * cosval; 310 | d0y = ((*dets / 2 - dIndex0) * virdDet + virshift) * sinval; 311 | dPoint0 = map_x(sourcex, sourcey, d0x, d0y); 312 | if (dIndex < *dets) { 313 | scalar_t detx = ((*dets / 2 - dIndex - 1) * virdDet + virshift) * cosval; 314 | scalar_t dety = ((*dets / 2 - dIndex - 1) * virdDet + virshift) * sinval; 315 | dPoint[tx] = map_x(sourcex, sourcey, detx, dety); 316 | } 317 | } else { 318 | d0x = ((dIndex0 - *dets / 2) * virdDet + virshift) * cosval; 319 | d0y = ((dIndex0 - *dets / 2) * virdDet + virshift) * sinval; 320 | dPoint0 = map_x(sourcex, sourcey, d0x, d0y); 321 | if (dIndex < *dets) { 322 | scalar_t detx = ((dIndex + 1 - *dets / 2) * virdDet + virshift) * cosval; 323 | scalar_t dety = ((dIndex + 1 - *dets / 2) * virdDet + virshift) * sinval; 324 | dPoint[tx] = map_x(sourcex, sourcey, detx, dety); 325 | } 326 | } 327 | __syncthreads(); 328 | dImage = abs(sourcey) / sqrt((dPoint0 - sourcex) * (dPoint0 - sourcex) + sourcey * sourcey); 329 | if (dIndex < *dets) { 330 | prj[tx] = abs(sourcey) / sqrt((dPoint[tx] - sourcex) * (dPoint[tx] - sourcex) + sourcey * sourcey); 331 | } else { 332 | prj[tx] = 0; 333 | } 334 | __syncthreads(); 335 | if (dIndex < *dets) { 336 | if (tx == 0){ 337 | coef[tx] = (prj[tx] + dImage) / 2; 338 | } else { 339 | coef[tx] = (prj[tx] + prj[tx-1]) / 2; 340 | } 341 | } 342 | __syncthreads(); 343 | if (dIndex < *dets) { 344 | if (ang_error >= 3 && ang_error < 7) { 345 | prj[tx] = projection[idxchannel][0][idxview][static_cast(*dets)-1-dIndex]; 346 | } else { 347 | prj[tx] = projection[idxchannel][0][idxview][dIndex]; 348 | } 349 | prj[tx] *= *dImg; 350 | prj[tx] /= coef[tx]; 351 | if (tx == 0){ 352 | coef[tx] = dPoint[tx] - dPoint0; 353 | } else { 354 | coef[tx] = dPoint[tx] - dPoint[tx-1]; 355 | } 356 | prj[tx] /= coef[tx]; 357 | } 358 | __syncthreads(); 359 | dImage = *dImg; 360 | for (int i = 0; i < ceil(*height / blockDim.x); i++){ 361 | int idxrow = i * blockDim.x + tx; 362 | if (idxrow < *height) { 363 | scalar_t i0y = (*height / 2 - idxrow - 0.5) * dImage; 364 | scalar_t i0x = - *width / 2 * dImage; 365 | int idx0col = floor(((i0y - sourcey) / (d0y - sourcey) * 366 | (d0x - sourcex) + sourcex - i0x) / dImage); 367 | idx0col = max(idx0col, 0); 368 | i0x += idx0col * dImage; 369 | scalar_t threadprj = 0; 370 | scalar_t prebound = map_x(sourcex, sourcey, i0x, i0y); 371 | prebound = max(prebound, dPoint0); 372 | i0x += dImage; 373 | scalar_t pixbound = map_x(sourcex, sourcey, i0x, i0y); 374 | scalar_t detbound = dPoint[0]; 375 | int idxd = 0, idxi = idx0col; 376 | while (idxi < *width && (idxd + dIndex0) < *dets && idxd < blockDim.x) { 377 | if (detbound <= prebound) { 378 | idxd ++; 379 | if (idxd < blockDim.x) detbound = dPoint[idxd]; 380 | }else if (pixbound <= prebound){ 381 | idxi ++; 382 | i0x += dImage; 383 | pixbound = map_x(sourcex, sourcey, i0x, i0y); 384 | }else if (pixbound <= detbound) { 385 | threadprj += (pixbound - prebound) * prj[idxd]; 386 | prebound = pixbound; 387 | atomicAdd(&(image[idxchannel][0][idxrow][idxi]), threadprj); 388 | threadprj = 0; 389 | idxi ++; 390 | i0x += dImage; 391 | pixbound = map_x(sourcex, sourcey, i0x, i0y); 392 | } else { 393 | threadprj += (detbound - prebound) * prj[idxd]; 394 | prebound = detbound; 395 | idxd ++; 396 | if (idxd < blockDim.x) detbound = dPoint[idxd]; 397 | } 398 | } 399 | if (threadprj !=0 ) atomicAdd(&(image[idxchannel][0][idxrow][idxi]), threadprj); 400 | } 401 | } 402 | } else { 403 | ang_error = (ang - floor(ang / 2 / PI) * 2 * PI) * 4 / PI; 404 | if (ang_error >= 3 && ang_error < 7) { 405 | d0x = ((*dets / 2 - dIndex0) * virdDet + virshift) * cosval; 406 | d0y = ((*dets / 2 - dIndex0) * virdDet + virshift) * sinval; 407 | dPoint0 = map_y(sourcex, sourcey, d0x, d0y); 408 | if (dIndex < *dets) { 409 | scalar_t detx = ((*dets / 2 - dIndex - 1) * virdDet + virshift) * cosval; 410 | scalar_t dety = ((*dets / 2 - dIndex - 1) * virdDet + virshift) * sinval; 411 | dPoint[tx] = map_y(sourcex, sourcey, detx, dety); 412 | } 413 | } else { 414 | d0x = ((dIndex0 - *dets / 2) * virdDet + virshift) * cosval; 415 | d0y = ((dIndex0 - *dets / 2) * virdDet + virshift) * sinval; 416 | dPoint0 = map_y(sourcex, sourcey, d0x, d0y); 417 | if (dIndex < *dets) { 418 | scalar_t detx = ((dIndex + 1 - *dets / 2) * virdDet + virshift) * cosval; 419 | scalar_t dety = ((dIndex + 1 - *dets / 2) * virdDet + virshift) * sinval; 420 | dPoint[tx] = map_y(sourcex, sourcey, detx, dety); 421 | } 422 | } 423 | __syncthreads(); 424 | dImage = abs(sourcex) / sqrt((dPoint0 - sourcey) * (dPoint0 - sourcey) + sourcex * sourcex); 425 | if (dIndex < *dets) { 426 | prj[tx] = abs(sourcex) / sqrt((dPoint[tx] - sourcey) * (dPoint[tx] - sourcey) + sourcex * sourcex); 427 | } else { 428 | prj[tx] = 0; 429 | } 430 | __syncthreads(); 431 | if (dIndex < *dets) { 432 | if (tx == 0){ 433 | coef[tx] = (prj[tx] + dImage) / 2; 434 | } else { 435 | coef[tx] = (prj[tx] + prj[tx-1]) / 2; 436 | } 437 | } 438 | __syncthreads(); 439 | if (dIndex < *dets) { 440 | if (ang_error >= 3 && ang_error < 7) { 441 | prj[tx] = projection[idxchannel][0][idxview][static_cast(*dets)-1-dIndex]; 442 | } else { 443 | prj[tx] = projection[idxchannel][0][idxview][dIndex]; 444 | } 445 | prj[tx] *= *dImg; 446 | prj[tx] /= coef[tx]; 447 | if (tx == 0){ 448 | coef[tx] = dPoint[tx] - dPoint0; 449 | } else { 450 | coef[tx] = dPoint[tx] - dPoint[tx-1]; 451 | } 452 | prj[tx] /= coef[tx]; 453 | } 454 | __syncthreads(); 455 | dImage = *dImg; 456 | for (int i = 0; i < ceil(*width / blockDim.x); i++){ 457 | int idxcol = i * blockDim.x + tx; 458 | if (idxcol < *width) { 459 | scalar_t i0x = (idxcol - *width / 2 + 0.5) * dImage; 460 | scalar_t i0y = - *height / 2 * dImage; 461 | int idx0row = floor(((i0x - sourcex) / (d0x - sourcex) * 462 | (d0y - sourcey) + sourcey - i0y) / dImage); 463 | idx0row = max(idx0row, 0); 464 | i0y += idx0row * dImage; 465 | scalar_t threadprj = 0; 466 | scalar_t prebound = map_y(sourcex, sourcey, i0x, i0y); 467 | prebound = max(prebound, dPoint0); 468 | i0y += dImage; 469 | scalar_t pixbound = map_y(sourcex, sourcey, i0x, i0y); 470 | scalar_t detbound = dPoint[0]; 471 | int idxd = 0, idxi = idx0row; 472 | while (idxi < *height && (idxd + dIndex0) < *dets && idxd < blockDim.x) { 473 | if (detbound <= prebound) { 474 | idxd ++; 475 | if (idxd < blockDim.x) detbound = dPoint[idxd]; 476 | }else if (pixbound <= prebound) { 477 | idxi ++; 478 | i0y += dImage; 479 | pixbound = map_y(sourcex, sourcey, i0x, i0y); 480 | }else if (pixbound <= detbound) { 481 | threadprj += (pixbound - prebound) * prj[idxd]; 482 | prebound = pixbound; 483 | atomicAdd(&(image[idxchannel][0][static_cast(*height)-1-idxi][idxcol]), threadprj); 484 | threadprj = 0; 485 | idxi ++; 486 | i0y += dImage; 487 | pixbound = map_y(sourcex, sourcey, i0x, i0y); 488 | } else { 489 | threadprj += (detbound - prebound) * prj[idxd]; 490 | prebound = detbound; 491 | idxd ++; 492 | if (idxd < blockDim.x) detbound = dPoint[idxd]; 493 | } 494 | } 495 | if (threadprj !=0 ) atomicAdd(&(image[idxchannel][0][static_cast(*height)-1-idxi][idxcol]), threadprj); 496 | } 497 | } 498 | } 499 | } 500 | 501 | template 502 | __global__ void bprj_t_fan_ed( 503 | const torch::PackedTensorAccessor image, 504 | torch::PackedTensorAccessor projection, 505 | const scalar_t* __restrict__ views, const scalar_t* __restrict__ dets, 506 | const scalar_t* __restrict__ width, const scalar_t* __restrict__ height, 507 | const scalar_t* __restrict__ dImg, const scalar_t* __restrict__ dDet, 508 | const scalar_t* __restrict__ Ang0, const scalar_t* __restrict__ dAng, 509 | const scalar_t* __restrict__ s2r, const scalar_t* __restrict__ d2r, 510 | const scalar_t* __restrict__ binshift) { 511 | 512 | __shared__ unsigned int nblocks; 513 | __shared__ unsigned int idxchannel; 514 | __shared__ unsigned int idxview; 515 | nblocks = ceil(*views / gridDim.y); 516 | idxchannel = blockIdx.x % nblocks; 517 | idxview = idxchannel * gridDim.y + blockIdx.y; 518 | if (idxview >= *views) return; 519 | idxchannel = blockIdx.x / nblocks; 520 | __shared__ scalar_t prj[BLOCK_DIM]; 521 | __shared__ scalar_t dPoint[BLOCK_DIM]; 522 | __shared__ scalar_t dImage; 523 | __shared__ scalar_t sourcex; 524 | __shared__ scalar_t sourcey; 525 | __shared__ scalar_t d0x; 526 | __shared__ scalar_t d0y; 527 | __shared__ scalar_t s0; 528 | __shared__ scalar_t dPoint0; 529 | __shared__ double ang; 530 | __shared__ double PI; 531 | __shared__ double ang_error; 532 | __shared__ double cosval; 533 | __shared__ double sinval; 534 | __shared__ double virdDet; 535 | __shared__ double virshift; 536 | __shared__ unsigned int dIndex0; 537 | 538 | PI = acos(-1.0); 539 | ang = idxview * *dAng + *Ang0; 540 | dImage = *dImg; 541 | ang_error = abs(ang - round(ang / PI) * PI) * 4 / PI; 542 | cosval = cos(ang); 543 | sinval = sin(ang); 544 | sourcex = - sinval * *s2r; 545 | sourcey = cosval * *s2r; 546 | virdDet = *s2r / (*s2r + *d2r) * *dDet; 547 | virshift = *s2r / (*s2r + *d2r) * *binshift; 548 | s0 = *s2r; 549 | dIndex0 = blockIdx.z * blockDim.x; 550 | unsigned int tx = threadIdx.x; 551 | unsigned int dIndex = dIndex0 + tx; 552 | prj[tx] = 0; 553 | __syncthreads(); 554 | if (ang_error <= 1) { 555 | ang_error = (ang - floor(ang / 2 / PI) * 2 * PI) * 4 / PI; 556 | if (ang_error >= 3 && ang_error < 7) { 557 | d0x = ((*dets / 2 - dIndex0) * virdDet + virshift) * cosval; 558 | d0y = ((*dets / 2 - dIndex0) * virdDet + virshift) * sinval; 559 | dPoint0 = map_x(sourcex, sourcey, d0x, d0y); 560 | if (dIndex < *dets) { 561 | scalar_t detx = ((*dets / 2 - dIndex - 1) * virdDet + virshift) * cosval; 562 | scalar_t dety = ((*dets / 2 - dIndex - 1) * virdDet + virshift) * sinval; 563 | dPoint[tx] = map_x(sourcex, sourcey, detx, dety); 564 | } 565 | } else { 566 | d0x = ((dIndex0 - *dets / 2) * virdDet + virshift) * cosval; 567 | d0y = ((dIndex0 - *dets / 2) * virdDet + virshift) * sinval; 568 | dPoint0 = map_x(sourcex, sourcey, d0x, d0y); 569 | if (dIndex < *dets) { 570 | scalar_t detx = ((dIndex + 1 - *dets / 2) * virdDet + virshift) * cosval; 571 | scalar_t dety = ((dIndex + 1 - *dets / 2) * virdDet + virshift) * sinval; 572 | dPoint[tx] = map_x(sourcex, sourcey, detx, dety); 573 | } 574 | } 575 | __syncthreads(); 576 | for (int i = 0; i < ceil(*height / blockDim.x); i++){ 577 | int idxrow = i * blockDim.x + tx; 578 | if (idxrow < *height) { 579 | scalar_t i0y = (*height / 2 - idxrow - 0.5) * dImage; 580 | scalar_t i0x = - *width / 2 * dImage; 581 | int idx0col = floor(((i0y - sourcey) / (d0y - sourcey) * 582 | (d0x - sourcex) + sourcex - i0x) / dImage); 583 | idx0col = max(idx0col, 0); 584 | i0x += idx0col * dImage; 585 | scalar_t threadprj = 0; 586 | scalar_t prebound = map_x(sourcex, sourcey, i0x, i0y); 587 | scalar_t prepixbound = prebound; 588 | prebound = max(prebound, dPoint0); 589 | i0x += dImage; 590 | scalar_t pixbound = map_x(sourcex, sourcey, i0x, i0y); 591 | scalar_t detbound = dPoint[0]; 592 | int idxd = 0, idxi = idx0col; 593 | while (idxi < *width && (idxd + dIndex0) < *dets && idxd < blockDim.x) { 594 | if (detbound <= prebound) { 595 | idxd ++; 596 | if (idxd < blockDim.x) detbound = dPoint[idxd]; 597 | }else if (pixbound <= prebound){ 598 | idxi ++; 599 | i0x += dImage; 600 | prepixbound = pixbound; 601 | pixbound = map_x(sourcex, sourcey, i0x, i0y); 602 | }else if (pixbound < detbound) { 603 | threadprj += (pixbound - prebound) * image[idxchannel][0][idxrow][idxi] / (pixbound - prepixbound) * cweight(sourcex, sourcey, i0x - dImage / 2, i0y, s0); 604 | prebound = pixbound; 605 | idxi ++; 606 | i0x += dImage; 607 | prepixbound = pixbound; 608 | pixbound = map_x(sourcex, sourcey, i0x, i0y); 609 | } else { 610 | threadprj += (detbound - prebound) * image[idxchannel][0][idxrow][idxi] / (pixbound - prepixbound) * cweight(sourcex, sourcey, i0x - dImage / 2, i0y, s0); 611 | prebound = detbound; 612 | atomicAdd(prj+idxd, threadprj); 613 | threadprj = 0; 614 | idxd ++; 615 | if (idxd < blockDim.x) detbound = dPoint[idxd]; 616 | } 617 | } 618 | if (threadprj != 0) atomicAdd(prj+idxd, threadprj); 619 | } 620 | } 621 | __syncthreads(); 622 | if (dIndex < *dets) { 623 | if (ang_error >= 3 && ang_error < 7) { 624 | projection[idxchannel][0][idxview][static_cast(*dets)-1-dIndex] = prj[tx]; 625 | } else { 626 | projection[idxchannel][0][idxview][dIndex] = prj[tx]; 627 | } 628 | } 629 | } else { 630 | ang_error = (ang - floor(ang / 2 / PI) * 2 * PI) * 4 / PI; 631 | if (ang_error >= 3 && ang_error < 7) { 632 | d0x = ((*dets / 2 - dIndex0) * virdDet + virshift) * cosval; 633 | d0y = ((*dets / 2 - dIndex0) * virdDet + virshift) * sinval; 634 | dPoint0 = map_y(sourcex, sourcey, d0x, d0y); 635 | if (dIndex < *dets) { 636 | scalar_t detx = ((*dets / 2 - dIndex - 1) * virdDet + virshift) * cosval; 637 | scalar_t dety = ((*dets / 2 - dIndex - 1) * virdDet + virshift) * sinval; 638 | dPoint[tx] = map_y(sourcex, sourcey, detx, dety); 639 | } 640 | } else { 641 | d0x = ((dIndex0 - *dets / 2) * virdDet + virshift) * cosval; 642 | d0y = ((dIndex0 - *dets / 2) * virdDet + virshift) * sinval; 643 | dPoint0 = map_y(sourcex, sourcey, d0x, d0y); 644 | if (dIndex < *dets) { 645 | scalar_t detx = ((dIndex + 1 - *dets / 2) * virdDet + virshift) * cosval; 646 | scalar_t dety = ((dIndex + 1 - *dets / 2) * virdDet + virshift) * sinval; 647 | dPoint[tx] = map_y(sourcex, sourcey, detx, dety); 648 | } 649 | } 650 | __syncthreads(); 651 | for (int i = 0; i < ceil(*width / blockDim.x); i++){ 652 | int idxcol = i * blockDim.x + tx; 653 | if (idxcol < *width) { 654 | scalar_t i0x = (idxcol - *width / 2 + 0.5) * dImage; 655 | scalar_t i0y = - *height / 2 * dImage; 656 | int idx0row = floor(((i0x - sourcex) / (d0x - sourcex) * 657 | (d0y - sourcey) + sourcey - i0y) / dImage); 658 | idx0row = max(idx0row, 0); 659 | i0y += idx0row * dImage; 660 | scalar_t threadprj = 0; 661 | scalar_t prebound = map_y(sourcex, sourcey, i0x, i0y); 662 | scalar_t prepixbound = prebound; 663 | prebound = max(prebound, dPoint0); 664 | i0y += dImage; 665 | scalar_t pixbound = map_y(sourcex, sourcey, i0x, i0y); 666 | scalar_t detbound = dPoint[0]; 667 | int idxd = 0, idxi = idx0row; 668 | while (idxi < *height && (idxd + dIndex0) < *dets && idxd < blockDim.x) { 669 | if (detbound <= prebound) { 670 | idxd ++; 671 | if (idxd < blockDim.x) detbound = dPoint[idxd]; 672 | }else if (pixbound <= prebound) { 673 | idxi ++; 674 | i0y += dImage; 675 | prepixbound = pixbound; 676 | pixbound = map_y(sourcex, sourcey, i0x, i0y); 677 | }else if (pixbound < detbound) { 678 | threadprj += (pixbound - prebound) * image[idxchannel][0][static_cast(*height)-1-idxi][idxcol] / (pixbound - prepixbound) * cweight(sourcex, sourcey, i0x, i0y - dImage / 2, s0); 679 | prebound = pixbound; 680 | idxi ++; 681 | i0y += dImage; 682 | prepixbound = pixbound; 683 | pixbound = map_y(sourcex, sourcey, i0x, i0y); 684 | } else { 685 | threadprj += (detbound - prebound) * image[idxchannel][0][static_cast(*height)-1-idxi][idxcol] / (pixbound - prepixbound) * cweight(sourcex, sourcey, i0x, i0y - dImage / 2, s0); 686 | prebound = detbound; 687 | atomicAdd(prj+idxd, threadprj); 688 | threadprj = 0; 689 | idxd ++; 690 | if (idxd < blockDim.x) detbound = dPoint[idxd]; 691 | } 692 | } 693 | if (threadprj != 0) atomicAdd(prj+idxd, threadprj); 694 | } 695 | } 696 | __syncthreads(); 697 | if (dIndex < *dets) { 698 | if (ang_error >= 3 && ang_error < 7) { 699 | projection[idxchannel][0][idxview][static_cast(*dets)-1-dIndex] = prj[tx]; 700 | } else { 701 | projection[idxchannel][0][idxview][dIndex] = prj[tx]; 702 | } 703 | } 704 | } 705 | } 706 | 707 | template 708 | __global__ void bprj_fan_ed( 709 | torch::PackedTensorAccessor image, 710 | const torch::PackedTensorAccessor projection, 711 | const scalar_t* __restrict__ views, const scalar_t* __restrict__ dets, 712 | const scalar_t* __restrict__ width, const scalar_t* __restrict__ height, 713 | const scalar_t* __restrict__ dImg, const scalar_t* __restrict__ dDet, 714 | const scalar_t* __restrict__ Ang0, const scalar_t* __restrict__ dAng, 715 | const scalar_t* __restrict__ s2r, const scalar_t* __restrict__ d2r, 716 | const scalar_t* __restrict__ binshift) { 717 | 718 | __shared__ unsigned int nblocks; 719 | __shared__ unsigned int idxchannel; 720 | __shared__ unsigned int idxview; 721 | nblocks = ceil(*views / gridDim.y); 722 | idxchannel = blockIdx.x % nblocks; 723 | idxview = idxchannel * gridDim.y + blockIdx.y; 724 | if (idxview >= *views) return; 725 | idxchannel = blockIdx.x / nblocks; 726 | __shared__ scalar_t prj[BLOCK_DIM]; 727 | __shared__ scalar_t dPoint[BLOCK_DIM]; 728 | __shared__ scalar_t dImage; 729 | __shared__ scalar_t sourcex; 730 | __shared__ scalar_t sourcey; 731 | __shared__ scalar_t d0x; 732 | __shared__ scalar_t d0y; 733 | __shared__ scalar_t dPoint0; 734 | __shared__ scalar_t s0; 735 | __shared__ double ang; 736 | __shared__ double PI; 737 | __shared__ double ang_error; 738 | __shared__ double cosval; 739 | __shared__ double sinval; 740 | __shared__ double virdDet; 741 | __shared__ double virshift; 742 | __shared__ unsigned int dIndex0; 743 | 744 | PI = acos(-1.0); 745 | ang = idxview * *dAng + *Ang0; 746 | ang_error = abs(ang - round(ang / PI) * PI) * 4 / PI; 747 | cosval = cos(ang); 748 | sinval = sin(ang); 749 | sourcex = - sinval * *s2r; 750 | sourcey = cosval * *s2r; 751 | virdDet = *s2r / (*s2r + *d2r) * *dDet; 752 | virshift = *s2r / (*s2r + *d2r) * *binshift; 753 | s0 = *s2r; 754 | dImage = *dImg; 755 | dIndex0 = blockIdx.z * blockDim.x; 756 | unsigned int tx = threadIdx.x; 757 | unsigned int dIndex = dIndex0 + tx; 758 | __syncthreads(); 759 | if (ang_error <= 1) { 760 | ang_error = (ang - floor(ang / 2 / PI) * 2 * PI) * 4 / PI; 761 | if (ang_error >= 3 && ang_error < 7) { 762 | d0x = ((*dets / 2 - dIndex0) * virdDet + virshift) * cosval; 763 | d0y = ((*dets / 2 - dIndex0) * virdDet + virshift) * sinval; 764 | dPoint0 = map_x(sourcex, sourcey, d0x, d0y); 765 | if (dIndex < *dets) { 766 | scalar_t detx = ((*dets / 2 - dIndex - 1) * virdDet + virshift) * cosval; 767 | scalar_t dety = ((*dets / 2 - dIndex - 1) * virdDet + virshift) * sinval; 768 | dPoint[tx] = map_x(sourcex, sourcey, detx, dety); 769 | } 770 | } else { 771 | d0x = ((dIndex0 - *dets / 2) * virdDet + virshift) * cosval; 772 | d0y = ((dIndex0 - *dets / 2) * virdDet + virshift) * sinval; 773 | dPoint0 = map_x(sourcex, sourcey, d0x, d0y); 774 | if (dIndex < *dets) { 775 | scalar_t detx = ((dIndex + 1 - *dets / 2) * virdDet + virshift) * cosval; 776 | scalar_t dety = ((dIndex + 1 - *dets / 2) * virdDet + virshift) * sinval; 777 | dPoint[tx] = map_x(sourcex, sourcey, detx, dety); 778 | } 779 | } 780 | __syncthreads(); 781 | if (dIndex < *dets) { 782 | if (ang_error >= 3 && ang_error < 7) { 783 | prj[tx] = projection[idxchannel][0][idxview][static_cast(*dets)-1-dIndex]; 784 | } else { 785 | prj[tx] = projection[idxchannel][0][idxview][dIndex]; 786 | } 787 | } else { 788 | prj[tx] = 0; 789 | } 790 | __syncthreads(); 791 | for (int i = 0; i < ceil(*height / blockDim.x); i++){ 792 | int idxrow = i * blockDim.x + tx; 793 | if (idxrow < *height) { 794 | scalar_t i0y = (*height / 2 - idxrow - 0.5) * dImage; 795 | scalar_t i0x = - *width / 2 * dImage; 796 | int idx0col = floor(((i0y - sourcey) / (d0y - sourcey) * 797 | (d0x - sourcex) + sourcex - i0x) / dImage); 798 | idx0col = max(idx0col, 0); 799 | i0x += idx0col * dImage; 800 | scalar_t threadprj = 0; 801 | scalar_t prebound = map_x(sourcex, sourcey, i0x, i0y); 802 | scalar_t prepixbound = prebound; 803 | prebound = max(prebound, dPoint0); 804 | i0x += dImage; 805 | scalar_t pixbound = map_x(sourcex, sourcey, i0x, i0y); 806 | scalar_t detbound = dPoint[0]; 807 | int idxd = 0, idxi = idx0col; 808 | while (idxi < *width && (idxd + dIndex0) < *dets && idxd < blockDim.x) { 809 | if (detbound <= prebound) { 810 | idxd ++; 811 | if (idxd < blockDim.x) detbound = dPoint[idxd]; 812 | }else if (pixbound <= prebound){ 813 | idxi ++; 814 | i0x += dImage; 815 | prepixbound = pixbound; 816 | pixbound = map_x(sourcex, sourcey, i0x, i0y); 817 | }else if (pixbound <= detbound) { 818 | threadprj += (pixbound - prebound) * prj[idxd] / (pixbound - prepixbound); 819 | prebound = pixbound; 820 | threadprj *= cweight(sourcex, sourcey, i0x - dImage / 2, i0y, s0); 821 | atomicAdd(&(image[idxchannel][0][idxrow][idxi]), threadprj); 822 | threadprj = 0; 823 | idxi ++; 824 | i0x += dImage; 825 | prepixbound = pixbound; 826 | pixbound = map_x(sourcex, sourcey, i0x, i0y); 827 | } else { 828 | threadprj += (detbound - prebound) * prj[idxd] / (pixbound - prepixbound); 829 | prebound = detbound; 830 | idxd ++; 831 | if (idxd < blockDim.x) detbound = dPoint[idxd]; 832 | } 833 | } 834 | if (threadprj !=0 ) { 835 | threadprj *= cweight(sourcex, sourcey, i0x - dImage / 2, i0y, s0); 836 | atomicAdd(&(image[idxchannel][0][idxrow][idxi]), threadprj); 837 | } 838 | } 839 | } 840 | } else { 841 | ang_error = (ang - floor(ang / 2 / PI) * 2 * PI) * 4 / PI; 842 | if (ang_error >= 3 && ang_error < 7) { 843 | d0x = ((*dets / 2 - dIndex0) * virdDet + virshift) * cosval; 844 | d0y = ((*dets / 2 - dIndex0) * virdDet + virshift) * sinval; 845 | dPoint0 = map_y(sourcex, sourcey, d0x, d0y); 846 | if (dIndex < *dets) { 847 | scalar_t detx = ((*dets / 2 - dIndex - 1) * virdDet + virshift) * cosval; 848 | scalar_t dety = ((*dets / 2 - dIndex - 1) * virdDet + virshift) * sinval; 849 | dPoint[tx] = map_y(sourcex, sourcey, detx, dety); 850 | } 851 | } else { 852 | d0x = ((dIndex0 - *dets / 2) * virdDet + virshift) * cosval; 853 | d0y = ((dIndex0 - *dets / 2) * virdDet + virshift) * sinval; 854 | dPoint0 = map_y(sourcex, sourcey, d0x, d0y); 855 | if (dIndex < *dets) { 856 | scalar_t detx = ((dIndex + 1 - *dets / 2) * virdDet + virshift) * cosval; 857 | scalar_t dety = ((dIndex + 1 - *dets / 2) * virdDet + virshift) * sinval; 858 | dPoint[tx] = map_y(sourcex, sourcey, detx, dety); 859 | } 860 | } 861 | __syncthreads(); 862 | if (dIndex < *dets) { 863 | if (ang_error >= 3 && ang_error < 7) { 864 | prj[tx] = projection[idxchannel][0][idxview][static_cast(*dets)-1-dIndex]; 865 | } else { 866 | prj[tx] = projection[idxchannel][0][idxview][dIndex]; 867 | } 868 | } else { 869 | prj[tx] = 0; 870 | } 871 | __syncthreads(); 872 | 873 | for (int i = 0; i < ceil(*width / blockDim.x); i++){ 874 | int idxcol = i * blockDim.x + tx; 875 | if (idxcol < *width) { 876 | scalar_t i0x = (idxcol - *width / 2 + 0.5) * dImage; 877 | scalar_t i0y = - *height / 2 * dImage; 878 | int idx0row = floor(((i0x - sourcex) / (d0x - sourcex) * 879 | (d0y - sourcey) + sourcey - i0y) / dImage); 880 | idx0row = max(idx0row, 0); 881 | i0y += idx0row * dImage; 882 | scalar_t threadprj = 0; 883 | scalar_t prebound = map_y(sourcex, sourcey, i0x, i0y); 884 | scalar_t prepixbound = prebound; 885 | prebound = max(prebound, dPoint0); 886 | i0y += dImage; 887 | scalar_t pixbound = map_y(sourcex, sourcey, i0x, i0y); 888 | scalar_t detbound = dPoint[0]; 889 | int idxd = 0, idxi = idx0row; 890 | while (idxi < *height && (idxd + dIndex0) < *dets && idxd < blockDim.x) { 891 | if (detbound <= prebound) { 892 | idxd ++; 893 | if (idxd < blockDim.x) detbound = dPoint[idxd]; 894 | }else if (pixbound <= prebound) { 895 | idxi ++; 896 | i0y += dImage; 897 | prepixbound = pixbound; 898 | pixbound = map_y(sourcex, sourcey, i0x, i0y); 899 | }else if (pixbound <= detbound) { 900 | threadprj += (pixbound - prebound) * prj[idxd] / (pixbound - prepixbound); 901 | prebound = pixbound; 902 | threadprj *= cweight(sourcex, sourcey, i0x, i0y - dImage / 2, s0); 903 | atomicAdd(&(image[idxchannel][0][static_cast(*height)-1-idxi][idxcol]), threadprj); 904 | threadprj = 0; 905 | idxi ++; 906 | i0y += dImage; 907 | prepixbound = pixbound; 908 | pixbound = map_y(sourcex, sourcey, i0x, i0y); 909 | } else { 910 | threadprj += (detbound - prebound) * prj[idxd] / (pixbound - prepixbound); 911 | prebound = detbound; 912 | idxd ++; 913 | if (idxd < blockDim.x) detbound = dPoint[idxd]; 914 | } 915 | } 916 | if (threadprj !=0 ) { 917 | threadprj *= cweight(sourcex, sourcey, i0x, i0y - dImage / 2, s0); 918 | atomicAdd(&(image[idxchannel][0][static_cast(*height)-1-idxi][idxcol]), threadprj); 919 | } 920 | } 921 | } 922 | } 923 | } 924 | 925 | template 926 | __global__ void rlfilter(scalar_t* __restrict__ filter, 927 | const scalar_t* __restrict__ dets, const scalar_t* __restrict__ dDet) { 928 | unsigned xIndex = blockIdx.x * blockDim.x + threadIdx.x; 929 | __shared__ double PI; 930 | __shared__ scalar_t d; 931 | PI = acos(-1.0); 932 | d = *dDet; 933 | if (xIndex < (*dets * 2 - 1)) { 934 | int x = xIndex - *dets + 1; 935 | if ((abs(x) % 2) == 1) { 936 | filter[xIndex] = -1 / (PI * PI * x * x * d * d); 937 | } else if (x == 0) { 938 | filter[xIndex] = 1 / (4 * d * d); 939 | } else { 940 | filter[xIndex] = 0; 941 | } 942 | } 943 | } 944 | 945 | template 946 | __global__ void bprj_sv_fan_ed( 947 | torch::PackedTensorAccessor image, 948 | const torch::PackedTensorAccessor projection, 949 | const scalar_t* __restrict__ views, const scalar_t* __restrict__ dets, 950 | const scalar_t* __restrict__ width, const scalar_t* __restrict__ height, 951 | const scalar_t* __restrict__ dImg, const scalar_t* __restrict__ dDet, 952 | const scalar_t* __restrict__ Ang0, const scalar_t* __restrict__ dAng, 953 | const scalar_t* __restrict__ s2r, const scalar_t* __restrict__ d2r, 954 | const scalar_t* __restrict__ binshift) { 955 | 956 | __shared__ unsigned int nblocks; 957 | __shared__ unsigned int idxchannel; 958 | __shared__ unsigned int idxview; 959 | nblocks = ceil(*views / gridDim.y); 960 | idxchannel = blockIdx.x % nblocks; 961 | idxview = idxchannel * gridDim.y + blockIdx.y; 962 | if (idxview >= *views) return; 963 | idxchannel = blockIdx.x / nblocks; 964 | __shared__ scalar_t prj[BLOCK_DIM]; 965 | __shared__ scalar_t dPoint[BLOCK_DIM]; 966 | __shared__ scalar_t dImage; 967 | __shared__ scalar_t sourcex; 968 | __shared__ scalar_t sourcey; 969 | __shared__ scalar_t d0x; 970 | __shared__ scalar_t d0y; 971 | __shared__ scalar_t dPoint0; 972 | __shared__ scalar_t s0; 973 | __shared__ double ang; 974 | __shared__ double PI; 975 | __shared__ double ang_error; 976 | __shared__ double cosval; 977 | __shared__ double sinval; 978 | __shared__ double virdDet; 979 | __shared__ double virshift; 980 | __shared__ unsigned int dIndex0; 981 | 982 | PI = acos(-1.0); 983 | ang = idxview * *dAng + *Ang0; 984 | ang_error = abs(ang - round(ang / PI) * PI) * 4 / PI; 985 | cosval = cos(ang); 986 | sinval = sin(ang); 987 | sourcex = - sinval * *s2r; 988 | sourcey = cosval * *s2r; 989 | virdDet = *s2r / (*s2r + *d2r) * *dDet; 990 | virshift = *s2r / (*s2r + *d2r) * *binshift; 991 | s0 = *s2r; 992 | dImage = *dImg; 993 | dIndex0 = blockIdx.z * blockDim.x; 994 | unsigned int tx = threadIdx.x; 995 | unsigned int dIndex = dIndex0 + tx; 996 | __syncthreads(); 997 | if (ang_error <= 1) { 998 | ang_error = (ang - floor(ang / 2 / PI) * 2 * PI) * 4 / PI; 999 | if (ang_error >= 3 && ang_error < 7) { 1000 | d0x = ((*dets / 2 - dIndex0) * virdDet + virshift) * cosval; 1001 | d0y = ((*dets / 2 - dIndex0) * virdDet + virshift) * sinval; 1002 | dPoint0 = map_x(sourcex, sourcey, d0x, d0y); 1003 | if (dIndex < *dets) { 1004 | scalar_t detx = ((*dets / 2 - dIndex - 1) * virdDet + virshift) * cosval; 1005 | scalar_t dety = ((*dets / 2 - dIndex - 1) * virdDet + virshift) * sinval; 1006 | dPoint[tx] = map_x(sourcex, sourcey, detx, dety); 1007 | } 1008 | } else { 1009 | d0x = ((dIndex0 - *dets / 2) * virdDet + virshift) * cosval; 1010 | d0y = ((dIndex0 - *dets / 2) * virdDet + virshift) * sinval; 1011 | dPoint0 = map_x(sourcex, sourcey, d0x, d0y); 1012 | if (dIndex < *dets) { 1013 | scalar_t detx = ((dIndex + 1 - *dets / 2) * virdDet + virshift) * cosval; 1014 | scalar_t dety = ((dIndex + 1 - *dets / 2) * virdDet + virshift) * sinval; 1015 | dPoint[tx] = map_x(sourcex, sourcey, detx, dety); 1016 | } 1017 | } 1018 | __syncthreads(); 1019 | if (dIndex < *dets) { 1020 | if (ang_error >= 3 && ang_error < 7) { 1021 | prj[tx] = projection[idxchannel][0][idxview][static_cast(*dets)-1-dIndex]; 1022 | } else { 1023 | prj[tx] = projection[idxchannel][0][idxview][dIndex]; 1024 | } 1025 | } else { 1026 | prj[tx] = 0; 1027 | } 1028 | __syncthreads(); 1029 | for (int i = 0; i < ceil(*height / blockDim.x); i++){ 1030 | int idxrow = i * blockDim.x + tx; 1031 | if (idxrow < *height) { 1032 | scalar_t i0y = (*height / 2 - idxrow - 0.5) * dImage; 1033 | scalar_t i0x = - *width / 2 * dImage; 1034 | int idx0col = floor(((i0y - sourcey) / (d0y - sourcey) * 1035 | (d0x - sourcex) + sourcex - i0x) / dImage); 1036 | idx0col = max(idx0col, 0); 1037 | i0x += idx0col * dImage; 1038 | scalar_t threadprj = 0; 1039 | scalar_t prebound = map_x(sourcex, sourcey, i0x, i0y); 1040 | scalar_t prepixbound = prebound; 1041 | prebound = max(prebound, dPoint0); 1042 | i0x += dImage; 1043 | scalar_t pixbound = map_x(sourcex, sourcey, i0x, i0y); 1044 | scalar_t detbound = dPoint[0]; 1045 | int idxd = 0, idxi = idx0col; 1046 | while (idxi < *width && (idxd + dIndex0) < *dets && idxd < blockDim.x) { 1047 | if (detbound <= prebound) { 1048 | idxd ++; 1049 | if (idxd < blockDim.x) detbound = dPoint[idxd]; 1050 | }else if (pixbound <= prebound){ 1051 | idxi ++; 1052 | i0x += dImage; 1053 | prepixbound = pixbound; 1054 | pixbound = map_x(sourcex, sourcey, i0x, i0y); 1055 | }else if (pixbound <= detbound) { 1056 | threadprj += (pixbound - prebound) * prj[idxd] / (pixbound - prepixbound); 1057 | prebound = pixbound; 1058 | threadprj *= cweight(sourcex, sourcey, i0x - dImage / 2, i0y, s0); 1059 | atomicAdd(&(image[idxchannel][idxview][idxrow][idxi]), threadprj); 1060 | threadprj = 0; 1061 | idxi ++; 1062 | i0x += dImage; 1063 | prepixbound = pixbound; 1064 | pixbound = map_x(sourcex, sourcey, i0x, i0y); 1065 | } else { 1066 | threadprj += (detbound - prebound) * prj[idxd] / (pixbound - prepixbound); 1067 | prebound = detbound; 1068 | idxd ++; 1069 | if (idxd < blockDim.x) detbound = dPoint[idxd]; 1070 | } 1071 | } 1072 | if (threadprj !=0 ) { 1073 | threadprj *= cweight(sourcex, sourcey, i0x - dImage / 2, i0y, s0); 1074 | atomicAdd(&(image[idxchannel][idxview][idxrow][idxi]), threadprj); 1075 | } 1076 | } 1077 | } 1078 | } else { 1079 | ang_error = (ang - floor(ang / 2 / PI) * 2 * PI) * 4 / PI; 1080 | if (ang_error >= 3 && ang_error < 7) { 1081 | d0x = ((*dets / 2 - dIndex0) * virdDet + virshift) * cosval; 1082 | d0y = ((*dets / 2 - dIndex0) * virdDet + virshift) * sinval; 1083 | dPoint0 = map_y(sourcex, sourcey, d0x, d0y); 1084 | if (dIndex < *dets) { 1085 | scalar_t detx = ((*dets / 2 - dIndex - 1) * virdDet + virshift) * cosval; 1086 | scalar_t dety = ((*dets / 2 - dIndex - 1) * virdDet + virshift) * sinval; 1087 | dPoint[tx] = map_y(sourcex, sourcey, detx, dety); 1088 | } 1089 | } else { 1090 | d0x = ((dIndex0 - *dets / 2) * virdDet + virshift) * cosval; 1091 | d0y = ((dIndex0 - *dets / 2) * virdDet + virshift) * sinval; 1092 | dPoint0 = map_y(sourcex, sourcey, d0x, d0y); 1093 | if (dIndex < *dets) { 1094 | scalar_t detx = ((dIndex + 1 - *dets / 2) * virdDet + virshift) * cosval; 1095 | scalar_t dety = ((dIndex + 1 - *dets / 2) * virdDet + virshift) * sinval; 1096 | dPoint[tx] = map_y(sourcex, sourcey, detx, dety); 1097 | } 1098 | } 1099 | __syncthreads(); 1100 | if (dIndex < *dets) { 1101 | if (ang_error >= 3 && ang_error < 7) { 1102 | prj[tx] = projection[idxchannel][0][idxview][static_cast(*dets)-1-dIndex]; 1103 | } else { 1104 | prj[tx] = projection[idxchannel][0][idxview][dIndex]; 1105 | } 1106 | } else { 1107 | prj[tx] = 0; 1108 | } 1109 | __syncthreads(); 1110 | 1111 | for (int i = 0; i < ceil(*width / blockDim.x); i++){ 1112 | int idxcol = i * blockDim.x + tx; 1113 | if (idxcol < *width) { 1114 | scalar_t i0x = (idxcol - *width / 2 + 0.5) * dImage; 1115 | scalar_t i0y = - *height / 2 * dImage; 1116 | int idx0row = floor(((i0x - sourcex) / (d0x - sourcex) * 1117 | (d0y - sourcey) + sourcey - i0y) / dImage); 1118 | idx0row = max(idx0row, 0); 1119 | i0y += idx0row * dImage; 1120 | scalar_t threadprj = 0; 1121 | scalar_t prebound = map_y(sourcex, sourcey, i0x, i0y); 1122 | scalar_t prepixbound = prebound; 1123 | prebound = max(prebound, dPoint0); 1124 | i0y += dImage; 1125 | scalar_t pixbound = map_y(sourcex, sourcey, i0x, i0y); 1126 | scalar_t detbound = dPoint[0]; 1127 | int idxd = 0, idxi = idx0row; 1128 | while (idxi < *height && (idxd + dIndex0) < *dets && idxd < blockDim.x) { 1129 | if (detbound <= prebound) { 1130 | idxd ++; 1131 | if (idxd < blockDim.x) detbound = dPoint[idxd]; 1132 | }else if (pixbound <= prebound) { 1133 | idxi ++; 1134 | i0y += dImage; 1135 | prepixbound = pixbound; 1136 | pixbound = map_y(sourcex, sourcey, i0x, i0y); 1137 | }else if (pixbound <= detbound) { 1138 | threadprj += (pixbound - prebound) * prj[idxd] / (pixbound - prepixbound); 1139 | prebound = pixbound; 1140 | threadprj *= cweight(sourcex, sourcey, i0x, i0y - dImage / 2, s0); 1141 | atomicAdd(&(image[idxchannel][idxview][static_cast(*height)-1-idxi][idxcol]), threadprj); 1142 | threadprj = 0; 1143 | idxi ++; 1144 | i0y += dImage; 1145 | prepixbound = pixbound; 1146 | pixbound = map_y(sourcex, sourcey, i0x, i0y); 1147 | } else { 1148 | threadprj += (detbound - prebound) * prj[idxd] / (pixbound - prepixbound); 1149 | prebound = detbound; 1150 | idxd ++; 1151 | if (idxd < blockDim.x) detbound = dPoint[idxd]; 1152 | } 1153 | } 1154 | if (threadprj !=0 ) { 1155 | threadprj *= cweight(sourcex, sourcey, i0x, i0y - dImage / 2, s0); 1156 | atomicAdd(&(image[idxchannel][idxview][static_cast(*height)-1-idxi][idxcol]), threadprj); 1157 | } 1158 | } 1159 | } 1160 | } 1161 | } 1162 | 1163 | torch::Tensor prj_fan_ed_cuda(torch::Tensor image, torch::Tensor options) { 1164 | cudaSetDevice(image.device().index()); 1165 | auto views = options[0]; 1166 | auto dets = options[1]; 1167 | auto width = options[2]; 1168 | auto height = options[3]; 1169 | auto dImg = options[4]; 1170 | auto dDet = options[5]; 1171 | auto Ang0 = options[6]; 1172 | auto dAng = options[7]; 1173 | auto s2r = options[8]; 1174 | auto d2r = options[9]; 1175 | auto binshift = options[10]; 1176 | const int channels = static_cast(image.size(0)); 1177 | auto projection = torch::empty({channels, 1, views.item(), dets.item()}, image.options()); 1178 | 1179 | int nblocksx = ceil(views.item() / GRID_DIM) * channels; 1180 | int nblocksy = min(views.item(), GRID_DIM); 1181 | int nblocksz = ceil(dets.item() / BLOCK_DIM); 1182 | const dim3 blocks(nblocksx, nblocksy, nblocksz); 1183 | 1184 | AT_DISPATCH_FLOATING_TYPES(image.type(), "fan_beam_equal_distance_projection", ([&] { 1185 | prj_fan_ed<<>>( 1186 | image.packed_accessor(), 1187 | projection.packed_accessor(), 1188 | views.data(), dets.data(), width.data(), 1189 | height.data(), dImg.data(), dDet.data(), 1190 | Ang0.data(), dAng.data(), s2r.data(), 1191 | d2r.data(), binshift.data() 1192 | ); 1193 | })); 1194 | return projection; 1195 | } 1196 | 1197 | torch::Tensor prj_t_fan_ed_cuda(torch::Tensor projection, torch::Tensor options) { 1198 | cudaSetDevice(projection.device().index()); 1199 | auto views = options[0]; 1200 | auto dets = options[1]; 1201 | auto width = options[2]; 1202 | auto height = options[3]; 1203 | auto dImg = options[4]; 1204 | auto dDet = options[5]; 1205 | auto Ang0 = options[6]; 1206 | auto dAng = options[7]; 1207 | auto s2r = options[8]; 1208 | auto d2r = options[9]; 1209 | auto binshift = options[10]; 1210 | const int channels = static_cast(projection.size(0)); 1211 | auto image = torch::zeros({channels, 1, height.item(), width.item()}, projection.options()); 1212 | 1213 | int nblocksx = ceil(views.item() / GRID_DIM) * channels; 1214 | int nblocksy = min(views.item(), GRID_DIM); 1215 | int nblocksz = ceil(dets.item() / BLOCK_DIM); 1216 | const dim3 blocks(nblocksx, nblocksy, nblocksz); 1217 | 1218 | AT_DISPATCH_FLOATING_TYPES(projection.type(), "fan_beam_equal_distance_backprojection", ([&] { 1219 | prj_t_fan_ed<<>>( 1220 | image.packed_accessor(), 1221 | projection.packed_accessor(), 1222 | views.data(), dets.data(), width.data(), 1223 | height.data(), dImg.data(), dDet.data(), 1224 | Ang0.data(), dAng.data(), s2r.data(), 1225 | d2r.data(), binshift.data() 1226 | ); 1227 | })); 1228 | return image; 1229 | } 1230 | 1231 | torch::Tensor bprj_t_fan_ed_cuda(torch::Tensor image, torch::Tensor options) { 1232 | cudaSetDevice(image.device().index()); 1233 | auto views = options[0]; 1234 | auto dets = options[1]; 1235 | auto width = options[2]; 1236 | auto height = options[3]; 1237 | auto dImg = options[4]; 1238 | auto dDet = options[5]; 1239 | auto Ang0 = options[6]; 1240 | auto dAng = options[7]; 1241 | auto s2r = options[8]; 1242 | auto d2r = options[9]; 1243 | auto binshift = options[10]; 1244 | const int channels = static_cast(image.size(0)); 1245 | auto projection = torch::empty({channels, 1, views.item(), dets.item()}, image.options()); 1246 | 1247 | int nblocksx = ceil(views.item() / GRID_DIM) * channels; 1248 | int nblocksy = min(views.item(), GRID_DIM); 1249 | int nblocksz = ceil(dets.item() / BLOCK_DIM); 1250 | const dim3 blocks(nblocksx, nblocksy, nblocksz); 1251 | 1252 | AT_DISPATCH_FLOATING_TYPES(image.type(), "fan_beam_equal_distance_fbp_projection", ([&] { 1253 | bprj_t_fan_ed<<>>( 1254 | image.packed_accessor(), 1255 | projection.packed_accessor(), 1256 | views.data(), dets.data(), width.data(), 1257 | height.data(), dImg.data(), dDet.data(), 1258 | Ang0.data(), dAng.data(), s2r.data(), 1259 | d2r.data(), binshift.data() 1260 | ); 1261 | })); 1262 | return projection; 1263 | } 1264 | 1265 | torch::Tensor bprj_fan_ed_cuda(torch::Tensor projection, torch::Tensor options) { 1266 | cudaSetDevice(projection.device().index()); 1267 | auto views = options[0]; 1268 | auto dets = options[1]; 1269 | auto width = options[2]; 1270 | auto height = options[3]; 1271 | auto dImg = options[4]; 1272 | auto dDet = options[5]; 1273 | auto Ang0 = options[6]; 1274 | auto dAng = options[7]; 1275 | auto s2r = options[8]; 1276 | auto d2r = options[9]; 1277 | auto binshift = options[10]; 1278 | const int channels = static_cast(projection.size(0)); 1279 | auto image = torch::zeros({channels, 1, height.item(), width.item()}, projection.options()); 1280 | 1281 | int nblocksx = ceil(views.item() / GRID_DIM) * channels; 1282 | int nblocksy = min(views.item(), GRID_DIM); 1283 | int nblocksz = ceil(dets.item() / BLOCK_DIM); 1284 | const dim3 blocks(nblocksx, nblocksy, nblocksz); 1285 | 1286 | AT_DISPATCH_FLOATING_TYPES(projection.type(), "fan_beam_equal_distance_fbp_backprojection", ([&] { 1287 | bprj_fan_ed<<>>( 1288 | image.packed_accessor(), 1289 | projection.packed_accessor(), 1290 | views.data(), dets.data(), width.data(), 1291 | height.data(), dImg.data(), dDet.data(), 1292 | Ang0.data(), dAng.data(), s2r.data(), 1293 | d2r.data(), binshift.data() 1294 | ); 1295 | })); 1296 | return image; 1297 | } 1298 | 1299 | torch::Tensor bprj_sv_fan_ed_cuda(torch::Tensor projection, torch::Tensor options) { 1300 | cudaSetDevice(projection.device().index()); 1301 | auto views = options[0]; 1302 | auto dets = options[1]; 1303 | auto width = options[2]; 1304 | auto height = options[3]; 1305 | auto dImg = options[4]; 1306 | auto dDet = options[5]; 1307 | auto Ang0 = options[6]; 1308 | auto dAng = options[7]; 1309 | auto s2r = options[8]; 1310 | auto d2r = options[9]; 1311 | auto binshift = options[10]; 1312 | const int channels = static_cast(projection.size(0)); 1313 | auto image = torch::zeros({channels, views.item(), height.item(), width.item()}, projection.options()); 1314 | 1315 | int nblocksx = ceil(views.item() / GRID_DIM) * channels; 1316 | int nblocksy = min(views.item(), GRID_DIM); 1317 | int nblocksz = ceil(dets.item() / BLOCK_DIM); 1318 | const dim3 blocks(nblocksx, nblocksy, nblocksz); 1319 | 1320 | AT_DISPATCH_FLOATING_TYPES(projection.type(), "fan_beam_equal_distance_backprojection_single_view", ([&] { 1321 | bprj_sv_fan_ed<<>>( 1322 | image.packed_accessor(), 1323 | projection.packed_accessor(), 1324 | views.data(), dets.data(), width.data(), 1325 | height.data(), dImg.data(), dDet.data(), 1326 | Ang0.data(), dAng.data(), s2r.data(), 1327 | d2r.data(), binshift.data() 1328 | ); 1329 | })); 1330 | return image; 1331 | } 1332 | 1333 | torch::Tensor fbp_fan_ed_cuda(torch::Tensor projection, torch::Tensor options) { 1334 | cudaSetDevice(projection.device().index()); 1335 | auto views = options[0]; 1336 | auto dets = options[1]; 1337 | auto width = options[2]; 1338 | auto height = options[3]; 1339 | auto dImg = options[4]; 1340 | auto dDet = options[5]; 1341 | auto Ang0 = options[6]; 1342 | auto dAng = options[7]; 1343 | auto s2r = options[8]; 1344 | auto d2r = options[9]; 1345 | auto binshift = options[10]; 1346 | const int channels = static_cast(projection.size(0)); 1347 | auto image = torch::zeros({channels, 1, height.item(), width.item()}, projection.options()); 1348 | auto filter = torch::empty({1,1,1,dets.item()*2-1}, projection.options()); 1349 | auto virdDet = dDet * s2r / (s2r + d2r); 1350 | auto rectweight = torch::arange((-dets.item()/2+0.5), dets.item()/2, 1, projection.options()); 1351 | rectweight = rectweight * virdDet; 1352 | rectweight = torch::pow(rectweight, 2); 1353 | rectweight = rectweight.view({1, 1, 1, dets.item()}); 1354 | rectweight = s2r / torch::sqrt(s2r * s2r + rectweight); 1355 | rectweight = projection * rectweight * virdDet; 1356 | 1357 | int filterdim = ceil((dets.item()*2-1) / BLOCK_DIM); 1358 | int nblocksx = ceil(views.item() / GRID_DIM) * channels; 1359 | int nblocksy = min(views.item(), GRID_DIM); 1360 | int nblocksz = ceil(dets.item() / BLOCK_DIM); 1361 | const dim3 blocks(nblocksx, nblocksy, nblocksz); 1362 | 1363 | AT_DISPATCH_FLOATING_TYPES(projection.type(), "ramp_filter", ([&] { 1364 | rlfilter<<>>( 1365 | filter.data(), dets.data(), virdDet.data()); 1366 | })); 1367 | 1368 | auto filtered_projection = torch::conv2d(rectweight, filter, {}, 1, torch::IntArrayRef({0, dets.item()-1})); 1369 | 1370 | AT_DISPATCH_FLOATING_TYPES(projection.type(), "fan_beam_equal_distance_fbp_backprojection", ([&] { 1371 | bprj_fan_ed<<>>( 1372 | image.packed_accessor(), 1373 | filtered_projection.packed_accessor(), 1374 | views.data(), dets.data(), width.data(), 1375 | height.data(), dImg.data(), dDet.data(), 1376 | Ang0.data(), dAng.data(), s2r.data(), 1377 | d2r.data(), binshift.data() 1378 | ); 1379 | })); 1380 | image = image * dAng / 2; 1381 | return image; 1382 | } 1383 | -------------------------------------------------------------------------------- /src/laplacian_cuda.h: -------------------------------------------------------------------------------- 1 | #ifndef LAPLACIAN_CUDA_H 2 | #define LAPLACIAN_CUDA_H 3 | 4 | #include 5 | 6 | torch::Tensor laplacian_cuda_forward(torch::Tensor input, int k); 7 | 8 | #endif -------------------------------------------------------------------------------- /src/laplacian_cuda_kernel.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #define BLOCK_DIM_1 16 7 | #define BLOCK_DIM_2 256 8 | 9 | template 10 | __global__ void compute_squared_norm( 11 | const torch::PackedTensorAccessor array, 12 | int point_num, int dimension, scalar_t* __restrict__ norm){ 13 | unsigned int xIndex = blockIdx.x * blockDim.x + threadIdx.x; 14 | if (xIndex < point_num){ 15 | scalar_t sum = 0; 16 | for (int i = 0; i < dimension; i++){ 17 | scalar_t val = array[xIndex][i]; 18 | sum += val * val; 19 | } 20 | norm[xIndex] = sum; 21 | } 22 | } 23 | 24 | template 25 | __global__ void add_norm( 26 | torch::PackedTensorAccessor array, 27 | int point_num, scalar_t* __restrict__ norm){ 28 | unsigned int tx = threadIdx.x; 29 | unsigned int ty = threadIdx.y; 30 | unsigned int xIndex = blockIdx.x * blockDim.x + tx; 31 | unsigned int yIndex = blockIdx.y * blockDim.y + ty; 32 | unsigned int yyIndex = blockIdx.x * blockDim.x + ty; 33 | __shared__ scalar_t shared_vec[2][BLOCK_DIM_1]; 34 | if (tx == 0 && yIndex < point_num) 35 | shared_vec[0][ty] = norm[yIndex]; 36 | else if (tx == 1 && yyIndex < point_num) 37 | shared_vec[1][ty] = norm[yyIndex]; 38 | __syncthreads(); 39 | if (xIndex < point_num && yIndex < point_num){ 40 | array[xIndex][yIndex] = shared_vec[0][ty]; 41 | array[xIndex][yIndex] += shared_vec[1][tx]; 42 | } 43 | } 44 | 45 | 46 | torch::Tensor laplacian_cuda_forward(torch::Tensor input, int k) { 47 | cudaSetDevice(input.device().index()); 48 | const int point_num = (int)input.size(0); 49 | const int dimension = (int)input.size(1); 50 | auto options = input.options(); 51 | auto norm2 = torch::empty({point_num}, options); 52 | auto dist = torch::empty({point_num, point_num}, options); 53 | int n_block1 = point_num / BLOCK_DIM_1; 54 | if (point_num % BLOCK_DIM_1 != 0) n_block1 += 1; 55 | const dim3 threads1(BLOCK_DIM_1, BLOCK_DIM_1); 56 | const dim3 blocks1(n_block1, n_block1); 57 | 58 | int n_block2 = point_num / BLOCK_DIM_2; 59 | if (point_num % BLOCK_DIM_2 != 0) n_block2 += 1; 60 | const dim3 threads2(BLOCK_DIM_2); 61 | const dim3 blocks2(n_block2); 62 | 63 | cublasStatus_t stat; 64 | cublasHandle_t handle; 65 | 66 | stat = cublasCreate(&handle); 67 | if (stat != CUBLAS_STATUS_SUCCESS) { 68 | printf ("CUBLAS initialization failed\n"); 69 | } 70 | 71 | AT_DISPATCH_FLOATING_TYPES(input.type(), "compute_squared_norm", ([&] { 72 | compute_squared_norm<<>>( 73 | input.packed_accessor(), 74 | point_num, dimension, norm2.data() 75 | ); 76 | })); 77 | 78 | AT_DISPATCH_FLOATING_TYPES(input.type(), "add_norm", ([&] { 79 | add_norm<<>>( 80 | dist.packed_accessor(), 81 | point_num, norm2.data() 82 | ); 83 | })); 84 | 85 | if (input.dtype() == torch::kFloat32){ 86 | float alpha = -2.0; 87 | float beta = 1.0; 88 | stat = cublasSgemm( 89 | handle, CUBLAS_OP_T, CUBLAS_OP_N, point_num, 90 | point_num, dimension, &alpha, 91 | input.data(), dimension, 92 | input.data(), dimension, 93 | &beta, 94 | dist.data(), point_num 95 | ); 96 | }else{ 97 | double alpha = -2.0; 98 | double beta = 1.0; 99 | stat = cublasDgemm( 100 | handle, CUBLAS_OP_T, CUBLAS_OP_N, point_num, 101 | point_num, dimension, &alpha, 102 | input.data(), dimension, 103 | input.data(), dimension, 104 | &beta, 105 | dist.data(), point_num 106 | ); 107 | } 108 | if (stat != CUBLAS_STATUS_SUCCESS) { 109 | printf("CUBLAS computation failed\n"); 110 | } 111 | cublasDestroy(handle); 112 | auto topkRes = torch::topk(dist, k, 1, false); 113 | auto coef = std::get<0>(topkRes); 114 | auto indj = std::get<1>(topkRes); 115 | auto median = torch::median(coef); 116 | coef = coef / median; 117 | coef = torch::exp(-coef); 118 | auto coef_sum = coef.sum(1,true); 119 | coef = coef / coef_sum; 120 | coef = coef.view(-1); 121 | options = options.dtype(torch::kLong); 122 | auto indi = torch::arange(point_num, options).unsqueeze_(1).repeat({1, k}); 123 | auto ind1 = torch::stack({indi, indj}).view({2, -1}); 124 | options = options.dtype(torch::kFloat32); 125 | options = options.layout(torch::kSparse); 126 | auto W = torch::sparse_coo_tensor(ind1, coef, {point_num, point_num}, options); 127 | return W; 128 | } -------------------------------------------------------------------------------- /src/para_cuda.h: -------------------------------------------------------------------------------- 1 | #ifndef PARA_CUDA_H 2 | #define PARA_CUDA_H 3 | #include 4 | 5 | torch::Tensor prj_para_cuda(torch::Tensor image, torch::Tensor options); 6 | torch::Tensor prj_t_para_cuda(torch::Tensor projection, torch::Tensor options); 7 | torch::Tensor bprj_t_para_cuda(torch::Tensor image, torch::Tensor options); 8 | torch::Tensor bprj_para_cuda(torch::Tensor projection, torch::Tensor options); 9 | torch::Tensor fbp_para_cuda(torch::Tensor projection, torch::Tensor options); 10 | 11 | #endif -------------------------------------------------------------------------------- /src/para_kernel.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #define BLOCK_DIM 256 5 | #define GRID_DIM 512 6 | 7 | template 8 | __global__ void prj_para( 9 | const torch::PackedTensorAccessor image, 10 | torch::PackedTensorAccessor projection, 11 | const scalar_t* __restrict__ views, const scalar_t* __restrict__ dets, 12 | const scalar_t* __restrict__ width, const scalar_t* __restrict__ height, 13 | const scalar_t* __restrict__ dImg, const scalar_t* __restrict__ dDet, 14 | const scalar_t* __restrict__ Ang0, const scalar_t* __restrict__ dAng, 15 | const scalar_t* __restrict__ binshift) { 16 | 17 | __shared__ unsigned int nblocks; 18 | __shared__ unsigned int idxchannel; 19 | __shared__ unsigned int idxview; 20 | nblocks = ceil(*views / gridDim.y); 21 | idxchannel = blockIdx.x % nblocks; 22 | idxview = idxchannel * gridDim.y + blockIdx.y; 23 | if (idxview >= *views) return; 24 | idxchannel = blockIdx.x / nblocks; 25 | __shared__ scalar_t prj[BLOCK_DIM]; 26 | __shared__ scalar_t dPoint[BLOCK_DIM]; 27 | __shared__ scalar_t dImage; 28 | __shared__ scalar_t dPoint0; 29 | __shared__ double ang; 30 | __shared__ double PI; 31 | __shared__ double ang_error; 32 | __shared__ double cosval; 33 | __shared__ double sinval; 34 | __shared__ unsigned int dIndex0; 35 | __shared__ scalar_t dinterval; 36 | 37 | PI = acos(-1.0); 38 | ang = idxview * *dAng + *Ang0; 39 | dImage = *dImg; 40 | ang_error = abs(ang - round(ang / PI) * PI) * 4 / PI; 41 | cosval = cos(ang); 42 | sinval = sin(ang); 43 | dinterval = dImage / *dDet; 44 | dIndex0 = blockIdx.z * blockDim.x; 45 | unsigned int tx = threadIdx.x; 46 | unsigned int dIndex = dIndex0 + tx; 47 | prj[tx] = 0; 48 | __syncthreads(); 49 | if (ang_error <= 1) { 50 | ang_error = (ang - floor(ang / 2 / PI) * 2 * PI) * 4 / PI; 51 | if (ang_error >= 3 && ang_error < 7) { 52 | dPoint0 = ((*dets / 2 - dIndex0) * *dDet + *binshift) / cosval; 53 | if (dIndex < *dets) { 54 | dPoint[tx] = ((*dets / 2 - dIndex - 1) * *dDet + *binshift) / cosval; 55 | } 56 | } else { 57 | dPoint0 = ((dIndex0 - *dets / 2) * *dDet + *binshift) / cosval; 58 | if (dIndex < *dets) { 59 | dPoint[tx] = ((dIndex + 1 - *dets / 2) * *dDet + *binshift) / cosval; 60 | } 61 | } 62 | __syncthreads(); 63 | for (int i = 0; i < ceil(*height / blockDim.x); i++){ 64 | int idxrow = i * blockDim.x + tx; 65 | if (idxrow < *height) { 66 | scalar_t i0y = (*height / 2 - idxrow - 0.5) * dImage; 67 | scalar_t i0x = - *width / 2 * dImage; 68 | scalar_t pixbound = sinval / cosval * i0y + i0x; 69 | int idx0col = floor((dPoint0 - pixbound) / dImage); 70 | idx0col = max(idx0col, 0); 71 | pixbound += idx0col * dImage; 72 | scalar_t threadprj = 0; 73 | scalar_t prebound = max(pixbound, dPoint0); 74 | pixbound += dImage; 75 | scalar_t detbound = dPoint[0]; 76 | int idxd = 0, idxi = idx0col; 77 | while (idxi < *width && (idxd + dIndex0) < *dets && idxd < blockDim.x) { 78 | if (detbound <= prebound) { 79 | idxd ++; 80 | if (idxd < blockDim.x) detbound = dPoint[idxd]; 81 | }else if (pixbound <= prebound){ 82 | idxi ++; 83 | pixbound += dImage; 84 | }else if (pixbound < detbound) { 85 | threadprj += (pixbound - prebound) * image[idxchannel][0][idxrow][idxi] * dinterval; 86 | prebound = pixbound; 87 | idxi ++; 88 | pixbound += dImage; 89 | } else { 90 | threadprj += (detbound - prebound) * image[idxchannel][0][idxrow][idxi] * dinterval; 91 | prebound = detbound; 92 | atomicAdd(prj+idxd, threadprj); 93 | threadprj = 0; 94 | idxd ++; 95 | if (idxd < blockDim.x) detbound = dPoint[idxd]; 96 | } 97 | } 98 | if (threadprj != 0) atomicAdd(prj+idxd, threadprj); 99 | } 100 | } 101 | __syncthreads(); 102 | if (dIndex < *dets) { 103 | if (ang_error >= 3 && ang_error < 7) { 104 | projection[idxchannel][0][idxview][static_cast(*dets)-1-dIndex] = prj[tx]; 105 | } else { 106 | projection[idxchannel][0][idxview][dIndex] = prj[tx]; 107 | } 108 | } 109 | } else { 110 | ang_error = (ang - floor(ang / 2 / PI) * 2 * PI) * 4 / PI; 111 | if (ang_error >= 3 && ang_error < 7) { 112 | dPoint0 = ((*dets / 2 - dIndex0) * *dDet + *binshift) / sinval; 113 | if (dIndex < *dets) { 114 | dPoint[tx] = ((*dets / 2 - dIndex - 1) * *dDet + *binshift) / sinval; 115 | } 116 | } else { 117 | dPoint0 = ((dIndex0 - *dets / 2) * *dDet + *binshift) / sinval; 118 | if (dIndex < *dets) { 119 | dPoint[tx] = ((dIndex + 1 - *dets / 2) * *dDet + *binshift) / sinval; 120 | } 121 | } 122 | __syncthreads(); 123 | for (int i = 0; i < ceil(*width / blockDim.x); i++){ 124 | int idxcol = i * blockDim.x + tx; 125 | if (idxcol < *width) { 126 | scalar_t i0x = (idxcol - *width / 2 + 0.5) * dImage; 127 | scalar_t i0y = - *height / 2 * dImage; 128 | scalar_t pixbound = i0y + cosval / sinval * i0x; 129 | int idx0row = floor((dPoint0 - pixbound) / dImage); 130 | idx0row = max(idx0row, 0); 131 | pixbound += idx0row * dImage; 132 | scalar_t threadprj = 0; 133 | scalar_t prebound = max(pixbound, dPoint0); 134 | pixbound += dImage; 135 | scalar_t detbound = dPoint[0]; 136 | int idxd = 0, idxi = idx0row; 137 | while (idxi < *height && (idxd + dIndex0) < *dets && idxd < blockDim.x) { 138 | if (detbound <= prebound) { 139 | idxd ++; 140 | if (idxd < blockDim.x) detbound = dPoint[idxd]; 141 | }else if (pixbound <= prebound) { 142 | idxi ++; 143 | pixbound += dImage; 144 | }else if (pixbound < detbound) { 145 | threadprj += (pixbound - prebound) * image[idxchannel][0][static_cast(*height)-1-idxi][idxcol] * dinterval; 146 | prebound = pixbound; 147 | idxi ++; 148 | pixbound += dImage; 149 | } else { 150 | threadprj += (detbound - prebound) * image[idxchannel][0][static_cast(*height)-1-idxi][idxcol] * dinterval; 151 | prebound = detbound; 152 | atomicAdd(prj+idxd, threadprj); 153 | threadprj = 0; 154 | idxd ++; 155 | if (idxd < blockDim.x) detbound = dPoint[idxd]; 156 | } 157 | } 158 | if (threadprj != 0) atomicAdd(prj+idxd, threadprj); 159 | } 160 | } 161 | __syncthreads(); 162 | if (dIndex < *dets) { 163 | if (ang_error >= 3 && ang_error < 7) { 164 | projection[idxchannel][0][idxview][static_cast(*dets)-1-dIndex] = prj[tx]; 165 | } else { 166 | projection[idxchannel][0][idxview][dIndex] = prj[tx]; 167 | } 168 | } 169 | } 170 | } 171 | 172 | template 173 | __global__ void prj_t_para( 174 | torch::PackedTensorAccessor image, 175 | const torch::PackedTensorAccessor projection, 176 | const scalar_t* __restrict__ views, const scalar_t* __restrict__ dets, 177 | const scalar_t* __restrict__ width, const scalar_t* __restrict__ height, 178 | const scalar_t* __restrict__ dImg, const scalar_t* __restrict__ dDet, 179 | const scalar_t* __restrict__ Ang0, const scalar_t* __restrict__ dAng, 180 | const scalar_t* __restrict__ binshift) { 181 | 182 | __shared__ unsigned int nblocks; 183 | __shared__ unsigned int idxchannel; 184 | __shared__ unsigned int idxview; 185 | nblocks = ceil(*views / gridDim.y); 186 | idxchannel = blockIdx.x % nblocks; 187 | idxview = idxchannel * gridDim.y + blockIdx.y; 188 | if (idxview >= *views) return; 189 | idxchannel = blockIdx.x / nblocks; 190 | __shared__ scalar_t prj[BLOCK_DIM]; 191 | __shared__ scalar_t dPoint[BLOCK_DIM]; 192 | __shared__ scalar_t dImage; 193 | __shared__ scalar_t dPoint0; 194 | __shared__ double ang; 195 | __shared__ double PI; 196 | __shared__ double ang_error; 197 | __shared__ double cosval; 198 | __shared__ double sinval; 199 | __shared__ unsigned int dIndex0; 200 | __shared__ scalar_t dinterval; 201 | 202 | PI = acos(-1.0); 203 | ang = idxview * *dAng + *Ang0; 204 | dImage = *dImg; 205 | ang_error = abs(ang - round(ang / PI) * PI) * 4 / PI; 206 | cosval = cos(ang); 207 | sinval = sin(ang); 208 | dinterval = dImage / *dDet; 209 | dIndex0 = blockIdx.z * blockDim.x; 210 | unsigned int tx = threadIdx.x; 211 | unsigned int dIndex = dIndex0 + tx; 212 | __syncthreads(); 213 | if (ang_error <= 1) { 214 | ang_error = (ang - floor(ang / 2 / PI) * 2 * PI) * 4 / PI; 215 | if (ang_error >= 3 && ang_error < 7) { 216 | dPoint0 = ((*dets / 2 - dIndex0) * *dDet + *binshift) / cosval; 217 | if (dIndex < *dets) { 218 | dPoint[tx] = ((*dets / 2 - dIndex - 1) * *dDet + *binshift) / cosval; 219 | } 220 | } else { 221 | dPoint0 = ((dIndex0 - *dets / 2) * *dDet + *binshift) / cosval; 222 | if (dIndex < *dets) { 223 | dPoint[tx] = ((dIndex + 1 - *dets / 2) * *dDet + *binshift) / cosval; 224 | } 225 | } 226 | if (dIndex < *dets) { 227 | if (ang_error >= 3 && ang_error < 7) { 228 | prj[tx] = projection[idxchannel][0][idxview][static_cast(*dets)-1-dIndex]; 229 | } else { 230 | prj[tx] = projection[idxchannel][0][idxview][dIndex]; 231 | } 232 | } else { 233 | prj[tx] = 0; 234 | } 235 | __syncthreads(); 236 | for (int i = 0; i < ceil(*height / blockDim.x); i++){ 237 | int idxrow = i * blockDim.x + tx; 238 | if (idxrow < *height) { 239 | scalar_t i0y = (*height / 2 - idxrow - 0.5) * dImage; 240 | scalar_t i0x = - *width / 2 * dImage; 241 | scalar_t pixbound = sinval / cosval * i0y + i0x; 242 | int idx0col = floor((dPoint0 - pixbound) / dImage); 243 | idx0col = max(idx0col, 0); 244 | pixbound += idx0col * dImage; 245 | scalar_t threadprj = 0; 246 | scalar_t prebound = max(pixbound, dPoint0); 247 | pixbound += dImage; 248 | scalar_t detbound = dPoint[0]; 249 | int idxd = 0, idxi = idx0col; 250 | while (idxi < *width && (idxd + dIndex0) < *dets && idxd < blockDim.x) { 251 | if (detbound <= prebound) { 252 | idxd ++; 253 | if (idxd < blockDim.x) detbound = dPoint[idxd]; 254 | }else if (pixbound <= prebound){ 255 | idxi ++; 256 | pixbound += dImage; 257 | }else if (pixbound <= detbound) { 258 | threadprj += (pixbound - prebound) * prj[idxd] * dinterval; 259 | prebound = pixbound; 260 | atomicAdd(&(image[idxchannel][0][idxrow][idxi]), threadprj); 261 | threadprj = 0; 262 | idxi ++; 263 | pixbound += dImage; 264 | } else { 265 | threadprj += (detbound - prebound) * prj[idxd] * dinterval; 266 | prebound = detbound; 267 | idxd ++; 268 | if (idxd < blockDim.x) detbound = dPoint[idxd]; 269 | } 270 | } 271 | if (threadprj !=0 ) atomicAdd(&(image[idxchannel][0][idxrow][idxi]), threadprj); 272 | } 273 | } 274 | } else { 275 | ang_error = (ang - floor(ang / 2 / PI) * 2 * PI) * 4 / PI; 276 | if (ang_error >= 3 && ang_error < 7) { 277 | dPoint0 = ((*dets / 2 - dIndex0) * *dDet + *binshift) / sinval; 278 | if (dIndex < *dets) { 279 | dPoint[tx] = ((*dets / 2 - dIndex - 1) * *dDet + *binshift) / sinval; 280 | } 281 | } else { 282 | dPoint0 = ((dIndex0 - *dets / 2) * *dDet + *binshift) / sinval; 283 | if (dIndex < *dets) { 284 | dPoint[tx] = ((dIndex + 1 - *dets / 2) * *dDet + *binshift) / sinval; 285 | } 286 | } 287 | if (dIndex < *dets) { 288 | if (ang_error >= 3 && ang_error < 7) { 289 | prj[tx] = projection[idxchannel][0][idxview][static_cast(*dets)-1-dIndex]; 290 | } else { 291 | prj[tx] = projection[idxchannel][0][idxview][dIndex]; 292 | } 293 | } else { 294 | prj[tx] = 0; 295 | } 296 | __syncthreads(); 297 | for (int i = 0; i < ceil(*width / blockDim.x); i++){ 298 | int idxcol = i * blockDim.x + tx; 299 | if (idxcol < *width) { 300 | scalar_t i0x = (idxcol - *width / 2 + 0.5) * dImage; 301 | scalar_t i0y = - *height / 2 * dImage; 302 | scalar_t pixbound = i0y + cosval / sinval * i0x; 303 | int idx0row = floor((dPoint0 - pixbound) / dImage); 304 | idx0row = max(idx0row, 0); 305 | pixbound += idx0row * dImage; 306 | scalar_t threadprj = 0; 307 | scalar_t prebound = max(pixbound, dPoint0); 308 | pixbound += dImage; 309 | scalar_t detbound = dPoint[0]; 310 | int idxd = 0, idxi = idx0row; 311 | while (idxi < *height && (idxd + dIndex0) < *dets && idxd < blockDim.x) { 312 | if (detbound <= prebound) { 313 | idxd ++; 314 | if (idxd < blockDim.x) detbound = dPoint[idxd]; 315 | }else if (pixbound <= prebound) { 316 | idxi ++; 317 | pixbound += dImage; 318 | }else if (pixbound <= detbound) { 319 | threadprj += (pixbound - prebound) * prj[idxd] * dinterval; 320 | prebound = pixbound; 321 | atomicAdd(&(image[idxchannel][0][static_cast(*height)-1-idxi][idxcol]), threadprj); 322 | threadprj = 0; 323 | idxi ++; 324 | pixbound += dImage; 325 | } else { 326 | threadprj += (detbound - prebound) * prj[idxd] * dinterval; 327 | prebound = detbound; 328 | idxd ++; 329 | if (idxd < blockDim.x) detbound = dPoint[idxd]; 330 | } 331 | } 332 | if (threadprj !=0 ) atomicAdd(&(image[idxchannel][0][static_cast(*height)-1-idxi][idxcol]), threadprj); 333 | } 334 | } 335 | } 336 | } 337 | 338 | template 339 | __global__ void bprj_t_para( 340 | const torch::PackedTensorAccessor image, 341 | torch::PackedTensorAccessor projection, 342 | const scalar_t* __restrict__ views, const scalar_t* __restrict__ dets, 343 | const scalar_t* __restrict__ width, const scalar_t* __restrict__ height, 344 | const scalar_t* __restrict__ dImg, const scalar_t* __restrict__ dDet, 345 | const scalar_t* __restrict__ Ang0, const scalar_t* __restrict__ dAng, 346 | const scalar_t* __restrict__ binshift) { 347 | 348 | __shared__ unsigned int nblocks; 349 | __shared__ unsigned int idxchannel; 350 | __shared__ unsigned int idxview; 351 | nblocks = ceil(*views / gridDim.y); 352 | idxchannel = blockIdx.x % nblocks; 353 | idxview = idxchannel * gridDim.y + blockIdx.y; 354 | if (idxview >= *views) return; 355 | idxchannel = blockIdx.x / nblocks; 356 | __shared__ scalar_t prj[BLOCK_DIM]; 357 | __shared__ scalar_t dPoint[BLOCK_DIM]; 358 | __shared__ scalar_t dImage; 359 | __shared__ scalar_t dPoint0; 360 | __shared__ double ang; 361 | __shared__ double PI; 362 | __shared__ double ang_error; 363 | __shared__ double cosval; 364 | __shared__ double sinval; 365 | __shared__ unsigned int dIndex0; 366 | 367 | PI = acos(-1.0); 368 | ang = idxview * *dAng + *Ang0; 369 | dImage = *dImg; 370 | ang_error = abs(ang - round(ang / PI) * PI) * 4 / PI; 371 | cosval = cos(ang); 372 | sinval = sin(ang); 373 | dIndex0 = blockIdx.z * blockDim.x; 374 | unsigned int tx = threadIdx.x; 375 | unsigned int dIndex = dIndex0 + tx; 376 | prj[tx] = 0; 377 | __syncthreads(); 378 | if (ang_error <= 1) { 379 | ang_error = (ang - floor(ang / 2 / PI) * 2 * PI) * 4 / PI; 380 | if (ang_error >= 3 && ang_error < 7) { 381 | dPoint0 = ((*dets / 2 - dIndex0) * *dDet + *binshift) / cosval; 382 | if (dIndex < *dets) { 383 | dPoint[tx] = ((*dets / 2 - dIndex - 1) * *dDet + *binshift) / cosval; 384 | } 385 | } else { 386 | dPoint0 = ((dIndex0 - *dets / 2) * *dDet + *binshift) / cosval; 387 | if (dIndex < *dets) { 388 | dPoint[tx] = ((dIndex + 1 - *dets / 2) * *dDet + *binshift) / cosval; 389 | } 390 | } 391 | __syncthreads(); 392 | for (int i = 0; i < ceil(*height / blockDim.x); i++){ 393 | int idxrow = i * blockDim.x + tx; 394 | if (idxrow < *height) { 395 | scalar_t i0y = (*height / 2 - idxrow - 0.5) * dImage; 396 | scalar_t i0x = - *width / 2 * dImage; 397 | scalar_t pixbound = sinval / cosval * i0y + i0x; 398 | int idx0col = floor((dPoint0 - pixbound) / dImage); 399 | idx0col = max(idx0col, 0); 400 | pixbound += idx0col * dImage; 401 | scalar_t threadprj = 0; 402 | scalar_t prebound = max(pixbound, dPoint0); 403 | pixbound += dImage; 404 | scalar_t detbound = dPoint[0]; 405 | int idxd = 0, idxi = idx0col; 406 | while (idxi < *width && (idxd + dIndex0) < *dets && idxd < blockDim.x) { 407 | if (detbound <= prebound) { 408 | idxd ++; 409 | if (idxd < blockDim.x) detbound = dPoint[idxd]; 410 | }else if (pixbound <= prebound){ 411 | idxi ++; 412 | pixbound += dImage; 413 | }else if (pixbound < detbound) { 414 | threadprj += (pixbound - prebound) * image[idxchannel][0][idxrow][idxi] / dImage; 415 | prebound = pixbound; 416 | idxi ++; 417 | pixbound += dImage; 418 | } else { 419 | threadprj += (detbound - prebound) * image[idxchannel][0][idxrow][idxi] / dImage; 420 | prebound = detbound; 421 | atomicAdd(prj+idxd, threadprj); 422 | threadprj = 0; 423 | idxd ++; 424 | if (idxd < blockDim.x) detbound = dPoint[idxd]; 425 | } 426 | } 427 | if (threadprj != 0) atomicAdd(prj+idxd, threadprj); 428 | } 429 | } 430 | __syncthreads(); 431 | if (dIndex < *dets) { 432 | if (ang_error >= 3 && ang_error < 7) { 433 | projection[idxchannel][0][idxview][static_cast(*dets)-1-dIndex] = prj[tx]; 434 | } else { 435 | projection[idxchannel][0][idxview][dIndex] = prj[tx]; 436 | } 437 | } 438 | } else { 439 | ang_error = (ang - floor(ang / 2 / PI) * 2 * PI) * 4 / PI; 440 | if (ang_error >= 3 && ang_error < 7) { 441 | dPoint0 = ((*dets / 2 - dIndex0) * *dDet + *binshift) / sinval; 442 | if (dIndex < *dets) { 443 | dPoint[tx] = ((*dets / 2 - dIndex - 1) * *dDet + *binshift) / sinval; 444 | } 445 | } else { 446 | dPoint0 = ((dIndex0 - *dets / 2) * *dDet + *binshift) / sinval; 447 | if (dIndex < *dets) { 448 | dPoint[tx] = ((dIndex + 1 - *dets / 2) * *dDet + *binshift) / sinval; 449 | } 450 | } 451 | __syncthreads(); 452 | for (int i = 0; i < ceil(*width / blockDim.x); i++){ 453 | int idxcol = i * blockDim.x + tx; 454 | if (idxcol < *width) { 455 | scalar_t i0x = (idxcol - *width / 2 + 0.5) * dImage; 456 | scalar_t i0y = - *height / 2 * dImage; 457 | scalar_t pixbound = i0y + cosval / sinval * i0x; 458 | int idx0row = floor((dPoint0 - pixbound) / dImage); 459 | idx0row = max(idx0row, 0); 460 | pixbound += idx0row * dImage; 461 | scalar_t threadprj = 0; 462 | scalar_t prebound = max(pixbound, dPoint0); 463 | pixbound += dImage; 464 | scalar_t detbound = dPoint[0]; 465 | int idxd = 0, idxi = idx0row; 466 | while (idxi < *height && (idxd + dIndex0) < *dets && idxd < blockDim.x) { 467 | if (detbound <= prebound) { 468 | idxd ++; 469 | if (idxd < blockDim.x) detbound = dPoint[idxd]; 470 | }else if (pixbound <= prebound) { 471 | idxi ++; 472 | pixbound += dImage; 473 | }else if (pixbound < detbound) { 474 | threadprj += (pixbound - prebound) * image[idxchannel][0][static_cast(*height)-1-idxi][idxcol] / dImage; 475 | prebound = pixbound; 476 | idxi ++; 477 | pixbound += dImage; 478 | } else { 479 | threadprj += (detbound - prebound) * image[idxchannel][0][static_cast(*height)-1-idxi][idxcol] / dImage; 480 | prebound = detbound; 481 | atomicAdd(prj+idxd, threadprj); 482 | threadprj = 0; 483 | idxd ++; 484 | if (idxd < blockDim.x) detbound = dPoint[idxd]; 485 | } 486 | } 487 | if (threadprj != 0) atomicAdd(prj+idxd, threadprj); 488 | } 489 | } 490 | __syncthreads(); 491 | if (dIndex < *dets) { 492 | if (ang_error >= 3 && ang_error < 7) { 493 | projection[idxchannel][0][idxview][static_cast(*dets)-1-dIndex] = prj[tx]; 494 | } else { 495 | projection[idxchannel][0][idxview][dIndex] = prj[tx]; 496 | } 497 | } 498 | } 499 | } 500 | 501 | template 502 | __global__ void bprj_para( 503 | torch::PackedTensorAccessor image, 504 | const torch::PackedTensorAccessor projection, 505 | const scalar_t* __restrict__ views, const scalar_t* __restrict__ dets, 506 | const scalar_t* __restrict__ width, const scalar_t* __restrict__ height, 507 | const scalar_t* __restrict__ dImg, const scalar_t* __restrict__ dDet, 508 | const scalar_t* __restrict__ Ang0, const scalar_t* __restrict__ dAng, 509 | const scalar_t* __restrict__ binshift) { 510 | 511 | __shared__ unsigned int nblocks; 512 | __shared__ unsigned int idxchannel; 513 | __shared__ unsigned int idxview; 514 | nblocks = ceil(*views / gridDim.y); 515 | idxchannel = blockIdx.x % nblocks; 516 | idxview = idxchannel * gridDim.y + blockIdx.y; 517 | if (idxview >= *views) return; 518 | idxchannel = blockIdx.x / nblocks; 519 | __shared__ scalar_t prj[BLOCK_DIM]; 520 | __shared__ scalar_t dPoint[BLOCK_DIM]; 521 | __shared__ scalar_t dImage; 522 | __shared__ scalar_t dPoint0; 523 | __shared__ double ang; 524 | __shared__ double PI; 525 | __shared__ double ang_error; 526 | __shared__ double cosval; 527 | __shared__ double sinval; 528 | __shared__ unsigned int dIndex0; 529 | 530 | PI = acos(-1.0); 531 | ang = idxview * *dAng + *Ang0; 532 | dImage = *dImg; 533 | ang_error = abs(ang - round(ang / PI) * PI) * 4 / PI; 534 | cosval = cos(ang); 535 | sinval = sin(ang); 536 | dIndex0 = blockIdx.z * blockDim.x; 537 | unsigned int tx = threadIdx.x; 538 | unsigned int dIndex = dIndex0 + tx; 539 | __syncthreads(); 540 | if (ang_error <= 1) { 541 | ang_error = (ang - floor(ang / 2 / PI) * 2 * PI) * 4 / PI; 542 | if (ang_error >= 3 && ang_error < 7) { 543 | dPoint0 = ((*dets / 2 - dIndex0) * *dDet + *binshift) / cosval; 544 | if (dIndex < *dets) { 545 | dPoint[tx] = ((*dets / 2 - dIndex - 1) * *dDet + *binshift) / cosval; 546 | } 547 | } else { 548 | dPoint0 = ((dIndex0 - *dets / 2) * *dDet + *binshift) / cosval; 549 | if (dIndex < *dets) { 550 | dPoint[tx] = ((dIndex + 1 - *dets / 2) * *dDet + *binshift) / cosval; 551 | } 552 | } 553 | if (dIndex < *dets) { 554 | if (ang_error >= 3 && ang_error < 7) { 555 | prj[tx] = projection[idxchannel][0][idxview][static_cast(*dets)-1-dIndex]; 556 | } else { 557 | prj[tx] = projection[idxchannel][0][idxview][dIndex]; 558 | } 559 | } else { 560 | prj[tx] = 0; 561 | } 562 | __syncthreads(); 563 | for (int i = 0; i < ceil(*height / blockDim.x); i++){ 564 | int idxrow = i * blockDim.x + tx; 565 | if (idxrow < *height) { 566 | scalar_t i0y = (*height / 2 - idxrow - 0.5) * dImage; 567 | scalar_t i0x = - *width / 2 * dImage; 568 | scalar_t pixbound = sinval / cosval * i0y + i0x; 569 | int idx0col = floor((dPoint0 - pixbound) / dImage); 570 | idx0col = max(idx0col, 0); 571 | pixbound += idx0col * dImage; 572 | scalar_t threadprj = 0; 573 | scalar_t prebound = max(pixbound, dPoint0); 574 | pixbound += dImage; 575 | scalar_t detbound = dPoint[0]; 576 | int idxd = 0, idxi = idx0col; 577 | while (idxi < *width && (idxd + dIndex0) < *dets && idxd < blockDim.x) { 578 | if (detbound <= prebound) { 579 | idxd ++; 580 | if (idxd < blockDim.x) detbound = dPoint[idxd]; 581 | }else if (pixbound <= prebound){ 582 | idxi ++; 583 | pixbound += dImage; 584 | }else if (pixbound <= detbound) { 585 | threadprj += (pixbound - prebound) * prj[idxd] / dImage; 586 | prebound = pixbound; 587 | atomicAdd(&(image[idxchannel][0][idxrow][idxi]), threadprj); 588 | threadprj = 0; 589 | idxi ++; 590 | pixbound += dImage; 591 | } else { 592 | threadprj += (detbound - prebound) * prj[idxd] / dImage; 593 | prebound = detbound; 594 | idxd ++; 595 | if (idxd < blockDim.x) detbound = dPoint[idxd]; 596 | } 597 | } 598 | if (threadprj !=0 ) atomicAdd(&(image[idxchannel][0][idxrow][idxi]), threadprj); 599 | } 600 | } 601 | } else { 602 | ang_error = (ang - floor(ang / 2 / PI) * 2 * PI) * 4 / PI; 603 | if (ang_error >= 3 && ang_error < 7) { 604 | dPoint0 = ((*dets / 2 - dIndex0) * *dDet + *binshift) / sinval; 605 | if (dIndex < *dets) { 606 | dPoint[tx] = ((*dets / 2 - dIndex - 1) * *dDet + *binshift) / sinval; 607 | } 608 | } else { 609 | dPoint0 = ((dIndex0 - *dets / 2) * *dDet + *binshift) / sinval; 610 | if (dIndex < *dets) { 611 | dPoint[tx] = ((dIndex + 1 - *dets / 2) * *dDet + *binshift) / sinval; 612 | } 613 | } 614 | if (dIndex < *dets) { 615 | if (ang_error >= 3 && ang_error < 7) { 616 | prj[tx] = projection[idxchannel][0][idxview][static_cast(*dets)-1-dIndex]; 617 | } else { 618 | prj[tx] = projection[idxchannel][0][idxview][dIndex]; 619 | } 620 | } else { 621 | prj[tx] = 0; 622 | } 623 | __syncthreads(); 624 | for (int i = 0; i < ceil(*width / blockDim.x); i++){ 625 | int idxcol = i * blockDim.x + tx; 626 | if (idxcol < *width) { 627 | scalar_t i0x = (idxcol - *width / 2 + 0.5) * dImage; 628 | scalar_t i0y = - *height / 2 * dImage; 629 | scalar_t pixbound = i0y + cosval / sinval * i0x; 630 | int idx0row = floor((dPoint0 - pixbound) / dImage); 631 | idx0row = max(idx0row, 0); 632 | pixbound += idx0row * dImage; 633 | scalar_t threadprj = 0; 634 | scalar_t prebound = max(pixbound, dPoint0); 635 | pixbound += dImage; 636 | scalar_t detbound = dPoint[0]; 637 | int idxd = 0, idxi = idx0row; 638 | while (idxi < *height && (idxd + dIndex0) < *dets && idxd < blockDim.x) { 639 | if (detbound <= prebound) { 640 | idxd ++; 641 | if (idxd < blockDim.x) detbound = dPoint[idxd]; 642 | }else if (pixbound <= prebound) { 643 | idxi ++; 644 | pixbound += dImage; 645 | }else if (pixbound <= detbound) { 646 | threadprj += (pixbound - prebound) * prj[idxd] / dImage; 647 | prebound = pixbound; 648 | atomicAdd(&(image[idxchannel][0][static_cast(*height)-1-idxi][idxcol]), threadprj); 649 | threadprj = 0; 650 | idxi ++; 651 | pixbound += dImage; 652 | } else { 653 | threadprj += (detbound - prebound) * prj[idxd] / dImage; 654 | prebound = detbound; 655 | idxd ++; 656 | if (idxd < blockDim.x) detbound = dPoint[idxd]; 657 | } 658 | } 659 | if (threadprj !=0 ) atomicAdd(&(image[idxchannel][0][static_cast(*height)-1-idxi][idxcol]), threadprj); 660 | } 661 | } 662 | } 663 | } 664 | 665 | template 666 | __global__ void rlfilter(scalar_t* __restrict__ filter, 667 | const scalar_t* __restrict__ dets, const scalar_t* __restrict__ dDet) { 668 | unsigned xIndex = blockIdx.x * blockDim.x + threadIdx.x; 669 | __shared__ double PI; 670 | __shared__ scalar_t d; 671 | PI = acos(-1.0); 672 | d = *dDet; 673 | if (xIndex < (*dets * 2 - 1)) { 674 | int x = xIndex - *dets + 1; 675 | if ((abs(x) % 2) == 1) { 676 | filter[xIndex] = -1 / (PI * PI * x * x * d * d); 677 | } else if (x == 0) { 678 | filter[xIndex] = 1 / (4 * d * d); 679 | } else { 680 | filter[xIndex] = 0; 681 | } 682 | } 683 | } 684 | 685 | torch::Tensor prj_para_cuda(torch::Tensor image, torch::Tensor options) { 686 | cudaSetDevice(image.device().index()); 687 | auto views = options[0]; 688 | auto dets = options[1]; 689 | auto width = options[2]; 690 | auto height = options[3]; 691 | auto dImg = options[4]; 692 | auto dDet = options[5]; 693 | auto Ang0 = options[6]; 694 | auto dAng = options[7]; 695 | auto binshift = options[8]; 696 | const int channels = static_cast(image.size(0)); 697 | auto projection = torch::empty({channels, 1, views.item(), dets.item()}, image.options()); 698 | 699 | int nblocksx = ceil(views.item() / GRID_DIM) * channels; 700 | int nblocksy = min(views.item(), GRID_DIM); 701 | int nblocksz = ceil(dets.item() / BLOCK_DIM); 702 | const dim3 blocks(nblocksx, nblocksy, nblocksz); 703 | 704 | AT_DISPATCH_FLOATING_TYPES(image.type(), "para_projection", ([&] { 705 | prj_para<<>>( 706 | image.packed_accessor(), 707 | projection.packed_accessor(), 708 | views.data(), dets.data(), width.data(), 709 | height.data(), dImg.data(), dDet.data(), 710 | Ang0.data(), dAng.data(), binshift.data() 711 | ); 712 | })); 713 | return projection; 714 | } 715 | 716 | torch::Tensor prj_t_para_cuda(torch::Tensor projection, torch::Tensor options) { 717 | cudaSetDevice(projection.device().index()); 718 | auto views = options[0]; 719 | auto dets = options[1]; 720 | auto width = options[2]; 721 | auto height = options[3]; 722 | auto dImg = options[4]; 723 | auto dDet = options[5]; 724 | auto Ang0 = options[6]; 725 | auto dAng = options[7]; 726 | auto binshift = options[8]; 727 | const int channels = static_cast(projection.size(0)); 728 | auto image = torch::zeros({channels, 1, height.item(), width.item()}, projection.options()); 729 | 730 | int nblocksx = ceil(views.item() / GRID_DIM) * channels; 731 | int nblocksy = min(views.item(), GRID_DIM); 732 | int nblocksz = ceil(dets.item() / BLOCK_DIM); 733 | const dim3 blocks(nblocksx, nblocksy, nblocksz); 734 | 735 | AT_DISPATCH_FLOATING_TYPES(projection.type(), "para_backprojection", ([&] { 736 | prj_t_para<<>>( 737 | image.packed_accessor(), 738 | projection.packed_accessor(), 739 | views.data(), dets.data(), width.data(), 740 | height.data(), dImg.data(), dDet.data(), 741 | Ang0.data(), dAng.data(), binshift.data() 742 | ); 743 | })); 744 | return image; 745 | } 746 | 747 | torch::Tensor bprj_t_para_cuda(torch::Tensor image, torch::Tensor options) { 748 | cudaSetDevice(image.device().index()); 749 | auto views = options[0]; 750 | auto dets = options[1]; 751 | auto width = options[2]; 752 | auto height = options[3]; 753 | auto dImg = options[4]; 754 | auto dDet = options[5]; 755 | auto Ang0 = options[6]; 756 | auto dAng = options[7]; 757 | auto binshift = options[8]; 758 | const int channels = static_cast(image.size(0)); 759 | auto projection = torch::empty({channels, 1, views.item(), dets.item()}, image.options()); 760 | 761 | int nblocksx = ceil(views.item() / GRID_DIM) * channels; 762 | int nblocksy = min(views.item(), GRID_DIM); 763 | int nblocksz = ceil(dets.item() / BLOCK_DIM); 764 | const dim3 blocks(nblocksx, nblocksy, nblocksz); 765 | 766 | AT_DISPATCH_FLOATING_TYPES(image.type(), "para_projection", ([&] { 767 | bprj_t_para<<>>( 768 | image.packed_accessor(), 769 | projection.packed_accessor(), 770 | views.data(), dets.data(), width.data(), 771 | height.data(), dImg.data(), dDet.data(), 772 | Ang0.data(), dAng.data(), binshift.data() 773 | ); 774 | })); 775 | return projection; 776 | } 777 | 778 | torch::Tensor bprj_para_cuda(torch::Tensor projection, torch::Tensor options) { 779 | cudaSetDevice(projection.device().index()); 780 | auto views = options[0]; 781 | auto dets = options[1]; 782 | auto width = options[2]; 783 | auto height = options[3]; 784 | auto dImg = options[4]; 785 | auto dDet = options[5]; 786 | auto Ang0 = options[6]; 787 | auto dAng = options[7]; 788 | auto binshift = options[8]; 789 | const int channels = static_cast(projection.size(0)); 790 | auto image = torch::zeros({channels, 1, height.item(), width.item()}, projection.options()); 791 | 792 | int nblocksx = ceil(views.item() / GRID_DIM) * channels; 793 | int nblocksy = min(views.item(), GRID_DIM); 794 | int nblocksz = ceil(dets.item() / BLOCK_DIM); 795 | const dim3 blocks(nblocksx, nblocksy, nblocksz); 796 | 797 | AT_DISPATCH_FLOATING_TYPES(projection.type(), "para_backprojection", ([&] { 798 | bprj_para<<>>( 799 | image.packed_accessor(), 800 | projection.packed_accessor(), 801 | views.data(), dets.data(), width.data(), 802 | height.data(), dImg.data(), dDet.data(), 803 | Ang0.data(), dAng.data(), binshift.data() 804 | ); 805 | })); 806 | return image; 807 | } 808 | 809 | torch::Tensor fbp_para_cuda(torch::Tensor projection, torch::Tensor options) { 810 | cudaSetDevice(projection.device().index()); 811 | auto views = options[0]; 812 | auto dets = options[1]; 813 | auto width = options[2]; 814 | auto height = options[3]; 815 | auto dImg = options[4]; 816 | auto dDet = options[5]; 817 | auto Ang0 = options[6]; 818 | auto dAng = options[7]; 819 | auto binshift = options[8]; 820 | const int channels = static_cast(projection.size(0)); 821 | auto image = torch::zeros({channels, 1, height.item(), width.item()}, projection.options()); 822 | auto filter = torch::empty({1,1,1,dets.item()*2-1}, projection.options()); 823 | auto rectweight = projection * dDet; 824 | 825 | int filterdim = ceil((dets.item()*2-1) / BLOCK_DIM); 826 | int nblocksx = ceil(views.item() / GRID_DIM) * channels; 827 | int nblocksy = min(views.item(), GRID_DIM); 828 | int nblocksz = ceil(dets.item() / BLOCK_DIM); 829 | const dim3 blocks(nblocksx, nblocksy, nblocksz); 830 | 831 | AT_DISPATCH_FLOATING_TYPES(projection.type(), "ramp_filter", ([&] { 832 | rlfilter<<>>( 833 | filter.data(), dets.data(), dDet.data()); 834 | })); 835 | 836 | auto filtered_projection = torch::conv2d(rectweight, filter, {}, 1, torch::IntArrayRef({0, dets.item()-1})); 837 | 838 | AT_DISPATCH_FLOATING_TYPES(projection.type(), "para_w_backprojection", ([&] { 839 | bprj_para<<>>( 840 | image.packed_accessor(), 841 | filtered_projection.packed_accessor(), 842 | views.data(), dets.data(), width.data(), 843 | height.data(), dImg.data(), dDet.data(), 844 | Ang0.data(), dAng.data(), binshift.data() 845 | ); 846 | })); 847 | image = image * dAng / 2; 848 | return image; 849 | } --------------------------------------------------------------------------------