"""Model-grid emulator utilities."""
# besta/grid/emulator.py
# -*- coding: utf-8 -*-
from __future__ import annotations
import os
import json
from dataclasses import dataclass, asdict
from typing import Any, Dict, List, Optional, Tuple, Union, Callable
import numpy as np
from joblib import dump, load
from besta.utils import mkdir
from besta.grid.transforms import LinearStandardiser
# -----------------------------------------------------------------------------
# utils
# -----------------------------------------------------------------------------
def _ensure_2d(a: np.ndarray) -> np.ndarray:
a = np.asarray(a)
if a.ndim == 1:
return a[:, None]
return a
def _as_f32(a: np.ndarray) -> np.ndarray:
a = np.asarray(a)
return a.astype(np.float32, copy=False) if a.dtype != np.float32 else a
# -----------------------------------------------------------------------------
# TransformPack
# -----------------------------------------------------------------------------
# -----------------------------------------------------------------------------
# EmulatorConfig
# -----------------------------------------------------------------------------
[docs]
@dataclass
class EmulatorConfig:
"""Configuration container for :class:`Emulator` predictions and errors."""
y_space: str = "mag" # metadata only: "mag", "flux", "ew" etc.
standardize_y: bool = True
predict_batch_size: int = 131072
version: str = "EmulatorV5"
# error model
error_model: str = "zero" # "zero" | "ensemble_std"
error_floor: float = 0.0 # constant floor in y_space units
# -----------------------------------------------------------------------------
# Emulator
# -----------------------------------------------------------------------------
[docs]
class Emulator:
"""Generic emulator.
Parameters
----------
model : Any
Trained regression model with a .predict(X) method.
transforms : TransformPack
Transforms for targets -> X and model_space -> y_space.
config : EmulatorConfig, optional
Emulator configuration.
Methods
-------
predict(targets, return_model_space=False) -> np.ndarray
Predict observables in y_space for given targets.
predict_error(targets) -> np.ndarray
Predict errors in y_space for given targets.
predict_with_error(targets) -> Tuple[np.ndarray, np.ndarray]
Predict observables and errors in y_space for given targets.
save(outdir, name="emulator", compress=("lz4", 3)) -> Dict[str, str]
Save the emulator to disk.
load(outdir, name="emulator") -> Emulator
Load an emulator from disk.
"""
def __init__(
self,
model: Any,
transforms: TransformPack,
config: Optional[EmulatorConfig] = None,
):
self.model = model
self.transforms = transforms
self.config = config or EmulatorConfig()
def _predict_model_space(self, X: np.ndarray) -> np.ndarray:
y = self.model.predict(X)
y = _ensure_2d(y)
return _as_f32(y)
[docs]
def predict(
self, targets: np.ndarray, *, return_model_space: bool = False
) -> np.ndarray:
targets = _ensure_2d(targets)
X_all = _as_f32(self.transforms.X_from_targets(targets))
bs = (
int(self.config.predict_batch_size) if self.config.predict_batch_size else 0
)
if bs <= 0 or X_all.shape[0] <= bs:
y_model = self._predict_model_space(X_all)
else:
out = []
for i in range(0, X_all.shape[0], bs):
out.append(self._predict_model_space(X_all[i : i + bs]))
y_model = np.vstack(out)
if return_model_space:
return y_model
# model space -> y_space
if self.config.standardize_y:
return self.transforms.y_from_model_space(y_model)
return y_model
[docs]
def predict_error(self, targets: np.ndarray) -> np.ndarray:
y = self.predict(targets)
err = np.zeros_like(y, dtype=np.float32)
if self.config.error_floor and self.config.error_floor > 0:
err = err + np.float32(self.config.error_floor)
return err
[docs]
def predict_with_error(self, targets: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
y = self.predict(targets)
return y, self.predict_error(targets)
# persistence
[docs]
def save(
self,
outdir: str,
name: str = "emulator",
compress: Union[Tuple[str, int], None] = ("lz4", 3),
) -> Dict[str, str]:
mkdir(outdir)
model_path = os.path.join(outdir, f"{name}.joblib")
meta_path = os.path.join(outdir, f"{name}.json")
dump(self.model, model_path, compress=compress)
meta = {
"type": "Emulator",
"config": asdict(self.config),
"transforms": self.transforms.to_dict(),
}
with open(meta_path, "w") as f:
json.dump(meta, f, indent=2, sort_keys=True)
return {"model": model_path, "meta": meta_path}
[docs]
@classmethod
def load(cls, outdir: str, name: str = "emulator") -> "Emulator":
model_path = os.path.join(outdir, f"{name}.joblib")
meta_path = os.path.join(outdir, f"{name}.json")
with open(meta_path, "r") as f:
meta = json.load(f)
model = load(model_path)
transforms = TransformPack.from_dict(meta["transforms"])
config = EmulatorConfig(**meta["config"])
return cls(model=model, transforms=transforms, config=config)
# -----------------------------------------------------------------------------
# EnsembleEmulator
# -----------------------------------------------------------------------------
[docs]
class EnsembleEmulator:
"""
Ensemble wrapper.
- predict() -> mean prediction in y_space
- predict_error() -> std across members in y_space (+ optional floor)
"""
def __init__(self, members: List[Emulator], error_floor: float = 0.0):
if not members:
raise ValueError("EnsembleEmulator requires at least one member.")
self.members = members
self.error_floor = float(error_floor)
# sanity checks: identical transforms and y_space
ref_t = members[0].transforms.to_dict()
ref_space = members[0].config.y_space
for m in members[1:]:
if m.transforms.to_dict() != ref_t:
raise ValueError(
"All ensemble members must share identical transforms."
)
if m.config.y_space != ref_space:
raise ValueError("All ensemble members must share y_space.")
self.transforms = members[0].transforms
self.y_space = ref_space
def predict(self, targets: np.ndarray) -> np.ndarray:
preds = [m.predict(targets) for m in self.members]
P = np.stack(preds, axis=0)
return np.mean(P, axis=0).astype(np.float32)
def predict_error(self, targets: np.ndarray, ddof: int = 0) -> np.ndarray:
preds = [m.predict(targets) for m in self.members]
P = np.stack(preds, axis=0)
err = np.std(P, axis=0, ddof=ddof).astype(np.float32)
if self.error_floor > 0:
err = np.sqrt(err**2 + self.error_floor**2).astype(np.float32)
return err
def predict_with_error(
self, targets: np.ndarray, ddof: int = 0
) -> Tuple[np.ndarray, np.ndarray]:
preds = [m.predict(targets) for m in self.members]
P = np.stack(preds, axis=0)
mean = np.mean(P, axis=0).astype(np.float32)
err = np.std(P, axis=0, ddof=ddof).astype(np.float32)
if self.error_floor > 0:
err = np.sqrt(err**2 + self.error_floor**2).astype(np.float32)
return mean, err
def save(self, outdir: str, name: str = "ensemble") -> Dict[str, Any]:
mkdir(outdir)
member_dirs = []
for k, mem in enumerate(self.members):
d = os.path.join(outdir, f"{name}_member{k:02d}")
mem.save(d, name="emulator")
member_dirs.append(d)
index = {
"type": "EnsembleEmulator",
"name": name,
"member_dirs": member_dirs,
"error_floor": self.error_floor,
}
index_path = os.path.join(outdir, f"{name}.json")
with open(index_path, "w") as f:
json.dump(index, f, indent=2, sort_keys=True)
return {"index": index_path, "member_dirs": member_dirs}
@classmethod
def load(cls, outdir: str, name: str = "ensemble") -> "EnsembleEmulator":
index_path = os.path.join(outdir, f"{name}.json")
with open(index_path, "r") as f:
index = json.load(f)
members = [Emulator.load(d, name="emulator") for d in index["member_dirs"]]
return cls(members=members, error_floor=float(index.get("error_floor", 0.0)))