ViennaCL - The Vienna Computing Library  1.5.1
viennacl/generator/set_arguments_functor.hpp
Go to the documentation of this file.
00001 #ifndef VIENNACL_GENERATOR_ENQUEUE_TREE_HPP
00002 #define VIENNACL_GENERATOR_ENQUEUE_TREE_HPP
00003 
00004 /* =========================================================================
00005    Copyright (c) 2010-2014, Institute for Microelectronics,
00006                             Institute for Analysis and Scientific Computing,
00007                             TU Wien.
00008    Portions of this software are copyright by UChicago Argonne, LLC.
00009 
00010                             -----------------
00011                   ViennaCL - The Vienna Computing Library
00012                             -----------------
00013 
00014    Project Head:    Karl Rupp                   rupp@iue.tuwien.ac.at
00015 
00016    (A list of authors and contributors can be found in the PDF manual)
00017 
00018    License:         MIT (X11), see file LICENSE in the base directory
00019 ============================================================================= */
00020 
00021 
00026 #include <set>
00027 
00028 #include "viennacl/matrix.hpp"
00029 #include "viennacl/vector.hpp"
00030 
00031 #include "viennacl/forwards.h"
00032 #include "viennacl/scheduler/forwards.h"
00033 #include "viennacl/generator/forwards.h"
00034 
00035 #include "viennacl/meta/result_of.hpp"
00036 
00037 #include "viennacl/tools/shared_ptr.hpp"
00038 
00039 #include "viennacl/ocl/kernel.hpp"
00040 
00041 #include "viennacl/generator/helpers.hpp"
00042 #include "viennacl/generator/utils.hpp"
00043 #include "viennacl/generator/mapped_objects.hpp"
00044 
00045 
00046 namespace viennacl{
00047 
00048   namespace generator{
00049 
00050     namespace detail{
00051 
00053       class set_arguments_functor : public traversal_functor{
00054         public:
00055           typedef void result_type;
00056 
00057           set_arguments_functor(std::set<void *> & memory, unsigned int & current_arg, viennacl::ocl::kernel & kernel) : memory_(memory), current_arg_(current_arg), kernel_(kernel){ }
00058 
00059           template<class ScalarType>
00060           result_type operator()(ScalarType const & scal) const {
00061             typedef typename viennacl::result_of::cl_type<ScalarType>::type cl_scalartype;
00062             kernel_.arg(current_arg_++, cl_scalartype(scal));
00063           }
00064 
00066           template<class ScalarType>
00067           result_type operator()(scalar<ScalarType> const & scal) const {
00068             if(memory_.insert((void*)&scal).second)
00069               kernel_.arg(current_arg_++, scal.handle().opencl_handle());
00070           }
00071 
00073           template<class ScalarType>
00074           result_type operator()(vector_base<ScalarType> const & vec) const {
00075             if(memory_.insert((void*)&vec).second){
00076               kernel_.arg(current_arg_++, vec.handle().opencl_handle());
00077               if(viennacl::traits::start(vec)>0)
00078                 kernel_.arg(current_arg_++, cl_uint(viennacl::traits::start(vec)));
00079               if(vec.stride()>1)
00080                 kernel_.arg(current_arg_++, cl_uint(viennacl::traits::stride(vec)));
00081             }
00082           }
00083 
00085           template<class ScalarType>
00086           result_type operator()(implicit_vector_base<ScalarType> const & vec) const {
00087             typedef typename viennacl::result_of::cl_type<ScalarType>::type cl_scalartype;
00088             if(memory_.insert((void*)&vec).second){
00089               if(vec.is_value_static()==false)
00090                 kernel_.arg(current_arg_++, cl_scalartype(vec.value()));
00091               if(vec.has_index())
00092                 kernel_.arg(current_arg_++, cl_uint(vec.index()));
00093             }
00094           }
00095 
00097           template<class ScalarType, class Layout>
00098           result_type operator()(matrix_base<ScalarType, Layout> const & mat) const {
00099             //typedef typename matrix_base<ScalarType, Layout>::size_type size_type;
00100             if(memory_.insert((void*)&mat).second){
00101               kernel_.arg(current_arg_++, mat.handle().opencl_handle());
00102               if(viennacl::traits::start1(mat)>0)
00103                 kernel_.arg(current_arg_++, cl_uint(viennacl::traits::start1(mat)));
00104               if(viennacl::traits::stride1(mat)>1)
00105                 kernel_.arg(current_arg_++, cl_uint(viennacl::traits::stride1(mat)));
00106               if(viennacl::traits::start2(mat)>0)
00107                 kernel_.arg(current_arg_++, cl_uint(viennacl::traits::start2(mat)));
00108               if(viennacl::traits::stride2(mat)>1)
00109                 kernel_.arg(current_arg_++, cl_uint(viennacl::traits::stride2(mat)));
00110             }
00111           }
00112 
00114           template<class ScalarType>
00115           result_type operator()(implicit_matrix_base<ScalarType> const & mat) const {
00116             if(mat.is_value_static()==false)
00117               kernel_.arg(current_arg_++, mat.value());
00118           }
00119 
00121           void operator()(scheduler::statement const * /*statement*/, scheduler::statement_node const * root_node, detail::node_type node_type) const {
00122             if(node_type==LHS_NODE_TYPE && root_node->lhs.type_family != scheduler::COMPOSITE_OPERATION_FAMILY)
00123               utils::call_on_element(root_node->lhs, *this);
00124             else if(node_type==RHS_NODE_TYPE && root_node->rhs.type_family != scheduler::COMPOSITE_OPERATION_FAMILY)
00125               utils::call_on_element(root_node->rhs, *this);
00126           }
00127 
00128         private:
00129           std::set<void *> & memory_;
00130           unsigned int & current_arg_;
00131           viennacl::ocl::kernel & kernel_;
00132       };
00133 
00134     }
00135 
00136   }
00137 
00138 }
00139 #endif