Update chain.py
This commit is contained in:
parent
112b324445
commit
d77c7d90d3
23
app/chain.py
23
app/chain.py
@ -3,13 +3,23 @@ from langchain_core.runnables import RunnablePassthrough
|
|||||||
from langchain_core.output_parsers import StrOutputParser
|
from langchain_core.output_parsers import StrOutputParser
|
||||||
from langchain_core.prompts import ChatPromptTemplate
|
from langchain_core.prompts import ChatPromptTemplate
|
||||||
from langchain_core.runnables import RunnableLambda
|
from langchain_core.runnables import RunnableLambda
|
||||||
from app.retriever import get_retriever
|
from retriever import get_retriever
|
||||||
from app.llm import TritonLLM
|
from llm import TritonLLM
|
||||||
|
import json
|
||||||
|
|
||||||
def build_chain():
|
def build_chain():
|
||||||
retriever = get_retriever()
|
retriever = get_retriever()
|
||||||
llm = TritonLLM()
|
llm = TritonLLM()
|
||||||
|
|
||||||
|
def build_chat_input(x):
|
||||||
|
context = format_docs(retriever.invoke(x["question"]))
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": f"다음 정보를 참고하여 사용자의 질문에 답변하세요:\n\n{x["context"]}\n\n답변할 수 없는 내용이면 솔직하게 모른다고 말하세요."},
|
||||||
|
{"role": "user", "content": x["question"]}
|
||||||
|
]
|
||||||
|
|
||||||
|
return json.dumps(messages, ensure_ascii=False)
|
||||||
|
|
||||||
# RAG 프롬프트 템플릿 생성
|
# RAG 프롬프트 템플릿 생성
|
||||||
prompt_template = ChatPromptTemplate.from_messages([
|
prompt_template = ChatPromptTemplate.from_messages([
|
||||||
("system", """다음 정보를 참고하여 사용자의 질문에 답변하세요:
|
("system", """다음 정보를 참고하여 사용자의 질문에 답변하세요:
|
||||||
@ -21,15 +31,20 @@ def build_chain():
|
|||||||
("human", "{question}")
|
("human", "{question}")
|
||||||
])
|
])
|
||||||
|
|
||||||
|
llm_with_params = RunnableLambda(
|
||||||
|
lambda x: llm.invoke(x["question"], temperature=x.get("temperature"), max_tokens=x.get("max_tokens"))
|
||||||
|
)
|
||||||
|
|
||||||
def format_docs(docs):
|
def format_docs(docs):
|
||||||
return "\n\n".join(f"문서: {i+1}\n{doc.page_content}" for i, doc in enumerate(docs))
|
return "\n\n".join(f"문서: {i+1}\n{doc.page_content}" for i, doc in enumerate(docs))
|
||||||
|
|
||||||
rag_chain = (
|
rag_chain = (
|
||||||
{
|
{
|
||||||
"context": RunnableLambda(lambda x: x["question"]) | retriever | format_docs,
|
"context": RunnableLambda(lambda x: x["question"]) | retriever | format_docs,
|
||||||
"question": RunnablePassthrough()
|
"question": lambda x: x["question"]
|
||||||
}
|
}
|
||||||
| prompt_template
|
#| prompt_template
|
||||||
|
| RunnableLambda(build_chat_input)
|
||||||
| llm
|
| llm
|
||||||
| StrOutputParser()
|
| StrOutputParser()
|
||||||
)
|
)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user