Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backend/launcher/interactem/launcher/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ async def submit(msg: NATSMsg, js: JetStreamContext) -> None:
walltime=agent_event.duration,
reservation=reservation,
num_nodes=agent_event.num_nodes,
gpus_per_node=agent_event.gpus_per_node,
)
except ValidationError as e:
logger.error(f"Failed to parse job request: {e}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,10 @@ export HDF5_USE_FILE_LOCKING=FALSE
export UV_CACHE_DIR="${TMPDIR:-/tmp}/uv-cache-${SLURM_JOB_ID:-$$}"
mkdir -p "$UV_CACHE_DIR"
cd {{settings.ENV_FILE_DIR}}
srun --nodes={{job.num_nodes}} --ntasks-per-node=1 uv run --project {{settings.AGENT_PROJECT_DIR}} interactem-agent
srun \
--nodes={{ job.num_nodes }} \
--ntasks-per-node=1 \
{%- if job.constraint == "gpu" %}
--gpus-per-node={{ job.gpus_per_node }} \
{%- endif %}
uv run --project {{ settings.AGENT_PROJECT_DIR }} interactem-agent
8 changes: 6 additions & 2 deletions backend/launcher/tests/expected_script.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

#SBATCH --qos=normal
#SBATCH --constraint=gpu
#SBATCH --time=01:30:00
#SBATCH --time=01:30:00
#SBATCH --account=test_account
#SBATCH --nodes=2
#SBATCH --exclusive
Expand All @@ -18,4 +18,8 @@ export HDF5_USE_FILE_LOCKING=FALSE
export UV_CACHE_DIR="${TMPDIR:-/tmp}/uv-cache-${SLURM_JOB_ID:-$$}"
mkdir -p "$UV_CACHE_DIR"
cd /path/to/.env
srun --nodes=2 --ntasks-per-node=1 uv run --project /path/to/interactEM/backend/agent interactem-agent
srun \
--nodes=2 \
--ntasks-per-node=1 \
--gpus-per-node=4 \
uv run --project /path/to/interactEM/backend/agent interactem-agent
25 changes: 25 additions & 0 deletions backend/launcher/tests/test_rendering.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,28 @@ async def test_submit_rendering(expected_script: str):
)

assert script == expected_script


@pytest.mark.asyncio
async def test_submit_rendering_uses_configured_gpus_per_node():
job_req = JobSubmitEvent(
machine=Machine.perlmutter,
account="test_account",
qos="normal",
constraint="gpu",
walltime=timedelta(hours=1, minutes=30),
reservation=None,
num_nodes=2,
gpus_per_node=2,
)

jinja_env = Environment(
loader=PackageLoader("interactem.launcher"), enable_async=True
)
template = jinja_env.get_template(LAUNCH_AGENT_TEMPLATE)

script = await template.render_async(
job=job_req.model_dump(), settings=cfg.model_dump()
)

assert "--gpus-per-node=2" in script
6 changes: 4 additions & 2 deletions backend/sfapi_models/interactem/sfapi_models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from enum import Enum
from typing import Any

from pydantic import BaseModel, model_validator
from pydantic import BaseModel, Field, model_validator
from sfapi_client._models import StatusValue

class Machine(str, Enum):
Expand All @@ -26,6 +26,7 @@ class AgentCreateEvent(BaseModel):
duration: datetime.timedelta
compute_type: ComputeType
num_nodes: int
gpus_per_node: int | None = Field(default=None, ge=1)
extra: dict[str, Any] | None = None


Expand All @@ -37,6 +38,7 @@ class JobSubmitEvent(BaseModel):
walltime: datetime.timedelta | str
reservation: str | None = None
num_nodes: int = 1
gpus_per_node: int | None = Field(default=None, ge=1)

@model_validator(mode="after")
def format_walltime(self) -> "JobSubmitEvent":
Expand All @@ -55,4 +57,4 @@ def format_walltime(self) -> "JobSubmitEvent":
hours, remainder = divmod(total_seconds, 3600)
minutes, seconds = divmod(remainder, 60)
self.walltime = f"{hours:02}:{minutes:02}:{seconds:02}"
return self
return self
7 changes: 7 additions & 0 deletions frontend/interactEM/openapi.json
Original file line number Diff line number Diff line change
Expand Up @@ -1348,6 +1348,13 @@
},
"compute_type": { "$ref": "#/components/schemas/ComputeType" },
"num_nodes": { "type": "integer", "title": "Num Nodes" },
"gpus_per_node": {
"anyOf": [
{ "type": "integer", "minimum": 1.0 },
{ "type": "null" }
],
"title": "Gpus Per Node"
},
"extra": {
"anyOf": [
{ "additionalProperties": true, "type": "object" },
Expand Down
4 changes: 4 additions & 0 deletions frontend/interactEM/src/client/generated/types.gen.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ export type AgentCreateEvent = {
* Num Nodes
*/
num_nodes: number
/**
* Gpus Per Node
*/
gpus_per_node?: number | null
/**
* Extra
*/
Expand Down
1 change: 1 addition & 0 deletions frontend/interactEM/src/client/generated/zod.gen.ts
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ export const zAgentCreateEvent = z.object({
duration: z.string(),
compute_type: zComputeType,
num_nodes: z.number().int(),
gpus_per_node: z.union([z.number().int().gte(1), z.null()]).optional(),
extra: z.union([z.record(z.unknown()), z.null()]).optional(),
})

Expand Down
1 change: 1 addition & 0 deletions frontend/interactEM/src/client/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import {
export const zAgentCreateEvent = zAgentCreateEventTmp.extend({
duration: z.string().time(),
num_nodes: z.coerce.number().int(),
gpus_per_node: z.coerce.number().int().positive().optional().nullable(),
})

// This is a workaround for the generated data field being an object.
Expand Down
36 changes: 36 additions & 0 deletions frontend/interactEM/src/components/agents/launchbutton.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,21 @@ export const LaunchAgentButton = () => {
const {
control,
handleSubmit,
watch,
formState: { errors },
} = useForm<AgentCreateEvent>({
resolver: zodResolver(zAgentCreateEvent),
shouldUnregister: true,

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does this bad boy do?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It removes gpus_per_node from payload if the user selects cpu

defaultValues: {
machine: "perlmutter",
compute_type: "cpu",
duration: "01:00:00",
num_nodes: 1,
gpus_per_node: 4,
},
})
const computeType = watch("compute_type")

const onSubmit: SubmitHandler<AgentCreateEvent> = useCallback(
async (formData: AgentCreateEvent) => {
try {
Expand Down Expand Up @@ -145,6 +150,37 @@ export const LaunchAgentButton = () => {
/>
</FormControl>

{computeType === "gpu" && (
<FormControl fullWidth sx={{ marginBottom: "1rem" }}>
<Controller
name="gpus_per_node"
control={control}
render={({ field }) => (
<TextField
label="GPUs per Node"
variant="outlined"
type="number"
onChange={field.onChange}
onBlur={field.onBlur}
value={field.value ?? ""}
fullWidth
error={!!errors.gpus_per_node}
helperText={
errors.gpus_per_node
? "This should be a positive integer."
: ""
}
slotProps={{
htmlInput: {
min: 1,
},
}}
/>
)}
/>
</FormControl>
)}

<FormControl fullWidth sx={{ marginBottom: "1rem" }}>
<Controller
name="duration"
Expand Down
Loading