aboutsummaryrefslogtreecommitdiff
path: root/Source/FieldSolver/SpectralSolver/WrapRocFFT.cpp
blob: 66c9c6c5004102855f44ad7261867479651bb8a9 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
/* Copyright 2019-2020
 *
 * This file is part of WarpX.
 *
 * License: BSD-3-Clause-LBNL
 */

#include "AnyFFT.H"

#include "Utils/TextMsg.H"

namespace AnyFFT
{

    std::string rocfftErrorToString (const rocfft_status err);

    namespace {
        void assert_rocfft_status (std::string const& name, rocfft_status status)
        {
            if (status != rocfft_status_success) {
                amrex::Abort(Utils::TextMsg::Err(
                    name + " failed! Error: " + rocfftErrorToString(status)));
            }
        }
    }

    FFTplan CreatePlan (const amrex::IntVect& real_size, amrex::Real * const real_array,
                        Complex * const complex_array, const direction dir, const int dim)
    {
        FFTplan fft_plan;

        const std::size_t lengths[] = {AMREX_D_DECL(std::size_t(real_size[0]),
                                                    std::size_t(real_size[1]),
                                                    std::size_t(real_size[2]))};

        // Initialize fft_plan.m_plan with the vendor fft plan.
        rocfft_status result = rocfft_plan_create(&(fft_plan.m_plan),
                                                  rocfft_placement_notinplace,
                                                  (dir == direction::R2C)
                                                      ? rocfft_transform_type_real_forward
                                                      : rocfft_transform_type_real_inverse,
#ifdef AMREX_USE_FLOAT
                                                  rocfft_precision_single,
#else
                                                  rocfft_precision_double,
#endif
                                                  dim, lengths,
                                                  1, // number of transforms,
                                                  nullptr);
        assert_rocfft_status("rocfft_plan_create", result);

        // Store meta-data in fft_plan
        fft_plan.m_real_array = real_array;
        fft_plan.m_complex_array = complex_array;
        fft_plan.m_dir = dir;
        fft_plan.m_dim = dim;

        return fft_plan;
    }

    void DestroyPlan (FFTplan& fft_plan)
    {
        rocfft_plan_destroy( fft_plan.m_plan );
    }

    void Execute (FFTplan& fft_plan)
    {
        rocfft_execution_info execinfo = NULL;
        rocfft_status result = rocfft_execution_info_create(&execinfo);
        assert_rocfft_status("rocfft_execution_info_create", result);

        std::size_t buffersize = 0;
        result = rocfft_plan_get_work_buffer_size(fft_plan.m_plan, &buffersize);
        assert_rocfft_status("rocfft_plan_get_work_buffer_size", result);

        void* buffer = amrex::The_Arena()->alloc(buffersize);
        result = rocfft_execution_info_set_work_buffer(execinfo, buffer, buffersize);
        assert_rocfft_status("rocfft_execution_info_set_work_buffer", result);

        result = rocfft_execution_info_set_stream(execinfo, amrex::Gpu::gpuStream());
        assert_rocfft_status("rocfft_execution_info_set_stream", result);

        if (fft_plan.m_dir == direction::R2C) {
            result = rocfft_execute(fft_plan.m_plan,
                                    (void**)&(fft_plan.m_real_array), // in
                                    (void**)&(fft_plan.m_complex_array), // out
                                    execinfo);
        } else if (fft_plan.m_dir == direction::C2R) {
            result = rocfft_execute(fft_plan.m_plan,
                                    (void**)&(fft_plan.m_complex_array), // in
                                    (void**)&(fft_plan.m_real_array), // out
                                    execinfo);
        } else {
            amrex::Abort(Utils::TextMsg::Err(
                "direction must be AnyFFT::direction::R2C or AnyFFT::direction::C2R"));
        }

        assert_rocfft_status("rocfft_execute", result);

        amrex::Gpu::streamSynchronize();

        amrex::The_Arena()->free(buffer);

        result = rocfft_execution_info_destroy(execinfo);
        assert_rocfft_status("rocfft_execution_info_destroy", result);
    }

    /** \brief This method converts a rocfftResult
     * into the corresponding string
     *
     * @param[in] err a rocfftResult
     * @return an std::string
     */
    std::string rocfftErrorToString (const rocfft_status err)
    {
        if              (err == rocfft_status_success) {
            return std::string("rocfft_status_success");
        } else if       (err == rocfft_status_failure) {
            return std::string("rocfft_status_failure");
        } else if       (err == rocfft_status_invalid_arg_value) {
            return std::string("rocfft_status_invalid_arg_value");
        } else if       (err == rocfft_status_invalid_dimensions) {
            return std::string("rocfft_status_invalid_dimensions");
        } else if       (err == rocfft_status_invalid_array_type) {
            return std::string("rocfft_status_invalid_array_type");
        } else if       (err == rocfft_status_invalid_strides) {
            return std::string("rocfft_status_invalid_strides");
        } else if       (err == rocfft_status_invalid_distance) {
            return std::string("rocfft_status_invalid_distance");
        } else if       (err == rocfft_status_invalid_offset) {
            return std::string("rocfft_status_invalid_offset");
        } else {
            return std::to_string(err) + " (unknown error code)";
        }
    }
}