import triton_python_backend_utils as pb_utils from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig import numpy as np import json class TritonPythonModel: def initialize(self, args): """모델 초기화. Triton이 서버 시작 시 실행.""" self.logger = pb_utils.Logger model_repository = args["model_repository"] model_name = args["model_name"] model_path = f"{model_repository}/{model_name}" self.model_config = json.loads(args["model_config"]) self.tokenizer = AutoTokenizer.from_pretrained(model_path) self.model = AutoModelForCausalLM.from_pretrained( model_path, local_files_only=True, device_map="auto" ) self.logger.log_info(f"\n### {model_name} 모델 초기화 완료") def execute(self, requests): """Triton이 호출하는 Inference 실행 함수.""" responses = [] input_name = self.model_config.get("input")[0]['name'] output_name = self.model_config.get("output")[0]['name'] for request in requests: # Triton 입력 파싱 input_tensor = pb_utils.get_input_tensor_by_name(request, input_name).as_numpy()[0] input_text = input_tensor.decode('utf-8') self.logger.log_info(f"### INPUT_TEXT: {input_text}") # 토크나이징 inputs = self.tokenizer( f"### 질문: {input_text}\n\n### 답변:", return_tensors="pt").to(device=self.model.device) input_ids = inputs["input_ids"].to(device=self.model.device) attention_mask = inputs["attention_mask"].to(device=self.model.device) generation_config = GenerationConfig( max_new_tokens=256, ) # 모델 추론 gened = self.model.generate( generation_config=generation_config, input_ids=input_ids, attention_mask=attention_mask, pad_token_id=2, repetition_penalty=1.1, ) # 생성된 텍스트 디코딩 answer = self.tokenizer.decode(gened[0]) self.logger.log_info(f"### MODEL ANSWER:\n{answer}") # 답변 내용 후처리 output = self.post_process(answer) self.logger.log_info(f"### OUTPUT_TEXT: {output}") # Triton에 텐서로 반환 output_tensor = pb_utils.Tensor(output_name, np.array(output.encode('utf-8'), dtype=np.bytes_)) responses.append(pb_utils.InferenceResponse(output_tensors=[output_tensor])) return responses def post_process(self, text): try: return str(text.split("### 답변:")[1].split("### 질문:")[0].strip()) except IndexError: return text def finalize(self): """서버 종료 시 정리 (옵션).""" pass