31 #ifndef REMORA_KERNELS_CLBLAS_TRSM_HPP 32 #define REMORA_KERNELS_CLBLAS_TRSM_HPP 34 #include "../../expression_types.hpp" 35 #include "../../detail/traits.hpp" 36 #include <boost/compute/kernel.hpp> 37 #include <boost/compute/detail/meta_kernel.hpp> 38 #include <boost/compute/functional/operator.hpp> 41 namespace remora {
namespace bindings {
43 boost::compute::kernel kernel;
45 std::size_t start_index;
46 std::size_t end_index;
47 std::size_t unit_index;
48 std::size_t upper_index;
51 template<
class MatA,
class MatB>
52 trsm_kernel createTRSMDiagBlockKernel(
53 matrix_expression<MatA, gpu_tag>
const& A,
54 matrix_expression<MatB, gpu_tag> &B,
57 typedef typename MatA::value_type value_typeA;
58 typedef typename MatB::value_type value_typeB;
59 boost::compute::multiplies<value_typeB> prod;
61 boost::compute::detail::meta_kernel k(
"blas_trsm");
62 std::size_t K_index = k.add_arg<std::size_t>(
"K");
63 std::size_t start_index = k.add_arg<std::size_t>(
"start");
64 std::size_t end_index = k.add_arg<std::size_t>(
"end");
65 std::size_t unit_index = k.add_arg<std::size_t>(
"unit");
66 std::size_t upper_index = k.add_arg<std::size_t>(
"upper");
70 k <<
"__local " <<k.decl<value_typeA>(
"Asub")<<
"[TILE_SIZE][TILE_SIZE+2];\n";
71 k <<
"__local " <<k.decl<value_typeB>(
"Bsub")<<
"[TILE_SIZE_K][TILE_SIZE+2];\n";
72 k <<
"const ulong numWorkers = get_local_size(0);\n";
74 k <<
"const ulong t = get_group_id(1);\n";
75 k <<
"const ulong curTileA = end-start;\n";
76 k <<
"const ulong curTileK = min(TILE_SIZE_K, K - t*TILE_SIZE_K);\n";
79 k <<
"for(ulong i = get_local_id(0); i < TILE_SIZE; i += numWorkers){\n";
80 k <<
" for(ulong j = get_local_id(1); j < TILE_SIZE; j += numWorkers){\n";
81 k <<
" Asub[i][j] ="<< A()(k.expr<cl_ulong>(
"min(end-1, start + i)"),k.expr<cl_ulong>(
"min(end-1, start + j)"))<<
";\n";
87 k <<
"for(ulong i = get_local_id(0); i < TILE_SIZE; i += numWorkers){\n";
88 k <<
" for(ulong k = get_local_id(1); k < TILE_SIZE_K; k += numWorkers){\n";
89 k <<
" Bsub[k][i] ="<< B()(k.expr<cl_ulong>(
"min(end-1,start + i)"),k.expr<cl_ulong>(
"min(K-1,t * TILE_SIZE_K+k)"))<<
";\n";
93 k <<
"barrier(CLK_LOCAL_MEM_FENCE);\n";
98 k <<
" for(ulong k = get_local_id(1); k < curTileK; k += numWorkers){\n";
99 k <<
" for(ulong i = 0; i < TILE_SIZE && get_local_id(0) == 0; ++i){\n";
100 k <<
" if(!unit){Bsub[k][i] /= Asub[i][i];}\n";
101 k <<
" for(ulong j = i+1; j < TILE_SIZE; ++j){\n";
102 k <<
" Bsub[k][j] -= "<< prod(k.expr<value_typeB>(
"Bsub[k][i]"), k.expr<value_typeA>(
"Asub[j][i]"))<<
";\n";
108 k <<
" for(ulong k = get_local_id(1); k < curTileK; k += numWorkers){\n";
109 k <<
" for(ulong n = curTileA; n > 0 && get_local_id(0) == 0; --n){\n";
110 k <<
" ulong i = n-1;\n";
111 k <<
" if(!unit ){Bsub[k][i] /= Asub[i][i];}\n";
112 k <<
" for(ulong j = 0; j < i; j ++){\n";
113 k <<
" Bsub[k][j] -= "<< prod(k.expr<value_typeB>(
"Bsub[k][i]"), k.expr<value_typeA>(
"Asub[j][i]"))<<
";\n";
119 k <<
"barrier(CLK_LOCAL_MEM_FENCE);\n";
121 k <<
"for(ulong i = get_local_id(0); i < curTileA; i += numWorkers){\n";
122 k <<
" for(ulong k = get_local_id(1); k < curTileK; k += numWorkers){\n";
123 k << B()(k.expr<cl_ulong>(
"(start+i)"),k.expr<cl_ulong>(
"(t * TILE_SIZE_K+k)"))<<
" = Bsub[k][i];\n";
127 boost::compute::kernel kernel = k.compile(B().queue().get_context(), options);
128 return {kernel,K_index,start_index,end_index,unit_index,upper_index};
131 template <
typename MatA,
typename MatB,
class Triangular>
133 matrix_expression<MatA, gpu_tag>
const& Afull,
134 matrix_expression<MatB, gpu_tag> & Bfull,
138 std::size_t tileSizeA,
139 std::size_t tileSizeB,
140 std::size_t numWorkers,
143 auto A = subrange(Afull,start,end,start,end);
144 auto B = rows(Bfull,start,end);
145 std::size_t size = A.size1();
147 if(size <= tileSizeA){
149 kernel.kernel.set_arg(kernel.K_index, Bfull().size2());
150 kernel.kernel.set_arg(kernel.start_index, start);
151 kernel.kernel.set_arg(kernel.end_index, end);
152 kernel.kernel.set_arg(kernel.unit_index, (std::size_t)Triangular::is_unit);
153 kernel.kernel.set_arg(kernel.upper_index, (std::size_t)Triangular::is_upper);
155 std::size_t global_work_size[2] = {
157 (Bfull().size2()+tileSizeB-1)/ tileSizeB * numWorkers
159 std::size_t local_work_size[2] = {numWorkers, numWorkers};
160 Bfull().queue().enqueue_nd_range_kernel(kernel.kernel, 2,
nullptr, global_work_size, local_work_size);
163 std::size_t numBlocks = (A.size1()+tileSizeA-1)/tileSizeA;
164 std::size_t split = numBlocks/2 * tileSizeA;
165 auto Bfront = rows(B,0,split);
166 auto Bback = rows(B,split,size);
169 if(Triangular::is_upper){
170 trsm_recursive(Afull, Bfull, kernel, start+split,end, tileSizeA,tileSizeB, numWorkers, t);
171 kernels::gemm(subrange(A,0,split,split,size), Bback, Bfront, -1.0);
172 trsm_recursive(Afull, Bfull, kernel, start,start+split, tileSizeA,tileSizeB, numWorkers, t);
174 trsm_recursive(Afull, Bfull, kernel, start,start+split, tileSizeA,tileSizeB, numWorkers, t);
175 kernels::gemm(subrange(A,split,size,0,split), Bfront, Bback, -1.0);
176 trsm_recursive(Afull, Bfull, kernel, start+split,end, tileSizeA,tileSizeB, numWorkers, t);
180 template <
typename MatA,
typename MatB,
class Triangular>
182 matrix_expression<MatA, gpu_tag>
const& A,
183 matrix_expression<MatB, gpu_tag>& B,
187 REMORA_SIZE_CHECK(A().size1() == A().size2());
188 REMORA_SIZE_CHECK(A().size2() == B().size1());
189 std::size_t
const TileSizeA = 32;
190 std::size_t
const TileSizeB = 32;
191 std::size_t
const numWorkers = 8;
192 char const* options =
"-DTILE_SIZE=32ul -DTILE_SIZE_K=32ul";
193 auto kernel = bindings::createTRSMDiagBlockKernel(A,B,options);
195 trsm_recursive(A,B,kernel,0,A().size1(), TileSizeA, TileSizeB, numWorkers,Triangular());
198 template <
typename MatA,
typename MatB,
class Triangular>
200 matrix_expression<MatA, gpu_tag>
const& A,
201 matrix_expression<MatB, gpu_tag>& B,
205 matrix_transpose<typename const_expression<MatA>::type> transA(A());
206 matrix_transpose<MatB> transB(B());
207 trsm_call(transA,transB,
typename Triangular::transposed_orientation(),left());
213 template <
class Triangular,
class S
ide,
typename MatA,
typename MatB>
215 matrix_expression<MatA, gpu_tag>
const& A,
216 matrix_expression<MatB, gpu_tag>& B
218 bindings::trsm_call(A,B,Triangular(), Side());