"""
continuous_spectra_data.py
====================================
Module to represent continuous spectra data.
"""
import warnings
from os.path import join
from pathlib import Path
import numpy as np
import pandas as pd
from astropy.io import fits
from astropy.io.votable.tree import Field, Resource, Table, VOTableFile
from astropy.units import UnitsWarning
from fastavro import parse_schema, writer
from fastavro.validation import validate_many
from gaiaxpy.core.satellite import BANDS
from .output_data import OutputData
from .utils import _add_ecsv_header, _build_ecsv_header, _generate_fits_header, _load_header_dict
warnings.filterwarnings('ignore', category=UnitsWarning)
[docs]
class ContinuousSpectraData(OutputData):
def __init__(self, data):
super().__init__(data, None)
def _save_avro(self, output_path, output_file):
"""
Save the output spectra in AVRO format.
Args:
output_path (str): Path where to save the file.
output_file (str): Name of the output file.
"""
def _generate_avro_schema(_spectra_dicts):
"""
Generate the AVRO schema required to store the output.
Args:
_spectra_dicts (list): A list of dictionaries containing spectra.
Returns:
dict: A dictionary containing the parsed schema that matches the input.
list of dicts: A list of dictionaries with the modified input spectra
according to the valid AVRO types.
"""
field_to_type = {
'source_id': 'long',
'bp_standard_deviation': 'float', 'rp_standard_deviation': 'float',
'bp_coefficients': 'string', 'rp_coefficients': 'string',
'bp_coefficient_correlations': 'string', 'rp_coefficient_correlations': 'string',
'bp_coefficient_errors': 'string', 'rp_coefficient_errors': 'string',
'bp_n_parameters': 'int', 'rp_n_parameters': 'int',
'bp_basis_function_id': 'int', 'rp_basis_function_id': 'int'
}
def build_field(keys):
return [{'name': key, 'type': field_to_type[key]} for key in keys]
schema = {'doc': 'Spectrum output.', 'name': 'Spectra', 'namespace': 'spectrum', 'type': 'record',
'fields': build_field(_spectra_dicts[0].keys())}
# Spectrum fields to string
for spectrum in _spectra_dicts:
for field, _type in field_to_type.items():
if _type == 'string':
spectrum[field] = str(tuple(spectrum[field]))
# Validate that records match the schema
validate_many(_spectra_dicts, schema)
return parse_schema(schema), _spectra_dicts
data = self.data
# List with one dictionary per source
spectra_dicts = data.to_dict('records')
parsed_schema, spectra_dicts = _generate_avro_schema(spectra_dicts)
Path(output_path).mkdir(parents=True, exist_ok=True)
output_path = join(output_path, f'{output_file}.avro')
with open(output_path, 'wb') as output:
writer(output, parsed_schema, spectra_dicts)
def _save_csv(self, output_path, output_file):
"""
Save the output spectra in CSV format.
Args:
output_path (str): Path where to save the file.
output_file (str): Name of the output file.
"""
spectra_df = self.data
array_columns = [column for column in spectra_df.columns if isinstance(spectra_df[column].iloc[0], np.ndarray)]
spectra_df[array_columns] = spectra_df[array_columns].apply(lambda col: col.apply(tuple)).astype('str')
Path(output_path).mkdir(parents=True, exist_ok=True)
output_path = join(output_path, f'{output_file}.csv')
spectra_df.to_csv(output_path, index=False)
def _save_ecsv(self, output_path, output_file):
"""
Save the output spectra in ECSV format.
Args:
output_path (str): Path where to save the file.
output_file (str): Name of the output file.
"""
spectra_df = self.data
array_columns = [column for column in spectra_df.columns if isinstance(spectra_df[column].iloc[0], np.ndarray)]
header_lines = _build_ecsv_header(spectra_df)
spectra_df[array_columns] = spectra_df[array_columns].apply(lambda col: col.apply(tuple)).astype('str')
Path(output_path).mkdir(parents=True, exist_ok=True)
spectra_df.to_csv(join(output_path, f'{output_file}.ecsv'), index=False)
_add_ecsv_header(header_lines, output_path, output_file)
def _save_fits(self, output_path, output_file):
"""
Save the output data in FITS format.
Args:
output_path (str): Path where to save the file.
output_file (str): Name of the output file.
"""
data = self.data
coefficients_format = f"{len(data['bp_coefficients'].iloc[0])}D" # D: double precision float
correlations_format = f"{len(data['bp_coefficient_correlations'].iloc[0])}D"
errors_format = f"{len(data['bp_coefficient_errors'].iloc[0])}E" # E: single precision float
# Define formats for each type according to FITS
column_formats = {
'source_id': 'K',
'bp_standard_deviation': 'D', 'rp_standard_deviation': 'D',
'bp_coefficients': coefficients_format, 'rp_coefficients': coefficients_format,
'bp_coefficient_correlations': correlations_format, 'rp_coefficient_correlations': correlations_format,
'bp_coefficient_errors': errors_format, 'rp_coefficient_errors': errors_format,
'bp_n_parameters': 'I', 'rp_n_parameters': 'I',
'bp_basis_function_id': 'I', 'rp_basis_function_id': 'I'}
# create a list of HDUs
hdu_list = list()
# create a header to include the sampling
hdr = fits.Header()
primary_hdu = fits.PrimaryHDU(header=hdr)
hdu_list.append(primary_hdu)
# Create a dictionary to hold all the data
output_by_column_dict = data.reset_index().to_dict(orient='list')
# Remove index from output dict
output_by_column_dict.pop('index', None)
spectra_keys = output_by_column_dict.keys()
data_type = data.attrs['data_type']
units_dict = data_type.get_units()
columns = [fits.Column(name=key, array=np.array(output_by_column_dict[key]), format=column_formats[key],
unit=units_dict.get(key, '')) for key in spectra_keys]
header = _generate_fits_header(data, column_formats)
hdu = fits.BinTableHDU.from_columns(columns, header=header)
hdu_list.append(hdu)
# Put all HDUs together
hdul = fits.HDUList(hdu_list)
# Write the file and replace it if it already exists
Path(output_path).mkdir(parents=True, exist_ok=True)
output_path = join(output_path, f'{output_file}.fits')
hdul.writeto(output_path, overwrite=True)
def _save_xml(self, output_path, output_file):
"""
Save the output spectra in XML/VOTABLE format.
Args:
output_path (str): Path where to save the file.
output_file (str): Name of the output file.
"""
def _create_fields(_votable, _spectra_df):
fields_datatypes = {'source_id': 'long',
f'{BANDS.bp}_standard_deviation': 'double', f'{BANDS.rp}_standard_deviation': 'double',
f'{BANDS.bp}_coefficients': 'double', f'{BANDS.rp}_coefficients': 'double',
f'{BANDS.bp}_coefficient_correlations': 'double',
f'{BANDS.rp}_coefficient_correlations': 'double',
f'{BANDS.bp}_coefficient_errors': 'float', f'{BANDS.rp}_coefficient_errors': 'float',
f'{BANDS.bp}_n_parameters': 'int', f'{BANDS.rp}_n_parameters': 'int',
f'{BANDS.bp}_basis_function_id': 'int', f'{BANDS.rp}_basis_function_id': 'int'}
fields_array_size = {'source_id': '',
f'{BANDS.bp}_standard_deviation': '', f'{BANDS.rp}_standard_deviation': '',
f'{BANDS.bp}_coefficients': '*', f'{BANDS.rp}_coefficients': '*',
f'{BANDS.bp}_coefficient_correlations': '*',
f'{BANDS.rp}_coefficient_correlations': '*',
f'{BANDS.bp}_coefficient_errors': '*', f'{BANDS.rp}_coefficient_errors': '*',
f'{BANDS.bp}_n_parameters': '', f'{BANDS.rp}_n_parameters': '',
f'{BANDS.bp}_basis_function_id': '', f'{BANDS.rp}_basis_function_id': ''}
columns = _spectra_df.columns
data_type = _spectra_df.attrs['data_type']
units_dict = data_type.get_units()
header_dict = _load_header_dict()
_fields = [Field(_votable, name=column, datatype=fields_datatypes[column],
arraysize=fields_array_size[column], ucd=header_dict.get(column, dict()).get('meta', ''),
unit=units_dict.get(column, '')) if fields_array_size[column] != '' else
Field(_votable, name=column, datatype=fields_datatypes[column],
ucd=header_dict.get(column, dict()).get('meta', ''), unit=units_dict.get(column, ''))
for column in columns]
for _field in _fields:
_field.description = header_dict.get(_field.name, dict()).get('description', '')
return _fields
spectra_df = self.data
# Create a new VOTable file
votable = VOTableFile()
# Add a resource
resource = Resource()
votable.resources.append(resource)
# Add a table for the spectra (and add the sampling as metadata)
spectra_table = Table(votable)
resource.tables.append(spectra_table)
# Add spectrum fields
fields = _create_fields(votable, spectra_df)
spectra_table.fields.extend(fields)
# Create the record arrays, with the given number of rows
spectra_table.create_arrays(len(spectra_df))
for index, row in enumerate(spectra_df.to_dict('records')):
spectra_table.array[index] = tuple([row[column] for column in spectra_df.columns])
# Write to a file
Path(output_path).mkdir(parents=True, exist_ok=True)
output_path = join(output_path, f'{output_file}.xml')
votable.to_xml(output_path)
def _get_spectra_df(self):
data = self.data
spectra_bp_df = pd.DataFrame.from_records([spectrum[BANDS.bp].spectrum_to_dict() for spectrum in data])
spectra_rp_df = pd.DataFrame.from_records([spectrum[BANDS.rp].spectrum_to_dict() for spectrum in data])
spectra_df = spectra_bp_df.merge(spectra_rp_df, on='source_id', how='outer')
for col in spectra_df.columns:
if 'xp' in col:
spectra_df = spectra_df.drop(col, axis=1)
else:
if '_x' in col:
col_new = col.replace('_x', '')
col_new = f'{BANDS.bp}_' + col_new
spectra_df = spectra_df.rename(columns={col: col_new})
if '_y' in col:
col_new = col.replace('_y', '')
col_new = f'{BANDS.rp}_' + col_new
spectra_df = spectra_df.rename(columns={col: col_new})
return spectra_df