From 4c83d7b508fd5459104b521c4381b30e236ca693 Mon Sep 17 00:00:00 2001 From: Minh Vu Date: Sun, 28 Jun 2026 15:54:07 +0200 Subject: [PATCH] Fix Slurm noquote marker handling --- nemo_run/core/execution/slurm.py | 16 ++++++++++++---- test/core/execution/test_slurm_templates.py | 18 ++++++++++++++++++ 2 files changed, 30 insertions(+), 4 deletions(-) diff --git a/nemo_run/core/execution/slurm.py b/nemo_run/core/execution/slurm.py index 2a794e30..d98f8784 100644 --- a/nemo_run/core/execution/slurm.py +++ b/nemo_run/core/execution/slurm.py @@ -22,7 +22,7 @@ import warnings from dataclasses import asdict, dataclass, field from pathlib import Path -from typing import Any, Dict, Optional, Type, TypeAlias, Union +from typing import Any, Dict, Optional, Type, Union import invoke from invoke.context import Context @@ -52,7 +52,13 @@ from nemo_run.devspace.base import DevSpace logger = logging.getLogger(__name__) -noquote: TypeAlias = str + + +class NoQuote(str): + """Marker for strings that must bypass shell quoting.""" + + +noquote = NoQuote @dataclass(kw_only=True) @@ -1011,10 +1017,12 @@ def get_container_flags( _srun_args = ["--wait=60", "--kill-on-bad-exit=1"] _srun_args.extend(resource_req.srun_args or []) else: - cmd_stdout = srun_stdout.replace(original_job_name, self.jobs[group_ind]) + cmd_stdout = noquote(srun_stdout.replace(original_job_name, self.jobs[group_ind])) cmd_stderr = stderr_flags.copy() if cmd_stderr: - cmd_stderr[-1] = cmd_stderr[-1].replace(original_job_name, self.jobs[group_ind]) + cmd_stderr[-1] = noquote( + cmd_stderr[-1].replace(original_job_name, self.jobs[group_ind]) + ) _container_flags = get_container_flags( base_mounts=self.executor.container_mounts, src_job_dir=os.path.join( diff --git a/test/core/execution/test_slurm_templates.py b/test/core/execution/test_slurm_templates.py index 5457563d..9192919c 100644 --- a/test/core/execution/test_slurm_templates.py +++ b/test/core/execution/test_slurm_templates.py @@ -593,6 +593,24 @@ def test_group_resource_req_batch_request_materialize( expected = Path(artifact).read_text() assert sbatch_script.strip() == expected.strip() + def test_group_batch_request_quotes_regular_srun_args_but_not_log_paths( + self, + group_slurm_request_with_artifact: tuple[SlurmBatchRequest, str], + ): + group_slurm_request, _ = group_slurm_request_with_artifact + executor = group_slurm_request.executor + group_slurm_request.executor = SlurmExecutor.merge([executor], num_tasks=2) + group_slurm_request.executor.srun_args = ["--comment=hello world"] + self.apply_macros(executor) + + sbatch_script = group_slurm_request.materialize() + + assert "'--comment=hello world'" in sbatch_script + assert ( + "--output /some/job/dir/sample_job/log-your_account-account.sample_job-0_%j_${SLURM_RESTART_COUNT:-0}.out " + "--container-image some-image" + ) in sbatch_script + def test_group_resource_req_request_custom_job_details( self, group_resource_req_slurm_request_with_artifact: tuple[SlurmBatchRequest, str],