From 364b33a22dcc3840d9bec25061d820b0be77b140 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philip=20Dub=C3=A9?= Date: Fri, 18 Oct 2019 18:45:22 +0000 Subject: [PATCH] mark_aggregate_for_distributed_execution: Add some parameter validation --- src/backend/distributed/commands/function.c | 49 +++++++++++++++++++-- 1 file changed, 45 insertions(+), 4 deletions(-) diff --git a/src/backend/distributed/commands/function.c b/src/backend/distributed/commands/function.c index 48aab1541..8dc8302e3 100644 --- a/src/backend/distributed/commands/function.c +++ b/src/backend/distributed/commands/function.c @@ -240,6 +240,17 @@ mark_aggregate_for_distributed_execution(PG_FUNCTION_ARGS) Oid *argtypes = NULL; int aggregationStrategy = -1; + /* if called on NULL input, error out */ + if (funcOid == InvalidOid) + { + ereport(ERROR, (errmsg( + "the first parameter for mark_aggregate_for_distributed_execution() " + "should be a single a valid function or procedure name " + "followed by a list of parameters in parantheses"))); + } + + EnsureFunctionOwner(funcOid); + if (!PG_ARGISNULL(1)) { char *strategyParam = TextDatumGetCString(PG_GETARG_TEXT_P(1)); @@ -268,22 +279,27 @@ mark_aggregate_for_distributed_execution(PG_FUNCTION_ARGS) proctup = SearchSysCache1(PROCOID, funcOid); if (!HeapTupleIsValid(proctup)) { - goto early_exit; + elog(ERROR, "citus cache lookup failed for procedure %u", funcOid); } proc = (Form_pg_proc) GETSTRUCT(proctup); if (proc->prokind != PROKIND_AGGREGATE) { - goto early_exit; + elog(ERROR, "Procedure is not an aggregate"); } aggtup = SearchSysCache1(AGGFNOID, funcOid); if (!HeapTupleIsValid(aggtup)) { - goto early_exit; + elog(ERROR, "citus cache lookup failed for aggregate %u", funcOid); } agg = (Form_pg_aggregate) GETSTRUCT(aggtup); + if (agg->aggcombinefn == InvalidOid) + { + elog(ERROR, "Aggregate lacks a combine function."); + } + initStringInfo(&helperSuffix); appendStringInfoAggregateHelperSuffix(&helperSuffix, proc, agg); @@ -365,11 +381,36 @@ mark_aggregate_for_distributed_execution(PG_FUNCTION_ARGS) numargs, argtypes, false); } } + else if (aggregationStrategy == AGGREGATION_STRATEGY_COMMUTE) + { + proctup = SearchSysCache1(PROCOID, funcOid); + if (!HeapTupleIsValid(proctup)) + { + elog(ERROR, "citus cache lookup failed for procedure %u", funcOid); + } + proc = (Form_pg_proc) GETSTRUCT(proctup); + + if (proc->prokind != PROKIND_AGGREGATE) + { + elog(ERROR, "Procedure is not an aggregate"); + } + + if (proc->pronargs != 1) + { + elog(ERROR, + "Commute aggregation strategy only works for single argument aggregates."); + } + + if (proc->proargtypes.values[0] != proc->prorettype) + { + elog(ERROR, + "Commute aggregation strategy requires input type to match result type."); + } + } /* set strategy column value */ UpdateDistObjectAggregationStrategy(funcOid, aggregationStrategy); -early_exit: if (aggtup && HeapTupleIsValid(aggtup)) { ReleaseSysCache(aggtup);