Update rag_chain.py

This commit is contained in:
localsoo 2025-04-21 00:54:18 +00:00
parent ef877215ae
commit 901baee508

@ -6,6 +6,8 @@ from langchain.embeddings import HuggingFaceEmbeddings # 필요 시 HuggingFace
from langchain.vectorstores import Weaviate from langchain.vectorstores import Weaviate
from langchain.chains import RetrievalQA from langchain.chains import RetrievalQA
from langchain.llms import HuggingFaceHub from langchain.llms import HuggingFaceHub
from langchain.chat_models import ChatOpenAI
from langchain.embeddings import OpenAIEmbeddings
def build_rag_chain(): def build_rag_chain():
# 1. Weaviate 클라이언트 # 1. Weaviate 클라이언트
@ -22,24 +24,16 @@ def build_rag_chain():
# 2. 벡터스토어 # 2. 벡터스토어
vectorstore = Weaviate( vectorstore = Weaviate(
client=client, client=client,
index_name="Test", index_name="LangDocs",
text_key="text", text_key="text",
embedding= HuggingFaceEmbeddings( embedding=OpenAIEmbeddings()
model_name="sentence-transformers/all-MiniLM-L6-v2"
)
) )
# 3. HuggingFace LLM (예: mistralai/Mistral-7B-Instruct-v0.2) # 3. HuggingFace LLM (예: mistralai/Mistral-7B-Instruct-v0.2)
llm = HuggingFaceHub( llm = ChatOpenAI(temperature=0)
repo_id="mistralai/Mistral-7B-Instruct-v0.2",
model_kwargs={
"temperature": 0.1,
"max_new_tokens": 512,
"top_p": 0.95,
}
)
retriever = vectorstore.as_retriever() retriever = vectorstore.as_retriever()
# 4. RetrievalQA chain 구성 # 4. RetrievalQA chain 구성
return RetrievalQA.from_chain_type(llm=llm, retriever=retriever, chain_type="stuff")
return RetrievalQA.from_chain_type(llm=llm, retriever=retriever)