ViennaCL - The Vienna Computing Library  1.5.1
viennacl/scheduler/execute.hpp
Go to the documentation of this file.
00001 #ifndef VIENNACL_SCHEDULER_EXECUTE_HPP
00002 #define VIENNACL_SCHEDULER_EXECUTE_HPP
00003 
00004 /* =========================================================================
00005    Copyright (c) 2010-2014, Institute for Microelectronics,
00006                             Institute for Analysis and Scientific Computing,
00007                             TU Wien.
00008    Portions of this software are copyright by UChicago Argonne, LLC.
00009 
00010                             -----------------
00011                   ViennaCL - The Vienna Computing Library
00012                             -----------------
00013 
00014    Project Head:    Karl Rupp                   rupp@iue.tuwien.ac.at
00015 
00016    (A list of authors and contributors can be found in the PDF manual)
00017 
00018    License:         MIT (X11), see file LICENSE in the base directory
00019 ============================================================================= */
00020 
00021 
00026 #include "viennacl/forwards.h"
00027 #include "viennacl/scheduler/forwards.h"
00028 
00029 #include "viennacl/scheduler/execute_scalar_assign.hpp"
00030 #include "viennacl/scheduler/execute_axbx.hpp"
00031 #include "viennacl/scheduler/execute_elementwise.hpp"
00032 #include "viennacl/scheduler/execute_matrix_prod.hpp"
00033 
00034 namespace viennacl
00035 {
00036   namespace scheduler
00037   {
00038     namespace detail
00039     {
00041       void execute_composite(statement const & s, statement_node const & root_node)
00042       {
00043         statement::container_type const & expr = s.array();
00044 
00045         statement_node const & leaf = expr[root_node.rhs.node_index];
00046 
00047         if (leaf.op.type  == OPERATION_BINARY_ADD_TYPE || leaf.op.type  == OPERATION_BINARY_SUB_TYPE) // x = (y) +- (z)  where y and z are either data objects or expressions
00048         {
00049           execute_axbx(s, root_node);
00050         }
00051         else if (leaf.op.type == OPERATION_BINARY_MULT_TYPE || leaf.op.type == OPERATION_BINARY_DIV_TYPE) // x = (y) * / alpha;
00052         {
00053           bool scalar_is_temporary = (leaf.rhs.type_family != SCALAR_TYPE_FAMILY);
00054 
00055           statement_node scalar_temp_node;
00056           if (scalar_is_temporary)
00057           {
00058             lhs_rhs_element temp;
00059             temp.type_family  = SCALAR_TYPE_FAMILY;
00060             temp.subtype      = DEVICE_SCALAR_TYPE;
00061             temp.numeric_type = root_node.lhs.numeric_type;
00062             detail::new_element(scalar_temp_node.lhs, temp);
00063 
00064             scalar_temp_node.op.type_family = OPERATION_BINARY_TYPE_FAMILY;
00065             scalar_temp_node.op.type        = OPERATION_BINARY_ASSIGN_TYPE;
00066 
00067             scalar_temp_node.rhs.type_family  = COMPOSITE_OPERATION_FAMILY;
00068             scalar_temp_node.rhs.subtype      = INVALID_SUBTYPE;
00069             scalar_temp_node.rhs.numeric_type = INVALID_NUMERIC_TYPE;
00070             scalar_temp_node.rhs.node_index   = leaf.rhs.node_index;
00071 
00072             // work on subexpression:
00073             // TODO: Catch exception, free temporary, then rethrow
00074             execute_composite(s, scalar_temp_node);
00075           }
00076 
00077           if (leaf.lhs.type_family == COMPOSITE_OPERATION_FAMILY)  //(y) is an expression, so introduce a temporary z = (y):
00078           {
00079             statement_node new_root_y;
00080 
00081             new_root_y.lhs.type_family  = root_node.lhs.type_family;
00082             new_root_y.lhs.subtype      = root_node.lhs.subtype;
00083             new_root_y.lhs.numeric_type = root_node.lhs.numeric_type;
00084             detail::new_element(new_root_y.lhs, root_node.lhs);
00085 
00086             new_root_y.op.type_family = OPERATION_BINARY_TYPE_FAMILY;
00087             new_root_y.op.type        = OPERATION_BINARY_ASSIGN_TYPE;
00088 
00089             new_root_y.rhs.type_family  = COMPOSITE_OPERATION_FAMILY;
00090             new_root_y.rhs.subtype      = INVALID_SUBTYPE;
00091             new_root_y.rhs.numeric_type = INVALID_NUMERIC_TYPE;
00092             new_root_y.rhs.node_index   = leaf.lhs.node_index;
00093 
00094             // work on subexpression:
00095             // TODO: Catch exception, free temporary, then rethrow
00096             execute_composite(s, new_root_y);
00097 
00098             // now compute x = z * / alpha:
00099             lhs_rhs_element u = root_node.lhs;
00100             lhs_rhs_element v = new_root_y.lhs;
00101             lhs_rhs_element alpha = scalar_is_temporary ? scalar_temp_node.lhs : leaf.rhs;
00102 
00103             bool is_division = (leaf.op.type  == OPERATION_BINARY_DIV_TYPE);
00104             switch (root_node.op.type)
00105             {
00106               case OPERATION_BINARY_ASSIGN_TYPE:
00107                 detail::ax(u,
00108                            v, alpha, 1, is_division, false);
00109                 break;
00110               case OPERATION_BINARY_INPLACE_ADD_TYPE:
00111                 detail::axbx(u,
00112                              u,   1.0, 1, false,       false,
00113                              v, alpha, 1, is_division, false);
00114                 break;
00115               case OPERATION_BINARY_INPLACE_SUB_TYPE:
00116                 detail::axbx(u,
00117                              u,   1.0, 1, false,       false,
00118                              v, alpha, 1, is_division, true);
00119                 break;
00120               default:
00121                 throw statement_not_supported_exception("Unsupported binary operator for vector operation in root note (should be =, +=, or -=)");
00122             }
00123 
00124             detail::delete_element(new_root_y.lhs);
00125           }
00126           else if (leaf.lhs.type_family != COMPOSITE_OPERATION_FAMILY)
00127           {
00128             lhs_rhs_element u = root_node.lhs;
00129             lhs_rhs_element v = leaf.lhs;
00130             lhs_rhs_element alpha = scalar_is_temporary ? scalar_temp_node.lhs : leaf.rhs;
00131 
00132             bool is_division = (leaf.op.type  == OPERATION_BINARY_DIV_TYPE);
00133             switch (root_node.op.type)
00134             {
00135               case OPERATION_BINARY_ASSIGN_TYPE:
00136                 detail::ax(u,
00137                            v, alpha, 1, is_division, false);
00138                 break;
00139               case OPERATION_BINARY_INPLACE_ADD_TYPE:
00140                 detail::axbx(u,
00141                              u,   1.0, 1, false,       false,
00142                              v, alpha, 1, is_division, false);
00143                 break;
00144               case OPERATION_BINARY_INPLACE_SUB_TYPE:
00145                 detail::axbx(u,
00146                              u,   1.0, 1, false,       false,
00147                              v, alpha, 1, is_division, true);
00148                 break;
00149               default:
00150                 throw statement_not_supported_exception("Unsupported binary operator for vector operation in root note (should be =, +=, or -=)");
00151             }
00152           }
00153           else
00154             throw statement_not_supported_exception("Unsupported binary operator for OPERATION_BINARY_MULT_TYPE || OPERATION_BINARY_DIV_TYPE on leaf node.");
00155 
00156           // clean up
00157           if (scalar_is_temporary)
00158             detail::delete_element(scalar_temp_node.lhs);
00159         }
00160         else if (   leaf.op.type == OPERATION_BINARY_INNER_PROD_TYPE
00161                  || leaf.op.type == OPERATION_UNARY_NORM_1_TYPE
00162                  || leaf.op.type == OPERATION_UNARY_NORM_2_TYPE
00163                  || leaf.op.type == OPERATION_UNARY_NORM_INF_TYPE)
00164         {
00165           execute_scalar_assign_composite(s, root_node);
00166         }
00167         else if (   (leaf.op.type_family == OPERATION_UNARY_TYPE_FAMILY && leaf.op.type != OPERATION_UNARY_TRANS_TYPE)
00168                  || leaf.op.type == OPERATION_BINARY_ELEMENT_PROD_TYPE
00169                  || leaf.op.type == OPERATION_BINARY_ELEMENT_DIV_TYPE) // element-wise operations
00170         {
00171           execute_element_composite(s, root_node);
00172         }
00173         else if (   leaf.op.type == OPERATION_BINARY_MAT_VEC_PROD_TYPE
00174                  || leaf.op.type == OPERATION_BINARY_MAT_MAT_PROD_TYPE)
00175         {
00176           execute_matrix_prod(s, root_node);
00177         }
00178         else
00179           throw statement_not_supported_exception("Unsupported binary operator");
00180       }
00181 
00182 
00184       inline void execute_single(statement const &, statement_node const & root_node)
00185       {
00186         lhs_rhs_element u = root_node.lhs;
00187         lhs_rhs_element v = root_node.rhs;
00188         switch (root_node.op.type)
00189         {
00190           case OPERATION_BINARY_ASSIGN_TYPE:
00191             detail::ax(u,
00192                        v, 1.0, 1, false, false);
00193             break;
00194           case OPERATION_BINARY_INPLACE_ADD_TYPE:
00195             detail::axbx(u,
00196                          u, 1.0, 1, false, false,
00197                          v, 1.0, 1, false, false);
00198             break;
00199           case OPERATION_BINARY_INPLACE_SUB_TYPE:
00200             detail::axbx(u,
00201                          u, 1.0, 1, false, false,
00202                          v, 1.0, 1, false, true);
00203             break;
00204           default:
00205             throw statement_not_supported_exception("Unsupported binary operator for operation in root note (should be =, +=, or -=)");
00206         }
00207 
00208       }
00209 
00210 
00211       inline void execute_impl(statement const & s, statement_node const & root_node)
00212       {
00213         if (   root_node.lhs.type_family != SCALAR_TYPE_FAMILY
00214             && root_node.lhs.type_family != VECTOR_TYPE_FAMILY
00215             && root_node.lhs.type_family != MATRIX_TYPE_FAMILY)
00216           throw statement_not_supported_exception("Unsupported lvalue encountered in head node.");
00217 
00218         switch (root_node.rhs.type_family)
00219         {
00220           case COMPOSITE_OPERATION_FAMILY:
00221             execute_composite(s, root_node);
00222             break;
00223           case SCALAR_TYPE_FAMILY:
00224           case VECTOR_TYPE_FAMILY:
00225           case MATRIX_TYPE_FAMILY:
00226             execute_single(s, root_node);
00227             break;
00228           default:
00229             throw statement_not_supported_exception("Invalid rvalue encountered in vector assignment");
00230         }
00231 
00232       }
00233     }
00234 
00235     inline void execute(statement const & s)
00236     {
00237       // simply start execution from the root node:
00238       detail::execute_impl(s, s.array()[s.root()]);
00239     }
00240 
00241 
00242   }
00243 
00244 } //namespace viennacl
00245 
00246 #endif
00247