diff --git a/app/chain.py b/app/chain.py new file mode 100644 index 0000000..cd056ad --- /dev/null +++ b/app/chain.py @@ -0,0 +1,51 @@ +from langchain.schema.runnable import RunnableMap +from langchain.vectorstores import Weaviate +from langchain.llms import HuggingFaceEndpoint +from langchain.prompts import PromptTemplate +from langchain.embeddings.base import Embeddings +import httpx +import weaviate +import os + +# 환경 변수에서 세팅 +LLM_ENDPOINT = os.environ.get('LLM_ENDPOINT', '') # https://deploymodel.cheetah.svc.cluster.local +EMBEDDING_ENDPOINT = os.environ.get('EMBEDDING_ENDPOINT', '') # https://embedding.cheetah.svc.cluster.local +VECTOR_DATABSE_ENDPOINT = os.environ.get('VECTOR_DATABSE_ENDPOINT', '') # https://vector.cheetah.svc.cluster.local +VECTOR_DATABSE_APIKEY = os.environ.get('VECTOR_DATABSE_APIKEY', '') # https://vector.cheetah.svc.cluster.local + +# 사용자 정의 임베딩 클래스 (온프레미스 embedding 서버 사용) +class CustomRemoteEmbeddings(Embeddings): + def __init__(self, endpoint_url: str): + self.endpoint_url = endpoint_url + + def embed_documents(self, texts: list[str]) -> list[list[float]]: + response = httpx.post( + self.endpoint_url, + json={"inputs": texts} + ) + return response.json()["embeddings"] + + def embed_query(self, text: str) -> list[float]: + return self.embed_documents([text])[0] + +# Triton Inference Server 기반 LLM +llm = HuggingFaceEndpoint( + endpoint_url=LLM_ENDPOINT, # Triton inference REST endpoint + task="text-generation", + model_kwargs={"temperature": 0.7, "max_new_tokens": 512} +) + +# 사용자 정의 embedding 인스턴스 +embedding_model = CustomRemoteEmbeddings(endpoint_url=EMBEDDING_ENDPOINT) + +# Weaviate 벡터 DB +client = weaviate.Client(VECTOR_DATABSE_ENDPOINT) +vectorstore = Weaviate(client=client, index_name="LangChainIndex", text_key="text", embedding=embedding_model) + +# RAG Prompt + Chain 구성 +# 여기에 들어가는 prompt 템플릿을 어떻게 가져가야 하는지 향후 prompt에서 가져다가 쓸 수 있는 방법이 있다면 확장성 좋을 것 같음 +prompt = PromptTemplate.from_template("Answer the question based on the following context:\n{context}\nQuestion: {question}") +chain = RunnableMap({ + "context": lambda x: "\n".join(doc.page_content for doc in vectorstore.similarity_search(x["question"], k=3)), + "question": lambda x: x["question"] +}) | prompt | llm