From 0d14a58f0cbc606dcee46a0d417b32834caf2e8a Mon Sep 17 00:00:00 2001 From: groupuser Date: Tue, 22 Apr 2025 23:57:11 +0000 Subject: [PATCH] Upload llm.py --- app/llm.py | 66 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 66 insertions(+) create mode 100644 app/llm.py diff --git a/app/llm.py b/app/llm.py new file mode 100644 index 0000000..735aa1b --- /dev/null +++ b/app/llm.py @@ -0,0 +1,66 @@ +from langchain_core.language_models.llms import LLM +from langchain_core.callbacks.manager import CallbackManagerForLLMRun +from typing import Any, Dict, List, Optional +import uuid +from typing import List +import requests +import os + +triton_host: str = os.getenv("TRITON_LLM_ENDPOINT", "http://183.111.96.67:8000/v2/models/llama-3-3b/versions/1/infer") +model_api_key: str = os.getenv("MODEL_API_KEY", "01jsdesqtp4w81b8hydysxa1k2") + +class TritonLLM(LLM): + """Triton Inference Server를 사용하는 LLM""" + + @property + def _llm_type(self) -> str: + return "tritonllm" + + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> str: + if stop is not None: + raise ValueError("stop kwargs are not permitted.") + trace_id: str = str(run_manager.parent_run_id) + span_id: str = str(run_manager.run_id) + + payload = { + "id": trace_id, + "inputs": [ + { + "name": "text_input", + "datatype": "BYTES", + "shape": [1], + "data": [prompt], + } + ] + } + uuid_trace_id_hex = uuid.UUID(trace_id).hex + uuid_span_id_hex = uuid.UUID(span_id).hex + + # eg. 00-80e1afed08e019fc1110464cfa66635c-00085853722dc6d2-00 + # The traceparent header uses the version-trace_id-parent_id-trace_flags format + header_trace = { + "traceparent": f"00-{uuid_trace_id_hex}-00{uuid_span_id_hex[:14]}-00", + "Content-Type": "application/json", + "model-api-key": "01jsdesqtp4w81b8hydysxa1k2" + } + + ret = requests.post( + triton_host, + json=payload, + timeout=10, + headers=header_trace, + ) + + print(ret.status_code) + print(ret.text) + + res = ret.json() + query_response = res["outputs"][0]["data"][0] + + return query_response