diff --git a/app/sample.py b/app/sample.py new file mode 100644 index 0000000..791e671 --- /dev/null +++ b/app/sample.py @@ -0,0 +1,342 @@ +""" +LangChain RAG 시스템 with WeaviateVectorStore, 커스텀 임베딩 서비스, Triton Inference Server LLM +langchain_weaviate 패키지 사용 +""" + +import os +from typing import List, Dict, Any + +import requests +from langchain_core.documents import Document +from langchain_core.embeddings import Embeddings +from langchain_core.language_models.llms import LLM +from langchain_core.output_parsers import StrOutputParser +from langchain_core.prompts import ChatPromptTemplate +from langchain_core.retrievers import BaseRetriever +from langchain_core.runnables import RunnablePassthrough +from langchain_weaviate import WeaviateVectorStore # 새로운 import 방식 +from langchain.schema import AIMessage, HumanMessage +from langchain_community.document_loaders import TextLoader, DirectoryLoader +from langchain_text_splitters import RecursiveCharacterTextSplitter +import weaviate +from pydantic import BaseModel, Field +from fastapi import FastAPI + +# LangServe 임포트 +from langserve import add_routes + + +# 커스텀 임베딩 모델 클래스 정의 +class CustomEmbeddingModel(Embeddings): + """커스텀 엔드포인트를 사용하는 임베딩 모델""" + + def __init__(self, api_url: str): + self.api_url = api_url + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + """문서 리스트를 임베딩합니다.""" + embeddings = [] + for text in texts: + embeddings.append(self.embed_query(text)) + return embeddings + + def embed_query(self, text: str) -> List[float]: + """쿼리 텍스트를 임베딩합니다.""" + response = requests.post( + self.api_url, + json={"text": text} + ) + if response.status_code != 200: + raise ValueError(f"임베딩 API 호출 실패: {response.status_code}, {response.text}") + + return response.json()["embedding"] + + +# 커스텀 LLM 클래스 정의 (Triton Inference Server) +class TritonLLM(LLM): + """Triton Inference Server를 사용하는 LLM""" + + def __init__(self, api_url: str): + super().__init__() + self.api_url = api_url + + def _call(self, prompt: str, stop: List[str] = None, **kwargs) -> str: + """LLM에 프롬프트를 전송하고 응답을 받습니다.""" + payload = { + "prompt": prompt, + "max_tokens": kwargs.get("max_tokens", 1024), + "temperature": kwargs.get("temperature", 0.7), + } + + if stop: + payload["stop"] = stop + + response = requests.post(self.api_url, json=payload) + + if response.status_code != 200: + raise ValueError(f"LLM API 호출 실패: {response.status_code}, {response.text}") + + return response.json()["text"] + + @property + def _llm_type(self) -> str: + return "triton-custom-llm" + + +# 문서를 로드하고 청크로 분할하는 함수 +def load_and_split_documents(directory_path: str) -> List[Document]: + """디렉토리에서 문서를 로드하고 청크로 분할합니다.""" + loader = DirectoryLoader(directory_path, glob="**/*.txt", loader_cls=TextLoader) + documents = loader.load() + + splitter = RecursiveCharacterTextSplitter( + chunk_size=1000, + chunk_overlap=200, + ) + + return splitter.split_documents(documents) + + +# RAG 파이프라인 생성 함수 +def create_rag_pipeline( + weaviate_url: str, + embedding_url: str, + llm_url: str, + index_name: str = "Knowledge" +): + """LangChain RAG 파이프라인을 생성합니다.""" + + # 임베딩 모델 인스턴스 생성 + embeddings = CustomEmbeddingModel(api_url=embedding_url) + + # Weaviate 클라이언트 설정 + client = weaviate.Client( + url=weaviate_url + ) + + # WeaviateVectorStore 생성 - 새로운 방식 적용 + vectorstore = WeaviateVectorStore( + client=client, + index_name=index_name, + text_key="content", # 필요에 따라 텍스트 필드 키 설정 + embedding=embeddings, + metadatas_key="metadata", # 메타데이터 필드 키 설정 + attributes=["source", "page"] # 검색 시 포함할 메타데이터 속성 + ) + + # 리트리버 생성 + retriever = vectorstore.as_retriever( + search_kwargs={"k": 5, "score_threshold": 0.7} # 유사도 임계값 추가 + ) + + # LLM 인스턴스 생성 + llm = TritonLLM(api_url=llm_url) + + # RAG 프롬프트 템플릿 생성 + prompt_template = ChatPromptTemplate.from_messages([ + ("system", """다음 정보를 참고하여 사용자의 질문에 답변하세요: + +{context} + +답변할 수 없는 내용이나 주어진 컨텍스트에 없는 내용이면 솔직하게 모른다고 말하세요. +답변은 주어진 컨텍스트에 기반하여 구체적이고 간결하게 작성하세요."""), + ("human", "{question}") + ]) + + # RAG 파이프라인 조립 + def format_docs(docs): + return "\n\n".join(f"문서: {i+1}\n{doc.page_content}" for i, doc in enumerate(docs)) + + rag_chain = ( + {"context": retriever | format_docs, "question": RunnablePassthrough()} + | prompt_template + | llm + | StrOutputParser() + ) + + return rag_chain + + +# 문서 인덱싱 함수 - WeaviateVectorStore 사용 방식 +def index_documents( + documents: List[Document], + weaviate_url: str, + embedding_url: str, + index_name: str = "Knowledge" +): + """문서를 Weaviate에 인덱싱합니다.""" + embeddings = CustomEmbeddingModel(api_url=embedding_url) + + # Weaviate 클라이언트 생성 + client = weaviate.Client( + url=weaviate_url + ) + + # 클래스 존재 여부 확인 및 삭제 (필요한 경우) + try: + client.schema.get(index_name) + client.schema.delete_class(index_name) + print(f"기존 '{index_name}' 클래스를 삭제했습니다.") + except Exception as e: + # 클래스가 존재하지 않는 경우 무시 + pass + + # 클래스 생성 - WeaviateVectorStore에 맞는 방식으로 스키마 정의 + class_obj = { + "class": index_name, + "vectorizer": "none", # 외부 임베딩을 사용하므로 none + "properties": [ + { + "name": "content", + "dataType": ["text"] + }, + { + "name": "metadata", + "dataType": ["object"] + } + ] + } + + client.schema.create_class(class_obj) + print(f"'{index_name}' 클래스를 생성했습니다.") + + # WeaviateVectorStore 생성 및 문서 추가 + vectorstore = WeaviateVectorStore.from_documents( + documents=documents, + embedding=embeddings, + client=client, + index_name=index_name, + text_key="content", + metadatas_key="metadata" + ) + + print(f"{len(documents)}개의 문서가 성공적으로 인덱싱되었습니다.") + return vectorstore + + +# Pydantic 모델 (API 입력용) +class QueryInput(BaseModel): + """쿼리 입력 모델""" + question: str = Field(..., description="사용자 질문") + + +# 메인 애플리케이션 +def create_application(): + """FastAPI 애플리케이션을 생성하고 LangServe 라우트를 추가합니다.""" + app = FastAPI( + title="RAG API with LangChain and LangServe", + version="1.0", + description="문서 검색과 질문 답변을 위한 RAG API" + ) + + # 환경 변수 설정 (실제 사용 시 환경 변수나 설정 파일 사용 권장) + WEAVIATE_URL = "http://183.111.96.67:32401" + EMBEDDING_URL = "http://183.111.86.67:21212/embeding" # 원본 URL 유지 + LLM_URL = "http://183.111.85.21:121212" + INDEX_NAME = "Knowledge" + + # RAG 파이프라인 생성 + rag_chain = create_rag_pipeline( + weaviate_url=WEAVIATE_URL, + embedding_url=EMBEDDING_URL, + llm_url=LLM_URL, + index_name=INDEX_NAME + ) + + # LangServe 라우트 추가 + add_routes( + app, + rag_chain, + path="/rag", + input_type=QueryInput + ) + + # 상태 확인 엔드포인트 + @app.get("/") + def read_root(): + return {"status": "online", "service": "RAG API"} + + return app + + +# 문서 인덱싱을 위한 스크립트 (별도로 실행) +def index_documents_script(): + """문서를 인덱싱하기 위한 스크립트""" + import argparse + + parser = argparse.ArgumentParser(description="문서를 Weaviate에 인덱싱합니다.") + parser.add_argument("--dir", type=str, required=True, help="문서가 있는 디렉토리 경로") + args = parser.parse_args() + + # 환경 변수 설정 (실제 사용 시 환경 변수나 설정 파일 사용 권장) + WEAVIATE_URL = "http://183.111.96.67:32401" + EMBEDDING_URL = "http://183.111.86.67:21212/embeding" # 원본 URL 유지 + INDEX_NAME = "Knowledge" + + # 문서 로드 및 분할 + documents = load_and_split_documents(args.dir) + print(f"{len(documents)}개의 문서 청크가 로드되었습니다.") + + # 문서 인덱싱 + index_documents( + documents=documents, + weaviate_url=WEAVIATE_URL, + embedding_url=EMBEDDING_URL, + index_name=INDEX_NAME + ) + + +# 고급 기능: 메타데이터 필터링 기능이 있는 쿼리 함수 +def query_with_metadata_filter( + question: str, + filter_dict: Dict[str, Any], + weaviate_url: str, + embedding_url: str, + llm_url: str, + index_name: str = "Knowledge" +): + """메타데이터 필터링을 사용한 쿼리 함수""" + embeddings = CustomEmbeddingModel(api_url=embedding_url) + client = weaviate.Client(url=weaviate_url) + + # WeaviateVectorStore 생성 + vectorstore = WeaviateVectorStore( + client=client, + index_name=index_name, + text_key="content", + embedding=embeddings, + metadatas_key="metadata", + attributes=["source", "page"] + ) + + # 메타데이터 필터를 사용한 검색 + docs = vectorstore.similarity_search( + query=question, + k=5, + where_filter=filter_dict # 예: {"metadata.source": "특정파일.txt"} + ) + + # LLM으로 답변 생성 + llm = TritonLLM(api_url=llm_url) + + prompt_template = ChatPromptTemplate.from_messages([ + ("system", "다음 정보를 참고하여 사용자의 질문에 답변하세요:\n\n{context}"), + ("human", "{question}") + ]) + + context_text = "\n\n".join(f"문서: {i+1}\n{doc.page_content}" for i, doc in enumerate(docs)) + + response = prompt_template.format( + context=context_text, + question=question + ) + + return llm.invoke(response) + + +# LangServe 서버 실행을 위한 메인 함수 +if __name__ == "__main__": + import uvicorn + + app = create_application() + uvicorn.run(app, host="0.0.0.0", port=8000)