00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025 #ifndef EIGEN_TRIANGULAR_MATRIX_MATRIX_H
00026 #define EIGEN_TRIANGULAR_MATRIX_MATRIX_H
00027
00028 namespace internal {
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041
00042
00043
00044
00045
00046
00047
00048
00049
00050
00051
00052
00053
00054
00055
00056
00057 template <typename Scalar, typename Index,
00058 int Mode, bool LhsIsTriangular,
00059 int LhsStorageOrder, bool ConjugateLhs,
00060 int RhsStorageOrder, bool ConjugateRhs,
00061 int ResStorageOrder>
00062 struct product_triangular_matrix_matrix;
00063
00064 template <typename Scalar, typename Index,
00065 int Mode, bool LhsIsTriangular,
00066 int LhsStorageOrder, bool ConjugateLhs,
00067 int RhsStorageOrder, bool ConjugateRhs>
00068 struct product_triangular_matrix_matrix<Scalar,Index,Mode,LhsIsTriangular,
00069 LhsStorageOrder,ConjugateLhs,
00070 RhsStorageOrder,ConjugateRhs,RowMajor>
00071 {
00072 static EIGEN_STRONG_INLINE void run(
00073 Index rows, Index cols, Index depth,
00074 const Scalar* lhs, Index lhsStride,
00075 const Scalar* rhs, Index rhsStride,
00076 Scalar* res, Index resStride,
00077 Scalar alpha)
00078 {
00079 product_triangular_matrix_matrix<Scalar, Index,
00080 (Mode&(UnitDiag|ZeroDiag)) | ((Mode&Upper) ? Lower : Upper),
00081 (!LhsIsTriangular),
00082 RhsStorageOrder==RowMajor ? ColMajor : RowMajor,
00083 ConjugateRhs,
00084 LhsStorageOrder==RowMajor ? ColMajor : RowMajor,
00085 ConjugateLhs,
00086 ColMajor>
00087 ::run(cols, rows, depth, rhs, rhsStride, lhs, lhsStride, res, resStride, alpha);
00088 }
00089 };
00090
00091
00092 template <typename Scalar, typename Index, int Mode,
00093 int LhsStorageOrder, bool ConjugateLhs,
00094 int RhsStorageOrder, bool ConjugateRhs>
00095 struct product_triangular_matrix_matrix<Scalar,Index,Mode,true,
00096 LhsStorageOrder,ConjugateLhs,
00097 RhsStorageOrder,ConjugateRhs,ColMajor>
00098 {
00099
00100 static EIGEN_DONT_INLINE void run(
00101 Index rows, Index cols, Index depth,
00102 const Scalar* _lhs, Index lhsStride,
00103 const Scalar* _rhs, Index rhsStride,
00104 Scalar* res, Index resStride,
00105 Scalar alpha)
00106 {
00107 const_blas_data_mapper<Scalar, Index, LhsStorageOrder> lhs(_lhs,lhsStride);
00108 const_blas_data_mapper<Scalar, Index, RhsStorageOrder> rhs(_rhs,rhsStride);
00109
00110 typedef gebp_traits<Scalar,Scalar> Traits;
00111 enum {
00112 SmallPanelWidth = EIGEN_PLAIN_ENUM_MAX(Traits::mr,Traits::nr),
00113 IsLower = (Mode&Lower) == Lower,
00114 SetDiag = (Mode&(ZeroDiag|UnitDiag)) ? 0 : 1
00115 };
00116
00117 Index kc = depth;
00118 Index mc = rows;
00119 Index nc = cols;
00120 computeProductBlockingSizes<Scalar,Scalar,4>(kc, mc, nc);
00121
00122 Scalar* blockA = ei_aligned_stack_new(Scalar, kc*mc);
00123 std::size_t sizeW = kc*Traits::WorkSpaceFactor;
00124 std::size_t sizeB = sizeW + kc*cols;
00125 Scalar* allocatedBlockB = ei_aligned_stack_new(Scalar, sizeB);
00126 Scalar* blockB = allocatedBlockB + sizeW;
00127
00128 Matrix<Scalar,SmallPanelWidth,SmallPanelWidth,LhsStorageOrder> triangularBuffer;
00129 triangularBuffer.setZero();
00130 if((Mode&ZeroDiag)==ZeroDiag)
00131 triangularBuffer.diagonal().setZero();
00132 else
00133 triangularBuffer.diagonal().setOnes();
00134
00135 gebp_kernel<Scalar, Scalar, Index, Traits::mr, Traits::nr, ConjugateLhs, ConjugateRhs> gebp_kernel;
00136 gemm_pack_lhs<Scalar, Index, Traits::mr, Traits::LhsProgress, LhsStorageOrder> pack_lhs;
00137 gemm_pack_rhs<Scalar, Index, Traits::nr,RhsStorageOrder> pack_rhs;
00138
00139 for(Index k2=IsLower ? depth : 0;
00140 IsLower ? k2>0 : k2<depth;
00141 IsLower ? k2-=kc : k2+=kc)
00142 {
00143 Index actual_kc = std::min(IsLower ? k2 : depth-k2, kc);
00144 Index actual_k2 = IsLower ? k2-actual_kc : k2;
00145
00146
00147 if((!IsLower)&&(k2<rows)&&(k2+actual_kc>rows))
00148 {
00149 actual_kc = rows-k2;
00150 k2 = k2+actual_kc-kc;
00151 }
00152
00153 pack_rhs(blockB, &rhs(actual_k2,0), rhsStride, actual_kc, cols);
00154
00155
00156
00157
00158
00159
00160 if(IsLower || actual_k2<rows)
00161 {
00162
00163 for (Index k1=0; k1<actual_kc; k1+=SmallPanelWidth)
00164 {
00165 Index actualPanelWidth = std::min<Index>(actual_kc-k1, SmallPanelWidth);
00166 Index lengthTarget = IsLower ? actual_kc-k1-actualPanelWidth : k1;
00167 Index startBlock = actual_k2+k1;
00168 Index blockBOffset = k1;
00169
00170
00171
00172
00173 for (Index k=0;k<actualPanelWidth;++k)
00174 {
00175 if (SetDiag)
00176 triangularBuffer.coeffRef(k,k) = lhs(startBlock+k,startBlock+k);
00177 for (Index i=IsLower ? k+1 : 0; IsLower ? i<actualPanelWidth : i<k; ++i)
00178 triangularBuffer.coeffRef(i,k) = lhs(startBlock+i,startBlock+k);
00179 }
00180 pack_lhs(blockA, triangularBuffer.data(), triangularBuffer.outerStride(), actualPanelWidth, actualPanelWidth);
00181
00182 gebp_kernel(res+startBlock, resStride, blockA, blockB, actualPanelWidth, actualPanelWidth, cols, alpha,
00183 actualPanelWidth, actual_kc, 0, blockBOffset);
00184
00185
00186 if (lengthTarget>0)
00187 {
00188 Index startTarget = IsLower ? actual_k2+k1+actualPanelWidth : actual_k2;
00189
00190 pack_lhs(blockA, &lhs(startTarget,startBlock), lhsStride, actualPanelWidth, lengthTarget);
00191
00192 gebp_kernel(res+startTarget, resStride, blockA, blockB, lengthTarget, actualPanelWidth, cols, alpha,
00193 actualPanelWidth, actual_kc, 0, blockBOffset);
00194 }
00195 }
00196 }
00197
00198 {
00199 Index start = IsLower ? k2 : 0;
00200 Index end = IsLower ? rows : std::min(actual_k2,rows);
00201 for(Index i2=start; i2<end; i2+=mc)
00202 {
00203 const Index actual_mc = std::min(i2+mc,end)-i2;
00204 gemm_pack_lhs<Scalar, Index, Traits::mr,Traits::LhsProgress, LhsStorageOrder,false>()
00205 (blockA, &lhs(i2, actual_k2), lhsStride, actual_kc, actual_mc);
00206
00207 gebp_kernel(res+i2, resStride, blockA, blockB, actual_mc, actual_kc, cols, alpha);
00208 }
00209 }
00210 }
00211
00212 ei_aligned_stack_delete(Scalar, blockA, kc*mc);
00213 ei_aligned_stack_delete(Scalar, allocatedBlockB, sizeB);
00214
00215 }
00216 };
00217
00218
00219 template <typename Scalar, typename Index, int Mode,
00220 int LhsStorageOrder, bool ConjugateLhs,
00221 int RhsStorageOrder, bool ConjugateRhs>
00222 struct product_triangular_matrix_matrix<Scalar,Index,Mode,false,
00223 LhsStorageOrder,ConjugateLhs,
00224 RhsStorageOrder,ConjugateRhs,ColMajor>
00225 {
00226
00227 static EIGEN_DONT_INLINE void run(
00228 Index rows, Index cols, Index depth,
00229 const Scalar* _lhs, Index lhsStride,
00230 const Scalar* _rhs, Index rhsStride,
00231 Scalar* res, Index resStride,
00232 Scalar alpha)
00233 {
00234 const_blas_data_mapper<Scalar, Index, LhsStorageOrder> lhs(_lhs,lhsStride);
00235 const_blas_data_mapper<Scalar, Index, RhsStorageOrder> rhs(_rhs,rhsStride);
00236
00237 typedef gebp_traits<Scalar,Scalar> Traits;
00238 enum {
00239 SmallPanelWidth = EIGEN_PLAIN_ENUM_MAX(Traits::mr,Traits::nr),
00240 IsLower = (Mode&Lower) == Lower,
00241 SetDiag = (Mode&(ZeroDiag|UnitDiag)) ? 0 : 1
00242 };
00243
00244 Index kc = depth;
00245 Index mc = rows;
00246 Index nc = cols;
00247 computeProductBlockingSizes<Scalar,Scalar,4>(kc, mc, nc);
00248
00249 Scalar* blockA = ei_aligned_stack_new(Scalar, kc*mc);
00250 std::size_t sizeW = kc*Traits::WorkSpaceFactor;
00251 std::size_t sizeB = sizeW + kc*cols;
00252 Scalar* allocatedBlockB = ei_aligned_stack_new(Scalar,sizeB);
00253 Scalar* blockB = allocatedBlockB + sizeW;
00254
00255 Matrix<Scalar,SmallPanelWidth,SmallPanelWidth,RhsStorageOrder> triangularBuffer;
00256 triangularBuffer.setZero();
00257 if((Mode&ZeroDiag)==ZeroDiag)
00258 triangularBuffer.diagonal().setZero();
00259 else
00260 triangularBuffer.diagonal().setOnes();
00261
00262 gebp_kernel<Scalar, Scalar, Index, Traits::mr, Traits::nr, ConjugateLhs, ConjugateRhs> gebp_kernel;
00263 gemm_pack_lhs<Scalar, Index, Traits::mr, Traits::LhsProgress, LhsStorageOrder> pack_lhs;
00264 gemm_pack_rhs<Scalar, Index, Traits::nr,RhsStorageOrder> pack_rhs;
00265 gemm_pack_rhs<Scalar, Index, Traits::nr,RhsStorageOrder,false,true> pack_rhs_panel;
00266
00267 for(Index k2=IsLower ? 0 : depth;
00268 IsLower ? k2<depth : k2>0;
00269 IsLower ? k2+=kc : k2-=kc)
00270 {
00271 Index actual_kc = std::min(IsLower ? depth-k2 : k2, kc);
00272 Index actual_k2 = IsLower ? k2 : k2-actual_kc;
00273
00274
00275 if(IsLower && (k2<cols) && (actual_k2+actual_kc>cols))
00276 {
00277 actual_kc = cols-k2;
00278 k2 = actual_k2 + actual_kc - kc;
00279 }
00280
00281
00282 Index rs = IsLower ? std::min(cols,actual_k2) : cols - k2;
00283
00284 Index ts = (IsLower && actual_k2>=cols) ? 0 : actual_kc;
00285
00286 Scalar* geb = blockB+ts*ts;
00287
00288 pack_rhs(geb, &rhs(actual_k2,IsLower ? 0 : k2), rhsStride, actual_kc, rs);
00289
00290
00291 if(ts>0)
00292 {
00293 for (Index j2=0; j2<actual_kc; j2+=SmallPanelWidth)
00294 {
00295 Index actualPanelWidth = std::min<Index>(actual_kc-j2, SmallPanelWidth);
00296 Index actual_j2 = actual_k2 + j2;
00297 Index panelOffset = IsLower ? j2+actualPanelWidth : 0;
00298 Index panelLength = IsLower ? actual_kc-j2-actualPanelWidth : j2;
00299
00300 pack_rhs_panel(blockB+j2*actual_kc,
00301 &rhs(actual_k2+panelOffset, actual_j2), rhsStride,
00302 panelLength, actualPanelWidth,
00303 actual_kc, panelOffset);
00304
00305
00306 for (Index j=0;j<actualPanelWidth;++j)
00307 {
00308 if (SetDiag)
00309 triangularBuffer.coeffRef(j,j) = rhs(actual_j2+j,actual_j2+j);
00310 for (Index k=IsLower ? j+1 : 0; IsLower ? k<actualPanelWidth : k<j; ++k)
00311 triangularBuffer.coeffRef(k,j) = rhs(actual_j2+k,actual_j2+j);
00312 }
00313
00314 pack_rhs_panel(blockB+j2*actual_kc,
00315 triangularBuffer.data(), triangularBuffer.outerStride(),
00316 actualPanelWidth, actualPanelWidth,
00317 actual_kc, j2);
00318 }
00319 }
00320
00321 for (Index i2=0; i2<rows; i2+=mc)
00322 {
00323 const Index actual_mc = std::min(mc,rows-i2);
00324 pack_lhs(blockA, &lhs(i2, actual_k2), lhsStride, actual_kc, actual_mc);
00325
00326
00327 if(ts>0)
00328 {
00329 for (Index j2=0; j2<actual_kc; j2+=SmallPanelWidth)
00330 {
00331 Index actualPanelWidth = std::min<Index>(actual_kc-j2, SmallPanelWidth);
00332 Index panelLength = IsLower ? actual_kc-j2 : j2+actualPanelWidth;
00333 Index blockOffset = IsLower ? j2 : 0;
00334
00335 gebp_kernel(res+i2+(actual_k2+j2)*resStride, resStride,
00336 blockA, blockB+j2*actual_kc,
00337 actual_mc, panelLength, actualPanelWidth,
00338 alpha,
00339 actual_kc, actual_kc,
00340 blockOffset, blockOffset,
00341 allocatedBlockB);
00342 }
00343 }
00344 gebp_kernel(res+i2+(IsLower ? 0 : k2)*resStride, resStride,
00345 blockA, geb, actual_mc, actual_kc, rs,
00346 alpha,
00347 -1, -1, 0, 0, allocatedBlockB);
00348 }
00349 }
00350
00351 ei_aligned_stack_delete(Scalar, blockA, kc*mc);
00352 ei_aligned_stack_delete(Scalar, allocatedBlockB, sizeB);
00353 }
00354 };
00355
00356
00357
00358
00359
00360 template<int Mode, bool LhsIsTriangular, typename Lhs, typename Rhs>
00361 struct traits<TriangularProduct<Mode,LhsIsTriangular,Lhs,false,Rhs,false> >
00362 : traits<ProductBase<TriangularProduct<Mode,LhsIsTriangular,Lhs,false,Rhs,false>, Lhs, Rhs> >
00363 {};
00364
00365 }
00366
00367 template<int Mode, bool LhsIsTriangular, typename Lhs, typename Rhs>
00368 struct TriangularProduct<Mode,LhsIsTriangular,Lhs,false,Rhs,false>
00369 : public ProductBase<TriangularProduct<Mode,LhsIsTriangular,Lhs,false,Rhs,false>, Lhs, Rhs >
00370 {
00371 EIGEN_PRODUCT_PUBLIC_INTERFACE(TriangularProduct)
00372
00373 TriangularProduct(const Lhs& lhs, const Rhs& rhs) : Base(lhs,rhs) {}
00374
00375 template<typename Dest> void scaleAndAddTo(Dest& dst, Scalar alpha) const
00376 {
00377 const ActualLhsType lhs = LhsBlasTraits::extract(m_lhs);
00378 const ActualRhsType rhs = RhsBlasTraits::extract(m_rhs);
00379
00380 Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(m_lhs)
00381 * RhsBlasTraits::extractScalarFactor(m_rhs);
00382
00383 internal::product_triangular_matrix_matrix<Scalar, Index,
00384 Mode, LhsIsTriangular,
00385 (internal::traits<_ActualLhsType>::Flags&RowMajorBit) ? RowMajor : ColMajor, LhsBlasTraits::NeedToConjugate,
00386 (internal::traits<_ActualRhsType>::Flags&RowMajorBit) ? RowMajor : ColMajor, RhsBlasTraits::NeedToConjugate,
00387 (internal::traits<Dest >::Flags&RowMajorBit) ? RowMajor : ColMajor>
00388 ::run(
00389 lhs.rows(), rhs.cols(), lhs.cols(),
00390 &lhs.coeffRef(0,0), lhs.outerStride(),
00391 &rhs.coeffRef(0,0), rhs.outerStride(),
00392 &dst.coeffRef(0,0), dst.outerStride(),
00393 actualAlpha
00394 );
00395 }
00396 };
00397
00398
00399 #endif // EIGEN_TRIANGULAR_MATRIX_MATRIX_H