From 8dd70f3b53db70c3dd8630d3282eb97cd6397cae Mon Sep 17 00:00:00 2001 From: groupuser Date: Tue, 22 Apr 2025 23:55:52 +0000 Subject: [PATCH] Upload chain.py --- app/chain.py | 37 +++++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) create mode 100644 app/chain.py diff --git a/app/chain.py b/app/chain.py new file mode 100644 index 0000000..40ccda8 --- /dev/null +++ b/app/chain.py @@ -0,0 +1,37 @@ +from langchain.chains import RetrievalQA +from langchain_core.runnables import RunnablePassthrough +from langchain_core.output_parsers import StrOutputParser +from langchain_core.prompts import ChatPromptTemplate +from langchain_core.runnables import RunnableLambda +from retriever import get_retriever +from llm import TritonLLM + +def build_chain(): + retriever = get_retriever() + llm = TritonLLM() + + # RAG 프롬프트 템플릿 생성 + prompt_template = ChatPromptTemplate.from_messages([ + ("system", """다음 정보를 참고하여 사용자의 질문에 답변하세요: + +{context} + +답변할 수 없는 내용이나 주어진 컨텍스트에 없는 내용이면 솔직하게 모른다고 말하세요. +답변은 주어진 컨텍스트에 기반하여 구체적이고 간결하게 작성하세요."""), + ("human", "{question}") + ]) + + def format_docs(docs): + return "\n\n".join(f"문서: {i+1}\n{doc.page_content}" for i, doc in enumerate(docs)) + + rag_chain = ( + { + "context": RunnableLambda(lambda x: x["question"]) | retriever | format_docs, + "question": RunnablePassthrough() + } + | prompt_template + | llm + | StrOutputParser() + ) + + return rag_chain