31 #ifndef REMORA_KERNELS_CLBLAS_SUM_ROWS_HPP 32 #define REMORA_KERNELS_CLBLAS_SUM_ROWS_HPP 34 #include "../../expression_types.hpp" 35 #include "../../detail/traits.hpp" 36 #include <boost/compute/kernel.hpp> 37 #include <boost/compute/detail/meta_kernel.hpp> 38 #include <boost/compute/functional/operator.hpp> 40 namespace remora{
namespace bindings{
42 template<
class M,
class V,
class Orientation>
44 matrix_expression<M, gpu_tag>
const& A,
45 vector_expression<V, gpu_tag>& v,
46 typename V::value_type alpha,
47 Orientation, dense_tag, dense_tag
49 typedef typename V::value_type value_type;
50 boost::compute::detail::meta_kernel k(
"blas_sum_rows_row");
51 std::size_t alpha_index = k.add_arg<value_type>(
"alpha");
52 std::size_t size1_index = k.add_arg<std::size_t>(
"size1");
53 std::size_t size2_index = k.add_arg<std::size_t>(
"size2");
55 k <<
"__local " <<k.decl<value_type>(
"sums")<<
"[TILE_DIM][TILE_DIM+1];\n";
56 k <<
"uint colid = get_global_id(1);\n";
57 k <<
"sums[get_local_id(0)][get_local_id(1)] = 0.0;\n";
58 k <<
"for(uint i = get_local_id(0) ; i < size1 && colid < size2; i += TILE_DIM){\n";
59 auto exprRow = k.expr<cl_uint>(
"i");
60 auto exprCol = k.expr<cl_uint>(
"colid");
61 k<<
" sums[get_local_id(0)][get_local_id(1)] +=" << A()(exprRow,exprCol)<<
";\n";
63 k <<
"barrier(CLK_LOCAL_MEM_FENCE);\n";
65 k <<
"if(get_local_id(0) == 0 && colid < size2){\n";
66 k <<
" for(uint i = 1 ; i < TILE_DIM; ++i){\n";
67 k <<
" sums[0][get_local_id(1)] +=sums[i][get_local_id(1)];\n";
69 k << v()(exprCol) <<
"+= alpha * sums[0][get_local_id(1)];\n";
73 std::size_t TILE_DIM = 8;
74 char const* options =
"-DTILE_DIM=8";
75 boost::compute::kernel kernel = k.compile(v().queue().get_context(), options);
77 kernel.set_arg(alpha_index, alpha);
78 kernel.set_arg(size1_index, A().size1());
79 kernel.set_arg(size2_index, A().size2());
81 std::size_t global_work_size[2] = {
83 ((A().size2()+TILE_DIM-1)/TILE_DIM) * TILE_DIM
85 std::size_t local_work_size[2] = {TILE_DIM, TILE_DIM};
86 v().queue().enqueue_nd_range_kernel(kernel, 2,
nullptr, global_work_size, local_work_size);