38 lines
1.3 KiB
Python
38 lines
1.3 KiB
Python
from langchain.chains import RetrievalQA
|
|
from langchain_core.runnables import RunnablePassthrough
|
|
from langchain_core.output_parsers import StrOutputParser
|
|
from langchain_core.prompts import ChatPromptTemplate
|
|
from langchain_core.runnables import RunnableLambda
|
|
from app.retriever import get_retriever
|
|
from app.llm import TritonLLM
|
|
|
|
def build_chain():
|
|
retriever = get_retriever()
|
|
llm = TritonLLM()
|
|
|
|
# RAG 프롬프트 템플릿 생성
|
|
prompt_template = ChatPromptTemplate.from_messages([
|
|
("system", """다음 정보를 참고하여 사용자의 질문에 답변하세요:
|
|
|
|
{context}
|
|
|
|
답변할 수 없는 내용이나 주어진 컨텍스트에 없는 내용이면 솔직하게 모른다고 말하세요.
|
|
답변은 주어진 컨텍스트에 기반하여 구체적이고 간결하게 작성하세요."""),
|
|
("human", "{question}")
|
|
])
|
|
|
|
def format_docs(docs):
|
|
return "\n\n".join(f"문서: {i+1}\n{doc.page_content}" for i, doc in enumerate(docs))
|
|
|
|
rag_chain = (
|
|
{
|
|
"context": RunnableLambda(lambda x: x["question"]) | retriever | format_docs,
|
|
"question": RunnablePassthrough()
|
|
}
|
|
| prompt_template
|
|
| llm
|
|
| StrOutputParser()
|
|
)
|
|
|
|
return rag_chain
|