From e6c9c6dce90d631bcd85e8fd37de812304730742 Mon Sep 17 00:00:00 2001 From: Chris Harris Date: Tue, 12 May 2026 19:59:33 +0000 Subject: [PATCH] Allow number of GPUs to be selected --- .../launcher/interactem/launcher/launcher.py | 1 + .../launcher/templates/launch_agent.sh.j2 | 8 ++++- backend/launcher/tests/expected_script.sh | 8 +++-- backend/launcher/tests/test_rendering.py | 25 +++++++++++++ .../interactem/sfapi_models/__init__.py | 6 ++-- frontend/interactEM/openapi.json | 7 ++++ .../src/client/generated/types.gen.ts | 4 +++ .../src/client/generated/zod.gen.ts | 1 + frontend/interactEM/src/client/index.ts | 1 + .../src/components/agents/launchbutton.tsx | 36 +++++++++++++++++++ 10 files changed, 92 insertions(+), 5 deletions(-) diff --git a/backend/launcher/interactem/launcher/launcher.py b/backend/launcher/interactem/launcher/launcher.py index 1c8a3da6..8293a57f 100644 --- a/backend/launcher/interactem/launcher/launcher.py +++ b/backend/launcher/interactem/launcher/launcher.py @@ -113,6 +113,7 @@ async def submit(msg: NATSMsg, js: JetStreamContext) -> None: walltime=agent_event.duration, reservation=reservation, num_nodes=agent_event.num_nodes, + gpus_per_node=agent_event.gpus_per_node, ) except ValidationError as e: logger.error(f"Failed to parse job request: {e}") diff --git a/backend/launcher/interactem/launcher/templates/launch_agent.sh.j2 b/backend/launcher/interactem/launcher/templates/launch_agent.sh.j2 index 4cd9b3da..9ee4b3bf 100644 --- a/backend/launcher/interactem/launcher/templates/launch_agent.sh.j2 +++ b/backend/launcher/interactem/launcher/templates/launch_agent.sh.j2 @@ -4,4 +4,10 @@ export HDF5_USE_FILE_LOCKING=FALSE export UV_CACHE_DIR="${TMPDIR:-/tmp}/uv-cache-${SLURM_JOB_ID:-$$}" mkdir -p "$UV_CACHE_DIR" cd {{settings.ENV_FILE_DIR}} -srun --nodes={{job.num_nodes}} --ntasks-per-node=1 uv run --project {{settings.AGENT_PROJECT_DIR}} interactem-agent \ No newline at end of file +srun \ + --nodes={{ job.num_nodes }} \ + --ntasks-per-node=1 \ + {%- if job.constraint == "gpu" %} + --gpus-per-node={{ job.gpus_per_node }} \ + {%- endif %} + uv run --project {{ settings.AGENT_PROJECT_DIR }} interactem-agent \ No newline at end of file diff --git a/backend/launcher/tests/expected_script.sh b/backend/launcher/tests/expected_script.sh index 8673b7dd..5281b4a8 100644 --- a/backend/launcher/tests/expected_script.sh +++ b/backend/launcher/tests/expected_script.sh @@ -9,7 +9,7 @@ #SBATCH --qos=normal #SBATCH --constraint=gpu -#SBATCH --time=01:30:00 +#SBATCH --time=01:30:00 #SBATCH --account=test_account #SBATCH --nodes=2 #SBATCH --exclusive @@ -18,4 +18,8 @@ export HDF5_USE_FILE_LOCKING=FALSE export UV_CACHE_DIR="${TMPDIR:-/tmp}/uv-cache-${SLURM_JOB_ID:-$$}" mkdir -p "$UV_CACHE_DIR" cd /path/to/.env -srun --nodes=2 --ntasks-per-node=1 uv run --project /path/to/interactEM/backend/agent interactem-agent \ No newline at end of file +srun \ + --nodes=2 \ + --ntasks-per-node=1 \ + --gpus-per-node=4 \ + uv run --project /path/to/interactEM/backend/agent interactem-agent \ No newline at end of file diff --git a/backend/launcher/tests/test_rendering.py b/backend/launcher/tests/test_rendering.py index c616d196..ab24e95a 100644 --- a/backend/launcher/tests/test_rendering.py +++ b/backend/launcher/tests/test_rendering.py @@ -40,3 +40,28 @@ async def test_submit_rendering(expected_script: str): ) assert script == expected_script + + +@pytest.mark.asyncio +async def test_submit_rendering_uses_configured_gpus_per_node(): + job_req = JobSubmitEvent( + machine=Machine.perlmutter, + account="test_account", + qos="normal", + constraint="gpu", + walltime=timedelta(hours=1, minutes=30), + reservation=None, + num_nodes=2, + gpus_per_node=2, + ) + + jinja_env = Environment( + loader=PackageLoader("interactem.launcher"), enable_async=True + ) + template = jinja_env.get_template(LAUNCH_AGENT_TEMPLATE) + + script = await template.render_async( + job=job_req.model_dump(), settings=cfg.model_dump() + ) + + assert "--gpus-per-node=2" in script diff --git a/backend/sfapi_models/interactem/sfapi_models/__init__.py b/backend/sfapi_models/interactem/sfapi_models/__init__.py index c2a78b66..e606d361 100644 --- a/backend/sfapi_models/interactem/sfapi_models/__init__.py +++ b/backend/sfapi_models/interactem/sfapi_models/__init__.py @@ -2,7 +2,7 @@ from enum import Enum from typing import Any -from pydantic import BaseModel, model_validator +from pydantic import BaseModel, Field, model_validator from sfapi_client._models import StatusValue class Machine(str, Enum): @@ -26,6 +26,7 @@ class AgentCreateEvent(BaseModel): duration: datetime.timedelta compute_type: ComputeType num_nodes: int + gpus_per_node: int | None = Field(default=None, ge=1) extra: dict[str, Any] | None = None @@ -37,6 +38,7 @@ class JobSubmitEvent(BaseModel): walltime: datetime.timedelta | str reservation: str | None = None num_nodes: int = 1 + gpus_per_node: int | None = Field(default=None, ge=1) @model_validator(mode="after") def format_walltime(self) -> "JobSubmitEvent": @@ -55,4 +57,4 @@ def format_walltime(self) -> "JobSubmitEvent": hours, remainder = divmod(total_seconds, 3600) minutes, seconds = divmod(remainder, 60) self.walltime = f"{hours:02}:{minutes:02}:{seconds:02}" - return self \ No newline at end of file + return self diff --git a/frontend/interactEM/openapi.json b/frontend/interactEM/openapi.json index 2bde9b81..0bf6dab8 100644 --- a/frontend/interactEM/openapi.json +++ b/frontend/interactEM/openapi.json @@ -1348,6 +1348,13 @@ }, "compute_type": { "$ref": "#/components/schemas/ComputeType" }, "num_nodes": { "type": "integer", "title": "Num Nodes" }, + "gpus_per_node": { + "anyOf": [ + { "type": "integer", "minimum": 1.0 }, + { "type": "null" } + ], + "title": "Gpus Per Node" + }, "extra": { "anyOf": [ { "additionalProperties": true, "type": "object" }, diff --git a/frontend/interactEM/src/client/generated/types.gen.ts b/frontend/interactEM/src/client/generated/types.gen.ts index 91124abd..384987d3 100644 --- a/frontend/interactEM/src/client/generated/types.gen.ts +++ b/frontend/interactEM/src/client/generated/types.gen.ts @@ -18,6 +18,10 @@ export type AgentCreateEvent = { * Num Nodes */ num_nodes: number + /** + * Gpus Per Node + */ + gpus_per_node?: number | null /** * Extra */ diff --git a/frontend/interactEM/src/client/generated/zod.gen.ts b/frontend/interactEM/src/client/generated/zod.gen.ts index 0dd33091..7f1d616a 100644 --- a/frontend/interactEM/src/client/generated/zod.gen.ts +++ b/frontend/interactEM/src/client/generated/zod.gen.ts @@ -51,6 +51,7 @@ export const zAgentCreateEvent = z.object({ duration: z.string(), compute_type: zComputeType, num_nodes: z.number().int(), + gpus_per_node: z.union([z.number().int().gte(1), z.null()]).optional(), extra: z.union([z.record(z.unknown()), z.null()]).optional(), }) diff --git a/frontend/interactEM/src/client/index.ts b/frontend/interactEM/src/client/index.ts index 53ad29f1..ef80e364 100644 --- a/frontend/interactEM/src/client/index.ts +++ b/frontend/interactEM/src/client/index.ts @@ -16,6 +16,7 @@ import { export const zAgentCreateEvent = zAgentCreateEventTmp.extend({ duration: z.string().time(), num_nodes: z.coerce.number().int(), + gpus_per_node: z.coerce.number().int().positive().optional().nullable(), }) // This is a workaround for the generated data field being an object. diff --git a/frontend/interactEM/src/components/agents/launchbutton.tsx b/frontend/interactEM/src/components/agents/launchbutton.tsx index 55c829c2..8720b7f4 100644 --- a/frontend/interactEM/src/components/agents/launchbutton.tsx +++ b/frontend/interactEM/src/components/agents/launchbutton.tsx @@ -36,16 +36,21 @@ export const LaunchAgentButton = () => { const { control, handleSubmit, + watch, formState: { errors }, } = useForm({ resolver: zodResolver(zAgentCreateEvent), + shouldUnregister: true, defaultValues: { machine: "perlmutter", compute_type: "cpu", duration: "01:00:00", num_nodes: 1, + gpus_per_node: 4, }, }) + const computeType = watch("compute_type") + const onSubmit: SubmitHandler = useCallback( async (formData: AgentCreateEvent) => { try { @@ -145,6 +150,37 @@ export const LaunchAgentButton = () => { /> + {computeType === "gpu" && ( + + ( + + )} + /> + + )} +