VScode 里面使用 python 去直接调用 CUDA

上一个 帖子主要分享了如何 去将 C++ 程序 打包成一个package。 我们最后的 目的实际上是想把 CUDA 的程序 打包成 一个 Package , C++ 程序只是起到了桥梁的作用

首先:CUDA 程序 和 C++ 的程序一样, 都有一个 .cu 的源文件和 一个 .h 的头文件

我们的文件 包含 Cpp 文件组成,负责当作 CUDA 和 Python 的桥梁。 还有 对应的 CUDA 的源代码文件和 头文件。将这个cpp 文件命名成 ext.cpp.

#include <torch/extension.h>
#include "interpolation_kernel.h"  ## CUDA 的头文件

PYBIND11_MODULE(TORCH_EXTENSION_NAME,m){
    m.def("trilinear_interpolation",&trilinear_interpolation);
}

cpp_properities.json 配置文件

{
    "configurations": [
        {
            "name": "Linux",
            "includePath": [
                "${workspaceFolder}/**",
                "/home/smiao/anaconda3/envs/Gen_3DGS/lib/python3.8",
                "/home/smiao/anaconda3/envs/Gen_3DGS/lib/python3.8/site-packages/torch/include/",
                "/home/smiao/anaconda3/envs/Gen_3DGS/lib/python3.8/site-packages/torch/include/torch/csrc/api/include/"
            ],
            "defines": [],
            "compilerPath": "/usr/bin/gcc",
            "cStandard": "c17",
            "cppStandard": "gnu++14",
            "intelliSenseMode": "linux-gcc-x64"
        }
    ],
    "version": 4

CUDA 部分:

CUDA 的头文件 *** interpolation_kernel.h ***

#include <torch/extension.h>

#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x "must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)

torch::Tensor trilinear_interpolation(torch::Tensor feats, torch::Tensor point);

对应的 源代码 文件*** interpolation_kernel.cu ***

include 的 头文件 和源代码文件 尽量放在同一级的 目录

#include <torch/extension.h>
#include "interpolation_kernel.h"
template <typename scalar_t>
__global__ void trilinear_fw_kernel(
            const torch::PackedTensorAccessor<scalar_t, 3, torch::RestrictPtrTraits, size_t> feats,
            const torch::PackedTensorAccessor<scalar_t, 2, torch::RestrictPtrTraits, size_t> points,
            torch::PackedTensorAccessor<scalar_t, 2, torch::RestrictPtrTraits, size_t> feat_interp
    const int n = blockIdx.x * blockDim.x + threadIdx.x; // 每个 Thread 都有一个 独特的Id
    const int f = blockIdx.y * blockDim.y + threadIdx.y;

    // 不参与计算的 thread 之间返回即可
    if(n >= feats.size(0) && f >= feats.size(2))
        return;
    
    const scalar_t u = (points[n][0]+1) / 2;
    const scalar_t v = (points[n][1]+1) / 2;
    const scalar_t w = (points[n][2]+1) / 2;

    const scalar_t a = (1-v) * (1-w);
    const scalar_t b = (1-v) * w;
    const scalar_t c = v * (1-w);
    const scalar_t d = 1-a-b-c;

    feat_interp[n][f] = (1-u) * ( a*feats[n][0][f]+
                               b*feats[n][1][f]+
                               c*feats[n][2][f]+
                               d*feats[n][3][f]) + 
                            u*(a*feats[n][4][f]+
                               b*feats[n][5][f]+
                               c*feats[n][6][f]+
                               d*feats[n][7][f]);
}

// 编写启动函数 如下:
torch::Tensor trilinear_interpolation(torch::Tensor feats, torch::Tensor points){
    CHECK_CUDA(feats);
    CHECK_CUDA(points);
	const int N = feats.size(0), F = feats.size(2);
    torch::Tensor feat_interp = torch::zeros({N,F}, feats.options());

    // 定义 几个 Block 和 Thread 的数量 
    const dim3 threads(16,16);
    const dim3 blocks((N + threads.x-1)/threads.x,(F + threads.y-1)/threads.y);

    //  启动 Kernel 函数, Kernel 的 函数类型一定是 Void  不会包含任何返回类型
    AT_DISPATCH_FLOATING_TYPES(feats.type(), "trilinear_interpolation", 
    ([&] {
        trilinear_fw_kernel<scalar_t><<<blocks, threads>>>(
            feats.packed_accessor<scalar_t, 3, torch::RestrictPtrTraits, size_t>(),
            points.packed_accessor<scalar_t, 2, torch::RestrictPtrTraits, size_t>(),
            feat_interp.packed_accessor<scalar_t, 2, torch::RestrictPtrTraits, size_t>()
        );
    }));
    return feats;
}

配置文件 setup.py 部分:

配置文件的 包含 ** *.cpp 文件 和 *.cu 文件 **
其他的部分应该 尽量不去改变。

from setuptools import setup
from torch.utils.cpp_extension import CUDAExtension, BuildExtension
import os
import glob

os.path.dirname(os.path.abspath(__file__))

setup(
    name="cuda_tutorial",
    version='1.0',
    ext_modules=[
        CUDAExtension(
            name='cuda_tutorial',
            sources=["interpolation_kernel.cu","ext.cpp"], 
              extra_compile_args={'cxx': ['-O2'],
                                'nvcc': ['-O2']}
        )
    ],
    cmdclass={
        'build_ext': BuildExtension
    }
)

最后是安装

pip install .

编写 test.py 的 python 代码对于 cuda 代码进行调用:

if __name__ == '__main__':
    N,F = 65536,256
    feats = torch.rand(N,8,F,device='cuda')
    points = torch.rand(N,3,device='cuda')*2-1
	
	## C++ 没有 feat=feat 这样传递 参数的形式
    t1 = time.time() 
    out_cuda = cuda_tutorial.trilinear_interpolation(feats,points)
    print(f"cuda time {time.time()-t1} s")

    t2 = time.time() 
    out_py = trilinear_interpolation_py(feats=feats,points=points)
    print(f"cuda time {time.time()-t2} s")
    print(torch.allclose(out_cuda,out_py))
    print(out_cuda.shape)

相关推荐

  1. VScode 里面使用 python 直接调用 CUDA

    2024-04-23 13:42:04       36 阅读
  2. Python4Delphi: 使用Delphi代码调用Python代码里面的类

    2024-04-23 13:42:04       55 阅读
  3. yolov7直接调用zed相机实现三维测距(python

    2024-04-23 13:42:04       35 阅读

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-04-23 13:42:04       94 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-04-23 13:42:04       100 阅读
  3. 在Django里面运行非项目文件

    2024-04-23 13:42:04       82 阅读
  4. Python语言-面向对象

    2024-04-23 13:42:04       91 阅读

热门阅读

  1. 从零开始精通RTSP之深入理解RTCP协议

    2024-04-23 13:42:04       43 阅读
  2. 基于spring boot开发的快递管理系统开题报告

    2024-04-23 13:42:04       32 阅读
  3. 使用selenium调用firefox提示Profile Missing的问题解决

    2024-04-23 13:42:04       33 阅读
  4. 【前端】vue.config.js打包时不编译

    2024-04-23 13:42:04       34 阅读
  5. vue中如何控制一个全局接口的调用频率

    2024-04-23 13:42:04       37 阅读
  6. ui_admin_vue3启动

    2024-04-23 13:42:04       31 阅读
  7. 图片 组件 vue2+element

    2024-04-23 13:42:04       35 阅读
  8. 谈谈 vue 生命周期

    2024-04-23 13:42:04       35 阅读
  9. python输入输出特殊处理

    2024-04-23 13:42:04       40 阅读
  10. 单链表(详解)

    2024-04-23 13:42:04       29 阅读
  11. 腾讯云开通幻兽帕鲁服务器需要多少钱?30元

    2024-04-23 13:42:04       38 阅读