Source code for AIModels.AIClasses

'''
Classes for the AI project

Classes
=======

Field: 
    Class for fields

EarlyStopping: 
    Class for early stopping

TimeSeriesDataset: 
    Class for time series dataset

'''
# import os,sys
# import math
import numpy as np
import xarray as xr
import pandas as pd

import scipy.linalg as sc

import matplotlib.pyplot as plt
# import datetime
import time as tm

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import transformers as tr
from sklearn.preprocessing import StandardScaler, MinMaxScaler

import zapata.computation as zcom
import zapata.data as zd
import zapata.lib as zlib
import zapata.mapping as zmap
import zapata.koopman as zkop

[docs] class Field(): ''' Class for fields Parameters ========== name: string Name of the field levels: string Level of the field area: string Area to be analyzed, possible values are * 'TROPIC': Tropics * 'GLOBAL': Global * 'PACTROPIC': Pacific Tropics * 'WORLD': World * 'EUROPE': Europe * 'NORTH_AMERICA': North America * 'NH-ML': Northern Hemisphere Mid-Latitudes mr: float Number of EOF retained dstart: string Start date for field dtend: string End date for field Attributes ========== name: string Name of the field levels: string Level of the field area: string Area of the field mr: float Number of EOF retained dstart: string Start date for field dtend: string End date for field ''' def __init__(self, name, levels,area, mr, dstart='1/1/1940', dtend='12/31/2022'): self.name = name self.levels = levels self.mr = mr self.area = area self.dstart = dstart self.dtend = dtend def __call__(self): return self.name,self.levels def __repr__(self): ''' Printing Information ''' print(f'Field: {self.name}, Levels: {self.levels}, Area: {self.area}') print(f'EOF Retained: {self.mr}, StartDate: {self.dstart}, EndDate: {self.dtend}') return '\n'
# Early stopping class
[docs] class EarlyStopping: ''' Class for early stopping Parameters ========== patience: int Number of epochs to wait before stopping verbose: boolean If True, print the epoch when stopping delta: float Minimum change in loss to be considered an improvement Attributes ========== patience: int Number of epochs to wait before stopping verbose: boolean If True, print the epoch when stopping delta: float Minimum change in loss to be considered an improvement counter: int Number of epochs since last improvement best_score: float Best loss score early_stop: boolean If True, stop the training ''' def __init__(self, patience=5, verbose=False, delta=0): self.patience = patience self.verbose = verbose self.counter = 0 self.best_score = None self.early_stop = False self.delta = delta def __call__(self, val_loss): if self.best_score is None: self.best_score = val_loss return False elif val_loss > self.best_score + self.delta: self.counter += 1 if self.counter >= self.patience: self.early_stop = True else: self.best_score = val_loss self.counter = 0 return self.early_stop
[docs] class TimeSeriesDataset(Dataset): ''' Class for time series dataset. Includes time feature for transformers Parameters ========== datasrc: numpy array Source data datatgt: numpy array Target data TIN: int Input time steps MIN: int Input variables size T: int Predictions time steps K: int Output variables size time_features: numpy array (optional) If not `None` contain Time Features shift: Overlap between source and target, for trasnformers `overlap = 0` for LSTM `overlap` should be TIN-T Attributes ========== datasrc: numpy array Source data datatgt: numpy array Target data time_features: numpy array Time features TIN: int Input time steps MIN: int Input variables T: int Output time steps K: int Output variables ''' def __init__(self, datasrc, datatgt, TIN, MIN, T, K, time_features=None): self.datasrc = datasrc self.datatgt = datatgt self.TIN = TIN self.MIN = MIN self.T = T self.K = K self.time_features = time_features def __len__(self): return len(self.datasrc) - self.TIN - self.T + 1 def __getitem__(self, idx): input_seq = self.datasrc[idx:idx+self.TIN, :self.MIN] target_seq = self.datatgt[idx+self.TIN:idx+self.TIN+self.T, :self.K] pasft = self.time_features[idx:idx+self.TIN,:] futft = self.time_features[idx+self.TIN:idx+self.TIN+self.T,:] return input_seq, target_seq, pasft, futft