32 #ifndef REMORA_KERNELS_GPU_GEMM_HPP 33 #define REMORA_KERNELS_GPU_GEMM_HPP 35 #include "../../expression_types.hpp" 36 #include "../../detail/traits.hpp" 37 #include <boost/compute/kernel.hpp> 38 #include <boost/compute/detail/meta_kernel.hpp> 39 #include <boost/compute/functional/operator.hpp> 41 namespace remora{
namespace kernels{
44 template <
typename MatA,
typename MatB,
typename MatC>
46 matrix_expression<MatA, gpu_tag>
const& A,
47 matrix_expression<MatB, gpu_tag>
const& B,
48 matrix_expression<MatC, gpu_tag>& C,
49 typename MatC::value_type
const& alpha
51 REMORA_SIZE_CHECK(A().size1() == C().size1());
52 REMORA_SIZE_CHECK(B().size2() == C().size2());
53 REMORA_SIZE_CHECK(A().size2()== B().size1());
64 std::size_t BLOCK_SIZE = 4;
65 std::size_t TILE_SIZE = 32;
66 std::size_t NUM_WORKERS = TILE_SIZE / BLOCK_SIZE;
68 char const* options =
"-DTILE_SIZE=32ul -DBLOCK_SIZE=4ul -DTILE_SIZE_K=16ul";
69 typedef typename MatC::value_type value_type;
71 boost::compute::detail::meta_kernel k(
"blas_gemm");
72 std::size_t M_index = k.add_arg<std::size_t>(
"M");
73 std::size_t N_index = k.add_arg<std::size_t>(
"N");
74 std::size_t K_index = k.add_arg<std::size_t>(
"K");
75 std::size_t alpha_index = k.add_arg<value_type>(
"alpha");
78 k <<
"__local " <<k.decl<value_type>(
"Asub")<<
"[TILE_SIZE_K][TILE_SIZE+2];\n";
79 k <<
"__local " <<k.decl<value_type>(
"Bsub")<<
"[TILE_SIZE_K][TILE_SIZE+2];\n";
80 k <<
" const ulong numWorkers = get_local_size(0);\n";
88 k << k.decl<value_type>(
"acc") <<
"[BLOCK_SIZE][BLOCK_SIZE];\n";
89 k <<
"for (ulong wm=0; wm<BLOCK_SIZE; wm++){\n";
90 k <<
" for (ulong wn=0; wn<BLOCK_SIZE; wn++){\n";
91 k <<
" acc[wm][wn] = 0.0f;\n";
97 k <<
"ulong numTiles = (K+TILE_SIZE_K-1)/TILE_SIZE_K;\n";
98 k <<
"for (ulong t=0; t<numTiles; t++){\n";
101 k <<
" const ulong curTileK = min(TILE_SIZE_K, K - t*TILE_SIZE_K);\n";
104 k <<
" for(ulong i = get_local_id(0); i < TILE_SIZE; i += numWorkers){\n";
105 k <<
" for(ulong k = get_local_id(1); k < curTileK; k += numWorkers){\n";
106 k <<
" ulong ktile = t * TILE_SIZE_K + k;\n";
107 k <<
" Asub[k][i] ="<< A()(k.expr<cl_ulong>(
"min(M-1,TILE_SIZE * get_group_id(0)+i)"),k.expr<cl_ulong>(
"ktile"))<<
";\n";
108 k <<
" Bsub[k][i] ="<< B()(k.expr<cl_ulong>(
"ktile"),k.expr<cl_ulong>(
"min(N-1,TILE_SIZE * get_group_id(1)+i)"))<<
";\n";
113 k <<
" barrier(CLK_LOCAL_MEM_FENCE);\n";
117 k <<
" for (ulong k=0; k<curTileK; k++){\n";
119 k << k.decl<value_type>(
"Breg")<<
"[BLOCK_SIZE];\n";
120 k <<
" for (ulong wn=0; wn<BLOCK_SIZE; wn++){\n";
121 k <<
" Breg[wn] = Bsub[k][get_local_id(1) + wn * numWorkers];\n";
125 k <<
" for (ulong wm = 0; wm<BLOCK_SIZE; wm++){\n";
126 k << k.decl<value_type>(
"Areg") <<
"= Asub[k][get_local_id(0) + wm * numWorkers];\n";
127 k <<
" for (ulong wn=0; wn<BLOCK_SIZE; wn++){\n";
128 k <<
" acc[wm][wn] += Areg * Breg[wn];\n";
134 k <<
" barrier(CLK_LOCAL_MEM_FENCE);\n";
138 k <<
"const ulong maxCi = min(TILE_SIZE, M - get_group_id(0) * TILE_SIZE);\n";
139 k <<
"const ulong maxCj = min(TILE_SIZE, N - get_group_id(1) * TILE_SIZE);\n";
140 k <<
"const ulong offTileCi = TILE_SIZE * get_group_id(0);\n";
141 k <<
"const ulong offTileCj = TILE_SIZE * get_group_id(1);\n";
142 k <<
"ulong wm = 0;\n";
143 k <<
"for (ulong i = get_local_id(0); i < maxCi; i += numWorkers, wm++){\n";
144 k <<
" ulong wn = 0;\n";
145 k <<
" for (ulong j =get_local_id(1); j < maxCj; j += numWorkers, wn++){\n";
146 k << C()(k.expr<cl_ulong>(
"(offTileCi + i)"), k.expr<cl_ulong>(
"(offTileCj + j)")) <<
"+= alpha * acc[wm][wn];\n";
150 boost::compute::kernel kernel = k.compile(C().queue().get_context(), options);
153 kernel.set_arg(M_index, C().size1());
154 kernel.set_arg(N_index, C().size2());
155 kernel.set_arg(K_index, A().size2());
156 kernel.set_arg(alpha_index, alpha);
158 std::size_t global_work_size[2] = {
159 (C().size1()+TILE_SIZE-1)/ TILE_SIZE * NUM_WORKERS,
160 (C().size2()+TILE_SIZE-1)/ TILE_SIZE * NUM_WORKERS
162 std::size_t local_work_size[2] = {NUM_WORKERS, NUM_WORKERS};
163 C().queue().enqueue_nd_range_kernel(kernel, 2,
nullptr, global_work_size, local_work_size);