"""
Model grid container and fitting machinery.
This module defines the ModelGrid class, which stores a grid of models with
their parameters and provides methods for fitting and evaluating these models.
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Sequence, Tuple, Mapping, Iterator, Union
from abc import ABC, abstractmethod
from itertools import product
import os
import json
import pickle
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor, as_completed
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
from astropy.table import Table, Column
from astropy.io import fits
import h5py
from scipy.spatial import cKDTree
from scipy.stats import gaussian_kde
from besta.grid.prob import (
Prior,
FlatPrior,
ObservableDependentPrior,
Likelihood,
GaussianProductLikelihood,
posterior_over_models as posterior_over_models_fn,
)
from besta.postprocess import (
enclosed_fraction_map,
pit_from_discrete_posterior,
pdf_stats,
photoz_metrics,
weighted_quantiles,
)
from .transforms import LinearStandardiser
from besta.utils import available_memory_bytes
from besta.logging import get_logger
os.environ.setdefault("OMP_NUM_THREADS", "1")
os.environ.setdefault("OPENBLAS_NUM_THREADS", "1")
os.environ.setdefault("MKL_NUM_THREADS", "1")
logger = get_logger(__name__)
def _guess_slices(n_objects, n_observables, n_jobs, tasks_per_worker=6):
# fewer, larger slices when P is large
base_tasks = n_jobs * tasks_per_worker
scale = max(1, n_observables // 8)
T = max(1, base_tasks // scale)
# turn T into contiguous slices
q, r = divmod(n_objects, T)
s = 0
out = []
for i in range(T):
size = q + (1 if i < r else 0)
out.append((s, s + size))
s += size
return out
def _truncate_posterior_mass(
cand_idx: np.ndarray,
w: np.ndarray,
*,
keep_mass: float,
min_candidates: int = 0,
max_candidates: Optional[int] = None,
keep_ties: bool = False,
tie_rtol: float = 0.0,
tie_atol: float = 0.0,
) -> Tuple[np.ndarray, np.ndarray, Dict[str, Any]]:
"""
Truncate a discrete posterior ``w`` over candidates to the smallest set
whose cumulative mass is at least ``keep_mass`` (after sorting by weight
in descending order).
Parameters
----------
cand_idx : numpy.ndarray
Integer array of candidate indices corresponding to the weights in
``w``. This is typically an index into a larger grid or catalogue.
w : numpy.ndarray
Array of (possibly unnormalised) posterior weights for each
candidate. Values are normalised internally to sum to one before
truncation.
keep_mass : float
Target cumulative posterior mass to retain in the truncated set.
Must be between 0 and 1. The function selects the smallest number
of highest-weight candidates whose cumulative mass meets or exceeds
this value.
min_candidates : int, optional
Minimum number of candidates to keep, regardless of the posterior
mass accumulated. This is applied after sorting by descending weight
but before enforcing ``keep_mass``.
max_candidates : int or None, optional
Optional upper bound on the number of candidates to keep. If not
``None``, the truncated set will contain at most this many
candidates, even if the requested ``keep_mass`` has not yet been
reached.
keep_ties : bool, optional
If ``True``, include all candidates whose weight is numerically tied
(within ``tie_rtol`` and ``tie_atol``) with the cut-off candidate at
the truncation boundary, potentially increasing the number of kept
candidates beyond ``max_candidates``.
tie_rtol : float, optional
Relative tolerance used when determining whether two weights are
considered equal for the purpose of tie handling. Passed to
``numpy.isclose`` via the ``rtol`` parameter.
tie_atol : float, optional
Absolute tolerance used when determining whether two weights are
considered equal for the purpose of tie handling. Passed to
``numpy.isclose`` via the ``atol`` parameter.
Returns
-------
cand_kept : numpy.ndarray
Array of indices of the candidates retained after truncation.
w_kept : numpy.ndarray
Array of posterior weights for the retained candidates, renormalised
to sum to one after truncation.
meta : dict
Dictionary with bookkeeping information about the truncation, such as
the number of candidates before and after truncation, the total mass
kept and discarded, the weight at the truncation cut, and effective
sample sizes before and after truncation.
"""
n_before = w.size
if n_before == 0:
meta = {
"n_before": 0,
"n_after": 0,
"keep_mass": keep_mass,
"mass_kept": 0.0,
"mass_dropped": 1.0,
"cut_weight": np.nan,
"ess_before": 0.0,
"ess_after": 0.0,
}
return cand_idx, w, meta
# Effective sample size
ess_before = 1.0 / np.sum(np.square(w))
keep_mass = np.clip(keep_mass, 0.0, 1.0)
min_candidates = max(0, min_candidates)
max_candidates = None if max_candidates is None else max(0, max_candidates)
k_cap = n_before if max_candidates is None else min(n_before, max_candidates)
k_floor = min(k_cap, min_candidates)
if k_cap == n_before and keep_mass >= 1.0 and not keep_ties:
meta = {
"n_before": n_before,
"n_after": n_before,
"keep_mass": keep_mass,
"mass_kept": 1.0,
"mass_dropped": 0.0,
"cut_weight": np.nan,
"ess_before": ess_before,
"ess_after": ess_before,
}
return cand_idx, w, meta
if k_cap == 0:
meta = {
"n_before": n_before,
"n_after": 0,
"keep_mass": keep_mass,
"mass_kept": 0.0,
"mass_dropped": 1.0,
"cut_weight": np.nan,
"ess_before": ess_before,
"ess_after": 0.0,
}
return cand_idx[:0], w[:0], meta
if keep_mass <= 0.0 and k_floor == 0:
keep_local = np.empty(0, dtype=np.int64)
cut_weight = np.nan
else:
probe = min(n_before, max(k_floor, 32))
w_sorted = None
ord_desc = None
cs = None
while True:
if probe >= n_before:
ord_desc = np.argsort(w)[::-1]
else:
ord_desc = np.argpartition(w, n_before - probe)[-probe:]
ord_desc = ord_desc[np.argsort(w[ord_desc])[::-1]]
w_sorted = w[ord_desc]
cs = np.cumsum(w_sorted)
enough_mass = cs[-1] >= keep_mass if keep_mass > 0.0 else True
if enough_mass or probe >= k_cap:
break
probe = min(k_cap, max(probe * 2, k_floor, 1))
k = 0 if keep_mass <= 0.0 else np.searchsorted(cs, keep_mass, side="left") + 1
k = min(max(k, k_floor), k_cap, ord_desc.size)
cut_weight = np.nan if k == 0 else w_sorted[k - 1]
if keep_ties and k > 0 and max_candidates is None:
tol = tie_atol + tie_rtol * abs(cut_weight)
keep_mask = np.abs(w - cut_weight) <= tol
keep_mask |= w > (cut_weight + tol)
keep_local = np.flatnonzero(keep_mask)
keep_local = keep_local[np.argsort(w[keep_local])[::-1]]
else:
keep_local = ord_desc[:k]
cand_kept = cand_idx[keep_local]
w_kept_raw = w[keep_local]
mass_kept = np.sum(w_kept_raw)
# Re-normalize weights
if not np.isfinite(mass_kept) or mass_kept <= 0:
w_kept = np.full_like(w_kept_raw, 1.0 / max(1, w_kept_raw.size), dtype=float)
mass_kept = 0.0
else:
w_kept = w_kept_raw / mass_kept
ess_after = 1.0 / np.sum(np.square(w_kept)) if w_kept.size > 0 else 0.0
meta = {
"n_before": n_before,
"n_after": w_kept.size,
"keep_mass": keep_mass,
"mass_kept": mass_kept,
"mass_dropped": max(0.0, 1.0 - mass_kept),
"cut_weight": cut_weight,
"ess_before": ess_before,
"ess_after": ess_after,
}
return cand_kept, w_kept, meta
# Base model grid class
[docs]
@dataclass
class ModelGrid:
"""
Base container for a grid of models.
This class stores:
1) observables: array of shape (N, P)
2) targets: array of shape (N, Q) with continuous quantities
3) names and metadata for both
4) optional per-model weights
Attributes
----------
observables : ndarray, shape (N, P)
Observable vectors per model, e.g. colours or fluxes.
targets : ndarray, shape (N, Q)
Continuous targets per model, e.g. SFR, logM, metallicity.
observable_names : list of str
Names of the P observables in order.
target_names : list of str
Names of the Q targets in order.
weights : ndarray, shape (N,), optional
Optional sampling weights for models. Defaults to uniform if None.
meta : dict, optional
Free-form metadata for provenance, settings, and units.
check_boundaries : callable, optional
Function to check if model parameters are within valid boundaries.
If None, no boundary checks are performed. The function should take a
model parameter vector as input and return a boolean indicating
whether the parameters are valid. This is only used during interpolation.
observable_standardiser : LinearStandardiser
Standardiser for observable quantities.
target_standardiser : LinearStandardiser
Standardiser for target quantities.
Examples
--------
>>> grid = ModelGrid(
... observables=np.array([[1.0, 0.5], [0.8, 0.3], [1.2, 0.7]]),
... targets=np.array([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]),
... observable_names=["obs1", "obs2"],
... target_names=["target1", "target2"]
... )
>>> print(grid.n_models)
3
>>> print(grid.n_observables)
2
"""
observables: np.ndarray
targets: np.ndarray
observable_names: List[str]
target_names: List[str]
weights: Optional[np.ndarray] = None
meta: Dict[str, Any] = field(default_factory=dict)
check_boundaries: Optional["Callable[[np.ndarray], bool]"] = field(
default=None, init=True, repr=False
)
# Cached stats for standardisation
observable_standardiser: LinearStandardiser = field(
default_factory=LinearStandardiser, init=False, repr=False
)
target_standardiser: LinearStandardiser = field(
default_factory=LinearStandardiser, init=False, repr=False
)
_observables_standardized: Optional[np.ndarray] = field(
default=None, init=False, repr=False
)
_targets_standardized: Optional[np.ndarray] = field(
default=None, init=False, repr=False
)
_kdtree: Optional["cKDTree"] = field(default=None, init=False, repr=False)
_kdtree_standardized: Optional[bool] = field(default=None, init=False, repr=False)
def __post_init__(self) -> None:
if self.observables.ndim != 2:
raise ValueError("observables must be 2D (N, P)")
if self.targets.ndim != 2:
raise ValueError("targets must be 2D (N, Q)")
if self.observables.shape[0] != self.targets.shape[0]:
raise ValueError("observables and targets must have same N")
if len(self.observable_names) != self.observables.shape[1]:
raise ValueError("observable_names length must match P")
if len(self.target_names) != self.targets.shape[1]:
raise ValueError("target_names length must match Q")
if self.weights is not None:
if self.weights.shape != (self.observables.shape[0],):
raise ValueError("weights must have shape (N,)")
if self.check_boundaries is not None and not callable(self.check_boundaries):
raise ValueError("check_boundaries must be callable if provided")
# ------------ basic properties ------------
@property
def n_models(self) -> int:
"""Number of models N."""
return self.observables.shape[0]
@property
def n_observables(self) -> int:
"""Number of observables P."""
return self.observables.shape[1]
@property
def n_targets(self) -> int:
"""Number of targets Q."""
return self.targets.shape[1]
@property
def grid_int(self) -> type:
"""Adaptative int type that accounts for model dimensionality."""
if self.n_models > np.iinfo(np.int32).max:
return np.int64
elif self.n_models > np.iinfo(np.int16).max:
return np.int32
else:
return np.int16
@property
def grid_float(self) -> type:
"""Adaptative float type that accounts for model precision."""
return self.observables.dtype
# ------------ views and subsets ------------
[docs]
def get_target(self, key):
"""Return a target column given a key."""
col = self.target_names.index(key)
return self.targets[:, col]
[docs]
def select(
self, idx: np.ndarray, observables=None, targets=None, standardisers=True
) -> "ModelGrid":
"""
Return a new ModelGrid containing a subset of models.
Parameters
----------
idx : ndarray of int, shape (K,)
Indices to select.
observables : list of str or list of int, optional, default=None
Names or indices of observables to include.
targets : list of str or list of int, optional, default=None
Names or indices of targets to include.
standardisers : bool, optional, default=True
Whether to propagate fitted standardisers.
Returns
-------
sub : ModelGrid
Subset grid view (copies arrays).
"""
w = None if self.weights is None else self.weights[idx].copy()
if observables is None or len(observables) == 0:
observables = self.observable_names
if isinstance(observables[0], str):
observables = [self.observable_names.index(n) for n in observables]
if targets is None or len(targets) == 0:
targets = self.target_names
if isinstance(targets[0], str):
targets = [self.target_names.index(n) for n in targets]
sub = ModelGrid(
observables=self.observables[idx][:, observables].copy(),
targets=self.targets[idx][:, targets].copy(),
observable_names=[self.observable_names[o] for o in observables],
target_names=[self.target_names[t] for t in targets],
weights=w,
meta=dict(self.meta),
)
if standardisers:
if self.observable_standardiser.is_fit:
sub.observable_standardiser.mean = self.observable_standardiser.mean[
observables
].copy()
sub.observable_standardiser.sd = self.observable_standardiser.sd[
observables
].copy()
if self._observables_standardized is not None:
sub._observables_standardized = self._observables_standardized[
idx
][:, observables].copy()
if self.target_standardiser.is_fit:
sub.target_standardiser.mean = self.target_standardiser.mean[
targets
].copy()
sub.target_standardiser.sd = self.target_standardiser.sd[targets].copy()
if self._targets_standardized is not None:
sub._targets_standardized = self._targets_standardized[idx][
:, targets
].copy()
return sub
# ------------ standardisation ------------
[docs]
def fit_target_standardiser(self, mask: Optional[np.ndarray] = None) -> None:
"""Fit mean/std for target standardisation used by KDTree distances.
Parameters
----------
mask : ndarray of bool, shape (N,), optional
If provided, use only masked entries to compute stats.
"""
X = self.targets if mask is None else self.targets[mask]
self.target_standardiser.fit(X, ddof=0)
self._targets_standardized = self.target_standardiser.transform(self.targets)
self.invalidate_knn_index()
[docs]
def fit_standardiser(self, mask: Optional[np.ndarray] = None) -> None:
"""
Fit mean and standard deviation for observable standardisation.
Parameters
----------
mask : ndarray of bool, shape (N,), optional
If provided, use only masked entries to compute stats.
"""
X = self.observables if mask is None else self.observables[mask]
self.observable_standardiser.fit(X, ddof=0)
self._observables_standardized = self.observable_standardiser.transform(
self.observables
)
[docs]
def _ensure_standardized_observables(self) -> np.ndarray:
"""Return cached standardized observables, computing them lazily if needed."""
if not self.observable_standardiser.is_fit:
raise RuntimeError(
"fit_standardiser must be called before using standardized observables"
)
if self._observables_standardized is None:
self._observables_standardized = self.observable_standardiser.transform(
self.observables
)
return self._observables_standardized
[docs]
def _ensure_standardized_targets(self) -> np.ndarray:
"""Return cached standardized targets, computing them lazily if needed."""
if not self.target_standardiser.is_fit:
raise RuntimeError(
"fit_target_standardiser must be called before using standardized targets"
)
if self._targets_standardized is None:
self._targets_standardized = self.target_standardiser.transform(
self.targets
)
return self._targets_standardized
# ------------ KDTree for KNN and interpolation ------------
[docs]
def invalidate_knn_index(self) -> None:
"""Invalidate cached KDTree."""
self._kdtree = None
self._kdtree_standardized = None
[docs]
def get_kdtree(self, *, standardize: bool = True):
"""Get a KDTree either stored in cash or built on the fly."""
if self._kdtree is not None and self._kdtree_standardized == standardize:
return self._kdtree
X = np.asarray(self.targets, dtype=float)
if X.ndim != 2 or X.shape[1] != self.n_targets:
raise ValueError("targets must be 2D (N, Q)")
if standardize:
if not self.target_standardiser.is_fit:
raise RuntimeError(
"target_standardiser is not fit. Call fit_target_standardiser() "
"before get_kdtree(standardize=True)."
)
Xn = self._ensure_standardized_targets()
else:
Xn = X
self._kdtree = cKDTree(Xn)
self._kdtree_standardized = standardize
return self._kdtree
[docs]
def in_boundaries(self, targets_query: np.ndarray) -> np.ndarray:
"""
Check which target queries are within the model grid boundaries.
Parameters
----------
targets_query : ndarray, shape (M, Q)
Target query points.
Returns
-------
in_bounds : ndarray of bool, shape (M,)
Whether each query is within the grid boundaries.
"""
if self.check_boundaries is None:
return np.ones(targets_query.shape[0], dtype=bool)
# in_bounds = np.ones(tq.shape[0], dtype=bool)
# for j in range(self.n_targets):
# tmin = np.min(self.targets[:, j])
# tmax = np.max(self.targets[:, j])
# in_bounds &= (tq[:, j] >= tmin) & (tq[:, j] <= tmax)
else:
in_bounds = self.check_boundaries(targets_query)
return in_bounds
[docs]
def interpolate_observables(
self,
targets_query: np.ndarray,
*,
method: Optional[str] = None,
fill_value: float = np.nan,
k: int = 32,
p: float = 2.0,
eps: float = 1e-12,
standardize: bool = True,
mode: str = "local_linear", # "idw" | "local_linear"
ridge: float = 1e-8, # Tikhonov regularization for stability
) -> np.ndarray:
"""
Interpolate observables for arbitrary target values using cached KDTree.
mode="nearest" : nearest-neighbour lookup
mode="idw" : inverse-distance weighted KNN
mode="local_linear": weighted local affine fit (exact for linear functions)
"""
tq = np.atleast_2d(np.asarray(targets_query, dtype=float))
if tq.shape[1] != self.n_targets:
raise ValueError(
f"targets_query must have shape (M, {self.n_targets}); got {tq.shape}"
)
if method is not None:
mode = method
mode = str(mode).lower()
if mode not in {"nearest", "idw", "local_linear"}:
raise ValueError(
f"Unknown mode={mode!r} (use 'nearest', 'idw' or 'local_linear')."
)
bad_q = ~np.isfinite(tq).all(axis=1) | ~self.in_boundaries(tq)
out = np.full((tq.shape[0], self.n_observables), fill_value, dtype=float)
if bad_q.all():
return out
if standardize:
if not self.target_standardiser.is_fit:
self.fit_target_standardiser()
tqn = self.target_standardiser.transform(tq)
else:
tqn = tq
tree = self.get_kdtree(standardize=standardize)
k_eff = 1 if mode == "nearest" else max(1, min(k, self.n_models))
if mode == "local_linear" and k_eff < (self.n_targets + 1):
raise ValueError(
"local_linear interpolation requires at least n_targets + 1 neighbours "
f"(got {k_eff}, need {self.n_targets + 1})."
)
dists, idx = tree.query(tqn[~bad_q], k=k_eff, workers=-1)
if k_eff == 1:
dists = dists[:, None]
idx = idx[:, None]
neigh_obs = self.observables[idx] # (Mgood, k, P)
neigh_tgt = (
self._ensure_standardized_targets()[idx] if standardize else self.targets[idx]
)
obs_q = np.empty((dists.shape[0], self.n_observables), dtype=float)
zero = dists <= eps
any_zero = zero.any(axis=1)
# Exact matches: average exact neighbors
if np.any(any_zero):
zrows = np.where(any_zero)[0]
for r in zrows:
m = zero[r]
obs_q[r] = np.nanmean(neigh_obs[r, m, :], axis=0)
# Non-exact rows
rows = np.where(~any_zero)[0]
if rows.size == 0:
out[~bad_q] = obs_q
return out
if mode == "nearest":
obs_q[rows] = neigh_obs[rows, 0, :]
elif mode == "idw":
w = 1.0 / (np.power(dists[rows], p) + eps)
wsum = np.sum(w, axis=1, keepdims=True)
w = np.where(wsum > 0, w / wsum, 0.0)
obs_q[rows] = np.einsum("nk,nkp->np", w, neigh_obs[rows])
elif mode == "local_linear":
# Weighted local affine fit per query:
# y ≈ a + B (x - xq)
# Implemented as weighted least squares on design matrix [1, (x - xq)]
D = self.n_targets
P = self.n_observables
for rr in rows:
xq = tqn[~bad_q][rr] # (D,)
Xn = neigh_tgt[rr] # (k, D)
Yn = neigh_obs[rr] # (k, P)
dn = dists[rr] # (k,)
# weights (same spirit as IDW)
w = 1.0 / (np.power(dn, p) + eps) # (k,)
# build design matrix: (k, 1+D)
A = np.empty((k_eff, 1 + D), dtype=float)
A[:, 0] = 1.0
A[:, 1:] = Xn - xq[None, :]
# Apply weights via sqrt(w)
sw = np.sqrt(w)[:, None] # (k,1)
Aw = A * sw # (k,1+D)
Yw = Yn * sw # (k,P)
# Solve (Aw^T Aw + ridge I) beta = Aw^T Yw
G = Aw.T @ Aw
if ridge > 0:
G = G + ridge * np.eye(G.shape[0])
rhs = Aw.T @ Yw
try:
beta = np.linalg.solve(G, rhs) # (1+D, P)
obs_q[rr] = beta[0] # prediction at xq (delta=0)
except np.linalg.LinAlgError:
logger.warning("Local linear interpolation failed; falling back to IDW.")
# Fallback to IDW if ill-conditioned
ww = w / (w.sum() + eps)
obs_q[rr] = ww @ Yn
out[~bad_q] = obs_q
return out
# ------------ I/O ------------
[docs]
@staticmethod
def _std_to_state(std) -> Dict[str, Any]:
"""Serialisable state for LinearStandardiser."""
if std is None:
return {"is_fit": False, "mean": None, "sd": None}
return {
"is_fit": bool(getattr(std, "is_fit", False)),
"mean": None
if getattr(std, "mean", None) is None
else np.asarray(std.mean),
"sd": None if getattr(std, "sd", None) is None else np.asarray(std.sd),
}
[docs]
@staticmethod
def _state_to_std(state: Mapping[str, Any], std) -> None:
"""Restore LinearStandardiser from state into an existing instance."""
if not state:
return
is_fit = bool(state.get("is_fit", False))
if not is_fit:
return
mean = state.get("mean", None)
sd = state.get("sd", None)
if mean is None or sd is None:
return
std.mean = np.asarray(mean, dtype=float).copy()
std.sd = np.asarray(sd, dtype=float).copy()
[docs]
def to_dict(self) -> Dict[str, Any]:
"""
Serialise grid to a dictionary that is np.savez / JSON-friendly
(except numpy arrays, which are fine for np.savez).
Notes
-----
- check_boundaries is NOT serialised (it is generally not portable).
If you need it, store a string key in ``meta`` and resolve via a
registry.
"""
return {
"observables": np.asarray(self.observables),
"targets": np.asarray(self.targets),
"observable_names": list(self.observable_names),
"target_names": list(self.target_names),
"weights": None if self.weights is None else np.asarray(self.weights),
"meta": dict(self.meta or {}),
"standardisers": {
"observables": self._std_to_state(self.observable_standardiser),
"targets": self._std_to_state(self.target_standardiser),
},
# Explicitly exclude runtime-only/cache/callables:
"check_boundaries_key": self.meta.get("check_boundaries_key", None),
}
[docs]
@classmethod
def from_dict(
cls,
d: Mapping[str, Any],
*,
check_boundaries: Optional["Callable[[np.ndarray], np.ndarray]"] = None,
) -> "ModelGrid":
"""
Build a ModelGrid from a dictionary as produced by to_dict().
Parameters
----------
d : mapping
Mapping produced by to_dict().
check_boundaries : callable, optional
Optional boundary function to attach at load time.
(Not serialised by default.)
"""
grid = cls(
observables=np.asarray(d["observables"]),
targets=np.asarray(d["targets"]),
observable_names=list(d["observable_names"]),
target_names=list(d["target_names"]),
weights=None
if d.get("weights", None) is None
else np.asarray(d["weights"]),
meta=dict(d.get("meta", {}) or {}),
check_boundaries=check_boundaries,
)
std = d.get("standardisers", {}) or {}
cls._state_to_std(std.get("observables", {}), grid.observable_standardiser)
cls._state_to_std(std.get("targets", {}), grid.target_standardiser)
return grid
[docs]
@classmethod
def from_fits_table(
cls,
path,
observable_cols=None,
target_cols=None,
weight_col=None,
row_mask=None,
table_hdu=1,
memmap=False,
meta_key="MODELGRID_META",
):
"""
Build a ModelGrid from an Astropy FITS table.
Priority order:
1) Auto-load if table.meta contains OBSNAME and TGTNAME written by to_fits_table.
2) If explicit mappings (observable_cols / target_cols) are provided, use them.
3) Fallback: load all numeric 1-D columns as observables; no targets.
Parameters
----------
path : str
Path to the FITS file.
observable_cols : list[str] or dict[str, str], optional
Explicit mapping for observables. If None, try auto-discovery.
target_cols : list[str] or dict[str, str], optional
Explicit mapping for targets. If None, try auto-discovery.
weight_col : str, optional
Column name for per-model weights. If None, will auto-use "weight" if present.
row_mask : array_like of bool, optional
Boolean mask of length N to select rows after reading.
table_hdu : int or str, optional
FITS HDU index or name containing the table. Default 1.
memmap : bool, optional
astropy Table read memmap. Default False.
meta_key : str, optional
Table.meta key with JSON meta blob written by to_fits_table. Default "MODELGRID_META".
Returns
-------
grid : ModelGrid
"""
try:
from astropy.table import Table
except Exception as e:
raise ImportError("astropy is required for from_fits_table") from e
t = Table.read(path, hdu=table_hdu, memmap=memmap)
def _stack_named(names):
arrs = []
for n in names:
if n not in t.colnames:
raise KeyError(f"Missing FITS column '{n}'")
a = np.asarray(t[n])
if a.ndim > 1:
# allow masked scalars, but require 1-D per column
a = np.squeeze(a)
if a.ndim != 1:
raise ValueError(
f"FITS column '{n}' must be 1-D; got shape {a.shape}"
)
arrs.append(a)
return np.vstack(arrs).T # (N, P)
def _resolve_mapping(cols):
if cols is None:
return None, None
if isinstance(cols, dict):
names_dst = list(cols.keys())
names_src = list(cols.values())
else:
names_dst = list(cols)
names_src = list(cols)
X = _stack_named(names_src)
return X, names_dst
# --- 1) Auto-discovery from metadata written by to_fits_table ---
obs_names_meta = t.meta.get("OBSNAME")
tgt_names_meta = t.meta.get("TGTNAME")
meta_blob_raw = t.meta.get(meta_key)
auto_loaded = False
if (
obs_names_meta is not None
and tgt_names_meta is not None
and observable_cols is None
and target_cols is None
):
obs_names = [s for s in str(obs_names_meta).split(",") if s]
tgt_names = [s for s in str(tgt_names_meta).split(",") if s]
X = _stack_named(obs_names) if obs_names else np.empty((len(t), 0))
Y = _stack_named(tgt_names) if tgt_names else np.empty((len(t), 0))
w = None
if weight_col is not None and weight_col in t.colnames:
w = np.asarray(t[weight_col])
elif "weight" in t.colnames:
w = np.asarray(t["weight"])
meta = {}
if meta_blob_raw is not None:
try:
meta = json.loads(meta_blob_raw)
except Exception:
meta = {}
auto_loaded = True
# --- 2) Explicit mappings (if provided) ---
if not auto_loaded:
X, obs_names = _resolve_mapping(observable_cols)
Y, tgt_names = _resolve_mapping(target_cols)
# Fallback if still None: load all numeric 1-D columns as observables
if X is None and Y is None:
num_cols = []
for name in t.colnames:
if name == "weight":
continue
a = np.asarray(t[name])
if a.ndim == 1 and np.issubdtype(a.dtype, np.number):
num_cols.append(name)
if not num_cols:
raise ValueError("No numeric 1-D columns found to load")
obs_names = num_cols
X = _stack_named(obs_names)
tgt_names = []
Y = np.empty((X.shape[0], 0))
elif X is None or obs_names is None:
raise ValueError(
"observable_cols must be provided if auto-discovery is not available"
)
elif Y is None or tgt_names is None:
# allow targets empty
tgt_names = []
Y = np.empty((X.shape[0], 0))
# weights
w = None
if weight_col is not None:
if weight_col not in t.colnames:
raise KeyError(f"weight_col '{weight_col}' not found")
w = np.asarray(t[weight_col])
elif "weight" in t.colnames:
w = np.asarray(t["weight"])
# meta
import json
meta = {}
if meta_blob_raw is not None:
try:
meta = json.loads(meta_blob_raw)
except Exception:
meta = {}
# optional row selection
if row_mask is not None:
m = np.asarray(row_mask, dtype=bool)
if m.shape[0] != X.shape[0]:
raise ValueError("row_mask length mismatch")
X = X[m]
Y = Y[m]
if w is not None:
w = w[m]
return cls(
observables=X,
targets=Y,
observable_names=(obs_names if auto_loaded else obs_names),
target_names=(tgt_names if auto_loaded else tgt_names),
weights=w,
meta=dict(meta or {}),
)
[docs]
@classmethod
def from_hdf5(
cls,
path,
group="/modelgrid",
observable_dsets=None,
target_dsets=None,
weight_dset=None,
row_slice=None,
):
"""
Build a ModelGrid from an HDF5 file, with automatic discovery.
Priority order:
1) Auto-load if the group contains subgroups 'observables' and 'targets'
as written by to_hdf5 (reads names and meta from attributes).
2) If explicit mappings (observable_dsets / target_dsets) are provided, use them.
3) Fallback: load all 1-D numeric datasets directly under the group as
observables; no targets.
Parameters
----------
path : str
Path to the HDF5 file.
group : str, optional
Group path to read from. Default "/modelgrid".
observable_dsets : list[str] or dict[str, str], optional
Explicit dataset mapping for observables (relative to group).
target_dsets : list[str] or dict[str, str], optional
Explicit dataset mapping for targets (relative to group).
weight_dset : str, optional
Dataset name for weights (relative to group). If None, will use
"weights" if present in group.
row_slice : slice or array_like of int or bool, optional
Optional selection of rows after reading.
Returns
-------
grid : ModelGrid
"""
if not os.path.exists(path):
raise FileNotFoundError(path)
def _ensure_1d(d, key):
a = np.asarray(d)
if a.ndim > 1:
a = np.squeeze(a)
if a.ndim != 1:
raise ValueError(
f"HDF5 dataset '{key}' must be 1-D; got shape {a.shape}"
)
return a
with h5py.File(path, "r") as f:
if group not in f:
raise KeyError(f"group '{group}' not found in file")
g = f[group]
# --- 1) Auto-discovery (preferred layout written by to_hdf5) ---
auto_loaded = False
if (
"observables" in g
and "targets" in g
and observable_dsets is None
and target_dsets is None
):
gob = g["observables"]
tgt = g["targets"]
# names from attributes if present, else dataset keys order
try:
obs_names = json.loads(g.attrs.get("observable_names", "[]"))
tgt_names = json.loads(g.attrs.get("target_names", "[]"))
except Exception:
obs_names = []
tgt_names = []
if not obs_names:
obs_names = list(gob.keys())
if not tgt_names:
tgt_names = list(tgt.keys())
# read columns
cols = [
_ensure_1d(gob[name][...], f"{group}/observables/{name}")
for name in obs_names
]
X = (
np.vstack(cols).T
if cols
else np.empty((g.attrs.get("n_models", 0), 0))
)
tcols = [
_ensure_1d(tgt[name][...], f"{group}/targets/{name}")
for name in tgt_names
]
Y = np.vstack(tcols).T if tcols else np.empty((X.shape[0], 0))
# weights
w = None
if weight_dset is not None and weight_dset in g:
w = _ensure_1d(g[weight_dset][...], f"{group}/{weight_dset}")
elif "weights" in g:
w = _ensure_1d(g["weights"][...], f"{group}/weights")
# meta
try:
meta = json.loads(g.attrs.get("meta", "{}"))
except Exception:
meta = {}
auto_loaded = True
# --- 2) Explicit mappings ---
if not auto_loaded:
def _full(k): # relative to group
return f"{group.rstrip('/')}/{k.lstrip('/')}"
def _load_map(dsets):
if dsets is None:
return None, None
if isinstance(dsets, dict):
dst = list(dsets.keys())
src = [_full(v) for v in dsets.values()]
else:
dst = list(dsets)
src = [_full(v) for v in dsets]
cols = []
for key, src_key in zip(dst, src):
if src_key not in f:
raise KeyError(f"Missing dataset '{src_key}'")
cols.append(_ensure_1d(f[src_key][...], src_key))
X = np.vstack(cols).T
return X, dst
X, obs_names = _load_map(observable_dsets)
Y, tgt_names = _load_map(target_dsets)
# Fallback: all 1-D numeric datasets directly under group as observables
if X is None and Y is None:
names = []
cols = []
for k, d in g.items():
if isinstance(d, h5py.Dataset):
a = np.asarray(d[...])
if (
a.ndim == 1
and np.issubdtype(a.dtype, np.number)
and k != "weights"
):
names.append(k)
cols.append(a)
if not cols:
raise ValueError(
"No 1-D numeric datasets found to load under the group"
)
X = np.vstack(cols).T
obs_names = names
Y = np.empty((X.shape[0], 0))
tgt_names = []
elif X is None or obs_names is None:
raise ValueError(
"observable_dsets must be provided if auto-discovery is not available"
)
elif Y is None or tgt_names is None:
Y = np.empty((X.shape[0], 0))
tgt_names = []
# weights
w = None
if weight_dset is not None:
key = _full(weight_dset)
if key not in f:
raise KeyError(f"weight_dset '{key}' not found")
w = _ensure_1d(f[key][...], key)
elif "weights" in g:
w = _ensure_1d(g["weights"][...], f"{group}/weights")
# meta from attribute if present
try:
meta = json.loads(g.attrs.get("meta", "{}"))
except Exception:
meta = {}
# Optional row selection
if row_slice is not None:
X = X[row_slice]
Y = Y[row_slice]
if w is not None:
w = w[row_slice]
return cls(
observables=X,
targets=Y,
observable_names=obs_names,
target_names=tgt_names,
weights=w,
meta=dict(meta or {}),
)
def to_fits_table(
self,
path,
*,
hdu_name: str = "MODELGRID",
overwrite: bool = False,
include_meta: bool = True,
meta_key: str = "MODELGRID_META",
fill_value=np.nan,
):
try:
from astropy.table import Table, Column
from astropy.io import fits
except Exception as e:
raise ImportError("astropy is required for to_fits_table") from e
N, P = self.observables.shape
Nt, Q = self.targets.shape
if N != Nt:
raise ValueError("observables and targets must have same number of rows")
tab = Table()
for j, name in enumerate(self.observable_names):
col = np.asarray(self.observables[:, j])
if hasattr(col, "mask"):
col = np.ma.filled(col, fill_value)
tab.add_column(Column(col, name=str(name)))
for j, name in enumerate(self.target_names):
col = np.asarray(self.targets[:, j])
if hasattr(col, "mask"):
col = np.ma.filled(col, fill_value)
tab.add_column(Column(col, name=str(name)))
if self.weights is not None:
w = np.asarray(self.weights)
if hasattr(w, "mask"):
w = np.ma.filled(w, fill_value)
tab.add_column(Column(w, name="weight"))
tab.meta = dict(tab.meta or {})
tab.meta.update(
{
"NMODELS": int(N),
"NOBS": int(P),
"NTGT": int(Q),
"OBSNAME": ",".join(map(str, self.observable_names)),
"TGTNAME": ",".join(map(str, self.target_names)),
}
)
if include_meta:
try:
tab.meta[meta_key] = json.dumps(self.meta or {}, ensure_ascii=True)
except Exception:
tab.meta[meta_key] = json.dumps({}, ensure_ascii=True)
# Build HDUList: Primary + BinTable
primary = fits.PrimaryHDU()
bintable = fits.BinTableHDU(tab, name=str(hdu_name))
if overwrite or (not os.path.exists(path)):
fits.HDUList([primary, bintable]).writeto(path, overwrite=overwrite)
return
# Replace if HDU exists by name; else append
with fits.open(path, mode="update") as hdul:
idx = None
for i, h in enumerate(hdul):
if getattr(h, "name", None) == str(hdu_name):
idx = i
break
if idx is None:
hdul.append(bintable)
else:
hdul[idx] = bintable
hdul.flush()
[docs]
def to_hdf5(
self,
path,
group="/modelgrid",
overwrite=False,
compression="gzip",
compression_opts=4,
chunks=True,
include_meta=True,
):
"""
Save the grid as an HDF5 group with per-column datasets.
Layout
------
``{group}/observables/<name>``
One 1-D dataset per observable.
``{group}/targets/<name>``
One 1-D dataset per target.
``{group}/weights``
Optional 1-D dataset if model weights are present.
Attributes on ``{group}``
``observable_names``, ``target_names``, ``n_models``,
``n_observables``, ``n_targets``, and ``meta`` (JSON when
``include_meta=True``).
Parameters
----------
path : str
Output HDF5 path. File will be created if it does not exist.
group : str, optional
Group path where to store the grid. Default "/modelgrid".
overwrite : bool, optional
If True and the group exists, it will be deleted and recreated.
Default False.
compression : str or None, optional
Compression for datasets (e.g., "gzip"). Default "gzip".
compression_opts : int, optional
Compression level if applicable. Default 4.
chunks : bool or tuple, optional
Enable chunking (True) or provide explicit chunk shape. Default True.
include_meta : bool, optional
If True, write JSON-encoded meta as a group attribute. Default True.
Raises
------
ImportError
If h5py is not available.
ValueError
If shapes are inconsistent.
"""
try:
import h5py
except Exception as e:
raise ImportError("h5py is required for to_hdf5") from e
N, P = self.observables.shape
Nt, Q = self.targets.shape
if N != Nt:
raise ValueError("observables and targets must have same number of rows")
# Ensure file exists
mode = "a" if os.path.exists(path) else "w"
with h5py.File(path, mode) as f:
# Handle group
if group in f:
if overwrite:
del f[group]
else:
raise ValueError(
f"group '{group}' already exists; use overwrite=True"
)
g = f.create_group(group)
gob = g.create_group("observables")
tgt = g.create_group("targets")
# Observables datasets
for j, name in enumerate(self.observable_names):
d = np.asarray(self.observables[:, j])
gob.create_dataset(
name,
data=d,
compression=compression,
compression_opts=compression_opts,
chunks=chunks,
)
# Targets datasets
for j, name in enumerate(self.target_names):
d = np.asarray(self.targets[:, j])
tgt.create_dataset(
name,
data=d,
compression=compression,
compression_opts=compression_opts,
chunks=chunks,
)
# Weights
if self.weights is not None:
g.create_dataset(
"weights",
data=np.asarray(self.weights),
compression=compression,
compression_opts=compression_opts,
chunks=chunks,
)
# Attributes
g.attrs["n_models"] = int(N)
g.attrs["n_observables"] = int(P)
g.attrs["n_targets"] = int(Q)
# Store names as JSON to avoid fixed-length string issues
g.attrs["observable_names"] = json.dumps(
list(self.observable_names), ensure_ascii=True
)
g.attrs["target_names"] = json.dumps(
list(self.target_names), ensure_ascii=True
)
if include_meta:
try:
g.attrs["meta"] = json.dumps(self.meta or {}, ensure_ascii=True)
except Exception:
g.attrs["meta"] = json.dumps({}, ensure_ascii=True)
g.attrs["standardisers"] = json.dumps(
{
"observables": {
"is_fit": self.observable_standardiser.is_fit,
"mean": None
if not self.observable_standardiser.is_fit
else self.observable_standardiser.mean.tolist(),
"sd": None
if not self.observable_standardiser.is_fit
else self.observable_standardiser.sd.tolist(),
},
"targets": {
"is_fit": self.target_standardiser.is_fit,
"mean": None
if not self.target_standardiser.is_fit
else self.target_standardiser.mean.tolist(),
"sd": None
if not self.target_standardiser.is_fit
else self.target_standardiser.sd.tolist(),
},
},
ensure_ascii=True,
)
[docs]
@classmethod
def from_pickle(cls, path):
"""
Load a ModelGrid from a pickle file.
Parameters
----------
path : str
Path to the pickle file.
Returns
-------
grid : ModelGrid
"""
with open(path, "rb") as f:
obj = pickle.load(f)
if not isinstance(obj, cls):
raise TypeError(
f"Pickle file does not contain a ModelGrid; got {type(obj)}"
)
return obj
[docs]
def to_pickle(self, path):
"""
Save the ModelGrid to a pickle file.
Parameters
----------
path : str
Output pickle file path.
"""
with open(path, "wb") as f:
pickle.dump(self, f)
[docs]
@classmethod
def load_auto(cls, path, **kwargs):
"""
Auto-load a ModelGrid from a file based on its extension.
Supported formats:
.fits, .fit, .fts FITS table via from_fits_table
.hdf5, .h5 HDF5 via from_hdf5
.pkl, .pickle Pickle via from_pickle
Parameters
----------
path : str
Path to the file.
**kwargs
Additional keyword arguments passed to the specific loader.
Returns
-------
grid : ModelGrid
"""
ext = os.path.splitext(path)[1].lower()
if ext in (".fits", ".fit", ".fts"):
return cls.from_fits_table(path, **kwargs)
elif ext in (".hdf5", ".h5"):
return cls.from_hdf5(path, **kwargs)
elif ext in (".pkl", ".pickle"):
return cls.from_pickle(path)
else:
raise ValueError(f"Unsupported file extension '{ext}' for auto-loading")
# ---------------------------------------------------------------------
# ModelGrid fitter
# ---------------------------------------------------------------------
[docs]
class GridFitter:
"""
Bayesian fitter over a ModelGrid using optinal Prior and Likelihood.
The fitter computes posterior weights over candidate models by combining
a user-provided likelihood p(x | model) with a prior p(model). It can
then histogram targets to obtain marginal posteriors.
Parameters
----------
grid : ModelGrid
Training model grid that defines the observable and target spaces.
likelihood : Likelihood, optional
Likelihood instance. Default is GaussianProductLikelihood.
prior : Prior, optional
Prior instance. Default is FlatPrior.
use_standardised : bool, optional
If True and the ModelGrid supports standardisation, evaluation can be
done in standardised space for internal transforms. The likelihood is
still called with native units unless you adapt it. Default True.
Notes
-----
Posterior over models is computed as:
.. math::
\log w_i = \log p(x \mid \mathrm{model}_i) +
\log p(\mathrm{model}_i) +
\log w_i^{\mathrm{model}}
and then normalised to sum to one over the candidate set.
"""
def __init__(
self,
grid,
likelihood: Optional[Likelihood] = None,
prior: Optional[Prior] = None,
use_standardised: bool = True,
) -> None:
self.grid = grid
self.likelihood = (
likelihood if likelihood is not None else GaussianProductLikelihood()
)
self.prior = prior if prior is not None else FlatPrior()
self.use_standardised = use_standardised
if self.use_standardised and hasattr(self.grid, "fit_standardiser"):
self.grid.fit_standardiser()
[docs]
def posterior_over_models(
self,
x_native: np.ndarray,
sigma_native: np.ndarray,
candidate_idx: Optional[np.ndarray] = None,
) -> np.ndarray:
"""
Compute posterior weights over candidate models for one query.
This method evaluates log p(x | model) + log p(model) for a set of
candidate models and returns the normalised weights. If
use_standardised is True, observables are transformed to the
grid's standardised space for likelihood evaluation and the input
uncertainties are mapped to that same space. Priors that depend on
observables always receive native (non-standardised) observables.
Parameters
----------
x_native : ndarray, shape (P,)
Query observables in native units and ordering matching
grid.observable_names.
sigma_native : ndarray, shape (P,)
Per-dimension measurement uncertainties for the query in native
units. If use_standardised is True, these are internally divided
by the grid standard deviations so they live in the same space as
the standardised observables.
candidate_idx : ndarray of int, optional
Indices of candidate models to consider. If None, all models in
the grid are used.
Returns
-------
w : ndarray, shape (Nc,)
Posterior weights over the candidate set, normalised to sum to 1.
Raises
------
RuntimeError
If use_standardised is True and the grid standardiser has not been
fitted (fit_standardiser must be called before).
ValueError
If any derived bandwidth (from sigma) is non positive or not finite,
depending on the Likelihood implementation.
Notes
-----
1. Consistent spaces: when use_standardised is True, the likelihood is
evaluated with standardised observables and uncertainties
(x_eval, X_eval, sigma_eval). When False, evaluation is done in
native space.
2. Observable dependent priors: priors that require observables are
passed native observables to avoid feeding z-scored magnitudes or
colours into priors defined in native units (for example p(z | VIS)).
3. Model weights: if the grid has per-model sampling weights, they are
multiplied into the posterior before normalisation.
See Also
--------
posterior_over_models_fn : Backend routine that combines likelihood and prior.
"""
# Select model candidates
idx = None if candidate_idx is None else np.asarray(candidate_idx)
x_models_native = (
self.grid.observables if idx is None else self.grid.observables[idx]
)
Tc = self.grid.targets if idx is None else self.grid.targets[idx]
if getattr(self.grid, "weights", None) is None:
wc = None
elif idx is None:
wc = self.grid.weights
else:
wc = self.grid.weights[idx]
# Map to evaluation space if requested
if self.use_standardised and hasattr(self.grid, "transform_observables"):
x_eval = self.grid.transform_observables(x_native)
if not self.grid.observable_standardiser.is_fit:
raise RuntimeError("standardiser is not fitted")
X_std = self.grid._ensure_standardized_observables()
x_models_eval = X_std if idx is None else X_std[idx]
sigma_eval = sigma_native / self.grid.observable_standardiser.sd
# Priors that depend on observables must see native observables
prior_obs = x_models_native
else:
x_eval = x_native
x_models_eval = x_models_native
sigma_eval = sigma_native
prior_obs = x_models_native
# Detect priors that need observables
prior_needs_obs = isinstance(self.prior, ObservableDependentPrior)
w = posterior_over_models_fn(
x_native=x_eval,
sigma_native=sigma_eval,
X_models=x_models_eval,
targets_models=Tc,
likelihood=self.likelihood,
prior=self.prior,
model_weights=wc,
prior_needs_observables=prior_needs_obs,
observables_models=prior_obs,
)
return w
[docs]
def posterior_over_target(
self,
x_native: np.ndarray,
sigma_native: np.ndarray,
target_col: int | str,
bins: np.ndarray,
candidate_idx: Optional[np.ndarray] = None,
) -> Tuple[np.ndarray, np.ndarray]:
"""
Compute a histogrammed posterior over one target.
Parameters
----------
x_native : ndarray, shape (P,)
Query observables in native units.
sigma_native : ndarray, shape (P,)
Per-dimension uncertainties in native units.
target_col : int or str
Column index or name in grid.targets.
bins : ndarray, shape (K+1,)
Bin edges.
candidate_idx : ndarray of int, optional
Candidate model indices.
Returns
-------
post : ndarray, shape (K,)
Posterior over bins (sums to one).
centers : ndarray, shape (K,)
Bin centers.
"""
idx = np.arange(self.grid.n_models) if candidate_idx is None else candidate_idx
w = self.posterior_over_models(
x_native=x_native, sigma_native=sigma_native, candidate_idx=idx
)
if isinstance(target_col, str):
try:
j = self.grid.target_names.index(target_col)
except ValueError as e:
raise KeyError(f"Unknown target column '{target_col}'") from e
else:
j = target_col
y = self.grid.targets[idx, j]
hist, _ = np.histogram(y, bins=bins, weights=w, density=True)
s = np.sum(hist * np.diff(bins))
post = (
hist / s
if s > 0 and np.isfinite(s)
else np.full_like(hist, 1.0 / hist.size)
)
centers = 0.5 * (bins[:-1] + bins[1:])
return post, centers
[docs]
def fit_batch(
self,
X_native: np.ndarray,
SIG_native: np.ndarray,
binner: Any | None = None,
n_jobs: int = 1,
backend: str = "thread",
stats_for: Optional[Sequence[int | str]] = None,
stats_bins: Optional[Sequence[np.ndarray]] = None,
find_multimodal: bool = False,
use_kde_for_stats: bool = False,
return_pdf_for_stats: bool = False,
verbose: bool = True,
max_memory_gb: Optional[float] = 16.0,
memcheck_sample: int = 256,
safety_margin: float = 1.2,
tasks_per_worker: Optional[int] = None,
dry_run: bool = False,
# Prune posterior samples
posterior_keep_mass: Optional[float] = 0.99,
posterior_keep_min_candidates: int = 128,
posterior_keep_max_candidates: Optional[int] = None,
posterior_keep_ties: bool = False,
return_mode: str = "iter", # "iter" | "list"
# Save results into HDF5
output_hdf5_path: Optional[str] = None,
output_hdf5_group: str = "/fit_batch",
output_hdf5_overwrite: bool = False,
output_hdf5_compression: Optional[str] = "gzip",
output_hdf5_compression_opts: int = 4,
output_hdf5_flush_every: int = 256,
output_hdf5_write_only: bool = False,
) -> Union[Iterator[Dict[str, Any]], List[Dict[str, Any]]]:
"""
Evaluate model posteriors for multiple queries.
Workflow
--------
The computation is split into three stages:
Step 1 — Candidate selection
For each query ``m``, find a set of candidate model indices ``cands[m]``
using the optional ``binner``. If ``binner`` is None or returns an
empty set, fall back to the full grid. Work is submitted as slices of
contiguous queries to a thread pool. Slice sizing is heuristic:
it scales inversely with the observable dimension P so that each
task has a reasonable amount of work without overscheduling.
Step 2 — Posterior over models
For each query ``m``, compute posterior weights over ``cands[m]`` by combining
likelihood and prior (plus per-model weights). If ``backend="process"``,
a picklability check is performed and the code falls back to threads
when needed. Slices are *cost-balanced* using a cheap proxy
``cost[m] ≃ len(cands[m]) * P`` so that each future does a similar amount
of numerical work.
Step 3 — Per-target statistics
Optionally, for each requested target, histogram the posterior and
compute summary stats (mean, std, map, central intervals, quantiles,
modal count). This stage is NumPy-bound and relatively light, so it
always uses a thread pool and reuses Step-2 slices (or a similarly
sized partition).
Parallel task sizing
--------------------
The number of tasks submitted to the executor is controlled by
``tasks_per_worker``. By default, it mirrors the current behaviour
(~ 5 tasks per worker). Larger values create **more, smaller tasks**
(better load-balancing, higher scheduler overhead). Smaller values create
**fewer, larger tasks** (lower overhead, potentially less balanced).
Memory safety
-------------
The previous explicit memory pre-check is currently disabled in this
implementation. For large runs, prefer:
- candidate selection via ``binner``,
- posterior truncation controls,
- ``return_mode="iter"``,
- HDF5 output with ``output_hdf5_write_only=True`` when appropriate.
Parameters
----------
X_native : ndarray, shape (M, P)
Query observables in native units. Order must match the grid.
SIG_native : ndarray, shape (M, P)
Per-dimension uncertainties in native units.
binner : object, optional
Must expose `.dims` and `.candidates(y_native, sigmas_native)`.
All configuration (e.g. sigma/target-based selection) is handled
internally by the binner itself.
n_jobs : int, optional
Maximum concurrency (threads or processes, per backend).
backend : {"thread","process"}, optional
Execution backend for Step 2 (posterior evaluation). Steps 1 and 3
always use threads. If objects are not picklable with "process",
we fall back to "thread".
stats_for : sequence of {int,str}, optional
Target columns (indices or names) to summarise. If None, Step 3 is
skipped.
stats_bins : sequence of ndarray, optional
Per-target bin edges. Must have same length as `stats_for` if provided.
find_multimodal : bool, optional
If True, attempt to count posterior modes in each 1D histogram.
return_pdf_for_stats : bool, optional
If True, return the full (M, K) discrete posterior for each requested
target (increases memory).
verbose : bool, optional
Print progress messages.
max_memory_gb : float or None, optional
Maximum allowed additional memory. None disables the check.
memcheck_sample : int, optional
Number of pilot queries for candidate-size estimation in the memory
check (upper bound).
safety_margin : float, optional
Multiplier applied to the memory estimate before comparing to the
allowed/available memory.
tasks_per_worker : int or None, optional
Controls how many tasks (slices) are queued per worker for each stage.
Default None behaves like the previous version (~5). Increase to
improve load balancing for heterogeneous costs; decrease to reduce
scheduling overhead.
dry_run : bool, optional
If True, perform only the memory check and return an empty dict.
Returns
-------
iterator or list of dict
Controlled by ``return_mode``:
- ``"iter"``: returns an iterator yielding one result dict per query.
- ``"list"``: returns a list with one result dict per query.
Each result dict contains:
- ``"m"``: query index
- ``"candidates"``: kept candidate indices
- ``"post_models"``: posterior weights over kept candidates
- ``"level"``: binner level (or ``None``)
- ``"truncation"``: truncation metadata
- optional ``"stats"``: per-target summary stats
- optional ``"posts_target"``: per-target posterior over bins
Notes
-----
When ``use_standardised=True``, likelihoods are evaluated in the grid's
standardised space, but observable-dependent priors receive *native*
observables.
If ``SIG_native`` contains zeros/near-zeros, pre-clip or configure a
positive floor in your Likelihood to avoid degenerate bandwidths.
"""
X_native = np.asarray(X_native)
SIG_native = np.asarray(SIG_native)
if X_native.ndim != 2 or SIG_native.ndim != 2:
raise ValueError("X_native and SIG_native must be 2D arrays (M, P)")
if X_native.shape != SIG_native.shape:
raise ValueError("X_native and SIG_native must have the same shape (M, P)")
# M=queries, P=Observables
M, P = X_native.shape
# OTF statistics
if stats_for is not None:
# Resolve stats target indices once
stats_j = []
stats_key = []
for tname in stats_for:
if isinstance(tname, str):
j = self.grid.target_names.index(tname)
key = tname
else:
j = int(tname)
key = self.grid.target_names[j]
stats_j.append(j)
stats_key.append(key)
if stats_bins is None or len(stats_for) != len(stats_bins):
logger.info("stats_bins not provided; creating bins based on grid.")
stats_bins = [
np.linspace(
self.grid.targets[:, j].min(),
self.grid.targets[:, j].max(),
200,
)
for j in stats_j
]
else:
logger.info("Using user-provided stats_bins.")
else:
stats_j, stats_key = [], []
# Backend selection
if backend not in ("thread", "process"):
raise ValueError("backend must be 'thread' or 'process'")
if backend == "process":
try:
import pickle as _p
_p.dumps((self.likelihood, self.prior))
except Exception as e:
if verbose:
logger.warning(
"[fit_batch] prior/likelihood are not picklable; "
"falling back to thread backend. Reason: %r",
e,
)
backend = "thread"
# Organise tasks per worker and slices
tpw = 8 if (tasks_per_worker is None) else max(1, int(tasks_per_worker))
slices = _guess_slices(M, P, max(1, int(n_jobs)), tasks_per_worker=tpw)
# Pre-flight summary
if verbose:
logger.info(
f"Starting batch fit: #queries={M}, #observables={P}, "
f"batch backend={backend}, n_jobs={n_jobs}, job tasks={len(slices)}"
)
if posterior_keep_mass is not None:
logger.info(
f" - truncation: keep_mass={posterior_keep_mass}, "
f"\n - min candidates={posterior_keep_min_candidates}"
f"\n - max={posterior_keep_max_candidates}"
f"\n - ties={posterior_keep_ties}"
)
if stats_for is not None:
logger.info("OTF statistics")
logger.info(
f" - stats_for={stats_key} (return_pdf_for_stats={return_pdf_for_stats})"
)
if output_hdf5_path is not None:
logger.info("Data output")
logger.info(" - HDF5 directory: %s:%s", output_hdf5_path, output_hdf5_group)
# if max_memory_gb is not None:
# avail = available_memory_bytes()
# limit = min(avail, int(max_memory_gb * (1024**3)))
# if limit <= 0:
# raise MemoryError("[fit_batch] max_memory_gb yields a non-positive limit.")
# cand_est = self.grid.n_models
# if binner is not None and memcheck_sample != 0:
# sample_n = min(M, max(1, int(memcheck_sample)))
# if sample_n == M:
# sample_idx = range(M)
# else:
# sample_idx = np.linspace(0, M - 1, sample_n, dtype=int)
# sizes = []
# for m in sample_idx:
# idx, _ = binner.candidates(
# y_native=X_native[m, binner.dims],
# sigmas_native=SIG_native[m, binner.dims],
# )
# try:
# size = int(idx.size)
# except AttributeError:
# size = len(idx)
# sizes.append(size if size > 0 else self.grid.n_models)
# if sizes:
# cand_est = int(np.max(sizes))
# kept_est = cand_est
# if posterior_keep_max_candidates is not None:
# kept_est = min(kept_est, int(posterior_keep_max_candidates))
# float_size = np.dtype(float).itemsize
# idx_size = np.dtype(self.grid.grid_int).itemsize
# per_query = kept_est * (idx_size + float_size)
# if stats_for is not None:
# nstats = len(stats_j)
# per_query += (nstats * nstats) * float_size
# if return_pdf_for_stats:
# pdf_bins = 0
# for bins in stats_bins:
# pdf_bins += max(0, int(np.asarray(bins).size) - 1)
# per_query += pdf_bins * float_size
# max_slice = max((e - s) for (s, e) in slices) if slices else M
# inflight = max(1, int(n_jobs)) if n_jobs != 1 else 1
# store_all = (return_mode == "list") and (not output_hdf5_write_only)
# if store_all:
# est = per_query * M + per_query * max_slice * inflight
# else:
# est = per_query * max_slice * inflight
# est = int(est * float(safety_margin))
# if est > limit:
# raise MemoryError(
# "[fit_batch] Estimated peak additional memory "
# f"{est / (1024**3):.2f} GB exceeds limit "
# f"{limit / (1024**3):.2f} GB "
# f"(available {avail / (1024**3):.2f} GB). "
# "Reduce candidates, use return_mode='iter', enable HDF5 "
# "write-only output, or increase max_memory_gb."
# )
# if verbose:
# print(
# f"[fit_batch] Estimated peak additional memory: "
# f"{est / (1024**3):.2f} GB "
# f"(limit {limit / (1024**3):.2f} GB, "
# f"safety_margin={safety_margin})."
# )
if dry_run:
return iter(()) if return_mode == "iter" else []
# Candidate selection helper
def _select_candidates_helper(m: int) -> Tuple[np.ndarray, Optional[int]]:
if binner is None:
return np.arange(self.grid.n_models, dtype=self.grid.grid_int), None
idx, lev = binner.candidates(
y_native=X_native[m],
sigmas_native=SIG_native[m],
)
if idx.size == 0:
idx = np.arange(self.grid.n_models, dtype=self.grid.grid_int)
lev = lev
return idx, lev
# Slice worker
def _slice_worker(s: int, e: int) -> List[Dict[str, Any]]:
out = []
for m in range(s, e):
cand_idx, lev = _select_candidates_helper(m)
w_full = self.posterior_over_models(
x_native=X_native[m],
sigma_native=SIG_native[m],
candidate_idx=cand_idx,
)
# truncation
should_truncate = (
posterior_keep_mass is not None
and (
posterior_keep_mass < 0.9999
or posterior_keep_max_candidates is not None
)
)
if should_truncate:
cand_kept, w_kept, tmeta = _truncate_posterior_mass(
cand_idx,
w_full,
keep_mass=posterior_keep_mass,
min_candidates=posterior_keep_min_candidates,
max_candidates=posterior_keep_max_candidates,
keep_ties=posterior_keep_ties,
)
else:
cand_kept = cand_idx
# ensure normalized
# TODO: this check should be done by the master posterior_over_models
ssum = np.sum(w_full)
if not np.isfinite(ssum) or ssum <= 0:
# TODO: raise a warning?
w_kept = np.full_like(
w_full, 1.0 / max(1, w_full.size), dtype=float
)
else:
w_kept = w_full / ssum
tmeta = {
"n_before": w_full.size,
"n_after": w_kept.size,
"keep_mass": 1.0,
"mass_kept": 1.0,
"mass_dropped": 0.0,
"cut_weight": np.nan,
"ess_before": 1.0 / np.sum(np.square(w_kept))
if w_kept.size
else 0.0,
"ess_after": 1.0 / np.sum(np.square(w_kept))
if w_kept.size
else 0.0,
}
# compute stats on truncated posterior
stats_out = None
posts_target_out = None
if stats_for is not None:
stats_out = {}
if return_pdf_for_stats:
posts_target_out = {}
Y_cov = self.grid.targets[cand_kept][:, stats_j]
if Y_cov.ndim == 1:
Y_cov = Y_cov[:, None]
cov = np.cov(Y_cov.T, aweights=w_kept)
if np.ndim(cov) == 0:
cov = np.array([[float(cov)]])
elif np.ndim(cov) == 1:
cov = np.diag(cov)
stats_out["cov"] = cov
for key, j, bins in zip(stats_key, stats_j, stats_bins):
bins = np.asarray(bins, dtype=float)
y = self.grid.targets[cand_kept, j]
if use_kde_for_stats:
try:
kde = gaussian_kde(y, weights=w_kept)
except np.linalg.LinAlgError:
# Fallback to histogram if KDE fails
logger.warning(
"KDE failed for stats target '%s'; falling back to histogram.",
key,
)
kde = None
if kde is None:
post, _ = np.histogram(
y, bins=bins, weights=w_kept, density=True
)
else:
post = kde.evaluate(0.5 * (bins[:-1] + bins[1:]))
else:
post, _ = np.histogram(
y, bins=bins, weights=w_kept, density=True
)
st = pdf_stats(bins, post, find_multimodal=find_multimodal)
stats_out[key] = {
"mean": st["mean"],
"std": st["std"],
"map": st["map"],
"median": st["median"],
"lo68": st["lo68"],
"hi68": st["hi68"],
"nmodes": (
len(st.get("modes", [st["map"]]))
if find_multimodal
else 1
),
}
if return_pdf_for_stats:
posts_target_out[key] = post
# Results summary
r = {
"m": int(m),
"level": lev,
"candidates": cand_kept,
"post_models": w_kept,
"truncation": tmeta,
}
if stats_out is not None:
r["stats"] = stats_out
if posts_target_out is not None:
r["posts_target"] = posts_target_out
out.append(r)
return out
# (Optional) output to HDF5 file
need_h5 = output_hdf5_path is not None
writer = None
# Prepare stats specs for writer
if need_h5:
stats_specs = None
if stats_for is not None:
stats_specs = [
GridFitStatsSpec(key=k, bins=b)
for k, b in zip(stats_key, stats_bins)
]
writer = GridFitHDF5Writer(
path=output_hdf5_path,
group=output_hdf5_group,
M=M,
P=P,
configuration={
"M": int(M),
"P": int(P),
"backend": str(backend),
"n_jobs": int(n_jobs),
"tasks_per_worker": int(tpw),
"posterior_keep_mass": None
if posterior_keep_mass is None
else float(posterior_keep_mass),
"posterior_keep_min_candidates": int(posterior_keep_min_candidates),
"posterior_keep_max_candidates": None
if posterior_keep_max_candidates is None
else int(posterior_keep_max_candidates),
"posterior_keep_ties": bool(posterior_keep_ties),
"find_multimodal": bool(find_multimodal),
"return_pdf_for_stats": bool(return_pdf_for_stats),
"stats_for": list(stats_key) if stats_for is not None else None,
},
stats_specs=stats_specs,
return_pdf_for_stats=return_pdf_for_stats,
overwrite=output_hdf5_overwrite,
compression=output_hdf5_compression,
compression_opts=output_hdf5_compression_opts,
flush_every=output_hdf5_flush_every,
)
# Iterator producer (writes HDF5 only from the main thread/process)
def _iter_results() -> Iterator[Dict[str, Any]]:
try:
if n_jobs == 1:
for s, e in slices:
batch = _slice_worker(s, e)
for r in batch:
if writer is not None:
writer.write(r)
if not output_hdf5_write_only:
yield r
else:
Executor = (
ThreadPoolExecutor
if backend == "thread"
else ProcessPoolExecutor
)
with Executor(max_workers=max(1, int(n_jobs))) as ex:
futs = [ex.submit(_slice_worker, s, e) for (s, e) in slices]
for fut in as_completed(futs):
batch = fut.result()
for r in batch:
if writer is not None:
writer.write(r)
if not output_hdf5_write_only:
yield r
if verbose:
logger.info("fit_batch complete.")
finally:
if writer is not None:
try:
writer.close()
except Exception:
pass
if return_mode == "iter":
return _iter_results()
elif return_mode == "list":
return list(_iter_results())
else:
raise ValueError("return_mode must be 'iter' or 'list'")
[docs]
def corner_for_targets(
self,
x_native: np.ndarray,
sigma_native: np.ndarray,
target_cols: Sequence[int | str],
true_target_vals: Optional[np.ndarray] = None,
candidate_idx: Optional[np.ndarray] = None,
bins: int | Sequence[int | np.ndarray] = 50,
figsize: Optional[Tuple[float, float]] = None,
suptitle: Optional[str] = None,
color: Optional[str] = None,
kappa_sigma_edges: float = 5.0,
alpha_hist: float = 0.6,
alpha_mesh: float = 1.0,
quantiles: Tuple[float, float, float] = (0.16, 0.5, 0.84),
) -> Tuple[plt.Figure, np.ndarray, Dict[str, Any]]:
"""
Corner plot for selected targets using model posterior weights.
Parameters
----------
x_native : ndarray, shape (P,)
sigma_native : ndarray, shape (P,)
target_cols : sequence of int or str
Target columns to include.
true_target_vals: ndarray of float, optional
True target values.
candidate_idx : ndarray of int, optional
bins : int or sequence
Common bin count, or per-dimension specs (int or edges).
figsize : tuple, optional
suptitle : str, optional
color : str, optional
alpha_hist : float, optional
alpha_mesh : float, optional
quantiles : tuple, optional
Returns
-------
fig, axes, summary : Figure, Axes array, dict
"""
idx = np.arange(self.grid.n_models) if candidate_idx is None else candidate_idx
w = self.posterior_over_models(
x_native=x_native, sigma_native=sigma_native, candidate_idx=idx
)
# Resolve columns and names
cols, names = [], []
for tc in target_cols:
if isinstance(tc, str):
j = self.grid.target_names.index(tc)
else:
j = int(tc)
cols.append(j)
names.append(self.grid.target_names[j])
Y = self.grid.targets[idx][:, cols]
w = w / np.sum(w)
best_fit = np.argmax(w)
best_y = Y[best_fit]
mean_y = np.sum(Y * w[:, None], axis=0)
D = len(cols)
if isinstance(bins, int):
bins_list = [bins] * D
else:
bins_list = list(bins)
if len(bins_list) == 1:
bins_list = bins_list * D
assert len(bins_list) == D
if figsize is None:
figsize = (2.2 * D, 2.2 * D)
fig, axes = plt.subplots(D, D, figsize=figsize, squeeze=False, sharex="col")
# Weighted summaries
q = np.zeros((D, 3))
mu = np.zeros(D)
for d in range(D):
q[d] = weighted_quantiles(Y[:, d], w, quantiles)
mu[d] = np.sum(w * Y[:, d])
# Diagonals
all_edges = []
for i in range(D):
ax = axes[i, i]
bi = bins_list[i]
if isinstance(bi, int):
lo, mid, hi = weighted_quantiles(Y[:, i], w, (0.16, 0.5, 0.84))
lo = lo if np.isfinite(lo) else np.nanmin(Y[:, i])
hi = hi if np.isfinite(hi) else np.nanmax(Y[:, i])
min_v = max(mid - kappa_sigma_edges * (mid - lo), Y[:, i].min())
max_v = min(mid + kappa_sigma_edges * (hi - mid), Y[:, i].max())
if min_v == max_v:
min_v = Y[:, i].min() * 0.9
max_v = Y[:, i].max() * 1.1
edges = np.linspace(min_v, max_v, bi + 1)
else:
edges = np.asarray(bi)
all_edges.append(edges)
hist, _ = np.histogram(Y[:, i], bins=edges, weights=w)
width = np.diff(edges)
dens = hist / (np.sum(hist) * width if np.sum(hist) > 0 else width)
centers = 0.5 * (edges[:-1] + edges[1:])
ax.fill_between(centers, 0, dens, step="mid", alpha=alpha_hist, color=color)
ax.plot(centers, dens, lw=1.0, color=color)
if true_target_vals is not None:
ax.axvline(true_target_vals[i], color="r", label="True")
for qv in q[i]:
ax.axvline(qv, ls="--", lw=1.0, color="gold")
tlt = ", ".join([f"{v:.3f}" for v in q[i]])
ax.set_title(tlt)
ax.axvline(best_y[i], color="fuchsia", lw=1.0, label="Max-like")
ax.axvline(mean_y[i], color="lime", lw=1.0, label="Mean")
ax.legend(fontsize=8)
# Lower triangle
for i in range(1, D):
for j in range(i):
ax = axes[i, j]
xedges = all_edges[j]
yedges = all_edges[i]
H, xe, ye = np.histogram2d(
Y[:, j], Y[:, i], bins=[xedges, yedges], weights=w
)
xb = (xedges[:-1] + xedges[1:]) / 2
yb = (yedges[:-1] + yedges[1:]) / 2
max_val = np.nanmax(H)
# ax.pcolormesh(xe, ye, H.T, shading="auto", alpha=alpha_mesh,
# norm=LogNorm(vmin=max_val / 1e5, vmax=max_val),
# cmap="hot_r")
frac = enclosed_fraction_map(H, xedges=xedges, yedges=yedges)
ax.contourf(
xb, yb, frac.T, cmap="Greys", levels=[0.01, 0.05, 0.32, 0.5, 1]
)
ax.scatter(best_y[j], best_y[i], ec="fuchsia", fc="none")
ax.scatter(mean_y[j], mean_y[i], ec="lime", fc="none")
if true_target_vals is not None:
ax.scatter(
true_target_vals[j], true_target_vals[i], ec="r", fc="none"
)
if i == D - 1:
ax.set_xlabel(names[j])
if j == 0:
ax.set_ylabel(names[i])
else:
ax.set_yticklabels([])
# Hide upper triangle
for i in range(D):
for j in range(i + 1, D):
axes[i, j].axis("off")
if suptitle:
fig.suptitle(suptitle)
else:
fig.suptitle(f"No. of\ncandidate models: {w.size}", fontsize="small")
fig.tight_layout()
# Return summary
cov = (Y - mu).T @ ((Y - mu) * w[:, None])
cov = 0.5 * (cov + cov.T)
sd = np.sqrt(np.clip(np.diag(cov), 0, np.inf))
with np.errstate(invalid="ignore", divide="ignore"):
corr = cov / np.outer(sd, sd)
corr[np.isnan(corr)] = 0.0
summary = {"names": names, "q": q, "mean": mu, "cov": cov, "corr": corr}
return fig, axes, summary
@dataclass
class GridFitStatsSpec:
"""
Specification for one stats target to be written.
Parameters
----------
key : str
Target name (used as group name under /stats/<key>).
bins : ndarray, shape (K+1,)
Bin edges used to build the discrete posterior.
"""
key: str
bins: np.ndarray
[docs]
class GridFitHDF5Writer:
"""
HDF5 writer for GridFitter.fit_batch streaming outputs.
Layout written under ``<group>``::
/configuration attrs['json'] with a JSON blob
/index/m (M,) int64 [0..M-1]
/index/status (M,) int8 [0=not written, 1=ok, -1=failed]
/truncation/* (M,) scalar datasets (level, n_before, n_after, ...)
/candidates/candidates (M,) vlen int64
/candidates/post_models (M,) vlen float64
/stats/<key>/* per-target scalar arrays (M,) and optional posts_target (M,K)
Notes
-----
- Writes are done by index ``m`` in each result dict, so results may
arrive unsorted.
- Designed for use from the main thread/process: collect worker results,
then call ``writer.write(r)`` as they arrive.
"""
def __init__(
self,
path: str,
group: str,
*,
M: int,
P: int,
configuration: Optional[Dict[str, Any]] = None,
stats_specs: Optional[Sequence[GridFitStatsSpec]] = None,
return_pdf_for_stats: bool = False,
overwrite: bool = False,
compression: Optional[str] = "gzip",
compression_opts: int = 4,
flush_every: int = 256,
dtype_posts_target=np.float32,
dtype_post_models=np.float64,
) -> None:
self.path = str(path)
self.group = str(group)
self.M = int(M)
self.P = int(P)
self.return_pdf_for_stats = bool(return_pdf_for_stats)
self.overwrite = bool(overwrite)
self.compression = compression
self.compression_opts = int(compression_opts)
self.flush_every = int(flush_every) if flush_every is not None else 0
self.dtype_posts_target = np.dtype(dtype_posts_target)
self.dtype_post_models = np.dtype(dtype_post_models)
self._written = 0
self._is_closed = True
self._stats_specs: Tuple[GridFitStatsSpec, ...] = tuple(stats_specs or ())
# dataset handles (set in _prepare)
self.h5: Optional[h5py.File] = None
self.g: Optional[h5py.Group] = None
self.ds_status = None
self.ds_level = None
self.ds_n_before = None
self.ds_n_after = None
self.ds_mass_kept = None
self.ds_mass_dropped = None
self.ds_cut_weight = None
self.ds_ess_before = None
self.ds_ess_after = None
self.ds_candidates = None
self.ds_post_models = None
self._stats_handles: Dict[Tuple[str, str], h5py.Dataset] = {}
self.ds_cov = None
self._cov_dim = 0
self._cov_keys: Tuple[str, ...] = ()
self._prepare(configuration=configuration or {})
# --------------------------
# Context manager support
# --------------------------
def __enter__(self) -> "GridFitHDF5Writer":
return self
def __exit__(self, exc_type, exc, tb) -> None:
# If an exception happened, we still close cleanly.
self.close()
# --------------------------
# Public API
# --------------------------
[docs]
def write(self, r: Dict[str, Any]) -> None:
"""
Write one result dict produced by the streaming fit_batch.
Required keys:
- 'm' (int)
- 'level' (int or None)
- 'candidates' (ndarray int)
- 'post_models' (ndarray float, sums to 1)
- 'truncation' (dict with n_before, n_after, mass_kept, mass_dropped, cut_weight, ess_before, ess_after)
Optional:
- 'stats' : dict keyed by target name
- 'posts_target' : dict keyed by target name with (K,) posterior arrays (if enabled)
"""
if self._is_closed:
raise RuntimeError("Writer is closed.")
m = int(r["m"])
if m < 0 or m >= self.M:
raise IndexError(f"m={m} out of bounds for M={self.M}")
# vlen candidates + post_models
self.ds_candidates[m] = np.asarray(r["candidates"], dtype=np.int64)
self.ds_post_models[m] = np.asarray(
r["post_models"], dtype=self.dtype_post_models
)
# truncation + level
lev = r.get("level", None)
self.ds_level[m] = -1 if lev is None else int(lev)
t = r.get("truncation", {}) or {}
self.ds_n_before[m] = int(t.get("n_before", -1))
self.ds_n_after[m] = int(t.get("n_after", -1))
self.ds_mass_kept[m] = np.float32(t.get("mass_kept", np.nan))
self.ds_mass_dropped[m] = np.float32(t.get("mass_dropped", np.nan))
self.ds_cut_weight[m] = np.float32(t.get("cut_weight", np.nan))
self.ds_ess_before[m] = np.float32(t.get("ess_before", np.nan))
self.ds_ess_after[m] = np.float32(t.get("ess_after", np.nan))
# stats (scalars)
stats = r.get("stats", None)
if stats is not None:
for key, st in stats.items():
if key == "cov":
self._write_stat_cov(m, st)
else:
self._write_stats_scalars(m, key, st)
# stats (posterior over bins)
if self.return_pdf_for_stats:
posts_target = r.get("posts_target", None)
if posts_target is not None:
for key, post in posts_target.items():
self._write_posts_target(m, key, post)
# mark status OK
self.ds_status[m] = 1
self._written += 1
if self.flush_every and (self._written % self.flush_every == 0):
self.flush()
[docs]
def mark_failed(self, m: int) -> None:
"""
Mark one object as failed (status=-1). Useful if you catch exceptions externally.
"""
if self._is_closed:
raise RuntimeError("Writer is closed.")
m = int(m)
if m < 0 or m >= self.M:
raise IndexError(f"m={m} out of bounds for M={self.M}")
self.ds_status[m] = -1
self._written += 1
if self.flush_every and (self._written % self.flush_every == 0):
self.flush()
def flush(self) -> None:
if self.h5 is not None and not self._is_closed:
self.h5.flush()
def close(self) -> None:
if self._is_closed:
return
try:
self.flush()
finally:
try:
if self.h5 is not None:
self.h5.close()
finally:
self.h5 = None
self.g = None
self._is_closed = True
def _prepare(self, configuration: Dict[str, Any]) -> None:
mode = "a" if os.path.exists(self.path) else "w"
self.h5 = h5py.File(self.path, mode)
self._is_closed = False
grp_path = self.group
if grp_path in self.h5:
if self.overwrite:
del self.h5[grp_path]
else:
raise ValueError(
f"HDF5 group '{grp_path}' exists in '{self.path}'. "
"Set overwrite=True to replace it."
)
self.g = self.h5.create_group(grp_path)
# /configuration
gc = self.g.create_group("configuration")
cfg = dict(configuration or {})
cfg.setdefault("M", int(self.M))
cfg.setdefault("P", int(self.P))
cfg.setdefault("group", str(self.group))
cfg.setdefault("return_pdf_for_stats", bool(self.return_pdf_for_stats))
cfg.setdefault(
"stats_keys",
[s.key for s in self._stats_specs] if self._stats_specs else None,
)
gc.attrs["json"] = json.dumps(cfg, ensure_ascii=True)
# /index
gi = self.g.create_group("index")
gi.create_dataset("m", data=np.arange(self.M, dtype=np.int64))
self.ds_status = gi.create_dataset("status", shape=(self.M,), dtype=np.int8)
self.ds_status[...] = 0 # 0=not written, 1=ok, -1=failed
# /truncation
gt = self.g.create_group("truncation")
self.ds_level = gt.create_dataset("level", shape=(self.M,), dtype=np.int32)
self.ds_level[...] = -1
self.ds_n_before = gt.create_dataset(
"n_before", shape=(self.M,), dtype=np.int32
)
self.ds_n_before[...] = -1
self.ds_n_after = gt.create_dataset("n_after", shape=(self.M,), dtype=np.int32)
self.ds_n_after[...] = -1
self.ds_mass_kept = gt.create_dataset(
"mass_kept", shape=(self.M,), dtype=np.float32
)
self.ds_mass_kept[...] = np.nan
self.ds_mass_dropped = gt.create_dataset(
"mass_dropped", shape=(self.M,), dtype=np.float32
)
self.ds_mass_dropped[...] = np.nan
self.ds_cut_weight = gt.create_dataset(
"cut_weight", shape=(self.M,), dtype=np.float32
)
self.ds_cut_weight[...] = np.nan
self.ds_ess_before = gt.create_dataset(
"ess_before", shape=(self.M,), dtype=np.float32
)
self.ds_ess_before[...] = np.nan
self.ds_ess_after = gt.create_dataset(
"ess_after", shape=(self.M,), dtype=np.float32
)
self.ds_ess_after[...] = np.nan
# /candidates (variable-length arrays)
gcand = self.g.create_group("candidates")
vlen_i64 = h5py.vlen_dtype(np.dtype("int64"))
vlen_post = h5py.vlen_dtype(self.dtype_post_models)
self.ds_candidates = gcand.create_dataset(
"candidates", shape=(self.M,), dtype=vlen_i64
)
self.ds_post_models = gcand.create_dataset(
"post_models", shape=(self.M,), dtype=vlen_post
)
# /stats
if self._stats_specs:
gs = self.g.create_group("stats")
self._stats_handles = {}
for spec in self._stats_specs:
key = str(spec.key)
bins = np.asarray(spec.bins, dtype=float)
if bins.ndim != 1 or bins.size < 2:
raise ValueError(f"Invalid bins for stats target '{key}'")
tg = gs.create_group(key)
K = int(bins.size - 1)
tg.create_dataset("bins", data=bins)
# Scalars per object
for name, dtype, fill in [
("mean", np.float32, np.nan),
("std", np.float32, np.nan),
("map", np.float32, np.nan),
("median", np.float32, np.nan),
("lo68", np.float32, np.nan),
("hi68", np.float32, np.nan),
]:
d = tg.create_dataset(name, shape=(self.M,), dtype=dtype)
d[...] = fill
self._stats_handles[(key, name)] = d
d_nm = tg.create_dataset("nmodes", shape=(self.M,), dtype=np.int16)
d_nm[...] = 0
self._stats_handles[(key, "nmodes")] = d_nm
if self.return_pdf_for_stats:
dpt = tg.create_dataset(
"posts_target",
shape=(self.M, K),
dtype=self.dtype_posts_target,
compression=self.compression,
compression_opts=self.compression_opts,
chunks=True,
)
dpt[...] = np.nan
self._stats_handles[(key, "posts_target")] = dpt
# --------------------------
# /stats/cov : per-object covariance
# --------------------------
self._cov_keys = tuple(str(s.key) for s in self._stats_specs)
self._cov_dim = len(self._cov_keys)
gcov = gs.create_group("cov")
gcov.create_dataset("keys", data=np.array(self._cov_keys, dtype="S"))
# Covariance per object: (M, D, D)
# Chunk across objects; full matrix per chunk row is typically fine.
self.ds_cov = gcov.create_dataset(
"cov",
shape=(self.M, self._cov_dim, self._cov_dim),
dtype=np.float32,
compression=self.compression,
compression_opts=self.compression_opts,
chunks=(min(self.M, 64), self._cov_dim, self._cov_dim),
)
self.ds_cov[...] = np.nan
# Make sure initial metadata is on disk
self.flush()
# --------------------------
# Internal: stats writing
# --------------------------
[docs]
def _write_stats_scalars(self, m: int, key: str, st: Dict[str, Any]) -> None:
"""
Write scalar stats for one target.
Missing targets are ignored (allows partial stats).
"""
key = str(key)
# Only write keys that exist in file
if (key, "mean") not in self._stats_handles:
return
# Support both the compact dict (as produced by our fit_batch) and
# extended dicts (ignoring unknown fields).
self._stats_handles[(key, "mean")][m] = np.float32(st.get("mean", np.nan))
self._stats_handles[(key, "std")][m] = np.float32(st.get("std", np.nan))
self._stats_handles[(key, "map")][m] = np.float32(st.get("map", np.nan))
self._stats_handles[(key, "median")][m] = np.float32(st.get("median", np.nan))
self._stats_handles[(key, "lo68")][m] = np.float32(st.get("lo68", np.nan))
self._stats_handles[(key, "hi68")][m] = np.float32(st.get("hi68", np.nan))
self._stats_handles[(key, "nmodes")][m] = np.int16(st.get("nmodes", 0))
[docs]
def _write_posts_target(self, m: int, key: str, post: Any) -> None:
"""
Write the discrete posterior over bins for one target (shape (K,)).
If the dataset doesn't exist or shapes mismatch, raises ValueError.
"""
key = str(key)
hkey = (key, "posts_target")
if hkey not in self._stats_handles:
return
dpt = self._stats_handles[hkey]
arr = np.asarray(post, dtype=self.dtype_posts_target)
if arr.ndim != 1 or arr.shape[0] != dpt.shape[1]:
raise ValueError(
f"posts_target shape mismatch for '{key}': got {arr.shape}, expected ({dpt.shape[1]},)"
)
dpt[m, :] = arr
[docs]
def _write_stat_cov(self, m: int, st: Any) -> None:
"""
Write the covariance matrix for object m.
Expected input formats
----------------------
st can be:
1) a dict with key "cov" holding an array-like (D, D), OR
2) directly an array-like (D, D)
Where D == len(self._stats_specs) and matches /stats/cov/keys ordering.
If the covariance dataset is not present (e.g. no stats_specs),
this is a no-op.
"""
if self.ds_cov is None:
return
# Accept dict or raw array
if isinstance(st, dict) and ("cov" in st):
cov = st["cov"]
else:
cov = st
arr = np.asarray(cov, dtype=np.float32)
D = int(self._cov_dim)
if arr.ndim != 2 or arr.shape != (D, D):
raise ValueError(
f"cov shape mismatch: got {arr.shape}, expected ({D}, {D}). "
f"Target order: {self._cov_keys}"
)
# Optional: enforce symmetry softly (avoid surprises)
# Comment out if you want to store exactly what you compute.
arr = 0.5 * (arr + arr.T)
self.ds_cov[m, :, :] = arr