Source code for airflow.providers.google.cloud.operators.vertex_ai.auto_ml

#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

"""This module contains Google Vertex AI operators."""

from __future__ import annotations

from typing import TYPE_CHECKING, Sequence

from google.api_core.exceptions import NotFound
from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
from google.cloud.aiplatform import datasets
from google.cloud.aiplatform.models import Model
from google.cloud.aiplatform_v1.types.training_pipeline import TrainingPipeline

from airflow.providers.google.cloud.hooks.vertex_ai.auto_ml import AutoMLHook
from airflow.providers.google.cloud.links.vertex_ai import (
    VertexAIModelLink,
    VertexAITrainingLink,
    VertexAITrainingPipelinesLink,
)
from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator

if TYPE_CHECKING:
    from google.api_core.retry import Retry

    from airflow.utils.context import Context


[docs]class AutoMLTrainingJobBaseOperator(GoogleCloudBaseOperator): """The base class for operators that launch AutoML jobs on VertexAI.""" def __init__( self, *, project_id: str, region: str, display_name: str, labels: dict[str, str] | None = None, parent_model: str | None = None, is_default_version: bool | None = None, model_version_aliases: list[str] | None = None, model_version_description: str | None = None, training_encryption_spec_key_name: str | None = None, model_encryption_spec_key_name: str | None = None, # RUN training_fraction_split: float | None = None, test_fraction_split: float | None = None, model_display_name: str | None = None, model_labels: dict[str, str] | None = None, sync: bool = True, gcp_conn_id: str = "google_cloud_default", impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) self.project_id = project_id self.region = region self.display_name = display_name self.labels = labels self.parent_model = parent_model self.is_default_version = is_default_version self.model_version_aliases = model_version_aliases self.model_version_description = model_version_description self.training_encryption_spec_key_name = training_encryption_spec_key_name self.model_encryption_spec_key_name = model_encryption_spec_key_name # START Run param self.training_fraction_split = training_fraction_split self.test_fraction_split = test_fraction_split self.model_display_name = model_display_name self.model_labels = model_labels self.sync = sync # END Run param self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain self.hook: AutoMLHook | None = None
[docs] def on_kill(self) -> None: """Act as a callback called when the operator is killed; cancel any running job.""" if self.hook: self.hook.cancel_auto_ml_job()
[docs]class CreateAutoMLForecastingTrainingJobOperator(AutoMLTrainingJobBaseOperator): """Create AutoML Forecasting Training job."""
[docs] template_fields = ( "parent_model", "dataset_id", "region", "impersonation_chain", )
def __init__( self, *, dataset_id: str, target_column: str, time_column: str, time_series_identifier_column: str, unavailable_at_forecast_columns: list[str], available_at_forecast_columns: list[str], forecast_horizon: int, data_granularity_unit: str, data_granularity_count: int, optimization_objective: str | None = None, column_specs: dict[str, str] | None = None, column_transformations: list[dict[str, dict[str, str]]] | None = None, validation_fraction_split: float | None = None, predefined_split_column_name: str | None = None, weight_column: str | None = None, time_series_attribute_columns: list[str] | None = None, context_window: int | None = None, export_evaluated_data_items: bool = False, export_evaluated_data_items_bigquery_destination_uri: str | None = None, export_evaluated_data_items_override_destination: bool = False, quantiles: list[float] | None = None, validation_options: str | None = None, budget_milli_node_hours: int = 1000, region: str, impersonation_chain: str | Sequence[str] | None = None, parent_model: str | None = None, **kwargs, ) -> None: super().__init__( region=region, impersonation_chain=impersonation_chain, parent_model=parent_model, **kwargs ) self.dataset_id = dataset_id self.target_column = target_column self.time_column = time_column self.time_series_identifier_column = time_series_identifier_column self.unavailable_at_forecast_columns = unavailable_at_forecast_columns self.available_at_forecast_columns = available_at_forecast_columns self.forecast_horizon = forecast_horizon self.data_granularity_unit = data_granularity_unit self.data_granularity_count = data_granularity_count self.optimization_objective = optimization_objective self.column_specs = column_specs self.column_transformations = column_transformations self.validation_fraction_split = validation_fraction_split self.predefined_split_column_name = predefined_split_column_name self.weight_column = weight_column self.time_series_attribute_columns = time_series_attribute_columns self.context_window = context_window self.export_evaluated_data_items = export_evaluated_data_items self.export_evaluated_data_items_bigquery_destination_uri = ( export_evaluated_data_items_bigquery_destination_uri ) self.export_evaluated_data_items_override_destination = ( export_evaluated_data_items_override_destination ) self.quantiles = quantiles self.validation_options = validation_options self.budget_milli_node_hours = budget_milli_node_hours
[docs] def execute(self, context: Context): self.hook = AutoMLHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, ) self.parent_model = self.parent_model.rpartition("@")[0] if self.parent_model else None model, training_id = self.hook.create_auto_ml_forecasting_training_job( project_id=self.project_id, region=self.region, display_name=self.display_name, dataset=datasets.TimeSeriesDataset(dataset_name=self.dataset_id), parent_model=self.parent_model, is_default_version=self.is_default_version, model_version_aliases=self.model_version_aliases, model_version_description=self.model_version_description, target_column=self.target_column, time_column=self.time_column, time_series_identifier_column=self.time_series_identifier_column, unavailable_at_forecast_columns=self.unavailable_at_forecast_columns, available_at_forecast_columns=self.available_at_forecast_columns, forecast_horizon=self.forecast_horizon, data_granularity_unit=self.data_granularity_unit, data_granularity_count=self.data_granularity_count, optimization_objective=self.optimization_objective, column_specs=self.column_specs, column_transformations=self.column_transformations, labels=self.labels, training_encryption_spec_key_name=self.training_encryption_spec_key_name, model_encryption_spec_key_name=self.model_encryption_spec_key_name, training_fraction_split=self.training_fraction_split, validation_fraction_split=self.validation_fraction_split, test_fraction_split=self.test_fraction_split, predefined_split_column_name=self.predefined_split_column_name, weight_column=self.weight_column, time_series_attribute_columns=self.time_series_attribute_columns, context_window=self.context_window, export_evaluated_data_items=self.export_evaluated_data_items, export_evaluated_data_items_bigquery_destination_uri=( self.export_evaluated_data_items_bigquery_destination_uri ), export_evaluated_data_items_override_destination=( self.export_evaluated_data_items_override_destination ), quantiles=self.quantiles, validation_options=self.validation_options, budget_milli_node_hours=self.budget_milli_node_hours, model_display_name=self.model_display_name, model_labels=self.model_labels, sync=self.sync, ) if model: result = Model.to_dict(model) model_id = self.hook.extract_model_id(result) self.xcom_push(context, key="model_id", value=model_id) VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id) else: result = model # type: ignore self.xcom_push(context, key="training_id", value=training_id) VertexAITrainingLink.persist(context=context, task_instance=self, training_id=training_id) return result
[docs]class CreateAutoMLImageTrainingJobOperator(AutoMLTrainingJobBaseOperator): """Create Auto ML Image Training job."""
[docs] template_fields = ( "parent_model", "dataset_id", "region", "impersonation_chain", )
def __init__( self, *, dataset_id: str, prediction_type: str = "classification", multi_label: bool = False, model_type: str = "CLOUD", base_model: Model | None = None, validation_fraction_split: float | None = None, training_filter_split: str | None = None, validation_filter_split: str | None = None, test_filter_split: str | None = None, budget_milli_node_hours: int | None = None, disable_early_stopping: bool = False, region: str, impersonation_chain: str | Sequence[str] | None = None, parent_model: str | None = None, **kwargs, ) -> None: super().__init__( region=region, impersonation_chain=impersonation_chain, parent_model=parent_model, **kwargs ) self.dataset_id = dataset_id self.prediction_type = prediction_type self.multi_label = multi_label self.model_type = model_type self.base_model = base_model self.validation_fraction_split = validation_fraction_split self.training_filter_split = training_filter_split self.validation_filter_split = validation_filter_split self.test_filter_split = test_filter_split self.budget_milli_node_hours = budget_milli_node_hours self.disable_early_stopping = disable_early_stopping
[docs] def execute(self, context: Context): self.hook = AutoMLHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, ) self.parent_model = self.parent_model.rpartition("@")[0] if self.parent_model else None model, training_id = self.hook.create_auto_ml_image_training_job( project_id=self.project_id, region=self.region, display_name=self.display_name, dataset=datasets.ImageDataset(dataset_name=self.dataset_id), parent_model=self.parent_model, is_default_version=self.is_default_version, model_version_aliases=self.model_version_aliases, model_version_description=self.model_version_description, prediction_type=self.prediction_type, multi_label=self.multi_label, model_type=self.model_type, base_model=self.base_model, labels=self.labels, training_encryption_spec_key_name=self.training_encryption_spec_key_name, model_encryption_spec_key_name=self.model_encryption_spec_key_name, training_fraction_split=self.training_fraction_split, validation_fraction_split=self.validation_fraction_split, test_fraction_split=self.test_fraction_split, training_filter_split=self.training_filter_split, validation_filter_split=self.validation_filter_split, test_filter_split=self.test_filter_split, budget_milli_node_hours=self.budget_milli_node_hours, model_display_name=self.model_display_name, model_labels=self.model_labels, disable_early_stopping=self.disable_early_stopping, sync=self.sync, ) if model: result = Model.to_dict(model) model_id = self.hook.extract_model_id(result) self.xcom_push(context, key="model_id", value=model_id) VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id) else: result = model # type: ignore self.xcom_push(context, key="training_id", value=training_id) VertexAITrainingLink.persist(context=context, task_instance=self, training_id=training_id) return result
[docs]class CreateAutoMLTabularTrainingJobOperator(AutoMLTrainingJobBaseOperator): """Create Auto ML Tabular Training job."""
[docs] template_fields = ( "parent_model", "dataset_id", "region", "impersonation_chain", )
def __init__( self, *, dataset_id: str, target_column: str, optimization_prediction_type: str, optimization_objective: str | None = None, column_specs: dict[str, str] | None = None, column_transformations: list[dict[str, dict[str, str]]] | None = None, optimization_objective_recall_value: float | None = None, optimization_objective_precision_value: float | None = None, validation_fraction_split: float | None = None, predefined_split_column_name: str | None = None, timestamp_split_column_name: str | None = None, weight_column: str | None = None, budget_milli_node_hours: int = 1000, disable_early_stopping: bool = False, export_evaluated_data_items: bool = False, export_evaluated_data_items_bigquery_destination_uri: str | None = None, export_evaluated_data_items_override_destination: bool = False, region: str, impersonation_chain: str | Sequence[str] | None = None, parent_model: str | None = None, **kwargs, ) -> None: super().__init__( region=region, impersonation_chain=impersonation_chain, parent_model=parent_model, **kwargs ) self.dataset_id = dataset_id self.target_column = target_column self.optimization_prediction_type = optimization_prediction_type self.optimization_objective = optimization_objective self.column_specs = column_specs self.column_transformations = column_transformations self.optimization_objective_recall_value = optimization_objective_recall_value self.optimization_objective_precision_value = optimization_objective_precision_value self.validation_fraction_split = validation_fraction_split self.predefined_split_column_name = predefined_split_column_name self.timestamp_split_column_name = timestamp_split_column_name self.weight_column = weight_column self.budget_milli_node_hours = budget_milli_node_hours self.disable_early_stopping = disable_early_stopping self.export_evaluated_data_items = export_evaluated_data_items self.export_evaluated_data_items_bigquery_destination_uri = ( export_evaluated_data_items_bigquery_destination_uri ) self.export_evaluated_data_items_override_destination = ( export_evaluated_data_items_override_destination )
[docs] def execute(self, context: Context): self.hook = AutoMLHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, ) credentials, _ = self.hook.get_credentials_and_project_id() self.parent_model = self.parent_model.rpartition("@")[0] if self.parent_model else None model, training_id = self.hook.create_auto_ml_tabular_training_job( project_id=self.project_id, region=self.region, display_name=self.display_name, dataset=datasets.TabularDataset( dataset_name=self.dataset_id, project=self.project_id, credentials=credentials, ), parent_model=self.parent_model, is_default_version=self.is_default_version, model_version_aliases=self.model_version_aliases, model_version_description=self.model_version_description, target_column=self.target_column, optimization_prediction_type=self.optimization_prediction_type, optimization_objective=self.optimization_objective, column_specs=self.column_specs, column_transformations=self.column_transformations, optimization_objective_recall_value=self.optimization_objective_recall_value, optimization_objective_precision_value=self.optimization_objective_precision_value, labels=self.labels, training_encryption_spec_key_name=self.training_encryption_spec_key_name, model_encryption_spec_key_name=self.model_encryption_spec_key_name, training_fraction_split=self.training_fraction_split, validation_fraction_split=self.validation_fraction_split, test_fraction_split=self.test_fraction_split, predefined_split_column_name=self.predefined_split_column_name, timestamp_split_column_name=self.timestamp_split_column_name, weight_column=self.weight_column, budget_milli_node_hours=self.budget_milli_node_hours, model_display_name=self.model_display_name, model_labels=self.model_labels, disable_early_stopping=self.disable_early_stopping, export_evaluated_data_items=self.export_evaluated_data_items, export_evaluated_data_items_bigquery_destination_uri=( self.export_evaluated_data_items_bigquery_destination_uri ), export_evaluated_data_items_override_destination=( self.export_evaluated_data_items_override_destination ), sync=self.sync, ) if model: result = Model.to_dict(model) model_id = self.hook.extract_model_id(result) self.xcom_push(context, key="model_id", value=model_id) VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id) else: result = model # type: ignore self.xcom_push(context, key="training_id", value=training_id) VertexAITrainingLink.persist(context=context, task_instance=self, training_id=training_id) return result
[docs]class CreateAutoMLTextTrainingJobOperator(AutoMLTrainingJobBaseOperator): """Create Auto ML Text Training job."""
[docs] template_fields = [ "parent_model", "dataset_id", "region", "impersonation_chain", ]
def __init__( self, *, dataset_id: str, prediction_type: str, multi_label: bool = False, sentiment_max: int = 10, validation_fraction_split: float | None = None, training_filter_split: str | None = None, validation_filter_split: str | None = None, test_filter_split: str | None = None, **kwargs, ) -> None: super().__init__(**kwargs) self.dataset_id = dataset_id self.prediction_type = prediction_type self.multi_label = multi_label self.sentiment_max = sentiment_max self.validation_fraction_split = validation_fraction_split self.training_filter_split = training_filter_split self.validation_filter_split = validation_filter_split self.test_filter_split = test_filter_split
[docs] def execute(self, context: Context): self.hook = AutoMLHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, ) self.parent_model = self.parent_model.rpartition("@")[0] if self.parent_model else None model, training_id = self.hook.create_auto_ml_text_training_job( project_id=self.project_id, region=self.region, display_name=self.display_name, dataset=datasets.TextDataset(dataset_name=self.dataset_id), prediction_type=self.prediction_type, multi_label=self.multi_label, sentiment_max=self.sentiment_max, labels=self.labels, training_encryption_spec_key_name=self.training_encryption_spec_key_name, model_encryption_spec_key_name=self.model_encryption_spec_key_name, training_fraction_split=self.training_fraction_split, validation_fraction_split=self.validation_fraction_split, test_fraction_split=self.test_fraction_split, training_filter_split=self.training_filter_split, validation_filter_split=self.validation_filter_split, test_filter_split=self.test_filter_split, model_display_name=self.model_display_name, model_labels=self.model_labels, sync=self.sync, parent_model=self.parent_model, is_default_version=self.is_default_version, model_version_aliases=self.model_version_aliases, model_version_description=self.model_version_description, ) if model: result = Model.to_dict(model) model_id = self.hook.extract_model_id(result) self.xcom_push(context, key="model_id", value=model_id) VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id) else: result = model # type: ignore self.xcom_push(context, key="training_id", value=training_id) VertexAITrainingLink.persist(context=context, task_instance=self, training_id=training_id) return result
[docs]class CreateAutoMLVideoTrainingJobOperator(AutoMLTrainingJobBaseOperator): """Create Auto ML Video Training job."""
[docs] template_fields = ( "parent_model", "dataset_id", "region", "impersonation_chain", )
def __init__( self, *, dataset_id: str, prediction_type: str = "classification", model_type: str = "CLOUD", training_filter_split: str | None = None, test_filter_split: str | None = None, region: str, impersonation_chain: str | Sequence[str] | None = None, parent_model: str | None = None, **kwargs, ) -> None: super().__init__( region=region, impersonation_chain=impersonation_chain, parent_model=parent_model, **kwargs ) self.dataset_id = dataset_id self.prediction_type = prediction_type self.model_type = model_type self.training_filter_split = training_filter_split self.test_filter_split = test_filter_split
[docs] def execute(self, context: Context): self.hook = AutoMLHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, ) self.parent_model = self.parent_model.rpartition("@")[0] if self.parent_model else None model, training_id = self.hook.create_auto_ml_video_training_job( project_id=self.project_id, region=self.region, display_name=self.display_name, dataset=datasets.VideoDataset(dataset_name=self.dataset_id), prediction_type=self.prediction_type, model_type=self.model_type, labels=self.labels, training_encryption_spec_key_name=self.training_encryption_spec_key_name, model_encryption_spec_key_name=self.model_encryption_spec_key_name, training_fraction_split=self.training_fraction_split, test_fraction_split=self.test_fraction_split, training_filter_split=self.training_filter_split, test_filter_split=self.test_filter_split, model_display_name=self.model_display_name, model_labels=self.model_labels, sync=self.sync, parent_model=self.parent_model, is_default_version=self.is_default_version, model_version_aliases=self.model_version_aliases, model_version_description=self.model_version_description, ) if model: result = Model.to_dict(model) model_id = self.hook.extract_model_id(result) self.xcom_push(context, key="model_id", value=model_id) VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id) else: result = model # type: ignore self.xcom_push(context, key="training_id", value=training_id) VertexAITrainingLink.persist(context=context, task_instance=self, training_id=training_id) return result
[docs]class DeleteAutoMLTrainingJobOperator(GoogleCloudBaseOperator): """ Delete an AutoML training job. Can be used with AutoMLForecastingTrainingJob, AutoMLImageTrainingJob, AutoMLTabularTrainingJob, AutoMLTextTrainingJob, or AutoMLVideoTrainingJob. """
[docs] template_fields = ("training_pipeline", "region", "project_id", "impersonation_chain")
def __init__( self, *, training_pipeline_id: str, region: str, project_id: str, retry: Retry | _MethodDefault = DEFAULT, timeout: float | None = None, metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) self.training_pipeline = training_pipeline_id self.region = region self.project_id = project_id self.retry = retry self.timeout = timeout self.metadata = metadata self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain
[docs] def execute(self, context: Context): hook = AutoMLHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, ) try: self.log.info("Deleting Auto ML training pipeline: %s", self.training_pipeline) training_pipeline_operation = hook.delete_training_pipeline( training_pipeline=self.training_pipeline, region=self.region, project_id=self.project_id, retry=self.retry, timeout=self.timeout, metadata=self.metadata, ) hook.wait_for_operation(timeout=self.timeout, operation=training_pipeline_operation) self.log.info("Training pipeline was deleted.") except NotFound: self.log.info("The Training Pipeline ID %s does not exist.", self.training_pipeline)
[docs]class ListAutoMLTrainingJobOperator(GoogleCloudBaseOperator): """ List an AutoML training job. Can be used with AutoMLForecastingTrainingJob, AutoMLImageTrainingJob, AutoMLTabularTrainingJob, AutoMLTextTrainingJob, or AutoMLVideoTrainingJob in a Location. """
[docs] template_fields = ( "region", "project_id", "impersonation_chain", )
def __init__( self, *, region: str, project_id: str, page_size: int | None = None, page_token: str | None = None, filter: str | None = None, read_mask: str | None = None, retry: Retry | _MethodDefault = DEFAULT, timeout: float | None = None, metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) self.region = region self.project_id = project_id self.page_size = page_size self.page_token = page_token self.filter = filter self.read_mask = read_mask self.retry = retry self.timeout = timeout self.metadata = metadata self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain
[docs] def execute(self, context: Context): hook = AutoMLHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, ) results = hook.list_training_pipelines( region=self.region, project_id=self.project_id, page_size=self.page_size, page_token=self.page_token, filter=self.filter, read_mask=self.read_mask, retry=self.retry, timeout=self.timeout, metadata=self.metadata, ) VertexAITrainingPipelinesLink.persist(context=context, task_instance=self) return [TrainingPipeline.to_dict(result) for result in results]

Was this entry helpful?