from langchain_core.language_models.llms import LLM from langchain_core.callbacks.manager import CallbackManagerForLLMRun from langchain_core.language_models.base import ( BaseLanguageModel, LangSmithParams, LanguageModelInput, ) from langchain_core.runnables import RunnableConfig, ensure_config, get_config_list from typing import Any, Dict, List, Optional import uuid from typing import List import requests import os triton_host: str = os.getenv("TRITON_LLM_ENDPOINT", "http://183.111.96.67:8000/v2/models/llama-3-3b/versions/1/infer") model_api_key: str = os.getenv("MODEL_API_KEY", "01jsdesqtp4w81b8hydysxa1k2") class TritonLLM(LLM): """Triton Inference Server를 사용하는 LLM""" @property def _llm_type(self) -> str: return "tritonllm" def _call( self, prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> str: if stop is not None: raise ValueError("stop kwargs are not permitted.") trace_id: str = str(run_manager.parent_run_id) span_id: str = str(run_manager.run_id) print("prompt") print(prompt) max_tokens = kwargs.get("max_tokens", 512) # 기본값 설정 temperature = kwargs.get("temperature", 0.7) # 기본값 설정 payload = { "id": trace_id, "inputs": [ { "name": "text_input", "datatype": "BYTES", "shape": [1], "data": [prompt], }, { "name": "max_length", "shape": [1], "datatype": "INT32", "data": [max_tokens] }, { "name": "max_new_tokens", "shape": [1], "datatype": "INT32", "data": [max_tokens] }, { "name": "temperature", "shape": [1], "datatype": "FP32", "data": [temperature] }, ] } print(payload) uuid_trace_id_hex = uuid.UUID(trace_id).hex uuid_span_id_hex = uuid.UUID(span_id).hex # eg. 00-80e1afed08e019fc1110464cfa66635c-00085853722dc6d2-00 # The traceparent header uses the version-trace_id-parent_id-trace_flags format header_trace = { "traceparent": f"00-{uuid_trace_id_hex}-00{uuid_span_id_hex[:14]}-00", "Content-Type": "application/json", "model-api-key": "01jsdesqtp4w81b8hydysxa1k2" } ret = requests.post( triton_host, json=payload, timeout=120, headers=header_trace, ) print(ret.status_code) print(ret.text) res = ret.json() query_response = res["outputs"][0]["data"][0] return query_response