-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathmlp.py
More file actions
30 lines (26 loc) · 927 Bytes
/
mlp.py
File metadata and controls
30 lines (26 loc) · 927 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import torch.nn as nn
class FeedForward(nn.Module):
"""
Two-layer feedforward network with configurable activation and dropout.
Accepts a config object with attributes:
- n_embd
- dropout
- bias
- activation (optional: 'gelu' [default] or 'relu')
"""
def __init__(self, config):
super().__init__()
act_fn = {
'gelu': nn.GELU(),
'relu': nn.ReLU(),
}.get(getattr(config, 'activation', 'gelu').lower())
if act_fn is None:
raise ValueError(f"Unsupported activation: {getattr(config, 'activation', 'gelu')}")
self.net = nn.Sequential(
nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias),
act_fn,
nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias),
nn.Dropout(config.dropout),
)
def forward(self, x):
return self.net(x)