Source code for tests.system.providers.amazon.aws.example_sagemaker

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

import base64
import json
import logging
import subprocess
from datetime import datetime
from tempfile import NamedTemporaryFile

import boto3

from airflow import DAG
from airflow.decorators import task
from airflow.models.baseoperator import chain
from airflow.operators.python import get_current_context
from airflow.providers.amazon.aws.operators.s3 import (
    S3CreateBucketOperator,
    S3CreateObjectOperator,
    S3DeleteBucketOperator,
)
from airflow.providers.amazon.aws.operators.sagemaker import (
    SageMakerAutoMLOperator,
    SageMakerCreateExperimentOperator,
    SageMakerDeleteModelOperator,
    SageMakerModelOperator,
    SageMakerProcessingOperator,
    SageMakerRegisterModelVersionOperator,
    SageMakerStartPipelineOperator,
    SageMakerStopPipelineOperator,
    SageMakerTrainingOperator,
    SageMakerTransformOperator,
    SageMakerTuningOperator,
)
from airflow.providers.amazon.aws.sensors.sagemaker import (
    SageMakerAutoMLSensor,
    SageMakerPipelineSensor,
    SageMakerTrainingSensor,
    SageMakerTransformSensor,
    SageMakerTuningSensor,
)
from airflow.utils.trigger_rule import TriggerRule
from tests.system.providers.amazon.aws.utils import ENV_ID_KEY, SystemTestContextBuilder, purge_logs

[docs]DAG_ID = "example_sagemaker"
# Externally fetched variables:
[docs]ROLE_ARN_KEY = "ROLE_ARN"
[docs]sys_test_context_task = SystemTestContextBuilder().add_variable(ROLE_ARN_KEY).build()
# The URI of a Docker image for handling KNN model training. # To find the URI of a free Amazon-provided image that can be used, substitute your # desired region in the following link and find the URI under "Registry Path". # https://docs.aws.amazon.com/sagemaker/latest/dg/ecr-us-east-1.html#knn-us-east-1.title # This URI should be in the format of {12-digits}.dkr.ecr.{region}.amazonaws.com/knn
[docs]KNN_IMAGES_BY_REGION = { "us-east-1": "382416733822.dkr.ecr.us-east-1.amazonaws.com/knn:1", "us-west-2": "174872318107.dkr.ecr.us-west-2.amazonaws.com/knn:1",
}
[docs]SAMPLE_SIZE = 600
# This script will be the entrypoint for the docker image which will handle preprocessing the raw data # NOTE: The following string must remain dedented as it is being written to a file.
[docs]PREPROCESS_SCRIPT_TEMPLATE = """ import boto3 import numpy as np import pandas as pd def main(): # Load the dataset from {input_path}/input.csv, split it into train/test # subsets, and write them to {output_path}/ for the Processing Operator. data = pd.read_csv('{input_path}/input.csv') # Split into test and train data data_train, data_test = np.split( data.sample(frac=1, random_state=np.random.RandomState()), [int(0.7 * len(data))] ) # Remove the "answers" from the test set data_test.drop(['class'], axis=1, inplace=True) # Write the splits to disk data_train.to_csv('{output_path}/train.csv', index=False, header=False) data_test.to_csv('{output_path}/test.csv', index=False, header=False) print('Preprocessing Done.') if __name__ == "__main__": main() """
def _create_ecr_repository(repo_name): execution_role_arn = boto3.client("sts").get_caller_identity()["Arn"] access_policy = { "Version": "2012-10-17", "Statement": [ { "Sid": "Allow access to the system test execution role", "Effect": "Allow", "Principal": {"AWS": execution_role_arn}, "Action": "ecr:*", } ], } client = boto3.client("ecr") repo = client.create_repository(repositoryName=repo_name)["repository"] client.set_repository_policy(repositoryName=repo["repositoryName"], policyText=json.dumps(access_policy)) return repo["repositoryUri"] def _build_and_upload_docker_image(preprocess_script, repository_uri): """ We need a Docker image with the following requirements: - Has numpy, pandas, requests, and boto3 installed - Has our data preprocessing script mounted and set as the entry point """ ecr_region = repository_uri.split(".")[3] # Fetch and parse ECR Token to be used for the docker push token = boto3.client("ecr", region_name=ecr_region).get_authorization_token() credentials = (base64.b64decode(token["authorizationData"][0]["authorizationToken"])).decode("utf-8") username, password = credentials.split(":") with NamedTemporaryFile(mode="w+t") as preprocessing_script, NamedTemporaryFile(mode="w+t") as dockerfile: preprocessing_script.write(preprocess_script) preprocessing_script.flush() dockerfile.write( f""" FROM amazonlinux COPY {preprocessing_script.name.split('/')[2]} /preprocessing.py ADD credentials /credentials ENV AWS_SHARED_CREDENTIALS_FILE=/credentials RUN yum install python3 pip -y RUN pip3 install boto3 pandas requests CMD [ "python3", "/preprocessing.py"] """ ) dockerfile.flush() docker_build_and_push_commands = f""" cp /root/.aws/credentials /tmp/credentials && docker build --platform=linux/amd64 -f {dockerfile.name} -t {repository_uri} /tmp && rm /tmp/credentials && aws ecr get-login-password --region {ecr_region} | docker login --username {username} --password {password} {repository_uri} && docker push {repository_uri} """ docker_build = subprocess.Popen( docker_build_and_push_commands, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, ) _, stderr = docker_build.communicate() if docker_build.returncode != 0: raise RuntimeError( "Failed to push docker image to the repository. The following error " f"message may be useful, but can occasionally be misleading: {stderr}" )
[docs]def generate_data() -> str: """generates a very simple csv dataset with headers""" content = "class,x,y\n" # headers for i in range(SAMPLE_SIZE): content += f"{i%100},{i},{SAMPLE_SIZE-i}\n" return content
@task
[docs]def set_up(env_id, role_arn): bucket_name = f"{env_id}-sagemaker-example" ecr_repository_name = f"{env_id}-repo" model_name = f"{env_id}-KNN-model" processing_job_name = f"{env_id}-processing" training_job_name = f"{env_id}-train" transform_job_name = f"{env_id}-transform" tuning_job_name = f"{env_id}-tune" model_package_group_name = f"{env_id}-group" pipeline_name = f"{env_id}-pipe" auto_ml_job_name = f"{env_id}-automl" experiment_name = f"{env_id}-experiment" input_data_S3_key = f"{env_id}/processed-input-data" prediction_output_s3_key = f"{env_id}/transform" processing_local_input_path = "/opt/ml/processing/input" processing_local_output_path = "/opt/ml/processing/output" raw_data_s3_key = f"{env_id}/preprocessing/input.csv" training_output_s3_key = f"{env_id}/results" ecr_repository_uri = _create_ecr_repository(ecr_repository_name) region = boto3.session.Session().region_name try: knn_image_uri = KNN_IMAGES_BY_REGION[region] except KeyError: raise KeyError( f"Region name {region} does not have a known KNN " f"Image URI. Please add the region and URI following " f"the directions at the top of the system testfile " ) # Json definition for a dummy pipeline of 30 chained "conditional step" checking that 3 < 6 # Each step takes roughly 1 second to execute, so the pipeline runtimes is ~30 seconds, which should be # enough to test stopping and awaiting without race conditions. # Built using sagemaker sdk, and using json.loads(pipeline.definition()) pipeline_json_definition = """{"Version": "2020-12-01", "Metadata": {}, "Parameters": [], "PipelineExperimentConfig": {"ExperimentName": {"Get": "Execution.PipelineName"}, "TrialName": {"Get": "Execution.PipelineExecutionId"}}, "Steps": [{"Name": "DummyCond29", "Type": "Condition", "Arguments": {"Conditions": [{"Type": "LessThanOrEqualTo", "LeftValue": 3.0, "RightValue": 6.0}], "IfSteps": [{"Name": "DummyCond28", "Type": "Condition", "Arguments": {"Conditions": [{"Type": "LessThanOrEqualTo", "LeftValue": 3.0, "RightValue": 6.0}], "IfSteps": [{"Name": "DummyCond27", "Type": "Condition", "Arguments": {"Conditions": [{"Type": "LessThanOrEqualTo", "LeftValue": 3.0, "RightValue": 6.0}], "IfSteps": [{"Name": "DummyCond26", "Type": "Condition", "Arguments": {"Conditions": [{"Type": "LessThanOrEqualTo", "LeftValue": 3.0, "RightValue": 6.0}], "IfSteps": [{"Name": "DummyCond25", "Type": "Condition", "Arguments": {"Conditions": [{"Type": "LessThanOrEqualTo", "LeftValue": 3.0, "RightValue": 6.0}], "IfSteps": [{"Name": "DummyCond24", "Type": "Condition", "Arguments": {"Conditions": [{"Type": "LessThanOrEqualTo", "LeftValue": 3.0, "RightValue": 6.0}], "IfSteps": [{"Name": "DummyCond23", "Type": "Condition", "Arguments": {"Conditions": [{"Type": "LessThanOrEqualTo", "LeftValue": 3.0, "RightValue": 6.0}], "IfSteps": [{"Name": "DummyCond22", "Type": "Condition", "Arguments": {"Conditions": [{"Type": "LessThanOrEqualTo", "LeftValue": 3.0, "RightValue": 6.0}], "IfSteps": [{"Name": "DummyCond21", "Type": "Condition", "Arguments": {"Conditions": [{"Type": "LessThanOrEqualTo", "LeftValue": 3.0, "RightValue": 6.0}], "IfSteps": [{"Name": "DummyCond20", "Type": "Condition", "Arguments": {"Conditions": [{"Type": "LessThanOrEqualTo", "LeftValue": 3.0, "RightValue": 6.0}], "IfSteps": [{"Name": "DummyCond19", "Type": "Condition", "Arguments": {"Conditions": [{"Type": "LessThanOrEqualTo", "LeftValue": 3.0, "RightValue": 6.0}], "IfSteps": [{"Name": "DummyCond18", "Type": "Condition", "Arguments": {"Conditions": [{"Type": "LessThanOrEqualTo", "LeftValue": 3.0, "RightValue": 6.0}], "IfSteps": [{"Name": "DummyCond17", "Type": "Condition", "Arguments": {"Conditions": [{"Type": "LessThanOrEqualTo", "LeftValue": 3.0, "RightValue": 6.0}], "IfSteps": [{"Name": "DummyCond16", "Type": "Condition", "Arguments": {"Conditions": [{"Type": "LessThanOrEqualTo", "LeftValue": 3.0, "RightValue": 6.0}], "IfSteps": [{"Name": "DummyCond15", "Type": "Condition", "Arguments": {"Conditions": [{"Type": "LessThanOrEqualTo", "LeftValue": 3.0, "RightValue": 6.0}], "IfSteps": [{"Name": "DummyCond14", "Type": "Condition", "Arguments": {"Conditions": [{"Type": "LessThanOrEqualTo", "LeftValue": 3.0, "RightValue": 6.0}], "IfSteps": [{"Name": "DummyCond13", "Type": "Condition", "Arguments": {"Conditions": [{"Type": "LessThanOrEqualTo", "LeftValue": 3.0, "RightValue": 6.0}], "IfSteps": [{"Name": "DummyCond12", "Type": "Condition", "Arguments": {"Conditions": [{"Type": "LessThanOrEqualTo", "LeftValue": 3.0, "RightValue": 6.0}], "IfSteps": [{"Name": "DummyCond11", "Type": "Condition", "Arguments": {"Conditions": [{"Type": "LessThanOrEqualTo", "LeftValue": 3.0, "RightValue": 6.0}], "IfSteps": [{"Name": "DummyCond10", "Type": "Condition", "Arguments": {"Conditions": [{"Type": "LessThanOrEqualTo", "LeftValue": 3.0, "RightValue": 6.0}], "IfSteps": [{"Name": "DummyCond9", "Type": "Condition", "Arguments": {"Conditions": [{"Type": "LessThanOrEqualTo", "LeftValue": 3.0, "RightValue": 6.0}], "IfSteps": [{"Name": "DummyCond8", "Type": "Condition", "Arguments": {"Conditions": [{"Type": "LessThanOrEqualTo", "LeftValue": 3.0, "RightValue": 6.0}], "IfSteps": [{"Name": "DummyCond7", "Type": "Condition", "Arguments": {"Conditions": [{"Type": "LessThanOrEqualTo", "LeftValue": 3.0, "RightValue": 6.0}], "IfSteps": [{"Name": "DummyCond6", "Type": "Condition", "Arguments": {"Conditions": [{"Type": "LessThanOrEqualTo", "LeftValue": 3.0, "RightValue": 6.0}], "IfSteps": [{"Name": "DummyCond5", "Type": "Condition", "Arguments": {"Conditions": [{"Type": "LessThanOrEqualTo", "LeftValue": 3.0, "RightValue": 6.0}], "IfSteps": [{"Name": "DummyCond4", "Type": "Condition", "Arguments": {"Conditions": [{"Type": "LessThanOrEqualTo", "LeftValue": 3.0, "RightValue": 6.0}], "IfSteps": [{"Name": "DummyCond3", "Type": "Condition", "Arguments": {"Conditions": [{"Type": "LessThanOrEqualTo", "LeftValue": 3.0, "RightValue": 6.0}], "IfSteps": [{"Name": "DummyCond2", "Type": "Condition", "Arguments": {"Conditions": [{"Type": "LessThanOrEqualTo", "LeftValue": 3.0, "RightValue": 6.0}], "IfSteps": [{"Name": "DummyCond1", "Type": "Condition", "Arguments": {"Conditions": [{"Type": "LessThanOrEqualTo", "LeftValue": 3.0, "RightValue": 6.0}], "IfSteps": [{"Name": "DummyCond0", "Type": "Condition", "Arguments": {"Conditions": [{"Type": "LessThanOrEqualTo", "LeftValue": 3.0, "RightValue": 6.0}], "IfSteps": [{"Name": "DummyCond", "Type": "Condition", "Arguments": {"Conditions": [{"Type": "LessThanOrEqualTo", "LeftValue": 3.0, "RightValue": 6.0}], "IfSteps": [], "ElseSteps": []}}], "ElseSteps": []}}], "ElseSteps": []}}], "ElseSteps": []}}], "ElseSteps": []}}], "ElseSteps": []}}], "ElseSteps": []}}], "ElseSteps": []}}], "ElseSteps": []}}], "ElseSteps": []}}], "ElseSteps": []}}], "ElseSteps": []}}], "ElseSteps": []}}], "ElseSteps": []}}], "ElseSteps": []}}], "ElseSteps": []}}], "ElseSteps": []}}], "ElseSteps": []}}], "ElseSteps": []}}], "ElseSteps": []}}], "ElseSteps": []}}], "ElseSteps": []}}], "ElseSteps": []}}], "ElseSteps": []}}], "ElseSteps": []}}], "ElseSteps": []}}], "ElseSteps": []}}], "ElseSteps": []}}], "ElseSteps": []}}], "ElseSteps": []}}], "ElseSteps": []}}]}""" # noqa: E501 sgmk_client = boto3.client("sagemaker") sgmk_client.create_pipeline( PipelineName=pipeline_name, PipelineDefinition=pipeline_json_definition, RoleArn=role_arn ) resource_config = { "InstanceCount": 1, "InstanceType": "ml.m5.large", "VolumeSizeInGB": 1, } input_data_uri = f"s3://{bucket_name}/{raw_data_s3_key}" processing_config = { "ProcessingJobName": processing_job_name, "ProcessingInputs": [ { "InputName": "input", "AppManaged": False, "S3Input": { "S3Uri": input_data_uri, "LocalPath": processing_local_input_path, "S3DataType": "S3Prefix", "S3InputMode": "File", "S3DataDistributionType": "FullyReplicated", "S3CompressionType": "None", }, }, ], "ProcessingOutputConfig": { "Outputs": [ { "OutputName": "output", "S3Output": { "S3Uri": f"s3://{bucket_name}/{input_data_S3_key}", "LocalPath": processing_local_output_path, "S3UploadMode": "EndOfJob", }, "AppManaged": False, } ] }, "ProcessingResources": { "ClusterConfig": resource_config, }, "StoppingCondition": {"MaxRuntimeInSeconds": 60}, "AppSpecification": { "ImageUri": ecr_repository_uri, }, "RoleArn": role_arn, } training_data_source = { "CompressionType": "None", "ContentType": "text/csv", "DataSource": { "S3DataSource": { "S3DataDistributionType": "FullyReplicated", "S3DataType": "S3Prefix", "S3Uri": f"s3://{bucket_name}/{input_data_S3_key}/train.csv", } }, } training_config = { "AlgorithmSpecification": { "TrainingImage": knn_image_uri, "TrainingInputMode": "File", }, "HyperParameters": { "predictor_type": "classifier", "feature_dim": "2", "k": "3", "sample_size": str(SAMPLE_SIZE), }, "InputDataConfig": [ { "ChannelName": "train", **training_data_source, } ], "OutputDataConfig": {"S3OutputPath": f"s3://{bucket_name}/{training_output_s3_key}/"}, "ExperimentConfig": {"ExperimentName": experiment_name}, "ResourceConfig": resource_config, "RoleArn": role_arn, "StoppingCondition": {"MaxRuntimeInSeconds": 60}, "TrainingJobName": training_job_name, } model_trained_weights = ( f"s3://{bucket_name}/{training_output_s3_key}/{training_job_name}/output/model.tar.gz" ) model_config = { "ExecutionRoleArn": role_arn, "ModelName": model_name, "PrimaryContainer": { "Mode": "SingleModel", "Image": knn_image_uri, "ModelDataUrl": model_trained_weights, }, } tuning_config = { "HyperParameterTuningJobName": tuning_job_name, "HyperParameterTuningJobConfig": { "Strategy": "Bayesian", "HyperParameterTuningJobObjective": { "MetricName": "test:accuracy", "Type": "Maximize", }, "ResourceLimits": { "MaxNumberOfTrainingJobs": 10, "MaxParallelTrainingJobs": 10, }, "ParameterRanges": { "CategoricalParameterRanges": [], "IntegerParameterRanges": [ # Set the min and max values of the hyperparameters you want to tune. { "Name": "k", "MinValue": "1", "MaxValue": str(SAMPLE_SIZE), }, { "Name": "sample_size", "MinValue": "1", "MaxValue": str(SAMPLE_SIZE), }, ], }, }, "TrainingJobDefinition": { "StaticHyperParameters": { "predictor_type": "classifier", "feature_dim": "2", }, "AlgorithmSpecification": {"TrainingImage": knn_image_uri, "TrainingInputMode": "File"}, "InputDataConfig": [ { "ChannelName": "train", **training_data_source, }, { "ChannelName": "test", **training_data_source, }, ], "OutputDataConfig": {"S3OutputPath": f"s3://{bucket_name}/{training_output_s3_key}"}, "ResourceConfig": resource_config, "RoleArn": role_arn, "StoppingCondition": {"MaxRuntimeInSeconds": 60}, }, } transform_config = { "TransformJobName": transform_job_name, "TransformInput": { "DataSource": { "S3DataSource": { "S3DataType": "S3Prefix", "S3Uri": f"s3://{bucket_name}/{input_data_S3_key}/test.csv", } }, "SplitType": "Line", "ContentType": "text/csv", }, "TransformOutput": {"S3OutputPath": f"s3://{bucket_name}/{prediction_output_s3_key}"}, "TransformResources": { "InstanceCount": 1, "InstanceType": "ml.m5.large", }, "ModelName": model_name, } preprocess_script = PREPROCESS_SCRIPT_TEMPLATE.format( input_path=processing_local_input_path, output_path=processing_local_output_path ) _build_and_upload_docker_image(preprocess_script, ecr_repository_uri) ti = get_current_context()["ti"] ti.xcom_push(key="docker_image", value=ecr_repository_uri) ti.xcom_push(key="bucket_name", value=bucket_name) ti.xcom_push(key="raw_data_s3_key", value=raw_data_s3_key) ti.xcom_push(key="ecr_repository_name", value=ecr_repository_name) ti.xcom_push(key="processing_config", value=processing_config) ti.xcom_push(key="input_data_uri", value=input_data_uri) ti.xcom_push(key="output_data_uri", value=f"s3://{bucket_name}/{training_output_s3_key}") ti.xcom_push(key="training_config", value=training_config) ti.xcom_push(key="training_job_name", value=training_job_name) ti.xcom_push(key="model_package_group_name", value=model_package_group_name) ti.xcom_push(key="pipeline_name", value=pipeline_name) ti.xcom_push(key="auto_ml_job_name", value=auto_ml_job_name) ti.xcom_push(key="experiment_name", value=experiment_name) ti.xcom_push(key="model_config", value=model_config) ti.xcom_push(key="model_name", value=model_name) ti.xcom_push(key="inference_code_image", value=knn_image_uri) ti.xcom_push(key="model_trained_weights", value=model_trained_weights) ti.xcom_push(key="tuning_config", value=tuning_config) ti.xcom_push(key="tuning_job_name", value=tuning_job_name) ti.xcom_push(key="transform_config", value=transform_config) ti.xcom_push(key="transform_job_name", value=transform_job_name)
@task(trigger_rule=TriggerRule.ALL_DONE)
[docs]def delete_ecr_repository(repository_name): client = boto3.client("ecr") # All images must be removed from the repo before it can be deleted. image_ids = client.list_images(repositoryName=repository_name)["imageIds"] client.batch_delete_image( repositoryName=repository_name, imageIds=[{"imageDigest": image["imageDigest"] for image in image_ids}], ) client.delete_repository(repositoryName=repository_name)
@task(trigger_rule=TriggerRule.ALL_DONE)
[docs]def delete_logs(env_id): generated_logs = [ # Format: ('log group name', 'log stream prefix') ("/aws/sagemaker/ProcessingJobs", env_id), ("/aws/sagemaker/TrainingJobs", env_id), ("/aws/sagemaker/TransformJobs", env_id), ] purge_logs(generated_logs)
@task(trigger_rule=TriggerRule.ALL_DONE)
[docs]def delete_model_group(group_name, model_version_arn): sgmk_client = boto3.client("sagemaker") # need to destroy model registered in group first sgmk_client.delete_model_package(ModelPackageName=model_version_arn) sgmk_client.delete_model_package_group(ModelPackageGroupName=group_name)
@task(trigger_rule=TriggerRule.ALL_DONE)
[docs]def delete_pipeline(pipeline_name): sgmk_client = boto3.client("sagemaker") sgmk_client.delete_pipeline(PipelineName=pipeline_name)
@task(trigger_rule=TriggerRule.ALL_DONE)
[docs]def delete_experiment(name): sgmk_client = boto3.client("sagemaker") trials = sgmk_client.list_trials(ExperimentName=name) trials_names = [s["TrialName"] for s in trials["TrialSummaries"]] for trial in trials_names: components = sgmk_client.list_trial_components(TrialName=trial) components_names = [s["TrialComponentName"] for s in components["TrialComponentSummaries"]] for component in components_names: sgmk_client.disassociate_trial_component(TrialComponentName=component, TrialName=trial) sgmk_client.delete_trial_component(TrialComponentName=component) sgmk_client.delete_trial(TrialName=trial) sgmk_client.delete_experiment(ExperimentName=name)
@task(trigger_rule=TriggerRule.ALL_DONE)
[docs]def delete_docker_image(image_name): docker_build = subprocess.Popen( f"docker rmi {image_name}", shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, ) _, stderr = docker_build.communicate() if docker_build.returncode != 0: logging.error( "Failed to delete local docker image. "
"Run 'docker images' to see if you need to clean it yourself.\n" f"error message: {stderr}" ) with DAG( dag_id=DAG_ID, schedule="@once", start_date=datetime(2021, 1, 1), tags=["example"], catchup=False, ) as dag:
[docs] test_context = sys_test_context_task()
test_setup = set_up( env_id=test_context[ENV_ID_KEY], role_arn=test_context[ROLE_ARN_KEY], ) create_bucket = S3CreateBucketOperator( task_id="create_bucket", bucket_name=test_setup["bucket_name"], ) upload_dataset = S3CreateObjectOperator( task_id="upload_dataset", s3_bucket=test_setup["bucket_name"], s3_key=test_setup["raw_data_s3_key"], data=generate_data(), replace=True, ) # [START howto_operator_sagemaker_auto_ml] automl = SageMakerAutoMLOperator( task_id="auto_ML", job_name=test_setup["auto_ml_job_name"], s3_input=test_setup["input_data_uri"], target_attribute="class", s3_output=test_setup["output_data_uri"], role_arn=test_context[ROLE_ARN_KEY], time_limit=30, # will stop the job before it can do anything, but it's not the point here ) # [END howto_operator_sagemaker_auto_ml] automl.wait_for_completion = False # just to be able to test the sensor next # [START howto_sensor_sagemaker_auto_ml] await_automl = SageMakerAutoMLSensor(job_name=test_setup["auto_ml_job_name"], task_id="await_auto_ML") # [END howto_sensor_sagemaker_auto_ml] # [START howto_operator_sagemaker_start_pipeline] start_pipeline1 = SageMakerStartPipelineOperator( task_id="start_pipeline1", pipeline_name=test_setup["pipeline_name"], ) # [END howto_operator_sagemaker_start_pipeline] # [START howto_operator_sagemaker_stop_pipeline] stop_pipeline1 = SageMakerStopPipelineOperator( task_id="stop_pipeline1", pipeline_exec_arn=start_pipeline1.output, ) # [END howto_operator_sagemaker_stop_pipeline] start_pipeline2 = SageMakerStartPipelineOperator( task_id="start_pipeline2", pipeline_name=test_setup["pipeline_name"], ) # [START howto_sensor_sagemaker_pipeline] await_pipeline2 = SageMakerPipelineSensor( task_id="await_pipeline2", pipeline_exec_arn=start_pipeline2.output, ) # [END howto_sensor_sagemaker_pipeline] # [START howto_operator_sagemaker_experiment] create_experiment = SageMakerCreateExperimentOperator( task_id="create_experiment", name=test_setup["experiment_name"] ) # [END howto_operator_sagemaker_experiment] # [START howto_operator_sagemaker_processing] preprocess_raw_data = SageMakerProcessingOperator( task_id="preprocess_raw_data", config=test_setup["processing_config"], ) # [END howto_operator_sagemaker_processing] # [START howto_operator_sagemaker_training] train_model = SageMakerTrainingOperator( task_id="train_model", config=test_setup["training_config"], ) # [END howto_operator_sagemaker_training] # SageMakerTrainingOperator waits by default, setting as False to test the Sensor below. train_model.wait_for_completion = False # [START howto_sensor_sagemaker_training] await_training = SageMakerTrainingSensor( task_id="await_training", job_name=test_setup["training_job_name"], ) # [END howto_sensor_sagemaker_training] # [START howto_operator_sagemaker_model] create_model = SageMakerModelOperator( task_id="create_model", config=test_setup["model_config"], ) # [END howto_operator_sagemaker_model] # [START howto_operator_sagemaker_register] register_model = SageMakerRegisterModelVersionOperator( task_id="register_model", image_uri=test_setup["inference_code_image"], model_url=test_setup["model_trained_weights"], package_group_name=test_setup["model_package_group_name"], ) # [END howto_operator_sagemaker_register] # [START howto_operator_sagemaker_tuning] tune_model = SageMakerTuningOperator( task_id="tune_model", config=test_setup["tuning_config"], ) # [END howto_operator_sagemaker_tuning] # SageMakerTuningOperator waits by default, setting as False to test the Sensor below. tune_model.wait_for_completion = False # [START howto_sensor_sagemaker_tuning] await_tuning = SageMakerTuningSensor( task_id="await_tuning", job_name=test_setup["tuning_job_name"], ) # [END howto_sensor_sagemaker_tuning] # [START howto_operator_sagemaker_transform] test_model = SageMakerTransformOperator( task_id="test_model", config=test_setup["transform_config"], ) # [END howto_operator_sagemaker_transform] # SageMakerTransformOperator waits by default, setting as False to test the Sensor below. test_model.wait_for_completion = False # [START howto_sensor_sagemaker_transform] await_transform = SageMakerTransformSensor( task_id="await_transform", job_name=test_setup["transform_job_name"], ) # [END howto_sensor_sagemaker_transform] # [START howto_operator_sagemaker_delete_model] delete_model = SageMakerDeleteModelOperator( task_id="delete_model", config={"ModelName": test_setup["model_name"]}, ) # [END howto_operator_sagemaker_delete_model] delete_model.trigger_rule = TriggerRule.ALL_DONE delete_bucket = S3DeleteBucketOperator( task_id="delete_bucket", trigger_rule=TriggerRule.ALL_DONE, bucket_name=test_setup["bucket_name"], force_delete=True, ) chain( # TEST SETUP test_context, test_setup, create_bucket, upload_dataset, # TEST BODY automl, await_automl, start_pipeline1, start_pipeline2, stop_pipeline1, await_pipeline2, create_experiment, preprocess_raw_data, train_model, await_training, create_model, register_model, tune_model, await_tuning, test_model, await_transform, # TEST TEARDOWN delete_ecr_repository(test_setup["ecr_repository_name"]), delete_model_group(test_setup["model_package_group_name"], register_model.output), delete_model, delete_bucket, delete_experiment(test_setup["experiment_name"]), delete_pipeline(test_setup["pipeline_name"]), delete_logs(test_context[ENV_ID_KEY]), delete_docker_image(test_setup["docker_image"]), ) from tests.system.utils.watcher import watcher # This test needs watcher in order to properly mark success/failure # when "tearDown" task with trigger rule is part of the DAG list(dag.tasks) >> watcher() from tests.system.utils import get_test_run # noqa: E402 # Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest)
[docs]test_run = get_test_run(dag)

Was this entry helpful?