32 #ifndef SHARK_DATA_BATCHINTERFACE_H 33 #define SHARK_DATA_BATCHINTERFACE_H 38 #include <boost/utility/enable_if.hpp> 39 #include <boost/mpl/if.hpp> 40 #include <type_traits> 47 template<
class BatchType>
50 typedef BatchType type;
52 typedef typename type::reference reference;
54 typedef typename type::const_reference const_reference;
58 typedef typename type::value_type value_type;
61 typedef typename type::iterator iterator;
63 typedef typename type::const_iterator const_iterator;
67 static type
createBatch(value_type
const& input, std::size_t size = 1){
68 return type(size,input);
72 template<
class Iterator>
73 static type createBatchFromRange(Iterator
const& begin, Iterator
const& end){
75 typename type::reference c=batch[0];
77 std::copy(begin,end,batch.begin());
82 static void resize(T& batch, std::size_t
batchSize, std::size_t elements){
83 batch.resize(batchSize);
94 static std::size_t size(T
const& batch){
return batch.size();}
97 static typename T::reference
get(T& batch, std::size_t i){
101 static typename T::const_reference
get(T
const& batch, std::size_t i){
105 static typename T::iterator begin(T& batch){
106 return batch.begin();
109 static typename T::const_iterator begin(T
const& batch){
110 return batch.begin();
113 static typename T::iterator end(T& batch){
117 static typename T::const_iterator end(T
const& batch){
126 template<
class Matrix>
127 class MatrixRowReference:
public blas::temporary_proxy<blas::matrix_row<Matrix> >{
129 typedef blas::temporary_proxy<blas::matrix_row<Matrix> > base_type;
131 typedef typename blas::vector_temporary<Matrix>::type Vector;
133 MatrixRowReference( Matrix& matrix, std::size_t i)
134 :base_type(blas::matrix_row<Matrix>(matrix,i)){}
136 MatrixRowReference(T
const& matrixrow)
137 :base_type(blas::matrix_row<Matrix>(matrixrow.expression(),matrixrow.index())){}
140 const MatrixRowReference& operator=(
const T& argument){
141 static_cast<base_type&
>(*this)=argument;
146 return Vector(*
this);
151 void swap(MatrixRowReference<M> ref1, MatrixRowReference<M> ref2){
152 swap_rows(ref1.expression().expression(),ref1.index(),ref2.expression().expression(),ref2.index());
155 template<
class M1,
class M2>
156 void swap(MatrixRowReference<M1> ref1, MatrixRowReference<M2> ref2){
157 swap_rows(ref1.expression().expression(),ref1.index(),ref2.expression().expression(),ref2.index());
160 template<
class Matrix>
163 typedef typename blas::matrix_temporary<Matrix>::type type;
166 typedef typename blas::vector_temporary<Matrix>::type value_type;
170 typedef detail::MatrixRowReference<Matrix> reference;
172 typedef detail::MatrixRowReference<const Matrix> const_reference;
176 typedef ProxyIterator<Matrix, value_type, reference > iterator;
178 typedef ProxyIterator<const Matrix, value_type, const_reference > const_iterator;
181 template<
class Element>
182 static type
createBatch(Element
const& input, std::size_t size = 1){
183 return type(size,input.size());
186 template<
class Iterator>
187 static type createBatchFromRange(Iterator
const& pos, Iterator
const& end){
188 type batch(end - pos,pos->size());
189 std::copy(pos,end,begin(batch));
194 static void resize(Matrix& batch, std::size_t batchSize, std::size_t elements){
195 ensure_size(batch,batchSize,elements);
198 static std::size_t size(Matrix
const& batch){
return batch.size1();}
199 static reference
get( Matrix& batch, std::size_t i){
200 return reference(batch,i);
202 static const_reference
get( Matrix
const& batch, std::size_t i){
203 return const_reference(batch,i);
206 static iterator begin(Matrix& batch){
207 return iterator(batch,0);
209 static const_iterator begin(Matrix
const& batch){
210 return const_iterator(batch,0);
213 static iterator end(Matrix& batch){
214 return iterator(batch,batch.size1());
216 static const_iterator end(Matrix
const& batch){
217 return const_iterator(batch,batch.size1());
234 :
public std::conditional<
235 std::is_arithmetic<T>::value,
236 detail::SimpleBatch<blas::vector<T> >,
237 detail::SimpleBatch<std::vector<T> >
242 struct Batch<blas::vector<T> >:
public detail::VectorBatch<blas::matrix<T> >{};
248 typedef shark::blas::compressed_matrix<T> type;
251 typedef shark::blas::compressed_vector<T> value_type;
266 template<
class Element>
267 static type
createBatch(Element
const& input, std::size_t size = 1){
268 return type(size,input.size());
271 template<
class Iterator>
274 std::size_t nonzeros = 0;
275 for(Iterator pos = start; pos != end; ++pos){
276 nonzeros += pos->nnz();
279 type batch(end - start,start->size(),nonzeros);
280 std::copy(start,end,begin(batch));
285 static void resize(type& batch, std::size_t batchSize, std::size_t elements){
286 ensure_size(batch,batchSize,elements);
289 static std::size_t
size(type
const& batch){
return batch.size1();}
290 static reference
get( type& batch, std::size_t i){
291 return reference(batch,i);
293 static const_reference
get( type
const& batch, std::size_t i){
294 return const_reference(batch,i);
298 return iterator(batch,0);
300 static const_iterator
begin(type
const& batch){
301 return const_iterator(batch,0);
304 static iterator
end(type& batch){
305 return iterator(batch,batch.size1());
307 static const_iterator
end(type
const& batch){
308 return const_iterator(batch,batch.size1());
313 struct Batch<detail::MatrixRowReference<M> >
314 :
public Batch<typename detail::MatrixRowReference<M>::Vector>{};
317 template<
class BatchType>
331 struct BatchTraits<blas::dense_matrix_adaptor<T, blas::row_major> >{
332 typedef detail::VectorBatch<blas::dense_matrix_adaptor<T, blas::row_major> > type;
337 struct batch_to_element{
341 struct batch_to_element<T&>{
346 struct batch_to_element<T const&>{
352 struct batch_to_reference{
356 struct batch_to_reference<T&>{
360 struct batch_to_reference<T const&>{
365 struct element_to_batch{
369 struct element_to_batch<T&>{
373 struct element_to_batch<T const&>{
377 struct element_to_batch<detail::MatrixRowReference<M> >{
381 struct element_to_batch<detail::MatrixRowReference<M const> >{
388 template<
class T,
class Range>
394 template<
class Range>
399 template<
class T,
class Iterator>
404 template<
class BatchT>
409 template<
class BatchT>
414 template<
class BatchT>
419 template<
class BatchT>
424 template<
class BatchT>
429 template<
class BatchT>
434 template<
class BatchT>