Source code for besta.spectrum

"""
This module contains classes and functions related
to dealing with spectra

"""
from __future__ import annotations

from dataclasses import dataclass, field, replace
from functools import wraps
from typing import Iterable, Sequence, Tuple, Optional

import numpy as np
import scipy
from scipy import ndimage
from scipy.special import legendre
from scipy.signal import find_peaks
from scipy.ndimage import median_filter, label, find_objects
from scipy.interpolate import make_smoothing_spline, interp1d
from scipy.optimize import curve_fit

from astropy.table import Table
from astropy import constants
from astropy import units as u

from besta.logging import get_logger

logger = get_logger(__name__)


[docs] def get_legendre_polynomial_array( wavelength, order, bounds=None, scale=None, clip_first_zero=True ): """ Compute an array of Legendre polynomials evaluated at normalized wavelengths. Parameters ---------- wavelength : numpy.ndarray Array of wavelength values. order : int The maximum order of the Legendre polynomial to compute. bounds : tuple, optional A tuple specifying the minimum and maximum bounds for normalization (bounds[0], bounds[1]). If None, the normalization is based on the minimum and maximum of the `wavelength` array. scale : float, optional A maximum scale to probe by the polynomials. If provided, the set of polynomial will comprise the range that is sensitive to scales smaller the input value (i.e. lower order polynomials are not included). clip_first_zero : bool, optional If ``True``, the values of each polynomial below the first and las zero of the Legendre polynomial are set to 0. This prevents the edges to reach extremelly large values when the order of the polynomial is high. Returns ------- numpy.ndarray A 2D array where each row corresponds to the values of a Legendre polynomial of a given degree, evaluated at the normalized wavelengths. The shape of the array is (order + 1, len(wavelength)). """ if bounds == None: bounds = wavelength.min(), wavelength.max() norm_wl = 2 * (wavelength - bounds[0]) / (bounds[1] - bounds[0]) - 1 norm_wl = norm_wl.clip(-1, 1) if scale is not None: min_order = np.round((bounds[1].value - bounds[0].value) / scale, 0) else: min_order = 1 if isinstance(norm_wl, u.Quantity): norm_wl = norm_wl.decompose().value logger.debug("Pol order == %s", np.arange(min_order, min_order + order + 1)) poly_set = [] for deg in [0, *np.arange(min_order, min_order + order)]: pol = legendre(deg) pol_wl = pol(norm_wl) # Clip the values on the edges to avoid extremes if clip_first_zero and deg > 0: first_zero = pol.roots.real.min() pol_wl[(norm_wl < first_zero) | (norm_wl > -first_zero)] = 0 poly_set.append(pol_wl) legendre_arr = np.array(poly_set) return legendre_arr
[docs] def legendre_decorator(make_observable_mthd): """Include multiplicative Legendre polynomials during a fit.""" @wraps(make_observable_mthd) def wrapper(*args, **kwargs): if "legendre_pol" in args[0].config: legendre_pol = args[0].config["legendre_pol"] # Get the coefficients from the input DataBlock coeffs = np.array( [1.0] + [ args[1]["legendre", f"legendre_{ith}"] for ith in range(1, legendre_pol.shape[0]) ] ) output = make_observable_mthd(*args, **kwargs) if isinstance(output, tuple): return ( output[0] * np.sum(legendre_pol * coeffs[:, np.newaxis], axis=0), output[1], ) else: return output * np.sum(legendre_pol * coeffs[:, np.newaxis], axis=0) else: return make_observable_mthd(*args, **kwargs) return wrapper
### Masking # Telluric regions
[docs] @dataclass(frozen=True) class TelluricBand: """Named wavelength interval affected by telluric absorption.""" name: str wmin: float # in Angstrom wmax: float # in Angstrom
DEFAULT_TELLURIC_BANDS_AA: Tuple[TelluricBand, ...] = ( # --- Optical --- # TelluricBand("O3 Chappuis (broad)", 5000.0, 7000.0), TelluricBand("O2 gamma (~0.63 um)", 6270.0, 6340.0), # TelluricBand("H2O (~0.65 um)", 6450.0, 6600.0), TelluricBand("O2 B-band", 6860.0, 6950.0), TelluricBand("H2O (~0.72 um)", 7160.0, 7400.0), TelluricBand("O2 A-band", 7580.0, 7700.0), TelluricBand("H2O (~0.82 um)", 8100.0, 8400.0), TelluricBand("H2O (~0.93 um)", 9000.0, 9900.0), # --- Near-IR --- TelluricBand("H2O (~1.13 um)", 11000.0, 11700.0), TelluricBand("H2O (~1.40 um)", 13400.0, 15000.0), TelluricBand("H2O/CO2 (~1.90 um)", 18000.0, 20000.0), )
[docs] def mask_telluric_regions( wavelength: np.ndarray, *, weight: Optional[np.ndarray] = None, bands: Optional[Sequence[Tuple[float, float]]] = None, pad: float = 0.0, redshift: float = 0.0, return_mask: bool = False, ) -> np.ndarray | tuple[np.ndarray, np.ndarray, list[TelluricBand]]: """ Mask regions affected by telluric absorption by setting weights to 0. Parameters ---------- wavelength : ndarray 1D array of wavelength values in Angstrom. bands : sequence of (wmin, wmax), optional Telluric band edges in the same units as `wavelength`. If None, uses a default set. pad : float, optional Extra padding (same units as `wavelength`) added to both sides of every band: [wmin - pad, wmax + pad]. return_mask : bool, optional If True, also return the boolean mask of where weights were set to 0. Returns ------- new_weight : ndarray Copy of input weight with telluric regions set to 0. mask : ndarray of bool, optional If ``return_mask`` is True, also return the boolean mask of where weights were set to 0. """ logger.info("Masking telluric regions with pad=%.1f Angstrom", pad) if redshift != 0.0: logger.warning( "Redshift is non-zero (z=%.3f); telluric bands will be shifted accordingly. Make sure this is intended.", redshift, ) w = np.asanyarray(wavelength) if weight is None: wt = np.ones_like(w) else: wt = np.asanyarray(weight) if w.ndim != 1 or wt.ndim != 1: raise ValueError("All inputs must be 1D arrays.") if not (w.size == wt.size): raise ValueError("All inputs must have the same length.") # Build list of bands in the same units as wavelength array (i.e. input units) if bands is None: # Defaults are in microns; convert to input units by dividing by to_um logger.info( "Using default telluric bands with %d bands.", len(DEFAULT_TELLURIC_BANDS_AA), ) bands = DEFAULT_TELLURIC_BANDS_AA # Combine masks across all bands tell_mask = np.zeros_like(w, dtype=bool) bands_used = [] for band in bands: a = band.wmin - pad b = band.wmax + pad overlap = (w >= a) & (w <= b) if overlap.any(): tell_mask |= overlap bands_used.append(band) logger.info( "Masked %d pixels in %d telluric bands", tell_mask.sum(), len(bands_used) ) new_weight = np.array(wt, copy=True) new_weight[tell_mask] = 0.0 if return_mask: return new_weight, tell_mask, bands_used return new_weight
# Emission lines
[docs] @dataclass(frozen=True) class EmissionLine: """Emission line definition.""" name: str rest_wavelength: float default_half_width: float = 10.0 flux: Optional[float] = None flux_error: Optional[float] = None rest_wavelength_error: Optional[float] = None flag: int = 0 metadata: dict = field(default_factory=dict)
[docs] class EmissionLineList: """Collection of emission lines with helper methods.""" def __init__(self, lines: Iterable[EmissionLine]): self.lines = list(lines) self.size = len(self.lines) def __len__(self): return self.size def __getitem__(self, item): return self.lines[item] @property def rest_wavelengths(self) -> np.ndarray: """Array of rest wavelengths of the lines.""" return np.array([line.rest_wavelength for line in self.lines]) @property def names(self) -> list[str]: """List of line names.""" return [line.name for line in self.lines] @property def default_half_widths(self) -> np.ndarray: """Array of default half-widths of the lines.""" return np.array([line.default_half_width for line in self.lines]) @property def fluxes(self) -> np.ndarray: """Array of measured line fluxes.""" return np.array( [np.nan if line.flux is None else line.flux for line in self.lines] ) @property def flux_errors(self) -> np.ndarray: """Array of measured line-flux uncertainties.""" return np.array( [ np.nan if line.flux_error is None else line.flux_error for line in self.lines ] ) @property def rest_wavelength_errors(self) -> np.ndarray: """Array of line-center uncertainties.""" return np.array( [ np.nan if line.rest_wavelength_error is None else line.rest_wavelength_error for line in self.lines ] ) @property def flags(self) -> np.ndarray: """Array of line quality flags.""" return np.array([line.flag for line in self.lines], dtype=int)
[docs] def __iter__(self): """Allow iteration over the lines in the list.""" return iter(self.lines)
[docs] def get_observed_wavelengths(self, redshift: float) -> np.ndarray: """Compute observed wavelengths of the lines given a redshift. Parameters ---------- redshift : float Redshift to apply to the rest wavelengths. Returns ------- np.ndarray Array of observed wavelengths corresponding to the input redshift. """ return np.array([line.rest_wavelength * (1 + redshift) for line in self.lines])
[docs] def crossmatch_line_list(self, line_list: EmissionLineList): """Cross-match this line list with another line list and return the matched lines.""" matched_lines = [] for line in self.lines: for other_line in line_list: if np.isclose( line.rest_wavelength, other_line.rest_wavelength, atol=1e-3 ): matched_lines.append(line) break return EmissionLineList(matched_lines)
[docs] def with_names(self, names: Sequence[str]) -> EmissionLineList: """Return a copy of the line list with updated names.""" if len(names) != len(self.lines): raise ValueError("Number of names must match number of lines.") return EmissionLineList( [replace(line, name=name) for line, name in zip(self.lines, names)] )
[docs] def get_closest_line( self, wavelength: float, *, redshift: float = 0.0 ) -> Optional[EmissionLine]: """Return the line with observed wavelength closest to the input value.""" if not self.lines: return None observed_wavelengths = self.get_observed_wavelengths(redshift) index = int(np.argmin(np.abs(observed_wavelengths - wavelength))) return self.lines[index]
[docs] def get_line_by_observed_wavelength( self, wavelength: float, redshift: float, tol: float = 5.0 ) -> Optional[EmissionLine]: """Match an input line to the list of emission lines. Parameters ---------- wavelength : float Observed wavelength to match. redshift : float Redshift to apply to the rest wavelengths. tol : float, optional Tolerance for matching, by default 5.0. Returns ------- Optional[EmissionLine] The matched emission line, or None if no match is found. """ for line in self.lines: obs_wl = line.rest_wavelength * (1 + redshift) if abs(obs_wl - wavelength) <= tol: return line return None
[docs] def to_mask(self, wavelength, redshift=0.0) -> np.ndarray: """Create a boolean mask for the input wavelength array where the lines are located. Parameters ---------- wavelength : np.ndarray Array of wavelength values to create the mask for. redshift : float, optional Redshift to apply to the rest wavelengths, by default 0.0. Returns ------- np.ndarray Boolean array of the same shape as `wavelength` where True indicates the presence of a line. """ mask = np.zeros_like(wavelength, dtype=bool) for line in self.lines: obs_wl = line.rest_wavelength * (1 + redshift) # Apply redshift line_mask = (wavelength >= obs_wl - line.default_half_width) & ( wavelength <= obs_wl + line.default_half_width ) mask |= line_mask return mask
[docs] def to_table(self, filename: str = None, **table_kwargs) -> Table: """Save emission lines to a file. Parameters ---------- filename : str, optional If provided, the table is saved to this file. Returns ------- astropy.table.Table Table containing the emission line information. """ table = Table() table["name"] = self.names table["rest_wavelength"] = self.rest_wavelengths table["default_half_width"] = self.default_half_widths if np.isfinite(self.fluxes).any(): table["flux"] = self.fluxes if np.isfinite(self.flux_errors).any(): table["flux_error"] = self.flux_errors if np.isfinite(self.rest_wavelength_errors).any(): table["rest_wavelength_error"] = self.rest_wavelength_errors if np.any(self.flags != 0): table["flag"] = self.flags if filename is not None: table.write(filename, overwrite=True, **table_kwargs) return table
[docs] @classmethod def from_table(cls, table: Table) -> EmissionLineList: """Create an EmissionLineList from an astropy Table. Parameters ---------- table : astropy.table.Table Table containing the emission line information with columns: name, rest_wavelength, default_half_width. Returns ------- EmissionLineList An instance of EmissionLineList created from the input table. """ names = table["name"] if "name" in table.colnames else table[table.colnames[0]] rest_wavelengths = ( table["rest_wavelength"] if "rest_wavelength" in table.colnames else table[table.colnames[1]] ) if "default_half_width" in table.colnames: default_half_widths = table["default_half_width"] elif len(table.colnames) >= 3: default_half_widths = table[table.colnames[2]] else: default_half_widths = np.full(len(rest_wavelengths), 10.0) fluxes = ( table["flux"] if "flux" in table.colnames else np.full(len(rest_wavelengths), np.nan) ) flux_errors = ( table["flux_error"] if "flux_error" in table.colnames else np.full(len(rest_wavelengths), np.nan) ) rest_wavelength_errors = ( table["rest_wavelength_error"] if "rest_wavelength_error" in table.colnames else np.full(len(rest_wavelengths), np.nan) ) flags = ( table["flag"] if "flag" in table.colnames else np.zeros(len(rest_wavelengths), dtype=int) ) lines = [ EmissionLine( name=str(name), rest_wavelength=float(rest_wl), default_half_width=float(half_width), flux=None if not np.isfinite(flux) else float(flux), flux_error=None if not np.isfinite(flux_err) else float(flux_err), rest_wavelength_error=None if not np.isfinite(rest_err) else float(rest_err), flag=int(flag), ) for name, rest_wl, half_width, flux, flux_err, rest_err, flag in zip( names, rest_wavelengths, default_half_widths, fluxes, flux_errors, rest_wavelength_errors, flags, ) ] return cls(lines)
[docs] @classmethod def from_file(cls, filename: str, **table_kwargs) -> EmissionLineList: """Load emission lines from a file with columns. Parameters ---------- filename : str Path to the file containing the emission line data. **table_kwargs Additional keyword arguments to pass to `astropy.table.Table.read`. Returns ------- EmissionLineList An instance of `EmissionLineList` containing the loaded emission lines. """ data = Table.read(filename, **table_kwargs) return cls.from_table(data)
[docs] def get_default_emission_lines(): """Return the default optical emission-line list. Returns ------- EmissionLineList Default list of common rest-frame galaxy emission lines with nominal masking half-widths in Angstrom. """ return EmissionLineList( [ EmissionLine("[OII]3726", 3726.03, 8.0), EmissionLine("[OII]3729", 3728.82, 8.0), EmissionLine("Hd", 4101.74, 10.0), EmissionLine("Hg", 4340.47, 10.0), EmissionLine("[OIII]4363", 4363.21, 8.0), EmissionLine("Hb", 4861.33, 12.0), EmissionLine("[OIII]4959", 4958.91, 10.0), EmissionLine("[OIII]5007", 5006.84, 10.0), EmissionLine("[NI]5200", 5199.0, 10.0), EmissionLine("HeI5876", 5875.62, 12.0), EmissionLine("[OI]6300", 6300.30, 10.0), EmissionLine("[OI]6364", 6363.78, 10.0), EmissionLine("[NII]6548", 6548.05, 12.0), EmissionLine("Ha", 6562.80, 14.0), EmissionLine("[NII]6583", 6583.45, 12.0), EmissionLine("[SII]6716", 6716.44, 12.0), EmissionLine("[SII]6731", 6730.82, 12.0), EmissionLine("[ArIII]7136", 7135.79, 12.0), ] )
[docs] def get_default_sky_emission_lines(): """Return the default sky emission-line list. Returns ------- EmissionLineList Default list of common sky emission lines with nominal masking half-widths in Angstrom. """ return EmissionLineList( [ EmissionLine("NaI 5890", 5890.0, 10.0), EmissionLine("NaI 5896", 5896.0, 10.0), EmissionLine("OI 5577", 5577.34, 10.0), EmissionLine("OI 6300", 6300.30, 10.0), EmissionLine("OI 6364", 6363.78, 10.0), ] )
[docs] def _parse_lines_param_decorator(func): """Decorator to parse the `lines` parameter.""" @wraps(func) def wrapper(*args, **kwargs): if kwargs.get("lines") is not None: lines_input = kwargs["lines"] if isinstance(lines_input, EmissionLineList): # Already an EmissionLineList; use as is pass elif isinstance(lines_input, str): # Assume it is a filename; load the line list from the file kwargs["lines"] = EmissionLineList.from_file(lines_input) elif isinstance(lines_input, Iterable) and all( isinstance(line, EmissionLine) for line in lines_input ): # It's already an iterable of EmissionLine objects; convert to EmissionLineList kwargs["lines"] = EmissionLineList(lines_input) else: raise ValueError( "Invalid format for 'lines' parameter. Must be a filename or an iterable of EmissionLine objects." ) return func(*args, **kwargs) return wrapper
[docs] @_parse_lines_param_decorator def mask_strong_emission_lines( wavelength: np.ndarray, flux: np.ndarray, uncertainty: np.ndarray, weight: np.ndarray, *, redshift: float = 0.0, lines: Optional[Sequence[EmissionLine]] = None, # detection controls snr_threshold: float = 5.0, min_continuum_snr: float = 1.0, # local continuum estimation cont_half_window: int = 75, cont_sigma_clip: float = 4.0, # mask width controls half_width: Optional[float] = None, width_mode: str = "fixed", # "fixed" | "fwhm_pixels" fwhm_pixels: float = 3.0, max_half_width: Optional[float] = None, pad: float = 0.0, # bookkeeping return_mask: bool = False, return_lines_masked: bool = False, ) -> ( np.ndarray | tuple[np.ndarray, np.ndarray] | tuple[np.ndarray, np.ndarray, list[EmissionLine]] ): """ Identify strong emission lines and mask them by setting weight=0. This is a robust, low-assumption approach: 1) For each expected line center (rest -> observed using `z`), estimate a local continuum with a running median in a window around the line. 2) Compute line "excess" = flux - continuum and its S/N using uncertainty. 3) If peak S/N within the line window exceeds `snr_threshold` (and the continuum is not completely noise-dominated), mask the line region. Parameters ---------- wavelength, flux, uncertainty, weight : ndarray 1D arrays of the same length. z : float, optional Redshift used to place the line centers: lambda_obs = lambda_rest*(1+z). lines : sequence of EmissionLine, optional Line list. If None, uses DEFAULT_EMISSION_LINES_A (Angstrom). snr_threshold : float, optional Minimum peak (flux-continuum)/sigma within the line window to mask. min_continuum_snr : float, optional Require median(abs(continuum)/sigma) in the continuum window to be at least this value; helps avoid false positives when everything is noise. cont_half_window : int, optional Half-window size in pixels for local continuum estimation around each line. cont_sigma_clip : float, optional Sigma-clipping level applied to the continuum residuals when estimating the median continuum (simple, robust clip). half_width : float, optional Half-width of the masked region in wavelength units. If None, uses each line's `default_half_width`. width_mode : {"fixed","fwhm_pixels"}, optional - "fixed": uses `half_width` (or per-line default) in wavelength units. - "fwhm_pixels": convert `fwhm_pixels` to wavelength half-width using the local dispersion (median delta-lambda near the line) and mask +-k*FWHM, where k=1.5 (roughly covers wings). fwhm_pixels : float, optional Only used if width_mode="fwhm_pixels". Instrumental/profile FWHM in pixels. max_half_width : float, optional Cap the computed half-width (in wavelength units) to avoid huge masks. pad : float, optional Extra padding added to half-width (same wavelength units). return_mask : bool, optional If True, also return boolean mask of where weights were set to 0. return_lines_masked : bool, optional If True, also return list of line names that were actually masked. Returns ------- new_weight : ndarray Copy of weight with emission-line regions set to 0. mask : ndarray of bool, optional True where masked. lines_masked : list of str, optional Names of lines that triggered masking. Notes ----- - This targets *strong, narrow-ish* features near known lines. It will not detect arbitrary lines at unknown wavelengths unless you expand the line list. - If you already have a model continuum, you can replace the continuum estimation with that for even more robustness. """ logger.info( "Masking strong emission lines with redshift z=%.3f and snr_threshold=%.1f", redshift, snr_threshold, ) w = np.asanyarray(wavelength) f = np.asanyarray(flux) s = np.asanyarray(uncertainty) wt = np.asanyarray(weight) if any(arr.ndim != 1 for arr in (w, f, s, wt)): raise ValueError("All inputs must be 1D arrays.") if not (w.size == f.size == s.size == wt.size): raise ValueError("All inputs must have the same length.") if w.size < 5: logger.warning( "Input arrays have less than 5 pixels; skipping emission line masking." ) new_weight = np.array(wt, copy=True) if return_mask and return_lines_masked: return new_weight, np.zeros_like(w, dtype=bool), [] if return_mask: return new_weight, np.zeros_like(w, dtype=bool) return new_weight # Ensure sorted wavelength (common in spectra). If not sorted, do it safely. # Keep mapping so we still return an array matching the original order. order = np.argsort(w) inv_order = np.empty_like(order) inv_order[order] = np.arange(order.size) w_s = w[order] f_s = f[order] s_s = s[order] wt_s = wt[order] if lines is None: lines = get_default_emission_lines() logger.info("Using default emission line list with %d lines.", lines.size) # Precompute pixel dispersion (delta-lambda), robust median for local conversion dw = np.diff(w_s) # Replace non-positive steps (if duplicates) with nan for robust median dw = np.where(dw > 0, dw, np.nan) masked = np.zeros_like(w_s, dtype=bool) lines_masked: list[EmissionLine] = [] # Helper: continuum estimate around a given index window def _local_continuum(i0: int, i1: int) -> float: # simple robust median with sigma-clip on residuals seg_f = f_s[i0:i1] seg_s = s_s[i0:i1] seg_wt = wt_s[i0:i1] good = (seg_wt > 0) & np.isfinite(seg_f) & np.isfinite(seg_s) & (seg_s > 0) if good.sum() < 5: return float(np.nanmedian(seg_f)) # fallback x = seg_f[good] med = np.median(x) # clip using provided uncertainties if reasonable, else MAD sig = seg_s[good] # avoid division by zero sig = np.where( sig > 0, sig, np.nanmedian(sig[sig > 0]) if np.any(sig > 0) else 1.0 ) resid = x - med keep = np.abs(resid) <= cont_sigma_clip * sig if keep.sum() < 5: return float(med) return float(np.median(x[keep])) # Core loop over lines for line in lines: lam0 = line.rest_wavelength * (1.0 + redshift) # skip if out of wavelength coverage if lam0 < w_s[0] or lam0 > w_s[-1]: continue # Find nearest pixel to line center j = int(np.searchsorted(w_s, lam0)) j = max(0, min(j, w_s.size - 1)) # Define a pixel window for continuum around the line i0 = max(0, j - cont_half_window) i1 = min(w_s.size, j + cont_half_window + 1) cont = _local_continuum(i0, i1) if not np.isfinite(cont): continue # Require continuum not to be pure noise (helps avoid false triggers) seg_s = s_s[i0:i1] seg_wt = wt_s[i0:i1] good = (seg_wt > 0) & np.isfinite(seg_s) & (seg_s > 0) if good.sum() >= 5: cont_snr = np.median(np.abs(cont) / seg_s[good]) if np.isfinite(cont_snr) and cont_snr < min_continuum_snr: # still allow masking if the line is extremely significant; # we keep going but require higher SNR implicitly via snr_threshold. pass # Determine mask half-width (in wavelength units) if width_mode.lower() == "fixed": hw = float(line.default_half_width if half_width is None else half_width) elif width_mode.lower() == "fwhm_pixels": # convert pixels -> wavelength using local dispersion near j k0 = max(0, j - 5) k1 = min(dw.size, j + 5) local_dw = np.nanmedian(dw[k0:k1]) if not np.isfinite(local_dw) or local_dw <= 0: local_dw = np.nanmedian(dw) if not np.isfinite(local_dw) or local_dw <= 0: # cannot estimate dispersion; fallback to fixed default hw = float( line.default_half_width if half_width is None else half_width ) else: # mask ~ +/- 1.5*FWHM (covers core+wings for most cases) hw = 1.5 * float(fwhm_pixels) * float(local_dw) else: raise ValueError( f"Unsupported width_mode={width_mode!r}. Use 'fixed' or 'fwhm_pixels'." ) hw += float(pad) if max_half_width is not None: hw = min(hw, float(max_half_width)) hw = max(hw, 0.0) # Evaluate peak SNR in the line window (use continuum-subtracted flux) # Build index window from wavelength bounds a = lam0 - hw b = lam0 + hw k0 = int(np.searchsorted(w_s, a, side="left")) k1 = int(np.searchsorted(w_s, b, side="right")) if k1 - k0 < 3: continue seg_f = f_s[k0:k1] seg_s = s_s[k0:k1] seg_wt = wt_s[k0:k1] good = (seg_wt > 0) & np.isfinite(seg_f) & np.isfinite(seg_s) & (seg_s > 0) if good.sum() < 3: continue excess = seg_f[good] - cont snr = excess / seg_s[good] peak_snr = float(np.nanmax(snr)) if snr.size else -np.inf if np.isfinite(peak_snr) and peak_snr >= snr_threshold: masked[k0:k1] = True lines_masked.append(line) new_wt_s = np.array(wt_s, copy=True) new_wt_s[masked] = 0.0 # Restore original wavelength order new_weight = new_wt_s[inv_order] mask = masked[inv_order] logger.info("Masked %d pixels in %d lines", mask.sum(), len(lines_masked)) logger.debug("Lines used: %s", ", ".join(line.name for line in lines_masked)) if return_mask and return_lines_masked: return new_weight, mask, lines_masked if return_mask: return new_weight, mask return new_weight
[docs] def mask_sky_emission_lines( *args, **kwargs ) -> np.ndarray | tuple[np.ndarray, np.ndarray, list[EmissionLine]]: """ Specialized wrapper around mask_strong_emission_lines to target sky lines. This is a robust, low-assumption approach: 1. For each expected line center (rest to observed using redshift), estimate a local continuum with a running median in a window around the line. 2. Compute line excess as ``flux - continuum`` and its S/N using uncertainty. 3. If peak S/N within the line window exceeds ``snr_threshold``, mask the line region. Parameters ---------- See :func:`mask_strong_emission_lines` for details. The only difference is that if ``lines`` is None, a default set of common sky emission lines is used. Notes ----- - This targets *strong, narrow-ish* features near known lines. It will not detect arbitrary lines at unknown wavelengths unless you expand the line list. - The default line list is focused on common sky lines in the optical/NIR, but you can provide your own list for other regimes. """ if kwargs.get("lines") is None: lines = get_default_sky_emission_lines() logger.info("Using default sky emission line list with %d lines.", lines.size) kwargs["lines"] = lines if kwargs.get("redshift") != 0.0: logger.warning( "Redshift is non-zero (z=%.3f); sky lines will be shifted accordingly. Make sure this is intended.", kwargs.get("redshift"), ) return mask_strong_emission_lines(*args, **kwargs)
[docs] def estimate_continuum( wl, flux, err, weights=None, use_log=True, knot_spacing=100.0, sigma_clip=3.0 ): """Estimate continuum using a robust spline fit to the log-flux. Parameters ---------- wl : ndarray Wavelength array (1D). flux : ndarray Flux array (1D, same length as wl). err : ndarray Uncertainty array (1D, same length as wl). weights : ndarray, optional Weights for each pixel (1D, same length as wl). If None, all pixels are equally weighted. knot_spacing : float, optional Spacing between spline knots in the same units as wl (e.g., Angstrom). sigma_clip : float, optional Sigma-clipping threshold for outlier rejection in the log-flux residuals. use_log : bool, optional Whether to fit the log of the flux (True) or the flux itself (False). Fitting log-flux is more robust to outliers and multiplicative features, but may be less accurate if there are many zero/negative flux values. Returns ------- continuum : ndarray Estimated continuum flux at each wavelength. continuum_err : ndarray Estimated uncertainty of the continuum at each wavelength. """ if weights is None: weights = np.ones_like(wl) knots = np.arange(wl.min() + knot_spacing / 2, wl.max(), knot_spacing) knots_idx = np.searchsorted(wl, knots) knots_idx = np.insert(knots_idx, 0, 0) knots_idx[-1] = len(wl) - 1 if use_log: log_flux = np.where(flux > 0, np.log(flux), np.nan) log_flux_err = np.where(flux > 0, err / flux, np.nan) weights = np.where((flux > 0) & (err > 0), weights / log_flux_err**2, 0.0) else: log_flux = flux log_flux_err = err weights = weights / log_flux_err**2 # Initial fit using weighted median in each knot region knot_flux = np.array( [ np.nanmedian(log_flux[idx[0] : idx[1]][weights[idx[0] : idx[1]] > 0]) if np.any(weights[idx[0] : idx[1]] > 0) else np.nan for idx in zip(knots_idx[:-1], knots_idx[1:]) ] ) knot_flux_nmad = np.array( [ 1.48 * np.nanmedian( np.abs( log_flux[idx[0] : idx[1]][weights[idx[0] : idx[1]] > 0] - knot_flux[i] ) ) if np.any(weights[idx[0] : idx[1]] > 0) else np.nan for i, idx in enumerate(zip(knots_idx[:-1], knots_idx[1:])) ] ) log_flux_model = np.interp(wl, knots, knot_flux) log_flux_nmad_model = np.interp(wl, knots, knot_flux_nmad) residuals = log_flux - log_flux_model mask = np.abs(residuals) > sigma_clip * log_flux_nmad_model weights[mask] = 0.0 knot_flux = np.array( [ np.nansum(log_flux[idx[0] : idx[1]] * weights[idx[0] : idx[1]]) / np.nansum(weights[idx[0] : idx[1]]) if np.any(weights[idx[0] : idx[1]] > 0) else np.nan for idx in zip(knots_idx[:-1], knots_idx[1:]) ] ) knot_flux_var = np.array( [ np.nansum( (log_flux[idx[0] : idx[1]] - knot_flux[i]) ** 2 * weights[idx[0] : idx[1]] ) / np.nansum(weights[idx[0] : idx[1]]) if np.any(weights[idx[0] : idx[1]] > 0) else np.nan for i, idx in enumerate(zip(knots_idx[:-1], knots_idx[1:])) ] ) knot_weights = np.array( [ np.nansum(weights[idx[0] : idx[1]]) for idx in zip(knots_idx[:-1], knots_idx[1:]) ] ) good_knots = knot_weights > 0 if good_knots.sum() < 2: logger.warning( "Not enough good knots (%d) to fit continuum; returning median flux as flat continuum.", good_knots.sum(), ) median_flux = np.nanmedian(flux[weights > 0]) if np.any(weights > 0) else 1.0 return np.full_like(wl, median_flux), np.full_like(wl, median_flux) elif good_knots.sum() < 4: logger.warning( "Only %d good knots to fit continuum; using linear interpolation between knots.", good_knots.sum(), ) pol_c = np.polyfit( knots[good_knots], knot_flux[good_knots], w=knot_weights[good_knots], deg=1 ) pol_continuum = np.poly1d(pol_c) continuum = pol_continuum(wl) if use_log: continuum = np.exp(continuum) continuum_var = np.exp(2 * pol_continuum(wl)) * np.interp( wl, knots[good_knots], knot_flux_var[good_knots] ) else: continuum_var = np.interp(wl, knots[good_knots], knot_flux_var[good_knots]) return continuum, continuum_var**0.5 # Sigma-clip residuals and refit spline = make_smoothing_spline( knots[good_knots], knot_flux[good_knots], w=knot_weights[good_knots] ) if use_log: continuum = np.exp(spline(wl)) continuum_var = np.exp(spline(wl)) ** 2 * np.interp(wl, knots, knot_flux_var) else: continuum = spline(wl) continuum_var = np.interp(wl, knots, knot_flux_var) return continuum, continuum_var**0.5
def _gaussian(wl, line_flux, center, sigma): return ( line_flux / (sigma * np.sqrt(2 * np.pi)) * np.exp(-0.5 * ((wl - center) / sigma) ** 2) )
[docs] class LineSegmentationMap: """Container for detected emission-line segments and fitted line products. Parameters ---------- line_segmentation : ndarray Integer segmentation map where 0 marks non-line pixels and positive values identify detected line regions. flux : ndarray Observed flux array. wavelength : ndarray Wavelength array corresponding to ``flux``. error : ndarray Flux uncertainty array. weights : ndarray Pixel weights used when fitting line profiles. continuum : ndarray Estimated continuum flux array. continuum_error : ndarray Uncertainty of the estimated continuum. Attributes ---------- flux_cont_sub : ndarray Continuum-subtracted flux. flux_cont_sub_err : ndarray Uncertainty of the continuum-subtracted flux. nlines : int Number of labeled line regions in ``line_segmentation``. lines : EmissionLineList or None Fitted line measurements after calling :meth:`fit_all_lines`. """ def __init__( self, line_segmentation: np.ndarray, flux: np.ndarray, wavelength: np.ndarray, error: np.ndarray, weights: np.ndarray, continuum: np.ndarray, continuum_error: np.ndarray, ): self.line_segmentation = line_segmentation self.flux = flux self.wavelength = wavelength self.error = error self.weights = weights self.continuum = continuum self.continuum_error = continuum_error self.flux_cont_sub = flux - continuum self.flux_cont_sub_err = np.sqrt(error**2 + continuum_error**2) self.nlines = int(line_segmentation.max()) self.lines = None
[docs] def get_line_mask(self, line_id: int) -> np.ndarray: """Return a boolean mask for the specified line_id.""" return self.line_segmentation == line_id
def _gaussian_fit(self, wl, flux, err, weights, line_id, line_mask): mask = ( line_mask & np.isfinite(flux) & np.isfinite(err) & (err > 0) & (weights > 0) ) n_valid = np.count_nonzero(mask) if not n_valid: logger.warning( "No valid pixels to fit line_id=%s; returning NaN parameters.", line_id ) return dict( id=line_id, line_flux=0.0, line_flux_err=0.0, center=np.nan, sigma=np.nan, npixels=0, flag=1, ) elif n_valid < 3: logger.warning( "Not enough valid pixels to fit line_id=%s; using peak flux and width instead.", line_id, ) peak = np.argmax(flux[mask]) wl_peak = wl[mask][peak] sigma = np.clip(wl[mask].ptp() / 2.355, (wl[1] - wl[0]) / 10, None) return dict( id=line_id, line_flux=flux[mask][peak], line_flux_err=0.0, center=wl_peak, sigma=sigma, npixels=n_valid, flag=2, ) weighted_flux = np.sum(flux[mask] * weights[mask]) mean = np.median(wl[mask]) sigma = max(np.std(wl[mask]), (wl[1] - wl[0]) / 10) line_flux = np.nanmax(flux[mask]) try: popt, pcov, infodict, messg, ier = curve_fit( _gaussian, wl[mask], flux[mask], p0=[line_flux, mean, sigma], bounds=( [0, wl[mask].min(), (wl[1] - wl[0]) / 10], [ 10 * line_flux, max(wl[mask].min() + wl[1] - wl[0], wl[mask].max()), 100, ], ), sigma=err[mask], absolute_sigma=True, full_output=True, ftol=1e-9, maxfev=1000, ) line_flux, mean, sigma = popt except Exception as e: logger.warning( "Gaussian fit failed (%s) for line_id=%s; using MLE estimates instead.", str(e), line_id, ) return dict( id=line_id, line_flux=line_flux, line_flux_err=line_flux, center=mean, sigma=sigma, npixels=n_valid, flag=3, ) return dict( id=line_id, line_flux=line_flux, line_flux_err=np.sqrt(np.diag(pcov)[0]), center=mean, sigma=sigma, npixels=n_valid, flag=0, )
[docs] def fit_line(self, line_id: int) -> dict: """Fit a Gaussian to the specified line_id and return fit parameters.""" mask = self.get_line_mask(line_id) return self._gaussian_fit( wl=self.wavelength, flux=self.flux_cont_sub, err=self.flux_cont_sub_err, weights=self.weights, line_id=line_id, line_mask=mask, # Convert boolean mask to int for fitting )
[docs] def fit_all_lines(self) -> Table: """Fit all lines in the segmentation map and return a table of results.""" output_table = Table( names=[ "id", "line_flux", "center", "line_flux_err", "sigma", "npixels", "flag", ], dtype=[int, float, float, float, float, int, int], meta={ "description": "Fitted parameters for each emission line segment", "flags": "0=good fit, 1=no valid pixels, 2=not enough pixels for fit, 3=fit failed, used MLE estimates", }, ) measured_lines = [] for line_id in range(1, self.nlines + 1): fit_params = self.fit_line(line_id) output_table.add_row(fit_params) measured_lines.append( EmissionLine( name=f"line_{line_id}", rest_wavelength=float(fit_params["center"]) if np.isfinite(fit_params["center"]) else np.nan, default_half_width=float(fit_params["sigma"]) if np.isfinite(fit_params["sigma"]) else 10.0, flux=float(fit_params["line_flux"]), flux_error=float(fit_params["line_flux_err"]), flag=int(fit_params["flag"]), metadata={ "id": int(fit_params["id"]), "npixels": int(fit_params["npixels"]), }, ) ) self.lines = EmissionLineList(measured_lines) # Build emission line spectra eline_flux = np.zeros_like(self.wavelength) for row in output_table: line = _gaussian( self.wavelength, row["line_flux"], row["center"], row["sigma"] ) eline_flux += line if np.isfinite(line).all() else 0.0 self.eline_flux = eline_flux return output_table
[docs] def _watershed_1d( signal: np.ndarray, markers: np.ndarray, mask: np.ndarray ) -> np.ndarray: """ 1D watershed segmentation for deblending overlapping emission lines. Starting from labeled seed regions (markers), flood outward following descending signal gradient until regions meet or the mask boundary is reached. Parameters ---------- signal : ndarray 1D array used to guide the flooding (typically SNR). Higher values are flooded first. markers : ndarray Integer label array with seed pixels (>0) already identified. 0 means unlabeled; will be filled by watershed. mask : ndarray of bool Only pixels where mask is True are eligible for flooding. Returns ------- labels : ndarray of int Label array after watershed. Pixels outside the mask remain 0. """ from heapq import heappush, heappop labels = np.array(markers, dtype=int, copy=True) in_queue = np.zeros(len(signal), dtype=bool) # Priority queue: (-signal_value, pixel_index, label) # Negative because heapq is a min-heap; we want to process highest signal first heap = [] # Seed the queue with all boundary pixels of each marker region for i in range(len(signal)): if labels[i] > 0 and mask[i]: for di in (-1, 1): nb = i + di if ( 0 <= nb < len(signal) and mask[nb] and labels[nb] == 0 and not in_queue[nb] ): heappush(heap, (-signal[nb], nb, labels[i])) in_queue[nb] = True while heap: neg_val, idx, lbl = heappop(heap) # Skip if already labeled by a previous (higher-priority) flood if labels[idx] != 0: continue labels[idx] = lbl # Expand to unlabeled neighbours for di in (-1, 1): nb = idx + di if ( 0 <= nb < len(signal) and mask[nb] and labels[nb] == 0 and not in_queue[nb] ): heappush(heap, (-signal[nb], nb, lbl)) in_queue[nb] = True return labels
[docs] @_parse_lines_param_decorator def find_emission_lines( wl, flux, err, weights=None, continuum=None, continuum_error=None, lines=None, redshift=0.0, to_rest_frame=False, snr_threshold=3.0, min_continuum_snr=1.0, cont_sigma_clip=1.5, knot_spacing=100.0, min_npixels: int = 3, pad_detection_pix: int = 10, deblend_lines: bool = True, ): """Find and fit emission lines in a spectrum. Parameters ---------- wl : ndarray Wavelength array (1D). flux : ndarray Flux array (1D, same length as wl). err : ndarray Uncertainty array (1D, same length as wl). weights : ndarray, optional Weights for each pixel (1D, same length as wl). If None, all pixels are equally weighted. continuum : ndarray, optional Pre-computed continuum flux at each wavelength. If None, it will be estimated. continuum_error : ndarray, optional Uncertainty of the continuum at each wavelength. If None, it will be estimated. lines : sequence of EmissionLine, optional Line list. If None, uses DEFAULT_EMISSION_LINES_A (Angstrom). snr_threshold : float, optional Minimum peak (flux-continuum)/sigma to consider a line detected. min_continuum_snr : float, optional Require median(abs(continuum)/sigma) in the continuum window to be at least this value to consider a line detection valid; helps avoid false positives when everything is noise. cont_sigma_clip : float, optional Sigma-clipping threshold for outlier rejection in the continuum estimation. knot_spacing : float, optional Spacing between spline knots for continuum estimation in the same units as wl (e.g., Angstrom). min_npixels : int, optional Minimum number of contiguous pixels above the SNR threshold to consider a valid line detection. pad_detection_pix : int, optional Number of pixels to pad on either side of detected line regions to ensure the full line is captured, especially for broad lines. deblend_lines : bool, optional If True, apply 1D watershed deblending to separate overlapping line detections. Each local SNR maximum seeds its own region and regions are grown by flooding in descending SNR order until they meet. Default is True. Returns ------- output_table : astropy.table.Table Table with columns: line_flux, line_flux_err, center, sigma, npixels, flag. continuum : ndarray Estimated continuum flux at each wavelength. continuum_error : ndarray Estimated uncertainty of the continuum at each wavelength. """ if weights is None: weights = np.ones_like(wl) # Estimate continuum if continuum is None and continuum_error is None: continuum, continuum_error = estimate_continuum( wl, flux, err, weights, knot_spacing=knot_spacing, sigma_clip=cont_sigma_clip, ) if lines is not None: logger.info("Using provided line list with %d lines.", lines.size) lines_mask = lines.to_mask(wl, redshift=redshift).astype(float) else: logger.info("No line list provided; using all pixels for detection.") lines_mask = np.ones_like(wl) # Continuum-free flux and SNR clean_flux = flux - continuum clean_flux *= lines_mask clean_flux_err = np.sqrt(err**2 + continuum_error**2) clean_snr = np.where(clean_flux_err > 0, clean_flux / clean_flux_err, 0.0) # Mask regions where continuum is not well constrained cont_snr = np.where(continuum_error > 0, continuum / continuum_error, 0.0) clean_snr[cont_snr < min_continuum_snr] = 0.0 bright_pixels = clean_snr > snr_threshold line_ids, nlines = label(bright_pixels) slices = find_objects(line_ids) # just to log the line regions logger.info( "Initial detection found %d candidate lines with SNR > %.1f", nlines, snr_threshold, ) # Deblend lines if deblend_lines: # Find local SNR maxima within each labeled region to use as seeds peak_markers = np.zeros_like(line_ids) next_marker_id = 1 for seg_id in range(1, nlines + 1): seg_indices = np.where(line_ids == seg_id)[0] if len(seg_indices) == 0: continue seg_snr = clean_snr[seg_indices] # Find local peaks within this segment local_peaks, _ = find_peaks(seg_snr) if len(local_peaks) == 0: # No local peak found; use the global maximum as the single seed local_peaks = [int(np.argmax(seg_snr))] for peak_pix in local_peaks: global_idx = seg_indices[peak_pix] peak_markers[global_idx] = next_marker_id next_marker_id += 1 n_seeds = next_marker_id - 1 logger.info( "Watershed deblending: found %d peak seeds across %d initial regions.", n_seeds, nlines, ) if n_seeds > nlines: # Only run watershed if we actually found sub-peaks worth splitting line_ids = _watershed_1d(clean_snr, peak_markers, bright_pixels) nlines = np.unique(line_ids[line_ids > 0]).size slices = find_objects(line_ids) logger.info("After watershed deblending, found %d candidate lines.", nlines) else: logger.info("No additional peaks found; skipping watershed split.") # Get number of pixels in each line and filter by min_npixels line_pixel_counts = np.array([s[0].stop - s[0].start for s in slices]) valid_line_ids = np.where(line_pixel_counts >= min_npixels)[0] + 1 line_ids = np.where(np.isin(line_ids, valid_line_ids), line_ids, 0) nlines = len(valid_line_ids) logger.info( "%d candidate lines have at least %d pixels above the SNR threshold.", nlines, min_npixels, ) # re-map line ids to 1..nlines unique_ids = np.unique(line_ids) new_id_map = { old_id: new_id for new_id, old_id in enumerate(unique_ids[unique_ids > 0], start=1) } line_ids = np.array( np.where(line_ids > 0, np.vectorize(new_id_map.get)(line_ids), 0), dtype=int ) if pad_detection_pix > 0: slices = find_objects(line_ids) for i, s in enumerate(slices, start=1): if s is None: continue # avoid overlap with neighboring lines by only padding within the current line segment if i > 1 and s[0].start - pad_detection_pix < slices[i - 2][0].stop: dist_to_prev = s[0].start - slices[i - 2][0].stop pad_left = min(pad_detection_pix, dist_to_prev) else: pad_left = pad_detection_pix if i < len(slices) and s[0].stop + pad_detection_pix > slices[i][0].start: dist_to_next = slices[i][0].start - s[0].stop pad_right = min(pad_detection_pix, dist_to_next) else: pad_right = pad_detection_pix start = max(0, s[0].start - pad_left) stop = min(len(wl), s[0].stop + pad_right) line_ids[start:stop] = i logger.info( "Found %d candidate emission lines with SNR > %.1f", nlines, snr_threshold ) line_segm_map = LineSegmentationMap( line_segmentation=line_ids, flux=flux, wavelength=wl, error=err, weights=weights, continuum=continuum, continuum_error=continuum_error, ) output_table = line_segm_map.fit_all_lines() if lines is not None: # Match detected lines to input line list based on proximity of centers matched_line_names = [] for row in output_table: line_center = row["center"] if not np.isfinite(line_center): matched_line_names.append("unknown") continue matched_line = lines.get_closest_line(line_center, redshift=redshift) tolerance = ( max(3 * row["sigma"], matched_line.default_half_width) if matched_line is not None else 0.0 ) if ( matched_line is not None and np.abs(matched_line.rest_wavelength * (1 + redshift) - line_center) < tolerance ): matched_line_names.append(matched_line.name) else: matched_line_names.append("unknown") output_table["line_name"] = matched_line_names if line_segm_map.lines is not None: line_segm_map.lines = line_segm_map.lines.with_names(matched_line_names) if to_rest_frame: output_table["center"] = output_table["center"] / (1 + redshift) output_table["sigma"] = output_table["sigma"] / (1 + redshift) if line_segm_map.lines is not None: line_segm_map.lines = EmissionLineList( [ replace( line, rest_wavelength=line.rest_wavelength / (1 + redshift), default_half_width=line.default_half_width / (1 + redshift), ) for line in line_segm_map.lines ] ) return output_table, line_segm_map