ViennaCL - The Vienna Computing Library
1.5.1
|
00001 #ifndef VIENNACL_SCHEDULER_EXECUTE_HPP 00002 #define VIENNACL_SCHEDULER_EXECUTE_HPP 00003 00004 /* ========================================================================= 00005 Copyright (c) 2010-2014, Institute for Microelectronics, 00006 Institute for Analysis and Scientific Computing, 00007 TU Wien. 00008 Portions of this software are copyright by UChicago Argonne, LLC. 00009 00010 ----------------- 00011 ViennaCL - The Vienna Computing Library 00012 ----------------- 00013 00014 Project Head: Karl Rupp rupp@iue.tuwien.ac.at 00015 00016 (A list of authors and contributors can be found in the PDF manual) 00017 00018 License: MIT (X11), see file LICENSE in the base directory 00019 ============================================================================= */ 00020 00021 00026 #include "viennacl/forwards.h" 00027 #include "viennacl/scheduler/forwards.h" 00028 00029 #include "viennacl/scheduler/execute_scalar_assign.hpp" 00030 #include "viennacl/scheduler/execute_axbx.hpp" 00031 #include "viennacl/scheduler/execute_elementwise.hpp" 00032 #include "viennacl/scheduler/execute_matrix_prod.hpp" 00033 00034 namespace viennacl 00035 { 00036 namespace scheduler 00037 { 00038 namespace detail 00039 { 00041 void execute_composite(statement const & s, statement_node const & root_node) 00042 { 00043 statement::container_type const & expr = s.array(); 00044 00045 statement_node const & leaf = expr[root_node.rhs.node_index]; 00046 00047 if (leaf.op.type == OPERATION_BINARY_ADD_TYPE || leaf.op.type == OPERATION_BINARY_SUB_TYPE) // x = (y) +- (z) where y and z are either data objects or expressions 00048 { 00049 execute_axbx(s, root_node); 00050 } 00051 else if (leaf.op.type == OPERATION_BINARY_MULT_TYPE || leaf.op.type == OPERATION_BINARY_DIV_TYPE) // x = (y) * / alpha; 00052 { 00053 bool scalar_is_temporary = (leaf.rhs.type_family != SCALAR_TYPE_FAMILY); 00054 00055 statement_node scalar_temp_node; 00056 if (scalar_is_temporary) 00057 { 00058 lhs_rhs_element temp; 00059 temp.type_family = SCALAR_TYPE_FAMILY; 00060 temp.subtype = DEVICE_SCALAR_TYPE; 00061 temp.numeric_type = root_node.lhs.numeric_type; 00062 detail::new_element(scalar_temp_node.lhs, temp); 00063 00064 scalar_temp_node.op.type_family = OPERATION_BINARY_TYPE_FAMILY; 00065 scalar_temp_node.op.type = OPERATION_BINARY_ASSIGN_TYPE; 00066 00067 scalar_temp_node.rhs.type_family = COMPOSITE_OPERATION_FAMILY; 00068 scalar_temp_node.rhs.subtype = INVALID_SUBTYPE; 00069 scalar_temp_node.rhs.numeric_type = INVALID_NUMERIC_TYPE; 00070 scalar_temp_node.rhs.node_index = leaf.rhs.node_index; 00071 00072 // work on subexpression: 00073 // TODO: Catch exception, free temporary, then rethrow 00074 execute_composite(s, scalar_temp_node); 00075 } 00076 00077 if (leaf.lhs.type_family == COMPOSITE_OPERATION_FAMILY) //(y) is an expression, so introduce a temporary z = (y): 00078 { 00079 statement_node new_root_y; 00080 00081 new_root_y.lhs.type_family = root_node.lhs.type_family; 00082 new_root_y.lhs.subtype = root_node.lhs.subtype; 00083 new_root_y.lhs.numeric_type = root_node.lhs.numeric_type; 00084 detail::new_element(new_root_y.lhs, root_node.lhs); 00085 00086 new_root_y.op.type_family = OPERATION_BINARY_TYPE_FAMILY; 00087 new_root_y.op.type = OPERATION_BINARY_ASSIGN_TYPE; 00088 00089 new_root_y.rhs.type_family = COMPOSITE_OPERATION_FAMILY; 00090 new_root_y.rhs.subtype = INVALID_SUBTYPE; 00091 new_root_y.rhs.numeric_type = INVALID_NUMERIC_TYPE; 00092 new_root_y.rhs.node_index = leaf.lhs.node_index; 00093 00094 // work on subexpression: 00095 // TODO: Catch exception, free temporary, then rethrow 00096 execute_composite(s, new_root_y); 00097 00098 // now compute x = z * / alpha: 00099 lhs_rhs_element u = root_node.lhs; 00100 lhs_rhs_element v = new_root_y.lhs; 00101 lhs_rhs_element alpha = scalar_is_temporary ? scalar_temp_node.lhs : leaf.rhs; 00102 00103 bool is_division = (leaf.op.type == OPERATION_BINARY_DIV_TYPE); 00104 switch (root_node.op.type) 00105 { 00106 case OPERATION_BINARY_ASSIGN_TYPE: 00107 detail::ax(u, 00108 v, alpha, 1, is_division, false); 00109 break; 00110 case OPERATION_BINARY_INPLACE_ADD_TYPE: 00111 detail::axbx(u, 00112 u, 1.0, 1, false, false, 00113 v, alpha, 1, is_division, false); 00114 break; 00115 case OPERATION_BINARY_INPLACE_SUB_TYPE: 00116 detail::axbx(u, 00117 u, 1.0, 1, false, false, 00118 v, alpha, 1, is_division, true); 00119 break; 00120 default: 00121 throw statement_not_supported_exception("Unsupported binary operator for vector operation in root note (should be =, +=, or -=)"); 00122 } 00123 00124 detail::delete_element(new_root_y.lhs); 00125 } 00126 else if (leaf.lhs.type_family != COMPOSITE_OPERATION_FAMILY) 00127 { 00128 lhs_rhs_element u = root_node.lhs; 00129 lhs_rhs_element v = leaf.lhs; 00130 lhs_rhs_element alpha = scalar_is_temporary ? scalar_temp_node.lhs : leaf.rhs; 00131 00132 bool is_division = (leaf.op.type == OPERATION_BINARY_DIV_TYPE); 00133 switch (root_node.op.type) 00134 { 00135 case OPERATION_BINARY_ASSIGN_TYPE: 00136 detail::ax(u, 00137 v, alpha, 1, is_division, false); 00138 break; 00139 case OPERATION_BINARY_INPLACE_ADD_TYPE: 00140 detail::axbx(u, 00141 u, 1.0, 1, false, false, 00142 v, alpha, 1, is_division, false); 00143 break; 00144 case OPERATION_BINARY_INPLACE_SUB_TYPE: 00145 detail::axbx(u, 00146 u, 1.0, 1, false, false, 00147 v, alpha, 1, is_division, true); 00148 break; 00149 default: 00150 throw statement_not_supported_exception("Unsupported binary operator for vector operation in root note (should be =, +=, or -=)"); 00151 } 00152 } 00153 else 00154 throw statement_not_supported_exception("Unsupported binary operator for OPERATION_BINARY_MULT_TYPE || OPERATION_BINARY_DIV_TYPE on leaf node."); 00155 00156 // clean up 00157 if (scalar_is_temporary) 00158 detail::delete_element(scalar_temp_node.lhs); 00159 } 00160 else if ( leaf.op.type == OPERATION_BINARY_INNER_PROD_TYPE 00161 || leaf.op.type == OPERATION_UNARY_NORM_1_TYPE 00162 || leaf.op.type == OPERATION_UNARY_NORM_2_TYPE 00163 || leaf.op.type == OPERATION_UNARY_NORM_INF_TYPE) 00164 { 00165 execute_scalar_assign_composite(s, root_node); 00166 } 00167 else if ( (leaf.op.type_family == OPERATION_UNARY_TYPE_FAMILY && leaf.op.type != OPERATION_UNARY_TRANS_TYPE) 00168 || leaf.op.type == OPERATION_BINARY_ELEMENT_PROD_TYPE 00169 || leaf.op.type == OPERATION_BINARY_ELEMENT_DIV_TYPE) // element-wise operations 00170 { 00171 execute_element_composite(s, root_node); 00172 } 00173 else if ( leaf.op.type == OPERATION_BINARY_MAT_VEC_PROD_TYPE 00174 || leaf.op.type == OPERATION_BINARY_MAT_MAT_PROD_TYPE) 00175 { 00176 execute_matrix_prod(s, root_node); 00177 } 00178 else 00179 throw statement_not_supported_exception("Unsupported binary operator"); 00180 } 00181 00182 00184 inline void execute_single(statement const &, statement_node const & root_node) 00185 { 00186 lhs_rhs_element u = root_node.lhs; 00187 lhs_rhs_element v = root_node.rhs; 00188 switch (root_node.op.type) 00189 { 00190 case OPERATION_BINARY_ASSIGN_TYPE: 00191 detail::ax(u, 00192 v, 1.0, 1, false, false); 00193 break; 00194 case OPERATION_BINARY_INPLACE_ADD_TYPE: 00195 detail::axbx(u, 00196 u, 1.0, 1, false, false, 00197 v, 1.0, 1, false, false); 00198 break; 00199 case OPERATION_BINARY_INPLACE_SUB_TYPE: 00200 detail::axbx(u, 00201 u, 1.0, 1, false, false, 00202 v, 1.0, 1, false, true); 00203 break; 00204 default: 00205 throw statement_not_supported_exception("Unsupported binary operator for operation in root note (should be =, +=, or -=)"); 00206 } 00207 00208 } 00209 00210 00211 inline void execute_impl(statement const & s, statement_node const & root_node) 00212 { 00213 if ( root_node.lhs.type_family != SCALAR_TYPE_FAMILY 00214 && root_node.lhs.type_family != VECTOR_TYPE_FAMILY 00215 && root_node.lhs.type_family != MATRIX_TYPE_FAMILY) 00216 throw statement_not_supported_exception("Unsupported lvalue encountered in head node."); 00217 00218 switch (root_node.rhs.type_family) 00219 { 00220 case COMPOSITE_OPERATION_FAMILY: 00221 execute_composite(s, root_node); 00222 break; 00223 case SCALAR_TYPE_FAMILY: 00224 case VECTOR_TYPE_FAMILY: 00225 case MATRIX_TYPE_FAMILY: 00226 execute_single(s, root_node); 00227 break; 00228 default: 00229 throw statement_not_supported_exception("Invalid rvalue encountered in vector assignment"); 00230 } 00231 00232 } 00233 } 00234 00235 inline void execute(statement const & s) 00236 { 00237 // simply start execution from the root node: 00238 detail::execute_impl(s, s.array()[s.root()]); 00239 } 00240 00241 00242 } 00243 00244 } //namespace viennacl 00245 00246 #endif 00247