From 3a320df3aa78306aa83d0a62820c969ab383aa6e Mon Sep 17 00:00:00 2001
From: Chris Gwilliams <517923+encima@users.noreply.github.com>
Date: Sat, 17 Jan 2026 12:33:24 +0700
Subject: [PATCH] feat: add scenarios for different usage
Signed-off-by: Chris Gwilliams <517923+encima@users.noreply.github.com>
---
Makefile | 6 +-
api/handlers.go | 54 ++-
api/router.go | 1 +
config/config.go | 8 +-
db/postgres.go | 16 +
frontend/.gitignore | 1 +
frontend/src/App.jsx | 12 +-
frontend/src/api/client.js | 5 +
frontend/src/components/ControlPanel.jsx | 54 ++-
frontend/src/components/StatsPanel.jsx | 8 +-
frontend/src/utils/formatting.js | 11 +
go.mod | 4 +-
go.sum | 4 +
init.sql | 142 ++++++--
load/controller.go | 87 ++++-
load/reader.go | 47 +--
load/writer.go | 32 +-
main.go | 12 +-
metrics/types.go | 1 +
schema/builtin.go | 405 +++++++++++++++++++++++
schema/custom.go | 342 +++++++++++++++++++
schema/generator.go | 151 +++++++++
schema/scenario.go | 118 +++++++
23 files changed, 1439 insertions(+), 82 deletions(-)
create mode 100644 schema/builtin.go
create mode 100644 schema/custom.go
create mode 100644 schema/generator.go
create mode 100644 schema/scenario.go
diff --git a/Makefile b/Makefile
index 0fee460..f2ced1c 100644
--- a/Makefile
+++ b/Makefile
@@ -8,7 +8,7 @@ build: frontend backend
# Build frontend
frontend:
- cd frontend && npm install && npm run build
+ cd frontend && pnpm install && pnpm run build
# Build backend (includes embedded frontend)
backend:
@@ -41,12 +41,12 @@ test:
# Format code
fmt:
go fmt ./...
- cd frontend && npm run format 2>/dev/null || true
+ cd frontend && pnpm run format 2>/dev/null || true
# Tidy dependencies
tidy:
go mod tidy
- cd frontend && npm install
+ cd frontend && pnpm install
# Build Docker image
image: build
diff --git a/api/handlers.go b/api/handlers.go
index 343dfdb..bb25adf 100644
--- a/api/handlers.go
+++ b/api/handlers.go
@@ -6,6 +6,7 @@ import (
"supafirehose/load"
"supafirehose/metrics"
+ "supafirehose/schema"
)
// Handlers holds the HTTP handler dependencies
@@ -47,10 +48,12 @@ func (h *Handlers) HandleStatus(w http.ResponseWriter, r *http.Request) {
// ConfigRequest is the request body for POST /api/config
type ConfigRequest struct {
- Connections int `json:"connections"`
- ReadQPS int `json:"read_qps"`
- WriteQPS int `json:"write_qps"`
- ChurnRate int `json:"churn_rate"`
+ Connections int `json:"connections"`
+ ReadQPS int `json:"read_qps"`
+ WriteQPS int `json:"write_qps"`
+ ChurnRate int `json:"churn_rate"`
+ Scenario string `json:"scenario,omitempty"`
+ CustomTable string `json:"custom_table,omitempty"`
}
// ConfigResponse is the response for POST /api/config
@@ -72,11 +75,26 @@ func (h *Handlers) HandleConfig(w http.ResponseWriter, r *http.Request) {
return
}
+ // Get current config for defaults
+ currentConfig := h.controller.GetConfig()
+
+ // Use current values if not provided
+ scenario := req.Scenario
+ if scenario == "" {
+ scenario = currentConfig.Scenario
+ }
+ customTable := req.CustomTable
+ if customTable == "" && scenario == currentConfig.Scenario {
+ customTable = currentConfig.CustomTable
+ }
+
cfg := load.Config{
Connections: req.Connections,
ReadQPS: req.ReadQPS,
WriteQPS: req.WriteQPS,
ChurnRate: req.ChurnRate,
+ Scenario: scenario,
+ CustomTable: customTable,
}
h.controller.UpdateConfig(cfg)
@@ -150,3 +168,31 @@ func writeJSON(w http.ResponseWriter, data interface{}) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(data)
}
+
+// ScenariosResponse is the response for GET /api/scenarios
+type ScenariosResponse struct {
+ Scenarios []schema.ScenarioInfo `json:"scenarios"`
+}
+
+// HandleScenarios returns the list of available scenarios
+func (h *Handlers) HandleScenarios(w http.ResponseWriter, r *http.Request) {
+ if r.Method != http.MethodGet {
+ http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
+ return
+ }
+
+ scenarios := h.controller.GetRegistry().List()
+
+ // Add custom scenario option
+ scenarios = append(scenarios, schema.ScenarioInfo{
+ Name: "custom",
+ Description: "Custom table (optionally specify table name)",
+ TableName: "",
+ })
+
+ resp := ScenariosResponse{
+ Scenarios: scenarios,
+ }
+
+ writeJSON(w, resp)
+}
diff --git a/api/router.go b/api/router.go
index f17db3a..2886856 100644
--- a/api/router.go
+++ b/api/router.go
@@ -16,6 +16,7 @@ func NewRouter(handlers *Handlers, wsHub *WebSocketHub, staticFS fs.FS) http.Han
mux.HandleFunc("/api/start", handlers.HandleStart)
mux.HandleFunc("/api/stop", handlers.HandleStop)
mux.HandleFunc("/api/reset", handlers.HandleReset)
+ mux.HandleFunc("/api/scenarios", handlers.HandleScenarios)
// WebSocket route
mux.HandleFunc("/ws/metrics", wsHub.HandleWebSocket)
diff --git a/config/config.go b/config/config.go
index 61b258c..7396512 100644
--- a/config/config.go
+++ b/config/config.go
@@ -26,7 +26,10 @@ type Config struct {
// Metrics
MetricsInterval time.Duration
- MaxUserID int64
+
+ // Schema scenario
+ DefaultScenario string
+ CustomTable string
}
// Load reads configuration from environment variables with defaults
@@ -41,7 +44,8 @@ func Load() *Config {
MaxReadQPS: getEnvInt("MAX_READ_QPS", 500000),
MaxWriteQPS: getEnvInt("MAX_WRITE_QPS", 500000),
MetricsInterval: getEnvDuration("METRICS_INTERVAL", 100*time.Millisecond),
- MaxUserID: getEnvInt64("MAX_USER_ID", 100000),
+ DefaultScenario: getEnv("DEFAULT_SCENARIO", "simple"),
+ CustomTable: getEnv("CUSTOM_TABLE", ""),
}
}
diff --git a/db/postgres.go b/db/postgres.go
index 2cf87c3..6ffc6d4 100644
--- a/db/postgres.go
+++ b/db/postgres.go
@@ -62,3 +62,19 @@ func (cm *ConnectionManager) Ping(ctx context.Context) error {
defer conn.Close(ctx)
return conn.Ping(ctx)
}
+
+// GetDatabaseSize returns the current database size in bytes
+func (cm *ConnectionManager) GetDatabaseSize(ctx context.Context) (int64, error) {
+ conn, err := pgx.Connect(ctx, cm.connString)
+ if err != nil {
+ return 0, fmt.Errorf("failed to connect: %w", err)
+ }
+ defer conn.Close(ctx)
+
+ var size int64
+ err = conn.QueryRow(ctx, "SELECT pg_database_size(current_database())").Scan(&size)
+ if err != nil {
+ return 0, fmt.Errorf("failed to get database size: %w", err)
+ }
+ return size, nil
+}
diff --git a/frontend/.gitignore b/frontend/.gitignore
index a547bf3..2d393d7 100644
--- a/frontend/.gitignore
+++ b/frontend/.gitignore
@@ -22,3 +22,4 @@ dist-ssr
*.njsproj
*.sln
*.sw?
+pnpm-lock.yaml
\ No newline at end of file
diff --git a/frontend/src/App.jsx b/frontend/src/App.jsx
index c01c694..1db360b 100644
--- a/frontend/src/App.jsx
+++ b/frontend/src/App.jsx
@@ -1,7 +1,7 @@
import { useState, useEffect, useMemo, useRef } from 'react';
import { useWebSocket } from './hooks/useWebSocket';
import { useMetricsHistory } from './hooks/useMetricsHistory';
-import { getStatus, updateConfig, start, stop, reset } from './api/client';
+import { getStatus, getScenarios, updateConfig, start, stop, reset } from './api/client';
import { ConnectionStatus } from './components/ConnectionStatus';
import { ControlPanel } from './components/ControlPanel';
import { StatsPanel } from './components/StatsPanel';
@@ -15,7 +15,10 @@ function App() {
read_qps: 100,
write_qps: 10,
churn_rate: 0,
+ scenario: 'simple',
+ custom_table: '',
});
+ const [scenarios, setScenarios] = useState([]);
const [running, setRunning] = useState(false);
const [latestMetrics, setLatestMetrics] = useState(null);
@@ -28,12 +31,16 @@ function App() {
// Memoize display data to prevent unnecessary re-computations
const displayData = useMemo(() => getDisplayData(), [getDisplayData]);
- // Fetch initial status
+ // Fetch initial status and scenarios
useEffect(() => {
getStatus().then((status) => {
setRunning(status.running);
setConfig(status.config);
}).catch(console.error);
+
+ getScenarios().then((data) => {
+ setScenarios(data.scenarios || []);
+ }).catch(console.error);
}, []);
// Update metrics from WebSocket - use ref to track changes
@@ -116,6 +123,7 @@ function App() {
{
@@ -13,11 +13,63 @@ export function ControlPanel({ config, running, onConfigChange, onStart, onStop,
onConfigChange(newConfig);
};
+ const handleScenarioChange = (scenario) => {
+ const newConfig = { ...localConfig, scenario };
+ // Clear custom_table if not custom scenario
+ if (scenario !== 'custom') {
+ newConfig.custom_table = '';
+ }
+ setLocalConfig(newConfig);
+ onConfigChange(newConfig);
+ };
+
+ const handleCustomTableChange = (customTable) => {
+ const newConfig = { ...localConfig, custom_table: customTable };
+ setLocalConfig(newConfig);
+ // Don't auto-submit custom table - wait for blur or enter
+ };
+
+ const submitCustomTable = () => {
+ onConfigChange(localConfig);
+ };
+
return (
Control Panel
+ {/* Scenario Selector */}
+
+
+
+
+
+ {/* Custom Table Input */}
+ {localConfig.scenario === 'custom' && (
+
+
+ handleCustomTableChange(e.target.value)}
+ onBlur={submitCustomTable}
+ onKeyDown={(e) => e.key === 'Enter' && submitCustomTable()}
+ placeholder="schema.table_name"
+ className="w-full bg-slate-700 border border-slate-600 rounded-lg px-3 py-2 text-white text-sm focus:outline-none focus:ring-2 focus:ring-blue-500 placeholder-slate-500"
+ />
+
+ )}
+
Statistics
-
);
diff --git a/frontend/src/utils/formatting.js b/frontend/src/utils/formatting.js
index 00ac80a..91ff626 100644
--- a/frontend/src/utils/formatting.js
+++ b/frontend/src/utils/formatting.js
@@ -25,3 +25,14 @@ export function formatLatency(ms) {
export function formatPercent(rate) {
return (rate * 100).toFixed(3) + '%';
}
+
+export function formatBytes(bytes) {
+ if (bytes === 0 || bytes === undefined || bytes === null) {
+ return '0 B';
+ }
+ const units = ['B', 'KB', 'MB', 'GB', 'TB'];
+ const k = 1024;
+ const i = Math.floor(Math.log(bytes) / Math.log(k));
+ const value = bytes / Math.pow(k, i);
+ return value.toFixed(i > 0 ? 2 : 0) + ' ' + units[i];
+}
diff --git a/go.mod b/go.mod
index e8146bc..280f1b3 100644
--- a/go.mod
+++ b/go.mod
@@ -3,15 +3,15 @@ module supafirehose
go 1.25.1
require (
+ github.com/google/uuid v1.6.0
github.com/gorilla/websocket v1.5.3
github.com/jackc/pgx/v5 v5.8.0
golang.org/x/time v0.14.0
)
require (
+ github.com/brianvoe/gofakeit/v6 v6.28.0 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
- github.com/jackc/puddle/v2 v2.2.2 // indirect
- golang.org/x/sync v0.17.0 // indirect
golang.org/x/text v0.29.0 // indirect
)
diff --git a/go.sum b/go.sum
index ec09583..8f3e1bc 100644
--- a/go.sum
+++ b/go.sum
@@ -1,6 +1,10 @@
+github.com/brianvoe/gofakeit/v6 v6.28.0 h1:Xib46XXuQfmlLS2EXRuJpqcw8St6qSZz75OUo0tgAW4=
+github.com/brianvoe/gofakeit/v6 v6.28.0/go.mod h1:Xj58BMSnFqcn/fAQeSK+/PLtC5kSb7FJIq4JyGa8vEs=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
+github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
+github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
diff --git a/init.sql b/init.sql
index 03672b8..a23b1bd 100644
--- a/init.sql
+++ b/init.sql
@@ -1,24 +1,128 @@
-- Supafirehose Demo Database Schema
-- Run: psql -h localhost -U postgres -d pooler_demo -f init.sql
-
--- Create users table for read/write operations
+-- ============================================
+-- Simple Scenario: Basic users table
+-- ============================================
CREATE TABLE IF NOT EXISTS users (
- id BIGSERIAL PRIMARY KEY,
- username VARCHAR(255) NOT NULL,
- email VARCHAR(255) NOT NULL,
- created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
+ id BIGSERIAL PRIMARY KEY,
+ username VARCHAR(255) NOT NULL,
+ email VARCHAR(255) NOT NULL,
+ created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
-
--- Create index on id (primary key handles this, but being explicit)
--- The primary key already creates a unique index
-
--- Seed with 100,000 users for read operations
-INSERT INTO users (username, email)
-SELECT
- 'user_' || i,
- 'user_' || i || '@example.com'
-FROM generate_series(1, 100000) AS i
-ON CONFLICT DO NOTHING;
-
--- Analyze table for query planner
ANALYZE users;
+-- ============================================
+-- JSONB Scenario: Table with JSONB payload
+-- ============================================
+CREATE TABLE IF NOT EXISTS jsonb_data (
+ id BIGSERIAL PRIMARY KEY,
+ payload JSONB NOT NULL,
+ created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
+);
+-- Create GIN index on JSONB column for efficient queries
+CREATE INDEX IF NOT EXISTS idx_jsonb_data_payload ON jsonb_data USING GIN (payload);
+ANALYZE jsonb_data;
+-- ============================================
+-- Wide Scenario: Table with many columns
+-- ============================================
+CREATE TABLE IF NOT EXISTS wide_data (
+ id BIGSERIAL PRIMARY KEY,
+ col_01 VARCHAR(255),
+ col_02 VARCHAR(255),
+ col_03 VARCHAR(255),
+ col_04 VARCHAR(255),
+ col_05 VARCHAR(255),
+ col_06 VARCHAR(255),
+ col_07 VARCHAR(255),
+ col_08 VARCHAR(255),
+ col_09 VARCHAR(255),
+ col_10 VARCHAR(255),
+ col_11 VARCHAR(255),
+ col_12 VARCHAR(255),
+ col_13 VARCHAR(255),
+ col_14 VARCHAR(255),
+ col_15 VARCHAR(255),
+ col_16 VARCHAR(255),
+ col_17 VARCHAR(255),
+ col_18 VARCHAR(255),
+ col_19 VARCHAR(255),
+ col_20 VARCHAR(255),
+ int_01 INTEGER,
+ int_02 INTEGER,
+ int_03 INTEGER,
+ int_04 INTEGER,
+ int_05 INTEGER,
+ created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
+);
+-- Seed with 100,000 rows of wide data
+INSERT INTO wide_data (
+ col_01,
+ col_02,
+ col_03,
+ col_04,
+ col_05,
+ col_06,
+ col_07,
+ col_08,
+ col_09,
+ col_10,
+ col_11,
+ col_12,
+ col_13,
+ col_14,
+ col_15,
+ col_16,
+ col_17,
+ col_18,
+ col_19,
+ col_20,
+ int_01,
+ int_02,
+ int_03,
+ int_04,
+ int_05
+ )
+SELECT 'col01_' || i,
+ 'col02_' || i,
+ 'col03_' || i,
+ 'col04_' || i,
+ 'col05_' || i,
+ 'col06_' || i,
+ 'col07_' || i,
+ 'col08_' || i,
+ 'col09_' || i,
+ 'col10_' || i,
+ 'col11_' || i,
+ 'col12_' || i,
+ 'col13_' || i,
+ 'col14_' || i,
+ 'col15_' || i,
+ 'col16_' || i,
+ 'col17_' || i,
+ 'col18_' || i,
+ 'col19_' || i,
+ 'col20_' || i,
+ i % 1000,
+ i % 500,
+ i % 250,
+ i % 100,
+ i % 50
+FROM generate_series(1, 100000) AS i ON CONFLICT DO NOTHING;
+ANALYZE wide_data;
+-- ============================================
+-- FK Scenario: Tables with foreign key lookup
+-- ============================================
+DROP TABLE IF EXISTS categories CASCADE;
+DROP TABLE IF EXISTS items CASCADE;
+CREATE TABLE IF NOT EXISTS categories (
+ id BIGINT PRIMARY KEY,
+ name VARCHAR(255) NOT NULL
+);
+CREATE EXTENSION IF NOT EXISTS "uuid-ossp";
+CREATE TABLE IF NOT EXISTS items (
+ id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
+ category_id BIGINT NOT NULL REFERENCES categories(id),
+ name VARCHAR(255) NOT NULL,
+ created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
+);
+ANALYZE categories;
+ANALYZE items;
\ No newline at end of file
diff --git a/load/controller.go b/load/controller.go
index cf7f45d..4f68d99 100644
--- a/load/controller.go
+++ b/load/controller.go
@@ -6,16 +6,19 @@ import (
"supafirehose/db"
"supafirehose/metrics"
+ "supafirehose/schema"
"golang.org/x/time/rate"
)
// Config holds the load generator configuration
type Config struct {
- Connections int `json:"connections"`
- ReadQPS int `json:"read_qps"`
- WriteQPS int `json:"write_qps"`
- ChurnRate int `json:"churn_rate"` // Connections churned per second
+ Connections int `json:"connections"`
+ ReadQPS int `json:"read_qps"`
+ WriteQPS int `json:"write_qps"`
+ ChurnRate int `json:"churn_rate"` // Connections churned per second
+ Scenario string `json:"scenario"` // Scenario name (simple, jsonb, wide, fk, custom)
+ CustomTable string `json:"custom_table"` // Table name for custom scenario
}
// Controller manages the load generation workers
@@ -32,7 +35,10 @@ type Controller struct {
// Dependencies
connMgr *db.ConnectionManager
collector *metrics.Collector
- maxUserID int64
+ registry *schema.Registry
+
+ // Current scenario
+ scenario schema.Scenario
// Worker management
ctx context.Context
@@ -41,16 +47,58 @@ type Controller struct {
}
// NewController creates a new load controller
-func NewController(connMgr *db.ConnectionManager, collector *metrics.Collector, maxUserID int64) *Controller {
+func NewController(connMgr *db.ConnectionManager, collector *metrics.Collector) *Controller {
+ registry := schema.NewRegistry()
+ defaultScenario, _ := registry.Get("simple")
+
return &Controller{
connMgr: connMgr,
collector: collector,
- maxUserID: maxUserID,
+ registry: registry,
+ scenario: defaultScenario,
readLimiter: rate.NewLimiter(rate.Limit(100), 100),
writeLimiter: rate.NewLimiter(rate.Limit(10), 10),
}
}
+// GetRegistry returns the scenario registry
+func (c *Controller) GetRegistry() *schema.Registry {
+ return c.registry
+}
+
+// GetScenario returns the current scenario
+func (c *Controller) GetScenario() schema.Scenario {
+ c.mu.RLock()
+ defer c.mu.RUnlock()
+ return c.scenario
+}
+
+// SetScenario sets the current scenario by name
+func (c *Controller) SetScenario(name string, customTable string) error {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
+ var newScenario schema.Scenario
+
+ if name == "custom" {
+ // customTable can be empty for auto-discovery
+ newScenario = c.registry.CreateCustomScenario(customTable)
+ } else {
+ s, ok := c.registry.Get(name)
+ if !ok {
+ // Default to simple if not found
+ s, _ = c.registry.Get("simple")
+ }
+ newScenario = s
+ }
+
+ c.scenario = newScenario
+ c.config.Scenario = name
+ c.config.CustomTable = customTable
+
+ return nil
+}
+
// Start begins load generation with the current configuration
func (c *Controller) Start() {
c.mu.Lock()
@@ -81,12 +129,15 @@ func (c *Controller) Start() {
numWriters = 0
}
+ // Get current scenario
+ scenario := c.scenario
+
// Start read workers
for i := 0; i < numReaders; i++ {
c.wg.Add(1)
go func() {
defer c.wg.Done()
- worker := NewReadWorker(c.connMgr, c.readLimiter, c.collector, c.maxUserID, churnRate)
+ worker := NewReadWorker(c.connMgr, c.readLimiter, c.collector, scenario, churnRate)
worker.Run(c.ctx)
}()
}
@@ -96,7 +147,7 @@ func (c *Controller) Start() {
c.wg.Add(1)
go func() {
defer c.wg.Done()
- worker := NewWriteWorker(c.connMgr, c.writeLimiter, c.collector, churnRate)
+ worker := NewWriteWorker(c.connMgr, c.writeLimiter, c.collector, scenario, churnRate)
worker.Run(c.ctx)
}()
}
@@ -128,8 +179,22 @@ func (c *Controller) UpdateConfig(cfg Config) {
c.writeLimiter.SetLimit(rate.Limit(cfg.WriteQPS))
c.writeLimiter.SetBurst(max(cfg.WriteQPS, 1))
- // If running and connection count or churn changed, restart workers
- needsRestart := c.running && (oldConfig.Connections != cfg.Connections || oldConfig.ChurnRate != cfg.ChurnRate)
+ // Update scenario if changed
+ scenarioChanged := oldConfig.Scenario != cfg.Scenario || oldConfig.CustomTable != cfg.CustomTable
+ if scenarioChanged {
+ if cfg.Scenario == "custom" {
+ c.scenario = c.registry.CreateCustomScenario(cfg.CustomTable)
+ } else if cfg.Scenario != "" {
+ if s, ok := c.registry.Get(cfg.Scenario); ok {
+ c.scenario = s
+ }
+ }
+ }
+
+ // If running and connection count, churn, or scenario changed, restart workers
+ needsRestart := c.running && (oldConfig.Connections != cfg.Connections ||
+ oldConfig.ChurnRate != cfg.ChurnRate ||
+ scenarioChanged)
c.mu.Unlock()
if needsRestart {
diff --git a/load/reader.go b/load/reader.go
index dd835f4..61e0cdb 100644
--- a/load/reader.go
+++ b/load/reader.go
@@ -7,6 +7,7 @@ import (
"supafirehose/db"
"supafirehose/metrics"
+ "supafirehose/schema"
"github.com/jackc/pgx/v5"
"golang.org/x/time/rate"
@@ -14,34 +15,29 @@ import (
// ReadWorker executes read queries against the database
type ReadWorker struct {
- connMgr *db.ConnectionManager
- limiter *rate.Limiter
- collector *metrics.Collector
- maxID int64
- churnRate float64 // Probability of churning connection per second
+ connMgr *db.ConnectionManager
+ limiter *rate.Limiter
+ collector *metrics.Collector
+ scenario schema.Scenario
+ churnRate float64 // Probability of churning connection per second
}
// NewReadWorker creates a new read worker
-func NewReadWorker(connMgr *db.ConnectionManager, limiter *rate.Limiter, collector *metrics.Collector, maxID int64, churnRate float64) *ReadWorker {
+func NewReadWorker(connMgr *db.ConnectionManager, limiter *rate.Limiter, collector *metrics.Collector, scenario schema.Scenario, churnRate float64) *ReadWorker {
return &ReadWorker{
connMgr: connMgr,
limiter: limiter,
collector: collector,
- maxID: maxID,
+ scenario: scenario,
churnRate: churnRate,
}
}
-// User represents a row from the users table
-type User struct {
- ID int64
- Username string
- Email string
- CreatedAt time.Time
-}
-
// Run starts the read worker loop with its own connection
func (w *ReadWorker) Run(ctx context.Context) {
+ // Track if scenario has been initialized (for custom scenarios)
+ scenarioInitialized := false
+
for {
select {
case <-ctx.Done():
@@ -62,6 +58,18 @@ func (w *ReadWorker) Run(ctx context.Context) {
continue
}
+ // Initialize scenario if needed (for custom scenarios that need table introspection)
+ if !scenarioInitialized {
+ if err := w.scenario.Initialize(ctx, conn); err != nil {
+ w.collector.RecordRead(0, err)
+ conn.Close(context.Background())
+ w.connMgr.Release()
+ time.Sleep(100 * time.Millisecond)
+ continue
+ }
+ scenarioInitialized = true
+ }
+
// Run queries on this connection until churn or context done
w.runWithConnection(ctx, conn)
@@ -113,14 +121,7 @@ func (w *ReadWorker) runWithConnection(ctx context.Context, conn *pgx.Conn) {
func (w *ReadWorker) executeRead(ctx context.Context, conn *pgx.Conn) {
start := time.Now()
- // Random ID within the known range
- id := rand.Int63n(w.maxID) + 1
-
- var user User
- err := conn.QueryRow(ctx,
- "SELECT id, username, email, created_at FROM users WHERE id = $1",
- id,
- ).Scan(&user.ID, &user.Username, &user.Email, &user.CreatedAt)
+ err := w.scenario.ExecuteRead(ctx, conn)
latency := time.Since(start)
diff --git a/load/writer.go b/load/writer.go
index 86399eb..688c7e4 100644
--- a/load/writer.go
+++ b/load/writer.go
@@ -2,12 +2,12 @@ package load
import (
"context"
- "fmt"
"math/rand"
"time"
"supafirehose/db"
"supafirehose/metrics"
+ "supafirehose/schema"
"github.com/jackc/pgx/v5"
"golang.org/x/time/rate"
@@ -18,21 +18,26 @@ type WriteWorker struct {
connMgr *db.ConnectionManager
limiter *rate.Limiter
collector *metrics.Collector
+ scenario schema.Scenario
churnRate float64 // Probability of churning connection per second
}
// NewWriteWorker creates a new write worker
-func NewWriteWorker(connMgr *db.ConnectionManager, limiter *rate.Limiter, collector *metrics.Collector, churnRate float64) *WriteWorker {
+func NewWriteWorker(connMgr *db.ConnectionManager, limiter *rate.Limiter, collector *metrics.Collector, scenario schema.Scenario, churnRate float64) *WriteWorker {
return &WriteWorker{
connMgr: connMgr,
limiter: limiter,
collector: collector,
+ scenario: scenario,
churnRate: churnRate,
}
}
// Run starts the write worker loop with its own connection
func (w *WriteWorker) Run(ctx context.Context) {
+ // Track if scenario has been initialized (for custom scenarios)
+ scenarioInitialized := false
+
for {
select {
case <-ctx.Done():
@@ -53,6 +58,18 @@ func (w *WriteWorker) Run(ctx context.Context) {
continue
}
+ // Initialize scenario if needed (for custom scenarios that need table introspection)
+ if !scenarioInitialized {
+ if err := w.scenario.Initialize(ctx, conn); err != nil {
+ w.collector.RecordWrite(0, err)
+ conn.Close(context.Background())
+ w.connMgr.Release()
+ time.Sleep(100 * time.Millisecond)
+ continue
+ }
+ scenarioInitialized = true
+ }
+
// Run queries on this connection until churn or context done
w.runWithConnection(ctx, conn)
@@ -98,16 +115,7 @@ func (w *WriteWorker) runWithConnection(ctx context.Context, conn *pgx.Conn) {
func (w *WriteWorker) executeWrite(ctx context.Context, conn *pgx.Conn) {
start := time.Now()
- // Generate random user data
- randNum := rand.Int63()
- username := fmt.Sprintf("user_%d", randNum)
- email := fmt.Sprintf("user_%d@example.com", randNum)
-
- var newID int64
- err := conn.QueryRow(ctx,
- "INSERT INTO users (username, email) VALUES ($1, $2) RETURNING id",
- username, email,
- ).Scan(&newID)
+ err := w.scenario.ExecuteWrite(ctx, conn)
latency := time.Since(start)
diff --git a/main.go b/main.go
index 1e7cc06..6ce180f 100644
--- a/main.go
+++ b/main.go
@@ -47,20 +47,30 @@ func main() {
// Create metrics collector with connection stats function
collector := metrics.NewCollector(func() metrics.PoolStats {
+ // Get database size (ignore errors, just return 0 if it fails)
+ dbSize, _ := connMgr.GetDatabaseSize(ctx)
return metrics.PoolStats{
ActiveConnections: connMgr.ActiveConnections(),
IdleConnections: 0,
WaitingRequests: 0,
+ DatabaseSizeBytes: dbSize,
}
})
// Create load controller
- controller := load.NewController(connMgr, collector, cfg.MaxUserID)
+ controller := load.NewController(connMgr, collector)
+
+ // Set initial scenario
+ controller.SetScenario(cfg.DefaultScenario, cfg.CustomTable)
+
+ // Set initial config
controller.SetConfig(load.Config{
Connections: cfg.DefaultConnections,
ReadQPS: cfg.DefaultReadQPS,
WriteQPS: cfg.DefaultWriteQPS,
ChurnRate: 0,
+ Scenario: cfg.DefaultScenario,
+ CustomTable: cfg.CustomTable,
})
// Create API handlers
diff --git a/metrics/types.go b/metrics/types.go
index 4fef196..baebbf4 100644
--- a/metrics/types.go
+++ b/metrics/types.go
@@ -37,4 +37,5 @@ type PoolStats struct {
ActiveConnections int32 `json:"active_connections"`
IdleConnections int32 `json:"idle_connections"`
WaitingRequests int32 `json:"waiting_requests"`
+ DatabaseSizeBytes int64 `json:"database_size_bytes"`
}
diff --git a/schema/builtin.go b/schema/builtin.go
new file mode 100644
index 0000000..e33a541
--- /dev/null
+++ b/schema/builtin.go
@@ -0,0 +1,405 @@
+package schema
+
+import (
+ "context"
+ "fmt"
+ "math/rand"
+ "strings"
+ "sync"
+ "time"
+
+ "github.com/jackc/pgx/v5"
+)
+
+// SimpleScenario is the original users table scenario
+type SimpleScenario struct {
+ maxID int64
+ ids []string
+ mu sync.RWMutex
+}
+
+func NewSimpleScenario() *SimpleScenario {
+ return &SimpleScenario{
+ maxID: 100000,
+ ids: make([]string, 0, 10000),
+ }
+}
+
+func (s *SimpleScenario) Name() string { return "simple" }
+func (s *SimpleScenario) Description() string { return "Simple users table (username, email)" }
+func (s *SimpleScenario) TableName() string { return "users" }
+func (s *SimpleScenario) MaxID() int64 { return s.maxID }
+
+func (s *SimpleScenario) Initialize(ctx context.Context, conn *pgx.Conn) error {
+ // Pre-load some existing IDs to support UUIDs or string IDs
+ rows, err := conn.Query(ctx, "SELECT id::text FROM users LIMIT 10000")
+ if err != nil {
+ return nil // Ignore error if table doesn't exist
+ }
+ defer rows.Close()
+
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ for rows.Next() {
+ var id string
+ if err := rows.Scan(&id); err == nil {
+ s.ids = append(s.ids, id)
+ }
+ }
+ return nil
+}
+
+func (s *SimpleScenario) ExecuteRead(ctx context.Context, conn *pgx.Conn) error {
+ // Try to use cached ID first (for UUID support), fallback to random int
+ var id interface{}
+
+ s.mu.RLock()
+ if len(s.ids) > 0 {
+ id = s.ids[rand.Intn(len(s.ids))]
+ } else {
+ id = rand.Int63n(s.maxID) + 1
+ }
+ s.mu.RUnlock()
+
+ var userID string
+ var username, email string
+ var createdAt time.Time
+
+ return conn.QueryRow(ctx,
+ "SELECT id::text, username, email, created_at FROM users WHERE id = $1",
+ id,
+ ).Scan(&userID, &username, &email, &createdAt)
+}
+
+func (s *SimpleScenario) ExecuteWrite(ctx context.Context, conn *pgx.Conn) error {
+ randNum := rand.Int63()
+ username := fmt.Sprintf("user_%d", randNum)
+ email := fmt.Sprintf("user_%d@example.com", randNum)
+
+ var newID string
+ err := conn.QueryRow(ctx,
+ "INSERT INTO users (username, email) VALUES ($1, $2) RETURNING id::text",
+ username, email,
+ ).Scan(&newID)
+
+ if err == nil && newID != "" {
+ s.mu.Lock()
+ if len(s.ids) < 10000 {
+ s.ids = append(s.ids, newID)
+ } else {
+ s.ids[rand.Intn(len(s.ids))] = newID
+ }
+ s.mu.Unlock()
+ }
+
+ return err
+}
+
+// JSONBScenario uses a table with a JSONB payload column
+type JSONBScenario struct {
+ maxID int64
+ ids []string
+ mu sync.RWMutex
+}
+
+func NewJSONBScenario() *JSONBScenario {
+ return &JSONBScenario{
+ maxID: 100000,
+ ids: make([]string, 0, 10000),
+ }
+}
+
+func (s *JSONBScenario) Name() string { return "jsonb" }
+func (s *JSONBScenario) Description() string { return "Table with JSONB payload column" }
+func (s *JSONBScenario) TableName() string { return "jsonb_data" }
+func (s *JSONBScenario) MaxID() int64 { return s.maxID }
+
+func (s *JSONBScenario) Initialize(ctx context.Context, conn *pgx.Conn) error {
+ // Pre-load some existing IDs
+ rows, err := conn.Query(ctx, "SELECT id::text FROM jsonb_data LIMIT 10000")
+ if err != nil {
+ return nil
+ }
+ defer rows.Close()
+
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ for rows.Next() {
+ var id string
+ if err := rows.Scan(&id); err == nil {
+ s.ids = append(s.ids, id)
+ }
+ }
+ return nil
+}
+
+func (s *JSONBScenario) ExecuteRead(ctx context.Context, conn *pgx.Conn) error {
+ var id interface{}
+
+ s.mu.RLock()
+ if len(s.ids) > 0 {
+ id = s.ids[rand.Intn(len(s.ids))]
+ } else {
+ id = rand.Int63n(s.maxID) + 1
+ }
+ s.mu.RUnlock()
+
+ var dataID int64
+ var payload string
+ var createdAt time.Time
+ return conn.QueryRow(ctx,
+ "SELECT id, payload, created_at FROM jsonb_data WHERE id = $1",
+ id,
+ ).Scan(&dataID, &payload, &createdAt)
+}
+
+func (s *JSONBScenario) ExecuteWrite(ctx context.Context, conn *pgx.Conn) error {
+ payload := generateJSON()
+
+ var newID string
+ err := conn.QueryRow(ctx,
+ "INSERT INTO jsonb_data (payload) VALUES ($1::jsonb) RETURNING id::text",
+ payload,
+ ).Scan(&newID)
+
+ if err == nil && newID != "" {
+ s.mu.Lock()
+ if len(s.ids) < 10000 {
+ s.ids = append(s.ids, newID)
+ } else {
+ s.ids[rand.Intn(len(s.ids))] = newID
+ }
+ s.mu.Unlock()
+ }
+ return err
+}
+
+// WideScenario uses a table with many columns
+type WideScenario struct {
+ maxID int64
+ ids []string
+ mu sync.RWMutex
+}
+
+func NewWideScenario() *WideScenario {
+ return &WideScenario{
+ maxID: 100000,
+ ids: make([]string, 0, 10000),
+ }
+}
+
+func (s *WideScenario) Name() string { return "wide" }
+func (s *WideScenario) Description() string { return "Wide table with 20+ columns" }
+func (s *WideScenario) TableName() string { return "wide_data" }
+func (s *WideScenario) MaxID() int64 { return s.maxID }
+
+func (s *WideScenario) Initialize(ctx context.Context, conn *pgx.Conn) error {
+ // Pre-load some existing IDs
+ rows, err := conn.Query(ctx, "SELECT id::text FROM wide_data LIMIT 10000")
+ if err != nil {
+ return nil
+ }
+ defer rows.Close()
+
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ for rows.Next() {
+ var id string
+ if err := rows.Scan(&id); err == nil {
+ s.ids = append(s.ids, id)
+ }
+ }
+ return nil
+}
+
+func (s *WideScenario) ExecuteRead(ctx context.Context, conn *pgx.Conn) error {
+ var id interface{}
+
+ s.mu.RLock()
+ if len(s.ids) > 0 {
+ id = s.ids[rand.Intn(len(s.ids))]
+ } else {
+ id = rand.Int63n(s.maxID) + 1
+ }
+ s.mu.RUnlock()
+
+ // Read all columns
+ var dataID int64
+ var cols [20]string
+ var ints [5]int32
+ var createdAt time.Time
+
+ return conn.QueryRow(ctx,
+ `SELECT id,
+ col_01, col_02, col_03, col_04, col_05,
+ col_06, col_07, col_08, col_09, col_10,
+ col_11, col_12, col_13, col_14, col_15,
+ col_16, col_17, col_18, col_19, col_20,
+ int_01, int_02, int_03, int_04, int_05,
+ created_at
+ FROM wide_data WHERE id = $1`,
+ id,
+ ).Scan(&dataID,
+ &cols[0], &cols[1], &cols[2], &cols[3], &cols[4],
+ &cols[5], &cols[6], &cols[7], &cols[8], &cols[9],
+ &cols[10], &cols[11], &cols[12], &cols[13], &cols[14],
+ &cols[15], &cols[16], &cols[17], &cols[18], &cols[19],
+ &ints[0], &ints[1], &ints[2], &ints[3], &ints[4],
+ &createdAt,
+ )
+}
+
+func (s *WideScenario) ExecuteWrite(ctx context.Context, conn *pgx.Conn) error {
+ // Generate values for all 25 data columns
+ args := make([]interface{}, 25)
+ for i := 0; i < 20; i++ {
+ args[i] = generateString(50)
+ }
+ for i := 20; i < 25; i++ {
+ args[i] = rand.Int31()
+ }
+
+ var newID string
+ err := conn.QueryRow(ctx,
+ `INSERT INTO wide_data (
+ col_01, col_02, col_03, col_04, col_05,
+ col_06, col_07, col_08, col_09, col_10,
+ col_11, col_12, col_13, col_14, col_15,
+ col_16, col_17, col_18, col_19, col_20,
+ int_01, int_02, int_03, int_04, int_05
+ ) VALUES (
+ $1, $2, $3, $4, $5, $6, $7, $8, $9, $10,
+ $11, $12, $13, $14, $15, $16, $17, $18, $19, $20,
+ $21, $22, $23, $24, $25
+ ) RETURNING id::text`,
+ args...,
+ ).Scan(&newID)
+
+ if err == nil && newID != "" {
+ s.mu.Lock()
+ if len(s.ids) < 10000 {
+ s.ids = append(s.ids, newID)
+ } else {
+ s.ids[rand.Intn(len(s.ids))] = newID
+ }
+ s.mu.Unlock()
+ }
+ return err
+}
+
+// FKScenario uses tables with foreign key relationships
+type FKScenario struct {
+ maxCategoryID int64
+ ids []string
+ mu sync.RWMutex
+}
+
+func NewFKScenario() *FKScenario {
+ return &FKScenario{
+ maxCategoryID: 100,
+ ids: make([]string, 0, 10000),
+ }
+}
+
+func (s *FKScenario) Name() string { return "fk" }
+func (s *FKScenario) Description() string { return "Tables with foreign key lookup" }
+func (s *FKScenario) TableName() string { return "items" }
+func (s *FKScenario) MaxID() int64 { return 0 } // Not used for UUIDs
+
+func (s *FKScenario) Initialize(ctx context.Context, conn *pgx.Conn) error {
+ // Pre-load some existing IDs
+ rows, err := conn.Query(ctx, "SELECT id::text FROM items LIMIT 10000")
+ if err != nil {
+ return nil // Ignore error if table doesn't exist or is empty
+ }
+ defer rows.Close()
+
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ for rows.Next() {
+ var id string
+ if err := rows.Scan(&id); err == nil {
+ s.ids = append(s.ids, id)
+ }
+ }
+ return nil
+}
+
+func (s *FKScenario) ExecuteRead(ctx context.Context, conn *pgx.Conn) error {
+ var id string
+
+ s.mu.RLock()
+ if len(s.ids) > 0 {
+ id = s.ids[rand.Intn(len(s.ids))]
+ }
+ s.mu.RUnlock()
+
+ // If we don't have any IDs yet, we can't efficiently read
+ if id == "" {
+ return nil
+ }
+
+ // Join query to read item with its category
+ var itemID, categoryID string // IDs are potentially UUIDs or strings
+ var itemName, categoryName string
+ var createdAt time.Time
+
+ return conn.QueryRow(ctx,
+ `SELECT i.id::text, i.name, i.created_at, c.id::text, c.name
+ FROM items i
+ JOIN categories c ON i.category_id = c.id
+ WHERE i.id = $1`,
+ id,
+ ).Scan(&itemID, &itemName, &createdAt, &categoryID, &categoryName)
+}
+
+func (s *FKScenario) ExecuteWrite(ctx context.Context, conn *pgx.Conn) error {
+ // Pick a random category
+ categoryID := rand.Int63n(s.maxCategoryID) + 1
+ name := fmt.Sprintf("item_%d", rand.Int63())
+
+ var newID string
+ err := conn.QueryRow(ctx,
+ "INSERT INTO items (category_id, name) VALUES ($1, $2) RETURNING id::text",
+ categoryID, name,
+ ).Scan(&newID)
+
+ if err == nil && newID != "" {
+ s.mu.Lock()
+ // Keep cache strict size
+ if len(s.ids) < 10000 {
+ s.ids = append(s.ids, newID)
+ } else {
+ // Random replacement to keep cache fresh
+ s.ids[rand.Intn(len(s.ids))] = newID
+ }
+ s.mu.Unlock()
+ }
+
+ return err
+}
+
+// Helper to build a column list for queries
+func buildColumnList(columns []string, prefix string) string {
+ if prefix == "" {
+ return strings.Join(columns, ", ")
+ }
+ prefixed := make([]string, len(columns))
+ for i, col := range columns {
+ prefixed[i] = prefix + "." + col
+ }
+ return strings.Join(prefixed, ", ")
+}
+
+// Helper to build placeholder list ($1, $2, $3, ...)
+func buildPlaceholders(count int) string {
+ placeholders := make([]string, count)
+ for i := 0; i < count; i++ {
+ placeholders[i] = fmt.Sprintf("$%d", i+1)
+ }
+ return strings.Join(placeholders, ", ")
+}
diff --git a/schema/custom.go b/schema/custom.go
new file mode 100644
index 0000000..eb53223
--- /dev/null
+++ b/schema/custom.go
@@ -0,0 +1,342 @@
+package schema
+
+import (
+ "context"
+ "fmt"
+ "math/rand"
+ "strings"
+ "sync"
+
+ "github.com/jackc/pgx/v5"
+)
+
+// ColumnInfo holds metadata about a table column
+type ColumnInfo struct {
+ Name string
+ DataType string
+ IsNullable bool
+ HasDefault bool
+ IsSerial bool // SERIAL/BIGSERIAL columns (auto-generated)
+}
+
+// CustomScenario dynamically introspects a table and generates appropriate queries
+type CustomScenario struct {
+ tableName string
+ schemaName string
+ initialized bool
+ mu sync.RWMutex
+
+ // Discovered column information
+ columns []ColumnInfo
+ insertColumns []ColumnInfo // Columns we can insert into (excludes serials)
+ primaryKey string
+ primaryKeyType string // "uuid", "integer", etc.
+
+ // Pre-built queries
+ insertQuery string
+ selectQuery string
+
+ // For read operations
+ maxID int64
+ ids []string // Cache for UUIDs or non-sequential IDs
+}
+
+// NewCustomScenario creates a new custom scenario for the given table
+func NewCustomScenario(tableName string) *CustomScenario {
+ schemaName := "public"
+ tblName := tableName
+
+ if tableName != "" && strings.Contains(tableName, ".") {
+ parts := strings.SplitN(tableName, ".", 2)
+ schemaName = parts[0]
+ tblName = parts[1]
+ }
+
+ return &CustomScenario{
+ tableName: tblName,
+ schemaName: schemaName,
+ maxID: 100000,
+ ids: make([]string, 0, 10000),
+ }
+}
+
+func (s *CustomScenario) Name() string {
+ if s.tableName == "" {
+ return "custom:auto"
+ }
+ return "custom:" + s.tableName
+}
+
+func (s *CustomScenario) Description() string {
+ if s.tableName == "" {
+ return "Custom table: "
+ }
+ return fmt.Sprintf("Custom table: %s", s.tableName)
+}
+
+func (s *CustomScenario) TableName() string {
+ return s.tableName
+}
+
+func (s *CustomScenario) MaxID() int64 {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+ return s.maxID
+}
+
+// Initialize introspects the table structure and builds queries
+func (s *CustomScenario) Initialize(ctx context.Context, conn *pgx.Conn) error {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ if s.initialized {
+ return nil
+ }
+
+ // Auto-discover table if not specified
+ if s.tableName == "" {
+ err := conn.QueryRow(ctx, `
+ SELECT table_schema, table_name
+ FROM information_schema.tables
+ WHERE table_schema = 'public'
+ AND table_type = 'BASE TABLE'
+ ORDER BY table_name LIMIT 1
+ `).Scan(&s.schemaName, &s.tableName)
+
+ if err != nil {
+ return fmt.Errorf("failed to auto-discover table: %w", err)
+ }
+ }
+
+ // Get column information from information_schema
+ // Fix: use COALESCE for is_serial to avoid NULL scanning error
+ rows, err := conn.Query(ctx, `
+ SELECT
+ c.column_name,
+ c.data_type,
+ c.is_nullable = 'YES' as is_nullable,
+ c.column_default IS NOT NULL as has_default,
+ COALESCE((c.column_default LIKE 'nextval%'), false) as is_serial
+ FROM information_schema.columns c
+ WHERE c.table_schema = $1 AND c.table_name = $2
+ ORDER BY c.ordinal_position
+ `, s.schemaName, s.tableName)
+ if err != nil {
+ return fmt.Errorf("failed to query columns: %w", err)
+ }
+ defer rows.Close()
+
+ s.columns = nil
+ s.insertColumns = nil
+
+ for rows.Next() {
+ var col ColumnInfo
+ if err := rows.Scan(&col.Name, &col.DataType, &col.IsNullable, &col.HasDefault, &col.IsSerial); err != nil {
+ return fmt.Errorf("failed to scan column info: %w", err)
+ }
+ s.columns = append(s.columns, col)
+
+ // Include columns that are not serial/auto-generated for inserts
+ if !col.IsSerial {
+ s.insertColumns = append(s.insertColumns, col)
+ }
+ }
+
+ if len(s.columns) == 0 {
+ return fmt.Errorf("table %s.%s not found or has no columns", s.schemaName, s.tableName)
+ }
+
+ // Get primary key column and type
+ var pkType string
+ err = conn.QueryRow(ctx, `
+ SELECT a.attname, format_type(a.atttypid, a.atttypmod) as data_type
+ FROM pg_index i
+ JOIN pg_attribute a ON a.attrelid = i.indrelid AND a.attnum = ANY(i.indkey)
+ WHERE i.indrelid = $1::regclass AND i.indisprimary
+ LIMIT 1
+ `, fmt.Sprintf("%s.%s", s.schemaName, s.tableName)).Scan(&s.primaryKey, &pkType)
+
+ if err != nil {
+ // No primary key found, use first column
+ s.primaryKey = s.columns[0].Name
+ s.primaryKeyType = s.columns[0].DataType
+ } else {
+ s.primaryKeyType = pkType
+ }
+
+ // Determine strategy based on PK type (int vs uuid/other)
+ // We default to integer strategy ONLY if it strongly looks like an integer/serial.
+ // otherwise we default to the safer ID caching strategy (which handles UUIDs, text, sparseness)
+ lowerType := strings.ToLower(s.primaryKeyType)
+ isIntegerPK := (strings.Contains(lowerType, "int") || strings.Contains(lowerType, "serial")) &&
+ !strings.Contains(lowerType, "uuid") // Explicitly exclude UUID if it somehow matches "int" (unlikely but safe)
+
+ if isIntegerPK {
+ // Integer strategy: Get max ID
+ var maxID *int64
+ err = conn.QueryRow(ctx, fmt.Sprintf(
+ "SELECT MAX(%s) FROM %s.%s",
+ s.primaryKey, s.schemaName, s.tableName,
+ )).Scan(&maxID)
+ if err == nil && maxID != nil {
+ s.maxID = *maxID
+ }
+ if s.maxID < 1 {
+ s.maxID = 1 // At least 1
+ }
+ } else {
+ // UUID/String strategy: Cache IDs
+ // We explicitly cast to text to ensure scanning works for UUID/VARCHAR/etc
+ idRows, err := conn.Query(ctx, fmt.Sprintf(
+ "SELECT %s::text FROM %s.%s LIMIT 10000",
+ s.primaryKey, s.schemaName, s.tableName,
+ ))
+ if err == nil {
+ defer idRows.Close()
+ for idRows.Next() {
+ var id string
+ if err := idRows.Scan(&id); err == nil {
+ s.ids = append(s.ids, id)
+ }
+ }
+ }
+ }
+
+ // Build INSERT query
+ if len(s.insertColumns) > 0 {
+ colNames := make([]string, len(s.insertColumns))
+ placeholders := make([]string, len(s.insertColumns))
+ for i, col := range s.insertColumns {
+ colNames[i] = col.Name
+ placeholders[i] = fmt.Sprintf("$%d", i+1)
+ }
+ // Always return as text to be safe
+ s.insertQuery = fmt.Sprintf(
+ "INSERT INTO %s.%s (%s) VALUES (%s) RETURNING %s::text",
+ s.schemaName, s.tableName,
+ strings.Join(colNames, ", "),
+ strings.Join(placeholders, ", "),
+ s.primaryKey,
+ )
+ }
+
+ // Build SELECT query
+ colNames := make([]string, len(s.columns))
+ for i, col := range s.columns {
+ colNames[i] = col.Name
+ }
+ // Always cast parameter to text in application logic (handled by driver usually, but good to be consistent)
+ s.selectQuery = fmt.Sprintf(
+ "SELECT %s FROM %s.%s WHERE %s = $1",
+ strings.Join(colNames, ", "),
+ s.schemaName, s.tableName,
+ s.primaryKey,
+ )
+
+ s.initialized = true
+ return nil
+}
+
+func (s *CustomScenario) ExecuteRead(ctx context.Context, conn *pgx.Conn) error {
+ s.mu.RLock()
+ if !s.initialized {
+ s.mu.RUnlock()
+ if err := s.Initialize(ctx, conn); err != nil {
+ return err
+ }
+ s.mu.RLock()
+ }
+ query := s.selectQuery
+ maxID := s.maxID
+ numCols := len(s.columns)
+ cachedIDs := len(s.ids)
+ s.mu.RUnlock()
+
+ if query == "" {
+ return fmt.Errorf("custom scenario not initialized")
+ }
+
+ // Determine ID to read
+ var id interface{}
+ if cachedIDs > 0 {
+ s.mu.RLock()
+ if len(s.ids) > 0 {
+ id = s.ids[rand.Intn(len(s.ids))]
+ }
+ s.mu.RUnlock()
+ } else if maxID > 0 {
+ id = rand.Int63n(maxID) + 1
+ }
+
+ // If we still don't have an ID (empty table or failed init), we can't read
+ if id == nil {
+ return nil // Behave like a no-op instead of erroring
+ }
+
+ // Create scan destinations for all columns
+ destinations := make([]interface{}, numCols)
+ values := make([]interface{}, numCols)
+ for i := range destinations {
+ destinations[i] = &values[i]
+ }
+
+ return conn.QueryRow(ctx, query, id).Scan(destinations...)
+}
+
+func (s *CustomScenario) ExecuteWrite(ctx context.Context, conn *pgx.Conn) error {
+ s.mu.RLock()
+ if !s.initialized {
+ s.mu.RUnlock()
+ if err := s.Initialize(ctx, conn); err != nil {
+ return err
+ }
+ s.mu.RLock()
+ }
+ query := s.insertQuery
+ insertColumns := s.insertColumns
+ s.mu.RUnlock()
+
+ if query == "" || len(insertColumns) == 0 {
+ return fmt.Errorf("custom scenario not initialized or no insertable columns")
+ }
+
+ // Generate values for each insert column
+ args := make([]interface{}, len(insertColumns))
+ for i, col := range insertColumns {
+ args[i] = GenerateValue(col.DataType, col.Name)
+ }
+
+ var newID string
+ err := conn.QueryRow(ctx, query, args...).Scan(&newID)
+
+ // If successful and we have a new ID, cache it if using ID cache strategy
+ if err == nil && newID != "" {
+ s.mu.Lock()
+ if len(s.ids) > 0 || s.maxID == 0 { // Use cache if we have IDs or maxID is 0 (meaning not int strategy)
+ if len(s.ids) < 10000 {
+ s.ids = append(s.ids, newID)
+ } else {
+ s.ids[rand.Intn(len(s.ids))] = newID
+ }
+ }
+ s.mu.Unlock()
+ }
+
+ return err
+}
+
+// IsInitialized returns whether the scenario has been initialized
+func (s *CustomScenario) IsInitialized() bool {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+ return s.initialized
+}
+
+// GetColumns returns the discovered columns (for debugging/info)
+func (s *CustomScenario) GetColumns() []ColumnInfo {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+ result := make([]ColumnInfo, len(s.columns))
+ copy(result, s.columns)
+ return result
+}
diff --git a/schema/generator.go b/schema/generator.go
new file mode 100644
index 0000000..72a4e6f
--- /dev/null
+++ b/schema/generator.go
@@ -0,0 +1,151 @@
+package schema
+
+import (
+ "encoding/json"
+ "fmt"
+ "math/rand"
+ "strings"
+ "time"
+
+ "github.com/brianvoe/gofakeit/v6"
+ "github.com/google/uuid"
+)
+
+// GenerateValue generates a random value appropriate for the given PostgreSQL column type.
+// It uses column name hints to generate more realistic data.
+func GenerateValue(colType string, colName string) interface{} {
+ colType = strings.ToLower(colType)
+ colName = strings.ToLower(colName)
+
+ // Check for specific column name patterns first for realistic data
+ switch {
+ case strings.Contains(colName, "email"):
+ return gofakeit.Email()
+ case strings.Contains(colName, "username") || strings.Contains(colName, "user_name"):
+ return gofakeit.Username()
+ case strings.Contains(colName, "first_name") || strings.Contains(colName, "firstname"):
+ return gofakeit.FirstName()
+ case strings.Contains(colName, "last_name") || strings.Contains(colName, "lastname"):
+ return gofakeit.LastName()
+ case strings.Contains(colName, "full_name") || strings.Contains(colName, "fullname") || colName == "name":
+ return gofakeit.Name()
+ case strings.Contains(colName, "phone") || strings.Contains(colName, "cell") || strings.Contains(colName, "mobile"):
+ return gofakeit.Phone()
+ case strings.Contains(colName, "city"):
+ return gofakeit.City()
+ case strings.Contains(colName, "country"):
+ return gofakeit.Country()
+ case strings.Contains(colName, "state") || strings.Contains(colName, "province"):
+ return gofakeit.State()
+ case strings.Contains(colName, "zip") || strings.Contains(colName, "postal"):
+ return gofakeit.Zip()
+ case strings.Contains(colName, "address"):
+ return gofakeit.Address().Address
+ case strings.Contains(colName, "company") || strings.Contains(colName, "org"):
+ return gofakeit.Company()
+ case strings.Contains(colName, "job") || strings.Contains(colName, "title"):
+ return gofakeit.JobTitle()
+ case strings.Contains(colName, "bio") || strings.Contains(colName, "description"):
+ return gofakeit.Sentence(10)
+ case strings.Contains(colName, "url") || strings.Contains(colName, "link") || strings.Contains(colName, "website"):
+ return gofakeit.URL()
+ case strings.Contains(colName, "ipv4"):
+ return gofakeit.IPv4Address()
+ case strings.Contains(colName, "ipv6"):
+ return gofakeit.IPv6Address()
+ case strings.Contains(colName, "user_agent"):
+ return gofakeit.UserAgent()
+ }
+
+ // Generate based on column type if no specific name match
+ switch {
+ case strings.HasPrefix(colType, "varchar"), strings.HasPrefix(colType, "character varying"),
+ colType == "text", strings.HasPrefix(colType, "char"):
+ if strings.Contains(colType, "(") {
+ // Extract length if possible, otherwise default
+ return gofakeit.Sentence(3)
+ }
+ return gofakeit.Sentence(5)
+
+ case colType == "integer", colType == "int", colType == "int4":
+ return int32(gofakeit.Number(0, 1000000))
+
+ case colType == "bigint", colType == "int8":
+ return int64(gofakeit.Number(0, 1000000000))
+
+ case colType == "smallint", colType == "int2":
+ return int16(gofakeit.Number(0, 32000))
+
+ case colType == "boolean", colType == "bool":
+ return gofakeit.Bool()
+
+ case colType == "real", colType == "float4":
+ return float32(gofakeit.Float64Range(0, 1000))
+
+ case colType == "double precision", colType == "float8":
+ return gofakeit.Float64Range(0, 10000)
+
+ case strings.HasPrefix(colType, "numeric"), strings.HasPrefix(colType, "decimal"):
+ return gofakeit.Float64Range(0, 100000)
+
+ case colType == "uuid":
+ return uuid.New().String()
+
+ case colType == "timestamp", colType == "timestamp without time zone":
+ return gofakeit.Date()
+
+ case colType == "timestamptz", colType == "timestamp with time zone":
+ return gofakeit.Date()
+
+ case colType == "date":
+ return gofakeit.Date().Format("2006-01-02")
+
+ case colType == "time", colType == "time without time zone":
+ return gofakeit.Date().Format("15:04:05")
+
+ case colType == "timetz", colType == "time with time zone":
+ return gofakeit.Date().Format("15:04:05-07:00")
+
+ case colType == "jsonb", colType == "json":
+ return generateJSON() // Keep existing simple JSON generator or use gofakeit structure
+
+ case colType == "bytea":
+ return []byte(gofakeit.Sentence(5))
+
+ case strings.HasPrefix(colType, "interval"):
+ return fmt.Sprintf("%d hours", gofakeit.Number(1, 24))
+
+ default:
+ // Default to string for unknown types
+ return gofakeit.Sentence(5)
+ }
+}
+
+func generateJSON() string {
+ person := gofakeit.Person()
+ data := map[string]interface{}{
+ "id": gofakeit.UUID(),
+ "timestamp": time.Now().Unix(),
+ "name": person.FirstName + " " + person.LastName,
+ "active": gofakeit.Bool(),
+ "score": gofakeit.Float64Range(0, 100),
+ "tags": []string{gofakeit.Word(), gofakeit.Word()},
+ "metadata": map[string]interface{}{
+ "version": "1.0",
+ "source": "generated",
+ "job": person.Job.Title,
+ },
+ }
+ bytes, _ := json.Marshal(data)
+ return string(bytes)
+}
+
+func generateBytes(length int) []byte {
+ b := make([]byte, length)
+ rand.Read(b)
+ return b
+}
+
+func generateString(maxLen int) string {
+ return gofakeit.Sentence(maxLen/10 + 1) // Approx words to match length
+}
diff --git a/schema/scenario.go b/schema/scenario.go
new file mode 100644
index 0000000..7ca02ba
--- /dev/null
+++ b/schema/scenario.go
@@ -0,0 +1,118 @@
+package schema
+
+import (
+ "context"
+ "sync"
+
+ "github.com/jackc/pgx/v5"
+)
+
+// Scenario defines the interface for different database schema scenarios.
+// Each scenario encapsulates the read/write operations for a specific table structure.
+type Scenario interface {
+ // Name returns the unique identifier for this scenario
+ Name() string
+
+ // Description returns a human-readable description
+ Description() string
+
+ // TableName returns the primary table name used by this scenario
+ TableName() string
+
+ // MaxID returns the maximum ID for read operations (for seeded data)
+ MaxID() int64
+
+ // ExecuteRead performs a read operation using this scenario's query
+ ExecuteRead(ctx context.Context, conn *pgx.Conn) error
+
+ // ExecuteWrite performs a write operation using this scenario's query
+ ExecuteWrite(ctx context.Context, conn *pgx.Conn) error
+
+ // Initialize performs any necessary setup (e.g., table introspection for custom scenarios)
+ Initialize(ctx context.Context, conn *pgx.Conn) error
+}
+
+// ScenarioInfo contains metadata about a scenario for API responses
+type ScenarioInfo struct {
+ Name string `json:"name"`
+ Description string `json:"description"`
+ TableName string `json:"table_name"`
+}
+
+// Registry holds all available scenarios
+type Registry struct {
+ mu sync.RWMutex
+ scenarios map[string]Scenario
+}
+
+// NewRegistry creates a new scenario registry with builtin scenarios
+func NewRegistry() *Registry {
+ r := &Registry{
+ scenarios: make(map[string]Scenario),
+ }
+
+ // Register builtin scenarios
+ r.Register(NewSimpleScenario())
+ r.Register(NewJSONBScenario())
+ r.Register(NewWideScenario())
+ r.Register(NewFKScenario())
+
+ return r
+}
+
+// Register adds a scenario to the registry
+func (r *Registry) Register(s Scenario) {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+ r.scenarios[s.Name()] = s
+}
+
+// Get retrieves a scenario by name
+func (r *Registry) Get(name string) (Scenario, bool) {
+ r.mu.RLock()
+ defer r.mu.RUnlock()
+ s, ok := r.scenarios[name]
+ return s, ok
+}
+
+// List returns info about all registered scenarios
+func (r *Registry) List() []ScenarioInfo {
+ r.mu.RLock()
+ defer r.mu.RUnlock()
+
+ list := make([]ScenarioInfo, 0, len(r.scenarios))
+ // Return in a consistent order
+ order := []string{"simple", "jsonb", "wide", "fk"}
+ for _, name := range order {
+ if s, ok := r.scenarios[name]; ok {
+ list = append(list, ScenarioInfo{
+ Name: s.Name(),
+ Description: s.Description(),
+ TableName: s.TableName(),
+ })
+ }
+ }
+ // Add any others not in the predefined order
+ for name, s := range r.scenarios {
+ found := false
+ for _, o := range order {
+ if o == name {
+ found = true
+ break
+ }
+ }
+ if !found {
+ list = append(list, ScenarioInfo{
+ Name: s.Name(),
+ Description: s.Description(),
+ TableName: s.TableName(),
+ })
+ }
+ }
+ return list
+}
+
+// CreateCustomScenario creates a custom scenario for a specific table
+func (r *Registry) CreateCustomScenario(tableName string) *CustomScenario {
+ return NewCustomScenario(tableName)
+}