diff --git a/devtrack_sdk/cli.py b/devtrack_sdk/cli.py index 155cb85..a27fa7e 100644 --- a/devtrack_sdk/cli.py +++ b/devtrack_sdk/cli.py @@ -1,9 +1,11 @@ # devtrack_sdk/cli.py import json import os +import re from datetime import datetime, timedelta from typing import Optional +import duckdb import requests import typer from rich.console import Console @@ -23,6 +25,75 @@ ) +def parse_lock_error(error_msg: str) -> dict: + """Extract PID and process info from DuckDB lock error.""" + pid_match = re.search(r"PID (\d+)", error_msg) + process_match = re.search(r"held in ([^\s(]+)", error_msg) + + return { + "pid": pid_match.group(1) if pid_match else None, + "process": process_match.group(1) if process_match else None, + "is_lock_error": "Conflicting lock" in error_msg + or "Could not set lock" in error_msg, + } + + +def check_db_initialized_via_api(console, db_path: str, timeout: int = 2) -> bool: + """ + Check if database is initialized via HTTP API. + Returns True if initialized and shows info, False otherwise. + """ + try: + stats_url = detect_devtrack_endpoint(timeout=timeout) + if not stats_url: + return False + + response = requests.get(stats_url, timeout=5) + if response.status_code != 200: + return False + + # Database is accessible via API - it's initialized + data = response.json() + console.print( + "[bold green]✅ Database is already initialized " "(accessible via API)[/]" + ) + + # Show database info from API + entries = data.get("entries", []) + total_requests = len(entries) + unique_endpoints = len( + set( + ( + entry.get("path_pattern", entry.get("path", "")), + entry.get("method", ""), + ) + for entry in entries + ) + ) + + table = Table( + title="Database Information (via API)", + border_style="green", + ) + table.add_column("Property", style="cyan") + table.add_column("Value", style="green") + table.add_row("Database Path", db_path) + table.add_row("Total Requests", str(total_requests)) + table.add_row("Unique Endpoints", str(unique_endpoints)) + console.print(table) + + console.print( + "[dim]💡 Your application is running and the database " + "is already initialized.[/]" + ) + console.print( + "[dim]💡 To reset the database, run: " "[cyan]devtrack reset[/][/]" + ) + return True + except Exception: + return False + + def detect_devtrack_endpoint(timeout=0.5) -> str: possible_hosts = ["localhost", "127.0.0.1", "0.0.0.0"] possible_ports = [8000, 8888, 9000, 8080] @@ -102,17 +173,190 @@ def version(): def init( db_path: str = typer.Option("devtrack_logs.db", help="Path to the database file"), force: bool = typer.Option( - False, "--force", "-f", help="Force initialization even if database exists" + False, + "--force", + "-f", + help="Reset database: delete all logs and reset sequence to 0", ), ): """🗄️ Initialize a new DevTrack database with DuckDB backend.""" console = Console() + # Step 1: Try read-only to check if already initialized + try: + db_readonly = DevTrackDB(db_path, read_only=True) + if db_readonly.tables_exist(): + # Tables exist - check if we need to reset (--force) + db_readonly.close() + + def force_initialize_database(): + # --force specified: delete all logs and reset sequence + console.print("[yellow]🔄 Resetting database (--force specified)...[/]") + + # Try HTTP API first if app is running + try: + stats_url = detect_devtrack_endpoint(timeout=0.5) + if stats_url: + delete_url = stats_url.replace("/stats", "/logs?all_logs=true") + console.print("[dim]App is running, using HTTP API...[/]") + + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + console=console, + ) as progress: + task = progress.add_task( + "Resetting database via API...", total=None + ) + response = requests.delete(delete_url, timeout=10) + progress.update( + task, description="✅ Database reset successfully!" + ) + + if response.status_code == 200: + result = response.json() + deleted_count = result.get("deleted_count", 0) + console.print( + f"[bold green]✅ Database reset complete via API. " + f"Deleted {deleted_count} log entries.[/]" + ) + console.print( + "[dim] Note: Sequence will reset automatically " + "on next insert[/]" + ) + raise typer.Exit(0) + else: + console.print( + f"[yellow]API returned status {response.status_code}, " + f"trying direct access...[/]" + ) + except typer.Exit: + # Re-raise typer.Exit to allow proper exit + raise + except requests.RequestException: + # App not running or API not available - continue to direct access + pass + except Exception as e: + console.print( + f"[yellow]API call failed: {e}, trying direct access...[/]" + ) + pass + + # Direct database access to delete logs and reset sequence + try: + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + console=console, + ) as progress: + task = progress.add_task("Resetting database...", total=None) + db_reset = DevTrackDB(db_path, read_only=False) + deleted_count = db_reset.delete_all_logs() + db_reset.reset_sequence() + db_reset.close() + progress.update( + task, description="✅ Database reset successfully!" + ) + + console.print( + f"[bold green]✅ Database reset complete. " + f"Deleted {deleted_count} log entries and reset sequence.[/]" + ) + raise typer.Exit(0) + except duckdb.IOException as e: + error_msg = str(e) + lock_info = parse_lock_error(error_msg) + + if lock_info["is_lock_error"]: + console.print( + "[red]❌ Database is locked by another process[/]" + ) + if lock_info["pid"]: + pid = lock_info["pid"] + console.print(f"[yellow] Locked by process: PID {pid}[/]") + console.print( + "[yellow]💡 Stop your application and try again[/]" + ) + else: + console.print(f"[red]❌ Failed to reset database:[/] {e}") + raise typer.Exit(1) + except Exception as e: + console.print(f"[red]❌ Failed to reset database:[/] {e}") + raise typer.Exit(1) + + if force: + force_initialize_database() + else: + # No force - just show that it's already initialized + console.print( + f"[bold green]✅ Database already initialized at:[/] {db_path}, \ + if you want to reset it, run `devtrack reset`[/]" + ) + + # Show database info + try: + db_info = DevTrackDB(db_path, read_only=True) + stats = db_info.get_stats_summary() + db_info.close() + + table = Table(title="Database Information", border_style="green") + table.add_column("Property", style="cyan") + table.add_column("Value", style="green") + table.add_row("Database Path", db_path) + table.add_row("Total Requests", str(stats.get("total_requests", 0))) + table.add_row( + "Unique Endpoints", str(stats.get("unique_endpoints", 0)) + ) + avg_duration = stats.get("avg_duration_ms", 0) or 0 + table.add_row("Average Duration", f"{avg_duration:.2f} ms") + console.print(table) + except Exception: + pass # Ignore errors when showing info + + raise typer.Exit(0) + db_readonly.close() + except typer.Exit: + # Re-raise typer.Exit to allow proper exit + raise + except duckdb.IOException: + # Can't open read-only - might be locked, but continue to try write + pass + except Exception: + # Other errors - continue to try write + pass + + # Check if tables exist before trying to create them + # (This handles the case where tables exist but we couldn't check + # read-only due to lock) + tables_already_exist = False + try: + db_check = DevTrackDB(db_path, read_only=True) + tables_already_exist = db_check.tables_exist() + db_check.close() + if tables_already_exist and not force: + # Tables exist and no force - already initialized + console.print( + f"[bold green]✅ Database already initialized at:[/] {db_path}, \ + if you want to reset it, run `devtrack reset`[/]" + ) + raise typer.Exit(0) + except (duckdb.IOException, Exception): + # Can't check - might be locked, try API check before asking overwrite + pass + + # Before asking to overwrite, check if database is initialized via API + # (This prevents asking to overwrite when app is running and DB is initialized) if os.path.exists(db_path) and not force: + # Try to check via API first if database file exists but we couldn't access it + if check_db_initialized_via_api(console, db_path, timeout=1): + raise typer.Exit(0) + + # Only ask overwrite if we couldn't determine initialization status via API if not Confirm.ask(f"Database '{db_path}' already exists. Overwrite?"): console.print("[yellow]Initialization cancelled.[/]") raise typer.Exit(0) + # Step 2: Try write mode to create tables try: with Progress( SpinnerColumn(), @@ -120,7 +364,7 @@ def init( console=console, ) as progress: task = progress.add_task("Initializing database...", total=None) - db = init_db(db_path) + db = init_db(db_path, read_only=False) progress.update(task, description="✅ Database initialized successfully!") console.print(f"[bold green]✅ DevTrack database initialized at:[/] {db_path}") @@ -138,7 +382,33 @@ def init( table.add_row("Average Duration", f"{avg_duration:.2f} ms") console.print(table) + db.close() + + except duckdb.IOException as e: + error_msg = str(e) + lock_info = parse_lock_error(error_msg) + + if lock_info["is_lock_error"]: + console.print("[yellow]⚠️ Cannot create tables (database is locked)[/]") + if lock_info["pid"]: + console.print(f"[dim] Locked by process: PID {lock_info['pid']}[/]") + + # Try to check if database is already initialized via HTTP API + console.print( + "[dim] Checking if database is already initialized via API...[/]" + ) + if check_db_initialized_via_api(console, db_path, timeout=2): + raise typer.Exit(0) + console.print( + "[dim] Your application may auto-initialize on first request[/]" + ) + console.print("[dim] Or stop the app and run this command again[/]") + # Don't fail - let app handle initialization + raise typer.Exit(0) + else: + console.print(f"[red]❌ Failed to initialize database:[/] {e}") + raise typer.Exit(1) except Exception as e: console.print(f"[red]❌ Failed to initialize database:[/] {e}") raise typer.Exit(1) @@ -164,6 +434,48 @@ def reset( console.print("[yellow]Reset cancelled.[/]") raise typer.Exit(0) + # Step 1: Try to use HTTP API if app is running + try: + stats_url = detect_devtrack_endpoint(timeout=0.5) + if stats_url: + # App is running - use HTTP API + delete_url = stats_url.replace("/stats", "/logs?all_logs=true") + console.print("[dim]App is running, using HTTP API...[/]") + + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + console=console, + ) as progress: + task = progress.add_task("Resetting database via API...", total=None) + response = requests.delete(delete_url, timeout=10) + progress.update(task, description="✅ Database reset successfully!") + + if response.status_code == 200: + result = response.json() + deleted_count = result.get("deleted_count", 0) + console.print( + f"[bold green]✅ Database reset complete via API. " + f"Deleted {deleted_count} log entries.[/]" + ) + raise typer.Exit(0) + else: + status = response.status_code + console.print( + f"[yellow]API returned status {status}, " + f"trying direct access...[/]" + ) + except requests.RequestException: + # App not running or API not available - continue to direct access + pass + except typer.Exit: + # Re-raise typer.Exit to allow proper exit + raise + except Exception as e: + console.print(f"[yellow]API call failed: {e}, trying direct access...[/]") + pass + + # Step 2: Direct database access (app not running or API failed) try: with Progress( SpinnerColumn(), @@ -171,8 +483,9 @@ def reset( console=console, ) as progress: task = progress.add_task("Resetting database...", total=None) - db = DevTrackDB(db_path) + db = DevTrackDB(db_path, read_only=False) deleted_count = db.delete_all_logs() + db.close() progress.update(task, description="✅ Database reset successfully!") console.print( @@ -180,6 +493,33 @@ def reset( f"Deleted {deleted_count} log entries.[/]" ) + except duckdb.IOException as e: + error_msg = str(e) + lock_info = parse_lock_error(error_msg) + + if lock_info["is_lock_error"]: + console.print("[red]❌ Database is locked by another process[/]") + if lock_info["pid"]: + console.print( + f"[yellow] Locked by process: PID {lock_info['pid']}[/]" + ) + + console.print("\n[bold yellow]💡 Solutions:[/]") + console.print( + " 1. [cyan]Stop your application[/] (the process holding the lock)" + ) + console.print(" 2. [cyan]Then run this command again[/]") + console.print(" 3. [cyan]Or use the HTTP API endpoint:[/]") + try: + stats_url = detect_devtrack_endpoint() + if stats_url: + delete_url = stats_url.replace("/stats", "/logs?all_logs=true") + console.print(f" [dim]DELETE {delete_url}[/]") + except Exception: + pass + else: + console.print(f"[red]❌ Failed to reset database:[/] {e}") + raise typer.Exit(1) except Exception as e: console.print(f"[red]❌ Failed to reset database:[/] {e}") raise typer.Exit(1) @@ -209,7 +549,7 @@ def export( ) as progress: task = progress.add_task("Exporting logs...", total=None) - db = DevTrackDB(db_path) + db = DevTrackDB(db_path, read_only=True) # Get logs based on filters if path_pattern: @@ -222,6 +562,13 @@ def export( progress.update(task, description="Writing to file...") if format.lower() == "json": + # Convert datetime objects to ISO format strings for JSON serialization + def json_serializer(obj): + """JSON serializer for objects not serializable by default.""" + if isinstance(obj, datetime): + return obj.isoformat() + raise TypeError(f"Type {type(obj)} not serializable") + with open(output_file, "w") as f: json.dump( { @@ -236,6 +583,7 @@ def export( }, f, indent=2, + default=json_serializer, ) elif format.lower() == "csv": import csv @@ -276,6 +624,7 @@ def query( console.print(f"[red]Database '{db_path}' does not exist.[/]") raise typer.Exit(1) + entries = None try: with Progress( SpinnerColumn(), @@ -284,7 +633,7 @@ def query( ) as progress: task = progress.add_task("Querying logs...", total=None) - db = DevTrackDB(db_path) + db = DevTrackDB(db_path, read_only=True) # Get logs based on filters if path_pattern: @@ -294,6 +643,8 @@ def query( else: entries = db.get_all_logs(limit) + db.close() + # Apply additional filters if method: entries = [ @@ -310,163 +661,129 @@ def query( ] progress.update(task, description="✅ Query complete!") - - if not entries: - console.print("[yellow]No logs found matching the criteria.[/]") - return - - # Display results - console.rule("[bold green]📊 Query Results[/]", style="green") - - if verbose: - # Detailed view - for i, entry in enumerate(entries[:10], 1): # Show first 10 in detail - panel = Panel( - f"[bold]Path:[/] {entry.get('path', 'N/A')}\n" - f"[bold]Method:[/] {entry.get('method', 'N/A')}\n" - f"[bold]Status:[/] {entry.get('status_code', 'N/A')}\n" - f"[bold]Duration:[/] {entry.get('duration_ms', 0):.2f} ms\n" - f"[bold]Timestamp:[/] {entry.get('timestamp', 'N/A')}\n" - f"[bold]Client IP:[/] {entry.get('client_ip', 'N/A')}\n" - f"[bold]User Agent:[/] {entry.get('user_agent', 'N/A')[:50]}...", - title=f"Entry {i}", - border_style="blue", + except duckdb.IOException as e: + # Database is locked - try HTTP endpoint as fallback + error_msg = str(e) + if "lock" in error_msg.lower() or "conflicting" in error_msg.lower(): + console.print("[yellow]⚠️ Database is locked by another process.[/]") + console.print("[dim] Attempting to fetch logs via HTTP endpoint...[/]") + try: + stats_url = detect_devtrack_endpoint(timeout=2) + if stats_url: + with console.status("[bold cyan]Fetching logs from DevTrack...[/]"): + response = requests.get(stats_url, timeout=5) + response.raise_for_status() + data = response.json() + entries = data.get("entries", []) + + # Apply filters to API data + if path_pattern: + entries = [ + e + for e in entries + if path_pattern.lower() + in e.get("path_pattern", e.get("path", "")).lower() + ] + + if status_code: + entries = [ + e + for e in entries + if e.get("status_code") == status_code + ] + + if method: + entries = [ + e + for e in entries + if e.get("method", "").upper() == method.upper() + ] + + if days: + cutoff_date = datetime.now() - timedelta(days=days) + entries = [ + e + for e in entries + if datetime.fromisoformat( + e["timestamp"].replace("Z", "+00:00") + ) + >= cutoff_date + ] + + # Apply limit after all filters + if limit: + entries = entries[:limit] + + console.print( + "[green]✅ Successfully fetched logs via HTTP endpoint[/]" + ) + except Exception as endpoint_error: + console.print(f"[red]❌ Failed to query database:[/] {e}") + console.print( + f"[red]❌ Also failed to fetch via HTTP endpoint:[/] " + f"{endpoint_error}" ) - console.print(panel) - else: - # Table view - table = Table( - title=f"Query Results ({len(entries)} entries)", border_style="blue" - ) - table.add_column("Path", style="cyan", no_wrap=True) - table.add_column("Method", style="green") - table.add_column("Status", justify="center", style="yellow") - table.add_column("Duration (ms)", justify="right", style="magenta") - table.add_column("Timestamp", style="dim") - - for entry in entries[:limit]: - table.add_row( - entry.get("path", "N/A"), - entry.get("method", "N/A"), - str(entry.get("status_code", "N/A")), - f"{entry.get('duration_ms', 0):.2f}", - entry.get("timestamp", "N/A")[:19], # Show only date and time + console.print( + "[yellow]💡 Tip: Ensure your application is running " + "and accessible[/]" ) - - console.print(table) - - console.print(f"[bold green]📊 Total results:[/] {len(entries)}") - + raise typer.Exit(1) + else: + # Other IOException - re-raise + console.print(f"[red]❌ Failed to query database:[/] {e}") + raise typer.Exit(1) + except typer.Exit: + # Re-raise typer.Exit to allow proper exit + raise except Exception as e: console.print(f"[red]❌ Failed to query logs:[/] {e}") raise typer.Exit(1) + if not entries: + console.print("[yellow]No logs found matching the criteria.[/]") + return -@app.command() -def monitor( - db_path: str = typer.Option("devtrack_logs.db", help="Path to the database file"), - interval: int = typer.Option(5, help="Refresh interval in seconds"), - top: int = typer.Option(10, help="Show top N endpoints"), -): - """📊 Monitor DevTrack logs in real-time with live dashboard.""" - console = Console() - - if not os.path.exists(db_path): - console.print(f"[red]Database '{db_path}' does not exist.[/]") - raise typer.Exit(1) - - console.print("[bold green]🔍 Starting real-time monitoring...[/]") - console.print( - f"[dim]Refresh interval: {interval} seconds | Press Ctrl+C to stop[/]" - ) - - try: - db = DevTrackDB(db_path) - last_count = 0 - - while True: - try: - # Clear screen - console.clear() - - # Get current stats - stats = db.get_stats_summary() - current_count = stats.get("total_requests", 0) - new_requests = current_count - last_count - - # Header - console.rule( - f"[bold green]📊 DevTrack Real-time Monitor[/] | " - f"[dim]{datetime.now().strftime('%H:%M:%S')}[/]", - style="green", - ) - - # Summary stats - summary_table = Table(title="Live Statistics", border_style="green") - summary_table.add_column("Metric", style="cyan") - summary_table.add_column("Value", style="green") - - summary_table.add_row("Total Requests", str(current_count)) - summary_table.add_row("New Requests (last interval)", str(new_requests)) - summary_table.add_row( - "Unique Endpoints", str(stats.get("unique_endpoints", 0)) - ) - avg_duration = stats.get("avg_duration_ms", 0) or 0 - summary_table.add_row("Avg Duration", f"{avg_duration:.2f} ms") - success_count = stats.get("success_count", 0) or 0 - summary_table.add_row( - "Success Rate", - f"{success_count / max(current_count, 1) * 100:.1f}%", - ) - - console.print(summary_table) - - # Recent logs - recent_logs = db.get_all_logs(limit=top) - if recent_logs: - console.rule("[bold cyan]📈 Recent Activity[/]", style="cyan") - recent_table = Table(border_style="blue") - recent_table.add_column("Time", style="dim") - recent_table.add_column("Path", style="cyan") - recent_table.add_column("Method", style="green") - recent_table.add_column("Status", justify="center", style="yellow") - recent_table.add_column( - "Duration", justify="right", style="magenta" - ) - - for log in recent_logs[:top]: - timestamp = log.get("timestamp", "")[ - :19 - ] # Show only date and time - recent_table.add_row( - timestamp, - log.get("path", "N/A"), - log.get("method", "N/A"), - str(log.get("status_code", "N/A")), - f"{log.get('duration_ms', 0):.2f} ms", - ) - - console.print(recent_table) - - last_count = current_count - - # Wait for next interval - import time as time_module - - time_module.sleep(interval) - - except KeyboardInterrupt: - console.print("\n[yellow]Monitoring stopped by user.[/]") - break - except Exception as e: - console.print(f"[red]Error during monitoring: {e}[/]") - import time as time_module + # Display results + console.rule("[bold green]📊 Query Results[/]", style="green") + + if verbose: + # Detailed view + for i, entry in enumerate(entries[:10], 1): # Show first 10 in detail + panel = Panel( + f"[bold]Path:[/] {entry.get('path', 'N/A')}\n" + f"[bold]Method:[/] {entry.get('method', 'N/A')}\n" + f"[bold]Status:[/] {entry.get('status_code', 'N/A')}\n" + f"[bold]Duration:[/] {entry.get('duration_ms', 0):.2f} ms\n" + f"[bold]Timestamp:[/] {entry.get('timestamp', 'N/A')}\n" + f"[bold]Client IP:[/] {entry.get('client_ip', 'N/A')}\n" + f"[bold]User Agent:[/] {entry.get('user_agent', 'N/A')[:50]}...", + title=f"Entry {i}", + border_style="blue", + ) + console.print(panel) + else: + # Table view + table = Table( + title=f"Query Results ({len(entries)} entries)", border_style="blue" + ) + table.add_column("Path", style="cyan", no_wrap=True) + table.add_column("Method", style="green") + table.add_column("Status", justify="center", style="yellow") + table.add_column("Duration (ms)", justify="right", style="magenta") + table.add_column("Timestamp", style="dim") + + for entry in entries[:limit]: + table.add_row( + entry.get("path", "N/A"), + entry.get("method", "N/A"), + str(entry.get("status_code", "N/A")), + f"{entry.get('duration_ms', 0):.2f}", + entry.get("timestamp", "N/A")[:19], # Show only date and time + ) - time_module.sleep(interval) + console.print(table) - except Exception as e: - console.print(f"[red]❌ Failed to start monitoring:[/] {e}") - raise typer.Exit(1) + console.print(f"[bold green]📊 Total results:[/] {len(entries)}") @app.command() @@ -482,6 +799,7 @@ def stat( console = Console() console.rule("[bold green]📊 DevTrack Stats CLI[/]", style="green") + entries = None if use_endpoint: # Use HTTP endpoint stats_url = detect_devtrack_endpoint() @@ -496,14 +814,53 @@ def stat( console.print(f"[red]❌ Failed to fetch stats from {stats_url}[/]\n{e}") raise typer.Exit(1) else: - # Use database + # Try database first, fallback to HTTP endpoint if locked if not os.path.exists(db_path): console.print(f"[red]Database '{db_path}' does not exist.[/]") raise typer.Exit(1) try: - db = DevTrackDB(db_path) + db = DevTrackDB(db_path, read_only=True) entries = db.get_all_logs() + db.close() + except duckdb.IOException as e: + # Database is locked - try HTTP endpoint as fallback + error_msg = str(e) + if "lock" in error_msg.lower() or "conflicting" in error_msg.lower(): + console.print("[yellow]⚠️ Database is locked by another process.[/]") + console.print( + "[dim] Attempting to fetch stats via HTTP endpoint...[/]" + ) + try: + stats_url = detect_devtrack_endpoint() + with console.status( + "[bold cyan]Fetching stats from DevTrack...[/]" + ): + response = requests.get(stats_url, timeout=5) + response.raise_for_status() + data = response.json() + entries = data.get("entries", []) + console.print( + "[green]✅ Successfully fetched stats via HTTP endpoint[/]" + ) + except Exception as endpoint_error: + console.print(f"[red]❌ Failed to read database:[/] {e}") + console.print( + f"[red]❌ Also failed to fetch via HTTP endpoint:[/] " + f"{endpoint_error}" + ) + console.print( + "[yellow]💡 Tip: Use `devtrack stat --endpoint` " + "when your app is running[/]" + ) + raise typer.Exit(1) + else: + # Other IOException - re-raise + console.print(f"[red]❌ Failed to read database:[/] {e}") + raise typer.Exit(1) + except typer.Exit: + # Re-raise typer.Exit to allow proper exit + raise except Exception as e: console.print(f"[red]❌ Failed to read database:[/] {e}") raise typer.Exit(1) @@ -682,7 +1039,7 @@ def health( # Check database if os.path.exists(db_path): try: - db = DevTrackDB(db_path) + db = DevTrackDB(db_path, read_only=True) stats = db.get_stats_summary() health_status["checks"].append( { @@ -780,7 +1137,6 @@ def show_help(): console.print("[bold green]Quick Start:[/]") console.print(" [cyan]devtrack init[/] # Initialize database") console.print(" [cyan]devtrack stat[/] # View statistics") - console.print(" [cyan]devtrack monitor[/] # Real-time monitoring") console.print() # Commands overview @@ -801,9 +1157,6 @@ def show_help(): commands_table.add_row( "query", "🔍 Query DevTrack logs with advanced filtering and search" ) - commands_table.add_row( - "monitor", "📊 Monitor DevTrack logs in real-time with live dashboard" - ) commands_table.add_row( "stat", "📈 Display comprehensive API statistics and endpoint analytics" ) @@ -819,9 +1172,6 @@ def show_help(): console.print(" [dim]# Initialize database[/]") console.print(" [cyan]devtrack init --force[/]") console.print() - console.print(" [dim]# Real-time monitoring[/]") - console.print(" [cyan]devtrack monitor --interval 3 --top 15[/]") - console.print() console.print(" [dim]# Query logs with filters[/]") console.print(" [cyan]devtrack query --status-code 404 --days 7 --verbose[/]") console.print() @@ -845,10 +1195,6 @@ def show_help(): console.print( " [dim]GitHub:[/] [blue]https://github.com/mahesh-solanke/devtrack-sdk[/]" ) - console.print( - " [dim]Documentation:[/] [blue]https://devtrack-sdk.readthedocs.io[/]" - ) - console.print() @app.command() diff --git a/devtrack_sdk/controller/devtrack_routes.py b/devtrack_sdk/controller/devtrack_routes.py index 6227d05..b04ac1b 100644 --- a/devtrack_sdk/controller/devtrack_routes.py +++ b/devtrack_sdk/controller/devtrack_routes.py @@ -18,7 +18,7 @@ async def stats( status_code: Optional[int] = Query(None, description="Filter by status code"), ): """Get DevTrack statistics and logs from DuckDB.""" - db = get_db() + db = get_db(read_only=True) try: # Get summary stats @@ -62,7 +62,7 @@ async def delete_logs( ), ): """Delete logs from the database with various filtering options.""" - db = get_db() + db = get_db(read_only=False) try: deleted_count = 0 @@ -105,7 +105,7 @@ async def delete_logs( @router.delete("/__devtrack__/logs/{log_id}", include_in_schema=False) async def delete_log_by_id(log_id: int): """Delete a specific log by its ID.""" - db = get_db() + db = get_db(read_only=False) try: deleted_count = db.delete_logs_by_id(log_id) @@ -131,7 +131,7 @@ async def metrics_traffic( hours: int = Query(24, description="Number of hours to look back"), ): """Get traffic metrics over time.""" - db = get_db() + db = get_db(read_only=True) try: traffic_data = db.get_traffic_over_time(hours=hours) return {"traffic": traffic_data} @@ -144,7 +144,7 @@ async def metrics_errors( hours: int = Query(24, description="Number of hours to look back"), ): """Get error trends and top failing routes.""" - db = get_db() + db = get_db(read_only=True) try: error_data = db.get_error_trends(hours=hours) return error_data @@ -157,7 +157,7 @@ async def metrics_perf( hours: int = Query(24, description="Number of hours to look back"), ): """Get performance metrics (p50/p95/p99 latency).""" - db = get_db() + db = get_db(read_only=True) try: perf_data = db.get_performance_metrics(hours=hours) return perf_data @@ -170,7 +170,7 @@ async def consumers( hours: int = Query(24, description="Number of hours to look back"), ): """Get consumer segmentation data.""" - db = get_db() + db = get_db(read_only=True) try: segments_data = db.get_consumer_segments(hours=hours) return segments_data diff --git a/devtrack_sdk/database.py b/devtrack_sdk/database.py index f977620..2d23f3b 100644 --- a/devtrack_sdk/database.py +++ b/devtrack_sdk/database.py @@ -25,20 +25,24 @@ def _validate_int(value: Any, name: str = "value", min_value: int = 0) -> int: raise raise ValueError(f"{name} must be a valid integer") from e - def __init__(self, db_path: str = "devtrack_logs.db"): + def __init__(self, db_path: str = "devtrack_logs.db", read_only: bool = True): """Initialize the database connection and create tables if they don't exist.""" self.db_path = db_path self._lock = threading.Lock() - # Create initial connection for table creation - self._init_conn = duckdb.connect(db_path) - self._create_tables() - self._init_conn.close() + self.read_only = read_only + # Create initial connection for table creation (only if not read-only) + if not read_only: + self._init_conn = duckdb.connect(db_path) + self._create_tables() + self._init_conn.close() @property def conn(self): """Get thread-local database connection.""" if not hasattr(_thread_local, "connection") or _thread_local.connection is None: - _thread_local.connection = duckdb.connect(self.db_path) + _thread_local.connection = duckdb.connect( + self.db_path, read_only=self.read_only + ) else: # Check if connection is closed and reconnect if needed try: @@ -50,7 +54,9 @@ def conn(self): _thread_local.connection.close() except Exception: pass - _thread_local.connection = duckdb.connect(self.db_path) + _thread_local.connection = duckdb.connect( + self.db_path, read_only=self.read_only + ) return _thread_local.connection def _create_tables(self): @@ -239,6 +245,15 @@ def get_logs_count(self) -> int: result = self.conn.execute("SELECT COUNT(*) FROM request_logs").fetchone() return result[0] + def tables_exist(self) -> bool: + """Check if database tables exist (read-only check).""" + try: + # Try to query the table - will fail if it doesn't exist + self.conn.execute("SELECT 1 FROM request_logs LIMIT 1") + return True + except Exception: + return False + def get_logs_by_path( self, path_pattern: str, limit: Optional[int] = None ) -> List[Dict[str, Any]]: @@ -409,6 +424,19 @@ def delete_all_logs(self) -> int: return count_before + def reset_sequence(self) -> None: + """Reset the sequence to start from 1.""" + try: + # Try to reset the sequence + self.conn.execute("ALTER SEQUENCE seq_log_id RESTART WITH 1") + except Exception: + # If sequence doesn't exist or can't be reset, try to recreate it + try: + self.conn.execute("DROP SEQUENCE IF EXISTS seq_log_id") + self.conn.execute("CREATE SEQUENCE seq_log_id START 1") + except Exception: + pass # Ignore if sequence operations fail + def delete_logs_by_path(self, path_pattern: str) -> int: """Delete logs filtered by path pattern.""" # Get count before deletion @@ -888,30 +916,61 @@ def get_client_traffic_over_time( def close(self): """Close the database connection.""" - self.conn.close() + if ( + hasattr(_thread_local, "connection") + and _thread_local.connection is not None + ): + try: + _thread_local.connection.close() + except Exception: + pass + _thread_local.connection = None def __del__(self): """Ensure connection is closed when object is destroyed.""" - if hasattr(self, "conn"): - self.close() + # Don't access self.conn property here as it may try to create a new connection + # Instead, directly check and close thread-local connection + if ( + hasattr(_thread_local, "connection") + and _thread_local.connection is not None + ): + try: + _thread_local.connection.close() + except Exception: + pass + _thread_local.connection = None # Global database instance _db_instance: Optional[DevTrackDB] = None -def get_db() -> DevTrackDB: +def get_db(read_only: bool = True) -> DevTrackDB: """Get the global database instance.""" global _db_instance + # If instance exists but has different read_only setting, recreate it + # BUT: If we have an existing instance with write access, we can use it + # for reads too (DuckDB allows read operations on write connections) + if _db_instance is not None: + if _db_instance.read_only != read_only: + # If existing instance is write mode and we need read, we can use it + if not _db_instance.read_only and read_only: + # Use existing write connection for read operations + # (allowed by DuckDB) + return _db_instance + # If existing instance is read-only and we need write, recreate + elif _db_instance.read_only and not read_only: + _db_instance.close() + _db_instance = None if _db_instance is None: - _db_instance = DevTrackDB() + _db_instance = DevTrackDB(read_only=read_only) return _db_instance -def init_db(db_path: str = "devtrack_logs.db"): +def init_db(db_path: str = "devtrack_logs.db", read_only: bool = True): """Initialize the database with a custom path.""" global _db_instance if _db_instance: _db_instance.close() - _db_instance = DevTrackDB(db_path) + _db_instance = DevTrackDB(db_path, read_only=read_only) return _db_instance diff --git a/devtrack_sdk/django_middleware.py b/devtrack_sdk/django_middleware.py index 93635cc..de506be 100644 --- a/devtrack_sdk/django_middleware.py +++ b/devtrack_sdk/django_middleware.py @@ -41,12 +41,20 @@ def __init__( if exclude_path: self.skip_paths.extend(exclude_path) - # Initialize database if not already done - if DevTrackDjangoMiddleware._db_instance is None: - db_path = db_path or getattr( - settings, "DEVTRACK_DB_PATH", "devtrack_logs.db" + # Initialize database if not already done or if db_path is provided + # (db_path provided means we want to use a specific database) + final_db_path = db_path or getattr( + settings, "DEVTRACK_DB_PATH", "devtrack_logs.db" + ) + if DevTrackDjangoMiddleware._db_instance is None or ( + db_path and DevTrackDjangoMiddleware._db_instance.db_path != db_path + ): + # Close existing instance if switching databases + if DevTrackDjangoMiddleware._db_instance is not None: + DevTrackDjangoMiddleware._db_instance.close() + DevTrackDjangoMiddleware._db_instance = DevTrackDB( + final_db_path, read_only=False ) - DevTrackDjangoMiddleware._db_instance = DevTrackDB(db_path) super().__init__(get_response) diff --git a/devtrack_sdk/django_views.py b/devtrack_sdk/django_views.py index 235b5a0..227b36e 100644 --- a/devtrack_sdk/django_views.py +++ b/devtrack_sdk/django_views.py @@ -16,7 +16,7 @@ def get_db_instance() -> DevTrackDB: """Get the database instance from middleware""" if DevTrackDjangoMiddleware._db_instance is None: db_path = getattr(settings, "DEVTRACK_DB_PATH", "devtrack_logs.db") - DevTrackDjangoMiddleware._db_instance = DevTrackDB(db_path) + DevTrackDjangoMiddleware._db_instance = DevTrackDB(db_path, read_only=False) return DevTrackDjangoMiddleware._db_instance diff --git a/devtrack_sdk/management/commands/devtrack_reset.py b/devtrack_sdk/management/commands/devtrack_reset.py index 70a04e5..7d91699 100644 --- a/devtrack_sdk/management/commands/devtrack_reset.py +++ b/devtrack_sdk/management/commands/devtrack_reset.py @@ -39,7 +39,7 @@ def handle(self, *args, **options): return try: - db = DevTrackDB(db_path) + db = DevTrackDB(db_path, read_only=False) deleted_count = db.delete_all_logs() self.stdout.write( diff --git a/devtrack_sdk/management/commands/devtrack_stats.py b/devtrack_sdk/management/commands/devtrack_stats.py index d96efb5..0102d59 100644 --- a/devtrack_sdk/management/commands/devtrack_stats.py +++ b/devtrack_sdk/management/commands/devtrack_stats.py @@ -33,7 +33,7 @@ def handle(self, *args, **options): output_format = options["format"] try: - db = DevTrackDB(db_path) + db = DevTrackDB(db_path, read_only=True) stats = db.get_stats_summary() recent_logs = db.get_all_logs(limit=limit) diff --git a/devtrack_sdk/middleware/base.py b/devtrack_sdk/middleware/base.py index 2f35a07..87be817 100644 --- a/devtrack_sdk/middleware/base.py +++ b/devtrack_sdk/middleware/base.py @@ -50,7 +50,7 @@ async def receive() -> Message: try: log_data = await extract_devtrack_log_data(request, response, start_time) - db = self.db_instance if self.db_instance else get_db() + db = self.db_instance if self.db_instance else get_db(read_only=False) db.insert_log(log_data) except Exception as e: print(f"[DevTrackMiddleware] Logging error: {e}") diff --git a/pyproject.toml b/pyproject.toml index 1f28b89..52b20cf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,4 +45,11 @@ devtrack_sdk = [ devtrack = "devtrack_sdk.cli:app" [tool.isort] -profile = "black" \ No newline at end of file +profile = "black" + +[tool.pytest.ini_options] +testpaths = ["tests"] +python_files = ["test_*.py"] +python_classes = ["Test*"] +python_functions = ["test_*"] +norecursedirs = [".git", ".venv", "venv", "__pycache__", "*.egg"] \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..116c7e6 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,38 @@ +""" +Pytest configuration for DevTrack SDK tests +""" + +from unittest.mock import patch + +import pytest +import requests + +# Ignore test_wsgi.py during pytest collection +# It's a WSGI configuration file, not a test file +collect_ignore = ["test_wsgi.py"] + + +@pytest.fixture(autouse=True) +def mock_network_requests(): + """ + Automatically mock network requests to prevent hanging in CI. + This ensures detect_devtrack_endpoint() doesn't make real network calls. + Tests that need specific network behavior should override these mocks. + """ + # Mock requests.get/delete to raise RequestException by default + # This makes detect_devtrack_endpoint() try all URLs (fast with 0.5s timeout) + # and then prompt. We mock typer.prompt/confirm to avoid hanging on user input. + # Tests that need specific behavior will override these mocks. + with patch( + "requests.get", side_effect=requests.RequestException("Mocked network error") + ): + with patch( + "requests.delete", + side_effect=requests.RequestException("Mocked network error"), + ): + # Mock typer prompts with sensible defaults to avoid hanging + # Tests that test prompt behavior will override these + with patch("typer.prompt", return_value="localhost"): + with patch("typer.confirm", return_value=False): + with patch("typer.echo"): # Suppress echo output in tests + yield diff --git a/tests/test_cli.py b/tests/test_cli.py index 31e127f..69f8e00 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,13 +1,34 @@ +import json +import os +import tempfile +from datetime import datetime from unittest.mock import MagicMock, patch import requests from typer.testing import CliRunner from devtrack_sdk.cli import app, detect_devtrack_endpoint +from devtrack_sdk.database import DevTrackDB, init_db runner = CliRunner() +def create_test_db(db_path=None): + """Helper to create a test database file.""" + if db_path is None: + with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp: + db_path = tmp.name + os.unlink(db_path) # Remove empty file + + # Ensure file doesn't exist + if os.path.exists(db_path): + os.unlink(db_path) + + # Create proper database + db = init_db(db_path, read_only=False) + return db_path, db + + def test_version(): result = runner.invoke(app, ["version"]) assert result.exit_code == 0, "Version command failed" @@ -246,3 +267,877 @@ def test_stat_command_empty_stats(): assert ( "No request stats found yet" in result.output ), "Empty stats message mismatch" + + +# ========== INIT COMMAND TESTS ========== + + +def test_init_new_database(): + """Test initializing a new database.""" + with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp: + db_path = tmp.name + os.unlink(db_path) # Remove empty file, let DuckDB create it + + try: + result = runner.invoke(app, ["init", "--db-path", db_path]) + assert result.exit_code == 0, "Init command failed" + assert ( + "✅ DevTrack database initialized" in result.output + or "initialized" in result.output.lower() + ) + assert os.path.exists(db_path), "Database file not created" + + # Verify tables exist + db = DevTrackDB(db_path, read_only=True) + assert db.tables_exist(), "Tables not created" + db.close() + finally: + if os.path.exists(db_path): + os.unlink(db_path) + + +def test_init_existing_database(): + """Test init on existing database without --force.""" + with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp: + db_path = tmp.name + os.unlink(db_path) # Remove empty file + + try: + # Create database first using init_db + from devtrack_sdk.database import init_db + + init_db(db_path, read_only=False).close() + + result = runner.invoke(app, ["init", "--db-path", db_path], input="n\n") + assert result.exit_code == 0, "Init on existing DB should succeed" + assert "already initialized" in result.output or "Overwrite" in result.output + finally: + if os.path.exists(db_path): + os.unlink(db_path) + + +def test_init_with_force(): + """Test init with --force flag.""" + with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp: + db_path = tmp.name + os.unlink(db_path) # Remove empty file + + try: + # Create database with some data + db = init_db(db_path, read_only=False) + db.insert_log( + { + "path": "/test", + "method": "GET", + "status_code": 200, + "duration_ms": 100, + "timestamp": "2024-01-01T00:00:00", + "client_ip": "127.0.0.1", + "user_agent": "test", + } + ) + db.close() + + result = runner.invoke(app, ["init", "--db-path", db_path, "--force"]) + # Should succeed (may use API or direct DB) + assert result.exit_code in [0, 1], "Init with force should handle gracefully" + finally: + if os.path.exists(db_path): + os.unlink(db_path) + + +def test_init_with_force_via_api(): + """Test init --force using HTTP API.""" + with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp: + db_path = tmp.name + os.unlink(db_path) # Remove empty file + + try: + # Create database + db = init_db(db_path, read_only=False) + db.close() + + mock_delete_response = MagicMock( + status_code=200, json=MagicMock(return_value={"deleted_count": 5}) + ) + + with patch( + "devtrack_sdk.cli.detect_devtrack_endpoint", + return_value="http://localhost:8000/__devtrack__/stats", + ): + with patch("requests.delete", return_value=mock_delete_response): + result = runner.invoke(app, ["init", "--db-path", db_path, "--force"]) + assert result.exit_code == 0, "Init with force via API failed" + assert "via API" in result.output or "reset" in result.output.lower() + finally: + if os.path.exists(db_path): + os.unlink(db_path) + + +def test_init_locked_database(): + """Test init when database is locked during initialization.""" + with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp: + db_path = tmp.name + os.unlink(db_path) # Remove empty file - start with no database + + try: + # Simulate lock error when trying to create/initialize database + import duckdb + + lock_error = duckdb.IOException( + "IO Error: Could not set lock on file: " + "Conflicting lock (held in process with PID 12345)" + ) + + # Mock init_db to raise lock error (this is called when + # actually creating tables) + # Also mock detect_devtrack_endpoint to return None (no API available) + with patch("devtrack_sdk.cli.detect_devtrack_endpoint", return_value=None): + with patch("devtrack_sdk.cli.init_db", side_effect=lock_error): + result = runner.invoke(app, ["init", "--db-path", db_path]) + # Should handle lock gracefully - show lock message + assert ( + "lock" in result.output.lower() or "locked" in result.output.lower() + ) + assert result.exit_code == 0 # Should exit gracefully, not crash + finally: + if os.path.exists(db_path): + os.unlink(db_path) + + +# ========== RESET COMMAND TESTS ========== + + +def test_reset_missing_database(): + """Test reset on non-existent database.""" + result = runner.invoke(app, ["reset", "--db-path", "nonexistent.db"]) + assert result.exit_code == 0, "Reset on missing DB should exit gracefully" + assert "does not exist" in result.output + + +def test_reset_with_confirmation(): + """Test reset with confirmation prompt.""" + with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp: + db_path = tmp.name + os.unlink(db_path) # Remove empty file + + try: + db = init_db(db_path, read_only=False) + db.insert_log( + { + "path": "/test", + "method": "GET", + "status_code": 200, + "duration_ms": 100, + "timestamp": "2024-01-01T00:00:00", + "client_ip": "127.0.0.1", + "user_agent": "test", + } + ) + db.close() + + result = runner.invoke(app, ["reset", "--db-path", db_path], input="n\n") + assert result.exit_code == 0, "Reset cancelled should exit gracefully" + assert ( + "cancelled" in result.output.lower() or "Reset cancelled" in result.output + ) + finally: + if os.path.exists(db_path): + os.unlink(db_path) + + +def test_reset_with_yes_flag(): + """Test reset with --yes flag (skip confirmation).""" + with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp: + db_path = tmp.name + os.unlink(db_path) # Remove empty file + + try: + db = init_db(db_path, read_only=False) + db.insert_log( + { + "path": "/test", + "method": "GET", + "status_code": 200, + "duration_ms": 100, + "timestamp": "2024-01-01T00:00:00", + "client_ip": "127.0.0.1", + "user_agent": "test", + } + ) + db.close() + + result = runner.invoke(app, ["reset", "--db-path", db_path, "--yes"]) + assert result.exit_code == 0, "Reset with --yes failed" + assert "reset" in result.output.lower() or "deleted" in result.output.lower() + finally: + if os.path.exists(db_path): + os.unlink(db_path) + + +def test_reset_via_api(): + """Test reset using HTTP API.""" + with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp: + db_path = tmp.name + os.unlink(db_path) # Remove empty file + + try: + db = init_db(db_path, read_only=False) + db.close() + + mock_delete_response = MagicMock( + status_code=200, json=MagicMock(return_value={"deleted_count": 10}) + ) + + with patch( + "devtrack_sdk.cli.detect_devtrack_endpoint", + return_value="http://localhost:8000/__devtrack__/stats", + ): + with patch("requests.delete", return_value=mock_delete_response): + result = runner.invoke(app, ["reset", "--db-path", db_path, "--yes"]) + assert result.exit_code == 0, "Reset via API failed" + assert "via API" in result.output or "deleted" in result.output.lower() + finally: + if os.path.exists(db_path): + os.unlink(db_path) + + +def test_reset_locked_database(): + """Test reset when database is locked.""" + with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp: + db_path = tmp.name + os.unlink(db_path) # Remove empty file + + try: + db = init_db(db_path, read_only=False) + db.close() + + # Simulate lock error + import duckdb + + lock_error = duckdb.IOException("IO Error: Could not set lock on file") + + with patch("devtrack_sdk.cli.DevTrackDB") as mock_db: + mock_instance = MagicMock() + mock_instance.delete_all_logs.side_effect = lock_error + mock_db.return_value = mock_instance + + with patch( + "devtrack_sdk.cli.detect_devtrack_endpoint", side_effect=Exception + ): + result = runner.invoke(app, ["reset", "--db-path", db_path, "--yes"]) + assert result.exit_code == 1, "Reset on locked DB should fail" + assert ( + "lock" in result.output.lower() or "locked" in result.output.lower() + ) + finally: + if os.path.exists(db_path): + os.unlink(db_path) + + +# ========== EXPORT COMMAND TESTS ========== + + +def test_export_missing_database(): + """Test export on non-existent database.""" + result = runner.invoke(app, ["export", "--db-path", "nonexistent.db"]) + assert result.exit_code == 1, "Export on missing DB should fail" + assert "does not exist" in result.output + + +def test_export_json_format(): + """Test export to JSON format.""" + with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp_db: + db_path = tmp_db.name + os.unlink(db_path) # Remove empty file + + with tempfile.NamedTemporaryFile(suffix=".json", delete=False) as tmp_out: + out_path = tmp_out.name + os.unlink(out_path) # Remove empty file + + try: + db = init_db(db_path, read_only=False) + db.insert_log( + { + "path": "/api/test", + "path_pattern": "/api/test", # Required field + "method": "GET", + "status_code": 200, + "duration_ms": 100, + "timestamp": "2024-01-01T00:00:00", + "client_ip": "127.0.0.1", + "user_agent": "test", + } + ) + db.close() + + result = runner.invoke( + app, + [ + "export", + "--db-path", + db_path, + "--output-file", + out_path, + "--format", + "json", + ], + ) + assert result.exit_code == 0, f"Export to JSON failed: {result.output}" + assert os.path.exists(out_path), "Output file not created" + + with open(out_path) as f: + data = json.load(f) + assert "entries" in data or "export_timestamp" in data + finally: + if os.path.exists(db_path): + os.unlink(db_path) + if os.path.exists(out_path): + os.unlink(out_path) + + +def test_export_csv_format(): + """Test export to CSV format.""" + with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp_db: + db_path = tmp_db.name + os.unlink(db_path) # Remove empty file + + with tempfile.NamedTemporaryFile(suffix=".csv", delete=False) as tmp_out: + out_path = tmp_out.name + os.unlink(out_path) # Remove empty file + + try: + db = init_db(db_path, read_only=False) + db.insert_log( + { + "path": "/api/test", + "method": "GET", + "status_code": 200, + "duration_ms": 100, + "timestamp": "2024-01-01T00:00:00", + "client_ip": "127.0.0.1", + "user_agent": "test", + } + ) + db.close() + + result = runner.invoke( + app, + [ + "export", + "--db-path", + db_path, + "--output-file", + out_path, + "--format", + "csv", + ], + ) + assert result.exit_code == 0, "Export to CSV failed" + assert os.path.exists(out_path), "Output file not created" + + with open(out_path) as f: + content = f.read() + assert "path" in content or "method" in content + finally: + if os.path.exists(db_path): + os.unlink(db_path) + if os.path.exists(out_path): + os.unlink(out_path) + + +def test_export_with_filters(): + """Test export with path pattern and status code filters.""" + with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp_db: + db_path = tmp_db.name + os.unlink(db_path) # Remove empty file + + with tempfile.NamedTemporaryFile(suffix=".json", delete=False) as tmp_out: + out_path = tmp_out.name + os.unlink(out_path) # Remove empty file + + try: + db = init_db(db_path, read_only=False) + db.insert_log( + { + "path": "/api/test", + "method": "GET", + "status_code": 200, + "duration_ms": 100, + "timestamp": "2024-01-01T00:00:00", + "client_ip": "127.0.0.1", + "user_agent": "test", + } + ) + db.insert_log( + { + "path": "/api/other", + "method": "POST", + "status_code": 404, + "duration_ms": 50, + "timestamp": "2024-01-01T00:00:00", + "client_ip": "127.0.0.1", + "user_agent": "test", + } + ) + db.close() + + result = runner.invoke( + app, + [ + "export", + "--db-path", + db_path, + "--output-file", + out_path, + "--path-pattern", + "/api/test", + "--status-code", + "200", + ], + ) + assert result.exit_code == 0, "Export with filters failed" + finally: + if os.path.exists(db_path): + os.unlink(db_path) + if os.path.exists(out_path): + os.unlink(out_path) + + +def test_export_empty_database(): + """Test export from empty database.""" + with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp_db: + db_path = tmp_db.name + os.unlink(db_path) # Remove empty file + + with tempfile.NamedTemporaryFile(suffix=".json", delete=False) as tmp_out: + out_path = tmp_out.name + os.unlink(out_path) # Remove empty file + + try: + db = init_db(db_path, read_only=False) + db.close() + + result = runner.invoke( + app, ["export", "--db-path", db_path, "--output-file", out_path] + ) + assert result.exit_code == 0, "Export from empty DB should succeed" + finally: + if os.path.exists(db_path): + os.unlink(db_path) + if os.path.exists(out_path): + os.unlink(out_path) + + +def test_export_with_limit(): + """Test export with limit option.""" + with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp_db: + db_path = tmp_db.name + os.unlink(db_path) # Remove empty file + + with tempfile.NamedTemporaryFile(suffix=".json", delete=False) as tmp_out: + out_path = tmp_out.name + os.unlink(out_path) # Remove empty file + + try: + db = init_db(db_path, read_only=False) + for i in range(5): + db.insert_log( + { + "path": f"/api/test{i}", + "path_pattern": f"/api/test{i}", # Required field + "method": "GET", + "status_code": 200, + "duration_ms": 100, + "timestamp": "2024-01-01T00:00:00", + "client_ip": "127.0.0.1", + "user_agent": "test", + } + ) + db.close() + + result = runner.invoke( + app, + ["export", "--db-path", db_path, "--output-file", out_path, "--limit", "2"], + ) + assert result.exit_code == 0, f"Export with limit failed: {result.output}" + finally: + if os.path.exists(db_path): + os.unlink(db_path) + if os.path.exists(out_path): + os.unlink(out_path) + + +# ========== QUERY COMMAND TESTS ========== + + +def test_query_missing_database(): + """Test query on non-existent database.""" + result = runner.invoke(app, ["query", "--db-path", "nonexistent.db"]) + assert result.exit_code == 1, "Query on missing DB should fail" + assert "does not exist" in result.output + + +def test_query_empty_database(): + """Test query on empty database.""" + with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp: + db_path = tmp.name + os.unlink(db_path) # Remove empty file + + try: + db = init_db(db_path, read_only=False) + db.close() + + result = runner.invoke(app, ["query", "--db-path", db_path]) + assert result.exit_code == 0, "Query on empty DB should succeed" + assert "No logs found" in result.output or "No logs" in result.output + finally: + if os.path.exists(db_path): + os.unlink(db_path) + + +def test_query_with_path_pattern(): + """Test query with path pattern filter.""" + with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp: + db_path = tmp.name + os.unlink(db_path) # Remove empty file + + try: + db = init_db(db_path, read_only=False) + db.insert_log( + { + "path": "/api/test", + # Required - get_logs_by_path searches path_pattern column + "path_pattern": "/api/test", + "method": "GET", + "status_code": 200, + "duration_ms": 100, + "timestamp": "2024-01-01T00:00:00", + "client_ip": "127.0.0.1", + "user_agent": "test", + } + ) + db.close() + + result = runner.invoke( + app, ["query", "--db-path", db_path, "--path-pattern", "/api/test"] + ) + assert result.exit_code == 0, "Query with path pattern failed" + # get_logs_by_path searches path_pattern column, so it should find the log + assert ( + "No logs found" not in result.output + ), f"Expected to find logs but got: {result.output}" + assert "/api/test" in result.output or "test" in result.output.lower() + finally: + if os.path.exists(db_path): + os.unlink(db_path) + + +def test_query_with_status_code(): + """Test query with status code filter.""" + with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp: + db_path = tmp.name + os.unlink(db_path) # Remove empty file + + try: + db = init_db(db_path, read_only=False) + db.insert_log( + { + "path": "/api/test", + "method": "GET", + "status_code": 404, + "duration_ms": 100, + "timestamp": "2024-01-01T00:00:00", + "client_ip": "127.0.0.1", + "user_agent": "test", + } + ) + db.close() + + result = runner.invoke( + app, ["query", "--db-path", db_path, "--status-code", "404"] + ) + assert result.exit_code == 0, "Query with status code failed" + assert "404" in result.output + finally: + if os.path.exists(db_path): + os.unlink(db_path) + + +def test_query_with_method(): + """Test query with HTTP method filter.""" + with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp: + db_path = tmp.name + os.unlink(db_path) # Remove empty file + + try: + db = init_db(db_path, read_only=False) + db.insert_log( + { + "path": "/api/test", + "method": "POST", + "status_code": 200, + "duration_ms": 100, + "timestamp": "2024-01-01T00:00:00", + "client_ip": "127.0.0.1", + "user_agent": "test", + } + ) + db.close() + + result = runner.invoke(app, ["query", "--db-path", db_path, "--method", "POST"]) + assert result.exit_code == 0, "Query with method failed" + assert "POST" in result.output + finally: + if os.path.exists(db_path): + os.unlink(db_path) + + +def test_query_with_days(): + """Test query with days filter.""" + with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp: + db_path = tmp.name + os.unlink(db_path) # Remove empty file + + try: + db = init_db(db_path, read_only=False) + db.insert_log( + { + "path": "/api/test", + "method": "GET", + "status_code": 200, + "duration_ms": 100, + "timestamp": datetime.now().isoformat(), + "client_ip": "127.0.0.1", + "user_agent": "test", + } + ) + db.close() + + result = runner.invoke(app, ["query", "--db-path", db_path, "--days", "7"]) + assert result.exit_code == 0, "Query with days failed" + finally: + if os.path.exists(db_path): + os.unlink(db_path) + + +def test_query_with_verbose(): + """Test query with verbose flag.""" + with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp: + db_path = tmp.name + os.unlink(db_path) # Remove empty file + + try: + db = init_db(db_path, read_only=False) + db.insert_log( + { + "path": "/api/test", + "method": "GET", + "status_code": 200, + "duration_ms": 100, + "timestamp": "2024-01-01T00:00:00", + "client_ip": "127.0.0.1", + "user_agent": "test", + } + ) + db.close() + + result = runner.invoke(app, ["query", "--db-path", db_path, "--verbose"]) + assert result.exit_code == 0, "Query with verbose failed" + assert "Path:" in result.output or "Method:" in result.output + finally: + if os.path.exists(db_path): + os.unlink(db_path) + + +def test_query_with_limit(): + """Test query with limit option.""" + with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp: + db_path = tmp.name + os.unlink(db_path) # Remove empty file + + try: + db = init_db(db_path, read_only=False) + for i in range(5): + db.insert_log( + { + "path": f"/api/test{i}", + "method": "GET", + "status_code": 200, + "duration_ms": 100, + "timestamp": "2024-01-01T00:00:00", + "client_ip": "127.0.0.1", + "user_agent": "test", + } + ) + db.close() + + result = runner.invoke(app, ["query", "--db-path", db_path, "--limit", "2"]) + assert result.exit_code == 0, "Query with limit failed" + finally: + if os.path.exists(db_path): + os.unlink(db_path) + + +# ========== STAT COMMAND TESTS (Additional) ========== + + +def test_stat_command_database_mode(): + """Test stat command using database (not endpoint).""" + with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp: + db_path = tmp.name + os.unlink(db_path) # Remove empty file + + try: + db = init_db(db_path, read_only=False) + db.insert_log( + { + "path": "/api/test", + "method": "GET", + "status_code": 200, + "duration_ms": 100, + "timestamp": "2024-01-01T00:00:00", + "client_ip": "127.0.0.1", + "user_agent": "test", + } + ) + db.close() + + result = runner.invoke(app, ["stat", "--db-path", db_path], input="n\n") + assert result.exit_code == 0, "Stat with DB mode failed" + assert "📊 DevTrack Stats CLI" in result.output + finally: + if os.path.exists(db_path): + os.unlink(db_path) + + +def test_stat_command_missing_database(): + """Test stat command with missing database.""" + result = runner.invoke(app, ["stat", "--db-path", "nonexistent.db"]) + assert result.exit_code == 1, "Stat on missing DB should fail" + assert "does not exist" in result.output + + +def test_stat_command_empty_database(): + """Test stat command with empty database.""" + with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp: + db_path = tmp.name + os.unlink(db_path) # Remove empty file + + try: + db = init_db(db_path, read_only=False) + db.close() + + result = runner.invoke(app, ["stat", "--db-path", db_path], input="n\n") + assert result.exit_code == 0, "Stat on empty DB should succeed" + assert "No request stats found" in result.output + finally: + if os.path.exists(db_path): + os.unlink(db_path) + + +# ========== HEALTH COMMAND TESTS ========== + + +def test_health_command_database_only(): + """Test health command checking database only.""" + with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp: + db_path = tmp.name + os.unlink(db_path) # Remove empty file + + try: + db = init_db(db_path, read_only=False) + db.close() + + result = runner.invoke(app, ["health", "--db-path", db_path]) + assert result.exit_code in [0, 1], "Health check should complete" + assert "Health Check" in result.output or "Healthy" in result.output + finally: + if os.path.exists(db_path): + os.unlink(db_path) + + +def test_health_command_missing_database(): + """Test health command with missing database.""" + result = runner.invoke(app, ["health", "--db-path", "nonexistent.db"]) + # Health check doesn't fail on missing DB, it just reports it + assert result.exit_code in [0, 1], "Health check should complete" + assert ( + "Not Found" in result.output + or "does not exist" in result.output + or "⚠️" in result.output + ) + + +def test_health_command_with_endpoint(): + """Test health command with endpoint check.""" + mock_response = MagicMock(status_code=200) + + with patch( + "devtrack_sdk.cli.detect_devtrack_endpoint", + return_value="http://localhost:8000/__devtrack__/stats", + ): + with patch("requests.get", return_value=mock_response): + with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp: + db_path = tmp.name + os.unlink(db_path) # Remove empty file + try: + db = init_db(db_path, read_only=False) + db.close() + + result = runner.invoke( + app, ["health", "--db-path", db_path, "--endpoint"] + ) + assert result.exit_code in [ + 0, + 1, + ], "Health with endpoint should complete" + finally: + if os.path.exists(db_path): + os.unlink(db_path) + + +def test_health_command_endpoint_unreachable(): + """Test health command when endpoint is unreachable.""" + with patch( + "devtrack_sdk.cli.detect_devtrack_endpoint", + return_value="http://localhost:8000/__devtrack__/stats", + ): + with patch( + "requests.get", side_effect=requests.RequestException("Connection failed") + ): + with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp: + db_path = tmp.name + os.unlink(db_path) # Remove empty file + try: + db = init_db(db_path, read_only=False) + db.close() + + result = runner.invoke( + app, ["health", "--db-path", db_path, "--endpoint"] + ) + assert ( + result.exit_code == 1 + ), "Health with unreachable endpoint should fail" + assert "Unreachable" in result.output or "Unhealthy" in result.output + finally: + if os.path.exists(db_path): + os.unlink(db_path) + + +# ========== HELP COMMAND TESTS ========== + + +def test_help_command(): + """Test help command.""" + result = runner.invoke(app, ["help"]) + assert result.exit_code == 0, "Help command failed" + assert "DevTrack CLI" in result.output or "Available Commands" in result.output + + +def test_version_command_detailed(): + """Test version command shows all information.""" + result = runner.invoke(app, ["version"]) + assert result.exit_code == 0, "Version command failed" + assert "DevTrack SDK" in result.output + assert "Version" in result.output or "Framework Support" in result.output diff --git a/tests/test_django_integration.py b/tests/test_django_integration.py index 80ba173..8ec788c 100644 --- a/tests/test_django_integration.py +++ b/tests/test_django_integration.py @@ -4,6 +4,7 @@ import json import os +import tempfile from datetime import datetime, timezone from unittest.mock import Mock, patch @@ -25,7 +26,29 @@ def setUp(self): self.factory = RequestFactory() # Create a mock get_response function self.mock_get_response = Mock() - self.middleware = DevTrackDjangoMiddleware(self.mock_get_response) + # Use a temporary database file to avoid lock conflicts + self.temp_db = tempfile.mktemp(suffix=".db") + # Reset the middleware's database instance to use temp file + if DevTrackDjangoMiddleware._db_instance is not None: + DevTrackDjangoMiddleware._db_instance.close() + DevTrackDjangoMiddleware._db_instance = None + # Create middleware with temp database path + self.middleware = DevTrackDjangoMiddleware( + self.mock_get_response, db_path=self.temp_db + ) + + def tearDown(self): + """Clean up temporary database file""" + # Close and clean up database instance + if DevTrackDjangoMiddleware._db_instance: + DevTrackDjangoMiddleware._db_instance.close() + DevTrackDjangoMiddleware._db_instance = None + # Remove temp file if it exists + if os.path.exists(self.temp_db): + try: + os.unlink(self.temp_db) + except OSError: + pass def test_middleware_initialization(self): """Test middleware initializes correctly""" @@ -92,8 +115,11 @@ def test_extract_log_data(self): def test_custom_exclude_paths(self): """Test custom exclude paths functionality""" + # Use the same temp database for consistency custom_middleware = DevTrackDjangoMiddleware( - self.mock_get_response, exclude_path=["/custom/path/"] + self.mock_get_response, + exclude_path=["/custom/path/"], + db_path=self.temp_db, ) self.assertIn("/custom/path/", custom_middleware.skip_paths) @@ -104,8 +130,22 @@ class DevTrackDjangoViewsTest(TestCase): def setUp(self): self.factory = RequestFactory() # Clear database before each test to avoid test data interference - if DevTrackDjangoMiddleware._db_instance is not None: - DevTrackDjangoMiddleware._db_instance.delete_all_logs() + # Ensure database instance exists and is in write mode + if DevTrackDjangoMiddleware._db_instance is None: + import tempfile + + from devtrack_sdk.database import DevTrackDB + + db_path = tempfile.mktemp(suffix=".db") + DevTrackDjangoMiddleware._db_instance = DevTrackDB(db_path, read_only=False) + # Ensure it's not read-only (recreate if needed) + if DevTrackDjangoMiddleware._db_instance.read_only: + db_path = DevTrackDjangoMiddleware._db_instance.db_path + DevTrackDjangoMiddleware._db_instance.close() + from devtrack_sdk.database import DevTrackDB + + DevTrackDjangoMiddleware._db_instance = DevTrackDB(db_path, read_only=False) + DevTrackDjangoMiddleware._db_instance.delete_all_logs() def test_stats_view(self): """Test stats view returns correct format""" diff --git a/tests/test_metrics_endpoints.py b/tests/test_metrics_endpoints.py index 343fd83..1456f6b 100644 --- a/tests/test_metrics_endpoints.py +++ b/tests/test_metrics_endpoints.py @@ -23,13 +23,16 @@ def app_with_middleware(): if os.path.exists(db_path): os.unlink(db_path) - init_db(db_path) + db = init_db(db_path, read_only=False) # Write mode for middleware app = FastAPI() app.include_router(devtrack_router) - db = get_db() app.add_middleware(DevTrackMiddleware, db_instance=db) + # Store db in app state so tests can access it + app.state.db = db + app.state.db_path = db_path + @app.get("/") async def root(): return {"message": "Hello"} @@ -55,18 +58,22 @@ async def get_user(user_id: int): # Cleanup try: - db = get_db() - db.close() - if os.path.exists(db_path): - os.unlink(db_path) + if hasattr(app.state, "db"): + app.state.db.close() + if hasattr(app.state, "db_path") and os.path.exists(app.state.db_path): + os.unlink(app.state.db_path) except Exception: pass -def clear_db_logs(): +def clear_db_logs(app=None): """Clear all logs from the database.""" try: - db = get_db() + if app and hasattr(app.state, "db"): + # Use the db instance from app state if available + db = app.state.db + else: + db = get_db(read_only=False) # Need write access to delete logs db.delete_all_logs() except Exception: # Database might be closed, ignore @@ -76,7 +83,7 @@ def clear_db_logs(): def test_traffic_metrics_endpoint(app_with_middleware): """Test /__devtrack__/metrics/traffic endpoint.""" client = TestClient(app_with_middleware) - clear_db_logs() + clear_db_logs(app_with_middleware) # Generate some traffic for _ in range(5): @@ -109,7 +116,7 @@ def test_traffic_metrics_endpoint(app_with_middleware): def test_error_metrics_endpoint(app_with_middleware): """Test /__devtrack__/metrics/errors endpoint.""" client = TestClient(app_with_middleware) - clear_db_logs() + clear_db_logs(app_with_middleware) # Generate some requests with errors client.get("/error") # 404 @@ -149,7 +156,7 @@ def test_error_metrics_endpoint(app_with_middleware): def test_performance_metrics_endpoint(app_with_middleware): """Test /__devtrack__/metrics/perf endpoint.""" client = TestClient(app_with_middleware) - clear_db_logs() + clear_db_logs(app_with_middleware) # Generate requests with varying latencies client.get("/") # Fast @@ -189,7 +196,7 @@ def test_performance_metrics_endpoint(app_with_middleware): def test_consumers_endpoint(app_with_middleware): """Test /__devtrack__/consumers endpoint.""" client = TestClient(app_with_middleware) - clear_db_logs() + clear_db_logs(app_with_middleware) # Generate requests with different user agents (to simulate different consumers) client.get("/", headers={"User-Agent": "Consumer1/1.0"}) @@ -258,7 +265,7 @@ def test_dashboard_assets_endpoint(app_with_middleware): def test_metrics_endpoints_with_no_data(app_with_middleware): """Test metrics endpoints with empty database.""" client = TestClient(app_with_middleware) - clear_db_logs() + clear_db_logs(app_with_middleware) # All endpoints should return empty data, not errors endpoints = [ @@ -293,7 +300,7 @@ def test_metrics_endpoints_error_handling(app_with_middleware): def test_integrated_metrics_workflow(app_with_middleware): """Test complete workflow: generate traffic, then check all metrics.""" client = TestClient(app_with_middleware) - clear_db_logs() + clear_db_logs(app_with_middleware) # Generate diverse traffic client.get("/") # Success diff --git a/tests/test_middleware.py b/tests/test_middleware.py index c68febd..131f89f 100644 --- a/tests/test_middleware.py +++ b/tests/test_middleware.py @@ -9,10 +9,14 @@ from devtrack_sdk.middleware.base import DevTrackMiddleware -def clear_db_logs(): +def clear_db_logs(app=None): """Clear all logs from the database for testing.""" try: - db = get_db() + if app and hasattr(app.state, "db"): + # Use the db instance from app state if available + db = app.state.db + else: + db = get_db(read_only=False) # Need write access to delete logs db.delete_all_logs() except Exception: # Database might be closed, ignore @@ -32,14 +36,17 @@ def app_with_middleware(): if os.path.exists(db_path): os.unlink(db_path) - # Initialize database with temporary file - init_db(db_path) + # Initialize database with temporary file (write mode for middleware) + db = init_db(db_path, read_only=False) app = FastAPI() app.include_router(devtrack_router) - db = get_db() app.add_middleware(DevTrackMiddleware, db_instance=db) + # Store db in app state so tests can access it + app.state.db = db + app.state.db_path = db_path + @app.get("/") async def root(): return {"message": "Hello"} @@ -79,12 +86,13 @@ async def user_profile(user_id: int): def test_root_logging(app_with_middleware): client = TestClient(app_with_middleware) - clear_db_logs() + clear_db_logs(app_with_middleware) response = client.get("/") assert response.status_code == 200 - db = get_db() + # Use the same db instance that middleware uses + db = app_with_middleware.state.db logs = db.get_all_logs() assert len(logs) == 1 log_entry = logs[0] @@ -106,7 +114,8 @@ def test_error_logging(app_with_middleware): response = client.get("/error") assert response.status_code == 400 - db = get_db() + # Use the same db instance that middleware uses + db = app_with_middleware.state.db logs = db.get_all_logs() assert len(logs) == 1 log_entry = logs[0] @@ -127,7 +136,8 @@ def test_post_request_logging(app_with_middleware): response = client.post("/users", json={"name": "Test User"}) assert response.status_code == 200 - db = get_db() + # Use the same db instance that middleware uses + db = app_with_middleware.state.db logs = db.get_all_logs() assert len(logs) == 1 log_entry = logs[0] @@ -182,7 +192,8 @@ def test_excluded_paths_not_logged(app_with_middleware): client.get("/redoc") client.get("/openapi.json") - db = get_db() + # Use the same db instance that middleware uses + db = app_with_middleware.state.db logs = db.get_all_logs() assert len(logs) == 0 @@ -204,7 +215,8 @@ def test_path_pattern_normalization(app_with_middleware): response = client.get("/users/123/profile") assert response.status_code == 200 - db = get_db() + # Use the same db instance that middleware uses + db = app_with_middleware.state.db logs = db.get_all_logs() assert len(logs) == 1 log_entry = logs[0] @@ -235,7 +247,8 @@ def test_middleware_logging(app_with_middleware): response = client.get("/") assert response.status_code == 200 - db = get_db() + # Use the same db instance that middleware uses + db = app_with_middleware.state.db logs = db.get_all_logs() assert len(logs) == 1 assert logs[0]["status_code"] == 200 @@ -286,7 +299,8 @@ def test_delete_all_logs(app_with_middleware): client.post("/users", json={"name": "Test User"}) client.get("/error") - db = get_db() + # Use the same db instance that middleware uses + db = app_with_middleware.state.db logs = db.get_all_logs() assert len(logs) == 3 @@ -313,7 +327,8 @@ def test_delete_logs_by_status_code(app_with_middleware): client.post("/users", json={"name": "Test User"}) # 200 client.get("/error") # 400 - db = get_db() + # Use the same db instance that middleware uses + db = app_with_middleware.state.db logs = db.get_all_logs() assert len(logs) == 3 @@ -341,7 +356,8 @@ def test_delete_logs_by_path_pattern(app_with_middleware): client.post("/users", json={"name": "Test User"}) client.get("/users/123/profile") - db = get_db() + # Use the same db instance that middleware uses + db = app_with_middleware.state.db logs = db.get_all_logs() assert len(logs) == 3 @@ -368,7 +384,8 @@ def test_delete_log_by_id(app_with_middleware): client.get("/") client.post("/users", json={"name": "Test User"}) - db = get_db() + # Use the same db instance that middleware uses + db = app_with_middleware.state.db logs = db.get_all_logs() assert len(logs) == 2 @@ -399,7 +416,8 @@ def test_delete_logs_by_ids(app_with_middleware): client.post("/users", json={"name": "Test User"}) client.get("/error") - db = get_db() + # Use the same db instance that middleware uses + db = app_with_middleware.state.db logs = db.get_all_logs() assert len(logs) == 3 @@ -429,7 +447,8 @@ def test_delete_logs_older_than(app_with_middleware): client.get("/") client.post("/users", json={"name": "Test User"}) - db = get_db() + # Use the same db instance that middleware uses + db = app_with_middleware.state.db logs = db.get_all_logs() assert len(logs) == 2