343 lines
11 KiB
Python
343 lines
11 KiB
Python
"""
|
|
LangChain RAG 시스템 with WeaviateVectorStore, 커스텀 임베딩 서비스, Triton Inference Server LLM
|
|
langchain_weaviate 패키지 사용
|
|
"""
|
|
|
|
import os
|
|
from typing import List, Dict, Any
|
|
|
|
import requests
|
|
from langchain_core.documents import Document
|
|
from langchain_core.embeddings import Embeddings
|
|
from langchain_core.language_models.llms import LLM
|
|
from langchain_core.output_parsers import StrOutputParser
|
|
from langchain_core.prompts import ChatPromptTemplate
|
|
from langchain_core.retrievers import BaseRetriever
|
|
from langchain_core.runnables import RunnablePassthrough
|
|
from langchain_weaviate import WeaviateVectorStore # 새로운 import 방식
|
|
from langchain.schema import AIMessage, HumanMessage
|
|
from langchain_community.document_loaders import TextLoader, DirectoryLoader
|
|
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
|
import weaviate
|
|
from pydantic import BaseModel, Field
|
|
from fastapi import FastAPI
|
|
|
|
# LangServe 임포트
|
|
from langserve import add_routes
|
|
|
|
|
|
# 커스텀 임베딩 모델 클래스 정의
|
|
class CustomEmbeddingModel(Embeddings):
|
|
"""커스텀 엔드포인트를 사용하는 임베딩 모델"""
|
|
|
|
def __init__(self, api_url: str):
|
|
self.api_url = api_url
|
|
|
|
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
|
"""문서 리스트를 임베딩합니다."""
|
|
embeddings = []
|
|
for text in texts:
|
|
embeddings.append(self.embed_query(text))
|
|
return embeddings
|
|
|
|
def embed_query(self, text: str) -> List[float]:
|
|
"""쿼리 텍스트를 임베딩합니다."""
|
|
response = requests.post(
|
|
self.api_url,
|
|
json={"text": text}
|
|
)
|
|
if response.status_code != 200:
|
|
raise ValueError(f"임베딩 API 호출 실패: {response.status_code}, {response.text}")
|
|
|
|
return response.json()["embedding"]
|
|
|
|
|
|
# 커스텀 LLM 클래스 정의 (Triton Inference Server)
|
|
class TritonLLM(LLM):
|
|
"""Triton Inference Server를 사용하는 LLM"""
|
|
|
|
def __init__(self, api_url: str):
|
|
super().__init__()
|
|
self.api_url = api_url
|
|
|
|
def _call(self, prompt: str, stop: List[str] = None, **kwargs) -> str:
|
|
"""LLM에 프롬프트를 전송하고 응답을 받습니다."""
|
|
payload = {
|
|
"prompt": prompt,
|
|
"max_tokens": kwargs.get("max_tokens", 1024),
|
|
"temperature": kwargs.get("temperature", 0.7),
|
|
}
|
|
|
|
if stop:
|
|
payload["stop"] = stop
|
|
|
|
response = requests.post(self.api_url, json=payload)
|
|
|
|
if response.status_code != 200:
|
|
raise ValueError(f"LLM API 호출 실패: {response.status_code}, {response.text}")
|
|
|
|
return response.json()["text"]
|
|
|
|
@property
|
|
def _llm_type(self) -> str:
|
|
return "triton-custom-llm"
|
|
|
|
|
|
# 문서를 로드하고 청크로 분할하는 함수
|
|
def load_and_split_documents(directory_path: str) -> List[Document]:
|
|
"""디렉토리에서 문서를 로드하고 청크로 분할합니다."""
|
|
loader = DirectoryLoader(directory_path, glob="**/*.txt", loader_cls=TextLoader)
|
|
documents = loader.load()
|
|
|
|
splitter = RecursiveCharacterTextSplitter(
|
|
chunk_size=1000,
|
|
chunk_overlap=200,
|
|
)
|
|
|
|
return splitter.split_documents(documents)
|
|
|
|
|
|
# RAG 파이프라인 생성 함수
|
|
def create_rag_pipeline(
|
|
weaviate_url: str,
|
|
embedding_url: str,
|
|
llm_url: str,
|
|
index_name: str = "Knowledge"
|
|
):
|
|
"""LangChain RAG 파이프라인을 생성합니다."""
|
|
|
|
# 임베딩 모델 인스턴스 생성
|
|
embeddings = CustomEmbeddingModel(api_url=embedding_url)
|
|
|
|
# Weaviate 클라이언트 설정
|
|
client = weaviate.Client(
|
|
url=weaviate_url
|
|
)
|
|
|
|
# WeaviateVectorStore 생성 - 새로운 방식 적용
|
|
vectorstore = WeaviateVectorStore(
|
|
client=client,
|
|
index_name=index_name,
|
|
text_key="content", # 필요에 따라 텍스트 필드 키 설정
|
|
embedding=embeddings,
|
|
metadatas_key="metadata", # 메타데이터 필드 키 설정
|
|
attributes=["source", "page"] # 검색 시 포함할 메타데이터 속성
|
|
)
|
|
|
|
# 리트리버 생성
|
|
retriever = vectorstore.as_retriever(
|
|
search_kwargs={"k": 5, "score_threshold": 0.7} # 유사도 임계값 추가
|
|
)
|
|
|
|
# LLM 인스턴스 생성
|
|
llm = TritonLLM(api_url=llm_url)
|
|
|
|
# RAG 프롬프트 템플릿 생성
|
|
prompt_template = ChatPromptTemplate.from_messages([
|
|
("system", """다음 정보를 참고하여 사용자의 질문에 답변하세요:
|
|
|
|
{context}
|
|
|
|
답변할 수 없는 내용이나 주어진 컨텍스트에 없는 내용이면 솔직하게 모른다고 말하세요.
|
|
답변은 주어진 컨텍스트에 기반하여 구체적이고 간결하게 작성하세요."""),
|
|
("human", "{question}")
|
|
])
|
|
|
|
# RAG 파이프라인 조립
|
|
def format_docs(docs):
|
|
return "\n\n".join(f"문서: {i+1}\n{doc.page_content}" for i, doc in enumerate(docs))
|
|
|
|
rag_chain = (
|
|
{"context": retriever | format_docs, "question": RunnablePassthrough()}
|
|
| prompt_template
|
|
| llm
|
|
| StrOutputParser()
|
|
)
|
|
|
|
return rag_chain
|
|
|
|
|
|
# 문서 인덱싱 함수 - WeaviateVectorStore 사용 방식
|
|
def index_documents(
|
|
documents: List[Document],
|
|
weaviate_url: str,
|
|
embedding_url: str,
|
|
index_name: str = "Knowledge"
|
|
):
|
|
"""문서를 Weaviate에 인덱싱합니다."""
|
|
embeddings = CustomEmbeddingModel(api_url=embedding_url)
|
|
|
|
# Weaviate 클라이언트 생성
|
|
client = weaviate.Client(
|
|
url=weaviate_url
|
|
)
|
|
|
|
# 클래스 존재 여부 확인 및 삭제 (필요한 경우)
|
|
try:
|
|
client.schema.get(index_name)
|
|
client.schema.delete_class(index_name)
|
|
print(f"기존 '{index_name}' 클래스를 삭제했습니다.")
|
|
except Exception as e:
|
|
# 클래스가 존재하지 않는 경우 무시
|
|
pass
|
|
|
|
# 클래스 생성 - WeaviateVectorStore에 맞는 방식으로 스키마 정의
|
|
class_obj = {
|
|
"class": index_name,
|
|
"vectorizer": "none", # 외부 임베딩을 사용하므로 none
|
|
"properties": [
|
|
{
|
|
"name": "content",
|
|
"dataType": ["text"]
|
|
},
|
|
{
|
|
"name": "metadata",
|
|
"dataType": ["object"]
|
|
}
|
|
]
|
|
}
|
|
|
|
client.schema.create_class(class_obj)
|
|
print(f"'{index_name}' 클래스를 생성했습니다.")
|
|
|
|
# WeaviateVectorStore 생성 및 문서 추가
|
|
vectorstore = WeaviateVectorStore.from_documents(
|
|
documents=documents,
|
|
embedding=embeddings,
|
|
client=client,
|
|
index_name=index_name,
|
|
text_key="content",
|
|
metadatas_key="metadata"
|
|
)
|
|
|
|
print(f"{len(documents)}개의 문서가 성공적으로 인덱싱되었습니다.")
|
|
return vectorstore
|
|
|
|
|
|
# Pydantic 모델 (API 입력용)
|
|
class QueryInput(BaseModel):
|
|
"""쿼리 입력 모델"""
|
|
question: str = Field(..., description="사용자 질문")
|
|
|
|
|
|
# 메인 애플리케이션
|
|
def create_application():
|
|
"""FastAPI 애플리케이션을 생성하고 LangServe 라우트를 추가합니다."""
|
|
app = FastAPI(
|
|
title="RAG API with LangChain and LangServe",
|
|
version="1.0",
|
|
description="문서 검색과 질문 답변을 위한 RAG API"
|
|
)
|
|
|
|
# 환경 변수 설정 (실제 사용 시 환경 변수나 설정 파일 사용 권장)
|
|
WEAVIATE_URL = "http://183.111.96.67:32401"
|
|
EMBEDDING_URL = "http://183.111.86.67:21212/embeding" # 원본 URL 유지
|
|
LLM_URL = "http://183.111.85.21:121212"
|
|
INDEX_NAME = "Knowledge"
|
|
|
|
# RAG 파이프라인 생성
|
|
rag_chain = create_rag_pipeline(
|
|
weaviate_url=WEAVIATE_URL,
|
|
embedding_url=EMBEDDING_URL,
|
|
llm_url=LLM_URL,
|
|
index_name=INDEX_NAME
|
|
)
|
|
|
|
# LangServe 라우트 추가
|
|
add_routes(
|
|
app,
|
|
rag_chain,
|
|
path="/rag",
|
|
input_type=QueryInput
|
|
)
|
|
|
|
# 상태 확인 엔드포인트
|
|
@app.get("/")
|
|
def read_root():
|
|
return {"status": "online", "service": "RAG API"}
|
|
|
|
return app
|
|
|
|
|
|
# 문서 인덱싱을 위한 스크립트 (별도로 실행)
|
|
def index_documents_script():
|
|
"""문서를 인덱싱하기 위한 스크립트"""
|
|
import argparse
|
|
|
|
parser = argparse.ArgumentParser(description="문서를 Weaviate에 인덱싱합니다.")
|
|
parser.add_argument("--dir", type=str, required=True, help="문서가 있는 디렉토리 경로")
|
|
args = parser.parse_args()
|
|
|
|
# 환경 변수 설정 (실제 사용 시 환경 변수나 설정 파일 사용 권장)
|
|
WEAVIATE_URL = "http://183.111.96.67:32401"
|
|
EMBEDDING_URL = "http://183.111.86.67:21212/embeding" # 원본 URL 유지
|
|
INDEX_NAME = "Knowledge"
|
|
|
|
# 문서 로드 및 분할
|
|
documents = load_and_split_documents(args.dir)
|
|
print(f"{len(documents)}개의 문서 청크가 로드되었습니다.")
|
|
|
|
# 문서 인덱싱
|
|
index_documents(
|
|
documents=documents,
|
|
weaviate_url=WEAVIATE_URL,
|
|
embedding_url=EMBEDDING_URL,
|
|
index_name=INDEX_NAME
|
|
)
|
|
|
|
|
|
# 고급 기능: 메타데이터 필터링 기능이 있는 쿼리 함수
|
|
def query_with_metadata_filter(
|
|
question: str,
|
|
filter_dict: Dict[str, Any],
|
|
weaviate_url: str,
|
|
embedding_url: str,
|
|
llm_url: str,
|
|
index_name: str = "Knowledge"
|
|
):
|
|
"""메타데이터 필터링을 사용한 쿼리 함수"""
|
|
embeddings = CustomEmbeddingModel(api_url=embedding_url)
|
|
client = weaviate.Client(url=weaviate_url)
|
|
|
|
# WeaviateVectorStore 생성
|
|
vectorstore = WeaviateVectorStore(
|
|
client=client,
|
|
index_name=index_name,
|
|
text_key="content",
|
|
embedding=embeddings,
|
|
metadatas_key="metadata",
|
|
attributes=["source", "page"]
|
|
)
|
|
|
|
# 메타데이터 필터를 사용한 검색
|
|
docs = vectorstore.similarity_search(
|
|
query=question,
|
|
k=5,
|
|
where_filter=filter_dict # 예: {"metadata.source": "특정파일.txt"}
|
|
)
|
|
|
|
# LLM으로 답변 생성
|
|
llm = TritonLLM(api_url=llm_url)
|
|
|
|
prompt_template = ChatPromptTemplate.from_messages([
|
|
("system", "다음 정보를 참고하여 사용자의 질문에 답변하세요:\n\n{context}"),
|
|
("human", "{question}")
|
|
])
|
|
|
|
context_text = "\n\n".join(f"문서: {i+1}\n{doc.page_content}" for i, doc in enumerate(docs))
|
|
|
|
response = prompt_template.format(
|
|
context=context_text,
|
|
question=question
|
|
)
|
|
|
|
return llm.invoke(response)
|
|
|
|
|
|
# LangServe 서버 실행을 위한 메인 함수
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
|
|
app = create_application()
|
|
uvicorn.run(app, host="0.0.0.0", port=8000)
|