Skip to main content

Creating a Custom Task Plugin

Custom task plugins in Flyte allow you to extend the execution engine to support new task types. By implementing the Plugin interface, you can define how Flyte handles the lifecycle of a task, from its initial execution to completion or failure.

This guide walks you through creating a custom "Sleep" task plugin that uses persistent state to track its progress across multiple execution rounds.

Prerequisites

To build a custom plugin, you need the following packages from the Flyte codebase:

  • github.com/flyteorg/flyte/v2/flyteplugins/go/tasks/pluginmachinery/core: Contains the core plugin interfaces.
  • github.com/flyteorg/flyte/v2/flyteplugins/go/tasks/pluginmachinery: Provides the registry for plugin discovery.

Step 1: Define the Plugin and State Structures

First, define your plugin struct and a state struct to persist data between execution rounds. Unlike in-memory maps, the PluginState is managed by Flyte's PluginStateManager and persisted in the execution engine's backend.

package sleep

import (
"time"
"github.com/flyteorg/flyte/v2/flyteplugins/go/tasks/pluginmachinery/core"
)

const sleepTaskType = "core-sleep"

// PluginState is persisted across Handle calls using Gob encoding.
type PluginState struct {
StartTime *time.Time `json:"start_time"`
}

type Plugin struct{}

func (p *Plugin) GetID() string {
return sleepTaskType
}

func (p *Plugin) GetProperties() core.PluginProperties {
return core.PluginProperties{}
}

The PluginState struct will be used to store the time when the task actually started. Flyte uses PluginStateManager (found in executor/pkg/plugin/state_manager.go) to handle the serialization of this state using Gob encoding.

Step 2: Implement the Handle Method

The Handle method is the core of your plugin. It must be non-blocking and idempotent, as Flyte may call it multiple times for the same task execution.

func (p *Plugin) Handle(ctx context.Context, tCtx core.TaskExecutionContext) (core.Transition, error) {
// 1. Read the current state
state := &PluginState{}
if _, err := tCtx.PluginStateReader().Get(state); err != nil {
return core.UnknownTransition, err
}

// 2. Resolve inputs (e.g., sleep duration)
sleepDuration, err := resolveSleepDuration(ctx, tCtx)
if err != nil {
return core.DoTransition(core.PhaseInfoFailure("InvalidInput", err.Error(), nil)), nil
}

// 3. Check if this is the first time Handle is called
if state.StartTime == nil {
now := time.Now()
state.StartTime = &now

// Persist the state for the next round
if err := tCtx.PluginStateWriter().Put(0, state); err != nil {
return core.UnknownTransition, err
}

return core.DoTransition(core.PhaseInfoRunning(core.DefaultPhaseVersion, nil)), nil
}

// 4. Check if the duration has elapsed
if time.Since(*state.StartTime) >= sleepDuration {
return core.DoTransition(core.PhaseInfoSuccess(nil)), nil
}

// Still running
return core.DoTransition(core.PhaseInfoRunning(core.DefaultPhaseVersion, nil)), nil
}

In this implementation:

  • tCtx.PluginStateReader().Get(state) retrieves the state written in the previous round.
  • tCtx.PluginStateWriter().Put(0, state) saves the state. Note that this data is not accessible until the next call to Handle.
  • core.DoTransition moves the task through different phases (e.g., PhaseInfoRunning, PhaseInfoSuccess).

Step 3: Implement Abort and Finalize

The Abort method is called if the task is cancelled, while Finalize is always called after Handle or Abort to perform cleanup.

func (p *Plugin) Abort(ctx context.Context, tCtx core.TaskExecutionContext) error {
// Clean up any external resources if necessary
return nil
}

func (p *Plugin) Finalize(ctx context.Context, tCtx core.TaskExecutionContext) error {
// Final cleanup logic
return nil
}

Step 4: Register the Plugin

To make your plugin available to the Flyte execution engine, register it in an init() function using the PluginRegistry.

import (
"github.com/flyteorg/flyte/v2/flyteplugins/go/tasks/pluginmachinery"
)

func init() {
pluginmachinery.PluginRegistry().RegisterCorePlugin(
core.PluginEntry{
ID: sleepTaskType,
RegisteredTaskTypes: []core.TaskType{sleepTaskType},
LoadPlugin: func(ctx context.Context, iCtx core.SetupContext) (core.Plugin, error) {
return &Plugin{}, nil
},
IsDefault: false,
},
)
}

The PluginRegistry (which implements PluginRegistryIface in executor/pkg/plugin/registry.go) allows Flyte to discover your plugin by its ID and the RegisteredTaskTypes it supports.

How it Works in Flyte

When a task of type core-sleep is executed:

  1. Flyte's executor looks up the plugin in the PluginRegistry.
  2. It creates a taskExecutionContext (defined in executor/pkg/plugin/task_exec_context.go) which provides the plugin with access to inputs, outputs, and state management.
  3. The PluginStateManager loads any existing state from the task's status and provides it via the PluginStateReader.
  4. Your Handle method is called. If you call PluginStateWriter().Put(), the PluginStateManager captures the new state and persists it back to the Flyte database after Handle returns.