aboutsummaryrefslogtreecommitdiff
path: root/Source/AcceleratorLattice/LatticeElementFinder.H
blob: dd9358b19afb0a66626fcfd36ac78cae84a43159 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
/* Copyright 2022 David Grote
 *
 * This file is part of WarpX.
 *
 * License: BSD-3-Clause-LBNL
 */
#ifndef WARPX_ACCELERATORLATTICE_LATTICEELEMENTS_LATTICEELEMENTFINDER_H_
#define WARPX_ACCELERATORLATTICE_LATTICEELEMENTS_LATTICEELEMENTFINDER_H_

#include "LatticeElements/HardEdgedQuadrupole.H"
#include "LatticeElements/HardEdgedPlasmaLens.H"
#include "Particles/Pusher/GetAndSetPosition.H"
#include "Particles/WarpXParticleContainer.H"

#include <AMReX_REAL.H>
#include <AMReX_GpuContainers.H>

class AcceleratorLattice;
struct LatticeElementFinderDevice;

// Instances of the LatticeElementFinder class are saved in the AcceleratorLattice class
// as the objects in a LayoutData.
// The LatticeElementFinder handles the lookup needed to find the lattice elements at
// particle locations.

struct LatticeElementFinder
{

    /**
     * \brief Initialize the element finder at the level and grid
     *
     * @param[in] lev the refinement level
     * @param[in] a_mfi specifies the grid where the finder is defined
     * @param[in] accelerator_lattice a reference to the accelerator lattice at the refinement level
     */
    void InitElementFinder (int const lev, amrex::MFIter const& a_mfi,
                            AcceleratorLattice const& accelerator_lattice);

    /**
     * \brief Allocate the index lookup tables for each element type
     *
     * @param[in] accelerator_lattice a reference to the accelerator lattice at the refinement level
     */
    void AllocateIndices (AcceleratorLattice const& accelerator_lattice);

    /**
     * \brief Update the index lookup tables for each element type, filling in the values
     *
     * @param[in] lev the refinement level
     * @param[in] a_mfi specifies the grid where the finder is defined
     * @param[in] accelerator_lattice a reference to the accelerator lattice at the refinement level
     */
    void UpdateIndices (int const lev, amrex::MFIter const& a_mfi,
                        AcceleratorLattice const& accelerator_lattice);

    /* Define the location and size of the index lookup table */
    /* Use the type Real to be consistent with the way the main grid is defined */
    int m_nz;
    amrex::Real m_zmin;
    amrex::Real m_dz;

    /* Parameters needed for the Lorentz transforms into and out of the boosted frame */
    /* The time for m_time is consistent with the main time variable */
    amrex::ParticleReal m_gamma_boost;
    amrex::ParticleReal m_uz_boost;
    amrex::Real m_time;

    /**
     * \brief Get the device level instance associated with this instance
     *
     * @param[in] a_pti specifies the grid where the finder is defined
     * @param[in] a_offset particle index offset needed to access particle info
     * @param[in] accelerator_lattice a reference to the accelerator lattice at the refinement level
     */
    LatticeElementFinderDevice GetFinderDeviceInstance (WarpXParIter const& a_pti, int const a_offset,
                                                        AcceleratorLattice const& accelerator_lattice);

    /* The index lookup tables for each lattice element type */
    amrex::Gpu::DeviceVector<int> d_quad_indices;
    amrex::Gpu::DeviceVector<int> d_plasmalens_indices;

    /**
     * \brief Fill in the index lookup tables
     * This loops over the grid (in z) and finds the lattice element closest to each grid point
     *
     * @param[in] zs list of the starts of the lattice elements
     * @param[in] ze list of the ends of the lattice elements
     * @param[in] indices the index lookup table to be filled in
     */
    void setup_lattice_indices (amrex::Gpu::DeviceVector<amrex::ParticleReal> const & zs,
                                amrex::Gpu::DeviceVector<amrex::ParticleReal> const & ze,
                                amrex::Gpu::DeviceVector<int> & indices)
    {

        using namespace amrex::literals;

        const auto nelements = static_cast<int>(zs.size());
        amrex::ParticleReal const * zs_arr = zs.data();
        amrex::ParticleReal const * ze_arr = ze.data();
        int * indices_arr = indices.data();

        amrex::Real const zmin = m_zmin;
        amrex::Real const dz = m_dz;

        amrex::ParticleReal const gamma_boost = m_gamma_boost;
        amrex::ParticleReal const uz_boost = m_uz_boost;
        amrex::Real const time = m_time;

        amrex::ParallelFor( m_nz,
            [=] AMREX_GPU_DEVICE (int iz) {

                // Get the location of the grid node
                amrex::Real z_node = zmin + iz*dz;

                if (gamma_boost > 1._prt) {
                    // Transform to lab frame
                    z_node = gamma_boost*z_node + uz_boost*time;
                }

                // Find the index to the element that is closest to the grid cell.
                // For now, this assumes that there is no overlap among elements of the same type.
                for (int ie = 0 ; ie < nelements ; ie++) {
                    // Find the mid points between element ie and the ones before and after it.
                    // The first and last element need special handling.
                    const amrex::ParticleReal zcenter_left = (ie == 0)?
                        (std::numeric_limits<amrex::ParticleReal>::lowest()) : (0.5_prt*(ze_arr[ie-1] + zs_arr[ie]));
                    const amrex::ParticleReal zcenter_right = (ie < nelements - 1)?
                        (0.5_prt*(ze_arr[ie] + zs_arr[ie+1])) : (std::numeric_limits<amrex::ParticleReal>::max());
                    if (zcenter_left <= z_node && z_node < zcenter_right) {
                        indices_arr[iz] = ie;
                    }

                }
            }
        );
    }

};

/**
 * \brief The lattice element finder class that can be trivially copied to the device.
 * This only has simple data and pointers.
 */
struct LatticeElementFinderDevice
{

    /**
     * \brief Initialize the data needed to do the lookups
     *
     * @param[in] a_pti specifies the grid where the finder is defined
     * @param[in] a_offset particle index offset needed to access particle info
     * @param[in] accelerator_lattice a reference to the accelerator lattice at the refinement level
     * @param[in] h_finder The host level instance of the element finder that this is associated with
     */
    void
    InitLatticeElementFinderDevice (WarpXParIter const& a_pti, int const a_offset,
                                    AcceleratorLattice const& accelerator_lattice,
                                    LatticeElementFinder const & h_finder);

    /* Size and location of the index lookup table */
    amrex::Real m_zmin;
    amrex::Real m_dz;
    amrex::Real m_dt;

    /* Parameters needed for the Lorentz transforms into and out of the boosted frame */
    amrex::ParticleReal m_gamma_boost;
    amrex::ParticleReal m_uz_boost;
    amrex::Real m_time;

    GetParticlePosition m_get_position;
    const amrex::ParticleReal* AMREX_RESTRICT m_ux = nullptr;
    const amrex::ParticleReal* AMREX_RESTRICT m_uy = nullptr;
    const amrex::ParticleReal* AMREX_RESTRICT m_uz = nullptr;

    /* Device level instances for each lattice element type */
    HardEdgedQuadrupoleDevice d_quad;
    HardEdgedPlasmaLensDevice d_plasmalens;

    /* Device level index lookup tables for each element type */
    int const* d_quad_indices_arr = nullptr;
    int const* d_plasmalens_indices_arr = nullptr;

    /**
     * \brief Gather the field for the particle from the lattice elements
     *
     * @param[in] i the particle index
     * @param[out] field_Ex,field_Ey,field_Ez,field_Bx,field_By,field_Bz the gathered E and B fields
     */
    AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
    void operator () (const long i,
                      amrex::ParticleReal& field_Ex,
                      amrex::ParticleReal& field_Ey,
                      amrex::ParticleReal& field_Ez,
                      amrex::ParticleReal& field_Bx,
                      amrex::ParticleReal& field_By,
                      amrex::ParticleReal& field_Bz) const noexcept
    {

        using namespace amrex::literals;

        amrex::ParticleReal x, y, z;
        m_get_position(i, x, y, z);

        // Find location of partice in the indices grid
        // (which is in the boosted frame)
        const int iz = static_cast<int>((z - m_zmin)/m_dz);

        constexpr amrex::ParticleReal inv_c2 = 1._prt/(PhysConst::c*PhysConst::c);
        amrex::ParticleReal const gamma = std::sqrt(1._prt + (m_ux[i]*m_ux[i] + m_uy[i]*m_uy[i] + m_uz[i]*m_uz[i])*inv_c2);
        amrex::ParticleReal const vzp = m_uz[i]/gamma;

        amrex::ParticleReal zpvdt = z + vzp*m_dt;

        // The position passed to the get_field methods needs to be in the lab frame.
        if (m_gamma_boost > 1._prt) {
            z = m_gamma_boost*z + m_uz_boost*m_time;
            zpvdt = m_gamma_boost*zpvdt + m_uz_boost*(m_time + m_dt);
        }

        amrex::ParticleReal Ex_sum = 0._prt;
        amrex::ParticleReal Ey_sum = 0._prt;
        const amrex::ParticleReal Ez_sum = 0._prt;
        amrex::ParticleReal Bx_sum = 0._prt;
        amrex::ParticleReal By_sum = 0._prt;
        const amrex::ParticleReal Bz_sum = 0._prt;

        if (d_quad.nelements > 0) {
            if (d_quad_indices_arr[iz] > -1) {
                const auto ielement = d_quad_indices_arr[iz];
                amrex::ParticleReal Ex, Ey, Bx, By;
                d_quad.get_field(ielement, x, y, z, zpvdt, Ex, Ey, Bx, By);
                Ex_sum += Ex;
                Ey_sum += Ey;
                Bx_sum += Bx;
                By_sum += By;
            }
        }

        if (d_plasmalens.nelements > 0) {
            if (d_plasmalens_indices_arr[iz] > -1) {
                const auto ielement = d_plasmalens_indices_arr[iz];
                amrex::ParticleReal Ex, Ey, Bx, By;
                d_plasmalens.get_field(ielement, x, y, z, zpvdt, Ex, Ey, Bx, By);
                Ex_sum += Ex;
                Ey_sum += Ey;
                Bx_sum += Bx;
                By_sum += By;
            }
        }

        if (m_gamma_boost > 1._prt) {
            // The fields returned from get_field is in the lab frame
            // Transform the fields to the boosted frame
            const amrex::ParticleReal Ex_boost = m_gamma_boost*Ex_sum - m_uz_boost*By_sum;
            const amrex::ParticleReal Ey_boost = m_gamma_boost*Ey_sum + m_uz_boost*Bx_sum;
            const amrex::ParticleReal Bx_boost = m_gamma_boost*Bx_sum + m_uz_boost*Ey_sum*inv_c2;
            const amrex::ParticleReal By_boost = m_gamma_boost*By_sum - m_uz_boost*Ex_sum*inv_c2;
            Ex_sum = Ex_boost;
            Ey_sum = Ey_boost;
            Bx_sum = Bx_boost;
            By_sum = By_boost;
        }

        field_Ex += Ex_sum;
        field_Ey += Ey_sum;
        field_Ez += Ez_sum;
        field_Bx += Bx_sum;
        field_By += By_sum;
        field_Bz += Bz_sum;

    }

};

#endif // WARPX_ACCELERATORLATTICE_LATTICEELEMENTS_LATTICEELEMENTFINDER_H_