Source code for airflow.providers.oracle.hooks.oracle

#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

import math
import warnings
from datetime import datetime

import oracledb

from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.providers.common.sql.hooks.sql import DbApiHook

[docs]PARAM_TYPES = {bool, float, int, str}
def _map_param(value): if value in PARAM_TYPES: # In this branch, value is a Python type; calling it produces # an instance of the type which is understood by the Oracle driver # in the out parameter mapping mechanism. value = value() return value def _get_bool(val): if isinstance(val, bool): return val if isinstance(val, str): val = val.lower().strip() if val == "true": return True if val == "false": return False return None def _get_first_bool(*vals): for val in vals: converted = _get_bool(val) if isinstance(converted, bool): return converted return None
[docs]class OracleHook(DbApiHook): """ Interact with Oracle SQL. :param oracle_conn_id: The :ref:`Oracle connection id <howto/connection:oracle>` used for Oracle credentials. :param thick_mode: Specify whether to use python-oracledb in thick mode. Defaults to False. If set to True, you must have the Oracle Client libraries installed. See `oracledb docs<https://python-oracledb.readthedocs.io/en/latest/user_guide/initialization.html>` for more info. :param thick_mode_lib_dir: Path to use to find the Oracle Client libraries when using thick mode. If not specified, defaults to the standard way of locating the Oracle Client library on the OS. See `oracledb docs <https://python-oracledb.readthedocs.io/en/latest/user_guide/initialization.html#setting-the-oracle-client-library-directory>` for more info. :param thick_mode_config_dir: Path to use to find the Oracle Client library configuration files when using thick mode. If not specified, defaults to the standard way of locating the Oracle Client library configuration files on the OS. See `oracledb docs <https://python-oracledb.readthedocs.io/en/latest/user_guide/initialization.html#optional-oracle-net-configuration-files>` for more info. :param fetch_decimals: Specify whether numbers should be fetched as ``decimal.Decimal`` values. See `defaults.fetch_decimals <https://python-oracledb.readthedocs.io/en/latest/api_manual/defaults.html#defaults.fetch_decimals>` for more info. :param fetch_lobs: Specify whether to fetch strings/bytes for CLOBs or BLOBs instead of locators. See `defaults.fetch_lobs <https://python-oracledb.readthedocs.io/en/latest/api_manual/defaults.html#defaults.fetch_decimals>` for more info. """
[docs] conn_name_attr = "oracle_conn_id"
[docs] default_conn_name = "oracle_default"
[docs] conn_type = "oracle"
[docs] hook_name = "Oracle"
_test_connection_sql = "select 1 from dual"
[docs] supports_autocommit = True
def __init__( self, *args, thick_mode: bool | None = None, thick_mode_lib_dir: str | None = None, thick_mode_config_dir: str | None = None, fetch_decimals: bool | None = None, fetch_lobs: bool | None = None, **kwargs, ) -> None: super().__init__(*args, **kwargs) self.thick_mode = thick_mode self.thick_mode_lib_dir = thick_mode_lib_dir self.thick_mode_config_dir = thick_mode_config_dir self.fetch_decimals = fetch_decimals self.fetch_lobs = fetch_lobs
[docs] def get_conn(self) -> oracledb.Connection: """ Get an Oracle connection object. Optional parameters for using a custom DSN connection (instead of using a server alias from tnsnames.ora) The dsn (data source name) is the TNS entry (from the Oracle names server or tnsnames.ora file), or is a string like the one returned from ``makedsn()``. :param dsn: the data source name for the Oracle server :param service_name: the db_unique_name of the database that you are connecting to (CONNECT_DATA part of TNS) :param sid: Oracle System ID that identifies a particular database on a system You can set these parameters in the extra fields of your connection as in .. code-block:: python {"dsn": ("(DESCRIPTION=(ADDRESS=(PROTOCOL=TCP)(HOST=host)(PORT=1521))(CONNECT_DATA=(SID=sid)))")} see more param detail in `oracledb.connect <https://python-oracledb.readthedocs.io/en/latest/api_manual/module.html#oracledb.connect>`_ """ conn = self.get_connection(self.oracle_conn_id) # type: ignore[attr-defined] conn_config = {"user": conn.login, "password": conn.password} sid = conn.extra_dejson.get("sid") mod = conn.extra_dejson.get("module") schema = conn.schema # Enable oracledb thick mode if thick_mode is set to True # Parameters take precedence over connection config extra # Defaults to use thin mode if not provided in params or connection config extra thick_mode = _get_first_bool(self.thick_mode, conn.extra_dejson.get("thick_mode")) if thick_mode is True: if self.thick_mode_lib_dir is None: self.thick_mode_lib_dir = conn.extra_dejson.get("thick_mode_lib_dir") if not isinstance(self.thick_mode_lib_dir, (str, type(None))): raise TypeError( f"thick_mode_lib_dir expected str or None, " f"got {type(self.thick_mode_lib_dir).__name__}" ) if self.thick_mode_config_dir is None: self.thick_mode_config_dir = conn.extra_dejson.get("thick_mode_config_dir") if not isinstance(self.thick_mode_config_dir, (str, type(None))): raise TypeError( f"thick_mode_config_dir expected str or None, " f"got {type(self.thick_mode_config_dir).__name__}" ) oracledb.init_oracle_client( lib_dir=self.thick_mode_lib_dir, config_dir=self.thick_mode_config_dir ) # Set oracledb Defaults Attributes if provided # (https://python-oracledb.readthedocs.io/en/latest/api_manual/defaults.html) fetch_decimals = _get_first_bool(self.fetch_decimals, conn.extra_dejson.get("fetch_decimals")) if isinstance(fetch_decimals, bool): oracledb.defaults.fetch_decimals = fetch_decimals fetch_lobs = _get_first_bool(self.fetch_lobs, conn.extra_dejson.get("fetch_lobs")) if isinstance(fetch_lobs, bool): oracledb.defaults.fetch_lobs = fetch_lobs # Set up DSN service_name = conn.extra_dejson.get("service_name") port = conn.port if conn.port else 1521 if conn.host and sid and not service_name: conn_config["dsn"] = oracledb.makedsn(conn.host, port, sid) elif conn.host and service_name and not sid: conn_config["dsn"] = oracledb.makedsn(conn.host, port, service_name=service_name) else: dsn = conn.extra_dejson.get("dsn") if dsn is None: dsn = conn.host if conn.port is not None: dsn += f":{conn.port}" if service_name: dsn += f"/{service_name}" elif conn.schema: warnings.warn( """Using conn.schema to pass the Oracle Service Name is deprecated. Please use conn.extra.service_name instead.""", AirflowProviderDeprecationWarning, stacklevel=2, ) dsn += f"/{conn.schema}" conn_config["dsn"] = dsn if "events" in conn.extra_dejson: conn_config["events"] = conn.extra_dejson.get("events") mode = conn.extra_dejson.get("mode", "").lower() if mode == "sysdba": conn_config["mode"] = oracledb.AUTH_MODE_SYSDBA elif mode == "sysasm": conn_config["mode"] = oracledb.AUTH_MODE_SYSASM elif mode == "sysoper": conn_config["mode"] = oracledb.AUTH_MODE_SYSOPER elif mode == "sysbkp": conn_config["mode"] = oracledb.AUTH_MODE_SYSBKP elif mode == "sysdgd": conn_config["mode"] = oracledb.AUTH_MODE_SYSDGD elif mode == "syskmt": conn_config["mode"] = oracledb.AUTH_MODE_SYSKMT elif mode == "sysrac": conn_config["mode"] = oracledb.AUTH_MODE_SYSRAC purity = conn.extra_dejson.get("purity", "").lower() if purity == "new": conn_config["purity"] = oracledb.PURITY_NEW elif purity == "self": conn_config["purity"] = oracledb.PURITY_SELF elif purity == "default": conn_config["purity"] = oracledb.PURITY_DEFAULT conn = oracledb.connect(**conn_config) if mod is not None: conn.module = mod # if Connection.schema is defined, set schema after connecting successfully # cannot be part of conn_config # https://python-oracledb.readthedocs.io/en/latest/api_manual/connection.html?highlight=schema#Connection.current_schema # Only set schema when not using conn.schema as Service Name if schema and service_name: conn.current_schema = schema return conn
[docs] def insert_rows( self, table: str, rows: list[tuple], target_fields=None, commit_every: int = 1000, replace: bool | None = False, **kwargs, ) -> None: """ Insert a collection of tuples into a table. All data to insert are treated as one transaction. Changes from standard DbApiHook implementation: - Oracle SQL queries can not be terminated with a semicolon (``;``). - Replace NaN values with NULL using ``numpy.nan_to_num`` (not using ``is_nan()`` because of input types error for strings). - Coerce datetime cells to Oracle DATETIME format during insert. :param table: target Oracle table, use dot notation to target a specific database :param rows: the rows to insert into the table :param target_fields: the names of the columns to fill in the table :param commit_every: the maximum number of rows to insert in one transaction Default 1000, Set greater than 0. Set 1 to insert each row in each single transaction :param replace: Whether to replace instead of insert. Currently not implemented. """ if replace: warnings.warn( "Using 'replace=True' does not implement any replace functionality currently.", category=UserWarning, stacklevel=2, ) try: import numpy as np except ImportError: np = None # type: ignore if target_fields: target_fields = ", ".join(target_fields) target_fields = f"({target_fields})" else: target_fields = "" conn = self.get_conn() if self.supports_autocommit: self.set_autocommit(conn, False) cur = conn.cursor() # type: ignore[attr-defined] i = 0 for row in rows: i += 1 lst = [] for cell in row: if isinstance(cell, str): lst.append("'" + str(cell).replace("'", "''") + "'") elif cell is None or isinstance(cell, float) and math.isnan(cell): # coerce numpy NaN to NULL lst.append("NULL") elif np and isinstance(cell, np.datetime64): lst.append(f"'{cell}'") elif isinstance(cell, datetime): lst.append(f"to_date('{cell:%Y-%m-%d %H:%M:%S}','YYYY-MM-DD HH24:MI:SS')") else: lst.append(str(cell)) values = tuple(lst) sql = f"INSERT /*+ APPEND */ INTO {table} {target_fields} VALUES ({','.join(values)})" cur.execute(sql) if i % commit_every == 0: conn.commit() # type: ignore[attr-defined] self.log.info("Loaded %s into %s rows so far", i, table) conn.commit() # type: ignore[attr-defined] cur.close() conn.close() # type: ignore[attr-defined] self.log.info("Done loading. Loaded a total of %s rows", i)
[docs] def bulk_insert_rows( self, table: str, rows: list[tuple], target_fields: list[str] | None = None, commit_every: int = 5000, sequence_column: str | None = None, sequence_name: str | None = None, ): """ Perform bulk inserts efficiently for Oracle DB. This uses prepared statements via `executemany()`. For best performance, pass in `rows` as an iterator. :param table: target Oracle table, use dot notation to target a specific database :param rows: the rows to insert into the table :param target_fields: the names of the columns to fill in the table, default None. If None, each rows should have some order as table columns name :param commit_every: the maximum number of rows to insert in one transaction Default 5000. Set greater than 0. Set 1 to insert each row in each transaction :param sequence_column: the column name to which the sequence will be applied, default None. :param sequence_name: the names of the sequence_name in the table, default None. """ if not rows: raise ValueError("parameter rows could not be None or empty iterable") conn = self.get_conn() if self.supports_autocommit: self.set_autocommit(conn, False) cursor = conn.cursor() # type: ignore[attr-defined] values_base = target_fields or rows[0] if bool(sequence_column) ^ bool(sequence_name): raise ValueError( "Parameters 'sequence_column' and 'sequence_name' must be provided together or not at all." ) if sequence_column and sequence_name: prepared_stm = "insert into {tablename} {columns} values ({values})".format( tablename=table, columns="({})".format(", ".join([sequence_column] + target_fields)) if target_fields else f"({sequence_column})", values=", ".join( [f"{sequence_name}.NEXTVAL"] + [f":{i}" for i in range(1, len(values_base) + 1)] ), ) else: prepared_stm = "insert into {tablename} {columns} values ({values})".format( tablename=table, columns="({})".format(", ".join(target_fields)) if target_fields else "", values=", ".join(f":{i}" for i in range(1, len(values_base) + 1)), ) row_count = 0 # Chunk the rows row_chunk = [] for row in rows: row_chunk.append(row) row_count += 1 if row_count % commit_every == 0: cursor.prepare(prepared_stm) cursor.executemany(None, row_chunk) conn.commit() # type: ignore[attr-defined] self.log.info("[%s] inserted %s rows", table, row_count) # Empty chunk row_chunk = [] # Commit the leftover chunk cursor.prepare(prepared_stm) cursor.executemany(None, row_chunk) conn.commit() # type: ignore[attr-defined] self.log.info("[%s] inserted %s rows", table, row_count) cursor.close() conn.close() # type: ignore[attr-defined]
[docs] def callproc( self, identifier: str, autocommit: bool = False, parameters: list | dict | None = None, ) -> list | dict | tuple | None: """ Call the stored procedure identified by the provided string. Any OUT parameters must be provided with a value of either the expected Python type (e.g., `int`) or an instance of that type. The return value is a list or mapping that includes parameters in both directions; the actual return type depends on the type of the provided `parameters` argument. See https://python-oracledb.readthedocs.io/en/latest/api_manual/cursor.html#Cursor.var for further reference. """ if parameters is None: parameters = [] args = ",".join( f":{name}" for name in (parameters if isinstance(parameters, dict) else range(1, len(parameters) + 1)) ) sql = f"BEGIN {identifier}({args}); END;" def handler(cursor): if cursor.bindvars is None: return if isinstance(cursor.bindvars, list): return [v.getvalue() for v in cursor.bindvars] if isinstance(cursor.bindvars, dict): return {n: v.getvalue() for (n, v) in cursor.bindvars.items()} raise TypeError(f"Unexpected bindvars: {cursor.bindvars!r}") result = self.run( sql, autocommit=autocommit, parameters=( {name: _map_param(value) for (name, value) in parameters.items()} if isinstance(parameters, dict) else [_map_param(value) for value in parameters] ), handler=handler, ) return result

Was this entry helpful?