diff --git a/src/backend/distributed/planner/insert_select_planner.c b/src/backend/distributed/planner/insert_select_planner.c index 3191615d663..92a7869f883 100644 --- a/src/backend/distributed/planner/insert_select_planner.c +++ b/src/backend/distributed/planner/insert_select_planner.c @@ -89,6 +89,9 @@ static DeferredErrorMessage * InsertPartitionColumnMatchesSelect(Query *query, subqueryRte, Oid * selectPartitionColumnTableId); +static bool InsertPartitionColumnIsBatchPassThrough(Query *query, + RangeTblEntry *insertRte, + RangeTblEntry *subqueryRte); static DistributedPlan * CreateNonPushableInsertSelectPlan(uint64 planId, Query *parse, ParamListInfo boundParams); static DeferredErrorMessage * NonPushableInsertSelectSupported(Query *insertSelectQuery); @@ -771,14 +774,16 @@ DistributedInsertSelectSupported(Query *queryTree, RangeTblEntry *insertRte, /* first apply toplevel pushdown checks to SELECT query */ error = DeferErrorIfUnsupportedSubqueryPushdown(subquery, plannerRestrictionContext, - true); + true, AllowUnsafeInsertSelectPushdown) + ; if (error) { return error; } /* then apply subquery pushdown checks to SELECT query */ - error = DeferErrorIfCannotPushdownSubquery(subquery, false); + error = DeferErrorIfCannotPushdownSubquery(subquery, false, + AllowUnsafeInsertSelectPushdown); if (error) { return error; @@ -821,11 +826,29 @@ DistributedInsertSelectSupported(Query *queryTree, RangeTblEntry *insertRte, "table", NULL, NULL); } + /* + * Ensure that INSERT's partition column comes from SELECT's partition + * column. Normally this requires a plain-Var partition-column match. + * With unsafe INSERT ... SELECT pushdown enabled we additionally accept + * a provably shard-local batch pass-through of the distribution column, + * i.e. unnest(array_agg(dist_key)); those values can only hash back into + * the current shard, so the co-location check further below keeps the + * batch shard-local. Any other derived distribution value is still + * rejected, because it could route rows that actually belong to a different + * shard. + */ if (HasDistributionKey(targetRelationId)) { - /* ensure that INSERT's partition column comes from SELECT's partition column */ - error = InsertPartitionColumnMatchesSelect(queryTree, insertRte, subqueryRte, + error = InsertPartitionColumnMatchesSelect(queryTree, insertRte, + subqueryRte, &selectPartitionColumnTableId); + if (error && AllowUnsafeInsertSelectPushdown && + InsertPartitionColumnIsBatchPassThrough(queryTree, insertRte, + subqueryRte)) + { + error = NULL; + } + if (error) { return error; @@ -1196,6 +1219,56 @@ ReorderInsertSelectTargetLists(Query *originalQuery, RangeTblEntry *insertRte, } +/* + * InsertPartitionColumnIsBatchPassThrough returns true if the SELECT target + * entry that feeds the INSERT's partition column is a provably shard-local + * unnest(array_agg()) batch pass-through. It mirrors the + * partition-column position mapping used by InsertPartitionColumnMatchesSelect + * and then defers the actual pattern check to IsBatchUnnestArrayAggPartitionColumn. + */ +static bool +InsertPartitionColumnIsBatchPassThrough(Query *query, RangeTblEntry *insertRte, + RangeTblEntry *subqueryRte) +{ + Oid insertRelationId = insertRte->relid; + Var *insertPartitionColumn = PartitionColumn(insertRelationId, 1); + Query *subquery = subqueryRte->subquery; + + ListCell *targetEntryCell = NULL; + foreach(targetEntryCell, query->targetList) + { + TargetEntry *targetEntry = (TargetEntry *) lfirst(targetEntryCell); + List *insertTargetEntryColumnList = pull_var_clause_default((Node *) targetEntry); + + if (list_length(insertTargetEntryColumnList) != 1) + { + continue; + } + + Var *insertVar = (Var *) linitial(insertTargetEntryColumnList); + + /* skip processing of target table non-partition columns */ + if (targetEntry->resno != insertPartitionColumn->varattno) + { + continue; + } + + if (insertVar->varattno > list_length(subquery->targetList)) + { + return false; + } + + TargetEntry *subqueryTargetEntry = list_nth(subquery->targetList, + insertVar->varattno - 1); + + return IsBatchUnnestArrayAggPartitionColumn(subqueryTargetEntry->expr, + subquery); + } + + return false; +} + + /* * InsertPartitionColumnMatchesSelect returns NULL the partition column in the * table targeted by INSERTed matches with the any of the SELECTed table's diff --git a/src/backend/distributed/planner/merge_planner.c b/src/backend/distributed/planner/merge_planner.c index a05eef493d8..95a5217ac29 100644 --- a/src/backend/distributed/planner/merge_planner.c +++ b/src/backend/distributed/planner/merge_planner.c @@ -1166,7 +1166,7 @@ DeferErrorIfRoutableMergeNotSupported(Query *query, List *rangeTableList, deferredError = DeferErrorIfUnsupportedSubqueryPushdown(query, plannerRestrictionContext, - true); + true, false); if (deferredError) { ereport(DEBUG1, (errmsg("Sub-query is not pushable, try repartitioning"))); diff --git a/src/backend/distributed/planner/multi_logical_optimizer.c b/src/backend/distributed/planner/multi_logical_optimizer.c index 3a4aef85042..2f8924d46f4 100644 --- a/src/backend/distributed/planner/multi_logical_optimizer.c +++ b/src/backend/distributed/planner/multi_logical_optimizer.c @@ -4763,6 +4763,141 @@ IsPartitionColumn(Expr *columnExpression, Query *query, bool skipOuterVars) } +/* + * IsBatchUnnestArrayAggPartitionColumn returns true if the given SELECT target + * expression is a provably shard-local "batch pass-through" of the distribution + * column, i.e. it has the shape + * + * unnest(array_agg()) + * + * (optionally projected through one or more plain-Var subquery indirections, and + * with array_agg allowed to carry ORDER BY / DISTINCT / FILTER modifiers). + * + * Such an expression only ever emits distribution-column values that were read + * from rows of the current shard, so - given source and target are colocated - + * every produced value hashes back into this shard's range. That makes it safe + * to push a colocated INSERT ... SELECT down to the shards even though the + * distribution value is technically a derived expression rather than a plain + * Var. Any intermediate transformation (e.g. unnest(array_agg(dist_key + 1)) or + * unnest(f(array_agg(dist_key)))) is rejected because it could produce values + * that would hash to a different shard. + */ +bool +IsBatchUnnestArrayAggPartitionColumn(Expr *expr, Query *query) +{ + Query *leafQuery = query; + + /* + * Peel plain-Var subquery projection indirection down to the underlying + * expression. We only follow the simple "SELECT FROM (subquery)" + * projection form; anything else makes us conservatively bail out. + */ + for (;;) + { + expr = (Expr *) strip_implicit_coercions((Node *) expr); + + if (!IsA(expr, Var)) + { + break; + } + + Var *var = (Var *) expr; + if (var->varlevelsup != 0 || var->varattno <= InvalidAttrNumber) + { + return false; + } + + if (var->varno <= 0 || var->varno > list_length(leafQuery->rtable)) + { + return false; + } + + RangeTblEntry *rte = rt_fetch(var->varno, leafQuery->rtable); + if (rte->rtekind != RTE_SUBQUERY) + { + return false; + } + + Query *subquery = rte->subquery; + if (var->varattno > list_length(subquery->targetList)) + { + return false; + } + + TargetEntry *targetEntry = list_nth(subquery->targetList, var->varattno - 1); + expr = targetEntry->expr; + leafQuery = subquery; + } + + /* the leaf expression must be unnest(...) over a single array argument */ + if (!IsA(expr, FuncExpr)) + { + return false; + } + + FuncExpr *unnestExpr = (FuncExpr *) expr; + if (unnestExpr->funcid != F_UNNEST_ANYARRAY || + list_length(unnestExpr->args) != 1) + { + return false; + } + + /* the unnest argument must be array_agg(...) with no wrapping transform */ + Expr *unnestArg = (Expr *) strip_implicit_coercions( + (Node *) linitial(unnestExpr->args)); + if (!IsA(unnestArg, Aggref)) + { + return false; + } + + Aggref *arrayAgg = (Aggref *) unnestArg; + if (arrayAgg->aggfnoid != F_ARRAY_AGG_ANYNONARRAY && + arrayAgg->aggfnoid != F_ARRAY_AGG_ANYARRAY) + { + return false; + } + + /* + * Locate the single aggregated value argument. ORDER BY keys that differ + * from the aggregated value appear as additional resjunk target entries; + * they only reorder the (shard-local) values and do not affect routing. + */ + TargetEntry *aggValueTargetEntry = NULL; + TargetEntry *aggArgTargetEntry = NULL; + foreach_declared_ptr(aggArgTargetEntry, arrayAgg->args) + { + if (aggArgTargetEntry->resjunk) + { + continue; + } + + if (aggValueTargetEntry != NULL) + { + /* + * array_agg (the OIDs we matched above) is a single-argument + * aggregate, so a second non-resjunk entry cannot occur for a + * well-formed tree. Assert to catch a broken invariant in debug + * builds, but still bail out gracefully in production: a false + * result only forgoes the optimization, it is never unsafe. + */ + Assert(false); + return false; + } + + aggValueTargetEntry = aggArgTargetEntry; + } + + if (aggValueTargetEntry == NULL) + { + return false; + } + + /* the aggregated value must be the untransformed source partition column */ + bool skipOuterVars = false; + return IsPartitionColumn(aggValueTargetEntry->expr, leafQuery, skipOuterVars); +} + + /* * FindReferencedTableColumn recursively traverses query tree to find actual relation * id, and column that columnExpression refers to. If columnExpression is a diff --git a/src/backend/distributed/planner/multi_router_planner.c b/src/backend/distributed/planner/multi_router_planner.c index 313d2674ddc..bea08f039b9 100644 --- a/src/backend/distributed/planner/multi_router_planner.c +++ b/src/backend/distributed/planner/multi_router_planner.c @@ -369,6 +369,27 @@ AddPartitionKeyNotNullFilterToSelect(Query *subqery) } } + /* + * Normally the SELECT projects the distribution column as a plain Var. With + * unsafe INSERT ... SELECT pushdown the distribution column may instead be a + * provably shard-local batch pass-through, i.e. unnest(array_agg(dist_col)). + * In that case there is no plain Var to attach a NOT NULL filter to, and the + * batch stays shard-local, so the filter is unnecessary; skip it. Any other + * derived distribution value is rejected earlier during planning, so if we + * reach here without a plain-Var partition column it must be this pattern. + */ + if (targetPartitionColumnVar == NULL && AllowUnsafeInsertSelectPushdown) + { + TargetEntry *batchTargetEntry = NULL; + foreach_declared_ptr(batchTargetEntry, targetList) + { + if (IsBatchUnnestArrayAggPartitionColumn(batchTargetEntry->expr, subqery)) + { + return; + } + } + } + /* we should have found target partition column */ Assert(targetPartitionColumnVar != NULL); @@ -1342,7 +1363,7 @@ MultiShardUpdateDeleteSupported(Query *originalQuery, errorMessage = DeferErrorIfUnsupportedSubqueryPushdown( originalQuery, plannerRestrictionContext, - true); + true, false); } return errorMessage; diff --git a/src/backend/distributed/planner/query_pushdown_planning.c b/src/backend/distributed/planner/query_pushdown_planning.c index 753643b1929..152c3e45682 100644 --- a/src/backend/distributed/planner/query_pushdown_planning.c +++ b/src/backend/distributed/planner/query_pushdown_planning.c @@ -80,6 +80,7 @@ typedef struct RelidsReferenceWalkerContext /* Config variable managed via guc.c */ bool SubqueryPushdown = false; /* is subquery pushdown enabled */ int ValuesMaterializationThreshold = 100; +bool AllowUnsafeInsertSelectPushdown = false; /* Local functions forward declarations */ static bool JoinTreeContainsSubqueryWalker(Node *joinTreeNode, void *context); @@ -92,7 +93,9 @@ static DeferredErrorMessage * DeferredErrorIfUnsupportedRecurringTuplesJoin( static DeferredErrorMessage * DeferErrorIfUnsupportedTableCombination(Query *queryTree); static DeferredErrorMessage * DeferErrorIfSubqueryRequiresMerge(Query *subqueryTree, bool lateral, - char *referencedThing); + char *referencedThing, + bool + allowUnsafeShardLocalGrouping); static bool ExtractSetOperationStatementWalker(Node *node, List **setOperationList); static RecurringTuplesType FetchFirstRecurType(PlannerInfo *plannerInfo, Relids relids); @@ -550,7 +553,7 @@ SubqueryMultiNodeTree(Query *originalQuery, Query *queryTree, DeferredErrorMessage *subqueryPushdownError = DeferErrorIfUnsupportedSubqueryPushdown( originalQuery, plannerRestrictionContext, - false); + false, false); if (subqueryPushdownError != NULL) { @@ -570,12 +573,22 @@ SubqueryMultiNodeTree(Query *originalQuery, Query *queryTree, * entry list and uses helper functions to check if we can push down subquery * to worker nodes. These helper functions returns a deferred error if we * cannot push down the subquery. + * + * allowUnsafeShardLocalGroupingForSubqueries is forwarded as-is to + * DeferErrorIfCannotPushdownSubquery for every subquery checked here. It must only + * be set for colocated INSERT ... SELECT under + * citus.allow_unsafe_insert_select_pushdown; when true, the GROUP BY / aggregate / + * window / DISTINCT merge-step requirements are skipped, trusting the caller that + * those grouping constructs stay shard-local. All other pushdown checks + * (co-location, joins on the distribution column, recurring tuples, LIMIT/OFFSET) + * remain enforced. */ DeferredErrorMessage * DeferErrorIfUnsupportedSubqueryPushdown(Query *originalQuery, PlannerRestrictionContext * plannerRestrictionContext, - bool plannerPhase) + bool plannerPhase, + bool allowUnsafeShardLocalGroupingForSubqueries) { bool outerMostQueryHasLimit = false; ListCell *subqueryCell = NULL; @@ -648,7 +661,8 @@ DeferErrorIfUnsupportedSubqueryPushdown(Query *originalQuery, { Query *subquery = lfirst(subqueryCell); error = DeferErrorIfCannotPushdownSubquery(subquery, - outerMostQueryHasLimit); + outerMostQueryHasLimit, + allowUnsafeShardLocalGroupingForSubqueries); if (error) { return error; @@ -970,8 +984,8 @@ DeferredErrorIfUnsupportedRecurringTuplesJoin(PlannerRestrictionContext * bool CanPushdownSubquery(Query *subqueryTree, bool outerMostQueryHasLimit) { - return DeferErrorIfCannotPushdownSubquery(subqueryTree, outerMostQueryHasLimit) == - NULL; + return DeferErrorIfCannotPushdownSubquery(subqueryTree, outerMostQueryHasLimit, + false) == NULL; } @@ -996,9 +1010,18 @@ CanPushdownSubquery(Query *subqueryTree, bool outerMostQueryHasLimit) * a subquery has a group by on another subquery which includes order by with * limit, we let this query to run, but results could be wrong depending on the * features of underlying tables. + * + * When allowUnsafeShardLocalGrouping is true, the caller asserts that any GROUP BY / + * aggregate / window / DISTINCT in the subquery stays within a single shard, so the + * partition-column requirements that would otherwise force a coordinator merge step + * are skipped (this flag is forwarded to DeferErrorIfSubqueryRequiresMerge). + * LIMIT/OFFSET handling is unaffected. It is only set for colocated + * INSERT ... SELECT under citus.allow_unsafe_insert_select_pushdown, where keeping + * batches shard-local becomes the user's responsibility. */ DeferredErrorMessage * -DeferErrorIfCannotPushdownSubquery(Query *subqueryTree, bool outerMostQueryHasLimit) +DeferErrorIfCannotPushdownSubquery(Query *subqueryTree, bool outerMostQueryHasLimit, + bool allowUnsafeShardLocalGrouping) { bool preconditionsSatisfied = true; char *errorDetail = NULL; @@ -1027,7 +1050,8 @@ DeferErrorIfCannotPushdownSubquery(Query *subqueryTree, bool outerMostQueryHasLi if (!ContainsReferencesToOuterQuery(subqueryTree)) { deferredError = DeferErrorIfSubqueryRequiresMerge(subqueryTree, false, - "another query"); + "another query", + allowUnsafeShardLocalGrouping); if (deferredError) { return deferredError; @@ -1125,10 +1149,18 @@ FlattenGroupExprs(Query *queryTree) * DeferErrorIfSubqueryRequiresMerge returns a deferred error if the subquery * requires a merge step on the coordinator (e.g. limit, group by non-distribution * column, etc.). + * + * When allowUnsafeShardLocalGrouping is true, the partition-column requirements for + * GROUP BY / aggregate / window / DISTINCT are skipped: the caller guarantees these + * grouping constructs are shard-local, so they do not need a coordinator merge. + * LIMIT/OFFSET are still rejected, because those require a merge regardless of the + * distribution column. The flag is only ever true for colocated INSERT ... SELECT + * under citus.allow_unsafe_insert_select_pushdown. */ static DeferredErrorMessage * DeferErrorIfSubqueryRequiresMerge(Query *subqueryTree, bool lateral, - char *referencedThing) + char *referencedThing, bool + allowUnsafeShardLocalGrouping) { bool preconditionsSatisfied = true; char *errorDetail = NULL; @@ -1152,68 +1184,81 @@ DeferErrorIfSubqueryRequiresMerge(Query *subqueryTree, bool lateral, referencedThing); } - /* group clause list must include partition column */ - if (subqueryTree->groupClause) + /* + * With unsafe INSERT ... SELECT pushdown the entire subquery executes on a + * single shard (colocation is enforced separately), so grouping / aggregation + * / window / distinct on non-distribution columns stays shard-local and does + * not require a merge step. Skip the partition-column requirements in that + * case. + */ + if (!allowUnsafeShardLocalGrouping) { - List *groupClauseList = subqueryTree->groupClause; - List *targetEntryList = subqueryTree->targetList; - List *groupTargetEntryList = GroupTargetEntryList(groupClauseList, - targetEntryList); - bool groupOnPartitionColumn = - TargetListOnPartitionColumn(subqueryTree, groupTargetEntryList); - if (!groupOnPartitionColumn) + /* group clause list must include partition column */ + if (subqueryTree->groupClause) { - preconditionsSatisfied = false; - errorDetail = psprintf("Group by list without partition column is currently " - "unsupported when a %ssubquery references a column " - "from %s", lateralString, referencedThing); + List *groupClauseList = subqueryTree->groupClause; + List *targetEntryList = subqueryTree->targetList; + List *groupTargetEntryList = GroupTargetEntryList(groupClauseList, + targetEntryList); + bool groupOnPartitionColumn = + TargetListOnPartitionColumn(subqueryTree, groupTargetEntryList); + if (!groupOnPartitionColumn) + { + preconditionsSatisfied = false; + errorDetail = psprintf( + "Group by list without partition column is currently " + "unsupported when a %ssubquery references a column " + "from %s", lateralString, referencedThing); + } } - } - /* we don't support aggregates without group by */ - if (subqueryTree->hasAggs && (subqueryTree->groupClause == NULL)) - { - preconditionsSatisfied = false; - errorDetail = psprintf("Aggregates without group by are currently unsupported " - "when a %ssubquery references a column from %s", - lateralString, referencedThing); - } - - /* having clause without group by on partition column is not supported */ - if (subqueryTree->havingQual && (subqueryTree->groupClause == NULL)) - { - preconditionsSatisfied = false; - errorDetail = psprintf("Having qual without group by on partition column is " - "currently unsupported when a %ssubquery references " - "a column from %s", lateralString, referencedThing); - } + /* we don't support aggregates without group by */ + if (subqueryTree->hasAggs && (subqueryTree->groupClause == NULL)) + { + preconditionsSatisfied = false; + errorDetail = psprintf( + "Aggregates without group by are currently unsupported " + "when a %ssubquery references a column from %s", + lateralString, referencedThing); + } - /* - * We support window functions when the window function - * is partitioned on distribution column. - */ - StringInfo errorInfo = NULL; - if (subqueryTree->hasWindowFuncs && !SafeToPushdownWindowFunction(subqueryTree, - &errorInfo)) - { - errorDetail = (char *) errorInfo->data; - preconditionsSatisfied = false; - } + /* having clause without group by on partition column is not supported */ + if (subqueryTree->havingQual && (subqueryTree->groupClause == NULL)) + { + preconditionsSatisfied = false; + errorDetail = psprintf( + "Having qual without group by on partition column is " + "currently unsupported when a %ssubquery references " + "a column from %s", lateralString, referencedThing); + } - /* distinct clause list must include partition column */ - if (subqueryTree->distinctClause) - { - List *distinctClauseList = subqueryTree->distinctClause; - List *targetEntryList = subqueryTree->targetList; - List *distinctTargetEntryList = GroupTargetEntryList(distinctClauseList, - targetEntryList); - bool distinctOnPartitionColumn = - TargetListOnPartitionColumn(subqueryTree, distinctTargetEntryList); - if (!distinctOnPartitionColumn) + /* + * We support window functions when the window function + * is partitioned on distribution column. + */ + StringInfo errorInfo = NULL; + if (subqueryTree->hasWindowFuncs && !SafeToPushdownWindowFunction(subqueryTree, + &errorInfo)) { + errorDetail = (char *) errorInfo->data; preconditionsSatisfied = false; - errorDetail = "Distinct on columns without partition column is " - "currently unsupported"; + } + + /* distinct clause list must include partition column */ + if (subqueryTree->distinctClause) + { + List *distinctClauseList = subqueryTree->distinctClause; + List *targetEntryList = subqueryTree->targetList; + List *distinctTargetEntryList = GroupTargetEntryList(distinctClauseList, + targetEntryList); + bool distinctOnPartitionColumn = + TargetListOnPartitionColumn(subqueryTree, distinctTargetEntryList); + if (!distinctOnPartitionColumn) + { + preconditionsSatisfied = false; + errorDetail = "Distinct on columns without partition column is " + "currently unsupported"; + } } } @@ -1763,7 +1808,7 @@ DeferredErrorIfUnsupportedLateralSubquery(PlannerInfo *plannerInfo, /* property number 3, has a merge step */ DeferredErrorMessage *deferredError = DeferErrorIfSubqueryRequiresMerge( - rangeTableEntry->subquery, true, recurTypeDescription); + rangeTableEntry->subquery, true, recurTypeDescription, false); if (deferredError) { return deferredError; diff --git a/src/backend/distributed/shared_library_init.c b/src/backend/distributed/shared_library_init.c index 9ea35038f8e..058fe764be3 100644 --- a/src/backend/distributed/shared_library_init.c +++ b/src/backend/distributed/shared_library_init.c @@ -1057,6 +1057,27 @@ RegisterCitusConfigVariables(void) GUC_NO_SHOW_ALL | GUC_NOT_IN_SAMPLE, NULL, NULL, NULL); + DefineCustomBoolVariable( + "citus.allow_unsafe_insert_select_pushdown", + gettext_noop("Allows pushdown of otherwise-unsafe colocated " + "INSERT ... SELECT queries."), + gettext_noop("When enabled, Citus relaxes safety checks (GROUP BY / window / " + "aggregate / DISTINCT constructs on non-distribution columns) " + "for colocated INSERT ... SELECT, so that batching and any batch " + "UDF call run on the shards instead of pulling data to the " + "coordinator. The INSERT's distribution column may then be a " + "provably shard-local unnest(array_agg()); " + "any other derived distribution value is still rejected, since it " + "could route rows that actually belong to a different shard. " + "Colocation is still enforced, but the user takes responsibility " + "for keeping batches order-preserving; otherwise results may be " + "silently incorrect."), + &AllowUnsafeInsertSelectPushdown, + false, + PGC_USERSET, + GUC_STANDARD, + NULL, NULL, NULL); + DefineCustomBoolVariable( "citus.allow_unsafe_locks_from_workers", gettext_noop("Enables acquiring a distributed lock from a worker " diff --git a/src/include/distributed/multi_logical_optimizer.h b/src/include/distributed/multi_logical_optimizer.h index 940cbc12358..d62bd7b0e7f 100644 --- a/src/include/distributed/multi_logical_optimizer.h +++ b/src/include/distributed/multi_logical_optimizer.h @@ -176,6 +176,7 @@ extern List * SubqueryMultiTableList(MultiNode *multiNode); extern List * GroupTargetEntryList(List *groupClauseList, List *targetEntryList); extern bool ExtractQueryWalker(Node *node, List **queryList); extern bool IsPartitionColumn(Expr *columnExpression, Query *query, bool skipOuterVars); +extern bool IsBatchUnnestArrayAggPartitionColumn(Expr *expr, Query *query); extern void FindReferencedTableColumn(Expr *columnExpression, List *parentQueryList, Query *query, Var **column, RangeTblEntry **rteContainingReferencedColumn, diff --git a/src/include/distributed/query_pushdown_planning.h b/src/include/distributed/query_pushdown_planning.h index 0b69d36c75f..fec709fbdc5 100644 --- a/src/include/distributed/query_pushdown_planning.h +++ b/src/include/distributed/query_pushdown_planning.h @@ -22,6 +22,7 @@ /* Config variables managed via guc.c */ extern bool SubqueryPushdown; extern int ValuesMaterializationThreshold; +extern bool AllowUnsafeInsertSelectPushdown; extern bool AllowAggregateWorkerCombineOnInternalTypes; @@ -44,10 +45,14 @@ extern DeferredErrorMessage * DeferErrorIfUnsupportedSubqueryPushdown(Query * PlannerRestrictionContext * plannerRestrictionContext, - bool plannerPhase); + bool plannerPhase, + bool + allowUnsafeShardLocalGroupingForSubqueries); extern DeferredErrorMessage * DeferErrorIfCannotPushdownSubquery(Query *subqueryTree, bool - outerMostQueryHasLimit); + outerMostQueryHasLimit, + bool + allowUnsafeShardLocalGrouping); extern DeferredErrorMessage * DeferErrorIfUnsupportedUnionQuery(Query *queryTree); extern bool IsJsonTableRTE(RangeTblEntry *rte); extern bool IsOuterJoinExpr(Node *node); diff --git a/src/test/regress/expected/allow_unsafe_insert_select_pushdown.out b/src/test/regress/expected/allow_unsafe_insert_select_pushdown.out new file mode 100644 index 00000000000..d79f779669a --- /dev/null +++ b/src/test/regress/expected/allow_unsafe_insert_select_pushdown.out @@ -0,0 +1,785 @@ +-- +-- ALLOW_UNSAFE_INSERT_SELECT_PUSHDOWN +-- +-- Tests citus.allow_unsafe_insert_select_pushdown, which lets a colocated +-- INSERT .. SELECT push GROUP BY / window / DISTINCT batching and a batch UDF +-- down to the shards instead of pulling rows to the coordinator. +-- +-- The distribution column of the target must still come from the source +-- distribution column unchanged: either as a plain Var, or as the provably +-- shard-local batch pass-through unnest(array_agg(dist_col)). Any other derived +-- distribution value (an arithmetic/function transform, or a transform wrapped +-- inside the array_agg) is rejected even with the GUC enabled, because it could +-- route a row to a different shard. +-- +CREATE SCHEMA allow_unsafe_insert_select_pushdown; +SET search_path = allow_unsafe_insert_select_pushdown; +SET citus.next_shard_id TO 14000000; +SET citus.shard_count = 4; +SET citus.shard_replication_factor = 1; +CREATE TABLE dist(text_id int, text_col text); +CREATE TABLE res(text_id int, val int); +SELECT create_distributed_table('dist', 'text_id'); + create_distributed_table +--------------------------------------------------------------------- + +(1 row) + +SELECT create_distributed_table('res', 'text_id'); + create_distributed_table +--------------------------------------------------------------------- + +(1 row) + +INSERT INTO dist SELECT g, 't' || g FROM generate_series(1, 500) g; +-- a batched UDF: returns one value per input, mimicking a batched API call. +-- immutable + parallel safe, like a real batch UDF. +CREATE FUNCTION batch_transform(t text[]) RETURNS int[] +LANGUAGE sql IMMUTABLE PARALLEL SAFE AS $$ SELECT array_agg(length(x)) FROM unnest(t) x $$; +SELECT create_distributed_function('batch_transform(text[])'); +NOTICE: procedure allow_unsafe_insert_select_pushdown.batch_transform is already distributed +DETAIL: Citus distributes procedures with CREATE [PROCEDURE|FUNCTION|AGGREGATE] commands + create_distributed_function +--------------------------------------------------------------------- + +(1 row) + +-- default off: batching is done after pulling rows to the coordinator +-- (explain_filter strips the PG18-only "Window:" line so the plan is +-- comparable across supported Postgres versions) +SELECT public.explain_filter($$ +EXPLAIN (COSTS OFF) INSERT INTO res(text_id, val) +SELECT id, val FROM ( + SELECT b, unnest(array_agg(text_id)) id, + unnest(batch_transform(array_agg(text_col))) val + FROM (SELECT text_id, text_col, (row_number() OVER () - 1) / 100 b FROM dist) q + GROUP BY b +) s +$$, true); + explain_filter +--------------------------------------------------------------------- + Custom Scan (Citus INSERT ... SELECT) + INSERT/SELECT method: pull to coordinator + -> Custom Scan (Citus Adaptive) + -> Distributed Subplan XXX_1 + -> WindowAgg + -> Custom Scan (Citus Adaptive) + Task Count: 4 + Tasks Shown: One of 4 + -> Task + Node: host=localhost port=xxxxx dbname=regression + -> Seq Scan on dist_14000000 dist + Task Count: 1 + Tasks Shown: All + -> Task + Node: host=localhost port=xxxxx dbname=regression + -> Subquery Scan on s + -> ProjectSet + -> HashAggregate + Group Key: intermediate_result.b + -> Function Scan on read_intermediate_result intermediate_result +(20 rows) + +SET citus.allow_unsafe_insert_select_pushdown TO on; +-- now the batching and UDF call run on the shards +SELECT public.explain_filter($$ +EXPLAIN (COSTS OFF) INSERT INTO res(text_id, val) +SELECT id, val FROM ( + SELECT b, unnest(array_agg(text_id)) id, + unnest(batch_transform(array_agg(text_col))) val + FROM (SELECT text_id, text_col, (row_number() OVER () - 1) / 100 b FROM dist) q + GROUP BY b +) s +$$, true); + explain_filter +--------------------------------------------------------------------- + Custom Scan (Citus Adaptive) + Task Count: 4 + Tasks Shown: One of 4 + -> Task + Node: host=localhost port=xxxxx dbname=regression + -> Insert on res_14000004 citus_table_alias + -> Subquery Scan on s + -> ProjectSet + -> HashAggregate + Group Key: ((row_number() OVER (?) - 1) / 100) + -> WindowAgg + -> Seq Scan on dist_14000000 dist +(12 rows) + +INSERT INTO res(text_id, val) +SELECT id, val FROM ( + SELECT b, unnest(array_agg(text_id)) id, + unnest(batch_transform(array_agg(text_col))) val + FROM (SELECT text_id, text_col, (row_number() OVER () - 1) / 100 b FROM dist) q + GROUP BY b +) s; +-- every text_id should be matched to the right value +SELECT count(*), count(*) FILTER (WHERE val = length('t' || text_id)) AS ok +FROM res JOIN dist USING (text_id); + count | ok +--------------------------------------------------------------------- + 500 | 500 +(1 row) + +-- --------------------------------------------------------------------- +-- Positive per-branch coverage. Each construct below used to force a +-- coordinator merge (or was rejected outright). With the GUC enabled the whole +-- colocated INSERT .. SELECT is pushed to the shards because the distribution +-- column is either a plain partition-column Var or the unnest(array_agg(text_id)) +-- batch pass-through. The pushed-down plan is a Custom Scan (Citus Adaptive) +-- whose task runs the INSERT on a shard, with no Distributed Subplan / +-- intermediate results. explain_filter keeps the plan comparable across +-- Postgres versions. +-- --------------------------------------------------------------------- +-- the batched benchmark shape: bucket rows into fixed-size batches with +-- row_number()/batch_size, array_agg each batch (id and text in the same +-- order), call the batch UDF once per batch, then unnest back to one row per id. +-- This is the query the GUC is meant to push down to the shards. +SELECT public.explain_filter($$ +EXPLAIN (COSTS OFF) INSERT INTO res(text_id, val) +SELECT id, val FROM ( + SELECT + unnest(array_agg(text_id ORDER BY text_id)) id, + unnest(batch_transform(array_agg(text_col ORDER BY text_id))) val + FROM ( + SELECT text_id, text_col, (row_number() OVER () - 1) / 100 batch FROM dist + ) q + GROUP BY batch +) s +$$, true); + explain_filter +--------------------------------------------------------------------- + Custom Scan (Citus Adaptive) + Task Count: 4 + Tasks Shown: One of 4 + -> Task + Node: host=localhost port=xxxxx dbname=regression + -> Insert on res_14000004 citus_table_alias + -> Subquery Scan on s + -> ProjectSet + -> GroupAggregate + Group Key: q.batch + -> Sort + Sort Key: q.batch, q.text_id + -> Subquery Scan on q + -> WindowAgg + -> Seq Scan on dist_14000000 dist +(15 rows) + +-- branch: GROUP BY on a non-distribution column, distribution column projected +-- as the unnest(array_agg(text_id)) batch pass-through +SELECT public.explain_filter($$ +EXPLAIN (COSTS OFF) INSERT INTO res(text_id, val) +SELECT unnest(array_agg(text_id)), unnest(array_agg(length(text_col))) +FROM dist GROUP BY text_id % 10 +$$, true); + explain_filter +--------------------------------------------------------------------- + Custom Scan (Citus Adaptive) + Task Count: 4 + Tasks Shown: One of 4 + -> Task + Node: host=localhost port=xxxxx dbname=regression + -> Insert on res_14000004 citus_table_alias + -> Subquery Scan on "*SELECT*" + -> ProjectSet + -> HashAggregate + Group Key: (dist.text_id % 10) + -> Seq Scan on dist_14000000 dist +(11 rows) + +-- branch: aggregates without GROUP BY +SELECT public.explain_filter($$ +EXPLAIN (COSTS OFF) INSERT INTO res(text_id, val) +SELECT unnest(array_agg(text_id)), unnest(array_agg(length(text_col))) FROM dist +$$, true); + explain_filter +--------------------------------------------------------------------- + Custom Scan (Citus Adaptive) + Task Count: 4 + Tasks Shown: One of 4 + -> Task + Node: host=localhost port=xxxxx dbname=regression + -> Insert on res_14000004 citus_table_alias + -> ProjectSet + -> Aggregate + -> Seq Scan on dist_14000000 dist +(9 rows) + +-- branch: HAVING without GROUP BY on the distribution column +SELECT public.explain_filter($$ +EXPLAIN (COSTS OFF) INSERT INTO res(text_id, val) +SELECT unnest(array_agg(text_id)), unnest(array_agg(length(text_col))) +FROM dist HAVING count(*) > 0 +$$, true); + explain_filter +--------------------------------------------------------------------- + Custom Scan (Citus Adaptive) + Task Count: 4 + Tasks Shown: One of 4 + -> Task + Node: host=localhost port=xxxxx dbname=regression + -> Insert on res_14000004 citus_table_alias + -> ProjectSet + -> Aggregate + Filter: (count(*) > 0) + -> Seq Scan on dist_14000000 dist +(10 rows) + +-- branch: window function not partitioned on the distribution column, with the +-- distribution column projected as a plain partition-column Var +SELECT public.explain_filter($$ +EXPLAIN (COSTS OFF) INSERT INTO res(text_id, val) +SELECT text_id, row_number() OVER (ORDER BY text_col) FROM dist +$$, true); + explain_filter +--------------------------------------------------------------------- + Custom Scan (Citus Adaptive) + Task Count: 4 + Tasks Shown: One of 4 + -> Task + Node: host=localhost port=xxxxx dbname=regression + -> Insert on res_14000004 citus_table_alias + -> Subquery Scan on "*SELECT*" + -> WindowAgg + -> Sort + Sort Key: dist.text_col + -> Seq Scan on dist_14000000 dist + Filter: (text_id IS NOT NULL) +(12 rows) + +-- combination: GROUP BY + DISTINCT, distribution column as the batch pass-through +SELECT public.explain_filter($$ +EXPLAIN (COSTS OFF) INSERT INTO res(text_id, val) +SELECT DISTINCT unnest(array_agg(text_id)), unnest(array_agg(length(text_col))) +FROM dist GROUP BY text_id % 10 +$$, true); + explain_filter +--------------------------------------------------------------------- + Custom Scan (Citus Adaptive) + Task Count: 4 + Tasks Shown: One of 4 + -> Task + Node: host=localhost port=xxxxx dbname=regression + -> Insert on res_14000004 citus_table_alias + -> Subquery Scan on "*SELECT*" + -> HashAggregate + Group Key: unnest((array_agg(dist.text_id))), unnest((array_agg(length(dist.text_col)))) + -> ProjectSet + -> HashAggregate + Group Key: (dist.text_id % 10) + -> Seq Scan on dist_14000000 dist +(13 rows) + +-- combination: window + GROUP BY, relaxed constructs living in nested subqueries +SELECT public.explain_filter($$ +EXPLAIN (COSTS OFF) INSERT INTO res(text_id, val) +SELECT id, v FROM ( + SELECT unnest(array_agg(text_id)) id, unnest(array_agg(length(text_col))) v + FROM (SELECT text_id, text_col, row_number() OVER (ORDER BY text_col) rn FROM dist) q + GROUP BY rn % 5 +) s +$$, true); + explain_filter +--------------------------------------------------------------------- + Custom Scan (Citus Adaptive) + Task Count: 4 + Tasks Shown: One of 4 + -> Task + Node: host=localhost port=xxxxx dbname=regression + -> Insert on res_14000004 citus_table_alias + -> Subquery Scan on s + -> ProjectSet + -> HashAggregate + Group Key: (q.rn % '5'::bigint) + -> Subquery Scan on q + -> WindowAgg + -> Sort + Sort Key: dist.text_col + -> Seq Scan on dist_14000000 dist +(15 rows) + +-- --------------------------------------------------------------------- +-- Negative coverage: pattern requirement. With the GUC enabled the batching +-- relaxations still fire, but the distribution column is a *transformed* value +-- rather than the source partition column, so the query is not pushed down and +-- falls back to a coordinator merge -- identical to the GUC-disabled plan. +-- --------------------------------------------------------------------- +-- distribution column derived by arithmetic on the partition column +SET citus.allow_unsafe_insert_select_pushdown TO off; +SELECT public.explain_filter($$ +EXPLAIN (COSTS OFF) INSERT INTO res(text_id, val) +SELECT text_id + 0, length(text_col) FROM dist +$$, true); + explain_filter +--------------------------------------------------------------------- + Custom Scan (Citus INSERT ... SELECT) + INSERT/SELECT method: repartition + -> Custom Scan (Citus Adaptive) + Task Count: 4 + Tasks Shown: One of 4 + -> Task + Node: host=localhost port=xxxxx dbname=regression + -> Seq Scan on dist_14000000 dist +(8 rows) + +SET citus.allow_unsafe_insert_select_pushdown TO on; +SELECT public.explain_filter($$ +EXPLAIN (COSTS OFF) INSERT INTO res(text_id, val) +SELECT text_id + 0, length(text_col) FROM dist +$$, true); + explain_filter +--------------------------------------------------------------------- + Custom Scan (Citus INSERT ... SELECT) + INSERT/SELECT method: repartition + -> Custom Scan (Citus Adaptive) + Task Count: 4 + Tasks Shown: One of 4 + -> Task + Node: host=localhost port=xxxxx dbname=regression + -> Seq Scan on dist_14000000 dist +(8 rows) + +-- distribution column derived from a non-distribution column (DISTINCT) +SET citus.allow_unsafe_insert_select_pushdown TO off; +SELECT public.explain_filter($$ +EXPLAIN (COSTS OFF) INSERT INTO res(text_id, val) +SELECT DISTINCT length(text_col), text_id % 7 FROM dist +$$, true); + explain_filter +--------------------------------------------------------------------- + Custom Scan (Citus INSERT ... SELECT) + INSERT/SELECT method: pull to coordinator + -> HashAggregate + Group Key: remote_scan.text_id, remote_scan.val + -> Custom Scan (Citus Adaptive) + Task Count: 4 + Tasks Shown: One of 4 + -> Task + Node: host=localhost port=xxxxx dbname=regression + -> HashAggregate + Group Key: length(text_col), (text_id % 7) + -> Seq Scan on dist_14000000 dist +(12 rows) + +SET citus.allow_unsafe_insert_select_pushdown TO on; +SELECT public.explain_filter($$ +EXPLAIN (COSTS OFF) INSERT INTO res(text_id, val) +SELECT DISTINCT length(text_col), text_id % 7 FROM dist +$$, true); + explain_filter +--------------------------------------------------------------------- + Custom Scan (Citus INSERT ... SELECT) + INSERT/SELECT method: pull to coordinator + -> HashAggregate + Group Key: remote_scan.text_id, remote_scan.val + -> Custom Scan (Citus Adaptive) + Task Count: 4 + Tasks Shown: One of 4 + -> Task + Node: host=localhost port=xxxxx dbname=regression + -> HashAggregate + Group Key: length(text_col), (text_id % 7) + -> Seq Scan on dist_14000000 dist +(12 rows) + +-- distribution column derived from a non-distribution column, one subquery down +SET citus.allow_unsafe_insert_select_pushdown TO off; +SELECT public.explain_filter($$ +EXPLAIN (COSTS OFF) INSERT INTO res(text_id, val) +SELECT k, k FROM ( + SELECT DISTINCT length(text_col) k FROM dist +) s +$$, true); + explain_filter +--------------------------------------------------------------------- + Custom Scan (Citus INSERT ... SELECT) + INSERT/SELECT method: pull to coordinator + -> Custom Scan (Citus Adaptive) + -> Distributed Subplan XXX_1 + -> HashAggregate + Group Key: remote_scan.k + -> Custom Scan (Citus Adaptive) + Task Count: 4 + Tasks Shown: One of 4 + -> Task + Node: host=localhost port=xxxxx dbname=regression + -> HashAggregate + Group Key: length(text_col) + -> Seq Scan on dist_14000000 dist + Task Count: 1 + Tasks Shown: All + -> Task + Node: host=localhost port=xxxxx dbname=regression + -> Function Scan on read_intermediate_result intermediate_result +(19 rows) + +SET citus.allow_unsafe_insert_select_pushdown TO on; +SELECT public.explain_filter($$ +EXPLAIN (COSTS OFF) INSERT INTO res(text_id, val) +SELECT k, k FROM ( + SELECT DISTINCT length(text_col) k FROM dist +) s +$$, true); + explain_filter +--------------------------------------------------------------------- + Custom Scan (Citus INSERT ... SELECT) + INSERT/SELECT method: pull to coordinator + -> Custom Scan (Citus Adaptive) + -> Distributed Subplan XXX_1 + -> HashAggregate + Group Key: remote_scan.k + -> Custom Scan (Citus Adaptive) + Task Count: 4 + Tasks Shown: One of 4 + -> Task + Node: host=localhost port=xxxxx dbname=regression + -> HashAggregate + Group Key: length(text_col) + -> Seq Scan on dist_14000000 dist + Task Count: 1 + Tasks Shown: All + -> Task + Node: host=localhost port=xxxxx dbname=regression + -> Function Scan on read_intermediate_result intermediate_result +(19 rows) + +-- the batched benchmark shape, but with the distribution column transformed +-- *inside* array_agg (unnest(array_agg(text_id + 1))): the batch pass-through no +-- longer carries the untransformed partition column, so the same shape that +-- pushes down above is now rejected and falls back to a coordinator merge +SET citus.allow_unsafe_insert_select_pushdown TO off; +SELECT public.explain_filter($$ +EXPLAIN (COSTS OFF) INSERT INTO res(text_id, val) +SELECT id, val FROM ( + SELECT unnest(array_agg(text_id + 1)) id, + unnest(batch_transform(array_agg(text_col))) val + FROM (SELECT text_id, text_col, (row_number() OVER () - 1) / 100 b FROM dist) q + GROUP BY b +) s +$$, true); + explain_filter +--------------------------------------------------------------------- + Custom Scan (Citus INSERT ... SELECT) + INSERT/SELECT method: pull to coordinator + -> Custom Scan (Citus Adaptive) + -> Distributed Subplan XXX_1 + -> WindowAgg + -> Custom Scan (Citus Adaptive) + Task Count: 4 + Tasks Shown: One of 4 + -> Task + Node: host=localhost port=xxxxx dbname=regression + -> Seq Scan on dist_14000000 dist + Task Count: 1 + Tasks Shown: All + -> Task + Node: host=localhost port=xxxxx dbname=regression + -> Subquery Scan on s + -> ProjectSet + -> HashAggregate + Group Key: intermediate_result.b + -> Function Scan on read_intermediate_result intermediate_result +(20 rows) + +SET citus.allow_unsafe_insert_select_pushdown TO on; +SELECT public.explain_filter($$ +EXPLAIN (COSTS OFF) INSERT INTO res(text_id, val) +SELECT id, val FROM ( + SELECT unnest(array_agg(text_id + 1)) id, + unnest(batch_transform(array_agg(text_col))) val + FROM (SELECT text_id, text_col, (row_number() OVER () - 1) / 100 b FROM dist) q + GROUP BY b +) s +$$, true); + explain_filter +--------------------------------------------------------------------- + Custom Scan (Citus INSERT ... SELECT) + INSERT/SELECT method: pull to coordinator + -> Custom Scan (Citus Adaptive) + -> Distributed Subplan XXX_1 + -> WindowAgg + -> Custom Scan (Citus Adaptive) + Task Count: 4 + Tasks Shown: One of 4 + -> Task + Node: host=localhost port=xxxxx dbname=regression + -> Seq Scan on dist_14000000 dist + Task Count: 1 + Tasks Shown: All + -> Task + Node: host=localhost port=xxxxx dbname=regression + -> Subquery Scan on s + -> ProjectSet + -> HashAggregate + Group Key: intermediate_result.b + -> Function Scan on read_intermediate_result intermediate_result +(20 rows) + +-- the batched benchmark shape, but with a transform wrapping array_agg for the +-- distribution column (unnest(batch_transform(array_agg(text_col)))): the unnest +-- argument must be a bare array_agg of the partition column, so this is rejected +-- too and falls back to a coordinator merge +SET citus.allow_unsafe_insert_select_pushdown TO off; +SELECT public.explain_filter($$ +EXPLAIN (COSTS OFF) INSERT INTO res(text_id, val) +SELECT id, val FROM ( + SELECT unnest(batch_transform(array_agg(text_col))) id, + unnest(array_agg(text_id)) val + FROM (SELECT text_id, text_col, (row_number() OVER () - 1) / 100 b FROM dist) q + GROUP BY b +) s +$$, true); + explain_filter +--------------------------------------------------------------------- + Custom Scan (Citus INSERT ... SELECT) + INSERT/SELECT method: pull to coordinator + -> Custom Scan (Citus Adaptive) + -> Distributed Subplan XXX_1 + -> WindowAgg + -> Custom Scan (Citus Adaptive) + Task Count: 4 + Tasks Shown: One of 4 + -> Task + Node: host=localhost port=xxxxx dbname=regression + -> Seq Scan on dist_14000000 dist + Task Count: 1 + Tasks Shown: All + -> Task + Node: host=localhost port=xxxxx dbname=regression + -> Subquery Scan on s + -> ProjectSet + -> HashAggregate + Group Key: intermediate_result.b + -> Function Scan on read_intermediate_result intermediate_result +(20 rows) + +SET citus.allow_unsafe_insert_select_pushdown TO on; +SELECT public.explain_filter($$ +EXPLAIN (COSTS OFF) INSERT INTO res(text_id, val) +SELECT id, val FROM ( + SELECT unnest(batch_transform(array_agg(text_col))) id, + unnest(array_agg(text_id)) val + FROM (SELECT text_id, text_col, (row_number() OVER () - 1) / 100 b FROM dist) q + GROUP BY b +) s +$$, true); + explain_filter +--------------------------------------------------------------------- + Custom Scan (Citus INSERT ... SELECT) + INSERT/SELECT method: pull to coordinator + -> Custom Scan (Citus Adaptive) + -> Distributed Subplan XXX_1 + -> WindowAgg + -> Custom Scan (Citus Adaptive) + Task Count: 4 + Tasks Shown: One of 4 + -> Task + Node: host=localhost port=xxxxx dbname=regression + -> Seq Scan on dist_14000000 dist + Task Count: 1 + Tasks Shown: All + -> Task + Node: host=localhost port=xxxxx dbname=regression + -> Subquery Scan on s + -> ProjectSet + -> HashAggregate + Group Key: intermediate_result.b + -> Function Scan on read_intermediate_result intermediate_result +(20 rows) + +-- --------------------------------------------------------------------- +-- Correctness of the pushed-down batches. +-- --------------------------------------------------------------------- +-- correctness for a GROUP BY batch: per-shard aggregation keeps each text_id +-- matched to its own value +TRUNCATE res; +INSERT INTO res(text_id, val) +SELECT unnest(array_agg(text_id)), unnest(array_agg(length(text_col))) +FROM dist GROUP BY text_id % 10; +SELECT count(*), count(*) FILTER (WHERE val = length('t' || text_id)) AS ok +FROM res JOIN dist USING (text_id); + count | ok +--------------------------------------------------------------------- + 500 | 500 +(1 row) + +-- correctness for the batched benchmark shape: every text_id keeps its own value +-- after batching, the UDF call, and the unnest zip-back +TRUNCATE res; +INSERT INTO res(text_id, val) +SELECT id, val FROM ( + SELECT + unnest(array_agg(text_id ORDER BY text_id)) id, + unnest(batch_transform(array_agg(text_col ORDER BY text_id))) val + FROM ( + SELECT text_id, text_col, (row_number() OVER () - 1) / 100 batch FROM dist + ) q + GROUP BY batch +) s; +SELECT count(*), count(*) FILTER (WHERE val = length('t' || text_id)) AS ok +FROM res JOIN dist USING (text_id); + count | ok +--------------------------------------------------------------------- + 500 | 500 +(1 row) + +-- --------------------------------------------------------------------- +-- Negative coverage: constructs the GUC never relaxes. +-- --------------------------------------------------------------------- +-- reference table target: the GUC does not apply. The plan is a coordinator +-- merge whether the GUC is disabled or enabled (identical), unlike the +-- distributed-target cases above which push down once the GUC is enabled. +CREATE TABLE ref(text_id int, val int); +SELECT create_reference_table('ref'); + create_reference_table +--------------------------------------------------------------------- + +(1 row) + +SET citus.allow_unsafe_insert_select_pushdown TO off; +SELECT public.explain_filter($$ +EXPLAIN (COSTS OFF) INSERT INTO ref(text_id, val) +SELECT text_id % 10, count(*)::int FROM dist GROUP BY text_id % 10 +$$, true); + explain_filter +--------------------------------------------------------------------- + Custom Scan (Citus INSERT ... SELECT) + INSERT/SELECT method: pull to coordinator + -> HashAggregate + Group Key: remote_scan.text_id + -> Custom Scan (Citus Adaptive) + Task Count: 4 + Tasks Shown: One of 4 + -> Task + Node: host=localhost port=xxxxx dbname=regression + -> HashAggregate + Group Key: (text_id % 10) + -> Seq Scan on dist_14000000 dist +(12 rows) + +SET citus.allow_unsafe_insert_select_pushdown TO on; +SELECT public.explain_filter($$ +EXPLAIN (COSTS OFF) INSERT INTO ref(text_id, val) +SELECT text_id % 10, count(*)::int FROM dist GROUP BY text_id % 10 +$$, true); + explain_filter +--------------------------------------------------------------------- + Custom Scan (Citus INSERT ... SELECT) + INSERT/SELECT method: pull to coordinator + -> HashAggregate + Group Key: remote_scan.text_id + -> Custom Scan (Citus Adaptive) + Task Count: 4 + Tasks Shown: One of 4 + -> Task + Node: host=localhost port=xxxxx dbname=regression + -> HashAggregate + Group Key: (text_id % 10) + -> Seq Scan on dist_14000000 dist +(12 rows) + +-- volatile functions: the GUC relaxes only grouping / partition-column matching, +-- never volatility. A volatile function in the SELECT is still not pushed to the +-- shards with the GUC enabled -- it falls back to a coordinator plan, identical +-- to the GUC-disabled case. +CREATE FUNCTION volatile_transform(t text) RETURNS int +LANGUAGE sql VOLATILE AS $$ SELECT length(t) $$; +SET citus.allow_unsafe_insert_select_pushdown TO off; +SELECT public.explain_filter($$ +EXPLAIN (COSTS OFF) INSERT INTO res(text_id, val) +SELECT text_id, volatile_transform(text_col) FROM dist +$$, true); + explain_filter +--------------------------------------------------------------------- + Custom Scan (Citus INSERT ... SELECT) + INSERT/SELECT method: repartition + -> Custom Scan (Citus Adaptive) + Task Count: 4 + Tasks Shown: One of 4 + -> Task + Node: host=localhost port=xxxxx dbname=regression + -> Seq Scan on dist_14000000 dist +(8 rows) + +SET citus.allow_unsafe_insert_select_pushdown TO on; +SELECT public.explain_filter($$ +EXPLAIN (COSTS OFF) INSERT INTO res(text_id, val) +SELECT text_id, volatile_transform(text_col) FROM dist +$$, true); + explain_filter +--------------------------------------------------------------------- + Custom Scan (Citus INSERT ... SELECT) + INSERT/SELECT method: repartition + -> Custom Scan (Citus Adaptive) + Task Count: 4 + Tasks Shown: One of 4 + -> Task + Node: host=localhost port=xxxxx dbname=regression + -> Seq Scan on dist_14000000 dist +(8 rows) + +-- a LIMIT forces a coordinator merge even with the GUC enabled and an otherwise +-- pushdown-eligible plan (grouped on the distribution column); LIMIT/OFFSET are +-- never relaxed +SELECT public.explain_filter($$ +EXPLAIN (COSTS OFF) INSERT INTO res(text_id, val) +SELECT id, val FROM ( + SELECT text_id id, count(*)::int val FROM dist GROUP BY text_id LIMIT 100 +) s +$$, true); + explain_filter +--------------------------------------------------------------------- + Custom Scan (Citus INSERT ... SELECT) + INSERT/SELECT method: pull to coordinator + -> Custom Scan (Citus Adaptive) + -> Distributed Subplan XXX_1 + -> Limit + -> Custom Scan (Citus Adaptive) + Task Count: 4 + Tasks Shown: One of 4 + -> Task + Node: host=localhost port=xxxxx dbname=regression + -> Limit + -> HashAggregate + Group Key: text_id + -> Seq Scan on dist_14000000 dist + Task Count: 1 + Tasks Shown: All + -> Task + Node: host=localhost port=xxxxx dbname=regression + -> Function Scan on read_intermediate_result intermediate_result +(19 rows) + +-- ... and likewise for OFFSET +SELECT public.explain_filter($$ +EXPLAIN (COSTS OFF) INSERT INTO res(text_id, val) +SELECT id, val FROM ( + SELECT text_id id, count(*)::int val FROM dist GROUP BY text_id OFFSET 5 +) s +$$, true); + explain_filter +--------------------------------------------------------------------- + Custom Scan (Citus INSERT ... SELECT) + INSERT/SELECT method: pull to coordinator + -> Custom Scan (Citus Adaptive) + -> Distributed Subplan XXX_1 + -> Limit + -> Custom Scan (Citus Adaptive) + Task Count: 4 + Tasks Shown: One of 4 + -> Task + Node: host=localhost port=xxxxx dbname=regression + -> HashAggregate + Group Key: text_id + -> Seq Scan on dist_14000000 dist + Task Count: 1 + Tasks Shown: All + -> Task + Node: host=localhost port=xxxxx dbname=regression + -> Function Scan on read_intermediate_result intermediate_result +(18 rows) + +SET client_min_messages TO WARNING; +DROP SCHEMA allow_unsafe_insert_select_pushdown CASCADE; diff --git a/src/test/regress/expected/pg18.out b/src/test/regress/expected/pg18.out index 861c12f4946..419342506c5 100644 --- a/src/test/regress/expected/pg18.out +++ b/src/test/regress/expected/pg18.out @@ -3165,6 +3165,79 @@ drop cascades to type product_rating drop cascades to table product_ratings drop cascades to table record_arg_t -- END: PG18: MIN/MAX aggregate OID resolution for ANYARRAY and RECORD +-- PG18: shard-local INSERT .. SELECT batching pushdown +-- (citus.allow_unsafe_insert_select_pushdown). This mirrors a case from +-- allow_unsafe_insert_select_pushdown.sql, which wraps EXPLAIN in +-- public.explain_filter so the plan is comparable across supported Postgres +-- versions. Here we keep the raw EXPLAIN output because this file only runs on +-- PG18+, so we exercise the real plan including the PG18-only WindowAgg +-- "Window:" line. +CREATE SCHEMA pg18_insert_select_pushdown; +SET search_path TO pg18_insert_select_pushdown; +SET citus.next_shard_id TO 14100000; +SET citus.shard_count = 4; +SET citus.shard_replication_factor = 1; +CREATE TABLE dist(text_id int, text_col text); +CREATE TABLE res(text_id int, val int); +SELECT create_distributed_table('dist', 'text_id'); + create_distributed_table +--------------------------------------------------------------------- + +(1 row) + +SELECT create_distributed_table('res', 'text_id'); + create_distributed_table +--------------------------------------------------------------------- + +(1 row) + +INSERT INTO dist SELECT g, 't' || g FROM generate_series(1, 500) g; +-- a batch UDF: returns one value per input, mimicking a batched API call. +-- immutable + parallel safe, like a real batch UDF. +CREATE FUNCTION batch_transform(t text[]) RETURNS int[] +LANGUAGE sql IMMUTABLE PARALLEL SAFE AS $$ SELECT array_agg(length(x)) FROM unnest(t) x $$; +SELECT create_distributed_function('batch_transform(text[])'); +NOTICE: procedure pg18_insert_select_pushdown.batch_transform is already distributed +DETAIL: Citus distributes procedures with CREATE [PROCEDURE|FUNCTION|AGGREGATE] commands + create_distributed_function +--------------------------------------------------------------------- + +(1 row) + +SET citus.allow_unsafe_insert_select_pushdown TO on; +-- the batching and batch UDF run on the shards; the raw PG18 plan +-- includes the WindowAgg "Window:" line +EXPLAIN (COSTS OFF) INSERT INTO res(text_id, val) +SELECT id, val FROM ( + SELECT b, unnest(array_agg(text_id)) id, + unnest(batch_transform(array_agg(text_col))) val + FROM (SELECT text_id, text_col, (row_number() OVER () - 1) / 100 b FROM dist) q + GROUP BY b +) s; + QUERY PLAN +--------------------------------------------------------------------- + Custom Scan (Citus Adaptive) + Task Count: 4 + Tasks Shown: One of 4 + -> Task + Node: host=localhost port=xxxxx dbname=regression + -> Insert on res_14100004 citus_table_alias + -> Subquery Scan on s + -> ProjectSet + -> HashAggregate + Group Key: ((row_number() OVER (?) - 1) / 100) + -> WindowAgg + Window: w1 AS (ROWS UNBOUNDED PRECEDING) + -> Seq Scan on dist_14100000 dist +(13 rows) + +RESET citus.allow_unsafe_insert_select_pushdown; +DROP SCHEMA pg18_insert_select_pushdown CASCADE; +NOTICE: drop cascades to 3 other objects +DETAIL: drop cascades to table dist +drop cascades to table res +drop cascades to function batch_transform(text[]) +SET search_path TO pg18_nn; -- cleanup with minimum verbosity SET client_min_messages TO ERROR; RESET search_path; diff --git a/src/test/regress/multi_schedule b/src/test/regress/multi_schedule index 49891fcdb37..0960831df83 100644 --- a/src/test/regress/multi_schedule +++ b/src/test/regress/multi_schedule @@ -19,6 +19,7 @@ test: multi_behavioral_analytics_basics multi_behavioral_analytics_single_shard_ # We don't parallelize the following test with the ones above because they're # not idempotent and hence causing flaky test detection check to fail. test: multi_insert_select_non_pushable_queries multi_insert_select +test: allow_unsafe_insert_select_pushdown test: multi_shard_update_delete recursive_dml_with_different_planners_executors test: insert_select_repartition window_functions dml_recursive multi_insert_select_window diff --git a/src/test/regress/sql/allow_unsafe_insert_select_pushdown.sql b/src/test/regress/sql/allow_unsafe_insert_select_pushdown.sql new file mode 100644 index 00000000000..4f154da7568 --- /dev/null +++ b/src/test/regress/sql/allow_unsafe_insert_select_pushdown.sql @@ -0,0 +1,333 @@ +-- +-- ALLOW_UNSAFE_INSERT_SELECT_PUSHDOWN +-- +-- Tests citus.allow_unsafe_insert_select_pushdown, which lets a colocated +-- INSERT .. SELECT push GROUP BY / window / DISTINCT batching and a batch UDF +-- down to the shards instead of pulling rows to the coordinator. +-- +-- The distribution column of the target must still come from the source +-- distribution column unchanged: either as a plain Var, or as the provably +-- shard-local batch pass-through unnest(array_agg(dist_col)). Any other derived +-- distribution value (an arithmetic/function transform, or a transform wrapped +-- inside the array_agg) is rejected even with the GUC enabled, because it could +-- route a row to a different shard. +-- +CREATE SCHEMA allow_unsafe_insert_select_pushdown; +SET search_path = allow_unsafe_insert_select_pushdown; +SET citus.next_shard_id TO 14000000; +SET citus.shard_count = 4; +SET citus.shard_replication_factor = 1; + +CREATE TABLE dist(text_id int, text_col text); +CREATE TABLE res(text_id int, val int); +SELECT create_distributed_table('dist', 'text_id'); +SELECT create_distributed_table('res', 'text_id'); + +INSERT INTO dist SELECT g, 't' || g FROM generate_series(1, 500) g; + +-- a batched UDF: returns one value per input, mimicking a batched API call. +-- immutable + parallel safe, like a real batch UDF. +CREATE FUNCTION batch_transform(t text[]) RETURNS int[] +LANGUAGE sql IMMUTABLE PARALLEL SAFE AS $$ SELECT array_agg(length(x)) FROM unnest(t) x $$; +SELECT create_distributed_function('batch_transform(text[])'); + +-- default off: batching is done after pulling rows to the coordinator +-- (explain_filter strips the PG18-only "Window:" line so the plan is +-- comparable across supported Postgres versions) +SELECT public.explain_filter($$ +EXPLAIN (COSTS OFF) INSERT INTO res(text_id, val) +SELECT id, val FROM ( + SELECT b, unnest(array_agg(text_id)) id, + unnest(batch_transform(array_agg(text_col))) val + FROM (SELECT text_id, text_col, (row_number() OVER () - 1) / 100 b FROM dist) q + GROUP BY b +) s +$$, true); + +SET citus.allow_unsafe_insert_select_pushdown TO on; + +-- now the batching and UDF call run on the shards +SELECT public.explain_filter($$ +EXPLAIN (COSTS OFF) INSERT INTO res(text_id, val) +SELECT id, val FROM ( + SELECT b, unnest(array_agg(text_id)) id, + unnest(batch_transform(array_agg(text_col))) val + FROM (SELECT text_id, text_col, (row_number() OVER () - 1) / 100 b FROM dist) q + GROUP BY b +) s +$$, true); + +INSERT INTO res(text_id, val) +SELECT id, val FROM ( + SELECT b, unnest(array_agg(text_id)) id, + unnest(batch_transform(array_agg(text_col))) val + FROM (SELECT text_id, text_col, (row_number() OVER () - 1) / 100 b FROM dist) q + GROUP BY b +) s; + +-- every text_id should be matched to the right value +SELECT count(*), count(*) FILTER (WHERE val = length('t' || text_id)) AS ok +FROM res JOIN dist USING (text_id); + +-- --------------------------------------------------------------------- +-- Positive per-branch coverage. Each construct below used to force a +-- coordinator merge (or was rejected outright). With the GUC enabled the whole +-- colocated INSERT .. SELECT is pushed to the shards because the distribution +-- column is either a plain partition-column Var or the unnest(array_agg(text_id)) +-- batch pass-through. The pushed-down plan is a Custom Scan (Citus Adaptive) +-- whose task runs the INSERT on a shard, with no Distributed Subplan / +-- intermediate results. explain_filter keeps the plan comparable across +-- Postgres versions. +-- --------------------------------------------------------------------- + +-- the batched benchmark shape: bucket rows into fixed-size batches with +-- row_number()/batch_size, array_agg each batch (id and text in the same +-- order), call the batch UDF once per batch, then unnest back to one row per id. +-- This is the query the GUC is meant to push down to the shards. +SELECT public.explain_filter($$ +EXPLAIN (COSTS OFF) INSERT INTO res(text_id, val) +SELECT id, val FROM ( + SELECT + unnest(array_agg(text_id ORDER BY text_id)) id, + unnest(batch_transform(array_agg(text_col ORDER BY text_id))) val + FROM ( + SELECT text_id, text_col, (row_number() OVER () - 1) / 100 batch FROM dist + ) q + GROUP BY batch +) s +$$, true); + +-- branch: GROUP BY on a non-distribution column, distribution column projected +-- as the unnest(array_agg(text_id)) batch pass-through +SELECT public.explain_filter($$ +EXPLAIN (COSTS OFF) INSERT INTO res(text_id, val) +SELECT unnest(array_agg(text_id)), unnest(array_agg(length(text_col))) +FROM dist GROUP BY text_id % 10 +$$, true); + +-- branch: aggregates without GROUP BY +SELECT public.explain_filter($$ +EXPLAIN (COSTS OFF) INSERT INTO res(text_id, val) +SELECT unnest(array_agg(text_id)), unnest(array_agg(length(text_col))) FROM dist +$$, true); + +-- branch: HAVING without GROUP BY on the distribution column +SELECT public.explain_filter($$ +EXPLAIN (COSTS OFF) INSERT INTO res(text_id, val) +SELECT unnest(array_agg(text_id)), unnest(array_agg(length(text_col))) +FROM dist HAVING count(*) > 0 +$$, true); + +-- branch: window function not partitioned on the distribution column, with the +-- distribution column projected as a plain partition-column Var +SELECT public.explain_filter($$ +EXPLAIN (COSTS OFF) INSERT INTO res(text_id, val) +SELECT text_id, row_number() OVER (ORDER BY text_col) FROM dist +$$, true); + +-- combination: GROUP BY + DISTINCT, distribution column as the batch pass-through +SELECT public.explain_filter($$ +EXPLAIN (COSTS OFF) INSERT INTO res(text_id, val) +SELECT DISTINCT unnest(array_agg(text_id)), unnest(array_agg(length(text_col))) +FROM dist GROUP BY text_id % 10 +$$, true); + +-- combination: window + GROUP BY, relaxed constructs living in nested subqueries +SELECT public.explain_filter($$ +EXPLAIN (COSTS OFF) INSERT INTO res(text_id, val) +SELECT id, v FROM ( + SELECT unnest(array_agg(text_id)) id, unnest(array_agg(length(text_col))) v + FROM (SELECT text_id, text_col, row_number() OVER (ORDER BY text_col) rn FROM dist) q + GROUP BY rn % 5 +) s +$$, true); + +-- --------------------------------------------------------------------- +-- Negative coverage: pattern requirement. With the GUC enabled the batching +-- relaxations still fire, but the distribution column is a *transformed* value +-- rather than the source partition column, so the query is not pushed down and +-- falls back to a coordinator merge -- identical to the GUC-disabled plan. +-- --------------------------------------------------------------------- + +-- distribution column derived by arithmetic on the partition column +SET citus.allow_unsafe_insert_select_pushdown TO off; +SELECT public.explain_filter($$ +EXPLAIN (COSTS OFF) INSERT INTO res(text_id, val) +SELECT text_id + 0, length(text_col) FROM dist +$$, true); +SET citus.allow_unsafe_insert_select_pushdown TO on; +SELECT public.explain_filter($$ +EXPLAIN (COSTS OFF) INSERT INTO res(text_id, val) +SELECT text_id + 0, length(text_col) FROM dist +$$, true); + +-- distribution column derived from a non-distribution column (DISTINCT) +SET citus.allow_unsafe_insert_select_pushdown TO off; +SELECT public.explain_filter($$ +EXPLAIN (COSTS OFF) INSERT INTO res(text_id, val) +SELECT DISTINCT length(text_col), text_id % 7 FROM dist +$$, true); +SET citus.allow_unsafe_insert_select_pushdown TO on; +SELECT public.explain_filter($$ +EXPLAIN (COSTS OFF) INSERT INTO res(text_id, val) +SELECT DISTINCT length(text_col), text_id % 7 FROM dist +$$, true); + +-- distribution column derived from a non-distribution column, one subquery down +SET citus.allow_unsafe_insert_select_pushdown TO off; +SELECT public.explain_filter($$ +EXPLAIN (COSTS OFF) INSERT INTO res(text_id, val) +SELECT k, k FROM ( + SELECT DISTINCT length(text_col) k FROM dist +) s +$$, true); +SET citus.allow_unsafe_insert_select_pushdown TO on; +SELECT public.explain_filter($$ +EXPLAIN (COSTS OFF) INSERT INTO res(text_id, val) +SELECT k, k FROM ( + SELECT DISTINCT length(text_col) k FROM dist +) s +$$, true); + +-- the batched benchmark shape, but with the distribution column transformed +-- *inside* array_agg (unnest(array_agg(text_id + 1))): the batch pass-through no +-- longer carries the untransformed partition column, so the same shape that +-- pushes down above is now rejected and falls back to a coordinator merge +SET citus.allow_unsafe_insert_select_pushdown TO off; +SELECT public.explain_filter($$ +EXPLAIN (COSTS OFF) INSERT INTO res(text_id, val) +SELECT id, val FROM ( + SELECT unnest(array_agg(text_id + 1)) id, + unnest(batch_transform(array_agg(text_col))) val + FROM (SELECT text_id, text_col, (row_number() OVER () - 1) / 100 b FROM dist) q + GROUP BY b +) s +$$, true); +SET citus.allow_unsafe_insert_select_pushdown TO on; +SELECT public.explain_filter($$ +EXPLAIN (COSTS OFF) INSERT INTO res(text_id, val) +SELECT id, val FROM ( + SELECT unnest(array_agg(text_id + 1)) id, + unnest(batch_transform(array_agg(text_col))) val + FROM (SELECT text_id, text_col, (row_number() OVER () - 1) / 100 b FROM dist) q + GROUP BY b +) s +$$, true); + +-- the batched benchmark shape, but with a transform wrapping array_agg for the +-- distribution column (unnest(batch_transform(array_agg(text_col)))): the unnest +-- argument must be a bare array_agg of the partition column, so this is rejected +-- too and falls back to a coordinator merge +SET citus.allow_unsafe_insert_select_pushdown TO off; +SELECT public.explain_filter($$ +EXPLAIN (COSTS OFF) INSERT INTO res(text_id, val) +SELECT id, val FROM ( + SELECT unnest(batch_transform(array_agg(text_col))) id, + unnest(array_agg(text_id)) val + FROM (SELECT text_id, text_col, (row_number() OVER () - 1) / 100 b FROM dist) q + GROUP BY b +) s +$$, true); +SET citus.allow_unsafe_insert_select_pushdown TO on; +SELECT public.explain_filter($$ +EXPLAIN (COSTS OFF) INSERT INTO res(text_id, val) +SELECT id, val FROM ( + SELECT unnest(batch_transform(array_agg(text_col))) id, + unnest(array_agg(text_id)) val + FROM (SELECT text_id, text_col, (row_number() OVER () - 1) / 100 b FROM dist) q + GROUP BY b +) s +$$, true); + +-- --------------------------------------------------------------------- +-- Correctness of the pushed-down batches. +-- --------------------------------------------------------------------- + +-- correctness for a GROUP BY batch: per-shard aggregation keeps each text_id +-- matched to its own value +TRUNCATE res; +INSERT INTO res(text_id, val) +SELECT unnest(array_agg(text_id)), unnest(array_agg(length(text_col))) +FROM dist GROUP BY text_id % 10; + +SELECT count(*), count(*) FILTER (WHERE val = length('t' || text_id)) AS ok +FROM res JOIN dist USING (text_id); + +-- correctness for the batched benchmark shape: every text_id keeps its own value +-- after batching, the UDF call, and the unnest zip-back +TRUNCATE res; +INSERT INTO res(text_id, val) +SELECT id, val FROM ( + SELECT + unnest(array_agg(text_id ORDER BY text_id)) id, + unnest(batch_transform(array_agg(text_col ORDER BY text_id))) val + FROM ( + SELECT text_id, text_col, (row_number() OVER () - 1) / 100 batch FROM dist + ) q + GROUP BY batch +) s; + +SELECT count(*), count(*) FILTER (WHERE val = length('t' || text_id)) AS ok +FROM res JOIN dist USING (text_id); + +-- --------------------------------------------------------------------- +-- Negative coverage: constructs the GUC never relaxes. +-- --------------------------------------------------------------------- + +-- reference table target: the GUC does not apply. The plan is a coordinator +-- merge whether the GUC is disabled or enabled (identical), unlike the +-- distributed-target cases above which push down once the GUC is enabled. +CREATE TABLE ref(text_id int, val int); +SELECT create_reference_table('ref'); + +SET citus.allow_unsafe_insert_select_pushdown TO off; +SELECT public.explain_filter($$ +EXPLAIN (COSTS OFF) INSERT INTO ref(text_id, val) +SELECT text_id % 10, count(*)::int FROM dist GROUP BY text_id % 10 +$$, true); + +SET citus.allow_unsafe_insert_select_pushdown TO on; +SELECT public.explain_filter($$ +EXPLAIN (COSTS OFF) INSERT INTO ref(text_id, val) +SELECT text_id % 10, count(*)::int FROM dist GROUP BY text_id % 10 +$$, true); + +-- volatile functions: the GUC relaxes only grouping / partition-column matching, +-- never volatility. A volatile function in the SELECT is still not pushed to the +-- shards with the GUC enabled -- it falls back to a coordinator plan, identical +-- to the GUC-disabled case. +CREATE FUNCTION volatile_transform(t text) RETURNS int +LANGUAGE sql VOLATILE AS $$ SELECT length(t) $$; + +SET citus.allow_unsafe_insert_select_pushdown TO off; +SELECT public.explain_filter($$ +EXPLAIN (COSTS OFF) INSERT INTO res(text_id, val) +SELECT text_id, volatile_transform(text_col) FROM dist +$$, true); + +SET citus.allow_unsafe_insert_select_pushdown TO on; +SELECT public.explain_filter($$ +EXPLAIN (COSTS OFF) INSERT INTO res(text_id, val) +SELECT text_id, volatile_transform(text_col) FROM dist +$$, true); + +-- a LIMIT forces a coordinator merge even with the GUC enabled and an otherwise +-- pushdown-eligible plan (grouped on the distribution column); LIMIT/OFFSET are +-- never relaxed +SELECT public.explain_filter($$ +EXPLAIN (COSTS OFF) INSERT INTO res(text_id, val) +SELECT id, val FROM ( + SELECT text_id id, count(*)::int val FROM dist GROUP BY text_id LIMIT 100 +) s +$$, true); + +-- ... and likewise for OFFSET +SELECT public.explain_filter($$ +EXPLAIN (COSTS OFF) INSERT INTO res(text_id, val) +SELECT id, val FROM ( + SELECT text_id id, count(*)::int val FROM dist GROUP BY text_id OFFSET 5 +) s +$$, true); + +SET client_min_messages TO WARNING; +DROP SCHEMA allow_unsafe_insert_select_pushdown CASCADE; diff --git a/src/test/regress/sql/pg18.sql b/src/test/regress/sql/pg18.sql index aedc786473d..8e6e5b505e0 100644 --- a/src/test/regress/sql/pg18.sql +++ b/src/test/regress/sql/pg18.sql @@ -2007,6 +2007,49 @@ DROP SCHEMA pg18_minmax CASCADE; -- END: PG18: MIN/MAX aggregate OID resolution for ANYARRAY and RECORD +-- PG18: shard-local INSERT .. SELECT batching pushdown +-- (citus.allow_unsafe_insert_select_pushdown). This mirrors a case from +-- allow_unsafe_insert_select_pushdown.sql, which wraps EXPLAIN in +-- public.explain_filter so the plan is comparable across supported Postgres +-- versions. Here we keep the raw EXPLAIN output because this file only runs on +-- PG18+, so we exercise the real plan including the PG18-only WindowAgg +-- "Window:" line. +CREATE SCHEMA pg18_insert_select_pushdown; +SET search_path TO pg18_insert_select_pushdown; +SET citus.next_shard_id TO 14100000; +SET citus.shard_count = 4; +SET citus.shard_replication_factor = 1; + +CREATE TABLE dist(text_id int, text_col text); +CREATE TABLE res(text_id int, val int); +SELECT create_distributed_table('dist', 'text_id'); +SELECT create_distributed_table('res', 'text_id'); + +INSERT INTO dist SELECT g, 't' || g FROM generate_series(1, 500) g; + +-- a batch UDF: returns one value per input, mimicking a batched API call. +-- immutable + parallel safe, like a real batch UDF. +CREATE FUNCTION batch_transform(t text[]) RETURNS int[] +LANGUAGE sql IMMUTABLE PARALLEL SAFE AS $$ SELECT array_agg(length(x)) FROM unnest(t) x $$; +SELECT create_distributed_function('batch_transform(text[])'); + +SET citus.allow_unsafe_insert_select_pushdown TO on; + +-- the batching and batch UDF run on the shards; the raw PG18 plan +-- includes the WindowAgg "Window:" line +EXPLAIN (COSTS OFF) INSERT INTO res(text_id, val) +SELECT id, val FROM ( + SELECT b, unnest(array_agg(text_id)) id, + unnest(batch_transform(array_agg(text_col))) val + FROM (SELECT text_id, text_col, (row_number() OVER () - 1) / 100 b FROM dist) q + GROUP BY b +) s; + +RESET citus.allow_unsafe_insert_select_pushdown; +DROP SCHEMA pg18_insert_select_pushdown CASCADE; +SET search_path TO pg18_nn; + + -- cleanup with minimum verbosity SET client_min_messages TO ERROR; RESET search_path;