29 #ifndef REMORA_KERNELS_DEFAULT_PSTRF_HPP 30 #define REMORA_KERNELS_DEFAULT_PSTRF_HPP 32 #include "../gemm.hpp" 33 #include "../gemv.hpp" 35 #include "../../vector.hpp" 37 namespace remora{
namespace bindings {
39 template<
class MatA,
class VecP>
41 matrix_expression<MatA, cpu_tag> &A,
42 vector_expression<VecP, cpu_tag>& P,
68 std::size_t block_size = 20;
71 size_t m = A().size1();
73 vector<typename MatA::value_type> pivotValues(m);
76 double max_diag = A()(0,0);
77 for(std::size_t i = 1; i < m; ++i)
78 max_diag = std::max(max_diag,std::abs(A()(i,i)));
79 double epsilon = m * m * std::numeric_limits<typename MatA::value_type>::epsilon() * max_diag;
81 for(std::size_t k = 0; k < m; k += block_size){
82 std::size_t currentSize = std::min(m-k,block_size);
84 auto Ak = simple_subrange(A,k,m,k,m);
85 auto pivots = simple_subrange(pivotValues,k,m);
88 for(std::size_t j = 0; j != currentSize; ++j){
92 for(std::size_t i = 0; i != m-k; ++i)
95 for(std::size_t i = j; i != m-k; ++i)
96 pivots(i) -= Ak(i,j-1) * Ak(i,j-1);
100 std::size_t pivot = std::max_element(pivots.begin()+j,pivots.end())-pivots.begin();
102 P()(k+j) = (
int)(pivot+k);
103 A().swap_rows(k+j,k+pivot);
104 A().swap_columns(k+j,k+pivot);
109 auto pivotValue = pivots(j);
110 if(pivotValue < epsilon){
112 simple_subrange(Ak,j,m-k,j,m-k).clear();
117 Ak(j,j) = std::sqrt(pivotValue);
120 auto curCol = simple_column(Ak,j);
121 auto colLowerPart = simple_subrange(curCol,j+1,m-k);
125 auto blockLL = simple_subrange(Ak,j+1,m-k,0,j);
126 auto curRow = simple_row(Ak,j);
127 auto rowLeftPart = simple_subrange(curRow,0,j);
137 kernels::gemv(blockLL,rowLeftPart,colLowerPart,-1);
139 colLowerPart /= Ak(j,j);
141 subrange(Ak,j,j+1,j+1,Ak.size2()).clear();
143 if(k+currentSize < m){
144 auto blockLL = simple_subrange(Ak, block_size, m-k, 0, block_size);
145 auto blockLR = simple_subrange(Ak, block_size, m-k, block_size, m-k);
146 kernels::gemm(blockLL,simple_trans(blockLL), blockLR, -1);
153 template<
class MatA,
class VecP>
155 matrix_expression<MatA, cpu_tag> &A,
156 vector_expression<VecP, cpu_tag>& P,
159 auto transA = simple_trans(A);
160 return pstrf(transA,P,lower());