# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import inspect
import re
import warnings
from itertools import chain
from textwrap import dedent
from typing import (
Any,
Callable,
ClassVar,
Collection,
Dict,
Generic,
Iterator,
Mapping,
Sequence,
TypeVar,
cast,
overload,
)
import attr
import typing_extensions
from sqlalchemy.orm import Session
from airflow import Dataset
from airflow.compat.functools import cached_property
from airflow.exceptions import AirflowException
from airflow.models.abstractoperator import DEFAULT_RETRIES, DEFAULT_RETRY_DELAY
from airflow.models.baseoperator import (
BaseOperator,
coerce_resources,
coerce_timedelta,
get_merged_defaults,
parse_retries,
)
from airflow.models.dag import DAG, DagContext
from airflow.models.expandinput import (
EXPAND_INPUT_EMPTY,
DictOfListsExpandInput,
ExpandInput,
ListOfDictsExpandInput,
OperatorExpandArgument,
OperatorExpandKwargsArgument,
is_mappable,
)
from airflow.models.mappedoperator import MappedOperator, ValidationSource, ensure_xcomarg_return_value
from airflow.models.pool import Pool
from airflow.models.xcom_arg import XComArg
from airflow.typing_compat import ParamSpec, Protocol
from airflow.utils import timezone
from airflow.utils.context import KNOWN_CONTEXT_KEYS, Context
from airflow.utils.decorators import remove_task_decorator
from airflow.utils.helpers import prevent_duplicates
from airflow.utils.task_group import TaskGroup, TaskGroupContext
from airflow.utils.types import NOTSET
class ExpandableFactory(Protocol):
"""Protocol providing inspection against wrapped function.
This is used in ``validate_expand_kwargs`` and implemented by function
decorators like ``@task`` and ``@task_group``.
:meta private:
"""
function: Callable
@cached_property
def function_signature(self) -> inspect.Signature:
return inspect.signature(self.function)
@cached_property
def _mappable_function_argument_names(self) -> set[str]:
"""Arguments that can be mapped against."""
return set(self.function_signature.parameters)
def _validate_arg_names(self, func: ValidationSource, kwargs: dict[str, Any]) -> None:
"""Ensure that all arguments passed to operator-mapping functions are accounted for."""
parameters = self.function_signature.parameters
if any(v.kind == inspect.Parameter.VAR_KEYWORD for v in parameters.values()):
return
kwargs_left = kwargs.copy()
for arg_name in self._mappable_function_argument_names:
value = kwargs_left.pop(arg_name, NOTSET)
if func != "expand" or value is NOTSET or is_mappable(value):
continue
tname = type(value).__name__
raise ValueError(f"expand() got an unexpected type {tname!r} for keyword argument {arg_name!r}")
if len(kwargs_left) == 1:
raise TypeError(f"{func}() got an unexpected keyword argument {next(iter(kwargs_left))!r}")
elif kwargs_left:
names = ", ".join(repr(n) for n in kwargs_left)
raise TypeError(f"{func}() got unexpected keyword arguments {names}")
[docs]def get_unique_task_id(
task_id: str,
dag: DAG | None = None,
task_group: TaskGroup | None = None,
) -> str:
"""
Generate unique task id given a DAG (or if run in a DAG context).
IDs are generated by appending a unique number to the end of
the original task id.
Example:
task_id
task_id__1
task_id__2
...
task_id__20
"""
dag = dag or DagContext.get_current_dag()
if not dag:
return task_id
# We need to check if we are in the context of TaskGroup as the task_id may
# already be altered
task_group = task_group or TaskGroupContext.get_current_task_group(dag)
tg_task_id = task_group.child_id(task_id) if task_group else task_id
if tg_task_id not in dag.task_ids:
return task_id
def _find_id_suffixes(dag: DAG) -> Iterator[int]:
prefix = re.split(r"__\d+$", tg_task_id)[0]
for task_id in dag.task_ids:
match = re.match(rf"^{prefix}__(\d+)$", task_id)
if match is None:
continue
yield int(match.group(1))
yield 0 # Default if there's no matching task ID.
core = re.split(r"__\d+$", task_id)[0]
return f"{core}__{max(_find_id_suffixes(dag)) + 1}"
[docs]class DecoratedOperator(BaseOperator):
"""
Wraps a Python callable and captures args/kwargs when called for execution.
:param python_callable: A reference to an object that is callable
:param op_kwargs: a dictionary of keyword arguments that will get unpacked
in your function (templated)
:param op_args: a list of positional arguments that will get unpacked when
calling your callable (templated)
:param multiple_outputs: If set to True, the decorated function's return value will be unrolled to
multiple XCom values. Dict will unroll to XCom values with its keys as XCom keys. Defaults to False.
:param kwargs_to_upstream: For certain operators, we might need to upstream certain arguments
that would otherwise be absorbed by the DecoratedOperator (for example python_callable for the
PythonOperator). This gives a user the option to upstream kwargs as needed.
"""
[docs] template_fields: Sequence[str] = ("op_args", "op_kwargs")
[docs] template_fields_renderers = {"op_args": "py", "op_kwargs": "py"}
# since we won't mutate the arguments, we should just do the shallow copy
# there are some cases we can't deepcopy the objects (e.g protobuf).
[docs] shallow_copy_attrs: Sequence[str] = ("python_callable",)
def __init__(
self,
*,
python_callable: Callable,
task_id: str,
op_args: Collection[Any] | None = None,
op_kwargs: Mapping[str, Any] | None = None,
multiple_outputs: bool = False,
kwargs_to_upstream: dict[str, Any] | None = None,
**kwargs,
) -> None:
task_id = get_unique_task_id(task_id, kwargs.get("dag"), kwargs.get("task_group"))
self.python_callable = python_callable
kwargs_to_upstream = kwargs_to_upstream or {}
op_args = op_args or []
op_kwargs = op_kwargs or {}
# Check that arguments can be binded. There's a slight difference when
# we do validation for task-mapping: Since there's no guarantee we can
# receive enough arguments at parse time, we use bind_partial to simply
# check all the arguments we know are valid. Whether these are enough
# can only be known at execution time, when unmapping happens, and this
# is called without the _airflow_mapped_validation_only flag.
if kwargs.get("_airflow_mapped_validation_only"):
inspect.signature(python_callable).bind_partial(*op_args, **op_kwargs)
else:
inspect.signature(python_callable).bind(*op_args, **op_kwargs)
self.multiple_outputs = multiple_outputs
self.op_args = op_args
self.op_kwargs = op_kwargs
super().__init__(task_id=task_id, **kwargs_to_upstream, **kwargs)
[docs] def execute(self, context: Context):
# todo make this more generic (move to prepare_lineage) so it deals with non taskflow operators
# as well
for arg in chain(self.op_args, self.op_kwargs.values()):
if isinstance(arg, Dataset):
self.inlets.append(arg)
return_value = super().execute(context)
return self._handle_output(return_value=return_value, context=context, xcom_push=self.xcom_push)
def _handle_output(self, return_value: Any, context: Context, xcom_push: Callable):
"""
Handles logic for whether a decorator needs to push a single return value or multiple return values.
It sets outlets if any datasets are found in the returned value(s)
:param return_value:
:param context:
:param xcom_push:
"""
if isinstance(return_value, Dataset):
self.outlets.append(return_value)
if isinstance(return_value, list):
for item in return_value:
if isinstance(item, Dataset):
self.outlets.append(item)
if not self.multiple_outputs:
return return_value
if isinstance(return_value, dict):
for key in return_value.keys():
if not isinstance(key, str):
raise AirflowException(
"Returned dictionary keys must be strings when using "
f"multiple_outputs, found {key} ({type(key)}) instead"
)
for key, value in return_value.items():
if isinstance(value, Dataset):
self.outlets.append(value)
xcom_push(context, key, value)
else:
raise AirflowException(
f"Returned output was type {type(return_value)} expected dictionary for multiple_outputs"
)
return return_value
def _hook_apply_defaults(self, *args, **kwargs):
if "python_callable" not in kwargs:
return args, kwargs
python_callable = kwargs["python_callable"]
default_args = kwargs.get("default_args") or {}
op_kwargs = kwargs.get("op_kwargs") or {}
f_sig = inspect.signature(python_callable)
for arg in f_sig.parameters:
if arg not in op_kwargs and arg in default_args:
op_kwargs[arg] = default_args[arg]
kwargs["op_kwargs"] = op_kwargs
return args, kwargs
[docs] def get_python_source(self):
raw_source = inspect.getsource(self.python_callable)
res = dedent(raw_source)
res = remove_task_decorator(res, self.custom_operator_name)
return res
[docs]FParams = ParamSpec("FParams")
[docs]FReturn = TypeVar("FReturn")
[docs]OperatorSubclass = TypeVar("OperatorSubclass", bound="BaseOperator")
@attr.define(slots=False)
class _TaskDecorator(ExpandableFactory, Generic[FParams, FReturn, OperatorSubclass]):
"""
Helper class for providing dynamic task mapping to decorated functions.
``task_decorator_factory`` returns an instance of this, instead of just a plain wrapped function.
:meta private:
"""
function: Callable[FParams, FReturn] = attr.ib(validator=attr.validators.is_callable())
operator_class: type[OperatorSubclass]
multiple_outputs: bool = attr.ib()
kwargs: dict[str, Any] = attr.ib(factory=dict)
decorator_name: str = attr.ib(repr=False, default="task")
_airflow_is_task_decorator: ClassVar[bool] = True
_is_setup: ClassVar[bool] = False
_is_teardown: ClassVar[bool] = False
_on_failure_fail_dagrun: ClassVar[bool] = False
@multiple_outputs.default
def _infer_multiple_outputs(self):
if "return" not in self.function.__annotations__:
# No return type annotation, nothing to infer
return False
try:
# We only care about the return annotation, not anything about the parameters
def fake():
...
fake.__annotations__ = {"return": self.function.__annotations__["return"]}
return_type = typing_extensions.get_type_hints(fake, self.function.__globals__).get("return", Any)
except NameError as e:
warnings.warn(
f"Cannot infer multiple_outputs for TaskFlow function {self.function.__name__!r} with forward"
f" type references that are not imported. (Error was {e})",
stacklevel=4,
)
return False
except TypeError: # Can't evaluate return type.
return False
ttype = getattr(return_type, "__origin__", return_type)
return ttype == dict or ttype == Dict
def __attrs_post_init__(self):
if "self" in self.function_signature.parameters:
raise TypeError(f"@{self.decorator_name} does not support methods")
self.kwargs.setdefault("task_id", self.function.__name__)
def __call__(self, *args: FParams.args, **kwargs: FParams.kwargs) -> XComArg:
op = self.operator_class(
python_callable=self.function,
op_args=args,
op_kwargs=kwargs,
multiple_outputs=self.multiple_outputs,
**self.kwargs,
)
op._is_setup = self._is_setup
op._is_teardown = self._is_teardown
op._on_failure_fail_dagrun = self._on_failure_fail_dagrun
op_doc_attrs = [op.doc, op.doc_json, op.doc_md, op.doc_rst, op.doc_yaml]
# Set the task's doc_md to the function's docstring if it exists and no other doc* args are set.
if self.function.__doc__ and not any(op_doc_attrs):
op.doc_md = self.function.__doc__
return XComArg(op)
@property
def __wrapped__(self) -> Callable[FParams, FReturn]:
return self.function
def _validate_arg_names(self, func: ValidationSource, kwargs: dict[str, Any]):
# Ensure that context variables are not shadowed.
context_keys_being_mapped = KNOWN_CONTEXT_KEYS.intersection(kwargs)
if len(context_keys_being_mapped) == 1:
(name,) = context_keys_being_mapped
raise ValueError(f"cannot call {func}() on task context variable {name!r}")
elif context_keys_being_mapped:
names = ", ".join(repr(n) for n in context_keys_being_mapped)
raise ValueError(f"cannot call {func}() on task context variables {names}")
super()._validate_arg_names(func, kwargs)
def expand(self, **map_kwargs: OperatorExpandArgument) -> XComArg:
if not map_kwargs:
raise TypeError("no arguments to expand against")
self._validate_arg_names("expand", map_kwargs)
prevent_duplicates(self.kwargs, map_kwargs, fail_reason="mapping already partial")
# Since the input is already checked at parse time, we can set strict
# to False to skip the checks on execution.
return self._expand(DictOfListsExpandInput(map_kwargs), strict=False)
def expand_kwargs(self, kwargs: OperatorExpandKwargsArgument, *, strict: bool = True) -> XComArg:
if isinstance(kwargs, Sequence):
for item in kwargs:
if not isinstance(item, (XComArg, Mapping)):
raise TypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}")
elif not isinstance(kwargs, XComArg):
raise TypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}")
return self._expand(ListOfDictsExpandInput(kwargs), strict=strict)
def _expand(self, expand_input: ExpandInput, *, strict: bool) -> XComArg:
ensure_xcomarg_return_value(expand_input.value)
task_kwargs = self.kwargs.copy()
dag = task_kwargs.pop("dag", None) or DagContext.get_current_dag()
task_group = task_kwargs.pop("task_group", None) or TaskGroupContext.get_current_task_group(dag)
partial_kwargs, partial_params = get_merged_defaults(
dag=dag,
task_group=task_group,
task_params=task_kwargs.pop("params", None),
task_default_args=task_kwargs.pop("default_args", None),
)
partial_kwargs.update(task_kwargs)
task_id = get_unique_task_id(partial_kwargs.pop("task_id"), dag, task_group)
if task_group:
task_id = task_group.child_id(task_id)
# Logic here should be kept in sync with BaseOperatorMeta.partial().
if "task_concurrency" in partial_kwargs:
raise TypeError("unexpected argument: task_concurrency")
if partial_kwargs.get("wait_for_downstream"):
partial_kwargs["depends_on_past"] = True
start_date = timezone.convert_to_utc(partial_kwargs.pop("start_date", None))
end_date = timezone.convert_to_utc(partial_kwargs.pop("end_date", None))
if partial_kwargs.get("pool") is None:
partial_kwargs["pool"] = Pool.DEFAULT_POOL_NAME
partial_kwargs["retries"] = parse_retries(partial_kwargs.get("retries", DEFAULT_RETRIES))
partial_kwargs["retry_delay"] = coerce_timedelta(
partial_kwargs.get("retry_delay", DEFAULT_RETRY_DELAY),
key="retry_delay",
)
max_retry_delay = partial_kwargs.get("max_retry_delay")
partial_kwargs["max_retry_delay"] = (
max_retry_delay
if max_retry_delay is None
else coerce_timedelta(max_retry_delay, key="max_retry_delay")
)
partial_kwargs["resources"] = coerce_resources(partial_kwargs.get("resources"))
partial_kwargs.setdefault("executor_config", {})
partial_kwargs.setdefault("op_args", [])
partial_kwargs.setdefault("op_kwargs", {})
# Mypy does not work well with a subclassed attrs class :(
_MappedOperator = cast(Any, DecoratedMappedOperator)
try:
operator_name = self.operator_class.custom_operator_name # type: ignore
except AttributeError:
operator_name = self.operator_class.__name__
operator = _MappedOperator(
operator_class=self.operator_class,
expand_input=EXPAND_INPUT_EMPTY, # Don't use this; mapped values go to op_kwargs_expand_input.
partial_kwargs=partial_kwargs,
task_id=task_id,
params=partial_params,
deps=MappedOperator.deps_for(self.operator_class),
operator_extra_links=self.operator_class.operator_extra_links,
template_ext=self.operator_class.template_ext,
template_fields=self.operator_class.template_fields,
template_fields_renderers=self.operator_class.template_fields_renderers,
ui_color=self.operator_class.ui_color,
ui_fgcolor=self.operator_class.ui_fgcolor,
is_empty=False,
task_module=self.operator_class.__module__,
task_type=self.operator_class.__name__,
operator_name=operator_name,
dag=dag,
task_group=task_group,
start_date=start_date,
end_date=end_date,
multiple_outputs=self.multiple_outputs,
python_callable=self.function,
op_kwargs_expand_input=expand_input,
disallow_kwargs_override=strict,
# Different from classic operators, kwargs passed to a taskflow
# task's expand() contribute to the op_kwargs operator argument, not
# the operator arguments themselves, and should expand against it.
expand_input_attr="op_kwargs_expand_input",
)
return XComArg(operator=operator)
def partial(self, **kwargs: Any) -> _TaskDecorator[FParams, FReturn, OperatorSubclass]:
self._validate_arg_names("partial", kwargs)
old_kwargs = self.kwargs.get("op_kwargs", {})
prevent_duplicates(old_kwargs, kwargs, fail_reason="duplicate partial")
kwargs.update(old_kwargs)
return attr.evolve(self, kwargs={**self.kwargs, "op_kwargs": kwargs})
def override(self, **kwargs: Any) -> _TaskDecorator[FParams, FReturn, OperatorSubclass]:
result = attr.evolve(self, kwargs={**self.kwargs, **kwargs})
setattr(result, "_is_setup", self._is_setup)
setattr(result, "_is_teardown", self._is_teardown)
setattr(result, "_on_failure_fail_dagrun", self._on_failure_fail_dagrun)
return result
@attr.define(kw_only=True, repr=False)
[docs]class DecoratedMappedOperator(MappedOperator):
"""MappedOperator implementation for @task-decorated task function."""
[docs] python_callable: Callable
# We can't save these in expand_input because op_kwargs need to be present
# in partial_kwargs, and MappedOperator prevents duplication.
[docs] def __hash__(self):
return id(self)
[docs] def __attrs_post_init__(self):
# The magic super() doesn't work here, so we use the explicit form.
# Not using super(..., self) to work around pyupgrade bug.
super(DecoratedMappedOperator, DecoratedMappedOperator).__attrs_post_init__(self)
XComArg.apply_upstream_relationship(self, self.op_kwargs_expand_input.value)
def _expand_mapped_kwargs(self, context: Context, session: Session) -> tuple[Mapping[str, Any], set[int]]:
# We only use op_kwargs_expand_input so this must always be empty.
assert self.expand_input is EXPAND_INPUT_EMPTY
op_kwargs, resolved_oids = super()._expand_mapped_kwargs(context, session)
return {"op_kwargs": op_kwargs}, resolved_oids
def _get_unmap_kwargs(self, mapped_kwargs: Mapping[str, Any], *, strict: bool) -> dict[str, Any]:
partial_op_kwargs = self.partial_kwargs["op_kwargs"]
mapped_op_kwargs = mapped_kwargs["op_kwargs"]
if strict:
prevent_duplicates(partial_op_kwargs, mapped_op_kwargs, fail_reason="mapping already partial")
kwargs = {
"multiple_outputs": self.multiple_outputs,
"python_callable": self.python_callable,
"op_kwargs": {**partial_op_kwargs, **mapped_op_kwargs},
}
return super()._get_unmap_kwargs(kwargs, strict=False)
[docs]class Task(Protocol, Generic[FParams, FReturn]):
"""Declaration of a @task-decorated callable for type-checking.
An instance of this type inherits the call signature of the decorated
function wrapped in it (not *exactly* since it actually returns an XComArg,
but there's no way to express that right now), and provides two additional
methods for task-mapping.
This type is implemented by ``_TaskDecorator`` at runtime.
"""
[docs] __call__: Callable[FParams, XComArg]
[docs] function: Callable[FParams, FReturn]
@property
[docs] def __wrapped__(self) -> Callable[FParams, FReturn]:
...
[docs] def partial(self, **kwargs: Any) -> Task[FParams, FReturn]:
...
[docs] def expand(self, **kwargs: OperatorExpandArgument) -> XComArg:
...
[docs] def expand_kwargs(self, kwargs: OperatorExpandKwargsArgument, *, strict: bool = True) -> XComArg:
...
[docs] def override(self, **kwargs: Any) -> Task[FParams, FReturn]:
...
[docs]class TaskDecorator(Protocol):
"""Type declaration for ``task_decorator_factory`` return type."""
@overload
[docs] def __call__( # type: ignore[misc]
self,
python_callable: Callable[FParams, FReturn],
) -> Task[FParams, FReturn]:
"""For the "bare decorator" ``@task`` case."""
@overload
def __call__(
self,
*,
multiple_outputs: bool | None = None,
**kwargs: Any,
) -> Callable[[Callable[FParams, FReturn]], Task[FParams, FReturn]]:
"""For the decorator factory ``@task()`` case."""
[docs] def override(self, **kwargs: Any) -> Task[FParams, FReturn]:
...
[docs]def task_decorator_factory(
python_callable: Callable | None = None,
*,
multiple_outputs: bool | None = None,
decorated_operator_class: type[BaseOperator],
**kwargs,
) -> TaskDecorator:
"""Generate a wrapper that wraps a function into an Airflow operator.
Can be reused in a single DAG.
:param python_callable: Function to decorate.
:param multiple_outputs: If set to True, the decorated function's return
value will be unrolled to multiple XCom values. Dict will unroll to XCom
values with its keys as XCom keys. If set to False (default), only at
most one XCom value is pushed.
:param decorated_operator_class: The operator that executes the logic needed
to run the python function in the correct environment.
Other kwargs are directly forwarded to the underlying operator class when
it's instantiated.
"""
if multiple_outputs is None:
multiple_outputs = cast(bool, attr.NOTHING)
if python_callable:
decorator = _TaskDecorator(
function=python_callable,
multiple_outputs=multiple_outputs,
operator_class=decorated_operator_class,
kwargs=kwargs,
)
return cast(TaskDecorator, decorator)
elif python_callable is not None:
raise TypeError("No args allowed while using @task, use kwargs instead")
def decorator_factory(python_callable):
return _TaskDecorator(
function=python_callable,
multiple_outputs=multiple_outputs,
operator_class=decorated_operator_class,
kwargs=kwargs,
)
return cast(TaskDecorator, decorator_factory)