99 lines
3.0 KiB
Python
99 lines
3.0 KiB
Python
from langchain_core.language_models.llms import LLM
|
|
from langchain_core.callbacks.manager import CallbackManagerForLLMRun
|
|
from langchain_core.language_models.base import (
|
|
BaseLanguageModel,
|
|
LangSmithParams,
|
|
LanguageModelInput,
|
|
)
|
|
from langchain_core.runnables import RunnableConfig, ensure_config, get_config_list
|
|
|
|
from typing import Any, Dict, List, Optional
|
|
import uuid
|
|
from typing import List
|
|
import requests
|
|
import os
|
|
|
|
triton_host: str = os.getenv("TRITON_LLM_ENDPOINT", "http://183.111.96.67:8000/v2/models/llama-3-3b/versions/1/infer")
|
|
model_api_key: str = os.getenv("MODEL_API_KEY", "01jsdesqtp4w81b8hydysxa1k2")
|
|
|
|
class TritonLLM(LLM):
|
|
"""Triton Inference Server를 사용하는 LLM"""
|
|
|
|
@property
|
|
def _llm_type(self) -> str:
|
|
return "tritonllm"
|
|
|
|
def _call(
|
|
self,
|
|
prompt: str,
|
|
stop: Optional[List[str]] = None,
|
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
**kwargs: Any,
|
|
) -> str:
|
|
if stop is not None:
|
|
raise ValueError("stop kwargs are not permitted.")
|
|
trace_id: str = str(run_manager.parent_run_id)
|
|
span_id: str = str(run_manager.run_id)
|
|
|
|
print("prompt")
|
|
print(prompt)
|
|
|
|
max_tokens = kwargs.get("max_tokens", 512) # 기본값 설정
|
|
temperature = kwargs.get("temperature", 0.7) # 기본값 설정
|
|
|
|
payload = {
|
|
"id": trace_id,
|
|
"inputs": [
|
|
{
|
|
"name": "text_input",
|
|
"datatype": "BYTES",
|
|
"shape": [1],
|
|
"data": [prompt],
|
|
},
|
|
{
|
|
"name": "max_length",
|
|
"shape": [1],
|
|
"datatype": "INT32",
|
|
"data": [max_tokens]
|
|
},
|
|
{
|
|
"name": "max_new_tokens",
|
|
"shape": [1],
|
|
"datatype": "INT32",
|
|
"data": [max_tokens]
|
|
},
|
|
{
|
|
"name": "temperature",
|
|
"shape": [1],
|
|
"datatype": "FP32",
|
|
"data": [temperature]
|
|
},
|
|
]
|
|
}
|
|
print(payload)
|
|
uuid_trace_id_hex = uuid.UUID(trace_id).hex
|
|
uuid_span_id_hex = uuid.UUID(span_id).hex
|
|
|
|
# eg. 00-80e1afed08e019fc1110464cfa66635c-00085853722dc6d2-00
|
|
# The traceparent header uses the version-trace_id-parent_id-trace_flags format
|
|
header_trace = {
|
|
"traceparent": f"00-{uuid_trace_id_hex}-00{uuid_span_id_hex[:14]}-00",
|
|
"Content-Type": "application/json",
|
|
"model-api-key": "01jsdesqtp4w81b8hydysxa1k2"
|
|
}
|
|
|
|
ret = requests.post(
|
|
triton_host,
|
|
json=payload,
|
|
timeout=120,
|
|
headers=header_trace,
|
|
)
|
|
|
|
print(ret.status_code)
|
|
print(ret.text)
|
|
|
|
res = ret.json()
|
|
query_response = res["outputs"][0]["data"][0]
|
|
|
|
return query_response
|