diff --git a/dm/pkg/ddl/rewriter/mariadb_rules.go b/dm/pkg/ddl/rewriter/mariadb_rules.go new file mode 100644 index 0000000000..8d679bfb22 --- /dev/null +++ b/dm/pkg/ddl/rewriter/mariadb_rules.go @@ -0,0 +1,197 @@ +// Copyright 2026 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package rewriter + +import ( + "github.com/pingcap/tidb/pkg/parser/ast" + "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/parser/types" +) + +const ( + // maxStringIndexPrefixLen keeps rewritten string index prefixes within TiDB's default maximum index length. + // TiDB limits an index to 3072 bytes, which is 768 characters with 4-byte UTF-8 encoding. + // See https://docs.pingcap.com/tidb/stable/tidb-limitations/#limitations-on-indexes. + maxStringIndexPrefixLen = 768 + + // defaultBlobIndexPrefixLen follows a common MySQL/MariaDB prefix length for BLOB/TEXT indexes. + defaultBlobIndexPrefixLen = 255 +) + +var mariaDBCompatibilityRules = []rule{ + secondaryIndexPrefixRule{}, + columnDefaultValueRule{}, + functionDefaultRule{}, + jsonValueRule{}, +} + +// secondaryIndexPrefixRule adds explicit prefix lengths for plain secondary indexes that TiDB rejects. +type secondaryIndexPrefixRule struct{} + +func (r secondaryIndexPrefixRule) Apply(node ast.Node) bool { + stmt, ok := node.(*ast.CreateTableStmt) + if !ok { + return false + } + colMap := make(map[string]*ast.ColumnDef, len(stmt.Cols)) + for _, col := range stmt.Cols { + colMap[col.Name.Name.L] = col + } + + changed := false + for _, cons := range stmt.Constraints { + switch cons.Tp { + case ast.ConstraintKey, ast.ConstraintIndex: + default: + continue + } + for _, key := range cons.Keys { + if key.Length > 0 || key.Column == nil { + continue + } + col := colMap[key.Column.Name.L] + if col == nil { + continue + } + switch { + case types.IsTypeBlob(col.Tp.GetType()): + key.Length = defaultBlobIndexPrefixLen + changed = true + case types.IsTypeChar(col.Tp.GetType()): + if col.Tp.GetFlen() > maxStringIndexPrefixLen { + key.Length = maxStringIndexPrefixLen + changed = true + } + } + } + } + return changed +} + +// columnDefaultValueRule removes literal non-NULL defaults from TEXT/BLOB/JSON columns. +type columnDefaultValueRule struct{} + +func (r columnDefaultValueRule) Apply(node ast.Node) bool { + col, ok := node.(*ast.ColumnDef) + if !ok || !isTextBlobOrJSON(col.Tp) { + return false + } + return filterColumnOptions(col, func(opt *ast.ColumnOption) bool { + return opt.Tp == ast.ColumnOptionDefaultValue && + isNonNullLiteralValueExpr(opt.Expr) + }) +} + +func isNonNullLiteralValueExpr(expr ast.ExprNode) bool { + expr = unwrapParentheses(expr) + valExpr, ok := expr.(ast.ValueExpr) + if !ok { + return false + } + return valExpr.GetValue() != nil +} + +// functionDefaultRule removes time-function defaults from column types that TiDB rejects. +type functionDefaultRule struct{} + +func (r functionDefaultRule) Apply(node ast.Node) bool { + col, ok := node.(*ast.ColumnDef) + if !ok { + return false + } + return filterColumnOptions(col, func(opt *ast.ColumnOption) bool { + return opt.Tp == ast.ColumnOptionDefaultValue && + isUnsupportedTimeDefault(col.Tp.GetType(), opt.Expr) + }) +} + +func isUnsupportedTimeDefault(colType byte, expr ast.ExprNode) bool { + expr = unwrapParentheses(expr) + fn, ok := expr.(*ast.FuncCallExpr) + if !ok { + return false + } + switch fn.FnName.L { + case ast.CurrentTimestamp, ast.Now, ast.LocalTime, ast.LocalTimestamp: + return colType != mysql.TypeTimestamp && colType != mysql.TypeDatetime + case ast.CurrentDate: + return colType != mysql.TypeDate && colType != mysql.TypeDatetime + case ast.CurrentTime: + return colType != mysql.TypeDuration + default: + return false + } +} + +// jsonValueRule rewrites MariaDB JSON_VALUE in generated column expressions +// to supported JSON functions. +type jsonValueRule struct{} + +func (r jsonValueRule) Apply(node ast.Node) bool { + col, ok := node.(*ast.ColumnDef) + if !ok { + return false + } + changed := false + for _, opt := range col.Options { + if opt.Tp == ast.ColumnOptionGenerated { + if expr, ok := rewriteJSONValueExpr(opt.Expr); ok { + opt.Expr = expr + changed = true + } + } + } + return changed +} + +func rewriteJSONValueExpr(expr ast.ExprNode) (ast.ExprNode, bool) { + visitor := &jsonValueExprRewriteVisitor{} + node, ok := expr.Accept(visitor) + if !ok { + return expr, false + } + newExpr, ok := node.(ast.ExprNode) + if !ok { + return expr, false + } + return newExpr, visitor.changed +} + +type jsonValueExprRewriteVisitor struct { + changed bool +} + +func (v *jsonValueExprRewriteVisitor) Enter(node ast.Node) (ast.Node, bool) { + return node, false +} + +func (v *jsonValueExprRewriteVisitor) Leave(node ast.Node) (ast.Node, bool) { + fn, ok := node.(*ast.FuncCallExpr) + if !ok || fn.FnName.L != "json_value" || len(fn.Args) != 2 { + return node, true + } + v.changed = true + jsonExtract := &ast.FuncCallExpr{ + FnName: ast.NewCIStr(ast.JSONExtract), + Args: fn.Args, + } + return &ast.FuncCallExpr{ + FnName: ast.NewCIStr(ast.JSONUnquote), + Args: []ast.ExprNode{jsonExtract}, + }, true +} + +func isTextBlobOrJSON(ft *types.FieldType) bool { + return types.IsTypeBlob(ft.GetType()) || ft.GetType() == mysql.TypeJSON +} diff --git a/dm/pkg/ddl/rewriter/rewriter.go b/dm/pkg/ddl/rewriter/rewriter.go new file mode 100644 index 0000000000..12dffe7abc --- /dev/null +++ b/dm/pkg/ddl/rewriter/rewriter.go @@ -0,0 +1,79 @@ +// Copyright 2026 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package rewriter + +import ( + "github.com/pingcap/tidb/pkg/parser/ast" +) + +// rule defines a rule to apply to AST nodes with best effort. +type rule interface { + Apply(ast.Node) bool +} + +type rewriteOptions struct { + rules []rule +} + +// Option configures the AST rules used by RewriteStmt. +type Option interface { + apply(*rewriteOptions) +} + +type optionFunc func(*rewriteOptions) + +func (f optionFunc) apply(options *rewriteOptions) { + f(options) +} + +// WithMariaDBCompatibility enables MariaDB compatibility AST rewrite rules. +func WithMariaDBCompatibility() Option { + return optionFunc(func(options *rewriteOptions) { + options.rules = append(options.rules, mariaDBCompatibilityRules...) + }) +} + +// RewriteStmt applies enabled rules to stmt in place. +// It is a best-effort compatibility layer for parsed AST nodes; parser failures and +// DDL failures that can be handled by downstream session settings, such as SQL mode, +// are intentionally left to the normal DM flow. +func RewriteStmt(stmt ast.StmtNode, opts ...Option) bool { + options := rewriteOptions{} + for _, opt := range opts { + opt.apply(&options) + } + if stmt == nil || len(options.rules) == 0 { + return false + } + visitor := &rewriteVisitor{rules: options.rules} + stmt.Accept(visitor) + return visitor.changed +} + +type rewriteVisitor struct { + rules []rule + changed bool +} + +func (v *rewriteVisitor) Enter(node ast.Node) (ast.Node, bool) { + return node, false +} + +func (v *rewriteVisitor) Leave(node ast.Node) (ast.Node, bool) { + for _, rule := range v.rules { + changed := rule.Apply(node) + v.changed = v.changed || changed + } + return node, true +} diff --git a/dm/pkg/ddl/rewriter/rewriter_test.go b/dm/pkg/ddl/rewriter/rewriter_test.go new file mode 100644 index 0000000000..76c89d8e6f --- /dev/null +++ b/dm/pkg/ddl/rewriter/rewriter_test.go @@ -0,0 +1,206 @@ +// Copyright 2026 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package rewriter + +import ( + "strings" + "testing" + + "github.com/pingcap/tidb/pkg/parser" + "github.com/pingcap/tidb/pkg/parser/ast" + "github.com/pingcap/tidb/pkg/parser/format" + _ "github.com/pingcap/tidb/pkg/types/parser_driver" // register parser driver + "github.com/stretchr/testify/require" +) + +func TestRewriteStmtRemovesFunctionDefaultOnVarchar(t *testing.T) { + stmt, changed := rewriteCreateTable(t, "CREATE TABLE t(t VARCHAR(100) DEFAULT current_timestamp());") + require.True(t, changed) + + col := findColumn(stmt, "t") + require.NotNil(t, col) + require.False(t, hasColumnOption(col, ast.ColumnOptionDefaultValue)) +} + +func TestRewriteStmtKeepsTimeFunctionDefaultOnTimeColumn(t *testing.T) { + stmt, changed := rewriteCreateTable(t, "CREATE TABLE t(ts TIMESTAMP DEFAULT current_timestamp());") + require.False(t, changed) + require.True(t, hasColumnOption(findColumn(stmt, "ts"), ast.ColumnOptionDefaultValue)) +} + +func TestRewriteStmtKeepsZeroTimeDefaults(t *testing.T) { + stmt, changed := rewriteCreateTable(t, `CREATE TABLE t( + d DATE DEFAULT '0000-00-00', + dt DATETIME DEFAULT '0000-00-00 00:00:00', + ts TIMESTAMP DEFAULT '0000-00-00 00:00:00' +)`) + require.False(t, changed) + + require.True(t, hasColumnOption(findColumn(stmt, "d"), ast.ColumnOptionDefaultValue)) + require.True(t, hasColumnOption(findColumn(stmt, "dt"), ast.ColumnOptionDefaultValue)) + require.True(t, hasColumnOption(findColumn(stmt, "ts"), ast.ColumnOptionDefaultValue)) +} + +func TestRewriteStmtTimeFunctionDefaultRules(t *testing.T) { + stmt, changed := rewriteCreateTable(t, `CREATE TABLE t( + dec_col DECIMAL(30,6) DEFAULT CURRENT_TIMESTAMP(6), + time_col TIME DEFAULT CURRENT_TIMESTAMP(), + date_col DATE DEFAULT CURRENT_TIMESTAMP(), + ok_date DATE DEFAULT CURRENT_DATE, + ok_datetime DATETIME DEFAULT CURRENT_DATE +)`) + require.True(t, changed) + + require.False(t, hasColumnOption(findColumn(stmt, "dec_col"), ast.ColumnOptionDefaultValue)) + require.False(t, hasColumnOption(findColumn(stmt, "time_col"), ast.ColumnOptionDefaultValue)) + require.False(t, hasColumnOption(findColumn(stmt, "date_col"), ast.ColumnOptionDefaultValue)) + require.True(t, hasColumnOption(findColumn(stmt, "ok_date"), ast.ColumnOptionDefaultValue)) + require.True(t, hasColumnOption(findColumn(stmt, "ok_datetime"), ast.ColumnOptionDefaultValue)) +} + +func TestRewriteStmtDefaultRules(t *testing.T) { + input := `CREATE TABLE t ( + id INT(11), + txt TEXT DEFAULT 'x', + txt_null TEXT DEFAULT NULL, + txt_expr TEXT DEFAULT(uuid()), + v VARCHAR(800), + j JSON DEFAULT(json_object('now', now())), + g JSON GENERATED ALWAYS AS (JSON_EXTRACT(j, '$.a')) VIRTUAL, + zero_ts TIMESTAMP DEFAULT '0000-00-00 00:00:00', + CHECK (json_valid(j)), + KEY idx_txt (txt), + KEY idx_v (v), + UNIQUE KEY uk_txt (txt) +) DEFAULT CHARSET=latin1 COLLATE=latin1_swedish_ci;` + + stmt, changed := rewriteCreateTable(t, input) + require.True(t, changed) + + require.Equal(t, "latin1", findTableOption(stmt, ast.TableOptionCharset)) + require.Equal(t, "latin1_swedish_ci", findTableOption(stmt, ast.TableOptionCollate)) + require.Equal(t, 11, findColumn(stmt, "id").Tp.GetFlen()) + require.Equal(t, 800, findColumn(stmt, "v").Tp.GetFlen()) + require.False(t, hasColumnOption(findColumn(stmt, "txt"), ast.ColumnOptionDefaultValue)) + require.True(t, hasColumnOption(findColumn(stmt, "txt_null"), ast.ColumnOptionDefaultValue)) + require.True(t, hasColumnOption(findColumn(stmt, "txt_expr"), ast.ColumnOptionDefaultValue)) + require.True(t, hasColumnOption(findColumn(stmt, "j"), ast.ColumnOptionDefaultValue)) + require.True(t, hasColumnOption(findColumn(stmt, "g"), ast.ColumnOptionGenerated)) + require.True(t, hasColumnOption(findColumn(stmt, "zero_ts"), ast.ColumnOptionDefaultValue)) + require.True(t, hasCheckConstraint(stmt)) + + idxTxt := findConstraint(stmt, "idx_txt") + require.NotNil(t, idxTxt) + require.Equal(t, 255, idxTxt.Keys[0].Length) + idxV := findConstraint(stmt, "idx_v") + require.NotNil(t, idxV) + require.Equal(t, 768, idxV.Keys[0].Length) + ukTxt := findConstraint(stmt, "uk_txt") + require.NotNil(t, ukTxt) + require.Equal(t, -1, ukTxt.Keys[0].Length) +} + +func TestRewriteStmtSkipsExpressionIndexPrefix(t *testing.T) { + rewriteCreateTable(t, "CREATE TABLE t(name VARCHAR(32), KEY idx_expr ((LOWER(name))));") +} + +func TestRewriteStmtRewritesParenthesizedJSONValueGeneratedColumn(t *testing.T) { + stmt, changed := rewriteCreateTable(t, + "CREATE TABLE t(j JSON, g TIME GENERATED ALWAYS AS (CAST((JSON_VALUE(j, '$.a')) AS TIME)) VIRTUAL);", + ) + require.True(t, changed) + expr := generatedExpr(findColumn(stmt, "g")) + require.NotNil(t, expr) + restored := strings.ToLower(restoreNode(t, expr)) + require.Contains(t, restored, "json_unquote(json_extract") + require.NotContains(t, restored, "json_value") +} + +func rewriteCreateTable(t *testing.T, sql string) (*ast.CreateTableStmt, bool) { + t.Helper() + stmt := parseCreateTable(t, sql) + changed := RewriteStmt(stmt, WithMariaDBCompatibility()) + return stmt, changed +} + +func parseCreateTable(t *testing.T, sql string) *ast.CreateTableStmt { + t.Helper() + stmt, err := parser.New().ParseOneStmt(sql, "", "") + require.NoError(t, err) + create, ok := stmt.(*ast.CreateTableStmt) + require.True(t, ok) + return create +} + +func findColumn(stmt *ast.CreateTableStmt, name string) *ast.ColumnDef { + for _, col := range stmt.Cols { + if strings.EqualFold(col.Name.Name.O, name) { + return col + } + } + return nil +} + +func hasColumnOption(col *ast.ColumnDef, optionType ast.ColumnOptionType) bool { + for _, opt := range col.Options { + if opt.Tp == optionType { + return true + } + } + return false +} + +func hasCheckConstraint(stmt *ast.CreateTableStmt) bool { + for _, cons := range stmt.Constraints { + if cons.Tp == ast.ConstraintCheck { + return true + } + } + return false +} + +func generatedExpr(col *ast.ColumnDef) ast.ExprNode { + for _, opt := range col.Options { + if opt.Tp == ast.ColumnOptionGenerated { + return opt.Expr + } + } + return nil +} + +func findConstraint(stmt *ast.CreateTableStmt, name string) *ast.Constraint { + for _, cons := range stmt.Constraints { + if strings.EqualFold(cons.Name, name) { + return cons + } + } + return nil +} + +func findTableOption(stmt *ast.CreateTableStmt, optionType ast.TableOptionType) string { + for _, opt := range stmt.Options { + if opt.Tp == optionType { + return strings.ToLower(opt.StrValue) + } + } + return "" +} + +func restoreNode(t *testing.T, node ast.Node) string { + t.Helper() + var sb strings.Builder + ctx := format.NewRestoreCtx(format.DefaultRestoreFlags, &sb) + require.NoError(t, node.Restore(ctx)) + return sb.String() +} diff --git a/dm/pkg/ddl/rewriter/strict_collation_rule.go b/dm/pkg/ddl/rewriter/strict_collation_rule.go new file mode 100644 index 0000000000..8cc75be0e6 --- /dev/null +++ b/dm/pkg/ddl/rewriter/strict_collation_rule.go @@ -0,0 +1,183 @@ +// Copyright 2026 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package rewriter + +import ( + "strings" + + "github.com/pingcap/tidb/pkg/parser/ast" + "github.com/pingcap/tiflow/dm/pkg/binlog/event" + "github.com/pingcap/tiflow/dm/pkg/log" + "go.uber.org/zap" +) + +// WithStrictCollation enables collation-compatible=strict AST rewrite rules. +func WithStrictCollation( + statusVars []byte, + charsetAndDefaultCollation map[string]string, + idAndCollationMap map[int]string, + originSQL string, + logger log.Logger, +) Option { + return optionFunc(func(options *rewriteOptions) { + if logger.Logger == nil { + logger = log.L() + } + options.rules = append(options.rules, strictCollationRule{ + statusVars: statusVars, + charsetAndDefaultCollation: charsetAndDefaultCollation, + idAndCollationMap: idAndCollationMap, + originSQL: originSQL, + logger: logger, + }) + }) +} + +// strictCollationRule adds explicit collations from upstream INFORMATION_SCHEMA.COLLATIONS. +type strictCollationRule struct { + statusVars []byte + charsetAndDefaultCollation map[string]string + idAndCollationMap map[int]string + originSQL string + logger log.Logger +} + +func (r strictCollationRule) Apply(node ast.Node) bool { + switch createStmt := node.(type) { + case *ast.CreateTableStmt: + return r.rewriteCreateTable(createStmt) + case *ast.CreateDatabaseStmt: + return r.rewriteCreateDatabase(createStmt) + default: + return false + } +} + +func (r strictCollationRule) rewriteCreateTable(createStmt *ast.CreateTableStmt) bool { + if createStmt.ReferTable != nil { + return false + } + + changed := r.rewriteColumnCollations(createStmt) + var justCharset string + for _, tableOption := range createStmt.Options { + if tableOption.Tp == ast.TableOptionCollate { + return changed + } + if tableOption.Tp == ast.TableOptionCharset { + justCharset = tableOption.StrValue + } + } + if justCharset == "" { + r.logger.Warn("detect create table risk which use implicit charset and collation", + zap.String("originSQL", r.originSQL)) + return changed + } + + collation, ok := r.charsetAndDefaultCollation[strings.ToLower(justCharset)] + if !ok { + r.logger.Warn("not found charset default collation.", + zap.String("originSQL", r.originSQL), + zap.String("charset", strings.ToLower(justCharset))) + return changed + } + r.logger.Info( + "detect create table risk which use explicit charset and implicit collation, we will add collation by INFORMATION_SCHEMA.COLLATIONS", + zap.String("originSQL", r.originSQL), + zap.String("collation", collation), + ) + createStmt.Options = append(createStmt.Options, &ast.TableOption{Tp: ast.TableOptionCollate, StrValue: collation}) + return true +} + +func (r strictCollationRule) rewriteCreateDatabase(createStmt *ast.CreateDatabaseStmt) bool { + var justCharset string + for _, createOption := range createStmt.Options { + if createOption.Tp == ast.DatabaseOptionCollate { + return false + } + if createOption.Tp == ast.DatabaseOptionCharset { + justCharset = createOption.Value + } + } + + var collation string + if justCharset != "" { + var ok bool + collation, ok = r.charsetAndDefaultCollation[strings.ToLower(justCharset)] + if !ok { + r.logger.Warn("not found charset default collation.", + zap.String("originSQL", r.originSQL), + zap.String("charset", strings.ToLower(justCharset))) + return false + } + r.logger.Info( + "detect create database risk which use explicit charset and implicit collation, we will add collation by INFORMATION_SCHEMA.COLLATIONS", + zap.String("originSQL", r.originSQL), + zap.String("collation", collation), + ) + } else { + var err error + collation, err = event.GetServerCollationByStatusVars(r.statusVars, r.idAndCollationMap) + if err != nil { + r.logger.Error("can not get charset server collation from binlog statusVars.", + zap.Error(err), + zap.String("originSQL", r.originSQL)) + } + if collation == "" { + r.logger.Error("get server collation from binlog statusVars is nil.", + zap.Error(err), + zap.String("originSQL", r.originSQL)) + return false + } + r.logger.Info( + "detect create database risk which use implicit charset and collation, we will add collation by binlog status_vars", + zap.String("originSQL", r.originSQL), + zap.String("collation", collation), + ) + } + createStmt.Options = append(createStmt.Options, &ast.DatabaseOption{Tp: ast.DatabaseOptionCollate, Value: collation}) + return true +} + +func (r strictCollationRule) rewriteColumnCollations(createStmt *ast.CreateTableStmt) bool { + changed := false +ColumnLoop: + for _, col := range createStmt.Cols { + for _, options := range col.Options { + if options.Tp == ast.ColumnOptionCollate { + continue ColumnLoop + } + } + fieldType := col.Tp + if fieldType.GetCollate() != "" || fieldType.GetCharset() == "" { + continue + } + + collation, ok := r.charsetAndDefaultCollation[strings.ToLower(fieldType.GetCharset())] + if !ok { + r.logger.Warn( + "not found charset default collation for column.", + zap.String("originSQL", r.originSQL), + zap.String("table", createStmt.Table.Name.String()), + zap.String("column", col.Name.String()), + zap.String("charset", strings.ToLower(fieldType.GetCharset())), + ) + continue + } + col.Options = append(col.Options, &ast.ColumnOption{Tp: ast.ColumnOptionCollate, StrValue: collation}) + changed = true + } + return changed +} diff --git a/dm/pkg/ddl/rewriter/utils.go b/dm/pkg/ddl/rewriter/utils.go new file mode 100644 index 0000000000..f924f53fab --- /dev/null +++ b/dm/pkg/ddl/rewriter/utils.go @@ -0,0 +1,40 @@ +// Copyright 2026 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package rewriter + +import "github.com/pingcap/tidb/pkg/parser/ast" + +func filterColumnOptions(col *ast.ColumnDef, drop func(*ast.ColumnOption) bool) bool { + options := col.Options[:0] + changed := false + for _, opt := range col.Options { + if drop(opt) { + changed = true + continue + } + options = append(options, opt) + } + col.Options = options + return changed +} + +func unwrapParentheses(expr ast.ExprNode) ast.ExprNode { + for { + p, ok := expr.(*ast.ParenthesesExpr) + if !ok { + return expr + } + expr = p.Expr + } +} diff --git a/dm/syncer/ddl.go b/dm/syncer/ddl.go index 2731e8feb1..d7a3155289 100644 --- a/dm/syncer/ddl.go +++ b/dm/syncer/ddl.go @@ -37,6 +37,7 @@ import ( "github.com/pingcap/tiflow/dm/pkg/binlog/event" "github.com/pingcap/tiflow/dm/pkg/conn" tcontext "github.com/pingcap/tiflow/dm/pkg/context" + ddlrewriter "github.com/pingcap/tiflow/dm/pkg/ddl/rewriter" "github.com/pingcap/tiflow/dm/pkg/log" parserpkg "github.com/pingcap/tiflow/dm/pkg/parser" "github.com/pingcap/tiflow/dm/pkg/schema" @@ -81,6 +82,7 @@ type DDLWorker struct { idAndCollationMap map[int]string baList *filter.Filter foreignKeyChecksEnabled bool + enableDDLRewrite bool getTableInfo func(tctx *tcontext.Context, sourceTable, targetTable *filter.Table) (*model.TableInfo, error) getDBInfoFromDownstream func(tctx *tcontext.Context, sourceTable, targetTable *filter.Table) (*model.DBInfo, error) @@ -111,6 +113,7 @@ func NewDDLWorker(pLogger *log.Logger, syncer *Syncer) *DDLWorker { idAndCollationMap: syncer.idAndCollationMap, baList: syncer.baList, foreignKeyChecksEnabled: config.IsForeignKeyChecksEnabled(syncer.cfg.To.Session), + enableDDLRewrite: syncer.enableDDLRewrite, recordSkipSQLsLocation: syncer.recordSkipSQLsLocation, trackDDL: syncer.trackDDL, saveTablePoint: syncer.saveTablePoint, @@ -238,13 +241,14 @@ func (ddl *DDLWorker) HandleQueryEvent(ev *replication.QueryEvent, ec eventConte } qec := &queryEventContext{ - eventContext: &ec, - ddlSchema: string(ev.Schema), - originSQL: utils.TrimCtrlChars(originSQL), - splitDDLs: make([]string, 0), - appliedDDLs: make([]string, 0), - sourceTbls: make(map[string]map[string]struct{}), - eventStatusVars: ev.StatusVars, + eventContext: &ec, + ddlSchema: string(ev.Schema), + originSQL: utils.TrimCtrlChars(originSQL), + enableDDLRewrite: ddl.enableDDLRewrite, + splitDDLs: make([]string, 0), + appliedDDLs: make([]string, 0), + sourceTbls: make(map[string]map[string]struct{}), + eventStatusVars: ev.StatusVars, } defer func() { @@ -964,7 +968,16 @@ func parseOneStmt(qec *queryEventContext) (stmt ast.StmtNode, err error) { if len(stmts) == 0 { return nil, nil } - return stmts[0], nil + stmt = stmts[0] + if qec.enableDDLRewrite { + changed := ddlrewriter.RewriteStmt(stmt, ddlrewriter.WithMariaDBCompatibility()) + if changed { + qec.tctx.L().Info("rewrite MariaDB DDL with AST compatibility rules", + zap.String("event", "query"), + zap.String("originSQL", qec.originSQL)) + } + } + return stmt, nil } // copy from https://github.com/pingcap/tidb/blob/fc4f8a1d8f5342cd01f78eb460e47d78d177ed20/ddl/column.go#L366 @@ -1322,9 +1335,17 @@ func (ddl *DDLWorker) genDDLInfo(qec *queryEventContext, sql string) (*ddlInfo, targetTables: targetTables, } - // "strict" will adjust collation if ddl.collationCompatible == config.StrictCollationCompatible { - ddl.adjustCollation(ddlInfo, qec.eventStatusVars, ddl.charsetAndDefaultCollation, ddl.idAndCollationMap) + ddlrewriter.RewriteStmt( + ddlInfo.stmtCache, + ddlrewriter.WithStrictCollation( + qec.eventStatusVars, + ddl.charsetAndDefaultCollation, + ddl.idAndCollationMap, + ddlInfo.originDDL, + ddl.logger, + ), + ) } routedDDL, err := parserpkg.RenameDDLTable(ddlInfo.stmtCache, ddlInfo.targetTables) @@ -1415,104 +1436,6 @@ func (ddl *Pessimist) clearOnlineDDL(tctx *tcontext.Context, targetTable *filter return nil } -// adjustCollation adds collation for create database and check create table. -func (ddl *DDLWorker) adjustCollation(ddlInfo *ddlInfo, statusVars []byte, charsetAndDefaultCollationMap map[string]string, idAndCollationMap map[int]string) { - switch createStmt := ddlInfo.stmtCache.(type) { - case *ast.CreateTableStmt: - if createStmt.ReferTable != nil { - return - } - ddl.adjustColumnsCollation(createStmt, charsetAndDefaultCollationMap) - var justCharset string - for _, tableOption := range createStmt.Options { - // already have 'Collation' - if tableOption.Tp == ast.TableOptionCollate { - return - } - if tableOption.Tp == ast.TableOptionCharset { - justCharset = tableOption.StrValue - } - } - if justCharset == "" { - ddl.logger.Warn("detect create table risk which use implicit charset and collation", zap.String("originSQL", ddlInfo.originDDL)) - return - } - // just has charset, can add collation by charset and default collation map - collation, ok := charsetAndDefaultCollationMap[strings.ToLower(justCharset)] - if !ok { - ddl.logger.Warn("not found charset default collation.", zap.String("originSQL", ddlInfo.originDDL), zap.String("charset", strings.ToLower(justCharset))) - return - } - ddl.logger.Info("detect create table risk which use explicit charset and implicit collation, we will add collation by INFORMATION_SCHEMA.COLLATIONS", zap.String("originSQL", ddlInfo.originDDL), zap.String("collation", collation)) - createStmt.Options = append(createStmt.Options, &ast.TableOption{Tp: ast.TableOptionCollate, StrValue: collation}) - - case *ast.CreateDatabaseStmt: - var justCharset, collation string - var ok bool - var err error - for _, createOption := range createStmt.Options { - // already have 'Collation' - if createOption.Tp == ast.DatabaseOptionCollate { - return - } - if createOption.Tp == ast.DatabaseOptionCharset { - justCharset = createOption.Value - } - } - - // just has charset, can add collation by charset and default collation map - if justCharset != "" { - collation, ok = charsetAndDefaultCollationMap[strings.ToLower(justCharset)] - if !ok { - ddl.logger.Warn("not found charset default collation.", zap.String("originSQL", ddlInfo.originDDL), zap.String("charset", strings.ToLower(justCharset))) - return - } - ddl.logger.Info("detect create database risk which use explicit charset and implicit collation, we will add collation by INFORMATION_SCHEMA.COLLATIONS", zap.String("originSQL", ddlInfo.originDDL), zap.String("collation", collation)) - } else { - // has no charset and collation - // add collation by server collation from binlog statusVars - collation, err = event.GetServerCollationByStatusVars(statusVars, idAndCollationMap) - if err != nil { - ddl.logger.Error("can not get charset server collation from binlog statusVars.", zap.Error(err), zap.String("originSQL", ddlInfo.originDDL)) - } - if collation == "" { - ddl.logger.Error("get server collation from binlog statusVars is nil.", zap.Error(err), zap.String("originSQL", ddlInfo.originDDL)) - return - } - // add collation - ddl.logger.Info("detect create database risk which use implicit charset and collation, we will add collation by binlog status_vars", zap.String("originSQL", ddlInfo.originDDL), zap.String("collation", collation)) - } - createStmt.Options = append(createStmt.Options, &ast.DatabaseOption{Tp: ast.DatabaseOptionCollate, Value: collation}) - } -} - -// adjustColumnsCollation adds column's collation. -func (ddl *DDLWorker) adjustColumnsCollation(createStmt *ast.CreateTableStmt, charsetAndDefaultCollationMap map[string]string) { -ColumnLoop: - for _, col := range createStmt.Cols { - for _, options := range col.Options { - // already have 'Collation' - if options.Tp == ast.ColumnOptionCollate { - continue ColumnLoop - } - } - fieldType := col.Tp - // already have 'Collation' - if fieldType.GetCollate() != "" { - continue - } - if fieldType.GetCharset() != "" { - // just have charset - collation, ok := charsetAndDefaultCollationMap[strings.ToLower(fieldType.GetCharset())] - if !ok { - ddl.logger.Warn("not found charset default collation for column.", zap.String("table", createStmt.Table.Name.String()), zap.String("column", col.Name.String()), zap.String("charset", strings.ToLower(fieldType.GetCharset()))) - continue - } - col.Options = append(col.Options, &ast.ColumnOption{Tp: ast.ColumnOptionCollate, StrValue: collation}) - } - } -} - type ddlInfo struct { originDDL string routedDDL string diff --git a/dm/syncer/ddl_test.go b/dm/syncer/ddl_test.go index 23d7490fa1..8ed087d0fc 100644 --- a/dm/syncer/ddl_test.go +++ b/dm/syncer/ddl_test.go @@ -31,6 +31,7 @@ import ( "github.com/pingcap/tiflow/dm/config/dbconfig" "github.com/pingcap/tiflow/dm/pkg/conn" tcontext "github.com/pingcap/tiflow/dm/pkg/context" + ddlrewriter "github.com/pingcap/tiflow/dm/pkg/ddl/rewriter" "github.com/pingcap/tiflow/dm/pkg/log" parserpkg "github.com/pingcap/tiflow/dm/pkg/parser" "github.com/pingcap/tiflow/dm/pkg/terror" @@ -419,6 +420,24 @@ func (s *testDDLSuite) TestParseOneStmt(c *check.C) { } } +func TestParseOneStmtWithMariaDBASTRewriter(t *testing.T) { + tctx := tcontext.Background().WithLogger(log.With(zap.String("test", "TestParseOneStmtWithMariaDBASTRewriter"))) + qec := &queryEventContext{ + eventContext: &eventContext{tctx: tctx}, + ddlSchema: "test", + originSQL: "CREATE TABLE t(c VARCHAR(100) DEFAULT current_timestamp())", + p: parser.New(), + enableDDLRewrite: true, + } + + stmt, err := parseOneStmt(qec) + require.NoError(t, err) + sqls, err := parserpkg.SplitDDL(stmt, qec.ddlSchema) + require.NoError(t, err) + require.Len(t, sqls, 1) + require.NotContains(t, strings.ToLower(sqls[0]), "default") +} + func (s *testDDLSuite) TestResolveGeneratedColumnSQL(c *check.C) { testCases := []struct { sql string @@ -729,11 +748,6 @@ func (s *testDDLSuite) TestAdjustDatabaseCollation(c *check.C) { } tctx := tcontext.Background().WithLogger(log.With(zap.String("test", "TestAdjustTableCollation"))) - syncer := NewSyncer(&config.SubTaskConfig{ - Flavor: mysql.MySQLFlavor, - CollationCompatible: config.StrictCollationCompatible, - }, nil, nil) - syncer.tctx = tctx p := parser.New() tab := &filter.Table{ Schema: "test", @@ -742,7 +756,6 @@ func (s *testDDLSuite) TestAdjustDatabaseCollation(c *check.C) { charsetAndDefaultCollationMap := map[string]string{"utf8mb4": "utf8mb4_general_ci"} idAndCollationMap := map[int]string{46: "utf8mb4_bin", 277: "utf8mb4_vi_0900_ai_ci"} - ddlWorker := NewDDLWorker(&tctx.Logger, syncer) for i, statusVars := range statusVarsArray { for j, sql := range sqls { ddlInfo := &ddlInfo{ @@ -755,7 +768,16 @@ func (s *testDDLSuite) TestAdjustDatabaseCollation(c *check.C) { c.Assert(err, check.IsNil) c.Assert(stmt, check.NotNil) ddlInfo.stmtCache = stmt - ddlWorker.adjustCollation(ddlInfo, statusVars, charsetAndDefaultCollationMap, idAndCollationMap) + ddlrewriter.RewriteStmt( + ddlInfo.stmtCache, + ddlrewriter.WithStrictCollation( + statusVars, + charsetAndDefaultCollationMap, + idAndCollationMap, + ddlInfo.originDDL, + tctx.Logger, + ), + ) routedDDL, err := parserpkg.RenameDDLTable(ddlInfo.stmtCache, ddlInfo.targetTables) c.Assert(err, check.IsNil) c.Assert(routedDDL, check.Equals, expectedSQLs[i][j]) @@ -809,11 +831,6 @@ func TestAdjustCollation(t *testing.T) { } tctx := tcontext.Background().WithLogger(log.With(zap.String("test", "TestAdjustTableCollation"))) - syncer := NewSyncer(&config.SubTaskConfig{ - Flavor: mysql.MySQLFlavor, - CollationCompatible: config.StrictCollationCompatible, - }, nil, nil) - syncer.tctx = tctx p := parser.New() tab := &filter.Table{ Schema: "test", @@ -822,7 +839,6 @@ func TestAdjustCollation(t *testing.T) { statusVars := []byte{4, 0, 0, 0, 0, 46, 0} charsetAndDefaultCollationMap := map[string]string{"utf8mb4": "utf8mb4_general_ci", "latin1": "latin1_swedish_ci"} idAndCollationMap := map[int]string{46: "utf8mb4_bin"} - ddlWorker := NewDDLWorker(&tctx.Logger, syncer) for i, sql := range sqls { ddlInfo := &ddlInfo{ originDDL: sql, @@ -834,7 +850,16 @@ func TestAdjustCollation(t *testing.T) { require.NoError(t, err) require.NotNil(t, stmt) ddlInfo.stmtCache = stmt - ddlWorker.adjustCollation(ddlInfo, statusVars, charsetAndDefaultCollationMap, idAndCollationMap) + ddlrewriter.RewriteStmt( + ddlInfo.stmtCache, + ddlrewriter.WithStrictCollation( + statusVars, + charsetAndDefaultCollationMap, + idAndCollationMap, + ddlInfo.originDDL, + tctx.Logger, + ), + ) routedDDL, err := parserpkg.RenameDDLTable(ddlInfo.stmtCache, ddlInfo.targetTables) require.NoError(t, err) require.Equal(t, expectedSQLs[i], routedDDL) diff --git a/dm/syncer/syncer.go b/dm/syncer/syncer.go index 1013bf2a39..ce2aea2b07 100644 --- a/dm/syncer/syncer.go +++ b/dm/syncer/syncer.go @@ -266,6 +266,7 @@ type Syncer struct { idAndCollationMap map[int]string ddlWorker *DDLWorker + enableDDLRewrite bool fetchBinlogLogger *zap.Logger unhandledEventLogger *zap.Logger } @@ -533,6 +534,7 @@ func (s *Syncer) Init(ctx context.Context) (err error) { } s.metricsProxies = metricProxies.CacheForOneTask(s.cfg.Name, s.cfg.WorkerName, s.cfg.SourceID) + s.enableDDLRewrite = strings.EqualFold(s.cfg.Flavor, mysql.MariaDBFlavor) s.ddlWorker = NewDDLWorker(&s.tctx.Logger, s) return nil } @@ -2768,9 +2770,10 @@ func (s *Syncer) handleRowsEvent(ev *replication.RowsEvent, ec eventContext) (*f type queryEventContext struct { *eventContext - p *parser.Parser // used parser - ddlSchema string // used schema - originSQL string // before split + p *parser.Parser // used parser + ddlSchema string // used schema + originSQL string // before split + enableDDLRewrite bool // split multi-schema change DDL into multiple one schema change DDL due to TiDB's limitation splitDDLs []string // after split before online ddl appliedDDLs []string // after onlineDDL apply if onlineDDL != nil @@ -2956,13 +2959,14 @@ func (s *Syncer) trackOriginDDL(ev *replication.QueryEvent, ec eventContext) (ma } var err error qec := &queryEventContext{ - eventContext: &ec, - ddlSchema: string(ev.Schema), - originSQL: utils.TrimCtrlChars(originSQL), - splitDDLs: make([]string, 0), - appliedDDLs: make([]string, 0), - sourceTbls: make(map[string]map[string]struct{}), - eventStatusVars: ev.StatusVars, + eventContext: &ec, + ddlSchema: string(ev.Schema), + originSQL: utils.TrimCtrlChars(originSQL), + enableDDLRewrite: s.enableDDLRewrite, + splitDDLs: make([]string, 0), + appliedDDLs: make([]string, 0), + sourceTbls: make(map[string]map[string]struct{}), + eventStatusVars: ev.StatusVars, } qec.p, err = event.GetParserForStatusVars(ev.StatusVars) if err != nil {