mirror of https://github.com/citusdata/citus.git
AGGREGATE_STRATEGY_COMMUTE
parent
a720929a8b
commit
9b3260f4df
|
@ -102,7 +102,7 @@ PG_FUNCTION_INFO_V1(mark_aggregate_for_distributed_execution);
|
||||||
Datum
|
Datum
|
||||||
create_distributed_function(PG_FUNCTION_ARGS)
|
create_distributed_function(PG_FUNCTION_ARGS)
|
||||||
{
|
{
|
||||||
RegProcedure funcOid = PG_GETARG_OID(0);
|
RegProcedure funcOid = PG_ARGISNULL(0) ? InvalidOid : PG_GETARG_OID(0);
|
||||||
|
|
||||||
text *distributionArgumentNameText = NULL; /* optional */
|
text *distributionArgumentNameText = NULL; /* optional */
|
||||||
text *colocateWithText = NULL; /* optional */
|
text *colocateWithText = NULL; /* optional */
|
||||||
|
@ -228,18 +228,44 @@ create_distributed_function(PG_FUNCTION_ARGS)
|
||||||
Datum
|
Datum
|
||||||
mark_aggregate_for_distributed_execution(PG_FUNCTION_ARGS)
|
mark_aggregate_for_distributed_execution(PG_FUNCTION_ARGS)
|
||||||
{
|
{
|
||||||
RegProcedure funcOid = PG_GETARG_OID(0);
|
RegProcedure funcOid = PG_ARGISNULL(0) ? InvalidOid : PG_GETARG_OID(0);
|
||||||
StringInfoData helperName;
|
StringInfoData helperName;
|
||||||
StringInfoData helperSuffix;
|
StringInfoData helperSuffix;
|
||||||
Form_pg_proc proc = NULL;
|
Form_pg_proc proc = NULL;
|
||||||
Form_pg_aggregate agg = NULL;
|
Form_pg_aggregate agg = NULL;
|
||||||
HeapTuple proctup = SearchSysCache1(PROCOID, funcOid);
|
HeapTuple proctup = NULL;
|
||||||
HeapTuple aggtup = NULL;
|
HeapTuple aggtup = NULL;
|
||||||
Oid helperOid = InvalidOid;
|
Oid helperOid = InvalidOid;
|
||||||
int numargs = 0;
|
int numargs = 0;
|
||||||
Oid *argtypes = NULL;
|
Oid *argtypes = NULL;
|
||||||
|
int aggregationStrategy = -1;
|
||||||
|
|
||||||
|
if (!PG_ARGISNULL(1))
|
||||||
|
{
|
||||||
|
char *strategyParam = TextDatumGetCString(PG_GETARG_TEXT_P(1));
|
||||||
|
if (strcmp(strategyParam, "none") == 0)
|
||||||
|
{
|
||||||
|
aggregationStrategy = AGGREGATION_STRATEGY_NONE;
|
||||||
|
}
|
||||||
|
else if (strcmp(strategyParam, "combine") == 0)
|
||||||
|
{
|
||||||
|
aggregationStrategy = AGGREGATION_STRATEGY_COMBINE;
|
||||||
|
}
|
||||||
|
else if (strcmp(strategyParam, "commute") == 0)
|
||||||
|
{
|
||||||
|
aggregationStrategy = AGGREGATION_STRATEGY_COMMUTE;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (aggregationStrategy == -1)
|
||||||
|
{
|
||||||
|
elog(ERROR,
|
||||||
|
"mark_aggregate_for_distributed_execution expects a strategy that is one of: 'none', 'combine', 'commute'");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (aggregationStrategy == AGGREGATION_STRATEGY_COMBINE)
|
||||||
|
{
|
||||||
|
proctup = SearchSysCache1(PROCOID, funcOid);
|
||||||
if (!HeapTupleIsValid(proctup))
|
if (!HeapTupleIsValid(proctup))
|
||||||
{
|
{
|
||||||
goto early_exit;
|
goto early_exit;
|
||||||
|
@ -271,7 +297,8 @@ mark_aggregate_for_distributed_execution(PG_FUNCTION_ARGS)
|
||||||
{
|
{
|
||||||
Oid coordArgTypes[2] = { BYTEAOID, ANYELEMENTOID };
|
Oid coordArgTypes[2] = { BYTEAOID, ANYELEMENTOID };
|
||||||
|
|
||||||
CreateAggregateHelper(helperName.data, COORD_COMBINE_AGGREGATE_NAME, proc, agg,
|
CreateAggregateHelper(helperName.data, COORD_COMBINE_AGGREGATE_NAME, proc,
|
||||||
|
agg,
|
||||||
2, coordArgTypes, true);
|
2, coordArgTypes, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -284,7 +311,8 @@ mark_aggregate_for_distributed_execution(PG_FUNCTION_ARGS)
|
||||||
if (helperOid == InvalidOid)
|
if (helperOid == InvalidOid)
|
||||||
{
|
{
|
||||||
/* Also check that we have a matching worker_partial_sfunc(internal, oid, ...) */
|
/* Also check that we have a matching worker_partial_sfunc(internal, oid, ...) */
|
||||||
FuncCandidateList clist = FuncnameGetCandidates(list_make2(makeString("citus"),
|
FuncCandidateList clist = FuncnameGetCandidates(list_make2(makeString(
|
||||||
|
"citus"),
|
||||||
makeString(
|
makeString(
|
||||||
"worker_partial_agg_sfunc")),
|
"worker_partial_agg_sfunc")),
|
||||||
numargs + 2,
|
numargs + 2,
|
||||||
|
@ -310,7 +338,8 @@ mark_aggregate_for_distributed_execution(PG_FUNCTION_ARGS)
|
||||||
"CREATE FUNCTION citus.worker_partial_agg_sfunc(internal, oid");
|
"CREATE FUNCTION citus.worker_partial_agg_sfunc(internal, oid");
|
||||||
for (i = 0; i < numargs; i++)
|
for (i = 0; i < numargs; i++)
|
||||||
{
|
{
|
||||||
appendStringInfo(&command, ", %s", format_type_be_qualified(argtypes[i]));
|
appendStringInfo(&command, ", %s", format_type_be_qualified(
|
||||||
|
argtypes[i]));
|
||||||
}
|
}
|
||||||
appendStringInfoString(&command,
|
appendStringInfoString(&command,
|
||||||
") RETURNS internal AS 'citus' LANGUAGE C PARALLEL SAFE");
|
") RETURNS internal AS 'citus' LANGUAGE C PARALLEL SAFE");
|
||||||
|
@ -331,12 +360,14 @@ mark_aggregate_for_distributed_execution(PG_FUNCTION_ARGS)
|
||||||
pfree(command.data);
|
pfree(command.data);
|
||||||
}
|
}
|
||||||
|
|
||||||
CreateAggregateHelper(helperName.data, WORKER_PARTIAL_AGGREGATE_NAME, proc, agg,
|
CreateAggregateHelper(helperName.data, WORKER_PARTIAL_AGGREGATE_NAME, proc,
|
||||||
|
agg,
|
||||||
numargs, argtypes, false);
|
numargs, argtypes, false);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/* set strategy column value */
|
/* set strategy column value */
|
||||||
UpdateDistObjectAggregationStrategy(funcOid, AGGREGATION_STRATEGY_COMBINE);
|
UpdateDistObjectAggregationStrategy(funcOid, aggregationStrategy);
|
||||||
|
|
||||||
early_exit:
|
early_exit:
|
||||||
if (aggtup && HeapTupleIsValid(aggtup))
|
if (aggtup && HeapTupleIsValid(aggtup))
|
||||||
|
|
|
@ -254,7 +254,7 @@ static List * WorkerAggregateExpressionList(Aggref *originalAggregate,
|
||||||
WorkerAggregateWalkerContext *walkerContextry);
|
WorkerAggregateWalkerContext *walkerContextry);
|
||||||
static AggregateType GetAggregateType(Oid aggFunctionId);
|
static AggregateType GetAggregateType(Oid aggFunctionId);
|
||||||
static Oid AggregateArgumentType(Aggref *aggregate);
|
static Oid AggregateArgumentType(Aggref *aggregate);
|
||||||
static bool AggregateEnabledCustom(Oid aggregateOid);
|
static int AggregateEnabledCustom(Oid aggregateOid);
|
||||||
static HeapTuple AggregateFunctionHelperOidHelper(Oid aggOid, StringInfo helperName);
|
static HeapTuple AggregateFunctionHelperOidHelper(Oid aggOid, StringInfo helperName);
|
||||||
static Oid AggregateCoordCombineOid(Oid aggOid);
|
static Oid AggregateCoordCombineOid(Oid aggOid);
|
||||||
static Oid AggregateWorkerPartialOid(Oid aggOid);
|
static Oid AggregateWorkerPartialOid(Oid aggOid);
|
||||||
|
@ -1848,7 +1848,7 @@ MasterAggregateExpression(Aggref *originalAggregate,
|
||||||
|
|
||||||
newMasterExpression = (Expr *) unionAggregate;
|
newMasterExpression = (Expr *) unionAggregate;
|
||||||
}
|
}
|
||||||
else if (aggregateType == AGGREGATE_CUSTOM)
|
else if (aggregateType == AGGREGATE_CUSTOM_COMBINE)
|
||||||
{
|
{
|
||||||
aggTuple = SearchSysCache1(AGGFNOID,
|
aggTuple = SearchSysCache1(AGGFNOID,
|
||||||
ObjectIdGetDatum(originalAggregate->aggfnoid));
|
ObjectIdGetDatum(originalAggregate->aggfnoid));
|
||||||
|
@ -1869,12 +1869,14 @@ MasterAggregateExpression(Aggref *originalAggregate,
|
||||||
{
|
{
|
||||||
Const *aggparam = NULL;
|
Const *aggparam = NULL;
|
||||||
Var *column = NULL;
|
Var *column = NULL;
|
||||||
|
Const *nulltag = NULL;
|
||||||
List *aggArguments = NIL;
|
List *aggArguments = NIL;
|
||||||
Aggref *newMasterAggregate = NULL;
|
Aggref *newMasterAggregate = NULL;
|
||||||
Oid coordCombineId = AggregateCoordCombineOid(originalAggregate->aggfnoid);
|
Oid coordCombineId = AggregateCoordCombineOid(originalAggregate->aggfnoid);
|
||||||
Oid workerReturnType = BYTEAOID;
|
Oid workerReturnType = BYTEAOID;
|
||||||
int32 workerReturnTypeMod = -1;
|
int32 workerReturnTypeMod = -1;
|
||||||
Oid workerCollationId = InvalidOid;
|
Oid workerCollationId = InvalidOid;
|
||||||
|
Oid resultType = exprType((Node *) originalAggregate);
|
||||||
|
|
||||||
if (coordCombineId == InvalidOid)
|
if (coordCombineId == InvalidOid)
|
||||||
{
|
{
|
||||||
|
@ -1887,10 +1889,13 @@ MasterAggregateExpression(Aggref *originalAggregate,
|
||||||
originalAggregate->aggfnoid), false, true);
|
originalAggregate->aggfnoid), false, true);
|
||||||
column = makeVar(masterTableId, walkerContext->columnId, workerReturnType,
|
column = makeVar(masterTableId, walkerContext->columnId, workerReturnType,
|
||||||
workerReturnTypeMod, workerCollationId, columnLevelsUp);
|
workerReturnTypeMod, workerCollationId, columnLevelsUp);
|
||||||
|
nulltag = makeNullConst(resultType, -1, InvalidOid);
|
||||||
|
|
||||||
aggArguments = list_make1(makeTargetEntry((Expr *) aggparam, 1, NULL, false));
|
aggArguments = list_make1(makeTargetEntry((Expr *) aggparam, 1, NULL, false));
|
||||||
aggArguments = lappend(aggArguments, makeTargetEntry((Expr *) column, 2, NULL,
|
aggArguments = lappend(aggArguments, makeTargetEntry((Expr *) column, 2, NULL,
|
||||||
false));
|
false));
|
||||||
|
aggArguments = lappend(aggArguments, makeTargetEntry((Expr *) nulltag, 3,
|
||||||
|
NULL, false));
|
||||||
|
|
||||||
/* coord_combine_agg(agg, workercol) */
|
/* coord_combine_agg(agg, workercol) */
|
||||||
newMasterAggregate = makeNode(Aggref);
|
newMasterAggregate = makeNode(Aggref);
|
||||||
|
@ -1900,8 +1905,8 @@ MasterAggregateExpression(Aggref *originalAggregate,
|
||||||
newMasterAggregate->aggkind = AGGKIND_NORMAL;
|
newMasterAggregate->aggkind = AGGKIND_NORMAL;
|
||||||
newMasterAggregate->aggfilter = originalAggregate->aggfilter;
|
newMasterAggregate->aggfilter = originalAggregate->aggfilter;
|
||||||
newMasterAggregate->aggtranstype = INTERNALOID;
|
newMasterAggregate->aggtranstype = INTERNALOID;
|
||||||
newMasterAggregate->aggargtypes = list_concat(list_make1_oid(OIDOID),
|
newMasterAggregate->aggargtypes = list_make3_oid(OIDOID, BYTEAOID,
|
||||||
list_make1_oid(BYTEAOID));
|
resultType);
|
||||||
newMasterAggregate->aggsplit = AGGSPLIT_SIMPLE;
|
newMasterAggregate->aggsplit = AGGSPLIT_SIMPLE;
|
||||||
|
|
||||||
newMasterExpression = (Expr *) newMasterAggregate;
|
newMasterExpression = (Expr *) newMasterAggregate;
|
||||||
|
@ -1925,14 +1930,8 @@ MasterAggregateExpression(Aggref *originalAggregate,
|
||||||
int32 workerReturnTypeMod = exprTypmod((Node *) originalAggregate);
|
int32 workerReturnTypeMod = exprTypmod((Node *) originalAggregate);
|
||||||
Oid workerCollationId = exprCollation((Node *) originalAggregate);
|
Oid workerCollationId = exprCollation((Node *) originalAggregate);
|
||||||
|
|
||||||
const char *aggregateName = AggregateNames[aggregateType];
|
|
||||||
Oid aggregateFunctionId = AggregateFunctionOid(aggregateName, workerReturnType);
|
|
||||||
Oid masterReturnType = get_func_rettype(aggregateFunctionId);
|
|
||||||
|
|
||||||
Aggref *newMasterAggregate = copyObject(originalAggregate);
|
Aggref *newMasterAggregate = copyObject(originalAggregate);
|
||||||
newMasterAggregate->aggdistinct = NULL;
|
newMasterAggregate->aggdistinct = NULL;
|
||||||
newMasterAggregate->aggfnoid = aggregateFunctionId;
|
|
||||||
newMasterAggregate->aggtype = masterReturnType;
|
|
||||||
newMasterAggregate->aggfilter = NULL;
|
newMasterAggregate->aggfilter = NULL;
|
||||||
|
|
||||||
column = makeVar(masterTableId, walkerContext->columnId, workerReturnType,
|
column = makeVar(masterTableId, walkerContext->columnId, workerReturnType,
|
||||||
|
@ -2932,7 +2931,7 @@ WorkerAggregateExpressionList(Aggref *originalAggregate,
|
||||||
workerAggregateList = lappend(workerAggregateList, sumAggregate);
|
workerAggregateList = lappend(workerAggregateList, sumAggregate);
|
||||||
workerAggregateList = lappend(workerAggregateList, countAggregate);
|
workerAggregateList = lappend(workerAggregateList, countAggregate);
|
||||||
}
|
}
|
||||||
else if (aggregateType == AGGREGATE_CUSTOM)
|
else if (aggregateType == AGGREGATE_CUSTOM_COMBINE)
|
||||||
{
|
{
|
||||||
aggTuple = SearchSysCache1(AGGFNOID,
|
aggTuple = SearchSysCache1(AGGFNOID,
|
||||||
ObjectIdGetDatum(originalAggregate->aggfnoid));
|
ObjectIdGetDatum(originalAggregate->aggfnoid));
|
||||||
|
@ -2990,11 +2989,16 @@ WorkerAggregateExpressionList(Aggref *originalAggregate,
|
||||||
|
|
||||||
workerAggregateList = list_make1(newWorkerAggregate);
|
workerAggregateList = list_make1(newWorkerAggregate);
|
||||||
}
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
elog(ERROR, "Aggregate lacks COMBINEFUNC");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
/*
|
/*
|
||||||
* All other aggregates are sent as they are to the worker nodes.
|
* All other aggregates are sent as they are to the worker nodes.
|
||||||
|
* This includes AGGREGATE_CUSTOM_COMMUTE.
|
||||||
*/
|
*/
|
||||||
Aggref *workerAggregate = copyObject(originalAggregate);
|
Aggref *workerAggregate = copyObject(originalAggregate);
|
||||||
workerAggregateList = lappend(workerAggregateList, workerAggregate);
|
workerAggregateList = lappend(workerAggregateList, workerAggregate);
|
||||||
|
@ -3023,6 +3027,7 @@ GetAggregateType(Oid aggFunctionId)
|
||||||
uint32 aggregateCount = 0;
|
uint32 aggregateCount = 0;
|
||||||
uint32 aggregateIndex = 0;
|
uint32 aggregateIndex = 0;
|
||||||
bool found = false;
|
bool found = false;
|
||||||
|
int customAggregationStrategy;
|
||||||
|
|
||||||
/* look up the function name */
|
/* look up the function name */
|
||||||
aggregateProcName = get_func_name(aggFunctionId);
|
aggregateProcName = get_func_name(aggFunctionId);
|
||||||
|
@ -3032,9 +3037,10 @@ GetAggregateType(Oid aggFunctionId)
|
||||||
aggFunctionId)));
|
aggFunctionId)));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (AggregateEnabledCustom(aggFunctionId))
|
customAggregationStrategy = AggregateEnabledCustom(aggFunctionId);
|
||||||
|
if (customAggregationStrategy != AGGREGATION_STRATEGY_NONE)
|
||||||
{
|
{
|
||||||
return AGGREGATE_CUSTOM;
|
return AGGREGATE_CUSTOM + customAggregationStrategy;
|
||||||
}
|
}
|
||||||
|
|
||||||
aggregateCount = lengthof(AggregateNames);
|
aggregateCount = lengthof(AggregateNames);
|
||||||
|
@ -3075,14 +3081,17 @@ AggregateArgumentType(Aggref *aggregate)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
static bool
|
static int
|
||||||
AggregateEnabledCustom(Oid aggregateOid)
|
AggregateEnabledCustom(Oid aggregateOid)
|
||||||
{
|
{
|
||||||
DistObjectCacheEntry *cacheEntry = LookupDistObjectCacheEntry(ProcedureRelationId,
|
DistObjectCacheEntry *cacheEntry = LookupDistObjectCacheEntry(ProcedureRelationId,
|
||||||
aggregateOid, 0);
|
aggregateOid, 0);
|
||||||
|
|
||||||
return cacheEntry != NULL && cacheEntry->isDistributed &&
|
if (cacheEntry == NULL || !cacheEntry->isDistributed)
|
||||||
cacheEntry->aggregationStrategy == AGGREGATION_STRATEGY_COMBINE;
|
{
|
||||||
|
return AGGREGATION_STRATEGY_NONE;
|
||||||
|
}
|
||||||
|
return cacheEntry->aggregationStrategy;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
SET search_path = 'pg_catalog';
|
SET search_path = 'pg_catalog';
|
||||||
|
|
||||||
CREATE FUNCTION mark_aggregate_for_distributed_execution(regprocedure)
|
CREATE FUNCTION mark_aggregate_for_distributed_execution(regprocedure, strategy text)
|
||||||
RETURNS void
|
RETURNS void
|
||||||
AS 'MODULE_PATHNAME'
|
AS 'MODULE_PATHNAME'
|
||||||
LANGUAGE C STRICT IMMUTABLE PARALLEL SAFE;
|
LANGUAGE C STRICT IMMUTABLE PARALLEL SAFE;
|
||||||
|
|
|
@ -66,5 +66,6 @@ typedef FormData_pg_dist_object *Form_pg_dist_object;
|
||||||
*/
|
*/
|
||||||
#define AGGREGATION_STRATEGY_NONE 0
|
#define AGGREGATION_STRATEGY_NONE 0
|
||||||
#define AGGREGATION_STRATEGY_COMBINE 1
|
#define AGGREGATION_STRATEGY_COMBINE 1
|
||||||
|
#define AGGREGATION_STRATEGY_COMMUTE 2
|
||||||
|
|
||||||
#endif /* PG_DIST_OBJECT_H */
|
#endif /* PG_DIST_OBJECT_H */
|
||||||
|
|
|
@ -79,8 +79,10 @@ typedef enum
|
||||||
AGGREGATE_TOPN_ADD_AGG = 18,
|
AGGREGATE_TOPN_ADD_AGG = 18,
|
||||||
AGGREGATE_TOPN_UNION_AGG = 19,
|
AGGREGATE_TOPN_UNION_AGG = 19,
|
||||||
|
|
||||||
/* AGGREGATE_CUSTOM must come last */
|
/* AGGREGATE_CUSTOM must come last. AGGREGATE_CUSTOM + AGGREGATION_STRATEGY = AGGREGATE_CUSTOM_STRATEGY */
|
||||||
AGGREGATE_CUSTOM = 20
|
AGGREGATE_CUSTOM = 20,
|
||||||
|
AGGREGATE_CUSTOM_COMBINE = 21,
|
||||||
|
AGGREGATE_CUSTOM_COMMUTE = 22
|
||||||
} AggregateType;
|
} AggregateType;
|
||||||
|
|
||||||
/*
|
/*
|
||||||
|
|
Loading…
Reference in New Issue