31 #ifndef REMORA_KERNELS_CBLAS_TRSV_HPP 32 #define REMORA_KERNELS_CBLAS_TRSV_HPP 35 #include <type_traits> 39 namespace remora {
namespace bindings {
41 CBLAS_ORDER order, CBLAS_UPLO uplo,
42 CBLAS_TRANSPOSE transA, CBLAS_DIAG unit,
44 float const *A,
int lda,
float *b,
int strideX
46 cblas_strsv(order, uplo, transA, unit,n, A, lda, b, strideX);
50 CBLAS_ORDER order, CBLAS_UPLO uplo,
51 CBLAS_TRANSPOSE transA, CBLAS_DIAG unit,
53 double const *A,
int lda,
double *b,
int strideX
55 cblas_dtrsv(order, uplo, transA, unit,n, A, lda, b, strideX);
59 CBLAS_ORDER order, CBLAS_UPLO uplo,
60 CBLAS_TRANSPOSE transA, CBLAS_DIAG unit,
62 std::complex<float>
const *A,
int lda, std::complex<float> *b,
int strideX
64 cblas_ctrsv(order, uplo, transA, unit,n,
65 reinterpret_cast<cblas_float_complex_type const *>(A), lda,
66 reinterpret_cast<cblas_float_complex_type *>(b), strideX);
69 CBLAS_ORDER order, CBLAS_UPLO uplo,
70 CBLAS_TRANSPOSE transA, CBLAS_DIAG unit,
72 std::complex<double>
const *A,
int lda, std::complex<double> *b,
int strideX
74 cblas_ztrsv(order, uplo, transA, unit,n,
75 reinterpret_cast<cblas_double_complex_type const *>(A), lda,
76 reinterpret_cast<cblas_double_complex_type *>(b), strideX);
81 template <
class Triangular,
typename MatA,
typename V>
83 matrix_expression<MatA, cpu_tag>
const &A,
84 vector_expression<V, cpu_tag> &b,
87 REMORA_SIZE_CHECK(A().size1() == A().size2());
88 REMORA_SIZE_CHECK(A().size1()== b().size());
89 CBLAS_DIAG cblasUnit = Triangular::is_unit?CblasUnit:CblasNonUnit;
90 CBLAS_ORDER
const storOrd= (CBLAS_ORDER)storage_order<typename MatA::orientation>::value;
91 CBLAS_UPLO uplo = Triangular::is_upper?CblasUpper:CblasLower;
94 int const n = A().size1();
95 auto storageA = A().raw_storage();
96 auto storageb = b().raw_storage();
97 trsv(storOrd, uplo, CblasNoTrans,cblasUnit, n,
99 storageA.leading_dimension,
106 template <
class Triangular,
typename MatA,
typename V>
108 matrix_expression<MatA, cpu_tag>
const &A,
109 vector_expression<V, cpu_tag> &b,
110 std::true_type, right
112 matrix_transpose<typename const_expression<MatA>::type> transA(A());
113 trsv_impl<typename Triangular::transposed_orientation>(transA, b, std::true_type(), left());
118 template <
class Triangular,
class S
ide,
typename MatA,
typename V>
120 matrix_expression<MatA, cpu_tag>
const& A,
121 vector_expression<V, cpu_tag> & b,
124 trsv_impl<Triangular>(A,b,std::true_type(), Side());
127 template<
class M,
class V>
128 struct has_optimized_trsv: std::integral_constant<bool,
129 allowed_cblas_type<typename M::value_type>::type::value
130 && std::is_same<typename M::value_type, typename V::value_type>::value
131 && std::is_base_of<dense_tag, typename M::storage_type::storage_tag>::value
132 && std::is_base_of<dense_tag, typename V::storage_type::storage_tag>::value