-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathormcp_tools.py
More file actions
416 lines (351 loc) · 17.3 KB
/
ormcp_tools.py
File metadata and controls
416 lines (351 loc) · 17.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
"""
ORMCP Tools Wrapper for LangChain
Provides tools for interacting with the ecommerce microservice via ORMCP server
"""
import os
import json
import requests
from typing import Dict, Any, List, Optional, Union
from langchain.tools import BaseTool
from pydantic import BaseModel, Field
class ORMCPClient:
"""Client for interacting with ORMCP server via HTTP"""
def __init__(self, base_url: str = "http://localhost:8080", session_id: Optional[str] = None):
# Ensure URL ends with /mcp/
if not base_url.endswith('/mcp/'):
if base_url.endswith('/mcp'):
base_url = base_url + '/'
elif base_url.endswith('/'):
base_url = base_url + 'mcp/'
else:
base_url = base_url + '/mcp/'
self.base_url = base_url
self.session_id = session_id
self.timeout = 30
self.request_id = 1
def _next_id(self) -> int:
"""Get next request ID"""
current_id = self.request_id
self.request_id += 1
return current_id
def _parse_sse_response(self, response_text: str) -> Dict[str, Any]:
"""Parse Server-Sent Events response format for FastMCP."""
lines = response_text.strip().split('\n')
data_line = None
for line in lines:
if line.startswith('data: '):
data_line = line[6:] # Remove 'data: ' prefix
break
if data_line:
try:
return json.loads(data_line)
except json.JSONDecodeError as e:
raise ValueError(f"Failed to parse JSON data: {data_line}") from e
else:
# If no SSE format, try parsing as direct JSON
try:
return json.loads(response_text.strip())
except json.JSONDecodeError:
raise ValueError(f"No valid data found in response: {response_text}")
def _make_request(self, method: str, params: Dict[str, Any]) -> Any:
"""Make MCP protocol request to ORMCP server"""
headers = {
"Content-Type": "application/json",
"Accept": "application/json, text/event-stream"
}
if self.session_id:
headers["mcp-session-id"] = self.session_id
payload = {
"jsonrpc": "2.0",
"id": self._next_id(),
"method": method,
"params": params
}
try:
response = requests.post(
self.base_url,
json=payload,
headers=headers,
timeout=self.timeout
)
# Extract session ID from response headers if present
if 'mcp-session-id' in response.headers:
self.session_id = response.headers['mcp-session-id']
response.raise_for_status()
# Handle SSE response format
if response.text.strip():
result = self._parse_sse_response(response.text)
if "result" in result:
return result["result"]
elif "error" in result:
raise Exception(f"MCP Error: {result['error']}")
return result
else:
return {"success": True}
except requests.exceptions.HTTPError as e:
# Provide more specific error messages for HTTP errors
if e.response is not None:
status_code = e.response.status_code
try:
error_detail = e.response.json()
error_msg = error_detail.get('error', {}).get('message', str(e))
except (ValueError, AttributeError):
error_msg = str(e)
if status_code == 400:
raise Exception(f"Bad Request to ORMCP server. Please check if the server is running and the request format is correct. Details: {error_msg}")
elif status_code == 404:
raise Exception(f"ORMCP server endpoint not found. Please check if the server is running at {self.base_url}")
elif status_code == 500:
raise Exception(f"ORMCP server internal error. Details: {error_msg}")
else:
raise Exception(f"HTTP {status_code} error from ORMCP server: {error_msg}")
else:
raise Exception(f"Request failed: {str(e)}")
except requests.exceptions.ConnectionError as e:
raise Exception(f"Cannot connect to ORMCP server at {self.base_url}. Please ensure the server is running.")
except requests.exceptions.Timeout as e:
raise Exception(f"Request to ORMCP server timed out after {self.timeout} seconds. The server may be overloaded or unresponsive.")
except requests.RequestException as e:
raise Exception(f"Request failed: {str(e)}")
def initialize(self) -> str:
"""Initialize MCP connection and get session ID"""
headers = {
"Content-Type": "application/json",
"Accept": "application/json, text/event-stream"
}
payload = {
"jsonrpc": "2.0",
"id": self._next_id(),
"method": "initialize",
"params": {
"protocolVersion": "2024-11-05",
"capabilities": {},
"clientInfo": {
"name": "inventory-watchdog-agent",
"version": "1.0.0"
}
}
}
try:
response = requests.post(
self.base_url,
json=payload,
headers=headers,
timeout=self.timeout
)
# Extract session ID from headers
if "mcp-session-id" in response.headers:
self.session_id = response.headers["mcp-session-id"]
response.raise_for_status()
# Parse response
if response.text.strip():
result = self._parse_sse_response(response.text)
if result.get("error"):
raise Exception(f"Initialization failed: {result['error']}")
# Send initialized notification
if self.session_id:
init_notification = {
"jsonrpc": "2.0",
"method": "notifications/initialized"
}
headers["mcp-session-id"] = self.session_id
requests.post(
self.base_url,
json=init_notification,
headers=headers,
timeout=self.timeout
)
return self.session_id or "initialized"
except requests.RequestException as e:
raise Exception(f"Failed to initialize: {str(e)}")
# Tool Input Schemas
class QueryInput(BaseModel):
className: str = Field(description="Full class name (e.g., com.acme.ecommerce.model.Product)")
filter: str = Field(default="", description="SQL-like WHERE clause filter")
maxObjects: int = Field(default=-1, description="Maximum number of objects to retrieve (-1 for all)")
deep: bool = Field(default=True, description="Include referenced objects")
class GetAggregateInput(BaseModel):
className: str = Field(description="Full class name")
attributeName: str = Field(description="Attribute name to aggregate")
aggregateType: str = Field(description="Aggregation type: COUNT, SUM, AVG, MIN, MAX")
filter: str = Field(default="", description="SQL-like WHERE clause filter")
class UpdateInput(BaseModel):
className: str = Field(description="Full class name")
jsonObjects: List[Dict[str, Any]] = Field(description="List of objects to update (must include primary keys)")
deep: bool = Field(default=True, description="Update referenced objects")
class InsertInput(BaseModel):
className: str = Field(description="Full class name")
jsonObjects: List[Dict[str, Any]] = Field(description="List of objects to insert")
deep: bool = Field(default=True, description="Save referenced objects")
class DeleteInput(BaseModel):
className: str = Field(description="Full class name (e.g., com.acme.ecommerce.model.Product)")
jsonObjects: List[Dict[str, Any]] = Field(description="List of objects to delete (only primary key required)")
deep: bool = Field(default=True, description="Delete referenced objects as well")
# LangChain Tools
class QueryTool(BaseTool):
name: str = "query_products"
description: str = """Query products from the ecommerce database.
Use this to monitor stock levels, find low-stock items, or search products by category.
Example: Query products where stockQuantity < threshold"""
args_schema: type[BaseModel] = QueryInput
client: ORMCPClient = Field(default=None, exclude=True)
def __init__(self, client: ORMCPClient, **kwargs):
super().__init__(**kwargs)
object.__setattr__(self, 'client', client)
def _run(self, className: str, filter: str = "", maxObjects: int = -1, deep: bool = True) -> str:
"""Execute query"""
result = self.client._make_request("tools/call", {
"name": "query",
"arguments": {
"className": className,
"filter": filter,
"maxObjects": maxObjects,
"deep": deep
}
})
return json.dumps(result, indent=2)
class GetAggregateTool(BaseTool):
name: str = "get_aggregate"
description: str = """Calculate aggregate values (COUNT, SUM, AVG, MIN, MAX) for product attributes.
Use this to analyze sales trends, calculate average prices, count products, etc.
Example: Get average sales velocity or total revenue"""
args_schema: type[BaseModel] = GetAggregateInput
client: ORMCPClient = Field(default=None, exclude=True)
def __init__(self, client: ORMCPClient, **kwargs):
super().__init__(**kwargs)
object.__setattr__(self, 'client', client)
def _run(self, className: str, attributeName: str, aggregateType: str, filter: str = "") -> str:
"""Execute aggregate query"""
result = self.client._make_request("tools/call", {
"name": "getAggregate",
"arguments": {
"className": className,
"attributeName": attributeName,
"aggregateType": aggregateType.upper(),
"filter": filter
}
})
return json.dumps(result, indent=2)
class UpdateTool(BaseTool):
name: str = "update_products"
description: str = """Update product information in the database.
IMPORTANT: To update a product, you MUST first query it to get its full current data including the ID.
Then provide the complete product object with the ID and only the fields you want to change.
Example workflow:
1. First query: query_products with filter "id = 1" to get the product
2. Then update: update_products with className "com.acme.ecommerce.model.Product" and jsonObjects containing the full product object with ID and updated stockQuantity
Required fields in jsonObjects: id (primary key), and the fields to update (e.g., stockQuantity).
Use this to update stock quantities, mark items for reorder, or adjust thresholds."""
args_schema: type[BaseModel] = UpdateInput
client: ORMCPClient = Field(default=None, exclude=True)
def __init__(self, client: ORMCPClient, **kwargs):
super().__init__(**kwargs)
object.__setattr__(self, 'client', client)
def _run(self, className: str, jsonObjects: List[Dict[str, Any]], deep: bool = True) -> str:
"""Execute update"""
result = self.client._make_request("tools/call", {
"name": "update",
"arguments": {
"className": className,
"jsonObjects": jsonObjects,
"deep": deep
}
})
return json.dumps(result, indent=2)
class InsertTool(BaseTool):
name: str = "insert_objects"
description: str = """Insert new objects into the database.
Use this to create reorder requests, inventory alerts, or new product entries.
Example: Create reorder requests or inventory alerts"""
args_schema: type[BaseModel] = InsertInput
client: ORMCPClient = Field(default=None, exclude=True)
def __init__(self, client: ORMCPClient, **kwargs):
super().__init__(**kwargs)
object.__setattr__(self, 'client', client)
def _run(self, className: str, jsonObjects: List[Dict[str, Any]], deep: bool = True) -> str:
"""Execute insert"""
# Pre-process jsonObjects to ensure price fields are properly formatted for BigDecimal
# CRITICAL FIX: JSON serializes whole numbers (1299.0) as integers (1299), which causes
# "Integer cannot be cast to BigDecimal" error in Java.
# Solution: Ensure price is always a float with a non-zero decimal part to preserve user input
# while ensuring JSON serializes it as a decimal number (not integer)
processed_objects = []
for obj in jsonObjects:
processed_obj = obj.copy()
# Check if this object has a price field (typically Product objects)
if "price" in processed_obj:
price_val = processed_obj["price"]
# Handle string prices
if isinstance(price_val, str):
try:
price_val = float(price_val)
except (ValueError, TypeError):
# If conversion fails, try to add .0 if it's a whole number string
if price_val.replace('.', '').replace('-', '').isdigit():
price_val = float(price_val)
else:
# Keep as string if it's not a valid number
processed_objects.append(processed_obj)
continue
if isinstance(price_val, (int, float)):
# Convert to float
price_float = float(price_val)
# CRITICAL FIX: JSON serializes whole number floats (1299.0) as integers (1299)
# This causes "Integer cannot be cast to BigDecimal" error in Java
# Solution: For whole numbers, add a tiny decimal part (0.0000001) that's negligible
# but ensures JSON serializes it as a decimal number, not an integer
# This preserves user input (1299 becomes 1299.0000001, which is effectively 1299.00)
if isinstance(price_val, int) or (isinstance(price_val, float) and price_float == int(price_float)):
# Add a very small decimal part that won't affect the price value
# This ensures JSON serializes as 1299.0000001 (decimal) instead of 1299 (integer)
# The value is so small it's negligible and won't affect calculations
processed_obj["price"] = price_float + 0.0000001
else:
# Already has decimal part, keep as is
processed_obj["price"] = price_float
processed_objects.append(processed_obj)
result = self.client._make_request("tools/call", {
"name": "insert",
"arguments": {
"className": className,
"jsonObjects": processed_objects,
"deep": deep
}
})
return json.dumps(result, indent=2)
class DeleteTool(BaseTool):
name: str = "delete_objects"
description: str = """Delete objects from the database.
Use this to remove products, orders, or other entities.
Only the primary key (id) is required to identify objects for deletion."""
args_schema: type[BaseModel] = DeleteInput
client: ORMCPClient = Field(default=None, exclude=True)
def __init__(self, client: ORMCPClient, **kwargs):
super().__init__(**kwargs)
object.__setattr__(self, 'client', client)
def _run(self, className: str, jsonObjects: List[Dict[str, Any]], deep: bool = True) -> str:
"""Execute delete"""
result = self.client._make_request("tools/call", {
"name": "delete",
"arguments": {
"className": className,
"jsonObjects": jsonObjects,
"deep": deep
}
})
return json.dumps(result, indent=2)
class GetObjectModelSummaryTool(BaseTool):
name: str = "get_object_model_summary"
description: str = """Get information about the database schema, entities, attributes, and relationships.
Use this to discover available classes and their structure."""
client: ORMCPClient = Field(default=None, exclude=True)
def __init__(self, client: ORMCPClient, **kwargs):
super().__init__(**kwargs)
object.__setattr__(self, 'client', client)
def _run(self) -> str:
"""Get object model summary"""
result = self.client._make_request("tools/call", {
"name": "getObjectModelSummary",
"arguments": {}
})
return json.dumps(result, indent=2)