diff options
Diffstat (limited to 'Tools/plot_parallel.py')
-rw-r--r-- | Tools/plot_parallel.py | 252 |
1 files changed, 252 insertions, 0 deletions
diff --git a/Tools/plot_parallel.py b/Tools/plot_parallel.py new file mode 100644 index 000000000..2041e7935 --- /dev/null +++ b/Tools/plot_parallel.py @@ -0,0 +1,252 @@ +import os, glob, matplotlib, sys, argparse +import yt ; yt.funcs.mylog.setLevel(50) +import numpy as np +import matplotlib.pyplot as plt +import scipy.constants as scc + +''' +This script loops over all WarpX plotfiles in a directory and, for each +plotfile, saves an image showing the field and particles. + +Requires yt>3.5 and Python3 + +It can be run serial: +> python plot_parallel.py --path <path/to/plt/files> +or parallel +> mpirun -np 32 python plot_parallel.py --path <path/to/plt/files> --parallel +When running parallel, the plotfiles are distributed as evenly as possible +between MPI ranks. + +This script also proposes an option to plot one quantity over all timesteps. +The data of all plotfiles are gathered to rank 0, and the quantity evolution +is plotted and saved to file. For the illustration, this quantity is the max +of Ey. Use " --plot_Ey_max_evolution " to activate this option. + +To get help, run +> python plot_parallel --help +''' + +# Parse command line for options. +parser = argparse.ArgumentParser() +parser.add_argument('--path', dest='path', default='.', + help='path to plotfiles. Plotfiles names must be plt?????') +parser.add_argument('--plotlib', dest='plotlib', default='yt', + choices=['yt','matplotlib'], + help='Plotting library to use') +parser.add_argument('--field', dest='field', default='Ez', + help='Which field to plot, e.g., Ez, By, jx or rho. The central slice in y is plotted') +parser.add_argument('--pjump', dest='pjump', default=20, + help='When plotlib=matplotlib, we plot every pjump particle') +parser.add_argument('--use_vmax', dest='use_vmax', default=False, + help='Whether to put bounds to field colormap') +parser.add_argument('--vmax', dest='vmax', default=1.e12, + help='If use_vmax=True, the colormab will have bounds [-vamx, vmax]') +parser.add_argument('--slicewidth', dest='slicewidth', default=10.e-6, + help='Only particles with -slicewidth/2<y<slicewidth/2 are plotted') +parser.add_argument('--parallel', dest='parallel', action='store_true', default=False, + help='whether or not to do the analysis in parallel (e.g., 1 plotfile per MPI rank)') +parser.add_argument('--species', dest='pslist', nargs='+', type=str, default=None, + help='Species to be plotted, e.g., " --species beam plasma_e ". By default, all species in the simulation are shown') +parser.add_argument('--plot_Ey_max_evolution', dest='plot_Ey_max_evolution', action='store_true', default=False, + help='Whether to plot evolution of max(Ey) to illustrate how to plot one quantity across all plotfiles.') +args = parser.parse_args() + +# Sanity check +if int(sys.version[0]) != 3: + print('WARNING: Parallel analysis was only tested with Python3') + +matplotlib.rcParams.update({'font.size': 14}) +pscolor = ['r','g','b','k','m','c','y','w'] +pssize = 1. +yt_slicedir = {2:2, 3:1} +yt_aspect = {2:.05, 3:20} + +# Get list of particle species. +def get_species(a_file_list): + # if user-specified, just return the user list + if args.pslist is not None: + return args.pslist + # otherwise, loop over all plotfiles to get particle species list + pslist = [] + for filename in a_file_list: + ds = yt.load( filename ) + # get list of species in current plotfile + pslist_plotfile = list( set( [x[0] for x in ds.field_list + if x[1][:9]=='particle_'] ) ) + # append species in current plotfile to pslist, and uniquify + pslist = list( set( pslist + pslist_plotfile ) ) + pslist.sort() + return pslist + +def plot_snapshot(filename): + print( filename ) + # Load plotfile + ds = yt.load( filename ) + # Get number of dimension + dim = ds.dimensionality + + # Plot field colormap + if plotlib == 'matplotlib': + plt.figure(figsize=(12,7)) + # Read field quantities from yt dataset + all_data_level_0 = ds.covering_grid(level=0,left_edge=ds.domain_left_edge, dims=ds.domain_dimensions) + F = all_data_level_0['boxlib', args.field].v.squeeze() + if dim == 3: + F = F[:,int(F.shape[1]+.5)//2,:] + extent = [ds.domain_left_edge[dim-1], ds.domain_right_edge[dim-1], + ds.domain_left_edge[0], ds.domain_right_edge[0]] + # Plot field quantities with matplotlib + plt.imshow(F, aspect='auto', extent=extent, origin='lower') + plt.colorbar() + plt.xlim(ds.domain_left_edge[dim-1], ds.domain_right_edge[dim-1]) + plt.ylim(ds.domain_left_edge[0], ds.domain_right_edge[0]) + if args.use_vmax: + plt.clim(-args.vmax, args.vmax) + if plotlib == 'yt': + # Directly plot with yt + sl = yt.SlicePlot(ds, yt_slicedir[dim], args.field, aspect=yt_aspect[dim]) + + # Plot particle quantities + for ispecies, pspecies in enumerate(pslist): + if pspecies in [x[0] for x in ds.field_list]: + if plotlib == 'matplotlib': + # Read particle quantities from yt dataset + xp = ad[pspecies, 'particle_position_x'].v + if dim == 3: + yp = ad[pspecies, 'particle_position_y'].v + zp = ad[pspecies, 'particle_position_z'].v + select = yp**2<(args.slicewidth/2)**2 + xp = xp[select] ; yp = yp[select] ; zp = zp[select] + if dim == 2: + zp = ad[pspecies, 'particle_position_y'].v + # Select randomly one every pjump particles + random_indices = np.random.choice(xp.shape[0], int(xp.shape[0]/args.pjump)) + if dim == 2: + xp=xp[random_indices] ; zp=zp[random_indices] + if dim == 3: + xp=xp[random_indices] ; yp=yp[random_indices] ; zp=zp[random_indices] + plt.scatter(zp,xp,c=pscolor[ispecies],s=pssize, linewidth=pssize,marker=',') + if plotlib == 'yt': + # Directly plot particles with yt + sl.annotate_particles(width=(args.slicewidth, 'm'), p_size=pssize, + ptype=pspecies, col=pscolor[ispecies]) + # Add labels to plot and save + iteration = int(filename[-5:]) + if plotlib == 'matplotlib': + plt.xlabel('z (m)') + plt.ylabel('x (m)') + plt.title(args.field + ' at iteration ' + str(iteration) + + ', time = ' + str(ds.current_time)) + plt.savefig(args.path + '/plt_' + args.field + '_' + plotlib + '_' + + str(iteration).zfill(5) + '.png', bbox_inches='tight', dpi=300) + plt.close() + if plotlib == 'yt': + sl.annotate_grids() + sl.save(args.path + '/plt_' + args.field + '_' + plotlib + '_' + + str(iteration).zfill(5) + '.png') + +# Compute max of field a_field in plotfile filename +def get_field_max( filename, a_field ): + # Load plotfile + ds = yt.load( filename ) + # Get number of dimension + dim = ds.dimensionality + # Read field quantities from yt dataset + all_data_level_0 = ds.covering_grid(level=0,left_edge=ds.domain_left_edge, dims=ds.domain_dimensions) + F = all_data_level_0['boxlib', a_field].v.squeeze() + zwin = (ds.domain_left_edge[dim-1]+ds.domain_right_edge[dim-1])/2 + maxF = np.amax(F) + return zwin, maxF + +def plot_field_max(): + plt.figure() + plt.plot(zwin_arr, maxF_arr) + plt.xlabel('z (m)') + plt.ylabel('Field (S.I.)') + plt.title('Field max evolution') + plt.savefig('max_field_evolution.pdf', bbox_inches='tight') + +### Analysis ### + +# Get list of plotfiles +plotlib = args.plotlib +plot_Ey_max_evolution = args.plot_Ey_max_evolution +file_list = glob.glob(args.path + '/plt?????') +file_list.sort() +nfiles = len(file_list) +number_list = range(nfiles) + +if args.parallel: + ### Parallel analysis ### + # Split plotfile list among MPI ranks + from mpi4py import MPI + comm_world = MPI.COMM_WORLD + rank = comm_world.Get_rank() + size = comm_world.Get_size() + max_buf_size = nfiles//size+1 + if rank == 0: + print('Parallel analysis') + print('number of MPI ranks: ' + str(size)) + print('Number of plotfiles: %s' %nfiles) + # List of files processed by current MPI rank + my_list = file_list[ (rank*nfiles)//size : ((rank+1)*nfiles)//size ] + my_number_list = number_list[ (rank*nfiles)//size : ((rank+1)*nfiles)//size ] + my_nfiles = len( my_list ) + nfiles_list = None + nfiles_list = comm_world.gather(my_nfiles, root=0) + # Get list of particles to plot + pslist = get_species(file_list); + if rank == 0: + print('list of species: ', pslist) + if plot_Ey_max_evolution: + my_zwin = np.zeros( max_buf_size ) + my_maxF = np.zeros( max_buf_size ) + # Loop over files and + # - plot field snapshot + # - store window position and field max in arrays + for count, filename in enumerate(my_list): + plot_snapshot( filename ) + if plot_Ey_max_evolution: + my_zwin[count], my_maxF[count] = get_field_max( filename, 'Ey' ) + + if plot_Ey_max_evolution: + # Gather window position and field max arrays to rank 0 + zwin_rbuf = None + maxF_rbuf = None + if rank == 0: + zwin_rbuf = np.empty([size, max_buf_size], dtype='d') + maxF_rbuf = np.empty([size, max_buf_size], dtype='d') + comm_world.Gather(my_zwin, zwin_rbuf, root=0) + comm_world.Gather(my_maxF, maxF_rbuf, root=0) + # Re-format 2D arrays zwin_rbuf and maxF_rbuf on rank 0 + # into 1D arrays, and plot them + if rank == 0: + zwin_arr = np.zeros( nfiles ) + maxF_arr = np.zeros( nfiles ) + istart = 0 + for i in range(size): + nelem = nfiles_list[i] + zwin_arr[istart:istart+nelem] = zwin_rbuf[i,0:nelem] + maxF_arr[istart:istart+nelem] = maxF_rbuf[i,0:nelem] + istart += nelem + # Plot evolution of field max + plot_field_max() +else: + ### Serial analysis ### + print('Serial analysis') + print('Number of plotfiles: %s' %nfiles) + pslist = get_species(file_list); + print('list of species: ', pslist) + if plot_Ey_max_evolution: + zwin_arr = np.zeros( nfiles ) + maxF_arr = np.zeros( nfiles ) + # Loop over files and + # - plot field snapshot + # - store window position and field max in arrays + for count, filename in enumerate(file_list): + plot_snapshot( filename ) + if plot_Ey_max_evolution: + zwin_arr[count], maxF_arr[count] = get_field_max( filename, 'Ey' ) + # Plot evolution of field max + if plot_Ey_max_evolution: + plot_field_max() |