aboutsummaryrefslogtreecommitdiff
path: root/Source/FieldSolver/SpectralSolver/AnyFFT.H
blob: 79dbdf0e03d0b9977738d0b57fd875e2788f87ec (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
/* Copyright 2019-2020
 *
 * This file is part of WarpX.
 *
 * License: BSD-3-Clause-LBNL
 */

#ifndef ANYFFT_H_
#define ANYFFT_H_

#include <AMReX_Config.H>
#include <AMReX_LayoutData.H>

#if defined(AMREX_USE_CUDA)
#  include <cufft.h>
#elif defined(AMREX_USE_HIP)
#  if __has_include(<rocfft/rocfft.h>)  // ROCm 5.3+
#    include <rocfft/rocfft.h>
#  else
#    include <rocfft.h>
#  endif
#else
#  include <fftw3.h>
#endif

/**
 * Wrapper around FFT libraries. The header file defines the API and the base types
 * (Complex and VendorFFTPlan), and the implementation for different FFT libraries is
 * done in different cpp files. This wrapper only depends on the underlying FFT library
 * AND on AMReX (There is no dependence on WarpX).
 */
namespace AnyFFT
{
    // First, define library-dependent types (complex, FFT plan)

    /** Complex type for FFT, depends on FFT library */
#if defined(AMREX_USE_CUDA)
#  ifdef AMREX_USE_FLOAT
    using Complex = cuComplex;
#  else
    using Complex = cuDoubleComplex;
#  endif
#elif defined(AMREX_USE_HIP)
#  ifdef AMREX_USE_FLOAT
    using Complex = float2;
#  else
    using Complex = double2;
#  endif
#else
#  ifdef AMREX_USE_FLOAT
    using Complex = fftwf_complex;
#  else
    using Complex = fftw_complex;
#  endif
#endif

    /** Library-dependent FFT plans type, which holds one fft plan per box
     * (plans are only initialized for the boxes that are owned by the local MPI rank).
     */
#if defined(AMREX_USE_CUDA)
    using VendorFFTPlan = cufftHandle;
#elif defined(AMREX_USE_HIP)
    using VendorFFTPlan = rocfft_plan;
#else
#  ifdef AMREX_USE_FLOAT
    using VendorFFTPlan = fftwf_plan;
#  else
    using VendorFFTPlan = fftw_plan;
#  endif
#endif

    // Second, define library-independent API

    /** Direction in which the FFT is performed. */
    enum struct direction {R2C, C2R};

    /** This struct contains the vendor FFT plan and additional metadata
     */
    struct FFTplan
    {
        amrex::Real* m_real_array; /**< pointer to real array */
        Complex* m_complex_array; /**< pointer to complex array */
        VendorFFTPlan m_plan; /**< Vendor FFT plan */
        direction m_dir;  /**< direction (C2R or R2C) */
        int m_dim; /**< Dimensionality of the FFT plan */
    };

    /** Collection of FFT plans, one FFTplan per box */
    using FFTplans = amrex::LayoutData<FFTplan>;

    /** \brief create FFT plan for the backend FFT library.
     * \param[in] real_size Size of the real array, along each dimension.
     *                      Only the first dim elements are used.
     * \param[out] real_array Real array from/to where R2C/C2R FFT is performed
     * \param[out] complex_array Complex array to/from where R2C/C2R FFT is performed
     * \param[in] dir direction, either R2C or C2R
     * \param[in] dim direction, number of dimensions of the arrays. Must be <= AMREX_SPACEDIM.
     */
    FFTplan CreatePlan(const amrex::IntVect& real_size, amrex::Real* real_array,
                       Complex* complex_array, direction dir, int dim);

    /** \brief Destroy library FFT plan.
     * \param[out] fft_plan plan to destroy
     */
    void DestroyPlan(FFTplan& fft_plan);

    /** \brief Perform FFT with backend library.
     * \param[out] fft_plan plan for which the FFT is performed
     */
    void Execute(FFTplan& fft_plan);
}

#endif // ANYFFT_H_