rag-chain-agent/app/llm.py
2025-04-23 01:58:37 +00:00

99 lines
3.0 KiB
Python

from langchain_core.language_models.llms import LLM
from langchain_core.callbacks.manager import CallbackManagerForLLMRun
from langchain_core.language_models.base import (
BaseLanguageModel,
LangSmithParams,
LanguageModelInput,
)
from langchain_core.runnables import RunnableConfig, ensure_config, get_config_list
from typing import Any, Dict, List, Optional
import uuid
from typing import List
import requests
import os
triton_host: str = os.getenv("TRITON_LLM_ENDPOINT", "http://183.111.96.67:8000/v2/models/llama-3-3b/versions/1/infer")
model_api_key: str = os.getenv("MODEL_API_KEY", "01jsdesqtp4w81b8hydysxa1k2")
class TritonLLM(LLM):
"""Triton Inference Server를 사용하는 LLM"""
@property
def _llm_type(self) -> str:
return "tritonllm"
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
if stop is not None:
raise ValueError("stop kwargs are not permitted.")
trace_id: str = str(run_manager.parent_run_id)
span_id: str = str(run_manager.run_id)
print("prompt")
print(prompt)
max_tokens = kwargs.get("max_tokens", 512) # 기본값 설정
temperature = kwargs.get("temperature", 0.7) # 기본값 설정
payload = {
"id": trace_id,
"inputs": [
{
"name": "text_input",
"datatype": "BYTES",
"shape": [1],
"data": [prompt],
},
{
"name": "max_length",
"shape": [1],
"datatype": "INT32",
"data": [max_tokens]
},
{
"name": "max_new_tokens",
"shape": [1],
"datatype": "INT32",
"data": [max_tokens]
},
{
"name": "temperature",
"shape": [1],
"datatype": "FP32",
"data": [temperature]
},
]
}
print(payload)
uuid_trace_id_hex = uuid.UUID(trace_id).hex
uuid_span_id_hex = uuid.UUID(span_id).hex
# eg. 00-80e1afed08e019fc1110464cfa66635c-00085853722dc6d2-00
# The traceparent header uses the version-trace_id-parent_id-trace_flags format
header_trace = {
"traceparent": f"00-{uuid_trace_id_hex}-00{uuid_span_id_hex[:14]}-00",
"Content-Type": "application/json",
"model-api-key": "01jsdesqtp4w81b8hydysxa1k2"
}
ret = requests.post(
triton_host,
json=payload,
timeout=120,
headers=header_trace,
)
print(ret.status_code)
print(ret.text)
res = ret.json()
query_response = res["outputs"][0]["data"][0]
return query_response