diff --git a/src/backend/distributed/planner/multi_router_planner.c b/src/backend/distributed/planner/multi_router_planner.c index 5a053f382..a8cab8b5d 100644 --- a/src/backend/distributed/planner/multi_router_planner.c +++ b/src/backend/distributed/planner/multi_router_planner.c @@ -202,6 +202,7 @@ CreateSingleTaskRouterPlan(Query *originalQuery, Query *query, Task *task = NULL; List *placementList = NIL; MultiPlan *multiPlan = CitusMakeNode(MultiPlan); + bool updateFromQuery = UpdateFromQuery(query); multiPlan->operation = query->commandType; @@ -211,7 +212,20 @@ CreateSingleTaskRouterPlan(Query *originalQuery, Query *query, modifyTask = true; } - if (modifyTask) + if (updateFromQuery) + { + DeferredErrorMessage *planningError = NULL; + planningError = ModifyQuerySupported(query); + + if (planningError != NULL) + { + multiPlan->planningError = planningError; + return multiPlan; + } + + task = RouterSelectTask(originalQuery, restrictionContext, &placementList); + } + else if (modifyTask) { Oid distributedTableId = ExtractFirstDistributedTableId(originalQuery); ShardInterval *targetShardInterval = NULL; @@ -270,6 +284,38 @@ CreateSingleTaskRouterPlan(Query *originalQuery, Query *query, } +bool +UpdateFromQuery(Query *query) +{ + CmdType commandType = query->commandType; + List *rangeTableList = query->rtable; + bool hasSubquery = false; + ListCell *rangeTableCell = NULL; + + if (query->hasSubLinks) + { + hasSubquery = true; + } + + foreach(rangeTableCell, rangeTableList) + { + RangeTblEntry *rangeTableEntry = (RangeTblEntry *) lfirst(rangeTableCell); + if (rangeTableEntry->rtekind == RTE_SUBQUERY) + { + hasSubquery = true; + break; + } + } + + if (commandType == CMD_UPDATE && hasSubquery) + { + return true; + } + + return false; +} + + /* * Creates a router plan for INSERT ... SELECT queries which could consists of * multiple tasks. @@ -1233,9 +1279,8 @@ ModifyQuerySupported(Query *queryTree) /* * Reject subqueries which are in SELECT or WHERE clause. - * Queries which include subqueries in FROM clauses are rejected below. */ - if (queryTree->hasSubLinks == true) + if (queryTree->hasSubLinks == true && !UpdateFromQuery(queryTree)) { return DeferredError(ERRCODE_FEATURE_NOT_SUPPORTED, "subqueries are not supported in distributed modifications", @@ -1298,6 +1343,10 @@ ModifyQuerySupported(Query *queryTree) { hasValuesScan = true; } + else if (UpdateFromQuery(queryTree)) + { + continue; + } else { /* @@ -1341,7 +1390,7 @@ ModifyQuerySupported(Query *queryTree) * Queries like "INSERT INTO table_name ON CONFLICT DO UPDATE (col) SET other_col = ''" * contains two range table entries, and we have to allow them. */ - if (commandType != CMD_INSERT && queryTableCount != 1) + if (commandType != CMD_INSERT && queryTableCount != 1 && !UpdateFromQuery(queryTree)) { return DeferredError(ERRCODE_FEATURE_NOT_SUPPORTED, "cannot perform distributed planning for the given" @@ -2262,7 +2311,6 @@ RouterSelectTask(Query *originalQuery, RelationRestrictionContext *restrictionCo placementList, &shardId, &relationShardList, replacePrunedQueryWithDummy); - if (!queryRoutable) { return NULL; @@ -2316,7 +2364,7 @@ RouterSelectQuery(Query *originalQuery, RelationRestrictionContext *restrictionC return false; } - Assert(commandType == CMD_SELECT); + Assert(commandType == CMD_SELECT || UpdateFromQuery(originalQuery)); foreach(prunedRelationShardListCell, prunedRelationShardList) { @@ -2429,7 +2477,7 @@ TargetShardIntervalsForSelect(Query *query, List *prunedRelationShardList = NIL; ListCell *restrictionCell = NULL; - Assert(query->commandType == CMD_SELECT); + Assert(query->commandType == CMD_SELECT || UpdateFromQuery(query)); Assert(restrictionContext != NULL); foreach(restrictionCell, restrictionContext->relationRestrictionList) diff --git a/src/backend/distributed/utils/ruleutils_95.c b/src/backend/distributed/utils/ruleutils_95.c index f8a60fe7d..8c4771886 100644 --- a/src/backend/distributed/utils/ruleutils_95.c +++ b/src/backend/distributed/utils/ruleutils_95.c @@ -35,6 +35,7 @@ #include "catalog/pg_type.h" #include "distributed/citus_nodefuncs.h" #include "distributed/citus_ruleutils.h" +#include "distributed/multi_router_planner.h" #include "commands/defrem.h" #include "commands/extension.h" #include "foreign/foreign.h" @@ -3144,17 +3145,39 @@ get_update_query_def(Query *query, deparse_context *context) * Start the query with UPDATE relname SET */ rte = rt_fetch(query->resultRelation, query->rtable); - Assert(rte->rtekind == RTE_RELATION); + Assert(rte->rtekind == RTE_RELATION || UpdateFromQuery(query)); if (PRETTY_INDENT(context)) { appendStringInfoChar(buf, ' '); context->indentLevel += PRETTYINDENT_STD; } - appendStringInfo(buf, "UPDATE %s%s", - only_marker(rte), - generate_relation_or_shard_name(rte->relid, - context->distrelid, - context->shardid, NIL)); + + /* if it's a shard, do differently */ + if (GetRangeTblKind(rte) == CITUS_RTE_SHARD) + { + char *fragmentSchemaName = NULL; + char *fragmentTableName = NULL; + + ExtractRangeTblExtraData(rte, NULL, &fragmentSchemaName, &fragmentTableName, NULL); + + /* Use schema and table name from the remote alias */ + appendStringInfo(buf, "UPDATE %s%s", + only_marker(rte), + generate_fragment_name(fragmentSchemaName, fragmentTableName)); + + if(rte->eref != NULL) + appendStringInfo(buf, " %s", + quote_identifier(rte->eref->aliasname)); + } + else + { + appendStringInfo(buf, "UPDATE %s%s", + only_marker(rte), + generate_relation_or_shard_name(rte->relid, + context->distrelid, + context->shardid, NIL)); + } + if (rte->alias != NULL) appendStringInfo(buf, " %s", quote_identifier(rte->alias->aliasname)); diff --git a/src/backend/distributed/utils/ruleutils_96.c b/src/backend/distributed/utils/ruleutils_96.c index 3e4a38ddf..188377472 100644 --- a/src/backend/distributed/utils/ruleutils_96.c +++ b/src/backend/distributed/utils/ruleutils_96.c @@ -48,6 +48,7 @@ #include "common/keywords.h" #include "distributed/citus_nodefuncs.h" #include "distributed/citus_ruleutils.h" +#include "distributed/multi_router_planner.h" #include "executor/spi.h" #include "foreign/foreign.h" #include "funcapi.h" @@ -3152,17 +3153,40 @@ get_update_query_def(Query *query, deparse_context *context) * Start the query with UPDATE relname SET */ rte = rt_fetch(query->resultRelation, query->rtable); - Assert(rte->rtekind == RTE_RELATION); + + Assert(rte->rtekind == RTE_RELATION || UpdateFromQuery(query)); if (PRETTY_INDENT(context)) { appendStringInfoChar(buf, ' '); context->indentLevel += PRETTYINDENT_STD; } - appendStringInfo(buf, "UPDATE %s%s", - only_marker(rte), - generate_relation_or_shard_name(rte->relid, - context->distrelid, - context->shardid, NIL)); + + /* if it's a shard, do differently */ + if (GetRangeTblKind(rte) == CITUS_RTE_SHARD) + { + char *fragmentSchemaName = NULL; + char *fragmentTableName = NULL; + + ExtractRangeTblExtraData(rte, NULL, &fragmentSchemaName, &fragmentTableName, NULL); + + /* Use schema and table name from the remote alias */ + appendStringInfo(buf, "UPDATE %s%s", + only_marker(rte), + generate_fragment_name(fragmentSchemaName, fragmentTableName)); + + if(rte->eref != NULL) + appendStringInfo(buf, " %s", + quote_identifier(rte->eref->aliasname)); + } + else + { + appendStringInfo(buf, "UPDATE %s%s", + only_marker(rte), + generate_relation_or_shard_name(rte->relid, + context->distrelid, + context->shardid, NIL)); + } + if (rte->alias != NULL) appendStringInfo(buf, " %s", quote_identifier(rte->alias->aliasname)); diff --git a/src/include/distributed/multi_router_planner.h b/src/include/distributed/multi_router_planner.h index 0dd0f1e7b..da57d2ce6 100644 --- a/src/include/distributed/multi_router_planner.h +++ b/src/include/distributed/multi_router_planner.h @@ -36,6 +36,7 @@ extern bool RouterSelectQuery(Query *originalQuery, List **placementList, uint64 *anchorShardId, List **relationShardList, bool replacePrunedQueryWithDummy); extern DeferredErrorMessage * ModifyQuerySupported(Query *queryTree); +extern bool UpdateFromQuery(Query *query); extern Query * ReorderInsertSelectTargetLists(Query *originalQuery, RangeTblEntry *insertRte, RangeTblEntry *subqueryRte); diff --git a/src/test/regress/expected/multi_modifications.out b/src/test/regress/expected/multi_modifications.out index b8b304a89..21b40b9ab 100644 --- a/src/test/regress/expected/multi_modifications.out +++ b/src/test/regress/expected/multi_modifications.out @@ -657,3 +657,47 @@ INSERT INTO app_analytics_events (app_id, name) VALUES (103, 'Mynt') RETURNING * (1 row) DROP TABLE app_analytics_events; +-- test UPDATE ... FROM +CREATE TABLE raw_table (id bigint, value bigint); +CREATE TABLE summary_table (id bigint, min_value numeric, average_value numeric); +SELECT create_distributed_table('raw_table', 'id'); + create_distributed_table +-------------------------- + +(1 row) + +SELECT create_distributed_table('summary_table', 'id'); + create_distributed_table +-------------------------- + +(1 row) + +INSERT INTO raw_table VALUES (1, 100); +INSERT INTO raw_table VALUES (1, 200); +INSERT INTO summary_table VALUES (1, NULL); +SELECT * FROM summary_table WHERE id = 1; + id | min_value | average_value +----+-----------+--------------- + 1 | | +(1 row) + +UPDATE summary_table SET average_value = average_query.average FROM ( + SELECT avg(value) AS average FROM raw_table WHERE id = 1 + ) average_query +WHERE id = 1; +SELECT * FROM summary_table WHERE id = 1; + id | min_value | average_value +----+-----------+---------------------- + 1 | | 150.0000000000000000 +(1 row) + +UPDATE summary_table SET min_value = 100 + WHERE id IN (SELECT id FROM raw_table WHERE id = 1 and value > 100) AND id = 1; +SELECT * FROM summary_table WHERE id = 1; + id | min_value | average_value +----+-----------+---------------------- + 1 | 100 | 150.0000000000000000 +(1 row) + +DROP TABLE raw_table; +DROP TABLE summary_table; diff --git a/src/test/regress/expected/multi_reference_table.out b/src/test/regress/expected/multi_reference_table.out index dc1b94891..4188d9916 100644 --- a/src/test/regress/expected/multi_reference_table.out +++ b/src/test/regress/expected/multi_reference_table.out @@ -39,7 +39,9 @@ SELECT FROM pg_dist_shard_placement WHERE - shardid IN (SELECT shardid FROM pg_dist_shard WHERE logicalrelid = 'reference_table_test'::regclass); + shardid IN (SELECT shardid FROM pg_dist_shard WHERE logicalrelid = 'reference_table_test'::regclass) +ORDER BY + nodeport; shardid | shardstate | nodename | nodeport ---------+------------+-----------+---------- 1250000 | 1 | localhost | 57637 @@ -822,7 +824,9 @@ SELECT FROM pg_dist_shard_placement WHERE - shardid IN (SELECT shardid FROM pg_dist_shard WHERE logicalrelid = 'reference_table_test_fourth'::regclass); + shardid IN (SELECT shardid FROM pg_dist_shard WHERE logicalrelid = 'reference_table_test_fourth'::regclass) +ORDER BY + nodeport; shardid | shardstate | nodename | nodeport ---------+------------+-----------+---------- 1250003 | 1 | localhost | 57637 diff --git a/src/test/regress/sql/multi_modifications.sql b/src/test/regress/sql/multi_modifications.sql index 583301fe8..14f3729eb 100644 --- a/src/test/regress/sql/multi_modifications.sql +++ b/src/test/regress/sql/multi_modifications.sql @@ -429,3 +429,32 @@ INSERT INTO app_analytics_events (app_id, name) VALUES (102, 'Wayz') RETURNING i INSERT INTO app_analytics_events (app_id, name) VALUES (103, 'Mynt') RETURNING *; DROP TABLE app_analytics_events; + +-- test UPDATE ... FROM +CREATE TABLE raw_table (id bigint, value bigint); +CREATE TABLE summary_table (id bigint, min_value numeric, average_value numeric); + +SELECT create_distributed_table('raw_table', 'id'); +SELECT create_distributed_table('summary_table', 'id'); + +INSERT INTO raw_table VALUES (1, 100); +INSERT INTO raw_table VALUES (1, 200); + +INSERT INTO summary_table VALUES (1, NULL); + +SELECT * FROM summary_table WHERE id = 1; + +UPDATE summary_table SET average_value = average_query.average FROM ( + SELECT avg(value) AS average FROM raw_table WHERE id = 1 + ) average_query +WHERE id = 1; + +SELECT * FROM summary_table WHERE id = 1; + +UPDATE summary_table SET min_value = 100 + WHERE id IN (SELECT id FROM raw_table WHERE id = 1 and value > 100) AND id = 1; + +SELECT * FROM summary_table WHERE id = 1; + +DROP TABLE raw_table; +DROP TABLE summary_table; diff --git a/src/test/regress/sql/multi_reference_table.sql b/src/test/regress/sql/multi_reference_table.sql index 46dcaa3bf..103c60bef 100644 --- a/src/test/regress/sql/multi_reference_table.sql +++ b/src/test/regress/sql/multi_reference_table.sql @@ -23,12 +23,15 @@ FROM pg_dist_shard WHERE logicalrelid = 'reference_table_test'::regclass; + SELECT shardid, shardstate, nodename, nodeport FROM pg_dist_shard_placement WHERE - shardid IN (SELECT shardid FROM pg_dist_shard WHERE logicalrelid = 'reference_table_test'::regclass); + shardid IN (SELECT shardid FROM pg_dist_shard WHERE logicalrelid = 'reference_table_test'::regclass) +ORDER BY + nodeport; -- check whether data was copied into distributed table SELECT * FROM reference_table_test; @@ -501,7 +504,9 @@ SELECT FROM pg_dist_shard_placement WHERE - shardid IN (SELECT shardid FROM pg_dist_shard WHERE logicalrelid = 'reference_table_test_fourth'::regclass); + shardid IN (SELECT shardid FROM pg_dist_shard WHERE logicalrelid = 'reference_table_test_fourth'::regclass) +ORDER BY + nodeport; -- let's not run some update/delete queries on arbitrary columns DELETE FROM