Source code for airflow.providers.amazon.aws.auth_manager.router.login
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import logging
from typing import Any
import anyio
from fastapi import HTTPException, Request
from starlette import status
from starlette.responses import RedirectResponse
from airflow.api_fastapi.app import get_auth_manager
from airflow.api_fastapi.common.router import AirflowRouter
from airflow.configuration import conf
from airflow.providers.amazon.aws.auth_manager.constants import CONF_SAML_METADATA_URL_KEY, CONF_SECTION_NAME
from airflow.providers.amazon.aws.auth_manager.user import AwsAuthManagerUser
try:
from onelogin.saml2.auth import OneLogin_Saml2_Auth
from onelogin.saml2.errors import OneLogin_Saml2_Error
from onelogin.saml2.idp_metadata_parser import OneLogin_Saml2_IdPMetadataParser
except ImportError:
raise ImportError(
"AWS auth manager requires the python3-saml library but it is not installed by default. "
"Please install the python3-saml library by running: "
"pip install apache-airflow-providers-amazon[python3-saml]"
)
[docs]
log = logging.getLogger(__name__)
[docs]
login_router = AirflowRouter(tags=["AWSAuthManagerLogin"])
@login_router.get("/login")
[docs]
def login(request: Request):
"""Authenticate the user."""
saml_auth = _init_saml_auth(request)
callback_url = saml_auth.login()
return RedirectResponse(url=callback_url)
@login_router.post("/login_callback")
[docs]
def login_callback(request: Request):
"""Authenticate the user."""
saml_auth = _init_saml_auth(request)
try:
saml_auth.process_response()
except OneLogin_Saml2_Error as e:
log.exception(e)
raise HTTPException(status.HTTP_500_INTERNAL_SERVER_ERROR, "Failed to authenticate")
errors = saml_auth.get_errors()
is_authenticated = saml_auth.is_authenticated()
if not is_authenticated:
error_reason = saml_auth.get_last_error_reason()
log.error("Failed to authenticate")
log.error("Errors: %s", errors)
log.error("Error reason: %s", error_reason)
raise HTTPException(status.HTTP_500_INTERNAL_SERVER_ERROR, f"Failed to authenticate: {error_reason}")
attributes = saml_auth.get_attributes()
user = AwsAuthManagerUser(
user_id=attributes["id"][0],
groups=attributes["groups"],
username=saml_auth.get_nameid(),
email=attributes["email"][0] if "email" in attributes else None,
)
return RedirectResponse(url=f"/webapp?token={get_auth_manager().get_jwt_token(user)}", status_code=303)
def _init_saml_auth(request: Request) -> OneLogin_Saml2_Auth:
request_data = _prepare_request(request)
base_url = conf.get(section="fastapi", key="base_url")
settings = {
# We want to keep this flag on in case of errors.
# It provides an error reasons, if turned off, it does not
"debug": True,
"sp": {
"entityId": "aws-auth-manager-saml-client",
"assertionConsumerService": {
"url": f"{base_url}/auth/login_callback",
"binding": "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST",
},
},
}
merged_settings = OneLogin_Saml2_IdPMetadataParser.merge_settings(_get_idp_data(), settings)
return OneLogin_Saml2_Auth(request_data, merged_settings)
def _prepare_request(request: Request) -> dict:
host = request.headers.get("host", request.client.host if request.client else "localhost")
data: dict[str, Any] = {
"https": "on" if request.url.scheme == "https" else "off",
"http_host": host,
"server_port": request.url.port,
"script_name": request.url.path,
"get_data": request.query_params,
"post_data": {},
}
form_data = anyio.from_thread.run(request.form)
if "SAMLResponse" in form_data:
data["post_data"]["SAMLResponse"] = form_data["SAMLResponse"]
if "RelayState" in form_data:
data["post_data"]["RelayState"] = form_data["RelayState"]
return data
def _get_idp_data() -> dict:
saml_metadata_url = conf.get_mandatory_value(CONF_SECTION_NAME, CONF_SAML_METADATA_URL_KEY)
return OneLogin_Saml2_IdPMetadataParser.parse_remote(saml_metadata_url)