Upload chain.py

This commit is contained in:
groupuser 2025-04-22 23:55:52 +00:00
parent 566feae028
commit 8dd70f3b53

37
app/chain.py Normal file

@ -0,0 +1,37 @@
from langchain.chains import RetrievalQA
from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnableLambda
from retriever import get_retriever
from llm import TritonLLM
def build_chain():
retriever = get_retriever()
llm = TritonLLM()
# RAG 프롬프트 템플릿 생성
prompt_template = ChatPromptTemplate.from_messages([
("system", """다음 정보를 참고하여 사용자의 질문에 답변하세요:
{context}
답변할 없는 내용이나 주어진 컨텍스트에 없는 내용이면 솔직하게 모른다고 말하세요.
답변은 주어진 컨텍스트에 기반하여 구체적이고 간결하게 작성하세요."""),
("human", "{question}")
])
def format_docs(docs):
return "\n\n".join(f"문서: {i+1}\n{doc.page_content}" for i, doc in enumerate(docs))
rag_chain = (
{
"context": RunnableLambda(lambda x: x["question"]) | retriever | format_docs,
"question": RunnablePassthrough()
}
| prompt_template
| llm
| StrOutputParser()
)
return rag_chain