Sketched out rest of how mark_aggregate_for_distributed_execution should look like

fix_120_custom_aggregates_distribute_multiarg
Philip Dubé 2019-10-15 02:13:32 +00:00
parent 020921f1eb
commit 1545cb312c
6 changed files with 399 additions and 38 deletions

View File

@ -41,6 +41,7 @@
#include "distributed/metadata/pg_dist_object.h"
#include "distributed/metadata_sync.h"
#include "distributed/multi_executor.h"
#include "distributed/multi_logical_optimizer.h"
#include "distributed/relation_access_tracking.h"
#include "distributed/worker_transaction.h"
#include "parser/parse_coerce.h"
@ -59,6 +60,9 @@
static char * GetAggregateDDLCommand(const RegProcedure funcOid);
static char * GetFunctionDDLCommand(const RegProcedure funcOid);
static char * GetFunctionAlterOwnerCommand(const RegProcedure funcOid);
static void CreateAggregateHelper(const char *helperName, const char *helperPrefix,
Form_pg_proc proc, Form_pg_aggregate agg, int numargs,
Oid *argtypes);
static int GetDistributionArgIndex(Oid functionOid, char *distributionArgumentName,
Oid *distributionArgumentOid);
static int GetFunctionColocationId(Oid functionOid, char *colocateWithName, Oid
@ -66,6 +70,7 @@ static int GetFunctionColocationId(Oid functionOid, char *colocateWithName, Oid
static void EnsureFunctionCanBeColocatedWithTable(Oid functionOid, Oid
distributionColumnType, Oid
sourceRelationId);
static void UpdateDistObjectAggregationStrategy(Oid funcOid, int aggregationStrategy);
static void UpdateFunctionDistributionInfo(const ObjectAddress *distAddress,
int *distribution_argument_index,
int *colocationId);
@ -82,6 +87,7 @@ static char * quote_qualified_func_name(Oid funcOid);
PG_FUNCTION_INFO_V1(create_distributed_function);
PG_FUNCTION_INFO_v1(mark_aggregate_for_distributed_execution);
#define AssertIsFunctionOrProcedure(objtype) \
Assert((objtype) == OBJECT_FUNCTION || (objtype) == OBJECT_PROCEDURE || (objtype) == \
@ -212,6 +218,305 @@ create_distributed_function(PG_FUNCTION_ARGS)
}
/*
* mark_aggregate_for_distributed_execution(regproc) signals the aggregate is safe to
* execute across worker nodes. This requires superuser because an aggregate
* which does not support distributed execution has undefined behavior.
* Also makes sure necessary helper functions exist across nodes.
*/
Datum
mark_aggregate_for_distributed_execution(PG_FUNCTION_ARGS)
{
RegProcedure funcOid = PG_GETARG_OID(0);
StringInfoData helperName;
StringInfoData helperSuffix;
Form_pg_proc proc = NULL;
Form_pg_aggregate agg = NULL;
HeapTuple proctup = SearchSysCache1(PROCOID, funcOid);
HeapTuple aggtup = NULL;
int numargs = 0;
Oid *argtypes = NULL;
char **argnames = NULL;
char *argmodes = NULL;
FuncCandidateList clist;
if (!HeapTupleIsValid(proctup))
{
goto early_exit;
}
proc = (Form_pg_proc) GETSTRUCT(proctup);
if (proc->prokind != PROKIND_AGGREGATE)
{
goto early_exit;
}
aggtup = SearchSysCache1(AGGFNOID, funcOid);
if (!HeapTupleIsValid(aggtup))
{
goto early_exit;
}
agg = (Form_pg_aggregate) GETSTRUCT(aggtup);
initStringInfo(&helperSuffix);
switch (proc->proparallel)
{
case PROPARALLEL_SAFE:
{
appendStringInfoString(&helperSuffix, "_ps");
break;
}
case PROPARALLEL_RESTRICTED:
{
appendStringInfoString(&helperSuffix, "_pr");
break;
}
case PROPARALLEL_UNSAFE:
{
appendStringInfoString(&helperSuffix, "_pu");
break;
}
}
if (agg->aggfinalfn != InvalidOid)
{
switch (agg->aggfinalmodify)
{
case AGGMODIFY_READ_ONLY:
{
appendStringInfoString(&helperSuffix, "_ro");
break;
}
case AGGMODIFY_SHAREABLE:
{
appendStringInfoString(&helperSuffix, "_rs");
break;
}
case AGGMODIFY_READ_WRITE:
{
appendStringInfoString(&helperSuffix, "_rw");
break;
}
}
if (agg->aggfinalextra)
{
appendStringInfoString(&helperSuffix, "_fx");
}
}
if (agg->aggmfinalfn != InvalidOid)
{
appendStringInfoString(&helperSuffix, "_m");
switch (agg->aggmfinalmodify)
{
case AGGMODIFY_READ_ONLY:
{
appendStringInfoString(&helperSuffix, "_ro");
break;
}
case AGGMODIFY_SHAREABLE:
{
appendStringInfoString(&helperSuffix, "_rs");
break;
}
case AGGMODIFY_READ_WRITE:
{
appendStringInfoString(&helperSuffix, "_rw");
break;
}
}
if (agg->aggmfinalextra)
{
appendStringInfoString(&helperSuffix, "_fx");
}
}
/* Parameters, borrows heavily from print_function_arguments in postgres */
/* coordinator_combine_agg */
initStringInfo(&helperName);
appendStringInfo(&helperName, "%s%s", COORD_COMBINE_AGGREGATE_NAME,
helperSuffix.data);
numargs = get_func_arg_info(proctup, &argtypes, &argnames, &argmodes);
clist = FuncnameGetCandidates(list_make2(makeString("citus"), makeString(
helperName.data)), proc->pronargs + 1,
NIL, false, false, true);
for (; clist; clist = clist->next)
{
if (clist->args[0] == OIDOID && memcmp(clist->args + 1, argtypes, numargs *
sizeof(numargs)) == 0)
{
break;
}
}
if (clist == NULL)
{
CreateAggregateHelper(helperName.data, COORD_COMBINE_AGGREGATE_NAME, proc, agg,
numargs, argtypes);
}
/* worker_partial_agg */
resetStringInfo(&helperName);
appendStringInfo(&helperName, "%s%s", WORKER_PARTIAL_AGGREGATE_NAME,
helperSuffix.data);
numargs = get_func_arg_info(proctup, &argtypes, &argnames, &argmodes);
clist = FuncnameGetCandidates(list_make2(makeString("citus"), makeString(
helperName.data)), proc->pronargs + 1,
NIL, false, false, true);
for (; clist; clist = clist->next)
{
if (clist->args[0] == OIDOID && memcmp(clist->args + 1, argtypes, numargs *
sizeof(numargs)) == 0)
{
break;
}
}
if (clist == NULL)
{
CreateAggregateHelper(helperName.data, WORKER_PARTIAL_AGGREGATE_NAME, proc, agg,
numargs, argtypes);
}
/* set strategy column value */
UpdateDistObjectAggregationStrategy(funcOid, AGGREGATION_STRATEGY_COMBINE);
early_exit:
if (aggtup && HeapTupleIsValid(aggtup))
{
ReleaseSysCache(aggtup);
}
if (proctup && HeapTupleIsValid(proctup))
{
ReleaseSysCache(proctup);
}
PG_RETURN_VOID();
}
/*
* CreateAggregateHelper creates helper aggregates across nodes
*/
static void
CreateAggregateHelper(const char *helperName, const char *helperPrefix,
Form_pg_proc proc, Form_pg_aggregate agg, int numargs,
Oid *argtypes)
{
StringInfoData command;
initStringInfo(&command);
appendStringInfo(&command, "CREATE AGGREGATE %s("
"STYPE = internal, SFUNC = citus.%s_sfunc, FINALFUNC = %s_ffunc"
", COMBINEFUNC = citus.citus_stype_combine"
", SERIALFUNC = citus.citus_stype_serialize"
", DESERIALFUNC = citus.citus_stype_deserialize",
quote_qualified_identifier("citus", helperName),
helperPrefix, helperPrefix);
switch (proc->proparallel)
{
case PROPARALLEL_SAFE:
{
appendStringInfoString(&command, ", PARALLEL = SAFE");
break;
}
case PROPARALLEL_RESTRICTED:
{
appendStringInfoString(&command, ", PARALLEL = RESTRICTED");
break;
}
case PROPARALLEL_UNSAFE:
{
appendStringInfoString(&command, ", PARALLEL = UNSAFE");
break;
}
}
if (agg->aggfinalfn != InvalidOid)
{
switch (agg->aggfinalmodify)
{
case AGGMODIFY_READ_ONLY:
{
appendStringInfoString(&command, ", FINALFUNC_MODIFY = READ_ONLY");
break;
}
case AGGMODIFY_SHAREABLE:
{
appendStringInfoString(&command, ", FINALFUNC_MODIFY = SHAREABLE");
break;
}
case AGGMODIFY_READ_WRITE:
{
appendStringInfoString(&command, ", FINALFUNC = READ_WRITE");
break;
}
}
if (agg->aggfinalextra)
{
appendStringInfoString(&command, ", FINALFUNC_EXTRA");
}
}
if (agg->aggmfinalfn != InvalidOid)
{
switch (agg->aggmfinalmodify)
{
case AGGMODIFY_READ_ONLY:
{
appendStringInfoString(&command, ", MFINALFUNC_MODIFY = READ_ONLY");
break;
}
case AGGMODIFY_SHAREABLE:
{
appendStringInfoString(&command, ", MFINALFUNC_MODIFY = SHAREABLE");
break;
}
case AGGMODIFY_READ_WRITE:
{
appendStringInfoString(&command, ", MFINALFUNC_MODIFY = READ_WRITE");
break;
}
}
if (agg->aggmfinalextra)
{
appendStringInfoString(&command, ", MFINALFUNC_EXTRA");
}
}
appendStringInfoChar(&command, ')');
SendCommandToWorkers(ALL_WORKERS, command.data);
pfree(command.data);
}
/*
* CreateFunctionDDLCommandsIdempotent returns a list of DDL statements (const char *) to be
* executed on a node to recreate the function addressed by the functionAddress.
@ -449,6 +754,68 @@ EnsureFunctionCanBeColocatedWithTable(Oid functionOid, Oid distributionColumnTyp
}
/*
* UpdateFunctionDistributionInfo gets object address of a function and
* updates its distribution_argument_index and colocationId in pg_dist_object.
*/
static void
UpdateDistObjectAggregationStrategy(Oid funcOid, int aggregationStrategy)
{
const bool indexOK = true;
Relation pgDistObjectRel = NULL;
TupleDesc tupleDescriptor = NULL;
ScanKeyData scanKey[3];
SysScanDesc scanDescriptor = NULL;
HeapTuple heapTuple = NULL;
Datum values[Natts_pg_dist_object];
bool isnull[Natts_pg_dist_object];
bool replace[Natts_pg_dist_object];
pgDistObjectRel = heap_open(DistObjectRelationId(), RowExclusiveLock);
tupleDescriptor = RelationGetDescr(pgDistObjectRel);
/* scan pg_dist_object for classid = $1 AND objid = $2 AND objsubid = $3 via index */
ScanKeyInit(&scanKey[0], Anum_pg_dist_object_classid, BTEqualStrategyNumber, F_OIDEQ,
ObjectIdGetDatum(ProcedureRelationId));
ScanKeyInit(&scanKey[1], Anum_pg_dist_object_objid, BTEqualStrategyNumber, F_OIDEQ,
ObjectIdGetDatum(funcOid));
ScanKeyInit(&scanKey[2], Anum_pg_dist_object_objsubid, BTEqualStrategyNumber,
F_INT4EQ, ObjectIdGetDatum(0));
scanDescriptor = systable_beginscan(pgDistObjectRel, DistObjectPrimaryKeyIndexId(),
indexOK,
NULL, 3, scanKey);
heapTuple = systable_getnext(scanDescriptor);
if (!HeapTupleIsValid(heapTuple))
{
ereport(ERROR, (errmsg("could not find valid entry for \"%d,%d,%d\" "
"in pg_dist_object", ProcedureRelationId,
funcOid, 0)));
}
memset(replace, 0, sizeof(replace));
replace[Anum_pg_dist_object_distribution_argument_index - 1] = true;
values[Anum_pg_dist_object_distribution_argument_index - 1] = Int32GetDatum(
aggregationStrategy);
isnull[Anum_pg_dist_object_distribution_argument_index - 1] = false;
heapTuple = heap_modify_tuple(heapTuple, tupleDescriptor, values, isnull, replace);
CatalogTupleUpdate(pgDistObjectRel, &heapTuple->t_self, heapTuple);
CitusInvalidateRelcacheByRelid(DistObjectRelationId());
CommandCounterIncrement();
systable_endscan(scanDescriptor);
heap_close(pgDistObjectRel, NoLock);
}
/*
* UpdateFunctionDistributionInfo gets object address of a function and
* updates its distribution_argument_index and colocationId in pg_dist_object.
@ -487,7 +854,7 @@ UpdateFunctionDistributionInfo(const ObjectAddress *distAddress,
heapTuple = systable_getnext(scanDescriptor);
if (!HeapTupleIsValid(heapTuple))
{
ereport(ERROR, (errmsg("could not find valid entry for node \"%d,%d,%d\" "
ereport(ERROR, (errmsg("could not find valid entry for \"%d,%d,%d\" "
"in pg_dist_object", distAddress->classId,
distAddress->objectId, distAddress->objectSubId)));
}

View File

@ -24,6 +24,7 @@
#include "catalog/pg_proc.h"
#include "catalog/pg_type.h"
#include "commands/extension.h"
#include "distributed/metadata/pg_dist_object.h"
#include "distributed/citus_nodes.h"
#include "distributed/citus_ruleutils.h"
#include "distributed/colocation_utils.h"
@ -253,7 +254,7 @@ static List * WorkerAggregateExpressionList(Aggref *originalAggregate,
static AggregateType GetAggregateType(Oid aggFunctionId);
static Oid AggregateArgumentType(Aggref *aggregate);
static bool AggregateEnabledCustom(Oid aggregateOid);
static Oid AggregateFunctionOidWithoutInput(const char *functionName);
static Oid AggregateFunctionHelperOid(const char *helperPrefix, Oid aggOid);
static Oid AggregateFunctionOid(const char *functionName, Oid inputType);
static Oid TypeOid(Oid schemaId, const char *typeName);
static SortGroupClause * CreateSortGroupClause(Var *column);
@ -1867,8 +1868,8 @@ MasterAggregateExpression(Aggref *originalAggregate,
Var *column = NULL;
List *aggArguments = NIL;
Aggref *newMasterAggregate = NULL;
Oid coordCombineId = AggregateFunctionOidWithoutInput(
COORD_COMBINE_AGGREGATE_NAME);
Oid coordCombineId = AggregateFunctionHelperOid(
COORD_COMBINE_AGGREGATE_NAME, originalAggregate->aggfnoid);
Oid workerReturnType = BYTEAOID;
int32 workerReturnTypeMod = -1;
@ -2947,8 +2948,8 @@ WorkerAggregateExpressionList(Aggref *originalAggregate,
Aggref *newWorkerAggregate = NULL;
List *aggArguments = NIL;
ListCell *originalAggArgCell;
Oid workerPartialId = AggregateFunctionOidWithoutInput(
WORKER_PARTIAL_AGGREGATE_NAME);
Oid workerPartialId = AggregateFunctionHelperOid(
WORKER_PARTIAL_AGGREGATE_NAME, originalAggregate->aggfnoid);
aggparam = makeConst(REGPROCEDUREOID, -1, InvalidOid, sizeof(Oid),
ObjectIdGetDatum(originalAggregate->aggfnoid), false,
@ -3067,7 +3068,8 @@ AggregateEnabledCustom(Oid aggregateOid)
DistObjectCacheEntry *cacheEntry = LookupDistObjectCacheEntry(AggregateRelationId,
aggregateOid, 0);
return cacheEntry != NULL;
return cacheEntry != NULL && cacheEntry->aggregationStrategy ==
AGGREGATION_STRATEGY_COMBINE;
}
@ -3133,12 +3135,10 @@ AggregateFunctionOid(const char *functionName, Oid inputType)
/*
* AggregateFunctionOid performs a reverse lookup on aggregate function name,
* and returns the corresponding aggregate function oid for the given function
* name and input type.
* AggregateFunctionHelperOid finds the aggregate helper for a given aggregate.
*/
static Oid
AggregateFunctionOidWithoutInput(const char *functionName)
AggregateFunctionHelperOid(const char *helperPrefix, Oid aggOid)
{
Oid functionOid = InvalidOid;
Relation procRelation = NULL;

View File

@ -1,5 +1,12 @@
SET search_path = 'pg_catalog';
CREATE FUNCTION mark_aggregate_for_distributed_execution(internal)
RETURNS void
AS 'MODULE_PATHNAME'
LANGUAGE C STRICT IMMUTABLE PARALLEL SAFE;
SET search_path = 'citus';
CREATE FUNCTION citus_stype_serialize(internal)
RETURNS bytea
AS 'MODULE_PATHNAME'
@ -35,31 +42,6 @@ RETURNS anyelement
AS 'MODULE_PATHNAME'
LANGUAGE C PARALLEL SAFE;
-- select worker_partial_agg(agg, ...)
-- equivalent to
-- select serialize_stype(agg_without_ffunc(...))
CREATE AGGREGATE worker_partial_agg(oid, anyelement) (
STYPE = internal,
SFUNC = worker_partial_agg_sfunc,
FINALFUNC = worker_partial_agg_ffunc,
COMBINEFUNC = citus_stype_combine,
SERIALFUNC = citus_stype_serialize,
DESERIALFUNC = citus_stype_deserialize,
PARALLEL = SAFE
);
-- select coord_combine_agg(agg, col)
-- equivalent to
-- select agg_ffunc(agg_combine(col))
CREATE AGGREGATE coord_combine_agg(oid, bytea, anyelement) (
STYPE = internal,
SFUNC = coord_combine_agg_sfunc,
FINALFUNC = coord_combine_agg_ffunc,
FINALFUNC_EXTRA,
COMBINEFUNC = citus_stype_combine,
SERIALFUNC = citus_stype_serialize,
DESERIALFUNC = citus_stype_deserialize,
PARALLEL = SAFE
);
ALTER TABLE pg_dist_object ADD aggregation_strategy int;
RESET search_path;

View File

@ -998,6 +998,8 @@ LookupDistObjectCacheEntry(Oid classid, Oid objid, int32 objsubid)
1]);
cacheEntry->colocationId =
DatumGetInt32(datumArray[Anum_pg_dist_object_colocationid - 1]);
cacheEntry->aggregationStrategy =
DatumGetInt32(datumArray[Anum_pg_dist_object_aggregation_strategy - 1]);
}
else
{

View File

@ -35,6 +35,7 @@ typedef struct FormData_pg_dist_object
uint32 distribution_argument_index; /* only valid for distributed functions/procedures */
uint32 colocationid; /* only valid for distributed functions/procedures */
uint32 aggregation_strategy; /* only valid for distributed aggregates */
#endif
} FormData_pg_dist_object;
@ -58,5 +59,12 @@ typedef FormData_pg_dist_object *Form_pg_dist_object;
#define Anum_pg_dist_object_object_args 6
#define Anum_pg_dist_object_distribution_argument_index 7
#define Anum_pg_dist_object_colocationid 8
#define Anum_pg_dist_object_aggregation_strategy 9
/*
* Values for aggregation_strategy
*/
#define AGGREGATION_STRATEGY_NONE 0
#define AGGREGATION_STRATEGY_COMBINE 1
#endif /* PG_DIST_OBJECT_H */

View File

@ -114,6 +114,8 @@ typedef struct DistObjectCacheEntry
int distributionArgIndex;
int colocationId;
int aggregationStrategy;
} DistObjectCacheEntry;