From 0340f47edb55ebd8b4c8933d1327b1794afbde60 Mon Sep 17 00:00:00 2001 From: njx <3771829673@qq.com> Date: Thu, 16 Oct 2025 13:14:24 +0800 Subject: [PATCH 1/4] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E4=BA=86=E9=9C=80?= =?UTF-8?q?=E8=A6=81=E5=BF=BD=E7=95=A5=E7=9A=84=E6=96=87=E4=BB=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .DS_Store | Bin 0 -> 10244 bytes 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 .DS_Store diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..391e98c3d6f8555eb2c504cdbcb0157f12ceb938 GIT binary patch literal 10244 zcmeHMziSjh6n=AFewjoAK`@nvg@{Nnf`W}~k~=~m8Zl~Ve%*OFxVw${0jBcWG$MkX zg{6Xph?SinSV+Le+Df#@UtpnrZ)UPLZ*O;#N{r0F?Ax1pZ}eaEi<*>1;TXV<0xzA{X#iBg}=>78S%IA^UkK?kxoY}v+;yk7eX42-~>Ry##u*lpT%G*PwEf zLRzDWIzjPpRN#*8Qzjx>*-08jm!v_=RqPQ=NkY- z{Xjgg8AQ0bMPlD(>?xiu?qKB3LpzF6z1x`2Kf1a-Y*?mvOT5n+Zx__9LhELH4yY&S zbw`)uDcO{E<#SvHFCX9CIbkq4@$QI9_K-*y@B(}#Bcx`G$dN=mXJ0vUwhyBwqy}=a zsYlZqt;+~K_V$U8OkeYSPI#`%`0LN_Z!P0Q&vw4DJ%s}1+up4iHR&E!tZIx4XUpXx zd?ky^i|`fwF8*J`Klbp)eA>v Date: Thu, 16 Oct 2025 13:15:01 +0800 Subject: [PATCH 2/4] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E4=BA=86=E9=9C=80?= =?UTF-8?q?=E8=A6=81=E5=BF=BD=E7=95=A5=E7=9A=84=E6=96=87=E4=BB=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.gitignore b/.gitignore index 0a1d769..c205cbe 100644 --- a/.gitignore +++ b/.gitignore @@ -175,3 +175,5 @@ cython_debug/ # Node stuff node_modules/ + +.DS_Store \ No newline at end of file From b5395554919abf63700d8137f5b81522fa6ac5f9 Mon Sep 17 00:00:00 2001 From: njx <3771829673@qq.com> Date: Tue, 9 Dec 2025 22:12:32 +0800 Subject: [PATCH 3/4] =?UTF-8?q?=E8=A7=84=E8=8C=83=E4=BA=86=E6=B7=BB?= =?UTF-8?q?=E5=8A=A0=E8=87=AA=E5=AE=9A=E4=B9=89=E5=87=BD=E6=95=B0=E7=9A=84?= =?UTF-8?q?=E6=9C=BA=E5=88=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .DS_Store | Bin 10244 -> 6148 bytes agent.py | 55 +++++++++---- config/functions.json | 13 +++ custom_functions/__init__.py | 2 + custom_functions/math.py | 4 + function_registry.py | 155 +++++++++++++++++++++++++++++++++++ test_agent.py | 44 +++++++++- 7 files changed, 257 insertions(+), 16 deletions(-) create mode 100644 config/functions.json create mode 100644 custom_functions/__init__.py create mode 100644 custom_functions/math.py create mode 100644 function_registry.py diff --git a/.DS_Store b/.DS_Store index 391e98c3d6f8555eb2c504cdbcb0157f12ceb938..9bf0470da4455594b99d1b671109be3daa6b3440 100644 GIT binary patch delta 266 zcmZn(XfcprU|?W$DortDU=RQ@Ie-{MGjUEV6q~50D98hn2Z`mR8wMxm=N4=%T+Jv8 z5@lh~V@PMnWGI2j0L6f6PYO>C5RhP#x*~ExuViwbfE0|`Cm_wltT6eefFqk>F<4ST zPyxXdnH(i3JlR=9c(Q>AA5+rF$qph)><89?6ds>!FCreaEi<*>1;TXV<0xzA{X#iBg}=>78S%IA^UkK?kxoY}v+;yk7eX42-~>Ry##u*lpT%G*PwEf zLRzDWIzjPpRN#*8Qzjx>*-08jm!v_=RqPQ=NkY- z{Xjgg8AQ0bMPlD(>?xiu?qKB3LpzF6z1x`2Kf1a-Y*?mvOT5n+Zx__9LhELH4yY&S zbw`)uDcO{E<#SvHFCX9CIbkq4@$QI9_K-*y@B(}#Bcx`G$dN=mXJ0vUwhyBwqy}=a zsYlZqt;+~K_V$U8OkeYSPI#`%`0LN_Z!P0Q&vw4DJ%s}1+up4iHR&E!tZIx4XUpXx zd?ky^i|`fwF8*J`Klbp)eA>v dict: - """Multiplies two numbers.""" - return {"result": x * y} - - class BrowserAgent: def __init__( self, @@ -79,6 +76,14 @@ def __init__( project=os.environ.get("VERTEXAI_PROJECT"), location=os.environ.get("VERTEXAI_LOCATION"), ) + config_path = os.environ.get( + "FUNCTION_CONFIG_PATH", + str(Path(__file__).parent / "config" / "functions.json"), + ) + self._function_registry = FunctionRegistry( + config_path=config_path, + client=self._client, + ) self._contents: list[Content] = [ Content( role="user", @@ -91,13 +96,7 @@ def __init__( # Exclude any predefined functions here. excluded_predefined_functions = [] - # Add your own custom functions here. - custom_functions = [ - # For example: - types.FunctionDeclaration.from_callable( - client=self._client, callable=multiply_numbers - ) - ] + custom_functions = self._function_registry.function_declarations() self._generate_content_config = GenerateContentConfig( temperature=1, @@ -187,9 +186,15 @@ def handle_action(self, action: types.FunctionCall) -> FunctionResponseT: destination_x=destination_x, destination_y=destination_y, ) - # Handle the custom function declarations here. - elif action.name == multiply_numbers.__name__: - return multiply_numbers(x=action.args["x"], y=action.args["y"]) + elif self._function_registry.has_function(action.name): + if not self._function_registry.is_whitelisted(action.name): + if not self._confirm_custom_function(action): + termcolor.cprint( + f"Custom function {action.name} denied by user.", + color="yellow", + ) + return {"status": "rejected", "reason": "user_denied"} + return self._function_registry.execute(action.name, action.args) else: raise ValueError(f"Unsupported function: {action}") @@ -386,6 +391,26 @@ def run_one_iteration(self) -> Literal["COMPLETE", "CONTINUE"]: return "CONTINUE" + def _confirm_custom_function( + self, action: types.FunctionCall + ) -> bool: + """Prompt user before executing non-whitelisted custom functions.""" + termcolor.cprint( + "Custom function requires confirmation!", + color="yellow", + attrs=["bold"], + ) + print(f"Function: {action.name}") + print(f"Args: {action.args}") + risk_note = self._function_registry.risk_note(action.name) + if risk_note: + print(f"Risk: {risk_note}") + decision = "" + while decision.lower() not in ("y", "n", "ye", "yes", "no"): + decision = input("Do you wish to execute? [Yes]/[No]\n") + print(f"debugging: user decision for {action.name} -> {decision}") + return decision.lower() in ("y", "yes") + def _get_safety_confirmation( self, safety: dict[str, Any] ) -> Literal["CONTINUE", "TERMINATE"]: diff --git a/config/functions.json b/config/functions.json new file mode 100644 index 0000000..61db1be --- /dev/null +++ b/config/functions.json @@ -0,0 +1,13 @@ +{ + "functions": [ + { + "name": "multiply_numbers", + "module": "custom_functions.math", + "attribute": "multiply_numbers", + "description": "Multiply two numbers and return the product.", + "whitelist": true, + "risk_note": "Safe arithmetic operation." + } + ] +} + diff --git a/custom_functions/__init__.py b/custom_functions/__init__.py new file mode 100644 index 0000000..5f716d7 --- /dev/null +++ b/custom_functions/__init__.py @@ -0,0 +1,2 @@ +# Placeholder module for custom function plugins. + diff --git a/custom_functions/math.py b/custom_functions/math.py new file mode 100644 index 0000000..7294573 --- /dev/null +++ b/custom_functions/math.py @@ -0,0 +1,4 @@ +def multiply_numbers(x: float, y: float) -> dict: + """Multiplies two numbers.""" + return {"result": x * y} + diff --git a/function_registry.py b/function_registry.py new file mode 100644 index 0000000..186ed2c --- /dev/null +++ b/function_registry.py @@ -0,0 +1,155 @@ +import importlib +import inspect +import json +import os +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional + +import termcolor +from google.genai import types + + +@dataclass +class FunctionSpec: + """Configuration for a single custom function.""" + + name: str + module: str + attribute: str + description: Optional[str] + whitelist: bool + risk_note: Optional[str] + + +class FunctionRegistry: + """Loads custom functions from configuration and exposes declarations/execution.""" + + def __init__(self, config_path: str, client: Any): + self._config_path = config_path + self._client = client + self._specs: Dict[str, FunctionSpec] = {} + self._callables: Dict[str, Callable] = {} + self._load_config() + + def _load_config(self) -> None: + """Load function specs and resolve callables.""" + if not os.path.exists(self._config_path): + termcolor.cprint( + f"Function config not found at {self._config_path}; no custom functions loaded.", + color="yellow", + ) + return + + try: + with open(self._config_path, "r", encoding="utf-8") as config_file: + config = json.load(config_file) or {} + except Exception as exc: + termcolor.cprint( + f"Failed to read function config {self._config_path}: {exc}", + color="red", + ) + return + + for entry in config.get("functions", []): + try: + spec = FunctionSpec( + name=entry["name"], + module=entry["module"], + attribute=entry.get("attribute", entry["name"]), + description=entry.get("description"), + whitelist=bool(entry.get("whitelist", False)), + risk_note=entry.get("risk_note"), + ) + except KeyError as exc: + termcolor.cprint( + f"Invalid function config entry missing required key {exc}: {entry}", + color="red", + ) + continue + + resolved = self._import_callable(spec) + if resolved: + self._specs[spec.name] = spec + self._callables[spec.name] = resolved + print(f"debugging: loaded custom functions {list(self._specs.keys())}") + + def _import_callable(self, spec: FunctionSpec) -> Optional[Callable]: + """Import callable from module according to spec.""" + try: + module = importlib.import_module(spec.module) + except Exception as exc: + termcolor.cprint( + f"Failed to import module {spec.module} for {spec.name}: {exc}", + color="red", + ) + return None + + try: + func = getattr(module, spec.attribute) + except AttributeError: + termcolor.cprint( + f"Attribute {spec.attribute} not found in module {spec.module}", + color="red", + ) + return None + + if not callable(func): + termcolor.cprint( + f"{spec.attribute} in module {spec.module} is not callable", + color="red", + ) + return None + + if spec.description and not (func.__doc__ and func.__doc__.strip()): + func.__doc__ = spec.description + return func + + def function_declarations(self) -> List[types.FunctionDeclaration]: + """Create function declarations for all loaded functions.""" + declarations: List[types.FunctionDeclaration] = [] + for name, func in self._callables.items(): + try: + declarations.append( + types.FunctionDeclaration.from_callable( + client=self._client, + callable=func, + ) + ) + except Exception as exc: + termcolor.cprint( + f"Failed to build declaration for {name}: {exc}", + color="red", + ) + print(f"debugging: built {len(declarations)} function declarations") + return declarations + + def has_function(self, name: str) -> bool: + return name in self._callables + + def is_whitelisted(self, name: str) -> bool: + return bool(self._specs.get(name) and self._specs[name].whitelist) + + def risk_note(self, name: str) -> Optional[str]: + spec = self._specs.get(name) + if not spec: + return None + return spec.risk_note + + def execute(self, name: str, args: dict) -> dict: + if name not in self._callables: + raise ValueError(f"Function {name} is not registered.") + + func = self._callables[name] + signature = inspect.signature(func) + try: + bound_args = signature.bind(**args) + bound_args.apply_defaults() + except TypeError as exc: + termcolor.cprint( + f"Invalid arguments for {name}: {exc}", + color="red", + ) + raise + print(f"debugging: executing custom function {name} with args {bound_args.arguments}") + return func(*bound_args.args, **bound_args.kwargs) + diff --git a/test_agent.py b/test_agent.py index c001493..423b7b8 100644 --- a/test_agent.py +++ b/test_agent.py @@ -12,12 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json import os +import tempfile import unittest from unittest.mock import MagicMock, patch from google.genai import types -from agent import BrowserAgent, multiply_numbers +from agent import BrowserAgent from computers import EnvState +from function_registry import FunctionRegistry +from custom_functions.math import multiply_numbers class TestBrowserAgent(unittest.TestCase): def setUp(self): @@ -62,6 +66,44 @@ def test_handle_action_navigate(self): self.agent.handle_action(action) self.mock_browser_computer.navigate.assert_called_once_with("https://example.com") + def test_function_registry_load_and_execute(self): + config_payload = { + "functions": [ + { + "name": "multiply_numbers", + "module": "custom_functions.math", + "attribute": "multiply_numbers", + "description": "Multiply two numbers.", + "whitelist": True, + } + ] + } + with tempfile.NamedTemporaryFile("w", delete=False, suffix=".json") as temp_config: + json.dump(config_payload, temp_config) + temp_path = temp_config.name + registry = FunctionRegistry(config_path=temp_path, client=MagicMock()) + self.assertTrue(registry.has_function("multiply_numbers")) + self.assertTrue(registry.is_whitelisted("multiply_numbers")) + self.assertEqual( + registry.execute("multiply_numbers", {"x": 2, "y": 3}), + {"result": 6}, + ) + os.remove(temp_path) + + @patch("agent.input", return_value="yes") + def test_handle_action_custom_function_requires_confirmation(self, mock_input): + mock_registry = MagicMock() + mock_registry.has_function.return_value = True + mock_registry.is_whitelisted.return_value = False + mock_registry.risk_note.return_value = "Risky operation" + mock_registry.execute.return_value = {"status": "ok"} + self.agent._function_registry = mock_registry + action = types.FunctionCall(name="custom_fn", args={"x": 1}) + result = self.agent.handle_action(action) + self.assertEqual(result, {"status": "ok"}) + mock_registry.execute.assert_called_once_with("custom_fn", {"x": 1}) + mock_input.assert_called() + def test_handle_action_unknown_function(self): action = types.FunctionCall(name="unknown_function", args={}) with self.assertRaises(ValueError): From 0c47057db599ddbc9e24539fd112dc3d30478fcc Mon Sep 17 00:00:00 2001 From: njx <3771829673@qq.com> Date: Tue, 9 Dec 2025 22:14:52 +0800 Subject: [PATCH 4/4] =?UTF-8?q?=E4=BF=AE=E6=94=B9bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- agent.py | 3 +-- function_registry.py | 3 --- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/agent.py b/agent.py index 3a66f8c..8809806 100644 --- a/agent.py +++ b/agent.py @@ -408,8 +408,7 @@ def _confirm_custom_function( decision = "" while decision.lower() not in ("y", "n", "ye", "yes", "no"): decision = input("Do you wish to execute? [Yes]/[No]\n") - print(f"debugging: user decision for {action.name} -> {decision}") - return decision.lower() in ("y", "yes") + return decision.lower() in ("y", "ye", "yes") def _get_safety_confirmation( self, safety: dict[str, Any] diff --git a/function_registry.py b/function_registry.py index 186ed2c..20ab315 100644 --- a/function_registry.py +++ b/function_registry.py @@ -71,7 +71,6 @@ def _load_config(self) -> None: if resolved: self._specs[spec.name] = spec self._callables[spec.name] = resolved - print(f"debugging: loaded custom functions {list(self._specs.keys())}") def _import_callable(self, spec: FunctionSpec) -> Optional[Callable]: """Import callable from module according to spec.""" @@ -120,7 +119,6 @@ def function_declarations(self) -> List[types.FunctionDeclaration]: f"Failed to build declaration for {name}: {exc}", color="red", ) - print(f"debugging: built {len(declarations)} function declarations") return declarations def has_function(self, name: str) -> bool: @@ -150,6 +148,5 @@ def execute(self, name: str, args: dict) -> dict: color="red", ) raise - print(f"debugging: executing custom function {name} with args {bound_args.arguments}") return func(*bound_args.args, **bound_args.kwargs)