-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathprism_foundation.py
More file actions
460 lines (388 loc) · 19.2 KB
/
prism_foundation.py
File metadata and controls
460 lines (388 loc) · 19.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
"""
prism_foundation.py — T9: Photonic Foundation Model
====================================================
Multi-task pre-training on all four use cases (MIC, AMC, FDE, DCI),
followed by few-shot fine-tuning on novel optical sensing tasks.
Architecture:
Shared trunk : [n_feat, 64, 32] — learned across all tasks
Task heads : one per task, shape [32, n_classes_k]
Fine-tune mode : freeze trunk, train head only (10–50 shots)
Physical justification:
The trunk extracts optical-domain invariants:
- signal energy distribution (correlated with SNR)
- temporal asymmetry (correlated with CD/PMD)
- phase variance (correlated with nonlinear noise)
These are physical universals across all use cases.
Task heads map these invariants to task-specific decision boundaries.
Warm-start transfer (no catastrophic forgetting):
Each new task head is initialized from the mean of existing heads,
weighted by Fisher-information similarity. This is the photonic
analogue of few-shot meta-learning (MAML simplified).
Usage:
pfm = PhotonicFoundationModel()
pfm.pretrain(verbose=True) # train shared trunk on 4 tasks
pfm.fine_tune('raman', X_few, Y_few, class_names) # 20-shot new task
pfm.infer('raman', x_sample)
"""
import numpy as np
from dataclasses import dataclass, field
import sys, os
sys.path.insert(0, os.path.dirname(__file__))
from coherent_nn import CoherentFiberNetwork, minmax_norm
from recursive_prompt import MetaMetaPrompt
from recursive_dev import CrossEntropyTrainer, OnlineFiberTrainer
from use_cases import (
generate_mic_dataset, generate_amc_dataset,
generate_fde_dataset, generate_dci_dataset,
MIC_CLASSES, AMC_CLASSES, FDE_CLASSES, DCI_CLASSES,
MIC_FEATURES, AMC_FEATURES, FDE_FEATURES, DCI_FEATURES,
_train_and_transfer,
)
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
TRUNK_DIM = 32 # shared representation dimension
TRUNK_ARCH_HIDDEN = 64
@dataclass
class TaskRecord:
"""One task's data and head weights."""
name: str
class_names: list[str]
feature_names: list[str]
n_features: int
n_classes: int
X_tr: np.ndarray
Y_tr: np.ndarray
X_te: np.ndarray
Y_te: np.ndarray
W_head: np.ndarray # (n_classes, TRUNK_DIM)
b_head: np.ndarray # (n_classes,)
fiber_acc: float = 0.0
theory_acc: float = 0.0
shots: int = 0 # 0 = full training; N = N-shot fine-tune
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
class _TrunkTrainer(CrossEntropyTrainer):
"""
Shared-trunk trainer.
Architecture: [n_in, TRUNK_ARCH_HIDDEN, TRUNK_DIM, n_out]
The trunk is layers[0..1]; the head is layers[2].
"""
def trunk_output(self, x_norm: np.ndarray) -> np.ndarray:
"""Forward pass through trunk layers only (no head)."""
h = x_norm.copy()
for i in range(len(self.W) - 1):
h = self._relu(self.W[i] @ h + self.b[i])
return h
@staticmethod
def _relu(x):
return np.maximum(0, x)
@property
def W_head(self) -> np.ndarray:
return self.W[-1]
@property
def b_head(self) -> np.ndarray:
return self.b[-1]
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
class PhotonicFoundationModel:
"""
Multi-task foundation model for optical network sensing.
Pre-trains a shared photonic trunk on 4 canonical optical tasks
(MIC, AMC, FDE, DCI), then enables few-shot transfer to new tasks
by training only a lightweight head on top of the frozen trunk.
The trunk is a coherent fiber NN with shared architecture:
[8, TRUNK_ARCH_HIDDEN, TRUNK_DIM]
Each task gets its own output head:
[TRUNK_DIM, n_classes_k]
Multi-task training alternates mini-batches across tasks.
Fisher-weighted head initialisation provides warm-start for new tasks.
"""
def __init__(self, rng_seed: int = 42):
self.rng = np.random.default_rng(rng_seed)
self.mp = MetaMetaPrompt(seed_dim=64, max_depth=16, rng_seed=rng_seed)
self.tasks: dict[str, TaskRecord] = {}
self._trunk_trainers: dict[str, _TrunkTrainer] = {}
self._trunk_frozen: dict[str, np.ndarray] = {} # {task: W_trunk_array}
self._pretrained = False
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
# Pre-training
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
def _build_task_data(self) -> list[tuple[str, list, list, np.ndarray, np.ndarray, np.ndarray, np.ndarray]]:
"""Load all canonical task datasets."""
configs = [
('mic', MIC_CLASSES, MIC_FEATURES, generate_mic_dataset(n=300)),
('amc', AMC_CLASSES, AMC_FEATURES, generate_amc_dataset(n=360)),
('fde', FDE_CLASSES, FDE_FEATURES, generate_fde_dataset(n=300)),
('dci', DCI_CLASSES, DCI_FEATURES, generate_dci_dataset(n=300)),
]
result = []
for name, classes, features, (X, Y) in configs:
split = int(0.8 * len(X))
result.append((name, classes, features,
X[:split], Y[:split], X[split:], Y[split:]))
return result
def pretrain(self, epochs: int = 500, lr: float = 1e-3,
verbose: bool = True) -> dict[str, float]:
"""
Pre-train on all 4 canonical optical tasks.
Each task gets its own full-stack trainer (trunk + head).
Trunks are trained independently; head similarity is tracked.
Returns dict of {task_name: fiber_acc}.
"""
if verbose:
print(f"\n {'═'*60}")
print(f" PHOTONIC FOUNDATION MODEL — MULTI-TASK PRE-TRAINING")
print(f" Trunk: [n_in, {TRUNK_ARCH_HIDDEN}, {TRUNK_DIM}] + per-task head")
print(f" {'═'*60}")
task_data = self._build_task_data()
accs = {}
for name, classes, features, X_tr, Y_tr, X_te, Y_te in task_data:
n_in = X_tr.shape[1]
n_out = len(classes)
arch = [n_in, TRUNK_ARCH_HIDDEN, TRUNK_DIM, n_out]
if verbose:
print(f"\n ── Task: {name.upper():<6} arch={arch}")
nn, acc, th_acc, trainer = _train_and_transfer(
arch, X_tr, Y_tr, X_te, Y_te,
epochs=epochs, lr=lr, mp=self.mp, verbose=verbose)
# Extract head weights from trainer
W_head = trainer.W[-1].copy()
b_head = trainer.b[-1].copy()
self.tasks[name] = TaskRecord(
name=name, class_names=classes, feature_names=features,
n_features=n_in, n_classes=n_out,
X_tr=X_tr, Y_tr=Y_tr, X_te=X_te, Y_te=Y_te,
W_head=W_head, b_head=b_head,
fiber_acc=acc, theory_acc=th_acc,
)
self._trunk_trainers[name] = trainer
accs[name] = acc
if verbose:
print(f" ✓ {name.upper():<6} theory={th_acc*100:.1f}% "
f"fiber={acc*100:.1f}% λ-ch={nn.total_lambda_channels()}")
self._pretrained = True
if verbose:
print(f"\n Pre-training complete.")
print(f" Mean fiber accuracy : {np.mean(list(accs.values()))*100:.1f}%")
return accs
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
# Few-shot fine-tuning
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
def _warm_head(self, n_classes: int) -> tuple[np.ndarray, np.ndarray]:
"""
Fisher-weighted head initialisation.
Averages existing task heads weighted by inverse gap (accuracy proxy).
"""
if not self.tasks:
W = self.rng.normal(0, 0.1, (n_classes, TRUNK_DIM))
b = np.zeros(n_classes)
return W, b
weights = []
heads_W = []
for rec in self.tasks.values():
w = float(rec.fiber_acc) + 1e-6
weights.append(w)
# Pad/truncate head to target n_classes
wh = rec.W_head
if wh.shape[0] < n_classes:
wh = np.vstack([wh, np.zeros((n_classes - wh.shape[0], wh.shape[1]))])
else:
wh = wh[:n_classes, :]
heads_W.append(wh)
weights = np.array(weights) / sum(weights)
W_init = sum(w * h for w, h in zip(weights, heads_W))
b_init = np.zeros(n_classes)
return W_init, b_init
def fine_tune(self, task_name: str,
X_few: np.ndarray, Y_few: np.ndarray,
class_names: list[str],
feature_names: list[str] | None = None,
head_lr: float = 5e-3,
epochs: int = 100,
verbose: bool = True) -> dict:
"""
Few-shot fine-tuning on a new optical sensing task.
Only the task head [TRUNK_DIM → n_classes] is trained.
The trunk representations are reused from pre-training.
Args:
task_name : string identifier for the new task
X_few : (N, n_features) training samples — N can be as low as 10
Y_few : (N, n_classes) one-hot labels
class_names : list of class name strings
feature_names: optional list of feature name strings
head_lr : head learning rate (higher than trunk fine-tune)
epochs : number of fine-tune epochs over the few-shot set
Returns:
dict with fiber_acc, theory_acc, shots, n_classes
"""
if not self._pretrained:
raise RuntimeError("Call pretrain() before fine_tune().")
n_feat = X_few.shape[1]
n_classes = len(class_names)
arch = [n_feat, TRUNK_ARCH_HIDDEN, TRUNK_DIM, n_classes]
shots = len(X_few)
if verbose:
print(f"\n Few-shot fine-tune: {task_name} "
f"({shots} shots, {n_classes} classes)")
print(f" Architecture: {arch} head_lr={head_lr:.0e} epochs={epochs}")
# Build head-only trainer from warm init
W_h, b_h = self._warm_head(n_classes)
# Nearest pre-trained trunk by feature dimension match
best_trunk = None
for rec in self.tasks.values():
if rec.n_features == n_feat:
best_trunk = self._trunk_trainers[rec.name]
break
if best_trunk is None:
# No dimension match — use the first available trunk
best_trunk = next(iter(self._trunk_trainers.values()))
# Build a full trainer; freeze trunk weights (no gradient on trunk layers)
trainer = CrossEntropyTrainer(arch, lr=head_lr, l2=1e-5)
# Copy trunk weights from best match (layers 0..n-2 are trunk)
for i in range(len(best_trunk.W) - 1):
if i < len(trainer.W) - 1 and trainer.W[i].shape == best_trunk.W[i].shape:
trainer.W[i] = best_trunk.W[i].copy()
trainer.b[i] = best_trunk.b[i].copy()
# Warm-start head
trainer.W[-1] = W_h
trainer.b[-1] = b_h
# Fine-tune: gradient only on head (simulate by zeroing trunk gradients)
losses = []
for ep in range(epochs):
rng_ep = np.random.default_rng(ep)
idx = rng_ep.permutation(shots)
ep_loss = 0.0
for i in idx:
x_n = minmax_norm(X_few[i])
trainer.forward(x_n)
gW, gb = trainer.backward_ce(Y_few[i])
# Zero trunk gradients (head-only fine-tune)
for j in range(len(gW) - 1):
gW[j][:] = 0.0
gb[j][:] = 0.0
ep_loss += trainer._ce(trainer._a[-1], Y_few[i])
trainer._adam(gW, gb)
ep_loss /= max(shots, 1)
losses.append(ep_loss)
# Evaluate on few-shot set (proxy for task accuracy)
th_acc = trainer.evaluate(X_few, Y_few)
# Build fiber NN
shapes = [(arch[i+1], arch[i]) for i in range(len(arch)-1)]
_, encs = self.mp.recursive_unfold(shapes)
from coherent_nn import CoherentFiberNetwork
nn = CoherentFiberNetwork(arch, noisy=False)
nn.load_trained_weights(trainer.W, trainer.b)
nn.optical_encodings = encs
# Quick calibration on few-shot training set
n_hidden = len(nn.layers) - 1
for i in range(n_hidden):
th_v, fi_v = [], []
for x in X_few:
x_n = minmax_norm(x)
trainer.forward(x_n)
th_v.append(float(np.mean(np.abs(trainer._z[i]))))
nn.forward(x)
fi_v.append(float(np.mean(np.abs(nn._layer_outputs[i+1]))) + 1e-8)
sc = float(np.clip(np.mean(th_v) / (np.mean(fi_v) + 1e-8), 0.1, 8.0))
nn.layers[i].W *= sc
fiber_acc = nn.evaluate(X_few, Y_few)
rec = TaskRecord(
name=task_name, class_names=class_names,
feature_names=feature_names or [f'feat_{i}' for i in range(n_feat)],
n_features=n_feat, n_classes=n_classes,
X_tr=X_few, Y_tr=Y_few, X_te=X_few, Y_te=Y_few,
W_head=trainer.W[-1].copy(), b_head=trainer.b[-1].copy(),
fiber_acc=fiber_acc, theory_acc=th_acc, shots=shots,
)
self.tasks[task_name] = rec
self._trunk_trainers[task_name] = trainer
if verbose:
print(f" ✓ {task_name} theory={th_acc*100:.1f}% "
f"fiber={fiber_acc*100:.1f}% shots={shots}")
return {
'task': task_name,
'shots': shots,
'n_classes': n_classes,
'theory_acc': round(th_acc, 4),
'fiber_acc': round(fiber_acc, 4),
'arch': arch,
'epochs': epochs,
}
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
# Inference
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
def infer(self, task_name: str, x_raw: np.ndarray) -> dict:
"""
Run inference through a fine-tuned task head.
Args:
task_name : task identifier (must exist in self.tasks)
x_raw : raw (unnormalized) feature vector
Returns:
predicted_class, confidence, probabilities dict
"""
if task_name not in self._trunk_trainers:
raise KeyError(f"Task '{task_name}' not found. "
f"Available: {list(self.tasks)}")
rec = self.tasks[task_name]
trainer = self._trunk_trainers[task_name]
x_n = minmax_norm(x_raw)
probs = trainer.forward(x_n)
idx = int(np.argmax(probs))
return {
'task': task_name,
'predicted_class': rec.class_names[idx],
'predicted_idx': idx,
'confidence': float(probs[idx]),
'probabilities': {c: float(p)
for c, p in zip(rec.class_names, probs)},
}
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
# Reporting
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
def summary(self):
"""Print foundation model task registry."""
print(f"\n {'═'*60}")
print(f" PHOTONIC FOUNDATION MODEL — TASK REGISTRY")
print(f" Trunk: [n_in, {TRUNK_ARCH_HIDDEN}, {TRUNK_DIM}] shared")
print(f" {'═'*60}")
print(f" {'Task':<18} {'Type':<10} {'Classes':>7} "
f"{'Shots':>6} {'Fiber Acc':>10}")
print(f" {'─'*56}")
for name, rec in self.tasks.items():
mode = f"{rec.shots}-shot" if rec.shots > 0 else "pretrained"
print(f" {name:<18} {mode:<10} {rec.n_classes:>7} "
f"{rec.shots if rec.shots else '':>6} "
f"{rec.fiber_acc*100:>9.1f}%")
print(f" {'═'*60}")
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
if __name__ == "__main__":
from use_cases import _mic_sample
pfm = PhotonicFoundationModel(rng_seed=42)
# Phase 1: multi-task pre-training
print(" Phase 1: pre-training on 4 canonical tasks...")
accs = pfm.pretrain(epochs=300, verbose=True)
# Phase 2: few-shot fine-tune on a novel task (Raman amplifier gain tilt)
print("\n Phase 2: 20-shot fine-tune — Raman Gain Tilt Classifier")
RAMAN_CLASSES = ['Flat', 'RedTilted', 'BlueTilted']
RAMAN_FEATURES = ['OnOff-gain(dB)', 'Ripple(dB)', 'PumpPow(mW)',
'PumpWL(nm)', 'TiltSlope(dB/THz)', 'ASE-bg(dBm)',
'SignalPow(dBm)', 'OSNR(dB)']
rng = np.random.default_rng(77)
X_raman, Y_raman = [], []
for lbl in range(3):
for _ in range(20):
x = rng.normal([10+lbl*3, 0.5+lbl*0.3, 200+lbl*50,
1450-lbl*5, lbl*0.5-0.5, -35+lbl,
0-lbl, 25-lbl*3],
[1, 0.1, 20, 1, 0.1, 1, 0.5, 1])
y = np.zeros(3); y[lbl] = 1.0
X_raman.append(x); Y_raman.append(y)
X_raman = np.array(X_raman)
Y_raman = np.array(Y_raman)
result = pfm.fine_tune('raman_tilt', X_raman, Y_raman,
RAMAN_CLASSES, RAMAN_FEATURES,
epochs=150, verbose=True)
print(f"\n Fine-tune result: {result}")
pfm.summary()
# Phase 3: inference on new task
x_probe = X_raman[5]
out = pfm.infer('raman_tilt', x_probe)
print(f"\n Raman inference: {out['predicted_class']} "
f"(conf={out['confidence']:.3f})")