'''
Auxiliary Plotting routines
===========================
Plotting routines for verification of the forecasts.
Functions
'''
import os,sys
import math
import numpy as np
import numpy.linalg as lin
import xarray as xr
import pandas as pd
import cartopy.crs as crs
import scipy.linalg as sc
from scipy import stats
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
import datetime
import csv
import pickle
# import torch
# import torch.nn as nn
# import torch.optim as optim
# from torch.utils.data import Dataset, DataLoader
# import torch.nn.functional as F
# import transformers as tr
# from sklearn.preprocessing import StandardScaler, MinMaxScaler
import zapata.computation as zcom
import zapata.data as zd
import zapata.lib as zlib
import zapata.mapping as zmap
import AIModels.AIutil as zai
[docs]
def Single_Forecast_plots( F, V, field, level, cont, timestart, world_projection='Atlantic', maxtime=3, stride = 1,colorbars=True,\
title='', mainlabel='', arealatz=None, arealonz=None, centlon=0,labfile=None):
'''
Make verification plots for the field `field` and level `level` for the dataset `F` and `V`.
Plots are made in two column of `maxtime` rows each.
Parameters
==========
F: xarray dataset
Dataset with the forecast to be analyzed
V: xarray dataset
Dataset with the verification field
mainlabel: string
Main label for the plot
field: string
Field to be analyzed
level: string
Level to be analyzed
timestart: string
Starting date for the plotting
maxtime: int
Maximum number of time steps to be plotted
The number of rows will be determined by this number
stride: int
Stride for the plotting
title: string
Title for the plot
cont :
* [cmin,cmax,cinc] fixed increment from cmin to cmax step cinc
* [ c1,c2, ..., cn] fixed contours at [ c1,c2, ..., cn]
* n n contours
* [] automatic choice
arealatz: numpy array
Latitude limits of the field
arealonz: numpy array
Longitude limits of the field
centlon: float
Central longitude
labfile: string
Name of the file to be written
Returns
=======
None
'''
label1 = mainlabel + field if level is None else mainlabel + field + level
index = int(np.where(F['time'] == pd.Timestamp(timestart))[0])
iv = index - 1
print(f'Plotting for time {F.time[iv].data} a total of {maxtime} time steps')
if maxtime > stride:
nrows = int(maxtime/stride)
else:
nrows = 1
stride = maxtime
figy = 4*(nrows-3) + 8
fig_index = np.arange(0,maxtime,stride,dtype=np.int32)
fig,ax,pro=zmap.init_figure(nrows,2,world_projection, constrained_layout=False, figsize=(12,figy) )
for i in range(len(fig_index)):
label12 = str(F.time[iv+fig_index[i]].data)[0:10]
label13 = f'Forecast Lead {fig_index[i]}'
out=zmap.xmap(F.isel(time=index,lead=fig_index[i]).unstack(),cont, pro, ax=ax[i,0],refline=None, c_format='{:4.2f}',data_cent_lon=centlon,\
xlimit=(arealonz[0],arealonz[1]),ylimit=(arealatz[1],arealatz[0]), custom=True,
title={'maintitle':label1, 'lefttitle':label12,'righttitle':label13},cmap='seismic',contour=False)
out.gl.right_labels = False
if colorbars:
zmap.add_colorbar(fig, out.handles['filled'], ax[i,0], label_size=10,edges=True)
label13 = f'Verification Lead {fig_index[i]}'
out2=zmap.xmap(V.isel(time=index,lead=fig_index[i]).unstack(),cont, pro, ax=ax[i,1],refline=None, c_format='{:4.2f}',data_cent_lon=centlon,\
xlimit=(arealonz[0],arealonz[1]),ylimit=(arealatz[1],arealatz[0]), custom=True,
title={'maintitle':label1, 'lefttitle':label12,'righttitle':label13},cmap='seismic',contour=False)
out2.gl.left_labels = False
if colorbars:
zmap.add_colorbar(fig, out2.handles['filled'], ax[i,1], label_size=10,edges=True)
if not colorbars:
# Create an axis on the bottom of the figure for the colorbar
cax = fig.add_axes([0.25, 0.05, 0.5, 0.015])
# Create a colorbar based on the second image
cbar = fig.colorbar(out.handles['filled'], cax=cax, orientation='horizontal')
# cbar.set_label('Colorbar Label')
plt.suptitle(title,fontsize=20)
if labfile is not None:
plt.savefig(labfile, orientation='landscape', format='pdf')
fig.subplots_adjust(wspace=0.05,hspace=0.0001)
plt.show()
return
[docs]
def many_plots(F, field, level, cont, timestart, world_projection='Pacific', mainlabel='', mode='time', lead=0, ncols=2, nrows=3, arealatz=None, arealonz=None, centlon=0, labfile=None, colorbars=False, suptitle=''):
'''
Make many plots for the field `field` and level `level` for the dataset `F`.
Plots are made in rows and columns according to `ncol` and `nrows`.
Parameters
==========
F: xarray dataset
Dataset with the field to be analyzed
mainlabel: string
Main label for the plot
field: string
Field to be analyzed
level: string
Level to be analyzed
timestart: string
Starting date for the plotting
mode: string
Mode for the plotting, either 'time' or 'lead'
lead: int
Lead time for the plotting
ncols: int
Number of columns for the plotting
nrows: int
Number of rows for the plotting
cont :
* [cmin,cmax,cinc] fixed increment from cmin to cmax step cinc
* [ c1,c2, ..., cn] fixed contours at [ c1,c2, ..., cn]
* n n contours
* [] automatic choice
arealatz: numpy array
Latitude limits of the field
arealonz: numpy array
Longitude limits of the field
centlon: float
Central longitude
labfile: string
Name of the file to be written
colorbars: bool
Whether to add colorbars to each panel or a single colorbar at the bottom
suptitle: string
General title on top of the plot
Returns
=======
None
'''
label1 = mainlabel + field + level
index = int(np.where(F['time'] == pd.Timestamp(timestart))[0])
iv = index - 1 + lead
if mode == 'time':
print(f'Plotting for time {F.time[iv].data} and lead {lead}')
elif mode == 'lead':
lead = -1
print(f'Plotting all leads for time {F.time[index].data}')
else:
print(f'No mode {mode} defined')
return None
fig, ax, pro = zmap.init_figure(nrows, ncols, world_projection, constrained_layout=False, figsize=(24, 12))
for i in range(nrows):
for j in range(ncols):
if mode == 'time':
iv += 1
label12 = str(F.time[iv].data)[0:10]
elif mode == 'lead':
lead += 1
label12 = str(F.time[iv + lead].data)[0:10]
else:
return None
if 'lead' in F.dims:
label13 = f'Lead {lead}'
handle = zmap.xmap(F.isel(time=iv, lead=lead).unstack(), cont, pro, ax=ax[i, j], refline=None, c_format='{:4.2f}', data_cent_lon=centlon,
xlimit=(arealonz[0], arealonz[1]), ylimit=(arealatz[1], arealatz[0]),
title={'maintitle': label1, 'lefttitle': label12, 'righttitle': label13}, cmap='seismic', contour=False)
else:
label13 = f'Time {label12}'
handle = zmap.xmap(F.isel(time=iv).unstack(), cont, pro, ax=ax[i, j], refline=None, c_format='{:4.2f}', data_cent_lon=centlon,
xlimit=(arealonz[0], arealonz[1]), ylimit=(arealatz[1], arealatz[0]),
title={'maintitle': label1, 'lefttitle': label12, 'righttitle': label13}, cmap='seismic', contour=False)
if colorbars:
zmap.add_colorbar(fig, handle['filled'], ax[i, j], label_size=10, edges=True)
if not colorbars:
# Create an axis on the bottom of the figure for the colorbar
cax = fig.add_axes([0.25, 0.05, 0.5, 0.015])
# Create a colorbar based on the first image
cbar = fig.colorbar(handle['filled'], cax=cax, orientation='horizontal')
# cbar.set_label('Colorbar Label')
plt.suptitle(suptitle, fontsize=20)
# fig.subplots_adjust(wspace=0.1, hspace=0.1)
if labfile is not None:
plt.savefig(labfile, orientation='landscape', format='pdf')
plt.show()
return
def _fix_labels(out,side='right'):
'''
Fix the labels for the plot
Parameters
==========
out: object
Object with the plot
side: string
Side for the labels
Returns
=======
None
'''
out.gl.left_labels = False
out.gl.right_labels = False
match side:
case 'left':
out.gl.left_labels = True
case 'right':
out.gl.right_labels = True
return out
[docs]
def Forecast_plots(F, V, field, level, cont, world_projection='Atlantic', pictimes=None, colorbars=True,
title='', mainlabel='', arealatz=None, arealonz=None, centlon=0, labfile=None):
'''
Make verification plots for the field `field` and level `level` for the dataset `F` and `V`.
Plots are made in two column of `maxtime` rows each.
Parameters
==========
F: xarray dataset
Dataset with the forecast to be analyzed
V: xarray dataset
Dataset with the verification field
mainlabel: string
Main label for the plot
field: string
Field to be analyzed
level: string
Level to be analyzed
maxtime: int
Maximum number of time steps to be plotted
The number of rows will be determined by this number
stride: int
Stride for the plotting
title: string
Title for the plot
cont :
* [cmin,cmax,cinc] fixed increment from cmin to cmax step cinc
* [ c1,c2, ..., cn] fixed contours at [ c1,c2, ..., cn]
* n n contours
* [] automatic choice
arealatz: numpy array
Latitude limits of the field
arealonz: numpy array
Longitude limits of the field
centlon: float
Central longitude
labfile: string
Name of the file to be written
Returns
=======
None
'''
label1 = mainlabel + field if level is None else mainlabel + field + level
# if maxtime > stride:
# nrows = int(maxtime / stride)
# else:
# nrows = 1
# stride = maxtime
nrows = len(pictimes)
figy = 4 * (nrows - 3) + 8
fig_index = pictimes
fig, ax, pro = zmap.init_figure(nrows, 2, world_projection, constrained_layout=False, figsize=(12, figy))
for i in range(len(fig_index)):
label12 = str(F.time[fig_index[i]].data)[0:10]
label13 = f'Forecast Lead {fig_index[i]}'
out = zmap.xmap(F.isel(time=fig_index[i]).unstack(), cont, pro, ax=ax[i, 0], refline=None, c_format='{:4.2f}', data_cent_lon=centlon,
xlimit=(arealonz[0], arealonz[1]), ylimit=(arealatz[1], arealatz[0]), custom=True,
title={'maintitle': label1, 'lefttitle': label12, 'righttitle': label13}, cmap='seismic', contour=False)
out.gl.right_labels = False
if colorbars:
zmap.add_colorbar(fig, out.handles['filled'], ax[i, 0], label_size=10, edges=True)
label12 = str(V.time[fig_index[i]].data)[0:10]
label13 = f'Verification Lead {fig_index[i]}'
out2 = zmap.xmap(V.isel(time=fig_index[i]).unstack(), cont, pro, ax=ax[i, 1], refline=None, c_format='{:4.2f}', data_cent_lon=centlon,
xlimit=(arealonz[0], arealonz[1]), ylimit=(arealatz[1], arealatz[0]), custom=True,
title={'maintitle': label1, 'lefttitle': label12, 'righttitle': label13}, cmap='seismic', contour=False)
out2.gl.left_labels = False
if colorbars:
zmap.add_colorbar(fig, out2.handles['filled'], ax[i, 1], label_size=10, edges=True)
if not colorbars:
# Create an axis on the bottom of the figure for the colorbar
cax = fig.add_axes([0.25, 0.05, 0.5, 0.015])
# Create a colorbar based on the second image
cbar = fig.colorbar(out.handles['filled'], cax=cax, orientation='horizontal')
# cbar.set_label('Colorbar Label')
plt.suptitle(title, fontsize=20)
if labfile is not None:
plt.savefig(labfile, orientation='landscape', format='pdf')
fig.subplots_adjust(wspace=0.05, hspace=0.0001)
plt.show()
return
[docs]
def Three_Forecast_plots( F, V, D, field, level, cont, timestart, pictimes,\
figsize=(12,8), world_projection='Atlantic',\
colorbars=True, maintitle='', picturelabels=None, arealatz=None, arealonz=None, centlon=0,labfile=None):
'''
Make verification plots for the field `field` and level `level` for the dataset `F` and `V`.
Plots are made in two column of `maxtime` rows each.
Parameters
==========
F: xarray dataset
Dataset with the forecast to be analyzed
V: xarray dataset
Dataset with the verification field
D: xarray dataset
Dataset with the deterministic forecast
figsize: tuple
Size of the figure
world_projection: string
World Projection for the plot
maintitle: string
Main label for the plot
picturelabels:namedtuple
List of labels for the pictures
field: string
Field to be analyzed
level: string
Level to be analyzed
timestart: string
Starting date for the plotting
column: int
Number of columns for the plotting
pictimes: list
List of times to be plotted
cont :
* [cmin,cmax,cinc] fixed increment from cmin to cmax step cinc
* [ c1,c2, ..., cn] fixed contours at [ c1,c2, ..., cn]
* n n contours
* [] automatic choice
arealatz: numpy array
Latitude limits of the field
arealonz: numpy array
Longitude limits of the field
centlon: float
Central longitude
labfile: string
Name of the file to be written
Returns
=======
None
'''
column = len(pictimes)
maxV = len(V['time'])
maxD = len(D['time'])
maxF = len(F['time'])
title = maintitle + field if level is None else maintitle + field + level
index = int(np.where(F['time'] == pd.Timestamp(timestart))[0])
iv = index
fstyle = {'fontsize': 11, 'fontfamily': 'futura', 'fontweight': 'bold'}
tstyle = {'fontsize': 12, 'fontfamily': 'futura', 'fontweight': 'bold'}
sup_style = {'fontsize': 24, 'fontfamily': 'futura', 'fontweight': 'bold'}
fig,ax,pro=zmap.init_figure(3,column,world_projection, constrained_layout=False, figsize=figsize )
for k, i in enumerate(pictimes):
if i <= maxF:
label12 = str(F.time[i].data)[0:10]
label13 = f'{picturelabels[0][k]}' if picturelabels is not None else f'Lead Month {i}'
out = zmap.xmap(F.isel(time=i).unstack(), cont, pro, ax=ax[k, 0], refline=None, c_format='{:4.2f}', data_cent_lon=centlon,
xlimit=(arealonz[0], arealonz[1]), ylimit=(arealatz[1], arealatz[0]),
custom=True, label_style=fstyle, title_style=tstyle,
title={'maintitle': 'Forecast', 'lefttitle': label12, 'righttitle': label13}, cmap='seismic', contour=False)
out = _fix_labels(out, 'left')
if colorbars:
zmap.add_colorbar(fig, out.handles['filled'], ax[k, 0], label_size=10, edges=True)
else:
ax[k, 0].axis('off')
if i <= maxD:
label13 = f'{picturelabels[1][k]}' if picturelabels is not None else f'Lead Month {i}'
out2 = zmap.xmap(D.isel(time=i).unstack(), cont, pro, ax=ax[k, 1], refline=None, c_format='{:4.2f}', data_cent_lon=centlon,
xlimit=(arealonz[0], arealonz[1]), ylimit=(arealatz[1], arealatz[0]),
custom=True, label_style=fstyle, title_style=tstyle,
title={'maintitle': 'GCM', 'lefttitle': label12, 'righttitle': label13}, cmap='seismic', contour=False)
out2 = _fix_labels(out2, None)
if colorbars:
zmap.add_colorbar(fig, out2.handles['filled'], ax[k, 1], label_size=10, edges=True)
else:
ax[k, 1].axis('off')
if i <= maxV:
label13 = f'{picturelabels[2][k]}' if picturelabels is not None else ''
out3 = zmap.xmap(V.isel(time=i).unstack(), cont, pro, ax=ax[k, 2], refline=None, c_format='{:4.2f}', data_cent_lon=centlon,
xlimit=(arealonz[0], arealonz[1]), ylimit=(arealatz[1], arealatz[0]),
custom=True, label_style=fstyle, title_style=tstyle,
title={'maintitle': 'Obs', 'lefttitle': label12, 'righttitle': label13}, cmap='seismic', contour=False)
out3 = _fix_labels(out3, 'right')
if colorbars:
zmap.add_colorbar(fig, out3.handles['filled'], ax[k, 2], label_size=10, edges=True)
else:
ax[k, 2].axis('off')
# Fix properties of bounding box
zmap.changebox(out.ax,'all',linewidth=2,color='black',capstyle='round')
zmap.changebox(out2.ax,'all',linewidth=2,color='black',capstyle='round')
zmap.changebox(out3.ax,'all',linewidth=2,color='black',capstyle='round')
if not colorbars:
# Create an axis on the bottom of the figure for the colorbar
cax = fig.add_axes([0.25, 0.05, 0.5, 0.015])
# Create a colorbar based on the first image
cbar = fig.colorbar(out.handles['filled'], cax=cax, orientation='horizontal')
# cbar.set_label('Colorbar Label')
plt.suptitle(title,**sup_style)
if labfile is not None:
plt.savefig(labfile, orientation='landscape', format='pdf')
# fig.subplots_adjust(wspace=0.05,hspace=0.0001)
plt.show()
return ax
# Write a routine to selecct indeces corrsponding to ceartain months into a xarray of datatimes64
[docs]
def select_months(datetimes,months):
'''
Select the indices corresponding to certain months in a xarray of datetimes64
Parameters
==========
datetimes: xarray of datetimes64
Datetimes to be analyzed
months: string or list of months
Label of months to be selected
Returns
=======
indices: numpy array
Indices of the selected months
'''
if isinstance(months, str):
match months:
case 'JFM':
months = [1,2,3 ]
case 'AMJ':
months = [4,5,6 ]
case 'JAS':
months = [7,8,9 ]
case 'OND':
months = [10,11,12 ]
case 'DJF':
months = [12,1,2 ]
else:
months = list(months)
indices = np.array([],dtype=np.int32)
for month in months:
indices = np.concatenate((indices,np.where(datetimes.dt.month == month)[0]))
return indices
[docs]
def plot_skill(corrresult,persistence,rmsres,rmsper,rmsdyn=None,corrdyn=None,labtit=None,
batch=False, savefig=False, data_dict=None,
skill='mean',numbers=True, printout=False, labfile='scores',figsize=(10,10)):
'''
Makes plot of the skill scores
for forecasts and persistence
Optionally add the skill scores for the dynamic forecast
'''
Tpredict = corrresult.shape[0]-1
tim = np.arange(0,Tpredict+1)
fig = plt.figure(figsize=figsize)
gs = fig.add_gridspec(2, 1, width_ratios=[0.5], height_ratios=2*[1], wspace=0.3, hspace=0.3)
# panels
ax1 = fig.add_subplot(gs[0, 0])
ax2 = fig.add_subplot(gs[1, 0])
# plot the skill scores and choose type of skill for correlation
match skill:
case 'mean':
sk = np.mean(corrresult,axis=1)
ps = np.mean(persistence,axis=1)
if corrdyn is not None:
dyn = np.mean(corrdyn,axis=1) # Start from month 1
case 'median':
sk = np.median(corrresult,axis=1)
ps = np.median(persistence,axis=1)
if corrdyn is not None:
dyn = np.median(corrdyn,axis=1) # Start from month 1
case _:
raise ValueError('Invalid skill choice for acc')
if printout:
write_skill_to_csv(labfile, skill, Tpredict, sk, ps, corrdyn=dyn, data_dict=data_dict)
if not batch:
ax1.plot(tim,sk)
ax1.plot(tim,ps,linestyle='dashed')
if corrdyn is not None:
ax1.plot(tim[0:7],dyn,linestyle='dashed',color='green')
ax1.axhline(0.6,linestyle='dashed')
ax1.set_title(f'ACC score {skill}',loc='left')
ax1.set_title(labtit,loc='right')
if numbers:
for ii in range(0,Tpredict+1):
ax1.text(ii,min(ax1.get_ylim())+0.1,f'{sk[ii]:4.2f}',horizontalalignment='center')
ax1.text(ii,min(ax1.get_ylim())+0.2,f'{ps[ii]:4.2f}',horizontalalignment='center',color='coral')
# plot the skill scores and choose type of skill for correlation
match skill:
case 'mean':
sk = np.mean(rmsres,axis=1)
ps = np.mean(rmsper,axis=1)
if rmsdyn is not None:
dyn = np.mean(rmsdyn,axis=1) # Start from month 1
case 'median':
sk = np.median(rmsres,axis=1)
ps = np.median(rmsper,axis=1)
if rmsdyn is not None:
dyn = np.median(rmsdyn,axis=1) # Start from month 1
case _:
raise ValueError('Invalid skill choice for rms')
if printout:
labrms = labfile.replace("ACC", "RMS")
write_skill_to_csv(labrms, skill, Tpredict, sk, ps, corrdyn=dyn, data_dict=data_dict)
if not batch:
ax2.plot(tim,sk)
ax2.plot(tim,ps,linestyle='dashed')
if corrdyn is not None:
ax2.plot(tim[0:7],dyn,linestyle='dashed',color='green')
ax2.set_title(f'RMS score {skill}',loc='left')
ax2.set_title(labtit,loc='right')
if numbers:
for ii in range(0,Tpredict+1):
ax2.text(ii,min(ax2.get_ylim())+0.01,f'{sk[ii]:4.2f}',horizontalalignment='center')
ax2.text(ii,min(ax2.get_ylim())+0.02,f'{ps[ii]:4.2f}',horizontalalignment='center',color='coral')
if savefig:
plt.savefig(f'{labfile}.pdf', format='pdf')
plt.show()
return
[docs]
def extract_and_merge_csv(input_folder, output_file, selected_rows):
"""
Extract selected rows from multiple CSV files in a folder and merge them into a single CSV.
Parameters
----------
input_folder : str
Path to the folder containing CSV files.
output_file : str
Path to the output merged CSV file.
selected_rows : list of int
List of row indices to extract from each CSV.
Returns
-------
None
Saves the merged CSV file to the specified output path.
"""
merged_df = pd.DataFrame()
# listfile = ['ACC_Score_Y620d_7_99_mean.csv','ACC_Score_Y620d_7_99_mean.csv','ACC_Score_Y620d_7_99_mean.csv']
header_saved = False
column_names = None
for filename in os.listdir(input_folder):
if filename.endswith(".csv"):
file_path = os.path.join(input_folder, filename)
if not header_saved:
df = pd.read_csv(file_path, header=0)
column_names = df.columns
header_saved = True
else:
df = pd.read_csv(file_path, header=None, skiprows=1, names=column_names)
selected_df = df.iloc[selected_rows]
merged_df = pd.concat([merged_df, selected_df], ignore_index=True)
# Save the merged data to a CSV file
with open(output_file, 'w', encoding='utf-8') as f:
merged_df.to_csv(f)#, index=False, header=True)
print(f"Merged CSV saved as {output_file}")
[docs]
def write_skill_to_csv(labfile, skill, Tpredict, sk, ps, corrdyn=None, data_dict=None):
"""
Write skill scores to a CSV file.
Parameters
----------
labfile : str
Base name for the output CSV file.
skill : str
Skill type (e.g., 'mean', 'median').
Tpredict : int
Number of prediction time steps.
sk : array-like
Skill scores for the forecast.
ps : array-like
Skill scores for the persistence.
corrdyn : array-like, optional
Skill scores for the dynamic forecast, by default None.
Returns
-------
None
"""
vars = zai.transform_strings(data_dict['input_fields'])
csv_filename = f"{labfile}_{skill}.csv"
print(f"Writing data to {csv_filename}")
head = ['Subcase_var','Subcase','VARS','LAGS','HID','LAYERS','EOF','DISCOUNT','INSEQ','AREA']
expvalue = [str(data_dict['subcase_var']), str(data_dict['subcase']), str(vars), str(data_dict['params']['lags']), str(data_dict['params']['D_DIM']), str(data_dict['params']['enc_dec_layers']), str(data_dict['eof']), str(data_dict['params']['discount']), str(data_dict['params']['TIN']), str(data_dict['area'])]
print(expvalue)
print(head)
with open(csv_filename, 'w', newline='') as csvfile:
writer = csv.writer(csvfile)
writer.writerow([f'Exp'] + head + [f'M{ii}' for ii in range(0, Tpredict + 1)])
writer.writerow([f'AC-{labfile}'] + expvalue +[f"{val:.4f}" for val in sk])
writer.writerow([f'PE-{labfile}'] + expvalue + [f"{val:.4f}" for val in ps])
if corrdyn is not None:
writer.writerow([f'DY-{labfile}'] + len(expvalue)*[' ']+ [f"{val:.2f}" for val in corrdyn])
return
[docs]
def plot_csv(file, figsize=(10, 6), colname='Exp', title='', xlab=None, ylab=None,
row_indices=[0, 3, 6, 9, 12], per_indices=None, GCM_indices=None,
sort_rows=False, savefile=None, input_axes=None, line06=True,
lstyle=None,tstyle=None):
'''
Plot the data in the skill csv file
Parameters
==========
file: str
Path to the CSV file
figsize: tuple
Size of the figure
colname: str
Column name to be used as labels
title: str
General title for the plot
xlab: str
Label for the x-axis
ylab: str
Label for the y-axis
row_indices: list
Indices of the rows to be plotted
per_indices: list
Indices of the persistence forecast
GCM_indices: list
Indices of the GCM forecast
sort_rows: bool
Whether to sort the subset by the chosen colname before plotting
savefile: str
Name of the file to be written (default not saved)
input_axes: obj (optinoal)
Whether to use input axes for the plot
Axes for the plot
line06: bool
Whether to plot the 0.6 line
lstyle: dict
Line style for the plot
tstyle: dict
Text style for the plot
Returns
=======
None
'''
df = pd.read_csv(file)
subset = df.loc[row_indices, 'M0':'M12']
labels = df.loc[row_indices, colname]
lstyle, tstyle = define_defaults_values(lstyle, tstyle)
if sort_rows:
combined = pd.concat([subset, labels.rename("Label")], axis=1)
try:
combined["Label"] = pd.to_numeric(combined["Label"])
except ValueError:
combined["Label"] = combined["Label"].astype(str)
combined = combined.sort_values(by="Label")
subset = combined.drop(columns="Label")
labels = combined["Label"]
if input_axes:
ax = input_axes
else:
fig, ax = plt.subplots(figsize=figsize)
for idx, row in subset.iterrows():
ax.plot(row.index, row.values, label=labels.loc[idx])
if per_indices is not None:
for idx in per_indices:
ax.plot(df.loc[idx, 'M0':'M12'], **{**lstyle, 'linestyle':'dashed'}, color='black')
else:
ax.plot(df.loc[1, 'M0':'M12'], label='Persistence', **{**lstyle, 'linestyle':'dashed'}, color='black')
if GCM_indices is not None:
for idx in GCM_indices:
ax.plot(df.loc[idx, 'M0':'M12'], label='GCM', **{**lstyle, 'linestyle':'dashed'}, color='green')
else:
ax.plot(df.loc[2, 'M0':'M7'], label='GCM', linestyle='dashed', color='green')
if line06:
ax.axhline(0.6, linestyle='dashed', color='blue')
ax.set_ylim([0.5,1.0])
ax.set_xlabel(xlab, fontname='Futura')
ax.set_ylabel(ylab, fontname='Futura')
ax.legend(loc='best', fontsize='x-small')
ax.grid(alpha=0.3)
ax.set_title(title, **tstyle)
if not input_axes:
if savefile:
plt.savefig(savefile, format='pdf')
plt.show()
return ax if input_axes else None
[docs]
def boxplot(file, verify_dyn=False, input_axes=None, savefile=False, pltype='ACC'):
'''
Make a box plot of the data in a pickle file. It also
write a tex files with the description of the plot
Parameters
==========
file: str
Path to the pickle file
verify_dyn: bool
Whether to include the GCM-based forecast
input_axes: bool
Whether to use input axes for the plot
savefile: str
Name of the file to be written (default not saved)
pltype: str
Type of plot to be made (default ACC)
Returns
=======
None
'''
# Set latex directory
texpath = os.path.expanduser("~")+'/CMCC Dropbox/Antonio Navarra/AI'
texdir = zai.create_subdirectory(texpath, 'LATEX')
# Read pickle file
with open(file, "rb") as f:
dat=pickle.load(f)
if dat['subcase'] == 203:
istart = 3
else:
istart = 1
#dat is a dictionary with the following keys
#dict_keys(['description', 'subcase_var', 'subcase', 'dataversion',
# 'invar_dict', 'pred_dict', 'author', 'date', 'name', 'config', 'area', 'size_model', 'best_val_loss', 'model_config',
# 'params', 'verification_field', 'input_fields', 'output_fields',
# 'eof', 'GCMcorr', 'corr', 'pers'])
# Create a box plot with customization
if input_axes:
ax = input_axes
else:
fig, ax = plt.subplots(figsize=(8, 6))
# Extract relevant data from the dictionary
if pltype == 'ACC':
corrresult = dat['corr']
persistence = dat['pers']
if verify_dyn:
corrdyn = dat['GCMcorr']
elif pltype == 'RMS':
corrresult = dat['rms']
persistence = dat['rmsper']
if verify_dyn:
corrdyn = dat['GCMrms']
else:
raise ValueError("Invalid plot type. Choose 'ACC' or 'RMS'.")
ntim, nens = corrresult.shape
#Customizing box plot appearance
# Create a box plot with different colors for each dataset
box1 = ax.boxplot(corrresult[istart:,:].T, patch_artist=True,
boxprops=dict(facecolor='lightblue', color='blue'),
medianprops=dict(color='darkblue',linewidth=2),
whiskerprops=dict(color='blue'),
capprops=dict(color='blue'),
flierprops=dict(markeredgecolor='blue'),
)
# change here the xlabel to np.range(istart,ntim)
box2 = ax.boxplot(persistence[istart:,:].T, patch_artist=True,
boxprops=dict(facecolor='lightgreen', color='green'),
medianprops=dict(color='darkgreen',linewidth=2),
whiskerprops=dict(color='green'),
capprops=dict(color='green'),
flierprops=dict(markeredgecolor='green'),
)
if verify_dyn:
box3 = ax.boxplot(corrdyn[istart:,:].T, patch_artist=True, widths=0.25,
boxprops=dict(facecolor='lightpink', color='red'),
medianprops=dict(color='darkred',linewidth=2),
whiskerprops=dict(color='red'),
capprops=dict(color='red'),
flierprops=dict(markeredgecolor='red'),
)
# Add a legend
ax.legend([box1["boxes"][0], box2["boxes"][0], box3["boxes"][0]], ['DeepSeason', 'Persistence', 'GCM'], loc='best', prop={'family': 'Futura'})
# Define titleleft from the verification key in the dictionary
titleleft = zai.transform_strings([dat['verification_field']])
# Define titleright from the subcase key in the dictionary
titleright = [str(dat['subcase'])]
# Add title and labels
ax.set_title(titleleft[0], loc='left', fontname='Futura')
ax.set_xlabel('Forecast Month', fontname='Futura')
ax.set_ylabel(pltype, fontname='Futura')
ax.set_xlim([istart-0.5, ntim+0.5])
if savefile:
savefile = f"{texdir}/{file}_boxplot.tex"
# Write to a tex file a latex piece describing the box plot
tex_content = f"""
\\documentclass{{article}}
\\usepackage{{graphicx}}
\\begin{{document}}
\\section*{{Box Plot Description}}
Correlation box plot for the verification field {titleleft[0]} and for the {dat['area']} region.
Lightblue is used for the model forecast, lightgreen for the persistence forecast, and lightpink for the GCM-based forecast. Median and quartile values
are shown in dark blue, dark green, and dark red, respectively. The input fields used are {zai.transform_strings(dat['input_fields'])}.
In this case the EOF truncation used is {dat['eof']}, the size of the hidden space of the
model is {dat['params']['D_DIM']} and the discount parameter for the loss function is {dat['params']['discount']}.
\\begin{{figure}}[h!]
\\centering
\\includegraphics[width=0.8\\textwidth]{{{savefile.replace('.tex', '.pdf')}}}
\\caption{{Correlation box plot for the verification field {titleleft[0]} and for the {dat['area']} region.
Lightblue is used for the model forecast, lightgreen for the persistence forecast, and lightpink for the GCM-based forecast. Median and quartile values
are shown in dark blue, dark green, and dark red, respectively. The input fields used are {dat['input_fields']}
and the LAGS parameters are {dat['params']['lags']}.
In this case the EOF truncation used is {dat['eof']}, the size of the hidden space of the
model is {dat['params']['D_DIM']} and the discount parameter for the loss function is {dat['params']['discount']}.}}
\\end{{figure}}
\\end{{document}}
"""
lab = file
with open(f"{texdir}/{lab}_boxplot.tex", "w") as tex_file:
tex_file.write(tex_content)
# Show the plot
if not input_axes:
if savefile:
plt.savefig(savefile, format='pdf')
plt.show()
return ax if input_axes else None
[docs]
def write_var_excel(name='Y620', varname='best_val_loss', vcases=[1, 2, 3, 4, 5], subcases=[8], filename=None, datain='_corr_data', write_latex=False):
"""
Read several pickle files, extract the requested variable "varname" from each,
and write them into a DataFrame with columns = subcases and rows = vcases.
Optionally write the DataFrame to an Excel file and also output a LaTeX table.
"""
texpath = os.path.expanduser("~") + '/CMCC Dropbox/Antonio Navarra/AI'
texdir = zai.create_subdirectory(texpath, 'LATEX')
# Decode subcase labels
subcases_labels, _ = _decode_subcases(subcases)
# Create a DataFrame with rows = vcases and columns = subcases_labels
data_df = pd.DataFrame(index=vcases, columns=subcases_labels, dtype=float)
# Fill DataFrame by reading each pickle file and extracting "varname"
for i, sc in enumerate(subcases):
col_label = subcases_labels[i]
for vc in vcases:
file_path = f"{name}_V{vc}_{sc}{datain}"
with open(file_path, "rb") as f:
data_dict = pickle.load(f)
data_df.loc[vc, col_label] = data_dict[varname]
# Write to Excel if requested
if filename:
data_df.to_excel(filename)
# Optionally write LaTeX code for the table
if write_latex:
# df_mean = data_df.copy()
# for col in df_mean.columns:
if isinstance(data_dict[varname], int):
# do not round integers
latex_code = data_df.to_latex(float_format='%.f')
else:
latex_code = data_df.to_latex(float_format='%.2f')
latex_filename = filename.replace('.xlsx', '.tex') if filename else 'results.tex'
with open(f"{texdir}/{latex_filename}", 'w') as tex_file:
tex_file.write(latex_code)
return
def _decode_subcases(subcases):
'''
Return a list of labels for the subcases'
'''
subcases_labels = []
for sc in subcases:
# make a list of if statements for subcases
if sc == 1:
subcases_labels.append('1')
elif sc == 2:
subcases_labels.append('DISCOUNT')
trial_labels = ['1.0', '0.9', '0.7', '0.3', '0.1']
elif sc == 3:
subcases_labels.append('HID')
trial_labels = ['1024', '512', '256', '128', '64']
elif sc == 4:
subcases_labels.append('LAGS')
trial_labels = ['[1]', '[1, 2]', '[1, 2, 3]', '[1, 2, 3, 4]', '[1, 2, 3, 4, 5, 6]']
elif sc == 5:
subcases_labels.append('5')
trial_labels = ['[10,25,15,25]', '[15,25,15,25]', '[25,25,25,25]', '[35,25,35,25]', '[45,25,45,25]']
elif sc == 6:
subcases_labels.append('LAYERS')
trial_labels = [1,2,4,8,16]
elif sc == 7:
subcases_labels.append('7')
elif sc == 8:
subcases_labels.append('VARS')
trial_labels = ['[SST]', '[SST, U850]', '[SST, U850,SP]', '[SST, U850,SP, OLR]', '[SST, OLR]']
elif sc == 9:
subcases_labels.append('HID')
trial_labels = ['1024', '512', '256', '128', '64']
elif sc == 10:
subcases_labels.append('LAGS')
trial_labels = ['[1]', '[1, 2]', '[1, 2, 3]', '[1, 2, 3, 4]', '[1, 2, 3, 4, 5, 6]']
elif sc == 11:
subcases_labels.append('LAYERS')
trial_labels = [1,2,4,8,16]
elif sc == 12:
subcases_labels.append('EOF')
trial_labels = ['[10,25]', '[15,25]', '[25,25]', '[45,25]', '[45,45]']
elif sc == 13:
subcases_labels.append('DISCOUNT')
trial_labels = ['[1.0]', '[0.9]', '[0.7]', '[0.5]', '[0.1]']
elif sc == 100:
subcases_labels.append('DISCOUNT')
trial_labels = ['EOF15']
elif sc == 21:
subcases_labels.append('VARS')
trial_labels = ['[T2M]', '[T850, T2M]', '[U850, T2M]', '[SP, T2M]', '[U850, V850, T850, SST, T2M]']
elif sc == 23:
subcases_labels.append('EOF')
trial_labels = ['[5,5]', '[6,6]', '[7,7]', '[8,8]', '[10,10]']
elif sc == 24:
subcases_labels.append('LAYERS')
trial_labels = [1,2,4,8,12]
elif sc == 25:
subcases_labels.append('DISCOUNT')
trial_labels = ['[1.0]', '[0.9]', '[0.7]', '[0.5]', '[0.1]']
else:
raise ValueError('Invalid subcase')
return subcases_labels, trial_labels
[docs]
def Two_Forecast_plots(F, V, field, level, cont, timestart, pictimes, figsize=(12, 8), world_projection='Atlantic', colorbars=True, maintitle='', picturelabels=None, arealatz=None, arealonz=None, centlon=0, labfile=None):
'''
Make verification plots for the field `field` and level `level` for the dataset `F` and `V`.
Plots are made in two columns of `maxtime` rows each.
Parameters
==========
F: xarray dataset
Dataset with the forecast to be analyzed
V: xarray dataset
Dataset with the verification field
figsize: tuple
Size of the figure
world_projection: string
World Projection for the plot
maintitle: string
Main label for the plot
picturelabels: namedtuple
List of labels for the pictures
field: string
Field to be analyzed
level: string
Level to be analyzed
timestart: string
Starting date for the plotting
pictimes: list
List of times to be plotted
cont :
* [cmin,cmax,cinc] fixed increment from cmin to cmax step cinc
* [ c1,c2, ..., cn] fixed contours at [ c1,c2, ..., cn]
* n n contours
* [] automatic choice
arealatz: numpy array
Latitude limits of the field
arealonz: numpy array
Longitude limits of the field
centlon: float
Central longitude
labfile: string
Name of the file to be written
Returns
=======
None
'''
column = len(pictimes)
maxV = len(V['time'])
maxF = len(F['time'])
title = maintitle + field if level is None else maintitle + field + level
index = int(np.where(F['time'] == pd.Timestamp(timestart))[0])
iv = index
fstyle = {'fontsize': 11, 'fontfamily': 'futura', 'fontweight': 'bold'}
tstyle = {'fontsize': 12, 'fontfamily': 'futura', 'fontweight': 'bold'}
sup_style = {'fontsize': 24, 'fontfamily': 'futura', 'fontweight': 'bold'}
fig, ax, pro = zmap.init_figure(2, column, world_projection, constrained_layout=False, figsize=figsize)
for k, i in enumerate(pictimes):
if i <= maxF:
label12 = str(F.time[i].data)[0:10]
label13 = f'{picturelabels[0][k]}' if picturelabels is not None else f'Lead Month {i}'
out = zmap.xmap(F.isel(time=i).unstack(), cont, pro, ax=ax[0, k], refline=None, c_format='{:4.2f}', data_cent_lon=centlon,
xlimit=(arealonz[0], arealonz[1]), ylimit=(arealatz[1], arealatz[0]),
custom=True, label_style=fstyle, title_style=tstyle,
title={'maintitle': 'Forecast', 'lefttitle': label12, 'righttitle': label13}, cmap='seismic', contour=False)
out = _fix_labels(out, 'left')
if colorbars:
zmap.add_colorbar(fig, out.handles['filled'], ax[0, k], label_size=10, edges=True)
else:
ax[0, k].axis('off')
if i <= maxV:
label13 = f'{picturelabels[1][k]}' if picturelabels is not None else ''
out2 = zmap.xmap(V.isel(time=i).unstack(), cont, pro, ax=ax[1, k], refline=None, c_format='{:4.2f}', data_cent_lon=centlon,
xlimit=(arealonz[0], arealonz[1]), ylimit=(arealatz[1], arealatz[0]),
custom=True, label_style=fstyle, title_style=tstyle,
title={'maintitle': 'Obs', 'lefttitle': label12, 'righttitle': label13}, cmap='seismic', contour=False)
out2 = _fix_labels(out2, 'right')
if colorbars:
zmap.add_colorbar(fig, out2.handles['filled'], ax[1, k], label_size=10, edges=True)
else:
ax[1, k].axis('off')
# Fix properties of bounding box
zmap.changebox(out.ax, 'all', linewidth=2, color='black', capstyle='round')
zmap.changebox(out2.ax, 'all', linewidth=2, color='black', capstyle='round')
if not colorbars:
# Create an axis on the bottom of the figure for the colorbar
cax = fig.add_axes([0.25, 0.05, 0.5, 0.015])
# Create a colorbar based on the first image
cbar = fig.colorbar(out.handles['filled'], cax=cax, orientation='horizontal')
# cbar.set_label('Colorbar Label')
plt.suptitle(title, **sup_style)
if labfile is not None:
plt.savefig(labfile, orientation='landscape', format='pdf')
plt.show()
return ax
[docs]
def write_var_table(name='Y620', case='LAGS', subcases=[8,9], rows=[0,3,5,7,9,11], srwox=[1,2],
filename='TableTex', skill=False, stype='mean',mark='max'):
'''
Read several csv files, with "name" and several subcases, option for "skill cases", and write a latex table from
input "rows", add the single rows in "special rows" and write the table in "filename"
'''
texpath = os.path.expanduser("~")+'/CMCC Dropbox/Antonio Navarra/AI'
texdir = zai.create_subdirectory(texpath, 'LATEX')
#Read several csv files looping over subcases with the skill scores
#Write the rows in the latex file
with open(f'{texdir}/{filename}', 'w') as texfile:
header_written = False
for sc in subcases:
file = f'{skill}_Scores_{name}_{sc}_{stype}.csv'
df = pd.read_csv(file)
desired_headers = [case] + [f"M{i}" for i in range(13)]
# Keep only the desired columns if they exist
existing_cols = [c for c in desired_headers if c in df.columns]
# Filter the rows we need
df_filtered_main = df.iloc[rows][existing_cols]
df_filtered_main = df_filtered_main.sort_values(by=case)
df_filtered_special = df.iloc[srwox][existing_cols]
df_filtered_special.iloc[0, 0] = "PERS"
df_filtered_special.iloc[1, 0] = "GCM"
# Combine them
df_combined = pd.concat([df_filtered_main, df_filtered_special])
# Replace NaN with blank
df_combined.fillna('', inplace=True)
# Color the maximum value in each M0..M12 column in red and center all entries
columns_to_color = [f"M{i}" for i in range(1,13) if f"M{i}" in df_combined.columns]
for col in columns_to_color:
data = pd.to_numeric(df_combined[col], errors='coerce')
if mark == "max":
max_val = data.max(skipna=True)
else:
max_val = data.min(skipna=True)
for idx in df_combined.index:
if pd.notnull(data.at[idx]) and data.at[idx] == max_val:
df_combined.at[idx, col] = f"\\textcolor{{red}}{{{data.at[idx]:.2f}}}"
elif pd.notnull(data.at[idx]):
df_combined.at[idx, col] = f"{data.at[idx]:.2f}"
# Convert to LaTeX, ensuring we don't escape our LaTeX commands and center columns
latex_table = df_combined.to_latex(float_format='%.2f',
index=False,
header=not header_written,
na_rep='',
escape=False,
column_format='c' * len(df_combined.columns)
)
texfile.write(latex_table + "\n")
header_written = True
return
[docs]
def define_defaults_values(lstyle, tstyle):
'''
Define default values for the line and title styles
'''
default_line_style = {
'linewidth': 1,
'linestyle': '-',
}
if lstyle is None:
lstyle = default_line_style
else:
lstyle = {**default_line_style, **lstyle}
default_title_style = {
'fontsize': 20,
'fontfamily': 'futura',
'fontweight': 'bold',
}
if tstyle is None:
tstyle = default_title_style
else:
tstyle = {**default_title_style, **tstyle}
return lstyle, tstyle
# Assuming Ft.time and Dt.time are arrays of datetime64, create an arrays that has only the dates that are
# in both arrays and get also the indeces of the original arrays
[docs]
def get_common_dates(Ft, Dt):
"""
Get common dates between two xarray datasets and their indices.
Parameters
----------
Ft : xarray.DataArray
First dataset with a time dimension.
Dt : xarray.DataArray
Second dataset with a time dimension.
Returns
-------
common_dates : numpy.ndarray
Array of common dates.
Ft_indices : numpy.ndarray
Indices of common dates in the first dataset.
Dt_indices : numpy.ndarray
Indices of common dates in the second dataset.
"""
common_dates = np.intersect1d(Ft.time.values, Dt.time.values)
Ft_indices = np.where(np.isin(Ft.time.values, common_dates))[0]
Dt_indices = np.where(np.isin(Dt.time.values, common_dates))[0]
return common_dates, Ft_indices, Dt_indices
# suggest some methods to obtain the signficant values for correlation calculated using the
# xarray function .corr write the code adding also the code for the calculation of the p-values
def calculate_significance(corr, n, alpha=0.05):
"""
Calculate the significance of correlation coefficients
using the Fisher Z-transformation and mask non-significant values.
Parameters
----------
corr : xarray.DataArray
Correlation coefficients.
n : int
Sample size.
alpha : float, optional
Significance level (default is 0.05).
Returns
-------
significant_corr : xarray.DataArray
Correlation coefficients with non-significant values masked as NaN.
p_values : xarray.DataArray
P-values for the correlation coefficients.
"""
# Calculate Z-scores using Fisher Z-transformation
z = 0.5 * np.log((1 + corr) / (1 - corr))
# Calculate standard error
se = 1 / np.sqrt(n - 3)
# Calculate Z-scores for significance testing
z_scores = z / se
# Calculate p-values
p_values = 2 * (1 - stats.norm.cdf(np.abs(z_scores)))
# Mask non-significant values (retain sign of correlation)
significant_corr = corr.where(p_values < alpha)
return significant_corr, p_values
import xarray as xr
import numpy as np
from scipy import stats
[docs]
def calculate_significance(corr: xr.DataArray,
n,
alpha: float = 0.05,
two_tailed: bool = True,
method: str = "fisher"):
"""
Significance test for correlation coefficients.
Parameters
----------
corr : xr.DataArray Correlation coefficients (−1 < r < 1)
n : int or xr.DataArray Sample size(s) (must be > 3)
alpha: float Significance level
two_tailed : bool Two‑tailed if True, else one‑tailed
method : {"fisher","t"} Approximate (normal) or exact (t) test
"""
if (xr.where(n <= 3, True, False)).any():
raise ValueError("Sample size n must be > 3 for a significance test")
# Clip r to avoid log/0 overflow
r = corr.clip(-0.999_999, 0.999_999)
if method == "fisher":
z = np.arctanh(r) # same as 0.5*log((1+r)/(1-r))
se = 1 / np.sqrt(n - 3)
z_s = z / se
p = stats.norm.sf(np.abs(z_s))
elif method == "t":
t = r * np.sqrt((n - 2) / (1 - r ** 2))
p = stats.t.sf(np.abs(t), n - 2)
else:
raise ValueError("method must be 'fisher' or 't'")
if two_tailed:
p *= 2
sig_corr = r.where(p < alpha)
return sig_corr, p