Source code for skmap.catalog

import json
import random
import re
import sys
from typing import Dict, List, Optional, Sequence, Tuple, Union

import numpy as np
import pandas as pd

import skmap_bindings as sb
from skmap.misc import mmdd_to_doy


[docs] class DataCatalog: def __init__(self, data, data_size) -> None: self.data = data self.data_size = data_size
[docs] @classmethod def create_catalog( cls, catalog_def: Union[pd.DataFrame, str], years: List[int], base_path: Union[List[str], str], verbose: bool = True, replace_group_feat_name: bool = False, ): """Create a metadata catalog from csv file :param catalog_def: Dataframe or csv with definitions :type catalog_def: Union[pd.DataFrame, str] :param years: list of years in analysis :type years: List[int] :param base_path: Base path :type base_path: Union[List[str], str] :param verbose: whether to enable verbose output, defaults to True :type verbose: bool, optional :param replace_group_feat_name: _description_, defaults to False :type replace_group_feat_name: bool, optional :return: A parsed catalog from the input file :rtype: DataCatalog """ # make the catalog catalog_dict = cls._create_dict_catalog( catalog_def, years, base_path, verbose, replace_group_feat_name ) data = {} years: list[str] = [str(year) for year in years] if "common" not in years: # common is default years += ["common"] features_names = cls._get_feature_names(catalog_dict) data_size = 0 for k in years: for f in features_names: if f not in catalog_dict[k]: continue if k not in data: data[k] = {} data[k][f] = {"path": catalog_dict[k][f], "idx": data_size} data_size += 1 data, data_size = cls.expand_whales_dependencies(data, data, data_size) # calls __init__ return cls(data, data_size)
@staticmethod def _create_dict_catalog( catalog_def: Union[pd.DataFrame, str], years: List[int], base_path: Union[List[str], str], verbose: bool, replace_group_feat_name: bool, ): if not isinstance(catalog_def, pd.DataFrame): covar = pd.read_csv(catalog_def) else: covar = catalog_def # Replace placeholders in `layer_name` and `path` if replace_group_feat_name: year_repl = "{year}" year_plus_one_repl = "{year_plus_one}" year_minus_one_repl = "{year_minus_one}" else: year_repl = "YYYY" year_plus_one_repl = "YYPO" year_minus_one_repl = "YYMO" def replace_layer_name_placeholders(value): replacements = { "{year}": year_repl, "{year_plus_one}": year_plus_one_repl, "{year_minus_one}": year_minus_one_repl, "{start_month}": "{start_month}", "{end_month}": "{end_month}", "{perc}": "{perc}", "{dt}": "{{dt}}", } if isinstance(value, str): for old, new in replacements.items(): value = value.replace(old, new) return value covar["layer_name"] = covar["layer_name"].apply(replace_layer_name_placeholders) covar["path"] = covar["path"].apply( lambda x: replace_layer_name_placeholders(x) if "/whale/" in x else x ) perc_mask = covar["layer_name"].str.contains(r"\{perc\}") | covar[ "path" ].str.contains(r"\{perc\}") perc_expanded_rows = [] for _, row in covar[perc_mask].iterrows(): perc_values = [ p.strip() for p in (row["perc"].split(",") if pd.notna(row["perc"]) else [None]) ] for perc in perc_values: new_row = row.copy() if perc: new_row["layer_name"] = new_row["layer_name"].replace( "{perc}", perc ) new_row["path"] = new_row["path"].replace("{perc}", perc) perc_expanded_rows.append(new_row) perc_expanded_df = pd.DataFrame(perc_expanded_rows) # Combine expanded rows with the rest of the dataframe covar = pd.concat([covar[~perc_mask], perc_expanded_df], ignore_index=True) month_mask = ( covar["layer_name"].str.contains(r"\{start_month\}") | covar["path"].str.contains(r"\{start_month\}") | covar["layer_name"].str.contains(r"\{end_month\}") | covar["path"].str.contains(r"\{end_month\}") ) month_expanded_rows = [] for _, row in covar[month_mask].iterrows(): start_month_values = [ sm.strip() for sm in ( row["start_month"].split(",") if pd.notna(row["start_month"]) else [None] ) ] end_month_values = [ em.strip() for em in ( row["end_month"].split(",") if pd.notna(row["end_month"]) else [None] ) ] max_len = max(len(start_month_values), len(end_month_values)) start_month_values = start_month_values or [None] * max_len end_month_values = end_month_values or [None] * max_len for start_month, end_month in zip(start_month_values, end_month_values): new_row = row.copy() if start_month: new_row["layer_name"] = new_row["layer_name"].replace( "{start_month}", start_month ) new_row["path"] = new_row["path"].replace( "{start_month}", start_month ) if end_month: new_row["layer_name"] = new_row["layer_name"].replace( "{end_month}", end_month ) new_row["path"] = new_row["path"].replace("{end_month}", end_month) month_expanded_rows.append(new_row) month_expanded_df = pd.DataFrame(month_expanded_rows) covar = pd.concat([covar[~month_mask], month_expanded_df], ignore_index=True) # Separate common and temporal data covar_comm = covar[covar["type"] == "common"].reset_index(drop=True) covar_temp = covar[covar["type"] == "temporal"].reset_index(drop=True) # Create the comm part of the catalog url_comm = covar_comm.set_index("layer_name")["path"].to_dict() comm = {"common": {layer_name: path for layer_name, path in url_comm.items()}} def calculate_year_placeholders(year, start_year, end_year, tmp_layer_name): if end_year == "" or (type(end_year) is float and ~np.isfinite(end_year)): years = start_year.split( "," ) # years is a list of strings. Ex. 1996,2006,2007,2008,2009,2010,2015,2016,2017,2018,2019,2020 if str(year) in years: return { "year": str(year), "year_plus_one": str(year + 1), "year_minus_one": str(year - 1), "tile_id": "{tile_id}", "base_path": "{base_path}", } # if year is not in the list of years, we propagate the closest year that is available before the given year valid_year = int(years[0]) for y in years: if int(y) < year: valid_year = int(y) else: break if (year != valid_year) & verbose: print( f"Year {year} not available for layer {tmp_layer_name}, propagating year {valid_year}" ) return { "year": str(valid_year), "year_plus_one": str(valid_year + 1), "year_minus_one": str(valid_year - 1), "tile_id": "{tile_id}", "base_path": "{base_path}", } else: valid_year = min(max(year, int(start_year)), int(end_year)) if (year != valid_year) & verbose: print( f"Year {year} not available for layer {tmp_layer_name}, propagating year {valid_year}" ) return { "year": str(valid_year), "year_plus_one": str(valid_year + 1), "year_minus_one": str(valid_year - 1), "tile_id": "{tile_id}", "base_path": "{base_path}", } def calculate_year_feat_name_placeholders(year): return { "year": str(year), "year_plus_one": str(year + 1), "year_minus_one": str(year - 1), "tile_id": "{tile_id}", "base_path": "{base_path}", } # Create the temporal part of the catalog url_temp = covar_temp.set_index("layer_name")["path"].to_dict() temporal = {} for year in years: year_dict = {} for i, (layer_name, path) in enumerate(url_temp.items()): year_placeholders = calculate_year_placeholders( year, covar_temp.loc[i, "start_year"], covar_temp.loc[i, "end_year"], layer_name, ) if replace_group_feat_name: year_feat_name_placeholders = calculate_year_feat_name_placeholders( year ) layer_name = layer_name.format(**year_feat_name_placeholders) year_dict[layer_name] = path.format(**year_placeholders) temporal[str(year)] = year_dict catalog_dict = {**comm, **temporal} for group, inner_dict in catalog_dict.items(): for layer_name, path in inner_dict.items(): if isinstance(base_path, str): base_path_tmp = base_path elif isinstance(base_path, list): base_path_tmp = random.choice(base_path) catalog_dict[group][layer_name] = path.format( base_path=base_path_tmp, tile_id="{tile_id}", dt="{dt}" ) return catalog_dict def save_json(self, json_out_path: Optional[str] = None) -> None: if json_out_path is not None: with open(json_out_path, "w") as f: json.dump(self.data, f, indent=4) def get_groups(self) -> List[str]: groups = sorted( list(set(self.data.keys()).difference(["common"]).difference(["otf"])) ) # by default don't return 'common' nor 'otf' group if not len(groups): # but, if there is only group 'common', return it. groups = ["common"] return groups def copy(self): return DataCatalog(self.data.copy(), int(self.data_size)) @staticmethod def _get_feature_names(catalog_dict: Dict) -> List[str]: """Make a set of feature names :param catalog_dict: dict-of-dicts :type catalog_dict: Dict :return: Set of unique feature names :rtype: {str} """ return sorted( { layer_name for _, inner_dict in catalog_dict.items() for layer_name, _ in inner_dict.items() } ) def find_group_and_feature_by_index( self, target_idx: int ) -> tuple[Optional[str], Optional[str]]: for group_name, features in self.data.items(): for feature_name, feature_info in features.items(): if feature_info.get("idx") == target_idx: return group_name, feature_name return None, None def get_feature_names(self) -> List[str]: return self._get_feature_names(self.data) def get_paths(self) -> Tuple[List[str], List[int], List[str]]: paths, idx, names = [], [], [] for k in self.data: if k == "otf": continue for f in self.data[k]: if not self.data[k][f]["path"].startswith("/whale"): paths += [self.data[k][f]["path"]] idx += [self.data[k][f]["idx"]] names += [f] # modify paths of non VRT files\ paths = [ f"/vsicurl/{p}" if p and p.startswith("http") and not p.endswith("vrt") and not p.startswith("/vsicurl/") else p for p in paths ] return paths, idx, names def get_unrolled_catalog(self) -> Tuple[List[str], List[str], List[int]]: names, paths, idx = [], [], [] for k in self.data: for f in self.data[k]: paths += [self.data[k][f]["path"]] idx += [self.data[k][f]["idx"]] names += [f] return names, paths, idx @staticmethod def get_whales(data): whale_paths, whale_keys, whale_layer_names = [], [], [] for k in data: if k == "otf": continue for f in data[k]: if data[k][f]["path"].startswith("/whale"): whale_paths += [data[k][f]["path"]] whale_layer_names += [f] whale_keys += [k] return whale_paths, whale_keys, whale_layer_names def _get_whales(self): return self.get_whales(self.data) def query(self, feature_names, groups: Optional[list[str]] = None) -> None: if groups is None: groups = self.get_groups() if not isinstance(groups, list): raise ValueError("Invalid `groups` parameter. Expecting a list.") if "common" not in groups: # include 'common' group by default groups = groups + ["common"] old_data = self.data.copy() self.data = {} self.data_size = 0 for k in groups: if k in old_data: for f in feature_names: if f in old_data[k]: if k not in self.data: self.data[k] = {} self.data[k][f] = { "path": old_data[k][f]["path"], "idx": self.data_size, } self.data_size += 1 else: print(f"Group {k} is missing from original catalog, skipping it") self._expand_whales_dependencies(old_data) missing_features_names = [ feature for feature in feature_names if feature not in set(self.get_feature_names()) ] for missing_feat_feature in missing_features_names: print( f"WARNING: Feature {missing_feat_feature} is missing in the original catalog, adding is in the otf (on the fly) common group" ) if missing_features_names: self.add_otf_features(missing_features_names) def add_otf_features(self, otf_features) -> None: if "otf" not in self.data: self.data["otf"] = {} for otf_feature in otf_features: self.data["otf"][otf_feature] = {"path": None, "idx": self.data_size} self.data_size += 1 @classmethod def expand_whales_dependencies(cls, reference_catalog_data, data, data_size): whale_paths, groups, whale_layer_names = cls.get_whales(data) for i, (whale_path, whale_layer_name) in enumerate( zip(whale_paths, whale_layer_names) ): dep_names, dep_paths, dep_exec_orders, dep_keys = get_whale_dependencies( whale_path, groups[i], reference_catalog_data, whale_layer_names[i] ) for dep_name, dep_path, dep_exec_order, dep_key in zip( dep_names, dep_paths, dep_exec_orders, dep_keys ): if dep_name not in data[dep_key]: data[dep_key][dep_name] = { "path": dep_path, "idx": data_size, "exec_order": dep_exec_order, } data_size += 1 elif dep_name == whale_layer_name: data[dep_key][dep_name]["exec_order"] = dep_exec_order return data, data_size def _expand_whales_dependencies(self, reference_catalog_data) -> None: self.data, self.data_size = self.expand_whales_dependencies( reference_catalog_data, self.data, self.data_size ) def get_otf_idx(self): otf_idx = {} if "otf" in self.data: for f in self.data["otf"]: if f not in otf_idx: otf_idx[f] = [] otf_idx[f] += [self.data["otf"][f]["idx"]] return otf_idx def _get_covs_idx(self, covs_lst: List[str]): groups = self.get_groups() covs_idx = np.zeros((len(covs_lst), len(groups)), np.int32) for j in range(len(groups)): k = groups[j] for i in range(len(covs_lst)): c = covs_lst[i] if "common" in self.data and c in self.data["common"]: covs_idx[i, j] = self.data["common"][c]["idx"] elif c in self.data[k]: covs_idx[i, j] = self.data[k][c]["idx"] else: covs_idx[i, j] = self.data["otf"][c]["idx"] return covs_idx
# #
[docs] def get_whale_dependencies(whale, key, main_catalog, whale_layer_name): func_name, params = parse_template_whale(whale) dep_tags, dep_names, dep_paths, dep_exec_orders, dep_keys = [], [], [], [], [] if func_name == "percentileAggregation": tag = params["entry_template"] for dt in params["dt"]: dep_tags += [tag.format(dt=dt)] elif func_name == "computeNormalizedDifference": dep_tags += [params["idx_plus"]] dep_tags += [params["idx_minus"]] elif func_name == "computeSavi": dep_tags += [params["idx_red"]] dep_tags += [params["idx_nir"]] elif func_name == "computeGeometricTemperature": dep_tags += [params["idx_latitude"]] dep_tags += [params["idx_elevation"]] elif func_name == "extractIndicator": dep_tags += [params["idx_layer"]] for dep_tag in dep_tags: if dep_tag in main_catalog[key]: dep_path = main_catalog[key][dep_tag]["path"] dep_key = key else: dep_path = main_catalog["common"][dep_tag]["path"] dep_key = "common" if dep_path.startswith("/whale"): rec_dep_names, rec_dep_paths, rec_dep_exec_orders, rec_dep_keys = ( get_whale_dependencies(dep_path, dep_key, main_catalog, dep_tag) ) dep_paths += rec_dep_paths dep_names += rec_dep_names dep_exec_orders += rec_dep_exec_orders dep_keys += rec_dep_keys else: dep_paths += [dep_path] dep_names += [dep_tag] dep_exec_orders += [0] dep_keys += [dep_key] dep_paths += [whale] dep_names += [whale_layer_name] if dep_exec_orders: dep_exec_orders += [max(dep_exec_orders) + 1] else: dep_exec_orders += [0] dep_keys += [key] return dep_names, dep_paths, dep_exec_orders, dep_keys
[docs] def parse_template_whale(whale): func_name_match = re.search(r"/whale/([^?]+)", whale) func_name = func_name_match.group(1) if func_name_match else None query_params_string = whale.split("?")[1] params = {} if query_params_string != "": for param in query_params_string.split("&"): key, value = param.split("=") # Split by commas to handle lists if "," in value: params[key] = value.split(",") else: params[key] = value return func_name, params
[docs] def run_whales(catalog: DataCatalog, array, n_threads: int, lat_info=None) -> None: # Computing on the fly covariates whale_paths, whale_keys, whale_names = catalog._get_whales() max_exec_order = 0 for whale_key, whale_name in zip(whale_keys, whale_names): max_exec_order = max( max_exec_order, int(catalog.data[whale_key][whale_name]["exec_order"]) ) for exec_order in range(max_exec_order + 1): for whale_path, whale_key, whale_name in zip( whale_paths, whale_keys, whale_names ): if exec_order != int(catalog.data[whale_key][whale_name]["exec_order"]): continue else: func_name, params = parse_template_whale(whale_path) whale_data_idx = catalog.data[whale_key][whale_name]["idx"] if func_name == "percentileAggregation": tag = params["entry_template"] in_idxs = [] for dt in params["dt"]: in_idxs += [catalog.data[whale_key][tag.format(dt=dt)]["idx"]] # @FIXME with this setting we do the sorting for percentiles N times for each used percentile # @FIXME implement the percentiles without the need of transposition array_sb = np.empty( (len(in_idxs), array.shape[1]), dtype=np.float32 ) array_sb_t = np.empty( (array_sb.shape[1], array_sb.shape[0]), dtype=np.float32 ) array_pct_t = np.empty((array_sb.shape[1], 1), dtype=np.float32) sb.selArrayRows(array, n_threads, array_sb, in_idxs) sb.transposeArray(array_sb, n_threads, array_sb_t) sb.computePercentiles( array_sb_t, n_threads, range(len(in_idxs)), array_pct_t, [0], [float(params["percentile"])], ) array[whale_data_idx, :] = array_pct_t[:, 0] elif func_name == "computeNormalizedDifference": sb.computeNormalizedDifference( array, n_threads, [catalog.data[whale_key][params["idx_plus"]]["idx"]], [catalog.data[whale_key][params["idx_minus"]]["idx"]], [whale_data_idx], float(params["scale_plus"]), float(params["scale_minus"]), float(params["scale_result"]), float(params["offset_result"]), [float(params["clip"][0]), float(params["clip"][1])], ) elif func_name == "computeSavi": sb.computeSavi( array, n_threads, [catalog.data[whale_key][params["idx_red"]]["idx"]], [catalog.data[whale_key][params["idx_nir"]]["idx"]], [whale_data_idx], float(params["scale_red"]), float(params["scale_nir"]), float(params["scale_result"]), float(params["offset_result"]), [float(params["clip"][0]), float(params["clip"][1])], ) elif func_name == "extractIndicator": array[whale_data_idx, :] = ( array[catalog.data[whale_key][params["idx_layer"]]["idx"], :] == float(params["code"]) ).astype(np.float32) elif func_name == "getLatitude": if ( not isinstance(lat_info, np.ndarray) and len(lat_info.shape) == 1 and lat_info.shape[0] == array.shape[1] ): raise Exception( "Information on how to get the latitude needs to be a numpy vector" ) elif ( isinstance(lat_info, np.ndarray) and len(lat_info.shape) == 1 and lat_info.shape[0] == array.shape[1] ): array[whale_data_idx, :] = lat_info elif func_name == "computeGeometricTemperature": day_of_year = mmdd_to_doy(str(params["day_of_year_mmdd"])) if params["idx_latitude"] in catalog.data[whale_key]: latitude = array[ catalog.data[whale_key][params["idx_latitude"]]["idx"], : ] else: latitude = array[ catalog.data["common"][params["idx_latitude"]]["idx"], : ] if params["idx_elevation"] in catalog.data[whale_key]: elevation = array[ catalog.data[whale_key][params["idx_elevation"]]["idx"], : ] else: elevation = array[ catalog.data["common"][params["idx_elevation"]]["idx"], : ] sb.computeGeometricTemperature( array, n_threads, latitude, elevation, float(params["elevation_scaling"]), float(params["a"]), float(params["b"]), float(params["result_scaling"]), [whale_data_idx], [day_of_year], ) else: sys.exit(f"The whale function {func_name} is not available")