Source code for zapata.lib

# import ipywidgets as widgets
import os
from IPython.display import Javascript

from typing import Sequence, Mapping, Iterable, Dict, Optional, Callable
from concurrent.futures import ThreadPoolExecutor
import multiprocessing as mp


import numpy as np
import scipy.linalg as sc
import pandas as pd
import xarray as xr

import sys
import platform
import pkg_resources

if os.environ.get("CONDA_DEFAULT_ENV") == 'BOOKOORT' or os.environ.get("CONDA_DEFAULT_ENV") == 'ERA5':   
    import ddsapi
else:
    print('ddsapi not available, not in BOOKOORT or ERA5 conda environment')

[docs] def name_notebook(newname): ''' Change name to Jupyterlab instance ''' Javascript('document.title="{}"'.format(newname))
[docs] def get_values_from_dict(input_dict, keys): ''' Get values from dictionary `input_dict` for keys in `keys` Parameters ========== input_dict: Dictionary keys: List of keys Returns ======= List of values ''' return [input_dict[key] for key in keys if key in input_dict]
[docs] def remove_values_from_list(the_list, val): """ Remove value `val` from list `the_list`""" return [value for value in the_list if value != val]
[docs] def makename(var,lev,yy,mm,dd): """ Utility to create names for ERA5 files. """ return var + "_" + lev + "_" + str(yy) +"_"+ str(mm) + "_" + str(dd) + ".grb"
[docs] def makemm(var,lev,yy,mm): """ Utility to create names for ERA5 numpy files""" work1 = var + lev + '/' work2 = var + "_" + lev + "_" + str(yy) +"_"+ str(mm) +'_MM' + ".npy" return work1 + work2
[docs] def makefilename(dir,var,lev,yy,mm,ext): """ Generalize file name creation """ work1 = dir + '/' work2 = var + "_" + lev + "_" + str(yy) +"_"+ str(mm) + "." + ext return work1 + work2
[docs] def adddir(name,dir): """ Add `dir` directory name to `name` file""" return dir +'/' + name.split('.')[0]+'.npy'
[docs] def makedir(fndir): """Create Directory `fndir`""" try: # Create target Directory os.mkdir(fndir) print("Directory " , fndir , " Created ") except FileExistsError: print("Directory " , fndir , " already exists")
[docs] def movefile(oldfile, newdir): """Move file from `oldfile` to `newdir`""" # Move File 'oldfile' to directory 'newdir', with error control try: command =' mv ' + oldfile + ' ' + newdir print(command) os.system(command) except: print('Error in Moving Data Files... ',oldfile,' to new directory .....', newdir)
[docs] def copyfile(oldfile, newdir): """Copy file from `oldfile` to `newdir`""" # Move File 'oldfile' to directory 'newdir', with error control try: command =' cp ' + oldfile + ' ' + newdir print(command) os.system(command) except: print('Error in Copying Data Files... ',oldfile,' to new directory .....', newdir)
[docs] def chop(a,epsilon=1e-10): """Eliminate real small complex number converting to real""" check=sc.norm(a.imag) if check < epsilon: out=a.real else: out=a return out
[docs] def year2date(years,i): """ Transform index i in string date yy/mm. Rounding requires the small shift Years are obtained from np.arange(1979,2018, 1/12) """ mon=['JAN','FEB','MAR','APR','MAY','JUN','JUL','AUG','SEP','OCT','NOV','DEC'] y=str(int(years[i]+0.001)) m=np.mod(int(round((years[i]-int(years[i]))*12)),12) date = mon[m] + ' ' + y return date
[docs] def date2year(years,date): """Transform index date ['Jan' '1989' ] in index i. Years are from np.arange(1979,2018, 1/12) """ mon=['JAN','FEB','MAR','APR','MAY','JUN','JUL','AUG','SEP','OCT','NOV','DEC'] y=float(date[4:8]) m=mon.index(str(date[0:3])) index = (y-1979)*12 + m return int(index)
[docs] def putna(left,right, xar, scalar = None): ''' Put NaN in xarray according if they are laying in the interval `left,right` Parameters ========== left,right: Extremes of the interval where the values must be NaN xar : Xarray scalar : If set all entries not satisfying the condition are put equal to `scalar` Returns ======= Modified array ''' if scalar: out=scalar else: out=xar return xr.where((xar < right) & (xar > left), np.nan, out)
[docs] def go_to(dir): ''' Set Working directory Parameters ========== dir: Target directory relative to users' root directory YIELD ===== Change working directory ''' homedir = os.path.expanduser("~") print('Root Directory for Data ',homedir) #Working Directory wkdir = homedir + '/'+ dir os.chdir(wkdir) print(f'Changed working directory to {wkdir}') return wkdir
[docs] def long_string(lon,cent_lon=0): ''' Get nice formatted longitude string Parameters ========== lon: Longitude cent_lon: Central longitude for projection used Yield ===== string in nice format ''' E = 'E' W = 'W' if cent_lon == 0: if lon < 0: out = str(-lon) + W elif lon > 0: out = str(lon) + E else: out = str(lon) elif cent_lon == 180: if lon > 0: out = str(lon) + W elif lon < 0: out = str(-lon) + E else: out = str(lon) else: SystemError(f'Error in longitude string cent_lon {cent_lon} ') return out
[docs] def lat_string(lat): ''' Get nice formatted latitude string Parameters ========== lat: Latitude Yield ===== string in nice format ''' if lat < 0: out = str(-lat) + 'S' elif lat > 0: out = str(lat) + 'N' else: out = 'Equator' return out
[docs] def get_environment_info(option): ''' Get information about the Python environment Parameters ========== option: String Options are: 'interpreter': Get the path of the Python interpreter 'version': Get the Python version 'packages': Get the list of installed packages Returns ======= Information about the Python environment ''' python_executable = sys.executable python_version = platform.python_version() installed_packages = sorted([(d.project_name, d.version) for d in pkg_resources.working_set]) match option: case 'interpreter': return python_executable case 'version': return python_version case 'packages': return installed_packages case 'all': return python_executable, python_version, installed_packages case _: print('Choose an option: interpreter, version, packages, all') return
# ─── a single executor shared by all calls ──────────────────────────────────── _dl_pool = ThreadPoolExecutor(max_workers=4) # tweak as needed # Pre-compute the static list so we don’t re-allocate it every call _VERTICAL_LEVELS = [ 0.5057600140571594, 1.5558552742004395, 2.6676816940307617, 3.8562798500061035, 5.140361309051514, 6.543033599853516, 8.09251880645752, 9.822750091552734, 11.773679733276367, 13.99103832244873, 16.52532196044922, 19.42980194091797, 22.75761604309082, 26.558300018310547, 30.874561309814453, 35.740203857421875, 41.180023193359375, 47.21189498901367, 53.85063552856445, 61.11283874511719, 69.02168273925781, 77.61116027832031, 86.92942810058594, 97.04131317138672, 108.03028106689453, 120, 133.07582092285156, 147.40625, 163.16445922851562, 180.5499267578125, 199.7899627685547, 221.14117431640625, 244.890625, 271.35638427734375, 300.88751220703125, 333.8628234863281, 370.6884765625, 411.7938537597656, 457.6256103515625, 508.639892578125, 565.2922973632812, 628.0260009765625, 697.2586669921875, 773.3682861328125, 856.678955078125, 947.4478759765625, 1045.854248046875, 1151.9912109375, 1265.8614501953125, 1387.376953125, 1516.3636474609375, 1652.5684814453125, 1795.6707763671875, 1945.2955322265625, 2101.026611328125, 2262.421630859375, 2429.025146484375, 2600.38037109375, 2776.039306640625, 2955.5703125, 3138.56494140625, 3324.640869140625, 3513.445556640625, 3704.65673828125, 3897.98193359375, 4093.15869140625, 4289.95263671875, 4488.15478515625, 4687.5810546875, 4888.06982421875, 5089.478515625, 5291.68310546875, 5494.5751953125, 5698.060546875, 5902.0576171875 ]
[docs] def get_ocean_GLORS(year: int | str, var: str, outfile: str, *, pool: ThreadPoolExecutor = _dl_pool): """ Asynchronously download GLORS reanalysis for `year` into `outfile`. Parameters ---------- year : int or str Year to fetch (e.g. 1993). Converted to str internally. var : str Variable to fetch, e.g. 'sea_water_potential_temperature'. outfile : str Path to write the NetCDF file. pool : ThreadPoolExecutor, optional Executor that runs the blocking retrieval. Defaults to a module-level pool so every call shares the same threads. Returns ------- concurrent.futures.Future Future whose `.result()` blocks until the file is fully written. """ # Ensure `year` is a string for the DDS request year_str = str(year) # Function that does the real work (runs in a background thread) def _download(): client = ddsapi.Client() client.retrieve( "cglorsv7", "monthly-regular", { "vertical" : _VERTICAL_LEVELS, "variable" : var, "time" : {"year": [year_str], "month": [str(m) for m in range(1, 13)]}, "format" : "netcdf", }, outfile, ) print(f"Data for {year_str} saved to {outfile}") # Submit the job to the shared thread pool and return the Future return pool.submit(_download)
[docs] def write_netcdf(ds,file): """ Write a NetCDF file with the given filename. Parameters ========== file: Name of the NetCDF file to be written. Returns ======= None """ # Define “maximum” loss-less compression settings compression_settings = { # Either syntax works; pick the one your h5netcdf/xarray combo supports # 1) NetCDF-style keywords (widely supported) "zlib": True, "complevel": 9, # 2) Direct HDF5 keywords (h5netcdf ≥ 1.0) # "compression": "gzip", # "compression_opts": 9, } # Apply the same settings to every variable (data + coords if you like) encoding = {var: compression_settings for var in ds.variables} # Write the file ds.to_netcdf( file, # destination mode='w', # write mode engine="h5netcdf", # use the pure-Python HDF5 backend encoding=encoding, # per-variable compression map ) return
from functools import partial # import xarray as xr # from functools import partial # from typing import Sequence, Mapping, Iterable, Callable, Optional, Dict # ────────────────────────────────────────────────────────────────────────────── def _month_filter( ds: xr.Dataset, *, months: Iterable[int], time_dim: str, weighted: bool, avg_per_file: bool, user_pp: Optional[Callable[[xr.Dataset], xr.Dataset]] = None, ) -> xr.Dataset: """ Internal helper used as `preprocess` by open_mfdataset. 1. runs any user-supplied preprocessing, 2. keeps only the requested months, 3. *optionally* averages those months inside the single file (weighted by month length if requested). """ if user_pp is not None: ds = user_pp(ds) # ------------- 2. sub-select months -------------------------------------- ds = ds.sel({time_dim: ds[time_dim].dt.month.isin(months)}) # ------------- 3. optionally aggregate *inside* each file --------------- if avg_per_file and ds[time_dim].size: if weighted: w = ds[time_dim].dt.days_in_month ds = (ds * w).sum(time_dim) / w.sum(time_dim) else: ds = ds.mean(time_dim) # keep a stub time coordinate (first month) so concat works cleanly first_stamp = ds[time_dim][0].values if time_dim in ds.coords else None ds = ds.expand_dims({time_dim: [first_stamp]}) return ds # ────────────────────────────────────────────────────────────────────────────── def _month_filter( ds: xr.Dataset, months: Iterable[int], time_dim: str, weighted: bool, avg_per_file: bool, user_pp: Callable[[xr.Dataset], xr.Dataset] | None = None, ) -> xr.Dataset: if user_pp is not None: ds = user_pp(ds) ds = ds.sel({time_dim: ds[time_dim].dt.month.isin(months)}) if avg_per_file: if ds[time_dim].size == 0: return ds.drop_vars(time_dim, errors="ignore") if weighted: w = ds[time_dim].dt.days_in_month return (ds * w).sum(time_dim) / w.sum(time_dim) else: return ds.mean(time_dim) return ds
[docs] def build_preprocess(months, time_dim, weighted, avg_per_file, user_pp): def _pre(ds): return _month_filter( ds, months=months, time_dim=time_dim, weighted=weighted, avg_per_file=avg_per_file, user_pp=user_pp, ) return _pre
[docs] def seasonal_average( files: Sequence[str], seasons: Mapping[str, Iterable[int]], *, time_dim: str = "time", weighted: bool = True, avg_per_file: bool = False, user_preprocess: Callable[[xr.Dataset], xr.Dataset] | None = None, **open_kwargs, ) -> Dict[str, xr.Dataset]: out: Dict[str, xr.Dataset] = {} for name, months in seasons.items(): print(f"Processing season: {name} with months {months}...") is_djf = set(months) == {12, 1, 2} if is_djf and avg_per_file: print(" NOTE: DJF spans calendar years and cannot be averaged per file correctly.") print(" Forcing avg_per_file=False to preserve seasonal continuity across years.") avg_per_file = False preprocess = build_preprocess( months=months, time_dim=time_dim, weighted=weighted, avg_per_file=avg_per_file, user_pp=user_preprocess, ) local_open_kwargs = open_kwargs.copy() if avg_per_file: local_open_kwargs.update({"combine": "nested", "concat_dim": "file"}) else: local_open_kwargs.setdefault("engine", "h5netcdf") local_open_kwargs.setdefault("lock", False) local_open_kwargs.setdefault("parallel", False) print(" Opening multiple files with open_mfdataset...") ds = xr.open_mfdataset(files, preprocess=preprocess, **local_open_kwargs) if not avg_per_file: if time_dim not in ds.dims or ds[time_dim].size == 0: raise ValueError(f"No data for season '{name}' in the provided files.") if avg_per_file: print(" Averaging across files...") out[name] = ds.mean("file") else: print(" Applying seasonal average logic...") time_vals = pd.to_datetime(ds[time_dim].values) ds = ds.assign_coords({time_dim: time_vals}) if is_djf: ds = ds.sel({time_dim: ds[time_dim].dt.month.isin([12, 1, 2])}) months = ds[time_dim].dt.month years = ds[time_dim].dt.year.where(months != 12, ds[time_dim].dt.year + 1) ds_grouped = ds.groupby(years.rename("djf_year")) valid_groups = [] for year, group in ds_grouped: if group[time_dim].size == 3: if weighted: w = group[time_dim].dt.days_in_month weighted_avg = (group * w).sum(time_dim) / w.sum(time_dim) else: weighted_avg = group.mean(time_dim) valid_groups.append(weighted_avg) if not valid_groups: raise ValueError("No complete DJF season found.") out[name] = xr.concat(valid_groups, dim="djf_year").mean("djf_year") else: if weighted: w = ds[time_dim].dt.days_in_month out[name] = (ds * w).sum(time_dim) / w.sum(time_dim) else: out[name] = ds.mean(time_dim) out[name].attrs.update(ds.attrs) out[name].attrs["season_months"] = list(months) out[name].attrs["season_weighted"] = "Weighted" if weighted else "Unweighted" out[name].attrs["season_avg_per_file"] = ( "Average per file" if avg_per_file else "Average after merging" ) out[name].attrs["season_name"] = name out[name].attrs["season_time_dim"] = time_dim print(f" Finished processing season: {name}\n") return out
[docs] def select_files_by_years(YEARS, file_template="GLORS_YEARS_T_44_{year}.nc", directory="."): """ Selects files based on the years provided, checks their existence in the specified directory, and returns a list of selected file paths along with a sorted list of years for which files were found. Parameters ========== YEARS (list): List of years to check. file_template (str): Template for the file name, e.g., "GLORS_YEARS_T_44_{year}.nc". directory (str): Directory path where the files are stored. Returns ======= tuple: A tuple containing: - selected_files (list): List of file paths found. - sorted_found_years (list): Sorted list of years for which the files were found. """ selected_files = [] found_years = [] # To store years for which the file was found for year in YEARS: # Construct the filename based on the current year filename = file_template.format(year=year) # Check if the file exists in the directory file_path = os.path.join(directory, filename) print(f"Checking for file: {file_path}") if os.path.exists(file_path): selected_files.append(file_path) found_years.append(year) # Add the year to the found_years list else: print(f"File not found: {filename}") # Sort the list of found years sorted_found_years = sorted(found_years) return sorted(selected_files), sorted_found_years
# read a text (txt) file and return an numpy array
[docs] def read_txt_file(file_path: str) -> np.ndarray: """ Read a text file and return its content as a numpy array. Parameters ========== file_path: Path to the text file to be read. Returns ======= numpy.ndarray: Numpy array containing the data from the text file. """ if not os.path.exists(file_path): raise FileNotFoundError(f"The file {file_path} does not exist.") with open(file_path, "r") as f: data = f.readlines() return np.array([line.strip().split() for line in data],dtype='float')
#Write a numpy array to a text file
[docs] def write_txt_file(data: np.ndarray, file_path: str) -> None: """ Write a numpy array to a text file. Parameters ========== data: Numpy array to be written to the file. file_path: Path where the text file will be saved. Returns ======= None """ np.savetxt(file_path, data, fmt='%s') return
[docs] def optimize_for_macos(): """Apply macOS-specific optimizations.""" # Use all available cores efficiently os.environ["OMP_NUM_THREADS"] = str(mp.cpu_count()) # Optimize numpy for Apple Silicon if available if sys.platform == "darwin" and hasattr(np, "__config__"): # Enable accelerated BLAS on macOS os.environ["OPENBLAS_NUM_THREADS"] = str(mp.cpu_count()) # Set xarray/dask to use multiple threads import dask dask.config.set(scheduler='threads') dask.config.set(num_workers=mp.cpu_count())