Upload embedding.py
This commit is contained in:
parent
8dd70f3b53
commit
e29196e458
33
app/embedding.py
Normal file
33
app/embedding.py
Normal file
@ -0,0 +1,33 @@
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from typing import List, Dict, Any
|
||||
import requests
|
||||
import os
|
||||
|
||||
embedding_host = os.getenv("EMBEDDING_HOST", "http://183.111.96.67:30136")
|
||||
|
||||
class WeaviateCustomEmbeddings(Embeddings):
|
||||
"""커스텀 엔드포인트를 사용하는 임베딩 모델"""
|
||||
|
||||
def __init__(self):
|
||||
self.api_url = embedding_host.rstrip("/")
|
||||
print(self.api_url)
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""문서 리스트를 임베딩합니다."""
|
||||
embeddings = []
|
||||
for text in texts:
|
||||
embeddings.append(self.embed_query(text))
|
||||
return embeddings
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""쿼리 텍스트를 임베딩합니다."""
|
||||
query_text = text.question if hasattr(text, 'question') else str(text)
|
||||
|
||||
response = requests.post(
|
||||
f"{self.api_url}/vectors",
|
||||
json={"text": query_text}
|
||||
)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"임베딩 API 호출 실패: {self.api_url}/vectors, {query_text}, {response.status_code}, {response.text}")
|
||||
|
||||
return response.json()["vector"]
|
||||
Loading…
Reference in New Issue
Block a user