aboutsummaryrefslogtreecommitdiff
path: root/Source/FieldSolver/SpectralSolver/WrapRocFFT.cpp
blob: 5f3131dec47048694dca3f6a0d7796e6a6b63207 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
/* Copyright 2019-2020
 *
 * This file is part of WarpX.
 *
 * License: BSD-3-Clause-LBNL
 */

#include "AnyFFT.H"

#include "Utils/TextMsg.H"

namespace AnyFFT
{

    std::string rocfftErrorToString (const rocfft_status err);

    namespace {
        void assert_rocfft_status (std::string const& name, rocfft_status status)
        {
            if (status != rocfft_status_success) {
                WARPX_ABORT_WITH_MESSAGE(
                    name + " failed! Error: " + rocfftErrorToString(status));
            }
        }
    }

    FFTplan CreatePlan (const amrex::IntVect& real_size, amrex::Real * const real_array,
                        Complex * const complex_array, const direction dir, const int dim)
    {
        FFTplan fft_plan;

        const std::size_t lengths[] = {AMREX_D_DECL(std::size_t(real_size[0]),
                                                    std::size_t(real_size[1]),
                                                    std::size_t(real_size[2]))};

        // Initialize fft_plan.m_plan with the vendor fft plan.
        rocfft_status result = rocfft_plan_create(&(fft_plan.m_plan),
                                                  rocfft_placement_notinplace,
                                                  (dir == direction::R2C)
                                                      ? rocfft_transform_type_real_forward
                                                      : rocfft_transform_type_real_inverse,
#ifdef AMREX_USE_FLOAT
                                                  rocfft_precision_single,
#else
                                                  rocfft_precision_double,
#endif
                                                  dim, lengths,
                                                  1, // number of transforms,
                                                  nullptr);
        assert_rocfft_status("rocfft_plan_create", result);

        // Store meta-data in fft_plan
        fft_plan.m_real_array = real_array;
        fft_plan.m_complex_array = complex_array;
        fft_plan.m_dir = dir;
        fft_plan.m_dim = dim;

        return fft_plan;
    }

    void DestroyPlan (FFTplan& fft_plan)
    {
        rocfft_plan_destroy( fft_plan.m_plan );
    }

    void Execute (FFTplan& fft_plan)
    {
        rocfft_execution_info execinfo = nullptr;
        rocfft_status result = rocfft_execution_info_create(&execinfo);
        assert_rocfft_status("rocfft_execution_info_create", result);

        std::size_t buffersize = 0;
        result = rocfft_plan_get_work_buffer_size(fft_plan.m_plan, &buffersize);
        assert_rocfft_status("rocfft_plan_get_work_buffer_size", result);

        void* buffer = amrex::The_Arena()->alloc(buffersize);
        result = rocfft_execution_info_set_work_buffer(execinfo, buffer, buffersize);
        assert_rocfft_status("rocfft_execution_info_set_work_buffer", result);

        result = rocfft_execution_info_set_stream(execinfo, amrex::Gpu::gpuStream());
        assert_rocfft_status("rocfft_execution_info_set_stream", result);

        if (fft_plan.m_dir == direction::R2C) {
            result = rocfft_execute(fft_plan.m_plan,
                                    (void**)&(fft_plan.m_real_array), // in
                                    (void**)&(fft_plan.m_complex_array), // out
                                    execinfo);
        } else if (fft_plan.m_dir == direction::C2R) {
            result = rocfft_execute(fft_plan.m_plan,
                                    (void**)&(fft_plan.m_complex_array), // in
                                    (void**)&(fft_plan.m_real_array), // out
                                    execinfo);
        } else {
            WARPX_ABORT_WITH_MESSAGE(
                "direction must be AnyFFT::direction::R2C or AnyFFT::direction::C2R");
        }

        assert_rocfft_status("rocfft_execute", result);

        amrex::Gpu::streamSynchronize();

        amrex::The_Arena()->free(buffer);

        result = rocfft_execution_info_destroy(execinfo);
        assert_rocfft_status("rocfft_execution_info_destroy", result);
    }

    /** \brief This method converts a rocfftResult
     * into the corresponding string
     *
     * @param[in] err a rocfftResult
     * @return an std::string
     */
    std::string rocfftErrorToString (const rocfft_status err)
    {
        if              (err == rocfft_status_success) {
            return std::string("rocfft_status_success");
        } else if       (err == rocfft_status_failure) {
            return std::string("rocfft_status_failure");
        } else if       (err == rocfft_status_invalid_arg_value) {
            return std::string("rocfft_status_invalid_arg_value");
        } else if       (err == rocfft_status_invalid_dimensions) {
            return std::string("rocfft_status_invalid_dimensions");
        } else if       (err == rocfft_status_invalid_array_type) {
            return std::string("rocfft_status_invalid_array_type");
        } else if       (err == rocfft_status_invalid_strides) {
            return std::string("rocfft_status_invalid_strides");
        } else if       (err == rocfft_status_invalid_distance) {
            return std::string("rocfft_status_invalid_distance");
        } else if       (err == rocfft_status_invalid_offset) {
            return std::string("rocfft_status_invalid_offset");
        } else {
            return std::to_string(err) + " (unknown error code)";
        }
    }
}
'>-7/+11 2023-10-16Fix formattingGravatar Ashcon Partovi 2-5/+4 2023-10-16Fix `Response.statusText` (#6151)Gravatar Chris Toshok 10-238/+269 2023-10-16fix-subprocess-argument-missing (#6407)Gravatar Nicolae-Rares Ailincai 4-2/+40 2023-10-16Add type parameter to `expect` (#6128)Gravatar Voldemat 1-3/+3 2023-10-16fix(node:worker_threads): ensure threadId property is exposed on worker_threa...Gravatar Jérôme Benoit 6-15/+75 2023-10-16Fix use before define bug in sqliteGravatar Ashcon Partovi 2-5/+5 2023-10-16fix(jest): fix toStrictEqual on same URLs (#6528)Gravatar João Alisson 2-13/+16 2023-10-16Fix `toHaveBeenCalled` having wrong error signatureGravatar Ashcon Partovi 1-2/+2 2023-10-16Fix formattingGravatar Ashcon Partovi 1-2/+1 2023-10-16Add `reusePort` to `Bun.serve` typesGravatar Ashcon Partovi 1-0/+9 2023-10-16Fix `request.url` having incorrect portGravatar Ashcon Partovi 4-1/+92 2023-10-16Remove uWebSockets header from Bun.serve responsesGravatar Ashcon Partovi 1-6/+6 2023-10-16Rename some testsGravatar Ashcon Partovi 3-0/+0 2023-10-16Fix #6467Gravatar Ashcon Partovi 2-3/+10 2023-10-16Update InternalModuleRegistryConstants.hGravatar Dylan Conway 1-3/+3 2023-10-16Development -> Contributing (#6538)Gravatar Colin McDonnell 2-1/+1 2023-10-14fix(net/tls) fix pg hang on end + hanging on query (#6487)Gravatar Ciro Spaciari 3-8/+36 2023-10-13fix installing dependencies that match workspace versions (#6494)Gravatar Dylan Conway 4-2/+64 2023-10-13fix lockfile struct padding (#6495)Gravatar Dylan Conway 3-3/+18