31 #ifndef REMORA_KERNELS_CBLAS_TRSM_HPP 32 #define REMORA_KERNELS_CBLAS_TRSM_HPP 35 #include "../../detail/matrix_proxy_classes.hpp" 36 #include <type_traits> 39 namespace remora{
namespace bindings {
41 CBLAS_ORDER order, CBLAS_UPLO uplo,CBLAS_TRANSPOSE transA,
42 CBLAS_SIDE side, CBLAS_DIAG unit,
44 float const *A,
int lda,
float *B,
int ldb
46 cblas_strsm(order, side, uplo, transA, unit,n, nRHS, 1.0, A, lda, B, ldb);
50 CBLAS_ORDER order, CBLAS_UPLO uplo,CBLAS_TRANSPOSE transA,
51 CBLAS_SIDE side, CBLAS_DIAG unit,
53 double const *A,
int lda,
double *B,
int ldb
55 cblas_dtrsm(order, side, uplo, transA, unit,n, nRHS, 1.0, A, lda, B, ldb);
59 CBLAS_ORDER order, CBLAS_UPLO uplo,CBLAS_TRANSPOSE transA,
60 CBLAS_SIDE side, CBLAS_DIAG unit,
62 std::complex<float>
const *A,
int lda, std::complex<float> *B,
int ldb
64 std::complex<float> alpha(1.0,0);
65 cblas_ctrsm(order, side, uplo, transA, unit,n, nRHS,
66 reinterpret_cast<cblas_float_complex_type const *>(&alpha),
67 reinterpret_cast<cblas_float_complex_type const *>(A), lda,
68 reinterpret_cast<cblas_float_complex_type *>(B), ldb);
71 CBLAS_ORDER order, CBLAS_UPLO uplo,CBLAS_TRANSPOSE transA,
72 CBLAS_SIDE side, CBLAS_DIAG unit,
74 std::complex<double>
const *A,
int lda, std::complex<double> *B,
int ldb
76 std::complex<double> alpha(1.0,0);
77 cblas_ztrsm(order, side, uplo, transA, unit,n, nRHS,
78 reinterpret_cast<cblas_double_complex_type const *>(&alpha),
79 reinterpret_cast<cblas_double_complex_type const *>(A), lda,
80 reinterpret_cast<cblas_double_complex_type *>(B), ldb);
85 template <
class Triangular,
typename MatA,
typename MatB>
87 matrix_expression<MatA, cpu_tag>
const &A,
88 matrix_expression<MatB, cpu_tag> &B,
91 REMORA_SIZE_CHECK(A().size1() == A().size2());
92 REMORA_SIZE_CHECK(A().size1() == B().size1());
95 CBLAS_ORDER
const storOrd = (CBLAS_ORDER)storage_order<typename MatB::orientation>::value;
97 bool transposeA = !std::is_same<typename MatA::orientation,typename MatB::orientation>::value;
99 CBLAS_DIAG cblasUnit = Triangular::is_unit?CblasUnit:CblasNonUnit;
100 CBLAS_UPLO cblasUplo = (Triangular::is_upper != transposeA)?CblasUpper:CblasLower;
101 CBLAS_TRANSPOSE transA = transposeA?CblasTrans:CblasNoTrans;
104 int nrhs = B().size2();
105 auto storageA = A().raw_storage();
106 auto storageB = B().raw_storage();
107 trsm(storOrd, cblasUplo, transA, CblasLeft,cblasUnit, m, nrhs,
109 storageA.leading_dimension,
111 storageB.leading_dimension
115 template <
class Triangular,
typename MatA,
typename MatB>
117 matrix_expression<MatA, cpu_tag>
const &A,
118 matrix_expression<MatB, cpu_tag> &B,
119 std::true_type, right
121 matrix_transpose<typename const_expression<MatA>::type> transA(A());
122 matrix_transpose<MatB> transB(B());
123 trsm_impl<typename Triangular::transposed_orientation>(transA, transB, std::true_type(), left());
126 template <
class Triangular,
class S
ide,
typename MatA,
typename MatB>
128 matrix_expression<MatA, cpu_tag>
const &A,
129 matrix_expression<MatB, cpu_tag> &B,
132 trsm_impl<Triangular>(A,B, std::true_type(), Side());
135 template<
class M1,
class M2>
136 struct has_optimized_trsm: std::integral_constant<bool,
137 allowed_cblas_type<typename M1::value_type>::type::value
138 && std::is_same<typename M1::value_type, typename M2::value_type>::value
139 && std::is_base_of<dense_tag, typename M1::storage_type::storage_tag>::value
140 && std::is_base_of<dense_tag, typename M2::storage_type::storage_tag>::value