Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 3 additions & 9 deletions src/ezmsg/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,17 +75,11 @@

import os
import logging
from .logconfig import create_ezmsg_stderr_handler, create_ezmsg_stdout_handler

logger = logging.getLogger("ezmsg")
handler = logging.StreamHandler()
formatter = logging.Formatter(
"%(asctime)s.%(msecs)03d - pid: %(process)d - %(threadName)s "
+ "- %(levelname)s - %(funcName)s: %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)

handler.setFormatter(formatter)
logger.addHandler(handler)
logger.addHandler(create_ezmsg_stdout_handler())
logger.addHandler(create_ezmsg_stderr_handler())

LOGLEVEL = os.environ.get("EZMSG_LOGLEVEL", "INFO").upper()
logger.setLevel(LOGLEVEL)
Expand Down
2 changes: 2 additions & 0 deletions src/ezmsg/core/command.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ async def run_command(
compact: int | None = None,
nobrowser: bool = False,
dashboard: int | bool | None = None,
log_file: str | None = None,
) -> None:
handlers = {
"dashboard": None,
Expand Down Expand Up @@ -106,6 +107,7 @@ async def run_command(
port=8000,
open_browser=False,
log_level="info",
log_file=log_file,
)
result = handlers[cmd](args)
if inspect.isawaitable(result):
Expand Down
74 changes: 74 additions & 0 deletions src/ezmsg/core/commands/common.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,28 @@
import argparse
from collections.abc import Iterator
from contextlib import contextmanager
import logging
import os
from datetime import datetime
from pathlib import Path

from ..graphserver import GraphService
from ..logconfig import create_ezmsg_log_formatter
from ..netprotocol import Address


def add_address_argument(parser: argparse.ArgumentParser) -> None:
parser.add_argument("--address", help="Address for GraphServer", default=None)


def add_log_file_argument(parser: argparse.ArgumentParser) -> None:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to plumb this into run_command() too.

parser.add_argument(
"--log-file",
help="Path to the ezmsg service log file",
default=None,
)


def add_compact_argument(parser: argparse.ArgumentParser) -> None:
parser.add_argument(
"-c",
Expand All @@ -24,3 +39,62 @@ def graph_address_from_args(args: argparse.Namespace) -> Address:
if args.address is None:
return GraphService.default_address()
return Address.from_string(args.address)


def resolve_log_file(args: argparse.Namespace, address: Address) -> Path:
if args.log_file is not None:
return Path(args.log_file).expanduser()

env_log_file = os.environ.get("EZMSG_LOG_FILE")
if env_log_file is not None:
return Path(env_log_file).expanduser()

if os.name == "nt":
data_home = Path(
os.environ.get("LOCALAPPDATA", Path.home() / "AppData" / "Local")
)
else:
data_home = Path(
os.environ.get("XDG_DATA_HOME", Path.home() / ".local" / "share")
)

log_dir = data_home / "ezmsg" / "logs" / f"GraphServer-{address.port}"
timestamp = datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
return log_dir / f"{timestamp}.log"


def _configure_managed_log_file(log_file: Path) -> tuple[Path, logging.FileHandler | None]:
log_path = log_file.expanduser().resolve()
log_path.parent.mkdir(parents=True, exist_ok=True)

logger = logging.getLogger("ezmsg")
if any(
isinstance(handler, logging.FileHandler)
and getattr(handler, "baseFilename", None) == str(log_path)
for handler in logger.handlers
):
return log_path, None

handler = logging.FileHandler(log_path, encoding="utf-8")
handler.setFormatter(create_ezmsg_log_formatter())
logger.addHandler(handler)

return log_path, handler


def configure_log_file(log_file: Path) -> Path:
log_path, _ = _configure_managed_log_file(log_file)

return log_path


@contextmanager
def managed_log_file(log_file: Path) -> Iterator[Path]:
log_path, handler = _configure_managed_log_file(log_file)
try:
yield log_path
finally:
if handler is not None:
logger = logging.getLogger("ezmsg")
logger.removeHandler(handler)
handler.close()
57 changes: 33 additions & 24 deletions src/ezmsg/core/commands/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,13 @@
import logging

from ..graphserver import GraphService
from .common import add_address_argument, graph_address_from_args
from .common import (
add_address_argument,
add_log_file_argument,
graph_address_from_args,
managed_log_file,
resolve_log_file,
)
from .dashboard import (
DashboardDependencyError,
add_dashboard_argument,
Expand All @@ -15,33 +21,36 @@

async def handle_serve(args: argparse.Namespace) -> None:
graph_address = graph_address_from_args(args)
graph_service = GraphService(graph_address)

logger.info(f"GraphServer Address: {graph_address}")
graph_server = graph_service.create_server()
dashboard_server = None

try:
if args.dashboard is not None:
dashboard_port = args.dashboard if type(args.dashboard) is int else None
dashboard_server = start_dashboard(
graph_service.address, dashboard_port=dashboard_port
)
logger.info(f"Dashboard Address: {dashboard_server.url}")
logger.info("Servers running...")
await asyncio.to_thread(graph_server.join)
except (KeyboardInterrupt, asyncio.CancelledError):
logger.info("Interrupt detected; shutting down servers")
except DashboardDependencyError as exc:
logger.warning(str(exc))
finally:
if dashboard_server is not None:
dashboard_server.stop()
graph_server.stop()
with managed_log_file(resolve_log_file(args, graph_address)) as log_path:
graph_service = GraphService(graph_address)

logger.info(f"GraphServer Address: {graph_address}")
logger.info(f"GraphServer Log File: {log_path}")
graph_server = graph_service.create_server()
dashboard_server = None

try:
if args.dashboard is not None:
dashboard_port = args.dashboard if type(args.dashboard) is int else None
dashboard_server = start_dashboard(
graph_service.address, dashboard_port=dashboard_port
)
logger.info(f"Dashboard Address: {dashboard_server.url}")
logger.info("Servers running...")
await asyncio.to_thread(graph_server.join)
except (KeyboardInterrupt, asyncio.CancelledError):
logger.info("Interrupt detected; shutting down servers")
except DashboardDependencyError as exc:
logger.warning(str(exc))
finally:
if dashboard_server is not None:
dashboard_server.stop()
graph_server.stop()


def setup_serve_cmdline(subparsers: argparse._SubParsersAction) -> None:
parser = subparsers.add_parser("serve")
add_address_argument(parser)
add_log_file_argument(parser)
add_dashboard_argument(parser)
parser.set_defaults(_handler=handle_serve)
5 changes: 4 additions & 1 deletion src/ezmsg/core/commands/start.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from ..graphserver import GraphService
from ..netprotocol import close_stream_writer
from .common import add_address_argument, graph_address_from_args
from .common import add_address_argument, add_log_file_argument, graph_address_from_args
from .dashboard import (
DashboardDependencyError,
add_dashboard_argument,
Expand All @@ -20,6 +20,8 @@ async def handle_start(args: argparse.Namespace) -> None:
graph_address = graph_address_from_args(args)
graph_service = GraphService(graph_address)
cmd = [sys.executable, "-m", "ezmsg.core", "serve", f"--address={graph_address}"]
if args.log_file is not None:
cmd.append(f"--log-file={args.log_file}")
if args.dashboard is not None:
try:
require_dashboard_dependency()
Expand All @@ -46,5 +48,6 @@ async def handle_start(args: argparse.Namespace) -> None:
def setup_start_cmdline(subparsers: argparse._SubParsersAction) -> None:
parser = subparsers.add_parser("start")
add_address_argument(parser)
add_log_file_argument(parser)
add_dashboard_argument(parser)
parser.set_defaults(_handler=handle_start)
47 changes: 47 additions & 0 deletions src/ezmsg/core/logconfig.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import logging
import sys
from typing import TextIO


EZMSG_LOG_FORMAT = (
"%(asctime)s.%(msecs)03d - pid: %(process)d - %(threadName)s "
"- %(levelname)s - %(funcName)s: %(message)s"
)
EZMSG_LOG_DATE_FORMAT = "%Y-%m-%d %H:%M:%S"
EZMSG_STDERR_LOG_LEVEL = logging.WARNING


class BelowLevelFilter(logging.Filter):
def __init__(self, level: int):
super().__init__()
self.level = level

def filter(self, record: logging.LogRecord) -> bool:
return record.levelno < self.level


class AtOrAboveLevelFilter(logging.Filter):
def __init__(self, level: int):
super().__init__()
self.level = level

def filter(self, record: logging.LogRecord) -> bool:
return record.levelno >= self.level


def create_ezmsg_log_formatter() -> logging.Formatter:
return logging.Formatter(EZMSG_LOG_FORMAT, datefmt=EZMSG_LOG_DATE_FORMAT)


def create_ezmsg_stdout_handler(stream: TextIO | None = None) -> logging.StreamHandler:
handler = logging.StreamHandler(sys.stdout if stream is None else stream)
handler.setFormatter(create_ezmsg_log_formatter())
handler.addFilter(BelowLevelFilter(EZMSG_STDERR_LOG_LEVEL))
return handler


def create_ezmsg_stderr_handler(stream: TextIO | None = None) -> logging.StreamHandler:
handler = logging.StreamHandler(sys.stderr if stream is None else stream)
handler.setFormatter(create_ezmsg_log_formatter())
handler.addFilter(AtOrAboveLevelFilter(EZMSG_STDERR_LOG_LEVEL))
return handler
86 changes: 85 additions & 1 deletion tests/test_command.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import pytest
import argparse
import asyncio
import sys
from pathlib import Path

from ezmsg.core.command import build_parser, cmdline
from ezmsg.core.command import build_parser, cmdline, run_command
from ezmsg.core.netprotocol import Address
from ezmsg.core.commands.start import handle_start


def test_mermaid_subparser_accepts_mermaid_specific_args():
Expand Down Expand Up @@ -140,6 +144,15 @@ def test_start_subparser_accepts_dashboard_flag():
assert args.dashboard is True


def test_serve_subparser_accepts_log_file():
parser = build_parser()

args = parser.parse_args(["serve", "--log-file", "/tmp/ezmsg.log"])

assert args.command == "serve"
assert args.log_file == "/tmp/ezmsg.log"


def test_serve_subparser_accepts_dashboard_port():
parser = build_parser()

Expand Down Expand Up @@ -222,3 +235,74 @@ def fake_import(name, globals=None, locals=None, fromlist=(), level=0):
args._handler(args)

assert "pip install ezmsg-dashboard" in caplog.text


def test_run_command_passes_log_file_to_handler(monkeypatch):
captured_args = []

async def fake_handle_start(args):
captured_args.append(args)

monkeypatch.setattr("ezmsg.core.command.handle_start", fake_handle_start)

asyncio.run(
run_command(
"start",
Address("127.0.0.1", 25978),
log_file="/tmp/ezmsg.log",
)
)

assert len(captured_args) == 1
assert captured_args[0].log_file == "/tmp/ezmsg.log"


def test_start_passes_log_file_to_serve(monkeypatch):
commands = []

class DummyPopen:
pid = 123

def __init__(self, cmd):
commands.append(cmd)

class DummyWriter:
def close(self):
pass

async def wait_closed(self):
pass

class DummyGraphService:
async def open_connection(self):
return object(), DummyWriter()

async def noop_close_stream_writer(writer):
return None

monkeypatch.setattr("ezmsg.core.commands.start.subprocess.Popen", DummyPopen)
monkeypatch.setattr(
"ezmsg.core.commands.start.GraphService", lambda address: DummyGraphService()
)
monkeypatch.setattr(
"ezmsg.core.commands.start.close_stream_writer", noop_close_stream_writer
)

args = argparse.Namespace(
address="127.0.0.1:25978",
dashboard=None,
log_file="/tmp/ezmsg.log",
)

asyncio.run(handle_start(args))

assert commands == [
[
sys.executable,
"-m",
"ezmsg.core",
"serve",
"--address=127.0.0.1:25978",
"--log-file=/tmp/ezmsg.log",
]
]
Loading
Loading