#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""This module contains a Google ML Engine Hook."""
from __future__ import annotations
import contextlib
import logging
import random
import time
from typing import TYPE_CHECKING, Callable
from aiohttp import ClientSession
from gcloud.aio.auth import AioSession, Token
from googleapiclient.discovery import Resource, build
from googleapiclient.errors import HttpError
from airflow.exceptions import AirflowException
from airflow.providers.google.common.hooks.base_google import (
PROVIDE_PROJECT_ID,
GoogleBaseAsyncHook,
GoogleBaseHook,
)
from airflow.version import version as airflow_version
if TYPE_CHECKING:
from httplib2 import Response
from requests import Session
[docs]log = logging.getLogger(__name__)
_AIRFLOW_VERSION = "v" + airflow_version.replace(".", "-").replace("+", "-")
def _poll_with_exponential_delay(
request, execute_num_retries, max_n, is_done_func, is_error_func
) -> Response:
"""
Execute request with exponential delay.
This method is intended to handle and retry in case of api-specific errors,
such as 429 "Too Many Requests", unlike the `request.execute` which handles
lower level errors like `ConnectionError`/`socket.timeout`/`ssl.SSLError`.
:param request: request to be executed.
:param execute_num_retries: num_retries for `request.execute` method.
:param max_n: number of times to retry request in this method.
:param is_done_func: callable to determine if operation is done.
:param is_error_func: callable to determine if operation is failed.
:return: response
"""
for i in range(0, max_n):
try:
response = request.execute(num_retries=execute_num_retries)
if is_error_func(response):
raise ValueError(f"The response contained an error: {response}")
if is_done_func(response):
log.info("Operation is done: %s", response)
return response
time.sleep((2**i) + random.random())
except HttpError as e:
if e.resp.status != 429:
log.info("Something went wrong. Not retrying: %s", format(e))
raise
else:
time.sleep((2**i) + random.random())
raise ValueError(f"Connection could not be established after {max_n} retries.")
[docs]class MLEngineHook(GoogleBaseHook):
"""
Hook for Google ML Engine APIs.
All the methods in the hook where project_id is used must be called with
keyword arguments rather than positional.
"""
[docs] def get_conn(self) -> Resource:
"""
Retrieve the connection to MLEngine.
:return: Google MLEngine services object.
"""
authed_http = self._authorize()
return build("ml", "v1", http=authed_http, cache_discovery=False)
@GoogleBaseHook.fallback_to_default_project_id
[docs] def create_job(self, job: dict, project_id: str, use_existing_job_fn: Callable | None = None) -> dict:
"""
Launch a MLEngine job and wait for it to reach a terminal state.
:param project_id: The Google Cloud project id within which MLEngine
job will be launched. If set to None or missing, the default project_id from the Google Cloud
connection is used.
:param job: MLEngine Job object that should be provided to the MLEngine
API, such as: ::
{
'jobId': 'my_job_id',
'trainingInput': {
'scaleTier': 'STANDARD_1',
...
}
}
:param use_existing_job_fn: In case that a MLEngine job with the same
job_id already exist, this method (if provided) will decide whether
we should use this existing job, continue waiting for it to finish
and returning the job object. It should accepts a MLEngine job
object, and returns a boolean value indicating whether it is OK to
reuse the existing job. If 'use_existing_job_fn' is not provided,
we by default reuse the existing MLEngine job.
:return: The MLEngine job object if the job successfully reach a
terminal state (which might be FAILED or CANCELLED state).
"""
hook = self.get_conn()
self._append_label(job)
self.log.info("Creating job.")
request = hook.projects().jobs().create(parent=f"projects/{project_id}", body=job)
job_id = job["jobId"]
try:
request.execute(num_retries=self.num_retries)
except HttpError as e:
# 409 means there is an existing job with the same job ID.
if e.resp.status == 409:
if use_existing_job_fn is not None:
existing_job = self.get_job(project_id, job_id)
if not use_existing_job_fn(existing_job):
self.log.error(
"Job with job_id %s already exist, but it does not match our expectation: %s",
job_id,
existing_job,
)
raise
self.log.info("Job with job_id %s already exist. Will waiting for it to finish", job_id)
else:
self.log.error("Failed to create MLEngine job: %s", e)
raise
return self._wait_for_job_done(project_id, job_id)
@GoogleBaseHook.fallback_to_default_project_id
[docs] def create_job_without_waiting_result(
self,
body: dict,
project_id: str,
):
"""
Launch a MLEngine job and wait for it to reach a terminal state.
:param project_id: The Google Cloud project id within which MLEngine
job will be launched. If set to None or missing, the default project_id from the Google Cloud
connection is used.
:param body: MLEngine Job object that should be provided to the MLEngine
API, such as: ::
{
'jobId': 'my_job_id',
'trainingInput': {
'scaleTier': 'STANDARD_1',
...
}
}
:return: The MLEngine job_id of the object if the job successfully reach a
terminal state (which might be FAILED or CANCELLED state).
"""
hook = self.get_conn()
self._append_label(body)
request = hook.projects().jobs().create(parent=f"projects/{project_id}", body=body)
job_id = body["jobId"]
request.execute(num_retries=self.num_retries)
return job_id
@GoogleBaseHook.fallback_to_default_project_id
[docs] def cancel_job(
self,
job_id: str,
project_id: str,
) -> dict:
"""
Cancel a MLEngine job.
:param project_id: The Google Cloud project id within which MLEngine
job will be cancelled. If set to None or missing, the default project_id from the Google Cloud
connection is used.
:param job_id: A unique id for the want-to-be cancelled Google MLEngine training job.
:return: Empty dict if cancelled successfully
:raises: googleapiclient.errors.HttpError
"""
hook = self.get_conn()
request = hook.projects().jobs().cancel(name=f"projects/{project_id}/jobs/{job_id}")
try:
return request.execute(num_retries=self.num_retries)
except HttpError as e:
if e.resp.status == 404:
self.log.error("Job with job_id %s does not exist. ", job_id)
raise
elif e.resp.status == 400:
self.log.info("Job with job_id %s is already complete, cancellation aborted.", job_id)
return {}
else:
self.log.error("Failed to cancel MLEngine job: %s", e)
raise
[docs] def get_job(self, project_id: str, job_id: str) -> dict:
"""
Get a MLEngine job based on the job id.
:param project_id: The project in which the Job is located. If set to None or missing, the default
project_id from the Google Cloud connection is used. (templated)
:param job_id: A unique id for the Google MLEngine job. (templated)
:return: MLEngine job object if succeed.
:raises: googleapiclient.errors.HttpError
"""
hook = self.get_conn()
job_name = f"projects/{project_id}/jobs/{job_id}"
request = hook.projects().jobs().get(name=job_name)
while True:
try:
return request.execute(num_retries=self.num_retries)
except HttpError as e:
if e.resp.status == 429:
# polling after 30 seconds when quota failure occurs
time.sleep(30)
else:
self.log.error("Failed to get MLEngine job: %s", e)
raise
def _wait_for_job_done(self, project_id: str, job_id: str, interval: int = 30):
"""
Wait for the Job to reach a terminal state.
This method will periodically check the job state until the job reach
a terminal state.
:param project_id: The project in which the Job is located. If set to None or missing, the default
project_id from the Google Cloud connection is used. (templated)
:param job_id: A unique id for the Google MLEngine job. (templated)
:param interval: Time expressed in seconds after which the job status is checked again. (templated)
:raises: googleapiclient.errors.HttpError
"""
self.log.info("Waiting for job. job_id=%s", job_id)
if interval <= 0:
raise ValueError("Interval must be > 0")
while True:
job = self.get_job(project_id, job_id)
if job["state"] in ["SUCCEEDED", "FAILED", "CANCELLED"]:
return job
time.sleep(interval)
@GoogleBaseHook.fallback_to_default_project_id
[docs] def create_version(
self,
model_name: str,
version_spec: dict,
project_id: str,
) -> dict:
"""
Create the Version on Google Cloud ML Engine.
:param version_spec: A dictionary containing the information about the version. (templated)
:param model_name: The name of the Google Cloud ML Engine model that the version belongs to.
(templated)
:param project_id: The Google Cloud project name to which MLEngine model belongs.
If set to None or missing, the default project_id from the Google Cloud connection is used.
(templated)
:return: If the version was created successfully, returns the operation.
Otherwise raises an error .
"""
hook = self.get_conn()
parent_name = f"projects/{project_id}/models/{model_name}"
self._append_label(version_spec)
create_request = hook.projects().models().versions().create(parent=parent_name, body=version_spec)
response = create_request.execute(num_retries=self.num_retries)
get_request = hook.projects().operations().get(name=response["name"])
return _poll_with_exponential_delay(
request=get_request,
execute_num_retries=self.num_retries,
max_n=9,
is_done_func=lambda resp: resp.get("done", False),
is_error_func=lambda resp: resp.get("error", None) is not None,
)
@GoogleBaseHook.fallback_to_default_project_id
[docs] def set_default_version(
self,
model_name: str,
version_name: str,
project_id: str,
) -> dict:
"""
Set a version to be the default. Blocks until finished.
:param model_name: The name of the Google Cloud ML Engine model that the version belongs to.
(templated)
:param version_name: A name to use for the version being operated upon. (templated)
:param project_id: The Google Cloud project name to which MLEngine model belongs. If set to None
or missing, the default project_id from the Google Cloud connection is used. (templated)
:return: If successful, return an instance of Version.
Otherwise raises an error.
:raises: googleapiclient.errors.HttpError
"""
hook = self.get_conn()
full_version_name = f"projects/{project_id}/models/{model_name}/versions/{version_name}"
request = hook.projects().models().versions().setDefault(name=full_version_name, body={})
try:
response = request.execute(num_retries=self.num_retries)
self.log.info("Successfully set version: %s to default", response)
return response
except HttpError as e:
self.log.error("Something went wrong: %s", e)
raise
@GoogleBaseHook.fallback_to_default_project_id
[docs] def list_versions(
self,
model_name: str,
project_id: str,
) -> list[dict]:
"""
List all available versions of a model. Blocks until finished.
:param model_name: The name of the Google Cloud ML Engine model that the version
belongs to. (templated)
:param project_id: The Google Cloud project name to which MLEngine model belongs. If set to None or
missing, the default project_id from the Google Cloud connection is used. (templated)
:return: return an list of instance of Version.
:raises: googleapiclient.errors.HttpError
"""
hook = self.get_conn()
result: list[dict] = []
full_parent_name = f"projects/{project_id}/models/{model_name}"
request = hook.projects().models().versions().list(parent=full_parent_name, pageSize=100)
while request is not None:
response = request.execute(num_retries=self.num_retries)
result.extend(response.get("versions", []))
request = (
hook.projects()
.models()
.versions()
.list_next(previous_request=request, previous_response=response)
)
time.sleep(5)
return result
@GoogleBaseHook.fallback_to_default_project_id
[docs] def delete_version(
self,
model_name: str,
version_name: str,
project_id: str,
) -> dict:
"""
Delete the given version of a model. Blocks until finished.
:param model_name: The name of the Google Cloud ML Engine model that the version
belongs to. (templated)
:param project_id: The Google Cloud project name to which MLEngine
model belongs.
:param version_name: A name to use for the version being operated upon. (templated)
:return: If the version was deleted successfully, returns the operation.
Otherwise raises an error.
"""
hook = self.get_conn()
full_name = f"projects/{project_id}/models/{model_name}/versions/{version_name}"
delete_request = hook.projects().models().versions().delete(name=full_name)
response = delete_request.execute(num_retries=self.num_retries)
get_request = hook.projects().operations().get(name=response["name"])
return _poll_with_exponential_delay(
request=get_request,
execute_num_retries=self.num_retries,
max_n=9,
is_done_func=lambda resp: resp.get("done", False),
is_error_func=lambda resp: resp.get("error", None) is not None,
)
@GoogleBaseHook.fallback_to_default_project_id
[docs] def create_model(
self,
model: dict,
project_id: str,
) -> dict:
"""
Create a Model. Blocks until finished.
:param model: A dictionary containing the information about the model.
:param project_id: The Google Cloud project name to which MLEngine model belongs. If set to None or
missing, the default project_id from the Google Cloud connection is used. (templated)
:return: If the version was created successfully, returns the instance of Model.
Otherwise raises an error.
:raises: googleapiclient.errors.HttpError
"""
hook = self.get_conn()
if "name" not in model or not model["name"]:
raise ValueError("Model name must be provided and could not be an empty string")
project = f"projects/{project_id}"
self._append_label(model)
try:
request = hook.projects().models().create(parent=project, body=model)
response = request.execute(num_retries=self.num_retries)
except HttpError as e:
if e.resp.status != 409:
raise e
str(e) # Fills in the error_details field
if not e.error_details or len(e.error_details) != 1:
raise e
error_detail = e.error_details[0]
if error_detail["@type"] != "type.googleapis.com/google.rpc.BadRequest":
raise e
if "fieldViolations" not in error_detail or len(error_detail["fieldViolations"]) != 1:
raise e
field_violation = error_detail["fieldViolations"][0]
if (
field_violation["field"] != "model.name"
or field_violation["description"] != "A model with the same name already exists."
):
raise e
response = self.get_model(model_name=model["name"], project_id=project_id)
return response
@GoogleBaseHook.fallback_to_default_project_id
[docs] def get_model(
self,
model_name: str,
project_id: str,
) -> dict | None:
"""
Get a Model. Blocks until finished.
:param model_name: The name of the model.
:param project_id: The Google Cloud project name to which MLEngine model belongs. If set to None
or missing, the default project_id from the Google Cloud connection is used. (templated)
:return: If the model exists, returns the instance of Model.
Otherwise return None.
:raises: googleapiclient.errors.HttpError
"""
hook = self.get_conn()
if not model_name:
raise ValueError("Model name must be provided and it could not be an empty string")
full_model_name = f"projects/{project_id}/models/{model_name}"
request = hook.projects().models().get(name=full_model_name)
try:
return request.execute(num_retries=self.num_retries)
except HttpError as e:
if e.resp.status == 404:
self.log.error("Model was not found: %s", e)
return None
raise
@GoogleBaseHook.fallback_to_default_project_id
[docs] def delete_model(
self,
model_name: str,
project_id: str,
delete_contents: bool = False,
) -> None:
"""
Delete a Model. Blocks until finished.
:param model_name: The name of the model.
:param delete_contents: Whether to force the deletion even if the models is not empty.
Will delete all version (if any) in the dataset if set to True.
The default value is False.
:param project_id: The Google Cloud project name to which MLEngine model belongs. If set to None
or missing, the default project_id from the Google Cloud connection is used. (templated)
:raises: googleapiclient.errors.HttpError
"""
hook = self.get_conn()
if not model_name:
raise ValueError("Model name must be provided and it could not be an empty string")
model_path = f"projects/{project_id}/models/{model_name}"
if delete_contents:
self._delete_all_versions(model_name, project_id)
request = hook.projects().models().delete(name=model_path)
try:
request.execute(num_retries=self.num_retries)
except HttpError as e:
if e.resp.status == 404:
self.log.error("Model was not found: %s", e)
return
raise
def _delete_all_versions(self, model_name: str, project_id: str):
versions = self.list_versions(project_id=project_id, model_name=model_name)
# The default version can only be deleted when it is the last one in the model
non_default_versions = (version for version in versions if not version.get("isDefault", False))
for version in non_default_versions:
_, _, version_name = version["name"].rpartition("/")
self.delete_version(project_id=project_id, model_name=model_name, version_name=version_name)
default_versions = (version for version in versions if version.get("isDefault", False))
for version in default_versions:
_, _, version_name = version["name"].rpartition("/")
self.delete_version(project_id=project_id, model_name=model_name, version_name=version_name)
def _append_label(self, model: dict) -> None:
model["labels"] = model.get("labels", {})
model["labels"]["airflow-version"] = _AIRFLOW_VERSION
[docs]class MLEngineAsyncHook(GoogleBaseAsyncHook):
"""Class to get asynchronous hook for MLEngine."""
[docs] sync_hook_class = MLEngineHook
[docs] scopes = ["https://www.googleapis.com/auth/cloud-platform"]
def _check_fileds(
self,
job_id: str,
project_id: str = PROVIDE_PROJECT_ID,
):
if not project_id:
raise AirflowException("Google Cloud project id is required.")
if not job_id:
raise AirflowException("An unique job id is required for Google MLEngine training job.")
async def _get_link(self, url: str, session: Session):
async with Token(scopes=self.scopes) as token:
session_aio = AioSession(session)
headers = {
"Authorization": f"Bearer {await token.get()}",
}
with contextlib.suppress(AirflowException):
# suppress AirflowException because we don't want to raise exception
job = await session_aio.get(url=url, headers=headers)
return job
[docs] async def get_job(self, job_id: str, session: Session, project_id: str = PROVIDE_PROJECT_ID):
"""Get the specified job resource by job ID and project ID."""
self._check_fileds(project_id=project_id, job_id=job_id)
url = f"https://ml.googleapis.com/v1/projects/{project_id}/jobs/{job_id}"
return await self._get_link(url=url, session=session)
[docs] async def get_job_status(
self,
job_id: str,
project_id: str = PROVIDE_PROJECT_ID,
) -> str | None:
"""
Poll for job status asynchronously using gcloud-aio.
Note that an OSError is raised when Job results are still pending.
Exception means that Job finished with errors
"""
self._check_fileds(project_id=project_id, job_id=job_id)
async with ClientSession() as session:
try:
job = await self.get_job(
project_id=project_id,
job_id=job_id,
session=session, # type: ignore
)
job = await job.json(content_type=None)
self.log.info("Retrieving json_response: %s", job)
if job["state"] in ["SUCCEEDED", "FAILED", "CANCELLED"]:
job_status = "success"
elif job["state"] in ["PREPARING", "RUNNING"]:
job_status = "pending"
except OSError:
job_status = "pending"
except Exception as e:
self.log.info("Query execution finished with errors...")
job_status = str(e)
return job_status