-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcode_agent.py
More file actions
79 lines (68 loc) · 2.95 KB
/
Copy pathcode_agent.py
File metadata and controls
79 lines (68 loc) · 2.95 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
# code_agent.py
import os
from pathlib import Path
from typing import TypedDict, Annotated, List, Dict, Any
import streamlit as st
from langchain.schema import BaseMessage, SystemMessage, AIMessage
from langchain_core.messages.ai import AIMessageChunk
from langchain_google_genai import ChatGoogleGenerativeAI
from langgraph.graph import StateGraph, START, END
from langgraph.graph.message import add_messages
from langgraph.checkpoint.memory import MemorySaver
from retriever import LangGraphRetrievalSystem
# Constants
QA_LLM_MODEL = "gemini-2.0-flash-lite"
SYSTEM_PROMPT_FILE = "D:\\SNK\\langchain-master\\Langraph agent\\GLang\\langgraph_system_prompt.txt"
# Load base system prompt once
try:
base_system_prompt = Path(SYSTEM_PROMPT_FILE).read_text()
except Exception:
base_system_prompt = "You are a helpful AI assistant specializing in LangGraph."
class MessagesState(TypedDict):
messages: Annotated[List[BaseMessage], add_messages]
docs: List[Any]
@st.cache_resource
def get_retriever_system(k: int, api_key: str) -> LangGraphRetrievalSystem:
try:
return LangGraphRetrievalSystem(k=k, google_api_key=api_key)
except Exception:
return LangGraphRetrievalSystem(k=k, google_api_key=api_key)
@st.cache_resource
def get_qa_llm(api_key: str) -> ChatGoogleGenerativeAI:
try:
return ChatGoogleGenerativeAI(model=QA_LLM_MODEL, temperature=0.3, google_api_key=api_key)
except Exception:
return ChatGoogleGenerativeAI(model=QA_LLM_MODEL, temperature=0.3, google_api_key=api_key)
def rag_qa_node(state: MessagesState, config: Dict[str, Any]) -> Dict[str, List[BaseMessage]]:
api_key = config.get("api_key", "AIzaSyB-CXqCqmdcxv-WiaoNKa5mQpHw0n_A_aE")
k_docs = config.get("k_docs", 5)
rerank = config.get("rerank", False)
retriever = get_retriever_system(k=k_docs, api_key=api_key)
llm = get_qa_llm(api_key=api_key)
question = state["messages"][-1].content if state["messages"] else ""
try:
docs = retriever.retrieve(question, k=k_docs, rerank_llm=rerank)
context = "\n\n---\n\n".join([d.page_content for d in docs]) or ""
except Exception:
context = ""
system_content = f"{base_system_prompt}\n\n```context\n{context}```"
sys_msg = SystemMessage(content=system_content)
msgs = [sys_msg] + state["messages"]
full = ""
for chunk in llm.stream(msgs):
if isinstance(chunk, AIMessageChunk) or hasattr(chunk, 'content'):
full += chunk.content
response = {
"messages": [AIMessage(content=full or "Sorry, no response.")],
"docs": docs,
}
print(f"docs: {docs}")
return response
@st.cache_resource
def get_compiled_graph():
builder = StateGraph(MessagesState)
builder.add_node("rag_qa", rag_qa_node)
builder.add_edge(START, "rag_qa")
builder.add_edge("rag_qa", END)
memory = MemorySaver()
return builder.compile(checkpointer=memory)