42 lines
1.3 KiB
Python
42 lines
1.3 KiB
Python
# rag_chain.py
|
|
import os
|
|
import weaviate
|
|
from langchain.embeddings import HuggingFaceEmbeddings # 필요 시 HuggingFaceEmbeddings로 교체 가능
|
|
from langchain.vectorstores import Weaviate
|
|
from langchain.chains import RetrievalQA
|
|
from langchain.llms import HuggingFaceHub
|
|
|
|
def build_rag_chain():
|
|
# 1. Weaviate 클라이언트
|
|
client = weaviate.Client(
|
|
url=os.getenv("WEAVIATE_URL", "http://183.111.96.67:30846"),
|
|
auth_client_secret=weaviate.AuthApiKey(os.getenv("WEAVIATE_API_KEY", "01jryrcctd8c8vxbj4bs2ywrgs")),
|
|
additional_headers={"X-HuggingFace-Api-Key": os.getenv("OPENAI_API_KEY", "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")}
|
|
)
|
|
|
|
|
|
# 2. 벡터스토어
|
|
vectorstore = Weaviate(
|
|
client=client,
|
|
index_name="LangDocs",
|
|
text_key="text",
|
|
embedding= HuggingFaceEmbeddings(
|
|
model_name="sentence-transformers/all-MiniLM-L6-v2"
|
|
)
|
|
)
|
|
|
|
# 3. HuggingFace LLM (예: mistralai/Mistral-7B-Instruct-v0.2)
|
|
llm = HuggingFaceHub(
|
|
repo_id="mistralai/Mistral-7B-Instruct-v0.2",
|
|
model_kwargs={
|
|
"temperature": 0.1,
|
|
"max_new_tokens": 512,
|
|
"top_p": 0.95,
|
|
}
|
|
)
|
|
|
|
retriever = vectorstore.as_retriever()
|
|
|
|
# 4. RetrievalQA chain 구성
|
|
return RetrievalQA.from_chain_type(llm=llm, retriever=retriever, chain_type="stuff")
|