Source code for airflow.providers.google.cloud.triggers.vertex_ai
# Licensed to the Apache Software Foundation (ASF) under one# or more contributor license agreements. See the NOTICE file# distributed with this work for additional information# regarding copyright ownership. The ASF licenses this file# to you under the Apache License, Version 2.0 (the# "License"); you may not use this file except in compliance# with the License. You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing,# software distributed under the License is distributed on an# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY# KIND, either express or implied. See the License for the# specific language governing permissions and limitations# under the License.from__future__importannotationsfromcollections.abcimportAsyncIterator,Sequencefromfunctoolsimportcached_propertyfromtypingimportTYPE_CHECKING,Anyfromgoogle.cloud.aiplatform_v1import(BatchPredictionJob,HyperparameterTuningJob,JobState,PipelineState,types,)fromairflow.exceptionsimportAirflowExceptionfromairflow.providers.google.cloud.hooks.vertex_ai.batch_prediction_jobimportBatchPredictionJobAsyncHookfromairflow.providers.google.cloud.hooks.vertex_ai.custom_jobimportCustomJobAsyncHookfromairflow.providers.google.cloud.hooks.vertex_ai.hyperparameter_tuning_jobimport(HyperparameterTuningJobAsyncHook,)fromairflow.providers.google.cloud.hooks.vertex_ai.pipeline_jobimportPipelineJobAsyncHookfromairflow.triggers.baseimportBaseTrigger,TriggerEventifTYPE_CHECKING:fromprotoimportMessage
[docs]classBaseVertexAIJobTrigger(BaseTrigger):""" Base class for Vertex AI job triggers. This trigger polls the Vertex AI job and checks its status. In order to use it properly, you must: - implement the following methods `_wait_job()`. - override required `job_type_verbose_name` attribute to provide meaningful message describing your job type. - override required `job_serializer_class` attribute to provide proto.Message class that will be used to serialize your job with `to_dict()` class method. """
[docs]asyncdefrun(self)->AsyncIterator[TriggerEvent]:try:job=awaitself._wait_job()exceptAirflowExceptionasex:yieldTriggerEvent({"status":"error","message":str(ex),})returnstatus="success"ifjob.stateinself.statuses_successelse"error"message=f"{self.job_type_verbose_name}{job.name} completed with status {job.state.name}"yieldTriggerEvent({"status":status,"message":message,"job":self._serialize_job(job),})
asyncdef_wait_job(self)->Any:"""Awaits a Vertex AI job instance for a status examination."""raiseNotImplementedErrordef_serialize_job(self,job:Any)->Any:returnself.job_serializer_class.to_dict(job)
[docs]classCreateHyperparameterTuningJobTrigger(BaseVertexAIJobTrigger):"""CreateHyperparameterTuningJobTrigger run on the trigger worker to perform create operation."""
[docs]classCreateBatchPredictionJobTrigger(BaseVertexAIJobTrigger):"""CreateBatchPredictionJobTrigger run on the trigger worker to perform create operation."""
[docs]classCustomTrainingJobTrigger(BaseVertexAIJobTrigger):""" Make async calls to Vertex AI to check the state of a running custom training job. Return the job when it enters a completed state. """
[docs]classCustomContainerTrainingJobTrigger(BaseVertexAIJobTrigger):""" Make async calls to Vertex AI to check the state of a running custom container training job. Return the job when it enters a completed state. """
[docs]job_type_verbose_name="Custom Container Training Job"
[docs]classCustomPythonPackageTrainingJobTrigger(BaseVertexAIJobTrigger):""" Make async calls to Vertex AI to check the state of a running custom python package training job. Return the job when it enters a completed state. """
[docs]job_type_verbose_name="Custom Python Package Training Job"