gemv.hpp
Go to the documentation of this file.
1 //===========================================================================
2 /*!
3  *
4  *
5  * \brief -
6  *
7  * \author O. Krause
8  * \date 2010
9  *
10  *
11  * \par Copyright 1995-2015 Shark Development Team
12  *
13  * <BR><HR>
14  * This file is part of Shark.
15  * <http://image.diku.dk/shark/>
16  *
17  * Shark is free software: you can redistribute it and/or modify
18  * it under the terms of the GNU Lesser General Public License as published
19  * by the Free Software Foundation, either version 3 of the License, or
20  * (at your option) any later version.
21  *
22  * Shark is distributed in the hope that it will be useful,
23  * but WITHOUT ANY WARRANTY; without even the implied warranty of
24  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
25  * GNU Lesser General Public License for more details.
26  *
27  * You should have received a copy of the GNU Lesser General Public License
28  * along with Shark. If not, see <http://www.gnu.org/licenses/>.
29  *
30  */
31 //===========================================================================
32 #ifndef REMORA_KERNELS_CBLAS_GEMV_HPP
33 #define REMORA_KERNELS_CBLAS_GEMV_HPP
34 
35 #include "cblas_inc.hpp"
36 #include <type_traits>
37 
38 namespace remora{namespace bindings {
39 
40 inline void gemv(CBLAS_ORDER const Order,
41  CBLAS_TRANSPOSE const TransA, int const M, int const N,
42  double alpha, float const *A, int const lda,
43  float const *X, int const incX,
44  double beta, float *Y, int const incY
45 ) {
46  cblas_sgemv(Order, TransA, M, N, alpha, A, lda,
47  X, incX,
48  beta, Y, incY);
49 }
50 
51 inline void gemv(CBLAS_ORDER const Order,
52  CBLAS_TRANSPOSE const TransA, int const M, int const N,
53  double alpha, double const *A, int const lda,
54  double const *X, int const incX,
55  double beta, double *Y, int const incY
56 ) {
57  cblas_dgemv(Order, TransA, M, N, alpha, A, lda,
58  X, incX,
59  beta, Y, incY);
60 }
61 
62 inline void gemv(CBLAS_ORDER const Order,
63  CBLAS_TRANSPOSE const TransA, int const M, int const N,
64  double alpha,
65  std::complex<float> const *A, int const lda,
66  std::complex<float> const *X, int const incX,
67  double beta,
68  std::complex<float> *Y, int const incY
69 ) {
70  std::complex<float> alphaArg(alpha,0);
71  std::complex<float> betaArg(beta,0);
72  cblas_cgemv(Order, TransA, M, N,
73  reinterpret_cast<cblas_float_complex_type const *>(&alphaArg),
74  reinterpret_cast<cblas_float_complex_type const *>(A), lda,
75  reinterpret_cast<cblas_float_complex_type const *>(X), incX,
76  reinterpret_cast<cblas_float_complex_type const *>(&betaArg),
77  reinterpret_cast<cblas_float_complex_type *>(Y), incY);
78 }
79 
80 inline void gemv(CBLAS_ORDER const Order,
81  CBLAS_TRANSPOSE const TransA, int const M, int const N,
82  double alpha,
83  std::complex<double> const *A, int const lda,
84  std::complex<double> const *X, int const incX,
85  double beta,
86  std::complex<double> *Y, int const incY
87 ) {
88  std::complex<double> alphaArg(alpha,0);
89  std::complex<double> betaArg(beta,0);
90  cblas_zgemv(Order, TransA, M, N,
91  reinterpret_cast<cblas_double_complex_type const *>(&alphaArg),
92  reinterpret_cast<cblas_double_complex_type const *>(A), lda,
93  reinterpret_cast<cblas_double_complex_type const *>(X), incX,
94  reinterpret_cast<cblas_double_complex_type const *>(&betaArg),
95  reinterpret_cast<cblas_double_complex_type *>(Y), incY);
96 }
97 
98 
99 // y <- alpha * op (A) * x + beta * y
100 // op (A) == A || A^T || A^H
101 template <typename MatA, typename VectorX, typename VectorY>
102 void gemv(
103  matrix_expression<MatA, cpu_tag> const &A,
104  vector_expression<VectorX, cpu_tag> const &x,
105  vector_expression<VectorY, cpu_tag> &y,
106  typename VectorY::value_type alpha,
107  std::true_type
108 ){
109  std::size_t m = A().size1();
110  std::size_t n = A().size2();
111 
112  REMORA_SIZE_CHECK(x().size() == A().size2());
113  REMORA_SIZE_CHECK(y().size() == A().size1());
114 
115  CBLAS_ORDER const stor_ord= (CBLAS_ORDER)storage_order<typename MatA::orientation>::value;
116 
117  auto storageA = A().raw_storage();
118  auto storagex = x().raw_storage();
119  auto storagey = y().raw_storage();
120  gemv(stor_ord, CblasNoTrans, (int)m, (int)n, alpha,
121  storageA.values,
122  storageA.leading_dimension,
123  storagex.values,
124  storagex.stride,
125  typename VectorY::value_type(1),
126  storagey.values,
127  storagey.stride
128  );
129 }
130 
131 template<class M, class V1, class V2>
132 struct has_optimized_gemv: std::integral_constant<bool,
133  allowed_cblas_type<typename M::value_type>::type::value
134  && std::is_same<typename M::value_type, typename V1::value_type>::value
135  && std::is_same<typename V1::value_type, typename V2::value_type>::value
136  && std::is_base_of<dense_tag, typename M::storage_type::storage_tag>::value
137  && std::is_base_of<dense_tag, typename V1::storage_type::storage_tag>::value
138  && std::is_base_of<dense_tag, typename V2::storage_type::storage_tag>::value
139 >{};
140 
141 }}
142 #endif