diff --git a/.agents/skills/tidb-test-guidelines/references/dxf-case-map.md b/.agents/skills/tidb-test-guidelines/references/dxf-case-map.md index 295a79f6213ef..dc06a6967e8ae 100644 --- a/.agents/skills/tidb-test-guidelines/references/dxf-case-map.md +++ b/.agents/skills/tidb-test-guidelines/references/dxf-case-map.md @@ -96,6 +96,7 @@ ## pkg/dxf/importinto ### Tests +- `pkg/dxf/importinto/clean_up_test.go` - dxf/importinto: Tests cleanup metering concurrency. - `pkg/dxf/importinto/collect_conflicts_test.go` - dxf/importinto: Tests collect conflicts step executor. - `pkg/dxf/importinto/conflict_resolution_test.go` - dxf/importinto: Tests conflict resolution step executor. - `pkg/dxf/importinto/encode_and_sort_operator_test.go` - dxf/importinto: Tests encode and sort operator. diff --git a/pkg/dxf/framework/integrationtests/bench_test.go b/pkg/dxf/framework/integrationtests/bench_test.go index dc87029dc7fe7..85b6d69292534 100644 --- a/pkg/dxf/framework/integrationtests/bench_test.go +++ b/pkg/dxf/framework/integrationtests/bench_test.go @@ -41,7 +41,7 @@ import ( ) var ( - maxConcurrentTask = flag.Int("max-concurrent-task", proto.MaxConcurrentTask, "max concurrent task") + maxConcurrentTask = flag.Int("max-concurrent-task", proto.GetMaxConcurrentTask(), "max concurrent task") waitDuration = flag.Duration("task-wait-duration", 2*time.Minute, "task wait duration") schedulerInterval = flag.Duration("scheduler-interval", scheduler.CheckTaskFinishedInterval, "scheduler interval") taskExecutorMgrInterval = flag.Duration("task-executor-mgr-interval", taskexecutor.TaskCheckInterval, "task executor mgr interval") @@ -65,43 +65,43 @@ func BenchmarkSchedulerOverhead(b *testing.B) { }() schIntervalBak := scheduler.CheckTaskFinishedInterval exeMgrIntervalBak := taskexecutor.TaskCheckInterval - bak := proto.MaxConcurrentTask + restoreMaxConcurrentTask := proto.SetMaxConcurrentTaskForTest(*maxConcurrentTask) b.Cleanup(func() { - proto.MaxConcurrentTask = bak + restoreMaxConcurrentTask() scheduler.CheckTaskFinishedInterval = schIntervalBak taskexecutor.TaskCheckInterval = exeMgrIntervalBak }) - proto.MaxConcurrentTask = *maxConcurrentTask scheduler.CheckTaskFinishedInterval = *schedulerInterval taskexecutor.TaskCheckInterval = *taskExecutorMgrInterval - b.Logf("max concurrent task: %d", proto.MaxConcurrentTask) + maxConcurrentTaskValue := proto.GetMaxConcurrentTask() + b.Logf("max concurrent task: %d", maxConcurrentTaskValue) b.Logf("taks wait duration: %s", *waitDuration) b.Logf("task meta size: %d", *taskMetaSize) b.Logf("scheduler interval: %s", scheduler.CheckTaskFinishedInterval) b.Logf("task executor mgr interval: %s", taskexecutor.TaskCheckInterval) prepareForBenchTest(b) - c := testutil.NewTestDXFContext(b, 1, 2*proto.MaxConcurrentTask, false) + c := testutil.NewTestDXFContext(b, 1, 2*maxConcurrentTaskValue, false) registerTaskTypeForBench(c) if *noTask { time.Sleep(*waitDuration) } else { - // in this test, we will start 4*proto.MaxConcurrentTask tasks, but only - // proto.MaxConcurrentTask will be scheduled at the same time, for other + // in this test, we will start 4*maxConcurrentTaskValue tasks, but only + // maxConcurrentTaskValue will be scheduled at the same time, for other // tasks will be in queue only to check the performance of querying them. - for i := range 4 * proto.MaxConcurrentTask { + for i := range 4 * maxConcurrentTaskValue { taskKey := fmt.Sprintf("task-%03d", i) taskMeta := make([]byte, *taskMetaSize) _, err := handle.SubmitTask(c.Ctx, taskKey, proto.TaskTypeExample, c.Store.GetKeyspace(), 1, "", 0, taskMeta) require.NoError(c.T, err) } // task has 2 steps, each step has 1 subtask,wait in serial to reduce WaitTask check overhead. - // only wait first proto.MaxConcurrentTask and exit + // only wait first maxConcurrentTaskValue and exit time.Sleep(2 * *waitDuration) - for i := range proto.MaxConcurrentTask { + for i := range maxConcurrentTaskValue { taskKey := fmt.Sprintf("task-%03d", i) testutil.WaitTaskDoneOrPaused(c.Ctx, c.T, taskKey) } diff --git a/pkg/dxf/framework/proto/BUILD.bazel b/pkg/dxf/framework/proto/BUILD.bazel index 857d7bd32eb0a..54dd9ff5cde07 100644 --- a/pkg/dxf/framework/proto/BUILD.bazel +++ b/pkg/dxf/framework/proto/BUILD.bazel @@ -26,6 +26,6 @@ go_test( ], embed = [":proto"], flaky = True, - shard_count = 9, + shard_count = 10, deps = ["@com_github_stretchr_testify//require"], ) diff --git a/pkg/dxf/framework/proto/task.go b/pkg/dxf/framework/proto/task.go index 4fd3c905ea6b0..91aa60b1f12d7 100644 --- a/pkg/dxf/framework/proto/task.go +++ b/pkg/dxf/framework/proto/task.go @@ -17,6 +17,7 @@ package proto import ( "cmp" "fmt" + "sync/atomic" "time" ) @@ -63,9 +64,44 @@ const ( NormalPriority = 512 ) -// MaxConcurrentTask is the max concurrency of task. -// TODO: remove this limit later. -var MaxConcurrentTask = 16 +const ( + // DefaultMaxConcurrentTask is the default max concurrency of task. + DefaultMaxConcurrentTask = 16 + // MinMaxConcurrentTask is the minimum allowed max concurrency of task. + MinMaxConcurrentTask = 16 + // MaxMaxConcurrentTask is the maximum allowed max concurrency of task. + MaxMaxConcurrentTask = 1000 +) + +var maxConcurrentTask atomic.Int64 + +func init() { + maxConcurrentTask.Store(DefaultMaxConcurrentTask) +} + +// GetMaxConcurrentTask returns the max concurrency of task. +func GetMaxConcurrentTask() int { + return int(maxConcurrentTask.Load()) +} + +// SetMaxConcurrentTask updates the max concurrency of task. +func SetMaxConcurrentTask(value int) error { + if value < MinMaxConcurrentTask || value > MaxMaxConcurrentTask { + return fmt.Errorf("max_concurrent_task %d is out of range [%d, %d]", + value, MinMaxConcurrentTask, MaxMaxConcurrentTask) + } + maxConcurrentTask.Store(int64(value)) + return nil +} + +// SetMaxConcurrentTaskForTest updates the max concurrency of task and returns a restore function. +func SetMaxConcurrentTaskForTest(value int) func() { + old := GetMaxConcurrentTask() + maxConcurrentTask.Store(int64(value)) + return func() { + maxConcurrentTask.Store(int64(old)) + } +} // ExtraParams is the extra params of task. // Note: only store params that's not used for filter or sort in this struct. diff --git a/pkg/dxf/framework/proto/task_test.go b/pkg/dxf/framework/proto/task_test.go index 5bbff6032ea0f..5bd7251394628 100644 --- a/pkg/dxf/framework/proto/task_test.go +++ b/pkg/dxf/framework/proto/task_test.go @@ -47,6 +47,23 @@ func TestTaskIsDone(t *testing.T) { } } +func TestMaxConcurrentTask(t *testing.T) { + restore := SetMaxConcurrentTaskForTest(DefaultMaxConcurrentTask) + defer restore() + + require.Equal(t, DefaultMaxConcurrentTask, GetMaxConcurrentTask()) + require.Equal(t, 1000, MaxMaxConcurrentTask) + for _, value := range []int{MinMaxConcurrentTask - 1, MaxMaxConcurrentTask + 1} { + require.Error(t, SetMaxConcurrentTask(value)) + require.Equal(t, DefaultMaxConcurrentTask, GetMaxConcurrentTask()) + } + + require.NoError(t, SetMaxConcurrentTask(128)) + require.Equal(t, 128, GetMaxConcurrentTask()) + require.NoError(t, SetMaxConcurrentTask(MaxMaxConcurrentTask)) + require.Equal(t, MaxMaxConcurrentTask, GetMaxConcurrentTask()) +} + func TestTaskCompare(t *testing.T) { taskA := Task{TaskBase: TaskBase{ ID: 100, diff --git a/pkg/dxf/framework/scheduler/BUILD.bazel b/pkg/dxf/framework/scheduler/BUILD.bazel index 062bf2953fa17..94e9f3cc1463e 100644 --- a/pkg/dxf/framework/scheduler/BUILD.bazel +++ b/pkg/dxf/framework/scheduler/BUILD.bazel @@ -64,7 +64,7 @@ go_test( embed = [":scheduler"], flaky = True, race = "off", - shard_count = 43, + shard_count = 45, deps = [ "//pkg/config", "//pkg/config/kerneltype", diff --git a/pkg/dxf/framework/scheduler/interface.go b/pkg/dxf/framework/scheduler/interface.go index 80ad11dc1a518..d08622e84787a 100644 --- a/pkg/dxf/framework/scheduler/interface.go +++ b/pkg/dxf/framework/scheduler/interface.go @@ -27,7 +27,7 @@ import ( // TaskManager defines the interface to access task table. type TaskManager interface { - // GetTopUnfinishedTasks returns unfinished tasks, limited by MaxConcurrentTask*2, + // GetTopUnfinishedTasks returns unfinished tasks, limited by GetMaxConcurrentTask()*2, // to make sure low ranking tasks can be scheduled if resource is enough. // The returned tasks are sorted by task order, see proto.Task. GetTopUnfinishedTasks(ctx context.Context) ([]*proto.TaskBase, error) diff --git a/pkg/dxf/framework/scheduler/main_test.go b/pkg/dxf/framework/scheduler/main_test.go index ec33136fbca84..b254de47a57a7 100644 --- a/pkg/dxf/framework/scheduler/main_test.go +++ b/pkg/dxf/framework/scheduler/main_test.go @@ -33,7 +33,7 @@ func (sm *Manager) DelRunningTask(id int64) { // DoCleanupRoutine implements Scheduler.DoCleanupRoutine interface. func (sm *Manager) DoCleanupRoutine() { - sm.doCleanupTask() + sm.doCleanupTasks() } func (s *BaseScheduler) Switch2NextStep() (err error) { diff --git a/pkg/dxf/framework/scheduler/scheduler.go b/pkg/dxf/framework/scheduler/scheduler.go index e0e1a8fde7532..f870dfa478237 100644 --- a/pkg/dxf/framework/scheduler/scheduler.go +++ b/pkg/dxf/framework/scheduler/scheduler.go @@ -503,7 +503,7 @@ func (s *BaseScheduler) switch2NextStep() error { // OnNextSubtasksBatch may use len(eligibleNodes) as a hint to // calculate the number of subtasks, so we need to do this before // filtering nodes by available slots in scheduleSubtask. - eligibleNodes = eligibleNodes[:task.MaxNodeCount] + eligibleNodes = s.randomSelectNodes(eligibleNodes, task.MaxNodeCount) } s.logger.Info("eligible instances", zap.Int("num", len(eligibleNodes))) @@ -527,6 +527,17 @@ func (s *BaseScheduler) switch2NextStep() error { return nil } +func (s *BaseScheduler) randomSelectNodes(nodes []string, maxNodeCount int) []string { + if maxNodeCount <= 0 || len(nodes) <= maxNodeCount { + return nodes + } + selectedNodes := append([]string(nil), nodes...) + s.rand.Shuffle(len(selectedNodes), func(i, j int) { + selectedNodes[i], selectedNodes[j] = selectedNodes[j], selectedNodes[i] + }) + return selectedNodes[:maxNodeCount] +} + func (s *BaseScheduler) scheduleSubTask( task *proto.Task, subtaskStep proto.Step, diff --git a/pkg/dxf/framework/scheduler/scheduler_manager.go b/pkg/dxf/framework/scheduler/scheduler_manager.go index c4d5fbefe91cd..b6660f80fab04 100644 --- a/pkg/dxf/framework/scheduler/scheduler_manager.go +++ b/pkg/dxf/framework/scheduler/scheduler_manager.go @@ -51,6 +51,12 @@ var ( defaultCollectMetricsInterval = 15 * time.Second ) +const maxCleanupTaskBatchSize = 100 + +type batchCleanUpRoutine interface { + CleanUpBatch(ctx context.Context, tasks []*proto.Task) error +} + func (sm *Manager) getSchedulerCount() int { sm.mu.RLock() defer sm.mu.RUnlock() @@ -155,7 +161,7 @@ func NewManager(ctx context.Context, store kv.Storage, taskMgr TaskManager, serv serverID: serverID, }), logger: logger, - finishCh: make(chan struct{}, proto.MaxConcurrentTask), + finishCh: make(chan struct{}, proto.MaxMaxConcurrentTask), nodeRes: nodeRes, } schedulerManager.mu.schedulerMap = make(map[int64]Scheduler) @@ -244,7 +250,8 @@ func (sm *Manager) getSchedulableTasks(ctx context.Context) ([]*proto.TaskBase, defer r.End() getTasksFn := sm.taskMgr.GetTopUnfinishedTasks taskCnt := sm.getSchedulerCount() - if taskCnt >= proto.MaxConcurrentTask { + maxConcurrentTask := proto.GetMaxConcurrentTask() + if taskCnt >= maxConcurrentTask { // when we have reached the limit of concurrent tasks, we only handle // tasks in states that don't need resources, e.g. reverting/cancelling/ // pausing/modifying. @@ -291,7 +298,7 @@ func (sm *Manager) startSchedulers(schedulableTasks []*proto.TaskBase) error { switch task.State { case proto.TaskStatePending, proto.TaskStateRunning, proto.TaskStateResuming: taskCnt := sm.getSchedulerCount() - if taskCnt >= proto.MaxConcurrentTask { + if taskCnt >= proto.GetMaxConcurrentTask() { continue } reservedExecID, ok = sm.slotMgr.canReserve(task) @@ -412,18 +419,24 @@ func (sm *Manager) cleanupTaskLoop() { sm.logger.Info("cleanup loop exits") return case <-sm.finishCh: - sm.doCleanupTask() + sm.doCleanupTasks() case <-ticker.C: - sm.doCleanupTask() + sm.doCleanupTasks() } } } -// doCleanupTask processes clean up routine defined by each type of tasks and cleanupMeta. +// doCleanupTasks keeps cleaning limited batches until there is no immediately cleanable task. +func (sm *Manager) doCleanupTasks() { + for sm.doCleanupTask() { + } +} + +// doCleanupTask processes one batch of cleanup routine defined by each type of tasks and cleanupMeta. // For example: // // tasks with global sort should clean up tmp files stored on S3. -func (sm *Manager) doCleanupTask() { +func (sm *Manager) doCleanupTask() bool { failpoint.InjectCall("doCleanupTask") tasks, err := sm.taskMgr.GetTasksInStates( sm.ctx, @@ -433,39 +446,53 @@ func (sm *Manager) doCleanupTask() { ) if err != nil { sm.logger.Warn("get task in states failed", zap.Error(err)) - return + return false } if len(tasks) == 0 { - return + return false } sm.logger.Info("cleanup routine start") err = sm.cleanupFinishedTasks(tasks) if err != nil { sm.logger.Warn("cleanup routine failed", zap.Error(err)) - return + return false } failpoint.InjectCall("WaitCleanUpFinished") sm.logger.Info("cleanup routine success") + return true } func (sm *Manager) cleanupFinishedTasks(tasks []*proto.Task) error { - cleanedTasks := make([]*proto.Task, 0) + cleanedTasks := make([]*proto.Task, 0, len(tasks)) var firstErr error + importIntoTasks := make([]*proto.Task, 0) + cleanUpImportIntoTasks := func() error { + if len(importIntoTasks) == 0 { + return nil + } + cleanedImportIntoTasks, err := sm.cleanupImportIntoTasks(importIntoTasks) + cleanedTasks = append(cleanedTasks, cleanedImportIntoTasks...) + importIntoTasks = importIntoTasks[:0] + return err + } for _, task := range tasks { sm.logger.Info("cleanup task", zap.Int64("task-id", task.ID), zap.String("task-key", task.Key)) - cleanupFactory := getSchedulerCleanUpFactory(task.Type) - if cleanupFactory != nil { - cleanup := cleanupFactory() - err := cleanup.CleanUp(sm.ctx, task) - if err != nil { - firstErr = err - break - } - cleanedTasks = append(cleanedTasks, task) - } else { - // if task doesn't register cleanup function, mark it as cleaned. - cleanedTasks = append(cleanedTasks, task) + if task.Type == proto.ImportInto { + importIntoTasks = append(importIntoTasks, task) + continue + } + if err := cleanUpImportIntoTasks(); err != nil { + firstErr = err + break } + if err := sm.cleanupSingleTask(task); err != nil { + firstErr = err + break + } + cleanedTasks = append(cleanedTasks, task) + } + if firstErr == nil { + firstErr = cleanUpImportIntoTasks() } if firstErr != nil { // normally ScheduleEventCounter requires a task ID, but since scheduler @@ -479,7 +506,45 @@ func (sm *Manager) cleanupFinishedTasks(tasks []*proto.Task) error { failpoint.Return(errors.New("transfer err")) }) - return sm.taskMgr.TransferTasks2History(sm.ctx, cleanedTasks) + if err := sm.taskMgr.TransferTasks2History(sm.ctx, cleanedTasks); err != nil { + return err + } + return firstErr +} + +func (sm *Manager) cleanupImportIntoTasks(tasks []*proto.Task) ([]*proto.Task, error) { + cleanupFactory := getSchedulerCleanUpFactory(proto.ImportInto) + if cleanupFactory == nil { + // if task doesn't register cleanup function, mark it as cleaned. + return tasks, nil + } + cleanup := cleanupFactory() + if batchCleanup, ok := cleanup.(batchCleanUpRoutine); ok { + if err := batchCleanup.CleanUpBatch(sm.ctx, tasks); err != nil { + return nil, err + } + return tasks, nil + } + + cleanedTasks := make([]*proto.Task, 0, len(tasks)) + for i, task := range tasks { + if i > 0 { + cleanup = cleanupFactory() + } + if err := cleanup.CleanUp(sm.ctx, task); err != nil { + return cleanedTasks, err + } + cleanedTasks = append(cleanedTasks, task) + } + return cleanedTasks, nil +} + +func (sm *Manager) cleanupSingleTask(task *proto.Task) error { + cleanupFactory := getSchedulerCleanUpFactory(task.Type) + if cleanupFactory == nil { + return nil + } + return cleanupFactory().CleanUp(sm.ctx, task) } func (sm *Manager) collectLoop() { diff --git a/pkg/dxf/framework/scheduler/scheduler_manager_nokit_test.go b/pkg/dxf/framework/scheduler/scheduler_manager_nokit_test.go index 1d5839bbb3e8b..0b81e8f38c9af 100644 --- a/pkg/dxf/framework/scheduler/scheduler_manager_nokit_test.go +++ b/pkg/dxf/framework/scheduler/scheduler_manager_nokit_test.go @@ -37,6 +37,34 @@ type storeWithKS struct { ks string } +type cleanUpCallRecorder struct { + cleanUpCalls []int64 + batchCalls [][]int64 +} + +func (r *cleanUpCallRecorder) CleanUp(_ context.Context, task *proto.Task) error { + r.cleanUpCalls = append(r.cleanUpCalls, task.ID) + return nil +} + +func (r *cleanUpCallRecorder) CleanUpBatch(_ context.Context, tasks []*proto.Task) error { + taskIDs := make([]int64, 0, len(tasks)) + for _, task := range tasks { + taskIDs = append(taskIDs, task.ID) + } + r.batchCalls = append(r.batchCalls, taskIDs) + return nil +} + +type singleCleanUpCallRecorder struct { + cleanUpCalls []int64 +} + +func (r *singleCleanUpCallRecorder) CleanUp(_ context.Context, task *proto.Task) error { + r.cleanUpCalls = append(r.cleanUpCalls, task.ID) + return nil +} + func (s *storeWithKS) GetKeyspace() string { return s.ks } @@ -128,6 +156,40 @@ func TestSchedulerCleanupTask(t *testing.T) { mgr.doCleanupTask() require.True(t, ctrl.Satisfied()) + manyTasks := make([]*proto.Task, maxCleanupTaskBatchSize+1) + for i := range manyTasks { + manyTasks[i] = &proto.Task{TaskBase: proto.TaskBase{ID: int64(i + 1)}} + } + taskMgr.EXPECT().GetTasksInStates( + mgr.ctx, + proto.TaskStateFailed, + proto.TaskStateReverted, + proto.TaskStateSucceed).Return(manyTasks[:maxCleanupTaskBatchSize], nil) + taskMgr.EXPECT().TransferTasks2History(mgr.ctx, manyTasks[:maxCleanupTaskBatchSize]).Return(nil) + mgr.doCleanupTask() + require.True(t, ctrl.Satisfied()) + + // wrapper cleans multiple limited batches in one tick. + taskMgr.EXPECT().GetTasksInStates( + mgr.ctx, + proto.TaskStateFailed, + proto.TaskStateReverted, + proto.TaskStateSucceed).Return(manyTasks[:maxCleanupTaskBatchSize], nil) + taskMgr.EXPECT().TransferTasks2History(mgr.ctx, manyTasks[:maxCleanupTaskBatchSize]).Return(nil) + taskMgr.EXPECT().GetTasksInStates( + mgr.ctx, + proto.TaskStateFailed, + proto.TaskStateReverted, + proto.TaskStateSucceed).Return(tasks, nil) + taskMgr.EXPECT().TransferTasks2History(mgr.ctx, tasks).Return(nil) + taskMgr.EXPECT().GetTasksInStates( + mgr.ctx, + proto.TaskStateFailed, + proto.TaskStateReverted, + proto.TaskStateSucceed).Return(nil, nil) + mgr.doCleanupTasks() + require.True(t, ctrl.Satisfied()) + // fail in transfer mockErr := errors.New("transfer err") taskMgr.EXPECT().GetTasksInStates( @@ -136,7 +198,7 @@ func TestSchedulerCleanupTask(t *testing.T) { proto.TaskStateReverted, proto.TaskStateSucceed).Return(tasks, nil) taskMgr.EXPECT().TransferTasks2History(mgr.ctx, tasks).Return(mockErr) - mgr.doCleanupTask() + mgr.doCleanupTasks() require.True(t, ctrl.Satisfied()) taskMgr.EXPECT().GetTasksInStates( @@ -149,6 +211,38 @@ func TestSchedulerCleanupTask(t *testing.T) { require.True(t, ctrl.Satisfied()) } +func TestSchedulerCleanupImportIntoTasksInBatch(t *testing.T) { + ClearSchedulerCleanUpFactory() + t.Cleanup(ClearSchedulerCleanUpFactory) + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + taskMgr := mock.NewMockTaskManager(ctrl) + mgr := NewManager(context.Background(), nil, taskMgr, "1", proto.NodeResourceForTest) + importCleanUp := &cleanUpCallRecorder{} + exampleCleanUp := &singleCleanUpCallRecorder{} + RegisterSchedulerCleanUpFactory(proto.ImportInto, func() CleanUpRoutine { + return importCleanUp + }) + RegisterSchedulerCleanUpFactory(proto.TaskTypeExample, func() CleanUpRoutine { + return exampleCleanUp + }) + + tasks := []*proto.Task{ + {TaskBase: proto.TaskBase{ID: 1, Type: proto.ImportInto}}, + {TaskBase: proto.TaskBase{ID: 2, Type: proto.ImportInto}}, + {TaskBase: proto.TaskBase{ID: 3, Type: proto.TaskTypeExample}}, + {TaskBase: proto.TaskBase{ID: 4, Type: proto.ImportInto}}, + {TaskBase: proto.TaskBase{ID: 5, Type: proto.TaskType("NoCleanUp")}}, + } + taskMgr.EXPECT().TransferTasks2History(mgr.ctx, tasks).Return(nil) + + require.NoError(t, mgr.cleanupFinishedTasks(tasks)) + require.Equal(t, [][]int64{{1, 2}, {4}}, importCleanUp.batchCalls) + require.Empty(t, importCleanUp.cleanUpCalls) + require.Equal(t, []int64{3}, exampleCleanUp.cleanUpCalls) +} + func TestManagerSchedulerNotAllocateSlots(t *testing.T) { // the tests make sure allocatedSlots correct. require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/pkg/dxf/framework/scheduler/exitScheduler", "return()")) @@ -201,11 +295,7 @@ func TestManagerSchedulerNotAllocateSlots(t *testing.T) { } func TestFastRespondNoNeedResourceTaskWhenSchedulersReachLimit(t *testing.T) { - bak := proto.MaxConcurrentTask - t.Cleanup(func() { - proto.MaxConcurrentTask = bak - }) - proto.MaxConcurrentTask = 1 + t.Cleanup(proto.SetMaxConcurrentTaskForTest(1)) ctrl := gomock.NewController(t) defer ctrl.Finish() diff --git a/pkg/dxf/framework/scheduler/scheduler_nokit_test.go b/pkg/dxf/framework/scheduler/scheduler_nokit_test.go index 283acff3a4ef5..632569f397a83 100644 --- a/pkg/dxf/framework/scheduler/scheduler_nokit_test.go +++ b/pkg/dxf/framework/scheduler/scheduler_nokit_test.go @@ -17,6 +17,7 @@ package scheduler import ( "context" "fmt" + "math/rand" "testing" "time" @@ -148,6 +149,58 @@ func TestSchedulerOnNextStage(t *testing.T) { require.True(t, ctrl.Satisfied()) } +func TestSchedulerMaxNodeCountRandomlySelectsEligibleNodes(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + taskMgr := mock.NewMockTaskManager(ctrl) + schExt := schmock.NewMockExtension(ctrl) + task := proto.Task{ + TaskBase: proto.TaskBase{ + ID: 1, + State: proto.TaskStatePending, + Step: proto.StepInit, + MaxNodeCount: 1, + RequiredSlots: 1, + }, + } + sch := createScheduler(&task, true, taskMgr, ctrl) + sch.Extension = schExt + + eligibleNodes := []string{":4000", ":4001", ":4002"} + expectedNodes := append([]string(nil), eligibleNodes...) + seed := int64(1) + for { + r := rand.New(rand.NewSource(seed)) + copy(expectedNodes, eligibleNodes) + r.Shuffle(len(expectedNodes), func(i, j int) { + expectedNodes[i], expectedNodes[j] = expectedNodes[j], expectedNodes[i] + }) + if expectedNodes[0] != eligibleNodes[0] { + break + } + seed++ + } + sch.rand = rand.New(rand.NewSource(seed)) + + schExt.EXPECT().GetNextStep(gomock.Any()).Return(proto.StepOne) + schExt.EXPECT().GetEligibleInstances(gomock.Any(), gomock.Any()).Return(eligibleNodes, nil) + schExt.EXPECT().OnNextSubtasksBatch(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func( + _ context.Context, + _ storage.TaskHandle, + _ *proto.Task, + nodes []string, + _ proto.Step, + ) ([][]byte, error) { + require.Equal(t, expectedNodes[:task.MaxNodeCount], nodes) + return [][]byte{[]byte(`{"xx": "1"}`)}, nil + }) + taskMgr.EXPECT().GetUsedSlotsOnNodes(gomock.Any()).Return(nil, nil) + taskMgr.EXPECT().SwitchTaskStep(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) + require.NoError(t, sch.Switch2NextStep()) + require.True(t, ctrl.Satisfied()) +} + func TestGetEligibleNodes(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() diff --git a/pkg/dxf/framework/scheduler/scheduler_test.go b/pkg/dxf/framework/scheduler/scheduler_test.go index afafd48736f1f..7005e6143e764 100644 --- a/pkg/dxf/framework/scheduler/scheduler_test.go +++ b/pkg/dxf/framework/scheduler/scheduler_test.go @@ -166,10 +166,9 @@ func checkSchedule(t *testing.T, taskCnt int, isSucc, isCancel, isSubtaskCancel, testfailpoint.Enable(t, "github.com/pingcap/tidb/pkg/domain/MockDisableDistTask", "return(true)") // test scheduleTaskLoop // test parallelism control - var originalConcurrency int + restoreMaxConcurrentTask := func() {} if taskCnt == 1 { - originalConcurrency = proto.MaxConcurrentTask - proto.MaxConcurrentTask = 1 + restoreMaxConcurrentTask = proto.SetMaxConcurrentTaskForTest(1) } store := testkit.CreateMockStore(t) @@ -190,10 +189,7 @@ func checkSchedule(t *testing.T, taskCnt int, isSucc, isCancel, isSubtaskCancel, sch.Start() defer func() { sch.Stop() - // make data race happy - if taskCnt == 1 { - proto.MaxConcurrentTask = originalConcurrency - } + restoreMaxConcurrentTask() }() // 3s diff --git a/pkg/dxf/framework/storage/BUILD.bazel b/pkg/dxf/framework/storage/BUILD.bazel index 320290856fbb8..041b7fd70d25e 100644 --- a/pkg/dxf/framework/storage/BUILD.bazel +++ b/pkg/dxf/framework/storage/BUILD.bazel @@ -48,7 +48,7 @@ go_test( ], embed = [":storage"], flaky = True, - shard_count = 28, + shard_count = 29, deps = [ "//pkg/config", "//pkg/dxf/framework/proto", @@ -56,13 +56,16 @@ go_test( "//pkg/dxf/framework/taskexecutor/execute", "//pkg/dxf/framework/testutil", "//pkg/kv", + "//pkg/session/sessionapi", "//pkg/sessionctx", "//pkg/sessionctx/vardef", + "//pkg/store/mockstore", "//pkg/testkit", "//pkg/testkit/testfailpoint", "//pkg/testkit/testsetup", "//pkg/util", "//pkg/util/sqlexec", + "@com_github_ngaut_pools//:pools", "@com_github_pingcap_errors//:errors", "@com_github_stretchr_testify//require", "@com_github_tikv_client_go_v2//util", diff --git a/pkg/dxf/framework/storage/history.go b/pkg/dxf/framework/storage/history.go index 3d44c573b87d7..9833d609734fb 100644 --- a/pkg/dxf/framework/storage/history.go +++ b/pkg/dxf/framework/storage/history.go @@ -56,42 +56,56 @@ func (mgr *TaskManager) TransferTasks2History(ctx context.Context, tasks []*prot return nil } taskIDStrs := make([]string, 0, len(tasks)) + updateMetaArgs := make([]any, 0, len(tasks)*2) + var updateMetaSQL strings.Builder + updateMetaSQL.WriteString(` + update mysql.tidb_global_task + set meta = case id`) for _, task := range tasks { taskIDStrs = append(taskIDStrs, fmt.Sprintf("%d", task.ID)) + updateMetaSQL.WriteString(" when %? then %?") + updateMetaArgs = append(updateMetaArgs, task.ID, task.Meta) } + taskIDList := strings.Join(taskIDStrs, `, `) + updateMetaSQL.WriteString(` + end, state_update_time = CURRENT_TIMESTAMP() + where id in(` + taskIDList + `)`) if err := injectfailpoint.DXFRandomErrorWithOnePercent(); err != nil { return err } return mgr.WithNewTxn(ctx, func(se sessionctx.Context) error { // sensitive data in meta might be redacted, need update first. exec := se.GetSQLExecutor() - for _, t := range tasks { - _, err := sqlexec.ExecSQL(ctx, exec, ` - update mysql.tidb_global_task - set meta= %?, state_update_time = CURRENT_TIMESTAMP() - where id = %?`, t.Meta, t.ID) - if err != nil { - return err - } + _, err := sqlexec.ExecSQL(ctx, exec, updateMetaSQL.String(), updateMetaArgs...) + if err != nil { + return err } - _, err := sqlexec.ExecSQL(ctx, exec, ` + _, err = sqlexec.ExecSQL(ctx, exec, ` insert into mysql.tidb_global_task_history select * from mysql.tidb_global_task - where id in(`+strings.Join(taskIDStrs, `, `)+`)`) + where id in(`+taskIDList+`)`) if err != nil { return err } _, err = sqlexec.ExecSQL(ctx, exec, ` delete from mysql.tidb_global_task - where id in(`+strings.Join(taskIDStrs, `, `)+`)`) + where id in(`+taskIDList+`)`) + if err != nil { + return err + } - for _, t := range tasks { - err = mgr.TransferSubtasks2HistoryWithSession(ctx, se, t.ID) - if err != nil { - return err - } + _, err = sqlexec.ExecSQL(ctx, exec, ` + insert into mysql.tidb_background_subtask_history + select * from mysql.tidb_background_subtask + where task_key in(`+taskIDList+`)`) + if err != nil { + return err } + + _, err = sqlexec.ExecSQL(ctx, exec, ` + delete from mysql.tidb_background_subtask + where task_key in(`+taskIDList+`)`) return err }) } diff --git a/pkg/dxf/framework/storage/table_test.go b/pkg/dxf/framework/storage/table_test.go index 9039f912a63bf..a388d43d55cd1 100644 --- a/pkg/dxf/framework/storage/table_test.go +++ b/pkg/dxf/framework/storage/table_test.go @@ -21,9 +21,12 @@ import ( "fmt" "slices" "sort" + "strings" + "sync" "testing" "time" + "github.com/ngaut/pools" "github.com/pingcap/errors" "github.com/pingcap/tidb/pkg/dxf/framework/proto" "github.com/pingcap/tidb/pkg/dxf/framework/schstatus" @@ -31,8 +34,10 @@ import ( "github.com/pingcap/tidb/pkg/dxf/framework/taskexecutor/execute" "github.com/pingcap/tidb/pkg/dxf/framework/testutil" "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/session/sessionapi" "github.com/pingcap/tidb/pkg/sessionctx" "github.com/pingcap/tidb/pkg/sessionctx/vardef" + "github.com/pingcap/tidb/pkg/store/mockstore" "github.com/pingcap/tidb/pkg/testkit" "github.com/pingcap/tidb/pkg/testkit/testfailpoint" "github.com/pingcap/tidb/pkg/util/sqlexec" @@ -41,6 +46,70 @@ import ( "go.uber.org/atomic" ) +type recordingSQLExecutor struct { + sqlexec.SQLExecutor + recorder *sqlRecorder +} + +func (e *recordingSQLExecutor) ExecuteInternal(ctx context.Context, sql string, args ...any) (sqlexec.RecordSet, error) { + e.recorder.record(sql) + return e.SQLExecutor.ExecuteInternal(ctx, sql, args...) +} + +type recordingSession struct { + sessionapi.Session + exec *recordingSQLExecutor +} + +func (s *recordingSession) GetSQLExecutor() sqlexec.SQLExecutor { + return s.exec +} + +type sqlRecorder struct { + mu sync.Mutex + sqls []string +} + +func (r *sqlRecorder) record(sql string) { + r.mu.Lock() + defer r.mu.Unlock() + r.sqls = append(r.sqls, normalizeSQL(sql)) +} + +func (r *sqlRecorder) reset() { + r.mu.Lock() + defer r.mu.Unlock() + r.sqls = nil +} + +func (r *sqlRecorder) countContains(substr string) int { + r.mu.Lock() + defer r.mu.Unlock() + cnt := 0 + for _, sql := range r.sqls { + if strings.Contains(sql, substr) { + cnt++ + } + } + return cnt +} + +func (r *sqlRecorder) requireContains(t *testing.T, substr string) { + t.Helper() + r.mu.Lock() + defer r.mu.Unlock() + for _, sql := range r.sqls { + if strings.Contains(sql, substr) { + return + } + } + require.Failf(t, "expected recorded SQL", "missing %q in %v", substr, r.sqls) +} + +func normalizeSQL(sql string) string { + return strings.ToLower(strings.Join(strings.Fields(sql), " ")) +} + func checkTaskStateStep(t *testing.T, task *proto.Task, state proto.TaskState, step proto.Step) { require.Equal(t, state, task.State) require.Equal(t, step, task.Step) @@ -436,11 +505,7 @@ func TestSwitchTaskStepInBatch(t *testing.T) { func TestGetTopUnfinishedTasks(t *testing.T) { _, gm, ctx := testutil.InitTableTest(t) - bak := proto.MaxConcurrentTask - t.Cleanup(func() { - proto.MaxConcurrentTask = bak - }) - proto.MaxConcurrentTask = 4 + t.Cleanup(proto.SetMaxConcurrentTaskForTest(4)) require.NoError(t, gm.InitMeta(ctx, ":4000", "")) taskStates := []proto.TaskState{ proto.TaskStateSucceed, @@ -505,18 +570,18 @@ func TestGetTopUnfinishedTasks(t *testing.T) { require.Len(t, tasks, 8) require.Equal(t, []string{"key/6", "key/5", "key/1", "key/2", "key/3", "key/4", "key/8", "key/9"}, getTaskKeys(tasks)) - proto.MaxConcurrentTask = 6 + proto.SetMaxConcurrentTaskForTest(6) tasks, err = gm.GetTopUnfinishedTasks(ctx) require.NoError(t, err) require.Len(t, tasks, 11) require.Equal(t, []string{"key/6", "key/5", "key/1", "key/2", "key/3", "key/4", "key/8", "key/9", "key/10", "key/11", "key/12"}, getTaskKeys(tasks)) - proto.MaxConcurrentTask = 3 + proto.SetMaxConcurrentTaskForTest(3) tasks, err = gm.GetTopNoNeedResourceTasks(ctx) require.NoError(t, err) require.Equal(t, []string{"key/5", "key/3", "key/4", "key/12"}, getTaskKeys(tasks)) - proto.MaxConcurrentTask = 1 + proto.SetMaxConcurrentTaskForTest(1) tasks, err = gm.GetTopNoNeedResourceTasks(ctx) require.NoError(t, err) require.Equal(t, []string{"key/5", "key/3"}, getTaskKeys(tasks)) @@ -990,6 +1055,19 @@ func TestSubtaskHistoryTable(t *testing.T) { } func TestTaskHistoryTable(t *testing.T) { + t.Run("get tasks in states limit follows max concurrent task", func(t *testing.T) { + t.Cleanup(proto.SetMaxConcurrentTaskForTest(1)) + _, gm, ctx := testutil.InitTableTest(t) + require.NoError(t, gm.InitMeta(ctx, ":4000", "")) + for i := range 5 { + _, err := gm.CreateTask(ctx, fmt.Sprintf("limit-%d", i), proto.TaskTypeExample, "", 1, "", 0, proto.ExtraParams{}, nil) + require.NoError(t, err) + } + tasks, err := gm.GetTasksInStates(ctx, proto.TaskStatePending) + require.NoError(t, err) + require.Len(t, tasks, proto.GetMaxConcurrentTask()*3) + }) + _, gm, ctx := testutil.InitTableTest(t) require.NoError(t, gm.InitMeta(ctx, ":4000", "")) @@ -1132,6 +1210,92 @@ func TestTaskHistoryTable(t *testing.T) { _, err2 = gm.ListHistoryTasks(ctx, 201, 0, "") require.ErrorContains(t, err2, "page size should be within") }) + + t.Run("get tasks in states returns at most one batch", func(t *testing.T) { + taskQueryLimit := proto.GetMaxConcurrentTask() * 3 + for _, sql := range []string{ + "delete from mysql.tidb_background_subtask", + "delete from mysql.tidb_background_subtask_history", + "delete from mysql.tidb_global_task", + "delete from mysql.tidb_global_task_history", + } { + _, err = gm.ExecuteSQLWithNewSession(ctx, sql) + require.NoError(t, err) + } + + createdIDs := make([]int64, 0, taskQueryLimit+1) + for i := 0; i < taskQueryLimit+1; i++ { + taskID, err2 := gm.CreateTask(ctx, fmt.Sprintf("batch-task-%03d", i), proto.TaskTypeExample, "", 1, "", 0, proto.ExtraParams{}, nil) + require.NoError(t, err2) + createdIDs = append(createdIDs, taskID) + } + tasks, err = gm.GetTasksInStates(ctx, proto.TaskStatePending) + require.NoError(t, err) + require.Len(t, tasks, taskQueryLimit) + for i, task := range tasks { + require.Equal(t, createdIDs[i], task.ID) + } + }) +} + +func TestTransferTasks2HistoryUsesBatchStatements(t *testing.T) { + testfailpoint.Enable(t, "github.com/pingcap/tidb/pkg/domain/MockDisableDistTask", "return(true)") + testfailpoint.Enable(t, "github.com/pingcap/tidb/pkg/util/cpu/mockNumCpu", "return(8)") + + store := testkit.CreateMockStore(t, mockstore.WithStoreType(mockstore.EmbedUnistore)) + recorder := &sqlRecorder{} + pool := pools.NewResourcePool(func() (pools.Resource, error) { + tk := testkit.NewTestKit(t, store) + se := tk.Session() + return &recordingSession{ + Session: se, + exec: &recordingSQLExecutor{ + SQLExecutor: se.GetSQLExecutor(), + recorder: recorder, + }, + }, nil + }, 10, 10, time.Second) + t.Cleanup(func() { + pool.Close() + }) + + gm := storage.NewTaskManager(pool) + storage.SetTaskManager(gm) + ctx := util.WithInternalSourceType(context.Background(), "table_test") + require.NoError(t, gm.InitMeta(ctx, ":4000", "")) + + tasksToTransfer := make([]*proto.Task, 0, 3) + for i := range 3 { + taskID, err := gm.CreateTask(ctx, fmt.Sprintf("batch-history-%d", i), proto.TaskTypeExample, "", 1, "", 0, proto.ExtraParams{}, []byte("original")) + require.NoError(t, err) + testutil.InsertSubtask(t, gm, taskID, proto.StepOne, "tidb1", proto.EmptyMeta, proto.SubtaskStateRunning, proto.TaskTypeExample, 1) + + task, err := gm.GetTaskByID(ctx, taskID) + require.NoError(t, err) + task.Meta = []byte(fmt.Sprintf("redacted-%d", i)) + tasksToTransfer = append(tasksToTransfer, task) + } + + recorder.reset() + require.NoError(t, gm.TransferTasks2History(ctx, tasksToTransfer)) + + require.Equal(t, 1, recorder.countContains("update mysql.tidb_global_task")) + recorder.requireContains(t, "set meta = case id") + require.Equal(t, 1, recorder.countContains("insert into mysql.tidb_background_subtask_history")) + require.Equal(t, 1, recorder.countContains("delete from mysql.tidb_background_subtask")) + + taskCnt, err := testutil.GetTasksFromHistory(ctx, gm) + require.NoError(t, err) + require.Equal(t, len(tasksToTransfer), taskCnt) + for i, task := range tasksToTransfer { + subtaskCnt, err := testutil.GetSubtasksFromHistoryByTaskID(ctx, gm, task.ID) + require.NoError(t, err) + require.Equal(t, 1, subtaskCnt) + + historyTask, err := gm.GetTaskByIDWithHistory(ctx, task.ID) + require.NoError(t, err) + require.Equal(t, []byte(fmt.Sprintf("redacted-%d", i)), historyTask.Meta) + } } func TestPauseAndResume(t *testing.T) { diff --git a/pkg/dxf/framework/storage/task_table.go b/pkg/dxf/framework/storage/task_table.go index a385189a37026..15ae6acefd696 100644 --- a/pkg/dxf/framework/storage/task_table.go +++ b/pkg/dxf/framework/storage/task_table.go @@ -360,7 +360,7 @@ func (mgr *TaskManager) getTopTasks(ctx context.Context, states ...proto.TaskSta for _, s := range states { args = append(args, s) } - args = append(args, proto.MaxConcurrentTask*2) + args = append(args, proto.GetMaxConcurrentTask()*2) rs, err := mgr.ExecuteSQLWithNewSession(ctx, sql, args...) if err != nil { return nil, err @@ -440,10 +440,13 @@ func (mgr *TaskManager) GetTasksInStates(ctx context.Context, states ...any) (ta if err := injectfailpoint.DXFRandomErrorWithOnePercent(); err != nil { return nil, err } + args := make([]any, 0, len(states)+1) + args = append(args, states...) + args = append(args, proto.GetMaxConcurrentTask()*3) rs, err := mgr.ExecuteSQLWithNewSession(ctx, "select "+TaskColumns+" from mysql.tidb_global_task t "+ "where state in ("+strings.Repeat("%?,", len(states)-1)+"%?)"+ - " order by priority asc, create_time asc, id asc", states...) + " order by priority asc, create_time asc, id asc limit %?", args...) if err != nil { return task, err } diff --git a/pkg/dxf/importinto/BUILD.bazel b/pkg/dxf/importinto/BUILD.bazel index 43a8f97a2b8af..bed22e0fe9b03 100644 --- a/pkg/dxf/importinto/BUILD.bazel +++ b/pkg/dxf/importinto/BUILD.bazel @@ -94,6 +94,7 @@ go_test( name = "importinto_test", timeout = "short", srcs = [ + "clean_up_test.go", "collect_conflicts_test.go", "conflict_resolution_test.go", "encode_and_sort_operator_test.go", @@ -108,7 +109,7 @@ go_test( ], embed = [":importinto"], flaky = True, - shard_count = 27, + shard_count = 29, deps = [ "//pkg/config", "//pkg/config/kerneltype", @@ -160,6 +161,7 @@ go_test( "@com_github_pingcap_errors//:errors", "@com_github_pingcap_failpoint//:failpoint", "@com_github_pingcap_kvproto//pkg/keyspacepb", + "@com_github_pingcap_log//:log", "@com_github_prometheus_client_golang//prometheus", "@com_github_stretchr_testify//require", "@com_github_stretchr_testify//suite", diff --git a/pkg/dxf/importinto/clean_up.go b/pkg/dxf/importinto/clean_up.go index 11723867590a0..4eb4e3c96f6a8 100644 --- a/pkg/dxf/importinto/clean_up.go +++ b/pkg/dxf/importinto/clean_up.go @@ -36,12 +36,15 @@ import ( "github.com/pingcap/tidb/pkg/lightning/verification" "github.com/pingcap/tidb/pkg/meta/model" "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/util" "github.com/pingcap/tidb/pkg/util/logutil" "go.uber.org/zap" ) var _ scheduler.CleanUpRoutine = (*ImportCleanUp)(nil) +const cleanUpMeteringConcurrency = 4 + // ImportCleanUp implements scheduler.CleanUpRoutine. type ImportCleanUp struct { } @@ -51,68 +54,148 @@ func newImportCleanUpS3() scheduler.CleanUpRoutine { } // CleanUp implements the CleanUpRoutine.CleanUp interface. -func (*ImportCleanUp) CleanUp(ctx context.Context, task *proto.Task) error { +func (c *ImportCleanUp) CleanUp(ctx context.Context, task *proto.Task) error { + return c.CleanUpBatch(ctx, []*proto.Task{task}) +} + +type cleanUpTaskInfo struct { + task *proto.Task + needFileCleanUp bool +} + +type cleanUpFileGroup struct { + cloudStorageURI string + nonPartitionedDirs []string + taskIDs []int64 +} + +type sendMeterOnCleanUpFunc func(context.Context, *proto.Task, *zap.Logger) error + +// CleanUpBatch cleans up multiple import tasks in batch. +func (*ImportCleanUp) CleanUpBatch(ctx context.Context, tasks []*proto.Task) error { + if len(tasks) == 0 { + return nil + } + // we can only clean up files after all write&ingest subtasks are finished, // since they might share the same file. - taskMeta := &TaskMeta{} - err := json.Unmarshal(task.Meta, taskMeta) - if err != nil { - return err - } - defer redactSensitiveInfo(task, taskMeta) + cleanUpTasks := make([]cleanUpTaskInfo, 0, len(tasks)) + fileGroupIdx := make(map[string]int) + fileGroups := make([]cleanUpFileGroup, 0, len(tasks)) + for _, task := range tasks { + taskMeta := &TaskMeta{} + err := json.Unmarshal(task.Meta, taskMeta) + if err != nil { + return err + } + defer redactSensitiveInfo(task, taskMeta) + + if err = cleanUpTableMode(ctx, taskMeta); err != nil { + return err + } - if kerneltype.IsClassic() { - taskManager, err := storage.GetTaskManager() + failpoint.InjectCall("mockCleanupError", &err) if err != nil { return err } - if err = taskManager.WithNewTxn(ctx, func(se sessionctx.Context) error { - return ddl.AlterTableMode(domain.GetDomain(se).DDLExecutor(), se, model.TableModeNormal, taskMeta.Plan.DBID, taskMeta.Plan.TableInfo.ID) - }); err != nil { - // If the table is not found, it means the table has been either - // dropped or truncated. In such cases, the table mode has already - // been reset to normal, so we can ignore this error. - if !goerrors.Is(err, infoschema.ErrTableNotExists) { - return err - } - logutil.BgLogger().Warn( - "table not found during import cleanup, skip altering table mode", - zap.Int64("tableID", taskMeta.Plan.TableInfo.ID), - ) + // Not use cloud storage, no need to cleanUp. + needFileCleanUp := taskMeta.Plan.CloudStorageURI != "" + cleanUpTasks = append(cleanUpTasks, cleanUpTaskInfo{ + task: task, + needFileCleanUp: needFileCleanUp, + }) + if !needFileCleanUp { + continue + } + idx, ok := fileGroupIdx[taskMeta.Plan.CloudStorageURI] + if !ok { + idx = len(fileGroups) + fileGroupIdx[taskMeta.Plan.CloudStorageURI] = idx + fileGroups = append(fileGroups, cleanUpFileGroup{ + cloudStorageURI: taskMeta.Plan.CloudStorageURI, + }) } + fileGroups[idx].nonPartitionedDirs = append(fileGroups[idx].nonPartitionedDirs, strconv.Itoa(int(task.ID))) + fileGroups[idx].taskIDs = append(fileGroups[idx].taskIDs, task.ID) } - failpoint.InjectCall("mockCleanupError", &err) + for _, fileGroup := range fileGroups { + if err := cleanUpExternalFiles(ctx, fileGroup); err != nil { + return err + } + } + + if kerneltype.IsNextGen() { + // send metering data for nextgen kernel, only for succeed tasks + if err := sendMeterOnCleanUpInParallel(ctx, cleanUpTasks, sendMeterOnCleanUp); err != nil { + return err + } + } + return nil +} + +func sendMeterOnCleanUpInParallel(ctx context.Context, cleanUpTasks []cleanUpTaskInfo, sendFn sendMeterOnCleanUpFunc) error { + eg, egCtx := util.NewErrorGroupWithRecoverWithCtx(ctx) + eg.SetLimit(cleanUpMeteringConcurrency) + for _, cleanUpTask := range cleanUpTasks { + cleanUpTask := cleanUpTask + if !cleanUpTask.needFileCleanUp || cleanUpTask.task.State != proto.TaskStateSucceed { + continue + } + eg.Go(func() error { + logger := logutil.BgLogger().With(zap.Int64("task-id", cleanUpTask.task.ID)) + if err := sendFn(egCtx, cleanUpTask.task, logger); err != nil { + logger.Warn("failed to send metering data on cleanup", zap.Error(err)) + return err + } + return nil + }) + } + return eg.Wait() +} + +func cleanUpTableMode(ctx context.Context, taskMeta *TaskMeta) error { + if !kerneltype.IsClassic() { + return nil + } + taskManager, err := storage.GetTaskManager() if err != nil { return err } + if err = taskManager.WithNewTxn(ctx, func(se sessionctx.Context) error { + return ddl.AlterTableMode(domain.GetDomain(se).DDLExecutor(), se, model.TableModeNormal, taskMeta.Plan.DBID, taskMeta.Plan.TableInfo.ID) + }); err != nil { + // If the table is not found, it means the table has been either + // dropped or truncated. In such cases, the table mode has already + // been reset to normal, so we can ignore this error. + if !goerrors.Is(err, infoschema.ErrTableNotExists) { + return err + } - // Not use cloud storage, no need to cleanUp. - if taskMeta.Plan.CloudStorageURI == "" { - return nil + logutil.BgLogger().Warn( + "table not found during import cleanup, skip altering table mode", + zap.Int64("tableID", taskMeta.Plan.TableInfo.ID), + ) } - logger := logutil.BgLogger().With(zap.Int64("task-id", task.ID)) + return nil +} + +func cleanUpExternalFiles(ctx context.Context, fileGroup cleanUpFileGroup) error { + logger := logutil.BgLogger().With(zap.Int64s("task-ids", fileGroup.taskIDs)) callLog := log.BeginTask(logger, "cleanup global sorted data") defer callLog.End(zap.InfoLevel, nil) - store, err := importer.GetSortStore(ctx, taskMeta.Plan.CloudStorageURI) + store, err := importer.GetSortStore(ctx, fileGroup.cloudStorageURI) if err != nil { logger.Warn("failed to create store", zap.Error(err)) return err } defer store.Close() - if err = external.CleanUpFiles(ctx, store, strconv.Itoa(int(task.ID))); err != nil { - logger.Warn("failed to clean up files of task", zap.Error(err)) + if err = external.CleanUpFiles(ctx, store, fileGroup.nonPartitionedDirs...); err != nil { + logger.Warn("failed to clean up files of tasks", zap.Error(err)) return err } - // send metering data for nextgen kernel, only for succeed tasks - if kerneltype.IsNextGen() && task.State == proto.TaskStateSucceed { - if err = sendMeterOnCleanUp(ctx, task, logger); err != nil { - logger.Warn("failed to send metering data on cleanup", zap.Error(err)) - return err - } - } return nil } diff --git a/pkg/dxf/importinto/clean_up_test.go b/pkg/dxf/importinto/clean_up_test.go new file mode 100644 index 0000000000000..0a843b09149d7 --- /dev/null +++ b/pkg/dxf/importinto/clean_up_test.go @@ -0,0 +1,124 @@ +// Copyright 2026 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package importinto + +import ( + "context" + "sync/atomic" + "testing" + "time" + + "github.com/pingcap/log" + "github.com/pingcap/tidb/pkg/dxf/framework/proto" + "github.com/stretchr/testify/require" + "go.uber.org/zap" +) + +func TestSendMeterOnCleanUpInParallelLimitsConcurrency(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + cleanUpTasks := make([]cleanUpTaskInfo, cleanUpMeteringConcurrency*2) + for i := range cleanUpTasks { + cleanUpTasks[i] = cleanUpTaskInfo{ + task: &proto.Task{ + TaskBase: proto.TaskBase{ + ID: int64(i + 1), + State: proto.TaskStateSucceed, + }, + }, + needFileCleanUp: true, + } + } + + firstBatchStarted := make(chan struct{}) + release := make(chan struct{}) + overflow := make(chan struct{}) + done := make(chan error, 1) + var active, maxActive, started, overflowed int32 + sendFn := func(ctx context.Context, task *proto.Task, logger *zap.Logger) error { + current := atomic.AddInt32(&active, 1) + defer atomic.AddInt32(&active, -1) + if current > int32(cleanUpMeteringConcurrency) && atomic.CompareAndSwapInt32(&overflowed, 0, 1) { + close(overflow) + } + for { + max := atomic.LoadInt32(&maxActive) + if current <= max || atomic.CompareAndSwapInt32(&maxActive, max, current) { + break + } + } + if atomic.AddInt32(&started, 1) == int32(cleanUpMeteringConcurrency) { + close(firstBatchStarted) + } + select { + case <-release: + return nil + case <-ctx.Done(): + return ctx.Err() + } + } + + go func() { + done <- sendMeterOnCleanUpInParallel(ctx, cleanUpTasks, sendFn) + }() + + select { + case <-firstBatchStarted: + case <-ctx.Done(): + require.NoError(t, ctx.Err()) + } + select { + case <-overflow: + close(release) + require.FailNow(t, "metering cleanup exceeded concurrency limit") + case err := <-done: + require.NoError(t, err) + require.FailNow(t, "metering cleanup returned before workers were released") + case <-time.After(100 * time.Millisecond): + } + + close(release) + var err error + select { + case err = <-done: + case <-ctx.Done(): + require.NoError(t, ctx.Err()) + } + require.NoError(t, err) + require.Equal(t, int32(len(cleanUpTasks)), atomic.LoadInt32(&started)) + require.Equal(t, int32(cleanUpMeteringConcurrency), atomic.LoadInt32(&maxActive)) +} + +func TestSendMeterOnCleanUpInParallelRecoversPanic(t *testing.T) { + restoreLog := log.ReplaceGlobals(zap.NewNop(), &log.ZapProperties{Level: zap.NewAtomicLevelAt(zap.FatalLevel)}) + defer restoreLog() + + cleanUpTasks := []cleanUpTaskInfo{{ + task: &proto.Task{ + TaskBase: proto.TaskBase{ + ID: 1, + State: proto.TaskStateSucceed, + }, + }, + needFileCleanUp: true, + }} + sendFn := func(context.Context, *proto.Task, *zap.Logger) error { + panic("metering panic") + } + + err := sendMeterOnCleanUpInParallel(context.Background(), cleanUpTasks, sendFn) + require.ErrorContains(t, err, "metering panic") +} diff --git a/pkg/lightning/backend/external/util.go b/pkg/lightning/backend/external/util.go index 2574add2ddb31..2fc7ef620f062 100644 --- a/pkg/lightning/backend/external/util.go +++ b/pkg/lightning/backend/external/util.go @@ -145,12 +145,12 @@ func getReadRangeFromProps( return readRangesPerKey, nil } -// GetAllFileNames returns files with the same non-partitioned dir. +// GetAllFileNames returns files with the same non-partitioned dirs. // - for intermediate KV/stat files we store them with a partitioned way to mitigate // limitation on Cloud, see randPartitionedPrefix for how we partition the files. // - for meta files, we store them directly under the non-partitioned dir. // -// for example, if nonPartitionedDir is '30001', the files returned might be +// for example, if nonPartitionedDirs contains '30001', the files returned might be // - 30001/6/meta.json // - 30001/7/meta.json // - 30001/plan/ingest/1/meta.json @@ -160,8 +160,16 @@ func getReadRangeFromProps( func GetAllFileNames( ctx context.Context, store storeapi.Storage, - nonPartitionedDir string, + nonPartitionedDirs ...string, ) ([]string, error) { + if len(nonPartitionedDirs) == 0 { + return nil, nil + } + nonPartitionedDirSet := make(map[string]struct{}, len(nonPartitionedDirs)) + for _, dir := range nonPartitionedDirs { + nonPartitionedDirSet[dir] = struct{}{} + } + var data []string err := store.WalkDir(ctx, @@ -175,7 +183,7 @@ func GetAllFileNames( } firstDir := bs[:firstIdx] - if string(firstDir) == nonPartitionedDir { + if _, ok := nonPartitionedDirSet[string(firstDir)]; ok { data = append(data, path) return nil } @@ -189,7 +197,7 @@ func GetAllFileNames( } secondDir := path[firstIdx+1 : firstIdx+1+secondIdx] - if secondDir == nonPartitionedDir { + if _, ok := nonPartitionedDirSet[secondDir]; ok { data = append(data, path) } return nil @@ -202,13 +210,16 @@ func GetAllFileNames( return data, nil } -// CleanUpFiles delete all data and stat files under the same non-partitioned dir. +// CleanUpFiles delete all data and stat files under the same non-partitioned dirs. // see randPartitionedPrefix for how we partition the files. -func CleanUpFiles(ctx context.Context, store storeapi.Storage, nonPartitionedDir string) error { +func CleanUpFiles(ctx context.Context, store storeapi.Storage, nonPartitionedDirs ...string) error { failpoint.Inject("skipCleanUpFiles", func() { failpoint.Return(nil) }) - names, err := GetAllFileNames(ctx, store, nonPartitionedDir) + if len(nonPartitionedDirs) == 0 { + return nil + } + names, err := GetAllFileNames(ctx, store, nonPartitionedDirs...) if err != nil { return err } diff --git a/pkg/lightning/backend/external/util_test.go b/pkg/lightning/backend/external/util_test.go index 693c83e97b270..7b87380aed412 100644 --- a/pkg/lightning/backend/external/util_test.go +++ b/pkg/lightning/backend/external/util_test.go @@ -38,6 +38,20 @@ type blockingOpenMemStorage struct { max atomic.Int32 } +type walkCountingStorage struct { + storeapi.Storage + count atomic.Int32 +} + +func (s *walkCountingStorage) WalkDir( + ctx context.Context, + opt *storeapi.WalkOption, + fn func(path string, size int64) error, +) error { + s.count.Add(1) + return s.Storage.WalkDir(ctx, opt, fn) +} + func (s *blockingOpenMemStorage) Open( ctx context.Context, path string, @@ -301,15 +315,13 @@ func TestGetAllFileNames(t *testing.T) { }, filenames) } -func TestCleanUpFiles(t *testing.T) { - ctx := context.Background() - store := objstore.NewMemStorage() +func writeCleanupTestFiles(ctx context.Context, t *testing.T, store storeapi.Storage, dir string) { w := NewWriterBuilder(). SetMemorySizeLimit(10*(lengthBytes*2+2)). SetBlockSize(10*(lengthBytes*2+2)). SetPropSizeDistance(5). SetPropKeysDistance(3). - Build(store, "/subtask", "0") + Build(store, dir, "0") keys := make([][]byte, 0, 30) values := make([][]byte, 0, 30) for i := range 30 { @@ -320,22 +332,38 @@ func TestCleanUpFiles(t *testing.T) { err := w.WriteRow(ctx, key, values[i], nil) require.NoError(t, err) } - err := w.Close(ctx) - require.NoError(t, err) + require.NoError(t, w.Close(ctx)) +} - filenames, err := GetAllFileNames(ctx, store, "subtask") +func TestCleanUpFiles(t *testing.T) { + ctx := context.Background() + baseStore := objstore.NewMemStorage() + store := &walkCountingStorage{Storage: baseStore} + writeCleanupTestFiles(ctx, t, store, "/subtask") + writeCleanupTestFiles(ctx, t, store, "/subtask2") + writeCleanupTestFiles(ctx, t, store, "/kept") + + filenames, err := GetAllFileNames(ctx, store, "subtask", "subtask2") require.NoError(t, err) filenames = removePartitionPrefix(t, filenames) require.Equal(t, []string{ "/subtask/0/0", "/subtask/0/1", "/subtask/0/2", "/subtask/0_stat/0", "/subtask/0_stat/1", "/subtask/0_stat/2", + "/subtask2/0/0", "/subtask2/0/1", "/subtask2/0/2", + "/subtask2/0_stat/0", "/subtask2/0_stat/1", "/subtask2/0_stat/2", }, filenames) + require.Equal(t, int32(1), store.count.Load()) - require.NoError(t, CleanUpFiles(ctx, store, "subtask")) + store.count.Store(0) + require.NoError(t, CleanUpFiles(ctx, store, "subtask", "subtask2")) + require.Equal(t, int32(1), store.count.Load()) - filenames, err = GetAllFileNames(ctx, store, "subtask") + filenames, err = GetAllFileNames(ctx, baseStore, "subtask", "subtask2") require.NoError(t, err) require.Equal(t, []string(nil), filenames) + filenames, err = GetAllFileNames(ctx, baseStore, "kept") + require.NoError(t, err) + require.Len(t, filenames, 6) } func TestGetMaxOverlapping(t *testing.T) { diff --git a/pkg/server/handler/tests/dxf_test.go b/pkg/server/handler/tests/dxf_test.go index c7816db6989f2..dff32a796366a 100644 --- a/pkg/server/handler/tests/dxf_test.go +++ b/pkg/server/handler/tests/dxf_test.go @@ -146,6 +146,50 @@ func TestDXFAPI(t *testing.T) { require.EqualValues(t, 2, out.PerKeyspace["ks1"]) }) + t.Run("max concurrent task api", func(t *testing.T) { + restore := proto.SetMaxConcurrentTaskForTest(proto.DefaultMaxConcurrentTask) + defer restore() + + runAndCheckReqFn(t, http.StatusBadRequest, "This api only support GET and POST method", func() (*http.Response, error) { + req, err := http.NewRequest(http.MethodDelete, ts.StatusURL("/dxf/task/max_concurrent"), nil) + require.NoError(t, err) + return http.DefaultClient.Do(req) + }) + for _, c := range [][2]string{ + {"/dxf/task/max_concurrent", "invalid value "}, + {"/dxf/task/max_concurrent?value=aa", "invalid value "}, + {"/dxf/task/max_concurrent?value=15", "out of range"}, + {fmt.Sprintf("/dxf/task/max_concurrent?value=%d", proto.MaxMaxConcurrentTask+1), "out of range"}, + } { + path, errMsg := c[0], c[1] + runAndCheckReqFn(t, http.StatusBadRequest, errMsg, func() (*http.Response, error) { + return ts.PostStatus(path, "", bytes.NewBuffer([]byte(""))) + }) + } + + body := runAndCheckReqFn(t, http.StatusOK, "", func() (*http.Response, error) { + return ts.FetchStatus("/dxf/task/max_concurrent") + }) + out := struct { + MaxConcurrentTask int `json:"max_concurrent_task"` + }{} + require.NoError(t, json.Unmarshal(body, &out)) + require.Equal(t, proto.DefaultMaxConcurrentTask, out.MaxConcurrentTask) + + body = runAndCheckReqFn(t, http.StatusOK, "", func() (*http.Response, error) { + return ts.PostStatus("/dxf/task/max_concurrent?value=128", "", bytes.NewBuffer([]byte(""))) + }) + require.NoError(t, json.Unmarshal(body, &out)) + require.Equal(t, 128, out.MaxConcurrentTask) + require.Equal(t, 128, proto.GetMaxConcurrentTask()) + + body = runAndCheckReqFn(t, http.StatusOK, "", func() (*http.Response, error) { + return ts.FetchStatus("/dxf/task/max_concurrent") + }) + require.NoError(t, json.Unmarshal(body, &out)) + require.Equal(t, 128, out.MaxConcurrentTask) + }) + t.Run("task history api", func(t *testing.T) { seedHistoryTasks := func(t *testing.T) []int64 { t.Helper() diff --git a/pkg/server/handler/tikvhandler/dxf.go b/pkg/server/handler/tikvhandler/dxf.go index 44fc810307810..04e0783fa534a 100644 --- a/pkg/server/handler/tikvhandler/dxf.go +++ b/pkg/server/handler/tikvhandler/dxf.go @@ -344,6 +344,40 @@ func (h *DXFScheduleTuneHandler) ServeHTTP(w http.ResponseWriter, req *http.Requ } } +// DXFTaskMaxConcurrentHandler handles the in-memory DXF task concurrency limit. +type DXFTaskMaxConcurrentHandler struct{} + +// NewDXFTaskMaxConcurrentHandler creates a new DXFTaskMaxConcurrentHandler. +func NewDXFTaskMaxConcurrentHandler() *DXFTaskMaxConcurrentHandler { + return &DXFTaskMaxConcurrentHandler{} +} + +func (*DXFTaskMaxConcurrentHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { + switch req.Method { + case http.MethodGet: + handler.WriteData(w, map[string]any{ + "max_concurrent_task": proto.GetMaxConcurrentTask(), + }) + case http.MethodPost: + valueStr := req.FormValue("value") + value, err := strconv.Atoi(valueStr) + if err != nil { + handler.WriteError(w, errors.Errorf("invalid value %s, error %v", valueStr, err)) + return + } + if err := proto.SetMaxConcurrentTask(value); err != nil { + handler.WriteError(w, err) + return + } + logutil.BgLogger().Info("set DXF max concurrent task", zap.Int("maxConcurrentTask", value)) + handler.WriteData(w, map[string]any{ + "max_concurrent_task": proto.GetMaxConcurrentTask(), + }) + default: + handler.WriteError(w, errors.Errorf("This api only support GET and POST method")) + } +} + // DXFTaskMaxRuntimeSlotsHandler handles changing max runtime slots of DXF task. type DXFTaskMaxRuntimeSlotsHandler struct{} diff --git a/pkg/server/http_status.go b/pkg/server/http_status.go index d96018dbcc46e..3b7abb76268a4 100644 --- a/pkg/server/http_status.go +++ b/pkg/server/http_status.go @@ -254,6 +254,7 @@ func (s *Server) startHTTPServer() { router.Handle("/dxf/schedule/tune", tikvhandler.NewDXFScheduleTuneHandler(tikvHandlerTool.Store.(kv.Storage))).Name("DXF_Schedule_Tune") router.Handle("/dxf/task/active", tikvhandler.NewDXFActiveTaskHandler()).Name("DXF_Task_Active") router.Handle("/dxf/task/history", tikvhandler.NewDXFTaskHistoryHandler()).Name("DXF_Task_History") + router.Handle("/dxf/task/max_concurrent", tikvhandler.NewDXFTaskMaxConcurrentHandler()).Name("DXF_Task_Max_Concurrent") router.Handle("/dxf/import-into/history/job/{keyspace}/{job_id}", tikvhandler.NewDXFImportIntoHistoryJobInfoHandler()).Name("DXF_Import_Into_History_Job_Info") router.Handle("/dxf/task/{taskID}/max_runtime_slots", tikvhandler.NewDXFTaskMaxRuntimeSlotsHandler()).Name("DXF_Task_Max_Runtime_Slots") } diff --git a/tests/realtikvtest/addindextest1/disttask_test.go b/tests/realtikvtest/addindextest1/disttask_test.go index d4c2bba9ad0d9..1f9b61ff4910d 100644 --- a/tests/realtikvtest/addindextest1/disttask_test.go +++ b/tests/realtikvtest/addindextest1/disttask_test.go @@ -507,7 +507,7 @@ func TestAddIndexScheduleAway(t *testing.T) { } func TestAddIndexDistCleanUpBlock(t *testing.T) { - proto.MaxConcurrentTask = 1 + t.Cleanup(proto.SetMaxConcurrentTaskForTest(1)) testfailpoint.Enable(t, "github.com/pingcap/tidb/pkg/util/cpu/mockNumCpu", `return(1)`) ch := make(chan struct{}) testfailpoint.EnableCall(t, "github.com/pingcap/tidb/pkg/dxf/framework/scheduler/doCleanupTask", func() {