Source code for airflow.providers.amazon.aws.hooks.athena

#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
"""
This module contains AWS Athena hook.

.. spelling::

    PageIterator
"""
from __future__ import annotations

import warnings
from time import sleep
from typing import Any

from botocore.paginate import PageIterator

from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook


[docs]class AthenaHook(AwsBaseHook): """ Interact with AWS Athena to run, poll queries and return query results Additional arguments (such as ``aws_conn_id``) may be specified and are passed down to the underlying AwsBaseHook. .. seealso:: :class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook` :param sleep_time: Time (in seconds) to wait between two consecutive calls to check query status on Athena :param log_query: Whether to log athena query and other execution params when it's executed. Defaults to *True*. """
[docs] INTERMEDIATE_STATES = ( "QUEUED", "RUNNING",
)
[docs] FAILURE_STATES = ( "FAILED", "CANCELLED",
)
[docs] SUCCESS_STATES = ("SUCCEEDED",)
[docs] TERMINAL_STATES = ( "SUCCEEDED", "FAILED", "CANCELLED",
) def __init__(self, *args: Any, sleep_time: int = 30, log_query: bool = True, **kwargs: Any) -> None: super().__init__(client_type="athena", *args, **kwargs) # type: ignore self.sleep_time = sleep_time self.log_query = log_query
[docs] def run_query( self, query: str, query_context: dict[str, str], result_configuration: dict[str, Any], client_request_token: str | None = None, workgroup: str = "primary", ) -> str: """ Run Presto query on athena with provided config and return submitted query_execution_id :param query: Presto query to run :param query_context: Context in which query need to be run :param result_configuration: Dict with path to store results in and config related to encryption :param client_request_token: Unique token created by user to avoid multiple executions of same query :param workgroup: Athena workgroup name, when not specified, will be 'primary' :return: str """ params = { "QueryString": query, "QueryExecutionContext": query_context, "ResultConfiguration": result_configuration, "WorkGroup": workgroup, } if client_request_token: params["ClientRequestToken"] = client_request_token if self.log_query: self.log.info("Running Query with params: %s", params) response = self.get_conn().start_query_execution(**params) query_execution_id = response["QueryExecutionId"] self.log.info("Query execution id: %s", query_execution_id) return query_execution_id
[docs] def check_query_status(self, query_execution_id: str) -> str | None: """ Fetch the status of submitted athena query. Returns None or one of valid query states. :param query_execution_id: Id of submitted athena query :return: str """ response = self.get_conn().get_query_execution(QueryExecutionId=query_execution_id) state = None try: state = response["QueryExecution"]["Status"]["State"] except Exception: self.log.exception( "Exception while getting query state. Query execution id: %s", query_execution_id ) finally: # The error is being absorbed here and is being handled by the caller. # The error is being absorbed to implement retries. return state
[docs] def get_state_change_reason(self, query_execution_id: str) -> str | None: """ Fetch the reason for a state change (e.g. error message). Returns None or reason string. :param query_execution_id: Id of submitted athena query :return: str """ response = self.get_conn().get_query_execution(QueryExecutionId=query_execution_id) reason = None try: reason = response["QueryExecution"]["Status"]["StateChangeReason"] except Exception: self.log.exception( "Exception while getting query state change reason. Query execution id: %s", query_execution_id, ) finally: # The error is being absorbed here and is being handled by the caller. # The error is being absorbed to implement retries. return reason
[docs] def get_query_results( self, query_execution_id: str, next_token_id: str | None = None, max_results: int = 1000 ) -> dict | None: """ Fetch submitted athena query results. returns none if query is in intermediate state or failed/cancelled state else dict of query output :param query_execution_id: Id of submitted athena query :param next_token_id: The token that specifies where to start pagination. :param max_results: The maximum number of results (rows) to return in this request. :return: dict """ query_state = self.check_query_status(query_execution_id) if query_state is None: self.log.error("Invalid Query state. Query execution id: %s", query_execution_id) return None elif query_state in self.INTERMEDIATE_STATES or query_state in self.FAILURE_STATES: self.log.error( 'Query is in "%s" state. Cannot fetch results. Query execution id: %s', query_state, query_execution_id, ) return None result_params = {"QueryExecutionId": query_execution_id, "MaxResults": max_results} if next_token_id: result_params["NextToken"] = next_token_id return self.get_conn().get_query_results(**result_params)
[docs] def get_query_results_paginator( self, query_execution_id: str, max_items: int | None = None, page_size: int | None = None, starting_token: str | None = None, ) -> PageIterator | None: """ Fetch submitted athena query results. returns none if query is in intermediate state or failed/cancelled state else a paginator to iterate through pages of results. If you wish to get all results at once, call build_full_result() on the returned PageIterator :param query_execution_id: Id of submitted athena query :param max_items: The total number of items to return. :param page_size: The size of each page. :param starting_token: A token to specify where to start paginating. :return: PageIterator """ query_state = self.check_query_status(query_execution_id) if query_state is None: self.log.error("Invalid Query state (null). Query execution id: %s", query_execution_id) return None if query_state in self.INTERMEDIATE_STATES or query_state in self.FAILURE_STATES: self.log.error( 'Query is in "%s" state. Cannot fetch results, Query execution id: %s', query_state, query_execution_id, ) return None result_params = { "QueryExecutionId": query_execution_id, "PaginationConfig": { "MaxItems": max_items, "PageSize": page_size, "StartingToken": starting_token, }, } paginator = self.get_conn().get_paginator("get_query_results") return paginator.paginate(**result_params)
[docs] def poll_query_status( self, query_execution_id: str, max_tries: int | None = None, max_polling_attempts: int | None = None, ) -> str | None: """ Poll the status of submitted athena query until query state reaches final state. Returns one of the final states :param query_execution_id: Id of submitted athena query :param max_tries: Deprecated - Use max_polling_attempts instead :param max_polling_attempts: Number of times to poll for query state before function exits :return: str """ if max_tries: warnings.warn( f"Passing 'max_tries' to {self.__class__.__name__}.poll_query_status is deprecated " f"and will be removed in a future release. Please use 'max_polling_attempts' instead.", DeprecationWarning, stacklevel=2, ) if max_polling_attempts and max_polling_attempts != max_tries: raise Exception("max_polling_attempts must be the same value as max_tries") else: max_polling_attempts = max_tries try_number = 1 final_query_state = None # Query state when query reaches final state or max_polling_attempts reached while True: query_state = self.check_query_status(query_execution_id) if query_state is None: self.log.info( "Query execution id: %s, trial %s: Invalid query state. Retrying again", query_execution_id, try_number, ) elif query_state in self.TERMINAL_STATES: self.log.info( "Query execution id: %s, trial %s: Query execution completed. Final state is %s", query_execution_id, try_number, query_state, ) final_query_state = query_state break else: self.log.info( "Query execution id: %s, trial %s: Query is still in non-terminal state - %s", query_execution_id, try_number, query_state, ) if ( max_polling_attempts and try_number >= max_polling_attempts ): # Break loop if max_polling_attempts reached final_query_state = query_state break try_number += 1 sleep(self.sleep_time) return final_query_state
[docs] def get_output_location(self, query_execution_id: str) -> str: """ Function to get the output location of the query results in s3 uri format. :param query_execution_id: Id of submitted athena query :return: str """ output_location = None if query_execution_id: response = self.get_conn().get_query_execution(QueryExecutionId=query_execution_id) if response: try: output_location = response["QueryExecution"]["ResultConfiguration"]["OutputLocation"] except KeyError: self.log.error( "Error retrieving OutputLocation. Query execution id: %s", query_execution_id ) raise else: raise else: raise ValueError("Invalid Query execution id. Query execution id: %s", query_execution_id) return output_location
[docs] def stop_query(self, query_execution_id: str) -> dict: """ Cancel the submitted athena query :param query_execution_id: Id of submitted athena query :return: dict """ self.log.info("Stopping Query with executionId - %s", query_execution_id) return self.get_conn().stop_query_execution(QueryExecutionId=query_execution_id)

Was this entry helpful?