Update llm.py
This commit is contained in:
parent
266936190b
commit
98c9237aad
36
app/llm.py
36
app/llm.py
@ -1,5 +1,12 @@
|
|||||||
from langchain_core.language_models.llms import LLM
|
from langchain_core.language_models.llms import LLM
|
||||||
from langchain_core.callbacks.manager import CallbackManagerForLLMRun
|
from langchain_core.callbacks.manager import CallbackManagerForLLMRun
|
||||||
|
from langchain_core.language_models.base import (
|
||||||
|
BaseLanguageModel,
|
||||||
|
LangSmithParams,
|
||||||
|
LanguageModelInput,
|
||||||
|
)
|
||||||
|
from langchain_core.runnables import RunnableConfig, ensure_config, get_config_list
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
import uuid
|
import uuid
|
||||||
from typing import List
|
from typing import List
|
||||||
@ -28,6 +35,12 @@ class TritonLLM(LLM):
|
|||||||
trace_id: str = str(run_manager.parent_run_id)
|
trace_id: str = str(run_manager.parent_run_id)
|
||||||
span_id: str = str(run_manager.run_id)
|
span_id: str = str(run_manager.run_id)
|
||||||
|
|
||||||
|
print("prompt")
|
||||||
|
print(prompt)
|
||||||
|
|
||||||
|
max_tokens = kwargs.get("max_tokens", 512) # 기본값 설정
|
||||||
|
temperature = kwargs.get("temperature", 0.7) # 기본값 설정
|
||||||
|
|
||||||
payload = {
|
payload = {
|
||||||
"id": trace_id,
|
"id": trace_id,
|
||||||
"inputs": [
|
"inputs": [
|
||||||
@ -36,9 +49,28 @@ class TritonLLM(LLM):
|
|||||||
"datatype": "BYTES",
|
"datatype": "BYTES",
|
||||||
"shape": [1],
|
"shape": [1],
|
||||||
"data": [prompt],
|
"data": [prompt],
|
||||||
}
|
},
|
||||||
|
{
|
||||||
|
"name": "max_length",
|
||||||
|
"shape": [1],
|
||||||
|
"datatype": "INT32",
|
||||||
|
"data": [max_tokens]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "max_new_tokens",
|
||||||
|
"shape": [1],
|
||||||
|
"datatype": "INT32",
|
||||||
|
"data": [max_tokens]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "temperature",
|
||||||
|
"shape": [1],
|
||||||
|
"datatype": "FP32",
|
||||||
|
"data": [temperature]
|
||||||
|
},
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
print(payload)
|
||||||
uuid_trace_id_hex = uuid.UUID(trace_id).hex
|
uuid_trace_id_hex = uuid.UUID(trace_id).hex
|
||||||
uuid_span_id_hex = uuid.UUID(span_id).hex
|
uuid_span_id_hex = uuid.UUID(span_id).hex
|
||||||
|
|
||||||
@ -53,7 +85,7 @@ class TritonLLM(LLM):
|
|||||||
ret = requests.post(
|
ret = requests.post(
|
||||||
triton_host,
|
triton_host,
|
||||||
json=payload,
|
json=payload,
|
||||||
timeout=10,
|
timeout=120,
|
||||||
headers=header_trace,
|
headers=header_trace,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user