Move custom aggregates into a not-default path, add pg_dist_enabled_custom_aggregates to have a set of allowed aggregates as it's possible for aggregates to not be compatible with this implementation

fix_120_custom_aggregates_distribute_multiarg
Philip Dubé 2019-09-10 19:14:28 +00:00
parent 9d19f2d810
commit dd5916ef3a
7 changed files with 193 additions and 120 deletions

View File

@ -34,6 +34,7 @@
#include "distributed/multi_logical_planner.h" #include "distributed/multi_logical_planner.h"
#include "distributed/multi_physical_planner.h" #include "distributed/multi_physical_planner.h"
#include "distributed/pg_dist_partition.h" #include "distributed/pg_dist_partition.h"
#include "distributed/pg_dist_enabled_custom_aggregates.h"
#include "distributed/worker_protocol.h" #include "distributed/worker_protocol.h"
#include "distributed/version_compat.h" #include "distributed/version_compat.h"
#include "nodes/makefuncs.h" #include "nodes/makefuncs.h"
@ -252,6 +253,7 @@ static List * WorkerAggregateExpressionList(Aggref *originalAggregate,
WorkerAggregateWalkerContext *walkerContextry); WorkerAggregateWalkerContext *walkerContextry);
static AggregateType GetAggregateType(Oid aggFunctionId); static AggregateType GetAggregateType(Oid aggFunctionId);
static Oid AggregateArgumentType(Aggref *aggregate); static Oid AggregateArgumentType(Aggref *aggregate);
static bool AggregateEnabledCustom(const char *functionName);
static Oid AggregateFunctionOidWithoutInput(const char *functionName); static Oid AggregateFunctionOidWithoutInput(const char *functionName);
static Oid AggregateFunctionOid(const char *functionName, Oid inputType); static Oid AggregateFunctionOid(const char *functionName, Oid inputType);
static Oid TypeOid(Oid schemaId, const char *typeName); static Oid TypeOid(Oid schemaId, const char *typeName);
@ -1513,58 +1515,7 @@ MasterAggregateExpression(Aggref *originalAggregate,
Form_pg_aggregate aggform; Form_pg_aggregate aggform;
Oid combine; Oid combine;
aggTuple = SearchSysCache1(AGGFNOID, if (aggregateType == AGGREGATE_COUNT && originalAggregate->aggdistinct &&
ObjectIdGetDatum(originalAggregate->aggfnoid));
if (!HeapTupleIsValid(aggTuple))
{
elog(WARNING, "citus cache lookup failed for aggregate %u",
originalAggregate->aggfnoid);
combine = InvalidOid;
}
else
{
aggform = (Form_pg_aggregate) GETSTRUCT(aggTuple);
combine = aggform->aggcombinefn;
ReleaseSysCache(aggTuple);
}
if (combine != InvalidOid)
{
Const *aggparam = NULL;
Var *column = NULL;
List *aggArguments = NIL;
Aggref *newMasterAggregate = NULL;
Oid coordCombineId = AggregateFunctionOidWithoutInput(
COORD_COMBINE_AGGREGATE_NAME);
Oid workerReturnType = BYTEAOID;
int32 workerReturnTypeMod = -1;
Oid workerCollationId = InvalidOid;
aggparam = makeConst(OIDOID, -1, InvalidOid, sizeof(Oid), ObjectIdGetDatum(
originalAggregate->aggfnoid), false, true);
column = makeVar(masterTableId, walkerContext->columnId, workerReturnType,
workerReturnTypeMod, workerCollationId, columnLevelsUp);
aggArguments = list_make1(makeTargetEntry((Expr *) aggparam, 1, NULL, false));
aggArguments = lappend(aggArguments, makeTargetEntry((Expr *) column, 2, NULL,
false));
/* coord_combine_agg(agg, workercol) */
newMasterAggregate = makeNode(Aggref);
newMasterAggregate->aggfnoid = coordCombineId;
newMasterAggregate->aggtype = originalAggregate->aggtype;
newMasterAggregate->args = aggArguments;
newMasterAggregate->aggkind = AGGKIND_NORMAL;
newMasterAggregate->aggfilter = originalAggregate->aggfilter;
newMasterAggregate->aggtranstype = INTERNALOID;
newMasterAggregate->aggargtypes = list_concat(list_make1_oid(OIDOID),
list_make1_oid(BYTEAOID));
newMasterAggregate->aggsplit = AGGSPLIT_SIMPLE;
newMasterExpression = (Expr *) newMasterAggregate;
}
else if (aggregateType == AGGREGATE_COUNT && originalAggregate->aggdistinct &&
CountDistinctErrorRate == DISABLE_DISTINCT_APPROXIMATION && CountDistinctErrorRate == DISABLE_DISTINCT_APPROXIMATION &&
walkerContext->pullDistinctColumns) walkerContext->pullDistinctColumns)
{ {
@ -1894,6 +1845,60 @@ MasterAggregateExpression(Aggref *originalAggregate,
newMasterExpression = (Expr *) unionAggregate; newMasterExpression = (Expr *) unionAggregate;
} }
else if (aggregateType == AGGREGATE_CUSTOM)
{
aggTuple = SearchSysCache1(AGGFNOID,
ObjectIdGetDatum(originalAggregate->aggfnoid));
if (!HeapTupleIsValid(aggTuple))
{
elog(WARNING, "citus cache lookup failed for aggregate %u",
originalAggregate->aggfnoid);
combine = InvalidOid;
}
else
{
aggform = (Form_pg_aggregate) GETSTRUCT(aggTuple);
combine = aggform->aggcombinefn;
ReleaseSysCache(aggTuple);
}
if (combine != InvalidOid)
{
Const *aggparam = NULL;
Var *column = NULL;
List *aggArguments = NIL;
Aggref *newMasterAggregate = NULL;
Oid coordCombineId = AggregateFunctionOidWithoutInput(
COORD_COMBINE_AGGREGATE_NAME);
Oid workerReturnType = BYTEAOID;
int32 workerReturnTypeMod = -1;
Oid workerCollationId = InvalidOid;
aggparam = makeConst(OIDOID, -1, InvalidOid, sizeof(Oid), ObjectIdGetDatum(
originalAggregate->aggfnoid), false, true);
column = makeVar(masterTableId, walkerContext->columnId, workerReturnType,
workerReturnTypeMod, workerCollationId, columnLevelsUp);
aggArguments = list_make1(makeTargetEntry((Expr *) aggparam, 1, NULL, false));
aggArguments = lappend(aggArguments, makeTargetEntry((Expr *) column, 2, NULL,
false));
/* coord_combine_agg(agg, workercol) */
newMasterAggregate = makeNode(Aggref);
newMasterAggregate->aggfnoid = coordCombineId;
newMasterAggregate->aggtype = originalAggregate->aggtype;
newMasterAggregate->args = aggArguments;
newMasterAggregate->aggkind = AGGKIND_NORMAL;
newMasterAggregate->aggfilter = originalAggregate->aggfilter;
newMasterAggregate->aggtranstype = INTERNALOID;
newMasterAggregate->aggargtypes = list_concat(list_make1_oid(OIDOID),
list_make1_oid(BYTEAOID));
newMasterAggregate->aggsplit = AGGSPLIT_SIMPLE;
newMasterExpression = (Expr *) newMasterAggregate;
}
}
else else
{ {
/* /*
@ -1929,6 +1934,7 @@ MasterAggregateExpression(Aggref *originalAggregate,
newMasterExpression = (Expr *) newMasterAggregate; newMasterExpression = (Expr *) newMasterAggregate;
} }
/* /*
* Aggregate functions could have changed the return type. If so, we wrap * Aggregate functions could have changed the return type. If so, we wrap
* the new expression with a conversion function to make it have the same * the new expression with a conversion function to make it have the same
@ -2801,56 +2807,7 @@ WorkerAggregateExpressionList(Aggref *originalAggregate,
Form_pg_aggregate aggform; Form_pg_aggregate aggform;
Oid combine; Oid combine;
aggTuple = SearchSysCache1(AGGFNOID, if (aggregateType == AGGREGATE_COUNT && originalAggregate->aggdistinct &&
ObjectIdGetDatum(originalAggregate->aggfnoid));
if (!HeapTupleIsValid(aggTuple))
{
elog(WARNING, "citus cache lookup failed for aggregate %u",
originalAggregate->aggfnoid);
combine = InvalidOid;
}
else
{
aggform = (Form_pg_aggregate) GETSTRUCT(aggTuple);
combine = aggform->aggcombinefn;
ReleaseSysCache(aggTuple);
}
if (combine != InvalidOid)
{
Const *aggparam = NULL;
Aggref *newWorkerAggregate = NULL;
List *aggArguments = NIL;
ListCell *originalAggArgCell;
Oid workerPartialId = AggregateFunctionOidWithoutInput(
WORKER_PARTIAL_AGGREGATE_NAME);
aggparam = makeConst(OIDOID, -1, InvalidOid, sizeof(Oid), ObjectIdGetDatum(
originalAggregate->aggfnoid), false, true);
aggArguments = list_make1(makeTargetEntry((Expr *) aggparam, 1, NULL, false));
foreach(originalAggArgCell, originalAggregate->args)
{
TargetEntry *arg = lfirst(originalAggArgCell);
TargetEntry *newArg = copyObject(arg);
newArg->resno++;
aggArguments = lappend(aggArguments, newArg);
}
/* worker_partial_agg(agg, ...args) */
newWorkerAggregate = makeNode(Aggref);
newWorkerAggregate->aggfnoid = workerPartialId;
newWorkerAggregate->aggtype = BYTEAOID;
newWorkerAggregate->args = aggArguments;
newWorkerAggregate->aggkind = AGGKIND_NORMAL;
newWorkerAggregate->aggfilter = originalAggregate->aggfilter;
newWorkerAggregate->aggtranstype = INTERNALOID;
newWorkerAggregate->aggargtypes = list_concat(list_make1_oid(OIDOID),
originalAggregate->aggargtypes);
newWorkerAggregate->aggsplit = AGGSPLIT_SIMPLE;
workerAggregateList = list_make1(newWorkerAggregate);
}
else if (aggregateType == AGGREGATE_COUNT && originalAggregate->aggdistinct &&
CountDistinctErrorRate == DISABLE_DISTINCT_APPROXIMATION && CountDistinctErrorRate == DISABLE_DISTINCT_APPROXIMATION &&
walkerContext->pullDistinctColumns) walkerContext->pullDistinctColumns)
{ {
@ -2964,6 +2921,58 @@ WorkerAggregateExpressionList(Aggref *originalAggregate,
workerAggregateList = lappend(workerAggregateList, sumAggregate); workerAggregateList = lappend(workerAggregateList, sumAggregate);
workerAggregateList = lappend(workerAggregateList, countAggregate); workerAggregateList = lappend(workerAggregateList, countAggregate);
} }
else if (aggregateType == AGGREGATE_CUSTOM)
{
aggTuple = SearchSysCache1(AGGFNOID,
ObjectIdGetDatum(originalAggregate->aggfnoid));
if (!HeapTupleIsValid(aggTuple))
{
elog(WARNING, "citus cache lookup failed for aggregate %u",
originalAggregate->aggfnoid);
combine = InvalidOid;
}
else
{
aggform = (Form_pg_aggregate) GETSTRUCT(aggTuple);
combine = aggform->aggcombinefn;
ReleaseSysCache(aggTuple);
}
if (combine != InvalidOid)
{
Const *aggparam = NULL;
Aggref *newWorkerAggregate = NULL;
List *aggArguments = NIL;
ListCell *originalAggArgCell;
Oid workerPartialId = AggregateFunctionOidWithoutInput(
WORKER_PARTIAL_AGGREGATE_NAME);
aggparam = makeConst(OIDOID, -1, InvalidOid, sizeof(Oid), ObjectIdGetDatum(
originalAggregate->aggfnoid), false, true);
aggArguments = list_make1(makeTargetEntry((Expr *) aggparam, 1, NULL, false));
foreach(originalAggArgCell, originalAggregate->args)
{
TargetEntry *arg = lfirst(originalAggArgCell);
TargetEntry *newArg = copyObject(arg);
newArg->resno++;
aggArguments = lappend(aggArguments, newArg);
}
/* worker_partial_agg(agg, ...args) */
newWorkerAggregate = makeNode(Aggref);
newWorkerAggregate->aggfnoid = workerPartialId;
newWorkerAggregate->aggtype = BYTEAOID;
newWorkerAggregate->args = aggArguments;
newWorkerAggregate->aggkind = AGGKIND_NORMAL;
newWorkerAggregate->aggfilter = originalAggregate->aggfilter;
newWorkerAggregate->aggtranstype = INTERNALOID;
newWorkerAggregate->aggargtypes = list_concat(list_make1_oid(OIDOID),
originalAggregate->aggargtypes);
newWorkerAggregate->aggsplit = AGGSPLIT_SIMPLE;
workerAggregateList = list_make1(newWorkerAggregate);
}
}
else else
{ {
/* /*
@ -3005,9 +3014,15 @@ GetAggregateType(Oid aggFunctionId)
aggFunctionId))); aggFunctionId)));
} }
if (AggregateEnabledCustom(aggregateProcName))
{
return AGGREGATE_CUSTOM;
}
aggregateCount = lengthof(AggregateNames); aggregateCount = lengthof(AggregateNames);
Assert(AGGREGATE_INVALID_FIRST == 0); Assert(AGGREGATE_INVALID_FIRST == 0);
for (aggregateIndex = 1; aggregateIndex < aggregateCount; aggregateIndex++) for (aggregateIndex = 1; aggregateIndex < aggregateCount; aggregateIndex++)
{ {
const char *aggregateName = AggregateNames[aggregateIndex]; const char *aggregateName = AggregateNames[aggregateIndex];
@ -3042,6 +3057,35 @@ AggregateArgumentType(Aggref *aggregate)
} }
static bool
AggregateEnabledCustom(const char *functionName)
{
SysScanDesc scanDescriptor = NULL;
ScanKeyData scanKey[1];
bool enabled = false;
HeapTuple heapTuple = NULL;
Relation pgDistEnabledCustomAggregates = NULL;
ScanKeyInit(&scanKey[0], Anum_pg_dist_enabled_custom_aggregates_name,
BTEqualStrategyNumber, F_NAMEEQ, CStringGetDatum(functionName));
pgDistEnabledCustomAggregates = heap_open(DistEnabledCustomAggregatesId(),
AccessShareLock);
scanDescriptor = systable_beginscan(pgDistEnabledCustomAggregates, InvalidOid, false,
NULL, 1, scanKey);
heapTuple = systable_getnext(scanDescriptor);
enabled = HeapTupleIsValid(heapTuple);
systable_endscan(scanDescriptor);
heap_close(pgDistEnabledCustomAggregates, AccessShareLock);
return enabled;
}
/* /*
* AggregateFunctionOid performs a reverse lookup on aggregate function name, * AggregateFunctionOid performs a reverse lookup on aggregate function name,
* and returns the corresponding aggregate function oid for the given function * and returns the corresponding aggregate function oid for the given function
@ -3160,8 +3204,8 @@ TypeOid(Oid schemaId, const char *typeName)
{ {
Oid typeOid; Oid typeOid;
typeOid = GetSysCacheOid2Compat(TYPENAMENSP, Anum_pg_type_oid, PointerGetDatum( typeOid = GetSysCacheOid2Compat(TYPENAMENSP, Anum_pg_type_oid,
typeName), PointerGetDatum(typeName),
ObjectIdGetDatum(schemaId)); ObjectIdGetDatum(schemaId));
return typeOid; return typeOid;

View File

@ -62,4 +62,10 @@ CREATE AGGREGATE coord_combine_agg(oid, bytea, anyelement) (
PARALLEL = SAFE PARALLEL = SAFE
); );
CREATE TABLE citus.pg_dist_enabled_custom_aggregates (
name text not null primary key
);
ALTER TABLE citus.pg_dist_enabled_custom_aggregates SET SCHEMA pg_catalog;
GRANT SELECT ON pg_catalog.pg_dist_node_metadata TO public;
RESET search_path; RESET search_path;

View File

@ -12,8 +12,6 @@
#include "fmgr.h" #include "fmgr.h"
#include "pg_config_manual.h" #include "pg_config_manual.h"
#include "utils/array.h"
PG_FUNCTION_INFO_V1(citus_stype_serialize); PG_FUNCTION_INFO_V1(citus_stype_serialize);
PG_FUNCTION_INFO_V1(citus_stype_deserialize); PG_FUNCTION_INFO_V1(citus_stype_deserialize);
PG_FUNCTION_INFO_V1(citus_stype_combine); PG_FUNCTION_INFO_V1(citus_stype_combine);
@ -265,8 +263,7 @@ citus_stype_deserialize(PG_FUNCTION_ARGS)
box->value_null = true; box->value_null = true;
PG_RETURN_POINTER(box); PG_RETURN_POINTER(box);
} }
else if (deserial != InvalidOid)
if (deserial != InvalidOid)
{ {
FmgrInfo fdeserialinfo; FmgrInfo fdeserialinfo;
LOCAL_FCINFO(fdeserial_callinfo, 2); LOCAL_FCINFO(fdeserial_callinfo, 2);
@ -282,12 +279,6 @@ citus_stype_deserialize(PG_FUNCTION_ARGS)
box->value = FunctionCallInvoke(fdeserial_callinfo); box->value = FunctionCallInvoke(fdeserial_callinfo);
box->value_null = fdeserial_callinfo->isnull; box->value_null = fdeserial_callinfo->isnull;
} }
/* TODO Correct null handling */
else if (value_null)
{
box->value = (Datum) 0;
box->value_null = true;
}
else else
{ {
transtypetuple = get_typeform(box->transtype, &transtypeform); transtypetuple = get_typeform(box->transtype, &transtypeform);

View File

@ -128,6 +128,7 @@ typedef struct MetadataCacheData
Oid distPlacementShardidIndexId; Oid distPlacementShardidIndexId;
Oid distPlacementPlacementidIndexId; Oid distPlacementPlacementidIndexId;
Oid distPlacementGroupidIndexId; Oid distPlacementGroupidIndexId;
Oid distEnabledCustomAggregatesId;
Oid distTransactionRelationId; Oid distTransactionRelationId;
Oid distTransactionGroupIndexId; Oid distTransactionGroupIndexId;
Oid distTransactionRecordIndexId; Oid distTransactionRecordIndexId;
@ -2113,6 +2114,17 @@ DistPlacementGroupidIndexId(void)
} }
/* return oid of pg_dist_enabled_custom_aggregates relation */
Oid
DistEnabledCustomAggregatesId(void)
{
CachedRelationLookup("pg_dist_enabled_custom_aggregates",
&MetadataCache.distEnabledCustomAggregatesId);
return MetadataCache.distEnabledCustomAggregatesId;
}
/* return oid of the read_intermediate_result(text,citus_copy_format) function */ /* return oid of the read_intermediate_result(text,citus_copy_format) function */
Oid Oid
CitusReadIntermediateResultFuncId(void) CitusReadIntermediateResultFuncId(void)

View File

@ -166,6 +166,7 @@ extern Oid DistPlacementRelationId(void);
extern Oid DistNodeRelationId(void); extern Oid DistNodeRelationId(void);
extern Oid DistLocalGroupIdRelationId(void); extern Oid DistLocalGroupIdRelationId(void);
extern Oid DistObjectRelationId(void); extern Oid DistObjectRelationId(void);
extern Oid DistEnabledCustomAggregatesId(void);
/* index oids */ /* index oids */
extern Oid DistNodeNodeIdIndexId(void); extern Oid DistNodeNodeIdIndexId(void);

View File

@ -77,9 +77,11 @@ typedef enum
AGGREGATE_HLL_ADD = 16, AGGREGATE_HLL_ADD = 16,
AGGREGATE_HLL_UNION = 17, AGGREGATE_HLL_UNION = 17,
AGGREGATE_TOPN_ADD_AGG = 18, AGGREGATE_TOPN_ADD_AGG = 18,
AGGREGATE_TOPN_UNION_AGG = 19 AGGREGATE_TOPN_UNION_AGG = 19,
} AggregateType;
/* AGGREGATE_CUSTOM must come last */
AGGREGATE_CUSTOM = 20
} AggregateType;
/* /*
* PushDownStatus indicates whether a node can be pushed down below its child * PushDownStatus indicates whether a node can be pushed down below its child

View File

@ -0,0 +1,17 @@
/*-------------------------------------------------------------------------
*
* pg_dist_enabled_custom_aggregates.h
* definition of the relation that lists which aggregates to treat as custom aggregates.
*
* Copyright (c) 2012-2019, Citus Data, Inc.
*
*-------------------------------------------------------------------------
*/
#ifndef PG_DIST_ENABLED_CUSTOM_AGGREGATES_H
#define PG_DIST_ENABLED_CUSTOM_AGGREGATES_H
#define Natts_pg_dist_node 1
#define Anum_pg_dist_enabled_custom_aggregates_name 1
#endif /* PG_DIST_ENABLED_CUSTOM_AGGREGATES_H */