ViennaCL - The Vienna Computing Library
1.5.1
|
00001 #ifndef VIENNACL_LINALG_OPENCL_DIRECT_SOLVE_HPP 00002 #define VIENNACL_LINALG_OPENCL_DIRECT_SOLVE_HPP 00003 00004 /* ========================================================================= 00005 Copyright (c) 2010-2014, Institute for Microelectronics, 00006 Institute for Analysis and Scientific Computing, 00007 TU Wien. 00008 Portions of this software are copyright by UChicago Argonne, LLC. 00009 00010 ----------------- 00011 ViennaCL - The Vienna Computing Library 00012 ----------------- 00013 00014 Project Head: Karl Rupp rupp@iue.tuwien.ac.at 00015 00016 (A list of authors and contributors can be found in the PDF manual) 00017 00018 License: MIT (X11), see file LICENSE in the base directory 00019 ============================================================================= */ 00020 00025 #include "viennacl/vector.hpp" 00026 #include "viennacl/matrix.hpp" 00027 #include "viennacl/ocl/kernel.hpp" 00028 #include "viennacl/ocl/device.hpp" 00029 #include "viennacl/ocl/handle.hpp" 00030 #include "viennacl/linalg/opencl/kernels/matrix_solve.hpp" 00031 00032 namespace viennacl 00033 { 00034 namespace linalg 00035 { 00036 namespace opencl 00037 { 00038 namespace detail 00039 { 00040 inline cl_uint get_option_for_solver_tag(viennacl::linalg::upper_tag) { return 0; } 00041 inline cl_uint get_option_for_solver_tag(viennacl::linalg::unit_upper_tag) { return (1 << 0); } 00042 inline cl_uint get_option_for_solver_tag(viennacl::linalg::lower_tag) { return (1 << 2); } 00043 inline cl_uint get_option_for_solver_tag(viennacl::linalg::unit_lower_tag) { return (1 << 2) | (1 << 0); } 00044 00045 template <typename M1, typename M2, typename KernelType> 00046 void inplace_solve_impl(M1 const & A, M2 & B, KernelType & k) 00047 { 00048 viennacl::ocl::enqueue(k(viennacl::traits::opencl_handle(A), 00049 cl_uint(viennacl::traits::start1(A)), cl_uint(viennacl::traits::start2(A)), 00050 cl_uint(viennacl::traits::stride1(A)), cl_uint(viennacl::traits::stride2(A)), 00051 cl_uint(viennacl::traits::size1(A)), cl_uint(viennacl::traits::size2(A)), 00052 cl_uint(viennacl::traits::internal_size1(A)), cl_uint(viennacl::traits::internal_size2(A)), 00053 viennacl::traits::opencl_handle(B), 00054 cl_uint(viennacl::traits::start1(B)), cl_uint(viennacl::traits::start2(B)), 00055 cl_uint(viennacl::traits::stride1(B)), cl_uint(viennacl::traits::stride2(B)), 00056 cl_uint(viennacl::traits::size1(B)), cl_uint(viennacl::traits::size2(B)), 00057 cl_uint(viennacl::traits::internal_size1(B)), cl_uint(viennacl::traits::internal_size2(B)) 00058 ) 00059 ); 00060 } 00061 } 00062 00063 00064 // 00065 // Note: By convention, all size checks are performed in the calling frontend. No need to double-check here. 00066 // 00067 00069 00074 template <typename NumericT, typename F1, typename F2, typename SOLVERTAG> 00075 void inplace_solve(const matrix_base<NumericT, F1> & A, matrix_base<NumericT, F2> & B, SOLVERTAG) 00076 { 00077 viennacl::ocl::context & ctx = const_cast<viennacl::ocl::context &>(viennacl::traits::opencl_handle(A).context()); 00078 00079 typedef viennacl::linalg::opencl::kernels::matrix_solve<NumericT, F1, F2> KernelClass; 00080 KernelClass::init(ctx); 00081 00082 std::stringstream ss; 00083 ss << SOLVERTAG::name() << "_solve"; 00084 viennacl::ocl::kernel & k = ctx.get_kernel(KernelClass::program_name(), ss.str()); 00085 00086 k.global_work_size(0, B.size2() * k.local_work_size()); 00087 detail::inplace_solve_impl(A, B, k); 00088 } 00089 00095 template <typename NumericT, typename F1, typename F2, typename SOLVERTAG> 00096 void inplace_solve(const matrix_base<NumericT, F1> & A, 00097 matrix_expression< const matrix_base<NumericT, F2>, const matrix_base<NumericT, F2>, op_trans> proxy_B, 00098 SOLVERTAG) 00099 { 00100 viennacl::ocl::context & ctx = const_cast<viennacl::ocl::context &>(viennacl::traits::opencl_handle(A).context()); 00101 00102 typedef viennacl::linalg::opencl::kernels::matrix_solve<NumericT, F1, F2> KernelClass; 00103 KernelClass::init(ctx); 00104 00105 std::stringstream ss; 00106 ss << SOLVERTAG::name() << "_trans_solve"; 00107 viennacl::ocl::kernel & k = ctx.get_kernel(KernelClass::program_name(), ss.str()); 00108 00109 k.global_work_size(0, proxy_B.lhs().size1() * k.local_work_size()); 00110 detail::inplace_solve_impl(A, proxy_B.lhs(), k); 00111 } 00112 00113 //upper triangular solver for transposed lower triangular matrices 00119 template <typename NumericT, typename F1, typename F2, typename SOLVERTAG> 00120 void inplace_solve(const matrix_expression< const matrix_base<NumericT, F1>, const matrix_base<NumericT, F1>, op_trans> & proxy_A, 00121 matrix_base<NumericT, F2> & B, 00122 SOLVERTAG) 00123 { 00124 viennacl::ocl::context & ctx = const_cast<viennacl::ocl::context &>(viennacl::traits::opencl_handle(B).context()); 00125 00126 typedef viennacl::linalg::opencl::kernels::matrix_solve<NumericT, F1, F2> KernelClass; 00127 KernelClass::init(ctx); 00128 00129 std::stringstream ss; 00130 ss << "trans_" << SOLVERTAG::name() << "_solve"; 00131 viennacl::ocl::kernel & k = ctx.get_kernel(KernelClass::program_name(), ss.str()); 00132 00133 k.global_work_size(0, B.size2() * k.local_work_size()); 00134 detail::inplace_solve_impl(proxy_A.lhs(), B, k); 00135 } 00136 00142 template <typename NumericT, typename F1, typename F2, typename SOLVERTAG> 00143 void inplace_solve(const matrix_expression< const matrix_base<NumericT, F1>, const matrix_base<NumericT, F1>, op_trans> & proxy_A, 00144 matrix_expression< const matrix_base<NumericT, F2>, const matrix_base<NumericT, F2>, op_trans> proxy_B, 00145 SOLVERTAG) 00146 { 00147 viennacl::ocl::context & ctx = const_cast<viennacl::ocl::context &>(viennacl::traits::opencl_handle(proxy_A.lhs()).context()); 00148 00149 typedef viennacl::linalg::opencl::kernels::matrix_solve<NumericT, F1, F2> KernelClass; 00150 KernelClass::init(ctx); 00151 00152 std::stringstream ss; 00153 ss << "trans_" << SOLVERTAG::name() << "_trans_solve"; 00154 viennacl::ocl::kernel & k = ctx.get_kernel(KernelClass::program_name(), ss.str()); 00155 00156 k.global_work_size(0, proxy_B.lhs().size1() * k.local_work_size()); 00157 detail::inplace_solve_impl(proxy_A.lhs(), proxy_B.lhs(), k); 00158 } 00159 00160 00161 00162 // 00163 // Solve on vector 00164 // 00165 00166 template <typename NumericT, typename F, typename SOLVERTAG> 00167 void inplace_solve(const matrix_base<NumericT, F> & mat, 00168 vector_base<NumericT> & vec, 00169 SOLVERTAG) 00170 { 00171 viennacl::ocl::context & ctx = const_cast<viennacl::ocl::context &>(viennacl::traits::opencl_handle(mat).context()); 00172 00173 typedef viennacl::linalg::opencl::kernels::matrix<NumericT, F> KernelClass; 00174 KernelClass::init(ctx); 00175 00176 cl_uint options = detail::get_option_for_solver_tag(SOLVERTAG()); 00177 viennacl::ocl::kernel & k = ctx.get_kernel(KernelClass::program_name(), "triangular_substitute_inplace"); 00178 00179 k.global_work_size(0, k.local_work_size()); 00180 viennacl::ocl::enqueue(k(viennacl::traits::opencl_handle(mat), 00181 cl_uint(viennacl::traits::start1(mat)), cl_uint(viennacl::traits::start2(mat)), 00182 cl_uint(viennacl::traits::stride1(mat)), cl_uint(viennacl::traits::stride2(mat)), 00183 cl_uint(viennacl::traits::size1(mat)), cl_uint(viennacl::traits::size2(mat)), 00184 cl_uint(viennacl::traits::internal_size1(mat)), cl_uint(viennacl::traits::internal_size2(mat)), 00185 viennacl::traits::opencl_handle(vec), 00186 cl_uint(viennacl::traits::start(vec)), 00187 cl_uint(viennacl::traits::stride(vec)), 00188 cl_uint(viennacl::traits::size(vec)), 00189 options 00190 ) 00191 ); 00192 } 00193 00199 template <typename NumericT, typename F, typename SOLVERTAG> 00200 void inplace_solve(const matrix_expression< const matrix_base<NumericT, F>, const matrix_base<NumericT, F>, op_trans> & proxy, 00201 vector_base<NumericT> & vec, 00202 SOLVERTAG) 00203 { 00204 viennacl::ocl::context & ctx = const_cast<viennacl::ocl::context &>(viennacl::traits::opencl_handle(vec).context()); 00205 00206 typedef viennacl::linalg::opencl::kernels::matrix<NumericT, F> KernelClass; 00207 KernelClass::init(ctx); 00208 00209 cl_uint options = detail::get_option_for_solver_tag(SOLVERTAG()) | 0x02; //add transpose-flag 00210 viennacl::ocl::kernel & k = ctx.get_kernel(KernelClass::program_name(), "triangular_substitute_inplace"); 00211 00212 k.global_work_size(0, k.local_work_size()); 00213 viennacl::ocl::enqueue(k(viennacl::traits::opencl_handle(proxy.lhs()), 00214 cl_uint(viennacl::traits::start1(proxy.lhs())), cl_uint(viennacl::traits::start2(proxy.lhs())), 00215 cl_uint(viennacl::traits::stride1(proxy.lhs())), cl_uint(viennacl::traits::stride2(proxy.lhs())), 00216 cl_uint(viennacl::traits::size1(proxy.lhs())), cl_uint(viennacl::traits::size2(proxy.lhs())), 00217 cl_uint(viennacl::traits::internal_size1(proxy.lhs())), cl_uint(viennacl::traits::internal_size2(proxy.lhs())), 00218 viennacl::traits::opencl_handle(vec), 00219 cl_uint(viennacl::traits::start(vec)), 00220 cl_uint(viennacl::traits::stride(vec)), 00221 cl_uint(viennacl::traits::size(vec)), 00222 options 00223 ) 00224 ); 00225 } 00226 00227 00228 } 00229 } 00230 } 00231 00232 #endif