Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 23 additions & 3 deletions gcm/schemas/slurm/squeue.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
import logging

from dataclasses import dataclass, field, fields

from gcm.monitoring.clock import time_to_time_aware
Expand All @@ -14,6 +16,24 @@
from gcm.schemas.dataclass import parsed_field
from gcm.schemas.slurm.derived_cluster import DerivedCluster

logger = logging.getLogger(__name__)

_MAX_NODELIST_ENTRIES = 1000


def _truncated_nodelist(s: str) -> list[str] | None:
"""Parse a Slurm nodelist, truncating to _MAX_NODELIST_ENTRIES with '...' marker."""
parsed, _ = nodelist()(s)
if parsed is None:
return None
if len(parsed) > _MAX_NODELIST_ENTRIES:
logger.warning(
f"Truncating NODELIST from {len(parsed)} to {_MAX_NODELIST_ENTRIES} entries "
f"(first: {parsed[0]}, last: {parsed[-1]})"
)
return parsed[:_MAX_NODELIST_ENTRIES] + ["..."]
return parsed


@dataclass(kw_only=True)
class JobData(DerivedCluster):
Expand All @@ -39,10 +59,10 @@ class JobData(DerivedCluster):
NODES: int | None = parsed_field(parser=maybe_int, field_name="NUMNODES")
TIME_LEFT: str = parsed_field(parser=str, field_name="TIMELEFT")
TIME_USED: str = parsed_field(parser=str, field_name="TIMEUSED")
NODELIST: list[str] | None = parsed_field(parser=lambda s: nodelist()(s)[0])
NODELIST: list[str] | None = parsed_field(parser=_truncated_nodelist)
DEPENDENCY: str = parsed_field(parser=str)
EXC_NODES: list[str] | None = parsed_field(
parser=lambda s: nodelist()(s)[0], field_name="EXCNODES"
parser=_truncated_nodelist, field_name="EXCNODES"
)
START_TIME: str = parsed_field(parser=time_to_time_aware, field_name="STARTTIME")
SUBMIT_TIME: str = parsed_field(parser=time_to_time_aware, field_name="SUBMITTIME")
Expand Down Expand Up @@ -75,7 +95,7 @@ class JobData(DerivedCluster):
REQUEUE: str = parsed_field(parser=str)
FEATURE: str = parsed_field(parser=str)
RESTARTCNT: int = parsed_field(parser=int)
SCHEDNODES: list[str] | None = parsed_field(parser=lambda s: nodelist()(s)[0])
SCHEDNODES: list[str] | None = parsed_field(parser=_truncated_nodelist)
LAST_SCHED_EVAL: str = parsed_field(
parser=time_to_time_aware, field_name="LASTSCHEDEVAL"
)
Expand Down
39 changes: 38 additions & 1 deletion gcm/tests/test_slurm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from gcm.schemas.slurm.sinfo import Sinfo
from gcm.schemas.slurm.sinfo_node import SinfoNode
from gcm.schemas.slurm.squeue import JobData
from gcm.schemas.slurm.squeue import _truncated_nodelist, JobData
from gcm.tests import data

TEST_CLUSTER = "test_cluster"
Expand Down Expand Up @@ -653,3 +653,40 @@ def test_parse_sdiag_json_with_explicit_null_timestamp_objects(
assert result.job_states_ts is None
assert result.bf_when_last_cycle is None
mock_reset.assert_called_once()


class TestTruncatedNodelist:
"""Tests for the _truncated_nodelist helper that caps oversized nodelists."""

def test_normal_nodelist_unchanged(self) -> None:
"""Nodelists with <= 1000 entries pass through unchanged."""
# A small nodelist: "node[001-010]" expands to 10 entries
result = _truncated_nodelist("node[001-010]")
assert result is not None
assert len(result) == 10
assert result[0] == "node001"
assert result[-1] == "node010"
# No "..." marker
assert "..." not in result

def test_large_nodelist_truncated(self) -> None:
"""Nodelists with > 1000 entries are truncated to 1000 + '...' marker."""
# A large nodelist: "h200-[0001-2000]" expands to 2000 entries
result = _truncated_nodelist("h200-[0001-2000]")
assert result is not None
assert len(result) == 1001 # 1000 entries + "..."
assert result[0] == "h200-0001"
assert result[999] == "h200-1000"
assert result[-1] == "..."

def test_exactly_1000_entries_unchanged(self) -> None:
"""Nodelists with exactly 1000 entries are not truncated."""
result = _truncated_nodelist("n[0001-1000]")
assert result is not None
assert len(result) == 1000
assert "..." not in result

def test_empty_nodelist_returns_none(self) -> None:
"""Empty/unparseable nodelists return None."""
result = _truncated_nodelist("")
assert result is None
Loading