├── .github
├── semantic.yaml
└── workflows
│ └── cicd-workflows.yaml
├── .gitignore
├── Dockerfile
├── README.md
├── benchmark.py
├── benchmark
├── 2080ti.png
├── 2080ti_ms.png
├── g4dn.png
├── jetson
├── smem_1080p.png
├── smem_4k.png
├── smem_8k.png
└── t4.png
├── deprecated
├── NHWC2NCHW.cu
├── NHWC2NCHW_free.cu
├── dockerfile.opencv
├── jetson_cuda_resize.py
├── resize_fixed_dim.py
├── resize_free_dim.py
├── resize_ker.cu
├── resize_multiple_frame_dim.py
└── resize_multiple_frame_dim_refactor.py
├── lerp.py
├── lib_cuResize.cu
├── lintrc
└── pylintrc
├── resize.py
├── resize_formated.py
├── resize_free.cu
├── rgba.png
├── tools
├── float3_example.py
└── stat.cu
└── trump.jpg
/.github/semantic.yaml:
--------------------------------------------------------------------------------
1 | titleOnly: true
2 | types:
3 | - feat
4 | - fix
5 | - docs
6 | - style
7 | - test
8 | - chore
9 | - revert
10 |
--------------------------------------------------------------------------------
/.github/workflows/cicd-workflows.yaml:
--------------------------------------------------------------------------------
1 | name: CICD
2 | env:
3 | # repo_name: ${{ github.event.repository.name }}
4 | repo_name: cuda_resize
5 |
6 | on:
7 | pull_request:
8 | push:
9 | branches:
10 | - master
11 | - development
12 | - "feature/**"
13 |
14 | jobs:
15 | commit_filter:
16 | name: Filter Commit
17 | runs-on: ubuntu-latest
18 | if: "contains(github.ref, 'master') || !contains(github.event.head_commit.message, 'skip ci')"
19 | steps:
20 | - name: Echo the greeting
21 | run: echo 'CI/CD triggered.'
22 | check_code:
23 | name: Code Checking
24 | runs-on: ubuntu-latest
25 | if: github.event_name != 'push'
26 | needs: [commit_filter]
27 | steps:
28 | - uses: actions/checkout@v3
29 | with:
30 | submodules: true
31 | token: ${{ secrets.CICD_CREDENTIALS }}
32 | - name: Setup Docker build kit
33 | uses: docker/setup-buildx-action@v2
34 | with:
35 | version: latest
36 | - name: Build and test image
37 | id: build_image
38 | run: |
39 | # Build, test a docker container
40 | docker buildx build --load --tag linting_machine .
41 | docker run -t --rm --entrypoint bash linting_machine -c "pip install pylint==2.13.0 && pylint --rcfile=lintrc/pylintrc *.py"
42 | build_image:
43 | name: Build & Push Container - Docker Hub
44 | needs: [commit_filter]
45 | if: github.event_name == 'push' && (contains(github.ref, 'master') || contains(github.ref, 'development') || contains(github.ref, 'feature'))
46 | runs-on: ubuntu-latest
47 | steps:
48 | - uses: actions/checkout@v3
49 | with:
50 | submodules: true
51 | token: ${{ secrets.CICD_CREDENTIALS }}
52 |
53 |
54 | # - name: Build the image (AMD64, ARM64)
55 | # run: |
56 | # docker buildx create --use --name multi-arch-builder && \
57 | # docker login -u ${{ secrets.DOCKER_USERNAME }} -p ${{ secrets.DOCKER_PASSWORD }} && \
58 | # docker buildx build --push \
59 | # --tag ${{ secrets.DOCKER_USERNAME }}/${{ env.repo_name }} \
60 | # --platform linux/amd64,linux/arm64 .
61 |
62 | - name: Build the image (AMD64)
63 | run: |
64 | docker buildx create --use --name multi-arch-builder && \
65 | docker login -u ${{ secrets.DOCKER_USERNAME }} -p ${{ secrets.DOCKER_PASSWORD }} && \
66 | docker buildx build --push \
67 | --cache-to ${{ secrets.DOCKER_USERNAME }}/${{ env.repo_name }}:build_cache \
68 | --cache-from ${{ secrets.DOCKER_USERNAME }}/${{ env.repo_name }}:build_cache \
69 | --tag ${{ secrets.DOCKER_USERNAME }}/${{ env.repo_name }}:cu12 \
70 | --tag ${{ secrets.DOCKER_USERNAME }}/${{ env.repo_name }}:latest .
71 |
72 | # - name: Docker Hub Description
73 | # uses: peter-evans/dockerhub-description@v3
74 | # with:
75 | # username: ${{ secrets.DOCKER_USERNAME }}
76 | # password: ${{ secrets.DOCKER_PASSWORD }}
77 | # repository: ${{ secrets.DOCKER_USERNAME }}/${{ env.repo_name }}
78 | # readme-filepath: ./README.md
79 |
80 | - if: success()
81 | name: Notify Deployment
82 | uses: rtCamp/action-slack-notify@master
83 | env:
84 | SLACK_WEBHOOK: ${{ secrets.SLACK_WEBHOOK }}
85 | SLACK_USERNAME: ${{ github.repository }}
86 | SLACK_ICON: https://github.com/royinx.png?size=48
87 | SLACK_TITLE: "New Version Deployed :rocket:"
88 | SLACK_MESSAGE: "Check out https://hub.docker.com/r/${{ secrets.DOCKER_USERNAME }}/${{ env.repo_name }}"
89 |
90 | # auto_merge_pr:
91 | # name: Auto Merge Sync Pull Request
92 | # runs-on: ubuntu-latest
93 | # # needs: [check_code]
94 | # if: "contains(github.event.pull_request.title, 'chore: auto sync master with development')"
95 | # steps:
96 | # - name: Auto Review
97 | # uses: andrewmusgrave/automatic-pull-request-review@0.0.2
98 | # with:
99 | # repo-token: "${{ secrets.CICD_CREDENTIALS }}"
100 | # event: APPROVE
101 | # body: "Auto Review by Ultron"
102 | # - name: Auto Merge Sync PR
103 | # uses: "pascalgn/automerge-action@4536e8847eb62fe2f0ee52c8fa92d17aa97f932f"
104 | # env:
105 | # GITHUB_TOKEN: "${{ secrets.CICD_CREDENTIALS }}"
106 | # MERGE_LABELS: ""
107 | # MERGE_METHOD: "merge"
108 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | *~
2 | build
3 | *.o
4 | *.so
5 | val2017*
6 | *.npy
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM nvcr.io/nvidia/tensorrt:23.06-py3
2 | ENV DEBIAN_FRONTEND noninteractive
3 |
4 | # Build tools
5 | RUN apt update && apt install -y libgl1-mesa-glx
6 | RUN python3 -m pip install opencv-python \
7 | line_profiler \
8 | cupy-cuda12x \
9 | pandas
10 | WORKDIR /workspace
11 | COPY . .
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Cupy, CUDA Bilinear interpolation
2 |
3 | Ultra fast Bilinear interpolation in image resize with CUDA.
4 |
5 |
6 | `lerp.py` : Concept and code base (*single thread, may take a while to run).
7 | `resize_ker.cu` : CUDA test case in `C`.
8 | `resize.py` : Cupy example
9 |
10 | (*PyCUDA(deprecated) is no longer support , use cupy instead )
11 |
12 | Requirements:
13 | >- GPU (compute capability: 3.0 or above, testing platform: 7.5)
14 | >- CUDA driver
15 | >- Docker and nvidia docker
16 | ---
17 | Pros:
18 | - support Batch image.
19 | - no shared object .so and .dll binary file
20 | - Install cupy and use
21 | - Compatible to `Numpy` library
22 | - pass the GPU array to TensorRT directly.
23 |
24 | Cons:
25 | - still need the concept of CUDA programming
26 | - SourceModule have to write in C CUDA, including all CUDA kernel and device code
27 |
28 | ---
29 | ### Quick Start
30 |
31 | ```bash
32 | # Pull docker image
33 | docker run -it --runtime=nvidia royinx/cuda_resize bash
34 |
35 | # For Cupy implementation
36 | python3 resize.py
37 |
38 | # For concept
39 | python3 lerp.py
40 |
41 | # For CUDA kernel testing
42 | nvcc resize_free.cu -o resize_free.o && ./resize_free.o
43 |
44 | # For benmarking
45 | wget http://images.cocodataset.org/zips/val2017.zip
46 | unzip val2017.zip
47 | python3 benchmark.py
48 | ```
49 |
50 | Build
51 |
52 | ```bash
53 | git clone https://github.com/royinx/CUDA_Resize.git
54 | cd CUDA_Resize
55 | docker build -t lerp_cuda .
56 | docker run -it --runtime=nvidia -v ${PWD}:/py -w /py lerp_cuda bash
57 | ```
58 |
59 |
60 | Advance Metrics
61 |
62 | ```bash
63 | docker run -it --privileged --runtime=nvidia -p 20072:22 -v ${PWD}:/py -w /py lerp_cuda bash
64 | sh -c 'echo 1 >/proc/sys/kernel/perf_event_paranoid'
65 | nvcc resize_free.cu -o resize_free.o
66 | nsys profile ./resize_free.o
67 |
68 | ncu -o metrics /bin/python3 resize.py > profile_log
69 | ncu -o metrics /bin/python3 resize.py
70 | ```
71 | Remark: Development platform is in dockerfile.opencv with OpenCV in C for debugging
72 |
73 | Function Working well in pycuda container, you dont need to build OpenCV.
74 |
75 |
76 | ---
77 |
78 | ### Benchmark
79 | #### 2080ti
80 | > ratio = 2080ti (ms) / Ryzen 2700x (ms)
81 |
82 | 
83 |
84 | > time (us/img)
85 |
86 | 
87 |
88 | shared memory
89 |
90 | 
91 | 
92 | 
93 |
94 |
95 |
96 | #### (Deprecated) [w/o smem] AWS g4dn.xlarge (Tesla T4)
97 | > ratio = T4 (ms) per img / Xeon Platinum 8259CL (ms) per img
98 | 
99 |
100 | > (ms) per img on T4
101 | 
102 |
--------------------------------------------------------------------------------
/benchmark.py:
--------------------------------------------------------------------------------
1 | # pylint: disable=line-too-long, invalid-name, too-many-locals, c-extension-no-member, redefined-outer-name
2 |
3 | # built-in library
4 | import sys
5 | import os
6 | import time
7 |
8 | # third party library
9 | import cv2
10 | import cupy as cp
11 | import numpy as np
12 | import pandas as pd
13 | from resize import cuda_resize
14 |
15 | def main(input_array: cp.ndarray, resize_shape:tuple):
16 | input_array_gpu = cp.empty(shape=input_array.shape,dtype=input_array.dtype)
17 |
18 | if isinstance(input_array, cp.ndarray): # DtoD
19 | cp.cuda.runtime.memcpy(dst = int(input_array_gpu.data), # dst_ptr
20 | src = int(input_array.data), # src_ptr
21 | size=input_array.nbytes,
22 | kind=3) # 0: HtoH, 1: HtoD, 2: DtoH, 3: DtoD, 4: unified virtual addressing
23 | elif isinstance(input_array, np.ndarray):
24 | cp.cuda.runtime.memcpy(dst = int(input_array_gpu.data), # dst_ptr
25 | src = input_array.ctypes.data, # src_ptr
26 | size=input_array.nbytes,
27 | kind=1)
28 |
29 | resize_scale, top_pad, left_pad, output_array = cuda_resize(input_array_gpu,
30 | resize_shape,
31 | pad=False) # N,W,H,C
32 |
33 | return output_array, [resize_scale, top_pad, left_pad]
34 |
35 | def warm_up(shape):
36 | w,h = shape
37 | input_array_gpu = cp.ones(shape=(200,h,w,3),dtype=np.uint8)
38 | _, _, _, output_array = cuda_resize(input_array_gpu,
39 | (128,256),
40 | pad=False) # N,W,H,C
41 | print("Warm up:", output_array.shape)
42 |
43 |
44 | if __name__ == "__main__":
45 | # prepare data
46 | batch = 100
47 | size = [(3840,2160),(1920,1080), (960,540), (480,270), (240,135), (120,67), (60,33), (30,16)]
48 | warm_up(size[0])
49 | benchmark = pd.DataFrame(columns=[str(size_) for size_ in size],
50 | index=[str(size_) for size_ in size])
51 |
52 | # benchmark = defaultdict(dict)
53 | for src_shape in size:
54 | if os.path.exists(f"{src_shape}.npy"):
55 | imgs = np.load(f"{src_shape}.npy")
56 | else:
57 | imgs = [cv2.resize(cv2.imread(f"val2017/{img_name}"),src_shape) for img_name in os.listdir("val2017")[:1000]]
58 | imgs = np.asarray(imgs)
59 | np.save(f"{src_shape}.npy",imgs)
60 |
61 | for dst_shape in size:
62 | # CPU benchmark
63 | cpu_metrics = []
64 |
65 | # start = time.perf_counter()
66 | # for index in range(0, len(imgs), batch):
67 | # start = time.perf_counter()
68 | # cpu_output = [cv2.resize(img,(dst_shape))for img in imgs[index:index+batch]]
69 | # cpu_metrics.append(time.perf_counter() - start)
70 | # # cv2.imwrite(f"{index}_output_cpu.jpg", cpu_output[0])
71 |
72 | # CUDA benchmark
73 | cuda_metrics = []
74 | for index in range(0, len(imgs), batch):
75 | input_array = imgs[index:index+batch]
76 | input_array_gpu = cp.empty(shape=input_array.shape,dtype=input_array.dtype)
77 | cp.cuda.runtime.memcpy(dst = int(input_array_gpu.data), # dst_ptr
78 | src = input_array.ctypes.data, # src_ptr
79 | size=input_array.nbytes,
80 | kind=1)
81 | # input_array_gpu = cp.load(f"{src_shape}.npy")
82 |
83 |
84 | # execution
85 | start = time.perf_counter()
86 | _, _, _, output_array = cuda_resize(input_array_gpu,
87 | dst_shape[::-1],
88 | pad=False) # N,W,H,C
89 |
90 | cuda_metrics.append(time.perf_counter() - start)
91 | # cv2.imwrite(f"{index}_output_cuda.jpg", cp.asnumpy(output_array[0]))
92 | del input_array_gpu
93 | cp.get_default_memory_pool().free_all_blocks()
94 | cpu_ = sum(cpu_metrics)
95 | gpu_ = sum(cuda_metrics)
96 | speedup = cpu_/gpu_
97 | # benchmark[f"{src_shape}"][f"{dst_shape}"] = speedup
98 |
99 | benchmark[f"{src_shape}"][f"{dst_shape}"] = gpu_/1000 *1000 * 1000 # sum / batch * ms * us
100 | # print(f"{src_shape} -> {dst_shape}: \t CPU: {cpu_} \t | CUDA: {gpu_} \t | Speedup: {speedup}")
101 | # print(benchmark)
102 | del imgs
103 | print(benchmark)
--------------------------------------------------------------------------------
/benchmark/2080ti.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/royinx/CUDA_Resize/938da3fa4ce538befba7c336d3cb837f2296cd3f/benchmark/2080ti.png
--------------------------------------------------------------------------------
/benchmark/2080ti_ms.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/royinx/CUDA_Resize/938da3fa4ce538befba7c336d3cb837f2296cd3f/benchmark/2080ti_ms.png
--------------------------------------------------------------------------------
/benchmark/g4dn.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/royinx/CUDA_Resize/938da3fa4ce538befba7c336d3cb837f2296cd3f/benchmark/g4dn.png
--------------------------------------------------------------------------------
/benchmark/jetson:
--------------------------------------------------------------------------------
1 | (1920, 1080) (960, 540) (480, 270) (240, 135) (120, 67) (60, 33) (30, 16)
2 | (1920, 1080) 2628.994772 3109.770425 3026.949618 2950.659376 3022.0138 2816.862353 2883.906551
3 | (960, 540) 890.719573 1142.379314 1199.411264 1144.996296 1171.316782 1183.976468 1186.506571
4 | (480, 270) 330.3115 434.45062 405.194254 466.806814 462.576296 444.262273 441.651127
5 | (240, 135) 148.463809 268.846699 176.247592 244.022628 229.015609 172.870492 194.332538
6 | (120, 67) 88.16277 121.218474 91.046449 133.075754 165.885802 104.027597 102.88192
7 | (60, 33) 74.635785 77.663792 94.81256 109.856651 91.6848 67.332144 83.526781
8 | (30, 16) 55.879294 77.381082 60.302332 126.001308 75.266281 46.970577 69.781059
9 |
10 |
--------------------------------------------------------------------------------
/benchmark/smem_1080p.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/royinx/CUDA_Resize/938da3fa4ce538befba7c336d3cb837f2296cd3f/benchmark/smem_1080p.png
--------------------------------------------------------------------------------
/benchmark/smem_4k.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/royinx/CUDA_Resize/938da3fa4ce538befba7c336d3cb837f2296cd3f/benchmark/smem_4k.png
--------------------------------------------------------------------------------
/benchmark/smem_8k.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/royinx/CUDA_Resize/938da3fa4ce538befba7c336d3cb837f2296cd3f/benchmark/smem_8k.png
--------------------------------------------------------------------------------
/benchmark/t4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/royinx/CUDA_Resize/938da3fa4ce538befba7c336d3cb837f2296cd3f/benchmark/t4.png
--------------------------------------------------------------------------------
/deprecated/NHWC2NCHW.cu:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | __global__ void transpose(unsigned char *odata, const unsigned char *idata)
4 | {
5 | int H = blockDim.x * gridDim.x; // # dst_height
6 | int W = blockDim.y * gridDim.y; // # dst_width
7 | int h = blockDim.x * blockIdx.x + threadIdx.x; // 32 * bkIdx[0:18] + tdIdx; [0,607] # x / h-th row
8 | int w = blockDim.y * blockIdx.y + threadIdx.y; // 32 * bkIdx[0:18] + tdIdx; [0,607] # y / w-th col
9 | int C = 3; // # ChannelDim
10 | int c = blockIdx.z % 3 ; // [0,2] # ChannelIdx
11 | int n = blockIdx.z / 3 ; // [0 , Batch size-1], # BatchIdx
12 |
13 | long src_idx = n * (H * W * C) +
14 | h * (W * C) +
15 | w * C +
16 | c;
17 |
18 | long dst_idx = n * (C * H * W) +
19 | c * (H * W)+
20 | h * W+
21 | w;
22 |
23 | odata[dst_idx] = idata[src_idx];
24 | }
25 |
26 | int main(){
27 | // dim3 dimBlock(32,32,1); << Max total is 1024 , so , x=32 ,y=32 , some one use 1024 to handle flatten tensor is fine.
28 | // dim3 dimGrid(19,19,3); << x = 608 / 32 = 19 , same on y , z = channel * batch_size, assume channel = 3.
29 | dim3 dimBlock(32,32,1);
30 | dim3 dimGrid(19,19,3);
31 |
32 | // init host array
33 | unsigned char host_src[608*608*3]; // N H W C
34 | // unsigned char host_dst[1108992];
35 | unsigned char host_dst[608*608*3]; // N C H W
36 |
37 | // init src image
38 | for(int i = 0; i < 608*608*3; i++){
39 | // host_src[i] = i+1;
40 | host_src[i] = (i%3);
41 | }
42 |
43 | // init device array
44 | unsigned char *device_src, *device_dst;
45 | cudaMalloc((unsigned char **)&device_src, 608*608*3* sizeof(unsigned char));
46 | cudaMalloc((unsigned char **)&device_dst, 608*608*3* sizeof(unsigned char));
47 |
48 | cudaMemcpy(device_src , host_src , 608*608*3 * sizeof(unsigned char), cudaMemcpyHostToDevice);
49 |
50 | // run kernel
51 | transpose<<>>(device_dst, device_src);
52 | cudaDeviceSynchronize();
53 |
54 | // take out output
55 | cudaMemcpy(host_dst, device_dst, 608*608*3 * sizeof(unsigned char), cudaMemcpyDeviceToHost);
56 |
57 | // DEBUG : print first image in batch , first 30 pixel in 3 channels.
58 |
59 | for(int i = 0; i < 30*3; i+=3){ // N H W C
60 | printf("%d\n",host_src[i]);
61 | }
62 | printf("============================\n");
63 |
64 | for(int c = 0; c<3*608*608 ; c+=608*608){ // N C H W
65 | for(int i = 0 ; i < 30; i++){
66 | printf("%d %d %d\n", c+i, i, host_dst[c+i]);
67 | }
68 | printf("------------------------------\n");
69 | }
70 |
71 |
72 | // deinit GPU
73 | cudaFree(device_src);
74 | cudaFree(device_dst);
75 |
76 | return 0;
77 | }
78 | // clear && clear && nvcc NHWC2NCHW.cu -o trans.o && ./trans.o
79 |
--------------------------------------------------------------------------------
/deprecated/NHWC2NCHW_free.cu:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | __global__ void transpose(unsigned char *odata, const unsigned char *idata,
4 | int H, int W)
5 | {
6 | int N = gridDim.y; // batch size
7 | int n = blockIdx.y; // batch number
8 | int C = gridDim.z; // channel
9 | int c = blockIdx.z; // channel number
10 | long idx = n * blockDim.x * gridDim.x * C +
11 | threadIdx.x * gridDim.x * C +
12 | blockIdx.x * C+
13 | c;
14 |
15 | int img_coor = idx % (H*W*C); //coordinate of one image, not idx of batch image
16 | int h = img_coor / (W*C); // dst idx
17 | int w = img_coor % (W*C)/C; // dst idx
18 | long src_idx = n * (H * W * C) +
19 | h * (W * C) +
20 | w * C +
21 | c;
22 |
23 | long dst_idx = n * (C * H * W) +
24 | c * (H * W)+
25 | h * W+
26 | w;
27 | odata[dst_idx] = idata[src_idx];
28 | }
29 |
30 | int main(){
31 | // dim3 dimBlock(32,32,1); << Max total is 1024 , so , x=32 ,y=32 , some one use 1024 to handle flatten tensor is fine.
32 | // dim3 dimGrid(19,19,3); << x = 608 / 32 = 19 , same on y , z = channel * batch_size, assume channel = 3.
33 |
34 | int BATCH = 10;
35 | int HEIGHT = 50;
36 | int WIDTH = 50;
37 | int C = 3;
38 | int SIZE = HEIGHT * WIDTH * C;
39 |
40 | cudaStream_t stream1;
41 | cudaStreamCreate ( &stream1) ;
42 |
43 | dim3 dimBlock(1024, 1, 1);
44 | dim3 dimGrid(int(SIZE/C/1024)+1,BATCH,C);
45 |
46 | // init host array
47 | unsigned char host[SIZE*BATCH];
48 |
49 | // init src image
50 | for(int i = 0; i < SIZE*BATCH; i++){
51 | // host_src[i] = i+1;
52 | host[i] = (i%C);
53 | }
54 |
55 | for(int i = 0; i < 30*3; i+=3){ // N H W C
56 | printf("%d\n",host[i]);
57 | }
58 | printf("============================\n");
59 |
60 | // init device array
61 | unsigned char *device_src, *device_dst;
62 | cudaMalloc((unsigned char **)&device_src, SIZE* BATCH* sizeof(unsigned char));
63 | cudaMalloc((unsigned char **)&device_dst, SIZE* BATCH* sizeof(unsigned char));
64 |
65 | cudaMemcpy(device_src , host , SIZE * BATCH * sizeof(unsigned char), cudaMemcpyHostToDevice);
66 |
67 | // run kernel
68 | transpose<<>>(device_dst, device_src, HEIGHT, WIDTH);
69 | cudaDeviceSynchronize();
70 |
71 | // take out output
72 | cudaMemcpy(host, device_dst, SIZE * BATCH * sizeof(unsigned char), cudaMemcpyDeviceToHost);
73 |
74 | // DEBUG : print first image in batch , first 30 pixel in 3 channels.
75 |
76 |
77 |
78 | for(int n = 0; na){
14 | return a + w*(b-a);
15 | }
16 | else{
17 | return b + w*(a-b);
18 | }
19 | }
20 |
21 | __device__ float lerp2d(int f00, int f01, int f10, int f11,
22 | float centroid_h, float centroid_w )
23 | {
24 | centroid_w = (1 + lroundf(centroid_w) - centroid_w)/2;
25 | centroid_h = (1 + lroundf(centroid_h) - centroid_h)/2;
26 |
27 | float r0, r1, r;
28 | r0 = lerp1d(f00,f01,centroid_w);
29 | r1 = lerp1d(f10,f11,centroid_w);
30 |
31 | r = lerp1d(r0, r1, centroid_h); //+ 0.00001
32 | return r;
33 | }
34 |
35 | __global__ void Transpose(unsigned char *odata, const unsigned char *idata,
36 | int H, int W)
37 | {
38 | int N = gridDim.y; // batch size
39 | int n = blockIdx.y; // batch number
40 | int C = gridDim.z; // channel
41 | int c = blockIdx.z; // channel number
42 | long long idx = n * blockDim.x * gridDim.x * C +
43 | threadIdx.x * gridDim.x * C +
44 | blockIdx.x * C+
45 | c;
46 | int img_coor = idx % (H*W*C); //coordinate of one image, not idx of batch image
47 | int h = img_coor / (W*C); // dst idx
48 | int w = img_coor % (W*C)/C; // dst idx
49 |
50 | long long src_idx = n * (H * W * C) +
51 | h * (W * C) +
52 | w * C +
53 | c;
54 |
55 | long long dst_idx = n * (C * H * W) +
56 | c * (H * W)+
57 | h * W+
58 | w;
59 |
60 | odata[dst_idx] = idata[src_idx];
61 | }
62 |
63 | __global__ void cuResize(unsigned char* dst_img, unsigned char* src_img,
64 | int src_h, int src_w,
65 | int dst_h, int dst_w,
66 | float stride_h, float stride_w)
67 | {
68 | /*
69 | Input:
70 | src_img - NHWC
71 | channel C, default = 3
72 |
73 | Output:
74 | dst_img - NHWC
75 | */
76 |
77 | int N = gridDim.y; // batch size
78 | int n = blockIdx.y; // batch number
79 | int C = gridDim.z; // channel
80 | int c = blockIdx.z; // channel number
81 | long long idx = n * blockDim.x * gridDim.x * C +
82 | threadIdx.x * gridDim.x * C +
83 | blockIdx.x * C+
84 | c;
85 |
86 | // some overhead threads in each image process
87 | // when thread idx in one image exceed one image size return;
88 | if (idx%(blockDim.x * gridDim.x * C) >= dst_h* dst_w * C){return;}
89 |
90 | int H = dst_h;
91 | int W = dst_w;
92 | int img_coor = idx % (dst_h*dst_w*C); //coordinate of one image, not idx of batch image
93 | int h = img_coor / (W*C);
94 | int w = img_coor % (W*C)/C;
95 |
96 | float centroid_h, centroid_w;
97 | centroid_h = stride_h * (h + 0.5); // h w c -> x, y, z : 1080 , 1920 , 3
98 | centroid_w = stride_w * (w + 0.5); //
99 |
100 | long long f00,f01,f10,f11;
101 |
102 | int src_h_idx = lroundf(centroid_h)-1;
103 | int src_w_idx = lroundf(centroid_w)-1;
104 | if (src_h_idx<0){src_h_idx=0;}
105 | if (src_w_idx<0){src_w_idx=0;}
106 |
107 | f00 = n * src_h * src_w * C +
108 | src_h_idx * src_w * C +
109 | src_w_idx * C +
110 | c;
111 | f01 = n * src_h * src_w * C +
112 | src_h_idx * src_w * C +
113 | (src_w_idx+1) * C +
114 | c;
115 |
116 | f10 = n * src_h * src_w * C +
117 | (src_h_idx+1) * src_w * C +
118 | src_w_idx * C +
119 | c;
120 | f11 = n * src_h * src_w * C +
121 | (src_h_idx+1) * src_w * C +
122 | (src_w_idx+1) * C +
123 | c;
124 |
125 |
126 | // int rs;
127 | // if (int(f10/ (src_h * src_w * C)) > n ){
128 | // centroid_w = (1 + lroundf(centroid_w) - centroid_w)/2;
129 | // rs = lroundf(lerp1d(f00,f01,centroid_w));
130 | // }else{
131 | // rs = lroundf(lerp2d(src_img[f00], src_img[f01], src_img[f10], src_img[f11],
132 | // centroid_h, centroid_w));
133 | // }
134 |
135 |
136 | if (src_h_idx<0){src_h_idx=0;} // handle boundary pixle
137 | if (src_w_idx<0){src_w_idx=0;} // handle boundary pixle
138 |
139 | int rs = lroundf(lerp2d(src_img[f00], src_img[f01], src_img[f10], src_img[f11],
140 | centroid_h, centroid_w));
141 |
142 | long long dst_idx = n * (H * W * C) +
143 | h * (W * C) +
144 | w * C +
145 | c;
146 |
147 | dst_img[dst_idx] = (unsigned char)rs;
148 | }
149 | """)
150 |
151 | cuResizeKer = module.get_function("cuResize")
152 | TransposeKer = module.get_function("Transpose")
153 |
154 | def gpu_resize(input_img: np.ndarray, stream):
155 | """
156 | Resize the batch image to (608,608)
157 | and Convert NHWC to NCHW
158 | pass the gpu array to normalize the pixel ( divide by 255)
159 |
160 | Application oriented
161 |
162 | input_img : batch input, format: NHWC , recommend RGB. *same as the NN input format
163 | input must be 3 channel, kernel set ChannelDim as 3.
164 | out : batch resized array, format: NCHW , same as intput channel
165 | """
166 | # ========= Init Params =========
167 |
168 |
169 | # convert to array
170 | batch, src_h, src_w, channel = input_img.shape
171 | dst_h, dst_w = 480, 640
172 | DST_SIZE = dst_h* dst_w* 3
173 | # Mem Allocation
174 | # input memory
175 | inp = cuda.managed_zeros(shape=(batch,src_h,src_w,channel),
176 | dtype=np.uint8,
177 | mem_flags=cuda.mem_attach_flags.GLOBAL)
178 |
179 | inp[:,:src_h,:src_w,:] = input_img
180 |
181 | # output data
182 | out = cuda.managed_zeros(shape=(batch,dst_h,dst_w,channel),
183 | dtype=np.uint8,
184 | mem_flags=cuda.mem_attach_flags.GLOBAL)
185 |
186 | #Transpose
187 | trans = cuda.managed_zeros(shape=(batch,channel,dst_h,dst_w),
188 | dtype=np.uint8,
189 | mem_flags=cuda.mem_attach_flags.GLOBAL)
190 |
191 | cuResizeKer(out, inp,
192 | np.int32(src_h), np.int32(src_w),
193 | np.int32(dst_h), np.int32(dst_w),
194 | np.float32(src_h/dst_h), np.float32(src_w/dst_w),
195 | block=(1024, 1, 1),
196 | grid=(int(DST_SIZE/3//1024)+1,batch,3),
197 | stream=stream)
198 |
199 | TransposeKer(trans,out,
200 | np.int32(dst_h), np.int32(dst_w),
201 | block=(1024, 1, 1),
202 | grid=(int(DST_SIZE/3//1024)+1,batch,3),
203 | stream=stream)
204 |
205 | # Wait for kernel completion before host access
206 | # stream.synchronize()
207 | context.synchronize()
208 |
209 | return trans
210 |
211 |
212 | if __name__ == "__main__":
213 | import cv2
214 | stream = cuda.Stream()
215 |
216 | batch = 32
217 | img_batch = np.tile(cv2.resize(cv2.imread("debug_image/helmet.jpg"),(1920,1080)),[batch,1,1,1])
218 |
219 | pix = gpu_resize(img_batch,stream)
220 | pix = np.transpose(pix,[0,2,3,1])
--------------------------------------------------------------------------------
/deprecated/resize_fixed_dim.py:
--------------------------------------------------------------------------------
1 | import pycuda.driver as cuda
2 | import pycuda.autoinit
3 | from pycuda.compiler import SourceModule
4 | from pycuda import gpuarray
5 | import numpy as np
6 | import cv2
7 | from line_profiler import LineProfiler
8 |
9 | profile = LineProfiler()
10 |
11 | bl_Normalize = 0
12 | bl_Trans = 0
13 | pagelock = 0
14 |
15 | module = SourceModule("""
16 |
17 | __device__ double lerp1d(int a, int b, float w)
18 | {
19 | return fma(w, (float)b, fma(-w,(float)a,(float)a));
20 | }
21 |
22 | __device__ float lerp2d(int f00, int f01, int f10, int f11,
23 | float centroid_h, float centroid_w )
24 | {
25 | centroid_w = (1 + lroundf(centroid_w) - centroid_w)/2;
26 | centroid_h = (1 + lroundf(centroid_h) - centroid_h)/2;
27 |
28 | float r0, r1, r;
29 | r0 = lerp1d(f00,f01,centroid_w);
30 | r1 = lerp1d(f10,f11,centroid_w);
31 |
32 | r = lerp1d(r0, r1, centroid_h); //+ 0.00001
33 | return r;
34 | }
35 |
36 | __global__ void Transpose(unsigned char *odata, const unsigned char *idata,
37 | int H, int W)
38 | {
39 | // int N = gridDim.y; // batch size
40 | int n = blockIdx.y; // batch number
41 | int C = gridDim.z; // channel
42 | int c = blockIdx.z; // channel number
43 | long long idx = n * blockDim.x * gridDim.x * C +
44 | threadIdx.x * gridDim.x * C +
45 | blockIdx.x * C+
46 | c;
47 | int img_coor = idx % (H*W*C); //coordinate of one image, not idx of batch image
48 | int h = img_coor / (W*C); // dst idx
49 | int w = img_coor % (W*C)/C; // dst idx
50 |
51 | long long src_idx = n * (H * W * C) +
52 | h * (W * C) +
53 | w * C +
54 | c;
55 |
56 | long long dst_idx = n * (C * H * W) +
57 | c * (H * W)+
58 | h * W+
59 | w;
60 |
61 | odata[dst_idx] = idata[src_idx];
62 | }
63 |
64 | __global__ void Transpose_and_normalise(float *odata, const unsigned char *idata,
65 | int H, int W)
66 | {
67 | // int N = gridDim.y; // batch size
68 | int n = blockIdx.y; // batch number
69 | int C = gridDim.z; // channel
70 | int c = blockIdx.z; // channel number
71 | long long idx = n * blockDim.x * gridDim.x * C +
72 | threadIdx.x * gridDim.x * C +
73 | blockIdx.x * C+
74 | c;
75 | int img_coor = idx % (H*W*C); //coordinate of one image, not idx of batch image
76 | int h = img_coor / (W*C); // dst idx
77 | int w = img_coor % (W*C)/C; // dst idx
78 |
79 | long long src_idx = n * (H * W * C) +
80 | h * (W * C) +
81 | w * C +
82 | c;
83 |
84 | long long dst_idx = n * (C * H * W) +
85 | c * (H * W)+
86 | h * W+
87 | w;
88 |
89 | odata[dst_idx] = idata[src_idx]/255.0;
90 | }
91 |
92 | __global__ void cuResize(unsigned char* src_img, unsigned char* dst_img,
93 | int src_h, int src_w,
94 | int dst_h, int dst_w,
95 | float stride_h, float stride_w)
96 | {
97 | /*
98 | Input:
99 | src_img - NHWC
100 | channel C, default = 3
101 |
102 | Output:
103 | dst_img - NHWC
104 | */
105 |
106 | // int N = gridDim.y; // batch size
107 | int n = blockIdx.y; // batch number
108 | int C = gridDim.z; // channel
109 | int c = blockIdx.z; // channel number
110 | long long idx = n * blockDim.x * gridDim.x * C +
111 | threadIdx.x * gridDim.x * C +
112 | blockIdx.x * C+
113 | c;
114 |
115 | // some overhead threads in each image process
116 | // when thread idx in one image exceed one image size return;
117 | if (idx%(blockDim.x * gridDim.x * C) >= dst_h* dst_w * C){return;}
118 |
119 | int H = dst_h;
120 | int W = dst_w;
121 | int img_coor = idx % (dst_h*dst_w*C); //coordinate of one image, not idx of batch image
122 | int h = img_coor / (W*C);
123 | int w = img_coor % (W*C)/C;
124 |
125 | float centroid_h, centroid_w;
126 | centroid_h = stride_h * (h + 0.5); // h w c -> x, y, z : 1080 , 1920 , 3
127 | centroid_w = stride_w * (w + 0.5); //
128 |
129 | long long f00,f01,f10,f11;
130 |
131 | int src_h_idx = lroundf(centroid_h)-1;
132 | int src_w_idx = lroundf(centroid_w)-1;
133 | if (src_h_idx<0){src_h_idx=0;}
134 | if (src_w_idx<0){src_w_idx=0;}
135 |
136 | f00 = n * src_h * src_w * C +
137 | src_h_idx * src_w * C +
138 | src_w_idx * C +
139 | c;
140 | f01 = n * src_h * src_w * C +
141 | src_h_idx * src_w * C +
142 | (src_w_idx+1) * C +
143 | c;
144 |
145 | f10 = n * src_h * src_w * C +
146 | (src_h_idx+1) * src_w * C +
147 | src_w_idx * C +
148 | c;
149 | f11 = n * src_h * src_w * C +
150 | (src_h_idx+1) * src_w * C +
151 | (src_w_idx+1) * C +
152 | c;
153 |
154 | if (src_w_idx+1>=src_w){f01 = f00; f11 = f10;}
155 | if (src_h_idx+1>=src_h){f10 = f00; f11 = f01;}
156 |
157 | int rs = lroundf(lerp2d(src_img[f00], src_img[f01], src_img[f10], src_img[f11],
158 | centroid_h, centroid_w));
159 |
160 | long long dst_idx = n * (H * W * C) +
161 | h * (W * C) +
162 | w * C +
163 | c;
164 |
165 | dst_img[dst_idx] = (unsigned char)rs;
166 | }
167 | """)
168 |
169 | # block = (32, 32, 1) blockDim | threadIdx
170 | # grid = (19,19,3)) gridDim | blockIdx
171 |
172 | cuResizeKer = module.get_function("cuResize")
173 | TransposeKer = module.get_function("Transpose")
174 | TransNorKer = module.get_function("Transpose_and_normalise")
175 |
176 |
177 |
178 | class cuResize():
179 | """docstring for ClassName"""
180 | def __init__(self, shape=(1920,1080), batch=50, frame_w=1920, frame_h=1080):
181 | # ========= Init Params =========
182 | # size of frame
183 | self.batch = batch # limited by bytes, maximum around 200* 1080p ~= 50 * 4k
184 | self.channel = 3
185 | self.frame_w = frame_w # 1920 / 1920*n , fixed input image size
186 | self.frame_h = frame_h # 1080 / 1080*n , fixed input image size
187 | self.dst_w = shape[0] # 1920
188 | self.dst_h = shape[1] # 1080
189 | self.DST_SIZE = self.dst_h * self.dst_w * 3
190 |
191 | # memory
192 | self.inp = None
193 | self.out = None
194 | # async stream
195 | self.stream = cuda.Stream()
196 |
197 | self.allocate_memory()
198 | self.warm_up() # warm up
199 |
200 |
201 | def allocate_memory(self):
202 | self.inp = {"host":cuda.pagelocked_zeros(shape=(self.batch,self.frame_h,self.frame_w,self.channel),
203 | dtype=np.uint8,
204 | mem_flags=cuda.host_alloc_flags.DEVICEMAP)}
205 | self.inp["device"] = cuda.mem_alloc(self.inp["host"].nbytes)
206 |
207 |
208 | self.out = {"host":cuda.pagelocked_zeros(shape=(self.batch,self.dst_h,self.dst_w,self.channel),
209 | dtype=np.uint8,
210 | mem_flags=cuda.host_alloc_flags.DEVICEMAP)}
211 | self.out["device"] = cuda.mem_alloc(self.out["host"].nbytes)
212 |
213 |
214 |
215 | def warm_up(self):
216 | cuResizeKer(self.inp["device"], self.out["device"],
217 | np.int32(self.dst_h), np.int32(self.dst_w),
218 | np.int32(self.dst_h), np.int32(self.dst_w),
219 | np.float32(1), np.float32(1),
220 | block=(1024, 1, 1),
221 | grid=(int(self.DST_SIZE/3//1024)+1,self.batch,3),
222 | stream=self.stream)
223 |
224 | @profile
225 | def __call__(self, input_img: np.ndarray):
226 | """
227 | Resize the batch image to (608,608)
228 | and Convert NHWC to NCHW
229 | pass the gpu array to normalize the pixel ( divide by 255)
230 | Application oriented
231 | input_img : batch input, format: NHWC , recommend RGB. *same as the NN input format
232 | input must be 3 channel, kernel set ChannelDim as 3.
233 | out : batch resized array, format: NCHW , same as intput channel
234 | """
235 | batch, src_h, src_w, channel = input_img.shape
236 | assert (src_h <= self.frame_h) & (src_w <= self.frame_w)
237 | self.inp["host"][:,:src_h,:src_w,:] = input_img
238 | cuda.memcpy_htod_async(self.inp["device"], self.inp["host"],self.stream)
239 |
240 | cuResizeKer(self.inp["device"], self.out["device"],
241 | np.int32(src_h), np.int32(src_w),
242 | np.int32(self.dst_h), np.int32(self.dst_w),
243 | np.float32(src_h/self.dst_h), np.float32(src_w/self.dst_w),
244 | block=(1024, 1, 1),
245 | grid=(int(self.DST_SIZE/3//1024)+1,self.batch,3),
246 | stream=self.stream)
247 |
248 | cuda.memcpy_dtoh_async(self.out["host"], self.out["device"],self.stream)
249 |
250 | self.stream.synchronize()
251 | # self.cleanup()
252 | return self.out["host"]
253 |
254 | def cleanup(self):
255 | self.inp["host"][:,:,:,:] = 0
256 |
257 | def print_stats(self):
258 | profile.print_stats()
259 |
260 | # def deallocate(self):
261 | # free(gpu_mem)
262 |
263 |
264 | if __name__ == "__main__":
265 | print("[ WARNING ] - pycuda is deprecated , recommend cupy instead")
266 | from time import perf_counter
267 | batch = 200
268 | img_batch = np.tile(cv2.resize(cv2.imread("trump.jpg"),(1920,1080)),[batch,1,1,1])
269 | resizer = cuResize(shape=(1920,1080), batch=200, frame_h=1080, frame_w=1920) # C backend hv to pre allocate input frame maximum dimension
270 |
271 | for _ in range(10):
272 | start = perf_counter()
273 | batch_result = resizer(img_batch)
274 | print("cuResize: ",perf_counter()- start,"s")
275 | print(batch_result.shape)
276 | resizer.print_stats()
277 |
278 | # batch_result = np.transpose(batch_result,[0,2,3,1])
279 |
280 | cv2.imwrite("output_1.jpg", batch_result[0])
281 | cv2.imwrite("output_50.jpg", batch_result[49])
282 | cv2.imwrite("output_102.jpg", batch_result[101])
283 | print(batch_result.shape)
--------------------------------------------------------------------------------
/deprecated/resize_free_dim.py:
--------------------------------------------------------------------------------
1 | import pycuda.driver as cuda
2 | import pycuda.autoinit
3 | from pycuda.compiler import SourceModule
4 | from pycuda import gpuarray
5 | import numpy as np
6 | import cv2
7 | from line_profiler import LineProfiler
8 |
9 | profile = LineProfiler()
10 |
11 | bl_Normalize = 0
12 | bl_Trans = 1
13 | pagelock = 1
14 |
15 | module = SourceModule("""
16 |
17 | __device__ double lerp1d(int a, int b, float w)
18 | {
19 | return fma(w, (float)b, fma(-w,(float)a,(float)a));
20 | }
21 |
22 | __device__ float lerp2d(int f00, int f01, int f10, int f11,
23 | float centroid_h, float centroid_w )
24 | {
25 | centroid_w = (1 + lroundf(centroid_w) - centroid_w)/2;
26 | centroid_h = (1 + lroundf(centroid_h) - centroid_h)/2;
27 |
28 | double r0, r1, r;
29 | r0 = lerp1d(f00,f01,centroid_w);
30 | r1 = lerp1d(f10,f11,centroid_w);
31 |
32 | r = lerp1d(r0, r1, centroid_h); //+ 0.00001
33 | return r;
34 | }
35 |
36 | __global__ void Transpose(unsigned char *odata, const unsigned char *idata,
37 | int H, int W)
38 | {
39 | int n = blockIdx.y; // batch number
40 | int C = gridDim.z; // channel
41 | int c = blockIdx.z; // channel number
42 | long long idx = n * blockDim.x * gridDim.x * C +
43 | threadIdx.x * gridDim.x * C +
44 | blockIdx.x * C+
45 | c;
46 | int img_coor = idx % (H*W*C); //coordinate of one image, not idx of batch image
47 | int h = img_coor / (W*C); // dst idx
48 | int w = img_coor % (W*C)/C; // dst idx
49 |
50 | long long src_idx = n * (H * W * C) +
51 | h * (W * C) +
52 | w * C +
53 | c;
54 |
55 | long long dst_idx = n * (C * H * W) +
56 | c * (H * W)+
57 | h * W+
58 | w;
59 |
60 | odata[dst_idx] = idata[src_idx];
61 | }
62 |
63 | __global__ void Transpose_and_normalise(float *odata, const unsigned char *idata,
64 | int H, int W)
65 | {
66 | int n = blockIdx.y; // batch number
67 | int C = gridDim.z; // channel
68 | int c = blockIdx.z; // channel number
69 | long long idx = n * blockDim.x * gridDim.x * C +
70 | threadIdx.x * gridDim.x * C +
71 | blockIdx.x * C+
72 | c;
73 | int img_coor = idx % (H*W*C); //coordinate of one image, not idx of batch image
74 | int h = img_coor / (W*C); // dst idx
75 | int w = img_coor % (W*C)/C; // dst idx
76 |
77 | long long src_idx = n * (H * W * C) +
78 | h * (W * C) +
79 | w * C +
80 | c;
81 |
82 | long long dst_idx = n * (C * H * W) +
83 | c * (H * W)+
84 | h * W+
85 | w;
86 |
87 | odata[dst_idx] = idata[src_idx]/255.0;
88 | }
89 |
90 | __global__ void cuResize(unsigned char* src_img, unsigned char* dst_img,
91 | const int src_h, const int src_w,
92 | const int dst_h, const int dst_w,
93 | const float scale_h, const float scale_w)
94 | {
95 | /*
96 | Input:
97 | src_img - NHWC
98 | channel C, default = 3
99 |
100 | Output:
101 | dst_img - NHWC
102 | */
103 |
104 | int n = blockIdx.y; // batch number
105 | int C = gridDim.z; // channel
106 | int c = blockIdx.z; // channel number
107 | long long idx = n * blockDim.x * gridDim.x * C +
108 | threadIdx.x * gridDim.x * C +
109 | blockIdx.x * C+
110 | c;
111 |
112 | // some overhead threads in each image process
113 | // when thread idx in one image exceed one image size return;
114 | if (idx%(blockDim.x * gridDim.x * C) >= dst_h* dst_w * C){return;}
115 |
116 | int H = dst_h;
117 | int W = dst_w;
118 | int img_coor = idx % (dst_h*dst_w*C); //coordinate of one image, not idx of batch image
119 | int h = img_coor / (W*C);
120 | int w = img_coor % (W*C)/C;
121 |
122 | float centroid_h, centroid_w;
123 | centroid_h = scale_h * (h + 0.5); // h w c -> x, y, z : 1080 , 1920 , 3
124 | centroid_w = scale_w * (w + 0.5); //
125 |
126 | long long f00,f01,f10,f11;
127 |
128 | int src_h_idx = lroundf(centroid_h)-1;
129 | int src_w_idx = lroundf(centroid_w)-1;
130 | if (src_h_idx<0){src_h_idx=0;}
131 | if (src_w_idx<0){src_w_idx=0;}
132 |
133 | f00 = n * src_h * src_w * C +
134 | src_h_idx * src_w * C +
135 | src_w_idx * C +
136 | c;
137 | f01 = n * src_h * src_w * C +
138 | src_h_idx * src_w * C +
139 | (src_w_idx+1) * C +
140 | c;
141 |
142 | f10 = n * src_h * src_w * C +
143 | (src_h_idx+1) * src_w * C +
144 | src_w_idx * C +
145 | c;
146 | f11 = n * src_h * src_w * C +
147 | (src_h_idx+1) * src_w * C +
148 | (src_w_idx+1) * C +
149 | c;
150 |
151 | if (src_w_idx+1>=src_w){f01 = f00; f11 = f10;}
152 | if (src_h_idx+1>=src_h){f10 = f00; f11 = f01;}
153 |
154 | int rs = lroundf(lerp2d(src_img[f00], src_img[f01], src_img[f10], src_img[f11],
155 | centroid_h, centroid_w));
156 |
157 | long long dst_idx = n * (H * W * C) +
158 | h * (W * C) +
159 | w * C +
160 | c;
161 |
162 | dst_img[dst_idx] = (unsigned char)rs;
163 | }
164 | """)
165 |
166 | # block = (32, 32, 1) blockDim | threadIdx
167 | # grid = (19,19,3)) gridDim | blockIdx
168 |
169 | cuResizeKer = module.get_function("cuResize")
170 | TransposeKer = module.get_function("Transpose")
171 | TransNorKer = module.get_function("Transpose_and_normalise")
172 |
173 | @profile
174 | def gpu_resize(input_img: np.ndarray, shape=(608,608)):
175 | """
176 | Resize the batch image to (608,608)
177 | and Convert NHWC to NCHW
178 | pass the gpu array to normalize the pixel ( divide by 255)
179 |
180 | Application oriented
181 |
182 | input_img : batch input, format: NHWC , recommend RGB. *same as the NN input format
183 | input must be 3 channel, kernel set ChannelDim as 3.
184 | out : batch resized array, format: NCHW , same as intput channel
185 | """
186 | # ========= Init Params =========
187 | stream = cuda.Stream()
188 |
189 | # convert to array
190 | batch, src_h, src_w, channel = input_img.shape
191 | dst_h, dst_w = shape[0], shape[1]
192 | DST_SIZE = dst_h* dst_w* 3
193 | # Mem Allocation
194 | # input memory
195 |
196 | if pagelock: # = = = = = = Pagelock emory = = = = = =
197 | inp = {"host":cuda.pagelocked_zeros(shape=(batch,src_h,src_w,channel),
198 | dtype=np.uint8,
199 | mem_flags=cuda.host_alloc_flags.DEVICEMAP)}
200 | # inp = {"host":cuda.pagelocked_empty_like(input_img,
201 | # mem_flags=cuda.host_alloc_flags.DEVICEMAP)}
202 | # print(inp["host"].shape,input_img.shape)
203 | inp["host"][:,:src_h,:src_w,:] = input_img
204 | else: # = = = = = = Global memory = = = = = =
205 | inp = {"host":input_img}
206 |
207 | inp["device"] = cuda.mem_alloc(inp["host"].nbytes)
208 | cuda.memcpy_htod_async(inp["device"], inp["host"],stream)
209 |
210 |
211 |
212 |
213 | # output data
214 | if pagelock: # = = = = = = Pagelock emory = = = = = =
215 | out = {"host":cuda.pagelocked_zeros(shape=(batch,dst_h,dst_w,channel),
216 | dtype=np.uint8,
217 | mem_flags=cuda.host_alloc_flags.DEVICEMAP)}
218 | else: # = = = = = = Global memory = = = = = =
219 | out = {"host":np.zeros(shape=(batch,dst_h,dst_w,channel), dtype=np.uint8)} # N H W C
220 |
221 | out["device"] = cuda.mem_alloc(out["host"].nbytes)
222 | cuda.memcpy_htod_async(out["device"], out["host"],stream)
223 |
224 | import time
225 | time.sleep(5)
226 |
227 | #Transpose (and Normalize)
228 | if bl_Normalize or bl_Trans:
229 | if bl_Normalize:
230 | if pagelock:
231 | trans = {"host":cuda.pagelocked_zeros(shape=(batch,channel,dst_h,dst_w),
232 | dtype=np.float32,
233 | mem_flags=cuda.host_alloc_flags.DEVICEMAP)} # N C H W
234 | else:
235 | trans = {"host":np.zeros(shape=(batch,channel,dst_h,dst_w), dtype=np.float32)} # N C H W
236 | else:
237 | if pagelock:
238 | trans = {"host":cuda.pagelocked_zeros(shape=(batch,channel,dst_h,dst_w),
239 | dtype=np.uint8,
240 | mem_flags=cuda.host_alloc_flags.DEVICEMAP)}
241 | else:
242 | trans = {"host":np.zeros(shape=(batch,channel,dst_h,dst_w), dtype=np.uint8)} # N C H W
243 |
244 | trans["device"] = cuda.mem_alloc(trans["host"].nbytes)
245 | cuda.memcpy_htod_async(trans["device"], trans["host"],stream)
246 |
247 | # init resize , store kernel in cache
248 | cuResizeKer(inp["device"], out["device"],
249 | np.int32(src_h), np.int32(src_w),
250 | np.int32(dst_h), np.int32(dst_w),
251 | np.float32(src_h/dst_h), np.float32(src_w/dst_w),
252 | block=(1024, 1, 1),
253 | grid=(int(DST_SIZE/3//1024)+1,batch,3),
254 | stream=stream)
255 |
256 | # ========= Testing =========
257 |
258 | for _ in range(1):
259 | cuResizeKer(inp["device"], out["device"],
260 | np.int32(src_h), np.int32(src_w),
261 | np.int32(dst_h), np.int32(dst_w),
262 | np.float32(src_h/dst_h), np.float32(src_w/dst_w),
263 | block=(1024, 1, 1),
264 | grid=(int(DST_SIZE/3//1024)+1,batch,3))
265 |
266 | # ========= Copy out result =========
267 |
268 | if bl_Normalize:
269 | TransNorKer(trans["device"],out["device"],
270 | block=(32, 32, 1),
271 | grid=(19,19,3*batch))
272 | cuda.memcpy_dtoh_async(trans["host"], trans["device"],stream)
273 | stream.synchronize()
274 | return trans["host"]
275 | elif bl_Trans:
276 | TransposeKer(trans["device"],out["device"],
277 | np.int32(dst_h), np.int32(dst_w),
278 | block=(1024, 1, 1),
279 | grid=(int(DST_SIZE/3//1024)+1,batch,3))
280 | cuda.memcpy_dtoh_async(trans["host"], trans["device"],stream)
281 | stream.synchronize()
282 | return trans["host"]
283 | else:
284 | cuda.memcpy_dtoh_async(out["host"], out["device"],stream)
285 | stream.synchronize()
286 | return out["host"]
287 |
288 | if __name__ == "__main__":
289 | # img = cv2.resize(cv2.imread("trump.jpg"),(1920,1080))
290 | # img = cv2.imread("trump.jpg")
291 | # img = np.tile(img,[batch,1,1,1])
292 |
293 | # img = np.zeros(shape=(3,1080,1920,3),dtype = np.uint8)
294 | # img[0,:48,:64,:] = cv2.resize(cv2.imread("trump.jpg"),(64,48))
295 | # img[1,:480,:640,:] = cv2.resize(cv2.imread("trump.jpg"),(640,480))
296 | # img[2,:1080,:1920,:] = cv2.resize(cv2.imread("trump.jpg"),(1920,1080))
297 |
298 | batch = 50
299 | # img_batch_0 = np.tile(cv2.resize(cv2.imread("trump.jpg"),(20,20)),[batch,1,1,1])
300 | # img_batch_1 = np.tile(cv2.resize(cv2.imread("trump.jpg"),(320,240)),[batch,1,1,1])
301 | img_batch_2 = np.tile(cv2.resize(cv2.imread("trump.jpg"),(1920,1080)),[batch,1,1,1])
302 |
303 | # rgba_img = cv2.resize(cv2.imread("rgba.png"),(20,20))
304 | # img_batch_0[10] = rgba_img
305 | # img_batch_0[20] = rgba_img
306 | # img_batch_0[53] = rgba_img
307 |
308 | # pix_0 = gpu_resize(img_batch_0)
309 | # pix_1 = gpu_resize(img_batch_1)
310 | pix_2 = gpu_resize(img_batch_2,shape = (480,640))
311 | if bl_Normalize or bl_Trans:
312 | # print(1)
313 | # pix_0 = np.transpose(pix_0,[0,2,3,1])
314 | # pix_1 = np.transpose(pix_1,[0,2,3,1])
315 | pix_2 = np.transpose(pix_2,[0,2,3,1])
316 | # cv2.imwrite("trans0.jpg", pix_0[0])
317 | # cv2.imwrite("trans1.jpg", pix_1[0])
318 | cv2.imwrite("trans2.jpg", pix_2[0])
319 | print("Done")
320 |
321 | # print(pix_0[0])
322 | # print(pix_0[-1])
323 | # print(pix_0.shape)
324 |
325 | # imgs = pix_1
326 | # for idx,img in enumerate(list(imgs)):
327 | # print(idx)
328 | # assert np.array_equal(imgs[0],img)
329 |
330 | # profile.print_stats()
331 | # print(pix.shape)
332 | # cv2.imwrite("pycuda_outpuut.jpg", pix[0])
--------------------------------------------------------------------------------
/deprecated/resize_ker.cu:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | __device__ double lerp1d(int a, int b, float w)
4 | {
5 | return fma(w, (float)b, fma(-w,(float)a,(float)a));
6 | }
7 |
8 | __device__ float lerp2d(int f00, int f01, int f10, int f11,
9 | float centroid_h, float centroid_w )
10 | {
11 | centroid_w = (1 + lroundf(centroid_w) - centroid_w)/2;
12 | centroid_h = (1 + lroundf(centroid_h) - centroid_h)/2;
13 |
14 | float r0, r1, r;
15 | r0 = lerp1d(f00,f01,centroid_w);
16 | r1 = lerp1d(f10,f11,centroid_w);
17 |
18 | r = lerp1d(r0, r1, centroid_h); //+ 0.00001
19 | // printf("re: %f, %f | %f, %f | %f, %f | %f | %d, %d, %d, %d \n", centroid_x , centroid_y, centroid_x_re, centroid_y_re, r0, r1, r, f00, f01, f10, f11);
20 | return r;
21 | }
22 |
23 | __global__ void tester(unsigned char* src_img, unsigned char* dst_img,
24 | int src_h, int src_w,
25 | float stride_h, float stride_w)
26 | {
27 | int H = blockDim.x * gridDim.x; // # dst_height
28 | int W = blockDim.y * gridDim.y; // # dst_width
29 | int h = blockDim.x * blockIdx.x + threadIdx.x; // 32 * bkIdx[0:18] + tdIdx; [0,607] # x / h-th row
30 | int w = blockDim.y * blockIdx.y + threadIdx.y; // 32 * bkIdx[0:18] + tdIdx; [0,607] # y / w-th col
31 | int C = 3; // # ChannelDim
32 | int c = blockIdx.z % 3 ; // [0,2] # ChannelIdx
33 | int n = blockIdx.z / 3 ; // [0 , Batch size-1], # BatchIdx
34 | int N = gridDim.z / 3 ;
35 |
36 | // printf("%d(%d), %d(%d), %d(%d), %d(%d) \n",n,N,c,C,h,H,w,W);
37 | // idx = NHWC = n*(HWC) + h*(WC) + w*C + c;
38 | int idx = n * (H * W * C) +
39 | h * (W * C) +
40 | w * C +
41 | c;
42 |
43 | // idx = NCHW = n*(CHW) + c*(HW) + h*W + w
44 | // int idx = n * (C * H * W) +
45 | // c * (H * W)+
46 | // h * W+
47 | // w;
48 |
49 | // int idx = x * blockDim.y * gridDim.y * gridDim.z + y * gridDim.z + z; // x * 608(width) * 3(channel) + y * 3(channel) + [0,2]
50 |
51 | float centroid_h, centroid_w;
52 | centroid_h = stride_h * (h + 0.5); // h w c -> x, y, z : 1080 , 1920 , 3
53 | centroid_w = stride_w * (w + 0.5); //
54 |
55 | int f00,f01,f10,f11;
56 |
57 | int src_h_idx = lroundf(centroid_h)-1;
58 | int src_w_idx = lroundf(centroid_w)-1;
59 | if (src_h_idx<0){src_h_idx=0;}
60 | if (src_w_idx<0){src_w_idx=0;}
61 | // printf("h:%d w:%d\n",src_h_idx,src_w_idx);
62 |
63 | // // idx = NHWC = n*(HWC) + h*(WC) + w*C + c;
64 | f00 = n * src_h * src_w * C +
65 | src_h_idx * src_w * C +
66 | src_w_idx * C +
67 | c;
68 | f01 = n * src_h * src_w * C +
69 | src_h_idx * src_w * C +
70 | (src_w_idx+1) * C +
71 | c;
72 | f10 = n * src_h * src_w * C +
73 | (src_h_idx+1) * src_w * C +
74 | src_w_idx * C +
75 | c;
76 | f11 = n * src_h * src_w * C +
77 | (src_h_idx+1) * src_w * C +
78 | (src_w_idx+1) * C +
79 | c;
80 |
81 | // bool bl_a = (f01 == (f00 + 3));
82 | // bool bl_b = (f10 == (f00 + src_w * 3));
83 | // bool bl_c = (f11 == (f00 + src_w * 3 + 3));
84 | // printf("%d, %d, %d | %d, %d, %d | %d\n", bl_a,bl_b,bl_c, f01-f00, f10-f00,f11-f00, src_w);
85 |
86 |
87 |
88 | // printf("h: %d, w: %d | %d, %d, %d , %d | %d, %d | %d, %d, %d, %d \n", src_h_idx, src_w_idx, f00,f01,f10,f11, C, c, src_img[f00], src_img[f01], src_img[f10], src_img[f11]);
89 |
90 |
91 | // lerp2d(src_img[f00], src_img[f01], src_img[f10], src_img[f11],);
92 | // printf("%d, %d | %d, %d, %d, %d \n", src_h_idx, src_w_idx, src_img[f00], src_img[f01], src_img[f10], src_img[f11]);
93 |
94 | // float temp = lerp2d(src_img[f00], src_img[f01], src_img[f10], src_img[f11],
95 | // centroid_y, centroid_x);
96 | // printf("z: %d | %f, %f | %f | %d, %d, %d, %d \n", z, centroid_x, centroid_y, temp, src_img[f00], src_img[f01], src_img[f10], src_img[f11]);
97 | // printf("%f",temp);
98 |
99 |
100 | int rs = lroundf(lerp2d(src_img[f00], src_img[f01], src_img[f10], src_img[f11],
101 | centroid_h, centroid_w));
102 | // printf("rs: %d | centroid: h:%f, w:%f | h: %d, w: %d | %d, %d, %d , %d | %d, %d | %d, %d, %d, %d \n", rs, centroid_h, centroid_w, src_h_idx, src_w_idx, f00,f01,f10,f11, C, c, src_img[f00], src_img[f01], src_img[f10], src_img[f11]);
103 | // printf("rs: %d | stride h: %f , w: %f | centroid: h:%f, w:%f| h: %d, w: %d | %d, %d, %d , %d | %d, %d | %d, %d, %d, %d \n", rs, stride_h, stride_w, centroid_h, centroid_w, src_h_idx, src_w_idx, f00,f01,f10,f11, C, c, src_img[f00], src_img[f01], src_img[f10], src_img[f11]);
104 | // printf("z: %d | %f, %f | %d | %d, %d, %d, %d \n", z, centroid_x, centroid_y, rs, src_img[f00], src_img[f01], src_img[f10], src_img[f11]);
105 |
106 | dst_img[idx] = (unsigned char)rs;
107 | }
108 |
109 | int main(){
110 | // dim3 dimBlock(32,32,1); << Max total is 1024 , so , x=32 ,y=32 , some one use 1024 to handle flatten tensor is fine.
111 | // dim3 dimGrid(19,19,3); << x = 608 / 32 = 19 , same on y , z = channel * batch_size, assume channel = 3.
112 | dim3 dimBlock(32,32,1);
113 | dim3 dimGrid(19,19,3);
114 |
115 | unsigned char host_src[1920*1080*3];
116 | // unsigned char host_dst[1108992];
117 | unsigned char host_dst[608*608*3];
118 |
119 | // init src image
120 | for(int i = 0; i < 1920*1080*3; i++){
121 | host_src[i] = i+1;
122 | // host_src[i] = (i%3);
123 | }
124 |
125 | float stride_h = 1080.0 / 608;
126 | float stride_w = 1920.0 / 608;
127 |
128 | unsigned char *device_src, *device_dst;
129 | cudaMalloc((unsigned char **)&device_src, 1920*1080*3* sizeof(unsigned char));
130 | cudaMalloc((unsigned char **)&device_dst, 608*608*3* sizeof(unsigned char));
131 |
132 | cudaMemcpy(device_src , host_src , 1920*1080*3 * sizeof(unsigned char), cudaMemcpyHostToDevice);
133 |
134 | tester<<>>(device_src, device_dst,
135 | 1080, 1920,
136 | stride_h, stride_w);
137 | cudaDeviceSynchronize();
138 |
139 | cudaMemcpy(host_dst, device_dst, 608*608*3 * sizeof(unsigned char), cudaMemcpyDeviceToHost);
140 |
141 | // DEBUG : print first image in batch , first 30 pixel in 3 channels.
142 |
143 | for(int i = 0; i < 30*3; i+=3){ // NHWC
144 | printf("%d\n",host_src[i]);
145 | }
146 | printf("============================\n");
147 |
148 | // for(int c = 0; c<3*608*608 ; c+=608*608){ // if NCHW
149 | // for(int i = 0 ; i < 30; i++){
150 | // printf("%d %d %d\n", c+i, i, host_dst[c+i]);
151 | // }
152 | // printf("------------------------------\n");
153 | // }
154 | for(int c = 0; c<3; c++){ // NHWC
155 | for(int i = 0 ; i < 30; i++){
156 | int idx = i*3 +c;
157 | printf("%d %d %d\n", c+i, i, host_dst[idx]);
158 | }
159 | printf("------------------------------\n");
160 | }
161 |
162 |
163 |
164 | cudaFree(device_src);
165 | cudaFree(device_dst);
166 |
167 | return 0;
168 | }
--------------------------------------------------------------------------------
/deprecated/resize_multiple_frame_dim.py:
--------------------------------------------------------------------------------
1 | import pycuda.driver as cuda
2 | import pycuda.autoinit
3 | from pycuda.compiler import SourceModule
4 | from pycuda import gpuarray
5 | import numpy as np
6 | import cv2
7 | from line_profiler import LineProfiler
8 |
9 | profile = LineProfiler()
10 |
11 | bl_Normalize = 0
12 | bl_Trans = 1
13 | pagelock = 1
14 |
15 | module = SourceModule("""
16 |
17 | __device__ double lerp1d(int a, int b, float w)
18 | {
19 | return fma(w, (float)b, fma(-w,(float)a,(float)a));
20 | }
21 |
22 |
23 | __device__ float lerp2d(int f00, int f01, int f10, int f11,
24 | float centroid_h, float centroid_w )
25 | {
26 | centroid_w = (1 + lroundf(centroid_w) - centroid_w)/2;
27 | centroid_h = (1 + lroundf(centroid_h) - centroid_h)/2;
28 |
29 | float r0, r1, r;
30 | r0 = lerp1d(f00,f01,centroid_w);
31 | r1 = lerp1d(f10,f11,centroid_w);
32 |
33 | r = lerp1d(r0, r1, centroid_h); //+ 0.00001
34 | return r;
35 | }
36 |
37 | __global__ void Transpose(unsigned char *odata, const unsigned char *idata)
38 | {
39 | int H = blockDim.x * gridDim.x; // # dst_height
40 | int W = blockDim.y * gridDim.y; // # dst_width
41 | int h = blockDim.x * blockIdx.x + threadIdx.x; // 32 * bkIdx[0:18] + tdIdx; [0,607] # x / h-th row
42 | int w = blockDim.y * blockIdx.y + threadIdx.y; // 32 * bkIdx[0:18] + tdIdx; [0,607] # y / w-th col
43 | int C = 3; // # ChannelDim
44 | int c = blockIdx.z % 3 ; // [0,2] # ChannelIdx
45 | int n = blockIdx.z / 3 ; // [0 , Batch size-1], # BatchIdx
46 |
47 | long src_idx = n * (H * W * C) +
48 | h * (W * C) +
49 | w * C +
50 | c;
51 |
52 | long dst_idx = n * (C * H * W) +
53 | c * (H * W)+
54 | h * W+
55 | w;
56 |
57 | odata[dst_idx] = idata[src_idx];
58 | }
59 |
60 | __global__ void Transpose_and_normalise(float *odata, const unsigned char *idata)
61 | {
62 | int H = blockDim.x * gridDim.x; // # dst_height
63 | int W = blockDim.y * gridDim.y; // # dst_width
64 | int h = blockDim.x * blockIdx.x + threadIdx.x; // 32 * bkIdx[0:18] + tdIdx; [0,607] # x / h-th row
65 | int w = blockDim.y * blockIdx.y + threadIdx.y; // 32 * bkIdx[0:18] + tdIdx; [0,607] # y / w-th col
66 | int C = 3; // # ChannelDim
67 | int c = blockIdx.z % 3 ; // [0,2] # ChannelIdx
68 | int n = blockIdx.z / 3 ; // [0 , Batch size-1], # BatchIdx
69 |
70 | long src_idx = n * (H * W * C) +
71 | h * (W * C) +
72 | w * C +
73 | c;
74 |
75 | long dst_idx = n * (C * H * W) +
76 | c * (H * W)+
77 | h * W+
78 | w;
79 |
80 | odata[dst_idx] = idata[src_idx]/255.0;
81 | }
82 |
83 | __global__ void YoloResize(unsigned char* src_img, unsigned char* dst_img,
84 | int src_h, int src_w,
85 | int frame_h, int frame_w,
86 | float stride_h, float stride_w)
87 | {
88 | int H = blockDim.x * gridDim.x; // # dst_height
89 | int W = blockDim.y * gridDim.y; // # dst_width
90 | int h = blockDim.x * blockIdx.x + threadIdx.x; // 32 * bkIdx[0:18] + tdIdx; [0,607] # x / h-th row
91 | int w = blockDim.y * blockIdx.y + threadIdx.y; // 32 * bkIdx[0:18] + tdIdx; [0,607] # y / w-th col
92 | int C = 3; // # ChannelDim
93 | int c = blockIdx.z % 3 ; // [0,2] # ChannelIdx
94 | int n = blockIdx.z / 3 ; // [0 , Batch size-1], # BatchIdx
95 |
96 | int idx = n * (H * W * C) +
97 | h * (W * C) +
98 | w * C +
99 | c;
100 |
101 | float centroid_h, centroid_w;
102 | centroid_h = stride_h * (h + 0.5); // h w c -> x, y, z : 1080 , 1920 , 3
103 | centroid_w = stride_w * (w + 0.5); //
104 |
105 | int f00,f01,f10,f11;
106 |
107 | int src_h_idx = lroundf(centroid_h)-1;
108 | int src_w_idx = lroundf(centroid_w)-1;
109 | if (src_h_idx<0){src_h_idx=0;}
110 | if (src_w_idx<0){src_w_idx=0;}
111 |
112 | f00 = n * frame_h * frame_w * C +
113 | src_h_idx * frame_w * C +
114 | src_w_idx * C +
115 | c;
116 | f01 = n * frame_h * frame_w * C +
117 | src_h_idx * frame_w * C +
118 | (src_w_idx+1) * C +
119 | c;
120 | f10 = n * frame_h * frame_w * C +
121 | (src_h_idx+1) * frame_w * C +
122 | src_w_idx * C +
123 | c;
124 | f11 = n * frame_h * frame_w * C +
125 | (src_h_idx+1) * frame_w * C +
126 | (src_w_idx+1) * C +
127 | c;
128 |
129 | int rs = lroundf(lerp2d(src_img[f00], src_img[f01], src_img[f10], src_img[f11],
130 | centroid_h, centroid_w));
131 |
132 | dst_img[idx] = (unsigned char)rs;
133 | }
134 | """)
135 |
136 | # block = (32, 32, 1) blockDim | threadIdx
137 | # grid = (19,19,3)) gridDim | blockIdx
138 |
139 | YoloResizeKer = module.get_function("YoloResize")
140 | TransposeKer = module.get_function("Transpose")
141 | TransNorKer = module.get_function("Transpose_and_normalise")
142 |
143 | @profile
144 | def gpu_resize(input_img: np.ndarray):
145 | """
146 | Resize the batch image to (608,608)
147 | and Convert NHWC to NCHW
148 | pass the gpu array to normalize the pixel ( divide by 255)
149 |
150 | Application oriented
151 |
152 | input_img : batch input, format: NHWC , recommend RGB. *same as the NN input format
153 | input must be 3 channel, kernel set ChannelDim as 3.
154 | out : batch resized array, format: NCHW , same as intput channel
155 | """
156 | # ========= Init Params =========
157 | stream = cuda.Stream()
158 |
159 | # convert to array
160 | batch, src_h, src_w, channel = input_img.shape
161 | dst_h, dst_w = 608, 608
162 | frame_h, frame_w = 1080*2, 1920*2
163 | assert (src_h <= frame_h) & (src_w <= frame_w)
164 | # Mem Allocation
165 | # input memory
166 |
167 | if pagelock: # = = = = = = Pagelock emory = = = = = =
168 | inp = {"host":cuda.pagelocked_zeros(shape=(batch,frame_h,frame_w,channel),
169 | dtype=np.uint8,
170 | mem_flags=cuda.host_alloc_flags.DEVICEMAP)}
171 | # inp = {"host":cuda.pagelocked_empty_like(input_img,
172 | # mem_flags=cuda.host_alloc_flags.DEVICEMAP)}
173 | # print(inp["host"].shape,input_img.shape)
174 | inp["host"][:,:src_h,:src_w,:] = input_img
175 | else: # = = = = = = Global memory = = = = = =
176 | inp = {"host":input_img}
177 |
178 | inp["device"] = cuda.mem_alloc(inp["host"].nbytes)
179 | cuda.memcpy_htod_async(inp["device"], inp["host"],stream)
180 |
181 |
182 |
183 |
184 | # output data
185 | if pagelock: # = = = = = = Pagelock emory = = = = = =
186 | out = {"host":cuda.pagelocked_zeros(shape=(batch,dst_h,dst_w,channel),
187 | dtype=np.uint8,
188 | mem_flags=cuda.host_alloc_flags.DEVICEMAP)}
189 | else: # = = = = = = Global memory = = = = = =
190 | out = {"host":np.zeros(shape=(batch,dst_h,dst_w,channel), dtype=np.uint8)} # N H W C
191 |
192 | out["device"] = cuda.mem_alloc(out["host"].nbytes)
193 | cuda.memcpy_htod_async(out["device"], out["host"],stream)
194 |
195 |
196 | #Transpose (and Normalize)
197 | if bl_Normalize or bl_Trans:
198 | if bl_Normalize:
199 | if pagelock:
200 | trans = {"host":cuda.pagelocked_zeros(shape=(batch,channel,dst_h,dst_w),
201 | dtype=np.float32,
202 | mem_flags=cuda.host_alloc_flags.DEVICEMAP)} # N C H W
203 | else:
204 | trans = {"host":np.zeros(shape=(batch,channel,dst_h,dst_w), dtype=np.float32)} # N C H W
205 | else:
206 | if pagelock:
207 | trans = {"host":cuda.pagelocked_zeros(shape=(batch,channel,dst_h,dst_w),
208 | dtype=np.uint8,
209 | mem_flags=cuda.host_alloc_flags.DEVICEMAP)}
210 | else:
211 | trans = {"host":np.zeros(shape=(batch,channel,dst_h,dst_w), dtype=np.uint8)} # N C H W
212 |
213 | trans["device"] = cuda.mem_alloc(trans["host"].nbytes)
214 | cuda.memcpy_htod_async(trans["device"], trans["host"],stream)
215 |
216 | # init resize , store kernel in cache
217 | YoloResizeKer(inp["device"], out["device"],
218 | np.int32(src_h), np.int32(src_w),
219 | np.int32(frame_h), np.int32(frame_w),
220 | np.float32(src_h/dst_h), np.float32(src_w/dst_w),
221 | block=(32, 32, 1),
222 | grid=(19,19,3*batch))
223 |
224 | # ========= Testing =========
225 |
226 | for _ in range(10):
227 | YoloResizeKer(inp["device"], out["device"],
228 | np.int32(src_h), np.int32(src_w),
229 | np.int32(frame_h), np.int32(frame_w),
230 | np.float32(src_h/dst_h), np.float32(src_w/dst_w),
231 | block=(32, 32, 1),
232 | grid=(19,19,3*batch))
233 |
234 | # ========= Copy out result =========
235 |
236 | if bl_Normalize:
237 | TransNorKer(trans["device"],out["device"],
238 | block=(32, 32, 1),
239 | grid=(19,19,3*batch))
240 | cuda.memcpy_dtoh_async(trans["host"], trans["device"],stream)
241 | stream.synchronize()
242 | return trans["host"]
243 | elif bl_Trans:
244 | TransposeKer(trans["device"],out["device"],
245 | block=(32, 32, 1),
246 | grid=(19,19,3*batch))
247 | cuda.memcpy_dtoh_async(trans["host"], trans["device"],stream)
248 | stream.synchronize()
249 | return trans["host"]
250 | else:
251 | cuda.memcpy_dtoh_async(out["host"], out["device"],stream)
252 | stream.synchronize()
253 | return out["host"]
254 |
255 | if __name__ == "__main__":
256 | grid = 19
257 | block = 32
258 | batch = 2
259 |
260 | # img = cv2.resize(cv2.imread("trump.jpg"),(1920,1080))
261 | # img = cv2.imread("trump.jpg")
262 | # img = np.tile(img,[batch,1,1,1])
263 |
264 | # img = np.zeros(shape=(3,1080,1920,3),dtype = np.uint8)
265 | # img[0,:48,:64,:] = cv2.resize(cv2.imread("trump.jpg"),(64,48))
266 | # img[1,:480,:640,:] = cv2.resize(cv2.imread("trump.jpg"),(640,480))
267 | # img[2,:1080,:1920,:] = cv2.resize(cv2.imread("trump.jpg"),(1920,1080))
268 |
269 | batch = 58
270 | img_batch_0 = np.tile(cv2.resize(cv2.imread("trump.jpg"),(64,48)),[batch,1,1,1])
271 | img_batch_1 = np.tile(cv2.resize(cv2.imread("trump.jpg"),(320,240)),[batch,1,1,1])
272 | img_batch_2 = np.tile(cv2.resize(cv2.imread("trump.jpg"),(1920,1080)),[batch,1,1,1])
273 | pix_0 = gpu_resize(img_batch_0)
274 | pix_1 = gpu_resize(img_batch_1)
275 | pix_2 = gpu_resize(img_batch_2)
276 | if bl_Normalize or bl_Trans:
277 | pix_0 = np.transpose(pix_0,[0,2,3,1])
278 | pix_1 = np.transpose(pix_1,[0,2,3,1])
279 | pix_2 = np.transpose(pix_2,[0,2,3,1])
280 | cv2.imwrite("trans0.jpg", pix_0[0])
281 | cv2.imwrite("trans1.jpg", pix_1[0])
282 | cv2.imwrite("trans2.jpg", pix_2[0])
283 |
284 | profile.print_stats()
285 | # print(pix.shape)
286 | # cv2.imwrite("pycuda_outpuut.jpg", pix[0])
--------------------------------------------------------------------------------
/deprecated/resize_multiple_frame_dim_refactor.py:
--------------------------------------------------------------------------------
1 | import pycuda.driver as cuda
2 | import pycuda.autoinit
3 | from pycuda.compiler import SourceModule
4 | from pycuda import gpuarray
5 | import numpy as np
6 | import cv2
7 | from line_profiler import LineProfiler
8 |
9 | profile = LineProfiler()
10 |
11 | module = SourceModule("""
12 |
13 | __device__ double lerp1d(int a, int b, float w)
14 | {
15 | return fma(w, (float)b, fma(-w,(float)a,(float)a));
16 | }
17 |
18 |
19 | __device__ float lerp2d(int f00, int f01, int f10, int f11,
20 | float centroid_h, float centroid_w )
21 | {
22 | centroid_w = (1 + lroundf(centroid_w) - centroid_w)/2;
23 | centroid_h = (1 + lroundf(centroid_h) - centroid_h)/2;
24 |
25 | float r0, r1, r;
26 | r0 = lerp1d(f00,f01,centroid_w);
27 | r1 = lerp1d(f10,f11,centroid_w);
28 |
29 | r = lerp1d(r0, r1, centroid_h); //+ 0.00001
30 | return r;
31 | }
32 |
33 | __global__ void Transpose(unsigned char *odata, const unsigned char *idata)
34 | {
35 | int H = blockDim.x * gridDim.x; // # dst_height
36 | int W = blockDim.y * gridDim.y; // # dst_width
37 | int h = blockDim.x * blockIdx.x + threadIdx.x; // 32 * bkIdx[0:18] + tdIdx; [0,607] # x / h-th row
38 | int w = blockDim.y * blockIdx.y + threadIdx.y; // 32 * bkIdx[0:18] + tdIdx; [0,607] # y / w-th col
39 | int C = 3; // # ChannelDim
40 | int c = blockIdx.z % 3 ; // [0,2] # ChannelIdx
41 | int n = blockIdx.z / 3 ; // [0 , Batch size-1], # BatchIdx
42 |
43 | long src_idx = n * (H * W * C) +
44 | h * (W * C) +
45 | w * C +
46 | c;
47 |
48 | long dst_idx = n * (C * H * W) +
49 | c * (H * W)+
50 | h * W+
51 | w;
52 |
53 | odata[dst_idx] = idata[src_idx];
54 | }
55 |
56 | __global__ void Transpose_and_normalise(float *odata, const unsigned char *idata)
57 | {
58 | int H = blockDim.x * gridDim.x; // # dst_height
59 | int W = blockDim.y * gridDim.y; // # dst_width
60 | int h = blockDim.x * blockIdx.x + threadIdx.x; // 32 * bkIdx[0:18] + tdIdx; [0,607] # x / h-th row
61 | int w = blockDim.y * blockIdx.y + threadIdx.y; // 32 * bkIdx[0:18] + tdIdx; [0,607] # y / w-th col
62 | int C = 3; // # ChannelDim
63 | int c = blockIdx.z % 3 ; // [0,2] # ChannelIdx
64 | int n = blockIdx.z / 3 ; // [0 , Batch size-1], # BatchIdx
65 |
66 | long src_idx = n * (H * W * C) +
67 | h * (W * C) +
68 | w * C +
69 | c;
70 |
71 | long dst_idx = n * (C * H * W) +
72 | c * (H * W)+
73 | h * W+
74 | w;
75 |
76 | odata[dst_idx] = idata[src_idx]/255.0;
77 | }
78 |
79 | __global__ void YoloResize(unsigned char* src_img, unsigned char* dst_img,
80 | int src_h, int src_w,
81 | int frame_h, int frame_w,
82 | float stride_h, float stride_w)
83 | {
84 | int H = blockDim.x * gridDim.x; // # dst_height
85 | int W = blockDim.y * gridDim.y; // # dst_width
86 | int h = blockDim.x * blockIdx.x + threadIdx.x; // 32 * bkIdx[0:18] + tdIdx; [0,607] # x / h-th row
87 | int w = blockDim.y * blockIdx.y + threadIdx.y; // 32 * bkIdx[0:18] + tdIdx; [0,607] # y / w-th col
88 | int C = 3; // # ChannelDim
89 | int c = blockIdx.z % 3 ; // [0,2] # ChannelIdx
90 | int n = blockIdx.z / 3 ; // [0 , Batch size-1], # BatchIdx
91 |
92 | int idx = n * (H * W * C) +
93 | h * (W * C) +
94 | w * C +
95 | c;
96 |
97 | float centroid_h, centroid_w;
98 | centroid_h = stride_h * (h + 0.5); // h w c -> x, y, z : 1080 , 1920 , 3
99 | centroid_w = stride_w * (w + 0.5); //
100 |
101 | int f00,f01,f10,f11;
102 |
103 | int src_h_idx = lroundf(centroid_h)-1;
104 | int src_w_idx = lroundf(centroid_w)-1;
105 | if (src_h_idx<0){src_h_idx=0;}
106 | if (src_w_idx<0){src_w_idx=0;}
107 |
108 | f00 = n * frame_h * frame_w * C +
109 | src_h_idx * frame_w * C +
110 | src_w_idx * C +
111 | c;
112 | f01 = n * frame_h * frame_w * C +
113 | src_h_idx * frame_w * C +
114 | (src_w_idx+1) * C +
115 | c;
116 | f10 = n * frame_h * frame_w * C +
117 | (src_h_idx+1) * frame_w * C +
118 | src_w_idx * C +
119 | c;
120 | f11 = n * frame_h * frame_w * C +
121 | (src_h_idx+1) * frame_w * C +
122 | (src_w_idx+1) * C +
123 | c;
124 |
125 | int rs = lroundf(lerp2d(src_img[f00], src_img[f01], src_img[f10], src_img[f11],
126 | centroid_h, centroid_w));
127 |
128 | dst_img[idx] = (unsigned char)rs;
129 | }
130 | """)
131 |
132 | # block = (32, 32, 1) blockDim | threadIdx
133 | # grid = (19,19,3)) gridDim | blockIdx
134 |
135 |
136 | class GPU_RESIZE_PROCESSOR():
137 | """docstring for ClassName"""
138 | def __init__(self, frame_h,frame_w, batch):
139 | # ========= Init Params =========
140 | # size of frame
141 | self.batch = batch
142 | self.channel = 3
143 | self.frame_h = frame_h # 1080 / 1080*n
144 | self.frame_w = frame_w #1920 / 1920*n
145 | self.dst_h = 608
146 | self.dst_w = 608
147 |
148 | # memory
149 | self.inp = None
150 | self.out = None
151 | self.trans = None
152 | # async stream
153 | self.stream = cuda.Stream()
154 |
155 | # CUDA kernel
156 | self.YoloResizeKer = module.get_function("YoloResize")
157 | self.TransposeKer = module.get_function("Transpose")
158 | self.TransNorKer = module.get_function("Transpose_and_normalise")
159 |
160 | self.allocate_memory()
161 | self.warm_up() # warm up
162 |
163 | def allocate_memory(self):
164 | self.inp = {"host":cuda.pagelocked_zeros(shape=(self.batch,self.frame_h,self.frame_w,self.channel),
165 | dtype=np.uint8,
166 | mem_flags=cuda.host_alloc_flags.DEVICEMAP)}
167 | self.inp["device"] = cuda.mem_alloc(self.inp["host"].nbytes)
168 |
169 |
170 | self.out = {"host":cuda.pagelocked_zeros(shape=(self.batch,self.dst_h,self.dst_w,self.channel),
171 | dtype=np.uint8,
172 | mem_flags=cuda.host_alloc_flags.DEVICEMAP)}
173 | self.out["device"] = cuda.mem_alloc(self.out["host"].nbytes)
174 |
175 |
176 | self.trans = {"host":cuda.pagelocked_zeros(shape=(self.batch,self.channel,self.dst_h,self.dst_w),
177 | # dtype=np.float32,
178 | dtype=np.uint8,
179 | mem_flags=cuda.host_alloc_flags.DEVICEMAP)} # N C H W
180 | self.trans["device"] = cuda.mem_alloc(self.trans["host"].nbytes)
181 |
182 | def warm_up(self):
183 | self.YoloResizeKer(self.inp["device"], self.out["device"],
184 | np.int32(self.frame_h), np.int32(self.frame_w),
185 | np.int32(self.frame_h), np.int32(self.frame_w),
186 | np.float32(1), np.float32(1),
187 | block=(32, 32, 1),
188 | grid=(19,19,3*self.batch))
189 | # self.TransNorKer(self.trans["device"],self.out["device"],
190 | # block=(32, 32, 1),
191 | # grid=(19,19,3*self.batch))
192 | self.TransposeKer(self.trans["device"],self.out["device"],
193 | block=(32, 32, 1),
194 | grid=(19,19,3*self.batch))
195 |
196 | @profile
197 | def resize(self, input_img: np.ndarray):
198 | """
199 | Resize the batch image to (608,608)
200 | and Convert NHWC to NCHW
201 | pass the gpu array to normalize the pixel ( divide by 255)
202 |
203 | Application oriented
204 |
205 | input_img : batch input, format: NHWC , recommend RGB. *same as the NN input format
206 | input must be 3 channel, kernel set ChannelDim as 3.
207 | out : batch resized array, format: NCHW , same as intput channel
208 | """
209 | batch, src_h, src_w, channel = input_img.shape
210 | assert (src_h <= self.frame_h) & (src_w <= self.frame_w)
211 | self.inp["host"][:,:src_h,:src_w,:] = input_img
212 | cuda.memcpy_htod_async(self.inp["device"], self.inp["host"],self.stream)
213 |
214 | self.YoloResizeKer(self.inp["device"], self.out["device"],
215 | np.int32(src_h), np.int32(src_w),
216 | np.int32(self.frame_h), np.int32(self.frame_w),
217 | np.float32(src_h/self.dst_h), np.float32(src_w/self.dst_w),
218 | block=(32, 32, 1),
219 | grid=(19,19,3*self.batch))
220 | # self.TransNorKer(self.trans["device"],self.out["device"],
221 | # block=(32, 32, 1),
222 | # grid=(19,19,3*self.batch))
223 |
224 | self.TransposeKer(self.trans["device"],self.out["device"],
225 | block=(32, 32, 1),
226 | grid=(19,19,3*self.batch))
227 | cuda.memcpy_dtoh_async(self.trans["host"], self.trans["device"],self.stream)
228 |
229 | self.stream.synchronize()
230 | # self.cleanup()
231 | return self.trans["host"]
232 |
233 | def cleanup(self):
234 | self.inp["host"][:,:,:,:] = 0
235 |
236 | # def deallocate(self):
237 | # free(gpu_mem)
238 | @profile
239 | def main():
240 |
241 | batch_size = 58
242 | CUDA_processor = GPU_RESIZE_PROCESSOR(frame_h=1080,frame_w=1920,batch=batch_size)
243 |
244 | shape = [(64,48),(320,240),(1920,1080)]
245 | for idx,batch in enumerate(shape):
246 | img_batch = np.tile(cv2.resize(cv2.imread("trump.jpg"),batch),[batch_size,1,1,1])
247 | pix = CUDA_processor.resize(img_batch)
248 | pix = np.transpose(pix,[0,2,3,1])
249 | cv2.imwrite(f"trans{idx}.jpg", pix[0])
250 |
251 | profile.print_stats()
252 | # print(pix.shape)
253 | # cv2.imwrite("pycuda_outpuut.jpg", pix[0])
254 |
255 | if __name__ == "__main__":
256 | main()
257 |
--------------------------------------------------------------------------------
/lerp.py:
--------------------------------------------------------------------------------
1 | # pylint: disable=line-too-long, invalid-name, multiple-statements, too-many-locals, too-many-arguments
2 |
3 | import numpy as np
4 | # import cv2
5 | from line_profiler import LineProfiler
6 |
7 | profile = LineProfiler()
8 |
9 |
10 | def lerp1d( a, b, w):
11 | """
12 | a + w*(b-a)
13 |
14 | Returns the linear interpolation of a and b based on weight w.
15 |
16 | a and b are either both scalars or both vectors of the same length.
17 | The weight w may be a scalar or a vector of the same length as a and b.
18 | w can be any value (so is not restricted to be between zero and one);
19 | if w has values outside the [0,1] range, it actually extrapolates.
20 |
21 | lerp returns a when w is zero and returns b when w is one.
22 | """
23 | if b>a:
24 | return a + w*(b-a)
25 | return b + w*(a-b)
26 |
27 |
28 |
29 | @profile
30 | # def lerp2d(grid, centroid:np.ndarray):
31 | # """ Linear Interpolation
32 | # grid is a 2by2 matrix
33 | # centroid is the centroid of the 2x2 matrix, (row-y,col-x), range:[0,1]
34 | # -----r0-- ---------
35 | # |0,0 | |0,1 |
36 | # | | | |
37 | # | -px- x -+ - qx - -|
38 | # ------+--+---------
39 | # |1,0 | |1,1 |
40 | # | qy | |
41 | # | | | |
42 | # -----r1-- ---------
43 | # """
44 |
45 | # p = (1 - np.round(centroid)+centroid)/2
46 |
47 | # r0 = lerp1d(grid[0,0],grid[0,1],p[1])
48 | # r1 = lerp1d(grid[1,0],grid[1,1],p[1])
49 | # r = lerp1d(r0,r1,p[0]) +0.0001 # +0.0001 for np.round, sometimes 3.5 round down to 3. since computer science basis..
50 | # # if (grid=SRC_W){return;}
38 | const uchar3* src = (uchar3*)(src_img);
39 | uchar3* dst = (uchar3*)(dst_img);
40 |
41 | // coordinate dst pixel in src image
42 | int dst_row_idx = blockIdx.x;
43 | float centroid_h;
44 | centroid_h = scale_h * (dst_row_idx + 0.5);
45 | int src_h_idx = lroundf(centroid_h)-1;
46 | if (src_h_idx<0){src_h_idx=0;}
47 |
48 | int n = blockIdx.y; // batch number
49 | __shared__ uchar3 srcTile[2][MAX_WIDTH]; // cache `2 src rows` for `1 dst row` pixel
50 | int row_start;
51 | int pix_idx;
52 |
53 | for( int w = threadIdx.x ; w < SRC_W ; w+=blockDim.x){
54 | pix_idx = n * SRC_H * SRC_W + // move to the start of image in batch
55 | src_h_idx * SRC_W ; // move to the start of row index of src image
56 | // loop over 2 row image
57 | for (int row = 0; row < 2; row++){
58 | row_start = pix_idx + SRC_W * row; // jump to next row
59 | srcTile[row][w].x = src[row_start+w].x;
60 | srcTile[row][w].y = src[row_start+w].y;
61 | srcTile[row][w].z = src[row_start+w].z;
62 | }
63 | }
64 | __syncthreads();
65 |
66 | long long pixel_idx = n * DST_H * DST_W + // offset batch
67 | blockIdx.x * DST_W + // offset row(height)
68 | threadIdx.x; // offset col(width)
69 |
70 | uchar3 *f00, *f01, *f10, *f11;
71 | float centroid_w;
72 | for( int w = threadIdx.x ; w < DST_W ; w+=blockDim.x){
73 |
74 | centroid_w = scale_w * (w + 0.5);
75 | int src_w_idx = lroundf(centroid_w)-1;
76 | if (src_w_idx<0){src_w_idx=0;}
77 |
78 | // loop over 2 row image
79 |
80 | f00 = &srcTile[0][src_w_idx];
81 | f01 = &srcTile[0][src_w_idx+1];
82 | f10 = &srcTile[1][src_w_idx];
83 | f11 = &srcTile[1][src_w_idx+1];
84 |
85 | if (src_w_idx+1>=SRC_W){f01 = f00; f11 = f10;}
86 | if (src_h_idx+1>=SRC_H){f10 = f00; f11 = f01;}
87 |
88 | dst[pixel_idx].x = (unsigned char) lroundf(lerp2d((*f00).x, (*f01).x, (*f10).x, (*f11).x, centroid_h, centroid_w));
89 | dst[pixel_idx].y = (unsigned char) lroundf(lerp2d((*f00).y, (*f01).y, (*f10).y, (*f11).y, centroid_h, centroid_w));
90 | dst[pixel_idx].z = (unsigned char) lroundf(lerp2d((*f00).z, (*f01).z, (*f10).z, (*f11).z, centroid_h, centroid_w));
91 |
92 | pixel_idx += blockDim.x;
93 |
94 | }
95 | }
96 | }
--------------------------------------------------------------------------------
/lintrc/pylintrc:
--------------------------------------------------------------------------------
1 | [MASTER]
2 |
3 | # A comma-separated list of package or module names from where C extensions may
4 | # be loaded. Extensions are loading into the active Python interpreter and may
5 | # run arbitrary code.
6 | extension-pkg-whitelist=
7 |
8 | # Add files or directories to the blacklist. They should be base names, not
9 | # paths.
10 | ignore=CVS
11 |
12 | # Add files or directories matching the regex patterns to the blacklist. The
13 | # regex matches against base names, not paths.
14 | ignore-patterns=
15 |
16 | # Python code to execute, usually for sys.path manipulation such as
17 | # pygtk.require().
18 | #init-hook=
19 |
20 | # Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the
21 | # number of processors available to use.
22 | jobs=1
23 |
24 | # Control the amount of potential inferred values when inferring a single
25 | # object. This can help the performance when dealing with large functions or
26 | # complex, nested conditions.
27 | limit-inference-results=100
28 |
29 | # List of plugins (as comma separated values of python module names) to load,
30 | # usually to register additional checkers.
31 | load-plugins=
32 |
33 | # Pickle collected data for later comparisons.
34 | persistent=yes
35 |
36 | # Specify a configuration file.
37 | #rcfile=
38 |
39 | # When enabled, pylint would attempt to guess common misconfiguration and emit
40 | # user-friendly hints instead of false-positive error messages.
41 | suggestion-mode=yes
42 |
43 | # Allow loading of arbitrary C extensions. Extensions are imported into the
44 | # active Python interpreter and may run arbitrary code.
45 | unsafe-load-any-extension=no
46 |
47 |
48 | [MESSAGES CONTROL]
49 |
50 | # Only show warnings with the listed confidence levels. Leave empty to show
51 | # all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED.
52 | confidence=
53 |
54 | # Disable the message, report, category or checker with the given id(s). You
55 | # can either give multiple identifiers separated by comma (,) or put this
56 | # option multiple times (only on the command line, not in the configuration
57 | # file where it should appear only once). You can also use "--disable=all" to
58 | # disable everything first and then reenable specific checks. For example, if
59 | # you want to run only the similarities checker, you can use "--disable=all
60 | # --enable=similarities". If you want to run only the classes checker, but have
61 | # no Warning level messages displayed, use "--disable=all --enable=classes
62 | # --disable=W".
63 | disable=print-statement,
64 | parameter-unpacking,
65 | unpacking-in-except,
66 | old-raise-syntax,
67 | backtick,
68 | long-suffix,
69 | old-ne-operator,
70 | old-octal-literal,
71 | import-star-module-level,
72 | non-ascii-bytes-literal,
73 | raw-checker-failed,
74 | bad-inline-option,
75 | locally-disabled,
76 | file-ignored,
77 | suppressed-message,
78 | useless-suppression,
79 | deprecated-pragma,
80 | use-symbolic-message-instead,
81 | apply-builtin,
82 | basestring-builtin,
83 | buffer-builtin,
84 | cmp-builtin,
85 | coerce-builtin,
86 | execfile-builtin,
87 | file-builtin,
88 | long-builtin,
89 | raw_input-builtin,
90 | reduce-builtin,
91 | standarderror-builtin,
92 | unicode-builtin,
93 | xrange-builtin,
94 | coerce-method,
95 | delslice-method,
96 | getslice-method,
97 | setslice-method,
98 | no-absolute-import,
99 | old-division,
100 | dict-iter-method,
101 | dict-view-method,
102 | next-method-called,
103 | metaclass-assignment,
104 | indexing-exception,
105 | raising-string,
106 | reload-builtin,
107 | oct-method,
108 | hex-method,
109 | nonzero-method,
110 | cmp-method,
111 | input-builtin,
112 | round-builtin,
113 | intern-builtin,
114 | unichr-builtin,
115 | map-builtin-not-iterating,
116 | zip-builtin-not-iterating,
117 | range-builtin-not-iterating,
118 | filter-builtin-not-iterating,
119 | using-cmp-argument,
120 | eq-without-hash,
121 | div-method,
122 | idiv-method,
123 | rdiv-method,
124 | exception-message-attribute,
125 | invalid-str-codec,
126 | sys-max-int,
127 | bad-python3-import,
128 | deprecated-string-function,
129 | deprecated-str-translate-call,
130 | deprecated-itertools-function,
131 | deprecated-types-field,
132 | next-method-defined,
133 | dict-items-not-iterating,
134 | dict-keys-not-iterating,
135 | dict-values-not-iterating,
136 | deprecated-operator-function,
137 | deprecated-urllib-function,
138 | xreadlines-attribute,
139 | deprecated-sys-function,
140 | exception-escape,
141 | comprehension-escape,
142 | missing-function-docstring,
143 | missing-class-docstring,
144 | missing-module-docstring,
145 | missing-final-newline,
146 | broad-except,
147 | arguments-differ,
148 | duplicate-code
149 |
150 | # Enable the message, report, category or checker with the given id(s). You can
151 | # either give multiple identifier separated by comma (,) or put this option
152 | # multiple time (only on the command line, not in the configuration file where
153 | # it should appear only once). See also the "--disable" option for examples.
154 | enable=c-extension-no-member
155 |
156 |
157 | [REPORTS]
158 |
159 | # Python expression which should return a score less than or equal to 10. You
160 | # have access to the variables 'error', 'warning', 'refactor', and 'convention'
161 | # which contain the number of messages in each category, as well as 'statement'
162 | # which is the total number of statements analyzed. This score is used by the
163 | # global evaluation report (RP0004).
164 | evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)
165 |
166 | # Template used to display messages. This is a python new-style format string
167 | # used to format the message information. See doc for all details.
168 | #msg-template=
169 |
170 | # Set the output format. Available formats are text, parseable, colorized, json
171 | # and msvs (visual studio). You can also give a reporter class, e.g.
172 | # mypackage.mymodule.MyReporterClass.
173 | output-format=text
174 |
175 | # Tells whether to display a full report or only the messages.
176 | reports=no
177 |
178 | # Activate the evaluation score.
179 | score=yes
180 |
181 |
182 | [REFACTORING]
183 |
184 | # Maximum number of nested blocks for function / method body
185 | max-nested-blocks=5
186 |
187 | # Complete name of functions that never returns. When checking for
188 | # inconsistent-return-statements if a never returning function is called then
189 | # it will be considered as an explicit return statement and no message will be
190 | # printed.
191 | never-returning-functions=sys.exit
192 |
193 |
194 | [FORMAT]
195 |
196 | # Expected format of line ending, e.g. empty (any line ending), LF or CRLF.
197 | expected-line-ending-format=
198 |
199 | # Regexp for a line that is allowed to be longer than the limit.
200 | ignore-long-lines=^\s*(# )??$
201 |
202 | # Number of spaces of indent required inside a hanging or continued line.
203 | indent-after-paren=4
204 |
205 | # String used as indentation unit. This is usually " " (4 spaces) or "\t" (1
206 | # tab).
207 | indent-string=' '
208 |
209 | # Maximum number of characters on a single line.
210 | max-line-length=100
211 |
212 | # Maximum number of lines in a module.
213 | max-module-lines=1000
214 |
215 | # List of optional constructs for which whitespace checking is disabled. `dict-
216 | # separator` is used to allow tabulation in dicts, etc.: {1 : 1,\n222: 2}.
217 | # `trailing-comma` allows a space between comma and closing bracket: (a, ).
218 | # `empty-line` allows space-only lines.
219 | no-space-check=trailing-comma,
220 | dict-separator
221 |
222 | # Allow the body of a class to be on the same line as the declaration if body
223 | # contains single statement.
224 | single-line-class-stmt=no
225 |
226 | # Allow the body of an if to be on the same line as the test if there is no
227 | # else.
228 | single-line-if-stmt=no
229 |
230 |
231 | [SPELLING]
232 |
233 | # Limits count of emitted suggestions for spelling mistakes.
234 | max-spelling-suggestions=4
235 |
236 | # Spelling dictionary name. Available dictionaries: none. To make it work,
237 | # install the python-enchant package.
238 | spelling-dict=
239 |
240 | # List of comma separated words that should not be checked.
241 | spelling-ignore-words=
242 |
243 | # A path to a file that contains the private dictionary; one word per line.
244 | spelling-private-dict-file=
245 |
246 | # Tells whether to store unknown words to the private dictionary (see the
247 | # --spelling-private-dict-file option) instead of raising a message.
248 | spelling-store-unknown-words=no
249 |
250 |
251 | [LOGGING]
252 |
253 | # Format style used to check logging format string. `old` means using %
254 | # formatting, `new` is for `{}` formatting,and `fstr` is for f-strings.
255 | logging-format-style=old
256 |
257 | # Logging modules to check that the string format arguments are in logging
258 | # function parameter format.
259 | logging-modules=logging
260 |
261 |
262 | [SIMILARITIES]
263 |
264 | # Ignore comments when computing similarities.
265 | ignore-comments=yes
266 |
267 | # Ignore docstrings when computing similarities.
268 | ignore-docstrings=yes
269 |
270 | # Ignore imports when computing similarities.
271 | ignore-imports=no
272 |
273 | # Minimum lines number of a similarity.
274 | min-similarity-lines=6
275 |
276 |
277 | [TYPECHECK]
278 |
279 | # List of decorators that produce context managers, such as
280 | # contextlib.contextmanager. Add to this list to register other decorators that
281 | # produce valid context managers.
282 | contextmanager-decorators=contextlib.contextmanager
283 |
284 | # List of members which are set dynamically and missed by pylint inference
285 | # system, and so shouldn't trigger E1101 when accessed. Python regular
286 | # expressions are accepted.
287 | generated-members=
288 |
289 | # Tells whether missing members accessed in mixin class should be ignored. A
290 | # mixin class is detected if its name ends with "mixin" (case insensitive).
291 | ignore-mixin-members=yes
292 |
293 | # Tells whether to warn about missing members when the owner of the attribute
294 | # is inferred to be None.
295 | ignore-none=yes
296 |
297 | # This flag controls whether pylint should warn about no-member and similar
298 | # checks whenever an opaque object is returned when inferring. The inference
299 | # can return multiple potential results while evaluating a Python object, but
300 | # some branches might not be evaluated, which results in partial inference. In
301 | # that case, it might be useful to still emit no-member and other checks for
302 | # the rest of the inferred objects.
303 | ignore-on-opaque-inference=yes
304 |
305 | # List of class names for which member attributes should not be checked (useful
306 | # for classes with dynamically set attributes). This supports the use of
307 | # qualified names.
308 | ignored-classes=optparse.Values,thread._local,_thread._local
309 |
310 | # List of module names for which member attributes should not be checked
311 | # (useful for modules/projects where namespaces are manipulated during runtime
312 | # and thus existing member attributes cannot be deduced by static analysis). It
313 | # supports qualified module names, as well as Unix pattern matching.
314 | ignored-modules=cv2,asyncio,boto3,torch
315 |
316 | # Show a hint with possible names when a member name was not found. The aspect
317 | # of finding the hint is based on edit distance.
318 | missing-member-hint=yes
319 |
320 | # The minimum edit distance a name should have in order to be considered a
321 | # similar match for a missing member name.
322 | missing-member-hint-distance=1
323 |
324 | # The total number of similar names that should be taken in consideration when
325 | # showing a hint for a missing member.
326 | missing-member-max-choices=1
327 |
328 | # List of decorators that change the signature of a decorated function.
329 | signature-mutators=
330 |
331 |
332 | [MISCELLANEOUS]
333 |
334 | # List of note tags to take in consideration, separated by a comma.
335 | notes=FIXME,
336 | XXX,
337 | TODO
338 |
339 |
340 | [STRING]
341 |
342 | # This flag controls whether the implicit-str-concat-in-sequence should
343 | # generate a warning on implicit string concatenation in sequences defined over
344 | # several lines.
345 | check-str-concat-over-line-jumps=no
346 |
347 |
348 | [BASIC]
349 |
350 | # Naming style matching correct argument names.
351 | argument-naming-style=snake_case
352 |
353 | # Regular expression matching correct argument names. Overrides argument-
354 | # naming-style.
355 | #argument-rgx=
356 |
357 | # Naming style matching correct attribute names.
358 | attr-naming-style=snake_case
359 |
360 | # Regular expression matching correct attribute names. Overrides attr-naming-
361 | # style.
362 | #attr-rgx=
363 |
364 | # Bad variable names which should always be refused, separated by a comma.
365 | bad-names=foo,
366 | bar,
367 | baz,
368 | toto,
369 | tutu,
370 | tata
371 |
372 | # Naming style matching correct class attribute names.
373 | class-attribute-naming-style=any
374 |
375 | # Regular expression matching correct class attribute names. Overrides class-
376 | # attribute-naming-style.
377 | #class-attribute-rgx=
378 |
379 | # Naming style matching correct class names.
380 | class-naming-style=PascalCase
381 |
382 | # Regular expression matching correct class names. Overrides class-naming-
383 | # style.
384 | #class-rgx=
385 |
386 | # Naming style matching correct constant names.
387 | const-naming-style=UPPER_CASE
388 |
389 | # Regular expression matching correct constant names. Overrides const-naming-
390 | # style.
391 | #const-rgx=
392 |
393 | # Minimum line length for functions/classes that require docstrings, shorter
394 | # ones are exempt.
395 | docstring-min-length=-1
396 |
397 | # Naming style matching correct function names.
398 | function-naming-style=snake_case
399 |
400 | # Regular expression matching correct function names. Overrides function-
401 | # naming-style.
402 | #function-rgx=
403 |
404 | # Good variable names which should always be accepted, separated by a comma.
405 | good-names=i,
406 | j,
407 | k,
408 | w,
409 | h,
410 | x,
411 | y,
412 | ex,
413 | Run,
414 | _,
415 | app,
416 | routes,
417 | util,
418 | common_util,
419 | logger,
420 | loop
421 |
422 | # Include a hint for the correct naming format with invalid-name.
423 | include-naming-hint=no
424 |
425 | # Naming style matching correct inline iteration names.
426 | inlinevar-naming-style=any
427 |
428 | # Regular expression matching correct inline iteration names. Overrides
429 | # inlinevar-naming-style.
430 | #inlinevar-rgx=
431 |
432 | # Naming style matching correct method names.
433 | method-naming-style=snake_case
434 |
435 | # Regular expression matching correct method names. Overrides method-naming-
436 | # style.
437 | #method-rgx=
438 |
439 | # Naming style matching correct module names.
440 | module-naming-style=snake_case
441 |
442 | # Regular expression matching correct module names. Overrides module-naming-
443 | # style.
444 | #module-rgx=
445 |
446 | # Colon-delimited sets of names that determine each other's naming style when
447 | # the name regexes allow several styles.
448 | name-group=
449 |
450 | # Regular expression which should only match function or class names that do
451 | # not require a docstring.
452 | no-docstring-rgx=^_
453 |
454 | # List of decorators that produce properties, such as abc.abstractproperty. Add
455 | # to this list to register other decorators that produce valid properties.
456 | # These decorators are taken in consideration only for invalid-name.
457 | property-classes=abc.abstractproperty
458 |
459 | # Naming style matching correct variable names.
460 | variable-naming-style=snake_case
461 |
462 | # Regular expression matching correct variable names. Overrides variable-
463 | # naming-style.
464 | #variable-rgx=
465 |
466 |
467 | [VARIABLES]
468 |
469 | # List of additional names supposed to be defined in builtins. Remember that
470 | # you should avoid defining new builtins when possible.
471 | additional-builtins=
472 |
473 | # Tells whether unused global variables should be treated as a violation.
474 | allow-global-unused-variables=yes
475 |
476 | # List of strings which can identify a callback function by name. A callback
477 | # name must start or end with one of those strings.
478 | callbacks=cb_,
479 | _cb
480 |
481 | # A regular expression matching the name of dummy variables (i.e. expected to
482 | # not be used).
483 | dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_
484 |
485 | # Argument names that match this expression will be ignored. Default to name
486 | # with leading underscore.
487 | ignored-argument-names=_.*|^ignored_|^unused_
488 |
489 | # Tells whether we should check for unused import in __init__ files.
490 | init-import=no
491 |
492 | # List of qualified module names which can have objects that can redefine
493 | # builtins.
494 | redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io
495 |
496 |
497 | [CLASSES]
498 |
499 | # List of method names used to declare (i.e. assign) instance attributes.
500 | defining-attr-methods=__init__,
501 | __new__,
502 | setUp,
503 | __post_init__
504 |
505 | # List of member names, which should be excluded from the protected access
506 | # warning.
507 | exclude-protected=_asdict,
508 | _fields,
509 | _replace,
510 | _source,
511 | _make
512 |
513 | # List of valid names for the first argument in a class method.
514 | valid-classmethod-first-arg=cls
515 |
516 | # List of valid names for the first argument in a metaclass class method.
517 | valid-metaclass-classmethod-first-arg=cls
518 |
519 |
520 | [IMPORTS]
521 |
522 | # List of modules that can be imported at any level, not just the top level
523 | # one.
524 | allow-any-import-level=
525 |
526 | # Allow wildcard imports from modules that define __all__.
527 | allow-wildcard-with-all=no
528 |
529 | # Analyse import fallback blocks. This can be used to support both Python 2 and
530 | # 3 compatible code, which means that the block might have code that exists
531 | # only in one or another interpreter, leading to false positives when analysed.
532 | analyse-fallback-blocks=no
533 |
534 | # Deprecated modules which should not be used, separated by a comma.
535 | deprecated-modules=optparse,tkinter.tix
536 |
537 | # Create a graph of external dependencies in the given file (report RP0402 must
538 | # not be disabled).
539 | ext-import-graph=
540 |
541 | # Create a graph of every (i.e. internal and external) dependencies in the
542 | # given file (report RP0402 must not be disabled).
543 | import-graph=
544 |
545 | # Create a graph of internal dependencies in the given file (report RP0402 must
546 | # not be disabled).
547 | int-import-graph=
548 |
549 | # Force import order to recognize a module as part of the standard
550 | # compatibility libraries.
551 | known-standard-library=
552 |
553 | # Force import order to recognize a module as part of a third party library.
554 | known-third-party=enchant
555 |
556 | # Couples of modules and preferred modules, separated by a comma.
557 | preferred-modules=
558 |
559 |
560 | [DESIGN]
561 |
562 | # Maximum number of arguments for function / method.
563 | max-args=5
564 |
565 | # Maximum number of attributes for a class (see R0902).
566 | max-attributes=30
567 |
568 | # Maximum number of boolean expressions in an if statement (see R0916).
569 | max-bool-expr=5
570 |
571 | # Maximum number of branch for function / method body.
572 | max-branches=12
573 |
574 | # Maximum number of locals for function / method body.
575 | max-locals=15
576 |
577 | # Maximum number of parents for a class (see R0901).
578 | max-parents=7
579 |
580 | # Maximum number of public methods for a class (see R0904).
581 | max-public-methods=20
582 |
583 | # Maximum number of return / yield for function / method body.
584 | max-returns=6
585 |
586 | # Maximum number of statements in function / method body.
587 | max-statements=50
588 |
589 | # Minimum number of public methods for a class (see R0903).
590 | min-public-methods=1
591 |
592 |
593 | [EXCEPTIONS]
594 |
595 | # Exceptions that will emit a warning when being caught. Defaults to
596 | # "BaseException, Exception".
597 | overgeneral-exceptions=BaseException,
598 | Exception
599 |
--------------------------------------------------------------------------------
/resize.py:
--------------------------------------------------------------------------------
1 | # pylint: disable=line-too-long, invalid-name, too-many-locals, raising-bad-type, c-extension-no-member, redefined-outer-name
2 | import cv2
3 | import cupy as cp
4 | import numpy as np
5 | from line_profiler import LineProfiler
6 |
7 | with open('lib_cuResize.cu', 'r', encoding="utf-8") as reader:
8 | module = cp.RawModule(code=reader.read())
9 |
10 | cuResizeKer = module.get_function("cuResize")
11 | profile = LineProfiler()
12 |
13 | @profile
14 | def cuda_resize(inputs: cp.ndarray, # src: (N,H,W,C)
15 | shape: tuple, # (dst_h, dst_w)
16 | out: cp.ndarray=None, # dst: (N,H,W,C)
17 | pad: bool=True):
18 | """
19 | to optimise with shared memory
20 | block = (1024, ) # 1024 threads per block , to loop a row for dst row, with MAX_WIDTH 7680 (8K)
21 | grid = (dst_h,N) #
22 | """
23 | out_dtype = cp.uint8
24 |
25 | N, src_h, src_w, C = inputs.shape
26 | assert C == 3 # resize kernel only accept 3 channel tensors.
27 | dst_h, dst_w = shape
28 |
29 | if len(shape)!=2:
30 | raise "cuda resize target shape must be (h,w)"
31 | if out:
32 | assert out.dtype == out_dtype
33 | assert out.shape[1] == dst_h
34 | assert out.shape[2] == dst_w
35 |
36 | resize_scale = 1
37 | left_pad = 0
38 | top_pad = 0
39 | if pad:
40 | padded_batch = cp.zeros((N, dst_h, dst_w, C), dtype=out_dtype)
41 | if src_h / src_w > dst_h / dst_w:
42 | resize_scale = dst_h / src_h
43 | ker_h = dst_h
44 | ker_w = int(src_w * resize_scale)
45 | left_pad = int((dst_w - ker_w) / 2)
46 | else:
47 | resize_scale = dst_w / src_w
48 | ker_h = int(src_h * resize_scale)
49 | ker_w = dst_w
50 | top_pad = int((dst_h - ker_h) / 2)
51 | else:
52 | ker_h = dst_h
53 | ker_w = dst_w
54 |
55 | shape = (N, ker_h, ker_w, C)
56 | if not out:
57 | out = cp.empty(tuple(shape),dtype = out_dtype)
58 | # define kernel configs
59 | block = (1024, )
60 | grid = (ker_h, N)
61 | with cp.cuda.stream.Stream() as stream:
62 | print(inputs.dtype, out.dtype ,
63 | inputs.shape, out.shape,
64 | src_h, src_w,
65 | ker_h, ker_w,
66 | cp.float32(src_h/ker_h), cp.float32(src_w/ker_w))
67 |
68 | cuResizeKer(grid, block,
69 | (inputs, out,
70 | cp.int32(src_h), cp.int32(src_w),
71 | cp.int32(ker_h), cp.int32(ker_w),
72 | cp.float32(src_h/ker_h), cp.float32(src_w/ker_w)
73 | )
74 | )
75 |
76 | if pad:
77 | if src_h / src_w > dst_h / dst_w:
78 | padded_batch[:, :, left_pad:left_pad + out.shape[2], :] = out
79 | else:
80 | padded_batch[:, top_pad:top_pad + out.shape[1], :, :] = out
81 | padded_batch = cp.ascontiguousarray(padded_batch)
82 | stream.synchronize()
83 |
84 | if pad:
85 | return resize_scale, top_pad, left_pad, padded_batch
86 | return resize_scale, top_pad, left_pad, out
87 |
88 |
89 |
90 | def main(input_array: cp.ndarray, resize_shape:tuple):
91 | input_array_gpu = cp.empty(shape=input_array.shape,dtype=input_array.dtype)
92 |
93 | if isinstance(input_array, cp.ndarray): # DtoD
94 | cp.cuda.runtime.memcpy(dst = int(input_array_gpu.data), # dst_ptr
95 | src = int(input_array.data), # src_ptr
96 | size=input_array.nbytes,
97 | kind=3) # 0: HtoH, 1: HtoD, 2: DtoH, 3: DtoD, 4: unified virtual addressing
98 | elif isinstance(input_array, np.ndarray):
99 | cp.cuda.runtime.memcpy(dst = int(input_array_gpu.data), # dst_ptr
100 | src = input_array.ctypes.data, # src_ptr
101 | size=input_array.nbytes,
102 | kind=1)
103 |
104 | resize_scale, top_pad, left_pad, output_array = cuda_resize(input_array_gpu,
105 | resize_shape,
106 | pad=True) # N,W,H,C
107 |
108 | return output_array, [resize_scale, top_pad, left_pad]
109 |
110 | if __name__ == "__main__":
111 | # prepare data
112 | batch = 50
113 | img_batch = np.tile(cv2.resize(cv2.imread("trump.jpg"),
114 | (1920,1080)),
115 | [batch,1,1,1])
116 | img_batch[-1] = np.tile(cv2.resize(cv2.imread("rgba.png"),(1920,1080)),[1,1,1])
117 | output_array, _ = main(img_batch, (320,640))
118 | print(output_array)
119 |
120 | for idx, img in enumerate(cp.asnumpy(output_array)):
121 | cv2.imwrite(f"output_{idx}.jpg", img)
122 |
--------------------------------------------------------------------------------
/resize_formated.py:
--------------------------------------------------------------------------------
1 | # pylint: disable=line-too-long, invalid-name, too-many-locals, raising-bad-type, c-extension-no-member, redefined-outer-name
2 | import cv2
3 | import cupy as cp
4 | import numpy as np
5 | from line_profiler import LineProfiler
6 |
7 | with open('lib_cuResize.cu', 'r', encoding="utf-8") as reader:
8 | module = cp.RawModule(code=reader.read())
9 |
10 | cuResizeKer = module.get_function("cuResize_xyz")
11 | profile = LineProfiler()
12 |
13 | @profile
14 | def cuda_resize(inputs: cp.ndarray, # src: (N,H,W,C)
15 | shape: tuple, # (dst_h, dst_w)
16 | out: cp.ndarray=None, # dst: (N,H,W,C)
17 | pad: bool=True):
18 | """
19 | to optimise with shared memory
20 | block = (1024, ) # 1024 threads per block , to loop a row for dst row, with MAX_WIDTH 7680 (8K)
21 | grid = (dst_h,N) #
22 | """
23 | out_dtype = cp.uint8
24 |
25 | N, src_h, src_w, C = inputs.shape
26 | assert C == 3 # resize kernel only accept 3 channel tensors.
27 | dst_h, dst_w = shape
28 |
29 | if len(shape)!=2:
30 | raise "cuda resize target shape must be (h,w)"
31 | if out:
32 | assert out.dtype == out_dtype
33 | assert out.shape[1] == dst_h
34 | assert out.shape[2] == dst_w
35 |
36 | resize_scale = 1
37 | left_pad = 0
38 | top_pad = 0
39 | if pad:
40 | padded_batch = cp.zeros((N, dst_h, dst_w, C), dtype=out_dtype)
41 | if src_h / src_w > dst_h / dst_w:
42 | resize_scale = dst_h / src_h
43 | ker_h = dst_h
44 | ker_w = int(src_w * resize_scale)
45 | left_pad = int((dst_w - ker_w) / 2)
46 | else:
47 | resize_scale = dst_w / src_w
48 | ker_h = int(src_h * resize_scale)
49 | ker_w = dst_w
50 | top_pad = int((dst_h - ker_h) / 2)
51 | else:
52 | ker_h = dst_h
53 | ker_w = dst_w
54 |
55 | shape = (N, ker_h, ker_w, C)
56 | if not out:
57 | out = cp.empty(tuple(shape),dtype = out_dtype)
58 | # define kernel configs
59 | block = (1024, )
60 | grid = (ker_h, N)
61 | with cp.cuda.stream.Stream() as stream:
62 | print(inputs.dtype, out.dtype ,
63 | inputs.shape, out.shape,
64 | src_h, src_w,
65 | ker_h, ker_w,
66 | cp.float32(src_h/ker_h), cp.float32(src_w/ker_w))
67 |
68 | cuResizeKer(grid, block,
69 | (inputs, out,
70 | cp.int32(src_h), cp.int32(src_w),
71 | cp.int32(ker_h), cp.int32(ker_w),
72 | cp.float32(src_h/ker_h), cp.float32(src_w/ker_w)
73 | )
74 | )
75 |
76 | if pad:
77 | if src_h / src_w > dst_h / dst_w:
78 | padded_batch[:, :, left_pad:left_pad + out.shape[2], :] = out
79 | else:
80 | padded_batch[:, top_pad:top_pad + out.shape[1], :, :] = out
81 | padded_batch = cp.ascontiguousarray(padded_batch)
82 | stream.synchronize()
83 |
84 | if pad:
85 | return resize_scale, top_pad, left_pad, padded_batch
86 | return resize_scale, top_pad, left_pad, out
87 |
88 |
89 |
90 | def main(input_array: cp.ndarray, resize_shape:tuple):
91 | input_array_gpu = cp.empty(shape=input_array.shape,dtype=input_array.dtype)
92 |
93 | if isinstance(input_array, cp.ndarray): # DtoD
94 | cp.cuda.runtime.memcpy(dst = int(input_array_gpu.data), # dst_ptr
95 | src = int(input_array.data), # src_ptr
96 | size=input_array.nbytes,
97 | kind=3) # 0: HtoH, 1: HtoD, 2: DtoH, 3: DtoD, 4: unified virtual addressing
98 | elif isinstance(input_array, np.ndarray):
99 | cp.cuda.runtime.memcpy(dst = int(input_array_gpu.data), # dst_ptr
100 | src = input_array.ctypes.data, # src_ptr
101 | size=input_array.nbytes,
102 | kind=1)
103 |
104 | resize_scale, top_pad, left_pad, output_array = cuda_resize(input_array_gpu,
105 | resize_shape,
106 | pad=True) # N,W,H,C
107 |
108 | return output_array, [resize_scale, top_pad, left_pad]
109 |
110 | if __name__ == "__main__":
111 | # prepare data
112 | batch = 50
113 | img_batch = np.tile(cv2.resize(cv2.imread("trump.jpg"),
114 | # (2,2)),
115 | (2560,1080)),
116 | [batch,1,1,1])
117 | img_batch[-1] = np.tile(cv2.resize(cv2.imread("rgba.png"),(2560,1080)),[1,1,1])
118 | cv2.imwrite("input.jpg", cp.asnumpy(img_batch[0]))
119 | output_array, _ = main(img_batch, (192,200))
120 |
121 | # img_batch = cp.arange(1*2*1024*3, dtype=cp.uint8).reshape((1,2,1024,3))
122 | # output_array, _ = main(img_batch, (1 ,1024))
123 |
124 | # img_batch = cp.arange(1*4*4*3, dtype=cp.uint8).reshape((1,4,4,3))
125 | # img_batch = cp.tile(img_batch, (1,1,1,1))
126 | # output_array, _ = main(img_batch, (200,200))
127 | print(output_array)
128 |
129 | # block = (1024, )
130 | # grid = (dst_h, N)
131 |
132 |
133 | for idx, img in enumerate(cp.asnumpy(output_array)):
134 | cv2.imwrite(f"output_{idx}.jpg", img)
135 |
--------------------------------------------------------------------------------
/resize_free.cu:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #define CHECK_CUDA_ERROR(val) check((val), #val, __FILE__, __LINE__)
4 | template
5 | void check(T err, const char* const func, const char* const file,
6 | const int line)
7 | {
8 | if (err != cudaSuccess)
9 | {
10 | std::cerr << "CUDA Runtime Error at: " << file << ":" << line << std::endl;
11 | std::cerr << cudaGetErrorString(err) << " " << func << std::endl;
12 | // We don't exit when we encounter CUDA errors in this example.
13 | // std::exit(EXIT_FAILURE);
14 | }
15 | }
16 |
17 | #define CHECK_LAST_CUDA_ERROR() checkLast(__FILE__, __LINE__)
18 | void checkLast(const char* const file, const int line)
19 | {
20 | cudaError_t err{cudaGetLastError()};
21 | if (err != cudaSuccess)
22 | {
23 | std::cerr << "CUDA Runtime Error at: " << file << ":" << line
24 | << std::endl;
25 | std::cerr << cudaGetErrorString(err) << std::endl;
26 | // We don't exit when we encounter CUDA errors.
27 | // std::exit(EXIT_FAILURE);
28 | }
29 | }
30 |
31 | // __device__ float lerp1d(int a, int b, float w)
32 | // {
33 | // if(b>a){
34 | // return a + w*(b-a);
35 | // }
36 | // else{
37 | // return b + w*(a-b);
38 | // }
39 | // }
40 |
41 | __device__ float lerp1d(int a, int b, float w)
42 | {
43 | return fma(w, (float)b, fma(-w,(float)a,(float)a));
44 | }
45 |
46 | __device__ float lerp2d(int f00, int f01, int f10, int f11,
47 | float centroid_h, float centroid_w )
48 | {
49 | centroid_w = (1 + lroundf(centroid_w) - centroid_w)/2;
50 | centroid_h = (1 + lroundf(centroid_h) - centroid_h)/2;
51 |
52 | float r0, r1, r;
53 | r0 = lerp1d(f00,f01,centroid_w);
54 | r1 = lerp1d(f10,f11,centroid_w);
55 |
56 | r = lerp1d(r0, r1, centroid_h); //+ 0.00001
57 | // printf("%f, %f | %f, %f | %f | %d, %d, %d, %d \n", centroid_h , centroid_w, r0, r1, r, f00, f01, f10, f11);
58 | return r;
59 | }
60 |
61 | __global__ void GPU_validation(void)
62 | {
63 | printf("GPU has been activated \n");
64 | }
65 |
66 | __global__ void cuRESIZE(unsigned char* src_img, unsigned char* dst_img,
67 | const int src_h, const int src_w,
68 | const int dst_h, const int dst_w,
69 | const float scale_h, const float scale_w)
70 | {
71 | /*
72 | Input:
73 | src_img - NHWC
74 | channel C, default = 3
75 |
76 | Output:
77 | dst_img - NHWC
78 |
79 | */
80 |
81 | // int const N = gridDim.y; // batch size
82 | int const n = blockIdx.y; // batch number
83 | int const C = gridDim.z; // channel
84 | int const c = blockIdx.z; // channel number
85 | long idx = n * blockDim.x * gridDim.x * C +
86 | threadIdx.x * gridDim.x * C +
87 | blockIdx.x * C +
88 | c;
89 |
90 | // some overhead threads in each image process
91 | // when thread idx in one image exceed one image size return;
92 | if (idx%(blockDim.x * gridDim.x * C) >= dst_h* dst_w * C){return;}
93 |
94 | /*
95 | Now implementation :
96 | ( (1024 * int(DST_SIZE/3/1024)+1) - (src_h * src_w) )* N
97 | = overhead * N times
98 |
99 | to do: put the batch into gridDim.x
100 | dim3 dimGrid(int(DST_SIZE*batch/3/1024)+1,1,3);
101 |
102 | */
103 |
104 | int H = dst_h;
105 | int W = dst_w;
106 |
107 | int img_coor = idx % (dst_h*dst_w*C); //coordinate of one image, not idx of batch image
108 | int h = img_coor / (W*C); // dst idx
109 | int w = img_coor % (W*C)/C; // dst idx
110 |
111 | float centroid_h, centroid_w;
112 | centroid_h = scale_h * (h + 0.5); // h w c -> x, y, z : 1080 , 1920 , 3
113 | centroid_w = scale_w * (w + 0.5); //
114 |
115 | // unsigned long = 4,294,967,295 , up to (1080p,RGB)*600 imgs
116 | long f00,f01,f10,f11;
117 |
118 | int src_h_idx = lroundf(centroid_h)-1;
119 | int src_w_idx = lroundf(centroid_w)-1;
120 | if (src_h_idx<0){src_h_idx=0;} // handle boundary pixle
121 | if (src_w_idx<0){src_w_idx=0;} // handle boundary pixle
122 | // printf("h:%d w:%d\n",src_h_idx,src_w_idx);
123 | // printf("src_h_idx:%d , h: %d | src_w_idx:%d , w: %d\n",src_h_idx,h,src_w_idx,w);
124 |
125 | // idx = NHWC = n*(HWC) + h*(WC) + w*C + c;
126 | f00 = n * src_h * src_w * C +
127 | src_h_idx * src_w * C +
128 | src_w_idx * C +
129 | c;
130 | f01 = n * src_h * src_w * C +
131 | src_h_idx * src_w * C +
132 | (src_w_idx+1) * C +
133 | c;
134 | f10 = n * src_h * src_w * C +
135 | (src_h_idx+1) * src_w * C +
136 | src_w_idx * C +
137 | c;
138 | f11 = n * src_h * src_w * C +
139 | (src_h_idx+1) * src_w * C +
140 | (src_w_idx+1) * C +
141 | c;
142 | int rs;
143 | if (src_w_idx+1>=src_w){f01 = f00; f11 = f10;} // handle boundary pixle
144 | if (src_h_idx+1>=src_h){f10 = f00; f11 = f01;} // handle boundary pixle
145 |
146 | if (int(f10/ (src_h * src_w * C)) > n ){
147 | centroid_w = (1 + lroundf(centroid_w) - centroid_w)/2;
148 | rs = lroundf(lerp1d(f00,f01,centroid_w));
149 | }else{
150 | rs = lroundf(lerp2d(src_img[f00], src_img[f01], src_img[f10], src_img[f11],
151 | centroid_h, centroid_w));
152 | }
153 |
154 | long dst_idx = n * (H * W * C) +
155 | h * (W * C) +
156 | w * C +
157 | c;
158 |
159 | dst_img[dst_idx] = (unsigned char)rs;
160 | }
161 |
162 | int main(){
163 | int SRC_HEIGHT = 20;
164 | int SRC_WIDTH = 20;
165 | int SRC_SIZE = SRC_HEIGHT * SRC_WIDTH * 3;
166 |
167 | int DST_HEIGHT = 40;
168 | int DST_WIDTH = 40;
169 | int DST_SIZE = DST_HEIGHT * DST_WIDTH * 3;
170 |
171 | int batch = 1;
172 |
173 |
174 | // cudaStream_t stream1, stream2, stream3, stream4 ;
175 | cudaStream_t stream1;
176 | cudaStreamCreate ( &stream1) ;
177 |
178 | dim3 dimBlock(1024, 1,1); // maximum threads: 1024
179 | dim3 dimGrid(int(DST_SIZE/3/1024)+1,batch,3);
180 |
181 | unsigned char host_src[SRC_SIZE];
182 | // unsigned char host_dst[1108992];
183 | unsigned char host_dst[DST_SIZE];
184 |
185 | // init src image
186 | for(int i = 0; i < SRC_SIZE; i++){
187 | host_src[i] = i+1;
188 | // host_src[i] = (i%3);
189 | }
190 |
191 | float scale_h = (float)SRC_HEIGHT / DST_HEIGHT;
192 | float scale_w = (float)SRC_WIDTH / DST_WIDTH;
193 |
194 | unsigned char *device_src, *device_dst;
195 | CHECK_CUDA_ERROR(cudaMalloc((unsigned char **)&device_src, SRC_SIZE* sizeof(unsigned char)));
196 | CHECK_CUDA_ERROR(cudaMalloc((unsigned char **)&device_dst, DST_SIZE* sizeof(unsigned char)));
197 |
198 | CHECK_CUDA_ERROR(cudaMemcpy(device_src , host_src , SRC_SIZE * sizeof(unsigned char), cudaMemcpyHostToDevice));
199 |
200 | GPU_validation<<<1,1>>>();
201 | CHECK_CUDA_ERROR(cudaDeviceSynchronize());
202 |
203 |
204 | cuRESIZE<<>>(device_src, device_dst,
205 | SRC_HEIGHT, SRC_WIDTH,
206 | DST_HEIGHT, DST_WIDTH,
207 | scale_h, scale_w);
208 |
209 | CHECK_CUDA_ERROR(cudaDeviceSynchronize());
210 |
211 | // for(int i = 0; i<10; i++){
212 | // tester<<>>(device_src, device_dst,
213 | // SRC_HEIGHT, SRC_WIDTH,
214 | // scale_h, scale_w);
215 | // cudaDeviceSynchronize();
216 | // }
217 |
218 | cudaMemcpy(host_dst, device_dst, DST_SIZE * sizeof(unsigned char), cudaMemcpyDeviceToHost);
219 |
220 | // DEBUG : print first image in batch , first 30 pixel in 3 channels.
221 |
222 | // for(int i = 0; i < 30*3; i+=3){ // NHWC
223 | // printf("%d\n",host_src[i]);
224 | // }
225 | printf("============================\n");
226 |
227 | for(int c = 0; c<3*DST_HEIGHT*DST_WIDTH ; c+=DST_HEIGHT*DST_WIDTH){ // if NCHW
228 | for(int i = 0 ; i < 30; i++){
229 | printf("%d %d %d\n", c+i, i, host_dst[c+i]);
230 | }
231 | printf("------------------------------\n");
232 | }
233 |
234 | // print first 30 elements from each chanel
235 | // for(int c = 0; c<3; c++){ // NHWC
236 | // for(int i = 0 ; i < 30; i++){
237 | // int idx = i*3 +c;
238 | // printf("%d %d %d\n", c+i*3, i, host_dst[idx]);
239 | // }
240 | // printf("------------------------------\n");
241 | // }
242 |
243 | // int count_0=0;
244 | // int count_1=0;
245 | // int count_2=0;
246 | // for(int idx = 0; idx')
36 | a_array = cp.arange(12, dtype=cp.float32).reshape((2,2,3))
37 | b_array = cp.arange(12, 24, dtype=cp.float32).reshape((2,2,3))
38 | result = cp.zeros((2,2,3), dtype=cp.float32)
39 |
40 | ker((1,), (2,2,), (a_array, b_array, result, result.size//3))
41 | print("A\n", a_array)
42 | print("\nB\n", b_array)
43 | print("\nresult\n",result)
44 |
45 | print("cpu", a_array + b_array)
46 | # assert cp.allclose(result, 5*(2*x)+3*n) # note that we've multiplied by 2 earlier
--------------------------------------------------------------------------------
/tools/stat.cu:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #define CHECK_CUDA_ERROR(val) check((val), #val, __FILE__, __LINE__)
4 |
5 | template
6 | void check(T err, const char* const func, const char* const file,
7 | const int line)
8 | {
9 | if (err != cudaSuccess)
10 | {
11 | std::cerr << "CUDA Runtime Error at: " << file << ":" << line << std::endl;
12 | std::cerr << cudaGetErrorString(err) << " " << func << std::endl;
13 | // We don't exit when we encounter CUDA errors in this example.
14 | // std::exit(EXIT_FAILURE);
15 | }
16 | }
17 |
18 | #define MAX_WIDTH 7680 // 7680 3840 1920
19 | __global__ void tile_check(unsigned char* device_src)
20 | {
21 | // int idx = threadIdx.x + blockIdx.x * blockDim.x;
22 | __shared__ uchar3 srcTile[2][MAX_WIDTH]; // cache 2rows for 1 dst pixel
23 | for( int w = threadIdx.x ; w < MAX_WIDTH ; w+=blockDim.x){
24 | for (int row = 0; row < 2; row++){
25 | srcTile[row][w].x = 2;
26 | srcTile[row][w].y = 3;
27 | srcTile[row][w].z = 4;
28 | }
29 | }
30 | __syncthreads();
31 | int x = 1;
32 | // printf("x: %d\n", srcTile[0][x].x);
33 | // printf("sizeof(srcTile): %ld, %ld , %ld , %ld, %ld\n", sizeof(srcTile) , sizeof(srcTile[0]) , sizeof(srcTile[0][0]), sizeof(uchar3), sizeof(unsigned char));
34 | }
35 |
36 | int main() {
37 | int nDevices;
38 |
39 | cudaGetDeviceCount(&nDevices);
40 | // for (int i = 0; i < nDevices; i++) {
41 | // cudaDeviceProp prop;
42 | // cudaGetDeviceProperties(&prop, i);
43 | // printf("Device Number: %d\n", i);
44 | // printf(" Device name: %s\n", prop.name);
45 | // printf(" Memory Clock Rate (KHz): %d\n",
46 | // prop.memoryClockRate);
47 | // printf(" Memory Bus Width (bits): %d\n",
48 | // prop.memoryBusWidth);
49 | // printf(" Peak Memory Bandwidth (GB/s): %f\n\n",
50 | // 2.0*prop.memoryClockRate*(prop.memoryBusWidth/8)/1.0e6);
51 | // printf(" Max Threads per Block: %d\n", prop.maxThreadsPerBlock);
52 | // printf(" Max Threads per Multiprocessor: %d\n", prop.maxThreadsPerMultiProcessor);
53 | // printf(" Max Registers per Block: %d\n", prop.regsPerBlock);
54 | // printf(" Shared Memory per Block: %ld\n", prop.sharedMemPerBlock);
55 | // printf(" Total Constant Memory: %ld\n", prop.totalConstMem);
56 | // printf(" Memory Pitch: %ld\n", prop.memPitch);
57 | // }
58 |
59 | dim3 dimBlock(1024, 1,1); // maximum threads: 1024
60 | dim3 dimGrid(1920, 50,1);
61 | int SRC_SIZE = 1920*1080*50;
62 | int DST_SIZE = 20*20*50;
63 |
64 | // printf("%d\n", SRC_SIZE);
65 |
66 | unsigned char *host_src = (unsigned char *) malloc(sizeof(unsigned char) * SRC_SIZE);
67 | unsigned char *host_dst = (unsigned char *) malloc(sizeof(unsigned char) * DST_SIZE);
68 |
69 | // init src image
70 | for(int i = 0; i < SRC_SIZE; i++){
71 | host_src[i] = 1;
72 | }
73 | unsigned char *device_src, *device_dst;
74 | CHECK_CUDA_ERROR(cudaMalloc((unsigned char **)&device_src, SRC_SIZE* sizeof(unsigned char)));
75 | CHECK_CUDA_ERROR(cudaMalloc((unsigned char **)&device_dst, DST_SIZE* sizeof(unsigned char)));
76 |
77 | CHECK_CUDA_ERROR(cudaMemcpy(device_src , host_src , SRC_SIZE * sizeof(unsigned char), cudaMemcpyHostToDevice));
78 |
79 | tile_check<<>>(device_src);
80 |
81 | free(host_src);
82 | free(host_dst);
83 | cudaFree(device_src);
84 | cudaFree(device_dst);
85 | return 0;
86 | }
87 |
88 |
89 |
90 | // struct cudaDeviceProp {
91 | // char name[256];
92 | // size_t totalGlobalMem;
93 | // size_t sharedMemPerBlock;
94 | // int regsPerBlock;
95 | // int warpSize;
96 | // size_t memPitch;
97 | // int maxThreadsPerBlock;
98 | // int maxThreadsDim[3];
99 | // int maxGridSize[3];
100 | // size_t totalConstMem;
101 | // int major;
102 | // int minor;
103 | // int clockRate;
104 | // size_t textureAlignment;
105 | // int deviceOverlap;
106 | // int multiProcessorCount;
107 | // int kernelExecTimeoutEnabled;
108 | // int integrated;
109 | // int canMapHostMemory;
110 | // int computeMode;
111 | // int concurrentKernels;
112 | // int ECCEnabled;
113 | // int pciBusID;
114 | // int pciDeviceID;
115 | // int tccDriver;
116 | // }
117 |
118 | // nvcc stat.cu -o stat.o && ./stat.o
--------------------------------------------------------------------------------
/trump.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/royinx/CUDA_Resize/938da3fa4ce538befba7c336d3cb837f2296cd3f/trump.jpg
--------------------------------------------------------------------------------