ViennaCL - The Vienna Computing Library  1.5.1
viennacl/fft.hpp
Go to the documentation of this file.
00001 #ifndef VIENNACL_FFT_HPP
00002 #define VIENNACL_FFT_HPP
00003 
00004 /* =========================================================================
00005    Copyright (c) 2010-2014, Institute for Microelectronics,
00006                             Institute for Analysis and Scientific Computing,
00007                             TU Wien.
00008    Portions of this software are copyright by UChicago Argonne, LLC.
00009 
00010                             -----------------
00011                   ViennaCL - The Vienna Computing Library
00012                             -----------------
00013 
00014    Project Head:    Karl Rupp                   rupp@iue.tuwien.ac.at
00015 
00016    (A list of authors and contributors can be found in the PDF manual)
00017 
00018    License:         MIT (X11), see file LICENSE in the base directory
00019 ============================================================================= */
00020 
00025 #include <viennacl/vector.hpp>
00026 #include <viennacl/matrix.hpp>
00027 
00028 #include "viennacl/linalg/opencl/kernels/fft.hpp"
00029 
00030 #include <cmath>
00031 
00032 #include <stdexcept>
00033 
00034 namespace viennacl
00035 {
00036   namespace detail
00037   {
00038     namespace fft
00039     {
00040         const vcl_size_t MAX_LOCAL_POINTS_NUM = 512;
00041 
00042         namespace FFT_DATA_ORDER {
00043             enum DATA_ORDER {
00044                 ROW_MAJOR,
00045                 COL_MAJOR
00046             };
00047         }
00048     }
00049   }
00050 }
00051 
00053 namespace viennacl
00054 {
00055   namespace detail
00056   {
00057     namespace fft
00058     {
00059 
00060         inline bool is_radix2(vcl_size_t data_size) {
00061             return !((data_size > 2) && (data_size & (data_size - 1)));
00062 
00063         }
00064 
00065         inline vcl_size_t next_power_2(vcl_size_t n) {
00066             n = n - 1;
00067 
00068             vcl_size_t power = 1;
00069 
00070             while(power < sizeof(vcl_size_t) * 8) {
00071                 n = n | (n >> power);
00072                 power *= 2;
00073             }
00074 
00075             return n + 1;
00076         }
00077 
00078         inline vcl_size_t num_bits(vcl_size_t size)
00079         {
00080             vcl_size_t bits_datasize = 0;
00081             vcl_size_t ds = 1;
00082 
00083             while(ds < size)
00084             {
00085                 ds = ds << 1;
00086                 bits_datasize++;
00087             }
00088 
00089             return bits_datasize;
00090         }
00091 
00092 
00099         template<class SCALARTYPE>
00100         void direct(const viennacl::ocl::handle<cl_mem>& in,
00101                     const viennacl::ocl::handle<cl_mem>& out,
00102                     vcl_size_t size,
00103                     vcl_size_t stride,
00104                     vcl_size_t batch_num,
00105                     SCALARTYPE sign = -1.0f,
00106                     FFT_DATA_ORDER::DATA_ORDER data_order = FFT_DATA_ORDER::ROW_MAJOR
00107                     )
00108         {
00109           viennacl::ocl::context & ctx = const_cast<viennacl::ocl::context &>(in.context());
00110           viennacl::linalg::opencl::kernels::fft<SCALARTYPE>::init(ctx);
00111 
00112           std::string program_string = viennacl::linalg::opencl::kernels::matrix<SCALARTYPE, row_major>::program_name();
00113           if (data_order == FFT_DATA_ORDER::COL_MAJOR)
00114           {
00115             viennacl::linalg::opencl::kernels::matrix<SCALARTYPE, column_major>::init(ctx);
00116             program_string = viennacl::linalg::opencl::kernels::matrix<SCALARTYPE, column_major>::program_name();
00117           }
00118           else
00119             viennacl::linalg::opencl::kernels::matrix<SCALARTYPE, row_major>::init(ctx);
00120           viennacl::ocl::kernel& kernel = ctx.get_kernel(program_string, "fft_direct");
00121           viennacl::ocl::enqueue(kernel(in, out, static_cast<cl_uint>(size), static_cast<cl_uint>(stride), static_cast<cl_uint>(batch_num), sign));
00122         }
00123 
00124         /*
00125         * This function performs reorder of input data. Indexes are sorted in bit-reversal order.
00126         * Such reordering should be done before in-place FFT.
00127         */
00128         template <typename SCALARTYPE>
00129         void reorder(const viennacl::ocl::handle<cl_mem>& in,
00130                      vcl_size_t size,
00131                      vcl_size_t stride,
00132                      vcl_size_t bits_datasize,
00133                      vcl_size_t batch_num,
00134                      FFT_DATA_ORDER::DATA_ORDER data_order = FFT_DATA_ORDER::ROW_MAJOR
00135                      )
00136         {
00137           viennacl::ocl::context & ctx = const_cast<viennacl::ocl::context &>(in.context());
00138           viennacl::linalg::opencl::kernels::fft<SCALARTYPE>::init(ctx);
00139 
00140           std::string program_string = viennacl::linalg::opencl::kernels::matrix<SCALARTYPE, row_major>::program_name();
00141           if (data_order == FFT_DATA_ORDER::COL_MAJOR)
00142           {
00143             viennacl::linalg::opencl::kernels::matrix<SCALARTYPE, column_major>::init(ctx);
00144             program_string = viennacl::linalg::opencl::kernels::matrix<SCALARTYPE, column_major>::program_name();
00145           }
00146           else
00147             viennacl::linalg::opencl::kernels::matrix<SCALARTYPE, row_major>::init(ctx);
00148 
00149           viennacl::ocl::kernel& kernel = ctx.get_kernel(program_string, "fft_reorder");
00150           viennacl::ocl::enqueue(kernel(in,
00151                                         static_cast<cl_uint>(bits_datasize),
00152                                         static_cast<cl_uint>(size),
00153                                         static_cast<cl_uint>(stride),
00154                                         static_cast<cl_uint>(batch_num)
00155                                        )
00156                                 );
00157         }
00158 
00166         template<class SCALARTYPE>
00167         void radix2(const viennacl::ocl::handle<cl_mem>& in,
00168                     vcl_size_t size,
00169                     vcl_size_t stride,
00170                     vcl_size_t batch_num,
00171                     SCALARTYPE sign = -1.0f,
00172                     FFT_DATA_ORDER::DATA_ORDER data_order = FFT_DATA_ORDER::ROW_MAJOR
00173                     )
00174         {
00175           viennacl::ocl::context & ctx = const_cast<viennacl::ocl::context &>(in.context());
00176           viennacl::linalg::opencl::kernels::fft<SCALARTYPE>::init(ctx);
00177 
00178             assert(batch_num != 0);
00179             assert(is_radix2(size));
00180 
00181             std::string program_string = viennacl::linalg::opencl::kernels::matrix<SCALARTYPE, row_major>::program_name();
00182             if (data_order == FFT_DATA_ORDER::COL_MAJOR)
00183             {
00184               viennacl::linalg::opencl::kernels::matrix<SCALARTYPE, column_major>::init(ctx);
00185               program_string = viennacl::linalg::opencl::kernels::matrix<SCALARTYPE, column_major>::program_name();
00186             }
00187             else
00188               viennacl::linalg::opencl::kernels::matrix<SCALARTYPE, row_major>::init(ctx);
00189 
00190             vcl_size_t bits_datasize = num_bits(size);
00191 
00192             if(size <= MAX_LOCAL_POINTS_NUM)
00193             {
00194                 viennacl::ocl::kernel& kernel = ctx.get_kernel(program_string, "fft_radix2_local");
00195                 viennacl::ocl::enqueue(kernel(in,
00196                                               viennacl::ocl::local_mem((size * 4) * sizeof(SCALARTYPE)),
00197                                               static_cast<cl_uint>(bits_datasize),
00198                                               static_cast<cl_uint>(size),
00199                                               static_cast<cl_uint>(stride),
00200                                               static_cast<cl_uint>(batch_num),
00201                                               sign));
00202             }
00203             else
00204             {
00205                 reorder<SCALARTYPE>(in, size, stride, bits_datasize, batch_num);
00206 
00207                 for(vcl_size_t step = 0; step < bits_datasize; step++)
00208                 {
00209                     viennacl::ocl::kernel& kernel = ctx.get_kernel(program_string, "fft_radix2");
00210                     viennacl::ocl::enqueue(kernel(in,
00211                                                   static_cast<cl_uint>(step),
00212                                                   static_cast<cl_uint>(bits_datasize),
00213                                                   static_cast<cl_uint>(size),
00214                                                   static_cast<cl_uint>(stride),
00215                                                   static_cast<cl_uint>(batch_num),
00216                                                   sign));
00217                 }
00218 
00219             }
00220         }
00221 
00229         template<class SCALARTYPE, unsigned int ALIGNMENT>
00230         void bluestein(viennacl::vector<SCALARTYPE, ALIGNMENT>& in,
00231                        viennacl::vector<SCALARTYPE, ALIGNMENT>& out,
00232                        vcl_size_t /*batch_num*/)
00233         {
00234           viennacl::ocl::context & ctx = const_cast<viennacl::ocl::context &>(viennacl::traits::opencl_handle(in).context());
00235           viennacl::linalg::opencl::kernels::fft<SCALARTYPE>::init(ctx);
00236 
00237           vcl_size_t size = in.size() >> 1;
00238           vcl_size_t ext_size = next_power_2(2 * size - 1);
00239 
00240           viennacl::vector<SCALARTYPE, ALIGNMENT> A(ext_size << 1);
00241           viennacl::vector<SCALARTYPE, ALIGNMENT> B(ext_size << 1);
00242 
00243           viennacl::vector<SCALARTYPE, ALIGNMENT> Z(ext_size << 1);
00244 
00245             {
00246                 viennacl::ocl::kernel& kernel = ctx.get_kernel(viennacl::linalg::opencl::kernels::fft<SCALARTYPE>::program_name(), "zero2");
00247                 viennacl::ocl::enqueue(kernel(
00248                                             A,
00249                                             B,
00250                                             static_cast<cl_uint>(ext_size)
00251                                             ));
00252 
00253             }
00254             {
00255                 viennacl::ocl::kernel& kernel = ctx.get_kernel(viennacl::linalg::opencl::kernels::fft<SCALARTYPE>::program_name(), "bluestein_pre");
00256                 viennacl::ocl::enqueue(kernel(
00257                                            in,
00258                                            A,
00259                                            B,
00260                                            static_cast<cl_uint>(size),
00261                                            static_cast<cl_uint>(ext_size)
00262                                        ));
00263             }
00264 
00265             viennacl::linalg::convolve_i(A, B, Z);
00266 
00267             {
00268                 viennacl::ocl::kernel& kernel = ctx.get_kernel(viennacl::linalg::opencl::kernels::fft<SCALARTYPE>::program_name(), "bluestein_post");
00269                 viennacl::ocl::enqueue(kernel(
00270                                             Z,
00271                                             out,
00272                                             static_cast<cl_uint>(size)
00273                                             ));
00274             }
00275         }
00276 
00277         template<class SCALARTYPE, unsigned int ALIGNMENT>
00278         void multiply(viennacl::vector<SCALARTYPE, ALIGNMENT> const & input1,
00279                       viennacl::vector<SCALARTYPE, ALIGNMENT> const & input2,
00280                       viennacl::vector<SCALARTYPE, ALIGNMENT> & output)
00281         {
00282           viennacl::ocl::context & ctx = const_cast<viennacl::ocl::context &>(viennacl::traits::opencl_handle(input1).context());
00283           viennacl::linalg::opencl::kernels::fft<SCALARTYPE>::init(ctx);
00284           vcl_size_t size = input1.size() >> 1;
00285           viennacl::ocl::kernel& kernel = ctx.get_kernel(viennacl::linalg::opencl::kernels::fft<SCALARTYPE>::program_name(), "fft_mult_vec");
00286           viennacl::ocl::enqueue(kernel(input1, input2, output, static_cast<cl_uint>(size)));
00287         }
00288 
00289         template<class SCALARTYPE, unsigned int ALIGNMENT>
00290         void normalize(viennacl::vector<SCALARTYPE, ALIGNMENT> & input)
00291         {
00292           viennacl::ocl::context & ctx = const_cast<viennacl::ocl::context &>(viennacl::traits::opencl_handle(input).context());
00293           viennacl::linalg::opencl::kernels::fft<SCALARTYPE>::init(ctx);
00294 
00295           viennacl::ocl::kernel& kernel = ctx.get_kernel(viennacl::linalg::opencl::kernels::fft<SCALARTYPE>::program_name(), "fft_div_vec_scalar");
00296           vcl_size_t size = input.size() >> 1;
00297           SCALARTYPE norm_factor = static_cast<SCALARTYPE>(size);
00298           viennacl::ocl::enqueue(kernel(input, static_cast<cl_uint>(size), norm_factor));
00299         }
00300 
00301         template<class SCALARTYPE, unsigned int ALIGNMENT>
00302         void transpose(viennacl::matrix<SCALARTYPE, viennacl::row_major, ALIGNMENT> & input)
00303         {
00304           viennacl::ocl::context & ctx = const_cast<viennacl::ocl::context &>(viennacl::traits::opencl_handle(input).context());
00305           viennacl::linalg::opencl::kernels::fft<SCALARTYPE>::init(ctx);
00306 
00307           viennacl::ocl::kernel& kernel = ctx.get_kernel(viennacl::linalg::opencl::kernels::fft<SCALARTYPE>::program_name(), "transpose_inplace");
00308           viennacl::ocl::enqueue(kernel(input,
00309                                         static_cast<cl_uint>(input.internal_size1()),
00310                                         static_cast<cl_uint>(input.internal_size2()) >> 1));
00311         }
00312 
00313         template<class SCALARTYPE, unsigned int ALIGNMENT>
00314         void transpose(viennacl::matrix<SCALARTYPE, viennacl::row_major, ALIGNMENT> const & input,
00315                        viennacl::matrix<SCALARTYPE, viennacl::row_major, ALIGNMENT> & output)
00316         {
00317           viennacl::ocl::context & ctx = const_cast<viennacl::ocl::context &>(viennacl::traits::opencl_handle(input).context());
00318           viennacl::linalg::opencl::kernels::fft<SCALARTYPE>::init(ctx);
00319 
00320           viennacl::ocl::kernel& kernel = ctx.get_kernel(viennacl::linalg::opencl::kernels::fft<SCALARTYPE>::program_name(), "transpose");
00321           viennacl::ocl::enqueue(kernel(input,
00322                                         output,
00323                                         static_cast<cl_uint>(input.internal_size1()),
00324                                         static_cast<cl_uint>(input.internal_size2() >> 1))
00325                                 );
00326         }
00327 
00328         template<class SCALARTYPE>
00329         void real_to_complex(viennacl::vector_base<SCALARTYPE> const & in,
00330                              viennacl::vector_base<SCALARTYPE> & out,
00331                              vcl_size_t size)
00332         {
00333           viennacl::ocl::context & ctx = const_cast<viennacl::ocl::context &>(viennacl::traits::opencl_handle(in).context());
00334           viennacl::linalg::opencl::kernels::fft<SCALARTYPE>::init(ctx);
00335           viennacl::ocl::kernel & kernel = ctx.get_kernel(viennacl::linalg::opencl::kernels::fft<SCALARTYPE>::program_name(), "real_to_complex");
00336           viennacl::ocl::enqueue(kernel(in, out, static_cast<cl_uint>(size)));
00337         }
00338 
00339         template<class SCALARTYPE>
00340         void complex_to_real(viennacl::vector_base<SCALARTYPE> const & in,
00341                              viennacl::vector_base<SCALARTYPE>& out,
00342                              vcl_size_t size)
00343         {
00344           viennacl::ocl::context & ctx = const_cast<viennacl::ocl::context &>(viennacl::traits::opencl_handle(in).context());
00345           viennacl::linalg::opencl::kernels::fft<SCALARTYPE>::init(ctx);
00346           viennacl::ocl::kernel& kernel = ctx.get_kernel(viennacl::linalg::opencl::kernels::fft<SCALARTYPE>::program_name(), "complex_to_real");
00347           viennacl::ocl::enqueue(kernel(in, out, static_cast<cl_uint>(size)));
00348         }
00349 
00350         template<class SCALARTYPE>
00351         void reverse(viennacl::vector_base<SCALARTYPE>& in)
00352         {
00353           viennacl::ocl::context & ctx = const_cast<viennacl::ocl::context &>(viennacl::traits::opencl_handle(in).context());
00354           viennacl::linalg::opencl::kernels::fft<SCALARTYPE>::init(ctx);
00355           vcl_size_t size = in.size();
00356           viennacl::ocl::kernel& kernel = ctx.get_kernel(viennacl::linalg::opencl::kernels::fft<SCALARTYPE>::program_name(), "reverse_inplace");
00357           viennacl::ocl::enqueue(kernel(in, static_cast<cl_uint>(size)));
00358         }
00359 
00360 
00361     } //namespace fft
00362   } //namespace detail
00363 
00371   template<class SCALARTYPE, unsigned int ALIGNMENT>
00372   void inplace_fft(viennacl::vector<SCALARTYPE, ALIGNMENT>& input,
00373             vcl_size_t batch_num = 1,
00374             SCALARTYPE sign = -1.0)
00375   {
00376       vcl_size_t size = (input.size() >> 1) / batch_num;
00377 
00378       if(!viennacl::detail::fft::is_radix2(size))
00379       {
00380           viennacl::vector<SCALARTYPE, ALIGNMENT> output(input.size());
00381           viennacl::detail::fft::direct(viennacl::traits::opencl_handle(input),
00382                                         viennacl::traits::opencl_handle(output),
00383                                         size,
00384                                         size,
00385                                         batch_num,
00386                                         sign);
00387 
00388           viennacl::copy(output, input);
00389       } else {
00390           viennacl::detail::fft::radix2(viennacl::traits::opencl_handle(input), size, size, batch_num, sign);
00391       }
00392   }
00393 
00402   template<class SCALARTYPE, unsigned int ALIGNMENT>
00403   void fft(viennacl::vector<SCALARTYPE, ALIGNMENT>& input,
00404             viennacl::vector<SCALARTYPE, ALIGNMENT>& output,
00405             vcl_size_t batch_num = 1,
00406             SCALARTYPE sign = -1.0
00407             )
00408   {
00409       vcl_size_t size = (input.size() >> 1) / batch_num;
00410 
00411       if(viennacl::detail::fft::is_radix2(size))
00412       {
00413           viennacl::copy(input, output);
00414           viennacl::detail::fft::radix2(viennacl::traits::opencl_handle(output), size, size, batch_num, sign);
00415       } else {
00416           viennacl::detail::fft::direct(viennacl::traits::opencl_handle(input),
00417                                         viennacl::traits::opencl_handle(output),
00418                                         size,
00419                                         size,
00420                                         batch_num,
00421                                         sign);
00422       }
00423   }
00424 
00431   template<class SCALARTYPE, unsigned int ALIGNMENT>
00432   void inplace_fft(viennacl::matrix<SCALARTYPE, viennacl::row_major, ALIGNMENT>& input,
00433             SCALARTYPE sign = -1.0)
00434   {
00435       vcl_size_t rows_num = input.size1();
00436       vcl_size_t cols_num = input.size2() >> 1;
00437 
00438       vcl_size_t cols_int = input.internal_size2() >> 1;
00439 
00440       // batch with rows
00441       if(viennacl::detail::fft::is_radix2(cols_num))
00442       {
00443           viennacl::detail::fft::radix2(viennacl::traits::opencl_handle(input), cols_num, cols_int, rows_num, sign, viennacl::detail::fft::FFT_DATA_ORDER::ROW_MAJOR);
00444       }
00445       else
00446       {
00447           viennacl::matrix<SCALARTYPE, viennacl::row_major, ALIGNMENT> output(input.size1(), input.size2());
00448 
00449           viennacl::detail::fft::direct(viennacl::traits::opencl_handle(input),
00450                                         viennacl::traits::opencl_handle(output),
00451                                         cols_num,
00452                                         cols_int,
00453                                         rows_num,
00454                                         sign,
00455                                         viennacl::detail::fft::FFT_DATA_ORDER::ROW_MAJOR
00456                                         );
00457 
00458           input = output;
00459       }
00460 
00461       // batch with cols
00462       if (viennacl::detail::fft::is_radix2(rows_num)) {
00463           viennacl::detail::fft::radix2(viennacl::traits::opencl_handle(input), rows_num, cols_int, cols_num, sign, viennacl::detail::fft::FFT_DATA_ORDER::COL_MAJOR);
00464       } else {
00465           viennacl::matrix<SCALARTYPE, viennacl::row_major, ALIGNMENT> output(input.size1(), input.size2());
00466 
00467           viennacl::detail::fft::direct(viennacl::traits::opencl_handle(input),
00468                                         viennacl::traits::opencl_handle(output),
00469                                         rows_num,
00470                                         cols_int,
00471                                         cols_num,
00472                                         sign,
00473                                         viennacl::detail::fft::FFT_DATA_ORDER::COL_MAJOR);
00474 
00475           input = output;
00476       }
00477 
00478   }
00479 
00487   template<class SCALARTYPE, unsigned int ALIGNMENT>
00488   void fft(viennacl::matrix<SCALARTYPE, viennacl::row_major, ALIGNMENT>& input,
00489             viennacl::matrix<SCALARTYPE, viennacl::row_major, ALIGNMENT>& output,
00490             SCALARTYPE sign = -1.0)
00491   {
00492       vcl_size_t rows_num = input.size1();
00493       vcl_size_t cols_num = input.size2() >> 1;
00494 
00495       vcl_size_t cols_int = input.internal_size2() >> 1;
00496 
00497       // batch with rows
00498       if(viennacl::detail::fft::is_radix2(cols_num))
00499       {
00500           output = input;
00501           viennacl::detail::fft::radix2(viennacl::traits::opencl_handle(output), cols_num, cols_int, rows_num, sign, viennacl::detail::fft::FFT_DATA_ORDER::ROW_MAJOR);
00502       }
00503       else
00504       {
00505           viennacl::detail::fft::direct(viennacl::traits::opencl_handle(input),
00506                                         viennacl::traits::opencl_handle(output),
00507                                         cols_num,
00508                                         cols_int,
00509                                         rows_num,
00510                                         sign,
00511                                         viennacl::detail::fft::FFT_DATA_ORDER::ROW_MAJOR
00512                                         );
00513       }
00514 
00515       // batch with cols
00516       if(viennacl::detail::fft::is_radix2(rows_num))
00517       {
00518           viennacl::detail::fft::radix2(viennacl::traits::opencl_handle(output), rows_num, cols_int, cols_num, sign, viennacl::detail::fft::FFT_DATA_ORDER::COL_MAJOR);
00519       }
00520       else
00521       {
00522           viennacl::matrix<SCALARTYPE, viennacl::row_major, ALIGNMENT> tmp(output.size1(), output.size2());
00523           tmp = output;
00524 
00525           viennacl::detail::fft::direct(viennacl::traits::opencl_handle(tmp),
00526                               viennacl::traits::opencl_handle(output),
00527                               rows_num,
00528                               cols_int,
00529                               cols_num,
00530                               sign,
00531                               viennacl::detail::fft::FFT_DATA_ORDER::COL_MAJOR);
00532       }
00533   }
00534 
00544   template<class SCALARTYPE, unsigned int ALIGNMENT>
00545   void inplace_ifft(viennacl::vector<SCALARTYPE, ALIGNMENT>& input,
00546             vcl_size_t batch_num = 1)
00547   {
00548       viennacl::inplace_fft(input, batch_num, SCALARTYPE(1.0));
00549       viennacl::detail::fft::normalize(input);
00550   }
00551 
00562   template<class SCALARTYPE, unsigned int ALIGNMENT>
00563   void ifft(viennacl::vector<SCALARTYPE, ALIGNMENT>& input,
00564             viennacl::vector<SCALARTYPE, ALIGNMENT>& output,
00565             vcl_size_t batch_num = 1
00566             )
00567   {
00568       viennacl::fft(input, output, batch_num, SCALARTYPE(1.0));
00569       viennacl::detail::fft::normalize(output);
00570   }
00571 
00572   namespace linalg
00573   {
00583     template<class SCALARTYPE, unsigned int ALIGNMENT>
00584     void convolve(viennacl::vector<SCALARTYPE, ALIGNMENT>& input1,
00585                   viennacl::vector<SCALARTYPE, ALIGNMENT>& input2,
00586                   viennacl::vector<SCALARTYPE, ALIGNMENT>& output
00587                   )
00588     {
00589         assert(input1.size() == input2.size());
00590         assert(input1.size() == output.size());
00591         //temporal arrays
00592         viennacl::vector<SCALARTYPE, ALIGNMENT> tmp1(input1.size());
00593         viennacl::vector<SCALARTYPE, ALIGNMENT> tmp2(input2.size());
00594         viennacl::vector<SCALARTYPE, ALIGNMENT> tmp3(output.size());
00595 
00596         // align input arrays to equal size
00597         // FFT of input data
00598         viennacl::fft(input1, tmp1);
00599         viennacl::fft(input2, tmp2);
00600 
00601         // multiplication of input data
00602         viennacl::detail::fft::multiply(tmp1, tmp2, tmp3);
00603         // inverse FFT of input data
00604         viennacl::ifft(tmp3, output);
00605     }
00606 
00616     template<class SCALARTYPE, unsigned int ALIGNMENT>
00617     void convolve_i(viennacl::vector<SCALARTYPE, ALIGNMENT>& input1,
00618                     viennacl::vector<SCALARTYPE, ALIGNMENT>& input2,
00619                     viennacl::vector<SCALARTYPE, ALIGNMENT>& output
00620                     )
00621     {
00622         assert(input1.size() == input2.size());
00623         assert(input1.size() == output.size());
00624 
00625         viennacl::inplace_fft(input1);
00626         viennacl::inplace_fft(input2);
00627 
00628         viennacl::detail::fft::multiply(input1, input2, output);
00629 
00630         viennacl::inplace_ifft(output);
00631     }
00632   }
00633 } //namespace linalg
00634 
00636 #endif