From 98c9237aad90a02c9199121e0b9ec08d11a8cf8a Mon Sep 17 00:00:00 2001 From: groupuser Date: Wed, 23 Apr 2025 01:58:37 +0000 Subject: [PATCH] Update llm.py --- app/llm.py | 36 ++++++++++++++++++++++++++++++++++-- 1 file changed, 34 insertions(+), 2 deletions(-) diff --git a/app/llm.py b/app/llm.py index 735aa1b..98563fc 100644 --- a/app/llm.py +++ b/app/llm.py @@ -1,5 +1,12 @@ from langchain_core.language_models.llms import LLM from langchain_core.callbacks.manager import CallbackManagerForLLMRun +from langchain_core.language_models.base import ( + BaseLanguageModel, + LangSmithParams, + LanguageModelInput, +) +from langchain_core.runnables import RunnableConfig, ensure_config, get_config_list + from typing import Any, Dict, List, Optional import uuid from typing import List @@ -28,6 +35,12 @@ class TritonLLM(LLM): trace_id: str = str(run_manager.parent_run_id) span_id: str = str(run_manager.run_id) + print("prompt") + print(prompt) + + max_tokens = kwargs.get("max_tokens", 512) # 기본값 설정 + temperature = kwargs.get("temperature", 0.7) # 기본값 설정 + payload = { "id": trace_id, "inputs": [ @@ -36,9 +49,28 @@ class TritonLLM(LLM): "datatype": "BYTES", "shape": [1], "data": [prompt], - } + }, + { + "name": "max_length", + "shape": [1], + "datatype": "INT32", + "data": [max_tokens] + }, + { + "name": "max_new_tokens", + "shape": [1], + "datatype": "INT32", + "data": [max_tokens] + }, + { + "name": "temperature", + "shape": [1], + "datatype": "FP32", + "data": [temperature] + }, ] } + print(payload) uuid_trace_id_hex = uuid.UUID(trace_id).hex uuid_span_id_hex = uuid.UUID(span_id).hex @@ -53,7 +85,7 @@ class TritonLLM(LLM): ret = requests.post( triton_host, json=payload, - timeout=10, + timeout=120, headers=header_trace, )