00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025 #ifndef EIGEN_BLASUTIL_H
00026 #define EIGEN_BLASUTIL_H
00027
00028
00029
00030
00031 namespace internal {
00032
00033
00034 template<typename LhsScalar, typename RhsScalar, typename Index, int mr, int nr, bool ConjugateLhs=false, bool ConjugateRhs=false>
00035 struct gebp_kernel;
00036
00037 template<typename Scalar, typename Index, int nr, int StorageOrder, bool Conjugate = false, bool PanelMode=false>
00038 struct gemm_pack_rhs;
00039
00040 template<typename Scalar, typename Index, int Pack1, int Pack2, int StorageOrder, bool Conjugate = false, bool PanelMode = false>
00041 struct gemm_pack_lhs;
00042
00043 template<
00044 typename Index,
00045 typename LhsScalar, int LhsStorageOrder, bool ConjugateLhs,
00046 typename RhsScalar, int RhsStorageOrder, bool ConjugateRhs,
00047 int ResStorageOrder>
00048 struct general_matrix_matrix_product;
00049
00050 template<typename Index, typename LhsScalar, int LhsStorageOrder, bool ConjugateLhs, typename RhsScalar, bool ConjugateRhs>
00051 struct general_matrix_vector_product;
00052
00053
00054 template<bool Conjugate> struct conj_if;
00055
00056 template<> struct conj_if<true> {
00057 template<typename T>
00058 inline T operator()(const T& x) { return conj(x); }
00059 };
00060
00061 template<> struct conj_if<false> {
00062 template<typename T>
00063 inline const T& operator()(const T& x) { return x; }
00064 };
00065
00066 template<typename Scalar> struct conj_helper<Scalar,Scalar,false,false>
00067 {
00068 EIGEN_STRONG_INLINE Scalar pmadd(const Scalar& x, const Scalar& y, const Scalar& c) const { return internal::pmadd(x,y,c); }
00069 EIGEN_STRONG_INLINE Scalar pmul(const Scalar& x, const Scalar& y) const { return internal::pmul(x,y); }
00070 };
00071
00072 template<typename RealScalar> struct conj_helper<std::complex<RealScalar>, std::complex<RealScalar>, false,true>
00073 {
00074 typedef std::complex<RealScalar> Scalar;
00075 EIGEN_STRONG_INLINE Scalar pmadd(const Scalar& x, const Scalar& y, const Scalar& c) const
00076 { return c + pmul(x,y); }
00077
00078 EIGEN_STRONG_INLINE Scalar pmul(const Scalar& x, const Scalar& y) const
00079 { return Scalar(real(x)*real(y) + imag(x)*imag(y), imag(x)*real(y) - real(x)*imag(y)); }
00080 };
00081
00082 template<typename RealScalar> struct conj_helper<std::complex<RealScalar>, std::complex<RealScalar>, true,false>
00083 {
00084 typedef std::complex<RealScalar> Scalar;
00085 EIGEN_STRONG_INLINE Scalar pmadd(const Scalar& x, const Scalar& y, const Scalar& c) const
00086 { return c + pmul(x,y); }
00087
00088 EIGEN_STRONG_INLINE Scalar pmul(const Scalar& x, const Scalar& y) const
00089 { return Scalar(real(x)*real(y) + imag(x)*imag(y), real(x)*imag(y) - imag(x)*real(y)); }
00090 };
00091
00092 template<typename RealScalar> struct conj_helper<std::complex<RealScalar>, std::complex<RealScalar>, true,true>
00093 {
00094 typedef std::complex<RealScalar> Scalar;
00095 EIGEN_STRONG_INLINE Scalar pmadd(const Scalar& x, const Scalar& y, const Scalar& c) const
00096 { return c + pmul(x,y); }
00097
00098 EIGEN_STRONG_INLINE Scalar pmul(const Scalar& x, const Scalar& y) const
00099 { return Scalar(real(x)*real(y) - imag(x)*imag(y), - real(x)*imag(y) - imag(x)*real(y)); }
00100 };
00101
00102 template<typename RealScalar,bool Conj> struct conj_helper<std::complex<RealScalar>, RealScalar, Conj,false>
00103 {
00104 typedef std::complex<RealScalar> Scalar;
00105 EIGEN_STRONG_INLINE Scalar pmadd(const Scalar& x, const RealScalar& y, const Scalar& c) const
00106 { return padd(c, pmul(x,y)); }
00107 EIGEN_STRONG_INLINE Scalar pmul(const Scalar& x, const RealScalar& y) const
00108 { return conj_if<Conj>()(x)*y; }
00109 };
00110
00111 template<typename RealScalar,bool Conj> struct conj_helper<RealScalar, std::complex<RealScalar>, false,Conj>
00112 {
00113 typedef std::complex<RealScalar> Scalar;
00114 EIGEN_STRONG_INLINE Scalar pmadd(const RealScalar& x, const Scalar& y, const Scalar& c) const
00115 { return padd(c, pmul(x,y)); }
00116 EIGEN_STRONG_INLINE Scalar pmul(const RealScalar& x, const Scalar& y) const
00117 { return x*conj_if<Conj>()(y); }
00118 };
00119
00120 template<typename From,typename To> struct get_factor {
00121 EIGEN_STRONG_INLINE static To run(const From& x) { return x; }
00122 };
00123
00124 template<typename Scalar> struct get_factor<Scalar,typename NumTraits<Scalar>::Real> {
00125 EIGEN_STRONG_INLINE static typename NumTraits<Scalar>::Real run(const Scalar& x) { return real(x); }
00126 };
00127
00128
00129
00130
00131 template<typename Scalar, typename Index, int StorageOrder>
00132 class blas_data_mapper
00133 {
00134 public:
00135 blas_data_mapper(Scalar* data, Index stride) : m_data(data), m_stride(stride) {}
00136 EIGEN_STRONG_INLINE Scalar& operator()(Index i, Index j)
00137 { return m_data[StorageOrder==RowMajor ? j + i*m_stride : i + j*m_stride]; }
00138 protected:
00139 Scalar* EIGEN_RESTRICT m_data;
00140 Index m_stride;
00141 };
00142
00143
00144 template<typename Scalar, typename Index, int StorageOrder>
00145 class const_blas_data_mapper
00146 {
00147 public:
00148 const_blas_data_mapper(const Scalar* data, Index stride) : m_data(data), m_stride(stride) {}
00149 EIGEN_STRONG_INLINE const Scalar& operator()(Index i, Index j) const
00150 { return m_data[StorageOrder==RowMajor ? j + i*m_stride : i + j*m_stride]; }
00151 protected:
00152 const Scalar* EIGEN_RESTRICT m_data;
00153 Index m_stride;
00154 };
00155
00156
00157
00158
00159
00160 template<typename XprType> struct blas_traits
00161 {
00162 typedef typename traits<XprType>::Scalar Scalar;
00163 typedef const XprType& ExtractType;
00164 typedef XprType _ExtractType;
00165 enum {
00166 IsComplex = NumTraits<Scalar>::IsComplex,
00167 IsTransposed = false,
00168 NeedToConjugate = false,
00169 HasUsableDirectAccess = ( (int(XprType::Flags)&DirectAccessBit)
00170 && ( bool(XprType::IsVectorAtCompileTime)
00171 || int(inner_stride_at_compile_time<XprType>::ret) == 1)
00172 ) ? 1 : 0
00173 };
00174 typedef typename conditional<bool(HasUsableDirectAccess),
00175 ExtractType,
00176 typename _ExtractType::PlainObject
00177 >::type DirectLinearAccessType;
00178 static inline const ExtractType extract(const XprType& x) { return x; }
00179 static inline const Scalar extractScalarFactor(const XprType&) { return Scalar(1); }
00180 };
00181
00182
00183 template<typename Scalar, typename NestedXpr>
00184 struct blas_traits<CwiseUnaryOp<scalar_conjugate_op<Scalar>, NestedXpr> >
00185 : blas_traits<NestedXpr>
00186 {
00187 typedef blas_traits<NestedXpr> Base;
00188 typedef CwiseUnaryOp<scalar_conjugate_op<Scalar>, NestedXpr> XprType;
00189 typedef typename Base::ExtractType ExtractType;
00190
00191 enum {
00192 IsComplex = NumTraits<Scalar>::IsComplex,
00193 NeedToConjugate = Base::NeedToConjugate ? 0 : IsComplex
00194 };
00195 static inline const ExtractType extract(const XprType& x) { return Base::extract(x.nestedExpression()); }
00196 static inline Scalar extractScalarFactor(const XprType& x) { return conj(Base::extractScalarFactor(x.nestedExpression())); }
00197 };
00198
00199
00200 template<typename Scalar, typename NestedXpr>
00201 struct blas_traits<CwiseUnaryOp<scalar_multiple_op<Scalar>, NestedXpr> >
00202 : blas_traits<NestedXpr>
00203 {
00204 typedef blas_traits<NestedXpr> Base;
00205 typedef CwiseUnaryOp<scalar_multiple_op<Scalar>, NestedXpr> XprType;
00206 typedef typename Base::ExtractType ExtractType;
00207 static inline const ExtractType extract(const XprType& x) { return Base::extract(x.nestedExpression()); }
00208 static inline Scalar extractScalarFactor(const XprType& x)
00209 { return x.functor().m_other * Base::extractScalarFactor(x.nestedExpression()); }
00210 };
00211
00212
00213 template<typename Scalar, typename NestedXpr>
00214 struct blas_traits<CwiseUnaryOp<scalar_opposite_op<Scalar>, NestedXpr> >
00215 : blas_traits<NestedXpr>
00216 {
00217 typedef blas_traits<NestedXpr> Base;
00218 typedef CwiseUnaryOp<scalar_opposite_op<Scalar>, NestedXpr> XprType;
00219 typedef typename Base::ExtractType ExtractType;
00220 static inline const ExtractType extract(const XprType& x) { return Base::extract(x.nestedExpression()); }
00221 static inline Scalar extractScalarFactor(const XprType& x)
00222 { return - Base::extractScalarFactor(x.nestedExpression()); }
00223 };
00224
00225
00226 template<typename NestedXpr>
00227 struct blas_traits<Transpose<NestedXpr> >
00228 : blas_traits<NestedXpr>
00229 {
00230 typedef typename NestedXpr::Scalar Scalar;
00231 typedef blas_traits<NestedXpr> Base;
00232 typedef Transpose<NestedXpr> XprType;
00233 typedef Transpose<const typename Base::_ExtractType> ExtractType;
00234 typedef Transpose<const typename Base::_ExtractType> _ExtractType;
00235 typedef typename conditional<bool(Base::HasUsableDirectAccess),
00236 ExtractType,
00237 typename ExtractType::PlainObject
00238 >::type DirectLinearAccessType;
00239 enum {
00240 IsTransposed = Base::IsTransposed ? 0 : 1
00241 };
00242 static inline const ExtractType extract(const XprType& x) { return Base::extract(x.nestedExpression()); }
00243 static inline Scalar extractScalarFactor(const XprType& x) { return Base::extractScalarFactor(x.nestedExpression()); }
00244 };
00245
00246 template<typename T>
00247 struct blas_traits<const T>
00248 : blas_traits<T>
00249 {};
00250
00251 template<typename T, bool HasUsableDirectAccess=blas_traits<T>::HasUsableDirectAccess>
00252 struct extract_data_selector {
00253 static const typename T::Scalar* run(const T& m)
00254 {
00255 return const_cast<typename T::Scalar*>(&blas_traits<T>::extract(m).coeffRef(0,0));
00256 }
00257 };
00258
00259 template<typename T>
00260 struct extract_data_selector<T,false> {
00261 static typename T::Scalar* run(const T&) { return 0; }
00262 };
00263
00264 template<typename T> const typename T::Scalar* extract_data(const T& m)
00265 {
00266 return extract_data_selector<T>::run(m);
00267 }
00268
00269 }
00270
00271 #endif // EIGEN_BLASUTIL_H