Skip to content

Commit

Permalink
Merge pull request #354 from ROCmSoftwarePlatform/gemm_ex_rocblas_int
Browse files Browse the repository at this point in the history
correct gemm_ex API use rocblas_int not int
  • Loading branch information
amcamd authored Sep 7, 2018
2 parents 31349b8 + 067fb71 commit dd05b47
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 57 deletions.
14 changes: 7 additions & 7 deletions library/include/rocblas-functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -1400,23 +1400,23 @@ ROCBLAS_EXPORT rocblas_status rocblas_dgeam(rocblas_handle handle,
ROCBLAS_EXPORT rocblas_status rocblas_gemm_ex(rocblas_handle handle,
rocblas_operation trans_a,
rocblas_operation trans_b,
int m,
int n,
int k,
rocblas_int m,
rocblas_int n,
rocblas_int k,
const void* alpha,
const void* a,
rocblas_datatype a_type,
int lda,
rocblas_int lda,
const void* b,
rocblas_datatype b_type,
int ldb,
rocblas_int ldb,
const void* beta,
const void* c,
rocblas_datatype c_type,
int ldc,
rocblas_int ldc,
void* d,
rocblas_datatype d_type,
int ldd,
rocblas_int ldd,
rocblas_datatype compute_type,
rocblas_gemm_algo algo,
uint32_t solution_index,
Expand Down
134 changes: 84 additions & 50 deletions library/src/blas_ex/rocblas_gemm_ex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -278,11 +278,11 @@ TensileStatus tensile_Cijk_Alik_Bjlk_B<double,double>(double* dataC, const doubl
template <typename Td, typename Tc>
rocblas_status tensile_gemm_handle_transpose(rocblas_handle handle,
rocblas_operation trans_a, rocblas_operation trans_b,
int m, int n, int k, const Tc alpha,
const Td* a, int lda,
const Td* b, int ldb, const Tc beta,
const Td* c, int ldc,
Td* d, int ldd)
rocblas_int m, rocblas_int n, rocblas_int k, const Tc alpha,
const Td* a, rocblas_int lda,
const Td* b, rocblas_int ldb, const Tc beta,
const Td* c, rocblas_int ldc,
Td* d, rocblas_int ldd)
{
TensileStatus t_status;
rocblas_status rb_status;
Expand All @@ -291,42 +291,78 @@ rocblas_status tensile_gemm_handle_transpose(rocblas_handle handle,

if((trans_a == rocblas_operation_none) && (trans_b == rocblas_operation_none))
{
unsigned int const stride_a = lda * k;
unsigned int const stride_b = ldb * n;
unsigned int const stride_d = ldd * n;
t_status = tensile_Cijk_Ailk_Bljk_B<Td,Tc>(static_cast<Td*>(d), static_cast<const Td*>(a), static_cast<const Td*>(b),
alpha, beta, 0, 0, 0, ldd, stride_d, lda, stride_a, ldb, stride_b,
m, n, 1, k, handle->rocblas_stream);
unsigned int const stride_a = static_cast<unsigned int const>(lda * k);
unsigned int const stride_b = static_cast<unsigned int const>(ldb * n);
unsigned int const stride_d = static_cast<unsigned int const>(ldd * n);
t_status = tensile_Cijk_Ailk_Bljk_B<Td,Tc>(static_cast<Td*>(d),
static_cast<const Td*>(a),
static_cast<const Td*>(b),
alpha, beta, 0, 0, 0,
static_cast<unsigned int>(ldd), stride_d,
static_cast<unsigned int>(lda), stride_a,
static_cast<unsigned int>(ldb), stride_b,
static_cast<unsigned int>(m),
static_cast<unsigned int>(n),
static_cast<unsigned int>(1),
static_cast<unsigned int>(k),
handle->rocblas_stream);
}
else if((trans_a == rocblas_operation_none) &&
(trans_b == rocblas_operation_transpose || trans_b == rocblas_operation_conjugate_transpose))
{
unsigned int const stride_a = lda * k;
unsigned int const stride_b = ldb * k;
unsigned int const stride_d = ldd * n;
t_status = tensile_Cijk_Ailk_Bjlk_B<Td,Tc>(static_cast<Td*>(d), static_cast<const Td*>(a), static_cast<const Td*>(b),
alpha, beta, 0, 0, 0, ldd, stride_d, lda, stride_a, ldb, stride_b,
m, n, 1, k, handle->rocblas_stream);
unsigned int const stride_a = static_cast<unsigned int const>(lda * k);
unsigned int const stride_b = static_cast<unsigned int const>(ldb * k);
unsigned int const stride_d = static_cast<unsigned int const>(ldd * n);
t_status = tensile_Cijk_Ailk_Bjlk_B<Td,Tc>(static_cast<Td*>(d),
static_cast<const Td*>(a),
static_cast<const Td*>(b),
alpha, beta, 0, 0, 0,
static_cast<unsigned int>(ldd), stride_d,
static_cast<unsigned int>(lda), stride_a,
static_cast<unsigned int>(ldb), stride_b,
static_cast<unsigned int>(m),
static_cast<unsigned int>(n),
static_cast<unsigned int>(1),
static_cast<unsigned int>(k),
handle->rocblas_stream);
}
else if((trans_a == rocblas_operation_transpose || trans_a == rocblas_operation_conjugate_transpose) &&
(trans_b == rocblas_operation_none))
{
unsigned int const stride_a = lda * m;
unsigned int const stride_b = ldb * n;
unsigned int const stride_d = ldd * n;
t_status = tensile_Cijk_Alik_Bljk_B<Td,Tc>(static_cast<Td*>(d), static_cast<const Td*>(a), static_cast<const Td*>(b),
alpha, beta, 0, 0, 0, ldd, stride_d, lda, stride_a, ldb, stride_b,
m, n, 1, k, handle->rocblas_stream);
unsigned int const stride_a = static_cast<unsigned int const>(lda * m);
unsigned int const stride_b = static_cast<unsigned int const>(ldb * n);
unsigned int const stride_d = static_cast<unsigned int const>(ldd * n);
t_status = tensile_Cijk_Alik_Bljk_B<Td,Tc>(static_cast<Td*>(d),
static_cast<const Td*>(a),
static_cast<const Td*>(b),
alpha, beta, 0, 0, 0,
static_cast<unsigned int>(ldd), stride_d,
static_cast<unsigned int>(lda), stride_a,
static_cast<unsigned int>(ldb), stride_b,
static_cast<unsigned int>(m),
static_cast<unsigned int>(n),
static_cast<unsigned int>(1),
static_cast<unsigned int>(k),
handle->rocblas_stream);
}
else if((trans_a == rocblas_operation_transpose || trans_a == rocblas_operation_conjugate_transpose) &&
(trans_b == rocblas_operation_transpose || trans_b == rocblas_operation_conjugate_transpose))
{
unsigned int const stride_a = lda * m;
unsigned int const stride_b = ldb * k;
unsigned int const stride_d = ldd * n;
t_status = tensile_Cijk_Alik_Bjlk_B<Td,Tc>(static_cast<Td*>(d), static_cast<const Td*>(a), static_cast<const Td*>(b),
alpha, beta, 0, 0, 0, ldd, stride_d, lda, stride_a, ldb, stride_b,
m, n, 1, k, handle->rocblas_stream);
unsigned int const stride_a = static_cast<unsigned int const>(lda * m);
unsigned int const stride_b = static_cast<unsigned int const>(ldb * k);
unsigned int const stride_d = static_cast<unsigned int const>(ldd * n);
t_status = tensile_Cijk_Alik_Bjlk_B<Td,Tc>(static_cast<Td*>(d),
static_cast<const Td*>(a),
static_cast<const Td*>(b),
alpha, beta, 0, 0, 0,
static_cast<unsigned int>(ldd), stride_d,
static_cast<unsigned int>(lda), stride_a,
static_cast<unsigned int>(ldb), stride_b,
static_cast<unsigned int>(m),
static_cast<unsigned int>(n),
static_cast<unsigned int>(1),
static_cast<unsigned int>(k),
handle->rocblas_stream);
}
else
{
Expand All @@ -349,19 +385,19 @@ template <typename Td, typename Tc>
rocblas_status tensile_gemm_chunk(rocblas_handle handle,
rocblas_operation trans_a,
rocblas_operation trans_b,
int m,
int n,
int k,
rocblas_int m,
rocblas_int n,
rocblas_int k,
Tc alpha,
const Td* a,
int lda,
rocblas_int lda,
const Td* b,
int ldb,
rocblas_int ldb,
Tc beta,
const Td* c,
int ldc,
rocblas_int ldc,
Td* d,
int ldd)
rocblas_int ldd)
{
unsigned int int_limit = std::numeric_limits<int>::max() / sizeof(Td);
unsigned int m_chunk_size = m;
Expand Down Expand Up @@ -445,11 +481,11 @@ rocblas_status tensile_gemm_chunk(rocblas_handle handle,
template <typename Td, typename Tc>
rocblas_status tensile_gemm_typecasting(rocblas_handle handle,
rocblas_operation trans_a, rocblas_operation trans_b,
int m, int n, int k, const void* alpha,
const void* a, int lda,
const void* b, int ldb, const void* beta,
const void* c, int ldc,
void* d, int ldd)
rocblas_int m, rocblas_int n, rocblas_int k, const void* alpha,
const void* a, rocblas_int lda,
const void* b, rocblas_int ldb, const void* beta,
const void* c, rocblas_int ldc,
void* d, rocblas_int ldd)
{
Tc h_alpha;
Tc h_beta;
Expand Down Expand Up @@ -484,7 +520,6 @@ rocblas_status tensile_gemm_typecasting(rocblas_handle handle,
ldd);
}


/*! \brief BLAS EX API
\details
Expand Down Expand Up @@ -583,29 +618,28 @@ rocblas_status tensile_gemm_typecasting(rocblas_handle handle,
workspace void*
workspace
********************************************************************/

extern "C" rocblas_status rocblas_gemm_ex(rocblas_handle handle,
rocblas_operation trans_a,
rocblas_operation trans_b,
int m,
int n,
int k,
rocblas_int m,
rocblas_int n,
rocblas_int k,
const void* alpha,
const void* a,
rocblas_datatype a_type,
int lda,
rocblas_int lda,
const void* b,
rocblas_datatype b_type,
int ldb,
rocblas_int ldb,
const void* beta,
const void* c,
rocblas_datatype c_type,
int ldc,
rocblas_int ldc,
void* d,
rocblas_datatype d_type,
int ldd,
rocblas_int ldd,
rocblas_datatype compute_type,
rocblas_gemm_algo algo,
uint32_t solution_index,
Expand Down

0 comments on commit dd05b47

Please sign in to comment.