133 lines
4.2 KiB
Python
133 lines
4.2 KiB
Python
# ----------------------------------------------
|
|
# Cheetah Pipeline Template
|
|
# ----------------------------------------------
|
|
|
|
# USER_EDIT_IMPORT_START
|
|
import os
|
|
import json
|
|
from datetime import datetime
|
|
from airflow.operators.python import PythonOperator
|
|
from airflow.providers.http.hooks.http import HttpHook
|
|
from operators.cheetah_dataset_operator import CheetahDatasetOperator
|
|
# USER_EDIT_IMPORT_END
|
|
|
|
from decorators.cheetah_pipeline_decorator import cheetah_pipeline
|
|
from airflow import DAG
|
|
|
|
|
|
@cheetah_pipeline(
|
|
upload_outputs=True,
|
|
push_to_dataset=False,
|
|
s3_bucket=None,
|
|
s3_conn_id=None
|
|
)
|
|
def dag_template():
|
|
|
|
# USER_EDIT_DAG_META_START
|
|
start_date = datetime(2026, 3, 3)
|
|
schedule = None
|
|
catchup = False
|
|
tags = []
|
|
|
|
def _decrypt_via_kai_gateway(**context):
|
|
def get_input_path(upstream_task_id: str, sub_path: str = "") -> str:
|
|
base = f"/cheetah/inputs/{upstream_task_id}/output"
|
|
return os.path.join(base, sub_path) if sub_path else base
|
|
|
|
def get_output_path(filename: str) -> str:
|
|
return os.path.join("/cheetah/outputs", filename)
|
|
|
|
upstream_id = "download_dataset"
|
|
input_dir = get_input_path(upstream_id, sub_path="raw_data")
|
|
|
|
if not os.path.isdir(input_dir):
|
|
raise FileNotFoundError(f"Input directory not found: {input_dir}")
|
|
|
|
out_root = get_output_path("decrypted_data")
|
|
os.makedirs(out_root, exist_ok=True)
|
|
|
|
http = HttpHook(method="POST", http_conn_id="groupuser_kai_drm_gateway")
|
|
|
|
successes = []
|
|
failures = []
|
|
|
|
for root, _, files in os.walk(input_dir):
|
|
for name in files:
|
|
in_path = os.path.join(root, name)
|
|
rel = os.path.relpath(in_path, input_dir)
|
|
out_path = os.path.join(out_root, rel)
|
|
os.makedirs(os.path.dirname(out_path), exist_ok=True)
|
|
|
|
try:
|
|
with open(in_path, "rb") as f_in:
|
|
resp = http.run(
|
|
endpoint="/drm/decrypt",
|
|
files={"file": (name, f_in, "application/octet-stream")},
|
|
extra_options={"timeout": 600, "stream": True},
|
|
headers={},
|
|
)
|
|
resp.raise_for_status()
|
|
|
|
with open(out_path, "wb") as f_out:
|
|
for chunk in resp.iter_content(chunk_size=1024 * 1024):
|
|
if chunk:
|
|
f_out.write(chunk)
|
|
|
|
successes.append({"input": in_path, "output": out_path})
|
|
print(f"[DRM] OK {in_path} -> {out_path}")
|
|
|
|
except Exception as e:
|
|
failures.append({"input": in_path, "error": str(e)})
|
|
print(f"[DRM] FAIL {in_path} ({e})")
|
|
continue
|
|
|
|
report_path = get_output_path("drm_decrypt_report.json")
|
|
report = {
|
|
"input_dir": input_dir,
|
|
"output_dir": out_root,
|
|
"success_count": len(successes),
|
|
"failure_count": len(failures),
|
|
"successes": successes,
|
|
"failures": failures,
|
|
"timestamp": str(datetime.now()),
|
|
}
|
|
with open(report_path, "w", encoding="utf-8") as f:
|
|
json.dump(report, f, indent=2, ensure_ascii=False)
|
|
|
|
print(f"[DRM] Report saved: {report_path}")
|
|
|
|
STRICT = False
|
|
if STRICT and failures:
|
|
raise RuntimeError(
|
|
f"DRM decrypt failed for {len(failures)} file(s). See report: {report_path}"
|
|
)
|
|
# USER_EDIT_DAG_META_END
|
|
|
|
with DAG(
|
|
dag_id='ca_kai-drm-gateway',
|
|
start_date=start_date,
|
|
schedule=schedule,
|
|
catchup=catchup,
|
|
tags=tags,
|
|
is_paused_upon_creation=False
|
|
) as dag:
|
|
|
|
# USER_EDIT_TASK_START
|
|
t1 = CheetahDatasetOperator(
|
|
task_id="download_dataset",
|
|
dataset_sn=409,
|
|
output_path="raw_data",
|
|
)
|
|
|
|
t2 = PythonOperator(
|
|
task_id="drm_decrypt",
|
|
python_callable=_decrypt_via_kai_gateway,
|
|
)
|
|
|
|
t1 >> t2
|
|
# USER_EDIT_TASK_END
|
|
|
|
return dag
|
|
|
|
|
|
dag = dag_template() |