WeightsLab is a powerful tool for editing and inspecting data & AI models.
WeightsLab addresses critical AI research challenges:
- Dataset insights & optimization
- Overfitting and training plateau
- Over/Under parameterization
The granular statistics and interactive paradigm enable powerful workflows:
- Monitor granular insights on data samples, signals, and weight parameters
- Use the AI agent to:
- Create slices of data and discard them for the next training iteration
- Discard low-quality samples from training data
- Iterative pruning or growing of the architectures (INCOMING feature)
- Docker Desktop v4.77 or newer β required to deploy the Weights Studio UI (
weightslab ui launch). - Docker Compose v2 (the
docker composeCLI plugin, bundled with Docker Desktop) β recommended. The legacy v1 standalone binary (docker-compose, β₯ 1.27) also works:weightslab ui launchauto-detects whichever is installed and uses it. Compose v1 below 1.27 is not supported. - Python >=3.10, <3.15 β to install and run the
weightslabframework.
Install directly on your machine.
Tip
Setting a clean Python environment:
python -m venv weightslab_venv
./weightslab_venv/Scripts/activateInstall our framework:
pip install weightslabDeploy our interface:
weightslab ui launchThe command weightslab ui launch removes any stale weightslab/weights_studio Docker resources that could break the launch, then starts the UI stack. By default, it runs unsecured (HTTP, no gRPC auth) β no certificates are generated. However, communication are not safe.
Tip
To run secured communication, pass the arguments --certs:
weightslab ui launch --certs # generates TLS certs + a gRPC auth token if missing, then launches securedWhen using certs, set WEIGHTSLAB_CERTS_DIR so the training backend and any new terminal use the same certificates (it is the single source of truth). weightslab se and weightslab ui launch --certs print the exact export/setx command for your shell. You can also generate certs up front with weightslab se.
Important
For a detailed installation guide and more advanced features, please see the Installation Documentation.
-
Add the import at the top of your script:
import weightslab as wl # β Include our SDK into your experiment
-
Wrap your parameters with WeightsLab tracking:
model = wl.watch_or_edit(parameters, flag='hp', ...) # β Now WeightsLab monitors your parameters and allow you to update them from your UI
-
Wrap your model with WeightsLab tracking:
model = wl.watch_or_edit(SimpleModel(...), flag='model', ...) # β Now WeightsLab monitors your model state
-
Wrap your optimizer with WeightsLab tracking:
optimizer = wl.watch_or_edit(optim.Adam(...), flag='opt', ...) # β Tracks optimizer state and update optimizer learning rate from your UI
-
Wrap your signal with WeightsLab tracking:
train_criterion = wl.watch_or_edit(nn.CrossEntropyLoss(reduction="none"), flag='signal', name="train_loss/sample", per_sample=True, log=True) # β Tracks this signal and others (metrics, ..etc) from your UI test_criterion = wl.watch_or_edit(nn.CrossEntropyLoss(reduction="none"), flag='signal', name="test_loss/sample", per_sample=True, log=False) # β Tracks this signal and others (metrics, ..etc) from your UI - Plot is disabled, only per sample signal
-
Wrap your dataset with WeightsLab tracking:
train_loader = wl.watch_or_edit(train_dataset, flag='data', loader_name="train_loader", ...) # β Tracks this dataset and others (validation, test) from your UI val_loader = wl.watch_or_edit(val_dataset, flag='data', loader_name="val_loader", ...) # β Tracks this dataset and others (validation, test) from your UI
-
Run your training script as usual:
python train.py
-
Launch the UI in another terminal:
weightslab ui launch
-
Open your browser to
https://localhost:5173to track experiment evoluation and results!
Here's a complete example showing how to integrate WeightsLab into a basic PyTorch training script:
#!/usr/bin/env python3
"""
Basic PyTorch training script with WeightsLab integration
"""
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import weightslab as wl # β Import WeightsLab (uses TLS certs from WEIGHTSLAB_CERTS_DIR if present)
# Define a simple model
class SimpleModel(nn.Module):
def __init__(self):
super().__init__(input_shape=12, output_shape=2)
self.linear = nn.Linear(input_shape, 1)
def forward(self, x):
return self.linear(x)
# Create synthetic data
def create_data(n_samples=1000):
X = torch.randn(n_samples, 10)
y = X.sum(dim=1, keepdim=True) + 0.1 * torch.randn(n_samples, 1)
return TensorDataset(X, y)
# Main training function
def main():
# Initialize WeightsLab - this creates certificates automatically!
print("π Initializing WeightsLab...")
# Load hyperparameters (from YAML if present)
parameters = {}
config_path = os.path.join(os.path.dirname(__file__), "config.yaml")
if os.path.exists(config_path):
with open(config_path, "r") as fh:
parameters = yaml.safe_load(fh) or {}
parameters = wl.watch_or_edit(
parameters,
flag="hyperparameters",
defaults=parameters,
poll_interval=1.0,
) or {} # Wrap the hyperparameters
# Wrap your model and optimizer with WeightsLab
model = wl.watch_or_edit(
SimpleModel(
input_shape=parameters.get('model', {}).get('input_shape', 10),
output_shape=parameters.get('model', {}).get('output_shape', 1)
)
) # β WeightsLab tracks your model
optimizer = wl.watch_or_edit(
optim.Adam(model.parameters(), lr=parameters.get('model', {}).get('optimizer', {}).get('lr', 0.01)),
flag='optimizer'
) # β WeightsLab tracks optimizer
# Create and wrap criterion
criterion = wl.watch_or_edit(
nn.CrossEntropyLoss(reduction="none"),
flag="loss",
signal_name="train-loss-CE",
log=True # If log is False, only save per sample value, not plot criterion
)
# Create data and dataloader
dataset = create_data()
train_loader = wl.watch_or_edit(
dataset,
flag="data",
loader_name="loader",
batch_size=parameters.get('data', {}).get('train_loader', {}).get('batch_size', 8),
shuffle=parameters.get('data', {}).get('train_loader', {}).get('shuffle', False),
is_training=True, # Is it the training dataloader ?
compute_hash=parameters.get('data', {}).get('train_loader', {}).get('compute_hash', True), # Compute hash for train loader to allow dynamic augmentations and dataset sanity check
preload_labels=parameters.get('data', {}).get('train_loader', {}).get('preload_labels', True),
preload_metadata=parameters.get('data', {}).get('train_loader', {}).get('preload_metadata', True),
enable_h5_persistence=parameters.get('data', {}).get('train_loader', {}).get('enable_h5_persistence', True),
num_workers=parameters.get('data', {}).get('train_loader', {}).get('num_workers', 4)
)
# Training loop
print("π Starting training...")
print("π‘ Launch the UI with: weightslab ui launch")
print("π Open browser to: https://localhost:5173")
n_epochs = parameters.get('n_epochs')
pbar = tqdm.tqdm(range(n_epochs), desc='Training..') if parameters.get('tqdm_display', False) else range(n_epochs)
for epoch in pbar: # Train for 5 epochs
total_loss = 0
for batch_X, batch_y in dataloader:
# Forward pass
predictions = model(batch_X)
loss = criterion(predictions, batch_y)
# Backward pass
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
avg_loss = total_loss / len(dataloader)
print(f"Epoch {epoch+1}/5 - Loss: {avg_loss:.4f}")
print("β
Training complete!")
if __name__ == "__main__":
main()Migrating from wandb? See the diff:
--- train_baseline.py
+++ train_wl.py
@@ -1,11 +1,12 @@
import argparse
import torch
import torch.nn as nn
-from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models
from torchmetrics.classification import MulticlassAccuracy
-import wandb
+import weightslab as wl
+from weightslab.components.global_monitoring import (
+ guard_training_context, guard_testing_context)
+
+@wl.signal(name="byte_adjusted_loss", subscribe_to="loss/CE")
+def byte_adjusted_loss(ctx): return ctx.subscribed_value / ctx.image_bytes # chains on image_bytes
+
def main():
@@ -15,29 +16,38 @@
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
parameters = {"batch_size": 128, "lr": 1e-3}
- wandb.init(project="cifar10")
-
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
])
train_set = datasets.CIFAR10("./data", train=True, download=True, transform=transform)
test_set = datasets.CIFAR10("./data", train=False, download=True, transform=transform)
- train_loader = DataLoader(train_set, batch_size=parameters["batch_size"], shuffle=True, num_workers=2)
- test_loader = DataLoader(test_set, batch_size=256, num_workers=2)
+ wl.watch_or_edit(parameters, flag="hyperparameters") # live-editable in UI
+
+ train_loader = wl.watch_or_edit(
+ train_set, flag="data", loader_name="train_loader",
+ batch_size=parameters["batch_size"], shuffle=True, is_training=True)
+ test_loader = wl.watch_or_edit(
+ test_set, flag="data", loader_name="test_loader",
+ batch_size=256, shuffle=False, is_training=False)
model = models.resnet18(weights=None)
model.fc = nn.Linear(model.fc.in_features, 10)
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=parameters["lr"])
- criterion = nn.CrossEntropyLoss()
- accuracy = MulticlassAccuracy(num_classes=10).to(device)
+ criterion = wl.watch_or_edit(
+ nn.CrossEntropyLoss(), flag="loss", signal_name="loss/CE")
+ accuracy = wl.watch_or_edit(
+ MulticlassAccuracy(num_classes=10).to(device),
+ flag="metric", signal_name="acc")
+
+ wl.serve(serving_grpc=True)
for epoch in range(1, args.epochs + 1):
model.train()
accuracy.reset()
for x, y in train_loader:
+ with guard_training_context:
x, y = x.to(device), y.to(device)
logits = model(x)
loss = criterion(logits, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
accuracy.update(logits, y)
- wandb.log({"train/loss": loss.item()})
- wandb.log({"train/acc": accuracy.compute().item(), "epoch": epoch})
+ wl.save_signals(preds_raw=logits, targets=y,
+ signals={"metric/accuracy": accuracy.compute().item()})
model.eval()
accuracy.reset()
with torch.no_grad():
for x, y in test_loader:
+ with guard_testing_context:
x, y = x.to(device), y.to(device)
accuracy.update(model(x), y)
- wandb.log({"test/acc": accuracy.compute().item(), "epoch": epoch})
+ wl.save_signals(preds_raw=logits, targets=y,
+ signals={"metric/accuracy": accuracy.compute().item()})
- wandb.finish()
+ wl.keep_serving()- π Experiment tracking for reproducibility
- π Provides live metrics and visualization in the web UI
- π Enables data supervision during training and experiment hyperparameter tuning through the UI
After starting the UI, launch a local experiment with the command:
weightslab start example # classification (default)
# weightslab start example --cls # classification
# weightslab start example --seg # segmentation
# weightslab start example --det # detection
# weightslab start example --clus # clustering
# weightslab start example --gen # generationFind our sandbox online. The password is graybx.
Find our documentation online.
New here (human or AI coding agent)? Start with AGENTS.md β it
captures the cross-repo architecture (weightslab backend β weights_studio
frontend via the shared proto), the module maps, the wl.watch_or_edit
integration pattern, where tests live, and the gotchas that aren't obvious from
any single file. It's the fastest way to orient before a first change.
Graybx is building a wonderful community of AI researchers and engineers. Are you interested in joining our project? Contact us at hello [at] graybx [dot] com
