"""Probability, likelihood, and prior models for grid-based inference."""
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Optional, Sequence, Tuple, List
import warnings
import numpy as np
from besta.logging import get_logger
logger = get_logger(__name__)
try:
from numba import njit, prange
NUMBA_OK = True
except Exception:
NUMBA_OK = False
logger.warning("numba could not be imported")
# ------------------------------- utilities -------------------------------
[docs]
def _logsumexp(a: np.ndarray, axis: Optional[int] = None) -> np.ndarray:
"""
Stable logsumexp.
Parameters
----------
a : ndarray
Input array.
axis : int or None
Axis over which to reduce. If None, reduces over all elements.
Returns
-------
out : ndarray
log(sum(exp(a))) along the given axis.
"""
m = np.nanmax(a, axis=axis, keepdims=True)
m[np.isneginf(m)] = 0.0
s = np.sum(np.exp(a - m), axis=axis, keepdims=True)
out = np.log(s) + m
if axis is None:
return out.reshape(())
return np.squeeze(out, axis=axis)
[docs]
def _normal_logpdf(x: np.ndarray, mu: np.ndarray, sigma: np.ndarray) -> np.ndarray:
"""
Univariate normal log pdf per element.
Parameters
----------
x, mu, sigma : ndarray
Broadcastable arrays. sigma must be positive.
Returns
-------
logp : ndarray
Log density values.
"""
var = sigma * sigma
return -0.5 * (np.log(2.0 * np.pi * var) + (x - mu) ** 2 / var)
[docs]
def _std_norm_cdf(x: np.ndarray) -> np.ndarray:
"""
Standard normal CDF using erf.
Parameters
----------
x : ndarray
Returns
-------
cdf : ndarray
"""
from math import sqrt
from scipy.special import (
erf,
) # if you prefer not to depend on SciPy, replace with a rational approx
return 0.5 * (1.0 + erf(x / sqrt(2.0)))
# ================================ Priors ================================
[docs]
class Prior(ABC):
"""
Abstract prior interface over model targets.
A prior returns log p(theta) for each model row. It may depend on
specific target columns (e.g., redshift) and optionally on other
model metadata.
Methods
-------
log_prob_for_models(targets)
Return log prior probability per model row.
"""
[docs]
@abstractmethod
def log_prob_for_models(self, targets: np.ndarray) -> np.ndarray:
"""
Evaluate log prior for each model row.
Parameters
----------
targets : ndarray, shape (N, Q)
Model targets array from ModelGrid.
Returns
-------
logp : ndarray, shape (N,)
Log prior probability per model.
"""
raise NotImplementedError
[docs]
@dataclass
class FlatPrior(Prior):
"""
Flat prior over all models (constant log probability).
"""
[docs]
def log_prob_for_models(self, targets: np.ndarray) -> np.ndarray:
return np.zeros(targets.shape[0], dtype=float)
[docs]
@dataclass
class GaussianPrior1D(Prior):
"""
1-D Gaussian prior over a single target column.
Parameters
----------
target_col : int
Index of the target column to build the prior on (e.g., redshift).
mu : float
Mean of the Gaussian prior.
sigma : float
Standard deviation of the Gaussian prior.
"""
target_col: int
mu: float
sigma: float
[docs]
def log_prob_for_models(self, targets: np.ndarray) -> np.ndarray:
t = targets[:, self.target_col]
return _normal_logpdf(t, self.mu, self.sigma)
[docs]
@dataclass
class DeltaPrior1D(Prior):
"""
1-D delta-function prior over a single target column.
Parameters
----------
target_col : int
Index of the target column to build the prior on (e.g., redshift).
value : float
Location of the delta prior.
tolerance : float, optional
Width around `value` to assign finite log-probability (default 1e-6).
"""
target_col: int
value: float
tolerance: float = 1e-6
[docs]
def log_prob_for_models(self, targets: np.ndarray) -> np.ndarray:
t = targets[:, self.target_col]
logp = np.where(
np.abs(t - self.value) <= self.tolerance,
-np.log(2.0 * self.tolerance),
-np.inf,
)
return logp
[docs]
@dataclass
class ExponentialPrior1D(Prior):
"""
1-D exponential prior over a single target column.
Parameters
----------
target_col : int
Index of the target column to build the prior on (e.g., redshift).
scale : float
Scale parameter (mean) of the exponential prior.
"""
target_col: int
scale: float
[docs]
def log_prob_for_models(self, targets: np.ndarray) -> np.ndarray:
t = targets[:, self.target_col]
logp = np.where(t >= 0, -t / self.scale - np.log(self.scale), -np.inf)
return logp
[docs]
@dataclass
class ExponentialTruncatedPrior1D(Prior):
"""
1-D truncated exponential prior over a single target column.
Parameters
----------
target_col : int
Index of the target column to build the prior on (e.g., redshift).
scale : float
Scale parameter (mean) of the exponential prior.
t_max : float
Maximum value of the target (support upper bound).
"""
target_col: int
scale: float
t_max: float
[docs]
def log_prob_for_models(self, targets: np.ndarray) -> np.ndarray:
t = targets[:, self.target_col]
norm = 1.0 - np.exp(-self.t_max / self.scale)
logp = np.where(
(t >= 0) & (t <= self.t_max),
-t / self.scale - np.log(self.scale * norm),
-np.inf,
)
return logp
[docs]
@dataclass
class PowerLawPrior1D(Prior):
"""
1-D power-law prior over a single target column.
Parameters
----------
target_col : int
Index of the target column to build the prior on (e.g., redshift).
alpha : float
Power-law index (p(t) ~ t^alpha).
t_min : float
Minimum value of the target (support lower bound).
t_max : float
Maximum value of the target (support upper bound).
"""
target_col: int
alpha: float
t_min: float
t_max: float
[docs]
def log_prob_for_models(self, targets: np.ndarray) -> np.ndarray:
t = targets[:, self.target_col]
norm = (self.t_max ** (self.alpha + 1) - self.t_min ** (self.alpha + 1)) / (
self.alpha + 1
)
logp = np.where(
(t >= self.t_min) & (t <= self.t_max),
self.alpha * np.log(t) - np.log(norm),
-np.inf,
)
return logp
[docs]
@dataclass
class EmpiricalHistogramPrior1D(Prior):
"""
Empirical 1-D histogram prior over a single target column.
This prior estimates p(t) from the model grid itself and assigns the
same log prior to all models that fall in the same bin.
Parameters
----------
target_col : int
Index of the target column to build the prior on (e.g., redshift).
edges : ndarray, shape (K+1,)
Histogram bin edges. Must cover the support of the target.
density_floor : float, optional
Minimum probability mass per bin to avoid -inf (default 1e-12).
"""
target_col: int
edges: np.ndarray
density_floor: float = 1e-12
[docs]
def fit_from_targets(
self, targets: np.ndarray, weights: Optional[np.ndarray] = None
) -> "EmpiricalHistogramPrior1D":
"""
Fit histogram from targets.
Parameters
----------
targets : ndarray, shape (N, Q)
weights : ndarray, shape (N,), optional
Returns
-------
self : EmpiricalHistogramPrior1D
"""
t = targets[:, self.target_col]
hist, _ = np.histogram(t, bins=self.edges, weights=weights, density=False)
mass = hist.astype(float)
mass = (
mass / np.sum(mass)
if np.sum(mass) > 0
else np.full_like(mass, 1.0 / mass.size)
)
mass = np.clip(mass, self.density_floor, None)
self._logp_per_bin = np.log(mass)
return self
[docs]
def log_prob_for_models(self, targets: np.ndarray) -> np.ndarray:
if not hasattr(self, "_logp_per_bin"):
raise RuntimeError("Prior not fitted. Call fit_from_targets first.")
t = targets[:, self.target_col]
j = np.digitize(t, self.edges) - 1
j = np.clip(j, 0, self._logp_per_bin.size - 1)
return self._logp_per_bin[j]
[docs]
@dataclass
class EmpiricalFlatteningPriorND(Prior):
"""
Empirical flattening prior over several target columns.
This prior uses the model grid itself to estimate the (possibly
non-flat) distribution of a set of parameters and builds a prior
that counteracts those inhomogeneities.
Two modes are provided:
- 'factorised': build 1-D histograms for each column separately and
form a product prior over dimensions. This approximately flattens
the *marginal* distributions of those parameters.
- 'joint': build a joint N-D histogram over all selected columns and
assign prior mass proportional to 1 / N_k for each occupied
N-D bin. This flattens the distribution over the joint grid of
bins, but is more memory hungry and can be sparse.
Parameters
----------
target_cols : sequence of int
Indices of the target columns to build the prior on.
edges_list : sequence of ndarray
List of bin edges for each column. Must have the same length as
`target_cols`. Each element is an array of shape (K_d + 1,)
defining the bin edges along dimension d.
mode : {'factorised', 'joint'}, optional
Flattening strategy (default 'factorised').
density_floor : float, optional
Minimum probability mass per bin (or joint cell) to avoid
-inf log-probabilities (default 1e-12). Only relevant for
'joint' mode.
count_floor : float, optional
Minimum effective count per bin to avoid infinite weights
(default 1e-3). Used to clip empty or nearly empty bins.
Notes
-----
The prior is defined *over models*, not over a continuous parameter
space. It is intended to compensate for non-uniform sampling of the
grid, so that the effective prior over the chosen parameters is
closer to flat.
"""
target_cols: Sequence[int]
edges_list: Sequence[np.ndarray]
mode: str = "factorised"
density_floor: float = 1e-12
count_floor: float = 1e-3
def __post_init__(self):
if len(self.target_cols) != len(self.edges_list):
raise ValueError("target_cols and edges_list must have the same length.")
if self.mode not in ("factorised", "joint"):
raise ValueError("mode must be 'factorised' or 'joint'.")
[docs]
def fit_from_targets(
self,
targets: np.ndarray,
weights: Optional[np.ndarray] = None,
) -> "EmpiricalFlatteningPriorND":
"""
Fit flattening prior from the model grid targets.
Parameters
----------
targets : ndarray, shape (N, Q)
Model grid targets.
weights : ndarray, shape (N,), optional
Optional model weights (e.g. importance weights) used to
define the effective occupancy per bin.
Returns
-------
self : EmpiricalFlatteningPriorND
"""
t = targets[:, self.target_cols] # (N, D)
D = t.shape[1]
if weights is not None and weights.shape[0] != t.shape[0]:
raise ValueError("weights must have shape (N,) if provided.")
if self.mode == "factorised":
# One histogram per dimension, store log inverse-mass per bin
log_inv_mass_list = []
for d in range(D):
edges = self.edges_list[d]
td = t[:, d]
counts, _ = np.histogram(td, bins=edges, weights=weights, density=False)
counts = counts.astype(float)
total = np.sum(counts)
if total <= 0:
raise RuntimeError(
f"No models in any bin for dimension {d}; cannot fit prior."
)
counts = np.clip(counts, self.count_floor, None)
# Define per-bin mass proportional to 1 / counts
inv_counts = 1.0 / counts
inv_counts /= np.sum(inv_counts)
# Store log(mass_d per bin) or directly log(1/count_d) up to a constant
# For our purpose, log prior for a model in bin j_d is sum_d log(inv_counts_d[j_d])
log_inv_mass_list.append(np.log(inv_counts))
self._log_inv_mass_list = log_inv_mass_list
else: # mode == 'joint'
# Build joint N-D histogram
bin_indices = []
bin_sizes = []
for d in range(D):
edges = self.edges_list[d]
td = t[:, d]
j = np.digitize(td, edges) - 1
# Clip into valid range
j = np.clip(j, 0, edges.size - 2)
bin_indices.append(j)
bin_sizes.append(edges.size - 1)
bin_indices = np.stack(bin_indices, axis=0) # (D, N)
# Flatten to 1-D indices for bincount
linear_indices = np.ravel_multi_index(
bin_indices, dims=tuple(bin_sizes)
) # (N,)
counts_flat = np.bincount(
linear_indices,
weights=weights,
minlength=int(np.prod(bin_sizes)),
).astype(float)
total = np.sum(counts_flat)
if total <= 0:
raise RuntimeError("No models in any joint bin; cannot fit prior.")
counts_flat = np.clip(counts_flat, self.count_floor, None)
self.counts_flat = counts_flat
# Per-cell prior mass proportional to 1 / counts
inv_counts_flat = 1.0 / counts_flat
inv_counts_flat /= np.sum(inv_counts_flat)
inv_counts_flat = np.clip(inv_counts_flat, self.density_floor, None)
self._log_mass_per_cell = np.log(inv_counts_flat)
self._bin_sizes = tuple(bin_sizes)
return self
[docs]
def log_prob_for_models(self, targets: np.ndarray) -> np.ndarray:
"""
Compute log prior for each model in the given targets array.
Parameters
----------
targets : ndarray, shape (N, Q)
Model grid targets.
Returns
-------
logp : ndarray, shape (N,)
Log prior for each target row.
"""
t = targets[:, self.target_cols] # (N, D)
N, D = t.shape
logp = np.zeros(N, dtype=float)
if self.mode == "factorised":
if not hasattr(self, "_log_inv_mass_list"):
raise RuntimeError("Prior not fitted. Call fit_from_targets first.")
for d in range(D):
edges = self.edges_list[d]
td = t[:, d]
j = np.digitize(td, edges) - 1
j = np.clip(j, 0, edges.size - 2)
log_inv_mass_d = self._log_inv_mass_list[d]
logp += log_inv_mass_d[j]
elif self.mode == "joint":
if not hasattr(self, "_log_mass_per_cell"):
raise RuntimeError("Prior not fitted. Call fit_from_targets first.")
bin_indices = []
for d in range(D):
edges = self.edges_list[d]
td = t[:, d]
j = np.digitize(td, edges) - 1
# Clip into valid range
j = np.clip(j, 0, edges.size - 2)
bin_indices.append(j)
linear_indices = np.ravel_multi_index(bin_indices, dims=self._bin_sizes)
logp += self._log_mass_per_cell[linear_indices]
return logp
[docs]
class ObservableDependentPrior(Prior):
"""TODO"""
def fit_from_grid(self):
raise NotImplementedError()
[docs]
@dataclass
class MagDependentRedshiftPrior(ObservableDependentPrior):
"""
Magnitude-dependent redshift prior p(z | m) from a 2-D histogram.
Parameters
----------
z_col : int
Index of redshift in targets.
mag_observable_index : int
Index of magnitude in observables (e.g., VIS magnitude column).
z_edges : ndarray
Bin edges in redshift.
m_edges : ndarray
Bin edges in magnitude.
density_floor : float, optional
Minimum conditional probability per (z bin) to avoid zeros.
Notes
-----
After calling fit_from_grid, the conditional log prior is defined by
log p(z_k | m_bin) for each magnitude bin. For a given model row, its
magnitude chooses a column in the 2-D histogram and the redshift of
that model chooses the row.
"""
z_col: int
mag_observable_index: int
z_edges: np.ndarray
m_edges: np.ndarray
density_floor: float = 1e-12
[docs]
def fit_from_grid(
self,
observables: np.ndarray,
targets: np.ndarray,
weights: Optional[np.ndarray] = None,
) -> "MagDependentRedshiftPrior":
"""
Fit conditional histogram from the model grid.
Parameters
----------
observables : ndarray, shape (N, P)
targets : ndarray, shape (N, Q)
weights : ndarray, shape (N,), optional
Returns
-------
self : MagDependentRedshiftPrior
"""
z = targets[:, self.z_col]
m = observables[:, self.mag_observable_index]
H, z_edges, m_edges = np.histogram2d(
z, m, bins=[self.z_edges, self.m_edges], weights=weights
)
# normalise each magnitude column to sum 1 over z
colsum = H.sum(axis=0, keepdims=True)
colsum[colsum == 0] = 1.0
P = H / colsum
P = np.clip(P, self.density_floor, None)
self._logP_z_given_m = np.log(P) # shape (Kz, Km)
return self
[docs]
def log_prob_for_models(
self, targets: np.ndarray, observables: Optional[np.ndarray] = None
) -> np.ndarray:
"""
Evaluate log p(z | m) per model row.
Parameters
----------
targets : ndarray, shape (N, Q)
observables : ndarray, shape (N, P), optional
Required for the magnitude column. If None, raises.
Returns
-------
logp : ndarray, shape (N,)
"""
if not hasattr(self, "_logP_z_given_m"):
raise RuntimeError("Prior not fitted. Call fit_from_grid first.")
if observables is None:
raise ValueError("observables must be provided to evaluate p(z|m)")
z = targets[:, self.z_col]
m = observables[:, self.mag_observable_index]
iz = np.clip(
np.digitize(z, self.z_edges) - 1, 0, self._logP_z_given_m.shape[0] - 1
)
im = np.clip(
np.digitize(m, self.m_edges) - 1, 0, self._logP_z_given_m.shape[1] - 1
)
return self._logP_z_given_m[iz, im]
[docs]
class HierarchicalPrior(Prior):
"""Abstract base class for priors controlled by learnable hyperparameters."""
def __init__(self, hyperparams: dict):
self.hyperparams = hyperparams
def update_hyperparams(self, new_values: dict) -> None:
self.hyperparams.update(new_values)
[docs]
@abstractmethod
def log_prob_for_models(self, targets: np.ndarray, **kwargs) -> np.ndarray:
pass
@abstractmethod
def fit_from_data(
self,
targets: np.ndarray,
observables: np.ndarray,
weights: np.ndarray | None = None,
) -> None:
pass
[docs]
@dataclass
class CompositePrior(Prior):
"""
Composite prior combining multiple target-only Priors.
The total log prior is defined as a weighted sum of component
log priors:
log p_total(model) = sum_i w_i * log p_i(model)
where each p_i is a Prior that does *not* depend on observables.
Parameters
----------
priors : sequence of Prior
List or tuple of prior instances to combine. Each must implement
`log_prob_for_models(targets)` and must NOT require observables.
weights : sequence of float, optional
Per-component multiplicative weights for the log-probabilities.
If None, all weights are taken as 1.0. A zero weight effectively
disables a component. Negative weights are allowed but should be
used with care, as they correspond to dividing by p_i(model) in
linear space.
"""
priors: Sequence[Prior]
weights: Optional[Sequence[float]] = None
def __post_init__(self):
self.priors = list(self.priors)
if not self.priors:
raise ValueError("CompositePrior requires at least one component prior.")
for i, pr in enumerate(self.priors):
if isinstance(pr, ObservableDependentPrior):
raise TypeError(
f"CompositePrior component {i} is observable-dependent "
"(ObservableDependentPrior); use ObservableCompositePrior instead."
)
if self.weights is not None:
if len(self.weights) != len(self.priors):
raise ValueError(
"weights must have the same length as priors "
f"({len(self.weights)} != {len(self.priors)})"
)
self._weights = np.asarray(self.weights, dtype=float)
else:
self._weights = None
[docs]
def log_prob_for_models(self, targets: np.ndarray) -> np.ndarray:
"""
Evaluate the composite log prior for each model row.
Parameters
----------
targets : ndarray, shape (N, Q)
Model targets array.
Returns
-------
logp : ndarray, shape (N,)
Composite log prior per model.
"""
N = targets.shape[0]
logp_total: Optional[np.ndarray] = None
for i, pr in enumerate(self.priors):
w_i = 1.0 if self._weights is None else float(self._weights[i])
if w_i == 0.0:
continue
lp_i = pr.log_prob_for_models(targets)
if lp_i.shape != (N,):
raise ValueError(
f"Component prior {i} returned shape {lp_i.shape}, expected ({N},)"
)
if w_i != 1.0:
lp_i = w_i * lp_i
logp_total = lp_i if logp_total is None else (logp_total + lp_i)
if logp_total is None:
warnings.warn(
"All weights were zero in CompositePrior; returning flat prior."
)
return np.zeros(N, dtype=float)
return logp_total
[docs]
@dataclass
class ObservableCompositePrior(ObservableDependentPrior):
"""
Composite prior combining Priors, including observable-dependent ones.
The total log prior is defined as a weighted sum of component
log priors:
log p_total(model) = sum_i w_i * log p_i(model)
Components can be:
- Plain Prior (target-only), evaluated as
p_i.log_prob_for_models(targets)
- ObservableDependentPrior, evaluated as
p_i.log_prob_for_models(targets, observables=...)
Parameters
----------
priors : sequence of Prior
List or tuple of prior instances to combine. Each must implement
`log_prob_for_models(...)` following the Prior or
ObservableDependentPrior API.
weights : sequence of float, optional
Per-component multiplicative weights for the log-probabilities.
If None, all weights are taken as 1.0. A zero weight effectively
disables a component. Negative weights are allowed but should be
used with care, as they correspond to dividing by p_i(model) in
linear space.
Notes
-----
- Because this class inherits from ObservableDependentPrior, external
code (e.g. GridFitter / posterior_over_models) should treat it as
requiring observables and pass them when available.
- If the composite contains any ObservableDependentPrior and
`observables` is None, a ValueError is raised.
"""
priors: Sequence[Prior]
weights: Optional[Sequence[float]] = None
def __post_init__(self):
self.priors = list(self.priors)
if not self.priors:
raise ValueError(
"ObservableCompositePrior requires at least one component prior."
)
if self.weights is not None:
if len(self.weights) != len(self.priors):
raise ValueError(
"weights must have the same length as priors "
f"({len(self.weights)} != {len(self.priors)})"
)
self._weights = np.asarray(self.weights, dtype=float)
else:
self._weights = None
[docs]
def log_prob_for_models(
self,
targets: np.ndarray,
observables: Optional[np.ndarray] = None,
) -> np.ndarray:
"""
Evaluate the composite log prior for each model row.
Parameters
----------
targets : ndarray, shape (N, Q)
Model targets array.
observables : ndarray, shape (N, P), optional
Model observables. Required if any component is an
ObservableDependentPrior. If such a prior is present and
observables is None, a ValueError is raised.
Returns
-------
logp : ndarray, shape (N,)
Composite log prior per model.
"""
N = targets.shape[0]
logp_total: Optional[np.ndarray] = None
need_obs = any(isinstance(pr, ObservableDependentPrior) for pr in self.priors)
if need_obs and observables is None:
raise ValueError(
"ObservableCompositePrior includes observable-dependent components "
"but 'observables' was not provided."
)
for i, pr in enumerate(self.priors):
w_i = 1.0 if self._weights is None else float(self._weights[i])
if w_i == 0.0:
continue # skip null-weight component
# Evaluate component prior
if isinstance(pr, ObservableDependentPrior):
lp_i = pr.log_prob_for_models(targets, observables=observables)
else:
lp_i = pr.log_prob_for_models(targets)
if lp_i.shape != (N,):
raise ValueError(
f"Component prior {i} returned shape {lp_i.shape}, expected ({N},)"
)
if w_i != 1.0:
lp_i = w_i * lp_i
logp_total = lp_i if logp_total is None else (logp_total + lp_i)
if logp_total is None:
# All weights were zero; fall back to flat
return np.zeros(N, dtype=float)
return logp_total
# ============================== Likelihoods ==============================
[docs]
class Likelihood(ABC):
"""
Abstract likelihood interface p(x | model).
Methods
-------
log_likelihood(x_native, sigma_native, X_models)
Return log likelihood per candidate model row.
"""
[docs]
@abstractmethod
def log_likelihood(
self, x_native: np.ndarray, sigma_native: np.ndarray, X_models: np.ndarray
) -> np.ndarray:
"""
Evaluate log likelihood for each model.
Parameters
----------
x_native : ndarray, shape (P,)
Query observables.
sigma_native : ndarray, shape (P,)
Measurement uncertainties for the query.
X_models : ndarray, shape (Nc, P)
Candidate model observables.
Returns
-------
logL : ndarray, shape (Nc,)
"""
raise NotImplementedError
[docs]
@dataclass
class GaussianProductLikelihood(Likelihood):
"""
Independent per-dimension Gaussian product likelihood.
The likelihood is proportional to the product over j of
N(x_j | X_ij, h_j^2), where h_j is derived from sigma_native with
an optional floor.
Parameters
----------
bandwidth_floor : float, optional
Minimum bandwidth per dimension in native units. Default 0.0.
scale : float, optional
Multiplicative scale applied to sigma_native. Default 1.0.
"""
bandwidth_floor: float = 0.0
scale: float = 1.0
[docs]
def log_likelihood(
self, x_native: np.ndarray, sigma_native: np.ndarray, X_models: np.ndarray
) -> np.ndarray:
h = np.maximum(self.scale * sigma_native, self.bandwidth_floor)
# broadcast to (Nc, P)
diff = X_models - x_native[None, :]
var = h[None, :] ** 2
# sum of 1-D logpdfs
logL = -0.5 * (
np.sum(np.log(2.0 * np.pi * var), axis=1) + np.sum(diff**2 / var, axis=1)
)
return logL
[docs]
@dataclass
class CensoredSizeLikelihood(Likelihood):
"""
Photometry-only Gaussian product with a left-censored size factor.
This is useful when the apparent size is below a reliability floor
(e.g., PSF or measurement threshold). The photometric part is a
Gaussian product over selected photometry indices. The size part
adds a log CDF factor log Phi((s_min - s_model) / h_s), where s is
log10(Re) and h_s is derived from sigma_native[size_index].
Parameters
----------
phot_indices : Sequence[int]
Indices of observable columns to include in the Gaussian product
(typically colours and anchor magnitude).
size_index : int
Index of the size observable column (e.g., log10(Re)).
s_min : float
Left-censoring threshold in the same units as the size observable.
bandwidth_floor : float, optional
Minimum bandwidth per dimension in native units. Default 0.0.
scale : float, optional
Multiplicative scale applied to sigma_native. Default 1.0.
"""
phot_indices: Sequence[int]
size_index: int
s_min: float
bandwidth_floor: float = 0.0
scale: float = 1.0
[docs]
def log_likelihood(
self, x_native: np.ndarray, sigma_native: np.ndarray, X_models: np.ndarray
) -> np.ndarray:
# Photometry part
phot_idx = np.asarray(self.phot_indices, dtype=int)
x_ph = x_native[phot_idx]
sig_ph = np.maximum(self.scale * sigma_native[phot_idx], self.bandwidth_floor)
Xm_ph = X_models[:, phot_idx]
diff = Xm_ph - x_ph[None, :]
var = sig_ph[None, :] ** 2
logL_ph = -0.5 * (
np.sum(np.log(2.0 * np.pi * var), axis=1) + np.sum(diff**2 / var, axis=1)
)
# Censored size factor: log Phi((s_min - s_model)/h_s)
h_s = max(self.scale * sigma_native[self.size_index], self.bandwidth_floor)
s_model = X_models[:, self.size_index]
z = (self.s_min - s_model) / h_s
# avoid log(0)
cdf = np.clip(_std_norm_cdf(z), 1e-300, 1.0)
logL_sz = np.log(cdf)
return logL_ph + logL_sz
[docs]
@dataclass
class CompositeLikelihood(Likelihood):
"""
Sum of multiple likelihood terms (log-likelihoods add).
Parameters
----------
terms : list of Likelihood
Components to add together.
"""
terms: List[Likelihood]
[docs]
def log_likelihood(
self, x_native: np.ndarray, sigma_native: np.ndarray, X_models: np.ndarray
) -> np.ndarray:
total = None
for lk in self.terms:
lp = lk.log_likelihood(x_native, sigma_native, X_models)
total = lp if total is None else (total + lp)
return total if total is not None else np.zeros(X_models.shape[0], dtype=float)
# =========================== posterior helper ===========================
[docs]
def posterior_over_models(
x_native: np.ndarray,
sigma_native: np.ndarray,
X_models: np.ndarray,
targets_models: np.ndarray,
likelihood: Likelihood,
prior: Prior,
model_weights: Optional[np.ndarray] = None,
prior_needs_observables: bool = False,
observables_models: Optional[np.ndarray] = None,
) -> np.ndarray:
"""
Compute normalised posterior weights over models.
Parameters
----------
x_native : ndarray, shape (P,)
Query observables.
sigma_native : ndarray, shape (P,)
Measurement uncertainties for the query.
X_models : ndarray, shape (Nc, P)
Candidate model observables.
targets_models : ndarray, shape (Nc, Q)
Candidate model targets.
likelihood : Likelihood
Likelihood instance to evaluate log p(x | model).
prior : Prior
Prior instance to evaluate log p(model).
model_weights : ndarray, shape (Nc,), optional
Optional sampling weights for models; multiplies the posterior.
prior_needs_observables : bool, optional
If True, the prior expects observables to evaluate (for p(z|m)).
observables_models : ndarray, shape (Nc, P), optional
Candidate observables to pass to the prior if needed.
Returns
-------
w : ndarray, shape (Nc,)
Normalised posterior weights over candidate models.
"""
logL = likelihood.log_likelihood(x_native, sigma_native, X_models)
if prior_needs_observables:
logP = prior.log_prob_for_models(
targets_models,
observables=observables_models
if observables_models is not None
else X_models,
)
else:
logP = prior.log_prob_for_models(targets_models)
logw = logL + logP
if model_weights is not None:
# multiply in linear space -> add in log space
with np.errstate(divide="ignore"):
logw = logw + np.log(np.clip(model_weights, 1e-300, np.inf))
# normalise safely
logZ = _logsumexp(logw)
w = np.exp(logw - logZ)
return w
# Numba-dedicated likelihood
@njit(parallel=True, fastmath=True, cache=True)
def _quadform_diag_parallel(X, x, h):
N, P = X.shape
out = np.empty(N, dtype=np.float64)
invh = 1.0 / h
for i in prange(N):
s = 0.0
Xi = X[i]
# unrolled-style simple loop lets numba vectorise well
for j in range(P):
d = (Xi[j] - x[j]) * invh[j]
s += d * d
out[i] = s
return out # squared Mahalanobis with diagonal covariance
@njit(parallel=True, fastmath=True, cache=True)
def _loglike_gaussprod_diag(X, x, h):
# log L_i = -0.5 * sum_j ((X_ij - x_j)/h_j)^2 (constants drop)
q = _quadform_diag_parallel(X, x, h)
return -0.5 * q
[docs]
class NumbaGaussianProductLikelihood(GaussianProductLikelihood):
"""
Diagonal Gaussian product likelihood accelerated with Numba.
log L_i = -0.5 * sum_j ((X_ij - x_j)/h_j)^2
"""
def __init__(
self,
bandwidth_floor: float = 0.0,
scale: float = 1.0,
prefer_batch: bool = False,
):
super().__init__(bandwidth_floor=bandwidth_floor, scale=scale)
self.prefer_batch = prefer_batch
[docs]
def log_likelihood(self, x_native, sigma_native, X_models):
if not NUMBA_OK:
return super().log_likelihood(x_native, sigma_native, X_models)
# Expect C-contiguous float64 for best performance
x = np.ascontiguousarray(x_native, dtype=np.float64)
X = np.ascontiguousarray(X_models, dtype=np.float64)
sigma = np.ascontiguousarray(sigma_native, dtype=np.float64)
# Match parent class: apply scale and bandwidth_floor
h = np.maximum(self.scale * sigma, self.bandwidth_floor)
h = np.ascontiguousarray(h, dtype=np.float64)
return _loglike_gaussprod_diag(X, x, h)