52 lines
2.3 KiB
Python
52 lines
2.3 KiB
Python
from langchain.schema.runnable import RunnableMap
|
|
from langchain.vectorstores import Weaviate
|
|
from langchain.llms import HuggingFaceEndpoint
|
|
from langchain.prompts import PromptTemplate
|
|
from langchain.embeddings.base import Embeddings
|
|
import httpx
|
|
import weaviate
|
|
import os
|
|
|
|
# 환경 변수에서 세팅
|
|
LLM_ENDPOINT = os.environ.get('LLM_ENDPOINT', '') # https://deploymodel.cheetah.svc.cluster.local
|
|
EMBEDDING_ENDPOINT = os.environ.get('EMBEDDING_ENDPOINT', '') # https://embedding.cheetah.svc.cluster.local
|
|
VECTOR_DATABSE_ENDPOINT = os.environ.get('VECTOR_DATABSE_ENDPOINT', '') # https://vector.cheetah.svc.cluster.local
|
|
VECTOR_DATABSE_APIKEY = os.environ.get('VECTOR_DATABSE_APIKEY', '') # https://vector.cheetah.svc.cluster.local
|
|
|
|
# 사용자 정의 임베딩 클래스 (온프레미스 embedding 서버 사용)
|
|
class CustomRemoteEmbeddings(Embeddings):
|
|
def __init__(self, endpoint_url: str):
|
|
self.endpoint_url = endpoint_url
|
|
|
|
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
|
response = httpx.post(
|
|
self.endpoint_url,
|
|
json={"inputs": texts}
|
|
)
|
|
return response.json()["embeddings"]
|
|
|
|
def embed_query(self, text: str) -> list[float]:
|
|
return self.embed_documents([text])[0]
|
|
|
|
# Triton Inference Server 기반 LLM
|
|
llm = HuggingFaceEndpoint(
|
|
endpoint_url=LLM_ENDPOINT, # Triton inference REST endpoint
|
|
task="text-generation",
|
|
model_kwargs={"temperature": 0.7, "max_new_tokens": 512}
|
|
)
|
|
|
|
# 사용자 정의 embedding 인스턴스
|
|
embedding_model = CustomRemoteEmbeddings(endpoint_url=EMBEDDING_ENDPOINT)
|
|
|
|
# Weaviate 벡터 DB
|
|
client = weaviate.Client(VECTOR_DATABSE_ENDPOINT)
|
|
vectorstore = Weaviate(client=client, index_name="LangChainIndex", text_key="text", embedding=embedding_model)
|
|
|
|
# RAG Prompt + Chain 구성
|
|
# 여기에 들어가는 prompt 템플릿을 어떻게 가져가야 하는지 향후 prompt에서 가져다가 쓸 수 있는 방법이 있다면 확장성 좋을 것 같음
|
|
prompt = PromptTemplate.from_template("Answer the question based on the following context:\n{context}\nQuestion: {question}")
|
|
chain = RunnableMap({
|
|
"context": lambda x: "\n".join(doc.page_content for doc in vectorstore.similarity_search(x["question"], k=3)),
|
|
"question": lambda x: x["question"]
|
|
}) | prompt | llm
|