00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025 #ifndef EIGEN_GENERAL_MATRIX_MATRIX_TRIANGULAR_H
00026 #define EIGEN_GENERAL_MATRIX_MATRIX_TRIANGULAR_H
00027
00028 namespace internal {
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038 template<typename LhsScalar, typename RhsScalar, typename Index, int mr, int nr, bool ConjLhs, bool ConjRhs, int UpLo>
00039 struct tribb_kernel;
00040
00041
00042 template <typename Index,
00043 typename LhsScalar, int LhsStorageOrder, bool ConjugateLhs,
00044 typename RhsScalar, int RhsStorageOrder, bool ConjugateRhs,
00045 int ResStorageOrder, int UpLo>
00046 struct general_matrix_matrix_triangular_product;
00047
00048
00049 template <typename Index, typename LhsScalar, int LhsStorageOrder, bool ConjugateLhs,
00050 typename RhsScalar, int RhsStorageOrder, bool ConjugateRhs, int UpLo>
00051 struct general_matrix_matrix_triangular_product<Index,LhsScalar,LhsStorageOrder,ConjugateLhs,RhsScalar,RhsStorageOrder,ConjugateRhs,RowMajor,UpLo>
00052 {
00053 typedef typename scalar_product_traits<LhsScalar, RhsScalar>::ReturnType ResScalar;
00054 static EIGEN_STRONG_INLINE void run(Index size, Index depth,const LhsScalar* lhs, Index lhsStride,
00055 const RhsScalar* rhs, Index rhsStride, ResScalar* res, Index resStride, ResScalar alpha)
00056 {
00057 general_matrix_matrix_triangular_product<Index,
00058 RhsScalar, RhsStorageOrder==RowMajor ? ColMajor : RowMajor, ConjugateRhs,
00059 LhsScalar, LhsStorageOrder==RowMajor ? ColMajor : RowMajor, ConjugateLhs,
00060 ColMajor, UpLo==Lower?Upper:Lower>
00061 ::run(size,depth,rhs,rhsStride,lhs,lhsStride,res,resStride,alpha);
00062 }
00063 };
00064
00065 template <typename Index, typename LhsScalar, int LhsStorageOrder, bool ConjugateLhs,
00066 typename RhsScalar, int RhsStorageOrder, bool ConjugateRhs, int UpLo>
00067 struct general_matrix_matrix_triangular_product<Index,LhsScalar,LhsStorageOrder,ConjugateLhs,RhsScalar,RhsStorageOrder,ConjugateRhs,ColMajor,UpLo>
00068 {
00069 typedef typename scalar_product_traits<LhsScalar, RhsScalar>::ReturnType ResScalar;
00070 static EIGEN_STRONG_INLINE void run(Index size, Index depth,const LhsScalar* _lhs, Index lhsStride,
00071 const RhsScalar* _rhs, Index rhsStride, ResScalar* res, Index resStride, ResScalar alpha)
00072 {
00073 const_blas_data_mapper<LhsScalar, Index, LhsStorageOrder> lhs(_lhs,lhsStride);
00074 const_blas_data_mapper<RhsScalar, Index, RhsStorageOrder> rhs(_rhs,rhsStride);
00075
00076 typedef gebp_traits<LhsScalar,RhsScalar> Traits;
00077
00078 Index kc = depth;
00079 Index mc = size;
00080 Index nc = size;
00081 computeProductBlockingSizes<LhsScalar,RhsScalar>(kc, mc, nc);
00082
00083 if(mc > Traits::nr)
00084 mc = (mc/Traits::nr)*Traits::nr;
00085
00086 LhsScalar* blockA = ei_aligned_stack_new(LhsScalar, kc*mc);
00087 std::size_t sizeW = kc*Traits::WorkSpaceFactor;
00088 std::size_t sizeB = sizeW + kc*size;
00089 RhsScalar* allocatedBlockB = ei_aligned_stack_new(RhsScalar, sizeB);
00090 RhsScalar* blockB = allocatedBlockB + sizeW;
00091
00092 gemm_pack_lhs<LhsScalar, Index, Traits::mr, Traits::LhsProgress, LhsStorageOrder> pack_lhs;
00093 gemm_pack_rhs<RhsScalar, Index, Traits::nr, RhsStorageOrder> pack_rhs;
00094 gebp_kernel <LhsScalar, RhsScalar, Index, Traits::mr, Traits::nr, ConjugateLhs, ConjugateRhs> gebp;
00095 tribb_kernel<LhsScalar, RhsScalar, Index, Traits::mr, Traits::nr, ConjugateLhs, ConjugateRhs, UpLo> sybb;
00096
00097 for(Index k2=0; k2<depth; k2+=kc)
00098 {
00099 const Index actual_kc = std::min(k2+kc,depth)-k2;
00100
00101
00102 pack_rhs(blockB, &rhs(k2,0), rhsStride, actual_kc, size);
00103
00104 for(Index i2=0; i2<size; i2+=mc)
00105 {
00106 const Index actual_mc = std::min(i2+mc,size)-i2;
00107
00108 pack_lhs(blockA, &lhs(i2, k2), lhsStride, actual_kc, actual_mc);
00109
00110
00111
00112
00113
00114 if (UpLo==Lower)
00115 gebp(res+i2, resStride, blockA, blockB, actual_mc, actual_kc, std::min(size,i2), alpha,
00116 -1, -1, 0, 0, allocatedBlockB);
00117
00118 sybb(res+resStride*i2 + i2, resStride, blockA, blockB + actual_kc*i2, actual_mc, actual_kc, alpha, allocatedBlockB);
00119
00120 if (UpLo==Upper)
00121 {
00122 Index j2 = i2+actual_mc;
00123 gebp(res+resStride*j2+i2, resStride, blockA, blockB+actual_kc*j2, actual_mc, actual_kc, std::max(Index(0), size-j2), alpha,
00124 -1, -1, 0, 0, allocatedBlockB);
00125 }
00126 }
00127 }
00128 ei_aligned_stack_delete(LhsScalar, blockA, kc*mc);
00129 ei_aligned_stack_delete(RhsScalar, allocatedBlockB, sizeB);
00130 }
00131 };
00132
00133
00134
00135
00136
00137
00138
00139
00140
00141
00142 template<typename LhsScalar, typename RhsScalar, typename Index, int mr, int nr, bool ConjLhs, bool ConjRhs, int UpLo>
00143 struct tribb_kernel
00144 {
00145 typedef gebp_traits<LhsScalar,RhsScalar,ConjLhs,ConjRhs> Traits;
00146 typedef typename Traits::ResScalar ResScalar;
00147
00148 enum {
00149 BlockSize = EIGEN_PLAIN_ENUM_MAX(mr,nr)
00150 };
00151 void operator()(ResScalar* res, Index resStride, const LhsScalar* blockA, const RhsScalar* blockB, Index size, Index depth, ResScalar alpha, RhsScalar* workspace)
00152 {
00153 gebp_kernel<LhsScalar, RhsScalar, Index, mr, nr, ConjLhs, ConjRhs> gebp_kernel;
00154 Matrix<ResScalar,BlockSize,BlockSize,ColMajor> buffer;
00155
00156
00157
00158 for (Index j=0; j<size; j+=BlockSize)
00159 {
00160 Index actualBlockSize = std::min<Index>(BlockSize,size - j);
00161 const RhsScalar* actual_b = blockB+j*depth;
00162
00163 if(UpLo==Upper)
00164 gebp_kernel(res+j*resStride, resStride, blockA, actual_b, j, depth, actualBlockSize, alpha,
00165 -1, -1, 0, 0, workspace);
00166
00167
00168 {
00169 Index i = j;
00170 buffer.setZero();
00171
00172 gebp_kernel(buffer.data(), BlockSize, blockA+depth*i, actual_b, actualBlockSize, depth, actualBlockSize, alpha,
00173 -1, -1, 0, 0, workspace);
00174
00175 for(Index j1=0; j1<actualBlockSize; ++j1)
00176 {
00177 ResScalar* r = res + (j+j1)*resStride + i;
00178 for(Index i1=UpLo==Lower ? j1 : 0;
00179 UpLo==Lower ? i1<actualBlockSize : i1<=j1; ++i1)
00180 r[i1] += buffer(i1,j1);
00181 }
00182 }
00183
00184 if(UpLo==Lower)
00185 {
00186 Index i = j+actualBlockSize;
00187 gebp_kernel(res+j*resStride+i, resStride, blockA+depth*i, actual_b, size-i, depth, actualBlockSize, alpha,
00188 -1, -1, 0, 0, workspace);
00189 }
00190 }
00191 }
00192 };
00193
00194 }
00195
00196
00197
00198 template<typename MatrixType, unsigned int UpLo>
00199 template<typename ProductDerived, typename _Lhs, typename _Rhs>
00200 TriangularView<MatrixType,UpLo>& TriangularView<MatrixType,UpLo>::assignProduct(const ProductBase<ProductDerived, _Lhs,_Rhs>& prod, const Scalar& alpha)
00201 {
00202 typedef typename internal::remove_all<typename ProductDerived::LhsNested>::type Lhs;
00203 typedef internal::blas_traits<Lhs> LhsBlasTraits;
00204 typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhs;
00205 typedef typename internal::remove_all<ActualLhs>::type _ActualLhs;
00206 const ActualLhs actualLhs = LhsBlasTraits::extract(prod.lhs());
00207
00208 typedef typename internal::remove_all<typename ProductDerived::RhsNested>::type Rhs;
00209 typedef internal::blas_traits<Rhs> RhsBlasTraits;
00210 typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhs;
00211 typedef typename internal::remove_all<ActualRhs>::type _ActualRhs;
00212 const ActualRhs actualRhs = RhsBlasTraits::extract(prod.rhs());
00213
00214 typename ProductDerived::Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(prod.lhs().derived()) * RhsBlasTraits::extractScalarFactor(prod.rhs().derived());
00215
00216 internal::general_matrix_matrix_triangular_product<Index,
00217 typename Lhs::Scalar, _ActualLhs::Flags&RowMajorBit ? RowMajor : ColMajor, LhsBlasTraits::NeedToConjugate,
00218 typename Rhs::Scalar, _ActualRhs::Flags&RowMajorBit ? RowMajor : ColMajor, RhsBlasTraits::NeedToConjugate,
00219 MatrixType::Flags&RowMajorBit ? RowMajor : ColMajor, UpLo>
00220 ::run(m_matrix.cols(), actualLhs.cols(),
00221 &actualLhs.coeffRef(0,0), actualLhs.outerStride(), &actualRhs.coeffRef(0,0), actualRhs.outerStride(),
00222 const_cast<Scalar*>(m_matrix.data()), m_matrix.outerStride(), actualAlpha);
00223
00224 return *this;
00225 }
00226
00227 #endif // EIGEN_GENERAL_MATRIX_MATRIX_TRIANGULAR_H