-
Notifications
You must be signed in to change notification settings - Fork 28
Expand file tree
/
Copy pathtrace_model.py
More file actions
667 lines (570 loc) · 23.8 KB
/
Copy pathtrace_model.py
File metadata and controls
667 lines (570 loc) · 23.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
"""
Trace data model for PyReduce.
This module defines the Trace dataclass and I/O functions for storing
trace positions, curvature, and wavelength calibration in FITS format.
The Trace dataclass consolidates what was previously scattered across
separate files (traces.npz, curve.npz, wavecal.npz) into a single structure.
"""
from __future__ import annotations
import logging
import re
from dataclasses import dataclass
from pathlib import Path
import astropy.io.fits as fits
import numpy as np
logger = logging.getLogger(__name__)
# Format version for backwards compatibility detection
# v2: Initial FITS format with FIBER column
# v3: Renamed FIBER→GROUP, added FIBER_IDX column
# v4: Added BUNDLE column (bundle id, independent of m)
FORMAT_VERSION = 4
@dataclass
class Trace:
"""Container for a single trace's geometry and calibration data.
A trace represents a single spectral order (or fiber within an order)
on the detector.
Attributes
----------
m : int | None
Spectral order number (diffraction order). This is the physical order
number from the grating equation, not a sequential index. In echelle
spectrographs, higher order numbers correspond to shorter wavelengths.
The order number is assigned in one of three ways:
1. **From order_centers.yaml** (preferred): If the instrument provides
an ``order_centers_{channel}.yaml`` file with known order positions,
traces are matched to these centers during detection and assigned
the corresponding order numbers immediately.
2. **From wavelength calibration**: If no order_centers file exists,
``m`` is initially None. During wavelength calibration, the linelist
file provides ``obase`` (the base order number). Each trace is then
assigned ``m = obase + trace_index``.
3. **Sequential fallback**: For legacy files or MOSAIC mode where order
identity cannot be determined, ``m`` may remain None or be assigned
sequentially from 0.
The order number is critical for 2D wavelength calibration, which fits
a polynomial in both pixel position (x) and order number (m). When
evaluating wavelengths via ``Trace.wlen()``, the trace's ``m`` value
is used as the second coordinate in the 2D polynomial.
bundle : int | None
Bundle identifier (1-indexed), independent of `m`. Used by
instruments where fibers are organised into spatial bundles
within each spectral order (MOSAIC: 90 bundles × 7 fibers; in
principle also bundled echelle: m orders × bundles). For
instruments without a bundle concept (ANDES groups, single-fiber
echelles), bundle stays None. Drives the bundle group name
``f"bundle_{bundle}"``.
group : str | int | None
Group identifier, or None if trace is ungrouped. When set, indicates
this trace is the result of grouping/merging fibers for this order.
There should be exactly one trace per (m, bundle, group). String
for named groups ('A', 'B', 'cal') or bundle merges
('bundle_45'); int kept for legacy compatibility.
**Mutually exclusive with fiber_idx.** A trace has either:
- group set (merged/grouped result) and fiber_idx=None, or
- fiber_idx set (individual fiber) and group=None, or
- both None (ungrouped single-fiber instrument)
fiber_idx : int | None
Physical fiber index (1-indexed) within (m, bundle). For
multi-fiber instruments where fibers are tracked individually
(not merged). Used for per-fiber wavelength calibration. There
should be exactly one trace per (m, bundle, fiber_idx).
**Mutually exclusive with group.** See group docstring for details.
pos : np.ndarray
y(x) trace position polynomial coefficients, shape (deg+1,).
Coefficients in numpy.polyval order (highest power first).
column_range : tuple[int, int]
Valid x range [start, end) for this trace.
height : float | None
Extraction aperture height in pixels. None to use settings default.
slit : np.ndarray | None
Slit curvature coefficients, shape (deg_y+1, deg_x+1).
Evaluates to x_offset = P(y) where P's coefficients vary with x.
slit[i, :] are coefficients for the y^i term as a function of x.
slitdelta : np.ndarray | None
Per-row slit correction, shape (height_pixels,).
Residual offsets beyond polynomial fit.
wave : np.ndarray | None
Wavelength polynomial coefficients. Can be:
- 1D array, shape (deg+1,): per-trace polynomial, wavelength = polyval(x)
- 2D array, shape (deg_x+1, deg_m+1): global 2D polynomial shared across
all traces. Wavelength = polyval2d(x, m) where m is this trace's order.
"""
# Identity
m: int | None
# Geometry (required)
pos: np.ndarray
column_range: tuple[int, int]
# Optional fields (must come after required fields)
bundle: int | None = None
group: str | int | None = None
fiber_idx: int | None = None
height: float | None = None
slit: np.ndarray | None = None
slitdelta: np.ndarray | None = None
wave: np.ndarray | None = None
_wave_idx: int | None = None # trace index for 2D polynomial evaluation
invalid: str | None = None # reason if trace should be skipped
def slit_at_x(self, x: float | np.ndarray) -> np.ndarray | None:
"""Evaluate slit polynomial coefficients at position x.
Parameters
----------
x : float or np.ndarray
Column position(s) to evaluate at.
Returns
-------
np.ndarray or None
Polynomial coefficients for y_offset = c0 + c1*y + c2*y^2 + ...
Shape (deg_y+1,) for scalar x, or (len(x), deg_y+1) for array x.
Returns None if no slit curvature is set.
"""
if self.slit is None:
return None
# slit[i, :] = coefficients for y^i term as function of x
# Evaluate each row's polynomial at x
return np.array([np.polyval(c, x) for c in self.slit])
def wlen(self, x: np.ndarray) -> np.ndarray | None:
"""Evaluate wavelength polynomial at column positions.
Parameters
----------
x : np.ndarray
Column positions to evaluate at.
Returns
-------
np.ndarray or None
Wavelength values at each x position.
Returns None if no wavelength calibration is set.
"""
if self.wave is None:
return None
if self.wave.ndim == 2:
# 2D polynomial: wave[i,j] is coeff for x^i * idx^j
# The polynomial is fitted with trace indices (0, 1, 2, ...),
# so we must use _wave_idx (not physical order number m).
idx = self._wave_idx
if idx is None:
logger.warning(
"Cannot evaluate 2D wavelength polynomial: trace._wave_idx is None."
)
return None
m_arr = np.full_like(x, idx, dtype=float)
return np.polynomial.polynomial.polyval2d(x, m_arr, self.wave)
else:
# 1D polynomial: standard polyval
return np.polyval(self.wave, x)
def y_at_x(self, x: np.ndarray) -> np.ndarray:
"""Evaluate trace y-position at column positions.
Parameters
----------
x : np.ndarray
Column positions to evaluate at.
Returns
-------
np.ndarray
Y positions of the trace center at each x.
"""
return np.polyval(self.pos, x)
def _validate_traces(traces: list[Trace], context: str = "") -> None:
"""Validate trace list invariants.
Checks:
1. (m, bundle, group) is unique for grouped traces (group is not None)
2. (m, bundle, fiber_idx) is unique for individual fiber traces
3. Traces are ordered by y-position (ascending)
Parameters
----------
traces : list[Trace]
Traces to validate.
context : str
Context for error messages (e.g., file path).
Raises
------
ValueError
If validation fails.
"""
if not traces:
return
# Check that group and fiber_idx are mutually exclusive
for i, t in enumerate(traces):
if t.group is not None and t.fiber_idx is not None:
raise ValueError(
f"Trace {i} has both group={t.group} and fiber_idx={t.fiber_idx}. "
f"These are mutually exclusive: group indicates merged fiber result, "
f"fiber_idx indicates individual fiber{context}"
)
# Check (m, bundle, group) uniqueness for grouped traces, and that
# bundle merges encode their bundle id consistently in both fields
# (group="bundle_5" must mean bundle=5).
seen_group = set()
bundle_pat = re.compile(r"^bundle_(\d+)$")
for t in traces:
if t.group is not None:
key = (t.m, t.bundle, t.group)
if key in seen_group:
raise ValueError(
f"Duplicate (m={t.m}, bundle={t.bundle}, group={t.group}) "
f"in traces{context}"
)
seen_group.add(key)
match = bundle_pat.match(str(t.group))
if match and t.bundle != int(match.group(1)):
raise ValueError(
f"Trace group={t.group!r} does not match bundle={t.bundle} "
f"(m={t.m}){context}"
)
# Check (m, bundle, fiber_idx) uniqueness for fiber traces
seen_fiber = set()
for t in traces:
if t.fiber_idx is not None:
key = (t.m, t.bundle, t.fiber_idx)
if key in seen_fiber:
raise ValueError(
f"Duplicate (m={t.m}, bundle={t.bundle}, "
f"fiber_idx={t.fiber_idx}) in traces{context}"
)
seen_fiber.add(key)
# Check y-position ordering (evaluate at midpoint of column range)
# Only applies to ungrouped traces - grouped traces are organized by (m, group)
has_groups = any(t.group is not None for t in traces)
if not has_groups:
ref_x = (traces[0].column_range[0] + traces[0].column_range[1]) // 2
y_positions = [t.y_at_x(ref_x) for t in traces]
for i in range(1, len(y_positions)):
if y_positions[i] < y_positions[i - 1]:
logger.warning(
"Traces not ordered by y-position at x=%d: trace %d (y=%.1f) < trace %d (y=%.1f)%s",
ref_x,
i,
y_positions[i],
i - 1,
y_positions[i - 1],
context,
)
def save_traces(
path: str | Path,
traces: list[Trace],
header: fits.Header = None,
steps: list[str] = None,
) -> None:
"""Save traces to a FITS binary table.
Parameters
----------
path : str | Path
Output file path.
traces : list[Trace]
Traces to save.
header : fits.Header, optional
FITS header to include. If None, a minimal header is created.
steps : list[str], optional
Pipeline steps that have been run (stored in E_STEPS header).
Raises
------
ValueError
If traces have duplicate (group, m) keys.
"""
if not traces:
raise ValueError("Cannot save empty trace list")
_validate_traces(traces, f" when saving to {path}")
if header is None:
header = fits.Header()
else:
header = header.copy()
# Add format metadata
header["E_FMTVER"] = (FORMAT_VERSION, "PyReduce format version")
if steps:
header["E_STEPS"] = (",".join(steps), "Pipeline steps run")
# Determine array sizes
max_pos_deg = max(len(t.pos) for t in traces)
# Determine wave dimensions - can be 1D (per-trace) or 2D (global poly)
wave_shapes = [t.wave.shape if t.wave is not None else () for t in traces]
wave_is_2d = any(len(s) == 2 for s in wave_shapes)
if wave_is_2d:
max_wave_x = max((s[0] if len(s) == 2 else 0) for s in wave_shapes)
max_wave_m = max((s[1] if len(s) == 2 else 0) for s in wave_shapes)
max_wave_deg = 0 # Not used for 2D
else:
max_wave_deg = max((s[0] if len(s) >= 1 else 0) for s in wave_shapes)
max_wave_x = max_wave_m = 0
max_slitdelta_len = max(
(len(t.slitdelta) if t.slitdelta is not None else 0) for t in traces
)
# Determine slit dimensions (deg_y+1, deg_x+1)
slit_shapes = [(t.slit.shape if t.slit is not None else (0, 0)) for t in traces]
max_slit_y = max(s[0] for s in slit_shapes)
max_slit_x = max(s[1] for s in slit_shapes)
ntrace = len(traces)
# Build arrays
m_arr = np.array([t.m if t.m is not None else -1 for t in traces], dtype=np.int16)
bundle_arr = np.array(
[t.bundle if t.bundle is not None else -1 for t in traces], dtype=np.int16
)
group_arr = np.array(
[str(t.group) if t.group is not None else "" for t in traces], dtype="U16"
)
fiber_idx_arr = np.array(
[t.fiber_idx if t.fiber_idx is not None else -1 for t in traces], dtype=np.int16
)
col_range_arr = np.array([t.column_range for t in traces], dtype=np.int32)
height_arr = np.array(
[t.height if t.height is not None else np.nan for t in traces], dtype=np.float32
)
pos_arr = np.zeros((ntrace, max_pos_deg), dtype=np.float64)
for i, t in enumerate(traces):
pos_arr[i, : len(t.pos)] = t.pos
wave_arr = None
if wave_is_2d and max_wave_x > 0 and max_wave_m > 0:
# 2D wavelength polynomial
wave_arr = np.full((ntrace, max_wave_x, max_wave_m), np.nan, dtype=np.float64)
for i, t in enumerate(traces):
if t.wave is not None and t.wave.ndim == 2:
wx, wm = t.wave.shape
wave_arr[i, :wx, :wm] = t.wave
elif max_wave_deg > 0:
# 1D wavelength polynomial per trace
wave_arr = np.full((ntrace, max_wave_deg), np.nan, dtype=np.float64)
for i, t in enumerate(traces):
if t.wave is not None and t.wave.ndim == 1:
wave_arr[i, : len(t.wave)] = t.wave
slit_arr = None
if max_slit_y > 0 and max_slit_x > 0:
slit_arr = np.full((ntrace, max_slit_y, max_slit_x), np.nan, dtype=np.float64)
for i, t in enumerate(traces):
if t.slit is not None:
sy, sx = t.slit.shape
slit_arr[i, :sy, :sx] = t.slit
slitdelta_arr = None
if max_slitdelta_len > 0:
slitdelta_arr = np.full((ntrace, max_slitdelta_len), np.nan, dtype=np.float32)
for i, t in enumerate(traces):
if t.slitdelta is not None:
slitdelta_arr[i, : len(t.slitdelta)] = t.slitdelta
# Build FITS columns
columns = [
fits.Column(name="M", format="I", array=m_arr),
fits.Column(name="BUNDLE", format="I", array=bundle_arr),
fits.Column(name="GROUP", format="16A", array=group_arr),
fits.Column(name="FIBER_IDX", format="I", array=fiber_idx_arr),
fits.Column(name="POS", format=f"{max_pos_deg}D", array=pos_arr),
fits.Column(name="COL_RANGE", format="2J", array=col_range_arr),
fits.Column(name="HEIGHT", format="E", array=height_arr),
]
if slit_arr is not None:
slit_flat = slit_arr.reshape(ntrace, -1)
columns.append(
fits.Column(
name="SLIT",
format=f"{slit_flat.shape[1]}D",
array=slit_flat,
dim=f"({max_slit_x},{max_slit_y})",
)
)
header["SLIT_Y"] = (max_slit_y, "Slit polynomial y-degree + 1")
header["SLIT_X"] = (max_slit_x, "Slit polynomial x-degree + 1")
if slitdelta_arr is not None:
columns.append(
fits.Column(
name="SLITDELTA", format=f"{max_slitdelta_len}E", array=slitdelta_arr
)
)
if wave_arr is not None:
if wave_is_2d:
wave_flat = wave_arr.reshape(ntrace, -1)
columns.append(
fits.Column(
name="WAVE",
format=f"{wave_flat.shape[1]}D",
array=wave_flat,
dim=f"({max_wave_m},{max_wave_x})",
)
)
header["WAVE_X"] = (max_wave_x, "Wave polynomial x-degree + 1")
header["WAVE_M"] = (max_wave_m, "Wave polynomial m-degree + 1")
else:
columns.append(
fits.Column(name="WAVE", format=f"{max_wave_deg}D", array=wave_arr)
)
# Create HDU list
primary = fits.PrimaryHDU(header=header)
table = fits.BinTableHDU.from_columns(columns, name="TRACES")
hdulist = fits.HDUList([primary, table])
hdulist.writeto(path, overwrite=True, output_verify="silentfix+ignore")
logger.info("Saved %d traces to: %s", ntrace, path)
def load_traces(path: str | Path) -> tuple[list[Trace], fits.Header]:
"""Load traces from a FITS file.
Also supports loading legacy NPZ format for backwards compatibility.
Parameters
----------
path : str | Path
Input file path (.fits or .npz).
Returns
-------
traces : list[Trace]
Loaded traces.
header : fits.Header
FITS header (empty for NPZ files).
"""
path = Path(path)
if path.suffix == ".npz":
return _load_traces_npz(path)
with fits.open(path, memmap=False) as hdu:
header = hdu[0].header
fmtver = header.get("E_FMTVER", 1)
if fmtver < 2:
logger.warning("Loading traces from old format (version %d)", fmtver)
data = hdu["TRACES"].data
m_arr = data["M"]
bundle_arr = data["BUNDLE"] if "BUNDLE" in data.dtype.names else None
# Handle both new (GROUP) and old (FIBER) column names
if "GROUP" in data.dtype.names:
group_arr = data["GROUP"]
else:
group_arr = data["FIBER"] # Backward compat with v2
fiber_idx_arr = data["FIBER_IDX"] if "FIBER_IDX" in data.dtype.names else None
pos_arr = data["POS"]
col_range_arr = data["COL_RANGE"]
height_arr = data["HEIGHT"]
slit_arr = data["SLIT"] if "SLIT" in data.dtype.names else None
slitdelta_arr = data["SLITDELTA"] if "SLITDELTA" in data.dtype.names else None
wave_arr = data["WAVE"] if "WAVE" in data.dtype.names else None
# Reshape slit if present
if slit_arr is not None:
slit_y = header.get("SLIT_Y", 0)
slit_x = header.get("SLIT_X", 0)
if slit_y > 0 and slit_x > 0:
slit_arr = slit_arr.reshape(-1, slit_y, slit_x)
# Reshape wave if 2D polynomial
wave_is_2d = False
if wave_arr is not None:
wave_x = header.get("WAVE_X", 0)
wave_m = header.get("WAVE_M", 0)
if wave_x > 0 and wave_m > 0:
wave_arr = wave_arr.reshape(-1, wave_x, wave_m)
wave_is_2d = True
traces = []
for i in range(len(m_arr)):
m = int(m_arr[i]) if m_arr[i] >= 0 else None
group = group_arr[i].strip()
# Empty string or "0" means no group (backward compat)
if group == "" or group == "0":
group = None
else:
# Try to convert group to int if it looks like one
try:
group = int(group)
except ValueError:
pass
fiber_idx = (
int(fiber_idx_arr[i])
if fiber_idx_arr is not None and fiber_idx_arr[i] >= 0
else None
)
bundle = (
int(bundle_arr[i])
if bundle_arr is not None and bundle_arr[i] >= 0
else None
)
# Remove trailing NaN/zeros from pos
pos = pos_arr[i]
pos = pos[~np.isnan(pos)] if np.any(np.isnan(pos)) else pos
column_range = (int(col_range_arr[i, 0]), int(col_range_arr[i, 1]))
height = float(height_arr[i]) if not np.isnan(height_arr[i]) else None
slit = None
if slit_arr is not None:
slit = slit_arr[i]
if np.all(np.isnan(slit)):
slit = None
else:
# Remove all-NaN rows/cols
mask_y = ~np.all(np.isnan(slit), axis=1)
mask_x = ~np.all(np.isnan(slit), axis=0)
slit = slit[mask_y][:, mask_x]
slitdelta = None
if slitdelta_arr is not None:
slitdelta = slitdelta_arr[i]
if np.all(np.isnan(slitdelta)):
slitdelta = None
else:
slitdelta = slitdelta[~np.isnan(slitdelta)]
wave = None
if wave_arr is not None:
wave = wave_arr[i]
if np.all(np.isnan(wave)):
wave = None
elif wave_is_2d:
# 2D polynomial - remove all-NaN rows/cols
mask_x = ~np.all(np.isnan(wave), axis=1)
mask_m = ~np.all(np.isnan(wave), axis=0)
wave = wave[mask_x][:, mask_m]
else:
# 1D polynomial - remove trailing NaN
wave = wave[~np.isnan(wave)]
traces.append(
Trace(
m=m,
bundle=bundle,
group=group,
fiber_idx=fiber_idx,
pos=pos,
column_range=column_range,
height=height,
slit=slit,
slitdelta=slitdelta,
wave=wave,
)
)
# Reconstruct _wave_idx for 2D wave polynomials.
# The 2D polynomial was fitted with trace index within group as the
# order coordinate, so we assign sequential indices per group.
if wave_is_2d:
group_counters: dict = {}
for t in traces:
if t.wave is not None and t.wave.ndim == 2:
g = t.group
idx = group_counters.get(g, 0)
t._wave_idx = idx
group_counters[g] = idx + 1
logger.info("Loaded %d traces from: %s", len(traces), path)
_validate_traces(traces, f" loaded from {path}")
return traces, header
def _load_traces_npz(path: Path) -> tuple[list[Trace], fits.Header]:
"""Load traces from legacy NPZ format.
This handles the old format where traces, column_range, and heights
were stored as separate arrays without order/fiber identity.
Parameters
----------
path : Path
Input NPZ file path.
Returns
-------
traces : list[Trace]
Loaded traces (m and fiber assigned sequentially).
header : fits.Header
Empty header.
"""
data = np.load(path, allow_pickle=True)
# Handle old 'orders' key name
if "orders" in data and "traces" not in data:
trace_coeffs = data["orders"]
else:
trace_coeffs = data["traces"]
column_range = data["column_range"]
# Heights may or may not be present
heights = data.get("heights", None)
if heights is not None and heights.ndim == 0:
heights = None
traces = []
for i in range(len(trace_coeffs)):
height = (
float(heights[i])
if heights is not None and not np.isnan(heights[i])
else None
)
traces.append(
Trace(
m=i, # Sequential order number (no identity preserved)
pos=trace_coeffs[i],
column_range=(int(column_range[i, 0]), int(column_range[i, 1])),
height=height,
)
)
logger.info("Loaded %d traces from legacy NPZ: %s", len(traces), path)
_validate_traces(traces, f" loaded from {path}")
return traces, fits.Header()