From af1fbd90f260b96303bfb00f1ef7506871114fa5 Mon Sep 17 00:00:00 2001 From: Ruihao Chen Date: Thu, 11 Jun 2026 18:02:52 +0800 Subject: [PATCH 01/17] syncer(dm): add MariaDB AST DDL rewriter --- dm/pkg/ddl/rewriter/rewriter.go | 130 +++++++++ dm/pkg/ddl/rewriter/rewriter_test.go | 164 +++++++++++ dm/pkg/ddl/rewriter/rules.go | 419 +++++++++++++++++++++++++++ dm/syncer/ddl.go | 19 +- dm/syncer/ddl_test.go | 19 ++ dm/syncer/syncer.go | 11 +- 6 files changed, 758 insertions(+), 4 deletions(-) create mode 100644 dm/pkg/ddl/rewriter/rewriter.go create mode 100644 dm/pkg/ddl/rewriter/rewriter_test.go create mode 100644 dm/pkg/ddl/rewriter/rules.go diff --git a/dm/pkg/ddl/rewriter/rewriter.go b/dm/pkg/ddl/rewriter/rewriter.go new file mode 100644 index 0000000000..dda4bf5465 --- /dev/null +++ b/dm/pkg/ddl/rewriter/rewriter.go @@ -0,0 +1,130 @@ +// Copyright 2026 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package rewriter + +import ( + "strings" + + "github.com/pingcap/tidb/pkg/parser" + "github.com/pingcap/tidb/pkg/parser/ast" + "github.com/pingcap/tidb/pkg/parser/format" + _ "github.com/pingcap/tidb/pkg/types/parser_driver" // register parser driver +) + +// Rule rewrites one AST node in place. +type Rule interface { + Name() string + Apply(ast.Node) (bool, error) +} + +// Rewriter applies a fixed, ordered set of AST rules. +type Rewriter struct { + rules []Rule +} + +// NewRewriter creates a rewriter from explicit rules. +func NewRewriter(rules ...Rule) *Rewriter { + return &Rewriter{rules: append([]Rule(nil), rules...)} +} + +// NewRewriterForFlavor returns a default rewriter only for MariaDB upstreams. +func NewRewriterForFlavor(flavor string) *Rewriter { + if !strings.EqualFold(flavor, "mariadb") { + return nil + } + return NewRewriter(defaultRules...) +} + +// RewriteStmt applies all rules to stmt in place. +func (r *Rewriter) RewriteStmt(stmt ast.StmtNode) (bool, error) { + if r == nil || len(r.rules) == 0 || stmt == nil { + return false, nil + } + visitor := &rewriteVisitor{rules: r.rules} + stmt.Accept(visitor) + return visitor.changed, visitor.err +} + +// RewriteSQL parses, rewrites, and restores SQL. It is mainly intended for unit tests +// and small call sites that do not already have a parsed AST. +func (r *Rewriter) RewriteSQL(sql string) (string, bool, error) { + p := parser.New() + stmts, _, err := p.Parse(sql, "", "") + if err != nil { + return "", false, err + } + + changed := false + for _, stmt := range stmts { + stmtChanged, err := r.RewriteStmt(stmt) + if err != nil { + return "", false, err + } + changed = changed || stmtChanged + } + if !changed { + return sql, false, nil + } + + out, err := restoreStatements(stmts) + if err != nil { + return "", false, err + } + return out, true, nil +} + +type rewriteVisitor struct { + rules []Rule + changed bool + err error +} + +func (v *rewriteVisitor) Enter(node ast.Node) (ast.Node, bool) { + if v.err != nil { + return node, true + } + return node, false +} + +func (v *rewriteVisitor) Leave(node ast.Node) (ast.Node, bool) { + if v.err != nil { + return node, false + } + for _, rule := range v.rules { + changed, err := rule.Apply(node) + if err != nil { + v.err = err + return node, false + } + v.changed = v.changed || changed + } + return node, true +} + +func restoreStatements(stmts []ast.StmtNode) (string, error) { + var out strings.Builder + for i, stmt := range stmts { + if i > 0 { + out.WriteString(";\n") + } + err := stmt.Restore(&format.RestoreCtx{ + Flags: format.DefaultRestoreFlags | format.RestoreTiDBSpecialComment | format.RestoreStringWithoutDefaultCharset, + In: &out, + }) + if err != nil { + return "", err + } + } + return out.String(), nil +} diff --git a/dm/pkg/ddl/rewriter/rewriter_test.go b/dm/pkg/ddl/rewriter/rewriter_test.go new file mode 100644 index 0000000000..11f4092a00 --- /dev/null +++ b/dm/pkg/ddl/rewriter/rewriter_test.go @@ -0,0 +1,164 @@ +// Copyright 2026 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package rewriter + +import ( + "strings" + "testing" + + "github.com/pingcap/tidb/pkg/parser" + "github.com/pingcap/tidb/pkg/parser/ast" + "github.com/stretchr/testify/require" +) + +func TestRewriteSQLRemovesFunctionDefaultOnVarchar(t *testing.T) { + rewriter := newDefaultRewriterForTest() + + out, changed, err := rewriter.RewriteSQL("CREATE TABLE t(t VARCHAR(100) DEFAULT current_timestamp());") + require.NoError(t, err) + require.True(t, changed) + require.NotContains(t, strings.ToLower(out), "default") + + stmt := parseCreateTable(t, out) + col := findColumn(stmt, "t") + require.NotNil(t, col) + require.False(t, hasColumnOption(col, ast.ColumnOptionDefaultValue)) +} + +func TestRewriteSQLKeepsTimeFunctionDefaultOnTimeColumn(t *testing.T) { + rewriter := newDefaultRewriterForTest() + + out, changed, err := rewriter.RewriteSQL("CREATE TABLE t(ts TIMESTAMP DEFAULT current_timestamp());") + require.NoError(t, err) + require.False(t, changed) + require.Contains(t, strings.ToLower(out), "default current_timestamp") +} + +func TestRewriteSQLDefaultRules(t *testing.T) { + rewriter := newDefaultRewriterForTest() + input := `CREATE TABLE t ( + id INT(11), + txt TEXT DEFAULT 'x', + v VARCHAR(800), + j JSON, + g JSON GENERATED ALWAYS AS (JSON_EXTRACT(j, '$.a')) VIRTUAL, + zero_ts TIMESTAMP DEFAULT '0000-00-00 00:00:00', + CHECK (json_valid(j)), + KEY idx_txt (txt), + KEY idx_v (v) +) DEFAULT CHARSET=latin1 COLLATE=latin1_swedish_ci;` + + out, changed, err := rewriter.RewriteSQL(input) + require.NoError(t, err) + require.True(t, changed) + + stmt := parseCreateTable(t, out) + require.Equal(t, "utf8mb4", findTableOption(stmt, ast.TableOptionCharset)) + require.Equal(t, "utf8mb4_0900_ai_ci", findTableOption(stmt, ast.TableOptionCollate)) + require.Equal(t, -1, findColumn(stmt, "id").Tp.GetFlen()) + require.Equal(t, 768, findColumn(stmt, "v").Tp.GetFlen()) + require.False(t, hasColumnOption(findColumn(stmt, "txt"), ast.ColumnOptionDefaultValue)) + require.False(t, hasColumnOption(findColumn(stmt, "g"), ast.ColumnOptionGenerated)) + require.False(t, hasColumnOption(findColumn(stmt, "zero_ts"), ast.ColumnOptionDefaultValue)) + require.False(t, hasJSONValidCheck(stmt)) + + idxTxt := findConstraint(stmt, "idx_txt") + require.NotNil(t, idxTxt) + require.Equal(t, 255, idxTxt.Keys[0].Length) + idxV := findConstraint(stmt, "idx_v") + require.NotNil(t, idxV) + require.Equal(t, 768, idxV.Keys[0].Length) +} + +func TestRewriteSQLSkipsExpressionIndexPrefix(t *testing.T) { + rewriter := newDefaultRewriterForTest() + + _, _, err := rewriter.RewriteSQL("CREATE TABLE t(name VARCHAR(32), KEY idx_expr ((LOWER(name))));") + require.NoError(t, err) +} + +func TestRewriteSQLRemovesParenthesizedJSONGeneratedColumn(t *testing.T) { + rewriter := newDefaultRewriterForTest() + + out, changed, err := rewriter.RewriteSQL( + "CREATE TABLE t(j JSON, g JSON GENERATED ALWAYS AS ((JSON_EXTRACT(j, '$.a'))) VIRTUAL);", + ) + require.NoError(t, err) + require.True(t, changed) + require.False(t, hasColumnOption(findColumn(parseCreateTable(t, out), "g"), ast.ColumnOptionGenerated)) +} + +func TestNewRewriterForFlavor(t *testing.T) { + require.NotNil(t, NewRewriterForFlavor("mariadb")) + require.NotNil(t, NewRewriterForFlavor("MariaDB")) + require.Nil(t, NewRewriterForFlavor("mysql")) +} + +func newDefaultRewriterForTest() *Rewriter { + return NewRewriter(defaultRules...) +} + +func parseCreateTable(t *testing.T, sql string) *ast.CreateTableStmt { + t.Helper() + stmt, err := parser.New().ParseOneStmt(sql, "", "") + require.NoError(t, err) + create, ok := stmt.(*ast.CreateTableStmt) + require.True(t, ok) + return create +} + +func findColumn(stmt *ast.CreateTableStmt, name string) *ast.ColumnDef { + for _, col := range stmt.Cols { + if strings.EqualFold(col.Name.Name.O, name) { + return col + } + } + return nil +} + +func hasColumnOption(col *ast.ColumnDef, optionType ast.ColumnOptionType) bool { + for _, opt := range col.Options { + if opt.Tp == optionType { + return true + } + } + return false +} + +func hasJSONValidCheck(stmt *ast.CreateTableStmt) bool { + for _, cons := range stmt.Constraints { + if cons.Tp == ast.ConstraintCheck && isJSONValidExpr(cons.Expr) { + return true + } + } + return false +} + +func findConstraint(stmt *ast.CreateTableStmt, name string) *ast.Constraint { + for _, cons := range stmt.Constraints { + if strings.EqualFold(cons.Name, name) { + return cons + } + } + return nil +} + +func findTableOption(stmt *ast.CreateTableStmt, optionType ast.TableOptionType) string { + for _, opt := range stmt.Options { + if opt.Tp == optionType { + return strings.ToLower(opt.StrValue) + } + } + return "" +} diff --git a/dm/pkg/ddl/rewriter/rules.go b/dm/pkg/ddl/rewriter/rules.go new file mode 100644 index 0000000000..064becd4b4 --- /dev/null +++ b/dm/pkg/ddl/rewriter/rules.go @@ -0,0 +1,419 @@ +// Copyright 2026 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package rewriter + +import ( + "strings" + + "github.com/pingcap/tidb/pkg/parser/ast" + "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/parser/types" + tidbtypes "github.com/pingcap/tidb/pkg/types" +) + +const maxVarcharLen = 768 + +var defaultRules = []Rule{ + collationRule{}, + zeroTimestampRule{}, + keyLengthRule{}, + indexPrefixRule{}, + integerWidthRule{}, + textBlobDefaultRule{}, + jsonCheckRule{}, + functionDefaultRule{}, + jsonGeneratedRule{}, +} + +type collationRule struct{} + +func (r collationRule) Name() string { return "collation" } + +func (r collationRule) Apply(node ast.Node) (bool, error) { + switch n := node.(type) { + case *ast.CreateDatabaseStmt: + return rewriteDatabaseOptions(n.Options), nil + case *ast.CreateTableStmt: + return rewriteTableOptions(&n.Options), nil + case *ast.ColumnDef: + return rewriteColumnCollations(n), nil + default: + return false, nil + } +} + +type zeroTimestampRule struct{} + +func (r zeroTimestampRule) Name() string { return "zero-timestamp" } + +func (r zeroTimestampRule) Apply(node ast.Node) (bool, error) { + col, ok := node.(*ast.ColumnDef) + if !ok || !isTimeType(col.Tp.GetType()) { + return false, nil + } + return filterColumnOptions(col, func(opt *ast.ColumnOption) bool { + return opt.Tp == ast.ColumnOptionDefaultValue && isZeroTimeDefault(opt.Expr) + }), nil +} + +type keyLengthRule struct{} + +func (r keyLengthRule) Name() string { return "key-length" } + +func (r keyLengthRule) Apply(node ast.Node) (bool, error) { + col, ok := node.(*ast.ColumnDef) + if !ok || col.Tp.GetFlen() <= maxVarcharLen { + return false, nil + } + switch col.Tp.GetType() { + case mysql.TypeVarchar, mysql.TypeVarString: + col.Tp.SetFlen(maxVarcharLen) + return true, nil + default: + return false, nil + } +} + +type indexPrefixRule struct{} + +func (r indexPrefixRule) Name() string { return "index-prefix" } + +func (r indexPrefixRule) Apply(node ast.Node) (bool, error) { + stmt, ok := node.(*ast.CreateTableStmt) + if !ok { + return false, nil + } + colMap := make(map[string]*ast.ColumnDef, len(stmt.Cols)) + for _, col := range stmt.Cols { + colMap[col.Name.Name.L] = col + } + + changed := false + for _, cons := range stmt.Constraints { + switch cons.Tp { + case ast.ConstraintPrimaryKey, ast.ConstraintKey, ast.ConstraintIndex, ast.ConstraintUniq: + default: + continue + } + for _, key := range cons.Keys { + if key.Length > 0 { + continue + } + if key.Column == nil { + continue + } + col := colMap[key.Column.Name.L] + if col == nil { + continue + } + switch { + case isTextOrBlob(col.Tp): + key.Length = 255 + changed = true + case isVarcharOrChar(col.Tp) && col.Tp.GetFlen() > 0: + key.Length = col.Tp.GetFlen() + changed = true + } + } + } + return changed, nil +} + +type integerWidthRule struct{} + +func (r integerWidthRule) Name() string { return "integer-width" } + +func (r integerWidthRule) Apply(node ast.Node) (bool, error) { + col, ok := node.(*ast.ColumnDef) + if !ok || !mysql.IsIntegerType(col.Tp.GetType()) { + return false, nil + } + if col.Tp.GetFlen() == types.UnspecifiedLength || col.Tp.GetFlen() <= 0 { + return false, nil + } + col.Tp.SetFlen(types.UnspecifiedLength) + return true, nil +} + +type textBlobDefaultRule struct{} + +func (r textBlobDefaultRule) Name() string { return "text-blob-default" } + +func (r textBlobDefaultRule) Apply(node ast.Node) (bool, error) { + col, ok := node.(*ast.ColumnDef) + if !ok || !isTextBlobOrJSON(col.Tp) { + return false, nil + } + return filterColumnOptions(col, func(opt *ast.ColumnOption) bool { + return opt.Tp == ast.ColumnOptionDefaultValue + }), nil +} + +type jsonCheckRule struct{} + +func (r jsonCheckRule) Name() string { return "json-check" } + +func (r jsonCheckRule) Apply(node ast.Node) (bool, error) { + switch n := node.(type) { + case *ast.ColumnDef: + return filterColumnOptions(n, func(opt *ast.ColumnOption) bool { + return opt.Tp == ast.ColumnOptionCheck && isJSONValidExpr(opt.Expr) + }), nil + case *ast.CreateTableStmt: + constraints := n.Constraints[:0] + changed := false + for _, cons := range n.Constraints { + if cons.Tp == ast.ConstraintCheck && isJSONValidExpr(cons.Expr) { + changed = true + continue + } + constraints = append(constraints, cons) + } + n.Constraints = constraints + return changed, nil + default: + return false, nil + } +} + +type functionDefaultRule struct{} + +func (r functionDefaultRule) Name() string { return "function-default" } + +func (r functionDefaultRule) Apply(node ast.Node) (bool, error) { + col, ok := node.(*ast.ColumnDef) + if !ok { + return false, nil + } + return filterColumnOptions(col, func(opt *ast.ColumnOption) bool { + return opt.Tp == ast.ColumnOptionDefaultValue && !keepDefaultExpr(col, opt.Expr) + }), nil +} + +type jsonGeneratedRule struct{} + +func (r jsonGeneratedRule) Name() string { return "json-generated" } + +func (r jsonGeneratedRule) Apply(node ast.Node) (bool, error) { + col, ok := node.(*ast.ColumnDef) + if !ok || !isJSONGenerated(col) { + return false, nil + } + return filterColumnOptions(col, func(opt *ast.ColumnOption) bool { + return opt.Tp == ast.ColumnOptionGenerated + }), nil +} + +func rewriteDatabaseOptions(options []*ast.DatabaseOption) bool { + changed := false + for _, opt := range options { + switch opt.Tp { + case ast.DatabaseOptionCharset: + if strings.EqualFold(opt.Value, "latin1") { + opt.Value = "utf8mb4" + changed = true + } + case ast.DatabaseOptionCollate: + if collation, ok := mapCollation(opt.Value); ok { + opt.Value = collation + changed = true + } + } + } + return changed +} + +func rewriteTableOptions(options *[]*ast.TableOption) bool { + changed := false + needCollate := "" + hasCollate := false + for _, opt := range *options { + switch opt.Tp { + case ast.TableOptionCharset: + if strings.EqualFold(opt.StrValue, "latin1") { + opt.StrValue = "utf8mb4" + needCollate = "utf8mb4_0900_ai_ci" + changed = true + } + case ast.TableOptionCollate: + hasCollate = true + if collation, ok := mapCollation(opt.StrValue); ok { + opt.StrValue = collation + changed = true + } + } + } + if needCollate != "" && !hasCollate { + *options = append(*options, &ast.TableOption{Tp: ast.TableOptionCollate, StrValue: needCollate}) + changed = true + } + return changed +} + +func rewriteColumnCollations(col *ast.ColumnDef) bool { + changed := false + for _, opt := range col.Options { + if opt.Tp != ast.ColumnOptionCollate { + continue + } + if collation, ok := mapCollation(opt.StrValue); ok { + opt.StrValue = collation + changed = true + } + } + return changed +} + +func mapCollation(collation string) (string, bool) { + name := strings.ToLower(collation) + if name == "latin1_swedish_ci" { + return "utf8mb4_0900_ai_ci", true + } + if strings.HasPrefix(name, "utf8mb4_unicode_") { + return "utf8mb4_0900_ai_ci", true + } + return "", false +} + +func filterColumnOptions(col *ast.ColumnDef, drop func(*ast.ColumnOption) bool) bool { + options := col.Options[:0] + changed := false + for _, opt := range col.Options { + if drop(opt) { + changed = true + continue + } + options = append(options, opt) + } + col.Options = options + return changed +} + +func keepDefaultExpr(col *ast.ColumnDef, expr ast.ExprNode) bool { + expr = unwrapParentheses(expr) + if _, ok := expr.(ast.ValueExpr); ok { + return true + } + fn, ok := expr.(*ast.FuncCallExpr) + if !ok { + return false + } + return isTimeType(col.Tp.GetType()) && allowedTimeDefaultFuncs[fn.FnName.L] +} + +var allowedTimeDefaultFuncs = map[string]bool{ + "current_timestamp": true, + "current_date": true, + "current_time": true, + "now": true, + "localtime": true, + "localtimestamp": true, +} + +func isJSONValidExpr(expr ast.ExprNode) bool { + fn, ok := expr.(*ast.FuncCallExpr) + return ok && strings.EqualFold(fn.FnName.O, "json_valid") +} + +func isJSONGenerated(col *ast.ColumnDef) bool { + for _, opt := range col.Options { + if opt.Tp != ast.ColumnOptionGenerated { + continue + } + fn, ok := unwrapParentheses(opt.Expr).(*ast.FuncCallExpr) + if ok && strings.HasPrefix(fn.FnName.L, "json") { + return true + } + } + return false +} + +func unwrapParentheses(expr ast.ExprNode) ast.ExprNode { + for { + p, ok := expr.(*ast.ParenthesesExpr) + if !ok { + return expr + } + expr = p.Expr + } +} + +func isZeroTimeDefault(expr ast.ExprNode) bool { + valExpr, ok := expr.(ast.ValueExpr) + if !ok { + return false + } + switch v := valExpr.GetValue().(type) { + case tidbtypes.Time: + return v.IsZero() || v.InvalidZero() + case string: + return isZeroTimeString(v) + case []byte: + return isZeroTimeString(string(v)) + case int: + return v == 0 + case int64: + return v == 0 + case uint64: + return v == 0 + default: + return false + } +} + +func isZeroTimeString(value string) bool { + value = strings.TrimSpace(value) + if !strings.HasPrefix(value, "0000-00-00") { + return false + } + rest := strings.TrimSpace(strings.TrimPrefix(value, "0000-00-00")) + if rest == "" || rest == "00:00:00" { + return true + } + if !strings.HasPrefix(rest, "00:00:00.") { + return false + } + return strings.Trim(rest[len("00:00:00."):], "0") == "" +} + +func isTimeType(tp byte) bool { + switch tp { + case mysql.TypeDate, mysql.TypeDatetime, mysql.TypeTimestamp: + return true + default: + return false + } +} + +func isTextBlobOrJSON(ft *types.FieldType) bool { + return types.IsTypeBlob(ft.GetType()) || ft.GetType() == mysql.TypeJSON +} + +func isTextOrBlob(ft *types.FieldType) bool { + switch ft.GetType() { + case mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeBlob, mysql.TypeLongBlob: + return true + default: + return false + } +} + +func isVarcharOrChar(ft *types.FieldType) bool { + switch ft.GetType() { + case mysql.TypeVarchar, mysql.TypeVarString, mysql.TypeString: + return true + default: + return false + } +} diff --git a/dm/syncer/ddl.go b/dm/syncer/ddl.go index 2731e8feb1..cae4dff7d5 100644 --- a/dm/syncer/ddl.go +++ b/dm/syncer/ddl.go @@ -37,6 +37,7 @@ import ( "github.com/pingcap/tiflow/dm/pkg/binlog/event" "github.com/pingcap/tiflow/dm/pkg/conn" tcontext "github.com/pingcap/tiflow/dm/pkg/context" + ddlrewriter "github.com/pingcap/tiflow/dm/pkg/ddl/rewriter" "github.com/pingcap/tiflow/dm/pkg/log" parserpkg "github.com/pingcap/tiflow/dm/pkg/parser" "github.com/pingcap/tiflow/dm/pkg/schema" @@ -81,6 +82,7 @@ type DDLWorker struct { idAndCollationMap map[int]string baList *filter.Filter foreignKeyChecksEnabled bool + ddlRewriter *ddlrewriter.Rewriter getTableInfo func(tctx *tcontext.Context, sourceTable, targetTable *filter.Table) (*model.TableInfo, error) getDBInfoFromDownstream func(tctx *tcontext.Context, sourceTable, targetTable *filter.Table) (*model.DBInfo, error) @@ -111,6 +113,7 @@ func NewDDLWorker(pLogger *log.Logger, syncer *Syncer) *DDLWorker { idAndCollationMap: syncer.idAndCollationMap, baList: syncer.baList, foreignKeyChecksEnabled: config.IsForeignKeyChecksEnabled(syncer.cfg.To.Session), + ddlRewriter: syncer.ddlRewriter, recordSkipSQLsLocation: syncer.recordSkipSQLsLocation, trackDDL: syncer.trackDDL, saveTablePoint: syncer.saveTablePoint, @@ -241,6 +244,7 @@ func (ddl *DDLWorker) HandleQueryEvent(ev *replication.QueryEvent, ec eventConte eventContext: &ec, ddlSchema: string(ev.Schema), originSQL: utils.TrimCtrlChars(originSQL), + ddlRewriter: ddl.ddlRewriter, splitDDLs: make([]string, 0), appliedDDLs: make([]string, 0), sourceTbls: make(map[string]map[string]struct{}), @@ -964,7 +968,20 @@ func parseOneStmt(qec *queryEventContext) (stmt ast.StmtNode, err error) { if len(stmts) == 0 { return nil, nil } - return stmts[0], nil + stmt = stmts[0] + if qec.ddlRewriter == nil { + return stmt, nil + } + changed, err := qec.ddlRewriter.RewriteStmt(stmt) + if err != nil { + return nil, terror.ErrRewriteSQL.Delegate(err, qec.originSQL) + } + if changed { + qec.tctx.L().Info("rewrite MariaDB DDL with AST compatibility rules", + zap.String("event", "query"), + zap.String("originSQL", qec.originSQL)) + } + return stmt, nil } // copy from https://github.com/pingcap/tidb/blob/fc4f8a1d8f5342cd01f78eb460e47d78d177ed20/ddl/column.go#L366 diff --git a/dm/syncer/ddl_test.go b/dm/syncer/ddl_test.go index 23d7490fa1..a3213b6ac0 100644 --- a/dm/syncer/ddl_test.go +++ b/dm/syncer/ddl_test.go @@ -31,6 +31,7 @@ import ( "github.com/pingcap/tiflow/dm/config/dbconfig" "github.com/pingcap/tiflow/dm/pkg/conn" tcontext "github.com/pingcap/tiflow/dm/pkg/context" + ddlrewriter "github.com/pingcap/tiflow/dm/pkg/ddl/rewriter" "github.com/pingcap/tiflow/dm/pkg/log" parserpkg "github.com/pingcap/tiflow/dm/pkg/parser" "github.com/pingcap/tiflow/dm/pkg/terror" @@ -419,6 +420,24 @@ func (s *testDDLSuite) TestParseOneStmt(c *check.C) { } } +func TestParseOneStmtWithMariaDBASTRewriter(t *testing.T) { + tctx := tcontext.Background().WithLogger(log.With(zap.String("test", "TestParseOneStmtWithMariaDBASTRewriter"))) + qec := &queryEventContext{ + eventContext: &eventContext{tctx: tctx}, + ddlSchema: "test", + originSQL: "CREATE TABLE t(c VARCHAR(100) DEFAULT current_timestamp())", + p: parser.New(), + ddlRewriter: ddlrewriter.NewRewriterForFlavor("mariadb"), + } + + stmt, err := parseOneStmt(qec) + require.NoError(t, err) + sqls, err := parserpkg.SplitDDL(stmt, qec.ddlSchema) + require.NoError(t, err) + require.Len(t, sqls, 1) + require.NotContains(t, strings.ToLower(sqls[0]), "default") +} + func (s *testDDLSuite) TestResolveGeneratedColumnSQL(c *check.C) { testCases := []struct { sql string diff --git a/dm/syncer/syncer.go b/dm/syncer/syncer.go index 1013bf2a39..89a33d6f4b 100644 --- a/dm/syncer/syncer.go +++ b/dm/syncer/syncer.go @@ -50,6 +50,7 @@ import ( "github.com/pingcap/tiflow/dm/pkg/binlog/reader" "github.com/pingcap/tiflow/dm/pkg/conn" tcontext "github.com/pingcap/tiflow/dm/pkg/context" + ddlrewriter "github.com/pingcap/tiflow/dm/pkg/ddl/rewriter" fr "github.com/pingcap/tiflow/dm/pkg/func-rollback" "github.com/pingcap/tiflow/dm/pkg/gtid" "github.com/pingcap/tiflow/dm/pkg/ha" @@ -266,6 +267,7 @@ type Syncer struct { idAndCollationMap map[int]string ddlWorker *DDLWorker + ddlRewriter *ddlrewriter.Rewriter fetchBinlogLogger *zap.Logger unhandledEventLogger *zap.Logger } @@ -533,6 +535,7 @@ func (s *Syncer) Init(ctx context.Context) (err error) { } s.metricsProxies = metricProxies.CacheForOneTask(s.cfg.Name, s.cfg.WorkerName, s.cfg.SourceID) + s.ddlRewriter = ddlrewriter.NewRewriterForFlavor(s.cfg.Flavor) s.ddlWorker = NewDDLWorker(&s.tctx.Logger, s) return nil } @@ -2768,9 +2771,10 @@ func (s *Syncer) handleRowsEvent(ev *replication.RowsEvent, ec eventContext) (*f type queryEventContext struct { *eventContext - p *parser.Parser // used parser - ddlSchema string // used schema - originSQL string // before split + p *parser.Parser // used parser + ddlSchema string // used schema + originSQL string // before split + ddlRewriter *ddlrewriter.Rewriter // split multi-schema change DDL into multiple one schema change DDL due to TiDB's limitation splitDDLs []string // after split before online ddl appliedDDLs []string // after onlineDDL apply if onlineDDL != nil @@ -2959,6 +2963,7 @@ func (s *Syncer) trackOriginDDL(ev *replication.QueryEvent, ec eventContext) (ma eventContext: &ec, ddlSchema: string(ev.Schema), originSQL: utils.TrimCtrlChars(originSQL), + ddlRewriter: s.ddlRewriter, splitDDLs: make([]string, 0), appliedDDLs: make([]string, 0), sourceTbls: make(map[string]map[string]struct{}), From 9cc7e0188edb26ded5a662d1747eb602820565fe Mon Sep 17 00:00:00 2001 From: Ruihao Chen Date: Tue, 16 Jun 2026 13:28:00 +0800 Subject: [PATCH 02/17] syncer(dm): simplify DDL rewriter API --- dm/pkg/ddl/rewriter/rewriter.go | 74 +++------------------------- dm/pkg/ddl/rewriter/rewriter_test.go | 56 +++++++-------------- dm/pkg/ddl/rewriter/rules.go | 20 +------- dm/syncer/ddl_test.go | 2 +- dm/syncer/syncer.go | 4 +- 5 files changed, 31 insertions(+), 125 deletions(-) diff --git a/dm/pkg/ddl/rewriter/rewriter.go b/dm/pkg/ddl/rewriter/rewriter.go index dda4bf5465..f8f0e18811 100644 --- a/dm/pkg/ddl/rewriter/rewriter.go +++ b/dm/pkg/ddl/rewriter/rewriter.go @@ -14,41 +14,26 @@ package rewriter import ( - "strings" - - "github.com/pingcap/tidb/pkg/parser" "github.com/pingcap/tidb/pkg/parser/ast" - "github.com/pingcap/tidb/pkg/parser/format" - _ "github.com/pingcap/tidb/pkg/types/parser_driver" // register parser driver ) -// Rule rewrites one AST node in place. -type Rule interface { - Name() string +type rule interface { Apply(ast.Node) (bool, error) } // Rewriter applies a fixed, ordered set of AST rules. type Rewriter struct { - rules []Rule + rules []rule } -// NewRewriter creates a rewriter from explicit rules. -func NewRewriter(rules ...Rule) *Rewriter { - return &Rewriter{rules: append([]Rule(nil), rules...)} -} - -// NewRewriterForFlavor returns a default rewriter only for MariaDB upstreams. -func NewRewriterForFlavor(flavor string) *Rewriter { - if !strings.EqualFold(flavor, "mariadb") { - return nil - } - return NewRewriter(defaultRules...) +// NewRewriter creates the default MariaDB compatibility AST rewriter. +func NewRewriter() *Rewriter { + return &Rewriter{rules: append([]rule(nil), defaultRules...)} } // RewriteStmt applies all rules to stmt in place. func (r *Rewriter) RewriteStmt(stmt ast.StmtNode) (bool, error) { - if r == nil || len(r.rules) == 0 || stmt == nil { + if len(r.rules) == 0 || stmt == nil { return false, nil } visitor := &rewriteVisitor{rules: r.rules} @@ -56,36 +41,8 @@ func (r *Rewriter) RewriteStmt(stmt ast.StmtNode) (bool, error) { return visitor.changed, visitor.err } -// RewriteSQL parses, rewrites, and restores SQL. It is mainly intended for unit tests -// and small call sites that do not already have a parsed AST. -func (r *Rewriter) RewriteSQL(sql string) (string, bool, error) { - p := parser.New() - stmts, _, err := p.Parse(sql, "", "") - if err != nil { - return "", false, err - } - - changed := false - for _, stmt := range stmts { - stmtChanged, err := r.RewriteStmt(stmt) - if err != nil { - return "", false, err - } - changed = changed || stmtChanged - } - if !changed { - return sql, false, nil - } - - out, err := restoreStatements(stmts) - if err != nil { - return "", false, err - } - return out, true, nil -} - type rewriteVisitor struct { - rules []Rule + rules []rule changed bool err error } @@ -111,20 +68,3 @@ func (v *rewriteVisitor) Leave(node ast.Node) (ast.Node, bool) { } return node, true } - -func restoreStatements(stmts []ast.StmtNode) (string, error) { - var out strings.Builder - for i, stmt := range stmts { - if i > 0 { - out.WriteString(";\n") - } - err := stmt.Restore(&format.RestoreCtx{ - Flags: format.DefaultRestoreFlags | format.RestoreTiDBSpecialComment | format.RestoreStringWithoutDefaultCharset, - In: &out, - }) - if err != nil { - return "", err - } - } - return out.String(), nil -} diff --git a/dm/pkg/ddl/rewriter/rewriter_test.go b/dm/pkg/ddl/rewriter/rewriter_test.go index 11f4092a00..5ed6360275 100644 --- a/dm/pkg/ddl/rewriter/rewriter_test.go +++ b/dm/pkg/ddl/rewriter/rewriter_test.go @@ -19,34 +19,26 @@ import ( "github.com/pingcap/tidb/pkg/parser" "github.com/pingcap/tidb/pkg/parser/ast" + _ "github.com/pingcap/tidb/pkg/types/parser_driver" // register parser driver "github.com/stretchr/testify/require" ) -func TestRewriteSQLRemovesFunctionDefaultOnVarchar(t *testing.T) { - rewriter := newDefaultRewriterForTest() - - out, changed, err := rewriter.RewriteSQL("CREATE TABLE t(t VARCHAR(100) DEFAULT current_timestamp());") - require.NoError(t, err) +func TestRewriteStmtRemovesFunctionDefaultOnVarchar(t *testing.T) { + stmt, changed := rewriteCreateTable(t, "CREATE TABLE t(t VARCHAR(100) DEFAULT current_timestamp());") require.True(t, changed) - require.NotContains(t, strings.ToLower(out), "default") - stmt := parseCreateTable(t, out) col := findColumn(stmt, "t") require.NotNil(t, col) require.False(t, hasColumnOption(col, ast.ColumnOptionDefaultValue)) } -func TestRewriteSQLKeepsTimeFunctionDefaultOnTimeColumn(t *testing.T) { - rewriter := newDefaultRewriterForTest() - - out, changed, err := rewriter.RewriteSQL("CREATE TABLE t(ts TIMESTAMP DEFAULT current_timestamp());") - require.NoError(t, err) +func TestRewriteStmtKeepsTimeFunctionDefaultOnTimeColumn(t *testing.T) { + stmt, changed := rewriteCreateTable(t, "CREATE TABLE t(ts TIMESTAMP DEFAULT current_timestamp());") require.False(t, changed) - require.Contains(t, strings.ToLower(out), "default current_timestamp") + require.True(t, hasColumnOption(findColumn(stmt, "ts"), ast.ColumnOptionDefaultValue)) } -func TestRewriteSQLDefaultRules(t *testing.T) { - rewriter := newDefaultRewriterForTest() +func TestRewriteStmtDefaultRules(t *testing.T) { input := `CREATE TABLE t ( id INT(11), txt TEXT DEFAULT 'x', @@ -59,11 +51,9 @@ func TestRewriteSQLDefaultRules(t *testing.T) { KEY idx_v (v) ) DEFAULT CHARSET=latin1 COLLATE=latin1_swedish_ci;` - out, changed, err := rewriter.RewriteSQL(input) - require.NoError(t, err) + stmt, changed := rewriteCreateTable(t, input) require.True(t, changed) - stmt := parseCreateTable(t, out) require.Equal(t, "utf8mb4", findTableOption(stmt, ast.TableOptionCharset)) require.Equal(t, "utf8mb4_0900_ai_ci", findTableOption(stmt, ast.TableOptionCollate)) require.Equal(t, -1, findColumn(stmt, "id").Tp.GetFlen()) @@ -81,32 +71,24 @@ func TestRewriteSQLDefaultRules(t *testing.T) { require.Equal(t, 768, idxV.Keys[0].Length) } -func TestRewriteSQLSkipsExpressionIndexPrefix(t *testing.T) { - rewriter := newDefaultRewriterForTest() - - _, _, err := rewriter.RewriteSQL("CREATE TABLE t(name VARCHAR(32), KEY idx_expr ((LOWER(name))));") - require.NoError(t, err) +func TestRewriteStmtSkipsExpressionIndexPrefix(t *testing.T) { + rewriteCreateTable(t, "CREATE TABLE t(name VARCHAR(32), KEY idx_expr ((LOWER(name))));") } -func TestRewriteSQLRemovesParenthesizedJSONGeneratedColumn(t *testing.T) { - rewriter := newDefaultRewriterForTest() - - out, changed, err := rewriter.RewriteSQL( +func TestRewriteStmtRemovesParenthesizedJSONGeneratedColumn(t *testing.T) { + stmt, changed := rewriteCreateTable(t, "CREATE TABLE t(j JSON, g JSON GENERATED ALWAYS AS ((JSON_EXTRACT(j, '$.a'))) VIRTUAL);", ) - require.NoError(t, err) require.True(t, changed) - require.False(t, hasColumnOption(findColumn(parseCreateTable(t, out), "g"), ast.ColumnOptionGenerated)) -} - -func TestNewRewriterForFlavor(t *testing.T) { - require.NotNil(t, NewRewriterForFlavor("mariadb")) - require.NotNil(t, NewRewriterForFlavor("MariaDB")) - require.Nil(t, NewRewriterForFlavor("mysql")) + require.False(t, hasColumnOption(findColumn(stmt, "g"), ast.ColumnOptionGenerated)) } -func newDefaultRewriterForTest() *Rewriter { - return NewRewriter(defaultRules...) +func rewriteCreateTable(t *testing.T, sql string) (*ast.CreateTableStmt, bool) { + t.Helper() + stmt := parseCreateTable(t, sql) + changed, err := NewRewriter().RewriteStmt(stmt) + require.NoError(t, err) + return stmt, changed } func parseCreateTable(t *testing.T, sql string) *ast.CreateTableStmt { diff --git a/dm/pkg/ddl/rewriter/rules.go b/dm/pkg/ddl/rewriter/rules.go index 064becd4b4..98e52bf13c 100644 --- a/dm/pkg/ddl/rewriter/rules.go +++ b/dm/pkg/ddl/rewriter/rules.go @@ -24,7 +24,7 @@ import ( const maxVarcharLen = 768 -var defaultRules = []Rule{ +var defaultRules = []rule{ collationRule{}, zeroTimestampRule{}, keyLengthRule{}, @@ -38,8 +38,6 @@ var defaultRules = []Rule{ type collationRule struct{} -func (r collationRule) Name() string { return "collation" } - func (r collationRule) Apply(node ast.Node) (bool, error) { switch n := node.(type) { case *ast.CreateDatabaseStmt: @@ -55,8 +53,6 @@ func (r collationRule) Apply(node ast.Node) (bool, error) { type zeroTimestampRule struct{} -func (r zeroTimestampRule) Name() string { return "zero-timestamp" } - func (r zeroTimestampRule) Apply(node ast.Node) (bool, error) { col, ok := node.(*ast.ColumnDef) if !ok || !isTimeType(col.Tp.GetType()) { @@ -69,8 +65,6 @@ func (r zeroTimestampRule) Apply(node ast.Node) (bool, error) { type keyLengthRule struct{} -func (r keyLengthRule) Name() string { return "key-length" } - func (r keyLengthRule) Apply(node ast.Node) (bool, error) { col, ok := node.(*ast.ColumnDef) if !ok || col.Tp.GetFlen() <= maxVarcharLen { @@ -87,8 +81,6 @@ func (r keyLengthRule) Apply(node ast.Node) (bool, error) { type indexPrefixRule struct{} -func (r indexPrefixRule) Name() string { return "index-prefix" } - func (r indexPrefixRule) Apply(node ast.Node) (bool, error) { stmt, ok := node.(*ast.CreateTableStmt) if !ok { @@ -132,8 +124,6 @@ func (r indexPrefixRule) Apply(node ast.Node) (bool, error) { type integerWidthRule struct{} -func (r integerWidthRule) Name() string { return "integer-width" } - func (r integerWidthRule) Apply(node ast.Node) (bool, error) { col, ok := node.(*ast.ColumnDef) if !ok || !mysql.IsIntegerType(col.Tp.GetType()) { @@ -148,8 +138,6 @@ func (r integerWidthRule) Apply(node ast.Node) (bool, error) { type textBlobDefaultRule struct{} -func (r textBlobDefaultRule) Name() string { return "text-blob-default" } - func (r textBlobDefaultRule) Apply(node ast.Node) (bool, error) { col, ok := node.(*ast.ColumnDef) if !ok || !isTextBlobOrJSON(col.Tp) { @@ -162,8 +150,6 @@ func (r textBlobDefaultRule) Apply(node ast.Node) (bool, error) { type jsonCheckRule struct{} -func (r jsonCheckRule) Name() string { return "json-check" } - func (r jsonCheckRule) Apply(node ast.Node) (bool, error) { switch n := node.(type) { case *ast.ColumnDef: @@ -189,8 +175,6 @@ func (r jsonCheckRule) Apply(node ast.Node) (bool, error) { type functionDefaultRule struct{} -func (r functionDefaultRule) Name() string { return "function-default" } - func (r functionDefaultRule) Apply(node ast.Node) (bool, error) { col, ok := node.(*ast.ColumnDef) if !ok { @@ -203,8 +187,6 @@ func (r functionDefaultRule) Apply(node ast.Node) (bool, error) { type jsonGeneratedRule struct{} -func (r jsonGeneratedRule) Name() string { return "json-generated" } - func (r jsonGeneratedRule) Apply(node ast.Node) (bool, error) { col, ok := node.(*ast.ColumnDef) if !ok || !isJSONGenerated(col) { diff --git a/dm/syncer/ddl_test.go b/dm/syncer/ddl_test.go index a3213b6ac0..79858bacb8 100644 --- a/dm/syncer/ddl_test.go +++ b/dm/syncer/ddl_test.go @@ -427,7 +427,7 @@ func TestParseOneStmtWithMariaDBASTRewriter(t *testing.T) { ddlSchema: "test", originSQL: "CREATE TABLE t(c VARCHAR(100) DEFAULT current_timestamp())", p: parser.New(), - ddlRewriter: ddlrewriter.NewRewriterForFlavor("mariadb"), + ddlRewriter: ddlrewriter.NewRewriter(), } stmt, err := parseOneStmt(qec) diff --git a/dm/syncer/syncer.go b/dm/syncer/syncer.go index 89a33d6f4b..454b66ed8b 100644 --- a/dm/syncer/syncer.go +++ b/dm/syncer/syncer.go @@ -535,7 +535,9 @@ func (s *Syncer) Init(ctx context.Context) (err error) { } s.metricsProxies = metricProxies.CacheForOneTask(s.cfg.Name, s.cfg.WorkerName, s.cfg.SourceID) - s.ddlRewriter = ddlrewriter.NewRewriterForFlavor(s.cfg.Flavor) + if strings.EqualFold(s.cfg.Flavor, mysql.MariaDBFlavor) { + s.ddlRewriter = ddlrewriter.NewRewriter() + } s.ddlWorker = NewDDLWorker(&s.tctx.Logger, s) return nil } From 54dda0bfa6b2c3f2a56efd096b59c05401e94389 Mon Sep 17 00:00:00 2001 From: Ruihao Chen Date: Tue, 16 Jun 2026 13:38:13 +0800 Subject: [PATCH 03/17] syncer(dm): use fixed DDL rewrite rules --- dm/pkg/ddl/rewriter/rewriter.go | 16 +++----------- dm/pkg/ddl/rewriter/rewriter_test.go | 2 +- dm/pkg/ddl/rewriter/rules.go | 29 +++++--------------------- dm/syncer/ddl.go | 24 ++++++++++----------- dm/syncer/ddl_test.go | 11 +++++----- dm/syncer/syncer.go | 31 +++++++++++++--------------- 6 files changed, 40 insertions(+), 73 deletions(-) diff --git a/dm/pkg/ddl/rewriter/rewriter.go b/dm/pkg/ddl/rewriter/rewriter.go index f8f0e18811..a9e7de612a 100644 --- a/dm/pkg/ddl/rewriter/rewriter.go +++ b/dm/pkg/ddl/rewriter/rewriter.go @@ -21,22 +21,12 @@ type rule interface { Apply(ast.Node) (bool, error) } -// Rewriter applies a fixed, ordered set of AST rules. -type Rewriter struct { - rules []rule -} - -// NewRewriter creates the default MariaDB compatibility AST rewriter. -func NewRewriter() *Rewriter { - return &Rewriter{rules: append([]rule(nil), defaultRules...)} -} - // RewriteStmt applies all rules to stmt in place. -func (r *Rewriter) RewriteStmt(stmt ast.StmtNode) (bool, error) { - if len(r.rules) == 0 || stmt == nil { +func RewriteStmt(stmt ast.StmtNode) (bool, error) { + if stmt == nil { return false, nil } - visitor := &rewriteVisitor{rules: r.rules} + visitor := &rewriteVisitor{rules: defaultRules} stmt.Accept(visitor) return visitor.changed, visitor.err } diff --git a/dm/pkg/ddl/rewriter/rewriter_test.go b/dm/pkg/ddl/rewriter/rewriter_test.go index 5ed6360275..d179557150 100644 --- a/dm/pkg/ddl/rewriter/rewriter_test.go +++ b/dm/pkg/ddl/rewriter/rewriter_test.go @@ -86,7 +86,7 @@ func TestRewriteStmtRemovesParenthesizedJSONGeneratedColumn(t *testing.T) { func rewriteCreateTable(t *testing.T, sql string) (*ast.CreateTableStmt, bool) { t.Helper() stmt := parseCreateTable(t, sql) - changed, err := NewRewriter().RewriteStmt(stmt) + changed, err := RewriteStmt(stmt) require.NoError(t, err) return stmt, changed } diff --git a/dm/pkg/ddl/rewriter/rules.go b/dm/pkg/ddl/rewriter/rules.go index 98e52bf13c..82f04d6297 100644 --- a/dm/pkg/ddl/rewriter/rules.go +++ b/dm/pkg/ddl/rewriter/rules.go @@ -70,13 +70,11 @@ func (r keyLengthRule) Apply(node ast.Node) (bool, error) { if !ok || col.Tp.GetFlen() <= maxVarcharLen { return false, nil } - switch col.Tp.GetType() { - case mysql.TypeVarchar, mysql.TypeVarString: + if tidbtypes.IsTypeVarchar(col.Tp.GetType()) { col.Tp.SetFlen(maxVarcharLen) return true, nil - default: - return false, nil } + return false, nil } type indexPrefixRule struct{} @@ -110,10 +108,11 @@ func (r indexPrefixRule) Apply(node ast.Node) (bool, error) { continue } switch { - case isTextOrBlob(col.Tp): + case types.IsTypeBlob(col.Tp.GetType()): key.Length = 255 changed = true - case isVarcharOrChar(col.Tp) && col.Tp.GetFlen() > 0: + case (tidbtypes.IsTypeChar(col.Tp.GetType()) || tidbtypes.IsTypeVarchar(col.Tp.GetType())) && + col.Tp.GetFlen() > 0: key.Length = col.Tp.GetFlen() changed = true } @@ -381,21 +380,3 @@ func isTimeType(tp byte) bool { func isTextBlobOrJSON(ft *types.FieldType) bool { return types.IsTypeBlob(ft.GetType()) || ft.GetType() == mysql.TypeJSON } - -func isTextOrBlob(ft *types.FieldType) bool { - switch ft.GetType() { - case mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeBlob, mysql.TypeLongBlob: - return true - default: - return false - } -} - -func isVarcharOrChar(ft *types.FieldType) bool { - switch ft.GetType() { - case mysql.TypeVarchar, mysql.TypeVarString, mysql.TypeString: - return true - default: - return false - } -} diff --git a/dm/syncer/ddl.go b/dm/syncer/ddl.go index cae4dff7d5..e4e5d30dac 100644 --- a/dm/syncer/ddl.go +++ b/dm/syncer/ddl.go @@ -82,7 +82,7 @@ type DDLWorker struct { idAndCollationMap map[int]string baList *filter.Filter foreignKeyChecksEnabled bool - ddlRewriter *ddlrewriter.Rewriter + enableDDLRewrite bool getTableInfo func(tctx *tcontext.Context, sourceTable, targetTable *filter.Table) (*model.TableInfo, error) getDBInfoFromDownstream func(tctx *tcontext.Context, sourceTable, targetTable *filter.Table) (*model.DBInfo, error) @@ -113,7 +113,7 @@ func NewDDLWorker(pLogger *log.Logger, syncer *Syncer) *DDLWorker { idAndCollationMap: syncer.idAndCollationMap, baList: syncer.baList, foreignKeyChecksEnabled: config.IsForeignKeyChecksEnabled(syncer.cfg.To.Session), - ddlRewriter: syncer.ddlRewriter, + enableDDLRewrite: syncer.enableDDLRewrite, recordSkipSQLsLocation: syncer.recordSkipSQLsLocation, trackDDL: syncer.trackDDL, saveTablePoint: syncer.saveTablePoint, @@ -241,14 +241,14 @@ func (ddl *DDLWorker) HandleQueryEvent(ev *replication.QueryEvent, ec eventConte } qec := &queryEventContext{ - eventContext: &ec, - ddlSchema: string(ev.Schema), - originSQL: utils.TrimCtrlChars(originSQL), - ddlRewriter: ddl.ddlRewriter, - splitDDLs: make([]string, 0), - appliedDDLs: make([]string, 0), - sourceTbls: make(map[string]map[string]struct{}), - eventStatusVars: ev.StatusVars, + eventContext: &ec, + ddlSchema: string(ev.Schema), + originSQL: utils.TrimCtrlChars(originSQL), + enableDDLRewrite: ddl.enableDDLRewrite, + splitDDLs: make([]string, 0), + appliedDDLs: make([]string, 0), + sourceTbls: make(map[string]map[string]struct{}), + eventStatusVars: ev.StatusVars, } defer func() { @@ -969,10 +969,10 @@ func parseOneStmt(qec *queryEventContext) (stmt ast.StmtNode, err error) { return nil, nil } stmt = stmts[0] - if qec.ddlRewriter == nil { + if !qec.enableDDLRewrite { return stmt, nil } - changed, err := qec.ddlRewriter.RewriteStmt(stmt) + changed, err := ddlrewriter.RewriteStmt(stmt) if err != nil { return nil, terror.ErrRewriteSQL.Delegate(err, qec.originSQL) } diff --git a/dm/syncer/ddl_test.go b/dm/syncer/ddl_test.go index 79858bacb8..5f771d232a 100644 --- a/dm/syncer/ddl_test.go +++ b/dm/syncer/ddl_test.go @@ -31,7 +31,6 @@ import ( "github.com/pingcap/tiflow/dm/config/dbconfig" "github.com/pingcap/tiflow/dm/pkg/conn" tcontext "github.com/pingcap/tiflow/dm/pkg/context" - ddlrewriter "github.com/pingcap/tiflow/dm/pkg/ddl/rewriter" "github.com/pingcap/tiflow/dm/pkg/log" parserpkg "github.com/pingcap/tiflow/dm/pkg/parser" "github.com/pingcap/tiflow/dm/pkg/terror" @@ -423,11 +422,11 @@ func (s *testDDLSuite) TestParseOneStmt(c *check.C) { func TestParseOneStmtWithMariaDBASTRewriter(t *testing.T) { tctx := tcontext.Background().WithLogger(log.With(zap.String("test", "TestParseOneStmtWithMariaDBASTRewriter"))) qec := &queryEventContext{ - eventContext: &eventContext{tctx: tctx}, - ddlSchema: "test", - originSQL: "CREATE TABLE t(c VARCHAR(100) DEFAULT current_timestamp())", - p: parser.New(), - ddlRewriter: ddlrewriter.NewRewriter(), + eventContext: &eventContext{tctx: tctx}, + ddlSchema: "test", + originSQL: "CREATE TABLE t(c VARCHAR(100) DEFAULT current_timestamp())", + p: parser.New(), + enableDDLRewrite: true, } stmt, err := parseOneStmt(qec) diff --git a/dm/syncer/syncer.go b/dm/syncer/syncer.go index 454b66ed8b..ce2aea2b07 100644 --- a/dm/syncer/syncer.go +++ b/dm/syncer/syncer.go @@ -50,7 +50,6 @@ import ( "github.com/pingcap/tiflow/dm/pkg/binlog/reader" "github.com/pingcap/tiflow/dm/pkg/conn" tcontext "github.com/pingcap/tiflow/dm/pkg/context" - ddlrewriter "github.com/pingcap/tiflow/dm/pkg/ddl/rewriter" fr "github.com/pingcap/tiflow/dm/pkg/func-rollback" "github.com/pingcap/tiflow/dm/pkg/gtid" "github.com/pingcap/tiflow/dm/pkg/ha" @@ -267,7 +266,7 @@ type Syncer struct { idAndCollationMap map[int]string ddlWorker *DDLWorker - ddlRewriter *ddlrewriter.Rewriter + enableDDLRewrite bool fetchBinlogLogger *zap.Logger unhandledEventLogger *zap.Logger } @@ -535,9 +534,7 @@ func (s *Syncer) Init(ctx context.Context) (err error) { } s.metricsProxies = metricProxies.CacheForOneTask(s.cfg.Name, s.cfg.WorkerName, s.cfg.SourceID) - if strings.EqualFold(s.cfg.Flavor, mysql.MariaDBFlavor) { - s.ddlRewriter = ddlrewriter.NewRewriter() - } + s.enableDDLRewrite = strings.EqualFold(s.cfg.Flavor, mysql.MariaDBFlavor) s.ddlWorker = NewDDLWorker(&s.tctx.Logger, s) return nil } @@ -2773,10 +2770,10 @@ func (s *Syncer) handleRowsEvent(ev *replication.RowsEvent, ec eventContext) (*f type queryEventContext struct { *eventContext - p *parser.Parser // used parser - ddlSchema string // used schema - originSQL string // before split - ddlRewriter *ddlrewriter.Rewriter + p *parser.Parser // used parser + ddlSchema string // used schema + originSQL string // before split + enableDDLRewrite bool // split multi-schema change DDL into multiple one schema change DDL due to TiDB's limitation splitDDLs []string // after split before online ddl appliedDDLs []string // after onlineDDL apply if onlineDDL != nil @@ -2962,14 +2959,14 @@ func (s *Syncer) trackOriginDDL(ev *replication.QueryEvent, ec eventContext) (ma } var err error qec := &queryEventContext{ - eventContext: &ec, - ddlSchema: string(ev.Schema), - originSQL: utils.TrimCtrlChars(originSQL), - ddlRewriter: s.ddlRewriter, - splitDDLs: make([]string, 0), - appliedDDLs: make([]string, 0), - sourceTbls: make(map[string]map[string]struct{}), - eventStatusVars: ev.StatusVars, + eventContext: &ec, + ddlSchema: string(ev.Schema), + originSQL: utils.TrimCtrlChars(originSQL), + enableDDLRewrite: s.enableDDLRewrite, + splitDDLs: make([]string, 0), + appliedDDLs: make([]string, 0), + sourceTbls: make(map[string]map[string]struct{}), + eventStatusVars: ev.StatusVars, } qec.p, err = event.GetParserForStatusVars(ev.StatusVars) if err != nil { From 039acfd8dbbe8b519d42ed62999269b452276f7f Mon Sep 17 00:00:00 2001 From: Ruihao Chen Date: Tue, 16 Jun 2026 13:46:06 +0800 Subject: [PATCH 04/17] syncer(dm): document DDL rewrite rules --- dm/pkg/ddl/rewriter/rules.go | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/dm/pkg/ddl/rewriter/rules.go b/dm/pkg/ddl/rewriter/rules.go index 82f04d6297..5e3a0979f4 100644 --- a/dm/pkg/ddl/rewriter/rules.go +++ b/dm/pkg/ddl/rewriter/rules.go @@ -22,6 +22,9 @@ import ( tidbtypes "github.com/pingcap/tidb/pkg/types" ) +// maxVarcharLen keeps rewritten VARCHAR columns within TiDB's default maximum index length. +// TiDB limits an index to 3072 bytes, which is 768 characters with 4-byte UTF-8 encoding. +// See https://docs.pingcap.com/tidb/stable/tidb-limitations/#limitations-on-indexes. const maxVarcharLen = 768 var defaultRules = []rule{ @@ -36,6 +39,7 @@ var defaultRules = []rule{ jsonGeneratedRule{}, } +// collationRule rewrites common MariaDB-only or unsupported charsets/collations to TiDB-supported ones. type collationRule struct{} func (r collationRule) Apply(node ast.Node) (bool, error) { @@ -51,6 +55,7 @@ func (r collationRule) Apply(node ast.Node) (bool, error) { } } +// zeroTimestampRule removes zero date/time defaults that TiDB rejects. type zeroTimestampRule struct{} func (r zeroTimestampRule) Apply(node ast.Node) (bool, error) { @@ -63,6 +68,7 @@ func (r zeroTimestampRule) Apply(node ast.Node) (bool, error) { }), nil } +// keyLengthRule caps oversized VARCHAR columns so indexes on them can fit TiDB's index length limit. type keyLengthRule struct{} func (r keyLengthRule) Apply(node ast.Node) (bool, error) { @@ -77,6 +83,7 @@ func (r keyLengthRule) Apply(node ast.Node) (bool, error) { return false, nil } +// indexPrefixRule adds explicit prefix lengths for string and blob indexes that MariaDB allows to omit. type indexPrefixRule struct{} func (r indexPrefixRule) Apply(node ast.Node) (bool, error) { @@ -121,6 +128,7 @@ func (r indexPrefixRule) Apply(node ast.Node) (bool, error) { return changed, nil } +// integerWidthRule removes integer display widths that TiDB no longer preserves. type integerWidthRule struct{} func (r integerWidthRule) Apply(node ast.Node) (bool, error) { @@ -135,6 +143,7 @@ func (r integerWidthRule) Apply(node ast.Node) (bool, error) { return true, nil } +// textBlobDefaultRule removes defaults from TEXT, BLOB, and JSON columns that TiDB rejects. type textBlobDefaultRule struct{} func (r textBlobDefaultRule) Apply(node ast.Node) (bool, error) { @@ -147,6 +156,7 @@ func (r textBlobDefaultRule) Apply(node ast.Node) (bool, error) { }), nil } +// jsonCheckRule removes MariaDB JSON_VALID checks that duplicate JSON column validation. type jsonCheckRule struct{} func (r jsonCheckRule) Apply(node ast.Node) (bool, error) { @@ -172,6 +182,7 @@ func (r jsonCheckRule) Apply(node ast.Node) (bool, error) { } } +// functionDefaultRule removes function defaults from column types where TiDB does not allow them. type functionDefaultRule struct{} func (r functionDefaultRule) Apply(node ast.Node) (bool, error) { @@ -184,6 +195,7 @@ func (r functionDefaultRule) Apply(node ast.Node) (bool, error) { }), nil } +// jsonGeneratedRule removes generated expressions from JSON columns that TiDB rejects. type jsonGeneratedRule struct{} func (r jsonGeneratedRule) Apply(node ast.Node) (bool, error) { From 20192d55ccfef2752c541d1a81e79878d5c83a77 Mon Sep 17 00:00:00 2001 From: Ruihao Chen Date: Tue, 16 Jun 2026 14:27:26 +0800 Subject: [PATCH 05/17] syncer(dm): move shared DDL rewrite helpers --- dm/pkg/ddl/rewriter/rules.go | 33 ----------------------- dm/pkg/ddl/rewriter/utils.go | 52 ++++++++++++++++++++++++++++++++++++ 2 files changed, 52 insertions(+), 33 deletions(-) create mode 100644 dm/pkg/ddl/rewriter/utils.go diff --git a/dm/pkg/ddl/rewriter/rules.go b/dm/pkg/ddl/rewriter/rules.go index 5e3a0979f4..9501ebf02c 100644 --- a/dm/pkg/ddl/rewriter/rules.go +++ b/dm/pkg/ddl/rewriter/rules.go @@ -279,20 +279,6 @@ func mapCollation(collation string) (string, bool) { return "", false } -func filterColumnOptions(col *ast.ColumnDef, drop func(*ast.ColumnOption) bool) bool { - options := col.Options[:0] - changed := false - for _, opt := range col.Options { - if drop(opt) { - changed = true - continue - } - options = append(options, opt) - } - col.Options = options - return changed -} - func keepDefaultExpr(col *ast.ColumnDef, expr ast.ExprNode) bool { expr = unwrapParentheses(expr) if _, ok := expr.(ast.ValueExpr); ok { @@ -332,16 +318,6 @@ func isJSONGenerated(col *ast.ColumnDef) bool { return false } -func unwrapParentheses(expr ast.ExprNode) ast.ExprNode { - for { - p, ok := expr.(*ast.ParenthesesExpr) - if !ok { - return expr - } - expr = p.Expr - } -} - func isZeroTimeDefault(expr ast.ExprNode) bool { valExpr, ok := expr.(ast.ValueExpr) if !ok { @@ -380,15 +356,6 @@ func isZeroTimeString(value string) bool { return strings.Trim(rest[len("00:00:00."):], "0") == "" } -func isTimeType(tp byte) bool { - switch tp { - case mysql.TypeDate, mysql.TypeDatetime, mysql.TypeTimestamp: - return true - default: - return false - } -} - func isTextBlobOrJSON(ft *types.FieldType) bool { return types.IsTypeBlob(ft.GetType()) || ft.GetType() == mysql.TypeJSON } diff --git a/dm/pkg/ddl/rewriter/utils.go b/dm/pkg/ddl/rewriter/utils.go new file mode 100644 index 0000000000..6086f0b9f3 --- /dev/null +++ b/dm/pkg/ddl/rewriter/utils.go @@ -0,0 +1,52 @@ +// Copyright 2026 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package rewriter + +import ( + "github.com/pingcap/tidb/pkg/parser/ast" + "github.com/pingcap/tidb/pkg/parser/mysql" +) + +func filterColumnOptions(col *ast.ColumnDef, drop func(*ast.ColumnOption) bool) bool { + options := col.Options[:0] + changed := false + for _, opt := range col.Options { + if drop(opt) { + changed = true + continue + } + options = append(options, opt) + } + col.Options = options + return changed +} + +func isTimeType(tp byte) bool { + switch tp { + case mysql.TypeDate, mysql.TypeDatetime, mysql.TypeTimestamp: + return true + default: + return false + } +} + +func unwrapParentheses(expr ast.ExprNode) ast.ExprNode { + for { + p, ok := expr.(*ast.ParenthesesExpr) + if !ok { + return expr + } + expr = p.Expr + } +} From 1b1bd0061bc48186ae92900d9a8791c3eb8521ee Mon Sep 17 00:00:00 2001 From: Ruihao Chen Date: Tue, 16 Jun 2026 16:01:27 +0800 Subject: [PATCH 06/17] syncer(dm): move strict collation rewrite rule --- dm/pkg/ddl/rewriter/rewriter.go | 34 +++- dm/pkg/ddl/rewriter/rewriter_test.go | 2 +- dm/pkg/ddl/rewriter/rules.go | 2 +- dm/pkg/ddl/rewriter/strict_collation_rule.go | 188 +++++++++++++++++++ dm/syncer/ddl.go | 115 ++---------- dm/syncer/ddl_test.go | 37 ++-- 6 files changed, 257 insertions(+), 121 deletions(-) create mode 100644 dm/pkg/ddl/rewriter/strict_collation_rule.go diff --git a/dm/pkg/ddl/rewriter/rewriter.go b/dm/pkg/ddl/rewriter/rewriter.go index a9e7de612a..b034ac6168 100644 --- a/dm/pkg/ddl/rewriter/rewriter.go +++ b/dm/pkg/ddl/rewriter/rewriter.go @@ -21,12 +21,38 @@ type rule interface { Apply(ast.Node) (bool, error) } -// RewriteStmt applies all rules to stmt in place. -func RewriteStmt(stmt ast.StmtNode) (bool, error) { - if stmt == nil { +type rewriteOptions struct { + rules []rule +} + +// Option configures the AST rules used by RewriteStmt. +type Option interface { + apply(*rewriteOptions) +} + +type optionFunc func(*rewriteOptions) + +func (f optionFunc) apply(options *rewriteOptions) { + f(options) +} + +// WithMariaDBCompatibility enables MariaDB compatibility AST rewrite rules. +func WithMariaDBCompatibility() Option { + return optionFunc(func(options *rewriteOptions) { + options.rules = append(options.rules, mariaDBCompatibilityRules...) + }) +} + +// RewriteStmt applies enabled rules to stmt in place. +func RewriteStmt(stmt ast.StmtNode, opts ...Option) (bool, error) { + options := rewriteOptions{} + for _, opt := range opts { + opt.apply(&options) + } + if stmt == nil || len(options.rules) == 0 { return false, nil } - visitor := &rewriteVisitor{rules: defaultRules} + visitor := &rewriteVisitor{rules: options.rules} stmt.Accept(visitor) return visitor.changed, visitor.err } diff --git a/dm/pkg/ddl/rewriter/rewriter_test.go b/dm/pkg/ddl/rewriter/rewriter_test.go index d179557150..1e606c314a 100644 --- a/dm/pkg/ddl/rewriter/rewriter_test.go +++ b/dm/pkg/ddl/rewriter/rewriter_test.go @@ -86,7 +86,7 @@ func TestRewriteStmtRemovesParenthesizedJSONGeneratedColumn(t *testing.T) { func rewriteCreateTable(t *testing.T, sql string) (*ast.CreateTableStmt, bool) { t.Helper() stmt := parseCreateTable(t, sql) - changed, err := RewriteStmt(stmt) + changed, err := RewriteStmt(stmt, WithMariaDBCompatibility()) require.NoError(t, err) return stmt, changed } diff --git a/dm/pkg/ddl/rewriter/rules.go b/dm/pkg/ddl/rewriter/rules.go index 9501ebf02c..de7f5390f5 100644 --- a/dm/pkg/ddl/rewriter/rules.go +++ b/dm/pkg/ddl/rewriter/rules.go @@ -27,7 +27,7 @@ import ( // See https://docs.pingcap.com/tidb/stable/tidb-limitations/#limitations-on-indexes. const maxVarcharLen = 768 -var defaultRules = []rule{ +var mariaDBCompatibilityRules = []rule{ collationRule{}, zeroTimestampRule{}, keyLengthRule{}, diff --git a/dm/pkg/ddl/rewriter/strict_collation_rule.go b/dm/pkg/ddl/rewriter/strict_collation_rule.go new file mode 100644 index 0000000000..b0fb53ad3a --- /dev/null +++ b/dm/pkg/ddl/rewriter/strict_collation_rule.go @@ -0,0 +1,188 @@ +// Copyright 2026 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package rewriter + +import ( + "strings" + + "github.com/pingcap/tidb/pkg/parser/ast" + "github.com/pingcap/tiflow/dm/pkg/binlog/event" + "github.com/pingcap/tiflow/dm/pkg/log" + "go.uber.org/zap" +) + +// WithStrictCollation enables collation-compatible=strict AST rewrite rules. +func WithStrictCollation( + statusVars []byte, + charsetAndDefaultCollation map[string]string, + idAndCollationMap map[int]string, + originSQL string, + logger log.Logger, +) Option { + return optionFunc(func(options *rewriteOptions) { + options.rules = append(options.rules, strictCollationRule{ + statusVars: statusVars, + charsetAndDefaultCollation: charsetAndDefaultCollation, + idAndCollationMap: idAndCollationMap, + originSQL: originSQL, + logger: logger, + }) + }) +} + +// strictCollationRule adds explicit collations from upstream INFORMATION_SCHEMA.COLLATIONS. +type strictCollationRule struct { + statusVars []byte + charsetAndDefaultCollation map[string]string + idAndCollationMap map[int]string + originSQL string + logger log.Logger +} + +func (r strictCollationRule) Apply(node ast.Node) (bool, error) { + switch createStmt := node.(type) { + case *ast.CreateTableStmt: + return r.rewriteCreateTable(createStmt), nil + case *ast.CreateDatabaseStmt: + return r.rewriteCreateDatabase(createStmt), nil + default: + return false, nil + } +} + +func (r strictCollationRule) rewriteCreateTable(createStmt *ast.CreateTableStmt) bool { + if createStmt.ReferTable != nil { + return false + } + + changed := r.rewriteColumnCollations(createStmt) + var justCharset string + for _, tableOption := range createStmt.Options { + if tableOption.Tp == ast.TableOptionCollate { + return changed + } + if tableOption.Tp == ast.TableOptionCharset { + justCharset = tableOption.StrValue + } + } + if justCharset == "" { + r.warn("detect create table risk which use implicit charset and collation") + return changed + } + + collation, ok := r.charsetAndDefaultCollation[strings.ToLower(justCharset)] + if !ok { + r.warn("not found charset default collation.", zap.String("charset", strings.ToLower(justCharset))) + return changed + } + r.info( + "detect create table risk which use explicit charset and implicit collation, we will add collation by INFORMATION_SCHEMA.COLLATIONS", + zap.String("collation", collation), + ) + createStmt.Options = append(createStmt.Options, &ast.TableOption{Tp: ast.TableOptionCollate, StrValue: collation}) + return true +} + +func (r strictCollationRule) rewriteCreateDatabase(createStmt *ast.CreateDatabaseStmt) bool { + var justCharset string + for _, createOption := range createStmt.Options { + if createOption.Tp == ast.DatabaseOptionCollate { + return false + } + if createOption.Tp == ast.DatabaseOptionCharset { + justCharset = createOption.Value + } + } + + var collation string + if justCharset != "" { + var ok bool + collation, ok = r.charsetAndDefaultCollation[strings.ToLower(justCharset)] + if !ok { + r.warn("not found charset default collation.", zap.String("charset", strings.ToLower(justCharset))) + return false + } + r.info( + "detect create database risk which use explicit charset and implicit collation, we will add collation by INFORMATION_SCHEMA.COLLATIONS", + zap.String("collation", collation), + ) + } else { + var err error + collation, err = event.GetServerCollationByStatusVars(r.statusVars, r.idAndCollationMap) + if err != nil { + r.error("can not get charset server collation from binlog statusVars.", zap.Error(err)) + } + if collation == "" { + r.error("get server collation from binlog statusVars is nil.", zap.Error(err)) + return false + } + r.info( + "detect create database risk which use implicit charset and collation, we will add collation by binlog status_vars", + zap.String("collation", collation), + ) + } + createStmt.Options = append(createStmt.Options, &ast.DatabaseOption{Tp: ast.DatabaseOptionCollate, Value: collation}) + return true +} + +func (r strictCollationRule) rewriteColumnCollations(createStmt *ast.CreateTableStmt) bool { + changed := false +ColumnLoop: + for _, col := range createStmt.Cols { + for _, options := range col.Options { + if options.Tp == ast.ColumnOptionCollate { + continue ColumnLoop + } + } + fieldType := col.Tp + if fieldType.GetCollate() != "" || fieldType.GetCharset() == "" { + continue + } + + collation, ok := r.charsetAndDefaultCollation[strings.ToLower(fieldType.GetCharset())] + if !ok { + r.warn( + "not found charset default collation for column.", + zap.String("table", createStmt.Table.Name.String()), + zap.String("column", col.Name.String()), + zap.String("charset", strings.ToLower(fieldType.GetCharset())), + ) + continue + } + col.Options = append(col.Options, &ast.ColumnOption{Tp: ast.ColumnOptionCollate, StrValue: collation}) + changed = true + } + return changed +} + +func (r strictCollationRule) info(msg string, fields ...zap.Field) { + if r.logger.Logger == nil { + return + } + r.logger.Info(msg, append([]zap.Field{zap.String("originSQL", r.originSQL)}, fields...)...) +} + +func (r strictCollationRule) warn(msg string, fields ...zap.Field) { + if r.logger.Logger == nil { + return + } + r.logger.Warn(msg, append([]zap.Field{zap.String("originSQL", r.originSQL)}, fields...)...) +} + +func (r strictCollationRule) error(msg string, fields ...zap.Field) { + if r.logger.Logger == nil { + return + } + r.logger.Error(msg, append(fields, zap.String("originSQL", r.originSQL))...) +} diff --git a/dm/syncer/ddl.go b/dm/syncer/ddl.go index e4e5d30dac..77f3ef6de5 100644 --- a/dm/syncer/ddl.go +++ b/dm/syncer/ddl.go @@ -972,7 +972,7 @@ func parseOneStmt(qec *queryEventContext) (stmt ast.StmtNode, err error) { if !qec.enableDDLRewrite { return stmt, nil } - changed, err := ddlrewriter.RewriteStmt(stmt) + changed, err := ddlrewriter.RewriteStmt(stmt, ddlrewriter.WithMariaDBCompatibility()) if err != nil { return nil, terror.ErrRewriteSQL.Delegate(err, qec.originSQL) } @@ -1339,9 +1339,20 @@ func (ddl *DDLWorker) genDDLInfo(qec *queryEventContext, sql string) (*ddlInfo, targetTables: targetTables, } - // "strict" will adjust collation if ddl.collationCompatible == config.StrictCollationCompatible { - ddl.adjustCollation(ddlInfo, qec.eventStatusVars, ddl.charsetAndDefaultCollation, ddl.idAndCollationMap) + _, err := ddlrewriter.RewriteStmt( + ddlInfo.stmtCache, + ddlrewriter.WithStrictCollation( + qec.eventStatusVars, + ddl.charsetAndDefaultCollation, + ddl.idAndCollationMap, + ddlInfo.originDDL, + ddl.logger, + ), + ) + if err != nil { + return nil, terror.ErrRewriteSQL.Delegate(err, ddlInfo.originDDL) + } } routedDDL, err := parserpkg.RenameDDLTable(ddlInfo.stmtCache, ddlInfo.targetTables) @@ -1432,104 +1443,6 @@ func (ddl *Pessimist) clearOnlineDDL(tctx *tcontext.Context, targetTable *filter return nil } -// adjustCollation adds collation for create database and check create table. -func (ddl *DDLWorker) adjustCollation(ddlInfo *ddlInfo, statusVars []byte, charsetAndDefaultCollationMap map[string]string, idAndCollationMap map[int]string) { - switch createStmt := ddlInfo.stmtCache.(type) { - case *ast.CreateTableStmt: - if createStmt.ReferTable != nil { - return - } - ddl.adjustColumnsCollation(createStmt, charsetAndDefaultCollationMap) - var justCharset string - for _, tableOption := range createStmt.Options { - // already have 'Collation' - if tableOption.Tp == ast.TableOptionCollate { - return - } - if tableOption.Tp == ast.TableOptionCharset { - justCharset = tableOption.StrValue - } - } - if justCharset == "" { - ddl.logger.Warn("detect create table risk which use implicit charset and collation", zap.String("originSQL", ddlInfo.originDDL)) - return - } - // just has charset, can add collation by charset and default collation map - collation, ok := charsetAndDefaultCollationMap[strings.ToLower(justCharset)] - if !ok { - ddl.logger.Warn("not found charset default collation.", zap.String("originSQL", ddlInfo.originDDL), zap.String("charset", strings.ToLower(justCharset))) - return - } - ddl.logger.Info("detect create table risk which use explicit charset and implicit collation, we will add collation by INFORMATION_SCHEMA.COLLATIONS", zap.String("originSQL", ddlInfo.originDDL), zap.String("collation", collation)) - createStmt.Options = append(createStmt.Options, &ast.TableOption{Tp: ast.TableOptionCollate, StrValue: collation}) - - case *ast.CreateDatabaseStmt: - var justCharset, collation string - var ok bool - var err error - for _, createOption := range createStmt.Options { - // already have 'Collation' - if createOption.Tp == ast.DatabaseOptionCollate { - return - } - if createOption.Tp == ast.DatabaseOptionCharset { - justCharset = createOption.Value - } - } - - // just has charset, can add collation by charset and default collation map - if justCharset != "" { - collation, ok = charsetAndDefaultCollationMap[strings.ToLower(justCharset)] - if !ok { - ddl.logger.Warn("not found charset default collation.", zap.String("originSQL", ddlInfo.originDDL), zap.String("charset", strings.ToLower(justCharset))) - return - } - ddl.logger.Info("detect create database risk which use explicit charset and implicit collation, we will add collation by INFORMATION_SCHEMA.COLLATIONS", zap.String("originSQL", ddlInfo.originDDL), zap.String("collation", collation)) - } else { - // has no charset and collation - // add collation by server collation from binlog statusVars - collation, err = event.GetServerCollationByStatusVars(statusVars, idAndCollationMap) - if err != nil { - ddl.logger.Error("can not get charset server collation from binlog statusVars.", zap.Error(err), zap.String("originSQL", ddlInfo.originDDL)) - } - if collation == "" { - ddl.logger.Error("get server collation from binlog statusVars is nil.", zap.Error(err), zap.String("originSQL", ddlInfo.originDDL)) - return - } - // add collation - ddl.logger.Info("detect create database risk which use implicit charset and collation, we will add collation by binlog status_vars", zap.String("originSQL", ddlInfo.originDDL), zap.String("collation", collation)) - } - createStmt.Options = append(createStmt.Options, &ast.DatabaseOption{Tp: ast.DatabaseOptionCollate, Value: collation}) - } -} - -// adjustColumnsCollation adds column's collation. -func (ddl *DDLWorker) adjustColumnsCollation(createStmt *ast.CreateTableStmt, charsetAndDefaultCollationMap map[string]string) { -ColumnLoop: - for _, col := range createStmt.Cols { - for _, options := range col.Options { - // already have 'Collation' - if options.Tp == ast.ColumnOptionCollate { - continue ColumnLoop - } - } - fieldType := col.Tp - // already have 'Collation' - if fieldType.GetCollate() != "" { - continue - } - if fieldType.GetCharset() != "" { - // just have charset - collation, ok := charsetAndDefaultCollationMap[strings.ToLower(fieldType.GetCharset())] - if !ok { - ddl.logger.Warn("not found charset default collation for column.", zap.String("table", createStmt.Table.Name.String()), zap.String("column", col.Name.String()), zap.String("charset", strings.ToLower(fieldType.GetCharset()))) - continue - } - col.Options = append(col.Options, &ast.ColumnOption{Tp: ast.ColumnOptionCollate, StrValue: collation}) - } - } -} - type ddlInfo struct { originDDL string routedDDL string diff --git a/dm/syncer/ddl_test.go b/dm/syncer/ddl_test.go index 5f771d232a..7ced1f8bc8 100644 --- a/dm/syncer/ddl_test.go +++ b/dm/syncer/ddl_test.go @@ -31,6 +31,7 @@ import ( "github.com/pingcap/tiflow/dm/config/dbconfig" "github.com/pingcap/tiflow/dm/pkg/conn" tcontext "github.com/pingcap/tiflow/dm/pkg/context" + ddlrewriter "github.com/pingcap/tiflow/dm/pkg/ddl/rewriter" "github.com/pingcap/tiflow/dm/pkg/log" parserpkg "github.com/pingcap/tiflow/dm/pkg/parser" "github.com/pingcap/tiflow/dm/pkg/terror" @@ -747,11 +748,6 @@ func (s *testDDLSuite) TestAdjustDatabaseCollation(c *check.C) { } tctx := tcontext.Background().WithLogger(log.With(zap.String("test", "TestAdjustTableCollation"))) - syncer := NewSyncer(&config.SubTaskConfig{ - Flavor: mysql.MySQLFlavor, - CollationCompatible: config.StrictCollationCompatible, - }, nil, nil) - syncer.tctx = tctx p := parser.New() tab := &filter.Table{ Schema: "test", @@ -760,7 +756,6 @@ func (s *testDDLSuite) TestAdjustDatabaseCollation(c *check.C) { charsetAndDefaultCollationMap := map[string]string{"utf8mb4": "utf8mb4_general_ci"} idAndCollationMap := map[int]string{46: "utf8mb4_bin", 277: "utf8mb4_vi_0900_ai_ci"} - ddlWorker := NewDDLWorker(&tctx.Logger, syncer) for i, statusVars := range statusVarsArray { for j, sql := range sqls { ddlInfo := &ddlInfo{ @@ -773,7 +768,17 @@ func (s *testDDLSuite) TestAdjustDatabaseCollation(c *check.C) { c.Assert(err, check.IsNil) c.Assert(stmt, check.NotNil) ddlInfo.stmtCache = stmt - ddlWorker.adjustCollation(ddlInfo, statusVars, charsetAndDefaultCollationMap, idAndCollationMap) + _, err = ddlrewriter.RewriteStmt( + ddlInfo.stmtCache, + ddlrewriter.WithStrictCollation( + statusVars, + charsetAndDefaultCollationMap, + idAndCollationMap, + ddlInfo.originDDL, + tctx.Logger, + ), + ) + c.Assert(err, check.IsNil) routedDDL, err := parserpkg.RenameDDLTable(ddlInfo.stmtCache, ddlInfo.targetTables) c.Assert(err, check.IsNil) c.Assert(routedDDL, check.Equals, expectedSQLs[i][j]) @@ -827,11 +832,6 @@ func TestAdjustCollation(t *testing.T) { } tctx := tcontext.Background().WithLogger(log.With(zap.String("test", "TestAdjustTableCollation"))) - syncer := NewSyncer(&config.SubTaskConfig{ - Flavor: mysql.MySQLFlavor, - CollationCompatible: config.StrictCollationCompatible, - }, nil, nil) - syncer.tctx = tctx p := parser.New() tab := &filter.Table{ Schema: "test", @@ -840,7 +840,6 @@ func TestAdjustCollation(t *testing.T) { statusVars := []byte{4, 0, 0, 0, 0, 46, 0} charsetAndDefaultCollationMap := map[string]string{"utf8mb4": "utf8mb4_general_ci", "latin1": "latin1_swedish_ci"} idAndCollationMap := map[int]string{46: "utf8mb4_bin"} - ddlWorker := NewDDLWorker(&tctx.Logger, syncer) for i, sql := range sqls { ddlInfo := &ddlInfo{ originDDL: sql, @@ -852,7 +851,17 @@ func TestAdjustCollation(t *testing.T) { require.NoError(t, err) require.NotNil(t, stmt) ddlInfo.stmtCache = stmt - ddlWorker.adjustCollation(ddlInfo, statusVars, charsetAndDefaultCollationMap, idAndCollationMap) + _, err = ddlrewriter.RewriteStmt( + ddlInfo.stmtCache, + ddlrewriter.WithStrictCollation( + statusVars, + charsetAndDefaultCollationMap, + idAndCollationMap, + ddlInfo.originDDL, + tctx.Logger, + ), + ) + require.NoError(t, err) routedDDL, err := parserpkg.RenameDDLTable(ddlInfo.stmtCache, ddlInfo.targetTables) require.NoError(t, err) require.Equal(t, expectedSQLs[i], routedDDL) From b4ad61ba4cf37593ec8e5b1a419fc023cf056c07 Mon Sep 17 00:00:00 2001 From: Ruihao Chen Date: Tue, 16 Jun 2026 16:20:52 +0800 Subject: [PATCH 07/17] syncer(dm): simplify strict collation logging --- dm/pkg/ddl/rewriter/strict_collation_rule.go | 55 +++++++++----------- 1 file changed, 25 insertions(+), 30 deletions(-) diff --git a/dm/pkg/ddl/rewriter/strict_collation_rule.go b/dm/pkg/ddl/rewriter/strict_collation_rule.go index b0fb53ad3a..c55bc73dc5 100644 --- a/dm/pkg/ddl/rewriter/strict_collation_rule.go +++ b/dm/pkg/ddl/rewriter/strict_collation_rule.go @@ -31,6 +31,9 @@ func WithStrictCollation( logger log.Logger, ) Option { return optionFunc(func(options *rewriteOptions) { + if logger.Logger == nil { + logger = log.L() + } options.rules = append(options.rules, strictCollationRule{ statusVars: statusVars, charsetAndDefaultCollation: charsetAndDefaultCollation, @@ -77,17 +80,21 @@ func (r strictCollationRule) rewriteCreateTable(createStmt *ast.CreateTableStmt) } } if justCharset == "" { - r.warn("detect create table risk which use implicit charset and collation") + r.logger.Warn("detect create table risk which use implicit charset and collation", + zap.String("originSQL", r.originSQL)) return changed } collation, ok := r.charsetAndDefaultCollation[strings.ToLower(justCharset)] if !ok { - r.warn("not found charset default collation.", zap.String("charset", strings.ToLower(justCharset))) + r.logger.Warn("not found charset default collation.", + zap.String("originSQL", r.originSQL), + zap.String("charset", strings.ToLower(justCharset))) return changed } - r.info( + r.logger.Info( "detect create table risk which use explicit charset and implicit collation, we will add collation by INFORMATION_SCHEMA.COLLATIONS", + zap.String("originSQL", r.originSQL), zap.String("collation", collation), ) createStmt.Options = append(createStmt.Options, &ast.TableOption{Tp: ast.TableOptionCollate, StrValue: collation}) @@ -110,25 +117,33 @@ func (r strictCollationRule) rewriteCreateDatabase(createStmt *ast.CreateDatabas var ok bool collation, ok = r.charsetAndDefaultCollation[strings.ToLower(justCharset)] if !ok { - r.warn("not found charset default collation.", zap.String("charset", strings.ToLower(justCharset))) + r.logger.Warn("not found charset default collation.", + zap.String("originSQL", r.originSQL), + zap.String("charset", strings.ToLower(justCharset))) return false } - r.info( + r.logger.Info( "detect create database risk which use explicit charset and implicit collation, we will add collation by INFORMATION_SCHEMA.COLLATIONS", + zap.String("originSQL", r.originSQL), zap.String("collation", collation), ) } else { var err error collation, err = event.GetServerCollationByStatusVars(r.statusVars, r.idAndCollationMap) if err != nil { - r.error("can not get charset server collation from binlog statusVars.", zap.Error(err)) + r.logger.Error("can not get charset server collation from binlog statusVars.", + zap.Error(err), + zap.String("originSQL", r.originSQL)) } if collation == "" { - r.error("get server collation from binlog statusVars is nil.", zap.Error(err)) + r.logger.Error("get server collation from binlog statusVars is nil.", + zap.Error(err), + zap.String("originSQL", r.originSQL)) return false } - r.info( + r.logger.Info( "detect create database risk which use implicit charset and collation, we will add collation by binlog status_vars", + zap.String("originSQL", r.originSQL), zap.String("collation", collation), ) } @@ -152,8 +167,9 @@ ColumnLoop: collation, ok := r.charsetAndDefaultCollation[strings.ToLower(fieldType.GetCharset())] if !ok { - r.warn( + r.logger.Warn( "not found charset default collation for column.", + zap.String("originSQL", r.originSQL), zap.String("table", createStmt.Table.Name.String()), zap.String("column", col.Name.String()), zap.String("charset", strings.ToLower(fieldType.GetCharset())), @@ -165,24 +181,3 @@ ColumnLoop: } return changed } - -func (r strictCollationRule) info(msg string, fields ...zap.Field) { - if r.logger.Logger == nil { - return - } - r.logger.Info(msg, append([]zap.Field{zap.String("originSQL", r.originSQL)}, fields...)...) -} - -func (r strictCollationRule) warn(msg string, fields ...zap.Field) { - if r.logger.Logger == nil { - return - } - r.logger.Warn(msg, append([]zap.Field{zap.String("originSQL", r.originSQL)}, fields...)...) -} - -func (r strictCollationRule) error(msg string, fields ...zap.Field) { - if r.logger.Logger == nil { - return - } - r.logger.Error(msg, append(fields, zap.String("originSQL", r.originSQL))...) -} From e6449c85fa51904efdf84b857ddf7097e305d5cb Mon Sep 17 00:00:00 2001 From: Ruihao Chen Date: Tue, 16 Jun 2026 16:24:11 +0800 Subject: [PATCH 08/17] syncer(dm): rename MariaDB DDL rewrite rules --- dm/pkg/ddl/rewriter/{rules.go => mariadb_rules.go} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename dm/pkg/ddl/rewriter/{rules.go => mariadb_rules.go} (100%) diff --git a/dm/pkg/ddl/rewriter/rules.go b/dm/pkg/ddl/rewriter/mariadb_rules.go similarity index 100% rename from dm/pkg/ddl/rewriter/rules.go rename to dm/pkg/ddl/rewriter/mariadb_rules.go From ea849091fe5834ae692265db49e12b132d85842e Mon Sep 17 00:00:00 2001 From: Ruihao Chen Date: Tue, 16 Jun 2026 16:44:24 +0800 Subject: [PATCH 09/17] syncer(dm): document strict collation rewrite behavior --- dm/syncer/ddl.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/dm/syncer/ddl.go b/dm/syncer/ddl.go index 77f3ef6de5..765874bc39 100644 --- a/dm/syncer/ddl.go +++ b/dm/syncer/ddl.go @@ -1340,6 +1340,9 @@ func (ddl *DDLWorker) genDDLInfo(qec *queryEventContext, sql string) (*ddlInfo, } if ddl.collationCompatible == config.StrictCollationCompatible { + // Keep the original best-effort strict collation behavior: the rule logs and skips when + // upstream collation metadata is incomplete. The error check is only for the common + // rewriter interface; the strict collation rule itself does not fail the DDL. _, err := ddlrewriter.RewriteStmt( ddlInfo.stmtCache, ddlrewriter.WithStrictCollation( From 8dc8c38cf62b83a8ccd2492c5d1d270f7563c872 Mon Sep 17 00:00:00 2001 From: Ruihao Chen Date: Tue, 16 Jun 2026 16:49:39 +0800 Subject: [PATCH 10/17] syncer(dm): document DDL rewriter scope --- dm/pkg/ddl/rewriter/rewriter.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/dm/pkg/ddl/rewriter/rewriter.go b/dm/pkg/ddl/rewriter/rewriter.go index b034ac6168..759f373dc0 100644 --- a/dm/pkg/ddl/rewriter/rewriter.go +++ b/dm/pkg/ddl/rewriter/rewriter.go @@ -44,6 +44,8 @@ func WithMariaDBCompatibility() Option { } // RewriteStmt applies enabled rules to stmt in place. +// It is a best-effort compatibility layer for parsed AST nodes; parser failures and +// unsupported DDL validation failures are intentionally left to the normal DM flow. func RewriteStmt(stmt ast.StmtNode, opts ...Option) (bool, error) { options := rewriteOptions{} for _, opt := range opts { From 721cc8e1c6da01db7534d0523892aeac3b1b4abb Mon Sep 17 00:00:00 2001 From: Ruihao Chen Date: Tue, 16 Jun 2026 18:13:48 +0800 Subject: [PATCH 11/17] syncer(dm): narrow MariaDB DDL rewrite rules --- dm/pkg/ddl/rewriter/mariadb_rules.go | 279 +++++---------------------- dm/pkg/ddl/rewriter/rewriter_test.go | 53 +++-- 2 files changed, 86 insertions(+), 246 deletions(-) diff --git a/dm/pkg/ddl/rewriter/mariadb_rules.go b/dm/pkg/ddl/rewriter/mariadb_rules.go index de7f5390f5..65e56f6c72 100644 --- a/dm/pkg/ddl/rewriter/mariadb_rules.go +++ b/dm/pkg/ddl/rewriter/mariadb_rules.go @@ -14,76 +14,28 @@ package rewriter import ( - "strings" - "github.com/pingcap/tidb/pkg/parser/ast" "github.com/pingcap/tidb/pkg/parser/mysql" "github.com/pingcap/tidb/pkg/parser/types" tidbtypes "github.com/pingcap/tidb/pkg/types" ) -// maxVarcharLen keeps rewritten VARCHAR columns within TiDB's default maximum index length. +// maxStringIndexPrefixLen keeps rewritten string index prefixes within TiDB's default maximum index length. // TiDB limits an index to 3072 bytes, which is 768 characters with 4-byte UTF-8 encoding. // See https://docs.pingcap.com/tidb/stable/tidb-limitations/#limitations-on-indexes. -const maxVarcharLen = 768 +const maxStringIndexPrefixLen = 768 + +// defaultBlobIndexPrefixLen follows a common MySQL/MariaDB prefix length for BLOB/TEXT indexes. +const defaultBlobIndexPrefixLen = 255 var mariaDBCompatibilityRules = []rule{ - collationRule{}, - zeroTimestampRule{}, - keyLengthRule{}, indexPrefixRule{}, - integerWidthRule{}, textBlobDefaultRule{}, - jsonCheckRule{}, functionDefaultRule{}, jsonGeneratedRule{}, } -// collationRule rewrites common MariaDB-only or unsupported charsets/collations to TiDB-supported ones. -type collationRule struct{} - -func (r collationRule) Apply(node ast.Node) (bool, error) { - switch n := node.(type) { - case *ast.CreateDatabaseStmt: - return rewriteDatabaseOptions(n.Options), nil - case *ast.CreateTableStmt: - return rewriteTableOptions(&n.Options), nil - case *ast.ColumnDef: - return rewriteColumnCollations(n), nil - default: - return false, nil - } -} - -// zeroTimestampRule removes zero date/time defaults that TiDB rejects. -type zeroTimestampRule struct{} - -func (r zeroTimestampRule) Apply(node ast.Node) (bool, error) { - col, ok := node.(*ast.ColumnDef) - if !ok || !isTimeType(col.Tp.GetType()) { - return false, nil - } - return filterColumnOptions(col, func(opt *ast.ColumnOption) bool { - return opt.Tp == ast.ColumnOptionDefaultValue && isZeroTimeDefault(opt.Expr) - }), nil -} - -// keyLengthRule caps oversized VARCHAR columns so indexes on them can fit TiDB's index length limit. -type keyLengthRule struct{} - -func (r keyLengthRule) Apply(node ast.Node) (bool, error) { - col, ok := node.(*ast.ColumnDef) - if !ok || col.Tp.GetFlen() <= maxVarcharLen { - return false, nil - } - if tidbtypes.IsTypeVarchar(col.Tp.GetType()) { - col.Tp.SetFlen(maxVarcharLen) - return true, nil - } - return false, nil -} - -// indexPrefixRule adds explicit prefix lengths for string and blob indexes that MariaDB allows to omit. +// indexPrefixRule adds explicit prefix lengths for plain secondary indexes that TiDB rejects. type indexPrefixRule struct{} func (r indexPrefixRule) Apply(node ast.Node) (bool, error) { @@ -99,7 +51,7 @@ func (r indexPrefixRule) Apply(node ast.Node) (bool, error) { changed := false for _, cons := range stmt.Constraints { switch cons.Tp { - case ast.ConstraintPrimaryKey, ast.ConstraintKey, ast.ConstraintIndex, ast.ConstraintUniq: + case ast.ConstraintKey, ast.ConstraintIndex: default: continue } @@ -116,11 +68,11 @@ func (r indexPrefixRule) Apply(node ast.Node) (bool, error) { } switch { case types.IsTypeBlob(col.Tp.GetType()): - key.Length = 255 + key.Length = defaultBlobIndexPrefixLen changed = true case (tidbtypes.IsTypeChar(col.Tp.GetType()) || tidbtypes.IsTypeVarchar(col.Tp.GetType())) && - col.Tp.GetFlen() > 0: - key.Length = col.Tp.GetFlen() + col.Tp.GetFlen() > maxStringIndexPrefixLen: + key.Length = maxStringIndexPrefixLen changed = true } } @@ -128,22 +80,7 @@ func (r indexPrefixRule) Apply(node ast.Node) (bool, error) { return changed, nil } -// integerWidthRule removes integer display widths that TiDB no longer preserves. -type integerWidthRule struct{} - -func (r integerWidthRule) Apply(node ast.Node) (bool, error) { - col, ok := node.(*ast.ColumnDef) - if !ok || !mysql.IsIntegerType(col.Tp.GetType()) { - return false, nil - } - if col.Tp.GetFlen() == types.UnspecifiedLength || col.Tp.GetFlen() <= 0 { - return false, nil - } - col.Tp.SetFlen(types.UnspecifiedLength) - return true, nil -} - -// textBlobDefaultRule removes defaults from TEXT, BLOB, and JSON columns that TiDB rejects. +// textBlobDefaultRule removes non-NULL defaults from TEXT, BLOB, and JSON columns that TiDB rejects. type textBlobDefaultRule struct{} func (r textBlobDefaultRule) Apply(node ast.Node) (bool, error) { @@ -152,37 +89,11 @@ func (r textBlobDefaultRule) Apply(node ast.Node) (bool, error) { return false, nil } return filterColumnOptions(col, func(opt *ast.ColumnOption) bool { - return opt.Tp == ast.ColumnOptionDefaultValue + return opt.Tp == ast.ColumnOptionDefaultValue && !isNullValueExpr(opt.Expr) }), nil } -// jsonCheckRule removes MariaDB JSON_VALID checks that duplicate JSON column validation. -type jsonCheckRule struct{} - -func (r jsonCheckRule) Apply(node ast.Node) (bool, error) { - switch n := node.(type) { - case *ast.ColumnDef: - return filterColumnOptions(n, func(opt *ast.ColumnOption) bool { - return opt.Tp == ast.ColumnOptionCheck && isJSONValidExpr(opt.Expr) - }), nil - case *ast.CreateTableStmt: - constraints := n.Constraints[:0] - changed := false - for _, cons := range n.Constraints { - if cons.Tp == ast.ConstraintCheck && isJSONValidExpr(cons.Expr) { - changed = true - continue - } - constraints = append(constraints, cons) - } - n.Constraints = constraints - return changed, nil - default: - return false, nil - } -} - -// functionDefaultRule removes function defaults from column types where TiDB does not allow them. +// functionDefaultRule removes time-function defaults from column types that TiDB rejects. type functionDefaultRule struct{} func (r functionDefaultRule) Apply(node ast.Node) (bool, error) { @@ -191,169 +102,71 @@ func (r functionDefaultRule) Apply(node ast.Node) (bool, error) { return false, nil } return filterColumnOptions(col, func(opt *ast.ColumnOption) bool { - return opt.Tp == ast.ColumnOptionDefaultValue && !keepDefaultExpr(col, opt.Expr) + return opt.Tp == ast.ColumnOptionDefaultValue && isUnsupportedTimeDefault(col, opt.Expr) }), nil } -// jsonGeneratedRule removes generated expressions from JSON columns that TiDB rejects. +// jsonGeneratedRule rewrites MariaDB JSON_VALUE generated expressions to TiDB-supported JSON functions. type jsonGeneratedRule struct{} func (r jsonGeneratedRule) Apply(node ast.Node) (bool, error) { col, ok := node.(*ast.ColumnDef) - if !ok || !isJSONGenerated(col) { + if !ok { return false, nil } - return filterColumnOptions(col, func(opt *ast.ColumnOption) bool { - return opt.Tp == ast.ColumnOptionGenerated - }), nil -} - -func rewriteDatabaseOptions(options []*ast.DatabaseOption) bool { - changed := false - for _, opt := range options { - switch opt.Tp { - case ast.DatabaseOptionCharset: - if strings.EqualFold(opt.Value, "latin1") { - opt.Value = "utf8mb4" - changed = true - } - case ast.DatabaseOptionCollate: - if collation, ok := mapCollation(opt.Value); ok { - opt.Value = collation - changed = true - } - } - } - return changed -} - -func rewriteTableOptions(options *[]*ast.TableOption) bool { - changed := false - needCollate := "" - hasCollate := false - for _, opt := range *options { - switch opt.Tp { - case ast.TableOptionCharset: - if strings.EqualFold(opt.StrValue, "latin1") { - opt.StrValue = "utf8mb4" - needCollate = "utf8mb4_0900_ai_ci" - changed = true - } - case ast.TableOptionCollate: - hasCollate = true - if collation, ok := mapCollation(opt.StrValue); ok { - opt.StrValue = collation - changed = true - } - } - } - if needCollate != "" && !hasCollate { - *options = append(*options, &ast.TableOption{Tp: ast.TableOptionCollate, StrValue: needCollate}) - changed = true - } - return changed -} - -func rewriteColumnCollations(col *ast.ColumnDef) bool { changed := false for _, opt := range col.Options { - if opt.Tp != ast.ColumnOptionCollate { + if opt.Tp != ast.ColumnOptionGenerated { continue } - if collation, ok := mapCollation(opt.StrValue); ok { - opt.StrValue = collation + if expr, ok := rewriteJSONValueExpr(opt.Expr); ok { + opt.Expr = expr changed = true } } - return changed -} - -func mapCollation(collation string) (string, bool) { - name := strings.ToLower(collation) - if name == "latin1_swedish_ci" { - return "utf8mb4_0900_ai_ci", true - } - if strings.HasPrefix(name, "utf8mb4_unicode_") { - return "utf8mb4_0900_ai_ci", true - } - return "", false + return changed, nil } -func keepDefaultExpr(col *ast.ColumnDef, expr ast.ExprNode) bool { +func isUnsupportedTimeDefault(col *ast.ColumnDef, expr ast.ExprNode) bool { expr = unwrapParentheses(expr) - if _, ok := expr.(ast.ValueExpr); ok { - return true - } - fn, ok := expr.(*ast.FuncCallExpr) - if !ok { - return false - } - return isTimeType(col.Tp.GetType()) && allowedTimeDefaultFuncs[fn.FnName.L] -} - -var allowedTimeDefaultFuncs = map[string]bool{ - "current_timestamp": true, - "current_date": true, - "current_time": true, - "now": true, - "localtime": true, - "localtimestamp": true, -} - -func isJSONValidExpr(expr ast.ExprNode) bool { fn, ok := expr.(*ast.FuncCallExpr) - return ok && strings.EqualFold(fn.FnName.O, "json_valid") -} - -func isJSONGenerated(col *ast.ColumnDef) bool { - for _, opt := range col.Options { - if opt.Tp != ast.ColumnOptionGenerated { - continue - } - fn, ok := unwrapParentheses(opt.Expr).(*ast.FuncCallExpr) - if ok && strings.HasPrefix(fn.FnName.L, "json") { - return true - } - } - return false -} - -func isZeroTimeDefault(expr ast.ExprNode) bool { - valExpr, ok := expr.(ast.ValueExpr) if !ok { return false } - switch v := valExpr.GetValue().(type) { - case tidbtypes.Time: - return v.IsZero() || v.InvalidZero() - case string: - return isZeroTimeString(v) - case []byte: - return isZeroTimeString(string(v)) - case int: - return v == 0 - case int64: - return v == 0 - case uint64: - return v == 0 + switch fn.FnName.L { + case ast.CurrentTimestamp, ast.Now, ast.LocalTime, ast.LocalTimestamp: + return col.Tp.GetType() != mysql.TypeTimestamp && col.Tp.GetType() != mysql.TypeDatetime + case ast.CurrentDate: + return col.Tp.GetType() != mysql.TypeDate && col.Tp.GetType() != mysql.TypeDatetime + case ast.CurrentTime: + return col.Tp.GetType() != mysql.TypeDuration default: return false } } -func isZeroTimeString(value string) bool { - value = strings.TrimSpace(value) - if !strings.HasPrefix(value, "0000-00-00") { - return false +func rewriteJSONValueExpr(expr ast.ExprNode) (ast.ExprNode, bool) { + fn, ok := unwrapParentheses(expr).(*ast.FuncCallExpr) + if !ok || fn.FnName.L != "json_value" || len(fn.Args) != 2 { + return expr, false } - rest := strings.TrimSpace(strings.TrimPrefix(value, "0000-00-00")) - if rest == "" || rest == "00:00:00" { - return true + jsonExtract := &ast.FuncCallExpr{ + FnName: ast.NewCIStr(ast.JSONExtract), + Args: fn.Args, } - if !strings.HasPrefix(rest, "00:00:00.") { + return &ast.FuncCallExpr{ + FnName: ast.NewCIStr(ast.JSONUnquote), + Args: []ast.ExprNode{jsonExtract}, + }, true +} + +func isNullValueExpr(expr ast.ExprNode) bool { + expr = unwrapParentheses(expr) + valExpr, ok := expr.(ast.ValueExpr) + if !ok { return false } - return strings.Trim(rest[len("00:00:00."):], "0") == "" + return valExpr.GetValue() == nil } func isTextBlobOrJSON(ft *types.FieldType) bool { diff --git a/dm/pkg/ddl/rewriter/rewriter_test.go b/dm/pkg/ddl/rewriter/rewriter_test.go index 1e606c314a..541aa6e4ea 100644 --- a/dm/pkg/ddl/rewriter/rewriter_test.go +++ b/dm/pkg/ddl/rewriter/rewriter_test.go @@ -19,6 +19,7 @@ import ( "github.com/pingcap/tidb/pkg/parser" "github.com/pingcap/tidb/pkg/parser/ast" + "github.com/pingcap/tidb/pkg/parser/format" _ "github.com/pingcap/tidb/pkg/types/parser_driver" // register parser driver "github.com/stretchr/testify/require" ) @@ -42,26 +43,29 @@ func TestRewriteStmtDefaultRules(t *testing.T) { input := `CREATE TABLE t ( id INT(11), txt TEXT DEFAULT 'x', + txt_null TEXT DEFAULT NULL, v VARCHAR(800), j JSON, g JSON GENERATED ALWAYS AS (JSON_EXTRACT(j, '$.a')) VIRTUAL, zero_ts TIMESTAMP DEFAULT '0000-00-00 00:00:00', CHECK (json_valid(j)), KEY idx_txt (txt), - KEY idx_v (v) + KEY idx_v (v), + UNIQUE KEY uk_txt (txt) ) DEFAULT CHARSET=latin1 COLLATE=latin1_swedish_ci;` stmt, changed := rewriteCreateTable(t, input) require.True(t, changed) - require.Equal(t, "utf8mb4", findTableOption(stmt, ast.TableOptionCharset)) - require.Equal(t, "utf8mb4_0900_ai_ci", findTableOption(stmt, ast.TableOptionCollate)) - require.Equal(t, -1, findColumn(stmt, "id").Tp.GetFlen()) - require.Equal(t, 768, findColumn(stmt, "v").Tp.GetFlen()) + require.Equal(t, "latin1", findTableOption(stmt, ast.TableOptionCharset)) + require.Equal(t, "latin1_swedish_ci", findTableOption(stmt, ast.TableOptionCollate)) + require.Equal(t, 11, findColumn(stmt, "id").Tp.GetFlen()) + require.Equal(t, 800, findColumn(stmt, "v").Tp.GetFlen()) require.False(t, hasColumnOption(findColumn(stmt, "txt"), ast.ColumnOptionDefaultValue)) - require.False(t, hasColumnOption(findColumn(stmt, "g"), ast.ColumnOptionGenerated)) - require.False(t, hasColumnOption(findColumn(stmt, "zero_ts"), ast.ColumnOptionDefaultValue)) - require.False(t, hasJSONValidCheck(stmt)) + require.True(t, hasColumnOption(findColumn(stmt, "txt_null"), ast.ColumnOptionDefaultValue)) + require.True(t, hasColumnOption(findColumn(stmt, "g"), ast.ColumnOptionGenerated)) + require.True(t, hasColumnOption(findColumn(stmt, "zero_ts"), ast.ColumnOptionDefaultValue)) + require.True(t, hasCheckConstraint(stmt)) idxTxt := findConstraint(stmt, "idx_txt") require.NotNil(t, idxTxt) @@ -69,18 +73,24 @@ func TestRewriteStmtDefaultRules(t *testing.T) { idxV := findConstraint(stmt, "idx_v") require.NotNil(t, idxV) require.Equal(t, 768, idxV.Keys[0].Length) + ukTxt := findConstraint(stmt, "uk_txt") + require.NotNil(t, ukTxt) + require.Equal(t, -1, ukTxt.Keys[0].Length) } func TestRewriteStmtSkipsExpressionIndexPrefix(t *testing.T) { rewriteCreateTable(t, "CREATE TABLE t(name VARCHAR(32), KEY idx_expr ((LOWER(name))));") } -func TestRewriteStmtRemovesParenthesizedJSONGeneratedColumn(t *testing.T) { +func TestRewriteStmtRewritesParenthesizedJSONValueGeneratedColumn(t *testing.T) { stmt, changed := rewriteCreateTable(t, - "CREATE TABLE t(j JSON, g JSON GENERATED ALWAYS AS ((JSON_EXTRACT(j, '$.a'))) VIRTUAL);", + "CREATE TABLE t(j JSON, g VARCHAR(64) GENERATED ALWAYS AS ((JSON_VALUE(j, '$.a'))) VIRTUAL);", ) require.True(t, changed) - require.False(t, hasColumnOption(findColumn(stmt, "g"), ast.ColumnOptionGenerated)) + expr := generatedExpr(findColumn(stmt, "g")) + require.NotNil(t, expr) + restored := strings.ToLower(restoreNode(t, expr)) + require.Contains(t, restored, "json_unquote(json_extract") } func rewriteCreateTable(t *testing.T, sql string) (*ast.CreateTableStmt, bool) { @@ -118,15 +128,24 @@ func hasColumnOption(col *ast.ColumnDef, optionType ast.ColumnOptionType) bool { return false } -func hasJSONValidCheck(stmt *ast.CreateTableStmt) bool { +func hasCheckConstraint(stmt *ast.CreateTableStmt) bool { for _, cons := range stmt.Constraints { - if cons.Tp == ast.ConstraintCheck && isJSONValidExpr(cons.Expr) { + if cons.Tp == ast.ConstraintCheck { return true } } return false } +func generatedExpr(col *ast.ColumnDef) ast.ExprNode { + for _, opt := range col.Options { + if opt.Tp == ast.ColumnOptionGenerated { + return opt.Expr + } + } + return nil +} + func findConstraint(stmt *ast.CreateTableStmt, name string) *ast.Constraint { for _, cons := range stmt.Constraints { if strings.EqualFold(cons.Name, name) { @@ -144,3 +163,11 @@ func findTableOption(stmt *ast.CreateTableStmt, optionType ast.TableOptionType) } return "" } + +func restoreNode(t *testing.T, node ast.Node) string { + t.Helper() + var sb strings.Builder + ctx := format.NewRestoreCtx(format.DefaultRestoreFlags, &sb) + require.NoError(t, node.Restore(ctx)) + return sb.String() +} From 5285dfc93ca0895ac9cf627bbb8e5d950820fc4a Mon Sep 17 00:00:00 2001 From: Ruihao Chen Date: Wed, 17 Jun 2026 10:29:34 +0800 Subject: [PATCH 12/17] syncer(dm): refine MariaDB DDL rewrite coverage --- dm/pkg/ddl/rewriter/mariadb_rules.go | 34 +++++++++++++++++++++++----- dm/pkg/ddl/rewriter/rewriter_test.go | 25 ++++++++++++++++++-- 2 files changed, 51 insertions(+), 8 deletions(-) diff --git a/dm/pkg/ddl/rewriter/mariadb_rules.go b/dm/pkg/ddl/rewriter/mariadb_rules.go index 65e56f6c72..94a35b972a 100644 --- a/dm/pkg/ddl/rewriter/mariadb_rules.go +++ b/dm/pkg/ddl/rewriter/mariadb_rules.go @@ -80,7 +80,7 @@ func (r indexPrefixRule) Apply(node ast.Node) (bool, error) { return changed, nil } -// textBlobDefaultRule removes non-NULL defaults from TEXT, BLOB, and JSON columns that TiDB rejects. +// textBlobDefaultRule removes literal non-NULL defaults from TEXT, BLOB, and JSON columns that TiDB rejects. type textBlobDefaultRule struct{} func (r textBlobDefaultRule) Apply(node ast.Node) (bool, error) { @@ -89,7 +89,7 @@ func (r textBlobDefaultRule) Apply(node ast.Node) (bool, error) { return false, nil } return filterColumnOptions(col, func(opt *ast.ColumnOption) bool { - return opt.Tp == ast.ColumnOptionDefaultValue && !isNullValueExpr(opt.Expr) + return opt.Tp == ast.ColumnOptionDefaultValue && isNonNullLiteralValueExpr(opt.Expr) }), nil } @@ -146,10 +146,32 @@ func isUnsupportedTimeDefault(col *ast.ColumnDef, expr ast.ExprNode) bool { } func rewriteJSONValueExpr(expr ast.ExprNode) (ast.ExprNode, bool) { - fn, ok := unwrapParentheses(expr).(*ast.FuncCallExpr) - if !ok || fn.FnName.L != "json_value" || len(fn.Args) != 2 { + visitor := &jsonValueExprRewriteVisitor{} + node, ok := expr.Accept(visitor) + if !ok { return expr, false } + newExpr, ok := node.(ast.ExprNode) + if !ok { + return expr, false + } + return newExpr, visitor.changed +} + +type jsonValueExprRewriteVisitor struct { + changed bool +} + +func (v *jsonValueExprRewriteVisitor) Enter(node ast.Node) (ast.Node, bool) { + return node, false +} + +func (v *jsonValueExprRewriteVisitor) Leave(node ast.Node) (ast.Node, bool) { + fn, ok := node.(*ast.FuncCallExpr) + if !ok || fn.FnName.L != "json_value" || len(fn.Args) != 2 { + return node, true + } + v.changed = true jsonExtract := &ast.FuncCallExpr{ FnName: ast.NewCIStr(ast.JSONExtract), Args: fn.Args, @@ -160,13 +182,13 @@ func rewriteJSONValueExpr(expr ast.ExprNode) (ast.ExprNode, bool) { }, true } -func isNullValueExpr(expr ast.ExprNode) bool { +func isNonNullLiteralValueExpr(expr ast.ExprNode) bool { expr = unwrapParentheses(expr) valExpr, ok := expr.(ast.ValueExpr) if !ok { return false } - return valExpr.GetValue() == nil + return valExpr.GetValue() != nil } func isTextBlobOrJSON(ft *types.FieldType) bool { diff --git a/dm/pkg/ddl/rewriter/rewriter_test.go b/dm/pkg/ddl/rewriter/rewriter_test.go index 541aa6e4ea..46b9f58678 100644 --- a/dm/pkg/ddl/rewriter/rewriter_test.go +++ b/dm/pkg/ddl/rewriter/rewriter_test.go @@ -39,13 +39,31 @@ func TestRewriteStmtKeepsTimeFunctionDefaultOnTimeColumn(t *testing.T) { require.True(t, hasColumnOption(findColumn(stmt, "ts"), ast.ColumnOptionDefaultValue)) } +func TestRewriteStmtTimeFunctionDefaultRules(t *testing.T) { + stmt, changed := rewriteCreateTable(t, `CREATE TABLE t( + dec_col DECIMAL(30,6) DEFAULT CURRENT_TIMESTAMP(6), + time_col TIME DEFAULT CURRENT_TIMESTAMP(), + date_col DATE DEFAULT CURRENT_TIMESTAMP(), + ok_date DATE DEFAULT CURRENT_DATE, + ok_datetime DATETIME DEFAULT CURRENT_DATE +)`) + require.True(t, changed) + + require.False(t, hasColumnOption(findColumn(stmt, "dec_col"), ast.ColumnOptionDefaultValue)) + require.False(t, hasColumnOption(findColumn(stmt, "time_col"), ast.ColumnOptionDefaultValue)) + require.False(t, hasColumnOption(findColumn(stmt, "date_col"), ast.ColumnOptionDefaultValue)) + require.True(t, hasColumnOption(findColumn(stmt, "ok_date"), ast.ColumnOptionDefaultValue)) + require.True(t, hasColumnOption(findColumn(stmt, "ok_datetime"), ast.ColumnOptionDefaultValue)) +} + func TestRewriteStmtDefaultRules(t *testing.T) { input := `CREATE TABLE t ( id INT(11), txt TEXT DEFAULT 'x', txt_null TEXT DEFAULT NULL, + txt_expr TEXT DEFAULT(uuid()), v VARCHAR(800), - j JSON, + j JSON DEFAULT(json_object('now', now())), g JSON GENERATED ALWAYS AS (JSON_EXTRACT(j, '$.a')) VIRTUAL, zero_ts TIMESTAMP DEFAULT '0000-00-00 00:00:00', CHECK (json_valid(j)), @@ -63,6 +81,8 @@ func TestRewriteStmtDefaultRules(t *testing.T) { require.Equal(t, 800, findColumn(stmt, "v").Tp.GetFlen()) require.False(t, hasColumnOption(findColumn(stmt, "txt"), ast.ColumnOptionDefaultValue)) require.True(t, hasColumnOption(findColumn(stmt, "txt_null"), ast.ColumnOptionDefaultValue)) + require.True(t, hasColumnOption(findColumn(stmt, "txt_expr"), ast.ColumnOptionDefaultValue)) + require.True(t, hasColumnOption(findColumn(stmt, "j"), ast.ColumnOptionDefaultValue)) require.True(t, hasColumnOption(findColumn(stmt, "g"), ast.ColumnOptionGenerated)) require.True(t, hasColumnOption(findColumn(stmt, "zero_ts"), ast.ColumnOptionDefaultValue)) require.True(t, hasCheckConstraint(stmt)) @@ -84,13 +104,14 @@ func TestRewriteStmtSkipsExpressionIndexPrefix(t *testing.T) { func TestRewriteStmtRewritesParenthesizedJSONValueGeneratedColumn(t *testing.T) { stmt, changed := rewriteCreateTable(t, - "CREATE TABLE t(j JSON, g VARCHAR(64) GENERATED ALWAYS AS ((JSON_VALUE(j, '$.a'))) VIRTUAL);", + "CREATE TABLE t(j JSON, g TIME GENERATED ALWAYS AS (CAST((JSON_VALUE(j, '$.a')) AS TIME)) VIRTUAL);", ) require.True(t, changed) expr := generatedExpr(findColumn(stmt, "g")) require.NotNil(t, expr) restored := strings.ToLower(restoreNode(t, expr)) require.Contains(t, restored, "json_unquote(json_extract") + require.NotContains(t, restored, "json_value") } func rewriteCreateTable(t *testing.T, sql string) (*ast.CreateTableStmt, bool) { From bec29d3146ebeeaf8acde2044bb5928c87807d2a Mon Sep 17 00:00:00 2001 From: Ruihao Chen Date: Wed, 17 Jun 2026 16:03:14 +0800 Subject: [PATCH 13/17] syncer(dm): document SQL mode DDL rewrite boundary --- dm/pkg/ddl/rewriter/mariadb_rules.go | 3 +++ dm/pkg/ddl/rewriter/rewriter.go | 3 ++- dm/pkg/ddl/rewriter/rewriter_test.go | 13 +++++++++++++ 3 files changed, 18 insertions(+), 1 deletion(-) diff --git a/dm/pkg/ddl/rewriter/mariadb_rules.go b/dm/pkg/ddl/rewriter/mariadb_rules.go index 94a35b972a..6a44e9f217 100644 --- a/dm/pkg/ddl/rewriter/mariadb_rules.go +++ b/dm/pkg/ddl/rewriter/mariadb_rules.go @@ -29,6 +29,9 @@ const maxStringIndexPrefixLen = 768 const defaultBlobIndexPrefixLen = 255 var mariaDBCompatibilityRules = []rule{ + // Do not add rules for DDL failures that can be solved by downstream SQL mode. + // DM's AdjustSQLModeCompatible already disables strict zero-date checks by default, + // and users can override the downstream session SQL mode when needed. indexPrefixRule{}, textBlobDefaultRule{}, functionDefaultRule{}, diff --git a/dm/pkg/ddl/rewriter/rewriter.go b/dm/pkg/ddl/rewriter/rewriter.go index 759f373dc0..8e619ecab1 100644 --- a/dm/pkg/ddl/rewriter/rewriter.go +++ b/dm/pkg/ddl/rewriter/rewriter.go @@ -45,7 +45,8 @@ func WithMariaDBCompatibility() Option { // RewriteStmt applies enabled rules to stmt in place. // It is a best-effort compatibility layer for parsed AST nodes; parser failures and -// unsupported DDL validation failures are intentionally left to the normal DM flow. +// DDL failures that can be handled by downstream session settings, such as SQL mode, +// are intentionally left to the normal DM flow. func RewriteStmt(stmt ast.StmtNode, opts ...Option) (bool, error) { options := rewriteOptions{} for _, opt := range opts { diff --git a/dm/pkg/ddl/rewriter/rewriter_test.go b/dm/pkg/ddl/rewriter/rewriter_test.go index 46b9f58678..b27f48e3a5 100644 --- a/dm/pkg/ddl/rewriter/rewriter_test.go +++ b/dm/pkg/ddl/rewriter/rewriter_test.go @@ -39,6 +39,19 @@ func TestRewriteStmtKeepsTimeFunctionDefaultOnTimeColumn(t *testing.T) { require.True(t, hasColumnOption(findColumn(stmt, "ts"), ast.ColumnOptionDefaultValue)) } +func TestRewriteStmtKeepsZeroTimeDefaults(t *testing.T) { + stmt, changed := rewriteCreateTable(t, `CREATE TABLE t( + d DATE DEFAULT '0000-00-00', + dt DATETIME DEFAULT '0000-00-00 00:00:00', + ts TIMESTAMP DEFAULT '0000-00-00 00:00:00' +)`) + require.False(t, changed) + + require.True(t, hasColumnOption(findColumn(stmt, "d"), ast.ColumnOptionDefaultValue)) + require.True(t, hasColumnOption(findColumn(stmt, "dt"), ast.ColumnOptionDefaultValue)) + require.True(t, hasColumnOption(findColumn(stmt, "ts"), ast.ColumnOptionDefaultValue)) +} + func TestRewriteStmtTimeFunctionDefaultRules(t *testing.T) { stmt, changed := rewriteCreateTable(t, `CREATE TABLE t( dec_col DECIMAL(30,6) DEFAULT CURRENT_TIMESTAMP(6), From 75f7d44cce62983578eaa067bab41993bc08a3d6 Mon Sep 17 00:00:00 2001 From: Ruihao Chen Date: Wed, 17 Jun 2026 16:20:16 +0800 Subject: [PATCH 14/17] syncer(dm): remove unused DDL rewriter helper --- dm/pkg/ddl/rewriter/utils.go | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/dm/pkg/ddl/rewriter/utils.go b/dm/pkg/ddl/rewriter/utils.go index 6086f0b9f3..70f6a50818 100644 --- a/dm/pkg/ddl/rewriter/utils.go +++ b/dm/pkg/ddl/rewriter/utils.go @@ -15,7 +15,6 @@ package rewriter import ( "github.com/pingcap/tidb/pkg/parser/ast" - "github.com/pingcap/tidb/pkg/parser/mysql" ) func filterColumnOptions(col *ast.ColumnDef, drop func(*ast.ColumnOption) bool) bool { @@ -32,15 +31,6 @@ func filterColumnOptions(col *ast.ColumnDef, drop func(*ast.ColumnOption) bool) return changed } -func isTimeType(tp byte) bool { - switch tp { - case mysql.TypeDate, mysql.TypeDatetime, mysql.TypeTimestamp: - return true - default: - return false - } -} - func unwrapParentheses(expr ast.ExprNode) ast.ExprNode { for { p, ok := expr.(*ast.ParenthesesExpr) From 2d821f31e99f8623d176508e9f71ea28c9092f9d Mon Sep 17 00:00:00 2001 From: Ruihao Chen Date: Wed, 17 Jun 2026 05:17:16 -0400 Subject: [PATCH 15/17] Update Signed-off-by: Ruihao Chen --- dm/pkg/ddl/rewriter/mariadb_rules.go | 135 ++++++++++--------- dm/pkg/ddl/rewriter/rewriter.go | 22 +-- dm/pkg/ddl/rewriter/rewriter_test.go | 3 +- dm/pkg/ddl/rewriter/strict_collation_rule.go | 8 +- dm/pkg/ddl/rewriter/utils.go | 14 +- dm/syncer/ddl.go | 26 ++-- dm/syncer/ddl_test.go | 6 +- 7 files changed, 100 insertions(+), 114 deletions(-) diff --git a/dm/pkg/ddl/rewriter/mariadb_rules.go b/dm/pkg/ddl/rewriter/mariadb_rules.go index 6a44e9f217..c1b2c70271 100644 --- a/dm/pkg/ddl/rewriter/mariadb_rules.go +++ b/dm/pkg/ddl/rewriter/mariadb_rules.go @@ -20,8 +20,7 @@ import ( tidbtypes "github.com/pingcap/tidb/pkg/types" ) -// maxStringIndexPrefixLen keeps rewritten string index prefixes within TiDB's default maximum index length. -// TiDB limits an index to 3072 bytes, which is 768 characters with 4-byte UTF-8 encoding. +// maxStringIndexPrefixLen is TiDB's default maximum index length. // See https://docs.pingcap.com/tidb/stable/tidb-limitations/#limitations-on-indexes. const maxStringIndexPrefixLen = 768 @@ -29,22 +28,19 @@ const maxStringIndexPrefixLen = 768 const defaultBlobIndexPrefixLen = 255 var mariaDBCompatibilityRules = []rule{ - // Do not add rules for DDL failures that can be solved by downstream SQL mode. - // DM's AdjustSQLModeCompatible already disables strict zero-date checks by default, - // and users can override the downstream session SQL mode when needed. - indexPrefixRule{}, - textBlobDefaultRule{}, + secondaryIndexPrefixRule{}, + columnDefaultValueRule{}, functionDefaultRule{}, - jsonGeneratedRule{}, + jsonValueRule{}, } -// indexPrefixRule adds explicit prefix lengths for plain secondary indexes that TiDB rejects. -type indexPrefixRule struct{} +// secondaryIndexPrefixRule adds explicit prefix lengths for plain secondary indexes that TiDB rejects. +type secondaryIndexPrefixRule struct{} -func (r indexPrefixRule) Apply(node ast.Node) (bool, error) { +func (r secondaryIndexPrefixRule) Apply(node ast.Node) bool { stmt, ok := node.(*ast.CreateTableStmt) if !ok { - return false, nil + return false } colMap := make(map[string]*ast.ColumnDef, len(stmt.Cols)) for _, col := range stmt.Cols { @@ -59,10 +55,7 @@ func (r indexPrefixRule) Apply(node ast.Node) (bool, error) { continue } for _, key := range cons.Keys { - if key.Length > 0 { - continue - } - if key.Column == nil { + if key.Length > 0 || key.Column == nil { continue } col := colMap[key.Column.Name.L] @@ -73,64 +66,63 @@ func (r indexPrefixRule) Apply(node ast.Node) (bool, error) { case types.IsTypeBlob(col.Tp.GetType()): key.Length = defaultBlobIndexPrefixLen changed = true - case (tidbtypes.IsTypeChar(col.Tp.GetType()) || tidbtypes.IsTypeVarchar(col.Tp.GetType())) && - col.Tp.GetFlen() > maxStringIndexPrefixLen: - key.Length = maxStringIndexPrefixLen - changed = true + case tidbtypes.IsTypeChar(col.Tp.GetType()): + if col.Tp.GetFlen() > maxStringIndexPrefixLen { + key.Length = maxStringIndexPrefixLen + changed = true + } } } } - return changed, nil + return changed } -// textBlobDefaultRule removes literal non-NULL defaults from TEXT, BLOB, and JSON columns that TiDB rejects. -type textBlobDefaultRule struct{} +// columnDefaultValueRule removes literal non-NULL defaults from TEXT/BLOB/JSON columns. +type columnDefaultValueRule struct{} -func (r textBlobDefaultRule) Apply(node ast.Node) (bool, error) { +func (r columnDefaultValueRule) Apply(node ast.Node) bool { col, ok := node.(*ast.ColumnDef) if !ok || !isTextBlobOrJSON(col.Tp) { - return false, nil + return false } - return filterColumnOptions(col, func(opt *ast.ColumnOption) bool { - return opt.Tp == ast.ColumnOptionDefaultValue && isNonNullLiteralValueExpr(opt.Expr) - }), nil + return filterColumnOptions( + col, + func(opt *ast.ColumnOption) (bool, bool) { + dropped := opt.Tp == ast.ColumnOptionDefaultValue && + isNonNullLiteralValueExpr(opt.Expr) + return dropped, dropped + }, + ) } -// functionDefaultRule removes time-function defaults from column types that TiDB rejects. -type functionDefaultRule struct{} - -func (r functionDefaultRule) Apply(node ast.Node) (bool, error) { - col, ok := node.(*ast.ColumnDef) +func isNonNullLiteralValueExpr(expr ast.ExprNode) bool { + expr = unwrapParentheses(expr) + valExpr, ok := expr.(ast.ValueExpr) if !ok { - return false, nil + return false } - return filterColumnOptions(col, func(opt *ast.ColumnOption) bool { - return opt.Tp == ast.ColumnOptionDefaultValue && isUnsupportedTimeDefault(col, opt.Expr) - }), nil + return valExpr.GetValue() != nil } -// jsonGeneratedRule rewrites MariaDB JSON_VALUE generated expressions to TiDB-supported JSON functions. -type jsonGeneratedRule struct{} +// functionDefaultRule removes time-function defaults from column types that TiDB rejects. +type functionDefaultRule struct{} -func (r jsonGeneratedRule) Apply(node ast.Node) (bool, error) { +func (r functionDefaultRule) Apply(node ast.Node) bool { col, ok := node.(*ast.ColumnDef) if !ok { - return false, nil - } - changed := false - for _, opt := range col.Options { - if opt.Tp != ast.ColumnOptionGenerated { - continue - } - if expr, ok := rewriteJSONValueExpr(opt.Expr); ok { - opt.Expr = expr - changed = true - } + return false } - return changed, nil + return filterColumnOptions( + col, + func(opt *ast.ColumnOption) (bool, bool) { + dropped := opt.Tp == ast.ColumnOptionDefaultValue && + isUnsupportedTimeDefault(col.Tp.GetType(), opt.Expr) + return dropped, dropped + }, + ) } -func isUnsupportedTimeDefault(col *ast.ColumnDef, expr ast.ExprNode) bool { +func isUnsupportedTimeDefault(colType byte, expr ast.ExprNode) bool { expr = unwrapParentheses(expr) fn, ok := expr.(*ast.FuncCallExpr) if !ok { @@ -138,16 +130,38 @@ func isUnsupportedTimeDefault(col *ast.ColumnDef, expr ast.ExprNode) bool { } switch fn.FnName.L { case ast.CurrentTimestamp, ast.Now, ast.LocalTime, ast.LocalTimestamp: - return col.Tp.GetType() != mysql.TypeTimestamp && col.Tp.GetType() != mysql.TypeDatetime + return colType != mysql.TypeTimestamp && colType != mysql.TypeDatetime case ast.CurrentDate: - return col.Tp.GetType() != mysql.TypeDate && col.Tp.GetType() != mysql.TypeDatetime + return colType != mysql.TypeDate && colType != mysql.TypeDatetime case ast.CurrentTime: - return col.Tp.GetType() != mysql.TypeDuration + return colType != mysql.TypeDuration default: return false } } +// jsonValueRule rewrites MariaDB JSON_VALUE to supported JSON functions. +type jsonValueRule struct{} + +func (r jsonValueRule) Apply(node ast.Node) bool { + col, ok := node.(*ast.ColumnDef) + if !ok { + return false + } + return filterColumnOptions( + col, + func(opt *ast.ColumnOption) (bool, bool) { + if opt.Tp == ast.ColumnOptionGenerated { + if expr, ok := rewriteJSONValueExpr(opt.Expr); ok { + opt.Expr = expr + return true, false + } + } + return false, false + }, + ) +} + func rewriteJSONValueExpr(expr ast.ExprNode) (ast.ExprNode, bool) { visitor := &jsonValueExprRewriteVisitor{} node, ok := expr.Accept(visitor) @@ -185,15 +199,6 @@ func (v *jsonValueExprRewriteVisitor) Leave(node ast.Node) (ast.Node, bool) { }, true } -func isNonNullLiteralValueExpr(expr ast.ExprNode) bool { - expr = unwrapParentheses(expr) - valExpr, ok := expr.(ast.ValueExpr) - if !ok { - return false - } - return valExpr.GetValue() != nil -} - func isTextBlobOrJSON(ft *types.FieldType) bool { return types.IsTypeBlob(ft.GetType()) || ft.GetType() == mysql.TypeJSON } diff --git a/dm/pkg/ddl/rewriter/rewriter.go b/dm/pkg/ddl/rewriter/rewriter.go index 8e619ecab1..12dffe7abc 100644 --- a/dm/pkg/ddl/rewriter/rewriter.go +++ b/dm/pkg/ddl/rewriter/rewriter.go @@ -17,8 +17,9 @@ import ( "github.com/pingcap/tidb/pkg/parser/ast" ) +// rule defines a rule to apply to AST nodes with best effort. type rule interface { - Apply(ast.Node) (bool, error) + Apply(ast.Node) bool } type rewriteOptions struct { @@ -47,42 +48,31 @@ func WithMariaDBCompatibility() Option { // It is a best-effort compatibility layer for parsed AST nodes; parser failures and // DDL failures that can be handled by downstream session settings, such as SQL mode, // are intentionally left to the normal DM flow. -func RewriteStmt(stmt ast.StmtNode, opts ...Option) (bool, error) { +func RewriteStmt(stmt ast.StmtNode, opts ...Option) bool { options := rewriteOptions{} for _, opt := range opts { opt.apply(&options) } if stmt == nil || len(options.rules) == 0 { - return false, nil + return false } visitor := &rewriteVisitor{rules: options.rules} stmt.Accept(visitor) - return visitor.changed, visitor.err + return visitor.changed } type rewriteVisitor struct { rules []rule changed bool - err error } func (v *rewriteVisitor) Enter(node ast.Node) (ast.Node, bool) { - if v.err != nil { - return node, true - } return node, false } func (v *rewriteVisitor) Leave(node ast.Node) (ast.Node, bool) { - if v.err != nil { - return node, false - } for _, rule := range v.rules { - changed, err := rule.Apply(node) - if err != nil { - v.err = err - return node, false - } + changed := rule.Apply(node) v.changed = v.changed || changed } return node, true diff --git a/dm/pkg/ddl/rewriter/rewriter_test.go b/dm/pkg/ddl/rewriter/rewriter_test.go index b27f48e3a5..76c89d8e6f 100644 --- a/dm/pkg/ddl/rewriter/rewriter_test.go +++ b/dm/pkg/ddl/rewriter/rewriter_test.go @@ -130,8 +130,7 @@ func TestRewriteStmtRewritesParenthesizedJSONValueGeneratedColumn(t *testing.T) func rewriteCreateTable(t *testing.T, sql string) (*ast.CreateTableStmt, bool) { t.Helper() stmt := parseCreateTable(t, sql) - changed, err := RewriteStmt(stmt, WithMariaDBCompatibility()) - require.NoError(t, err) + changed := RewriteStmt(stmt, WithMariaDBCompatibility()) return stmt, changed } diff --git a/dm/pkg/ddl/rewriter/strict_collation_rule.go b/dm/pkg/ddl/rewriter/strict_collation_rule.go index c55bc73dc5..8cc75be0e6 100644 --- a/dm/pkg/ddl/rewriter/strict_collation_rule.go +++ b/dm/pkg/ddl/rewriter/strict_collation_rule.go @@ -53,14 +53,14 @@ type strictCollationRule struct { logger log.Logger } -func (r strictCollationRule) Apply(node ast.Node) (bool, error) { +func (r strictCollationRule) Apply(node ast.Node) bool { switch createStmt := node.(type) { case *ast.CreateTableStmt: - return r.rewriteCreateTable(createStmt), nil + return r.rewriteCreateTable(createStmt) case *ast.CreateDatabaseStmt: - return r.rewriteCreateDatabase(createStmt), nil + return r.rewriteCreateDatabase(createStmt) default: - return false, nil + return false } } diff --git a/dm/pkg/ddl/rewriter/utils.go b/dm/pkg/ddl/rewriter/utils.go index 70f6a50818..a91e917d39 100644 --- a/dm/pkg/ddl/rewriter/utils.go +++ b/dm/pkg/ddl/rewriter/utils.go @@ -17,15 +17,19 @@ import ( "github.com/pingcap/tidb/pkg/parser/ast" ) -func filterColumnOptions(col *ast.ColumnDef, drop func(*ast.ColumnOption) bool) bool { +func filterColumnOptions( + col *ast.ColumnDef, + filterFunc func(*ast.ColumnOption) (changed bool, drop bool), +) bool { options := col.Options[:0] changed := false for _, opt := range col.Options { - if drop(opt) { - changed = true - continue + c, drop := filterFunc(opt) + if !drop { + options = append(options, opt) } - options = append(options, opt) + changed = changed || c + } col.Options = options return changed diff --git a/dm/syncer/ddl.go b/dm/syncer/ddl.go index 765874bc39..d7a3155289 100644 --- a/dm/syncer/ddl.go +++ b/dm/syncer/ddl.go @@ -969,17 +969,13 @@ func parseOneStmt(qec *queryEventContext) (stmt ast.StmtNode, err error) { return nil, nil } stmt = stmts[0] - if !qec.enableDDLRewrite { - return stmt, nil - } - changed, err := ddlrewriter.RewriteStmt(stmt, ddlrewriter.WithMariaDBCompatibility()) - if err != nil { - return nil, terror.ErrRewriteSQL.Delegate(err, qec.originSQL) - } - if changed { - qec.tctx.L().Info("rewrite MariaDB DDL with AST compatibility rules", - zap.String("event", "query"), - zap.String("originSQL", qec.originSQL)) + if qec.enableDDLRewrite { + changed := ddlrewriter.RewriteStmt(stmt, ddlrewriter.WithMariaDBCompatibility()) + if changed { + qec.tctx.L().Info("rewrite MariaDB DDL with AST compatibility rules", + zap.String("event", "query"), + zap.String("originSQL", qec.originSQL)) + } } return stmt, nil } @@ -1340,10 +1336,7 @@ func (ddl *DDLWorker) genDDLInfo(qec *queryEventContext, sql string) (*ddlInfo, } if ddl.collationCompatible == config.StrictCollationCompatible { - // Keep the original best-effort strict collation behavior: the rule logs and skips when - // upstream collation metadata is incomplete. The error check is only for the common - // rewriter interface; the strict collation rule itself does not fail the DDL. - _, err := ddlrewriter.RewriteStmt( + ddlrewriter.RewriteStmt( ddlInfo.stmtCache, ddlrewriter.WithStrictCollation( qec.eventStatusVars, @@ -1353,9 +1346,6 @@ func (ddl *DDLWorker) genDDLInfo(qec *queryEventContext, sql string) (*ddlInfo, ddl.logger, ), ) - if err != nil { - return nil, terror.ErrRewriteSQL.Delegate(err, ddlInfo.originDDL) - } } routedDDL, err := parserpkg.RenameDDLTable(ddlInfo.stmtCache, ddlInfo.targetTables) diff --git a/dm/syncer/ddl_test.go b/dm/syncer/ddl_test.go index 7ced1f8bc8..8ed087d0fc 100644 --- a/dm/syncer/ddl_test.go +++ b/dm/syncer/ddl_test.go @@ -768,7 +768,7 @@ func (s *testDDLSuite) TestAdjustDatabaseCollation(c *check.C) { c.Assert(err, check.IsNil) c.Assert(stmt, check.NotNil) ddlInfo.stmtCache = stmt - _, err = ddlrewriter.RewriteStmt( + ddlrewriter.RewriteStmt( ddlInfo.stmtCache, ddlrewriter.WithStrictCollation( statusVars, @@ -778,7 +778,6 @@ func (s *testDDLSuite) TestAdjustDatabaseCollation(c *check.C) { tctx.Logger, ), ) - c.Assert(err, check.IsNil) routedDDL, err := parserpkg.RenameDDLTable(ddlInfo.stmtCache, ddlInfo.targetTables) c.Assert(err, check.IsNil) c.Assert(routedDDL, check.Equals, expectedSQLs[i][j]) @@ -851,7 +850,7 @@ func TestAdjustCollation(t *testing.T) { require.NoError(t, err) require.NotNil(t, stmt) ddlInfo.stmtCache = stmt - _, err = ddlrewriter.RewriteStmt( + ddlrewriter.RewriteStmt( ddlInfo.stmtCache, ddlrewriter.WithStrictCollation( statusVars, @@ -861,7 +860,6 @@ func TestAdjustCollation(t *testing.T) { tctx.Logger, ), ) - require.NoError(t, err) routedDDL, err := parserpkg.RenameDDLTable(ddlInfo.stmtCache, ddlInfo.targetTables) require.NoError(t, err) require.Equal(t, expectedSQLs[i], routedDDL) From 4578326d0c9c6791a488f73675011cd6fa17f53d Mon Sep 17 00:00:00 2001 From: Ruihao Chen Date: Wed, 17 Jun 2026 17:21:13 +0800 Subject: [PATCH 16/17] syncer(dm): simplify DDL rewriter helpers --- dm/pkg/ddl/rewriter/mariadb_rules.go | 53 +++++++++++----------------- dm/pkg/ddl/rewriter/utils.go | 18 ++++------ 2 files changed, 27 insertions(+), 44 deletions(-) diff --git a/dm/pkg/ddl/rewriter/mariadb_rules.go b/dm/pkg/ddl/rewriter/mariadb_rules.go index c1b2c70271..8f428689bb 100644 --- a/dm/pkg/ddl/rewriter/mariadb_rules.go +++ b/dm/pkg/ddl/rewriter/mariadb_rules.go @@ -17,10 +17,10 @@ import ( "github.com/pingcap/tidb/pkg/parser/ast" "github.com/pingcap/tidb/pkg/parser/mysql" "github.com/pingcap/tidb/pkg/parser/types" - tidbtypes "github.com/pingcap/tidb/pkg/types" ) -// maxStringIndexPrefixLen is TiDB's default maximum index length. +// maxStringIndexPrefixLen keeps rewritten string index prefixes within TiDB's default maximum index length. +// TiDB limits an index to 3072 bytes, which is 768 characters with 4-byte UTF-8 encoding. // See https://docs.pingcap.com/tidb/stable/tidb-limitations/#limitations-on-indexes. const maxStringIndexPrefixLen = 768 @@ -66,7 +66,7 @@ func (r secondaryIndexPrefixRule) Apply(node ast.Node) bool { case types.IsTypeBlob(col.Tp.GetType()): key.Length = defaultBlobIndexPrefixLen changed = true - case tidbtypes.IsTypeChar(col.Tp.GetType()): + case types.IsTypeChar(col.Tp.GetType()): if col.Tp.GetFlen() > maxStringIndexPrefixLen { key.Length = maxStringIndexPrefixLen changed = true @@ -85,14 +85,9 @@ func (r columnDefaultValueRule) Apply(node ast.Node) bool { if !ok || !isTextBlobOrJSON(col.Tp) { return false } - return filterColumnOptions( - col, - func(opt *ast.ColumnOption) (bool, bool) { - dropped := opt.Tp == ast.ColumnOptionDefaultValue && - isNonNullLiteralValueExpr(opt.Expr) - return dropped, dropped - }, - ) + return filterColumnOptions(col, func(opt *ast.ColumnOption) bool { + return opt.Tp == ast.ColumnOptionDefaultValue && isNonNullLiteralValueExpr(opt.Expr) + }) } func isNonNullLiteralValueExpr(expr ast.ExprNode) bool { @@ -112,14 +107,9 @@ func (r functionDefaultRule) Apply(node ast.Node) bool { if !ok { return false } - return filterColumnOptions( - col, - func(opt *ast.ColumnOption) (bool, bool) { - dropped := opt.Tp == ast.ColumnOptionDefaultValue && - isUnsupportedTimeDefault(col.Tp.GetType(), opt.Expr) - return dropped, dropped - }, - ) + return filterColumnOptions(col, func(opt *ast.ColumnOption) bool { + return opt.Tp == ast.ColumnOptionDefaultValue && isUnsupportedTimeDefault(col.Tp.GetType(), opt.Expr) + }) } func isUnsupportedTimeDefault(colType byte, expr ast.ExprNode) bool { @@ -140,7 +130,7 @@ func isUnsupportedTimeDefault(colType byte, expr ast.ExprNode) bool { } } -// jsonValueRule rewrites MariaDB JSON_VALUE to supported JSON functions. +// jsonValueRule rewrites MariaDB JSON_VALUE in generated column expressions to supported JSON functions. type jsonValueRule struct{} func (r jsonValueRule) Apply(node ast.Node) bool { @@ -148,18 +138,17 @@ func (r jsonValueRule) Apply(node ast.Node) bool { if !ok { return false } - return filterColumnOptions( - col, - func(opt *ast.ColumnOption) (bool, bool) { - if opt.Tp == ast.ColumnOptionGenerated { - if expr, ok := rewriteJSONValueExpr(opt.Expr); ok { - opt.Expr = expr - return true, false - } - } - return false, false - }, - ) + changed := false + for _, opt := range col.Options { + if opt.Tp != ast.ColumnOptionGenerated { + continue + } + if expr, ok := rewriteJSONValueExpr(opt.Expr); ok { + opt.Expr = expr + changed = true + } + } + return changed } func rewriteJSONValueExpr(expr ast.ExprNode) (ast.ExprNode, bool) { diff --git a/dm/pkg/ddl/rewriter/utils.go b/dm/pkg/ddl/rewriter/utils.go index a91e917d39..f924f53fab 100644 --- a/dm/pkg/ddl/rewriter/utils.go +++ b/dm/pkg/ddl/rewriter/utils.go @@ -13,23 +13,17 @@ package rewriter -import ( - "github.com/pingcap/tidb/pkg/parser/ast" -) +import "github.com/pingcap/tidb/pkg/parser/ast" -func filterColumnOptions( - col *ast.ColumnDef, - filterFunc func(*ast.ColumnOption) (changed bool, drop bool), -) bool { +func filterColumnOptions(col *ast.ColumnDef, drop func(*ast.ColumnOption) bool) bool { options := col.Options[:0] changed := false for _, opt := range col.Options { - c, drop := filterFunc(opt) - if !drop { - options = append(options, opt) + if drop(opt) { + changed = true + continue } - changed = changed || c - + options = append(options, opt) } col.Options = options return changed From 5970724ec3bb523796ebc172ac90d9dfeb09b370 Mon Sep 17 00:00:00 2001 From: Ruihao Chen Date: Wed, 17 Jun 2026 05:25:20 -0400 Subject: [PATCH 17/17] Update Signed-off-by: Ruihao Chen --- dm/pkg/ddl/rewriter/mariadb_rules.go | 36 +++++++++++++++------------- 1 file changed, 20 insertions(+), 16 deletions(-) diff --git a/dm/pkg/ddl/rewriter/mariadb_rules.go b/dm/pkg/ddl/rewriter/mariadb_rules.go index 8f428689bb..8d679bfb22 100644 --- a/dm/pkg/ddl/rewriter/mariadb_rules.go +++ b/dm/pkg/ddl/rewriter/mariadb_rules.go @@ -19,13 +19,15 @@ import ( "github.com/pingcap/tidb/pkg/parser/types" ) -// maxStringIndexPrefixLen keeps rewritten string index prefixes within TiDB's default maximum index length. -// TiDB limits an index to 3072 bytes, which is 768 characters with 4-byte UTF-8 encoding. -// See https://docs.pingcap.com/tidb/stable/tidb-limitations/#limitations-on-indexes. -const maxStringIndexPrefixLen = 768 - -// defaultBlobIndexPrefixLen follows a common MySQL/MariaDB prefix length for BLOB/TEXT indexes. -const defaultBlobIndexPrefixLen = 255 +const ( + // maxStringIndexPrefixLen keeps rewritten string index prefixes within TiDB's default maximum index length. + // TiDB limits an index to 3072 bytes, which is 768 characters with 4-byte UTF-8 encoding. + // See https://docs.pingcap.com/tidb/stable/tidb-limitations/#limitations-on-indexes. + maxStringIndexPrefixLen = 768 + + // defaultBlobIndexPrefixLen follows a common MySQL/MariaDB prefix length for BLOB/TEXT indexes. + defaultBlobIndexPrefixLen = 255 +) var mariaDBCompatibilityRules = []rule{ secondaryIndexPrefixRule{}, @@ -86,7 +88,8 @@ func (r columnDefaultValueRule) Apply(node ast.Node) bool { return false } return filterColumnOptions(col, func(opt *ast.ColumnOption) bool { - return opt.Tp == ast.ColumnOptionDefaultValue && isNonNullLiteralValueExpr(opt.Expr) + return opt.Tp == ast.ColumnOptionDefaultValue && + isNonNullLiteralValueExpr(opt.Expr) }) } @@ -108,7 +111,8 @@ func (r functionDefaultRule) Apply(node ast.Node) bool { return false } return filterColumnOptions(col, func(opt *ast.ColumnOption) bool { - return opt.Tp == ast.ColumnOptionDefaultValue && isUnsupportedTimeDefault(col.Tp.GetType(), opt.Expr) + return opt.Tp == ast.ColumnOptionDefaultValue && + isUnsupportedTimeDefault(col.Tp.GetType(), opt.Expr) }) } @@ -130,7 +134,8 @@ func isUnsupportedTimeDefault(colType byte, expr ast.ExprNode) bool { } } -// jsonValueRule rewrites MariaDB JSON_VALUE in generated column expressions to supported JSON functions. +// jsonValueRule rewrites MariaDB JSON_VALUE in generated column expressions +// to supported JSON functions. type jsonValueRule struct{} func (r jsonValueRule) Apply(node ast.Node) bool { @@ -140,12 +145,11 @@ func (r jsonValueRule) Apply(node ast.Node) bool { } changed := false for _, opt := range col.Options { - if opt.Tp != ast.ColumnOptionGenerated { - continue - } - if expr, ok := rewriteJSONValueExpr(opt.Expr); ok { - opt.Expr = expr - changed = true + if opt.Tp == ast.ColumnOptionGenerated { + if expr, ok := rewriteJSONValueExpr(opt.Expr); ok { + opt.Expr = expr + changed = true + } } } return changed