Upload chain.py

This commit is contained in:
localsoo 2025-04-18 01:39:46 +00:00
parent 2ac1c4a167
commit 052ffd587e

51
app/chain.py Normal file

@ -0,0 +1,51 @@
from langchain.schema.runnable import RunnableMap
from langchain.vectorstores import Weaviate
from langchain.llms import HuggingFaceEndpoint
from langchain.prompts import PromptTemplate
from langchain.embeddings.base import Embeddings
import httpx
import weaviate
import os
# 환경 변수에서 세팅
LLM_ENDPOINT = os.environ.get('LLM_ENDPOINT', '') # https://deploymodel.cheetah.svc.cluster.local
EMBEDDING_ENDPOINT = os.environ.get('EMBEDDING_ENDPOINT', '') # https://embedding.cheetah.svc.cluster.local
VECTOR_DATABSE_ENDPOINT = os.environ.get('VECTOR_DATABSE_ENDPOINT', '') # https://vector.cheetah.svc.cluster.local
VECTOR_DATABSE_APIKEY = os.environ.get('VECTOR_DATABSE_APIKEY', '') # https://vector.cheetah.svc.cluster.local
# 사용자 정의 임베딩 클래스 (온프레미스 embedding 서버 사용)
class CustomRemoteEmbeddings(Embeddings):
def __init__(self, endpoint_url: str):
self.endpoint_url = endpoint_url
def embed_documents(self, texts: list[str]) -> list[list[float]]:
response = httpx.post(
self.endpoint_url,
json={"inputs": texts}
)
return response.json()["embeddings"]
def embed_query(self, text: str) -> list[float]:
return self.embed_documents([text])[0]
# Triton Inference Server 기반 LLM
llm = HuggingFaceEndpoint(
endpoint_url=LLM_ENDPOINT, # Triton inference REST endpoint
task="text-generation",
model_kwargs={"temperature": 0.7, "max_new_tokens": 512}
)
# 사용자 정의 embedding 인스턴스
embedding_model = CustomRemoteEmbeddings(endpoint_url=EMBEDDING_ENDPOINT)
# Weaviate 벡터 DB
client = weaviate.Client(VECTOR_DATABSE_ENDPOINT)
vectorstore = Weaviate(client=client, index_name="LangChainIndex", text_key="text", embedding=embedding_model)
# RAG Prompt + Chain 구성
# 여기에 들어가는 prompt 템플릿을 어떻게 가져가야 하는지 향후 prompt에서 가져다가 쓸 수 있는 방법이 있다면 확장성 좋을 것 같음
prompt = PromptTemplate.from_template("Answer the question based on the following context:\n{context}\nQuestion: {question}")
chain = RunnableMap({
"context": lambda x: "\n".join(doc.page_content for doc in vectorstore.similarity_search(x["question"], k=3)),
"question": lambda x: x["question"]
}) | prompt | llm