最近の砂場活動その5: AWS Step Functionsで機械学習のワークフローの管理をする

はてなブログのHTTPS配信をやっていた同僚からAWS Step Functionsはいいぞ!というのを教えてもらいました(発表資料)。機械学習のワークフロー管理にもこれは便利そうだなーと思って、自分でも試してみました。やってる内容はN番煎じです...。

機械学習とワークフローの管理

状態を持つワークフローの管理、機械学習でも難しいので悩むところですね。例えば

  • データの取得
  • 前処理
  • 特徴量の生成
  • モデルの学習
  • 検証データに対する精度をトラッキングできるように記録
  • S3等に学習済みのモデルファイルを配置
  • 新しいデータに対して予測を行なう
  • 全てが終わったらslackに通知

などがぱっと上げられますが、ときどきどこかがエラーでこけます。エラーでこけていてもretryして解決するならretryして欲しいし、何度も失敗するようだったらexponential backoffしながらretryして欲しいです。あまりに何度も失敗するようなら系として失敗にして欲しいし、系として失敗したならslackで通知をしたりフローの中のどこでこけたのかが一目瞭然に分かって欲しいものです。この問題の解決方法としてフローのどこの処理を現在行なっていて、どのジョブは前回どういう状態だったか(成功したのか、何回失敗したのかなど)を記録しておくという方法が考えられます。しかし、これを自前でやるのは明らかに面倒です。こういったことを引き受けてくれるマネージドサービスの一つにAWS Step Functionsがあります。

機械学習のワークフローをStep Functionsで管理する

ML Newsで定期的に機械学習関係のワークフローを回しているので、これをStep Functionsで管理します。小さなアプリケーションなのであまり大したことはしていないですが、以下の2つのことをしています。

  • 1: Twitterで話題になっているURLをクローリング(特徴量で必要になる本文等を取得)
  • 2: 学習データから分類器を構築、新規のURLに対して分類*1、推薦リストを更新する

1が終わる前に2が始まっても仕方ないですし、1が何らかの理由で失敗していたら2は開始しないで欲しいです*2。こういった制御をStep Functionsにやらせます。Step Functionsでワークフローの管理をするには、ステートマシーンを書くだけです。見てもらったら分かると思うけど、ステートマシンはこういうやつです。

フローが図として分かるの最高だし、水色の進行中のところを見れば今どこの処理をやっているか分かります。赤の失敗のところを見ればどこでこけたかがログを見なくても一目で分かりますし、失敗したステートをクリックするとそのコンポーネントのエラーログを見ることができます。長大になっているワークフローのどこでこけているかCloudWatch Logsから根性で探すというのは地獄なので、かなり便利であることが分かりますね。

詳細は他の方が書いているエントリを見てくれ!!

参考までに今回ステートマシンを作ったCloudFormationのコードの断片を置いておきます。

ステートマシンを構築するCloudFormation

  MyStateMachine:
    Type: AWS::StepFunctions::StateMachine
    Properties:
      StateMachineName: HelloWorld-StateMachine
      RoleArn: 
        Fn::ImportValue:
          !Sub "${IAMStackName}:StepFunctionsRole"
      DefinitionString: !Sub |
        {
          "StartAt": "AddRecentUrls",
          "States": {
            "AddRecentUrls": {
              "Type": "Task",
              "Resource": "arn:aws:lambda:${AWS::Region}:${AWS::AccountId}:function:BatchJobTriggerResource",
              "Next": "WaitXSecondsForAddingRecentUrls"
            },
            "WaitXSecondsForAddingRecentUrls": {
              "Type": "Wait",
              "Seconds": 60,
              "Next": "GetJobStatusOfAddingRecentUrls"
            },
            "GetJobStatusOfAddingRecentUrls": {
              "Type": "Task",
              "Resource": "arn:aws:lambda:${AWS::Region}:${AWS::AccountId}:function:PollCheckJobFunction",
              "ResultPath": "$.status",
              "Next": "HasJobCompleteAddingRecentUrls"
            },
            "HasJobCompleteAddingRecentUrls": {
              "Type": "Choice",
              "Choices": [
                {
                  "Variable": "$.status",
                  "StringEquals": "FAILED",
                  "Next": "JobFailedAddingRecentUrls"
                },
                {
                  "Variable": "$.status",
                  "StringEquals": "SUCCEEDED",
                  "Next": "UpdateRecommendation"
                }
              ],
              "Default": "WaitXSecondsForAddingRecentUrls"
            },
            "JobFailedAddingRecentUrls": {
              "Type": "Fail",
              "Cause": "AWS Batch Job Failed",
              "Error": "DescribeJob returned FAILED"
            },
            "UpdateRecommendation": {
              "Type": "Task",
              "Resource": "arn:aws:lambda:${AWS::Region}:${AWS::AccountId}:function:UpdateRecommendationBatchJobTriggerResource",
              "Next": "WaitXSecondsForUpdatingRecommendation"
            },
            "WaitXSecondsForUpdatingRecommendation": {
              "Type": "Wait",
              "Seconds": 60,
              "Next": "GetJobStatusOfUpdatingRecommendation"
            },
            "GetJobStatusOfUpdatingRecommendation": {
              "Type": "Task",
              "Resource": "arn:aws:lambda:${AWS::Region}:${AWS::AccountId}:function:PollCheckJobFunction",
              "ResultPath": "$.status",
              "Next": "HasJobCompleteUpdatingRecommendation"
            },
            "HasJobCompleteUpdatingRecommendation": {
              "Type": "Choice",
              "Choices": [
                {
                  "Variable": "$.status",
                  "StringEquals": "FAILED",
                  "Next": "JobFailedUpdatingRecommendation"
                },
                {
                  "Variable": "$.status",
                  "StringEquals": "SUCCEEDED",
                  "Next": "JobSucceedUpdatingRecommendation"
                }
              ],
              "Default": "WaitXSecondsForUpdatingRecommendation"
            },
            "JobFailedUpdatingRecommendation": {
              "Type": "Fail",
              "Cause": "AWS Batch Job Failed",
              "Error": "DescribeJob returned FAILED"
            },
            "JobSucceedUpdatingRecommendation": {
              "Type": "Succeed"
            }
          }
        }

上記のStep Functionsを定期的に起動するLambdaとCloud Watch Events Rule

  StepFunctionsTrigger:
    Type: AWS::Lambda::Function
    Description: "データ取得→学習および推薦リストの作成をやるStepFunctionsをkickするLambda Function"
    Properties:
      FunctionName: StepFunctionsTrigger
      Role: 
        Fn::ImportValue: !Sub "${IAMStackName}:StepFunctionsRole"
      Handler: index.lambda_handler
      Runtime: python3.6
      MemorySize: 128
      Timeout: 10
      Environment:
        Variables:
          STATE_MACHINE_ARN: !Ref MyStateMachine
      Code:
        ZipFile: |
          import os
          import boto3
          from datetime import datetime as dt
          
          client = boto3.client('stepfunctions')
          def lambda_handler(event, context):
              try:
                  response = client.start_execution(
                      stateMachineArn=os.environ['STATE_MACHINE_ARN'],
                      name=dt.now().strftime('%Y_%m_%d_%H_%M_%S'),
                  )
              except Exception as e:
                  raise e
  StepFunctionsTriggerRule:
    Type: AWS::Events::Rule
    Properties:
      Name: StepFunctionsTriggerRule
      ScheduleExpression: rate(3 hours)
      Targets:
        - Id: StepFunctionsTrigger
          Arn: !GetAtt StepFunctionsTrigger.Arn
      State: "ENABLED"
  StepFunctionsPermissionForEventsToInvokeLambda:
    Type: AWS::Lambda::Permission
    Properties:
      FunctionName: !Ref StepFunctionsTrigger 
      SourceArn: !GetAtt StepFunctionsTriggerRule.Arn
      Action: lambda:InvokeFunction
      Principal: events.amazonaws.com

AWS Step Functionsのモニタリング

ステートマシンの状態遷移を見ているのは楽しいですが、監視はモニタリングツールに任せたいですね。Mackerelをお使いの方はプラグインを入れてもらうとすぐにモニタリングできます。

実行時間や失敗した数などがメトリックとして取れるので、適宜監視を仕込むと安心です。

f:id:mackerelio:20180420191310p:plain

AWS Step Functionsの類似ツール

ワークフロー管理はAWS Step Functionsが初めて出したわけではなく、同じようなものがすでにいくつかあります。

これらのツールを使ったことがあるわけではないので的を外しているかもしれませんが、AWS Step Functionsを使うと

  • AWS Batch/Lambda/ECSのタスク/EMRなど既存のAWSスタックとの連携がスムーズにできる
  • マネージトサービスなので、ワークフロー監視システム自体の管理をする必要がない

といったところが利点かなと思います。

Amazon Web Services 基礎からのネットワーク&サーバー構築 改訂版

Amazon Web Services 基礎からのネットワーク&サーバー構築 改訂版

*1:バッチでまとめてやっています

*2:これまではCloudWatch Eventで適当に回していて、これくらい立ったら1が終わってるから2をやればよかろうとやっていたり、1が失敗しても2は問答無用で走る形になっていました...