airflow example_mlengine 源码

  • 2022-10-20
  • 浏览 (239)

airflow example_mlengine 代码

文件路径:/airflow/providers/google/cloud/example_dags/example_mlengine.py

#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Example Airflow DAG for Google ML Engine service.
"""
from __future__ import annotations

import os
from datetime import datetime
from typing import Any

from airflow import models
from airflow.operators.bash import BashOperator
from airflow.providers.google.cloud.operators.mlengine import (
    MLEngineCreateModelOperator,
    MLEngineCreateVersionOperator,
    MLEngineDeleteModelOperator,
    MLEngineDeleteVersionOperator,
    MLEngineGetModelOperator,
    MLEngineListVersionsOperator,
    MLEngineSetDefaultVersionOperator,
    MLEngineStartBatchPredictionJobOperator,
    MLEngineStartTrainingJobOperator,
)
from airflow.providers.google.cloud.utils import mlengine_operator_utils

PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "example-project")

MODEL_NAME = os.environ.get("GCP_MLENGINE_MODEL_NAME", "model_name")

SAVED_MODEL_PATH = os.environ.get("GCP_MLENGINE_SAVED_MODEL_PATH", "gs://INVALID BUCKET NAME/saved-model/")
JOB_DIR = os.environ.get("GCP_MLENGINE_JOB_DIR", "gs://INVALID BUCKET NAME/keras-job-dir")
PREDICTION_INPUT = os.environ.get(
    "GCP_MLENGINE_PREDICTION_INPUT", "gs://INVALID BUCKET NAME/prediction_input.json"
)
PREDICTION_OUTPUT = os.environ.get(
    "GCP_MLENGINE_PREDICTION_OUTPUT", "gs://INVALID BUCKET NAME/prediction_output"
)
TRAINER_URI = os.environ.get("GCP_MLENGINE_TRAINER_URI", "gs://INVALID BUCKET NAME/trainer.tar.gz")
TRAINER_PY_MODULE = os.environ.get("GCP_MLENGINE_TRAINER_TRAINER_PY_MODULE", "trainer.task")

SUMMARY_TMP = os.environ.get("GCP_MLENGINE_DATAFLOW_TMP", "gs://INVALID BUCKET NAME/tmp/")
SUMMARY_STAGING = os.environ.get("GCP_MLENGINE_DATAFLOW_STAGING", "gs://INVALID BUCKET NAME/staging/")


with models.DAG(
    "example_gcp_mlengine",
    start_date=datetime(2021, 1, 1),
    catchup=False,
    tags=['example'],
    params={"model_name": MODEL_NAME},
) as dag:
    hyperparams: dict[str, Any] = {
        'goal': 'MAXIMIZE',
        'hyperparameterMetricTag': 'metric1',
        'maxTrials': 30,
        'maxParallelTrials': 1,
        'enableTrialEarlyStopping': True,
        'params': [],
    }

    hyperparams['params'].append(
        {
            'parameterName': 'hidden1',
            'type': 'INTEGER',
            'minValue': 40,
            'maxValue': 400,
            'scaleType': 'UNIT_LINEAR_SCALE',
        }
    )

    hyperparams['params'].append(
        {'parameterName': 'numRnnCells', 'type': 'DISCRETE', 'discreteValues': [1, 2, 3, 4]}
    )

    hyperparams['params'].append(
        {
            'parameterName': 'rnnCellType',
            'type': 'CATEGORICAL',
            'categoricalValues': [
                'BasicLSTMCell',
                'BasicRNNCell',
                'GRUCell',
                'LSTMCell',
                'LayerNormBasicLSTMCell',
            ],
        }
    )
    # [START howto_operator_gcp_mlengine_training]
    training = MLEngineStartTrainingJobOperator(
        task_id="training",
        project_id=PROJECT_ID,
        region="us-central1",
        job_id="training-job-{{ ts_nodash }}-{{ params.model_name }}",
        runtime_version="1.15",
        python_version="3.7",
        job_dir=JOB_DIR,
        package_uris=[TRAINER_URI],
        training_python_module=TRAINER_PY_MODULE,
        training_args=[],
        labels={"job_type": "training"},
        hyperparameters=hyperparams,
    )
    # [END howto_operator_gcp_mlengine_training]

    # [START howto_operator_gcp_mlengine_create_model]
    create_model = MLEngineCreateModelOperator(
        task_id="create-model",
        project_id=PROJECT_ID,
        model={
            "name": MODEL_NAME,
        },
    )
    # [END howto_operator_gcp_mlengine_create_model]

    # [START howto_operator_gcp_mlengine_get_model]
    get_model = MLEngineGetModelOperator(
        task_id="get-model",
        project_id=PROJECT_ID,
        model_name=MODEL_NAME,
    )
    # [END howto_operator_gcp_mlengine_get_model]

    # [START howto_operator_gcp_mlengine_print_model]
    get_model_result = BashOperator(
        bash_command=f"echo {get_model.output}",
        task_id="get-model-result",
    )
    # [END howto_operator_gcp_mlengine_print_model]

    # [START howto_operator_gcp_mlengine_create_version1]
    create_version = MLEngineCreateVersionOperator(
        task_id="create-version",
        project_id=PROJECT_ID,
        model_name=MODEL_NAME,
        version={
            "name": "v1",
            "description": "First-version",
            "deployment_uri": f'{JOB_DIR}/keras_export/',
            "runtime_version": "1.15",
            "machineType": "mls1-c1-m2",
            "framework": "TENSORFLOW",
            "pythonVersion": "3.7",
        },
    )
    # [END howto_operator_gcp_mlengine_create_version1]

    # [START howto_operator_gcp_mlengine_create_version2]
    create_version_2 = MLEngineCreateVersionOperator(
        task_id="create-version-2",
        project_id=PROJECT_ID,
        model_name=MODEL_NAME,
        version={
            "name": "v2",
            "description": "Second version",
            "deployment_uri": SAVED_MODEL_PATH,
            "runtime_version": "1.15",
            "machineType": "mls1-c1-m2",
            "framework": "TENSORFLOW",
            "pythonVersion": "3.7",
        },
    )
    # [END howto_operator_gcp_mlengine_create_version2]

    # [START howto_operator_gcp_mlengine_default_version]
    set_defaults_version = MLEngineSetDefaultVersionOperator(
        task_id="set-default-version",
        project_id=PROJECT_ID,
        model_name=MODEL_NAME,
        version_name="v2",
    )
    # [END howto_operator_gcp_mlengine_default_version]

    # [START howto_operator_gcp_mlengine_list_versions]
    list_version = MLEngineListVersionsOperator(
        task_id="list-version",
        project_id=PROJECT_ID,
        model_name=MODEL_NAME,
    )
    # [END howto_operator_gcp_mlengine_list_versions]

    # [START howto_operator_gcp_mlengine_print_versions]
    list_version_result = BashOperator(
        bash_command=f"echo {list_version.output}",
        task_id="list-version-result",
    )
    # [END howto_operator_gcp_mlengine_print_versions]

    # [START howto_operator_gcp_mlengine_get_prediction]
    prediction = MLEngineStartBatchPredictionJobOperator(
        task_id="prediction",
        project_id=PROJECT_ID,
        job_id="prediction-{{ ts_nodash }}-{{ params.model_name }}",
        region="us-central1",
        model_name=MODEL_NAME,
        data_format="TEXT",
        input_paths=[PREDICTION_INPUT],
        output_path=PREDICTION_OUTPUT,
        labels={"job_type": "prediction"},
    )
    # [END howto_operator_gcp_mlengine_get_prediction]

    # [START howto_operator_gcp_mlengine_delete_version]
    delete_version = MLEngineDeleteVersionOperator(
        task_id="delete-version", project_id=PROJECT_ID, model_name=MODEL_NAME, version_name="v1"
    )
    # [END howto_operator_gcp_mlengine_delete_version]

    # [START howto_operator_gcp_mlengine_delete_model]
    delete_model = MLEngineDeleteModelOperator(
        task_id="delete-model", project_id=PROJECT_ID, model_name=MODEL_NAME, delete_contents=True
    )
    # [END howto_operator_gcp_mlengine_delete_model]

    training >> create_version
    training >> create_version_2
    create_model >> get_model >> [get_model_result, delete_model]
    create_model >> get_model >> delete_model
    create_model >> create_version >> create_version_2 >> set_defaults_version >> list_version
    create_version >> prediction
    create_version_2 >> prediction
    prediction >> delete_version
    list_version >> list_version_result
    list_version >> delete_version
    delete_version >> delete_model

    # [START howto_operator_gcp_mlengine_get_metric]
    def get_metric_fn_and_keys():
        """
        Gets metric function and keys used to generate summary
        """

        def normalize_value(inst: dict):
            val = float(inst['dense_4'][0])
            return tuple([val])  # returns a tuple.

        return normalize_value, ['val']  # key order must match.

    # [END howto_operator_gcp_mlengine_get_metric]

    # [START howto_operator_gcp_mlengine_validate_error]
    def validate_err_and_count(summary: dict) -> dict:
        """
        Validate summary result
        """
        if summary['val'] > 1:
            raise ValueError(f'Too high val>1; summary={summary}')
        if summary['val'] < 0:
            raise ValueError(f'Too low val<0; summary={summary}')
        if summary['count'] != 20:
            raise ValueError(f'Invalid value val != 20; summary={summary}')
        return summary

    # [END howto_operator_gcp_mlengine_validate_error]

    # [START howto_operator_gcp_mlengine_evaluate]
    evaluate_prediction, evaluate_summary, evaluate_validation = mlengine_operator_utils.create_evaluate_ops(
        task_prefix="evaluate-ops",
        data_format="TEXT",
        input_paths=[PREDICTION_INPUT],
        prediction_path=PREDICTION_OUTPUT,
        metric_fn_and_keys=get_metric_fn_and_keys(),
        validate_fn=validate_err_and_count,
        batch_prediction_job_id="evaluate-ops-{{ ts_nodash }}-{{ params.model_name }}",
        project_id=PROJECT_ID,
        region="us-central1",
        dataflow_options={
            'project': PROJECT_ID,
            'tempLocation': SUMMARY_TMP,
            'stagingLocation': SUMMARY_STAGING,
        },
        model_name=MODEL_NAME,
        version_name="v1",
        py_interpreter="python3",
    )
    # [END howto_operator_gcp_mlengine_evaluate]

    create_model >> create_version >> evaluate_prediction
    evaluate_validation >> delete_version

相关信息

airflow 源码目录

相关文章

airflow init 源码

airflow example_automl_nl_text_classification 源码

airflow example_automl_nl_text_sentiment 源码

airflow example_automl_tables 源码

airflow example_automl_translation 源码

airflow example_automl_video_intelligence_classification 源码

airflow example_automl_video_intelligence_tracking 源码

airflow example_automl_vision_object_detection 源码

airflow example_bigquery_dts 源码

airflow example_bigtable 源码

0  赞