Upload llm.py
This commit is contained in:
parent
bd618fd353
commit
0d14a58f0c
66
app/llm.py
Normal file
66
app/llm.py
Normal file
@ -0,0 +1,66 @@
|
||||
from langchain_core.language_models.llms import LLM
|
||||
from langchain_core.callbacks.manager import CallbackManagerForLLMRun
|
||||
from typing import Any, Dict, List, Optional
|
||||
import uuid
|
||||
from typing import List
|
||||
import requests
|
||||
import os
|
||||
|
||||
triton_host: str = os.getenv("TRITON_LLM_ENDPOINT", "http://183.111.96.67:8000/v2/models/llama-3-3b/versions/1/infer")
|
||||
model_api_key: str = os.getenv("MODEL_API_KEY", "01jsdesqtp4w81b8hydysxa1k2")
|
||||
|
||||
class TritonLLM(LLM):
|
||||
"""Triton Inference Server를 사용하는 LLM"""
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "tritonllm"
|
||||
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
if stop is not None:
|
||||
raise ValueError("stop kwargs are not permitted.")
|
||||
trace_id: str = str(run_manager.parent_run_id)
|
||||
span_id: str = str(run_manager.run_id)
|
||||
|
||||
payload = {
|
||||
"id": trace_id,
|
||||
"inputs": [
|
||||
{
|
||||
"name": "text_input",
|
||||
"datatype": "BYTES",
|
||||
"shape": [1],
|
||||
"data": [prompt],
|
||||
}
|
||||
]
|
||||
}
|
||||
uuid_trace_id_hex = uuid.UUID(trace_id).hex
|
||||
uuid_span_id_hex = uuid.UUID(span_id).hex
|
||||
|
||||
# eg. 00-80e1afed08e019fc1110464cfa66635c-00085853722dc6d2-00
|
||||
# The traceparent header uses the version-trace_id-parent_id-trace_flags format
|
||||
header_trace = {
|
||||
"traceparent": f"00-{uuid_trace_id_hex}-00{uuid_span_id_hex[:14]}-00",
|
||||
"Content-Type": "application/json",
|
||||
"model-api-key": "01jsdesqtp4w81b8hydysxa1k2"
|
||||
}
|
||||
|
||||
ret = requests.post(
|
||||
triton_host,
|
||||
json=payload,
|
||||
timeout=10,
|
||||
headers=header_trace,
|
||||
)
|
||||
|
||||
print(ret.status_code)
|
||||
print(ret.text)
|
||||
|
||||
res = ret.json()
|
||||
query_response = res["outputs"][0]["data"][0]
|
||||
|
||||
return query_response
|
||||
Loading…
Reference in New Issue
Block a user