Skip to content

Commit 26715cb

Browse files
[Alctz] Various updates to alcatraz comp interface (#87)
1 parent 9b89ef5 commit 26715cb

5 files changed

Lines changed: 191 additions & 110 deletions

File tree

project/common/nanoeval_alcatraz/nanoeval_alcatraz/alcatraz_computer_interface.py

Lines changed: 106 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1+
import os
12
from abc import ABC, abstractmethod
23
from contextlib import asynccontextmanager
3-
from typing import Any, AsyncContextManager, AsyncGenerator, Literal, cast
4+
from typing import Any, AsyncGenerator, Literal, cast
45

56
import async_lru
67
import structlog.stdlib
@@ -9,18 +10,24 @@
910

1011
import chz
1112
from alcatraz.clusters.local import BaseAlcatrazCluster, ClusterConfig, LocalConfig
12-
from nanoeval_alcatraz.task_to_alcatraz_config import task_to_alcatraz_config
1313
from alcatraz.utils.cmds import run_command_streaming
1414
from nanoeval.solvers.computer_tasks.code_execution_interface import (
1515
ComputerConfiguration,
16+
ComputerInterface,
1617
ComputerRuntime,
1718
ExecutionResult,
1819
JupyterComputerInterface,
1920
JupyterExecutionResult,
2021
)
22+
from nanoeval_alcatraz.task_to_alcatraz_config import (
23+
SUPPORTED_NETWORK_MODES,
24+
task_to_alcatraz_config,
25+
)
2126

2227
logger = structlog.stdlib.get_logger(component=__name__)
2328

29+
ALCATRAZ_TIMEOUT = int(os.getenv("ALCATRAZ_TIMEOUT", 120))
30+
2431

2532
class Python3ExceptionDict(BaseModel):
2633
"""A pydantic model for serializing a Python 3.x exception.
@@ -38,10 +45,14 @@ class Python3ExceptionDict(BaseModel):
3845
notes: list[str]
3946

4047

41-
class BaseAlcatrazComputerInterface(JupyterComputerInterface, ABC):
48+
class BaseAlcatrazComputerInterfaceNoJupyter(ComputerInterface, ABC):
49+
"""
50+
Override this class to create a custom AlcatrazComputerInterface without Jupyter support.
51+
"""
52+
4253
@property
4354
@abstractmethod
44-
def cluster(self) -> BaseAlcatrazCluster:
55+
def _cluster(self) -> BaseAlcatrazCluster:
4556
pass
4657

4758
async def disable_internet(self) -> None:
@@ -52,29 +63,44 @@ async def disable_internet(self) -> None:
5263

5364
# Verify
5465
logger.info("Post-setup network access disabled")
55-
logger.info("Verified network access successfully disabled")
66+
try:
67+
from alcatraz.utils.network import assert_internet_disabled # type: ignore
68+
69+
await assert_internet_disabled(self._cluster)
70+
logger.info("Verified network access successfully disabled")
71+
except ImportError:
72+
pass
5673

5774
async def upload(self, file: bytes, destination: str) -> None:
58-
return await self.cluster.upload(file, destination)
75+
return await self._cluster.upload(file, destination)
5976

6077
async def download(self, file: str) -> bytes:
61-
return await self.cluster.download(file)
78+
return await self._cluster.download(file)
6279

63-
async def send_shell_command(self, cmd: str, idempotent: bool = False) -> ExecutionResult:
64-
res = await run_command_streaming(self.cluster, cmd)
80+
async def send_shell_command(self, cmd: str, *, idempotent: bool = False) -> ExecutionResult:
81+
res = await run_command_streaming(self._cluster, cmd)
6582
return ExecutionResult(output=res["result"], exit_code=res["exit_code"])
6683

6784
async def fetch_container_names(self) -> list[str]:
68-
return await self.cluster.fetch_container_names()
85+
return await self._cluster.fetch_container_names()
6986

7087
async def stop(self) -> None:
71-
await self.cluster._stop()
88+
await self._cluster._stop()
89+
90+
91+
@chz.chz
92+
class BaseAlcatrazComputerInterface(
93+
BaseAlcatrazComputerInterfaceNoJupyter, JupyterComputerInterface
94+
):
95+
"""
96+
Override this class to add Jupyter-specific functionality to the AlcatrazComputerInterface.
97+
"""
7298

7399
@override
74-
async def execute(self, code: str, timeout: int = 120) -> JupyterExecutionResult:
75-
await self._start_cluster_once()
100+
async def execute(self, code: str, timeout: int = ALCATRAZ_TIMEOUT) -> JupyterExecutionResult:
101+
await self._start_jupyter_kernel_once()
76102

77-
messages = await self.cluster.send_kernel_command(code, timeout=timeout)
103+
messages = await self._cluster.send_kernel_command(code, timeout=timeout)
78104

79105
# Parse the messages into a final execution result
80106
# TODO(kevinliu) - this may not be a perfect parsing, but it should only really be used for setup and grade so hopefully it's good enough
@@ -108,42 +134,89 @@ async def execute(self, code: str, timeout: int = 120) -> JupyterExecutionResult
108134
)
109135

110136
@async_lru.alru_cache(maxsize=1)
111-
async def _start_cluster_once(self) -> None:
112-
if not await self.cluster.is_kernel_started():
113-
await self.cluster.create_kernel_on_machine()
137+
async def _start_jupyter_kernel_once(self) -> None:
138+
if not await self._cluster.is_kernel_started():
139+
await self._cluster.create_kernel_on_machine()
114140

115141

116142
@chz.chz
117-
class AlcatrazComputerInterface(BaseAlcatrazComputerInterface):
143+
class AlcatrazComputerInterfaceNoJupyter(BaseAlcatrazComputerInterfaceNoJupyter):
118144
cluster_value: BaseAlcatrazCluster
119145

120146
@property
121-
def cluster(self) -> BaseAlcatrazCluster:
147+
def _cluster(self) -> BaseAlcatrazCluster:
122148
return self.cluster_value
123149

124150

125151
@chz.chz
126-
class AlcatrazComputerRuntime(ComputerRuntime):
152+
class AlcatrazComputerInterface(AlcatrazComputerInterfaceNoJupyter, BaseAlcatrazComputerInterface):
153+
pass
154+
155+
156+
@chz.chz
157+
class AlcatrazComputerRuntimeNoJupyter(ComputerRuntime):
127158
env: ClusterConfig = chz.field(default_factory=LocalConfig)
128159

160+
async def _do_runtime_setup(
161+
self, task: ComputerConfiguration, computer: ComputerInterface
162+
) -> None:
163+
assert isinstance(computer, BaseAlcatrazComputerInterfaceNoJupyter)
164+
# Alcatraz must do this dynamically since it doesn't have the ability to remove
165+
# a container's internet access at creation time.
166+
# Other runtimes should do this by configuring the container object itself.
167+
if not task.allow_internet:
168+
logger.info("Disabling internet (since allow_internet is False)")
169+
await computer.disable_internet()
170+
171+
if task.network_mode not in SUPPORTED_NETWORK_MODES:
172+
raise ValueError(
173+
f"The {task.network_mode} network mode is not supported on Alcatraz, and will never be. Please change it, or stop using Alcatraz."
174+
)
175+
129176
@asynccontextmanager
130-
async def run(
177+
async def _start_computer(
131178
self, task: ComputerConfiguration
132-
) -> AsyncGenerator[AlcatrazComputerInterface, None]:
133-
async with task_to_alcatraz_config(task, self.env).build() as cluster:
134-
computer = AlcatrazComputerInterface(cluster_value=cluster)
179+
) -> AsyncGenerator[AlcatrazComputerInterfaceNoJupyter, None]:
180+
async with task_to_alcatraz_config(task, self.env).build() as _cluster:
181+
computer = AlcatrazComputerInterfaceNoJupyter(cluster_value=_cluster)
135182
yield computer
136183

137-
@override
184+
185+
@chz.chz
186+
class AlcatrazComputerRuntime(ComputerRuntime):
187+
env: ClusterConfig = chz.field(default_factory=LocalConfig)
188+
138189
async def _do_runtime_setup(
139-
self, task: ComputerConfiguration, computer: AlcatrazComputerInterface
190+
self, task: ComputerConfiguration, computer: ComputerInterface
140191
) -> None:
141-
"""No-op, we don't use this but need to implement the abstract method."""
142-
return
192+
assert isinstance(computer, BaseAlcatrazComputerInterface)
193+
194+
# Run a jupyter command; this will force Jupyter to start up and be installed.
195+
# This is an Alcatraz only feature. It must be done before Internet access is disabled, because
196+
# Alcatraz supports live-installing Jupyter.
197+
# TODO(kevinliu) we should catalog which evals rely on this and remove it.
198+
logger.info("Running initial Jupyter command to ensure Jupyter is installed")
199+
await computer.execute("print('hi')")
200+
logger.info("Jupyter working!")
201+
202+
# Alcatraz must do this dynamically since it doesn't have the ability to remove
203+
# a container's internet access at creation time.
204+
# Other runtimes should do this by configuring the container object itself.
205+
if not task.allow_internet:
206+
logger.info("Disabling internet (since allow_internet is False)")
207+
await computer.disable_internet()
208+
209+
if task.network_mode not in SUPPORTED_NETWORK_MODES:
210+
raise ValueError(
211+
f"The {task.network_mode} network mode is not supported on Alcatraz, and will never be. Please change it, or stop using Alcatraz."
212+
)
143213

144-
@override
145-
def _start_computer(
214+
@asynccontextmanager
215+
async def _start_computer(
146216
self, task: ComputerConfiguration
147-
) -> AsyncContextManager[AlcatrazComputerInterface]:
148-
"""No-op, we don't use this but need to implement the abstract method."""
149-
return
217+
) -> AsyncGenerator[AlcatrazComputerInterface, None]:
218+
async with task_to_alcatraz_config(task, self.env).build() as _cluster:
219+
computer = AlcatrazComputerInterface(cluster_value=_cluster)
220+
yield computer
221+
222+

project/common/nanoeval_alcatraz/nanoeval_alcatraz/task_to_alcatraz_config.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,24 @@
33
import logging
44

55
from alcatraz.clusters.local import ClusterConfig
6-
from nanoeval.solvers.computer_tasks.code_execution_interface import ComputerConfiguration
6+
from nanoeval.solvers.computer_tasks.code_execution_interface import (
7+
ComputerConfiguration,
8+
NetworkMode,
9+
)
710

811
logger = logging.getLogger(__name__)
912

13+
SUPPORTED_NETWORK_MODES = [NetworkMode.NONE, NetworkMode.UNPROXIED]
14+
1015

1116
def task_to_alcatraz_config(task: ComputerConfiguration, config: ClusterConfig) -> ClusterConfig:
1217
# TODO, we should really just have a ClusterConfig as part of the task itself
1318

19+
if task.network_mode not in SUPPORTED_NETWORK_MODES:
20+
raise ValueError(
21+
f"The {task.network_mode} network mode is not supported on Alcatraz, and will never be. Please change it, or stop using Alcatraz."
22+
)
23+
1424
if task.azure_vm_sku:
1525
logger.info("Using custom azure_vm_sku: %s", task.azure_vm_sku)
1626
config = config.model_copy(update={"azure_vm_sku": task.azure_vm_sku})
@@ -35,11 +45,22 @@ def task_to_alcatraz_config(task: ComputerConfiguration, config: ClusterConfig)
3545
if task.alcatraz_limits:
3646
logger.info("Using custom limits: %s", task.alcatraz_limits)
3747
config = config.model_copy(update={"limits": task.alcatraz_limits})
38-
48+
merged_volumes_config = {}
3949
if task.volumes_config:
4050
logger.info("Using custom volumes_config: %s", task.volumes_config)
41-
config = config.model_copy(update={"volumes_config": task.volumes_config})
42-
51+
merged_volumes_config.update(task.volumes_config)
52+
if task.volume_mounts:
53+
volumes_config = {
54+
str(i): {
55+
"bind_source": volume_mount.host_path,
56+
"bind_dest": volume_mount.container_path,
57+
}
58+
for i, volume_mount in enumerate(task.volume_mounts)
59+
}
60+
logger.info("Using custom volumes_mount: %s", volumes_config)
61+
merged_volumes_config.update(volumes_config)
62+
if merged_volumes_config:
63+
config = config.model_copy(update={"volumes_config": merged_volumes_config})
4364
if task.shm_size:
4465
logger.info("Using custom shm size: %s", task.shm_size)
4566
config = config.model_copy(update={"shm_size": task.shm_size})

project/swelancer/swelancer/eval.py

Lines changed: 55 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -122,54 +122,63 @@ async def _setup(self, computer: ComputerInterface, runtime_config: RuntimeConfi
122122
ctx_logger = logger.bind(
123123
run_group_id=self.run_group_id, runs_dir=self.runs_dir, run_id=self.run_id
124124
)
125+
async with asyncio.timeout(2400):
126+
try:
127+
# Assert various things about the environment
128+
ctx_logger.info("SETUP", destinations=["run"])
129+
if isinstance(computer, JupyterComputerInterface):
130+
await computer.check_execute(
131+
"""import os; assert os.environ.get('CONDA_DEFAULT_ENV') == 'testbed', os.environ.get('CONDA_DEFAULT_ENV')""",
132+
)
133+
ctx_logger.info("Correct Jupyter environment is loaded.", destinations=["run"])
125134

126-
# Assert various things about the environment
127-
ctx_logger.info("SETUP", destinations=["run"])
128-
if isinstance(computer, JupyterComputerInterface):
129-
await computer.check_execute(
130-
"""import os; assert os.environ.get('CONDA_DEFAULT_ENV') == 'testbed', os.environ.get('CONDA_DEFAULT_ENV')""",
131-
)
132-
ctx_logger.info("Correct Jupyter environment is loaded.", destinations=["run"])
133-
134-
if self.variant == "swe_manager":
135-
# reset to the broken commit
136-
await computer.check_shell_command(f"bash -c 'git reset --hard {self.manager_commit}^'")
137-
else:
138-
# instead of a time based check, check until the SETUP_DONE environment variable is set
139-
ctx_logger.info("Waiting for setup to be done", destinations=["run"])
140-
while True:
141-
res = await computer.send_shell_command("cat /setup_done.txt")
142-
if res.output.decode("utf-8").strip() == "done":
143-
ctx_logger.info("Setup is done", destinations=["run"])
144-
break
145-
await asyncio.sleep(15)
146-
147-
# power user flag: if using a stale image where zip and unzip are not installed
148-
if os.getenv("INSTALL_ZIP_UNZIP", "false").lower() == "true":
149-
await computer.check_shell_command("apt-get update && apt-get install -y zip unzip")
150-
151-
# Zip the tests directory with a password
152-
ctx_logger.info("Zipping tests", destinations=["run"])
153-
await computer.check_shell_command(
154-
"cd /app/tests && zip -P 'secretpasswordhere' -r /app/tmp.zip . && mv /app/tmp.zip /app/tests.zip"
155-
)
156-
# Clear tests directory
157-
await computer.send_shell_command("rm -rf /app/tests")
158-
159-
# Remove existing git repo and create a new one
160-
await computer.check_shell_command("rm -rf .git")
161-
await computer.check_shell_command("git init")
162-
163-
# Creat a temp commit to mark where the model's changes begin
164-
if self.variant == "ic_swe":
165-
await computer.check_shell_command("bash -c 'git add .'")
166-
await computer.send_shell_command(
167-
'bash -c \'git -c user.name="temp" -c user.email="temp@example.com" commit -m "temp"\''
168-
)
169-
ctx_logger.info("Temp commit created", destinations=["run"])
135+
if self.variant == "swe_manager":
136+
# reset to the broken commit
137+
await computer.check_shell_command(
138+
f"bash -c 'git reset --hard {self.manager_commit}^'"
139+
)
140+
else:
141+
# instead of a time based check, check until the SETUP_DONE environment variable is set
142+
ctx_logger.info("Waiting for setup to be done", destinations=["run"])
143+
while True:
144+
res = await computer.send_shell_command("cat /setup_done.txt")
145+
if res.output.decode("utf-8").strip() == "done":
146+
ctx_logger.info("Setup is done", destinations=["run"])
147+
break
148+
await asyncio.sleep(15)
149+
150+
# power user flag: if using a stale image where zip and unzip are not installed
151+
if os.getenv("INSTALL_ZIP_UNZIP", "false").lower() == "true":
152+
await computer.check_shell_command(
153+
"apt-get update && apt-get install -y zip unzip"
154+
)
170155

171-
if self.disable_internet:
172-
await computer.disable_internet()
156+
# Zip the tests directory with a password
157+
ctx_logger.info("Zipping tests", destinations=["run"])
158+
await computer.check_shell_command(
159+
"cd /app/tests && zip -P 'secretpasswordhere' -r /app/tmp.zip . && mv /app/tmp.zip /app/tests.zip"
160+
)
161+
# Clear tests directory
162+
await computer.send_shell_command("rm -rf /app/tests")
163+
164+
# Remove existing git repo and create a new one
165+
await computer.check_shell_command("rm -rf .git")
166+
await computer.check_shell_command("git init")
167+
168+
# Creat a temp commit to mark where the model's changes begin
169+
if self.variant == "ic_swe":
170+
await computer.check_shell_command("bash -c 'git add .'")
171+
await computer.send_shell_command(
172+
'bash -c \'git -c user.name="temp" -c user.email="temp@example.com" commit -m "temp"\''
173+
)
174+
ctx_logger.info("Temp commit created", destinations=["run"])
175+
176+
if self.disable_internet:
177+
await computer.disable_internet()
178+
except Exception as e:
179+
ctx_logger.exception("An error occurred during setup", destinations=["run"])
180+
raise RolloutSystemError(f"An error occurred during setup: {e}") from e
181+
ctx_logger.info("Setup complete", destinations=["run"])
173182

174183
@override
175184
async def grade(

0 commit comments

Comments
 (0)