diff --git a/1/model.py b/1/model.py index 22e6486..b1d317d 100644 --- a/1/model.py +++ b/1/model.py @@ -9,17 +9,21 @@ class TritonPythonModel: Triton Python Model 클래스. """ - def initialize(self, args): - """ - 모델이 로드될 때 호출됩니다. - """ - print("TritonPythonModel: initialize() called.") - self.model_config = json.loads(args['model_config']) - - output_config = pb_utils.get_output_config_by_name( - self.model_config, "OUTPUT") - self.output_dtype = pb_utils.triton_string_to_np_dtype( - output_config['data_type']) +def initialize(self, args): + print("TritonPythonModel: initialize() called.") + self.model_config = json.loads(args['model_config']) + + # 출력 설정에서 데이터 타입 정보를 가져옴 + output_config = pb_utils.get_output_config_by_name( + self.model_config, "OUTPUT") + + # Triton 데이터 타입 문자열을 NumPy 데이터 타입으로 직접 변환 + # 'BYTES'는 np.object_ 타입에 해당함 + if output_config['data_type'] == 'TYPE_STRING': + self.output_dtype = np.object_ + else: + # 다른 데이터 타입에 대한 처리 로직 추가 가능 + self.output_dtype = pb_utils.triton_string_to_np_dtype(output_config['data_type']) def execute(self, requests): """