Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -1346,11 +1346,10 @@ The `graph` tool uses the same model provider environment variables as `use_agen
| Environment Variable | Description | Default |
|----------------------|-------------|---------|
| MONGODB_ATLAS_CLUSTER_URI | MongoDB Atlas connection string | None |
| MONGODB_DEFAULT_DATABASE | Default database name for MongoDB operations | memories |
| MONGODB_DEFAULT_COLLECTION | Default collection name for MongoDB operations | user_memories |
| MONGODB_DEFAULT_NAMESPACE | Default namespace for memory isolation | default |
| MONGODB_DEFAULT_MAX_RESULTS | Default maximum results for list operations | 50 |
| MONGODB_DEFAULT_MIN_SCORE | Default minimum relevance score for filtering results | 0.4 |
| MONGODB_DATABASE_NAME | Database name for MongoDB operations | strands_memory |
| MONGODB_COLLECTION_NAME | Collection name for MongoDB operations | memories |
| MONGODB_NAMESPACE | Namespace for memory isolation | default |
| MONGODB_EMBEDDING_MODEL | Amazon Bedrock model for embeddings | amazon.titan-embed-text-v2:0 |

**Note**: This tool requires AWS account credentials to generate embeddings using Amazon Bedrock Titan models.

Expand Down
6 changes: 5 additions & 1 deletion docs/mongodb_memory_tool.md
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ result = mongodb_memory(

### Environment Variables

You can also use environment variables for configuration:
You can use environment variables for configuration:

```bash
export MONGODB_ATLAS_CLUSTER_URI="mongodb+srv://user:password@cluster.mongodb.net/"
Expand All @@ -123,6 +123,10 @@ export MONGODB_EMBEDDING_MODEL="amazon.titan-embed-text-v2:0"
export AWS_REGION="us-west-2"
```

**Note:** Environment variables take precedence over tool parameters in the standalone function.
To let the agent control the connection target, use the class-based approach or leave the
environment variable unset.

Then use the tool with minimal parameters (environment variables will be used):

```python
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,12 @@ def write_files(self, action: WriteFilesAction) -> Dict[str, Any]:

logger.debug(f"Writing {len(action.content)} files to session '{session_name}'")

content_dicts = [{"path": fc.path, "text": fc.text} for fc in action.content]
content_dicts = []
for fc in action.content:
if fc.blob is not None:
content_dicts.append({"path": fc.path, "blob": fc.blob})
else:
content_dicts.append({"path": fc.path, "text": fc.text})
params = {"content": content_dicts}
response = self._sessions[session_name].client.invoke("writeFiles", params)

Expand Down
19 changes: 17 additions & 2 deletions src/strands_tools/code_interpreter/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from enum import Enum
from typing import List, Literal, Optional, Union

from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, model_validator


class LanguageType(str, Enum):
Expand All @@ -24,7 +24,22 @@ class FileContent(BaseModel):
updating files during code execution sessions."""

path: str = Field(description="The file path where content should be written")
text: str = Field(description="Text content for the file")
text: Optional[str] = Field(
default=None,
description="Text content for the file",
)
blob: Optional[bytes] = Field(
default=None,
description="Base64-encoded binary content for the file",
)

@model_validator(mode="after")
def validate_content(self) -> "FileContent":
if self.text is None and self.blob is None:
raise ValueError("Either text or blob must be provided")
if self.text is not None and self.blob is not None:
raise ValueError("Only one of text or blob may be provided")
return self


# Action-specific Pydantic models using discriminated unions
Expand Down
31 changes: 9 additions & 22 deletions src/strands_tools/elasticsearch_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,14 @@
from strands import Agent
from strands_tools.elasticsearch_memory import elasticsearch_memory

# Create agent with direct tool usage
# Create agent with elasticsearch_memory tool (credentials via env vars)
agent = Agent(tools=[elasticsearch_memory])

# Store a memory with semantic embeddings
elasticsearch_memory(
action="record",
content="User prefers vegetarian pizza with extra cheese",
metadata={"category": "food_preferences", "type": "dietary"},
cloud_id="your-elasticsearch-cloud-id",
api_key="your-api-key",
index_name="memories",
namespace="user_123"
)
Expand All @@ -55,8 +53,6 @@
action="retrieve",
query="food preferences and dietary restrictions",
max_results=5,
cloud_id="your-elasticsearch-cloud-id",
api_key="your-api-key",
index_name="memories",
namespace="user_123"
)
Expand All @@ -65,8 +61,6 @@
elasticsearch_memory(
action="list",
max_results=10,
cloud_id="your-elasticsearch-cloud-id",
api_key="your-api-key",
index_name="memories",
namespace="user_123"
)
Expand All @@ -75,8 +69,6 @@
elasticsearch_memory(
action="get",
memory_id="mem_1234567890_abcd1234",
cloud_id="your-elasticsearch-cloud-id",
api_key="your-api-key",
index_name="memories",
namespace="user_123"
)
Expand All @@ -85,8 +77,6 @@
elasticsearch_memory(
action="delete",
memory_id="mem_1234567890_abcd1234",
cloud_id="your-elasticsearch-cloud-id",
api_key="your-api-key",
index_name="memories",
namespace="user_123"
)
Expand Down Expand Up @@ -599,9 +589,6 @@ def elasticsearch_memory(
max_results: Optional[int] = None,
next_token: Optional[str] = None,
metadata: Optional[Dict] = None,
cloud_id: Optional[str] = None,
api_key: Optional[str] = None,
es_url: Optional[str] = None,
index_name: Optional[str] = None,
namespace: Optional[str] = None,
embedding_model: Optional[str] = None,
Expand Down Expand Up @@ -639,6 +626,10 @@ def elasticsearch_memory(
- delete: Remove a specific memory
Use this to delete memories that are no longer needed.

Connection credentials are read from environment variables:
- ELASTICSEARCH_CLOUD_ID or ELASTICSEARCH_URL for connection
- ELASTICSEARCH_API_KEY for authentication

Args:
action: The memory operation to perform (one of: "record", "retrieve", "list", "get", "delete")
content: For record action: Text content to store as a memory
Expand All @@ -647,9 +638,6 @@ def elasticsearch_memory(
max_results: Maximum number of results to return (optional, default: 10)
next_token: Pagination token for list action (optional)
metadata: Additional metadata to store with the memory (optional)
cloud_id: Elasticsearch Cloud ID for connection (optional if es_url provided)
api_key: Elasticsearch API key for authentication
es_url: Elasticsearch URL for serverless connection (optional if cloud_id provided)
index_name: Name of the Elasticsearch index (defaults to 'strands_memory')
namespace: Namespace for memory operations (defaults to 'default')
embedding_model: Amazon Bedrock model for embeddings (defaults to Titan)
Expand All @@ -658,12 +646,11 @@ def elasticsearch_memory(
Returns:
Dict: Response containing the requested memory information or operation status
"""
try:
# Get values from environment variables if not provided
cloud_id = cloud_id or os.getenv("ELASTICSEARCH_CLOUD_ID")
es_url = es_url or os.getenv("ELASTICSEARCH_URL")
api_key = api_key or os.getenv("ELASTICSEARCH_API_KEY")
cloud_id = os.getenv("ELASTICSEARCH_CLOUD_ID")
es_url = os.getenv("ELASTICSEARCH_URL")
api_key = os.getenv("ELASTICSEARCH_API_KEY")

try:
# Validate required parameters
if not api_key:
return {"status": "error", "content": [{"text": "api_key is required"}]}
Expand Down
11 changes: 10 additions & 1 deletion src/strands_tools/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,16 @@


# Protected variables that can't be modified
PROTECTED_VARS = {"PATH", "PYTHONPATH", "STRANDS_HOME", "SHELL", "USER", "HOME", "BYPASS_TOOL_CONSENT"}
PROTECTED_VARS = {
"PATH",
"PYTHONPATH",
"STRANDS_HOME",
"SHELL",
"USER",
"HOME",
"BYPASS_TOOL_CONSENT",
"STRANDS_NON_INTERACTIVE",
}


def mask_sensitive_value(name: str, value: str) -> str:
Expand Down
4 changes: 3 additions & 1 deletion src/strands_tools/file_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,9 @@ def find_files(console: Console, pattern: str, recursive: bool = True) -> List[s
elif os.path.isdir(pattern):
matching_files = []

for root, _dirs, files in os.walk(pattern):
for root, dirs, files in os.walk(pattern):
dirs[:] = [d for d in dirs if not d.startswith(".")]

if not recursive and root != pattern:
continue

Expand Down
22 changes: 13 additions & 9 deletions src/strands_tools/mongodb_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -1006,9 +1006,12 @@ def mongodb_memory(
max_results: Maximum number of results to return (optional, default: 10)
next_token: Pagination token for list action (optional)
metadata: Additional metadata to store with the memory (optional)
cluster_uri: MongoDB Atlas cluster URI (optional if set via environment)
database_name: Name of the MongoDB database (optional, defaults to 'strands_memory')
collection_name: Name of the MongoDB collection (optional, defaults to 'memories')
cluster_uri: MongoDB Atlas cluster URI. If the MONGODB_ATLAS_CLUSTER_URI environment variable
is set, it takes precedence and this parameter is ignored.
database_name: Name of the MongoDB database. If the MONGODB_DATABASE_NAME environment variable
is set, it takes precedence. Defaults to 'strands_memory'.
collection_name: Name of the MongoDB collection. If the MONGODB_COLLECTION_NAME environment
variable is set, it takes precedence. Defaults to 'memories'.
namespace: Namespace for memory operations (defaults to 'default')
embedding_model: Amazon Bedrock model for embeddings (defaults to Titan)
region: AWS region for Bedrock service (defaults to 'us-west-2')
Expand All @@ -1018,12 +1021,13 @@ def mongodb_memory(
Dict: Response containing the requested memory information or operation status
"""
try:
# Get values from environment variables if not provided
cluster_uri = cluster_uri or os.getenv("MONGODB_ATLAS_CLUSTER_URI")
database_name = database_name or os.getenv("MONGODB_DATABASE_NAME", DEFAULT_DATABASE_NAME)
collection_name = collection_name or os.getenv("MONGODB_COLLECTION_NAME", DEFAULT_COLLECTION_NAME)
embedding_model = embedding_model or os.getenv("MONGODB_EMBEDDING_MODEL", DEFAULT_EMBEDDING_MODEL)
region = region or os.getenv("AWS_REGION", DEFAULT_AWS_REGION)
# Environment variables take precedence over agent-provided parameters to prevent
# the agent from redirecting connections to untrusted servers.
cluster_uri = os.getenv("MONGODB_ATLAS_CLUSTER_URI") or cluster_uri
database_name = os.getenv("MONGODB_DATABASE_NAME", database_name or DEFAULT_DATABASE_NAME)
collection_name = os.getenv("MONGODB_COLLECTION_NAME", collection_name or DEFAULT_COLLECTION_NAME)
embedding_model = os.getenv("MONGODB_EMBEDDING_MODEL", embedding_model or DEFAULT_EMBEDDING_MODEL)
region = os.getenv("AWS_REGION", region or DEFAULT_AWS_REGION)
vector_index_name = vector_index_name or DEFAULT_VECTOR_INDEX_NAME
if namespace is None:
namespace = os.getenv("MONGODB_NAMESPACE", DEFAULT_NAMESPACE)
Expand Down
25 changes: 19 additions & 6 deletions src/strands_tools/retrieve.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,20 @@ def filter_results_by_score(results: List[Dict[str, Any]], min_score: float) ->
return [result for result in results if result.get("score", 0.0) >= min_score]


# Mapping of RetrievalResultLocation types to their document identifier fields.
# See: https://docs.aws.amazon.com/bedrock/latest/APIReference/API_agent-runtime_RetrievalResultLocation.html
_LOCATION_FIELD_MAP = {
"customDocumentLocation": "id",
"s3Location": "uri",
"webLocation": "url",
"confluenceLocation": "url",
"salesforceLocation": "url",
"sharePointLocation": "url",
"kendraDocumentLocation": "uri",
"sqlLocation": "query",
}


def format_results_for_display(results: List[Dict[str, Any]], enable_metadata: bool = False) -> str:
"""
Format retrieval results for readable display.
Expand All @@ -226,14 +240,13 @@ def format_results_for_display(results: List[Dict[str, Any]], enable_metadata: b

formatted = []
for result in results:
# Extract document location - handle both s3Location and customDocumentLocation
# Extract document location - handle all RetrievalResultLocation types
location = result.get("location", {})
doc_id = "Unknown"
if "customDocumentLocation" in location:
doc_id = location["customDocumentLocation"].get("id", "Unknown")
elif "s3Location" in location:
# Extract meaningful part from S3 URI
doc_id = location["s3Location"].get("uri", "")
for loc_key, field in _LOCATION_FIELD_MAP.items():
if loc_key in location:
doc_id = location[loc_key].get(field, "Unknown")
break
score = result.get("score", 0.0)
formatted.append(f"\nScore: {score:.4f}")
formatted.append(f"Document ID: {doc_id}")
Expand Down
6 changes: 1 addition & 5 deletions src/strands_tools/shell.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,6 @@ def shell(
ignore_errors: bool = False,
timeout: int = None,
work_dir: str = None,
non_interactive: bool = False,
) -> Dict[str, Any]:
"""Interactive shell with PTY support for real-time command execution and interaction. Features:

Expand Down Expand Up @@ -482,16 +481,13 @@ def shell(
ignore_errors: Continue execution even if some commands fail (default: False)
timeout: Timeout in seconds for each command (default: controlled by SHELL_DEFAULT_TIMEOUT environment variable)
work_dir: Working directory for command execution (default: current)
non_interactive: Run in non-interactive mode without user prompts (default: False)

Returns:
Dict containing status and response content
"""
console = console_util.create()

is_strands_non_interactive = os.environ.get("STRANDS_NON_INTERACTIVE", "").lower() == "true"
# Here we keep both doors open, but we only prompt env STRANDS_NON_INTERACTIVE in our doc.
non_interactive_mode = is_strands_non_interactive or non_interactive
non_interactive_mode = os.environ.get("STRANDS_NON_INTERACTIVE", "").lower() == "true"

# Validate command parameter
if command is None:
Expand Down
1 change: 1 addition & 0 deletions src/strands_tools/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,7 @@ def _create_task_agent(self, task: Dict) -> Agent:

# Create the task agent
task_agent = Agent(
name=task.get("task_id"),
model=selected_model,
system_prompt=system_prompt,
tools=filtered_tools,
Expand Down
31 changes: 31 additions & 0 deletions tests/code_interpreter/test_agent_core_code_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -756,6 +756,37 @@ def test_write_files_success(interpreter, mock_client):
)


def test_write_files_action_with_blob(interpreter, mock_client):
"""Test successful file writing with base64 blob content."""
session_info = SessionInfo(session_id="test-session-id-123", description="Test session", client=mock_client)
interpreter._sessions["test-session"] = session_info

action = WriteFilesAction(
type="writeFiles",
session_name="test-session",
content=[
FileContent(path="image.png", blob=b"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJ"),
FileContent(path="data.txt", text="Some data"),
],
)

result = interpreter.write_files(action)

assert result["status"] == "success"
mock_client.invoke.assert_called_once_with(
"writeFiles",
{
"content": [
{
"path": "image.png",
"blob": b"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJ",
},
{"path": "data.txt", "text": "Some data"},
]
},
)


def test_create_tool_result_with_stream(interpreter):
"""Test _create_tool_result with stream response."""
response = {"stream": [{"result": {"content": "Test output"}}], "isError": False}
Expand Down
Loading
Loading