Source code for airflow.providers.samba.hooks.samba

#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

import posixpath
from functools import wraps
from shutil import copyfileobj
from typing import TYPE_CHECKING, Any

import smbclient

from airflow.hooks.base import BaseHook

if TYPE_CHECKING:
    import smbprotocol.connection


[docs]class SambaHook(BaseHook): """Allows for interaction with a Samba server. The hook should be used as a context manager in order to correctly set up a session and disconnect open connections upon exit. :param samba_conn_id: The connection id reference. :param share: An optional share name. If this is unset then the "schema" field of the connection is used in its place. """
[docs] conn_name_attr = "samba_conn_id"
[docs] default_conn_name = "samba_default"
[docs] conn_type = "samba"
[docs] hook_name = "Samba"
def __init__(self, samba_conn_id: str = default_conn_name, share: str | None = None) -> None: super().__init__() conn = self.get_connection(samba_conn_id) if not conn.login: self.log.info("Login not provided") if not conn.password: self.log.info("Password not provided") connection_cache: dict[str, smbprotocol.connection.Connection] = {} self._host = conn.host self._share = share or conn.schema self._connection_cache = connection_cache self._conn_kwargs = { "username": conn.login, "password": conn.password, "port": conn.port or 445, "connection_cache": connection_cache, }
[docs] def __enter__(self): # This immediately connects to the host (which can be # perceived as a benefit), but also help work around an issue: # # https://github.com/jborean93/smbprotocol/issues/109. smbclient.register_session(self._host, **self._conn_kwargs) return self
[docs] def __exit__(self, exc_type, exc_value, traceback): for host, connection in self._connection_cache.items(): self.log.info("Disconnecting from %s", host) connection.disconnect() self._connection_cache.clear()
def _join_path(self, path): return f"//{posixpath.join(self._host, self._share, path.lstrip('/'))}" @wraps(smbclient.link) @wraps(smbclient.listdir)
[docs] def listdir(self, path): return smbclient.listdir(self._join_path(path), **self._conn_kwargs)
@wraps(smbclient.lstat)
[docs] def lstat(self, path): return smbclient.lstat(self._join_path(path), **self._conn_kwargs)
@wraps(smbclient.makedirs)
[docs] def makedirs(self, path, exist_ok=False): return smbclient.makedirs(self._join_path(path), exist_ok=exist_ok, **self._conn_kwargs)
@wraps(smbclient.mkdir)
[docs] def mkdir(self, path): return smbclient.mkdir(self._join_path(path), **self._conn_kwargs)
@wraps(smbclient.open_file)
[docs] def open_file( self, path, mode="r", buffering=-1, encoding=None, errors=None, newline=None, share_access=None, desired_access=None, file_attributes=None, file_type="file", ): return smbclient.open_file( self._join_path(path), mode=mode, buffering=buffering, encoding=encoding, errors=errors, newline=newline, share_access=share_access, desired_access=desired_access, file_attributes=file_attributes, file_type=file_type, **self._conn_kwargs, )
@wraps(smbclient.readlink) @wraps(smbclient.remove)
[docs] def remove(self, path): return smbclient.remove(self._join_path(path), **self._conn_kwargs)
@wraps(smbclient.removedirs)
[docs] def removedirs(self, path): return smbclient.removedirs(self._join_path(path), **self._conn_kwargs)
@wraps(smbclient.rename)
[docs] def rename(self, src, dst): return smbclient.rename(self._join_path(src), self._join_path(dst), **self._conn_kwargs)
@wraps(smbclient.replace)
[docs] def replace(self, src, dst): return smbclient.replace(self._join_path(src), self._join_path(dst), **self._conn_kwargs)
@wraps(smbclient.rmdir)
[docs] def rmdir(self, path): return smbclient.rmdir(self._join_path(path), **self._conn_kwargs)
@wraps(smbclient.scandir)
[docs] def scandir(self, path, search_pattern="*"): return smbclient.scandir( self._join_path(path), search_pattern=search_pattern, **self._conn_kwargs, )
@wraps(smbclient.stat)
[docs] def stat(self, path, follow_symlinks=True): return smbclient.stat(self._join_path(path), follow_symlinks=follow_symlinks, **self._conn_kwargs)
@wraps(smbclient.stat_volume)
[docs] def stat_volume(self, path): return smbclient.stat_volume(self._join_path(path), **self._conn_kwargs)
@wraps(smbclient.symlink) @wraps(smbclient.truncate)
[docs] def truncate(self, path, length): return smbclient.truncate(self._join_path(path), length, **self._conn_kwargs)
@wraps(smbclient.unlink) @wraps(smbclient.utime)
[docs] def utime(self, path, times=None, ns=None, follow_symlinks=True): return smbclient.utime( self._join_path(path), times=times, ns=ns, follow_symlinks=follow_symlinks, **self._conn_kwargs, )
@wraps(smbclient.walk)
[docs] def walk(self, path, topdown=True, onerror=None, follow_symlinks=False): return smbclient.walk( self._join_path(path), topdown=topdown, onerror=onerror, follow_symlinks=follow_symlinks, **self._conn_kwargs, )
@wraps(smbclient.getxattr)
[docs] def getxattr(self, path, attribute, follow_symlinks=True): return smbclient.getxattr( self._join_path(path), attribute, follow_symlinks=follow_symlinks, **self._conn_kwargs )
@wraps(smbclient.listxattr)
[docs] def listxattr(self, path, follow_symlinks=True): return smbclient.listxattr( self._join_path(path), follow_symlinks=follow_symlinks, **self._conn_kwargs )
@wraps(smbclient.removexattr)
[docs] def removexattr(self, path, attribute, follow_symlinks=True): return smbclient.removexattr( self._join_path(path), attribute, follow_symlinks=follow_symlinks, **self._conn_kwargs )
@wraps(smbclient.setxattr)
[docs] def setxattr(self, path, attribute, value, flags=0, follow_symlinks=True): return smbclient.setxattr( self._join_path(path), attribute, value, flags=flags, follow_symlinks=follow_symlinks, **self._conn_kwargs, )
[docs] def push_from_local(self, destination_filepath: str, local_filepath: str, buffer_size: int | None = None): """ Push local file to samba server. :param destination_filepath: the samba location to push to :param local_filepath: the file to push :param buffer_size: size in bytes of the individual chunks of file to send. Larger values may speed up large file transfers """ extra_args = (buffer_size,) if buffer_size else () with open(local_filepath, "rb") as f, self.open_file(destination_filepath, mode="wb") as g: copyfileobj(f, g, *extra_args)
@classmethod
[docs] def get_ui_field_behaviour(cls) -> dict[str, Any]: """Return custom field behaviour.""" return { "hidden_fields": ["extra"], "relabeling": {"schema": "Share"}, }

Was this entry helpful?