Source code for skmap.load_config
from types import SimpleNamespace
from typing import Any, Dict
import yaml
from skmap.modeler import (
Classifier,
Modeler,
Regressor,
RFClassifier,
RFRegressor,
RFRegressorTrees,
)
MODEL_REGISTRY = {
"RFRegressor": RFRegressor,
"RFRegressorTrees": RFRegressorTrees,
"Modeler": Modeler,
"Classifier": Classifier,
"RFClassifier": RFClassifier,
"Regressor": Regressor,
}
class _SafeDict(dict):
"""
A dictionary subclass that returns '{key}' for missing keys when used
with str.format_map(). This prevents KeyError for placeholders.
"""
def __missing__(self, key: str) -> str:
return f"{{{key}}}"
def _recursive_format(item: Any, context: dict) -> Any:
"""
Recursively traverses a nested structure (dicts, lists), formats string
values using the context, and converts any string 'None' to the
Python None object.
"""
if isinstance(item, dict):
return {k: _recursive_format(v, context) for k, v in item.items()}
elif isinstance(item, list):
return [_recursive_format(v, context) for v in item]
elif isinstance(item, str):
if item == "None":
return None
return item.format_map(context)
else:
return item
def _to_hybrid_namespace(d: Dict) -> SimpleNamespace:
"""
Converts only the top level of a dictionary to a SimpleNamespace.
Nested dictionaries and lists remain as they are.
"""
ns = SimpleNamespace()
for key, value in d.items():
setattr(ns, key, value)
return ns
[docs]
def parse_config(yaml_path: str) -> SimpleNamespace:
"""
Parses a YAML configuration file into a hybrid namespace where only the
top-level keys are accessible via dot notation. It resolves self-referential
string templates and converts string values of 'None' to NoneType.
Args:
yaml_path: The path to the input YAML file.
Returns:
A SimpleNamespace object with nested dictionaries.
"""
with open(yaml_path, "r") as f:
data = yaml.safe_load(f)
# 1. Separate base config from the models list
base_config = {k: v for k, v in data.items() if k != "models_params"}
models_params_list = data.get("models_params", [])
# 2. Iteratively format the base configuration
for _ in range(5):
context = _SafeDict(base_config)
base_config = _recursive_format(base_config, context)
# 3. Process each dictionary within the models_params list
processed_models = []
for model_dict in models_params_list:
full_context = _SafeDict({**base_config, **model_dict})
formatted_model = _recursive_format(model_dict, full_context)
formatted_model["model"] = MODEL_REGISTRY[base_config["model_type"]](
formatted_model["model_path_template"]
)
processed_models.append(formatted_model)
# 4. Recombine into the final configuration dictionary
final_config_dict = base_config
final_config_dict["models_params"] = processed_models
final_config_dict["s3_params"]["s3_addresses"] = [
final_config_dict["gaia_addr_range"]["template"].format(gaia_ip=gaia_ip)
for gaia_ip in range(
final_config_dict["gaia_addr_range"]["start"],
final_config_dict["gaia_addr_range"]["end"],
)
]
# 5. Convert only the top level to a SimpleNamespace
return _to_hybrid_namespace(final_config_dict)