trmm.hpp
Go to the documentation of this file.
1 //===========================================================================
2 /*!
3  *
4  *
5  * \brief -
6  *
7  * \author O. Krause
8  * \date 2010
9  *
10  *
11  * \par Copyright 1995-2015 Shark Development Team
12  *
13  * <BR><HR>
14  * This file is part of Shark.
15  * <http://image.diku.dk/shark/>
16  *
17  * Shark is free software: you can redistribute it and/or modify
18  * it under the terms of the GNU Lesser General Public License as published
19  * by the Free Software Foundation, either version 3 of the License, or
20  * (at your option) any later version.
21  *
22  * Shark is distributed in the hope that it will be useful,
23  * but WITHOUT ANY WARRANTY; without even the implied warranty of
24  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
25  * GNU Lesser General Public License for more details.
26  *
27  * You should have received a copy of the GNU Lesser General Public License
28  * along with Shark. If not, see <http://www.gnu.org/licenses/>.
29  *
30  */
31 //===========================================================================
32 #ifndef REMORA_KERNELS_CBLAS_TRMM_HPP
33 #define REMORA_KERNELS_CBLAS_TRMM_HPP
34 
35 #include "cblas_inc.hpp"
36 #include <type_traits>
37 
38 namespace remora{namespace bindings {
39 
40 inline void trmm(
41  CBLAS_ORDER const order,
42  CBLAS_SIDE const side,
43  CBLAS_UPLO const uplo,
44  CBLAS_TRANSPOSE const transA,
45  CBLAS_DIAG const unit,
46  int const M,
47  int const N,
48  float const *A, int const lda,
49  float* B, int const incB
50 ) {
51  cblas_strmm(order, side, uplo, transA, unit, M, N,
52  1.0,
53  A, lda,
54  B, incB
55  );
56 }
57 
58 inline void trmm(
59  CBLAS_ORDER const order,
60  CBLAS_SIDE const side,
61  CBLAS_UPLO const uplo,
62  CBLAS_TRANSPOSE const transA,
63  CBLAS_DIAG const unit,
64  int const M,
65  int const N,
66  double const *A, int const lda,
67  double* B, int const incB
68 ) {
69  cblas_dtrmm(order, side, uplo, transA, unit, M, N,
70  1.0,
71  A, lda,
72  B, incB
73  );
74 }
75 
76 
77 inline void trmm(
78  CBLAS_ORDER const order,
79  CBLAS_SIDE const side,
80  CBLAS_UPLO const uplo,
81  CBLAS_TRANSPOSE const transA,
82  CBLAS_DIAG const unit,
83  int const M,
84  int const N,
85  std::complex<float> const *A, int const lda,
86  std::complex<float>* B, int const incB
87 ) {
88  std::complex<float> alpha = 1.0;
89  cblas_ctrmm(order, side, uplo, transA, unit, M, N,
90  reinterpret_cast<cblas_float_complex_type const *>(&alpha),
91  reinterpret_cast<cblas_float_complex_type const *>(A), lda,
92  reinterpret_cast<cblas_float_complex_type *>(B), incB
93  );
94 }
95 
96 inline void trmm(
97  CBLAS_ORDER const order,
98  CBLAS_SIDE const side,
99  CBLAS_UPLO const uplo,
100  CBLAS_TRANSPOSE const transA,
101  CBLAS_DIAG const unit,
102  int const M,
103  int const N,
104  std::complex<double> const *A, int const lda,
105  std::complex<double>* B, int const incB
106 ) {
107  std::complex<double> alpha = 1.0;
108  cblas_ztrmm(order, side, uplo, transA, unit, M, N,
109  reinterpret_cast<cblas_double_complex_type const *>(&alpha),
110  reinterpret_cast<cblas_double_complex_type const *>(A), lda,
111  reinterpret_cast<cblas_double_complex_type *>(B), incB
112  );
113 }
114 
115 template <bool upper, bool unit, typename MatA, typename MatB>
116 void trmm(
117  matrix_expression<MatA, cpu_tag> const& A,
118  matrix_expression<MatB, cpu_tag>& B,
119  std::true_type
120 ){
121  REMORA_SIZE_CHECK(A().size1() == A().size2());
122  REMORA_SIZE_CHECK(A().size2() == B().size1());
123  std::size_t n = A().size1();
124  std::size_t m = B().size2();
125  CBLAS_DIAG cblasUnit = unit?CblasUnit:CblasNonUnit;
126  CBLAS_UPLO cblasUplo = upper?CblasUpper:CblasLower;
127  CBLAS_ORDER stor_ord= (CBLAS_ORDER)storage_order<typename MatA::orientation>::value;
128  CBLAS_TRANSPOSE trans=CblasNoTrans;
129 
130  //special case: MatA and MatB do not have same storage order. in this case compute as
131  //AB->B^TA^T where transpose of B is done implicitely by exchanging storage order
132  CBLAS_ORDER stor_ordB= (CBLAS_ORDER)storage_order<typename MatB::orientation>::value;
133  if(stor_ord != stor_ordB){
134  trans = CblasTrans;
135  cblasUplo= upper?CblasLower:CblasUpper;
136  }
137 
138  auto storageA = A().raw_storage();
139  auto storageB = B().raw_storage();
140  trmm(stor_ordB, CblasLeft, cblasUplo, trans, cblasUnit,
141  (int)n, int(m),
142  storageA.values,
143  storageA.leading_dimension,
144  storageB.values,
145  storageB.leading_dimension
146  );
147 }
148 
149 
150 template<class M1, class M2>
151 struct has_optimized_trmm: std::integral_constant<bool,
152  allowed_cblas_type<typename M1::value_type>::type::value
153  && std::is_same<typename M1::value_type, typename M2::value_type>::value
154  && std::is_base_of<dense_tag, typename M1::storage_type::storage_tag>::value
155  && std::is_base_of<dense_tag, typename M2::storage_type::storage_tag>::value
156 >{};
157 
158 }}
159 #endif