AGGREGATE_STRATEGY_COMMUTE

fix_120_custom_aggregates_distribute_multiarg
Philip Dubé 2019-10-16 22:05:55 +00:00
parent a720929a8b
commit 9b3260f4df
5 changed files with 149 additions and 106 deletions

View File

@ -102,7 +102,7 @@ PG_FUNCTION_INFO_V1(mark_aggregate_for_distributed_execution);
Datum
create_distributed_function(PG_FUNCTION_ARGS)
{
RegProcedure funcOid = PG_GETARG_OID(0);
RegProcedure funcOid = PG_ARGISNULL(0) ? InvalidOid : PG_GETARG_OID(0);
text *distributionArgumentNameText = NULL; /* optional */
text *colocateWithText = NULL; /* optional */
@ -228,115 +228,146 @@ create_distributed_function(PG_FUNCTION_ARGS)
Datum
mark_aggregate_for_distributed_execution(PG_FUNCTION_ARGS)
{
RegProcedure funcOid = PG_GETARG_OID(0);
RegProcedure funcOid = PG_ARGISNULL(0) ? InvalidOid : PG_GETARG_OID(0);
StringInfoData helperName;
StringInfoData helperSuffix;
Form_pg_proc proc = NULL;
Form_pg_aggregate agg = NULL;
HeapTuple proctup = SearchSysCache1(PROCOID, funcOid);
HeapTuple proctup = NULL;
HeapTuple aggtup = NULL;
Oid helperOid = InvalidOid;
int numargs = 0;
Oid *argtypes = NULL;
int aggregationStrategy = -1;
if (!HeapTupleIsValid(proctup))
if (!PG_ARGISNULL(1))
{
goto early_exit;
}
proc = (Form_pg_proc) GETSTRUCT(proctup);
if (proc->prokind != PROKIND_AGGREGATE)
{
goto early_exit;
}
aggtup = SearchSysCache1(AGGFNOID, funcOid);
if (!HeapTupleIsValid(aggtup))
{
goto early_exit;
}
agg = (Form_pg_aggregate) GETSTRUCT(aggtup);
initStringInfo(&helperSuffix);
appendStringInfoAggregateHelperSuffix(&helperSuffix, proc, agg);
/* coordinator_combine_agg */
initStringInfo(&helperName);
appendStringInfo(&helperName, "%s%s", COORD_COMBINE_AGGREGATE_NAME,
helperSuffix.data);
helperOid = CoordCombineAggOid(helperName.data);
if (helperOid == InvalidOid)
{
Oid coordArgTypes[2] = { BYTEAOID, ANYELEMENTOID };
CreateAggregateHelper(helperName.data, COORD_COMBINE_AGGREGATE_NAME, proc, agg,
2, coordArgTypes, true);
}
/* worker_partial_agg */
resetStringInfo(&helperName);
appendStringInfo(&helperName, "%s%s", WORKER_PARTIAL_AGGREGATE_NAME,
helperSuffix.data);
helperOid = WorkerPartialAggOid(helperName.data, proctup, &numargs, &argtypes);
if (helperOid == InvalidOid)
{
/* Also check that we have a matching worker_partial_sfunc(internal, oid, ...) */
FuncCandidateList clist = FuncnameGetCandidates(list_make2(makeString("citus"),
makeString(
"worker_partial_agg_sfunc")),
numargs + 2,
NIL, false, false, true);
for (; clist; clist = clist->next)
char *strategyParam = TextDatumGetCString(PG_GETARG_TEXT_P(1));
if (strcmp(strategyParam, "none") == 0)
{
if (clist->args[0] == INTERNALOID && clist->args[1] == OIDOID &&
memcmp(clist->args + 2, argtypes, numargs * sizeof(Oid)) == 0)
{
break;
return clist->oid;
}
aggregationStrategy = AGGREGATION_STRATEGY_NONE;
}
else if (strcmp(strategyParam, "combine") == 0)
{
aggregationStrategy = AGGREGATION_STRATEGY_COMBINE;
}
else if (strcmp(strategyParam, "commute") == 0)
{
aggregationStrategy = AGGREGATION_STRATEGY_COMMUTE;
}
}
if (aggregationStrategy == -1)
{
elog(ERROR,
"mark_aggregate_for_distributed_execution expects a strategy that is one of: 'none', 'combine', 'commute'");
}
if (aggregationStrategy == AGGREGATION_STRATEGY_COMBINE)
{
proctup = SearchSysCache1(PROCOID, funcOid);
if (!HeapTupleIsValid(proctup))
{
goto early_exit;
}
proc = (Form_pg_proc) GETSTRUCT(proctup);
if (proc->prokind != PROKIND_AGGREGATE)
{
goto early_exit;
}
if (clist == NULL)
aggtup = SearchSysCache1(AGGFNOID, funcOid);
if (!HeapTupleIsValid(aggtup))
{
int i;
StringInfoData command;
goto early_exit;
}
agg = (Form_pg_aggregate) GETSTRUCT(aggtup);
initStringInfo(&command);
appendStringInfoString(&command,
"CREATE FUNCTION citus.worker_partial_agg_sfunc(internal, oid");
for (i = 0; i < numargs; i++)
{
appendStringInfo(&command, ", %s", format_type_be_qualified(argtypes[i]));
}
appendStringInfoString(&command,
") RETURNS internal AS 'citus' LANGUAGE C PARALLEL SAFE");
initStringInfo(&helperSuffix);
appendStringInfoAggregateHelperSuffix(&helperSuffix, proc, agg);
SendCommandToWorkers(ALL_WORKERS, command.data);
/* coordinator_combine_agg */
initStringInfo(&helperName);
appendStringInfo(&helperName, "%s%s", COORD_COMBINE_AGGREGATE_NAME,
helperSuffix.data);
helperOid = CoordCombineAggOid(helperName.data);
/* TODO execute as CitusExtensionOwner */
if (SPI_connect() != SPI_OK_CONNECT)
{
elog(ERROR, "SPI_connect failed");
}
if (SPI_execute(command.data, false, 0) < 0)
{
elog(ERROR, "SPI_execute %s failed", command.data);
}
SPI_finish();
if (helperOid == InvalidOid)
{
Oid coordArgTypes[2] = { BYTEAOID, ANYELEMENTOID };
pfree(command.data);
CreateAggregateHelper(helperName.data, COORD_COMBINE_AGGREGATE_NAME, proc,
agg,
2, coordArgTypes, true);
}
CreateAggregateHelper(helperName.data, WORKER_PARTIAL_AGGREGATE_NAME, proc, agg,
numargs, argtypes, false);
/* worker_partial_agg */
resetStringInfo(&helperName);
appendStringInfo(&helperName, "%s%s", WORKER_PARTIAL_AGGREGATE_NAME,
helperSuffix.data);
helperOid = WorkerPartialAggOid(helperName.data, proctup, &numargs, &argtypes);
if (helperOid == InvalidOid)
{
/* Also check that we have a matching worker_partial_sfunc(internal, oid, ...) */
FuncCandidateList clist = FuncnameGetCandidates(list_make2(makeString(
"citus"),
makeString(
"worker_partial_agg_sfunc")),
numargs + 2,
NIL, false, false, true);
for (; clist; clist = clist->next)
{
if (clist->args[0] == INTERNALOID && clist->args[1] == OIDOID &&
memcmp(clist->args + 2, argtypes, numargs * sizeof(Oid)) == 0)
{
break;
return clist->oid;
}
}
if (clist == NULL)
{
int i;
StringInfoData command;
initStringInfo(&command);
appendStringInfoString(&command,
"CREATE FUNCTION citus.worker_partial_agg_sfunc(internal, oid");
for (i = 0; i < numargs; i++)
{
appendStringInfo(&command, ", %s", format_type_be_qualified(
argtypes[i]));
}
appendStringInfoString(&command,
") RETURNS internal AS 'citus' LANGUAGE C PARALLEL SAFE");
SendCommandToWorkers(ALL_WORKERS, command.data);
/* TODO execute as CitusExtensionOwner */
if (SPI_connect() != SPI_OK_CONNECT)
{
elog(ERROR, "SPI_connect failed");
}
if (SPI_execute(command.data, false, 0) < 0)
{
elog(ERROR, "SPI_execute %s failed", command.data);
}
SPI_finish();
pfree(command.data);
}
CreateAggregateHelper(helperName.data, WORKER_PARTIAL_AGGREGATE_NAME, proc,
agg,
numargs, argtypes, false);
}
}
/* set strategy column value */
UpdateDistObjectAggregationStrategy(funcOid, AGGREGATION_STRATEGY_COMBINE);
UpdateDistObjectAggregationStrategy(funcOid, aggregationStrategy);
early_exit:
if (aggtup && HeapTupleIsValid(aggtup))

View File

@ -254,7 +254,7 @@ static List * WorkerAggregateExpressionList(Aggref *originalAggregate,
WorkerAggregateWalkerContext *walkerContextry);
static AggregateType GetAggregateType(Oid aggFunctionId);
static Oid AggregateArgumentType(Aggref *aggregate);
static bool AggregateEnabledCustom(Oid aggregateOid);
static int AggregateEnabledCustom(Oid aggregateOid);
static HeapTuple AggregateFunctionHelperOidHelper(Oid aggOid, StringInfo helperName);
static Oid AggregateCoordCombineOid(Oid aggOid);
static Oid AggregateWorkerPartialOid(Oid aggOid);
@ -1848,7 +1848,7 @@ MasterAggregateExpression(Aggref *originalAggregate,
newMasterExpression = (Expr *) unionAggregate;
}
else if (aggregateType == AGGREGATE_CUSTOM)
else if (aggregateType == AGGREGATE_CUSTOM_COMBINE)
{
aggTuple = SearchSysCache1(AGGFNOID,
ObjectIdGetDatum(originalAggregate->aggfnoid));
@ -1869,12 +1869,14 @@ MasterAggregateExpression(Aggref *originalAggregate,
{
Const *aggparam = NULL;
Var *column = NULL;
Const *nulltag = NULL;
List *aggArguments = NIL;
Aggref *newMasterAggregate = NULL;
Oid coordCombineId = AggregateCoordCombineOid(originalAggregate->aggfnoid);
Oid workerReturnType = BYTEAOID;
int32 workerReturnTypeMod = -1;
Oid workerCollationId = InvalidOid;
Oid resultType = exprType((Node *) originalAggregate);
if (coordCombineId == InvalidOid)
{
@ -1887,10 +1889,13 @@ MasterAggregateExpression(Aggref *originalAggregate,
originalAggregate->aggfnoid), false, true);
column = makeVar(masterTableId, walkerContext->columnId, workerReturnType,
workerReturnTypeMod, workerCollationId, columnLevelsUp);
nulltag = makeNullConst(resultType, -1, InvalidOid);
aggArguments = list_make1(makeTargetEntry((Expr *) aggparam, 1, NULL, false));
aggArguments = lappend(aggArguments, makeTargetEntry((Expr *) column, 2, NULL,
false));
aggArguments = lappend(aggArguments, makeTargetEntry((Expr *) nulltag, 3,
NULL, false));
/* coord_combine_agg(agg, workercol) */
newMasterAggregate = makeNode(Aggref);
@ -1900,8 +1905,8 @@ MasterAggregateExpression(Aggref *originalAggregate,
newMasterAggregate->aggkind = AGGKIND_NORMAL;
newMasterAggregate->aggfilter = originalAggregate->aggfilter;
newMasterAggregate->aggtranstype = INTERNALOID;
newMasterAggregate->aggargtypes = list_concat(list_make1_oid(OIDOID),
list_make1_oid(BYTEAOID));
newMasterAggregate->aggargtypes = list_make3_oid(OIDOID, BYTEAOID,
resultType);
newMasterAggregate->aggsplit = AGGSPLIT_SIMPLE;
newMasterExpression = (Expr *) newMasterAggregate;
@ -1925,14 +1930,8 @@ MasterAggregateExpression(Aggref *originalAggregate,
int32 workerReturnTypeMod = exprTypmod((Node *) originalAggregate);
Oid workerCollationId = exprCollation((Node *) originalAggregate);
const char *aggregateName = AggregateNames[aggregateType];
Oid aggregateFunctionId = AggregateFunctionOid(aggregateName, workerReturnType);
Oid masterReturnType = get_func_rettype(aggregateFunctionId);
Aggref *newMasterAggregate = copyObject(originalAggregate);
newMasterAggregate->aggdistinct = NULL;
newMasterAggregate->aggfnoid = aggregateFunctionId;
newMasterAggregate->aggtype = masterReturnType;
newMasterAggregate->aggfilter = NULL;
column = makeVar(masterTableId, walkerContext->columnId, workerReturnType,
@ -2932,7 +2931,7 @@ WorkerAggregateExpressionList(Aggref *originalAggregate,
workerAggregateList = lappend(workerAggregateList, sumAggregate);
workerAggregateList = lappend(workerAggregateList, countAggregate);
}
else if (aggregateType == AGGREGATE_CUSTOM)
else if (aggregateType == AGGREGATE_CUSTOM_COMBINE)
{
aggTuple = SearchSysCache1(AGGFNOID,
ObjectIdGetDatum(originalAggregate->aggfnoid));
@ -2990,11 +2989,16 @@ WorkerAggregateExpressionList(Aggref *originalAggregate,
workerAggregateList = list_make1(newWorkerAggregate);
}
else
{
elog(ERROR, "Aggregate lacks COMBINEFUNC");
}
}
else
{
/*
* All other aggregates are sent as they are to the worker nodes.
* This includes AGGREGATE_CUSTOM_COMMUTE.
*/
Aggref *workerAggregate = copyObject(originalAggregate);
workerAggregateList = lappend(workerAggregateList, workerAggregate);
@ -3023,6 +3027,7 @@ GetAggregateType(Oid aggFunctionId)
uint32 aggregateCount = 0;
uint32 aggregateIndex = 0;
bool found = false;
int customAggregationStrategy;
/* look up the function name */
aggregateProcName = get_func_name(aggFunctionId);
@ -3032,9 +3037,10 @@ GetAggregateType(Oid aggFunctionId)
aggFunctionId)));
}
if (AggregateEnabledCustom(aggFunctionId))
customAggregationStrategy = AggregateEnabledCustom(aggFunctionId);
if (customAggregationStrategy != AGGREGATION_STRATEGY_NONE)
{
return AGGREGATE_CUSTOM;
return AGGREGATE_CUSTOM + customAggregationStrategy;
}
aggregateCount = lengthof(AggregateNames);
@ -3075,14 +3081,17 @@ AggregateArgumentType(Aggref *aggregate)
}
static bool
static int
AggregateEnabledCustom(Oid aggregateOid)
{
DistObjectCacheEntry *cacheEntry = LookupDistObjectCacheEntry(ProcedureRelationId,
aggregateOid, 0);
return cacheEntry != NULL && cacheEntry->isDistributed &&
cacheEntry->aggregationStrategy == AGGREGATION_STRATEGY_COMBINE;
if (cacheEntry == NULL || !cacheEntry->isDistributed)
{
return AGGREGATION_STRATEGY_NONE;
}
return cacheEntry->aggregationStrategy;
}

View File

@ -1,6 +1,6 @@
SET search_path = 'pg_catalog';
CREATE FUNCTION mark_aggregate_for_distributed_execution(regprocedure)
CREATE FUNCTION mark_aggregate_for_distributed_execution(regprocedure, strategy text)
RETURNS void
AS 'MODULE_PATHNAME'
LANGUAGE C STRICT IMMUTABLE PARALLEL SAFE;

View File

@ -66,5 +66,6 @@ typedef FormData_pg_dist_object *Form_pg_dist_object;
*/
#define AGGREGATION_STRATEGY_NONE 0
#define AGGREGATION_STRATEGY_COMBINE 1
#define AGGREGATION_STRATEGY_COMMUTE 2
#endif /* PG_DIST_OBJECT_H */

View File

@ -79,8 +79,10 @@ typedef enum
AGGREGATE_TOPN_ADD_AGG = 18,
AGGREGATE_TOPN_UNION_AGG = 19,
/* AGGREGATE_CUSTOM must come last */
AGGREGATE_CUSTOM = 20
/* AGGREGATE_CUSTOM must come last. AGGREGATE_CUSTOM + AGGREGATION_STRATEGY = AGGREGATE_CUSTOM_STRATEGY */
AGGREGATE_CUSTOM = 20,
AGGREGATE_CUSTOM_COMBINE = 21,
AGGREGATE_CUSTOM_COMMUTE = 22
} AggregateType;
/*