diff --git a/app/chain.py b/app/chain.py index de68d9c..0ce9a94 100644 --- a/app/chain.py +++ b/app/chain.py @@ -3,13 +3,23 @@ from langchain_core.runnables import RunnablePassthrough from langchain_core.output_parsers import StrOutputParser from langchain_core.prompts import ChatPromptTemplate from langchain_core.runnables import RunnableLambda -from app.retriever import get_retriever -from app.llm import TritonLLM +from retriever import get_retriever +from llm import TritonLLM +import json def build_chain(): retriever = get_retriever() llm = TritonLLM() + def build_chat_input(x): + context = format_docs(retriever.invoke(x["question"])) + messages = [ + {"role": "system", "content": f"다음 정보를 참고하여 사용자의 질문에 답변하세요:\n\n{x["context"]}\n\n답변할 수 없는 내용이면 솔직하게 모른다고 말하세요."}, + {"role": "user", "content": x["question"]} + ] + + return json.dumps(messages, ensure_ascii=False) + # RAG 프롬프트 템플릿 생성 prompt_template = ChatPromptTemplate.from_messages([ ("system", """다음 정보를 참고하여 사용자의 질문에 답변하세요: @@ -21,15 +31,20 @@ def build_chain(): ("human", "{question}") ]) + llm_with_params = RunnableLambda( + lambda x: llm.invoke(x["question"], temperature=x.get("temperature"), max_tokens=x.get("max_tokens")) + ) + def format_docs(docs): return "\n\n".join(f"문서: {i+1}\n{doc.page_content}" for i, doc in enumerate(docs)) rag_chain = ( { "context": RunnableLambda(lambda x: x["question"]) | retriever | format_docs, - "question": RunnablePassthrough() + "question": lambda x: x["question"] } - | prompt_template + #| prompt_template + | RunnableLambda(build_chat_input) | llm | StrOutputParser() )