# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
from functools import cached_property
from pathlib import Path
from typing import TYPE_CHECKING, Any, cast
from kubernetes.client import CoreV1Api, CustomObjectsApi, models as k8s
from airflow.exceptions import AirflowException
from airflow.providers.cncf.kubernetes import pod_generator
from airflow.providers.cncf.kubernetes.hooks.kubernetes import KubernetesHook, _load_body_to_dict
from airflow.providers.cncf.kubernetes.kubernetes_helper_functions import add_unique_suffix
from airflow.providers.cncf.kubernetes.operators.custom_object_launcher import CustomObjectLauncher
from airflow.providers.cncf.kubernetes.operators.pod import KubernetesPodOperator
from airflow.providers.cncf.kubernetes.pod_generator import MAX_LABEL_LEN, PodGenerator
from airflow.providers.cncf.kubernetes.utils.pod_manager import PodManager
from airflow.utils.helpers import prune_dict
import jinja2
from airflow.sdk.definitions.context import Context
except ImportError:
# TODO: Remove once provider drops support for Airflow 2
from airflow.utils.context import Context
class SparkKubernetesOperator(KubernetesPodOperator):
Creates sparkApplication object in kubernetes cluster.
.. seealso::
For more detail about Spark Application Object have a look at the reference:
:param image: Docker image you wish to launch. Defaults to hub.docker.com,
:param code_path: path to the spark code in image,
:param namespace: kubernetes namespace to put sparkApplication
:param name: name of the pod in which the task will run, will be used (plus a random
suffix if random_name_suffix is True) to generate a pod id (DNS-1123 subdomain,
containing only [a-z0-9.-]).
:param application_file: filepath to kubernetes custom_resource_definition of sparkApplication
:param template_spec: kubernetes sparkApplication specification
:param get_logs: get the stdout of the container as logs of the tasks.
:param do_xcom_push: If True, the content of the file
/airflow/xcom/return.json in the container will also be pushed to an
XCom when the container completes.
:param success_run_history_limit: Number of past successful runs of the application to keep.
:param startup_timeout_seconds: timeout in seconds to startup the pod.
:param log_events_on_failure: Log the pod's events if a failure occurs
:param reattach_on_restart: if the scheduler dies while the pod is running, reattach and monitor
:param delete_on_termination: What to do when the pod reaches its final
state, or the execution is interrupted. If True (default), delete the
pod; if False, leave the pod.
:param kubernetes_conn_id: the connection to Kubernetes cluster
:param random_name_suffix: If True, adds a random suffix to the pod name
template_fields = ["application_file", "namespace", "template_spec", "kubernetes_conn_id"]
template_fields_renderers = {"template_spec": "py"}
template_ext = ("yaml", "yml", "json")
BASE_CONTAINER_NAME = "spark-kubernetes-driver"
def __init__(
image: str | None = None,
code_path: str | None = None,
namespace: str = "default",
name: str | None = None,
application_file: str | None = None,
get_logs: bool = True,
do_xcom_push: bool = False,
success_run_history_limit: int = 1,
log_events_on_failure: bool = False,
reattach_on_restart: bool = True,
delete_on_termination: bool = True,
kubernetes_conn_id: str = "kubernetes_default",
random_name_suffix: bool = True,
) -> None:
super().__init__(name=name, **kwargs)
self.code_path = code_path
self.application_file = application_file
self.template_spec = template_spec
self.kubernetes_conn_id = kubernetes_conn_id
self.startup_timeout_seconds = startup_timeout_seconds
self.reattach_on_restart = reattach_on_restart
self.delete_on_termination = delete_on_termination
self.do_xcom_push = do_xcom_push
self.namespace = namespace
self.get_logs = get_logs
self.log_events_on_failure = log_events_on_failure
self.success_run_history_limit = success_run_history_limit
self.random_name_suffix = random_name_suffix
# fix mypy typing
self.base_container_name: str
self.container_logs: list[str]
if self.base_container_name != self.BASE_CONTAINER_NAME:
"base_container_name is not supported and will be overridden to %s", self.BASE_CONTAINER_NAME
self.base_container_name = self.BASE_CONTAINER_NAME
if self.get_logs and self.container_logs != self.BASE_CONTAINER_NAME:
"container_logs is not supported and will be overridden to %s", self.BASE_CONTAINER_NAME
self.container_logs = [self.BASE_CONTAINER_NAME]
def _render_nested_template_fields(
content: Any,
context: Context,
jinja_env: jinja2.Environment,
seen_oids: set,
) -> None:
if id(content) not in seen_oids and isinstance(content, k8s.V1EnvVar):
self._do_render_template_fields(content, ("value", "name"), context, jinja_env, seen_oids)
super()._render_nested_template_fields(content, context, jinja_env, seen_oids)
def manage_template_specs(self):
if self.application_file:
filepath = Path(self.application_file.rstrip()).resolve(strict=True)
except (FileNotFoundError, OSError, RuntimeError, ValueError):
application_file_body = self.application_file
application_file_body = filepath.read_text()
template_body = _load_body_to_dict(application_file_body)
if not isinstance(template_body, dict):
msg = f"application_file body can't transformed into the dictionary:\n{application_file_body}"
raise TypeError(msg)
elif self.template_spec:
template_body = self.template_spec
raise AirflowException("either application_file or template_spec should be passed")
if "spark" not in template_body:
template_body = {"spark": template_body}
return template_body
def create_job_name(self):
name = (
self.name or self.template_body.get("spark", {}).get("metadata", {}).get("name") or self.task_id
if self.random_name_suffix:
updated_name = add_unique_suffix(name=name, max_len=MAX_LABEL_LEN)
# truncation is required to maintain the same behavior as before
updated_name = name[:MAX_LABEL_LEN]
return self._set_name(updated_name)
def _get_ti_pod_labels(context: Context | None = None, include_try_number: bool = True) -> dict[str, str]:
Generate labels for the pod to track the pod in case of Operator crash.
:param include_try_number: add try number to labels
:param context: task context provided by airflow DAG
:return: dict.
if not context:
return {}
context_dict = cast(dict, context)
ti = context_dict["ti"]
run_id = context_dict["run_id"]
labels = {
"dag_id": ti.dag_id,
"task_id": ti.task_id,
"run_id": run_id,
"spark_kubernetes_operator": "True",
# If running on Airflow 2.3+:
map_index = getattr(ti, "map_index", -1)
if map_index >= 0:
labels["map_index"] = map_index
if include_try_number:
# In the case of sub dags this is just useful
# TODO: Remove this when the minimum version of Airflow is bumped to 3.0
if getattr(context_dict["dag"], "is_subdag", False):
labels["parent_dag_id"] = context_dict["dag"].parent_dag.dag_id
# Ensure that label is valid for Kube,
# and if not truncate/remove invalid chars and replace with short hash.
for label_id, label in labels.items():
safe_label = pod_generator.make_safe_label_value(str(label))
labels[label_id] = safe_label
return labels
def pod_manager(self) -> PodManager:
return PodManager(kube_client=self.client)
def _try_numbers_match(context, pod) -> bool:
return pod.metadata.labels["try_number"] == context["ti"].try_number
def template_body(self):
"""Templated body for CustomObjectLauncher."""
return self.manage_template_specs()
def find_spark_job(self, context, exclude_checked: bool = True):
label_selector = (
self._build_find_pod_label_selector(context, exclude_checked=exclude_checked)
+ ",spark-role=driver"
pod_list = self.client.list_namespaced_pod(self.namespace, label_selector=label_selector).items
pod = None
if len(pod_list) > 1: # and self.reattach_on_restart:
raise AirflowException(f"More than one pod running with labels: {label_selector}")
elif len(pod_list) == 1:
pod = pod_list[0]
"Found matching driver pod %s with labels %s", pod.metadata.name, pod.metadata.labels
self.log.info("`try_number` of task_instance: %s", context["ti"].try_number)
self.log.info("`try_number` of pod: %s", pod.metadata.labels["try_number"])
return pod
def get_or_create_spark_crd(self, launcher: CustomObjectLauncher, context) -> k8s.V1Pod:
if self.reattach_on_restart:
driver_pod = self.find_spark_job(context)
if driver_pod:
return driver_pod
driver_pod, spark_obj_spec = launcher.start_spark_job(
image=self.image, code_path=self.code_path, startup_timeout=self.startup_timeout_seconds
return driver_pod
def process_pod_deletion(self, pod, *, reraise=True):
if pod is not None:
if self.delete_on_termination:
self.log.info("Deleting spark job: %s", pod.metadata.name.replace("-driver", ""))
self.launcher.delete_spark_job(pod.metadata.name.replace("-driver", ""))
self.log.info("skipping deleting spark job: %s", pod.metadata.name)
def hook(self) -> KubernetesHook:
hook = KubernetesHook(
in_cluster=self.in_cluster or self.template_body.get("kubernetes", {}).get("in_cluster", False),
or self.template_body.get("kubernetes", {}).get("kube_config_file", None),
or self.template_body.get("kubernetes", {}).get("cluster_context", None),
return hook
def client(self) -> CoreV1Api:
return self.hook.core_v1_client
def custom_obj_api(self) -> CustomObjectsApi:
return CustomObjectsApi()
def execute(self, context: Context):
self.name = self.create_job_name()
self.log.info("Creating sparkApplication.")
self.launcher = CustomObjectLauncher(
self.pod = self.get_or_create_spark_crd(self.launcher, context)
self.pod_request_obj = self.launcher.pod_spec
return super().execute(context=context)
def on_kill(self) -> None:
if self.launcher:
self.log.debug("Deleting spark job for task %s", self.task_id)
def patch_already_checked(self, pod: k8s.V1Pod, *, reraise=True):
"""Add an "already checked" annotation to ensure we don't reattach on retries."""
pod.metadata.labels["already_checked"] = "True"
body = PodGenerator.serialize_pod(pod)
self.client.patch_namespaced_pod(pod.metadata.name, pod.metadata.namespace, body)
def dry_run(self) -> None:
"""Print out the spark job that would be created by this operator."""
print(prune_dict(self.launcher.body, mode="strict"))