trsm.hpp
Go to the documentation of this file.
1 /*!
2  *
3  *
4  * \brief -
5  *
6  * \author O. Krause
7  * \date 2011
8  *
9  *
10  * \par Copyright 1995-2015 Shark Development Team
11  *
12  * <BR><HR>
13  * This file is part of Shark.
14  * <http://image.diku.dk/shark/>
15  *
16  * Shark is free software: you can redistribute it and/or modify
17  * it under the terms of the GNU Lesser General Public License as published
18  * by the Free Software Foundation, either version 3 of the License, or
19  * (at your option) any later version.
20  *
21  * Shark is distributed in the hope that it will be useful,
22  * but WITHOUT ANY WARRANTY; without even the implied warranty of
23  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
24  * GNU Lesser General Public License for more details.
25  *
26  * You should have received a copy of the GNU Lesser General Public License
27  * along with Shark. If not, see <http://www.gnu.org/licenses/>.
28  *
29  */
30 
31 #ifndef REMORA_KERNELS_CBLAS_TRSM_HPP
32 #define REMORA_KERNELS_CBLAS_TRSM_HPP
33 
34 #include "cblas_inc.hpp"
35 #include "../../detail/matrix_proxy_classes.hpp"
36 #include <type_traits>
37 ///solves systems of triangular matrices
38 
39 namespace remora{namespace bindings {
40 inline void trsm(
41  CBLAS_ORDER order, CBLAS_UPLO uplo,CBLAS_TRANSPOSE transA,
42  CBLAS_SIDE side, CBLAS_DIAG unit,
43  int n, int nRHS,
44  float const *A, int lda, float *B, int ldb
45 ) {
46  cblas_strsm(order, side, uplo, transA, unit,n, nRHS, 1.0, A, lda, B, ldb);
47 }
48 
49 inline void trsm(
50  CBLAS_ORDER order, CBLAS_UPLO uplo,CBLAS_TRANSPOSE transA,
51  CBLAS_SIDE side, CBLAS_DIAG unit,
52  int n, int nRHS,
53  double const *A, int lda, double *B, int ldb
54 ) {
55  cblas_dtrsm(order, side, uplo, transA, unit,n, nRHS, 1.0, A, lda, B, ldb);
56 }
57 
58 inline void trsm(
59  CBLAS_ORDER order, CBLAS_UPLO uplo,CBLAS_TRANSPOSE transA,
60  CBLAS_SIDE side, CBLAS_DIAG unit,
61  int n, int nRHS,
62  std::complex<float> const *A, int lda, std::complex<float> *B, int ldb
63 ) {
64  std::complex<float> alpha(1.0,0);
65  cblas_ctrsm(order, side, uplo, transA, unit,n, nRHS,
66  reinterpret_cast<cblas_float_complex_type const *>(&alpha),
67  reinterpret_cast<cblas_float_complex_type const *>(A), lda,
68  reinterpret_cast<cblas_float_complex_type *>(B), ldb);
69 }
70 inline void trsm(
71  CBLAS_ORDER order, CBLAS_UPLO uplo,CBLAS_TRANSPOSE transA,
72  CBLAS_SIDE side, CBLAS_DIAG unit,
73  int n, int nRHS,
74  std::complex<double> const *A, int lda, std::complex<double> *B, int ldb
75 ) {
76  std::complex<double> alpha(1.0,0);
77  cblas_ztrsm(order, side, uplo, transA, unit,n, nRHS,
78  reinterpret_cast<cblas_double_complex_type const *>(&alpha),
79  reinterpret_cast<cblas_double_complex_type const *>(A), lda,
80  reinterpret_cast<cblas_double_complex_type *>(B), ldb);
81 }
82 
83 // trsm(): solves A system of linear equations A * X = B
84 // when A is a triangular matrix
85 template <class Triangular, typename MatA, typename MatB>
86 void trsm_impl(
87  matrix_expression<MatA, cpu_tag> const &A,
88  matrix_expression<MatB, cpu_tag> &B,
89  std::true_type, left
90 ){
91  REMORA_SIZE_CHECK(A().size1() == A().size2());
92  REMORA_SIZE_CHECK(A().size1() == B().size1());
93 
94  //orientation is defined by the second argument
95  CBLAS_ORDER const storOrd = (CBLAS_ORDER)storage_order<typename MatB::orientation>::value;
96  //if orientations do not match, wecan interpret this as transposing A
97  bool transposeA = !std::is_same<typename MatA::orientation,typename MatB::orientation>::value;
98 
99  CBLAS_DIAG cblasUnit = Triangular::is_unit?CblasUnit:CblasNonUnit;
100  CBLAS_UPLO cblasUplo = (Triangular::is_upper != transposeA)?CblasUpper:CblasLower;
101  CBLAS_TRANSPOSE transA = transposeA?CblasTrans:CblasNoTrans;
102 
103  int m = B().size1();
104  int nrhs = B().size2();
105  auto storageA = A().raw_storage();
106  auto storageB = B().raw_storage();
107  trsm(storOrd, cblasUplo, transA, CblasLeft,cblasUnit, m, nrhs,
108  storageA.values,
109  storageA.leading_dimension,
110  storageB.values,
111  storageB.leading_dimension
112  );
113 }
114 
115 template <class Triangular, typename MatA, typename MatB>
116 void trsm_impl(
117  matrix_expression<MatA, cpu_tag> const &A,
118  matrix_expression<MatB, cpu_tag> &B,
119  std::true_type, right
120 ){
121  matrix_transpose<typename const_expression<MatA>::type> transA(A());
122  matrix_transpose<MatB> transB(B());
123  trsm_impl<typename Triangular::transposed_orientation>(transA, transB, std::true_type(), left());
124 }
125 
126 template <class Triangular, class Side, typename MatA, typename MatB>
127 void trsm(
128  matrix_expression<MatA, cpu_tag> const &A,
129  matrix_expression<MatB, cpu_tag> &B,
130  std::true_type
131 ){
132  trsm_impl<Triangular>(A,B, std::true_type(), Side());
133 }
134 
135 template<class M1, class M2>
136 struct has_optimized_trsm: std::integral_constant<bool,
137  allowed_cblas_type<typename M1::value_type>::type::value
138  && std::is_same<typename M1::value_type, typename M2::value_type>::value
139  && std::is_base_of<dense_tag, typename M1::storage_type::storage_tag>::value
140  && std::is_base_of<dense_tag, typename M2::storage_type::storage_tag>::value
141 >{};
142 
143 }}
144 #endif