Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion pkg/tern/local_apply_grouped.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ func (c *LocalClient) executeGroupedApply(ctx context.Context, apply *storage.Ap
ctx, cancelApply := context.WithCancel(ctx)
defer cancelApply()
defer c.startApplyHeartbeat(ctx, apply, cancelApply)()
creds := c.credentials()
mode := groupedApplyMode(apply)
modeDescription := groupedApplyModeDescription(apply)

Expand Down Expand Up @@ -57,6 +56,11 @@ func (c *LocalClient) executeGroupedApply(ctx context.Context, apply *storage.Ap
fmt.Sprintf("MySQL applies support one namespace per apply, but plan has %d: %v", len(plan.Namespaces), names))
return
}
creds, err := c.credentialsForGroupedApply(plan)
if err != nil {
c.failApplyWithTasks(ctx, apply, tasks, err.Error())
return
}
changes := planNamespacesToChanges(plan.Namespaces)

// For Vitess: initialize the VitessApplyData row before the engine starts.
Expand Down
14 changes: 12 additions & 2 deletions pkg/tern/local_apply_sequential.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,16 @@ func (c *LocalClient) runEngineTask(ctx context.Context, apply *storage.Apply, t
} else if handled {
return taskStopped
}
taskCreds := creds
if c.config.Type == storage.DatabaseTypeMySQL {
var err error
taskCreds, err = c.credentialsForMySQLNamespace(task.Namespace)
if err != nil {
c.markTaskFailed(ctx, task, err.Error())
c.logger.Error("task failed to resolve namespace credentials", "error", err, "task_id", task.TaskIdentifier, "namespace", task.Namespace, "table", task.TableName)
return taskFailed
}
}

// Sequential mode: one DDL per engine call. Use the task identifier as
// MigrationContext so each table's schema change is tracked independently.
Expand All @@ -157,7 +167,7 @@ func (c *LocalClient) runEngineTask(ctx context.Context, apply *storage.Apply, t
}},
Options: options,
ResumeState: &engine.ResumeState{MigrationContext: task.TaskIdentifier},
Credentials: creds,
Credentials: taskCreds,
})

if err != nil {
Expand All @@ -182,7 +192,7 @@ func (c *LocalClient) runEngineTask(ctx context.Context, apply *storage.Apply, t
c.logger.Info("task running", "task_id", task.TaskIdentifier, "table", task.TableName)

// Poll to completion
pollAction := c.pollTaskToCompletion(ctx, apply, task, creds)
pollAction := c.pollTaskToCompletion(ctx, apply, task, taskCreds)
if pollAction == taskAbort || pollAction == taskStopped {
return pollAction
}
Expand Down
173 changes: 156 additions & 17 deletions pkg/tern/local_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ import (
"github.com/block/schemabot/pkg/mysqlconn"
ternv1 "github.com/block/schemabot/pkg/proto/ternv1"
"github.com/block/schemabot/pkg/psclient"
"github.com/block/schemabot/pkg/schema"
"github.com/block/schemabot/pkg/state"
"github.com/block/schemabot/pkg/storage"
)
Expand Down Expand Up @@ -253,6 +254,82 @@ func (c *LocalClient) credentials() *engine.Credentials {
}
}

func (c *LocalClient) credentialsForMySQLNamespace(namespace string) (*engine.Credentials, error) {
if c.config.Type != storage.DatabaseTypeMySQL {
return c.credentials(), nil
}
hasDatabase, err := mysqlDSNHasDatabase(c.config.TargetDSN)
if err != nil {
return nil, fmt.Errorf("inspect MySQL target DSN for namespace injection: %w", err)
}
// Transitional: a target DSN that already names a database is used as-is.
// The data-plane model is a namespace-free DSN with the schema injected per
// operation (below); existing static/local configs still carry the database
// in the DSN, and those keep working until they migrate to namespace-free.
if hasDatabase {
return c.credentials(), nil
}
// A namespace-free target DSN is the inventory/data-plane shape: the concrete
// namespace is the connection schema and must be injected per operation.
if namespace == "" {
return nil, fmt.Errorf("MySQL namespace is required for a namespace-free target DSN")
}
dsn, err := mysqlDSNWithDatabase(c.config.TargetDSN, namespace)
if err != nil {
return nil, err
}
return &engine.Credentials{
DSN: dsn,
Metadata: c.config.Metadata,
}, nil
}

func (c *LocalClient) credentialsForTask(task *storage.Task) (*engine.Credentials, error) {
if c.config.Type != storage.DatabaseTypeMySQL {
return c.credentials(), nil
}
if task == nil {
return nil, fmt.Errorf("task is required for MySQL credentials")
}
return c.credentialsForMySQLNamespace(task.Namespace)
}

// credentialsForGroupedApply resolves the single-namespace credentials for a
// grouped/atomic MySQL apply. A grouped apply runs one Spirit execution against
// one schema, so the plan must carry exactly one namespace. Fail closed rather
// than pick a namespace by map iteration order (or silently use a namespace-free
// DSN) if that invariant is ever violated.
func (c *LocalClient) credentialsForGroupedApply(plan *storage.Plan) (*engine.Credentials, error) {
if c.config.Type != storage.DatabaseTypeMySQL {
return c.credentials(), nil
}
if len(plan.Namespaces) != 1 {
return nil, fmt.Errorf("grouped MySQL apply requires exactly one namespace, plan has %d", len(plan.Namespaces))
}
var namespace string
for ns := range plan.Namespaces {
namespace = ns
}
return c.credentialsForMySQLNamespace(namespace)
}

func mysqlDSNWithDatabase(dsn, database string) (string, error) {
cfg, err := mysql.ParseDSN(dsn)
if err != nil {
return "", fmt.Errorf("parse MySQL DSN: %w", err)
}
cfg.DBName = database
return cfg.FormatDSN(), nil
}

func mysqlDSNHasDatabase(dsn string) (bool, error) {
cfg, err := mysql.ParseDSN(dsn)
if err != nil {
return false, fmt.Errorf("parse MySQL DSN: %w", err)
}
return cfg.DBName != "", nil
}

func (c *LocalClient) deferredCutoverSignalExists(ctx context.Context, apply *storage.Apply) (bool, bool, error) {
if apply == nil {
return false, false, fmt.Errorf("apply is required for deferred cutover signal lookup")
Expand Down Expand Up @@ -286,11 +363,20 @@ func (c *LocalClient) PullSchema(ctx context.Context, req *ternv1.PullSchemaRequ
return nil, fmt.Errorf("pull schema for database %s: request type %q does not match client type %q: %w", c.config.Database, req.Type, c.config.Type, ErrPullSchemaInvalidRequest)
}

targetDSN := c.config.TargetDSN
if c.config.Type == storage.DatabaseTypeMySQL {
creds, err := c.credentialsForMySQLNamespace(c.config.Database)
if err != nil {
return nil, fmt.Errorf("resolve database %s credentials for schema pull: %w", c.config.Database, err)
}
targetDSN = creds.DSN
}

attrs := []any{"database", c.config.Database}
attrs = append(attrs, dsnLogAttrs(c.config.TargetDSN)...)
attrs = append(attrs, dsnLogAttrs(targetDSN)...)
c.logger.Info("LocalClient.PullSchema: loading live schema", attrs...)

db, err := mysqlconn.Open(c.config.TargetDSN)
db, err := mysqlconn.Open(targetDSN)
if err != nil {
return nil, fmt.Errorf("open database %s for schema pull: %w", c.config.Database, err)
}
Expand Down Expand Up @@ -348,29 +434,15 @@ func (c *LocalClient) Plan(ctx context.Context, req *ternv1.PlanRequest) (*ternv
return nil, fmt.Errorf("type must be %q or %q", storage.DatabaseTypeMySQL, storage.DatabaseTypeVitess)
}

eng := c.getEngine()
if eng == nil {
return nil, fmt.Errorf("no engine configured for type: %s", c.config.Type)
}

// Convert schema files from proto to engine type
schemaFiles := protoToSchemaFiles(req.SchemaFiles)

creds := c.credentials()

planLogAttrs := []any{"database", c.config.Database}
planLogAttrs = append(planLogAttrs, dsnLogAttrs(c.config.TargetDSN)...)
planLogAttrs = append(planLogAttrs, "schema_file_count", len(schemaFiles))
c.logger.Info("LocalClient.Plan: calling engine", planLogAttrs...)

result, err := eng.Plan(ctx, &engine.PlanRequest{
Database: c.config.Database,
DatabaseType: c.config.Type,
SchemaFiles: schemaFiles,
Repository: req.Repository,
PullRequest: int(req.PullRequest),
Credentials: creds,
})
result, err := c.planWithEngine(ctx, req, c.config.Database, schemaFiles)
if err != nil {
c.logger.Error("plan failed", "error", err, "database", c.config.Database)
return nil, err // Error already has clear prefix (SQL syntax/usage error)
Expand Down Expand Up @@ -537,6 +609,73 @@ func (c *LocalClient) Plan(ctx context.Context, req *ternv1.PlanRequest) (*ternv
}, nil
}

func (c *LocalClient) planWithEngine(ctx context.Context, req *ternv1.PlanRequest, database string, schemaFiles schema.SchemaFiles) (*engine.PlanResult, error) {
eng := c.getEngine()
if eng == nil {
return nil, fmt.Errorf("no engine configured for type: %s", c.config.Type)
}
if c.config.Type != storage.DatabaseTypeMySQL {
return c.planNamespaceWithEngine(ctx, eng, req, database, schemaFiles, c.credentials())
}
hasDatabase, err := mysqlDSNHasDatabase(c.config.TargetDSN)
if err != nil {
return nil, err
}
if hasDatabase {
return c.planNamespaceWithEngine(ctx, eng, req, database, schemaFiles, c.credentials())
}
if len(schemaFiles) == 0 {
return nil, fmt.Errorf("schema files are required for namespace-free MySQL target DSN")
}
if len(schemaFiles) == 1 {
for namespace := range schemaFiles {
creds, err := c.credentialsForMySQLNamespace(namespace)
if err != nil {
return nil, err
}
return c.planNamespaceWithEngine(ctx, eng, req, namespace, schemaFiles, creds)
}
}
return c.planMySQLNamespacesWithEngine(ctx, eng, req, schemaFiles)
}

func (c *LocalClient) planNamespaceWithEngine(ctx context.Context, eng engine.Engine, req *ternv1.PlanRequest, database string, schemaFiles schema.SchemaFiles, creds *engine.Credentials) (*engine.PlanResult, error) {
return eng.Plan(ctx, &engine.PlanRequest{
Database: database,
DatabaseType: c.config.Type,
SchemaFiles: schemaFiles,
Repository: req.Repository,
PullRequest: int(req.PullRequest),
Credentials: creds,
})
}

func (c *LocalClient) planMySQLNamespacesWithEngine(ctx context.Context, eng engine.Engine, req *ternv1.PlanRequest, schemaFiles schema.SchemaFiles) (*engine.PlanResult, error) {
namespaces := make([]string, 0, len(schemaFiles))
for namespace := range schemaFiles {
namespaces = append(namespaces, namespace)
}
sort.Strings(namespaces)

result := &engine.PlanResult{PlanID: fmt.Sprintf("plan-%d", time.Now().UnixNano()), NoChanges: true}
for _, namespace := range namespaces {
creds, err := c.credentialsForMySQLNamespace(namespace)
if err != nil {
return nil, err
}
nsResult, err := c.planNamespaceWithEngine(ctx, eng, req, namespace, schema.SchemaFiles{namespace: schemaFiles[namespace]}, creds)
if err != nil {
return nil, fmt.Errorf("plan MySQL namespace %q: %w", namespace, err)
}
result.Changes = append(result.Changes, nsResult.Changes...)
result.LintViolations = append(result.LintViolations, nsResult.LintViolations...)
if !nsResult.NoChanges || len(nsResult.Changes) > 0 {
result.NoChanges = false
}
}
return result, nil
}

// Apply executes a previously generated plan.
// In local mode, Apply has additional conflict checking and polls for completion.
//
Expand Down
Loading
Loading