From 06461ca55f31491bc7dd49cd5b59215b08aadcca Mon Sep 17 00:00:00 2001 From: Onder Kalaci Date: Tue, 9 Jun 2020 11:20:38 +0200 Subject: [PATCH] Coerce types properly for INSERT Also, unify similar code-paths to rely on more accurate function. --- src/backend/distributed/commands/call.c | 13 +- .../planner/function_call_delegation.c | 10 +- .../planner/multi_router_planner.c | 27 ++- .../distributed/planner/shard_pruning.c | 44 ++++- src/include/distributed/shard_pruning.h | 4 +- .../expected/multi_prune_shard_list.out | 183 +++++++++++++++++- .../regress/sql/multi_prune_shard_list.sql | 41 ++++ 7 files changed, 301 insertions(+), 21 deletions(-) diff --git a/src/backend/distributed/commands/call.c b/src/backend/distributed/commands/call.c index 4ec2ad4a2..477d13caf 100644 --- a/src/backend/distributed/commands/call.c +++ b/src/backend/distributed/commands/call.c @@ -116,19 +116,18 @@ CallFuncExprRemotely(CallStmt *callStmt, DistObjectCacheEntry *procedure, ereport(DEBUG1, (errmsg("distribution argument value must be a constant"))); return false; } - Const *partitionValue = (Const *) partitionValueNode; - Datum partitionValueDatum = partitionValue->constvalue; + Const *partitionValue = (Const *) partitionValueNode; if (partitionValue->consttype != partitionColumn->vartype) { - CopyCoercionData coercionData; + bool missingOk = false; - ConversionPathForTypes(partitionValue->consttype, partitionColumn->vartype, - &coercionData); - - partitionValueDatum = CoerceColumnValue(partitionValueDatum, &coercionData); + partitionValue = + TransformPartitionRestrictionValue(partitionColumn, partitionValue, + missingOk); } + Datum partitionValueDatum = partitionValue->constvalue; ShardInterval *shardInterval = FindShardInterval(partitionValueDatum, distTable); if (shardInterval == NULL) { diff --git a/src/backend/distributed/planner/function_call_delegation.c b/src/backend/distributed/planner/function_call_delegation.c index 0fdb06043..dbd192109 100644 --- a/src/backend/distributed/planner/function_call_delegation.c +++ b/src/backend/distributed/planner/function_call_delegation.c @@ -314,12 +314,10 @@ TryToDelegateFunctionCall(DistributedPlanningContext *planContext) if (partitionValue->consttype != partitionColumn->vartype) { - CopyCoercionData coercionData; - - ConversionPathForTypes(partitionValue->consttype, partitionColumn->vartype, - &coercionData); - - partitionValueDatum = CoerceColumnValue(partitionValueDatum, &coercionData); + bool missingOk = false; + partitionValue = + TransformPartitionRestrictionValue(partitionColumn, partitionValue, + missingOk); } shardInterval = FindShardInterval(partitionValueDatum, distTable); diff --git a/src/backend/distributed/planner/multi_router_planner.c b/src/backend/distributed/planner/multi_router_planner.c index 39f831406..55c2747ef 100644 --- a/src/backend/distributed/planner/multi_router_planner.c +++ b/src/backend/distributed/planner/multi_router_planner.c @@ -2282,6 +2282,20 @@ TargetShardIntervalForFastPathQuery(Query *query, bool *isMultiShardQuery, if (inputDistributionKeyValue && !inputDistributionKeyValue->constisnull) { CitusTableCacheEntry *cache = GetCitusTableCacheEntry(relationId); + Var *distributionKey = cache->partitionColumn; + + /* + * We currently don't allow implicitly coerced values to be handled by fast- + * path planner. Still, let's be defensive for any future changes.. + */ + if (inputDistributionKeyValue->consttype != distributionKey->vartype) + { + bool missingOk = false; + inputDistributionKeyValue = + TransformPartitionRestrictionValue(distributionKey, + inputDistributionKeyValue, missingOk); + } + ShardInterval *cachedShardInterval = FindShardInterval(inputDistributionKeyValue->constvalue, cache); if (cachedShardInterval == NULL) @@ -2603,9 +2617,20 @@ BuildRoutesForInsert(Query *query, DeferredErrorMessage **planningError) if (partitionMethod == DISTRIBUTE_BY_HASH || partitionMethod == DISTRIBUTE_BY_RANGE) { + Var *distributionKey = cacheEntry->partitionColumn; + + /* handle coercions, if fails throw an error */ + if (partitionValueConst->consttype != distributionKey->vartype) + { + bool missingOk = false; + partitionValueConst = + TransformPartitionRestrictionValue(distributionKey, + partitionValueConst, + missingOk); + } + Datum partitionValue = partitionValueConst->constvalue; - cacheEntry = GetCitusTableCacheEntry(distributedTableId); ShardInterval *shardInterval = FindShardInterval(partitionValue, cacheEntry); if (shardInterval != NULL) { diff --git a/src/backend/distributed/planner/shard_pruning.c b/src/backend/distributed/planner/shard_pruning.c index 69191f3d3..ece01419f 100644 --- a/src/backend/distributed/planner/shard_pruning.c +++ b/src/backend/distributed/planner/shard_pruning.c @@ -93,6 +93,7 @@ #include "parser/parse_coerce.h" #include "utils/arrayaccess.h" #include "utils/catcache.h" +#include "utils/fmgrprotos.h" #include "utils/lsyscache.h" #include "utils/memutils.h" #include "utils/ruleutils.h" @@ -255,14 +256,14 @@ static void AddPartitionKeyRestrictionToInstance(ClauseWalkerContext *context, Const *constantClause); static bool VarConstOpExprClause(OpExpr *opClause, Var *partitionColumn, Var **varClause, Const **constantClause); -static Const * TransformPartitionRestrictionValue(Var *partitionColumn, - Const *restrictionValue); static void AddSAOPartitionKeyRestrictionToInstance(ClauseWalkerContext *context, ScalarArrayOpExpr * arrayOperatorExpression); static bool SAORestrictions(ScalarArrayOpExpr *arrayOperatorExpression, Var *partitionColumn, List **requestedRestrictions); +static void ErrorTypesDontMatch(Oid firstType, Oid firstCollId, Oid secondType, + Oid secondCollId); static bool IsValidHashRestriction(OpExpr *opClause); static void AddHashRestrictionToInstance(ClauseWalkerContext *context, OpExpr *opClause, Var *varClause, Const *constantClause); @@ -1111,7 +1112,7 @@ AddPartitionKeyRestrictionToInstance(ClauseWalkerContext *context, OpExpr *opCla { /* we want our restriction value in terms of the type of the partition column */ constantClause = TransformPartitionRestrictionValue(partitionColumn, - constantClause); + constantClause, true); if (constantClause == NULL) { /* couldn't coerce value, its invalid restriction */ @@ -1223,8 +1224,9 @@ AddPartitionKeyRestrictionToInstance(ClauseWalkerContext *context, OpExpr *opCla * It is conceivable that in some instances this may not be possible, * in those cases we will simply fail to prune partitions based on this clause. */ -static Const * -TransformPartitionRestrictionValue(Var *partitionColumn, Const *restrictionValue) +Const * +TransformPartitionRestrictionValue(Var *partitionColumn, Const *restrictionValue, + bool missingOk) { Node *transformedValue = coerce_to_target_type(NULL, (Node *) restrictionValue, restrictionValue->consttype, @@ -1236,6 +1238,13 @@ TransformPartitionRestrictionValue(Var *partitionColumn, Const *restrictionValue /* if NULL, no implicit coercion is possible between the types */ if (transformedValue == NULL) { + if (!missingOk) + { + ErrorTypesDontMatch(partitionColumn->vartype, partitionColumn->varcollid, + restrictionValue->consttype, + restrictionValue->constcollid); + } + return NULL; } @@ -1248,6 +1257,13 @@ TransformPartitionRestrictionValue(Var *partitionColumn, Const *restrictionValue /* if still not a constant, no immutable coercion matched */ if (!IsA(transformedValue, Const)) { + if (!missingOk) + { + ErrorTypesDontMatch(partitionColumn->vartype, partitionColumn->varcollid, + restrictionValue->consttype, + restrictionValue->constcollid); + } + return NULL; } @@ -1255,6 +1271,24 @@ TransformPartitionRestrictionValue(Var *partitionColumn, Const *restrictionValue } +/* + * ErrorTypesDontMatch throws an error explicitly printing the type names. + */ +static void +ErrorTypesDontMatch(Oid firstType, Oid firstCollId, Oid secondType, Oid secondCollId) +{ + Datum firstTypename = + DirectFunctionCall1Coll(regtypeout, firstCollId, ObjectIdGetDatum(firstType)); + + Datum secondTypename = + DirectFunctionCall1Coll(regtypeout, secondCollId, ObjectIdGetDatum(secondType)); + + ereport(ERROR, (errmsg("Cannot coerce %s to %s", + DatumGetCString(secondTypename), + DatumGetCString(firstTypename)))); +} + + /* * IsValidHashRestriction checks whether an operator clause is a valid restriction for hashed column. */ diff --git a/src/include/distributed/shard_pruning.h b/src/include/distributed/shard_pruning.h index bcbe60655..a780a7336 100644 --- a/src/include/distributed/shard_pruning.h +++ b/src/include/distributed/shard_pruning.h @@ -20,5 +20,7 @@ extern List * PruneShards(Oid relationId, Index rangeTableId, List *whereClauseList, Const **partitionValueConst); extern bool ContainsFalseClause(List *whereClauseList); - +extern Const * TransformPartitionRestrictionValue(Var *partitionColumn, + Const *restrictionValue, + bool missingOk); #endif /* SHARD_PRUNING_H_ */ diff --git a/src/test/regress/expected/multi_prune_shard_list.out b/src/test/regress/expected/multi_prune_shard_list.out index c0fdb8324..3aa44e766 100644 --- a/src/test/regress/expected/multi_prune_shard_list.out +++ b/src/test/regress/expected/multi_prune_shard_list.out @@ -363,9 +363,189 @@ EXECUTE coerce_numeric_2(1); 1 | test value (1 row) +-- Test that we can insert an integer literal into a numeric column as well +CREATE TABLE numeric_test (id numeric(6, 1), val int); +SELECT create_distributed_table('numeric_test', 'id'); + create_distributed_table +--------------------------------------------------------------------- + +(1 row) + +INSERT INTO numeric_test VALUES (21, 87) RETURNING *; + id | val +--------------------------------------------------------------------- + 21.0 | 87 +(1 row) + +SELECT * FROM numeric_test WHERE id = 21; + id | val +--------------------------------------------------------------------- + 21.0 | 87 +(1 row) + +SELECT * FROM numeric_test WHERE id = 21::int; + id | val +--------------------------------------------------------------------- + 21.0 | 87 +(1 row) + +SELECT * FROM numeric_test WHERE id = 21::bigint; + id | val +--------------------------------------------------------------------- + 21.0 | 87 +(1 row) + +SELECT * FROM numeric_test WHERE id = 21.0; + id | val +--------------------------------------------------------------------- + 21.0 | 87 +(1 row) + +SELECT * FROM numeric_test WHERE id = 21.0::numeric; + id | val +--------------------------------------------------------------------- + 21.0 | 87 +(1 row) + +PREPARE insert_p(int) AS INSERT INTO numeric_test VALUES ($1, 87) RETURNING *; +EXECUTE insert_p(1); + id | val +--------------------------------------------------------------------- + 1.0 | 87 +(1 row) + +EXECUTE insert_p(2); + id | val +--------------------------------------------------------------------- + 2.0 | 87 +(1 row) + +EXECUTE insert_p(3); + id | val +--------------------------------------------------------------------- + 3.0 | 87 +(1 row) + +EXECUTE insert_p(4); + id | val +--------------------------------------------------------------------- + 4.0 | 87 +(1 row) + +EXECUTE insert_p(5); + id | val +--------------------------------------------------------------------- + 5.0 | 87 +(1 row) + +EXECUTE insert_p(6); + id | val +--------------------------------------------------------------------- + 6.0 | 87 +(1 row) + +PREPARE select_p(int) AS SELECT * FROM numeric_test WHERE id=$1; +EXECUTE select_p(1); + id | val +--------------------------------------------------------------------- + 1.0 | 87 +(1 row) + +EXECUTE select_p(2); + id | val +--------------------------------------------------------------------- + 2.0 | 87 +(1 row) + +EXECUTE select_p(3); + id | val +--------------------------------------------------------------------- + 3.0 | 87 +(1 row) + +EXECUTE select_p(4); + id | val +--------------------------------------------------------------------- + 4.0 | 87 +(1 row) + +EXECUTE select_p(5); + id | val +--------------------------------------------------------------------- + 5.0 | 87 +(1 row) + +EXECUTE select_p(6); + id | val +--------------------------------------------------------------------- + 6.0 | 87 +(1 row) + +SET citus.enable_fast_path_router_planner TO false; +EXECUTE select_p(1); + id | val +--------------------------------------------------------------------- + 1.0 | 87 +(1 row) + +EXECUTE select_p(2); + id | val +--------------------------------------------------------------------- + 2.0 | 87 +(1 row) + +EXECUTE select_p(3); + id | val +--------------------------------------------------------------------- + 3.0 | 87 +(1 row) + +EXECUTE select_p(4); + id | val +--------------------------------------------------------------------- + 4.0 | 87 +(1 row) + +EXECUTE select_p(5); + id | val +--------------------------------------------------------------------- + 5.0 | 87 +(1 row) + +EXECUTE select_p(6); + id | val +--------------------------------------------------------------------- + 6.0 | 87 +(1 row) + +-- make sure that we don't return wrong resuls +INSERT INTO numeric_test VALUES (21.1, 87) RETURNING *; + id | val +--------------------------------------------------------------------- + 21.1 | 87 +(1 row) + +SELECT * FROM numeric_test WHERE id = 21; + id | val +--------------------------------------------------------------------- + 21.0 | 87 +(1 row) + +SELECT * FROM numeric_test WHERE id = 21::numeric; + id | val +--------------------------------------------------------------------- + 21.0 | 87 +(1 row) + +SELECT * FROM numeric_test WHERE id = 21.1::numeric; + id | val +--------------------------------------------------------------------- + 21.1 | 87 +(1 row) + SET search_path TO public; DROP SCHEMA prune_shard_list CASCADE; -NOTICE: drop cascades to 9 other objects +NOTICE: drop cascades to 10 other objects DETAIL: drop cascades to function prune_shard_list.prune_using_no_values(regclass) drop cascades to function prune_shard_list.prune_using_single_value(regclass,text) drop cascades to function prune_shard_list.prune_using_either_value(regclass,text,text) @@ -375,3 +555,4 @@ drop cascades to function prune_shard_list.print_sorted_shard_intervals(regclass drop cascades to table prune_shard_list.pruning drop cascades to table prune_shard_list.pruning_range drop cascades to table prune_shard_list.coerce_hash +drop cascades to table prune_shard_list.numeric_test diff --git a/src/test/regress/sql/multi_prune_shard_list.sql b/src/test/regress/sql/multi_prune_shard_list.sql index 192768bfe..224a759cf 100644 --- a/src/test/regress/sql/multi_prune_shard_list.sql +++ b/src/test/regress/sql/multi_prune_shard_list.sql @@ -177,5 +177,46 @@ EXECUTE coerce_numeric_2(1); EXECUTE coerce_numeric_2(1); +-- Test that we can insert an integer literal into a numeric column as well +CREATE TABLE numeric_test (id numeric(6, 1), val int); +SELECT create_distributed_table('numeric_test', 'id'); + +INSERT INTO numeric_test VALUES (21, 87) RETURNING *; +SELECT * FROM numeric_test WHERE id = 21; +SELECT * FROM numeric_test WHERE id = 21::int; +SELECT * FROM numeric_test WHERE id = 21::bigint; +SELECT * FROM numeric_test WHERE id = 21.0; +SELECT * FROM numeric_test WHERE id = 21.0::numeric; + +PREPARE insert_p(int) AS INSERT INTO numeric_test VALUES ($1, 87) RETURNING *; +EXECUTE insert_p(1); +EXECUTE insert_p(2); +EXECUTE insert_p(3); +EXECUTE insert_p(4); +EXECUTE insert_p(5); +EXECUTE insert_p(6); + +PREPARE select_p(int) AS SELECT * FROM numeric_test WHERE id=$1; +EXECUTE select_p(1); +EXECUTE select_p(2); +EXECUTE select_p(3); +EXECUTE select_p(4); +EXECUTE select_p(5); +EXECUTE select_p(6); + +SET citus.enable_fast_path_router_planner TO false; +EXECUTE select_p(1); +EXECUTE select_p(2); +EXECUTE select_p(3); +EXECUTE select_p(4); +EXECUTE select_p(5); +EXECUTE select_p(6); + +-- make sure that we don't return wrong resuls +INSERT INTO numeric_test VALUES (21.1, 87) RETURNING *; +SELECT * FROM numeric_test WHERE id = 21; +SELECT * FROM numeric_test WHERE id = 21::numeric; +SELECT * FROM numeric_test WHERE id = 21.1::numeric; + SET search_path TO public; DROP SCHEMA prune_shard_list CASCADE;