diff --git a/src/backend/distributed/executor/multi_executor.c b/src/backend/distributed/executor/multi_executor.c index d24d59008..62c35d769 100644 --- a/src/backend/distributed/executor/multi_executor.c +++ b/src/backend/distributed/executor/multi_executor.c @@ -66,7 +66,8 @@ static CustomExecMethods RouterSingleModifyCustomExecMethods = { .ExplainCustomScan = CitusExplainScan }; -static CustomExecMethods RouterMultiModifyCustomExecMethods = { +/* not static to enable swapping in multi-modify logic during router execution */ +CustomExecMethods RouterMultiModifyCustomExecMethods = { .CustomName = "RouterMultiModifyScan", .BeginCustomScan = CitusModifyBeginScan, .ExecCustomScan = RouterMultiModifyExecScan, diff --git a/src/backend/distributed/executor/multi_router_executor.c b/src/backend/distributed/executor/multi_router_executor.c index 7ba5b864d..3eada7eb7 100644 --- a/src/backend/distributed/executor/multi_router_executor.c +++ b/src/backend/distributed/executor/multi_router_executor.c @@ -419,6 +419,11 @@ CitusModifyBeginScan(CustomScanState *node, EState *estate, int eflags) RaiseDeferredError(planningError, ERROR); } + if (list_length(taskList) > 1) + { + node->methods = &RouterMultiModifyCustomExecMethods; + } + workerJob->taskList = taskList; } @@ -428,7 +433,7 @@ CitusModifyBeginScan(CustomScanState *node, EState *estate, int eflags) /* prevent concurrent placement changes */ AcquireMetadataLocks(taskList); - /* assign task placements */ + /* modify tasks are always assigned using first-replica policy */ workerJob->taskList = FirstReplicaAssignTaskList(taskList); } diff --git a/src/backend/distributed/planner/deparse_shard_query.c b/src/backend/distributed/planner/deparse_shard_query.c index 9ddcc0be1..07aa37729 100644 --- a/src/backend/distributed/planner/deparse_shard_query.c +++ b/src/backend/distributed/planner/deparse_shard_query.c @@ -26,11 +26,15 @@ #include "nodes/nodes.h" #include "nodes/parsenodes.h" #include "nodes/pg_list.h" +#include "parser/parsetree.h" #include "storage/lock.h" #include "utils/lsyscache.h" #include "utils/rel.h" - +static RangeTblEntry * ExtractDistributedInsertValuesRTE(Query *query, + Oid distributedTableId); +static void UpdateTaskQueryString(Query *query, Oid distributedTableId, + RangeTblEntry *valuesRTE, Task *task); static void ConvertRteToSubqueryWithEmptyResult(RangeTblEntry *rte); @@ -43,11 +47,12 @@ RebuildQueryStrings(Query *originalQuery, List *taskList) { ListCell *taskCell = NULL; Oid relationId = ((RangeTblEntry *) linitial(originalQuery->rtable))->relid; + RangeTblEntry *valuesRTE = ExtractDistributedInsertValuesRTE(originalQuery, + relationId); foreach(taskCell, taskList) { Task *task = (Task *) lfirst(taskCell); - StringInfo newQueryString = makeStringInfo(); Query *query = originalQuery; if (task->insertSelectQuery) @@ -90,31 +95,108 @@ RebuildQueryStrings(Query *originalQuery, List *taskList) } } - /* - * For INSERT queries, we only have one relation to update, so we can - * use deparse_shard_query(). For UPDATE and DELETE queries, we may have - * subqueries and joins, so we use relation shard list to update shard - * names and call pg_get_query_def() directly. - */ - if (query->commandType == CMD_INSERT) - { - deparse_shard_query(query, relationId, task->anchorShardId, newQueryString); - } - else - { - List *relationShardList = task->relationShardList; - UpdateRelationToShardNames((Node *) query, relationShardList); + ereport(DEBUG4, (errmsg("query before rebuilding: %s", task->queryString))); - pg_get_query_def(query, newQueryString); - } + UpdateTaskQueryString(query, relationId, valuesRTE, task); - ereport(DEBUG4, (errmsg("distributed statement: %s", newQueryString->data))); - - task->queryString = newQueryString->data; + ereport(DEBUG4, (errmsg("query after rebuilding: %s", task->queryString))); } } +/* + * ExtractDistributedInsertValuesRTE does precisely that. If the provided + * query is not an INSERT, or if the table is a reference table, or if the + * INSERT does not have a VALUES RTE (i.e. it is not a multi-row INSERT), this + * function returns NULL. If all those conditions are met, an RTE representing + * the multiple values of a multi-row INSERT is returned. + */ +static RangeTblEntry * +ExtractDistributedInsertValuesRTE(Query *query, Oid distributedTableId) +{ + RangeTblEntry *valuesRTE = NULL; + uint32 rangeTableId = 1; + Var *partitionColumn = NULL; + TargetEntry *targetEntry = NULL; + + if (query->commandType != CMD_INSERT) + { + return NULL; + } + + partitionColumn = PartitionColumn(distributedTableId, rangeTableId); + if (partitionColumn == NULL) + { + return NULL; + } + + targetEntry = get_tle_by_resno(query->targetList, partitionColumn->varattno); + Assert(targetEntry != NULL); + + if (IsA(targetEntry->expr, Var)) + { + Var *partitionVar = (Var *) targetEntry->expr; + + valuesRTE = rt_fetch(partitionVar->varno, query->rtable); + if (valuesRTE->rtekind != RTE_VALUES) + { + return NULL; + } + } + + return valuesRTE; +} + + +/* + * UpdateTaskQueryString updates the query string stored within the provided + * Task. If the Task has row values from a multi-row INSERT, those are injected + * into the provided query (using the provided valuesRTE, which must belong to + * the query) before deparse occurs (the query's full VALUES list will be + * restored before this function returns). + */ +static void +UpdateTaskQueryString(Query *query, Oid distributedTableId, RangeTblEntry *valuesRTE, + Task *task) +{ + StringInfo queryString = makeStringInfo(); + List *oldValuesLists = NIL; + + if (valuesRTE != NULL) + { + Assert(valuesRTE->rtekind == RTE_VALUES); + + oldValuesLists = valuesRTE->values_lists; + valuesRTE->values_lists = task->rowValuesLists; + } + + /* + * For INSERT queries, we only have one relation to update, so we can + * use deparse_shard_query(). For UPDATE and DELETE queries, we may have + * subqueries and joins, so we use relation shard list to update shard + * names and call pg_get_query_def() directly. + */ + if (query->commandType == CMD_INSERT) + { + deparse_shard_query(query, distributedTableId, task->anchorShardId, queryString); + } + else + { + List *relationShardList = task->relationShardList; + UpdateRelationToShardNames((Node *) query, relationShardList); + + pg_get_query_def(query, queryString); + } + + if (valuesRTE != NULL) + { + valuesRTE->values_lists = oldValuesLists; + } + + task->queryString = queryString->data; +} + + /* * UpdateRelationToShardNames walks over the query tree and appends shard ids to * relations. It uses unique identity value to establish connection between a diff --git a/src/backend/distributed/planner/multi_router_planner.c b/src/backend/distributed/planner/multi_router_planner.c index 2021379c3..c9854e1b2 100644 --- a/src/backend/distributed/planner/multi_router_planner.c +++ b/src/backend/distributed/planner/multi_router_planner.c @@ -71,6 +71,28 @@ #include "catalog/pg_proc.h" #include "optimizer/planmain.h" +/* intermediate value for INSERT processing */ +typedef struct InsertValues +{ + Expr *partitionValueExpr; /* partition value provided in INSERT row */ + List *rowValues; /* full values list of INSERT row, possibly NIL */ + int64 shardId; /* target shard for this row, possibly invalid */ +} InsertValues; + + +/* + * A ModifyRoute encapsulates the the information needed to route modifications + * to the appropriate shard. For a single-shard modification, only one route + * is needed, but in the case of e.g. a multi-row INSERT, lists of these values + * will help divide the rows by their destination shards, permitting later + * shard-and-row-specific extension of the original SQL. + */ +typedef struct ModifyRoute +{ + int64 shardId; /* identifier of target shard */ + List *rowValuesLists; /* for multi-row INSERTs, list of rows to be inserted */ +} ModifyRoute; + typedef struct WalkerState { @@ -99,9 +121,6 @@ static void ErrorIfNoShardsExist(DistTableCacheEntry *cacheEntry); static bool CanShardPrune(Oid distributedTableId, Query *query); static Job * CreateJob(Query *query); static Task * CreateTask(TaskType taskType); -static ShardInterval * FindShardForInsert(Query *query, DistTableCacheEntry *cacheEntry, - DeferredErrorMessage **planningError); -static Expr * ExtractInsertPartitionValue(Query *query, Var *partitionColumn); static Job * RouterJob(Query *originalQuery, RelationRestrictionContext *restrictionContext, DeferredErrorMessage **planningError); @@ -110,6 +129,9 @@ static List * TargetShardIntervalsForRouter(Query *query, RelationRestrictionContext *restrictionContext, bool *multiShardQuery); static List * WorkersContainingAllShards(List *prunedShardIntervalsList); +static List * BuildRoutesForInsert(Query *query, DeferredErrorMessage **planningError); +static List * GroupInsertValuesByShardId(List *insertValuesList); +static List * ExtractInsertValuesList(Query *query, Var *partitionColumn); static bool MultiRouterPlannableQuery(Query *query, RelationRestrictionContext *restrictionContext); static DeferredErrorMessage * ErrorIfQueryHasModifyingCTE(Query *queryTree); @@ -120,6 +142,8 @@ static bool SelectsFromDistributedTable(List *rangeTableList); #if (PG_VERSION_NUM >= 100000) static List * get_all_actual_clauses(List *restrictinfo_list); #endif +static int CompareInsertValuesByShardId(const void *leftElement, + const void *rightElement); /* @@ -471,7 +495,6 @@ ModifyQuerySupported(Query *queryTree, bool multiShardQuery) bool isCoordinator = IsCoordinator(); List *rangeTableList = NIL; ListCell *rangeTableCell = NULL; - bool hasValuesScan = false; uint32 queryTableCount = 0; bool specifiesPartitionValue = false; ListCell *setTargetCell = NULL; @@ -555,7 +578,7 @@ ModifyQuerySupported(Query *queryTree, bool multiShardQuery) } else if (rangeTableEntry->rtekind == RTE_VALUES) { - hasValuesScan = true; + /* do nothing, this type is supported */ } else { @@ -626,24 +649,6 @@ ModifyQuerySupported(Query *queryTree, bool multiShardQuery) } } - /* reject queries which involve multi-row inserts */ - if (hasValuesScan) - { - /* - * NB: If you remove this check you must also change the checks further in this - * method and ensure that VOLATILE function calls aren't allowed in INSERT - * statements. Currently they're allowed but the function call is replaced - * with a constant, and if you're inserting multiple rows at once the function - * should return a different value for each row. - */ - return DeferredError(ERRCODE_FEATURE_NOT_SUPPORTED, - "cannot perform distributed planning for the given" - " modification", - "Multi-row INSERTs to distributed tables are not " - "supported.", - NULL); - } - if (commandType == CMD_INSERT || commandType == CMD_UPDATE || commandType == CMD_DELETE) { @@ -1143,7 +1148,8 @@ CanShardPrune(Oid distributedTableId, Query *query) { uint32 rangeTableId = 1; Var *partitionColumn = NULL; - Expr *partitionValueExpr = NULL; + List *insertValuesList = NIL; + ListCell *insertValuesCell = NULL; if (query->commandType != CMD_INSERT) { @@ -1158,14 +1164,19 @@ CanShardPrune(Oid distributedTableId, Query *query) return true; } - partitionValueExpr = ExtractInsertPartitionValue(query, partitionColumn); - if (IsA(partitionValueExpr, Const)) + /* get full list of partition values and ensure they are all Consts */ + insertValuesList = ExtractInsertValuesList(query, partitionColumn); + foreach(insertValuesCell, insertValuesList) { - /* can do shard pruning if the partition column is constant */ - return true; + InsertValues *insertValues = (InsertValues *) lfirst(insertValuesCell); + if (!IsA(insertValues->partitionValueExpr, Const)) + { + /* can't do shard pruning if the partition column is not constant */ + return false; + } } - return false; + return true; } @@ -1198,8 +1209,9 @@ ErrorIfNoShardsExist(DistTableCacheEntry *cacheEntry) List * RouterInsertTaskList(Query *query, DeferredErrorMessage **planningError) { - ShardInterval *shardInterval = NULL; - Task *modifyTask = NULL; + List *insertTaskList = NIL; + List *modifyRouteList = NIL; + ListCell *modifyRouteCell = NULL; Oid distributedTableId = ExtractFirstDistributedTableId(query); DistTableCacheEntry *cacheEntry = DistributedTableCacheEntry(distributedTableId); @@ -1208,26 +1220,30 @@ RouterInsertTaskList(Query *query, DeferredErrorMessage **planningError) Assert(query->commandType == CMD_INSERT); - shardInterval = FindShardForInsert(query, cacheEntry, planningError); - + modifyRouteList = BuildRoutesForInsert(query, planningError); if (*planningError != NULL) { return NIL; } - /* an INSERT always routes to exactly one shard */ - Assert(shardInterval != NULL); - - modifyTask = CreateTask(MODIFY_TASK); - modifyTask->anchorShardId = shardInterval->shardId; - modifyTask->replicationModel = cacheEntry->replicationModel; - - if (query->onConflict != NULL) + foreach(modifyRouteCell, modifyRouteList) { - modifyTask->upsertQuery = true; + ModifyRoute *modifyRoute = (ModifyRoute *) lfirst(modifyRouteCell); + + Task *modifyTask = CreateTask(MODIFY_TASK); + modifyTask->anchorShardId = modifyRoute->shardId; + modifyTask->replicationModel = cacheEntry->replicationModel; + modifyTask->rowValuesLists = modifyRoute->rowValuesLists; + + if (query->onConflict != NULL) + { + modifyTask->upsertQuery = true; + } + + insertTaskList = lappend(insertTaskList, modifyTask); } - return list_make1(modifyTask); + return insertTaskList; } @@ -1264,132 +1280,6 @@ CreateTask(TaskType taskType) } -/* - * FindShardForInsert returns the shard interval for an INSERT query or NULL if - * the partition column value is defined as an expression that still needs to be - * evaluated. If the partition column value falls within 0 or multiple - * (overlapping) shards, the planningError is set. - */ -static ShardInterval * -FindShardForInsert(Query *query, DistTableCacheEntry *cacheEntry, - DeferredErrorMessage **planningError) -{ - Oid distributedTableId = cacheEntry->relationId; - char partitionMethod = cacheEntry->partitionMethod; - uint32 rangeTableId = 1; - Var *partitionColumn = NULL; - Expr *partitionValueExpr = NULL; - Const *partitionValueConst = NULL; - int prunedShardCount = 0; - List *prunedShardList = NIL; - - Assert(query->commandType == CMD_INSERT); - - /* reference tables do not have a partition column, but can only have one shard */ - if (partitionMethod == DISTRIBUTE_BY_NONE) - { - int shardCount = cacheEntry->shardIntervalArrayLength; - if (shardCount != 1) - { - ereport(ERROR, (errmsg("reference table cannot have %d shards", shardCount))); - } - - return cacheEntry->sortedShardIntervalArray[0]; - } - - partitionColumn = PartitionColumn(distributedTableId, rangeTableId); - partitionValueExpr = ExtractInsertPartitionValue(query, partitionColumn); - - /* non-constants should have been caught by CanShardPrune */ - if (!IsA(partitionValueExpr, Const)) - { - ereport(ERROR, (errcode(ERRCODE_NULL_VALUE_NOT_ALLOWED), - errmsg("cannot perform an INSERT with a non-constant in the " - "partition column"))); - } - - partitionValueConst = (Const *) partitionValueExpr; - if (partitionValueConst->constisnull) - { - ereport(ERROR, (errcode(ERRCODE_NULL_VALUE_NOT_ALLOWED), - errmsg("cannot perform an INSERT with NULL in the partition " - "column"))); - } - - if (partitionMethod == DISTRIBUTE_BY_HASH || partitionMethod == DISTRIBUTE_BY_RANGE) - { - Datum partitionValue = partitionValueConst->constvalue; - ShardInterval *shardInterval = FindShardInterval(partitionValue, cacheEntry); - - if (shardInterval != NULL) - { - prunedShardList = list_make1(shardInterval); - } - } - else - { - List *restrictClauseList = NIL; - Index tableId = 1; - OpExpr *equalityExpr = MakeOpExpression(partitionColumn, BTEqualStrategyNumber); - Node *rightOp = get_rightop((Expr *) equalityExpr); - Const *rightConst = (Const *) rightOp; - - Assert(IsA(rightOp, Const)); - - rightConst->constvalue = partitionValueConst->constvalue; - rightConst->constisnull = partitionValueConst->constisnull; - rightConst->constbyval = partitionValueConst->constbyval; - - restrictClauseList = list_make1(equalityExpr); - - prunedShardList = PruneShards(distributedTableId, tableId, restrictClauseList); - } - - prunedShardCount = list_length(prunedShardList); - if (prunedShardCount != 1) - { - char *partitionKeyString = cacheEntry->partitionKeyString; - char *partitionColumnName = ColumnNameToColumn(distributedTableId, - partitionKeyString); - StringInfo errorMessage = makeStringInfo(); - StringInfo errorHint = makeStringInfo(); - const char *targetCountType = NULL; - - if (prunedShardCount == 0) - { - targetCountType = "no"; - } - else - { - targetCountType = "multiple"; - } - - if (prunedShardCount == 0) - { - appendStringInfo(errorHint, "Make sure you have created a shard which " - "can receive this partition column value."); - } - else - { - appendStringInfo(errorHint, "Make sure the value for partition column " - "\"%s\" falls into a single shard.", - partitionColumnName); - } - - appendStringInfo(errorMessage, "cannot run INSERT command which targets %s " - "shards", targetCountType); - - (*planningError) = DeferredError(ERRCODE_FEATURE_NOT_SUPPORTED, - errorMessage->data, NULL, - errorHint->data); - - return NULL; - } - - return (ShardInterval *) linitial(prunedShardList); -} - - /* * ExtractFirstDistributedTableId takes a given query, and finds the relationId * for the first distributed table in that query. If the function cannot find a @@ -1420,27 +1310,6 @@ ExtractFirstDistributedTableId(Query *query) } -/* - * ExtractPartitionValue extracts the partition column value from a the target - * of an INSERT command. If a partition value is missing altogether or is - * NULL, this function throws an error. - */ -static Expr * -ExtractInsertPartitionValue(Query *query, Var *partitionColumn) -{ - TargetEntry *targetEntry = get_tle_by_resno(query->targetList, - partitionColumn->varattno); - if (targetEntry == NULL) - { - ereport(ERROR, (errcode(ERRCODE_NULL_VALUE_NOT_ALLOWED), - errmsg("cannot perform an INSERT without a partition column " - "value"))); - } - - return targetEntry->expr; -} - - /* RouterJob builds a Job to represent a single shard select/update/delete query */ static Job * RouterJob(Query *originalQuery, RelationRestrictionContext *restrictionContext, @@ -1957,6 +1826,167 @@ WorkersContainingAllShards(List *prunedShardIntervalsList) } +/* + * BuildRoutesForInsert returns a list of ModifyRoute objects for an INSERT + * query or an empty list if the partition column value is defined as an ex- + * pression that still needs to be evaluated. If any partition column value + * falls within 0 or multiple (overlapping) shards, the planning error is set. + * + * Multi-row INSERTs are handled by grouping their rows by target shard. These + * groups are returned in ascending order by shard id, ready for later deparse + * to shard-specific SQL. + */ +static List * +BuildRoutesForInsert(Query *query, DeferredErrorMessage **planningError) +{ + Oid distributedTableId = ExtractFirstDistributedTableId(query); + DistTableCacheEntry *cacheEntry = DistributedTableCacheEntry(distributedTableId); + char partitionMethod = cacheEntry->partitionMethod; + uint32 rangeTableId = 1; + Var *partitionColumn = NULL; + List *insertValuesList = NIL; + List *modifyRouteList = NIL; + ListCell *insertValuesCell = NULL; + + Assert(query->commandType == CMD_INSERT); + + /* reference tables can only have one shard */ + if (partitionMethod == DISTRIBUTE_BY_NONE) + { + int shardCount = 0; + List *shardIntervalList = LoadShardIntervalList(distributedTableId); + ShardInterval *shardInterval = NULL; + ModifyRoute *modifyRoute = NULL; + + shardCount = list_length(shardIntervalList); + if (shardCount != 1) + { + ereport(ERROR, (errmsg("reference table cannot have %d shards", shardCount))); + } + + shardInterval = linitial(shardIntervalList); + modifyRoute = palloc(sizeof(ModifyRoute)); + + modifyRoute->shardId = shardInterval->shardId; + modifyRoute->rowValuesLists = NIL; + + modifyRouteList = lappend(modifyRouteList, modifyRoute); + + return modifyRouteList; + } + + partitionColumn = PartitionColumn(distributedTableId, rangeTableId); + + /* get full list of insert values and iterate over them to prune */ + insertValuesList = ExtractInsertValuesList(query, partitionColumn); + + foreach(insertValuesCell, insertValuesList) + { + InsertValues *insertValues = (InsertValues *) lfirst(insertValuesCell); + Const *partitionValueConst = NULL; + List *prunedShardList = NIL; + int prunedShardCount = 0; + ShardInterval *targetShard = NULL; + + if (!IsA(insertValues->partitionValueExpr, Const)) + { + /* shard pruning not possible right now */ + return NIL; + } + + partitionValueConst = (Const *) insertValues->partitionValueExpr; + if (partitionValueConst->constisnull) + { + ereport(ERROR, (errcode(ERRCODE_NULL_VALUE_NOT_ALLOWED), + errmsg("cannot perform an INSERT with NULL in the partition " + "column"))); + } + + if (partitionMethod == DISTRIBUTE_BY_HASH || partitionMethod == + DISTRIBUTE_BY_RANGE) + { + Datum partitionValue = partitionValueConst->constvalue; + DistTableCacheEntry *cacheEntry = DistributedTableCacheEntry( + distributedTableId); + ShardInterval *shardInterval = FindShardInterval(partitionValue, cacheEntry); + + if (shardInterval != NULL) + { + prunedShardList = list_make1(shardInterval); + } + } + else + { + List *restrictClauseList = NIL; + Index tableId = 1; + OpExpr *equalityExpr = MakeOpExpression(partitionColumn, + BTEqualStrategyNumber); + Node *rightOp = get_rightop((Expr *) equalityExpr); + Const *rightConst = (Const *) rightOp; + + Assert(IsA(rightOp, Const)); + + rightConst->constvalue = partitionValueConst->constvalue; + rightConst->constisnull = partitionValueConst->constisnull; + rightConst->constbyval = partitionValueConst->constbyval; + + restrictClauseList = list_make1(equalityExpr); + + prunedShardList = PruneShards(distributedTableId, tableId, + restrictClauseList); + } + + prunedShardCount = list_length(prunedShardList); + if (prunedShardCount != 1) + { + char *partitionKeyString = cacheEntry->partitionKeyString; + char *partitionColumnName = ColumnNameToColumn(distributedTableId, + partitionKeyString); + StringInfo errorMessage = makeStringInfo(); + StringInfo errorHint = makeStringInfo(); + const char *targetCountType = NULL; + + if (prunedShardCount == 0) + { + targetCountType = "no"; + } + else + { + targetCountType = "multiple"; + } + + if (prunedShardCount == 0) + { + appendStringInfo(errorHint, "Make sure you have created a shard which " + "can receive this partition column value."); + } + else + { + appendStringInfo(errorHint, "Make sure the value for partition column " + "\"%s\" falls into a single shard.", + partitionColumnName); + } + + appendStringInfo(errorMessage, "cannot run INSERT command which targets %s " + "shards", targetCountType); + + (*planningError) = DeferredError(ERRCODE_FEATURE_NOT_SUPPORTED, + errorMessage->data, NULL, + errorHint->data); + + return NIL; + } + + targetShard = (ShardInterval *) linitial(prunedShardList); + insertValues->shardId = targetShard->shardId; + } + + modifyRouteList = GroupInsertValuesByShardId(insertValuesList); + + return modifyRouteList; +} + + /* * IntersectPlacementList performs placement pruning based on matching on * nodeName:nodePort fields of shard placement data. We start pruning from all @@ -1993,6 +2023,129 @@ IntersectPlacementList(List *lhsPlacementList, List *rhsPlacementList) } +/* + * GroupInsertValuesByShardId takes care of grouping the rows from a multi-row + * INSERT by target shard. At this point, all pruning has taken place and we + * need only to build sets of rows for each destination. This is done by a + * simple sort (by shard identifier) and gather step. The sort has the side- + * effect of getting things in ascending order to avoid unnecessary deadlocks + * during Task execution. + */ +static List * +GroupInsertValuesByShardId(List *insertValuesList) +{ + ModifyRoute *route = NULL; + ListCell *insertValuesCell = NULL; + List *modifyRouteList = NIL; + + insertValuesList = SortList(insertValuesList, CompareInsertValuesByShardId); + foreach(insertValuesCell, insertValuesList) + { + InsertValues *insertValues = (InsertValues *) lfirst(insertValuesCell); + int64 shardId = insertValues->shardId; + bool foundSameShardId = false; + + if (route != NULL) + { + if (route->shardId == shardId) + { + foundSameShardId = true; + } + else + { + /* new shard id seen; current aggregation done; add to list */ + modifyRouteList = lappend(modifyRouteList, route); + } + } + + if (foundSameShardId) + { + /* + * Our current value has the same shard id as our aggregate object, + * so append the rowValues. + */ + route->rowValuesLists = lappend(route->rowValuesLists, + insertValues->rowValues); + } + else + { + /* we encountered a new shard id; build a new aggregate object */ + route = (ModifyRoute *) palloc(sizeof(ModifyRoute)); + route->shardId = insertValues->shardId; + route->rowValuesLists = list_make1(insertValues->rowValues); + } + } + + /* left holding one final aggregate object; add to list */ + modifyRouteList = lappend(modifyRouteList, route); + + return modifyRouteList; +} + + +/* + * ExtractInsertValuesList extracts the partition column value for an INSERT + * command and returns it within an InsertValues struct. For single-row INSERTs + * this is simply a value extracted from the target list, but multi-row INSERTs + * will generate a List of InsertValues, each with full row values in addition + * to the partition value. If a partition value is NULL or missing altogether, + * this function errors. + */ +static List * +ExtractInsertValuesList(Query *query, Var *partitionColumn) +{ + List *insertValuesList = NIL; + TargetEntry *targetEntry = get_tle_by_resno(query->targetList, + partitionColumn->varattno); + + if (targetEntry == NULL) + { + ereport(ERROR, (errcode(ERRCODE_NULL_VALUE_NOT_ALLOWED), + errmsg("cannot perform an INSERT without a partition column " + "value"))); + } + + /* + * We've got a multi-row INSERT. PostgreSQL internally represents such + * commands by linking Vars in the target list to lists of values within + * a special VALUES range table entry. By extracting the right positional + * expression from each list within that RTE, we will extract the partition + * values for each row within the multi-row INSERT. + */ + if (IsA(targetEntry->expr, Var)) + { + Var *partitionVar = (Var *) targetEntry->expr; + RangeTblEntry *referencedRTE = NULL; + ListCell *valuesListCell = NULL; + + referencedRTE = rt_fetch(partitionVar->varno, query->rtable); + foreach(valuesListCell, referencedRTE->values_lists) + { + InsertValues *insertValues = (InsertValues *) palloc(sizeof(InsertValues)); + insertValues->rowValues = (List *) lfirst(valuesListCell); + insertValues->partitionValueExpr = list_nth(insertValues->rowValues, + (partitionVar->varattno - 1)); + insertValues->shardId = INVALID_SHARD_ID; + + insertValuesList = lappend(insertValuesList, insertValues); + } + } + + /* nothing's been found yet; this is a simple single-row INSERT */ + if (insertValuesList == NIL) + { + InsertValues *insertValues = (InsertValues *) palloc(sizeof(InsertValues)); + insertValues->rowValues = NIL; + insertValues->partitionValueExpr = targetEntry->expr; + insertValues->shardId = INVALID_SHARD_ID; + + insertValuesList = lappend(insertValuesList, insertValues); + } + + return insertValuesList; +} + + /* * MultiRouterPlannableQuery returns true if given query can be router plannable. * The query is router plannable if it is a modify query, or if its is a select @@ -2170,3 +2323,31 @@ get_all_actual_clauses(List *restrictinfo_list) #endif + + +/* + * CompareInsertValuesByShardId does what it says in the name. Used for sorting + * InsertValues objects by their shard. + */ +static int +CompareInsertValuesByShardId(const void *leftElement, const void *rightElement) +{ + InsertValues *leftValue = *((InsertValues **) leftElement); + InsertValues *rightValue = *((InsertValues **) rightElement); + int64 leftShardId = leftValue->shardId; + int64 rightShardId = rightValue->shardId; + + /* we compare 64-bit integers, instead of casting their difference to int */ + if (leftShardId > rightShardId) + { + return 1; + } + else if (leftShardId < rightShardId) + { + return -1; + } + else + { + return 0; + } +} diff --git a/src/backend/distributed/utils/citus_clauses.c b/src/backend/distributed/utils/citus_clauses.c index 979666b5c..79e5268b1 100644 --- a/src/backend/distributed/utils/citus_clauses.c +++ b/src/backend/distributed/utils/citus_clauses.c @@ -32,6 +32,7 @@ typedef struct FunctionEvaluationContext /* private function declarations */ +static void EvaluateValuesListsItems(List *valuesLists, PlanState *planState); static Node * EvaluateNodeIfReferencesFunction(Node *expression, PlanState *planState); static Node * PartiallyEvaluateExpressionMutator(Node *expression, FunctionEvaluationContext *context); @@ -63,14 +64,19 @@ RequiresMasterEvaluation(Query *query) { RangeTblEntry *rte = (RangeTblEntry *) lfirst(rteCell); - if (rte->rtekind != RTE_SUBQUERY) + if (rte->rtekind == RTE_SUBQUERY) { - continue; + if (RequiresMasterEvaluation(rte->subquery)) + { + return true; + } } - - if (RequiresMasterEvaluation(rte->subquery)) + else if (rte->rtekind == RTE_VALUES) { - return true; + if (contain_mutable_functions((Node *) rte->values_lists)) + { + return true; + } } } @@ -131,12 +137,14 @@ ExecuteMasterEvaluableFunctions(Query *query, PlanState *planState) { RangeTblEntry *rte = (RangeTblEntry *) lfirst(rteCell); - if (rte->rtekind != RTE_SUBQUERY) + if (rte->rtekind == RTE_SUBQUERY) { - continue; + ExecuteMasterEvaluableFunctions(rte->subquery, planState); + } + else if (rte->rtekind == RTE_VALUES) + { + EvaluateValuesListsItems(rte->values_lists, planState); } - - ExecuteMasterEvaluableFunctions(rte->subquery, planState); } foreach(cteCell, query->cteList) @@ -148,6 +156,35 @@ ExecuteMasterEvaluableFunctions(Query *query, PlanState *planState) } +/* + * EvaluateValuesListsItems siply does the work of walking over each expression + * in each value list contained in a multi-row INSERT's VALUES RTE. Basically + * a nested for loop to perform an in-place replacement of expressions with + * their ultimate values, should evaluation be necessary. + */ +static void +EvaluateValuesListsItems(List *valuesLists, PlanState *planState) +{ + ListCell *exprListCell = NULL; + + foreach(exprListCell, valuesLists) + { + List *exprList = (List *) lfirst(exprListCell); + ListCell *exprCell = NULL; + + foreach(exprCell, exprList) + { + Expr *expr = (Expr *) lfirst(exprCell); + Node *modifiedNode = NULL; + + modifiedNode = PartiallyEvaluateExpression((Node *) expr, planState); + + exprCell->data.ptr_value = (void *) modifiedNode; + } + } +} + + /* * Walks the expression evaluating any node which invokes a function as long as a Var * doesn't show up in the parameter list. diff --git a/src/backend/distributed/utils/citus_copyfuncs.c b/src/backend/distributed/utils/citus_copyfuncs.c index b93f6b434..9a3ab06af 100644 --- a/src/backend/distributed/utils/citus_copyfuncs.c +++ b/src/backend/distributed/utils/citus_copyfuncs.c @@ -232,6 +232,7 @@ CopyNodeTask(COPYFUNC_ARGS) COPY_SCALAR_FIELD(replicationModel); COPY_SCALAR_FIELD(insertSelectQuery); COPY_NODE_FIELD(relationShardList); + COPY_NODE_FIELD(rowValuesLists); } diff --git a/src/backend/distributed/utils/citus_outfuncs.c b/src/backend/distributed/utils/citus_outfuncs.c index b4bb5bb30..e5fb43fab 100644 --- a/src/backend/distributed/utils/citus_outfuncs.c +++ b/src/backend/distributed/utils/citus_outfuncs.c @@ -441,6 +441,7 @@ OutTask(OUTFUNC_ARGS) WRITE_CHAR_FIELD(replicationModel); WRITE_BOOL_FIELD(insertSelectQuery); WRITE_NODE_FIELD(relationShardList); + WRITE_NODE_FIELD(rowValuesLists); } diff --git a/src/backend/distributed/utils/citus_readfuncs.c b/src/backend/distributed/utils/citus_readfuncs.c index 849c54ba1..6e597cb6e 100644 --- a/src/backend/distributed/utils/citus_readfuncs.c +++ b/src/backend/distributed/utils/citus_readfuncs.c @@ -351,6 +351,7 @@ ReadTask(READFUNC_ARGS) READ_CHAR_FIELD(replicationModel); READ_BOOL_FIELD(insertSelectQuery); READ_NODE_FIELD(relationShardList); + READ_NODE_FIELD(rowValuesLists); READ_DONE(); } diff --git a/src/include/distributed/multi_executor.h b/src/include/distributed/multi_executor.h index eac04ab92..a710a9bc7 100644 --- a/src/include/distributed/multi_executor.h +++ b/src/include/distributed/multi_executor.h @@ -28,6 +28,8 @@ typedef struct CitusScanState } CitusScanState; +extern CustomExecMethods RouterMultiModifyCustomExecMethods; + extern Node * RealTimeCreateScan(CustomScan *scan); extern Node * TaskTrackerCreateScan(CustomScan *scan); extern Node * RouterCreateScan(CustomScan *scan); diff --git a/src/include/distributed/multi_physical_planner.h b/src/include/distributed/multi_physical_planner.h index 1f43851f3..bbff3a83e 100644 --- a/src/include/distributed/multi_physical_planner.h +++ b/src/include/distributed/multi_physical_planner.h @@ -187,6 +187,8 @@ typedef struct Task bool insertSelectQuery; List *relationShardList; + + List *rowValuesLists; /* rows to use when building multi-row INSERT */ } Task; diff --git a/src/test/regress/expected/multi_insert_select.out b/src/test/regress/expected/multi_insert_select.out index 91494ee9b..16d3c22e7 100644 --- a/src/test/regress/expected/multi_insert_select.out +++ b/src/test/regress/expected/multi_insert_select.out @@ -1657,6 +1657,12 @@ BEGIN; INSERT INTO raw_events_first SELECT * FROM raw_events_second WHERE user_id = 100; COPY raw_events_first (user_id, value_1) FROM STDIN DELIMITER ','; ROLLBACK; +-- Similarly, multi-row INSERTs will take part in transactions and reuse connections... +BEGIN; +INSERT INTO raw_events_first SELECT * FROM raw_events_second WHERE user_id = 100; +COPY raw_events_first (user_id, value_1) FROM STDIN DELIMITER ','; +INSERT INTO raw_events_first (user_id, value_1) VALUES (105, 105), (106, 106); +ROLLBACK; -- selecting from views works CREATE VIEW test_view AS SELECT * FROM raw_events_first; INSERT INTO raw_events_first (user_id, time, value_1, value_2, value_3, value_4) VALUES diff --git a/src/test/regress/expected/multi_modifications.out b/src/test/regress/expected/multi_modifications.out index c53e090fa..608c23609 100644 --- a/src/test/regress/expected/multi_modifications.out +++ b/src/test/regress/expected/multi_modifications.out @@ -190,7 +190,7 @@ CONTEXT: while executing command on localhost:57638 INSERT INTO limit_orders VALUES (34153, 'LEE', 5994, '2001-04-16 03:37:28', 'buy', 0.58) RETURNING id / 0; ERROR: could not modify any active placements SET client_min_messages TO DEFAULT; --- commands with non-constant partition values are unsupported +-- commands with non-constant partition values are supported INSERT INTO limit_orders VALUES (random() * 100, 'ORCL', 152, '2011-08-25 11:50:45', 'sell', 0.58); -- values for other columns are totally fine @@ -201,10 +201,10 @@ ERROR: functions used in the WHERE clause of modification queries on distribute -- commands with mutable but non-volatile functions(ie: stable func.) in their quals -- (the cast to timestamp is because the timestamp_eq_timestamptz operator is stable) DELETE FROM limit_orders WHERE id = 246 AND placed_at = current_timestamp::timestamp; --- commands with multiple rows are unsupported -INSERT INTO limit_orders VALUES (DEFAULT), (DEFAULT); -ERROR: cannot perform distributed planning for the given modification -DETAIL: Multi-row INSERTs to distributed tables are not supported. +-- commands with multiple rows are supported +INSERT INTO limit_orders VALUES (2037, 'GOOG', 5634, now(), 'buy', random()), + (2038, 'GOOG', 5634, now(), 'buy', random()), + (2039, 'GOOG', 5634, now(), 'buy', random()); -- Who says that? :) -- INSERT ... SELECT ... FROM commands are unsupported -- INSERT INTO limit_orders SELECT * FROM limit_orders; diff --git a/src/test/regress/expected/multi_mx_modifications.out b/src/test/regress/expected/multi_mx_modifications.out index 56e7cd59d..b9f0b09b6 100644 --- a/src/test/regress/expected/multi_mx_modifications.out +++ b/src/test/regress/expected/multi_mx_modifications.out @@ -104,10 +104,10 @@ ERROR: functions used in the WHERE clause of modification queries on distribute -- commands with mutable but non-volatile functions(ie: stable func.) in their quals -- (the cast to timestamp is because the timestamp_eq_timestamptz operator is stable) DELETE FROM limit_orders_mx WHERE id = 246 AND placed_at = current_timestamp::timestamp; --- commands with multiple rows are unsupported -INSERT INTO limit_orders_mx VALUES (DEFAULT), (DEFAULT); -ERROR: cannot perform distributed planning for the given modification -DETAIL: Multi-row INSERTs to distributed tables are not supported. +-- commands with multiple rows are supported +INSERT INTO limit_orders_mx VALUES (2037, 'GOOG', 5634, now(), 'buy', random()), + (2038, 'GOOG', 5634, now(), 'buy', random()), + (2039, 'GOOG', 5634, now(), 'buy', random()); -- connect back to the other node \c - - - :worker_1_port -- commands containing a CTE are unsupported diff --git a/src/test/regress/sql/multi_insert_select.sql b/src/test/regress/sql/multi_insert_select.sql index dfc70455a..d1d54d739 100644 --- a/src/test/regress/sql/multi_insert_select.sql +++ b/src/test/regress/sql/multi_insert_select.sql @@ -1375,6 +1375,15 @@ COPY raw_events_first (user_id, value_1) FROM STDIN DELIMITER ','; \. ROLLBACK; +-- Similarly, multi-row INSERTs will take part in transactions and reuse connections... +BEGIN; +INSERT INTO raw_events_first SELECT * FROM raw_events_second WHERE user_id = 100; +COPY raw_events_first (user_id, value_1) FROM STDIN DELIMITER ','; +104,104 +\. +INSERT INTO raw_events_first (user_id, value_1) VALUES (105, 105), (106, 106); +ROLLBACK; + -- selecting from views works CREATE VIEW test_view AS SELECT * FROM raw_events_first; INSERT INTO raw_events_first (user_id, time, value_1, value_2, value_3, value_4) VALUES diff --git a/src/test/regress/sql/multi_modifications.sql b/src/test/regress/sql/multi_modifications.sql index 518c4f37c..c3c1becf7 100644 --- a/src/test/regress/sql/multi_modifications.sql +++ b/src/test/regress/sql/multi_modifications.sql @@ -132,7 +132,7 @@ INSERT INTO limit_orders VALUES (34153, 'LEE', 5994, '2001-04-16 03:37:28', 'buy SET client_min_messages TO DEFAULT; --- commands with non-constant partition values are unsupported +-- commands with non-constant partition values are supported INSERT INTO limit_orders VALUES (random() * 100, 'ORCL', 152, '2011-08-25 11:50:45', 'sell', 0.58); @@ -146,8 +146,10 @@ DELETE FROM limit_orders WHERE id = 246 AND bidder_id = (random() * 1000); -- (the cast to timestamp is because the timestamp_eq_timestamptz operator is stable) DELETE FROM limit_orders WHERE id = 246 AND placed_at = current_timestamp::timestamp; --- commands with multiple rows are unsupported -INSERT INTO limit_orders VALUES (DEFAULT), (DEFAULT); +-- commands with multiple rows are supported +INSERT INTO limit_orders VALUES (2037, 'GOOG', 5634, now(), 'buy', random()), + (2038, 'GOOG', 5634, now(), 'buy', random()), + (2039, 'GOOG', 5634, now(), 'buy', random()); -- Who says that? :) -- INSERT ... SELECT ... FROM commands are unsupported diff --git a/src/test/regress/sql/multi_mx_modifications.sql b/src/test/regress/sql/multi_mx_modifications.sql index c10e8fe38..fd3874c02 100644 --- a/src/test/regress/sql/multi_mx_modifications.sql +++ b/src/test/regress/sql/multi_mx_modifications.sql @@ -76,8 +76,10 @@ DELETE FROM limit_orders_mx WHERE id = 246 AND bidder_id = (random() * 1000); -- (the cast to timestamp is because the timestamp_eq_timestamptz operator is stable) DELETE FROM limit_orders_mx WHERE id = 246 AND placed_at = current_timestamp::timestamp; --- commands with multiple rows are unsupported -INSERT INTO limit_orders_mx VALUES (DEFAULT), (DEFAULT); +-- commands with multiple rows are supported +INSERT INTO limit_orders_mx VALUES (2037, 'GOOG', 5634, now(), 'buy', random()), + (2038, 'GOOG', 5634, now(), 'buy', random()), + (2039, 'GOOG', 5634, now(), 'buy', random()); -- connect back to the other node \c - - - :worker_1_port