31 #ifndef REMORA_KERNELS_DEFAULT_GEMM_HPP 32 #define REMORA_KERNELS_DEFAULT_GEMM_HPP 34 #include "../gemv.hpp" 35 #include "../../assignment.hpp" 36 #include "../../vector.hpp" 37 #include "../../detail/matrix_proxy_classes.hpp" 38 #include <type_traits> 40 namespace remora{
namespace bindings {
44 template <
class E1,
class E2,
class M,
class Orientation>
46 matrix_expression<E1, cpu_tag>
const& e1,
47 matrix_expression<E2, cpu_tag>
const& e2,
48 matrix_expression<M, cpu_tag>& m,
49 typename M::value_type alpha,
50 row_major, row_major, Orientation,
53 for (std::size_t i = 0; i != e1().size1(); ++i) {
54 matrix_row<M> row_m(m(),i);
55 matrix_row<typename const_expression<E1>::type> row_e1(e1(),i);
56 matrix_transpose<typename const_expression<E2>::type> trans_e2(e2());
57 kernels::gemv(trans_e2,row_e1,row_m,alpha);
61 template <
class E1,
class E2,
class M>
63 matrix_expression<E1, cpu_tag>
const& e1,
64 matrix_expression<E2, cpu_tag>
const& e2,
65 matrix_expression<M, cpu_tag>& m,
66 typename M::value_type alpha,
67 row_major, column_major, column_major,
70 typedef matrix_transpose<M> Trans_M;
71 typedef matrix_transpose<typename const_expression<E2>::type> Trans_E2;
73 Trans_E2 trans_e2(e2());
74 for (std::size_t j = 0; j != e2().size2(); ++j) {
75 matrix_row<Trans_M> column_m(trans_m,j);
76 matrix_row<Trans_E2> column_e2(trans_e2,j);
77 kernels::gemv(e1,column_e2,column_m,alpha);
81 template <
class E1,
class E2,
class M>
83 matrix_expression<E1, cpu_tag>
const& e1,
84 matrix_expression<E2, cpu_tag>
const& e2,
85 matrix_expression<M, cpu_tag>& m,
86 typename M::value_type alpha,
87 row_major, column_major, row_major,
90 for (std::size_t k = 0; k != e1().size2(); ++k) {
91 for(std::size_t i = 0; i != e1().size1(); ++i){
92 matrix_row<M> row_m(m(),i);
93 matrix_row<typename const_expression<E2>::type> row_e2(e2(),k);
94 plus_assign(row_m,row_e2,alpha * e1()(i,k));
100 template <
class E1,
class E2,
class M,
class Orientation>
102 matrix_expression<E1, cpu_tag>
const& e1,
103 matrix_expression<E2, cpu_tag>
const& e2,
104 matrix_expression<M, cpu_tag>& m,
105 typename M::value_type alpha,
106 row_major, row_major, Orientation,
107 sparse_tag, dense_tag
109 for (std::size_t i = 0; i != e1().size1(); ++i) {
110 matrix_row<M> row_m(m(),i);
111 matrix_row<E1> row_e1(e1(),i);
112 matrix_transpose<E2> trans_e2(e2());
113 kernels::gemv(trans_e2,row_e1,row_m,alpha);
117 template <
class E1,
class E2,
class M>
119 matrix_expression<E1, cpu_tag>
const& e1,
120 matrix_expression<E2, cpu_tag>
const& e2,
121 matrix_expression<M, cpu_tag>& m,
122 typename M::value_type alpha,
123 row_major, column_major, column_major,
124 sparse_tag, dense_tag
126 typedef matrix_transpose<M> Trans_M;
127 typedef matrix_transpose<typename const_expression<E2>::type> Trans_E2;
128 Trans_M trans_m(m());
129 Trans_E2 trans_e2(e2());
130 for (std::size_t j = 0; j != e2().size2(); ++j) {
131 matrix_row<Trans_M> column_m(trans_m,j);
132 matrix_row<Trans_E2> column_e2(trans_e2,j);
133 kernels::gemv(e1,column_e2,column_m,alpha);
137 template <
class E1,
class E2,
class M>
139 matrix_expression<E1, cpu_tag>
const& e1,
140 matrix_expression<E2, cpu_tag>
const& e2,
141 matrix_expression<M, cpu_tag>& m,
142 typename M::value_type alpha,
143 row_major, column_major, row_major,
144 sparse_tag, dense_tag
146 for (std::size_t k = 0; k != e1().size2(); ++k) {
147 auto e1end = e1().column_end(k);
148 for(
auto e1pos = e1().column_begin(k); e1pos != e1end; ++e1pos){
149 std::size_t i = e1pos.index();
150 matrix_row<M> row_m(m(),i);
151 matrix_row<typename const_expression<E2>::type> row_e2(e2(),k);
152 plus_assign(row_m,row_e2,alpha * (*e1pos));
158 template<
class M,
class E1,
class E2>
160 matrix_expression<E1, cpu_tag>
const& e1,
161 matrix_expression<E2, cpu_tag>
const& e2,
162 matrix_expression<M, cpu_tag>& m,
163 typename M::value_type alpha,
164 row_major, row_major, row_major,
165 sparse_tag, sparse_tag
167 typedef typename M::value_type value_type;
168 value_type zero = value_type();
169 vector<value_type> temporary(e2().size2(), zero);
170 matrix_transpose<typename const_expression<E2>::type> e2trans(e2());
171 for (std::size_t i = 0; i != e1().size1(); ++i) {
172 matrix_row<typename const_expression<E1>::type> rowe1(e1(),i);
173 kernels::gemv(e2trans,rowe1,temporary,alpha);
174 auto insert_pos = m().row_begin(i);
175 for (std::size_t j = 0; j != temporary.size(); ++ j) {
176 if (temporary(j) != zero) {
178 auto row_end = m().row_end(i);
179 while(insert_pos != row_end && insert_pos.index() < j)
182 if(insert_pos != row_end && insert_pos.index() == j){
183 *insert_pos += temporary(j);
185 insert_pos = m().set_element(insert_pos,j,temporary(j));
194 template<
class M,
class E1,
class E2>
196 matrix_expression<E1, cpu_tag>
const& e1,
197 matrix_expression<E2, cpu_tag>
const& e2,
198 matrix_expression<M, cpu_tag>& m,
199 typename M::value_type alpha,
200 row_major, row_major, column_major,
201 sparse_tag, sparse_tag
203 typedef matrix_transpose<M> Trans_M;
204 typedef matrix_transpose<typename const_expression<E2>::type> Trans_E2;
205 Trans_M trans_m(m());
206 Trans_E2 trans_e2(e2());
207 for (std::size_t j = 0; j != e2().size2(); ++j) {
208 matrix_row<Trans_M> column_m(trans_m,j);
209 matrix_row<Trans_E2> column_e2(trans_e2,j);
210 kernels::gemv(e1,column_e2,column_m,alpha);
214 template <
class E1,
class E2,
class M,
class Orientation>
216 matrix_expression<E1, cpu_tag>
const& e1,
217 matrix_expression<E2, cpu_tag>
const& e2,
218 matrix_expression<M, cpu_tag>& m,
219 typename M::value_type alpha,
220 row_major, column_major, Orientation o,
221 sparse_tag t1, sparse_tag t2
225 typename transposed_matrix_temporary<E1>::type e1_trans(e1);
226 gemm(e1_trans,e2,m,alpha,row_major(),row_major(),o,t1,t2);