AGGREGATE_STRATEGY_COMMUTE

fix_120_custom_aggregates_distribute_multiarg
Philip Dubé 2019-10-16 22:05:55 +00:00
parent a720929a8b
commit 9b3260f4df
5 changed files with 149 additions and 106 deletions

View File

@ -102,7 +102,7 @@ PG_FUNCTION_INFO_V1(mark_aggregate_for_distributed_execution);
Datum Datum
create_distributed_function(PG_FUNCTION_ARGS) create_distributed_function(PG_FUNCTION_ARGS)
{ {
RegProcedure funcOid = PG_GETARG_OID(0); RegProcedure funcOid = PG_ARGISNULL(0) ? InvalidOid : PG_GETARG_OID(0);
text *distributionArgumentNameText = NULL; /* optional */ text *distributionArgumentNameText = NULL; /* optional */
text *colocateWithText = NULL; /* optional */ text *colocateWithText = NULL; /* optional */
@ -228,18 +228,44 @@ create_distributed_function(PG_FUNCTION_ARGS)
Datum Datum
mark_aggregate_for_distributed_execution(PG_FUNCTION_ARGS) mark_aggregate_for_distributed_execution(PG_FUNCTION_ARGS)
{ {
RegProcedure funcOid = PG_GETARG_OID(0); RegProcedure funcOid = PG_ARGISNULL(0) ? InvalidOid : PG_GETARG_OID(0);
StringInfoData helperName; StringInfoData helperName;
StringInfoData helperSuffix; StringInfoData helperSuffix;
Form_pg_proc proc = NULL; Form_pg_proc proc = NULL;
Form_pg_aggregate agg = NULL; Form_pg_aggregate agg = NULL;
HeapTuple proctup = SearchSysCache1(PROCOID, funcOid); HeapTuple proctup = NULL;
HeapTuple aggtup = NULL; HeapTuple aggtup = NULL;
Oid helperOid = InvalidOid; Oid helperOid = InvalidOid;
int numargs = 0; int numargs = 0;
Oid *argtypes = NULL; Oid *argtypes = NULL;
int aggregationStrategy = -1;
if (!PG_ARGISNULL(1))
{
char *strategyParam = TextDatumGetCString(PG_GETARG_TEXT_P(1));
if (strcmp(strategyParam, "none") == 0)
{
aggregationStrategy = AGGREGATION_STRATEGY_NONE;
}
else if (strcmp(strategyParam, "combine") == 0)
{
aggregationStrategy = AGGREGATION_STRATEGY_COMBINE;
}
else if (strcmp(strategyParam, "commute") == 0)
{
aggregationStrategy = AGGREGATION_STRATEGY_COMMUTE;
}
}
if (aggregationStrategy == -1)
{
elog(ERROR,
"mark_aggregate_for_distributed_execution expects a strategy that is one of: 'none', 'combine', 'commute'");
}
if (aggregationStrategy == AGGREGATION_STRATEGY_COMBINE)
{
proctup = SearchSysCache1(PROCOID, funcOid);
if (!HeapTupleIsValid(proctup)) if (!HeapTupleIsValid(proctup))
{ {
goto early_exit; goto early_exit;
@ -271,7 +297,8 @@ mark_aggregate_for_distributed_execution(PG_FUNCTION_ARGS)
{ {
Oid coordArgTypes[2] = { BYTEAOID, ANYELEMENTOID }; Oid coordArgTypes[2] = { BYTEAOID, ANYELEMENTOID };
CreateAggregateHelper(helperName.data, COORD_COMBINE_AGGREGATE_NAME, proc, agg, CreateAggregateHelper(helperName.data, COORD_COMBINE_AGGREGATE_NAME, proc,
agg,
2, coordArgTypes, true); 2, coordArgTypes, true);
} }
@ -284,7 +311,8 @@ mark_aggregate_for_distributed_execution(PG_FUNCTION_ARGS)
if (helperOid == InvalidOid) if (helperOid == InvalidOid)
{ {
/* Also check that we have a matching worker_partial_sfunc(internal, oid, ...) */ /* Also check that we have a matching worker_partial_sfunc(internal, oid, ...) */
FuncCandidateList clist = FuncnameGetCandidates(list_make2(makeString("citus"), FuncCandidateList clist = FuncnameGetCandidates(list_make2(makeString(
"citus"),
makeString( makeString(
"worker_partial_agg_sfunc")), "worker_partial_agg_sfunc")),
numargs + 2, numargs + 2,
@ -310,7 +338,8 @@ mark_aggregate_for_distributed_execution(PG_FUNCTION_ARGS)
"CREATE FUNCTION citus.worker_partial_agg_sfunc(internal, oid"); "CREATE FUNCTION citus.worker_partial_agg_sfunc(internal, oid");
for (i = 0; i < numargs; i++) for (i = 0; i < numargs; i++)
{ {
appendStringInfo(&command, ", %s", format_type_be_qualified(argtypes[i])); appendStringInfo(&command, ", %s", format_type_be_qualified(
argtypes[i]));
} }
appendStringInfoString(&command, appendStringInfoString(&command,
") RETURNS internal AS 'citus' LANGUAGE C PARALLEL SAFE"); ") RETURNS internal AS 'citus' LANGUAGE C PARALLEL SAFE");
@ -331,12 +360,14 @@ mark_aggregate_for_distributed_execution(PG_FUNCTION_ARGS)
pfree(command.data); pfree(command.data);
} }
CreateAggregateHelper(helperName.data, WORKER_PARTIAL_AGGREGATE_NAME, proc, agg, CreateAggregateHelper(helperName.data, WORKER_PARTIAL_AGGREGATE_NAME, proc,
agg,
numargs, argtypes, false); numargs, argtypes, false);
} }
}
/* set strategy column value */ /* set strategy column value */
UpdateDistObjectAggregationStrategy(funcOid, AGGREGATION_STRATEGY_COMBINE); UpdateDistObjectAggregationStrategy(funcOid, aggregationStrategy);
early_exit: early_exit:
if (aggtup && HeapTupleIsValid(aggtup)) if (aggtup && HeapTupleIsValid(aggtup))

View File

@ -254,7 +254,7 @@ static List * WorkerAggregateExpressionList(Aggref *originalAggregate,
WorkerAggregateWalkerContext *walkerContextry); WorkerAggregateWalkerContext *walkerContextry);
static AggregateType GetAggregateType(Oid aggFunctionId); static AggregateType GetAggregateType(Oid aggFunctionId);
static Oid AggregateArgumentType(Aggref *aggregate); static Oid AggregateArgumentType(Aggref *aggregate);
static bool AggregateEnabledCustom(Oid aggregateOid); static int AggregateEnabledCustom(Oid aggregateOid);
static HeapTuple AggregateFunctionHelperOidHelper(Oid aggOid, StringInfo helperName); static HeapTuple AggregateFunctionHelperOidHelper(Oid aggOid, StringInfo helperName);
static Oid AggregateCoordCombineOid(Oid aggOid); static Oid AggregateCoordCombineOid(Oid aggOid);
static Oid AggregateWorkerPartialOid(Oid aggOid); static Oid AggregateWorkerPartialOid(Oid aggOid);
@ -1848,7 +1848,7 @@ MasterAggregateExpression(Aggref *originalAggregate,
newMasterExpression = (Expr *) unionAggregate; newMasterExpression = (Expr *) unionAggregate;
} }
else if (aggregateType == AGGREGATE_CUSTOM) else if (aggregateType == AGGREGATE_CUSTOM_COMBINE)
{ {
aggTuple = SearchSysCache1(AGGFNOID, aggTuple = SearchSysCache1(AGGFNOID,
ObjectIdGetDatum(originalAggregate->aggfnoid)); ObjectIdGetDatum(originalAggregate->aggfnoid));
@ -1869,12 +1869,14 @@ MasterAggregateExpression(Aggref *originalAggregate,
{ {
Const *aggparam = NULL; Const *aggparam = NULL;
Var *column = NULL; Var *column = NULL;
Const *nulltag = NULL;
List *aggArguments = NIL; List *aggArguments = NIL;
Aggref *newMasterAggregate = NULL; Aggref *newMasterAggregate = NULL;
Oid coordCombineId = AggregateCoordCombineOid(originalAggregate->aggfnoid); Oid coordCombineId = AggregateCoordCombineOid(originalAggregate->aggfnoid);
Oid workerReturnType = BYTEAOID; Oid workerReturnType = BYTEAOID;
int32 workerReturnTypeMod = -1; int32 workerReturnTypeMod = -1;
Oid workerCollationId = InvalidOid; Oid workerCollationId = InvalidOid;
Oid resultType = exprType((Node *) originalAggregate);
if (coordCombineId == InvalidOid) if (coordCombineId == InvalidOid)
{ {
@ -1887,10 +1889,13 @@ MasterAggregateExpression(Aggref *originalAggregate,
originalAggregate->aggfnoid), false, true); originalAggregate->aggfnoid), false, true);
column = makeVar(masterTableId, walkerContext->columnId, workerReturnType, column = makeVar(masterTableId, walkerContext->columnId, workerReturnType,
workerReturnTypeMod, workerCollationId, columnLevelsUp); workerReturnTypeMod, workerCollationId, columnLevelsUp);
nulltag = makeNullConst(resultType, -1, InvalidOid);
aggArguments = list_make1(makeTargetEntry((Expr *) aggparam, 1, NULL, false)); aggArguments = list_make1(makeTargetEntry((Expr *) aggparam, 1, NULL, false));
aggArguments = lappend(aggArguments, makeTargetEntry((Expr *) column, 2, NULL, aggArguments = lappend(aggArguments, makeTargetEntry((Expr *) column, 2, NULL,
false)); false));
aggArguments = lappend(aggArguments, makeTargetEntry((Expr *) nulltag, 3,
NULL, false));
/* coord_combine_agg(agg, workercol) */ /* coord_combine_agg(agg, workercol) */
newMasterAggregate = makeNode(Aggref); newMasterAggregate = makeNode(Aggref);
@ -1900,8 +1905,8 @@ MasterAggregateExpression(Aggref *originalAggregate,
newMasterAggregate->aggkind = AGGKIND_NORMAL; newMasterAggregate->aggkind = AGGKIND_NORMAL;
newMasterAggregate->aggfilter = originalAggregate->aggfilter; newMasterAggregate->aggfilter = originalAggregate->aggfilter;
newMasterAggregate->aggtranstype = INTERNALOID; newMasterAggregate->aggtranstype = INTERNALOID;
newMasterAggregate->aggargtypes = list_concat(list_make1_oid(OIDOID), newMasterAggregate->aggargtypes = list_make3_oid(OIDOID, BYTEAOID,
list_make1_oid(BYTEAOID)); resultType);
newMasterAggregate->aggsplit = AGGSPLIT_SIMPLE; newMasterAggregate->aggsplit = AGGSPLIT_SIMPLE;
newMasterExpression = (Expr *) newMasterAggregate; newMasterExpression = (Expr *) newMasterAggregate;
@ -1925,14 +1930,8 @@ MasterAggregateExpression(Aggref *originalAggregate,
int32 workerReturnTypeMod = exprTypmod((Node *) originalAggregate); int32 workerReturnTypeMod = exprTypmod((Node *) originalAggregate);
Oid workerCollationId = exprCollation((Node *) originalAggregate); Oid workerCollationId = exprCollation((Node *) originalAggregate);
const char *aggregateName = AggregateNames[aggregateType];
Oid aggregateFunctionId = AggregateFunctionOid(aggregateName, workerReturnType);
Oid masterReturnType = get_func_rettype(aggregateFunctionId);
Aggref *newMasterAggregate = copyObject(originalAggregate); Aggref *newMasterAggregate = copyObject(originalAggregate);
newMasterAggregate->aggdistinct = NULL; newMasterAggregate->aggdistinct = NULL;
newMasterAggregate->aggfnoid = aggregateFunctionId;
newMasterAggregate->aggtype = masterReturnType;
newMasterAggregate->aggfilter = NULL; newMasterAggregate->aggfilter = NULL;
column = makeVar(masterTableId, walkerContext->columnId, workerReturnType, column = makeVar(masterTableId, walkerContext->columnId, workerReturnType,
@ -2932,7 +2931,7 @@ WorkerAggregateExpressionList(Aggref *originalAggregate,
workerAggregateList = lappend(workerAggregateList, sumAggregate); workerAggregateList = lappend(workerAggregateList, sumAggregate);
workerAggregateList = lappend(workerAggregateList, countAggregate); workerAggregateList = lappend(workerAggregateList, countAggregate);
} }
else if (aggregateType == AGGREGATE_CUSTOM) else if (aggregateType == AGGREGATE_CUSTOM_COMBINE)
{ {
aggTuple = SearchSysCache1(AGGFNOID, aggTuple = SearchSysCache1(AGGFNOID,
ObjectIdGetDatum(originalAggregate->aggfnoid)); ObjectIdGetDatum(originalAggregate->aggfnoid));
@ -2990,11 +2989,16 @@ WorkerAggregateExpressionList(Aggref *originalAggregate,
workerAggregateList = list_make1(newWorkerAggregate); workerAggregateList = list_make1(newWorkerAggregate);
} }
else
{
elog(ERROR, "Aggregate lacks COMBINEFUNC");
}
} }
else else
{ {
/* /*
* All other aggregates are sent as they are to the worker nodes. * All other aggregates are sent as they are to the worker nodes.
* This includes AGGREGATE_CUSTOM_COMMUTE.
*/ */
Aggref *workerAggregate = copyObject(originalAggregate); Aggref *workerAggregate = copyObject(originalAggregate);
workerAggregateList = lappend(workerAggregateList, workerAggregate); workerAggregateList = lappend(workerAggregateList, workerAggregate);
@ -3023,6 +3027,7 @@ GetAggregateType(Oid aggFunctionId)
uint32 aggregateCount = 0; uint32 aggregateCount = 0;
uint32 aggregateIndex = 0; uint32 aggregateIndex = 0;
bool found = false; bool found = false;
int customAggregationStrategy;
/* look up the function name */ /* look up the function name */
aggregateProcName = get_func_name(aggFunctionId); aggregateProcName = get_func_name(aggFunctionId);
@ -3032,9 +3037,10 @@ GetAggregateType(Oid aggFunctionId)
aggFunctionId))); aggFunctionId)));
} }
if (AggregateEnabledCustom(aggFunctionId)) customAggregationStrategy = AggregateEnabledCustom(aggFunctionId);
if (customAggregationStrategy != AGGREGATION_STRATEGY_NONE)
{ {
return AGGREGATE_CUSTOM; return AGGREGATE_CUSTOM + customAggregationStrategy;
} }
aggregateCount = lengthof(AggregateNames); aggregateCount = lengthof(AggregateNames);
@ -3075,14 +3081,17 @@ AggregateArgumentType(Aggref *aggregate)
} }
static bool static int
AggregateEnabledCustom(Oid aggregateOid) AggregateEnabledCustom(Oid aggregateOid)
{ {
DistObjectCacheEntry *cacheEntry = LookupDistObjectCacheEntry(ProcedureRelationId, DistObjectCacheEntry *cacheEntry = LookupDistObjectCacheEntry(ProcedureRelationId,
aggregateOid, 0); aggregateOid, 0);
return cacheEntry != NULL && cacheEntry->isDistributed && if (cacheEntry == NULL || !cacheEntry->isDistributed)
cacheEntry->aggregationStrategy == AGGREGATION_STRATEGY_COMBINE; {
return AGGREGATION_STRATEGY_NONE;
}
return cacheEntry->aggregationStrategy;
} }

View File

@ -1,6 +1,6 @@
SET search_path = 'pg_catalog'; SET search_path = 'pg_catalog';
CREATE FUNCTION mark_aggregate_for_distributed_execution(regprocedure) CREATE FUNCTION mark_aggregate_for_distributed_execution(regprocedure, strategy text)
RETURNS void RETURNS void
AS 'MODULE_PATHNAME' AS 'MODULE_PATHNAME'
LANGUAGE C STRICT IMMUTABLE PARALLEL SAFE; LANGUAGE C STRICT IMMUTABLE PARALLEL SAFE;

View File

@ -66,5 +66,6 @@ typedef FormData_pg_dist_object *Form_pg_dist_object;
*/ */
#define AGGREGATION_STRATEGY_NONE 0 #define AGGREGATION_STRATEGY_NONE 0
#define AGGREGATION_STRATEGY_COMBINE 1 #define AGGREGATION_STRATEGY_COMBINE 1
#define AGGREGATION_STRATEGY_COMMUTE 2
#endif /* PG_DIST_OBJECT_H */ #endif /* PG_DIST_OBJECT_H */

View File

@ -79,8 +79,10 @@ typedef enum
AGGREGATE_TOPN_ADD_AGG = 18, AGGREGATE_TOPN_ADD_AGG = 18,
AGGREGATE_TOPN_UNION_AGG = 19, AGGREGATE_TOPN_UNION_AGG = 19,
/* AGGREGATE_CUSTOM must come last */ /* AGGREGATE_CUSTOM must come last. AGGREGATE_CUSTOM + AGGREGATION_STRATEGY = AGGREGATE_CUSTOM_STRATEGY */
AGGREGATE_CUSTOM = 20 AGGREGATE_CUSTOM = 20,
AGGREGATE_CUSTOM_COMBINE = 21,
AGGREGATE_CUSTOM_COMMUTE = 22
} AggregateType; } AggregateType;
/* /*