Source code for

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
"""Interact with AWS DataSync, using the AWS ``boto3`` library."""
from __future__ import annotations

import time
from urllib.parse import urlsplit

from airflow.exceptions import AirflowBadRequest, AirflowException, AirflowTaskTimeout
from import AwsBaseHook

[docs]class DataSyncHook(AwsBaseHook): """ Interact with AWS DataSync. Provide thick wrapper around :external+boto3:py:class:`boto3.client("datasync") <DataSync.Client>`. Additional arguments (such as ``aws_conn_id``) may be specified and are passed down to the underlying AwsBaseHook. :param wait_interval_seconds: Time to wait between two consecutive calls to check TaskExecution status. Defaults to 30 seconds. :raises ValueError: If wait_interval_seconds is not between 0 and 15*60 seconds. .. seealso:: - :class:`` """
def __init__(self, wait_interval_seconds: int = 30, *args, **kwargs) -> None: super().__init__(client_type="datasync", *args, **kwargs) # type: ignore[misc] self.locations: list = [] self.tasks: list = [] # wait_interval_seconds = 0 is used during unit tests if 0 <= wait_interval_seconds <= 15 * 60: self.wait_interval_seconds = wait_interval_seconds else: raise ValueError(f"Invalid wait_interval_seconds {wait_interval_seconds}")
[docs] def create_location(self, location_uri: str, **create_location_kwargs) -> str: """ Create a new location. .. seealso:: - :external+boto3:py:meth:`DataSync.Client.create_location_s3` - :external+boto3:py:meth:`DataSync.Client.create_location_smb` - :external+boto3:py:meth:`DataSync.Client.create_location_nfs` - :external+boto3:py:meth:`DataSync.Client.create_location_efs` :param location_uri: Location URI used to determine the location type (S3, SMB, NFS, EFS). :param create_location_kwargs: Passed to ``DataSync.Client.create_location_*`` methods. :return: LocationArn of the created Location. :raises AirflowException: If location type (prefix from ``location_uri``) is invalid. """ schema = urlsplit(location_uri).scheme if schema == "smb": location = self.get_conn().create_location_smb(**create_location_kwargs) elif schema == "s3": location = self.get_conn().create_location_s3(**create_location_kwargs) elif schema == "nfs": location = self.get_conn().create_location_nfs(**create_location_kwargs) elif schema == "efs": location = self.get_conn().create_location_efs(**create_location_kwargs) else: raise AirflowException(f"Invalid/Unsupported location type: {schema}") self._refresh_locations() return location["LocationArn"]
[docs] def get_location_arns( self, location_uri: str, case_sensitive: bool = False, ignore_trailing_slash: bool = True ) -> list[str]: """ Return all LocationArns which match a LocationUri. :param location_uri: Location URI to search for, eg ``s3://mybucket/mypath`` :param case_sensitive: Do a case sensitive search for location URI. :param ignore_trailing_slash: Ignore / at the end of URI when matching. :return: List of LocationArns. :raises AirflowBadRequest: if ``location_uri`` is empty """ if not location_uri: raise AirflowBadRequest("location_uri not specified") if not self.locations: self._refresh_locations() result = [] if not case_sensitive: location_uri = location_uri.lower() if ignore_trailing_slash and location_uri.endswith("/"): location_uri = location_uri[:-1] for location_from_aws in self.locations: location_uri_from_aws = location_from_aws["LocationUri"] if not case_sensitive: location_uri_from_aws = location_uri_from_aws.lower() if ignore_trailing_slash and location_uri_from_aws.endswith("/"): location_uri_from_aws = location_uri_from_aws[:-1] if location_uri == location_uri_from_aws: result.append(location_from_aws["LocationArn"]) return result
def _refresh_locations(self) -> None: """Refresh the local list of Locations.""" locations = self.get_conn().list_locations() self.locations = locations["Locations"] while "NextToken" in locations: locations = self.get_conn().list_locations(NextToken=locations["NextToken"]) self.locations.extend(locations["Locations"])
[docs] def create_task( self, source_location_arn: str, destination_location_arn: str, **create_task_kwargs ) -> str: """Create a Task between the specified source and destination LocationArns. .. seealso:: - :external+boto3:py:meth:`DataSync.Client.create_task` :param source_location_arn: Source LocationArn. Must exist already. :param destination_location_arn: Destination LocationArn. Must exist already. :param create_task_kwargs: Passed to ``boto.create_task()``. See AWS boto3 datasync documentation. :return: TaskArn of the created Task """ task = self.get_conn().create_task( SourceLocationArn=source_location_arn, DestinationLocationArn=destination_location_arn, **create_task_kwargs, ) self._refresh_tasks() return task["TaskArn"]
[docs] def update_task(self, task_arn: str, **update_task_kwargs) -> None: """Update a Task. .. seealso:: - :external+boto3:py:meth:`DataSync.Client.update_task` :param task_arn: The TaskArn to update. :param update_task_kwargs: Passed to ``boto.update_task()``, See AWS boto3 datasync documentation. """ self.get_conn().update_task(TaskArn=task_arn, **update_task_kwargs)
[docs] def delete_task(self, task_arn: str) -> None: """Delete a Task. .. seealso:: - :external+boto3:py:meth:`DataSync.Client.delete_task` :param task_arn: The TaskArn to delete. """ self.get_conn().delete_task(TaskArn=task_arn)
def _refresh_tasks(self) -> None: """Refresh the local list of Tasks.""" tasks = self.get_conn().list_tasks() self.tasks = tasks["Tasks"] while "NextToken" in tasks: tasks = self.get_conn().list_tasks(NextToken=tasks["NextToken"]) self.tasks.extend(tasks["Tasks"])
[docs] def get_task_arns_for_location_arns( self, source_location_arns: list, destination_location_arns: list, ) -> list: """ Return list of TaskArns which use both a specified source and destination LocationArns. :param source_location_arns: List of source LocationArns. :param destination_location_arns: List of destination LocationArns. :raises AirflowBadRequest: if ``source_location_arns`` or ``destination_location_arns`` are empty. """ if not source_location_arns: raise AirflowBadRequest("source_location_arns not specified") if not destination_location_arns: raise AirflowBadRequest("destination_location_arns not specified") if not self.tasks: self._refresh_tasks() result = [] for task in self.tasks: task_arn = task["TaskArn"] task_description = self.get_task_description(task_arn) if task_description["SourceLocationArn"] in source_location_arns: if task_description["DestinationLocationArn"] in destination_location_arns: result.append(task_arn) return result
[docs] def start_task_execution(self, task_arn: str, **kwargs) -> str: """ Start a TaskExecution for the specified task_arn. Each task can have at most one TaskExecution. Additional keyword arguments send to ``start_task_execution`` boto3 method. .. seealso:: - :external+boto3:py:meth:`DataSync.Client.start_task_execution` :param task_arn: TaskArn :return: TaskExecutionArn :raises ClientError: If a TaskExecution is already busy running for this ``task_arn``. :raises AirflowBadRequest: If ``task_arn`` is empty. """ if not task_arn: raise AirflowBadRequest("task_arn not specified") task_execution = self.get_conn().start_task_execution(TaskArn=task_arn, **kwargs) return task_execution["TaskExecutionArn"]
[docs] def cancel_task_execution(self, task_execution_arn: str) -> None: """ Cancel a TaskExecution for the specified ``task_execution_arn``. .. seealso:: - :external+boto3:py:meth:`DataSync.Client.cancel_task_execution` :param task_execution_arn: TaskExecutionArn. :raises AirflowBadRequest: If ``task_execution_arn`` is empty. """ if not task_execution_arn: raise AirflowBadRequest("task_execution_arn not specified") self.get_conn().cancel_task_execution(TaskExecutionArn=task_execution_arn)
[docs] def get_task_description(self, task_arn: str) -> dict: """ Get description for the specified ``task_arn``. .. seealso:: - :external+boto3:py:meth:`DataSync.Client.describe_task` :param task_arn: TaskArn :return: AWS metadata about a task. :raises AirflowBadRequest: If ``task_arn`` is empty. """ if not task_arn: raise AirflowBadRequest("task_arn not specified") return self.get_conn().describe_task(TaskArn=task_arn)
[docs] def describe_task_execution(self, task_execution_arn: str) -> dict: """ Get description for the specified ``task_execution_arn``. .. seealso:: - :external+boto3:py:meth:`DataSync.Client.describe_task_execution` :param task_execution_arn: TaskExecutionArn :return: AWS metadata about a task execution. :raises AirflowBadRequest: If ``task_execution_arn`` is empty. """ return self.get_conn().describe_task_execution(TaskExecutionArn=task_execution_arn)
[docs] def get_current_task_execution_arn(self, task_arn: str) -> str | None: """ Get current TaskExecutionArn (if one exists) for the specified ``task_arn``. :param task_arn: TaskArn :return: CurrentTaskExecutionArn for this ``task_arn`` or None. :raises AirflowBadRequest: if ``task_arn`` is empty. """ if not task_arn: raise AirflowBadRequest("task_arn not specified") task_description = self.get_task_description(task_arn) if "CurrentTaskExecutionArn" in task_description: return task_description["CurrentTaskExecutionArn"] return None
[docs] def wait_for_task_execution(self, task_execution_arn: str, max_iterations: int = 60) -> bool: """ Wait for Task Execution status to be complete (SUCCESS/ERROR). The ``task_execution_arn`` must exist, or a boto3 ClientError will be raised. :param task_execution_arn: TaskExecutionArn :param max_iterations: Maximum number of iterations before timing out. :return: Result of task execution. :raises AirflowTaskTimeout: If maximum iterations is exceeded. :raises AirflowBadRequest: If ``task_execution_arn`` is empty. """ if not task_execution_arn: raise AirflowBadRequest("task_execution_arn not specified") for _ in range(max_iterations): task_execution = self.get_conn().describe_task_execution(TaskExecutionArn=task_execution_arn) status = task_execution["Status"]"status=%s", status) if status in self.TASK_EXECUTION_SUCCESS_STATES: return True elif status in self.TASK_EXECUTION_FAILURE_STATES: return False elif status is None or status in self.TASK_EXECUTION_INTERMEDIATE_STATES: time.sleep(self.wait_interval_seconds) else: raise AirflowException(f"Unknown status: {status}") # Should never happen time.sleep(self.wait_interval_seconds) else: raise AirflowTaskTimeout("Max iterations exceeded!")

Was this entry helpful?