Source code for redrock.external.boss

"""
redrock.external.boss
=====================

redrock wrapper tools for BOSS
"""
import os
import sys
import re
import warnings
import traceback

import argparse

import numpy as np
from scipy import sparse

from astropy.io import fits
from astropy.table import Table

import fitsio

import desispec.resolution
from desispec.resolution import Resolution

from ..utils import elapsed, get_mp, distribute_work

from ..targets import Spectrum, Target, DistTargetsCopy

from ..templates import load_dist_templates

from ..results import write_zscan

from ..zfind import zfind

from .._version import __version__

from ..archetypes import All_archetypes


def platemjdfiber2targetid(plate, mjd, fiber):
    return plate*1000000000 + mjd*10000 + fiber


def targetid2platemjdfiber(targetid):
    fiber = targetid % 10000
    mjd = (targetid // 10000) % 100000
    plate = (targetid // (10000 * 100000))
    return (plate, mjd, fiber)


[docs]def write_zbest(outfile, zbest, template_version, archetype_version): """Write zbest Table to outfile Args: outfile (str): output file. zbest (Table): the output best fit results. """ header = fits.Header() header['RRVER'] = (__version__, 'Redrock version') for i, fulltype in enumerate(template_version.keys()): header['TEMNAM'+str(i).zfill(2)] = fulltype header['TEMVER'+str(i).zfill(2)] = template_version[fulltype] if not archetype_version is None: for i, fulltype in enumerate(archetype_version.keys()): header['ARCNAM'+str(i).zfill(2)] = fulltype header['ARCVER'+str(i).zfill(2)] = archetype_version[fulltype] zbest.meta['EXTNAME'] = 'ZBEST' hx = fits.HDUList() hx.append(fits.PrimaryHDU(header=header)) hx.append(fits.convenience.table_to_hdu(zbest)) outfile = os.path.expandvars(outfile) tempfile = outfile + '.tmp' hx.writeto(tempfile, overwrite=True) os.rename(tempfile, outfile) return
### @profile
[docs]def read_spectra(spplates_name, targetids=None, use_frames=False, fiberid=None, coadd=False, cache_Rcsr=False, use_andmask=False): """Read targets from a list of spectra files Args: spplates_name (list or str): input spPlate files or pattern to match files. targetids (list): restrict targets to this subset. use_frames (bool): if True, use frames. fiberid (int): Use this fiber ID. coadd (bool): if True, compute and use the coadds. cache_Rcsr (bool): pre-calculate and cache sparse CSR format of resolution matrix R use_andmask (bool): sets ivar = 0 to pixels with and_mask != 0 Returns: tuple: (targets, meta) where targets is a list of Target objects and meta is a Table of metadata (currently only BRICKNAME). """ # check the file list if isinstance(spplates_name, str): import glob spplates_name = glob.glob(spplates_name) ## read spplates useThingid = False if len(spplates_name)>1: print("DEBUG: Reading multiple observations: using THING_ID instead of PLATE*1000000000 + MJD*10000 + FIBERID") useThingid = True fiberid2thingid = {} plate = [] mjd = [] infiles = [] for spplate_name in spplates_name: spplate = fitsio.FITS(spplate_name) plate += [spplate[0].read_header()["PLATEID"]] mjd += [spplate[0].read_header()["MJD"]] if useThingid: photoPlate = fitsio.FITS(spplate_name.replace('spPlate','photoPosPlate')) if use_frames: path = os.path.dirname(spplate_name) cameras = ['b1','r1','b2','r2'] nexp_tot=0 for c in cameras: try: nexp = spplate[0].read_header()["NEXP_{}".format(c.upper())] except ValueError: print("DEBUG: spplate {} has no exposures in camera {} ".format(spplate_name,c)) continue for i in range(1,nexp+1): nexp_tot += 1 expid = str(nexp_tot).zfill(2) exp = path+"/spCFrame-"+spplate[0].read_header()["EXPID"+expid][:11]+".fits" infiles.append(exp) if useThingid: fiberid2thingid[exp] = photoPlate[1]['THING_ID'][:] else: infiles.append(spplate_name) if useThingid: fiberid2thingid[spplate_name] = photoPlate[1]['THING_ID'][:] spplate.close() if useThingid: photoPlate.close() if len(set(plate))==1: plate = plate[0] else: plate = 0 if len(set(mjd))==1: mjd = mjd[0] else: mjd = 0 bricknames={} dic_spectra = {} for infile in infiles: h = fitsio.FITS(infile) if not useThingid: assert plate == h[0].read_header()["PLATEID"] fs = h[5]["FIBERID"][:] if fiberid is not None: w = np.in1d(fs,fiberid) fs = fs[w] fl = h[0].read() iv = h[1].read() if use_andmask: iv *= 1.*(h[2].read()==0) wd = h[4].read() ## crop to lmin, lmax lmin = 3500. lmax = 10000. if use_frames: la = 10**h[3].read() if h[0].read_header()["CAMERAS"][0]=="b": lmax = 6000. else: lmin = 5500. else: coeff0 = h[0].read_header()["COEFF0"] coeff1 = h[0].read_header()["COEFF1"] la = 10**(coeff0 + coeff1*np.arange(fl.shape[1])) la = np.broadcast_to(la,fl.shape) imin = abs(la-lmin).min(axis=0).argmin() imax = abs(la-lmax).min(axis=0).argmin() la = la[:,imin:imax] fl = fl[:,imin:imax] iv = iv[:,imin:imax] wd = wd[:,imin:imax] w = wd<1e-5 wd[w]=2. ii = np.arange(la.shape[1]) di = ii-ii[:,None] di2 = di**2 ndiag = int(4*np.ceil(wd.max())+1) nbins = wd.shape[1] for f in fs: i = (f-1) if use_frames: i = i%500 if useThingid: t = fiberid2thingid[infile][f-1] else: t = platemjdfiber2targetid(plate, mjd, f) if not targetids is None and not t in targetids: continue if t not in dic_spectra: dic_spectra[t]=[] brickname = '{}-{}'.format(plate,mjd) bricknames[t] = brickname ## build resolution from wdisp reso = np.zeros([ndiag,nbins]) for idiag in range(ndiag): offset = ndiag//2-idiag d = np.diagonal(di2,offset=offset) if offset<0: reso[idiag,:len(d)] = np.exp(-d/2/wd[i,:len(d)]**2) else: reso[idiag,nbins-len(d):nbins]=np.exp(-d/2/wd[i,nbins-len(d):nbins]**2) # R = Resolution(reso) # ccd = sparse.spdiags(1./R.sum(axis=1).T, 0, *R.shape) # R = (ccd*R).todia() reso /= np.sum(reso, axis=0) offsets = ndiag//2 - np.arange(ndiag) nwave = reso.shape[1] R = sparse.dia_matrix((reso, offsets), (nwave, nwave)) if cache_Rcsr: Rcsr = R.tocsr() else: Rcsr = None dic_spectra[t].append(Spectrum(la[i], fl[i], iv[i], R, Rcsr)) h.close() print("DEBUG: read {} ".format(infile)) if targetids == None: targetids = sorted(list(dic_spectra.keys())) else: targetids = sorted(targetids) targets = [] for targetid in targetids: spectra = dic_spectra[targetid] # Add the brickname to the meta dictionary. The keys of this dictionary # will end up as extra columns in the output ZBEST HDU. tmeta = dict() tmeta["BRICKNAME"] = bricknames[targetid] tmeta["BRICKNAME_datatype"] = "S8" if len(spectra) > 0: targets.append(Target(targetid, spectra, coadd=coadd, meta=tmeta)) else: print('ERROR: Target {} on {} has no good spectra'.format(targetid, os.path.basename(brickfiles[0]))) #- Create a metadata table in case we might want to add other columns #- in the future for k in sorted(list(bricknames.keys())): if k not in targetids: del bricknames[k] assert len(bricknames.keys()) == len(targets) metatable = Table() metatable['TARGETID'] = targetids bx = np.array([bricknames[t] for t in targetids], dtype='S8') metatable['BRICKNAME'] = bx # metatable = Table(names=("TARGETID", "BRICKNAME"), dtype=("i8", "S8",)) # for i, t in enumerate(targetids): # metatable.add_row( (t, bricknames[t]) ) return targets, metatable
[docs]def rrboss(options=None, comm=None): """Estimate redshifts for BOSS targets. This loads targets serially and copies them into a DistTargets class. It then runs redshift fitting and writes the output to a catalog. Args: options (list): optional list of commandline options to parse. comm (mpi4py.Comm): MPI communicator to use. """ global_start = elapsed(None, "", comm=comm) parser = argparse.ArgumentParser(description="Estimate redshifts from" " BOSS target spectra.") parser.add_argument("--spplate", type=str, default=None, required=True, help="input plate files", nargs='*') parser.add_argument("-t", "--templates", type=str, default=None, required=False, help="template file or directory") parser.add_argument("--archetypes", type=str, default=None, required=False, help="archetype file or directory for final redshift comparisons") parser.add_argument("-o", "--output", type=str, default=None, required=False, help="output file") parser.add_argument("--zbest", type=str, default=None, required=False, help="output zbest FITS file") parser.add_argument("--targetids", type=str, default=None, required=False, help="comma-separated list of target IDs") parser.add_argument("--mintarget", type=int, default=None, required=False, help="first target to process") parser.add_argument("--priors", type=str, default=None, required=False, help="optional redshift prior file") parser.add_argument("--chi2-scan", type=str, default=None, required=False, help="Load the chi2-scan from the input file") parser.add_argument("-n", "--ntargets", type=int, required=False, help="the number of targets to process") parser.add_argument("--nminima", type=int, default=3, required=False, help="the number of redshift minima to search") parser.add_argument("--allspec", default=False, action="store_true", required=False, help="use individual spectra instead of coadd") parser.add_argument("--mp", type=int, default=0, required=False, help="if not using MPI, the number of multiprocessing" " processes to use (defaults to half of the hardware threads)") parser.add_argument("--use-frames", default=False, action="store_true", required=False, help="use individual spcframes instead of spplate " "(the spCFrame files are expected to be in the same directory as " "the spPlate") parser.add_argument("--use-andmask", default=False, action="store_true", required=False, help="uses and_mask values to set masked pixel's ivar to zero") parser.add_argument("--no-mpi-abort", default=False, action="store_true", required=False, help="Do not call MPI Abort upon failure of a single rank") parser.add_argument("--debug", default=False, action="store_true", required=False, help="debug with ipython (only if communicator has a " "single process)") args = None if options is None: args = parser.parse_args() else: args = parser.parse_args(options) comm_size = 1 comm_rank = 0 if comm is not None: comm_size = comm.size comm_rank = comm.rank # Check arguments- all processes have this, so just check on the first # process if comm_rank == 0: if args.debug and comm_size != 1: print("--debug can only be used if the communicator has one " " process") sys.stdout.flush() if comm is not None: comm.Abort() if (args.output is None) and (args.zbest is None): print("--output or --zbest required") sys.stdout.flush() if comm is not None: comm.Abort() if (args.targetids is not None) and ((args.mintarget is not None) \ or (args.ntargets is not None)): print("cannot select targets by both ID and range") sys.stdout.flush() if comm is not None: comm.Abort() targetids = None if args.targetids is not None: targetids = [ int(x) for x in args.targetids.split(",") ] n_targets = None if args.ntargets is not None: n_targets = args.ntargets first_target = None if args.mintarget is not None: first_target = args.mintarget elif n_targets is not None: first_target = 0 # Multiprocessing processes to use if MPI is disabled. mpprocs = 0 if comm is None: mpprocs = get_mp(args.mp) print("Running with {} processes".format(mpprocs)) if "OMP_NUM_THREADS" in os.environ: nthread = int(os.environ["OMP_NUM_THREADS"]) if nthread != 1: print("WARNING: {} multiprocesses running, each with " "{} threads ({} total)".format(mpprocs, nthread, mpprocs*nthread)) print("WARNING: Please ensure this is <= the number of " "physical cores on the system") else: print("WARNING: using multiprocessing, but the OMP_NUM_THREADS") print("WARNING: environment variable is not set- your system may") print("WARNING: be oversubscribed.") sys.stdout.flush() elif comm_rank == 0: print("Running with {} processes".format(comm_size)) sys.stdout.flush() try: # Load and distribute the targets if comm_rank == 0: print("Loading targets...") sys.stdout.flush() start = elapsed(None, "", comm=comm) # Read the spectra on the root process. Currently the "meta" Table # returned here is not propagated to the output zbest file. However, # that could be changed to work like the DESI write_zbest() function. # Each target contains metadata which is propagated to the output zbest # table though. targets, meta = read_spectra(args.spplate, targetids=targetids, use_frames=args.use_frames, coadd=(not args.allspec), cache_Rcsr=True, use_andmask=args.use_andmask) if args.ntargets is not None: targets = targets[first_target:first_target+n_targets] meta = meta[first_target:first_target+n_targets] stop = elapsed(start, "Read of {} targets"\ .format(len(targets)), comm=comm) # Distribute the targets. start = elapsed(None, "", comm=comm) dtargets = DistTargetsCopy(targets, comm=comm, root=0) # Get the dictionary of wavelength grids dwave = dtargets.wavegrids() stop = elapsed(start, "Distribution of {} targets"\ .format(len(dtargets.all_target_ids)), comm=comm) # Read the template data dtemplates = load_dist_templates(dwave, templates=args.templates, comm=comm, mp_procs=mpprocs) # Compute the redshifts, including both the coarse scan and the # refinement. This function only returns data on the rank 0 process. start = elapsed(None, "", comm=comm) scandata, zfit = zfind(dtargets, dtemplates, mpprocs, nminima=args.nminima, archetypes=args.archetypes, priors=args.priors, chi2_scan=args.chi2_scan) stop = elapsed(start, "Computing redshifts took", comm=comm) # Write the outputs if args.output is not None: start = elapsed(None, "", comm=comm) if comm_rank == 0: write_zscan(args.output, scandata, zfit, clobber=True) stop = elapsed(start, "Writing zscan data took", comm=comm) if args.zbest: start = elapsed(None, "", comm=comm) if comm_rank == 0: zbest = zfit[zfit['znum'] == 0] # Remove extra columns not needed for zbest zbest.remove_columns(['zz', 'zzchi2', 'znum']) # Change to upper case like DESI for colname in zbest.colnames: if colname.islower(): zbest.rename_column(colname, colname.upper()) template_version = {t._template.full_type:t._template._version for t in dtemplates} archetype_version = None if not args.archetypes is None: archetypes = All_archetypes(archetypes_dir=args.archetypes).archetypes archetype_version = {name:arch._version for name, arch in archetypes.items() } write_zbest(args.zbest, zbest, template_version, archetype_version) stop = elapsed(start, "Writing zbest data took", comm=comm) except Exception as err: exc_type, exc_value, exc_traceback = sys.exc_info() lines = traceback.format_exception(exc_type, exc_value, exc_traceback) lines = [ "Proc {}: {}".format(comm_rank, x) for x in lines ] print("--- Process {} raised an exception ---".format(comm_rank)) print("".join(lines)) sys.stdout.flush() if comm is None or args.no_mpi_abort: raise err else: comm.Abort() global_stop = elapsed(global_start, "Total run time", comm=comm) if args.debug: import IPython IPython.embed()