Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions plutus_agent/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,7 @@ def build_parser():
p = argparse.ArgumentParser(
prog="plutus", description=f"Plutus — {__tagline__}")
p.add_argument("--version", action="version", version=f"plutus v{__version__}")
p.add_argument("--db", help="database path (overrides PLUTUS_DB)")
sub = p.add_subparsers(dest="cmd")

pi = sub.add_parser("init", help="create config + database")
Expand Down Expand Up @@ -480,6 +481,9 @@ def main(argv=None):
_force_utf8()
parser = build_parser()
args = parser.parse_args(argv)
# Wire --db through to PLUTUS_DB env var so all db.connect() calls honor it
if hasattr(args, 'db') and args.db:
os.environ['PLUTUS_DB'] = args.db
if not getattr(args, "func", None):
parser.print_help()
return 0
Expand Down
55 changes: 52 additions & 3 deletions plutus_agent/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,13 +127,52 @@ def _minimal_yaml_read(path: Path) -> dict:
"""Tiny fallback reader (flat key: value, one level of nesting by indent).

Only used if PyYAML is missing. Good enough to read a config we wrote.
Supports simple YAML (key: value, one level of 2-space-indented nesting)
and JSON as a fallback.
"""
try:
import json
text = path.read_text(encoding="utf-8")
# we may have written JSON as a fallback
# Try JSON first if it looks like JSON
if text.lstrip().startswith("{"):
import json
return json.loads(text)

# Parse simple YAML: 'key: value' and one level of ' nested: value'
result = {}
current_section = None
for line in text.splitlines():
stripped = line.rstrip()
if not stripped or stripped.startswith("#"):
continue

# Top-level key (no leading spaces)
if not line.startswith(" ") and ":" in stripped:
key, _, val = stripped.partition(":")
key = key.strip()
val = val.strip()
if val:
# Simple value: try to parse as JSON literal, else string
try:
import json
result[key] = json.loads(val)
except (json.JSONDecodeError, ValueError):
result[key] = val
else:
# Empty value means nested section follows
result[key] = {}
current_section = key
# Nested key (2-space indent)
elif line.startswith(" ") and ":" in stripped and current_section:
key, _, val = stripped.strip().partition(":")
key = key.strip()
val = val.strip()
if val:
try:
import json
result[current_section][key] = json.loads(val)
except (json.JSONDecodeError, ValueError):
result[current_section][key] = val
return result
except Exception:
pass
return {}
Expand Down Expand Up @@ -231,8 +270,18 @@ def _strip_env_secrets(cfg: dict) -> dict:
def save(cfg: dict) -> Path:
"""Persist config to disk. Secrets that came from the environment are
stripped first (see :func:`_strip_env_secrets`) — ``config.yaml`` should
never hold a live key that was provided via env."""
never hold a live key that was provided via env.

Creates a timestamped backup of the existing config before overwriting.
"""
path = config_path()
# Create timestamped backup if file already exists
if path.exists():
from datetime import datetime
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
backup_path = path.with_suffix(f".yaml.bak-{timestamp}")
import shutil
shutil.copy2(path, backup_path)
_dump_yaml(path, _strip_env_secrets(cfg))
return path

Expand Down
2 changes: 1 addition & 1 deletion plutus_agent/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@
CREATE TABLE IF NOT EXISTS alerts_log (
id TEXT PRIMARY KEY,
org_id TEXT NOT NULL REFERENCES organizations(id) ON DELETE CASCADE,
workspace_id TEXT,
workspace_id TEXT REFERENCES workspaces(id) ON DELETE SET NULL,
kind TEXT NOT NULL, -- low_balance|budget_warn|budget_cap
message TEXT NOT NULL,
delivered INTEGER NOT NULL DEFAULT 0,
Expand Down
9 changes: 5 additions & 4 deletions plutus_agent/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from .. import __version__, bridge, config as cfgmod, db
from ..billing import StripeClient, BillingError, handle_webhook_event
from ..utils import strict_int
from . import api, views, auth as authmod

# Paths reachable without a session when auth is enabled.
Expand Down Expand Up @@ -317,10 +318,10 @@ def _ingest_usage(self, conn):
conn, org_id, provider=str(ev["provider"]),
model=ev.get("model"), task_type=ev.get("task_type", "general"),
workspace=ev.get("workspace"),
input_tokens=int(ev.get("input_tokens", 0) or 0),
output_tokens=int(ev.get("output_tokens", 0) or 0),
cache_read_tokens=int(ev.get("cache_read_tokens", 0) or 0),
reasoning_tokens=int(ev.get("reasoning_tokens", 0) or 0),
input_tokens=strict_int(ev.get("input_tokens", 0) or 0),
output_tokens=strict_int(ev.get("output_tokens", 0) or 0),
cache_read_tokens=strict_int(ev.get("cache_read_tokens", 0) or 0),
reasoning_tokens=strict_int(ev.get("reasoning_tokens", 0) or 0),
cost_usd=ev.get("cost_usd"),
source=ev.get("source", "api"),
pricing_overrides=cfg.get("pricing", {}).get("overrides"),
Expand Down
2 changes: 1 addition & 1 deletion plutus_agent/server/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def _claims_from_id_token(id_token: str, cfg, nonce: str) -> dict:
raise AuthError("id_token nonce mismatch")
if not claims.get("email"):
raise AuthError("id_token has no email")
if claims.get("email_verified") in (False, "false"):
if claims.get("email_verified") not in (True, "true"):
raise AuthError("email is not verified")
return claims

Expand Down
26 changes: 26 additions & 0 deletions plutus_agent/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
"""Utility functions for Plutus."""
from __future__ import annotations


def strict_int(value) -> int:
"""Parse an integer strictly, rejecting floats and non-integer strings.

Accepts:
- int: 5, 0, -10
- str with integer value: '5', '0', '-10'

Rejects:
- float: 1.9, 1.0
- str with float: '1.9', '1.0'

Raises:
ValueError: if value is a float or contains a decimal point
TypeError: if value cannot be converted to int
"""
if isinstance(value, float):
raise ValueError(f"strict_int: float {value} not allowed (would truncate)")
if isinstance(value, str):
if '.' in value:
raise ValueError(f"strict_int: string '{value}' contains decimal point")
return int(value)
return int(value)
177 changes: 176 additions & 1 deletion tests/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@

sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from plutus_agent import db, metering, pricing, demo, reports
from plutus_agent import db, metering, pricing, demo, reports, config as cfgmod
from plutus_agent.billing import handle_webhook_event
from plutus_agent.utils import strict_int
from plutus_agent.server import auth as authmod


_PATHS = {} # id(conn) -> file path, since sqlite3.Connection forbids attrs
Expand Down Expand Up @@ -399,5 +401,178 @@ def test_hook_meters_payload(self):
os.environ.pop("PLUTUS_ORG", None)


class TestPolishIssue37(unittest.TestCase):
"""Tests for the polish punch-list fixes (issue #37)."""

def test_strict_int_accepts_integers(self):
"""strict_int accepts int and integer-valued strings."""
self.assertEqual(strict_int(5), 5)
self.assertEqual(strict_int('5'), 5)
self.assertEqual(strict_int(0), 0)
self.assertEqual(strict_int('0'), 0)
self.assertEqual(strict_int(-10), -10)
self.assertEqual(strict_int('-10'), -10)

def test_strict_int_rejects_floats(self):
"""strict_int rejects float values and float-bearing strings."""
with self.assertRaises(ValueError):
strict_int(1.9)
with self.assertRaises(ValueError):
strict_int('1.9')
with self.assertRaises(ValueError):
strict_int(1.0)
with self.assertRaises(ValueError):
strict_int('1.0')

def test_db_wiring_honors_env(self):
"""db_path() honors PLUTUS_DB environment variable."""
orig = os.environ.get('PLUTUS_DB')
try:
test_path = '/tmp/test_plutus_custom.db'
os.environ['PLUTUS_DB'] = test_path
self.assertEqual(str(cfgmod.db_path()), test_path)
finally:
if orig is None:
os.environ.pop('PLUTUS_DB', None)
else:
os.environ['PLUTUS_DB'] = orig

def test_config_save_creates_backup(self):
"""config.save() creates a timestamped backup when file exists."""
import tempfile
import shutil
from pathlib import Path

home = tempfile.mkdtemp()
orig_home = os.environ.get('PLUTUS_HOME')
try:
os.environ['PLUTUS_HOME'] = home
cfg_path = cfgmod.config_path()

# First save: no backup should be created
initial_cfg = {'test': 'value1'}
cfgmod.save(initial_cfg)
self.assertTrue(cfg_path.exists())
backups = list(cfg_path.parent.glob('config.yaml.bak-*'))
self.assertEqual(len(backups), 0, "No backup on first save")

# Second save: backup should be created
import time
time.sleep(1.1) # Ensure different timestamp
updated_cfg = {'test': 'value2'}
cfgmod.save(updated_cfg)
backups = list(cfg_path.parent.glob('config.yaml.bak-*'))
self.assertGreaterEqual(len(backups), 1, "At least one backup after second save")

# Third save: another backup
time.sleep(1.1)
cfgmod.save({'test': 'value3'})
backups = list(cfg_path.parent.glob('config.yaml.bak-*'))
self.assertGreaterEqual(len(backups), 2, "At least two backups after third save")
finally:
if orig_home is None:
os.environ.pop('PLUTUS_HOME', None)
else:
os.environ['PLUTUS_HOME'] = orig_home
shutil.rmtree(home, ignore_errors=True)

def test_email_verified_requires_truthy(self):
"""email_verified must be explicitly True or 'true', missing/False rejected."""
cfg = {
'auth': {
'google_client_id': 'test-client-id',
'google_client_secret': 'test-secret',
'base_url': 'http://localhost:8420'
}
}

# Missing email_verified should raise
claims_missing = {
'aud': 'test-client-id',
'iss': 'https://accounts.google.com',
'exp': time.time() + 3600,
'nonce': 'test-nonce',
'email': 'test@example.com'
# email_verified missing
}
with self.assertRaises(authmod.AuthError) as ctx:
authmod._claims_from_id_token.__wrapped__ = authmod._claims_from_id_token
# We need to mock just the claims extraction part
# Since _claims_from_id_token expects a JWT, let's test directly
import base64
import json

def make_fake_token(claims):
# Create a fake JWT (we won't verify signature)
header = base64.urlsafe_b64encode(json.dumps({'alg': 'RS256'}).encode()).decode().rstrip('=')
payload = base64.urlsafe_b64encode(json.dumps(claims).encode()).decode().rstrip('=')
return f"{header}.{payload}.fake_signature"

fake_token = make_fake_token(claims_missing)
authmod._claims_from_id_token(fake_token, cfg, 'test-nonce')
self.assertIn('not verified', str(ctx.exception))

# False should raise
claims_false = claims_missing.copy()
claims_false['email_verified'] = False
with self.assertRaises(authmod.AuthError) as ctx:
fake_token = make_fake_token(claims_false)
authmod._claims_from_id_token(fake_token, cfg, 'test-nonce')
self.assertIn('not verified', str(ctx.exception))

# True should pass
claims_true = claims_missing.copy()
claims_true['email_verified'] = True
fake_token = make_fake_token(claims_true)
result = authmod._claims_from_id_token(fake_token, cfg, 'test-nonce')
self.assertEqual(result['email'], 'test@example.com')

# 'true' string should pass
claims_str = claims_missing.copy()
claims_str['email_verified'] = 'true'
fake_token = make_fake_token(claims_str)
result = authmod._claims_from_id_token(fake_token, cfg, 'test-nonce')
self.assertEqual(result['email'], 'test@example.com')

def test_minimal_yaml_read_roundtrip(self):
"""_minimal_yaml_read can parse simple YAML config."""
import tempfile
from pathlib import Path

fd, path_str = tempfile.mkstemp(suffix='.yaml')
os.close(fd)
path = Path(path_str)

try:
# Write a simple YAML config
yaml_content = """server:
host: 127.0.0.1
port: 8420
billing:
currency: usd
stripe_price_pro: price_123
alerts:
enabled: false
low_balance_usd: 10.0
"""
path.write_text(yaml_content, encoding='utf-8')

# Read it back with the minimal reader
result = cfgmod._minimal_yaml_read(path)

# Verify structure
self.assertIn('server', result)
self.assertIn('billing', result)
self.assertIn('alerts', result)
self.assertEqual(result['server']['host'], '127.0.0.1')
self.assertEqual(result['server']['port'], 8420)
self.assertEqual(result['billing']['currency'], 'usd')
self.assertEqual(result['billing']['stripe_price_pro'], 'price_123')
self.assertEqual(result['alerts']['enabled'], False)
self.assertEqual(result['alerts']['low_balance_usd'], 10.0)
finally:
path.unlink()


if __name__ == "__main__":
unittest.main()
Loading