32 #ifndef REMORA_KERNELS_CBLAS_DENSE_GEMM_HPP 33 #define REMORA_KERNELS_CBLAS_DENSE_GEMM_HPP 36 #include "../../detail/matrix_proxy_classes.hpp" 37 #include "../default/simd.hpp" 38 #include <type_traits> 39 namespace remora{
namespace bindings {
41 inline void dense_gemm(
42 CBLAS_ORDER
const Order, CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB,
44 float alpha,
float const *A,
int lda,
45 float const *B,
int ldb,
46 float beta,
float *C,
int ldc
49 Order, TransA, TransB,
57 inline void dense_gemm(
58 CBLAS_ORDER
const Order, CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB,
60 double alpha,
double const *A,
int lda,
61 double const *B,
int ldb,
62 double beta,
double *C,
int ldc
65 Order, TransA, TransB,
75 inline void dense_gemm(
76 CBLAS_ORDER
const Order, CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB,
79 std::complex<float>
const *A,
int lda,
80 std::complex<float>
const *B,
int ldb,
82 std::complex<float>* C,
int ldc
84 std::complex<float> alphaArg(alpha,0);
85 std::complex<float> betaArg(beta,0);
87 Order, TransA, TransB,
89 reinterpret_cast<cblas_float_complex_type const *>(&alphaArg),
90 reinterpret_cast<cblas_float_complex_type const *>(A), lda,
91 reinterpret_cast<cblas_float_complex_type const *>(B), ldb,
92 reinterpret_cast<cblas_float_complex_type const *>(&betaArg),
93 reinterpret_cast<cblas_float_complex_type *>(C), ldc
97 inline void dense_gemm(
98 CBLAS_ORDER
const Order, CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB,
101 std::complex<double>
const *A,
int lda,
102 std::complex<double>
const *B,
int ldb,
104 std::complex<double>* C,
int ldc
106 std::complex<double> alphaArg(alpha,0);
107 std::complex<double> betaArg(beta,0);
109 Order, TransA, TransB,
111 reinterpret_cast<cblas_double_complex_type const *>(&alphaArg),
112 reinterpret_cast<cblas_double_complex_type const *>(A), lda,
113 reinterpret_cast<cblas_double_complex_type const *>(B), ldb,
114 reinterpret_cast<cblas_double_complex_type const *>(&betaArg),
115 reinterpret_cast<cblas_double_complex_type *>(C), ldc
120 template <
typename MatA,
typename MatB,
typename MatC>
122 matrix_expression<MatA, cpu_tag>
const& A,
123 matrix_expression<MatB, cpu_tag>
const& B,
124 matrix_expression<MatC, cpu_tag>& C,
125 typename MatC::value_type alpha,
128 static_assert(std::is_same<typename MatC::orientation,row_major>::value,
"C must be row major");
130 CBLAS_TRANSPOSE transA = std::is_same<typename MatA::orientation,typename MatC::orientation>::value?CblasNoTrans:CblasTrans;
131 CBLAS_TRANSPOSE transB = std::is_same<typename MatB::orientation,typename MatC::orientation>::value?CblasNoTrans:CblasTrans;
132 std::size_t m = C().size1();
133 std::size_t n = C().size2();
134 std::size_t k = A().size2();
135 CBLAS_ORDER stor_ord = (CBLAS_ORDER) storage_order<typename MatC::orientation >::value;
137 auto storageA = A().raw_storage();
138 auto storageB = B().raw_storage();
139 auto storageC = C().raw_storage();
140 dense_gemm(stor_ord, transA, transB, (
int)m, (
int)n, (
int)k, alpha,
142 storageA.leading_dimension,
144 storageB.leading_dimension,
145 typename MatC::value_type(1),
147 storageC.leading_dimension
151 template <
typename MatA,
typename MatB,
typename MatC>
153 matrix_expression<MatA, cpu_tag>
const& A,
154 matrix_expression<MatB, cpu_tag>
const& B,
155 matrix_expression<MatC, cpu_tag>& C,
156 typename MatC::value_type alpha,
159 typedef typename MatC::value_type value_type;
160 std::size_t
const tile_size = 512;
161 static const std::size_t align = 64;
162 std::size_t size1 = C().size1();
163 std::size_t size2 = C().size2();
164 std::size_t num_blocks = (A().size2()+tile_size-1)/tile_size;
165 boost::alignment::aligned_allocator<value_type,align> allocator;
166 value_type* A_pointer = allocator.allocate(size1 * tile_size);
167 value_type* B_pointer = allocator.allocate(size2 * tile_size);
168 for(std::size_t k = 0; k != num_blocks; ++k){
169 std::size_t start_k = k * tile_size;
170 std::size_t current_size = std::min(tile_size,A().size2() - start_k);
171 dense_matrix_adaptor<value_type,row_major> A_block(A_pointer, size1, current_size);
172 dense_matrix_adaptor<value_type,row_major> B_block(B_pointer, current_size, size2);
173 matrix_range<MatA const> A_range(A(), 0, size1, start_k, start_k + current_size);
174 matrix_range<MatB const> B_range(B(), start_k, start_k + current_size, 0, size2);
175 noalias(A_block) = A_range;
176 noalias(B_block) = B_range;
177 dense_gemm(A_block, B_block, C, alpha, std::true_type());
179 allocator.deallocate(A_pointer, size1 * tile_size);
180 allocator.deallocate(B_pointer, size1 * tile_size);
184 template<
class M1,
class M2,
class M3>
185 struct has_optimized_gemm: std::integral_constant<bool,
186 allowed_cblas_type<typename M1::value_type>::type::value
187 && std::is_same<typename M1::value_type, typename M2::value_type>::value
188 && std::is_same<typename M1::value_type, typename M3::value_type>::value
189 && std::is_base_of<dense_tag, typename M1::storage_type::storage_tag>::value
190 && std::is_base_of<dense_tag, typename M2::storage_type::storage_tag>::value
191 && std::is_base_of<dense_tag, typename M3::storage_type::storage_tag>::value
194 template <
typename MatA,
typename MatB,
typename MatC>
196 matrix_expression<MatA, cpu_tag>
const& A,
197 matrix_expression<MatB, cpu_tag>
const& B,
198 matrix_expression<MatC, cpu_tag>& C,
199 typename MatC::value_type alpha
201 REMORA_SIZE_CHECK(A().size1() == C().size1());
202 REMORA_SIZE_CHECK(B().size2() == C().size2());
203 REMORA_SIZE_CHECK(A().size2()== B().size1());
206 typename has_optimized_gemm<MatA,MatB,MatC>::type()