Source code for airflow.providers.amazon.aws.hooks.athena_sql

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

import json
from functools import cached_property
from typing import TYPE_CHECKING, Any

import pyathena
from sqlalchemy.engine.url import URL

from airflow.exceptions import AirflowException, AirflowNotFoundException
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
from airflow.providers.amazon.aws.utils.connection_wrapper import AwsConnectionWrapper
from airflow.providers.common.sql.hooks.sql import DbApiHook

if TYPE_CHECKING:
    from pyathena.connection import Connection as AthenaConnection


[docs]class AthenaSQLHook(AwsBaseHook, DbApiHook): """ Interact with Amazon Athena. Provide wrapper around PyAthena library. :param athena_conn_id: :ref:`Amazon Athena Connection <howto/connection:athena>`. Additional arguments (such as ``aws_conn_id``) may be specified and are passed down to the underlying AwsBaseHook. You can specify ``driver`` in ``extra`` of your connection in order to use a different driver than the default ``rest``. Also, aws_domain could be specified in ``extra`` of your connection. PyAthena and AWS Authentication parameters could be passed in extra field of ``athena_conn_id`` connection. Passing authentication parameters in ``athena_conn_id`` will override those in ``aws_conn_id``. .. seealso:: :class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook` .. note:: get_uri() depends on SQLAlchemy and PyAthena. """
[docs] conn_name_attr = "athena_conn_id"
[docs] default_conn_name = "athena_default"
[docs] conn_type = "athena"
[docs] hook_name = "Amazon Athena"
[docs] supports_autocommit = True
def __init__(self, athena_conn_id: str = default_conn_name, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.athena_conn_id = athena_conn_id @classmethod
[docs] def get_ui_field_behaviour(cls) -> dict[str, Any]: """Return custom UI field behaviour for AWS Athena Connection.""" return { "hidden_fields": ["host", "port"], "relabeling": { "login": "AWS Access Key ID", "password": "AWS Secret Access Key", }, "placeholders": { "login": "AKIAIOSFODNN7EXAMPLE", "password": "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", "extra": json.dumps( { "aws_domain": "amazonaws.com", "driver": "rest", "s3_staging_dir": "s3://bucket_name/staging/", "work_group": "primary", "region_name": "us-east-1", "session_kwargs": {"profile_name": "default"}, "config_kwargs": {"retries": {"mode": "standard", "max_attempts": 10}}, "role_arn": "arn:aws:iam::123456789098:role/role-name", "assume_role_method": "assume_role", "assume_role_kwargs": {"RoleSessionName": "airflow"}, "aws_session_token": "AQoDYXdzEJr...EXAMPLETOKEN", "endpoint_url": "http://localhost:4566", }, indent=2, ), }, }
@cached_property
[docs] def conn_config(self) -> AwsConnectionWrapper: """Get the Airflow Connection object and wrap it in helper (cached).""" athena_conn = self.get_connection(self.athena_conn_id) if self.aws_conn_id: try: connection = self.get_connection(self.aws_conn_id) connection.login = athena_conn.login connection.password = athena_conn.password connection.schema = athena_conn.schema connection.set_extra(json.dumps({**athena_conn.extra_dejson, **connection.extra_dejson})) except AirflowNotFoundException: connection = athena_conn connection.conn_type = "aws" self.log.warning( "Unable to find AWS Connection ID '%s', switching to empty.", self.aws_conn_id ) return AwsConnectionWrapper( conn=connection, region_name=self._region_name, botocore_config=self._config, verify=self._verify )
@property
[docs] def conn(self) -> AwsConnectionWrapper: """Get Aws Connection Wrapper object.""" return self.conn_config
def _get_conn_params(self) -> dict[str, str | None]: """Retrieve connection parameters.""" if not self.conn.region_name: raise AirflowException("region_name must be specified in the connection's extra") return dict( driver=self.conn.extra_dejson.get("driver", "rest"), schema_name=self.conn.schema, region_name=self.conn.region_name, aws_domain=self.conn.extra_dejson.get("aws_domain", "amazonaws.com"), )
[docs] def get_uri(self) -> str: """Overridden to use the Athena dialect as driver name.""" conn_params = self._get_conn_params() creds = self.get_credentials(region_name=conn_params["region_name"]) return URL.create( f'awsathena+{conn_params["driver"]}', username=creds.access_key, password=creds.secret_key, host=f'athena.{conn_params["region_name"]}.{conn_params["aws_domain"]}', port=443, database=conn_params["schema_name"], query={"aws_session_token": creds.token, **self.conn.extra_dejson}, )
[docs] def get_conn(self) -> AthenaConnection: """Get a ``pyathena.Connection`` object.""" conn_params = self._get_conn_params() conn_kwargs: dict = { "schema_name": conn_params["schema_name"], "region_name": conn_params["region_name"], "session": self.get_session(region_name=conn_params["region_name"]), **self.conn.extra_dejson, } return pyathena.connect(**conn_kwargs)

Was this entry helpful?