Fixes from testing. Seems to be working

fix_120_custom_aggregates_distribute_multiarg
Philip Dubé 2019-10-16 04:30:10 +00:00
parent 070ffc785f
commit f5b038e2bc
4 changed files with 216 additions and 85 deletions

View File

@ -44,6 +44,7 @@
#include "distributed/multi_logical_optimizer.h"
#include "distributed/relation_access_tracking.h"
#include "distributed/worker_transaction.h"
#include "executor/spi.h"
#include "parser/parse_coerce.h"
#include "parser/parse_type.h"
#include "storage/lmgr.h"
@ -62,7 +63,7 @@ static char * GetFunctionDDLCommand(const RegProcedure funcOid);
static char * GetFunctionAlterOwnerCommand(const RegProcedure funcOid);
static void CreateAggregateHelper(const char *helperName, const char *helperPrefix,
Form_pg_proc proc, Form_pg_aggregate agg, int numargs,
Oid *argtypes);
Oid *argtypes, bool finalextra);
static int GetDistributionArgIndex(Oid functionOid, char *distributionArgumentName,
Oid *distributionArgumentOid);
static int GetFunctionColocationId(Oid functionOid, char *colocateWithName, Oid
@ -238,6 +239,7 @@ mark_aggregate_for_distributed_execution(PG_FUNCTION_ARGS)
int numargs = 0;
Oid *argtypes = NULL;
elog(WARNING, "prelude");
if (!HeapTupleIsValid(proctup))
{
@ -257,6 +259,8 @@ mark_aggregate_for_distributed_execution(PG_FUNCTION_ARGS)
}
agg = (Form_pg_aggregate) GETSTRUCT(aggtup);
elog(WARNING, "begins");
initStringInfo(&helperSuffix);
appendStringInfoAggregateHelperSuffix(&helperSuffix, proc, agg);
@ -264,26 +268,80 @@ mark_aggregate_for_distributed_execution(PG_FUNCTION_ARGS)
initStringInfo(&helperName);
appendStringInfo(&helperName, "%s%s", COORD_COMBINE_AGGREGATE_NAME,
helperSuffix.data);
helperOid = AggregateHelperOid(helperName.data, proctup, &numargs, &argtypes);
helperOid = CoordCombineAggOid(helperName.data);
elog(WARNING, "helperOid %d", helperOid);
if (helperOid == InvalidOid)
{
Oid coordArgTypes[2] = { BYTEAOID, ANYELEMENTOID };
CreateAggregateHelper(helperName.data, COORD_COMBINE_AGGREGATE_NAME, proc, agg,
numargs, argtypes);
2, coordArgTypes, true);
}
/* worker_partial_agg */
resetStringInfo(&helperName);
appendStringInfo(&helperName, "%s%s", WORKER_PARTIAL_AGGREGATE_NAME,
helperSuffix.data);
helperOid = AggregateHelperOid(helperName.data, proctup, &numargs, &argtypes);
helperOid = WorkerPartialAggOid(helperName.data, proctup, &numargs, &argtypes);
if (helperOid == InvalidOid)
{
/* Also check that we have a matching worker_partial_sfunc(internal, oid, ...) */
FuncCandidateList clist = FuncnameGetCandidates(list_make2(makeString("citus"),
makeString(
"worker_partial_agg_sfunc")),
numargs + 2,
NIL, false, false, true);
for (; clist; clist = clist->next)
{
if (clist->args[0] == INTERNALOID && clist->args[1] == OIDOID &&
memcmp(clist->args + 2, argtypes, numargs * sizeof(Oid)) == 0)
{
break;
return clist->oid;
}
}
if (clist == NULL)
{
int i;
StringInfoData command;
initStringInfo(&command);
appendStringInfoString(&command,
"CREATE FUNCTION citus.worker_partial_agg_sfunc(internal, oid");
for (i = 0; i < numargs; i++)
{
appendStringInfo(&command, ", %s", format_type_be_qualified(argtypes[i]));
}
appendStringInfoString(&command,
") RETURNS internal AS 'citus' LANGUAGE C PARALLEL SAFE");
SendCommandToWorkers(ALL_WORKERS, command.data);
/* TODO execute as CitusExtensionOwner */
if (SPI_connect() != SPI_OK_CONNECT)
{
elog(ERROR, "SPI_connect failed");
}
if (SPI_execute(command.data, false, 0) < 0)
{
elog(ERROR, "SPI_execute %s failed", command.data);
}
SPI_finish();
pfree(command.data);
}
CreateAggregateHelper(helperName.data, WORKER_PARTIAL_AGGREGATE_NAME, proc, agg,
numargs, argtypes);
numargs, argtypes, false);
}
elog(WARNING, "mark em");
/* set strategy column value */
UpdateDistObjectAggregationStrategy(funcOid, AGGREGATION_STRATEGY_COMBINE);
@ -302,10 +360,36 @@ early_exit:
/*
* AggregateHelperOid returns helper aggregate oid for given proc's HeapTuple
* CoordCombineAggOid returns coord_combine_agg oid with given name.
*/
Oid
AggregateHelperOid(char *helperName, HeapTuple proctup, int *numargs, Oid **argtypes)
CoordCombineAggOid(char *helperName)
{
FuncCandidateList clist;
clist = FuncnameGetCandidates(list_make2(makeString("citus"),
makeString(helperName)),
3, NIL, false, false, true);
for (; clist; clist = clist->next)
{
if (clist->args[0] == OIDOID &&
clist->args[1] == BYTEAOID &&
clist->args[2] == ANYELEMENTOID)
{
return clist->oid;
}
}
return InvalidOid;
}
/*
* WorkerPartialAggOid returns worker_partial_agg oid for given proc's HeapTuple.
*/
Oid
WorkerPartialAggOid(char *helperName, HeapTuple proctup, int *numargs, Oid **argtypes)
{
char **argnames = NULL;
char *argmodes = NULL;
@ -318,8 +402,8 @@ AggregateHelperOid(char *helperName, HeapTuple proctup, int *numargs, Oid **argt
for (; clist; clist = clist->next)
{
if (clist->args[0] == OIDOID && memcmp(clist->args + 1, argtypes, *numargs *
sizeof(Oid)) == 0)
if (clist->args[0] == OIDOID &&
memcmp(clist->args + 1, *argtypes, *numargs * sizeof(Oid)) == 0)
{
return clist->oid;
}
@ -332,7 +416,9 @@ AggregateHelperOid(char *helperName, HeapTuple proctup, int *numargs, Oid **argt
/*
* AggregateHelperName returns helper function name for a given aggregate.
*/
void appendStringInfoAggregateHelperSuffix(StringInfo helperSuffix, Form_pg_proc proc, Form_pg_aggregate agg)
void
appendStringInfoAggregateHelperSuffix(StringInfo helperSuffix, Form_pg_proc proc,
Form_pg_aggregate agg)
{
switch (proc->proparallel)
{
@ -377,11 +463,6 @@ void appendStringInfoAggregateHelperSuffix(StringInfo helperSuffix, Form_pg_proc
break;
}
}
if (agg->aggfinalextra)
{
appendStringInfoString(helperSuffix, "_fx");
}
}
if (agg->aggmfinalfn != InvalidOid)
@ -414,7 +495,6 @@ void appendStringInfoAggregateHelperSuffix(StringInfo helperSuffix, Form_pg_proc
appendStringInfoString(helperSuffix, "_fx");
}
}
}
@ -424,18 +504,23 @@ void appendStringInfoAggregateHelperSuffix(StringInfo helperSuffix, Form_pg_proc
static void
CreateAggregateHelper(const char *helperName, const char *helperPrefix,
Form_pg_proc proc, Form_pg_aggregate agg, int numargs,
Oid *argtypes)
Oid *argtypes, bool finalextra)
{
int i;
StringInfoData command;
initStringInfo(&command);
appendStringInfo(&command, "CREATE AGGREGATE %s("
"STYPE = internal, SFUNC = citus.%s_sfunc, FINALFUNC = %s_ffunc"
", COMBINEFUNC = citus.citus_stype_combine"
", SERIALFUNC = citus.citus_stype_serialize"
", DESERIALFUNC = citus.citus_stype_deserialize",
quote_qualified_identifier("citus", helperName),
appendStringInfo(&command, "CREATE AGGREGATE citus.%s(oid", helperName);
for (i = 0; i < numargs; i++)
{
appendStringInfo(&command, ", %s", format_type_be_qualified(argtypes[i]));
}
appendStringInfo(&command,
") (STYPE = internal, SFUNC = citus.%s_sfunc, FINALFUNC = citus.%s_ffunc"
", COMBINEFUNC = citus.citus_stype_combine"
", SERIALFUNC = citus.citus_stype_serialize"
", DESERIALFUNC = citus.citus_stype_deserialize",
helperPrefix, helperPrefix);
switch (proc->proparallel)
@ -482,45 +567,30 @@ CreateAggregateHelper(const char *helperName, const char *helperPrefix,
}
}
if (agg->aggfinalextra)
if (finalextra)
{
appendStringInfoString(&command, ", FINALFUNC_EXTRA");
}
}
if (agg->aggmfinalfn != InvalidOid)
{
switch (agg->aggmfinalmodify)
{
case AGGMODIFY_READ_ONLY:
{
appendStringInfoString(&command, ", MFINALFUNC_MODIFY = READ_ONLY");
break;
}
case AGGMODIFY_SHAREABLE:
{
appendStringInfoString(&command, ", MFINALFUNC_MODIFY = SHAREABLE");
break;
}
case AGGMODIFY_READ_WRITE:
{
appendStringInfoString(&command, ", MFINALFUNC_MODIFY = READ_WRITE");
break;
}
}
if (agg->aggmfinalextra)
{
appendStringInfoString(&command, ", MFINALFUNC_EXTRA");
}
}
appendStringInfoChar(&command, ')');
elog(WARNING, "SEND %s", command.data);
SendCommandToWorkers(ALL_WORKERS, command.data);
/* TODO execute as CitusExtensionOwner */
if (SPI_connect() != SPI_OK_CONNECT)
{
elog(ERROR, "SPI_connect failed");
}
if (SPI_execute(command.data, false, 0) < 0)
{
elog(ERROR, "SPI_execute %s failed", command.data);
}
SPI_finish();
elog(WARNING, "SENT");
pfree(command.data);
}
@ -805,10 +875,10 @@ UpdateDistObjectAggregationStrategy(Oid funcOid, int aggregationStrategy)
memset(replace, 0, sizeof(replace));
replace[Anum_pg_dist_object_distribution_argument_index - 1] = true;
values[Anum_pg_dist_object_distribution_argument_index - 1] = Int32GetDatum(
replace[Anum_pg_dist_object_aggregation_strategy - 1] = true;
values[Anum_pg_dist_object_aggregation_strategy - 1] = Int32GetDatum(
aggregationStrategy);
isnull[Anum_pg_dist_object_distribution_argument_index - 1] = false;
isnull[Anum_pg_dist_object_aggregation_strategy - 1] = false;
heapTuple = heap_modify_tuple(heapTuple, tupleDescriptor, values, isnull, replace);
@ -821,6 +891,8 @@ UpdateDistObjectAggregationStrategy(Oid funcOid, int aggregationStrategy)
systable_endscan(scanDescriptor);
heap_close(pgDistObjectRel, NoLock);
elog(WARNING, "marked %d %d", funcOid, aggregationStrategy);
}
@ -1148,7 +1220,7 @@ GetAggregateDDLCommand(const RegProcedure funcOid)
appendStringInfo(&buf, "%s ", quote_identifier(argname));
}
appendStringInfoString(&buf, format_type_be(argtype));
appendStringInfoString(&buf, format_type_be_qualified(argtype));
argsprinted++;

View File

@ -255,7 +255,9 @@ static List * WorkerAggregateExpressionList(Aggref *originalAggregate,
static AggregateType GetAggregateType(Oid aggFunctionId);
static Oid AggregateArgumentType(Aggref *aggregate);
static bool AggregateEnabledCustom(Oid aggregateOid);
static Oid AggregateFunctionHelperOid(const char *helperPrefix, Oid aggOid);
static HeapTuple AggregateFunctionHelperOidHelper(Oid aggOid, StringInfo helperName);
static Oid AggregateCoordCombineOid(Oid aggOid);
static Oid AggregateWorkerPartialOid(Oid aggOid);
static Oid AggregateFunctionOid(const char *functionName, Oid inputType);
static Oid TypeOid(Oid schemaId, const char *typeName);
static SortGroupClause * CreateSortGroupClause(Var *column);
@ -1852,6 +1854,7 @@ MasterAggregateExpression(Aggref *originalAggregate,
ObjectIdGetDatum(originalAggregate->aggfnoid));
if (!HeapTupleIsValid(aggTuple))
{
elog(WARNING, "!@#");
elog(WARNING, "citus cache lookup failed for aggregate %u",
originalAggregate->aggfnoid);
combine = InvalidOid;
@ -1869,13 +1872,18 @@ MasterAggregateExpression(Aggref *originalAggregate,
Var *column = NULL;
List *aggArguments = NIL;
Aggref *newMasterAggregate = NULL;
Oid coordCombineId = AggregateFunctionHelperOid(
COORD_COMBINE_AGGREGATE_NAME, originalAggregate->aggfnoid);
Oid coordCombineId = AggregateCoordCombineOid(originalAggregate->aggfnoid);
Oid workerReturnType = BYTEAOID;
int32 workerReturnTypeMod = -1;
Oid workerCollationId = InvalidOid;
if (coordCombineId == InvalidOid)
{
elog(ERROR,
"Could not find " COORD_COMBINE_AGGREGATE_NAME
" with correct signature.");
}
aggparam = makeConst(OIDOID, -1, InvalidOid, sizeof(Oid), ObjectIdGetDatum(
originalAggregate->aggfnoid), false, true);
column = makeVar(masterTableId, walkerContext->columnId, workerReturnType,
@ -1939,7 +1947,6 @@ MasterAggregateExpression(Aggref *originalAggregate,
newMasterExpression = (Expr *) newMasterAggregate;
}
/*
* Aggregate functions could have changed the return type. If so, we wrap
* the new expression with a conversion function to make it have the same
@ -2932,6 +2939,7 @@ WorkerAggregateExpressionList(Aggref *originalAggregate,
ObjectIdGetDatum(originalAggregate->aggfnoid));
if (!HeapTupleIsValid(aggTuple))
{
elog(WARNING, "!3434");
elog(WARNING, "citus cache lookup failed for aggregate %u",
originalAggregate->aggfnoid);
combine = InvalidOid;
@ -2949,8 +2957,14 @@ WorkerAggregateExpressionList(Aggref *originalAggregate,
Aggref *newWorkerAggregate = NULL;
List *aggArguments = NIL;
ListCell *originalAggArgCell;
Oid workerPartialId = AggregateFunctionHelperOid(
WORKER_PARTIAL_AGGREGATE_NAME, originalAggregate->aggfnoid);
Oid workerPartialId = AggregateWorkerPartialOid(originalAggregate->aggfnoid);
if (workerPartialId == InvalidOid)
{
elog(ERROR,
"Could not find " WORKER_PARTIAL_AGGREGATE_NAME
" with correct signature.");
}
aggparam = makeConst(REGPROCEDUREOID, -1, InvalidOid, sizeof(Oid),
ObjectIdGetDatum(originalAggregate->aggfnoid), false,
@ -3066,11 +3080,11 @@ AggregateArgumentType(Aggref *aggregate)
static bool
AggregateEnabledCustom(Oid aggregateOid)
{
DistObjectCacheEntry *cacheEntry = LookupDistObjectCacheEntry(AggregateRelationId,
DistObjectCacheEntry *cacheEntry = LookupDistObjectCacheEntry(ProcedureRelationId,
aggregateOid, 0);
return cacheEntry != NULL && cacheEntry->aggregationStrategy ==
AGGREGATION_STRATEGY_COMBINE;
return cacheEntry != NULL && cacheEntry->isDistributed &&
cacheEntry->aggregationStrategy == AGGREGATION_STRATEGY_COMBINE;
}
@ -3138,12 +3152,9 @@ AggregateFunctionOid(const char *functionName, Oid inputType)
/*
* AggregateFunctionHelperOid finds the aggregate helper for a given aggregate.
*/
static Oid
AggregateFunctionHelperOid(const char *helperPrefix, Oid aggOid)
static HeapTuple
AggregateFunctionHelperOidHelper(Oid aggOid, StringInfo helperName)
{
StringInfoData helperName;
int numargs;
Oid *argtypes;
HeapTuple proctup;
HeapTuple aggtup;
Form_pg_proc proc;
@ -3152,22 +3163,73 @@ AggregateFunctionHelperOid(const char *helperPrefix, Oid aggOid)
proctup = SearchSysCache1(PROCOID, aggOid);
if (!HeapTupleIsValid(proctup))
{
return InvalidOid;
return NULL;
}
proc = (Form_pg_proc) GETSTRUCT(proctup);
aggtup = SearchSysCache1(AGGFNOID, aggOid);
if (!HeapTupleIsValid(aggtup))
{
return InvalidOid;
ReleaseSysCache(proctup);
return NULL;
}
agg = (Form_pg_aggregate) GETSTRUCT(aggtup);
initStringInfo(&helperName);
appendStringInfoString(&helperName, helperPrefix);
appendStringInfoAggregateHelperSuffix(&helperName, proc, agg);
appendStringInfoAggregateHelperSuffix(helperName, proc, agg);
return AggregateHelperOid(helperName.data, proctup, &numargs, &argtypes);
ReleaseSysCache(aggtup);
return proctup;
}
/*
* AggregateFunctionHelperOid finds the aggregate helper for a given aggregate.
*/
static Oid
AggregateCoordCombineOid(Oid aggOid)
{
StringInfoData helperName;
HeapTuple proctup;
initStringInfo(&helperName);
appendStringInfoString(&helperName, COORD_COMBINE_AGGREGATE_NAME);
proctup = AggregateFunctionHelperOidHelper(aggOid, &helperName);
if (proctup == NULL)
{
elog(ERROR, "Failed to locate appropriate " COORD_COMBINE_AGGREGATE_NAME);
}
ReleaseSysCache(proctup);
return CoordCombineAggOid(helperName.data);
}
/*
* AggregateFunctionHelperOid finds the aggregate helper for a given aggregate.
*/
static Oid
AggregateWorkerPartialOid(Oid aggOid)
{
Oid result;
int numargs;
Oid *argtypes;
StringInfoData helperName;
HeapTuple proctup;
initStringInfo(&helperName);
appendStringInfoString(&helperName, WORKER_PARTIAL_AGGREGATE_NAME);
proctup = AggregateFunctionHelperOidHelper(aggOid, &helperName);
if (proctup == NULL)
{
elog(ERROR, "Failed to locate appropriate " WORKER_PARTIAL_AGGREGATE_NAME);
}
result = WorkerPartialAggOid(helperName.data, proctup, &numargs, &argtypes);
ReleaseSysCache(proctup);
return result;
}

View File

@ -1,6 +1,6 @@
SET search_path = 'pg_catalog';
CREATE FUNCTION mark_aggregate_for_distributed_execution(internal)
CREATE FUNCTION mark_aggregate_for_distributed_execution(regprocedure)
RETURNS void
AS 'MODULE_PATHNAME'
LANGUAGE C STRICT IMMUTABLE PARALLEL SAFE;
@ -22,11 +22,6 @@ RETURNS internal
AS 'MODULE_PATHNAME'
LANGUAGE C PARALLEL SAFE;
CREATE FUNCTION worker_partial_agg_sfunc(internal, oid, anyelement)
RETURNS internal
AS 'MODULE_PATHNAME'
LANGUAGE C PARALLEL SAFE;
CREATE FUNCTION worker_partial_agg_ffunc(internal)
RETURNS bytea
AS 'MODULE_PATHNAME'

View File

@ -53,7 +53,9 @@ extern bool ConstraintIsAForeignKey(char *constraintName, Oid relationId);
extern void appendStringInfoAggregateHelperSuffix(StringInfo helperSuffix,
Form_pg_proc proc,
Form_pg_aggregate agg);
extern Oid AggregateHelperOid(char *helperName, HeapTuple proctup, int *numargs, Oid **argtypes);
extern Oid CoordCombineAggOid(char *helperName);
extern Oid WorkerPartialAggOid(char *helperName, HeapTuple proctup, int *numargs,
Oid **argtypes);
extern List * PlanCreateFunctionStmt(CreateFunctionStmt *stmt, const char *queryString);
extern List * ProcessCreateFunctionStmt(CreateFunctionStmt *stmt, const
char *queryString);