'''
Utility Treatment Routines for ML Predicting models
===================================================
This modules contains a set of routines for preparing data for insertion in transformers models used
for prediction of multivariate time series data. The classes are contaiend in the companion file `AIClasses.py` and are imported in this module.
The classes are contaiend in the companion file `AIClasses.py` and are imported in this module.
Utilities
---------
'''
import os,sys,re,gc
import math
import numpy as np
# import numpy.linalg as lin
import xarray as xr
import pandas as pd
import scipy.linalg as sc
# import matplotlib.pyplot as plt
# import datetime
import torch
import torch.nn as nn
# import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
# import transformers as tr
from sklearn.preprocessing import StandardScaler, MinMaxScaler
import zapata.computation as zcom
import zapata.data as zd
# import zapata.lib as zlib
import zapata.mapping as zmap
# import zapata.koopman as zkop
import AIModels.AIClasses as zaic
[docs]
def func_name():
"""
:return: name of caller
"""
return sys._getframe(1).f_code.co_name
[docs]
def copy_dict(data, strip_values=False, remove_keys=[]):
'''
Copy dictionary
Parameters
==========
data: dict
Dictionary to be copied
strip_values: boolean
If True, strip values
remove_keys: list
List of keys to be removed
Returns
=======
out: dict
Copied dictionary without the keys in `remove_keys`
'''
if type(data) is dict:
out = {}
for key, value in data.items():
if key not in remove_keys:
out[key] = copy_dict(value, strip_values=strip_values, remove_keys=remove_keys)
return out
else:
return [] if strip_values else data
[docs]
def get_arealat_arealon(area):
'''
Obtain lat lon extent for various choices of areas
Parameters
==========
area: string
Area to be analyzed, possible values are
* 'TROPIC': Tropics (-35,35), [0,360]
* 'GLOBAL': Global (-60,60), [120,290]
* 'PACTROPIC': Pacific Tropics (-25,25), [180,290]
* 'WORLD': World (-70,70), [0,360]
* 'EUROPE': Europe (30,70), [-15,50]
* 'NORTH_AMERICA': North America (25,70), [200,310]
* 'NH-ML': Northern Hemisphere Mid-Latitudes (20,90), [0,360]
'''
# central pacific coordinates
if area == 'TROPIC':
arealat=(35,-35)
arealon=np.array([0, 360.])
elif area == 'GLOBAL':
arealat=(60,-60)
arealon=np.array([120,290.] )
elif area == 'PACTROPIC':
arealat=(25,-25)
arealon=np.array([180,290.])
elif area == 'WORLD':
arealat=(70,-70)
arealon=[0,360.]
elif area == 'NH-ML':
arealat=(90,20)
arealon=[0,360.]
elif area == 'EUROPE':
# Require central longitude to 0 lon
shift_Centlon = 0
arealat=(70,30)
arealon=[-15, 50.]
elif area == 'NORTH_AMERICA':
arealat=(70,25)
arealon=[200, 310.]
else:
print(f'No area defined')
raise ValueError
arealat = np.array(arealat)
arealon = np.array(arealon)
return arealat, arealon
[docs]
def select_field(INX,outfield,verbose=False):
'''
Select field for `outfield` in dict
PARAMETERS
==========
INX: dict
Dictionary with the fields to be analyzed
outfield: string
Field to be analyzed
verbose: boolean
If True, print the field information
RETURNS
=======
out: numpy array
Field data
arealat: numpy array
Latitude limits of the field
arealon: numpy array
Longitude limits of the field
centlon: numpy array
Central longitude of the field
'''
for i in INX.keys():
if i == outfield:
if verbose:
print(INX[i])
out = INX[i]['X']
arealat = np.array(INX[i]['arealat'])
arealon = np.array(INX[i]['arealon'])
centlon = np.array(INX[i]['centlon'])
break
return out,arealat,arealon,centlon
[docs]
def select_field_key(INX,outfield,dataname):
'''
Select field for `outfield` in dict according to `dataname`
PARAMETERS
==========
INX: dict
Dictionary with the fields to be analyzed
outfield: string
Field to be analyzed
dataname: string
Key to be analyzed
RETURNS
=======
out: numpy array
Field data
'''
for i in INX.keys():
if i == outfield:
out = INX[i][dataname]
break
return out
[docs]
def select_field_eof(INX,outfield):
'''
Select eof-related fields for `outfield` in dict
PARAMETERS
==========
INX: dict
Dictionary with the fields to be analyzed
outfield: string
Field to be analyzed
RETURNS
=======
U: numpy array
EOF modes
S: numpy array
Singular values
V: numpy array
EOF coefficients
'''
for i in INX.keys():
if i == outfield:
print(f'{func_name()} ---> Extracting EOF data for {INX[i]["field"]}')
U = INX[i]['udat']
V = INX[i]['vdat']
S = INX[i]['sdat']
break
return U,S,V
[docs]
def select_area(area,ZZ):
'''
Select area for dataset ZZ
Parameters
==========
area: string
Area to be analyzed, possible values are
(EUROPE only implemented)
* 'TROPIC': Tropics
* 'GLOBAL': Global
* 'PACTROPIC': Pacific Tropics
* 'WORLD': World
* 'EUROPE': Europe
* 'NORTH_AMERICA': North America
* 'NH-ML': Northern Hemisphere Mid-Latitudes
ZZ: xarray dataset
Returns
=======
Z: xarray dataset
'''
shift_Centlon = 180
arealat, arealon = get_arealat_arealon(area)
if area =='EUROPE':
print('Use Greenwich centered coordinates with centlat=0')
ZZ1 = ZZ.sel(lat=slice(70,30),lon=slice(340,360))
ZZ1 = ZZ1.assign_coords({'lon':ZZ1.lon -360})
Z = xr.concat((ZZ1, ZZ.sel(lat=slice(70,30),lon=slice(0,50))), dim='lon')
del ZZ1
else:
print('Use Pacific centered coordinates with centlat=180')
Z = ZZ.sel(lat=slice(arealat[0],arealat[1]), lon=slice(arealon[0],arealon[1]))
Z = Z.assign_coords(lon = Z.lon-180.)
arealon = arealon - 180
return Z
[docs]
def make_matrix(S,SMOOTH,normalization,shift,**kwargs):
'''
Make matrix for SVD analysis
Parameters
==========
S: xarray dataset
Dataset with the field to be analyzed
SMOOTH: boolean
If True, smooth the data
normalization: string
Type of normalization to be applied
shift: string
Type of shift to be applied to select the data
dropnan: String
Option to drop NaN values from Xmat
True -- Uses indexed dropna in Xmat
False -- Uses standard dropna in Xmat from Xarray
detrend: boolean
If True, detrend the data
Returns
=======
X: xarray dataset
Math Matrix with EOF coefficients
'''
# Default values
default_values = {'dropnan': False, 'detrend' : False}
mv = {**default_values, **kwargs}
print(mv)
#Weight cosine of latitude
WS=np.cos(S.lat*math.pi/180)
# Create Xmat Matrix
if mv['dropnan']:
X=zcom.Xmat(S*WS,dims=('lat','lon'),option='DropNaN')
print('Option DropNaN -- Shape of Xmat',X.A.shape)
else:
X=zcom.Xmat(S*WS,dims=('lat','lon'))
X.A = X.A.dropna(dim='z')
print('Option DropNaN False -- Shape of Xmat',X.A.shape)
# detrend data
if mv['detrend']:
print('make_matrix: -- Detrending data')
X.detrend(axis=1)
# Smooth data
# Put label on the last month of the window
if SMOOTH:
X.A = X.A.rolling(time=3,center=False,min_periods=3).mean().dropna(dim='time')
# Create anomalies
X.anom(option=normalization)
if shift == 'SCRAM':
X.A = X.A[:,np.random.permutation(np.arange(len(X.A.time)))].assign_coords(time=S.time[:-2])
return X
[docs]
def make_eof(X, mr,eof_interval=None):
'''
Compute EOF analysis on data matrix
Parameters
==========
X: X matrix
Data matrix in `X` format
mr: int
Number of modes to be retained
eof_interval: list
List with the starting and ending date for EOF analysis
Returns
=======
mr: int
Number of modes retained, in case the number of modes is less than rank
Set to `math.inf` to keep all modes
var_retained: float
Percentage of variance retained
udat: numpy array
Matrix of EOF modes
vdat: numpy array
Matrix of EOF coefficients
sdat: numpy array
Vector of singular values
'''
xdat = X.A.data
#Check eof interval
if eof_interval is not None:
eofstart = X.A.time.to_index().get_loc(eof_interval[0])
eofend = X.A.time.to_index().get_loc(eof_interval[1])
print(f'make_eof: -- EOF interval defined -- Using data from {eof_interval[0]} to {eof_interval[1]}')
print(f'make_eof: -- EOF interval defined -- Using data from {eofstart} to {eofend}')
udatx, sdatx, vdatx = sc.svd(xdat[:,eofstart:eofend], full_matrices=False,lapack_driver='gesvd')
else:
udatx, sdatx, vdatx = sc.svd(xdat[:,:], full_matrices=False)
# lmr = lin.matrix_rank(xdat)
lmr = matrix_rank_light(xdat,sdatx)
var = sdatx**2
if isinstance(mr, float) and 0 < mr < 1.0:
print(f'Target Variance {mr}')
total_variance = var.sum()
csum = np.cumsum(var)
new_mr = np.searchsorted(csum, mr * total_variance) + 1
mr = int(min(new_mr, lmr))
var_retained = csum[mr - 1] / total_variance
print(f' Number of SVD modes retained {mr}, rank of matrix {lmr}')
print(f' Variance Retained {var_retained:.2f} out of possible {len(var)} modes')
else:
mr = min(mr,lmr)
print(f' Number of SVD modes retained {mr}, rank of matrix {lmr}')
var_retained = sum(var[0:mr])/sum(var)
print(f'Variance Retained {var_retained:.2f} out of possible {len(var)} modes')
print(f' Condition number {max(sdatx)/min(sdatx)}')
#keep only mr modes
# S Field
udat=udatx[:,0:mr]
sdat=sdatx[0:mr]
# the columns of vhdatx contain the coefficients of the field (standardized to unit variance)
# vhdat=vhdatx[0:mr,:]
# cofficients non standardized directly from projection on EOF
print(f'Use cofficients non standardized directly from projection on EOF')
vdat=udat.T @ xdat
return mr, var_retained, udat,vdat, sdat
[docs]
def make_field(*args,**kwargs):
'''
Make field for analysis
Parameters
==========
area: string (postional argument)
Area to be analyzed, possible values are
* 'TROPIC': Tropics
* 'GLOBAL': Global
* 'PACTROPIC': Pacific Tropics
* 'WORLD': World
* 'EUROPE': Europe
* 'NORTH_AMERICA': North America
* 'NH-ML': Northern Hemisphere Mid-Latitudes
var: string (keyword argument)
Variable to be analyzed
level: string (keyword argument)
Level to be analyzed
period: string (keyword argument)
Period to be analyzed
version: string (keyword argument)
Version of the dataset
loc: string (keyword argument)
Location of the dataset
Returns
=======
Z: xarray dataset
Field to be analyzed
arealat: numpy array
Latitude limits of the field
arealon: numpy array
Longitude limits of the field
'''
# Default values
default_values = {'var':'SST','level':'SST','period':'ANN','version':'V5','loc':None}
merged_values = {**default_values, **kwargs}
verdata = merged_values['version']
match verdata:
case 'V5':
return make_field_V5(*args,**merged_values)
case 'HAD':
return make_field_HAD(*args,**merged_values)
case _:
print(f'Version {verdata} not defined')
return None
[docs]
def make_field_V5(area,**kwargs):
'''
Make field for analysis
Parameters
==========
area: string
Area to be analyzed, possible values are
* 'TROPIC': Tropics
* 'GLOBAL': Global
* 'PACTROPIC': Pacific Tropics
* 'WORLD': World
* 'EUROPE': Europe
* 'NORTH_AMERICA': North America
* 'NH-ML': Northern Hemisphere Mid-Latitudes
var: string
Variable to be analyzed
level: string
Level to be analyzed
period: string
Period to be analyzed
version: string
Version of the dataset
loc: string
Location of the dataset
Returns
=======
Z: xarray dataset
Field to be analyzed
arealat: numpy array
Latitude limits of the field
arealon: numpy array
Longitude limits of the field
'''
period = kwargs['period']
version = kwargs['version']
var = kwargs['var']
level = kwargs['level']
loc = kwargs['loc']
shift_Centlon = 180
arealat, arealon = get_arealat_arealon(area)
dd=zd.in_data(var,level,period=period,epoch=version, loc = loc,averaging=False,verbose=True)
ZZ = dd[var.lower()]
if area =='EUROPE':
print('Use Greenwich centered coordinates with centlat=0')
ZZ1 = ZZ.sel(lat=slice(70,30),lon=slice(340,360))
ZZ1 = ZZ1.assign_coords({'lon':ZZ1.lon -360})
Z = xr.concat((ZZ1, ZZ.sel(lat=slice(70,30),lon=slice(0,50))), dim='lon')
shift_Centlon = 0
del ZZ1
else:
print('Use Pacific centered coordinates with centlat=180')
Z = ZZ.sel(lat=slice(arealat[0],arealat[1]), lon=slice(arealon[0],arealon[1]))
Z = Z.assign_coords(lon = Z.lon-180.)
arealon = arealon - 180
del dd,ZZ
print(f'Selecting field {var} for level {level} and area {area}')
return Z, arealat, arealon, shift_Centlon
[docs]
def make_field_HAD(area,**kwargs):
'''
Make field for analysis for HADSST
Parameters
==========
area: string
Area to be analyzed, possible values are
* 'TROPIC': Tropics
* 'GLOBAL': Global
* 'PACTROPIC': Pacific Tropics
* 'WORLD': World
* 'EUROPE': Europe
* 'NORTH_AMERICA': North America
* 'NH-ML': Northern Hemisphere Mid-Latitudes
var: string
Variable to be analyzed
level: string
Level to be analyzed
period: string
Period to be analyzed
version: string
Version of the dataset
loc: string
Location of the dataset
Returns
=======
Z: xarray dataset
Field to be analyzed
arealat: numpy array
Latitude limits of the field
arealon: numpy array
Longitude limits of the field
'''
period = kwargs['period']
version = kwargs['version']
var = kwargs['var']
level = kwargs['level']
loc = kwargs['loc']
shift_Centlon = 180
arealat, arealon = get_arealat_arealon(area)
dstart = '1/1/1870'
dtend = '12/30/2020'
data_time = pd.date_range(start=dstart, end=dtend, freq='1MS')
# Hist1950
datadir = loc + '/DATA/HADSST/'
files = 'HadISST_sst.nc'
print(f'{datadir} \n SST file --> \t {files} \n')
ds = xr.open_dataset(datadir + files,use_cftime=None).drop_dims('nv').rename({'latitude': 'lat', 'longitude': 'lon'})
sst_Pac = zmap.adjust_data_centlon(ds.sst)
sst_Pac = sst_Pac.assign_coords({'lon':sst_Pac.lon+180.})
#
#Fix discontinuity at dateline
sst_Pac.loc[dict(lon=180.5)] = (sst_Pac.sel(lon=180-1.5)+sst_Pac.sel(lon=182.5))/2
sst_Pac.loc[dict(lon=180+1.5)] = (sst_Pac.sel(lon=180-1.5)+sst_Pac.sel(lon=182.5))/2
ZZ = sst_Pac.sel(lat=slice(arealat[0],arealat[1]),lon=slice(arealon[0],arealon[1])).assign_coords({'time': data_time})
ZZ.data[abs(ZZ.data) > 100] = np.nan
if area =='EUROPE':
print('Use Greenwich centered coordinates with centlat=0')
ZZ1 = ZZ.sel(lat=slice(70,30),lon=slice(340,360))
ZZ1 = ZZ1.assign_coords({'lon':ZZ1.lon -360})
Z = xr.concat((ZZ1, ZZ.sel(lat=slice(70,30),lon=slice(0,50))), dim='lon')
del ZZ1
else:
print('Use Pacific centered coordinates with centlat=180')
Z = ZZ.sel(lat=slice(arealat[0],arealat[1]), lon=slice(arealon[0],arealon[1]))
Z = Z.assign_coords(lon = Z.lon-180.)
arealon = arealon - 180
del ZZ
print(f'Selecting field {var} for level {level} and area {area}')
return Z, arealat, arealon, shift_Centlon
[docs]
def normalize_training_data(params, vdat, period='train',scaler=None,feature_scale=1):
'''
Normalize training data
PARAMETERS
==========
params: dict
Dictionary with the parameters for the analysis
vdat: numpy array
Data to be normalized
period: string
Period to be analyzed
scaler: object
Scaler object
feature_scale: float
Feature scale
RETURNS
=======
datatr: numpy array
Normalized data
sstr: object
Scaler object
'''
ps = period + '_period_start'
pe = period + '_period_end'
if scaler is None:
if params['scaling'] == 'MaxMin':
sstr = MinMaxScaler(feature_range=(-1 ,1))
datatr = sstr.fit_transform(vdat.T[params[ps]:params[pe] + 1, :])
elif params['scaling'] == 'Standard':
sstr = StandardScaler()
datatr = sstr.fit_transform(vdat.T[params[ps]:params[pe] + 1, :])
elif params['scaling'] == 'Identity':
sstr = zaic.IdentityScaler()
datatr = sstr.fit_transform(vdat.T[params[ps]:params[pe] + 1, :])
elif params['scaling'] == 'SymScaler':
sstr = zaic.SymmetricFeatureScaler(feature_scales=feature_scale)
datatr = sstr.fit_transform(vdat.T[params[ps]:params[pe] + 1, :])
else:
raise ValueError(f'Wrong scaling defined for {params["scaling"]}')
else:
datatr = scaler.transform(vdat.T[params[ps]:params[pe]+1,:])
sstr = scaler
return datatr, sstr
[docs]
def make_data(INX,params):
'''
Prepare data for analysis and concatenate as needed. Modify input `INX` dictionary
by adding values for `scaler` and `index` for each field. `scaler` is the scaler used, `index` is the
index of the data in the concatenated matrix INX.
The Convention for indeces is that they point to the real date.
If python ranges need to be defined then it must take into account the extra 1
in the end of the range.
PARAMETERS
==========
INX: dict
Dictionary with the fields to be analyzed
params: dict
Dictionary with the parameters for the analysis
RETURNS
=======
datain: numpy array
Matrix with the data to be analyzed
INX: dict
Input dictionary updated with the information from the data analysis
'''
indstart = 0
for num,i in enumerate(INX.keys()):
print(f'\nProcessing field {INX[i]["field"]} that is {INX[i]["datatype"]}')
_,_, vdat = select_field_eof(INX,i)
indstart = indstart
indend = indstart + INX[i]['mr']
print('vdat',vdat.shape)
#Choose scaling
if params['scaling'] == 'SymScaler':
print(f'Using Symmetric Scaler')
# Use the feature scale in the Symmetric Scaler
var12 = INX[i]['var_retained']
ss = INX[i]['sdat']
vartot = sum(ss**2)/var12
feat_scale = ss**2/vartot
else:
print(f'Using {params["scaling"]} scaling')
feat_scale = 1
#Normalize Training Data
tmp, sstr = normalize_training_data(params, vdat, period='train',feature_scale=feat_scale)
if num == 0:
datatr = tmp
else:
datatr = np.concatenate((datatr,tmp),axis=1)
INX[i]['scaler_tr'] = sstr
#Normalize Validation Data
tmp, ssva = normalize_training_data(params, vdat, period='val', feature_scale=feat_scale)
if num == 0:
datava = tmp
else:
datava = np.concatenate((datava,tmp),axis=1)
INX[i]['scaler_va'] = ssva
#Normalize Test Data
# Use the scaling of the training data
tmp, _ = normalize_training_data(params, vdat, period='test',scaler=sstr, feature_scale=feat_scale)
if num == 0:
datate = tmp
else:
datate = np.concatenate((datate,tmp),axis=1)
INX[i]['scaler_te'] = sstr
INX[i]['index'] = np.arange(indstart, indend)
print(f'Added field {i} to feature input data')
print(f'Index for field {INX[i]["field"]} are {indstart} and {indend}\n')
print(f'Using {params["scaling"]} scaling')
print(f'Using {INX[i]["mr"]} EOFs for {INX[i]["var_retained"]} variance retained')
indstart = indend
# Transform to Tensor
datatr = torch.tensor(datatr, device=params['device'], dtype=params['t_type'])
datava = torch.tensor(datava, device=params['device'], dtype=params['t_type'])
datate = torch.tensor(datate, device=params['device'], dtype=params['t_type'])
print(f'Training data shape {datatr.shape}')
print(f'Validation data shape {datava.shape}')
print(f'Testing data shape {datate.shape}')
return datatr, datava, datate, INX
[docs]
def make_features(INX):
'''
Prepare features for analysis and compute features boundaries
PARAMETERS
==========
INX: dict
Dictionary with the fields to be analyzed
RETURNS
=======
num_features: int
Number of features
m_limit: list
List with the boundaries of the features
'''
num_features = 0
m_limit = []
for num,ii in enumerate(INX.keys()):
print(f'\nProcessing field {INX[ii]["field"]} that is {INX[ii]["datatype"]}')
num_features += INX[ii]['mr']
m_limit.append(INX[ii]['mr'])
print(f'Limit for {INX[ii]["field"]} is {INX[ii]["mr"]}')
print(f'Total number of features {num_features}')
return num_features, m_limit
[docs]
def make_data_base(InputVars, period='ANN', version='V5', SMOOTH=False, normalization='STD', \
eof_interval = None, detrend=False, \
shift='ERA5', case=None, datatype='Source_data', location='DDIR'):
'''
Organize data variables in data base `INX`
PARAMETERS
==========
InputVars: list
List of variables to be analyzed
period: string
Period to be analyzed
version: string
Version of the dataset
SMOOTH: boolean
If True, smooth the data
normalization: string
Type of normalization to be applied
eof_interval: list
List with the starting and ending date for EOF analysis
detrend: boolean
If True, detrend the data
shift: string
Type of shift to be applied to select the data
case: string
Case to be analyzed
datatype: string
Type of data to be analyzed, either `Source_data` or `Target_data`
location: string
Data directory
RETURNS
=======
INX: dict
Dictionary with the fields to be analyzed. The dictionary is organized as follows:
INX = {'id1':{'case':case,'datatype': datatype,'field':invar.name,'level':inlevel,
'centlon':centlon,'arealat':arealat, 'arealon':arealon,
X':X,'xdat':xdat,'mr':mr,'var_retained':varr,'udat':udat,'vdat':vdat,'sdat':sdat}}
'''
INX = {}
for invar in InputVars:
for inlevel in invar.levels:
S,arealat,arealon, centlon = make_field(invar.area,var=invar.name,level=inlevel,period=period,version=version,loc=location)
X = make_matrix(S,SMOOTH,normalization,shift,dropnan=False,detrend=detrend)
if eof_interval is None:
print(f'No EOF interval defined -- Using all data, using {invar.mr} modes')
mr, varr, udat, vdat,sdat = make_eof(X,invar.mr)
else:
print(f'EOF interval defined -- Using data from {eof_interval[0]} to {eof_interval[1]}')
mr, varr, udat, vdat,sdat = make_eof(X,invar.mr,eof_interval=eof_interval)
dv = {'case':case,'area':invar.area, 'datatype': datatype,'field':invar.name,'level':inlevel, 'centlon':centlon,\
'arealat':arealat, 'arealon':arealon, 'X':X,'mr':mr,'var_retained':varr,'udat':udat,'vdat':vdat,'sdat':sdat}
if invar.dropX:
print(f'Dropping X matrix')
del dv['X']
gc.collect()
id = (invar.name+inlevel).upper()
INX.update( {id:dv})
print(f'Added field `{invar.name}` with identification `{id}` to data base')
del S
return INX
[docs]
def init_weights(m):
''' Initialize weights with uniform ditribution'''
for name, param in m.named_parameters():
nn.init.uniform_(param.data, -0.08, 0.08)
[docs]
def count_parameters(model):
''' Count Model Parameters'''
return sum(p.numel() for p in model.parameters() if p.requires_grad)
[docs]
def epoch_time(start_time, end_time):
'''Compute Execution time per epoch'''
elapsed_time = end_time - start_time
elapsed_mins = int(elapsed_time / 60)
elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
return elapsed_mins, elapsed_secs
[docs]
def create_time_features(data_time, start,device):
'''
Create the past features for monthly means
PARAMETERS
==========
data_time: xarray dataset
Time data
start: datetime
Starting date
device: string
Device to be used for computation
RETURNS
=======
pasft: torch tensor
Tensor with the past features
Three levels:
* The first level is the month sequence
* The second level is the seasonal cycle
* The third is the year
'''
pasft = torch.zeros(len(data_time),3,device=device,dtype=torch.float32)
tt = data_time.year
#Months
pasft[:,0] = torch.tensor((data_time.month - 1) / 11.0 - 0.5)
#Seasons
pasft[:,1] = torch.tensor((data_time.quarter-2.5)/4)
#Years
# pasft[:,2] = torch.tensor((data_time.year-data_time.year[0])/(data_time.year[-1]-data_time.year[0]))
lent = data_time.shape[0]
pasft[:,2] = torch.tensor(np.arange(lent)/lent)
return pasft
[docs]
def rescale (params, PDX, out_train0, out_val0, out_test0,verbose=True):
''''
Rescale data to original values, according to scaling choice in
PARAMETERS
==========
params: dict
Dictionary with the parameters for the analysis
PDX: dict
Dictionary with the information for the fields to be analyzed
out_train0: numpy array
Training data
out_val0: numpy array
Validation data
out_test0: numpy array
Test data
RETURNS
=======
out_train: numpy array
Rescaled training data
out_val: numpy array
Rescaled validation data
out_test: numpy array
Rescaled test data
true: numpy array
Original data
'''
Tpredict = params['Tpredict']
out_train = np.ones(out_train0.shape)
out_val = np.ones(out_val0.shape)
out_test = np.ones(out_test0.shape)
for num,i in enumerate(PDX.keys()):
if verbose:
print(f'\nProcessing field {PDX[i]["field"]} that is {PDX[i]["datatype"]}')
print(f'Number of modes retained {PDX[i]["mr"]}')
print(PDX[i]['vdat'].T.shape)
indloc = PDX[i]['index']
if num == 0:
for t in range(Tpredict):
out_train[:,t,indloc] = PDX[i]['scaler_tr'].inverse_transform(out_train0[:,t,indloc])
out_val[:,t,indloc] = PDX[i]['scaler_va'].inverse_transform(out_val0[:,t,indloc])
out_test[:,t,indloc] = PDX[i]['scaler_te'].inverse_transform(out_test0[:,t,indloc])
true = PDX[i]['vdat'].T
else:
if verbose:
print(out_train.shape)
for t in range(Tpredict):
out_train[:,t,indloc] = PDX[i]['scaler_tr'].inverse_transform(out_train0[:,t,indloc])
out_val[:,t,indloc] = PDX[i]['scaler_va'].inverse_transform(out_val0[:,t,indloc])
out_test[:,t,indloc] = PDX[i]['scaler_te'].inverse_transform(out_test0[:,t,indloc])
true = np.concatenate((true,PDX[i]['vdat'].T),axis=1)
return out_train,out_val,out_test,true
# def make_dyn_verification(ver_f, area, dyn_cases, dynddr, times, dyn_startdate, dyn_enddate, filever):
# '''
# Make dynamic verification data
# PARAMETERS
# ==========
# ver_f: numpy array
# Verification data
# area: string
# Area to be analyzed
# dyn_cases: list
# List of cases to be analyzed
# dynaddr: string
# Address of the dynamic data
# times: numpy array
# Time data
# dyn_startdate: string
# Starting date for the data
# dyn_enddate: string
# Ending date for the data
# filever: string
# Name of the file to be written
# RETURNS
# =======
# ver_d: numpy array
# Verification data in numpy format
# '''
# endindex = np.where(times == dyn_enddate)
# startindex = np.where(times == dyn_startdate)
# ndyn = int(endindex[0]-startindex[0]+1)
# # ngrid = INX['T2MT2M']['X'].A.shape[0]
# # dynddr = homedir + '/Dropbox (CMCC)/ERA5/SEASONAL_'+ ver_f
# print(f'Starting date for verification of dynamic data {dyn_startdate} and ending date {dyn_enddate}')
# print(f'Number of months for verification {ndyn}')
# print(f'Verification field {ver_f}')
# lead_length = 6
# match ver_f:
# case 'SST':
# vername_f = 'ssta'
# case 'T2M':
# vername_f = 't2a'
# case _ :
# raise ValueError(f'Verification field {ver_f} not defined')
# try:
# out = xr.open_dataset(filever).stack(z=('lat','lon'))[vername_f]
# print(f'Verification data loaded from file {filever}')
# print(f'Verification data shape {out.shape}, times {out.time[0].data} to {out.time[-1].data}')
# except:
# print(f'Verification data not found -- Creating file {filever}')
# # ver_d = np.ones((ndyn,6,ngrid))
# for cen in dyn_cases:
# for lead in range(lead_length):
# vfiles = [f"{dynddr}/{cen[ver_f]}_{cen['center']}_{k}_{lead+1}.nc" for k in cen['system']]
# for index, value in enumerate(vfiles):
# print(lead,index, value)
# ddd = xr.open_dataset(value).rename({'longitude':'lon','latitude':'lat'}).sel(time=slice(times[startindex[0]+lead][0],times[endindex[0]-6+lead][0])).mean(dim='number')#.persist()
# print(f'Variable selected ---> {vername_f}, size {ddd[vername_f].shape}')
# out_field = select_area(area,ddd).stack(z=('lat','lon')).transpose('time','z')[vername_f]
# if index == 0:
# result = out_field
# else:
# result = xr.concat([result,out_field],dim='time')
# del ddd,out_field
# print(f'\nAt lead {lead}, Result shape {result.shape}, times {result.time[0].data} to {result.time[-1].data}')
# if lead == 0:
# out_time = result.time
# out = xr.full_like(result, 1).expand_dims(dim={"lead":range(1,lead_length+1)}, axis=1).assign_coords(time=out_time).copy()
# out[:,lead,:] = result.assign_coords(time=out_time)
# print(f'\n Lead 0 --> {out.shape},{result.shape}')
# else:
# out[:,lead,:] = result.assign_coords(time=out_time)
# print(f'Finished analysis for lead {lead_length} {out.shape}')
# print(f'Verification data shape {out.shape}, times {out.time[0].data} to {out.time[-1].data}')
# out.unstack().to_netcdf(filever,engine='h5netcdf')
# return out
[docs]
def matrix_rank_light(X,S):
'''
Compute the rank of a matrix using the singular values
PARAMETERS
==========
X: numpy array
Matrix to be analyzed
S: numpy array
Singular values of the matrix
RETURNS
=======
rank: int
Rank of the matrix
'''
rtol = max(X.shape[-2:]) * np.finfo(S.dtype).eps
tol = S.max(axis=-1, keepdims=True) * rtol
return np.count_nonzero(S > tol, axis=-1)
import re
import re
import re
[docs]
def make_fcst_array(startdate, enddate, leads, data):
'''
Make forecast array for verification.
The input uses the xarray DataArray format
of dimension (ntim, lead, z) where z is a stacked coordinate.
The output is an xarray DataArray with the time, lead, and z dimensions,
and the valid_time coordinate as a 2D array.
The lead time starts from 0 as the last element of
the input sequence to leads-1.
PARAMETERS
==========
startdate: string
Starting date for the forecast
enddate: string
Ending date for the forecast
leads: int
Number of leads, including the IC
data: xarray DataArray
DataArray with the forecast data
RETURNS
=======
out: xarray DataArray
Forecast array for verification
'''
if len(data.shape) == 4:
ntim, member, nlead, n = data.shape
print(f'Number of times {ntim}, number of members {member}, number of leads {nlead}, number of features {n}')
else:
ntim, nlead, n = data.shape
print(f'Number of times {ntim}, number of leads {nlead}, number of features {n}')
time = pd.date_range(startdate, enddate, freq="MS") # Start-of-month freq
print(time)
# Validation checks
if ntim != len(time):
raise ValueError(f"Number of times {ntim} does not match length of time {len(time)}")
if nlead != leads:
raise ValueError(f"Number of leads {nlead} does not match expected leads {leads}")
if not isinstance(data, xr.DataArray):
raise TypeError(f"Data must be an xarray DataArray, but got {type(data)}")
# Create the valid_time coordinate
valid_time = xr.DataArray(
np.array([
[t + pd.DateOffset(months=int(l)) for l in np.arange(nlead)]
for t in time
]),
dims=["time", "lead"],
coords={"time": time, "lead": np.arange(nlead)},
name="valid_time"
)
# Assign the valid_time coordinate to the data
data = data.assign_coords(valid_time=valid_time)
return data
[docs]
def eof_to_grid(field, forecasts, startdate, enddate, params=None, INX=None, truncation=None):
'''
Transform from the eof representation to the grid representation,
putting data in the format (Np, Tpredict,gridpoints) in stacked format .
Where Np is the total number of cases that is given by
$N_p = L - TIN - Tpredict +1$
L is the total length of the test period and Tpredict is the number of lead times
and gridpoints is the number of grid points in the field.
All fields start at month 1 of the prediction.
PARAMETERS
==========
field: string
Field to be analyzed
forecasts: numpy array
Forecasts data from the network
startdate: int
Starting date for the forecast (IC)
enddate: int
Ending date for the forecast
params: dict
Dictionary with the parameters for the analysis
INX: dict
Dictionary with the information for the fields to be analyzed
truncation: int
Number of modes to be retained in the observations
RETURNS
=======
The routines returns arrays in the format of xarray datasets
as (Np,lead,grid points) in stacked format
Fcst: xarray dataset
Forecast data in grid representation with dims (Np,lead,grid points)
Per: xarray dataset
Persistence data in grid representation with dims (Np,lead,grid points)
'''
if INX is None or params is None:
raise ValueError("INX or params not defined")
# Get the pattern EOF for the requested field
zudat = INX[field]['udat']
# Get the `X` matrix for the field as a template -- it contains the time information
XZ = INX[field]['X'].A
# Length of the prediction period
Tpredict = params['Tpredict']
# Length of input sequence
TIN = params['TIN']
# Length of prediction
T = params['T']
# Transform the observation to gridpoint
# if truncation is not None truncate to `truncation`
if truncation is not None:
xpro = XZ.data.T@zudat[:,:truncation]
Obs = XZ.copy(data=zudat[:,:truncation]@xpro.T)
print(f'Verification using truncated EOFs with {INX[field]["mr"]} modes')
else:
Obs = XZ
print(f'Verification using entire field {Obs.shape})')
print(f'Obs from {Obs.time[0].data} to {Obs.time[-1].data}\n')
# Define number of cases (forecasts) to be analyzed
# Calculate the number of months
Np = (enddate.year - startdate.year) * 12 + (enddate.month - startdate.month) + 1
print(f'Number of cases {Np}')
tclock = Obs.time.data
# Align the verification data with the first foreacast date
print(f'Obs from {startdate} to {enddate}\n')
print(f'Forecasts for IC at {startdate} to {enddate} and prediction time of {Tpredict} months')
flead = np.arange(0,Tpredict+1)
# Assemble arrays for verification on grid
# The Observation IC start one time step before the forecast
# because the forecast contains only from lead 1 to lead Tpredict
# The IC for the forecast (lead=0) is empty
Tmp = Obs.sel(time=slice(startdate.strftime('%Y-%m-%d'),enddate.strftime('%Y-%m-%d'))).expand_dims({"lead": Tpredict+1}).drop_vars('month').assign_coords({"lead": flead}).transpose('time','lead','z')
# Compute forecast and consider only the `Np` forecast that fit with the Tpredict period
hhh = np.einsum('ijk,lk',forecasts,INX[field]['udat'])
# Keep only the `Np` cases that are within the Tpredict period
kzz = np.ones((Np,Tpredict+1,len(Tmp.z)) )
kzz[:,0,:] = Tmp[:Np,0,:].data # kz[:,0,:]
for i in range(1,Tpredict+1):
kzz[:,i,:] = hhh[:Np,i-1,:]
Fcst = Tmp.copy(data=kzz,deep=True)
print(f'Forecast shape {Fcst.shape}, {Fcst.time.data[0]} to {Fcst.time.data[-1]}')
# Compute Persistence
# Start from the previous time step with respect to the first forecast
kzz = np.ones((Np,Tpredict+1,len(Tmp.z)) )
for i in range(Np):
for j in range(Tpredict+1):
kzz[i,j,:] = Tmp.data[i,0,:]
Per = Tmp.copy(data=kzz,deep=True)
print(f'Persistence shape {Per.shape}, {Per.time.data[0]} to {Per.time.data[-1]}')
return Fcst, Per, Obs.transpose('time','z')
# routine to advance a Timestamp of one month
[docs]
def advance_months(ts, n=1):
'''
Advance or reduce a Timestamp by n months
PARAMETERS
==========
ts: pd.Timestamp or str
Timestamp to be modified. If str, it should be in the YYYY-MM-DD format.
n: int
Number of months to advance (positive) or reduce (negative)
RETURNS
=======
pd.Timestamp or str
Modified timestamp. If input was a string, output will be a string in the YYYY-MM-DD format.
'''
if isinstance(ts, str):
ts = pd.Timestamp(ts)
return (ts + pd.DateOffset(months=n)).strftime('%Y-%m-%d')
return ts + pd.DateOffset(months=n)
[docs]
def project_dyn(data, INX, field, truncation=None):
'''
Project the dynamic data into the EOF space
PARAMETERS
==========
data: xarray dataset
Dynamic data
INX: dict
Dictionary with the information for the fields to be analyzed
field: string
Field to be analyzed
truncation: int
Number of modes to be retained in the observations
RETURNS
=======
out: xarray dataset
Projected data, in stacked format with NaN values
'''
# Get the pattern EOF for the requested field
zudat = INX[field]['udat']
# Get the `X` matrix for the field as a template -- it contains the time information
XZ = INX[field]['X'].A
# Transform the observation to gridpoint
# if truncation is not None truncate to `truncation`
if truncation is not None:
xzudat = XZ.isel(time=np.arange(truncation)).copy(data=zudat[:,:truncation]).unstack().stack(z=('lat','lon'))
xNEW = data.unstack().stack(z=('lat','lon'))
projNEW = np.einsum('ijk,lk',np.where(np.isnan(xNEW), 0, xNEW),np.where(np.isnan(xzudat), 0, xzudat))
tmpdata = np.einsum('ijk,kl',projNEW,xzudat)
recon_DATA = xNEW.copy(data=tmpdata)
print(f'Verification using truncated EOFs with {truncation} modes')
else:
recon_DATA = data.unstack().stack(z=('lat','lon'))
print(f'Verification using entire field {data.shape})')
return recon_DATA
[docs]
def make_dyn_verification_new(ver_f, area, dyn_cases, dynddr, times, filever):
'''
Make dynamic verification data. Read all time levels for the GCM data
PARAMETERS
==========
ver_f: numpy array
Verification data
area: string
Area to be analyzed
dyn_cases: list
List of cases to be analyzed
dynaddr: string
Address of the dynamic data
times: numpy array
Time data
filever: string
Name of the file to be written
RETURNS
=======
ver_d: numpy array
Verification data in numpy format
'''
lead_length = 6
match ver_f:
case 'SST':
vername_f = 'ssta'
case 'T2M':
vername_f = 't2a'
case 'Z500':
vername_f = 'za'
case _ :
raise ValueError(f'Verification field {ver_f} not defined')
try:
out = xr.open_dataset(filever).stack(z=('lat','lon'))[vername_f]
print(f'Verification data loaded from file {filever}')
print(f'Verification data shape {out.shape}, times {out.time[0].data} to {out.time[-1].data}')
except:
print(f'Verification data not found -- Creating file {filever}')
#
for cen in dyn_cases:
for lead in range(lead_length):
vfiles = [f"{dynddr}/{cen[ver_f]}_{cen['center']}_{k}_{lead+1}.nc" for k in cen['system']]
for index, value in enumerate(vfiles):
print(lead,index, value)
match ver_f:
case 'SST':
ddd = xr.open_dataset(value).rename({'longitude':'lon','latitude':'lat'}).mean(dim='number')
case 'T2M':
ddd = xr.open_dataset(value).rename({'longitude':'lon','latitude':'lat'}).mean(dim='number')
case 'Z500':
ddd = xr.open_dataset(value).rename({'longitude':'lon','latitude':'lat','forecast_reference_time':'time'}).mean(dim='number')
# Define new latitude and longitude coordinates
new_lat = np.arange(-90, 90.22, 0.25) # Adjust based on dataset bounds
new_lon = np.arange(0, 360, 0.25)
# Perform interpolation
# ddd = ddd.interp(lat=new_lat, lon=new_lon, method="linear").isel(forecastMonth=0,pressure_level=0).drop_vars({'forecastMonth','pressure_level'})
ddd = ddd.sel(lat=slice(None, None, -1)).interp(lat=new_lat, lon=new_lon, method="cubic",kwargs={"fill_value": None}).sel(lat=slice(None, None, -1)).drop_vars({'forecastMonth','pressure_level'}).squeeze()
case _ :
raise ValueError(f'Verification field {ver_f} not defined')
print(f'Variable selected ---> {vername_f}, size {ddd[vername_f].shape}')
out_field = select_area(area,ddd).stack(z=('lat','lon')).transpose('time','z')[vername_f]
if index == 0:
result = out_field
else:
result = xr.concat([result,out_field],dim='time')
del ddd,out_field
print(f'\nAt lead {lead}, Result shape {result.shape}, times {result.time[0].data} to {result.time[-1].data}')
if lead == 0:
out_time = result.time
out = xr.full_like(result, 1).expand_dims(dim={"lead":range(1,lead_length+1)}, axis=1).assign_coords(time=out_time).copy()
out[:,lead,:] = result.assign_coords(time=out_time)
print(f'Lead 0 --> {out.shape},{result.shape}')
else:
out[:,lead,:] = result.assign_coords(time=out_time)
print(f'Finished analysis for lead {lead_length} {out.shape}')
print(f'Verification data shape {out.shape}, times {out.time[0].data} to {out.time[-1].data}')
out.unstack().to_netcdf(filever,engine='h5netcdf')
return out
[docs]
def compute_increments(tensor, axis=0):
'''
Take the difference of the torch tensor along the specified axis and output the initial value
PARAMETERS
==========
tensor: torch tensor
Tensor to be analyzed
axis: int
Axis along which to compute the differences
RETURNS
=======
diff: torch tensor
Tensor with the differences along the specified axis
init_value: torch tensor
Initial value along the specified axis
'''
print(f'Computing differences for input tensor along axis {axis}')
# Initialize the difference tensor with the same shape as the input tensor
diff = torch.zeros_like(tensor)
# Compute the differences along the specified axis
diff = torch.diff(tensor, dim=axis, prepend=tensor.select(axis, 0).unsqueeze(axis))
# Extract the initial value along the specified axis
init_value = tensor.select(axis, 0)
return diff, init_value
[docs]
def cumsum_with_init(differences, init_value):
'''
Compute cumulative sum with initial values for a PyTorch tensor.
This function calculates the cumulative sum of a tensor containing differences,
while incorporating specific initial values. The input tensor `differences`
has dimensions (time, nlead, neof), and the initial values tensor `init_value`
has dimensions (neof). The initial values are added to the first time step
of each lead slice, and the cumulative sum is computed along the time dimension.
Parameters
----------
differences : torch.Tensor
Tensor containing the differences with dimensions (time, nlead, neof).
init_value : torch.Tensor
1D tensor representing the initial values, with size matching the last
dimension (neof) of `differences`.
Returns
-------
torch.Tensor
Tensor after computing the cumulative sum with initial values,
having the same shape as the input `differences`.
Raises
------
ValueError
If the size of `init_value` does not match the last dimension of `differences`.
Computes the cumulative sum for a PyTorch tensor of differences with specific initial values.
The differences tensor has dimensions (time, nlead, neof), and the initial values tensor
has dimensions (neof). The initial values are added to the first time level of each slice (lead)
and then accumulated using the cumulative sum.
Parameters:
- differences: torch.Tensor
Input tensor containing the differences with dimensions (time, nlead, neof).
- init_value: torch.Tensor
A 1D tensor representing the initial values, with a size matching the last dimension (neof).
Returns:
- result: torch.Tensor
The resulting tensor after computing the cumulative sum with initial values.
'''
# Ensure init_value matches the last dimension of differences
if init_value.size(0) != differences.size(-1):
raise ValueError("init_value must match the size of the last dimension in the input tensor.")
# Create a copy of the differences tensor to avoid modifying the input
result = differences.clone()
# Add init_value to the first time level of each lead slice
result[0] += init_value
# Compute the cumulative sum along the time dimension (axis=0)
result = torch.cumsum(result, dim=0)
return result
[docs]
def select_fcst(IC, my_data):
'''
Select the forecast data for the given initial condition IC from the xarray my_data dataset.
PARAMETERS
==========
IC: string
Initial condition for the forecast
my_data: xarray dataset
Forecast data
RETURNS
=======
ds_single_init: xarray dataset
Forecast data for verification with coordinate "time" as the valid_time
'''
init_time_sel = IC
ds_single_init = my_data.sel(time=init_time_sel, drop=False)
# Turn valid_time into a 1D coordinate named "time"
ds_single_init = ds_single_init.assign_coords(time=ds_single_init.valid_time)
ds_single_init = ds_single_init.swap_dims({"lead": "time"})
ds_single_init = ds_single_init.drop_vars("valid_time")
# Debugging: Check coordinate presence
# print("Coordinates in my_data:", list(my_data.coords))
# print("Coordinates in ds_single_init before assignment:", list(ds_single_init.coords))
# Check if 'z' is a MultiIndex in my_data
z_index = my_data.indexes["z"] # This is the actual pd.MultiIndex
if isinstance(z_index, pd.MultiIndex):
# Drop existing lat, lon if they exist
ds_single_init = ds_single_init.drop_vars(["lat", "lon"], errors="ignore")
# Extract lat and lon from the MultiIndex, then assign as coordinates
lat_vals = z_index.get_level_values("lat")
lon_vals = z_index.get_level_values("lon")
ds_single_init = ds_single_init.assign_coords(
lat=("z", lat_vals),
lon=("z", lon_vals)
)
# Rebuild the MultiIndex for 'z' with original level names
ds_single_init = ds_single_init.set_index(z=z_index.names)
return ds_single_init.drop_vars("lead", errors="ignore")
[docs]
def variance_features(INX):
'''
retur the variance of the features
PARAMETERS
==========
INX: dict
Dictionary with the information for the fields to be analyzed
RETURNS
=======
ssvar: numpy array
Variance of the features
'''
ssvar = []
for i, feature in enumerate(INX.keys()):
tmp = INX[feature]['sdat']
print(f'Processing field {feature} with shape {tmp.shape}')
ss = tmp**2 / sum(tmp**2)
ssvar.append(ss)
ssvar = np.concatenate(ssvar, axis=0)
return ssvar
import os
[docs]
def create_subdirectory(parent_dir, subdirectory_name):
"""
Create a subdirectory within the specified parent directory if it does not exist.
Parameters
----------
parent_dir : str
The path to the parent directory.
subdirectory_name : str
The name of the subdirectory to create.
Returns
-------
str
The full path of the created or existing subdirectory.
"""
subdirectory_path = os.path.join(parent_dir, subdirectory_name)
# Check if the directory exists, if not, create it
if not os.path.exists(subdirectory_path):
os.makedirs(subdirectory_path)
print(f"Directory created: {subdirectory_path}")
else:
print(f"Directory already exists: {subdirectory_path}")
return subdirectory_path
def _set_directory(file_path):
'''
Set the directory for `file_path` and create it if it does not exist.
Parameters
==========
file_path: string
Relative path to the data directory
Returns
=======
homedir: string
Root directory for the data
drop_home: string
Relative path to the data directory
'''
homedir = os.path.expanduser("~")
target = homedir+file_path
try:
os.makedirs(target)
print('Creating Directory ',homedir)
except FileExistsError:
print(f'Directory {target} already exists')
return target
[docs]
def eof_to_grid_new(field, forecasts, startdate, enddate, params=None, INX=None, truncation=None):
'''
Transform from the EOF representation to the grid representation.
The routine now supports forecasts arrays with an extra ensemble dimension.
For the original case, forecasts has shape (Np, T, n_eofs) and the
output forecast dataset has dims (Np, lead, gridpoints). Now, if forecasts
has shape (Np, K, T, n_eofs), the output forecast dataset will have dims
(Np, member, lead, gridpoints).
PARAMETERS
==========
field: string
Field to be analyzed
forecasts: numpy array
Forecast data from the network. Expected shape is either (Np, T, n_eofs)
or (Np, K, T, n_eofs), where K is the ensemble size.
startdate: datetime-like
Starting date for the forecast (initial condition)
enddate: datetime-like
Ending date for the forecast
params: dict
Dictionary with the parameters for the analysis. Must include 'Tpredict', 'TIN', and 'T'
INX: dict
Dictionary with information for the fields to be analyzed (including the EOF patterns)
truncation: int, optional
Number of modes to be retained in the observations
RETURNS
========
Fcst: xarray DataArray
Forecast data in grid representation.
* If forecasts is 3D, dims are (time, lead, z).
* If forecasts is 4D, dims are (time, member, lead, z).
Per: xarray DataArray
Persistence data in grid representation, with the same dims as Fcst.
Obs: xarray DataArray
Observation data in grid representation, with dims (time, z)
'''
if INX is None or params is None:
raise ValueError("INX or params not defined")
# Get the EOF patterns and template X matrix
zudat = INX[field]['udat']
XZ = INX[field]['X'].A
Tpredict = params['Tpredict']
TIN = params['TIN']
T = params['T']
# Transform the observation to gridpoint
if truncation is not None:
xpro = XZ.data.T @ zudat[:, :truncation]
Obs = XZ.copy(data=zudat[:, :truncation] @ xpro.T)
print(f'Verification using truncated EOFs with {INX[field]["mr"]} modes')
else:
Obs = XZ
print(f'Verification using entire field {Obs.shape}')
print(f'Obs from {Obs.time[0].data} to {Obs.time[-1].data}\n')
# Determine number of forecast cases
Np = (enddate.year - startdate.year) * 12 + (enddate.month - startdate.month) + 1
print(f'Number of cases {Np}')
print(f'Obs from {startdate} to {enddate}\n')
print(f'Forecasts for IC at {startdate} to {enddate} and prediction time of {Tpredict} months')
flead = np.arange(0, Tpredict+1)
# Build the observation template (no ensemble dimension)
Tmp = Obs.sel(time=slice(startdate.strftime('%Y-%m-%d'),
enddate.strftime('%Y-%m-%d'))
).expand_dims({"lead": Tpredict+1}).drop_vars('month'
).assign_coords({"lead": flead}).transpose('time', 'lead', 'z')
# Check the forecasts dimensions and process accordingly
if forecasts.ndim == 4:
# forecasts shape: (Np, K, T, n_eofs)
ensemble_size = forecasts.shape[1]
# Reconstruct grid forecasts for each ensemble member:
hhh = np.einsum('nktj,lj->nktl', forecasts, INX[field]['udat'])
print('Allocate array for forecasts with shape (Np, member, Tpredict+1, gridpoints)')
kzz = np.ones((Np, ensemble_size, Tpredict+1, len(Tmp.z)))
# For lead 0, use the observation for all ensemble members
obs_lead0 = Tmp[:Np, 0, :].data # shape (Np, gridpoints)
kzz[:, :, 0, :] = np.repeat(obs_lead0[:, np.newaxis, :], ensemble_size, axis=1)
# For leads 1 to Tpredict, fill with reconstructed forecasts
for lead in range(1, Tpredict+1):
kzz[:, :, lead, :] = hhh[:Np, :, lead-1, :]
# Create an xarray DataArray with dims (time, member, lead, z)
Fcst = xr.DataArray(
kzz,
dims=["time", "member", "lead", "z"],
coords={"time": Tmp.time.data[:Np],
"member": np.arange(ensemble_size),
"lead": flead,
"z": Tmp.z.data}
)
elif forecasts.ndim == 3:
# Original behavior for forecasts shape: (Np, T, n_eofs)
hhh = np.einsum('ijk,lk->ijl', forecasts, INX[field]['udat'])
kzz = np.ones((Np, Tpredict+1, len(Tmp.z)))
kzz[:, 0, :] = Tmp[:Np, 0, :].data
for lead in range(1, Tpredict+1):
kzz[:, lead, :] = hhh[:Np, lead-1, :]
Fcst = xr.DataArray(
kzz,
dims=["time", "lead", "z"],
coords={"time": Tmp.time.data[:Np],
"lead": flead,
"z": Tmp.z.data}
)
else:
raise ValueError("Forecasts array must be 3D or 4D.")
print(f'Forecast shape {Fcst.shape}, {Fcst.time.data[0]} to {Fcst.time.data[-1]}')
# Compute persistence by repeating the initial condition (lead=0) across all leads
if Fcst.ndim == 4:
persistence_data = np.repeat(Fcst.data[:, :, 0:1, :], Tpredict+1, axis=2)
Per = xr.DataArray(
persistence_data,
dims=Fcst.dims,
coords=Fcst.coords
)
else:
persistence_data = np.repeat(Fcst.data[:, 0:1, :], Tpredict+1, axis=1)
Per = xr.DataArray(
persistence_data,
dims=Fcst.dims,
coords=Fcst.coords
)
return Fcst, Per, Obs.transpose('time', 'z')
# Assuming Ft.time and Dt.time are arrays of datetime64, create an arrays that has only the dates that are
# in both arrays and get also the indeces of the original arrays
[docs]
def get_common_dates(Ft, Dt):
"""
Get common dates between two xarray datasets and their indices.
Parameters
----------
Ft : xarray.DataArray
First dataset with a time dimension.
Dt : xarray.DataArray
Second dataset with a time dimension.
Returns
-------
common_dates : numpy.ndarray
Array of common dates.
Ft_indices : numpy.ndarray
Indices of common dates in the first dataset.
Dt_indices : numpy.ndarray
Indices of common dates in the second dataset.
"""
common_dates = np.intersect1d(Ft.time.values, Dt.time.values)
Ft_indices = np.where(np.isin(Ft.time.values, common_dates))[0]
Dt_indices = np.where(np.isin(Dt.time.values, common_dates))[0]
return common_dates, Ft_indices, Dt_indices