"""
Post-processing utilities for BESTA results.
This module provides:
- I/O for CosmoSIS-like tabular results.
- Weighted posterior summaries (MAP, mean, covariance, correlation).
- Weighted quantiles and (optionally multi-modal) HDI intervals.
- 1D PDFs (histogram-derived + optional KDE).
- 2D PDFs (KDE fallback to histogram) + HPD enclosed-fraction maps.
- A structured ResultsSummary object with FITS + JSON export helpers.
Notes
-----
- This module assumes the posterior samples are stored in an Astropy Table
with a log-posterior column (default: "post") and parameter columns
named like "section--name" (default delimiter/prefix: "--").
"""
from __future__ import annotations
import json
import os
from dataclasses import dataclass, field
from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple
import numpy as np
from astropy.io import fits
from astropy.table import Table, Column
from astropy import units as u
from matplotlib import pyplot as plt
from scipy import stats
from besta.logging import get_logger
logger = get_logger(__name__)
# -----------------------------------------------------------------------------
# Weighted statistics
# -----------------------------------------------------------------------------
[docs]
def _as_float_array(x) -> np.ndarray:
"""Conversion to float ndarray (handles Astropy Column/Quantity)."""
if isinstance(x, (u.Quantity, Column)):
return x.value
else:
return np.asarray(x, dtype=float)
[docs]
def normalize_weights(weights: np.ndarray, *, allow_all_zero: bool = False) -> np.ndarray:
"""Normalize weights to sum=1 over finite entries."""
w = _as_float_array(weights)
w = np.where(np.isfinite(w), w, 0.0)
s = np.sum(w)
if s <= 0:
if allow_all_zero:
return w
raise ValueError("Sum of weights must be > 0.")
return w / s
[docs]
def weighted_mean(x: np.ndarray, weights: np.ndarray, axis: int = -1) -> np.ndarray:
"""Weighted mean along `axis`."""
x = _as_float_array(x)
w = normalize_weights(weights)
# Expand w to broadcast along x dimensions
shape = [1] * x.ndim
shape[axis] = -1
w_ = w.reshape(shape)
return np.sum(x * w_, axis=axis)
[docs]
def weighted_covariance(
x: np.ndarray,
weights: np.ndarray,
*,
unbiased: bool = False,
) -> np.ndarray:
"""
Weighted covariance for x with shape (D, N) and weights with shape (N,).
If `unbiased=True`, applies the common correction for normalized weights:
cov /= (1 - sum(w^2))
"""
x = _as_float_array(x)
if x.ndim != 2:
raise ValueError("x must have shape (D, N).")
w = normalize_weights(weights) # sum=1
mu = np.sum(x * w[None, :], axis=1) # (D,)
xm = x - mu[:, None] # (D, N)
cov = np.sum(w[None, None, :] * xm[:, None, :] * xm[None, :, :], axis=2) # (D, D)
if unbiased:
denom = 1.0 - np.sum(w**2)
if denom <= 0:
raise ValueError("Unbiased covariance undefined for degenerate weights.")
cov /= denom
return cov
[docs]
def covariance_to_correlation(cov: np.ndarray) -> np.ndarray:
"""Convert covariance matrix to correlation matrix."""
cov = _as_float_array(cov)
d = np.sqrt(np.clip(np.diag(cov), 0.0, np.inf))
with np.errstate(divide="ignore", invalid="ignore"):
corr = cov / (d[:, None] * d[None, :])
corr[~np.isfinite(corr)] = 0.0
np.fill_diagonal(corr, 1.0)
return corr
[docs]
def weighted_quantile(x: np.ndarray, weights: np.ndarray, q: Sequence[float]) -> np.ndarray:
"""
Weighted quantiles for 1D sample x with weights.
Parameters
----------
x : array-like, shape (N,)
weights : array-like, shape (N,)
q : sequence of quantiles in [0, 1]
Returns
-------
quantiles : ndarray shape (len(q),)
"""
x = _as_float_array(x).ravel()
w = _as_float_array(weights).ravel()
mask = np.isfinite(x) & np.isfinite(w)
x = x[mask]
w = w[mask]
if x.size == 0:
return np.full(len(q), np.nan, dtype=float)
w = normalize_weights(w)
idx = np.argsort(x)
xs = x[idx]
ws = w[idx]
cdf = np.cumsum(ws)
# Ensure cdf spans [0,1]
cdf[-1] = 1.0
return np.interp(np.asarray(q, dtype=float), cdf, xs)
[docs]
def weighted_hdi(
x: np.ndarray,
weights: np.ndarray,
*,
mass: float = 0.68,
max_intervals: int = 2,
merge_tol: float = 0.0,
) -> List[Tuple[float, float]]:
"""
Compute weighted highest-density interval(s) from 1D samples.
The output approximates the smallest region containing ``mass`` probability,
allowing multi-modality via disjoint intervals.
Method
------
1. Sort samples by ``x``.
2. Use cumulative weighted mass.
3. Find the narrowest interval(s) with enclosed mass >= ``mass``.
4. Optionally repeat to recover additional disjoint intervals.
Parameters
----------
x, weights : arrays of shape (N,)
mass : target probability mass in (0,1]
max_intervals : maximum number of disjoint intervals to return (approximate)
merge_tol : merge intervals whose gap is <= merge_tol (in x units)
Returns
-------
intervals : list of (low, high)
"""
x = _as_float_array(x).ravel()
w = _as_float_array(weights).ravel()
mask = np.isfinite(x) & np.isfinite(w)
x = x[mask]
w = w[mask]
if x.size == 0:
return [(np.nan, np.nan)]
if not (0 < mass <= 1):
raise ValueError("mass must be in (0, 1].")
# Normalize and sort
w = normalize_weights(w)
idx = np.argsort(x)
xs = x[idx]
ws = w[idx]
cdf = np.cumsum(ws)
cdf[-1] = 1.0
intervals: List[Tuple[float, float]] = []
used = np.zeros(xs.size, dtype=bool)
def _find_best_interval(available_mask: np.ndarray) -> Optional[Tuple[int, int]]:
# Work on contiguous segments of availability. This keeps interval definition meaningful.
best = None
best_width = np.inf
# Identify contiguous runs
avail = available_mask.astype(int)
# runs: start indices where diff==1, end where diff==-1
starts = np.where(np.diff(np.r_[0, avail]) == 1)[0]
ends = np.where(np.diff(np.r_[avail, 0]) == -1)[0]
for s, e in zip(starts, ends):
# Consider sub-array xs[s:e], ws[s:e]
sub_ws = ws[s:e]
if sub_ws.size == 0:
continue
sub_cdf = np.cumsum(sub_ws)
sub_cdf[-1] = np.sum(sub_ws)
# Normalize to segment mass; but we want absolute mass, so compare to `mass` directly.
# Since total mass across all samples is 1, segment mass might be < mass; skip then.
if sub_cdf[-1] < mass:
continue
sub_cdf0 = np.concatenate([[0.0], sub_cdf])
# Two-pointer to find minimal width interval >= mass in this segment
i = 0
for j in range(1, sub_cdf0.size):
while (sub_cdf0[j] - sub_cdf0[i]) >= mass and i < j:
width = xs[s + j - 1] - xs[s + i]
if width < best_width:
best_width = width
best = (s + i, s + j - 1)
i += 1
return best
for _ in range(max_intervals):
available = ~used
best = _find_best_interval(available)
if best is None:
break
i, j = best
intervals.append((float(xs[i]), float(xs[j])))
used[i : j + 1] = True
# If the first interval already covers the full mass approximately, stop.
# (We approximate by computing mass inside that interval.)
m = np.sum(ws[(xs >= xs[i]) & (xs <= xs[j])])
if m >= mass:
# Good enough — returning single interval is typical.
break
# Merge close intervals if requested
if merge_tol > 0 and len(intervals) > 1:
intervals = sorted(intervals, key=lambda t: t[0])
merged: List[Tuple[float, float]] = [intervals[0]]
for lo, hi in intervals[1:]:
plo, phi = merged[-1]
if lo - phi <= merge_tol:
merged[-1] = (plo, max(phi, hi))
else:
merged.append((lo, hi))
intervals = merged
if len(intervals) == 0:
# Fallback to full range
intervals = [(float(xs[0]), float(xs[-1]))]
return intervals
# -----------------------------------------------------------------------------
# PDFs and HPD maps
# -----------------------------------------------------------------------------
[docs]
def enclosed_fraction_map(
density: np.ndarray,
*,
xedges: Optional[np.ndarray] = None,
yedges: Optional[np.ndarray] = None,
) -> np.ndarray:
"""
Compute an HPD-style enclosed fraction map from a 2D density grid.
Returns an array ``F`` with the same shape as ``density`` where each pixel
stores the cumulative enclosed mass at that density threshold (after sorting
pixels by density in descending order).
If ``xedges`` and ``yedges`` are provided, pixel areas are included when
computing mass; otherwise constant pixel area is assumed.
"""
dens = _as_float_array(density)
if dens.ndim != 2:
raise ValueError("density must be 2D.")
if xedges is not None and yedges is not None:
dx = np.diff(_as_float_array(xedges)) # along x/columns
dy = np.diff(_as_float_array(yedges)) # along y/rows
if dens.shape != (dy.size, dx.size):
raise ValueError(
f"density shape {dens.shape} does not match edges "
f"(ny,nx)=({dy.size},{dx.size})."
)
area = dy[:, None] * dx[None, :]
mass = dens * area
else:
mass = dens
flat_d = dens.ravel()
flat_m = mass.ravel()
# Mask non-finite and negative masses/densities
good = np.isfinite(flat_d) & np.isfinite(flat_m) & (flat_m >= 0)
out = np.full_like(flat_d, np.nan, dtype=float)
if not np.any(good):
return out.reshape(dens.shape)
flat_dg = flat_d[good]
flat_mg = flat_m[good]
order = np.argsort(flat_dg)[::-1] # descending density
cum = np.cumsum(flat_mg[order])
if cum[-1] <= 0:
return out.reshape(dens.shape)
cum /= cum[-1]
# Place back into original flat array
tmp = np.empty_like(cum)
tmp[order] = cum
out[good] = tmp
return out.reshape(dens.shape)
[docs]
def histogram_pdf_1d(
x: np.ndarray,
weights: np.ndarray,
*,
bins: int = 100,
range: Optional[Tuple[float, float]] = None,
) -> Tuple[np.ndarray, np.ndarray]:
"""
1D PDF from weighted histogram.
Returns
-------
centers : (bins,)
pdf : (bins,), normalized to integrate to 1 over x
"""
x = _as_float_array(x).ravel()
w = _as_float_array(weights).ravel()
mask = np.isfinite(x) & np.isfinite(w)
x = x[mask]
w = w[mask]
if x.size == 0:
centers = np.linspace(0, 1, bins)
return centers, np.full_like(centers, np.nan)
# Normalize weights but histogram density will handle scaling; still OK.
w = normalize_weights(w)
hist, edges = np.histogram(x, bins=bins, range=range, weights=w, density=False)
dx = np.diff(edges)
# Convert probability per bin -> density
with np.errstate(divide="ignore", invalid="ignore"):
pdf = hist / dx
centers = 0.5 * (edges[:-1] + edges[1:])
# Ensure integrates to 1 (numerical)
integral = np.nansum(pdf * dx)
if integral > 0:
pdf /= integral
return centers, pdf
[docs]
def kde_pdf_1d(
x: np.ndarray,
weights: np.ndarray,
grid: np.ndarray,
) -> np.ndarray:
"""
KDE PDF on a provided grid; returns NaNs on failure.
"""
x = _as_float_array(x).ravel()
w = _as_float_array(weights).ravel()
g = _as_float_array(grid).ravel()
mask = np.isfinite(x) & np.isfinite(w)
x = x[mask]
w = w[mask]
if x.size == 0:
return np.full_like(g, np.nan)
try:
w = normalize_weights(w)
kde = stats.gaussian_kde(x, weights=w)
return kde(g)
except Exception:
return np.full_like(g, np.nan)
[docs]
def kde_or_hist_pdf_2d(
x: np.ndarray,
y: np.ndarray,
weights: np.ndarray,
*,
bins: int = 80,
range: Optional[Tuple[Tuple[float, float], Tuple[float, float]]] = None,
use_kde: bool = True,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
2D PDF on a regular grid. Returns (x_centers, y_centers, pdf[y,x]).
Uses KDE if possible (and use_kde), otherwise falls back to weighted histogram2d.
"""
x = _as_float_array(x).ravel()
y = _as_float_array(y).ravel()
w = _as_float_array(weights).ravel()
mask = np.isfinite(x) & np.isfinite(y) & np.isfinite(w)
x = x[mask]
y = y[mask]
w = w[mask]
if x.size == 0:
xc = np.linspace(0, 1, bins)
yc = np.linspace(0, 1, bins)
return xc, yc, np.full((bins, bins), np.nan)
w = normalize_weights(w)
if range is None:
xr = (np.min(x), np.max(x))
yr = (np.min(y), np.max(y))
else:
xr, yr = range
xedges = np.linspace(xr[0], xr[1], bins + 1)
yedges = np.linspace(yr[0], yr[1], bins + 1)
xc = 0.5 * (xedges[:-1] + xedges[1:])
yc = 0.5 * (yedges[:-1] + yedges[1:])
if use_kde:
try:
# KDE evaluated on grid
X, Y = np.meshgrid(xc, yc, indexing="xy") # shapes (bins,bins)
kde = stats.gaussian_kde(np.vstack([x, y]), weights=w)
Z = kde(np.vstack([X.ravel(), Y.ravel()])).reshape((bins, bins))
# Normalize numerically to ensure integral ~ 1
dx = np.diff(xedges)[0]
dy = np.diff(yedges)[0]
integral = np.sum(Z) * dx * dy
if integral > 0:
Z /= integral
return xc, yc, Z
except Exception:
pass
# Histogram fallback: numpy returns shape (bins_x, bins_y) by default; we want (bins_y, bins_x)
H, xe, ye = np.histogram2d(x, y, bins=[xedges, yedges], weights=w, density=False)
# Convert probability per bin -> density
dx = np.diff(xe)[0]
dy = np.diff(ye)[0]
Z = H / (dx * dy)
# Transpose to (y,x) for meshgrid(indexing="xy") conventions
Z = Z.T
# Normalize numerically
integral = np.sum(Z) * dx * dy
if integral > 0:
Z /= integral
return xc, yc, Z
# -----------------------------------------------------------------------------
# I/O helpers
# -----------------------------------------------------------------------------
[docs]
def read_results_file(path: str, *, delimiter: str = "\t") -> Table:
"""
Read a CosmoSIS-style text results file:
- First line is a commented header starting with '#'
- Columns are delimiter-separated
- Remaining lines numeric
Returns an Astropy Table with lowercase column names.
"""
with open(path, "r", encoding="utf-8") as f:
header = f.readline()
if not header.startswith("#"):
raise ValueError("Expected first line header starting with '#'.")
columns = header.strip("# \n").split(delimiter)
matrix = np.atleast_2d(np.loadtxt(path))
tab = Table()
if matrix.size <= 1:
return tab
if matrix.shape[1] != len(columns):
raise ValueError(
f"Data has {matrix.shape[1]} columns but header lists {len(columns)}."
)
for i, c in enumerate(columns):
tab[c.strip().lower()] = matrix[:, i]
return tab
def _select_parameter_keys(
table: Table,
*,
parameter_prefix: str = "--",
parameter_keys: Optional[Sequence[str]] = None,
) -> List[str]:
if parameter_keys is not None:
keys = list(parameter_keys)
else:
keys = [k for k in table.colnames if parameter_prefix in k]
if len(keys) == 0:
raise ValueError("No parameter keys found/selected.")
return keys
[docs]
def _split_param_key(key: str, prefix: str = "--") -> Tuple[str, str]:
"""
Split a parameter key into (section, name) using the delimiter/prefix.
If the key cannot be split, returns ("", key).
"""
if prefix in key:
sect, name = key.split(prefix, 1)
return sect, name
return "", key
def _logsumexp_weighted(a: np.ndarray, w: np.ndarray) -> float:
a = _as_float_array(a).ravel()
w = _as_float_array(w).ravel()
m = np.nanmax(a)
return float(m + np.log(np.nansum(w * np.exp(a - m))))
[docs]
def effective_sample_size(weights: np.ndarray) -> float:
"""Return the standard importance-sampling effective sample size."""
w = normalize_weights(weights)
return float(1.0 / np.sum(w**2))
[docs]
def laplace_logz(loglike_map: float, logprior_map: float, cov: np.ndarray) -> EvidenceEstimate:
"""Estimate log-evidence with a Laplace approximation around the MAP point."""
cov = _as_float_array(cov)
d = cov.shape[0]
sign, logdet = np.linalg.slogdet(cov)
if sign <= 0 or not np.isfinite(logdet):
return EvidenceEstimate(
method="laplace",
logz=float("nan"),
details={"reason": "covariance not positive definite", "sign": float(sign), "logdet": float(logdet)},
)
logz = loglike_map + logprior_map + 0.5 * d * np.log(2.0 * np.pi) + 0.5 * logdet
return EvidenceEstimate(
method="laplace",
logz=float(logz),
details={"d": int(d), "logdet_cov": float(logdet)},
)
[docs]
def harmonic_mean_logz(loglike: np.ndarray, weights: np.ndarray, *, trim_frac: float = 0.01) -> EvidenceEstimate:
"""
Robust harmonic-mean estimator.
Uses ``log Z = -log(E_posterior[exp(-logL)])``.
This estimator is generally unstable; trimming helps, but it should still
be treated as a sanity check only.
"""
ll = _as_float_array(loglike).ravel()
w = normalize_weights(weights)
mask = np.isfinite(ll) & np.isfinite(w)
ll = ll[mask]
w = w[mask]
if ll.size == 0:
return EvidenceEstimate(method="hme_robust", logz=float("nan"), details={"reason": "no finite loglike"})
# Trim the top tail in loglike to reduce variance
if trim_frac > 0:
cutoff = np.quantile(ll, 1.0 - trim_frac)
keep = ll <= cutoff
ll = ll[keep]
w = normalize_weights(w[keep])
log_term = _logsumexp_weighted(-ll, w)
logz = -log_term
return EvidenceEstimate(
method="hme_robust",
logz=float(logz),
details={"n": int(ll.size), "ess": effective_sample_size(w), "trim_frac": float(trim_frac)},
)
[docs]
@dataclass
class EvidenceEstimate:
"""Container for a scalar evidence estimate and method-specific metadata."""
method: str
logz: float
logz_err: Optional[float] = None
details: Dict[str, Any] = None
# -----------------------------------------------------------------------------
# ResultsSummary dataclass
# -----------------------------------------------------------------------------
[docs]
@dataclass
class ResultsSummary:
"""
Container for posterior summary products.
Attributes
----------
parameter_keys : list of full parameter keys (e.g. "sfh--tau")
parameter_names : list of short names (e.g. "tau")
parameter_sections : list of sections (e.g. "sfh")
posterior_key : log-posterior column name used
map_index : index of MAP (maximum log-posterior) sample among the *filtered* samples
n_samples : number of samples used after filtering
weights : normalized posterior weights, shape (n_samples,)
logpost : log posterior values, shape (n_samples,)
samples : array shape (n_params, n_samples)
map : vector shape (n_params,)
mean : vector shape (n_params,)
covariance : matrix shape (n_params, n_params)
correlation : matrix shape (n_params, n_params)
percentiles : list of quantiles in [0,1]
percentiles_values : array shape (n_params, n_percentiles)
percentiles_logpost : array shape (n_params, n_percentiles)
hdi_mass : mass used for HDI
hdi_intervals : dict short_name -> list of (low, high)
pdf_1d : dict short_name -> dict with keys: grid, hist_pdf, kde_pdf
pdf_2d : dict (short_i, short_j) -> dict with keys: xgrid, ygrid, pdf, enclosed_fraction
extra_info : arbitrary metadata to include in exports
"""
parameter_keys: List[str]
posterior_key: str = "post"
parameter_sections: List[str] = field(default_factory=list)
parameter_names: List[str] = field(default_factory=list)
n_samples: int = 0
map_index: int = -1
weights: np.ndarray = field(default_factory=lambda: np.array([], dtype=float))
logpost: np.ndarray = field(default_factory=lambda: np.array([], dtype=float))
samples: np.ndarray = field(default_factory=lambda: np.empty((0, 0), dtype=float))
map: np.ndarray = field(default_factory=lambda: np.array([], dtype=float))
mean: np.ndarray = field(default_factory=lambda: np.array([], dtype=float))
covariance: np.ndarray = field(default_factory=lambda: np.empty((0, 0), dtype=float))
correlation: np.ndarray = field(default_factory=lambda: np.empty((0, 0), dtype=float))
percentiles: List[float] = field(default_factory=list)
percentiles_values: np.ndarray = field(default_factory=lambda: np.empty((0, 0), dtype=float))
percentiles_logpost: np.ndarray = field(default_factory=lambda: np.empty((0, 0), dtype=float))
hdi_mass: float = 0.68
hdi_intervals: Dict[str, List[Tuple[float, float]]] = field(default_factory=dict)
evidence: Optional[EvidenceEstimate] = None
pdf_1d: Dict[str, Dict[str, np.ndarray]] = field(default_factory=dict)
pdf_2d: Dict[Tuple[str, str], Dict[str, np.ndarray]] = field(default_factory=dict)
extra_info: Dict[str, Any] = field(default_factory=dict)
[docs]
def estimate_evidence_from_table(
self,
table: Table,
*,
logpost_key: Optional[str] = None,
logprior_key: str = "prior",
loglike_key: Optional[str] = None,
method: str = "laplace",
hme_trim_frac: float = 0.01,
parameter_prefix: str = "--",
) -> EvidenceEstimate:
"""
Estimate evidence logZ using information in the original results table.
Supported methods are ``"laplace"`` (MAP loglike/logprior + covariance)
and ``"hme_robust"`` (posterior expectation of ``1/L``, sanity-check).
If ``loglike_key`` is not available, log-likelihood is reconstructed as
``loglike = post - prior``.
"""
if logpost_key is None:
logpost_key = self.posterior_key
if logpost_key not in table.colnames:
return EvidenceEstimate(method=method, logz=float("nan"), details={"reason": f"missing '{logpost_key}'"})
if logprior_key not in table.colnames:
return EvidenceEstimate(method=method, logz=float("nan"), details={"reason": f"missing '{logprior_key}'"})
logpost_all = _as_float_array(table[logpost_key])
logprior_all = _as_float_array(table[logprior_key])
if loglike_key is not None and loglike_key in table.colnames:
loglike_all = _as_float_array(table[loglike_key])
else:
# Reconstruct loglike = post - prior
loglike_all = logpost_all - logprior_all
# Rebuild the same “finite” mask used by summarize_results (posterior + parameters + prior/like)
mask = np.isfinite(logpost_all) & np.isfinite(logprior_all) & np.isfinite(loglike_all)
for k in self.parameter_keys:
if k not in table.colnames:
return EvidenceEstimate(method=method, logz=float("nan"), details={"reason": f"missing parameter column '{k}'"})
mask &= np.isfinite(_as_float_array(table[k]))
if not np.any(mask):
return EvidenceEstimate(method=method, logz=float("nan"), details={"reason": "no finite rows after masking"})
logpost = logpost_all[mask]
logprior = logprior_all[mask]
loglike = loglike_all[mask]
# Weights from posterior
w = normalize_weights(np.exp(logpost - np.max(logpost)))
map_idx = int(np.nanargmax(logpost))
ll_map = float(loglike[map_idx])
lp_map = float(logprior[map_idx])
if method.lower() == "laplace":
est = laplace_logz(ll_map, lp_map, self.covariance)
elif method.lower() in ("hme", "hme_robust", "harmonic", "harmonic_mean"):
est = harmonic_mean_logz(loglike, w, trim_frac=hme_trim_frac)
else:
est = EvidenceEstimate(method=method, logz=float("nan"), details={"reason": f"unknown method '{method}'"})
# Store on the object for export
self.evidence = est
return est
[docs]
def to_json_dict(self) -> Dict[str, Any]:
"""Convert summary to a JSON-serializable dictionary (numpy arrays -> lists)."""
def _tolist(a):
if isinstance(a, np.ndarray):
return a.tolist()
return a
out: Dict[str, Any] = {
"parameter_keys": self.parameter_keys,
"parameter_sections": self.parameter_sections,
"parameter_names": self.parameter_names,
"posterior_key": self.posterior_key,
"n_samples": int(self.n_samples),
"map_index": int(self.map_index),
"weights": _tolist(self.weights),
"logpost": _tolist(self.logpost),
"map": _tolist(self.map),
"mean": _tolist(self.mean),
"covariance": _tolist(self.covariance),
"correlation": _tolist(self.correlation),
"percentiles": _tolist(np.asarray(self.percentiles, dtype=float)),
"percentiles_values": _tolist(self.percentiles_values),
"percentiles_logpost": _tolist(self.percentiles_logpost),
"hdi_mass": float(self.hdi_mass),
"hdi_intervals": {k: [list(iv) for iv in v] for k, v in self.hdi_intervals.items()},
"evidence": None if self.evidence is None else {
"method": self.evidence.method,
"logz": self.evidence.logz,
"logz_err": self.evidence.logz_err,
"details": self.evidence.details,
},
"pdf_1d": {
k: {kk: _tolist(vv) for kk, vv in d.items()}
for k, d in self.pdf_1d.items()
},
"pdf_2d": {
f"{k0}__{k1}": {kk: _tolist(vv) for kk, vv in d.items()}
for (k0, k1), d in self.pdf_2d.items()
},
"extra_info": self.extra_info,
}
return out
[docs]
def write_json(self, path: str, *, overwrite: bool = True, indent: int = 2) -> str:
"""Write a JSON summary file."""
if (not overwrite) and os.path.exists(path):
raise FileExistsError(path)
with open(path, "w", encoding="utf-8") as f:
json.dump(self.to_json_dict(), f, indent=indent)
return path
[docs]
def to_fits(self) -> fits.HDUList:
"""
Build a FITS HDUList with:
- PrimaryHDU: extra_info
- ImageHDU: COVARIANCE (+ header with means/MAP)
- ImageHDU: CORRELATION
- BinTableHDU: PERCENTILES
- BinTableHDU: PDF1D
- ImageHDUs: PDF2D_<i>__<j> and ENCFRAC_<i>__<j>
"""
prim = fits.PrimaryHDU()
# Store extra_info as header cards (best effort)
for k, v in (self.extra_info or {}).items():
try:
prim.header[str(k)[:8]] = v
except Exception:
# Skip non-serializable header items
continue
hdr = fits.Header()
hdr["N_SAMP"] = int(self.n_samples)
hdr["MAP_IDX"] = int(self.map_index)
hdr["POSTKEY"] = self.posterior_key
# Per-parameter header cards (best effort: keep short names)
for i, (full_key, sect, name) in enumerate(zip(self.parameter_keys, self.parameter_sections, self.parameter_names)):
# FITS keyword length limits: use e.g. P000NM, P000SC, P000MN, P000MP
tag = f"P{i:03d}"
hdr[f"{tag}NM"] = name[:68]
hdr[f"{tag}SC"] = sect[:68]
hdr[f"{tag}KY"] = full_key[:68]
if i < self.mean.size:
if np.isfinite(self.mean[i]):
hdr[f"{tag}MN"] = float(self.mean[i]), "mean"
else:
hdr[f"{tag}MN"] = "nan", "mean not finite"
if i < self.map.size:
if np.isfinite(self.map[i]):
hdr[f"{tag}MP"] = float(self.map[i]), "MAP"
else:
hdr[f"{tag}MP"] = "nan", "map not finite"
if self.evidence is not None and np.isfinite(self.evidence.logz):
prim.header["LOGZ"] = float(self.evidence.logz)
prim.header["LOGZMET"] = self.evidence.method[:20]
hdus: List[fits.hdu.base.ExtensionHDU] = [prim]
hdus.append(fits.ImageHDU(data=_as_float_array(self.covariance), header=hdr, name="COVARIANCE"))
hdus.append(fits.ImageHDU(data=_as_float_array(self.correlation), name="CORRELATION"))
# Percentiles table
t_pct = Table()
t_pct["percentile"] = np.asarray(self.percentiles, dtype=float)
for i, name in enumerate(self.parameter_names):
t_pct[f"{name}_val"] = _as_float_array(self.percentiles_values[i, :]) if self.percentiles_values.size else np.full(len(self.percentiles), np.nan)
t_pct[f"{name}_logp"] = _as_float_array(self.percentiles_logpost[i, :]) if self.percentiles_logpost.size else np.full(len(self.percentiles), np.nan)
# Add HDI intervals as header cards on the percentiles HDU (compact)
pct_hdr = fits.Header()
pct_hdr["HDIMASS"] = float(self.hdi_mass) if np.isfinite(self.hdi_mass) else "nan", "HDI mass"
for name, ivs in self.hdi_intervals.items():
# store up to 2 intervals by default
for j, (lo, hi) in enumerate(ivs[:2]):
pct_hdr[f"{name[:6]}L{j}"] = float(lo) if np.isfinite(lo) else "nan", "lower limit"
pct_hdr[f"{name[:6]}H{j}"] = float(hi) if np.isfinite(hi) else "nan", "upper limit"
hdus.append(fits.BinTableHDU(t_pct, name="PERCENTILES", header=pct_hdr))
# PDF1D table: store grid/pdf/kde per parameter as separate columns
t_pdf1 = Table()
for name, d in self.pdf_1d.items():
t_pdf1[f"{name}_x"] = _as_float_array(d["grid"])
t_pdf1[f"{name}_pdf"] = _as_float_array(d["hist_pdf"])
t_pdf1[f"{name}_kde"] = _as_float_array(d.get("kde_pdf", np.full_like(d["grid"], np.nan)))
if len(t_pdf1.colnames) > 0:
hdus.append(fits.BinTableHDU(t_pdf1, name="PDF1D"))
# PDF2D images
for (n0, n1), d in self.pdf_2d.items():
# pdf stored as (ny, nx) matching ygrid,xgrid
hdr2 = fits.Header()
hdr2["AX0"] = n0
hdr2["AX1"] = n1
hdus.append(fits.ImageHDU(data=_as_float_array(d["pdf"]), header=hdr2, name=f"PDF2D_{n0}__{n1}"[:68]))
if "enclosed_fraction" in d:
hdus.append(fits.ImageHDU(data=_as_float_array(d["enclosed_fraction"]), header=hdr2, name=f"ENCFRAC_{n0}__{n1}"[:68]))
return fits.HDUList(hdus)
[docs]
def write_fits(self, path: str, *, overwrite: bool = True) -> str:
"""Write FITS summary file."""
hdul = self.to_fits()
hdul.writeto(path, overwrite=overwrite)
return path
# -------------------------------------------------------------------------
# Plot helpers (optional)
# -------------------------------------------------------------------------
[docs]
def plot_1d_pdfs(
self,
outdir: str,
*,
show: bool = False,
dpi: int = 200,
) -> List[str]:
"""Save 1D PDF plots for all parameters."""
os.makedirs(outdir, exist_ok=True)
paths: List[str] = []
for i, name in enumerate(self.parameter_names):
if name not in self.pdf_1d:
continue
d = self.pdf_1d[name]
fig, ax = plt.subplots()
ax.plot(d["grid"], d["hist_pdf"], label="hist")
if "kde_pdf" in d and np.any(np.isfinite(d["kde_pdf"])):
ax.plot(d["grid"], d["kde_pdf"], label="kde")
if i < self.mean.size:
ax.axvline(self.mean[i], label="mean")
if i < self.map.size:
ax.axvline(self.map[i], label="MAP")
# Percentiles
if self.percentiles_values.size:
for p, v in zip(self.percentiles, self.percentiles_values[i, :]):
ax.axvline(v, alpha=0.5)
# HDI
if name in self.hdi_intervals:
for lo, hi in self.hdi_intervals[name]:
ax.axvspan(lo, hi, alpha=0.15)
ax.set_title(name)
ax.set_xlabel(name)
ax.set_ylabel("PDF")
ax.legend()
fp = os.path.join(outdir, f"pdf1d_{name}.png")
fig.savefig(fp, dpi=dpi, bbox_inches="tight")
paths.append(fp)
if show:
plt.show()
else:
plt.close(fig)
return paths
[docs]
def plot_2d_pdfs(
self,
outdir: str,
*,
levels: Sequence[float] = (0.68, 0.95),
show: bool = False,
dpi: int = 200,
) -> List[str]:
"""Save 2D PDF plots with HPD enclosed-fraction contours."""
os.makedirs(outdir, exist_ok=True)
paths: List[str] = []
for (n0, n1), d in self.pdf_2d.items():
xg = _as_float_array(d["xgrid"])
yg = _as_float_array(d["ygrid"])
pdf = _as_float_array(d["pdf"]) # (ny, nx)
frac = _as_float_array(d.get("enclosed_fraction", np.full_like(pdf, np.nan)))
fig, ax = plt.subplots()
X, Y = np.meshgrid(xg, yg, indexing="xy")
im = ax.pcolormesh(X, Y, pdf, shading="auto", cmap="Greys")
fig.colorbar(im, ax=ax, label="PDF")
if np.any(np.isfinite(frac)):
cs = ax.contour(X, Y, frac, levels=list(levels), linewidths=1.0)
ax.clabel(cs, inline=True, fontsize=8, fmt=lambda v: f"{v:.2f}")
ax.set_xlabel(n0)
ax.set_ylabel(n1)
ax.set_title(f"{n0} vs {n1}")
fp = os.path.join(outdir, f"pdf2d_{n0}__{n1}.png")
fig.savefig(fp, dpi=dpi, bbox_inches="tight")
paths.append(fp)
if show:
plt.show()
else:
plt.close(fig)
return paths
[docs]
def corner_plot(
self,
outpath: str,
*,
bins: int = 50,
max_points: int = 20000,
dpi: int = 200,
show: bool = False,
) -> str:
"""
Lightweight corner plot without external dependencies.
Diagonal: 1D hist PDFs (weighted)
Off-diagonal: scatter of (subsampled) points colored by weight rank (simple)
For serious usage, consider adding an optional dependency later (corner/arviz),
but this is a decent built-in baseline.
"""
npar = len(self.parameter_names)
if npar == 0 or self.samples.size == 0:
raise ValueError("No samples available to plot.")
# Subsample points for scatter
N = self.n_samples
if N > max_points:
# sample indices proportional to weights
idx = np.random.choice(np.arange(N), size=max_points, replace=False, p=self.weights)
else:
idx = np.arange(N)
S = self.samples[:, idx]
w = self.weights[idx]
fig, axes = plt.subplots(npar, npar, figsize=(2.2 * npar, 2.2 * npar), constrained_layout=True)
for i in range(npar):
for j in range(npar):
ax = axes[i, j]
if i == j:
x = self.samples[i, :]
xc, pdf = histogram_pdf_1d(x, self.weights, bins=bins)
ax.plot(xc, pdf, lw=1.2)
# mark mean/MAP
ax.axvline(self.mean[i], lw=1.0, alpha=0.8)
ax.axvline(self.map[i], lw=1.0, alpha=0.8)
ax.set_yticks([])
elif i > j:
ax.scatter(S[j, :], S[i, :], s=2, alpha=0.25)
else:
ax.axis("off")
if i == npar - 1 and j <= i:
ax.set_xlabel(self.parameter_names[j])
else:
ax.set_xticks([])
if j == 0 and i >= j:
ax.set_ylabel(self.parameter_names[i])
else:
ax.set_yticks([])
fig.savefig(outpath, dpi=dpi, bbox_inches="tight")
if show:
plt.show()
else:
plt.close(fig)
return outpath
# -----------------------------------------------------------------------------
# Main summarization function
# -----------------------------------------------------------------------------
[docs]
def summarize_results(
table: Table,
*,
output_fits: Optional[str] = None,
output_json: Optional[str] = None,
parameter_prefix: str = "--",
posterior_key: str = "post",
parameter_keys: Optional[Sequence[str]] = None,
percentiles: Sequence[float] = (0.05, 0.16, 0.5, 0.84, 0.95),
hdi_mass: float = 0.68,
compute_1d: bool = True,
compute_2d: bool = False,
parameter_key_pairs: Optional[Sequence[Tuple[str, str]]] = None,
pdf_bins_1d: int = 100,
pdf_bins_2d: int = 80,
estimate_evidence: bool = False,
evidence_method: str = "laplace",
logprior_key: str = "prior",
loglike_key: Optional[str] = None,
hme_trim_frac: float = 0.01,
kde_1d: bool = True,
kde_2d: bool = True,
extra_info: Optional[Dict[str, Any]] = None,
verbose: bool = False,
) -> ResultsSummary:
"""
Summarize posterior results from a samples table into a ResultsSummary.
Parameters
----------
table : astropy.table.Table
Input results table.
output_fits : str, optional
If provided, writes a FITS file with summary products.
output_json : str, optional
If provided, writes a JSON file with summary products.
parameter_prefix : str
Delimiter/prefix used in parameter names (default: "--").
posterior_key : str
Name of the log-posterior column (default: "post").
parameter_keys : list of str, optional
Explicit list of parameter columns to use. If None, selects columns
containing parameter_prefix.
percentiles : sequence of float
Quantiles in [0,1].
hdi_mass : float
Target mass for HDI intervals.
compute_1d / compute_2d : bool
Enable 1D/2D PDF products.
parameter_key_pairs : list of (key1, key2), required if compute_2d=True
pdf_bins_1d / pdf_bins_2d : int
Number of bins/grid size for PDFs.
kde_1d / kde_2d : bool
Prefer KDE; fallback to histogram if KDE fails.
extra_info : dict
Arbitrary metadata included in exports.
verbose : bool
Print progress.
Returns
-------
ResultsSummary
"""
if posterior_key not in table.colnames:
raise KeyError(f"posterior_key='{posterior_key}' not in table.")
keys = _select_parameter_keys(table, parameter_prefix=parameter_prefix, parameter_keys=parameter_keys)
# Extract and filter samples
logpost_all = _as_float_array(table[posterior_key])
# Finite mask across posterior and all selected parameters
mask = np.isfinite(logpost_all)
if not np.any(mask):
raise ValueError("No finite samples after masking posterior/parameters.")
logpost = logpost_all[mask]
# Stabilized weights from log-posterior
max_lp = np.max(logpost)
w = np.exp(logpost - max_lp)
w = normalize_weights(w)
# Samples matrix (D, N)
samples = np.vstack([_as_float_array(table[k])[mask] for k in keys])
# Filter NaN in samples
finite_mask = np.isfinite(samples).all(axis=0)
logpost = logpost[finite_mask]
w = w[finite_mask]
samples = samples[:, finite_mask]
npar, nsamp = samples.shape
map_idx = np.argmax(logpost)
map_vec = samples[:, map_idx]
mean_vec = weighted_mean(samples, w, axis=1)
cov = weighted_covariance(samples, w, unbiased=False)
corr = covariance_to_correlation(cov)
# Names/sections
sections = []
names = []
for k in keys:
sect, nm = _split_param_key(k, prefix=parameter_prefix)
sections.append(sect)
names.append(nm)
if verbose:
logger.info("Summarizing: %s parameters, %s samples (filtered).", npar, nsamp)
logger.info("Max logpost = %.6g", max_lp)
# Percentiles per parameter + interpolate logpost at those quantiles (via weighted CDF)
pct = np.asarray(percentiles, dtype=float)
if np.any((pct < 0) | (pct > 1)):
raise ValueError("percentiles must be in [0,1].")
pct_vals = np.full((npar, pct.size), np.nan, dtype=float)
pct_lp = np.full((npar, pct.size), np.nan, dtype=float)
for i in range(npar):
x = samples[i, :]
# Weighted quantiles
pct_vals[i, :] = weighted_quantile(x, w, pct)
# For logpost at those percentiles: sort by x, build weighted cdf, and interp logpost along that ordering
idx = np.argsort(x)
xs = x[idx]
ws = w[idx]
cdf = np.cumsum(ws)
cdf[-1] = 1.0
# logpost along sorted x
lps = logpost[idx]
pct_lp[i, :] = np.interp(pct, cdf, lps)
# HDI intervals
hdi: Dict[str, List[Tuple[float, float]]] = {}
for i, nm in enumerate(names):
hdi[nm] = weighted_hdi(samples[i, :], w, mass=hdi_mass, max_intervals=2)
# 1D PDFs
pdf1d: Dict[str, Dict[str, np.ndarray]] = {}
if compute_1d:
for i, nm in enumerate(names):
grid, hist_pdf = histogram_pdf_1d(samples[i, :], w, bins=pdf_bins_1d)
d = {"grid": grid, "hist_pdf": hist_pdf}
if kde_1d:
d["kde_pdf"] = kde_pdf_1d(samples[i, :], w, grid)
pdf1d[nm] = d
# 2D PDFs
pdf2d: Dict[Tuple[str, str], Dict[str, np.ndarray]] = {}
if compute_2d:
if parameter_key_pairs is None:
raise ValueError("compute_2d=True requires parameter_key_pairs.")
# Convert full keys to indices
key_to_idx = {k: i for i, k in enumerate(keys)}
for k0, k1 in parameter_key_pairs:
if k0 not in key_to_idx or k1 not in key_to_idx:
raise KeyError(f"Pair ({k0}, {k1}) not in selected parameter keys.")
i0 = key_to_idx[k0]
i1 = key_to_idx[k1]
n0 = names[i0]
n1 = names[i1]
xg, yg, pdf = kde_or_hist_pdf_2d(
samples[i0, :],
samples[i1, :],
w,
bins=pdf_bins_2d,
use_kde=kde_2d,
)
# pdf returned as (ny, nx) corresponding to yg, xg
frac = enclosed_fraction_map(pdf)
pdf2d[(n0, n1)] = {
"xgrid": xg,
"ygrid": yg,
"pdf": pdf,
"enclosed_fraction": frac,
}
summary = ResultsSummary(
parameter_keys=list(keys),
posterior_key=posterior_key,
parameter_sections=sections,
parameter_names=names,
n_samples=int(nsamp),
map_index=int(map_idx),
weights=w,
logpost=logpost,
samples=samples,
map=map_vec,
mean=mean_vec,
covariance=cov,
correlation=corr,
percentiles=list(map(float, pct.tolist())),
percentiles_values=pct_vals,
percentiles_logpost=pct_lp,
hdi_mass=float(hdi_mass),
hdi_intervals=hdi,
pdf_1d=pdf1d,
pdf_2d=pdf2d,
extra_info=dict(extra_info) if extra_info is not None else {},
)
if estimate_evidence:
summary.estimate_evidence_from_table(
table,
logprior_key=logprior_key,
loglike_key=loglike_key, # None -> will reconstruct from post-prior
method=evidence_method,
hme_trim_frac=hme_trim_frac,
parameter_prefix=parameter_prefix,
)
if output_fits is not None:
summary.write_fits(output_fits, overwrite=True)
if output_json is not None:
summary.write_json(output_json, overwrite=True)
return summary
# -----------------------------------------------------------------------------
# Convenience wrappers for file-based workflows
# -----------------------------------------------------------------------------
[docs]
def summarize_results_file(
results_path: str,
*,
output_fits: Optional[str] = None,
output_json: Optional[str] = None,
delimiter: str = "\t",
**kwargs,
) -> ResultsSummary:
"""Read a results file and summarize it (passes kwargs to summarize_results)."""
tab = read_results_file(results_path, delimiter=delimiter)
return summarize_results(tab, output_fits=output_fits, output_json=output_json, **kwargs)
[docs]
def compute_chain_percentiles(
chain_results: Mapping[str, Any],
*,
pct: Sequence[float] = (0.16, 0.5, 0.84),
weight_key: str = "weight",
parameter_prefix: str = "--",
) -> Dict[str, np.ndarray]:
"""
Compute weighted quantiles for chain-like results dict.
Expects chain_results[param] arrays and chain_results[weight_key] weights.
"""
if weight_key not in chain_results:
raise KeyError(f"'{weight_key}' not present in chain_results.")
w = _as_float_array(chain_results[weight_key]).ravel()
out: Dict[str, np.ndarray] = {}
for par, vals in chain_results.items():
if par == weight_key:
continue
if parameter_prefix not in par:
continue
x = _as_float_array(vals).ravel()
out[par] = weighted_quantile(x, w, pct)
return out
[docs]
def make_plot_chains(
chain_results: Mapping[str, Any],
*,
truth_values: Optional[Mapping[str, float]] = None,
weight_key: str = "weight",
parameter_prefix: str = "--",
outdir: Optional[str] = None,
show: bool = False,
dpi: int = 200,
) -> List[str]:
"""
Plot simple trace + weighted histogram per parameter.
Parameters
----------
chain_results : dict-like
Contains arrays for parameters and `weight_key`.
truth_values : dict, optional
Mapping parameter name -> truth value.
outdir : str, optional
If provided, saves PNGs to outdir and returns file paths.
If None, returns empty list and only shows/creates figures.
"""
if weight_key not in chain_results:
raise KeyError(f"'{weight_key}' not present in chain_results.")
w = _as_float_array(chain_results[weight_key]).ravel()
paths: List[str] = []
if outdir is not None:
os.makedirs(outdir, exist_ok=True)
for par, vals in chain_results.items():
if par == weight_key:
continue
if parameter_prefix not in par:
continue
x = _as_float_array(vals).ravel()
truth = np.nan
if truth_values is not None and par in truth_values:
truth = float(truth_values[par])
fig = plt.figure(constrained_layout=True, figsize=(7, 3))
ax = fig.add_subplot(111)
ax.plot(x, ",", c="k", alpha=0.6)
if np.isfinite(truth):
ax.axhline(truth, c="r", lw=1.0)
inax = ax.inset_axes((1.02, 0.0, 0.45, 1.0))
inax.hist(x[np.isfinite(x)], weights=normalize_weights(w[np.isfinite(x)]), bins=60)
if np.isfinite(truth):
inax.axvline(truth, c="r", lw=1.0)
inax.set_yticks([])
ax.set_title(par)
ax.set_xlabel("step")
ax.set_ylabel(par)
if outdir is not None:
safe = par.replace("/", "_").replace(" ", "_").replace(":", "_")
fp = os.path.join(outdir, f"chain_{safe}.png")
fig.savefig(fp, dpi=dpi, bbox_inches="tight")
paths.append(fp)
[docs]
def weighted_quantiles(x: np.ndarray, w: np.ndarray,
qs: Sequence[float]) -> np.ndarray:
"""Compute weighted quantiles for a one-dimensional sample."""
x = np.asarray(x); w = np.asarray(w)
m = np.isfinite(x) & np.isfinite(w) & (w >= 0)
if not m.any():
return np.array([np.nan] * len(qs))
x, w = x[m], w[m]
order = np.argsort(x)
x, w = x[order], w[order]
cdf = np.cumsum(w)
cdf = cdf / cdf[-1]
return np.interp(qs, cdf, x)
[docs]
def pdf_stats(edges: np.ndarray, pdf: np.ndarray,
quantiles=None,
find_multimodal: bool = False) -> dict:
"""
Compute summary stats from a discrete pdf over bin centers.
Parameters
----------
edges : ndarray, shape (K,)
Bin edges.
pdf : ndarray, shape (K,)
PDF defined by the edges.
quantiles : tuple of float, optional
Quantiles to report (default 16, 50, 84 percent).
find_multimodal : bool, optional
If True, return rough modes by local-maximum search.
Returns
-------
stats : dict
Keys: mean, std, map, q, lo68, hi68, modes (optional).
"""
# Ensure normalization
norm = np.sum(pdf * np.diff(edges))
pdf /= norm if norm > 0 else 1.0
centers = 0.5 * (edges[:-1] + edges[1:])
# Trapecium integrals
mean = np.sum(pdf * centers * np.diff(edges))
var = np.sum(pdf * (centers - mean)**2 * np.diff(edges))
std = var ** 0.5
k_map = np.argmax(pdf * np.diff(edges))
v_map = centers[k_map]
cdf = np.cumsum(pdf * np.diff(edges))
cdf = np.insert(cdf, 0, 0)
if quantiles is not None:
qs = np.array(quantiles, float)
qvals = np.interp(qs, cdf, edges, left=edges[0], right=edges[-1])
else:
qvals = None
median = np.interp(0.5, cdf, edges, left=edges[0], right=edges[-1])
lo68 = np.interp(0.16, cdf, edges, left=edges[0], right=edges[-1])
hi68 = np.interp(0.84, cdf, edges, left=edges[0], right=edges[-1])
out = {"mean": mean, "std": std, "map": v_map, "q": qvals,
"median": median, "lo68": lo68, "hi68": hi68}
if find_multimodal:
modes = []
for i in range(1, len(pdf) - 1):
if pdf[i] > pdf[i - 1] and pdf[i] > pdf[i + 1]:
modes.append(centers[i])
if not modes:
modes = [v_map]
out["modes"] = np.asarray(modes)
return out
[docs]
def pit_from_discrete_posterior(z_true: np.ndarray,
posts: np.ndarray,
z_edges: np.ndarray) -> np.ndarray:
"""
Probability Integral Transform for discrete posteriors on bins.
Assumes uniform density within each bin for within-bin interpolation.
Parameters
----------
z_true : ndarray, shape (N,)
True values.
posts : ndarray, shape (N, K)
Row-normalised posteriors over K bins.
z_edges : ndarray, shape (K+1,)
Bin edges.
Returns
-------
pit : ndarray, shape (N,)
PIT values in [0, 1].
"""
N, K = posts.shape
assert K == len(z_edges) - 1
cdf_bins = np.cumsum(posts, axis=1)
j = np.clip(np.digitize(z_true, z_edges) - 1, 0, K - 1)
idx = np.arange(N)
below = np.where(j > 0, cdf_bins[idx, j - 1], 0.0)
widths = z_edges[1:] - z_edges[:-1]
frac = np.clip((z_true - z_edges[j]) / widths[j], 0.0, 1.0)
pit = below + posts[idx, j] * frac
return np.clip(pit, 0.0, 1.0)
[docs]
def photoz_metrics(z_true: np.ndarray, z_est: np.ndarray) -> dict:
"""
Standard photo-z metrics using delta z over 1+z.
Returns
-------
out : dict
Keys: bias, nmad, outlier, rmse.
"""
d = (z_est - z_true) / (1.0 + z_true)
med = np.nanmedian(d)
nmad = 1.48 * np.nanmedian(np.abs(d - med))
outlier = float(np.mean(np.abs(d) > 0.15))
rmse = float(np.sqrt(np.nanmean(d ** 2)))
return {"bias": float(med), "nmad": float(nmad),
"outlier": outlier, "rmse": rmse}
[docs]
def plot_chains(table, truth_values=None, output_dir=None, posterior_key="post"):
"""Make trace plots from an astropy Table containing chain results.
Parameters
----------
table : Table
Astropy Table containing the chain results.
truth_values : list of float, optional
True values for the parameters (used for plotting).
output_dir : str, optional
Directory to save the output plots.
posterior_key : str, optional
Key to use for the posterior samples in the table.
Returns
-------
all_figs : list of Figure
List of all generated figures.
"""
parameters = [par for par in table.colnames if "parameters" in par]
if truth_values is None:
truth_values = [np.nan] * len(parameters)
all_figs = []
#TODO: user-configurable posterior key limits
if posterior_key is not None and posterior_key not in table.colnames:
raise ValueError(f"Posterior key '{posterior_key}' not found in table columns.")
else:
logger.info("Using posterior key: %s", posterior_key)
maxpost = np.nanmax(table[posterior_key])
vmin = max(np.nanmin(table[posterior_key]), maxpost - 3.4)
norm=plt.Normalize(vmin=vmin, vmax=maxpost)
x_values = np.arange(len(table[parameters[0]]))
for par, truth in zip(parameters, truth_values):
fig = plt.figure(constrained_layout=True)
ax = fig.add_subplot(111)
mappable = ax.scatter(x_values, table[par], s=0.5,
c=table[posterior_key], cmap="viridis",
norm=norm)
cbar = fig.colorbar(mappable, ax=ax)
cbar.set_label("log-posterior")
ax.set_xlabel("Sample index")
ax.set_ylabel(par.replace("parameters--", ""))
if truth is not None and np.isfinite(truth):
ax.axhline(truth, c="r")
if output_dir is not None:
fig.savefig(
os.path.join(
output_dir,
f"chain_plot_{par.replace('parameters--', '')}.png",
),
dpi=200,
bbox_inches="tight",
)
all_figs.append(fig)
return all_figs