Update llm.py

This commit is contained in:
groupuser 2025-04-23 01:58:37 +00:00
parent 266936190b
commit 98c9237aad

@ -1,5 +1,12 @@
from langchain_core.language_models.llms import LLM from langchain_core.language_models.llms import LLM
from langchain_core.callbacks.manager import CallbackManagerForLLMRun from langchain_core.callbacks.manager import CallbackManagerForLLMRun
from langchain_core.language_models.base import (
BaseLanguageModel,
LangSmithParams,
LanguageModelInput,
)
from langchain_core.runnables import RunnableConfig, ensure_config, get_config_list
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
import uuid import uuid
from typing import List from typing import List
@ -28,6 +35,12 @@ class TritonLLM(LLM):
trace_id: str = str(run_manager.parent_run_id) trace_id: str = str(run_manager.parent_run_id)
span_id: str = str(run_manager.run_id) span_id: str = str(run_manager.run_id)
print("prompt")
print(prompt)
max_tokens = kwargs.get("max_tokens", 512) # 기본값 설정
temperature = kwargs.get("temperature", 0.7) # 기본값 설정
payload = { payload = {
"id": trace_id, "id": trace_id,
"inputs": [ "inputs": [
@ -36,9 +49,28 @@ class TritonLLM(LLM):
"datatype": "BYTES", "datatype": "BYTES",
"shape": [1], "shape": [1],
"data": [prompt], "data": [prompt],
} },
{
"name": "max_length",
"shape": [1],
"datatype": "INT32",
"data": [max_tokens]
},
{
"name": "max_new_tokens",
"shape": [1],
"datatype": "INT32",
"data": [max_tokens]
},
{
"name": "temperature",
"shape": [1],
"datatype": "FP32",
"data": [temperature]
},
] ]
} }
print(payload)
uuid_trace_id_hex = uuid.UUID(trace_id).hex uuid_trace_id_hex = uuid.UUID(trace_id).hex
uuid_span_id_hex = uuid.UUID(span_id).hex uuid_span_id_hex = uuid.UUID(span_id).hex
@ -53,7 +85,7 @@ class TritonLLM(LLM):
ret = requests.post( ret = requests.post(
triton_host, triton_host,
json=payload, json=payload,
timeout=10, timeout=120,
headers=header_trace, headers=header_trace,
) )