aboutsummaryrefslogtreecommitdiff
path: root/Tools/plot_parallel.py
blob: ec44b5af80f7f48697e91560a7e58217f5c1f31d (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
import os
import glob
import matplotlib
import sys
import argparse
import yt
yt.funcs.mylog.setLevel(50)
import numpy as np
import matplotlib.pyplot as plt

'''
This script loops over all WarpX plotfiles in a directory and, for each
plotfile, saves an image showing the field and particles.

Requires yt>3.5 and Python3

It can be run serial:

> python plot_parallel.py --path <path/to/plt/files> --serial

or parallel

> mpirun -np 32 python plot_parallel.py --path <path/to/plt/files>

When running parallel, the plotfiles are distributed as evenly as possible
between MPI ranks.

This script also proposes an option to plot one quantity over all timesteps.
The data of all plotfiles are gathered to rank 0, and the quantity evolution
is plotted and saved to file. For the illustration, this quantity is the max
of Ey. Use " --plot_evolution Ey" to activate this option.

To get help, run
> python plot_parallel --help
'''

# Parse command line for options.
parser = argparse.ArgumentParser()
parser.add_argument('--path', default=None,
                    help='path to plotfiles, defaults to diags/plotfiles. Plotfiles names must be plt?????')
parser.add_argument('--image_dir', default=None,
                    help='path where images are placed, defaults to diags/plotfiles or path if specified.')
parser.add_argument('--plotlib', default='yt',
                    choices=['yt','matplotlib'],
                    help='Plotting library to use')
parser.add_argument('--field', default='Ez',
                    help='Which field to plot, e.g., Ez, By, jx or rho. The central slice in y is plotted')
parser.add_argument('--pjump', default=20,
                    help='When plotlib=matplotlib, we plot every pjump particle')
parser.add_argument('--vmax', type=float, default=None,
                    help='If specified, the colormap will have bounds [-vmax, vmax]')
parser.add_argument('--slicewidth', default=10.e-6,
                    help='Only particles with -slicewidth/2<y<slicewidth/2 are plotted')
parser.add_argument('--serial', action='store_true', default=False,
                    help='Specifies running in serial, avoiding the import of MPI')
parser.add_argument('--species', dest='pslist', nargs='+', type=str, default=None,
                    help='Species to be plotted, e.g., " --species beam plasma_e ". By default, all species in the simulation are shown')
parser.add_argument('--plot_evolution', type=str, default=None,
                    help='Quantity to plot the evolution of across all data files')
args = parser.parse_args()

path = args.path
image_dir = args.image_dir
plotlib = args.plotlib
plot_evolution = args.plot_evolution
vmax = args.vmax

if path is None:
    path = 'diags/plotfiles'
if image_dir is None:
    image_dir = path

# Sanity check
if int(sys.version[0]) != 3:
    print('WARNING: Parallel analysis was only tested with Python3')

matplotlib.rcParams.update({'font.size': 14})
pscolor = ['r','g','b','k','m','c','y','w']
pssize = 1.
yt_slicedir = {2:2, 3:1}
yt_aspect = {2:.05, 3:20}

# Get list of particle species.
def get_species(a_file_list):
    # if user-specified, just return the user list
    if args.pslist is not None:
        return args.pslist
    # otherwise, loop over all plotfiles to get particle species list
    psset = set()
    for filename in a_file_list:
        ds = yt.load( filename )
        for ps in ds.particle_types:
            if ps == 'all':
                continue
            psset.add(ps)
    pslist = list(psset)
    pslist.sort()
    return pslist

def plot_snapshot(filename):
    print( filename )
    # Load plotfile
    ds = yt.load( filename )
    # Get number of dimension
    dim = ds.dimensionality

    # Plot field colormap
    if plotlib == 'matplotlib':
        plt.figure(figsize=(12,7))
        # Read field quantities from yt dataset
        all_data_level_0 = ds.covering_grid(level=0,left_edge=ds.domain_left_edge, dims=ds.domain_dimensions)
        F = all_data_level_0['boxlib', args.field].v.squeeze()
        if dim == 3:
            F = F[:,int(F.shape[1]+.5)//2,:]
        extent = [ds.domain_left_edge[dim-1], ds.domain_right_edge[dim-1],
                  ds.domain_left_edge[0], ds.domain_right_edge[0]]
        # Plot field quantities with matplotlib
        plt.imshow(F, aspect='auto', extent=extent, origin='lower')
        plt.colorbar()
        plt.xlim(ds.domain_left_edge[dim-1], ds.domain_right_edge[dim-1])
        plt.ylim(ds.domain_left_edge[0], ds.domain_right_edge[0])
        if vmax is not None:
            plt.clim(-vmax, vmax)
    if plotlib == 'yt':
        # Directly plot with yt
        sl = yt.SlicePlot(ds, yt_slicedir[dim], args.field, aspect=yt_aspect[dim])
        if vmax is not None:
            sl.set_zlim(-vmax, vmax)

    # Plot particle quantities
    for ispecies, pspecies in enumerate(pslist):
        if pspecies in [x[0] for x in ds.field_list]:
            if plotlib == 'matplotlib':
                # Read particle quantities from yt dataset
                ad = ds.all_data()
                xp = ad[pspecies, 'particle_position_x'].v
                if dim == 3:
                    yp = ad[pspecies, 'particle_position_y'].v
                    zp = ad[pspecies, 'particle_position_z'].v
                    select = yp**2<(args.slicewidth/2)**2
                    xp = xp[select] ; yp = yp[select] ; zp = zp[select]
                if dim == 2:
                    zp = ad[pspecies, 'particle_position_y'].v
                # Select randomly one every pjump particles
                random_indices = np.random.choice(xp.shape[0], int(xp.shape[0]/args.pjump))
                if dim == 2:
                    xp=xp[random_indices] ; zp=zp[random_indices]
                if dim == 3:
                    xp=xp[random_indices] ; yp=yp[random_indices] ; zp=zp[random_indices]
                plt.scatter(zp,xp,c=pscolor[ispecies],s=pssize, linewidth=pssize,marker=',')
            if plotlib == 'yt':
                # Directly plot particles with yt
                sl.annotate_particles(width=(args.slicewidth, 'm'), p_size=pssize,
                                      ptype=pspecies, col=pscolor[ispecies])
    # Add labels to plot and save
    iteration = int(filename[-5:])
    image_file_name = os.path.join(image_dir, 'plt_%s_%s_%05d.png'%(args.field, plotlib, iteration))
    if plotlib == 'matplotlib':
        plt.xlabel('z (m)')
        plt.ylabel('x (m)')
        plt.title('%s at iteration %d, time = %e s'%(args.field, iteration, ds.current_time))
        plt.savefig(image_file_name, bbox_inches='tight', dpi=300)
        plt.close()
    if plotlib == 'yt':
        sl.annotate_grids()
        sl.save(image_file_name)

# Compute the evolved quantity from plotfile filename
def get_evolution_quantity(filename, quantity_name):
    # Load plotfile
    ds = yt.load( filename )
    # Get number of dimension
    dim = ds.dimensionality
    # Read field quantities from yt dataset
    all_data_level_0 = ds.covering_grid(level=0,left_edge=ds.domain_left_edge, dims=ds.domain_dimensions)
    F = all_data_level_0['boxlib', quantity_name].v.squeeze()
    zwin = (ds.domain_left_edge[dim-1]+ds.domain_right_edge[dim-1])/2
    quantity = np.amax(F)
    return zwin, quantity

def plot_evolved_quantity(zwin_arr, maxF_arr):
    plt.figure()
    plt.plot(zwin_arr, maxF_arr)
    plt.xlabel('z (m)')
    plt.ylabel('%s (S.I.)'%plot_evolution)
    plt.title('Field max evolution')
    plt.savefig(os.path.join(image_dir, 'max_%s_evolution.pdf'%plot_evolution), bbox_inches='tight')

### Analysis ###

# Get list of plotfiles
file_list = glob.glob(os.path.join(path, 'plt?????'))
file_list.sort()
nfiles = len(file_list)

# Get list of particle speciess to plot
pslist = get_species(file_list);

rank = 0
size = 1
if not args.serial:
    try:
        from mpi4py import MPI
        comm_world = MPI.COMM_WORLD
        rank = comm_world.Get_rank()
        size = comm_world.Get_size()
    except ImportError:
        pass

if rank == 0:
    print('number of MPI ranks: %d'%size)
    print('Number of plotfiles: %s'%nfiles)
    print('list of species: ', pslist)

if plot_evolution is not None:
    # Fill with a value less than any possible value
    zwin = np.full(nfiles, np.finfo(float).min)
    quantity = np.full(nfiles, np.finfo(float).min)

# Loop over files, splitting plotfile list among MPI ranks
# - plot field snapshot
# - store window position and field max in arrays
for count, filename in enumerate(file_list):
    if count%size != rank:
        continue

    plot_snapshot( filename )
    if plot_evolution is not None:
        zwin[count], quantity[count] = get_evolution_quantity( filename, plot_evolution )

if plot_evolution is not None:
    if size > 1:
        global_zwin = np.empty_like(zwin)
        global_quantity = np.empty_like(quantity)
        comm_world.Reduce(zwin, global_zwin, op=MPI.MAX)
        comm_world.Reduce(quantity, global_quantity, op=MPI.MAX)
        zwin = global_zwin
        quantity = global_quantity
    if rank == 0:
        evolved_quantity(zwin, quantity)