Source code for airflow.providers.cncf.kubernetes.operators.spark_kubernetes
## Licensed to the Apache Software Foundation (ASF) under one# or more contributor license agreements. See the NOTICE file# distributed with this work for additional information# regarding copyright ownership. The ASF licenses this file# to you under the Apache License, Version 2.0 (the# "License"); you may not use this file except in compliance# with the License. You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing,# software distributed under the License is distributed on an# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY# KIND, either express or implied. See the License for the# specific language governing permissions and limitations# under the License.from__future__importannotationsfromcollections.abcimportMappingfromfunctoolsimportcached_propertyfrompathlibimportPathfromtypingimportTYPE_CHECKING,Anyfromkubernetes.clientimportCoreV1Api,CustomObjectsApi,modelsask8sfromairflow.exceptionsimportAirflowExceptionfromairflow.providers.cncf.kubernetesimportpod_generatorfromairflow.providers.cncf.kubernetes.hooks.kubernetesimportKubernetesHook,_load_body_to_dictfromairflow.providers.cncf.kubernetes.kubernetes_helper_functionsimportadd_unique_suffixfromairflow.providers.cncf.kubernetes.operators.custom_object_launcherimportCustomObjectLauncherfromairflow.providers.cncf.kubernetes.operators.podimportKubernetesPodOperatorfromairflow.providers.cncf.kubernetes.pod_generatorimportMAX_LABEL_LEN,PodGeneratorfromairflow.providers.cncf.kubernetes.utils.pod_managerimportPodManagerfromairflow.utils.helpersimportprune_dictifTYPE_CHECKING:importjinja2fromairflow.utils.contextimportContext
[docs]classSparkKubernetesOperator(KubernetesPodOperator):""" Creates sparkApplication object in kubernetes cluster. .. seealso:: For more detail about Spark Application Object have a look at the reference: https://github.com/GoogleCloudPlatform/spark-on-k8s-operator/blob/v1beta2-1.3.3-3.1.1/docs/api-docs.md#sparkapplication :param image: Docker image you wish to launch. Defaults to hub.docker.com, :param code_path: path to the spark code in image, :param namespace: kubernetes namespace to put sparkApplication :param name: name of the pod in which the task will run, will be used (plus a random suffix if random_name_suffix is True) to generate a pod id (DNS-1123 subdomain, containing only [a-z0-9.-]). :param application_file: filepath to kubernetes custom_resource_definition of sparkApplication :param template_spec: kubernetes sparkApplication specification :param get_logs: get the stdout of the container as logs of the tasks. :param do_xcom_push: If True, the content of the file /airflow/xcom/return.json in the container will also be pushed to an XCom when the container completes. :param success_run_history_limit: Number of past successful runs of the application to keep. :param startup_timeout_seconds: timeout in seconds to startup the pod. :param log_events_on_failure: Log the pod's events if a failure occurs :param reattach_on_restart: if the scheduler dies while the pod is running, reattach and monitor :param delete_on_termination: What to do when the pod reaches its final state, or the execution is interrupted. If True (default), delete the pod; if False, leave the pod. :param kubernetes_conn_id: the connection to Kubernetes cluster :param random_name_suffix: If True, adds a random suffix to the pod name """
def__init__(self,*,image:str|None=None,code_path:str|None=None,namespace:str="default",name:str|None=None,application_file:str|None=None,template_spec=None,get_logs:bool=True,do_xcom_push:bool=False,success_run_history_limit:int=1,startup_timeout_seconds=600,log_events_on_failure:bool=False,reattach_on_restart:bool=True,delete_on_termination:bool=True,kubernetes_conn_id:str="kubernetes_default",random_name_suffix:bool=True,**kwargs,)->None:super().__init__(name=name,**kwargs)self.image=imageself.code_path=code_pathself.application_file=application_fileself.template_spec=template_specself.kubernetes_conn_id=kubernetes_conn_idself.startup_timeout_seconds=startup_timeout_secondsself.reattach_on_restart=reattach_on_restartself.delete_on_termination=delete_on_terminationself.do_xcom_push=do_xcom_pushself.namespace=namespaceself.get_logs=get_logsself.log_events_on_failure=log_events_on_failureself.success_run_history_limit=success_run_history_limitself.random_name_suffix=random_name_suffixifself.base_container_name!=self.BASE_CONTAINER_NAME:self.log.warning("base_container_name is not supported and will be overridden to %s",self.BASE_CONTAINER_NAME)self.base_container_name=self.BASE_CONTAINER_NAMEifself.get_logsandself.container_logs!=self.BASE_CONTAINER_NAME:self.log.warning("container_logs is not supported and will be overridden to %s",self.BASE_CONTAINER_NAME)self.container_logs=[self.BASE_CONTAINER_NAME]def_render_nested_template_fields(self,content:Any,context:Mapping[str,Any],jinja_env:jinja2.Environment,seen_oids:set,)->None:ifid(content)notinseen_oidsandisinstance(content,k8s.V1EnvVar):seen_oids.add(id(content))self._do_render_template_fields(content,("value","name"),context,jinja_env,seen_oids)returnsuper()._render_nested_template_fields(content,context,jinja_env,seen_oids)
[docs]defmanage_template_specs(self):ifself.application_file:try:filepath=Path(self.application_file.rstrip()).resolve(strict=True)except(FileNotFoundError,OSError,RuntimeError,ValueError):application_file_body=self.application_fileelse:application_file_body=filepath.read_text()template_body=_load_body_to_dict(application_file_body)ifnotisinstance(template_body,dict):msg=f"application_file body can't transformed into the dictionary:\n{application_file_body}"raiseTypeError(msg)elifself.template_spec:template_body=self.template_specelse:raiseAirflowException("either application_file or template_spec should be passed")if"spark"notintemplate_body:template_body={"spark":template_body}returntemplate_body
[docs]defcreate_job_name(self):name=(self.nameorself.template_body.get("spark",{}).get("metadata",{}).get("name")orself.task_id)ifself.random_name_suffix:updated_name=add_unique_suffix(name=name,max_len=MAX_LABEL_LEN)else:# truncation is required to maintain the same behavior as beforeupdated_name=name[:MAX_LABEL_LEN]returnself._set_name(updated_name)
[docs]defcreate_labels_for_pod(context:dict|None=None,include_try_number:bool=True)->dict:""" Generate labels for the pod to track the pod in case of Operator crash. :param include_try_number: add try number to labels :param context: task context provided by airflow DAG :return: dict. """ifnotcontext:return{}ti=context["ti"]run_id=context["run_id"]labels={"dag_id":ti.dag_id,"task_id":ti.task_id,"run_id":run_id,"spark_kubernetes_operator":"True",# 'execution_date': context['ts'],# 'try_number': context['ti'].try_number,}# If running on Airflow 2.3+:map_index=getattr(ti,"map_index",-1)ifmap_index>=0:labels["map_index"]=map_indexifinclude_try_number:labels.update(try_number=ti.try_number)# In the case of sub dags this is just useful# TODO: Remove this when the minimum version of Airflow is bumped to 3.0ifgetattr(context["dag"],"is_subdag",False):labels["parent_dag_id"]=context["dag"].parent_dag.dag_id# Ensure that label is valid for Kube,# and if not truncate/remove invalid chars and replace with short hash.forlabel_id,labelinlabels.items():safe_label=pod_generator.make_safe_label_value(str(label))labels[label_id]=safe_labelreturnlabels
[docs]deftemplate_body(self):"""Templated body for CustomObjectLauncher."""returnself.manage_template_specs()
[docs]deffind_spark_job(self,context):labels=self.create_labels_for_pod(context,include_try_number=False)label_selector=self._get_pod_identifying_label_string(labels)+",spark-role=driver"pod_list=self.client.list_namespaced_pod(self.namespace,label_selector=label_selector).itemspod=Noneiflen(pod_list)>1:# and self.reattach_on_restart:raiseAirflowException(f"More than one pod running with labels: {label_selector}")eliflen(pod_list)==1:pod=pod_list[0]self.log.info("Found matching driver pod %s with labels %s",pod.metadata.name,pod.metadata.labels)self.log.info("`try_number` of task_instance: %s",context["ti"].try_number)self.log.info("`try_number` of pod: %s",pod.metadata.labels["try_number"])returnpod
[docs]defon_kill(self)->None:ifself.launcher:self.log.debug("Deleting spark job for task %s",self.task_id)self.launcher.delete_spark_job()
[docs]defpatch_already_checked(self,pod:k8s.V1Pod,*,reraise=True):"""Add an "already checked" annotation to ensure we don't reattach on retries."""pod.metadata.labels["already_checked"]="True"body=PodGenerator.serialize_pod(pod)self.client.patch_namespaced_pod(pod.metadata.name,pod.metadata.namespace,body)
[docs]defdry_run(self)->None:"""Print out the spark job that would be created by this operator."""print(prune_dict(self.launcher.body,mode="strict"))