00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025 #ifndef EIGEN_SPARSESPARSEPRODUCT_H
00026 #define EIGEN_SPARSESPARSEPRODUCT_H
00027
00028 namespace internal {
00029
00030 template<typename Lhs, typename Rhs, typename ResultType>
00031 static void sparse_product_impl2(const Lhs& lhs, const Rhs& rhs, ResultType& res)
00032 {
00033 typedef typename remove_all<Lhs>::type::Scalar Scalar;
00034 typedef typename remove_all<Lhs>::type::Index Index;
00035
00036
00037 Index rows = lhs.innerSize();
00038 Index cols = rhs.outerSize();
00039 eigen_assert(lhs.outerSize() == rhs.innerSize());
00040
00041 std::vector<bool> mask(rows,false);
00042 Matrix<Scalar,Dynamic,1> values(rows);
00043 Matrix<Index,Dynamic,1> indices(rows);
00044
00045
00046 float ratioLhs = float(lhs.nonZeros())/(float(lhs.rows())*float(lhs.cols()));
00047 float avgNnzPerRhsColumn = float(rhs.nonZeros())/float(cols);
00048 float ratioRes = std::min(ratioLhs * avgNnzPerRhsColumn, 1.f);
00049
00050
00051
00052
00053 res.resize(rows, cols);
00054 res.reserve(Index(ratioRes*rows*cols));
00055
00056 for (Index j=0; j<cols; ++j)
00057 {
00058
00059 res.startVec(j);
00060 Index nnz = 0;
00061 for (typename Rhs::InnerIterator rhsIt(rhs, j); rhsIt; ++rhsIt)
00062 {
00063 Scalar y = rhsIt.value();
00064 Index k = rhsIt.index();
00065 for (typename Lhs::InnerIterator lhsIt(lhs, k); lhsIt; ++lhsIt)
00066 {
00067 Index i = lhsIt.index();
00068 Scalar x = lhsIt.value();
00069 if(!mask[i])
00070 {
00071 mask[i] = true;
00072
00073
00074 ++nnz;
00075 }
00076 else
00077 values[i] += x * y;
00078 }
00079 }
00080
00081
00082
00083
00084
00085
00086
00087
00088
00089
00090
00091
00092
00093
00094
00095
00096
00097
00098
00099
00100
00101
00102
00103
00104
00105
00106
00107
00108
00109 }
00110 res.finalize();
00111 }
00112
00113
00114 template<typename Lhs, typename Rhs, typename ResultType>
00115 static void sparse_product_impl(const Lhs& lhs, const Rhs& rhs, ResultType& res)
00116 {
00117
00118
00119 typedef typename remove_all<Lhs>::type::Scalar Scalar;
00120 typedef typename remove_all<Lhs>::type::Index Index;
00121
00122
00123 Index rows = lhs.innerSize();
00124 Index cols = rhs.outerSize();
00125
00126 eigen_assert(lhs.outerSize() == rhs.innerSize());
00127
00128
00129 AmbiVector<Scalar,Index> tempVector(rows);
00130
00131
00132 float ratioLhs = float(lhs.nonZeros())/(float(lhs.rows())*float(lhs.cols()));
00133 float avgNnzPerRhsColumn = float(rhs.nonZeros())/float(cols);
00134 float ratioRes = std::min(ratioLhs * avgNnzPerRhsColumn, 1.f);
00135
00136 res.resize(rows, cols);
00137 res.reserve(Index(ratioRes*rows*cols));
00138 for (Index j=0; j<cols; ++j)
00139 {
00140
00141
00142
00143 float ratioColRes = ratioRes;
00144 tempVector.init(ratioColRes);
00145 tempVector.setZero();
00146 for (typename Rhs::InnerIterator rhsIt(rhs, j); rhsIt; ++rhsIt)
00147 {
00148
00149 tempVector.restart();
00150 Scalar x = rhsIt.value();
00151 for (typename Lhs::InnerIterator lhsIt(lhs, rhsIt.index()); lhsIt; ++lhsIt)
00152 {
00153 tempVector.coeffRef(lhsIt.index()) += lhsIt.value() * x;
00154 }
00155 }
00156 res.startVec(j);
00157 for (typename AmbiVector<Scalar,Index>::Iterator it(tempVector); it; ++it)
00158 res.insertBackByOuterInner(j,it.index()) = it.value();
00159 }
00160 res.finalize();
00161 }
00162
00163 template<typename Lhs, typename Rhs, typename ResultType,
00164 int LhsStorageOrder = traits<Lhs>::Flags&RowMajorBit,
00165 int RhsStorageOrder = traits<Rhs>::Flags&RowMajorBit,
00166 int ResStorageOrder = traits<ResultType>::Flags&RowMajorBit>
00167 struct sparse_product_selector;
00168
00169 template<typename Lhs, typename Rhs, typename ResultType>
00170 struct sparse_product_selector<Lhs,Rhs,ResultType,ColMajor,ColMajor,ColMajor>
00171 {
00172 typedef typename traits<typename remove_all<Lhs>::type>::Scalar Scalar;
00173
00174 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res)
00175 {
00176
00177 typename remove_all<ResultType>::type _res(res.rows(), res.cols());
00178 sparse_product_impl<Lhs,Rhs,ResultType>(lhs, rhs, _res);
00179 res.swap(_res);
00180 }
00181 };
00182
00183 template<typename Lhs, typename Rhs, typename ResultType>
00184 struct sparse_product_selector<Lhs,Rhs,ResultType,ColMajor,ColMajor,RowMajor>
00185 {
00186 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res)
00187 {
00188
00189
00190 typedef SparseMatrix<typename ResultType::Scalar> SparseTemporaryType;
00191 SparseTemporaryType _res(res.rows(), res.cols());
00192 sparse_product_impl<Lhs,Rhs,SparseTemporaryType>(lhs, rhs, _res);
00193 res = _res;
00194 }
00195 };
00196
00197 template<typename Lhs, typename Rhs, typename ResultType>
00198 struct sparse_product_selector<Lhs,Rhs,ResultType,RowMajor,RowMajor,RowMajor>
00199 {
00200 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res)
00201 {
00202
00203
00204 typename remove_all<ResultType>::type _res(res.rows(), res.cols());
00205 sparse_product_impl<Rhs,Lhs,ResultType>(rhs, lhs, _res);
00206 res.swap(_res);
00207 }
00208 };
00209
00210 template<typename Lhs, typename Rhs, typename ResultType>
00211 struct sparse_product_selector<Lhs,Rhs,ResultType,RowMajor,RowMajor,ColMajor>
00212 {
00213 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res)
00214 {
00215
00216 typedef SparseMatrix<typename ResultType::Scalar,ColMajor> ColMajorMatrix;
00217 ColMajorMatrix colLhs(lhs);
00218 ColMajorMatrix colRhs(rhs);
00219
00220 sparse_product_impl<ColMajorMatrix,ColMajorMatrix,ResultType>(colLhs, colRhs, res);
00221
00222
00223
00224
00225
00226
00227
00228
00229 }
00230 };
00231
00232
00233
00234
00235 }
00236
00237
00238 template<typename Derived>
00239 template<typename Lhs, typename Rhs>
00240 inline Derived& SparseMatrixBase<Derived>::operator=(const SparseSparseProduct<Lhs,Rhs>& product)
00241 {
00242
00243 internal::sparse_product_selector<
00244 typename internal::remove_all<Lhs>::type,
00245 typename internal::remove_all<Rhs>::type,
00246 Derived>::run(product.lhs(),product.rhs(),derived());
00247 return derived();
00248 }
00249
00250 namespace internal {
00251
00252 template<typename Lhs, typename Rhs, typename ResultType,
00253 int LhsStorageOrder = traits<Lhs>::Flags&RowMajorBit,
00254 int RhsStorageOrder = traits<Rhs>::Flags&RowMajorBit,
00255 int ResStorageOrder = traits<ResultType>::Flags&RowMajorBit>
00256 struct sparse_product_selector2;
00257
00258 template<typename Lhs, typename Rhs, typename ResultType>
00259 struct sparse_product_selector2<Lhs,Rhs,ResultType,ColMajor,ColMajor,ColMajor>
00260 {
00261 typedef typename traits<typename remove_all<Lhs>::type>::Scalar Scalar;
00262
00263 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res)
00264 {
00265 sparse_product_impl2<Lhs,Rhs,ResultType>(lhs, rhs, res);
00266 }
00267 };
00268
00269 template<typename Lhs, typename Rhs, typename ResultType>
00270 struct sparse_product_selector2<Lhs,Rhs,ResultType,RowMajor,ColMajor,ColMajor>
00271 {
00272 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res)
00273 {
00274
00275 EIGEN_UNUSED_VARIABLE(lhs);
00276 EIGEN_UNUSED_VARIABLE(rhs);
00277 EIGEN_UNUSED_VARIABLE(res);
00278
00279
00280
00281
00282
00283
00284 }
00285 };
00286
00287 template<typename Lhs, typename Rhs, typename ResultType>
00288 struct sparse_product_selector2<Lhs,Rhs,ResultType,ColMajor,RowMajor,ColMajor>
00289 {
00290 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res)
00291 {
00292 typedef SparseMatrix<typename ResultType::Scalar,RowMajor> RowMajorMatrix;
00293 RowMajorMatrix lhsRow = lhs;
00294 RowMajorMatrix resRow(res.rows(), res.cols());
00295 sparse_product_impl2<Rhs,RowMajorMatrix,RowMajorMatrix>(rhs, lhsRow, resRow);
00296 res = resRow;
00297 }
00298 };
00299
00300 template<typename Lhs, typename Rhs, typename ResultType>
00301 struct sparse_product_selector2<Lhs,Rhs,ResultType,RowMajor,RowMajor,ColMajor>
00302 {
00303 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res)
00304 {
00305 typedef SparseMatrix<typename ResultType::Scalar,RowMajor> RowMajorMatrix;
00306 RowMajorMatrix resRow(res.rows(), res.cols());
00307 sparse_product_impl2<Rhs,Lhs,RowMajorMatrix>(rhs, lhs, resRow);
00308 res = resRow;
00309 }
00310 };
00311
00312
00313 template<typename Lhs, typename Rhs, typename ResultType>
00314 struct sparse_product_selector2<Lhs,Rhs,ResultType,ColMajor,ColMajor,RowMajor>
00315 {
00316 typedef typename traits<typename remove_all<Lhs>::type>::Scalar Scalar;
00317
00318 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res)
00319 {
00320 typedef SparseMatrix<typename ResultType::Scalar,ColMajor> ColMajorMatrix;
00321 ColMajorMatrix resCol(res.rows(), res.cols());
00322 sparse_product_impl2<Lhs,Rhs,ColMajorMatrix>(lhs, rhs, resCol);
00323 res = resCol;
00324 }
00325 };
00326
00327 template<typename Lhs, typename Rhs, typename ResultType>
00328 struct sparse_product_selector2<Lhs,Rhs,ResultType,RowMajor,ColMajor,RowMajor>
00329 {
00330 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res)
00331 {
00332 typedef SparseMatrix<typename ResultType::Scalar,ColMajor> ColMajorMatrix;
00333 ColMajorMatrix lhsCol = lhs;
00334 ColMajorMatrix resCol(res.rows(), res.cols());
00335 sparse_product_impl2<ColMajorMatrix,Rhs,ColMajorMatrix>(lhsCol, rhs, resCol);
00336 res = resCol;
00337 }
00338 };
00339
00340 template<typename Lhs, typename Rhs, typename ResultType>
00341 struct sparse_product_selector2<Lhs,Rhs,ResultType,ColMajor,RowMajor,RowMajor>
00342 {
00343 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res)
00344 {
00345 typedef SparseMatrix<typename ResultType::Scalar,ColMajor> ColMajorMatrix;
00346 ColMajorMatrix rhsCol = rhs;
00347 ColMajorMatrix resCol(res.rows(), res.cols());
00348 sparse_product_impl2<Lhs,ColMajorMatrix,ColMajorMatrix>(lhs, rhsCol, resCol);
00349 res = resCol;
00350 }
00351 };
00352
00353 template<typename Lhs, typename Rhs, typename ResultType>
00354 struct sparse_product_selector2<Lhs,Rhs,ResultType,RowMajor,RowMajor,RowMajor>
00355 {
00356 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res)
00357 {
00358 typedef SparseMatrix<typename ResultType::Scalar,ColMajor> ColMajorMatrix;
00359
00360
00361
00362
00363
00364
00365 typedef SparseMatrix<typename ResultType::Scalar,ColMajor> ColMajorMatrix;
00366 ColMajorMatrix lhsCol(lhs);
00367 ColMajorMatrix rhsCol(rhs);
00368 ColMajorMatrix resCol(res.rows(), res.cols());
00369 sparse_product_impl2<ColMajorMatrix,ColMajorMatrix,ColMajorMatrix>(lhsCol, rhsCol, resCol);
00370 res = resCol;
00371 }
00372 };
00373
00374 }
00375
00376 template<typename Derived>
00377 template<typename Lhs, typename Rhs>
00378 inline void SparseMatrixBase<Derived>::_experimentalNewProduct(const Lhs& lhs, const Rhs& rhs)
00379 {
00380
00381 internal::sparse_product_selector2<
00382 typename internal::remove_all<Lhs>::type,
00383 typename internal::remove_all<Rhs>::type,
00384 Derived>::run(lhs,rhs,derived());
00385 }
00386
00387
00388 template<typename Derived>
00389 template<typename OtherDerived>
00390 inline const typename SparseSparseProductReturnType<Derived,OtherDerived>::Type
00391 SparseMatrixBase<Derived>::operator*(const SparseMatrixBase<OtherDerived> &other) const
00392 {
00393 return typename SparseSparseProductReturnType<Derived,OtherDerived>::Type(derived(), other.derived());
00394 }
00395
00396 #endif // EIGEN_SPARSESPARSEPRODUCT_H