88 lines
3.3 KiB
Python
88 lines
3.3 KiB
Python
# rag_chain.py
|
|
import os
|
|
import weaviate
|
|
from weaviate import Client
|
|
from weaviate.connect import ConnectionParams
|
|
from weaviate.auth import AuthApiKey
|
|
from weaviate.classes.init import Auth
|
|
from langchain.vectorstores import Weaviate
|
|
from langchain.chains import RetrievalQA
|
|
from langchain.chat_models import ChatOpenAI
|
|
from langchain.embeddings import OpenAIEmbeddings
|
|
|
|
def build_rag_chain():
|
|
# 1. Weaviate 클라이언트
|
|
auth_config = weaviate.AuthApiKey(api_key="01js3q6y7twaxccm5dbh3se9bt")
|
|
|
|
# client = weaviate.connect_to_weaviate_cloud(cluster_url="http://183.111.96.67:32668",
|
|
# auth_credentials=Auth.api_key("01js3q6y7twaxccm5dbh3se9bt"),
|
|
# headers={
|
|
# "X-OpenAI-Api-Key": "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" # 필요할 경우
|
|
# })
|
|
OPENAI_API_KEY="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"
|
|
|
|
# client = weaviate.Client(
|
|
# url="http://183.111.96.67:32668", # 예: "http://183.111.96.67:32668"
|
|
# auth_client_secret=Auth.api_key("01js3q6y7twaxccm5dbh3se9bt"), # 필요 없으면 제거
|
|
# additional_headers={
|
|
# "X-OpenAI-Api-Key": "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" # 필요할 경우
|
|
# }
|
|
# )
|
|
# client = Client(
|
|
# connection_params=ConnectionParams.from_http(
|
|
# host="183.111.96.67", # 도메인 or IP
|
|
# port=32668, # 포트
|
|
# secure=False, # HTTP면 False, HTTPS면 True
|
|
# auth_credentials=auth_config
|
|
# )
|
|
# )
|
|
|
|
connection_params = ConnectionParams(
|
|
http_host="183.111.96.67",
|
|
http_port=32668,
|
|
http_secure=False,
|
|
grpc_host="183.111.96.67",
|
|
grpc_port=32619,
|
|
grpc_secure=False,
|
|
auth_credentials=auth_config
|
|
)
|
|
|
|
client = WeaviateClient(connection_params)
|
|
|
|
# client = weaviate.connect_to_custom(
|
|
# http_host="183.111.96.67",
|
|
# http_port=32668,
|
|
# grpc_host="183.111.96.67",
|
|
# http_secure=False,
|
|
# grpc_port=32619,
|
|
# grpc_secure=False,
|
|
# auth_credentials=AuthApiKey("01js3q6y7twaxccm5dbh3se9bt"), # 인증이 필요 없으면 생략 가능
|
|
# headers={"X-OpenAI-Api-Key": OPENAI_API_KEY} # 필요시
|
|
# )
|
|
|
|
if client.is_ready():
|
|
print("Weaviate 연결 성공!")
|
|
else:
|
|
print("연결 실패. 서버 상태를 확인하세요.")
|
|
|
|
|
|
# 2. 벡터스토어
|
|
vectorstore = Weaviate(
|
|
client=client,
|
|
index_name="LangDocs",
|
|
text_key="text",
|
|
embedding=OpenAIEmbeddings(openai_api_key=OPENAI_API_KEY)
|
|
)
|
|
|
|
# 3. HuggingFace LLM (예: mistralai/Mistral-7B-Instruct-v0.2)
|
|
llm = ChatOpenAI(temperature=0, openai_api_key=OPENAI_API_KEY)
|
|
|
|
retriever = vectorstore.as_retriever()
|
|
|
|
# 4. RetrievalQA chain 구성
|
|
|
|
qa_chain = RetrievalQA.from_chain_type(llm=llm, retriever=retriever)
|
|
|
|
client.close()
|
|
return qa_chain
|