trsv.hpp
Go to the documentation of this file.
1 /*!
2  *
3  *
4  * \brief -
5  *
6  * \author O. Krause
7  * \date 2011
8  *
9  *
10  * \par Copyright 1995-2015 Shark Development Team
11  *
12  * <BR><HR>
13  * This file is part of Shark.
14  * <http://image.diku.dk/shark/>
15  *
16  * Shark is free software: you can redistribute it and/or modify
17  * it under the terms of the GNU Lesser General Public License as published
18  * by the Free Software Foundation, either version 3 of the License, or
19  * (at your option) any later version.
20  *
21  * Shark is distributed in the hope that it will be useful,
22  * but WITHOUT ANY WARRANTY; without even the implied warranty of
23  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
24  * GNU Lesser General Public License for more details.
25  *
26  * You should have received a copy of the GNU Lesser General Public License
27  * along with Shark. If not, see <http://www.gnu.org/licenses/>.
28  *
29  */
30 
31 #ifndef REMORA_KERNELS_CBLAS_TRSV_HPP
32 #define REMORA_KERNELS_CBLAS_TRSV_HPP
33 
34 #include "cblas_inc.hpp"
35 #include <type_traits>
36 
37 ///solves systems of triangular matrices
38 
39 namespace remora {namespace bindings {
40 inline void trsv(
41  CBLAS_ORDER order, CBLAS_UPLO uplo,
42  CBLAS_TRANSPOSE transA, CBLAS_DIAG unit,
43  int n,
44  float const *A, int lda, float *b, int strideX
45 ){
46  cblas_strsv(order, uplo, transA, unit,n, A, lda, b, strideX);
47 }
48 
49 inline void trsv(
50  CBLAS_ORDER order, CBLAS_UPLO uplo,
51  CBLAS_TRANSPOSE transA, CBLAS_DIAG unit,
52  int n,
53  double const *A, int lda, double *b, int strideX
54 ){
55  cblas_dtrsv(order, uplo, transA, unit,n, A, lda, b, strideX);
56 }
57 
58 inline void trsv(
59  CBLAS_ORDER order, CBLAS_UPLO uplo,
60  CBLAS_TRANSPOSE transA, CBLAS_DIAG unit,
61  int n,
62  std::complex<float> const *A, int lda, std::complex<float> *b, int strideX
63 ){
64  cblas_ctrsv(order, uplo, transA, unit,n,
65  reinterpret_cast<cblas_float_complex_type const *>(A), lda,
66  reinterpret_cast<cblas_float_complex_type *>(b), strideX);
67 }
68 inline void trsv(
69  CBLAS_ORDER order, CBLAS_UPLO uplo,
70  CBLAS_TRANSPOSE transA, CBLAS_DIAG unit,
71  int n,
72  std::complex<double> const *A, int lda, std::complex<double> *b, int strideX
73 ){
74  cblas_ztrsv(order, uplo, transA, unit,n,
75  reinterpret_cast<cblas_double_complex_type const *>(A), lda,
76  reinterpret_cast<cblas_double_complex_type *>(b), strideX);
77 }
78 
79 // trsv(): solves A system of linear equations A * x = b
80 // when A is A triangular matrix.
81 template <class Triangular,typename MatA, typename V>
82 void trsv_impl(
83  matrix_expression<MatA, cpu_tag> const &A,
84  vector_expression<V, cpu_tag> &b,
85  std::true_type, left
86 ){
87  REMORA_SIZE_CHECK(A().size1() == A().size2());
88  REMORA_SIZE_CHECK(A().size1()== b().size());
89  CBLAS_DIAG cblasUnit = Triangular::is_unit?CblasUnit:CblasNonUnit;
90  CBLAS_ORDER const storOrd= (CBLAS_ORDER)storage_order<typename MatA::orientation>::value;
91  CBLAS_UPLO uplo = Triangular::is_upper?CblasUpper:CblasLower;
92 
93 
94  int const n = A().size1();
95  auto storageA = A().raw_storage();
96  auto storageb = b().raw_storage();
97  trsv(storOrd, uplo, CblasNoTrans,cblasUnit, n,
98  storageA.values,
99  storageA.leading_dimension,
100  storageb.values,
101  storageb.stride
102  );
103 }
104 
105 //right is mapped onto left via transposing A
106 template <class Triangular,typename MatA, typename V>
107 void trsv_impl(
108  matrix_expression<MatA, cpu_tag> const &A,
109  vector_expression<V, cpu_tag> &b,
110  std::true_type, right
111 ){
112  matrix_transpose<typename const_expression<MatA>::type> transA(A());
113  trsv_impl<typename Triangular::transposed_orientation>(transA, b, std::true_type(), left());
114 }
115 
116 //dispatcher
117 
118 template <class Triangular, class Side,typename MatA, typename V>
119 void trsv(
120  matrix_expression<MatA, cpu_tag> const& A,
121  vector_expression<V, cpu_tag> & b,
122  std::true_type//optimized
123 ){
124  trsv_impl<Triangular>(A,b,std::true_type(), Side());
125 }
126 
127 template<class M, class V>
128 struct has_optimized_trsv: std::integral_constant<bool,
129  allowed_cblas_type<typename M::value_type>::type::value
130  && std::is_same<typename M::value_type, typename V::value_type>::value
131  && std::is_base_of<dense_tag, typename M::storage_type::storage_tag>::value
132  && std::is_base_of<dense_tag, typename V::storage_type::storage_tag>::value
133 >{};
134 
135 }}
136 #endif