Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,9 @@ LANGSMITH_API_KEY=lsv2.... # LangSmith API key (must be replaced with real key)
# You can get it from the OpenAI website (https://platform.openai.com/).
OPENAI_API_KEY=sk...

## Groq
# Groq API Key - used to access Groq LLMs such as Mixtral or LLaMA models.
# Sign up and get your key from https://console.groq.com/keys
GROQ_API_KEY=grq...

# Others...
101 changes: 98 additions & 3 deletions agents/text/modules/chains.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,21 @@

"""

from langchain.schema.runnable import RunnablePassthrough, RunnableSerializable
from langchain.schema.runnable import (
RunnableLambda,
RunnableMap,
RunnablePassthrough,
RunnableSerializable,
)
from langchain_core.output_parsers import StrOutputParser

from agents.text.modules.models import get_openai_model
from agents.text.modules.models import get_groq_model, get_openai_model
from agents.text.modules.persona import PERSONA
from agents.text.modules.prompts import get_extraction_prompt, get_instagram_text_prompt
from agents.text.modules.prompts import (
get_extraction_prompt,
get_instagram_text_prompt,
get_persona_match_prompt,
)


def set_extraction_chain() -> RunnableSerializable:
Expand Down Expand Up @@ -80,3 +89,89 @@ def set_instagram_text_chain() -> RunnableSerializable:
| model # LLM 모델 호출
| StrOutputParser() # 결과를 문자열로 변환
)


def set_instagram_text_format_check_chain() -> RunnableLambda:
"""
인스타그램 포맷(2200자 이하) 검사를 위한 체인을 반환합니다.

Returns:
RunnableLambda: 텍스트 길이를 검사하는 실행 체인
"""
return RunnableLambda(lambda x: len(x["text"]) <= 2200)


def set_sensitive_text_check_chain() -> RunnableLambda:
def is_text_safe(x):
model = get_groq_model("meta-llama/llama-guard-4-12b")
try:
response = model.invoke(x["text"])
return "safe" in response.content.lower()
except Exception as e:
print(f"[ERROR] llama-guard request failed: {e}")
return False

return RunnableLambda(is_text_safe)


def set_text_persona_match_check_chain() -> RunnableLambda:
def check_persona_match(x):
model = get_openai_model()

text = x["text"]
persona = x.get("persona", {})

# 다양한 타입의 persona_description 처리: dict, str, list
if isinstance(persona, dict):
persona_description = "\n".join([f"{k}: {v}" for k, v in persona.items()])
elif isinstance(persona, list):
persona_description = "\n".join([str(p) for p in persona])
else:
persona_description = str(persona)

prompt_template = get_persona_match_prompt()
prompt = prompt_template.format(
persona_description=persona_description, text=text
)

try:
response = model.invoke(prompt).content.strip().upper()
return "YES" in response
except Exception as e:
print(f"[ERROR] Persona check failed: {e}")
return False

return RunnableLambda(check_persona_match)


def set_text_content_check_chain() -> RunnableSerializable:
return (
RunnablePassthrough.assign(
text=lambda x: x if isinstance(x, str) else x.get("instagram_text", ""),
persona=lambda x: (
x.get("persona_extracted", {}) if isinstance(x, dict) else {}
),
)
| RunnableMap(
{
"format_check_passed": set_instagram_text_format_check_chain(),
"safety_check_passed": set_sensitive_text_check_chain(),
"persona_check_passed": set_text_persona_match_check_chain(),
}
)
| RunnableLambda(
lambda results: {
"text_content_checker_result": {
"success": all(results.values()),
"reason": [k for k, v in results.items() if not v],
"content_check_passed": all(results.values()),
**results,
"message": (
"Text content is valid."
if all(results.values())
else "Text content failed validation checks."
),
}
}
)
)
13 changes: 13 additions & 0 deletions agents/text/modules/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
기본적으로 사용할 모델 인스턴스를 설정하고 생성하고 반환시킵니다.
"""

from langchain_groq import ChatGroq
Comment thread
jeongHwarr marked this conversation as resolved.
from langchain_openai import ChatOpenAI


Expand All @@ -17,3 +18,15 @@ def get_openai_model(temperature=0.7, top_p=0.9):
"""
# OpenAI 모델 초기화 및 반환
return ChatOpenAI(model="gpt-4o-mini", temperature=temperature, top_p=top_p)


def get_groq_model(model_name="llama3-8b-8192", temperature=0.7, top_p=0.9):
"""
Groq API를 사용하는 Llama3 기반 모델을 LangChain에서 가져옵니다.
사용 가능한 모델 예: "llama3-8b-8192", "llama3-70b-8192", "mixtral-8x7b-32768"
"""
return ChatGroq(
model_name=model_name,
temperature=temperature,
top_p=top_p,
)
39 changes: 38 additions & 1 deletion agents/text/modules/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@
"""

from agents.base_node import BaseNode
from agents.text.modules.chains import set_extraction_chain, set_instagram_text_chain
from agents.text.modules.chains import (
set_extraction_chain,
set_instagram_text_chain,
set_text_content_check_chain,
)
from agents.text.modules.persona import PERSONA
from agents.text.modules.state import TextState

Expand Down Expand Up @@ -62,3 +66,36 @@ def execute(self, state: TextState) -> dict:
return {
"instagram_text": instagram_text,
}


class TextContentCheckNode(BaseNode):
Comment thread
jeongHwarr marked this conversation as resolved.
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.chain = set_text_content_check_chain()

def execute(self, state: TextState) -> dict:
instagram_text = state.get("instagram_text", "")

# instagram_text가 빈 칸(또는 공백)일 때는 모든 체크를 스킵하고 성공 결과 리턴
if not instagram_text or not instagram_text.strip():
result = {
"text_content_checker_result": {
"success": True,
"reason": [],
"content_check_passed": True,
"format_check_passed": True,
"safety_check_passed": True,
"persona_check_passed": True,
"message": "Skipped checks because text_content is empty.",
}
}
state.update(result)
return result

input_data = {
"response": state.get("response", [""]),
"persona_extracted": state.get("persona_extracted", {}),
}
result = self.chain.invoke(input_data)
state.update(result)
return result
18 changes: 18 additions & 0 deletions agents/text/modules/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,21 @@ def get_instagram_text_prompt():
"persona_extracted",
],
)


def get_persona_match_prompt() -> PromptTemplate:
"""
Returns a prompt template to evaluate if a given text aligns with a provided persona.

The model must respond only with "YES" or "NO".
"""
template = (
"The following is an Instagram text content. Please determine whether it aligns with the provided persona. "
"Reply only with 'YES' if it matches well, or 'NO' if it doesn't.\n\n"
"[Persona]\n{persona_description}\n\n"
"[Text]\n{text}"
)

return PromptTemplate(
template=template, input_variables=["persona_description", "text"]
)
1 change: 1 addition & 0 deletions agents/text/modules/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,4 @@ class TextState(TypedDict):
response: Annotated[
list, add_messages
] # 응답 메시지 목록 (add_messages로 주석되어 메시지 추가 기능 제공)
text_content_checker_result: (dict) # 텍스트 컨텐츠 검사 결과 전체를 담는 구조화된 필드
1 change: 1 addition & 0 deletions agents/text/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,6 @@ description = "텍스트 기반 콘텐츠 생성을 위한 LangGraph Workflow
readme = "README.md"
requires-python = ">=3.13"
dependencies = [
"langchain-groq>=0.3.2",
"langchain-openai>=0.3.12",
]
15 changes: 12 additions & 3 deletions agents/text/workflow.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from langgraph.graph import StateGraph

from agents.base_workflow import BaseWorkflow
from agents.text.modules.nodes import GenTextNode, PersonaExtractionNode
from agents.text.modules.nodes import (
GenTextNode,
PersonaExtractionNode,
TextContentCheckNode,
)
from agents.text.modules.state import TextState


Expand Down Expand Up @@ -35,12 +39,17 @@ def build(self):
# 텍스트 생성 노드 추가
builder.add_node("text_generation", GenTextNode())

# 텍스트 컨텐츠 체커 노드 추가
builder.add_node("text_content_check", TextContentCheckNode())

# 시작 노드에서 페르소나 추출 노드로 연결
builder.add_edge("__start__", "persona_extraction")
# 페르소나 추출 노드에서 텍스트 생성 노드로 연결
builder.add_edge("persona_extraction", "text_generation")
# 텍스트 생성 노드에서 종료 노드로 연결
builder.add_edge("text_generation", "__end__")
# 텍스트 생성 노드에서 텍스트 컨텐츠 체커 노드로 연결
builder.add_edge("text_generation", "text_content_check")
# 텍스트 컨텐츠 체커 노드에서 종료 노드로 연결
builder.add_edge("text_content_check", "__end__")

# 조건부 에지 추가 예시
# builder.add_conditional_edges(
Expand Down