Source code for airflow.providers.snowflake.hooks.snowflake

#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Any, Dict, Optional, Tuple

from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization

# pylint: disable=no-name-in-module
from snowflake import connector
from snowflake.connector import SnowflakeConnection

from airflow.hooks.dbapi import DbApiHook


[docs]class SnowflakeHook(DbApiHook): """ Interact with Snowflake. get_sqlalchemy_engine() depends on snowflake-sqlalchemy """
[docs] conn_name_attr = 'snowflake_conn_id'
[docs] default_conn_name = 'snowflake_default'
[docs] conn_type = 'snowflake'
[docs] hook_name = 'Snowflake'
[docs] supports_autocommit = True
def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.account = kwargs.pop("account", None) self.warehouse = kwargs.pop("warehouse", None) self.database = kwargs.pop("database", None) self.region = kwargs.pop("region", None) self.role = kwargs.pop("role", None) self.schema = kwargs.pop("schema", None) self.authenticator = kwargs.pop("authenticator", None) self.session_parameters = kwargs.pop("session_parameters", None)
[docs] def _get_conn_params(self) -> Dict[str, Optional[str]]: """ One method to fetch connection params as a dict used in get_uri() and get_connection() """ conn = self.get_connection( self.snowflake_conn_id # type: ignore[attr-defined] # pylint: disable=no-member ) account = conn.extra_dejson.get('account', '') warehouse = conn.extra_dejson.get('warehouse', '') database = conn.extra_dejson.get('database', '') region = conn.extra_dejson.get("region", '') role = conn.extra_dejson.get('role', '') schema = conn.schema or '' authenticator = conn.extra_dejson.get('authenticator', 'snowflake') session_parameters = conn.extra_dejson.get('session_parameters') conn_config = { "user": conn.login, "password": conn.password or '', "schema": self.schema or schema, "database": self.database or database, "account": self.account or account, "warehouse": self.warehouse or warehouse, "region": self.region or region, "role": self.role or role, "authenticator": self.authenticator or authenticator, "session_parameters": self.session_parameters or session_parameters, } # If private_key_file is specified in the extra json, load the contents of the file as a private # key and specify that in the connection configuration. The connection password then becomes the # passphrase for the private key. If your private key file is not encrypted (not recommended), then # leave the password empty. private_key_file = conn.extra_dejson.get('private_key_file') if private_key_file: with open(private_key_file, "rb") as key: passphrase = None if conn.password: passphrase = conn.password.strip().encode() p_key = serialization.load_pem_private_key( key.read(), password=passphrase, backend=default_backend() ) pkb = p_key.private_bytes( encoding=serialization.Encoding.DER, format=serialization.PrivateFormat.PKCS8, encryption_algorithm=serialization.NoEncryption(), ) conn_config['private_key'] = pkb conn_config.pop('password', None) return conn_config
[docs] def get_uri(self) -> str: """Override DbApiHook get_uri method for get_sqlalchemy_engine()""" conn_config = self._get_conn_params() uri = ( 'snowflake://{user}:{password}@{account}/{database}/{schema}' '?warehouse={warehouse}&role={role}&authenticator={authenticator}' ) return uri.format(**conn_config)
[docs] def get_conn(self) -> SnowflakeConnection: """Returns a snowflake.connection object""" conn_config = self._get_conn_params() conn = connector.connect(**conn_config) return conn
[docs] def _get_aws_credentials(self) -> Tuple[Optional[Any], Optional[Any]]: """ Returns aws_access_key_id, aws_secret_access_key from extra intended to be used by external import and export statements """ if self.snowflake_conn_id: # type: ignore[attr-defined] # pylint: disable=no-member connection_object = self.get_connection( self.snowflake_conn_id # type: ignore[attr-defined] # pylint: disable=no-member ) if 'aws_secret_access_key' in connection_object.extra_dejson: aws_access_key_id = connection_object.extra_dejson.get('aws_access_key_id') aws_secret_access_key = connection_object.extra_dejson.get('aws_secret_access_key') return aws_access_key_id, aws_secret_access_key
[docs] def set_autocommit(self, conn, autocommit: Any) -> None: conn.autocommit(autocommit) conn.autocommit_mode = autocommit
[docs] def get_autocommit(self, conn): return getattr(conn, 'autocommit_mode', False)

Was this entry helpful?