# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import typing
from jupyter_client import AsyncKernelManager
from papermill.clientwrap import PapermillNotebookClient
from papermill.engines import NBClientEngine
from papermill.utils import merge_kwargs, remove_args
from traitlets import Unicode
from airflow.hooks.base import BaseHook
[docs]JUPYTER_KERNEL_SHELL_PORT = 60316
[docs]JUPYTER_KERNEL_IOPUB_PORT = 60317
[docs]JUPYTER_KERNEL_STDIN_PORT = 60318
[docs]JUPYTER_KERNEL_CONTROL_PORT = 60319
[docs]JUPYTER_KERNEL_HB_PORT = 60320
[docs]REMOTE_KERNEL_ENGINE = "remote_kernel_engine"
[docs]class KernelConnection:
"""Class to represent kernel connection object."""
[docs]class KernelHook(BaseHook):
"""
The KernelHook can be used to interact with remote jupyter kernel.
Takes kernel host/ip from connection and refers to jupyter kernel ports and session_key
from ``extra`` field.
:param kernel_conn_id: connection that has kernel host/ip
"""
[docs] conn_name_attr = "kernel_conn_id"
[docs] default_conn_name = "jupyter_kernel_default"
[docs] conn_type = "jupyter_kernel"
[docs] hook_name = "Jupyter Kernel"
def __init__(self, kernel_conn_id: str = default_conn_name, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.kernel_conn = self.get_connection(kernel_conn_id)
register_remote_kernel_engine()
[docs] def get_conn(self) -> KernelConnection:
kernel_connection = KernelConnection()
kernel_connection.ip = self.kernel_conn.host
kernel_connection.shell_port = self.kernel_conn.extra_dejson.get(
"shell_port", JUPYTER_KERNEL_SHELL_PORT
)
kernel_connection.iopub_port = self.kernel_conn.extra_dejson.get(
"iopub_port", JUPYTER_KERNEL_IOPUB_PORT
)
kernel_connection.stdin_port = self.kernel_conn.extra_dejson.get(
"stdin_port", JUPYTER_KERNEL_STDIN_PORT
)
kernel_connection.control_port = self.kernel_conn.extra_dejson.get(
"control_port", JUPYTER_KERNEL_CONTROL_PORT
)
kernel_connection.hb_port = self.kernel_conn.extra_dejson.get("hb_port", JUPYTER_KERNEL_HB_PORT)
kernel_connection.session_key = self.kernel_conn.extra_dejson.get("session_key", "")
return kernel_connection
[docs]def register_remote_kernel_engine():
"""Register ``RemoteKernelEngine`` papermill engine."""
from papermill.engines import papermill_engines
papermill_engines.register(REMOTE_KERNEL_ENGINE, RemoteKernelEngine)
[docs]class RemoteKernelManager(AsyncKernelManager):
"""Jupyter kernel manager that connects to a remote kernel."""
[docs] session_key = Unicode("", config=True, help="Session key to connect to remote kernel")
@property
[docs] def has_kernel(self) -> bool:
return True
async def _async_is_alive(self) -> bool:
return True
[docs] def shutdown_kernel(self, now: bool = False, restart: bool = False):
pass
[docs] def client(self, **kwargs: typing.Any):
"""Create a client configured to connect to our kernel."""
kernel_client = super().client(**kwargs)
# load connection info to set session_key
config: dict[str, int | str | bytes] = dict(
ip=self.ip,
shell_port=self.shell_port,
iopub_port=self.iopub_port,
stdin_port=self.stdin_port,
control_port=self.control_port,
hb_port=self.hb_port,
key=self.session_key,
transport="tcp",
signature_scheme="hmac-sha256",
)
kernel_client.load_connection_info(config)
return kernel_client
[docs]class RemoteKernelEngine(NBClientEngine):
"""Papermill engine to use ``RemoteKernelManager`` to connect to remote kernel and execute notebook."""
@classmethod
[docs] def execute_managed_notebook(
cls,
nb_man,
kernel_name,
log_output=False,
stdout_file=None,
stderr_file=None,
start_timeout=60,
execution_timeout=None,
**kwargs,
):
"""Perform the actual execution of the parameterized notebook locally."""
km = RemoteKernelManager()
km.ip = kwargs["kernel_ip"]
km.shell_port = kwargs["kernel_shell_port"]
km.iopub_port = kwargs["kernel_iopub_port"]
km.stdin_port = kwargs["kernel_stdin_port"]
km.control_port = kwargs["kernel_control_port"]
km.hb_port = kwargs["kernel_hb_port"]
km.ip = kwargs["kernel_ip"]
km.session_key = kwargs["kernel_session_key"]
# Exclude parameters that named differently downstream
safe_kwargs = remove_args(["timeout", "startup_timeout"], **kwargs)
final_kwargs = merge_kwargs(
safe_kwargs,
timeout=execution_timeout if execution_timeout else kwargs.get("timeout"),
startup_timeout=start_timeout,
log_output=False,
stdout_file=stdout_file,
stderr_file=stderr_file,
)
return PapermillNotebookClient(nb_man, km=km, **final_kwargs).execute()