diff --git a/plutus_agent/cli.py b/plutus_agent/cli.py index 885a6e1..7d76e8f 100644 --- a/plutus_agent/cli.py +++ b/plutus_agent/cli.py @@ -379,6 +379,7 @@ def build_parser(): p = argparse.ArgumentParser( prog="plutus", description=f"Plutus — {__tagline__}") p.add_argument("--version", action="version", version=f"plutus v{__version__}") + p.add_argument("--db", help="database path (overrides PLUTUS_DB)") sub = p.add_subparsers(dest="cmd") pi = sub.add_parser("init", help="create config + database") @@ -480,6 +481,9 @@ def main(argv=None): _force_utf8() parser = build_parser() args = parser.parse_args(argv) + # Wire --db through to PLUTUS_DB env var so all db.connect() calls honor it + if hasattr(args, 'db') and args.db: + os.environ['PLUTUS_DB'] = args.db if not getattr(args, "func", None): parser.print_help() return 0 diff --git a/plutus_agent/config.py b/plutus_agent/config.py index 120b801..1bd163d 100644 --- a/plutus_agent/config.py +++ b/plutus_agent/config.py @@ -127,13 +127,52 @@ def _minimal_yaml_read(path: Path) -> dict: """Tiny fallback reader (flat key: value, one level of nesting by indent). Only used if PyYAML is missing. Good enough to read a config we wrote. + Supports simple YAML (key: value, one level of 2-space-indented nesting) + and JSON as a fallback. """ try: - import json text = path.read_text(encoding="utf-8") - # we may have written JSON as a fallback + # Try JSON first if it looks like JSON if text.lstrip().startswith("{"): + import json return json.loads(text) + + # Parse simple YAML: 'key: value' and one level of ' nested: value' + result = {} + current_section = None + for line in text.splitlines(): + stripped = line.rstrip() + if not stripped or stripped.startswith("#"): + continue + + # Top-level key (no leading spaces) + if not line.startswith(" ") and ":" in stripped: + key, _, val = stripped.partition(":") + key = key.strip() + val = val.strip() + if val: + # Simple value: try to parse as JSON literal, else string + try: + import json + result[key] = json.loads(val) + except (json.JSONDecodeError, ValueError): + result[key] = val + else: + # Empty value means nested section follows + result[key] = {} + current_section = key + # Nested key (2-space indent) + elif line.startswith(" ") and ":" in stripped and current_section: + key, _, val = stripped.strip().partition(":") + key = key.strip() + val = val.strip() + if val: + try: + import json + result[current_section][key] = json.loads(val) + except (json.JSONDecodeError, ValueError): + result[current_section][key] = val + return result except Exception: pass return {} @@ -231,8 +270,18 @@ def _strip_env_secrets(cfg: dict) -> dict: def save(cfg: dict) -> Path: """Persist config to disk. Secrets that came from the environment are stripped first (see :func:`_strip_env_secrets`) — ``config.yaml`` should - never hold a live key that was provided via env.""" + never hold a live key that was provided via env. + + Creates a timestamped backup of the existing config before overwriting. + """ path = config_path() + # Create timestamped backup if file already exists + if path.exists(): + from datetime import datetime + timestamp = datetime.now().strftime("%Y%m%d%H%M%S") + backup_path = path.with_suffix(f".yaml.bak-{timestamp}") + import shutil + shutil.copy2(path, backup_path) _dump_yaml(path, _strip_env_secrets(cfg)) return path diff --git a/plutus_agent/db.py b/plutus_agent/db.py index 5a51514..41e5002 100644 --- a/plutus_agent/db.py +++ b/plutus_agent/db.py @@ -88,7 +88,7 @@ CREATE TABLE IF NOT EXISTS alerts_log ( id TEXT PRIMARY KEY, org_id TEXT NOT NULL REFERENCES organizations(id) ON DELETE CASCADE, - workspace_id TEXT, + workspace_id TEXT REFERENCES workspaces(id) ON DELETE SET NULL, kind TEXT NOT NULL, -- low_balance|budget_warn|budget_cap message TEXT NOT NULL, delivered INTEGER NOT NULL DEFAULT 0, diff --git a/plutus_agent/server/app.py b/plutus_agent/server/app.py index f202c29..786f362 100644 --- a/plutus_agent/server/app.py +++ b/plutus_agent/server/app.py @@ -15,6 +15,7 @@ from .. import __version__, bridge, config as cfgmod, db from ..billing import StripeClient, BillingError, handle_webhook_event +from ..utils import strict_int from . import api, views, auth as authmod # Paths reachable without a session when auth is enabled. @@ -317,10 +318,10 @@ def _ingest_usage(self, conn): conn, org_id, provider=str(ev["provider"]), model=ev.get("model"), task_type=ev.get("task_type", "general"), workspace=ev.get("workspace"), - input_tokens=int(ev.get("input_tokens", 0) or 0), - output_tokens=int(ev.get("output_tokens", 0) or 0), - cache_read_tokens=int(ev.get("cache_read_tokens", 0) or 0), - reasoning_tokens=int(ev.get("reasoning_tokens", 0) or 0), + input_tokens=strict_int(ev.get("input_tokens", 0) or 0), + output_tokens=strict_int(ev.get("output_tokens", 0) or 0), + cache_read_tokens=strict_int(ev.get("cache_read_tokens", 0) or 0), + reasoning_tokens=strict_int(ev.get("reasoning_tokens", 0) or 0), cost_usd=ev.get("cost_usd"), source=ev.get("source", "api"), pricing_overrides=cfg.get("pricing", {}).get("overrides"), diff --git a/plutus_agent/server/auth.py b/plutus_agent/server/auth.py index b6072b8..e95083d 100644 --- a/plutus_agent/server/auth.py +++ b/plutus_agent/server/auth.py @@ -135,7 +135,7 @@ def _claims_from_id_token(id_token: str, cfg, nonce: str) -> dict: raise AuthError("id_token nonce mismatch") if not claims.get("email"): raise AuthError("id_token has no email") - if claims.get("email_verified") in (False, "false"): + if claims.get("email_verified") not in (True, "true"): raise AuthError("email is not verified") return claims diff --git a/plutus_agent/utils.py b/plutus_agent/utils.py new file mode 100644 index 0000000..3af2234 --- /dev/null +++ b/plutus_agent/utils.py @@ -0,0 +1,26 @@ +"""Utility functions for Plutus.""" +from __future__ import annotations + + +def strict_int(value) -> int: + """Parse an integer strictly, rejecting floats and non-integer strings. + + Accepts: + - int: 5, 0, -10 + - str with integer value: '5', '0', '-10' + + Rejects: + - float: 1.9, 1.0 + - str with float: '1.9', '1.0' + + Raises: + ValueError: if value is a float or contains a decimal point + TypeError: if value cannot be converted to int + """ + if isinstance(value, float): + raise ValueError(f"strict_int: float {value} not allowed (would truncate)") + if isinstance(value, str): + if '.' in value: + raise ValueError(f"strict_int: string '{value}' contains decimal point") + return int(value) + return int(value) diff --git a/tests/test_engine.py b/tests/test_engine.py index fb71059..eb49342 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -8,8 +8,10 @@ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -from plutus_agent import db, metering, pricing, demo, reports +from plutus_agent import db, metering, pricing, demo, reports, config as cfgmod from plutus_agent.billing import handle_webhook_event +from plutus_agent.utils import strict_int +from plutus_agent.server import auth as authmod _PATHS = {} # id(conn) -> file path, since sqlite3.Connection forbids attrs @@ -399,5 +401,178 @@ def test_hook_meters_payload(self): os.environ.pop("PLUTUS_ORG", None) +class TestPolishIssue37(unittest.TestCase): + """Tests for the polish punch-list fixes (issue #37).""" + + def test_strict_int_accepts_integers(self): + """strict_int accepts int and integer-valued strings.""" + self.assertEqual(strict_int(5), 5) + self.assertEqual(strict_int('5'), 5) + self.assertEqual(strict_int(0), 0) + self.assertEqual(strict_int('0'), 0) + self.assertEqual(strict_int(-10), -10) + self.assertEqual(strict_int('-10'), -10) + + def test_strict_int_rejects_floats(self): + """strict_int rejects float values and float-bearing strings.""" + with self.assertRaises(ValueError): + strict_int(1.9) + with self.assertRaises(ValueError): + strict_int('1.9') + with self.assertRaises(ValueError): + strict_int(1.0) + with self.assertRaises(ValueError): + strict_int('1.0') + + def test_db_wiring_honors_env(self): + """db_path() honors PLUTUS_DB environment variable.""" + orig = os.environ.get('PLUTUS_DB') + try: + test_path = '/tmp/test_plutus_custom.db' + os.environ['PLUTUS_DB'] = test_path + self.assertEqual(str(cfgmod.db_path()), test_path) + finally: + if orig is None: + os.environ.pop('PLUTUS_DB', None) + else: + os.environ['PLUTUS_DB'] = orig + + def test_config_save_creates_backup(self): + """config.save() creates a timestamped backup when file exists.""" + import tempfile + import shutil + from pathlib import Path + + home = tempfile.mkdtemp() + orig_home = os.environ.get('PLUTUS_HOME') + try: + os.environ['PLUTUS_HOME'] = home + cfg_path = cfgmod.config_path() + + # First save: no backup should be created + initial_cfg = {'test': 'value1'} + cfgmod.save(initial_cfg) + self.assertTrue(cfg_path.exists()) + backups = list(cfg_path.parent.glob('config.yaml.bak-*')) + self.assertEqual(len(backups), 0, "No backup on first save") + + # Second save: backup should be created + import time + time.sleep(1.1) # Ensure different timestamp + updated_cfg = {'test': 'value2'} + cfgmod.save(updated_cfg) + backups = list(cfg_path.parent.glob('config.yaml.bak-*')) + self.assertGreaterEqual(len(backups), 1, "At least one backup after second save") + + # Third save: another backup + time.sleep(1.1) + cfgmod.save({'test': 'value3'}) + backups = list(cfg_path.parent.glob('config.yaml.bak-*')) + self.assertGreaterEqual(len(backups), 2, "At least two backups after third save") + finally: + if orig_home is None: + os.environ.pop('PLUTUS_HOME', None) + else: + os.environ['PLUTUS_HOME'] = orig_home + shutil.rmtree(home, ignore_errors=True) + + def test_email_verified_requires_truthy(self): + """email_verified must be explicitly True or 'true', missing/False rejected.""" + cfg = { + 'auth': { + 'google_client_id': 'test-client-id', + 'google_client_secret': 'test-secret', + 'base_url': 'http://localhost:8420' + } + } + + # Missing email_verified should raise + claims_missing = { + 'aud': 'test-client-id', + 'iss': 'https://accounts.google.com', + 'exp': time.time() + 3600, + 'nonce': 'test-nonce', + 'email': 'test@example.com' + # email_verified missing + } + with self.assertRaises(authmod.AuthError) as ctx: + authmod._claims_from_id_token.__wrapped__ = authmod._claims_from_id_token + # We need to mock just the claims extraction part + # Since _claims_from_id_token expects a JWT, let's test directly + import base64 + import json + + def make_fake_token(claims): + # Create a fake JWT (we won't verify signature) + header = base64.urlsafe_b64encode(json.dumps({'alg': 'RS256'}).encode()).decode().rstrip('=') + payload = base64.urlsafe_b64encode(json.dumps(claims).encode()).decode().rstrip('=') + return f"{header}.{payload}.fake_signature" + + fake_token = make_fake_token(claims_missing) + authmod._claims_from_id_token(fake_token, cfg, 'test-nonce') + self.assertIn('not verified', str(ctx.exception)) + + # False should raise + claims_false = claims_missing.copy() + claims_false['email_verified'] = False + with self.assertRaises(authmod.AuthError) as ctx: + fake_token = make_fake_token(claims_false) + authmod._claims_from_id_token(fake_token, cfg, 'test-nonce') + self.assertIn('not verified', str(ctx.exception)) + + # True should pass + claims_true = claims_missing.copy() + claims_true['email_verified'] = True + fake_token = make_fake_token(claims_true) + result = authmod._claims_from_id_token(fake_token, cfg, 'test-nonce') + self.assertEqual(result['email'], 'test@example.com') + + # 'true' string should pass + claims_str = claims_missing.copy() + claims_str['email_verified'] = 'true' + fake_token = make_fake_token(claims_str) + result = authmod._claims_from_id_token(fake_token, cfg, 'test-nonce') + self.assertEqual(result['email'], 'test@example.com') + + def test_minimal_yaml_read_roundtrip(self): + """_minimal_yaml_read can parse simple YAML config.""" + import tempfile + from pathlib import Path + + fd, path_str = tempfile.mkstemp(suffix='.yaml') + os.close(fd) + path = Path(path_str) + + try: + # Write a simple YAML config + yaml_content = """server: + host: 127.0.0.1 + port: 8420 +billing: + currency: usd + stripe_price_pro: price_123 +alerts: + enabled: false + low_balance_usd: 10.0 +""" + path.write_text(yaml_content, encoding='utf-8') + + # Read it back with the minimal reader + result = cfgmod._minimal_yaml_read(path) + + # Verify structure + self.assertIn('server', result) + self.assertIn('billing', result) + self.assertIn('alerts', result) + self.assertEqual(result['server']['host'], '127.0.0.1') + self.assertEqual(result['server']['port'], 8420) + self.assertEqual(result['billing']['currency'], 'usd') + self.assertEqual(result['billing']['stripe_price_pro'], 'price_123') + self.assertEqual(result['alerts']['enabled'], False) + self.assertEqual(result['alerts']['low_balance_usd'], 10.0) + finally: + path.unlink() + + if __name__ == "__main__": unittest.main()