Source code for tests.system.providers.amazon.aws.example_sagemaker_endpoint

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

import json
from datetime import datetime

import boto3

from airflow import DAG
from airflow.decorators import task
from airflow.models.baseoperator import chain
from airflow.providers.amazon.aws.operators.s3 import (
    S3CreateBucketOperator,
    S3CreateObjectOperator,
    S3DeleteBucketOperator,
)
from airflow.providers.amazon.aws.operators.sagemaker import (
    SageMakerDeleteModelOperator,
    SageMakerEndpointConfigOperator,
    SageMakerEndpointOperator,
    SageMakerModelOperator,
    SageMakerTrainingOperator,
)
from airflow.providers.amazon.aws.sensors.sagemaker import SageMakerEndpointSensor
from airflow.utils.trigger_rule import TriggerRule
from tests.system.providers.amazon.aws.utils import ENV_ID_KEY, SystemTestContextBuilder, purge_logs

[docs]DAG_ID = "example_sagemaker_endpoint"
# Externally fetched variables:
[docs]ROLE_ARN_KEY = "ROLE_ARN"
[docs]sys_test_context_task = SystemTestContextBuilder().add_variable(ROLE_ARN_KEY).build()
# The URI of a Docker image for handling KNN model training. # To find the URI of a free Amazon-provided image that can be used, substitute your # desired region in the following link and find the URI under "Registry Path". # https://docs.aws.amazon.com/sagemaker/latest/dg/ecr-us-east-1.html#knn-us-east-1.title # This URI should be in the format of {12-digits}.dkr.ecr.{region}.amazonaws.com/knn
[docs]KNN_IMAGES_BY_REGION = { "us-east-1": "382416733822.dkr.ecr.us-east-1.amazonaws.com/knn:1", "us-west-2": "174872318107.dkr.ecr.us-west-2.amazonaws.com/knn:1",
} # For an example of how to obtain the following train and test data, please see # https://github.com/apache/airflow/blob/main/airflow/providers/amazon/aws/example_dags/example_sagemaker.py
[docs]TRAIN_DATA = "0,4.9,2.5,4.5,1.7\n1,7.0,3.2,4.7,1.4\n0,7.3,2.9,6.3,1.8\n2,5.1,3.5,1.4,0.2\n"
[docs]SAMPLE_TEST_DATA = "6.4,3.2,4.5,1.5"
@task
[docs]def call_endpoint(endpoint_name): response = ( boto3.Session() .client("sagemaker-runtime") .invoke_endpoint( EndpointName=endpoint_name, ContentType="text/csv", Body=SAMPLE_TEST_DATA, ) ) return json.loads(response["Body"].read().decode())["predictions"]
@task(trigger_rule=TriggerRule.ALL_DONE)
[docs]def delete_endpoint_config(endpoint_config_job_name): boto3.client("sagemaker").delete_endpoint_config(EndpointConfigName=endpoint_config_job_name)
@task(trigger_rule=TriggerRule.ALL_DONE)
[docs]def delete_endpoint(endpoint_name): boto3.client("sagemaker").delete_endpoint(EndpointName=endpoint_name)
@task(trigger_rule=TriggerRule.ALL_DONE)
[docs]def delete_logs(env_id, endpoint_name): purge_logs( [ # Format: ('log group name', 'log stream prefix') ("/aws/sagemaker/TrainingJobs", env_id), ] ) purge_logs(test_logs=[(f"/aws/sagemaker/Endpoints/{endpoint_name}", None)], force_delete=True)
@task
[docs]def set_up(env_id, role_arn, ti=None): bucket_name = f"{env_id}-sagemaker" input_data_s3_key = f"{env_id}/input-data" training_output_s3_key = f"{env_id}/results" endpoint_config_job_name = f"{env_id}-endpoint-config" endpoint_name = f"{env_id}-endpoint" model_name = f"{env_id}-KNN-model" training_job_name = f"{env_id}-train" region = boto3.session.Session().region_name try: knn_image_uri = KNN_IMAGES_BY_REGION[region] except KeyError: raise KeyError( f"Region name {region} does not have a known KNN " f"Image URI. Please add the region and URI following " f"the directions at the top of the system testfile " ) training_config = { "TrainingJobName": training_job_name, "RoleArn": role_arn, "AlgorithmSpecification": { "TrainingImage": knn_image_uri, "TrainingInputMode": "File", }, "HyperParameters": { "predictor_type": "classifier", "feature_dim": "4", "k": "3", "sample_size": str(TRAIN_DATA.count("\n") - 1), }, "InputDataConfig": [ { "ChannelName": "train", "CompressionType": "None", "ContentType": "text/csv", "DataSource": { "S3DataSource": { "S3DataDistributionType": "FullyReplicated", "S3DataType": "S3Prefix", "S3Uri": f"s3://{bucket_name}/{input_data_s3_key}/train.csv", } }, } ], "OutputDataConfig": {"S3OutputPath": f"s3://{bucket_name}/{training_output_s3_key}/"}, "ResourceConfig": { "InstanceCount": 1, "InstanceType": "ml.m5.large", "VolumeSizeInGB": 1, }, "StoppingCondition": {"MaxRuntimeInSeconds": 6 * 60}, } model_config = { "ModelName": model_name, "ExecutionRoleArn": role_arn, "PrimaryContainer": { "Mode": "SingleModel", "Image": knn_image_uri, "ModelDataUrl": f"s3://{bucket_name}/{training_output_s3_key}/{training_job_name}/output/model.tar.gz", # noqa: E501 }, } endpoint_config_config = { "EndpointConfigName": endpoint_config_job_name, "ProductionVariants": [ { "VariantName": f"{env_id}-demo", "ModelName": model_name, "InstanceType": "ml.t2.medium", "InitialInstanceCount": 1, }, ], } deploy_endpoint_config = { "EndpointName": endpoint_name, "EndpointConfigName": endpoint_config_job_name, } ti.xcom_push(key="bucket_name", value=bucket_name) ti.xcom_push(key="input_data_s3_key", value=input_data_s3_key) ti.xcom_push(key="model_name", value=model_name) ti.xcom_push(key="endpoint_name", value=endpoint_name) ti.xcom_push(key="endpoint_config_job_name", value=endpoint_config_job_name) ti.xcom_push(key="training_config", value=training_config) ti.xcom_push(key="model_config", value=model_config) ti.xcom_push(key="endpoint_config_config", value=endpoint_config_config) ti.xcom_push(key="deploy_endpoint_config", value=deploy_endpoint_config)
with DAG( dag_id=DAG_ID, schedule="@once", start_date=datetime(2021, 1, 1), tags=["example"], catchup=False, ) as dag:
[docs] test_context = sys_test_context_task()
test_setup = set_up( env_id=test_context[ENV_ID_KEY], role_arn=test_context[ROLE_ARN_KEY], ) create_bucket = S3CreateBucketOperator( task_id="create_bucket", bucket_name=test_setup["bucket_name"], ) upload_data = S3CreateObjectOperator( task_id="upload_data", s3_bucket=test_setup["bucket_name"], s3_key=f'{test_setup["input_data_s3_key"]}/train.csv', data=TRAIN_DATA, ) train_model = SageMakerTrainingOperator( task_id="train_model", config=test_setup["training_config"], ) create_model = SageMakerModelOperator( task_id="create_model", config=test_setup["model_config"], ) # [START howto_operator_sagemaker_endpoint_config] configure_endpoint = SageMakerEndpointConfigOperator( task_id="configure_endpoint", config=test_setup["endpoint_config_config"], ) # [END howto_operator_sagemaker_endpoint_config] # [START howto_operator_sagemaker_endpoint] deploy_endpoint = SageMakerEndpointOperator( task_id="deploy_endpoint", config=test_setup["deploy_endpoint_config"], ) # [END howto_operator_sagemaker_endpoint] # SageMakerEndpointOperator waits by default, setting as False to test the Sensor below. deploy_endpoint.wait_for_completion = False # [START howto_sensor_sagemaker_endpoint] await_endpoint = SageMakerEndpointSensor( task_id="await_endpoint", endpoint_name=test_setup["endpoint_name"], ) # [END howto_sensor_sagemaker_endpoint] delete_model = SageMakerDeleteModelOperator( task_id="delete_model", trigger_rule=TriggerRule.ALL_DONE, config={"ModelName": test_setup["model_name"]}, ) delete_bucket = S3DeleteBucketOperator( task_id="delete_bucket", trigger_rule=TriggerRule.ALL_DONE, bucket_name=test_setup["bucket_name"], force_delete=True, ) chain( # TEST SETUP test_context, test_setup, create_bucket, upload_data, # TEST BODY train_model, create_model, configure_endpoint, deploy_endpoint, await_endpoint, call_endpoint(test_setup["endpoint_name"]), # TEST TEARDOWN delete_endpoint_config(test_setup["endpoint_config_job_name"]), delete_endpoint(test_setup["endpoint_name"]), delete_model, delete_bucket, delete_logs(test_context[ENV_ID_KEY], test_setup["endpoint_name"]), ) from tests.system.utils.watcher import watcher # This test needs watcher in order to properly mark success/failure # when "tearDown" task with trigger rule is part of the DAG list(dag.tasks) >> watcher() from tests.system.utils import get_test_run # noqa: E402 # Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest)
[docs]test_run = get_test_run(dag)

Was this entry helpful?