mirror of https://github.com/citusdata/citus.git
AGGREGATE_STRATEGY_COMMUTE
parent
a720929a8b
commit
9b3260f4df
|
@ -102,7 +102,7 @@ PG_FUNCTION_INFO_V1(mark_aggregate_for_distributed_execution);
|
|||
Datum
|
||||
create_distributed_function(PG_FUNCTION_ARGS)
|
||||
{
|
||||
RegProcedure funcOid = PG_GETARG_OID(0);
|
||||
RegProcedure funcOid = PG_ARGISNULL(0) ? InvalidOid : PG_GETARG_OID(0);
|
||||
|
||||
text *distributionArgumentNameText = NULL; /* optional */
|
||||
text *colocateWithText = NULL; /* optional */
|
||||
|
@ -228,18 +228,44 @@ create_distributed_function(PG_FUNCTION_ARGS)
|
|||
Datum
|
||||
mark_aggregate_for_distributed_execution(PG_FUNCTION_ARGS)
|
||||
{
|
||||
RegProcedure funcOid = PG_GETARG_OID(0);
|
||||
RegProcedure funcOid = PG_ARGISNULL(0) ? InvalidOid : PG_GETARG_OID(0);
|
||||
StringInfoData helperName;
|
||||
StringInfoData helperSuffix;
|
||||
Form_pg_proc proc = NULL;
|
||||
Form_pg_aggregate agg = NULL;
|
||||
HeapTuple proctup = SearchSysCache1(PROCOID, funcOid);
|
||||
HeapTuple proctup = NULL;
|
||||
HeapTuple aggtup = NULL;
|
||||
Oid helperOid = InvalidOid;
|
||||
int numargs = 0;
|
||||
Oid *argtypes = NULL;
|
||||
int aggregationStrategy = -1;
|
||||
|
||||
if (!PG_ARGISNULL(1))
|
||||
{
|
||||
char *strategyParam = TextDatumGetCString(PG_GETARG_TEXT_P(1));
|
||||
if (strcmp(strategyParam, "none") == 0)
|
||||
{
|
||||
aggregationStrategy = AGGREGATION_STRATEGY_NONE;
|
||||
}
|
||||
else if (strcmp(strategyParam, "combine") == 0)
|
||||
{
|
||||
aggregationStrategy = AGGREGATION_STRATEGY_COMBINE;
|
||||
}
|
||||
else if (strcmp(strategyParam, "commute") == 0)
|
||||
{
|
||||
aggregationStrategy = AGGREGATION_STRATEGY_COMMUTE;
|
||||
}
|
||||
}
|
||||
|
||||
if (aggregationStrategy == -1)
|
||||
{
|
||||
elog(ERROR,
|
||||
"mark_aggregate_for_distributed_execution expects a strategy that is one of: 'none', 'combine', 'commute'");
|
||||
}
|
||||
|
||||
if (aggregationStrategy == AGGREGATION_STRATEGY_COMBINE)
|
||||
{
|
||||
proctup = SearchSysCache1(PROCOID, funcOid);
|
||||
if (!HeapTupleIsValid(proctup))
|
||||
{
|
||||
goto early_exit;
|
||||
|
@ -271,7 +297,8 @@ mark_aggregate_for_distributed_execution(PG_FUNCTION_ARGS)
|
|||
{
|
||||
Oid coordArgTypes[2] = { BYTEAOID, ANYELEMENTOID };
|
||||
|
||||
CreateAggregateHelper(helperName.data, COORD_COMBINE_AGGREGATE_NAME, proc, agg,
|
||||
CreateAggregateHelper(helperName.data, COORD_COMBINE_AGGREGATE_NAME, proc,
|
||||
agg,
|
||||
2, coordArgTypes, true);
|
||||
}
|
||||
|
||||
|
@ -284,7 +311,8 @@ mark_aggregate_for_distributed_execution(PG_FUNCTION_ARGS)
|
|||
if (helperOid == InvalidOid)
|
||||
{
|
||||
/* Also check that we have a matching worker_partial_sfunc(internal, oid, ...) */
|
||||
FuncCandidateList clist = FuncnameGetCandidates(list_make2(makeString("citus"),
|
||||
FuncCandidateList clist = FuncnameGetCandidates(list_make2(makeString(
|
||||
"citus"),
|
||||
makeString(
|
||||
"worker_partial_agg_sfunc")),
|
||||
numargs + 2,
|
||||
|
@ -310,7 +338,8 @@ mark_aggregate_for_distributed_execution(PG_FUNCTION_ARGS)
|
|||
"CREATE FUNCTION citus.worker_partial_agg_sfunc(internal, oid");
|
||||
for (i = 0; i < numargs; i++)
|
||||
{
|
||||
appendStringInfo(&command, ", %s", format_type_be_qualified(argtypes[i]));
|
||||
appendStringInfo(&command, ", %s", format_type_be_qualified(
|
||||
argtypes[i]));
|
||||
}
|
||||
appendStringInfoString(&command,
|
||||
") RETURNS internal AS 'citus' LANGUAGE C PARALLEL SAFE");
|
||||
|
@ -331,12 +360,14 @@ mark_aggregate_for_distributed_execution(PG_FUNCTION_ARGS)
|
|||
pfree(command.data);
|
||||
}
|
||||
|
||||
CreateAggregateHelper(helperName.data, WORKER_PARTIAL_AGGREGATE_NAME, proc, agg,
|
||||
CreateAggregateHelper(helperName.data, WORKER_PARTIAL_AGGREGATE_NAME, proc,
|
||||
agg,
|
||||
numargs, argtypes, false);
|
||||
}
|
||||
}
|
||||
|
||||
/* set strategy column value */
|
||||
UpdateDistObjectAggregationStrategy(funcOid, AGGREGATION_STRATEGY_COMBINE);
|
||||
UpdateDistObjectAggregationStrategy(funcOid, aggregationStrategy);
|
||||
|
||||
early_exit:
|
||||
if (aggtup && HeapTupleIsValid(aggtup))
|
||||
|
|
|
@ -254,7 +254,7 @@ static List * WorkerAggregateExpressionList(Aggref *originalAggregate,
|
|||
WorkerAggregateWalkerContext *walkerContextry);
|
||||
static AggregateType GetAggregateType(Oid aggFunctionId);
|
||||
static Oid AggregateArgumentType(Aggref *aggregate);
|
||||
static bool AggregateEnabledCustom(Oid aggregateOid);
|
||||
static int AggregateEnabledCustom(Oid aggregateOid);
|
||||
static HeapTuple AggregateFunctionHelperOidHelper(Oid aggOid, StringInfo helperName);
|
||||
static Oid AggregateCoordCombineOid(Oid aggOid);
|
||||
static Oid AggregateWorkerPartialOid(Oid aggOid);
|
||||
|
@ -1848,7 +1848,7 @@ MasterAggregateExpression(Aggref *originalAggregate,
|
|||
|
||||
newMasterExpression = (Expr *) unionAggregate;
|
||||
}
|
||||
else if (aggregateType == AGGREGATE_CUSTOM)
|
||||
else if (aggregateType == AGGREGATE_CUSTOM_COMBINE)
|
||||
{
|
||||
aggTuple = SearchSysCache1(AGGFNOID,
|
||||
ObjectIdGetDatum(originalAggregate->aggfnoid));
|
||||
|
@ -1869,12 +1869,14 @@ MasterAggregateExpression(Aggref *originalAggregate,
|
|||
{
|
||||
Const *aggparam = NULL;
|
||||
Var *column = NULL;
|
||||
Const *nulltag = NULL;
|
||||
List *aggArguments = NIL;
|
||||
Aggref *newMasterAggregate = NULL;
|
||||
Oid coordCombineId = AggregateCoordCombineOid(originalAggregate->aggfnoid);
|
||||
Oid workerReturnType = BYTEAOID;
|
||||
int32 workerReturnTypeMod = -1;
|
||||
Oid workerCollationId = InvalidOid;
|
||||
Oid resultType = exprType((Node *) originalAggregate);
|
||||
|
||||
if (coordCombineId == InvalidOid)
|
||||
{
|
||||
|
@ -1887,10 +1889,13 @@ MasterAggregateExpression(Aggref *originalAggregate,
|
|||
originalAggregate->aggfnoid), false, true);
|
||||
column = makeVar(masterTableId, walkerContext->columnId, workerReturnType,
|
||||
workerReturnTypeMod, workerCollationId, columnLevelsUp);
|
||||
nulltag = makeNullConst(resultType, -1, InvalidOid);
|
||||
|
||||
aggArguments = list_make1(makeTargetEntry((Expr *) aggparam, 1, NULL, false));
|
||||
aggArguments = lappend(aggArguments, makeTargetEntry((Expr *) column, 2, NULL,
|
||||
false));
|
||||
aggArguments = lappend(aggArguments, makeTargetEntry((Expr *) nulltag, 3,
|
||||
NULL, false));
|
||||
|
||||
/* coord_combine_agg(agg, workercol) */
|
||||
newMasterAggregate = makeNode(Aggref);
|
||||
|
@ -1900,8 +1905,8 @@ MasterAggregateExpression(Aggref *originalAggregate,
|
|||
newMasterAggregate->aggkind = AGGKIND_NORMAL;
|
||||
newMasterAggregate->aggfilter = originalAggregate->aggfilter;
|
||||
newMasterAggregate->aggtranstype = INTERNALOID;
|
||||
newMasterAggregate->aggargtypes = list_concat(list_make1_oid(OIDOID),
|
||||
list_make1_oid(BYTEAOID));
|
||||
newMasterAggregate->aggargtypes = list_make3_oid(OIDOID, BYTEAOID,
|
||||
resultType);
|
||||
newMasterAggregate->aggsplit = AGGSPLIT_SIMPLE;
|
||||
|
||||
newMasterExpression = (Expr *) newMasterAggregate;
|
||||
|
@ -1925,14 +1930,8 @@ MasterAggregateExpression(Aggref *originalAggregate,
|
|||
int32 workerReturnTypeMod = exprTypmod((Node *) originalAggregate);
|
||||
Oid workerCollationId = exprCollation((Node *) originalAggregate);
|
||||
|
||||
const char *aggregateName = AggregateNames[aggregateType];
|
||||
Oid aggregateFunctionId = AggregateFunctionOid(aggregateName, workerReturnType);
|
||||
Oid masterReturnType = get_func_rettype(aggregateFunctionId);
|
||||
|
||||
Aggref *newMasterAggregate = copyObject(originalAggregate);
|
||||
newMasterAggregate->aggdistinct = NULL;
|
||||
newMasterAggregate->aggfnoid = aggregateFunctionId;
|
||||
newMasterAggregate->aggtype = masterReturnType;
|
||||
newMasterAggregate->aggfilter = NULL;
|
||||
|
||||
column = makeVar(masterTableId, walkerContext->columnId, workerReturnType,
|
||||
|
@ -2932,7 +2931,7 @@ WorkerAggregateExpressionList(Aggref *originalAggregate,
|
|||
workerAggregateList = lappend(workerAggregateList, sumAggregate);
|
||||
workerAggregateList = lappend(workerAggregateList, countAggregate);
|
||||
}
|
||||
else if (aggregateType == AGGREGATE_CUSTOM)
|
||||
else if (aggregateType == AGGREGATE_CUSTOM_COMBINE)
|
||||
{
|
||||
aggTuple = SearchSysCache1(AGGFNOID,
|
||||
ObjectIdGetDatum(originalAggregate->aggfnoid));
|
||||
|
@ -2990,11 +2989,16 @@ WorkerAggregateExpressionList(Aggref *originalAggregate,
|
|||
|
||||
workerAggregateList = list_make1(newWorkerAggregate);
|
||||
}
|
||||
else
|
||||
{
|
||||
elog(ERROR, "Aggregate lacks COMBINEFUNC");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
/*
|
||||
* All other aggregates are sent as they are to the worker nodes.
|
||||
* This includes AGGREGATE_CUSTOM_COMMUTE.
|
||||
*/
|
||||
Aggref *workerAggregate = copyObject(originalAggregate);
|
||||
workerAggregateList = lappend(workerAggregateList, workerAggregate);
|
||||
|
@ -3023,6 +3027,7 @@ GetAggregateType(Oid aggFunctionId)
|
|||
uint32 aggregateCount = 0;
|
||||
uint32 aggregateIndex = 0;
|
||||
bool found = false;
|
||||
int customAggregationStrategy;
|
||||
|
||||
/* look up the function name */
|
||||
aggregateProcName = get_func_name(aggFunctionId);
|
||||
|
@ -3032,9 +3037,10 @@ GetAggregateType(Oid aggFunctionId)
|
|||
aggFunctionId)));
|
||||
}
|
||||
|
||||
if (AggregateEnabledCustom(aggFunctionId))
|
||||
customAggregationStrategy = AggregateEnabledCustom(aggFunctionId);
|
||||
if (customAggregationStrategy != AGGREGATION_STRATEGY_NONE)
|
||||
{
|
||||
return AGGREGATE_CUSTOM;
|
||||
return AGGREGATE_CUSTOM + customAggregationStrategy;
|
||||
}
|
||||
|
||||
aggregateCount = lengthof(AggregateNames);
|
||||
|
@ -3075,14 +3081,17 @@ AggregateArgumentType(Aggref *aggregate)
|
|||
}
|
||||
|
||||
|
||||
static bool
|
||||
static int
|
||||
AggregateEnabledCustom(Oid aggregateOid)
|
||||
{
|
||||
DistObjectCacheEntry *cacheEntry = LookupDistObjectCacheEntry(ProcedureRelationId,
|
||||
aggregateOid, 0);
|
||||
|
||||
return cacheEntry != NULL && cacheEntry->isDistributed &&
|
||||
cacheEntry->aggregationStrategy == AGGREGATION_STRATEGY_COMBINE;
|
||||
if (cacheEntry == NULL || !cacheEntry->isDistributed)
|
||||
{
|
||||
return AGGREGATION_STRATEGY_NONE;
|
||||
}
|
||||
return cacheEntry->aggregationStrategy;
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
SET search_path = 'pg_catalog';
|
||||
|
||||
CREATE FUNCTION mark_aggregate_for_distributed_execution(regprocedure)
|
||||
CREATE FUNCTION mark_aggregate_for_distributed_execution(regprocedure, strategy text)
|
||||
RETURNS void
|
||||
AS 'MODULE_PATHNAME'
|
||||
LANGUAGE C STRICT IMMUTABLE PARALLEL SAFE;
|
||||
|
|
|
@ -66,5 +66,6 @@ typedef FormData_pg_dist_object *Form_pg_dist_object;
|
|||
*/
|
||||
#define AGGREGATION_STRATEGY_NONE 0
|
||||
#define AGGREGATION_STRATEGY_COMBINE 1
|
||||
#define AGGREGATION_STRATEGY_COMMUTE 2
|
||||
|
||||
#endif /* PG_DIST_OBJECT_H */
|
||||
|
|
|
@ -79,8 +79,10 @@ typedef enum
|
|||
AGGREGATE_TOPN_ADD_AGG = 18,
|
||||
AGGREGATE_TOPN_UNION_AGG = 19,
|
||||
|
||||
/* AGGREGATE_CUSTOM must come last */
|
||||
AGGREGATE_CUSTOM = 20
|
||||
/* AGGREGATE_CUSTOM must come last. AGGREGATE_CUSTOM + AGGREGATION_STRATEGY = AGGREGATE_CUSTOM_STRATEGY */
|
||||
AGGREGATE_CUSTOM = 20,
|
||||
AGGREGATE_CUSTOM_COMBINE = 21,
|
||||
AGGREGATE_CUSTOM_COMMUTE = 22
|
||||
} AggregateType;
|
||||
|
||||
/*
|
||||
|
|
Loading…
Reference in New Issue