32 #ifndef REMORA_KERNELS_CBLAS_TRMM_HPP 33 #define REMORA_KERNELS_CBLAS_TRMM_HPP 36 #include <type_traits> 38 namespace remora{
namespace bindings {
41 CBLAS_ORDER
const order,
42 CBLAS_SIDE
const side,
43 CBLAS_UPLO
const uplo,
44 CBLAS_TRANSPOSE
const transA,
45 CBLAS_DIAG
const unit,
48 float const *A,
int const lda,
49 float* B,
int const incB
51 cblas_strmm(order, side, uplo, transA, unit, M, N,
59 CBLAS_ORDER
const order,
60 CBLAS_SIDE
const side,
61 CBLAS_UPLO
const uplo,
62 CBLAS_TRANSPOSE
const transA,
63 CBLAS_DIAG
const unit,
66 double const *A,
int const lda,
67 double* B,
int const incB
69 cblas_dtrmm(order, side, uplo, transA, unit, M, N,
78 CBLAS_ORDER
const order,
79 CBLAS_SIDE
const side,
80 CBLAS_UPLO
const uplo,
81 CBLAS_TRANSPOSE
const transA,
82 CBLAS_DIAG
const unit,
85 std::complex<float>
const *A,
int const lda,
86 std::complex<float>* B,
int const incB
88 std::complex<float> alpha = 1.0;
89 cblas_ctrmm(order, side, uplo, transA, unit, M, N,
90 reinterpret_cast<cblas_float_complex_type const *>(&alpha),
91 reinterpret_cast<cblas_float_complex_type const *>(A), lda,
92 reinterpret_cast<cblas_float_complex_type *>(B), incB
97 CBLAS_ORDER
const order,
98 CBLAS_SIDE
const side,
99 CBLAS_UPLO
const uplo,
100 CBLAS_TRANSPOSE
const transA,
101 CBLAS_DIAG
const unit,
104 std::complex<double>
const *A,
int const lda,
105 std::complex<double>* B,
int const incB
107 std::complex<double> alpha = 1.0;
108 cblas_ztrmm(order, side, uplo, transA, unit, M, N,
109 reinterpret_cast<cblas_double_complex_type const *>(&alpha),
110 reinterpret_cast<cblas_double_complex_type const *>(A), lda,
111 reinterpret_cast<cblas_double_complex_type *>(B), incB
115 template <
bool upper,
bool unit,
typename MatA,
typename MatB>
117 matrix_expression<MatA, cpu_tag>
const& A,
118 matrix_expression<MatB, cpu_tag>& B,
121 REMORA_SIZE_CHECK(A().size1() == A().size2());
122 REMORA_SIZE_CHECK(A().size2() == B().size1());
123 std::size_t n = A().size1();
124 std::size_t m = B().size2();
125 CBLAS_DIAG cblasUnit = unit?CblasUnit:CblasNonUnit;
126 CBLAS_UPLO cblasUplo = upper?CblasUpper:CblasLower;
127 CBLAS_ORDER stor_ord= (CBLAS_ORDER)storage_order<typename MatA::orientation>::value;
128 CBLAS_TRANSPOSE trans=CblasNoTrans;
132 CBLAS_ORDER stor_ordB= (CBLAS_ORDER)storage_order<typename MatB::orientation>::value;
133 if(stor_ord != stor_ordB){
135 cblasUplo= upper?CblasLower:CblasUpper;
138 auto storageA = A().raw_storage();
139 auto storageB = B().raw_storage();
140 trmm(stor_ordB, CblasLeft, cblasUplo, trans, cblasUnit,
143 storageA.leading_dimension,
145 storageB.leading_dimension
150 template<
class M1,
class M2>
151 struct has_optimized_trmm: std::integral_constant<bool,
152 allowed_cblas_type<typename M1::value_type>::type::value
153 && std::is_same<typename M1::value_type, typename M2::value_type>::value
154 && std::is_base_of<dense_tag, typename M1::storage_type::storage_tag>::value
155 && std::is_base_of<dense_tag, typename M2::storage_type::storage_tag>::value