diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 0000000..9bf0470 Binary files /dev/null and b/.DS_Store differ diff --git a/.gitignore b/.gitignore index 0a1d769..c205cbe 100644 --- a/.gitignore +++ b/.gitignore @@ -175,3 +175,5 @@ cython_debug/ # Node stuff node_modules/ + +.DS_Store \ No newline at end of file diff --git a/agent.py b/agent.py index f650381..fe4275f 100644 --- a/agent.py +++ b/agent.py @@ -27,8 +27,10 @@ import time from rich.console import Console from rich.table import Table +from pathlib import Path from computers import EnvState, Computer +from function_registry import FunctionRegistry MAX_RECENT_TURN_WITH_SCREENSHOTS = 3 PREDEFINED_COMPUTER_USE_FUNCTIONS = [ @@ -55,11 +57,6 @@ FunctionResponseT = Union[EnvState, dict] -def multiply_numbers(x: float, y: float) -> dict: - """Multiplies two numbers.""" - return {"result": x * y} - - class BrowserAgent: def __init__( self, @@ -79,6 +76,14 @@ def __init__( project=os.environ.get("VERTEXAI_PROJECT"), location=os.environ.get("VERTEXAI_LOCATION"), ) + config_path = os.environ.get( + "FUNCTION_CONFIG_PATH", + str(Path(__file__).parent / "config" / "functions.json"), + ) + self._function_registry = FunctionRegistry( + config_path=config_path, + client=self._client, + ) self._contents: list[Content] = [ Content( role="user", @@ -91,13 +96,7 @@ def __init__( # Exclude any predefined functions here. excluded_predefined_functions = [] - # Add your own custom functions here. - custom_functions = [ - # For example: - types.FunctionDeclaration.from_callable( - client=self._client, callable=multiply_numbers - ) - ] + custom_functions = self._function_registry.function_declarations() self._generate_content_config = GenerateContentConfig( temperature=1, @@ -190,9 +189,15 @@ def handle_action(self, action: types.FunctionCall) -> FunctionResponseT: destination_x=destination_x, destination_y=destination_y, ) - # Handle the custom function declarations here. - elif action.name == multiply_numbers.__name__: - return multiply_numbers(x=action.args["x"], y=action.args["y"]) + elif self._function_registry.has_function(action.name): + if not self._function_registry.is_whitelisted(action.name): + if not self._confirm_custom_function(action): + termcolor.cprint( + f"Custom function {action.name} denied by user.", + color="yellow", + ) + return {"status": "rejected", "reason": "user_denied"} + return self._function_registry.execute(action.name, action.args) else: raise ValueError(f"Unsupported function: {action}") @@ -389,6 +394,25 @@ def run_one_iteration(self) -> Literal["COMPLETE", "CONTINUE"]: return "CONTINUE" + def _confirm_custom_function( + self, action: types.FunctionCall + ) -> bool: + """Prompt user before executing non-whitelisted custom functions.""" + termcolor.cprint( + "Custom function requires confirmation!", + color="yellow", + attrs=["bold"], + ) + print(f"Function: {action.name}") + print(f"Args: {action.args}") + risk_note = self._function_registry.risk_note(action.name) + if risk_note: + print(f"Risk: {risk_note}") + decision = "" + while decision.lower() not in ("y", "n", "ye", "yes", "no"): + decision = input("Do you wish to execute? [Yes]/[No]\n") + return decision.lower() in ("y", "ye", "yes") + def _get_safety_confirmation( self, safety: dict[str, Any] ) -> Literal["CONTINUE", "TERMINATE"]: diff --git a/config/functions.json b/config/functions.json new file mode 100644 index 0000000..61db1be --- /dev/null +++ b/config/functions.json @@ -0,0 +1,13 @@ +{ + "functions": [ + { + "name": "multiply_numbers", + "module": "custom_functions.math", + "attribute": "multiply_numbers", + "description": "Multiply two numbers and return the product.", + "whitelist": true, + "risk_note": "Safe arithmetic operation." + } + ] +} + diff --git a/custom_functions/__init__.py b/custom_functions/__init__.py new file mode 100644 index 0000000..5f716d7 --- /dev/null +++ b/custom_functions/__init__.py @@ -0,0 +1,2 @@ +# Placeholder module for custom function plugins. + diff --git a/custom_functions/math.py b/custom_functions/math.py new file mode 100644 index 0000000..7294573 --- /dev/null +++ b/custom_functions/math.py @@ -0,0 +1,4 @@ +def multiply_numbers(x: float, y: float) -> dict: + """Multiplies two numbers.""" + return {"result": x * y} + diff --git a/function_registry.py b/function_registry.py new file mode 100644 index 0000000..20ab315 --- /dev/null +++ b/function_registry.py @@ -0,0 +1,152 @@ +import importlib +import inspect +import json +import os +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional + +import termcolor +from google.genai import types + + +@dataclass +class FunctionSpec: + """Configuration for a single custom function.""" + + name: str + module: str + attribute: str + description: Optional[str] + whitelist: bool + risk_note: Optional[str] + + +class FunctionRegistry: + """Loads custom functions from configuration and exposes declarations/execution.""" + + def __init__(self, config_path: str, client: Any): + self._config_path = config_path + self._client = client + self._specs: Dict[str, FunctionSpec] = {} + self._callables: Dict[str, Callable] = {} + self._load_config() + + def _load_config(self) -> None: + """Load function specs and resolve callables.""" + if not os.path.exists(self._config_path): + termcolor.cprint( + f"Function config not found at {self._config_path}; no custom functions loaded.", + color="yellow", + ) + return + + try: + with open(self._config_path, "r", encoding="utf-8") as config_file: + config = json.load(config_file) or {} + except Exception as exc: + termcolor.cprint( + f"Failed to read function config {self._config_path}: {exc}", + color="red", + ) + return + + for entry in config.get("functions", []): + try: + spec = FunctionSpec( + name=entry["name"], + module=entry["module"], + attribute=entry.get("attribute", entry["name"]), + description=entry.get("description"), + whitelist=bool(entry.get("whitelist", False)), + risk_note=entry.get("risk_note"), + ) + except KeyError as exc: + termcolor.cprint( + f"Invalid function config entry missing required key {exc}: {entry}", + color="red", + ) + continue + + resolved = self._import_callable(spec) + if resolved: + self._specs[spec.name] = spec + self._callables[spec.name] = resolved + + def _import_callable(self, spec: FunctionSpec) -> Optional[Callable]: + """Import callable from module according to spec.""" + try: + module = importlib.import_module(spec.module) + except Exception as exc: + termcolor.cprint( + f"Failed to import module {spec.module} for {spec.name}: {exc}", + color="red", + ) + return None + + try: + func = getattr(module, spec.attribute) + except AttributeError: + termcolor.cprint( + f"Attribute {spec.attribute} not found in module {spec.module}", + color="red", + ) + return None + + if not callable(func): + termcolor.cprint( + f"{spec.attribute} in module {spec.module} is not callable", + color="red", + ) + return None + + if spec.description and not (func.__doc__ and func.__doc__.strip()): + func.__doc__ = spec.description + return func + + def function_declarations(self) -> List[types.FunctionDeclaration]: + """Create function declarations for all loaded functions.""" + declarations: List[types.FunctionDeclaration] = [] + for name, func in self._callables.items(): + try: + declarations.append( + types.FunctionDeclaration.from_callable( + client=self._client, + callable=func, + ) + ) + except Exception as exc: + termcolor.cprint( + f"Failed to build declaration for {name}: {exc}", + color="red", + ) + return declarations + + def has_function(self, name: str) -> bool: + return name in self._callables + + def is_whitelisted(self, name: str) -> bool: + return bool(self._specs.get(name) and self._specs[name].whitelist) + + def risk_note(self, name: str) -> Optional[str]: + spec = self._specs.get(name) + if not spec: + return None + return spec.risk_note + + def execute(self, name: str, args: dict) -> dict: + if name not in self._callables: + raise ValueError(f"Function {name} is not registered.") + + func = self._callables[name] + signature = inspect.signature(func) + try: + bound_args = signature.bind(**args) + bound_args.apply_defaults() + except TypeError as exc: + termcolor.cprint( + f"Invalid arguments for {name}: {exc}", + color="red", + ) + raise + return func(*bound_args.args, **bound_args.kwargs) + diff --git a/test_agent.py b/test_agent.py index c001493..423b7b8 100644 --- a/test_agent.py +++ b/test_agent.py @@ -12,12 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json import os +import tempfile import unittest from unittest.mock import MagicMock, patch from google.genai import types -from agent import BrowserAgent, multiply_numbers +from agent import BrowserAgent from computers import EnvState +from function_registry import FunctionRegistry +from custom_functions.math import multiply_numbers class TestBrowserAgent(unittest.TestCase): def setUp(self): @@ -62,6 +66,44 @@ def test_handle_action_navigate(self): self.agent.handle_action(action) self.mock_browser_computer.navigate.assert_called_once_with("https://example.com") + def test_function_registry_load_and_execute(self): + config_payload = { + "functions": [ + { + "name": "multiply_numbers", + "module": "custom_functions.math", + "attribute": "multiply_numbers", + "description": "Multiply two numbers.", + "whitelist": True, + } + ] + } + with tempfile.NamedTemporaryFile("w", delete=False, suffix=".json") as temp_config: + json.dump(config_payload, temp_config) + temp_path = temp_config.name + registry = FunctionRegistry(config_path=temp_path, client=MagicMock()) + self.assertTrue(registry.has_function("multiply_numbers")) + self.assertTrue(registry.is_whitelisted("multiply_numbers")) + self.assertEqual( + registry.execute("multiply_numbers", {"x": 2, "y": 3}), + {"result": 6}, + ) + os.remove(temp_path) + + @patch("agent.input", return_value="yes") + def test_handle_action_custom_function_requires_confirmation(self, mock_input): + mock_registry = MagicMock() + mock_registry.has_function.return_value = True + mock_registry.is_whitelisted.return_value = False + mock_registry.risk_note.return_value = "Risky operation" + mock_registry.execute.return_value = {"status": "ok"} + self.agent._function_registry = mock_registry + action = types.FunctionCall(name="custom_fn", args={"x": 1}) + result = self.agent.handle_action(action) + self.assertEqual(result, {"status": "ok"}) + mock_registry.execute.assert_called_once_with("custom_fn", {"x": 1}) + mock_input.assert_called() + def test_handle_action_unknown_function(self): action = types.FunctionCall(name="unknown_function", args={}) with self.assertRaises(ValueError):