ViennaCL - The Vienna Computing Library
1.5.1
|
00001 #ifndef VIENNACL_SCHEDULER_EXECUTE_MATRIX_PROD_HPP 00002 #define VIENNACL_SCHEDULER_EXECUTE_MATRIX_PROD_HPP 00003 00004 /* ========================================================================= 00005 Copyright (c) 2010-2014, Institute for Microelectronics, 00006 Institute for Analysis and Scientific Computing, 00007 TU Wien. 00008 Portions of this software are copyright by UChicago Argonne, LLC. 00009 00010 ----------------- 00011 ViennaCL - The Vienna Computing Library 00012 ----------------- 00013 00014 Project Head: Karl Rupp rupp@iue.tuwien.ac.at 00015 00016 (A list of authors and contributors can be found in the PDF manual) 00017 00018 License: MIT (X11), see file LICENSE in the base directory 00019 ============================================================================= */ 00020 00021 00026 #include "viennacl/forwards.h" 00027 #include "viennacl/scheduler/forwards.h" 00028 #include "viennacl/scheduler/execute_util.hpp" 00029 #include "viennacl/scheduler/execute_generic_dispatcher.hpp" 00030 #include "viennacl/linalg/vector_operations.hpp" 00031 #include "viennacl/linalg/matrix_operations.hpp" 00032 #include "viennacl/linalg/sparse_matrix_operations.hpp" 00033 #include "viennacl/compressed_matrix.hpp" 00034 #include "viennacl/coordinate_matrix.hpp" 00035 #include "viennacl/ell_matrix.hpp" 00036 #include "viennacl/hyb_matrix.hpp" 00037 00038 namespace viennacl 00039 { 00040 namespace scheduler 00041 { 00042 namespace detail 00043 { 00044 inline bool matrix_prod_temporary_required(statement const & s, lhs_rhs_element const & elem) 00045 { 00046 if (elem.type_family != COMPOSITE_OPERATION_FAMILY) 00047 return false; 00048 00049 // check composite node for being a transposed matrix proxy: 00050 statement_node const & leaf = s.array()[elem.node_index]; 00051 if ( leaf.op.type == OPERATION_UNARY_TRANS_TYPE && leaf.lhs.type_family == MATRIX_TYPE_FAMILY) 00052 return false; 00053 00054 return true; 00055 } 00056 00057 inline void matrix_matrix_prod(statement const & s, 00058 lhs_rhs_element result, 00059 lhs_rhs_element const & A, 00060 lhs_rhs_element const & B, 00061 double alpha, 00062 double beta) 00063 { 00064 if (A.type_family == MATRIX_TYPE_FAMILY && B.type_family == MATRIX_TYPE_FAMILY) // C = A * B 00065 { 00066 assert( A.numeric_type == B.numeric_type && bool("Numeric type not the same!")); 00067 assert( result.numeric_type == B.numeric_type && bool("Numeric type not the same!")); 00068 00069 #define VIENNACL_SCHEDULER_GENERATE_MATRIX_MATRIX_PROD(LAYOUTA, MEMBERA, LAYOUTB, MEMBERB, LAYOUTC, MEMBERC)\ 00070 if (A.subtype == LAYOUTA && B.subtype == LAYOUTB && result.subtype == LAYOUTC)\ 00071 {\ 00072 switch (result.numeric_type)\ 00073 {\ 00074 case FLOAT_TYPE:\ 00075 viennacl::linalg::prod_impl(*A.matrix_##MEMBERA##_float, *B.matrix_##MEMBERB##_float, *result.matrix_##MEMBERC##_float, static_cast<float>(alpha), static_cast<float>(beta)); break;\ 00076 case DOUBLE_TYPE:\ 00077 viennacl::linalg::prod_impl(*A.matrix_##MEMBERA##_double, *B.matrix_##MEMBERB##_double, *result.matrix_##MEMBERC##_double, alpha, beta); break;\ 00078 default:\ 00079 throw statement_not_supported_exception("Invalid numeric type in matrix-matrix multiplication");\ 00080 }\ 00081 } 00082 00083 VIENNACL_SCHEDULER_GENERATE_MATRIX_MATRIX_PROD(DENSE_ROW_MATRIX_TYPE, row, DENSE_ROW_MATRIX_TYPE, row, DENSE_ROW_MATRIX_TYPE, row) 00084 VIENNACL_SCHEDULER_GENERATE_MATRIX_MATRIX_PROD(DENSE_ROW_MATRIX_TYPE, row, DENSE_ROW_MATRIX_TYPE, row, DENSE_COL_MATRIX_TYPE, col) 00085 00086 VIENNACL_SCHEDULER_GENERATE_MATRIX_MATRIX_PROD(DENSE_ROW_MATRIX_TYPE, row, DENSE_COL_MATRIX_TYPE, col, DENSE_ROW_MATRIX_TYPE, row) 00087 VIENNACL_SCHEDULER_GENERATE_MATRIX_MATRIX_PROD(DENSE_ROW_MATRIX_TYPE, row, DENSE_COL_MATRIX_TYPE, col, DENSE_COL_MATRIX_TYPE, col) 00088 00089 VIENNACL_SCHEDULER_GENERATE_MATRIX_MATRIX_PROD(DENSE_COL_MATRIX_TYPE, col, DENSE_ROW_MATRIX_TYPE, row, DENSE_ROW_MATRIX_TYPE, row) 00090 VIENNACL_SCHEDULER_GENERATE_MATRIX_MATRIX_PROD(DENSE_COL_MATRIX_TYPE, col, DENSE_ROW_MATRIX_TYPE, row, DENSE_COL_MATRIX_TYPE, col) 00091 00092 VIENNACL_SCHEDULER_GENERATE_MATRIX_MATRIX_PROD(DENSE_COL_MATRIX_TYPE, col, DENSE_COL_MATRIX_TYPE, col, DENSE_ROW_MATRIX_TYPE, row) 00093 VIENNACL_SCHEDULER_GENERATE_MATRIX_MATRIX_PROD(DENSE_COL_MATRIX_TYPE, col, DENSE_COL_MATRIX_TYPE, col, DENSE_COL_MATRIX_TYPE, col) 00094 00095 #undef VIENNACL_SCHEDULER_GENERATE_MATRIX_MATRIX_PROD 00096 } 00097 else if (A.type_family == MATRIX_TYPE_FAMILY && B.type_family == COMPOSITE_OPERATION_FAMILY) // C = A * B^T 00098 { 00099 statement_node const & leaf = s.array()[B.node_index]; 00100 00101 assert(leaf.lhs.type_family == MATRIX_TYPE_FAMILY && leaf.op.type == OPERATION_UNARY_TRANS_TYPE && bool("Logic error: Argument not a matrix transpose!")); 00102 assert(leaf.lhs.numeric_type == result.numeric_type && bool("Numeric type not the same!")); 00103 assert(result.numeric_type == A.numeric_type && bool("Numeric type not the same!")); 00104 00105 #define VIENNACL_SCHEDULER_GENERATE_MATRIX_MATRIX_PROD(LAYOUTA, MEMBERA, LAYOUTB, MEMBERB, MAJORB, LAYOUTC, MEMBERC)\ 00106 if (A.subtype == LAYOUTA && leaf.lhs.subtype == LAYOUTB && result.subtype == LAYOUTC)\ 00107 {\ 00108 switch (result.numeric_type)\ 00109 {\ 00110 case FLOAT_TYPE:\ 00111 viennacl::linalg::prod_impl(*A.matrix_##MEMBERA##_float, \ 00112 viennacl::matrix_expression< const matrix_base<float, MAJORB>,\ 00113 const matrix_base<float, MAJORB>,\ 00114 op_trans> (*(leaf.lhs.matrix_##MEMBERB##_float), *(leaf.lhs.matrix_##MEMBERB##_float)), \ 00115 *result.matrix_##MEMBERC##_float, static_cast<float>(alpha), static_cast<float>(beta)); break;\ 00116 case DOUBLE_TYPE:\ 00117 viennacl::linalg::prod_impl(*A.matrix_##MEMBERA##_double,\ 00118 viennacl::matrix_expression< const matrix_base<double, MAJORB>,\ 00119 const matrix_base<double, MAJORB>,\ 00120 op_trans>(*(leaf.lhs.matrix_##MEMBERB##_double), *(leaf.lhs.matrix_##MEMBERB##_double)), \ 00121 *result.matrix_##MEMBERC##_double, alpha, beta); break;\ 00122 default:\ 00123 throw statement_not_supported_exception("Invalid numeric type in matrix-matrix multiplication");\ 00124 }\ 00125 } 00126 00127 VIENNACL_SCHEDULER_GENERATE_MATRIX_MATRIX_PROD(DENSE_ROW_MATRIX_TYPE, row, DENSE_ROW_MATRIX_TYPE, row, row_major, DENSE_ROW_MATRIX_TYPE, row) 00128 VIENNACL_SCHEDULER_GENERATE_MATRIX_MATRIX_PROD(DENSE_ROW_MATRIX_TYPE, row, DENSE_ROW_MATRIX_TYPE, row, row_major, DENSE_COL_MATRIX_TYPE, col) 00129 00130 VIENNACL_SCHEDULER_GENERATE_MATRIX_MATRIX_PROD(DENSE_ROW_MATRIX_TYPE, row, DENSE_COL_MATRIX_TYPE, col, column_major, DENSE_ROW_MATRIX_TYPE, row) 00131 VIENNACL_SCHEDULER_GENERATE_MATRIX_MATRIX_PROD(DENSE_ROW_MATRIX_TYPE, row, DENSE_COL_MATRIX_TYPE, col, column_major, DENSE_COL_MATRIX_TYPE, col) 00132 00133 VIENNACL_SCHEDULER_GENERATE_MATRIX_MATRIX_PROD(DENSE_COL_MATRIX_TYPE, col, DENSE_ROW_MATRIX_TYPE, row, row_major, DENSE_ROW_MATRIX_TYPE, row) 00134 VIENNACL_SCHEDULER_GENERATE_MATRIX_MATRIX_PROD(DENSE_COL_MATRIX_TYPE, col, DENSE_ROW_MATRIX_TYPE, row, row_major, DENSE_COL_MATRIX_TYPE, col) 00135 00136 VIENNACL_SCHEDULER_GENERATE_MATRIX_MATRIX_PROD(DENSE_COL_MATRIX_TYPE, col, DENSE_COL_MATRIX_TYPE, col, column_major, DENSE_ROW_MATRIX_TYPE, row) 00137 VIENNACL_SCHEDULER_GENERATE_MATRIX_MATRIX_PROD(DENSE_COL_MATRIX_TYPE, col, DENSE_COL_MATRIX_TYPE, col, column_major, DENSE_COL_MATRIX_TYPE, col) 00138 00139 #undef VIENNACL_SCHEDULER_GENERATE_MATRIX_MATRIX_PROD 00140 } 00141 else if (A.type_family == COMPOSITE_OPERATION_FAMILY && B.type_family == MATRIX_TYPE_FAMILY) // C = A^T * B 00142 { 00143 statement_node const & leaf = s.array()[A.node_index]; 00144 00145 assert(leaf.lhs.type_family == MATRIX_TYPE_FAMILY && leaf.op.type == OPERATION_UNARY_TRANS_TYPE && bool("Logic error: Argument not a matrix transpose!")); 00146 assert(leaf.lhs.numeric_type == result.numeric_type && bool("Numeric type not the same!")); 00147 assert(result.numeric_type == B.numeric_type && bool("Numeric type not the same!")); 00148 00149 #define VIENNACL_SCHEDULER_GENERATE_MATRIX_MATRIX_PROD(LAYOUTA, MEMBERA, MAJORA, LAYOUTB, MEMBERB, LAYOUTC, MEMBERC)\ 00150 if (leaf.lhs.subtype == LAYOUTA && B.subtype == LAYOUTB && result.subtype == LAYOUTC)\ 00151 {\ 00152 switch (result.numeric_type)\ 00153 {\ 00154 case FLOAT_TYPE:\ 00155 viennacl::linalg::prod_impl(viennacl::matrix_expression< const matrix_base<float, MAJORA>,\ 00156 const matrix_base<float, MAJORA>,\ 00157 op_trans>(*leaf.lhs.matrix_##MEMBERA##_float, *leaf.lhs.matrix_##MEMBERA##_float), \ 00158 *B.matrix_##MEMBERB##_float,\ 00159 *result.matrix_##MEMBERC##_float, static_cast<float>(alpha), static_cast<float>(beta)); break;\ 00160 case DOUBLE_TYPE:\ 00161 viennacl::linalg::prod_impl(viennacl::matrix_expression< const matrix_base<double, MAJORA>,\ 00162 const matrix_base<double, MAJORA>,\ 00163 op_trans>(*leaf.lhs.matrix_##MEMBERA##_double, *leaf.lhs.matrix_##MEMBERA##_double), \ 00164 *B.matrix_##MEMBERB##_double,\ 00165 *result.matrix_##MEMBERC##_double, alpha, beta); break;\ 00166 default:\ 00167 throw statement_not_supported_exception("Invalid numeric type in matrix-matrix multiplication");\ 00168 }\ 00169 } 00170 00171 VIENNACL_SCHEDULER_GENERATE_MATRIX_MATRIX_PROD(DENSE_ROW_MATRIX_TYPE, row, row_major, DENSE_ROW_MATRIX_TYPE, row, DENSE_ROW_MATRIX_TYPE, row) 00172 VIENNACL_SCHEDULER_GENERATE_MATRIX_MATRIX_PROD(DENSE_ROW_MATRIX_TYPE, row, row_major, DENSE_ROW_MATRIX_TYPE, row, DENSE_COL_MATRIX_TYPE, col) 00173 00174 VIENNACL_SCHEDULER_GENERATE_MATRIX_MATRIX_PROD(DENSE_ROW_MATRIX_TYPE, row, row_major, DENSE_COL_MATRIX_TYPE, col, DENSE_ROW_MATRIX_TYPE, row) 00175 VIENNACL_SCHEDULER_GENERATE_MATRIX_MATRIX_PROD(DENSE_ROW_MATRIX_TYPE, row, row_major, DENSE_COL_MATRIX_TYPE, col, DENSE_COL_MATRIX_TYPE, col) 00176 00177 VIENNACL_SCHEDULER_GENERATE_MATRIX_MATRIX_PROD(DENSE_COL_MATRIX_TYPE, col, column_major, DENSE_ROW_MATRIX_TYPE, row, DENSE_ROW_MATRIX_TYPE, row) 00178 VIENNACL_SCHEDULER_GENERATE_MATRIX_MATRIX_PROD(DENSE_COL_MATRIX_TYPE, col, column_major, DENSE_ROW_MATRIX_TYPE, row, DENSE_COL_MATRIX_TYPE, col) 00179 00180 VIENNACL_SCHEDULER_GENERATE_MATRIX_MATRIX_PROD(DENSE_COL_MATRIX_TYPE, col, column_major, DENSE_COL_MATRIX_TYPE, col, DENSE_ROW_MATRIX_TYPE, row) 00181 VIENNACL_SCHEDULER_GENERATE_MATRIX_MATRIX_PROD(DENSE_COL_MATRIX_TYPE, col, column_major, DENSE_COL_MATRIX_TYPE, col, DENSE_COL_MATRIX_TYPE, col) 00182 00183 #undef VIENNACL_SCHEDULER_GENERATE_MATRIX_MATRIX_PROD 00184 } 00185 else if (A.type_family == COMPOSITE_OPERATION_FAMILY && B.type_family == COMPOSITE_OPERATION_FAMILY) // C = A^T * B^T 00186 { 00187 statement_node const & leafA = s.array()[A.node_index]; 00188 statement_node const & leafB = s.array()[B.node_index]; 00189 00190 assert(leafA.lhs.type_family == MATRIX_TYPE_FAMILY && leafA.op.type == OPERATION_UNARY_TRANS_TYPE && bool("Logic error: Argument not a matrix transpose!")); 00191 assert(leafB.lhs.type_family == MATRIX_TYPE_FAMILY && leafB.op.type == OPERATION_UNARY_TRANS_TYPE && bool("Logic error: Argument not a matrix transpose!")); 00192 assert(leafA.lhs.numeric_type == result.numeric_type && bool("Numeric type not the same!")); 00193 assert(leafB.lhs.numeric_type == result.numeric_type && bool("Numeric type not the same!")); 00194 00195 #define VIENNACL_SCHEDULER_GENERATE_MATRIX_MATRIX_PROD(LAYOUTA, MEMBERA, MAJORA, LAYOUTB, MEMBERB, MAJORB, LAYOUTC, MEMBERC)\ 00196 if (leafA.lhs.subtype == LAYOUTA && leafB.lhs.subtype == LAYOUTB && result.subtype == LAYOUTC)\ 00197 {\ 00198 switch (result.numeric_type)\ 00199 {\ 00200 case FLOAT_TYPE:\ 00201 viennacl::linalg::prod_impl(viennacl::matrix_expression< const matrix_base<float, MAJORA>,\ 00202 const matrix_base<float, MAJORA>,\ 00203 op_trans>(*leafA.lhs.matrix_##MEMBERA##_float, *leafA.lhs.matrix_##MEMBERA##_float), \ 00204 viennacl::matrix_expression< const matrix_base<float, MAJORB>,\ 00205 const matrix_base<float, MAJORB>,\ 00206 op_trans>(*leafB.lhs.matrix_##MEMBERB##_float, *leafB.lhs.matrix_##MEMBERB##_float), \ 00207 *result.matrix_##MEMBERC##_float, static_cast<float>(alpha), static_cast<float>(beta)); break;\ 00208 case DOUBLE_TYPE:\ 00209 viennacl::linalg::prod_impl(viennacl::matrix_expression< const matrix_base<double, MAJORA>,\ 00210 const matrix_base<double, MAJORA>,\ 00211 op_trans>(*leafA.lhs.matrix_##MEMBERA##_double, *leafA.lhs.matrix_##MEMBERA##_double), \ 00212 viennacl::matrix_expression< const matrix_base<double, MAJORB>,\ 00213 const matrix_base<double, MAJORB>,\ 00214 op_trans>(*leafB.lhs.matrix_##MEMBERB##_double, *leafB.lhs.matrix_##MEMBERB##_double), \ 00215 *result.matrix_##MEMBERC##_double, alpha, beta); break;\ 00216 default:\ 00217 throw statement_not_supported_exception("Invalid numeric type in matrix-matrix multiplication");\ 00218 }\ 00219 } 00220 00221 VIENNACL_SCHEDULER_GENERATE_MATRIX_MATRIX_PROD(DENSE_ROW_MATRIX_TYPE, row, row_major, DENSE_ROW_MATRIX_TYPE, row, row_major, DENSE_ROW_MATRIX_TYPE, row) 00222 VIENNACL_SCHEDULER_GENERATE_MATRIX_MATRIX_PROD(DENSE_ROW_MATRIX_TYPE, row, row_major, DENSE_ROW_MATRIX_TYPE, row, row_major, DENSE_COL_MATRIX_TYPE, col) 00223 00224 VIENNACL_SCHEDULER_GENERATE_MATRIX_MATRIX_PROD(DENSE_ROW_MATRIX_TYPE, row, row_major, DENSE_COL_MATRIX_TYPE, col, column_major, DENSE_ROW_MATRIX_TYPE, row) 00225 VIENNACL_SCHEDULER_GENERATE_MATRIX_MATRIX_PROD(DENSE_ROW_MATRIX_TYPE, row, row_major, DENSE_COL_MATRIX_TYPE, col, column_major, DENSE_COL_MATRIX_TYPE, col) 00226 00227 VIENNACL_SCHEDULER_GENERATE_MATRIX_MATRIX_PROD(DENSE_COL_MATRIX_TYPE, col, column_major, DENSE_ROW_MATRIX_TYPE, row, row_major, DENSE_ROW_MATRIX_TYPE, row) 00228 VIENNACL_SCHEDULER_GENERATE_MATRIX_MATRIX_PROD(DENSE_COL_MATRIX_TYPE, col, column_major, DENSE_ROW_MATRIX_TYPE, row, row_major, DENSE_COL_MATRIX_TYPE, col) 00229 00230 VIENNACL_SCHEDULER_GENERATE_MATRIX_MATRIX_PROD(DENSE_COL_MATRIX_TYPE, col, column_major, DENSE_COL_MATRIX_TYPE, col, column_major, DENSE_ROW_MATRIX_TYPE, row) 00231 VIENNACL_SCHEDULER_GENERATE_MATRIX_MATRIX_PROD(DENSE_COL_MATRIX_TYPE, col, column_major, DENSE_COL_MATRIX_TYPE, col, column_major, DENSE_COL_MATRIX_TYPE, col) 00232 00233 #undef VIENNACL_SCHEDULER_GENERATE_MATRIX_MATRIX_PROD 00234 } 00235 else 00236 throw statement_not_supported_exception("Matrix-matrix multiplication encountered operands being neither dense matrices nor transposed dense matrices"); 00237 } 00238 00239 inline void matrix_vector_prod(statement const & s, 00240 lhs_rhs_element result, 00241 lhs_rhs_element const & A, 00242 lhs_rhs_element const & x) 00243 { 00244 assert( result.numeric_type == x.numeric_type && bool("Numeric type not the same!")); 00245 assert( result.type_family == x.type_family && bool("Subtype not the same!")); 00246 assert( result.subtype == DENSE_VECTOR_TYPE && bool("Result node for matrix-vector product not a vector type!")); 00247 00248 // deal with transposed product first: 00249 // switch: trans for A 00250 if (A.type_family == COMPOSITE_OPERATION_FAMILY) // prod(trans(A), x) 00251 { 00252 statement_node const & leaf = s.array()[A.node_index]; 00253 00254 assert(leaf.lhs.type_family == MATRIX_TYPE_FAMILY && leaf.op.type == OPERATION_UNARY_TRANS_TYPE && bool("Logic error: Argument not a matrix transpose!")); 00255 assert(leaf.lhs.numeric_type == x.numeric_type && bool("Numeric type not the same!")); 00256 00257 if (leaf.lhs.subtype == DENSE_ROW_MATRIX_TYPE) 00258 { 00259 switch (leaf.lhs.numeric_type) 00260 { 00261 case FLOAT_TYPE: 00262 viennacl::linalg::prod_impl(viennacl::matrix_expression< const matrix_base<float, row_major>, 00263 const matrix_base<float, row_major>, 00264 op_trans>(*leaf.lhs.matrix_row_float, *leaf.lhs.matrix_row_float), 00265 *x.vector_float, 00266 *result.vector_float); break; 00267 case DOUBLE_TYPE: 00268 viennacl::linalg::prod_impl(viennacl::matrix_expression< const matrix_base<double, row_major>, 00269 const matrix_base<double, row_major>, 00270 op_trans>(*leaf.lhs.matrix_row_double, *leaf.lhs.matrix_row_double), 00271 *x.vector_double, 00272 *result.vector_double); break; 00273 default: 00274 throw statement_not_supported_exception("Invalid numeric type in matrix-{matrix,vector} multiplication"); 00275 } 00276 } 00277 else if (leaf.lhs.subtype == DENSE_COL_MATRIX_TYPE) 00278 { 00279 switch (leaf.lhs.numeric_type) 00280 { 00281 case FLOAT_TYPE: 00282 viennacl::linalg::prod_impl(viennacl::matrix_expression< const matrix_base<float, column_major>, 00283 const matrix_base<float, column_major>, 00284 op_trans>(*leaf.lhs.matrix_col_float, *leaf.lhs.matrix_col_float), 00285 *x.vector_float, 00286 *result.vector_float); break; 00287 case DOUBLE_TYPE: 00288 viennacl::linalg::prod_impl(viennacl::matrix_expression< const matrix_base<double, column_major>, 00289 const matrix_base<double, column_major>, 00290 op_trans>(*leaf.lhs.matrix_col_double, *leaf.lhs.matrix_col_double), 00291 *x.vector_double, 00292 *result.vector_double); break; 00293 default: 00294 throw statement_not_supported_exception("Invalid numeric type in matrix-{matrix,vector} multiplication"); 00295 } 00296 } 00297 else 00298 throw statement_not_supported_exception("Invalid matrix type for transposed matrix-vector product"); 00299 } 00300 else if (A.subtype == DENSE_ROW_MATRIX_TYPE) 00301 { 00302 switch (A.numeric_type) 00303 { 00304 case FLOAT_TYPE: 00305 viennacl::linalg::prod_impl(*A.matrix_row_float, *x.vector_float, *result.vector_float); 00306 break; 00307 case DOUBLE_TYPE: 00308 viennacl::linalg::prod_impl(*A.matrix_row_double, *x.vector_double, *result.vector_double); 00309 break; 00310 default: 00311 throw statement_not_supported_exception("Invalid numeric type in matrix-{matrix,vector} multiplication"); 00312 } 00313 } 00314 else if (A.subtype == DENSE_COL_MATRIX_TYPE) 00315 { 00316 switch (A.numeric_type) 00317 { 00318 case FLOAT_TYPE: 00319 viennacl::linalg::prod_impl(*A.matrix_col_float, *x.vector_float, *result.vector_float); 00320 break; 00321 case DOUBLE_TYPE: 00322 viennacl::linalg::prod_impl(*A.matrix_col_double, *x.vector_double, *result.vector_double); 00323 break; 00324 default: 00325 throw statement_not_supported_exception("Invalid numeric type in matrix-{matrix,vector} multiplication"); 00326 } 00327 } 00328 else if (A.subtype == COMPRESSED_MATRIX_TYPE) 00329 { 00330 switch (A.numeric_type) 00331 { 00332 case FLOAT_TYPE: 00333 viennacl::linalg::prod_impl(*A.compressed_matrix_float, *x.vector_float, *result.vector_float); 00334 break; 00335 case DOUBLE_TYPE: 00336 viennacl::linalg::prod_impl(*A.compressed_matrix_double, *x.vector_double, *result.vector_double); 00337 break; 00338 default: 00339 throw statement_not_supported_exception("Invalid numeric type in matrix-{matrix,vector} multiplication"); 00340 } 00341 } 00342 else if (A.subtype == COORDINATE_MATRIX_TYPE) 00343 { 00344 switch (A.numeric_type) 00345 { 00346 case FLOAT_TYPE: 00347 viennacl::linalg::prod_impl(*A.coordinate_matrix_float, *x.vector_float, *result.vector_float); 00348 break; 00349 case DOUBLE_TYPE: 00350 viennacl::linalg::prod_impl(*A.coordinate_matrix_double, *x.vector_double, *result.vector_double); 00351 break; 00352 default: 00353 throw statement_not_supported_exception("Invalid numeric type in matrix-{matrix,vector} multiplication"); 00354 } 00355 } 00356 else if (A.subtype == ELL_MATRIX_TYPE) 00357 { 00358 switch (A.numeric_type) 00359 { 00360 case FLOAT_TYPE: 00361 viennacl::linalg::prod_impl(*A.ell_matrix_float, *x.vector_float, *result.vector_float); 00362 break; 00363 case DOUBLE_TYPE: 00364 viennacl::linalg::prod_impl(*A.ell_matrix_double, *x.vector_double, *result.vector_double); 00365 break; 00366 default: 00367 throw statement_not_supported_exception("Invalid numeric type in matrix-{matrix,vector} multiplication"); 00368 } 00369 } 00370 else if (A.subtype == HYB_MATRIX_TYPE) 00371 { 00372 switch (A.numeric_type) 00373 { 00374 case FLOAT_TYPE: 00375 viennacl::linalg::prod_impl(*A.hyb_matrix_float, *x.vector_float, *result.vector_float); 00376 break; 00377 case DOUBLE_TYPE: 00378 viennacl::linalg::prod_impl(*A.hyb_matrix_double, *x.vector_double, *result.vector_double); 00379 break; 00380 default: 00381 throw statement_not_supported_exception("Invalid numeric type in matrix-{matrix,vector} multiplication"); 00382 } 00383 } 00384 else 00385 { 00386 std::cout << "A.subtype: " << A.subtype << std::endl; 00387 throw statement_not_supported_exception("Invalid matrix type for matrix-vector product"); 00388 } 00389 } 00390 00391 } // namespace detail 00392 00393 inline void execute_matrix_prod(statement const & s, statement_node const & root_node) 00394 { 00395 statement_node const & leaf = s.array()[root_node.rhs.node_index]; 00396 00397 // Part 1: Check whether temporaries are required // 00398 00399 statement_node new_root_lhs; 00400 statement_node new_root_rhs; 00401 00402 bool lhs_needs_temporary = detail::matrix_prod_temporary_required(s, leaf.lhs); 00403 bool rhs_needs_temporary = detail::matrix_prod_temporary_required(s, leaf.rhs); 00404 00405 // check for temporary on lhs: 00406 if (lhs_needs_temporary) 00407 { 00408 std::cout << "Temporary for LHS!" << std::endl; 00409 detail::new_element(new_root_lhs.lhs, root_node.lhs); 00410 00411 new_root_lhs.op.type_family = OPERATION_BINARY_TYPE_FAMILY; 00412 new_root_lhs.op.type = OPERATION_BINARY_ASSIGN_TYPE; 00413 00414 new_root_lhs.rhs.type_family = COMPOSITE_OPERATION_FAMILY; 00415 new_root_lhs.rhs.subtype = INVALID_SUBTYPE; 00416 new_root_lhs.rhs.numeric_type = INVALID_NUMERIC_TYPE; 00417 new_root_lhs.rhs.node_index = leaf.lhs.node_index; 00418 00419 // work on subexpression: 00420 // TODO: Catch exception, free temporary, then rethrow 00421 detail::execute_composite(s, new_root_lhs); 00422 } 00423 00424 // check for temporary on rhs: 00425 if (rhs_needs_temporary) 00426 { 00427 detail::new_element(new_root_rhs.lhs, root_node.lhs); 00428 00429 new_root_rhs.op.type_family = OPERATION_BINARY_TYPE_FAMILY; 00430 new_root_rhs.op.type = OPERATION_BINARY_ASSIGN_TYPE; 00431 00432 new_root_rhs.rhs.type_family = COMPOSITE_OPERATION_FAMILY; 00433 new_root_rhs.rhs.subtype = INVALID_SUBTYPE; 00434 new_root_rhs.rhs.numeric_type = INVALID_NUMERIC_TYPE; 00435 new_root_rhs.rhs.node_index = leaf.rhs.node_index; 00436 00437 // work on subexpression: 00438 // TODO: Catch exception, free temporary, then rethrow 00439 detail::execute_composite(s, new_root_rhs); 00440 } 00441 00442 // Part 2: Run the actual computations // 00443 00444 lhs_rhs_element x = lhs_needs_temporary ? new_root_lhs.lhs : leaf.lhs; 00445 lhs_rhs_element y = rhs_needs_temporary ? new_root_rhs.lhs : leaf.rhs; 00446 00447 if (root_node.lhs.type_family == VECTOR_TYPE_FAMILY) 00448 { 00449 if (root_node.op.type != OPERATION_BINARY_ASSIGN_TYPE) 00450 { 00451 //split y += A*x 00452 statement_node new_root_z; 00453 detail::new_element(new_root_z.lhs, root_node.lhs); 00454 00455 // compute z = A * x 00456 detail::matrix_vector_prod(s, new_root_z.lhs, x, y); 00457 00458 // assignment y = z 00459 double alpha = 0; 00460 if (root_node.op.type == OPERATION_BINARY_INPLACE_ADD_TYPE) 00461 alpha = 1.0; 00462 else if (root_node.op.type == OPERATION_BINARY_INPLACE_SUB_TYPE) 00463 alpha = -1.0; 00464 else 00465 throw statement_not_supported_exception("Invalid assignment type for matrix-vector product"); 00466 00467 lhs_rhs_element y = root_node.lhs; 00468 detail::axbx(y, 00469 y, 1.0, 1, false, false, 00470 new_root_z.lhs, alpha, 1, false, false); 00471 00472 detail::delete_element(new_root_z.lhs); 00473 } 00474 else 00475 detail::matrix_vector_prod(s, root_node.lhs, x, y); 00476 } 00477 else 00478 { 00479 double alpha = (root_node.op.type == OPERATION_BINARY_INPLACE_SUB_TYPE) ? -1.0 : 1.0; 00480 double beta = (root_node.op.type != OPERATION_BINARY_ASSIGN_TYPE) ? 1.0 : 0.0; 00481 00482 detail::matrix_matrix_prod(s, root_node.lhs, x, y, alpha, beta); 00483 } 00484 00485 // Part 3: Clean up // 00486 00487 if (lhs_needs_temporary) 00488 detail::delete_element(new_root_lhs.lhs); 00489 00490 if (rhs_needs_temporary) 00491 detail::delete_element(new_root_rhs.lhs); 00492 } 00493 00494 } // namespace scheduler 00495 } // namespace viennacl 00496 00497 #endif 00498