30 #ifndef REMORA_KERNELS_DEFAULT_GETRF_HPP 31 #define REMORA_KERNELS_DEFAULT_GETRF_HPP 34 #include "../trsm.hpp" 35 #include "../gemm.hpp" 36 #include "../../permutation.hpp" 39 namespace remora{
namespace bindings {
42 template<
class MatA,
class VecP>
44 matrix_expression<MatA, cpu_tag>& A,
45 vector_expression<VecP, cpu_tag>& P,
48 for(std::size_t j = 0; j != A().size2(); ++j){
50 double pivot_value = A()(j,j);
52 for(std::size_t i = j+1; i != A().size1(); ++i){
53 if(std::abs(A()(i,j)) > std::abs(pivot_value)){
55 pivot_value = A()(i,j);
59 throw std::invalid_argument(
"[getrf] Matrix is rank deficient or numerically unstable");
61 if(std::size_t(P()(j)) != j){
62 A().swap_rows(j,P()(j));
68 for(std::size_t i = j+1; i != A().size1(); ++i){
69 A()(i,j) /= pivot_value;
75 for(std::size_t k = j+1; k != A().size2(); ++k){
76 for(std::size_t i = j+1; i != A().size1(); ++i){
77 A()(i,k) -= A()(i,j) * A()(j,k);
84 template<
class MatA,
class VecP>
86 matrix_expression<MatA, cpu_tag>& A,
87 vector_expression<VecP, cpu_tag>& P,
93 typedef typename MatA::value_type value_type;
94 std::vector<value_type> storage(A().size1() * A().size2());
95 dense_matrix_adaptor<value_type, column_major> colBlock(storage.data(), A().size1(), A().size2());
96 kernels::assign(colBlock, A);
97 getrf_block(colBlock, P, column_major());
98 kernels::assign(A, colBlock);
128 template <
typename MatA,
typename VecP>
129 void getrf_recursive(
130 matrix_expression<MatA, cpu_tag>& A,
131 vector_expression<VecP, cpu_tag>& P,
135 std::size_t block_size = 32;
136 std::size_t size = end-start;
137 std::size_t end1=A().size1();
140 if(size <= block_size){
141 auto Ablock = simple_subrange(A, start, end1, start, end);
142 auto Pblock = simple_subrange(P, start, end);
143 getrf_block(Ablock,Pblock,
typename MatA::orientation());
148 std::size_t numBlocks = (size + block_size - 1) / block_size;
149 std::size_t split = start + numBlocks/2 * block_size;
150 auto A_2 = simple_subrange(A, start, end1, split, end);
151 auto A11 = simple_subrange(A, start, split, start, split);
152 auto A12 = simple_subrange(A, start, split, split, end);
153 auto A21 = simple_subrange(A, split, end1, start, split);
154 auto A22 = simple_subrange(A, split, end1, split, end);
155 auto P1 = simple_subrange(P, start, split);
156 auto P2 = simple_subrange(P, split, end);
160 getrf_recursive(A, P, start, split);
167 kernels::trsm<unit_lower, left>(A11, A12);
170 kernels::gemm(A21, A12, A22, -1);
173 getrf_recursive(A, P, split, end);
182 template <
typename MatA,
typename VecP>
184 matrix_expression<MatA, cpu_tag>& A,
185 vector_expression<VecP, cpu_tag>& P
187 for(std::size_t i = 0; i != P().size(); ++i){
190 getrf_recursive(A, P, 0, A().size1());