Source code for gaiaxpy.spectrum.sampled_spectrum

"""
sampled_spectrum.py
====================================
Module to represent a sampled spectrum.
"""

import math
import numpy as np
import matplotlib.pyplot as plt
from os import path
from pathlib import Path
from .generic_spectrum import Spectrum


[docs]class SampledSpectrum(Spectrum): """ A spectrum defined by a set of discrete measurements. Each measurement is defined by a position in wavelength (or pseudo-wavelength), a measured flux and an associated flux error. Specific implementations of this class will define the units in use for positions and fluxes. """ def __init__( self, source_id, sampling_grid): """ Initialise a sampled spectrum. Args: source_id (str): Source identifier. sampling_grid (ndarray): 1D array containing the positions of the samples (or discrete measurements) for this spectrum. """ Spectrum.__init__(self, source_id) self.n_samples = np.size(sampling_grid) self.pos = None self.flux = None self.error = None def _get_fluxes(self): """ Get the flux samples. Returns: ndarray: 1D array containing all flux samples. """ return self.flux def _get_flux_errors(self): """ Get the flux errors of each sample. Returns: ndarray: 1D array containing the error in flux for all samples. """ return self.error def _get_positions(self): """ Get the positions of all samples. Returns: ndarray: 1D array containing the position of all samples. """ return self.pos def _get_flux_label(self): """ Get the labels describing the flux measurements. Returns: str: Short description of the flux measurements, including the units. """ return "" def _get_position_label(self): """ Get the positions of the samples, including the units. Returns: str: Short description of the positions of the samples. """ return "" def _get_inputs(self, spectrum): return spectrum._get_positions(), spectrum._get_fluxes(), spectrum._get_flux_errors() def _save_figure(self, output_path, file_name, format): if output_path: Path(output_path).mkdir(parents=True, exist_ok=True) plt.savefig( path.join(output_path, f'{file_name}.{format}'), format=format, transparent=False) @staticmethod def _sample_flux(coefficients, design_matrix): """ Given a set of coefficients to be applied to a set of basis functions and a design matrix containing the evaluation of each basis function at the positions corresponding to the samples, this method computes the flux values for each sample. Args: coefficients (ndarray): 1D array containing the coefficients multiplying the basis functions in the continuous representation. design_matrix (ndarray): 2D array containing the evaluation of the basis functions on the desired sampling grid. Returns: ndarray: 1D array containing the flux values for all samples. """ return coefficients.dot(design_matrix) @staticmethod def _sample_error(covariance, design_matrix, standard_deviation): """ Given the covariance matrix and standard deviation of the least squares solution defining the continuous representation of the spectrum in terms of basis functions and a design matrix containing the evaluation of each basis function at the positions corresponding to the samples, this method computes the error associated to the flux value for each sample. Args: covariance (ndarray): 2D array containing the elements of the covariance matrix. design_matrix (ndarray): 2D array containing the evaluation of the basis functions on the desired sampling grid. standard_deviation (float): Standard deviation. Returns: ndarray: 1D array containing the errors in flux for all samples. """ n_samples = design_matrix.shape[1] error = np.zeros(n_samples) for i in range(n_samples): error[i] = math.sqrt(design_matrix[:, i].dot(covariance).dot( design_matrix[:, i])) * standard_deviation return error