dense_gemm.hpp
Go to the documentation of this file.
1 //===========================================================================
2 /*!
3  *
4  *
5  * \brief cblas binding for dense gemm
6  *
7  * \author O. Krause
8  * \date 2016
9  *
10  *
11  * \par Copyright 1995-2015 Shark Development Team
12  *
13  * <BR><HR>
14  * This file is part of Shark.
15  * <http://image.diku.dk/shark/>
16  *
17  * Shark is free software: you can redistribute it and/or modify
18  * it under the terms of the GNU Lesser General Public License as published
19  * by the Free Software Foundation, either version 3 of the License, or
20  * (at your option) any later version.
21  *
22  * Shark is distributed in the hope that it will be useful,
23  * but WITHOUT ANY WARRANTY; without even the implied warranty of
24  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
25  * GNU Lesser General Public License for more details.
26  *
27  * You should have received a copy of the GNU Lesser General Public License
28  * along with Shark. If not, see <http://www.gnu.org/licenses/>.
29  *
30  */
31 //===========================================================================
32 #ifndef REMORA_KERNELS_CBLAS_DENSE_GEMM_HPP
33 #define REMORA_KERNELS_CBLAS_DENSE_GEMM_HPP
34 
35 #include "cblas_inc.hpp"
36 #include "../../detail/matrix_proxy_classes.hpp"
37 #include "../default/simd.hpp"
38 #include <type_traits>
39 namespace remora{ namespace bindings {
40 
41 inline void dense_gemm(
42  CBLAS_ORDER const Order, CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB,
43  int M, int N, int K,
44  float alpha, float const *A, int lda,
45  float const *B, int ldb,
46  float beta, float *C, int ldc
47 ){
48  cblas_sgemm(
49  Order, TransA, TransB,
50  M, N, K,
51  alpha, A, lda,
52  B, ldb,
53  beta, C, ldc
54  );
55 }
56 
57 inline void dense_gemm(
58  CBLAS_ORDER const Order, CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB,
59  int M, int N, int K,
60  double alpha, double const *A, int lda,
61  double const *B, int ldb,
62  double beta, double *C, int ldc
63 ){
64  cblas_dgemm(
65  Order, TransA, TransB,
66  M, N, K,
67  alpha,
68  A, lda,
69  B, ldb,
70  beta,
71  C, ldc
72  );
73 }
74 
75 inline void dense_gemm(
76  CBLAS_ORDER const Order, CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB,
77  int M, int N, int K,
78  float alpha,
79  std::complex<float> const *A, int lda,
80  std::complex<float> const *B, int ldb,
81  float beta,
82  std::complex<float>* C, int ldc
83 ) {
84  std::complex<float> alphaArg(alpha,0);
85  std::complex<float> betaArg(beta,0);
86  cblas_cgemm(
87  Order, TransA, TransB,
88  M, N, K,
89  reinterpret_cast<cblas_float_complex_type const *>(&alphaArg),
90  reinterpret_cast<cblas_float_complex_type const *>(A), lda,
91  reinterpret_cast<cblas_float_complex_type const *>(B), ldb,
92  reinterpret_cast<cblas_float_complex_type const *>(&betaArg),
93  reinterpret_cast<cblas_float_complex_type *>(C), ldc
94  );
95 }
96 
97 inline void dense_gemm(
98  CBLAS_ORDER const Order, CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB,
99  int M, int N, int K,
100  double alpha,
101  std::complex<double> const *A, int lda,
102  std::complex<double> const *B, int ldb,
103  double beta,
104  std::complex<double>* C, int ldc
105 ) {
106  std::complex<double> alphaArg(alpha,0);
107  std::complex<double> betaArg(beta,0);
108  cblas_zgemm(
109  Order, TransA, TransB,
110  M, N, K,
111  reinterpret_cast<cblas_double_complex_type const *>(&alphaArg),
112  reinterpret_cast<cblas_double_complex_type const *>(A), lda,
113  reinterpret_cast<cblas_double_complex_type const *>(B), ldb,
114  reinterpret_cast<cblas_double_complex_type const *>(&betaArg),
115  reinterpret_cast<cblas_double_complex_type *>(C), ldc
116  );
117 }
118 
119 //optimized cblas version
120 template <typename MatA, typename MatB, typename MatC>
121 void dense_gemm(
122  matrix_expression<MatA, cpu_tag> const& A,
123  matrix_expression<MatB, cpu_tag> const& B,
124  matrix_expression<MatC, cpu_tag>& C,
125  typename MatC::value_type alpha,
126  std::true_type
127 ){
128  static_assert(std::is_same<typename MatC::orientation,row_major>::value,"C must be row major");
129 
130  CBLAS_TRANSPOSE transA = std::is_same<typename MatA::orientation,typename MatC::orientation>::value?CblasNoTrans:CblasTrans;
131  CBLAS_TRANSPOSE transB = std::is_same<typename MatB::orientation,typename MatC::orientation>::value?CblasNoTrans:CblasTrans;
132  std::size_t m = C().size1();
133  std::size_t n = C().size2();
134  std::size_t k = A().size2();
135  CBLAS_ORDER stor_ord = (CBLAS_ORDER) storage_order<typename MatC::orientation >::value;
136 
137  auto storageA = A().raw_storage();
138  auto storageB = B().raw_storage();
139  auto storageC = C().raw_storage();
140  dense_gemm(stor_ord, transA, transB, (int)m, (int)n, (int)k, alpha,
141  storageA.values,
142  storageA.leading_dimension,
143  storageB.values,
144  storageB.leading_dimension,
145  typename MatC::value_type(1),
146  storageC.values,
147  storageC.leading_dimension
148  );
149 }
150 
151 template <typename MatA, typename MatB, typename MatC>
152 void dense_gemm(
153  matrix_expression<MatA, cpu_tag> const& A,
154  matrix_expression<MatB, cpu_tag> const& B,
155  matrix_expression<MatC, cpu_tag>& C,
156  typename MatC::value_type alpha,
157  std::false_type
158 ){
159  typedef typename MatC::value_type value_type;
160  std::size_t const tile_size = 512;
161  static const std::size_t align = 64;
162  std::size_t size1 = C().size1();
163  std::size_t size2 = C().size2();
164  std::size_t num_blocks = (A().size2()+tile_size-1)/tile_size;
165  boost::alignment::aligned_allocator<value_type,align> allocator;
166  value_type* A_pointer = allocator.allocate(size1 * tile_size);
167  value_type* B_pointer = allocator.allocate(size2 * tile_size);
168  for(std::size_t k = 0; k != num_blocks; ++k){
169  std::size_t start_k = k * tile_size;
170  std::size_t current_size = std::min(tile_size,A().size2() - start_k);
171  dense_matrix_adaptor<value_type,row_major> A_block(A_pointer, size1, current_size);
172  dense_matrix_adaptor<value_type,row_major> B_block(B_pointer, current_size, size2);
173  matrix_range<MatA const> A_range(A(), 0, size1, start_k, start_k + current_size);
174  matrix_range<MatB const> B_range(B(), start_k, start_k + current_size, 0, size2);
175  noalias(A_block) = A_range;
176  noalias(B_block) = B_range;
177  dense_gemm(A_block, B_block, C, alpha, std::true_type());
178  }
179  allocator.deallocate(A_pointer, size1 * tile_size);
180  allocator.deallocate(B_pointer, size1 * tile_size);
181 }
182 
183 
184 template<class M1, class M2, class M3>
185 struct has_optimized_gemm: std::integral_constant<bool,
186  allowed_cblas_type<typename M1::value_type>::type::value
187  && std::is_same<typename M1::value_type, typename M2::value_type>::value
188  && std::is_same<typename M1::value_type, typename M3::value_type>::value
189  && std::is_base_of<dense_tag, typename M1::storage_type::storage_tag>::value
190  && std::is_base_of<dense_tag, typename M2::storage_type::storage_tag>::value
191  && std::is_base_of<dense_tag, typename M3::storage_type::storage_tag>::value
192 >{};
193 
194 template <typename MatA, typename MatB, typename MatC>
195 void dense_gemm(
196  matrix_expression<MatA, cpu_tag> const& A,
197  matrix_expression<MatB, cpu_tag> const& B,
198  matrix_expression<MatC, cpu_tag>& C,
199  typename MatC::value_type alpha
200 ){
201  REMORA_SIZE_CHECK(A().size1() == C().size1());
202  REMORA_SIZE_CHECK(B().size2() == C().size2());
203  REMORA_SIZE_CHECK(A().size2()== B().size1());
204  dense_gemm(
205  A,B,C,alpha,
206  typename has_optimized_gemm<MatA,MatB,MatC>::type()
207  );
208 }
209 
210 }}
211 
212 #endif