ViennaCL - The Vienna Computing Library
1.5.1
|
00001 #ifndef VIENNACL_LINALG_PROD_HPP_ 00002 #define VIENNACL_LINALG_PROD_HPP_ 00003 00004 /* ========================================================================= 00005 Copyright (c) 2010-2014, Institute for Microelectronics, 00006 Institute for Analysis and Scientific Computing, 00007 TU Wien. 00008 Portions of this software are copyright by UChicago Argonne, LLC. 00009 00010 ----------------- 00011 ViennaCL - The Vienna Computing Library 00012 ----------------- 00013 00014 Project Head: Karl Rupp rupp@iue.tuwien.ac.at 00015 00016 (A list of authors and contributors can be found in the PDF manual) 00017 00018 License: MIT (X11), see file LICENSE in the base directory 00019 ============================================================================= */ 00020 00027 #include "viennacl/forwards.h" 00028 #include "viennacl/tools/tools.hpp" 00029 #include "viennacl/meta/enable_if.hpp" 00030 #include "viennacl/meta/tag_of.hpp" 00031 #include <vector> 00032 #include <map> 00033 00034 namespace viennacl 00035 { 00036 // 00037 // generic prod function 00038 // uses tag dispatch to identify which algorithm 00039 // should be called 00040 // 00041 namespace linalg 00042 { 00043 #ifdef VIENNACL_WITH_MTL4 00044 // ---------------------------------------------------- 00045 // mtl4 00046 // 00047 template< typename MatrixT, typename VectorT > 00048 typename viennacl::enable_if< viennacl::is_mtl4< typename viennacl::traits::tag_of< MatrixT >::type >::value, 00049 VectorT>::type 00050 prod(MatrixT const& matrix, VectorT const& vector) 00051 { 00052 return VectorT(matrix * vector); 00053 } 00054 #endif 00055 00056 #ifdef VIENNACL_WITH_EIGEN 00057 // ---------------------------------------------------- 00058 // Eigen 00059 // 00060 template< typename MatrixT, typename VectorT > 00061 typename viennacl::enable_if< viennacl::is_eigen< typename viennacl::traits::tag_of< MatrixT >::type >::value, 00062 VectorT>::type 00063 prod(MatrixT const& matrix, VectorT const& vector) 00064 { 00065 return matrix * vector; 00066 } 00067 #endif 00068 00069 #ifdef VIENNACL_WITH_UBLAS 00070 // ---------------------------------------------------- 00071 // UBLAS 00072 // 00073 template< typename MatrixT, typename VectorT > 00074 typename viennacl::enable_if< viennacl::is_ublas< typename viennacl::traits::tag_of< MatrixT >::type >::value, 00075 VectorT>::type 00076 prod(MatrixT const& matrix, VectorT const& vector) 00077 { 00078 // std::cout << "ublas .. " << std::endl; 00079 return boost::numeric::ublas::prod(matrix, vector); 00080 } 00081 #endif 00082 00083 00084 // ---------------------------------------------------- 00085 // STL type 00086 // 00087 00088 // dense matrix-vector product: 00089 template< typename T, typename A1, typename A2, typename VectorT > 00090 VectorT 00091 prod(std::vector< std::vector<T, A1>, A2 > const & matrix, VectorT const& vector) 00092 { 00093 VectorT result(matrix.size()); 00094 for (typename std::vector<T, A1>::size_type i=0; i<matrix.size(); ++i) 00095 { 00096 result[i] = 0; //we will not assume that VectorT is initialized to zero 00097 for (typename std::vector<T, A1>::size_type j=0; j<matrix[i].size(); ++j) 00098 result[i] += matrix[i][j] * vector[j]; 00099 } 00100 return result; 00101 } 00102 00103 // sparse matrix-vector product: 00104 template< typename KEY, typename DATA, typename COMPARE, typename AMAP, typename AVEC, typename VectorT > 00105 VectorT 00106 prod(std::vector< std::map<KEY, DATA, COMPARE, AMAP>, AVEC > const& matrix, VectorT const& vector) 00107 { 00108 typedef std::vector< std::map<KEY, DATA, COMPARE, AMAP>, AVEC > MatrixType; 00109 00110 VectorT result(matrix.size()); 00111 for (typename MatrixType::size_type i=0; i<matrix.size(); ++i) 00112 { 00113 result[i] = 0; //we will not assume that VectorT is initialized to zero 00114 for (typename std::map<KEY, DATA, COMPARE, AMAP>::const_iterator row_entries = matrix[i].begin(); 00115 row_entries != matrix[i].end(); 00116 ++row_entries) 00117 result[i] += row_entries->second * vector[row_entries->first]; 00118 } 00119 return result; 00120 } 00121 00122 00123 /*template< typename MatrixT, typename VectorT > 00124 VectorT 00125 prod(MatrixT const& matrix, VectorT const& vector, 00126 typename viennacl::enable_if< viennacl::is_stl< typename viennacl::traits::tag_of< MatrixT >::type >::value 00127 >::type* dummy = 0) 00128 { 00129 // std::cout << "std .. " << std::endl; 00130 return prod_impl(matrix, vector); 00131 }*/ 00132 00133 // ---------------------------------------------------- 00134 // VIENNACL 00135 // 00136 00137 // standard product: 00138 template< typename NumericT, typename F1, typename F2> 00139 viennacl::matrix_expression< const viennacl::matrix_base<NumericT, F1>, 00140 const viennacl::matrix_base<NumericT, F2>, 00141 viennacl::op_mat_mat_prod > 00142 prod(viennacl::matrix_base<NumericT, F1> const & A, 00143 viennacl::matrix_base<NumericT, F2> const & B) 00144 { 00145 // std::cout << "viennacl .. " << std::endl; 00146 return viennacl::matrix_expression< const viennacl::matrix_base<NumericT, F1>, 00147 const viennacl::matrix_base<NumericT, F2>, 00148 viennacl::op_mat_mat_prod >(A, B); 00149 } 00150 00151 // right factor is transposed: 00152 template< typename NumericT, typename F1, typename F2> 00153 viennacl::matrix_expression< const viennacl::matrix_base<NumericT, F1>, 00154 const viennacl::matrix_expression<const viennacl::matrix_base<NumericT, F2>, 00155 const viennacl::matrix_base<NumericT, F2>, 00156 op_trans>, 00157 viennacl::op_mat_mat_prod > 00158 prod(viennacl::matrix_base<NumericT, F1> const & A, 00159 viennacl::matrix_expression<const viennacl::matrix_base<NumericT, F2>, 00160 const viennacl::matrix_base<NumericT, F2>, 00161 op_trans> const & B) 00162 { 00163 // std::cout << "viennacl .. " << std::endl; 00164 return viennacl::matrix_expression< const viennacl::matrix_base<NumericT, F1>, 00165 const viennacl::matrix_expression<const viennacl::matrix_base<NumericT, F2>, 00166 const viennacl::matrix_base<NumericT, F2>, 00167 op_trans>, 00168 viennacl::op_mat_mat_prod >(A, B); 00169 } 00170 00171 // left factor transposed: 00172 template< typename NumericT, typename F1, typename F2> 00173 viennacl::matrix_expression< const viennacl::matrix_expression<const viennacl::matrix_base<NumericT, F1>, 00174 const viennacl::matrix_base<NumericT, F1>, 00175 op_trans>, 00176 const viennacl::matrix_base<NumericT, F2>, 00177 viennacl::op_mat_mat_prod > 00178 prod(viennacl::matrix_expression<const viennacl::matrix_base<NumericT, F1>, 00179 const viennacl::matrix_base<NumericT, F1>, 00180 op_trans> const & A, 00181 viennacl::matrix_base<NumericT, F2> const & B) 00182 { 00183 // std::cout << "viennacl .. " << std::endl; 00184 return viennacl::matrix_expression< const viennacl::matrix_expression<const viennacl::matrix_base<NumericT, F1>, 00185 const viennacl::matrix_base<NumericT, F1>, 00186 op_trans>, 00187 const viennacl::matrix_base<NumericT, F2>, 00188 viennacl::op_mat_mat_prod >(A, B); 00189 } 00190 00191 00192 // both factors transposed: 00193 template< typename NumericT, typename F1, typename F2> 00194 viennacl::matrix_expression< const viennacl::matrix_expression<const viennacl::matrix_base<NumericT, F1>, 00195 const viennacl::matrix_base<NumericT, F1>, 00196 op_trans>, 00197 const viennacl::matrix_expression<const viennacl::matrix_base<NumericT, F2>, 00198 const viennacl::matrix_base<NumericT, F2>, 00199 op_trans>, 00200 viennacl::op_mat_mat_prod > 00201 prod(viennacl::matrix_expression<const viennacl::matrix_base<NumericT, F1>, 00202 const viennacl::matrix_base<NumericT, F1>, 00203 op_trans> const & A, 00204 viennacl::matrix_expression<const viennacl::matrix_base<NumericT, F2>, 00205 const viennacl::matrix_base<NumericT, F2>, 00206 op_trans> const & B) 00207 { 00208 // std::cout << "viennacl .. " << std::endl; 00209 return viennacl::matrix_expression< const viennacl::matrix_expression<const viennacl::matrix_base<NumericT, F1>, 00210 const viennacl::matrix_base<NumericT, F1>, 00211 op_trans>, 00212 const viennacl::matrix_expression<const viennacl::matrix_base<NumericT, F2>, 00213 const viennacl::matrix_base<NumericT, F2>, 00214 op_trans>, 00215 viennacl::op_mat_mat_prod >(A, B); 00216 } 00217 00218 00219 00220 // matrix-vector product 00221 template< typename NumericT, typename F> 00222 viennacl::vector_expression< const viennacl::matrix_base<NumericT, F>, 00223 const viennacl::vector_base<NumericT>, 00224 viennacl::op_prod > 00225 prod(viennacl::matrix_base<NumericT, F> const & matrix, 00226 viennacl::vector_base<NumericT> const & vector) 00227 { 00228 // std::cout << "viennacl .. " << std::endl; 00229 return viennacl::vector_expression< const viennacl::matrix_base<NumericT, F>, 00230 const viennacl::vector_base<NumericT>, 00231 viennacl::op_prod >(matrix, vector); 00232 } 00233 00234 // transposed matrix-vector product 00235 template< typename NumericT, typename F> 00236 viennacl::vector_expression< const viennacl::matrix_expression<const viennacl::matrix_base<NumericT, F>, 00237 const viennacl::matrix_base<NumericT, F>, 00238 op_trans>, 00239 const viennacl::vector_base<NumericT>, 00240 viennacl::op_prod > 00241 prod(viennacl::matrix_expression<const viennacl::matrix_base<NumericT, F>, 00242 const viennacl::matrix_base<NumericT, F>, 00243 op_trans> const & matrix, 00244 viennacl::vector_base<NumericT> const & vector) 00245 { 00246 // std::cout << "viennacl .. " << std::endl; 00247 return viennacl::vector_expression< const viennacl::matrix_expression<const viennacl::matrix_base<NumericT, F>, 00248 const viennacl::matrix_base<NumericT, F>, 00249 op_trans>, 00250 const viennacl::vector_base<NumericT>, 00251 viennacl::op_prod >(matrix, vector); 00252 } 00253 00254 00255 template<typename SparseMatrixType, class SCALARTYPE> 00256 typename viennacl::enable_if< viennacl::is_any_sparse_matrix<SparseMatrixType>::value, 00257 vector_expression<const SparseMatrixType, 00258 const vector_base<SCALARTYPE>, 00259 op_prod > 00260 >::type 00261 prod(const SparseMatrixType & mat, 00262 const vector_base<SCALARTYPE> & vec) 00263 { 00264 return vector_expression<const SparseMatrixType, 00265 const vector_base<SCALARTYPE>, 00266 op_prod >(mat, vec); 00267 } 00268 00269 template< typename SparseMatrixType, typename SCALARTYPE, typename F1> 00270 typename viennacl::enable_if< viennacl::is_any_sparse_matrix<SparseMatrixType>::value, 00271 viennacl::matrix_expression<const SparseMatrixType, 00272 const matrix_base < SCALARTYPE, F1 >, 00273 op_prod > 00274 >::type 00275 prod(const SparseMatrixType & sp_mat, 00276 const viennacl::matrix_base<SCALARTYPE, F1> & d_mat) 00277 { 00278 return viennacl::matrix_expression<const SparseMatrixType, 00279 const viennacl::matrix_base < SCALARTYPE, F1 >, 00280 op_prod >(sp_mat, d_mat); 00281 } 00282 00283 // right factor is transposed 00284 template< typename SparseMatrixType, typename SCALARTYPE, typename F1 > 00285 typename viennacl::enable_if< viennacl::is_any_sparse_matrix<SparseMatrixType>::value, 00286 viennacl::matrix_expression< const SparseMatrixType, 00287 const viennacl::matrix_expression<const viennacl::matrix_base<SCALARTYPE, F1>, 00288 const viennacl::matrix_base<SCALARTYPE, F1>, 00289 op_trans>, 00290 viennacl::op_prod > 00291 >::type 00292 prod(const SparseMatrixType & A, 00293 viennacl::matrix_expression<const viennacl::matrix_base < SCALARTYPE, F1 >, 00294 const viennacl::matrix_base < SCALARTYPE, F1 >, 00295 op_trans> const & B) 00296 { 00297 return viennacl::matrix_expression< const SparseMatrixType, 00298 const viennacl::matrix_expression<const viennacl::matrix_base < SCALARTYPE, F1 >, 00299 const viennacl::matrix_base < SCALARTYPE, F1 >, 00300 op_trans>, 00301 viennacl::op_prod >(A, B); 00302 } 00303 00304 template<typename StructuredMatrixType, class SCALARTYPE> 00305 typename viennacl::enable_if< viennacl::is_any_dense_structured_matrix<StructuredMatrixType>::value, 00306 vector_expression<const StructuredMatrixType, 00307 const vector_base<SCALARTYPE>, 00308 op_prod > 00309 >::type 00310 prod(const StructuredMatrixType & mat, 00311 const vector_base<SCALARTYPE> & vec) 00312 { 00313 return vector_expression<const StructuredMatrixType, 00314 const vector_base<SCALARTYPE>, 00315 op_prod >(mat, vec); 00316 } 00317 00318 } // end namespace linalg 00319 } // end namespace viennacl 00320 #endif 00321 00322 00323 00324 00325