diff --git a/app/rag_chain.py b/app/rag_chain.py index 9147dab..99f6425 100644 --- a/app/rag_chain.py +++ b/app/rag_chain.py @@ -6,6 +6,8 @@ from langchain.embeddings import HuggingFaceEmbeddings # 필요 시 HuggingFace from langchain.vectorstores import Weaviate from langchain.chains import RetrievalQA from langchain.llms import HuggingFaceHub +from langchain.chat_models import ChatOpenAI +from langchain.embeddings import OpenAIEmbeddings def build_rag_chain(): # 1. Weaviate 클라이언트 @@ -22,24 +24,16 @@ def build_rag_chain(): # 2. 벡터스토어 vectorstore = Weaviate( client=client, - index_name="Test", + index_name="LangDocs", text_key="text", - embedding= HuggingFaceEmbeddings( - model_name="sentence-transformers/all-MiniLM-L6-v2" - ) + embedding=OpenAIEmbeddings() ) # 3. HuggingFace LLM (예: mistralai/Mistral-7B-Instruct-v0.2) - llm = HuggingFaceHub( - repo_id="mistralai/Mistral-7B-Instruct-v0.2", - model_kwargs={ - "temperature": 0.1, - "max_new_tokens": 512, - "top_p": 0.95, - } - ) + llm = ChatOpenAI(temperature=0) retriever = vectorstore.as_retriever() # 4. RetrievalQA chain 구성 - return RetrievalQA.from_chain_type(llm=llm, retriever=retriever, chain_type="stuff") + + return RetrievalQA.from_chain_type(llm=llm, retriever=retriever)