#Import standard modules:
from pathlib import Path
import numpy as np
import xarray as xr
import warnings # use to warn user about missing files.
import matplotlib.pyplot as plt
import pandas as pd
#Format warning messages:
warnings.formatwarning = my_formatwarning
[docs]
def tem(adf):
"""
Plot the contents of the TEM dignostic ouput of 2-D latitude vs vertical pressure maps.
Steps:
- loop through TEM variables
- calculate all-time fields (from individual months)
- take difference, calculate statistics
- make plots
Notes:
- If any of the TEM cases are missing, the ADF skips this plotting script and moves on.
"""
#Notify user that script has started:
msg = "\n Generating TEM plots..."
print(f"{msg}\n {'-' * (len(msg)-3)}")
#Special ADF variable which contains the output paths for
#all generated plots and tables for each case:
plot_location = Path(adf.plot_location[0])
#Check if plot output directory exists, and if not, then create it:
if not plot_location.is_dir():
print(f" {plot_location} not found, making new directory")
plot_location.mkdir(parents=True)
#CAM simulation variables (this is always assumed to be a list):
case_names = adf.get_cam_info("cam_case_name", required=True)
res = adf.variable_defaults # will be dict of variable-specific plot preferences
#Extract test case years
syear_cases = adf.climo_yrs["syears"]
eyear_cases = adf.climo_yrs["eyears"]
#Extract baseline years (which may be empty strings if using Obs):
syear_baseline = adf.climo_yrs["syear_baseline"]
eyear_baseline = adf.climo_yrs["eyear_baseline"]
#Grab all case nickname(s)
test_nicknames = adf.case_nicknames["test_nicknames"]
base_nickname = adf.case_nicknames["base_nickname"]
case_nicknames = test_nicknames + [base_nickname]
#Set plot file type:
# -- this should be set in basic_info_dict, but is not required
# -- So check for it, and default to png
basic_info_dict = adf.read_config_var("diag_basic_info")
plot_type = basic_info_dict.get('plot_type', 'png')
print(f"\t NOTE: Plot type is set to {plot_type}")
# check if existing plots need to be redone
redo_plot = adf.get_basic_info('redo_plot')
print(f"\t NOTE: redo_plot is set to {redo_plot}")
#-----------------------------------------
#Initialize list of input TEM file locations
tem_locs = []
#Extract TEM file save locations
tem_case_locs = adf.get_cam_info("cam_tem_loc",required=True)
tem_base_loc = adf.get_baseline_info("cam_tem_loc")
#If path not specified, skip TEM calculation?
if tem_case_locs is None:
print("\t 'cam_tem_loc' not found for test case(s) in config file, so no TEM plots will be generated.")
return
else:
for tem_case_loc in tem_case_locs:
tem_case_loc = Path(tem_case_loc)
#Check if TEM directory exists, and if not, then create it:
if not tem_case_loc.is_dir():
print(f" {tem_case_loc} not found, making new directory")
tem_case_loc.mkdir(parents=True)
#End if
tem_locs.append(tem_case_loc)
#End for
#Set seasonal ranges:
seasons = {"ANN": np.arange(1,13,1),
"DJF": [12, 1, 2],
"JJA": [6, 7, 8],
"MAM": [3, 4, 5],
"SON": [9, 10, 11]
}
#Suggestion from Rolando, if QBO is being produced, add utendvtem and utendwtem?
if "qbo" in adf.plotting_scripts:
var_list = ['uzm','epfy','epfz','vtem','wtem',
'psitem','utendepfd','utendvtem','utendwtem']
#Otherwise keep it simple
else:
var_list = ['uzm','epfy','epfz','vtem','wtem','psitem','utendepfd']
#Check if comparing against obs
if adf.compare_obs:
obs = True
#Set TEM file for observations
base_file_name = 'Obs.TEMdiag.nc'
input_loc_idx = Path(tem_locs[0])
else:
base_name = adf.get_baseline_info("cam_case_name", required=True)
#If path not specified, skip TEM calculation?
if tem_base_loc is None:
print(f"\t 'cam_tem_loc' not found for '{base_name}' in config file, so no TEM plots will be generated.")
return
else:
obs = False
input_loc_idx = Path(tem_base_loc)
#Set TEM file for baseline
base_file_name = f'{base_name}.TEMdiag_{syear_baseline}-{eyear_baseline}.nc'
#Set full path for baseline/obs file
tem_base = input_loc_idx / base_file_name
#Check to see if baseline/obs TEM file exists
if tem_base.is_file():
ds_base = xr.open_dataset(tem_base)
else:
print(f"\t'{base_file_name}' does not exist. TEM plots will be skipped.")
return
#Setup TEM plots
nrows = len(var_list)
ncols = len(case_nicknames)+1
fig_width = 20
#try and dynamically create size of fig based off number of cases
fig_height = 15+(ncols*nrows)
#Loop over season dictionary:
for s in seasons:
#Location to save plots
plot_name = plot_location / f"{s}_TEM_Mean.png"
# Check redo_plot. If set to True: remove old plot, if it already exists:
if (not redo_plot) and plot_name.is_file():
#Add already-existing plot to website (if enabled):
adf.debug_log(f"'{plot_name}' exists and clobber is false.")
adf.add_website_data(plot_name, "TEM", None, season=s, multi_case=True)
#Continue to next iteration:
continue
elif (redo_plot) and plot_name.is_file():
plot_name.unlink()
fig, axs = plt.subplots(nrows=nrows, ncols=ncols, figsize=(fig_width,fig_height),
facecolor='w', edgecolor='k')
#Loop over model cases:
for idx,case_name in enumerate(case_names):
#Extract start and end year values:
start_year = syear_cases[idx]
end_year = eyear_cases[idx]
#Open the TEM file
output_loc_idx = tem_locs[idx]
case_file_name = f'{case_name}.TEMdiag_{start_year}-{end_year}.nc'
tem = output_loc_idx / case_file_name
#Grab the data for the TEM netCDF files
if tem.is_file():
ds = xr.open_dataset(tem)
else:
print(f"\t'{case_file_name}' does not exist. TEM plots will be skipped.")
return
climo_yrs = {"test":[syear_cases[idx], eyear_cases[idx]],
"base":[syear_baseline, eyear_baseline]}
#Setup and plot the sub-plots
tem_plot(ds, ds_base, case_nicknames, axs, s, var_list, res, obs, climo_yrs)
#Set figure title
plt.suptitle(f'TEM Diagnostics: {s}', fontsize=20, y=.928)
#Write the figure to provided workspace/file:
fig.savefig(plot_name, bbox_inches='tight', dpi=300)
#Add plot to website (if enabled):
adf.add_website_data(plot_name, "TEM", None, season=s, multi_case=True)
print(" ...TEM plots have been generated successfully.")
# Helper functions
##################
[docs]
def tem_plot(ds, ds_base, case_names, axs, s, var_list, res, obs, climo_yrs):
"""
TEM subplots
"""
#Set empty message for comparison of cases with different vertical levels
#TODO: Work towards getting the vertical and horizontal interpolations!! - JR
empty_message = "These have different vertical levels\nCan't compare cases currently"
props = {'boxstyle': 'round', 'facecolor': 'wheat', 'alpha': 0.9}
prop_x = 0.18
prop_y = 0.42
for var in var_list:
#Grab variable defaults for this variable
vres = res[var]
#Gather data for both cases
mdata = ds[var].squeeze()
odata = ds_base[var].squeeze()
#Create array to avoid weighting missing values:
md_ones = xr.where(mdata.isnull(), 0.0, 1.0)
od_ones = xr.where(odata.isnull(), 0.0, 1.0)
month_length = mdata.time.dt.days_in_month
weights = (month_length.groupby("time.season") / month_length.groupby("time.season").sum())
#Calculate monthly-weighted seasonal averages:
if s == 'ANN':
#Calculate annual weights (i.e. don't group by season):
weights_ann = month_length / month_length.sum()
mseasons = (mdata * weights_ann).sum(dim='time')
mseasons = mseasons / (md_ones*weights_ann).sum(dim='time')
#Calculate monthly weights based on number of days:
if obs:
month_length_obs = odata.time.dt.days_in_month
weights_ann_obs = month_length_obs / month_length_obs.sum()
oseasons = (odata * weights_ann_obs).sum(dim='time')
oseasons = oseasons / (od_ones*weights_ann_obs).sum(dim='time')
else:
month_length_base = odata.time.dt.days_in_month
weights_ann_base = month_length_base / month_length_base.sum()
oseasons = (odata * weights_ann_base).sum(dim='time')
oseasons = oseasons / (od_ones*weights_ann_base).sum(dim='time')
else:
#this is inefficient because we do same calc over and over
mseasons = (mdata * weights).groupby("time.season").sum(dim="time").sel(season=s)
wgt_denom = (md_ones*weights).groupby("time.season").sum(dim="time").sel(season=s)
mseasons = mseasons / wgt_denom
if obs:
month_length_obs = odata.time.dt.days_in_month
weights_obs = (month_length_obs.groupby("time.season") / month_length_obs.groupby("time.season").sum())
oseasons = (odata * weights_obs).groupby("time.season").sum(dim="time").sel(season=s)
wgt_denom = (od_ones*weights_obs).groupby("time.season").sum(dim="time").sel(season=s)
oseasons = oseasons / wgt_denom
else:
month_length_base = odata.time.dt.days_in_month
weights_base = (month_length_base.groupby("time.season") / month_length_base.groupby("time.season").sum())
oseasons = (odata * weights_base).groupby("time.season").sum(dim="time").sel(season=s)
wgt_denom_base = (od_ones*weights_base).groupby("time.season").sum(dim="time").sel(season=s)
oseasons = oseasons / wgt_denom_base
#difference: each entry should be (lat, lon)
dseasons = mseasons-oseasons
#Run through variables and plot each against the baseline on each row
#Each column will be a case, ie (test, base, difference)
# uzm
#------------------------------------------------------------------------------------------
if var == "uzm":
mseasons.plot(ax=axs[0,0], y='lev', yscale='log',ylim=[1e3,1],
cbar_kwargs={'label': ds[var].units})
oseasons.plot(ax=axs[0,1], y='lev', yscale='log',ylim=[1e3,1],
cbar_kwargs={'label': ds[var].units})
#Check if difference plot has contour levels, if not print notification
if len(dseasons.lev) == 0:
axs[0,2].text(prop_x, prop_y, empty_message, transform=axs[0,2].transAxes, bbox=props)
else:
dseasons.plot(ax=axs[0,2], y='lev', yscale='log', ylim=[1e3,1],cmap="BrBG",
cbar_kwargs={'label': ds[var].units})
# epfy
#------------------------------------------------------------------------------------------
if var == "epfy":
mseasons.plot(ax=axs[1,0], y='lev', yscale='log',vmax=1e6,ylim=[1e2,1],
cbar_kwargs={'label': ds[var].units})
oseasons.plot(ax=axs[1,1], y='lev', yscale='log',vmax=1e6,ylim=[1e2,1],
cbar_kwargs={'label': ds[var].units})
#Check if difference plot has contour levels, if not print notification
if len(dseasons.lev) == 0:
axs[1,2].text(prop_x, prop_y, empty_message, transform=axs[1,2].transAxes, bbox=props)
else:
dseasons.plot(ax=axs[1,2], y='lev', yscale='log', vmax=1e6,
ylim=[1e2,1],cmap="BrBG",
cbar_kwargs={'label': ds[var].units})
# epfz
#------------------------------------------------------------------------------------------
if var == "epfz":
mseasons.plot(ax=axs[2,0], y='lev', yscale='log',vmax=1e5,ylim=[1e2,1],
cbar_kwargs={'label': ds[var].units})
oseasons.plot(ax=axs[2,1], y='lev', yscale='log',vmax=1e5,ylim=[1e2,1],
cbar_kwargs={'label': ds[var].units})
#Check if difference plot has contour levels, if not print notification
if len(dseasons.lev) == 0:
axs[2,2].text(prop_x, prop_y, empty_message, transform=axs[2,2].transAxes, bbox=props)
else:
dseasons.plot(ax=axs[2,2], y='lev', yscale='log', vmax=1e5,
ylim=[1e2,1],cmap="BrBG",
cbar_kwargs={'label': ds[var].units})
# vtem
#------------------------------------------------------------------------------------------
if var == "vtem":
mseasons.plot.contourf(ax=axs[3,0], levels = 21, y='lev', yscale='log',
vmax=3,vmin=-3,ylim=[1e2,1], cmap='RdBu_r',
cbar_kwargs={'label': ds[var].units})
mseasons.plot.contour(ax=axs[3,0], levels = 11, y='lev', yscale='log',
vmax=3,vmin=-3,ylim=[1e2,1],
colors='black', linestyles=None)
oseasons.plot.contourf(ax=axs[3,1], levels = 21, y='lev', yscale='log',
vmax=3,vmin=-3,ylim=[1e2,1], cmap='RdBu_r',
cbar_kwargs={'label': ds[var].units})
oseasons.plot.contour(ax=axs[3,1], levels = 11, y='lev', yscale='log',
vmax=3,vmin=-3,ylim=[1e2,1],
colors='black', linestyles=None)
#Check if difference plot has contour levels, if not print notification
if len(dseasons.lev) == 0:
axs[3,2].text(prop_x, prop_y, empty_message, transform=axs[3,2].transAxes, bbox=props)
else:
dseasons.plot(ax=axs[3,2], y='lev', yscale='log', vmax=3,vmin=-3,
ylim=[1e2,1],cmap="BrBG",
cbar_kwargs={'label': ds[var].units})
# wtem
#------------------------------------------------------------------------------------------
if var == "wtem":
mseasons.plot.contourf(ax=axs[4,0], levels = 21, y='lev', yscale='log',
vmax=0.005, vmin=-0.005, ylim=[1e2,1], cmap='RdBu_r',
cbar_kwargs={'label': ds[var].units})
mseasons.plot.contour(ax=axs[4,0], levels = 7, y='lev', yscale='log',
vmax=0.03, vmin=-0.03, ylim=[1e2,1],
colors='black', linestyles=None)
oseasons.plot.contourf(ax=axs[4,1], levels = 21, y='lev', yscale='log',
vmax=0.005, vmin=-0.005, ylim=[1e2,1], cmap='RdBu_r',
cbar_kwargs={'label': ds[var].units})
oseasons.plot.contour(ax=axs[4,1], levels = 7, y='lev', yscale='log',
vmax=0.03, vmin=-0.03, ylim=[1e2,1],
colors='black', linestyles=None)
#Check if difference plot has contour levels, if not print notification
if len(dseasons.lev) == 0:
axs[4,2].text(prop_x, prop_y, empty_message, transform=axs[4,2].transAxes, bbox=props)
else:
dseasons.plot(ax=axs[4,2], y='lev', yscale='log',vmax=0.005, vmin=-0.005,
ylim=[1e2,1],cmap="BrBG",
cbar_kwargs={'label': ds[var].units})
# psitem
#------------------------------------------------------------------------------------------
if var == "psitem":
mseasons.plot.contourf(ax=axs[5,0], levels = 21, y='lev', yscale='log',
vmax=5e9, ylim=[1e2,2],
cbar_kwargs={'label': ds[var].units})
oseasons.plot.contourf(ax=axs[5,1], levels = 21, y='lev', yscale='log',
vmax=5e9, ylim=[1e2,2],
cbar_kwargs={'label': ds[var].units})
#Check if difference plot has contour levels, if not print notification
if len(dseasons.lev) == 0:
axs[5,2].text(prop_x, prop_y, empty_message, transform=axs[5,2].transAxes, bbox=props)
else:
dseasons.plot(ax=axs[5,2], y='lev', yscale='log',vmax=5e9,
ylim=[1e2,2],cmap="BrBG",
cbar_kwargs={'label': ds[var].units})
# utendepfd
#------------------------------------------------------------------------------------------
if var == "utendepfd":
mseasons.plot(ax=axs[6,0], y='lev', yscale='log',
vmax=0.0001, vmin=-0.0001, ylim=[1e2,2],
cbar_kwargs={'label': ds[var].units})
oseasons.plot(ax=axs[6,1], y='lev', yscale='log',
vmax=0.0001, vmin=-0.0001, ylim=[1e2,2],
cbar_kwargs={'label': ds[var].units})
#Check if difference plot has contour levels, if not print notification
if len(dseasons.lev) == 0:
axs[6,2].text(prop_x, prop_y, empty_message, transform=axs[6,2].transAxes, bbox=props)
else:
dseasons.plot(ax=axs[6,2], y='lev', yscale='log',vmax=0.0001, vmin=-0.0001,
ylim=[1e2,2],cmap="BrBG",
cbar_kwargs={'label': ds[var].units})
# utendvtem
#------------------------------------------------------------------------------------------
if var == "utendvtem":
mseasons.plot(ax=axs[7,0], y='lev', yscale='log',vmax=0.001, ylim=[1e3,1],
cbar_kwargs={'label': ds[var].units})
oseasons.plot(ax=axs[7,1], y='lev', yscale='log',vmax=0.001, ylim=[1e3,1],
cbar_kwargs={'label': ds[var].units})
#Check if difference plot has contour levels, if not print notification
if len(dseasons.lev) == 0:
axs[7,2].text(prop_x, prop_y, empty_message, transform=axs[7,2].transAxes, bbox=props)
else:
dseasons.plot(ax=axs[7,2], y='lev', yscale='log', vmax=0.001, ylim=[1e3,1],cmap="BrBG",
cbar_kwargs={'label': ds[var].units})
# utendwtem
#------------------------------------------------------------------------------------------
if var == "utendwtem":
mseasons.plot(ax=axs[8,0], y='lev', yscale='log',vmax=0.0001, ylim=[1e3,1],
cbar_kwargs={'label': ds[var].units})
oseasons.plot(ax=axs[8,1], y='lev', yscale='log',vmax=0.0001, ylim=[1e3,1],
cbar_kwargs={'label': ds[var].units})
#Check if difference plot has contour levels, if not print notification
if len(dseasons.lev) == 0:
axs[8,2].text(prop_x, prop_y, empty_message, transform=axs[8,2].transAxes, bbox=props)
else:
dseasons.plot(ax=axs[8,2], y='lev', yscale='log', vmax=0.0001, ylim=[1e3,1],cmap="BrBG",
cbar_kwargs={'label': ds[var].units})
# Set the ticks and ticklabels for all x-axes
#NOTE: This has to come after all subplots have been done,
#I am assuming this is because of the way xarray plots info automatically for labels and titles
#This is to change the default xarray labels for each instance of the xarray plot method
plt.setp(axs, xticks=np.arange(-80,81,20), xlabel='latitude', title="")
#Set titles of subplots
#Set case names in first subplot only
uzm = ds["uzm"].long_name.replace(" ", "\ ")
test_yrs = f"{climo_yrs['test'][0]}-{climo_yrs['test'][1]}"
axs[0,0].set_title(f"\n\n"+"$\mathbf{Test}$"+f" yrs: {test_yrs}\n"+f"{case_names[0]}\n\n\n",fontsize=14)
if obs:
obs_title = Path(vres["obs_name"]).stem
axs[0,1].set_title(f"\n\n"+"$\mathbf{Baseline}$\n"+f"{obs_title}\n\n"+"$\mathbf{"+uzm+"}$"+"\n",fontsize=14)
else:
base_yrs = f"{climo_yrs['base'][0]}-{climo_yrs['base'][1]}"
axs[0,1].set_title(f"\n\n"+"$\mathbf{Baseline}$"+f" yrs: {base_yrs}\n"+f"{case_names[1]}\n\n"+"$\mathbf{"+uzm+"}$"+"\n",fontsize=14)
#Set main title for difference plots column
axs[0,2].set_title("$\mathbf{Test} - \mathbf{Baseline}$"+"\n\n\n",fontsize=14)
#Set variable name on center plot (except first plot, see above)
for i in range(1,len(var_list)):
var_name = ds[var_list[i]].long_name.replace(" ", "\ ")
axs[i,1].set_title("$\mathbf{"+var_name+"}$"+"\n",fontsize=14)
#Adjust subplots
#May need to adjust hspace and wspace depending on if multi-case diagnostics ever happen for TEM diags
hspace = 0.4
plt.subplots_adjust(wspace=0.5, hspace=hspace)
return axs
##############
#END OF SCRIPT