Update llm.py

This commit is contained in:
groupuser 2025-04-23 01:58:37 +00:00
parent 266936190b
commit 98c9237aad

@ -1,5 +1,12 @@
from langchain_core.language_models.llms import LLM
from langchain_core.callbacks.manager import CallbackManagerForLLMRun
from langchain_core.language_models.base import (
BaseLanguageModel,
LangSmithParams,
LanguageModelInput,
)
from langchain_core.runnables import RunnableConfig, ensure_config, get_config_list
from typing import Any, Dict, List, Optional
import uuid
from typing import List
@ -28,6 +35,12 @@ class TritonLLM(LLM):
trace_id: str = str(run_manager.parent_run_id)
span_id: str = str(run_manager.run_id)
print("prompt")
print(prompt)
max_tokens = kwargs.get("max_tokens", 512) # 기본값 설정
temperature = kwargs.get("temperature", 0.7) # 기본값 설정
payload = {
"id": trace_id,
"inputs": [
@ -36,9 +49,28 @@ class TritonLLM(LLM):
"datatype": "BYTES",
"shape": [1],
"data": [prompt],
}
},
{
"name": "max_length",
"shape": [1],
"datatype": "INT32",
"data": [max_tokens]
},
{
"name": "max_new_tokens",
"shape": [1],
"datatype": "INT32",
"data": [max_tokens]
},
{
"name": "temperature",
"shape": [1],
"datatype": "FP32",
"data": [temperature]
},
]
}
print(payload)
uuid_trace_id_hex = uuid.UUID(trace_id).hex
uuid_span_id_hex = uuid.UUID(span_id).hex
@ -53,7 +85,7 @@ class TritonLLM(LLM):
ret = requests.post(
triton_host,
json=payload,
timeout=10,
timeout=120,
headers=header_trace,
)