-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathlanggraph_tool_backend.py
More file actions
109 lines (79 loc) · 3.07 KB
/
langgraph_tool_backend.py
File metadata and controls
109 lines (79 loc) · 3.07 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
from langgraph.graph import StateGraph, START, END
from typing import TypedDict, Annotated
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_core.messages import HumanMessage, BaseMessage
from langgraph.checkpoint.sqlite import SqliteSaver
from langgraph.graph.message import add_messages # to add messages in the list used instead of reducer function operator.add
from dotenv import load_dotenv
import sqlite3
import os
from langgraph.prebuilt import ToolNode, tools_condition
from langchain_community.tools import DuckDuckGoSearchRun
from langchain_core.tools import tool
import requests
os.environ['LANGCHAIN_PROJECT'] = 'Chatbot 3'
load_dotenv()
# Tools
# Search tool
search_tool = DuckDuckGoSearchRun(region="us-en")
# Calculator tool
@tool
def calculator(first_num: float, second_num: float, operation: str) -> dict:
"""
Perform a basic arithmetic operation on two numbers.
Supported operations: add, sub, mul, div
"""
try:
if operation == "add":
result = first_num + second_num
elif operation == "sub":
result = first_num - second_num
elif operation == "mul":
result = first_num * second_num
elif operation == "div":
if second_num == 0:
return {"error": "Division by zero is not allowed"}
result = first_num / second_num
else:
return {"error": f"Unsupported operation '{operation}'"}
return {"first_num": first_num, "second_num": second_num, "operation": operation, "result": result}
except Exception as e:
return {"error": str(e)}
# Stock price tool
@tool
def get_stock_price(symbol: str) -> dict:
"""
Fetch latest stock price for a given symbol (e.g. 'AAPL', 'TSLA')
using Alpha Vantage with API key in the URL.
"""
url = f"https://www.alphavantage.co/query?function=GLOBAL_QUOTE&symbol={symbol}&apikey=C9PE94QUEW9VWGFM"
r = requests.get(url)
return r.json()
llm = ChatGoogleGenerativeAI(model='gemini-2.0-flash')
tools = [search_tool, calculator, get_stock_price]
llm_with_tools = llm.bind_tools(tools)
class ChatState(TypedDict):
messages: Annotated[list[BaseMessage], add_messages]
def chat_node(state: ChatState):
"""LLM node that may answer or request a tool call."""
messages = state['messages']
response = llm_with_tools.invoke(messages)
return {'messages': [response]}
tool_node = ToolNode(tools)
# create database
conn = sqlite3.connect(database='chatbot.db', check_same_thread= False)
# checkpointer
checkpointer = SqliteSaver(conn=conn)
graph = StateGraph(ChatState)
# add nodes
graph.add_node('chat_node', chat_node)
graph.add_node('tools', tool_node)
graph.add_edge(START, 'chat_node')
graph.add_conditional_edges('chat_node', tools_condition)
graph.add_edge('tools', 'chat_node')
chatbot = graph.compile(checkpointer=checkpointer)
def retrieve_all_threads():
all_threads = set()
for checkpoint in checkpointer.list(None):
all_threads.add(checkpoint.config['configurable']['thread_id'])
return list(all_threads)