32 #ifndef SHARK_MODELS_CONV2DModel_H 33 #define SHARK_MODELS_CONV2DModel_H 66 template <
class VectorType = RealVector,
class ActivationFunction = LinearNeuron>
72 static_assert(!std::is_same<typename VectorType::storage_type::storage_tag, blas::dense_tag>::value,
"Conv2D not implemented for sparse inputs");
80 base_type::m_features |= base_type::HAS_FIRST_PARAMETER_DERIVATIVE;
81 base_type::m_features |= base_type::HAS_FIRST_INPUT_DERIVATIVE;
92 base_type::m_features |= base_type::HAS_FIRST_PARAMETER_DERIVATIVE;
93 base_type::m_features |= base_type::HAS_FIRST_INPUT_DERIVATIVE;
94 setStructure(imageShape, filterShape, type);
98 {
return "Conv2DModel"; }
102 return {m_imageHeight, m_imageWidth, m_numChannels};
107 return {m_imageHeight, m_imageWidth, m_numFilters};
109 return {m_imageHeight - m_filterHeight + 1, m_imageWidth - m_filterWidth + 1, m_numFilters};
125 return m_filters | m_offset;
130 SIZE_CHECK(newParameters.size() == numberOfParameters());
131 noalias(m_filters) = subrange(newParameters,0,m_filters.size());
132 noalias(m_offset) = subrange(newParameters,m_filters.size(),newParameters.size());
133 updateBackpropFilters();
138 return m_filters.size() + m_offset.size();
150 m_imageHeight = imageShape[0];
151 m_imageWidth = imageShape[1];
152 m_numChannels = imageShape[2];
153 m_numFilters = filterShape[0];
154 m_filterHeight = filterShape[1];
155 m_filterWidth = filterShape[2];
156 m_filters.resize(m_filterHeight * m_filterWidth * m_numFilters * m_numChannels);
157 m_offset.resize(m_numFilters);
158 updateBackpropFilters();
162 return boost::shared_ptr<State>(
new typename ActivationFunction::State());
165 using base_type::eval;
168 void eval(BatchInputType
const& inputs, BatchOutputType& outputs,
State& state)
const{
169 SIZE_CHECK(inputs.size2() == inputShape().numElements());
170 outputs.resize(inputs.size1(),outputShape().numElements());
172 std::size_t outputsForFilter = outputShape().numElements()/m_numFilters;
175 for(std::size_t i = 0; i != inputs.size1(); ++i){
176 auto output = row(outputs,i);
177 blas::kernels::conv2d(row(inputs,i), m_filters, output,
178 m_numChannels, m_numFilters,
179 m_imageHeight, m_imageWidth,
180 m_filterHeight, m_filterWidth,
181 paddingHeight, paddingWidth
184 noalias(to_matrix(output, m_numFilters, outputsForFilter) ) += trans(blas::repeat(m_offset,outputsForFilter));
186 m_activation.evalInPlace(outputs, state.
toState<
typename ActivationFunction::State>());
191 BatchInputType
const& inputs,
192 BatchOutputType
const& outputs,
193 BatchOutputType
const& coefficients,
195 ParameterVectorType& gradient
197 SIZE_CHECK(coefficients.size2()==outputShape().numElements());
198 SIZE_CHECK(coefficients.size1()==inputs.size1());
200 BatchOutputType delta = coefficients;
201 m_activation.multiplyDerivative(outputs,delta, state.
toState<
typename ActivationFunction::State>());
203 gradient.resize(numberOfParameters());
205 auto weightGradient = to_matrix(subrange(gradient,0,m_filters.size()), m_numFilters, m_filters.size()/m_numFilters);
206 auto offsetGradient = subrange(gradient, m_filters.size(),gradient.size());
208 BatchInputType patches(outputShape().numElements()/m_numFilters, m_filters.size()/m_numFilters);
209 for(std::size_t i = 0; i != inputs.size1(); ++i){
211 blas::bindings::im2mat(
212 row(inputs,i),patches,
214 m_imageHeight, m_imageWidth,
215 m_filterHeight, m_filterWidth
218 blas::bindings::im2mat_pad(
219 row(inputs,i),patches,
221 m_imageHeight, m_imageWidth,
222 m_filterHeight, m_filterWidth,
223 m_filterHeight - 1, m_filterWidth - 1
226 auto delta_mat = to_matrix(row(delta,i), m_numFilters, patches.size1());
227 noalias(weightGradient) += delta_mat % patches;
228 noalias(offsetGradient) += sum_columns(delta_mat);
235 BatchInputType
const & inputs,
236 BatchOutputType
const& outputs,
237 BatchOutputType
const & coefficients,
239 BatchInputType& derivatives
241 SIZE_CHECK(coefficients.size2() == outputShape().numElements());
242 SIZE_CHECK(coefficients.size1() == inputs.size1());
244 BatchOutputType delta = coefficients;
245 m_activation.multiplyDerivative(outputs,delta, state.
toState<
typename ActivationFunction::State>());
246 Shape shape = outputShape();
247 std::size_t paddingHeight = m_filterHeight - 1;
248 std::size_t paddingWidth = m_filterWidth - 1;
253 derivatives.resize(inputs.size1(),inputShape().numElements());
255 for(std::size_t i = 0; i != inputs.size1(); ++i){
256 auto derivative = row(derivatives,i);
257 blas::kernels::conv2d(row(delta,i), m_backpropFilters, derivative,
258 m_numFilters, m_numChannels,
260 m_filterHeight, m_filterWidth,
261 paddingHeight, paddingWidth
268 archive >> m_filters;
270 archive >> m_imageHeight;
271 archive >> m_imageWidth;
272 archive >> m_filterHeight;
273 archive >> m_filterWidth;
274 archive >> m_numChannels;
275 archive >> m_numFilters;
277 updateBackpropFilters();
281 archive << m_filters;
283 archive << m_imageHeight;
284 archive << m_imageWidth;
285 archive << m_filterHeight;
286 archive << m_filterWidth;
287 archive << m_numChannels;
288 archive << m_numFilters;
299 void updateBackpropFilters(){
300 m_backpropFilters.resize(m_filters.size());
302 std::size_t filterImSize = m_filterWidth * m_filterHeight;
303 std::size_t filterSize = m_numChannels * m_filterWidth * m_filterHeight;
304 std::size_t bpFilterSize = m_numFilters * m_filterWidth * m_filterHeight;
307 for(std::size_t c = 0; c != m_numChannels; ++c){
308 auto channel_mat = subrange(
309 to_matrix(m_filters, m_numFilters, filterSize),
310 0, m_numFilters, c * filterImSize, (c+1) * filterImSize
317 auto target_vec = subrange(m_backpropFilters, c * bpFilterSize, (c+1) * bpFilterSize);
318 auto target_mat = to_matrix(target_vec,m_numFilters, m_filterWidth * m_filterHeight);
319 for(std::size_t f = 0; f != m_numFilters; ++f){
320 for(std::size_t i = 0; i != m_filterWidth * m_filterHeight; ++i){
321 target_mat(f,i) = channel_mat(f, m_filterWidth * m_filterHeight-i-1);
329 ActivationFunction m_activation;
331 std::size_t m_imageHeight;
332 std::size_t m_imageWidth;
333 std::size_t m_filterHeight;
334 std::size_t m_filterWidth;
335 std::size_t m_numChannels;
336 std::size_t m_numFilters;