Source code for besta.pipeline_modules.base_module

"""
Base pipeline module. This module contains the base class for creating new
pipeline modules in BESTA.
"""
from __future__ import annotations

from abc import abstractmethod
import os
import pickle
import sys

import importlib.util
import pathlib
import types
from typing import Callable, Optional

from matplotlib import pyplot as plt
from besta.visualization import draw_dict_in_axes

import numpy as np
from scipy.stats import norm
from sklearn.decomposition import NMF
from astropy import units as u
from astropy.io import ascii

from cosmosis import ClassModule
from cosmosis import DataBlock
from cosmosis.datablock import SectionOptions, option_section

from pst.utils import flux_conserving_interpolation
from pst.observables import Filter, FilterList
from pst import SSP, dust, sed
from pst.galaxy import GalaxySED

from besta import spectrum
from besta import kinematics
from besta import sfh
from besta import io
from besta import utils
from besta.grid import ModelGrid
from besta.config import cosmology, memory
from besta.logging import get_logger, setup_logging

logger = get_logger(__name__)


def _log(*args):
    logger.info(" ".join(str(arg) for arg in args))

[docs] class BaseModule(ClassModule): """BESTA Pipeline module base class."""
[docs] def __init__(self, options, *, alias=None): """ Set up the CosmoSIS module. Parameters ---------- options : dict or DataBlock Options from the startup configuration. alias : str, optional Module alias used to resolve section names. """ if alias is None: self.alias = self.name else: self.alias = alias options = self.parse_options(options) logging_level = options.get_string("logging_level", default="INFO") logging_file = options.get_string("logging_file", default="none") # Handle bool or string values if options.has_value("logging_overwrite"): logging_overwrite = options["logging_overwrite"] if isinstance(logging_overwrite, str): logging_overwrite = True if "t" in logging_overwrite.lower() else False else: logging_overwrite = False if options.has_value("logging_console"): logging_console = options["logging_console"] if isinstance(logging_console, str): logging_console = True if "t" in logging_console.lower() else False else: logging_console = False setup_logging(level=logging_level, log_file=logging_file if logging_file.lower() != "none" else None, overwrite=logging_overwrite, console=logging_console) self.config = {} # Likelihood name if options.has_value("like_name"): self.like_name = options["like_name"] if "_like" not in self.like_name: self.like_name += "_like" else: self.like_name = self.name + "_like" _log("Setting module likelihood name to default: ", self.like_name)
[docs] @abstractmethod def make_observable(self, *args, **kwargs): """Create an observable from an input set of model parameters."""
[docs] @abstractmethod def execute(self, block: DataBlock, *args, **kwargs): """Execute the pipeline."""
[docs] @abstractmethod def plot_solution(self, *args, **kwargs): """Plot the fit results.""" pass
[docs] @classmethod def get_path(cls): """Get the path to the module file.""" return sys.modules[cls.__module__].__file__
[docs] def parse_options(self, options: dict | DataBlock): """Parse the input setup options. Convert the input options into a :class:`SectionOptions` Parameters ---------- options : dict or :class:`DataBlock` Module setup options. Returns ------- options : :class:`SectionOptions` or :class:`DataBlock` """ if isinstance(options, dict): logger.debug("Parsing input options from dict") logger.debug("Input options: %s", options) options = DataBlock.from_dict(options) if options.has_section(option_section): options._delete_section(option_section) keys = options.keys(self.alias) if not keys: raise ValueError(f"No options found for module alias {self.alias}") for section, name in keys: options[option_section, name] = options[section, name] options = SectionOptions(options) return options
[docs] def prepare_ssp_model(self, options, normalize=False, velocity_buffer=800.0): """Prepare the SSP data. Parameters ---------- options : :class:`DataBlock` Input options to initialise the model. normalize : bool, optional If ``True``, normalizes the spectra using the given wavelength range. velocity_buffer : float Buffer offset (in terms of velocity) to keep extra wavelength elements. This reduced the corruption of the spectra at the edges during convolution. The buffer is applied to both sides of the SSP spectra. """ _log("\n-> Configuring SSP model") if options.has_value("SSPModelFromPickle"): _log("\n-> Loading preconfigured SSP model from pickle") if not os.path.isfile( os.path.expandvars(options["SSPModelFromPickle"])): raise FileNotFoundError( f"Input pickle file {options['SSPModelFromPickle']} not found") # Load the SSP model ssp = SSP.SSPBase.from_pickle( os.path.expandvars(options["SSPModelFromPickle"])) self.config["ssp_model"] = ssp self.config["ssp_sed"] = ssp.L_lambda.value.reshape( (ssp.L_lambda.shape[0] * ssp.L_lambda.shape[1], ssp.L_lambda.shape[2])) self.config["ssp_wl"] = ssp.wavelength.to_value("Angstrom") # Grid parameters velscale = options["velscale"] dlnlam = velscale / spectrum.constants.c.to("km/s").value extra_offset_pixel = int(velocity_buffer / velscale) self.config["velscale"] = velscale self.config["extra_pixels"] = extra_offset_pixel _log("-> Configuration done.") return ssp_name = options["SSPModel"] if options.has_value("SSPDir"): ssp_dir = options["SSPDir"] if "none" in ssp_dir.lower(): ssp_dir = None else: ssp_dir = None # Additional arguments to be passed to the SSP model if options.has_value("SSPModelArgs"): ssp_all_args = options["SSPModelArgs"] if isinstance(ssp_all_args, str): ssp_all_args = ssp_all_args.split(",") ssp_args = [] ssp_kwargs = {} for i, arg in enumerate(ssp_all_args): if "=" in arg: key, value = arg.split("=", 1) value = io._parse_value(value) ssp_kwargs[key.strip()] = value else: value = io._parse_value(arg) ssp_args.append(value) _log("SSP Model extra arguments: ", ssp_args, ssp_kwargs) else: ssp_args = [] ssp_kwargs = {} ssp_kwargs["path"] = ssp_dir ssp = getattr(SSP, ssp_name)(*ssp_args, **ssp_kwargs) # For photometric analyses stop here #TODO: loose check if not options.has_value("velscale"): self.config["ssp_model"] = ssp return # Parameters to format the templates to the input spectra velscale = options["velscale"] # Rebin the spectra dlnlam = velscale / spectrum.constants.c.to("km/s").value extra_offset_pixel = int(np.ceil(velocity_buffer / velscale)) _log( "Log-binning SSP spectra to velocity scale: ", velscale, " km/s per pixel", f"Keeping {extra_offset_pixel} extra pixels at both edges", ) if "ln_wave" in self.config: ln_wl_edges = self.config["ln_wave"][[0, -1]] else: ln_wl_edges = np.log(ssp.wavelength[[0, -1]].to_value("angstrom")) extra_offset_pixel = 0 lnlam_bin_edges = np.arange( ln_wl_edges[0] - 0.5 * dlnlam - dlnlam * extra_offset_pixel, ln_wl_edges[-1] + dlnlam * (1 + extra_offset_pixel), dlnlam, ) lnlam_bins = (lnlam_bin_edges[:-1] + lnlam_bin_edges[1:]) / 2 # Resample the SED ssp.interpolate_sed(np.exp(lnlam_bins), method="binfrac") _log("SSP Model SED dimensions (met, age, lambda): ", ssp.L_lambda.shape) # Convolve with instrumental LSF if "lsf" in self.config: _log("Convolving SSP model with instrumental LSF") inst_lsf = np.interp(ssp.wavelength, self.config["wavelength"], self.config["lsf"]) if options.has_value("SSPLSF"): _log("Including SSP resolution") ssp_lsf_wl, ssp_lsf_fwhm = np.loadtxt( os.path.expandvars(options["SSPLSF"]), unpack=True, usecols=(0, 1)) ssp_lsf_fwhm = np.interp(ssp.wavelength, ssp_lsf_wl << u.AA, ssp_lsf_fwhm) else: ssp_lsf_fwhm = np.zeros(ssp.wavelength.size, dtype=float) # Assume both LSF are Gaussian effective_lsf_disp = (inst_lsf / 2.355)**2 - (ssp_lsf_fwhm / 2.355)**2 if (effective_lsf_disp < 0).any(): raise ValueError("Effective SSP LSF cannot be negative!" + "SSP models do not have enough resolution") effective_lsf = np.sqrt(effective_lsf_disp) # Convert to pixels lsf_sigma_pixels = effective_lsf / np.diff(np.exp(lnlam_bin_edges)) _log("Starting convolution of SSP models with wavelength-dependent", f"LSF [min sigma={lsf_sigma_pixels.min():.2}," f" max sigma={lsf_sigma_pixels.max():.2} pix]") try: utils.check_array_memory( (ssp.L_lambda.shape[0], ssp.L_lambda.shape[1], ssp.L_lambda.shape[2], ssp.L_lambda.shape[2]), dtype=ssp.L_lambda[0, 0, 0].dtype, unit='GB', safety_margin=memory["ram_safety_margin"]) _log("Convolving full SSP model at once") ssp.L_lambda = kinematics.convolve_variable_gaussian_kernel( ssp.L_lambda, lsf_sigma_pixels) except MemoryError: # Do a loop along metallicity axis to prevent memory overflows _log("Insufficient RAM memory for full SSP SED convolution") _log("Looping along metallicity axis") utils.check_array_memory( (ssp.L_lambda.shape[1], ssp.L_lambda.shape[2], ssp.L_lambda.shape[2]), dtype=ssp.L_lambda[0, 0, 0].dtype, unit='GB', safety_margin=memory["ram_safety_margin"]) for ith in range(ssp.L_lambda.shape[0]): ssp.L_lambda[ith] = kinematics.convolve_variable_gaussian_kernel( ssp.L_lambda[ith], lsf_sigma_pixels) self.config["ssp_model"] = ssp self.config["ssp_wl"] = ssp.wavelength.to_value("Angstrom") # Grid parameters self.config["velscale"] = velscale self.config["extra_pixels"] = extra_offset_pixel if options.has_value("SaveSSPModel"): _log("Saving SSP model to ", options["SaveSSPModel"]) ssp.to_pickle(os.path.expandvars(options["SaveSSPModel"])) _log("-> Configuration done.") return
[docs] def prepare_extinction_law(self, options): """Prepare a dust extinction model. options : :class:`DataBlock` Input options to initialise the model. """ _log("\n -> Configuring Dust extinction model") if not options.has_value("ExtinctionLaw"): self.config["extinction_law"] = None return ext_law = options.get_string("ExtinctionLaw") _log("Extinction law: ", ext_law) # TODO: add more extinction laws self.config["extinction_law"] = dust.DustScreen(ext_law) _log("-> Configuration is done.")
[docs] def prepare_sfh_model(self, options): """Prepare the SFH model. Parameters ---------- options : :class:`DataBlock` Input options to initialise the model. """ _log("\n-> Configuring SFH model") sfh_model_name = options["SFHModel"] sfh_args = [] sfh_kwargs = {} if options.has_value("SFHArgs"): sfh_all_args = options["SFHArgs"] if isinstance(sfh_all_args, str): sfh_args, sfh_kwargs = io.string_to_func_args(sfh_all_args) elif isinstance(sfh_all_args, list): for arg in sfh_all_args: if isinstance(arg, str): a, ka = io.string_to_func_args(arg) sfh_args.extend(a) sfh_kwargs.update(ka) elif isinstance(arg, dict): sfh_kwargs.update(arg) else: sfh_args.append(arg) else: sfh_args.append(sfh_all_args) logger.info("SFH Model extra arguments: %s", sfh_args) logger.info("SFH Model extra keyword arguments: %s", sfh_kwargs) # Optional: enable parameter transforms inside SFH models self.config["use_transforms"] = False if options.has_value("use_transforms"): self.config["use_transforms"] = bool(options["use_transforms"]) if self.config["use_transforms"]: _log("Enabling parameter transforms inside SFH model") _log("SFH model name: ", sfh_model_name) sfh_model = getattr(sfh, sfh_model_name) sfh_model = sfh_model(*sfh_args, **sfh_kwargs, **self.config) self.config["sfh_model"] = sfh_model _log("-> Configuration done")
[docs] def log_like(self, data, model, var, weights=None, is_upper=None, is_lower=None, include_norm=True): """Compute log-likelihood between data and model. Parameters ---------- data : np.ndarray For detections: measured values. For limits: the limit value (upper or lower). model : np.ndarray Model prediction for each datum. var : np.ndarray Data variance. Must match shape of data/model. weights : np.ndarray, optional Data weights. If provided, returns weighted mean log-likelihood. is_upper : np.ndarray[bool], optional Mask for upper limits (x < data). is_lower : np.ndarray[bool], optional Mask for lower limits (x > data). include_norm : bool, optional, default=True If True, include Gaussian normalization terms for detections: -0.5*log(2*pi*var). Returns ------- loglike : float Total (or weighted-mean) log-likelihood. """ if data.shape != model.shape or data.shape != var.shape: raise ValueError("data, model, var must have the same shape (var is per-datum variance).") if np.any(var <= 0): raise ValueError("All var entries must be > 0 (variance).") if is_upper is None: is_upper = np.zeros_like(data, dtype=bool) else: is_upper = np.asarray(is_upper, dtype=bool) if is_lower is None: is_lower = np.zeros_like(data, dtype=bool) else: is_lower = np.asarray(is_lower, dtype=bool) if is_upper.shape != data.shape or is_lower.shape != data.shape: raise ValueError("is_upper and is_lower must have the same shape as data/model.") if np.any(is_upper & is_lower): raise ValueError("A data point cannot be both an upper and a lower limit.") if weights is None: weights = np.ones_like(data, dtype=float) normalize = False else: weights = np.asarray(weights, dtype=float) if weights.shape != data.shape: raise ValueError("weights must have the same shape as data/model.") if np.any(weights < 0): raise ValueError("weights must be non-negative.") normalize = True # TODO: to avoid this step the pipeline should store sigma sigma = np.sqrt(var) logp = np.empty_like(data, dtype=float) det = ~(is_upper | is_lower) # For detections: Gaussian logpdf if np.any(det): if include_norm: logp[det] = norm.logpdf(data[det], loc=model[det], scale=sigma[det]) else: z = (data[det] - model[det]) / sigma[det] logp[det] = -0.5 * z**2 # Upper limits: P(x < L) if np.any(is_upper): z_u = (data[is_upper] - model[is_upper]) / sigma[is_upper] logp[is_upper] = norm.logcdf(z_u) # Lower limits: P(x > L) = 1 - P(x < L) if np.any(is_lower): z_l = (data[is_lower] - model[is_lower]) / sigma[is_lower] logp[is_lower] = norm.logsf(z_l) # 1 - logcdf # User-provided weighted mean if normalize: wsum = np.sum(weights) if wsum <= 0: if np.all(weights == 0): raise ValueError("All weights are zero.") else: raise ValueError("Sum of weights must be > 0 for normalization.") return np.sum(logp * weights) / wsum return np.sum(logp * weights)
[docs] class SpectraFitModule(BaseModule): """Base class for spectral fitting modules in BESTA."""
[docs] def prepare_observed_spectra( self, options: DataBlock, normalize=False): """Prepare the input spectra data. Parameters ---------- options : :class:`DataBlock` normalize : bool, optional If ``True``, normalizes the spectra using the given wavelength range. """ _log("\n-> Configuring input observed spectra") filename = os.path.expandvars(options["inputSpectrum"]) # Read wavelength and spectra _log("Loading observed spectra from input file: ", filename) wavelength, flux, error = np.loadtxt(filename, unpack=True) _log("Wavelength coverage: ", wavelength[[0, -1]]) _log("Size: ", wavelength.size) # Convert units if needed if options.has_value("wlUnits"): _log("Converting wavelength units to Angstrom") wl_units = u.Unit(options["wlUnits"]) wavelength = (wavelength << wl_units).to("Angstrom").value else: _log("Assuming input wavelength units are in Angstrom") wl_units = u.angstrom if options.has_value("fluxUnits"): _log("Converting flux units to 1e-16 erg/s/cm^2/Angstrom") flux_units = u.Unit(options["fluxUnits"]) flux = (flux << flux_units).to( "1e-16 erg / (s cm2 Angstrom)").value error = (error << flux_units).to( "1e-16 erg / (s cm2 Angstrom)").value else: _log("Assuming input flux units are in 1e-16 erg/s/cm^2/Angstrom") flux_units = u.Unit("1e-16 erg / (s cm2 Angstrom)") # Wavelength range to include in the fit if options.has_value("wlRange"): wl_range = (np.asarray(options["wlRange"]) << wl_units ).to("Angstrom").value else: wl_range = wavelength[[0, -1]] # Wavelength range to renormalize the spectra if options.has_value("wlNormRange"): logger.warning("Input option 'wlNormRange' is deprecated and will be removed in future versions. ") wl_norm_range = (np.asarray(options["wlNormRange"]) << wl_units ).to("Angstrom").value else: wl_norm_range = wavelength[[0, -1]] # Input redshift (initial guess) if options.has_value("redshift"): redshift = options["redshift"] else: _log("No input redshift value provided (defaulting to 0)") redshift = 0.0 # Load mask if options.has_value("mask"): weights = np.array( np.loadtxt(os.path.expandvars(options["mask"])), dtype=float) else: weights = np.ones_like(flux) if weights.size != flux.size: raise ValueError( "Input mask size does not match the input spectrum size.") # Load the instrumental LSF if options.has_value("lsf"): lsf_wl, lsf_fwhm = np.loadtxt(os.path.expandvars(options["lsf"]), unpack=True) instrumental_lsf = np.array(np.interp(wavelength, lsf_wl, lsf_fwhm), dtype=float) else: instrumental_lsf = np.zeros_like(wavelength) # Optional masking of telluric regions if options.has_value("mask_telluric") and options["mask_telluric"]: telluric_pad = options.get_double("telluric_pad", default=0.0) telluric_pad = (telluric_pad << wl_units).to("Angstrom").value _log(f"Masking telluric regions with pad={telluric_pad} Angstrom") weights_tell, tell_mask, bands_used = spectrum.mask_telluric_regions( wavelength, weight=weights, redshift=0.0, pad=telluric_pad, return_mask=True) _log("Number of telluric-absorption masked pixels: ", np.count_nonzero(tell_mask)) weights *= weights_tell self.config["telluric_mask"] = tell_mask self.config["telluric_bands_used"] = bands_used if options.has_value("mask_sky_lines") and options["mask_sky_lines"]: sky_line_pad = options.get_double("sky_line_pad", default=0.0) sky_line_pad = (sky_line_pad << wl_units).to("Angstrom").value _log(f"Masking sky line regions with pad={sky_line_pad} Angstrom") weights_sky, sky_mask, lines_used = spectrum.mask_sky_emission_lines( wavelength, flux=flux, uncertainty=error, weight=weights, redshift=0.0, # Mask in observed frame pad=sky_line_pad, return_mask=True, return_lines_masked=True) _log("Number of sky-line masked pixels: ", np.count_nonzero(sky_mask)) weights *= weights_sky self.config["sky_line_mask"] = sky_mask self.config["sky_lines_used"] = lines_used # Optional masking of emission lines if options.has_value("mask_emission_lines") and options["mask_emission_lines"]: weights_el, line_mask, lines_used = spectrum.mask_strong_emission_lines( wavelength, flux, error, weights, redshift=redshift, # line_list=emission_line_list, # half_width=line_half_width, return_mask=True, return_lines_masked=True) weights *= weights_el _log("Number of emission-line masked pixels: ", np.count_nonzero(line_mask)) self.config["emission_lines_mask"] = line_mask self.config["emission_lines_used"] = lines_used # Apply redshift _log(f"Setting wavelength array to restframe (redshift: {redshift})") wavelength /= 1.0 + redshift _log("Constraining fit to wavelength range: ", wl_range) good_idx = np.where( (wavelength >= wl_range[0]) & (wavelength <= wl_range[1]))[0] if len(good_idx) == 0: raise ValueError("No wavelength points found within the given" "wavelength range.") wavelength = wavelength[good_idx] flux = flux[good_idx] cov = error[good_idx] ** 2 weights = weights[good_idx] instrumental_lsf = instrumental_lsf[good_idx] # Check error if (cov <= 0).any(): raise ValueError("Input flux error contains negative or null values.") if np.nansum(weights) <= 0: raise ValueError("All input weights are zero; cannot perform fit.") _log("Number of selected pixels within wavelength range: ", good_idx.size) if options.has_value("velscale"): velscale = options["velscale"] else: # Set velscale to None velscale = None _log("Log-binning spectra to velocity scale: ", velscale, " (km/s)") # Update the value of velscale if velscale is not None: dlnlam = velscale / spectrum.constants.c.to("km/s").value ln_wave = np.arange(np.log(wl_range[0]), np.log(wl_range[1]) + dlnlam, dlnlam) flux = flux_conserving_interpolation(ln_wave, np.log(wavelength), flux) cov = flux_conserving_interpolation(ln_wave, np.log(wavelength), cov) weights = np.interp(ln_wave, np.log(wavelength), weights) instrumental_lsf = np.interp(ln_wave, np.log(wavelength), instrumental_lsf) new_wavelength = np.exp(ln_wave) weights[(new_wavelength < wavelength[0]) | (new_wavelength > wavelength[-1])] = 0.0 wavelength = new_wavelength else: ln_wave = np.log(wavelength) _log("Number of pixels after interpolation: ", wavelength.size) # Normalize spectra if normalize: _log("Spectra normalized using wavelength range: ", wl_norm_range) norm_idx = np.where( (wavelength >= wl_norm_range[0]) & (wavelength <= wl_norm_range[1]) )[0] norm_flux = np.nanmedian(flux[norm_idx]) flux /= norm_flux cov /= norm_flux**2 flux_units = u.dimensionless_unscaled else: norm_flux = 1.0 if redshift > 0: dl_sq = cosmology.luminosity_distance(redshift).to("cm").value ** 2 dl_sq = 4 * np.pi * dl_sq * (1 + redshift) else: dl_sq = (10 * u.pc).to("cm").value ** 2 * 4 * np.pi self.config["flux"] = flux self.config["var"] = cov self.config["redshift"] = redshift self.config["wlUnits"] = wl_units self.config["fluxUnits"] = flux_units self.config["dl_sq"] = dl_sq self.config["norm_flux"] = norm_flux self.config["wavelength"] = wavelength << u.angstrom self.config["ln_wave"] = ln_wave self.config["weights"] = weights if not (instrumental_lsf == 0).all(): self.config["lsf"] = instrumental_lsf _log("-> Configuration done.")
[docs] def prepare_galaxy(self, options): """Build and configure a :class:`pst.galaxy.GalaxySED` model. The method initializes stellar, attenuation, and optional dust emission components using the already prepared configuration and stores the resulting galaxy model and parameter index in ``self.config``. """ # Stellar emissions ssp_model = self.config.get("ssp_model") if ssp_model is None: self.prepare_ssp_model(options) ssp_model = self.config["ssp_model"] sfh_model = self.config.get("sfh_model") if sfh_model is None: self.prepare_sfh_model(options) sfh_model = self.config["sfh_model"].model ## Create stellar emission model _log("Setting up stellar emission component") stars = sed.StellarComponent(ssp=ssp_model, sfh=sfh_model) # Dust extinction and emission dust_attenuation = None dust_emission = None if options.get_bool("DustAttenuation", False): _log("Setting up dust attenuation") att_name = options.get_string( "DustAttenuationModel", "DustScreenAttenuation") ext_law = options.get_string("ExtinctionLaw", "ccm89") dust_curve = dust.ExtinctionLibCurve(law=ext_law) _log("DurstAttenuationModel: ", att_name) # TODO: implement interface for other models if att_name == "DustScreenAttenuation": dust_attenuation = dust.DustScreenAttenuation(curve=dust_curve) if options.get_bool("DustEmission", False): _log("Setting up dust emission model") dust_sed = dust.Casey2012DustComponent() if options.get_bool("DustCalorimetric", True): _log("Setting energy balance approximation") dust_emission = dust.CalorimetricDustComponent( attenuation=dust_attenuation, dust_sed_component=dust_sed ) _log("Setting up galaxy model") galaxy = GalaxySED(stellar_model=stars, dust_attenuation_model=dust_attenuation, dust_model=dust_emission, target_wavelength=self.config["ssp_model"].wavelength, redshift=self.config.get("redshift", 0.0), cosmology=cosmology) params = galaxy.build_param_index(include_fixed=False, prefix="") sections = [s.rsplit(".", 1) for s in params] self.config["galaxy-params"] = params self.config["galaxy-sections"] = sections self.config["galaxy"] = galaxy
[docs] def prepare_legendre_polynomials(self, options): """Prepare the set of Legendre polynomials used during the fit. Parameters ---------- options : :class:`DataBlock` Input options to initialise the model. """ _log("\n-> Configuring multiplicative polynomial") if options.has_value("legendre_deg"): kwargs = {} if options.has_value("legendre_bounds"): kwargs["bounds"] = options["legendre_bounds"] if options.has_value("legendre_scale"): kwargs["scale"] = options["legendre_scale"] if options.has_value("legendre_clip_first_zero"): kwargs["clip_first_zero"] = options["legendre_clip_first_zero"] _log(f"Using Legendre polynomials up to degree {options['legendre_deg']}", "\nAdditional arguments: ", kwargs) self.config["legendre_pol"] = spectrum.get_legendre_polynomial_array( self.config["wavelength"], options["legendre_deg"], **kwargs) else: _log(f"Not using multiplicative Legendre polynomials") _log("-> Configuration done")
[docs] def measure_emission_lines(self, solution: DataBlock, **kwargs): """Measure emission line fluxes and EWs from the best-fit solution. Parameters ---------- solution : :class:`DataBlock` Best-fit solution containing the model parameters. Returns ------- line_table : :class:`astropy.table.Table` Table containing the measured emission line fluxes. line_segm_map : :class:`besta.spectrum.LineSegmentationMap` Map containing the segmentation information for the emission lines. """ _log("Measuring emission line fluxes from input solution") wavelength = self.config["wavelength"].to_value("Angstrom") flux = self.config["flux"] flux_error = np.sqrt(self.config["var"]) flux_model, _ = self.make_observable(solution, parse=True) # Build a new weights array that only includes the masking of the sky weights = self.config.get("telluric_mask", np.ones_like(flux, dtype=bool)) weights &= self.config.get("sky_line_mask", np.ones_like(flux, dtype=bool)) line_table, line_segm_map = spectrum.find_emission_lines( wavelength, flux, flux_error, flux_model, continuum=flux_model, continuum_error=flux_model / 100, **kwargs) return line_table, line_segm_map
[docs] def plot_solution(self, solution: DataBlock, figname=None, plot_lines=True): """Plot the fit.""" flux_model = self.make_observable(solution, parse=True) if isinstance(flux_model, tuple): weights = flux_model[1] flux_model = flux_model[0] else: weights = np.ones_like(flux_model) # Include input weights weights *= self.config["weights"] # Grab the solution values (visualuzation purpose only) sol_keys = solution.keys() sol_sections = {} for (sec, name) in sol_keys: if sec not in sol_sections: sol_sections[sec] = {name: solution[sec, name]} else: sol_sections[sec].update({name: solution[sec, name]}) fig, axs = plt.subplots(ncols=2, nrows=2, sharex="col", sharey="row", constrained_layout=True, width_ratios=[4, 1], height_ratios=[2, 1], figsize=(16, 9)) plt.suptitle(f"Module: {self.name}") # Display the information ax = axs[0, 1] # Pixel masking information mask_info = {"Total pixels": self.config["flux"].size, "Masked pixels (w=0)": np.sum(weights <= 0), " - Telluric abs.": np.sum( self.config.get("telluric_mask", 0)), " - Sky lines": np.sum( self.config.get("sky_line_mask", 0)), " - Emission lines": np.sum( self.config.get("emission_lines_mask", 0))} sol_sections["settings"] = {"use_transforms": self.config.get("use_transforms", False)} if sol_sections["settings"]["use_transforms"]: sfh_params_latent = self.config["sfh_model"].get_sfh_parameters_array(solution) sfh_params_phys = self.config["sfh_model"].to_physical(sfh_params_latent) for key, value in zip(self.config["sfh_model"].sfh_bin_keys, sfh_params_phys): sol_sections[self.config["sfh_model"].sect_name][key] = value sections = [(k, v) for k, v in sol_sections.items()] sections = sections + [("Masking", mask_info)] _ = draw_dict_in_axes(ax, sections, section_spacing=1, title_style="underline") ax.axis("off") # Plot input spectra and best-fit model ax = axs[0, 0] # SNR information snr = np.nanpercentile( self.config["flux"] / np.sqrt(self.config["var"]), (16, 50, 84) ) ax.annotate(f"SNR (16, 50, 84 percentiles): " f"{snr[0]:.1f}, {snr[1]:.1f}, {snr[2]:.1f}", xy=(0.02, 0.98), xycoords="axes fraction", va="top", fontsize=8) ax.fill_between( self.config["wavelength"].value, self.config["flux"] - self.config["var"] ** 0.5, self.config["flux"] + self.config["var"] ** 0.5, color="k", alpha=0.5, ) ax.plot( self.config["wavelength"], self.config["flux"], c="k", label="Observed", lw=0.7) # Show masked pixels nan_mask = np.ones_like(self.config["flux"]) nan_mask[weights <= 0] = np.nan ax.plot( self.config["wavelength"], self.config["flux"] * nan_mask, c="r", lw=0.7, label="Non-zero weights", ) # Plot model ax.plot(self.config["wavelength"], flux_model, c="b", label="Model", lw=0.7) # Plot residuals residuals = flux_model - self.config["flux"] ax.plot( self.config["wavelength"], residuals, c="grey", label="Residuals", lw=0.7 ) ax.plot( self.config["wavelength"], residuals * nan_mask, c="orange", label="Residuals (w>0)", lw=0.7 ) ax.axhline(0, ls="--", color="k", alpha=0.2) ax.set_ylabel("Flux") ax.legend(bbox_to_anchor=(0.5, 1.01), loc="lower center", ncols=5, fontsize=8) p5, p95 = np.nanpercentile(self.config["flux"], [5, 95]) p_residuals = np.nanpercentile(residuals, 5) * 0.95 ax.set_ylim(np.min([p_residuals, p5 * 0.8]), p95 * 1.2) # Plot masked emission lines and telluric regions if plot_lines: if "emission_lines_used" in self.config: for line in self.config["emission_lines_used"]: ax.axvline(line.rest_wavelength, ls="--", lw=0.7, color="red", alpha=0.5) ax.annotate(line.name, xy=(line.rest_wavelength, ax.get_ylim()[1]), xytext=(0, -5), textcoords="offset points", ha="center", va="top", fontsize=6, color="red") if "sky_lines_used" in self.config: for line in self.config["sky_lines_used"]: ax.axvline( line.rest_wavelength / (1 + self.config.get("redshift", 0.0)), ls="--", lw=0.7, color="cyan", alpha=0.5) ax.annotate( line.name, xy=(line.rest_wavelength / (1 + self.config.get("redshift", 0.0)), ax.get_ylim()[1] * 0.95), xytext=(0, -5), textcoords="offset points", ha="center", va="top", fontsize=6, color="cyan") if "telluric_bands_used" in self.config: for band in self.config["telluric_bands_used"]: ax.axvspan( band.wmin / (1 + self.config.get("redshift", 0.0)), band.wmax / (1 + self.config.get("redshift", 0.0)), color="orange", alpha=0.2) ax.annotate( band.name, xy=((band.wmin + band.wmax) / 2 / (1 + self.config.get("redshift", 0.0)), ax.get_ylim()[1] * 0.9), xytext=(0, -5), textcoords="offset points", ha="center", va="top", fontsize=6, color="orange") ax.legend() ax.set_xlim(self.config["wavelength"].value[[0, -1]]) # Plot chi2 good_pixels = weights > 0 chi2 = (flux_model - self.config["flux"]) ** 2 / self.config["var"] mean_chi2 = np.nanmean(chi2[good_pixels]) median_chi2 = np.nanmedian(chi2[good_pixels]) nmad_chi2 = 1.4826 * np.nanmedian( np.abs(chi2[good_pixels] - median_chi2)) loglike = self.log_like(self.config["flux"][good_pixels], flux_model[good_pixels], self.config["var"][good_pixels], weights=weights[good_pixels]) ax = axs[1, 0] ax.plot(self.config["wavelength"], chi2, c="k", lw=0.7) ax.grid(visible=True) ax.set_ylabel(r"$\chi^2$") ax.set_yscale("symlog", linthresh=1.0) ax.set_xlabel("Wavelength (AA)") ax = axs[1, 1] ax.hist( chi2, bins=np.geomspace(0.01, 100), orientation="horizontal", color="k", histtype="step" ) ax.annotate(f"Median chi2: {np.nanmedian(chi2):.1f}" + f"\nMean chi2: {mean_chi2:.1f}" + f"\nNMAD chi2: {nmad_chi2:.1f}" + f"\nLog-likelihood: {loglike:.1f}", xy=(0.05, 0.95), xycoords="axes fraction", va="top", fontsize=8) ax.set_xlabel("No. pixels") ax.grid(visible=True) ax.tick_params(labelleft=False) if figname is not None: fig.savefig(figname, bbox_inches="tight", dpi=300) _log(f"Fit plot saved at: {figname}") plt.close() return fig
[docs] class PhotometryFitModule(BaseModule): """Base class for photometry fitting modules in BESTA."""
[docs] def prepare_observed_photometry(self, options: SectionOptions): """Prepare the Photometric Data. Parameters ---------- options : :class:`DataBlock` """ _log("\n-> Configuring photometric data") photometry_file = os.path.expandvars(options["inputPhotometry"]) # Read the data input_data = ascii.read(photometry_file) filter_names = input_data[input_data.colnames[0]].value.astype(str) flux = input_data[input_data.colnames[1]].value.astype(float) flux_err = input_data[input_data.colnames[2]].value.astype(float) # Handle upper and lower limits if len(input_data.colnames) > 3: measure_limits = input_data[input_data.colnames[3]].value.astype(str) is_upper = measure_limits == "upper" is_lower = measure_limits == "lower" else: is_upper = None is_lower = None # Unit conversion if options.has_value("fluxUnits"): _log("Converting input flux units to uJy") flux_units = u.Unit(options["fluxUnits"]) flux = (flux << flux_units).to("uJy").value flux_err = (flux_err << flux_units).to("uJy").value flux_units = u.Unit("uJy") else: _log("Assuming input flux units are in uJy") flux_units = u.Unit("uJy") self.config["photometry_flux"] = flux self.config["photometry_flux_var"] = flux_err**2 self.config["photometry_flux_unit"] = flux_units self.config["photometry_upper_limit"] = is_upper self.config["photometry_lower_limit"] = is_lower # Load the photometric filters photometric_filters = [] for filter_name in filter_names: _log(f"Loading photometric filter: {filter_name}") if os.path.exists(os.path.expandvars(filter_name)): filt = Filter.from_text_file(os.path.expandvars(filter_name)) else: filt = Filter.from_svo(filter_name) photometric_filters.append(filt) # TODO: For now save both for backwards compatibility # TODO: Superseed by FilterList self.config["filters"] = photometric_filters self.config["filter_list"] = FilterList(photometric_filters) redshift = options.get_double("redshift", default=0.0) self.config["redshift"] = redshift _log("Source redshift: ", redshift) _log("-> Configuration done.")
[docs] def prepare_galaxy(self, options): """Build and configure a :class:`pst.galaxy.GalaxySED` model for photometry. The method initializes stellar, attenuation, and optional dust emission components, prepares the target wavelength grid, and stores the galaxy model and parameter index in ``self.config``. """ # Stellar emissions ssp_model = self.config.get("ssp_model") if ssp_model is None: self.prepare_ssp_model(options) ssp_model = self.config["ssp_model"] sfh_model = self.config.get("sfh_model") if sfh_model is None: self.prepare_sfh_model(options) sfh_model = self.config["sfh_model"].model filters = self.config.get("filter_list") if filters is None: self.prepare_observed_photometry(options) ## Create stellar emission model _log("Setting up stellar emission component") stars = sed.StellarComponent(ssp=ssp_model, sfh=sfh_model) # Dust extinction and emission dust_attenuation = None dust_emission = None if options.get_bool("DustAttenuation", False): _log("Setting up dust attenuation") att_name = options.get_string( "DustAttenuationModel", "DustScreenAttenuation") ext_law = options.get_string("ExtinctionLaw", "ccm89") dust_curve = dust.ExtinctionLibCurve(law=ext_law) _log("DurstAttenuationModel: ", att_name) # TODO: implement interface for other models if att_name == "DustScreenAttenuation": dust_attenuation = dust.DustScreenAttenuation(curve=dust_curve) if options.get_bool("DustEmission", False): _log("Setting up dust emission model") dust_sed = dust.Casey2012DustComponent() if options.get_bool("DustCalorimetric", True): _log("Setting energy balance approximation") dust_emission = dust.CalorimetricDustComponent( attenuation=dust_attenuation, dust_sed_component=dust_sed ) # Setup target wavelength range min_wl, max_wl = filters.wavelength_range() if dust_emission is not None: min_wl = np.min(( min_wl.to_value(u.AA), stars.ssp.wavelength.min().to_value("AA"))) << u.AA max_wl = np.max(( max_wl.to_value(u.AA), getattr(dust_emission.dust_sed_component, "ir_range", [None, 1 << u.AA])[1].to_value(u.AA) )) << u.AA z_obs = self.config.get("redshift", 0.0) if options.get_bool("logwave", False): target_wl = np.geomspace(min_wl.to_value("AA") / (1 + z_obs), max_wl.to_value("AA"), 3000) << u.AA else: target_wl = np.arange(min_wl.to_value("AA") / (1 + z_obs), max_wl.to_value("AA"), 10) << u.AA _log("Target wavelength range: ", f"{target_wl[0]:.1f} -- {target_wl[-1]:.1f} ({target_wl.size} pix)") _log("Setting up galaxy model") galaxy = GalaxySED(stellar_model=stars, dust_attenuation_model=dust_attenuation, dust_model=dust_emission, target_wavelength=target_wl, redshift=z_obs, cosmology=cosmology, filters=filters) params = galaxy.build_param_index(include_fixed=False, prefix="") sections = [s.rsplit(".", 1) for s in params] self.config["galaxy-params"] = params self.config["galaxy-sections"] = sections self.config["galaxy"] = galaxy
[docs] def plot_solution(self, solution: DataBlock, figname=None): """Plot the fit.""" flux_model, full_spec = self.make_observable(solution, parse=True, include_spec=True) # Include input weights fig, axs = plt.subplots(ncols=2, nrows=3, sharex="col", sharey="row", constrained_layout=True, width_ratios=[4, 1], height_ratios=[2, 1, 1], figsize=(16, 9)) plt.suptitle(f"Module: {self.name}") # Display the information ax = axs[0, 1] sol_keys = solution.keys() sol_sections = {} sol_sections["Settings"] = { "Redshift": self.config["redshift"], "No. bands": self.config["photometry_flux"].size, "use_transforms": self.config.get("use_transforms", False)} for (sec, name) in sol_keys: if sec not in sol_sections: sol_sections[sec] = {name: solution[sec, name]} else: sol_sections[sec].update({name: solution[sec, name]}) if sol_sections["Settings"]["use_transforms"]: sfh_params_latent = self.config["sfh_model"].get_sfh_parameters_array(solution) sfh_params_phys = self.config["sfh_model"].to_physical(sfh_params_latent) for key, value in zip(self.config["sfh_model"].sfh_bin_keys, sfh_params_phys): sol_sections[self.config["sfh_model"].sect_name][key] = value sections = [(k, v) for k, v in sol_sections.items()] _ = draw_dict_in_axes(ax, sections, section_spacing=1, title_style="underline") ax.axis("off") # Plot input photometry and model eff_wl = self.config["filter_list"].effective_wavelength ax = axs[0, 0] snr = np.nanpercentile( self.config["photometry_flux"] / np.sqrt(self.config["photometry_flux_var"]), (16, 50, 84) ) ax.annotate(f"SNR (16, 50, 84 percentiles): " f"{snr[0]:.1f}, {snr[1]:.1f}, {snr[2]:.1f}", xy=(0.02, 0.98), xycoords="axes fraction", va="top", fontsize=8) uplim = self.config["photometry_upper_limit"] uplim = uplim if uplim is not None else False lolim = self.config["photometry_lower_limit"] lolim = lolim if lolim is not None else False ax.errorbar(eff_wl.to_value("AA"), self.config["photometry_flux"], yerr=np.sqrt(self.config["photometry_flux_var"]), lolims=lolim, uplims=uplim, capsize=2, label="Observed", fmt="s", mec="k", mfc="none", ecolor="k" ) # Plot model ax.scatter(eff_wl.to_value("AA"), flux_model, edgecolors="r", fc="none", lw=2, label="Model") ax.plot(self.config["galaxy"].target_wavelength.to_value("AA"), full_spec, color="k", alpha=0.4) # Plot residuals ax.set_ylabel(f"Flux density ({self.config['photometry_flux_unit']})") ax.legend(bbox_to_anchor=(0.5, 1.01), loc="lower center", ncols=5, fontsize=8) ax.set_ylim(min(flux_model.min(), self.config["photometry_flux"].min()) * 0.8, max(flux_model.max(), self.config["photometry_flux"].max()) * 1.2) # If wavelength range is too large, use log scale wlmin, wlmax = self.config["galaxy"].target_wavelength[[0, -1]] if wlmax / wlmin > 10: ax.set_xscale("log") ax.set_yscale("log") twax = ax.twinx() for f in self.config["filter_list"].filters: twax.plot(f.filter_wavelength, f.filter_resp / f.filter_resp.max(), label=f.name) twax.legend(fontsize=6, loc="lower right") twax.set_ylabel("Filter response") min_wl, max_wl = self.config["filter_list"].wavelength_range() ax.set_xlim(min_wl.to_value(u.AA) * 0.8, max_wl.to_value(u.AA) * 1.2) # Flux density per wavelength unit ax = axs[1, 0] flam = (full_spec * u.Unit("uJy")).to("1e-16 erg / (s cm**2 AA)", u.spectral_density(self.config["galaxy"].target_wavelength)) ax.plot( self.config["galaxy"].target_wavelength.to_value("AA"), flam, color="k", alpha=0.4) ax.set_ylabel("Flux density (1e-16 erg / (s cm**2 AA))") # chi2 as function of wavelength chi2 = (flux_model - self.config["photometry_flux"]) ** 2 / self.config["photometry_flux_var"] ax = axs[2, 0] ax.scatter(eff_wl.to_value("AA"), chi2, c="k") ax.grid(visible=True) ax.set_ylabel(r"$\chi^2$") ax.set_yscale("symlog", linthresh=1.0) ax.set_ylim(0, np.nanmax(chi2) * 2) ax.set_xlabel("Wavelength (AA)") ax = axs[1, 1] ax.axis("off") ax = axs[2, 1] ax.axis("off") if figname is not None: fig.savefig(figname, bbox_inches="tight", dpi=300) _log(f"Fit plot saved at: {figname}") plt.close() return fig
class EquivalentWidthFitModule(BaseModule): """Base class for equivalent width fit modules in BESTA.""" pass def plot_solution(self, solution: DataBlock, figname=None): """Plot an equivalent-width fit solution. This placeholder is implemented by concrete equivalent-width modules. """ pass class GridFitMixin: """Mixin class for grid-based fitting modules in BESTA.""" def _load_callable_from_file(self, file_path: str | pathlib.Path, func_name: str, *, module_name: Optional[str] = None, ) -> Callable: """ Load a callable named `func_name` from a Python source file at `file_path`. Parameters ---------- file_path Path to the .py file (does not need to be importable / on PYTHONPATH). func_name Name of the function (or other callable) defined in that file. module_name Optional module name to assign during loading. If None, a unique name is generated from the filename. Returns ------- func The loaded callable object. Raises ------ FileNotFoundError If the file does not exist. ImportError If the module cannot be loaded. AttributeError If func_name is not found in the module. TypeError If the loaded attribute is not callable. """ file_path = pathlib.Path(file_path).expanduser().resolve() if not file_path.exists(): raise FileNotFoundError(str(file_path)) if file_path.suffix != ".py": raise ImportError(f"Expected a .py file, got: {file_path}") # Give the module a deterministic-ish name to help debugging and caching if module_name is None: module_name = f"_user_boundary_{file_path.stem}" spec = importlib.util.spec_from_file_location(module_name, str(file_path)) if spec is None or spec.loader is None: raise ImportError(f"Could not create import spec for: {file_path}") module = importlib.util.module_from_spec(spec) try: spec.loader.exec_module(module) # type: ignore[attr-defined] except Exception as e: raise ImportError(f"Error importing {file_path}: {e}") from e obj = getattr(module, func_name) # may raise AttributeError if not callable(obj): raise TypeError(f"{func_name!r} in {file_path} is not callable (got {type(obj)})") return obj def prepare_grid_model(self, options): """Prepare the model grid. Parameters ---------- options : :class:`DataBlock` Input options to initialise the model. """ logger.info("-> Configuring model grid") if not options.has_value("modelGridFile"): raise ValueError("No input model grid file provided.") grid_file = os.path.expandvars(options["modelGridFile"]) logger.info("Loading model grid from file: %s", grid_file) if not os.path.isfile(grid_file): raise FileNotFoundError(f"Input model grid file {grid_file} not found.") logger.info("Reading model grid... fluxes must be in microJansky / Msun") model_grid = ModelGrid.load_auto(grid_file) if options.has_value("boundaryFunctionFile"): boundary_file = os.path.expandvars( options["boundaryFunctionFile"]) logger.info("Loading boundary function from file: %s", boundary_file) # Split the path to the file and the function name given by [] mthd_s = boundary_file.find("[") mthd_e = boundary_file.find("]") boundary_func_name = boundary_file[mthd_s + 1:mthd_e] boundary_file = boundary_file[:mthd_s] boundary_func = self._load_callable_from_file( boundary_file, boundary_func_name) model_grid.check_boundaries = boundary_func logger.info("Applied boundary function to model grid.") self.config["model_grid"] = model_grid if options.has_value("knn"): self.config["knn"] = options["knn"] else: self.config["knn"] = int(4 * model_grid.n_targets) logger.info("-> Configuration done.") class EmulatorMixin: """Mixin class for modules using ML emulators in BESTA.""" def prepare_emulator(self, options): """Prepare the ML emulator. Parameters ---------- options : :class:`DataBlock` Input options to initialise the model. Notes ----- The ML emulator is expected to be stored in a joblib (.joblib) file. """ try: import joblib except ImportError: raise ImportError("joblib is required to load ML emulators." "Please install joblib and try again.") logger.info("-> Configuring ML emulator") if not options.has_value("emulatorFile"): raise ValueError("No input emulator file provided.") emulator_file = os.path.expandvars(options["emulatorFile"]) logger.info("Loading ML emulator from file: %s", emulator_file) if not os.path.isfile(emulator_file): raise FileNotFoundError(f"Input emulator file {emulator_file} not found.") logger.info("Reading ML emulator...") ml_emulator = joblib.load(emulator_file) self.config["ml_emulator"] = ml_emulator logger.info("-> Configuration done.")