Upload llm.py
This commit is contained in:
parent
bd618fd353
commit
0d14a58f0c
66
app/llm.py
Normal file
66
app/llm.py
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
from langchain_core.language_models.llms import LLM
|
||||||
|
from langchain_core.callbacks.manager import CallbackManagerForLLMRun
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
import uuid
|
||||||
|
from typing import List
|
||||||
|
import requests
|
||||||
|
import os
|
||||||
|
|
||||||
|
triton_host: str = os.getenv("TRITON_LLM_ENDPOINT", "http://183.111.96.67:8000/v2/models/llama-3-3b/versions/1/infer")
|
||||||
|
model_api_key: str = os.getenv("MODEL_API_KEY", "01jsdesqtp4w81b8hydysxa1k2")
|
||||||
|
|
||||||
|
class TritonLLM(LLM):
|
||||||
|
"""Triton Inference Server를 사용하는 LLM"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _llm_type(self) -> str:
|
||||||
|
return "tritonllm"
|
||||||
|
|
||||||
|
def _call(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> str:
|
||||||
|
if stop is not None:
|
||||||
|
raise ValueError("stop kwargs are not permitted.")
|
||||||
|
trace_id: str = str(run_manager.parent_run_id)
|
||||||
|
span_id: str = str(run_manager.run_id)
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"id": trace_id,
|
||||||
|
"inputs": [
|
||||||
|
{
|
||||||
|
"name": "text_input",
|
||||||
|
"datatype": "BYTES",
|
||||||
|
"shape": [1],
|
||||||
|
"data": [prompt],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
uuid_trace_id_hex = uuid.UUID(trace_id).hex
|
||||||
|
uuid_span_id_hex = uuid.UUID(span_id).hex
|
||||||
|
|
||||||
|
# eg. 00-80e1afed08e019fc1110464cfa66635c-00085853722dc6d2-00
|
||||||
|
# The traceparent header uses the version-trace_id-parent_id-trace_flags format
|
||||||
|
header_trace = {
|
||||||
|
"traceparent": f"00-{uuid_trace_id_hex}-00{uuid_span_id_hex[:14]}-00",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"model-api-key": "01jsdesqtp4w81b8hydysxa1k2"
|
||||||
|
}
|
||||||
|
|
||||||
|
ret = requests.post(
|
||||||
|
triton_host,
|
||||||
|
json=payload,
|
||||||
|
timeout=10,
|
||||||
|
headers=header_trace,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(ret.status_code)
|
||||||
|
print(ret.text)
|
||||||
|
|
||||||
|
res = ret.json()
|
||||||
|
query_response = res["outputs"][0]["data"][0]
|
||||||
|
|
||||||
|
return query_response
|
||||||
Loading…
Reference in New Issue
Block a user