#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import time
from dataclasses import dataclass
from pprint import pformat
from typing import TYPE_CHECKING, Any, Iterable
from uuid import UUID
from pendulum import duration
from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
from airflow.providers.amazon.aws.utils import trim_none_values
if TYPE_CHECKING:
from mypy_boto3_redshift_data import RedshiftDataAPIServiceClient # noqa: F401
from mypy_boto3_redshift_data.type_defs import DescribeStatementResponseTypeDef
[docs]FINISHED_STATE = "FINISHED"
[docs]ABORTED_STATE = "ABORTED"
[docs]FAILURE_STATES = {FAILED_STATE, ABORTED_STATE}
[docs]RUNNING_STATES = {"PICKED", "STARTED", "SUBMITTED"}
@dataclass
[docs]class QueryExecutionOutput:
"""Describes the output of a query execution."""
[docs]class RedshiftDataQueryFailedError(ValueError):
"""Raise an error that redshift data query failed."""
[docs]class RedshiftDataQueryAbortedError(ValueError):
"""Raise an error that redshift data query was aborted."""
[docs]class RedshiftDataHook(AwsGenericHook["RedshiftDataAPIServiceClient"]):
"""
Interact with Amazon Redshift Data API.
Provide thin wrapper around
:external+boto3:py:class:`boto3.client("redshift-data") <RedshiftDataAPIService.Client>`.
Additional arguments (such as ``aws_conn_id``) may be specified and
are passed down to the underlying AwsBaseHook.
.. seealso::
- :class:`airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook`
- `Amazon Redshift Data API \
<https://docs.aws.amazon.com/redshift-data/latest/APIReference/Welcome.html>`__
"""
def __init__(self, *args, **kwargs) -> None:
kwargs["client_type"] = "redshift-data"
super().__init__(*args, **kwargs)
[docs] def execute_query(
self,
sql: str | list[str],
database: str | None = None,
cluster_identifier: str | None = None,
db_user: str | None = None,
parameters: Iterable | None = None,
secret_arn: str | None = None,
statement_name: str | None = None,
with_event: bool = False,
wait_for_completion: bool = True,
poll_interval: int = 10,
workgroup_name: str | None = None,
session_id: str | None = None,
session_keep_alive_seconds: int | None = None,
) -> QueryExecutionOutput:
"""
Execute a statement against Amazon Redshift.
:param sql: the SQL statement or list of SQL statement to run
:param database: the name of the database
:param cluster_identifier: unique identifier of a cluster
:param db_user: the database username
:param parameters: the parameters for the SQL statement
:param secret_arn: the name or ARN of the secret that enables db access
:param statement_name: the name of the SQL statement
:param with_event: whether to send an event to EventBridge
:param wait_for_completion: whether to wait for a result
:param poll_interval: how often in seconds to check the query status
:param workgroup_name: name of the Redshift Serverless workgroup. Mutually exclusive with
`cluster_identifier`. Specify this parameter to query Redshift Serverless. More info
https://docs.aws.amazon.com/redshift/latest/mgmt/working-with-serverless.html
:param session_id: the session identifier of the query
:param session_keep_alive_seconds: duration in seconds to keep the session alive after the query
finishes. The maximum time a session can keep alive is 24 hours
:returns statement_id: str, the UUID of the statement
"""
kwargs: dict[str, Any] = {
"ClusterIdentifier": cluster_identifier,
"Database": database,
"DbUser": db_user,
"Parameters": parameters,
"WithEvent": with_event,
"SecretArn": secret_arn,
"StatementName": statement_name,
"WorkgroupName": workgroup_name,
"SessionId": session_id,
"SessionKeepAliveSeconds": session_keep_alive_seconds,
}
if sum(x is not None for x in (cluster_identifier, workgroup_name, session_id)) != 1:
raise ValueError(
"Exactly one of cluster_identifier, workgroup_name, or session_id must be provided"
)
if session_id is not None:
msg = "session_id must be a valid UUID4"
try:
if UUID(session_id).version != 4:
raise ValueError(msg)
except ValueError:
raise ValueError(msg)
if session_keep_alive_seconds is not None and (
session_keep_alive_seconds < 0 or duration(seconds=session_keep_alive_seconds).hours > 24
):
raise ValueError("Session keep alive duration must be between 0 and 86400 seconds.")
if isinstance(sql, list):
kwargs["Sqls"] = sql
resp = self.conn.batch_execute_statement(**trim_none_values(kwargs))
else:
kwargs["Sql"] = sql
resp = self.conn.execute_statement(**trim_none_values(kwargs))
statement_id = resp["Id"]
if wait_for_completion:
self.wait_for_results(statement_id, poll_interval=poll_interval)
return QueryExecutionOutput(statement_id=statement_id, session_id=resp.get("SessionId"))
[docs] def wait_for_results(self, statement_id: str, poll_interval: int) -> str:
while True:
self.log.info("Polling statement %s", statement_id)
is_finished = self.check_query_is_finished(statement_id)
if is_finished:
return FINISHED_STATE
time.sleep(poll_interval)
[docs] def check_query_is_finished(self, statement_id: str) -> bool:
"""Check whether query finished, raise exception is failed."""
resp = self.conn.describe_statement(Id=statement_id)
return self.parse_statement_response(resp)
[docs] def parse_statement_response(self, resp: DescribeStatementResponseTypeDef) -> bool:
"""Parse the response of describe_statement."""
status = resp["Status"]
if status == FINISHED_STATE:
num_rows = resp.get("ResultRows")
if num_rows is not None:
self.log.info("Processed %s rows", num_rows)
return True
elif status in FAILURE_STATES:
exception_cls = (
RedshiftDataQueryFailedError if status == FAILED_STATE else RedshiftDataQueryAbortedError
)
raise exception_cls(
f"Statement {resp['Id']} terminated with status {status}. "
f"Response details: {pformat(resp)}"
)
self.log.info("Query status: %s", status)
return False
[docs] def get_table_primary_key(
self,
table: str,
database: str,
schema: str | None = "public",
cluster_identifier: str | None = None,
workgroup_name: str | None = None,
db_user: str | None = None,
secret_arn: str | None = None,
statement_name: str | None = None,
with_event: bool = False,
wait_for_completion: bool = True,
poll_interval: int = 10,
) -> list[str] | None:
"""
Return the table primary key.
Copied from ``RedshiftSQLHook.get_table_primary_key()``
:param table: Name of the target table
:param database: the name of the database
:param schema: Name of the target schema, public by default
:param cluster_identifier: unique identifier of a cluster
:param workgroup_name: name of the Redshift Serverless workgroup. Mutually exclusive with
`cluster_identifier`. Specify this parameter to query Redshift Serverless. More info
https://docs.aws.amazon.com/redshift/latest/mgmt/working-with-serverless.html
:param db_user: the database username
:param secret_arn: the name or ARN of the secret that enables db access
:param statement_name: the name of the SQL statement
:param with_event: indicates whether to send an event to EventBridge
:param wait_for_completion: indicates whether to wait for a result, if True wait, if False don't wait
:param poll_interval: how often in seconds to check the query status
:return: Primary key columns list
"""
sql = f"""
select kcu.column_name
from information_schema.table_constraints tco
join information_schema.key_column_usage kcu
on kcu.constraint_name = tco.constraint_name
and kcu.constraint_schema = tco.constraint_schema
and kcu.constraint_name = tco.constraint_name
where tco.constraint_type = 'PRIMARY KEY'
and kcu.table_schema = {schema}
and kcu.table_name = {table}
"""
stmt_id = self.execute_query(
sql=sql,
database=database,
cluster_identifier=cluster_identifier,
workgroup_name=workgroup_name,
db_user=db_user,
secret_arn=secret_arn,
statement_name=statement_name,
with_event=with_event,
wait_for_completion=wait_for_completion,
poll_interval=poll_interval,
).statement_id
pk_columns = []
token = ""
while True:
kwargs = {"Id": stmt_id}
if token:
kwargs["NextToken"] = token
response = self.conn.get_statement_result(**kwargs)
# we only select a single column (that is a string),
# so safe to assume that there is only a single col in the record
pk_columns += [y["stringValue"] for x in response["Records"] for y in x]
if "NextToken" in response:
token = response["NextToken"]
else:
break
return pk_columns or None
[docs] async def is_still_running(self, statement_id: str) -> bool:
"""
Async function to check whether the query is still running.
:param statement_id: the UUID of the statement
"""
async with self.async_conn as client:
desc = await client.describe_statement(Id=statement_id)
return desc["Status"] in RUNNING_STATES
[docs] async def check_query_is_finished_async(self, statement_id: str) -> bool:
"""
Async function to check statement is finished.
It takes statement_id, makes async connection to redshift data to get the query status
by statement_id and returns the query status.
:param statement_id: the UUID of the statement
"""
async with self.async_conn as client:
resp = await client.describe_statement(Id=statement_id)
return self.parse_statement_response(resp)