ViennaCL - The Vienna Computing Library
1.5.1
|
00001 #ifndef VIENNACL_LINALG_OPENCL_KERNELS_COORDINATE_MATRIX_HPP 00002 #define VIENNACL_LINALG_OPENCL_KERNELS_COORDINATE_MATRIX_HPP 00003 00004 #include "viennacl/tools/tools.hpp" 00005 #include "viennacl/ocl/kernel.hpp" 00006 #include "viennacl/ocl/platform.hpp" 00007 #include "viennacl/ocl/utils.hpp" 00008 00009 #include "viennacl/linalg/opencl/common.hpp" 00010 00013 namespace viennacl 00014 { 00015 namespace linalg 00016 { 00017 namespace opencl 00018 { 00019 namespace kernels 00020 { 00021 00023 00024 template <typename StringType> 00025 void generate_coordinate_matrix_vec_mul(StringType & source, std::string const & numeric_string) 00026 { 00027 source.append("__kernel void vec_mul( \n"); 00028 source.append(" __global const uint2 * coords, \n");//(row_index, column_index) 00029 source.append(" __global const "); source.append(numeric_string); source.append(" * elements, \n"); 00030 source.append(" __global const uint * group_boundaries, \n"); 00031 source.append(" __global const "); source.append(numeric_string); source.append(" * x, \n"); 00032 source.append(" uint4 layout_x, \n"); 00033 source.append(" __global "); source.append(numeric_string); source.append(" * result, \n"); 00034 source.append(" uint4 layout_result, \n"); 00035 source.append(" __local unsigned int * shared_rows, \n"); 00036 source.append(" __local "); source.append(numeric_string); source.append(" * inter_results) \n"); 00037 source.append("{ \n"); 00038 source.append(" uint2 tmp; \n"); 00039 source.append(" "); source.append(numeric_string); source.append(" val; \n"); 00040 source.append(" uint group_start = group_boundaries[get_group_id(0)]; \n"); 00041 source.append(" uint group_end = group_boundaries[get_group_id(0) + 1]; \n"); 00042 source.append(" uint k_end = (group_end > group_start) ? 1 + (group_end - group_start - 1) / get_local_size(0) : 0; \n"); // -1 in order to have correct behavior if group_end - group_start == j * get_local_size(0) 00043 00044 source.append(" uint local_index = 0; \n"); 00045 00046 source.append(" for (uint k = 0; k < k_end; ++k) { \n"); 00047 source.append(" local_index = group_start + k * get_local_size(0) + get_local_id(0); \n"); 00048 00049 source.append(" tmp = (local_index < group_end) ? coords[local_index] : (uint2) 0; \n"); 00050 source.append(" val = (local_index < group_end) ? elements[local_index] * x[tmp.y * layout_x.y + layout_x.x] : 0; \n"); 00051 00052 //check for carry from previous loop run: 00053 source.append(" if (get_local_id(0) == 0 && k > 0) { \n"); 00054 source.append(" if (tmp.x == shared_rows[get_local_size(0)-1]) \n"); 00055 source.append(" val += inter_results[get_local_size(0)-1]; \n"); 00056 source.append(" else \n"); 00057 source.append(" result[shared_rows[get_local_size(0)-1] * layout_result.y + layout_result.x] = inter_results[get_local_size(0)-1]; \n"); 00058 source.append(" } \n"); 00059 00060 //segmented parallel reduction begin 00061 source.append(" barrier(CLK_LOCAL_MEM_FENCE); \n"); 00062 source.append(" shared_rows[get_local_id(0)] = tmp.x; \n"); 00063 source.append(" inter_results[get_local_id(0)] = val; \n"); 00064 source.append(" "); source.append(numeric_string); source.append(" left = 0; \n"); 00065 source.append(" barrier(CLK_LOCAL_MEM_FENCE); \n"); 00066 00067 source.append(" for (unsigned int stride = 1; stride < get_local_size(0); stride *= 2) { \n"); 00068 source.append(" left = (get_local_id(0) >= stride && tmp.x == shared_rows[get_local_id(0) - stride]) ? inter_results[get_local_id(0) - stride] : 0; \n"); 00069 source.append(" barrier(CLK_LOCAL_MEM_FENCE); \n"); 00070 source.append(" inter_results[get_local_id(0)] += left; \n"); 00071 source.append(" barrier(CLK_LOCAL_MEM_FENCE); \n"); 00072 source.append(" } \n"); 00073 //segmented parallel reduction end 00074 00075 source.append(" if (local_index < group_end && get_local_id(0) < get_local_size(0) - 1 && \n"); 00076 source.append(" shared_rows[get_local_id(0)] != shared_rows[get_local_id(0) + 1]) { \n"); 00077 source.append(" result[tmp.x * layout_result.y + layout_result.x] = inter_results[get_local_id(0)]; \n"); 00078 source.append(" } \n"); 00079 00080 source.append(" barrier(CLK_LOCAL_MEM_FENCE); \n"); 00081 source.append(" } \n"); //for k 00082 00083 source.append(" if (local_index + 1 == group_end) \n"); //write results of last active entry (this may not necessarily be the case already) 00084 source.append(" result[tmp.x * layout_result.y + layout_result.x] = inter_results[get_local_id(0)]; \n"); 00085 source.append("} \n"); 00086 00087 } 00088 00089 namespace detail 00090 { 00092 template <typename StringType> 00093 void generate_coordinate_matrix_dense_matrix_mul(StringType & source, std::string const & numeric_string, 00094 bool B_transposed, bool B_row_major, bool C_row_major) 00095 { 00096 source.append("__kernel void "); 00097 source.append(viennacl::linalg::opencl::detail::sparse_dense_matmult_kernel_name(B_transposed, B_row_major, C_row_major)); 00098 source.append("( \n"); 00099 source.append(" __global const uint2 * coords, \n");//(row_index, column_index) 00100 source.append(" __global const "); source.append(numeric_string); source.append(" * elements, \n"); 00101 source.append(" __global const uint * group_boundaries, \n"); 00102 source.append(" __global const "); source.append(numeric_string); source.append(" * d_mat, \n"); 00103 source.append(" unsigned int d_mat_row_start, \n"); 00104 source.append(" unsigned int d_mat_col_start, \n"); 00105 source.append(" unsigned int d_mat_row_inc, \n"); 00106 source.append(" unsigned int d_mat_col_inc, \n"); 00107 source.append(" unsigned int d_mat_row_size, \n"); 00108 source.append(" unsigned int d_mat_col_size, \n"); 00109 source.append(" unsigned int d_mat_internal_rows, \n"); 00110 source.append(" unsigned int d_mat_internal_cols, \n"); 00111 source.append(" __global "); source.append(numeric_string); source.append(" * result, \n"); 00112 source.append(" unsigned int result_row_start, \n"); 00113 source.append(" unsigned int result_col_start, \n"); 00114 source.append(" unsigned int result_row_inc, \n"); 00115 source.append(" unsigned int result_col_inc, \n"); 00116 source.append(" unsigned int result_row_size, \n"); 00117 source.append(" unsigned int result_col_size, \n"); 00118 source.append(" unsigned int result_internal_rows, \n"); 00119 source.append(" unsigned int result_internal_cols, \n"); 00120 source.append(" __local unsigned int * shared_rows, \n"); 00121 source.append(" __local "); source.append(numeric_string); source.append(" * inter_results) \n"); 00122 source.append("{ \n"); 00123 source.append(" uint2 tmp; \n"); 00124 source.append(" "); source.append(numeric_string); source.append(" val; \n"); 00125 source.append(" uint group_start = group_boundaries[get_group_id(0)]; \n"); 00126 source.append(" uint group_end = group_boundaries[get_group_id(0) + 1]; \n"); 00127 source.append(" uint k_end = (group_end > group_start) ? 1 + (group_end - group_start - 1) / get_local_size(0) : 0; \n"); // -1 in order to have correct behavior if group_end - group_start == j * get_local_size(0) 00128 00129 source.append(" uint local_index = 0; \n"); 00130 00131 source.append(" for (uint result_col = 0; result_col < result_col_size; ++result_col) { \n"); 00132 source.append(" for (uint k = 0; k < k_end; ++k) { \n"); 00133 source.append(" local_index = group_start + k * get_local_size(0) + get_local_id(0); \n"); 00134 00135 source.append(" tmp = (local_index < group_end) ? coords[local_index] : (uint2) 0; \n"); 00136 if (B_transposed && B_row_major) 00137 source.append(" val = (local_index < group_end) ? elements[local_index] * d_mat[ (d_mat_row_start + result_col * d_mat_row_inc) * d_mat_internal_cols + d_mat_col_start + tmp.y * d_mat_col_inc ] : 0; \n"); 00138 if (B_transposed && !B_row_major) 00139 source.append(" val = (local_index < group_end) ? elements[local_index] * d_mat[ (d_mat_row_start + result_col * d_mat_row_inc) + (d_mat_col_start + tmp.y * d_mat_col_inc) * d_mat_internal_rows ] : 0; \n"); 00140 else if (!B_transposed && B_row_major) 00141 source.append(" val = (local_index < group_end) ? elements[local_index] * d_mat[ (d_mat_row_start + tmp.y * d_mat_row_inc) * d_mat_internal_cols + d_mat_col_start + result_col * d_mat_col_inc ] : 0; \n"); 00142 else 00143 source.append(" val = (local_index < group_end) ? elements[local_index] * d_mat[ (d_mat_row_start + tmp.y * d_mat_row_inc) + (d_mat_col_start + result_col * d_mat_col_inc) * d_mat_internal_rows ] : 0; \n"); 00144 00145 //check for carry from previous loop run: 00146 source.append(" if (get_local_id(0) == 0 && k > 0) { \n"); 00147 source.append(" if (tmp.x == shared_rows[get_local_size(0)-1]) \n"); 00148 source.append(" val += inter_results[get_local_size(0)-1]; \n"); 00149 source.append(" else \n"); 00150 if (C_row_major) 00151 source.append(" result[(shared_rows[get_local_size(0)-1] * result_row_inc + result_row_start) * result_internal_cols + result_col_start + result_col * result_col_inc ] = inter_results[get_local_size(0)-1]; \n"); 00152 else 00153 source.append(" result[(shared_rows[get_local_size(0)-1] * result_row_inc + result_row_start) + (result_col_start + result_col * result_col_inc) * result_internal_rows ] = inter_results[get_local_size(0)-1]; \n"); 00154 source.append(" } \n"); 00155 00156 //segmented parallel reduction begin 00157 source.append(" barrier(CLK_LOCAL_MEM_FENCE); \n"); 00158 source.append(" shared_rows[get_local_id(0)] = tmp.x; \n"); 00159 source.append(" inter_results[get_local_id(0)] = val; \n"); 00160 source.append(" "); source.append(numeric_string); source.append(" left = 0; \n"); 00161 source.append(" barrier(CLK_LOCAL_MEM_FENCE); \n"); 00162 00163 source.append(" for (unsigned int stride = 1; stride < get_local_size(0); stride *= 2) { \n"); 00164 source.append(" left = (get_local_id(0) >= stride && tmp.x == shared_rows[get_local_id(0) - stride]) ? inter_results[get_local_id(0) - stride] : 0; \n"); 00165 source.append(" barrier(CLK_LOCAL_MEM_FENCE); \n"); 00166 source.append(" inter_results[get_local_id(0)] += left; \n"); 00167 source.append(" barrier(CLK_LOCAL_MEM_FENCE); \n"); 00168 source.append(" } \n"); 00169 //segmented parallel reduction end 00170 00171 source.append(" if (local_index < group_end && get_local_id(0) < get_local_size(0) - 1 && \n"); 00172 source.append(" shared_rows[get_local_id(0)] != shared_rows[get_local_id(0) + 1]) { \n"); 00173 if (C_row_major) 00174 source.append(" result[(tmp.x * result_row_inc + result_row_start) * result_internal_cols + result_col_start + result_col * result_col_inc ] = inter_results[get_local_id(0)]; \n"); 00175 else 00176 source.append(" result[(tmp.x * result_row_inc + result_row_start) + (result_col_start + result_col * result_col_inc) * result_internal_rows ] = inter_results[get_local_id(0)]; \n"); 00177 source.append(" } \n"); 00178 00179 source.append(" barrier(CLK_LOCAL_MEM_FENCE); \n"); 00180 source.append(" } \n"); //for k 00181 00182 source.append(" if (local_index + 1 == group_end) \n"); //write results of last active entry (this may not necessarily be the case already) 00183 if (C_row_major) 00184 source.append(" result[(tmp.x * result_row_inc + result_row_start) * result_internal_cols + result_col_start + result_col * result_col_inc ] = inter_results[get_local_id(0)]; \n"); 00185 else 00186 source.append(" result[(tmp.x * result_row_inc + result_row_start) + (result_col_start + result_col * result_col_inc) * result_internal_rows ] = inter_results[get_local_id(0)]; \n"); 00187 source.append(" } \n"); //for result_col 00188 source.append("} \n"); 00189 00190 } 00191 } 00192 00193 template <typename StringType> 00194 void generate_coordinate_matrix_dense_matrix_multiplication(StringType & source, std::string const & numeric_string) 00195 { 00196 detail::generate_coordinate_matrix_dense_matrix_mul(source, numeric_string, false, false, false); 00197 detail::generate_coordinate_matrix_dense_matrix_mul(source, numeric_string, false, false, true); 00198 detail::generate_coordinate_matrix_dense_matrix_mul(source, numeric_string, false, true, false); 00199 detail::generate_coordinate_matrix_dense_matrix_mul(source, numeric_string, false, true, true); 00200 00201 detail::generate_coordinate_matrix_dense_matrix_mul(source, numeric_string, true, false, false); 00202 detail::generate_coordinate_matrix_dense_matrix_mul(source, numeric_string, true, false, true); 00203 detail::generate_coordinate_matrix_dense_matrix_mul(source, numeric_string, true, true, false); 00204 detail::generate_coordinate_matrix_dense_matrix_mul(source, numeric_string, true, true, true); 00205 } 00206 00207 template <typename StringType> 00208 void generate_coordinate_matrix_row_info_extractor(StringType & source, std::string const & numeric_string) 00209 { 00210 source.append("__kernel void row_info_extractor( \n"); 00211 source.append(" __global const uint2 * coords, \n");//(row_index, column_index) 00212 source.append(" __global const "); source.append(numeric_string); source.append(" * elements, \n"); 00213 source.append(" __global const uint * group_boundaries, \n"); 00214 source.append(" __global "); source.append(numeric_string); source.append(" * result, \n"); 00215 source.append(" unsigned int option, \n"); 00216 source.append(" __local unsigned int * shared_rows, \n"); 00217 source.append(" __local "); source.append(numeric_string); source.append(" * inter_results) \n"); 00218 source.append("{ \n"); 00219 source.append(" uint2 tmp; \n"); 00220 source.append(" "); source.append(numeric_string); source.append(" val; \n"); 00221 source.append(" uint last_index = get_local_size(0) - 1; \n"); 00222 source.append(" uint group_start = group_boundaries[get_group_id(0)]; \n"); 00223 source.append(" uint group_end = group_boundaries[get_group_id(0) + 1]; \n"); 00224 source.append(" uint k_end = (group_end > group_start) ? 1 + (group_end - group_start - 1) / get_local_size(0) : ("); source.append(numeric_string); source.append(")0; \n"); // -1 in order to have correct behavior if group_end - group_start == j * get_local_size(0) 00225 00226 source.append(" uint local_index = 0; \n"); 00227 00228 source.append(" for (uint k = 0; k < k_end; ++k) \n"); 00229 source.append(" { \n"); 00230 source.append(" local_index = group_start + k * get_local_size(0) + get_local_id(0); \n"); 00231 00232 source.append(" tmp = (local_index < group_end) ? coords[local_index] : (uint2) 0; \n"); 00233 source.append(" val = (local_index < group_end && (option != 3 || tmp.x == tmp.y) ) ? elements[local_index] : 0; \n"); 00234 00235 //check for carry from previous loop run: 00236 source.append(" if (get_local_id(0) == 0 && k > 0) \n"); 00237 source.append(" { \n"); 00238 source.append(" if (tmp.x == shared_rows[last_index]) \n"); 00239 source.append(" { \n"); 00240 source.append(" switch (option) \n"); 00241 source.append(" { \n"); 00242 source.append(" case 0: \n"); //inf-norm 00243 source.append(" case 3: \n"); //diagonal entry 00244 source.append(" val = max(val, fabs(inter_results[last_index])); \n"); 00245 source.append(" break; \n"); 00246 00247 source.append(" case 1: \n"); //1-norm 00248 source.append(" val = fabs(val) + inter_results[last_index]; \n"); 00249 source.append(" break; \n"); 00250 00251 source.append(" case 2: \n"); //2-norm 00252 source.append(" val = sqrt(val * val + inter_results[last_index]); \n"); 00253 source.append(" break; \n"); 00254 00255 source.append(" default: \n"); 00256 source.append(" break; \n"); 00257 source.append(" } \n"); 00258 source.append(" } \n"); 00259 source.append(" else \n"); 00260 source.append(" { \n"); 00261 source.append(" switch (option) \n"); 00262 source.append(" { \n"); 00263 source.append(" case 0: \n"); //inf-norm 00264 source.append(" case 1: \n"); //1-norm 00265 source.append(" case 3: \n"); //diagonal entry 00266 source.append(" result[shared_rows[last_index]] = inter_results[last_index]; \n"); 00267 source.append(" break; \n"); 00268 00269 source.append(" case 2: \n"); //2-norm 00270 source.append(" result[shared_rows[last_index]] = sqrt(inter_results[last_index]); \n"); 00271 source.append(" default: \n"); 00272 source.append(" break; \n"); 00273 source.append(" } \n"); 00274 source.append(" } \n"); 00275 source.append(" } \n"); 00276 00277 //segmented parallel reduction begin 00278 source.append(" barrier(CLK_LOCAL_MEM_FENCE); \n"); 00279 source.append(" shared_rows[get_local_id(0)] = tmp.x; \n"); 00280 source.append(" switch (option) \n"); 00281 source.append(" { \n"); 00282 source.append(" case 0: \n"); 00283 source.append(" case 3: \n"); 00284 source.append(" inter_results[get_local_id(0)] = val; \n"); 00285 source.append(" break; \n"); 00286 source.append(" case 1: \n"); 00287 source.append(" inter_results[get_local_id(0)] = fabs(val); \n"); 00288 source.append(" break; \n"); 00289 source.append(" case 2: \n"); 00290 source.append(" inter_results[get_local_id(0)] = val * val; \n"); 00291 source.append(" default: \n"); 00292 source.append(" break; \n"); 00293 source.append(" } \n"); 00294 source.append(" "); source.append(numeric_string); source.append(" left = 0; \n"); 00295 source.append(" barrier(CLK_LOCAL_MEM_FENCE); \n"); 00296 00297 source.append(" for (unsigned int stride = 1; stride < get_local_size(0); stride *= 2) \n"); 00298 source.append(" { \n"); 00299 source.append(" left = (get_local_id(0) >= stride && tmp.x == shared_rows[get_local_id(0) - stride]) ? inter_results[get_local_id(0) - stride] : ("); source.append(numeric_string); source.append(")0; \n"); 00300 source.append(" barrier(CLK_LOCAL_MEM_FENCE); \n"); 00301 source.append(" switch (option) \n"); 00302 source.append(" { \n"); 00303 source.append(" case 0: \n"); //inf-norm 00304 source.append(" case 3: \n"); //diagonal entry 00305 source.append(" inter_results[get_local_id(0)] = max(inter_results[get_local_id(0)], left); \n"); 00306 source.append(" break; \n"); 00307 00308 source.append(" case 1: \n"); //1-norm 00309 source.append(" inter_results[get_local_id(0)] += left; \n"); 00310 source.append(" break; \n"); 00311 00312 source.append(" case 2: \n"); //2-norm 00313 source.append(" inter_results[get_local_id(0)] += left; \n"); 00314 source.append(" break; \n"); 00315 00316 source.append(" default: \n"); 00317 source.append(" break; \n"); 00318 source.append(" } \n"); 00319 source.append(" barrier(CLK_LOCAL_MEM_FENCE); \n"); 00320 source.append(" } \n"); 00321 //segmented parallel reduction end 00322 00323 source.append(" if (get_local_id(0) != last_index && \n"); 00324 source.append(" shared_rows[get_local_id(0)] != shared_rows[get_local_id(0) + 1] && \n"); 00325 source.append(" inter_results[get_local_id(0)] != 0) \n"); 00326 source.append(" { \n"); 00327 source.append(" result[tmp.x] = (option == 2) ? sqrt(inter_results[get_local_id(0)]) : inter_results[get_local_id(0)]; \n"); 00328 source.append(" } \n"); 00329 00330 source.append(" barrier(CLK_LOCAL_MEM_FENCE); \n"); 00331 source.append(" } \n"); //for k 00332 00333 source.append(" if (get_local_id(0) == last_index && inter_results[last_index] != 0) \n"); 00334 source.append(" result[tmp.x] = (option == 2) ? sqrt(inter_results[last_index]) : inter_results[last_index]; \n"); 00335 source.append("} \n"); 00336 } 00337 00339 00340 // main kernel class 00342 template <typename NumericT> 00343 struct coordinate_matrix 00344 { 00345 static std::string program_name() 00346 { 00347 return viennacl::ocl::type_to_string<NumericT>::apply() + "_coordinate_matrix"; 00348 } 00349 00350 static void init(viennacl::ocl::context & ctx) 00351 { 00352 viennacl::ocl::DOUBLE_PRECISION_CHECKER<NumericT>::apply(ctx); 00353 std::string numeric_string = viennacl::ocl::type_to_string<NumericT>::apply(); 00354 00355 static std::map<cl_context, bool> init_done; 00356 if (!init_done[ctx.handle().get()]) 00357 { 00358 std::string source; 00359 source.reserve(1024); 00360 00361 viennacl::ocl::append_double_precision_pragma<NumericT>(ctx, source); 00362 00363 generate_coordinate_matrix_vec_mul(source, numeric_string); 00364 generate_coordinate_matrix_dense_matrix_multiplication(source, numeric_string); 00365 generate_coordinate_matrix_row_info_extractor(source, numeric_string); 00366 00367 std::string prog_name = program_name(); 00368 #ifdef VIENNACL_BUILD_INFO 00369 std::cout << "Creating program " << prog_name << std::endl; 00370 #endif 00371 ctx.add_program(source, prog_name); 00372 init_done[ctx.handle().get()] = true; 00373 } //if 00374 } //init 00375 }; 00376 00377 } // namespace kernels 00378 } // namespace opencl 00379 } // namespace linalg 00380 } // namespace viennacl 00381 #endif 00382