rag-agent-soo/app/rag_chain.py
2025-04-21 00:49:55 +00:00

45 lines
1.5 KiB
Python

# rag_chain.py
import os
import weaviate
from langchain.embeddings import HuggingFaceEmbeddings # 필요 시 HuggingFaceEmbeddings로 교체 가능
from langchain.vectorstores import Weaviate
from langchain.chains import RetrievalQA
from langchain.llms import HuggingFaceHub
def build_rag_chain():
# 1. Weaviate 클라이언트
auth_config = weaviate.AuthApiKey(api_key="01jryrcctd8c8vxbj4bs2ywrgs")
client = weaviate.connect_to_weaviate_cloud(cluster_url="http://183.111.96.67:32668",
auth_credentials=Auth.api_key("01jryrcctd8c8vxbj4bs2ywrgs"),
headers={
"X-OpenAI-Api-Key": "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" # 필요할 경우
})
# 2. 벡터스토어
vectorstore = Weaviate(
client=client,
index_name="Test",
text_key="text",
embedding= HuggingFaceEmbeddings(
model_name="sentence-transformers/all-MiniLM-L6-v2"
)
)
# 3. HuggingFace LLM (예: mistralai/Mistral-7B-Instruct-v0.2)
llm = HuggingFaceHub(
repo_id="mistralai/Mistral-7B-Instruct-v0.2",
model_kwargs={
"temperature": 0.1,
"max_new_tokens": 512,
"top_p": 0.95,
}
)
retriever = vectorstore.as_retriever()
# 4. RetrievalQA chain 구성
return RetrievalQA.from_chain_type(llm=llm, retriever=retriever, chain_type="stuff")