项目中发现很多cuda代码很冗余,主要是有一些cuda内存相关的操作,比如cudaMemcpy之后,要进行错误检测,所以修改了一版,实现接口内部自己检测,这样代码看起来不会那么乱。
为什么使用宏定义,而不是函数定义,是因为函数定义的话,如果希望在出错误的时候直接输出出错位置的文件和行号,就需要在使用接口的时候,将__FILE__和__LINE__传入进去,有点麻烦,因此使用了宏定义的方式
下面是具体代码:
#pragma once
#include <assert.h>
#include <cuda_runtime.h>
#include <string>
// 向上取整
#define iDIV_UP(a, b) ((a + b - 1) / b)
#define ALG_MAX(x,y) ((x)>(y)?(x):(y))
#define ALG_MIN(x,y) ((x)<(y)?(x):(y))
#define FLOAT_EPS 1e-6
#define FLOAT_EQUAL(v1, v2) ((fabs((v1)-(v2))) < (FLOAT_EPS))
// cblas错误检查
#define CheckCBlasError(ErrorID) \
{ \
if(CUBLAS_STATUS_SUCCESS != ErrorID)\
{ \
printf("=====Imaging Error CBlas: %s, line: %d of file: %s\n", cublasGetStatusString(ErrorID), __LINE__, __FILE__);\
assert(false); \
} \
}
#define CheckCudaError(ErrorId){\
if (cudaSuccess != ErrorId)\
{ \
printf("=====Imaging Error Cuda: %s, file: %s : %d\n", cudaGetErrorString(ErrorId), __FILE__, __LINE__);assert(false);\
}\
}
// Cuda显存释放
#define Cuda_Free(pData){ \
if (nullptr != pData){\
cudaError_t error_id = cudaFree(pData);\
pData = nullptr;\
CheckCudaError(error_id);\
}\
}
// Cuda显存设置值
#define Cuda_Memset(devPtr, iValue, iSize){\
void** ptr = (void**)&devPtr;\
if (ptr != nullptr){\
auto error_id = cudaMemset(*ptr, iValue, iSize);\
CheckCudaError(error_id);\
}\
}
// 显存申请
#define Cuda_Malloc(pData, iSize){\
Cuda_Free(*pData);\
cudaError_t error_id = cudaMalloc(pData, iSize);\
CheckCudaError(error_id);\
}
#define Cuda_Memcpy(pDst, pSrc, iSize, cpyKind){\
auto error_id = cudaMemcpy(pDst, pSrc, iSize, cpyKind);\
CheckCudaError(error_id);\
}