import json
import random
import re
import sys
from typing import Dict, List, Optional, Sequence, Tuple, Union
import numpy as np
import pandas as pd
import skmap_bindings as sb
from skmap.misc import mmdd_to_doy
[docs]
class DataCatalog:
def __init__(self, data, data_size) -> None:
self.data = data
self.data_size = data_size
[docs]
@classmethod
def create_catalog(
cls,
catalog_def: Union[pd.DataFrame, str],
years: List[int],
base_path: Union[List[str], str],
verbose: bool = True,
replace_group_feat_name: bool = False,
):
"""Create a metadata catalog from csv file
:param catalog_def: Dataframe or csv with definitions
:type catalog_def: Union[pd.DataFrame, str]
:param years: list of years in analysis
:type years: List[int]
:param base_path: Base path
:type base_path: Union[List[str], str]
:param verbose: whether to enable verbose output, defaults to True
:type verbose: bool, optional
:param replace_group_feat_name: _description_, defaults to False
:type replace_group_feat_name: bool, optional
:return: A parsed catalog from the input file
:rtype: DataCatalog
"""
# make the catalog
catalog_dict = cls._create_dict_catalog(
catalog_def, years, base_path, verbose, replace_group_feat_name
)
data = {}
years: list[str] = [str(year) for year in years]
if "common" not in years: # common is default
years += ["common"]
features_names = cls._get_feature_names(catalog_dict)
data_size = 0
for k in years:
for f in features_names:
if f not in catalog_dict[k]:
continue
if k not in data:
data[k] = {}
data[k][f] = {"path": catalog_dict[k][f], "idx": data_size}
data_size += 1
data, data_size = cls.expand_whales_dependencies(data, data, data_size)
# calls __init__
return cls(data, data_size)
@staticmethod
def _create_dict_catalog(
catalog_def: Union[pd.DataFrame, str],
years: List[int],
base_path: Union[List[str], str],
verbose: bool,
replace_group_feat_name: bool,
):
if not isinstance(catalog_def, pd.DataFrame):
covar = pd.read_csv(catalog_def)
else:
covar = catalog_def
# Replace placeholders in `layer_name` and `path`
if replace_group_feat_name:
year_repl = "{year}"
year_plus_one_repl = "{year_plus_one}"
year_minus_one_repl = "{year_minus_one}"
else:
year_repl = "YYYY"
year_plus_one_repl = "YYPO"
year_minus_one_repl = "YYMO"
def replace_layer_name_placeholders(value):
replacements = {
"{year}": year_repl,
"{year_plus_one}": year_plus_one_repl,
"{year_minus_one}": year_minus_one_repl,
"{start_month}": "{start_month}",
"{end_month}": "{end_month}",
"{perc}": "{perc}",
"{dt}": "{{dt}}",
}
if isinstance(value, str):
for old, new in replacements.items():
value = value.replace(old, new)
return value
covar["layer_name"] = covar["layer_name"].apply(replace_layer_name_placeholders)
covar["path"] = covar["path"].apply(
lambda x: replace_layer_name_placeholders(x) if "/whale/" in x else x
)
perc_mask = covar["layer_name"].str.contains(r"\{perc\}") | covar[
"path"
].str.contains(r"\{perc\}")
perc_expanded_rows = []
for _, row in covar[perc_mask].iterrows():
perc_values = [
p.strip()
for p in (row["perc"].split(",") if pd.notna(row["perc"]) else [None])
]
for perc in perc_values:
new_row = row.copy()
if perc:
new_row["layer_name"] = new_row["layer_name"].replace(
"{perc}", perc
)
new_row["path"] = new_row["path"].replace("{perc}", perc)
perc_expanded_rows.append(new_row)
perc_expanded_df = pd.DataFrame(perc_expanded_rows)
# Combine expanded rows with the rest of the dataframe
covar = pd.concat([covar[~perc_mask], perc_expanded_df], ignore_index=True)
month_mask = (
covar["layer_name"].str.contains(r"\{start_month\}")
| covar["path"].str.contains(r"\{start_month\}")
| covar["layer_name"].str.contains(r"\{end_month\}")
| covar["path"].str.contains(r"\{end_month\}")
)
month_expanded_rows = []
for _, row in covar[month_mask].iterrows():
start_month_values = [
sm.strip()
for sm in (
row["start_month"].split(",")
if pd.notna(row["start_month"])
else [None]
)
]
end_month_values = [
em.strip()
for em in (
row["end_month"].split(",")
if pd.notna(row["end_month"])
else [None]
)
]
max_len = max(len(start_month_values), len(end_month_values))
start_month_values = start_month_values or [None] * max_len
end_month_values = end_month_values or [None] * max_len
for start_month, end_month in zip(start_month_values, end_month_values):
new_row = row.copy()
if start_month:
new_row["layer_name"] = new_row["layer_name"].replace(
"{start_month}", start_month
)
new_row["path"] = new_row["path"].replace(
"{start_month}", start_month
)
if end_month:
new_row["layer_name"] = new_row["layer_name"].replace(
"{end_month}", end_month
)
new_row["path"] = new_row["path"].replace("{end_month}", end_month)
month_expanded_rows.append(new_row)
month_expanded_df = pd.DataFrame(month_expanded_rows)
covar = pd.concat([covar[~month_mask], month_expanded_df], ignore_index=True)
# Separate common and temporal data
covar_comm = covar[covar["type"] == "common"].reset_index(drop=True)
covar_temp = covar[covar["type"] == "temporal"].reset_index(drop=True)
# Create the comm part of the catalog
url_comm = covar_comm.set_index("layer_name")["path"].to_dict()
comm = {"common": {layer_name: path for layer_name, path in url_comm.items()}}
def calculate_year_placeholders(year, start_year, end_year, tmp_layer_name):
if end_year == "" or (type(end_year) is float and ~np.isfinite(end_year)):
years = start_year.split(
","
) # years is a list of strings. Ex. 1996,2006,2007,2008,2009,2010,2015,2016,2017,2018,2019,2020
if str(year) in years:
return {
"year": str(year),
"year_plus_one": str(year + 1),
"year_minus_one": str(year - 1),
"tile_id": "{tile_id}",
"base_path": "{base_path}",
}
# if year is not in the list of years, we propagate the closest year that is available before the given year
valid_year = int(years[0])
for y in years:
if int(y) < year:
valid_year = int(y)
else:
break
if (year != valid_year) & verbose:
print(
f"Year {year} not available for layer {tmp_layer_name}, propagating year {valid_year}"
)
return {
"year": str(valid_year),
"year_plus_one": str(valid_year + 1),
"year_minus_one": str(valid_year - 1),
"tile_id": "{tile_id}",
"base_path": "{base_path}",
}
else:
valid_year = min(max(year, int(start_year)), int(end_year))
if (year != valid_year) & verbose:
print(
f"Year {year} not available for layer {tmp_layer_name}, propagating year {valid_year}"
)
return {
"year": str(valid_year),
"year_plus_one": str(valid_year + 1),
"year_minus_one": str(valid_year - 1),
"tile_id": "{tile_id}",
"base_path": "{base_path}",
}
def calculate_year_feat_name_placeholders(year):
return {
"year": str(year),
"year_plus_one": str(year + 1),
"year_minus_one": str(year - 1),
"tile_id": "{tile_id}",
"base_path": "{base_path}",
}
# Create the temporal part of the catalog
url_temp = covar_temp.set_index("layer_name")["path"].to_dict()
temporal = {}
for year in years:
year_dict = {}
for i, (layer_name, path) in enumerate(url_temp.items()):
year_placeholders = calculate_year_placeholders(
year,
covar_temp.loc[i, "start_year"],
covar_temp.loc[i, "end_year"],
layer_name,
)
if replace_group_feat_name:
year_feat_name_placeholders = calculate_year_feat_name_placeholders(
year
)
layer_name = layer_name.format(**year_feat_name_placeholders)
year_dict[layer_name] = path.format(**year_placeholders)
temporal[str(year)] = year_dict
catalog_dict = {**comm, **temporal}
for group, inner_dict in catalog_dict.items():
for layer_name, path in inner_dict.items():
if isinstance(base_path, str):
base_path_tmp = base_path
elif isinstance(base_path, list):
base_path_tmp = random.choice(base_path)
catalog_dict[group][layer_name] = path.format(
base_path=base_path_tmp, tile_id="{tile_id}", dt="{dt}"
)
return catalog_dict
def save_json(self, json_out_path: Optional[str] = None) -> None:
if json_out_path is not None:
with open(json_out_path, "w") as f:
json.dump(self.data, f, indent=4)
def get_groups(self) -> List[str]:
groups = sorted(
list(set(self.data.keys()).difference(["common"]).difference(["otf"]))
) # by default don't return 'common' nor 'otf' group
if not len(groups): # but, if there is only group 'common', return it.
groups = ["common"]
return groups
def copy(self):
return DataCatalog(self.data.copy(), int(self.data_size))
@staticmethod
def _get_feature_names(catalog_dict: Dict) -> List[str]:
"""Make a set of feature names
:param catalog_dict: dict-of-dicts
:type catalog_dict: Dict
:return: Set of unique feature names
:rtype: {str}
"""
return sorted(
{
layer_name
for _, inner_dict in catalog_dict.items()
for layer_name, _ in inner_dict.items()
}
)
def find_group_and_feature_by_index(
self, target_idx: int
) -> tuple[Optional[str], Optional[str]]:
for group_name, features in self.data.items():
for feature_name, feature_info in features.items():
if feature_info.get("idx") == target_idx:
return group_name, feature_name
return None, None
def get_feature_names(self) -> List[str]:
return self._get_feature_names(self.data)
def get_paths(self) -> Tuple[List[str], List[int], List[str]]:
paths, idx, names = [], [], []
for k in self.data:
if k == "otf":
continue
for f in self.data[k]:
if not self.data[k][f]["path"].startswith("/whale"):
paths += [self.data[k][f]["path"]]
idx += [self.data[k][f]["idx"]]
names += [f]
# modify paths of non VRT files\
paths = [
f"/vsicurl/{p}"
if p
and p.startswith("http")
and not p.endswith("vrt")
and not p.startswith("/vsicurl/")
else p
for p in paths
]
return paths, idx, names
def get_unrolled_catalog(self) -> Tuple[List[str], List[str], List[int]]:
names, paths, idx = [], [], []
for k in self.data:
for f in self.data[k]:
paths += [self.data[k][f]["path"]]
idx += [self.data[k][f]["idx"]]
names += [f]
return names, paths, idx
@staticmethod
def get_whales(data):
whale_paths, whale_keys, whale_layer_names = [], [], []
for k in data:
if k == "otf":
continue
for f in data[k]:
if data[k][f]["path"].startswith("/whale"):
whale_paths += [data[k][f]["path"]]
whale_layer_names += [f]
whale_keys += [k]
return whale_paths, whale_keys, whale_layer_names
def _get_whales(self):
return self.get_whales(self.data)
def query(self, feature_names, groups: Optional[list[str]] = None) -> None:
if groups is None:
groups = self.get_groups()
if not isinstance(groups, list):
raise ValueError("Invalid `groups` parameter. Expecting a list.")
if "common" not in groups: # include 'common' group by default
groups = groups + ["common"]
old_data = self.data.copy()
self.data = {}
self.data_size = 0
for k in groups:
if k in old_data:
for f in feature_names:
if f in old_data[k]:
if k not in self.data:
self.data[k] = {}
self.data[k][f] = {
"path": old_data[k][f]["path"],
"idx": self.data_size,
}
self.data_size += 1
else:
print(f"Group {k} is missing from original catalog, skipping it")
self._expand_whales_dependencies(old_data)
missing_features_names = [
feature
for feature in feature_names
if feature not in set(self.get_feature_names())
]
for missing_feat_feature in missing_features_names:
print(
f"WARNING: Feature {missing_feat_feature} is missing in the original catalog, adding is in the otf (on the fly) common group"
)
if missing_features_names:
self.add_otf_features(missing_features_names)
def add_otf_features(self, otf_features) -> None:
if "otf" not in self.data:
self.data["otf"] = {}
for otf_feature in otf_features:
self.data["otf"][otf_feature] = {"path": None, "idx": self.data_size}
self.data_size += 1
@classmethod
def expand_whales_dependencies(cls, reference_catalog_data, data, data_size):
whale_paths, groups, whale_layer_names = cls.get_whales(data)
for i, (whale_path, whale_layer_name) in enumerate(
zip(whale_paths, whale_layer_names)
):
dep_names, dep_paths, dep_exec_orders, dep_keys = get_whale_dependencies(
whale_path, groups[i], reference_catalog_data, whale_layer_names[i]
)
for dep_name, dep_path, dep_exec_order, dep_key in zip(
dep_names, dep_paths, dep_exec_orders, dep_keys
):
if dep_name not in data[dep_key]:
data[dep_key][dep_name] = {
"path": dep_path,
"idx": data_size,
"exec_order": dep_exec_order,
}
data_size += 1
elif dep_name == whale_layer_name:
data[dep_key][dep_name]["exec_order"] = dep_exec_order
return data, data_size
def _expand_whales_dependencies(self, reference_catalog_data) -> None:
self.data, self.data_size = self.expand_whales_dependencies(
reference_catalog_data, self.data, self.data_size
)
def get_otf_idx(self):
otf_idx = {}
if "otf" in self.data:
for f in self.data["otf"]:
if f not in otf_idx:
otf_idx[f] = []
otf_idx[f] += [self.data["otf"][f]["idx"]]
return otf_idx
def _get_covs_idx(self, covs_lst: List[str]):
groups = self.get_groups()
covs_idx = np.zeros((len(covs_lst), len(groups)), np.int32)
for j in range(len(groups)):
k = groups[j]
for i in range(len(covs_lst)):
c = covs_lst[i]
if "common" in self.data and c in self.data["common"]:
covs_idx[i, j] = self.data["common"][c]["idx"]
elif c in self.data[k]:
covs_idx[i, j] = self.data[k][c]["idx"]
else:
covs_idx[i, j] = self.data["otf"][c]["idx"]
return covs_idx
#
[docs]
def print_catalog_statistics(catalog: DataCatalog) -> None:
groups = list(catalog.data.keys())
groups.sort()
print(f"catalog groups: {groups}")
print(f"- rasters to read: {len(catalog.get_paths()[0])}")
print(f"- whales: {len(catalog._get_whales())}")
print(f"- on-the-fly features: {len(catalog.get_otf_idx())}")
if len(catalog.get_otf_idx()) > 0:
otf_list = list(catalog.get_otf_idx().keys())
otf_list.sort()
print(f"- otf list: {otf_list}")
print("")
#
[docs]
def get_whale_dependencies(whale, key, main_catalog, whale_layer_name):
func_name, params = parse_template_whale(whale)
dep_tags, dep_names, dep_paths, dep_exec_orders, dep_keys = [], [], [], [], []
if func_name == "percentileAggregation":
tag = params["entry_template"]
for dt in params["dt"]:
dep_tags += [tag.format(dt=dt)]
elif func_name == "computeNormalizedDifference":
dep_tags += [params["idx_plus"]]
dep_tags += [params["idx_minus"]]
elif func_name == "computeSavi":
dep_tags += [params["idx_red"]]
dep_tags += [params["idx_nir"]]
elif func_name == "computeGeometricTemperature":
dep_tags += [params["idx_latitude"]]
dep_tags += [params["idx_elevation"]]
elif func_name == "extractIndicator":
dep_tags += [params["idx_layer"]]
for dep_tag in dep_tags:
if dep_tag in main_catalog[key]:
dep_path = main_catalog[key][dep_tag]["path"]
dep_key = key
else:
dep_path = main_catalog["common"][dep_tag]["path"]
dep_key = "common"
if dep_path.startswith("/whale"):
rec_dep_names, rec_dep_paths, rec_dep_exec_orders, rec_dep_keys = (
get_whale_dependencies(dep_path, dep_key, main_catalog, dep_tag)
)
dep_paths += rec_dep_paths
dep_names += rec_dep_names
dep_exec_orders += rec_dep_exec_orders
dep_keys += rec_dep_keys
else:
dep_paths += [dep_path]
dep_names += [dep_tag]
dep_exec_orders += [0]
dep_keys += [dep_key]
dep_paths += [whale]
dep_names += [whale_layer_name]
if dep_exec_orders:
dep_exec_orders += [max(dep_exec_orders) + 1]
else:
dep_exec_orders += [0]
dep_keys += [key]
return dep_names, dep_paths, dep_exec_orders, dep_keys
[docs]
def parse_template_whale(whale):
func_name_match = re.search(r"/whale/([^?]+)", whale)
func_name = func_name_match.group(1) if func_name_match else None
query_params_string = whale.split("?")[1]
params = {}
if query_params_string != "":
for param in query_params_string.split("&"):
key, value = param.split("=")
# Split by commas to handle lists
if "," in value:
params[key] = value.split(",")
else:
params[key] = value
return func_name, params
[docs]
def run_whales(catalog: DataCatalog, array, n_threads: int, lat_info=None) -> None:
# Computing on the fly covariates
whale_paths, whale_keys, whale_names = catalog._get_whales()
max_exec_order = 0
for whale_key, whale_name in zip(whale_keys, whale_names):
max_exec_order = max(
max_exec_order, int(catalog.data[whale_key][whale_name]["exec_order"])
)
for exec_order in range(max_exec_order + 1):
for whale_path, whale_key, whale_name in zip(
whale_paths, whale_keys, whale_names
):
if exec_order != int(catalog.data[whale_key][whale_name]["exec_order"]):
continue
else:
func_name, params = parse_template_whale(whale_path)
whale_data_idx = catalog.data[whale_key][whale_name]["idx"]
if func_name == "percentileAggregation":
tag = params["entry_template"]
in_idxs = []
for dt in params["dt"]:
in_idxs += [catalog.data[whale_key][tag.format(dt=dt)]["idx"]]
# @FIXME with this setting we do the sorting for percentiles N times for each used percentile
# @FIXME implement the percentiles without the need of transposition
array_sb = np.empty(
(len(in_idxs), array.shape[1]), dtype=np.float32
)
array_sb_t = np.empty(
(array_sb.shape[1], array_sb.shape[0]), dtype=np.float32
)
array_pct_t = np.empty((array_sb.shape[1], 1), dtype=np.float32)
sb.selArrayRows(array, n_threads, array_sb, in_idxs)
sb.transposeArray(array_sb, n_threads, array_sb_t)
sb.computePercentiles(
array_sb_t,
n_threads,
range(len(in_idxs)),
array_pct_t,
[0],
[float(params["percentile"])],
)
array[whale_data_idx, :] = array_pct_t[:, 0]
elif func_name == "computeNormalizedDifference":
sb.computeNormalizedDifference(
array,
n_threads,
[catalog.data[whale_key][params["idx_plus"]]["idx"]],
[catalog.data[whale_key][params["idx_minus"]]["idx"]],
[whale_data_idx],
float(params["scale_plus"]),
float(params["scale_minus"]),
float(params["scale_result"]),
float(params["offset_result"]),
[float(params["clip"][0]), float(params["clip"][1])],
)
elif func_name == "computeSavi":
sb.computeSavi(
array,
n_threads,
[catalog.data[whale_key][params["idx_red"]]["idx"]],
[catalog.data[whale_key][params["idx_nir"]]["idx"]],
[whale_data_idx],
float(params["scale_red"]),
float(params["scale_nir"]),
float(params["scale_result"]),
float(params["offset_result"]),
[float(params["clip"][0]), float(params["clip"][1])],
)
elif func_name == "extractIndicator":
array[whale_data_idx, :] = (
array[catalog.data[whale_key][params["idx_layer"]]["idx"], :]
== float(params["code"])
).astype(np.float32)
elif func_name == "getLatitude":
if (
not isinstance(lat_info, np.ndarray)
and len(lat_info.shape) == 1
and lat_info.shape[0] == array.shape[1]
):
raise Exception(
"Information on how to get the latitude needs to be a numpy vector"
)
elif (
isinstance(lat_info, np.ndarray)
and len(lat_info.shape) == 1
and lat_info.shape[0] == array.shape[1]
):
array[whale_data_idx, :] = lat_info
elif func_name == "computeGeometricTemperature":
day_of_year = mmdd_to_doy(str(params["day_of_year_mmdd"]))
if params["idx_latitude"] in catalog.data[whale_key]:
latitude = array[
catalog.data[whale_key][params["idx_latitude"]]["idx"], :
]
else:
latitude = array[
catalog.data["common"][params["idx_latitude"]]["idx"], :
]
if params["idx_elevation"] in catalog.data[whale_key]:
elevation = array[
catalog.data[whale_key][params["idx_elevation"]]["idx"], :
]
else:
elevation = array[
catalog.data["common"][params["idx_elevation"]]["idx"], :
]
sb.computeGeometricTemperature(
array,
n_threads,
latitude,
elevation,
float(params["elevation_scaling"]),
float(params["a"]),
float(params["b"]),
float(params["result_scaling"]),
[whale_data_idx],
[day_of_year],
)
else:
sys.exit(f"The whale function {func_name} is not available")