# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import asyncio
from collections.abc import AsyncIterator
from typing import TYPE_CHECKING, Any
from botocore.exceptions import ClientError, WaiterError
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.ecs import EcsHook
from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook
from airflow.providers.amazon.aws.triggers.base import AwsBaseWaiterTrigger
from airflow.providers.amazon.aws.utils.task_log_fetcher import AwsTaskLogFetcher
from airflow.triggers.base import BaseTrigger, TriggerEvent
if TYPE_CHECKING:
from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
[docs]class ClusterActiveTrigger(AwsBaseWaiterTrigger):
"""
Polls the status of a cluster until it's active.
:param cluster_arn: ARN of the cluster to watch.
:param waiter_delay: The amount of time in seconds to wait between attempts.
:param waiter_max_attempts: The number of times to ping for status.
Will fail after that many unsuccessful attempts.
:param aws_conn_id: The Airflow connection used for AWS credentials.
:param region_name: The AWS region where the cluster is located.
"""
def __init__(
self,
cluster_arn: str,
waiter_delay: int,
waiter_max_attempts: int,
aws_conn_id: str | None,
region_name: str | None = None,
**kwargs,
):
super().__init__(
serialized_fields={"cluster_arn": cluster_arn},
waiter_name="cluster_active",
waiter_args={"clusters": [cluster_arn]},
failure_message="Failure while waiting for cluster to be available",
status_message="Cluster is not ready yet",
status_queries=["clusters[].status", "failures"],
return_key="arn",
return_value=cluster_arn,
waiter_delay=waiter_delay,
waiter_max_attempts=waiter_max_attempts,
aws_conn_id=aws_conn_id,
region_name=region_name,
**kwargs,
)
[docs] def hook(self) -> AwsGenericHook:
return EcsHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name)
[docs]class ClusterInactiveTrigger(AwsBaseWaiterTrigger):
"""
Polls the status of a cluster until it's inactive.
:param cluster_arn: ARN of the cluster to watch.
:param waiter_delay: The amount of time in seconds to wait between attempts.
:param waiter_max_attempts: The number of times to ping for status.
Will fail after that many unsuccessful attempts.
:param aws_conn_id: The Airflow connection used for AWS credentials.
:param region_name: The AWS region where the cluster is located.
"""
def __init__(
self,
cluster_arn: str,
waiter_delay: int,
waiter_max_attempts: int,
aws_conn_id: str | None,
region_name: str | None = None,
**kwargs,
):
super().__init__(
serialized_fields={"cluster_arn": cluster_arn},
waiter_name="cluster_inactive",
waiter_args={"clusters": [cluster_arn]},
failure_message="Failure while waiting for cluster to be deactivated",
status_message="Cluster deactivation is not done yet",
status_queries=["clusters[].status", "failures"],
return_value=cluster_arn,
waiter_delay=waiter_delay,
waiter_max_attempts=waiter_max_attempts,
aws_conn_id=aws_conn_id,
region_name=region_name,
**kwargs,
)
[docs] def hook(self) -> AwsGenericHook:
return EcsHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name)
[docs]class TaskDoneTrigger(BaseTrigger):
"""
Waits for an ECS task to be done, while eventually polling logs.
:param cluster: short name or full ARN of the cluster where the task is running.
:param task_arn: ARN of the task to watch.
:param waiter_delay: The amount of time in seconds to wait between attempts.
:param waiter_max_attempts: The number of times to ping for status.
Will fail after that many unsuccessful attempts.
:param aws_conn_id: The Airflow connection used for AWS credentials.
:param region: The AWS region where the cluster is located.
"""
def __init__(
self,
cluster: str,
task_arn: str,
waiter_delay: int,
waiter_max_attempts: int,
aws_conn_id: str | None,
region: str | None,
log_group: str | None = None,
log_stream: str | None = None,
):
self.cluster = cluster
self.task_arn = task_arn
self.waiter_delay = waiter_delay
self.waiter_max_attempts = waiter_max_attempts
self.aws_conn_id = aws_conn_id
self.region = region
self.log_group = log_group
self.log_stream = log_stream
[docs] def serialize(self) -> tuple[str, dict[str, Any]]:
return (
self.__class__.__module__ + "." + self.__class__.__qualname__,
{
"cluster": self.cluster,
"task_arn": self.task_arn,
"waiter_delay": self.waiter_delay,
"waiter_max_attempts": self.waiter_max_attempts,
"aws_conn_id": self.aws_conn_id,
"region": self.region,
"log_group": self.log_group,
"log_stream": self.log_stream,
},
)
[docs] async def run(self) -> AsyncIterator[TriggerEvent]:
async with (
EcsHook(aws_conn_id=self.aws_conn_id, region_name=self.region).async_conn as ecs_client,
AwsLogsHook(aws_conn_id=self.aws_conn_id, region_name=self.region).async_conn as logs_client,
):
waiter = ecs_client.get_waiter("tasks_stopped")
logs_token = None
while self.waiter_max_attempts:
self.waiter_max_attempts -= 1
try:
await waiter.wait(
cluster=self.cluster, tasks=[self.task_arn], WaiterConfig={"MaxAttempts": 1}
)
# we reach this point only if the waiter met a success criteria
yield TriggerEvent(
{"status": "success", "task_arn": self.task_arn, "cluster": self.cluster}
)
return
except WaiterError as error:
if "terminal failure" in str(error):
raise
self.log.info("Status of the task is %s", error.last_response["tasks"][0]["lastStatus"])
await asyncio.sleep(int(self.waiter_delay))
finally:
if self.log_group and self.log_stream:
logs_token = await self._forward_logs(logs_client, logs_token)
raise AirflowException("Waiter error: max attempts reached")
async def _forward_logs(self, logs_client, next_token: str | None = None) -> str | None:
"""
Read logs from the cloudwatch stream and print them to the task logs.
:return: the token to pass to the next iteration to resume where we started.
"""
while True:
if next_token is not None:
token_arg: dict[str, str] = {"nextToken": next_token}
else:
token_arg = {}
try:
response = await logs_client.get_log_events(
logGroupName=self.log_group,
logStreamName=self.log_stream,
startFromHead=True,
**token_arg,
)
except ClientError as ce:
if ce.response["Error"]["Code"] == "ResourceNotFoundException":
self.log.info(
"Tried to get logs from stream %s in group %s but it didn't exist (yet). "
"Will try again.",
self.log_stream,
self.log_group,
)
return None
raise
events = response["events"]
for log_event in events:
self.log.info(AwsTaskLogFetcher.event_to_str(log_event))
if len(events) == 0 or next_token == response["nextForwardToken"]:
return response["nextForwardToken"]
next_token = response["nextForwardToken"]