airflow/dags/groupuser/01kjrwehggnsj6191x3kabst01.py

133 lines
4.2 KiB
Python

# ----------------------------------------------
# Cheetah Pipeline Template
# ----------------------------------------------
# USER_EDIT_IMPORT_START
import os
import json
from datetime import datetime
from airflow.operators.python import PythonOperator
from airflow.providers.http.hooks.http import HttpHook
from operators.cheetah_dataset_operator import CheetahDatasetOperator
# USER_EDIT_IMPORT_END
from decorators.cheetah_pipeline_decorator import cheetah_pipeline
from airflow import DAG
@cheetah_pipeline(
upload_outputs=True,
push_to_dataset=False,
s3_bucket=None,
s3_conn_id=None
)
def dag_template():
# USER_EDIT_DAG_META_START
start_date = datetime(2026, 3, 3)
schedule = None
catchup = False
tags = []
def _decrypt_via_kai_gateway(**context):
def get_input_path(upstream_task_id: str, sub_path: str = "") -> str:
base = f"/cheetah/inputs/{upstream_task_id}/output"
return os.path.join(base, sub_path) if sub_path else base
def get_output_path(filename: str) -> str:
return os.path.join("/cheetah/outputs", filename)
upstream_id = "download_dataset"
input_dir = get_input_path(upstream_id, sub_path="raw_data")
if not os.path.isdir(input_dir):
raise FileNotFoundError(f"Input directory not found: {input_dir}")
out_root = get_output_path("decrypted_data")
os.makedirs(out_root, exist_ok=True)
http = HttpHook(method="POST", http_conn_id="groupuser_kai_drm_gateway")
successes = []
failures = []
for root, _, files in os.walk(input_dir):
for name in files:
in_path = os.path.join(root, name)
rel = os.path.relpath(in_path, input_dir)
out_path = os.path.join(out_root, rel)
os.makedirs(os.path.dirname(out_path), exist_ok=True)
try:
with open(in_path, "rb") as f_in:
resp = http.run(
endpoint="/drm/decrypt",
files={"file": (name, f_in, "application/octet-stream")},
extra_options={"timeout": 600, "stream": True},
headers={},
)
resp.raise_for_status()
with open(out_path, "wb") as f_out:
for chunk in resp.iter_content(chunk_size=1024 * 1024):
if chunk:
f_out.write(chunk)
successes.append({"input": in_path, "output": out_path})
print(f"[DRM] OK {in_path} -> {out_path}")
except Exception as e:
failures.append({"input": in_path, "error": str(e)})
print(f"[DRM] FAIL {in_path} ({e})")
continue
report_path = get_output_path("drm_decrypt_report.json")
report = {
"input_dir": input_dir,
"output_dir": out_root,
"success_count": len(successes),
"failure_count": len(failures),
"successes": successes,
"failures": failures,
"timestamp": str(datetime.now()),
}
with open(report_path, "w", encoding="utf-8") as f:
json.dump(report, f, indent=2, ensure_ascii=False)
print(f"[DRM] Report saved: {report_path}")
STRICT = False
if STRICT and failures:
raise RuntimeError(
f"DRM decrypt failed for {len(failures)} file(s). See report: {report_path}"
)
# USER_EDIT_DAG_META_END
with DAG(
dag_id='ca_kai-drm-gateway',
start_date=start_date,
schedule=schedule,
catchup=catchup,
tags=tags,
is_paused_upon_creation=False
) as dag:
# USER_EDIT_TASK_START
t1 = CheetahDatasetOperator(
task_id="download_dataset",
dataset_sn=409,
output_path="raw_data",
)
t2 = PythonOperator(
task_id="drm_decrypt",
python_callable=_decrypt_via_kai_gateway,
)
t1 >> t2
# USER_EDIT_TASK_END
return dag
dag = dag_template()