diff --git a/src/backend/distributed/planner/multi_router_planner.c b/src/backend/distributed/planner/multi_router_planner.c index 5a053f382..4cfb4c6ba 100644 --- a/src/backend/distributed/planner/multi_router_planner.c +++ b/src/backend/distributed/planner/multi_router_planner.c @@ -202,6 +202,7 @@ CreateSingleTaskRouterPlan(Query *originalQuery, Query *query, Task *task = NULL; List *placementList = NIL; MultiPlan *multiPlan = CitusMakeNode(MultiPlan); + bool updateFromQuery = UpdateFromQuery(query); multiPlan->operation = query->commandType; @@ -211,7 +212,20 @@ CreateSingleTaskRouterPlan(Query *originalQuery, Query *query, modifyTask = true; } - if (modifyTask) + if (updateFromQuery) + { + DeferredErrorMessage *planningError = NULL; + planningError = ModifyQuerySupported(query); + + if (planningError != NULL) + { + multiPlan->planningError = planningError; + return multiPlan; + } + + task = RouterSelectTask(originalQuery, restrictionContext, &placementList); + } + else if (modifyTask) { Oid distributedTableId = ExtractFirstDistributedTableId(originalQuery); ShardInterval *targetShardInterval = NULL; @@ -270,6 +284,33 @@ CreateSingleTaskRouterPlan(Query *originalQuery, Query *query, } +bool +UpdateFromQuery(Query *query) +{ + CmdType commandType = query->commandType; + List *rangeTableList = query->rtable; + bool hasSubquery = false; + ListCell *rangeTableCell = NULL; + + foreach(rangeTableCell, rangeTableList) + { + RangeTblEntry *rangeTableEntry = (RangeTblEntry *) lfirst(rangeTableCell); + if (rangeTableEntry->rtekind == RTE_SUBQUERY) + { + hasSubquery = true; + break; + } + } + + if (commandType == CMD_UPDATE && hasSubquery) + { + return true; + } + + return false; +} + + /* * Creates a router plan for INSERT ... SELECT queries which could consists of * multiple tasks. @@ -1233,7 +1274,6 @@ ModifyQuerySupported(Query *queryTree) /* * Reject subqueries which are in SELECT or WHERE clause. - * Queries which include subqueries in FROM clauses are rejected below. */ if (queryTree->hasSubLinks == true) { @@ -1298,6 +1338,10 @@ ModifyQuerySupported(Query *queryTree) { hasValuesScan = true; } + else if (UpdateFromQuery(queryTree)) + { + continue; + } else { /* @@ -1341,7 +1385,7 @@ ModifyQuerySupported(Query *queryTree) * Queries like "INSERT INTO table_name ON CONFLICT DO UPDATE (col) SET other_col = ''" * contains two range table entries, and we have to allow them. */ - if (commandType != CMD_INSERT && queryTableCount != 1) + if (commandType != CMD_INSERT && queryTableCount != 1 && !UpdateFromQuery(queryTree)) { return DeferredError(ERRCODE_FEATURE_NOT_SUPPORTED, "cannot perform distributed planning for the given" @@ -2262,7 +2306,6 @@ RouterSelectTask(Query *originalQuery, RelationRestrictionContext *restrictionCo placementList, &shardId, &relationShardList, replacePrunedQueryWithDummy); - if (!queryRoutable) { return NULL; @@ -2316,7 +2359,7 @@ RouterSelectQuery(Query *originalQuery, RelationRestrictionContext *restrictionC return false; } - Assert(commandType == CMD_SELECT); + Assert(commandType == CMD_SELECT || UpdateFromQuery(originalQuery)); foreach(prunedRelationShardListCell, prunedRelationShardList) { @@ -2429,7 +2472,7 @@ TargetShardIntervalsForSelect(Query *query, List *prunedRelationShardList = NIL; ListCell *restrictionCell = NULL; - Assert(query->commandType == CMD_SELECT); + Assert(query->commandType == CMD_SELECT || UpdateFromQuery(query)); Assert(restrictionContext != NULL); foreach(restrictionCell, restrictionContext->relationRestrictionList) diff --git a/src/backend/distributed/utils/ruleutils_96.c b/src/backend/distributed/utils/ruleutils_96.c index 3e4a38ddf..188377472 100644 --- a/src/backend/distributed/utils/ruleutils_96.c +++ b/src/backend/distributed/utils/ruleutils_96.c @@ -48,6 +48,7 @@ #include "common/keywords.h" #include "distributed/citus_nodefuncs.h" #include "distributed/citus_ruleutils.h" +#include "distributed/multi_router_planner.h" #include "executor/spi.h" #include "foreign/foreign.h" #include "funcapi.h" @@ -3152,17 +3153,40 @@ get_update_query_def(Query *query, deparse_context *context) * Start the query with UPDATE relname SET */ rte = rt_fetch(query->resultRelation, query->rtable); - Assert(rte->rtekind == RTE_RELATION); + + Assert(rte->rtekind == RTE_RELATION || UpdateFromQuery(query)); if (PRETTY_INDENT(context)) { appendStringInfoChar(buf, ' '); context->indentLevel += PRETTYINDENT_STD; } - appendStringInfo(buf, "UPDATE %s%s", - only_marker(rte), - generate_relation_or_shard_name(rte->relid, - context->distrelid, - context->shardid, NIL)); + + /* if it's a shard, do differently */ + if (GetRangeTblKind(rte) == CITUS_RTE_SHARD) + { + char *fragmentSchemaName = NULL; + char *fragmentTableName = NULL; + + ExtractRangeTblExtraData(rte, NULL, &fragmentSchemaName, &fragmentTableName, NULL); + + /* Use schema and table name from the remote alias */ + appendStringInfo(buf, "UPDATE %s%s", + only_marker(rte), + generate_fragment_name(fragmentSchemaName, fragmentTableName)); + + if(rte->eref != NULL) + appendStringInfo(buf, " %s", + quote_identifier(rte->eref->aliasname)); + } + else + { + appendStringInfo(buf, "UPDATE %s%s", + only_marker(rte), + generate_relation_or_shard_name(rte->relid, + context->distrelid, + context->shardid, NIL)); + } + if (rte->alias != NULL) appendStringInfo(buf, " %s", quote_identifier(rte->alias->aliasname)); diff --git a/src/include/distributed/multi_router_planner.h b/src/include/distributed/multi_router_planner.h index 0dd0f1e7b..da57d2ce6 100644 --- a/src/include/distributed/multi_router_planner.h +++ b/src/include/distributed/multi_router_planner.h @@ -36,6 +36,7 @@ extern bool RouterSelectQuery(Query *originalQuery, List **placementList, uint64 *anchorShardId, List **relationShardList, bool replacePrunedQueryWithDummy); extern DeferredErrorMessage * ModifyQuerySupported(Query *queryTree); +extern bool UpdateFromQuery(Query *query); extern Query * ReorderInsertSelectTargetLists(Query *originalQuery, RangeTblEntry *insertRte, RangeTblEntry *subqueryRte); diff --git a/src/test/regress/expected/multi_modifications.out b/src/test/regress/expected/multi_modifications.out index b8b304a89..b9a50dcf2 100644 --- a/src/test/regress/expected/multi_modifications.out +++ b/src/test/regress/expected/multi_modifications.out @@ -657,3 +657,39 @@ INSERT INTO app_analytics_events (app_id, name) VALUES (103, 'Mynt') RETURNING * (1 row) DROP TABLE app_analytics_events; +-- test UPDATE ... FROM +CREATE TABLE raw_table (id bigint, value bigint); +CREATE TABLE summary_table (id bigint, average_value numeric); +SELECT create_distributed_table('raw_table', 'id'); + create_distributed_table +-------------------------- + +(1 row) + +SELECT create_distributed_table('summary_table', 'id'); + create_distributed_table +-------------------------- + +(1 row) + +INSERT INTO raw_table VALUES (1, 100); +INSERT INTO raw_table VALUES (1, 200); +INSERT INTO summary_table VALUES (1, NULL); +SELECT * FROM summary_table WHERE id = 1; + id | average_value +----+--------------- + 1 | +(1 row) + +UPDATE summary_table SET average_value = average_query.average FROM ( + SELECT avg(value) AS average FROM raw_table WHERE id = 1 + ) average_query +WHERE id = 1; +SELECT * FROM summary_table WHERE id = 1; + id | average_value +----+---------------------- + 1 | 150.0000000000000000 +(1 row) + +DROP TABLE raw_table; +DROP TABLE summary_table; diff --git a/src/test/regress/sql/multi_modifications.sql b/src/test/regress/sql/multi_modifications.sql index 583301fe8..cbec5bd7a 100644 --- a/src/test/regress/sql/multi_modifications.sql +++ b/src/test/regress/sql/multi_modifications.sql @@ -429,3 +429,27 @@ INSERT INTO app_analytics_events (app_id, name) VALUES (102, 'Wayz') RETURNING i INSERT INTO app_analytics_events (app_id, name) VALUES (103, 'Mynt') RETURNING *; DROP TABLE app_analytics_events; + +-- test UPDATE ... FROM +CREATE TABLE raw_table (id bigint, value bigint); +CREATE TABLE summary_table (id bigint, average_value numeric); + +SELECT create_distributed_table('raw_table', 'id'); +SELECT create_distributed_table('summary_table', 'id'); + +INSERT INTO raw_table VALUES (1, 100); +INSERT INTO raw_table VALUES (1, 200); + +INSERT INTO summary_table VALUES (1, NULL); + +SELECT * FROM summary_table WHERE id = 1; + +UPDATE summary_table SET average_value = average_query.average FROM ( + SELECT avg(value) AS average FROM raw_table WHERE id = 1 + ) average_query +WHERE id = 1; + +SELECT * FROM summary_table WHERE id = 1; + +DROP TABLE raw_table; +DROP TABLE summary_table;