aboutsummaryrefslogtreecommitdiff
path: root/Tools/plot_parallel.py
diff options
context:
space:
mode:
Diffstat (limited to 'Tools/plot_parallel.py')
-rw-r--r--Tools/plot_parallel.py61
1 files changed, 32 insertions, 29 deletions
diff --git a/Tools/plot_parallel.py b/Tools/plot_parallel.py
index 56b243aca..d9efaa93a 100644
--- a/Tools/plot_parallel.py
+++ b/Tools/plot_parallel.py
@@ -1,8 +1,12 @@
-import os, glob, matplotlib, sys, argparse
-import yt ; yt.funcs.mylog.setLevel(50)
+import os
+import glob
+import matplotlib
+import sys
+import argparse
+import yt
+yt.funcs.mylog.setLevel(50)
import numpy as np
import matplotlib.pyplot as plt
-import scipy.constants as scc
'''
This script loops over all WarpX plotfiles in a directory and, for each
@@ -28,22 +32,20 @@ To get help, run
# Parse command line for options.
parser = argparse.ArgumentParser()
-parser.add_argument('--path', dest='path', default='.',
+parser.add_argument('--path', default='diags/plotfiles',
help='path to plotfiles. Plotfiles names must be plt?????')
-parser.add_argument('--plotlib', dest='plotlib', default='yt',
+parser.add_argument('--plotlib', default='yt',
choices=['yt','matplotlib'],
help='Plotting library to use')
-parser.add_argument('--field', dest='field', default='Ez',
+parser.add_argument('--field', default='Ez',
help='Which field to plot, e.g., Ez, By, jx or rho. The central slice in y is plotted')
-parser.add_argument('--pjump', dest='pjump', default=20,
+parser.add_argument('--pjump', default=20,
help='When plotlib=matplotlib, we plot every pjump particle')
-parser.add_argument('--use_vmax', dest='use_vmax', default=False,
- help='Whether to put bounds to field colormap')
-parser.add_argument('--vmax', dest='vmax', default=1.e12,
- help='If use_vmax=True, the colormab will have bounds [-vamx, vmax]')
-parser.add_argument('--slicewidth', dest='slicewidth', default=10.e-6,
+parser.add_argument('--vmax', type=float, default=None,
+ help='If specified, the colormap will have bounds [-vmax, vmax]')
+parser.add_argument('--slicewidth', default=10.e-6,
help='Only particles with -slicewidth/2<y<slicewidth/2 are plotted')
-parser.add_argument('--parallel', dest='parallel', action='store_true', default=False,
+parser.add_argument('--parallel', action='store_true', default=False,
help='whether or not to do the analysis in parallel (e.g., 1 plotfile per MPI rank)')
parser.add_argument('--species', dest='pslist', nargs='+', type=str, default=None,
help='Species to be plotted, e.g., " --species beam plasma_e ". By default, all species in the simulation are shown')
@@ -51,7 +53,10 @@ parser.add_argument('--plot_max_evolution', type=str, default=None,
help='Quantity to plot the max of across all data files')
args = parser.parse_args()
+plotlib = args.plotlib
plot_max_evolution = args.plot_max_evolution
+vmax = args.vmax
+
# Sanity check
if int(sys.version[0]) != 3:
print('WARNING: Parallel analysis was only tested with Python3')
@@ -68,14 +73,14 @@ def get_species(a_file_list):
if args.pslist is not None:
return args.pslist
# otherwise, loop over all plotfiles to get particle species list
- pslist = []
+ psset = set()
for filename in a_file_list:
ds = yt.load( filename )
- # get list of species in current plotfile
- pslist_plotfile = list( set( [x[0] for x in ds.field_list
- if x[1][:9]=='particle_'] ) )
- # append species in current plotfile to pslist, and uniquify
- pslist = list( set( pslist + pslist_plotfile ) )
+ for ps in ds.particle_types:
+ if ps == 'all':
+ continue
+ psset.add(ps)
+ pslist = list(psset)
pslist.sort()
return pslist
@@ -101,11 +106,13 @@ def plot_snapshot(filename):
plt.colorbar()
plt.xlim(ds.domain_left_edge[dim-1], ds.domain_right_edge[dim-1])
plt.ylim(ds.domain_left_edge[0], ds.domain_right_edge[0])
- if args.use_vmax:
- plt.clim(-args.vmax, args.vmax)
+ if vmax is not None:
+ plt.clim(-vmax, vmax)
if plotlib == 'yt':
# Directly plot with yt
sl = yt.SlicePlot(ds, yt_slicedir[dim], args.field, aspect=yt_aspect[dim])
+ if vmax is not None:
+ sl.set_zlim(-vmax, vmax)
# Plot particle quantities
for ispecies, pspecies in enumerate(pslist):
@@ -136,15 +143,12 @@ def plot_snapshot(filename):
if plotlib == 'matplotlib':
plt.xlabel('z (m)')
plt.ylabel('x (m)')
- plt.title(args.field + ' at iteration ' + str(iteration) +
- ', time = ' + str(ds.current_time))
- plt.savefig(args.path + '/plt_' + args.field + '_' + plotlib + '_' +
- str(iteration).zfill(5) + '.png', bbox_inches='tight', dpi=300)
+ plt.title('%s at iteration %d, time = %e s'%(args.field, iteration, ds.current_time))
+ plt.savefig('plt_%s_%s_%05d.png'%(args.field, plotlib, iteration), bbox_inches='tight', dpi=300)
plt.close()
if plotlib == 'yt':
sl.annotate_grids()
- sl.save(args.path + '/plt_' + args.field + '_' + plotlib + '_' +
- str(iteration).zfill(5) + '.png')
+ sl.save('plt_%s_%s_%05d.png'%(args.field, plotlib, iteration))
# Compute max of field a_field in plotfile filename
def get_field_max( filename, a_field ):
@@ -170,8 +174,7 @@ def plot_field_max(zwin_arr, maxF_arr):
### Analysis ###
# Get list of plotfiles
-plotlib = args.plotlib
-file_list = glob.glob(args.path + '/plt?????')
+file_list = glob.glob(os.path.join(args.path, 'plt?????'))
file_list.sort()
nfiles = len(file_list)