gemm.hpp
Go to the documentation of this file.
1 /*!
2  *
3  *
4  * \brief -
5  *
6  * \author O. Krause
7  * \date 2010
8  *
9  *
10  * \par Copyright 1995-2015 Shark Development Team
11  *
12  * <BR><HR>
13  * This file is part of Shark.
14  * <http://image.diku.dk/shark/>
15  *
16  * Shark is free software: you can redistribute it and/or modify
17  * it under the terms of the GNU Lesser General Public License as published
18  * by the Free Software Foundation, either version 3 of the License, or
19  * (at your option) any later version.
20  *
21  * Shark is distributed in the hope that it will be useful,
22  * but WITHOUT ANY WARRANTY; without even the implied warranty of
23  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
24  * GNU Lesser General Public License for more details.
25  *
26  * You should have received a copy of the GNU Lesser General Public License
27  * along with Shark. If not, see <http://www.gnu.org/licenses/>.
28  *
29  */
30 
31 #ifndef REMORA_KERNELS_DEFAULT_GEMM_HPP
32 #define REMORA_KERNELS_DEFAULT_GEMM_HPP
33 
34 #include "../gemv.hpp"//for dispatching to gemv
35 #include "../../assignment.hpp"//plus_assign
36 #include "../../vector.hpp"//sparse gemm needs temporary vector
37 #include "../../detail/matrix_proxy_classes.hpp"//matrix row,column,transpose,range
38 #include <type_traits> //std::false_type marker for unoptimized, std::common_type
39 
40 namespace remora{namespace bindings {
41 
42 
43 // Dense-Sparse gemm
44 template <class E1, class E2, class M, class Orientation>
45 void gemm(
46  matrix_expression<E1, cpu_tag> const& e1,
47  matrix_expression<E2, cpu_tag> const& e2,
48  matrix_expression<M, cpu_tag>& m,
49  typename M::value_type alpha,
50  row_major, row_major, Orientation,
51  dense_tag, sparse_tag
52 ){
53  for (std::size_t i = 0; i != e1().size1(); ++i) {
54  matrix_row<M> row_m(m(),i);
55  matrix_row<typename const_expression<E1>::type> row_e1(e1(),i);
56  matrix_transpose<typename const_expression<E2>::type> trans_e2(e2());
57  kernels::gemv(trans_e2,row_e1,row_m,alpha);
58  }
59 }
60 
61 template <class E1, class E2, class M>
62 void gemm(
63  matrix_expression<E1, cpu_tag> const& e1,
64  matrix_expression<E2, cpu_tag> const& e2,
65  matrix_expression<M, cpu_tag>& m,
66  typename M::value_type alpha,
67  row_major, column_major, column_major,
68  dense_tag, sparse_tag
69 ){
70  typedef matrix_transpose<M> Trans_M;
71  typedef matrix_transpose<typename const_expression<E2>::type> Trans_E2;
72  Trans_M trans_m(m());
73  Trans_E2 trans_e2(e2());
74  for (std::size_t j = 0; j != e2().size2(); ++j) {
75  matrix_row<Trans_M> column_m(trans_m,j);
76  matrix_row<Trans_E2> column_e2(trans_e2,j);
77  kernels::gemv(e1,column_e2,column_m,alpha);
78  }
79 }
80 
81 template <class E1, class E2, class M>
82 void gemm(
83  matrix_expression<E1, cpu_tag> const& e1,
84  matrix_expression<E2, cpu_tag> const& e2,
85  matrix_expression<M, cpu_tag>& m,
86  typename M::value_type alpha,
87  row_major, column_major, row_major,
88  dense_tag, sparse_tag
89 ){
90  for (std::size_t k = 0; k != e1().size2(); ++k) {
91  for(std::size_t i = 0; i != e1().size1(); ++i){
92  matrix_row<M> row_m(m(),i);
93  matrix_row<typename const_expression<E2>::type> row_e2(e2(),k);
94  plus_assign(row_m,row_e2,alpha * e1()(i,k));
95  }
96  }
97 }
98 
99 // Sparse-Dense gemm
100 template <class E1, class E2, class M, class Orientation>
101 void gemm(
102  matrix_expression<E1, cpu_tag> const& e1,
103  matrix_expression<E2, cpu_tag> const& e2,
104  matrix_expression<M, cpu_tag>& m,
105  typename M::value_type alpha,
106  row_major, row_major, Orientation,
107  sparse_tag, dense_tag
108 ){
109  for (std::size_t i = 0; i != e1().size1(); ++i) {
110  matrix_row<M> row_m(m(),i);
111  matrix_row<E1> row_e1(e1(),i);
112  matrix_transpose<E2> trans_e2(e2());
113  kernels::gemv(trans_e2,row_e1,row_m,alpha);
114  }
115 }
116 
117 template <class E1, class E2, class M>
118 void gemm(
119  matrix_expression<E1, cpu_tag> const& e1,
120  matrix_expression<E2, cpu_tag> const& e2,
121  matrix_expression<M, cpu_tag>& m,
122  typename M::value_type alpha,
123  row_major, column_major, column_major,
124  sparse_tag, dense_tag
125 ){
126  typedef matrix_transpose<M> Trans_M;
127  typedef matrix_transpose<typename const_expression<E2>::type> Trans_E2;
128  Trans_M trans_m(m());
129  Trans_E2 trans_e2(e2());
130  for (std::size_t j = 0; j != e2().size2(); ++j) {
131  matrix_row<Trans_M> column_m(trans_m,j);
132  matrix_row<Trans_E2> column_e2(trans_e2,j);
133  kernels::gemv(e1,column_e2,column_m,alpha);
134  }
135 }
136 
137 template <class E1, class E2, class M>
138 void gemm(
139  matrix_expression<E1, cpu_tag> const& e1,
140  matrix_expression<E2, cpu_tag> const& e2,
141  matrix_expression<M, cpu_tag>& m,
142  typename M::value_type alpha,
143  row_major, column_major, row_major,
144  sparse_tag, dense_tag
145 ){
146  for (std::size_t k = 0; k != e1().size2(); ++k) {
147  auto e1end = e1().column_end(k);
148  for(auto e1pos = e1().column_begin(k); e1pos != e1end; ++e1pos){
149  std::size_t i = e1pos.index();
150  matrix_row<M> row_m(m(),i);
151  matrix_row<typename const_expression<E2>::type> row_e2(e2(),k);
152  plus_assign(row_m,row_e2,alpha * (*e1pos));
153  }
154  }
155 }
156 
157 // Sparse-Sparse gemm
158 template<class M, class E1, class E2>
159 void gemm(
160  matrix_expression<E1, cpu_tag> const& e1,
161  matrix_expression<E2, cpu_tag> const& e2,
162  matrix_expression<M, cpu_tag>& m,
163  typename M::value_type alpha,
164  row_major, row_major, row_major,
165  sparse_tag, sparse_tag
166 ) {
167  typedef typename M::value_type value_type;
168  value_type zero = value_type();
169  vector<value_type> temporary(e2().size2(), zero);//dense vector for quick random access
170  matrix_transpose<typename const_expression<E2>::type> e2trans(e2());
171  for (std::size_t i = 0; i != e1().size1(); ++i) {
172  matrix_row<typename const_expression<E1>::type> rowe1(e1(),i);
173  kernels::gemv(e2trans,rowe1,temporary,alpha);
174  auto insert_pos = m().row_begin(i);
175  for (std::size_t j = 0; j != temporary.size(); ++ j) {
176  if (temporary(j) != zero) {
177  //find element with that index
178  auto row_end = m().row_end(i);
179  while(insert_pos != row_end && insert_pos.index() < j)
180  ++insert_pos;
181  //check if element exists
182  if(insert_pos != row_end && insert_pos.index() == j){
183  *insert_pos += temporary(j);
184  }else{//create new element
185  insert_pos = m().set_element(insert_pos,j,temporary(j));
186  }
187  //~ m()(i,j) += temporary(j);
188  temporary(j) = zero; // delete element
189  }
190  }
191  }
192 }
193 
194 template<class M, class E1, class E2>
195 void gemm(
196  matrix_expression<E1, cpu_tag> const& e1,
197  matrix_expression<E2, cpu_tag> const& e2,
198  matrix_expression<M, cpu_tag>& m,
199  typename M::value_type alpha,
200  row_major, row_major, column_major,
201  sparse_tag, sparse_tag
202 ) {
203  typedef matrix_transpose<M> Trans_M;
204  typedef matrix_transpose<typename const_expression<E2>::type> Trans_E2;
205  Trans_M trans_m(m());
206  Trans_E2 trans_e2(e2());
207  for (std::size_t j = 0; j != e2().size2(); ++j) {
208  matrix_row<Trans_M> column_m(trans_m,j);
209  matrix_row<Trans_E2> column_e2(trans_e2,j);
210  kernels::gemv(e1,column_e2,column_m,alpha);
211  }
212 }
213 
214 template <class E1, class E2, class M, class Orientation>
215 void gemm(
216  matrix_expression<E1, cpu_tag> const& e1,
217  matrix_expression<E2, cpu_tag> const& e2,
218  matrix_expression<M, cpu_tag>& m,
219  typename M::value_type alpha,
220  row_major, column_major, Orientation o,
221  sparse_tag t1, sparse_tag t2
222 ){
223  //best way to compute this is to transpose e1 in memory. alternative would be
224  // to compute outer products, which is a no-no.
225  typename transposed_matrix_temporary<E1>::type e1_trans(e1);
226  gemm(e1_trans,e2,m,alpha,row_major(),row_major(),o,t1,t2);
227 }
228 
229 }}
230 
231 #endif