Automatically convert useless declarations using regex replace (#3181)

* Add declaration removal to CI

* Convert declarations
pull/3181/merge
Jelte Fennema 2019-11-21 13:47:29 +01:00 committed by GitHub
parent 9961297d7b
commit 1d8dde232f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
130 changed files with 3487 additions and 5950 deletions

View File

@ -22,6 +22,12 @@ jobs:
- run: - run:
name: 'Check Style' name: 'Check Style'
command: citus_indent --check command: citus_indent --check
- run:
name: 'Remove useless declarations'
command: ci/remove_useless_declarations.sh
- run:
name: 'Check if changed'
command: git diff --cached --exit-code
check-sql-snapshots: check-sql-snapshots:
docker: docker:
- image: 'citus/extbuilder:latest' - image: 'citus/extbuilder:latest'

View File

@ -0,0 +1,23 @@
#!/bin/sh
set -eu
files=$(find src -iname '*.c' | git check-attr --stdin citus-style | grep -v ': unset$' | sed 's/: citus-style: set$//')
while true; do
# shellcheck disable=SC2086
perl -i -p0e 's/\n\t(?!return )(?P<type>(\w+ )+\**)(?>(?P<variable>\w+)( = *[\w>\s\n-]*?)?;\n(?P<code_between>(?>(?P<comment_or_string_or_not_preprocessor>\/\*.*?\*\/|"(?>\\"|.)*?"|[^#]))*?)(\t)?(?=\b(?P=variable)\b))(?<=\n\t)(?P=variable) =(?![^;]*?[^>_]\b(?P=variable)\b[^_])/\n$+{code_between}\t$+{type}$+{variable} =/sg' $files
# The following are simply the same regex, but repeated for different tab sizes
# (this is needed because variable sized backtracking is not supported in perl)
# shellcheck disable=SC2086
perl -i -p0e 's/\n\t\t(?!return )(?P<type>(\w+ )+\**)(?>(?P<variable>\w+)( = *[\w>\s\n-]*?)?;\n(?P<code_between>(?>(?P<comment_or_string_or_not_preprocessor>\/\*.*?\*\/|"(?>\\"|.)*?"|[^#]))*?)(\t\t)?(?=\b(?P=variable)\b))(?<=\n\t\t)(?P=variable) =(?![^;]*?[^>_]\b(?P=variable)\b[^_])/\n$+{code_between}\t\t$+{type}$+{variable} =/sg' $files
# shellcheck disable=SC2086
perl -i -p0e 's/\n\t\t\t(?!return )(?P<type>(\w+ )+\**)(?>(?P<variable>\w+)( = *[\w>\s\n-]*?)?;\n(?P<code_between>(?>(?P<comment_or_string_or_not_preprocessor>\/\*.*?\*\/|"(?>\\"|.)*?"|[^#]))*?)(\t\t\t)?(?=\b(?P=variable)\b))(?<=\n\t\t\t)(?P=variable) =(?![^;]*?[^>_]\b(?P=variable)\b[^_])/\n$+{code_between}\t\t\t$+{type}$+{variable} =/sg' $files
# shellcheck disable=SC2086
perl -i -p0e 's/\n\t\t\t\t(?!return )(?P<type>(\w+ )+\**)(?>(?P<variable>\w+)( = *[\w>\s\n-]*?)?;\n(?P<code_between>(?>(?P<comment_or_string_or_not_preprocessor>\/\*.*?\*\/|"(?>\\"|.)*?"|[^#]))*?)(\t\t\t\t)?(?=\b(?P=variable)\b))(?<=\n\t\t\t\t)(?P=variable) =(?![^;]*?[^>_]\b(?P=variable)\b[^_])/\n$+{code_between}\t\t\t\t$+{type}$+{variable} =/sg' $files
# shellcheck disable=SC2086
perl -i -p0e 's/\n\t\t\t\t\t(?!return )(?P<type>(\w+ )+\**)(?>(?P<variable>\w+)( = *[\w>\s\n-]*?)?;\n(?P<code_between>(?>(?P<comment_or_string_or_not_preprocessor>\/\*.*?\*\/|"(?>\\"|.)*?"|[^#]))*?)(\t\t\t\t\t)?(?=\b(?P=variable)\b))(?<=\n\t\t\t\t\t)(?P=variable) =(?![^;]*?[^>_]\b(?P=variable)\b[^_])/\n$+{code_between}\t\t\t\t\t$+{type}$+{variable} =/sg' $files
# shellcheck disable=SC2086
perl -i -p0e 's/\n\t\t\t\t\t\t(?!return )(?P<type>(\w+ )+\**)(?>(?P<variable>\w+)( = *[\w>\s\n-]*?)?;\n(?P<code_between>(?>(?P<comment_or_string_or_not_preprocessor>\/\*.*?\*\/|"(?>\\"|.)*?"|[^#]))*?)(\t\t\t\t\t\t)?(?=\b(?P=variable)\b))(?<=\n\t\t\t\t\t\t)(?P=variable) =(?![^;]*?[^>_]\b(?P=variable)\b[^_])/\n$+{code_between}\t\t\t\t\t\t$+{type}$+{variable} =/sg' $files
git diff --quiet && break;
git add .;
done

View File

@ -47,11 +47,11 @@ static bool CallFuncExprRemotely(CallStmt *callStmt,
bool bool
CallDistributedProcedureRemotely(CallStmt *callStmt, DestReceiver *dest) CallDistributedProcedureRemotely(CallStmt *callStmt, DestReceiver *dest)
{ {
DistObjectCacheEntry *procedure = NULL;
FuncExpr *funcExpr = callStmt->funcexpr; FuncExpr *funcExpr = callStmt->funcexpr;
Oid functionId = funcExpr->funcid; Oid functionId = funcExpr->funcid;
procedure = LookupDistObjectCacheEntry(ProcedureRelationId, functionId, 0); DistObjectCacheEntry *procedure = LookupDistObjectCacheEntry(ProcedureRelationId,
functionId, 0);
if (procedure == NULL || !procedure->isDistributed) if (procedure == NULL || !procedure->isDistributed)
{ {
return false; return false;
@ -68,25 +68,13 @@ static bool
CallFuncExprRemotely(CallStmt *callStmt, DistObjectCacheEntry *procedure, CallFuncExprRemotely(CallStmt *callStmt, DistObjectCacheEntry *procedure,
FuncExpr *funcExpr, DestReceiver *dest) FuncExpr *funcExpr, DestReceiver *dest)
{ {
Oid colocatedRelationId = InvalidOid;
Node *partitionValueNode = NULL;
Const *partitionValue = NULL;
Datum partitionValueDatum = 0;
ShardInterval *shardInterval = NULL;
List *placementList = NIL;
DistTableCacheEntry *distTable = NULL;
Var *partitionColumn = NULL;
ShardPlacement *placement = NULL;
WorkerNode *workerNode = NULL;
StringInfo callCommand = NULL;
if (IsMultiStatementTransaction()) if (IsMultiStatementTransaction())
{ {
ereport(DEBUG1, (errmsg("cannot push down CALL in multi-statement transaction"))); ereport(DEBUG1, (errmsg("cannot push down CALL in multi-statement transaction")));
return false; return false;
} }
colocatedRelationId = ColocatedTableId(procedure->colocationId); Oid colocatedRelationId = ColocatedTableId(procedure->colocationId);
if (colocatedRelationId == InvalidOid) if (colocatedRelationId == InvalidOid)
{ {
ereport(DEBUG1, (errmsg("stored procedure does not have co-located tables"))); ereport(DEBUG1, (errmsg("stored procedure does not have co-located tables")));
@ -107,8 +95,8 @@ CallFuncExprRemotely(CallStmt *callStmt, DistObjectCacheEntry *procedure,
return false; return false;
} }
distTable = DistributedTableCacheEntry(colocatedRelationId); DistTableCacheEntry *distTable = DistributedTableCacheEntry(colocatedRelationId);
partitionColumn = distTable->partitionColumn; Var *partitionColumn = distTable->partitionColumn;
if (partitionColumn == NULL) if (partitionColumn == NULL)
{ {
/* This can happen if colocated with a reference table. Punt for now. */ /* This can happen if colocated with a reference table. Punt for now. */
@ -117,7 +105,7 @@ CallFuncExprRemotely(CallStmt *callStmt, DistObjectCacheEntry *procedure,
return false; return false;
} }
partitionValueNode = (Node *) list_nth(funcExpr->args, Node *partitionValueNode = (Node *) list_nth(funcExpr->args,
procedure->distributionArgIndex); procedure->distributionArgIndex);
partitionValueNode = strip_implicit_coercions(partitionValueNode); partitionValueNode = strip_implicit_coercions(partitionValueNode);
if (!IsA(partitionValueNode, Const)) if (!IsA(partitionValueNode, Const))
@ -125,9 +113,9 @@ CallFuncExprRemotely(CallStmt *callStmt, DistObjectCacheEntry *procedure,
ereport(DEBUG1, (errmsg("distribution argument value must be a constant"))); ereport(DEBUG1, (errmsg("distribution argument value must be a constant")));
return false; return false;
} }
partitionValue = (Const *) partitionValueNode; Const *partitionValue = (Const *) partitionValueNode;
partitionValueDatum = partitionValue->constvalue; Datum partitionValueDatum = partitionValue->constvalue;
if (partitionValue->consttype != partitionColumn->vartype) if (partitionValue->consttype != partitionColumn->vartype)
{ {
CopyCoercionData coercionData; CopyCoercionData coercionData;
@ -138,14 +126,14 @@ CallFuncExprRemotely(CallStmt *callStmt, DistObjectCacheEntry *procedure,
partitionValueDatum = CoerceColumnValue(partitionValueDatum, &coercionData); partitionValueDatum = CoerceColumnValue(partitionValueDatum, &coercionData);
} }
shardInterval = FindShardInterval(partitionValueDatum, distTable); ShardInterval *shardInterval = FindShardInterval(partitionValueDatum, distTable);
if (shardInterval == NULL) if (shardInterval == NULL)
{ {
ereport(DEBUG1, (errmsg("cannot push down call, failed to find shard interval"))); ereport(DEBUG1, (errmsg("cannot push down call, failed to find shard interval")));
return false; return false;
} }
placementList = FinalizedShardPlacementList(shardInterval->shardId); List *placementList = FinalizedShardPlacementList(shardInterval->shardId);
if (list_length(placementList) != 1) if (list_length(placementList) != 1)
{ {
/* punt on this for now */ /* punt on this for now */
@ -154,8 +142,8 @@ CallFuncExprRemotely(CallStmt *callStmt, DistObjectCacheEntry *procedure,
return false; return false;
} }
placement = (ShardPlacement *) linitial(placementList); ShardPlacement *placement = (ShardPlacement *) linitial(placementList);
workerNode = FindWorkerNode(placement->nodeName, placement->nodePort); WorkerNode *workerNode = FindWorkerNode(placement->nodeName, placement->nodePort);
if (workerNode == NULL || !workerNode->hasMetadata || !workerNode->metadataSynced) if (workerNode == NULL || !workerNode->hasMetadata || !workerNode->metadataSynced)
{ {
ereport(DEBUG1, (errmsg("there is no worker node with metadata"))); ereport(DEBUG1, (errmsg("there is no worker node with metadata")));
@ -165,7 +153,7 @@ CallFuncExprRemotely(CallStmt *callStmt, DistObjectCacheEntry *procedure,
ereport(DEBUG1, (errmsg("pushing down the procedure"))); ereport(DEBUG1, (errmsg("pushing down the procedure")));
/* build remote command with fully qualified names */ /* build remote command with fully qualified names */
callCommand = makeStringInfo(); StringInfo callCommand = makeStringInfo();
appendStringInfo(callCommand, "CALL %s", pg_get_rule_expr((Node *) funcExpr)); appendStringInfo(callCommand, "CALL %s", pg_get_rule_expr((Node *) funcExpr));
{ {

View File

@ -28,10 +28,9 @@ PlanClusterStmt(ClusterStmt *clusterStmt, const char *clusterCommand)
} }
else else
{ {
Oid relationId = InvalidOid;
bool missingOK = false; bool missingOK = false;
relationId = RangeVarGetRelid(clusterStmt->relation, AccessShareLock, Oid relationId = RangeVarGetRelid(clusterStmt->relation, AccessShareLock,
missingOK); missingOK);
if (OidIsValid(relationId)) if (OidIsValid(relationId))

View File

@ -126,14 +126,10 @@ master_create_distributed_table(PG_FUNCTION_ARGS)
text *distributionColumnText = PG_GETARG_TEXT_P(1); text *distributionColumnText = PG_GETARG_TEXT_P(1);
Oid distributionMethodOid = PG_GETARG_OID(2); Oid distributionMethodOid = PG_GETARG_OID(2);
char *distributionColumnName = NULL;
Var *distributionColumn = NULL;
char distributionMethod = 0;
char *colocateWithTableName = NULL; char *colocateWithTableName = NULL;
bool viaDeprecatedAPI = true; bool viaDeprecatedAPI = true;
ObjectAddress tableAddress = { 0 }; ObjectAddress tableAddress = { 0 };
Relation relation = NULL;
CheckCitusVersion(ERROR); CheckCitusVersion(ERROR);
EnsureCoordinator(); EnsureCoordinator();
@ -153,7 +149,7 @@ master_create_distributed_table(PG_FUNCTION_ARGS)
* sense of this table until we've committed, and we don't want multiple * sense of this table until we've committed, and we don't want multiple
* backends manipulating this relation. * backends manipulating this relation.
*/ */
relation = try_relation_open(relationId, ExclusiveLock); Relation relation = try_relation_open(relationId, ExclusiveLock);
if (relation == NULL) if (relation == NULL)
{ {
@ -168,10 +164,10 @@ master_create_distributed_table(PG_FUNCTION_ARGS)
*/ */
EnsureRelationKindSupported(relationId); EnsureRelationKindSupported(relationId);
distributionColumnName = text_to_cstring(distributionColumnText); char *distributionColumnName = text_to_cstring(distributionColumnText);
distributionColumn = BuildDistributionKeyFromColumnName(relation, Var *distributionColumn = BuildDistributionKeyFromColumnName(relation,
distributionColumnName); distributionColumnName);
distributionMethod = LookupDistributionMethod(distributionMethodOid); char distributionMethod = LookupDistributionMethod(distributionMethodOid);
CreateDistributedTable(relationId, distributionColumn, distributionMethod, CreateDistributedTable(relationId, distributionColumn, distributionMethod,
colocateWithTableName, viaDeprecatedAPI); colocateWithTableName, viaDeprecatedAPI);
@ -190,28 +186,18 @@ master_create_distributed_table(PG_FUNCTION_ARGS)
Datum Datum
create_distributed_table(PG_FUNCTION_ARGS) create_distributed_table(PG_FUNCTION_ARGS)
{ {
Oid relationId = InvalidOid;
text *distributionColumnText = NULL;
Oid distributionMethodOid = InvalidOid;
text *colocateWithTableNameText = NULL;
ObjectAddress tableAddress = { 0 }; ObjectAddress tableAddress = { 0 };
Relation relation = NULL;
char *distributionColumnName = NULL;
Var *distributionColumn = NULL;
char distributionMethod = 0;
char *colocateWithTableName = NULL;
bool viaDeprecatedAPI = false; bool viaDeprecatedAPI = false;
CheckCitusVersion(ERROR); CheckCitusVersion(ERROR);
EnsureCoordinator(); EnsureCoordinator();
relationId = PG_GETARG_OID(0); Oid relationId = PG_GETARG_OID(0);
distributionColumnText = PG_GETARG_TEXT_P(1); text *distributionColumnText = PG_GETARG_TEXT_P(1);
distributionMethodOid = PG_GETARG_OID(2); Oid distributionMethodOid = PG_GETARG_OID(2);
colocateWithTableNameText = PG_GETARG_TEXT_P(3); text *colocateWithTableNameText = PG_GETARG_TEXT_P(3);
EnsureTableOwner(relationId); EnsureTableOwner(relationId);
@ -229,7 +215,7 @@ create_distributed_table(PG_FUNCTION_ARGS)
* sense of this table until we've committed, and we don't want multiple * sense of this table until we've committed, and we don't want multiple
* backends manipulating this relation. * backends manipulating this relation.
*/ */
relation = try_relation_open(relationId, ExclusiveLock); Relation relation = try_relation_open(relationId, ExclusiveLock);
if (relation == NULL) if (relation == NULL)
{ {
@ -244,12 +230,12 @@ create_distributed_table(PG_FUNCTION_ARGS)
*/ */
EnsureRelationKindSupported(relationId); EnsureRelationKindSupported(relationId);
distributionColumnName = text_to_cstring(distributionColumnText); char *distributionColumnName = text_to_cstring(distributionColumnText);
distributionColumn = BuildDistributionKeyFromColumnName(relation, Var *distributionColumn = BuildDistributionKeyFromColumnName(relation,
distributionColumnName); distributionColumnName);
distributionMethod = LookupDistributionMethod(distributionMethodOid); char distributionMethod = LookupDistributionMethod(distributionMethodOid);
colocateWithTableName = text_to_cstring(colocateWithTableNameText); char *colocateWithTableName = text_to_cstring(colocateWithTableNameText);
CreateDistributedTable(relationId, distributionColumn, distributionMethod, CreateDistributedTable(relationId, distributionColumn, distributionMethod,
colocateWithTableName, viaDeprecatedAPI); colocateWithTableName, viaDeprecatedAPI);
@ -270,10 +256,7 @@ create_reference_table(PG_FUNCTION_ARGS)
{ {
Oid relationId = PG_GETARG_OID(0); Oid relationId = PG_GETARG_OID(0);
Relation relation = NULL;
char *colocateWithTableName = NULL; char *colocateWithTableName = NULL;
List *workerNodeList = NIL;
int workerCount = 0;
Var *distributionColumn = NULL; Var *distributionColumn = NULL;
ObjectAddress tableAddress = { 0 }; ObjectAddress tableAddress = { 0 };
@ -297,7 +280,7 @@ create_reference_table(PG_FUNCTION_ARGS)
* sense of this table until we've committed, and we don't want multiple * sense of this table until we've committed, and we don't want multiple
* backends manipulating this relation. * backends manipulating this relation.
*/ */
relation = relation_open(relationId, ExclusiveLock); Relation relation = relation_open(relationId, ExclusiveLock);
/* /*
* We should do this check here since the codes in the following lines rely * We should do this check here since the codes in the following lines rely
@ -306,8 +289,8 @@ create_reference_table(PG_FUNCTION_ARGS)
*/ */
EnsureRelationKindSupported(relationId); EnsureRelationKindSupported(relationId);
workerNodeList = ActivePrimaryNodeList(ShareLock); List *workerNodeList = ActivePrimaryNodeList(ShareLock);
workerCount = list_length(workerNodeList); int workerCount = list_length(workerNodeList);
/* if there are no workers, error out */ /* if there are no workers, error out */
if (workerCount == 0) if (workerCount == 0)
@ -344,27 +327,24 @@ void
CreateDistributedTable(Oid relationId, Var *distributionColumn, char distributionMethod, CreateDistributedTable(Oid relationId, Var *distributionColumn, char distributionMethod,
char *colocateWithTableName, bool viaDeprecatedAPI) char *colocateWithTableName, bool viaDeprecatedAPI)
{ {
char replicationModel = REPLICATION_MODEL_INVALID; char replicationModel = AppropriateReplicationModel(distributionMethod,
uint32 colocationId = INVALID_COLOCATION_ID; viaDeprecatedAPI);
Oid colocatedTableId = InvalidOid;
bool localTableEmpty = false;
replicationModel = AppropriateReplicationModel(distributionMethod, viaDeprecatedAPI);
/* /*
* ColocationIdForNewTable assumes caller acquires lock on relationId. In our case, * ColocationIdForNewTable assumes caller acquires lock on relationId. In our case,
* our caller already acquired lock on relationId. * our caller already acquired lock on relationId.
*/ */
colocationId = ColocationIdForNewTable(relationId, distributionColumn, uint32 colocationId = ColocationIdForNewTable(relationId, distributionColumn,
distributionMethod, replicationModel, distributionMethod, replicationModel,
colocateWithTableName, viaDeprecatedAPI); colocateWithTableName,
viaDeprecatedAPI);
EnsureRelationCanBeDistributed(relationId, distributionColumn, distributionMethod, EnsureRelationCanBeDistributed(relationId, distributionColumn, distributionMethod,
colocationId, replicationModel, viaDeprecatedAPI); colocationId, replicationModel, viaDeprecatedAPI);
/* we need to calculate these variables before creating distributed metadata */ /* we need to calculate these variables before creating distributed metadata */
localTableEmpty = LocalTableEmpty(relationId); bool localTableEmpty = LocalTableEmpty(relationId);
colocatedTableId = ColocatedTableId(colocationId); Oid colocatedTableId = ColocatedTableId(colocationId);
/* create an entry for distributed table in pg_dist_partition */ /* create an entry for distributed table in pg_dist_partition */
InsertIntoPgDistPartition(relationId, distributionMethod, distributionColumn, InsertIntoPgDistPartition(relationId, distributionMethod, distributionColumn,
@ -642,9 +622,6 @@ EnsureRelationCanBeDistributed(Oid relationId, Var *distributionColumn,
char distributionMethod, uint32 colocationId, char distributionMethod, uint32 colocationId,
char replicationModel, bool viaDeprecatedAPI) char replicationModel, bool viaDeprecatedAPI)
{ {
Relation relation = NULL;
TupleDesc relationDesc = NULL;
char *relationName = NULL;
Oid parentRelationId = InvalidOid; Oid parentRelationId = InvalidOid;
EnsureTableNotDistributed(relationId); EnsureTableNotDistributed(relationId);
@ -652,9 +629,9 @@ EnsureRelationCanBeDistributed(Oid relationId, Var *distributionColumn,
EnsureReplicationSettings(InvalidOid, replicationModel); EnsureReplicationSettings(InvalidOid, replicationModel);
/* we assume callers took necessary locks */ /* we assume callers took necessary locks */
relation = relation_open(relationId, NoLock); Relation relation = relation_open(relationId, NoLock);
relationDesc = RelationGetDescr(relation); TupleDesc relationDesc = RelationGetDescr(relation);
relationName = RelationGetRelationName(relation); char *relationName = RelationGetRelationName(relation);
if (!RelationUsesHeapAccessMethodOrNone(relation)) if (!RelationUsesHeapAccessMethodOrNone(relation))
{ {
@ -805,7 +782,6 @@ EnsureTableCanBeColocatedWith(Oid relationId, char replicationModel,
char sourceDistributionMethod = sourceTableEntry->partitionMethod; char sourceDistributionMethod = sourceTableEntry->partitionMethod;
char sourceReplicationModel = sourceTableEntry->replicationModel; char sourceReplicationModel = sourceTableEntry->replicationModel;
Var *sourceDistributionColumn = DistPartitionKey(sourceRelationId); Var *sourceDistributionColumn = DistPartitionKey(sourceRelationId);
Oid sourceDistributionColumnType = InvalidOid;
if (sourceDistributionMethod != DISTRIBUTE_BY_HASH) if (sourceDistributionMethod != DISTRIBUTE_BY_HASH)
{ {
@ -826,7 +802,7 @@ EnsureTableCanBeColocatedWith(Oid relationId, char replicationModel,
sourceRelationName, relationName))); sourceRelationName, relationName)));
} }
sourceDistributionColumnType = sourceDistributionColumn->vartype; Oid sourceDistributionColumnType = sourceDistributionColumn->vartype;
if (sourceDistributionColumnType != distributionColumnType) if (sourceDistributionColumnType != distributionColumnType)
{ {
char *relationName = get_rel_name(relationId); char *relationName = get_rel_name(relationId);
@ -898,9 +874,8 @@ static void
EnsureTableNotDistributed(Oid relationId) EnsureTableNotDistributed(Oid relationId)
{ {
char *relationName = get_rel_name(relationId); char *relationName = get_rel_name(relationId);
bool isDistributedTable = false;
isDistributedTable = IsDistributedTable(relationId); bool isDistributedTable = IsDistributedTable(relationId);
if (isDistributedTable) if (isDistributedTable)
{ {
@ -949,20 +924,18 @@ EnsureReplicationSettings(Oid relationId, char replicationModel)
static char static char
LookupDistributionMethod(Oid distributionMethodOid) LookupDistributionMethod(Oid distributionMethodOid)
{ {
HeapTuple enumTuple = NULL;
Form_pg_enum enumForm = NULL;
char distributionMethod = 0; char distributionMethod = 0;
const char *enumLabel = NULL;
enumTuple = SearchSysCache1(ENUMOID, ObjectIdGetDatum(distributionMethodOid)); HeapTuple enumTuple = SearchSysCache1(ENUMOID, ObjectIdGetDatum(
distributionMethodOid));
if (!HeapTupleIsValid(enumTuple)) if (!HeapTupleIsValid(enumTuple))
{ {
ereport(ERROR, (errmsg("invalid internal value for enum: %u", ereport(ERROR, (errmsg("invalid internal value for enum: %u",
distributionMethodOid))); distributionMethodOid)));
} }
enumForm = (Form_pg_enum) GETSTRUCT(enumTuple); Form_pg_enum enumForm = (Form_pg_enum) GETSTRUCT(enumTuple);
enumLabel = NameStr(enumForm->enumlabel); const char *enumLabel = NameStr(enumForm->enumlabel);
if (strncmp(enumLabel, "append", NAMEDATALEN) == 0) if (strncmp(enumLabel, "append", NAMEDATALEN) == 0)
{ {
@ -997,9 +970,6 @@ static Oid
SupportFunctionForColumn(Var *partitionColumn, Oid accessMethodId, SupportFunctionForColumn(Var *partitionColumn, Oid accessMethodId,
int16 supportFunctionNumber) int16 supportFunctionNumber)
{ {
Oid operatorFamilyId = InvalidOid;
Oid supportFunctionOid = InvalidOid;
Oid operatorClassInputType = InvalidOid;
Oid columnOid = partitionColumn->vartype; Oid columnOid = partitionColumn->vartype;
Oid operatorClassId = GetDefaultOpClass(columnOid, accessMethodId); Oid operatorClassId = GetDefaultOpClass(columnOid, accessMethodId);
@ -1014,9 +984,9 @@ SupportFunctionForColumn(Var *partitionColumn, Oid accessMethodId,
" class defined."))); " class defined.")));
} }
operatorFamilyId = get_opclass_family(operatorClassId); Oid operatorFamilyId = get_opclass_family(operatorClassId);
operatorClassInputType = get_opclass_input_type(operatorClassId); Oid operatorClassInputType = get_opclass_input_type(operatorClassId);
supportFunctionOid = get_opfamily_proc(operatorFamilyId, operatorClassInputType, Oid supportFunctionOid = get_opfamily_proc(operatorFamilyId, operatorClassInputType,
operatorClassInputType, operatorClassInputType,
supportFunctionNumber); supportFunctionNumber);
@ -1037,13 +1007,8 @@ LocalTableEmpty(Oid tableId)
char *tableName = get_rel_name(tableId); char *tableName = get_rel_name(tableId);
char *tableQualifiedName = quote_qualified_identifier(schemaName, tableName); char *tableQualifiedName = quote_qualified_identifier(schemaName, tableName);
int spiConnectionResult = 0;
int spiQueryResult = 0;
StringInfo selectExistQueryString = makeStringInfo(); StringInfo selectExistQueryString = makeStringInfo();
HeapTuple tuple = NULL;
Datum hasDataDatum = 0;
bool localTableEmpty = false;
bool columnNull = false; bool columnNull = false;
bool readOnly = true; bool readOnly = true;
@ -1052,7 +1017,7 @@ LocalTableEmpty(Oid tableId)
AssertArg(!IsDistributedTable(tableId)); AssertArg(!IsDistributedTable(tableId));
spiConnectionResult = SPI_connect(); int spiConnectionResult = SPI_connect();
if (spiConnectionResult != SPI_OK_CONNECT) if (spiConnectionResult != SPI_OK_CONNECT)
{ {
ereport(ERROR, (errmsg("could not connect to SPI manager"))); ereport(ERROR, (errmsg("could not connect to SPI manager")));
@ -1060,7 +1025,7 @@ LocalTableEmpty(Oid tableId)
appendStringInfo(selectExistQueryString, SELECT_EXIST_QUERY, tableQualifiedName); appendStringInfo(selectExistQueryString, SELECT_EXIST_QUERY, tableQualifiedName);
spiQueryResult = SPI_execute(selectExistQueryString->data, readOnly, 0); int spiQueryResult = SPI_execute(selectExistQueryString->data, readOnly, 0);
if (spiQueryResult != SPI_OK_SELECT) if (spiQueryResult != SPI_OK_SELECT)
{ {
ereport(ERROR, (errmsg("execution was not successful \"%s\"", ereport(ERROR, (errmsg("execution was not successful \"%s\"",
@ -1070,9 +1035,10 @@ LocalTableEmpty(Oid tableId)
/* we expect that SELECT EXISTS query will return single value in a single row */ /* we expect that SELECT EXISTS query will return single value in a single row */
Assert(SPI_processed == 1); Assert(SPI_processed == 1);
tuple = SPI_tuptable->vals[rowId]; HeapTuple tuple = SPI_tuptable->vals[rowId];
hasDataDatum = SPI_getbinval(tuple, SPI_tuptable->tupdesc, attributeId, &columnNull); Datum hasDataDatum = SPI_getbinval(tuple, SPI_tuptable->tupdesc, attributeId,
localTableEmpty = !DatumGetBool(hasDataDatum); &columnNull);
bool localTableEmpty = !DatumGetBool(hasDataDatum);
SPI_finish(); SPI_finish();
@ -1145,13 +1111,12 @@ CanUseExclusiveConnections(Oid relationId, bool localTableEmpty)
void void
CreateTruncateTrigger(Oid relationId) CreateTruncateTrigger(Oid relationId)
{ {
CreateTrigStmt *trigger = NULL;
StringInfo triggerName = makeStringInfo(); StringInfo triggerName = makeStringInfo();
bool internal = true; bool internal = true;
appendStringInfo(triggerName, "truncate_trigger"); appendStringInfo(triggerName, "truncate_trigger");
trigger = makeNode(CreateTrigStmt); CreateTrigStmt *trigger = makeNode(CreateTrigStmt);
trigger->trigname = triggerName->data; trigger->trigname = triggerName->data;
trigger->relation = NULL; trigger->relation = NULL;
trigger->funcname = SystemFuncName("citus_truncate_trigger"); trigger->funcname = SystemFuncName("citus_truncate_trigger");
@ -1232,9 +1197,7 @@ CopyLocalDataIntoShards(Oid distributedRelationId)
HeapScanDesc scan = NULL; HeapScanDesc scan = NULL;
#endif #endif
HeapTuple tuple = NULL; HeapTuple tuple = NULL;
ExprContext *econtext = NULL;
MemoryContext oldContext = NULL; MemoryContext oldContext = NULL;
TupleTableSlot *slot = NULL;
uint64 rowsCopied = 0; uint64 rowsCopied = 0;
/* take an ExclusiveLock to block all operations except SELECT */ /* take an ExclusiveLock to block all operations except SELECT */
@ -1264,7 +1227,8 @@ CopyLocalDataIntoShards(Oid distributedRelationId)
/* get the table columns */ /* get the table columns */
tupleDescriptor = RelationGetDescr(distributedRelation); tupleDescriptor = RelationGetDescr(distributedRelation);
slot = MakeSingleTupleTableSlotCompat(tupleDescriptor, &TTSOpsHeapTuple); TupleTableSlot *slot = MakeSingleTupleTableSlotCompat(tupleDescriptor,
&TTSOpsHeapTuple);
columnNameList = TupleDescColumnNameList(tupleDescriptor); columnNameList = TupleDescColumnNameList(tupleDescriptor);
/* determine the partition column in the tuple descriptor */ /* determine the partition column in the tuple descriptor */
@ -1276,7 +1240,7 @@ CopyLocalDataIntoShards(Oid distributedRelationId)
/* initialise per-tuple memory context */ /* initialise per-tuple memory context */
estate = CreateExecutorState(); estate = CreateExecutorState();
econtext = GetPerTupleExprContext(estate); ExprContext *econtext = GetPerTupleExprContext(estate);
econtext->ecxt_scantuple = slot; econtext->ecxt_scantuple = slot;
copyDest = copyDest =
@ -1362,9 +1326,8 @@ static List *
TupleDescColumnNameList(TupleDesc tupleDescriptor) TupleDescColumnNameList(TupleDesc tupleDescriptor)
{ {
List *columnNameList = NIL; List *columnNameList = NIL;
int columnIndex = 0;
for (columnIndex = 0; columnIndex < tupleDescriptor->natts; columnIndex++) for (int columnIndex = 0; columnIndex < tupleDescriptor->natts; columnIndex++)
{ {
Form_pg_attribute currentColumn = TupleDescAttr(tupleDescriptor, columnIndex); Form_pg_attribute currentColumn = TupleDescAttr(tupleDescriptor, columnIndex);
char *columnName = NameStr(currentColumn->attname); char *columnName = NameStr(currentColumn->attname);
@ -1392,9 +1355,7 @@ TupleDescColumnNameList(TupleDesc tupleDescriptor)
static bool static bool
RelationUsesIdentityColumns(TupleDesc relationDesc) RelationUsesIdentityColumns(TupleDesc relationDesc)
{ {
int attributeIndex = 0; for (int attributeIndex = 0; attributeIndex < relationDesc->natts; attributeIndex++)
for (attributeIndex = 0; attributeIndex < relationDesc->natts; attributeIndex++)
{ {
Form_pg_attribute attributeForm = TupleDescAttr(relationDesc, attributeIndex); Form_pg_attribute attributeForm = TupleDescAttr(relationDesc, attributeIndex);

View File

@ -50,7 +50,6 @@ void
EnsureDependenciesExistsOnAllNodes(const ObjectAddress *target) EnsureDependenciesExistsOnAllNodes(const ObjectAddress *target)
{ {
/* local variables to work with dependencies */ /* local variables to work with dependencies */
List *dependencies = NIL;
List *dependenciesWithCommands = NIL; List *dependenciesWithCommands = NIL;
ListCell *dependencyCell = NULL; ListCell *dependencyCell = NULL;
@ -58,13 +57,12 @@ EnsureDependenciesExistsOnAllNodes(const ObjectAddress *target)
List *ddlCommands = NULL; List *ddlCommands = NULL;
/* local variables to work with worker nodes */ /* local variables to work with worker nodes */
List *workerNodeList = NULL;
ListCell *workerNodeCell = NULL; ListCell *workerNodeCell = NULL;
/* /*
* collect all dependencies in creation order and get their ddl commands * collect all dependencies in creation order and get their ddl commands
*/ */
dependencies = GetDependenciesForObject(target); List *dependencies = GetDependenciesForObject(target);
foreach(dependencyCell, dependencies) foreach(dependencyCell, dependencies)
{ {
ObjectAddress *dependency = (ObjectAddress *) lfirst(dependencyCell); ObjectAddress *dependency = (ObjectAddress *) lfirst(dependencyCell);
@ -94,7 +92,7 @@ EnsureDependenciesExistsOnAllNodes(const ObjectAddress *target)
* either get it now, or get it in master_add_node after this transaction finishes and * either get it now, or get it in master_add_node after this transaction finishes and
* the pg_dist_object record becomes visible. * the pg_dist_object record becomes visible.
*/ */
workerNodeList = ActivePrimaryWorkerNodeList(RowShareLock); List *workerNodeList = ActivePrimaryWorkerNodeList(RowShareLock);
/* /*
* right after we acquired the lock we mark our objects as distributed, these changes * right after we acquired the lock we mark our objects as distributed, these changes
@ -216,13 +214,12 @@ void
ReplicateAllDependenciesToNode(const char *nodeName, int nodePort) ReplicateAllDependenciesToNode(const char *nodeName, int nodePort)
{ {
ListCell *dependencyCell = NULL; ListCell *dependencyCell = NULL;
List *dependencies = NIL;
List *ddlCommands = NIL; List *ddlCommands = NIL;
/* /*
* collect all dependencies in creation order and get their ddl commands * collect all dependencies in creation order and get their ddl commands
*/ */
dependencies = GetDistributedObjectAddressList(); List *dependencies = GetDistributedObjectAddressList();
/* /*
* Depending on changes in the environment, such as the enable_object_propagation guc * Depending on changes in the environment, such as the enable_object_propagation guc

View File

@ -126,8 +126,6 @@ static void
MasterRemoveDistributedTableMetadataFromWorkers(Oid relationId, char *schemaName, MasterRemoveDistributedTableMetadataFromWorkers(Oid relationId, char *schemaName,
char *tableName) char *tableName)
{ {
char *deleteDistributionCommand = NULL;
/* /*
* The SQL_DROP trigger calls this function even for tables that are * The SQL_DROP trigger calls this function even for tables that are
* not distributed. In that case, silently ignore. This is not very * not distributed. In that case, silently ignore. This is not very
@ -147,6 +145,6 @@ MasterRemoveDistributedTableMetadataFromWorkers(Oid relationId, char *schemaName
} }
/* drop the distributed table metadata on the workers */ /* drop the distributed table metadata on the workers */
deleteDistributionCommand = DistributionDeleteCommand(schemaName, tableName); char *deleteDistributionCommand = DistributionDeleteCommand(schemaName, tableName);
SendCommandToWorkers(WORKERS_WITH_METADATA, deleteDistributionCommand); SendCommandToWorkers(WORKERS_WITH_METADATA, deleteDistributionCommand);
} }

View File

@ -82,8 +82,6 @@ ErrorIfUnstableCreateOrAlterExtensionStmt(Node *parseTree)
static char * static char *
ExtractNewExtensionVersion(Node *parseTree) ExtractNewExtensionVersion(Node *parseTree)
{ {
Value *newVersionValue = NULL;
List *optionsList = NIL; List *optionsList = NIL;
if (IsA(parseTree, CreateExtensionStmt)) if (IsA(parseTree, CreateExtensionStmt))
@ -100,7 +98,7 @@ ExtractNewExtensionVersion(Node *parseTree)
Assert(false); Assert(false);
} }
newVersionValue = GetExtensionOption(optionsList, "new_version"); Value *newVersionValue = GetExtensionOption(optionsList, "new_version");
/* return target string safely */ /* return target string safely */
if (newVersionValue) if (newVersionValue)
@ -126,9 +124,6 @@ ExtractNewExtensionVersion(Node *parseTree)
List * List *
PlanCreateExtensionStmt(CreateExtensionStmt *createExtensionStmt, const char *queryString) PlanCreateExtensionStmt(CreateExtensionStmt *createExtensionStmt, const char *queryString)
{ {
List *commands = NIL;
const char *createExtensionStmtSql = NULL;
if (!ShouldPropagateExtensionCommand((Node *) createExtensionStmt)) if (!ShouldPropagateExtensionCommand((Node *) createExtensionStmt))
{ {
return NIL; return NIL;
@ -168,13 +163,13 @@ PlanCreateExtensionStmt(CreateExtensionStmt *createExtensionStmt, const char *qu
*/ */
AddSchemaFieldIfMissing(createExtensionStmt); AddSchemaFieldIfMissing(createExtensionStmt);
createExtensionStmtSql = DeparseTreeNode((Node *) createExtensionStmt); const char *createExtensionStmtSql = DeparseTreeNode((Node *) createExtensionStmt);
/* /*
* To prevent recursive propagation in mx architecture, we disable ddl * To prevent recursive propagation in mx architecture, we disable ddl
* propagation before sending the command to workers. * propagation before sending the command to workers.
*/ */
commands = list_make3(DISABLE_DDL_PROPAGATION, List *commands = list_make3(DISABLE_DDL_PROPAGATION,
(void *) createExtensionStmtSql, (void *) createExtensionStmtSql,
ENABLE_DDL_PROPAGATION); ENABLE_DDL_PROPAGATION);
@ -229,8 +224,6 @@ void
ProcessCreateExtensionStmt(CreateExtensionStmt *createExtensionStmt, const ProcessCreateExtensionStmt(CreateExtensionStmt *createExtensionStmt, const
char *queryString) char *queryString)
{ {
const ObjectAddress *extensionAddress = NULL;
if (!ShouldPropagateExtensionCommand((Node *) createExtensionStmt)) if (!ShouldPropagateExtensionCommand((Node *) createExtensionStmt))
{ {
return; return;
@ -246,7 +239,8 @@ ProcessCreateExtensionStmt(CreateExtensionStmt *createExtensionStmt, const
return; return;
} }
extensionAddress = GetObjectAddressFromParseTree((Node *) createExtensionStmt, false); const ObjectAddress *extensionAddress = GetObjectAddressFromParseTree(
(Node *) createExtensionStmt, false);
EnsureDependenciesExistsOnAllNodes(extensionAddress); EnsureDependenciesExistsOnAllNodes(extensionAddress);
@ -267,11 +261,6 @@ PlanDropExtensionStmt(DropStmt *dropStmt, const char *queryString)
{ {
List *allDroppedExtensions = dropStmt->objects; List *allDroppedExtensions = dropStmt->objects;
List *distributedExtensions = NIL;
List *distributedExtensionAddresses = NIL;
List *commands = NIL;
const char *deparsedStmt = NULL;
ListCell *addressCell = NULL; ListCell *addressCell = NULL;
@ -281,7 +270,7 @@ PlanDropExtensionStmt(DropStmt *dropStmt, const char *queryString)
} }
/* get distributed extensions to be dropped in worker nodes as well */ /* get distributed extensions to be dropped in worker nodes as well */
distributedExtensions = FilterDistributedExtensions(allDroppedExtensions); List *distributedExtensions = FilterDistributedExtensions(allDroppedExtensions);
if (list_length(distributedExtensions) <= 0) if (list_length(distributedExtensions) <= 0)
{ {
@ -308,7 +297,7 @@ PlanDropExtensionStmt(DropStmt *dropStmt, const char *queryString)
*/ */
EnsureSequentialModeForExtensionDDL(); EnsureSequentialModeForExtensionDDL();
distributedExtensionAddresses = ExtensionNameListToObjectAddressList( List *distributedExtensionAddresses = ExtensionNameListToObjectAddressList(
distributedExtensions); distributedExtensions);
/* unmark each distributed extension */ /* unmark each distributed extension */
@ -326,7 +315,7 @@ PlanDropExtensionStmt(DropStmt *dropStmt, const char *queryString)
* its execution. * its execution.
*/ */
dropStmt->objects = distributedExtensions; dropStmt->objects = distributedExtensions;
deparsedStmt = DeparseTreeNode((Node *) dropStmt); const char *deparsedStmt = DeparseTreeNode((Node *) dropStmt);
dropStmt->objects = allDroppedExtensions; dropStmt->objects = allDroppedExtensions;
@ -334,7 +323,7 @@ PlanDropExtensionStmt(DropStmt *dropStmt, const char *queryString)
* To prevent recursive propagation in mx architecture, we disable ddl * To prevent recursive propagation in mx architecture, we disable ddl
* propagation before sending the command to workers. * propagation before sending the command to workers.
*/ */
commands = list_make3(DISABLE_DDL_PROPAGATION, List *commands = list_make3(DISABLE_DDL_PROPAGATION,
(void *) deparsedStmt, (void *) deparsedStmt,
ENABLE_DDL_PROPAGATION); ENABLE_DDL_PROPAGATION);
@ -425,9 +414,6 @@ List *
PlanAlterExtensionSchemaStmt(AlterObjectSchemaStmt *alterExtensionStmt, const PlanAlterExtensionSchemaStmt(AlterObjectSchemaStmt *alterExtensionStmt, const
char *queryString) char *queryString)
{ {
const char *alterExtensionStmtSql = NULL;
List *commands = NIL;
if (!ShouldPropagateExtensionCommand((Node *) alterExtensionStmt)) if (!ShouldPropagateExtensionCommand((Node *) alterExtensionStmt))
{ {
return NIL; return NIL;
@ -451,13 +437,13 @@ PlanAlterExtensionSchemaStmt(AlterObjectSchemaStmt *alterExtensionStmt, const
*/ */
EnsureSequentialModeForExtensionDDL(); EnsureSequentialModeForExtensionDDL();
alterExtensionStmtSql = DeparseTreeNode((Node *) alterExtensionStmt); const char *alterExtensionStmtSql = DeparseTreeNode((Node *) alterExtensionStmt);
/* /*
* To prevent recursive propagation in mx architecture, we disable ddl * To prevent recursive propagation in mx architecture, we disable ddl
* propagation before sending the command to workers. * propagation before sending the command to workers.
*/ */
commands = list_make3(DISABLE_DDL_PROPAGATION, List *commands = list_make3(DISABLE_DDL_PROPAGATION,
(void *) alterExtensionStmtSql, (void *) alterExtensionStmtSql,
ENABLE_DDL_PROPAGATION); ENABLE_DDL_PROPAGATION);
@ -474,9 +460,8 @@ void
ProcessAlterExtensionSchemaStmt(AlterObjectSchemaStmt *alterExtensionStmt, const ProcessAlterExtensionSchemaStmt(AlterObjectSchemaStmt *alterExtensionStmt, const
char *queryString) char *queryString)
{ {
const ObjectAddress *extensionAddress = NULL; const ObjectAddress *extensionAddress = GetObjectAddressFromParseTree(
(Node *) alterExtensionStmt, false);
extensionAddress = GetObjectAddressFromParseTree((Node *) alterExtensionStmt, false);
if (!ShouldPropagateExtensionCommand((Node *) alterExtensionStmt)) if (!ShouldPropagateExtensionCommand((Node *) alterExtensionStmt))
{ {
@ -495,9 +480,6 @@ List *
PlanAlterExtensionUpdateStmt(AlterExtensionStmt *alterExtensionStmt, const PlanAlterExtensionUpdateStmt(AlterExtensionStmt *alterExtensionStmt, const
char *queryString) char *queryString)
{ {
const char *alterExtensionStmtSql = NULL;
List *commands = NIL;
if (!ShouldPropagateExtensionCommand((Node *) alterExtensionStmt)) if (!ShouldPropagateExtensionCommand((Node *) alterExtensionStmt))
{ {
return NIL; return NIL;
@ -522,13 +504,13 @@ PlanAlterExtensionUpdateStmt(AlterExtensionStmt *alterExtensionStmt, const
*/ */
EnsureSequentialModeForExtensionDDL(); EnsureSequentialModeForExtensionDDL();
alterExtensionStmtSql = DeparseTreeNode((Node *) alterExtensionStmt); const char *alterExtensionStmtSql = DeparseTreeNode((Node *) alterExtensionStmt);
/* /*
* To prevent recursive propagation in mx architecture, we disable ddl * To prevent recursive propagation in mx architecture, we disable ddl
* propagation before sending the command to workers. * propagation before sending the command to workers.
*/ */
commands = list_make3(DISABLE_DDL_PROPAGATION, List *commands = list_make3(DISABLE_DDL_PROPAGATION,
(void *) alterExtensionStmtSql, (void *) alterExtensionStmtSql,
ENABLE_DDL_PROPAGATION); ENABLE_DDL_PROPAGATION);
@ -711,18 +693,13 @@ IsAlterExtensionSetSchemaCitus(Node *parseTree)
List * List *
CreateExtensionDDLCommand(const ObjectAddress *extensionAddress) CreateExtensionDDLCommand(const ObjectAddress *extensionAddress)
{ {
List *ddlCommands = NIL;
const char *ddlCommand = NULL;
Node *stmt = NULL;
/* generate a statement for creation of the extension in "if not exists" construct */ /* generate a statement for creation of the extension in "if not exists" construct */
stmt = RecreateExtensionStmt(extensionAddress->objectId); Node *stmt = RecreateExtensionStmt(extensionAddress->objectId);
/* capture ddl command for the create statement */ /* capture ddl command for the create statement */
ddlCommand = DeparseTreeNode(stmt); const char *ddlCommand = DeparseTreeNode(stmt);
ddlCommands = list_make1((void *) ddlCommand); List *ddlCommands = list_make1((void *) ddlCommand);
return ddlCommands; return ddlCommands;
} }
@ -747,26 +724,22 @@ RecreateExtensionStmt(Oid extensionOid)
} }
/* schema DefElement related variables */ /* schema DefElement related variables */
Oid extensionSchemaOid = InvalidOid;
char *extensionSchemaName = NULL;
Node *schemaNameArg = NULL;
/* set location to -1 as it is unknown */ /* set location to -1 as it is unknown */
int location = -1; int location = -1;
DefElem *schemaDefElement = NULL;
/* set extension name and if_not_exists fields */ /* set extension name and if_not_exists fields */
createExtensionStmt->extname = extensionName; createExtensionStmt->extname = extensionName;
createExtensionStmt->if_not_exists = true; createExtensionStmt->if_not_exists = true;
/* get schema name that extension was created on */ /* get schema name that extension was created on */
extensionSchemaOid = get_extension_schema(extensionOid); Oid extensionSchemaOid = get_extension_schema(extensionOid);
extensionSchemaName = get_namespace_name(extensionSchemaOid); char *extensionSchemaName = get_namespace_name(extensionSchemaOid);
/* make DefEleme for extensionSchemaName */ /* make DefEleme for extensionSchemaName */
schemaNameArg = (Node *) makeString(extensionSchemaName); Node *schemaNameArg = (Node *) makeString(extensionSchemaName);
schemaDefElement = makeDefElem("schema", schemaNameArg, location); DefElem *schemaDefElement = makeDefElem("schema", schemaNameArg, location);
/* append the schema name DefElem finally */ /* append the schema name DefElem finally */
createExtensionStmt->options = lappend(createExtensionStmt->options, createExtensionStmt->options = lappend(createExtensionStmt->options,
@ -784,15 +757,11 @@ ObjectAddress *
AlterExtensionSchemaStmtObjectAddress(AlterObjectSchemaStmt *alterExtensionSchemaStmt, AlterExtensionSchemaStmtObjectAddress(AlterObjectSchemaStmt *alterExtensionSchemaStmt,
bool missing_ok) bool missing_ok)
{ {
ObjectAddress *extensionAddress = NULL;
Oid extensionOid = InvalidOid;
const char *extensionName = NULL;
Assert(alterExtensionSchemaStmt->objectType == OBJECT_EXTENSION); Assert(alterExtensionSchemaStmt->objectType == OBJECT_EXTENSION);
extensionName = strVal(alterExtensionSchemaStmt->object); const char *extensionName = strVal(alterExtensionSchemaStmt->object);
extensionOid = get_extension_oid(extensionName, missing_ok); Oid extensionOid = get_extension_oid(extensionName, missing_ok);
if (extensionOid == InvalidOid) if (extensionOid == InvalidOid)
{ {
@ -801,7 +770,7 @@ AlterExtensionSchemaStmtObjectAddress(AlterObjectSchemaStmt *alterExtensionSchem
extensionName))); extensionName)));
} }
extensionAddress = palloc0(sizeof(ObjectAddress)); ObjectAddress *extensionAddress = palloc0(sizeof(ObjectAddress));
ObjectAddressSet(*extensionAddress, ExtensionRelationId, extensionOid); ObjectAddressSet(*extensionAddress, ExtensionRelationId, extensionOid);
return extensionAddress; return extensionAddress;
@ -816,13 +785,9 @@ ObjectAddress *
AlterExtensionUpdateStmtObjectAddress(AlterExtensionStmt *alterExtensionStmt, AlterExtensionUpdateStmtObjectAddress(AlterExtensionStmt *alterExtensionStmt,
bool missing_ok) bool missing_ok)
{ {
ObjectAddress *extensionAddress = NULL; const char *extensionName = alterExtensionStmt->extname;
Oid extensionOid = InvalidOid;
const char *extensionName = NULL;
extensionName = alterExtensionStmt->extname; Oid extensionOid = get_extension_oid(extensionName, missing_ok);
extensionOid = get_extension_oid(extensionName, missing_ok);
if (extensionOid == InvalidOid) if (extensionOid == InvalidOid)
{ {
@ -831,7 +796,7 @@ AlterExtensionUpdateStmtObjectAddress(AlterExtensionStmt *alterExtensionStmt,
extensionName))); extensionName)));
} }
extensionAddress = palloc0(sizeof(ObjectAddress)); ObjectAddress *extensionAddress = palloc0(sizeof(ObjectAddress));
ObjectAddressSet(*extensionAddress, ExtensionRelationId, extensionOid); ObjectAddressSet(*extensionAddress, ExtensionRelationId, extensionOid);
return extensionAddress; return extensionAddress;

View File

@ -49,25 +49,21 @@ static void ForeignConstraintFindDistKeys(HeapTuple pgConstraintTuple,
bool bool
ConstraintIsAForeignKeyToReferenceTable(char *constraintName, Oid relationId) ConstraintIsAForeignKeyToReferenceTable(char *constraintName, Oid relationId)
{ {
Relation pgConstraint = NULL;
SysScanDesc scanDescriptor = NULL;
ScanKeyData scanKey[1]; ScanKeyData scanKey[1];
int scanKeyCount = 1; int scanKeyCount = 1;
HeapTuple heapTuple = NULL;
bool foreignKeyToReferenceTable = false; bool foreignKeyToReferenceTable = false;
pgConstraint = heap_open(ConstraintRelationId, AccessShareLock); Relation pgConstraint = heap_open(ConstraintRelationId, AccessShareLock);
ScanKeyInit(&scanKey[0], Anum_pg_constraint_contype, BTEqualStrategyNumber, F_CHAREQ, ScanKeyInit(&scanKey[0], Anum_pg_constraint_contype, BTEqualStrategyNumber, F_CHAREQ,
CharGetDatum(CONSTRAINT_FOREIGN)); CharGetDatum(CONSTRAINT_FOREIGN));
scanDescriptor = systable_beginscan(pgConstraint, InvalidOid, false, SysScanDesc scanDescriptor = systable_beginscan(pgConstraint, InvalidOid, false,
NULL, scanKeyCount, scanKey); NULL, scanKeyCount, scanKey);
heapTuple = systable_getnext(scanDescriptor); HeapTuple heapTuple = systable_getnext(scanDescriptor);
while (HeapTupleIsValid(heapTuple)) while (HeapTupleIsValid(heapTuple))
{ {
Oid referencedTableId = InvalidOid;
Form_pg_constraint constraintForm = (Form_pg_constraint) GETSTRUCT(heapTuple); Form_pg_constraint constraintForm = (Form_pg_constraint) GETSTRUCT(heapTuple);
char *tupleConstraintName = (constraintForm->conname).data; char *tupleConstraintName = (constraintForm->conname).data;
@ -78,7 +74,7 @@ ConstraintIsAForeignKeyToReferenceTable(char *constraintName, Oid relationId)
continue; continue;
} }
referencedTableId = constraintForm->confrelid; Oid referencedTableId = constraintForm->confrelid;
Assert(IsDistributedTable(referencedTableId)); Assert(IsDistributedTable(referencedTableId));
@ -122,11 +118,8 @@ ErrorIfUnsupportedForeignConstraintExists(Relation relation, char referencingDis
Var *referencingDistKey, Var *referencingDistKey,
uint32 referencingColocationId) uint32 referencingColocationId)
{ {
Relation pgConstraint = NULL;
SysScanDesc scanDescriptor = NULL;
ScanKeyData scanKey[1]; ScanKeyData scanKey[1];
int scanKeyCount = 1; int scanKeyCount = 1;
HeapTuple heapTuple = NULL;
Oid referencingTableId = relation->rd_id; Oid referencingTableId = relation->rd_id;
Oid referencedTableId = InvalidOid; Oid referencedTableId = InvalidOid;
@ -145,26 +138,22 @@ ErrorIfUnsupportedForeignConstraintExists(Relation relation, char referencingDis
referencingNotReplicated = (ShardReplicationFactor == 1); referencingNotReplicated = (ShardReplicationFactor == 1);
} }
pgConstraint = heap_open(ConstraintRelationId, AccessShareLock); Relation pgConstraint = heap_open(ConstraintRelationId, AccessShareLock);
ScanKeyInit(&scanKey[0], Anum_pg_constraint_conrelid, BTEqualStrategyNumber, F_OIDEQ, ScanKeyInit(&scanKey[0], Anum_pg_constraint_conrelid, BTEqualStrategyNumber, F_OIDEQ,
relation->rd_id); relation->rd_id);
scanDescriptor = systable_beginscan(pgConstraint, ConstraintRelidTypidNameIndexId, SysScanDesc scanDescriptor = systable_beginscan(pgConstraint,
ConstraintRelidTypidNameIndexId,
true, NULL, true, NULL,
scanKeyCount, scanKey); scanKeyCount, scanKey);
heapTuple = systable_getnext(scanDescriptor); HeapTuple heapTuple = systable_getnext(scanDescriptor);
while (HeapTupleIsValid(heapTuple)) while (HeapTupleIsValid(heapTuple))
{ {
Form_pg_constraint constraintForm = (Form_pg_constraint) GETSTRUCT(heapTuple); Form_pg_constraint constraintForm = (Form_pg_constraint) GETSTRUCT(heapTuple);
bool referencedIsDistributed = false;
char referencedDistMethod = 0; char referencedDistMethod = 0;
Var *referencedDistKey = NULL; Var *referencedDistKey = NULL;
bool referencingIsReferenceTable = false;
bool referencedIsReferenceTable = false;
int referencingAttrIndex = -1; int referencingAttrIndex = -1;
int referencedAttrIndex = -1; int referencedAttrIndex = -1;
bool referencingColumnsIncludeDistKey = false;
bool foreignConstraintOnDistKey = false;
if (constraintForm->contype != CONSTRAINT_FOREIGN) if (constraintForm->contype != CONSTRAINT_FOREIGN)
{ {
@ -175,7 +164,7 @@ ErrorIfUnsupportedForeignConstraintExists(Relation relation, char referencingDis
referencedTableId = constraintForm->confrelid; referencedTableId = constraintForm->confrelid;
selfReferencingTable = (referencingTableId == referencedTableId); selfReferencingTable = (referencingTableId == referencedTableId);
referencedIsDistributed = IsDistributedTable(referencedTableId); bool referencedIsDistributed = IsDistributedTable(referencedTableId);
if (!referencedIsDistributed && !selfReferencingTable) if (!referencedIsDistributed && !selfReferencingTable)
{ {
ereport(ERROR, (errcode(ERRCODE_INVALID_TABLE_DEFINITION), ereport(ERROR, (errcode(ERRCODE_INVALID_TABLE_DEFINITION),
@ -199,8 +188,8 @@ ErrorIfUnsupportedForeignConstraintExists(Relation relation, char referencingDis
referencedColocationId = referencingColocationId; referencedColocationId = referencingColocationId;
} }
referencingIsReferenceTable = (referencingDistMethod == DISTRIBUTE_BY_NONE); bool referencingIsReferenceTable = (referencingDistMethod == DISTRIBUTE_BY_NONE);
referencedIsReferenceTable = (referencedDistMethod == DISTRIBUTE_BY_NONE); bool referencedIsReferenceTable = (referencedDistMethod == DISTRIBUTE_BY_NONE);
/* /*
@ -250,8 +239,8 @@ ErrorIfUnsupportedForeignConstraintExists(Relation relation, char referencingDis
referencedDistKey, referencedDistKey,
&referencingAttrIndex, &referencingAttrIndex,
&referencedAttrIndex); &referencedAttrIndex);
referencingColumnsIncludeDistKey = (referencingAttrIndex != -1); bool referencingColumnsIncludeDistKey = (referencingAttrIndex != -1);
foreignConstraintOnDistKey = bool foreignConstraintOnDistKey =
(referencingColumnsIncludeDistKey && referencingAttrIndex == (referencingColumnsIncludeDistKey && referencingAttrIndex ==
referencedAttrIndex); referencedAttrIndex);
@ -353,14 +342,11 @@ ForeignConstraintFindDistKeys(HeapTuple pgConstraintTuple,
int *referencingAttrIndex, int *referencingAttrIndex,
int *referencedAttrIndex) int *referencedAttrIndex)
{ {
Datum referencingColumnsDatum = 0;
Datum *referencingColumnArray = NULL; Datum *referencingColumnArray = NULL;
int referencingColumnCount = 0; int referencingColumnCount = 0;
Datum referencedColumnsDatum = 0;
Datum *referencedColumnArray = NULL; Datum *referencedColumnArray = NULL;
int referencedColumnCount = 0; int referencedColumnCount = 0;
bool isNull = false; bool isNull = false;
int attrIdx = 0;
*referencedAttrIndex = -1; *referencedAttrIndex = -1;
*referencedAttrIndex = -1; *referencedAttrIndex = -1;
@ -371,9 +357,9 @@ ForeignConstraintFindDistKeys(HeapTuple pgConstraintTuple,
* attributes together because partition column must be at the same place in both * attributes together because partition column must be at the same place in both
* referencing and referenced side of the foreign key constraint. * referencing and referenced side of the foreign key constraint.
*/ */
referencingColumnsDatum = SysCacheGetAttr(CONSTROID, pgConstraintTuple, Datum referencingColumnsDatum = SysCacheGetAttr(CONSTROID, pgConstraintTuple,
Anum_pg_constraint_conkey, &isNull); Anum_pg_constraint_conkey, &isNull);
referencedColumnsDatum = SysCacheGetAttr(CONSTROID, pgConstraintTuple, Datum referencedColumnsDatum = SysCacheGetAttr(CONSTROID, pgConstraintTuple,
Anum_pg_constraint_confkey, &isNull); Anum_pg_constraint_confkey, &isNull);
deconstruct_array(DatumGetArrayTypeP(referencingColumnsDatum), INT2OID, 2, true, deconstruct_array(DatumGetArrayTypeP(referencingColumnsDatum), INT2OID, 2, true,
@ -383,7 +369,7 @@ ForeignConstraintFindDistKeys(HeapTuple pgConstraintTuple,
Assert(referencingColumnCount == referencedColumnCount); Assert(referencingColumnCount == referencedColumnCount);
for (attrIdx = 0; attrIdx < referencingColumnCount; ++attrIdx) for (int attrIdx = 0; attrIdx < referencingColumnCount; ++attrIdx)
{ {
AttrNumber referencingAttrNo = DatumGetInt16(referencingColumnArray[attrIdx]); AttrNumber referencingAttrNo = DatumGetInt16(referencingColumnArray[attrIdx]);
AttrNumber referencedAttrNo = DatumGetInt16(referencedColumnArray[attrIdx]); AttrNumber referencedAttrNo = DatumGetInt16(referencedColumnArray[attrIdx]);
@ -412,31 +398,26 @@ ForeignConstraintFindDistKeys(HeapTuple pgConstraintTuple,
bool bool
ColumnAppearsInForeignKeyToReferenceTable(char *columnName, Oid relationId) ColumnAppearsInForeignKeyToReferenceTable(char *columnName, Oid relationId)
{ {
Relation pgConstraint = NULL;
SysScanDesc scanDescriptor = NULL;
ScanKeyData scanKey[1]; ScanKeyData scanKey[1];
int scanKeyCount = 1; int scanKeyCount = 1;
HeapTuple heapTuple = NULL;
bool foreignKeyToReferenceTableIncludesGivenColumn = false; bool foreignKeyToReferenceTableIncludesGivenColumn = false;
pgConstraint = heap_open(ConstraintRelationId, AccessShareLock); Relation pgConstraint = heap_open(ConstraintRelationId, AccessShareLock);
ScanKeyInit(&scanKey[0], Anum_pg_constraint_contype, BTEqualStrategyNumber, F_CHAREQ, ScanKeyInit(&scanKey[0], Anum_pg_constraint_contype, BTEqualStrategyNumber, F_CHAREQ,
CharGetDatum(CONSTRAINT_FOREIGN)); CharGetDatum(CONSTRAINT_FOREIGN));
scanDescriptor = systable_beginscan(pgConstraint, InvalidOid, false, SysScanDesc scanDescriptor = systable_beginscan(pgConstraint, InvalidOid, false,
NULL, scanKeyCount, scanKey); NULL, scanKeyCount, scanKey);
heapTuple = systable_getnext(scanDescriptor); HeapTuple heapTuple = systable_getnext(scanDescriptor);
while (HeapTupleIsValid(heapTuple)) while (HeapTupleIsValid(heapTuple))
{ {
Oid referencedTableId = InvalidOid;
Oid referencingTableId = InvalidOid;
int pgConstraintKey = 0; int pgConstraintKey = 0;
Form_pg_constraint constraintForm = (Form_pg_constraint) GETSTRUCT(heapTuple); Form_pg_constraint constraintForm = (Form_pg_constraint) GETSTRUCT(heapTuple);
referencedTableId = constraintForm->confrelid; Oid referencedTableId = constraintForm->confrelid;
referencingTableId = constraintForm->conrelid; Oid referencingTableId = constraintForm->conrelid;
if (referencedTableId == relationId) if (referencedTableId == relationId)
{ {
@ -493,11 +474,8 @@ GetTableForeignConstraintCommands(Oid relationId)
{ {
List *tableForeignConstraints = NIL; List *tableForeignConstraints = NIL;
Relation pgConstraint = NULL;
SysScanDesc scanDescriptor = NULL;
ScanKeyData scanKey[1]; ScanKeyData scanKey[1];
int scanKeyCount = 1; int scanKeyCount = 1;
HeapTuple heapTuple = NULL;
/* /*
* Set search_path to NIL so that all objects outside of pg_catalog will be * Set search_path to NIL so that all objects outside of pg_catalog will be
@ -510,14 +488,15 @@ GetTableForeignConstraintCommands(Oid relationId)
PushOverrideSearchPath(overridePath); PushOverrideSearchPath(overridePath);
/* open system catalog and scan all constraints that belong to this table */ /* open system catalog and scan all constraints that belong to this table */
pgConstraint = heap_open(ConstraintRelationId, AccessShareLock); Relation pgConstraint = heap_open(ConstraintRelationId, AccessShareLock);
ScanKeyInit(&scanKey[0], Anum_pg_constraint_conrelid, BTEqualStrategyNumber, F_OIDEQ, ScanKeyInit(&scanKey[0], Anum_pg_constraint_conrelid, BTEqualStrategyNumber, F_OIDEQ,
relationId); relationId);
scanDescriptor = systable_beginscan(pgConstraint, ConstraintRelidTypidNameIndexId, SysScanDesc scanDescriptor = systable_beginscan(pgConstraint,
ConstraintRelidTypidNameIndexId,
true, NULL, true, NULL,
scanKeyCount, scanKey); scanKeyCount, scanKey);
heapTuple = systable_getnext(scanDescriptor); HeapTuple heapTuple = systable_getnext(scanDescriptor);
while (HeapTupleIsValid(heapTuple)) while (HeapTupleIsValid(heapTuple))
{ {
Form_pg_constraint constraintForm = (Form_pg_constraint) GETSTRUCT(heapTuple); Form_pg_constraint constraintForm = (Form_pg_constraint) GETSTRUCT(heapTuple);
@ -556,24 +535,21 @@ GetTableForeignConstraintCommands(Oid relationId)
bool bool
HasForeignKeyToReferenceTable(Oid relationId) HasForeignKeyToReferenceTable(Oid relationId)
{ {
Relation pgConstraint = NULL;
SysScanDesc scanDescriptor = NULL;
ScanKeyData scanKey[1]; ScanKeyData scanKey[1];
int scanKeyCount = 1; int scanKeyCount = 1;
HeapTuple heapTuple = NULL;
bool hasForeignKeyToReferenceTable = false; bool hasForeignKeyToReferenceTable = false;
pgConstraint = heap_open(ConstraintRelationId, AccessShareLock); Relation pgConstraint = heap_open(ConstraintRelationId, AccessShareLock);
ScanKeyInit(&scanKey[0], Anum_pg_constraint_conrelid, BTEqualStrategyNumber, F_OIDEQ, ScanKeyInit(&scanKey[0], Anum_pg_constraint_conrelid, BTEqualStrategyNumber, F_OIDEQ,
relationId); relationId);
scanDescriptor = systable_beginscan(pgConstraint, ConstraintRelidTypidNameIndexId, SysScanDesc scanDescriptor = systable_beginscan(pgConstraint,
ConstraintRelidTypidNameIndexId,
true, NULL, true, NULL,
scanKeyCount, scanKey); scanKeyCount, scanKey);
heapTuple = systable_getnext(scanDescriptor); HeapTuple heapTuple = systable_getnext(scanDescriptor);
while (HeapTupleIsValid(heapTuple)) while (HeapTupleIsValid(heapTuple))
{ {
Oid referencedTableId = InvalidOid;
Form_pg_constraint constraintForm = (Form_pg_constraint) GETSTRUCT(heapTuple); Form_pg_constraint constraintForm = (Form_pg_constraint) GETSTRUCT(heapTuple);
if (constraintForm->contype != CONSTRAINT_FOREIGN) if (constraintForm->contype != CONSTRAINT_FOREIGN)
@ -582,7 +558,7 @@ HasForeignKeyToReferenceTable(Oid relationId)
continue; continue;
} }
referencedTableId = constraintForm->confrelid; Oid referencedTableId = constraintForm->confrelid;
if (!IsDistributedTable(referencedTableId)) if (!IsDistributedTable(referencedTableId))
{ {
@ -615,22 +591,20 @@ HasForeignKeyToReferenceTable(Oid relationId)
bool bool
TableReferenced(Oid relationId) TableReferenced(Oid relationId)
{ {
Relation pgConstraint = NULL;
HeapTuple heapTuple = NULL;
SysScanDesc scanDescriptor = NULL;
ScanKeyData scanKey[1]; ScanKeyData scanKey[1];
int scanKeyCount = 1; int scanKeyCount = 1;
Oid scanIndexId = InvalidOid; Oid scanIndexId = InvalidOid;
bool useIndex = false; bool useIndex = false;
pgConstraint = heap_open(ConstraintRelationId, AccessShareLock); Relation pgConstraint = heap_open(ConstraintRelationId, AccessShareLock);
ScanKeyInit(&scanKey[0], Anum_pg_constraint_confrelid, BTEqualStrategyNumber, F_OIDEQ, ScanKeyInit(&scanKey[0], Anum_pg_constraint_confrelid, BTEqualStrategyNumber, F_OIDEQ,
relationId); relationId);
scanDescriptor = systable_beginscan(pgConstraint, scanIndexId, useIndex, NULL, SysScanDesc scanDescriptor = systable_beginscan(pgConstraint, scanIndexId, useIndex,
NULL,
scanKeyCount, scanKey); scanKeyCount, scanKey);
heapTuple = systable_getnext(scanDescriptor); HeapTuple heapTuple = systable_getnext(scanDescriptor);
while (HeapTupleIsValid(heapTuple)) while (HeapTupleIsValid(heapTuple))
{ {
Form_pg_constraint constraintForm = (Form_pg_constraint) GETSTRUCT(heapTuple); Form_pg_constraint constraintForm = (Form_pg_constraint) GETSTRUCT(heapTuple);
@ -661,17 +635,15 @@ static bool
HeapTupleOfForeignConstraintIncludesColumn(HeapTuple heapTuple, Oid relationId, HeapTupleOfForeignConstraintIncludesColumn(HeapTuple heapTuple, Oid relationId,
int pgConstraintKey, char *columnName) int pgConstraintKey, char *columnName)
{ {
Datum columnsDatum = 0;
Datum *columnArray = NULL; Datum *columnArray = NULL;
int columnCount = 0; int columnCount = 0;
int attrIdx = 0;
bool isNull = false; bool isNull = false;
columnsDatum = SysCacheGetAttr(CONSTROID, heapTuple, pgConstraintKey, &isNull); Datum columnsDatum = SysCacheGetAttr(CONSTROID, heapTuple, pgConstraintKey, &isNull);
deconstruct_array(DatumGetArrayTypeP(columnsDatum), INT2OID, 2, true, deconstruct_array(DatumGetArrayTypeP(columnsDatum), INT2OID, 2, true,
's', &columnArray, NULL, &columnCount); 's', &columnArray, NULL, &columnCount);
for (attrIdx = 0; attrIdx < columnCount; ++attrIdx) for (int attrIdx = 0; attrIdx < columnCount; ++attrIdx)
{ {
AttrNumber attrNo = DatumGetInt16(columnArray[attrIdx]); AttrNumber attrNo = DatumGetInt16(columnArray[attrIdx]);
@ -696,22 +668,20 @@ HeapTupleOfForeignConstraintIncludesColumn(HeapTuple heapTuple, Oid relationId,
bool bool
TableReferencing(Oid relationId) TableReferencing(Oid relationId)
{ {
Relation pgConstraint = NULL;
HeapTuple heapTuple = NULL;
SysScanDesc scanDescriptor = NULL;
ScanKeyData scanKey[1]; ScanKeyData scanKey[1];
int scanKeyCount = 1; int scanKeyCount = 1;
Oid scanIndexId = InvalidOid; Oid scanIndexId = InvalidOid;
bool useIndex = false; bool useIndex = false;
pgConstraint = heap_open(ConstraintRelationId, AccessShareLock); Relation pgConstraint = heap_open(ConstraintRelationId, AccessShareLock);
ScanKeyInit(&scanKey[0], Anum_pg_constraint_conrelid, BTEqualStrategyNumber, F_OIDEQ, ScanKeyInit(&scanKey[0], Anum_pg_constraint_conrelid, BTEqualStrategyNumber, F_OIDEQ,
relationId); relationId);
scanDescriptor = systable_beginscan(pgConstraint, scanIndexId, useIndex, NULL, SysScanDesc scanDescriptor = systable_beginscan(pgConstraint, scanIndexId, useIndex,
NULL,
scanKeyCount, scanKey); scanKeyCount, scanKey);
heapTuple = systable_getnext(scanDescriptor); HeapTuple heapTuple = systable_getnext(scanDescriptor);
while (HeapTupleIsValid(heapTuple)) while (HeapTupleIsValid(heapTuple))
{ {
Form_pg_constraint constraintForm = (Form_pg_constraint) GETSTRUCT(heapTuple); Form_pg_constraint constraintForm = (Form_pg_constraint) GETSTRUCT(heapTuple);
@ -741,20 +711,17 @@ TableReferencing(Oid relationId)
bool bool
ConstraintIsAForeignKey(char *constraintNameInput, Oid relationId) ConstraintIsAForeignKey(char *constraintNameInput, Oid relationId)
{ {
Relation pgConstraint = NULL;
SysScanDesc scanDescriptor = NULL;
ScanKeyData scanKey[1]; ScanKeyData scanKey[1];
int scanKeyCount = 1; int scanKeyCount = 1;
HeapTuple heapTuple = NULL;
pgConstraint = heap_open(ConstraintRelationId, AccessShareLock); Relation pgConstraint = heap_open(ConstraintRelationId, AccessShareLock);
ScanKeyInit(&scanKey[0], Anum_pg_constraint_contype, BTEqualStrategyNumber, F_CHAREQ, ScanKeyInit(&scanKey[0], Anum_pg_constraint_contype, BTEqualStrategyNumber, F_CHAREQ,
CharGetDatum(CONSTRAINT_FOREIGN)); CharGetDatum(CONSTRAINT_FOREIGN));
scanDescriptor = systable_beginscan(pgConstraint, InvalidOid, false, SysScanDesc scanDescriptor = systable_beginscan(pgConstraint, InvalidOid, false,
NULL, scanKeyCount, scanKey); NULL, scanKeyCount, scanKey);
heapTuple = systable_getnext(scanDescriptor); HeapTuple heapTuple = systable_getnext(scanDescriptor);
while (HeapTupleIsValid(heapTuple)) while (HeapTupleIsValid(heapTuple))
{ {
Form_pg_constraint constraintForm = (Form_pg_constraint) GETSTRUCT(heapTuple); Form_pg_constraint constraintForm = (Form_pg_constraint) GETSTRUCT(heapTuple);

View File

@ -99,8 +99,6 @@ create_distributed_function(PG_FUNCTION_ARGS)
text *colocateWithText = NULL; /* optional */ text *colocateWithText = NULL; /* optional */
StringInfoData ddlCommand = { 0 }; StringInfoData ddlCommand = { 0 };
const char *createFunctionSQL = NULL;
const char *alterFunctionOwnerSQL = NULL;
ObjectAddress functionAddress = { 0 }; ObjectAddress functionAddress = { 0 };
int distributionArgumentIndex = -1; int distributionArgumentIndex = -1;
@ -159,8 +157,8 @@ create_distributed_function(PG_FUNCTION_ARGS)
EnsureDependenciesExistsOnAllNodes(&functionAddress); EnsureDependenciesExistsOnAllNodes(&functionAddress);
createFunctionSQL = GetFunctionDDLCommand(funcOid, true); const char *createFunctionSQL = GetFunctionDDLCommand(funcOid, true);
alterFunctionOwnerSQL = GetFunctionAlterOwnerCommand(funcOid); const char *alterFunctionOwnerSQL = GetFunctionAlterOwnerCommand(funcOid);
initStringInfo(&ddlCommand); initStringInfo(&ddlCommand);
appendStringInfo(&ddlCommand, "%s;%s", createFunctionSQL, alterFunctionOwnerSQL); appendStringInfo(&ddlCommand, "%s;%s", createFunctionSQL, alterFunctionOwnerSQL);
SendCommandToWorkersAsUser(ALL_WORKERS, CurrentUserName(), ddlCommand.data); SendCommandToWorkersAsUser(ALL_WORKERS, CurrentUserName(), ddlCommand.data);
@ -221,13 +219,10 @@ create_distributed_function(PG_FUNCTION_ARGS)
List * List *
CreateFunctionDDLCommandsIdempotent(const ObjectAddress *functionAddress) CreateFunctionDDLCommandsIdempotent(const ObjectAddress *functionAddress)
{ {
char *ddlCommand = NULL;
char *alterFunctionOwnerSQL = NULL;
Assert(functionAddress->classId == ProcedureRelationId); Assert(functionAddress->classId == ProcedureRelationId);
ddlCommand = GetFunctionDDLCommand(functionAddress->objectId, true); char *ddlCommand = GetFunctionDDLCommand(functionAddress->objectId, true);
alterFunctionOwnerSQL = GetFunctionAlterOwnerCommand(functionAddress->objectId); char *alterFunctionOwnerSQL = GetFunctionAlterOwnerCommand(functionAddress->objectId);
return list_make2(ddlCommand, alterFunctionOwnerSQL); return list_make2(ddlCommand, alterFunctionOwnerSQL);
} }
@ -243,23 +238,20 @@ GetDistributionArgIndex(Oid functionOid, char *distributionArgumentName,
{ {
int distributionArgumentIndex = -1; int distributionArgumentIndex = -1;
int numberOfArgs = 0;
int argIndex = 0;
Oid *argTypes = NULL; Oid *argTypes = NULL;
char **argNames = NULL; char **argNames = NULL;
char *argModes = NULL; char *argModes = NULL;
HeapTuple proctup = NULL;
*distributionArgumentOid = InvalidOid; *distributionArgumentOid = InvalidOid;
proctup = SearchSysCache1(PROCOID, ObjectIdGetDatum(functionOid)); HeapTuple proctup = SearchSysCache1(PROCOID, ObjectIdGetDatum(functionOid));
if (!HeapTupleIsValid(proctup)) if (!HeapTupleIsValid(proctup))
{ {
elog(ERROR, "cache lookup failed for function %u", functionOid); elog(ERROR, "cache lookup failed for function %u", functionOid);
} }
numberOfArgs = get_func_arg_info(proctup, &argTypes, &argNames, &argModes); int numberOfArgs = get_func_arg_info(proctup, &argTypes, &argNames, &argModes);
if (argumentStartsWith(distributionArgumentName, "$")) if (argumentStartsWith(distributionArgumentName, "$"))
{ {
@ -301,7 +293,7 @@ GetDistributionArgIndex(Oid functionOid, char *distributionArgumentName,
* So, loop over the arguments and try to find the argument name that matches * So, loop over the arguments and try to find the argument name that matches
* the parameter that user provided. * the parameter that user provided.
*/ */
for (argIndex = 0; argIndex < numberOfArgs; ++argIndex) for (int argIndex = 0; argIndex < numberOfArgs; ++argIndex)
{ {
char *argNameOnIndex = argNames != NULL ? argNames[argIndex] : NULL; char *argNameOnIndex = argNames != NULL ? argNames[argIndex] : NULL;
@ -352,8 +344,6 @@ GetFunctionColocationId(Oid functionOid, char *colocateWithTableName,
if (pg_strncasecmp(colocateWithTableName, "default", NAMEDATALEN) == 0) if (pg_strncasecmp(colocateWithTableName, "default", NAMEDATALEN) == 0)
{ {
Oid colocatedTableId = InvalidOid;
/* check for default colocation group */ /* check for default colocation group */
colocationId = ColocationId(ShardCount, ShardReplicationFactor, colocationId = ColocationId(ShardCount, ShardReplicationFactor,
distributionArgumentOid); distributionArgumentOid);
@ -369,7 +359,7 @@ GetFunctionColocationId(Oid functionOid, char *colocateWithTableName,
"option to create_distributed_function()"))); "option to create_distributed_function()")));
} }
colocatedTableId = ColocatedTableId(colocationId); Oid colocatedTableId = ColocatedTableId(colocationId);
if (colocatedTableId != InvalidOid) if (colocatedTableId != InvalidOid)
{ {
EnsureFunctionCanBeColocatedWithTable(functionOid, distributionArgumentOid, EnsureFunctionCanBeColocatedWithTable(functionOid, distributionArgumentOid,
@ -415,7 +405,6 @@ EnsureFunctionCanBeColocatedWithTable(Oid functionOid, Oid distributionColumnTyp
char sourceDistributionMethod = sourceTableEntry->partitionMethod; char sourceDistributionMethod = sourceTableEntry->partitionMethod;
char sourceReplicationModel = sourceTableEntry->replicationModel; char sourceReplicationModel = sourceTableEntry->replicationModel;
Var *sourceDistributionColumn = DistPartitionKey(sourceRelationId); Var *sourceDistributionColumn = DistPartitionKey(sourceRelationId);
Oid sourceDistributionColumnType = InvalidOid;
if (sourceDistributionMethod != DISTRIBUTE_BY_HASH) if (sourceDistributionMethod != DISTRIBUTE_BY_HASH)
{ {
@ -447,13 +436,12 @@ EnsureFunctionCanBeColocatedWithTable(Oid functionOid, Oid distributionColumnTyp
* If the types are the same, we're good. If not, we still check if there * If the types are the same, we're good. If not, we still check if there
* is any coercion path between the types. * is any coercion path between the types.
*/ */
sourceDistributionColumnType = sourceDistributionColumn->vartype; Oid sourceDistributionColumnType = sourceDistributionColumn->vartype;
if (sourceDistributionColumnType != distributionColumnType) if (sourceDistributionColumnType != distributionColumnType)
{ {
Oid coercionFuncId = InvalidOid; Oid coercionFuncId = InvalidOid;
CoercionPathType coercionType = COERCION_PATH_NONE;
coercionType = CoercionPathType coercionType =
find_coercion_pathway(distributionColumnType, sourceDistributionColumnType, find_coercion_pathway(distributionColumnType, sourceDistributionColumnType,
COERCION_EXPLICIT, &coercionFuncId); COERCION_EXPLICIT, &coercionFuncId);
@ -483,17 +471,13 @@ UpdateFunctionDistributionInfo(const ObjectAddress *distAddress,
{ {
const bool indexOK = true; const bool indexOK = true;
Relation pgDistObjectRel = NULL;
TupleDesc tupleDescriptor = NULL;
ScanKeyData scanKey[3]; ScanKeyData scanKey[3];
SysScanDesc scanDescriptor = NULL;
HeapTuple heapTuple = NULL;
Datum values[Natts_pg_dist_object]; Datum values[Natts_pg_dist_object];
bool isnull[Natts_pg_dist_object]; bool isnull[Natts_pg_dist_object];
bool replace[Natts_pg_dist_object]; bool replace[Natts_pg_dist_object];
pgDistObjectRel = heap_open(DistObjectRelationId(), RowExclusiveLock); Relation pgDistObjectRel = heap_open(DistObjectRelationId(), RowExclusiveLock);
tupleDescriptor = RelationGetDescr(pgDistObjectRel); TupleDesc tupleDescriptor = RelationGetDescr(pgDistObjectRel);
/* scan pg_dist_object for classid = $1 AND objid = $2 AND objsubid = $3 via index */ /* scan pg_dist_object for classid = $1 AND objid = $2 AND objsubid = $3 via index */
ScanKeyInit(&scanKey[0], Anum_pg_dist_object_classid, BTEqualStrategyNumber, F_OIDEQ, ScanKeyInit(&scanKey[0], Anum_pg_dist_object_classid, BTEqualStrategyNumber, F_OIDEQ,
@ -503,11 +487,12 @@ UpdateFunctionDistributionInfo(const ObjectAddress *distAddress,
ScanKeyInit(&scanKey[2], Anum_pg_dist_object_objsubid, BTEqualStrategyNumber, ScanKeyInit(&scanKey[2], Anum_pg_dist_object_objsubid, BTEqualStrategyNumber,
F_INT4EQ, ObjectIdGetDatum(distAddress->objectSubId)); F_INT4EQ, ObjectIdGetDatum(distAddress->objectSubId));
scanDescriptor = systable_beginscan(pgDistObjectRel, DistObjectPrimaryKeyIndexId(), SysScanDesc scanDescriptor = systable_beginscan(pgDistObjectRel,
DistObjectPrimaryKeyIndexId(),
indexOK, indexOK,
NULL, 3, scanKey); NULL, 3, scanKey);
heapTuple = systable_getnext(scanDescriptor); HeapTuple heapTuple = systable_getnext(scanDescriptor);
if (!HeapTupleIsValid(heapTuple)) if (!HeapTupleIsValid(heapTuple))
{ {
ereport(ERROR, (errmsg("could not find valid entry for node \"%d,%d,%d\" " ereport(ERROR, (errmsg("could not find valid entry for node \"%d,%d,%d\" "
@ -609,17 +594,10 @@ GetFunctionAlterOwnerCommand(const RegProcedure funcOid)
char *kindString = "FUNCTION"; char *kindString = "FUNCTION";
Oid procOwner = InvalidOid; Oid procOwner = InvalidOid;
char *functionSignature = NULL;
char *functionOwner = NULL;
OverrideSearchPath *overridePath = NULL;
Datum functionSignatureDatum = 0;
if (HeapTupleIsValid(proctup)) if (HeapTupleIsValid(proctup))
{ {
Form_pg_proc procform; Form_pg_proc procform = (Form_pg_proc) GETSTRUCT(proctup);
procform = (Form_pg_proc) GETSTRUCT(proctup);
procOwner = procform->proowner; procOwner = procform->proowner;
@ -644,7 +622,7 @@ GetFunctionAlterOwnerCommand(const RegProcedure funcOid)
* schema-prefixed. pg_catalog will be added automatically when we call * schema-prefixed. pg_catalog will be added automatically when we call
* PushOverrideSearchPath(), since we set addCatalog to true; * PushOverrideSearchPath(), since we set addCatalog to true;
*/ */
overridePath = GetOverrideSearchPath(CurrentMemoryContext); OverrideSearchPath *overridePath = GetOverrideSearchPath(CurrentMemoryContext);
overridePath->schemas = NIL; overridePath->schemas = NIL;
overridePath->addCatalog = true; overridePath->addCatalog = true;
@ -654,16 +632,16 @@ GetFunctionAlterOwnerCommand(const RegProcedure funcOid)
* If the function exists we want to use pg_get_function_identity_arguments to * If the function exists we want to use pg_get_function_identity_arguments to
* serialize its canonical arguments * serialize its canonical arguments
*/ */
functionSignatureDatum = Datum functionSignatureDatum =
DirectFunctionCall1(regprocedureout, ObjectIdGetDatum(funcOid)); DirectFunctionCall1(regprocedureout, ObjectIdGetDatum(funcOid));
/* revert back to original search_path */ /* revert back to original search_path */
PopOverrideSearchPath(); PopOverrideSearchPath();
/* regprocedureout returns cstring */ /* regprocedureout returns cstring */
functionSignature = DatumGetCString(functionSignatureDatum); char *functionSignature = DatumGetCString(functionSignatureDatum);
functionOwner = GetUserNameFromId(procOwner, false); char *functionOwner = GetUserNameFromId(procOwner, false);
appendStringInfo(alterCommand, "ALTER %s %s OWNER TO %s;", appendStringInfo(alterCommand, "ALTER %s %s OWNER TO %s;",
kindString, kindString,
@ -686,12 +664,8 @@ static char *
GetAggregateDDLCommand(const RegProcedure funcOid, bool useCreateOrReplace) GetAggregateDDLCommand(const RegProcedure funcOid, bool useCreateOrReplace)
{ {
StringInfoData buf = { 0 }; StringInfoData buf = { 0 };
HeapTuple proctup = NULL;
Form_pg_proc proc = NULL;
HeapTuple aggtup = NULL; HeapTuple aggtup = NULL;
Form_pg_aggregate agg = NULL; Form_pg_aggregate agg = NULL;
const char *name = NULL;
const char *nsp = NULL;
int numargs = 0; int numargs = 0;
int i = 0; int i = 0;
Oid *argtypes = NULL; Oid *argtypes = NULL;
@ -701,20 +675,20 @@ GetAggregateDDLCommand(const RegProcedure funcOid, bool useCreateOrReplace)
int argsprinted = 0; int argsprinted = 0;
int inputargno = 0; int inputargno = 0;
proctup = SearchSysCache1(PROCOID, funcOid); HeapTuple proctup = SearchSysCache1(PROCOID, funcOid);
if (!HeapTupleIsValid(proctup)) if (!HeapTupleIsValid(proctup))
{ {
elog(ERROR, "cache lookup failed for %d", funcOid); elog(ERROR, "cache lookup failed for %d", funcOid);
} }
proc = (Form_pg_proc) GETSTRUCT(proctup); Form_pg_proc proc = (Form_pg_proc) GETSTRUCT(proctup);
Assert(proc->prokind == PROKIND_AGGREGATE); Assert(proc->prokind == PROKIND_AGGREGATE);
initStringInfo(&buf); initStringInfo(&buf);
name = NameStr(proc->proname); const char *name = NameStr(proc->proname);
nsp = get_namespace_name(proc->pronamespace); const char *nsp = get_namespace_name(proc->pronamespace);
#if PG_VERSION_NUM >= 120000 #if PG_VERSION_NUM >= 120000
if (useCreateOrReplace) if (useCreateOrReplace)
@ -1112,8 +1086,6 @@ TriggerSyncMetadataToPrimaryNodes(void)
static bool static bool
ShouldPropagateCreateFunction(CreateFunctionStmt *stmt) ShouldPropagateCreateFunction(CreateFunctionStmt *stmt)
{ {
const ObjectAddress *address = NULL;
if (creating_extension) if (creating_extension)
{ {
/* /*
@ -1144,7 +1116,7 @@ ShouldPropagateCreateFunction(CreateFunctionStmt *stmt)
* Even though its a replace we should accept an non-existing function, it will just * Even though its a replace we should accept an non-existing function, it will just
* not be distributed * not be distributed
*/ */
address = GetObjectAddressFromParseTree((Node *) stmt, true); const ObjectAddress *address = GetObjectAddressFromParseTree((Node *) stmt, true);
if (!IsObjectDistributed(address)) if (!IsObjectDistributed(address))
{ {
/* do not propagate alter function for non-distributed functions */ /* do not propagate alter function for non-distributed functions */
@ -1231,18 +1203,15 @@ PlanCreateFunctionStmt(CreateFunctionStmt *stmt, const char *queryString)
List * List *
ProcessCreateFunctionStmt(CreateFunctionStmt *stmt, const char *queryString) ProcessCreateFunctionStmt(CreateFunctionStmt *stmt, const char *queryString)
{ {
const ObjectAddress *address = NULL;
List *commands = NIL;
if (!ShouldPropagateCreateFunction(stmt)) if (!ShouldPropagateCreateFunction(stmt))
{ {
return NIL; return NIL;
} }
address = GetObjectAddressFromParseTree((Node *) stmt, false); const ObjectAddress *address = GetObjectAddressFromParseTree((Node *) stmt, false);
EnsureDependenciesExistsOnAllNodes(address); EnsureDependenciesExistsOnAllNodes(address);
commands = list_make4(DISABLE_DDL_PROPAGATION, List *commands = list_make4(DISABLE_DDL_PROPAGATION,
GetFunctionDDLCommand(address->objectId, true), GetFunctionDDLCommand(address->objectId, true),
GetFunctionAlterOwnerCommand(address->objectId), GetFunctionAlterOwnerCommand(address->objectId),
ENABLE_DDL_PROPAGATION); ENABLE_DDL_PROPAGATION);
@ -1260,7 +1229,6 @@ ObjectAddress *
CreateFunctionStmtObjectAddress(CreateFunctionStmt *stmt, bool missing_ok) CreateFunctionStmtObjectAddress(CreateFunctionStmt *stmt, bool missing_ok)
{ {
ObjectType objectType = OBJECT_FUNCTION; ObjectType objectType = OBJECT_FUNCTION;
ObjectWithArgs *objectWithArgs = NULL;
ListCell *parameterCell = NULL; ListCell *parameterCell = NULL;
if (stmt->is_procedure) if (stmt->is_procedure)
@ -1268,7 +1236,7 @@ CreateFunctionStmtObjectAddress(CreateFunctionStmt *stmt, bool missing_ok)
objectType = OBJECT_PROCEDURE; objectType = OBJECT_PROCEDURE;
} }
objectWithArgs = makeNode(ObjectWithArgs); ObjectWithArgs *objectWithArgs = makeNode(ObjectWithArgs);
objectWithArgs->objname = stmt->funcname; objectWithArgs->objname = stmt->funcname;
foreach(parameterCell, stmt->parameters) foreach(parameterCell, stmt->parameters)
@ -1292,12 +1260,11 @@ CreateFunctionStmtObjectAddress(CreateFunctionStmt *stmt, bool missing_ok)
ObjectAddress * ObjectAddress *
DefineAggregateStmtObjectAddress(DefineStmt *stmt, bool missing_ok) DefineAggregateStmtObjectAddress(DefineStmt *stmt, bool missing_ok)
{ {
ObjectWithArgs *objectWithArgs = NULL;
ListCell *parameterCell = NULL; ListCell *parameterCell = NULL;
Assert(stmt->kind == OBJECT_AGGREGATE); Assert(stmt->kind == OBJECT_AGGREGATE);
objectWithArgs = makeNode(ObjectWithArgs); ObjectWithArgs *objectWithArgs = makeNode(ObjectWithArgs);
objectWithArgs->objname = stmt->defnames; objectWithArgs->objname = stmt->defnames;
foreach(parameterCell, linitial(stmt->args)) foreach(parameterCell, linitial(stmt->args))
@ -1318,13 +1285,9 @@ DefineAggregateStmtObjectAddress(DefineStmt *stmt, bool missing_ok)
List * List *
PlanAlterFunctionStmt(AlterFunctionStmt *stmt, const char *queryString) PlanAlterFunctionStmt(AlterFunctionStmt *stmt, const char *queryString)
{ {
const char *sql = NULL;
const ObjectAddress *address = NULL;
List *commands = NIL;
AssertObjectTypeIsFunctional(stmt->objtype); AssertObjectTypeIsFunctional(stmt->objtype);
address = GetObjectAddressFromParseTree((Node *) stmt, false); const ObjectAddress *address = GetObjectAddressFromParseTree((Node *) stmt, false);
if (!ShouldPropagateAlterFunction(address)) if (!ShouldPropagateAlterFunction(address))
{ {
return NIL; return NIL;
@ -1334,9 +1297,9 @@ PlanAlterFunctionStmt(AlterFunctionStmt *stmt, const char *queryString)
ErrorIfUnsupportedAlterFunctionStmt(stmt); ErrorIfUnsupportedAlterFunctionStmt(stmt);
EnsureSequentialModeForFunctionDDL(); EnsureSequentialModeForFunctionDDL();
QualifyTreeNode((Node *) stmt); QualifyTreeNode((Node *) stmt);
sql = DeparseTreeNode((Node *) stmt); const char *sql = DeparseTreeNode((Node *) stmt);
commands = list_make3(DISABLE_DDL_PROPAGATION, List *commands = list_make3(DISABLE_DDL_PROPAGATION,
(void *) sql, (void *) sql,
ENABLE_DDL_PROPAGATION); ENABLE_DDL_PROPAGATION);
@ -1355,13 +1318,9 @@ PlanAlterFunctionStmt(AlterFunctionStmt *stmt, const char *queryString)
List * List *
PlanRenameFunctionStmt(RenameStmt *stmt, const char *queryString) PlanRenameFunctionStmt(RenameStmt *stmt, const char *queryString)
{ {
const char *sql = NULL;
const ObjectAddress *address = NULL;
List *commands = NIL;
AssertObjectTypeIsFunctional(stmt->renameType); AssertObjectTypeIsFunctional(stmt->renameType);
address = GetObjectAddressFromParseTree((Node *) stmt, false); const ObjectAddress *address = GetObjectAddressFromParseTree((Node *) stmt, false);
if (!ShouldPropagateAlterFunction(address)) if (!ShouldPropagateAlterFunction(address))
{ {
return NIL; return NIL;
@ -1370,9 +1329,9 @@ PlanRenameFunctionStmt(RenameStmt *stmt, const char *queryString)
EnsureCoordinator(); EnsureCoordinator();
EnsureSequentialModeForFunctionDDL(); EnsureSequentialModeForFunctionDDL();
QualifyTreeNode((Node *) stmt); QualifyTreeNode((Node *) stmt);
sql = DeparseTreeNode((Node *) stmt); const char *sql = DeparseTreeNode((Node *) stmt);
commands = list_make3(DISABLE_DDL_PROPAGATION, List *commands = list_make3(DISABLE_DDL_PROPAGATION,
(void *) sql, (void *) sql,
ENABLE_DDL_PROPAGATION); ENABLE_DDL_PROPAGATION);
@ -1389,13 +1348,9 @@ PlanRenameFunctionStmt(RenameStmt *stmt, const char *queryString)
List * List *
PlanAlterFunctionSchemaStmt(AlterObjectSchemaStmt *stmt, const char *queryString) PlanAlterFunctionSchemaStmt(AlterObjectSchemaStmt *stmt, const char *queryString)
{ {
const char *sql = NULL;
const ObjectAddress *address = NULL;
List *commands = NIL;
AssertObjectTypeIsFunctional(stmt->objectType); AssertObjectTypeIsFunctional(stmt->objectType);
address = GetObjectAddressFromParseTree((Node *) stmt, false); const ObjectAddress *address = GetObjectAddressFromParseTree((Node *) stmt, false);
if (!ShouldPropagateAlterFunction(address)) if (!ShouldPropagateAlterFunction(address))
{ {
return NIL; return NIL;
@ -1404,9 +1359,9 @@ PlanAlterFunctionSchemaStmt(AlterObjectSchemaStmt *stmt, const char *queryString
EnsureCoordinator(); EnsureCoordinator();
EnsureSequentialModeForFunctionDDL(); EnsureSequentialModeForFunctionDDL();
QualifyTreeNode((Node *) stmt); QualifyTreeNode((Node *) stmt);
sql = DeparseTreeNode((Node *) stmt); const char *sql = DeparseTreeNode((Node *) stmt);
commands = list_make3(DISABLE_DDL_PROPAGATION, List *commands = list_make3(DISABLE_DDL_PROPAGATION,
(void *) sql, (void *) sql,
ENABLE_DDL_PROPAGATION); ENABLE_DDL_PROPAGATION);
@ -1424,13 +1379,9 @@ PlanAlterFunctionSchemaStmt(AlterObjectSchemaStmt *stmt, const char *queryString
List * List *
PlanAlterFunctionOwnerStmt(AlterOwnerStmt *stmt, const char *queryString) PlanAlterFunctionOwnerStmt(AlterOwnerStmt *stmt, const char *queryString)
{ {
const ObjectAddress *address = NULL;
const char *sql = NULL;
List *commands = NULL;
AssertObjectTypeIsFunctional(stmt->objectType); AssertObjectTypeIsFunctional(stmt->objectType);
address = GetObjectAddressFromParseTree((Node *) stmt, false); const ObjectAddress *address = GetObjectAddressFromParseTree((Node *) stmt, false);
if (!ShouldPropagateAlterFunction(address)) if (!ShouldPropagateAlterFunction(address))
{ {
return NIL; return NIL;
@ -1439,9 +1390,9 @@ PlanAlterFunctionOwnerStmt(AlterOwnerStmt *stmt, const char *queryString)
EnsureCoordinator(); EnsureCoordinator();
EnsureSequentialModeForFunctionDDL(); EnsureSequentialModeForFunctionDDL();
QualifyTreeNode((Node *) stmt); QualifyTreeNode((Node *) stmt);
sql = DeparseTreeNode((Node *) stmt); const char *sql = DeparseTreeNode((Node *) stmt);
commands = list_make3(DISABLE_DDL_PROPAGATION, List *commands = list_make3(DISABLE_DDL_PROPAGATION,
(void *) sql, (void *) sql,
ENABLE_DDL_PROPAGATION); ENABLE_DDL_PROPAGATION);
@ -1465,10 +1416,7 @@ PlanDropFunctionStmt(DropStmt *stmt, const char *queryString)
List *distributedObjectWithArgsList = NIL; List *distributedObjectWithArgsList = NIL;
List *distributedFunctionAddresses = NIL; List *distributedFunctionAddresses = NIL;
ListCell *addressCell = NULL; ListCell *addressCell = NULL;
const char *dropStmtSql = NULL;
List *commands = NULL;
ListCell *objectWithArgsListCell = NULL; ListCell *objectWithArgsListCell = NULL;
DropStmt *stmtCopy = NULL;
AssertObjectTypeIsFunctional(stmt->removeType); AssertObjectTypeIsFunctional(stmt->removeType);
@ -1502,11 +1450,9 @@ PlanDropFunctionStmt(DropStmt *stmt, const char *queryString)
*/ */
foreach(objectWithArgsListCell, deletingObjectWithArgsList) foreach(objectWithArgsListCell, deletingObjectWithArgsList)
{ {
ObjectWithArgs *func = NULL; ObjectWithArgs *func = castNode(ObjectWithArgs, lfirst(objectWithArgsListCell));
ObjectAddress *address = NULL; ObjectAddress *address = FunctionToObjectAddress(stmt->removeType, func,
stmt->missing_ok);
func = castNode(ObjectWithArgs, lfirst(objectWithArgsListCell));
address = FunctionToObjectAddress(stmt->removeType, func, stmt->missing_ok);
if (!IsObjectDistributed(address)) if (!IsObjectDistributed(address))
{ {
@ -1543,11 +1489,11 @@ PlanDropFunctionStmt(DropStmt *stmt, const char *queryString)
* Swap the list of objects before deparsing and restore the old list after. This * Swap the list of objects before deparsing and restore the old list after. This
* ensures we only have distributed functions in the deparsed drop statement. * ensures we only have distributed functions in the deparsed drop statement.
*/ */
stmtCopy = copyObject(stmt); DropStmt *stmtCopy = copyObject(stmt);
stmtCopy->objects = distributedObjectWithArgsList; stmtCopy->objects = distributedObjectWithArgsList;
dropStmtSql = DeparseTreeNode((Node *) stmtCopy); const char *dropStmtSql = DeparseTreeNode((Node *) stmtCopy);
commands = list_make3(DISABLE_DDL_PROPAGATION, List *commands = list_make3(DISABLE_DDL_PROPAGATION,
(void *) dropStmtSql, (void *) dropStmtSql,
ENABLE_DDL_PROPAGATION); ENABLE_DDL_PROPAGATION);
@ -1569,9 +1515,6 @@ PlanDropFunctionStmt(DropStmt *stmt, const char *queryString)
List * List *
PlanAlterFunctionDependsStmt(AlterObjectDependsStmt *stmt, const char *queryString) PlanAlterFunctionDependsStmt(AlterObjectDependsStmt *stmt, const char *queryString)
{ {
const ObjectAddress *address = NULL;
const char *functionName = NULL;
AssertObjectTypeIsFunctional(stmt->objectType); AssertObjectTypeIsFunctional(stmt->objectType);
if (creating_extension) if (creating_extension)
@ -1591,7 +1534,7 @@ PlanAlterFunctionDependsStmt(AlterObjectDependsStmt *stmt, const char *queryStri
return NIL; return NIL;
} }
address = GetObjectAddressFromParseTree((Node *) stmt, true); const ObjectAddress *address = GetObjectAddressFromParseTree((Node *) stmt, true);
if (!IsObjectDistributed(address)) if (!IsObjectDistributed(address))
{ {
return NIL; return NIL;
@ -1603,7 +1546,7 @@ PlanAlterFunctionDependsStmt(AlterObjectDependsStmt *stmt, const char *queryStri
* workers * workers
*/ */
functionName = getObjectIdentity(address); const char *functionName = getObjectIdentity(address);
ereport(ERROR, (errmsg("distrtibuted functions are not allowed to depend on an " ereport(ERROR, (errmsg("distrtibuted functions are not allowed to depend on an "
"extension"), "extension"),
errdetail("Function \"%s\" is already distributed. Functions from " errdetail("Function \"%s\" is already distributed. Functions from "
@ -1635,11 +1578,9 @@ AlterFunctionDependsStmtObjectAddress(AlterObjectDependsStmt *stmt, bool missing
void void
ProcessAlterFunctionSchemaStmt(AlterObjectSchemaStmt *stmt, const char *queryString) ProcessAlterFunctionSchemaStmt(AlterObjectSchemaStmt *stmt, const char *queryString)
{ {
const ObjectAddress *address = NULL;
AssertObjectTypeIsFunctional(stmt->objectType); AssertObjectTypeIsFunctional(stmt->objectType);
address = GetObjectAddressFromParseTree((Node *) stmt, false); const ObjectAddress *address = GetObjectAddressFromParseTree((Node *) stmt, false);
if (!ShouldPropagateAlterFunction(address)) if (!ShouldPropagateAlterFunction(address))
{ {
return; return;
@ -1698,16 +1639,11 @@ AlterFunctionOwnerObjectAddress(AlterOwnerStmt *stmt, bool missing_ok)
ObjectAddress * ObjectAddress *
AlterFunctionSchemaStmtObjectAddress(AlterObjectSchemaStmt *stmt, bool missing_ok) AlterFunctionSchemaStmtObjectAddress(AlterObjectSchemaStmt *stmt, bool missing_ok)
{ {
ObjectWithArgs *objectWithArgs = NULL;
Oid funcOid = InvalidOid;
List *names = NIL;
ObjectAddress *address = NULL;
AssertObjectTypeIsFunctional(stmt->objectType); AssertObjectTypeIsFunctional(stmt->objectType);
objectWithArgs = castNode(ObjectWithArgs, stmt->object); ObjectWithArgs *objectWithArgs = castNode(ObjectWithArgs, stmt->object);
funcOid = LookupFuncWithArgs(stmt->objectType, objectWithArgs, true); Oid funcOid = LookupFuncWithArgs(stmt->objectType, objectWithArgs, true);
names = objectWithArgs->objname; List *names = objectWithArgs->objname;
if (funcOid == InvalidOid) if (funcOid == InvalidOid)
{ {
@ -1744,7 +1680,7 @@ AlterFunctionSchemaStmtObjectAddress(AlterObjectSchemaStmt *stmt, bool missing_o
} }
} }
address = palloc0(sizeof(ObjectAddress)); ObjectAddress *address = palloc0(sizeof(ObjectAddress));
ObjectAddressSet(*address, ProcedureRelationId, funcOid); ObjectAddressSet(*address, ProcedureRelationId, funcOid);
return address; return address;
@ -1766,7 +1702,6 @@ GenerateBackupNameForProcCollision(const ObjectAddress *address)
address->objectId))); address->objectId)));
char *baseName = get_func_name(address->objectId); char *baseName = get_func_name(address->objectId);
int baseLength = strlen(baseName); int baseLength = strlen(baseName);
int numargs = 0;
Oid *argtypes = NULL; Oid *argtypes = NULL;
char **argnames = NULL; char **argnames = NULL;
char *argmodes = NULL; char *argmodes = NULL;
@ -1777,15 +1712,13 @@ GenerateBackupNameForProcCollision(const ObjectAddress *address)
elog(ERROR, "citus cache lookup failed."); elog(ERROR, "citus cache lookup failed.");
} }
numargs = get_func_arg_info(proctup, &argtypes, &argnames, &argmodes); int numargs = get_func_arg_info(proctup, &argtypes, &argnames, &argmodes);
ReleaseSysCache(proctup); ReleaseSysCache(proctup);
while (true) while (true)
{ {
int suffixLength = snprintf(suffix, NAMEDATALEN - 1, "(citus_backup_%d)", int suffixLength = snprintf(suffix, NAMEDATALEN - 1, "(citus_backup_%d)",
count); count);
List *newProcName = NIL;
FuncCandidateList clist = NULL;
/* trim the base name at the end to leave space for the suffix and trailing \0 */ /* trim the base name at the end to leave space for the suffix and trailing \0 */
baseLength = Min(baseLength, NAMEDATALEN - suffixLength - 1); baseLength = Min(baseLength, NAMEDATALEN - suffixLength - 1);
@ -1795,10 +1728,11 @@ GenerateBackupNameForProcCollision(const ObjectAddress *address)
strncpy(newName, baseName, baseLength); strncpy(newName, baseName, baseLength);
strncpy(newName + baseLength, suffix, suffixLength); strncpy(newName + baseLength, suffix, suffixLength);
newProcName = list_make2(namespace, makeString(newName)); List *newProcName = list_make2(namespace, makeString(newName));
/* don't need to rename if the input arguments don't match */ /* don't need to rename if the input arguments don't match */
clist = FuncnameGetCandidates(newProcName, numargs, NIL, false, false, true); FuncCandidateList clist = FuncnameGetCandidates(newProcName, numargs, NIL, false,
false, true);
for (; clist; clist = clist->next) for (; clist; clist = clist->next)
{ {
if (memcmp(clist->args, argtypes, sizeof(Oid) * numargs) == 0) if (memcmp(clist->args, argtypes, sizeof(Oid) * numargs) == 0)
@ -1828,8 +1762,6 @@ ObjectWithArgsFromOid(Oid funcOid)
Oid *argTypes = NULL; Oid *argTypes = NULL;
char **argNames = NULL; char **argNames = NULL;
char *argModes = NULL; char *argModes = NULL;
int numargs = 0;
int i = 0;
HeapTuple proctup = SearchSysCache1(PROCOID, funcOid); HeapTuple proctup = SearchSysCache1(PROCOID, funcOid);
if (!HeapTupleIsValid(proctup)) if (!HeapTupleIsValid(proctup))
@ -1837,14 +1769,14 @@ ObjectWithArgsFromOid(Oid funcOid)
elog(ERROR, "citus cache lookup failed."); elog(ERROR, "citus cache lookup failed.");
} }
numargs = get_func_arg_info(proctup, &argTypes, &argNames, &argModes); int numargs = get_func_arg_info(proctup, &argTypes, &argNames, &argModes);
objectWithArgs->objname = list_make2( objectWithArgs->objname = list_make2(
makeString(get_namespace_name(get_func_namespace(funcOid))), makeString(get_namespace_name(get_func_namespace(funcOid))),
makeString(get_func_name(funcOid)) makeString(get_func_name(funcOid))
); );
for (i = 0; i < numargs; i++) for (int i = 0; i < numargs; i++)
{ {
if (argModes == NULL || if (argModes == NULL ||
argModes[i] != PROARGMODE_OUT || argModes[i] != PROARGMODE_TABLE) argModes[i] != PROARGMODE_OUT || argModes[i] != PROARGMODE_TABLE)
@ -1870,13 +1802,10 @@ static ObjectAddress *
FunctionToObjectAddress(ObjectType objectType, ObjectWithArgs *objectWithArgs, FunctionToObjectAddress(ObjectType objectType, ObjectWithArgs *objectWithArgs,
bool missing_ok) bool missing_ok)
{ {
Oid funcOid = InvalidOid;
ObjectAddress *address = NULL;
AssertObjectTypeIsFunctional(objectType); AssertObjectTypeIsFunctional(objectType);
funcOid = LookupFuncWithArgs(objectType, objectWithArgs, missing_ok); Oid funcOid = LookupFuncWithArgs(objectType, objectWithArgs, missing_ok);
address = palloc0(sizeof(ObjectAddress)); ObjectAddress *address = palloc0(sizeof(ObjectAddress));
ObjectAddressSet(*address, ProcedureRelationId, funcOid); ObjectAddressSet(*address, ProcedureRelationId, funcOid);
return address; return address;

View File

@ -115,9 +115,6 @@ PlanIndexStmt(IndexStmt *createIndexStatement, const char *createIndexCommand)
*/ */
if (createIndexStatement->relation != NULL) if (createIndexStatement->relation != NULL)
{ {
Relation relation = NULL;
Oid relationId = InvalidOid;
bool isDistributedRelation = false;
LOCKMODE lockmode = ShareLock; LOCKMODE lockmode = ShareLock;
MemoryContext relationContext = NULL; MemoryContext relationContext = NULL;
@ -137,10 +134,10 @@ PlanIndexStmt(IndexStmt *createIndexStatement, const char *createIndexCommand)
* checked permissions, and will only fail when executing the actual * checked permissions, and will only fail when executing the actual
* index statements. * index statements.
*/ */
relation = heap_openrv(createIndexStatement->relation, lockmode); Relation relation = heap_openrv(createIndexStatement->relation, lockmode);
relationId = RelationGetRelid(relation); Oid relationId = RelationGetRelid(relation);
isDistributedRelation = IsDistributedTable(relationId); bool isDistributedRelation = IsDistributedTable(relationId);
if (createIndexStatement->relation->schemaname == NULL) if (createIndexStatement->relation->schemaname == NULL)
{ {
@ -163,15 +160,13 @@ PlanIndexStmt(IndexStmt *createIndexStatement, const char *createIndexCommand)
if (isDistributedRelation) if (isDistributedRelation)
{ {
Oid namespaceId = InvalidOid;
Oid indexRelationId = InvalidOid;
char *indexName = createIndexStatement->idxname; char *indexName = createIndexStatement->idxname;
char *namespaceName = createIndexStatement->relation->schemaname; char *namespaceName = createIndexStatement->relation->schemaname;
ErrorIfUnsupportedIndexStmt(createIndexStatement); ErrorIfUnsupportedIndexStmt(createIndexStatement);
namespaceId = get_namespace_oid(namespaceName, false); Oid namespaceId = get_namespace_oid(namespaceName, false);
indexRelationId = get_relname_relid(indexName, namespaceId); Oid indexRelationId = get_relname_relid(indexName, namespaceId);
/* if index does not exist, send the command to workers */ /* if index does not exist, send the command to workers */
if (!OidIsValid(indexRelationId)) if (!OidIsValid(indexRelationId))
@ -319,9 +314,6 @@ PlanDropIndexStmt(DropStmt *dropIndexStatement, const char *dropIndexCommand)
/* check if any of the indexes being dropped belong to a distributed table */ /* check if any of the indexes being dropped belong to a distributed table */
foreach(dropObjectCell, dropIndexStatement->objects) foreach(dropObjectCell, dropIndexStatement->objects)
{ {
Oid indexId = InvalidOid;
Oid relationId = InvalidOid;
bool isDistributedRelation = false;
struct DropRelationCallbackState state; struct DropRelationCallbackState state;
uint32 rvrFlags = RVR_MISSING_OK; uint32 rvrFlags = RVR_MISSING_OK;
LOCKMODE lockmode = AccessExclusiveLock; LOCKMODE lockmode = AccessExclusiveLock;
@ -349,7 +341,7 @@ PlanDropIndexStmt(DropStmt *dropIndexStatement, const char *dropIndexCommand)
state.heapOid = InvalidOid; state.heapOid = InvalidOid;
state.concurrent = dropIndexStatement->concurrent; state.concurrent = dropIndexStatement->concurrent;
indexId = RangeVarGetRelidExtended(rangeVar, lockmode, rvrFlags, Oid indexId = RangeVarGetRelidExtended(rangeVar, lockmode, rvrFlags,
RangeVarCallbackForDropIndex, RangeVarCallbackForDropIndex,
(void *) &state); (void *) &state);
@ -362,8 +354,8 @@ PlanDropIndexStmt(DropStmt *dropIndexStatement, const char *dropIndexCommand)
continue; continue;
} }
relationId = IndexGetRelation(indexId, false); Oid relationId = IndexGetRelation(indexId, false);
isDistributedRelation = IsDistributedTable(relationId); bool isDistributedRelation = IsDistributedTable(relationId);
if (isDistributedRelation) if (isDistributedRelation)
{ {
distributedIndexId = indexId; distributedIndexId = indexId;
@ -400,13 +392,6 @@ PlanDropIndexStmt(DropStmt *dropIndexStatement, const char *dropIndexCommand)
void void
PostProcessIndexStmt(IndexStmt *indexStmt) PostProcessIndexStmt(IndexStmt *indexStmt)
{ {
Relation relation = NULL;
Oid indexRelationId = InvalidOid;
Relation indexRelation = NULL;
Relation pg_index = NULL;
HeapTuple indexTuple = NULL;
Form_pg_index indexForm = NULL;
/* we are only processing CONCURRENT index statements */ /* we are only processing CONCURRENT index statements */
if (!indexStmt->concurrent) if (!indexStmt->concurrent)
{ {
@ -424,10 +409,10 @@ PostProcessIndexStmt(IndexStmt *indexStmt)
StartTransactionCommand(); StartTransactionCommand();
/* get the affected relation and index */ /* get the affected relation and index */
relation = heap_openrv(indexStmt->relation, ShareUpdateExclusiveLock); Relation relation = heap_openrv(indexStmt->relation, ShareUpdateExclusiveLock);
indexRelationId = get_relname_relid(indexStmt->idxname, Oid indexRelationId = get_relname_relid(indexStmt->idxname,
RelationGetNamespace(relation)); RelationGetNamespace(relation));
indexRelation = index_open(indexRelationId, RowExclusiveLock); Relation indexRelation = index_open(indexRelationId, RowExclusiveLock);
/* close relations but retain locks */ /* close relations but retain locks */
heap_close(relation, NoLock); heap_close(relation, NoLock);
@ -441,13 +426,14 @@ PostProcessIndexStmt(IndexStmt *indexStmt)
StartTransactionCommand(); StartTransactionCommand();
/* now, update index's validity in a way that can roll back */ /* now, update index's validity in a way that can roll back */
pg_index = heap_open(IndexRelationId, RowExclusiveLock); Relation pg_index = heap_open(IndexRelationId, RowExclusiveLock);
indexTuple = SearchSysCacheCopy1(INDEXRELID, ObjectIdGetDatum(indexRelationId)); HeapTuple indexTuple = SearchSysCacheCopy1(INDEXRELID, ObjectIdGetDatum(
indexRelationId));
Assert(HeapTupleIsValid(indexTuple)); /* better be present, we have lock! */ Assert(HeapTupleIsValid(indexTuple)); /* better be present, we have lock! */
/* mark as valid, save, and update pg_index indexes */ /* mark as valid, save, and update pg_index indexes */
indexForm = (Form_pg_index) GETSTRUCT(indexTuple); Form_pg_index indexForm = (Form_pg_index) GETSTRUCT(indexTuple);
indexForm->indisvalid = true; indexForm->indisvalid = true;
CatalogTupleUpdate(pg_index, &indexTuple->t_self, indexTuple); CatalogTupleUpdate(pg_index, &indexTuple->t_self, indexTuple);
@ -528,11 +514,10 @@ CreateIndexTaskList(Oid relationId, IndexStmt *indexStmt)
{ {
ShardInterval *shardInterval = (ShardInterval *) lfirst(shardIntervalCell); ShardInterval *shardInterval = (ShardInterval *) lfirst(shardIntervalCell);
uint64 shardId = shardInterval->shardId; uint64 shardId = shardInterval->shardId;
Task *task = NULL;
deparse_shard_index_statement(indexStmt, relationId, shardId, &ddlString); deparse_shard_index_statement(indexStmt, relationId, shardId, &ddlString);
task = CitusMakeNode(Task); Task *task = CitusMakeNode(Task);
task->jobId = jobId; task->jobId = jobId;
task->taskId = taskId++; task->taskId = taskId++;
task->taskType = DDL_TASK; task->taskType = DDL_TASK;
@ -574,11 +559,10 @@ CreateReindexTaskList(Oid relationId, ReindexStmt *reindexStmt)
{ {
ShardInterval *shardInterval = (ShardInterval *) lfirst(shardIntervalCell); ShardInterval *shardInterval = (ShardInterval *) lfirst(shardIntervalCell);
uint64 shardId = shardInterval->shardId; uint64 shardId = shardInterval->shardId;
Task *task = NULL;
deparse_shard_reindex_statement(reindexStmt, relationId, shardId, &ddlString); deparse_shard_reindex_statement(reindexStmt, relationId, shardId, &ddlString);
task = CitusMakeNode(Task); Task *task = CitusMakeNode(Task);
task->jobId = jobId; task->jobId = jobId;
task->taskId = taskId++; task->taskId = taskId++;
task->taskType = DDL_TASK; task->taskType = DDL_TASK;
@ -612,13 +596,11 @@ RangeVarCallbackForDropIndex(const RangeVar *rel, Oid relOid, Oid oldRelOid, voi
{ {
/* *INDENT-OFF* */ /* *INDENT-OFF* */
HeapTuple tuple; HeapTuple tuple;
struct DropRelationCallbackState *state;
char relkind; char relkind;
char expected_relkind; char expected_relkind;
Form_pg_class classform;
LOCKMODE heap_lockmode; LOCKMODE heap_lockmode;
state = (struct DropRelationCallbackState *) arg; struct DropRelationCallbackState *state = (struct DropRelationCallbackState *) arg;
relkind = state->relkind; relkind = state->relkind;
heap_lockmode = state->concurrent ? heap_lockmode = state->concurrent ?
ShareUpdateExclusiveLock : AccessExclusiveLock; ShareUpdateExclusiveLock : AccessExclusiveLock;
@ -643,7 +625,7 @@ RangeVarCallbackForDropIndex(const RangeVar *rel, Oid relOid, Oid oldRelOid, voi
tuple = SearchSysCache1(RELOID, ObjectIdGetDatum(relOid)); tuple = SearchSysCache1(RELOID, ObjectIdGetDatum(relOid));
if (!HeapTupleIsValid(tuple)) if (!HeapTupleIsValid(tuple))
return; /* concurrently dropped, so nothing to do */ return; /* concurrently dropped, so nothing to do */
classform = (Form_pg_class) GETSTRUCT(tuple); Form_pg_class classform = (Form_pg_class) GETSTRUCT(tuple);
/* /*
* PG 11 sends relkind as partitioned index for an index * PG 11 sends relkind as partitioned index for an index
@ -805,7 +787,6 @@ ErrorIfUnsupportedIndexStmt(IndexStmt *createIndexStatement)
Oid relationId = RangeVarGetRelid(relation, lockMode, missingOk); Oid relationId = RangeVarGetRelid(relation, lockMode, missingOk);
Var *partitionKey = DistPartitionKey(relationId); Var *partitionKey = DistPartitionKey(relationId);
char partitionMethod = PartitionMethod(relationId); char partitionMethod = PartitionMethod(relationId);
List *indexParameterList = NIL;
ListCell *indexParameterCell = NULL; ListCell *indexParameterCell = NULL;
bool indexContainsPartitionColumn = false; bool indexContainsPartitionColumn = false;
@ -825,12 +806,11 @@ ErrorIfUnsupportedIndexStmt(IndexStmt *createIndexStatement)
"is currently unsupported"))); "is currently unsupported")));
} }
indexParameterList = createIndexStatement->indexParams; List *indexParameterList = createIndexStatement->indexParams;
foreach(indexParameterCell, indexParameterList) foreach(indexParameterCell, indexParameterList)
{ {
IndexElem *indexElement = (IndexElem *) lfirst(indexParameterCell); IndexElem *indexElement = (IndexElem *) lfirst(indexParameterCell);
char *columnName = indexElement->name; char *columnName = indexElement->name;
AttrNumber attributeNumber = InvalidAttrNumber;
/* column name is null for index expressions, skip it */ /* column name is null for index expressions, skip it */
if (columnName == NULL) if (columnName == NULL)
@ -838,7 +818,7 @@ ErrorIfUnsupportedIndexStmt(IndexStmt *createIndexStatement)
continue; continue;
} }
attributeNumber = get_attnum(relationId, columnName); AttrNumber attributeNumber = get_attnum(relationId, columnName);
if (attributeNumber == partitionKey->varattno) if (attributeNumber == partitionKey->varattno)
{ {
indexContainsPartitionColumn = true; indexContainsPartitionColumn = true;
@ -902,7 +882,6 @@ DropIndexTaskList(Oid relationId, Oid indexId, DropStmt *dropStmt)
ShardInterval *shardInterval = (ShardInterval *) lfirst(shardIntervalCell); ShardInterval *shardInterval = (ShardInterval *) lfirst(shardIntervalCell);
uint64 shardId = shardInterval->shardId; uint64 shardId = shardInterval->shardId;
char *shardIndexName = pstrdup(indexName); char *shardIndexName = pstrdup(indexName);
Task *task = NULL;
AppendShardIdToName(&shardIndexName, shardId); AppendShardIdToName(&shardIndexName, shardId);
@ -913,7 +892,7 @@ DropIndexTaskList(Oid relationId, Oid indexId, DropStmt *dropStmt)
quote_qualified_identifier(schemaName, shardIndexName), quote_qualified_identifier(schemaName, shardIndexName),
(dropStmt->behavior == DROP_RESTRICT ? "RESTRICT" : "CASCADE")); (dropStmt->behavior == DROP_RESTRICT ? "RESTRICT" : "CASCADE"));
task = CitusMakeNode(Task); Task *task = CitusMakeNode(Task);
task->jobId = jobId; task->jobId = jobId;
task->taskId = taskId++; task->taskId = taskId++;
task->taskType = DDL_TASK; task->taskType = DDL_TASK;

View File

@ -298,8 +298,6 @@ PG_FUNCTION_INFO_V1(citus_text_send_as_jsonb);
static void static void
CitusCopyFrom(CopyStmt *copyStatement, char *completionTag) CitusCopyFrom(CopyStmt *copyStatement, char *completionTag)
{ {
bool isCopyFromWorker = false;
BeginOrContinueCoordinatedTransaction(); BeginOrContinueCoordinatedTransaction();
/* disallow COPY to/from file or program except for superusers */ /* disallow COPY to/from file or program except for superusers */
@ -324,7 +322,7 @@ CitusCopyFrom(CopyStmt *copyStatement, char *completionTag)
} }
masterConnection = NULL; /* reset, might still be set after error */ masterConnection = NULL; /* reset, might still be set after error */
isCopyFromWorker = IsCopyFromWorker(copyStatement); bool isCopyFromWorker = IsCopyFromWorker(copyStatement);
if (isCopyFromWorker) if (isCopyFromWorker)
{ {
CopyFromWorkerNode(copyStatement, completionTag); CopyFromWorkerNode(copyStatement, completionTag);
@ -387,9 +385,6 @@ CopyFromWorkerNode(CopyStmt *copyStatement, char *completionTag)
NodeAddress *masterNodeAddress = MasterNodeAddress(copyStatement); NodeAddress *masterNodeAddress = MasterNodeAddress(copyStatement);
char *nodeName = masterNodeAddress->nodeName; char *nodeName = masterNodeAddress->nodeName;
int32 nodePort = masterNodeAddress->nodePort; int32 nodePort = masterNodeAddress->nodePort;
Oid relationId = InvalidOid;
char partitionMethod = 0;
char *schemaName = NULL;
uint32 connectionFlags = FOR_DML; uint32 connectionFlags = FOR_DML;
masterConnection = GetNodeConnection(connectionFlags, nodeName, nodePort); masterConnection = GetNodeConnection(connectionFlags, nodeName, nodePort);
@ -399,14 +394,14 @@ CopyFromWorkerNode(CopyStmt *copyStatement, char *completionTag)
RemoteTransactionBeginIfNecessary(masterConnection); RemoteTransactionBeginIfNecessary(masterConnection);
/* strip schema name for local reference */ /* strip schema name for local reference */
schemaName = copyStatement->relation->schemaname; char *schemaName = copyStatement->relation->schemaname;
copyStatement->relation->schemaname = NULL; copyStatement->relation->schemaname = NULL;
relationId = RangeVarGetRelid(copyStatement->relation, NoLock, false); Oid relationId = RangeVarGetRelid(copyStatement->relation, NoLock, false);
/* put schema name back */ /* put schema name back */
copyStatement->relation->schemaname = schemaName; copyStatement->relation->schemaname = schemaName;
partitionMethod = MasterPartitionMethod(copyStatement->relation); char partitionMethod = MasterPartitionMethod(copyStatement->relation);
if (partitionMethod != DISTRIBUTE_BY_APPEND) if (partitionMethod != DISTRIBUTE_BY_APPEND)
{ {
ereport(ERROR, (errmsg("copy from worker nodes is only supported " ereport(ERROR, (errmsg("copy from worker nodes is only supported "
@ -439,18 +434,10 @@ CopyToExistingShards(CopyStmt *copyStatement, char *completionTag)
CitusCopyDestReceiver *copyDest = NULL; CitusCopyDestReceiver *copyDest = NULL;
DestReceiver *dest = NULL; DestReceiver *dest = NULL;
Relation distributedRelation = NULL;
Relation copiedDistributedRelation = NULL; Relation copiedDistributedRelation = NULL;
Form_pg_class copiedDistributedRelationTuple = NULL; Form_pg_class copiedDistributedRelationTuple = NULL;
TupleDesc tupleDescriptor = NULL;
uint32 columnCount = 0;
Datum *columnValues = NULL;
bool *columnNulls = NULL;
int columnIndex = 0;
List *columnNameList = NIL; List *columnNameList = NIL;
Var *partitionColumn = NULL;
int partitionColumnIndex = INVALID_PARTITION_COLUMN_INDEX; int partitionColumnIndex = INVALID_PARTITION_COLUMN_INDEX;
TupleTableSlot *tupleTableSlot = NULL;
EState *executorState = NULL; EState *executorState = NULL;
MemoryContext executorTupleContext = NULL; MemoryContext executorTupleContext = NULL;
@ -465,27 +452,28 @@ CopyToExistingShards(CopyStmt *copyStatement, char *completionTag)
ErrorContextCallback errorCallback; ErrorContextCallback errorCallback;
/* allocate column values and nulls arrays */ /* allocate column values and nulls arrays */
distributedRelation = heap_open(tableId, RowExclusiveLock); Relation distributedRelation = heap_open(tableId, RowExclusiveLock);
tupleDescriptor = RelationGetDescr(distributedRelation); TupleDesc tupleDescriptor = RelationGetDescr(distributedRelation);
columnCount = tupleDescriptor->natts; uint32 columnCount = tupleDescriptor->natts;
columnValues = palloc0(columnCount * sizeof(Datum)); Datum *columnValues = palloc0(columnCount * sizeof(Datum));
columnNulls = palloc0(columnCount * sizeof(bool)); bool *columnNulls = palloc0(columnCount * sizeof(bool));
/* set up a virtual tuple table slot */ /* set up a virtual tuple table slot */
tupleTableSlot = MakeSingleTupleTableSlotCompat(tupleDescriptor, &TTSOpsVirtual); TupleTableSlot *tupleTableSlot = MakeSingleTupleTableSlotCompat(tupleDescriptor,
&TTSOpsVirtual);
tupleTableSlot->tts_nvalid = columnCount; tupleTableSlot->tts_nvalid = columnCount;
tupleTableSlot->tts_values = columnValues; tupleTableSlot->tts_values = columnValues;
tupleTableSlot->tts_isnull = columnNulls; tupleTableSlot->tts_isnull = columnNulls;
/* determine the partition column index in the tuple descriptor */ /* determine the partition column index in the tuple descriptor */
partitionColumn = PartitionColumn(tableId, 0); Var *partitionColumn = PartitionColumn(tableId, 0);
if (partitionColumn != NULL) if (partitionColumn != NULL)
{ {
partitionColumnIndex = partitionColumn->varattno - 1; partitionColumnIndex = partitionColumn->varattno - 1;
} }
/* build the list of column names for remote COPY statements */ /* build the list of column names for remote COPY statements */
for (columnIndex = 0; columnIndex < columnCount; columnIndex++) for (int columnIndex = 0; columnIndex < columnCount; columnIndex++)
{ {
Form_pg_attribute currentColumn = TupleDescAttr(tupleDescriptor, columnIndex); Form_pg_attribute currentColumn = TupleDescAttr(tupleDescriptor, columnIndex);
char *columnName = NameStr(currentColumn->attname); char *columnName = NameStr(currentColumn->attname);
@ -566,15 +554,12 @@ CopyToExistingShards(CopyStmt *copyStatement, char *completionTag)
while (true) while (true)
{ {
bool nextRowFound = false;
MemoryContext oldContext = NULL;
ResetPerTupleExprContext(executorState); ResetPerTupleExprContext(executorState);
oldContext = MemoryContextSwitchTo(executorTupleContext); MemoryContext oldContext = MemoryContextSwitchTo(executorTupleContext);
/* parse a row from the input */ /* parse a row from the input */
nextRowFound = NextCopyFromCompat(copyState, executorExpressionContext, bool nextRowFound = NextCopyFromCompat(copyState, executorExpressionContext,
columnValues, columnNulls); columnValues, columnNulls);
if (!nextRowFound) if (!nextRowFound)
@ -625,8 +610,6 @@ CopyToExistingShards(CopyStmt *copyStatement, char *completionTag)
static void static void
CopyToNewShards(CopyStmt *copyStatement, char *completionTag, Oid relationId) CopyToNewShards(CopyStmt *copyStatement, char *completionTag, Oid relationId)
{ {
FmgrInfo *columnOutputFunctions = NULL;
/* allocate column values and nulls arrays */ /* allocate column values and nulls arrays */
Relation distributedRelation = heap_open(relationId, RowExclusiveLock); Relation distributedRelation = heap_open(relationId, RowExclusiveLock);
TupleDesc tupleDescriptor = RelationGetDescr(distributedRelation); TupleDesc tupleDescriptor = RelationGetDescr(distributedRelation);
@ -668,7 +651,8 @@ CopyToNewShards(CopyStmt *copyStatement, char *completionTag, Oid relationId)
copyOutState->fe_msgbuf = makeStringInfo(); copyOutState->fe_msgbuf = makeStringInfo();
copyOutState->rowcontext = executorTupleContext; copyOutState->rowcontext = executorTupleContext;
columnOutputFunctions = ColumnOutputFunctions(tupleDescriptor, copyOutState->binary); FmgrInfo *columnOutputFunctions = ColumnOutputFunctions(tupleDescriptor,
copyOutState->binary);
/* set up callback to identify error line number */ /* set up callback to identify error line number */
errorCallback.callback = CopyFromErrorCallback; errorCallback.callback = CopyFromErrorCallback;
@ -684,18 +668,14 @@ CopyToNewShards(CopyStmt *copyStatement, char *completionTag, Oid relationId)
while (true) while (true)
{ {
bool nextRowFound = false;
MemoryContext oldContext = NULL;
uint64 messageBufferSize = 0;
ResetPerTupleExprContext(executorState); ResetPerTupleExprContext(executorState);
/* switch to tuple memory context and start showing line number in errors */ /* switch to tuple memory context and start showing line number in errors */
error_context_stack = &errorCallback; error_context_stack = &errorCallback;
oldContext = MemoryContextSwitchTo(executorTupleContext); MemoryContext oldContext = MemoryContextSwitchTo(executorTupleContext);
/* parse a row from the input */ /* parse a row from the input */
nextRowFound = NextCopyFromCompat(copyState, executorExpressionContext, bool nextRowFound = NextCopyFromCompat(copyState, executorExpressionContext,
columnValues, columnNulls); columnValues, columnNulls);
if (!nextRowFound) if (!nextRowFound)
@ -739,7 +719,7 @@ CopyToNewShards(CopyStmt *copyStatement, char *completionTag, Oid relationId)
SendCopyDataToAll(copyOutState->fe_msgbuf, currentShardId, SendCopyDataToAll(copyOutState->fe_msgbuf, currentShardId,
shardConnections->connectionList); shardConnections->connectionList);
messageBufferSize = copyOutState->fe_msgbuf->len; uint64 messageBufferSize = copyOutState->fe_msgbuf->len;
copiedDataSizeInBytes = copiedDataSizeInBytes + messageBufferSize; copiedDataSizeInBytes = copiedDataSizeInBytes + messageBufferSize;
/* /*
@ -841,7 +821,6 @@ static char
MasterPartitionMethod(RangeVar *relation) MasterPartitionMethod(RangeVar *relation)
{ {
char partitionMethod = '\0'; char partitionMethod = '\0';
PGresult *queryResult = NULL;
bool raiseInterrupts = true; bool raiseInterrupts = true;
char *relationName = relation->relname; char *relationName = relation->relname;
@ -855,7 +834,7 @@ MasterPartitionMethod(RangeVar *relation)
{ {
ReportConnectionError(masterConnection, ERROR); ReportConnectionError(masterConnection, ERROR);
} }
queryResult = GetRemoteCommandResult(masterConnection, raiseInterrupts); PGresult *queryResult = GetRemoteCommandResult(masterConnection, raiseInterrupts);
if (PQresultStatus(queryResult) == PGRES_TUPLES_OK) if (PQresultStatus(queryResult) == PGRES_TUPLES_OK)
{ {
char *partitionMethodString = PQgetvalue((PGresult *) queryResult, 0, 0); char *partitionMethodString = PQgetvalue((PGresult *) queryResult, 0, 0);
@ -923,7 +902,6 @@ OpenCopyConnectionsForNewShards(CopyStmt *copyStatement,
ShardConnections *shardConnections, ShardConnections *shardConnections,
bool stopOnFailure, bool useBinaryCopyFormat) bool stopOnFailure, bool useBinaryCopyFormat)
{ {
List *finalizedPlacementList = NIL;
int failedPlacementCount = 0; int failedPlacementCount = 0;
ListCell *placementCell = NULL; ListCell *placementCell = NULL;
List *connectionList = NULL; List *connectionList = NULL;
@ -940,7 +918,7 @@ OpenCopyConnectionsForNewShards(CopyStmt *copyStatement,
/* release finalized placement list at the end of this function */ /* release finalized placement list at the end of this function */
MemoryContext oldContext = MemoryContextSwitchTo(localContext); MemoryContext oldContext = MemoryContextSwitchTo(localContext);
finalizedPlacementList = MasterShardPlacementList(shardId); List *finalizedPlacementList = MasterShardPlacementList(shardId);
MemoryContextSwitchTo(oldContext); MemoryContextSwitchTo(oldContext);
@ -948,10 +926,7 @@ OpenCopyConnectionsForNewShards(CopyStmt *copyStatement,
{ {
ShardPlacement *placement = (ShardPlacement *) lfirst(placementCell); ShardPlacement *placement = (ShardPlacement *) lfirst(placementCell);
char *nodeUser = CurrentUserName(); char *nodeUser = CurrentUserName();
MultiConnection *connection = NULL;
uint32 connectionFlags = FOR_DML; uint32 connectionFlags = FOR_DML;
StringInfo copyCommand = NULL;
PGresult *result = NULL;
/* /*
* For hash partitioned tables, connection establishment happens in * For hash partitioned tables, connection establishment happens in
@ -959,7 +934,8 @@ OpenCopyConnectionsForNewShards(CopyStmt *copyStatement,
*/ */
Assert(placement->partitionMethod != DISTRIBUTE_BY_HASH); Assert(placement->partitionMethod != DISTRIBUTE_BY_HASH);
connection = GetPlacementConnection(connectionFlags, placement, nodeUser); MultiConnection *connection = GetPlacementConnection(connectionFlags, placement,
nodeUser);
if (PQstatus(connection->pgConn) != CONNECTION_OK) if (PQstatus(connection->pgConn) != CONNECTION_OK)
{ {
@ -987,14 +963,15 @@ OpenCopyConnectionsForNewShards(CopyStmt *copyStatement,
ClaimConnectionExclusively(connection); ClaimConnectionExclusively(connection);
RemoteTransactionBeginIfNecessary(connection); RemoteTransactionBeginIfNecessary(connection);
copyCommand = ConstructCopyStatement(copyStatement, shardConnections->shardId, StringInfo copyCommand = ConstructCopyStatement(copyStatement,
shardConnections->shardId,
useBinaryCopyFormat); useBinaryCopyFormat);
if (!SendRemoteCommand(connection, copyCommand->data)) if (!SendRemoteCommand(connection, copyCommand->data))
{ {
ReportConnectionError(connection, ERROR); ReportConnectionError(connection, ERROR);
} }
result = GetRemoteCommandResult(connection, raiseInterrupts); PGresult *result = GetRemoteCommandResult(connection, raiseInterrupts);
if (PQresultStatus(result) != PGRES_COPY_IN) if (PQresultStatus(result) != PGRES_COPY_IN)
{ {
ReportResultError(connection, result, ERROR); ReportResultError(connection, result, ERROR);
@ -1035,9 +1012,8 @@ CanUseBinaryCopyFormat(TupleDesc tupleDescription)
{ {
bool useBinaryCopyFormat = true; bool useBinaryCopyFormat = true;
int totalColumnCount = tupleDescription->natts; int totalColumnCount = tupleDescription->natts;
int columnIndex = 0;
for (columnIndex = 0; columnIndex < totalColumnCount; columnIndex++) for (int columnIndex = 0; columnIndex < totalColumnCount; columnIndex++)
{ {
Form_pg_attribute currentColumn = TupleDescAttr(tupleDescription, columnIndex); Form_pg_attribute currentColumn = TupleDescAttr(tupleDescription, columnIndex);
Oid typeId = InvalidOid; Oid typeId = InvalidOid;
@ -1149,7 +1125,6 @@ static List *
RemoteFinalizedShardPlacementList(uint64 shardId) RemoteFinalizedShardPlacementList(uint64 shardId)
{ {
List *finalizedPlacementList = NIL; List *finalizedPlacementList = NIL;
PGresult *queryResult = NULL;
bool raiseInterrupts = true; bool raiseInterrupts = true;
StringInfo shardPlacementsCommand = makeStringInfo(); StringInfo shardPlacementsCommand = makeStringInfo();
@ -1159,13 +1134,12 @@ RemoteFinalizedShardPlacementList(uint64 shardId)
{ {
ReportConnectionError(masterConnection, ERROR); ReportConnectionError(masterConnection, ERROR);
} }
queryResult = GetRemoteCommandResult(masterConnection, raiseInterrupts); PGresult *queryResult = GetRemoteCommandResult(masterConnection, raiseInterrupts);
if (PQresultStatus(queryResult) == PGRES_TUPLES_OK) if (PQresultStatus(queryResult) == PGRES_TUPLES_OK)
{ {
int rowCount = PQntuples(queryResult); int rowCount = PQntuples(queryResult);
int rowIndex = 0;
for (rowIndex = 0; rowIndex < rowCount; rowIndex++) for (int rowIndex = 0; rowIndex < rowCount; rowIndex++)
{ {
char *placementIdString = PQgetvalue(queryResult, rowIndex, 0); char *placementIdString = PQgetvalue(queryResult, rowIndex, 0);
char *nodeName = pstrdup(PQgetvalue(queryResult, rowIndex, 1)); char *nodeName = pstrdup(PQgetvalue(queryResult, rowIndex, 1));
@ -1236,11 +1210,10 @@ ConstructCopyStatement(CopyStmt *copyStatement, int64 shardId, bool useBinaryCop
char *relationName = copyStatement->relation->relname; char *relationName = copyStatement->relation->relname;
char *shardName = pstrdup(relationName); char *shardName = pstrdup(relationName);
char *shardQualifiedName = NULL;
AppendShardIdToName(&shardName, shardId); AppendShardIdToName(&shardName, shardId);
shardQualifiedName = quote_qualified_identifier(schemaName, shardName); char *shardQualifiedName = quote_qualified_identifier(schemaName, shardName);
appendStringInfo(command, "COPY %s ", shardQualifiedName); appendStringInfo(command, "COPY %s ", shardQualifiedName);
@ -1331,7 +1304,6 @@ EndRemoteCopy(int64 shardId, List *connectionList)
foreach(connectionCell, connectionList) foreach(connectionCell, connectionList)
{ {
MultiConnection *connection = (MultiConnection *) lfirst(connectionCell); MultiConnection *connection = (MultiConnection *) lfirst(connectionCell);
PGresult *result = NULL;
bool raiseInterrupts = true; bool raiseInterrupts = true;
/* end the COPY input */ /* end the COPY input */
@ -1343,7 +1315,7 @@ EndRemoteCopy(int64 shardId, List *connectionList)
} }
/* check whether there were any COPY errors */ /* check whether there were any COPY errors */
result = GetRemoteCommandResult(connection, raiseInterrupts); PGresult *result = GetRemoteCommandResult(connection, raiseInterrupts);
if (PQresultStatus(result) != PGRES_COMMAND_OK) if (PQresultStatus(result) != PGRES_COMMAND_OK)
{ {
ReportCopyError(connection, result); ReportCopyError(connection, result);
@ -1487,14 +1459,13 @@ static Oid
TypeForColumnName(Oid relationId, TupleDesc tupleDescriptor, char *columnName) TypeForColumnName(Oid relationId, TupleDesc tupleDescriptor, char *columnName)
{ {
AttrNumber destAttrNumber = get_attnum(relationId, columnName); AttrNumber destAttrNumber = get_attnum(relationId, columnName);
Form_pg_attribute attr = NULL;
if (destAttrNumber == InvalidAttrNumber) if (destAttrNumber == InvalidAttrNumber)
{ {
ereport(ERROR, (errmsg("invalid attr? %s", columnName))); ereport(ERROR, (errmsg("invalid attr? %s", columnName)));
} }
attr = TupleDescAttr(tupleDescriptor, destAttrNumber - 1); Form_pg_attribute attr = TupleDescAttr(tupleDescriptor, destAttrNumber - 1);
return attr->atttypid; return attr->atttypid;
} }
@ -1508,9 +1479,8 @@ TypeArrayFromTupleDescriptor(TupleDesc tupleDescriptor)
{ {
int columnCount = tupleDescriptor->natts; int columnCount = tupleDescriptor->natts;
Oid *typeArray = palloc0(columnCount * sizeof(Oid)); Oid *typeArray = palloc0(columnCount * sizeof(Oid));
int columnIndex = 0;
for (columnIndex = 0; columnIndex < columnCount; columnIndex++) for (int columnIndex = 0; columnIndex < columnCount; columnIndex++)
{ {
Form_pg_attribute attr = TupleDescAttr(tupleDescriptor, columnIndex); Form_pg_attribute attr = TupleDescAttr(tupleDescriptor, columnIndex);
if (attr->attisdropped) if (attr->attisdropped)
@ -1537,15 +1507,13 @@ ColumnCoercionPaths(TupleDesc destTupleDescriptor, TupleDesc inputTupleDescripto
Oid destRelId, List *columnNameList, Oid destRelId, List *columnNameList,
Oid *finalColumnTypeArray) Oid *finalColumnTypeArray)
{ {
int columnIndex = 0;
int columnCount = inputTupleDescriptor->natts; int columnCount = inputTupleDescriptor->natts;
CopyCoercionData *coercePaths = palloc0(columnCount * sizeof(CopyCoercionData)); CopyCoercionData *coercePaths = palloc0(columnCount * sizeof(CopyCoercionData));
Oid *inputTupleTypes = TypeArrayFromTupleDescriptor(inputTupleDescriptor); Oid *inputTupleTypes = TypeArrayFromTupleDescriptor(inputTupleDescriptor);
ListCell *currentColumnName = list_head(columnNameList); ListCell *currentColumnName = list_head(columnNameList);
for (columnIndex = 0; columnIndex < columnCount; columnIndex++) for (int columnIndex = 0; columnIndex < columnCount; columnIndex++)
{ {
Oid destTupleType = InvalidOid;
Oid inputTupleType = inputTupleTypes[columnIndex]; Oid inputTupleType = inputTupleTypes[columnIndex];
char *columnName = lfirst(currentColumnName); char *columnName = lfirst(currentColumnName);
@ -1555,7 +1523,7 @@ ColumnCoercionPaths(TupleDesc destTupleDescriptor, TupleDesc inputTupleDescripto
continue; continue;
} }
destTupleType = TypeForColumnName(destRelId, destTupleDescriptor, columnName); Oid destTupleType = TypeForColumnName(destRelId, destTupleDescriptor, columnName);
finalColumnTypeArray[columnIndex] = destTupleType; finalColumnTypeArray[columnIndex] = destTupleType;
@ -1584,8 +1552,7 @@ TypeOutputFunctions(uint32 columnCount, Oid *typeIdArray, bool binaryFormat)
{ {
FmgrInfo *columnOutputFunctions = palloc0(columnCount * sizeof(FmgrInfo)); FmgrInfo *columnOutputFunctions = palloc0(columnCount * sizeof(FmgrInfo));
uint32 columnIndex = 0; for (uint32 columnIndex = 0; columnIndex < columnCount; columnIndex++)
for (columnIndex = 0; columnIndex < columnCount; columnIndex++)
{ {
FmgrInfo *currentOutputFunction = &columnOutputFunctions[columnIndex]; FmgrInfo *currentOutputFunction = &columnOutputFunctions[columnIndex];
Oid columnTypeId = typeIdArray[columnIndex]; Oid columnTypeId = typeIdArray[columnIndex];
@ -1665,7 +1632,6 @@ AppendCopyRowData(Datum *valueArray, bool *isNullArray, TupleDesc rowDescriptor,
uint32 totalColumnCount = (uint32) rowDescriptor->natts; uint32 totalColumnCount = (uint32) rowDescriptor->natts;
uint32 availableColumnCount = AvailableColumnCount(rowDescriptor); uint32 availableColumnCount = AvailableColumnCount(rowDescriptor);
uint32 appendedColumnCount = 0; uint32 appendedColumnCount = 0;
uint32 columnIndex = 0;
MemoryContext oldContext = MemoryContextSwitchTo(rowOutputState->rowcontext); MemoryContext oldContext = MemoryContextSwitchTo(rowOutputState->rowcontext);
@ -1673,7 +1639,7 @@ AppendCopyRowData(Datum *valueArray, bool *isNullArray, TupleDesc rowDescriptor,
{ {
CopySendInt16(rowOutputState, availableColumnCount); CopySendInt16(rowOutputState, availableColumnCount);
} }
for (columnIndex = 0; columnIndex < totalColumnCount; columnIndex++) for (uint32 columnIndex = 0; columnIndex < totalColumnCount; columnIndex++)
{ {
Form_pg_attribute currentColumn = TupleDescAttr(rowDescriptor, columnIndex); Form_pg_attribute currentColumn = TupleDescAttr(rowDescriptor, columnIndex);
Datum value = valueArray[columnIndex]; Datum value = valueArray[columnIndex];
@ -1803,9 +1769,8 @@ static uint32
AvailableColumnCount(TupleDesc tupleDescriptor) AvailableColumnCount(TupleDesc tupleDescriptor)
{ {
uint32 columnCount = 0; uint32 columnCount = 0;
uint32 columnIndex = 0;
for (columnIndex = 0; columnIndex < tupleDescriptor->natts; columnIndex++) for (uint32 columnIndex = 0; columnIndex < tupleDescriptor->natts; columnIndex++)
{ {
Form_pg_attribute currentColumn = TupleDescAttr(tupleDescriptor, columnIndex); Form_pg_attribute currentColumn = TupleDescAttr(tupleDescriptor, columnIndex);
@ -1916,13 +1881,11 @@ MasterCreateEmptyShard(char *relationName)
static int64 static int64
CreateEmptyShard(char *relationName) CreateEmptyShard(char *relationName)
{ {
int64 shardId = 0;
text *relationNameText = cstring_to_text(relationName); text *relationNameText = cstring_to_text(relationName);
Datum relationNameDatum = PointerGetDatum(relationNameText); Datum relationNameDatum = PointerGetDatum(relationNameText);
Datum shardIdDatum = DirectFunctionCall1(master_create_empty_shard, Datum shardIdDatum = DirectFunctionCall1(master_create_empty_shard,
relationNameDatum); relationNameDatum);
shardId = DatumGetInt64(shardIdDatum); int64 shardId = DatumGetInt64(shardIdDatum);
return shardId; return shardId;
} }
@ -1936,7 +1899,6 @@ static int64
RemoteCreateEmptyShard(char *relationName) RemoteCreateEmptyShard(char *relationName)
{ {
int64 shardId = 0; int64 shardId = 0;
PGresult *queryResult = NULL;
bool raiseInterrupts = true; bool raiseInterrupts = true;
StringInfo createEmptyShardCommand = makeStringInfo(); StringInfo createEmptyShardCommand = makeStringInfo();
@ -1946,7 +1908,7 @@ RemoteCreateEmptyShard(char *relationName)
{ {
ReportConnectionError(masterConnection, ERROR); ReportConnectionError(masterConnection, ERROR);
} }
queryResult = GetRemoteCommandResult(masterConnection, raiseInterrupts); PGresult *queryResult = GetRemoteCommandResult(masterConnection, raiseInterrupts);
if (PQresultStatus(queryResult) == PGRES_TUPLES_OK) if (PQresultStatus(queryResult) == PGRES_TUPLES_OK)
{ {
char *shardIdString = PQgetvalue((PGresult *) queryResult, 0, 0); char *shardIdString = PQgetvalue((PGresult *) queryResult, 0, 0);
@ -1991,7 +1953,6 @@ MasterUpdateShardStatistics(uint64 shardId)
static void static void
RemoteUpdateShardStatistics(uint64 shardId) RemoteUpdateShardStatistics(uint64 shardId)
{ {
PGresult *queryResult = NULL;
bool raiseInterrupts = true; bool raiseInterrupts = true;
StringInfo updateShardStatisticsCommand = makeStringInfo(); StringInfo updateShardStatisticsCommand = makeStringInfo();
@ -2002,7 +1963,7 @@ RemoteUpdateShardStatistics(uint64 shardId)
{ {
ReportConnectionError(masterConnection, ERROR); ReportConnectionError(masterConnection, ERROR);
} }
queryResult = GetRemoteCommandResult(masterConnection, raiseInterrupts); PGresult *queryResult = GetRemoteCommandResult(masterConnection, raiseInterrupts);
if (PQresultStatus(queryResult) != PGRES_TUPLES_OK) if (PQresultStatus(queryResult) != PGRES_TUPLES_OK)
{ {
ereport(ERROR, (errmsg("could not update shard statistics"))); ereport(ERROR, (errmsg("could not update shard statistics")));
@ -2067,7 +2028,6 @@ static void
CopyAttributeOutText(CopyOutState cstate, char *string) CopyAttributeOutText(CopyOutState cstate, char *string)
{ {
char *pointer = NULL; char *pointer = NULL;
char *start = NULL;
char c = '\0'; char c = '\0';
char delimc = cstate->delim[0]; char delimc = cstate->delim[0];
@ -2092,7 +2052,7 @@ CopyAttributeOutText(CopyOutState cstate, char *string)
* skip doing pg_encoding_mblen(), because in valid backend encodings, * skip doing pg_encoding_mblen(), because in valid backend encodings,
* extra bytes of a multibyte character never look like ASCII. * extra bytes of a multibyte character never look like ASCII.
*/ */
start = pointer; char *start = pointer;
while ((c = *pointer) != '\0') while ((c = *pointer) != '\0')
{ {
if ((unsigned char) c < (unsigned char) 0x20) if ((unsigned char) c < (unsigned char) 0x20)
@ -2184,9 +2144,8 @@ CreateCitusCopyDestReceiver(Oid tableId, List *columnNameList, int partitionColu
EState *executorState, bool stopOnFailure, EState *executorState, bool stopOnFailure,
char *intermediateResultIdPrefix) char *intermediateResultIdPrefix)
{ {
CitusCopyDestReceiver *copyDest = NULL; CitusCopyDestReceiver *copyDest = (CitusCopyDestReceiver *) palloc0(
sizeof(CitusCopyDestReceiver));
copyDest = (CitusCopyDestReceiver *) palloc0(sizeof(CitusCopyDestReceiver));
/* set up the DestReceiver function pointers */ /* set up the DestReceiver function pointers */
copyDest->pub.receiveSlot = CitusCopyDestReceiverReceive; copyDest->pub.receiveSlot = CitusCopyDestReceiverReceive;
@ -2225,20 +2184,14 @@ CitusCopyDestReceiverStartup(DestReceiver *dest, int operation,
Oid schemaOid = get_rel_namespace(tableId); Oid schemaOid = get_rel_namespace(tableId);
char *schemaName = get_namespace_name(schemaOid); char *schemaName = get_namespace_name(schemaOid);
Relation distributedRelation = NULL;
List *columnNameList = copyDest->columnNameList; List *columnNameList = copyDest->columnNameList;
List *quotedColumnNameList = NIL; List *quotedColumnNameList = NIL;
ListCell *columnNameCell = NULL; ListCell *columnNameCell = NULL;
char partitionMethod = '\0'; char partitionMethod = '\0';
DistTableCacheEntry *cacheEntry = NULL;
CopyStmt *copyStatement = NULL;
List *shardIntervalList = NULL;
CopyOutState copyOutState = NULL;
const char *delimiterCharacter = "\t"; const char *delimiterCharacter = "\t";
const char *nullPrintCharacter = "\\N"; const char *nullPrintCharacter = "\\N";
@ -2246,15 +2199,15 @@ CitusCopyDestReceiverStartup(DestReceiver *dest, int operation,
ErrorIfLocalExecutionHappened(); ErrorIfLocalExecutionHappened();
/* look up table properties */ /* look up table properties */
distributedRelation = heap_open(tableId, RowExclusiveLock); Relation distributedRelation = heap_open(tableId, RowExclusiveLock);
cacheEntry = DistributedTableCacheEntry(tableId); DistTableCacheEntry *cacheEntry = DistributedTableCacheEntry(tableId);
partitionMethod = cacheEntry->partitionMethod; partitionMethod = cacheEntry->partitionMethod;
copyDest->distributedRelation = distributedRelation; copyDest->distributedRelation = distributedRelation;
copyDest->tupleDescriptor = inputTupleDescriptor; copyDest->tupleDescriptor = inputTupleDescriptor;
/* load the list of shards and verify that we have shards to copy into */ /* load the list of shards and verify that we have shards to copy into */
shardIntervalList = LoadShardIntervalList(tableId); List *shardIntervalList = LoadShardIntervalList(tableId);
if (shardIntervalList == NIL) if (shardIntervalList == NIL)
{ {
if (partitionMethod == DISTRIBUTE_BY_HASH) if (partitionMethod == DISTRIBUTE_BY_HASH)
@ -2307,7 +2260,7 @@ CitusCopyDestReceiverStartup(DestReceiver *dest, int operation,
} }
/* define how tuples will be serialised */ /* define how tuples will be serialised */
copyOutState = (CopyOutState) palloc0(sizeof(CopyOutStateData)); CopyOutState copyOutState = (CopyOutState) palloc0(sizeof(CopyOutStateData));
copyOutState->delim = (char *) delimiterCharacter; copyOutState->delim = (char *) delimiterCharacter;
copyOutState->null_print = (char *) nullPrintCharacter; copyOutState->null_print = (char *) nullPrintCharacter;
copyOutState->null_print_client = (char *) nullPrintCharacter; copyOutState->null_print_client = (char *) nullPrintCharacter;
@ -2349,15 +2302,15 @@ CitusCopyDestReceiverStartup(DestReceiver *dest, int operation,
} }
/* define the template for the COPY statement that is sent to workers */ /* define the template for the COPY statement that is sent to workers */
copyStatement = makeNode(CopyStmt); CopyStmt *copyStatement = makeNode(CopyStmt);
if (copyDest->intermediateResultIdPrefix != NULL) if (copyDest->intermediateResultIdPrefix != NULL)
{ {
DefElem *formatResultOption = NULL;
copyStatement->relation = makeRangeVar(NULL, copyDest->intermediateResultIdPrefix, copyStatement->relation = makeRangeVar(NULL, copyDest->intermediateResultIdPrefix,
-1); -1);
formatResultOption = makeDefElem("format", (Node *) makeString("result"), -1); DefElem *formatResultOption = makeDefElem("format", (Node *) makeString("result"),
-1);
copyStatement->options = list_make1(formatResultOption); copyStatement->options = list_make1(formatResultOption);
} }
else else
@ -2422,7 +2375,6 @@ CitusSendTupleToPlacements(TupleTableSlot *slot, CitusCopyDestReceiver *copyDest
TupleDesc tupleDescriptor = copyDest->tupleDescriptor; TupleDesc tupleDescriptor = copyDest->tupleDescriptor;
CopyStmt *copyStatement = copyDest->copyStatement; CopyStmt *copyStatement = copyDest->copyStatement;
CopyShardState *shardState = NULL;
CopyOutState copyOutState = copyDest->copyOutState; CopyOutState copyOutState = copyDest->copyOutState;
FmgrInfo *columnOutputFunctions = copyDest->columnOutputFunctions; FmgrInfo *columnOutputFunctions = copyDest->columnOutputFunctions;
CopyCoercionData *columnCoercionPaths = copyDest->columnCoercionPaths; CopyCoercionData *columnCoercionPaths = copyDest->columnCoercionPaths;
@ -2432,10 +2384,6 @@ CitusSendTupleToPlacements(TupleTableSlot *slot, CitusCopyDestReceiver *copyDest
bool stopOnFailure = copyDest->stopOnFailure; bool stopOnFailure = copyDest->stopOnFailure;
Datum *columnValues = NULL;
bool *columnNulls = NULL;
int64 shardId = 0;
EState *executorState = copyDest->executorState; EState *executorState = copyDest->executorState;
MemoryContext executorTupleContext = GetPerTupleMemoryContext(executorState); MemoryContext executorTupleContext = GetPerTupleMemoryContext(executorState);
@ -2443,16 +2391,17 @@ CitusSendTupleToPlacements(TupleTableSlot *slot, CitusCopyDestReceiver *copyDest
slot_getallattrs(slot); slot_getallattrs(slot);
columnValues = slot->tts_values; Datum *columnValues = slot->tts_values;
columnNulls = slot->tts_isnull; bool *columnNulls = slot->tts_isnull;
shardId = ShardIdForTuple(copyDest, columnValues, columnNulls); int64 shardId = ShardIdForTuple(copyDest, columnValues, columnNulls);
/* connections hash is kept in memory context */ /* connections hash is kept in memory context */
MemoryContextSwitchTo(copyDest->memoryContext); MemoryContextSwitchTo(copyDest->memoryContext);
shardState = GetShardState(shardId, copyDest->shardStateHash, CopyShardState *shardState = GetShardState(shardId, copyDest->shardStateHash,
copyDest->connectionStateHash, stopOnFailure, copyDest->connectionStateHash,
stopOnFailure,
&cachedShardStateFound); &cachedShardStateFound);
if (!cachedShardStateFound) if (!cachedShardStateFound)
{ {
@ -2564,7 +2513,6 @@ ShardIdForTuple(CitusCopyDestReceiver *copyDest, Datum *columnValues, bool *colu
int partitionColumnIndex = copyDest->partitionColumnIndex; int partitionColumnIndex = copyDest->partitionColumnIndex;
Datum partitionColumnValue = 0; Datum partitionColumnValue = 0;
CopyCoercionData *columnCoercionPaths = copyDest->columnCoercionPaths; CopyCoercionData *columnCoercionPaths = copyDest->columnCoercionPaths;
ShardInterval *shardInterval = NULL;
/* /*
* Find the partition column value and corresponding shard interval * Find the partition column value and corresponding shard interval
@ -2605,7 +2553,8 @@ ShardIdForTuple(CitusCopyDestReceiver *copyDest, Datum *columnValues, bool *colu
* For reference table, this function blindly returns the tables single * For reference table, this function blindly returns the tables single
* shard. * shard.
*/ */
shardInterval = FindShardInterval(partitionColumnValue, copyDest->tableMetadata); ShardInterval *shardInterval = FindShardInterval(partitionColumnValue,
copyDest->tableMetadata);
if (shardInterval == NULL) if (shardInterval == NULL)
{ {
ereport(ERROR, (errcode(ERRCODE_OBJECT_NOT_IN_PREREQUISITE_STATE), ereport(ERROR, (errcode(ERRCODE_OBJECT_NOT_IN_PREREQUISITE_STATE),
@ -2628,11 +2577,10 @@ CitusCopyDestReceiverShutdown(DestReceiver *destReceiver)
CitusCopyDestReceiver *copyDest = (CitusCopyDestReceiver *) destReceiver; CitusCopyDestReceiver *copyDest = (CitusCopyDestReceiver *) destReceiver;
HTAB *connectionStateHash = copyDest->connectionStateHash; HTAB *connectionStateHash = copyDest->connectionStateHash;
List *connectionStateList = NIL;
ListCell *connectionStateCell = NULL; ListCell *connectionStateCell = NULL;
Relation distributedRelation = copyDest->distributedRelation; Relation distributedRelation = copyDest->distributedRelation;
connectionStateList = ConnectionStateList(connectionStateHash); List *connectionStateList = ConnectionStateList(connectionStateHash);
PG_TRY(); PG_TRY();
{ {
@ -2820,21 +2768,20 @@ ProcessCopyStmt(CopyStmt *copyStatement, char *completionTag, const char *queryS
else else
{ {
bool isFrom = copyStatement->is_from; bool isFrom = copyStatement->is_from;
Relation copiedRelation = NULL;
char *schemaName = NULL;
MemoryContext relationContext = NULL;
/* consider using RangeVarGetRelidExtended to check perms before locking */ /* consider using RangeVarGetRelidExtended to check perms before locking */
copiedRelation = heap_openrv(copyStatement->relation, Relation copiedRelation = heap_openrv(copyStatement->relation,
isFrom ? RowExclusiveLock : AccessShareLock); isFrom ? RowExclusiveLock :
AccessShareLock);
isDistributedRelation = IsDistributedTable(RelationGetRelid(copiedRelation)); isDistributedRelation = IsDistributedTable(RelationGetRelid(copiedRelation));
/* ensure future lookups hit the same relation */ /* ensure future lookups hit the same relation */
schemaName = get_namespace_name(RelationGetNamespace(copiedRelation)); char *schemaName = get_namespace_name(RelationGetNamespace(copiedRelation));
/* ensure we copy string into proper context */ /* ensure we copy string into proper context */
relationContext = GetMemoryChunkContext(copyStatement->relation); MemoryContext relationContext = GetMemoryChunkContext(
copyStatement->relation);
schemaName = MemoryContextStrdup(relationContext, schemaName); schemaName = MemoryContextStrdup(relationContext, schemaName);
copyStatement->relation->schemaname = schemaName; copyStatement->relation->schemaname = schemaName;
@ -2906,16 +2853,15 @@ ProcessCopyStmt(CopyStmt *copyStatement, char *completionTag, const char *queryS
!copyStatement->is_from && !is_absolute_path(filename)) !copyStatement->is_from && !is_absolute_path(filename))
{ {
bool binaryCopyFormat = CopyStatementHasFormat(copyStatement, "binary"); bool binaryCopyFormat = CopyStatementHasFormat(copyStatement, "binary");
int64 tuplesSent = 0;
Query *query = NULL; Query *query = NULL;
Node *queryNode = copyStatement->query; Node *queryNode = copyStatement->query;
List *queryTreeList = NIL;
StringInfo userFilePath = makeStringInfo(); StringInfo userFilePath = makeStringInfo();
RawStmt *rawStmt = makeNode(RawStmt); RawStmt *rawStmt = makeNode(RawStmt);
rawStmt->stmt = queryNode; rawStmt->stmt = queryNode;
queryTreeList = pg_analyze_and_rewrite(rawStmt, queryString, NULL, 0, NULL); List *queryTreeList = pg_analyze_and_rewrite(rawStmt, queryString, NULL, 0,
NULL);
if (list_length(queryTreeList) != 1) if (list_length(queryTreeList) != 1)
{ {
@ -2931,7 +2877,7 @@ ProcessCopyStmt(CopyStmt *copyStatement, char *completionTag, const char *queryS
*/ */
appendStringInfo(userFilePath, "%s.%u", filename, GetUserId()); appendStringInfo(userFilePath, "%s.%u", filename, GetUserId());
tuplesSent = WorkerExecuteSqlTask(query, filename, binaryCopyFormat); int64 tuplesSent = WorkerExecuteSqlTask(query, filename, binaryCopyFormat);
snprintf(completionTag, COMPLETION_TAG_BUFSIZE, snprintf(completionTag, COMPLETION_TAG_BUFSIZE,
"COPY " UINT64_FORMAT, tuplesSent); "COPY " UINT64_FORMAT, tuplesSent);
@ -2952,7 +2898,6 @@ ProcessCopyStmt(CopyStmt *copyStatement, char *completionTag, const char *queryS
static void static void
CreateLocalTable(RangeVar *relation, char *nodeName, int32 nodePort) CreateLocalTable(RangeVar *relation, char *nodeName, int32 nodePort)
{ {
List *ddlCommandList = NIL;
ListCell *ddlCommandCell = NULL; ListCell *ddlCommandCell = NULL;
char *relationName = relation->relname; char *relationName = relation->relname;
@ -2964,7 +2909,7 @@ CreateLocalTable(RangeVar *relation, char *nodeName, int32 nodePort)
* enough; therefore, we just throw an error which says that we could not * enough; therefore, we just throw an error which says that we could not
* run the copy operation. * run the copy operation.
*/ */
ddlCommandList = TableDDLCommandList(nodeName, nodePort, qualifiedRelationName); List *ddlCommandList = TableDDLCommandList(nodeName, nodePort, qualifiedRelationName);
if (ddlCommandList == NIL) if (ddlCommandList == NIL)
{ {
ereport(ERROR, (errmsg("could not run copy from the worker node"))); ereport(ERROR, (errmsg("could not run copy from the worker node")));
@ -3045,14 +2990,13 @@ CheckCopyPermissions(CopyStmt *copyStatement)
AclMode required_access = (is_from ? ACL_INSERT : ACL_SELECT); AclMode required_access = (is_from ? ACL_INSERT : ACL_SELECT);
List *attnums; List *attnums;
ListCell *cur; ListCell *cur;
RangeTblEntry *rte;
rel = heap_openrv(copyStatement->relation, rel = heap_openrv(copyStatement->relation,
is_from ? RowExclusiveLock : AccessShareLock); is_from ? RowExclusiveLock : AccessShareLock);
relid = RelationGetRelid(rel); relid = RelationGetRelid(rel);
rte = makeNode(RangeTblEntry); RangeTblEntry *rte = makeNode(RangeTblEntry);
rte->rtekind = RTE_RELATION; rte->rtekind = RTE_RELATION;
rte->relid = relid; rte->relid = relid;
rte->relkind = rel->rd_rel->relkind; rte->relkind = rel->rd_rel->relkind;
@ -3166,17 +3110,15 @@ CopyGetAttnums(TupleDesc tupDesc, Relation rel, List *attnamelist)
static HTAB * static HTAB *
CreateConnectionStateHash(MemoryContext memoryContext) CreateConnectionStateHash(MemoryContext memoryContext)
{ {
HTAB *connectionStateHash = NULL;
int hashFlags = 0;
HASHCTL info; HASHCTL info;
memset(&info, 0, sizeof(info)); memset(&info, 0, sizeof(info));
info.keysize = sizeof(int); info.keysize = sizeof(int);
info.entrysize = sizeof(CopyConnectionState); info.entrysize = sizeof(CopyConnectionState);
info.hcxt = memoryContext; info.hcxt = memoryContext;
hashFlags = (HASH_ELEM | HASH_CONTEXT | HASH_BLOBS); int hashFlags = (HASH_ELEM | HASH_CONTEXT | HASH_BLOBS);
connectionStateHash = hash_create("Copy Connection State Hash", 128, &info, HTAB *connectionStateHash = hash_create("Copy Connection State Hash", 128, &info,
hashFlags); hashFlags);
return connectionStateHash; return connectionStateHash;
@ -3191,17 +3133,15 @@ CreateConnectionStateHash(MemoryContext memoryContext)
static HTAB * static HTAB *
CreateShardStateHash(MemoryContext memoryContext) CreateShardStateHash(MemoryContext memoryContext)
{ {
HTAB *shardStateHash = NULL;
int hashFlags = 0;
HASHCTL info; HASHCTL info;
memset(&info, 0, sizeof(info)); memset(&info, 0, sizeof(info));
info.keysize = sizeof(uint64); info.keysize = sizeof(uint64);
info.entrysize = sizeof(CopyShardState); info.entrysize = sizeof(CopyShardState);
info.hcxt = memoryContext; info.hcxt = memoryContext;
hashFlags = (HASH_ELEM | HASH_CONTEXT | HASH_BLOBS); int hashFlags = (HASH_ELEM | HASH_CONTEXT | HASH_BLOBS);
shardStateHash = hash_create("Copy Shard State Hash", 128, &info, hashFlags); HTAB *shardStateHash = hash_create("Copy Shard State Hash", 128, &info, hashFlags);
return shardStateHash; return shardStateHash;
} }
@ -3214,14 +3154,15 @@ CreateShardStateHash(MemoryContext memoryContext)
static CopyConnectionState * static CopyConnectionState *
GetConnectionState(HTAB *connectionStateHash, MultiConnection *connection) GetConnectionState(HTAB *connectionStateHash, MultiConnection *connection)
{ {
CopyConnectionState *connectionState = NULL;
bool found = false; bool found = false;
int sock = PQsocket(connection->pgConn); int sock = PQsocket(connection->pgConn);
Assert(sock != -1); Assert(sock != -1);
connectionState = (CopyConnectionState *) hash_search(connectionStateHash, &sock, CopyConnectionState *connectionState = (CopyConnectionState *) hash_search(
HASH_ENTER, &found); connectionStateHash, &sock,
HASH_ENTER,
&found);
if (!found) if (!found)
{ {
connectionState->socket = sock; connectionState->socket = sock;
@ -3243,11 +3184,11 @@ ConnectionStateList(HTAB *connectionStateHash)
{ {
List *connectionStateList = NIL; List *connectionStateList = NIL;
HASH_SEQ_STATUS status; HASH_SEQ_STATUS status;
CopyConnectionState *connectionState = NULL;
hash_seq_init(&status, connectionStateHash); hash_seq_init(&status, connectionStateHash);
connectionState = (CopyConnectionState *) hash_seq_search(&status); CopyConnectionState *connectionState = (CopyConnectionState *) hash_seq_search(
&status);
while (connectionState != NULL) while (connectionState != NULL)
{ {
connectionStateList = lappend(connectionStateList, connectionState); connectionStateList = lappend(connectionStateList, connectionState);
@ -3268,9 +3209,7 @@ static CopyShardState *
GetShardState(uint64 shardId, HTAB *shardStateHash, GetShardState(uint64 shardId, HTAB *shardStateHash,
HTAB *connectionStateHash, bool stopOnFailure, bool *found) HTAB *connectionStateHash, bool stopOnFailure, bool *found)
{ {
CopyShardState *shardState = NULL; CopyShardState *shardState = (CopyShardState *) hash_search(shardStateHash, &shardId,
shardState = (CopyShardState *) hash_search(shardStateHash, &shardId,
HASH_ENTER, found); HASH_ENTER, found);
if (!*found) if (!*found)
{ {
@ -3292,7 +3231,6 @@ InitializeCopyShardState(CopyShardState *shardState,
HTAB *connectionStateHash, uint64 shardId, HTAB *connectionStateHash, uint64 shardId,
bool stopOnFailure) bool stopOnFailure)
{ {
List *finalizedPlacementList = NIL;
ListCell *placementCell = NULL; ListCell *placementCell = NULL;
int failedPlacementCount = 0; int failedPlacementCount = 0;
@ -3306,7 +3244,7 @@ InitializeCopyShardState(CopyShardState *shardState,
/* release finalized placement list at the end of this function */ /* release finalized placement list at the end of this function */
MemoryContext oldContext = MemoryContextSwitchTo(localContext); MemoryContext oldContext = MemoryContextSwitchTo(localContext);
finalizedPlacementList = MasterShardPlacementList(shardId); List *finalizedPlacementList = MasterShardPlacementList(shardId);
MemoryContextSwitchTo(oldContext); MemoryContextSwitchTo(oldContext);
@ -3316,8 +3254,6 @@ InitializeCopyShardState(CopyShardState *shardState,
foreach(placementCell, finalizedPlacementList) foreach(placementCell, finalizedPlacementList)
{ {
ShardPlacement *placement = (ShardPlacement *) lfirst(placementCell); ShardPlacement *placement = (ShardPlacement *) lfirst(placementCell);
CopyConnectionState *connectionState = NULL;
CopyPlacementState *placementState = NULL;
MultiConnection *connection = MultiConnection *connection =
CopyGetPlacementConnection(placement, stopOnFailure); CopyGetPlacementConnection(placement, stopOnFailure);
@ -3327,7 +3263,8 @@ InitializeCopyShardState(CopyShardState *shardState,
continue; continue;
} }
connectionState = GetConnectionState(connectionStateHash, connection); CopyConnectionState *connectionState = GetConnectionState(connectionStateHash,
connection);
/* /*
* If this is the first time we are using this connection for copying a * If this is the first time we are using this connection for copying a
@ -3338,7 +3275,7 @@ InitializeCopyShardState(CopyShardState *shardState,
RemoteTransactionBeginIfNecessary(connection); RemoteTransactionBeginIfNecessary(connection);
} }
placementState = palloc0(sizeof(CopyPlacementState)); CopyPlacementState *placementState = palloc0(sizeof(CopyPlacementState));
placementState->shardState = shardState; placementState->shardState = shardState;
placementState->data = makeStringInfo(); placementState->data = makeStringInfo();
placementState->connectionState = connectionState; placementState->connectionState = connectionState;
@ -3380,18 +3317,18 @@ InitializeCopyShardState(CopyShardState *shardState,
static MultiConnection * static MultiConnection *
CopyGetPlacementConnection(ShardPlacement *placement, bool stopOnFailure) CopyGetPlacementConnection(ShardPlacement *placement, bool stopOnFailure)
{ {
MultiConnection *connection = NULL;
uint32 connectionFlags = FOR_DML; uint32 connectionFlags = FOR_DML;
char *nodeUser = CurrentUserName(); char *nodeUser = CurrentUserName();
ShardPlacementAccess *placementAccess = NULL;
/* /*
* Determine whether the task has to be assigned to a particular connection * Determine whether the task has to be assigned to a particular connection
* due to a preceding access to the placement in the same transaction. * due to a preceding access to the placement in the same transaction.
*/ */
placementAccess = CreatePlacementAccess(placement, PLACEMENT_ACCESS_DML); ShardPlacementAccess *placementAccess = CreatePlacementAccess(placement,
connection = GetConnectionIfPlacementAccessedInXact(connectionFlags, PLACEMENT_ACCESS_DML);
list_make1(placementAccess), MultiConnection *connection = GetConnectionIfPlacementAccessedInXact(connectionFlags,
list_make1(
placementAccess),
NULL); NULL);
if (connection != NULL) if (connection != NULL)
{ {
@ -3451,21 +3388,19 @@ static void
StartPlacementStateCopyCommand(CopyPlacementState *placementState, StartPlacementStateCopyCommand(CopyPlacementState *placementState,
CopyStmt *copyStatement, CopyOutState copyOutState) CopyStmt *copyStatement, CopyOutState copyOutState)
{ {
StringInfo copyCommand = NULL;
PGresult *result = NULL;
MultiConnection *connection = placementState->connectionState->connection; MultiConnection *connection = placementState->connectionState->connection;
uint64 shardId = placementState->shardState->shardId; uint64 shardId = placementState->shardState->shardId;
bool raiseInterrupts = true; bool raiseInterrupts = true;
bool binaryCopy = copyOutState->binary; bool binaryCopy = copyOutState->binary;
copyCommand = ConstructCopyStatement(copyStatement, shardId, binaryCopy); StringInfo copyCommand = ConstructCopyStatement(copyStatement, shardId, binaryCopy);
if (!SendRemoteCommand(connection, copyCommand->data)) if (!SendRemoteCommand(connection, copyCommand->data))
{ {
ReportConnectionError(connection, ERROR); ReportConnectionError(connection, ERROR);
} }
result = GetRemoteCommandResult(connection, raiseInterrupts); PGresult *result = GetRemoteCommandResult(connection, raiseInterrupts);
if (PQresultStatus(result) != PGRES_COPY_IN) if (PQresultStatus(result) != PGRES_COPY_IN)
{ {
ReportResultError(connection, result, ERROR); ReportResultError(connection, result, ERROR);

View File

@ -30,8 +30,6 @@ PlanRenameStmt(RenameStmt *renameStmt, const char *renameCommand)
{ {
Oid objectRelationId = InvalidOid; /* SQL Object OID */ Oid objectRelationId = InvalidOid; /* SQL Object OID */
Oid tableRelationId = InvalidOid; /* Relation OID, maybe not the same. */ Oid tableRelationId = InvalidOid; /* Relation OID, maybe not the same. */
bool isDistributedRelation = false;
DDLJob *ddlJob = NULL;
/* /*
* We only support some of the PostgreSQL supported RENAME statements, and * We only support some of the PostgreSQL supported RENAME statements, and
@ -97,7 +95,7 @@ PlanRenameStmt(RenameStmt *renameStmt, const char *renameCommand)
return NIL; return NIL;
} }
isDistributedRelation = IsDistributedTable(tableRelationId); bool isDistributedRelation = IsDistributedTable(tableRelationId);
if (!isDistributedRelation) if (!isDistributedRelation)
{ {
return NIL; return NIL;
@ -110,7 +108,7 @@ PlanRenameStmt(RenameStmt *renameStmt, const char *renameCommand)
*/ */
ErrorIfUnsupportedRenameStmt(renameStmt); ErrorIfUnsupportedRenameStmt(renameStmt);
ddlJob = palloc0(sizeof(DDLJob)); DDLJob *ddlJob = palloc0(sizeof(DDLJob));
ddlJob->targetRelationId = tableRelationId; ddlJob->targetRelationId = tableRelationId;
ddlJob->concurrentIndexCmd = false; ddlJob->concurrentIndexCmd = false;
ddlJob->commandString = renameCommand; ddlJob->commandString = renameCommand;

View File

@ -46,7 +46,6 @@ List *
ProcessAlterRoleStmt(AlterRoleStmt *stmt, const char *queryString) ProcessAlterRoleStmt(AlterRoleStmt *stmt, const char *queryString)
{ {
ListCell *optionCell = NULL; ListCell *optionCell = NULL;
List *commands = NIL;
if (!EnableAlterRolePropagation || !IsCoordinator()) if (!EnableAlterRolePropagation || !IsCoordinator())
{ {
@ -82,7 +81,7 @@ ProcessAlterRoleStmt(AlterRoleStmt *stmt, const char *queryString)
break; break;
} }
} }
commands = list_make1((void *) CreateAlterRoleIfExistsCommand(stmt)); List *commands = list_make1((void *) CreateAlterRoleIfExistsCommand(stmt));
return NodeDDLTaskList(ALL_WORKERS, commands); return NodeDDLTaskList(ALL_WORKERS, commands);
} }
@ -120,14 +119,13 @@ ExtractEncryptedPassword(Oid roleOid)
TupleDesc pgAuthIdDescription = RelationGetDescr(pgAuthId); TupleDesc pgAuthIdDescription = RelationGetDescr(pgAuthId);
HeapTuple tuple = SearchSysCache1(AUTHOID, roleOid); HeapTuple tuple = SearchSysCache1(AUTHOID, roleOid);
bool isNull = true; bool isNull = true;
Datum passwordDatum;
if (!HeapTupleIsValid(tuple)) if (!HeapTupleIsValid(tuple))
{ {
return NULL; return NULL;
} }
passwordDatum = heap_getattr(tuple, Anum_pg_authid_rolpassword, Datum passwordDatum = heap_getattr(tuple, Anum_pg_authid_rolpassword,
pgAuthIdDescription, &isNull); pgAuthIdDescription, &isNull);
heap_close(pgAuthId, AccessShareLock); heap_close(pgAuthId, AccessShareLock);
@ -151,8 +149,6 @@ GenerateAlterRoleIfExistsCommand(HeapTuple tuple, TupleDesc pgAuthIdDescription)
{ {
char *rolPassword = ""; char *rolPassword = "";
char *rolValidUntil = "infinity"; char *rolValidUntil = "infinity";
Datum rolValidUntilDatum;
Datum rolPasswordDatum;
bool isNull = true; bool isNull = true;
Form_pg_authid role = ((Form_pg_authid) GETSTRUCT(tuple)); Form_pg_authid role = ((Form_pg_authid) GETSTRUCT(tuple));
AlterRoleStmt *stmt = makeNode(AlterRoleStmt); AlterRoleStmt *stmt = makeNode(AlterRoleStmt);
@ -199,7 +195,7 @@ GenerateAlterRoleIfExistsCommand(HeapTuple tuple, TupleDesc pgAuthIdDescription)
makeDefElemInt("connectionlimit", role->rolconnlimit)); makeDefElemInt("connectionlimit", role->rolconnlimit));
rolPasswordDatum = heap_getattr(tuple, Anum_pg_authid_rolpassword, Datum rolPasswordDatum = heap_getattr(tuple, Anum_pg_authid_rolpassword,
pgAuthIdDescription, &isNull); pgAuthIdDescription, &isNull);
if (!isNull) if (!isNull)
{ {
@ -214,7 +210,7 @@ GenerateAlterRoleIfExistsCommand(HeapTuple tuple, TupleDesc pgAuthIdDescription)
stmt->options = lappend(stmt->options, makeDefElem("password", NULL, -1)); stmt->options = lappend(stmt->options, makeDefElem("password", NULL, -1));
} }
rolValidUntilDatum = heap_getattr(tuple, Anum_pg_authid_rolvaliduntil, Datum rolValidUntilDatum = heap_getattr(tuple, Anum_pg_authid_rolvaliduntil,
pgAuthIdDescription, &isNull); pgAuthIdDescription, &isNull);
if (!isNull) if (!isNull)
{ {

View File

@ -158,14 +158,12 @@ PlanAlterObjectSchemaStmt(AlterObjectSchemaStmt *stmt, const char *queryString)
List * List *
PlanAlterTableSchemaStmt(AlterObjectSchemaStmt *stmt, const char *queryString) PlanAlterTableSchemaStmt(AlterObjectSchemaStmt *stmt, const char *queryString)
{ {
Oid relationId = InvalidOid;
if (stmt->relation == NULL) if (stmt->relation == NULL)
{ {
return NIL; return NIL;
} }
relationId = RangeVarGetRelid(stmt->relation, Oid relationId = RangeVarGetRelid(stmt->relation,
AccessExclusiveLock, AccessExclusiveLock,
stmt->missing_ok); stmt->missing_ok);

View File

@ -56,7 +56,6 @@ ErrorIfDistributedAlterSeqOwnedBy(AlterSeqStmt *alterSeqStmt)
{ {
Oid sequenceId = RangeVarGetRelid(alterSeqStmt->sequence, AccessShareLock, Oid sequenceId = RangeVarGetRelid(alterSeqStmt->sequence, AccessShareLock,
alterSeqStmt->missing_ok); alterSeqStmt->missing_ok);
bool sequenceOwned = false;
Oid ownedByTableId = InvalidOid; Oid ownedByTableId = InvalidOid;
Oid newOwnedByTableId = InvalidOid; Oid newOwnedByTableId = InvalidOid;
int32 ownedByColumnId = 0; int32 ownedByColumnId = 0;
@ -68,7 +67,7 @@ ErrorIfDistributedAlterSeqOwnedBy(AlterSeqStmt *alterSeqStmt)
return; return;
} }
sequenceOwned = sequenceIsOwned(sequenceId, DEPENDENCY_AUTO, &ownedByTableId, bool sequenceOwned = sequenceIsOwned(sequenceId, DEPENDENCY_AUTO, &ownedByTableId,
&ownedByColumnId); &ownedByColumnId);
if (!sequenceOwned) if (!sequenceOwned)
{ {

View File

@ -73,7 +73,6 @@ ProcessDropTableStmt(DropStmt *dropTableStatement)
List *tableNameList = (List *) lfirst(dropTableCell); List *tableNameList = (List *) lfirst(dropTableCell);
RangeVar *tableRangeVar = makeRangeVarFromNameList(tableNameList); RangeVar *tableRangeVar = makeRangeVarFromNameList(tableNameList);
bool missingOK = true; bool missingOK = true;
List *partitionList = NIL;
ListCell *partitionCell = NULL; ListCell *partitionCell = NULL;
Oid relationId = RangeVarGetRelid(tableRangeVar, AccessShareLock, missingOK); Oid relationId = RangeVarGetRelid(tableRangeVar, AccessShareLock, missingOK);
@ -98,7 +97,7 @@ ProcessDropTableStmt(DropStmt *dropTableStatement)
EnsureCoordinator(); EnsureCoordinator();
partitionList = PartitionList(relationId); List *partitionList = PartitionList(relationId);
if (list_length(partitionList) == 0) if (list_length(partitionList) == 0)
{ {
continue; continue;
@ -254,14 +253,7 @@ ProcessAlterTableStmtAttachPartition(AlterTableStmt *alterTableStatement)
List * List *
PlanAlterTableStmt(AlterTableStmt *alterTableStatement, const char *alterTableCommand) PlanAlterTableStmt(AlterTableStmt *alterTableStatement, const char *alterTableCommand)
{ {
List *ddlJobs = NIL;
DDLJob *ddlJob = NULL;
LOCKMODE lockmode = 0;
Oid leftRelationId = InvalidOid;
Oid rightRelationId = InvalidOid; Oid rightRelationId = InvalidOid;
char leftRelationKind;
bool isDistributedRelation = false;
List *commandList = NIL;
ListCell *commandCell = NULL; ListCell *commandCell = NULL;
bool executeSequentially = false; bool executeSequentially = false;
@ -271,8 +263,8 @@ PlanAlterTableStmt(AlterTableStmt *alterTableStatement, const char *alterTableCo
return NIL; return NIL;
} }
lockmode = AlterTableGetLockLevel(alterTableStatement->cmds); LOCKMODE lockmode = AlterTableGetLockLevel(alterTableStatement->cmds);
leftRelationId = AlterTableLookupRelation(alterTableStatement, lockmode); Oid leftRelationId = AlterTableLookupRelation(alterTableStatement, lockmode);
if (!OidIsValid(leftRelationId)) if (!OidIsValid(leftRelationId))
{ {
return NIL; return NIL;
@ -283,13 +275,13 @@ PlanAlterTableStmt(AlterTableStmt *alterTableStatement, const char *alterTableCo
* SET/SET storage parameters in Citus, so we might have to check for * SET/SET storage parameters in Citus, so we might have to check for
* another relation here. * another relation here.
*/ */
leftRelationKind = get_rel_relkind(leftRelationId); char leftRelationKind = get_rel_relkind(leftRelationId);
if (leftRelationKind == RELKIND_INDEX) if (leftRelationKind == RELKIND_INDEX)
{ {
leftRelationId = IndexGetRelation(leftRelationId, false); leftRelationId = IndexGetRelation(leftRelationId, false);
} }
isDistributedRelation = IsDistributedTable(leftRelationId); bool isDistributedRelation = IsDistributedTable(leftRelationId);
if (!isDistributedRelation) if (!isDistributedRelation)
{ {
return NIL; return NIL;
@ -317,7 +309,7 @@ PlanAlterTableStmt(AlterTableStmt *alterTableStatement, const char *alterTableCo
* set skip_validation to true to prevent PostgreSQL to verify validity of the * set skip_validation to true to prevent PostgreSQL to verify validity of the
* foreign constraint in master. Validity will be checked in workers anyway. * foreign constraint in master. Validity will be checked in workers anyway.
*/ */
commandList = alterTableStatement->cmds; List *commandList = alterTableStatement->cmds;
foreach(commandCell, commandList) foreach(commandCell, commandList)
{ {
@ -426,7 +418,7 @@ PlanAlterTableStmt(AlterTableStmt *alterTableStatement, const char *alterTableCo
SetLocalMultiShardModifyModeToSequential(); SetLocalMultiShardModifyModeToSequential();
} }
ddlJob = palloc0(sizeof(DDLJob)); DDLJob *ddlJob = palloc0(sizeof(DDLJob));
ddlJob->targetRelationId = leftRelationId; ddlJob->targetRelationId = leftRelationId;
ddlJob->concurrentIndexCmd = false; ddlJob->concurrentIndexCmd = false;
ddlJob->commandString = alterTableCommand; ddlJob->commandString = alterTableCommand;
@ -450,7 +442,7 @@ PlanAlterTableStmt(AlterTableStmt *alterTableStatement, const char *alterTableCo
ddlJob->taskList = DDLTaskList(leftRelationId, alterTableCommand); ddlJob->taskList = DDLTaskList(leftRelationId, alterTableCommand);
} }
ddlJobs = list_make1(ddlJob); List *ddlJobs = list_make1(ddlJob);
return ddlJobs; return ddlJobs;
} }
@ -465,10 +457,6 @@ Node *
WorkerProcessAlterTableStmt(AlterTableStmt *alterTableStatement, WorkerProcessAlterTableStmt(AlterTableStmt *alterTableStatement,
const char *alterTableCommand) const char *alterTableCommand)
{ {
LOCKMODE lockmode = 0;
Oid leftRelationId = InvalidOid;
bool isDistributedRelation = false;
List *commandList = NIL;
ListCell *commandCell = NULL; ListCell *commandCell = NULL;
/* first check whether a distributed relation is affected */ /* first check whether a distributed relation is affected */
@ -477,14 +465,14 @@ WorkerProcessAlterTableStmt(AlterTableStmt *alterTableStatement,
return (Node *) alterTableStatement; return (Node *) alterTableStatement;
} }
lockmode = AlterTableGetLockLevel(alterTableStatement->cmds); LOCKMODE lockmode = AlterTableGetLockLevel(alterTableStatement->cmds);
leftRelationId = AlterTableLookupRelation(alterTableStatement, lockmode); Oid leftRelationId = AlterTableLookupRelation(alterTableStatement, lockmode);
if (!OidIsValid(leftRelationId)) if (!OidIsValid(leftRelationId))
{ {
return (Node *) alterTableStatement; return (Node *) alterTableStatement;
} }
isDistributedRelation = IsDistributedTable(leftRelationId); bool isDistributedRelation = IsDistributedTable(leftRelationId);
if (!isDistributedRelation) if (!isDistributedRelation)
{ {
return (Node *) alterTableStatement; return (Node *) alterTableStatement;
@ -496,7 +484,7 @@ WorkerProcessAlterTableStmt(AlterTableStmt *alterTableStatement,
* set skip_validation to true to prevent PostgreSQL to verify validity of the * set skip_validation to true to prevent PostgreSQL to verify validity of the
* foreign constraint in master. Validity will be checked in workers anyway. * foreign constraint in master. Validity will be checked in workers anyway.
*/ */
commandList = alterTableStatement->cmds; List *commandList = alterTableStatement->cmds;
foreach(commandCell, commandList) foreach(commandCell, commandList)
{ {
@ -559,9 +547,6 @@ IsAlterTableRenameStmt(RenameStmt *renameStmt)
void void
ErrorIfAlterDropsPartitionColumn(AlterTableStmt *alterTableStatement) ErrorIfAlterDropsPartitionColumn(AlterTableStmt *alterTableStatement)
{ {
LOCKMODE lockmode = 0;
Oid leftRelationId = InvalidOid;
bool isDistributedRelation = false;
List *commandList = alterTableStatement->cmds; List *commandList = alterTableStatement->cmds;
ListCell *commandCell = NULL; ListCell *commandCell = NULL;
@ -571,14 +556,14 @@ ErrorIfAlterDropsPartitionColumn(AlterTableStmt *alterTableStatement)
return; return;
} }
lockmode = AlterTableGetLockLevel(alterTableStatement->cmds); LOCKMODE lockmode = AlterTableGetLockLevel(alterTableStatement->cmds);
leftRelationId = AlterTableLookupRelation(alterTableStatement, lockmode); Oid leftRelationId = AlterTableLookupRelation(alterTableStatement, lockmode);
if (!OidIsValid(leftRelationId)) if (!OidIsValid(leftRelationId))
{ {
return; return;
} }
isDistributedRelation = IsDistributedTable(leftRelationId); bool isDistributedRelation = IsDistributedTable(leftRelationId);
if (!isDistributedRelation) if (!isDistributedRelation)
{ {
return; return;
@ -613,11 +598,9 @@ PostProcessAlterTableStmt(AlterTableStmt *alterTableStatement)
{ {
List *commandList = alterTableStatement->cmds; List *commandList = alterTableStatement->cmds;
ListCell *commandCell = NULL; ListCell *commandCell = NULL;
LOCKMODE lockmode = NoLock;
Oid relationId = InvalidOid;
lockmode = AlterTableGetLockLevel(alterTableStatement->cmds); LOCKMODE lockmode = AlterTableGetLockLevel(alterTableStatement->cmds);
relationId = AlterTableLookupRelation(alterTableStatement, lockmode); Oid relationId = AlterTableLookupRelation(alterTableStatement, lockmode);
if (relationId != InvalidOid) if (relationId != InvalidOid)
{ {
@ -634,8 +617,6 @@ PostProcessAlterTableStmt(AlterTableStmt *alterTableStatement)
if (alterTableType == AT_AddConstraint) if (alterTableType == AT_AddConstraint)
{ {
Constraint *constraint = NULL;
Assert(list_length(commandList) == 1); Assert(list_length(commandList) == 1);
ErrorIfUnsupportedAlterAddConstraintStmt(alterTableStatement); ErrorIfUnsupportedAlterAddConstraintStmt(alterTableStatement);
@ -645,7 +626,7 @@ PostProcessAlterTableStmt(AlterTableStmt *alterTableStatement)
continue; continue;
} }
constraint = (Constraint *) command->def; Constraint *constraint = (Constraint *) command->def;
if (constraint->contype == CONSTR_FOREIGN) if (constraint->contype == CONSTR_FOREIGN)
{ {
InvalidateForeignKeyGraph(); InvalidateForeignKeyGraph();
@ -653,11 +634,10 @@ PostProcessAlterTableStmt(AlterTableStmt *alterTableStatement)
} }
else if (alterTableType == AT_AddColumn) else if (alterTableType == AT_AddColumn)
{ {
List *columnConstraints = NIL;
ListCell *columnConstraint = NULL; ListCell *columnConstraint = NULL;
ColumnDef *columnDefinition = (ColumnDef *) command->def; ColumnDef *columnDefinition = (ColumnDef *) command->def;
columnConstraints = columnDefinition->constraints; List *columnConstraints = columnDefinition->constraints;
if (columnConstraints) if (columnConstraints)
{ {
ErrorIfUnsupportedAlterAddConstraintStmt(alterTableStatement); ErrorIfUnsupportedAlterAddConstraintStmt(alterTableStatement);
@ -792,8 +772,6 @@ void
ErrorIfUnsupportedConstraint(Relation relation, char distributionMethod, ErrorIfUnsupportedConstraint(Relation relation, char distributionMethod,
Var *distributionColumn, uint32 colocationId) Var *distributionColumn, uint32 colocationId)
{ {
char *relationName = NULL;
List *indexOidList = NULL;
ListCell *indexOidCell = NULL; ListCell *indexOidCell = NULL;
/* /*
@ -817,21 +795,17 @@ ErrorIfUnsupportedConstraint(Relation relation, char distributionMethod,
return; return;
} }
relationName = RelationGetRelationName(relation); char *relationName = RelationGetRelationName(relation);
indexOidList = RelationGetIndexList(relation); List *indexOidList = RelationGetIndexList(relation);
foreach(indexOidCell, indexOidList) foreach(indexOidCell, indexOidList)
{ {
Oid indexOid = lfirst_oid(indexOidCell); Oid indexOid = lfirst_oid(indexOidCell);
Relation indexDesc = index_open(indexOid, RowExclusiveLock); Relation indexDesc = index_open(indexOid, RowExclusiveLock);
IndexInfo *indexInfo = NULL;
AttrNumber *attributeNumberArray = NULL;
bool hasDistributionColumn = false; bool hasDistributionColumn = false;
int attributeCount = 0;
int attributeIndex = 0;
/* extract index key information from the index's pg_index info */ /* extract index key information from the index's pg_index info */
indexInfo = BuildIndexInfo(indexDesc); IndexInfo *indexInfo = BuildIndexInfo(indexDesc);
/* only check unique indexes and exclusion constraints. */ /* only check unique indexes and exclusion constraints. */
if (indexInfo->ii_Unique == false && indexInfo->ii_ExclusionOps == NULL) if (indexInfo->ii_Unique == false && indexInfo->ii_ExclusionOps == NULL)
@ -856,22 +830,20 @@ ErrorIfUnsupportedConstraint(Relation relation, char distributionMethod,
errhint("Consider using hash partitioning."))); errhint("Consider using hash partitioning.")));
} }
attributeCount = indexInfo->ii_NumIndexAttrs; int attributeCount = indexInfo->ii_NumIndexAttrs;
attributeNumberArray = indexInfo->ii_IndexAttrNumbers; AttrNumber *attributeNumberArray = indexInfo->ii_IndexAttrNumbers;
for (attributeIndex = 0; attributeIndex < attributeCount; attributeIndex++) for (int attributeIndex = 0; attributeIndex < attributeCount; attributeIndex++)
{ {
AttrNumber attributeNumber = attributeNumberArray[attributeIndex]; AttrNumber attributeNumber = attributeNumberArray[attributeIndex];
bool uniqueConstraint = false;
bool exclusionConstraintWithEquality = false;
if (distributionColumn->varattno != attributeNumber) if (distributionColumn->varattno != attributeNumber)
{ {
continue; continue;
} }
uniqueConstraint = indexInfo->ii_Unique; bool uniqueConstraint = indexInfo->ii_Unique;
exclusionConstraintWithEquality = (indexInfo->ii_ExclusionOps != NULL && bool exclusionConstraintWithEquality = (indexInfo->ii_ExclusionOps != NULL &&
OperatorImplementsEquality( OperatorImplementsEquality(
indexInfo->ii_ExclusionOps[ indexInfo->ii_ExclusionOps[
attributeIndex])); attributeIndex]));
@ -1278,15 +1250,13 @@ InterShardDDLTaskList(Oid leftRelationId, Oid rightRelationId,
*/ */
if (rightPartitionMethod == DISTRIBUTE_BY_NONE) if (rightPartitionMethod == DISTRIBUTE_BY_NONE)
{ {
ShardInterval *rightShardInterval = NULL;
int rightShardCount = list_length(rightShardList); int rightShardCount = list_length(rightShardList);
int leftShardCount = list_length(leftShardList); int leftShardCount = list_length(leftShardList);
int shardCounter = 0;
Assert(rightShardCount == 1); Assert(rightShardCount == 1);
rightShardInterval = (ShardInterval *) linitial(rightShardList); ShardInterval *rightShardInterval = (ShardInterval *) linitial(rightShardList);
for (shardCounter = rightShardCount; shardCounter < leftShardCount; for (int shardCounter = rightShardCount; shardCounter < leftShardCount;
shardCounter++) shardCounter++)
{ {
rightShardList = lappend(rightShardList, rightShardInterval); rightShardList = lappend(rightShardList, rightShardInterval);
@ -1301,7 +1271,6 @@ InterShardDDLTaskList(Oid leftRelationId, Oid rightRelationId,
ShardInterval *leftShardInterval = (ShardInterval *) lfirst(leftShardCell); ShardInterval *leftShardInterval = (ShardInterval *) lfirst(leftShardCell);
uint64 leftShardId = leftShardInterval->shardId; uint64 leftShardId = leftShardInterval->shardId;
StringInfo applyCommand = makeStringInfo(); StringInfo applyCommand = makeStringInfo();
Task *task = NULL;
RelationShard *leftRelationShard = CitusMakeNode(RelationShard); RelationShard *leftRelationShard = CitusMakeNode(RelationShard);
RelationShard *rightRelationShard = CitusMakeNode(RelationShard); RelationShard *rightRelationShard = CitusMakeNode(RelationShard);
@ -1318,7 +1287,7 @@ InterShardDDLTaskList(Oid leftRelationId, Oid rightRelationId,
leftShardId, escapedLeftSchemaName, rightShardId, leftShardId, escapedLeftSchemaName, rightShardId,
escapedRightSchemaName, escapedCommandString); escapedRightSchemaName, escapedCommandString);
task = CitusMakeNode(Task); Task *task = CitusMakeNode(Task);
task->jobId = jobId; task->jobId = jobId;
task->taskId = taskId++; task->taskId = taskId++;
task->taskType = DDL_TASK; task->taskType = DDL_TASK;
@ -1345,8 +1314,6 @@ AlterInvolvesPartitionColumn(AlterTableStmt *alterTableStatement,
AlterTableCmd *command) AlterTableCmd *command)
{ {
bool involvesPartitionColumn = false; bool involvesPartitionColumn = false;
Var *partitionColumn = NULL;
HeapTuple tuple = NULL;
char *alterColumnName = command->name; char *alterColumnName = command->name;
LOCKMODE lockmode = AlterTableGetLockLevel(alterTableStatement->cmds); LOCKMODE lockmode = AlterTableGetLockLevel(alterTableStatement->cmds);
@ -1356,9 +1323,9 @@ AlterInvolvesPartitionColumn(AlterTableStmt *alterTableStatement,
return false; return false;
} }
partitionColumn = DistPartitionKey(relationId); Var *partitionColumn = DistPartitionKey(relationId);
tuple = SearchSysCacheAttName(relationId, alterColumnName); HeapTuple tuple = SearchSysCacheAttName(relationId, alterColumnName);
if (HeapTupleIsValid(tuple)) if (HeapTupleIsValid(tuple))
{ {
Form_pg_attribute targetAttr = (Form_pg_attribute) GETSTRUCT(tuple); Form_pg_attribute targetAttr = (Form_pg_attribute) GETSTRUCT(tuple);

View File

@ -42,7 +42,6 @@ void
RedirectCopyDataToRegularFile(const char *filename) RedirectCopyDataToRegularFile(const char *filename)
{ {
StringInfo copyData = makeStringInfo(); StringInfo copyData = makeStringInfo();
bool copyDone = false;
const int fileFlags = (O_APPEND | O_CREAT | O_RDWR | O_TRUNC | PG_BINARY); const int fileFlags = (O_APPEND | O_CREAT | O_RDWR | O_TRUNC | PG_BINARY);
const int fileMode = (S_IRUSR | S_IWUSR); const int fileMode = (S_IRUSR | S_IWUSR);
File fileDesc = FileOpenForTransmit(filename, fileFlags, fileMode); File fileDesc = FileOpenForTransmit(filename, fileFlags, fileMode);
@ -50,7 +49,7 @@ RedirectCopyDataToRegularFile(const char *filename)
SendCopyInStart(); SendCopyInStart();
copyDone = ReceiveCopyData(copyData); bool copyDone = ReceiveCopyData(copyData);
while (!copyDone) while (!copyDone)
{ {
/* if received data has contents, append to regular file */ /* if received data has contents, append to regular file */
@ -83,8 +82,6 @@ RedirectCopyDataToRegularFile(const char *filename)
void void
SendRegularFile(const char *filename) SendRegularFile(const char *filename)
{ {
StringInfo fileBuffer = NULL;
int readBytes = -1;
const uint32 fileBufferSize = 32768; /* 32 KB */ const uint32 fileBufferSize = 32768; /* 32 KB */
const int fileFlags = (O_RDONLY | PG_BINARY); const int fileFlags = (O_RDONLY | PG_BINARY);
const int fileMode = 0; const int fileMode = 0;
@ -97,12 +94,12 @@ SendRegularFile(const char *filename)
* We read file's contents into buffers of 32 KB. This buffer size is twice * We read file's contents into buffers of 32 KB. This buffer size is twice
* as large as Hadoop's default buffer size, and may later be configurable. * as large as Hadoop's default buffer size, and may later be configurable.
*/ */
fileBuffer = makeStringInfo(); StringInfo fileBuffer = makeStringInfo();
enlargeStringInfo(fileBuffer, fileBufferSize); enlargeStringInfo(fileBuffer, fileBufferSize);
SendCopyOutStart(); SendCopyOutStart();
readBytes = FileReadCompat(&fileCompat, fileBuffer->data, fileBufferSize, int readBytes = FileReadCompat(&fileCompat, fileBuffer->data, fileBufferSize,
PG_WAIT_IO); PG_WAIT_IO);
while (readBytes > 0) while (readBytes > 0)
{ {
@ -141,11 +138,9 @@ FreeStringInfo(StringInfo stringInfo)
File File
FileOpenForTransmit(const char *filename, int fileFlags, int fileMode) FileOpenForTransmit(const char *filename, int fileFlags, int fileMode)
{ {
File fileDesc = -1;
int fileStated = -1;
struct stat fileStat; struct stat fileStat;
fileStated = stat(filename, &fileStat); int fileStated = stat(filename, &fileStat);
if (fileStated >= 0) if (fileStated >= 0)
{ {
if (S_ISDIR(fileStat.st_mode)) if (S_ISDIR(fileStat.st_mode))
@ -155,7 +150,7 @@ FileOpenForTransmit(const char *filename, int fileFlags, int fileMode)
} }
} }
fileDesc = PathNameOpenFilePerm((char *) filename, fileFlags, fileMode); File fileDesc = PathNameOpenFilePerm((char *) filename, fileFlags, fileMode);
if (fileDesc < 0) if (fileDesc < 0)
{ {
ereport(ERROR, (errcode_for_file_access(), ereport(ERROR, (errcode_for_file_access(),
@ -175,7 +170,6 @@ SendCopyInStart(void)
{ {
StringInfoData copyInStart = { NULL, 0, 0, 0 }; StringInfoData copyInStart = { NULL, 0, 0, 0 };
const char copyFormat = 1; /* binary copy format */ const char copyFormat = 1; /* binary copy format */
int flushed = 0;
pq_beginmessage(&copyInStart, 'G'); pq_beginmessage(&copyInStart, 'G');
pq_sendbyte(&copyInStart, copyFormat); pq_sendbyte(&copyInStart, copyFormat);
@ -183,7 +177,7 @@ SendCopyInStart(void)
pq_endmessage(&copyInStart); pq_endmessage(&copyInStart);
/* flush here to ensure that FE knows it can send data */ /* flush here to ensure that FE knows it can send data */
flushed = pq_flush(); int flushed = pq_flush();
if (flushed != 0) if (flushed != 0)
{ {
ereport(WARNING, (errmsg("could not flush copy start data"))); ereport(WARNING, (errmsg("could not flush copy start data")));
@ -213,13 +207,12 @@ static void
SendCopyDone(void) SendCopyDone(void)
{ {
StringInfoData copyDone = { NULL, 0, 0, 0 }; StringInfoData copyDone = { NULL, 0, 0, 0 };
int flushed = 0;
pq_beginmessage(&copyDone, 'c'); pq_beginmessage(&copyDone, 'c');
pq_endmessage(&copyDone); pq_endmessage(&copyDone);
/* flush here to signal to FE that we are done */ /* flush here to signal to FE that we are done */
flushed = pq_flush(); int flushed = pq_flush();
if (flushed != 0) if (flushed != 0)
{ {
ereport(WARNING, (errmsg("could not flush copy start data"))); ereport(WARNING, (errmsg("could not flush copy start data")));
@ -250,14 +243,12 @@ SendCopyData(StringInfo fileBuffer)
static bool static bool
ReceiveCopyData(StringInfo copyData) ReceiveCopyData(StringInfo copyData)
{ {
int messageType = 0;
int messageCopied = 0;
bool copyDone = true; bool copyDone = true;
const int unlimitedSize = 0; const int unlimitedSize = 0;
HOLD_CANCEL_INTERRUPTS(); HOLD_CANCEL_INTERRUPTS();
pq_startmsgread(); pq_startmsgread();
messageType = pq_getbyte(); int messageType = pq_getbyte();
if (messageType == EOF) if (messageType == EOF)
{ {
ereport(ERROR, (errcode(ERRCODE_CONNECTION_FAILURE), ereport(ERROR, (errcode(ERRCODE_CONNECTION_FAILURE),
@ -265,7 +256,7 @@ ReceiveCopyData(StringInfo copyData)
} }
/* consume the rest of message before checking for message type */ /* consume the rest of message before checking for message type */
messageCopied = pq_getmessage(copyData, unlimitedSize); int messageCopied = pq_getmessage(copyData, unlimitedSize);
if (messageCopied == EOF) if (messageCopied == EOF)
{ {
ereport(ERROR, (errcode(ERRCODE_CONNECTION_FAILURE), ereport(ERROR, (errcode(ERRCODE_CONNECTION_FAILURE),
@ -382,8 +373,6 @@ TransmitStatementUser(CopyStmt *copyStatement)
void void
VerifyTransmitStmt(CopyStmt *copyStatement) VerifyTransmitStmt(CopyStmt *copyStatement)
{ {
char *fileName = NULL;
EnsureSuperUser(); EnsureSuperUser();
/* do some minimal option verification */ /* do some minimal option verification */
@ -394,7 +383,7 @@ VerifyTransmitStmt(CopyStmt *copyStatement)
errmsg("FORMAT 'transmit' requires a target file"))); errmsg("FORMAT 'transmit' requires a target file")));
} }
fileName = copyStatement->relation->relname; char *fileName = copyStatement->relation->relname;
if (is_absolute_path(fileName)) if (is_absolute_path(fileName))
{ {

View File

@ -180,8 +180,6 @@ LockTruncatedRelationMetadataInWorkers(TruncateStmt *truncateStatement)
{ {
RangeVar *rangeVar = (RangeVar *) lfirst(relationCell); RangeVar *rangeVar = (RangeVar *) lfirst(relationCell);
Oid relationId = RangeVarGetRelid(rangeVar, NoLock, false); Oid relationId = RangeVarGetRelid(rangeVar, NoLock, false);
DistTableCacheEntry *cacheEntry = NULL;
List *referencingTableList = NIL;
Oid referencingRelationId = InvalidOid; Oid referencingRelationId = InvalidOid;
if (!IsDistributedTable(relationId)) if (!IsDistributedTable(relationId))
@ -196,10 +194,10 @@ LockTruncatedRelationMetadataInWorkers(TruncateStmt *truncateStatement)
distributedRelationList = lappend_oid(distributedRelationList, relationId); distributedRelationList = lappend_oid(distributedRelationList, relationId);
cacheEntry = DistributedTableCacheEntry(relationId); DistTableCacheEntry *cacheEntry = DistributedTableCacheEntry(relationId);
Assert(cacheEntry != NULL); Assert(cacheEntry != NULL);
referencingTableList = cacheEntry->referencingRelationsViaForeignKey; List *referencingTableList = cacheEntry->referencingRelationsViaForeignKey;
foreach_oid(referencingRelationId, referencingTableList) foreach_oid(referencingRelationId, referencingTableList)
{ {
distributedRelationList = list_append_unique_oid(distributedRelationList, distributedRelationList = list_append_unique_oid(distributedRelationList,

View File

@ -114,9 +114,6 @@ static bool ShouldPropagateTypeCreate(void);
List * List *
PlanCompositeTypeStmt(CompositeTypeStmt *stmt, const char *queryString) PlanCompositeTypeStmt(CompositeTypeStmt *stmt, const char *queryString)
{ {
const char *compositeTypeStmtSql = NULL;
List *commands = NIL;
if (!ShouldPropagateTypeCreate()) if (!ShouldPropagateTypeCreate())
{ {
return NIL; return NIL;
@ -149,7 +146,7 @@ PlanCompositeTypeStmt(CompositeTypeStmt *stmt, const char *queryString)
* type previously has been attempted to be created in a transaction which did not * type previously has been attempted to be created in a transaction which did not
* commit on the coordinator. * commit on the coordinator.
*/ */
compositeTypeStmtSql = DeparseCompositeTypeStmt(stmt); const char *compositeTypeStmtSql = DeparseCompositeTypeStmt(stmt);
compositeTypeStmtSql = WrapCreateOrReplace(compositeTypeStmtSql); compositeTypeStmtSql = WrapCreateOrReplace(compositeTypeStmtSql);
/* /*
@ -158,7 +155,7 @@ PlanCompositeTypeStmt(CompositeTypeStmt *stmt, const char *queryString)
*/ */
EnsureSequentialModeForTypeDDL(); EnsureSequentialModeForTypeDDL();
commands = list_make3(DISABLE_DDL_PROPAGATION, List *commands = list_make3(DISABLE_DDL_PROPAGATION,
(void *) compositeTypeStmtSql, (void *) compositeTypeStmtSql,
ENABLE_DDL_PROPAGATION); ENABLE_DDL_PROPAGATION);
@ -174,8 +171,6 @@ PlanCompositeTypeStmt(CompositeTypeStmt *stmt, const char *queryString)
void void
ProcessCompositeTypeStmt(CompositeTypeStmt *stmt, const char *queryString) ProcessCompositeTypeStmt(CompositeTypeStmt *stmt, const char *queryString)
{ {
const ObjectAddress *typeAddress = NULL;
/* same check we perform during planning of the statement */ /* same check we perform during planning of the statement */
if (!ShouldPropagateTypeCreate()) if (!ShouldPropagateTypeCreate())
{ {
@ -186,7 +181,8 @@ ProcessCompositeTypeStmt(CompositeTypeStmt *stmt, const char *queryString)
* find object address of the just created object, because the type has been created * find object address of the just created object, because the type has been created
* locally it can't be missing * locally it can't be missing
*/ */
typeAddress = GetObjectAddressFromParseTree((Node *) stmt, false); const ObjectAddress *typeAddress = GetObjectAddressFromParseTree((Node *) stmt,
false);
EnsureDependenciesExistsOnAllNodes(typeAddress); EnsureDependenciesExistsOnAllNodes(typeAddress);
MarkObjectDistributed(typeAddress); MarkObjectDistributed(typeAddress);
@ -202,13 +198,10 @@ ProcessCompositeTypeStmt(CompositeTypeStmt *stmt, const char *queryString)
List * List *
PlanAlterTypeStmt(AlterTableStmt *stmt, const char *queryString) PlanAlterTypeStmt(AlterTableStmt *stmt, const char *queryString)
{ {
const char *alterTypeStmtSql = NULL;
const ObjectAddress *typeAddress = NULL;
List *commands = NIL;
Assert(stmt->relkind == OBJECT_TYPE); Assert(stmt->relkind == OBJECT_TYPE);
typeAddress = GetObjectAddressFromParseTree((Node *) stmt, false); const ObjectAddress *typeAddress = GetObjectAddressFromParseTree((Node *) stmt,
false);
if (!ShouldPropagateObject(typeAddress)) if (!ShouldPropagateObject(typeAddress))
{ {
return NIL; return NIL;
@ -218,7 +211,7 @@ PlanAlterTypeStmt(AlterTableStmt *stmt, const char *queryString)
/* reconstruct alter statement in a portable fashion */ /* reconstruct alter statement in a portable fashion */
QualifyTreeNode((Node *) stmt); QualifyTreeNode((Node *) stmt);
alterTypeStmtSql = DeparseTreeNode((Node *) stmt); const char *alterTypeStmtSql = DeparseTreeNode((Node *) stmt);
/* /*
* all types that are distributed will need their alter statements propagated * all types that are distributed will need their alter statements propagated
@ -227,7 +220,7 @@ PlanAlterTypeStmt(AlterTableStmt *stmt, const char *queryString)
*/ */
EnsureSequentialModeForTypeDDL(); EnsureSequentialModeForTypeDDL();
commands = list_make3(DISABLE_DDL_PROPAGATION, List *commands = list_make3(DISABLE_DDL_PROPAGATION,
(void *) alterTypeStmtSql, (void *) alterTypeStmtSql,
ENABLE_DDL_PROPAGATION); ENABLE_DDL_PROPAGATION);
@ -248,9 +241,6 @@ PlanAlterTypeStmt(AlterTableStmt *stmt, const char *queryString)
List * List *
PlanCreateEnumStmt(CreateEnumStmt *stmt, const char *queryString) PlanCreateEnumStmt(CreateEnumStmt *stmt, const char *queryString)
{ {
const char *createEnumStmtSql = NULL;
List *commands = NIL;
if (!ShouldPropagateTypeCreate()) if (!ShouldPropagateTypeCreate())
{ {
return NIL; return NIL;
@ -266,7 +256,7 @@ PlanCreateEnumStmt(CreateEnumStmt *stmt, const char *queryString)
QualifyTreeNode((Node *) stmt); QualifyTreeNode((Node *) stmt);
/* reconstruct creation statement in a portable fashion */ /* reconstruct creation statement in a portable fashion */
createEnumStmtSql = DeparseCreateEnumStmt(stmt); const char *createEnumStmtSql = DeparseCreateEnumStmt(stmt);
createEnumStmtSql = WrapCreateOrReplace(createEnumStmtSql); createEnumStmtSql = WrapCreateOrReplace(createEnumStmtSql);
/* /*
@ -276,7 +266,7 @@ PlanCreateEnumStmt(CreateEnumStmt *stmt, const char *queryString)
EnsureSequentialModeForTypeDDL(); EnsureSequentialModeForTypeDDL();
/* to prevent recursion with mx we disable ddl propagation */ /* to prevent recursion with mx we disable ddl propagation */
commands = list_make3(DISABLE_DDL_PROPAGATION, List *commands = list_make3(DISABLE_DDL_PROPAGATION,
(void *) createEnumStmtSql, (void *) createEnumStmtSql,
ENABLE_DDL_PROPAGATION); ENABLE_DDL_PROPAGATION);
@ -295,15 +285,14 @@ PlanCreateEnumStmt(CreateEnumStmt *stmt, const char *queryString)
void void
ProcessCreateEnumStmt(CreateEnumStmt *stmt, const char *queryString) ProcessCreateEnumStmt(CreateEnumStmt *stmt, const char *queryString)
{ {
const ObjectAddress *typeAddress = NULL;
if (!ShouldPropagateTypeCreate()) if (!ShouldPropagateTypeCreate())
{ {
return; return;
} }
/* lookup type address of just created type */ /* lookup type address of just created type */
typeAddress = GetObjectAddressFromParseTree((Node *) stmt, false); const ObjectAddress *typeAddress = GetObjectAddressFromParseTree((Node *) stmt,
false);
EnsureDependenciesExistsOnAllNodes(typeAddress); EnsureDependenciesExistsOnAllNodes(typeAddress);
/* /*
@ -326,11 +315,10 @@ ProcessCreateEnumStmt(CreateEnumStmt *stmt, const char *queryString)
List * List *
PlanAlterEnumStmt(AlterEnumStmt *stmt, const char *queryString) PlanAlterEnumStmt(AlterEnumStmt *stmt, const char *queryString)
{ {
const char *alterEnumStmtSql = NULL;
const ObjectAddress *typeAddress = NULL;
List *commands = NIL; List *commands = NIL;
typeAddress = GetObjectAddressFromParseTree((Node *) stmt, false); const ObjectAddress *typeAddress = GetObjectAddressFromParseTree((Node *) stmt,
false);
if (!ShouldPropagateObject(typeAddress)) if (!ShouldPropagateObject(typeAddress))
{ {
return NIL; return NIL;
@ -351,7 +339,7 @@ PlanAlterEnumStmt(AlterEnumStmt *stmt, const char *queryString)
EnsureCoordinator(); EnsureCoordinator();
QualifyTreeNode((Node *) stmt); QualifyTreeNode((Node *) stmt);
alterEnumStmtSql = DeparseTreeNode((Node *) stmt); const char *alterEnumStmtSql = DeparseTreeNode((Node *) stmt);
/* /*
* Before pg12 ALTER ENUM ... ADD VALUE could not be within a xact block. Instead of * Before pg12 ALTER ENUM ... ADD VALUE could not be within a xact block. Instead of
@ -396,9 +384,8 @@ PlanAlterEnumStmt(AlterEnumStmt *stmt, const char *queryString)
void void
ProcessAlterEnumStmt(AlterEnumStmt *stmt, const char *queryString) ProcessAlterEnumStmt(AlterEnumStmt *stmt, const char *queryString)
{ {
const ObjectAddress *typeAddress = NULL; const ObjectAddress *typeAddress = GetObjectAddressFromParseTree((Node *) stmt,
false);
typeAddress = GetObjectAddressFromParseTree((Node *) stmt, false);
if (!ShouldPropagateObject(typeAddress)) if (!ShouldPropagateObject(typeAddress))
{ {
return; return;
@ -422,25 +409,22 @@ ProcessAlterEnumStmt(AlterEnumStmt *stmt, const char *queryString)
* might already be added to some nodes, but not all. * might already be added to some nodes, but not all.
*/ */
int result = 0;
List *commands = NIL;
const char *alterEnumStmtSql = NULL;
/* qualification of the stmt happened during planning */ /* qualification of the stmt happened during planning */
alterEnumStmtSql = DeparseTreeNode((Node *) stmt); const char *alterEnumStmtSql = DeparseTreeNode((Node *) stmt);
commands = list_make2(DISABLE_DDL_PROPAGATION, (void *) alterEnumStmtSql); List *commands = list_make2(DISABLE_DDL_PROPAGATION, (void *) alterEnumStmtSql);
result = SendBareOptionalCommandListToWorkersAsUser(ALL_WORKERS, commands, NULL); int result = SendBareOptionalCommandListToWorkersAsUser(ALL_WORKERS, commands,
NULL);
if (result != RESPONSE_OKAY) if (result != RESPONSE_OKAY)
{ {
const char *alterEnumStmtIfNotExistsSql = NULL;
bool oldSkipIfNewValueExists = stmt->skipIfNewValExists; bool oldSkipIfNewValueExists = stmt->skipIfNewValExists;
/* deparse the query with IF NOT EXISTS */ /* deparse the query with IF NOT EXISTS */
stmt->skipIfNewValExists = true; stmt->skipIfNewValExists = true;
alterEnumStmtIfNotExistsSql = DeparseTreeNode((Node *) stmt); const char *alterEnumStmtIfNotExistsSql = DeparseTreeNode((Node *) stmt);
stmt->skipIfNewValExists = oldSkipIfNewValueExists; stmt->skipIfNewValExists = oldSkipIfNewValueExists;
ereport(WARNING, (errmsg("not all workers applied change to enum"), ereport(WARNING, (errmsg("not all workers applied change to enum"),
@ -466,18 +450,15 @@ PlanDropTypeStmt(DropStmt *stmt, const char *queryString)
* the old list to put back * the old list to put back
*/ */
List *oldTypes = stmt->objects; List *oldTypes = stmt->objects;
List *distributedTypes = NIL;
const char *dropStmtSql = NULL;
ListCell *addressCell = NULL; ListCell *addressCell = NULL;
List *distributedTypeAddresses = NIL;
List *commands = NIL;
if (!ShouldPropagate()) if (!ShouldPropagate())
{ {
return NIL; return NIL;
} }
distributedTypes = FilterNameListForDistributedTypes(oldTypes, stmt->missing_ok); List *distributedTypes = FilterNameListForDistributedTypes(oldTypes,
stmt->missing_ok);
if (list_length(distributedTypes) <= 0) if (list_length(distributedTypes) <= 0)
{ {
/* no distributed types to drop */ /* no distributed types to drop */
@ -494,7 +475,7 @@ PlanDropTypeStmt(DropStmt *stmt, const char *queryString)
/* /*
* remove the entries for the distributed objects on dropping * remove the entries for the distributed objects on dropping
*/ */
distributedTypeAddresses = TypeNameListToObjectAddresses(distributedTypes); List *distributedTypeAddresses = TypeNameListToObjectAddresses(distributedTypes);
foreach(addressCell, distributedTypeAddresses) foreach(addressCell, distributedTypeAddresses)
{ {
ObjectAddress *address = (ObjectAddress *) lfirst(addressCell); ObjectAddress *address = (ObjectAddress *) lfirst(addressCell);
@ -506,13 +487,13 @@ PlanDropTypeStmt(DropStmt *stmt, const char *queryString)
* deparse to an executable sql statement for the workers * deparse to an executable sql statement for the workers
*/ */
stmt->objects = distributedTypes; stmt->objects = distributedTypes;
dropStmtSql = DeparseTreeNode((Node *) stmt); const char *dropStmtSql = DeparseTreeNode((Node *) stmt);
stmt->objects = oldTypes; stmt->objects = oldTypes;
/* to prevent recursion with mx we disable ddl propagation */ /* to prevent recursion with mx we disable ddl propagation */
EnsureSequentialModeForTypeDDL(); EnsureSequentialModeForTypeDDL();
commands = list_make3(DISABLE_DDL_PROPAGATION, List *commands = list_make3(DISABLE_DDL_PROPAGATION,
(void *) dropStmtSql, (void *) dropStmtSql,
ENABLE_DDL_PROPAGATION); ENABLE_DDL_PROPAGATION);
@ -531,11 +512,8 @@ PlanDropTypeStmt(DropStmt *stmt, const char *queryString)
List * List *
PlanRenameTypeStmt(RenameStmt *stmt, const char *queryString) PlanRenameTypeStmt(RenameStmt *stmt, const char *queryString)
{ {
const char *renameStmtSql = NULL; const ObjectAddress *typeAddress = GetObjectAddressFromParseTree((Node *) stmt,
const ObjectAddress *typeAddress = NULL; false);
List *commands = NIL;
typeAddress = GetObjectAddressFromParseTree((Node *) stmt, false);
if (!ShouldPropagateObject(typeAddress)) if (!ShouldPropagateObject(typeAddress))
{ {
return NIL; return NIL;
@ -545,12 +523,12 @@ PlanRenameTypeStmt(RenameStmt *stmt, const char *queryString)
QualifyTreeNode((Node *) stmt); QualifyTreeNode((Node *) stmt);
/* deparse sql*/ /* deparse sql*/
renameStmtSql = DeparseTreeNode((Node *) stmt); const char *renameStmtSql = DeparseTreeNode((Node *) stmt);
/* to prevent recursion with mx we disable ddl propagation */ /* to prevent recursion with mx we disable ddl propagation */
EnsureSequentialModeForTypeDDL(); EnsureSequentialModeForTypeDDL();
commands = list_make3(DISABLE_DDL_PROPAGATION, List *commands = list_make3(DISABLE_DDL_PROPAGATION,
(void *) renameStmtSql, (void *) renameStmtSql,
ENABLE_DDL_PROPAGATION); ENABLE_DDL_PROPAGATION);
@ -568,14 +546,11 @@ PlanRenameTypeStmt(RenameStmt *stmt, const char *queryString)
List * List *
PlanRenameTypeAttributeStmt(RenameStmt *stmt, const char *queryString) PlanRenameTypeAttributeStmt(RenameStmt *stmt, const char *queryString)
{ {
const char *sql = NULL;
const ObjectAddress *typeAddress = NULL;
List *commands = NIL;
Assert(stmt->renameType == OBJECT_ATTRIBUTE); Assert(stmt->renameType == OBJECT_ATTRIBUTE);
Assert(stmt->relationType == OBJECT_TYPE); Assert(stmt->relationType == OBJECT_TYPE);
typeAddress = GetObjectAddressFromParseTree((Node *) stmt, false); const ObjectAddress *typeAddress = GetObjectAddressFromParseTree((Node *) stmt,
false);
if (!ShouldPropagateObject(typeAddress)) if (!ShouldPropagateObject(typeAddress))
{ {
return NIL; return NIL;
@ -583,10 +558,10 @@ PlanRenameTypeAttributeStmt(RenameStmt *stmt, const char *queryString)
QualifyTreeNode((Node *) stmt); QualifyTreeNode((Node *) stmt);
sql = DeparseTreeNode((Node *) stmt); const char *sql = DeparseTreeNode((Node *) stmt);
EnsureSequentialModeForTypeDDL(); EnsureSequentialModeForTypeDDL();
commands = list_make3(DISABLE_DDL_PROPAGATION, List *commands = list_make3(DISABLE_DDL_PROPAGATION,
(void *) sql, (void *) sql,
ENABLE_DDL_PROPAGATION); ENABLE_DDL_PROPAGATION);
@ -603,13 +578,10 @@ PlanRenameTypeAttributeStmt(RenameStmt *stmt, const char *queryString)
List * List *
PlanAlterTypeSchemaStmt(AlterObjectSchemaStmt *stmt, const char *queryString) PlanAlterTypeSchemaStmt(AlterObjectSchemaStmt *stmt, const char *queryString)
{ {
const char *sql = NULL;
const ObjectAddress *typeAddress = NULL;
List *commands = NIL;
Assert(stmt->objectType == OBJECT_TYPE); Assert(stmt->objectType == OBJECT_TYPE);
typeAddress = GetObjectAddressFromParseTree((Node *) stmt, false); const ObjectAddress *typeAddress = GetObjectAddressFromParseTree((Node *) stmt,
false);
if (!ShouldPropagateObject(typeAddress)) if (!ShouldPropagateObject(typeAddress))
{ {
return NIL; return NIL;
@ -618,11 +590,11 @@ PlanAlterTypeSchemaStmt(AlterObjectSchemaStmt *stmt, const char *queryString)
EnsureCoordinator(); EnsureCoordinator();
QualifyTreeNode((Node *) stmt); QualifyTreeNode((Node *) stmt);
sql = DeparseTreeNode((Node *) stmt); const char *sql = DeparseTreeNode((Node *) stmt);
EnsureSequentialModeForTypeDDL(); EnsureSequentialModeForTypeDDL();
commands = list_make3(DISABLE_DDL_PROPAGATION, List *commands = list_make3(DISABLE_DDL_PROPAGATION,
(void *) sql, (void *) sql,
ENABLE_DDL_PROPAGATION); ENABLE_DDL_PROPAGATION);
@ -638,11 +610,10 @@ PlanAlterTypeSchemaStmt(AlterObjectSchemaStmt *stmt, const char *queryString)
void void
ProcessAlterTypeSchemaStmt(AlterObjectSchemaStmt *stmt, const char *queryString) ProcessAlterTypeSchemaStmt(AlterObjectSchemaStmt *stmt, const char *queryString)
{ {
const ObjectAddress *typeAddress = NULL;
Assert(stmt->objectType == OBJECT_TYPE); Assert(stmt->objectType == OBJECT_TYPE);
typeAddress = GetObjectAddressFromParseTree((Node *) stmt, false); const ObjectAddress *typeAddress = GetObjectAddressFromParseTree((Node *) stmt,
false);
if (!ShouldPropagateObject(typeAddress)) if (!ShouldPropagateObject(typeAddress))
{ {
return; return;
@ -663,13 +634,10 @@ ProcessAlterTypeSchemaStmt(AlterObjectSchemaStmt *stmt, const char *queryString)
List * List *
PlanAlterTypeOwnerStmt(AlterOwnerStmt *stmt, const char *queryString) PlanAlterTypeOwnerStmt(AlterOwnerStmt *stmt, const char *queryString)
{ {
const ObjectAddress *typeAddress = NULL;
const char *sql = NULL;
List *commands = NULL;
Assert(stmt->objectType == OBJECT_TYPE); Assert(stmt->objectType == OBJECT_TYPE);
typeAddress = GetObjectAddressFromParseTree((Node *) stmt, false); const ObjectAddress *typeAddress = GetObjectAddressFromParseTree((Node *) stmt,
false);
if (!ShouldPropagateObject(typeAddress)) if (!ShouldPropagateObject(typeAddress))
{ {
return NIL; return NIL;
@ -678,10 +646,10 @@ PlanAlterTypeOwnerStmt(AlterOwnerStmt *stmt, const char *queryString)
EnsureCoordinator(); EnsureCoordinator();
QualifyTreeNode((Node *) stmt); QualifyTreeNode((Node *) stmt);
sql = DeparseTreeNode((Node *) stmt); const char *sql = DeparseTreeNode((Node *) stmt);
EnsureSequentialModeForTypeDDL(); EnsureSequentialModeForTypeDDL();
commands = list_make3(DISABLE_DDL_PROPAGATION, List *commands = list_make3(DISABLE_DDL_PROPAGATION,
(void *) sql, (void *) sql,
ENABLE_DDL_PROPAGATION); ENABLE_DDL_PROPAGATION);
@ -726,13 +694,10 @@ CreateTypeStmtByObjectAddress(const ObjectAddress *address)
static CompositeTypeStmt * static CompositeTypeStmt *
RecreateCompositeTypeStmt(Oid typeOid) RecreateCompositeTypeStmt(Oid typeOid)
{ {
CompositeTypeStmt *stmt = NULL;
List *names = NIL;
Assert(get_typtype(typeOid) == TYPTYPE_COMPOSITE); Assert(get_typtype(typeOid) == TYPTYPE_COMPOSITE);
stmt = makeNode(CompositeTypeStmt); CompositeTypeStmt *stmt = makeNode(CompositeTypeStmt);
names = stringToQualifiedNameList(format_type_be_qualified(typeOid)); List *names = stringToQualifiedNameList(format_type_be_qualified(typeOid));
stmt->typevar = makeRangeVarFromNameList(names); stmt->typevar = makeRangeVarFromNameList(names);
stmt->coldeflist = CompositeTypeColumnDefList(typeOid); stmt->coldeflist = CompositeTypeColumnDefList(typeOid);
@ -763,17 +728,14 @@ attributeFormToColumnDef(Form_pg_attribute attributeForm)
static List * static List *
CompositeTypeColumnDefList(Oid typeOid) CompositeTypeColumnDefList(Oid typeOid)
{ {
Relation relation = NULL;
Oid relationId = InvalidOid;
TupleDesc tupleDescriptor = NULL;
int attributeIndex = 0;
List *columnDefs = NIL; List *columnDefs = NIL;
relationId = typeidTypeRelid(typeOid); Oid relationId = typeidTypeRelid(typeOid);
relation = relation_open(relationId, AccessShareLock); Relation relation = relation_open(relationId, AccessShareLock);
tupleDescriptor = RelationGetDescr(relation); TupleDesc tupleDescriptor = RelationGetDescr(relation);
for (attributeIndex = 0; attributeIndex < tupleDescriptor->natts; attributeIndex++) for (int attributeIndex = 0; attributeIndex < tupleDescriptor->natts;
attributeIndex++)
{ {
Form_pg_attribute attributeForm = TupleDescAttr(tupleDescriptor, attributeIndex); Form_pg_attribute attributeForm = TupleDescAttr(tupleDescriptor, attributeIndex);
@ -799,11 +761,9 @@ CompositeTypeColumnDefList(Oid typeOid)
static CreateEnumStmt * static CreateEnumStmt *
RecreateEnumStmt(Oid typeOid) RecreateEnumStmt(Oid typeOid)
{ {
CreateEnumStmt *stmt = NULL;
Assert(get_typtype(typeOid) == TYPTYPE_ENUM); Assert(get_typtype(typeOid) == TYPTYPE_ENUM);
stmt = makeNode(CreateEnumStmt); CreateEnumStmt *stmt = makeNode(CreateEnumStmt);
stmt->typeName = stringToQualifiedNameList(format_type_be_qualified(typeOid)); stmt->typeName = stringToQualifiedNameList(format_type_be_qualified(typeOid));
stmt->vals = EnumValsList(typeOid); stmt->vals = EnumValsList(typeOid);
@ -818,8 +778,6 @@ RecreateEnumStmt(Oid typeOid)
static List * static List *
EnumValsList(Oid typeOid) EnumValsList(Oid typeOid)
{ {
Relation enum_rel = NULL;
SysScanDesc enum_scan = NULL;
HeapTuple enum_tuple = NULL; HeapTuple enum_tuple = NULL;
ScanKeyData skey = { 0 }; ScanKeyData skey = { 0 };
@ -831,8 +789,8 @@ EnumValsList(Oid typeOid)
BTEqualStrategyNumber, F_OIDEQ, BTEqualStrategyNumber, F_OIDEQ,
ObjectIdGetDatum(typeOid)); ObjectIdGetDatum(typeOid));
enum_rel = heap_open(EnumRelationId, AccessShareLock); Relation enum_rel = heap_open(EnumRelationId, AccessShareLock);
enum_scan = systable_beginscan(enum_rel, SysScanDesc enum_scan = systable_beginscan(enum_rel,
EnumTypIdSortOrderIndexId, EnumTypIdSortOrderIndexId,
true, NULL, true, NULL,
1, &skey); 1, &skey);
@ -861,13 +819,9 @@ EnumValsList(Oid typeOid)
ObjectAddress * ObjectAddress *
CompositeTypeStmtObjectAddress(CompositeTypeStmt *stmt, bool missing_ok) CompositeTypeStmtObjectAddress(CompositeTypeStmt *stmt, bool missing_ok)
{ {
TypeName *typeName = NULL; TypeName *typeName = MakeTypeNameFromRangeVar(stmt->typevar);
Oid typeOid = InvalidOid; Oid typeOid = LookupTypeNameOid(NULL, typeName, missing_ok);
ObjectAddress *address = NULL; ObjectAddress *address = palloc0(sizeof(ObjectAddress));
typeName = MakeTypeNameFromRangeVar(stmt->typevar);
typeOid = LookupTypeNameOid(NULL, typeName, missing_ok);
address = palloc0(sizeof(ObjectAddress));
ObjectAddressSet(*address, TypeRelationId, typeOid); ObjectAddressSet(*address, TypeRelationId, typeOid);
return address; return address;
@ -885,13 +839,9 @@ CompositeTypeStmtObjectAddress(CompositeTypeStmt *stmt, bool missing_ok)
ObjectAddress * ObjectAddress *
CreateEnumStmtObjectAddress(CreateEnumStmt *stmt, bool missing_ok) CreateEnumStmtObjectAddress(CreateEnumStmt *stmt, bool missing_ok)
{ {
TypeName *typeName = NULL; TypeName *typeName = makeTypeNameFromNameList(stmt->typeName);
Oid typeOid = InvalidOid; Oid typeOid = LookupTypeNameOid(NULL, typeName, missing_ok);
ObjectAddress *address = NULL; ObjectAddress *address = palloc0(sizeof(ObjectAddress));
typeName = makeTypeNameFromNameList(stmt->typeName);
typeOid = LookupTypeNameOid(NULL, typeName, missing_ok);
address = palloc0(sizeof(ObjectAddress));
ObjectAddressSet(*address, TypeRelationId, typeOid); ObjectAddressSet(*address, TypeRelationId, typeOid);
return address; return address;
@ -909,15 +859,11 @@ CreateEnumStmtObjectAddress(CreateEnumStmt *stmt, bool missing_ok)
ObjectAddress * ObjectAddress *
AlterTypeStmtObjectAddress(AlterTableStmt *stmt, bool missing_ok) AlterTypeStmtObjectAddress(AlterTableStmt *stmt, bool missing_ok)
{ {
TypeName *typeName = NULL;
Oid typeOid = InvalidOid;
ObjectAddress *address = NULL;
Assert(stmt->relkind == OBJECT_TYPE); Assert(stmt->relkind == OBJECT_TYPE);
typeName = MakeTypeNameFromRangeVar(stmt->relation); TypeName *typeName = MakeTypeNameFromRangeVar(stmt->relation);
typeOid = LookupTypeNameOid(NULL, typeName, missing_ok); Oid typeOid = LookupTypeNameOid(NULL, typeName, missing_ok);
address = palloc0(sizeof(ObjectAddress)); ObjectAddress *address = palloc0(sizeof(ObjectAddress));
ObjectAddressSet(*address, TypeRelationId, typeOid); ObjectAddressSet(*address, TypeRelationId, typeOid);
return address; return address;
@ -931,13 +877,9 @@ AlterTypeStmtObjectAddress(AlterTableStmt *stmt, bool missing_ok)
ObjectAddress * ObjectAddress *
AlterEnumStmtObjectAddress(AlterEnumStmt *stmt, bool missing_ok) AlterEnumStmtObjectAddress(AlterEnumStmt *stmt, bool missing_ok)
{ {
TypeName *typeName = NULL; TypeName *typeName = makeTypeNameFromNameList(stmt->typeName);
Oid typeOid = InvalidOid; Oid typeOid = LookupTypeNameOid(NULL, typeName, missing_ok);
ObjectAddress *address = NULL; ObjectAddress *address = palloc0(sizeof(ObjectAddress));
typeName = makeTypeNameFromNameList(stmt->typeName);
typeOid = LookupTypeNameOid(NULL, typeName, missing_ok);
address = palloc0(sizeof(ObjectAddress));
ObjectAddressSet(*address, TypeRelationId, typeOid); ObjectAddressSet(*address, TypeRelationId, typeOid);
return address; return address;
@ -951,15 +893,11 @@ AlterEnumStmtObjectAddress(AlterEnumStmt *stmt, bool missing_ok)
ObjectAddress * ObjectAddress *
RenameTypeStmtObjectAddress(RenameStmt *stmt, bool missing_ok) RenameTypeStmtObjectAddress(RenameStmt *stmt, bool missing_ok)
{ {
TypeName *typeName = NULL;
Oid typeOid = InvalidOid;
ObjectAddress *address = NULL;
Assert(stmt->renameType == OBJECT_TYPE); Assert(stmt->renameType == OBJECT_TYPE);
typeName = makeTypeNameFromNameList((List *) stmt->object); TypeName *typeName = makeTypeNameFromNameList((List *) stmt->object);
typeOid = LookupTypeNameOid(NULL, typeName, missing_ok); Oid typeOid = LookupTypeNameOid(NULL, typeName, missing_ok);
address = palloc0(sizeof(ObjectAddress)); ObjectAddress *address = palloc0(sizeof(ObjectAddress));
ObjectAddressSet(*address, TypeRelationId, typeOid); ObjectAddressSet(*address, TypeRelationId, typeOid);
return address; return address;
@ -978,21 +916,16 @@ RenameTypeStmtObjectAddress(RenameStmt *stmt, bool missing_ok)
ObjectAddress * ObjectAddress *
AlterTypeSchemaStmtObjectAddress(AlterObjectSchemaStmt *stmt, bool missing_ok) AlterTypeSchemaStmtObjectAddress(AlterObjectSchemaStmt *stmt, bool missing_ok)
{ {
ObjectAddress *address = NULL;
TypeName *typeName = NULL;
Oid typeOid = InvalidOid;
List *names = NIL;
Assert(stmt->objectType == OBJECT_TYPE); Assert(stmt->objectType == OBJECT_TYPE);
names = (List *) stmt->object; List *names = (List *) stmt->object;
/* /*
* we hardcode missing_ok here during LookupTypeNameOid because if we can't find it it * we hardcode missing_ok here during LookupTypeNameOid because if we can't find it it
* might have already been moved in this transaction. * might have already been moved in this transaction.
*/ */
typeName = makeTypeNameFromNameList(names); TypeName *typeName = makeTypeNameFromNameList(names);
typeOid = LookupTypeNameOid(NULL, typeName, true); Oid typeOid = LookupTypeNameOid(NULL, typeName, true);
if (typeOid == InvalidOid) if (typeOid == InvalidOid)
{ {
@ -1024,7 +957,7 @@ AlterTypeSchemaStmtObjectAddress(AlterObjectSchemaStmt *stmt, bool missing_ok)
} }
} }
address = palloc0(sizeof(ObjectAddress)); ObjectAddress *address = palloc0(sizeof(ObjectAddress));
ObjectAddressSet(*address, TypeRelationId, typeOid); ObjectAddressSet(*address, TypeRelationId, typeOid);
return address; return address;
@ -1042,16 +975,12 @@ AlterTypeSchemaStmtObjectAddress(AlterObjectSchemaStmt *stmt, bool missing_ok)
ObjectAddress * ObjectAddress *
RenameTypeAttributeStmtObjectAddress(RenameStmt *stmt, bool missing_ok) RenameTypeAttributeStmtObjectAddress(RenameStmt *stmt, bool missing_ok)
{ {
TypeName *typeName = NULL;
Oid typeOid = InvalidOid;
ObjectAddress *address = NULL;
Assert(stmt->renameType == OBJECT_ATTRIBUTE); Assert(stmt->renameType == OBJECT_ATTRIBUTE);
Assert(stmt->relationType == OBJECT_TYPE); Assert(stmt->relationType == OBJECT_TYPE);
typeName = MakeTypeNameFromRangeVar(stmt->relation); TypeName *typeName = MakeTypeNameFromRangeVar(stmt->relation);
typeOid = LookupTypeNameOid(NULL, typeName, missing_ok); Oid typeOid = LookupTypeNameOid(NULL, typeName, missing_ok);
address = palloc0(sizeof(ObjectAddress)); ObjectAddress *address = palloc0(sizeof(ObjectAddress));
ObjectAddressSet(*address, TypeRelationId, typeOid); ObjectAddressSet(*address, TypeRelationId, typeOid);
return address; return address;
@ -1065,15 +994,11 @@ RenameTypeAttributeStmtObjectAddress(RenameStmt *stmt, bool missing_ok)
ObjectAddress * ObjectAddress *
AlterTypeOwnerObjectAddress(AlterOwnerStmt *stmt, bool missing_ok) AlterTypeOwnerObjectAddress(AlterOwnerStmt *stmt, bool missing_ok)
{ {
TypeName *typeName = NULL;
Oid typeOid = InvalidOid;
ObjectAddress *address = NULL;
Assert(stmt->objectType == OBJECT_TYPE); Assert(stmt->objectType == OBJECT_TYPE);
typeName = makeTypeNameFromNameList((List *) stmt->object); TypeName *typeName = makeTypeNameFromNameList((List *) stmt->object);
typeOid = LookupTypeNameOid(NULL, typeName, missing_ok); Oid typeOid = LookupTypeNameOid(NULL, typeName, missing_ok);
address = palloc0(sizeof(ObjectAddress)); ObjectAddress *address = palloc0(sizeof(ObjectAddress));
ObjectAddressSet(*address, TypeRelationId, typeOid); ObjectAddressSet(*address, TypeRelationId, typeOid);
return address; return address;
@ -1088,10 +1013,7 @@ List *
CreateTypeDDLCommandsIdempotent(const ObjectAddress *typeAddress) CreateTypeDDLCommandsIdempotent(const ObjectAddress *typeAddress)
{ {
List *ddlCommands = NIL; List *ddlCommands = NIL;
const char *ddlCommand = NULL;
Node *stmt = NULL;
StringInfoData buf = { 0 }; StringInfoData buf = { 0 };
const char *username = NULL;
Assert(typeAddress->classId == TypeRelationId); Assert(typeAddress->classId == TypeRelationId);
@ -1106,15 +1028,15 @@ CreateTypeDDLCommandsIdempotent(const ObjectAddress *typeAddress)
return NIL; return NIL;
} }
stmt = CreateTypeStmtByObjectAddress(typeAddress); Node *stmt = CreateTypeStmtByObjectAddress(typeAddress);
/* capture ddl command for recreation and wrap in create if not exists construct */ /* capture ddl command for recreation and wrap in create if not exists construct */
ddlCommand = DeparseTreeNode(stmt); const char *ddlCommand = DeparseTreeNode(stmt);
ddlCommand = WrapCreateOrReplace(ddlCommand); ddlCommand = WrapCreateOrReplace(ddlCommand);
ddlCommands = lappend(ddlCommands, (void *) ddlCommand); ddlCommands = lappend(ddlCommands, (void *) ddlCommand);
/* add owner ship change so the creation command can be run as a different user */ /* add owner ship change so the creation command can be run as a different user */
username = GetUserNameFromId(GetTypeOwner(typeAddress->objectId), false); const char *username = GetUserNameFromId(GetTypeOwner(typeAddress->objectId), false);
initStringInfo(&buf); initStringInfo(&buf);
appendStringInfo(&buf, ALTER_TYPE_OWNER_COMMAND, getObjectIdentity(typeAddress), appendStringInfo(&buf, ALTER_TYPE_OWNER_COMMAND, getObjectIdentity(typeAddress),
quote_identifier(username)); quote_identifier(username));
@ -1145,8 +1067,6 @@ GenerateBackupNameForTypeCollision(const ObjectAddress *address)
{ {
int suffixLength = snprintf(suffix, NAMEDATALEN - 1, "(citus_backup_%d)", int suffixLength = snprintf(suffix, NAMEDATALEN - 1, "(citus_backup_%d)",
count); count);
TypeName *newTypeName = NULL;
Oid typeOid = InvalidOid;
/* trim the base name at the end to leave space for the suffix and trailing \0 */ /* trim the base name at the end to leave space for the suffix and trailing \0 */
baseLength = Min(baseLength, NAMEDATALEN - suffixLength - 1); baseLength = Min(baseLength, NAMEDATALEN - suffixLength - 1);
@ -1157,9 +1077,9 @@ GenerateBackupNameForTypeCollision(const ObjectAddress *address)
strncpy(newName + baseLength, suffix, suffixLength); strncpy(newName + baseLength, suffix, suffixLength);
rel->relname = newName; rel->relname = newName;
newTypeName = makeTypeNameFromNameList(MakeNameListFromRangeVar(rel)); TypeName *newTypeName = makeTypeNameFromNameList(MakeNameListFromRangeVar(rel));
typeOid = LookupTypeNameOid(NULL, newTypeName, true); Oid typeOid = LookupTypeNameOid(NULL, newTypeName, true);
if (typeOid == InvalidOid) if (typeOid == InvalidOid)
{ {
return newName; return newName;
@ -1235,9 +1155,8 @@ static Oid
GetTypeOwner(Oid typeOid) GetTypeOwner(Oid typeOid)
{ {
Oid result = InvalidOid; Oid result = InvalidOid;
HeapTuple tp = NULL;
tp = SearchSysCache1(TYPEOID, ObjectIdGetDatum(typeOid)); HeapTuple tp = SearchSysCache1(TYPEOID, ObjectIdGetDatum(typeOid));
if (HeapTupleIsValid(tp)) if (HeapTupleIsValid(tp))
{ {
Form_pg_type typtup = (Form_pg_type) GETSTRUCT(tp); Form_pg_type typtup = (Form_pg_type) GETSTRUCT(tp);

View File

@ -117,7 +117,6 @@ multi_ProcessUtility(PlannedStmt *pstmt,
{ {
Node *parsetree = pstmt->utilityStmt; Node *parsetree = pstmt->utilityStmt;
List *ddlJobs = NIL; List *ddlJobs = NIL;
bool checkCreateAlterExtensionVersion = false;
if (IsA(parsetree, TransactionStmt) || if (IsA(parsetree, TransactionStmt) ||
IsA(parsetree, LockStmt) || IsA(parsetree, LockStmt) ||
@ -143,7 +142,8 @@ multi_ProcessUtility(PlannedStmt *pstmt,
return; return;
} }
checkCreateAlterExtensionVersion = IsCreateAlterExtensionUpdateCitusStmt(parsetree); bool checkCreateAlterExtensionVersion = IsCreateAlterExtensionUpdateCitusStmt(
parsetree);
if (EnableVersionChecks && checkCreateAlterExtensionVersion) if (EnableVersionChecks && checkCreateAlterExtensionVersion)
{ {
ErrorIfUnstableCreateOrAlterExtensionStmt(parsetree); ErrorIfUnstableCreateOrAlterExtensionStmt(parsetree);
@ -332,12 +332,11 @@ multi_ProcessUtility(PlannedStmt *pstmt,
if (IsA(parsetree, CopyStmt)) if (IsA(parsetree, CopyStmt))
{ {
MemoryContext planContext = GetMemoryChunkContext(parsetree); MemoryContext planContext = GetMemoryChunkContext(parsetree);
MemoryContext previousContext;
parsetree = copyObject(parsetree); parsetree = copyObject(parsetree);
parsetree = ProcessCopyStmt((CopyStmt *) parsetree, completionTag, queryString); parsetree = ProcessCopyStmt((CopyStmt *) parsetree, completionTag, queryString);
previousContext = MemoryContextSwitchTo(planContext); MemoryContext previousContext = MemoryContextSwitchTo(planContext);
parsetree = copyObject(parsetree); parsetree = copyObject(parsetree);
MemoryContextSwitchTo(previousContext); MemoryContextSwitchTo(previousContext);
@ -886,14 +885,12 @@ multi_ProcessUtility(PlannedStmt *pstmt,
static bool static bool
IsDropSchemaOrDB(Node *parsetree) IsDropSchemaOrDB(Node *parsetree)
{ {
DropStmt *dropStatement = NULL;
if (!IsA(parsetree, DropStmt)) if (!IsA(parsetree, DropStmt))
{ {
return false; return false;
} }
dropStatement = (DropStmt *) parsetree; DropStmt *dropStatement = (DropStmt *) parsetree;
return (dropStatement->removeType == OBJECT_SCHEMA) || return (dropStatement->removeType == OBJECT_SCHEMA) ||
(dropStatement->removeType == OBJECT_DATABASE); (dropStatement->removeType == OBJECT_DATABASE);
} }
@ -1091,7 +1088,6 @@ ExecuteDistributedDDLJob(DDLJob *ddlJob)
static char * static char *
SetSearchPathToCurrentSearchPathCommand(void) SetSearchPathToCurrentSearchPathCommand(void)
{ {
StringInfo setCommand = NULL;
char *currentSearchPath = CurrentSearchPath(); char *currentSearchPath = CurrentSearchPath();
if (currentSearchPath == NULL) if (currentSearchPath == NULL)
@ -1099,7 +1095,7 @@ SetSearchPathToCurrentSearchPathCommand(void)
return NULL; return NULL;
} }
setCommand = makeStringInfo(); StringInfo setCommand = makeStringInfo();
appendStringInfo(setCommand, "SET search_path TO %s;", currentSearchPath); appendStringInfo(setCommand, "SET search_path TO %s;", currentSearchPath);
return setCommand->data; return setCommand->data;
@ -1217,7 +1213,6 @@ DDLTaskList(Oid relationId, const char *commandString)
ShardInterval *shardInterval = (ShardInterval *) lfirst(shardIntervalCell); ShardInterval *shardInterval = (ShardInterval *) lfirst(shardIntervalCell);
uint64 shardId = shardInterval->shardId; uint64 shardId = shardInterval->shardId;
StringInfo applyCommand = makeStringInfo(); StringInfo applyCommand = makeStringInfo();
Task *task = NULL;
/* /*
* If rightRelationId is not InvalidOid, instead of worker_apply_shard_ddl_command * If rightRelationId is not InvalidOid, instead of worker_apply_shard_ddl_command
@ -1226,7 +1221,7 @@ DDLTaskList(Oid relationId, const char *commandString)
appendStringInfo(applyCommand, WORKER_APPLY_SHARD_DDL_COMMAND, shardId, appendStringInfo(applyCommand, WORKER_APPLY_SHARD_DDL_COMMAND, shardId,
escapedSchemaName, escapedCommandString); escapedSchemaName, escapedCommandString);
task = CitusMakeNode(Task); Task *task = CitusMakeNode(Task);
task->jobId = jobId; task->jobId = jobId;
task->taskId = taskId++; task->taskId = taskId++;
task->taskType = DDL_TASK; task->taskType = DDL_TASK;
@ -1252,9 +1247,7 @@ NodeDDLTaskList(TargetWorkerSet targets, List *commands)
{ {
List *workerNodes = TargetWorkerSetNodeList(targets, NoLock); List *workerNodes = TargetWorkerSetNodeList(targets, NoLock);
char *concatenatedCommands = StringJoin(commands, ';'); char *concatenatedCommands = StringJoin(commands, ';');
DDLJob *ddlJob = NULL;
ListCell *workerNodeCell = NULL; ListCell *workerNodeCell = NULL;
Task *task = NULL;
if (list_length(workerNodes) <= 0) if (list_length(workerNodes) <= 0)
{ {
@ -1265,16 +1258,15 @@ NodeDDLTaskList(TargetWorkerSet targets, List *commands)
return NIL; return NIL;
} }
task = CitusMakeNode(Task); Task *task = CitusMakeNode(Task);
task->taskType = DDL_TASK; task->taskType = DDL_TASK;
task->queryString = concatenatedCommands; task->queryString = concatenatedCommands;
foreach(workerNodeCell, workerNodes) foreach(workerNodeCell, workerNodes)
{ {
WorkerNode *workerNode = (WorkerNode *) lfirst(workerNodeCell); WorkerNode *workerNode = (WorkerNode *) lfirst(workerNodeCell);
ShardPlacement *targetPlacement = NULL;
targetPlacement = CitusMakeNode(ShardPlacement); ShardPlacement *targetPlacement = CitusMakeNode(ShardPlacement);
targetPlacement->nodeName = workerNode->workerName; targetPlacement->nodeName = workerNode->workerName;
targetPlacement->nodePort = workerNode->workerPort; targetPlacement->nodePort = workerNode->workerPort;
targetPlacement->groupId = workerNode->groupId; targetPlacement->groupId = workerNode->groupId;
@ -1282,7 +1274,7 @@ NodeDDLTaskList(TargetWorkerSet targets, List *commands)
task->taskPlacementList = lappend(task->taskPlacementList, targetPlacement); task->taskPlacementList = lappend(task->taskPlacementList, targetPlacement);
} }
ddlJob = palloc0(sizeof(DDLJob)); DDLJob *ddlJob = palloc0(sizeof(DDLJob));
ddlJob->targetRelationId = InvalidOid; ddlJob->targetRelationId = InvalidOid;
ddlJob->concurrentIndexCmd = false; ddlJob->concurrentIndexCmd = false;
ddlJob->commandString = NULL; ddlJob->commandString = NULL;

View File

@ -62,7 +62,6 @@ void
ProcessVacuumStmt(VacuumStmt *vacuumStmt, const char *vacuumCommand) ProcessVacuumStmt(VacuumStmt *vacuumStmt, const char *vacuumCommand)
{ {
int relationIndex = 0; int relationIndex = 0;
bool distributedVacuumStmt = false;
List *vacuumRelationList = ExtractVacuumTargetRels(vacuumStmt); List *vacuumRelationList = ExtractVacuumTargetRels(vacuumStmt);
ListCell *vacuumRelationCell = NULL; ListCell *vacuumRelationCell = NULL;
List *relationIdList = NIL; List *relationIdList = NIL;
@ -79,7 +78,8 @@ ProcessVacuumStmt(VacuumStmt *vacuumStmt, const char *vacuumCommand)
relationIdList = lappend_oid(relationIdList, relationId); relationIdList = lappend_oid(relationIdList, relationId);
} }
distributedVacuumStmt = IsDistributedVacuumStmt(vacuumParams.options, relationIdList); bool distributedVacuumStmt = IsDistributedVacuumStmt(vacuumParams.options,
relationIdList);
if (!distributedVacuumStmt) if (!distributedVacuumStmt)
{ {
return; return;
@ -91,9 +91,6 @@ ProcessVacuumStmt(VacuumStmt *vacuumStmt, const char *vacuumCommand)
Oid relationId = lfirst_oid(relationIdCell); Oid relationId = lfirst_oid(relationIdCell);
if (IsDistributedTable(relationId)) if (IsDistributedTable(relationId))
{ {
List *vacuumColumnList = NIL;
List *taskList = NIL;
/* /*
* VACUUM commands cannot run inside a transaction block, so we use * VACUUM commands cannot run inside a transaction block, so we use
* the "bare" commit protocol without BEGIN/COMMIT. However, ANALYZE * the "bare" commit protocol without BEGIN/COMMIT. However, ANALYZE
@ -108,8 +105,8 @@ ProcessVacuumStmt(VacuumStmt *vacuumStmt, const char *vacuumCommand)
MultiShardCommitProtocol = COMMIT_PROTOCOL_BARE; MultiShardCommitProtocol = COMMIT_PROTOCOL_BARE;
} }
vacuumColumnList = VacuumColumnList(vacuumStmt, relationIndex); List *vacuumColumnList = VacuumColumnList(vacuumStmt, relationIndex);
taskList = VacuumTaskList(relationId, vacuumParams, vacuumColumnList); List *taskList = VacuumTaskList(relationId, vacuumParams, vacuumColumnList);
/* use adaptive executor when enabled */ /* use adaptive executor when enabled */
ExecuteUtilityTaskListWithoutResults(taskList); ExecuteUtilityTaskListWithoutResults(taskList);
@ -135,13 +132,12 @@ IsDistributedVacuumStmt(int vacuumOptions, List *vacuumRelationIdList)
bool distributeStmt = false; bool distributeStmt = false;
ListCell *relationIdCell = NULL; ListCell *relationIdCell = NULL;
int distributedRelationCount = 0; int distributedRelationCount = 0;
int vacuumedRelationCount = 0;
/* /*
* No table in the vacuum statement means vacuuming all relations * No table in the vacuum statement means vacuuming all relations
* which is not supported by citus. * which is not supported by citus.
*/ */
vacuumedRelationCount = list_length(vacuumRelationIdList); int vacuumedRelationCount = list_length(vacuumRelationIdList);
if (vacuumedRelationCount == 0) if (vacuumedRelationCount == 0)
{ {
/* WARN for unqualified VACUUM commands */ /* WARN for unqualified VACUUM commands */
@ -188,18 +184,16 @@ static List *
VacuumTaskList(Oid relationId, CitusVacuumParams vacuumParams, List *vacuumColumnList) VacuumTaskList(Oid relationId, CitusVacuumParams vacuumParams, List *vacuumColumnList)
{ {
List *taskList = NIL; List *taskList = NIL;
List *shardIntervalList = NIL;
ListCell *shardIntervalCell = NULL; ListCell *shardIntervalCell = NULL;
uint64 jobId = INVALID_JOB_ID; uint64 jobId = INVALID_JOB_ID;
int taskId = 1; int taskId = 1;
StringInfo vacuumString = DeparseVacuumStmtPrefix(vacuumParams); StringInfo vacuumString = DeparseVacuumStmtPrefix(vacuumParams);
const char *columnNames = NULL;
const int vacuumPrefixLen = vacuumString->len; const int vacuumPrefixLen = vacuumString->len;
Oid schemaId = get_rel_namespace(relationId); Oid schemaId = get_rel_namespace(relationId);
char *schemaName = get_namespace_name(schemaId); char *schemaName = get_namespace_name(schemaId);
char *tableName = get_rel_name(relationId); char *tableName = get_rel_name(relationId);
columnNames = DeparseVacuumColumnNames(vacuumColumnList); const char *columnNames = DeparseVacuumColumnNames(vacuumColumnList);
/* /*
* We obtain ShareUpdateExclusiveLock here to not conflict with INSERT's * We obtain ShareUpdateExclusiveLock here to not conflict with INSERT's
@ -209,7 +203,7 @@ VacuumTaskList(Oid relationId, CitusVacuumParams vacuumParams, List *vacuumColum
*/ */
LockRelationOid(relationId, ShareUpdateExclusiveLock); LockRelationOid(relationId, ShareUpdateExclusiveLock);
shardIntervalList = LoadShardIntervalList(relationId); List *shardIntervalList = LoadShardIntervalList(relationId);
/* grab shard lock before getting placement list */ /* grab shard lock before getting placement list */
LockShardListMetadata(shardIntervalList, ShareLock); LockShardListMetadata(shardIntervalList, ShareLock);
@ -218,7 +212,6 @@ VacuumTaskList(Oid relationId, CitusVacuumParams vacuumParams, List *vacuumColum
{ {
ShardInterval *shardInterval = (ShardInterval *) lfirst(shardIntervalCell); ShardInterval *shardInterval = (ShardInterval *) lfirst(shardIntervalCell);
uint64 shardId = shardInterval->shardId; uint64 shardId = shardInterval->shardId;
Task *task = NULL;
char *shardName = pstrdup(tableName); char *shardName = pstrdup(tableName);
AppendShardIdToName(&shardName, shardInterval->shardId); AppendShardIdToName(&shardName, shardInterval->shardId);
@ -228,7 +221,7 @@ VacuumTaskList(Oid relationId, CitusVacuumParams vacuumParams, List *vacuumColum
appendStringInfoString(vacuumString, shardName); appendStringInfoString(vacuumString, shardName);
appendStringInfoString(vacuumString, columnNames); appendStringInfoString(vacuumString, columnNames);
task = CitusMakeNode(Task); Task *task = CitusMakeNode(Task);
task->jobId = jobId; task->jobId = jobId;
task->taskId = taskId++; task->taskId = taskId++;
task->taskType = VACUUM_ANALYZE_TASK; task->taskType = VACUUM_ANALYZE_TASK;

View File

@ -96,9 +96,8 @@ IsSettingSafeToPropagate(char *name)
"exit_on_error", "exit_on_error",
"max_stack_depth" "max_stack_depth"
}; };
Index settingIndex = 0;
for (settingIndex = 0; settingIndex < lengthof(skipSettings); settingIndex++) for (Index settingIndex = 0; settingIndex < lengthof(skipSettings); settingIndex++)
{ {
if (pg_strcasecmp(skipSettings[settingIndex], name) == 0) if (pg_strcasecmp(skipSettings[settingIndex], name) == 0)
{ {
@ -138,9 +137,8 @@ ProcessVariableSetStmt(VariableSetStmt *setStmt, const char *setStmtString)
{ {
MultiConnection *connection = dlist_container(MultiConnection, transactionNode, MultiConnection *connection = dlist_container(MultiConnection, transactionNode,
iter.cur); iter.cur);
RemoteTransaction *transaction = NULL;
transaction = &connection->remoteTransaction; RemoteTransaction *transaction = &connection->remoteTransaction;
if (transaction->transactionFailed) if (transaction->transactionFailed)
{ {
continue; continue;
@ -162,10 +160,9 @@ ProcessVariableSetStmt(VariableSetStmt *setStmt, const char *setStmtString)
{ {
MultiConnection *connection = dlist_container(MultiConnection, transactionNode, MultiConnection *connection = dlist_container(MultiConnection, transactionNode,
iter.cur); iter.cur);
RemoteTransaction *transaction = NULL;
const bool raiseErrors = true; const bool raiseErrors = true;
transaction = &connection->remoteTransaction; RemoteTransaction *transaction = &connection->remoteTransaction;
if (transaction->transactionFailed) if (transaction->transactionFailed)
{ {
continue; continue;

View File

@ -76,8 +76,7 @@ InitConnParams()
void void
ResetConnParams() ResetConnParams()
{ {
Index paramIdx = 0; for (Index paramIdx = 0; paramIdx < ConnParams.size; paramIdx++)
for (paramIdx = 0; paramIdx < ConnParams.size; paramIdx++)
{ {
free((void *) ConnParams.keywords[paramIdx]); free((void *) ConnParams.keywords[paramIdx]);
free((void *) ConnParams.values[paramIdx]); free((void *) ConnParams.values[paramIdx]);
@ -135,7 +134,6 @@ bool
CheckConninfo(const char *conninfo, const char **whitelist, CheckConninfo(const char *conninfo, const char **whitelist,
Size whitelistLength, char **errorMsg) Size whitelistLength, char **errorMsg)
{ {
PQconninfoOption *optionArray = NULL;
PQconninfoOption *option = NULL; PQconninfoOption *option = NULL;
Index whitelistIdx PG_USED_FOR_ASSERTS_ONLY = 0; Index whitelistIdx PG_USED_FOR_ASSERTS_ONLY = 0;
char *errorMsgString = NULL; char *errorMsgString = NULL;
@ -165,7 +163,7 @@ CheckConninfo(const char *conninfo, const char **whitelist,
} }
/* this should at least parse */ /* this should at least parse */
optionArray = PQconninfoParse(conninfo, NULL); PQconninfoOption *optionArray = PQconninfoParse(conninfo, NULL);
if (optionArray == NULL) if (optionArray == NULL)
{ {
*errorMsg = "Provided string is not a valid libpq connection info string"; *errorMsg = "Provided string is not a valid libpq connection info string";
@ -187,14 +185,12 @@ CheckConninfo(const char *conninfo, const char **whitelist,
for (option = optionArray; option->keyword != NULL; option++) for (option = optionArray; option->keyword != NULL; option++)
{ {
void *matchingKeyword = NULL;
if (option->val == NULL || option->val[0] == '\0') if (option->val == NULL || option->val[0] == '\0')
{ {
continue; continue;
} }
matchingKeyword = bsearch(&option->keyword, whitelist, whitelistLength, void *matchingKeyword = bsearch(&option->keyword, whitelist, whitelistLength,
sizeof(char *), pg_qsort_strcmp); sizeof(char *), pg_qsort_strcmp);
if (matchingKeyword == NULL) if (matchingKeyword == NULL)
{ {
@ -283,8 +279,6 @@ GetConnParams(ConnectionHashKey *key, char ***keywords, char ***values,
/* auth keywords will begin after global and runtime ones are appended */ /* auth keywords will begin after global and runtime ones are appended */
Index authParamsIdx = ConnParams.size + lengthof(runtimeKeywords); Index authParamsIdx = ConnParams.size + lengthof(runtimeKeywords);
Index paramIndex = 0;
Index runtimeParamIndex = 0;
if (ConnParams.size + lengthof(runtimeKeywords) >= ConnParams.maxSize) if (ConnParams.size + lengthof(runtimeKeywords) >= ConnParams.maxSize)
{ {
@ -296,7 +290,7 @@ GetConnParams(ConnectionHashKey *key, char ***keywords, char ***values,
pg_ltoa(key->port, nodePortString); /* populate node port string with port */ pg_ltoa(key->port, nodePortString); /* populate node port string with port */
/* first step: copy global parameters to beginning of array */ /* first step: copy global parameters to beginning of array */
for (paramIndex = 0; paramIndex < ConnParams.size; paramIndex++) for (Index paramIndex = 0; paramIndex < ConnParams.size; paramIndex++)
{ {
/* copy the keyword&value pointers to the new array */ /* copy the keyword&value pointers to the new array */
connKeywords[paramIndex] = ConnParams.keywords[paramIndex]; connKeywords[paramIndex] = ConnParams.keywords[paramIndex];
@ -311,7 +305,7 @@ GetConnParams(ConnectionHashKey *key, char ***keywords, char ***values,
*runtimeParamStart = ConnParams.size; *runtimeParamStart = ConnParams.size;
/* second step: begin after global params and copy runtime params into our context */ /* second step: begin after global params and copy runtime params into our context */
for (runtimeParamIndex = 0; for (Index runtimeParamIndex = 0;
runtimeParamIndex < lengthof(runtimeKeywords); runtimeParamIndex < lengthof(runtimeKeywords);
runtimeParamIndex++) runtimeParamIndex++)
{ {
@ -334,9 +328,7 @@ GetConnParams(ConnectionHashKey *key, char ***keywords, char ***values,
const char * const char *
GetConnParam(const char *keyword) GetConnParam(const char *keyword)
{ {
Index i = 0; for (Index i = 0; i < ConnParams.size; i++)
for (i = 0; i < ConnParams.size; i++)
{ {
if (strcmp(keyword, ConnParams.keywords[i]) == 0) if (strcmp(keyword, ConnParams.keywords[i]) == 0)
{ {
@ -357,10 +349,9 @@ static Size
CalculateMaxSize() CalculateMaxSize()
{ {
PQconninfoOption *defaults = PQconndefaults(); PQconninfoOption *defaults = PQconndefaults();
PQconninfoOption *option = NULL;
Size maxSize = 0; Size maxSize = 0;
for (option = defaults; for (PQconninfoOption *option = defaults;
option->keyword != NULL; option->keyword != NULL;
option++, maxSize++) option++, maxSize++)
{ {

View File

@ -85,7 +85,6 @@ void
InitializeConnectionManagement(void) InitializeConnectionManagement(void)
{ {
HASHCTL info, connParamsInfo; HASHCTL info, connParamsInfo;
uint32 hashFlags = 0;
/* /*
* Create a single context for connection and transaction related memory * Create a single context for connection and transaction related memory
@ -105,7 +104,7 @@ InitializeConnectionManagement(void)
info.hash = ConnectionHashHash; info.hash = ConnectionHashHash;
info.match = ConnectionHashCompare; info.match = ConnectionHashCompare;
info.hcxt = ConnectionContext; info.hcxt = ConnectionContext;
hashFlags = (HASH_ELEM | HASH_FUNCTION | HASH_CONTEXT | HASH_COMPARE); uint32 hashFlags = (HASH_ELEM | HASH_FUNCTION | HASH_CONTEXT | HASH_COMPARE);
memcpy(&connParamsInfo, &info, sizeof(HASHCTL)); memcpy(&connParamsInfo, &info, sizeof(HASHCTL));
connParamsInfo.entrysize = sizeof(ConnParamsHashEntry); connParamsInfo.entrysize = sizeof(ConnParamsHashEntry);
@ -187,9 +186,7 @@ GetNodeConnection(uint32 flags, const char *hostname, int32 port)
MultiConnection * MultiConnection *
GetNonDataAccessConnection(const char *hostname, int32 port) GetNonDataAccessConnection(const char *hostname, int32 port)
{ {
MultiConnection *connection; MultiConnection *connection = StartNonDataAccessConnection(hostname, port);
connection = StartNonDataAccessConnection(hostname, port);
FinishConnectionEstablishment(connection); FinishConnectionEstablishment(connection);
@ -243,9 +240,8 @@ MultiConnection *
GetNodeUserDatabaseConnection(uint32 flags, const char *hostname, int32 port, const GetNodeUserDatabaseConnection(uint32 flags, const char *hostname, int32 port, const
char *user, const char *database) char *user, const char *database)
{ {
MultiConnection *connection; MultiConnection *connection = StartNodeUserDatabaseConnection(flags, hostname, port,
user, database);
connection = StartNodeUserDatabaseConnection(flags, hostname, port, user, database);
FinishConnectionEstablishment(connection); FinishConnectionEstablishment(connection);
@ -269,10 +265,10 @@ StartWorkerListConnections(List *workerNodeList, uint32 flags, const char *user,
WorkerNode *workerNode = (WorkerNode *) lfirst(workerNodeCell); WorkerNode *workerNode = (WorkerNode *) lfirst(workerNodeCell);
char *nodeName = workerNode->workerName; char *nodeName = workerNode->workerName;
int nodePort = workerNode->workerPort; int nodePort = workerNode->workerPort;
MultiConnection *connection = NULL;
int connectionFlags = 0; int connectionFlags = 0;
connection = StartNodeUserDatabaseConnection(connectionFlags, nodeName, nodePort, MultiConnection *connection = StartNodeUserDatabaseConnection(connectionFlags,
nodeName, nodePort,
user, database); user, database);
connectionList = lappend(connectionList, connection); connectionList = lappend(connectionList, connection);
@ -298,7 +294,6 @@ StartNodeUserDatabaseConnection(uint32 flags, const char *hostname, int32 port,
char *user, const char *database) char *user, const char *database)
{ {
ConnectionHashKey key; ConnectionHashKey key;
ConnectionHashEntry *entry = NULL;
MultiConnection *connection; MultiConnection *connection;
bool found; bool found;
@ -340,7 +335,7 @@ StartNodeUserDatabaseConnection(uint32 flags, const char *hostname, int32 port,
* connection list empty. * connection list empty.
*/ */
entry = hash_search(ConnectionHash, &key, HASH_ENTER, &found); ConnectionHashEntry *entry = hash_search(ConnectionHash, &key, HASH_ENTER, &found);
if (!found) if (!found)
{ {
entry->connections = MemoryContextAlloc(ConnectionContext, entry->connections = MemoryContextAlloc(ConnectionContext,
@ -412,14 +407,13 @@ CloseNodeConnectionsAfterTransaction(char *nodeName, int nodePort)
while ((entry = (ConnectionHashEntry *) hash_seq_search(&status)) != 0) while ((entry = (ConnectionHashEntry *) hash_seq_search(&status)) != 0)
{ {
dlist_iter iter; dlist_iter iter;
dlist_head *connections = NULL;
if (strcmp(entry->key.hostname, nodeName) != 0 || entry->key.port != nodePort) if (strcmp(entry->key.hostname, nodeName) != 0 || entry->key.port != nodePort)
{ {
continue; continue;
} }
connections = entry->connections; dlist_head *connections = entry->connections;
dlist_foreach(iter, connections) dlist_foreach(iter, connections)
{ {
MultiConnection *connection = MultiConnection *connection =
@ -575,7 +569,6 @@ EventSetSizeForConnectionList(List *connections)
static WaitEventSet * static WaitEventSet *
WaitEventSetFromMultiConnectionStates(List *connections, int *waitCount) WaitEventSetFromMultiConnectionStates(List *connections, int *waitCount)
{ {
WaitEventSet *waitEventSet = NULL;
ListCell *connectionCell = NULL; ListCell *connectionCell = NULL;
const int eventSetSize = EventSetSizeForConnectionList(connections); const int eventSetSize = EventSetSizeForConnectionList(connections);
@ -586,7 +579,7 @@ WaitEventSetFromMultiConnectionStates(List *connections, int *waitCount)
*waitCount = 0; *waitCount = 0;
} }
waitEventSet = CreateWaitEventSet(CurrentMemoryContext, eventSetSize); WaitEventSet *waitEventSet = CreateWaitEventSet(CurrentMemoryContext, eventSetSize);
EnsureReleaseResource((MemoryContextCallbackFunction) (&FreeWaitEventSet), EnsureReleaseResource((MemoryContextCallbackFunction) (&FreeWaitEventSet),
waitEventSet); waitEventSet);
@ -602,8 +595,6 @@ WaitEventSetFromMultiConnectionStates(List *connections, int *waitCount)
{ {
MultiConnectionPollState *connectionState = (MultiConnectionPollState *) lfirst( MultiConnectionPollState *connectionState = (MultiConnectionPollState *) lfirst(
connectionCell); connectionCell);
int sock = 0;
int eventMask = 0;
if (numEventsAdded >= eventSetSize) if (numEventsAdded >= eventSetSize)
{ {
@ -617,9 +608,9 @@ WaitEventSetFromMultiConnectionStates(List *connections, int *waitCount)
continue; continue;
} }
sock = PQsocket(connectionState->connection->pgConn); int sock = PQsocket(connectionState->connection->pgConn);
eventMask = MultiConnectionStateEventMask(connectionState); int eventMask = MultiConnectionStateEventMask(connectionState);
AddWaitEventToSet(waitEventSet, eventMask, sock, NULL, connectionState); AddWaitEventToSet(waitEventSet, eventMask, sock, NULL, connectionState);
numEventsAdded++; numEventsAdded++;
@ -672,8 +663,6 @@ FinishConnectionListEstablishment(List *multiConnectionList)
WaitEventSet *waitEventSet = NULL; WaitEventSet *waitEventSet = NULL;
bool waitEventSetRebuild = true; bool waitEventSetRebuild = true;
int waitCount = 0; int waitCount = 0;
WaitEvent *events = NULL;
MemoryContext oldContext = NULL;
foreach(multiConnectionCell, multiConnectionList) foreach(multiConnectionCell, multiConnectionList)
{ {
@ -699,7 +688,8 @@ FinishConnectionListEstablishment(List *multiConnectionList)
} }
/* prepare space for socket events */ /* prepare space for socket events */
events = (WaitEvent *) palloc0(EventSetSizeForConnectionList(connectionStates) * WaitEvent *events = (WaitEvent *) palloc0(EventSetSizeForConnectionList(
connectionStates) *
sizeof(WaitEvent)); sizeof(WaitEvent));
/* /*
@ -707,15 +697,13 @@ FinishConnectionListEstablishment(List *multiConnectionList)
* of (big) waitsets that we'd like to clean right after we have used them. To do this * of (big) waitsets that we'd like to clean right after we have used them. To do this
* we switch to a temporary memory context for this loop which gets reset at the end * we switch to a temporary memory context for this loop which gets reset at the end
*/ */
oldContext = MemoryContextSwitchTo( MemoryContext oldContext = MemoryContextSwitchTo(
AllocSetContextCreate(CurrentMemoryContext, AllocSetContextCreate(CurrentMemoryContext,
"connection establishment temporary context", "connection establishment temporary context",
ALLOCSET_DEFAULT_SIZES)); ALLOCSET_DEFAULT_SIZES));
while (waitCount > 0) while (waitCount > 0)
{ {
long timeout = DeadlineTimestampTzToTimeout(deadline); long timeout = DeadlineTimestampTzToTimeout(deadline);
int eventCount = 0;
int eventIndex = 0;
if (waitEventSetRebuild) if (waitEventSetRebuild)
{ {
@ -730,13 +718,12 @@ FinishConnectionListEstablishment(List *multiConnectionList)
} }
} }
eventCount = WaitEventSetWait(waitEventSet, timeout, events, waitCount, int eventCount = WaitEventSetWait(waitEventSet, timeout, events, waitCount,
WAIT_EVENT_CLIENT_READ); WAIT_EVENT_CLIENT_READ);
for (eventIndex = 0; eventIndex < eventCount; eventIndex++) for (int eventIndex = 0; eventIndex < eventCount; eventIndex++)
{ {
WaitEvent *event = &events[eventIndex]; WaitEvent *event = &events[eventIndex];
bool connectionStateChanged = false;
MultiConnectionPollState *connectionState = MultiConnectionPollState *connectionState =
(MultiConnectionPollState *) event->user_data; (MultiConnectionPollState *) event->user_data;
@ -764,7 +751,7 @@ FinishConnectionListEstablishment(List *multiConnectionList)
continue; continue;
} }
connectionStateChanged = MultiConnectionStatePoll(connectionState); bool connectionStateChanged = MultiConnectionStatePoll(connectionState);
if (connectionStateChanged) if (connectionStateChanged)
{ {
if (connectionState->phase != MULTI_CONNECTION_PHASE_CONNECTING) if (connectionState->phase != MULTI_CONNECTION_PHASE_CONNECTING)
@ -909,9 +896,8 @@ static uint32
ConnectionHashHash(const void *key, Size keysize) ConnectionHashHash(const void *key, Size keysize)
{ {
ConnectionHashKey *entry = (ConnectionHashKey *) key; ConnectionHashKey *entry = (ConnectionHashKey *) key;
uint32 hash = 0;
hash = string_hash(entry->hostname, NAMEDATALEN); uint32 hash = string_hash(entry->hostname, NAMEDATALEN);
hash = hash_combine(hash, hash_uint32(entry->port)); hash = hash_combine(hash, hash_uint32(entry->port));
hash = hash_combine(hash, string_hash(entry->user, NAMEDATALEN)); hash = hash_combine(hash, string_hash(entry->user, NAMEDATALEN));
hash = hash_combine(hash, string_hash(entry->database, NAMEDATALEN)); hash = hash_combine(hash, string_hash(entry->database, NAMEDATALEN));
@ -948,11 +934,9 @@ static MultiConnection *
StartConnectionEstablishment(ConnectionHashKey *key) StartConnectionEstablishment(ConnectionHashKey *key)
{ {
bool found = false; bool found = false;
MultiConnection *connection = NULL;
ConnParamsHashEntry *entry = NULL;
/* search our cache for precomputed connection settings */ /* search our cache for precomputed connection settings */
entry = hash_search(ConnParamsHash, key, HASH_ENTER, &found); ConnParamsHashEntry *entry = hash_search(ConnParamsHash, key, HASH_ENTER, &found);
if (!found || !entry->isValid) if (!found || !entry->isValid)
{ {
/* avoid leaking memory in the keys and values arrays */ /* avoid leaking memory in the keys and values arrays */
@ -968,7 +952,8 @@ StartConnectionEstablishment(ConnectionHashKey *key)
entry->isValid = true; entry->isValid = true;
} }
connection = MemoryContextAllocZero(ConnectionContext, sizeof(MultiConnection)); MultiConnection *connection = MemoryContextAllocZero(ConnectionContext,
sizeof(MultiConnection));
strlcpy(connection->hostname, key->hostname, MAX_NODE_LENGTH); strlcpy(connection->hostname, key->hostname, MAX_NODE_LENGTH);
connection->port = key->port; connection->port = key->port;
@ -1218,9 +1203,8 @@ char *
TrimLogLevel(const char *message) TrimLogLevel(const char *message)
{ {
char *chompedMessage = pchomp(message); char *chompedMessage = pchomp(message);
size_t n;
n = 0; size_t n = 0;
while (n < strlen(chompedMessage) && chompedMessage[n] != ':') while (n < strlen(chompedMessage) && chompedMessage[n] != ':')
{ {
n++; n++;

View File

@ -267,14 +267,15 @@ StartPlacementListConnection(uint32 flags, List *placementAccessList,
const char *userName) const char *userName)
{ {
char *freeUserName = NULL; char *freeUserName = NULL;
MultiConnection *chosenConnection = NULL;
if (userName == NULL) if (userName == NULL)
{ {
userName = freeUserName = CurrentUserName(); userName = freeUserName = CurrentUserName();
} }
chosenConnection = FindPlacementListConnection(flags, placementAccessList, userName); MultiConnection *chosenConnection = FindPlacementListConnection(flags,
placementAccessList,
userName);
if (chosenConnection == NULL) if (chosenConnection == NULL)
{ {
/* use the first placement from the list to extract nodename and nodeport */ /* use the first placement from the list to extract nodename and nodeport */
@ -346,10 +347,6 @@ AssignPlacementListToConnection(List *placementAccessList, MultiConnection *conn
ShardPlacement *placement = placementAccess->placement; ShardPlacement *placement = placementAccess->placement;
ShardPlacementAccessType accessType = placementAccess->accessType; ShardPlacementAccessType accessType = placementAccess->accessType;
ConnectionPlacementHashEntry *placementEntry = NULL;
ConnectionReference *placementConnection = NULL;
Oid relationId = InvalidOid;
if (placement->shardId == INVALID_SHARD_ID) if (placement->shardId == INVALID_SHARD_ID)
{ {
@ -363,8 +360,9 @@ AssignPlacementListToConnection(List *placementAccessList, MultiConnection *conn
continue; continue;
} }
placementEntry = FindOrCreatePlacementEntry(placement); ConnectionPlacementHashEntry *placementEntry = FindOrCreatePlacementEntry(
placementConnection = placementEntry->primaryConnection; placement);
ConnectionReference *placementConnection = placementEntry->primaryConnection;
if (placementConnection->connection == connection) if (placementConnection->connection == connection)
{ {
@ -438,7 +436,7 @@ AssignPlacementListToConnection(List *placementAccessList, MultiConnection *conn
} }
/* record the relation access */ /* record the relation access */
relationId = RelationIdForShard(placement->shardId); Oid relationId = RelationIdForShard(placement->shardId);
RecordRelationAccessIfReferenceTable(relationId, accessType); RecordRelationAccessIfReferenceTable(relationId, accessType);
} }
} }
@ -453,7 +451,6 @@ MultiConnection *
GetConnectionIfPlacementAccessedInXact(int flags, List *placementAccessList, GetConnectionIfPlacementAccessedInXact(int flags, List *placementAccessList,
const char *userName) const char *userName)
{ {
MultiConnection *connection = NULL;
char *freeUserName = NULL; char *freeUserName = NULL;
if (userName == NULL) if (userName == NULL)
@ -461,7 +458,7 @@ GetConnectionIfPlacementAccessedInXact(int flags, List *placementAccessList,
userName = freeUserName = CurrentUserName(); userName = freeUserName = CurrentUserName();
} }
connection = FindPlacementListConnection(flags, placementAccessList, MultiConnection *connection = FindPlacementListConnection(flags, placementAccessList,
userName); userName);
if (freeUserName != NULL) if (freeUserName != NULL)
@ -515,9 +512,6 @@ FindPlacementListConnection(int flags, List *placementAccessList, const char *us
ShardPlacement *placement = placementAccess->placement; ShardPlacement *placement = placementAccess->placement;
ShardPlacementAccessType accessType = placementAccess->accessType; ShardPlacementAccessType accessType = placementAccess->accessType;
ConnectionPlacementHashEntry *placementEntry = NULL;
ColocatedPlacementsHashEntry *colocatedEntry = NULL;
ConnectionReference *placementConnection = NULL;
if (placement->shardId == INVALID_SHARD_ID) if (placement->shardId == INVALID_SHARD_ID)
{ {
@ -530,9 +524,10 @@ FindPlacementListConnection(int flags, List *placementAccessList, const char *us
continue; continue;
} }
placementEntry = FindOrCreatePlacementEntry(placement); ConnectionPlacementHashEntry *placementEntry = FindOrCreatePlacementEntry(
colocatedEntry = placementEntry->colocatedEntry; placement);
placementConnection = placementEntry->primaryConnection; ColocatedPlacementsHashEntry *colocatedEntry = placementEntry->colocatedEntry;
ConnectionReference *placementConnection = placementEntry->primaryConnection;
/* note: the Asserts below are primarily for clarifying the conditions */ /* note: the Asserts below are primarily for clarifying the conditions */
@ -628,12 +623,13 @@ static ConnectionPlacementHashEntry *
FindOrCreatePlacementEntry(ShardPlacement *placement) FindOrCreatePlacementEntry(ShardPlacement *placement)
{ {
ConnectionPlacementHashKey connKey; ConnectionPlacementHashKey connKey;
ConnectionPlacementHashEntry *placementEntry = NULL;
bool found = false; bool found = false;
connKey.placementId = placement->placementId; connKey.placementId = placement->placementId;
placementEntry = hash_search(ConnectionPlacementHash, &connKey, HASH_ENTER, &found); ConnectionPlacementHashEntry *placementEntry = hash_search(ConnectionPlacementHash,
&connKey, HASH_ENTER,
&found);
if (!found) if (!found)
{ {
/* no connection has been chosen for this placement */ /* no connection has been chosen for this placement */
@ -646,14 +642,14 @@ FindOrCreatePlacementEntry(ShardPlacement *placement)
placement->partitionMethod == DISTRIBUTE_BY_NONE) placement->partitionMethod == DISTRIBUTE_BY_NONE)
{ {
ColocatedPlacementsHashKey coloKey; ColocatedPlacementsHashKey coloKey;
ColocatedPlacementsHashEntry *colocatedEntry = NULL;
coloKey.nodeId = placement->nodeId; coloKey.nodeId = placement->nodeId;
coloKey.colocationGroupId = placement->colocationGroupId; coloKey.colocationGroupId = placement->colocationGroupId;
coloKey.representativeValue = placement->representativeValue; coloKey.representativeValue = placement->representativeValue;
/* look for a connection assigned to co-located placements */ /* look for a connection assigned to co-located placements */
colocatedEntry = hash_search(ColocatedPlacementsHash, &coloKey, HASH_ENTER, ColocatedPlacementsHashEntry *colocatedEntry = hash_search(
ColocatedPlacementsHash, &coloKey, HASH_ENTER,
&found); &found);
if (!found) if (!found)
{ {
@ -835,12 +831,12 @@ AssociatePlacementWithShard(ConnectionPlacementHashEntry *placementEntry,
ShardPlacement *placement) ShardPlacement *placement)
{ {
ConnectionShardHashKey shardKey; ConnectionShardHashKey shardKey;
ConnectionShardHashEntry *shardEntry = NULL;
bool found = false; bool found = false;
dlist_iter placementIter; dlist_iter placementIter;
shardKey.shardId = placement->shardId; shardKey.shardId = placement->shardId;
shardEntry = hash_search(ConnectionShardHash, &shardKey, HASH_ENTER, &found); ConnectionShardHashEntry *shardEntry = hash_search(ConnectionShardHash, &shardKey,
HASH_ENTER, &found);
if (!found) if (!found)
{ {
dlist_init(&shardEntry->placementConnections); dlist_init(&shardEntry->placementConnections);
@ -1033,7 +1029,6 @@ CheckShardPlacements(ConnectionShardHashEntry *shardEntry)
ConnectionPlacementHashEntry *placementEntry = ConnectionPlacementHashEntry *placementEntry =
dlist_container(ConnectionPlacementHashEntry, shardNode, placementIter.cur); dlist_container(ConnectionPlacementHashEntry, shardNode, placementIter.cur);
ConnectionReference *primaryConnection = placementEntry->primaryConnection; ConnectionReference *primaryConnection = placementEntry->primaryConnection;
MultiConnection *connection = NULL;
/* we only consider shards that are modified */ /* we only consider shards that are modified */
if (primaryConnection == NULL || if (primaryConnection == NULL ||
@ -1042,7 +1037,7 @@ CheckShardPlacements(ConnectionShardHashEntry *shardEntry)
continue; continue;
} }
connection = primaryConnection->connection; MultiConnection *connection = primaryConnection->connection;
if (!connection || connection->remoteTransaction.transactionFailed) if (!connection || connection->remoteTransaction.transactionFailed)
{ {
@ -1096,7 +1091,6 @@ void
InitPlacementConnectionManagement(void) InitPlacementConnectionManagement(void)
{ {
HASHCTL info; HASHCTL info;
uint32 hashFlags = 0;
/* create (placementId) -> [ConnectionReference] hash */ /* create (placementId) -> [ConnectionReference] hash */
memset(&info, 0, sizeof(info)); memset(&info, 0, sizeof(info));
@ -1104,7 +1098,7 @@ InitPlacementConnectionManagement(void)
info.entrysize = sizeof(ConnectionPlacementHashEntry); info.entrysize = sizeof(ConnectionPlacementHashEntry);
info.hash = tag_hash; info.hash = tag_hash;
info.hcxt = ConnectionContext; info.hcxt = ConnectionContext;
hashFlags = (HASH_ELEM | HASH_BLOBS | HASH_CONTEXT); uint32 hashFlags = (HASH_ELEM | HASH_BLOBS | HASH_CONTEXT);
ConnectionPlacementHash = hash_create("citus connection cache (placementid)", ConnectionPlacementHash = hash_create("citus connection cache (placementid)",
64, &info, hashFlags); 64, &info, hashFlags);
@ -1141,9 +1135,8 @@ static uint32
ColocatedPlacementsHashHash(const void *key, Size keysize) ColocatedPlacementsHashHash(const void *key, Size keysize)
{ {
ColocatedPlacementsHashKey *entry = (ColocatedPlacementsHashKey *) key; ColocatedPlacementsHashKey *entry = (ColocatedPlacementsHashKey *) key;
uint32 hash = 0;
hash = hash_uint32(entry->nodeId); uint32 hash = hash_uint32(entry->nodeId);
hash = hash_combine(hash, hash_uint32(entry->colocationGroupId)); hash = hash_combine(hash, hash_uint32(entry->colocationGroupId));
hash = hash_combine(hash, hash_uint32(entry->representativeValue)); hash = hash_combine(hash, hash_uint32(entry->representativeValue));

View File

@ -171,9 +171,6 @@ ClearResultsIfReady(MultiConnection *connection)
while (true) while (true)
{ {
PGresult *result = NULL;
ExecStatusType resultStatus;
/* /*
* If busy, there might still be results already received and buffered * If busy, there might still be results already received and buffered
* by the OS. As connection is in non-blocking mode, we can check for * by the OS. As connection is in non-blocking mode, we can check for
@ -199,14 +196,14 @@ ClearResultsIfReady(MultiConnection *connection)
return false; return false;
} }
result = PQgetResult(pgConn); PGresult *result = PQgetResult(pgConn);
if (result == NULL) if (result == NULL)
{ {
/* no more results available */ /* no more results available */
return true; return true;
} }
resultStatus = PQresultStatus(result); ExecStatusType resultStatus = PQresultStatus(result);
/* only care about the status, can clear now */ /* only care about the status, can clear now */
PQclear(result); PQclear(result);
@ -241,18 +238,16 @@ bool
SqlStateMatchesCategory(char *sqlStateString, int category) SqlStateMatchesCategory(char *sqlStateString, int category)
{ {
bool sqlStateMatchesCategory = false; bool sqlStateMatchesCategory = false;
int sqlState = 0;
int sqlStateCategory = 0;
if (sqlStateString == NULL) if (sqlStateString == NULL)
{ {
return false; return false;
} }
sqlState = MAKE_SQLSTATE(sqlStateString[0], sqlStateString[1], sqlStateString[2], int sqlState = MAKE_SQLSTATE(sqlStateString[0], sqlStateString[1], sqlStateString[2],
sqlStateString[3], sqlStateString[4]); sqlStateString[3], sqlStateString[4]);
sqlStateCategory = ERRCODE_TO_CATEGORY(sqlState); int sqlStateCategory = ERRCODE_TO_CATEGORY(sqlState);
if (sqlStateCategory == category) if (sqlStateCategory == category)
{ {
sqlStateMatchesCategory = true; sqlStateMatchesCategory = true;
@ -390,17 +385,15 @@ ExecuteCriticalRemoteCommandList(MultiConnection *connection, List *commandList)
void void
ExecuteCriticalRemoteCommand(MultiConnection *connection, const char *command) ExecuteCriticalRemoteCommand(MultiConnection *connection, const char *command)
{ {
int querySent = 0;
PGresult *result = NULL;
bool raiseInterrupts = true; bool raiseInterrupts = true;
querySent = SendRemoteCommand(connection, command); int querySent = SendRemoteCommand(connection, command);
if (querySent == 0) if (querySent == 0)
{ {
ReportConnectionError(connection, ERROR); ReportConnectionError(connection, ERROR);
} }
result = GetRemoteCommandResult(connection, raiseInterrupts); PGresult *result = GetRemoteCommandResult(connection, raiseInterrupts);
if (!IsResponseOK(result)) if (!IsResponseOK(result))
{ {
ReportResultError(connection, result, ERROR); ReportResultError(connection, result, ERROR);
@ -422,18 +415,16 @@ int
ExecuteOptionalRemoteCommand(MultiConnection *connection, const char *command, ExecuteOptionalRemoteCommand(MultiConnection *connection, const char *command,
PGresult **result) PGresult **result)
{ {
int querySent = 0;
PGresult *localResult = NULL;
bool raiseInterrupts = true; bool raiseInterrupts = true;
querySent = SendRemoteCommand(connection, command); int querySent = SendRemoteCommand(connection, command);
if (querySent == 0) if (querySent == 0)
{ {
ReportConnectionError(connection, WARNING); ReportConnectionError(connection, WARNING);
return QUERY_SEND_FAILED; return QUERY_SEND_FAILED;
} }
localResult = GetRemoteCommandResult(connection, raiseInterrupts); PGresult *localResult = GetRemoteCommandResult(connection, raiseInterrupts);
if (!IsResponseOK(localResult)) if (!IsResponseOK(localResult))
{ {
ReportResultError(connection, localResult, WARNING); ReportResultError(connection, localResult, WARNING);
@ -473,7 +464,6 @@ SendRemoteCommandParams(MultiConnection *connection, const char *command,
const char *const *parameterValues) const char *const *parameterValues)
{ {
PGconn *pgConn = connection->pgConn; PGconn *pgConn = connection->pgConn;
int rc = 0;
LogRemoteCommand(connection, command); LogRemoteCommand(connection, command);
@ -488,7 +478,7 @@ SendRemoteCommandParams(MultiConnection *connection, const char *command,
Assert(PQisnonblocking(pgConn)); Assert(PQisnonblocking(pgConn));
rc = PQsendQueryParams(pgConn, command, parameterCount, parameterTypes, int rc = PQsendQueryParams(pgConn, command, parameterCount, parameterTypes,
parameterValues, NULL, NULL, 0); parameterValues, NULL, NULL, 0);
return rc; return rc;
@ -506,7 +496,6 @@ int
SendRemoteCommand(MultiConnection *connection, const char *command) SendRemoteCommand(MultiConnection *connection, const char *command)
{ {
PGconn *pgConn = connection->pgConn; PGconn *pgConn = connection->pgConn;
int rc = 0;
LogRemoteCommand(connection, command); LogRemoteCommand(connection, command);
@ -521,7 +510,7 @@ SendRemoteCommand(MultiConnection *connection, const char *command)
Assert(PQisnonblocking(pgConn)); Assert(PQisnonblocking(pgConn));
rc = PQsendQuery(pgConn, command); int rc = PQsendQuery(pgConn, command);
return rc; return rc;
} }
@ -536,7 +525,6 @@ ReadFirstColumnAsText(PGresult *queryResult)
{ {
List *resultRowList = NIL; List *resultRowList = NIL;
const int columnIndex = 0; const int columnIndex = 0;
int64 rowIndex = 0;
int64 rowCount = 0; int64 rowCount = 0;
ExecStatusType status = PQresultStatus(queryResult); ExecStatusType status = PQresultStatus(queryResult);
@ -545,7 +533,7 @@ ReadFirstColumnAsText(PGresult *queryResult)
rowCount = PQntuples(queryResult); rowCount = PQntuples(queryResult);
} }
for (rowIndex = 0; rowIndex < rowCount; rowIndex++) for (int64 rowIndex = 0; rowIndex < rowCount; rowIndex++)
{ {
char *rowValue = PQgetvalue(queryResult, rowIndex, columnIndex); char *rowValue = PQgetvalue(queryResult, rowIndex, columnIndex);
@ -579,7 +567,6 @@ PGresult *
GetRemoteCommandResult(MultiConnection *connection, bool raiseInterrupts) GetRemoteCommandResult(MultiConnection *connection, bool raiseInterrupts)
{ {
PGconn *pgConn = connection->pgConn; PGconn *pgConn = connection->pgConn;
PGresult *result = NULL;
/* /*
* Short circuit tests around the more expensive parts of this * Short circuit tests around the more expensive parts of this
@ -605,7 +592,7 @@ GetRemoteCommandResult(MultiConnection *connection, bool raiseInterrupts)
/* no IO should be necessary to get result */ /* no IO should be necessary to get result */
Assert(!PQisBusy(pgConn)); Assert(!PQisBusy(pgConn));
result = PQgetResult(connection->pgConn); PGresult *result = PQgetResult(connection->pgConn);
return result; return result;
} }
@ -621,7 +608,6 @@ bool
PutRemoteCopyData(MultiConnection *connection, const char *buffer, int nbytes) PutRemoteCopyData(MultiConnection *connection, const char *buffer, int nbytes)
{ {
PGconn *pgConn = connection->pgConn; PGconn *pgConn = connection->pgConn;
int copyState = 0;
bool allowInterrupts = true; bool allowInterrupts = true;
if (PQstatus(pgConn) != CONNECTION_OK) if (PQstatus(pgConn) != CONNECTION_OK)
@ -631,7 +617,7 @@ PutRemoteCopyData(MultiConnection *connection, const char *buffer, int nbytes)
Assert(PQisnonblocking(pgConn)); Assert(PQisnonblocking(pgConn));
copyState = PQputCopyData(pgConn, buffer, nbytes); int copyState = PQputCopyData(pgConn, buffer, nbytes);
if (copyState == -1) if (copyState == -1)
{ {
return false; return false;
@ -670,7 +656,6 @@ bool
PutRemoteCopyEnd(MultiConnection *connection, const char *errormsg) PutRemoteCopyEnd(MultiConnection *connection, const char *errormsg)
{ {
PGconn *pgConn = connection->pgConn; PGconn *pgConn = connection->pgConn;
int copyState = 0;
bool allowInterrupts = true; bool allowInterrupts = true;
if (PQstatus(pgConn) != CONNECTION_OK) if (PQstatus(pgConn) != CONNECTION_OK)
@ -680,7 +665,7 @@ PutRemoteCopyEnd(MultiConnection *connection, const char *errormsg)
Assert(PQisnonblocking(pgConn)); Assert(PQisnonblocking(pgConn));
copyState = PQputCopyEnd(pgConn, errormsg); int copyState = PQputCopyEnd(pgConn, errormsg);
if (copyState == -1) if (copyState == -1)
{ {
return false; return false;
@ -720,12 +705,10 @@ FinishConnectionIO(MultiConnection *connection, bool raiseInterrupts)
/* perform the necessary IO */ /* perform the necessary IO */
while (true) while (true)
{ {
int sendStatus = 0;
int rc = 0;
int waitFlags = WL_POSTMASTER_DEATH | WL_LATCH_SET; int waitFlags = WL_POSTMASTER_DEATH | WL_LATCH_SET;
/* try to send all pending data */ /* try to send all pending data */
sendStatus = PQflush(pgConn); int sendStatus = PQflush(pgConn);
/* if sending failed, there's nothing more we can do */ /* if sending failed, there's nothing more we can do */
if (sendStatus == -1) if (sendStatus == -1)
@ -753,7 +736,7 @@ FinishConnectionIO(MultiConnection *connection, bool raiseInterrupts)
return true; return true;
} }
rc = WaitLatchOrSocket(MyLatch, waitFlags, sock, 0, PG_WAIT_EXTENSION); int rc = WaitLatchOrSocket(MyLatch, waitFlags, sock, 0, PG_WAIT_EXTENSION);
if (rc & WL_POSTMASTER_DEATH) if (rc & WL_POSTMASTER_DEATH)
{ {
ereport(ERROR, (errmsg("postmaster was shut down, exiting"))); ereport(ERROR, (errmsg("postmaster was shut down, exiting")));
@ -837,7 +820,6 @@ WaitForAllConnections(List *connectionList, bool raiseInterrupts)
{ {
bool cancellationReceived = false; bool cancellationReceived = false;
int eventIndex = 0; int eventIndex = 0;
int eventCount = 0;
long timeout = -1; long timeout = -1;
int pendingConnectionCount = totalConnectionCount - int pendingConnectionCount = totalConnectionCount -
pendingConnectionsStartIndex; pendingConnectionsStartIndex;
@ -857,14 +839,14 @@ WaitForAllConnections(List *connectionList, bool raiseInterrupts)
} }
/* wait for I/O events */ /* wait for I/O events */
eventCount = WaitEventSetWait(waitEventSet, timeout, events, int eventCount = WaitEventSetWait(waitEventSet, timeout, events,
pendingConnectionCount, WAIT_EVENT_CLIENT_READ); pendingConnectionCount,
WAIT_EVENT_CLIENT_READ);
/* process I/O events */ /* process I/O events */
for (; eventIndex < eventCount; eventIndex++) for (; eventIndex < eventCount; eventIndex++)
{ {
WaitEvent *event = &events[eventIndex]; WaitEvent *event = &events[eventIndex];
MultiConnection *connection = NULL;
bool connectionIsReady = false; bool connectionIsReady = false;
if (event->events & WL_POSTMASTER_DEATH) if (event->events & WL_POSTMASTER_DEATH)
@ -896,7 +878,7 @@ WaitForAllConnections(List *connectionList, bool raiseInterrupts)
continue; continue;
} }
connection = (MultiConnection *) event->user_data; MultiConnection *connection = (MultiConnection *) event->user_data;
if (event->events & WL_SOCKET_WRITEABLE) if (event->events & WL_SOCKET_WRITEABLE)
{ {
@ -1028,8 +1010,6 @@ BuildWaitEventSet(MultiConnection **allConnections, int totalConnectionCount,
int pendingConnectionsStartIndex) int pendingConnectionsStartIndex)
{ {
int pendingConnectionCount = totalConnectionCount - pendingConnectionsStartIndex; int pendingConnectionCount = totalConnectionCount - pendingConnectionsStartIndex;
WaitEventSet *waitEventSet = NULL;
int connectionIndex = 0;
/* /*
* subtract 3 to make room for WL_POSTMASTER_DEATH, WL_LATCH_SET, and * subtract 3 to make room for WL_POSTMASTER_DEATH, WL_LATCH_SET, and
@ -1042,9 +1022,11 @@ BuildWaitEventSet(MultiConnection **allConnections, int totalConnectionCount,
/* allocate pending connections + 2 for the signal latch and postmaster death */ /* allocate pending connections + 2 for the signal latch and postmaster death */
/* (CreateWaitEventSet makes room for pgwin32_signal_event automatically) */ /* (CreateWaitEventSet makes room for pgwin32_signal_event automatically) */
waitEventSet = CreateWaitEventSet(CurrentMemoryContext, pendingConnectionCount + 2); WaitEventSet *waitEventSet = CreateWaitEventSet(CurrentMemoryContext,
pendingConnectionCount + 2);
for (connectionIndex = 0; connectionIndex < pendingConnectionCount; connectionIndex++) for (int connectionIndex = 0; connectionIndex < pendingConnectionCount;
connectionIndex++)
{ {
MultiConnection *connection = allConnections[pendingConnectionsStartIndex + MultiConnection *connection = allConnections[pendingConnectionsStartIndex +
connectionIndex]; connectionIndex];
@ -1078,7 +1060,6 @@ bool
SendCancelationRequest(MultiConnection *connection) SendCancelationRequest(MultiConnection *connection)
{ {
char errorBuffer[ERROR_BUFFER_SIZE] = { 0 }; char errorBuffer[ERROR_BUFFER_SIZE] = { 0 };
bool cancelSent = false;
PGcancel *cancelObject = PQgetCancel(connection->pgConn); PGcancel *cancelObject = PQgetCancel(connection->pgConn);
if (cancelObject == NULL) if (cancelObject == NULL)
@ -1087,7 +1068,7 @@ SendCancelationRequest(MultiConnection *connection)
return false; return false;
} }
cancelSent = PQcancel(cancelObject, errorBuffer, sizeof(errorBuffer)); bool cancelSent = PQcancel(cancelObject, errorBuffer, sizeof(errorBuffer));
if (!cancelSent) if (!cancelSent)
{ {
ereport(WARNING, (errmsg("could not issue cancel request"), ereport(WARNING, (errmsg("could not issue cancel request"),

View File

@ -200,16 +200,12 @@ pg_get_serverdef_string(Oid tableRelationId)
char * char *
pg_get_sequencedef_string(Oid sequenceRelationId) pg_get_sequencedef_string(Oid sequenceRelationId)
{ {
char *qualifiedSequenceName = NULL; Form_pg_sequence pgSequenceForm = pg_get_sequencedef(sequenceRelationId);
char *sequenceDef = NULL;
Form_pg_sequence pgSequenceForm = NULL;
pgSequenceForm = pg_get_sequencedef(sequenceRelationId);
/* build our DDL command */ /* build our DDL command */
qualifiedSequenceName = generate_qualified_relation_name(sequenceRelationId); char *qualifiedSequenceName = generate_qualified_relation_name(sequenceRelationId);
sequenceDef = psprintf(CREATE_SEQUENCE_COMMAND, qualifiedSequenceName, char *sequenceDef = psprintf(CREATE_SEQUENCE_COMMAND, qualifiedSequenceName,
pgSequenceForm->seqincrement, pgSequenceForm->seqmin, pgSequenceForm->seqincrement, pgSequenceForm->seqmin,
pgSequenceForm->seqmax, pgSequenceForm->seqstart, pgSequenceForm->seqmax, pgSequenceForm->seqstart,
pgSequenceForm->seqcycle ? "" : "NO "); pgSequenceForm->seqcycle ? "" : "NO ");
@ -225,16 +221,13 @@ pg_get_sequencedef_string(Oid sequenceRelationId)
Form_pg_sequence Form_pg_sequence
pg_get_sequencedef(Oid sequenceRelationId) pg_get_sequencedef(Oid sequenceRelationId)
{ {
Form_pg_sequence pgSequenceForm = NULL; HeapTuple heapTuple = SearchSysCache1(SEQRELID, sequenceRelationId);
HeapTuple heapTuple = NULL;
heapTuple = SearchSysCache1(SEQRELID, sequenceRelationId);
if (!HeapTupleIsValid(heapTuple)) if (!HeapTupleIsValid(heapTuple))
{ {
elog(ERROR, "cache lookup failed for sequence %u", sequenceRelationId); elog(ERROR, "cache lookup failed for sequence %u", sequenceRelationId);
} }
pgSequenceForm = (Form_pg_sequence) GETSTRUCT(heapTuple); Form_pg_sequence pgSequenceForm = (Form_pg_sequence) GETSTRUCT(heapTuple);
ReleaseSysCache(heapTuple); ReleaseSysCache(heapTuple);
@ -253,12 +246,7 @@ pg_get_sequencedef(Oid sequenceRelationId)
char * char *
pg_get_tableschemadef_string(Oid tableRelationId, bool includeSequenceDefaults) pg_get_tableschemadef_string(Oid tableRelationId, bool includeSequenceDefaults)
{ {
Relation relation = NULL;
char *relationName = NULL;
char relationKind = 0; char relationKind = 0;
TupleDesc tupleDescriptor = NULL;
TupleConstr *tupleConstraints = NULL;
int attributeIndex = 0;
bool firstAttributePrinted = false; bool firstAttributePrinted = false;
AttrNumber defaultValueIndex = 0; AttrNumber defaultValueIndex = 0;
AttrNumber constraintIndex = 0; AttrNumber constraintIndex = 0;
@ -273,8 +261,8 @@ pg_get_tableschemadef_string(Oid tableRelationId, bool includeSequenceDefaults)
* pg_attribute, pg_constraint, and pg_class; and therefore using the * pg_attribute, pg_constraint, and pg_class; and therefore using the
* descriptor saves us from a lot of additional work. * descriptor saves us from a lot of additional work.
*/ */
relation = relation_open(tableRelationId, AccessShareLock); Relation relation = relation_open(tableRelationId, AccessShareLock);
relationName = generate_relation_name(tableRelationId, NIL); char *relationName = generate_relation_name(tableRelationId, NIL);
EnsureRelationKindSupported(tableRelationId); EnsureRelationKindSupported(tableRelationId);
@ -301,10 +289,11 @@ pg_get_tableschemadef_string(Oid tableRelationId, bool includeSequenceDefaults)
* and is not inherited from another table, print the column's name and its * and is not inherited from another table, print the column's name and its
* formatted type. * formatted type.
*/ */
tupleDescriptor = RelationGetDescr(relation); TupleDesc tupleDescriptor = RelationGetDescr(relation);
tupleConstraints = tupleDescriptor->constr; TupleConstr *tupleConstraints = tupleDescriptor->constr;
for (attributeIndex = 0; attributeIndex < tupleDescriptor->natts; attributeIndex++) for (int attributeIndex = 0; attributeIndex < tupleDescriptor->natts;
attributeIndex++)
{ {
Form_pg_attribute attributeForm = TupleDescAttr(tupleDescriptor, attributeIndex); Form_pg_attribute attributeForm = TupleDescAttr(tupleDescriptor, attributeIndex);
@ -318,45 +307,40 @@ pg_get_tableschemadef_string(Oid tableRelationId, bool includeSequenceDefaults)
*/ */
if (!attributeForm->attisdropped) if (!attributeForm->attisdropped)
{ {
const char *attributeName = NULL;
const char *attributeTypeName = NULL;
if (firstAttributePrinted) if (firstAttributePrinted)
{ {
appendStringInfoString(&buffer, ", "); appendStringInfoString(&buffer, ", ");
} }
firstAttributePrinted = true; firstAttributePrinted = true;
attributeName = NameStr(attributeForm->attname); const char *attributeName = NameStr(attributeForm->attname);
appendStringInfo(&buffer, "%s ", quote_identifier(attributeName)); appendStringInfo(&buffer, "%s ", quote_identifier(attributeName));
attributeTypeName = format_type_with_typemod(attributeForm->atttypid, const char *attributeTypeName = format_type_with_typemod(
attributeForm->atttypmod); attributeForm->atttypid,
attributeForm->
atttypmod);
appendStringInfoString(&buffer, attributeTypeName); appendStringInfoString(&buffer, attributeTypeName);
/* if this column has a default value, append the default value */ /* if this column has a default value, append the default value */
if (attributeForm->atthasdef) if (attributeForm->atthasdef)
{ {
AttrDefault *defaultValueList = NULL;
AttrDefault *defaultValue = NULL;
Node *defaultNode = NULL;
List *defaultContext = NULL; List *defaultContext = NULL;
char *defaultString = NULL; char *defaultString = NULL;
Assert(tupleConstraints != NULL); Assert(tupleConstraints != NULL);
defaultValueList = tupleConstraints->defval; AttrDefault *defaultValueList = tupleConstraints->defval;
Assert(defaultValueList != NULL); Assert(defaultValueList != NULL);
defaultValue = &(defaultValueList[defaultValueIndex]); AttrDefault *defaultValue = &(defaultValueList[defaultValueIndex]);
defaultValueIndex++; defaultValueIndex++;
Assert(defaultValue->adnum == (attributeIndex + 1)); Assert(defaultValue->adnum == (attributeIndex + 1));
Assert(defaultValueIndex <= tupleConstraints->num_defval); Assert(defaultValueIndex <= tupleConstraints->num_defval);
/* convert expression to node tree, and prepare deparse context */ /* convert expression to node tree, and prepare deparse context */
defaultNode = (Node *) stringToNode(defaultValue->adbin); Node *defaultNode = (Node *) stringToNode(defaultValue->adbin);
/* /*
* if column default value is explicitly requested, or it is * if column default value is explicitly requested, or it is
@ -418,9 +402,6 @@ pg_get_tableschemadef_string(Oid tableRelationId, bool includeSequenceDefaults)
ConstrCheck *checkConstraintList = tupleConstraints->check; ConstrCheck *checkConstraintList = tupleConstraints->check;
ConstrCheck *checkConstraint = &(checkConstraintList[constraintIndex]); ConstrCheck *checkConstraint = &(checkConstraintList[constraintIndex]);
Node *checkNode = NULL;
List *checkContext = NULL;
char *checkString = NULL;
/* if an attribute or constraint has been printed, format properly */ /* if an attribute or constraint has been printed, format properly */
if (firstAttributePrinted || constraintIndex > 0) if (firstAttributePrinted || constraintIndex > 0)
@ -432,11 +413,11 @@ pg_get_tableschemadef_string(Oid tableRelationId, bool includeSequenceDefaults)
quote_identifier(checkConstraint->ccname)); quote_identifier(checkConstraint->ccname));
/* convert expression to node tree, and prepare deparse context */ /* convert expression to node tree, and prepare deparse context */
checkNode = (Node *) stringToNode(checkConstraint->ccbin); Node *checkNode = (Node *) stringToNode(checkConstraint->ccbin);
checkContext = deparse_context_for(relationName, tableRelationId); List *checkContext = deparse_context_for(relationName, tableRelationId);
/* deparse check constraint string */ /* deparse check constraint string */
checkString = deparse_expression(checkNode, checkContext, false, false); char *checkString = deparse_expression(checkNode, checkContext, false, false);
appendStringInfoString(&buffer, checkString); appendStringInfoString(&buffer, checkString);
} }
@ -491,9 +472,8 @@ void
EnsureRelationKindSupported(Oid relationId) EnsureRelationKindSupported(Oid relationId)
{ {
char relationKind = get_rel_relkind(relationId); char relationKind = get_rel_relkind(relationId);
bool supportedRelationKind = false;
supportedRelationKind = RegularTable(relationId) || bool supportedRelationKind = RegularTable(relationId) ||
relationKind == RELKIND_FOREIGN_TABLE; relationKind == RELKIND_FOREIGN_TABLE;
/* /*
@ -523,9 +503,6 @@ EnsureRelationKindSupported(Oid relationId)
char * char *
pg_get_tablecolumnoptionsdef_string(Oid tableRelationId) pg_get_tablecolumnoptionsdef_string(Oid tableRelationId)
{ {
Relation relation = NULL;
TupleDesc tupleDescriptor = NULL;
AttrNumber attributeIndex = 0;
List *columnOptionList = NIL; List *columnOptionList = NIL;
ListCell *columnOptionCell = NULL; ListCell *columnOptionCell = NULL;
bool firstOptionPrinted = false; bool firstOptionPrinted = false;
@ -536,7 +513,7 @@ pg_get_tablecolumnoptionsdef_string(Oid tableRelationId)
* and use the relation's tuple descriptor to access attribute information. * and use the relation's tuple descriptor to access attribute information.
* This is primarily to maintain symmetry with pg_get_tableschemadef. * This is primarily to maintain symmetry with pg_get_tableschemadef.
*/ */
relation = relation_open(tableRelationId, AccessShareLock); Relation relation = relation_open(tableRelationId, AccessShareLock);
EnsureRelationKindSupported(tableRelationId); EnsureRelationKindSupported(tableRelationId);
@ -545,9 +522,10 @@ pg_get_tablecolumnoptionsdef_string(Oid tableRelationId)
* and is not inherited from another table, check if column storage or * and is not inherited from another table, check if column storage or
* statistics statements need to be printed. * statistics statements need to be printed.
*/ */
tupleDescriptor = RelationGetDescr(relation); TupleDesc tupleDescriptor = RelationGetDescr(relation);
for (attributeIndex = 0; attributeIndex < tupleDescriptor->natts; attributeIndex++) for (AttrNumber attributeIndex = 0; attributeIndex < tupleDescriptor->natts;
attributeIndex++)
{ {
Form_pg_attribute attributeForm = TupleDescAttr(tupleDescriptor, attributeIndex); Form_pg_attribute attributeForm = TupleDescAttr(tupleDescriptor, attributeIndex);
char *attributeName = NameStr(attributeForm->attname); char *attributeName = NameStr(attributeForm->attname);
@ -631,8 +609,6 @@ pg_get_tablecolumnoptionsdef_string(Oid tableRelationId)
*/ */
foreach(columnOptionCell, columnOptionList) foreach(columnOptionCell, columnOptionList)
{ {
char *columnOptionStatement = NULL;
if (!firstOptionPrinted) if (!firstOptionPrinted)
{ {
initStringInfo(&buffer); initStringInfo(&buffer);
@ -645,7 +621,7 @@ pg_get_tablecolumnoptionsdef_string(Oid tableRelationId)
} }
firstOptionPrinted = true; firstOptionPrinted = true;
columnOptionStatement = (char *) lfirst(columnOptionCell); char *columnOptionStatement = (char *) lfirst(columnOptionCell);
appendStringInfoString(&buffer, columnOptionStatement); appendStringInfoString(&buffer, columnOptionStatement);
pfree(columnOptionStatement); pfree(columnOptionStatement);
@ -670,14 +646,13 @@ deparse_shard_index_statement(IndexStmt *origStmt, Oid distrelid, int64 shardid,
IndexStmt *indexStmt = copyObject(origStmt); /* copy to avoid modifications */ IndexStmt *indexStmt = copyObject(origStmt); /* copy to avoid modifications */
char *relationName = indexStmt->relation->relname; char *relationName = indexStmt->relation->relname;
char *indexName = indexStmt->idxname; char *indexName = indexStmt->idxname;
List *deparseContext = NULL;
/* extend relation and index name using shard identifier */ /* extend relation and index name using shard identifier */
AppendShardIdToName(&relationName, shardid); AppendShardIdToName(&relationName, shardid);
AppendShardIdToName(&indexName, shardid); AppendShardIdToName(&indexName, shardid);
/* use extended shard name and transformed stmt for deparsing */ /* use extended shard name and transformed stmt for deparsing */
deparseContext = deparse_context_for(relationName, distrelid); List *deparseContext = deparse_context_for(relationName, distrelid);
indexStmt = transformIndexStmt(distrelid, indexStmt, NULL); indexStmt = transformIndexStmt(distrelid, indexStmt, NULL);
appendStringInfo(buffer, "CREATE %s INDEX %s %s %s ON %s USING %s ", appendStringInfo(buffer, "CREATE %s INDEX %s %s %s ON %s USING %s ",
@ -850,19 +825,17 @@ deparse_index_columns(StringInfo buffer, List *indexParameterList, List *deparse
char * char *
pg_get_indexclusterdef_string(Oid indexRelationId) pg_get_indexclusterdef_string(Oid indexRelationId)
{ {
HeapTuple indexTuple = NULL;
Form_pg_index indexForm = NULL;
Oid tableRelationId = InvalidOid;
StringInfoData buffer = { NULL, 0, 0, 0 }; StringInfoData buffer = { NULL, 0, 0, 0 };
indexTuple = SearchSysCache(INDEXRELID, ObjectIdGetDatum(indexRelationId), 0, 0, 0); HeapTuple indexTuple = SearchSysCache(INDEXRELID, ObjectIdGetDatum(indexRelationId),
0, 0, 0);
if (!HeapTupleIsValid(indexTuple)) if (!HeapTupleIsValid(indexTuple))
{ {
ereport(ERROR, (errmsg("cache lookup failed for index %u", indexRelationId))); ereport(ERROR, (errmsg("cache lookup failed for index %u", indexRelationId)));
} }
indexForm = (Form_pg_index) GETSTRUCT(indexTuple); Form_pg_index indexForm = (Form_pg_index) GETSTRUCT(indexTuple);
tableRelationId = indexForm->indrelid; Oid tableRelationId = indexForm->indrelid;
/* check if the table is clustered on this index */ /* check if the table is clustered on this index */
if (indexForm->indisclustered) if (indexForm->indisclustered)
@ -892,20 +865,16 @@ pg_get_table_grants(Oid relationId)
{ {
/* *INDENT-OFF* */ /* *INDENT-OFF* */
StringInfoData buffer; StringInfoData buffer;
Relation relation = NULL;
char *relationName = NULL;
List *defs = NIL; List *defs = NIL;
HeapTuple classTuple = NULL;
Datum aclDatum = 0;
bool isNull = false; bool isNull = false;
relation = relation_open(relationId, AccessShareLock); Relation relation = relation_open(relationId, AccessShareLock);
relationName = generate_relation_name(relationId, NIL); char *relationName = generate_relation_name(relationId, NIL);
initStringInfo(&buffer); initStringInfo(&buffer);
/* lookup all table level grants */ /* lookup all table level grants */
classTuple = SearchSysCache1(RELOID, ObjectIdGetDatum(relationId)); HeapTuple classTuple = SearchSysCache1(RELOID, ObjectIdGetDatum(relationId));
if (!HeapTupleIsValid(classTuple)) if (!HeapTupleIsValid(classTuple))
{ {
ereport(ERROR, ereport(ERROR,
@ -914,17 +883,13 @@ pg_get_table_grants(Oid relationId)
relationId))); relationId)));
} }
aclDatum = SysCacheGetAttr(RELOID, classTuple, Anum_pg_class_relacl, Datum aclDatum = SysCacheGetAttr(RELOID, classTuple, Anum_pg_class_relacl,
&isNull); &isNull);
ReleaseSysCache(classTuple); ReleaseSysCache(classTuple);
if (!isNull) if (!isNull)
{ {
int i = 0;
AclItem *aidat = NULL;
Acl *acl = NULL;
int offtype = 0;
/* /*
* First revoke all default permissions, so we can start adding the * First revoke all default permissions, so we can start adding the
@ -943,11 +908,11 @@ pg_get_table_grants(Oid relationId)
/* iterate through the acl datastructure, emit GRANTs */ /* iterate through the acl datastructure, emit GRANTs */
acl = DatumGetAclP(aclDatum); Acl *acl = DatumGetAclP(aclDatum);
aidat = ACL_DAT(acl); AclItem *aidat = ACL_DAT(acl);
offtype = -1; int offtype = -1;
i = 0; int i = 0;
while (i < ACL_NUM(acl)) while (i < ACL_NUM(acl))
{ {
AclItem *aidata = NULL; AclItem *aidata = NULL;
@ -975,9 +940,8 @@ pg_get_table_grants(Oid relationId)
if (aidata->ai_grantee != 0) if (aidata->ai_grantee != 0)
{ {
HeapTuple htup;
htup = SearchSysCache1(AUTHOID, ObjectIdGetDatum(aidata->ai_grantee)); HeapTuple htup = SearchSysCache1(AUTHOID, ObjectIdGetDatum(aidata->ai_grantee));
if (HeapTupleIsValid(htup)) if (HeapTupleIsValid(htup))
{ {
Form_pg_authid authForm = ((Form_pg_authid) GETSTRUCT(htup)); Form_pg_authid authForm = ((Form_pg_authid) GETSTRUCT(htup));
@ -1029,28 +993,22 @@ pg_get_table_grants(Oid relationId)
char * char *
generate_qualified_relation_name(Oid relid) generate_qualified_relation_name(Oid relid)
{ {
HeapTuple tp; HeapTuple tp = SearchSysCache1(RELOID, ObjectIdGetDatum(relid));
Form_pg_class reltup;
char *relname;
char *nspname;
char *result;
tp = SearchSysCache1(RELOID, ObjectIdGetDatum(relid));
if (!HeapTupleIsValid(tp)) if (!HeapTupleIsValid(tp))
{ {
elog(ERROR, "cache lookup failed for relation %u", relid); elog(ERROR, "cache lookup failed for relation %u", relid);
} }
reltup = (Form_pg_class) GETSTRUCT(tp); Form_pg_class reltup = (Form_pg_class) GETSTRUCT(tp);
relname = NameStr(reltup->relname); char *relname = NameStr(reltup->relname);
nspname = get_namespace_name(reltup->relnamespace); char *nspname = get_namespace_name(reltup->relnamespace);
if (!nspname) if (!nspname)
{ {
elog(ERROR, "cache lookup failed for namespace %u", elog(ERROR, "cache lookup failed for namespace %u",
reltup->relnamespace); reltup->relnamespace);
} }
result = quote_qualified_identifier(nspname, relname); char *result = quote_qualified_identifier(nspname, relname);
ReleaseSysCache(tp); ReleaseSysCache(tp);
@ -1202,16 +1160,13 @@ contain_nextval_expression_walker(Node *node, void *context)
char * char *
pg_get_replica_identity_command(Oid tableRelationId) pg_get_replica_identity_command(Oid tableRelationId)
{ {
Relation relation = NULL;
StringInfo buf = makeStringInfo(); StringInfo buf = makeStringInfo();
char *relationName = NULL;
char replicaIdentity = 0;
relation = heap_open(tableRelationId, AccessShareLock); Relation relation = heap_open(tableRelationId, AccessShareLock);
replicaIdentity = relation->rd_rel->relreplident; char replicaIdentity = relation->rd_rel->relreplident;
relationName = generate_qualified_relation_name(tableRelationId); char *relationName = generate_qualified_relation_name(tableRelationId);
if (replicaIdentity == REPLICA_IDENTITY_INDEX) if (replicaIdentity == REPLICA_IDENTITY_INDEX)
{ {
@ -1251,17 +1206,15 @@ static char *
flatten_reloptions(Oid relid) flatten_reloptions(Oid relid)
{ {
char *result = NULL; char *result = NULL;
HeapTuple tuple;
Datum reloptions;
bool isnull; bool isnull;
tuple = SearchSysCache1(RELOID, ObjectIdGetDatum(relid)); HeapTuple tuple = SearchSysCache1(RELOID, ObjectIdGetDatum(relid));
if (!HeapTupleIsValid(tuple)) if (!HeapTupleIsValid(tuple))
{ {
elog(ERROR, "cache lookup failed for relation %u", relid); elog(ERROR, "cache lookup failed for relation %u", relid);
} }
reloptions = SysCacheGetAttr(RELOID, tuple, Datum reloptions = SysCacheGetAttr(RELOID, tuple,
Anum_pg_class_reloptions, &isnull); Anum_pg_class_reloptions, &isnull);
if (!isnull) if (!isnull)
{ {
@ -1279,16 +1232,14 @@ flatten_reloptions(Oid relid)
for (i = 0; i < noptions; i++) for (i = 0; i < noptions; i++)
{ {
char *option = TextDatumGetCString(options[i]); char *option = TextDatumGetCString(options[i]);
char *name;
char *separator;
char *value; char *value;
/* /*
* Each array element should have the form name=value. If the "=" * Each array element should have the form name=value. If the "="
* is missing for some reason, treat it like an empty value. * is missing for some reason, treat it like an empty value.
*/ */
name = option; char *name = option;
separator = strchr(option, '='); char *separator = strchr(option, '=');
if (separator) if (separator)
{ {
*separator = '\0'; *separator = '\0';
@ -1343,15 +1294,13 @@ flatten_reloptions(Oid relid)
static void static void
simple_quote_literal(StringInfo buf, const char *val) simple_quote_literal(StringInfo buf, const char *val)
{ {
const char *valptr;
/* /*
* We form the string literal according to the prevailing setting of * We form the string literal according to the prevailing setting of
* standard_conforming_strings; we never use E''. User is responsible for * standard_conforming_strings; we never use E''. User is responsible for
* making sure result is used correctly. * making sure result is used correctly.
*/ */
appendStringInfoChar(buf, '\''); appendStringInfoChar(buf, '\'');
for (valptr = val; *valptr; valptr++) for (const char *valptr = val; *valptr; valptr++)
{ {
char ch = *valptr; char ch = *valptr;

View File

@ -270,11 +270,9 @@ static void
AppendAlterExtensionSchemaStmt(StringInfo buf, AppendAlterExtensionSchemaStmt(StringInfo buf,
AlterObjectSchemaStmt *alterExtensionSchemaStmt) AlterObjectSchemaStmt *alterExtensionSchemaStmt)
{ {
const char *extensionName = NULL;
Assert(alterExtensionSchemaStmt->objectType == OBJECT_EXTENSION); Assert(alterExtensionSchemaStmt->objectType == OBJECT_EXTENSION);
extensionName = strVal(alterExtensionSchemaStmt->object); const char *extensionName = strVal(alterExtensionSchemaStmt->object);
appendStringInfo(buf, "ALTER EXTENSION %s SET SCHEMA %s;", extensionName, appendStringInfo(buf, "ALTER EXTENSION %s SET SCHEMA %s;", extensionName,
quote_identifier(alterExtensionSchemaStmt->newschema)); quote_identifier(alterExtensionSchemaStmt->newschema));
} }

View File

@ -488,14 +488,13 @@ AppendFunctionNameList(StringInfo buf, List *objects, ObjectType objtype)
foreach(objectCell, objects) foreach(objectCell, objects)
{ {
Node *object = lfirst(objectCell); Node *object = lfirst(objectCell);
ObjectWithArgs *func = NULL;
if (objectCell != list_head(objects)) if (objectCell != list_head(objects))
{ {
appendStringInfo(buf, ", "); appendStringInfo(buf, ", ");
} }
func = castNode(ObjectWithArgs, object); ObjectWithArgs *func = castNode(ObjectWithArgs, object);
AppendFunctionName(buf, func, objtype); AppendFunctionName(buf, func, objtype);
} }
@ -508,14 +507,11 @@ AppendFunctionNameList(StringInfo buf, List *objects, ObjectType objtype)
static void static void
AppendFunctionName(StringInfo buf, ObjectWithArgs *func, ObjectType objtype) AppendFunctionName(StringInfo buf, ObjectWithArgs *func, ObjectType objtype)
{ {
Oid funcid = InvalidOid;
HeapTuple proctup;
char *functionName = NULL; char *functionName = NULL;
char *schemaName = NULL; char *schemaName = NULL;
char *qualifiedFunctionName;
funcid = LookupFuncWithArgs(objtype, func, true); Oid funcid = LookupFuncWithArgs(objtype, func, true);
proctup = SearchSysCache1(PROCOID, ObjectIdGetDatum(funcid)); HeapTuple proctup = SearchSysCache1(PROCOID, ObjectIdGetDatum(funcid));
if (!HeapTupleIsValid(proctup)) if (!HeapTupleIsValid(proctup))
{ {
@ -529,9 +525,7 @@ AppendFunctionName(StringInfo buf, ObjectWithArgs *func, ObjectType objtype)
} }
else else
{ {
Form_pg_proc procform; Form_pg_proc procform = (Form_pg_proc) GETSTRUCT(proctup);
procform = (Form_pg_proc) GETSTRUCT(proctup);
functionName = NameStr(procform->proname); functionName = NameStr(procform->proname);
functionName = pstrdup(functionName); /* we release the tuple before used */ functionName = pstrdup(functionName); /* we release the tuple before used */
schemaName = get_namespace_name(procform->pronamespace); schemaName = get_namespace_name(procform->pronamespace);
@ -539,7 +533,7 @@ AppendFunctionName(StringInfo buf, ObjectWithArgs *func, ObjectType objtype)
ReleaseSysCache(proctup); ReleaseSysCache(proctup);
} }
qualifiedFunctionName = quote_qualified_identifier(schemaName, functionName); char *qualifiedFunctionName = quote_qualified_identifier(schemaName, functionName);
appendStringInfoString(buf, qualifiedFunctionName); appendStringInfoString(buf, qualifiedFunctionName);
if (OidIsValid(funcid)) if (OidIsValid(funcid))
@ -548,28 +542,25 @@ AppendFunctionName(StringInfo buf, ObjectWithArgs *func, ObjectType objtype)
* If the function exists we want to use pg_get_function_identity_arguments to * If the function exists we want to use pg_get_function_identity_arguments to
* serialize its canonical arguments * serialize its canonical arguments
*/ */
OverrideSearchPath *overridePath = NULL;
Datum sqlTextDatum = 0;
const char *args = NULL;
/* /*
* Set search_path to NIL so that all objects outside of pg_catalog will be * Set search_path to NIL so that all objects outside of pg_catalog will be
* schema-prefixed. pg_catalog will be added automatically when we call * schema-prefixed. pg_catalog will be added automatically when we call
* PushOverrideSearchPath(), since we set addCatalog to true; * PushOverrideSearchPath(), since we set addCatalog to true;
*/ */
overridePath = GetOverrideSearchPath(CurrentMemoryContext); OverrideSearchPath *overridePath = GetOverrideSearchPath(CurrentMemoryContext);
overridePath->schemas = NIL; overridePath->schemas = NIL;
overridePath->addCatalog = true; overridePath->addCatalog = true;
PushOverrideSearchPath(overridePath); PushOverrideSearchPath(overridePath);
sqlTextDatum = DirectFunctionCall1(pg_get_function_identity_arguments, Datum sqlTextDatum = DirectFunctionCall1(pg_get_function_identity_arguments,
ObjectIdGetDatum(funcid)); ObjectIdGetDatum(funcid));
/* revert back to original search_path */ /* revert back to original search_path */
PopOverrideSearchPath(); PopOverrideSearchPath();
args = TextDatumGetCString(sqlTextDatum); const char *args = TextDatumGetCString(sqlTextDatum);
appendStringInfo(buf, "(%s)", args); appendStringInfo(buf, "(%s)", args);
} }
else if (!func->args_unspecified) else if (!func->args_unspecified)
@ -580,9 +571,8 @@ AppendFunctionName(StringInfo buf, ObjectWithArgs *func, ObjectType objtype)
* postgres' TypeNameListToString. For now the best we can do until we understand * postgres' TypeNameListToString. For now the best we can do until we understand
* the underlying cause better. * the underlying cause better.
*/ */
const char *args = NULL;
args = TypeNameListToString(func->objargs); const char *args = TypeNameListToString(func->objargs);
appendStringInfo(buf, "(%s)", args); appendStringInfo(buf, "(%s)", args);
} }

View File

@ -137,14 +137,12 @@ AppendAlterTypeStmt(StringInfo buf, AlterTableStmt *stmt)
appendStringInfo(buf, "ALTER TYPE %s", identifier); appendStringInfo(buf, "ALTER TYPE %s", identifier);
foreach(cmdCell, stmt->cmds) foreach(cmdCell, stmt->cmds)
{ {
AlterTableCmd *alterTableCmd = NULL;
if (cmdCell != list_head(stmt->cmds)) if (cmdCell != list_head(stmt->cmds))
{ {
appendStringInfoString(buf, ", "); appendStringInfoString(buf, ", ");
} }
alterTableCmd = castNode(AlterTableCmd, lfirst(cmdCell)); AlterTableCmd *alterTableCmd = castNode(AlterTableCmd, lfirst(cmdCell));
AppendAlterTypeCmd(buf, alterTableCmd); AppendAlterTypeCmd(buf, alterTableCmd);
} }
@ -317,13 +315,11 @@ AppendCompositeTypeStmt(StringInfo str, CompositeTypeStmt *stmt)
static void static void
AppendCreateEnumStmt(StringInfo str, CreateEnumStmt *stmt) AppendCreateEnumStmt(StringInfo str, CreateEnumStmt *stmt)
{ {
RangeVar *typevar = NULL; RangeVar *typevar = makeRangeVarFromNameList(stmt->typeName);
const char *identifier = NULL;
typevar = makeRangeVarFromNameList(stmt->typeName);
/* create the identifier from the fully qualified rangevar */ /* create the identifier from the fully qualified rangevar */
identifier = quote_qualified_identifier(typevar->schemaname, typevar->relname); const char *identifier = quote_qualified_identifier(typevar->schemaname,
typevar->relname);
appendStringInfo(str, "CREATE TYPE %s AS ENUM (", identifier); appendStringInfo(str, "CREATE TYPE %s AS ENUM (", identifier);
AppendStringList(str, stmt->vals); AppendStringList(str, stmt->vals);
@ -472,11 +468,9 @@ DeparseAlterTypeSchemaStmt(AlterObjectSchemaStmt *stmt)
static void static void
AppendAlterTypeSchemaStmt(StringInfo buf, AlterObjectSchemaStmt *stmt) AppendAlterTypeSchemaStmt(StringInfo buf, AlterObjectSchemaStmt *stmt)
{ {
List *names = NIL;
Assert(stmt->objectType == OBJECT_TYPE); Assert(stmt->objectType == OBJECT_TYPE);
names = (List *) stmt->object; List *names = (List *) stmt->object;
appendStringInfo(buf, "ALTER TYPE %s SET SCHEMA %s;", NameListToQuotedString(names), appendStringInfo(buf, "ALTER TYPE %s SET SCHEMA %s;", NameListToQuotedString(names),
quote_identifier(stmt->newschema)); quote_identifier(stmt->newschema));
} }
@ -499,11 +493,9 @@ DeparseAlterTypeOwnerStmt(AlterOwnerStmt *stmt)
static void static void
AppendAlterTypeOwnerStmt(StringInfo buf, AlterOwnerStmt *stmt) AppendAlterTypeOwnerStmt(StringInfo buf, AlterOwnerStmt *stmt)
{ {
List *names = NIL;
Assert(stmt->objectType == OBJECT_TYPE); Assert(stmt->objectType == OBJECT_TYPE);
names = (List *) stmt->object; List *names = (List *) stmt->object;
appendStringInfo(buf, "ALTER TYPE %s OWNER TO %s;", NameListToQuotedString(names), appendStringInfo(buf, "ALTER TYPE %s OWNER TO %s;", NameListToQuotedString(names),
RoleSpecString(stmt->newowner, true)); RoleSpecString(stmt->newowner, true));
} }

View File

@ -60,18 +60,14 @@ FormatCollateBEQualified(Oid collate_oid)
char * char *
FormatCollateExtended(Oid collid, bits16 flags) FormatCollateExtended(Oid collid, bits16 flags)
{ {
HeapTuple tuple = NULL;
Form_pg_collation collform = NULL;
char *buf = NULL;
char *nspname = NULL; char *nspname = NULL;
char *typname = NULL;
if (collid == InvalidOid && (flags & FORMAT_COLLATE_ALLOW_INVALID) != 0) if (collid == InvalidOid && (flags & FORMAT_COLLATE_ALLOW_INVALID) != 0)
{ {
return pstrdup("-"); return pstrdup("-");
} }
tuple = SearchSysCache1(COLLOID, ObjectIdGetDatum(collid)); HeapTuple tuple = SearchSysCache1(COLLOID, ObjectIdGetDatum(collid));
if (!HeapTupleIsValid(tuple)) if (!HeapTupleIsValid(tuple))
{ {
if ((flags & FORMAT_COLLATE_ALLOW_INVALID) != 0) if ((flags & FORMAT_COLLATE_ALLOW_INVALID) != 0)
@ -83,7 +79,7 @@ FormatCollateExtended(Oid collid, bits16 flags)
elog(ERROR, "cache lookup failed for collate %u", collid); elog(ERROR, "cache lookup failed for collate %u", collid);
} }
} }
collform = (Form_pg_collation) GETSTRUCT(tuple); Form_pg_collation collform = (Form_pg_collation) GETSTRUCT(tuple);
if ((flags & FORMAT_COLLATE_FORCE_QUALIFY) == 0 && CollationIsVisible(collid)) if ((flags & FORMAT_COLLATE_FORCE_QUALIFY) == 0 && CollationIsVisible(collid))
{ {
@ -94,9 +90,9 @@ FormatCollateExtended(Oid collid, bits16 flags)
nspname = get_namespace_name_or_temp(collform->collnamespace); nspname = get_namespace_name_or_temp(collform->collnamespace);
} }
typname = NameStr(collform->collname); char *typname = NameStr(collform->collname);
buf = quote_qualified_identifier(nspname, typname); char *buf = quote_qualified_identifier(nspname, typname);
ReleaseSysCache(tuple); ReleaseSysCache(tuple);

View File

@ -143,11 +143,9 @@ QualifyFunctionSchemaName(ObjectWithArgs *func, ObjectType type)
{ {
char *schemaName = NULL; char *schemaName = NULL;
char *functionName = NULL; char *functionName = NULL;
Oid funcid = InvalidOid;
HeapTuple proctup;
funcid = LookupFuncWithArgs(type, func, true); Oid funcid = LookupFuncWithArgs(type, func, true);
proctup = SearchSysCache1(PROCOID, ObjectIdGetDatum(funcid)); HeapTuple proctup = SearchSysCache1(PROCOID, ObjectIdGetDatum(funcid));
/* /*
* We can not qualify the function if the catalogs do not have any records. * We can not qualify the function if the catalogs do not have any records.
@ -156,9 +154,7 @@ QualifyFunctionSchemaName(ObjectWithArgs *func, ObjectType type)
*/ */
if (HeapTupleIsValid(proctup)) if (HeapTupleIsValid(proctup))
{ {
Form_pg_proc procform; Form_pg_proc procform = (Form_pg_proc) GETSTRUCT(proctup);
procform = (Form_pg_proc) GETSTRUCT(proctup);
schemaName = get_namespace_name(procform->pronamespace); schemaName = get_namespace_name(procform->pronamespace);
functionName = NameStr(procform->proname); functionName = NameStr(procform->proname);
functionName = pstrdup(functionName); /* we release the tuple before used */ functionName = pstrdup(functionName); /* we release the tuple before used */

View File

@ -53,17 +53,15 @@ GetTypeNamespaceNameByNameList(List *names)
static Oid static Oid
TypeOidGetNamespaceOid(Oid typeOid) TypeOidGetNamespaceOid(Oid typeOid)
{ {
Form_pg_type typeData = NULL;
HeapTuple typeTuple = SearchSysCache1(TYPEOID, typeOid); HeapTuple typeTuple = SearchSysCache1(TYPEOID, typeOid);
Oid typnamespace = InvalidOid;
if (!HeapTupleIsValid(typeTuple)) if (!HeapTupleIsValid(typeTuple))
{ {
elog(ERROR, "citus cache lookup failed"); elog(ERROR, "citus cache lookup failed");
return InvalidOid; return InvalidOid;
} }
typeData = (Form_pg_type) GETSTRUCT(typeTuple); Form_pg_type typeData = (Form_pg_type) GETSTRUCT(typeTuple);
typnamespace = typeData->typnamespace; Oid typnamespace = typeData->typnamespace;
ReleaseSysCache(typeTuple); ReleaseSysCache(typeTuple);
@ -161,11 +159,9 @@ QualifyCreateEnumStmt(CreateEnumStmt *stmt)
void void
QualifyAlterTypeSchemaStmt(AlterObjectSchemaStmt *stmt) QualifyAlterTypeSchemaStmt(AlterObjectSchemaStmt *stmt)
{ {
List *names = NIL;
Assert(stmt->objectType == OBJECT_TYPE); Assert(stmt->objectType == OBJECT_TYPE);
names = (List *) stmt->object; List *names = (List *) stmt->object;
if (list_length(names) == 1) if (list_length(names) == 1)
{ {
/* not qualified with schema, lookup type and its schema s*/ /* not qualified with schema, lookup type and its schema s*/
@ -179,11 +175,9 @@ QualifyAlterTypeSchemaStmt(AlterObjectSchemaStmt *stmt)
void void
QualifyAlterTypeOwnerStmt(AlterOwnerStmt *stmt) QualifyAlterTypeOwnerStmt(AlterOwnerStmt *stmt)
{ {
List *names = NIL;
Assert(stmt->objectType == OBJECT_TYPE); Assert(stmt->objectType == OBJECT_TYPE);
names = (List *) stmt->object; List *names = (List *) stmt->object;
if (list_length(names) == 1) if (list_length(names) == 1)
{ {
/* not qualified with schema, lookup type and its schema s*/ /* not qualified with schema, lookup type and its schema s*/

View File

@ -612,7 +612,6 @@ AdaptiveExecutor(CitusScanState *scanState)
TupleTableSlot *resultSlot = NULL; TupleTableSlot *resultSlot = NULL;
DistributedPlan *distributedPlan = scanState->distributedPlan; DistributedPlan *distributedPlan = scanState->distributedPlan;
DistributedExecution *execution = NULL;
EState *executorState = ScanStateGetExecutorState(scanState); EState *executorState = ScanStateGetExecutorState(scanState);
ParamListInfo paramListInfo = executorState->es_param_list_info; ParamListInfo paramListInfo = executorState->es_param_list_info;
TupleDesc tupleDescriptor = ScanStateGetTupleDescriptor(scanState); TupleDesc tupleDescriptor = ScanStateGetTupleDescriptor(scanState);
@ -645,10 +644,13 @@ AdaptiveExecutor(CitusScanState *scanState)
scanState->tuplestorestate = scanState->tuplestorestate =
tuplestore_begin_heap(randomAccess, interTransactions, work_mem); tuplestore_begin_heap(randomAccess, interTransactions, work_mem);
execution = CreateDistributedExecution(distributedPlan->modLevel, taskList, DistributedExecution *execution = CreateDistributedExecution(
distributedPlan->hasReturning, paramListInfo, distributedPlan->modLevel, taskList,
distributedPlan->
hasReturning, paramListInfo,
tupleDescriptor, tupleDescriptor,
scanState->tuplestorestate, targetPoolSize); scanState->
tuplestorestate, targetPoolSize);
/* /*
* Make sure that we acquire the appropriate locks even if the local tasks * Make sure that we acquire the appropriate locks even if the local tasks
@ -715,7 +717,6 @@ static void
RunLocalExecution(CitusScanState *scanState, DistributedExecution *execution) RunLocalExecution(CitusScanState *scanState, DistributedExecution *execution)
{ {
uint64 rowsProcessed = ExecuteLocalTaskList(scanState, execution->localTaskList); uint64 rowsProcessed = ExecuteLocalTaskList(scanState, execution->localTaskList);
EState *executorState = NULL;
LocalExecutionHappened = true; LocalExecutionHappened = true;
@ -725,7 +726,7 @@ RunLocalExecution(CitusScanState *scanState, DistributedExecution *execution)
* and in AdaptiveExecutor. Instead, we set executorState here and skip updating it * and in AdaptiveExecutor. Instead, we set executorState here and skip updating it
* for reference table modifications in AdaptiveExecutor. * for reference table modifications in AdaptiveExecutor.
*/ */
executorState = ScanStateGetExecutorState(scanState); EState *executorState = ScanStateGetExecutorState(scanState);
executorState->es_processed = rowsProcessed; executorState->es_processed = rowsProcessed;
} }
@ -782,7 +783,6 @@ ExecuteTaskListExtended(RowModifyLevel modLevel, List *taskList,
TupleDesc tupleDescriptor, Tuplestorestate *tupleStore, TupleDesc tupleDescriptor, Tuplestorestate *tupleStore,
bool hasReturning, int targetPoolSize) bool hasReturning, int targetPoolSize)
{ {
DistributedExecution *execution = NULL;
ParamListInfo paramListInfo = NULL; ParamListInfo paramListInfo = NULL;
/* /*
@ -796,7 +796,7 @@ ExecuteTaskListExtended(RowModifyLevel modLevel, List *taskList,
targetPoolSize = 1; targetPoolSize = 1;
} }
execution = DistributedExecution *execution =
CreateDistributedExecution(modLevel, taskList, hasReturning, paramListInfo, CreateDistributedExecution(modLevel, taskList, hasReturning, paramListInfo,
tupleDescriptor, tupleStore, targetPoolSize); tupleDescriptor, tupleStore, targetPoolSize);
@ -993,8 +993,6 @@ DistributedPlanModifiesDatabase(DistributedPlan *plan)
static bool static bool
TaskListModifiesDatabase(RowModifyLevel modLevel, List *taskList) TaskListModifiesDatabase(RowModifyLevel modLevel, List *taskList)
{ {
Task *firstTask = NULL;
if (modLevel > ROW_MODIFY_READONLY) if (modLevel > ROW_MODIFY_READONLY)
{ {
return true; return true;
@ -1010,7 +1008,7 @@ TaskListModifiesDatabase(RowModifyLevel modLevel, List *taskList)
return false; return false;
} }
firstTask = (Task *) linitial(taskList); Task *firstTask = (Task *) linitial(taskList);
return !ReadOnlyTask(firstTask->taskType); return !ReadOnlyTask(firstTask->taskType);
} }
@ -1027,8 +1025,6 @@ DistributedExecutionRequiresRollback(DistributedExecution *execution)
{ {
List *taskList = execution->tasksToExecute; List *taskList = execution->tasksToExecute;
int taskCount = list_length(taskList); int taskCount = list_length(taskList);
Task *task = NULL;
bool selectForUpdate = false;
if (MultiShardCommitProtocol == COMMIT_PROTOCOL_BARE) if (MultiShardCommitProtocol == COMMIT_PROTOCOL_BARE)
{ {
@ -1040,9 +1036,9 @@ DistributedExecutionRequiresRollback(DistributedExecution *execution)
return false; return false;
} }
task = (Task *) linitial(taskList); Task *task = (Task *) linitial(taskList);
selectForUpdate = task->relationRowLockList != NIL; bool selectForUpdate = task->relationRowLockList != NIL;
if (selectForUpdate) if (selectForUpdate)
{ {
/* /*
@ -1114,16 +1110,12 @@ DistributedExecutionRequiresRollback(DistributedExecution *execution)
static bool static bool
TaskListRequires2PC(List *taskList) TaskListRequires2PC(List *taskList)
{ {
Task *task = NULL;
bool multipleTasks = false;
uint64 anchorShardId = INVALID_SHARD_ID;
if (taskList == NIL) if (taskList == NIL)
{ {
return false; return false;
} }
task = (Task *) linitial(taskList); Task *task = (Task *) linitial(taskList);
if (task->replicationModel == REPLICATION_MODEL_2PC) if (task->replicationModel == REPLICATION_MODEL_2PC)
{ {
return true; return true;
@ -1136,13 +1128,13 @@ TaskListRequires2PC(List *taskList)
* TODO: Do we ever need replicationModel in the Task structure? * TODO: Do we ever need replicationModel in the Task structure?
* Can't we always rely on anchorShardId? * Can't we always rely on anchorShardId?
*/ */
anchorShardId = task->anchorShardId; uint64 anchorShardId = task->anchorShardId;
if (anchorShardId != INVALID_SHARD_ID && ReferenceTableShardId(anchorShardId)) if (anchorShardId != INVALID_SHARD_ID && ReferenceTableShardId(anchorShardId))
{ {
return true; return true;
} }
multipleTasks = list_length(taskList) > 1; bool multipleTasks = list_length(taskList) > 1;
if (!ReadOnlyTask(task->taskType) && if (!ReadOnlyTask(task->taskType) &&
multipleTasks && MultiShardCommitProtocol == COMMIT_PROTOCOL_2PC) multipleTasks && MultiShardCommitProtocol == COMMIT_PROTOCOL_2PC)
{ {
@ -1190,7 +1182,6 @@ ReadOnlyTask(TaskType taskType)
static bool static bool
SelectForUpdateOnReferenceTable(RowModifyLevel modLevel, List *taskList) SelectForUpdateOnReferenceTable(RowModifyLevel modLevel, List *taskList)
{ {
Task *task = NULL;
ListCell *rtiLockCell = NULL; ListCell *rtiLockCell = NULL;
if (modLevel != ROW_MODIFY_READONLY) if (modLevel != ROW_MODIFY_READONLY)
@ -1204,7 +1195,7 @@ SelectForUpdateOnReferenceTable(RowModifyLevel modLevel, List *taskList)
return false; return false;
} }
task = (Task *) linitial(taskList); Task *task = (Task *) linitial(taskList);
foreach(rtiLockCell, task->relationRowLockList) foreach(rtiLockCell, task->relationRowLockList)
{ {
RelationRowLock *relationRowLock = (RelationRowLock *) lfirst(rtiLockCell); RelationRowLock *relationRowLock = (RelationRowLock *) lfirst(rtiLockCell);
@ -1441,7 +1432,6 @@ AssignTasksToConnections(DistributedExecution *execution)
foreach(taskCell, taskList) foreach(taskCell, taskList)
{ {
Task *task = (Task *) lfirst(taskCell); Task *task = (Task *) lfirst(taskCell);
ShardCommandExecution *shardCommandExecution = NULL;
ListCell *taskPlacementCell = NULL; ListCell *taskPlacementCell = NULL;
bool placementExecutionReady = true; bool placementExecutionReady = true;
int placementExecutionIndex = 0; int placementExecutionIndex = 0;
@ -1450,7 +1440,7 @@ AssignTasksToConnections(DistributedExecution *execution)
/* /*
* Execution of a command on a shard, which may have multiple replicas. * Execution of a command on a shard, which may have multiple replicas.
*/ */
shardCommandExecution = ShardCommandExecution *shardCommandExecution =
(ShardCommandExecution *) palloc0(sizeof(ShardCommandExecution)); (ShardCommandExecution *) palloc0(sizeof(ShardCommandExecution));
shardCommandExecution->task = task; shardCommandExecution->task = task;
shardCommandExecution->executionOrder = ExecutionOrderForTask(modLevel, task); shardCommandExecution->executionOrder = ExecutionOrderForTask(modLevel, task);
@ -1467,10 +1457,7 @@ AssignTasksToConnections(DistributedExecution *execution)
foreach(taskPlacementCell, task->taskPlacementList) foreach(taskPlacementCell, task->taskPlacementList)
{ {
ShardPlacement *taskPlacement = (ShardPlacement *) lfirst(taskPlacementCell); ShardPlacement *taskPlacement = (ShardPlacement *) lfirst(taskPlacementCell);
List *placementAccessList = NULL;
MultiConnection *connection = NULL;
int connectionFlags = 0; int connectionFlags = 0;
TaskPlacementExecution *placementExecution = NULL;
char *nodeName = taskPlacement->nodeName; char *nodeName = taskPlacement->nodeName;
int nodePort = taskPlacement->nodePort; int nodePort = taskPlacement->nodePort;
WorkerPool *workerPool = FindOrCreateWorkerPool(execution, nodeName, WorkerPool *workerPool = FindOrCreateWorkerPool(execution, nodeName,
@ -1480,7 +1467,7 @@ AssignTasksToConnections(DistributedExecution *execution)
* Execution of a command on a shard placement, which may not always * Execution of a command on a shard placement, which may not always
* happen if the query is read-only and the shard has multiple placements. * happen if the query is read-only and the shard has multiple placements.
*/ */
placementExecution = TaskPlacementExecution *placementExecution =
(TaskPlacementExecution *) palloc0(sizeof(TaskPlacementExecution)); (TaskPlacementExecution *) palloc0(sizeof(TaskPlacementExecution));
placementExecution->shardCommandExecution = shardCommandExecution; placementExecution->shardCommandExecution = shardCommandExecution;
placementExecution->shardPlacement = taskPlacement; placementExecution->shardPlacement = taskPlacement;
@ -1501,13 +1488,14 @@ AssignTasksToConnections(DistributedExecution *execution)
placementExecutionIndex++; placementExecutionIndex++;
placementAccessList = PlacementAccessListForTask(task, taskPlacement); List *placementAccessList = PlacementAccessListForTask(task, taskPlacement);
/* /*
* Determine whether the task has to be assigned to a particular connection * Determine whether the task has to be assigned to a particular connection
* due to a preceding access to the placement in the same transaction. * due to a preceding access to the placement in the same transaction.
*/ */
connection = GetConnectionIfPlacementAccessedInXact(connectionFlags, MultiConnection *connection = GetConnectionIfPlacementAccessedInXact(
connectionFlags,
placementAccessList, placementAccessList,
NULL); NULL);
if (connection != NULL) if (connection != NULL)
@ -1670,7 +1658,6 @@ FindOrCreateWorkerPool(DistributedExecution *execution, char *nodeName, int node
{ {
WorkerPool *workerPool = NULL; WorkerPool *workerPool = NULL;
ListCell *workerCell = NULL; ListCell *workerCell = NULL;
int nodeConnectionCount = 0;
foreach(workerCell, execution->workerList) foreach(workerCell, execution->workerList)
{ {
@ -1690,7 +1677,7 @@ FindOrCreateWorkerPool(DistributedExecution *execution, char *nodeName, int node
workerPool->distributedExecution = execution; workerPool->distributedExecution = execution;
/* "open" connections aggressively when there are cached connections */ /* "open" connections aggressively when there are cached connections */
nodeConnectionCount = MaxCachedConnectionsPerWorker; int nodeConnectionCount = MaxCachedConnectionsPerWorker;
workerPool->maxNewConnectionsPerCycle = Max(1, nodeConnectionCount); workerPool->maxNewConnectionsPerCycle = Max(1, nodeConnectionCount);
dlist_init(&workerPool->pendingTaskQueue); dlist_init(&workerPool->pendingTaskQueue);
@ -1775,8 +1762,6 @@ FindOrCreateWorkerSession(WorkerPool *workerPool, MultiConnection *connection)
static bool static bool
ShouldRunTasksSequentially(List *taskList) ShouldRunTasksSequentially(List *taskList)
{ {
Task *initialTask = NULL;
if (list_length(taskList) < 2) if (list_length(taskList) < 2)
{ {
/* single task plans are already qualified as sequential by definition */ /* single task plans are already qualified as sequential by definition */
@ -1784,7 +1769,7 @@ ShouldRunTasksSequentially(List *taskList)
} }
/* all the tasks are the same, so we only look one */ /* all the tasks are the same, so we only look one */
initialTask = (Task *) linitial(taskList); Task *initialTask = (Task *) linitial(taskList);
if (initialTask->rowValuesLists != NIL) if (initialTask->rowValuesLists != NIL)
{ {
/* found a multi-row INSERT */ /* found a multi-row INSERT */
@ -1860,7 +1845,6 @@ RunDistributedExecution(DistributedExecution *execution)
while (execution->unfinishedTaskCount > 0 && !cancellationReceived) while (execution->unfinishedTaskCount > 0 && !cancellationReceived)
{ {
int eventCount = 0;
int eventIndex = 0; int eventIndex = 0;
ListCell *workerCell = NULL; ListCell *workerCell = NULL;
long timeout = NextEventTimeout(execution); long timeout = NextEventTimeout(execution);
@ -1906,14 +1890,13 @@ RunDistributedExecution(DistributedExecution *execution)
} }
/* wait for I/O events */ /* wait for I/O events */
eventCount = WaitEventSetWait(execution->waitEventSet, timeout, events, int eventCount = WaitEventSetWait(execution->waitEventSet, timeout, events,
eventSetSize, WAIT_EVENT_CLIENT_READ); eventSetSize, WAIT_EVENT_CLIENT_READ);
/* process I/O events */ /* process I/O events */
for (; eventIndex < eventCount; eventIndex++) for (; eventIndex < eventCount; eventIndex++)
{ {
WaitEvent *event = &events[eventIndex]; WaitEvent *event = &events[eventIndex];
WorkerSession *session = NULL;
if (event->events & WL_POSTMASTER_DEATH) if (event->events & WL_POSTMASTER_DEATH)
{ {
@ -1944,7 +1927,7 @@ RunDistributedExecution(DistributedExecution *execution)
continue; continue;
} }
session = (WorkerSession *) event->user_data; WorkerSession *session = (WorkerSession *) event->user_data;
session->latestUnconsumedWaitEvents = event->events; session->latestUnconsumedWaitEvents = event->events;
ConnectionStateMachine(session); ConnectionStateMachine(session);
@ -2001,7 +1984,6 @@ ManageWorkerPool(WorkerPool *workerPool)
int failedConnectionCount = workerPool->failedConnectionCount; int failedConnectionCount = workerPool->failedConnectionCount;
int readyTaskCount = workerPool->readyTaskCount; int readyTaskCount = workerPool->readyTaskCount;
int newConnectionCount = 0; int newConnectionCount = 0;
int connectionIndex = 0;
/* we should always have more (or equal) active connections than idle connections */ /* we should always have more (or equal) active connections than idle connections */
Assert(activeConnectionCount >= idleConnectionCount); Assert(activeConnectionCount >= idleConnectionCount);
@ -2091,16 +2073,13 @@ ManageWorkerPool(WorkerPool *workerPool)
ereport(DEBUG4, (errmsg("opening %d new connections to %s:%d", newConnectionCount, ereport(DEBUG4, (errmsg("opening %d new connections to %s:%d", newConnectionCount,
workerPool->nodeName, workerPool->nodePort))); workerPool->nodeName, workerPool->nodePort)));
for (connectionIndex = 0; connectionIndex < newConnectionCount; connectionIndex++) for (int connectionIndex = 0; connectionIndex < newConnectionCount; connectionIndex++)
{ {
MultiConnection *connection = NULL;
WorkerSession *session = NULL;
/* experimental: just to see the perf benefits of caching connections */ /* experimental: just to see the perf benefits of caching connections */
int connectionFlags = 0; int connectionFlags = 0;
/* open a new connection to the worker */ /* open a new connection to the worker */
connection = StartNodeUserDatabaseConnection(connectionFlags, MultiConnection *connection = StartNodeUserDatabaseConnection(connectionFlags,
workerPool->nodeName, workerPool->nodeName,
workerPool->nodePort, workerPool->nodePort,
NULL, NULL); NULL, NULL);
@ -2119,7 +2098,7 @@ ManageWorkerPool(WorkerPool *workerPool)
connection->claimedExclusively = true; connection->claimedExclusively = true;
/* create a session for the connection */ /* create a session for the connection */
session = FindOrCreateWorkerSession(workerPool, connection); WorkerSession *session = FindOrCreateWorkerSession(workerPool, connection);
/* always poll the connection in the first round */ /* always poll the connection in the first round */
UpdateConnectionWaitFlags(session, WL_SOCKET_READABLE | WL_SOCKET_WRITEABLE); UpdateConnectionWaitFlags(session, WL_SOCKET_READABLE | WL_SOCKET_WRITEABLE);
@ -2250,7 +2229,6 @@ NextEventTimeout(DistributedExecution *execution)
foreach(workerCell, execution->workerList) foreach(workerCell, execution->workerList)
{ {
WorkerPool *workerPool = (WorkerPool *) lfirst(workerCell); WorkerPool *workerPool = (WorkerPool *) lfirst(workerCell);
int initiatedConnectionCount = 0;
if (workerPool->failed) if (workerPool->failed)
{ {
@ -2278,7 +2256,7 @@ NextEventTimeout(DistributedExecution *execution)
} }
} }
initiatedConnectionCount = list_length(workerPool->sessionList); int initiatedConnectionCount = list_length(workerPool->sessionList);
/* /*
* If there are connections to open we wait at most up to the end of the * If there are connections to open we wait at most up to the end of the
@ -2347,8 +2325,6 @@ ConnectionStateMachine(WorkerSession *session)
case MULTI_CONNECTION_CONNECTING: case MULTI_CONNECTION_CONNECTING:
{ {
PostgresPollingStatusType pollMode;
ConnStatusType status = PQstatus(connection->pgConn); ConnStatusType status = PQstatus(connection->pgConn);
if (status == CONNECTION_OK) if (status == CONNECTION_OK)
{ {
@ -2372,7 +2348,7 @@ ConnectionStateMachine(WorkerSession *session)
break; break;
} }
pollMode = PQconnectPoll(connection->pgConn); PostgresPollingStatusType pollMode = PQconnectPoll(connection->pgConn);
if (pollMode == PGRES_POLLING_FAILED) if (pollMode == PGRES_POLLING_FAILED)
{ {
connection->connectionState = MULTI_CONNECTION_FAILED; connection->connectionState = MULTI_CONNECTION_FAILED;
@ -2543,15 +2519,13 @@ ConnectionStateMachine(WorkerSession *session)
static void static void
Activate2PCIfModifyingTransactionExpandsToNewNode(WorkerSession *session) Activate2PCIfModifyingTransactionExpandsToNewNode(WorkerSession *session)
{ {
DistributedExecution *execution = NULL;
if (MultiShardCommitProtocol != COMMIT_PROTOCOL_2PC) if (MultiShardCommitProtocol != COMMIT_PROTOCOL_2PC)
{ {
/* we don't need 2PC, so no need to continue */ /* we don't need 2PC, so no need to continue */
return; return;
} }
execution = session->workerPool->distributedExecution; DistributedExecution *execution = session->workerPool->distributedExecution;
if (TransactionModifiedDistributedTable(execution) && if (TransactionModifiedDistributedTable(execution) &&
DistributedExecutionModifiesDatabase(execution) && DistributedExecutionModifiesDatabase(execution) &&
!ConnectionModifiedPlacement(session->connection)) !ConnectionModifiedPlacement(session->connection))
@ -2622,10 +2596,8 @@ TransactionStateMachine(WorkerSession *session)
} }
else else
{ {
TaskPlacementExecution *placementExecution = NULL; TaskPlacementExecution *placementExecution = PopPlacementExecution(
bool placementExecutionStarted = false; session);
placementExecution = PopPlacementExecution(session);
if (placementExecution == NULL) if (placementExecution == NULL)
{ {
/* /*
@ -2637,7 +2609,7 @@ TransactionStateMachine(WorkerSession *session)
break; break;
} }
placementExecutionStarted = bool placementExecutionStarted =
StartPlacementExecutionOnSession(placementExecution, session); StartPlacementExecutionOnSession(placementExecution, session);
if (!placementExecutionStarted) if (!placementExecutionStarted)
{ {
@ -2659,9 +2631,7 @@ TransactionStateMachine(WorkerSession *session)
case REMOTE_TRANS_SENT_BEGIN: case REMOTE_TRANS_SENT_BEGIN:
case REMOTE_TRANS_CLEARING_RESULTS: case REMOTE_TRANS_CLEARING_RESULTS:
{ {
PGresult *result = NULL; PGresult *result = PQgetResult(connection->pgConn);
result = PQgetResult(connection->pgConn);
if (result != NULL) if (result != NULL)
{ {
if (!IsResponseOK(result)) if (!IsResponseOK(result))
@ -2715,10 +2685,8 @@ TransactionStateMachine(WorkerSession *session)
case REMOTE_TRANS_STARTED: case REMOTE_TRANS_STARTED:
{ {
TaskPlacementExecution *placementExecution = NULL; TaskPlacementExecution *placementExecution = PopPlacementExecution(
bool placementExecutionStarted = false; session);
placementExecution = PopPlacementExecution(session);
if (placementExecution == NULL) if (placementExecution == NULL)
{ {
/* no tasks are ready to be executed at the moment */ /* no tasks are ready to be executed at the moment */
@ -2726,7 +2694,7 @@ TransactionStateMachine(WorkerSession *session)
break; break;
} }
placementExecutionStarted = bool placementExecutionStarted =
StartPlacementExecutionOnSession(placementExecution, session); StartPlacementExecutionOnSession(placementExecution, session);
if (!placementExecutionStarted) if (!placementExecutionStarted)
{ {
@ -2742,7 +2710,6 @@ TransactionStateMachine(WorkerSession *session)
case REMOTE_TRANS_SENT_COMMAND: case REMOTE_TRANS_SENT_COMMAND:
{ {
bool fetchDone = false;
TaskPlacementExecution *placementExecution = session->currentTask; TaskPlacementExecution *placementExecution = session->currentTask;
ShardCommandExecution *shardCommandExecution = ShardCommandExecution *shardCommandExecution =
placementExecution->shardCommandExecution; placementExecution->shardCommandExecution;
@ -2754,7 +2721,7 @@ TransactionStateMachine(WorkerSession *session)
storeRows = false; storeRows = false;
} }
fetchDone = ReceiveResults(session, storeRows); bool fetchDone = ReceiveResults(session, storeRows);
if (!fetchDone) if (!fetchDone)
{ {
break; break;
@ -2810,7 +2777,6 @@ UpdateConnectionWaitFlags(WorkerSession *session, int waitFlags)
static bool static bool
CheckConnectionReady(WorkerSession *session) CheckConnectionReady(WorkerSession *session)
{ {
int sendStatus = 0;
MultiConnection *connection = session->connection; MultiConnection *connection = session->connection;
int waitFlags = WL_SOCKET_READABLE; int waitFlags = WL_SOCKET_READABLE;
bool connectionReady = false; bool connectionReady = false;
@ -2823,7 +2789,7 @@ CheckConnectionReady(WorkerSession *session)
} }
/* try to send all pending data */ /* try to send all pending data */
sendStatus = PQflush(connection->pgConn); int sendStatus = PQflush(connection->pgConn);
if (sendStatus == -1) if (sendStatus == -1)
{ {
connection->connectionState = MULTI_CONNECTION_LOST; connection->connectionState = MULTI_CONNECTION_LOST;
@ -2865,10 +2831,9 @@ CheckConnectionReady(WorkerSession *session)
static TaskPlacementExecution * static TaskPlacementExecution *
PopPlacementExecution(WorkerSession *session) PopPlacementExecution(WorkerSession *session)
{ {
TaskPlacementExecution *placementExecution = NULL;
WorkerPool *workerPool = session->workerPool; WorkerPool *workerPool = session->workerPool;
placementExecution = PopAssignedPlacementExecution(session); TaskPlacementExecution *placementExecution = PopAssignedPlacementExecution(session);
if (placementExecution == NULL) if (placementExecution == NULL)
{ {
if (session->commandsSent > 0 && UseConnectionPerPlacement()) if (session->commandsSent > 0 && UseConnectionPerPlacement())
@ -2894,7 +2859,6 @@ PopPlacementExecution(WorkerSession *session)
static TaskPlacementExecution * static TaskPlacementExecution *
PopAssignedPlacementExecution(WorkerSession *session) PopAssignedPlacementExecution(WorkerSession *session)
{ {
TaskPlacementExecution *placementExecution = NULL;
dlist_head *readyTaskQueue = &(session->readyTaskQueue); dlist_head *readyTaskQueue = &(session->readyTaskQueue);
if (dlist_is_empty(readyTaskQueue)) if (dlist_is_empty(readyTaskQueue))
@ -2902,9 +2866,10 @@ PopAssignedPlacementExecution(WorkerSession *session)
return NULL; return NULL;
} }
placementExecution = dlist_container(TaskPlacementExecution, TaskPlacementExecution *placementExecution = dlist_container(TaskPlacementExecution,
sessionReadyQueueNode, sessionReadyQueueNode,
dlist_pop_head_node(readyTaskQueue)); dlist_pop_head_node(
readyTaskQueue));
return placementExecution; return placementExecution;
} }
@ -2916,7 +2881,6 @@ PopAssignedPlacementExecution(WorkerSession *session)
static TaskPlacementExecution * static TaskPlacementExecution *
PopUnassignedPlacementExecution(WorkerPool *workerPool) PopUnassignedPlacementExecution(WorkerPool *workerPool)
{ {
TaskPlacementExecution *placementExecution = NULL;
dlist_head *readyTaskQueue = &(workerPool->readyTaskQueue); dlist_head *readyTaskQueue = &(workerPool->readyTaskQueue);
if (dlist_is_empty(readyTaskQueue)) if (dlist_is_empty(readyTaskQueue))
@ -2924,9 +2888,10 @@ PopUnassignedPlacementExecution(WorkerPool *workerPool)
return NULL; return NULL;
} }
placementExecution = dlist_container(TaskPlacementExecution, TaskPlacementExecution *placementExecution = dlist_container(TaskPlacementExecution,
workerReadyQueueNode, workerReadyQueueNode,
dlist_pop_head_node(readyTaskQueue)); dlist_pop_head_node(
readyTaskQueue));
workerPool->readyTaskCount--; workerPool->readyTaskCount--;
@ -2960,7 +2925,6 @@ StartPlacementExecutionOnSession(TaskPlacementExecution *placementExecution,
List *placementAccessList = PlacementAccessListForTask(task, taskPlacement); List *placementAccessList = PlacementAccessListForTask(task, taskPlacement);
char *queryString = task->queryString; char *queryString = task->queryString;
int querySent = 0; int querySent = 0;
int singleRowMode = 0;
/* /*
* Make sure that subsequent commands on the same placement * Make sure that subsequent commands on the same placement
@ -3007,7 +2971,7 @@ StartPlacementExecutionOnSession(TaskPlacementExecution *placementExecution,
return false; return false;
} }
singleRowMode = PQsetSingleRowMode(connection->pgConn); int singleRowMode = PQsetSingleRowMode(connection->pgConn);
if (singleRowMode == 0) if (singleRowMode == 0)
{ {
connection->connectionState = MULTI_CONNECTION_LOST; connection->connectionState = MULTI_CONNECTION_LOST;
@ -3036,7 +3000,6 @@ ReceiveResults(WorkerSession *session, bool storeRows)
uint32 expectedColumnCount = 0; uint32 expectedColumnCount = 0;
char **columnArray = execution->columnArray; char **columnArray = execution->columnArray;
Tuplestorestate *tupleStore = execution->tupleStore; Tuplestorestate *tupleStore = execution->tupleStore;
MemoryContext ioContext = NULL;
if (tupleDescriptor != NULL) if (tupleDescriptor != NULL)
{ {
@ -3048,7 +3011,7 @@ ReceiveResults(WorkerSession *session, bool storeRows)
* into tuple. The context is reseted on every row, thus we create it at the * into tuple. The context is reseted on every row, thus we create it at the
* start of the loop and reset on every iteration. * start of the loop and reset on every iteration.
*/ */
ioContext = AllocSetContextCreate(CurrentMemoryContext, MemoryContext ioContext = AllocSetContextCreate(CurrentMemoryContext,
"IoContext", "IoContext",
ALLOCSET_DEFAULT_MINSIZE, ALLOCSET_DEFAULT_MINSIZE,
ALLOCSET_DEFAULT_INITSIZE, ALLOCSET_DEFAULT_INITSIZE,
@ -3056,11 +3019,8 @@ ReceiveResults(WorkerSession *session, bool storeRows)
while (!PQisBusy(connection->pgConn)) while (!PQisBusy(connection->pgConn))
{ {
uint32 rowIndex = 0;
uint32 columnIndex = 0; uint32 columnIndex = 0;
uint32 rowsProcessed = 0; uint32 rowsProcessed = 0;
uint32 columnCount = 0;
ExecStatusType resultStatus = 0;
PGresult *result = PQgetResult(connection->pgConn); PGresult *result = PQgetResult(connection->pgConn);
if (result == NULL) if (result == NULL)
@ -3070,7 +3030,7 @@ ReceiveResults(WorkerSession *session, bool storeRows)
break; break;
} }
resultStatus = PQresultStatus(result); ExecStatusType resultStatus = PQresultStatus(result);
if (resultStatus == PGRES_COMMAND_OK) if (resultStatus == PGRES_COMMAND_OK)
{ {
char *currentAffectedTupleString = PQcmdTuples(result); char *currentAffectedTupleString = PQcmdTuples(result);
@ -3121,7 +3081,7 @@ ReceiveResults(WorkerSession *session, bool storeRows)
} }
rowsProcessed = PQntuples(result); rowsProcessed = PQntuples(result);
columnCount = PQnfields(result); uint32 columnCount = PQnfields(result);
if (columnCount != expectedColumnCount) if (columnCount != expectedColumnCount)
{ {
@ -3130,10 +3090,8 @@ ReceiveResults(WorkerSession *session, bool storeRows)
columnCount, expectedColumnCount))); columnCount, expectedColumnCount)));
} }
for (rowIndex = 0; rowIndex < rowsProcessed; rowIndex++) for (uint32 rowIndex = 0; rowIndex < rowsProcessed; rowIndex++)
{ {
HeapTuple heapTuple = NULL;
MemoryContext oldContextPerRow = NULL;
memset(columnArray, 0, columnCount * sizeof(char *)); memset(columnArray, 0, columnCount * sizeof(char *));
for (columnIndex = 0; columnIndex < columnCount; columnIndex++) for (columnIndex = 0; columnIndex < columnCount; columnIndex++)
@ -3159,9 +3117,10 @@ ReceiveResults(WorkerSession *session, bool storeRows)
* protects us from any memory leaks that might be present in I/O functions * protects us from any memory leaks that might be present in I/O functions
* called by BuildTupleFromCStrings. * called by BuildTupleFromCStrings.
*/ */
oldContextPerRow = MemoryContextSwitchTo(ioContext); MemoryContext oldContextPerRow = MemoryContextSwitchTo(ioContext);
heapTuple = BuildTupleFromCStrings(attributeInputMetadata, columnArray); HeapTuple heapTuple = BuildTupleFromCStrings(attributeInputMetadata,
columnArray);
MemoryContextSwitchTo(oldContextPerRow); MemoryContextSwitchTo(oldContextPerRow);
@ -3309,7 +3268,6 @@ PlacementExecutionDone(TaskPlacementExecution *placementExecution, bool succeede
ShardCommandExecution *shardCommandExecution = ShardCommandExecution *shardCommandExecution =
placementExecution->shardCommandExecution; placementExecution->shardCommandExecution;
TaskExecutionState executionState = shardCommandExecution->executionState; TaskExecutionState executionState = shardCommandExecution->executionState;
TaskExecutionState newExecutionState = TASK_EXECUTION_NOT_FINISHED;
bool failedPlacementExecutionIsOnPendingQueue = false; bool failedPlacementExecutionIsOnPendingQueue = false;
/* mark the placement execution as finished */ /* mark the placement execution as finished */
@ -3360,7 +3318,8 @@ PlacementExecutionDone(TaskPlacementExecution *placementExecution, bool succeede
* Update unfinishedTaskCount only when state changes from not finished to * Update unfinishedTaskCount only when state changes from not finished to
* finished or failed state. * finished or failed state.
*/ */
newExecutionState = TaskExecutionStateMachine(shardCommandExecution); TaskExecutionState newExecutionState = TaskExecutionStateMachine(
shardCommandExecution);
if (newExecutionState == TASK_EXECUTION_FINISHED) if (newExecutionState == TASK_EXECUTION_FINISHED)
{ {
execution->unfinishedTaskCount--; execution->unfinishedTaskCount--;
@ -3597,21 +3556,18 @@ TaskExecutionStateMachine(ShardCommandExecution *shardCommandExecution)
static WaitEventSet * static WaitEventSet *
BuildWaitEventSet(List *sessionList) BuildWaitEventSet(List *sessionList)
{ {
WaitEventSet *waitEventSet = NULL;
ListCell *sessionCell = NULL; ListCell *sessionCell = NULL;
/* additional 2 is for postmaster and latch */ /* additional 2 is for postmaster and latch */
int eventSetSize = list_length(sessionList) + 2; int eventSetSize = list_length(sessionList) + 2;
waitEventSet = WaitEventSet *waitEventSet =
CreateWaitEventSet(CurrentMemoryContext, eventSetSize); CreateWaitEventSet(CurrentMemoryContext, eventSetSize);
foreach(sessionCell, sessionList) foreach(sessionCell, sessionList)
{ {
WorkerSession *session = lfirst(sessionCell); WorkerSession *session = lfirst(sessionCell);
MultiConnection *connection = session->connection; MultiConnection *connection = session->connection;
int sock = 0;
int waitEventSetIndex = 0;
if (connection->pgConn == NULL) if (connection->pgConn == NULL)
{ {
@ -3625,14 +3581,15 @@ BuildWaitEventSet(List *sessionList)
continue; continue;
} }
sock = PQsocket(connection->pgConn); int sock = PQsocket(connection->pgConn);
if (sock == -1) if (sock == -1)
{ {
/* connection was closed */ /* connection was closed */
continue; continue;
} }
waitEventSetIndex = AddWaitEventToSet(waitEventSet, connection->waitFlags, sock, int waitEventSetIndex = AddWaitEventToSet(waitEventSet, connection->waitFlags,
sock,
NULL, (void *) session); NULL, (void *) session);
session->waitEventSetIndex = waitEventSetIndex; session->waitEventSetIndex = waitEventSetIndex;
} }
@ -3657,7 +3614,6 @@ UpdateWaitEventSetFlags(WaitEventSet *waitEventSet, List *sessionList)
{ {
WorkerSession *session = lfirst(sessionCell); WorkerSession *session = lfirst(sessionCell);
MultiConnection *connection = session->connection; MultiConnection *connection = session->connection;
int sock = 0;
int waitEventSetIndex = session->waitEventSetIndex; int waitEventSetIndex = session->waitEventSetIndex;
if (connection->pgConn == NULL) if (connection->pgConn == NULL)
@ -3672,7 +3628,7 @@ UpdateWaitEventSetFlags(WaitEventSet *waitEventSet, List *sessionList)
continue; continue;
} }
sock = PQsocket(connection->pgConn); int sock = PQsocket(connection->pgConn);
if (sock == -1) if (sock == -1)
{ {
/* connection was closed */ /* connection was closed */
@ -3724,14 +3680,13 @@ ExtractParametersFromParamList(ParamListInfo paramListInfo,
const char ***parameterValues, bool const char ***parameterValues, bool
useOriginalCustomTypeOids) useOriginalCustomTypeOids)
{ {
int parameterIndex = 0;
int parameterCount = paramListInfo->numParams; int parameterCount = paramListInfo->numParams;
*parameterTypes = (Oid *) palloc0(parameterCount * sizeof(Oid)); *parameterTypes = (Oid *) palloc0(parameterCount * sizeof(Oid));
*parameterValues = (const char **) palloc0(parameterCount * sizeof(char *)); *parameterValues = (const char **) palloc0(parameterCount * sizeof(char *));
/* get parameter types and values */ /* get parameter types and values */
for (parameterIndex = 0; parameterIndex < parameterCount; parameterIndex++) for (int parameterIndex = 0; parameterIndex < parameterCount; parameterIndex++)
{ {
ParamExternData *parameterData = &paramListInfo->params[parameterIndex]; ParamExternData *parameterData = &paramListInfo->params[parameterIndex];
Oid typeOutputFunctionId = InvalidOid; Oid typeOutputFunctionId = InvalidOid;

View File

@ -119,12 +119,11 @@ RegisterCitusCustomScanMethods(void)
static void static void
CitusBeginScan(CustomScanState *node, EState *estate, int eflags) CitusBeginScan(CustomScanState *node, EState *estate, int eflags)
{ {
CitusScanState *scanState = NULL;
DistributedPlan *distributedPlan = NULL; DistributedPlan *distributedPlan = NULL;
MarkCitusInitiatedCoordinatorBackend(); MarkCitusInitiatedCoordinatorBackend();
scanState = (CitusScanState *) node; CitusScanState *scanState = (CitusScanState *) node;
#if PG_VERSION_NUM >= 120000 #if PG_VERSION_NUM >= 120000
ExecInitResultSlot(&scanState->customScanState.ss.ps, &TTSOpsMinimalTuple); ExecInitResultSlot(&scanState->customScanState.ss.ps, &TTSOpsMinimalTuple);
@ -152,7 +151,6 @@ TupleTableSlot *
CitusExecScan(CustomScanState *node) CitusExecScan(CustomScanState *node)
{ {
CitusScanState *scanState = (CitusScanState *) node; CitusScanState *scanState = (CitusScanState *) node;
TupleTableSlot *resultSlot = NULL;
if (!scanState->finishedRemoteScan) if (!scanState->finishedRemoteScan)
{ {
@ -161,7 +159,7 @@ CitusExecScan(CustomScanState *node)
scanState->finishedRemoteScan = true; scanState->finishedRemoteScan = true;
} }
resultSlot = ReturnTupleFromTuplestore(scanState); TupleTableSlot *resultSlot = ReturnTupleFromTuplestore(scanState);
return resultSlot; return resultSlot;
} }
@ -179,21 +177,18 @@ static void
CitusModifyBeginScan(CustomScanState *node, EState *estate, int eflags) CitusModifyBeginScan(CustomScanState *node, EState *estate, int eflags)
{ {
CitusScanState *scanState = (CitusScanState *) node; CitusScanState *scanState = (CitusScanState *) node;
DistributedPlan *distributedPlan = NULL;
Job *workerJob = NULL;
Query *jobQuery = NULL;
List *taskList = NIL;
/* /*
* We must not change the distributed plan since it may be reused across multiple * We must not change the distributed plan since it may be reused across multiple
* executions of a prepared statement. Instead we create a deep copy that we only * executions of a prepared statement. Instead we create a deep copy that we only
* use for the current execution. * use for the current execution.
*/ */
distributedPlan = scanState->distributedPlan = copyObject(scanState->distributedPlan); DistributedPlan *distributedPlan = scanState->distributedPlan = copyObject(
scanState->distributedPlan);
workerJob = distributedPlan->workerJob; Job *workerJob = distributedPlan->workerJob;
jobQuery = workerJob->jobQuery; Query *jobQuery = workerJob->jobQuery;
taskList = workerJob->taskList; List *taskList = workerJob->taskList;
if (workerJob->requiresMasterEvaluation) if (workerJob->requiresMasterEvaluation)
{ {
@ -407,8 +402,6 @@ ScanStateGetExecutorState(CitusScanState *scanState)
CustomScan * CustomScan *
FetchCitusCustomScanIfExists(Plan *plan) FetchCitusCustomScanIfExists(Plan *plan)
{ {
CustomScan *customScan = NULL;
if (plan == NULL) if (plan == NULL)
{ {
return NULL; return NULL;
@ -419,7 +412,7 @@ FetchCitusCustomScanIfExists(Plan *plan)
return (CustomScan *) plan; return (CustomScan *) plan;
} }
customScan = FetchCitusCustomScanIfExists(plan->lefttree); CustomScan *customScan = FetchCitusCustomScanIfExists(plan->lefttree);
if (customScan == NULL) if (customScan == NULL)
{ {
@ -457,9 +450,6 @@ IsCitusPlan(Plan *plan)
bool bool
IsCitusCustomScan(Plan *plan) IsCitusCustomScan(Plan *plan)
{ {
CustomScan *customScan = NULL;
Node *privateNode = NULL;
if (plan == NULL) if (plan == NULL)
{ {
return false; return false;
@ -470,13 +460,13 @@ IsCitusCustomScan(Plan *plan)
return false; return false;
} }
customScan = (CustomScan *) plan; CustomScan *customScan = (CustomScan *) plan;
if (list_length(customScan->custom_private) == 0) if (list_length(customScan->custom_private) == 0)
{ {
return false; return false;
} }
privateNode = (Node *) linitial(customScan->custom_private); Node *privateNode = (Node *) linitial(customScan->custom_private);
if (!CitusIsA(privateNode, DistributedPlan)) if (!CitusIsA(privateNode, DistributedPlan))
{ {
return false; return false;

View File

@ -93,7 +93,6 @@ static TupleTableSlot *
CoordinatorInsertSelectExecScanInternal(CustomScanState *node) CoordinatorInsertSelectExecScanInternal(CustomScanState *node)
{ {
CitusScanState *scanState = (CitusScanState *) node; CitusScanState *scanState = (CitusScanState *) node;
TupleTableSlot *resultSlot = NULL;
if (!scanState->finishedRemoteScan) if (!scanState->finishedRemoteScan)
{ {
@ -197,7 +196,7 @@ CoordinatorInsertSelectExecScanInternal(CustomScanState *node)
scanState->finishedRemoteScan = true; scanState->finishedRemoteScan = true;
} }
resultSlot = ReturnTupleFromTuplestore(scanState); TupleTableSlot *resultSlot = ReturnTupleFromTuplestore(scanState);
return resultSlot; return resultSlot;
} }
@ -217,36 +216,34 @@ ExecuteSelectIntoColocatedIntermediateResults(Oid targetRelationId,
char *intermediateResultIdPrefix) char *intermediateResultIdPrefix)
{ {
ParamListInfo paramListInfo = executorState->es_param_list_info; ParamListInfo paramListInfo = executorState->es_param_list_info;
int partitionColumnIndex = -1;
List *columnNameList = NIL;
bool stopOnFailure = false; bool stopOnFailure = false;
char partitionMethod = 0;
CitusCopyDestReceiver *copyDest = NULL;
Query *queryCopy = NULL;
partitionMethod = PartitionMethod(targetRelationId); char partitionMethod = PartitionMethod(targetRelationId);
if (partitionMethod == DISTRIBUTE_BY_NONE) if (partitionMethod == DISTRIBUTE_BY_NONE)
{ {
stopOnFailure = true; stopOnFailure = true;
} }
/* Get column name list and partition column index for the target table */ /* Get column name list and partition column index for the target table */
columnNameList = BuildColumnNameListFromTargetList(targetRelationId, List *columnNameList = BuildColumnNameListFromTargetList(targetRelationId,
insertTargetList); insertTargetList);
partitionColumnIndex = PartitionColumnIndexFromColumnList(targetRelationId, int partitionColumnIndex = PartitionColumnIndexFromColumnList(targetRelationId,
columnNameList); columnNameList);
/* set up a DestReceiver that copies into the intermediate table */ /* set up a DestReceiver that copies into the intermediate table */
copyDest = CreateCitusCopyDestReceiver(targetRelationId, columnNameList, CitusCopyDestReceiver *copyDest = CreateCitusCopyDestReceiver(targetRelationId,
partitionColumnIndex, executorState, columnNameList,
stopOnFailure, intermediateResultIdPrefix); partitionColumnIndex,
executorState,
stopOnFailure,
intermediateResultIdPrefix);
/* /*
* Make a copy of the query, since ExecuteQueryIntoDestReceiver may scribble on it * Make a copy of the query, since ExecuteQueryIntoDestReceiver may scribble on it
* and we want it to be replanned every time if it is stored in a prepared * and we want it to be replanned every time if it is stored in a prepared
* statement. * statement.
*/ */
queryCopy = copyObject(selectQuery); Query *queryCopy = copyObject(selectQuery);
ExecuteQueryIntoDestReceiver(queryCopy, paramListInfo, (DestReceiver *) copyDest); ExecuteQueryIntoDestReceiver(queryCopy, paramListInfo, (DestReceiver *) copyDest);
@ -268,28 +265,25 @@ ExecuteSelectIntoRelation(Oid targetRelationId, List *insertTargetList,
Query *selectQuery, EState *executorState) Query *selectQuery, EState *executorState)
{ {
ParamListInfo paramListInfo = executorState->es_param_list_info; ParamListInfo paramListInfo = executorState->es_param_list_info;
int partitionColumnIndex = -1;
List *columnNameList = NIL;
bool stopOnFailure = false; bool stopOnFailure = false;
char partitionMethod = 0;
CitusCopyDestReceiver *copyDest = NULL;
Query *queryCopy = NULL;
partitionMethod = PartitionMethod(targetRelationId); char partitionMethod = PartitionMethod(targetRelationId);
if (partitionMethod == DISTRIBUTE_BY_NONE) if (partitionMethod == DISTRIBUTE_BY_NONE)
{ {
stopOnFailure = true; stopOnFailure = true;
} }
/* Get column name list and partition column index for the target table */ /* Get column name list and partition column index for the target table */
columnNameList = BuildColumnNameListFromTargetList(targetRelationId, List *columnNameList = BuildColumnNameListFromTargetList(targetRelationId,
insertTargetList); insertTargetList);
partitionColumnIndex = PartitionColumnIndexFromColumnList(targetRelationId, int partitionColumnIndex = PartitionColumnIndexFromColumnList(targetRelationId,
columnNameList); columnNameList);
/* set up a DestReceiver that copies into the distributed table */ /* set up a DestReceiver that copies into the distributed table */
copyDest = CreateCitusCopyDestReceiver(targetRelationId, columnNameList, CitusCopyDestReceiver *copyDest = CreateCitusCopyDestReceiver(targetRelationId,
partitionColumnIndex, executorState, columnNameList,
partitionColumnIndex,
executorState,
stopOnFailure, NULL); stopOnFailure, NULL);
/* /*
@ -297,7 +291,7 @@ ExecuteSelectIntoRelation(Oid targetRelationId, List *insertTargetList,
* and we want it to be replanned every time if it is stored in a prepared * and we want it to be replanned every time if it is stored in a prepared
* statement. * statement.
*/ */
queryCopy = copyObject(selectQuery); Query *queryCopy = copyObject(selectQuery);
ExecuteQueryIntoDestReceiver(queryCopy, paramListInfo, (DestReceiver *) copyDest); ExecuteQueryIntoDestReceiver(queryCopy, paramListInfo, (DestReceiver *) copyDest);

View File

@ -111,10 +111,7 @@ broadcast_intermediate_result(PG_FUNCTION_ARGS)
char *resultIdString = text_to_cstring(resultIdText); char *resultIdString = text_to_cstring(resultIdText);
text *queryText = PG_GETARG_TEXT_P(1); text *queryText = PG_GETARG_TEXT_P(1);
char *queryString = text_to_cstring(queryText); char *queryString = text_to_cstring(queryText);
EState *estate = NULL;
List *nodeList = NIL;
bool writeLocalFile = false; bool writeLocalFile = false;
RemoteFileDestReceiver *resultDest = NULL;
ParamListInfo paramListInfo = NULL; ParamListInfo paramListInfo = NULL;
CheckCitusVersion(ERROR); CheckCitusVersion(ERROR);
@ -127,10 +124,12 @@ broadcast_intermediate_result(PG_FUNCTION_ARGS)
*/ */
BeginOrContinueCoordinatedTransaction(); BeginOrContinueCoordinatedTransaction();
nodeList = ActivePrimaryWorkerNodeList(NoLock); List *nodeList = ActivePrimaryWorkerNodeList(NoLock);
estate = CreateExecutorState(); EState *estate = CreateExecutorState();
resultDest = (RemoteFileDestReceiver *) CreateRemoteFileDestReceiver(resultIdString, RemoteFileDestReceiver *resultDest =
estate, nodeList, (RemoteFileDestReceiver *) CreateRemoteFileDestReceiver(resultIdString,
estate,
nodeList,
writeLocalFile); writeLocalFile);
ExecuteQueryStringIntoDestReceiver(queryString, paramListInfo, ExecuteQueryStringIntoDestReceiver(queryString, paramListInfo,
@ -153,10 +152,8 @@ create_intermediate_result(PG_FUNCTION_ARGS)
char *resultIdString = text_to_cstring(resultIdText); char *resultIdString = text_to_cstring(resultIdText);
text *queryText = PG_GETARG_TEXT_P(1); text *queryText = PG_GETARG_TEXT_P(1);
char *queryString = text_to_cstring(queryText); char *queryString = text_to_cstring(queryText);
EState *estate = NULL;
List *nodeList = NIL; List *nodeList = NIL;
bool writeLocalFile = true; bool writeLocalFile = true;
RemoteFileDestReceiver *resultDest = NULL;
ParamListInfo paramListInfo = NULL; ParamListInfo paramListInfo = NULL;
CheckCitusVersion(ERROR); CheckCitusVersion(ERROR);
@ -169,9 +166,11 @@ create_intermediate_result(PG_FUNCTION_ARGS)
*/ */
BeginOrContinueCoordinatedTransaction(); BeginOrContinueCoordinatedTransaction();
estate = CreateExecutorState(); EState *estate = CreateExecutorState();
resultDest = (RemoteFileDestReceiver *) CreateRemoteFileDestReceiver(resultIdString, RemoteFileDestReceiver *resultDest =
estate, nodeList, (RemoteFileDestReceiver *) CreateRemoteFileDestReceiver(resultIdString,
estate,
nodeList,
writeLocalFile); writeLocalFile);
ExecuteQueryStringIntoDestReceiver(queryString, paramListInfo, ExecuteQueryStringIntoDestReceiver(queryString, paramListInfo,
@ -193,9 +192,8 @@ DestReceiver *
CreateRemoteFileDestReceiver(char *resultId, EState *executorState, CreateRemoteFileDestReceiver(char *resultId, EState *executorState,
List *initialNodeList, bool writeLocalFile) List *initialNodeList, bool writeLocalFile)
{ {
RemoteFileDestReceiver *resultDest = NULL; RemoteFileDestReceiver *resultDest = (RemoteFileDestReceiver *) palloc0(
sizeof(RemoteFileDestReceiver));
resultDest = (RemoteFileDestReceiver *) palloc0(sizeof(RemoteFileDestReceiver));
/* set up the DestReceiver function pointers */ /* set up the DestReceiver function pointers */
resultDest->pub.receiveSlot = RemoteFileDestReceiverReceive; resultDest->pub.receiveSlot = RemoteFileDestReceiverReceive;
@ -228,7 +226,6 @@ RemoteFileDestReceiverStartup(DestReceiver *dest, int operation,
const char *resultId = resultDest->resultId; const char *resultId = resultDest->resultId;
CopyOutState copyOutState = NULL;
const char *delimiterCharacter = "\t"; const char *delimiterCharacter = "\t";
const char *nullPrintCharacter = "\\N"; const char *nullPrintCharacter = "\\N";
@ -240,7 +237,7 @@ RemoteFileDestReceiverStartup(DestReceiver *dest, int operation,
resultDest->tupleDescriptor = inputTupleDescriptor; resultDest->tupleDescriptor = inputTupleDescriptor;
/* define how tuples will be serialised */ /* define how tuples will be serialised */
copyOutState = (CopyOutState) palloc0(sizeof(CopyOutStateData)); CopyOutState copyOutState = (CopyOutState) palloc0(sizeof(CopyOutStateData));
copyOutState->delim = (char *) delimiterCharacter; copyOutState->delim = (char *) delimiterCharacter;
copyOutState->null_print = (char *) nullPrintCharacter; copyOutState->null_print = (char *) nullPrintCharacter;
copyOutState->null_print_client = (char *) nullPrintCharacter; copyOutState->null_print_client = (char *) nullPrintCharacter;
@ -256,12 +253,11 @@ RemoteFileDestReceiverStartup(DestReceiver *dest, int operation,
{ {
const int fileFlags = (O_APPEND | O_CREAT | O_RDWR | O_TRUNC | PG_BINARY); const int fileFlags = (O_APPEND | O_CREAT | O_RDWR | O_TRUNC | PG_BINARY);
const int fileMode = (S_IRUSR | S_IWUSR); const int fileMode = (S_IRUSR | S_IWUSR);
const char *fileName = NULL;
/* make sure the directory exists */ /* make sure the directory exists */
CreateIntermediateResultsDirectory(); CreateIntermediateResultsDirectory();
fileName = QueryResultFileName(resultId); const char *fileName = QueryResultFileName(resultId);
resultDest->fileCompat = FileCompatFromFileStart(FileOpenForTransmit(fileName, resultDest->fileCompat = FileCompatFromFileStart(FileOpenForTransmit(fileName,
fileFlags, fileFlags,
@ -273,7 +269,6 @@ RemoteFileDestReceiverStartup(DestReceiver *dest, int operation,
WorkerNode *workerNode = (WorkerNode *) lfirst(initialNodeCell); WorkerNode *workerNode = (WorkerNode *) lfirst(initialNodeCell);
char *nodeName = workerNode->workerName; char *nodeName = workerNode->workerName;
int nodePort = workerNode->workerPort; int nodePort = workerNode->workerPort;
MultiConnection *connection = NULL;
/* /*
* We prefer to use a connection that is not associcated with * We prefer to use a connection that is not associcated with
@ -281,7 +276,7 @@ RemoteFileDestReceiverStartup(DestReceiver *dest, int operation,
* exclusively and that would prevent the consecutive DML/DDL * exclusively and that would prevent the consecutive DML/DDL
* use the same connection. * use the same connection.
*/ */
connection = StartNonDataAccessConnection(nodeName, nodePort); MultiConnection *connection = StartNonDataAccessConnection(nodeName, nodePort);
ClaimConnectionExclusively(connection); ClaimConnectionExclusively(connection);
MarkRemoteTransactionCritical(connection); MarkRemoteTransactionCritical(connection);
@ -296,12 +291,10 @@ RemoteFileDestReceiverStartup(DestReceiver *dest, int operation,
foreach(connectionCell, connectionList) foreach(connectionCell, connectionList)
{ {
MultiConnection *connection = (MultiConnection *) lfirst(connectionCell); MultiConnection *connection = (MultiConnection *) lfirst(connectionCell);
StringInfo copyCommand = NULL;
bool querySent = false;
copyCommand = ConstructCopyResultStatement(resultId); StringInfo copyCommand = ConstructCopyResultStatement(resultId);
querySent = SendRemoteCommand(connection, copyCommand->data); bool querySent = SendRemoteCommand(connection, copyCommand->data);
if (!querySent) if (!querySent)
{ {
ReportConnectionError(connection, ERROR); ReportConnectionError(connection, ERROR);
@ -371,8 +364,6 @@ RemoteFileDestReceiverReceive(TupleTableSlot *slot, DestReceiver *dest)
CopyOutState copyOutState = resultDest->copyOutState; CopyOutState copyOutState = resultDest->copyOutState;
FmgrInfo *columnOutputFunctions = resultDest->columnOutputFunctions; FmgrInfo *columnOutputFunctions = resultDest->columnOutputFunctions;
Datum *columnValues = NULL;
bool *columnNulls = NULL;
StringInfo copyData = copyOutState->fe_msgbuf; StringInfo copyData = copyOutState->fe_msgbuf;
EState *executorState = resultDest->executorState; EState *executorState = resultDest->executorState;
@ -381,8 +372,8 @@ RemoteFileDestReceiverReceive(TupleTableSlot *slot, DestReceiver *dest)
slot_getallattrs(slot); slot_getallattrs(slot);
columnValues = slot->tts_values; Datum *columnValues = slot->tts_values;
columnNulls = slot->tts_isnull; bool *columnNulls = slot->tts_isnull;
resetStringInfo(copyData); resetStringInfo(copyData);
@ -526,11 +517,9 @@ RemoteFileDestReceiverDestroy(DestReceiver *destReceiver)
void void
ReceiveQueryResultViaCopy(const char *resultId) ReceiveQueryResultViaCopy(const char *resultId)
{ {
const char *resultFileName = NULL;
CreateIntermediateResultsDirectory(); CreateIntermediateResultsDirectory();
resultFileName = QueryResultFileName(resultId); const char *resultFileName = QueryResultFileName(resultId);
RedirectCopyDataToRegularFile(resultFileName); RedirectCopyDataToRegularFile(resultFileName);
} }
@ -671,12 +660,10 @@ RemoveIntermediateResultsDirectory(void)
int64 int64
IntermediateResultSize(char *resultId) IntermediateResultSize(char *resultId)
{ {
char *resultFileName = NULL;
struct stat fileStat; struct stat fileStat;
int statOK = 0;
resultFileName = QueryResultFileName(resultId); char *resultFileName = QueryResultFileName(resultId);
statOK = stat(resultFileName, &fileStat); int statOK = stat(resultFileName, &fileStat);
if (statOK < 0) if (statOK < 0)
{ {
return -1; return -1;
@ -710,24 +697,21 @@ read_intermediate_result(PG_FUNCTION_ARGS)
Datum copyFormatLabelDatum = DirectFunctionCall1(enum_out, copyFormatOidDatum); Datum copyFormatLabelDatum = DirectFunctionCall1(enum_out, copyFormatOidDatum);
char *copyFormatLabel = DatumGetCString(copyFormatLabelDatum); char *copyFormatLabel = DatumGetCString(copyFormatLabelDatum);
char *resultFileName = NULL;
struct stat fileStat; struct stat fileStat;
int statOK = 0;
Tuplestorestate *tupstore = NULL;
TupleDesc tupleDescriptor = NULL; TupleDesc tupleDescriptor = NULL;
CheckCitusVersion(ERROR); CheckCitusVersion(ERROR);
resultFileName = QueryResultFileName(resultIdString); char *resultFileName = QueryResultFileName(resultIdString);
statOK = stat(resultFileName, &fileStat); int statOK = stat(resultFileName, &fileStat);
if (statOK != 0) if (statOK != 0)
{ {
ereport(ERROR, (errcode_for_file_access(), ereport(ERROR, (errcode_for_file_access(),
errmsg("result \"%s\" does not exist", resultIdString))); errmsg("result \"%s\" does not exist", resultIdString)));
} }
tupstore = SetupTuplestore(fcinfo, &tupleDescriptor); Tuplestorestate *tupstore = SetupTuplestore(fcinfo, &tupleDescriptor);
ReadFileIntoTupleStore(resultFileName, copyFormatLabel, tupleDescriptor, tupstore); ReadFileIntoTupleStore(resultFileName, copyFormatLabel, tupleDescriptor, tupstore);

View File

@ -142,8 +142,6 @@ ExecuteLocalTaskList(CitusScanState *scanState, List *taskList)
{ {
Task *task = (Task *) lfirst(taskCell); Task *task = (Task *) lfirst(taskCell);
PlannedStmt *localPlan = NULL;
int cursorOptions = 0;
const char *shardQueryString = task->queryString; const char *shardQueryString = task->queryString;
Query *shardQuery = ParseQueryString(shardQueryString, parameterTypes, numParams); Query *shardQuery = ParseQueryString(shardQueryString, parameterTypes, numParams);
@ -153,7 +151,7 @@ ExecuteLocalTaskList(CitusScanState *scanState, List *taskList)
* go through the distributed executor, which we do not want since the * go through the distributed executor, which we do not want since the
* query is already known to be local. * query is already known to be local.
*/ */
cursorOptions = 0; int cursorOptions = 0;
/* /*
* Altough the shardQuery is local to this node, we prefer planner() * Altough the shardQuery is local to this node, we prefer planner()
@ -163,7 +161,7 @@ ExecuteLocalTaskList(CitusScanState *scanState, List *taskList)
* implemented. So, let planner to call distributed_planner() which * implemented. So, let planner to call distributed_planner() which
* eventually calls standard_planner(). * eventually calls standard_planner().
*/ */
localPlan = planner(shardQuery, cursorOptions, paramListInfo); PlannedStmt *localPlan = planner(shardQuery, cursorOptions, paramListInfo);
LogLocalCommand(shardQueryString); LogLocalCommand(shardQueryString);
@ -241,7 +239,6 @@ ExtractLocalAndRemoteTasks(bool readOnly, List *taskList, List **localTaskList,
} }
else else
{ {
Task *localTask = NULL;
Task *remoteTask = NULL; Task *remoteTask = NULL;
/* /*
@ -252,7 +249,7 @@ ExtractLocalAndRemoteTasks(bool readOnly, List *taskList, List **localTaskList,
*/ */
task->partiallyLocalOrRemote = true; task->partiallyLocalOrRemote = true;
localTask = copyObject(task); Task *localTask = copyObject(task);
localTask->taskPlacementList = localTaskPlacementList; localTask->taskPlacementList = localTaskPlacementList;
*localTaskList = lappend(*localTaskList, localTask); *localTaskList = lappend(*localTaskList, localTask);
@ -318,7 +315,6 @@ ExecuteLocalTaskPlan(CitusScanState *scanState, PlannedStmt *taskPlan, char *que
DestReceiver *tupleStoreDestReceiever = CreateDestReceiver(DestTuplestore); DestReceiver *tupleStoreDestReceiever = CreateDestReceiver(DestTuplestore);
ScanDirection scanDirection = ForwardScanDirection; ScanDirection scanDirection = ForwardScanDirection;
QueryEnvironment *queryEnv = create_queryEnv(); QueryEnvironment *queryEnv = create_queryEnv();
QueryDesc *queryDesc = NULL;
int eflags = 0; int eflags = 0;
uint64 totalRowsProcessed = 0; uint64 totalRowsProcessed = 0;
@ -331,7 +327,7 @@ ExecuteLocalTaskPlan(CitusScanState *scanState, PlannedStmt *taskPlan, char *que
CurrentMemoryContext, false); CurrentMemoryContext, false);
/* Create a QueryDesc for the query */ /* Create a QueryDesc for the query */
queryDesc = CreateQueryDesc(taskPlan, queryString, QueryDesc *queryDesc = CreateQueryDesc(taskPlan, queryString,
GetActiveSnapshot(), InvalidSnapshot, GetActiveSnapshot(), InvalidSnapshot,
tupleStoreDestReceiever, paramListInfo, tupleStoreDestReceiever, paramListInfo,
queryEnv, 0); queryEnv, 0);
@ -365,8 +361,6 @@ ExecuteLocalTaskPlan(CitusScanState *scanState, PlannedStmt *taskPlan, char *que
bool bool
ShouldExecuteTasksLocally(List *taskList) ShouldExecuteTasksLocally(List *taskList)
{ {
bool singleTask = false;
if (!EnableLocalExecution) if (!EnableLocalExecution)
{ {
return false; return false;
@ -394,7 +388,7 @@ ShouldExecuteTasksLocally(List *taskList)
return true; return true;
} }
singleTask = (list_length(taskList) == 1); bool singleTask = (list_length(taskList) == 1);
if (singleTask && TaskAccessesLocalNode((Task *) linitial(taskList))) if (singleTask && TaskAccessesLocalNode((Task *) linitial(taskList)))
{ {
/* /*

View File

@ -55,10 +55,9 @@ static int32
AllocateConnectionId(void) AllocateConnectionId(void)
{ {
int32 connectionId = INVALID_CONNECTION_ID; int32 connectionId = INVALID_CONNECTION_ID;
int32 connIndex = 0;
/* allocate connectionId from connection pool */ /* allocate connectionId from connection pool */
for (connIndex = 0; connIndex < MAX_CONNECTION_COUNT; connIndex++) for (int32 connIndex = 0; connIndex < MAX_CONNECTION_COUNT; connIndex++)
{ {
MultiConnection *connection = ClientConnectionArray[connIndex]; MultiConnection *connection = ClientConnectionArray[connIndex];
if (connection == NULL) if (connection == NULL)
@ -84,8 +83,6 @@ int32
MultiClientConnect(const char *nodeName, uint32 nodePort, const char *nodeDatabase, MultiClientConnect(const char *nodeName, uint32 nodePort, const char *nodeDatabase,
const char *userName) const char *userName)
{ {
MultiConnection *connection = NULL;
ConnStatusType connStatusType = CONNECTION_OK;
int32 connectionId = AllocateConnectionId(); int32 connectionId = AllocateConnectionId();
int connectionFlags = FORCE_NEW_CONNECTION; /* no cached connections for now */ int connectionFlags = FORCE_NEW_CONNECTION; /* no cached connections for now */
@ -103,10 +100,11 @@ MultiClientConnect(const char *nodeName, uint32 nodePort, const char *nodeDataba
} }
/* establish synchronous connection to worker node */ /* establish synchronous connection to worker node */
connection = GetNodeUserDatabaseConnection(connectionFlags, nodeName, nodePort, MultiConnection *connection = GetNodeUserDatabaseConnection(connectionFlags, nodeName,
nodePort,
userName, nodeDatabase); userName, nodeDatabase);
connStatusType = PQstatus(connection->pgConn); ConnStatusType connStatusType = PQstatus(connection->pgConn);
if (connStatusType == CONNECTION_OK) if (connStatusType == CONNECTION_OK)
{ {
@ -132,8 +130,6 @@ int32
MultiClientConnectStart(const char *nodeName, uint32 nodePort, const char *nodeDatabase, MultiClientConnectStart(const char *nodeName, uint32 nodePort, const char *nodeDatabase,
const char *userName) const char *userName)
{ {
MultiConnection *connection = NULL;
ConnStatusType connStatusType = CONNECTION_OK;
int32 connectionId = AllocateConnectionId(); int32 connectionId = AllocateConnectionId();
int connectionFlags = FORCE_NEW_CONNECTION; /* no cached connections for now */ int connectionFlags = FORCE_NEW_CONNECTION; /* no cached connections for now */
@ -151,9 +147,10 @@ MultiClientConnectStart(const char *nodeName, uint32 nodePort, const char *nodeD
} }
/* prepare asynchronous request for worker node connection */ /* prepare asynchronous request for worker node connection */
connection = StartNodeUserDatabaseConnection(connectionFlags, nodeName, nodePort, MultiConnection *connection = StartNodeUserDatabaseConnection(connectionFlags,
nodeName, nodePort,
userName, nodeDatabase); userName, nodeDatabase);
connStatusType = PQstatus(connection->pgConn); ConnStatusType connStatusType = PQstatus(connection->pgConn);
/* /*
* If prepared, we save the connection, and set its initial polling status * If prepared, we save the connection, and set its initial polling status
@ -181,15 +178,13 @@ MultiClientConnectStart(const char *nodeName, uint32 nodePort, const char *nodeD
ConnectStatus ConnectStatus
MultiClientConnectPoll(int32 connectionId) MultiClientConnectPoll(int32 connectionId)
{ {
MultiConnection *connection = NULL;
PostgresPollingStatusType pollingStatus = PGRES_POLLING_OK;
ConnectStatus connectStatus = CLIENT_INVALID_CONNECT; ConnectStatus connectStatus = CLIENT_INVALID_CONNECT;
Assert(connectionId != INVALID_CONNECTION_ID); Assert(connectionId != INVALID_CONNECTION_ID);
connection = ClientConnectionArray[connectionId]; MultiConnection *connection = ClientConnectionArray[connectionId];
Assert(connection != NULL); Assert(connection != NULL);
pollingStatus = ClientPollingStatusArray[connectionId]; PostgresPollingStatusType pollingStatus = ClientPollingStatusArray[connectionId];
if (pollingStatus == PGRES_POLLING_OK) if (pollingStatus == PGRES_POLLING_OK)
{ {
connectStatus = CLIENT_CONNECTION_READY; connectStatus = CLIENT_CONNECTION_READY;
@ -235,11 +230,10 @@ MultiClientConnectPoll(int32 connectionId)
void void
MultiClientDisconnect(int32 connectionId) MultiClientDisconnect(int32 connectionId)
{ {
MultiConnection *connection = NULL;
const int InvalidPollingStatus = -1; const int InvalidPollingStatus = -1;
Assert(connectionId != INVALID_CONNECTION_ID); Assert(connectionId != INVALID_CONNECTION_ID);
connection = ClientConnectionArray[connectionId]; MultiConnection *connection = ClientConnectionArray[connectionId];
Assert(connection != NULL); Assert(connection != NULL);
CloseConnection(connection); CloseConnection(connection);
@ -256,15 +250,13 @@ MultiClientDisconnect(int32 connectionId)
bool bool
MultiClientConnectionUp(int32 connectionId) MultiClientConnectionUp(int32 connectionId)
{ {
MultiConnection *connection = NULL;
ConnStatusType connStatusType = CONNECTION_OK;
bool connectionUp = true; bool connectionUp = true;
Assert(connectionId != INVALID_CONNECTION_ID); Assert(connectionId != INVALID_CONNECTION_ID);
connection = ClientConnectionArray[connectionId]; MultiConnection *connection = ClientConnectionArray[connectionId];
Assert(connection != NULL); Assert(connection != NULL);
connStatusType = PQstatus(connection->pgConn); ConnStatusType connStatusType = PQstatus(connection->pgConn);
if (connStatusType == CONNECTION_BAD) if (connStatusType == CONNECTION_BAD)
{ {
connectionUp = false; connectionUp = false;
@ -278,15 +270,13 @@ MultiClientConnectionUp(int32 connectionId)
bool bool
MultiClientSendQuery(int32 connectionId, const char *query) MultiClientSendQuery(int32 connectionId, const char *query)
{ {
MultiConnection *connection = NULL;
bool success = true; bool success = true;
int querySent = 0;
Assert(connectionId != INVALID_CONNECTION_ID); Assert(connectionId != INVALID_CONNECTION_ID);
connection = ClientConnectionArray[connectionId]; MultiConnection *connection = ClientConnectionArray[connectionId];
Assert(connection != NULL); Assert(connection != NULL);
querySent = SendRemoteCommand(connection, query); int querySent = SendRemoteCommand(connection, query);
if (querySent == 0) if (querySent == 0)
{ {
char *errorMessage = pchomp(PQerrorMessage(connection->pgConn)); char *errorMessage = pchomp(PQerrorMessage(connection->pgConn));
@ -313,14 +303,11 @@ MultiClientSendQuery(int32 connectionId, const char *query)
bool bool
MultiClientCancel(int32 connectionId) MultiClientCancel(int32 connectionId)
{ {
MultiConnection *connection = NULL;
bool canceled = true;
Assert(connectionId != INVALID_CONNECTION_ID); Assert(connectionId != INVALID_CONNECTION_ID);
connection = ClientConnectionArray[connectionId]; MultiConnection *connection = ClientConnectionArray[connectionId];
Assert(connection != NULL); Assert(connection != NULL);
canceled = SendCancelationRequest(connection); bool canceled = SendCancelationRequest(connection);
return canceled; return canceled;
} }
@ -330,16 +317,13 @@ MultiClientCancel(int32 connectionId)
ResultStatus ResultStatus
MultiClientResultStatus(int32 connectionId) MultiClientResultStatus(int32 connectionId)
{ {
MultiConnection *connection = NULL;
int consumed = 0;
ConnStatusType connStatusType = CONNECTION_OK;
ResultStatus resultStatus = CLIENT_INVALID_RESULT_STATUS; ResultStatus resultStatus = CLIENT_INVALID_RESULT_STATUS;
Assert(connectionId != INVALID_CONNECTION_ID); Assert(connectionId != INVALID_CONNECTION_ID);
connection = ClientConnectionArray[connectionId]; MultiConnection *connection = ClientConnectionArray[connectionId];
Assert(connection != NULL); Assert(connection != NULL);
connStatusType = PQstatus(connection->pgConn); ConnStatusType connStatusType = PQstatus(connection->pgConn);
if (connStatusType == CONNECTION_BAD) if (connStatusType == CONNECTION_BAD)
{ {
ereport(WARNING, (errmsg("could not maintain connection to worker node"))); ereport(WARNING, (errmsg("could not maintain connection to worker node")));
@ -347,7 +331,7 @@ MultiClientResultStatus(int32 connectionId)
} }
/* consume input to allow status change */ /* consume input to allow status change */
consumed = PQconsumeInput(connection->pgConn); int consumed = PQconsumeInput(connection->pgConn);
if (consumed != 0) if (consumed != 0)
{ {
int connectionBusy = PQisBusy(connection->pgConn); int connectionBusy = PQisBusy(connection->pgConn);
@ -383,15 +367,11 @@ BatchQueryStatus
MultiClientBatchResult(int32 connectionId, void **queryResult, int *rowCount, MultiClientBatchResult(int32 connectionId, void **queryResult, int *rowCount,
int *columnCount) int *columnCount)
{ {
MultiConnection *connection = NULL;
PGresult *result = NULL;
ConnStatusType connStatusType = CONNECTION_OK;
ExecStatusType resultStatus = PGRES_COMMAND_OK;
BatchQueryStatus queryStatus = CLIENT_INVALID_BATCH_QUERY; BatchQueryStatus queryStatus = CLIENT_INVALID_BATCH_QUERY;
bool raiseInterrupts = true; bool raiseInterrupts = true;
Assert(connectionId != INVALID_CONNECTION_ID); Assert(connectionId != INVALID_CONNECTION_ID);
connection = ClientConnectionArray[connectionId]; MultiConnection *connection = ClientConnectionArray[connectionId];
Assert(connection != NULL); Assert(connection != NULL);
/* set default result */ /* set default result */
@ -399,20 +379,20 @@ MultiClientBatchResult(int32 connectionId, void **queryResult, int *rowCount,
(*rowCount) = -1; (*rowCount) = -1;
(*columnCount) = -1; (*columnCount) = -1;
connStatusType = PQstatus(connection->pgConn); ConnStatusType connStatusType = PQstatus(connection->pgConn);
if (connStatusType == CONNECTION_BAD) if (connStatusType == CONNECTION_BAD)
{ {
ereport(WARNING, (errmsg("could not maintain connection to worker node"))); ereport(WARNING, (errmsg("could not maintain connection to worker node")));
return CLIENT_BATCH_QUERY_FAILED; return CLIENT_BATCH_QUERY_FAILED;
} }
result = GetRemoteCommandResult(connection, raiseInterrupts); PGresult *result = GetRemoteCommandResult(connection, raiseInterrupts);
if (result == NULL) if (result == NULL)
{ {
return CLIENT_BATCH_QUERY_DONE; return CLIENT_BATCH_QUERY_DONE;
} }
resultStatus = PQresultStatus(result); ExecStatusType resultStatus = PQresultStatus(result);
if (resultStatus == PGRES_TUPLES_OK) if (resultStatus == PGRES_TUPLES_OK)
{ {
(*queryResult) = (void **) result; (*queryResult) = (void **) result;
@ -457,20 +437,16 @@ MultiClientClearResult(void *queryResult)
QueryStatus QueryStatus
MultiClientQueryStatus(int32 connectionId) MultiClientQueryStatus(int32 connectionId)
{ {
MultiConnection *connection = NULL;
PGresult *result = NULL;
int tupleCount PG_USED_FOR_ASSERTS_ONLY = 0; int tupleCount PG_USED_FOR_ASSERTS_ONLY = 0;
bool copyResults = false; bool copyResults = false;
ConnStatusType connStatusType = CONNECTION_OK;
ExecStatusType resultStatus = PGRES_COMMAND_OK;
QueryStatus queryStatus = CLIENT_INVALID_QUERY; QueryStatus queryStatus = CLIENT_INVALID_QUERY;
bool raiseInterrupts = true; bool raiseInterrupts = true;
Assert(connectionId != INVALID_CONNECTION_ID); Assert(connectionId != INVALID_CONNECTION_ID);
connection = ClientConnectionArray[connectionId]; MultiConnection *connection = ClientConnectionArray[connectionId];
Assert(connection != NULL); Assert(connection != NULL);
connStatusType = PQstatus(connection->pgConn); ConnStatusType connStatusType = PQstatus(connection->pgConn);
if (connStatusType == CONNECTION_BAD) if (connStatusType == CONNECTION_BAD)
{ {
ereport(WARNING, (errmsg("could not maintain connection to worker node"))); ereport(WARNING, (errmsg("could not maintain connection to worker node")));
@ -482,8 +458,8 @@ MultiClientQueryStatus(int32 connectionId)
* isn't ready yet (the caller didn't wait for the connection to be ready), * isn't ready yet (the caller didn't wait for the connection to be ready),
* we will block on this call. * we will block on this call.
*/ */
result = GetRemoteCommandResult(connection, raiseInterrupts); PGresult *result = GetRemoteCommandResult(connection, raiseInterrupts);
resultStatus = PQresultStatus(result); ExecStatusType resultStatus = PQresultStatus(result);
if (resultStatus == PGRES_COMMAND_OK) if (resultStatus == PGRES_COMMAND_OK)
{ {
@ -536,22 +512,19 @@ MultiClientQueryStatus(int32 connectionId)
CopyStatus CopyStatus
MultiClientCopyData(int32 connectionId, int32 fileDescriptor, uint64 *returnBytesReceived) MultiClientCopyData(int32 connectionId, int32 fileDescriptor, uint64 *returnBytesReceived)
{ {
MultiConnection *connection = NULL;
char *receiveBuffer = NULL; char *receiveBuffer = NULL;
int consumed = 0;
int receiveLength = 0;
const int asynchronous = 1; const int asynchronous = 1;
CopyStatus copyStatus = CLIENT_INVALID_COPY; CopyStatus copyStatus = CLIENT_INVALID_COPY;
Assert(connectionId != INVALID_CONNECTION_ID); Assert(connectionId != INVALID_CONNECTION_ID);
connection = ClientConnectionArray[connectionId]; MultiConnection *connection = ClientConnectionArray[connectionId];
Assert(connection != NULL); Assert(connection != NULL);
/* /*
* Consume input to handle the case where previous copy operation might have * Consume input to handle the case where previous copy operation might have
* received zero bytes. * received zero bytes.
*/ */
consumed = PQconsumeInput(connection->pgConn); int consumed = PQconsumeInput(connection->pgConn);
if (consumed == 0) if (consumed == 0)
{ {
ereport(WARNING, (errmsg("could not read data from worker node"))); ereport(WARNING, (errmsg("could not read data from worker node")));
@ -559,11 +532,10 @@ MultiClientCopyData(int32 connectionId, int32 fileDescriptor, uint64 *returnByte
} }
/* receive copy data message in an asynchronous manner */ /* receive copy data message in an asynchronous manner */
receiveLength = PQgetCopyData(connection->pgConn, &receiveBuffer, asynchronous); int receiveLength = PQgetCopyData(connection->pgConn, &receiveBuffer, asynchronous);
while (receiveLength > 0) while (receiveLength > 0)
{ {
/* received copy data; append these data to file */ /* received copy data; append these data to file */
int appended = -1;
errno = 0; errno = 0;
if (returnBytesReceived) if (returnBytesReceived)
@ -571,7 +543,7 @@ MultiClientCopyData(int32 connectionId, int32 fileDescriptor, uint64 *returnByte
*returnBytesReceived += receiveLength; *returnBytesReceived += receiveLength;
} }
appended = write(fileDescriptor, receiveBuffer, receiveLength); int appended = write(fileDescriptor, receiveBuffer, receiveLength);
if (appended != receiveLength) if (appended != receiveLength)
{ {
/* if write didn't set errno, assume problem is no disk space */ /* if write didn't set errno, assume problem is no disk space */

View File

@ -196,9 +196,6 @@ TupleTableSlot *
ReturnTupleFromTuplestore(CitusScanState *scanState) ReturnTupleFromTuplestore(CitusScanState *scanState)
{ {
Tuplestorestate *tupleStore = scanState->tuplestorestate; Tuplestorestate *tupleStore = scanState->tuplestorestate;
TupleTableSlot *resultSlot = NULL;
EState *executorState = NULL;
ScanDirection scanDirection = NoMovementScanDirection;
bool forwardScanDirection = true; bool forwardScanDirection = true;
if (tupleStore == NULL) if (tupleStore == NULL)
@ -206,8 +203,8 @@ ReturnTupleFromTuplestore(CitusScanState *scanState)
return NULL; return NULL;
} }
executorState = ScanStateGetExecutorState(scanState); EState *executorState = ScanStateGetExecutorState(scanState);
scanDirection = executorState->es_direction; ScanDirection scanDirection = executorState->es_direction;
Assert(ScanDirectionIsValid(scanDirection)); Assert(ScanDirectionIsValid(scanDirection));
if (ScanDirectionIsBackward(scanDirection)) if (ScanDirectionIsBackward(scanDirection))
@ -215,7 +212,7 @@ ReturnTupleFromTuplestore(CitusScanState *scanState)
forwardScanDirection = false; forwardScanDirection = false;
} }
resultSlot = scanState->customScanState.ss.ps.ps_ResultTupleSlot; TupleTableSlot *resultSlot = scanState->customScanState.ss.ps.ps_ResultTupleSlot;
tuplestore_gettupleslot(tupleStore, forwardScanDirection, false, resultSlot); tuplestore_gettupleslot(tupleStore, forwardScanDirection, false, resultSlot);
return resultSlot; return resultSlot;
@ -234,13 +231,12 @@ void
LoadTuplesIntoTupleStore(CitusScanState *citusScanState, Job *workerJob) LoadTuplesIntoTupleStore(CitusScanState *citusScanState, Job *workerJob)
{ {
List *workerTaskList = workerJob->taskList; List *workerTaskList = workerJob->taskList;
TupleDesc tupleDescriptor = NULL;
ListCell *workerTaskCell = NULL; ListCell *workerTaskCell = NULL;
bool randomAccess = true; bool randomAccess = true;
bool interTransactions = false; bool interTransactions = false;
char *copyFormat = "text"; char *copyFormat = "text";
tupleDescriptor = ScanStateGetTupleDescriptor(citusScanState); TupleDesc tupleDescriptor = ScanStateGetTupleDescriptor(citusScanState);
Assert(citusScanState->tuplestorestate == NULL); Assert(citusScanState->tuplestorestate == NULL);
citusScanState->tuplestorestate = citusScanState->tuplestorestate =
@ -254,11 +250,9 @@ LoadTuplesIntoTupleStore(CitusScanState *citusScanState, Job *workerJob)
foreach(workerTaskCell, workerTaskList) foreach(workerTaskCell, workerTaskList)
{ {
Task *workerTask = (Task *) lfirst(workerTaskCell); Task *workerTask = (Task *) lfirst(workerTaskCell);
StringInfo jobDirectoryName = NULL;
StringInfo taskFilename = NULL;
jobDirectoryName = MasterJobDirectoryName(workerTask->jobId); StringInfo jobDirectoryName = MasterJobDirectoryName(workerTask->jobId);
taskFilename = TaskFilename(jobDirectoryName, workerTask->taskId); StringInfo taskFilename = TaskFilename(jobDirectoryName, workerTask->taskId);
ReadFileIntoTupleStore(taskFilename->data, copyFormat, tupleDescriptor, ReadFileIntoTupleStore(taskFilename->data, copyFormat, tupleDescriptor,
citusScanState->tuplestorestate); citusScanState->tuplestorestate);
@ -277,8 +271,6 @@ void
ReadFileIntoTupleStore(char *fileName, char *copyFormat, TupleDesc tupleDescriptor, ReadFileIntoTupleStore(char *fileName, char *copyFormat, TupleDesc tupleDescriptor,
Tuplestorestate *tupstore) Tuplestorestate *tupstore)
{ {
CopyState copyState = NULL;
/* /*
* Trick BeginCopyFrom into using our tuple descriptor by pretending it belongs * Trick BeginCopyFrom into using our tuple descriptor by pretending it belongs
* to a relation. * to a relation.
@ -293,25 +285,22 @@ ReadFileIntoTupleStore(char *fileName, char *copyFormat, TupleDesc tupleDescript
Datum *columnValues = palloc0(columnCount * sizeof(Datum)); Datum *columnValues = palloc0(columnCount * sizeof(Datum));
bool *columnNulls = palloc0(columnCount * sizeof(bool)); bool *columnNulls = palloc0(columnCount * sizeof(bool));
DefElem *copyOption = NULL;
List *copyOptions = NIL; List *copyOptions = NIL;
int location = -1; /* "unknown" token location */ int location = -1; /* "unknown" token location */
copyOption = makeDefElem("format", (Node *) makeString(copyFormat), location); DefElem *copyOption = makeDefElem("format", (Node *) makeString(copyFormat),
location);
copyOptions = lappend(copyOptions, copyOption); copyOptions = lappend(copyOptions, copyOption);
copyState = BeginCopyFrom(NULL, stubRelation, fileName, false, NULL, CopyState copyState = BeginCopyFrom(NULL, stubRelation, fileName, false, NULL,
NULL, copyOptions); NULL, copyOptions);
while (true) while (true)
{ {
MemoryContext oldContext = NULL;
bool nextRowFound = false;
ResetPerTupleExprContext(executorState); ResetPerTupleExprContext(executorState);
oldContext = MemoryContextSwitchTo(executorTupleContext); MemoryContext oldContext = MemoryContextSwitchTo(executorTupleContext);
nextRowFound = NextCopyFromCompat(copyState, executorExpressionContext, bool nextRowFound = NextCopyFromCompat(copyState, executorExpressionContext,
columnValues, columnNulls); columnValues, columnNulls);
if (!nextRowFound) if (!nextRowFound)
{ {
@ -355,7 +344,6 @@ SortTupleStore(CitusScanState *scanState)
ListCell *targetCell = NULL; ListCell *targetCell = NULL;
int sortKeyIndex = 0; int sortKeyIndex = 0;
Tuplesortstate *tuplesortstate = NULL;
/* /*
* Iterate on the returning target list and generate the necessary information * Iterate on the returning target list and generate the necessary information
@ -380,7 +368,7 @@ SortTupleStore(CitusScanState *scanState)
sortKeyIndex++; sortKeyIndex++;
} }
tuplesortstate = Tuplesortstate *tuplesortstate =
tuplesort_begin_heap(tupleDescriptor, numberOfSortKeys, sortColIdx, sortOperators, tuplesort_begin_heap(tupleDescriptor, numberOfSortKeys, sortColIdx, sortOperators,
collations, nullsFirst, work_mem, NULL, false); collations, nullsFirst, work_mem, NULL, false);
@ -467,7 +455,6 @@ ExecuteQueryStringIntoDestReceiver(const char *queryString, ParamListInfo params
Query * Query *
ParseQueryString(const char *queryString, Oid *paramOids, int numParams) ParseQueryString(const char *queryString, Oid *paramOids, int numParams)
{ {
Query *query = NULL;
RawStmt *rawStmt = (RawStmt *) ParseTreeRawStmt(queryString); RawStmt *rawStmt = (RawStmt *) ParseTreeRawStmt(queryString);
List *queryTreeList = List *queryTreeList =
pg_analyze_and_rewrite(rawStmt, queryString, paramOids, numParams, NULL); pg_analyze_and_rewrite(rawStmt, queryString, paramOids, numParams, NULL);
@ -477,7 +464,7 @@ ParseQueryString(const char *queryString, Oid *paramOids, int numParams)
ereport(ERROR, (errmsg("can only execute a single query"))); ereport(ERROR, (errmsg("can only execute a single query")));
} }
query = (Query *) linitial(queryTreeList); Query *query = (Query *) linitial(queryTreeList);
return query; return query;
} }
@ -490,13 +477,10 @@ ParseQueryString(const char *queryString, Oid *paramOids, int numParams)
void void
ExecuteQueryIntoDestReceiver(Query *query, ParamListInfo params, DestReceiver *dest) ExecuteQueryIntoDestReceiver(Query *query, ParamListInfo params, DestReceiver *dest)
{ {
PlannedStmt *queryPlan = NULL; int cursorOptions = CURSOR_OPT_PARALLEL_OK;
int cursorOptions = 0;
cursorOptions = CURSOR_OPT_PARALLEL_OK;
/* plan the subquery, this may be another distributed query */ /* plan the subquery, this may be another distributed query */
queryPlan = pg_plan_query(query, cursorOptions, params); PlannedStmt *queryPlan = pg_plan_query(query, cursorOptions, params);
ExecutePlanIntoDestReceiver(queryPlan, params, dest); ExecutePlanIntoDestReceiver(queryPlan, params, dest);
} }
@ -510,12 +494,11 @@ void
ExecutePlanIntoDestReceiver(PlannedStmt *queryPlan, ParamListInfo params, ExecutePlanIntoDestReceiver(PlannedStmt *queryPlan, ParamListInfo params,
DestReceiver *dest) DestReceiver *dest)
{ {
Portal portal = NULL;
int eflags = 0; int eflags = 0;
long count = FETCH_ALL; long count = FETCH_ALL;
/* create a new portal for executing the query */ /* create a new portal for executing the query */
portal = CreateNewPortal(); Portal portal = CreateNewPortal();
/* don't display the portal in pg_cursors, it is for internal use only */ /* don't display the portal in pg_cursors, it is for internal use only */
portal->visible = false; portal->visible = false;

View File

@ -170,7 +170,6 @@ InitTaskExecution(Task *task, TaskExecStatus initialTaskExecStatus)
{ {
/* each task placement (assignment) corresponds to one worker node */ /* each task placement (assignment) corresponds to one worker node */
uint32 nodeCount = list_length(task->taskPlacementList); uint32 nodeCount = list_length(task->taskPlacementList);
uint32 nodeIndex = 0;
TaskExecution *taskExecution = CitusMakeNode(TaskExecution); TaskExecution *taskExecution = CitusMakeNode(TaskExecution);
@ -185,7 +184,7 @@ InitTaskExecution(Task *task, TaskExecStatus initialTaskExecStatus)
taskExecution->connectionIdArray = palloc0(nodeCount * sizeof(int32)); taskExecution->connectionIdArray = palloc0(nodeCount * sizeof(int32));
taskExecution->fileDescriptorArray = palloc0(nodeCount * sizeof(int32)); taskExecution->fileDescriptorArray = palloc0(nodeCount * sizeof(int32));
for (nodeIndex = 0; nodeIndex < nodeCount; nodeIndex++) for (uint32 nodeIndex = 0; nodeIndex < nodeCount; nodeIndex++)
{ {
taskExecution->taskStatusArray[nodeIndex] = initialTaskExecStatus; taskExecution->taskStatusArray[nodeIndex] = initialTaskExecStatus;
taskExecution->transmitStatusArray[nodeIndex] = EXEC_TRANSMIT_UNASSIGNED; taskExecution->transmitStatusArray[nodeIndex] = EXEC_TRANSMIT_UNASSIGNED;
@ -205,8 +204,7 @@ InitTaskExecution(Task *task, TaskExecStatus initialTaskExecStatus)
void void
CleanupTaskExecution(TaskExecution *taskExecution) CleanupTaskExecution(TaskExecution *taskExecution)
{ {
uint32 nodeIndex = 0; for (uint32 nodeIndex = 0; nodeIndex < taskExecution->nodeCount; nodeIndex++)
for (nodeIndex = 0; nodeIndex < taskExecution->nodeCount; nodeIndex++)
{ {
int32 connectionId = taskExecution->connectionIdArray[nodeIndex]; int32 connectionId = taskExecution->connectionIdArray[nodeIndex];
int32 fileDescriptor = taskExecution->fileDescriptorArray[nodeIndex]; int32 fileDescriptor = taskExecution->fileDescriptorArray[nodeIndex];
@ -284,14 +282,12 @@ AdjustStateForFailure(TaskExecution *taskExecution)
bool bool
CheckIfSizeLimitIsExceeded(DistributedExecutionStats *executionStats) CheckIfSizeLimitIsExceeded(DistributedExecutionStats *executionStats)
{ {
uint64 maxIntermediateResultInBytes = 0;
if (!SubPlanLevel || MaxIntermediateResult < 0) if (!SubPlanLevel || MaxIntermediateResult < 0)
{ {
return false; return false;
} }
maxIntermediateResultInBytes = MaxIntermediateResult * 1024L; uint64 maxIntermediateResultInBytes = MaxIntermediateResult * 1024L;
if (executionStats->totalIntermediateResultSize < maxIntermediateResultInBytes) if (executionStats->totalIntermediateResultSize < maxIntermediateResultInBytes)
{ {
return false; return false;

View File

@ -128,16 +128,16 @@ BuildPlacementAccessList(int32 groupId, List *relationShardList,
foreach(relationShardCell, relationShardList) foreach(relationShardCell, relationShardList)
{ {
RelationShard *relationShard = (RelationShard *) lfirst(relationShardCell); RelationShard *relationShard = (RelationShard *) lfirst(relationShardCell);
ShardPlacement *placement = NULL;
ShardPlacementAccess *placementAccess = NULL;
placement = FindShardPlacementOnGroup(groupId, relationShard->shardId); ShardPlacement *placement = FindShardPlacementOnGroup(groupId,
relationShard->shardId);
if (placement == NULL) if (placement == NULL)
{ {
continue; continue;
} }
placementAccess = CreatePlacementAccess(placement, accessType); ShardPlacementAccess *placementAccess = CreatePlacementAccess(placement,
accessType);
placementAccessList = lappend(placementAccessList, placementAccess); placementAccessList = lappend(placementAccessList, placementAccess);
} }
@ -152,9 +152,8 @@ BuildPlacementAccessList(int32 groupId, List *relationShardList,
ShardPlacementAccess * ShardPlacementAccess *
CreatePlacementAccess(ShardPlacement *placement, ShardPlacementAccessType accessType) CreatePlacementAccess(ShardPlacement *placement, ShardPlacementAccessType accessType)
{ {
ShardPlacementAccess *placementAccess = NULL; ShardPlacementAccess *placementAccess = (ShardPlacementAccess *) palloc0(
sizeof(ShardPlacementAccess));
placementAccess = (ShardPlacementAccess *) palloc0(sizeof(ShardPlacementAccess));
placementAccess->placement = placement; placementAccess->placement = placement;
placementAccess->accessType = accessType; placementAccess->accessType = accessType;

View File

@ -36,7 +36,6 @@ ExecuteSubPlans(DistributedPlan *distributedPlan)
uint64 planId = distributedPlan->planId; uint64 planId = distributedPlan->planId;
List *subPlanList = distributedPlan->subPlanList; List *subPlanList = distributedPlan->subPlanList;
ListCell *subPlanCell = NULL; ListCell *subPlanCell = NULL;
HTAB *intermediateResultsHash = NULL;
if (subPlanList == NIL) if (subPlanList == NIL)
{ {
@ -44,7 +43,7 @@ ExecuteSubPlans(DistributedPlan *distributedPlan)
return; return;
} }
intermediateResultsHash = MakeIntermediateResultHTAB(); HTAB *intermediateResultsHash = MakeIntermediateResultHTAB();
RecordSubplanExecutionsOnNodes(intermediateResultsHash, distributedPlan); RecordSubplanExecutionsOnNodes(intermediateResultsHash, distributedPlan);
@ -61,9 +60,7 @@ ExecuteSubPlans(DistributedPlan *distributedPlan)
DistributedSubPlan *subPlan = (DistributedSubPlan *) lfirst(subPlanCell); DistributedSubPlan *subPlan = (DistributedSubPlan *) lfirst(subPlanCell);
PlannedStmt *plannedStmt = subPlan->plan; PlannedStmt *plannedStmt = subPlan->plan;
uint32 subPlanId = subPlan->subPlanId; uint32 subPlanId = subPlan->subPlanId;
DestReceiver *copyDest = NULL;
ParamListInfo params = NULL; ParamListInfo params = NULL;
EState *estate = NULL;
bool writeLocalFile = false; bool writeLocalFile = false;
char *resultId = GenerateResultId(planId, subPlanId); char *resultId = GenerateResultId(planId, subPlanId);
List *workerNodeList = List *workerNodeList =
@ -94,8 +91,9 @@ ExecuteSubPlans(DistributedPlan *distributedPlan)
} }
SubPlanLevel++; SubPlanLevel++;
estate = CreateExecutorState(); EState *estate = CreateExecutorState();
copyDest = CreateRemoteFileDestReceiver(resultId, estate, workerNodeList, DestReceiver *copyDest = CreateRemoteFileDestReceiver(resultId, estate,
workerNodeList,
writeLocalFile); writeLocalFile);
ExecutePlanIntoDestReceiver(plannedStmt, params, copyDest); ExecutePlanIntoDestReceiver(plannedStmt, params, copyDest);

View File

@ -49,9 +49,6 @@ Datum
citus_create_restore_point(PG_FUNCTION_ARGS) citus_create_restore_point(PG_FUNCTION_ARGS)
{ {
text *restoreNameText = PG_GETARG_TEXT_P(0); text *restoreNameText = PG_GETARG_TEXT_P(0);
char *restoreNameString = NULL;
XLogRecPtr localRestorePoint = InvalidXLogRecPtr;
List *connectionList = NIL;
CheckCitusVersion(ERROR); CheckCitusVersion(ERROR);
EnsureSuperUser(); EnsureSuperUser();
@ -74,7 +71,7 @@ citus_create_restore_point(PG_FUNCTION_ARGS)
"start."))); "start.")));
} }
restoreNameString = text_to_cstring(restoreNameText); char *restoreNameString = text_to_cstring(restoreNameText);
if (strlen(restoreNameString) >= MAXFNAMELEN) if (strlen(restoreNameString) >= MAXFNAMELEN)
{ {
ereport(ERROR, ereport(ERROR,
@ -87,7 +84,7 @@ citus_create_restore_point(PG_FUNCTION_ARGS)
* establish connections to all nodes before taking any locks * establish connections to all nodes before taking any locks
* ShareLock prevents new nodes being added, rendering connectionList incomplete * ShareLock prevents new nodes being added, rendering connectionList incomplete
*/ */
connectionList = OpenConnectionsToAllWorkerNodes(ShareLock); List *connectionList = OpenConnectionsToAllWorkerNodes(ShareLock);
/* /*
* Send a BEGIN to bust through pgbouncer. We won't actually commit since * Send a BEGIN to bust through pgbouncer. We won't actually commit since
@ -100,7 +97,7 @@ citus_create_restore_point(PG_FUNCTION_ARGS)
BlockDistributedTransactions(); BlockDistributedTransactions();
/* do local restore point first to bail out early if something goes wrong */ /* do local restore point first to bail out early if something goes wrong */
localRestorePoint = XLogRestorePoint(restoreNameString); XLogRecPtr localRestorePoint = XLogRestorePoint(restoreNameString);
/* run pg_create_restore_point on all nodes */ /* run pg_create_restore_point on all nodes */
CreateRemoteRestorePoints(restoreNameString, connectionList); CreateRemoteRestorePoints(restoreNameString, connectionList);
@ -117,18 +114,17 @@ static List *
OpenConnectionsToAllWorkerNodes(LOCKMODE lockMode) OpenConnectionsToAllWorkerNodes(LOCKMODE lockMode)
{ {
List *connectionList = NIL; List *connectionList = NIL;
List *workerNodeList = NIL;
ListCell *workerNodeCell = NULL; ListCell *workerNodeCell = NULL;
int connectionFlags = FORCE_NEW_CONNECTION; int connectionFlags = FORCE_NEW_CONNECTION;
workerNodeList = ActivePrimaryWorkerNodeList(lockMode); List *workerNodeList = ActivePrimaryWorkerNodeList(lockMode);
foreach(workerNodeCell, workerNodeList) foreach(workerNodeCell, workerNodeList)
{ {
WorkerNode *workerNode = (WorkerNode *) lfirst(workerNodeCell); WorkerNode *workerNode = (WorkerNode *) lfirst(workerNodeCell);
MultiConnection *connection = NULL;
connection = StartNodeConnection(connectionFlags, workerNode->workerName, MultiConnection *connection = StartNodeConnection(connectionFlags,
workerNode->workerName,
workerNode->workerPort); workerNode->workerPort);
MarkRemoteTransactionCritical(connection); MarkRemoteTransactionCritical(connection);

View File

@ -72,18 +72,10 @@ Datum
master_run_on_worker(PG_FUNCTION_ARGS) master_run_on_worker(PG_FUNCTION_ARGS)
{ {
ReturnSetInfo *rsinfo = (ReturnSetInfo *) fcinfo->resultinfo; ReturnSetInfo *rsinfo = (ReturnSetInfo *) fcinfo->resultinfo;
MemoryContext per_query_ctx = NULL;
MemoryContext oldcontext = NULL;
TupleDesc tupleDescriptor = NULL;
Tuplestorestate *tupleStore = NULL;
bool parallelExecution = false; bool parallelExecution = false;
StringInfo *nodeNameArray = NULL; StringInfo *nodeNameArray = NULL;
int *nodePortArray = NULL; int *nodePortArray = NULL;
StringInfo *commandStringArray = NULL; StringInfo *commandStringArray = NULL;
bool *statusArray = NULL;
StringInfo *resultArray = NULL;
int commandIndex = 0;
int commandCount = 0;
CheckCitusVersion(ERROR); CheckCitusVersion(ERROR);
@ -96,14 +88,14 @@ master_run_on_worker(PG_FUNCTION_ARGS)
"allowed in this context"))); "allowed in this context")));
} }
commandCount = ParseCommandParameters(fcinfo, &nodeNameArray, &nodePortArray, int commandCount = ParseCommandParameters(fcinfo, &nodeNameArray, &nodePortArray,
&commandStringArray, &parallelExecution); &commandStringArray, &parallelExecution);
per_query_ctx = rsinfo->econtext->ecxt_per_query_memory; MemoryContext per_query_ctx = rsinfo->econtext->ecxt_per_query_memory;
oldcontext = MemoryContextSwitchTo(per_query_ctx); MemoryContext oldcontext = MemoryContextSwitchTo(per_query_ctx);
/* get the requested return tuple description */ /* get the requested return tuple description */
tupleDescriptor = CreateTupleDescCopy(rsinfo->expectedDesc); TupleDesc tupleDescriptor = CreateTupleDescCopy(rsinfo->expectedDesc);
/* /*
* Check to make sure we have correct tuple descriptor * Check to make sure we have correct tuple descriptor
@ -121,9 +113,9 @@ master_run_on_worker(PG_FUNCTION_ARGS)
} }
/* prepare storage for status and result values */ /* prepare storage for status and result values */
statusArray = palloc0(commandCount * sizeof(bool)); bool *statusArray = palloc0(commandCount * sizeof(bool));
resultArray = palloc0(commandCount * sizeof(StringInfo)); StringInfo *resultArray = palloc0(commandCount * sizeof(StringInfo));
for (commandIndex = 0; commandIndex < commandCount; commandIndex++) for (int commandIndex = 0; commandIndex < commandCount; commandIndex++)
{ {
resultArray[commandIndex] = makeStringInfo(); resultArray[commandIndex] = makeStringInfo();
} }
@ -142,8 +134,9 @@ master_run_on_worker(PG_FUNCTION_ARGS)
/* let the caller know we're sending back a tuplestore */ /* let the caller know we're sending back a tuplestore */
rsinfo->returnMode = SFRM_Materialize; rsinfo->returnMode = SFRM_Materialize;
tupleStore = CreateTupleStore(tupleDescriptor, Tuplestorestate *tupleStore = CreateTupleStore(tupleDescriptor,
nodeNameArray, nodePortArray, statusArray, nodeNameArray, nodePortArray,
statusArray,
resultArray, commandCount); resultArray, commandCount);
rsinfo->setResult = tupleStore; rsinfo->setResult = tupleStore;
rsinfo->setDesc = tupleDescriptor; rsinfo->setDesc = tupleDescriptor;
@ -170,10 +163,6 @@ ParseCommandParameters(FunctionCallInfo fcinfo, StringInfo **nodeNameArray,
Datum *nodeNameDatumArray = DeconstructArrayObject(nodeNameArrayObject); Datum *nodeNameDatumArray = DeconstructArrayObject(nodeNameArrayObject);
Datum *nodePortDatumArray = DeconstructArrayObject(nodePortArrayObject); Datum *nodePortDatumArray = DeconstructArrayObject(nodePortArrayObject);
Datum *commandStringDatumArray = DeconstructArrayObject(commandStringArrayObject); Datum *commandStringDatumArray = DeconstructArrayObject(commandStringArrayObject);
int index = 0;
StringInfo *nodeNames = NULL;
int *nodePorts = NULL;
StringInfo *commandStrings = NULL;
if (nodeNameCount != nodePortCount || nodeNameCount != commandStringCount) if (nodeNameCount != nodePortCount || nodeNameCount != commandStringCount)
{ {
@ -182,11 +171,11 @@ ParseCommandParameters(FunctionCallInfo fcinfo, StringInfo **nodeNameArray,
errmsg("expected same number of node name, port, and query string"))); errmsg("expected same number of node name, port, and query string")));
} }
nodeNames = palloc0(nodeNameCount * sizeof(StringInfo)); StringInfo *nodeNames = palloc0(nodeNameCount * sizeof(StringInfo));
nodePorts = palloc0(nodeNameCount * sizeof(int)); int *nodePorts = palloc0(nodeNameCount * sizeof(int));
commandStrings = palloc0(nodeNameCount * sizeof(StringInfo)); StringInfo *commandStrings = palloc0(nodeNameCount * sizeof(StringInfo));
for (index = 0; index < nodeNameCount; index++) for (int index = 0; index < nodeNameCount; index++)
{ {
text *nodeNameText = DatumGetTextP(nodeNameDatumArray[index]); text *nodeNameText = DatumGetTextP(nodeNameDatumArray[index]);
char *nodeName = text_to_cstring(nodeNameText); char *nodeName = text_to_cstring(nodeNameText);
@ -224,13 +213,12 @@ ExecuteCommandsInParallelAndStoreResults(StringInfo *nodeNameArray, int *nodePor
bool *statusArray, StringInfo *resultStringArray, bool *statusArray, StringInfo *resultStringArray,
int commmandCount) int commmandCount)
{ {
int commandIndex = 0;
MultiConnection **connectionArray = MultiConnection **connectionArray =
palloc0(commmandCount * sizeof(MultiConnection *)); palloc0(commmandCount * sizeof(MultiConnection *));
int finishedCount = 0; int finishedCount = 0;
/* start connections asynchronously */ /* start connections asynchronously */
for (commandIndex = 0; commandIndex < commmandCount; commandIndex++) for (int commandIndex = 0; commandIndex < commmandCount; commandIndex++)
{ {
char *nodeName = nodeNameArray[commandIndex]->data; char *nodeName = nodeNameArray[commandIndex]->data;
int nodePort = nodePortArray[commandIndex]; int nodePort = nodePortArray[commandIndex];
@ -240,7 +228,7 @@ ExecuteCommandsInParallelAndStoreResults(StringInfo *nodeNameArray, int *nodePor
} }
/* establish connections */ /* establish connections */
for (commandIndex = 0; commandIndex < commmandCount; commandIndex++) for (int commandIndex = 0; commandIndex < commmandCount; commandIndex++)
{ {
MultiConnection *connection = connectionArray[commandIndex]; MultiConnection *connection = connectionArray[commandIndex];
StringInfo queryResultString = resultStringArray[commandIndex]; StringInfo queryResultString = resultStringArray[commandIndex];
@ -264,9 +252,8 @@ ExecuteCommandsInParallelAndStoreResults(StringInfo *nodeNameArray, int *nodePor
} }
/* send queries at once */ /* send queries at once */
for (commandIndex = 0; commandIndex < commmandCount; commandIndex++) for (int commandIndex = 0; commandIndex < commmandCount; commandIndex++)
{ {
int querySent = 0;
MultiConnection *connection = connectionArray[commandIndex]; MultiConnection *connection = connectionArray[commandIndex];
char *queryString = commandStringArray[commandIndex]->data; char *queryString = commandStringArray[commandIndex]->data;
StringInfo queryResultString = resultStringArray[commandIndex]; StringInfo queryResultString = resultStringArray[commandIndex];
@ -280,7 +267,7 @@ ExecuteCommandsInParallelAndStoreResults(StringInfo *nodeNameArray, int *nodePor
continue; continue;
} }
querySent = SendRemoteCommand(connection, queryString); int querySent = SendRemoteCommand(connection, queryString);
if (querySent == 0) if (querySent == 0)
{ {
StoreErrorMessage(connection, queryResultString); StoreErrorMessage(connection, queryResultString);
@ -294,19 +281,18 @@ ExecuteCommandsInParallelAndStoreResults(StringInfo *nodeNameArray, int *nodePor
/* check for query results */ /* check for query results */
while (finishedCount < commmandCount) while (finishedCount < commmandCount)
{ {
for (commandIndex = 0; commandIndex < commmandCount; commandIndex++) for (int commandIndex = 0; commandIndex < commmandCount; commandIndex++)
{ {
MultiConnection *connection = connectionArray[commandIndex]; MultiConnection *connection = connectionArray[commandIndex];
StringInfo queryResultString = resultStringArray[commandIndex]; StringInfo queryResultString = resultStringArray[commandIndex];
bool success = false; bool success = false;
bool queryFinished = false;
if (connection == NULL) if (connection == NULL)
{ {
continue; continue;
} }
queryFinished = GetConnectionStatusAndResult(connection, &success, bool queryFinished = GetConnectionStatusAndResult(connection, &success,
queryResultString); queryResultString);
if (queryFinished) if (queryFinished)
@ -343,9 +329,6 @@ GetConnectionStatusAndResult(MultiConnection *connection, bool *resultStatus,
{ {
bool finished = true; bool finished = true;
ConnStatusType connectionStatus = PQstatus(connection->pgConn); ConnStatusType connectionStatus = PQstatus(connection->pgConn);
int consumeInput = 0;
PGresult *queryResult = NULL;
bool success = false;
*resultStatus = false; *resultStatus = false;
resetStringInfo(queryResultString); resetStringInfo(queryResultString);
@ -356,7 +339,7 @@ GetConnectionStatusAndResult(MultiConnection *connection, bool *resultStatus,
return finished; return finished;
} }
consumeInput = PQconsumeInput(connection->pgConn); int consumeInput = PQconsumeInput(connection->pgConn);
if (consumeInput == 0) if (consumeInput == 0)
{ {
appendStringInfo(queryResultString, "query result unavailable"); appendStringInfo(queryResultString, "query result unavailable");
@ -371,8 +354,8 @@ GetConnectionStatusAndResult(MultiConnection *connection, bool *resultStatus,
} }
/* query result is available at this point */ /* query result is available at this point */
queryResult = PQgetResult(connection->pgConn); PGresult *queryResult = PQgetResult(connection->pgConn);
success = EvaluateQueryResult(connection, queryResult, queryResultString); bool success = EvaluateQueryResult(connection, queryResult, queryResultString);
PQclear(queryResult); PQclear(queryResult);
*resultStatus = success; *resultStatus = success;
@ -449,12 +432,10 @@ StoreErrorMessage(MultiConnection *connection, StringInfo queryResultString)
char *errorMessage = PQerrorMessage(connection->pgConn); char *errorMessage = PQerrorMessage(connection->pgConn);
if (errorMessage != NULL) if (errorMessage != NULL)
{ {
char *firstNewlineIndex = NULL;
/* copy the error message to a writable memory */ /* copy the error message to a writable memory */
errorMessage = pnstrdup(errorMessage, strlen(errorMessage)); errorMessage = pnstrdup(errorMessage, strlen(errorMessage));
firstNewlineIndex = strchr(errorMessage, '\n'); char *firstNewlineIndex = strchr(errorMessage, '\n');
/* trim the error message at the line break */ /* trim the error message at the line break */
if (firstNewlineIndex != NULL) if (firstNewlineIndex != NULL)
@ -484,16 +465,14 @@ ExecuteCommandsAndStoreResults(StringInfo *nodeNameArray, int *nodePortArray,
StringInfo *commandStringArray, bool *statusArray, StringInfo *commandStringArray, bool *statusArray,
StringInfo *resultStringArray, int commmandCount) StringInfo *resultStringArray, int commmandCount)
{ {
int commandIndex = 0; for (int commandIndex = 0; commandIndex < commmandCount; commandIndex++)
for (commandIndex = 0; commandIndex < commmandCount; commandIndex++)
{ {
char *nodeName = nodeNameArray[commandIndex]->data; char *nodeName = nodeNameArray[commandIndex]->data;
int32 nodePort = nodePortArray[commandIndex]; int32 nodePort = nodePortArray[commandIndex];
bool success = false;
char *queryString = commandStringArray[commandIndex]->data; char *queryString = commandStringArray[commandIndex]->data;
StringInfo queryResultString = resultStringArray[commandIndex]; StringInfo queryResultString = resultStringArray[commandIndex];
success = ExecuteRemoteQueryOrCommand(nodeName, nodePort, queryString, bool success = ExecuteRemoteQueryOrCommand(nodeName, nodePort, queryString,
queryResultString); queryResultString);
statusArray[commandIndex] = success; statusArray[commandIndex] = success;
@ -516,8 +495,6 @@ ExecuteRemoteQueryOrCommand(char *nodeName, uint32 nodePort, char *queryString,
int connectionFlags = FORCE_NEW_CONNECTION; int connectionFlags = FORCE_NEW_CONNECTION;
MultiConnection *connection = MultiConnection *connection =
GetNodeConnection(connectionFlags, nodeName, nodePort); GetNodeConnection(connectionFlags, nodeName, nodePort);
bool success = false;
PGresult *queryResult = NULL;
bool raiseInterrupts = true; bool raiseInterrupts = true;
if (PQstatus(connection->pgConn) != CONNECTION_OK) if (PQstatus(connection->pgConn) != CONNECTION_OK)
@ -528,8 +505,8 @@ ExecuteRemoteQueryOrCommand(char *nodeName, uint32 nodePort, char *queryString,
} }
SendRemoteCommand(connection, queryString); SendRemoteCommand(connection, queryString);
queryResult = GetRemoteCommandResult(connection, raiseInterrupts); PGresult *queryResult = GetRemoteCommandResult(connection, raiseInterrupts);
success = EvaluateQueryResult(connection, queryResult, queryResultString); bool success = EvaluateQueryResult(connection, queryResult, queryResultString);
PQclear(queryResult); PQclear(queryResult);
@ -547,13 +524,11 @@ CreateTupleStore(TupleDesc tupleDescriptor,
StringInfo *resultArray, int commandCount) StringInfo *resultArray, int commandCount)
{ {
Tuplestorestate *tupleStore = tuplestore_begin_heap(true, false, work_mem); Tuplestorestate *tupleStore = tuplestore_begin_heap(true, false, work_mem);
int commandIndex = 0;
bool nulls[4] = { false, false, false, false }; bool nulls[4] = { false, false, false, false };
for (commandIndex = 0; commandIndex < commandCount; commandIndex++) for (int commandIndex = 0; commandIndex < commandCount; commandIndex++)
{ {
Datum values[4]; Datum values[4];
HeapTuple tuple = NULL;
StringInfo nodeNameString = nodeNameArray[commandIndex]; StringInfo nodeNameString = nodeNameArray[commandIndex];
StringInfo resultString = resultArray[commandIndex]; StringInfo resultString = resultArray[commandIndex];
text *nodeNameText = cstring_to_text_with_len(nodeNameString->data, text *nodeNameText = cstring_to_text_with_len(nodeNameString->data,
@ -566,7 +541,7 @@ CreateTupleStore(TupleDesc tupleDescriptor,
values[2] = BoolGetDatum(statusArray[commandIndex]); values[2] = BoolGetDatum(statusArray[commandIndex]);
values[3] = PointerGetDatum(resultText); values[3] = PointerGetDatum(resultText);
tuple = heap_form_tuple(tupleDescriptor, values, nulls); HeapTuple tuple = heap_form_tuple(tupleDescriptor, values, nulls);
tuplestore_puttuple(tupleStore, tuple); tuplestore_puttuple(tupleStore, tuple);
heap_freetuple(tuple); heap_freetuple(tuple);

View File

@ -106,13 +106,6 @@ void
CreateShardsWithRoundRobinPolicy(Oid distributedTableId, int32 shardCount, CreateShardsWithRoundRobinPolicy(Oid distributedTableId, int32 shardCount,
int32 replicationFactor, bool useExclusiveConnections) int32 replicationFactor, bool useExclusiveConnections)
{ {
char shardStorageType = 0;
List *workerNodeList = NIL;
int32 workerNodeCount = 0;
uint32 placementAttemptCount = 0;
uint64 hashTokenIncrement = 0;
List *existingShardList = NIL;
int64 shardIndex = 0;
DistTableCacheEntry *cacheEntry = DistributedTableCacheEntry(distributedTableId); DistTableCacheEntry *cacheEntry = DistributedTableCacheEntry(distributedTableId);
bool colocatedShard = false; bool colocatedShard = false;
List *insertedShardPlacements = NIL; List *insertedShardPlacements = NIL;
@ -132,7 +125,7 @@ CreateShardsWithRoundRobinPolicy(Oid distributedTableId, int32 shardCount,
LockRelationOid(distributedTableId, ExclusiveLock); LockRelationOid(distributedTableId, ExclusiveLock);
/* validate that shards haven't already been created for this table */ /* validate that shards haven't already been created for this table */
existingShardList = LoadShardList(distributedTableId); List *existingShardList = LoadShardList(distributedTableId);
if (existingShardList != NIL) if (existingShardList != NIL)
{ {
char *tableName = get_rel_name(distributedTableId); char *tableName = get_rel_name(distributedTableId);
@ -171,16 +164,16 @@ CreateShardsWithRoundRobinPolicy(Oid distributedTableId, int32 shardCount,
} }
/* calculate the split of the hash space */ /* calculate the split of the hash space */
hashTokenIncrement = HASH_TOKEN_COUNT / shardCount; uint64 hashTokenIncrement = HASH_TOKEN_COUNT / shardCount;
/* don't allow concurrent node list changes that require an exclusive lock */ /* don't allow concurrent node list changes that require an exclusive lock */
LockRelationOid(DistNodeRelationId(), RowShareLock); LockRelationOid(DistNodeRelationId(), RowShareLock);
/* load and sort the worker node list for deterministic placement */ /* load and sort the worker node list for deterministic placement */
workerNodeList = DistributedTablePlacementNodeList(NoLock); List *workerNodeList = DistributedTablePlacementNodeList(NoLock);
workerNodeList = SortList(workerNodeList, CompareWorkerNodes); workerNodeList = SortList(workerNodeList, CompareWorkerNodes);
workerNodeCount = list_length(workerNodeList); int32 workerNodeCount = list_length(workerNodeList);
if (replicationFactor > workerNodeCount) if (replicationFactor > workerNodeCount)
{ {
ereport(ERROR, (errcode(ERRCODE_INVALID_PARAMETER_VALUE), ereport(ERROR, (errcode(ERRCODE_INVALID_PARAMETER_VALUE),
@ -191,26 +184,23 @@ CreateShardsWithRoundRobinPolicy(Oid distributedTableId, int32 shardCount,
} }
/* if we have enough nodes, add an extra placement attempt for backup */ /* if we have enough nodes, add an extra placement attempt for backup */
placementAttemptCount = (uint32) replicationFactor; uint32 placementAttemptCount = (uint32) replicationFactor;
if (workerNodeCount > replicationFactor) if (workerNodeCount > replicationFactor)
{ {
placementAttemptCount++; placementAttemptCount++;
} }
/* set shard storage type according to relation type */ /* set shard storage type according to relation type */
shardStorageType = ShardStorageType(distributedTableId); char shardStorageType = ShardStorageType(distributedTableId);
for (shardIndex = 0; shardIndex < shardCount; shardIndex++) for (int64 shardIndex = 0; shardIndex < shardCount; shardIndex++)
{ {
uint32 roundRobinNodeIndex = shardIndex % workerNodeCount; uint32 roundRobinNodeIndex = shardIndex % workerNodeCount;
/* initialize the hash token space for this shard */ /* initialize the hash token space for this shard */
text *minHashTokenText = NULL;
text *maxHashTokenText = NULL;
int32 shardMinHashToken = INT32_MIN + (shardIndex * hashTokenIncrement); int32 shardMinHashToken = INT32_MIN + (shardIndex * hashTokenIncrement);
int32 shardMaxHashToken = shardMinHashToken + (hashTokenIncrement - 1); int32 shardMaxHashToken = shardMinHashToken + (hashTokenIncrement - 1);
uint64 shardId = GetNextShardId(); uint64 shardId = GetNextShardId();
List *currentInsertedShardPlacements = NIL;
/* if we are at the last shard, make sure the max token value is INT_MAX */ /* if we are at the last shard, make sure the max token value is INT_MAX */
if (shardIndex == (shardCount - 1)) if (shardIndex == (shardCount - 1))
@ -219,8 +209,8 @@ CreateShardsWithRoundRobinPolicy(Oid distributedTableId, int32 shardCount,
} }
/* insert the shard metadata row along with its min/max values */ /* insert the shard metadata row along with its min/max values */
minHashTokenText = IntegerToText(shardMinHashToken); text *minHashTokenText = IntegerToText(shardMinHashToken);
maxHashTokenText = IntegerToText(shardMaxHashToken); text *maxHashTokenText = IntegerToText(shardMaxHashToken);
/* /*
* Grabbing the shard metadata lock isn't technically necessary since * Grabbing the shard metadata lock isn't technically necessary since
@ -233,7 +223,8 @@ CreateShardsWithRoundRobinPolicy(Oid distributedTableId, int32 shardCount,
InsertShardRow(distributedTableId, shardId, shardStorageType, InsertShardRow(distributedTableId, shardId, shardStorageType,
minHashTokenText, maxHashTokenText); minHashTokenText, maxHashTokenText);
currentInsertedShardPlacements = InsertShardPlacementRows(distributedTableId, List *currentInsertedShardPlacements = InsertShardPlacementRows(
distributedTableId,
shardId, shardId,
workerNodeList, workerNodeList,
roundRobinNodeIndex, roundRobinNodeIndex,
@ -255,9 +246,6 @@ void
CreateColocatedShards(Oid targetRelationId, Oid sourceRelationId, bool CreateColocatedShards(Oid targetRelationId, Oid sourceRelationId, bool
useExclusiveConnections) useExclusiveConnections)
{ {
char targetShardStorageType = 0;
List *existingShardList = NIL;
List *sourceShardIntervalList = NIL;
ListCell *sourceShardCell = NULL; ListCell *sourceShardCell = NULL;
bool colocatedShard = true; bool colocatedShard = true;
List *insertedShardPlacements = NIL; List *insertedShardPlacements = NIL;
@ -281,11 +269,11 @@ CreateColocatedShards(Oid targetRelationId, Oid sourceRelationId, bool
LockRelationOid(sourceRelationId, AccessShareLock); LockRelationOid(sourceRelationId, AccessShareLock);
/* prevent placement changes of the source relation until we colocate with them */ /* prevent placement changes of the source relation until we colocate with them */
sourceShardIntervalList = LoadShardIntervalList(sourceRelationId); List *sourceShardIntervalList = LoadShardIntervalList(sourceRelationId);
LockShardListMetadata(sourceShardIntervalList, ShareLock); LockShardListMetadata(sourceShardIntervalList, ShareLock);
/* validate that shards haven't already been created for this table */ /* validate that shards haven't already been created for this table */
existingShardList = LoadShardList(targetRelationId); List *existingShardList = LoadShardList(targetRelationId);
if (existingShardList != NIL) if (existingShardList != NIL)
{ {
char *targetRelationName = get_rel_name(targetRelationId); char *targetRelationName = get_rel_name(targetRelationId);
@ -294,7 +282,7 @@ CreateColocatedShards(Oid targetRelationId, Oid sourceRelationId, bool
targetRelationName))); targetRelationName)));
} }
targetShardStorageType = ShardStorageType(targetRelationId); char targetShardStorageType = ShardStorageType(targetRelationId);
foreach(sourceShardCell, sourceShardIntervalList) foreach(sourceShardCell, sourceShardIntervalList)
{ {
@ -319,17 +307,18 @@ CreateColocatedShards(Oid targetRelationId, Oid sourceRelationId, bool
int32 groupId = sourcePlacement->groupId; int32 groupId = sourcePlacement->groupId;
const RelayFileState shardState = FILE_FINALIZED; const RelayFileState shardState = FILE_FINALIZED;
const uint64 shardSize = 0; const uint64 shardSize = 0;
uint64 shardPlacementId = 0;
ShardPlacement *shardPlacement = NULL;
/* /*
* Optimistically add shard placement row the pg_dist_shard_placement, in case * Optimistically add shard placement row the pg_dist_shard_placement, in case
* of any error it will be roll-backed. * of any error it will be roll-backed.
*/ */
shardPlacementId = InsertShardPlacementRow(newShardId, INVALID_PLACEMENT_ID, uint64 shardPlacementId = InsertShardPlacementRow(newShardId,
shardState, shardSize, groupId); INVALID_PLACEMENT_ID,
shardState, shardSize,
groupId);
shardPlacement = LoadShardPlacement(newShardId, shardPlacementId); ShardPlacement *shardPlacement = LoadShardPlacement(newShardId,
shardPlacementId);
insertedShardPlacements = lappend(insertedShardPlacements, shardPlacement); insertedShardPlacements = lappend(insertedShardPlacements, shardPlacement);
} }
} }
@ -347,17 +336,11 @@ CreateColocatedShards(Oid targetRelationId, Oid sourceRelationId, bool
void void
CreateReferenceTableShard(Oid distributedTableId) CreateReferenceTableShard(Oid distributedTableId)
{ {
char shardStorageType = 0;
List *nodeList = NIL;
List *existingShardList = NIL;
uint64 shardId = INVALID_SHARD_ID;
int workerStartIndex = 0; int workerStartIndex = 0;
int replicationFactor = 0;
text *shardMinValue = NULL; text *shardMinValue = NULL;
text *shardMaxValue = NULL; text *shardMaxValue = NULL;
bool useExclusiveConnection = false; bool useExclusiveConnection = false;
bool colocatedShard = false; bool colocatedShard = false;
List *insertedShardPlacements = NIL;
/* /*
* In contrast to append/range partitioned tables it makes more sense to * In contrast to append/range partitioned tables it makes more sense to
@ -371,10 +354,10 @@ CreateReferenceTableShard(Oid distributedTableId)
LockRelationOid(distributedTableId, ExclusiveLock); LockRelationOid(distributedTableId, ExclusiveLock);
/* set shard storage type according to relation type */ /* set shard storage type according to relation type */
shardStorageType = ShardStorageType(distributedTableId); char shardStorageType = ShardStorageType(distributedTableId);
/* validate that shards haven't already been created for this table */ /* validate that shards haven't already been created for this table */
existingShardList = LoadShardList(distributedTableId); List *existingShardList = LoadShardList(distributedTableId);
if (existingShardList != NIL) if (existingShardList != NIL)
{ {
char *tableName = get_rel_name(distributedTableId); char *tableName = get_rel_name(distributedTableId);
@ -387,13 +370,13 @@ CreateReferenceTableShard(Oid distributedTableId)
* load and sort the worker node list for deterministic placements * load and sort the worker node list for deterministic placements
* create_reference_table has already acquired pg_dist_node lock * create_reference_table has already acquired pg_dist_node lock
*/ */
nodeList = ReferenceTablePlacementNodeList(ShareLock); List *nodeList = ReferenceTablePlacementNodeList(ShareLock);
nodeList = SortList(nodeList, CompareWorkerNodes); nodeList = SortList(nodeList, CompareWorkerNodes);
replicationFactor = ReferenceTableReplicationFactor(); int replicationFactor = ReferenceTableReplicationFactor();
/* get the next shard id */ /* get the next shard id */
shardId = GetNextShardId(); uint64 shardId = GetNextShardId();
/* /*
* Grabbing the shard metadata lock isn't technically necessary since * Grabbing the shard metadata lock isn't technically necessary since
@ -406,7 +389,7 @@ CreateReferenceTableShard(Oid distributedTableId)
InsertShardRow(distributedTableId, shardId, shardStorageType, shardMinValue, InsertShardRow(distributedTableId, shardId, shardStorageType, shardMinValue,
shardMaxValue); shardMaxValue);
insertedShardPlacements = InsertShardPlacementRows(distributedTableId, shardId, List *insertedShardPlacements = InsertShardPlacementRows(distributedTableId, shardId,
nodeList, workerStartIndex, nodeList, workerStartIndex,
replicationFactor); replicationFactor);
@ -436,11 +419,10 @@ CheckHashPartitionedTable(Oid distributedTableId)
text * text *
IntegerToText(int32 value) IntegerToText(int32 value)
{ {
text *valueText = NULL;
StringInfo valueString = makeStringInfo(); StringInfo valueString = makeStringInfo();
appendStringInfo(valueString, "%d", value); appendStringInfo(valueString, "%d", value);
valueText = cstring_to_text(valueString->data); text *valueText = cstring_to_text(valueString->data);
return valueText; return valueText;
} }

View File

@ -103,23 +103,10 @@ master_apply_delete_command(PG_FUNCTION_ARGS)
{ {
text *queryText = PG_GETARG_TEXT_P(0); text *queryText = PG_GETARG_TEXT_P(0);
char *queryString = text_to_cstring(queryText); char *queryString = text_to_cstring(queryText);
char *relationName = NULL;
char *schemaName = NULL;
Oid relationId = InvalidOid;
List *shardIntervalList = NIL;
List *deletableShardIntervalList = NIL; List *deletableShardIntervalList = NIL;
List *queryTreeList = NIL;
Query *deleteQuery = NULL;
Node *whereClause = NULL;
Node *deleteCriteria = NULL;
Node *queryTreeNode = NULL;
DeleteStmt *deleteStatement = NULL;
int droppedShardCount = 0;
LOCKMODE lockMode = 0;
char partitionMethod = 0;
bool failOK = false; bool failOK = false;
RawStmt *rawStmt = (RawStmt *) ParseTreeRawStmt(queryString); RawStmt *rawStmt = (RawStmt *) ParseTreeRawStmt(queryString);
queryTreeNode = rawStmt->stmt; Node *queryTreeNode = rawStmt->stmt;
EnsureCoordinator(); EnsureCoordinator();
CheckCitusVersion(ERROR); CheckCitusVersion(ERROR);
@ -130,19 +117,19 @@ master_apply_delete_command(PG_FUNCTION_ARGS)
ApplyLogRedaction(queryString)))); ApplyLogRedaction(queryString))));
} }
deleteStatement = (DeleteStmt *) queryTreeNode; DeleteStmt *deleteStatement = (DeleteStmt *) queryTreeNode;
schemaName = deleteStatement->relation->schemaname; char *schemaName = deleteStatement->relation->schemaname;
relationName = deleteStatement->relation->relname; char *relationName = deleteStatement->relation->relname;
/* /*
* We take an exclusive lock while dropping shards to prevent concurrent * We take an exclusive lock while dropping shards to prevent concurrent
* writes. We don't want to block SELECTs, which means queries might fail * writes. We don't want to block SELECTs, which means queries might fail
* if they access a shard that has just been dropped. * if they access a shard that has just been dropped.
*/ */
lockMode = ExclusiveLock; LOCKMODE lockMode = ExclusiveLock;
relationId = RangeVarGetRelid(deleteStatement->relation, lockMode, failOK); Oid relationId = RangeVarGetRelid(deleteStatement->relation, lockMode, failOK);
/* schema-prefix if it is not specified already */ /* schema-prefix if it is not specified already */
if (schemaName == NULL) if (schemaName == NULL)
@ -154,15 +141,15 @@ master_apply_delete_command(PG_FUNCTION_ARGS)
CheckDistributedTable(relationId); CheckDistributedTable(relationId);
EnsureTablePermissions(relationId, ACL_DELETE); EnsureTablePermissions(relationId, ACL_DELETE);
queryTreeList = pg_analyze_and_rewrite(rawStmt, queryString, NULL, 0, NULL); List *queryTreeList = pg_analyze_and_rewrite(rawStmt, queryString, NULL, 0, NULL);
deleteQuery = (Query *) linitial(queryTreeList); Query *deleteQuery = (Query *) linitial(queryTreeList);
CheckTableCount(deleteQuery); CheckTableCount(deleteQuery);
/* get where clause and flatten it */ /* get where clause and flatten it */
whereClause = (Node *) deleteQuery->jointree->quals; Node *whereClause = (Node *) deleteQuery->jointree->quals;
deleteCriteria = eval_const_expressions(NULL, whereClause); Node *deleteCriteria = eval_const_expressions(NULL, whereClause);
partitionMethod = PartitionMethod(relationId); char partitionMethod = PartitionMethod(relationId);
if (partitionMethod == DISTRIBUTE_BY_HASH) if (partitionMethod == DISTRIBUTE_BY_HASH)
{ {
ereport(ERROR, (errcode(ERRCODE_FEATURE_NOT_SUPPORTED), ereport(ERROR, (errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
@ -184,7 +171,7 @@ master_apply_delete_command(PG_FUNCTION_ARGS)
CheckDeleteCriteria(deleteCriteria); CheckDeleteCriteria(deleteCriteria);
CheckPartitionColumn(relationId, deleteCriteria); CheckPartitionColumn(relationId, deleteCriteria);
shardIntervalList = LoadShardIntervalList(relationId); List *shardIntervalList = LoadShardIntervalList(relationId);
/* drop all shards if where clause is not present */ /* drop all shards if where clause is not present */
if (deleteCriteria == NULL) if (deleteCriteria == NULL)
@ -199,7 +186,7 @@ master_apply_delete_command(PG_FUNCTION_ARGS)
deleteCriteria); deleteCriteria);
} }
droppedShardCount = DropShards(relationId, schemaName, relationName, int droppedShardCount = DropShards(relationId, schemaName, relationName,
deletableShardIntervalList); deletableShardIntervalList);
PG_RETURN_INT32(droppedShardCount); PG_RETURN_INT32(droppedShardCount);
@ -218,8 +205,6 @@ master_drop_all_shards(PG_FUNCTION_ARGS)
text *schemaNameText = PG_GETARG_TEXT_P(1); text *schemaNameText = PG_GETARG_TEXT_P(1);
text *relationNameText = PG_GETARG_TEXT_P(2); text *relationNameText = PG_GETARG_TEXT_P(2);
List *shardIntervalList = NIL;
int droppedShardCount = 0;
char *schemaName = text_to_cstring(schemaNameText); char *schemaName = text_to_cstring(schemaNameText);
char *relationName = text_to_cstring(relationNameText); char *relationName = text_to_cstring(relationNameText);
@ -246,8 +231,8 @@ master_drop_all_shards(PG_FUNCTION_ARGS)
*/ */
LockRelationOid(relationId, AccessExclusiveLock); LockRelationOid(relationId, AccessExclusiveLock);
shardIntervalList = LoadShardIntervalList(relationId); List *shardIntervalList = LoadShardIntervalList(relationId);
droppedShardCount = DropShards(relationId, schemaName, relationName, int droppedShardCount = DropShards(relationId, schemaName, relationName,
shardIntervalList); shardIntervalList);
PG_RETURN_INT32(droppedShardCount); PG_RETURN_INT32(droppedShardCount);
@ -265,7 +250,6 @@ Datum
master_drop_sequences(PG_FUNCTION_ARGS) master_drop_sequences(PG_FUNCTION_ARGS)
{ {
ArrayType *sequenceNamesArray = PG_GETARG_ARRAYTYPE_P(0); ArrayType *sequenceNamesArray = PG_GETARG_ARRAYTYPE_P(0);
ArrayIterator sequenceIterator = NULL;
Datum sequenceNameDatum = 0; Datum sequenceNameDatum = 0;
bool isNull = false; bool isNull = false;
StringInfo dropSeqCommand = makeStringInfo(); StringInfo dropSeqCommand = makeStringInfo();
@ -291,20 +275,17 @@ master_drop_sequences(PG_FUNCTION_ARGS)
} }
/* iterate over sequence names to build single command to DROP them all */ /* iterate over sequence names to build single command to DROP them all */
sequenceIterator = array_create_iterator(sequenceNamesArray, 0, NULL); ArrayIterator sequenceIterator = array_create_iterator(sequenceNamesArray, 0, NULL);
while (array_iterate(sequenceIterator, &sequenceNameDatum, &isNull)) while (array_iterate(sequenceIterator, &sequenceNameDatum, &isNull))
{ {
text *sequenceNameText = NULL;
Oid sequenceOid = InvalidOid;
if (isNull) if (isNull)
{ {
ereport(ERROR, (errmsg("unexpected NULL sequence name"), ereport(ERROR, (errmsg("unexpected NULL sequence name"),
errcode(ERRCODE_INVALID_PARAMETER_VALUE))); errcode(ERRCODE_INVALID_PARAMETER_VALUE)));
} }
sequenceNameText = DatumGetTextP(sequenceNameDatum); text *sequenceNameText = DatumGetTextP(sequenceNameDatum);
sequenceOid = ResolveRelationId(sequenceNameText, true); Oid sequenceOid = ResolveRelationId(sequenceNameText, true);
if (OidIsValid(sequenceOid)) if (OidIsValid(sequenceOid))
{ {
/* /*
@ -379,7 +360,6 @@ DropShards(Oid relationId, char *schemaName, char *relationName,
List *deletableShardIntervalList) List *deletableShardIntervalList)
{ {
ListCell *shardIntervalCell = NULL; ListCell *shardIntervalCell = NULL;
int droppedShardCount = 0;
BeginOrContinueCoordinatedTransaction(); BeginOrContinueCoordinatedTransaction();
@ -391,20 +371,18 @@ DropShards(Oid relationId, char *schemaName, char *relationName,
foreach(shardIntervalCell, deletableShardIntervalList) foreach(shardIntervalCell, deletableShardIntervalList)
{ {
List *shardPlacementList = NIL;
ListCell *shardPlacementCell = NULL; ListCell *shardPlacementCell = NULL;
ShardInterval *shardInterval = (ShardInterval *) lfirst(shardIntervalCell); ShardInterval *shardInterval = (ShardInterval *) lfirst(shardIntervalCell);
uint64 shardId = shardInterval->shardId; uint64 shardId = shardInterval->shardId;
char *quotedShardName = NULL;
char *shardRelationName = pstrdup(relationName); char *shardRelationName = pstrdup(relationName);
Assert(shardInterval->relationId == relationId); Assert(shardInterval->relationId == relationId);
/* Build shard relation name. */ /* Build shard relation name. */
AppendShardIdToName(&shardRelationName, shardId); AppendShardIdToName(&shardRelationName, shardId);
quotedShardName = quote_qualified_identifier(schemaName, shardRelationName); char *quotedShardName = quote_qualified_identifier(schemaName, shardRelationName);
shardPlacementList = ShardPlacementList(shardId); List *shardPlacementList = ShardPlacementList(shardId);
foreach(shardPlacementCell, shardPlacementList) foreach(shardPlacementCell, shardPlacementList)
{ {
ShardPlacement *shardPlacement = ShardPlacement *shardPlacement =
@ -412,7 +390,6 @@ DropShards(Oid relationId, char *schemaName, char *relationName,
char *workerName = shardPlacement->nodeName; char *workerName = shardPlacement->nodeName;
uint32 workerPort = shardPlacement->nodePort; uint32 workerPort = shardPlacement->nodePort;
StringInfo workerDropQuery = makeStringInfo(); StringInfo workerDropQuery = makeStringInfo();
MultiConnection *connection = NULL;
uint32 connectionFlags = FOR_DDL; uint32 connectionFlags = FOR_DDL;
char storageType = shardInterval->storageType; char storageType = shardInterval->storageType;
@ -441,7 +418,8 @@ DropShards(Oid relationId, char *schemaName, char *relationName,
continue; continue;
} }
connection = GetPlacementConnection(connectionFlags, shardPlacement, MultiConnection *connection = GetPlacementConnection(connectionFlags,
shardPlacement,
NULL); NULL);
RemoteTransactionBeginIfNecessary(connection); RemoteTransactionBeginIfNecessary(connection);
@ -471,7 +449,7 @@ DropShards(Oid relationId, char *schemaName, char *relationName,
DeleteShardRow(shardId); DeleteShardRow(shardId);
} }
droppedShardCount = list_length(deletableShardIntervalList); int droppedShardCount = list_length(deletableShardIntervalList);
return droppedShardCount; return droppedShardCount;
} }
@ -573,7 +551,6 @@ ShardsMatchingDeleteCriteria(Oid relationId, List *shardIntervalList,
Node *deleteCriteria) Node *deleteCriteria)
{ {
List *dropShardIntervalList = NIL; List *dropShardIntervalList = NIL;
List *deleteCriteriaList = NIL;
ListCell *shardIntervalCell = NULL; ListCell *shardIntervalCell = NULL;
/* build the base expression for constraint */ /* build the base expression for constraint */
@ -582,7 +559,7 @@ ShardsMatchingDeleteCriteria(Oid relationId, List *shardIntervalList,
Node *baseConstraint = BuildBaseConstraint(partitionColumn); Node *baseConstraint = BuildBaseConstraint(partitionColumn);
Assert(deleteCriteria != NULL); Assert(deleteCriteria != NULL);
deleteCriteriaList = list_make1(deleteCriteria); List *deleteCriteriaList = list_make1(deleteCriteria);
/* walk over shard list and check if shards can be dropped */ /* walk over shard list and check if shards can be dropped */
foreach(shardIntervalCell, shardIntervalList) foreach(shardIntervalCell, shardIntervalList)
@ -591,27 +568,23 @@ ShardsMatchingDeleteCriteria(Oid relationId, List *shardIntervalList,
if (shardInterval->minValueExists && shardInterval->maxValueExists) if (shardInterval->minValueExists && shardInterval->maxValueExists)
{ {
List *restrictInfoList = NIL; List *restrictInfoList = NIL;
bool dropShard = false;
BoolExpr *andExpr = NULL;
Expr *lessThanExpr = NULL;
Expr *greaterThanExpr = NULL;
RestrictInfo *lessThanRestrictInfo = NULL;
RestrictInfo *greaterThanRestrictInfo = NULL;
/* set the min/max values in the base constraint */ /* set the min/max values in the base constraint */
UpdateConstraint(baseConstraint, shardInterval); UpdateConstraint(baseConstraint, shardInterval);
andExpr = (BoolExpr *) baseConstraint; BoolExpr *andExpr = (BoolExpr *) baseConstraint;
lessThanExpr = (Expr *) linitial(andExpr->args); Expr *lessThanExpr = (Expr *) linitial(andExpr->args);
greaterThanExpr = (Expr *) lsecond(andExpr->args); Expr *greaterThanExpr = (Expr *) lsecond(andExpr->args);
lessThanRestrictInfo = make_simple_restrictinfo(lessThanExpr); RestrictInfo *lessThanRestrictInfo = make_simple_restrictinfo(lessThanExpr);
greaterThanRestrictInfo = make_simple_restrictinfo(greaterThanExpr); RestrictInfo *greaterThanRestrictInfo = make_simple_restrictinfo(
greaterThanExpr);
restrictInfoList = lappend(restrictInfoList, lessThanRestrictInfo); restrictInfoList = lappend(restrictInfoList, lessThanRestrictInfo);
restrictInfoList = lappend(restrictInfoList, greaterThanRestrictInfo); restrictInfoList = lappend(restrictInfoList, greaterThanRestrictInfo);
dropShard = predicate_implied_by(deleteCriteriaList, restrictInfoList, false); bool dropShard = predicate_implied_by(deleteCriteriaList, restrictInfoList,
false);
if (dropShard) if (dropShard)
{ {
dropShardIntervalList = lappend(dropShardIntervalList, shardInterval); dropShardIntervalList = lappend(dropShardIntervalList, shardInterval);

View File

@ -91,7 +91,6 @@ Datum
citus_total_relation_size(PG_FUNCTION_ARGS) citus_total_relation_size(PG_FUNCTION_ARGS)
{ {
Oid relationId = PG_GETARG_OID(0); Oid relationId = PG_GETARG_OID(0);
uint64 totalRelationSize = 0;
char *tableSizeFunction = PG_TOTAL_RELATION_SIZE_FUNCTION; char *tableSizeFunction = PG_TOTAL_RELATION_SIZE_FUNCTION;
CheckCitusVersion(ERROR); CheckCitusVersion(ERROR);
@ -101,7 +100,7 @@ citus_total_relation_size(PG_FUNCTION_ARGS)
tableSizeFunction = CSTORE_TABLE_SIZE_FUNCTION; tableSizeFunction = CSTORE_TABLE_SIZE_FUNCTION;
} }
totalRelationSize = DistributedTableSize(relationId, tableSizeFunction); uint64 totalRelationSize = DistributedTableSize(relationId, tableSizeFunction);
PG_RETURN_INT64(totalRelationSize); PG_RETURN_INT64(totalRelationSize);
} }
@ -115,7 +114,6 @@ Datum
citus_table_size(PG_FUNCTION_ARGS) citus_table_size(PG_FUNCTION_ARGS)
{ {
Oid relationId = PG_GETARG_OID(0); Oid relationId = PG_GETARG_OID(0);
uint64 tableSize = 0;
char *tableSizeFunction = PG_TABLE_SIZE_FUNCTION; char *tableSizeFunction = PG_TABLE_SIZE_FUNCTION;
CheckCitusVersion(ERROR); CheckCitusVersion(ERROR);
@ -125,7 +123,7 @@ citus_table_size(PG_FUNCTION_ARGS)
tableSizeFunction = CSTORE_TABLE_SIZE_FUNCTION; tableSizeFunction = CSTORE_TABLE_SIZE_FUNCTION;
} }
tableSize = DistributedTableSize(relationId, tableSizeFunction); uint64 tableSize = DistributedTableSize(relationId, tableSizeFunction);
PG_RETURN_INT64(tableSize); PG_RETURN_INT64(tableSize);
} }
@ -139,7 +137,6 @@ Datum
citus_relation_size(PG_FUNCTION_ARGS) citus_relation_size(PG_FUNCTION_ARGS)
{ {
Oid relationId = PG_GETARG_OID(0); Oid relationId = PG_GETARG_OID(0);
uint64 relationSize = 0;
char *tableSizeFunction = PG_RELATION_SIZE_FUNCTION; char *tableSizeFunction = PG_RELATION_SIZE_FUNCTION;
CheckCitusVersion(ERROR); CheckCitusVersion(ERROR);
@ -149,7 +146,7 @@ citus_relation_size(PG_FUNCTION_ARGS)
tableSizeFunction = CSTORE_TABLE_SIZE_FUNCTION; tableSizeFunction = CSTORE_TABLE_SIZE_FUNCTION;
} }
relationSize = DistributedTableSize(relationId, tableSizeFunction); uint64 relationSize = DistributedTableSize(relationId, tableSizeFunction);
PG_RETURN_INT64(relationSize); PG_RETURN_INT64(relationSize);
} }
@ -163,8 +160,6 @@ citus_relation_size(PG_FUNCTION_ARGS)
static uint64 static uint64
DistributedTableSize(Oid relationId, char *sizeQuery) DistributedTableSize(Oid relationId, char *sizeQuery)
{ {
Relation relation = NULL;
List *workerNodeList = NULL;
ListCell *workerNodeCell = NULL; ListCell *workerNodeCell = NULL;
uint64 totalRelationSize = 0; uint64 totalRelationSize = 0;
@ -175,7 +170,7 @@ DistributedTableSize(Oid relationId, char *sizeQuery)
" blocks which contain multi-shard data modifications"))); " blocks which contain multi-shard data modifications")));
} }
relation = try_relation_open(relationId, AccessShareLock); Relation relation = try_relation_open(relationId, AccessShareLock);
if (relation == NULL) if (relation == NULL)
{ {
@ -185,7 +180,7 @@ DistributedTableSize(Oid relationId, char *sizeQuery)
ErrorIfNotSuitableToGetSize(relationId); ErrorIfNotSuitableToGetSize(relationId);
workerNodeList = ActiveReadableNodeList(); List *workerNodeList = ActiveReadableNodeList();
foreach(workerNodeCell, workerNodeList) foreach(workerNodeCell, workerNodeList)
{ {
@ -209,27 +204,22 @@ DistributedTableSize(Oid relationId, char *sizeQuery)
static uint64 static uint64
DistributedTableSizeOnWorker(WorkerNode *workerNode, Oid relationId, char *sizeQuery) DistributedTableSizeOnWorker(WorkerNode *workerNode, Oid relationId, char *sizeQuery)
{ {
StringInfo tableSizeQuery = NULL;
StringInfo tableSizeStringInfo = NULL;
char *workerNodeName = workerNode->workerName; char *workerNodeName = workerNode->workerName;
uint32 workerNodePort = workerNode->workerPort; uint32 workerNodePort = workerNode->workerPort;
char *tableSizeString;
uint64 tableSize = 0;
MultiConnection *connection = NULL;
uint32 connectionFlag = 0; uint32 connectionFlag = 0;
PGresult *result = NULL; PGresult *result = NULL;
int queryResult = 0;
List *sizeList = NIL;
bool raiseErrors = true; bool raiseErrors = true;
List *shardIntervalsOnNode = ShardIntervalsOnWorkerGroup(workerNode, relationId); List *shardIntervalsOnNode = ShardIntervalsOnWorkerGroup(workerNode, relationId);
tableSizeQuery = GenerateSizeQueryOnMultiplePlacements(relationId, StringInfo tableSizeQuery = GenerateSizeQueryOnMultiplePlacements(relationId,
shardIntervalsOnNode, shardIntervalsOnNode,
sizeQuery); sizeQuery);
connection = GetNodeConnection(connectionFlag, workerNodeName, workerNodePort); MultiConnection *connection = GetNodeConnection(connectionFlag, workerNodeName,
queryResult = ExecuteOptionalRemoteCommand(connection, tableSizeQuery->data, &result); workerNodePort);
int queryResult = ExecuteOptionalRemoteCommand(connection, tableSizeQuery->data,
&result);
if (queryResult != 0) if (queryResult != 0)
{ {
@ -237,10 +227,10 @@ DistributedTableSizeOnWorker(WorkerNode *workerNode, Oid relationId, char *sizeQ
errmsg("cannot get the size because of a connection error"))); errmsg("cannot get the size because of a connection error")));
} }
sizeList = ReadFirstColumnAsText(result); List *sizeList = ReadFirstColumnAsText(result);
tableSizeStringInfo = (StringInfo) linitial(sizeList); StringInfo tableSizeStringInfo = (StringInfo) linitial(sizeList);
tableSizeString = tableSizeStringInfo->data; char *tableSizeString = tableSizeStringInfo->data;
tableSize = atol(tableSizeString); uint64 tableSize = atol(tableSizeString);
PQclear(result); PQclear(result);
ClearResults(connection, raiseErrors); ClearResults(connection, raiseErrors);
@ -260,18 +250,17 @@ GroupShardPlacementsForTableOnGroup(Oid relationId, int32 groupId)
DistTableCacheEntry *distTableCacheEntry = DistributedTableCacheEntry(relationId); DistTableCacheEntry *distTableCacheEntry = DistributedTableCacheEntry(relationId);
List *resultList = NIL; List *resultList = NIL;
int shardIndex = 0;
int shardIntervalArrayLength = distTableCacheEntry->shardIntervalArrayLength; int shardIntervalArrayLength = distTableCacheEntry->shardIntervalArrayLength;
for (shardIndex = 0; shardIndex < shardIntervalArrayLength; shardIndex++) for (int shardIndex = 0; shardIndex < shardIntervalArrayLength; shardIndex++)
{ {
GroupShardPlacement *placementArray = GroupShardPlacement *placementArray =
distTableCacheEntry->arrayOfPlacementArrays[shardIndex]; distTableCacheEntry->arrayOfPlacementArrays[shardIndex];
int numberOfPlacements = int numberOfPlacements =
distTableCacheEntry->arrayOfPlacementArrayLengths[shardIndex]; distTableCacheEntry->arrayOfPlacementArrayLengths[shardIndex];
int placementIndex = 0;
for (placementIndex = 0; placementIndex < numberOfPlacements; placementIndex++) for (int placementIndex = 0; placementIndex < numberOfPlacements;
placementIndex++)
{ {
GroupShardPlacement *placement = &placementArray[placementIndex]; GroupShardPlacement *placement = &placementArray[placementIndex];
@ -298,24 +287,22 @@ ShardIntervalsOnWorkerGroup(WorkerNode *workerNode, Oid relationId)
{ {
DistTableCacheEntry *distTableCacheEntry = DistributedTableCacheEntry(relationId); DistTableCacheEntry *distTableCacheEntry = DistributedTableCacheEntry(relationId);
List *shardIntervalList = NIL; List *shardIntervalList = NIL;
int shardIndex = 0;
int shardIntervalArrayLength = distTableCacheEntry->shardIntervalArrayLength; int shardIntervalArrayLength = distTableCacheEntry->shardIntervalArrayLength;
for (shardIndex = 0; shardIndex < shardIntervalArrayLength; shardIndex++) for (int shardIndex = 0; shardIndex < shardIntervalArrayLength; shardIndex++)
{ {
GroupShardPlacement *placementArray = GroupShardPlacement *placementArray =
distTableCacheEntry->arrayOfPlacementArrays[shardIndex]; distTableCacheEntry->arrayOfPlacementArrays[shardIndex];
int numberOfPlacements = int numberOfPlacements =
distTableCacheEntry->arrayOfPlacementArrayLengths[shardIndex]; distTableCacheEntry->arrayOfPlacementArrayLengths[shardIndex];
int placementIndex = 0;
for (placementIndex = 0; placementIndex < numberOfPlacements; placementIndex++) for (int placementIndex = 0; placementIndex < numberOfPlacements;
placementIndex++)
{ {
GroupShardPlacement *placement = &placementArray[placementIndex]; GroupShardPlacement *placement = &placementArray[placementIndex];
uint64 shardId = placement->shardId; uint64 shardId = placement->shardId;
bool metadataLock = false;
metadataLock = TryLockShardDistributionMetadata(shardId, ShareLock); bool metadataLock = TryLockShardDistributionMetadata(shardId, ShareLock);
/* if the lock is not acquired warn the user */ /* if the lock is not acquired warn the user */
if (metadataLock == false) if (metadataLock == false)
@ -364,12 +351,10 @@ GenerateSizeQueryOnMultiplePlacements(Oid distributedRelationId, List *shardInte
ShardInterval *shardInterval = (ShardInterval *) lfirst(shardIntervalCell); ShardInterval *shardInterval = (ShardInterval *) lfirst(shardIntervalCell);
uint64 shardId = shardInterval->shardId; uint64 shardId = shardInterval->shardId;
char *shardName = get_rel_name(distributedRelationId); char *shardName = get_rel_name(distributedRelationId);
char *shardQualifiedName = NULL;
char *quotedShardName = NULL;
AppendShardIdToName(&shardName, shardId); AppendShardIdToName(&shardName, shardId);
shardQualifiedName = quote_qualified_identifier(schemaName, shardName); char *shardQualifiedName = quote_qualified_identifier(schemaName, shardName);
quotedShardName = quote_literal_cstr(shardQualifiedName); char *quotedShardName = quote_literal_cstr(shardQualifiedName);
appendStringInfo(selectQuery, sizeQuery, quotedShardName); appendStringInfo(selectQuery, sizeQuery, quotedShardName);
appendStringInfo(selectQuery, " + "); appendStringInfo(selectQuery, " + ");
@ -509,12 +494,11 @@ LoadShardIntervalList(Oid relationId)
{ {
DistTableCacheEntry *cacheEntry = DistributedTableCacheEntry(relationId); DistTableCacheEntry *cacheEntry = DistributedTableCacheEntry(relationId);
List *shardList = NIL; List *shardList = NIL;
int i = 0;
for (i = 0; i < cacheEntry->shardIntervalArrayLength; i++) for (int i = 0; i < cacheEntry->shardIntervalArrayLength; i++)
{ {
ShardInterval *newShardInterval = NULL; ShardInterval *newShardInterval = (ShardInterval *) palloc0(
newShardInterval = (ShardInterval *) palloc0(sizeof(ShardInterval)); sizeof(ShardInterval));
CopyShardInterval(cacheEntry->sortedShardIntervalArray[i], newShardInterval); CopyShardInterval(cacheEntry->sortedShardIntervalArray[i], newShardInterval);
@ -557,9 +541,8 @@ LoadShardList(Oid relationId)
{ {
DistTableCacheEntry *cacheEntry = DistributedTableCacheEntry(relationId); DistTableCacheEntry *cacheEntry = DistributedTableCacheEntry(relationId);
List *shardList = NIL; List *shardList = NIL;
int i = 0;
for (i = 0; i < cacheEntry->shardIntervalArrayLength; i++) for (int i = 0; i < cacheEntry->shardIntervalArrayLength; i++)
{ {
ShardInterval *currentShardInterval = cacheEntry->sortedShardIntervalArray[i]; ShardInterval *currentShardInterval = cacheEntry->sortedShardIntervalArray[i];
uint64 *shardIdPointer = AllocateUint64(currentShardInterval->shardId); uint64 *shardIdPointer = AllocateUint64(currentShardInterval->shardId);
@ -673,10 +656,7 @@ NodeGroupHasShardPlacements(int32 groupId, bool onlyConsiderActivePlacements)
const int scanKeyCount = (onlyConsiderActivePlacements ? 2 : 1); const int scanKeyCount = (onlyConsiderActivePlacements ? 2 : 1);
const bool indexOK = false; const bool indexOK = false;
bool hasFinalizedPlacements = false;
HeapTuple heapTuple = NULL;
SysScanDesc scanDescriptor = NULL;
ScanKeyData scanKey[2]; ScanKeyData scanKey[2];
Relation pgPlacement = heap_open(DistPlacementRelationId(), Relation pgPlacement = heap_open(DistPlacementRelationId(),
@ -690,12 +670,13 @@ NodeGroupHasShardPlacements(int32 groupId, bool onlyConsiderActivePlacements)
BTEqualStrategyNumber, F_INT4EQ, Int32GetDatum(FILE_FINALIZED)); BTEqualStrategyNumber, F_INT4EQ, Int32GetDatum(FILE_FINALIZED));
} }
scanDescriptor = systable_beginscan(pgPlacement, SysScanDesc scanDescriptor = systable_beginscan(pgPlacement,
DistPlacementGroupidIndexId(), indexOK, DistPlacementGroupidIndexId(),
indexOK,
NULL, scanKeyCount, scanKey); NULL, scanKeyCount, scanKey);
heapTuple = systable_getnext(scanDescriptor); HeapTuple heapTuple = systable_getnext(scanDescriptor);
hasFinalizedPlacements = HeapTupleIsValid(heapTuple); bool hasFinalizedPlacements = HeapTupleIsValid(heapTuple);
systable_endscan(scanDescriptor); systable_endscan(scanDescriptor);
heap_close(pgPlacement, NoLock); heap_close(pgPlacement, NoLock);
@ -772,23 +753,21 @@ BuildShardPlacementList(ShardInterval *shardInterval)
{ {
int64 shardId = shardInterval->shardId; int64 shardId = shardInterval->shardId;
List *shardPlacementList = NIL; List *shardPlacementList = NIL;
Relation pgPlacement = NULL;
SysScanDesc scanDescriptor = NULL;
ScanKeyData scanKey[1]; ScanKeyData scanKey[1];
int scanKeyCount = 1; int scanKeyCount = 1;
bool indexOK = true; bool indexOK = true;
HeapTuple heapTuple = NULL;
pgPlacement = heap_open(DistPlacementRelationId(), AccessShareLock); Relation pgPlacement = heap_open(DistPlacementRelationId(), AccessShareLock);
ScanKeyInit(&scanKey[0], Anum_pg_dist_placement_shardid, ScanKeyInit(&scanKey[0], Anum_pg_dist_placement_shardid,
BTEqualStrategyNumber, F_INT8EQ, Int64GetDatum(shardId)); BTEqualStrategyNumber, F_INT8EQ, Int64GetDatum(shardId));
scanDescriptor = systable_beginscan(pgPlacement, SysScanDesc scanDescriptor = systable_beginscan(pgPlacement,
DistPlacementShardidIndexId(), indexOK, DistPlacementShardidIndexId(),
indexOK,
NULL, scanKeyCount, scanKey); NULL, scanKeyCount, scanKey);
heapTuple = systable_getnext(scanDescriptor); HeapTuple heapTuple = systable_getnext(scanDescriptor);
while (HeapTupleIsValid(heapTuple)) while (HeapTupleIsValid(heapTuple))
{ {
TupleDesc tupleDescriptor = RelationGetDescr(pgPlacement); TupleDesc tupleDescriptor = RelationGetDescr(pgPlacement);
@ -817,23 +796,21 @@ List *
AllShardPlacementsOnNodeGroup(int32 groupId) AllShardPlacementsOnNodeGroup(int32 groupId)
{ {
List *shardPlacementList = NIL; List *shardPlacementList = NIL;
Relation pgPlacement = NULL;
SysScanDesc scanDescriptor = NULL;
ScanKeyData scanKey[1]; ScanKeyData scanKey[1];
int scanKeyCount = 1; int scanKeyCount = 1;
bool indexOK = true; bool indexOK = true;
HeapTuple heapTuple = NULL;
pgPlacement = heap_open(DistPlacementRelationId(), AccessShareLock); Relation pgPlacement = heap_open(DistPlacementRelationId(), AccessShareLock);
ScanKeyInit(&scanKey[0], Anum_pg_dist_placement_groupid, ScanKeyInit(&scanKey[0], Anum_pg_dist_placement_groupid,
BTEqualStrategyNumber, F_INT4EQ, Int32GetDatum(groupId)); BTEqualStrategyNumber, F_INT4EQ, Int32GetDatum(groupId));
scanDescriptor = systable_beginscan(pgPlacement, SysScanDesc scanDescriptor = systable_beginscan(pgPlacement,
DistPlacementGroupidIndexId(), indexOK, DistPlacementGroupidIndexId(),
indexOK,
NULL, scanKeyCount, scanKey); NULL, scanKeyCount, scanKey);
heapTuple = systable_getnext(scanDescriptor); HeapTuple heapTuple = systable_getnext(scanDescriptor);
while (HeapTupleIsValid(heapTuple)) while (HeapTupleIsValid(heapTuple))
{ {
TupleDesc tupleDescriptor = RelationGetDescr(pgPlacement); TupleDesc tupleDescriptor = RelationGetDescr(pgPlacement);
@ -861,7 +838,6 @@ AllShardPlacementsOnNodeGroup(int32 groupId)
static GroupShardPlacement * static GroupShardPlacement *
TupleToGroupShardPlacement(TupleDesc tupleDescriptor, HeapTuple heapTuple) TupleToGroupShardPlacement(TupleDesc tupleDescriptor, HeapTuple heapTuple)
{ {
GroupShardPlacement *shardPlacement = NULL;
bool isNullArray[Natts_pg_dist_placement]; bool isNullArray[Natts_pg_dist_placement];
Datum datumArray[Natts_pg_dist_placement]; Datum datumArray[Natts_pg_dist_placement];
@ -877,7 +853,7 @@ TupleToGroupShardPlacement(TupleDesc tupleDescriptor, HeapTuple heapTuple)
*/ */
heap_deform_tuple(heapTuple, tupleDescriptor, datumArray, isNullArray); heap_deform_tuple(heapTuple, tupleDescriptor, datumArray, isNullArray);
shardPlacement = CitusMakeNode(GroupShardPlacement); GroupShardPlacement *shardPlacement = CitusMakeNode(GroupShardPlacement);
shardPlacement->placementId = DatumGetInt64( shardPlacement->placementId = DatumGetInt64(
datumArray[Anum_pg_dist_placement_placementid - 1]); datumArray[Anum_pg_dist_placement_placementid - 1]);
shardPlacement->shardId = DatumGetInt64( shardPlacement->shardId = DatumGetInt64(
@ -902,9 +878,6 @@ void
InsertShardRow(Oid relationId, uint64 shardId, char storageType, InsertShardRow(Oid relationId, uint64 shardId, char storageType,
text *shardMinValue, text *shardMaxValue) text *shardMinValue, text *shardMaxValue)
{ {
Relation pgDistShard = NULL;
TupleDesc tupleDescriptor = NULL;
HeapTuple heapTuple = NULL;
Datum values[Natts_pg_dist_shard]; Datum values[Natts_pg_dist_shard];
bool isNulls[Natts_pg_dist_shard]; bool isNulls[Natts_pg_dist_shard];
@ -932,10 +905,10 @@ InsertShardRow(Oid relationId, uint64 shardId, char storageType,
} }
/* open shard relation and insert new tuple */ /* open shard relation and insert new tuple */
pgDistShard = heap_open(DistShardRelationId(), RowExclusiveLock); Relation pgDistShard = heap_open(DistShardRelationId(), RowExclusiveLock);
tupleDescriptor = RelationGetDescr(pgDistShard); TupleDesc tupleDescriptor = RelationGetDescr(pgDistShard);
heapTuple = heap_form_tuple(tupleDescriptor, values, isNulls); HeapTuple heapTuple = heap_form_tuple(tupleDescriptor, values, isNulls);
CatalogTupleInsert(pgDistShard, heapTuple); CatalogTupleInsert(pgDistShard, heapTuple);
@ -958,9 +931,6 @@ InsertShardPlacementRow(uint64 shardId, uint64 placementId,
char shardState, uint64 shardLength, char shardState, uint64 shardLength,
int32 groupId) int32 groupId)
{ {
Relation pgDistPlacement = NULL;
TupleDesc tupleDescriptor = NULL;
HeapTuple heapTuple = NULL;
Datum values[Natts_pg_dist_placement]; Datum values[Natts_pg_dist_placement];
bool isNulls[Natts_pg_dist_placement]; bool isNulls[Natts_pg_dist_placement];
@ -979,10 +949,10 @@ InsertShardPlacementRow(uint64 shardId, uint64 placementId,
values[Anum_pg_dist_placement_groupid - 1] = Int32GetDatum(groupId); values[Anum_pg_dist_placement_groupid - 1] = Int32GetDatum(groupId);
/* open shard placement relation and insert new tuple */ /* open shard placement relation and insert new tuple */
pgDistPlacement = heap_open(DistPlacementRelationId(), RowExclusiveLock); Relation pgDistPlacement = heap_open(DistPlacementRelationId(), RowExclusiveLock);
tupleDescriptor = RelationGetDescr(pgDistPlacement); TupleDesc tupleDescriptor = RelationGetDescr(pgDistPlacement);
heapTuple = heap_form_tuple(tupleDescriptor, values, isNulls); HeapTuple heapTuple = heap_form_tuple(tupleDescriptor, values, isNulls);
CatalogTupleInsert(pgDistPlacement, heapTuple); CatalogTupleInsert(pgDistPlacement, heapTuple);
@ -1003,15 +973,13 @@ InsertIntoPgDistPartition(Oid relationId, char distributionMethod,
Var *distributionColumn, uint32 colocationId, Var *distributionColumn, uint32 colocationId,
char replicationModel) char replicationModel)
{ {
Relation pgDistPartition = NULL;
char *distributionColumnString = NULL; char *distributionColumnString = NULL;
HeapTuple newTuple = NULL;
Datum newValues[Natts_pg_dist_partition]; Datum newValues[Natts_pg_dist_partition];
bool newNulls[Natts_pg_dist_partition]; bool newNulls[Natts_pg_dist_partition];
/* open system catalog and insert new tuple */ /* open system catalog and insert new tuple */
pgDistPartition = heap_open(DistPartitionRelationId(), RowExclusiveLock); Relation pgDistPartition = heap_open(DistPartitionRelationId(), RowExclusiveLock);
/* form new tuple for pg_dist_partition */ /* form new tuple for pg_dist_partition */
memset(newValues, 0, sizeof(newValues)); memset(newValues, 0, sizeof(newValues));
@ -1038,7 +1006,8 @@ InsertIntoPgDistPartition(Oid relationId, char distributionMethod,
newNulls[Anum_pg_dist_partition_partkey - 1] = true; newNulls[Anum_pg_dist_partition_partkey - 1] = true;
} }
newTuple = heap_form_tuple(RelationGetDescr(pgDistPartition), newValues, newNulls); HeapTuple newTuple = heap_form_tuple(RelationGetDescr(pgDistPartition), newValues,
newNulls);
/* finally insert tuple, build index entries & register cache invalidation */ /* finally insert tuple, build index entries & register cache invalidation */
CatalogTupleInsert(pgDistPartition, newTuple); CatalogTupleInsert(pgDistPartition, newTuple);
@ -1092,21 +1061,19 @@ RecordDistributedRelationDependencies(Oid distributedRelationId, Node *distribut
void void
DeletePartitionRow(Oid distributedRelationId) DeletePartitionRow(Oid distributedRelationId)
{ {
Relation pgDistPartition = NULL;
HeapTuple heapTuple = NULL;
SysScanDesc scanDescriptor = NULL;
ScanKeyData scanKey[1]; ScanKeyData scanKey[1];
int scanKeyCount = 1; int scanKeyCount = 1;
pgDistPartition = heap_open(DistPartitionRelationId(), RowExclusiveLock); Relation pgDistPartition = heap_open(DistPartitionRelationId(), RowExclusiveLock);
ScanKeyInit(&scanKey[0], Anum_pg_dist_partition_logicalrelid, ScanKeyInit(&scanKey[0], Anum_pg_dist_partition_logicalrelid,
BTEqualStrategyNumber, F_OIDEQ, ObjectIdGetDatum(distributedRelationId)); BTEqualStrategyNumber, F_OIDEQ, ObjectIdGetDatum(distributedRelationId));
scanDescriptor = systable_beginscan(pgDistPartition, InvalidOid, false, NULL, SysScanDesc scanDescriptor = systable_beginscan(pgDistPartition, InvalidOid, false,
NULL,
scanKeyCount, scanKey); scanKeyCount, scanKey);
heapTuple = systable_getnext(scanDescriptor); HeapTuple heapTuple = systable_getnext(scanDescriptor);
if (!HeapTupleIsValid(heapTuple)) if (!HeapTupleIsValid(heapTuple))
{ {
ereport(ERROR, (errmsg("could not find valid entry for partition %d", ereport(ERROR, (errmsg("could not find valid entry for partition %d",
@ -1134,33 +1101,28 @@ DeletePartitionRow(Oid distributedRelationId)
void void
DeleteShardRow(uint64 shardId) DeleteShardRow(uint64 shardId)
{ {
Relation pgDistShard = NULL;
SysScanDesc scanDescriptor = NULL;
ScanKeyData scanKey[1]; ScanKeyData scanKey[1];
int scanKeyCount = 1; int scanKeyCount = 1;
bool indexOK = true; bool indexOK = true;
HeapTuple heapTuple = NULL;
Form_pg_dist_shard pgDistShardForm = NULL;
Oid distributedRelationId = InvalidOid;
pgDistShard = heap_open(DistShardRelationId(), RowExclusiveLock); Relation pgDistShard = heap_open(DistShardRelationId(), RowExclusiveLock);
ScanKeyInit(&scanKey[0], Anum_pg_dist_shard_shardid, ScanKeyInit(&scanKey[0], Anum_pg_dist_shard_shardid,
BTEqualStrategyNumber, F_INT8EQ, Int64GetDatum(shardId)); BTEqualStrategyNumber, F_INT8EQ, Int64GetDatum(shardId));
scanDescriptor = systable_beginscan(pgDistShard, SysScanDesc scanDescriptor = systable_beginscan(pgDistShard,
DistShardShardidIndexId(), indexOK, DistShardShardidIndexId(), indexOK,
NULL, scanKeyCount, scanKey); NULL, scanKeyCount, scanKey);
heapTuple = systable_getnext(scanDescriptor); HeapTuple heapTuple = systable_getnext(scanDescriptor);
if (!HeapTupleIsValid(heapTuple)) if (!HeapTupleIsValid(heapTuple))
{ {
ereport(ERROR, (errmsg("could not find valid entry for shard " ereport(ERROR, (errmsg("could not find valid entry for shard "
UINT64_FORMAT, shardId))); UINT64_FORMAT, shardId)));
} }
pgDistShardForm = (Form_pg_dist_shard) GETSTRUCT(heapTuple); Form_pg_dist_shard pgDistShardForm = (Form_pg_dist_shard) GETSTRUCT(heapTuple);
distributedRelationId = pgDistShardForm->logicalrelid; Oid distributedRelationId = pgDistShardForm->logicalrelid;
simple_heap_delete(pgDistShard, &heapTuple->t_self); simple_heap_delete(pgDistShard, &heapTuple->t_self);
@ -1181,34 +1143,30 @@ DeleteShardRow(uint64 shardId)
void void
DeleteShardPlacementRow(uint64 placementId) DeleteShardPlacementRow(uint64 placementId)
{ {
Relation pgDistPlacement = NULL;
SysScanDesc scanDescriptor = NULL;
const int scanKeyCount = 1; const int scanKeyCount = 1;
ScanKeyData scanKey[1]; ScanKeyData scanKey[1];
bool indexOK = true; bool indexOK = true;
HeapTuple heapTuple = NULL;
TupleDesc tupleDescriptor = NULL;
bool isNull = false; bool isNull = false;
uint64 shardId = 0;
pgDistPlacement = heap_open(DistPlacementRelationId(), RowExclusiveLock); Relation pgDistPlacement = heap_open(DistPlacementRelationId(), RowExclusiveLock);
tupleDescriptor = RelationGetDescr(pgDistPlacement); TupleDesc tupleDescriptor = RelationGetDescr(pgDistPlacement);
ScanKeyInit(&scanKey[0], Anum_pg_dist_placement_placementid, ScanKeyInit(&scanKey[0], Anum_pg_dist_placement_placementid,
BTEqualStrategyNumber, F_INT8EQ, Int64GetDatum(placementId)); BTEqualStrategyNumber, F_INT8EQ, Int64GetDatum(placementId));
scanDescriptor = systable_beginscan(pgDistPlacement, SysScanDesc scanDescriptor = systable_beginscan(pgDistPlacement,
DistPlacementPlacementidIndexId(), indexOK, DistPlacementPlacementidIndexId(),
indexOK,
NULL, scanKeyCount, scanKey); NULL, scanKeyCount, scanKey);
heapTuple = systable_getnext(scanDescriptor); HeapTuple heapTuple = systable_getnext(scanDescriptor);
if (heapTuple == NULL) if (heapTuple == NULL)
{ {
ereport(ERROR, (errmsg("could not find valid entry for shard placement " ereport(ERROR, (errmsg("could not find valid entry for shard placement "
INT64_FORMAT, placementId))); INT64_FORMAT, placementId)));
} }
shardId = heap_getattr(heapTuple, Anum_pg_dist_placement_shardid, uint64 shardId = heap_getattr(heapTuple, Anum_pg_dist_placement_shardid,
tupleDescriptor, &isNull); tupleDescriptor, &isNull);
if (HeapTupleHeaderGetNatts(heapTuple->t_data) != Natts_pg_dist_placement || if (HeapTupleHeaderGetNatts(heapTuple->t_data) != Natts_pg_dist_placement ||
HeapTupleHasNulls(heapTuple)) HeapTupleHasNulls(heapTuple))
@ -1233,29 +1191,25 @@ DeleteShardPlacementRow(uint64 placementId)
void void
UpdateShardPlacementState(uint64 placementId, char shardState) UpdateShardPlacementState(uint64 placementId, char shardState)
{ {
Relation pgDistPlacement = NULL;
SysScanDesc scanDescriptor = NULL;
ScanKeyData scanKey[1]; ScanKeyData scanKey[1];
int scanKeyCount = 1; int scanKeyCount = 1;
bool indexOK = true; bool indexOK = true;
HeapTuple heapTuple = NULL;
TupleDesc tupleDescriptor = NULL;
Datum values[Natts_pg_dist_placement]; Datum values[Natts_pg_dist_placement];
bool isnull[Natts_pg_dist_placement]; bool isnull[Natts_pg_dist_placement];
bool replace[Natts_pg_dist_placement]; bool replace[Natts_pg_dist_placement];
uint64 shardId = INVALID_SHARD_ID;
bool colIsNull = false; bool colIsNull = false;
pgDistPlacement = heap_open(DistPlacementRelationId(), RowExclusiveLock); Relation pgDistPlacement = heap_open(DistPlacementRelationId(), RowExclusiveLock);
tupleDescriptor = RelationGetDescr(pgDistPlacement); TupleDesc tupleDescriptor = RelationGetDescr(pgDistPlacement);
ScanKeyInit(&scanKey[0], Anum_pg_dist_placement_placementid, ScanKeyInit(&scanKey[0], Anum_pg_dist_placement_placementid,
BTEqualStrategyNumber, F_INT8EQ, Int64GetDatum(placementId)); BTEqualStrategyNumber, F_INT8EQ, Int64GetDatum(placementId));
scanDescriptor = systable_beginscan(pgDistPlacement, SysScanDesc scanDescriptor = systable_beginscan(pgDistPlacement,
DistPlacementPlacementidIndexId(), indexOK, DistPlacementPlacementidIndexId(),
indexOK,
NULL, scanKeyCount, scanKey); NULL, scanKeyCount, scanKey);
heapTuple = systable_getnext(scanDescriptor); HeapTuple heapTuple = systable_getnext(scanDescriptor);
if (!HeapTupleIsValid(heapTuple)) if (!HeapTupleIsValid(heapTuple))
{ {
ereport(ERROR, (errmsg("could not find valid entry for shard placement " ereport(ERROR, (errmsg("could not find valid entry for shard placement "
@ -1273,7 +1227,7 @@ UpdateShardPlacementState(uint64 placementId, char shardState)
CatalogTupleUpdate(pgDistPlacement, &heapTuple->t_self, heapTuple); CatalogTupleUpdate(pgDistPlacement, &heapTuple->t_self, heapTuple);
shardId = DatumGetInt64(heap_getattr(heapTuple, uint64 shardId = DatumGetInt64(heap_getattr(heapTuple,
Anum_pg_dist_placement_shardid, Anum_pg_dist_placement_shardid,
tupleDescriptor, &colIsNull)); tupleDescriptor, &colIsNull));
Assert(!colIsNull); Assert(!colIsNull);
@ -1293,9 +1247,7 @@ UpdateShardPlacementState(uint64 placementId, char shardState)
void void
EnsureTablePermissions(Oid relationId, AclMode mode) EnsureTablePermissions(Oid relationId, AclMode mode)
{ {
AclResult aclresult; AclResult aclresult = pg_class_aclcheck(relationId, GetUserId(), mode);
aclresult = pg_class_aclcheck(relationId, GetUserId(), mode);
if (aclresult != ACLCHECK_OK) if (aclresult != ACLCHECK_OK)
{ {
@ -1385,17 +1337,14 @@ EnsureSuperUser(void)
char * char *
TableOwner(Oid relationId) TableOwner(Oid relationId)
{ {
Oid userId = InvalidOid; HeapTuple tuple = SearchSysCache1(RELOID, ObjectIdGetDatum(relationId));
HeapTuple tuple;
tuple = SearchSysCache1(RELOID, ObjectIdGetDatum(relationId));
if (!HeapTupleIsValid(tuple)) if (!HeapTupleIsValid(tuple))
{ {
ereport(ERROR, (errcode(ERRCODE_UNDEFINED_TABLE), ereport(ERROR, (errcode(ERRCODE_UNDEFINED_TABLE),
errmsg("relation with OID %u does not exist", relationId))); errmsg("relation with OID %u does not exist", relationId)));
} }
userId = ((Form_pg_class) GETSTRUCT(tuple))->relowner; Oid userId = ((Form_pg_class) GETSTRUCT(tuple))->relowner;
ReleaseSysCache(tuple); ReleaseSysCache(tuple);

View File

@ -94,26 +94,20 @@ master_get_table_metadata(PG_FUNCTION_ARGS)
text *relationName = PG_GETARG_TEXT_P(0); text *relationName = PG_GETARG_TEXT_P(0);
Oid relationId = ResolveRelationId(relationName, false); Oid relationId = ResolveRelationId(relationName, false);
DistTableCacheEntry *partitionEntry = NULL;
char *partitionKeyString = NULL;
TypeFuncClass resultTypeClass = 0;
Datum partitionKeyExpr = 0; Datum partitionKeyExpr = 0;
Datum partitionKey = 0; Datum partitionKey = 0;
Datum metadataDatum = 0;
HeapTuple metadataTuple = NULL;
TupleDesc metadataDescriptor = NULL; TupleDesc metadataDescriptor = NULL;
uint64 shardMaxSizeInBytes = 0;
char shardStorageType = 0;
Datum values[TABLE_METADATA_FIELDS]; Datum values[TABLE_METADATA_FIELDS];
bool isNulls[TABLE_METADATA_FIELDS]; bool isNulls[TABLE_METADATA_FIELDS];
CheckCitusVersion(ERROR); CheckCitusVersion(ERROR);
/* find partition tuple for partitioned relation */ /* find partition tuple for partitioned relation */
partitionEntry = DistributedTableCacheEntry(relationId); DistTableCacheEntry *partitionEntry = DistributedTableCacheEntry(relationId);
/* create tuple descriptor for return value */ /* create tuple descriptor for return value */
resultTypeClass = get_call_result_type(fcinfo, NULL, &metadataDescriptor); TypeFuncClass resultTypeClass = get_call_result_type(fcinfo, NULL,
&metadataDescriptor);
if (resultTypeClass != TYPEFUNC_COMPOSITE) if (resultTypeClass != TYPEFUNC_COMPOSITE)
{ {
ereport(ERROR, (errmsg("return type must be a row type"))); ereport(ERROR, (errmsg("return type must be a row type")));
@ -123,7 +117,7 @@ master_get_table_metadata(PG_FUNCTION_ARGS)
memset(values, 0, sizeof(values)); memset(values, 0, sizeof(values));
memset(isNulls, false, sizeof(isNulls)); memset(isNulls, false, sizeof(isNulls));
partitionKeyString = partitionEntry->partitionKeyString; char *partitionKeyString = partitionEntry->partitionKeyString;
/* reference tables do not have partition key */ /* reference tables do not have partition key */
if (partitionKeyString == NULL) if (partitionKeyString == NULL)
@ -140,10 +134,10 @@ master_get_table_metadata(PG_FUNCTION_ARGS)
ObjectIdGetDatum(relationId)); ObjectIdGetDatum(relationId));
} }
shardMaxSizeInBytes = (int64) ShardMaxSize * 1024L; uint64 shardMaxSizeInBytes = (int64) ShardMaxSize * 1024L;
/* get storage type */ /* get storage type */
shardStorageType = ShardStorageType(relationId); char shardStorageType = ShardStorageType(relationId);
values[0] = ObjectIdGetDatum(relationId); values[0] = ObjectIdGetDatum(relationId);
values[1] = shardStorageType; values[1] = shardStorageType;
@ -153,8 +147,8 @@ master_get_table_metadata(PG_FUNCTION_ARGS)
values[5] = Int64GetDatum(shardMaxSizeInBytes); values[5] = Int64GetDatum(shardMaxSizeInBytes);
values[6] = Int32GetDatum(ShardPlacementPolicy); values[6] = Int32GetDatum(ShardPlacementPolicy);
metadataTuple = heap_form_tuple(metadataDescriptor, values, isNulls); HeapTuple metadataTuple = heap_form_tuple(metadataDescriptor, values, isNulls);
metadataDatum = HeapTupleGetDatum(metadataTuple); Datum metadataDatum = HeapTupleGetDatum(metadataTuple);
PG_RETURN_DATUM(metadataDatum); PG_RETURN_DATUM(metadataDatum);
} }
@ -212,17 +206,16 @@ master_get_table_ddl_events(PG_FUNCTION_ARGS)
Oid relationId = ResolveRelationId(relationName, false); Oid relationId = ResolveRelationId(relationName, false);
bool includeSequenceDefaults = true; bool includeSequenceDefaults = true;
MemoryContext oldContext = NULL;
List *tableDDLEventList = NIL;
/* create a function context for cross-call persistence */ /* create a function context for cross-call persistence */
functionContext = SRF_FIRSTCALL_INIT(); functionContext = SRF_FIRSTCALL_INIT();
/* switch to memory context appropriate for multiple function calls */ /* switch to memory context appropriate for multiple function calls */
oldContext = MemoryContextSwitchTo(functionContext->multi_call_memory_ctx); MemoryContext oldContext = MemoryContextSwitchTo(
functionContext->multi_call_memory_ctx);
/* allocate DDL statements, and then save position in DDL statements */ /* allocate DDL statements, and then save position in DDL statements */
tableDDLEventList = GetTableDDLEvents(relationId, includeSequenceDefaults); List *tableDDLEventList = GetTableDDLEvents(relationId, includeSequenceDefaults);
tableDDLEventCell = list_head(tableDDLEventList); tableDDLEventCell = list_head(tableDDLEventList);
functionContext->user_fctx = tableDDLEventCell; functionContext->user_fctx = tableDDLEventCell;
@ -266,14 +259,11 @@ master_get_table_ddl_events(PG_FUNCTION_ARGS)
Datum Datum
master_get_new_shardid(PG_FUNCTION_ARGS) master_get_new_shardid(PG_FUNCTION_ARGS)
{ {
uint64 shardId = 0;
Datum shardIdDatum = 0;
EnsureCoordinator(); EnsureCoordinator();
CheckCitusVersion(ERROR); CheckCitusVersion(ERROR);
shardId = GetNextShardId(); uint64 shardId = GetNextShardId();
shardIdDatum = Int64GetDatum(shardId); Datum shardIdDatum = Int64GetDatum(shardId);
PG_RETURN_DATUM(shardIdDatum); PG_RETURN_DATUM(shardIdDatum);
} }
@ -290,12 +280,8 @@ master_get_new_shardid(PG_FUNCTION_ARGS)
uint64 uint64
GetNextShardId() GetNextShardId()
{ {
text *sequenceName = NULL;
Oid sequenceId = InvalidOid;
Datum sequenceIdDatum = 0;
Oid savedUserId = InvalidOid; Oid savedUserId = InvalidOid;
int savedSecurityContext = 0; int savedSecurityContext = 0;
Datum shardIdDatum = 0;
uint64 shardId = 0; uint64 shardId = 0;
/* /*
@ -313,15 +299,15 @@ GetNextShardId()
return shardId; return shardId;
} }
sequenceName = cstring_to_text(SHARDID_SEQUENCE_NAME); text *sequenceName = cstring_to_text(SHARDID_SEQUENCE_NAME);
sequenceId = ResolveRelationId(sequenceName, false); Oid sequenceId = ResolveRelationId(sequenceName, false);
sequenceIdDatum = ObjectIdGetDatum(sequenceId); Datum sequenceIdDatum = ObjectIdGetDatum(sequenceId);
GetUserIdAndSecContext(&savedUserId, &savedSecurityContext); GetUserIdAndSecContext(&savedUserId, &savedSecurityContext);
SetUserIdAndSecContext(CitusExtensionOwner(), SECURITY_LOCAL_USERID_CHANGE); SetUserIdAndSecContext(CitusExtensionOwner(), SECURITY_LOCAL_USERID_CHANGE);
/* generate new and unique shardId from sequence */ /* generate new and unique shardId from sequence */
shardIdDatum = DirectFunctionCall1(nextval_oid, sequenceIdDatum); Datum shardIdDatum = DirectFunctionCall1(nextval_oid, sequenceIdDatum);
SetUserIdAndSecContext(savedUserId, savedSecurityContext); SetUserIdAndSecContext(savedUserId, savedSecurityContext);
@ -343,14 +329,11 @@ GetNextShardId()
Datum Datum
master_get_new_placementid(PG_FUNCTION_ARGS) master_get_new_placementid(PG_FUNCTION_ARGS)
{ {
uint64 placementId = 0;
Datum placementIdDatum = 0;
EnsureCoordinator(); EnsureCoordinator();
CheckCitusVersion(ERROR); CheckCitusVersion(ERROR);
placementId = GetNextPlacementId(); uint64 placementId = GetNextPlacementId();
placementIdDatum = Int64GetDatum(placementId); Datum placementIdDatum = Int64GetDatum(placementId);
PG_RETURN_DATUM(placementIdDatum); PG_RETURN_DATUM(placementIdDatum);
} }
@ -369,12 +352,8 @@ master_get_new_placementid(PG_FUNCTION_ARGS)
uint64 uint64
GetNextPlacementId(void) GetNextPlacementId(void)
{ {
text *sequenceName = NULL;
Oid sequenceId = InvalidOid;
Datum sequenceIdDatum = 0;
Oid savedUserId = InvalidOid; Oid savedUserId = InvalidOid;
int savedSecurityContext = 0; int savedSecurityContext = 0;
Datum placementIdDatum = 0;
uint64 placementId = 0; uint64 placementId = 0;
/* /*
@ -392,15 +371,15 @@ GetNextPlacementId(void)
return placementId; return placementId;
} }
sequenceName = cstring_to_text(PLACEMENTID_SEQUENCE_NAME); text *sequenceName = cstring_to_text(PLACEMENTID_SEQUENCE_NAME);
sequenceId = ResolveRelationId(sequenceName, false); Oid sequenceId = ResolveRelationId(sequenceName, false);
sequenceIdDatum = ObjectIdGetDatum(sequenceId); Datum sequenceIdDatum = ObjectIdGetDatum(sequenceId);
GetUserIdAndSecContext(&savedUserId, &savedSecurityContext); GetUserIdAndSecContext(&savedUserId, &savedSecurityContext);
SetUserIdAndSecContext(CitusExtensionOwner(), SECURITY_LOCAL_USERID_CHANGE); SetUserIdAndSecContext(CitusExtensionOwner(), SECURITY_LOCAL_USERID_CHANGE);
/* generate new and unique placement id from sequence */ /* generate new and unique placement id from sequence */
placementIdDatum = DirectFunctionCall1(nextval_oid, sequenceIdDatum); Datum placementIdDatum = DirectFunctionCall1(nextval_oid, sequenceIdDatum);
SetUserIdAndSecContext(savedUserId, savedSecurityContext); SetUserIdAndSecContext(savedUserId, savedSecurityContext);
@ -465,17 +444,16 @@ master_get_active_worker_nodes(PG_FUNCTION_ARGS)
if (SRF_IS_FIRSTCALL()) if (SRF_IS_FIRSTCALL())
{ {
MemoryContext oldContext = NULL;
List *workerNodeList = NIL;
TupleDesc tupleDescriptor = NULL; TupleDesc tupleDescriptor = NULL;
/* create a function context for cross-call persistence */ /* create a function context for cross-call persistence */
functionContext = SRF_FIRSTCALL_INIT(); functionContext = SRF_FIRSTCALL_INIT();
/* switch to memory context appropriate for multiple function calls */ /* switch to memory context appropriate for multiple function calls */
oldContext = MemoryContextSwitchTo(functionContext->multi_call_memory_ctx); MemoryContext oldContext = MemoryContextSwitchTo(
functionContext->multi_call_memory_ctx);
workerNodeList = ActiveReadableWorkerNodeList(); List *workerNodeList = ActiveReadableWorkerNodeList();
workerNodeCount = (uint32) list_length(workerNodeList); workerNodeCount = (uint32) list_length(workerNodeList);
functionContext->user_fctx = workerNodeList; functionContext->user_fctx = workerNodeList;
@ -525,14 +503,10 @@ master_get_active_worker_nodes(PG_FUNCTION_ARGS)
Oid Oid
ResolveRelationId(text *relationName, bool missingOk) ResolveRelationId(text *relationName, bool missingOk)
{ {
List *relationNameList = NIL;
RangeVar *relation = NULL;
Oid relationId = InvalidOid;
/* resolve relationId from passed in schema and relation name */ /* resolve relationId from passed in schema and relation name */
relationNameList = textToQualifiedNameList(relationName); List *relationNameList = textToQualifiedNameList(relationName);
relation = makeRangeVarFromNameList(relationNameList); RangeVar *relation = makeRangeVarFromNameList(relationNameList);
relationId = RangeVarGetRelid(relation, NoLock, missingOk); Oid relationId = RangeVarGetRelid(relation, NoLock, missingOk);
return relationId; return relationId;
} }
@ -551,22 +525,18 @@ List *
GetTableDDLEvents(Oid relationId, bool includeSequenceDefaults) GetTableDDLEvents(Oid relationId, bool includeSequenceDefaults)
{ {
List *tableDDLEventList = NIL; List *tableDDLEventList = NIL;
List *tableCreationCommandList = NIL;
List *indexAndConstraintCommandList = NIL;
List *replicaIdentityEvents = NIL;
List *policyCommands = NIL;
tableCreationCommandList = GetTableCreationCommands(relationId, List *tableCreationCommandList = GetTableCreationCommands(relationId,
includeSequenceDefaults); includeSequenceDefaults);
tableDDLEventList = list_concat(tableDDLEventList, tableCreationCommandList); tableDDLEventList = list_concat(tableDDLEventList, tableCreationCommandList);
indexAndConstraintCommandList = GetTableIndexAndConstraintCommands(relationId); List *indexAndConstraintCommandList = GetTableIndexAndConstraintCommands(relationId);
tableDDLEventList = list_concat(tableDDLEventList, indexAndConstraintCommandList); tableDDLEventList = list_concat(tableDDLEventList, indexAndConstraintCommandList);
replicaIdentityEvents = GetTableReplicaIdentityCommand(relationId); List *replicaIdentityEvents = GetTableReplicaIdentityCommand(relationId);
tableDDLEventList = list_concat(tableDDLEventList, replicaIdentityEvents); tableDDLEventList = list_concat(tableDDLEventList, replicaIdentityEvents);
policyCommands = CreatePolicyCommands(relationId); List *policyCommands = CreatePolicyCommands(relationId);
tableDDLEventList = list_concat(tableDDLEventList, policyCommands); tableDDLEventList = list_concat(tableDDLEventList, policyCommands);
return tableDDLEventList; return tableDDLEventList;
@ -581,7 +551,6 @@ static List *
GetTableReplicaIdentityCommand(Oid relationId) GetTableReplicaIdentityCommand(Oid relationId)
{ {
List *replicaIdentityCreateCommandList = NIL; List *replicaIdentityCreateCommandList = NIL;
char *replicaIdentityCreateCommand = NULL;
/* /*
* We skip non-relations because postgres does not support * We skip non-relations because postgres does not support
@ -593,7 +562,7 @@ GetTableReplicaIdentityCommand(Oid relationId)
return NIL; return NIL;
} }
replicaIdentityCreateCommand = pg_get_replica_identity_command(relationId); char *replicaIdentityCreateCommand = pg_get_replica_identity_command(relationId);
if (replicaIdentityCreateCommand) if (replicaIdentityCreateCommand)
{ {
@ -614,10 +583,6 @@ List *
GetTableCreationCommands(Oid relationId, bool includeSequenceDefaults) GetTableCreationCommands(Oid relationId, bool includeSequenceDefaults)
{ {
List *tableDDLEventList = NIL; List *tableDDLEventList = NIL;
char tableType = 0;
char *tableSchemaDef = NULL;
char *tableColumnOptionsDef = NULL;
char *tableOwnerDef = NULL;
/* /*
* Set search_path to NIL so that all objects outside of pg_catalog will be * Set search_path to NIL so that all objects outside of pg_catalog will be
@ -630,7 +595,7 @@ GetTableCreationCommands(Oid relationId, bool includeSequenceDefaults)
PushOverrideSearchPath(overridePath); PushOverrideSearchPath(overridePath);
/* if foreign table, fetch extension and server definitions */ /* if foreign table, fetch extension and server definitions */
tableType = get_rel_relkind(relationId); char tableType = get_rel_relkind(relationId);
if (tableType == RELKIND_FOREIGN_TABLE) if (tableType == RELKIND_FOREIGN_TABLE)
{ {
char *extensionDef = pg_get_extensiondef_string(relationId); char *extensionDef = pg_get_extensiondef_string(relationId);
@ -644,8 +609,9 @@ GetTableCreationCommands(Oid relationId, bool includeSequenceDefaults)
} }
/* fetch table schema and column option definitions */ /* fetch table schema and column option definitions */
tableSchemaDef = pg_get_tableschemadef_string(relationId, includeSequenceDefaults); char *tableSchemaDef = pg_get_tableschemadef_string(relationId,
tableColumnOptionsDef = pg_get_tablecolumnoptionsdef_string(relationId); includeSequenceDefaults);
char *tableColumnOptionsDef = pg_get_tablecolumnoptionsdef_string(relationId);
tableDDLEventList = lappend(tableDDLEventList, tableSchemaDef); tableDDLEventList = lappend(tableDDLEventList, tableSchemaDef);
if (tableColumnOptionsDef != NULL) if (tableColumnOptionsDef != NULL)
@ -653,7 +619,7 @@ GetTableCreationCommands(Oid relationId, bool includeSequenceDefaults)
tableDDLEventList = lappend(tableDDLEventList, tableColumnOptionsDef); tableDDLEventList = lappend(tableDDLEventList, tableColumnOptionsDef);
} }
tableOwnerDef = TableOwnerResetCommand(relationId); char *tableOwnerDef = TableOwnerResetCommand(relationId);
if (tableOwnerDef != NULL) if (tableOwnerDef != NULL)
{ {
tableDDLEventList = lappend(tableDDLEventList, tableOwnerDef); tableDDLEventList = lappend(tableDDLEventList, tableOwnerDef);
@ -674,11 +640,8 @@ List *
GetTableIndexAndConstraintCommands(Oid relationId) GetTableIndexAndConstraintCommands(Oid relationId)
{ {
List *indexDDLEventList = NIL; List *indexDDLEventList = NIL;
Relation pgIndex = NULL;
SysScanDesc scanDescriptor = NULL;
ScanKeyData scanKey[1]; ScanKeyData scanKey[1];
int scanKeyCount = 1; int scanKeyCount = 1;
HeapTuple heapTuple = NULL;
/* /*
* Set search_path to NIL so that all objects outside of pg_catalog will be * Set search_path to NIL so that all objects outside of pg_catalog will be
@ -691,16 +654,16 @@ GetTableIndexAndConstraintCommands(Oid relationId)
PushOverrideSearchPath(overridePath); PushOverrideSearchPath(overridePath);
/* open system catalog and scan all indexes that belong to this table */ /* open system catalog and scan all indexes that belong to this table */
pgIndex = heap_open(IndexRelationId, AccessShareLock); Relation pgIndex = heap_open(IndexRelationId, AccessShareLock);
ScanKeyInit(&scanKey[0], Anum_pg_index_indrelid, ScanKeyInit(&scanKey[0], Anum_pg_index_indrelid,
BTEqualStrategyNumber, F_OIDEQ, relationId); BTEqualStrategyNumber, F_OIDEQ, relationId);
scanDescriptor = systable_beginscan(pgIndex, SysScanDesc scanDescriptor = systable_beginscan(pgIndex,
IndexIndrelidIndexId, true, /* indexOK */ IndexIndrelidIndexId, true, /* indexOK */
NULL, scanKeyCount, scanKey); NULL, scanKeyCount, scanKey);
heapTuple = systable_getnext(scanDescriptor); HeapTuple heapTuple = systable_getnext(scanDescriptor);
while (HeapTupleIsValid(heapTuple)) while (HeapTupleIsValid(heapTuple))
{ {
Form_pg_index indexForm = (Form_pg_index) GETSTRUCT(heapTuple); Form_pg_index indexForm = (Form_pg_index) GETSTRUCT(heapTuple);
@ -824,8 +787,6 @@ WorkerNodeGetDatum(WorkerNode *workerNode, TupleDesc tupleDescriptor)
{ {
Datum values[WORKER_NODE_FIELDS]; Datum values[WORKER_NODE_FIELDS];
bool isNulls[WORKER_NODE_FIELDS]; bool isNulls[WORKER_NODE_FIELDS];
HeapTuple workerNodeTuple = NULL;
Datum workerNodeDatum = 0;
memset(values, 0, sizeof(values)); memset(values, 0, sizeof(values));
memset(isNulls, false, sizeof(isNulls)); memset(isNulls, false, sizeof(isNulls));
@ -833,8 +794,8 @@ WorkerNodeGetDatum(WorkerNode *workerNode, TupleDesc tupleDescriptor)
values[0] = CStringGetTextDatum(workerNode->workerName); values[0] = CStringGetTextDatum(workerNode->workerName);
values[1] = Int64GetDatum((int64) workerNode->workerPort); values[1] = Int64GetDatum((int64) workerNode->workerPort);
workerNodeTuple = heap_form_tuple(tupleDescriptor, values, isNulls); HeapTuple workerNodeTuple = heap_form_tuple(tupleDescriptor, values, isNulls);
workerNodeDatum = HeapTupleGetDatum(workerNodeTuple); Datum workerNodeDatum = HeapTupleGetDatum(workerNodeTuple);
return workerNodeDatum; return workerNodeDatum;
} }

View File

@ -139,9 +139,6 @@ BlockWritesToShardList(List *shardList)
{ {
ListCell *shardCell = NULL; ListCell *shardCell = NULL;
bool shouldSyncMetadata = false;
ShardInterval *firstShardInterval = NULL;
Oid firstDistributedTableId = InvalidOid;
foreach(shardCell, shardList) foreach(shardCell, shardList)
{ {
@ -167,10 +164,10 @@ BlockWritesToShardList(List *shardList)
* Since the function assumes that the input shards are colocated, * Since the function assumes that the input shards are colocated,
* calculating shouldSyncMetadata for a single table is sufficient. * calculating shouldSyncMetadata for a single table is sufficient.
*/ */
firstShardInterval = (ShardInterval *) linitial(shardList); ShardInterval *firstShardInterval = (ShardInterval *) linitial(shardList);
firstDistributedTableId = firstShardInterval->relationId; Oid firstDistributedTableId = firstShardInterval->relationId;
shouldSyncMetadata = ShouldSyncTableMetadata(firstDistributedTableId); bool shouldSyncMetadata = ShouldSyncTableMetadata(firstDistributedTableId);
if (shouldSyncMetadata) if (shouldSyncMetadata)
{ {
LockShardListMetadataOnWorkers(ExclusiveLock, shardList); LockShardListMetadataOnWorkers(ExclusiveLock, shardList);
@ -225,13 +222,7 @@ RepairShardPlacement(int64 shardId, char *sourceNodeName, int32 sourceNodePort,
char relationKind = get_rel_relkind(distributedTableId); char relationKind = get_rel_relkind(distributedTableId);
char *tableOwner = TableOwner(shardInterval->relationId); char *tableOwner = TableOwner(shardInterval->relationId);
bool missingOk = false; bool missingOk = false;
bool includeData = false;
bool partitionedTable = false;
List *ddlCommandList = NIL;
List *foreignConstraintCommandList = NIL;
List *placementList = NIL;
ShardPlacement *placement = NULL;
/* prevent table from being dropped */ /* prevent table from being dropped */
LockRelationOid(distributedTableId, AccessShareLock); LockRelationOid(distributedTableId, AccessShareLock);
@ -287,13 +278,14 @@ RepairShardPlacement(int64 shardId, char *sourceNodeName, int32 sourceNodePort,
* If the shard belongs to a partitioned table, we need to load the data after * If the shard belongs to a partitioned table, we need to load the data after
* creating the partitions and the partitioning hierarcy. * creating the partitions and the partitioning hierarcy.
*/ */
partitionedTable = PartitionedTableNoLock(distributedTableId); bool partitionedTable = PartitionedTableNoLock(distributedTableId);
includeData = !partitionedTable; bool includeData = !partitionedTable;
/* we generate necessary commands to recreate the shard in target node */ /* we generate necessary commands to recreate the shard in target node */
ddlCommandList = List *ddlCommandList =
CopyShardCommandList(shardInterval, sourceNodeName, sourceNodePort, includeData); CopyShardCommandList(shardInterval, sourceNodeName, sourceNodePort, includeData);
foreignConstraintCommandList = CopyShardForeignConstraintCommandList(shardInterval); List *foreignConstraintCommandList = CopyShardForeignConstraintCommandList(
shardInterval);
ddlCommandList = list_concat(ddlCommandList, foreignConstraintCommandList); ddlCommandList = list_concat(ddlCommandList, foreignConstraintCommandList);
/* /*
@ -305,12 +297,10 @@ RepairShardPlacement(int64 shardId, char *sourceNodeName, int32 sourceNodePort,
*/ */
if (partitionedTable) if (partitionedTable)
{ {
List *partitionCommandList = NIL;
char *shardName = ConstructQualifiedShardName(shardInterval); char *shardName = ConstructQualifiedShardName(shardInterval);
StringInfo copyShardDataCommand = makeStringInfo(); StringInfo copyShardDataCommand = makeStringInfo();
partitionCommandList = List *partitionCommandList =
CopyPartitionShardsCommandList(shardInterval, sourceNodeName, sourceNodePort); CopyPartitionShardsCommandList(shardInterval, sourceNodeName, sourceNodePort);
ddlCommandList = list_concat(ddlCommandList, partitionCommandList); ddlCommandList = list_concat(ddlCommandList, partitionCommandList);
@ -328,8 +318,9 @@ RepairShardPlacement(int64 shardId, char *sourceNodeName, int32 sourceNodePort,
ddlCommandList); ddlCommandList);
/* after successful repair, we update shard state as healthy*/ /* after successful repair, we update shard state as healthy*/
placementList = ShardPlacementList(shardId); List *placementList = ShardPlacementList(shardId);
placement = SearchShardPlacementInList(placementList, targetNodeName, targetNodePort, ShardPlacement *placement = SearchShardPlacementInList(placementList, targetNodeName,
targetNodePort,
missingOk); missingOk);
UpdateShardPlacementState(placement->placementId, FILE_FINALIZED); UpdateShardPlacementState(placement->placementId, FILE_FINALIZED);
} }
@ -347,13 +338,12 @@ CopyPartitionShardsCommandList(ShardInterval *shardInterval, char *sourceNodeNam
int32 sourceNodePort) int32 sourceNodePort)
{ {
Oid distributedTableId = shardInterval->relationId; Oid distributedTableId = shardInterval->relationId;
List *partitionList = NIL;
ListCell *partitionOidCell = NULL; ListCell *partitionOidCell = NULL;
List *ddlCommandList = NIL; List *ddlCommandList = NIL;
Assert(PartitionedTableNoLock(distributedTableId)); Assert(PartitionedTableNoLock(distributedTableId));
partitionList = PartitionList(distributedTableId); List *partitionList = PartitionList(distributedTableId);
foreach(partitionOidCell, partitionList) foreach(partitionOidCell, partitionList)
{ {
Oid partitionOid = lfirst_oid(partitionOidCell); Oid partitionOid = lfirst_oid(partitionOidCell);
@ -361,15 +351,13 @@ CopyPartitionShardsCommandList(ShardInterval *shardInterval, char *sourceNodeNam
ColocatedShardIdInRelation(partitionOid, shardInterval->shardIndex); ColocatedShardIdInRelation(partitionOid, shardInterval->shardIndex);
ShardInterval *partitionShardInterval = LoadShardInterval(partitionShardId); ShardInterval *partitionShardInterval = LoadShardInterval(partitionShardId);
bool includeData = false; bool includeData = false;
List *copyCommandList = NIL;
char *attachPartitionCommand = NULL;
copyCommandList = List *copyCommandList =
CopyShardCommandList(partitionShardInterval, sourceNodeName, sourceNodePort, CopyShardCommandList(partitionShardInterval, sourceNodeName, sourceNodePort,
includeData); includeData);
ddlCommandList = list_concat(ddlCommandList, copyCommandList); ddlCommandList = list_concat(ddlCommandList, copyCommandList);
attachPartitionCommand = char *attachPartitionCommand =
GenerateAttachShardPartitionCommand(partitionShardInterval); GenerateAttachShardPartitionCommand(partitionShardInterval);
ddlCommandList = lappend(ddlCommandList, attachPartitionCommand); ddlCommandList = lappend(ddlCommandList, attachPartitionCommand);
} }
@ -387,21 +375,23 @@ EnsureShardCanBeRepaired(int64 shardId, char *sourceNodeName, int32 sourceNodePo
char *targetNodeName, int32 targetNodePort) char *targetNodeName, int32 targetNodePort)
{ {
List *shardPlacementList = ShardPlacementList(shardId); List *shardPlacementList = ShardPlacementList(shardId);
ShardPlacement *sourcePlacement = NULL;
ShardPlacement *targetPlacement = NULL;
bool missingSourceOk = false; bool missingSourceOk = false;
bool missingTargetOk = false; bool missingTargetOk = false;
sourcePlacement = SearchShardPlacementInList(shardPlacementList, sourceNodeName, ShardPlacement *sourcePlacement = SearchShardPlacementInList(shardPlacementList,
sourceNodePort, missingSourceOk); sourceNodeName,
sourceNodePort,
missingSourceOk);
if (sourcePlacement->shardState != FILE_FINALIZED) if (sourcePlacement->shardState != FILE_FINALIZED)
{ {
ereport(ERROR, (errcode(ERRCODE_INVALID_PARAMETER_VALUE), ereport(ERROR, (errcode(ERRCODE_INVALID_PARAMETER_VALUE),
errmsg("source placement must be in finalized state"))); errmsg("source placement must be in finalized state")));
} }
targetPlacement = SearchShardPlacementInList(shardPlacementList, targetNodeName, ShardPlacement *targetPlacement = SearchShardPlacementInList(shardPlacementList,
targetNodePort, missingTargetOk); targetNodeName,
targetNodePort,
missingTargetOk);
if (targetPlacement->shardState != FILE_INACTIVE) if (targetPlacement->shardState != FILE_INACTIVE)
{ {
ereport(ERROR, (errcode(ERRCODE_INVALID_PARAMETER_VALUE), ereport(ERROR, (errcode(ERRCODE_INVALID_PARAMETER_VALUE),
@ -462,13 +452,11 @@ CopyShardCommandList(ShardInterval *shardInterval, char *sourceNodeName,
{ {
int64 shardId = shardInterval->shardId; int64 shardId = shardInterval->shardId;
char *shardName = ConstructQualifiedShardName(shardInterval); char *shardName = ConstructQualifiedShardName(shardInterval);
List *tableRecreationCommandList = NIL;
List *indexCommandList = NIL;
List *copyShardToNodeCommandsList = NIL; List *copyShardToNodeCommandsList = NIL;
StringInfo copyShardDataCommand = makeStringInfo(); StringInfo copyShardDataCommand = makeStringInfo();
Oid relationId = shardInterval->relationId; Oid relationId = shardInterval->relationId;
tableRecreationCommandList = RecreateTableDDLCommandList(relationId); List *tableRecreationCommandList = RecreateTableDDLCommandList(relationId);
tableRecreationCommandList = tableRecreationCommandList =
WorkerApplyShardDDLCommandList(tableRecreationCommandList, shardId); WorkerApplyShardDDLCommandList(tableRecreationCommandList, shardId);
@ -491,7 +479,7 @@ CopyShardCommandList(ShardInterval *shardInterval, char *sourceNodeName,
copyShardDataCommand->data); copyShardDataCommand->data);
} }
indexCommandList = GetTableIndexAndConstraintCommands(relationId); List *indexCommandList = GetTableIndexAndConstraintCommands(relationId);
indexCommandList = WorkerApplyShardDDLCommandList(indexCommandList, shardId); indexCommandList = WorkerApplyShardDDLCommandList(indexCommandList, shardId);
copyShardToNodeCommandsList = list_concat(copyShardToNodeCommandsList, copyShardToNodeCommandsList = list_concat(copyShardToNodeCommandsList,
@ -555,17 +543,13 @@ CopyShardForeignConstraintCommandListGrouped(ShardInterval *shardInterval,
char *command = (char *) lfirst(commandCell); char *command = (char *) lfirst(commandCell);
char *escapedCommand = quote_literal_cstr(command); char *escapedCommand = quote_literal_cstr(command);
Oid referencedRelationId = InvalidOid;
Oid referencedSchemaId = InvalidOid;
char *referencedSchemaName = NULL;
char *escapedReferencedSchemaName = NULL;
uint64 referencedShardId = INVALID_SHARD_ID; uint64 referencedShardId = INVALID_SHARD_ID;
bool colocatedForeignKey = false; bool colocatedForeignKey = false;
StringInfo applyForeignConstraintCommand = makeStringInfo(); StringInfo applyForeignConstraintCommand = makeStringInfo();
/* we need to parse the foreign constraint command to get referencing table id */ /* we need to parse the foreign constraint command to get referencing table id */
referencedRelationId = ForeignConstraintGetReferencedTableId(command); Oid referencedRelationId = ForeignConstraintGetReferencedTableId(command);
if (referencedRelationId == InvalidOid) if (referencedRelationId == InvalidOid)
{ {
ereport(ERROR, (errcode(ERRCODE_INVALID_PARAMETER_VALUE), ereport(ERROR, (errcode(ERRCODE_INVALID_PARAMETER_VALUE),
@ -573,9 +557,9 @@ CopyShardForeignConstraintCommandListGrouped(ShardInterval *shardInterval,
errdetail("Referenced relation cannot be found."))); errdetail("Referenced relation cannot be found.")));
} }
referencedSchemaId = get_rel_namespace(referencedRelationId); Oid referencedSchemaId = get_rel_namespace(referencedRelationId);
referencedSchemaName = get_namespace_name(referencedSchemaId); char *referencedSchemaName = get_namespace_name(referencedSchemaId);
escapedReferencedSchemaName = quote_literal_cstr(referencedSchemaName); char *escapedReferencedSchemaName = quote_literal_cstr(referencedSchemaName);
if (PartitionMethod(referencedRelationId) == DISTRIBUTE_BY_NONE) if (PartitionMethod(referencedRelationId) == DISTRIBUTE_BY_NONE)
{ {
@ -635,9 +619,8 @@ ConstructQualifiedShardName(ShardInterval *shardInterval)
Oid schemaId = get_rel_namespace(shardInterval->relationId); Oid schemaId = get_rel_namespace(shardInterval->relationId);
char *schemaName = get_namespace_name(schemaId); char *schemaName = get_namespace_name(schemaId);
char *tableName = get_rel_name(shardInterval->relationId); char *tableName = get_rel_name(shardInterval->relationId);
char *shardName = NULL;
shardName = pstrdup(tableName); char *shardName = pstrdup(tableName);
AppendShardIdToName(&shardName, shardInterval->shardId); AppendShardIdToName(&shardName, shardInterval->shardId);
shardName = quote_qualified_identifier(schemaName, shardName); shardName = quote_qualified_identifier(schemaName, shardName);
@ -660,9 +643,6 @@ RecreateTableDDLCommandList(Oid relationId)
relationName); relationName);
StringInfo dropCommand = makeStringInfo(); StringInfo dropCommand = makeStringInfo();
List *createCommandList = NIL;
List *dropCommandList = NIL;
List *recreateCommandList = NIL;
char relationKind = get_rel_relkind(relationId); char relationKind = get_rel_relkind(relationId);
bool includeSequenceDefaults = false; bool includeSequenceDefaults = false;
@ -684,9 +664,10 @@ RecreateTableDDLCommandList(Oid relationId)
"table"))); "table")));
} }
dropCommandList = list_make1(dropCommand->data); List *dropCommandList = list_make1(dropCommand->data);
createCommandList = GetTableCreationCommands(relationId, includeSequenceDefaults); List *createCommandList = GetTableCreationCommands(relationId,
recreateCommandList = list_concat(dropCommandList, createCommandList); includeSequenceDefaults);
List *recreateCommandList = list_concat(dropCommandList, createCommandList);
return recreateCommandList; return recreateCommandList;
} }

View File

@ -62,16 +62,13 @@ Datum
worker_hash(PG_FUNCTION_ARGS) worker_hash(PG_FUNCTION_ARGS)
{ {
Datum valueDatum = PG_GETARG_DATUM(0); Datum valueDatum = PG_GETARG_DATUM(0);
Datum hashedValueDatum = 0;
TypeCacheEntry *typeEntry = NULL;
FmgrInfo *hashFunction = NULL;
Oid valueDataType = InvalidOid;
CheckCitusVersion(ERROR); CheckCitusVersion(ERROR);
/* figure out hash function from the data type */ /* figure out hash function from the data type */
valueDataType = get_fn_expr_argtype(fcinfo->flinfo, 0); Oid valueDataType = get_fn_expr_argtype(fcinfo->flinfo, 0);
typeEntry = lookup_type_cache(valueDataType, TYPECACHE_HASH_PROC_FINFO); TypeCacheEntry *typeEntry = lookup_type_cache(valueDataType,
TYPECACHE_HASH_PROC_FINFO);
if (typeEntry->hash_proc_finfo.fn_oid == InvalidOid) if (typeEntry->hash_proc_finfo.fn_oid == InvalidOid)
{ {
@ -80,11 +77,12 @@ worker_hash(PG_FUNCTION_ARGS)
errhint("Cast input to a data type with a hash function."))); errhint("Cast input to a data type with a hash function.")));
} }
hashFunction = palloc0(sizeof(FmgrInfo)); FmgrInfo *hashFunction = palloc0(sizeof(FmgrInfo));
fmgr_info_copy(hashFunction, &(typeEntry->hash_proc_finfo), CurrentMemoryContext); fmgr_info_copy(hashFunction, &(typeEntry->hash_proc_finfo), CurrentMemoryContext);
/* calculate hash value */ /* calculate hash value */
hashedValueDatum = FunctionCall1Coll(hashFunction, PG_GET_COLLATION(), valueDatum); Datum hashedValueDatum = FunctionCall1Coll(hashFunction, PG_GET_COLLATION(),
valueDatum);
PG_RETURN_INT32(hashedValueDatum); PG_RETURN_INT32(hashedValueDatum);
} }

View File

@ -80,21 +80,17 @@ master_create_empty_shard(PG_FUNCTION_ARGS)
{ {
text *relationNameText = PG_GETARG_TEXT_P(0); text *relationNameText = PG_GETARG_TEXT_P(0);
char *relationName = text_to_cstring(relationNameText); char *relationName = text_to_cstring(relationNameText);
uint64 shardId = INVALID_SHARD_ID;
uint32 attemptableNodeCount = 0; uint32 attemptableNodeCount = 0;
ObjectAddress tableAddress = { 0 }; ObjectAddress tableAddress = { 0 };
uint32 candidateNodeIndex = 0; uint32 candidateNodeIndex = 0;
List *candidateNodeList = NIL; List *candidateNodeList = NIL;
List *workerNodeList = NIL;
text *nullMinValue = NULL; text *nullMinValue = NULL;
text *nullMaxValue = NULL; text *nullMaxValue = NULL;
char partitionMethod = 0;
char storageType = SHARD_STORAGE_TABLE; char storageType = SHARD_STORAGE_TABLE;
Oid relationId = ResolveRelationId(relationNameText, false); Oid relationId = ResolveRelationId(relationNameText, false);
char relationKind = get_rel_relkind(relationId); char relationKind = get_rel_relkind(relationId);
char replicationModel = REPLICATION_MODEL_INVALID;
CheckCitusVersion(ERROR); CheckCitusVersion(ERROR);
@ -136,7 +132,7 @@ master_create_empty_shard(PG_FUNCTION_ARGS)
} }
} }
partitionMethod = PartitionMethod(relationId); char partitionMethod = PartitionMethod(relationId);
if (partitionMethod == DISTRIBUTE_BY_HASH) if (partitionMethod == DISTRIBUTE_BY_HASH)
{ {
ereport(ERROR, (errmsg("relation \"%s\" is a hash partitioned table", ereport(ERROR, (errmsg("relation \"%s\" is a hash partitioned table",
@ -152,15 +148,15 @@ master_create_empty_shard(PG_FUNCTION_ARGS)
"on reference tables"))); "on reference tables")));
} }
replicationModel = TableReplicationModel(relationId); char replicationModel = TableReplicationModel(relationId);
EnsureReplicationSettings(relationId, replicationModel); EnsureReplicationSettings(relationId, replicationModel);
/* generate new and unique shardId from sequence */ /* generate new and unique shardId from sequence */
shardId = GetNextShardId(); uint64 shardId = GetNextShardId();
/* if enough live groups, add an extra candidate node as backup */ /* if enough live groups, add an extra candidate node as backup */
workerNodeList = DistributedTablePlacementNodeList(NoLock); List *workerNodeList = DistributedTablePlacementNodeList(NoLock);
if (list_length(workerNodeList) > ShardReplicationFactor) if (list_length(workerNodeList) > ShardReplicationFactor)
{ {
@ -232,33 +228,20 @@ master_append_table_to_shard(PG_FUNCTION_ARGS)
char *sourceTableName = text_to_cstring(sourceTableNameText); char *sourceTableName = text_to_cstring(sourceTableNameText);
char *sourceNodeName = text_to_cstring(sourceNodeNameText); char *sourceNodeName = text_to_cstring(sourceNodeNameText);
Oid shardSchemaOid = 0;
char *shardSchemaName = NULL;
char *shardTableName = NULL;
char *shardQualifiedName = NULL;
List *shardPlacementList = NIL;
ListCell *shardPlacementCell = NULL; ListCell *shardPlacementCell = NULL;
uint64 newShardSize = 0;
uint64 shardMaxSizeInBytes = 0;
float4 shardFillLevel = 0.0; float4 shardFillLevel = 0.0;
char partitionMethod = 0;
ShardInterval *shardInterval = NULL;
Oid relationId = InvalidOid;
bool cstoreTable = false;
char storageType = 0;
CheckCitusVersion(ERROR); CheckCitusVersion(ERROR);
shardInterval = LoadShardInterval(shardId); ShardInterval *shardInterval = LoadShardInterval(shardId);
relationId = shardInterval->relationId; Oid relationId = shardInterval->relationId;
/* don't allow the table to be dropped */ /* don't allow the table to be dropped */
LockRelationOid(relationId, AccessShareLock); LockRelationOid(relationId, AccessShareLock);
cstoreTable = CStoreTable(relationId); bool cstoreTable = CStoreTable(relationId);
storageType = shardInterval->storageType; char storageType = shardInterval->storageType;
EnsureTablePermissions(relationId, ACL_INSERT); EnsureTablePermissions(relationId, ACL_INSERT);
@ -268,7 +251,7 @@ master_append_table_to_shard(PG_FUNCTION_ARGS)
errdetail("The underlying shard is not a regular table"))); errdetail("The underlying shard is not a regular table")));
} }
partitionMethod = PartitionMethod(relationId); char partitionMethod = PartitionMethod(relationId);
if (partitionMethod == DISTRIBUTE_BY_HASH || partitionMethod == DISTRIBUTE_BY_NONE) if (partitionMethod == DISTRIBUTE_BY_HASH || partitionMethod == DISTRIBUTE_BY_NONE)
{ {
ereport(ERROR, (errmsg("cannot append to shardId " UINT64_FORMAT, shardId), ereport(ERROR, (errmsg("cannot append to shardId " UINT64_FORMAT, shardId),
@ -283,16 +266,17 @@ master_append_table_to_shard(PG_FUNCTION_ARGS)
LockShardResource(shardId, ExclusiveLock); LockShardResource(shardId, ExclusiveLock);
/* get schame name of the target shard */ /* get schame name of the target shard */
shardSchemaOid = get_rel_namespace(relationId); Oid shardSchemaOid = get_rel_namespace(relationId);
shardSchemaName = get_namespace_name(shardSchemaOid); char *shardSchemaName = get_namespace_name(shardSchemaOid);
/* Build shard table name. */ /* Build shard table name. */
shardTableName = get_rel_name(relationId); char *shardTableName = get_rel_name(relationId);
AppendShardIdToName(&shardTableName, shardId); AppendShardIdToName(&shardTableName, shardId);
shardQualifiedName = quote_qualified_identifier(shardSchemaName, shardTableName); char *shardQualifiedName = quote_qualified_identifier(shardSchemaName,
shardTableName);
shardPlacementList = FinalizedShardPlacementList(shardId); List *shardPlacementList = FinalizedShardPlacementList(shardId);
if (shardPlacementList == NIL) if (shardPlacementList == NIL)
{ {
ereport(ERROR, (errmsg("could not find any shard placements for shardId " ereport(ERROR, (errmsg("could not find any shard placements for shardId "
@ -309,7 +293,6 @@ master_append_table_to_shard(PG_FUNCTION_ARGS)
MultiConnection *connection = GetPlacementConnection(FOR_DML, shardPlacement, MultiConnection *connection = GetPlacementConnection(FOR_DML, shardPlacement,
NULL); NULL);
PGresult *queryResult = NULL; PGresult *queryResult = NULL;
int executeResult = 0;
StringInfo workerAppendQuery = makeStringInfo(); StringInfo workerAppendQuery = makeStringInfo();
appendStringInfo(workerAppendQuery, WORKER_APPEND_TABLE_TO_SHARD, appendStringInfo(workerAppendQuery, WORKER_APPEND_TABLE_TO_SHARD,
@ -319,7 +302,8 @@ master_append_table_to_shard(PG_FUNCTION_ARGS)
RemoteTransactionBeginIfNecessary(connection); RemoteTransactionBeginIfNecessary(connection);
executeResult = ExecuteOptionalRemoteCommand(connection, workerAppendQuery->data, int executeResult = ExecuteOptionalRemoteCommand(connection,
workerAppendQuery->data,
&queryResult); &queryResult);
PQclear(queryResult); PQclear(queryResult);
ForgetResults(connection); ForgetResults(connection);
@ -333,10 +317,10 @@ master_append_table_to_shard(PG_FUNCTION_ARGS)
MarkFailedShardPlacements(); MarkFailedShardPlacements();
/* update shard statistics and get new shard size */ /* update shard statistics and get new shard size */
newShardSize = UpdateShardStatistics(shardId); uint64 newShardSize = UpdateShardStatistics(shardId);
/* calculate ratio of current shard size compared to shard max size */ /* calculate ratio of current shard size compared to shard max size */
shardMaxSizeInBytes = (int64) ShardMaxSize * 1024L; uint64 shardMaxSizeInBytes = (int64) ShardMaxSize * 1024L;
shardFillLevel = ((float4) newShardSize / (float4) shardMaxSizeInBytes); shardFillLevel = ((float4) newShardSize / (float4) shardMaxSizeInBytes);
PG_RETURN_FLOAT4(shardFillLevel); PG_RETURN_FLOAT4(shardFillLevel);
@ -351,11 +335,10 @@ Datum
master_update_shard_statistics(PG_FUNCTION_ARGS) master_update_shard_statistics(PG_FUNCTION_ARGS)
{ {
int64 shardId = PG_GETARG_INT64(0); int64 shardId = PG_GETARG_INT64(0);
uint64 shardSize = 0;
CheckCitusVersion(ERROR); CheckCitusVersion(ERROR);
shardSize = UpdateShardStatistics(shardId); uint64 shardSize = UpdateShardStatistics(shardId);
PG_RETURN_INT64(shardSize); PG_RETURN_INT64(shardSize);
} }
@ -393,7 +376,6 @@ CreateAppendDistributedShardPlacements(Oid relationId, int64 shardId,
int attemptCount = replicationFactor; int attemptCount = replicationFactor;
int workerNodeCount = list_length(workerNodeList); int workerNodeCount = list_length(workerNodeList);
int placementsCreated = 0; int placementsCreated = 0;
int attemptNumber = 0;
List *foreignConstraintCommandList = GetTableForeignConstraintCommands(relationId); List *foreignConstraintCommandList = GetTableForeignConstraintCommands(relationId);
bool includeSequenceDefaults = false; bool includeSequenceDefaults = false;
List *ddlCommandList = GetTableDDLEvents(relationId, includeSequenceDefaults); List *ddlCommandList = GetTableDDLEvents(relationId, includeSequenceDefaults);
@ -406,7 +388,7 @@ CreateAppendDistributedShardPlacements(Oid relationId, int64 shardId,
attemptCount++; attemptCount++;
} }
for (attemptNumber = 0; attemptNumber < attemptCount; attemptNumber++) for (int attemptNumber = 0; attemptNumber < attemptCount; attemptNumber++)
{ {
int workerNodeIndex = attemptNumber % workerNodeCount; int workerNodeIndex = attemptNumber % workerNodeCount;
WorkerNode *workerNode = (WorkerNode *) list_nth(workerNodeList, workerNodeIndex); WorkerNode *workerNode = (WorkerNode *) list_nth(workerNodeList, workerNodeIndex);
@ -419,7 +401,6 @@ CreateAppendDistributedShardPlacements(Oid relationId, int64 shardId,
MultiConnection *connection = MultiConnection *connection =
GetNodeUserDatabaseConnection(connectionFlag, nodeName, nodePort, GetNodeUserDatabaseConnection(connectionFlag, nodeName, nodePort,
relationOwner, NULL); relationOwner, NULL);
List *commandList = NIL;
if (PQstatus(connection->pgConn) != CONNECTION_OK) if (PQstatus(connection->pgConn) != CONNECTION_OK)
{ {
@ -429,7 +410,7 @@ CreateAppendDistributedShardPlacements(Oid relationId, int64 shardId,
continue; continue;
} }
commandList = WorkerCreateShardCommandList(relationId, shardIndex, shardId, List *commandList = WorkerCreateShardCommandList(relationId, shardIndex, shardId,
ddlCommandList, ddlCommandList,
foreignConstraintCommandList); foreignConstraintCommandList);
@ -463,23 +444,21 @@ InsertShardPlacementRows(Oid relationId, int64 shardId, List *workerNodeList,
int workerStartIndex, int replicationFactor) int workerStartIndex, int replicationFactor)
{ {
int workerNodeCount = list_length(workerNodeList); int workerNodeCount = list_length(workerNodeList);
int attemptNumber = 0;
int placementsInserted = 0; int placementsInserted = 0;
List *insertedShardPlacements = NIL; List *insertedShardPlacements = NIL;
for (attemptNumber = 0; attemptNumber < replicationFactor; attemptNumber++) for (int attemptNumber = 0; attemptNumber < replicationFactor; attemptNumber++)
{ {
int workerNodeIndex = (workerStartIndex + attemptNumber) % workerNodeCount; int workerNodeIndex = (workerStartIndex + attemptNumber) % workerNodeCount;
WorkerNode *workerNode = (WorkerNode *) list_nth(workerNodeList, workerNodeIndex); WorkerNode *workerNode = (WorkerNode *) list_nth(workerNodeList, workerNodeIndex);
uint32 nodeGroupId = workerNode->groupId; uint32 nodeGroupId = workerNode->groupId;
const RelayFileState shardState = FILE_FINALIZED; const RelayFileState shardState = FILE_FINALIZED;
const uint64 shardSize = 0; const uint64 shardSize = 0;
uint64 shardPlacementId = 0;
ShardPlacement *shardPlacement = NULL;
shardPlacementId = InsertShardPlacementRow(shardId, INVALID_PLACEMENT_ID, uint64 shardPlacementId = InsertShardPlacementRow(shardId, INVALID_PLACEMENT_ID,
shardState, shardSize, nodeGroupId); shardState, shardSize,
shardPlacement = LoadShardPlacement(shardId, shardPlacementId); nodeGroupId);
ShardPlacement *shardPlacement = LoadShardPlacement(shardId, shardPlacementId);
insertedShardPlacements = lappend(insertedShardPlacements, shardPlacement); insertedShardPlacements = lappend(insertedShardPlacements, shardPlacement);
placementsInserted++; placementsInserted++;
@ -519,8 +498,6 @@ CreateShardsOnWorkers(Oid distributedRelationId, List *shardPlacements,
uint64 shardId = shardPlacement->shardId; uint64 shardId = shardPlacement->shardId;
ShardInterval *shardInterval = LoadShardInterval(shardId); ShardInterval *shardInterval = LoadShardInterval(shardId);
int shardIndex = -1; int shardIndex = -1;
List *commandList = NIL;
Task *task = NULL;
List *relationShardList = RelationShardListForShardCreate(shardInterval); List *relationShardList = RelationShardListForShardCreate(shardInterval);
if (colocatedShard) if (colocatedShard)
@ -528,11 +505,12 @@ CreateShardsOnWorkers(Oid distributedRelationId, List *shardPlacements,
shardIndex = ShardIndex(shardInterval); shardIndex = ShardIndex(shardInterval);
} }
commandList = WorkerCreateShardCommandList(distributedRelationId, shardIndex, List *commandList = WorkerCreateShardCommandList(distributedRelationId,
shardIndex,
shardId, ddlCommandList, shardId, ddlCommandList,
foreignConstraintCommandList); foreignConstraintCommandList);
task = CitusMakeNode(Task); Task *task = CitusMakeNode(Task);
task->jobId = INVALID_JOB_ID; task->jobId = INVALID_JOB_ID;
task->taskId = taskId++; task->taskId = taskId++;
task->taskType = DDL_TASK; task->taskType = DDL_TASK;
@ -580,26 +558,23 @@ CreateShardsOnWorkers(Oid distributedRelationId, List *shardPlacements,
static List * static List *
RelationShardListForShardCreate(ShardInterval *shardInterval) RelationShardListForShardCreate(ShardInterval *shardInterval)
{ {
List *relationShardList = NIL;
RelationShard *relationShard = NULL;
Oid relationId = shardInterval->relationId; Oid relationId = shardInterval->relationId;
DistTableCacheEntry *cacheEntry = DistributedTableCacheEntry(relationId); DistTableCacheEntry *cacheEntry = DistributedTableCacheEntry(relationId);
List *referencedRelationList = cacheEntry->referencedRelationsViaForeignKey; List *referencedRelationList = cacheEntry->referencedRelationsViaForeignKey;
List *referencingRelationList = cacheEntry->referencingRelationsViaForeignKey; List *referencingRelationList = cacheEntry->referencingRelationsViaForeignKey;
List *allForeignKeyRelations = NIL;
int shardIndex = -1; int shardIndex = -1;
ListCell *fkeyRelationIdCell = NULL; ListCell *fkeyRelationIdCell = NULL;
/* list_concat_*() modifies the first arg, so make a copy first */ /* list_concat_*() modifies the first arg, so make a copy first */
allForeignKeyRelations = list_copy(referencedRelationList); List *allForeignKeyRelations = list_copy(referencedRelationList);
allForeignKeyRelations = list_concat_unique_oid(allForeignKeyRelations, allForeignKeyRelations = list_concat_unique_oid(allForeignKeyRelations,
referencingRelationList); referencingRelationList);
/* record the placement access of the shard itself */ /* record the placement access of the shard itself */
relationShard = CitusMakeNode(RelationShard); RelationShard *relationShard = CitusMakeNode(RelationShard);
relationShard->relationId = relationId; relationShard->relationId = relationId;
relationShard->shardId = shardInterval->shardId; relationShard->shardId = shardInterval->shardId;
relationShardList = list_make1(relationShard); List *relationShardList = list_make1(relationShard);
if (cacheEntry->partitionMethod == DISTRIBUTE_BY_HASH && if (cacheEntry->partitionMethod == DISTRIBUTE_BY_HASH &&
cacheEntry->colocationId != INVALID_COLOCATION_ID) cacheEntry->colocationId != INVALID_COLOCATION_ID)
@ -612,7 +587,6 @@ RelationShardListForShardCreate(ShardInterval *shardInterval)
foreach(fkeyRelationIdCell, allForeignKeyRelations) foreach(fkeyRelationIdCell, allForeignKeyRelations)
{ {
Oid fkeyRelationid = lfirst_oid(fkeyRelationIdCell); Oid fkeyRelationid = lfirst_oid(fkeyRelationIdCell);
RelationShard *fkeyRelationShard = NULL;
uint64 fkeyShardId = INVALID_SHARD_ID; uint64 fkeyShardId = INVALID_SHARD_ID;
if (!IsDistributedTable(fkeyRelationid)) if (!IsDistributedTable(fkeyRelationid))
@ -645,7 +619,7 @@ RelationShardListForShardCreate(ShardInterval *shardInterval)
continue; continue;
} }
fkeyRelationShard = CitusMakeNode(RelationShard); RelationShard *fkeyRelationShard = CitusMakeNode(RelationShard);
fkeyRelationShard->relationId = fkeyRelationid; fkeyRelationShard->relationId = fkeyRelationid;
fkeyRelationShard->shardId = fkeyShardId; fkeyRelationShard->shardId = fkeyShardId;
@ -714,16 +688,12 @@ WorkerCreateShardCommandList(Oid relationId, int shardIndex, uint64 shardId,
char *command = (char *) lfirst(foreignConstraintCommandCell); char *command = (char *) lfirst(foreignConstraintCommandCell);
char *escapedCommand = quote_literal_cstr(command); char *escapedCommand = quote_literal_cstr(command);
Oid referencedRelationId = InvalidOid;
Oid referencedSchemaId = InvalidOid;
char *referencedSchemaName = NULL;
char *escapedReferencedSchemaName = NULL;
uint64 referencedShardId = INVALID_SHARD_ID; uint64 referencedShardId = INVALID_SHARD_ID;
StringInfo applyForeignConstraintCommand = makeStringInfo(); StringInfo applyForeignConstraintCommand = makeStringInfo();
/* we need to parse the foreign constraint command to get referencing table id */ /* we need to parse the foreign constraint command to get referencing table id */
referencedRelationId = ForeignConstraintGetReferencedTableId(command); Oid referencedRelationId = ForeignConstraintGetReferencedTableId(command);
if (referencedRelationId == InvalidOid) if (referencedRelationId == InvalidOid)
{ {
ereport(ERROR, (errcode(ERRCODE_INVALID_PARAMETER_VALUE), ereport(ERROR, (errcode(ERRCODE_INVALID_PARAMETER_VALUE),
@ -731,9 +701,9 @@ WorkerCreateShardCommandList(Oid relationId, int shardIndex, uint64 shardId,
errdetail("Referenced relation cannot be found."))); errdetail("Referenced relation cannot be found.")));
} }
referencedSchemaId = get_rel_namespace(referencedRelationId); Oid referencedSchemaId = get_rel_namespace(referencedRelationId);
referencedSchemaName = get_namespace_name(referencedSchemaId); char *referencedSchemaName = get_namespace_name(referencedSchemaId);
escapedReferencedSchemaName = quote_literal_cstr(referencedSchemaName); char *escapedReferencedSchemaName = quote_literal_cstr(referencedSchemaName);
/* /*
* In case of self referencing shards, relation itself might not be distributed * In case of self referencing shards, relation itself might not be distributed
@ -792,8 +762,6 @@ UpdateShardStatistics(int64 shardId)
Oid relationId = shardInterval->relationId; Oid relationId = shardInterval->relationId;
char storageType = shardInterval->storageType; char storageType = shardInterval->storageType;
char partitionType = PartitionMethod(relationId); char partitionType = PartitionMethod(relationId);
char *shardQualifiedName = NULL;
List *shardPlacementList = NIL;
ListCell *shardPlacementCell = NULL; ListCell *shardPlacementCell = NULL;
bool statsOK = false; bool statsOK = false;
uint64 shardSize = 0; uint64 shardSize = 0;
@ -807,9 +775,9 @@ UpdateShardStatistics(int64 shardId)
AppendShardIdToName(&shardName, shardId); AppendShardIdToName(&shardName, shardId);
shardQualifiedName = quote_qualified_identifier(schemaName, shardName); char *shardQualifiedName = quote_qualified_identifier(schemaName, shardName);
shardPlacementList = FinalizedShardPlacementList(shardId); List *shardPlacementList = FinalizedShardPlacementList(shardId);
/* get shard's statistics from a shard placement */ /* get shard's statistics from a shard placement */
foreach(shardPlacementCell, shardPlacementList) foreach(shardPlacementCell, shardPlacementList)
@ -881,28 +849,19 @@ static bool
WorkerShardStats(ShardPlacement *placement, Oid relationId, char *shardName, WorkerShardStats(ShardPlacement *placement, Oid relationId, char *shardName,
uint64 *shardSize, text **shardMinValue, text **shardMaxValue) uint64 *shardSize, text **shardMinValue, text **shardMaxValue)
{ {
char *quotedShardName = NULL;
bool cstoreTable = false;
StringInfo tableSizeQuery = makeStringInfo(); StringInfo tableSizeQuery = makeStringInfo();
const uint32 unusedTableId = 1; const uint32 unusedTableId = 1;
char partitionType = PartitionMethod(relationId); char partitionType = PartitionMethod(relationId);
Var *partitionColumn = NULL;
char *partitionColumnName = NULL;
StringInfo partitionValueQuery = makeStringInfo(); StringInfo partitionValueQuery = makeStringInfo();
PGresult *queryResult = NULL; PGresult *queryResult = NULL;
const int minValueIndex = 0; const int minValueIndex = 0;
const int maxValueIndex = 1; const int maxValueIndex = 1;
uint64 tableSize = 0;
char *tableSizeString = NULL;
char *tableSizeStringEnd = NULL; char *tableSizeStringEnd = NULL;
bool minValueIsNull = false;
bool maxValueIsNull = false;
int connectionFlags = 0; int connectionFlags = 0;
int executeCommand = 0;
MultiConnection *connection = GetPlacementConnection(connectionFlags, placement, MultiConnection *connection = GetPlacementConnection(connectionFlags, placement,
NULL); NULL);
@ -911,9 +870,9 @@ WorkerShardStats(ShardPlacement *placement, Oid relationId, char *shardName,
*shardMinValue = NULL; *shardMinValue = NULL;
*shardMaxValue = NULL; *shardMaxValue = NULL;
quotedShardName = quote_literal_cstr(shardName); char *quotedShardName = quote_literal_cstr(shardName);
cstoreTable = CStoreTable(relationId); bool cstoreTable = CStoreTable(relationId);
if (cstoreTable) if (cstoreTable)
{ {
appendStringInfo(tableSizeQuery, SHARD_CSTORE_TABLE_SIZE_QUERY, quotedShardName); appendStringInfo(tableSizeQuery, SHARD_CSTORE_TABLE_SIZE_QUERY, quotedShardName);
@ -923,14 +882,14 @@ WorkerShardStats(ShardPlacement *placement, Oid relationId, char *shardName,
appendStringInfo(tableSizeQuery, SHARD_TABLE_SIZE_QUERY, quotedShardName); appendStringInfo(tableSizeQuery, SHARD_TABLE_SIZE_QUERY, quotedShardName);
} }
executeCommand = ExecuteOptionalRemoteCommand(connection, tableSizeQuery->data, int executeCommand = ExecuteOptionalRemoteCommand(connection, tableSizeQuery->data,
&queryResult); &queryResult);
if (executeCommand != 0) if (executeCommand != 0)
{ {
return false; return false;
} }
tableSizeString = PQgetvalue(queryResult, 0, 0); char *tableSizeString = PQgetvalue(queryResult, 0, 0);
if (tableSizeString == NULL) if (tableSizeString == NULL)
{ {
PQclear(queryResult); PQclear(queryResult);
@ -939,7 +898,7 @@ WorkerShardStats(ShardPlacement *placement, Oid relationId, char *shardName,
} }
errno = 0; errno = 0;
tableSize = pg_strtouint64(tableSizeString, &tableSizeStringEnd, 0); uint64 tableSize = pg_strtouint64(tableSizeString, &tableSizeStringEnd, 0);
if (errno != 0 || (*tableSizeStringEnd) != '\0') if (errno != 0 || (*tableSizeStringEnd) != '\0')
{ {
PQclear(queryResult); PQclear(queryResult);
@ -959,8 +918,8 @@ WorkerShardStats(ShardPlacement *placement, Oid relationId, char *shardName,
} }
/* fill in the partition column name and shard name in the query. */ /* fill in the partition column name and shard name in the query. */
partitionColumn = PartitionColumn(relationId, unusedTableId); Var *partitionColumn = PartitionColumn(relationId, unusedTableId);
partitionColumnName = get_attname(relationId, partitionColumn->varattno, false); char *partitionColumnName = get_attname(relationId, partitionColumn->varattno, false);
appendStringInfo(partitionValueQuery, SHARD_RANGE_QUERY, appendStringInfo(partitionValueQuery, SHARD_RANGE_QUERY,
partitionColumnName, partitionColumnName, shardName); partitionColumnName, partitionColumnName, shardName);
@ -971,8 +930,8 @@ WorkerShardStats(ShardPlacement *placement, Oid relationId, char *shardName,
return false; return false;
} }
minValueIsNull = PQgetisnull(queryResult, 0, minValueIndex); bool minValueIsNull = PQgetisnull(queryResult, 0, minValueIndex);
maxValueIsNull = PQgetisnull(queryResult, 0, maxValueIndex); bool maxValueIsNull = PQgetisnull(queryResult, 0, maxValueIndex);
if (!minValueIsNull && !maxValueIsNull) if (!minValueIsNull && !maxValueIsNull)
{ {

View File

@ -41,21 +41,16 @@ PG_FUNCTION_INFO_V1(citus_truncate_trigger);
Datum Datum
citus_truncate_trigger(PG_FUNCTION_ARGS) citus_truncate_trigger(PG_FUNCTION_ARGS)
{ {
TriggerData *triggerData = NULL;
Relation truncatedRelation = NULL;
Oid relationId = InvalidOid;
char partitionMethod = 0;
if (!CALLED_AS_TRIGGER(fcinfo)) if (!CALLED_AS_TRIGGER(fcinfo))
{ {
ereport(ERROR, (errcode(ERRCODE_E_R_I_E_TRIGGER_PROTOCOL_VIOLATED), ereport(ERROR, (errcode(ERRCODE_E_R_I_E_TRIGGER_PROTOCOL_VIOLATED),
errmsg("must be called as trigger"))); errmsg("must be called as trigger")));
} }
triggerData = (TriggerData *) fcinfo->context; TriggerData *triggerData = (TriggerData *) fcinfo->context;
truncatedRelation = triggerData->tg_relation; Relation truncatedRelation = triggerData->tg_relation;
relationId = RelationGetRelid(truncatedRelation); Oid relationId = RelationGetRelid(truncatedRelation);
partitionMethod = PartitionMethod(relationId); char partitionMethod = PartitionMethod(relationId);
if (!EnableDDLPropagation) if (!EnableDDLPropagation)
{ {
@ -110,7 +105,6 @@ TruncateTaskList(Oid relationId)
ShardInterval *shardInterval = (ShardInterval *) lfirst(shardIntervalCell); ShardInterval *shardInterval = (ShardInterval *) lfirst(shardIntervalCell);
uint64 shardId = shardInterval->shardId; uint64 shardId = shardInterval->shardId;
StringInfo shardQueryString = makeStringInfo(); StringInfo shardQueryString = makeStringInfo();
Task *task = NULL;
char *shardName = pstrdup(relationName); char *shardName = pstrdup(relationName);
AppendShardIdToName(&shardName, shardId); AppendShardIdToName(&shardName, shardId);
@ -118,7 +112,7 @@ TruncateTaskList(Oid relationId)
appendStringInfo(shardQueryString, "TRUNCATE TABLE %s CASCADE", appendStringInfo(shardQueryString, "TRUNCATE TABLE %s CASCADE",
quote_qualified_identifier(schemaName, shardName)); quote_qualified_identifier(schemaName, shardName));
task = CitusMakeNode(Task); Task *task = CitusMakeNode(Task);
task->jobId = INVALID_JOB_ID; task->jobId = INVALID_JOB_ID;
task->taskId = taskId++; task->taskId = taskId++;
task->taskType = DDL_TASK; task->taskType = DDL_TASK;

View File

@ -67,7 +67,6 @@ WorkerGetRandomCandidateNode(List *currentNodeList)
WorkerNode *workerNode = NULL; WorkerNode *workerNode = NULL;
bool wantSameRack = false; bool wantSameRack = false;
uint32 tryCount = WORKER_RACK_TRIES; uint32 tryCount = WORKER_RACK_TRIES;
uint32 tryIndex = 0;
uint32 currentNodeCount = list_length(currentNodeList); uint32 currentNodeCount = list_length(currentNodeList);
List *candidateWorkerNodeList = PrimaryNodesNotInList(currentNodeList); List *candidateWorkerNodeList = PrimaryNodesNotInList(currentNodeList);
@ -104,17 +103,15 @@ WorkerGetRandomCandidateNode(List *currentNodeList)
* If after a predefined number of tries, we still cannot find such a node, * If after a predefined number of tries, we still cannot find such a node,
* we simply give up and return the last worker node we found. * we simply give up and return the last worker node we found.
*/ */
for (tryIndex = 0; tryIndex < tryCount; tryIndex++) for (uint32 tryIndex = 0; tryIndex < tryCount; tryIndex++)
{ {
WorkerNode *firstNode = (WorkerNode *) linitial(currentNodeList); WorkerNode *firstNode = (WorkerNode *) linitial(currentNodeList);
char *firstRack = firstNode->workerRack; char *firstRack = firstNode->workerRack;
char *workerRack = NULL;
bool sameRack = false;
workerNode = FindRandomNodeFromList(candidateWorkerNodeList); workerNode = FindRandomNodeFromList(candidateWorkerNodeList);
workerRack = workerNode->workerRack; char *workerRack = workerNode->workerRack;
sameRack = (strncmp(workerRack, firstRack, WORKER_LENGTH) == 0); bool sameRack = (strncmp(workerRack, firstRack, WORKER_LENGTH) == 0);
if ((sameRack && wantSameRack) || (!sameRack && !wantSameRack)) if ((sameRack && wantSameRack) || (!sameRack && !wantSameRack))
{ {
break; break;
@ -171,7 +168,6 @@ WorkerGetLocalFirstCandidateNode(List *currentNodeList)
if (currentNodeCount == 0) if (currentNodeCount == 0)
{ {
StringInfo clientHostStringInfo = makeStringInfo(); StringInfo clientHostStringInfo = makeStringInfo();
char *clientHost = NULL;
char *errorMessage = ClientHostAddress(clientHostStringInfo); char *errorMessage = ClientHostAddress(clientHostStringInfo);
if (errorMessage != NULL) if (errorMessage != NULL)
@ -184,7 +180,7 @@ WorkerGetLocalFirstCandidateNode(List *currentNodeList)
} }
/* if hostname is localhost.localdomain, change it to localhost */ /* if hostname is localhost.localdomain, change it to localhost */
clientHost = clientHostStringInfo->data; char *clientHost = clientHostStringInfo->data;
if (strncmp(clientHost, "localhost.localdomain", WORKER_LENGTH) == 0) if (strncmp(clientHost, "localhost.localdomain", WORKER_LENGTH) == 0)
{ {
clientHost = pstrdup("localhost"); clientHost = pstrdup("localhost");
@ -343,7 +339,6 @@ FilterActiveNodeListFunc(LOCKMODE lockMode, bool (*checkFunction)(WorkerNode *))
{ {
List *workerNodeList = NIL; List *workerNodeList = NIL;
WorkerNode *workerNode = NULL; WorkerNode *workerNode = NULL;
HTAB *workerNodeHash = NULL;
HASH_SEQ_STATUS status; HASH_SEQ_STATUS status;
Assert(checkFunction != NULL); Assert(checkFunction != NULL);
@ -353,7 +348,7 @@ FilterActiveNodeListFunc(LOCKMODE lockMode, bool (*checkFunction)(WorkerNode *))
LockRelationOid(DistNodeRelationId(), lockMode); LockRelationOid(DistNodeRelationId(), lockMode);
} }
workerNodeHash = GetWorkerNodeHash(); HTAB *workerNodeHash = GetWorkerNodeHash();
hash_seq_init(&status, workerNodeHash); hash_seq_init(&status, workerNodeHash);
while ((workerNode = hash_seq_search(&status)) != NULL) while ((workerNode = hash_seq_search(&status)) != NULL)
@ -568,10 +563,9 @@ CompareWorkerNodes(const void *leftElement, const void *rightElement)
{ {
const void *leftWorker = *((const void **) leftElement); const void *leftWorker = *((const void **) leftElement);
const void *rightWorker = *((const void **) rightElement); const void *rightWorker = *((const void **) rightElement);
int compare = 0;
Size ignoredKeySize = 0; Size ignoredKeySize = 0;
compare = WorkerNodeCompare(leftWorker, rightWorker, ignoredKeySize); int compare = WorkerNodeCompare(leftWorker, rightWorker, ignoredKeySize);
return compare; return compare;
} }
@ -588,16 +582,15 @@ WorkerNodeCompare(const void *lhsKey, const void *rhsKey, Size keySize)
const WorkerNode *workerLhs = (const WorkerNode *) lhsKey; const WorkerNode *workerLhs = (const WorkerNode *) lhsKey;
const WorkerNode *workerRhs = (const WorkerNode *) rhsKey; const WorkerNode *workerRhs = (const WorkerNode *) rhsKey;
int nameCompare = 0;
int portCompare = 0;
nameCompare = strncmp(workerLhs->workerName, workerRhs->workerName, WORKER_LENGTH); int nameCompare = strncmp(workerLhs->workerName, workerRhs->workerName,
WORKER_LENGTH);
if (nameCompare != 0) if (nameCompare != 0)
{ {
return nameCompare; return nameCompare;
} }
portCompare = workerLhs->workerPort - workerRhs->workerPort; int portCompare = workerLhs->workerPort - workerRhs->workerPort;
return portCompare; return portCompare;
} }

View File

@ -170,9 +170,7 @@ recurse_pg_depend(const ObjectAddress *target,
void (*apply)(ObjectAddressCollector *collector, Form_pg_depend row), void (*apply)(ObjectAddressCollector *collector, Form_pg_depend row),
ObjectAddressCollector *collector) ObjectAddressCollector *collector)
{ {
Relation depRel = NULL;
ScanKeyData key[2]; ScanKeyData key[2];
SysScanDesc depScan = NULL;
HeapTuple depTup = NULL; HeapTuple depTup = NULL;
List *pgDependEntries = NIL; List *pgDependEntries = NIL;
ListCell *pgDependCell = NULL; ListCell *pgDependCell = NULL;
@ -188,14 +186,15 @@ recurse_pg_depend(const ObjectAddress *target,
/* /*
* iterate the actual pg_depend catalog * iterate the actual pg_depend catalog
*/ */
depRel = heap_open(DependRelationId, AccessShareLock); Relation depRel = heap_open(DependRelationId, AccessShareLock);
/* scan pg_depend for classid = $1 AND objid = $2 using pg_depend_depender_index */ /* scan pg_depend for classid = $1 AND objid = $2 using pg_depend_depender_index */
ScanKeyInit(&key[0], Anum_pg_depend_classid, BTEqualStrategyNumber, F_OIDEQ, ScanKeyInit(&key[0], Anum_pg_depend_classid, BTEqualStrategyNumber, F_OIDEQ,
ObjectIdGetDatum(target->classId)); ObjectIdGetDatum(target->classId));
ScanKeyInit(&key[1], Anum_pg_depend_objid, BTEqualStrategyNumber, F_OIDEQ, ScanKeyInit(&key[1], Anum_pg_depend_objid, BTEqualStrategyNumber, F_OIDEQ,
ObjectIdGetDatum(target->objectId)); ObjectIdGetDatum(target->objectId));
depScan = systable_beginscan(depRel, DependDependerIndexId, true, NULL, 2, key); SysScanDesc depScan = systable_beginscan(depRel, DependDependerIndexId, true, NULL, 2,
key);
while (HeapTupleIsValid(depTup = systable_getnext(depScan))) while (HeapTupleIsValid(depTup = systable_getnext(depScan)))
{ {
@ -215,9 +214,7 @@ recurse_pg_depend(const ObjectAddress *target,
*/ */
if (expand != NULL) if (expand != NULL)
{ {
List *expandedEntries = NIL; List *expandedEntries = expand(collector, target);
expandedEntries = expand(collector, target);
pgDependEntries = list_concat(pgDependEntries, expandedEntries); pgDependEntries = list_concat(pgDependEntries, expandedEntries);
} }
@ -262,14 +259,13 @@ recurse_pg_depend(const ObjectAddress *target,
static void static void
InitObjectAddressCollector(ObjectAddressCollector *collector) InitObjectAddressCollector(ObjectAddressCollector *collector)
{ {
int hashFlags = 0;
HASHCTL info; HASHCTL info;
memset(&info, 0, sizeof(info)); memset(&info, 0, sizeof(info));
info.keysize = sizeof(ObjectAddress); info.keysize = sizeof(ObjectAddress);
info.entrysize = sizeof(ObjectAddress); info.entrysize = sizeof(ObjectAddress);
info.hcxt = CurrentMemoryContext; info.hcxt = CurrentMemoryContext;
hashFlags = (HASH_ELEM | HASH_CONTEXT | HASH_BLOBS); int hashFlags = (HASH_ELEM | HASH_CONTEXT | HASH_BLOBS);
collector->dependencySet = hash_create("dependency set", 128, &info, hashFlags); collector->dependencySet = hash_create("dependency set", 128, &info, hashFlags);
collector->dependencyList = NULL; collector->dependencyList = NULL;
@ -301,11 +297,11 @@ TargetObjectVisited(ObjectAddressCollector *collector, const ObjectAddress *targ
static void static void
MarkObjectVisited(ObjectAddressCollector *collector, const ObjectAddress *target) MarkObjectVisited(ObjectAddressCollector *collector, const ObjectAddress *target)
{ {
ObjectAddress *address = NULL;
bool found = false; bool found = false;
/* add to set */ /* add to set */
address = (ObjectAddress *) hash_search(collector->visitedObjects, target, ObjectAddress *address = (ObjectAddress *) hash_search(collector->visitedObjects,
target,
HASH_ENTER, &found); HASH_ENTER, &found);
if (!found) if (!found)
@ -322,11 +318,11 @@ MarkObjectVisited(ObjectAddressCollector *collector, const ObjectAddress *target
static void static void
CollectObjectAddress(ObjectAddressCollector *collector, const ObjectAddress *collect) CollectObjectAddress(ObjectAddressCollector *collector, const ObjectAddress *collect)
{ {
ObjectAddress *address = NULL;
bool found = false; bool found = false;
/* add to set */ /* add to set */
address = (ObjectAddress *) hash_search(collector->dependencySet, collect, ObjectAddress *address = (ObjectAddress *) hash_search(collector->dependencySet,
collect,
HASH_ENTER, &found); HASH_ENTER, &found);
if (!found) if (!found)
@ -475,20 +471,19 @@ bool
IsObjectAddressOwnedByExtension(const ObjectAddress *target, IsObjectAddressOwnedByExtension(const ObjectAddress *target,
ObjectAddress *extensionAddress) ObjectAddress *extensionAddress)
{ {
Relation depRel = NULL;
ScanKeyData key[2]; ScanKeyData key[2];
SysScanDesc depScan = NULL;
HeapTuple depTup = NULL; HeapTuple depTup = NULL;
bool result = false; bool result = false;
depRel = heap_open(DependRelationId, AccessShareLock); Relation depRel = heap_open(DependRelationId, AccessShareLock);
/* scan pg_depend for classid = $1 AND objid = $2 using pg_depend_depender_index */ /* scan pg_depend for classid = $1 AND objid = $2 using pg_depend_depender_index */
ScanKeyInit(&key[0], Anum_pg_depend_classid, BTEqualStrategyNumber, F_OIDEQ, ScanKeyInit(&key[0], Anum_pg_depend_classid, BTEqualStrategyNumber, F_OIDEQ,
ObjectIdGetDatum(target->classId)); ObjectIdGetDatum(target->classId));
ScanKeyInit(&key[1], Anum_pg_depend_objid, BTEqualStrategyNumber, F_OIDEQ, ScanKeyInit(&key[1], Anum_pg_depend_objid, BTEqualStrategyNumber, F_OIDEQ,
ObjectIdGetDatum(target->objectId)); ObjectIdGetDatum(target->objectId));
depScan = systable_beginscan(depRel, DependDependerIndexId, true, NULL, 2, key); SysScanDesc depScan = systable_beginscan(depRel, DependDependerIndexId, true, NULL, 2,
key);
while (HeapTupleIsValid(depTup = systable_getnext(depScan))) while (HeapTupleIsValid(depTup = systable_getnext(depScan)))
{ {

View File

@ -139,12 +139,11 @@ MarkObjectDistributed(const ObjectAddress *distAddress)
ObjectIdGetDatum(distAddress->objectId), ObjectIdGetDatum(distAddress->objectId),
Int32GetDatum(distAddress->objectSubId) Int32GetDatum(distAddress->objectSubId)
}; };
int spiStatus = 0;
char *insertQuery = "INSERT INTO citus.pg_dist_object (classid, objid, objsubid) " char *insertQuery = "INSERT INTO citus.pg_dist_object (classid, objid, objsubid) "
"VALUES ($1, $2, $3) ON CONFLICT DO NOTHING"; "VALUES ($1, $2, $3) ON CONFLICT DO NOTHING";
spiStatus = ExecuteCommandAsSuperuser(insertQuery, paramCount, paramTypes, int spiStatus = ExecuteCommandAsSuperuser(insertQuery, paramCount, paramTypes,
paramValues); paramValues);
if (spiStatus < 0) if (spiStatus < 0)
{ {
@ -160,14 +159,12 @@ MarkObjectDistributed(const ObjectAddress *distAddress)
bool bool
CitusExtensionObject(const ObjectAddress *objectAddress) CitusExtensionObject(const ObjectAddress *objectAddress)
{ {
char *extensionName = false;
if (objectAddress->classId != ExtensionRelationId) if (objectAddress->classId != ExtensionRelationId)
{ {
return false; return false;
} }
extensionName = get_extension_name(objectAddress->objectId); char *extensionName = get_extension_name(objectAddress->objectId);
if (extensionName != NULL && if (extensionName != NULL &&
strncasecmp(extensionName, "citus", NAMEDATALEN) == 0) strncasecmp(extensionName, "citus", NAMEDATALEN) == 0)
{ {
@ -188,13 +185,10 @@ static int
ExecuteCommandAsSuperuser(char *query, int paramCount, Oid *paramTypes, ExecuteCommandAsSuperuser(char *query, int paramCount, Oid *paramTypes,
Datum *paramValues) Datum *paramValues)
{ {
int spiConnected = 0;
Oid savedUserId = InvalidOid; Oid savedUserId = InvalidOid;
int savedSecurityContext = 0; int savedSecurityContext = 0;
int spiStatus = 0;
int spiFinished = 0;
spiConnected = SPI_connect(); int spiConnected = SPI_connect();
if (spiConnected != SPI_OK_CONNECT) if (spiConnected != SPI_OK_CONNECT)
{ {
ereport(ERROR, (errmsg("could not connect to SPI manager"))); ereport(ERROR, (errmsg("could not connect to SPI manager")));
@ -204,12 +198,12 @@ ExecuteCommandAsSuperuser(char *query, int paramCount, Oid *paramTypes,
GetUserIdAndSecContext(&savedUserId, &savedSecurityContext); GetUserIdAndSecContext(&savedUserId, &savedSecurityContext);
SetUserIdAndSecContext(CitusExtensionOwner(), SECURITY_LOCAL_USERID_CHANGE); SetUserIdAndSecContext(CitusExtensionOwner(), SECURITY_LOCAL_USERID_CHANGE);
spiStatus = SPI_execute_with_args(query, paramCount, paramTypes, paramValues, int spiStatus = SPI_execute_with_args(query, paramCount, paramTypes, paramValues,
NULL, false, 0); NULL, false, 0);
SetUserIdAndSecContext(savedUserId, savedSecurityContext); SetUserIdAndSecContext(savedUserId, savedSecurityContext);
spiFinished = SPI_finish(); int spiFinished = SPI_finish();
if (spiFinished != SPI_OK_FINISH) if (spiFinished != SPI_OK_FINISH)
{ {
ereport(ERROR, (errmsg("could not disconnect from SPI manager"))); ereport(ERROR, (errmsg("could not disconnect from SPI manager")));
@ -237,12 +231,11 @@ UnmarkObjectDistributed(const ObjectAddress *address)
ObjectIdGetDatum(address->objectId), ObjectIdGetDatum(address->objectId),
Int32GetDatum(address->objectSubId) Int32GetDatum(address->objectSubId)
}; };
int spiStatus = 0;
char *deleteQuery = "DELETE FROM citus.pg_dist_object WHERE classid = $1 AND " char *deleteQuery = "DELETE FROM citus.pg_dist_object WHERE classid = $1 AND "
"objid = $2 AND objsubid = $3"; "objid = $2 AND objsubid = $3";
spiStatus = ExecuteCommandAsSuperuser(deleteQuery, paramCount, paramTypes, int spiStatus = ExecuteCommandAsSuperuser(deleteQuery, paramCount, paramTypes,
paramValues); paramValues);
if (spiStatus < 0) if (spiStatus < 0)
{ {
@ -258,13 +251,10 @@ UnmarkObjectDistributed(const ObjectAddress *address)
bool bool
IsObjectDistributed(const ObjectAddress *address) IsObjectDistributed(const ObjectAddress *address)
{ {
Relation pgDistObjectRel = NULL;
ScanKeyData key[3]; ScanKeyData key[3];
SysScanDesc pgDistObjectScan = NULL;
HeapTuple pgDistObjectTup = NULL;
bool result = false; bool result = false;
pgDistObjectRel = heap_open(DistObjectRelationId(), AccessShareLock); Relation pgDistObjectRel = heap_open(DistObjectRelationId(), AccessShareLock);
/* scan pg_dist_object for classid = $1 AND objid = $2 AND objsubid = $3 via index */ /* scan pg_dist_object for classid = $1 AND objid = $2 AND objsubid = $3 via index */
ScanKeyInit(&key[0], Anum_pg_dist_object_classid, BTEqualStrategyNumber, F_OIDEQ, ScanKeyInit(&key[0], Anum_pg_dist_object_classid, BTEqualStrategyNumber, F_OIDEQ,
@ -273,10 +263,11 @@ IsObjectDistributed(const ObjectAddress *address)
ObjectIdGetDatum(address->objectId)); ObjectIdGetDatum(address->objectId));
ScanKeyInit(&key[2], Anum_pg_dist_object_objsubid, BTEqualStrategyNumber, F_INT4EQ, ScanKeyInit(&key[2], Anum_pg_dist_object_objsubid, BTEqualStrategyNumber, F_INT4EQ,
Int32GetDatum(address->objectSubId)); Int32GetDatum(address->objectSubId));
pgDistObjectScan = systable_beginscan(pgDistObjectRel, DistObjectPrimaryKeyIndexId(), SysScanDesc pgDistObjectScan = systable_beginscan(pgDistObjectRel,
DistObjectPrimaryKeyIndexId(),
true, NULL, 3, key); true, NULL, 3, key);
pgDistObjectTup = systable_getnext(pgDistObjectScan); HeapTuple pgDistObjectTup = systable_getnext(pgDistObjectScan);
if (HeapTupleIsValid(pgDistObjectTup)) if (HeapTupleIsValid(pgDistObjectTup))
{ {
result = true; result = true;
@ -299,14 +290,13 @@ ClusterHasDistributedFunctionWithDistArgument(void)
{ {
bool foundDistributedFunction = false; bool foundDistributedFunction = false;
SysScanDesc pgDistObjectScan = NULL;
HeapTuple pgDistObjectTup = NULL; HeapTuple pgDistObjectTup = NULL;
Relation pgDistObjectRel = heap_open(DistObjectRelationId(), AccessShareLock); Relation pgDistObjectRel = heap_open(DistObjectRelationId(), AccessShareLock);
TupleDesc tupleDescriptor = RelationGetDescr(pgDistObjectRel); TupleDesc tupleDescriptor = RelationGetDescr(pgDistObjectRel);
pgDistObjectScan = SysScanDesc pgDistObjectScan =
systable_beginscan(pgDistObjectRel, InvalidOid, false, NULL, 0, NULL); systable_beginscan(pgDistObjectRel, InvalidOid, false, NULL, 0, NULL);
while (HeapTupleIsValid(pgDistObjectTup = systable_getnext(pgDistObjectScan))) while (HeapTupleIsValid(pgDistObjectTup = systable_getnext(pgDistObjectScan)))
{ {
@ -315,8 +305,7 @@ ClusterHasDistributedFunctionWithDistArgument(void)
if (pg_dist_object->classid == ProcedureRelationId) if (pg_dist_object->classid == ProcedureRelationId)
{ {
bool distArgumentIsNull = false; bool distArgumentIsNull =
distArgumentIsNull =
heap_attisnull(pgDistObjectTup, heap_attisnull(pgDistObjectTup,
Anum_pg_dist_object_distribution_argument_index, Anum_pg_dist_object_distribution_argument_index,
tupleDescriptor); tupleDescriptor);
@ -345,13 +334,12 @@ ClusterHasDistributedFunctionWithDistArgument(void)
List * List *
GetDistributedObjectAddressList(void) GetDistributedObjectAddressList(void)
{ {
Relation pgDistObjectRel = NULL;
SysScanDesc pgDistObjectScan = NULL;
HeapTuple pgDistObjectTup = NULL; HeapTuple pgDistObjectTup = NULL;
List *objectAddressList = NIL; List *objectAddressList = NIL;
pgDistObjectRel = heap_open(DistObjectRelationId(), AccessShareLock); Relation pgDistObjectRel = heap_open(DistObjectRelationId(), AccessShareLock);
pgDistObjectScan = systable_beginscan(pgDistObjectRel, InvalidOid, false, NULL, 0, SysScanDesc pgDistObjectScan = systable_beginscan(pgDistObjectRel, InvalidOid, false,
NULL, 0,
NULL); NULL);
while (HeapTupleIsValid(pgDistObjectTup = systable_getnext(pgDistObjectScan))) while (HeapTupleIsValid(pgDistObjectTup = systable_getnext(pgDistObjectScan)))
{ {

File diff suppressed because it is too large Load Diff

View File

@ -91,7 +91,6 @@ start_metadata_sync_to_node(PG_FUNCTION_ARGS)
void void
StartMetadatSyncToNode(char *nodeNameString, int32 nodePort) StartMetadatSyncToNode(char *nodeNameString, int32 nodePort)
{ {
WorkerNode *workerNode = NULL;
char *escapedNodeName = quote_literal_cstr(nodeNameString); char *escapedNodeName = quote_literal_cstr(nodeNameString);
/* fail if metadata synchronization doesn't succeed */ /* fail if metadata synchronization doesn't succeed */
@ -106,7 +105,7 @@ StartMetadatSyncToNode(char *nodeNameString, int32 nodePort)
LockRelationOid(DistNodeRelationId(), ExclusiveLock); LockRelationOid(DistNodeRelationId(), ExclusiveLock);
workerNode = FindWorkerNode(nodeNameString, nodePort); WorkerNode *workerNode = FindWorkerNode(nodeNameString, nodePort);
if (workerNode == NULL) if (workerNode == NULL)
{ {
ereport(ERROR, (errcode(ERRCODE_OBJECT_NOT_IN_PREREQUISITE_STATE), ereport(ERROR, (errcode(ERRCODE_OBJECT_NOT_IN_PREREQUISITE_STATE),
@ -159,7 +158,6 @@ stop_metadata_sync_to_node(PG_FUNCTION_ARGS)
text *nodeName = PG_GETARG_TEXT_P(0); text *nodeName = PG_GETARG_TEXT_P(0);
int32 nodePort = PG_GETARG_INT32(1); int32 nodePort = PG_GETARG_INT32(1);
char *nodeNameString = text_to_cstring(nodeName); char *nodeNameString = text_to_cstring(nodeName);
WorkerNode *workerNode = NULL;
EnsureCoordinator(); EnsureCoordinator();
EnsureSuperUser(); EnsureSuperUser();
@ -167,7 +165,7 @@ stop_metadata_sync_to_node(PG_FUNCTION_ARGS)
LockRelationOid(DistNodeRelationId(), ExclusiveLock); LockRelationOid(DistNodeRelationId(), ExclusiveLock);
workerNode = FindWorkerNode(nodeNameString, nodePort); WorkerNode *workerNode = FindWorkerNode(nodeNameString, nodePort);
if (workerNode == NULL) if (workerNode == NULL)
{ {
ereport(ERROR, (errcode(ERRCODE_OBJECT_NOT_IN_PREREQUISITE_STATE), ereport(ERROR, (errcode(ERRCODE_OBJECT_NOT_IN_PREREQUISITE_STATE),
@ -297,12 +295,12 @@ bool
SendOptionalCommandListToWorkerInTransaction(char *nodeName, int32 nodePort, SendOptionalCommandListToWorkerInTransaction(char *nodeName, int32 nodePort,
char *nodeUser, List *commandList) char *nodeUser, List *commandList)
{ {
MultiConnection *workerConnection = NULL;
ListCell *commandCell = NULL; ListCell *commandCell = NULL;
int connectionFlags = FORCE_NEW_CONNECTION; int connectionFlags = FORCE_NEW_CONNECTION;
bool failed = false; bool failed = false;
workerConnection = GetNodeUserDatabaseConnection(connectionFlags, nodeName, nodePort, MultiConnection *workerConnection = GetNodeUserDatabaseConnection(connectionFlags,
nodeName, nodePort,
nodeUser, NULL); nodeUser, NULL);
RemoteTransactionBegin(workerConnection); RemoteTransactionBegin(workerConnection);
@ -356,14 +354,13 @@ MetadataCreateCommands(void)
bool includeNodesFromOtherClusters = true; bool includeNodesFromOtherClusters = true;
List *workerNodeList = ReadDistNode(includeNodesFromOtherClusters); List *workerNodeList = ReadDistNode(includeNodesFromOtherClusters);
ListCell *distributedTableCell = NULL; ListCell *distributedTableCell = NULL;
char *nodeListInsertCommand = NULL;
bool includeSequenceDefaults = true; bool includeSequenceDefaults = true;
/* make sure we have deterministic output for our tests */ /* make sure we have deterministic output for our tests */
workerNodeList = SortList(workerNodeList, CompareWorkerNodes); workerNodeList = SortList(workerNodeList, CompareWorkerNodes);
/* generate insert command for pg_dist_node table */ /* generate insert command for pg_dist_node table */
nodeListInsertCommand = NodeListInsertCommand(workerNodeList); char *nodeListInsertCommand = NodeListInsertCommand(workerNodeList);
metadataSnapshotCommandList = lappend(metadataSnapshotCommandList, metadataSnapshotCommandList = lappend(metadataSnapshotCommandList,
nodeListInsertCommand); nodeListInsertCommand);
@ -441,26 +438,22 @@ MetadataCreateCommands(void)
{ {
DistTableCacheEntry *cacheEntry = DistTableCacheEntry *cacheEntry =
(DistTableCacheEntry *) lfirst(distributedTableCell); (DistTableCacheEntry *) lfirst(distributedTableCell);
List *shardIntervalList = NIL;
List *shardCreateCommandList = NIL;
char *metadataCommand = NULL;
char *truncateTriggerCreateCommand = NULL;
Oid clusteredTableId = cacheEntry->relationId; Oid clusteredTableId = cacheEntry->relationId;
/* add the table metadata command first*/ /* add the table metadata command first*/
metadataCommand = DistributionCreateCommand(cacheEntry); char *metadataCommand = DistributionCreateCommand(cacheEntry);
metadataSnapshotCommandList = lappend(metadataSnapshotCommandList, metadataSnapshotCommandList = lappend(metadataSnapshotCommandList,
metadataCommand); metadataCommand);
/* add the truncate trigger command after the table became distributed */ /* add the truncate trigger command after the table became distributed */
truncateTriggerCreateCommand = char *truncateTriggerCreateCommand =
TruncateTriggerCreateCommand(cacheEntry->relationId); TruncateTriggerCreateCommand(cacheEntry->relationId);
metadataSnapshotCommandList = lappend(metadataSnapshotCommandList, metadataSnapshotCommandList = lappend(metadataSnapshotCommandList,
truncateTriggerCreateCommand); truncateTriggerCreateCommand);
/* add the pg_dist_shard{,placement} entries */ /* add the pg_dist_shard{,placement} entries */
shardIntervalList = LoadShardIntervalList(clusteredTableId); List *shardIntervalList = LoadShardIntervalList(clusteredTableId);
shardCreateCommandList = ShardListInsertCommand(shardIntervalList); List *shardCreateCommandList = ShardListInsertCommand(shardIntervalList);
metadataSnapshotCommandList = list_concat(metadataSnapshotCommandList, metadataSnapshotCommandList = list_concat(metadataSnapshotCommandList,
shardCreateCommandList); shardCreateCommandList);
@ -481,44 +474,36 @@ GetDistributedTableDDLEvents(Oid relationId)
{ {
DistTableCacheEntry *cacheEntry = DistributedTableCacheEntry(relationId); DistTableCacheEntry *cacheEntry = DistributedTableCacheEntry(relationId);
List *shardIntervalList = NIL;
List *commandList = NIL; List *commandList = NIL;
List *foreignConstraintCommands = NIL;
List *shardMetadataInsertCommandList = NIL;
List *sequenceDDLCommands = NIL;
List *tableDDLCommands = NIL;
char *tableOwnerResetCommand = NULL;
char *metadataCommand = NULL;
char *truncateTriggerCreateCommand = NULL;
bool includeSequenceDefaults = true; bool includeSequenceDefaults = true;
/* commands to create sequences */ /* commands to create sequences */
sequenceDDLCommands = SequenceDDLCommandsForTable(relationId); List *sequenceDDLCommands = SequenceDDLCommandsForTable(relationId);
commandList = list_concat(commandList, sequenceDDLCommands); commandList = list_concat(commandList, sequenceDDLCommands);
/* commands to create the table */ /* commands to create the table */
tableDDLCommands = GetTableDDLEvents(relationId, includeSequenceDefaults); List *tableDDLCommands = GetTableDDLEvents(relationId, includeSequenceDefaults);
commandList = list_concat(commandList, tableDDLCommands); commandList = list_concat(commandList, tableDDLCommands);
/* command to reset the table owner */ /* command to reset the table owner */
tableOwnerResetCommand = TableOwnerResetCommand(relationId); char *tableOwnerResetCommand = TableOwnerResetCommand(relationId);
commandList = lappend(commandList, tableOwnerResetCommand); commandList = lappend(commandList, tableOwnerResetCommand);
/* command to insert pg_dist_partition entry */ /* command to insert pg_dist_partition entry */
metadataCommand = DistributionCreateCommand(cacheEntry); char *metadataCommand = DistributionCreateCommand(cacheEntry);
commandList = lappend(commandList, metadataCommand); commandList = lappend(commandList, metadataCommand);
/* commands to create the truncate trigger of the table */ /* commands to create the truncate trigger of the table */
truncateTriggerCreateCommand = TruncateTriggerCreateCommand(relationId); char *truncateTriggerCreateCommand = TruncateTriggerCreateCommand(relationId);
commandList = lappend(commandList, truncateTriggerCreateCommand); commandList = lappend(commandList, truncateTriggerCreateCommand);
/* commands to insert pg_dist_shard & pg_dist_placement entries */ /* commands to insert pg_dist_shard & pg_dist_placement entries */
shardIntervalList = LoadShardIntervalList(relationId); List *shardIntervalList = LoadShardIntervalList(relationId);
shardMetadataInsertCommandList = ShardListInsertCommand(shardIntervalList); List *shardMetadataInsertCommandList = ShardListInsertCommand(shardIntervalList);
commandList = list_concat(commandList, shardMetadataInsertCommandList); commandList = list_concat(commandList, shardMetadataInsertCommandList);
/* commands to create foreign key constraints */ /* commands to create foreign key constraints */
foreignConstraintCommands = GetTableForeignConstraintCommands(relationId); List *foreignConstraintCommands = GetTableForeignConstraintCommands(relationId);
commandList = list_concat(commandList, foreignConstraintCommands); commandList = list_concat(commandList, foreignConstraintCommands);
/* commands to create partitioning hierarchy */ /* commands to create partitioning hierarchy */
@ -686,10 +671,9 @@ DistributionCreateCommand(DistTableCacheEntry *cacheEntry)
char * char *
DistributionDeleteCommand(char *schemaName, char *tableName) DistributionDeleteCommand(char *schemaName, char *tableName)
{ {
char *distributedRelationName = NULL;
StringInfo deleteDistributionCommand = makeStringInfo(); StringInfo deleteDistributionCommand = makeStringInfo();
distributedRelationName = quote_qualified_identifier(schemaName, tableName); char *distributedRelationName = quote_qualified_identifier(schemaName, tableName);
appendStringInfo(deleteDistributionCommand, appendStringInfo(deleteDistributionCommand,
"SELECT worker_drop_distributed_table(%s)", "SELECT worker_drop_distributed_table(%s)",
@ -850,11 +834,9 @@ ShardDeleteCommandList(ShardInterval *shardInterval)
{ {
uint64 shardId = shardInterval->shardId; uint64 shardId = shardInterval->shardId;
List *commandList = NIL; List *commandList = NIL;
StringInfo deletePlacementCommand = NULL;
StringInfo deleteShardCommand = NULL;
/* create command to delete shard placements */ /* create command to delete shard placements */
deletePlacementCommand = makeStringInfo(); StringInfo deletePlacementCommand = makeStringInfo();
appendStringInfo(deletePlacementCommand, appendStringInfo(deletePlacementCommand,
"DELETE FROM pg_dist_placement WHERE shardid = " UINT64_FORMAT, "DELETE FROM pg_dist_placement WHERE shardid = " UINT64_FORMAT,
shardId); shardId);
@ -862,7 +844,7 @@ ShardDeleteCommandList(ShardInterval *shardInterval)
commandList = lappend(commandList, deletePlacementCommand->data); commandList = lappend(commandList, deletePlacementCommand->data);
/* create command to delete shard */ /* create command to delete shard */
deleteShardCommand = makeStringInfo(); StringInfo deleteShardCommand = makeStringInfo();
appendStringInfo(deleteShardCommand, appendStringInfo(deleteShardCommand,
"DELETE FROM pg_dist_shard WHERE shardid = " UINT64_FORMAT, shardId); "DELETE FROM pg_dist_shard WHERE shardid = " UINT64_FORMAT, shardId);
@ -1013,27 +995,23 @@ UpdateDistNodeBoolAttr(char *nodeName, int32 nodePort, int attrNum, bool value)
{ {
const bool indexOK = false; const bool indexOK = false;
Relation pgDistNode = NULL;
TupleDesc tupleDescriptor = NULL;
ScanKeyData scanKey[2]; ScanKeyData scanKey[2];
SysScanDesc scanDescriptor = NULL;
HeapTuple heapTuple = NULL;
Datum values[Natts_pg_dist_node]; Datum values[Natts_pg_dist_node];
bool isnull[Natts_pg_dist_node]; bool isnull[Natts_pg_dist_node];
bool replace[Natts_pg_dist_node]; bool replace[Natts_pg_dist_node];
pgDistNode = heap_open(DistNodeRelationId(), RowExclusiveLock); Relation pgDistNode = heap_open(DistNodeRelationId(), RowExclusiveLock);
tupleDescriptor = RelationGetDescr(pgDistNode); TupleDesc tupleDescriptor = RelationGetDescr(pgDistNode);
ScanKeyInit(&scanKey[0], Anum_pg_dist_node_nodename, ScanKeyInit(&scanKey[0], Anum_pg_dist_node_nodename,
BTEqualStrategyNumber, F_TEXTEQ, CStringGetTextDatum(nodeName)); BTEqualStrategyNumber, F_TEXTEQ, CStringGetTextDatum(nodeName));
ScanKeyInit(&scanKey[1], Anum_pg_dist_node_nodeport, ScanKeyInit(&scanKey[1], Anum_pg_dist_node_nodeport,
BTEqualStrategyNumber, F_INT4EQ, Int32GetDatum(nodePort)); BTEqualStrategyNumber, F_INT4EQ, Int32GetDatum(nodePort));
scanDescriptor = systable_beginscan(pgDistNode, InvalidOid, indexOK, SysScanDesc scanDescriptor = systable_beginscan(pgDistNode, InvalidOid, indexOK,
NULL, 2, scanKey); NULL, 2, scanKey);
heapTuple = systable_getnext(scanDescriptor); HeapTuple heapTuple = systable_getnext(scanDescriptor);
if (!HeapTupleIsValid(heapTuple)) if (!HeapTupleIsValid(heapTuple))
{ {
ereport(ERROR, (errmsg("could not find valid entry for node \"%s:%d\"", ereport(ERROR, (errmsg("could not find valid entry for node \"%s:%d\"",
@ -1113,18 +1091,15 @@ char *
CreateSchemaDDLCommand(Oid schemaId) CreateSchemaDDLCommand(Oid schemaId)
{ {
char *schemaName = get_namespace_name(schemaId); char *schemaName = get_namespace_name(schemaId);
StringInfo schemaNameDef = NULL;
const char *ownerName = NULL;
const char *quotedSchemaName = NULL;
if (strncmp(schemaName, "public", NAMEDATALEN) == 0) if (strncmp(schemaName, "public", NAMEDATALEN) == 0)
{ {
return NULL; return NULL;
} }
schemaNameDef = makeStringInfo(); StringInfo schemaNameDef = makeStringInfo();
quotedSchemaName = quote_identifier(schemaName); const char *quotedSchemaName = quote_identifier(schemaName);
ownerName = quote_identifier(SchemaOwnerName(schemaId)); const char *ownerName = quote_identifier(SchemaOwnerName(schemaId));
appendStringInfo(schemaNameDef, CREATE_SCHEMA_COMMAND, quotedSchemaName, ownerName); appendStringInfo(schemaNameDef, CREATE_SCHEMA_COMMAND, quotedSchemaName, ownerName);
return schemaNameDef->data; return schemaNameDef->data;
@ -1155,11 +1130,9 @@ TruncateTriggerCreateCommand(Oid relationId)
static char * static char *
SchemaOwnerName(Oid objectId) SchemaOwnerName(Oid objectId)
{ {
HeapTuple tuple = NULL;
Oid ownerId = InvalidOid; Oid ownerId = InvalidOid;
char *ownerName = NULL;
tuple = SearchSysCache1(NAMESPACEOID, ObjectIdGetDatum(objectId)); HeapTuple tuple = SearchSysCache1(NAMESPACEOID, ObjectIdGetDatum(objectId));
if (HeapTupleIsValid(tuple)) if (HeapTupleIsValid(tuple))
{ {
ownerId = ((Form_pg_namespace) GETSTRUCT(tuple))->nspowner; ownerId = ((Form_pg_namespace) GETSTRUCT(tuple))->nspowner;
@ -1169,7 +1142,7 @@ SchemaOwnerName(Oid objectId)
ownerId = GetUserId(); ownerId = GetUserId();
} }
ownerName = GetUserNameFromId(ownerId, false); char *ownerName = GetUserNameFromId(ownerId, false);
ReleaseSysCache(tuple); ReleaseSysCache(tuple);
@ -1248,7 +1221,6 @@ DetachPartitionCommandList(void)
{ {
DistTableCacheEntry *cacheEntry = DistTableCacheEntry *cacheEntry =
(DistTableCacheEntry *) lfirst(distributedTableCell); (DistTableCacheEntry *) lfirst(distributedTableCell);
List *partitionList = NIL;
ListCell *partitionCell = NULL; ListCell *partitionCell = NULL;
if (!PartitionedTable(cacheEntry->relationId)) if (!PartitionedTable(cacheEntry->relationId))
@ -1256,7 +1228,7 @@ DetachPartitionCommandList(void)
continue; continue;
} }
partitionList = PartitionList(cacheEntry->relationId); List *partitionList = PartitionList(cacheEntry->relationId);
foreach(partitionCell, partitionList) foreach(partitionCell, partitionList)
{ {
Oid partitionRelationId = lfirst_oid(partitionCell); Oid partitionRelationId = lfirst_oid(partitionCell);
@ -1295,7 +1267,6 @@ DetachPartitionCommandList(void)
MetadataSyncResult MetadataSyncResult
SyncMetadataToNodes(void) SyncMetadataToNodes(void)
{ {
List *workerList = NIL;
ListCell *workerCell = NULL; ListCell *workerCell = NULL;
MetadataSyncResult result = METADATA_SYNC_SUCCESS; MetadataSyncResult result = METADATA_SYNC_SUCCESS;
@ -1314,7 +1285,7 @@ SyncMetadataToNodes(void)
return METADATA_SYNC_FAILED_LOCK; return METADATA_SYNC_FAILED_LOCK;
} }
workerList = ActivePrimaryWorkerNodeList(NoLock); List *workerList = ActivePrimaryWorkerNodeList(NoLock);
foreach(workerCell, workerList) foreach(workerCell, workerList)
{ {

View File

@ -128,7 +128,6 @@ master_add_node(PG_FUNCTION_ARGS)
text *nodeName = PG_GETARG_TEXT_P(0); text *nodeName = PG_GETARG_TEXT_P(0);
int32 nodePort = PG_GETARG_INT32(1); int32 nodePort = PG_GETARG_INT32(1);
char *nodeNameString = text_to_cstring(nodeName); char *nodeNameString = text_to_cstring(nodeName);
int nodeId = 0;
NodeMetadata nodeMetadata = DefaultNodeMetadata(); NodeMetadata nodeMetadata = DefaultNodeMetadata();
bool nodeAlreadyExists = false; bool nodeAlreadyExists = false;
@ -153,7 +152,7 @@ master_add_node(PG_FUNCTION_ARGS)
nodeMetadata.nodeRole = PG_GETARG_OID(3); nodeMetadata.nodeRole = PG_GETARG_OID(3);
} }
nodeId = AddNodeMetadata(nodeNameString, nodePort, &nodeMetadata, int nodeId = AddNodeMetadata(nodeNameString, nodePort, &nodeMetadata,
&nodeAlreadyExists); &nodeAlreadyExists);
/* /*
@ -185,14 +184,13 @@ master_add_inactive_node(PG_FUNCTION_ARGS)
NodeMetadata nodeMetadata = DefaultNodeMetadata(); NodeMetadata nodeMetadata = DefaultNodeMetadata();
bool nodeAlreadyExists = false; bool nodeAlreadyExists = false;
int nodeId = 0;
nodeMetadata.groupId = PG_GETARG_INT32(2); nodeMetadata.groupId = PG_GETARG_INT32(2);
nodeMetadata.nodeRole = PG_GETARG_OID(3); nodeMetadata.nodeRole = PG_GETARG_OID(3);
nodeMetadata.nodeCluster = NameStr(*nodeClusterName); nodeMetadata.nodeCluster = NameStr(*nodeClusterName);
CheckCitusVersion(ERROR); CheckCitusVersion(ERROR);
nodeId = AddNodeMetadata(nodeNameString, nodePort, &nodeMetadata, int nodeId = AddNodeMetadata(nodeNameString, nodePort, &nodeMetadata,
&nodeAlreadyExists); &nodeAlreadyExists);
PG_RETURN_INT32(nodeId); PG_RETURN_INT32(nodeId);
@ -217,7 +215,6 @@ master_add_secondary_node(PG_FUNCTION_ARGS)
Name nodeClusterName = PG_GETARG_NAME(4); Name nodeClusterName = PG_GETARG_NAME(4);
NodeMetadata nodeMetadata = DefaultNodeMetadata(); NodeMetadata nodeMetadata = DefaultNodeMetadata();
bool nodeAlreadyExists = false; bool nodeAlreadyExists = false;
int nodeId = 0;
nodeMetadata.groupId = GroupForNode(primaryNameString, primaryPort); nodeMetadata.groupId = GroupForNode(primaryNameString, primaryPort);
nodeMetadata.nodeCluster = NameStr(*nodeClusterName); nodeMetadata.nodeCluster = NameStr(*nodeClusterName);
@ -226,7 +223,7 @@ master_add_secondary_node(PG_FUNCTION_ARGS)
CheckCitusVersion(ERROR); CheckCitusVersion(ERROR);
nodeId = AddNodeMetadata(nodeNameString, nodePort, &nodeMetadata, int nodeId = AddNodeMetadata(nodeNameString, nodePort, &nodeMetadata,
&nodeAlreadyExists); &nodeAlreadyExists);
PG_RETURN_INT32(nodeId); PG_RETURN_INT32(nodeId);
@ -307,11 +304,9 @@ master_disable_node(PG_FUNCTION_ARGS)
} }
PG_CATCH(); PG_CATCH();
{ {
ErrorData *edata = NULL;
/* CopyErrorData() requires (CurrentMemoryContext != ErrorContext) */ /* CopyErrorData() requires (CurrentMemoryContext != ErrorContext) */
MemoryContextSwitchTo(savedContext); MemoryContextSwitchTo(savedContext);
edata = CopyErrorData(); ErrorData *edata = CopyErrorData();
ereport(ERROR, (errcode(ERRCODE_OBJECT_NOT_IN_PREREQUISITE_STATE), ereport(ERROR, (errcode(ERRCODE_OBJECT_NOT_IN_PREREQUISITE_STATE),
errmsg("Disabling %s:%d failed", workerNode->workerName, errmsg("Disabling %s:%d failed", workerNode->workerName,
@ -397,14 +392,12 @@ SetUpDistributedTableDependencies(WorkerNode *newWorkerNode)
static void static void
PropagateRolesToNewNode(WorkerNode *newWorkerNode) PropagateRolesToNewNode(WorkerNode *newWorkerNode)
{ {
List *ddlCommands = NIL;
if (!EnableAlterRolePropagation) if (!EnableAlterRolePropagation)
{ {
return; return;
} }
ddlCommands = GenerateAlterRoleIfExistsCommandAllRoles(); List *ddlCommands = GenerateAlterRoleIfExistsCommandAllRoles();
SendCommandListToWorkerInSingleTransaction(newWorkerNode->workerName, SendCommandListToWorkerInSingleTransaction(newWorkerNode->workerName,
newWorkerNode->workerPort, newWorkerNode->workerPort,
@ -419,8 +412,6 @@ PropagateRolesToNewNode(WorkerNode *newWorkerNode)
static WorkerNode * static WorkerNode *
ModifiableWorkerNode(const char *nodeName, int32 nodePort) ModifiableWorkerNode(const char *nodeName, int32 nodePort)
{ {
WorkerNode *workerNode = NULL;
CheckCitusVersion(ERROR); CheckCitusVersion(ERROR);
EnsureCoordinator(); EnsureCoordinator();
@ -428,7 +419,7 @@ ModifiableWorkerNode(const char *nodeName, int32 nodePort)
/* take an exclusive lock on pg_dist_node to serialize pg_dist_node changes */ /* take an exclusive lock on pg_dist_node to serialize pg_dist_node changes */
LockRelationOid(DistNodeRelationId(), ExclusiveLock); LockRelationOid(DistNodeRelationId(), ExclusiveLock);
workerNode = FindWorkerNodeAnyCluster(nodeName, nodePort); WorkerNode *workerNode = FindWorkerNodeAnyCluster(nodeName, nodePort);
if (workerNode == NULL) if (workerNode == NULL)
{ {
ereport(ERROR, (errmsg("node at \"%s:%u\" does not exist", nodeName, nodePort))); ereport(ERROR, (errmsg("node at \"%s:%u\" does not exist", nodeName, nodePort)));
@ -581,13 +572,12 @@ PrimaryNodeForGroup(int32 groupId, bool *groupContainsNodes)
static int static int
ActivateNode(char *nodeName, int nodePort) ActivateNode(char *nodeName, int nodePort)
{ {
WorkerNode *newWorkerNode = NULL;
bool isActive = true; bool isActive = true;
/* take an exclusive lock on pg_dist_node to serialize pg_dist_node changes */ /* take an exclusive lock on pg_dist_node to serialize pg_dist_node changes */
LockRelationOid(DistNodeRelationId(), ExclusiveLock); LockRelationOid(DistNodeRelationId(), ExclusiveLock);
newWorkerNode = SetNodeState(nodeName, nodePort, isActive); WorkerNode *newWorkerNode = SetNodeState(nodeName, nodePort, isActive);
PropagateRolesToNewNode(newWorkerNode); PropagateRolesToNewNode(newWorkerNode);
SetUpDistributedTableDependencies(newWorkerNode); SetUpDistributedTableDependencies(newWorkerNode);
@ -621,14 +611,13 @@ master_update_node(PG_FUNCTION_ARGS)
int32 lock_cooldown = PG_GETARG_INT32(4); int32 lock_cooldown = PG_GETARG_INT32(4);
char *newNodeNameString = text_to_cstring(newNodeName); char *newNodeNameString = text_to_cstring(newNodeName);
WorkerNode *workerNode = NULL;
WorkerNode *workerNodeWithSameAddress = NULL;
List *placementList = NIL; List *placementList = NIL;
BackgroundWorkerHandle *handle = NULL; BackgroundWorkerHandle *handle = NULL;
CheckCitusVersion(ERROR); CheckCitusVersion(ERROR);
workerNodeWithSameAddress = FindWorkerNodeAnyCluster(newNodeNameString, newNodePort); WorkerNode *workerNodeWithSameAddress = FindWorkerNodeAnyCluster(newNodeNameString,
newNodePort);
if (workerNodeWithSameAddress != NULL) if (workerNodeWithSameAddress != NULL)
{ {
/* a node with the given hostname and port already exists in the metadata */ /* a node with the given hostname and port already exists in the metadata */
@ -646,7 +635,7 @@ master_update_node(PG_FUNCTION_ARGS)
} }
} }
workerNode = LookupNodeByNodeId(nodeId); WorkerNode *workerNode = LookupNodeByNodeId(nodeId);
if (workerNode == NULL) if (workerNode == NULL)
{ {
ereport(ERROR, (errcode(ERRCODE_NO_DATA_FOUND), ereport(ERROR, (errcode(ERRCODE_NO_DATA_FOUND),
@ -734,25 +723,22 @@ UpdateNodeLocation(int32 nodeId, char *newNodeName, int32 newNodePort)
{ {
const bool indexOK = true; const bool indexOK = true;
Relation pgDistNode = NULL;
TupleDesc tupleDescriptor = NULL;
ScanKeyData scanKey[1]; ScanKeyData scanKey[1];
SysScanDesc scanDescriptor = NULL;
HeapTuple heapTuple = NULL;
Datum values[Natts_pg_dist_node]; Datum values[Natts_pg_dist_node];
bool isnull[Natts_pg_dist_node]; bool isnull[Natts_pg_dist_node];
bool replace[Natts_pg_dist_node]; bool replace[Natts_pg_dist_node];
pgDistNode = heap_open(DistNodeRelationId(), RowExclusiveLock); Relation pgDistNode = heap_open(DistNodeRelationId(), RowExclusiveLock);
tupleDescriptor = RelationGetDescr(pgDistNode); TupleDesc tupleDescriptor = RelationGetDescr(pgDistNode);
ScanKeyInit(&scanKey[0], Anum_pg_dist_node_nodeid, ScanKeyInit(&scanKey[0], Anum_pg_dist_node_nodeid,
BTEqualStrategyNumber, F_INT4EQ, Int32GetDatum(nodeId)); BTEqualStrategyNumber, F_INT4EQ, Int32GetDatum(nodeId));
scanDescriptor = systable_beginscan(pgDistNode, DistNodeNodeIdIndexId(), indexOK, SysScanDesc scanDescriptor = systable_beginscan(pgDistNode, DistNodeNodeIdIndexId(),
indexOK,
NULL, 1, scanKey); NULL, 1, scanKey);
heapTuple = systable_getnext(scanDescriptor); HeapTuple heapTuple = systable_getnext(scanDescriptor);
if (!HeapTupleIsValid(heapTuple)) if (!HeapTupleIsValid(heapTuple))
{ {
ereport(ERROR, (errmsg("could not find valid entry for node \"%s:%d\"", ereport(ERROR, (errmsg("could not find valid entry for node \"%s:%d\"",
@ -791,8 +777,6 @@ Datum
get_shard_id_for_distribution_column(PG_FUNCTION_ARGS) get_shard_id_for_distribution_column(PG_FUNCTION_ARGS)
{ {
ShardInterval *shardInterval = NULL; ShardInterval *shardInterval = NULL;
char distributionMethod = 0;
Oid relationId = InvalidOid;
CheckCitusVersion(ERROR); CheckCitusVersion(ERROR);
@ -806,7 +790,7 @@ get_shard_id_for_distribution_column(PG_FUNCTION_ARGS)
errmsg("relation cannot be NULL"))); errmsg("relation cannot be NULL")));
} }
relationId = PG_GETARG_OID(0); Oid relationId = PG_GETARG_OID(0);
EnsureTablePermissions(relationId, ACL_SELECT); EnsureTablePermissions(relationId, ACL_SELECT);
if (!IsDistributedTable(relationId)) if (!IsDistributedTable(relationId))
@ -815,7 +799,7 @@ get_shard_id_for_distribution_column(PG_FUNCTION_ARGS)
errmsg("relation is not distributed"))); errmsg("relation is not distributed")));
} }
distributionMethod = PartitionMethod(relationId); char distributionMethod = PartitionMethod(relationId);
if (distributionMethod == DISTRIBUTE_BY_NONE) if (distributionMethod == DISTRIBUTE_BY_NONE)
{ {
List *shardIntervalList = LoadShardIntervalList(relationId); List *shardIntervalList = LoadShardIntervalList(relationId);
@ -829,12 +813,6 @@ get_shard_id_for_distribution_column(PG_FUNCTION_ARGS)
else if (distributionMethod == DISTRIBUTE_BY_HASH || else if (distributionMethod == DISTRIBUTE_BY_HASH ||
distributionMethod == DISTRIBUTE_BY_RANGE) distributionMethod == DISTRIBUTE_BY_RANGE)
{ {
Var *distributionColumn = NULL;
Oid distributionDataType = InvalidOid;
Oid inputDataType = InvalidOid;
char *distributionValueString = NULL;
Datum inputDatum = 0;
Datum distributionValueDatum = 0;
DistTableCacheEntry *cacheEntry = DistributedTableCacheEntry(relationId); DistTableCacheEntry *cacheEntry = DistributedTableCacheEntry(relationId);
/* if given table is not reference table, distributionValue cannot be NULL */ /* if given table is not reference table, distributionValue cannot be NULL */
@ -845,14 +823,14 @@ get_shard_id_for_distribution_column(PG_FUNCTION_ARGS)
"than reference tables."))); "than reference tables.")));
} }
inputDatum = PG_GETARG_DATUM(1); Datum inputDatum = PG_GETARG_DATUM(1);
inputDataType = get_fn_expr_argtype(fcinfo->flinfo, 1); Oid inputDataType = get_fn_expr_argtype(fcinfo->flinfo, 1);
distributionValueString = DatumToString(inputDatum, inputDataType); char *distributionValueString = DatumToString(inputDatum, inputDataType);
distributionColumn = DistPartitionKey(relationId); Var *distributionColumn = DistPartitionKey(relationId);
distributionDataType = distributionColumn->vartype; Oid distributionDataType = distributionColumn->vartype;
distributionValueDatum = StringToDatum(distributionValueString, Datum distributionValueDatum = StringToDatum(distributionValueString,
distributionDataType); distributionDataType);
shardInterval = FindShardInterval(distributionValueDatum, cacheEntry); shardInterval = FindShardInterval(distributionValueDatum, cacheEntry);
@ -881,17 +859,16 @@ get_shard_id_for_distribution_column(PG_FUNCTION_ARGS)
WorkerNode * WorkerNode *
FindWorkerNode(char *nodeName, int32 nodePort) FindWorkerNode(char *nodeName, int32 nodePort)
{ {
WorkerNode *cachedWorkerNode = NULL;
HTAB *workerNodeHash = GetWorkerNodeHash(); HTAB *workerNodeHash = GetWorkerNodeHash();
bool handleFound = false; bool handleFound = false;
void *hashKey = NULL;
WorkerNode *searchedNode = (WorkerNode *) palloc0(sizeof(WorkerNode)); WorkerNode *searchedNode = (WorkerNode *) palloc0(sizeof(WorkerNode));
strlcpy(searchedNode->workerName, nodeName, WORKER_LENGTH); strlcpy(searchedNode->workerName, nodeName, WORKER_LENGTH);
searchedNode->workerPort = nodePort; searchedNode->workerPort = nodePort;
hashKey = (void *) searchedNode; void *hashKey = (void *) searchedNode;
cachedWorkerNode = (WorkerNode *) hash_search(workerNodeHash, hashKey, HASH_FIND, WorkerNode *cachedWorkerNode = (WorkerNode *) hash_search(workerNodeHash, hashKey,
HASH_FIND,
&handleFound); &handleFound);
if (handleFound) if (handleFound)
{ {
@ -939,22 +916,19 @@ FindWorkerNodeAnyCluster(const char *nodeName, int32 nodePort)
List * List *
ReadDistNode(bool includeNodesFromOtherClusters) ReadDistNode(bool includeNodesFromOtherClusters)
{ {
SysScanDesc scanDescriptor = NULL;
ScanKeyData scanKey[1]; ScanKeyData scanKey[1];
int scanKeyCount = 0; int scanKeyCount = 0;
HeapTuple heapTuple = NULL;
List *workerNodeList = NIL; List *workerNodeList = NIL;
TupleDesc tupleDescriptor = NULL;
Relation pgDistNode = heap_open(DistNodeRelationId(), AccessShareLock); Relation pgDistNode = heap_open(DistNodeRelationId(), AccessShareLock);
scanDescriptor = systable_beginscan(pgDistNode, SysScanDesc scanDescriptor = systable_beginscan(pgDistNode,
InvalidOid, false, InvalidOid, false,
NULL, scanKeyCount, scanKey); NULL, scanKeyCount, scanKey);
tupleDescriptor = RelationGetDescr(pgDistNode); TupleDesc tupleDescriptor = RelationGetDescr(pgDistNode);
heapTuple = systable_getnext(scanDescriptor); HeapTuple heapTuple = systable_getnext(scanDescriptor);
while (HeapTupleIsValid(heapTuple)) while (HeapTupleIsValid(heapTuple))
{ {
WorkerNode *workerNode = TupleToWorkerNode(tupleDescriptor, heapTuple); WorkerNode *workerNode = TupleToWorkerNode(tupleDescriptor, heapTuple);
@ -989,7 +963,6 @@ ReadDistNode(bool includeNodesFromOtherClusters)
static void static void
RemoveNodeFromCluster(char *nodeName, int32 nodePort) RemoveNodeFromCluster(char *nodeName, int32 nodePort)
{ {
char *nodeDeleteCommand = NULL;
WorkerNode *workerNode = ModifiableWorkerNode(nodeName, nodePort); WorkerNode *workerNode = ModifiableWorkerNode(nodeName, nodePort);
if (NodeIsPrimary(workerNode)) if (NodeIsPrimary(workerNode))
@ -1012,7 +985,7 @@ RemoveNodeFromCluster(char *nodeName, int32 nodePort)
DeleteNodeRow(workerNode->workerName, nodePort); DeleteNodeRow(workerNode->workerName, nodePort);
nodeDeleteCommand = NodeDeleteCommand(workerNode->nodeId); char *nodeDeleteCommand = NodeDeleteCommand(workerNode->nodeId);
/* make sure we don't have any lingering session lifespan connections */ /* make sure we don't have any lingering session lifespan connections */
CloseNodeConnectionsAfterTransaction(workerNode->workerName, nodePort); CloseNodeConnectionsAfterTransaction(workerNode->workerName, nodePort);
@ -1059,11 +1032,6 @@ AddNodeMetadata(char *nodeName, int32 nodePort,
NodeMetadata *nodeMetadata, NodeMetadata *nodeMetadata,
bool *nodeAlreadyExists) bool *nodeAlreadyExists)
{ {
int nextNodeIdInt = 0;
WorkerNode *workerNode = NULL;
char *nodeDeleteCommand = NULL;
uint32 primariesWithMetadata = 0;
EnsureCoordinator(); EnsureCoordinator();
*nodeAlreadyExists = false; *nodeAlreadyExists = false;
@ -1075,7 +1043,7 @@ AddNodeMetadata(char *nodeName, int32 nodePort,
*/ */
LockRelationOid(DistNodeRelationId(), ExclusiveLock); LockRelationOid(DistNodeRelationId(), ExclusiveLock);
workerNode = FindWorkerNodeAnyCluster(nodeName, nodePort); WorkerNode *workerNode = FindWorkerNodeAnyCluster(nodeName, nodePort);
if (workerNode != NULL) if (workerNode != NULL)
{ {
/* fill return data and return */ /* fill return data and return */
@ -1122,18 +1090,18 @@ AddNodeMetadata(char *nodeName, int32 nodePort,
} }
/* generate the new node id from the sequence */ /* generate the new node id from the sequence */
nextNodeIdInt = GetNextNodeId(); int nextNodeIdInt = GetNextNodeId();
InsertNodeRow(nextNodeIdInt, nodeName, nodePort, nodeMetadata); InsertNodeRow(nextNodeIdInt, nodeName, nodePort, nodeMetadata);
workerNode = FindWorkerNodeAnyCluster(nodeName, nodePort); workerNode = FindWorkerNodeAnyCluster(nodeName, nodePort);
/* send the delete command to all primary nodes with metadata */ /* send the delete command to all primary nodes with metadata */
nodeDeleteCommand = NodeDeleteCommand(workerNode->nodeId); char *nodeDeleteCommand = NodeDeleteCommand(workerNode->nodeId);
SendCommandToWorkers(WORKERS_WITH_METADATA, nodeDeleteCommand); SendCommandToWorkers(WORKERS_WITH_METADATA, nodeDeleteCommand);
/* finally prepare the insert command and send it to all primary nodes */ /* finally prepare the insert command and send it to all primary nodes */
primariesWithMetadata = CountPrimariesWithMetadata(); uint32 primariesWithMetadata = CountPrimariesWithMetadata();
if (primariesWithMetadata != 0) if (primariesWithMetadata != 0)
{ {
List *workerNodeList = list_make1(workerNode); List *workerNodeList = list_make1(workerNode);
@ -1157,7 +1125,6 @@ SetWorkerColumn(WorkerNode *workerNode, int columnIndex, Datum value)
Relation pgDistNode = heap_open(DistNodeRelationId(), RowExclusiveLock); Relation pgDistNode = heap_open(DistNodeRelationId(), RowExclusiveLock);
TupleDesc tupleDescriptor = RelationGetDescr(pgDistNode); TupleDesc tupleDescriptor = RelationGetDescr(pgDistNode);
HeapTuple heapTuple = GetNodeTuple(workerNode->workerName, workerNode->workerPort); HeapTuple heapTuple = GetNodeTuple(workerNode->workerName, workerNode->workerPort);
WorkerNode *newWorkerNode = NULL;
Datum values[Natts_pg_dist_node]; Datum values[Natts_pg_dist_node];
bool isnull[Natts_pg_dist_node]; bool isnull[Natts_pg_dist_node];
@ -1206,7 +1173,7 @@ SetWorkerColumn(WorkerNode *workerNode, int columnIndex, Datum value)
CitusInvalidateRelcacheByRelid(DistNodeRelationId()); CitusInvalidateRelcacheByRelid(DistNodeRelationId());
CommandCounterIncrement(); CommandCounterIncrement();
newWorkerNode = TupleToWorkerNode(tupleDescriptor, heapTuple); WorkerNode *newWorkerNode = TupleToWorkerNode(tupleDescriptor, heapTuple);
heap_close(pgDistNode, NoLock); heap_close(pgDistNode, NoLock);
@ -1257,18 +1224,16 @@ GetNodeTuple(const char *nodeName, int32 nodePort)
const bool indexOK = false; const bool indexOK = false;
ScanKeyData scanKey[2]; ScanKeyData scanKey[2];
SysScanDesc scanDescriptor = NULL;
HeapTuple heapTuple = NULL;
HeapTuple nodeTuple = NULL; HeapTuple nodeTuple = NULL;
ScanKeyInit(&scanKey[0], Anum_pg_dist_node_nodename, ScanKeyInit(&scanKey[0], Anum_pg_dist_node_nodename,
BTEqualStrategyNumber, F_TEXTEQ, CStringGetTextDatum(nodeName)); BTEqualStrategyNumber, F_TEXTEQ, CStringGetTextDatum(nodeName));
ScanKeyInit(&scanKey[1], Anum_pg_dist_node_nodeport, ScanKeyInit(&scanKey[1], Anum_pg_dist_node_nodeport,
BTEqualStrategyNumber, F_INT4EQ, Int32GetDatum(nodePort)); BTEqualStrategyNumber, F_INT4EQ, Int32GetDatum(nodePort));
scanDescriptor = systable_beginscan(pgDistNode, InvalidOid, indexOK, SysScanDesc scanDescriptor = systable_beginscan(pgDistNode, InvalidOid, indexOK,
NULL, scanKeyCount, scanKey); NULL, scanKeyCount, scanKey);
heapTuple = systable_getnext(scanDescriptor); HeapTuple heapTuple = systable_getnext(scanDescriptor);
if (HeapTupleIsValid(heapTuple)) if (HeapTupleIsValid(heapTuple))
{ {
nodeTuple = heap_copytuple(heapTuple); nodeTuple = heap_copytuple(heapTuple);
@ -1298,18 +1263,16 @@ GetNextGroupId()
Datum sequenceIdDatum = ObjectIdGetDatum(sequenceId); Datum sequenceIdDatum = ObjectIdGetDatum(sequenceId);
Oid savedUserId = InvalidOid; Oid savedUserId = InvalidOid;
int savedSecurityContext = 0; int savedSecurityContext = 0;
Datum groupIdDatum = 0;
int32 groupId = 0;
GetUserIdAndSecContext(&savedUserId, &savedSecurityContext); GetUserIdAndSecContext(&savedUserId, &savedSecurityContext);
SetUserIdAndSecContext(CitusExtensionOwner(), SECURITY_LOCAL_USERID_CHANGE); SetUserIdAndSecContext(CitusExtensionOwner(), SECURITY_LOCAL_USERID_CHANGE);
/* generate new and unique shardId from sequence */ /* generate new and unique shardId from sequence */
groupIdDatum = DirectFunctionCall1(nextval_oid, sequenceIdDatum); Datum groupIdDatum = DirectFunctionCall1(nextval_oid, sequenceIdDatum);
SetUserIdAndSecContext(savedUserId, savedSecurityContext); SetUserIdAndSecContext(savedUserId, savedSecurityContext);
groupId = DatumGetInt32(groupIdDatum); int32 groupId = DatumGetInt32(groupIdDatum);
return groupId; return groupId;
} }
@ -1332,18 +1295,16 @@ GetNextNodeId()
Datum sequenceIdDatum = ObjectIdGetDatum(sequenceId); Datum sequenceIdDatum = ObjectIdGetDatum(sequenceId);
Oid savedUserId = InvalidOid; Oid savedUserId = InvalidOid;
int savedSecurityContext = 0; int savedSecurityContext = 0;
Datum nextNodeIdDatum;
int nextNodeId = 0;
GetUserIdAndSecContext(&savedUserId, &savedSecurityContext); GetUserIdAndSecContext(&savedUserId, &savedSecurityContext);
SetUserIdAndSecContext(CitusExtensionOwner(), SECURITY_LOCAL_USERID_CHANGE); SetUserIdAndSecContext(CitusExtensionOwner(), SECURITY_LOCAL_USERID_CHANGE);
/* generate new and unique shardId from sequence */ /* generate new and unique shardId from sequence */
nextNodeIdDatum = DirectFunctionCall1(nextval_oid, sequenceIdDatum); Datum nextNodeIdDatum = DirectFunctionCall1(nextval_oid, sequenceIdDatum);
SetUserIdAndSecContext(savedUserId, savedSecurityContext); SetUserIdAndSecContext(savedUserId, savedSecurityContext);
nextNodeId = DatumGetUInt32(nextNodeIdDatum); int nextNodeId = DatumGetUInt32(nextNodeIdDatum);
return nextNodeId; return nextNodeId;
} }
@ -1377,9 +1338,6 @@ EnsureCoordinator(void)
static void static void
InsertNodeRow(int nodeid, char *nodeName, int32 nodePort, NodeMetadata *nodeMetadata) InsertNodeRow(int nodeid, char *nodeName, int32 nodePort, NodeMetadata *nodeMetadata)
{ {
Relation pgDistNode = NULL;
TupleDesc tupleDescriptor = NULL;
HeapTuple heapTuple = NULL;
Datum values[Natts_pg_dist_node]; Datum values[Natts_pg_dist_node];
bool isNulls[Natts_pg_dist_node]; bool isNulls[Natts_pg_dist_node];
@ -1404,10 +1362,10 @@ InsertNodeRow(int nodeid, char *nodeName, int32 nodePort, NodeMetadata *nodeMeta
values[Anum_pg_dist_node_shouldhaveshards - 1] = BoolGetDatum( values[Anum_pg_dist_node_shouldhaveshards - 1] = BoolGetDatum(
nodeMetadata->shouldHaveShards); nodeMetadata->shouldHaveShards);
pgDistNode = heap_open(DistNodeRelationId(), RowExclusiveLock); Relation pgDistNode = heap_open(DistNodeRelationId(), RowExclusiveLock);
tupleDescriptor = RelationGetDescr(pgDistNode); TupleDesc tupleDescriptor = RelationGetDescr(pgDistNode);
heapTuple = heap_form_tuple(tupleDescriptor, values, isNulls); HeapTuple heapTuple = heap_form_tuple(tupleDescriptor, values, isNulls);
CatalogTupleInsert(pgDistNode, heapTuple); CatalogTupleInsert(pgDistNode, heapTuple);
@ -1430,8 +1388,6 @@ DeleteNodeRow(char *nodeName, int32 nodePort)
const int scanKeyCount = 2; const int scanKeyCount = 2;
bool indexOK = false; bool indexOK = false;
HeapTuple heapTuple = NULL;
SysScanDesc heapScan = NULL;
ScanKeyData scanKey[2]; ScanKeyData scanKey[2];
Relation pgDistNode = heap_open(DistNodeRelationId(), RowExclusiveLock); Relation pgDistNode = heap_open(DistNodeRelationId(), RowExclusiveLock);
@ -1447,10 +1403,10 @@ DeleteNodeRow(char *nodeName, int32 nodePort)
ScanKeyInit(&scanKey[1], Anum_pg_dist_node_nodeport, ScanKeyInit(&scanKey[1], Anum_pg_dist_node_nodeport,
BTEqualStrategyNumber, F_INT4EQ, Int32GetDatum(nodePort)); BTEqualStrategyNumber, F_INT4EQ, Int32GetDatum(nodePort));
heapScan = systable_beginscan(pgDistNode, InvalidOid, indexOK, SysScanDesc heapScan = systable_beginscan(pgDistNode, InvalidOid, indexOK,
NULL, scanKeyCount, scanKey); NULL, scanKeyCount, scanKey);
heapTuple = systable_getnext(heapScan); HeapTuple heapTuple = systable_getnext(heapScan);
if (!HeapTupleIsValid(heapTuple)) if (!HeapTupleIsValid(heapTuple))
{ {
@ -1481,11 +1437,8 @@ DeleteNodeRow(char *nodeName, int32 nodePort)
static WorkerNode * static WorkerNode *
TupleToWorkerNode(TupleDesc tupleDescriptor, HeapTuple heapTuple) TupleToWorkerNode(TupleDesc tupleDescriptor, HeapTuple heapTuple)
{ {
WorkerNode *workerNode = NULL;
Datum datumArray[Natts_pg_dist_node]; Datum datumArray[Natts_pg_dist_node];
bool isNullArray[Natts_pg_dist_node]; bool isNullArray[Natts_pg_dist_node];
char *nodeName = NULL;
char *nodeRack = NULL;
Assert(!HeapTupleHasNulls(heapTuple)); Assert(!HeapTupleHasNulls(heapTuple));
@ -1502,10 +1455,10 @@ TupleToWorkerNode(TupleDesc tupleDescriptor, HeapTuple heapTuple)
*/ */
heap_deform_tuple(heapTuple, tupleDescriptor, datumArray, isNullArray); heap_deform_tuple(heapTuple, tupleDescriptor, datumArray, isNullArray);
nodeName = DatumGetCString(datumArray[Anum_pg_dist_node_nodename - 1]); char *nodeName = DatumGetCString(datumArray[Anum_pg_dist_node_nodename - 1]);
nodeRack = DatumGetCString(datumArray[Anum_pg_dist_node_noderack - 1]); char *nodeRack = DatumGetCString(datumArray[Anum_pg_dist_node_noderack - 1]);
workerNode = (WorkerNode *) palloc0(sizeof(WorkerNode)); WorkerNode *workerNode = (WorkerNode *) palloc0(sizeof(WorkerNode));
workerNode->nodeId = DatumGetUInt32(datumArray[Anum_pg_dist_node_nodeid - 1]); workerNode->nodeId = DatumGetUInt32(datumArray[Anum_pg_dist_node_nodeid - 1]);
workerNode->workerPort = DatumGetUInt32(datumArray[Anum_pg_dist_node_nodeport - 1]); workerNode->workerPort = DatumGetUInt32(datumArray[Anum_pg_dist_node_nodeport - 1]);
workerNode->groupId = DatumGetInt32(datumArray[Anum_pg_dist_node_groupid - 1]); workerNode->groupId = DatumGetInt32(datumArray[Anum_pg_dist_node_groupid - 1]);
@ -1546,12 +1499,11 @@ StringToDatum(char *inputString, Oid dataType)
Oid typIoFunc = InvalidOid; Oid typIoFunc = InvalidOid;
Oid typIoParam = InvalidOid; Oid typIoParam = InvalidOid;
int32 typeModifier = -1; int32 typeModifier = -1;
Datum datum = 0;
getTypeInputInfo(dataType, &typIoFunc, &typIoParam); getTypeInputInfo(dataType, &typIoFunc, &typIoParam);
getBaseTypeAndTypmod(dataType, &typeModifier); getBaseTypeAndTypmod(dataType, &typeModifier);
datum = OidInputFunctionCall(typIoFunc, inputString, typIoParam, typeModifier); Datum datum = OidInputFunctionCall(typIoFunc, inputString, typIoParam, typeModifier);
return datum; return datum;
} }
@ -1563,12 +1515,11 @@ StringToDatum(char *inputString, Oid dataType)
char * char *
DatumToString(Datum datum, Oid dataType) DatumToString(Datum datum, Oid dataType)
{ {
char *outputString = NULL;
Oid typIoFunc = InvalidOid; Oid typIoFunc = InvalidOid;
bool typIsVarlena = false; bool typIsVarlena = false;
getTypeOutputInfo(dataType, &typIoFunc, &typIsVarlena); getTypeOutputInfo(dataType, &typIoFunc, &typIsVarlena);
outputString = OidOutputFunctionCall(typIoFunc, datum); char *outputString = OidOutputFunctionCall(typIoFunc, datum);
return outputString; return outputString;
} }
@ -1582,34 +1533,29 @@ static bool
UnsetMetadataSyncedForAll(void) UnsetMetadataSyncedForAll(void)
{ {
bool updatedAtLeastOne = false; bool updatedAtLeastOne = false;
Relation relation = NULL;
SysScanDesc scanDescriptor = NULL;
ScanKeyData scanKey[2]; ScanKeyData scanKey[2];
int scanKeyCount = 2; int scanKeyCount = 2;
bool indexOK = false; bool indexOK = false;
HeapTuple heapTuple = NULL;
TupleDesc tupleDescriptor = NULL;
CatalogIndexState indstate;
/* /*
* Concurrent master_update_node() calls might iterate and try to update * Concurrent master_update_node() calls might iterate and try to update
* pg_dist_node in different orders. To protect against deadlock, we * pg_dist_node in different orders. To protect against deadlock, we
* get an exclusive lock here. * get an exclusive lock here.
*/ */
relation = heap_open(DistNodeRelationId(), ExclusiveLock); Relation relation = heap_open(DistNodeRelationId(), ExclusiveLock);
tupleDescriptor = RelationGetDescr(relation); TupleDesc tupleDescriptor = RelationGetDescr(relation);
ScanKeyInit(&scanKey[0], Anum_pg_dist_node_hasmetadata, ScanKeyInit(&scanKey[0], Anum_pg_dist_node_hasmetadata,
BTEqualStrategyNumber, F_BOOLEQ, BoolGetDatum(true)); BTEqualStrategyNumber, F_BOOLEQ, BoolGetDatum(true));
ScanKeyInit(&scanKey[1], Anum_pg_dist_node_metadatasynced, ScanKeyInit(&scanKey[1], Anum_pg_dist_node_metadatasynced,
BTEqualStrategyNumber, F_BOOLEQ, BoolGetDatum(true)); BTEqualStrategyNumber, F_BOOLEQ, BoolGetDatum(true));
indstate = CatalogOpenIndexes(relation); CatalogIndexState indstate = CatalogOpenIndexes(relation);
scanDescriptor = systable_beginscan(relation, SysScanDesc scanDescriptor = systable_beginscan(relation,
InvalidOid, indexOK, InvalidOid, indexOK,
NULL, scanKeyCount, scanKey); NULL, scanKeyCount, scanKey);
heapTuple = systable_getnext(scanDescriptor); HeapTuple heapTuple = systable_getnext(scanDescriptor);
if (HeapTupleIsValid(heapTuple)) if (HeapTupleIsValid(heapTuple))
{ {
updatedAtLeastOne = true; updatedAtLeastOne = true;
@ -1617,7 +1563,6 @@ UnsetMetadataSyncedForAll(void)
while (HeapTupleIsValid(heapTuple)) while (HeapTupleIsValid(heapTuple))
{ {
HeapTuple newHeapTuple = NULL;
Datum values[Natts_pg_dist_node]; Datum values[Natts_pg_dist_node];
bool isnull[Natts_pg_dist_node]; bool isnull[Natts_pg_dist_node];
bool replace[Natts_pg_dist_node]; bool replace[Natts_pg_dist_node];
@ -1629,7 +1574,8 @@ UnsetMetadataSyncedForAll(void)
values[Anum_pg_dist_node_metadatasynced - 1] = BoolGetDatum(false); values[Anum_pg_dist_node_metadatasynced - 1] = BoolGetDatum(false);
replace[Anum_pg_dist_node_metadatasynced - 1] = true; replace[Anum_pg_dist_node_metadatasynced - 1] = true;
newHeapTuple = heap_modify_tuple(heapTuple, tupleDescriptor, values, isnull, HeapTuple newHeapTuple = heap_modify_tuple(heapTuple, tupleDescriptor, values,
isnull,
replace); replace);
CatalogTupleUpdateWithInfo(relation, &newHeapTuple->t_self, newHeapTuple, CatalogTupleUpdateWithInfo(relation, &newHeapTuple->t_self, newHeapTuple,

View File

@ -61,21 +61,17 @@ RebuildQueryStrings(Query *originalQuery, List *taskList)
else if (query->commandType == CMD_INSERT && task->modifyWithSubquery) else if (query->commandType == CMD_INSERT && task->modifyWithSubquery)
{ {
/* for INSERT..SELECT, adjust shard names in SELECT part */ /* for INSERT..SELECT, adjust shard names in SELECT part */
RangeTblEntry *copiedInsertRte = NULL;
RangeTblEntry *copiedSubqueryRte = NULL;
Query *copiedSubquery = NULL;
List *relationShardList = task->relationShardList; List *relationShardList = task->relationShardList;
ShardInterval *shardInterval = LoadShardInterval(task->anchorShardId); ShardInterval *shardInterval = LoadShardInterval(task->anchorShardId);
char partitionMethod = 0;
query = copyObject(originalQuery); query = copyObject(originalQuery);
copiedInsertRte = ExtractResultRelationRTE(query); RangeTblEntry *copiedInsertRte = ExtractResultRelationRTE(query);
copiedSubqueryRte = ExtractSelectRangeTableEntry(query); RangeTblEntry *copiedSubqueryRte = ExtractSelectRangeTableEntry(query);
copiedSubquery = copiedSubqueryRte->subquery; Query *copiedSubquery = copiedSubqueryRte->subquery;
/* there are no restrictions to add for reference tables */ /* there are no restrictions to add for reference tables */
partitionMethod = PartitionMethod(shardInterval->relationId); char partitionMethod = PartitionMethod(shardInterval->relationId);
if (partitionMethod != DISTRIBUTE_BY_NONE) if (partitionMethod != DISTRIBUTE_BY_NONE)
{ {
AddShardIntervalRestrictionToSelect(copiedSubquery, shardInterval); AddShardIntervalRestrictionToSelect(copiedSubquery, shardInterval);
@ -95,14 +91,12 @@ RebuildQueryStrings(Query *originalQuery, List *taskList)
else if (query->commandType == CMD_INSERT && (query->onConflict != NULL || else if (query->commandType == CMD_INSERT && (query->onConflict != NULL ||
valuesRTE != NULL)) valuesRTE != NULL))
{ {
RangeTblEntry *rangeTableEntry = NULL;
/* /*
* Always an alias in UPSERTs and multi-row INSERTs to avoid * Always an alias in UPSERTs and multi-row INSERTs to avoid
* deparsing issues (e.g. RETURNING might reference the original * deparsing issues (e.g. RETURNING might reference the original
* table name, which has been replaced by a shard name). * table name, which has been replaced by a shard name).
*/ */
rangeTableEntry = linitial(query->rtable); RangeTblEntry *rangeTableEntry = linitial(query->rtable);
if (rangeTableEntry->alias == NULL) if (rangeTableEntry->alias == NULL)
{ {
Alias *alias = makeAlias(CITUS_TABLE_ALIAS, NIL); Alias *alias = makeAlias(CITUS_TABLE_ALIAS, NIL);
@ -184,13 +178,8 @@ UpdateTaskQueryString(Query *query, Oid distributedTableId, RangeTblEntry *value
bool bool
UpdateRelationToShardNames(Node *node, List *relationShardList) UpdateRelationToShardNames(Node *node, List *relationShardList)
{ {
RangeTblEntry *newRte = NULL;
uint64 shardId = INVALID_SHARD_ID; uint64 shardId = INVALID_SHARD_ID;
Oid relationId = InvalidOid; Oid relationId = InvalidOid;
Oid schemaId = InvalidOid;
char *relationName = NULL;
char *schemaName = NULL;
bool replaceRteWithNullValues = false;
ListCell *relationShardCell = NULL; ListCell *relationShardCell = NULL;
RelationShard *relationShard = NULL; RelationShard *relationShard = NULL;
@ -212,7 +201,7 @@ UpdateRelationToShardNames(Node *node, List *relationShardList)
relationShardList); relationShardList);
} }
newRte = (RangeTblEntry *) node; RangeTblEntry *newRte = (RangeTblEntry *) node;
if (newRte->rtekind != RTE_RELATION) if (newRte->rtekind != RTE_RELATION)
{ {
@ -238,7 +227,7 @@ UpdateRelationToShardNames(Node *node, List *relationShardList)
relationShard = NULL; relationShard = NULL;
} }
replaceRteWithNullValues = relationShard == NULL || bool replaceRteWithNullValues = relationShard == NULL ||
relationShard->shardId == INVALID_SHARD_ID; relationShard->shardId == INVALID_SHARD_ID;
if (replaceRteWithNullValues) if (replaceRteWithNullValues)
{ {
@ -249,11 +238,11 @@ UpdateRelationToShardNames(Node *node, List *relationShardList)
shardId = relationShard->shardId; shardId = relationShard->shardId;
relationId = relationShard->relationId; relationId = relationShard->relationId;
relationName = get_rel_name(relationId); char *relationName = get_rel_name(relationId);
AppendShardIdToName(&relationName, shardId); AppendShardIdToName(&relationName, shardId);
schemaId = get_rel_namespace(relationId); Oid schemaId = get_rel_namespace(relationId);
schemaName = get_namespace_name(schemaId); char *schemaName = get_namespace_name(schemaId);
ModifyRangeTblExtraData(newRte, CITUS_RTE_SHARD, schemaName, relationName, NIL); ModifyRangeTblExtraData(newRte, CITUS_RTE_SHARD, schemaName, relationName, NIL);
@ -271,31 +260,26 @@ ConvertRteToSubqueryWithEmptyResult(RangeTblEntry *rte)
Relation relation = heap_open(rte->relid, NoLock); Relation relation = heap_open(rte->relid, NoLock);
TupleDesc tupleDescriptor = RelationGetDescr(relation); TupleDesc tupleDescriptor = RelationGetDescr(relation);
int columnCount = tupleDescriptor->natts; int columnCount = tupleDescriptor->natts;
int columnIndex = 0;
Query *subquery = NULL;
List *targetList = NIL; List *targetList = NIL;
FromExpr *joinTree = NULL;
for (columnIndex = 0; columnIndex < columnCount; columnIndex++) for (int columnIndex = 0; columnIndex < columnCount; columnIndex++)
{ {
FormData_pg_attribute *attributeForm = TupleDescAttr(tupleDescriptor, FormData_pg_attribute *attributeForm = TupleDescAttr(tupleDescriptor,
columnIndex); columnIndex);
TargetEntry *targetEntry = NULL;
StringInfo resname = NULL;
Const *constValue = NULL;
if (attributeForm->attisdropped) if (attributeForm->attisdropped)
{ {
continue; continue;
} }
resname = makeStringInfo(); StringInfo resname = makeStringInfo();
constValue = makeNullConst(attributeForm->atttypid, attributeForm->atttypmod, Const *constValue = makeNullConst(attributeForm->atttypid,
attributeForm->atttypmod,
attributeForm->attcollation); attributeForm->attcollation);
appendStringInfo(resname, "%s", attributeForm->attname.data); appendStringInfo(resname, "%s", attributeForm->attname.data);
targetEntry = makeNode(TargetEntry); TargetEntry *targetEntry = makeNode(TargetEntry);
targetEntry->expr = (Expr *) constValue; targetEntry->expr = (Expr *) constValue;
targetEntry->resno = columnIndex; targetEntry->resno = columnIndex;
targetEntry->resname = resname->data; targetEntry->resname = resname->data;
@ -305,10 +289,10 @@ ConvertRteToSubqueryWithEmptyResult(RangeTblEntry *rte)
heap_close(relation, NoLock); heap_close(relation, NoLock);
joinTree = makeNode(FromExpr); FromExpr *joinTree = makeNode(FromExpr);
joinTree->quals = makeBoolConst(false, false); joinTree->quals = makeBoolConst(false, false);
subquery = makeNode(Query); Query *subquery = makeNode(Query);
subquery->commandType = CMD_SELECT; subquery->commandType = CMD_SELECT;
subquery->querySource = QSRC_ORIGINAL; subquery->querySource = QSRC_ORIGINAL;
subquery->canSetTag = true; subquery->canSetTag = true;

View File

@ -113,7 +113,6 @@ distributed_planner(Query *parse, int cursorOptions, ParamListInfo boundParams)
PlannedStmt *result = NULL; PlannedStmt *result = NULL;
bool needsDistributedPlanning = false; bool needsDistributedPlanning = false;
Query *originalQuery = NULL; Query *originalQuery = NULL;
PlannerRestrictionContext *plannerRestrictionContext = NULL;
bool setPartitionedTablesInherited = false; bool setPartitionedTablesInherited = false;
List *rangeTableList = ExtractRangeTableEntryList(parse); List *rangeTableList = ExtractRangeTableEntryList(parse);
@ -181,7 +180,8 @@ distributed_planner(Query *parse, int cursorOptions, ParamListInfo boundParams)
ReplaceTableVisibleFunction((Node *) parse); ReplaceTableVisibleFunction((Node *) parse);
/* create a restriction context and put it at the end if context list */ /* create a restriction context and put it at the end if context list */
plannerRestrictionContext = CreateAndPushPlannerRestrictionContext(); PlannerRestrictionContext *plannerRestrictionContext =
CreateAndPushPlannerRestrictionContext();
PG_TRY(); PG_TRY();
{ {
@ -519,8 +519,6 @@ CreateDistributedPlannedStmt(uint64 planId, PlannedStmt *localPlan, Query *origi
Query *query, ParamListInfo boundParams, Query *query, ParamListInfo boundParams,
PlannerRestrictionContext *plannerRestrictionContext) PlannerRestrictionContext *plannerRestrictionContext)
{ {
DistributedPlan *distributedPlan = NULL;
PlannedStmt *resultPlan = NULL;
bool hasUnresolvedParams = false; bool hasUnresolvedParams = false;
JoinRestrictionContext *joinRestrictionContext = JoinRestrictionContext *joinRestrictionContext =
plannerRestrictionContext->joinRestrictionContext; plannerRestrictionContext->joinRestrictionContext;
@ -533,7 +531,7 @@ CreateDistributedPlannedStmt(uint64 planId, PlannedStmt *localPlan, Query *origi
plannerRestrictionContext->joinRestrictionContext = plannerRestrictionContext->joinRestrictionContext =
RemoveDuplicateJoinRestrictions(joinRestrictionContext); RemoveDuplicateJoinRestrictions(joinRestrictionContext);
distributedPlan = DistributedPlan *distributedPlan =
CreateDistributedPlan(planId, originalQuery, query, boundParams, CreateDistributedPlan(planId, originalQuery, query, boundParams,
hasUnresolvedParams, plannerRestrictionContext); hasUnresolvedParams, plannerRestrictionContext);
@ -580,7 +578,7 @@ CreateDistributedPlannedStmt(uint64 planId, PlannedStmt *localPlan, Query *origi
distributedPlan->planId = planId; distributedPlan->planId = planId;
/* create final plan by combining local plan with distributed plan */ /* create final plan by combining local plan with distributed plan */
resultPlan = FinalizePlan(localPlan, distributedPlan); PlannedStmt *resultPlan = FinalizePlan(localPlan, distributedPlan);
/* /*
* As explained above, force planning costs to be unrealistically high if * As explained above, force planning costs to be unrealistically high if
@ -617,17 +615,14 @@ CreateDistributedPlan(uint64 planId, Query *originalQuery, Query *query, ParamLi
PlannerRestrictionContext *plannerRestrictionContext) PlannerRestrictionContext *plannerRestrictionContext)
{ {
DistributedPlan *distributedPlan = NULL; DistributedPlan *distributedPlan = NULL;
MultiTreeRoot *logicalPlan = NULL;
List *subPlanList = NIL;
bool hasCtes = originalQuery->cteList != NIL; bool hasCtes = originalQuery->cteList != NIL;
if (IsModifyCommand(originalQuery)) if (IsModifyCommand(originalQuery))
{ {
Oid targetRelationId = InvalidOid;
EnsureModificationsCanRun(); EnsureModificationsCanRun();
targetRelationId = ModifyQueryResultRelationId(query); Oid targetRelationId = ModifyQueryResultRelationId(query);
EnsurePartitionTableNotReplicated(targetRelationId); EnsurePartitionTableNotReplicated(targetRelationId);
if (InsertSelectIntoDistributedTable(originalQuery)) if (InsertSelectIntoDistributedTable(originalQuery))
@ -722,7 +717,7 @@ CreateDistributedPlan(uint64 planId, Query *originalQuery, Query *query, ParamLi
* Plan subqueries and CTEs that cannot be pushed down by recursively * Plan subqueries and CTEs that cannot be pushed down by recursively
* calling the planner and return the resulting plans to subPlanList. * calling the planner and return the resulting plans to subPlanList.
*/ */
subPlanList = GenerateSubplansForSubqueriesAndCTEs(planId, originalQuery, List *subPlanList = GenerateSubplansForSubqueriesAndCTEs(planId, originalQuery,
plannerRestrictionContext); plannerRestrictionContext);
/* /*
@ -798,7 +793,7 @@ CreateDistributedPlan(uint64 planId, Query *originalQuery, Query *query, ParamLi
query->cteList = NIL; query->cteList = NIL;
Assert(originalQuery->cteList == NIL); Assert(originalQuery->cteList == NIL);
logicalPlan = MultiLogicalPlanCreate(originalQuery, query, MultiTreeRoot *logicalPlan = MultiLogicalPlanCreate(originalQuery, query,
plannerRestrictionContext); plannerRestrictionContext);
MultiLogicalPlanOptimize(logicalPlan); MultiLogicalPlanOptimize(logicalPlan);
@ -937,14 +932,11 @@ ResolveExternalParams(Node *inputNode, ParamListInfo boundParams)
if (IsA(inputNode, Param)) if (IsA(inputNode, Param))
{ {
Param *paramToProcess = (Param *) inputNode; Param *paramToProcess = (Param *) inputNode;
ParamExternData *correspondingParameterData = NULL;
int numberOfParameters = boundParams->numParams; int numberOfParameters = boundParams->numParams;
int parameterId = paramToProcess->paramid; int parameterId = paramToProcess->paramid;
int16 typeLength = 0; int16 typeLength = 0;
bool typeByValue = false; bool typeByValue = false;
Datum constValue = 0; Datum constValue = 0;
bool paramIsNull = false;
int parameterIndex = 0;
if (paramToProcess->paramkind != PARAM_EXTERN) if (paramToProcess->paramkind != PARAM_EXTERN)
{ {
@ -957,13 +949,14 @@ ResolveExternalParams(Node *inputNode, ParamListInfo boundParams)
} }
/* parameterId starts from 1 */ /* parameterId starts from 1 */
parameterIndex = parameterId - 1; int parameterIndex = parameterId - 1;
if (parameterIndex >= numberOfParameters) if (parameterIndex >= numberOfParameters)
{ {
return inputNode; return inputNode;
} }
correspondingParameterData = &boundParams->params[parameterIndex]; ParamExternData *correspondingParameterData =
&boundParams->params[parameterIndex];
if (!(correspondingParameterData->pflags & PARAM_FLAG_CONST)) if (!(correspondingParameterData->pflags & PARAM_FLAG_CONST))
{ {
@ -972,7 +965,7 @@ ResolveExternalParams(Node *inputNode, ParamListInfo boundParams)
get_typlenbyval(paramToProcess->paramtype, &typeLength, &typeByValue); get_typlenbyval(paramToProcess->paramtype, &typeLength, &typeByValue);
paramIsNull = correspondingParameterData->isnull; bool paramIsNull = correspondingParameterData->isnull;
if (paramIsNull) if (paramIsNull)
{ {
constValue = 0; constValue = 0;
@ -1015,17 +1008,14 @@ ResolveExternalParams(Node *inputNode, ParamListInfo boundParams)
DistributedPlan * DistributedPlan *
GetDistributedPlan(CustomScan *customScan) GetDistributedPlan(CustomScan *customScan)
{ {
Node *node = NULL;
DistributedPlan *distributedPlan = NULL;
Assert(list_length(customScan->custom_private) == 1); Assert(list_length(customScan->custom_private) == 1);
node = (Node *) linitial(customScan->custom_private); Node *node = (Node *) linitial(customScan->custom_private);
Assert(CitusIsA(node, DistributedPlan)); Assert(CitusIsA(node, DistributedPlan));
CheckNodeCopyAndSerialization(node); CheckNodeCopyAndSerialization(node);
distributedPlan = (DistributedPlan *) node; DistributedPlan *distributedPlan = (DistributedPlan *) node;
return distributedPlan; return distributedPlan;
} }
@ -1040,7 +1030,6 @@ FinalizePlan(PlannedStmt *localPlan, DistributedPlan *distributedPlan)
{ {
PlannedStmt *finalPlan = NULL; PlannedStmt *finalPlan = NULL;
CustomScan *customScan = makeNode(CustomScan); CustomScan *customScan = makeNode(CustomScan);
Node *distributedPlanData = NULL;
MultiExecutorType executorType = MULTI_EXECUTOR_INVALID_FIRST; MultiExecutorType executorType = MULTI_EXECUTOR_INVALID_FIRST;
if (!distributedPlan->planningError) if (!distributedPlan->planningError)
@ -1092,7 +1081,7 @@ FinalizePlan(PlannedStmt *localPlan, DistributedPlan *distributedPlan)
distributedPlan->relationIdList = localPlan->relationOids; distributedPlan->relationIdList = localPlan->relationOids;
distributedPlan->queryId = localPlan->queryId; distributedPlan->queryId = localPlan->queryId;
distributedPlanData = (Node *) distributedPlan; Node *distributedPlanData = (Node *) distributedPlan;
customScan->custom_private = list_make1(distributedPlanData); customScan->custom_private = list_make1(distributedPlanData);
customScan->flags = CUSTOMPATH_SUPPORT_BACKWARD_SCAN; customScan->flags = CUSTOMPATH_SUPPORT_BACKWARD_SCAN;
@ -1119,9 +1108,7 @@ static PlannedStmt *
FinalizeNonRouterPlan(PlannedStmt *localPlan, DistributedPlan *distributedPlan, FinalizeNonRouterPlan(PlannedStmt *localPlan, DistributedPlan *distributedPlan,
CustomScan *customScan) CustomScan *customScan)
{ {
PlannedStmt *finalPlan = NULL; PlannedStmt *finalPlan = MasterNodeSelectPlan(distributedPlan, customScan);
finalPlan = MasterNodeSelectPlan(distributedPlan, customScan);
finalPlan->queryId = localPlan->queryId; finalPlan->queryId = localPlan->queryId;
finalPlan->utilityStmt = localPlan->utilityStmt; finalPlan->utilityStmt = localPlan->utilityStmt;
@ -1141,8 +1128,6 @@ FinalizeNonRouterPlan(PlannedStmt *localPlan, DistributedPlan *distributedPlan,
static PlannedStmt * static PlannedStmt *
FinalizeRouterPlan(PlannedStmt *localPlan, CustomScan *customScan) FinalizeRouterPlan(PlannedStmt *localPlan, CustomScan *customScan)
{ {
PlannedStmt *routerPlan = NULL;
RangeTblEntry *remoteScanRangeTableEntry = NULL;
ListCell *targetEntryCell = NULL; ListCell *targetEntryCell = NULL;
List *targetList = NIL; List *targetList = NIL;
List *columnNameList = NIL; List *columnNameList = NIL;
@ -1154,9 +1139,6 @@ FinalizeRouterPlan(PlannedStmt *localPlan, CustomScan *customScan)
foreach(targetEntryCell, localPlan->planTree->targetlist) foreach(targetEntryCell, localPlan->planTree->targetlist)
{ {
TargetEntry *targetEntry = lfirst(targetEntryCell); TargetEntry *targetEntry = lfirst(targetEntryCell);
TargetEntry *newTargetEntry = NULL;
Var *newVar = NULL;
Value *columnName = NULL;
Assert(IsA(targetEntry, TargetEntry)); Assert(IsA(targetEntry, TargetEntry));
@ -1171,7 +1153,7 @@ FinalizeRouterPlan(PlannedStmt *localPlan, CustomScan *customScan)
} }
/* build target entry pointing to remote scan range table entry */ /* build target entry pointing to remote scan range table entry */
newVar = makeVarFromTargetEntry(customScanRangeTableIndex, targetEntry); Var *newVar = makeVarFromTargetEntry(customScanRangeTableIndex, targetEntry);
if (newVar->vartype == RECORDOID) if (newVar->vartype == RECORDOID)
{ {
@ -1184,20 +1166,20 @@ FinalizeRouterPlan(PlannedStmt *localPlan, CustomScan *customScan)
newVar->vartypmod = BlessRecordExpression(targetEntry->expr); newVar->vartypmod = BlessRecordExpression(targetEntry->expr);
} }
newTargetEntry = flatCopyTargetEntry(targetEntry); TargetEntry *newTargetEntry = flatCopyTargetEntry(targetEntry);
newTargetEntry->expr = (Expr *) newVar; newTargetEntry->expr = (Expr *) newVar;
targetList = lappend(targetList, newTargetEntry); targetList = lappend(targetList, newTargetEntry);
columnName = makeString(targetEntry->resname); Value *columnName = makeString(targetEntry->resname);
columnNameList = lappend(columnNameList, columnName); columnNameList = lappend(columnNameList, columnName);
} }
customScan->scan.plan.targetlist = targetList; customScan->scan.plan.targetlist = targetList;
routerPlan = makeNode(PlannedStmt); PlannedStmt *routerPlan = makeNode(PlannedStmt);
routerPlan->planTree = (Plan *) customScan; routerPlan->planTree = (Plan *) customScan;
remoteScanRangeTableEntry = RemoteScanRangeTableEntry(columnNameList); RangeTblEntry *remoteScanRangeTableEntry = RemoteScanRangeTableEntry(columnNameList);
routerPlan->rtable = list_make1(remoteScanRangeTableEntry); routerPlan->rtable = list_make1(remoteScanRangeTableEntry);
/* add original range table list for access permission checks */ /* add original range table list for access permission checks */
@ -1236,10 +1218,9 @@ BlessRecordExpression(Expr *expr)
*/ */
Oid resultTypeId = InvalidOid; Oid resultTypeId = InvalidOid;
TupleDesc resultTupleDesc = NULL; TupleDesc resultTupleDesc = NULL;
TypeFuncClass typeClass;
/* get_expr_result_type blesses the tuple descriptor */ /* get_expr_result_type blesses the tuple descriptor */
typeClass = get_expr_result_type((Node *) expr, &resultTypeId, TypeFuncClass typeClass = get_expr_result_type((Node *) expr, &resultTypeId,
&resultTupleDesc); &resultTupleDesc);
if (typeClass == TYPEFUNC_COMPOSITE) if (typeClass == TYPEFUNC_COMPOSITE)
{ {
@ -1368,32 +1349,27 @@ multi_join_restriction_hook(PlannerInfo *root,
JoinType jointype, JoinType jointype,
JoinPathExtraData *extra) JoinPathExtraData *extra)
{ {
PlannerRestrictionContext *plannerRestrictionContext = NULL;
JoinRestrictionContext *joinRestrictionContext = NULL;
JoinRestriction *joinRestriction = NULL;
MemoryContext restrictionsMemoryContext = NULL;
MemoryContext oldMemoryContext = NULL;
List *restrictInfoList = NIL;
/* /*
* Use a memory context that's guaranteed to live long enough, could be * Use a memory context that's guaranteed to live long enough, could be
* called in a more shorted lived one (e.g. with GEQO). * called in a more shorted lived one (e.g. with GEQO).
*/ */
plannerRestrictionContext = CurrentPlannerRestrictionContext(); PlannerRestrictionContext *plannerRestrictionContext =
restrictionsMemoryContext = plannerRestrictionContext->memoryContext; CurrentPlannerRestrictionContext();
oldMemoryContext = MemoryContextSwitchTo(restrictionsMemoryContext); MemoryContext restrictionsMemoryContext = plannerRestrictionContext->memoryContext;
MemoryContext oldMemoryContext = MemoryContextSwitchTo(restrictionsMemoryContext);
/* /*
* We create a copy of restrictInfoList because it may be created in a memory * We create a copy of restrictInfoList because it may be created in a memory
* context which will be deleted when we still need it, thus we create a copy * context which will be deleted when we still need it, thus we create a copy
* of it in our memory context. * of it in our memory context.
*/ */
restrictInfoList = copyObject(extra->restrictlist); List *restrictInfoList = copyObject(extra->restrictlist);
joinRestrictionContext = plannerRestrictionContext->joinRestrictionContext; JoinRestrictionContext *joinRestrictionContext =
plannerRestrictionContext->joinRestrictionContext;
Assert(joinRestrictionContext != NULL); Assert(joinRestrictionContext != NULL);
joinRestriction = palloc0(sizeof(JoinRestriction)); JoinRestriction *joinRestriction = palloc0(sizeof(JoinRestriction));
joinRestriction->joinType = jointype; joinRestriction->joinType = jointype;
joinRestriction->joinRestrictInfoList = restrictInfoList; joinRestriction->joinRestrictInfoList = restrictInfoList;
joinRestriction->plannerInfo = root; joinRestriction->plannerInfo = root;
@ -1424,14 +1400,7 @@ void
multi_relation_restriction_hook(PlannerInfo *root, RelOptInfo *relOptInfo, multi_relation_restriction_hook(PlannerInfo *root, RelOptInfo *relOptInfo,
Index restrictionIndex, RangeTblEntry *rte) Index restrictionIndex, RangeTblEntry *rte)
{ {
PlannerRestrictionContext *plannerRestrictionContext = NULL;
RelationRestrictionContext *relationRestrictionContext = NULL;
MemoryContext restrictionsMemoryContext = NULL;
MemoryContext oldMemoryContext = NULL;
RelationRestriction *relationRestriction = NULL;
DistTableCacheEntry *cacheEntry = NULL; DistTableCacheEntry *cacheEntry = NULL;
bool distributedTable = false;
bool localTable = false;
AdjustReadIntermediateResultCost(rte, relOptInfo); AdjustReadIntermediateResultCost(rte, relOptInfo);
@ -1444,14 +1413,15 @@ multi_relation_restriction_hook(PlannerInfo *root, RelOptInfo *relOptInfo,
* Use a memory context that's guaranteed to live long enough, could be * Use a memory context that's guaranteed to live long enough, could be
* called in a more shorted lived one (e.g. with GEQO). * called in a more shorted lived one (e.g. with GEQO).
*/ */
plannerRestrictionContext = CurrentPlannerRestrictionContext(); PlannerRestrictionContext *plannerRestrictionContext =
restrictionsMemoryContext = plannerRestrictionContext->memoryContext; CurrentPlannerRestrictionContext();
oldMemoryContext = MemoryContextSwitchTo(restrictionsMemoryContext); MemoryContext restrictionsMemoryContext = plannerRestrictionContext->memoryContext;
MemoryContext oldMemoryContext = MemoryContextSwitchTo(restrictionsMemoryContext);
distributedTable = IsDistributedTable(rte->relid); bool distributedTable = IsDistributedTable(rte->relid);
localTable = !distributedTable; bool localTable = !distributedTable;
relationRestriction = palloc0(sizeof(RelationRestriction)); RelationRestriction *relationRestriction = palloc0(sizeof(RelationRestriction));
relationRestriction->index = restrictionIndex; relationRestriction->index = restrictionIndex;
relationRestriction->relationId = rte->relid; relationRestriction->relationId = rte->relid;
relationRestriction->rte = rte; relationRestriction->rte = rte;
@ -1463,7 +1433,8 @@ multi_relation_restriction_hook(PlannerInfo *root, RelOptInfo *relOptInfo,
/* see comments on GetVarFromAssignedParam() */ /* see comments on GetVarFromAssignedParam() */
relationRestriction->outerPlanParamsList = OuterPlanParamsList(root); relationRestriction->outerPlanParamsList = OuterPlanParamsList(root);
relationRestrictionContext = plannerRestrictionContext->relationRestrictionContext; RelationRestrictionContext *relationRestrictionContext =
plannerRestrictionContext->relationRestrictionContext;
relationRestrictionContext->hasDistributedRelation |= distributedTable; relationRestrictionContext->hasDistributedRelation |= distributedTable;
relationRestrictionContext->hasLocalRelation |= localTable; relationRestrictionContext->hasLocalRelation |= localTable;
@ -1644,9 +1615,8 @@ static List *
OuterPlanParamsList(PlannerInfo *root) OuterPlanParamsList(PlannerInfo *root)
{ {
List *planParamsList = NIL; List *planParamsList = NIL;
PlannerInfo *outerNodeRoot = NULL;
for (outerNodeRoot = root->parent_root; outerNodeRoot != NULL; for (PlannerInfo *outerNodeRoot = root->parent_root; outerNodeRoot != NULL;
outerNodeRoot = outerNodeRoot->parent_root) outerNodeRoot = outerNodeRoot->parent_root)
{ {
RootPlanParams *rootPlanParams = palloc0(sizeof(RootPlanParams)); RootPlanParams *rootPlanParams = palloc0(sizeof(RootPlanParams));
@ -1729,11 +1699,9 @@ CreateAndPushPlannerRestrictionContext(void)
static PlannerRestrictionContext * static PlannerRestrictionContext *
CurrentPlannerRestrictionContext(void) CurrentPlannerRestrictionContext(void)
{ {
PlannerRestrictionContext *plannerRestrictionContext = NULL;
Assert(plannerRestrictionContextList != NIL); Assert(plannerRestrictionContextList != NIL);
plannerRestrictionContext = PlannerRestrictionContext *plannerRestrictionContext =
(PlannerRestrictionContext *) linitial(plannerRestrictionContextList); (PlannerRestrictionContext *) linitial(plannerRestrictionContextList);
if (plannerRestrictionContext == NULL) if (plannerRestrictionContext == NULL)
@ -1804,7 +1772,6 @@ HasUnresolvedExternParamsWalker(Node *expression, ParamListInfo boundParams)
if (boundParams && paramId > 0 && paramId <= boundParams->numParams) if (boundParams && paramId > 0 && paramId <= boundParams->numParams)
{ {
ParamExternData *externParam = NULL; ParamExternData *externParam = NULL;
Oid paramType = InvalidOid;
/* give hook a chance in case parameter is dynamic */ /* give hook a chance in case parameter is dynamic */
if (boundParams->paramFetch != NULL) if (boundParams->paramFetch != NULL)
@ -1818,7 +1785,7 @@ HasUnresolvedExternParamsWalker(Node *expression, ParamListInfo boundParams)
externParam = &boundParams->params[paramId - 1]; externParam = &boundParams->params[paramId - 1];
} }
paramType = externParam->ptype; Oid paramType = externParam->ptype;
if (OidIsValid(paramType)) if (OidIsValid(paramType))
{ {
return false; return false;
@ -1890,7 +1857,6 @@ IsLocalReferenceTableJoin(Query *parse, List *rangeTableList)
foreach(rangeTableCell, rangeTableList) foreach(rangeTableCell, rangeTableList)
{ {
RangeTblEntry *rangeTableEntry = (RangeTblEntry *) lfirst(rangeTableCell); RangeTblEntry *rangeTableEntry = (RangeTblEntry *) lfirst(rangeTableCell);
DistTableCacheEntry *cacheEntry = NULL;
if (rangeTableEntry->rtekind == RTE_FUNCTION) if (rangeTableEntry->rtekind == RTE_FUNCTION)
{ {
@ -1909,7 +1875,8 @@ IsLocalReferenceTableJoin(Query *parse, List *rangeTableList)
continue; continue;
} }
cacheEntry = DistributedTableCacheEntry(rangeTableEntry->relid); DistTableCacheEntry *cacheEntry = DistributedTableCacheEntry(
rangeTableEntry->relid);
if (cacheEntry->partitionMethod == DISTRIBUTE_BY_NONE) if (cacheEntry->partitionMethod == DISTRIBUTE_BY_NONE)
{ {
hasReferenceTable = true; hasReferenceTable = true;
@ -1931,14 +1898,12 @@ IsLocalReferenceTableJoin(Query *parse, List *rangeTableList)
static bool static bool
QueryIsNotSimpleSelect(Node *node) QueryIsNotSimpleSelect(Node *node)
{ {
Query *query = NULL;
if (!IsA(node, Query)) if (!IsA(node, Query))
{ {
return false; return false;
} }
query = (Query *) node; Query *query = (Query *) node;
return (query->commandType != CMD_SELECT) || (query->rowMarks != NIL); return (query->commandType != CMD_SELECT) || (query->rowMarks != NIL);
} }
@ -1950,14 +1915,6 @@ QueryIsNotSimpleSelect(Node *node)
static bool static bool
UpdateReferenceTablesWithShard(Node *node, void *context) UpdateReferenceTablesWithShard(Node *node, void *context)
{ {
RangeTblEntry *newRte = NULL;
uint64 shardId = INVALID_SHARD_ID;
Oid relationId = InvalidOid;
Oid schemaId = InvalidOid;
char *relationName = NULL;
DistTableCacheEntry *cacheEntry = NULL;
ShardInterval *shardInterval = NULL;
if (node == NULL) if (node == NULL)
{ {
return false; return false;
@ -1976,32 +1933,32 @@ UpdateReferenceTablesWithShard(Node *node, void *context)
NULL); NULL);
} }
newRte = (RangeTblEntry *) node; RangeTblEntry *newRte = (RangeTblEntry *) node;
if (newRte->rtekind != RTE_RELATION) if (newRte->rtekind != RTE_RELATION)
{ {
return false; return false;
} }
relationId = newRte->relid; Oid relationId = newRte->relid;
if (!IsDistributedTable(relationId)) if (!IsDistributedTable(relationId))
{ {
return false; return false;
} }
cacheEntry = DistributedTableCacheEntry(relationId); DistTableCacheEntry *cacheEntry = DistributedTableCacheEntry(relationId);
if (cacheEntry->partitionMethod != DISTRIBUTE_BY_NONE) if (cacheEntry->partitionMethod != DISTRIBUTE_BY_NONE)
{ {
return false; return false;
} }
shardInterval = cacheEntry->sortedShardIntervalArray[0]; ShardInterval *shardInterval = cacheEntry->sortedShardIntervalArray[0];
shardId = shardInterval->shardId; uint64 shardId = shardInterval->shardId;
relationName = get_rel_name(relationId); char *relationName = get_rel_name(relationId);
AppendShardIdToName(&relationName, shardId); AppendShardIdToName(&relationName, shardId);
schemaId = get_rel_namespace(relationId); Oid schemaId = get_rel_namespace(relationId);
newRte->relid = get_relname_relid(relationName, schemaId); newRte->relid = get_relname_relid(relationName, schemaId);
/* /*

View File

@ -45,28 +45,21 @@ ExtendedOpNodeProperties
BuildExtendedOpNodeProperties(MultiExtendedOp *extendedOpNode) BuildExtendedOpNodeProperties(MultiExtendedOp *extendedOpNode)
{ {
ExtendedOpNodeProperties extendedOpNodeProperties; ExtendedOpNodeProperties extendedOpNodeProperties;
List *tableNodeList = NIL;
List *targetList = NIL;
Node *havingQual = NULL;
bool groupedByDisjointPartitionColumn = false;
bool repartitionSubquery = false;
bool hasNonPartitionColumnDistinctAgg = false;
bool pullDistinctColumns = false;
bool pushDownWindowFunctions = false;
tableNodeList = FindNodesOfType((MultiNode *) extendedOpNode, T_MultiTable); List *tableNodeList = FindNodesOfType((MultiNode *) extendedOpNode, T_MultiTable);
groupedByDisjointPartitionColumn = GroupedByDisjointPartitionColumn(tableNodeList, bool groupedByDisjointPartitionColumn = GroupedByDisjointPartitionColumn(
tableNodeList,
extendedOpNode); extendedOpNode);
repartitionSubquery = ExtendedOpNodeContainsRepartitionSubquery(extendedOpNode); bool repartitionSubquery = ExtendedOpNodeContainsRepartitionSubquery(extendedOpNode);
targetList = extendedOpNode->targetList; List *targetList = extendedOpNode->targetList;
havingQual = extendedOpNode->havingQual; Node *havingQual = extendedOpNode->havingQual;
hasNonPartitionColumnDistinctAgg = bool hasNonPartitionColumnDistinctAgg =
HasNonPartitionColumnDistinctAgg(targetList, havingQual, tableNodeList); HasNonPartitionColumnDistinctAgg(targetList, havingQual, tableNodeList);
pullDistinctColumns = bool pullDistinctColumns =
ShouldPullDistinctColumn(repartitionSubquery, groupedByDisjointPartitionColumn, ShouldPullDistinctColumn(repartitionSubquery, groupedByDisjointPartitionColumn,
hasNonPartitionColumnDistinctAgg); hasNonPartitionColumnDistinctAgg);
@ -75,7 +68,7 @@ BuildExtendedOpNodeProperties(MultiExtendedOp *extendedOpNode)
* using hasWindowFuncs is safe for now. However, this should be fixed * using hasWindowFuncs is safe for now. However, this should be fixed
* when we support pull-to-master window functions. * when we support pull-to-master window functions.
*/ */
pushDownWindowFunctions = extendedOpNode->hasWindowFuncs; bool pushDownWindowFunctions = extendedOpNode->hasWindowFuncs;
extendedOpNodeProperties.groupedByDisjointPartitionColumn = extendedOpNodeProperties.groupedByDisjointPartitionColumn =
groupedByDisjointPartitionColumn; groupedByDisjointPartitionColumn;
@ -103,14 +96,13 @@ GroupedByDisjointPartitionColumn(List *tableNodeList, MultiExtendedOp *opNode)
{ {
MultiTable *tableNode = (MultiTable *) lfirst(tableNodeCell); MultiTable *tableNode = (MultiTable *) lfirst(tableNodeCell);
Oid relationId = tableNode->relationId; Oid relationId = tableNode->relationId;
char partitionMethod = 0;
if (relationId == SUBQUERY_RELATION_ID || !IsDistributedTable(relationId)) if (relationId == SUBQUERY_RELATION_ID || !IsDistributedTable(relationId))
{ {
continue; continue;
} }
partitionMethod = PartitionMethod(relationId); char partitionMethod = PartitionMethod(relationId);
if (partitionMethod != DISTRIBUTE_BY_RANGE && if (partitionMethod != DISTRIBUTE_BY_RANGE &&
partitionMethod != DISTRIBUTE_BY_HASH) partitionMethod != DISTRIBUTE_BY_HASH)
{ {
@ -173,12 +165,8 @@ HasNonPartitionColumnDistinctAgg(List *targetEntryList, Node *havingQual,
foreach(aggregateCheckCell, aggregateCheckList) foreach(aggregateCheckCell, aggregateCheckList)
{ {
Node *targetNode = lfirst(aggregateCheckCell); Node *targetNode = lfirst(aggregateCheckCell);
Aggref *targetAgg = NULL;
List *varList = NIL;
ListCell *varCell = NULL; ListCell *varCell = NULL;
bool isPartitionColumn = false; bool isPartitionColumn = false;
TargetEntry *firstTargetEntry = NULL;
Node *firstTargetExprNode = NULL;
if (IsA(targetNode, Var)) if (IsA(targetNode, Var))
{ {
@ -186,7 +174,7 @@ HasNonPartitionColumnDistinctAgg(List *targetEntryList, Node *havingQual,
} }
Assert(IsA(targetNode, Aggref)); Assert(IsA(targetNode, Aggref));
targetAgg = (Aggref *) targetNode; Aggref *targetAgg = (Aggref *) targetNode;
if (targetAgg->aggdistinct == NIL) if (targetAgg->aggdistinct == NIL)
{ {
continue; continue;
@ -201,14 +189,15 @@ HasNonPartitionColumnDistinctAgg(List *targetEntryList, Node *havingQual,
return true; return true;
} }
firstTargetEntry = linitial_node(TargetEntry, targetAgg->args); TargetEntry *firstTargetEntry = linitial_node(TargetEntry, targetAgg->args);
firstTargetExprNode = strip_implicit_coercions((Node *) firstTargetEntry->expr); Node *firstTargetExprNode = strip_implicit_coercions(
(Node *) firstTargetEntry->expr);
if (!IsA(firstTargetExprNode, Var)) if (!IsA(firstTargetExprNode, Var))
{ {
return true; return true;
} }
varList = pull_var_clause_default((Node *) targetAgg->args); List *varList = pull_var_clause_default((Node *) targetAgg->args);
foreach(varCell, varList) foreach(varCell, varList)
{ {
Node *targetVar = (Node *) lfirst(varCell); Node *targetVar = (Node *) lfirst(varCell);

View File

@ -71,8 +71,6 @@ static bool DistKeyInSimpleOpExpression(Expr *clause, Var *distColumn);
PlannedStmt * PlannedStmt *
FastPathPlanner(Query *originalQuery, Query *parse, ParamListInfo boundParams) FastPathPlanner(Query *originalQuery, Query *parse, ParamListInfo boundParams)
{ {
PlannedStmt *result = NULL;
/* /*
* To support prepared statements for fast-path queries, we resolve the * To support prepared statements for fast-path queries, we resolve the
* external parameters at this point. Note that this is normally done by * external parameters at this point. Note that this is normally done by
@ -98,7 +96,7 @@ FastPathPlanner(Query *originalQuery, Query *parse, ParamListInfo boundParams)
(Node *) eval_const_expressions(NULL, (Node *) parse->jointree->quals); (Node *) eval_const_expressions(NULL, (Node *) parse->jointree->quals);
result = GeneratePlaceHolderPlannedStmt(originalQuery); PlannedStmt *result = GeneratePlaceHolderPlannedStmt(originalQuery);
return result; return result;
} }
@ -122,7 +120,6 @@ GeneratePlaceHolderPlannedStmt(Query *parse)
PlannedStmt *result = makeNode(PlannedStmt); PlannedStmt *result = makeNode(PlannedStmt);
SeqScan *seqScanNode = makeNode(SeqScan); SeqScan *seqScanNode = makeNode(SeqScan);
Plan *plan = &seqScanNode->plan; Plan *plan = &seqScanNode->plan;
Oid relationId = InvalidOid;
AssertArg(FastPathRouterQuery(parse)); AssertArg(FastPathRouterQuery(parse));
@ -143,7 +140,7 @@ GeneratePlaceHolderPlannedStmt(Query *parse)
result->rtable = copyObject(parse->rtable); result->rtable = copyObject(parse->rtable);
result->planTree = (Plan *) plan; result->planTree = (Plan *) plan;
relationId = ExtractFirstDistributedTableId(parse); Oid relationId = ExtractFirstDistributedTableId(parse);
result->relationOids = list_make1_oid(relationId); result->relationOids = list_make1_oid(relationId);
return result; return result;
@ -166,12 +163,8 @@ GeneratePlaceHolderPlannedStmt(Query *parse)
bool bool
FastPathRouterQuery(Query *query) FastPathRouterQuery(Query *query)
{ {
RangeTblEntry *rangeTableEntry = NULL;
FromExpr *joinTree = query->jointree; FromExpr *joinTree = query->jointree;
Node *quals = NULL; Node *quals = NULL;
Oid distributedTableId = InvalidOid;
Var *distributionKey = NULL;
DistTableCacheEntry *cacheEntry = NULL;
if (!EnableFastPathRouterPlanner) if (!EnableFastPathRouterPlanner)
{ {
@ -201,15 +194,15 @@ FastPathRouterQuery(Query *query)
return false; return false;
} }
rangeTableEntry = (RangeTblEntry *) linitial(query->rtable); RangeTblEntry *rangeTableEntry = (RangeTblEntry *) linitial(query->rtable);
if (rangeTableEntry->rtekind != RTE_RELATION) if (rangeTableEntry->rtekind != RTE_RELATION)
{ {
return false; return false;
} }
/* we don't want to deal with append/range distributed tables */ /* we don't want to deal with append/range distributed tables */
distributedTableId = rangeTableEntry->relid; Oid distributedTableId = rangeTableEntry->relid;
cacheEntry = DistributedTableCacheEntry(distributedTableId); DistTableCacheEntry *cacheEntry = DistributedTableCacheEntry(distributedTableId);
if (!(cacheEntry->partitionMethod == DISTRIBUTE_BY_HASH || if (!(cacheEntry->partitionMethod == DISTRIBUTE_BY_HASH ||
cacheEntry->partitionMethod == DISTRIBUTE_BY_NONE)) cacheEntry->partitionMethod == DISTRIBUTE_BY_NONE))
{ {
@ -224,7 +217,7 @@ FastPathRouterQuery(Query *query)
} }
/* if that's a reference table, we don't need to check anything further */ /* if that's a reference table, we don't need to check anything further */
distributionKey = PartitionColumn(distributedTableId, 1); Var *distributionKey = PartitionColumn(distributedTableId, 1);
if (!distributionKey) if (!distributionKey)
{ {
return true; return true;
@ -269,11 +262,10 @@ static bool
ColumnAppearsMultipleTimes(Node *quals, Var *distributionKey) ColumnAppearsMultipleTimes(Node *quals, Var *distributionKey)
{ {
ListCell *varClauseCell = NULL; ListCell *varClauseCell = NULL;
List *varClauseList = NIL;
int partitionColumnReferenceCount = 0; int partitionColumnReferenceCount = 0;
/* make sure partition column is used only once in the quals */ /* make sure partition column is used only once in the quals */
varClauseList = pull_var_clause_default(quals); List *varClauseList = pull_var_clause_default(quals);
foreach(varClauseCell, varClauseList) foreach(varClauseCell, varClauseList)
{ {
Var *column = (Var *) lfirst(varClauseCell); Var *column = (Var *) lfirst(varClauseCell);

View File

@ -98,7 +98,6 @@ contain_param_walker(Node *node, void *context)
DistributedPlan * DistributedPlan *
TryToDelegateFunctionCall(Query *query, bool *hasExternParam) TryToDelegateFunctionCall(Query *query, bool *hasExternParam)
{ {
FromExpr *joinTree = NULL;
List *targetList = NIL; List *targetList = NIL;
TargetEntry *targetEntry = NULL; TargetEntry *targetEntry = NULL;
FuncExpr *funcExpr = NULL; FuncExpr *funcExpr = NULL;
@ -116,7 +115,6 @@ TryToDelegateFunctionCall(Query *query, bool *hasExternParam)
Task *task = NULL; Task *task = NULL;
Job *job = NULL; Job *job = NULL;
DistributedPlan *distributedPlan = NULL; DistributedPlan *distributedPlan = NULL;
int32 groupId = 0;
struct ParamWalkerContext walkerParamContext = { 0 }; struct ParamWalkerContext walkerParamContext = { 0 };
/* set hasExternParam now in case of early exit */ /* set hasExternParam now in case of early exit */
@ -128,7 +126,7 @@ TryToDelegateFunctionCall(Query *query, bool *hasExternParam)
return NULL; return NULL;
} }
groupId = GetLocalGroupId(); int32 groupId = GetLocalGroupId();
if (groupId != 0 || groupId == GROUP_ID_UPGRADING) if (groupId != 0 || groupId == GROUP_ID_UPGRADING)
{ {
/* do not delegate from workers, or while upgrading */ /* do not delegate from workers, or while upgrading */
@ -147,7 +145,7 @@ TryToDelegateFunctionCall(Query *query, bool *hasExternParam)
return NULL; return NULL;
} }
joinTree = query->jointree; FromExpr *joinTree = query->jointree;
if (joinTree == NULL) if (joinTree == NULL)
{ {
/* no join tree (mostly here to be defensive) */ /* no join tree (mostly here to be defensive) */

View File

@ -136,9 +136,6 @@ static bool
CheckInsertSelectQuery(Query *query) CheckInsertSelectQuery(Query *query)
{ {
CmdType commandType = query->commandType; CmdType commandType = query->commandType;
List *fromList = NULL;
RangeTblRef *rangeTableReference = NULL;
RangeTblEntry *subqueryRte = NULL;
if (commandType != CMD_INSERT) if (commandType != CMD_INSERT)
{ {
@ -150,19 +147,19 @@ CheckInsertSelectQuery(Query *query)
return false; return false;
} }
fromList = query->jointree->fromlist; List *fromList = query->jointree->fromlist;
if (list_length(fromList) != 1) if (list_length(fromList) != 1)
{ {
return false; return false;
} }
rangeTableReference = linitial(fromList); RangeTblRef *rangeTableReference = linitial(fromList);
if (!IsA(rangeTableReference, RangeTblRef)) if (!IsA(rangeTableReference, RangeTblRef))
{ {
return false; return false;
} }
subqueryRte = rt_fetch(rangeTableReference->rtindex, query->rtable); RangeTblEntry *subqueryRte = rt_fetch(rangeTableReference->rtindex, query->rtable);
if (subqueryRte->rtekind != RTE_SUBQUERY) if (subqueryRte->rtekind != RTE_SUBQUERY)
{ {
return false; return false;
@ -185,17 +182,14 @@ DistributedPlan *
CreateInsertSelectPlan(uint64 planId, Query *originalQuery, CreateInsertSelectPlan(uint64 planId, Query *originalQuery,
PlannerRestrictionContext *plannerRestrictionContext) PlannerRestrictionContext *plannerRestrictionContext)
{ {
DistributedPlan *distributedPlan = NULL; DeferredErrorMessage *deferredError = ErrorIfOnConflictNotSupported(originalQuery);
DeferredErrorMessage *deferredError = NULL;
deferredError = ErrorIfOnConflictNotSupported(originalQuery);
if (deferredError != NULL) if (deferredError != NULL)
{ {
/* raising the error as there is no possible solution for the unsupported on conflict statements */ /* raising the error as there is no possible solution for the unsupported on conflict statements */
RaiseDeferredError(deferredError, ERROR); RaiseDeferredError(deferredError, ERROR);
} }
distributedPlan = CreateDistributedInsertSelectPlan(originalQuery, DistributedPlan *distributedPlan = CreateDistributedInsertSelectPlan(originalQuery,
plannerRestrictionContext); plannerRestrictionContext);
if (distributedPlan->planningError != NULL) if (distributedPlan->planningError != NULL)
@ -220,10 +214,8 @@ static DistributedPlan *
CreateDistributedInsertSelectPlan(Query *originalQuery, CreateDistributedInsertSelectPlan(Query *originalQuery,
PlannerRestrictionContext *plannerRestrictionContext) PlannerRestrictionContext *plannerRestrictionContext)
{ {
int shardOffset = 0;
List *sqlTaskList = NIL; List *sqlTaskList = NIL;
uint32 taskIdIndex = 1; /* 0 is reserved for invalid taskId */ uint32 taskIdIndex = 1; /* 0 is reserved for invalid taskId */
Job *workerJob = NULL;
uint64 jobId = INVALID_JOB_ID; uint64 jobId = INVALID_JOB_ID;
DistributedPlan *distributedPlan = CitusMakeNode(DistributedPlan); DistributedPlan *distributedPlan = CitusMakeNode(DistributedPlan);
RangeTblEntry *insertRte = ExtractResultRelationRTE(originalQuery); RangeTblEntry *insertRte = ExtractResultRelationRTE(originalQuery);
@ -234,7 +226,6 @@ CreateDistributedInsertSelectPlan(Query *originalQuery,
RelationRestrictionContext *relationRestrictionContext = RelationRestrictionContext *relationRestrictionContext =
plannerRestrictionContext->relationRestrictionContext; plannerRestrictionContext->relationRestrictionContext;
bool allReferenceTables = relationRestrictionContext->allReferenceTables; bool allReferenceTables = relationRestrictionContext->allReferenceTables;
bool allDistributionKeysInQueryAreEqual = false;
distributedPlan->modLevel = RowModifyLevelForQuery(originalQuery); distributedPlan->modLevel = RowModifyLevelForQuery(originalQuery);
@ -251,7 +242,7 @@ CreateDistributedInsertSelectPlan(Query *originalQuery,
return distributedPlan; return distributedPlan;
} }
allDistributionKeysInQueryAreEqual = bool allDistributionKeysInQueryAreEqual =
AllDistributionKeysInQueryAreEqual(originalQuery, plannerRestrictionContext); AllDistributionKeysInQueryAreEqual(originalQuery, plannerRestrictionContext);
/* /*
@ -263,13 +254,13 @@ CreateDistributedInsertSelectPlan(Query *originalQuery,
* the current shard boundaries. Finally, perform the normal shard pruning to * the current shard boundaries. Finally, perform the normal shard pruning to
* decide on whether to push the query to the current shard or not. * decide on whether to push the query to the current shard or not.
*/ */
for (shardOffset = 0; shardOffset < shardCount; shardOffset++) for (int shardOffset = 0; shardOffset < shardCount; shardOffset++)
{ {
ShardInterval *targetShardInterval = ShardInterval *targetShardInterval =
targetCacheEntry->sortedShardIntervalArray[shardOffset]; targetCacheEntry->sortedShardIntervalArray[shardOffset];
Task *modifyTask = NULL;
modifyTask = RouterModifyTaskForShardInterval(originalQuery, targetShardInterval, Task *modifyTask = RouterModifyTaskForShardInterval(originalQuery,
targetShardInterval,
plannerRestrictionContext, plannerRestrictionContext,
taskIdIndex, taskIdIndex,
allDistributionKeysInQueryAreEqual); allDistributionKeysInQueryAreEqual);
@ -289,7 +280,7 @@ CreateDistributedInsertSelectPlan(Query *originalQuery,
} }
/* Create the worker job */ /* Create the worker job */
workerJob = CitusMakeNode(Job); Job *workerJob = CitusMakeNode(Job);
workerJob->taskList = sqlTaskList; workerJob->taskList = sqlTaskList;
workerJob->subqueryPushdown = false; workerJob->subqueryPushdown = false;
workerJob->dependedJobList = NIL; workerJob->dependedJobList = NIL;
@ -321,17 +312,15 @@ static DeferredErrorMessage *
DistributedInsertSelectSupported(Query *queryTree, RangeTblEntry *insertRte, DistributedInsertSelectSupported(Query *queryTree, RangeTblEntry *insertRte,
RangeTblEntry *subqueryRte, bool allReferenceTables) RangeTblEntry *subqueryRte, bool allReferenceTables)
{ {
Query *subquery = NULL;
Oid selectPartitionColumnTableId = InvalidOid; Oid selectPartitionColumnTableId = InvalidOid;
Oid targetRelationId = insertRte->relid; Oid targetRelationId = insertRte->relid;
char targetPartitionMethod = PartitionMethod(targetRelationId); char targetPartitionMethod = PartitionMethod(targetRelationId);
ListCell *rangeTableCell = NULL; ListCell *rangeTableCell = NULL;
DeferredErrorMessage *error = NULL;
/* we only do this check for INSERT ... SELECT queries */ /* we only do this check for INSERT ... SELECT queries */
AssertArg(InsertSelectIntoDistributedTable(queryTree)); AssertArg(InsertSelectIntoDistributedTable(queryTree));
subquery = subqueryRte->subquery; Query *subquery = subqueryRte->subquery;
if (!NeedsDistributedPlanning(subquery)) if (!NeedsDistributedPlanning(subquery))
{ {
@ -363,7 +352,7 @@ DistributedInsertSelectSupported(Query *queryTree, RangeTblEntry *insertRte,
} }
/* we don't support LIMIT, OFFSET and WINDOW functions */ /* we don't support LIMIT, OFFSET and WINDOW functions */
error = MultiTaskRouterSelectQuerySupported(subquery); DeferredErrorMessage *error = MultiTaskRouterSelectQuerySupported(subquery);
if (error) if (error)
{ {
return error; return error;
@ -442,20 +431,15 @@ RouterModifyTaskForShardInterval(Query *originalQuery, ShardInterval *shardInter
StringInfo queryString = makeStringInfo(); StringInfo queryString = makeStringInfo();
ListCell *restrictionCell = NULL; ListCell *restrictionCell = NULL;
Task *modifyTask = NULL;
List *selectPlacementList = NIL; List *selectPlacementList = NIL;
uint64 selectAnchorShardId = INVALID_SHARD_ID; uint64 selectAnchorShardId = INVALID_SHARD_ID;
List *relationShardList = NIL; List *relationShardList = NIL;
List *prunedShardIntervalListList = NIL; List *prunedShardIntervalListList = NIL;
uint64 jobId = INVALID_JOB_ID; uint64 jobId = INVALID_JOB_ID;
List *insertShardPlacementList = NULL;
List *intersectedPlacementList = NULL;
bool replacePrunedQueryWithDummy = false;
bool allReferenceTables = bool allReferenceTables =
plannerRestrictionContext->relationRestrictionContext->allReferenceTables; plannerRestrictionContext->relationRestrictionContext->allReferenceTables;
List *shardOpExpressions = NIL; List *shardOpExpressions = NIL;
RestrictInfo *shardRestrictionList = NULL; RestrictInfo *shardRestrictionList = NULL;
DeferredErrorMessage *planningError = NULL;
bool multiShardModifyQuery = false; bool multiShardModifyQuery = false;
List *relationRestrictionList = NIL; List *relationRestrictionList = NIL;
@ -517,16 +501,19 @@ RouterModifyTaskForShardInterval(Query *originalQuery, ShardInterval *shardInter
} }
/* mark that we don't want the router planner to generate dummy hosts/queries */ /* mark that we don't want the router planner to generate dummy hosts/queries */
replacePrunedQueryWithDummy = false; bool replacePrunedQueryWithDummy = false;
/* /*
* Use router planner to decide on whether we can push down the query or not. * Use router planner to decide on whether we can push down the query or not.
* If we can, we also rely on the side-effects that all RTEs have been updated * If we can, we also rely on the side-effects that all RTEs have been updated
* to point to the relevant nodes and selectPlacementList is determined. * to point to the relevant nodes and selectPlacementList is determined.
*/ */
planningError = PlanRouterQuery(copiedSubquery, copyOfPlannerRestrictionContext, DeferredErrorMessage *planningError = PlanRouterQuery(copiedSubquery,
&selectPlacementList, &selectAnchorShardId, copyOfPlannerRestrictionContext,
&relationShardList, &prunedShardIntervalListList, &selectPlacementList,
&selectAnchorShardId,
&relationShardList,
&prunedShardIntervalListList,
replacePrunedQueryWithDummy, replacePrunedQueryWithDummy,
&multiShardModifyQuery, NULL); &multiShardModifyQuery, NULL);
@ -552,8 +539,8 @@ RouterModifyTaskForShardInterval(Query *originalQuery, ShardInterval *shardInter
} }
/* get the placements for insert target shard and its intersection with select */ /* get the placements for insert target shard and its intersection with select */
insertShardPlacementList = FinalizedShardPlacementList(shardId); List *insertShardPlacementList = FinalizedShardPlacementList(shardId);
intersectedPlacementList = IntersectPlacementList(insertShardPlacementList, List *intersectedPlacementList = IntersectPlacementList(insertShardPlacementList,
selectPlacementList); selectPlacementList);
/* /*
@ -586,7 +573,8 @@ RouterModifyTaskForShardInterval(Query *originalQuery, ShardInterval *shardInter
ereport(DEBUG2, (errmsg("distributed statement: %s", ereport(DEBUG2, (errmsg("distributed statement: %s",
ApplyLogRedaction(queryString->data)))); ApplyLogRedaction(queryString->data))));
modifyTask = CreateBasicTask(jobId, taskIdIndex, MODIFY_TASK, queryString->data); Task *modifyTask = CreateBasicTask(jobId, taskIdIndex, MODIFY_TASK,
queryString->data);
modifyTask->dependedTaskList = NULL; modifyTask->dependedTaskList = NULL;
modifyTask->anchorShardId = shardId; modifyTask->anchorShardId = shardId;
modifyTask->taskPlacementList = insertShardPlacementList; modifyTask->taskPlacementList = insertShardPlacementList;
@ -612,21 +600,18 @@ Query *
ReorderInsertSelectTargetLists(Query *originalQuery, RangeTblEntry *insertRte, ReorderInsertSelectTargetLists(Query *originalQuery, RangeTblEntry *insertRte,
RangeTblEntry *subqueryRte) RangeTblEntry *subqueryRte)
{ {
Query *subquery = NULL;
ListCell *insertTargetEntryCell; ListCell *insertTargetEntryCell;
List *newSubqueryTargetlist = NIL; List *newSubqueryTargetlist = NIL;
List *newInsertTargetlist = NIL; List *newInsertTargetlist = NIL;
int resno = 1; int resno = 1;
Index insertTableId = 1; Index insertTableId = 1;
Oid insertRelationId = InvalidOid;
int subqueryTargetLength = 0;
int targetEntryIndex = 0; int targetEntryIndex = 0;
AssertArg(InsertSelectIntoDistributedTable(originalQuery)); AssertArg(InsertSelectIntoDistributedTable(originalQuery));
subquery = subqueryRte->subquery; Query *subquery = subqueryRte->subquery;
insertRelationId = insertRte->relid; Oid insertRelationId = insertRte->relid;
/* /*
* We implement the following algorithm for the reoderding: * We implement the following algorithm for the reoderding:
@ -642,11 +627,7 @@ ReorderInsertSelectTargetLists(Query *originalQuery, RangeTblEntry *insertRte,
foreach(insertTargetEntryCell, originalQuery->targetList) foreach(insertTargetEntryCell, originalQuery->targetList)
{ {
TargetEntry *oldInsertTargetEntry = lfirst(insertTargetEntryCell); TargetEntry *oldInsertTargetEntry = lfirst(insertTargetEntryCell);
TargetEntry *newInsertTargetEntry = NULL;
Var *newInsertVar = NULL;
TargetEntry *newSubqueryTargetEntry = NULL; TargetEntry *newSubqueryTargetEntry = NULL;
List *targetVarList = NULL;
int targetVarCount = 0;
AttrNumber originalAttrNo = get_attnum(insertRelationId, AttrNumber originalAttrNo = get_attnum(insertRelationId,
oldInsertTargetEntry->resname); oldInsertTargetEntry->resname);
@ -665,10 +646,10 @@ ReorderInsertSelectTargetLists(Query *originalQuery, RangeTblEntry *insertRte,
* It is safe to pull Var clause and ignore the coercions since that * It is safe to pull Var clause and ignore the coercions since that
* are already going to be added on the workers implicitly. * are already going to be added on the workers implicitly.
*/ */
targetVarList = pull_var_clause((Node *) oldInsertTargetEntry->expr, List *targetVarList = pull_var_clause((Node *) oldInsertTargetEntry->expr,
PVC_RECURSE_AGGREGATES); PVC_RECURSE_AGGREGATES);
targetVarCount = list_length(targetVarList); int targetVarCount = list_length(targetVarList);
/* a single INSERT target entry cannot have more than one Var */ /* a single INSERT target entry cannot have more than one Var */
Assert(targetVarCount <= 1); Assert(targetVarCount <= 1);
@ -702,12 +683,13 @@ ReorderInsertSelectTargetLists(Query *originalQuery, RangeTblEntry *insertRte,
*/ */
Assert(!newSubqueryTargetEntry->resjunk); Assert(!newSubqueryTargetEntry->resjunk);
newInsertVar = makeVar(insertTableId, originalAttrNo, Var *newInsertVar = makeVar(insertTableId, originalAttrNo,
exprType((Node *) newSubqueryTargetEntry->expr), exprType((Node *) newSubqueryTargetEntry->expr),
exprTypmod((Node *) newSubqueryTargetEntry->expr), exprTypmod((Node *) newSubqueryTargetEntry->expr),
exprCollation((Node *) newSubqueryTargetEntry->expr), exprCollation((Node *) newSubqueryTargetEntry->expr),
0); 0);
newInsertTargetEntry = makeTargetEntry((Expr *) newInsertVar, originalAttrNo, TargetEntry *newInsertTargetEntry = makeTargetEntry((Expr *) newInsertVar,
originalAttrNo,
oldInsertTargetEntry->resname, oldInsertTargetEntry->resname,
oldInsertTargetEntry->resjunk); oldInsertTargetEntry->resjunk);
@ -719,12 +701,11 @@ ReorderInsertSelectTargetLists(Query *originalQuery, RangeTblEntry *insertRte,
* if there are any remaining target list entries (i.e., GROUP BY column not on the * if there are any remaining target list entries (i.e., GROUP BY column not on the
* target list of subquery), update the remaining resnos. * target list of subquery), update the remaining resnos.
*/ */
subqueryTargetLength = list_length(subquery->targetList); int subqueryTargetLength = list_length(subquery->targetList);
for (; targetEntryIndex < subqueryTargetLength; ++targetEntryIndex) for (; targetEntryIndex < subqueryTargetLength; ++targetEntryIndex)
{ {
TargetEntry *oldSubqueryTle = list_nth(subquery->targetList, TargetEntry *oldSubqueryTle = list_nth(subquery->targetList,
targetEntryIndex); targetEntryIndex);
TargetEntry *newSubqueryTargetEntry = NULL;
/* /*
* Skip non-junk entries since we've already processed them above and this * Skip non-junk entries since we've already processed them above and this
@ -735,7 +716,7 @@ ReorderInsertSelectTargetLists(Query *originalQuery, RangeTblEntry *insertRte,
continue; continue;
} }
newSubqueryTargetEntry = copyObject(oldSubqueryTle); TargetEntry *newSubqueryTargetEntry = copyObject(oldSubqueryTle);
newSubqueryTargetEntry->resno = resno; newSubqueryTargetEntry->resno = resno;
newSubqueryTargetlist = lappend(newSubqueryTargetlist, newSubqueryTargetlist = lappend(newSubqueryTargetlist,
@ -920,13 +901,8 @@ InsertPartitionColumnMatchesSelect(Query *query, RangeTblEntry *insertRte,
{ {
TargetEntry *targetEntry = (TargetEntry *) lfirst(targetEntryCell); TargetEntry *targetEntry = (TargetEntry *) lfirst(targetEntryCell);
List *insertTargetEntryColumnList = pull_var_clause_default((Node *) targetEntry); List *insertTargetEntryColumnList = pull_var_clause_default((Node *) targetEntry);
Var *insertVar = NULL;
AttrNumber originalAttrNo = InvalidAttrNumber;
TargetEntry *subqueryTargetEntry = NULL;
Expr *selectTargetExpr = NULL;
Oid subqueryPartitionColumnRelationId = InvalidOid; Oid subqueryPartitionColumnRelationId = InvalidOid;
Var *subqueryPartitionColumn = NULL; Var *subqueryPartitionColumn = NULL;
List *parentQueryList = NIL;
/* /*
* We only consider target entries that include a single column. Note that this * We only consider target entries that include a single column. Note that this
@ -941,8 +917,8 @@ InsertPartitionColumnMatchesSelect(Query *query, RangeTblEntry *insertRte,
continue; continue;
} }
insertVar = (Var *) linitial(insertTargetEntryColumnList); Var *insertVar = (Var *) linitial(insertTargetEntryColumnList);
originalAttrNo = targetEntry->resno; AttrNumber originalAttrNo = targetEntry->resno;
/* skip processing of target table non-partition columns */ /* skip processing of target table non-partition columns */
if (originalAttrNo != insertPartitionColumn->varattno) if (originalAttrNo != insertPartitionColumn->varattno)
@ -953,11 +929,11 @@ InsertPartitionColumnMatchesSelect(Query *query, RangeTblEntry *insertRte,
/* INSERT query includes the partition column */ /* INSERT query includes the partition column */
targetTableHasPartitionColumn = true; targetTableHasPartitionColumn = true;
subqueryTargetEntry = list_nth(subquery->targetList, TargetEntry *subqueryTargetEntry = list_nth(subquery->targetList,
insertVar->varattno - 1); insertVar->varattno - 1);
selectTargetExpr = subqueryTargetEntry->expr; Expr *selectTargetExpr = subqueryTargetEntry->expr;
parentQueryList = list_make2(query, subquery); List *parentQueryList = list_make2(query, subquery);
FindReferencedTableColumn(selectTargetExpr, FindReferencedTableColumn(selectTargetExpr,
parentQueryList, subquery, parentQueryList, subquery,
&subqueryPartitionColumnRelationId, &subqueryPartitionColumnRelationId,
@ -1135,7 +1111,6 @@ static DistributedPlan *
CreateCoordinatorInsertSelectPlan(uint64 planId, Query *parse) CreateCoordinatorInsertSelectPlan(uint64 planId, Query *parse)
{ {
Query *insertSelectQuery = copyObject(parse); Query *insertSelectQuery = copyObject(parse);
Query *selectQuery = NULL;
RangeTblEntry *selectRte = ExtractSelectRangeTableEntry(insertSelectQuery); RangeTblEntry *selectRte = ExtractSelectRangeTableEntry(insertSelectQuery);
RangeTblEntry *insertRte = ExtractResultRelationRTE(insertSelectQuery); RangeTblEntry *insertRte = ExtractResultRelationRTE(insertSelectQuery);
@ -1152,7 +1127,7 @@ CreateCoordinatorInsertSelectPlan(uint64 planId, Query *parse)
return distributedPlan; return distributedPlan;
} }
selectQuery = selectRte->subquery; Query *selectQuery = selectRte->subquery;
/* /*
* Wrap the SELECT as a subquery if the INSERT...SELECT has CTEs or the SELECT * Wrap the SELECT as a subquery if the INSERT...SELECT has CTEs or the SELECT
@ -1194,15 +1169,13 @@ CreateCoordinatorInsertSelectPlan(uint64 planId, Query *parse)
* insertSelectSubuery and a workerJob to execute afterwards. * insertSelectSubuery and a workerJob to execute afterwards.
*/ */
uint64 jobId = INVALID_JOB_ID; uint64 jobId = INVALID_JOB_ID;
Job *workerJob = NULL;
List *taskList = NIL;
char *resultIdPrefix = InsertSelectResultIdPrefix(planId); char *resultIdPrefix = InsertSelectResultIdPrefix(planId);
/* generate tasks for the INSERT..SELECT phase */ /* generate tasks for the INSERT..SELECT phase */
taskList = TwoPhaseInsertSelectTaskList(targetRelationId, insertSelectQuery, List *taskList = TwoPhaseInsertSelectTaskList(targetRelationId, insertSelectQuery,
resultIdPrefix); resultIdPrefix);
workerJob = CitusMakeNode(Job); Job *workerJob = CitusMakeNode(Job);
workerJob->taskList = taskList; workerJob->taskList = taskList;
workerJob->subqueryPushdown = false; workerJob->subqueryPushdown = false;
workerJob->dependedJobList = NIL; workerJob->dependedJobList = NIL;
@ -1232,18 +1205,14 @@ CreateCoordinatorInsertSelectPlan(uint64 planId, Query *parse)
static DeferredErrorMessage * static DeferredErrorMessage *
CoordinatorInsertSelectSupported(Query *insertSelectQuery) CoordinatorInsertSelectSupported(Query *insertSelectQuery)
{ {
RangeTblEntry *insertRte = NULL; DeferredErrorMessage *deferredError = ErrorIfOnConflictNotSupported(
RangeTblEntry *subqueryRte = NULL; insertSelectQuery);
Query *subquery = NULL;
DeferredErrorMessage *deferredError = NULL;
deferredError = ErrorIfOnConflictNotSupported(insertSelectQuery);
if (deferredError) if (deferredError)
{ {
return deferredError; return deferredError;
} }
insertRte = ExtractResultRelationRTE(insertSelectQuery); RangeTblEntry *insertRte = ExtractResultRelationRTE(insertSelectQuery);
if (PartitionMethod(insertRte->relid) == DISTRIBUTE_BY_APPEND) if (PartitionMethod(insertRte->relid) == DISTRIBUTE_BY_APPEND)
{ {
return DeferredError(ERRCODE_FEATURE_NOT_SUPPORTED, return DeferredError(ERRCODE_FEATURE_NOT_SUPPORTED,
@ -1251,8 +1220,8 @@ CoordinatorInsertSelectSupported(Query *insertSelectQuery)
"not supported", NULL, NULL); "not supported", NULL, NULL);
} }
subqueryRte = ExtractSelectRangeTableEntry(insertSelectQuery); RangeTblEntry *subqueryRte = ExtractSelectRangeTableEntry(insertSelectQuery);
subquery = (Query *) subqueryRte->subquery; Query *subquery = (Query *) subqueryRte->subquery;
if (NeedsDistributedPlanning(subquery) && if (NeedsDistributedPlanning(subquery) &&
contain_nextval_expression_walker((Node *) insertSelectQuery->targetList, NULL)) contain_nextval_expression_walker((Node *) insertSelectQuery->targetList, NULL))
@ -1274,25 +1243,22 @@ CoordinatorInsertSelectSupported(Query *insertSelectQuery)
static Query * static Query *
WrapSubquery(Query *subquery) WrapSubquery(Query *subquery)
{ {
Query *outerQuery = NULL;
ParseState *pstate = make_parsestate(NULL); ParseState *pstate = make_parsestate(NULL);
Alias *selectAlias = NULL;
RangeTblEntry *newRangeTableEntry = NULL;
RangeTblRef *newRangeTableRef = NULL;
ListCell *selectTargetCell = NULL; ListCell *selectTargetCell = NULL;
List *newTargetList = NIL; List *newTargetList = NIL;
outerQuery = makeNode(Query); Query *outerQuery = makeNode(Query);
outerQuery->commandType = CMD_SELECT; outerQuery->commandType = CMD_SELECT;
/* create range table entries */ /* create range table entries */
selectAlias = makeAlias("citus_insert_select_subquery", NIL); Alias *selectAlias = makeAlias("citus_insert_select_subquery", NIL);
newRangeTableEntry = addRangeTableEntryForSubquery(pstate, subquery, RangeTblEntry *newRangeTableEntry = addRangeTableEntryForSubquery(pstate, subquery,
selectAlias, false, true); selectAlias, false,
true);
outerQuery->rtable = list_make1(newRangeTableEntry); outerQuery->rtable = list_make1(newRangeTableEntry);
/* set the FROM expression to the subquery */ /* set the FROM expression to the subquery */
newRangeTableRef = makeNode(RangeTblRef); RangeTblRef *newRangeTableRef = makeNode(RangeTblRef);
newRangeTableRef->rtindex = 1; newRangeTableRef->rtindex = 1;
outerQuery->jointree = makeFromExpr(list_make1(newRangeTableRef), NULL); outerQuery->jointree = makeFromExpr(list_make1(newRangeTableRef), NULL);
@ -1300,8 +1266,6 @@ WrapSubquery(Query *subquery)
foreach(selectTargetCell, subquery->targetList) foreach(selectTargetCell, subquery->targetList)
{ {
TargetEntry *selectTargetEntry = (TargetEntry *) lfirst(selectTargetCell); TargetEntry *selectTargetEntry = (TargetEntry *) lfirst(selectTargetCell);
Var *newSelectVar = NULL;
TargetEntry *newSelectTargetEntry = NULL;
/* exactly 1 entry in FROM */ /* exactly 1 entry in FROM */
int indexInRangeTable = 1; int indexInRangeTable = 1;
@ -1311,12 +1275,12 @@ WrapSubquery(Query *subquery)
continue; continue;
} }
newSelectVar = makeVar(indexInRangeTable, selectTargetEntry->resno, Var *newSelectVar = makeVar(indexInRangeTable, selectTargetEntry->resno,
exprType((Node *) selectTargetEntry->expr), exprType((Node *) selectTargetEntry->expr),
exprTypmod((Node *) selectTargetEntry->expr), exprTypmod((Node *) selectTargetEntry->expr),
exprCollation((Node *) selectTargetEntry->expr), 0); exprCollation((Node *) selectTargetEntry->expr), 0);
newSelectTargetEntry = makeTargetEntry((Expr *) newSelectVar, TargetEntry *newSelectTargetEntry = makeTargetEntry((Expr *) newSelectVar,
selectTargetEntry->resno, selectTargetEntry->resno,
selectTargetEntry->resname, selectTargetEntry->resname,
selectTargetEntry->resjunk); selectTargetEntry->resjunk);
@ -1352,16 +1316,13 @@ TwoPhaseInsertSelectTaskList(Oid targetRelationId, Query *insertSelectQuery,
DistTableCacheEntry *targetCacheEntry = DistributedTableCacheEntry(targetRelationId); DistTableCacheEntry *targetCacheEntry = DistributedTableCacheEntry(targetRelationId);
int shardCount = targetCacheEntry->shardIntervalArrayLength; int shardCount = targetCacheEntry->shardIntervalArrayLength;
int shardOffset = 0;
uint32 taskIdIndex = 1; uint32 taskIdIndex = 1;
uint64 jobId = INVALID_JOB_ID; uint64 jobId = INVALID_JOB_ID;
ListCell *targetEntryCell = NULL; ListCell *targetEntryCell = NULL;
Relation distributedRelation = NULL;
TupleDesc destTupleDescriptor = NULL;
distributedRelation = heap_open(targetRelationId, RowExclusiveLock); Relation distributedRelation = heap_open(targetRelationId, RowExclusiveLock);
destTupleDescriptor = RelationGetDescr(distributedRelation); TupleDesc destTupleDescriptor = RelationGetDescr(distributedRelation);
/* /*
* If the type of insert column and target table's column type is * If the type of insert column and target table's column type is
@ -1388,25 +1349,22 @@ TwoPhaseInsertSelectTaskList(Oid targetRelationId, Query *insertSelectQuery,
} }
} }
for (shardOffset = 0; shardOffset < shardCount; shardOffset++) for (int shardOffset = 0; shardOffset < shardCount; shardOffset++)
{ {
ShardInterval *targetShardInterval = ShardInterval *targetShardInterval =
targetCacheEntry->sortedShardIntervalArray[shardOffset]; targetCacheEntry->sortedShardIntervalArray[shardOffset];
uint64 shardId = targetShardInterval->shardId; uint64 shardId = targetShardInterval->shardId;
List *columnAliasList = NIL; List *columnAliasList = NIL;
List *insertShardPlacementList = NIL;
Query *resultSelectQuery = NULL;
StringInfo queryString = makeStringInfo(); StringInfo queryString = makeStringInfo();
RelationShard *relationShard = NULL;
Task *modifyTask = NULL;
StringInfo resultId = makeStringInfo(); StringInfo resultId = makeStringInfo();
/* during COPY, the shard ID is appended to the result name */ /* during COPY, the shard ID is appended to the result name */
appendStringInfo(resultId, "%s_" UINT64_FORMAT, resultIdPrefix, shardId); appendStringInfo(resultId, "%s_" UINT64_FORMAT, resultIdPrefix, shardId);
/* generate the query on the intermediate result */ /* generate the query on the intermediate result */
resultSelectQuery = BuildSubPlanResultQuery(insertSelectQuery->targetList, Query *resultSelectQuery = BuildSubPlanResultQuery(insertSelectQuery->targetList,
columnAliasList, resultId->data); columnAliasList,
resultId->data);
/* put the intermediate result query in the INSERT..SELECT */ /* put the intermediate result query in the INSERT..SELECT */
selectRte->subquery = resultSelectQuery; selectRte->subquery = resultSelectQuery;
@ -1431,13 +1389,14 @@ TwoPhaseInsertSelectTaskList(Oid targetRelationId, Query *insertSelectQuery,
ereport(DEBUG2, (errmsg("distributed statement: %s", queryString->data))); ereport(DEBUG2, (errmsg("distributed statement: %s", queryString->data)));
LockShardDistributionMetadata(shardId, ShareLock); LockShardDistributionMetadata(shardId, ShareLock);
insertShardPlacementList = FinalizedShardPlacementList(shardId); List *insertShardPlacementList = FinalizedShardPlacementList(shardId);
relationShard = CitusMakeNode(RelationShard); RelationShard *relationShard = CitusMakeNode(RelationShard);
relationShard->relationId = targetShardInterval->relationId; relationShard->relationId = targetShardInterval->relationId;
relationShard->shardId = targetShardInterval->shardId; relationShard->shardId = targetShardInterval->shardId;
modifyTask = CreateBasicTask(jobId, taskIdIndex, MODIFY_TASK, queryString->data); Task *modifyTask = CreateBasicTask(jobId, taskIdIndex, MODIFY_TASK,
queryString->data);
modifyTask->dependedTaskList = NULL; modifyTask->dependedTaskList = NULL;
modifyTask->anchorShardId = shardId; modifyTask->anchorShardId = shardId;
modifyTask->taskPlacementList = insertShardPlacementList; modifyTask->taskPlacementList = insertShardPlacementList;

View File

@ -52,7 +52,6 @@ FindSubPlansUsedInNode(Node *node)
{ {
char *resultId = char *resultId =
FindIntermediateResultIdIfExists(rangeTableEntry); FindIntermediateResultIdIfExists(rangeTableEntry);
Value *resultIdValue = NULL;
if (resultId == NULL) if (resultId == NULL)
{ {
@ -63,7 +62,7 @@ FindSubPlansUsedInNode(Node *node)
* Use a Value to be able to use list_append_unique and store * Use a Value to be able to use list_append_unique and store
* the result ID in the DistributedPlan. * the result ID in the DistributedPlan.
*/ */
resultIdValue = makeString(resultId); Value *resultIdValue = makeString(resultId);
subPlanList = list_append_unique(subPlanList, resultIdValue); subPlanList = list_append_unique(subPlanList, resultIdValue);
} }
} }
@ -185,8 +184,6 @@ AppendAllAccessedWorkerNodes(List *workerNodeList, DistributedPlan *distributedP
HTAB * HTAB *
MakeIntermediateResultHTAB() MakeIntermediateResultHTAB()
{ {
HTAB *intermediateResultsHash = NULL;
uint32 hashFlags = 0;
HASHCTL info = { 0 }; HASHCTL info = { 0 };
int initialNumberOfElements = 16; int initialNumberOfElements = 16;
@ -194,10 +191,11 @@ MakeIntermediateResultHTAB()
info.entrysize = sizeof(IntermediateResultsHashEntry); info.entrysize = sizeof(IntermediateResultsHashEntry);
info.hash = string_hash; info.hash = string_hash;
info.hcxt = CurrentMemoryContext; info.hcxt = CurrentMemoryContext;
hashFlags = (HASH_ELEM | HASH_FUNCTION | HASH_CONTEXT); uint32 hashFlags = (HASH_ELEM | HASH_FUNCTION | HASH_CONTEXT);
intermediateResultsHash = hash_create("Intermediate results hash", HTAB *intermediateResultsHash = hash_create("Intermediate results hash",
initialNumberOfElements, &info, hashFlags); initialNumberOfElements, &info,
hashFlags);
return intermediateResultsHash; return intermediateResultsHash;
} }
@ -243,10 +241,10 @@ FindAllWorkerNodesUsingSubplan(HTAB *intermediateResultsHash,
static IntermediateResultsHashEntry * static IntermediateResultsHashEntry *
SearchIntermediateResult(HTAB *intermediateResultsHash, char *resultId) SearchIntermediateResult(HTAB *intermediateResultsHash, char *resultId)
{ {
IntermediateResultsHashEntry *entry = NULL;
bool found = false; bool found = false;
entry = hash_search(intermediateResultsHash, resultId, HASH_ENTER, &found); IntermediateResultsHashEntry *entry = hash_search(intermediateResultsHash, resultId,
HASH_ENTER, &found);
/* use sane defaults */ /* use sane defaults */
if (!found) if (!found)

View File

@ -343,9 +343,8 @@ ExplainTaskList(List *taskList, ExplainState *es)
foreach(taskCell, taskList) foreach(taskCell, taskList)
{ {
Task *task = (Task *) lfirst(taskCell); Task *task = (Task *) lfirst(taskCell);
RemoteExplainPlan *remoteExplain = NULL;
remoteExplain = RemoteExplain(task, es); RemoteExplainPlan *remoteExplain = RemoteExplain(task, es);
remoteExplainList = lappend(remoteExplainList, remoteExplain); remoteExplainList = lappend(remoteExplainList, remoteExplain);
if (!ExplainAllTasks) if (!ExplainAllTasks)
@ -374,14 +373,12 @@ ExplainTaskList(List *taskList, ExplainState *es)
static RemoteExplainPlan * static RemoteExplainPlan *
RemoteExplain(Task *task, ExplainState *es) RemoteExplain(Task *task, ExplainState *es)
{ {
StringInfo explainQuery = NULL;
List *taskPlacementList = task->taskPlacementList; List *taskPlacementList = task->taskPlacementList;
int placementCount = list_length(taskPlacementList); int placementCount = list_length(taskPlacementList);
int placementIndex = 0;
RemoteExplainPlan *remotePlan = NULL;
remotePlan = (RemoteExplainPlan *) palloc0(sizeof(RemoteExplainPlan)); RemoteExplainPlan *remotePlan = (RemoteExplainPlan *) palloc0(
explainQuery = BuildRemoteExplainQuery(task->queryString, es); sizeof(RemoteExplainPlan));
StringInfo explainQuery = BuildRemoteExplainQuery(task->queryString, es);
/* /*
* Use a coordinated transaction to ensure that we open a transaction block * Use a coordinated transaction to ensure that we open a transaction block
@ -389,17 +386,16 @@ RemoteExplain(Task *task, ExplainState *es)
*/ */
BeginOrContinueCoordinatedTransaction(); BeginOrContinueCoordinatedTransaction();
for (placementIndex = 0; placementIndex < placementCount; placementIndex++) for (int placementIndex = 0; placementIndex < placementCount; placementIndex++)
{ {
ShardPlacement *taskPlacement = list_nth(taskPlacementList, placementIndex); ShardPlacement *taskPlacement = list_nth(taskPlacementList, placementIndex);
MultiConnection *connection = NULL;
PGresult *queryResult = NULL; PGresult *queryResult = NULL;
int connectionFlags = 0; int connectionFlags = 0;
int executeResult = 0;
remotePlan->placementIndex = placementIndex; remotePlan->placementIndex = placementIndex;
connection = GetPlacementConnection(connectionFlags, taskPlacement, NULL); MultiConnection *connection = GetPlacementConnection(connectionFlags,
taskPlacement, NULL);
/* try other placements if we fail to connect this one */ /* try other placements if we fail to connect this one */
if (PQstatus(connection->pgConn) != CONNECTION_OK) if (PQstatus(connection->pgConn) != CONNECTION_OK)
@ -417,7 +413,7 @@ RemoteExplain(Task *task, ExplainState *es)
ExecuteCriticalRemoteCommand(connection, "SAVEPOINT citus_explain_savepoint"); ExecuteCriticalRemoteCommand(connection, "SAVEPOINT citus_explain_savepoint");
/* run explain query */ /* run explain query */
executeResult = ExecuteOptionalRemoteCommand(connection, explainQuery->data, int executeResult = ExecuteOptionalRemoteCommand(connection, explainQuery->data,
&queryResult); &queryResult);
if (executeResult != 0) if (executeResult != 0)
{ {
@ -517,11 +513,9 @@ ExplainTaskPlacement(ShardPlacement *taskPlacement, List *explainOutputList,
foreach(explainOutputCell, explainOutputList) foreach(explainOutputCell, explainOutputList)
{ {
StringInfo rowString = (StringInfo) lfirst(explainOutputCell); StringInfo rowString = (StringInfo) lfirst(explainOutputCell);
int rowLength = 0;
char *lineStart = NULL;
rowLength = strlen(rowString->data); int rowLength = strlen(rowString->data);
lineStart = rowString->data; char *lineStart = rowString->data;
/* parse the lines in the remote EXPLAIN for proper indentation */ /* parse the lines in the remote EXPLAIN for proper indentation */
while (lineStart < rowString->data + rowLength) while (lineStart < rowString->data + rowLength)
@ -646,14 +640,13 @@ ExplainOneQuery(Query *query, int cursorOptions,
} }
else else
{ {
PlannedStmt *plan;
instr_time planstart, instr_time planstart,
planduration; planduration;
INSTR_TIME_SET_CURRENT(planstart); INSTR_TIME_SET_CURRENT(planstart);
/* plan the query */ /* plan the query */
plan = pg_plan_query(query, cursorOptions, params); PlannedStmt *plan = pg_plan_query(query, cursorOptions, params);
INSTR_TIME_SET_CURRENT(planduration); INSTR_TIME_SET_CURRENT(planduration);
INSTR_TIME_SUBTRACT(planduration, planstart); INSTR_TIME_SUBTRACT(planduration, planstart);

View File

@ -116,18 +116,16 @@ JoinExprList(FromExpr *fromExpr)
if (joinList != NIL) if (joinList != NIL)
{ {
/* multiple nodes in from clause, add an explicit join between them */ /* multiple nodes in from clause, add an explicit join between them */
JoinExpr *newJoinExpr = NULL;
RangeTblRef *nextRangeTableRef = NULL;
int nextRangeTableIndex = 0; int nextRangeTableIndex = 0;
/* find the left most range table in this node */ /* find the left most range table in this node */
ExtractLeftMostRangeTableIndex((Node *) fromExpr, &nextRangeTableIndex); ExtractLeftMostRangeTableIndex((Node *) fromExpr, &nextRangeTableIndex);
nextRangeTableRef = makeNode(RangeTblRef); RangeTblRef *nextRangeTableRef = makeNode(RangeTblRef);
nextRangeTableRef->rtindex = nextRangeTableIndex; nextRangeTableRef->rtindex = nextRangeTableIndex;
/* join the previous node with nextRangeTableRef */ /* join the previous node with nextRangeTableRef */
newJoinExpr = makeNode(JoinExpr); JoinExpr *newJoinExpr = makeNode(JoinExpr);
newJoinExpr->jointype = JOIN_INNER; newJoinExpr->jointype = JOIN_INNER;
newJoinExpr->rarg = (Node *) nextRangeTableRef; newJoinExpr->rarg = (Node *) nextRangeTableRef;
newJoinExpr->quals = NULL; newJoinExpr->quals = NULL;
@ -261,17 +259,15 @@ JoinOnColumns(Var *currentColumn, Var *candidateColumn, List *joinClauseList)
List * List *
JoinOrderList(List *tableEntryList, List *joinClauseList) JoinOrderList(List *tableEntryList, List *joinClauseList)
{ {
List *bestJoinOrder = NIL;
List *candidateJoinOrderList = NIL; List *candidateJoinOrderList = NIL;
ListCell *tableEntryCell = NULL; ListCell *tableEntryCell = NULL;
foreach(tableEntryCell, tableEntryList) foreach(tableEntryCell, tableEntryList)
{ {
TableEntry *startingTable = (TableEntry *) lfirst(tableEntryCell); TableEntry *startingTable = (TableEntry *) lfirst(tableEntryCell);
List *candidateJoinOrder = NIL;
/* each candidate join order starts with a different table */ /* each candidate join order starts with a different table */
candidateJoinOrder = JoinOrderForTable(startingTable, tableEntryList, List *candidateJoinOrder = JoinOrderForTable(startingTable, tableEntryList,
joinClauseList); joinClauseList);
if (candidateJoinOrder != NULL) if (candidateJoinOrder != NULL)
@ -289,7 +285,7 @@ JoinOrderList(List *tableEntryList, List *joinClauseList)
"equal operator"))); "equal operator")));
} }
bestJoinOrder = BestJoinOrder(candidateJoinOrderList); List *bestJoinOrder = BestJoinOrder(candidateJoinOrderList);
/* if logging is enabled, print join order */ /* if logging is enabled, print join order */
if (LogMultiJoinOrder) if (LogMultiJoinOrder)
@ -312,10 +308,7 @@ JoinOrderList(List *tableEntryList, List *joinClauseList)
static List * static List *
JoinOrderForTable(TableEntry *firstTable, List *tableEntryList, List *joinClauseList) JoinOrderForTable(TableEntry *firstTable, List *tableEntryList, List *joinClauseList)
{ {
JoinOrderNode *currentJoinNode = NULL;
JoinRuleType firstJoinRule = JOIN_RULE_INVALID_FIRST; JoinRuleType firstJoinRule = JOIN_RULE_INVALID_FIRST;
List *joinOrderList = NIL;
List *joinedTableList = NIL;
int joinedTableCount = 1; int joinedTableCount = 1;
int totalTableCount = list_length(tableEntryList); int totalTableCount = list_length(tableEntryList);
@ -331,20 +324,19 @@ JoinOrderForTable(TableEntry *firstTable, List *tableEntryList, List *joinClause
firstTable); firstTable);
/* add first node to the join order */ /* add first node to the join order */
joinOrderList = list_make1(firstJoinNode); List *joinOrderList = list_make1(firstJoinNode);
joinedTableList = list_make1(firstTable); List *joinedTableList = list_make1(firstTable);
currentJoinNode = firstJoinNode; JoinOrderNode *currentJoinNode = firstJoinNode;
/* loop until we join all remaining tables */ /* loop until we join all remaining tables */
while (joinedTableCount < totalTableCount) while (joinedTableCount < totalTableCount)
{ {
List *pendingTableList = NIL;
ListCell *pendingTableCell = NULL; ListCell *pendingTableCell = NULL;
JoinOrderNode *nextJoinNode = NULL; JoinOrderNode *nextJoinNode = NULL;
TableEntry *nextJoinedTable = NULL;
JoinRuleType nextJoinRuleType = JOIN_RULE_LAST; JoinRuleType nextJoinRuleType = JOIN_RULE_LAST;
pendingTableList = TableEntryListDifference(tableEntryList, joinedTableList); List *pendingTableList = TableEntryListDifference(tableEntryList,
joinedTableList);
/* /*
* Iterate over all pending tables, and find the next best table to * Iterate over all pending tables, and find the next best table to
@ -354,13 +346,13 @@ JoinOrderForTable(TableEntry *firstTable, List *tableEntryList, List *joinClause
foreach(pendingTableCell, pendingTableList) foreach(pendingTableCell, pendingTableList)
{ {
TableEntry *pendingTable = (TableEntry *) lfirst(pendingTableCell); TableEntry *pendingTable = (TableEntry *) lfirst(pendingTableCell);
JoinOrderNode *pendingJoinNode = NULL;
JoinRuleType pendingJoinRuleType = JOIN_RULE_LAST;
JoinType joinType = JOIN_INNER; JoinType joinType = JOIN_INNER;
/* evaluate all join rules for this pending table */ /* evaluate all join rules for this pending table */
pendingJoinNode = EvaluateJoinRules(joinedTableList, currentJoinNode, JoinOrderNode *pendingJoinNode = EvaluateJoinRules(joinedTableList,
pendingTable, joinClauseList, joinType); currentJoinNode,
pendingTable,
joinClauseList, joinType);
if (pendingJoinNode == NULL) if (pendingJoinNode == NULL)
{ {
@ -369,7 +361,7 @@ JoinOrderForTable(TableEntry *firstTable, List *tableEntryList, List *joinClause
} }
/* if this rule is better than previous ones, keep it */ /* if this rule is better than previous ones, keep it */
pendingJoinRuleType = pendingJoinNode->joinRuleType; JoinRuleType pendingJoinRuleType = pendingJoinNode->joinRuleType;
if (pendingJoinRuleType < nextJoinRuleType) if (pendingJoinRuleType < nextJoinRuleType)
{ {
nextJoinNode = pendingJoinNode; nextJoinNode = pendingJoinNode;
@ -387,7 +379,7 @@ JoinOrderForTable(TableEntry *firstTable, List *tableEntryList, List *joinClause
} }
Assert(nextJoinNode != NULL); Assert(nextJoinNode != NULL);
nextJoinedTable = nextJoinNode->tableEntry; TableEntry *nextJoinedTable = nextJoinNode->tableEntry;
/* add next node to the join order */ /* add next node to the join order */
joinOrderList = lappend(joinOrderList, nextJoinNode); joinOrderList = lappend(joinOrderList, nextJoinNode);
@ -411,8 +403,6 @@ JoinOrderForTable(TableEntry *firstTable, List *tableEntryList, List *joinClause
static List * static List *
BestJoinOrder(List *candidateJoinOrders) BestJoinOrder(List *candidateJoinOrders)
{ {
List *bestJoinOrder = NULL;
uint32 ruleTypeIndex = 0;
uint32 highestValidIndex = JOIN_RULE_LAST - 1; uint32 highestValidIndex = JOIN_RULE_LAST - 1;
uint32 candidateCount PG_USED_FOR_ASSERTS_ONLY = 0; uint32 candidateCount PG_USED_FOR_ASSERTS_ONLY = 0;
@ -429,7 +419,7 @@ BestJoinOrder(List *candidateJoinOrders)
* have 3 or more, if there isn't a join order with fewer DPs; and so * have 3 or more, if there isn't a join order with fewer DPs; and so
* forth. * forth.
*/ */
for (ruleTypeIndex = highestValidIndex; ruleTypeIndex > 0; ruleTypeIndex--) for (uint32 ruleTypeIndex = highestValidIndex; ruleTypeIndex > 0; ruleTypeIndex--)
{ {
JoinRuleType ruleType = (JoinRuleType) ruleTypeIndex; JoinRuleType ruleType = (JoinRuleType) ruleTypeIndex;
@ -451,7 +441,7 @@ BestJoinOrder(List *candidateJoinOrders)
* If there still is a tie, we pick the join order whose relation appeared * If there still is a tie, we pick the join order whose relation appeared
* earliest in the query's range table entry list. * earliest in the query's range table entry list.
*/ */
bestJoinOrder = (List *) linitial(candidateJoinOrders); List *bestJoinOrder = (List *) linitial(candidateJoinOrders);
return bestJoinOrder; return bestJoinOrder;
} }
@ -662,24 +652,21 @@ EvaluateJoinRules(List *joinedTableList, JoinOrderNode *currentJoinNode,
JoinType joinType) JoinType joinType)
{ {
JoinOrderNode *nextJoinNode = NULL; JoinOrderNode *nextJoinNode = NULL;
uint32 candidateTableId = 0;
List *joinedTableIdList = NIL;
List *applicableJoinClauses = NIL;
uint32 lowestValidIndex = JOIN_RULE_INVALID_FIRST + 1; uint32 lowestValidIndex = JOIN_RULE_INVALID_FIRST + 1;
uint32 highestValidIndex = JOIN_RULE_LAST - 1; uint32 highestValidIndex = JOIN_RULE_LAST - 1;
uint32 ruleIndex = 0;
/* /*
* We first find all applicable join clauses between already joined tables * We first find all applicable join clauses between already joined tables
* and the candidate table. * and the candidate table.
*/ */
joinedTableIdList = RangeTableIdList(joinedTableList); List *joinedTableIdList = RangeTableIdList(joinedTableList);
candidateTableId = candidateTable->rangeTableId; uint32 candidateTableId = candidateTable->rangeTableId;
applicableJoinClauses = ApplicableJoinClauses(joinedTableIdList, candidateTableId, List *applicableJoinClauses = ApplicableJoinClauses(joinedTableIdList,
candidateTableId,
joinClauseList); joinClauseList);
/* we then evaluate all join rules in order */ /* we then evaluate all join rules in order */
for (ruleIndex = lowestValidIndex; ruleIndex <= highestValidIndex; ruleIndex++) for (uint32 ruleIndex = lowestValidIndex; ruleIndex <= highestValidIndex; ruleIndex++)
{ {
JoinRuleType ruleType = (JoinRuleType) ruleIndex; JoinRuleType ruleType = (JoinRuleType) ruleIndex;
RuleEvalFunction ruleEvalFunction = JoinRuleEvalFunction(ruleType); RuleEvalFunction ruleEvalFunction = JoinRuleEvalFunction(ruleType);
@ -737,7 +724,6 @@ static RuleEvalFunction
JoinRuleEvalFunction(JoinRuleType ruleType) JoinRuleEvalFunction(JoinRuleType ruleType)
{ {
static bool ruleEvalFunctionsInitialized = false; static bool ruleEvalFunctionsInitialized = false;
RuleEvalFunction ruleEvalFunction = NULL;
if (!ruleEvalFunctionsInitialized) if (!ruleEvalFunctionsInitialized)
{ {
@ -751,7 +737,7 @@ JoinRuleEvalFunction(JoinRuleType ruleType)
ruleEvalFunctionsInitialized = true; ruleEvalFunctionsInitialized = true;
} }
ruleEvalFunction = RuleEvalFunctionArray[ruleType]; RuleEvalFunction ruleEvalFunction = RuleEvalFunctionArray[ruleType];
Assert(ruleEvalFunction != NULL); Assert(ruleEvalFunction != NULL);
return ruleEvalFunction; return ruleEvalFunction;
@ -763,7 +749,6 @@ static char *
JoinRuleName(JoinRuleType ruleType) JoinRuleName(JoinRuleType ruleType)
{ {
static bool ruleNamesInitialized = false; static bool ruleNamesInitialized = false;
char *ruleName = NULL;
if (!ruleNamesInitialized) if (!ruleNamesInitialized)
{ {
@ -780,7 +765,7 @@ JoinRuleName(JoinRuleType ruleType)
ruleNamesInitialized = true; ruleNamesInitialized = true;
} }
ruleName = RuleNameArray[ruleType]; char *ruleName = RuleNameArray[ruleType];
Assert(ruleName != NULL); Assert(ruleName != NULL);
return ruleName; return ruleName;
@ -857,7 +842,6 @@ static JoinOrderNode *
LocalJoin(JoinOrderNode *currentJoinNode, TableEntry *candidateTable, LocalJoin(JoinOrderNode *currentJoinNode, TableEntry *candidateTable,
List *applicableJoinClauses, JoinType joinType) List *applicableJoinClauses, JoinType joinType)
{ {
JoinOrderNode *nextJoinNode = NULL;
Oid relationId = candidateTable->relationId; Oid relationId = candidateTable->relationId;
uint32 tableId = candidateTable->rangeTableId; uint32 tableId = candidateTable->rangeTableId;
Var *candidatePartitionColumn = PartitionColumn(relationId, tableId); Var *candidatePartitionColumn = PartitionColumn(relationId, tableId);
@ -865,8 +849,6 @@ LocalJoin(JoinOrderNode *currentJoinNode, TableEntry *candidateTable,
char candidatePartitionMethod = PartitionMethod(relationId); char candidatePartitionMethod = PartitionMethod(relationId);
char currentPartitionMethod = currentJoinNode->partitionMethod; char currentPartitionMethod = currentJoinNode->partitionMethod;
TableEntry *currentAnchorTable = currentJoinNode->anchorTable; TableEntry *currentAnchorTable = currentJoinNode->anchorTable;
bool joinOnPartitionColumns = false;
bool coPartitionedTables = false;
/* /*
* If we previously dual-hash re-partitioned the tables for a join or made cartesian * If we previously dual-hash re-partitioned the tables for a join or made cartesian
@ -883,7 +865,7 @@ LocalJoin(JoinOrderNode *currentJoinNode, TableEntry *candidateTable,
return NULL; return NULL;
} }
joinOnPartitionColumns = JoinOnColumns(currentPartitionColumn, bool joinOnPartitionColumns = JoinOnColumns(currentPartitionColumn,
candidatePartitionColumn, candidatePartitionColumn,
applicableJoinClauses); applicableJoinClauses);
if (!joinOnPartitionColumns) if (!joinOnPartitionColumns)
@ -892,14 +874,15 @@ LocalJoin(JoinOrderNode *currentJoinNode, TableEntry *candidateTable,
} }
/* shard interval lists must have 1-1 matching for local joins */ /* shard interval lists must have 1-1 matching for local joins */
coPartitionedTables = CoPartitionedTables(currentAnchorTable->relationId, relationId); bool coPartitionedTables = CoPartitionedTables(currentAnchorTable->relationId,
relationId);
if (!coPartitionedTables) if (!coPartitionedTables)
{ {
return NULL; return NULL;
} }
nextJoinNode = MakeJoinOrderNode(candidateTable, LOCAL_PARTITION_JOIN, JoinOrderNode *nextJoinNode = MakeJoinOrderNode(candidateTable, LOCAL_PARTITION_JOIN,
currentPartitionColumn, currentPartitionColumn,
currentPartitionMethod, currentPartitionMethod,
currentAnchorTable); currentAnchorTable);
@ -925,7 +908,6 @@ SinglePartitionJoin(JoinOrderNode *currentJoinNode, TableEntry *candidateTable,
TableEntry *currentAnchorTable = currentJoinNode->anchorTable; TableEntry *currentAnchorTable = currentJoinNode->anchorTable;
JoinRuleType currentJoinRuleType = currentJoinNode->joinRuleType; JoinRuleType currentJoinRuleType = currentJoinNode->joinRuleType;
OpExpr *joinClause = NULL;
Oid relationId = candidateTable->relationId; Oid relationId = candidateTable->relationId;
uint32 tableId = candidateTable->rangeTableId; uint32 tableId = candidateTable->rangeTableId;
@ -948,7 +930,7 @@ SinglePartitionJoin(JoinOrderNode *currentJoinNode, TableEntry *candidateTable,
return NULL; return NULL;
} }
joinClause = OpExpr *joinClause =
SinglePartitionJoinClause(currentPartitionColumn, applicableJoinClauses); SinglePartitionJoinClause(currentPartitionColumn, applicableJoinClauses);
if (joinClause != NULL) if (joinClause != NULL)
{ {

View File

@ -312,19 +312,8 @@ static bool HasOrderByHllType(List *sortClauseList, List *targetList);
void void
MultiLogicalPlanOptimize(MultiTreeRoot *multiLogicalPlan) MultiLogicalPlanOptimize(MultiTreeRoot *multiLogicalPlan)
{ {
bool hasOrderByHllType = false;
List *selectNodeList = NIL;
List *projectNodeList = NIL;
List *collectNodeList = NIL;
List *extendedOpNodeList = NIL;
List *tableNodeList = NIL;
ListCell *collectNodeCell = NULL; ListCell *collectNodeCell = NULL;
ListCell *tableNodeCell = NULL; ListCell *tableNodeCell = NULL;
MultiProject *projectNode = NULL;
MultiExtendedOp *extendedOpNode = NULL;
MultiExtendedOp *masterExtendedOpNode = NULL;
MultiExtendedOp *workerExtendedOpNode = NULL;
ExtendedOpNodeProperties extendedOpNodeProperties;
MultiNode *logicalPlanNode = (MultiNode *) multiLogicalPlan; MultiNode *logicalPlanNode = (MultiNode *) multiLogicalPlan;
/* check that we can optimize aggregates in the plan */ /* check that we can optimize aggregates in the plan */
@ -336,7 +325,7 @@ MultiLogicalPlanOptimize(MultiTreeRoot *multiLogicalPlan)
* exist, we modify the tree in place to swap the original select node with * exist, we modify the tree in place to swap the original select node with
* And and Or nodes. We then push down the And select node if it exists. * And and Or nodes. We then push down the And select node if it exists.
*/ */
selectNodeList = FindNodesOfType(logicalPlanNode, T_MultiSelect); List *selectNodeList = FindNodesOfType(logicalPlanNode, T_MultiSelect);
if (selectNodeList != NIL) if (selectNodeList != NIL)
{ {
MultiSelect *selectNode = (MultiSelect *) linitial(selectNodeList); MultiSelect *selectNode = (MultiSelect *) linitial(selectNodeList);
@ -365,12 +354,12 @@ MultiLogicalPlanOptimize(MultiTreeRoot *multiLogicalPlan)
} }
/* push down the multi project node */ /* push down the multi project node */
projectNodeList = FindNodesOfType(logicalPlanNode, T_MultiProject); List *projectNodeList = FindNodesOfType(logicalPlanNode, T_MultiProject);
projectNode = (MultiProject *) linitial(projectNodeList); MultiProject *projectNode = (MultiProject *) linitial(projectNodeList);
PushDownNodeLoop((MultiUnaryNode *) projectNode); PushDownNodeLoop((MultiUnaryNode *) projectNode);
/* pull up collect nodes and merge duplicate collects */ /* pull up collect nodes and merge duplicate collects */
collectNodeList = FindNodesOfType(logicalPlanNode, T_MultiCollect); List *collectNodeList = FindNodesOfType(logicalPlanNode, T_MultiCollect);
foreach(collectNodeCell, collectNodeList) foreach(collectNodeCell, collectNodeList)
{ {
MultiCollect *collectNode = (MultiCollect *) lfirst(collectNodeCell); MultiCollect *collectNode = (MultiCollect *) lfirst(collectNodeCell);
@ -385,19 +374,20 @@ MultiLogicalPlanOptimize(MultiTreeRoot *multiLogicalPlan)
* clause list to the worker operator node. We then push the worker operator * clause list to the worker operator node. We then push the worker operator
* node below the collect node. * node below the collect node.
*/ */
extendedOpNodeList = FindNodesOfType(logicalPlanNode, T_MultiExtendedOp); List *extendedOpNodeList = FindNodesOfType(logicalPlanNode, T_MultiExtendedOp);
extendedOpNode = (MultiExtendedOp *) linitial(extendedOpNodeList); MultiExtendedOp *extendedOpNode = (MultiExtendedOp *) linitial(extendedOpNodeList);
extendedOpNodeProperties = BuildExtendedOpNodeProperties(extendedOpNode); ExtendedOpNodeProperties extendedOpNodeProperties = BuildExtendedOpNodeProperties(
extendedOpNode);
masterExtendedOpNode = MultiExtendedOp *masterExtendedOpNode =
MasterExtendedOpNode(extendedOpNode, &extendedOpNodeProperties); MasterExtendedOpNode(extendedOpNode, &extendedOpNodeProperties);
workerExtendedOpNode = MultiExtendedOp *workerExtendedOpNode =
WorkerExtendedOpNode(extendedOpNode, &extendedOpNodeProperties); WorkerExtendedOpNode(extendedOpNode, &extendedOpNodeProperties);
ApplyExtendedOpNodes(extendedOpNode, masterExtendedOpNode, workerExtendedOpNode); ApplyExtendedOpNodes(extendedOpNode, masterExtendedOpNode, workerExtendedOpNode);
tableNodeList = FindNodesOfType(logicalPlanNode, T_MultiTable); List *tableNodeList = FindNodesOfType(logicalPlanNode, T_MultiTable);
foreach(tableNodeCell, tableNodeList) foreach(tableNodeCell, tableNodeList)
{ {
MultiTable *tableNode = (MultiTable *) lfirst(tableNodeCell); MultiTable *tableNode = (MultiTable *) lfirst(tableNodeCell);
@ -414,7 +404,7 @@ MultiLogicalPlanOptimize(MultiTreeRoot *multiLogicalPlan)
* clause's sortop oid, so we can't push an order by on the hll data type to * clause's sortop oid, so we can't push an order by on the hll data type to
* the worker node. We check that here and error out if necessary. * the worker node. We check that here and error out if necessary.
*/ */
hasOrderByHllType = HasOrderByHllType(workerExtendedOpNode->sortClauseList, bool hasOrderByHllType = HasOrderByHllType(workerExtendedOpNode->sortClauseList,
workerExtendedOpNode->targetList); workerExtendedOpNode->targetList);
if (hasOrderByHllType) if (hasOrderByHllType)
{ {
@ -597,7 +587,6 @@ PushDownNodeLoop(MultiUnaryNode *currentNode)
static void static void
PullUpCollectLoop(MultiCollect *collectNode) PullUpCollectLoop(MultiCollect *collectNode)
{ {
MultiNode *childNode = NULL;
MultiUnaryNode *currentNode = (MultiUnaryNode *) collectNode; MultiUnaryNode *currentNode = (MultiUnaryNode *) collectNode;
PullUpStatus pullUpStatus = CanPullUp(currentNode); PullUpStatus pullUpStatus = CanPullUp(currentNode);
@ -611,7 +600,7 @@ PullUpCollectLoop(MultiCollect *collectNode)
* After pulling up the collect node, if we find that our child node is also * After pulling up the collect node, if we find that our child node is also
* a collect, we merge the two collect nodes together by removing this node. * a collect, we merge the two collect nodes together by removing this node.
*/ */
childNode = currentNode->childNode; MultiNode *childNode = currentNode->childNode;
if (CitusIsA(childNode, MultiCollect)) if (CitusIsA(childNode, MultiCollect))
{ {
RemoveUnaryNode(currentNode); RemoveUnaryNode(currentNode);
@ -753,8 +742,8 @@ CanPullUp(MultiUnaryNode *childNode)
* Evaluate if parent can be pushed down below the child node, since it * Evaluate if parent can be pushed down below the child node, since it
* is equivalent to pulling up the child above its parent. * is equivalent to pulling up the child above its parent.
*/ */
PushDownStatus parentPushDownStatus = PUSH_DOWN_INVALID_FIRST; PushDownStatus parentPushDownStatus = Commutative((MultiUnaryNode *) parentNode,
parentPushDownStatus = Commutative((MultiUnaryNode *) parentNode, childNode); childNode);
if (parentPushDownStatus == PUSH_DOWN_VALID) if (parentPushDownStatus == PUSH_DOWN_VALID)
{ {
@ -932,8 +921,6 @@ SelectClauseTableIdList(List *selectClauseList)
{ {
Node *selectClause = (Node *) lfirst(selectClauseCell); Node *selectClause = (Node *) lfirst(selectClauseCell);
List *selectColumnList = pull_var_clause_default(selectClause); List *selectColumnList = pull_var_clause_default(selectClause);
Var *selectColumn = NULL;
int selectColumnTableId = 0;
if (list_length(selectColumnList) == 0) if (list_length(selectColumnList) == 0)
{ {
@ -941,8 +928,8 @@ SelectClauseTableIdList(List *selectClauseList)
continue; continue;
} }
selectColumn = (Var *) linitial(selectColumnList); Var *selectColumn = (Var *) linitial(selectColumnList);
selectColumnTableId = (int) selectColumn->varno; int selectColumnTableId = (int) selectColumn->varno;
tableIdList = lappend_int(tableIdList, selectColumnTableId); tableIdList = lappend_int(tableIdList, selectColumnTableId);
} }
@ -1014,9 +1001,9 @@ GenerateNode(MultiUnaryNode *currentNode, MultiNode *childNode)
{ {
MultiSelect *selectNode = (MultiSelect *) currentNode; MultiSelect *selectNode = (MultiSelect *) currentNode;
List *selectClauseList = copyObject(selectNode->selectClauseList); List *selectClauseList = copyObject(selectNode->selectClauseList);
List *newSelectClauseList = NIL;
newSelectClauseList = TableIdListSelectClauses(tableIdList, selectClauseList); List *newSelectClauseList = TableIdListSelectClauses(tableIdList,
selectClauseList);
if (newSelectClauseList != NIL) if (newSelectClauseList != NIL)
{ {
MultiSelect *newSelectNode = CitusMakeNode(MultiSelect); MultiSelect *newSelectNode = CitusMakeNode(MultiSelect);
@ -1370,7 +1357,6 @@ static MultiExtendedOp *
MasterExtendedOpNode(MultiExtendedOp *originalOpNode, MasterExtendedOpNode(MultiExtendedOp *originalOpNode,
ExtendedOpNodeProperties *extendedOpNodeProperties) ExtendedOpNodeProperties *extendedOpNodeProperties)
{ {
MultiExtendedOp *masterExtendedOpNode = NULL;
List *targetEntryList = originalOpNode->targetList; List *targetEntryList = originalOpNode->targetList;
List *newTargetEntryList = NIL; List *newTargetEntryList = NIL;
ListCell *targetEntryCell = NULL; ListCell *targetEntryCell = NULL;
@ -1433,7 +1419,7 @@ MasterExtendedOpNode(MultiExtendedOp *originalOpNode,
newHavingQual = MasterAggregateMutator(originalHavingQual, walkerContext); newHavingQual = MasterAggregateMutator(originalHavingQual, walkerContext);
} }
masterExtendedOpNode = CitusMakeNode(MultiExtendedOp); MultiExtendedOp *masterExtendedOpNode = CitusMakeNode(MultiExtendedOp);
masterExtendedOpNode->targetList = newTargetEntryList; masterExtendedOpNode->targetList = newTargetEntryList;
masterExtendedOpNode->groupClauseList = originalOpNode->groupClauseList; masterExtendedOpNode->groupClauseList = originalOpNode->groupClauseList;
masterExtendedOpNode->sortClauseList = originalOpNode->sortClauseList; masterExtendedOpNode->sortClauseList = originalOpNode->sortClauseList;
@ -1510,7 +1496,6 @@ MasterAggregateExpression(Aggref *originalAggregate,
{ {
AggregateType aggregateType = GetAggregateType(originalAggregate->aggfnoid); AggregateType aggregateType = GetAggregateType(originalAggregate->aggfnoid);
Expr *newMasterExpression = NULL; Expr *newMasterExpression = NULL;
Expr *typeConvertedExpression = NULL;
const uint32 masterTableId = 1; /* one table on the master node */ const uint32 masterTableId = 1; /* one table on the master node */
const Index columnLevelsUp = 0; /* normal column */ const Index columnLevelsUp = 0; /* normal column */
const AttrNumber argumentId = 1; /* our aggregates have single arguments */ const AttrNumber argumentId = 1; /* our aggregates have single arguments */
@ -1576,9 +1561,6 @@ MasterAggregateExpression(Aggref *originalAggregate,
const int argCount = 1; const int argCount = 1;
const int defaultTypeMod = -1; const int defaultTypeMod = -1;
TargetEntry *hllTargetEntry = NULL;
Aggref *unionAggregate = NULL;
FuncExpr *cardinalityExpression = NULL;
/* extract schema name of hll */ /* extract schema name of hll */
Oid hllId = get_extension_oid(HLL_EXTENSION_NAME, false); Oid hllId = get_extension_oid(HLL_EXTENSION_NAME, false);
@ -1598,9 +1580,10 @@ MasterAggregateExpression(Aggref *originalAggregate,
hllTypeCollationId, columnLevelsUp); hllTypeCollationId, columnLevelsUp);
walkerContext->columnId++; walkerContext->columnId++;
hllTargetEntry = makeTargetEntry((Expr *) hllColumn, argumentId, NULL, false); TargetEntry *hllTargetEntry = makeTargetEntry((Expr *) hllColumn, argumentId,
NULL, false);
unionAggregate = makeNode(Aggref); Aggref *unionAggregate = makeNode(Aggref);
unionAggregate->aggfnoid = unionFunctionId; unionAggregate->aggfnoid = unionFunctionId;
unionAggregate->aggtype = hllType; unionAggregate->aggtype = hllType;
unionAggregate->args = list_make1(hllTargetEntry); unionAggregate->args = list_make1(hllTargetEntry);
@ -1610,7 +1593,7 @@ MasterAggregateExpression(Aggref *originalAggregate,
unionAggregate->aggargtypes = list_make1_oid(unionAggregate->aggtype); unionAggregate->aggargtypes = list_make1_oid(unionAggregate->aggtype);
unionAggregate->aggsplit = AGGSPLIT_SIMPLE; unionAggregate->aggsplit = AGGSPLIT_SIMPLE;
cardinalityExpression = makeNode(FuncExpr); FuncExpr *cardinalityExpression = makeNode(FuncExpr);
cardinalityExpression->funcid = cardinalityFunctionId; cardinalityExpression->funcid = cardinalityFunctionId;
cardinalityExpression->funcresulttype = cardinalityReturnType; cardinalityExpression->funcresulttype = cardinalityReturnType;
cardinalityExpression->args = list_make1(unionAggregate); cardinalityExpression->args = list_make1(unionAggregate);
@ -1647,12 +1630,6 @@ MasterAggregateExpression(Aggref *originalAggregate,
* Count aggregates are handled in two steps. First, worker nodes report * Count aggregates are handled in two steps. First, worker nodes report
* their count results. Then, the master node sums up these results. * their count results. Then, the master node sums up these results.
*/ */
Var *column = NULL;
TargetEntry *columnTargetEntry = NULL;
CoerceViaIO *coerceExpr = NULL;
Const *zeroConst = NULL;
List *coalesceArgs = NULL;
CoalesceExpr *coalesceExpr = NULL;
/* worker aggregate and original aggregate have the same return type */ /* worker aggregate and original aggregate have the same return type */
Oid workerReturnType = exprType((Node *) originalAggregate); Oid workerReturnType = exprType((Node *) originalAggregate);
@ -1673,16 +1650,17 @@ MasterAggregateExpression(Aggref *originalAggregate,
newMasterAggregate->aggargtypes = list_make1_oid(newMasterAggregate->aggtype); newMasterAggregate->aggargtypes = list_make1_oid(newMasterAggregate->aggtype);
newMasterAggregate->aggsplit = AGGSPLIT_SIMPLE; newMasterAggregate->aggsplit = AGGSPLIT_SIMPLE;
column = makeVar(masterTableId, walkerContext->columnId, workerReturnType, Var *column = makeVar(masterTableId, walkerContext->columnId, workerReturnType,
workerReturnTypeMod, workerCollationId, columnLevelsUp); workerReturnTypeMod, workerCollationId, columnLevelsUp);
walkerContext->columnId++; walkerContext->columnId++;
/* aggref expects its arguments to be wrapped in target entries */ /* aggref expects its arguments to be wrapped in target entries */
columnTargetEntry = makeTargetEntry((Expr *) column, argumentId, NULL, false); TargetEntry *columnTargetEntry = makeTargetEntry((Expr *) column, argumentId,
NULL, false);
newMasterAggregate->args = list_make1(columnTargetEntry); newMasterAggregate->args = list_make1(columnTargetEntry);
/* cast numeric sum result to bigint (count's return type) */ /* cast numeric sum result to bigint (count's return type) */
coerceExpr = makeNode(CoerceViaIO); CoerceViaIO *coerceExpr = makeNode(CoerceViaIO);
coerceExpr->arg = (Expr *) newMasterAggregate; coerceExpr->arg = (Expr *) newMasterAggregate;
coerceExpr->resulttype = INT8OID; coerceExpr->resulttype = INT8OID;
coerceExpr->resultcollid = InvalidOid; coerceExpr->resultcollid = InvalidOid;
@ -1690,10 +1668,10 @@ MasterAggregateExpression(Aggref *originalAggregate,
coerceExpr->location = -1; coerceExpr->location = -1;
/* convert NULL to 0 in case of no rows */ /* convert NULL to 0 in case of no rows */
zeroConst = MakeIntegerConstInt64(0); Const *zeroConst = MakeIntegerConstInt64(0);
coalesceArgs = list_make2(coerceExpr, zeroConst); List *coalesceArgs = list_make2(coerceExpr, zeroConst);
coalesceExpr = makeNode(CoalesceExpr); CoalesceExpr *coalesceExpr = makeNode(CoalesceExpr);
coalesceExpr->coalescetype = INT8OID; coalesceExpr->coalescetype = INT8OID;
coalesceExpr->coalescecollid = InvalidOid; coalesceExpr->coalescecollid = InvalidOid;
coalesceExpr->args = coalesceArgs; coalesceExpr->args = coalesceArgs;
@ -1713,10 +1691,6 @@ MasterAggregateExpression(Aggref *originalAggregate,
* the arrays or jsons on the master and compute the array_cat_agg() * the arrays or jsons on the master and compute the array_cat_agg()
* or jsonb_cat_agg() aggregate on them to get the final array or json. * or jsonb_cat_agg() aggregate on them to get the final array or json.
*/ */
Var *column = NULL;
TargetEntry *catAggArgument = NULL;
Aggref *newMasterAggregate = NULL;
Oid aggregateFunctionId = InvalidOid;
const char *catAggregateName = NULL; const char *catAggregateName = NULL;
Oid catInputType = InvalidOid; Oid catInputType = InvalidOid;
@ -1753,17 +1727,18 @@ MasterAggregateExpression(Aggref *originalAggregate,
Assert(catAggregateName != NULL); Assert(catAggregateName != NULL);
Assert(catInputType != InvalidOid); Assert(catInputType != InvalidOid);
aggregateFunctionId = AggregateFunctionOid(catAggregateName, Oid aggregateFunctionId = AggregateFunctionOid(catAggregateName,
catInputType); catInputType);
/* create argument for the array_cat_agg() or jsonb_cat_agg() aggregate */ /* create argument for the array_cat_agg() or jsonb_cat_agg() aggregate */
column = makeVar(masterTableId, walkerContext->columnId, workerReturnType, Var *column = makeVar(masterTableId, walkerContext->columnId, workerReturnType,
workerReturnTypeMod, workerCollationId, columnLevelsUp); workerReturnTypeMod, workerCollationId, columnLevelsUp);
catAggArgument = makeTargetEntry((Expr *) column, argumentId, NULL, false); TargetEntry *catAggArgument = makeTargetEntry((Expr *) column, argumentId, NULL,
false);
walkerContext->columnId++; walkerContext->columnId++;
/* construct the master array_cat_agg() or jsonb_cat_agg() expression */ /* construct the master array_cat_agg() or jsonb_cat_agg() expression */
newMasterAggregate = copyObject(originalAggregate); Aggref *newMasterAggregate = copyObject(originalAggregate);
newMasterAggregate->aggfnoid = aggregateFunctionId; newMasterAggregate->aggfnoid = aggregateFunctionId;
newMasterAggregate->args = list_make1(catAggArgument); newMasterAggregate->args = list_make1(catAggArgument);
newMasterAggregate->aggfilter = NULL; newMasterAggregate->aggfilter = NULL;
@ -1781,8 +1756,6 @@ MasterAggregateExpression(Aggref *originalAggregate,
* to apply in the master after running the original aggregate in * to apply in the master after running the original aggregate in
* workers. * workers.
*/ */
TargetEntry *hllTargetEntry = NULL;
Aggref *unionAggregate = NULL;
Oid hllType = exprType((Node *) originalAggregate); Oid hllType = exprType((Node *) originalAggregate);
Oid unionFunctionId = AggregateFunctionOid(HLL_UNION_AGGREGATE_NAME, hllType); Oid unionFunctionId = AggregateFunctionOid(HLL_UNION_AGGREGATE_NAME, hllType);
@ -1793,9 +1766,10 @@ MasterAggregateExpression(Aggref *originalAggregate,
hllReturnTypeMod, hllTypeCollationId, columnLevelsUp); hllReturnTypeMod, hllTypeCollationId, columnLevelsUp);
walkerContext->columnId++; walkerContext->columnId++;
hllTargetEntry = makeTargetEntry((Expr *) hllColumn, argumentId, NULL, false); TargetEntry *hllTargetEntry = makeTargetEntry((Expr *) hllColumn, argumentId,
NULL, false);
unionAggregate = makeNode(Aggref); Aggref *unionAggregate = makeNode(Aggref);
unionAggregate->aggfnoid = unionFunctionId; unionAggregate->aggfnoid = unionFunctionId;
unionAggregate->aggtype = hllType; unionAggregate->aggtype = hllType;
unionAggregate->args = list_make1(hllTargetEntry); unionAggregate->args = list_make1(hllTargetEntry);
@ -1816,8 +1790,6 @@ MasterAggregateExpression(Aggref *originalAggregate,
* Then, we gather the Top-Ns on the master and take the union of all * Then, we gather the Top-Ns on the master and take the union of all
* to get the final topn. * to get the final topn.
*/ */
TargetEntry *topNTargetEntry = NULL;
Aggref *unionAggregate = NULL;
/* worker aggregate and original aggregate have same return type */ /* worker aggregate and original aggregate have same return type */
Oid topnType = exprType((Node *) originalAggregate); Oid topnType = exprType((Node *) originalAggregate);
@ -1831,10 +1803,11 @@ MasterAggregateExpression(Aggref *originalAggregate,
topnReturnTypeMod, topnTypeCollationId, columnLevelsUp); topnReturnTypeMod, topnTypeCollationId, columnLevelsUp);
walkerContext->columnId++; walkerContext->columnId++;
topNTargetEntry = makeTargetEntry((Expr *) topnColumn, argumentId, NULL, false); TargetEntry *topNTargetEntry = makeTargetEntry((Expr *) topnColumn, argumentId,
NULL, false);
/* construct the master topn_union_agg() expression */ /* construct the master topn_union_agg() expression */
unionAggregate = makeNode(Aggref); Aggref *unionAggregate = makeNode(Aggref);
unionAggregate->aggfnoid = unionFunctionId; unionAggregate->aggfnoid = unionFunctionId;
unionAggregate->aggtype = topnType; unionAggregate->aggtype = topnType;
unionAggregate->args = list_make1(topNTargetEntry); unionAggregate->args = list_make1(topNTargetEntry);
@ -1869,32 +1842,30 @@ MasterAggregateExpression(Aggref *originalAggregate,
if (combine != InvalidOid) if (combine != InvalidOid)
{ {
Const *aggOidParam = NULL;
Var *column = NULL;
Const *nullTag = NULL;
List *aggArguments = NIL;
Aggref *newMasterAggregate = NULL;
Oid coordCombineId = CoordCombineAggOid(); Oid coordCombineId = CoordCombineAggOid();
Oid workerReturnType = CSTRINGOID; Oid workerReturnType = CSTRINGOID;
int32 workerReturnTypeMod = -1; int32 workerReturnTypeMod = -1;
Oid workerCollationId = InvalidOid; Oid workerCollationId = InvalidOid;
Oid resultType = exprType((Node *) originalAggregate); Oid resultType = exprType((Node *) originalAggregate);
aggOidParam = makeConst(OIDOID, -1, InvalidOid, sizeof(Oid), Const *aggOidParam = makeConst(OIDOID, -1, InvalidOid, sizeof(Oid),
ObjectIdGetDatum(originalAggregate->aggfnoid), ObjectIdGetDatum(originalAggregate->aggfnoid),
false, true); false, true);
column = makeVar(masterTableId, walkerContext->columnId, workerReturnType, Var *column = makeVar(masterTableId, walkerContext->columnId,
workerReturnType,
workerReturnTypeMod, workerCollationId, columnLevelsUp); workerReturnTypeMod, workerCollationId, columnLevelsUp);
walkerContext->columnId++; walkerContext->columnId++;
nullTag = makeNullConst(resultType, -1, InvalidOid); Const *nullTag = makeNullConst(resultType, -1, InvalidOid);
aggArguments = list_make3(makeTargetEntry((Expr *) aggOidParam, 1, NULL, List *aggArguments = list_make3(makeTargetEntry((Expr *) aggOidParam, 1, NULL,
false), false),
makeTargetEntry((Expr *) column, 2, NULL, false), makeTargetEntry((Expr *) column, 2, NULL,
makeTargetEntry((Expr *) nullTag, 3, NULL, false)); false),
makeTargetEntry((Expr *) nullTag, 3, NULL,
false));
/* coord_combine_agg(agg, workercol) */ /* coord_combine_agg(agg, workercol) */
newMasterAggregate = makeNode(Aggref); Aggref *newMasterAggregate = makeNode(Aggref);
newMasterAggregate->aggfnoid = coordCombineId; newMasterAggregate->aggfnoid = coordCombineId;
newMasterAggregate->aggtype = originalAggregate->aggtype; newMasterAggregate->aggtype = originalAggregate->aggtype;
newMasterAggregate->args = aggArguments; newMasterAggregate->args = aggArguments;
@ -1918,9 +1889,6 @@ MasterAggregateExpression(Aggref *originalAggregate,
* All other aggregates are handled as they are. These include sum, min, * All other aggregates are handled as they are. These include sum, min,
* and max. * and max.
*/ */
Var *column = NULL;
TargetEntry *columnTargetEntry = NULL;
Aggref *newMasterAggregate = NULL;
/* worker aggregate and original aggregate have the same return type */ /* worker aggregate and original aggregate have the same return type */
Oid workerReturnType = exprType((Node *) originalAggregate); Oid workerReturnType = exprType((Node *) originalAggregate);
@ -1940,18 +1908,19 @@ MasterAggregateExpression(Aggref *originalAggregate,
{ {
masterReturnType = workerReturnType; masterReturnType = workerReturnType;
} }
newMasterAggregate = copyObject(originalAggregate); Aggref *newMasterAggregate = copyObject(originalAggregate);
newMasterAggregate->aggdistinct = NULL; newMasterAggregate->aggdistinct = NULL;
newMasterAggregate->aggfnoid = aggregateFunctionId; newMasterAggregate->aggfnoid = aggregateFunctionId;
newMasterAggregate->aggtype = masterReturnType; newMasterAggregate->aggtype = masterReturnType;
newMasterAggregate->aggfilter = NULL; newMasterAggregate->aggfilter = NULL;
column = makeVar(masterTableId, walkerContext->columnId, workerReturnType, Var *column = makeVar(masterTableId, walkerContext->columnId, workerReturnType,
workerReturnTypeMod, workerCollationId, columnLevelsUp); workerReturnTypeMod, workerCollationId, columnLevelsUp);
walkerContext->columnId++; walkerContext->columnId++;
/* aggref expects its arguments to be wrapped in target entries */ /* aggref expects its arguments to be wrapped in target entries */
columnTargetEntry = makeTargetEntry((Expr *) column, argumentId, NULL, false); TargetEntry *columnTargetEntry = makeTargetEntry((Expr *) column, argumentId,
NULL, false);
newMasterAggregate->args = list_make1(columnTargetEntry); newMasterAggregate->args = list_make1(columnTargetEntry);
newMasterExpression = (Expr *) newMasterAggregate; newMasterExpression = (Expr *) newMasterAggregate;
@ -1964,7 +1933,7 @@ MasterAggregateExpression(Aggref *originalAggregate,
* type as the original aggregate. We need this since functions like sorting * type as the original aggregate. We need this since functions like sorting
* and grouping have already been chosen based on the original type. * and grouping have already been chosen based on the original type.
*/ */
typeConvertedExpression = AddTypeConversion((Node *) originalAggregate, Expr *typeConvertedExpression = AddTypeConversion((Node *) originalAggregate,
(Node *) newMasterExpression); (Node *) newMasterExpression);
if (typeConvertedExpression != NULL) if (typeConvertedExpression != NULL)
{ {
@ -1999,22 +1968,15 @@ MasterAverageExpression(Oid sumAggregateType, Oid countAggregateType,
Oid sumTypeCollationId = get_typcollation(sumAggregateType); Oid sumTypeCollationId = get_typcollation(sumAggregateType);
Oid countTypeCollationId = get_typcollation(countAggregateType); Oid countTypeCollationId = get_typcollation(countAggregateType);
Var *firstColumn = NULL;
Var *secondColumn = NULL;
TargetEntry *firstTargetEntry = NULL;
TargetEntry *secondTargetEntry = NULL;
Aggref *firstSum = NULL;
Aggref *secondSum = NULL;
List *operatorNameList = NIL;
Expr *opExpr = NULL;
/* create the first argument for sum(column1) */ /* create the first argument for sum(column1) */
firstColumn = makeVar(masterTableId, (*columnId), sumAggregateType, Var *firstColumn = makeVar(masterTableId, (*columnId), sumAggregateType,
defaultTypeMod, sumTypeCollationId, defaultLevelsUp); defaultTypeMod, sumTypeCollationId, defaultLevelsUp);
firstTargetEntry = makeTargetEntry((Expr *) firstColumn, argumentId, NULL, false); TargetEntry *firstTargetEntry = makeTargetEntry((Expr *) firstColumn, argumentId,
NULL, false);
(*columnId)++; (*columnId)++;
firstSum = makeNode(Aggref); Aggref *firstSum = makeNode(Aggref);
firstSum->aggfnoid = AggregateFunctionOid(sumAggregateName, sumAggregateType); firstSum->aggfnoid = AggregateFunctionOid(sumAggregateName, sumAggregateType);
firstSum->aggtype = get_func_rettype(firstSum->aggfnoid); firstSum->aggtype = get_func_rettype(firstSum->aggfnoid);
firstSum->args = list_make1(firstTargetEntry); firstSum->args = list_make1(firstTargetEntry);
@ -2024,12 +1986,13 @@ MasterAverageExpression(Oid sumAggregateType, Oid countAggregateType,
firstSum->aggsplit = AGGSPLIT_SIMPLE; firstSum->aggsplit = AGGSPLIT_SIMPLE;
/* create the second argument for sum(column2) */ /* create the second argument for sum(column2) */
secondColumn = makeVar(masterTableId, (*columnId), countAggregateType, Var *secondColumn = makeVar(masterTableId, (*columnId), countAggregateType,
defaultTypeMod, countTypeCollationId, defaultLevelsUp); defaultTypeMod, countTypeCollationId, defaultLevelsUp);
secondTargetEntry = makeTargetEntry((Expr *) secondColumn, argumentId, NULL, false); TargetEntry *secondTargetEntry = makeTargetEntry((Expr *) secondColumn, argumentId,
NULL, false);
(*columnId)++; (*columnId)++;
secondSum = makeNode(Aggref); Aggref *secondSum = makeNode(Aggref);
secondSum->aggfnoid = AggregateFunctionOid(sumAggregateName, countAggregateType); secondSum->aggfnoid = AggregateFunctionOid(sumAggregateName, countAggregateType);
secondSum->aggtype = get_func_rettype(secondSum->aggfnoid); secondSum->aggtype = get_func_rettype(secondSum->aggfnoid);
secondSum->args = list_make1(secondTargetEntry); secondSum->args = list_make1(secondTargetEntry);
@ -2042,8 +2005,9 @@ MasterAverageExpression(Oid sumAggregateType, Oid countAggregateType,
* Build the division operator between these two aggregates. This function * Build the division operator between these two aggregates. This function
* will convert the types of the aggregates if necessary. * will convert the types of the aggregates if necessary.
*/ */
operatorNameList = list_make1(makeString(DIVISION_OPER_NAME)); List *operatorNameList = list_make1(makeString(DIVISION_OPER_NAME));
opExpr = make_op(NULL, operatorNameList, (Node *) firstSum, (Node *) secondSum, NULL, Expr *opExpr = make_op(NULL, operatorNameList, (Node *) firstSum, (Node *) secondSum,
NULL,
-1); -1);
return opExpr; return opExpr;
@ -2061,7 +2025,6 @@ AddTypeConversion(Node *originalAggregate, Node *newExpression)
Oid newTypeId = exprType(newExpression); Oid newTypeId = exprType(newExpression);
Oid originalTypeId = exprType(originalAggregate); Oid originalTypeId = exprType(originalAggregate);
int32 originalTypeMod = exprTypmod(originalAggregate); int32 originalTypeMod = exprTypmod(originalAggregate);
Node *typeConvertedExpression = NULL;
/* nothing to do if the two types are the same */ /* nothing to do if the two types are the same */
if (originalTypeId == newTypeId) if (originalTypeId == newTypeId)
@ -2070,7 +2033,7 @@ AddTypeConversion(Node *originalAggregate, Node *newExpression)
} }
/* otherwise, add a type conversion function */ /* otherwise, add a type conversion function */
typeConvertedExpression = coerce_to_target_type(NULL, newExpression, newTypeId, Node *typeConvertedExpression = coerce_to_target_type(NULL, newExpression, newTypeId,
originalTypeId, originalTypeMod, originalTypeId, originalTypeMod,
COERCION_EXPLICIT, COERCION_EXPLICIT,
COERCE_EXPLICIT_CAST, -1); COERCE_EXPLICIT_CAST, -1);
@ -2090,10 +2053,7 @@ static MultiExtendedOp *
WorkerExtendedOpNode(MultiExtendedOp *originalOpNode, WorkerExtendedOpNode(MultiExtendedOp *originalOpNode,
ExtendedOpNodeProperties *extendedOpNodeProperties) ExtendedOpNodeProperties *extendedOpNodeProperties)
{ {
MultiExtendedOp *workerExtendedOpNode = NULL;
Index nextSortGroupRefIndex = 0;
bool distinctPreventsLimitPushdown = false; bool distinctPreventsLimitPushdown = false;
bool groupByExtended = false;
bool groupedByDisjointPartitionColumn = bool groupedByDisjointPartitionColumn =
extendedOpNodeProperties->groupedByDisjointPartitionColumn; extendedOpNodeProperties->groupedByDisjointPartitionColumn;
@ -2125,7 +2085,7 @@ WorkerExtendedOpNode(MultiExtendedOp *originalOpNode,
memset(&queryOrderByLimit, 0, sizeof(queryGroupClause)); memset(&queryOrderByLimit, 0, sizeof(queryGroupClause));
/* calculate the next sort group index based on the original target list */ /* calculate the next sort group index based on the original target list */
nextSortGroupRefIndex = GetNextSortGroupRef(originalTargetEntryList); Index nextSortGroupRefIndex = GetNextSortGroupRef(originalTargetEntryList);
/* targetProjectionNumber starts from 1 */ /* targetProjectionNumber starts from 1 */
queryTargetList.targetProjectionNumber = 1; queryTargetList.targetProjectionNumber = 1;
@ -2167,7 +2127,7 @@ WorkerExtendedOpNode(MultiExtendedOp *originalOpNode,
* (1) Creating a new group by clause during aggregate mutation, or * (1) Creating a new group by clause during aggregate mutation, or
* (2) Distinct clause is not pushed down * (2) Distinct clause is not pushed down
*/ */
groupByExtended = bool groupByExtended =
list_length(queryGroupClause.groupClauseList) > originalGroupClauseLength; list_length(queryGroupClause.groupClauseList) > originalGroupClauseLength;
if (!groupByExtended && !distinctPreventsLimitPushdown) if (!groupByExtended && !distinctPreventsLimitPushdown)
{ {
@ -2188,7 +2148,7 @@ WorkerExtendedOpNode(MultiExtendedOp *originalOpNode,
} }
/* finally, fill the extended op node with the data we gathered */ /* finally, fill the extended op node with the data we gathered */
workerExtendedOpNode = CitusMakeNode(MultiExtendedOp); MultiExtendedOp *workerExtendedOpNode = CitusMakeNode(MultiExtendedOp);
workerExtendedOpNode->targetList = queryTargetList.targetEntryList; workerExtendedOpNode->targetList = queryTargetList.targetEntryList;
workerExtendedOpNode->groupClauseList = queryGroupClause.groupClauseList; workerExtendedOpNode->groupClauseList = queryGroupClause.groupClauseList;
@ -2303,9 +2263,7 @@ ProcessHavingClauseForWorkerQuery(Node *originalHavingQual,
QueryTargetList *queryTargetList, QueryTargetList *queryTargetList,
QueryGroupClause *queryGroupClause) QueryGroupClause *queryGroupClause)
{ {
List *newExpressionList = NIL;
TargetEntry *targetEntry = NULL; TargetEntry *targetEntry = NULL;
WorkerAggregateWalkerContext *workerAggContext = NULL;
if (originalHavingQual == NULL) if (originalHavingQual == NULL)
{ {
@ -2314,13 +2272,14 @@ ProcessHavingClauseForWorkerQuery(Node *originalHavingQual,
*workerHavingQual = NULL; *workerHavingQual = NULL;
workerAggContext = palloc0(sizeof(WorkerAggregateWalkerContext)); WorkerAggregateWalkerContext *workerAggContext = palloc0(
sizeof(WorkerAggregateWalkerContext));
workerAggContext->expressionList = NIL; workerAggContext->expressionList = NIL;
workerAggContext->pullDistinctColumns = extendedOpNodeProperties->pullDistinctColumns; workerAggContext->pullDistinctColumns = extendedOpNodeProperties->pullDistinctColumns;
workerAggContext->createGroupByClause = false; workerAggContext->createGroupByClause = false;
WorkerAggregateWalker(originalHavingQual, workerAggContext); WorkerAggregateWalker(originalHavingQual, workerAggContext);
newExpressionList = workerAggContext->expressionList; List *newExpressionList = workerAggContext->expressionList;
ExpandWorkerTargetEntry(newExpressionList, targetEntry, ExpandWorkerTargetEntry(newExpressionList, targetEntry,
workerAggContext->createGroupByClause, workerAggContext->createGroupByClause,
@ -2385,7 +2344,6 @@ ProcessDistinctClauseForWorkerQuery(List *distinctClause, bool hasDistinctOn,
bool *distinctPreventsLimitPushdown) bool *distinctPreventsLimitPushdown)
{ {
bool distinctClauseSupersetofGroupClause = false; bool distinctClauseSupersetofGroupClause = false;
bool shouldPushdownDistinct = false;
if (distinctClause == NIL) if (distinctClause == NIL)
{ {
@ -2419,7 +2377,7 @@ ProcessDistinctClauseForWorkerQuery(List *distinctClause, bool hasDistinctOn,
* distinct pushdown if distinct clause is missing some entries that * distinct pushdown if distinct clause is missing some entries that
* group by clause has. * group by clause has.
*/ */
shouldPushdownDistinct = !queryHasAggregates && bool shouldPushdownDistinct = !queryHasAggregates &&
distinctClauseSupersetofGroupClause; distinctClauseSupersetofGroupClause;
if (shouldPushdownDistinct) if (shouldPushdownDistinct)
{ {
@ -2524,8 +2482,6 @@ ProcessLimitOrderByForWorkerQuery(OrderByLimitReference orderByLimitReference,
QueryOrderByLimit *queryOrderByLimit, QueryOrderByLimit *queryOrderByLimit,
QueryTargetList *queryTargetList) QueryTargetList *queryTargetList)
{ {
List *newTargetEntryListForSortClauses = NIL;
queryOrderByLimit->workerLimitCount = queryOrderByLimit->workerLimitCount =
WorkerLimitCount(originalLimitCount, limitOffset, orderByLimitReference); WorkerLimitCount(originalLimitCount, limitOffset, orderByLimitReference);
@ -2539,7 +2495,7 @@ ProcessLimitOrderByForWorkerQuery(OrderByLimitReference orderByLimitReference,
* TODO: Do we really need to add the target entries if we're not pushing * TODO: Do we really need to add the target entries if we're not pushing
* down ORDER BY? * down ORDER BY?
*/ */
newTargetEntryListForSortClauses = List *newTargetEntryListForSortClauses =
GenerateNewTargetEntriesForSortClauses(originalTargetList, GenerateNewTargetEntriesForSortClauses(originalTargetList,
queryOrderByLimit->workerSortClauseList, queryOrderByLimit->workerSortClauseList,
&(queryTargetList->targetProjectionNumber), &(queryTargetList->targetProjectionNumber),
@ -2634,10 +2590,9 @@ ExpandWorkerTargetEntry(List *expressionList, TargetEntry *originalTargetEntry,
foreach(newExpressionCell, expressionList) foreach(newExpressionCell, expressionList)
{ {
Expr *newExpression = (Expr *) lfirst(newExpressionCell); Expr *newExpression = (Expr *) lfirst(newExpressionCell);
TargetEntry *newTargetEntry = NULL;
/* generate and add the new target entry to the target list */ /* generate and add the new target entry to the target list */
newTargetEntry = TargetEntry *newTargetEntry =
GenerateWorkerTargetEntry(originalTargetEntry, newExpression, GenerateWorkerTargetEntry(originalTargetEntry, newExpression,
queryTargetList->targetProjectionNumber); queryTargetList->targetProjectionNumber);
(queryTargetList->targetProjectionNumber)++; (queryTargetList->targetProjectionNumber)++;
@ -2749,14 +2704,12 @@ AppendTargetEntryToGroupClause(TargetEntry *targetEntry,
QueryGroupClause *queryGroupClause) QueryGroupClause *queryGroupClause)
{ {
Expr *targetExpr PG_USED_FOR_ASSERTS_ONLY = targetEntry->expr; Expr *targetExpr PG_USED_FOR_ASSERTS_ONLY = targetEntry->expr;
Var *targetColumn = NULL;
SortGroupClause *groupByClause = NULL;
/* we currently only support appending Var target entries */ /* we currently only support appending Var target entries */
AssertArg(IsA(targetExpr, Var)); AssertArg(IsA(targetExpr, Var));
targetColumn = (Var *) targetEntry->expr; Var *targetColumn = (Var *) targetEntry->expr;
groupByClause = CreateSortGroupClause(targetColumn); SortGroupClause *groupByClause = CreateSortGroupClause(targetColumn);
/* the target entry should have an index */ /* the target entry should have an index */
targetEntry->ressortgroupref = *queryGroupClause->nextSortGroupRefIndex; targetEntry->ressortgroupref = *queryGroupClause->nextSortGroupRefIndex;
@ -2854,10 +2807,6 @@ WorkerAggregateExpressionList(Aggref *originalAggregate,
const int hashArgumentCount = 2; const int hashArgumentCount = 2;
const int addArgumentCount = 2; const int addArgumentCount = 2;
TargetEntry *hashedColumnArgument = NULL;
TargetEntry *storageSizeArgument = NULL;
List *addAggregateArgumentList = NIL;
Aggref *addAggregateFunction = NULL;
/* init hll_hash() related variables */ /* init hll_hash() related variables */
Oid argumentType = AggregateArgumentType(originalAggregate); Oid argumentType = AggregateArgumentType(originalAggregate);
@ -2888,13 +2837,14 @@ WorkerAggregateExpressionList(Aggref *originalAggregate,
hashFunction->args = list_make1(argumentExpression); hashFunction->args = list_make1(argumentExpression);
/* construct hll_add_agg() expression */ /* construct hll_add_agg() expression */
hashedColumnArgument = makeTargetEntry((Expr *) hashFunction, TargetEntry *hashedColumnArgument = makeTargetEntry((Expr *) hashFunction,
firstArgumentId, NULL, false); firstArgumentId, NULL, false);
storageSizeArgument = makeTargetEntry((Expr *) logOfStorageSizeConst, TargetEntry *storageSizeArgument = makeTargetEntry((Expr *) logOfStorageSizeConst,
secondArgumentId, NULL, false); secondArgumentId, NULL, false);
addAggregateArgumentList = list_make2(hashedColumnArgument, storageSizeArgument); List *addAggregateArgumentList = list_make2(hashedColumnArgument,
storageSizeArgument);
addAggregateFunction = makeNode(Aggref); Aggref *addAggregateFunction = makeNode(Aggref);
addAggregateFunction->aggfnoid = addFunctionId; addAggregateFunction->aggfnoid = addFunctionId;
addAggregateFunction->aggtype = hllType; addAggregateFunction->aggtype = hllType;
addAggregateFunction->args = addAggregateArgumentList; addAggregateFunction->args = addAggregateArgumentList;
@ -2964,16 +2914,14 @@ WorkerAggregateExpressionList(Aggref *originalAggregate,
if (combine != InvalidOid) if (combine != InvalidOid)
{ {
Const *aggOidParam = NULL;
Aggref *newWorkerAggregate = NULL;
List *aggArguments = NIL;
ListCell *originalAggArgCell; ListCell *originalAggArgCell;
Oid workerPartialId = WorkerPartialAggOid(); Oid workerPartialId = WorkerPartialAggOid();
aggOidParam = makeConst(REGPROCEDUREOID, -1, InvalidOid, sizeof(Oid), Const *aggOidParam = makeConst(REGPROCEDUREOID, -1, InvalidOid, sizeof(Oid),
ObjectIdGetDatum(originalAggregate->aggfnoid), false, ObjectIdGetDatum(originalAggregate->aggfnoid),
false,
true); true);
aggArguments = list_make1(makeTargetEntry((Expr *) aggOidParam, 1, NULL, List *aggArguments = list_make1(makeTargetEntry((Expr *) aggOidParam, 1, NULL,
false)); false));
foreach(originalAggArgCell, originalAggregate->args) foreach(originalAggArgCell, originalAggregate->args)
{ {
@ -2984,7 +2932,7 @@ WorkerAggregateExpressionList(Aggref *originalAggregate,
} }
/* worker_partial_agg(agg, ...args) */ /* worker_partial_agg(agg, ...args) */
newWorkerAggregate = makeNode(Aggref); Aggref *newWorkerAggregate = makeNode(Aggref);
newWorkerAggregate->aggfnoid = workerPartialId; newWorkerAggregate->aggfnoid = workerPartialId;
newWorkerAggregate->aggtype = CSTRINGOID; newWorkerAggregate->aggtype = CSTRINGOID;
newWorkerAggregate->args = aggArguments; newWorkerAggregate->args = aggArguments;
@ -3030,35 +2978,27 @@ WorkerAggregateExpressionList(Aggref *originalAggregate,
static AggregateType static AggregateType
GetAggregateType(Oid aggFunctionId) GetAggregateType(Oid aggFunctionId)
{ {
char *aggregateProcName = NULL;
uint32 aggregateCount = 0;
uint32 aggregateIndex = 0;
bool found = false;
/* look up the function name */ /* look up the function name */
aggregateProcName = get_func_name(aggFunctionId); char *aggregateProcName = get_func_name(aggFunctionId);
if (aggregateProcName == NULL) if (aggregateProcName == NULL)
{ {
ereport(ERROR, (errmsg("citus cache lookup failed for function %u", ereport(ERROR, (errmsg("citus cache lookup failed for function %u",
aggFunctionId))); aggFunctionId)));
} }
aggregateCount = lengthof(AggregateNames); uint32 aggregateCount = lengthof(AggregateNames);
Assert(AGGREGATE_INVALID_FIRST == 0); Assert(AGGREGATE_INVALID_FIRST == 0);
for (aggregateIndex = 1; aggregateIndex < aggregateCount; aggregateIndex++) for (uint32 aggregateIndex = 1; aggregateIndex < aggregateCount; aggregateIndex++)
{ {
const char *aggregateName = AggregateNames[aggregateIndex]; const char *aggregateName = AggregateNames[aggregateIndex];
if (strncmp(aggregateName, aggregateProcName, NAMEDATALEN) == 0) if (strncmp(aggregateName, aggregateProcName, NAMEDATALEN) == 0)
{ {
found = true; return aggregateIndex;
break;
} }
} }
if (!found)
{
if (AggregateEnabledCustom(aggFunctionId)) if (AggregateEnabledCustom(aggFunctionId))
{ {
return AGGREGATE_CUSTOM; return AGGREGATE_CUSTOM;
@ -3067,9 +3007,6 @@ GetAggregateType(Oid aggFunctionId)
ereport(ERROR, (errmsg("unsupported aggregate function %s", aggregateProcName))); ereport(ERROR, (errmsg("unsupported aggregate function %s", aggregateProcName)));
} }
return aggregateIndex;
}
/* Extracts the type of the argument over which the aggregate is operating. */ /* Extracts the type of the argument over which the aggregate is operating. */
static Oid static Oid
@ -3093,18 +3030,12 @@ AggregateArgumentType(Aggref *aggregate)
static bool static bool
AggregateEnabledCustom(Oid aggregateOid) AggregateEnabledCustom(Oid aggregateOid)
{ {
HeapTuple aggTuple; HeapTuple aggTuple = SearchSysCache1(AGGFNOID, aggregateOid);
Form_pg_aggregate aggform;
HeapTuple typeTuple;
Form_pg_type typeform;
bool supportsSafeCombine;
aggTuple = SearchSysCache1(AGGFNOID, aggregateOid);
if (!HeapTupleIsValid(aggTuple)) if (!HeapTupleIsValid(aggTuple))
{ {
elog(ERROR, "citus cache lookup failed."); elog(ERROR, "citus cache lookup failed.");
} }
aggform = (Form_pg_aggregate) GETSTRUCT(aggTuple); Form_pg_aggregate aggform = (Form_pg_aggregate) GETSTRUCT(aggTuple);
if (aggform->aggcombinefn == InvalidOid) if (aggform->aggcombinefn == InvalidOid)
{ {
@ -3112,14 +3043,14 @@ AggregateEnabledCustom(Oid aggregateOid)
return false; return false;
} }
typeTuple = SearchSysCache1(TYPEOID, aggform->aggtranstype); HeapTuple typeTuple = SearchSysCache1(TYPEOID, aggform->aggtranstype);
if (!HeapTupleIsValid(typeTuple)) if (!HeapTupleIsValid(typeTuple))
{ {
elog(ERROR, "citus cache lookup failed."); elog(ERROR, "citus cache lookup failed.");
} }
typeform = (Form_pg_type) GETSTRUCT(typeTuple); Form_pg_type typeform = (Form_pg_type) GETSTRUCT(typeTuple);
supportsSafeCombine = typeform->typtype != TYPTYPE_PSEUDO; bool supportsSafeCombine = typeform->typtype != TYPTYPE_PSEUDO;
ReleaseSysCache(aggTuple); ReleaseSysCache(aggTuple);
ReleaseSysCache(typeTuple); ReleaseSysCache(typeTuple);
@ -3137,23 +3068,20 @@ static Oid
AggregateFunctionOid(const char *functionName, Oid inputType) AggregateFunctionOid(const char *functionName, Oid inputType)
{ {
Oid functionOid = InvalidOid; Oid functionOid = InvalidOid;
Relation procRelation = NULL;
SysScanDesc scanDescriptor = NULL;
ScanKeyData scanKey[1]; ScanKeyData scanKey[1];
int scanKeyCount = 1; int scanKeyCount = 1;
HeapTuple heapTuple = NULL;
procRelation = heap_open(ProcedureRelationId, AccessShareLock); Relation procRelation = heap_open(ProcedureRelationId, AccessShareLock);
ScanKeyInit(&scanKey[0], Anum_pg_proc_proname, ScanKeyInit(&scanKey[0], Anum_pg_proc_proname,
BTEqualStrategyNumber, F_NAMEEQ, CStringGetDatum(functionName)); BTEqualStrategyNumber, F_NAMEEQ, CStringGetDatum(functionName));
scanDescriptor = systable_beginscan(procRelation, SysScanDesc scanDescriptor = systable_beginscan(procRelation,
ProcedureNameArgsNspIndexId, true, ProcedureNameArgsNspIndexId, true,
NULL, scanKeyCount, scanKey); NULL, scanKeyCount, scanKey);
/* loop until we find the right function */ /* loop until we find the right function */
heapTuple = systable_getnext(scanDescriptor); HeapTuple heapTuple = systable_getnext(scanDescriptor);
while (HeapTupleIsValid(heapTuple)) while (HeapTupleIsValid(heapTuple))
{ {
Form_pg_proc procForm = (Form_pg_proc) GETSTRUCT(heapTuple); Form_pg_proc procForm = (Form_pg_proc) GETSTRUCT(heapTuple);
@ -3253,9 +3181,7 @@ CoordCombineAggOid()
static Oid static Oid
TypeOid(Oid schemaId, const char *typeName) TypeOid(Oid schemaId, const char *typeName)
{ {
Oid typeOid; Oid typeOid = GetSysCacheOid2Compat(TYPENAMENSP, Anum_pg_type_oid,
typeOid = GetSysCacheOid2Compat(TYPENAMENSP, Anum_pg_type_oid,
PointerGetDatum(typeName), PointerGetDatum(typeName),
ObjectIdGetDatum(schemaId)); ObjectIdGetDatum(schemaId));
@ -3410,8 +3336,6 @@ ErrorIfContainsUnsupportedAggregate(MultiNode *logicalPlanNode)
foreach(expressionCell, expressionList) foreach(expressionCell, expressionList)
{ {
Node *expression = (Node *) lfirst(expressionCell); Node *expression = (Node *) lfirst(expressionCell);
Aggref *aggregateExpression = NULL;
AggregateType aggregateType = AGGREGATE_INVALID_FIRST;
/* only consider aggregate expressions */ /* only consider aggregate expressions */
if (!IsA(expression, Aggref)) if (!IsA(expression, Aggref))
@ -3420,8 +3344,8 @@ ErrorIfContainsUnsupportedAggregate(MultiNode *logicalPlanNode)
} }
/* GetAggregateType errors out on unsupported aggregate types */ /* GetAggregateType errors out on unsupported aggregate types */
aggregateExpression = (Aggref *) expression; Aggref *aggregateExpression = (Aggref *) expression;
aggregateType = GetAggregateType(aggregateExpression->aggfnoid); AggregateType aggregateType = GetAggregateType(aggregateExpression->aggfnoid);
Assert(aggregateType != AGGREGATE_INVALID_FIRST); Assert(aggregateType != AGGREGATE_INVALID_FIRST);
/* /*
@ -3514,11 +3438,6 @@ ErrorIfUnsupportedAggregateDistinct(Aggref *aggregateExpression,
{ {
char *errorDetail = NULL; char *errorDetail = NULL;
bool distinctSupported = true; bool distinctSupported = true;
List *repartitionNodeList = NIL;
Var *distinctColumn = NULL;
List *tableNodeList = NIL;
List *extendedOpNodeList = NIL;
MultiExtendedOp *extendedOpNode = NULL;
AggregateType aggregateType = GetAggregateType(aggregateExpression->aggfnoid); AggregateType aggregateType = GetAggregateType(aggregateExpression->aggfnoid);
@ -3588,18 +3507,18 @@ ErrorIfUnsupportedAggregateDistinct(Aggref *aggregateExpression,
} }
} }
repartitionNodeList = FindNodesOfType(logicalPlanNode, T_MultiPartition); List *repartitionNodeList = FindNodesOfType(logicalPlanNode, T_MultiPartition);
if (repartitionNodeList != NIL) if (repartitionNodeList != NIL)
{ {
distinctSupported = false; distinctSupported = false;
errorDetail = "aggregate (distinct) with table repartitioning is unsupported"; errorDetail = "aggregate (distinct) with table repartitioning is unsupported";
} }
tableNodeList = FindNodesOfType(logicalPlanNode, T_MultiTable); List *tableNodeList = FindNodesOfType(logicalPlanNode, T_MultiTable);
extendedOpNodeList = FindNodesOfType(logicalPlanNode, T_MultiExtendedOp); List *extendedOpNodeList = FindNodesOfType(logicalPlanNode, T_MultiExtendedOp);
extendedOpNode = (MultiExtendedOp *) linitial(extendedOpNodeList); MultiExtendedOp *extendedOpNode = (MultiExtendedOp *) linitial(extendedOpNodeList);
distinctColumn = AggregateDistinctColumn(aggregateExpression); Var *distinctColumn = AggregateDistinctColumn(aggregateExpression);
if (distinctSupported) if (distinctSupported)
{ {
if (distinctColumn == NULL) if (distinctColumn == NULL)
@ -3664,29 +3583,26 @@ ErrorIfUnsupportedAggregateDistinct(Aggref *aggregateExpression,
static Var * static Var *
AggregateDistinctColumn(Aggref *aggregateExpression) AggregateDistinctColumn(Aggref *aggregateExpression)
{ {
Var *aggregateColumn = NULL;
int aggregateArgumentCount = 0;
TargetEntry *aggregateTargetEntry = NULL;
/* only consider aggregates with distincts */ /* only consider aggregates with distincts */
if (!aggregateExpression->aggdistinct) if (!aggregateExpression->aggdistinct)
{ {
return NULL; return NULL;
} }
aggregateArgumentCount = list_length(aggregateExpression->args); int aggregateArgumentCount = list_length(aggregateExpression->args);
if (aggregateArgumentCount != 1) if (aggregateArgumentCount != 1)
{ {
return NULL; return NULL;
} }
aggregateTargetEntry = (TargetEntry *) linitial(aggregateExpression->args); TargetEntry *aggregateTargetEntry = (TargetEntry *) linitial(
aggregateExpression->args);
if (!IsA(aggregateTargetEntry->expr, Var)) if (!IsA(aggregateTargetEntry->expr, Var))
{ {
return NULL; return NULL;
} }
aggregateColumn = (Var *) aggregateTargetEntry->expr; Var *aggregateColumn = (Var *) aggregateTargetEntry->expr;
return aggregateColumn; return aggregateColumn;
} }
@ -3710,8 +3626,6 @@ TablePartitioningSupportsDistinct(List *tableNodeList, MultiExtendedOp *opNode,
MultiTable *tableNode = (MultiTable *) lfirst(tableNodeCell); MultiTable *tableNode = (MultiTable *) lfirst(tableNodeCell);
Oid relationId = tableNode->relationId; Oid relationId = tableNode->relationId;
bool tableDistinctSupported = false; bool tableDistinctSupported = false;
char partitionMethod = 0;
List *shardList = NIL;
if (relationId == SUBQUERY_RELATION_ID || if (relationId == SUBQUERY_RELATION_ID ||
relationId == SUBQUERY_PUSHDOWN_RELATION_ID) relationId == SUBQUERY_PUSHDOWN_RELATION_ID)
@ -3720,7 +3634,7 @@ TablePartitioningSupportsDistinct(List *tableNodeList, MultiExtendedOp *opNode,
} }
/* if table has one shard, task results don't overlap */ /* if table has one shard, task results don't overlap */
shardList = LoadShardList(relationId); List *shardList = LoadShardList(relationId);
if (list_length(shardList) == 1) if (list_length(shardList) == 1)
{ {
continue; continue;
@ -3730,13 +3644,12 @@ TablePartitioningSupportsDistinct(List *tableNodeList, MultiExtendedOp *opNode,
* We need to check that task results don't overlap. We can only do this * We need to check that task results don't overlap. We can only do this
* if table is range partitioned. * if table is range partitioned.
*/ */
partitionMethod = PartitionMethod(relationId); char partitionMethod = PartitionMethod(relationId);
if (partitionMethod == DISTRIBUTE_BY_RANGE || if (partitionMethod == DISTRIBUTE_BY_RANGE ||
partitionMethod == DISTRIBUTE_BY_HASH) partitionMethod == DISTRIBUTE_BY_HASH)
{ {
Var *tablePartitionColumn = tableNode->partitionColumn; Var *tablePartitionColumn = tableNode->partitionColumn;
bool groupedByPartitionColumn = false;
if (aggregateType == AGGREGATE_COUNT) if (aggregateType == AGGREGATE_COUNT)
{ {
@ -3752,7 +3665,7 @@ TablePartitioningSupportsDistinct(List *tableNodeList, MultiExtendedOp *opNode,
} }
/* if results are grouped by partition column, we can push down */ /* if results are grouped by partition column, we can push down */
groupedByPartitionColumn = GroupedByColumn(opNode->groupClauseList, bool groupedByPartitionColumn = GroupedByColumn(opNode->groupClauseList,
opNode->targetList, opNode->targetList,
tablePartitionColumn); tablePartitionColumn);
if (groupedByPartitionColumn) if (groupedByPartitionColumn)
@ -3901,8 +3814,6 @@ FindReferencedTableColumn(Expr *columnExpression, List *parentQueryList, Query *
{ {
Var *candidateColumn = NULL; Var *candidateColumn = NULL;
List *rangetableList = query->rtable; List *rangetableList = query->rtable;
Index rangeTableEntryIndex = 0;
RangeTblEntry *rangeTableEntry = NULL;
Expr *strippedColumnExpression = (Expr *) strip_implicit_coercions( Expr *strippedColumnExpression = (Expr *) strip_implicit_coercions(
(Node *) columnExpression); (Node *) columnExpression);
@ -3940,8 +3851,8 @@ FindReferencedTableColumn(Expr *columnExpression, List *parentQueryList, Query *
return; return;
} }
rangeTableEntryIndex = candidateColumn->varno - 1; Index rangeTableEntryIndex = candidateColumn->varno - 1;
rangeTableEntry = list_nth(rangetableList, rangeTableEntryIndex); RangeTblEntry *rangeTableEntry = list_nth(rangetableList, rangeTableEntryIndex);
if (rangeTableEntry->rtekind == RTE_RELATION) if (rangeTableEntry->rtekind == RTE_RELATION)
{ {
@ -4402,7 +4313,6 @@ HasOrderByComplexExpression(List *sortClauseList, List *targetList)
{ {
SortGroupClause *sortClause = (SortGroupClause *) lfirst(sortClauseCell); SortGroupClause *sortClause = (SortGroupClause *) lfirst(sortClauseCell);
Node *sortExpression = get_sortgroupclause_expr(sortClause, targetList); Node *sortExpression = get_sortgroupclause_expr(sortClause, targetList);
bool nestedAggregate = false;
/* simple aggregate functions are ok */ /* simple aggregate functions are ok */
if (IsA(sortExpression, Aggref)) if (IsA(sortExpression, Aggref))
@ -4410,7 +4320,7 @@ HasOrderByComplexExpression(List *sortClauseList, List *targetList)
continue; continue;
} }
nestedAggregate = contain_agg_clause(sortExpression); bool nestedAggregate = contain_agg_clause(sortExpression);
if (nestedAggregate) if (nestedAggregate)
{ {
hasOrderByComplexExpression = true; hasOrderByComplexExpression = true;
@ -4430,20 +4340,17 @@ static bool
HasOrderByHllType(List *sortClauseList, List *targetList) HasOrderByHllType(List *sortClauseList, List *targetList)
{ {
bool hasOrderByHllType = false; bool hasOrderByHllType = false;
Oid hllId = InvalidOid;
Oid hllSchemaOid = InvalidOid;
Oid hllTypeId = InvalidOid;
ListCell *sortClauseCell = NULL; ListCell *sortClauseCell = NULL;
/* check whether HLL is loaded */ /* check whether HLL is loaded */
hllId = get_extension_oid(HLL_EXTENSION_NAME, true); Oid hllId = get_extension_oid(HLL_EXTENSION_NAME, true);
if (!OidIsValid(hllId)) if (!OidIsValid(hllId))
{ {
return hasOrderByHllType; return hasOrderByHllType;
} }
hllSchemaOid = get_extension_schema(hllId); Oid hllSchemaOid = get_extension_schema(hllId);
hllTypeId = TypeOid(hllSchemaOid, HLL_TYPE_NAME); Oid hllTypeId = TypeOid(hllSchemaOid, HLL_TYPE_NAME);
foreach(sortClauseCell, sortClauseList) foreach(sortClauseCell, sortClauseList)
{ {

View File

@ -134,7 +134,6 @@ MultiLogicalPlanCreate(Query *originalQuery, Query *queryTree,
PlannerRestrictionContext *plannerRestrictionContext) PlannerRestrictionContext *plannerRestrictionContext)
{ {
MultiNode *multiQueryNode = NULL; MultiNode *multiQueryNode = NULL;
MultiTreeRoot *rootNode = NULL;
if (ShouldUseSubqueryPushDown(originalQuery, queryTree, plannerRestrictionContext)) if (ShouldUseSubqueryPushDown(originalQuery, queryTree, plannerRestrictionContext))
@ -148,7 +147,7 @@ MultiLogicalPlanCreate(Query *originalQuery, Query *queryTree,
} }
/* add a root node to serve as the permanent handle to the tree */ /* add a root node to serve as the permanent handle to the tree */
rootNode = CitusMakeNode(MultiTreeRoot); MultiTreeRoot *rootNode = CitusMakeNode(MultiTreeRoot);
SetChild((MultiUnaryNode *) rootNode, multiQueryNode); SetChild((MultiUnaryNode *) rootNode, multiQueryNode);
return rootNode; return rootNode;
@ -206,9 +205,7 @@ bool
SingleRelationRepartitionSubquery(Query *queryTree) SingleRelationRepartitionSubquery(Query *queryTree)
{ {
List *rangeTableIndexList = NULL; List *rangeTableIndexList = NULL;
RangeTblEntry *rangeTableEntry = NULL;
List *rangeTableList = queryTree->rtable; List *rangeTableList = queryTree->rtable;
int rangeTableIndex = 0;
/* we don't support subqueries in WHERE */ /* we don't support subqueries in WHERE */
if (queryTree->hasSubLinks) if (queryTree->hasSubLinks)
@ -234,8 +231,8 @@ SingleRelationRepartitionSubquery(Query *queryTree)
return false; return false;
} }
rangeTableIndex = linitial_int(rangeTableIndexList); int rangeTableIndex = linitial_int(rangeTableIndexList);
rangeTableEntry = rt_fetch(rangeTableIndex, rangeTableList); RangeTblEntry *rangeTableEntry = rt_fetch(rangeTableIndex, rangeTableList);
if (rangeTableEntry->rtekind == RTE_RELATION) if (rangeTableEntry->rtekind == RTE_RELATION)
{ {
return true; return true;
@ -413,9 +410,6 @@ QueryContainsDistributedTableRTE(Query *query)
bool bool
IsDistributedTableRTE(Node *node) IsDistributedTableRTE(Node *node)
{ {
RangeTblEntry *rangeTableEntry = NULL;
Oid relationId = InvalidOid;
if (node == NULL) if (node == NULL)
{ {
return false; return false;
@ -426,13 +420,13 @@ IsDistributedTableRTE(Node *node)
return false; return false;
} }
rangeTableEntry = (RangeTblEntry *) node; RangeTblEntry *rangeTableEntry = (RangeTblEntry *) node;
if (rangeTableEntry->rtekind != RTE_RELATION) if (rangeTableEntry->rtekind != RTE_RELATION)
{ {
return false; return false;
} }
relationId = rangeTableEntry->relid; Oid relationId = rangeTableEntry->relid;
if (!IsDistributedTable(relationId) || if (!IsDistributedTable(relationId) ||
PartitionMethod(relationId) == DISTRIBUTE_BY_NONE) PartitionMethod(relationId) == DISTRIBUTE_BY_NONE)
{ {
@ -453,7 +447,6 @@ FullCompositeFieldList(List *compositeFieldList)
bool fullCompositeFieldList = true; bool fullCompositeFieldList = true;
bool *compositeFieldArray = NULL; bool *compositeFieldArray = NULL;
uint32 compositeFieldCount = 0; uint32 compositeFieldCount = 0;
uint32 fieldIndex = 0;
ListCell *fieldSelectCell = NULL; ListCell *fieldSelectCell = NULL;
foreach(fieldSelectCell, compositeFieldList) foreach(fieldSelectCell, compositeFieldList)
@ -490,7 +483,7 @@ FullCompositeFieldList(List *compositeFieldList)
compositeFieldArray[compositeFieldIndex] = true; compositeFieldArray[compositeFieldIndex] = true;
} }
for (fieldIndex = 0; fieldIndex < compositeFieldCount; fieldIndex++) for (uint32 fieldIndex = 0; fieldIndex < compositeFieldCount; fieldIndex++)
{ {
if (!compositeFieldArray[fieldIndex]) if (!compositeFieldArray[fieldIndex])
{ {
@ -523,8 +516,6 @@ CompositeFieldRecursive(Expr *expression, Query *query)
{ {
FieldSelect *compositeField = NULL; FieldSelect *compositeField = NULL;
List *rangetableList = query->rtable; List *rangetableList = query->rtable;
Index rangeTableEntryIndex = 0;
RangeTblEntry *rangeTableEntry = NULL;
Var *candidateColumn = NULL; Var *candidateColumn = NULL;
if (IsA(expression, FieldSelect)) if (IsA(expression, FieldSelect))
@ -542,8 +533,8 @@ CompositeFieldRecursive(Expr *expression, Query *query)
return NULL; return NULL;
} }
rangeTableEntryIndex = candidateColumn->varno - 1; Index rangeTableEntryIndex = candidateColumn->varno - 1;
rangeTableEntry = list_nth(rangetableList, rangeTableEntryIndex); RangeTblEntry *rangeTableEntry = list_nth(rangetableList, rangeTableEntryIndex);
if (rangeTableEntry->rtekind == RTE_SUBQUERY) if (rangeTableEntry->rtekind == RTE_SUBQUERY)
{ {
@ -633,29 +624,24 @@ MultiNodeTree(Query *queryTree)
{ {
List *rangeTableList = queryTree->rtable; List *rangeTableList = queryTree->rtable;
List *targetEntryList = queryTree->targetList; List *targetEntryList = queryTree->targetList;
List *whereClauseList = NIL;
List *joinClauseList = NIL; List *joinClauseList = NIL;
List *joinOrderList = NIL; List *joinOrderList = NIL;
List *tableEntryList = NIL; List *tableEntryList = NIL;
List *tableNodeList = NIL; List *tableNodeList = NIL;
List *collectTableList = NIL; List *collectTableList = NIL;
List *subqueryEntryList = NIL;
MultiNode *joinTreeNode = NULL; MultiNode *joinTreeNode = NULL;
MultiSelect *selectNode = NULL;
MultiProject *projectNode = NULL;
MultiExtendedOp *extendedOpNode = NULL;
MultiNode *currentTopNode = NULL; MultiNode *currentTopNode = NULL;
DeferredErrorMessage *unsupportedQueryError = NULL;
/* verify we can perform distributed planning on this query */ /* verify we can perform distributed planning on this query */
unsupportedQueryError = DeferErrorIfQueryNotSupported(queryTree); DeferredErrorMessage *unsupportedQueryError = DeferErrorIfQueryNotSupported(
queryTree);
if (unsupportedQueryError != NULL) if (unsupportedQueryError != NULL)
{ {
RaiseDeferredError(unsupportedQueryError, ERROR); RaiseDeferredError(unsupportedQueryError, ERROR);
} }
/* extract where clause qualifiers and verify we can plan for them */ /* extract where clause qualifiers and verify we can plan for them */
whereClauseList = WhereClauseList(queryTree->jointree); List *whereClauseList = WhereClauseList(queryTree->jointree);
unsupportedQueryError = DeferErrorIfUnsupportedClause(whereClauseList); unsupportedQueryError = DeferErrorIfUnsupportedClause(whereClauseList);
if (unsupportedQueryError) if (unsupportedQueryError)
{ {
@ -666,29 +652,23 @@ MultiNodeTree(Query *queryTree)
* If we have a subquery, build a multi table node for the subquery and * If we have a subquery, build a multi table node for the subquery and
* add a collect node on top of the multi table node. * add a collect node on top of the multi table node.
*/ */
subqueryEntryList = SubqueryEntryList(queryTree); List *subqueryEntryList = SubqueryEntryList(queryTree);
if (subqueryEntryList != NIL) if (subqueryEntryList != NIL)
{ {
RangeTblEntry *subqueryRangeTableEntry = NULL;
MultiCollect *subqueryCollectNode = CitusMakeNode(MultiCollect); MultiCollect *subqueryCollectNode = CitusMakeNode(MultiCollect);
MultiTable *subqueryNode = NULL;
MultiNode *subqueryExtendedNode = NULL;
Query *subqueryTree = NULL;
List *whereClauseColumnList = NIL;
List *targetListColumnList = NIL;
List *columnList = NIL;
ListCell *columnCell = NULL; ListCell *columnCell = NULL;
/* we only support single subquery in the entry list */ /* we only support single subquery in the entry list */
Assert(list_length(subqueryEntryList) == 1); Assert(list_length(subqueryEntryList) == 1);
subqueryRangeTableEntry = (RangeTblEntry *) linitial(subqueryEntryList); RangeTblEntry *subqueryRangeTableEntry = (RangeTblEntry *) linitial(
subqueryTree = subqueryRangeTableEntry->subquery; subqueryEntryList);
Query *subqueryTree = subqueryRangeTableEntry->subquery;
/* ensure if subquery satisfies preconditions */ /* ensure if subquery satisfies preconditions */
Assert(DeferErrorIfUnsupportedSubqueryRepartition(subqueryTree) == NULL); Assert(DeferErrorIfUnsupportedSubqueryRepartition(subqueryTree) == NULL);
subqueryNode = CitusMakeNode(MultiTable); MultiTable *subqueryNode = CitusMakeNode(MultiTable);
subqueryNode->relationId = SUBQUERY_RELATION_ID; subqueryNode->relationId = SUBQUERY_RELATION_ID;
subqueryNode->rangeTableId = SUBQUERY_RANGE_TABLE_ID; subqueryNode->rangeTableId = SUBQUERY_RANGE_TABLE_ID;
subqueryNode->partitionColumn = NULL; subqueryNode->partitionColumn = NULL;
@ -704,10 +684,10 @@ MultiNodeTree(Query *queryTree)
*/ */
Assert(list_length(subqueryEntryList) == 1); Assert(list_length(subqueryEntryList) == 1);
whereClauseColumnList = pull_var_clause_default((Node *) whereClauseList); List *whereClauseColumnList = pull_var_clause_default((Node *) whereClauseList);
targetListColumnList = pull_var_clause_default((Node *) targetEntryList); List *targetListColumnList = pull_var_clause_default((Node *) targetEntryList);
columnList = list_concat(whereClauseColumnList, targetListColumnList); List *columnList = list_concat(whereClauseColumnList, targetListColumnList);
foreach(columnCell, columnList) foreach(columnCell, columnList)
{ {
Var *column = (Var *) lfirst(columnCell); Var *column = (Var *) lfirst(columnCell);
@ -715,7 +695,7 @@ MultiNodeTree(Query *queryTree)
} }
/* recursively create child nested multitree */ /* recursively create child nested multitree */
subqueryExtendedNode = MultiNodeTree(subqueryTree); MultiNode *subqueryExtendedNode = MultiNodeTree(subqueryTree);
SetChild((MultiUnaryNode *) subqueryCollectNode, (MultiNode *) subqueryNode); SetChild((MultiUnaryNode *) subqueryCollectNode, (MultiNode *) subqueryNode);
SetChild((MultiUnaryNode *) subqueryNode, subqueryExtendedNode); SetChild((MultiUnaryNode *) subqueryNode, subqueryExtendedNode);
@ -751,7 +731,7 @@ MultiNodeTree(Query *queryTree)
Assert(currentTopNode != NULL); Assert(currentTopNode != NULL);
/* build select node if the query has selection criteria */ /* build select node if the query has selection criteria */
selectNode = MultiSelectNode(whereClauseList); MultiSelect *selectNode = MultiSelectNode(whereClauseList);
if (selectNode != NULL) if (selectNode != NULL)
{ {
SetChild((MultiUnaryNode *) selectNode, currentTopNode); SetChild((MultiUnaryNode *) selectNode, currentTopNode);
@ -759,7 +739,7 @@ MultiNodeTree(Query *queryTree)
} }
/* build project node for the columns to project */ /* build project node for the columns to project */
projectNode = MultiProjectNode(targetEntryList); MultiProject *projectNode = MultiProjectNode(targetEntryList);
SetChild((MultiUnaryNode *) projectNode, currentTopNode); SetChild((MultiUnaryNode *) projectNode, currentTopNode);
currentTopNode = (MultiNode *) projectNode; currentTopNode = (MultiNode *) projectNode;
@ -769,7 +749,7 @@ MultiNodeTree(Query *queryTree)
* distinguish between aggregates and expressions; and we address this later * distinguish between aggregates and expressions; and we address this later
* in the logical optimizer. * in the logical optimizer.
*/ */
extendedOpNode = MultiExtendedOpNode(queryTree); MultiExtendedOp *extendedOpNode = MultiExtendedOpNode(queryTree);
SetChild((MultiUnaryNode *) extendedOpNode, currentTopNode); SetChild((MultiUnaryNode *) extendedOpNode, currentTopNode);
currentTopNode = (MultiNode *) extendedOpNode; currentTopNode = (MultiNode *) extendedOpNode;
@ -816,16 +796,13 @@ IsReadIntermediateResultFunction(Node *node)
char * char *
FindIntermediateResultIdIfExists(RangeTblEntry *rte) FindIntermediateResultIdIfExists(RangeTblEntry *rte)
{ {
List *functionList = NULL;
RangeTblFunction *rangeTblfunction = NULL;
FuncExpr *funcExpr = NULL;
char *resultId = NULL; char *resultId = NULL;
Assert(rte->rtekind == RTE_FUNCTION); Assert(rte->rtekind == RTE_FUNCTION);
functionList = rte->functions; List *functionList = rte->functions;
rangeTblfunction = (RangeTblFunction *) linitial(functionList); RangeTblFunction *rangeTblfunction = (RangeTblFunction *) linitial(functionList);
funcExpr = (FuncExpr *) rangeTblfunction->funcexpr; FuncExpr *funcExpr = (FuncExpr *) rangeTblfunction->funcexpr;
if (IsReadIntermediateResultFunction((Node *) funcExpr)) if (IsReadIntermediateResultFunction((Node *) funcExpr))
{ {
@ -850,9 +827,6 @@ DeferredErrorMessage *
DeferErrorIfQueryNotSupported(Query *queryTree) DeferErrorIfQueryNotSupported(Query *queryTree)
{ {
char *errorMessage = NULL; char *errorMessage = NULL;
bool hasTablesample = false;
bool hasUnsupportedJoin = false;
bool hasComplexRangeTableType = false;
bool preconditionsSatisfied = true; bool preconditionsSatisfied = true;
StringInfo errorInfo = NULL; StringInfo errorInfo = NULL;
const char *errorHint = NULL; const char *errorHint = NULL;
@ -922,7 +896,7 @@ DeferErrorIfQueryNotSupported(Query *queryTree)
errorHint = filterHint; errorHint = filterHint;
} }
hasTablesample = HasTablesample(queryTree); bool hasTablesample = HasTablesample(queryTree);
if (hasTablesample) if (hasTablesample)
{ {
preconditionsSatisfied = false; preconditionsSatisfied = false;
@ -930,7 +904,8 @@ DeferErrorIfQueryNotSupported(Query *queryTree)
errorHint = filterHint; errorHint = filterHint;
} }
hasUnsupportedJoin = HasUnsupportedJoinWalker((Node *) queryTree->jointree, NULL); bool hasUnsupportedJoin = HasUnsupportedJoinWalker((Node *) queryTree->jointree,
NULL);
if (hasUnsupportedJoin) if (hasUnsupportedJoin)
{ {
preconditionsSatisfied = false; preconditionsSatisfied = false;
@ -939,7 +914,7 @@ DeferErrorIfQueryNotSupported(Query *queryTree)
errorHint = joinHint; errorHint = joinHint;
} }
hasComplexRangeTableType = HasComplexRangeTableType(queryTree); bool hasComplexRangeTableType = HasComplexRangeTableType(queryTree);
if (hasComplexRangeTableType) if (hasComplexRangeTableType)
{ {
preconditionsSatisfied = false; preconditionsSatisfied = false;
@ -1079,9 +1054,6 @@ DeferErrorIfUnsupportedSubqueryRepartition(Query *subqueryTree)
char *errorDetail = NULL; char *errorDetail = NULL;
bool preconditionsSatisfied = true; bool preconditionsSatisfied = true;
List *joinTreeTableIndexList = NIL; List *joinTreeTableIndexList = NIL;
int rangeTableIndex = 0;
RangeTblEntry *rangeTableEntry = NULL;
Query *innerSubquery = NULL;
if (!subqueryTree->hasAggs) if (!subqueryTree->hasAggs)
{ {
@ -1136,15 +1108,15 @@ DeferErrorIfUnsupportedSubqueryRepartition(Query *subqueryTree)
Assert(list_length(joinTreeTableIndexList) == 1); Assert(list_length(joinTreeTableIndexList) == 1);
/* continue with the inner subquery */ /* continue with the inner subquery */
rangeTableIndex = linitial_int(joinTreeTableIndexList); int rangeTableIndex = linitial_int(joinTreeTableIndexList);
rangeTableEntry = rt_fetch(rangeTableIndex, subqueryTree->rtable); RangeTblEntry *rangeTableEntry = rt_fetch(rangeTableIndex, subqueryTree->rtable);
if (rangeTableEntry->rtekind == RTE_RELATION) if (rangeTableEntry->rtekind == RTE_RELATION)
{ {
return NULL; return NULL;
} }
Assert(rangeTableEntry->rtekind == RTE_SUBQUERY); Assert(rangeTableEntry->rtekind == RTE_SUBQUERY);
innerSubquery = rangeTableEntry->subquery; Query *innerSubquery = rangeTableEntry->subquery;
/* recursively continue to the inner subqueries */ /* recursively continue to the inner subqueries */
return DeferErrorIfUnsupportedSubqueryRepartition(innerSubquery); return DeferErrorIfUnsupportedSubqueryRepartition(innerSubquery);
@ -1225,10 +1197,9 @@ WhereClauseList(FromExpr *fromExpr)
{ {
FromExpr *fromExprCopy = copyObject(fromExpr); FromExpr *fromExprCopy = copyObject(fromExpr);
QualifierWalkerContext *walkerContext = palloc0(sizeof(QualifierWalkerContext)); QualifierWalkerContext *walkerContext = palloc0(sizeof(QualifierWalkerContext));
List *whereClauseList = NIL;
ExtractFromExpressionWalker((Node *) fromExprCopy, walkerContext); ExtractFromExpressionWalker((Node *) fromExprCopy, walkerContext);
whereClauseList = walkerContext->baseQualifierList; List *whereClauseList = walkerContext->baseQualifierList;
return whereClauseList; return whereClauseList;
} }
@ -1335,7 +1306,6 @@ JoinClauseList(List *whereClauseList)
static bool static bool
ExtractFromExpressionWalker(Node *node, QualifierWalkerContext *walkerContext) ExtractFromExpressionWalker(Node *node, QualifierWalkerContext *walkerContext)
{ {
bool walkerResult = false;
if (node == NULL) if (node == NULL)
{ {
return false; return false;
@ -1406,7 +1376,7 @@ ExtractFromExpressionWalker(Node *node, QualifierWalkerContext *walkerContext)
} }
} }
walkerResult = expression_tree_walker(node, ExtractFromExpressionWalker, bool walkerResult = expression_tree_walker(node, ExtractFromExpressionWalker,
(void *) walkerContext); (void *) walkerContext);
return walkerResult; return walkerResult;
@ -1421,10 +1391,6 @@ ExtractFromExpressionWalker(Node *node, QualifierWalkerContext *walkerContext)
bool bool
IsJoinClause(Node *clause) IsJoinClause(Node *clause)
{ {
OpExpr *operatorExpression = NULL;
bool equalsOperator = false;
List *varList = NIL;
Var *initialVar = NULL;
Var *var = NULL; Var *var = NULL;
if (!IsA(clause, OpExpr)) if (!IsA(clause, OpExpr))
@ -1432,8 +1398,8 @@ IsJoinClause(Node *clause)
return false; return false;
} }
operatorExpression = castNode(OpExpr, clause); OpExpr *operatorExpression = castNode(OpExpr, clause);
equalsOperator = OperatorImplementsEquality(operatorExpression->opno); bool equalsOperator = OperatorImplementsEquality(operatorExpression->opno);
if (!equalsOperator) if (!equalsOperator)
{ {
@ -1452,13 +1418,13 @@ IsJoinClause(Node *clause)
* take all column references from the clause, if we find 2 column references from a * take all column references from the clause, if we find 2 column references from a
* different relation we assume this is a join clause * different relation we assume this is a join clause
*/ */
varList = pull_var_clause_default(clause); List *varList = pull_var_clause_default(clause);
if (list_length(varList) <= 0) if (list_length(varList) <= 0)
{ {
/* no column references in query, not describing a join */ /* no column references in query, not describing a join */
return false; return false;
} }
initialVar = castNode(Var, linitial(varList)); Var *initialVar = castNode(Var, linitial(varList));
foreach_ptr(var, varList) foreach_ptr(var, varList)
{ {
@ -1635,15 +1601,16 @@ MultiJoinTree(List *joinOrderList, List *collectTableList, List *joinWhereClause
JoinRuleType joinRuleType = joinOrderNode->joinRuleType; JoinRuleType joinRuleType = joinOrderNode->joinRuleType;
JoinType joinType = joinOrderNode->joinType; JoinType joinType = joinOrderNode->joinType;
Var *partitionColumn = joinOrderNode->partitionColumn; Var *partitionColumn = joinOrderNode->partitionColumn;
MultiNode *newJoinNode = NULL;
List *joinClauseList = joinOrderNode->joinClauseList; List *joinClauseList = joinOrderNode->joinClauseList;
/* /*
* Build a join node between the top of our join tree and the next * Build a join node between the top of our join tree and the next
* table in the join order. * table in the join order.
*/ */
newJoinNode = ApplyJoinRule(currentTopNode, (MultiNode *) collectNode, MultiNode *newJoinNode = ApplyJoinRule(currentTopNode,
joinRuleType, partitionColumn, joinType, (MultiNode *) collectNode,
joinRuleType, partitionColumn,
joinType,
joinClauseList); joinClauseList);
/* the new join node becomes the top of our join tree */ /* the new join node becomes the top of our join tree */
@ -1727,22 +1694,19 @@ MultiSelectNode(List *whereClauseList)
static bool static bool
IsSelectClause(Node *clause) IsSelectClause(Node *clause)
{ {
List *columnList = NIL;
ListCell *columnCell = NULL; ListCell *columnCell = NULL;
Var *firstColumn = NULL;
Index firstColumnTableId = 0;
bool isSelectClause = true; bool isSelectClause = true;
/* extract columns from the clause */ /* extract columns from the clause */
columnList = pull_var_clause_default(clause); List *columnList = pull_var_clause_default(clause);
if (list_length(columnList) == 0) if (list_length(columnList) == 0)
{ {
return true; return true;
} }
/* get first column's tableId */ /* get first column's tableId */
firstColumn = (Var *) linitial(columnList); Var *firstColumn = (Var *) linitial(columnList);
firstColumnTableId = firstColumn->varno; Index firstColumnTableId = firstColumn->varno;
/* check if all columns are from the same table */ /* check if all columns are from the same table */
foreach(columnCell, columnList) foreach(columnCell, columnList)
@ -1766,13 +1730,11 @@ IsSelectClause(Node *clause)
MultiProject * MultiProject *
MultiProjectNode(List *targetEntryList) MultiProjectNode(List *targetEntryList)
{ {
MultiProject *projectNode = NULL;
List *uniqueColumnList = NIL; List *uniqueColumnList = NIL;
List *columnList = NIL;
ListCell *columnCell = NULL; ListCell *columnCell = NULL;
/* extract the list of columns and remove any duplicates */ /* extract the list of columns and remove any duplicates */
columnList = pull_var_clause_default((Node *) targetEntryList); List *columnList = pull_var_clause_default((Node *) targetEntryList);
foreach(columnCell, columnList) foreach(columnCell, columnList)
{ {
Var *column = (Var *) lfirst(columnCell); Var *column = (Var *) lfirst(columnCell);
@ -1781,7 +1743,7 @@ MultiProjectNode(List *targetEntryList)
} }
/* create project node with list of columns to project */ /* create project node with list of columns to project */
projectNode = CitusMakeNode(MultiProject); MultiProject *projectNode = CitusMakeNode(MultiProject);
projectNode->columnList = uniqueColumnList; projectNode->columnList = uniqueColumnList;
return projectNode; return projectNode;
@ -1932,7 +1894,6 @@ List *
FindNodesOfType(MultiNode *node, int type) FindNodesOfType(MultiNode *node, int type)
{ {
List *nodeList = NIL; List *nodeList = NIL;
int nodeType = T_Invalid;
/* terminal condition for recursion */ /* terminal condition for recursion */
if (node == NULL) if (node == NULL)
@ -1941,7 +1902,7 @@ FindNodesOfType(MultiNode *node, int type)
} }
/* current node has expected node type */ /* current node has expected node type */
nodeType = CitusNodeTag(node); int nodeType = CitusNodeTag(node);
if (nodeType == type) if (nodeType == type)
{ {
nodeList = lappend(nodeList, node); nodeList = lappend(nodeList, node);
@ -1997,26 +1958,21 @@ static MultiNode *
ApplyJoinRule(MultiNode *leftNode, MultiNode *rightNode, JoinRuleType ruleType, ApplyJoinRule(MultiNode *leftNode, MultiNode *rightNode, JoinRuleType ruleType,
Var *partitionColumn, JoinType joinType, List *joinClauseList) Var *partitionColumn, JoinType joinType, List *joinClauseList)
{ {
RuleApplyFunction ruleApplyFunction = NULL;
MultiNode *multiNode = NULL;
List *applicableJoinClauses = NIL;
List *leftTableIdList = OutputTableIdList(leftNode); List *leftTableIdList = OutputTableIdList(leftNode);
List *rightTableIdList = OutputTableIdList(rightNode); List *rightTableIdList = OutputTableIdList(rightNode);
int rightTableIdCount PG_USED_FOR_ASSERTS_ONLY = 0; int rightTableIdCount PG_USED_FOR_ASSERTS_ONLY = 0;
uint32 rightTableId = 0;
rightTableIdCount = list_length(rightTableIdList); rightTableIdCount = list_length(rightTableIdList);
Assert(rightTableIdCount == 1); Assert(rightTableIdCount == 1);
/* find applicable join clauses between the left and right data sources */ /* find applicable join clauses between the left and right data sources */
rightTableId = (uint32) linitial_int(rightTableIdList); uint32 rightTableId = (uint32) linitial_int(rightTableIdList);
applicableJoinClauses = ApplicableJoinClauses(leftTableIdList, rightTableId, List *applicableJoinClauses = ApplicableJoinClauses(leftTableIdList, rightTableId,
joinClauseList); joinClauseList);
/* call the join rule application function to create the new join node */ /* call the join rule application function to create the new join node */
ruleApplyFunction = JoinRuleApplyFunction(ruleType); RuleApplyFunction ruleApplyFunction = JoinRuleApplyFunction(ruleType);
multiNode = (*ruleApplyFunction)(leftNode, rightNode, partitionColumn, MultiNode *multiNode = (*ruleApplyFunction)(leftNode, rightNode, partitionColumn,
joinType, applicableJoinClauses); joinType, applicableJoinClauses);
if (joinType != JOIN_INNER && CitusIsA(multiNode, MultiJoin)) if (joinType != JOIN_INNER && CitusIsA(multiNode, MultiJoin))
@ -2041,7 +1997,6 @@ static RuleApplyFunction
JoinRuleApplyFunction(JoinRuleType ruleType) JoinRuleApplyFunction(JoinRuleType ruleType)
{ {
static bool ruleApplyFunctionInitialized = false; static bool ruleApplyFunctionInitialized = false;
RuleApplyFunction ruleApplyFunction = NULL;
if (!ruleApplyFunctionInitialized) if (!ruleApplyFunctionInitialized)
{ {
@ -2057,7 +2012,7 @@ JoinRuleApplyFunction(JoinRuleType ruleType)
ruleApplyFunctionInitialized = true; ruleApplyFunctionInitialized = true;
} }
ruleApplyFunction = RuleApplyFunctionArray[ruleType]; RuleApplyFunction ruleApplyFunction = RuleApplyFunctionArray[ruleType];
Assert(ruleApplyFunction != NULL); Assert(ruleApplyFunction != NULL);
return ruleApplyFunction; return ruleApplyFunction;
@ -2154,11 +2109,6 @@ ApplySinglePartitionJoin(MultiNode *leftNode, MultiNode *rightNode,
Var *partitionColumn, JoinType joinType, Var *partitionColumn, JoinType joinType,
List *applicableJoinClauses) List *applicableJoinClauses)
{ {
OpExpr *joinClause = NULL;
Var *leftColumn = NULL;
Var *rightColumn = NULL;
List *rightTableIdList = NIL;
uint32 rightTableId = 0;
uint32 partitionTableId = partitionColumn->varno; uint32 partitionTableId = partitionColumn->varno;
/* create all operator structures up front */ /* create all operator structures up front */
@ -2171,12 +2121,13 @@ ApplySinglePartitionJoin(MultiNode *leftNode, MultiNode *rightNode,
* column against the join clause's columns. If one of the columns matches, * column against the join clause's columns. If one of the columns matches,
* we introduce a (re-)partition operator for the other column. * we introduce a (re-)partition operator for the other column.
*/ */
joinClause = SinglePartitionJoinClause(partitionColumn, applicableJoinClauses); OpExpr *joinClause = SinglePartitionJoinClause(partitionColumn,
applicableJoinClauses);
Assert(joinClause != NULL); Assert(joinClause != NULL);
/* both are verified in SinglePartitionJoinClause to not be NULL, assert is to guard */ /* both are verified in SinglePartitionJoinClause to not be NULL, assert is to guard */
leftColumn = LeftColumnOrNULL(joinClause); Var *leftColumn = LeftColumnOrNULL(joinClause);
rightColumn = RightColumnOrNULL(joinClause); Var *rightColumn = RightColumnOrNULL(joinClause);
Assert(leftColumn != NULL); Assert(leftColumn != NULL);
Assert(rightColumn != NULL); Assert(rightColumn != NULL);
@ -2193,8 +2144,8 @@ ApplySinglePartitionJoin(MultiNode *leftNode, MultiNode *rightNode,
} }
/* determine the node the partition operator goes on top of */ /* determine the node the partition operator goes on top of */
rightTableIdList = OutputTableIdList(rightNode); List *rightTableIdList = OutputTableIdList(rightNode);
rightTableId = (uint32) linitial_int(rightTableIdList); uint32 rightTableId = (uint32) linitial_int(rightTableIdList);
Assert(list_length(rightTableIdList) == 1); Assert(list_length(rightTableIdList) == 1);
/* /*
@ -2238,33 +2189,22 @@ ApplyDualPartitionJoin(MultiNode *leftNode, MultiNode *rightNode,
Var *partitionColumn, JoinType joinType, Var *partitionColumn, JoinType joinType,
List *applicableJoinClauses) List *applicableJoinClauses)
{ {
MultiJoin *joinNode = NULL;
OpExpr *joinClause = NULL;
MultiPartition *leftPartitionNode = NULL;
MultiPartition *rightPartitionNode = NULL;
MultiCollect *leftCollectNode = NULL;
MultiCollect *rightCollectNode = NULL;
Var *leftColumn = NULL;
Var *rightColumn = NULL;
List *rightTableIdList = NIL;
uint32 rightTableId = 0;
/* find the appropriate join clause */ /* find the appropriate join clause */
joinClause = DualPartitionJoinClause(applicableJoinClauses); OpExpr *joinClause = DualPartitionJoinClause(applicableJoinClauses);
Assert(joinClause != NULL); Assert(joinClause != NULL);
/* both are verified in DualPartitionJoinClause to not be NULL, assert is to guard */ /* both are verified in DualPartitionJoinClause to not be NULL, assert is to guard */
leftColumn = LeftColumnOrNULL(joinClause); Var *leftColumn = LeftColumnOrNULL(joinClause);
rightColumn = RightColumnOrNULL(joinClause); Var *rightColumn = RightColumnOrNULL(joinClause);
Assert(leftColumn != NULL); Assert(leftColumn != NULL);
Assert(rightColumn != NULL); Assert(rightColumn != NULL);
rightTableIdList = OutputTableIdList(rightNode); List *rightTableIdList = OutputTableIdList(rightNode);
rightTableId = (uint32) linitial_int(rightTableIdList); uint32 rightTableId = (uint32) linitial_int(rightTableIdList);
Assert(list_length(rightTableIdList) == 1); Assert(list_length(rightTableIdList) == 1);
leftPartitionNode = CitusMakeNode(MultiPartition); MultiPartition *leftPartitionNode = CitusMakeNode(MultiPartition);
rightPartitionNode = CitusMakeNode(MultiPartition); MultiPartition *rightPartitionNode = CitusMakeNode(MultiPartition);
/* find the partition node each join clause column belongs to */ /* find the partition node each join clause column belongs to */
if (leftColumn->varno == rightTableId) if (leftColumn->varno == rightTableId)
@ -2283,14 +2223,14 @@ ApplyDualPartitionJoin(MultiNode *leftNode, MultiNode *rightNode,
SetChild((MultiUnaryNode *) rightPartitionNode, rightNode); SetChild((MultiUnaryNode *) rightPartitionNode, rightNode);
/* add collect operators on top of the two partition operators */ /* add collect operators on top of the two partition operators */
leftCollectNode = CitusMakeNode(MultiCollect); MultiCollect *leftCollectNode = CitusMakeNode(MultiCollect);
rightCollectNode = CitusMakeNode(MultiCollect); MultiCollect *rightCollectNode = CitusMakeNode(MultiCollect);
SetChild((MultiUnaryNode *) leftCollectNode, (MultiNode *) leftPartitionNode); SetChild((MultiUnaryNode *) leftCollectNode, (MultiNode *) leftPartitionNode);
SetChild((MultiUnaryNode *) rightCollectNode, (MultiNode *) rightPartitionNode); SetChild((MultiUnaryNode *) rightCollectNode, (MultiNode *) rightPartitionNode);
/* add join operator on top of the two collect operators */ /* add join operator on top of the two collect operators */
joinNode = CitusMakeNode(MultiJoin); MultiJoin *joinNode = CitusMakeNode(MultiJoin);
joinNode->joinRuleType = DUAL_PARTITION_JOIN; joinNode->joinRuleType = DUAL_PARTITION_JOIN;
joinNode->joinType = joinType; joinNode->joinType = joinType;
joinNode->joinClauseList = applicableJoinClauses; joinNode->joinClauseList = applicableJoinClauses;

View File

@ -71,13 +71,13 @@ PlannedStmt *
MasterNodeSelectPlan(DistributedPlan *distributedPlan, CustomScan *remoteScan) MasterNodeSelectPlan(DistributedPlan *distributedPlan, CustomScan *remoteScan)
{ {
Query *masterQuery = distributedPlan->masterQuery; Query *masterQuery = distributedPlan->masterQuery;
PlannedStmt *masterSelectPlan = NULL;
Job *workerJob = distributedPlan->workerJob; Job *workerJob = distributedPlan->workerJob;
List *workerTargetList = workerJob->jobQuery->targetList; List *workerTargetList = workerJob->jobQuery->targetList;
List *masterTargetList = MasterTargetList(workerTargetList); List *masterTargetList = MasterTargetList(workerTargetList);
masterSelectPlan = BuildSelectStatement(masterQuery, masterTargetList, remoteScan); PlannedStmt *masterSelectPlan = BuildSelectStatement(masterQuery, masterTargetList,
remoteScan);
return masterSelectPlan; return masterSelectPlan;
} }
@ -99,15 +99,13 @@ MasterTargetList(List *workerTargetList)
foreach(workerTargetCell, workerTargetList) foreach(workerTargetCell, workerTargetList)
{ {
TargetEntry *workerTargetEntry = (TargetEntry *) lfirst(workerTargetCell); TargetEntry *workerTargetEntry = (TargetEntry *) lfirst(workerTargetCell);
TargetEntry *masterTargetEntry = NULL;
Var *masterColumn = NULL;
if (workerTargetEntry->resjunk) if (workerTargetEntry->resjunk)
{ {
continue; continue;
} }
masterColumn = makeVarFromTargetEntry(tableId, workerTargetEntry); Var *masterColumn = makeVarFromTargetEntry(tableId, workerTargetEntry);
masterColumn->varattno = columnId; masterColumn->varattno = columnId;
masterColumn->varoattno = columnId; masterColumn->varoattno = columnId;
columnId++; columnId++;
@ -124,7 +122,7 @@ MasterTargetList(List *workerTargetList)
* from the worker target entry. Note that any changes to worker target * from the worker target entry. Note that any changes to worker target
* entry's sort and group clauses will *break* us here. * entry's sort and group clauses will *break* us here.
*/ */
masterTargetEntry = flatCopyTargetEntry(workerTargetEntry); TargetEntry *masterTargetEntry = flatCopyTargetEntry(workerTargetEntry);
masterTargetEntry->expr = (Expr *) masterColumn; masterTargetEntry->expr = (Expr *) masterColumn;
masterTargetList = lappend(masterTargetList, masterTargetEntry); masterTargetList = lappend(masterTargetList, masterTargetEntry);
} }
@ -469,16 +467,14 @@ BuildAggregatePlan(PlannerInfo *root, Query *masterQuery, Plan *subPlan)
static bool static bool
HasDistinctAggregate(Query *masterQuery) HasDistinctAggregate(Query *masterQuery)
{ {
List *targetVarList = NIL;
List *havingVarList = NIL;
List *allColumnList = NIL;
ListCell *allColumnCell = NULL; ListCell *allColumnCell = NULL;
targetVarList = pull_var_clause((Node *) masterQuery->targetList, List *targetVarList = pull_var_clause((Node *) masterQuery->targetList,
PVC_INCLUDE_AGGREGATES);
List *havingVarList = pull_var_clause(masterQuery->havingQual,
PVC_INCLUDE_AGGREGATES); PVC_INCLUDE_AGGREGATES);
havingVarList = pull_var_clause(masterQuery->havingQual, PVC_INCLUDE_AGGREGATES);
allColumnList = list_concat(targetVarList, havingVarList); List *allColumnList = list_concat(targetVarList, havingVarList);
foreach(allColumnCell, allColumnList) foreach(allColumnCell, allColumnList)
{ {
Node *columnNode = lfirst(allColumnCell); Node *columnNode = lfirst(allColumnCell);
@ -506,7 +502,6 @@ static bool
UseGroupAggregateWithHLL(Query *masterQuery) UseGroupAggregateWithHLL(Query *masterQuery)
{ {
Oid hllId = get_extension_oid(HLL_EXTENSION_NAME, true); Oid hllId = get_extension_oid(HLL_EXTENSION_NAME, true);
const char *gucStrValue = NULL;
/* If HLL extension is not loaded, return false */ /* If HLL extension is not loaded, return false */
if (!OidIsValid(hllId)) if (!OidIsValid(hllId))
@ -515,7 +510,7 @@ UseGroupAggregateWithHLL(Query *masterQuery)
} }
/* If HLL is loaded but related GUC is not set, return false */ /* If HLL is loaded but related GUC is not set, return false */
gucStrValue = GetConfigOption(HLL_FORCE_GROUPAGG_GUC_NAME, true, false); const char *gucStrValue = GetConfigOption(HLL_FORCE_GROUPAGG_GUC_NAME, true, false);
if (gucStrValue == NULL || strcmp(gucStrValue, "off") == 0) if (gucStrValue == NULL || strcmp(gucStrValue, "off") == 0)
{ {
return false; return false;
@ -532,10 +527,9 @@ UseGroupAggregateWithHLL(Query *masterQuery)
static bool static bool
QueryContainsAggregateWithHLL(Query *query) QueryContainsAggregateWithHLL(Query *query)
{ {
List *varList = NIL;
ListCell *varCell = NULL; ListCell *varCell = NULL;
varList = pull_var_clause((Node *) query->targetList, PVC_INCLUDE_AGGREGATES); List *varList = pull_var_clause((Node *) query->targetList, PVC_INCLUDE_AGGREGATES);
foreach(varCell, varList) foreach(varCell, varList)
{ {
Var *var = (Var *) lfirst(varCell); Var *var = (Var *) lfirst(varCell);
@ -579,10 +573,8 @@ static Plan *
BuildDistinctPlan(Query *masterQuery, Plan *subPlan) BuildDistinctPlan(Query *masterQuery, Plan *subPlan)
{ {
Plan *distinctPlan = NULL; Plan *distinctPlan = NULL;
bool distinctClausesHashable = true;
List *distinctClauseList = masterQuery->distinctClause; List *distinctClauseList = masterQuery->distinctClause;
List *targetList = copyObject(masterQuery->targetList); List *targetList = copyObject(masterQuery->targetList);
bool hasDistinctAggregate = false;
/* /*
* We don't need to add distinct plan if all of the columns used in group by * We don't need to add distinct plan if all of the columns used in group by
@ -602,8 +594,8 @@ BuildDistinctPlan(Query *masterQuery, Plan *subPlan)
* members are hashable, and not containing distinct aggregate. * members are hashable, and not containing distinct aggregate.
* Otherwise create sort+unique plan. * Otherwise create sort+unique plan.
*/ */
distinctClausesHashable = grouping_is_hashable(distinctClauseList); bool distinctClausesHashable = grouping_is_hashable(distinctClauseList);
hasDistinctAggregate = HasDistinctAggregate(masterQuery); bool hasDistinctAggregate = HasDistinctAggregate(masterQuery);
if (enable_hashagg && distinctClausesHashable && !hasDistinctAggregate) if (enable_hashagg && distinctClausesHashable && !hasDistinctAggregate)
{ {

File diff suppressed because it is too large Load Diff

View File

@ -260,12 +260,10 @@ CreateSingleTaskRouterPlan(DistributedPlan *distributedPlan, Query *originalQuer
Query *query, Query *query,
PlannerRestrictionContext *plannerRestrictionContext) PlannerRestrictionContext *plannerRestrictionContext)
{ {
Job *job = NULL;
distributedPlan->modLevel = RowModifyLevelForQuery(query); distributedPlan->modLevel = RowModifyLevelForQuery(query);
/* we cannot have multi shard update/delete query via this code path */ /* we cannot have multi shard update/delete query via this code path */
job = RouterJob(originalQuery, plannerRestrictionContext, Job *job = RouterJob(originalQuery, plannerRestrictionContext,
&distributedPlan->planningError); &distributedPlan->planningError);
if (distributedPlan->planningError != NULL) if (distributedPlan->planningError != NULL)
@ -302,7 +300,6 @@ ShardIntervalOpExpressions(ShardInterval *shardInterval, Index rteIndex)
Oid relationId = shardInterval->relationId; Oid relationId = shardInterval->relationId;
char partitionMethod = PartitionMethod(shardInterval->relationId); char partitionMethod = PartitionMethod(shardInterval->relationId);
Var *partitionColumn = NULL; Var *partitionColumn = NULL;
Node *baseConstraint = NULL;
if (partitionMethod == DISTRIBUTE_BY_HASH) if (partitionMethod == DISTRIBUTE_BY_HASH)
{ {
@ -321,7 +318,7 @@ ShardIntervalOpExpressions(ShardInterval *shardInterval, Index rteIndex)
} }
/* build the base expression for constraint */ /* build the base expression for constraint */
baseConstraint = BuildBaseConstraint(partitionColumn); Node *baseConstraint = BuildBaseConstraint(partitionColumn);
/* walk over shard list and check if shards can be pruned */ /* walk over shard list and check if shards can be pruned */
if (shardInterval->minValueExists && shardInterval->maxValueExists) if (shardInterval->minValueExists && shardInterval->maxValueExists)
@ -349,14 +346,7 @@ AddShardIntervalRestrictionToSelect(Query *subqery, ShardInterval *shardInterval
List *targetList = subqery->targetList; List *targetList = subqery->targetList;
ListCell *targetEntryCell = NULL; ListCell *targetEntryCell = NULL;
Var *targetPartitionColumnVar = NULL; Var *targetPartitionColumnVar = NULL;
Oid integer4GEoperatorId = InvalidOid;
Oid integer4LEoperatorId = InvalidOid;
TypeCacheEntry *typeEntry = NULL;
FuncExpr *hashFunctionExpr = NULL;
OpExpr *greaterThanAndEqualsBoundExpr = NULL;
OpExpr *lessThanAndEqualsBoundExpr = NULL;
List *boundExpressionList = NIL; List *boundExpressionList = NIL;
Expr *andedBoundExpressions = NULL;
/* iterate through the target entries */ /* iterate through the target entries */
foreach(targetEntryCell, targetList) foreach(targetEntryCell, targetList)
@ -374,10 +364,10 @@ AddShardIntervalRestrictionToSelect(Query *subqery, ShardInterval *shardInterval
/* we should have found target partition column */ /* we should have found target partition column */
Assert(targetPartitionColumnVar != NULL); Assert(targetPartitionColumnVar != NULL);
integer4GEoperatorId = get_opfamily_member(INTEGER_BTREE_FAM_OID, INT4OID, Oid integer4GEoperatorId = get_opfamily_member(INTEGER_BTREE_FAM_OID, INT4OID,
INT4OID, INT4OID,
BTGreaterEqualStrategyNumber); BTGreaterEqualStrategyNumber);
integer4LEoperatorId = get_opfamily_member(INTEGER_BTREE_FAM_OID, INT4OID, Oid integer4LEoperatorId = get_opfamily_member(INTEGER_BTREE_FAM_OID, INT4OID,
INT4OID, INT4OID,
BTLessEqualStrategyNumber); BTLessEqualStrategyNumber);
@ -386,7 +376,7 @@ AddShardIntervalRestrictionToSelect(Query *subqery, ShardInterval *shardInterval
Assert(integer4LEoperatorId != InvalidOid); Assert(integer4LEoperatorId != InvalidOid);
/* look up the type cache */ /* look up the type cache */
typeEntry = lookup_type_cache(targetPartitionColumnVar->vartype, TypeCacheEntry *typeEntry = lookup_type_cache(targetPartitionColumnVar->vartype,
TYPECACHE_HASH_PROC_FINFO); TYPECACHE_HASH_PROC_FINFO);
/* probable never possible given that the tables are already hash partitioned */ /* probable never possible given that the tables are already hash partitioned */
@ -398,7 +388,7 @@ AddShardIntervalRestrictionToSelect(Query *subqery, ShardInterval *shardInterval
} }
/* generate hashfunc(partCol) expression */ /* generate hashfunc(partCol) expression */
hashFunctionExpr = makeNode(FuncExpr); FuncExpr *hashFunctionExpr = makeNode(FuncExpr);
hashFunctionExpr->funcid = CitusWorkerHashFunctionId(); hashFunctionExpr->funcid = CitusWorkerHashFunctionId();
hashFunctionExpr->args = list_make1(targetPartitionColumnVar); hashFunctionExpr->args = list_make1(targetPartitionColumnVar);
@ -406,7 +396,7 @@ AddShardIntervalRestrictionToSelect(Query *subqery, ShardInterval *shardInterval
hashFunctionExpr->funcresulttype = INT4OID; hashFunctionExpr->funcresulttype = INT4OID;
/* generate hashfunc(partCol) >= shardMinValue OpExpr */ /* generate hashfunc(partCol) >= shardMinValue OpExpr */
greaterThanAndEqualsBoundExpr = OpExpr *greaterThanAndEqualsBoundExpr =
(OpExpr *) make_opclause(integer4GEoperatorId, (OpExpr *) make_opclause(integer4GEoperatorId,
InvalidOid, false, InvalidOid, false,
(Expr *) hashFunctionExpr, (Expr *) hashFunctionExpr,
@ -421,7 +411,7 @@ AddShardIntervalRestrictionToSelect(Query *subqery, ShardInterval *shardInterval
get_func_rettype(greaterThanAndEqualsBoundExpr->opfuncid); get_func_rettype(greaterThanAndEqualsBoundExpr->opfuncid);
/* generate hashfunc(partCol) <= shardMinValue OpExpr */ /* generate hashfunc(partCol) <= shardMinValue OpExpr */
lessThanAndEqualsBoundExpr = OpExpr *lessThanAndEqualsBoundExpr =
(OpExpr *) make_opclause(integer4LEoperatorId, (OpExpr *) make_opclause(integer4LEoperatorId,
InvalidOid, false, InvalidOid, false,
(Expr *) hashFunctionExpr, (Expr *) hashFunctionExpr,
@ -438,7 +428,7 @@ AddShardIntervalRestrictionToSelect(Query *subqery, ShardInterval *shardInterval
boundExpressionList = lappend(boundExpressionList, greaterThanAndEqualsBoundExpr); boundExpressionList = lappend(boundExpressionList, greaterThanAndEqualsBoundExpr);
boundExpressionList = lappend(boundExpressionList, lessThanAndEqualsBoundExpr); boundExpressionList = lappend(boundExpressionList, lessThanAndEqualsBoundExpr);
andedBoundExpressions = make_ands_explicit(boundExpressionList); Expr *andedBoundExpressions = make_ands_explicit(boundExpressionList);
/* finally add the quals */ /* finally add the quals */
if (subqery->jointree->quals == NULL) if (subqery->jointree->quals == NULL)
@ -461,19 +451,15 @@ AddShardIntervalRestrictionToSelect(Query *subqery, ShardInterval *shardInterval
RangeTblEntry * RangeTblEntry *
ExtractSelectRangeTableEntry(Query *query) ExtractSelectRangeTableEntry(Query *query)
{ {
List *fromList = NULL;
RangeTblRef *reference = NULL;
RangeTblEntry *subqueryRte = NULL;
Assert(InsertSelectIntoDistributedTable(query)); Assert(InsertSelectIntoDistributedTable(query));
/* /*
* Since we already asserted InsertSelectIntoDistributedTable() it is safe to access * Since we already asserted InsertSelectIntoDistributedTable() it is safe to access
* both lists * both lists
*/ */
fromList = query->jointree->fromlist; List *fromList = query->jointree->fromlist;
reference = linitial(fromList); RangeTblRef *reference = linitial(fromList);
subqueryRte = rt_fetch(reference->rtindex, query->rtable); RangeTblEntry *subqueryRte = rt_fetch(reference->rtindex, query->rtable);
return subqueryRte; return subqueryRte;
} }
@ -490,8 +476,6 @@ ExtractSelectRangeTableEntry(Query *query)
Oid Oid
ModifyQueryResultRelationId(Query *query) ModifyQueryResultRelationId(Query *query)
{ {
RangeTblEntry *resultRte = NULL;
/* only modify queries have result relations */ /* only modify queries have result relations */
if (!IsModifyCommand(query)) if (!IsModifyCommand(query))
{ {
@ -499,7 +483,7 @@ ModifyQueryResultRelationId(Query *query)
errmsg("input query is not a modification query"))); errmsg("input query is not a modification query")));
} }
resultRte = ExtractResultRelationRTE(query); RangeTblEntry *resultRte = ExtractResultRelationRTE(query);
Assert(OidIsValid(resultRte->relid)); Assert(OidIsValid(resultRte->relid));
return resultRte->relid; return resultRte->relid;
@ -562,7 +546,6 @@ DeferredErrorMessage *
ModifyQuerySupported(Query *queryTree, Query *originalQuery, bool multiShardQuery, ModifyQuerySupported(Query *queryTree, Query *originalQuery, bool multiShardQuery,
PlannerRestrictionContext *plannerRestrictionContext) PlannerRestrictionContext *plannerRestrictionContext)
{ {
DeferredErrorMessage *deferredError = NULL;
Oid distributedTableId = ExtractFirstDistributedTableId(queryTree); Oid distributedTableId = ExtractFirstDistributedTableId(queryTree);
uint32 rangeTableId = 1; uint32 rangeTableId = 1;
Var *partitionColumn = PartitionColumn(distributedTableId, rangeTableId); Var *partitionColumn = PartitionColumn(distributedTableId, rangeTableId);
@ -571,7 +554,7 @@ ModifyQuerySupported(Query *queryTree, Query *originalQuery, bool multiShardQuer
uint32 queryTableCount = 0; uint32 queryTableCount = 0;
CmdType commandType = queryTree->commandType; CmdType commandType = queryTree->commandType;
deferredError = DeferErrorIfModifyView(queryTree); DeferredErrorMessage *deferredError = DeferErrorIfModifyView(queryTree);
if (deferredError != NULL) if (deferredError != NULL)
{ {
return deferredError; return deferredError;
@ -624,7 +607,6 @@ ModifyQuerySupported(Query *queryTree, Query *originalQuery, bool multiShardQuer
{ {
CommonTableExpr *cte = (CommonTableExpr *) lfirst(cteCell); CommonTableExpr *cte = (CommonTableExpr *) lfirst(cteCell);
Query *cteQuery = (Query *) cte->ctequery; Query *cteQuery = (Query *) cte->ctequery;
DeferredErrorMessage *cteError = NULL;
if (cteQuery->commandType != CMD_SELECT) if (cteQuery->commandType != CMD_SELECT)
{ {
@ -649,7 +631,7 @@ ModifyQuerySupported(Query *queryTree, Query *originalQuery, bool multiShardQuer
NULL, NULL); NULL, NULL);
} }
cteError = MultiRouterPlannableQuery(cteQuery); DeferredErrorMessage *cteError = MultiRouterPlannableQuery(cteQuery);
if (cteError) if (cteError)
{ {
return cteError; return cteError;
@ -957,12 +939,7 @@ DeferErrorIfModifyView(Query *queryTree)
DeferredErrorMessage * DeferredErrorMessage *
ErrorIfOnConflictNotSupported(Query *queryTree) ErrorIfOnConflictNotSupported(Query *queryTree)
{ {
Oid distributedTableId = InvalidOid;
uint32 rangeTableId = 1; uint32 rangeTableId = 1;
Var *partitionColumn = NULL;
List *onConflictSet = NIL;
Node *arbiterWhere = NULL;
Node *onConflictWhere = NULL;
ListCell *setTargetCell = NULL; ListCell *setTargetCell = NULL;
bool specifiesPartitionValue = false; bool specifiesPartitionValue = false;
@ -972,12 +949,12 @@ ErrorIfOnConflictNotSupported(Query *queryTree)
return NULL; return NULL;
} }
distributedTableId = ExtractFirstDistributedTableId(queryTree); Oid distributedTableId = ExtractFirstDistributedTableId(queryTree);
partitionColumn = PartitionColumn(distributedTableId, rangeTableId); Var *partitionColumn = PartitionColumn(distributedTableId, rangeTableId);
onConflictSet = queryTree->onConflict->onConflictSet; List *onConflictSet = queryTree->onConflict->onConflictSet;
arbiterWhere = queryTree->onConflict->arbiterWhere; Node *arbiterWhere = queryTree->onConflict->arbiterWhere;
onConflictWhere = queryTree->onConflict->onConflictWhere; Node *onConflictWhere = queryTree->onConflict->onConflictWhere;
/* /*
* onConflictSet is expanded via expand_targetlist() on the standard planner. * onConflictSet is expanded via expand_targetlist() on the standard planner.
@ -1207,11 +1184,10 @@ UpdateOrDeleteQuery(Query *query)
static bool static bool
MasterIrreducibleExpression(Node *expression, bool *varArgument, bool *badCoalesce) MasterIrreducibleExpression(Node *expression, bool *varArgument, bool *badCoalesce)
{ {
bool result;
WalkerState data; WalkerState data;
data.containsVar = data.varArgument = data.badCoalesce = false; data.containsVar = data.varArgument = data.badCoalesce = false;
result = MasterIrreducibleExpressionWalker(expression, &data); bool result = MasterIrreducibleExpressionWalker(expression, &data);
*varArgument |= data.varArgument; *varArgument |= data.varArgument;
*badCoalesce |= data.badCoalesce; *badCoalesce |= data.badCoalesce;
@ -1379,13 +1355,12 @@ TargetEntryChangesValue(TargetEntry *targetEntry, Var *column, FromExpr *joinTre
List *restrictClauseList = WhereClauseList(joinTree); List *restrictClauseList = WhereClauseList(joinTree);
OpExpr *equalityExpr = MakeOpExpression(column, BTEqualStrategyNumber); OpExpr *equalityExpr = MakeOpExpression(column, BTEqualStrategyNumber);
Const *rightConst = (Const *) get_rightop((Expr *) equalityExpr); Const *rightConst = (Const *) get_rightop((Expr *) equalityExpr);
bool predicateIsImplied = false;
rightConst->constvalue = newValue->constvalue; rightConst->constvalue = newValue->constvalue;
rightConst->constisnull = newValue->constisnull; rightConst->constisnull = newValue->constisnull;
rightConst->constbyval = newValue->constbyval; rightConst->constbyval = newValue->constbyval;
predicateIsImplied = predicate_implied_by(list_make1(equalityExpr), bool predicateIsImplied = predicate_implied_by(list_make1(equalityExpr),
restrictClauseList, false); restrictClauseList, false);
if (predicateIsImplied) if (predicateIsImplied)
{ {
@ -1408,7 +1383,6 @@ RouterInsertJob(Query *originalQuery, Query *query, DeferredErrorMessage **plann
{ {
Oid distributedTableId = ExtractFirstDistributedTableId(query); Oid distributedTableId = ExtractFirstDistributedTableId(query);
List *taskList = NIL; List *taskList = NIL;
Job *job = NULL;
bool requiresMasterEvaluation = false; bool requiresMasterEvaluation = false;
bool deferredPruning = false; bool deferredPruning = false;
Const *partitionKeyValue = NULL; Const *partitionKeyValue = NULL;
@ -1459,7 +1433,7 @@ RouterInsertJob(Query *originalQuery, Query *query, DeferredErrorMessage **plann
partitionKeyValue = ExtractInsertPartitionKeyValue(originalQuery); partitionKeyValue = ExtractInsertPartitionKeyValue(originalQuery);
} }
job = CreateJob(originalQuery); Job *job = CreateJob(originalQuery);
job->taskList = taskList; job->taskList = taskList;
job->requiresMasterEvaluation = requiresMasterEvaluation; job->requiresMasterEvaluation = requiresMasterEvaluation;
job->deferredPruning = deferredPruning; job->deferredPruning = deferredPruning;
@ -1475,9 +1449,7 @@ RouterInsertJob(Query *originalQuery, Query *query, DeferredErrorMessage **plann
static Job * static Job *
CreateJob(Query *query) CreateJob(Query *query)
{ {
Job *job = NULL; Job *job = CitusMakeNode(Job);
job = CitusMakeNode(Job);
job->jobId = UniqueJobId(); job->jobId = UniqueJobId();
job->jobQuery = query; job->jobQuery = query;
job->taskList = NIL; job->taskList = NIL;
@ -1498,8 +1470,6 @@ static bool
CanShardPrune(Oid distributedTableId, Query *query) CanShardPrune(Oid distributedTableId, Query *query)
{ {
uint32 rangeTableId = 1; uint32 rangeTableId = 1;
Var *partitionColumn = NULL;
List *insertValuesList = NIL;
ListCell *insertValuesCell = NULL; ListCell *insertValuesCell = NULL;
if (query->commandType != CMD_INSERT) if (query->commandType != CMD_INSERT)
@ -1508,7 +1478,7 @@ CanShardPrune(Oid distributedTableId, Query *query)
return true; return true;
} }
partitionColumn = PartitionColumn(distributedTableId, rangeTableId); Var *partitionColumn = PartitionColumn(distributedTableId, rangeTableId);
if (partitionColumn == NULL) if (partitionColumn == NULL)
{ {
/* can always do shard pruning for reference tables */ /* can always do shard pruning for reference tables */
@ -1516,7 +1486,7 @@ CanShardPrune(Oid distributedTableId, Query *query)
} }
/* get full list of partition values and ensure they are all Consts */ /* get full list of partition values and ensure they are all Consts */
insertValuesList = ExtractInsertValuesList(query, partitionColumn); List *insertValuesList = ExtractInsertValuesList(query, partitionColumn);
foreach(insertValuesCell, insertValuesList) foreach(insertValuesCell, insertValuesList)
{ {
InsertValues *insertValues = (InsertValues *) lfirst(insertValuesCell); InsertValues *insertValues = (InsertValues *) lfirst(insertValuesCell);
@ -1561,7 +1531,6 @@ List *
RouterInsertTaskList(Query *query, DeferredErrorMessage **planningError) RouterInsertTaskList(Query *query, DeferredErrorMessage **planningError)
{ {
List *insertTaskList = NIL; List *insertTaskList = NIL;
List *modifyRouteList = NIL;
ListCell *modifyRouteCell = NULL; ListCell *modifyRouteCell = NULL;
Oid distributedTableId = ExtractFirstDistributedTableId(query); Oid distributedTableId = ExtractFirstDistributedTableId(query);
@ -1571,7 +1540,7 @@ RouterInsertTaskList(Query *query, DeferredErrorMessage **planningError)
Assert(query->commandType == CMD_INSERT); Assert(query->commandType == CMD_INSERT);
modifyRouteList = BuildRoutesForInsert(query, planningError); List *modifyRouteList = BuildRoutesForInsert(query, planningError);
if (*planningError != NULL) if (*planningError != NULL)
{ {
return NIL; return NIL;
@ -1599,9 +1568,7 @@ RouterInsertTaskList(Query *query, DeferredErrorMessage **planningError)
static Task * static Task *
CreateTask(TaskType taskType) CreateTask(TaskType taskType)
{ {
Task *task = NULL; Task *task = CitusMakeNode(Task);
task = CitusMakeNode(Task);
task->taskType = taskType; task->taskType = taskType;
task->jobId = INVALID_JOB_ID; task->jobId = INVALID_JOB_ID;
task->taskId = INVALID_TASK_ID; task->taskId = INVALID_TASK_ID;
@ -1666,22 +1633,18 @@ static Job *
RouterJob(Query *originalQuery, PlannerRestrictionContext *plannerRestrictionContext, RouterJob(Query *originalQuery, PlannerRestrictionContext *plannerRestrictionContext,
DeferredErrorMessage **planningError) DeferredErrorMessage **planningError)
{ {
Job *job = NULL;
uint64 shardId = INVALID_SHARD_ID; uint64 shardId = INVALID_SHARD_ID;
List *placementList = NIL; List *placementList = NIL;
List *relationShardList = NIL; List *relationShardList = NIL;
List *prunedShardIntervalListList = NIL; List *prunedShardIntervalListList = NIL;
bool replacePrunedQueryWithDummy = false;
bool requiresMasterEvaluation = false;
RangeTblEntry *updateOrDeleteRTE = NULL;
bool isMultiShardModifyQuery = false; bool isMultiShardModifyQuery = false;
Const *partitionKeyValue = NULL; Const *partitionKeyValue = NULL;
/* router planner should create task even if it doesn't hit a shard at all */ /* router planner should create task even if it doesn't hit a shard at all */
replacePrunedQueryWithDummy = true; bool replacePrunedQueryWithDummy = true;
/* check if this query requires master evaluation */ /* check if this query requires master evaluation */
requiresMasterEvaluation = RequiresMasterEvaluation(originalQuery); bool requiresMasterEvaluation = RequiresMasterEvaluation(originalQuery);
(*planningError) = PlanRouterQuery(originalQuery, plannerRestrictionContext, (*planningError) = PlanRouterQuery(originalQuery, plannerRestrictionContext,
&placementList, &shardId, &relationShardList, &placementList, &shardId, &relationShardList,
@ -1694,10 +1657,10 @@ RouterJob(Query *originalQuery, PlannerRestrictionContext *plannerRestrictionCon
return NULL; return NULL;
} }
job = CreateJob(originalQuery); Job *job = CreateJob(originalQuery);
job->partitionKeyValue = partitionKeyValue; job->partitionKeyValue = partitionKeyValue;
updateOrDeleteRTE = GetUpdateOrDeleteRTE(originalQuery); RangeTblEntry *updateOrDeleteRTE = GetUpdateOrDeleteRTE(originalQuery);
/* /*
* If all of the shards are pruned, we replace the relation RTE into * If all of the shards are pruned, we replace the relation RTE into
@ -1770,16 +1733,12 @@ ReorderTaskPlacementsByTaskAssignmentPolicy(Job *job,
{ {
if (taskAssignmentPolicy == TASK_ASSIGNMENT_ROUND_ROBIN) if (taskAssignmentPolicy == TASK_ASSIGNMENT_ROUND_ROBIN)
{ {
Task *task = NULL;
List *reorderedPlacementList = NIL;
ShardPlacement *primaryPlacement = NULL;
/* /*
* We hit a single shard on router plans, and there should be only * We hit a single shard on router plans, and there should be only
* one task in the task list * one task in the task list
*/ */
Assert(list_length(job->taskList) == 1); Assert(list_length(job->taskList) == 1);
task = (Task *) linitial(job->taskList); Task *task = (Task *) linitial(job->taskList);
/* /*
* For round-robin SELECT queries, we don't want to include the coordinator * For round-robin SELECT queries, we don't want to include the coordinator
@ -1796,10 +1755,11 @@ ReorderTaskPlacementsByTaskAssignmentPolicy(Job *job,
placementList = RemoveCoordinatorPlacement(placementList); placementList = RemoveCoordinatorPlacement(placementList);
/* reorder the placement list */ /* reorder the placement list */
reorderedPlacementList = RoundRobinReorder(task, placementList); List *reorderedPlacementList = RoundRobinReorder(task, placementList);
task->taskPlacementList = reorderedPlacementList; task->taskPlacementList = reorderedPlacementList;
primaryPlacement = (ShardPlacement *) linitial(reorderedPlacementList); ShardPlacement *primaryPlacement = (ShardPlacement *) linitial(
reorderedPlacementList);
ereport(DEBUG3, (errmsg("assigned task %u to node %s:%u", task->taskId, ereport(DEBUG3, (errmsg("assigned task %u to node %s:%u", task->taskId,
primaryPlacement->nodeName, primaryPlacement->nodeName,
primaryPlacement->nodePort))); primaryPlacement->nodePort)));
@ -1916,16 +1876,14 @@ SingleShardModifyTaskList(Query *query, uint64 jobId, List *relationShardList,
{ {
Task *task = CreateTask(MODIFY_TASK); Task *task = CreateTask(MODIFY_TASK);
StringInfo queryString = makeStringInfo(); StringInfo queryString = makeStringInfo();
DistTableCacheEntry *modificationTableCacheEntry = NULL;
char modificationPartitionMethod = 0;
List *rangeTableList = NIL; List *rangeTableList = NIL;
RangeTblEntry *updateOrDeleteRTE = NULL;
ExtractRangeTableEntryWalker((Node *) query, &rangeTableList); ExtractRangeTableEntryWalker((Node *) query, &rangeTableList);
updateOrDeleteRTE = GetUpdateOrDeleteRTE(query); RangeTblEntry *updateOrDeleteRTE = GetUpdateOrDeleteRTE(query);
modificationTableCacheEntry = DistributedTableCacheEntry(updateOrDeleteRTE->relid); DistTableCacheEntry *modificationTableCacheEntry = DistributedTableCacheEntry(
modificationPartitionMethod = modificationTableCacheEntry->partitionMethod; updateOrDeleteRTE->relid);
char modificationPartitionMethod = modificationTableCacheEntry->partitionMethod;
if (modificationPartitionMethod == DISTRIBUTE_BY_NONE && if (modificationPartitionMethod == DISTRIBUTE_BY_NONE &&
SelectsFromDistributedTable(rangeTableList, query)) SelectsFromDistributedTable(rangeTableList, query))
@ -1983,14 +1941,14 @@ SelectsFromDistributedTable(List *rangeTableList, Query *query)
foreach(rangeTableCell, rangeTableList) foreach(rangeTableCell, rangeTableList)
{ {
RangeTblEntry *rangeTableEntry = (RangeTblEntry *) lfirst(rangeTableCell); RangeTblEntry *rangeTableEntry = (RangeTblEntry *) lfirst(rangeTableCell);
DistTableCacheEntry *cacheEntry = NULL;
if (rangeTableEntry->relid == InvalidOid) if (rangeTableEntry->relid == InvalidOid)
{ {
continue; continue;
} }
cacheEntry = DistributedTableCacheEntry(rangeTableEntry->relid); DistTableCacheEntry *cacheEntry = DistributedTableCacheEntry(
rangeTableEntry->relid);
if (cacheEntry->partitionMethod != DISTRIBUTE_BY_NONE && if (cacheEntry->partitionMethod != DISTRIBUTE_BY_NONE &&
(resultRangeTableEntry == NULL || resultRangeTableEntry->relid != (resultRangeTableEntry == NULL || resultRangeTableEntry->relid !=
rangeTableEntry->relid)) rangeTableEntry->relid))
@ -2242,7 +2200,6 @@ GetAnchorShardId(List *prunedShardIntervalListList)
foreach(prunedShardIntervalListCell, prunedShardIntervalListList) foreach(prunedShardIntervalListCell, prunedShardIntervalListList)
{ {
List *prunedShardIntervalList = (List *) lfirst(prunedShardIntervalListCell); List *prunedShardIntervalList = (List *) lfirst(prunedShardIntervalListCell);
ShardInterval *shardInterval = NULL;
/* no shard is present or all shards are pruned out case will be handled later */ /* no shard is present or all shards are pruned out case will be handled later */
if (prunedShardIntervalList == NIL) if (prunedShardIntervalList == NIL)
@ -2250,7 +2207,7 @@ GetAnchorShardId(List *prunedShardIntervalListList)
continue; continue;
} }
shardInterval = linitial(prunedShardIntervalList); ShardInterval *shardInterval = linitial(prunedShardIntervalList);
if (ReferenceTableShardId(shardInterval->shardId)) if (ReferenceTableShardId(shardInterval->shardId))
{ {
@ -2341,7 +2298,6 @@ TargetShardIntervalsForRestrictInfo(RelationRestrictionContext *restrictionConte
List *prunedShardIntervalList = NIL; List *prunedShardIntervalList = NIL;
List *joinInfoList = relationRestriction->relOptInfo->joininfo; List *joinInfoList = relationRestriction->relOptInfo->joininfo;
List *pseudoRestrictionList = extract_actual_clauses(joinInfoList, true); List *pseudoRestrictionList = extract_actual_clauses(joinInfoList, true);
bool whereFalseQuery = false;
relationRestriction->prunedShardIntervalList = NIL; relationRestriction->prunedShardIntervalList = NIL;
@ -2351,7 +2307,7 @@ TargetShardIntervalsForRestrictInfo(RelationRestrictionContext *restrictionConte
* inside relOptInfo->joininfo list. We treat such cases as if all * inside relOptInfo->joininfo list. We treat such cases as if all
* shards of the table are pruned out. * shards of the table are pruned out.
*/ */
whereFalseQuery = ContainsFalseClause(pseudoRestrictionList); bool whereFalseQuery = ContainsFalseClause(pseudoRestrictionList);
if (!whereFalseQuery && shardCount > 0) if (!whereFalseQuery && shardCount > 0)
{ {
Const *restrictionPartitionValueConst = NULL; Const *restrictionPartitionValueConst = NULL;
@ -2445,9 +2401,6 @@ WorkersContainingAllShards(List *prunedShardIntervalsList)
foreach(prunedShardIntervalCell, prunedShardIntervalsList) foreach(prunedShardIntervalCell, prunedShardIntervalsList)
{ {
List *shardIntervalList = (List *) lfirst(prunedShardIntervalCell); List *shardIntervalList = (List *) lfirst(prunedShardIntervalCell);
ShardInterval *shardInterval = NULL;
uint64 shardId = INVALID_SHARD_ID;
List *newPlacementList = NIL;
if (shardIntervalList == NIL) if (shardIntervalList == NIL)
{ {
@ -2456,11 +2409,11 @@ WorkersContainingAllShards(List *prunedShardIntervalsList)
Assert(list_length(shardIntervalList) == 1); Assert(list_length(shardIntervalList) == 1);
shardInterval = (ShardInterval *) linitial(shardIntervalList); ShardInterval *shardInterval = (ShardInterval *) linitial(shardIntervalList);
shardId = shardInterval->shardId; uint64 shardId = shardInterval->shardId;
/* retrieve all active shard placements for this shard */ /* retrieve all active shard placements for this shard */
newPlacementList = FinalizedShardPlacementList(shardId); List *newPlacementList = FinalizedShardPlacementList(shardId);
if (firstShard) if (firstShard)
{ {
@ -2506,8 +2459,6 @@ BuildRoutesForInsert(Query *query, DeferredErrorMessage **planningError)
DistTableCacheEntry *cacheEntry = DistributedTableCacheEntry(distributedTableId); DistTableCacheEntry *cacheEntry = DistributedTableCacheEntry(distributedTableId);
char partitionMethod = cacheEntry->partitionMethod; char partitionMethod = cacheEntry->partitionMethod;
uint32 rangeTableId = 1; uint32 rangeTableId = 1;
Var *partitionColumn = NULL;
List *insertValuesList = NIL;
List *modifyRouteList = NIL; List *modifyRouteList = NIL;
ListCell *insertValuesCell = NULL; ListCell *insertValuesCell = NULL;
@ -2516,24 +2467,20 @@ BuildRoutesForInsert(Query *query, DeferredErrorMessage **planningError)
/* reference tables can only have one shard */ /* reference tables can only have one shard */
if (partitionMethod == DISTRIBUTE_BY_NONE) if (partitionMethod == DISTRIBUTE_BY_NONE)
{ {
int shardCount = 0;
List *shardIntervalList = LoadShardIntervalList(distributedTableId); List *shardIntervalList = LoadShardIntervalList(distributedTableId);
RangeTblEntry *valuesRTE = NULL;
ShardInterval *shardInterval = NULL;
ModifyRoute *modifyRoute = NULL;
shardCount = list_length(shardIntervalList); int shardCount = list_length(shardIntervalList);
if (shardCount != 1) if (shardCount != 1)
{ {
ereport(ERROR, (errmsg("reference table cannot have %d shards", shardCount))); ereport(ERROR, (errmsg("reference table cannot have %d shards", shardCount)));
} }
shardInterval = linitial(shardIntervalList); ShardInterval *shardInterval = linitial(shardIntervalList);
modifyRoute = palloc(sizeof(ModifyRoute)); ModifyRoute *modifyRoute = palloc(sizeof(ModifyRoute));
modifyRoute->shardId = shardInterval->shardId; modifyRoute->shardId = shardInterval->shardId;
valuesRTE = ExtractDistributedInsertValuesRTE(query); RangeTblEntry *valuesRTE = ExtractDistributedInsertValuesRTE(query);
if (valuesRTE != NULL) if (valuesRTE != NULL)
{ {
/* add the values list for a multi-row INSERT */ /* add the values list for a multi-row INSERT */
@ -2549,18 +2496,15 @@ BuildRoutesForInsert(Query *query, DeferredErrorMessage **planningError)
return modifyRouteList; return modifyRouteList;
} }
partitionColumn = PartitionColumn(distributedTableId, rangeTableId); Var *partitionColumn = PartitionColumn(distributedTableId, rangeTableId);
/* get full list of insert values and iterate over them to prune */ /* get full list of insert values and iterate over them to prune */
insertValuesList = ExtractInsertValuesList(query, partitionColumn); List *insertValuesList = ExtractInsertValuesList(query, partitionColumn);
foreach(insertValuesCell, insertValuesList) foreach(insertValuesCell, insertValuesList)
{ {
InsertValues *insertValues = (InsertValues *) lfirst(insertValuesCell); InsertValues *insertValues = (InsertValues *) lfirst(insertValuesCell);
Const *partitionValueConst = NULL;
List *prunedShardIntervalList = NIL; List *prunedShardIntervalList = NIL;
int prunedShardIntervalCount = 0;
ShardInterval *targetShard = NULL;
if (!IsA(insertValues->partitionValueExpr, Const)) if (!IsA(insertValues->partitionValueExpr, Const))
{ {
@ -2568,7 +2512,7 @@ BuildRoutesForInsert(Query *query, DeferredErrorMessage **planningError)
return NIL; return NIL;
} }
partitionValueConst = (Const *) insertValues->partitionValueExpr; Const *partitionValueConst = (Const *) insertValues->partitionValueExpr;
if (partitionValueConst->constisnull) if (partitionValueConst->constisnull)
{ {
ereport(ERROR, (errcode(ERRCODE_NULL_VALUE_NOT_ALLOWED), ereport(ERROR, (errcode(ERRCODE_NULL_VALUE_NOT_ALLOWED),
@ -2580,10 +2524,9 @@ BuildRoutesForInsert(Query *query, DeferredErrorMessage **planningError)
DISTRIBUTE_BY_RANGE) DISTRIBUTE_BY_RANGE)
{ {
Datum partitionValue = partitionValueConst->constvalue; Datum partitionValue = partitionValueConst->constvalue;
ShardInterval *shardInterval = NULL;
cacheEntry = DistributedTableCacheEntry(distributedTableId); cacheEntry = DistributedTableCacheEntry(distributedTableId);
shardInterval = FindShardInterval(partitionValue, cacheEntry); ShardInterval *shardInterval = FindShardInterval(partitionValue, cacheEntry);
if (shardInterval != NULL) if (shardInterval != NULL)
{ {
prunedShardIntervalList = list_make1(shardInterval); prunedShardIntervalList = list_make1(shardInterval);
@ -2591,7 +2534,6 @@ BuildRoutesForInsert(Query *query, DeferredErrorMessage **planningError)
} }
else else
{ {
List *restrictClauseList = NIL;
Index tableId = 1; Index tableId = 1;
OpExpr *equalityExpr = MakeOpExpression(partitionColumn, OpExpr *equalityExpr = MakeOpExpression(partitionColumn,
BTEqualStrategyNumber); BTEqualStrategyNumber);
@ -2604,13 +2546,13 @@ BuildRoutesForInsert(Query *query, DeferredErrorMessage **planningError)
rightConst->constisnull = partitionValueConst->constisnull; rightConst->constisnull = partitionValueConst->constisnull;
rightConst->constbyval = partitionValueConst->constbyval; rightConst->constbyval = partitionValueConst->constbyval;
restrictClauseList = list_make1(equalityExpr); List *restrictClauseList = list_make1(equalityExpr);
prunedShardIntervalList = PruneShards(distributedTableId, tableId, prunedShardIntervalList = PruneShards(distributedTableId, tableId,
restrictClauseList, NULL); restrictClauseList, NULL);
} }
prunedShardIntervalCount = list_length(prunedShardIntervalList); int prunedShardIntervalCount = list_length(prunedShardIntervalList);
if (prunedShardIntervalCount != 1) if (prunedShardIntervalCount != 1)
{ {
char *partitionKeyString = cacheEntry->partitionKeyString; char *partitionKeyString = cacheEntry->partitionKeyString;
@ -2651,7 +2593,7 @@ BuildRoutesForInsert(Query *query, DeferredErrorMessage **planningError)
return NIL; return NIL;
} }
targetShard = (ShardInterval *) linitial(prunedShardIntervalList); ShardInterval *targetShard = (ShardInterval *) linitial(prunedShardIntervalList);
insertValues->shardId = targetShard->shardId; insertValues->shardId = targetShard->shardId;
} }
@ -2768,19 +2710,15 @@ NormalizeMultiRowInsertTargetList(Query *query)
{ {
TargetEntry *targetEntry = lfirst(targetEntryCell); TargetEntry *targetEntry = lfirst(targetEntryCell);
Node *targetExprNode = (Node *) targetEntry->expr; Node *targetExprNode = (Node *) targetEntry->expr;
Oid targetType = InvalidOid;
int32 targetTypmod = -1;
Oid targetColl = InvalidOid;
Var *syntheticVar = NULL;
/* RTE_VALUES comes 2nd, after destination table */ /* RTE_VALUES comes 2nd, after destination table */
Index valuesVarno = 2; Index valuesVarno = 2;
targetEntryNo++; targetEntryNo++;
targetType = exprType(targetExprNode); Oid targetType = exprType(targetExprNode);
targetTypmod = exprTypmod(targetExprNode); int32 targetTypmod = exprTypmod(targetExprNode);
targetColl = exprCollation(targetExprNode); Oid targetColl = exprCollation(targetExprNode);
valuesRTE->coltypes = lappend_oid(valuesRTE->coltypes, targetType); valuesRTE->coltypes = lappend_oid(valuesRTE->coltypes, targetType);
valuesRTE->coltypmods = lappend_int(valuesRTE->coltypmods, targetTypmod); valuesRTE->coltypmods = lappend_int(valuesRTE->coltypmods, targetTypmod);
@ -2794,7 +2732,7 @@ NormalizeMultiRowInsertTargetList(Query *query)
} }
/* replace the original expression with a Var referencing values_lists */ /* replace the original expression with a Var referencing values_lists */
syntheticVar = makeVar(valuesVarno, targetEntryNo, targetType, targetTypmod, Var *syntheticVar = makeVar(valuesVarno, targetEntryNo, targetType, targetTypmod,
targetColl, 0); targetColl, 0);
targetEntry->expr = (Expr *) syntheticVar; targetEntry->expr = (Expr *) syntheticVar;
} }
@ -2935,11 +2873,10 @@ ExtractInsertValuesList(Query *query, Var *partitionColumn)
if (IsA(targetEntry->expr, Var)) if (IsA(targetEntry->expr, Var))
{ {
Var *partitionVar = (Var *) targetEntry->expr; Var *partitionVar = (Var *) targetEntry->expr;
RangeTblEntry *referencedRTE = NULL;
ListCell *valuesListCell = NULL; ListCell *valuesListCell = NULL;
Index ivIndex = 0; Index ivIndex = 0;
referencedRTE = rt_fetch(partitionVar->varno, query->rtable); RangeTblEntry *referencedRTE = rt_fetch(partitionVar->varno, query->rtable);
foreach(valuesListCell, referencedRTE->values_lists) foreach(valuesListCell, referencedRTE->values_lists)
{ {
InsertValues *insertValues = (InsertValues *) palloc(sizeof(InsertValues)); InsertValues *insertValues = (InsertValues *) palloc(sizeof(InsertValues));
@ -2980,10 +2917,7 @@ ExtractInsertPartitionKeyValue(Query *query)
{ {
Oid distributedTableId = ExtractFirstDistributedTableId(query); Oid distributedTableId = ExtractFirstDistributedTableId(query);
uint32 rangeTableId = 1; uint32 rangeTableId = 1;
Var *partitionColumn = NULL;
TargetEntry *targetEntry = NULL;
Const *singlePartitionValueConst = NULL; Const *singlePartitionValueConst = NULL;
Node *targetExpression = NULL;
char partitionMethod = PartitionMethod(distributedTableId); char partitionMethod = PartitionMethod(distributedTableId);
if (partitionMethod == DISTRIBUTE_BY_NONE) if (partitionMethod == DISTRIBUTE_BY_NONE)
@ -2991,15 +2925,16 @@ ExtractInsertPartitionKeyValue(Query *query)
return NULL; return NULL;
} }
partitionColumn = PartitionColumn(distributedTableId, rangeTableId); Var *partitionColumn = PartitionColumn(distributedTableId, rangeTableId);
targetEntry = get_tle_by_resno(query->targetList, partitionColumn->varattno); TargetEntry *targetEntry = get_tle_by_resno(query->targetList,
partitionColumn->varattno);
if (targetEntry == NULL) if (targetEntry == NULL)
{ {
/* partition column value not specified */ /* partition column value not specified */
return NULL; return NULL;
} }
targetExpression = strip_implicit_coercions((Node *) targetEntry->expr); Node *targetExpression = strip_implicit_coercions((Node *) targetEntry->expr);
/* /*
* Multi-row INSERTs have a Var in the target list that points to * Multi-row INSERTs have a Var in the target list that points to
@ -3008,10 +2943,9 @@ ExtractInsertPartitionKeyValue(Query *query)
if (IsA(targetExpression, Var)) if (IsA(targetExpression, Var))
{ {
Var *partitionVar = (Var *) targetExpression; Var *partitionVar = (Var *) targetExpression;
RangeTblEntry *referencedRTE = NULL;
ListCell *valuesListCell = NULL; ListCell *valuesListCell = NULL;
referencedRTE = rt_fetch(partitionVar->varno, query->rtable); RangeTblEntry *referencedRTE = rt_fetch(partitionVar->varno, query->rtable);
foreach(valuesListCell, referencedRTE->values_lists) foreach(valuesListCell, referencedRTE->values_lists)
{ {
@ -3019,7 +2953,6 @@ ExtractInsertPartitionKeyValue(Query *query)
Node *partitionValueNode = list_nth(rowValues, partitionVar->varattno - 1); Node *partitionValueNode = list_nth(rowValues, partitionVar->varattno - 1);
Expr *partitionValueExpr = (Expr *) strip_implicit_coercions( Expr *partitionValueExpr = (Expr *) strip_implicit_coercions(
partitionValueNode); partitionValueNode);
Const *partitionValueConst = NULL;
if (!IsA(partitionValueExpr, Const)) if (!IsA(partitionValueExpr, Const))
{ {
@ -3028,7 +2961,7 @@ ExtractInsertPartitionKeyValue(Query *query)
break; break;
} }
partitionValueConst = (Const *) partitionValueExpr; Const *partitionValueConst = (Const *) partitionValueExpr;
if (singlePartitionValueConst == NULL) if (singlePartitionValueConst == NULL)
{ {
@ -3098,7 +3031,6 @@ MultiRouterPlannableQuery(Query *query)
{ {
/* only hash partitioned tables are supported */ /* only hash partitioned tables are supported */
Oid distributedTableId = rte->relid; Oid distributedTableId = rte->relid;
char partitionMethod = 0;
if (!IsDistributedTable(distributedTableId)) if (!IsDistributedTable(distributedTableId))
{ {
@ -3109,7 +3041,7 @@ MultiRouterPlannableQuery(Query *query)
NULL, NULL); NULL, NULL);
} }
partitionMethod = PartitionMethod(distributedTableId); char partitionMethod = PartitionMethod(distributedTableId);
if (!(partitionMethod == DISTRIBUTE_BY_HASH || partitionMethod == if (!(partitionMethod == DISTRIBUTE_BY_HASH || partitionMethod ==
DISTRIBUTE_BY_NONE || partitionMethod == DISTRIBUTE_BY_RANGE)) DISTRIBUTE_BY_NONE || partitionMethod == DISTRIBUTE_BY_RANGE))
{ {

View File

@ -43,9 +43,6 @@ make_unique_from_sortclauses(Plan *lefttree, List *distinctList)
Plan *plan = &node->plan; Plan *plan = &node->plan;
int numCols = list_length(distinctList); int numCols = list_length(distinctList);
int keyno = 0; int keyno = 0;
AttrNumber *uniqColIdx;
Oid *uniqOperators;
Oid *uniqCollations;
ListCell *slitem; ListCell *slitem;
plan->targetlist = lefttree->targetlist; plan->targetlist = lefttree->targetlist;
@ -58,9 +55,9 @@ make_unique_from_sortclauses(Plan *lefttree, List *distinctList)
* operators, as wanted by executor * operators, as wanted by executor
*/ */
Assert(numCols > 0); Assert(numCols > 0);
uniqColIdx = (AttrNumber *) palloc(sizeof(AttrNumber) * numCols); AttrNumber *uniqColIdx = (AttrNumber *) palloc(sizeof(AttrNumber) * numCols);
uniqOperators = (Oid *) palloc(sizeof(Oid) * numCols); Oid *uniqOperators = (Oid *) palloc(sizeof(Oid) * numCols);
uniqCollations = (Oid *) palloc(sizeof(Oid) * numCols); Oid *uniqCollations = (Oid *) palloc(sizeof(Oid) * numCols);
foreach(slitem, distinctList) foreach(slitem, distinctList)
{ {
@ -97,8 +94,6 @@ make_unique_from_sortclauses(Plan *lefttree, List *distinctList)
Plan *plan = &node->plan; Plan *plan = &node->plan;
int numCols = list_length(distinctList); int numCols = list_length(distinctList);
int keyno = 0; int keyno = 0;
AttrNumber *uniqColIdx;
Oid *uniqOperators;
ListCell *slitem; ListCell *slitem;
plan->targetlist = lefttree->targetlist; plan->targetlist = lefttree->targetlist;
@ -111,8 +106,8 @@ make_unique_from_sortclauses(Plan *lefttree, List *distinctList)
* operators, as wanted by executor * operators, as wanted by executor
*/ */
Assert(numCols > 0); Assert(numCols > 0);
uniqColIdx = (AttrNumber *) palloc(sizeof(AttrNumber) * numCols); AttrNumber *uniqColIdx = (AttrNumber *) palloc(sizeof(AttrNumber) * numCols);
uniqOperators = (Oid *) palloc(sizeof(Oid) * numCols); Oid *uniqOperators = (Oid *) palloc(sizeof(Oid) * numCols);
foreach(slitem, distinctList) foreach(slitem, distinctList)
{ {

View File

@ -49,14 +49,10 @@ CreateColocatedJoinChecker(Query *subquery, PlannerRestrictionContext *restricti
{ {
ColocatedJoinChecker colocatedJoinChecker; ColocatedJoinChecker colocatedJoinChecker;
RangeTblEntry *anchorRangeTblEntry = NULL;
Query *anchorSubquery = NULL; Query *anchorSubquery = NULL;
PlannerRestrictionContext *anchorPlannerRestrictionContext = NULL;
RelationRestrictionContext *anchorRelationRestrictionContext = NULL;
List *anchorRestrictionEquivalences = NIL;
/* we couldn't pick an anchor subquery, no need to continue */ /* we couldn't pick an anchor subquery, no need to continue */
anchorRangeTblEntry = AnchorRte(subquery); RangeTblEntry *anchorRangeTblEntry = AnchorRte(subquery);
if (anchorRangeTblEntry == NULL) if (anchorRangeTblEntry == NULL)
{ {
colocatedJoinChecker.anchorRelationRestrictionList = NIL; colocatedJoinChecker.anchorRelationRestrictionList = NIL;
@ -84,11 +80,11 @@ CreateColocatedJoinChecker(Query *subquery, PlannerRestrictionContext *restricti
pg_unreachable(); pg_unreachable();
} }
anchorPlannerRestrictionContext = PlannerRestrictionContext *anchorPlannerRestrictionContext =
FilterPlannerRestrictionForQuery(restrictionContext, anchorSubquery); FilterPlannerRestrictionForQuery(restrictionContext, anchorSubquery);
anchorRelationRestrictionContext = RelationRestrictionContext *anchorRelationRestrictionContext =
anchorPlannerRestrictionContext->relationRestrictionContext; anchorPlannerRestrictionContext->relationRestrictionContext;
anchorRestrictionEquivalences = List *anchorRestrictionEquivalences =
GenerateAllAttributeEquivalences(anchorPlannerRestrictionContext); GenerateAllAttributeEquivalences(anchorPlannerRestrictionContext);
/* fill the non colocated planning context */ /* fill the non colocated planning context */
@ -191,9 +187,6 @@ SubqueryColocated(Query *subquery, ColocatedJoinChecker *checker)
List *filteredRestrictionList = List *filteredRestrictionList =
filteredPlannerContext->relationRestrictionContext->relationRestrictionList; filteredPlannerContext->relationRestrictionContext->relationRestrictionList;
List *unionedRelationRestrictionList = NULL;
RelationRestrictionContext *unionedRelationRestrictionContext = NULL;
PlannerRestrictionContext *unionedPlannerRestrictionContext = NULL;
/* /*
* There are no relations in the input subquery, such as a subquery * There are no relations in the input subquery, such as a subquery
@ -213,7 +206,7 @@ SubqueryColocated(Query *subquery, ColocatedJoinChecker *checker)
* forming this temporary context is to check whether the context contains * forming this temporary context is to check whether the context contains
* distribution key equality or not. * distribution key equality or not.
*/ */
unionedRelationRestrictionList = List *unionedRelationRestrictionList =
UnionRelationRestrictionLists(anchorRelationRestrictionList, UnionRelationRestrictionLists(anchorRelationRestrictionList,
filteredRestrictionList); filteredRestrictionList);
@ -224,11 +217,13 @@ SubqueryColocated(Query *subquery, ColocatedJoinChecker *checker)
* join restrictions, we're already relying on the attributeEquivalances * join restrictions, we're already relying on the attributeEquivalances
* provided by the context. * provided by the context.
*/ */
unionedRelationRestrictionContext = palloc0(sizeof(RelationRestrictionContext)); RelationRestrictionContext *unionedRelationRestrictionContext = palloc0(
sizeof(RelationRestrictionContext));
unionedRelationRestrictionContext->relationRestrictionList = unionedRelationRestrictionContext->relationRestrictionList =
unionedRelationRestrictionList; unionedRelationRestrictionList;
unionedPlannerRestrictionContext = palloc0(sizeof(PlannerRestrictionContext)); PlannerRestrictionContext *unionedPlannerRestrictionContext = palloc0(
sizeof(PlannerRestrictionContext));
unionedPlannerRestrictionContext->relationRestrictionContext = unionedPlannerRestrictionContext->relationRestrictionContext =
unionedRelationRestrictionContext; unionedRelationRestrictionContext;
@ -256,14 +251,11 @@ WrapRteRelationIntoSubquery(RangeTblEntry *rteRelation)
{ {
Query *subquery = makeNode(Query); Query *subquery = makeNode(Query);
RangeTblRef *newRangeTableRef = makeNode(RangeTblRef); RangeTblRef *newRangeTableRef = makeNode(RangeTblRef);
RangeTblEntry *newRangeTableEntry = NULL;
Var *targetColumn = NULL;
TargetEntry *targetEntry = NULL;
subquery->commandType = CMD_SELECT; subquery->commandType = CMD_SELECT;
/* we copy the input rteRelation to preserve the rteIdentity */ /* we copy the input rteRelation to preserve the rteIdentity */
newRangeTableEntry = copyObject(rteRelation); RangeTblEntry *newRangeTableEntry = copyObject(rteRelation);
subquery->rtable = list_make1(newRangeTableEntry); subquery->rtable = list_make1(newRangeTableEntry);
/* set the FROM expression to the subquery */ /* set the FROM expression to the subquery */
@ -272,11 +264,12 @@ WrapRteRelationIntoSubquery(RangeTblEntry *rteRelation)
subquery->jointree = makeFromExpr(list_make1(newRangeTableRef), NULL); subquery->jointree = makeFromExpr(list_make1(newRangeTableRef), NULL);
/* Need the whole row as a junk var */ /* Need the whole row as a junk var */
targetColumn = makeWholeRowVar(newRangeTableEntry, newRangeTableRef->rtindex, 0, Var *targetColumn = makeWholeRowVar(newRangeTableEntry, newRangeTableRef->rtindex, 0,
false); false);
/* create a dummy target entry */ /* create a dummy target entry */
targetEntry = makeTargetEntry((Expr *) targetColumn, 1, "wholerow", true); TargetEntry *targetEntry = makeTargetEntry((Expr *) targetColumn, 1, "wholerow",
true);
subquery->targetList = lappend(subquery->targetList, targetEntry); subquery->targetList = lappend(subquery->targetList, targetEntry);
@ -292,15 +285,13 @@ WrapRteRelationIntoSubquery(RangeTblEntry *rteRelation)
static List * static List *
UnionRelationRestrictionLists(List *firstRelationList, List *secondRelationList) UnionRelationRestrictionLists(List *firstRelationList, List *secondRelationList)
{ {
RelationRestrictionContext *unionedRestrictionContext = NULL;
List *unionedRelationRestrictionList = NULL; List *unionedRelationRestrictionList = NULL;
ListCell *relationRestrictionCell = NULL; ListCell *relationRestrictionCell = NULL;
Relids rteIdentities = NULL; Relids rteIdentities = NULL;
List *allRestrictionList = NIL;
/* list_concat destructively modifies the first list, thus copy it */ /* list_concat destructively modifies the first list, thus copy it */
firstRelationList = list_copy(firstRelationList); firstRelationList = list_copy(firstRelationList);
allRestrictionList = list_concat(firstRelationList, secondRelationList); List *allRestrictionList = list_concat(firstRelationList, secondRelationList);
foreach(relationRestrictionCell, allRestrictionList) foreach(relationRestrictionCell, allRestrictionList)
{ {
@ -320,7 +311,8 @@ UnionRelationRestrictionLists(List *firstRelationList, List *secondRelationList)
rteIdentities = bms_add_member(rteIdentities, rteIdentity); rteIdentities = bms_add_member(rteIdentities, rteIdentity);
} }
unionedRestrictionContext = palloc0(sizeof(RelationRestrictionContext)); RelationRestrictionContext *unionedRestrictionContext = palloc0(
sizeof(RelationRestrictionContext));
unionedRestrictionContext->relationRestrictionList = unionedRelationRestrictionList; unionedRestrictionContext->relationRestrictionList = unionedRelationRestrictionList;
return unionedRelationRestrictionList; return unionedRelationRestrictionList;

View File

@ -107,7 +107,6 @@ bool
ShouldUseSubqueryPushDown(Query *originalQuery, Query *rewrittenQuery, ShouldUseSubqueryPushDown(Query *originalQuery, Query *rewrittenQuery,
PlannerRestrictionContext *plannerRestrictionContext) PlannerRestrictionContext *plannerRestrictionContext)
{ {
List *qualifierList = NIL;
StringInfo errorMessage = NULL; StringInfo errorMessage = NULL;
/* /*
@ -183,7 +182,7 @@ ShouldUseSubqueryPushDown(Query *originalQuery, Query *rewrittenQuery,
* Some unsupported join clauses in logical planner * Some unsupported join clauses in logical planner
* may be supported by subquery pushdown planner. * may be supported by subquery pushdown planner.
*/ */
qualifierList = QualifierList(rewrittenQuery->jointree); List *qualifierList = QualifierList(rewrittenQuery->jointree);
if (DeferErrorIfUnsupportedClause(qualifierList) != NULL) if (DeferErrorIfUnsupportedClause(qualifierList) != NULL)
{ {
return true; return true;
@ -283,7 +282,6 @@ bool
WhereOrHavingClauseContainsSubquery(Query *query) WhereOrHavingClauseContainsSubquery(Query *query)
{ {
FromExpr *joinTree = query->jointree; FromExpr *joinTree = query->jointree;
Node *queryQuals = NULL;
if (FindNodeCheck(query->havingQual, IsNodeSubquery)) if (FindNodeCheck(query->havingQual, IsNodeSubquery))
{ {
@ -295,7 +293,7 @@ WhereOrHavingClauseContainsSubquery(Query *query)
return false; return false;
} }
queryQuals = joinTree->quals; Node *queryQuals = joinTree->quals;
return FindNodeCheck(queryQuals, IsNodeSubquery); return FindNodeCheck(queryQuals, IsNodeSubquery);
} }
@ -450,15 +448,13 @@ WindowPartitionOnDistributionColumn(Query *query)
foreach(windowClauseCell, windowClauseList) foreach(windowClauseCell, windowClauseList)
{ {
WindowClause *windowClause = lfirst(windowClauseCell); WindowClause *windowClause = lfirst(windowClauseCell);
List *groupTargetEntryList = NIL;
bool partitionOnDistributionColumn = false;
List *partitionClauseList = windowClause->partitionClause; List *partitionClauseList = windowClause->partitionClause;
List *targetEntryList = query->targetList; List *targetEntryList = query->targetList;
groupTargetEntryList = List *groupTargetEntryList =
GroupTargetEntryList(partitionClauseList, targetEntryList); GroupTargetEntryList(partitionClauseList, targetEntryList);
partitionOnDistributionColumn = bool partitionOnDistributionColumn =
TargetListOnPartitionColumn(query, groupTargetEntryList); TargetListOnPartitionColumn(query, groupTargetEntryList);
if (!partitionOnDistributionColumn) if (!partitionOnDistributionColumn)
@ -495,14 +491,13 @@ SubqueryMultiNodeTree(Query *originalQuery, Query *queryTree,
PlannerRestrictionContext *plannerRestrictionContext) PlannerRestrictionContext *plannerRestrictionContext)
{ {
MultiNode *multiQueryNode = NULL; MultiNode *multiQueryNode = NULL;
DeferredErrorMessage *subqueryPushdownError = NULL;
DeferredErrorMessage *unsupportedQueryError = NULL;
/* /*
* This is a generic error check that applies to both subquery pushdown * This is a generic error check that applies to both subquery pushdown
* and single table repartition subquery. * and single table repartition subquery.
*/ */
unsupportedQueryError = DeferErrorIfQueryNotSupported(originalQuery); DeferredErrorMessage *unsupportedQueryError = DeferErrorIfQueryNotSupported(
originalQuery);
if (unsupportedQueryError != NULL) if (unsupportedQueryError != NULL)
{ {
RaiseDeferredError(unsupportedQueryError, ERROR); RaiseDeferredError(unsupportedQueryError, ERROR);
@ -513,7 +508,8 @@ SubqueryMultiNodeTree(Query *originalQuery, Query *queryTree,
* to create a logical plan, continue with trying the single table * to create a logical plan, continue with trying the single table
* repartition subquery planning. * repartition subquery planning.
*/ */
subqueryPushdownError = DeferErrorIfUnsupportedSubqueryPushdown(originalQuery, DeferredErrorMessage *subqueryPushdownError = DeferErrorIfUnsupportedSubqueryPushdown(
originalQuery,
plannerRestrictionContext); plannerRestrictionContext);
if (!subqueryPushdownError) if (!subqueryPushdownError)
{ {
@ -521,30 +517,26 @@ SubqueryMultiNodeTree(Query *originalQuery, Query *queryTree,
} }
else if (subqueryPushdownError) else if (subqueryPushdownError)
{ {
bool singleRelationRepartitionSubquery = false;
RangeTblEntry *subqueryRangeTableEntry = NULL;
Query *subqueryTree = NULL;
DeferredErrorMessage *repartitionQueryError = NULL;
List *subqueryEntryList = NULL;
/* /*
* If not eligible for single relation repartition query, we should raise * If not eligible for single relation repartition query, we should raise
* subquery pushdown error. * subquery pushdown error.
*/ */
singleRelationRepartitionSubquery = bool singleRelationRepartitionSubquery =
SingleRelationRepartitionSubquery(originalQuery); SingleRelationRepartitionSubquery(originalQuery);
if (!singleRelationRepartitionSubquery) if (!singleRelationRepartitionSubquery)
{ {
RaiseDeferredErrorInternal(subqueryPushdownError, ERROR); RaiseDeferredErrorInternal(subqueryPushdownError, ERROR);
} }
subqueryEntryList = SubqueryEntryList(queryTree); List *subqueryEntryList = SubqueryEntryList(queryTree);
subqueryRangeTableEntry = (RangeTblEntry *) linitial(subqueryEntryList); RangeTblEntry *subqueryRangeTableEntry = (RangeTblEntry *) linitial(
subqueryEntryList);
Assert(subqueryRangeTableEntry->rtekind == RTE_SUBQUERY); Assert(subqueryRangeTableEntry->rtekind == RTE_SUBQUERY);
subqueryTree = subqueryRangeTableEntry->subquery; Query *subqueryTree = subqueryRangeTableEntry->subquery;
repartitionQueryError = DeferErrorIfUnsupportedSubqueryRepartition(subqueryTree); DeferredErrorMessage *repartitionQueryError =
DeferErrorIfUnsupportedSubqueryRepartition(subqueryTree);
if (repartitionQueryError) if (repartitionQueryError)
{ {
RaiseDeferredErrorInternal(repartitionQueryError, ERROR); RaiseDeferredErrorInternal(repartitionQueryError, ERROR);
@ -574,7 +566,6 @@ DeferErrorIfUnsupportedSubqueryPushdown(Query *originalQuery,
bool outerMostQueryHasLimit = false; bool outerMostQueryHasLimit = false;
ListCell *subqueryCell = NULL; ListCell *subqueryCell = NULL;
List *subqueryList = NIL; List *subqueryList = NIL;
DeferredErrorMessage *error = NULL;
if (originalQuery->limitCount != NULL) if (originalQuery->limitCount != NULL)
{ {
@ -610,7 +601,7 @@ DeferErrorIfUnsupportedSubqueryPushdown(Query *originalQuery,
} }
/* we shouldn't allow reference tables in the FROM clause when the query has sublinks */ /* we shouldn't allow reference tables in the FROM clause when the query has sublinks */
error = DeferErrorIfFromClauseRecurs(originalQuery); DeferredErrorMessage *error = DeferErrorIfFromClauseRecurs(originalQuery);
if (error) if (error)
{ {
return error; return error;
@ -666,14 +657,12 @@ DeferErrorIfUnsupportedSubqueryPushdown(Query *originalQuery,
static DeferredErrorMessage * static DeferredErrorMessage *
DeferErrorIfFromClauseRecurs(Query *queryTree) DeferErrorIfFromClauseRecurs(Query *queryTree)
{ {
RecurringTuplesType recurType = RECURRING_TUPLES_INVALID;
if (!queryTree->hasSubLinks) if (!queryTree->hasSubLinks)
{ {
return NULL; return NULL;
} }
recurType = FromClauseRecurringTupleType(queryTree); RecurringTuplesType recurType = FromClauseRecurringTupleType(queryTree);
if (recurType == RECURRING_TUPLES_REFERENCE_TABLE) if (recurType == RECURRING_TUPLES_REFERENCE_TABLE)
{ {
return DeferredError(ERRCODE_FEATURE_NOT_SUPPORTED, return DeferredError(ERRCODE_FEATURE_NOT_SUPPORTED,
@ -892,9 +881,9 @@ DeferErrorIfCannotPushdownSubquery(Query *subqueryTree, bool outerMostQueryHasLi
bool preconditionsSatisfied = true; bool preconditionsSatisfied = true;
char *errorDetail = NULL; char *errorDetail = NULL;
StringInfo errorInfo = NULL; StringInfo errorInfo = NULL;
DeferredErrorMessage *deferredError = NULL;
deferredError = DeferErrorIfUnsupportedTableCombination(subqueryTree); DeferredErrorMessage *deferredError = DeferErrorIfUnsupportedTableCombination(
subqueryTree);
if (deferredError) if (deferredError)
{ {
return deferredError; return deferredError;
@ -1187,9 +1176,8 @@ DeferErrorIfUnsupportedUnionQuery(Query *subqueryTree)
if (IsA(leftArg, RangeTblRef)) if (IsA(leftArg, RangeTblRef))
{ {
Query *leftArgSubquery = NULL;
leftArgRTI = ((RangeTblRef *) leftArg)->rtindex; leftArgRTI = ((RangeTblRef *) leftArg)->rtindex;
leftArgSubquery = rt_fetch(leftArgRTI, subqueryTree->rtable)->subquery; Query *leftArgSubquery = rt_fetch(leftArgRTI, subqueryTree->rtable)->subquery;
recurType = FromClauseRecurringTupleType(leftArgSubquery); recurType = FromClauseRecurringTupleType(leftArgSubquery);
if (recurType != RECURRING_TUPLES_INVALID) if (recurType != RECURRING_TUPLES_INVALID)
{ {
@ -1199,9 +1187,9 @@ DeferErrorIfUnsupportedUnionQuery(Query *subqueryTree)
if (IsA(rightArg, RangeTblRef)) if (IsA(rightArg, RangeTblRef))
{ {
Query *rightArgSubquery = NULL;
rightArgRTI = ((RangeTblRef *) rightArg)->rtindex; rightArgRTI = ((RangeTblRef *) rightArg)->rtindex;
rightArgSubquery = rt_fetch(rightArgRTI, subqueryTree->rtable)->subquery; Query *rightArgSubquery = rt_fetch(rightArgRTI,
subqueryTree->rtable)->subquery;
recurType = FromClauseRecurringTupleType(rightArgSubquery); recurType = FromClauseRecurringTupleType(rightArgSubquery);
if (recurType != RECURRING_TUPLES_INVALID) if (recurType != RECURRING_TUPLES_INVALID)
{ {
@ -1251,7 +1239,6 @@ DeferErrorIfUnsupportedUnionQuery(Query *subqueryTree)
static bool static bool
ExtractSetOperationStatmentWalker(Node *node, List **setOperationList) ExtractSetOperationStatmentWalker(Node *node, List **setOperationList)
{ {
bool walkerResult = false;
if (node == NULL) if (node == NULL)
{ {
return false; return false;
@ -1264,7 +1251,7 @@ ExtractSetOperationStatmentWalker(Node *node, List **setOperationList)
(*setOperationList) = lappend(*setOperationList, setOperation); (*setOperationList) = lappend(*setOperationList, setOperation);
} }
walkerResult = expression_tree_walker(node, ExtractSetOperationStatmentWalker, bool walkerResult = expression_tree_walker(node, ExtractSetOperationStatmentWalker,
setOperationList); setOperationList);
return walkerResult; return walkerResult;
@ -1522,21 +1509,11 @@ static MultiNode *
SubqueryPushdownMultiNodeTree(Query *queryTree) SubqueryPushdownMultiNodeTree(Query *queryTree)
{ {
List *targetEntryList = queryTree->targetList; List *targetEntryList = queryTree->targetList;
List *columnList = NIL;
List *flattenedExprList = NIL;
List *targetColumnList = NIL;
MultiCollect *subqueryCollectNode = CitusMakeNode(MultiCollect); MultiCollect *subqueryCollectNode = CitusMakeNode(MultiCollect);
MultiTable *subqueryNode = NULL;
MultiProject *projectNode = NULL;
MultiExtendedOp *extendedOpNode = NULL;
MultiNode *currentTopNode = NULL;
Query *pushedDownQuery = NULL;
List *subqueryTargetEntryList = NIL;
List *havingClauseColumnList = NIL;
DeferredErrorMessage *unsupportedQueryError = NULL;
/* verify we can perform distributed planning on this query */ /* verify we can perform distributed planning on this query */
unsupportedQueryError = DeferErrorIfQueryNotSupported(queryTree); DeferredErrorMessage *unsupportedQueryError = DeferErrorIfQueryNotSupported(
queryTree);
if (unsupportedQueryError != NULL) if (unsupportedQueryError != NULL)
{ {
RaiseDeferredError(unsupportedQueryError, ERROR); RaiseDeferredError(unsupportedQueryError, ERROR);
@ -1588,14 +1565,14 @@ SubqueryPushdownMultiNodeTree(Query *queryTree)
* columnList. Columns mentioned in multiProject node and multiExtendedOp * columnList. Columns mentioned in multiProject node and multiExtendedOp
* node are indexed with their respective position in columnList. * node are indexed with their respective position in columnList.
*/ */
targetColumnList = pull_var_clause_default((Node *) targetEntryList); List *targetColumnList = pull_var_clause_default((Node *) targetEntryList);
havingClauseColumnList = pull_var_clause_default(queryTree->havingQual); List *havingClauseColumnList = pull_var_clause_default(queryTree->havingQual);
columnList = list_concat(targetColumnList, havingClauseColumnList); List *columnList = list_concat(targetColumnList, havingClauseColumnList);
flattenedExprList = FlattenJoinVars(columnList, queryTree); List *flattenedExprList = FlattenJoinVars(columnList, queryTree);
/* create a target entry for each unique column */ /* create a target entry for each unique column */
subqueryTargetEntryList = CreateSubqueryTargetEntryList(flattenedExprList); List *subqueryTargetEntryList = CreateSubqueryTargetEntryList(flattenedExprList);
/* /*
* Update varno/varattno fields of columns in columnList to * Update varno/varattno fields of columns in columnList to
@ -1605,7 +1582,7 @@ SubqueryPushdownMultiNodeTree(Query *queryTree)
subqueryTargetEntryList); subqueryTargetEntryList);
/* new query only has target entries, join tree, and rtable*/ /* new query only has target entries, join tree, and rtable*/
pushedDownQuery = makeNode(Query); Query *pushedDownQuery = makeNode(Query);
pushedDownQuery->commandType = queryTree->commandType; pushedDownQuery->commandType = queryTree->commandType;
pushedDownQuery->targetList = subqueryTargetEntryList; pushedDownQuery->targetList = subqueryTargetEntryList;
pushedDownQuery->jointree = copyObject(queryTree->jointree); pushedDownQuery->jointree = copyObject(queryTree->jointree);
@ -1614,13 +1591,13 @@ SubqueryPushdownMultiNodeTree(Query *queryTree)
pushedDownQuery->querySource = queryTree->querySource; pushedDownQuery->querySource = queryTree->querySource;
pushedDownQuery->hasSubLinks = queryTree->hasSubLinks; pushedDownQuery->hasSubLinks = queryTree->hasSubLinks;
subqueryNode = MultiSubqueryPushdownTable(pushedDownQuery); MultiTable *subqueryNode = MultiSubqueryPushdownTable(pushedDownQuery);
SetChild((MultiUnaryNode *) subqueryCollectNode, (MultiNode *) subqueryNode); SetChild((MultiUnaryNode *) subqueryCollectNode, (MultiNode *) subqueryNode);
currentTopNode = (MultiNode *) subqueryCollectNode; MultiNode *currentTopNode = (MultiNode *) subqueryCollectNode;
/* build project node for the columns to project */ /* build project node for the columns to project */
projectNode = MultiProjectNode(targetEntryList); MultiProject *projectNode = MultiProjectNode(targetEntryList);
SetChild((MultiUnaryNode *) projectNode, currentTopNode); SetChild((MultiUnaryNode *) projectNode, currentTopNode);
currentTopNode = (MultiNode *) projectNode; currentTopNode = (MultiNode *) projectNode;
@ -1630,7 +1607,7 @@ SubqueryPushdownMultiNodeTree(Query *queryTree)
* distinguish between aggregates and expressions; and we address this later * distinguish between aggregates and expressions; and we address this later
* in the logical optimizer. * in the logical optimizer.
*/ */
extendedOpNode = MultiExtendedOpNode(queryTree); MultiExtendedOp *extendedOpNode = MultiExtendedOpNode(queryTree);
/* /*
* Postgres standard planner converts having qual node to a list of and * Postgres standard planner converts having qual node to a list of and
@ -1724,8 +1701,6 @@ FlattenJoinVarsMutator(Node *node, Query *queryTree)
RangeTblEntry *rte = rt_fetch(column->varno, queryTree->rtable); RangeTblEntry *rte = rt_fetch(column->varno, queryTree->rtable);
if (rte->rtekind == RTE_JOIN) if (rte->rtekind == RTE_JOIN)
{ {
Node *newColumn = NULL;
/* /*
* if join has an alias, it is copied over join RTE. We should * if join has an alias, it is copied over join RTE. We should
* reference this RTE. * reference this RTE.
@ -1737,7 +1712,7 @@ FlattenJoinVarsMutator(Node *node, Query *queryTree)
/* join RTE does not have and alias defined at this level, deeper look is needed */ /* join RTE does not have and alias defined at this level, deeper look is needed */
Assert(column->varattno > 0); Assert(column->varattno > 0);
newColumn = (Node *) list_nth(rte->joinaliasvars, column->varattno - 1); Node *newColumn = (Node *) list_nth(rte->joinaliasvars, column->varattno - 1);
Assert(newColumn != NULL); Assert(newColumn != NULL);
/* /*
@ -1894,7 +1869,6 @@ UpdateColumnToMatchingTargetEntry(Var *column, Node *flattenedExpr, List *target
static MultiTable * static MultiTable *
MultiSubqueryPushdownTable(Query *subquery) MultiSubqueryPushdownTable(Query *subquery)
{ {
MultiTable *subqueryTableNode = NULL;
StringInfo rteName = makeStringInfo(); StringInfo rteName = makeStringInfo();
List *columnNamesList = NIL; List *columnNamesList = NIL;
ListCell *targetEntryCell = NULL; ListCell *targetEntryCell = NULL;
@ -1907,7 +1881,7 @@ MultiSubqueryPushdownTable(Query *subquery)
columnNamesList = lappend(columnNamesList, makeString(targetEntry->resname)); columnNamesList = lappend(columnNamesList, makeString(targetEntry->resname));
} }
subqueryTableNode = CitusMakeNode(MultiTable); MultiTable *subqueryTableNode = CitusMakeNode(MultiTable);
subqueryTableNode->subquery = subquery; subqueryTableNode->subquery = subquery;
subqueryTableNode->relationId = SUBQUERY_PUSHDOWN_RELATION_ID; subqueryTableNode->relationId = SUBQUERY_PUSHDOWN_RELATION_ID;
subqueryTableNode->rangeTableId = SUBQUERY_RANGE_TABLE_ID; subqueryTableNode->rangeTableId = SUBQUERY_RANGE_TABLE_ID;

View File

@ -189,7 +189,6 @@ GenerateSubplansForSubqueriesAndCTEs(uint64 planId, Query *originalQuery,
PlannerRestrictionContext *plannerRestrictionContext) PlannerRestrictionContext *plannerRestrictionContext)
{ {
RecursivePlanningContext context; RecursivePlanningContext context;
DeferredErrorMessage *error = NULL;
recursivePlanningDepth++; recursivePlanningDepth++;
@ -217,7 +216,8 @@ GenerateSubplansForSubqueriesAndCTEs(uint64 planId, Query *originalQuery,
context.allDistributionKeysInQueryAreEqual = context.allDistributionKeysInQueryAreEqual =
AllDistributionKeysInQueryAreEqual(originalQuery, plannerRestrictionContext); AllDistributionKeysInQueryAreEqual(originalQuery, plannerRestrictionContext);
error = RecursivelyPlanSubqueriesAndCTEs(originalQuery, &context); DeferredErrorMessage *error = RecursivelyPlanSubqueriesAndCTEs(originalQuery,
&context);
if (error != NULL) if (error != NULL)
{ {
recursivePlanningDepth--; recursivePlanningDepth--;
@ -257,9 +257,7 @@ GenerateSubplansForSubqueriesAndCTEs(uint64 planId, Query *originalQuery,
static DeferredErrorMessage * static DeferredErrorMessage *
RecursivelyPlanSubqueriesAndCTEs(Query *query, RecursivePlanningContext *context) RecursivelyPlanSubqueriesAndCTEs(Query *query, RecursivePlanningContext *context)
{ {
DeferredErrorMessage *error = NULL; DeferredErrorMessage *error = RecursivelyPlanCTEs(query, context);
error = RecursivelyPlanCTEs(query, context);
if (error != NULL) if (error != NULL)
{ {
return error; return error;
@ -410,14 +408,12 @@ ContainsSubquery(Query *query)
static void static void
RecursivelyPlanNonColocatedSubqueries(Query *subquery, RecursivePlanningContext *context) RecursivelyPlanNonColocatedSubqueries(Query *subquery, RecursivePlanningContext *context)
{ {
ColocatedJoinChecker colocatedJoinChecker;
FromExpr *joinTree = subquery->jointree; FromExpr *joinTree = subquery->jointree;
PlannerRestrictionContext *restrictionContext = NULL;
/* create the context for the non colocated subquery planning */ /* create the context for the non colocated subquery planning */
restrictionContext = context->plannerRestrictionContext; PlannerRestrictionContext *restrictionContext = context->plannerRestrictionContext;
colocatedJoinChecker = CreateColocatedJoinChecker(subquery, restrictionContext); ColocatedJoinChecker colocatedJoinChecker = CreateColocatedJoinChecker(subquery,
restrictionContext);
/* /*
* Although this is a rare case, we weren't able to pick an anchor * Although this is a rare case, we weren't able to pick an anchor
@ -490,7 +486,6 @@ RecursivelyPlanNonColocatedJoinWalker(Node *joinNode,
int rangeTableIndex = ((RangeTblRef *) joinNode)->rtindex; int rangeTableIndex = ((RangeTblRef *) joinNode)->rtindex;
List *rangeTableList = colocatedJoinChecker->subquery->rtable; List *rangeTableList = colocatedJoinChecker->subquery->rtable;
RangeTblEntry *rte = rt_fetch(rangeTableIndex, rangeTableList); RangeTblEntry *rte = rt_fetch(rangeTableIndex, rangeTableList);
Query *subquery = NULL;
/* we're only interested in subqueries for now */ /* we're only interested in subqueries for now */
if (rte->rtekind != RTE_SUBQUERY) if (rte->rtekind != RTE_SUBQUERY)
@ -502,7 +497,7 @@ RecursivelyPlanNonColocatedJoinWalker(Node *joinNode,
* If the subquery is not colocated with the anchor subquery, * If the subquery is not colocated with the anchor subquery,
* recursively plan it. * recursively plan it.
*/ */
subquery = rte->subquery; Query *subquery = rte->subquery;
if (!SubqueryColocated(subquery, colocatedJoinChecker)) if (!SubqueryColocated(subquery, colocatedJoinChecker))
{ {
RecursivelyPlanSubquery(subquery, recursivePlanningContext); RecursivelyPlanSubquery(subquery, recursivePlanningContext);
@ -560,7 +555,6 @@ static List *
SublinkList(Query *originalQuery) SublinkList(Query *originalQuery)
{ {
FromExpr *joinTree = originalQuery->jointree; FromExpr *joinTree = originalQuery->jointree;
Node *queryQuals = NULL;
List *sublinkList = NIL; List *sublinkList = NIL;
if (!joinTree) if (!joinTree)
@ -568,7 +562,7 @@ SublinkList(Query *originalQuery)
return NIL; return NIL;
} }
queryQuals = joinTree->quals; Node *queryQuals = joinTree->quals;
ExtractSublinkWalker(queryQuals, &sublinkList); ExtractSublinkWalker(queryQuals, &sublinkList);
return sublinkList; return sublinkList;
@ -610,17 +604,14 @@ ExtractSublinkWalker(Node *node, List **sublinkList)
static bool static bool
ShouldRecursivelyPlanAllSubqueriesInWhere(Query *query) ShouldRecursivelyPlanAllSubqueriesInWhere(Query *query)
{ {
FromExpr *joinTree = NULL; FromExpr *joinTree = query->jointree;
Node *whereClause = NULL;
joinTree = query->jointree;
if (joinTree == NULL) if (joinTree == NULL)
{ {
/* there is no FROM clause */ /* there is no FROM clause */
return false; return false;
} }
whereClause = joinTree->quals; Node *whereClause = joinTree->quals;
if (whereClause == NULL) if (whereClause == NULL)
{ {
/* there is no WHERE clause */ /* there is no WHERE clause */
@ -703,11 +694,7 @@ RecursivelyPlanCTEs(Query *query, RecursivePlanningContext *planningContext)
char *cteName = cte->ctename; char *cteName = cte->ctename;
Query *subquery = (Query *) cte->ctequery; Query *subquery = (Query *) cte->ctequery;
uint64 planId = planningContext->planId; uint64 planId = planningContext->planId;
uint32 subPlanId = 0;
char *resultId = NULL;
List *cteTargetList = NIL; List *cteTargetList = NIL;
Query *resultQuery = NULL;
DistributedSubPlan *subPlan = NULL;
ListCell *rteCell = NULL; ListCell *rteCell = NULL;
int replacedCtesCount = 0; int replacedCtesCount = 0;
@ -729,7 +716,7 @@ RecursivelyPlanCTEs(Query *query, RecursivePlanningContext *planningContext)
continue; continue;
} }
subPlanId = list_length(planningContext->subPlanList) + 1; uint32 subPlanId = list_length(planningContext->subPlanList) + 1;
if (IsLoggableLevel(DEBUG1)) if (IsLoggableLevel(DEBUG1))
{ {
@ -742,11 +729,11 @@ RecursivelyPlanCTEs(Query *query, RecursivePlanningContext *planningContext)
} }
/* build a sub plan for the CTE */ /* build a sub plan for the CTE */
subPlan = CreateDistributedSubPlan(subPlanId, subquery); DistributedSubPlan *subPlan = CreateDistributedSubPlan(subPlanId, subquery);
planningContext->subPlanList = lappend(planningContext->subPlanList, subPlan); planningContext->subPlanList = lappend(planningContext->subPlanList, subPlan);
/* build the result_id parameter for the call to read_intermediate_result */ /* build the result_id parameter for the call to read_intermediate_result */
resultId = GenerateResultId(planId, subPlanId); char *resultId = GenerateResultId(planId, subPlanId);
if (subquery->returningList) if (subquery->returningList)
{ {
@ -760,7 +747,7 @@ RecursivelyPlanCTEs(Query *query, RecursivePlanningContext *planningContext)
} }
/* replace references to the CTE with a subquery that reads results */ /* replace references to the CTE with a subquery that reads results */
resultQuery = BuildSubPlanResultQuery(cteTargetList, cte->aliascolnames, Query *resultQuery = BuildSubPlanResultQuery(cteTargetList, cte->aliascolnames,
resultId); resultId);
foreach(rteCell, context.cteReferenceList) foreach(rteCell, context.cteReferenceList)
@ -832,7 +819,6 @@ RecursivelyPlanSubqueryWalker(Node *node, RecursivePlanningContext *context)
if (IsA(node, Query)) if (IsA(node, Query))
{ {
Query *query = (Query *) node; Query *query = (Query *) node;
DeferredErrorMessage *error = NULL;
context->level += 1; context->level += 1;
@ -840,7 +826,7 @@ RecursivelyPlanSubqueryWalker(Node *node, RecursivePlanningContext *context)
* First, make sure any subqueries and CTEs within this subquery * First, make sure any subqueries and CTEs within this subquery
* are recursively planned if necessary. * are recursively planned if necessary.
*/ */
error = RecursivelyPlanSubqueriesAndCTEs(query, context); DeferredErrorMessage *error = RecursivelyPlanSubqueriesAndCTEs(query, context);
if (error != NULL) if (error != NULL)
{ {
RaiseDeferredError(error, ERROR); RaiseDeferredError(error, ERROR);
@ -934,19 +920,16 @@ static bool
AllDistributionKeysInSubqueryAreEqual(Query *subquery, AllDistributionKeysInSubqueryAreEqual(Query *subquery,
PlannerRestrictionContext *restrictionContext) PlannerRestrictionContext *restrictionContext)
{ {
bool allDistributionKeysInSubqueryAreEqual = false;
PlannerRestrictionContext *filteredRestrictionContext = NULL;
/* we don't support distribution eq. checks for CTEs yet */ /* we don't support distribution eq. checks for CTEs yet */
if (subquery->cteList != NIL) if (subquery->cteList != NIL)
{ {
return false; return false;
} }
filteredRestrictionContext = PlannerRestrictionContext *filteredRestrictionContext =
FilterPlannerRestrictionForQuery(restrictionContext, subquery); FilterPlannerRestrictionForQuery(restrictionContext, subquery);
allDistributionKeysInSubqueryAreEqual = bool allDistributionKeysInSubqueryAreEqual =
AllDistributionKeysInQueryAreEqual(subquery, filteredRestrictionContext); AllDistributionKeysInQueryAreEqual(subquery, filteredRestrictionContext);
if (!allDistributionKeysInSubqueryAreEqual) if (!allDistributionKeysInSubqueryAreEqual)
{ {
@ -965,8 +948,6 @@ AllDistributionKeysInSubqueryAreEqual(Query *subquery,
static bool static bool
ShouldRecursivelyPlanSetOperation(Query *query, RecursivePlanningContext *context) ShouldRecursivelyPlanSetOperation(Query *query, RecursivePlanningContext *context)
{ {
PlannerRestrictionContext *filteredRestrictionContext = NULL;
SetOperationStmt *setOperations = (SetOperationStmt *) query->setOperations; SetOperationStmt *setOperations = (SetOperationStmt *) query->setOperations;
if (setOperations == NULL) if (setOperations == NULL)
{ {
@ -1000,7 +981,7 @@ ShouldRecursivelyPlanSetOperation(Query *query, RecursivePlanningContext *contex
return true; return true;
} }
filteredRestrictionContext = PlannerRestrictionContext *filteredRestrictionContext =
FilterPlannerRestrictionForQuery(context->plannerRestrictionContext, query); FilterPlannerRestrictionForQuery(context->plannerRestrictionContext, query);
if (!SafeToPushdownUnionSubquery(filteredRestrictionContext)) if (!SafeToPushdownUnionSubquery(filteredRestrictionContext))
{ {
@ -1062,9 +1043,6 @@ RecursivelyPlanSetOperations(Query *query, Node *node,
static bool static bool
IsLocalTableRTE(Node *node) IsLocalTableRTE(Node *node)
{ {
RangeTblEntry *rangeTableEntry = NULL;
Oid relationId = InvalidOid;
if (node == NULL) if (node == NULL)
{ {
return false; return false;
@ -1075,7 +1053,7 @@ IsLocalTableRTE(Node *node)
return false; return false;
} }
rangeTableEntry = (RangeTblEntry *) node; RangeTblEntry *rangeTableEntry = (RangeTblEntry *) node;
if (rangeTableEntry->rtekind != RTE_RELATION) if (rangeTableEntry->rtekind != RTE_RELATION)
{ {
return false; return false;
@ -1086,7 +1064,7 @@ IsLocalTableRTE(Node *node)
return false; return false;
} }
relationId = rangeTableEntry->relid; Oid relationId = rangeTableEntry->relid;
if (IsDistributedTable(relationId)) if (IsDistributedTable(relationId))
{ {
return false; return false;
@ -1111,11 +1089,7 @@ IsLocalTableRTE(Node *node)
static void static void
RecursivelyPlanSubquery(Query *subquery, RecursivePlanningContext *planningContext) RecursivelyPlanSubquery(Query *subquery, RecursivePlanningContext *planningContext)
{ {
DistributedSubPlan *subPlan = NULL;
uint64 planId = planningContext->planId; uint64 planId = planningContext->planId;
int subPlanId = 0;
char *resultId = NULL;
Query *resultQuery = NULL;
Query *debugQuery = NULL; Query *debugQuery = NULL;
if (ContainsReferencesToOuterQuery(subquery)) if (ContainsReferencesToOuterQuery(subquery))
@ -1138,19 +1112,19 @@ RecursivelyPlanSubquery(Query *subquery, RecursivePlanningContext *planningConte
/* /*
* Create the subplan and append it to the list in the planning context. * Create the subplan and append it to the list in the planning context.
*/ */
subPlanId = list_length(planningContext->subPlanList) + 1; int subPlanId = list_length(planningContext->subPlanList) + 1;
subPlan = CreateDistributedSubPlan(subPlanId, subquery); DistributedSubPlan *subPlan = CreateDistributedSubPlan(subPlanId, subquery);
planningContext->subPlanList = lappend(planningContext->subPlanList, subPlan); planningContext->subPlanList = lappend(planningContext->subPlanList, subPlan);
/* build the result_id parameter for the call to read_intermediate_result */ /* build the result_id parameter for the call to read_intermediate_result */
resultId = GenerateResultId(planId, subPlanId); char *resultId = GenerateResultId(planId, subPlanId);
/* /*
* BuildSubPlanResultQuery() can optionally use provided column aliases. * BuildSubPlanResultQuery() can optionally use provided column aliases.
* We do not need to send additional alias list for subqueries. * We do not need to send additional alias list for subqueries.
*/ */
resultQuery = BuildSubPlanResultQuery(subquery->targetList, NIL, resultId); Query *resultQuery = BuildSubPlanResultQuery(subquery->targetList, NIL, resultId);
if (IsLoggableLevel(DEBUG1)) if (IsLoggableLevel(DEBUG1))
{ {
@ -1176,7 +1150,6 @@ RecursivelyPlanSubquery(Query *subquery, RecursivePlanningContext *planningConte
static DistributedSubPlan * static DistributedSubPlan *
CreateDistributedSubPlan(uint32 subPlanId, Query *subPlanQuery) CreateDistributedSubPlan(uint32 subPlanId, Query *subPlanQuery)
{ {
DistributedSubPlan *subPlan = NULL;
int cursorOptions = 0; int cursorOptions = 0;
if (ContainsReadIntermediateResultFunction((Node *) subPlanQuery)) if (ContainsReadIntermediateResultFunction((Node *) subPlanQuery))
@ -1192,7 +1165,7 @@ CreateDistributedSubPlan(uint32 subPlanId, Query *subPlanQuery)
cursorOptions |= CURSOR_OPT_FORCE_DISTRIBUTED; cursorOptions |= CURSOR_OPT_FORCE_DISTRIBUTED;
} }
subPlan = CitusMakeNode(DistributedSubPlan); DistributedSubPlan *subPlan = CitusMakeNode(DistributedSubPlan);
subPlan->plan = planner(subPlanQuery, cursorOptions, NULL); subPlan->plan = planner(subPlanQuery, cursorOptions, NULL);
subPlan->subPlanId = subPlanId; subPlan->subPlanId = subPlanId;
@ -1310,11 +1283,10 @@ ContainsReferencesToOuterQueryWalker(Node *node, VarLevelsUpWalkerContext *conte
else if (IsA(node, Query)) else if (IsA(node, Query))
{ {
Query *query = (Query *) node; Query *query = (Query *) node;
bool found = false;
int flags = 0; int flags = 0;
context->level += 1; context->level += 1;
found = query_tree_walker(query, ContainsReferencesToOuterQueryWalker, bool found = query_tree_walker(query, ContainsReferencesToOuterQueryWalker,
context, flags); context, flags);
context->level -= 1; context->level -= 1;
@ -1383,19 +1355,16 @@ TransformFunctionRTE(RangeTblEntry *rangeTblEntry)
{ {
Query *subquery = makeNode(Query); Query *subquery = makeNode(Query);
RangeTblRef *newRangeTableRef = makeNode(RangeTblRef); RangeTblRef *newRangeTableRef = makeNode(RangeTblRef);
RangeTblEntry *newRangeTableEntry = NULL;
Var *targetColumn = NULL; Var *targetColumn = NULL;
TargetEntry *targetEntry = NULL; TargetEntry *targetEntry = NULL;
RangeTblFunction *rangeTblFunction = NULL;
AttrNumber targetColumnIndex = 0; AttrNumber targetColumnIndex = 0;
TupleDesc tupleDesc = NULL;
rangeTblFunction = linitial(rangeTblEntry->functions); RangeTblFunction *rangeTblFunction = linitial(rangeTblEntry->functions);
subquery->commandType = CMD_SELECT; subquery->commandType = CMD_SELECT;
/* copy the input rangeTblEntry to prevent cycles */ /* copy the input rangeTblEntry to prevent cycles */
newRangeTableEntry = copyObject(rangeTblEntry); RangeTblEntry *newRangeTableEntry = copyObject(rangeTblEntry);
/* set the FROM expression to the subquery */ /* set the FROM expression to the subquery */
subquery->rtable = list_make1(newRangeTableEntry); subquery->rtable = list_make1(newRangeTableEntry);
@ -1407,7 +1376,7 @@ TransformFunctionRTE(RangeTblEntry *rangeTblEntry)
* If function return type is not composite or rowtype can't be determined, * If function return type is not composite or rowtype can't be determined,
* tupleDesc is set to null here * tupleDesc is set to null here
*/ */
tupleDesc = (TupleDesc) get_expr_result_tupdesc(rangeTblFunction->funcexpr, TupleDesc tupleDesc = (TupleDesc) get_expr_result_tupdesc(rangeTblFunction->funcexpr,
true); true);
/* /*
@ -1460,10 +1429,9 @@ TransformFunctionRTE(RangeTblEntry *rangeTblEntry)
else else
{ {
/* create target entries for all columns returned by the function */ /* create target entries for all columns returned by the function */
List *functionColumnNames = NULL;
ListCell *functionColumnName = NULL; ListCell *functionColumnName = NULL;
functionColumnNames = rangeTblEntry->eref->colnames; List *functionColumnNames = rangeTblEntry->eref->colnames;
foreach(functionColumnName, functionColumnNames) foreach(functionColumnName, functionColumnNames)
{ {
char *columnName = strVal(lfirst(functionColumnName)); char *columnName = strVal(lfirst(functionColumnName));
@ -1574,19 +1542,10 @@ ShouldTransformRTE(RangeTblEntry *rangeTableEntry)
Query * Query *
BuildSubPlanResultQuery(List *targetEntryList, List *columnAliasList, char *resultId) BuildSubPlanResultQuery(List *targetEntryList, List *columnAliasList, char *resultId)
{ {
Query *resultQuery = NULL;
Const *resultIdConst = NULL;
Const *resultFormatConst = NULL;
FuncExpr *funcExpr = NULL;
Alias *funcAlias = NULL;
List *funcColNames = NIL; List *funcColNames = NIL;
List *funcColTypes = NIL; List *funcColTypes = NIL;
List *funcColTypMods = NIL; List *funcColTypMods = NIL;
List *funcColCollations = NIL; List *funcColCollations = NIL;
RangeTblFunction *rangeTableFunction = NULL;
RangeTblEntry *rangeTableEntry = NULL;
RangeTblRef *rangeTableRef = NULL;
FromExpr *joinTree = NULL;
ListCell *targetEntryCell = NULL; ListCell *targetEntryCell = NULL;
List *targetList = NIL; List *targetList = NIL;
int columnNumber = 1; int columnNumber = 1;
@ -1603,8 +1562,6 @@ BuildSubPlanResultQuery(List *targetEntryList, List *columnAliasList, char *resu
Oid columnType = exprType(targetExpr); Oid columnType = exprType(targetExpr);
Oid columnTypMod = exprTypmod(targetExpr); Oid columnTypMod = exprTypmod(targetExpr);
Oid columnCollation = exprCollation(targetExpr); Oid columnCollation = exprCollation(targetExpr);
Var *functionColumnVar = NULL;
TargetEntry *newTargetEntry = NULL;
if (targetEntry->resjunk) if (targetEntry->resjunk)
{ {
@ -1616,7 +1573,7 @@ BuildSubPlanResultQuery(List *targetEntryList, List *columnAliasList, char *resu
funcColTypMods = lappend_int(funcColTypMods, columnTypMod); funcColTypMods = lappend_int(funcColTypMods, columnTypMod);
funcColCollations = lappend_int(funcColCollations, columnCollation); funcColCollations = lappend_int(funcColCollations, columnCollation);
functionColumnVar = makeNode(Var); Var *functionColumnVar = makeNode(Var);
functionColumnVar->varno = 1; functionColumnVar->varno = 1;
functionColumnVar->varattno = columnNumber; functionColumnVar->varattno = columnNumber;
functionColumnVar->vartype = columnType; functionColumnVar->vartype = columnType;
@ -1627,7 +1584,7 @@ BuildSubPlanResultQuery(List *targetEntryList, List *columnAliasList, char *resu
functionColumnVar->varoattno = columnNumber; functionColumnVar->varoattno = columnNumber;
functionColumnVar->location = -1; functionColumnVar->location = -1;
newTargetEntry = makeNode(TargetEntry); TargetEntry *newTargetEntry = makeNode(TargetEntry);
newTargetEntry->expr = (Expr *) functionColumnVar; newTargetEntry->expr = (Expr *) functionColumnVar;
newTargetEntry->resno = columnNumber; newTargetEntry->resno = columnNumber;
@ -1659,7 +1616,7 @@ BuildSubPlanResultQuery(List *targetEntryList, List *columnAliasList, char *resu
columnNumber++; columnNumber++;
} }
resultIdConst = makeNode(Const); Const *resultIdConst = makeNode(Const);
resultIdConst->consttype = TEXTOID; resultIdConst->consttype = TEXTOID;
resultIdConst->consttypmod = -1; resultIdConst->consttypmod = -1;
resultIdConst->constlen = -1; resultIdConst->constlen = -1;
@ -1674,7 +1631,7 @@ BuildSubPlanResultQuery(List *targetEntryList, List *columnAliasList, char *resu
copyFormatId = TextCopyFormatId(); copyFormatId = TextCopyFormatId();
} }
resultFormatConst = makeNode(Const); Const *resultFormatConst = makeNode(Const);
resultFormatConst->consttype = CitusCopyFormatTypeId(); resultFormatConst->consttype = CitusCopyFormatTypeId();
resultFormatConst->consttypmod = -1; resultFormatConst->consttypmod = -1;
resultFormatConst->constlen = 4; resultFormatConst->constlen = 4;
@ -1684,7 +1641,7 @@ BuildSubPlanResultQuery(List *targetEntryList, List *columnAliasList, char *resu
resultFormatConst->location = -1; resultFormatConst->location = -1;
/* build the call to read_intermediate_result */ /* build the call to read_intermediate_result */
funcExpr = makeNode(FuncExpr); FuncExpr *funcExpr = makeNode(FuncExpr);
funcExpr->funcid = CitusReadIntermediateResultFuncId(); funcExpr->funcid = CitusReadIntermediateResultFuncId();
funcExpr->funcretset = true; funcExpr->funcretset = true;
funcExpr->funcvariadic = false; funcExpr->funcvariadic = false;
@ -1695,7 +1652,7 @@ BuildSubPlanResultQuery(List *targetEntryList, List *columnAliasList, char *resu
funcExpr->args = list_make2(resultIdConst, resultFormatConst); funcExpr->args = list_make2(resultIdConst, resultFormatConst);
/* build the RTE for the call to read_intermediate_result */ /* build the RTE for the call to read_intermediate_result */
rangeTableFunction = makeNode(RangeTblFunction); RangeTblFunction *rangeTableFunction = makeNode(RangeTblFunction);
rangeTableFunction->funccolcount = list_length(funcColNames); rangeTableFunction->funccolcount = list_length(funcColNames);
rangeTableFunction->funccolnames = funcColNames; rangeTableFunction->funccolnames = funcColNames;
rangeTableFunction->funccoltypes = funcColTypes; rangeTableFunction->funccoltypes = funcColTypes;
@ -1704,25 +1661,25 @@ BuildSubPlanResultQuery(List *targetEntryList, List *columnAliasList, char *resu
rangeTableFunction->funcparams = NULL; rangeTableFunction->funcparams = NULL;
rangeTableFunction->funcexpr = (Node *) funcExpr; rangeTableFunction->funcexpr = (Node *) funcExpr;
funcAlias = makeNode(Alias); Alias *funcAlias = makeNode(Alias);
funcAlias->aliasname = "intermediate_result"; funcAlias->aliasname = "intermediate_result";
funcAlias->colnames = funcColNames; funcAlias->colnames = funcColNames;
rangeTableEntry = makeNode(RangeTblEntry); RangeTblEntry *rangeTableEntry = makeNode(RangeTblEntry);
rangeTableEntry->rtekind = RTE_FUNCTION; rangeTableEntry->rtekind = RTE_FUNCTION;
rangeTableEntry->functions = list_make1(rangeTableFunction); rangeTableEntry->functions = list_make1(rangeTableFunction);
rangeTableEntry->inFromCl = true; rangeTableEntry->inFromCl = true;
rangeTableEntry->eref = funcAlias; rangeTableEntry->eref = funcAlias;
/* build the join tree using the read_intermediate_result RTE */ /* build the join tree using the read_intermediate_result RTE */
rangeTableRef = makeNode(RangeTblRef); RangeTblRef *rangeTableRef = makeNode(RangeTblRef);
rangeTableRef->rtindex = 1; rangeTableRef->rtindex = 1;
joinTree = makeNode(FromExpr); FromExpr *joinTree = makeNode(FromExpr);
joinTree->fromlist = list_make1(rangeTableRef); joinTree->fromlist = list_make1(rangeTableRef);
/* build the SELECT query */ /* build the SELECT query */
resultQuery = makeNode(Query); Query *resultQuery = makeNode(Query);
resultQuery->commandType = CMD_SELECT; resultQuery->commandType = CMD_SELECT;
resultQuery->rtable = list_make1(rangeTableEntry); resultQuery->rtable = list_make1(rangeTableEntry);
resultQuery->jointree = joinTree; resultQuery->jointree = joinTree;

View File

@ -160,9 +160,6 @@ bool
AllDistributionKeysInQueryAreEqual(Query *originalQuery, AllDistributionKeysInQueryAreEqual(Query *originalQuery,
PlannerRestrictionContext *plannerRestrictionContext) PlannerRestrictionContext *plannerRestrictionContext)
{ {
bool restrictionEquivalenceForPartitionKeys = false;
RelationRestrictionContext *restrictionContext = NULL;
/* we don't support distribution key equality checks for CTEs yet */ /* we don't support distribution key equality checks for CTEs yet */
if (originalQuery->cteList != NIL) if (originalQuery->cteList != NIL)
{ {
@ -170,13 +167,14 @@ AllDistributionKeysInQueryAreEqual(Query *originalQuery,
} }
/* we don't support distribution key equality checks for local tables */ /* we don't support distribution key equality checks for local tables */
restrictionContext = plannerRestrictionContext->relationRestrictionContext; RelationRestrictionContext *restrictionContext =
plannerRestrictionContext->relationRestrictionContext;
if (ContextContainsLocalRelation(restrictionContext)) if (ContextContainsLocalRelation(restrictionContext))
{ {
return false; return false;
} }
restrictionEquivalenceForPartitionKeys = bool restrictionEquivalenceForPartitionKeys =
RestrictionEquivalenceForPartitionKeys(plannerRestrictionContext); RestrictionEquivalenceForPartitionKeys(plannerRestrictionContext);
if (restrictionEquivalenceForPartitionKeys) if (restrictionEquivalenceForPartitionKeys)
{ {
@ -245,9 +243,6 @@ SafeToPushdownUnionSubquery(PlannerRestrictionContext *plannerRestrictionContext
AttributeEquivalenceClass *attributeEquivalance = AttributeEquivalenceClass *attributeEquivalance =
palloc0(sizeof(AttributeEquivalenceClass)); palloc0(sizeof(AttributeEquivalenceClass));
ListCell *relationRestrictionCell = NULL; ListCell *relationRestrictionCell = NULL;
List *relationRestrictionAttributeEquivalenceList = NIL;
List *joinRestrictionAttributeEquivalenceList = NIL;
List *allAttributeEquivalenceList = NIL;
attributeEquivalance->equivalenceId = attributeEquivalenceId++; attributeEquivalance->equivalenceId = attributeEquivalenceId++;
@ -338,12 +333,12 @@ SafeToPushdownUnionSubquery(PlannerRestrictionContext *plannerRestrictionContext
* we determine whether all relations are joined on the partition column * we determine whether all relations are joined on the partition column
* by adding the equivalence classes that can be inferred from joins. * by adding the equivalence classes that can be inferred from joins.
*/ */
relationRestrictionAttributeEquivalenceList = List *relationRestrictionAttributeEquivalenceList =
GenerateAttributeEquivalencesForRelationRestrictions(restrictionContext); GenerateAttributeEquivalencesForRelationRestrictions(restrictionContext);
joinRestrictionAttributeEquivalenceList = List *joinRestrictionAttributeEquivalenceList =
GenerateAttributeEquivalencesForJoinRestrictions(joinRestrictionContext); GenerateAttributeEquivalencesForJoinRestrictions(joinRestrictionContext);
allAttributeEquivalenceList = List *allAttributeEquivalenceList =
list_concat(relationRestrictionAttributeEquivalenceList, list_concat(relationRestrictionAttributeEquivalenceList,
joinRestrictionAttributeEquivalenceList); joinRestrictionAttributeEquivalenceList);
@ -373,8 +368,6 @@ FindTranslatedVar(List *appendRelList, Oid relationOid, Index relationRteIndex,
AppendRelInfo *targetAppendRelInfo = NULL; AppendRelInfo *targetAppendRelInfo = NULL;
ListCell *translatedVarCell = NULL; ListCell *translatedVarCell = NULL;
AttrNumber childAttrNumber = 0; AttrNumber childAttrNumber = 0;
Var *relationPartitionKey = NULL;
List *translaterVars = NULL;
*partitionKeyIndex = 0; *partitionKeyIndex = 0;
@ -400,13 +393,12 @@ FindTranslatedVar(List *appendRelList, Oid relationOid, Index relationRteIndex,
return NULL; return NULL;
} }
relationPartitionKey = DistPartitionKey(relationOid); Var *relationPartitionKey = DistPartitionKey(relationOid);
translaterVars = targetAppendRelInfo->translated_vars; List *translaterVars = targetAppendRelInfo->translated_vars;
foreach(translatedVarCell, translaterVars) foreach(translatedVarCell, translaterVars)
{ {
Node *targetNode = (Node *) lfirst(translatedVarCell); Node *targetNode = (Node *) lfirst(translatedVarCell);
Var *targetVar = NULL;
childAttrNumber++; childAttrNumber++;
@ -415,7 +407,7 @@ FindTranslatedVar(List *appendRelList, Oid relationOid, Index relationRteIndex,
continue; continue;
} }
targetVar = (Var *) lfirst(translatedVarCell); Var *targetVar = (Var *) lfirst(translatedVarCell);
if (targetVar->varno == relationRteIndex && if (targetVar->varno == relationRteIndex &&
targetVar->varattno == relationPartitionKey->varattno) targetVar->varattno == relationPartitionKey->varattno)
{ {
@ -464,15 +456,13 @@ FindTranslatedVar(List *appendRelList, Oid relationOid, Index relationRteIndex,
bool bool
RestrictionEquivalenceForPartitionKeys(PlannerRestrictionContext *restrictionContext) RestrictionEquivalenceForPartitionKeys(PlannerRestrictionContext *restrictionContext)
{ {
List *attributeEquivalenceList = NIL;
/* there is a single distributed relation, no need to continue */ /* there is a single distributed relation, no need to continue */
if (!ContainsMultipleDistributedRelations(restrictionContext)) if (!ContainsMultipleDistributedRelations(restrictionContext))
{ {
return true; return true;
} }
attributeEquivalenceList = GenerateAllAttributeEquivalences(restrictionContext); List *attributeEquivalenceList = GenerateAllAttributeEquivalences(restrictionContext);
return RestrictionEquivalenceForPartitionKeysViaEquivalances(restrictionContext, return RestrictionEquivalenceForPartitionKeysViaEquivalances(restrictionContext,
attributeEquivalenceList); attributeEquivalenceList);
@ -554,19 +544,17 @@ GenerateAllAttributeEquivalences(PlannerRestrictionContext *plannerRestrictionCo
JoinRestrictionContext *joinRestrictionContext = JoinRestrictionContext *joinRestrictionContext =
plannerRestrictionContext->joinRestrictionContext; plannerRestrictionContext->joinRestrictionContext;
List *relationRestrictionAttributeEquivalenceList = NIL;
List *joinRestrictionAttributeEquivalenceList = NIL;
List *allAttributeEquivalenceList = NIL;
/* reset the equivalence id counter per call to prevent overflows */ /* reset the equivalence id counter per call to prevent overflows */
attributeEquivalenceId = 1; attributeEquivalenceId = 1;
relationRestrictionAttributeEquivalenceList = List *relationRestrictionAttributeEquivalenceList =
GenerateAttributeEquivalencesForRelationRestrictions(relationRestrictionContext); GenerateAttributeEquivalencesForRelationRestrictions(relationRestrictionContext);
joinRestrictionAttributeEquivalenceList = List *joinRestrictionAttributeEquivalenceList =
GenerateAttributeEquivalencesForJoinRestrictions(joinRestrictionContext); GenerateAttributeEquivalencesForJoinRestrictions(joinRestrictionContext);
allAttributeEquivalenceList = list_concat(relationRestrictionAttributeEquivalenceList, List *allAttributeEquivalenceList = list_concat(
relationRestrictionAttributeEquivalenceList,
joinRestrictionAttributeEquivalenceList); joinRestrictionAttributeEquivalenceList);
return allAttributeEquivalenceList; return allAttributeEquivalenceList;
@ -609,7 +597,6 @@ bool
EquivalenceListContainsRelationsEquality(List *attributeEquivalenceList, EquivalenceListContainsRelationsEquality(List *attributeEquivalenceList,
RelationRestrictionContext *restrictionContext) RelationRestrictionContext *restrictionContext)
{ {
AttributeEquivalenceClass *commonEquivalenceClass = NULL;
ListCell *commonEqClassCell = NULL; ListCell *commonEqClassCell = NULL;
ListCell *relationRestrictionCell = NULL; ListCell *relationRestrictionCell = NULL;
Relids commonRteIdentities = NULL; Relids commonRteIdentities = NULL;
@ -619,7 +606,8 @@ EquivalenceListContainsRelationsEquality(List *attributeEquivalenceList,
* common equivalence class. The main goal is to test whether this main class * common equivalence class. The main goal is to test whether this main class
* contains all partition keys of the existing relations. * contains all partition keys of the existing relations.
*/ */
commonEquivalenceClass = GenerateCommonEquivalence(attributeEquivalenceList, AttributeEquivalenceClass *commonEquivalenceClass = GenerateCommonEquivalence(
attributeEquivalenceList,
restrictionContext); restrictionContext);
/* add the rte indexes of relations to a bitmap */ /* add the rte indexes of relations to a bitmap */
@ -885,13 +873,12 @@ static AttributeEquivalenceClass *
GenerateCommonEquivalence(List *attributeEquivalenceList, GenerateCommonEquivalence(List *attributeEquivalenceList,
RelationRestrictionContext *relationRestrictionContext) RelationRestrictionContext *relationRestrictionContext)
{ {
AttributeEquivalenceClass *commonEquivalenceClass = NULL;
AttributeEquivalenceClass *firstEquivalenceClass = NULL;
Bitmapset *addedEquivalenceIds = NULL; Bitmapset *addedEquivalenceIds = NULL;
uint32 equivalenceListSize = list_length(attributeEquivalenceList); uint32 equivalenceListSize = list_length(attributeEquivalenceList);
uint32 equivalenceClassIndex = 0; uint32 equivalenceClassIndex = 0;
commonEquivalenceClass = palloc0(sizeof(AttributeEquivalenceClass)); AttributeEquivalenceClass *commonEquivalenceClass = palloc0(
sizeof(AttributeEquivalenceClass));
commonEquivalenceClass->equivalenceId = 0; commonEquivalenceClass->equivalenceId = 0;
/* /*
@ -899,7 +886,7 @@ GenerateCommonEquivalence(List *attributeEquivalenceList,
* table since we always want the input distributed relations to be * table since we always want the input distributed relations to be
* on the common class. * on the common class.
*/ */
firstEquivalenceClass = AttributeEquivalenceClass *firstEquivalenceClass =
GenerateEquivalanceClassForRelationRestriction(relationRestrictionContext); GenerateEquivalanceClassForRelationRestriction(relationRestrictionContext);
/* we skip the calculation if there are not enough information */ /* we skip the calculation if there are not enough information */
@ -915,11 +902,11 @@ GenerateCommonEquivalence(List *attributeEquivalenceList,
while (equivalenceClassIndex < equivalenceListSize) while (equivalenceClassIndex < equivalenceListSize)
{ {
AttributeEquivalenceClass *currentEquivalenceClass = NULL;
ListCell *equivalenceMemberCell = NULL; ListCell *equivalenceMemberCell = NULL;
bool restartLoop = false; bool restartLoop = false;
currentEquivalenceClass = list_nth(attributeEquivalenceList, AttributeEquivalenceClass *currentEquivalenceClass = list_nth(
attributeEquivalenceList,
equivalenceClassIndex); equivalenceClassIndex);
/* /*
@ -1077,22 +1064,14 @@ GenerateAttributeEquivalencesForJoinRestrictions(JoinRestrictionContext *
foreach(restrictionInfoList, joinRestriction->joinRestrictInfoList) foreach(restrictionInfoList, joinRestriction->joinRestrictInfoList)
{ {
RestrictInfo *rinfo = (RestrictInfo *) lfirst(restrictionInfoList); RestrictInfo *rinfo = (RestrictInfo *) lfirst(restrictionInfoList);
OpExpr *restrictionOpExpr = NULL;
Node *leftNode = NULL;
Node *rightNode = NULL;
Expr *strippedLeftExpr = NULL;
Expr *strippedRightExpr = NULL;
Var *leftVar = NULL;
Var *rightVar = NULL;
Expr *restrictionClause = rinfo->clause; Expr *restrictionClause = rinfo->clause;
AttributeEquivalenceClass *attributeEquivalance = NULL;
if (!IsA(restrictionClause, OpExpr)) if (!IsA(restrictionClause, OpExpr))
{ {
continue; continue;
} }
restrictionOpExpr = (OpExpr *) restrictionClause; OpExpr *restrictionOpExpr = (OpExpr *) restrictionClause;
if (list_length(restrictionOpExpr->args) != 2) if (list_length(restrictionOpExpr->args) != 2)
{ {
continue; continue;
@ -1102,22 +1081,24 @@ GenerateAttributeEquivalencesForJoinRestrictions(JoinRestrictionContext *
continue; continue;
} }
leftNode = linitial(restrictionOpExpr->args); Node *leftNode = linitial(restrictionOpExpr->args);
rightNode = lsecond(restrictionOpExpr->args); Node *rightNode = lsecond(restrictionOpExpr->args);
/* we also don't want implicit coercions */ /* we also don't want implicit coercions */
strippedLeftExpr = (Expr *) strip_implicit_coercions((Node *) leftNode); Expr *strippedLeftExpr = (Expr *) strip_implicit_coercions((Node *) leftNode);
strippedRightExpr = (Expr *) strip_implicit_coercions((Node *) rightNode); Expr *strippedRightExpr = (Expr *) strip_implicit_coercions(
(Node *) rightNode);
if (!(IsA(strippedLeftExpr, Var) && IsA(strippedRightExpr, Var))) if (!(IsA(strippedLeftExpr, Var) && IsA(strippedRightExpr, Var)))
{ {
continue; continue;
} }
leftVar = (Var *) strippedLeftExpr; Var *leftVar = (Var *) strippedLeftExpr;
rightVar = (Var *) strippedRightExpr; Var *rightVar = (Var *) strippedRightExpr;
attributeEquivalance = palloc0(sizeof(AttributeEquivalenceClass)); AttributeEquivalenceClass *attributeEquivalance = palloc0(
sizeof(AttributeEquivalenceClass));
attributeEquivalance->equivalenceId = attributeEquivalenceId++; attributeEquivalance->equivalenceId = attributeEquivalenceId++;
AddToAttributeEquivalenceClass(&attributeEquivalance, AddToAttributeEquivalenceClass(&attributeEquivalance,
@ -1167,8 +1148,6 @@ static void
AddToAttributeEquivalenceClass(AttributeEquivalenceClass **attributeEquivalanceClass, AddToAttributeEquivalenceClass(AttributeEquivalenceClass **attributeEquivalanceClass,
PlannerInfo *root, Var *varToBeAdded) PlannerInfo *root, Var *varToBeAdded)
{ {
RangeTblEntry *rangeTableEntry = NULL;
/* punt if it's a whole-row var rather than a plain column reference */ /* punt if it's a whole-row var rather than a plain column reference */
if (varToBeAdded->varattno == InvalidAttrNumber) if (varToBeAdded->varattno == InvalidAttrNumber)
{ {
@ -1181,7 +1160,7 @@ AddToAttributeEquivalenceClass(AttributeEquivalenceClass **attributeEquivalanceC
return; return;
} }
rangeTableEntry = root->simple_rte_array[varToBeAdded->varno]; RangeTblEntry *rangeTableEntry = root->simple_rte_array[varToBeAdded->varno];
if (rangeTableEntry->rtekind == RTE_RELATION) if (rangeTableEntry->rtekind == RTE_RELATION)
{ {
AddRteRelationToAttributeEquivalenceClass(attributeEquivalanceClass, AddRteRelationToAttributeEquivalenceClass(attributeEquivalanceClass,
@ -1210,7 +1189,6 @@ AddRteSubqueryToAttributeEquivalenceClass(AttributeEquivalenceClass
PlannerInfo *root, Var *varToBeAdded) PlannerInfo *root, Var *varToBeAdded)
{ {
RelOptInfo *baseRelOptInfo = find_base_rel(root, varToBeAdded->varno); RelOptInfo *baseRelOptInfo = find_base_rel(root, varToBeAdded->varno);
TargetEntry *subqueryTargetEntry = NULL;
Query *targetSubquery = GetTargetSubquery(root, rangeTableEntry, varToBeAdded); Query *targetSubquery = GetTargetSubquery(root, rangeTableEntry, varToBeAdded);
/* /*
@ -1229,7 +1207,7 @@ AddRteSubqueryToAttributeEquivalenceClass(AttributeEquivalenceClass
return; return;
} }
subqueryTargetEntry = get_tle_by_resno(targetSubquery->targetList, TargetEntry *subqueryTargetEntry = get_tle_by_resno(targetSubquery->targetList,
varToBeAdded->varattno); varToBeAdded->varattno);
/* if we fail to find corresponding target entry, do not proceed */ /* if we fail to find corresponding target entry, do not proceed */
@ -1402,9 +1380,7 @@ AddRteRelationToAttributeEquivalenceClass(AttributeEquivalenceClass **
RangeTblEntry *rangeTableEntry, RangeTblEntry *rangeTableEntry,
Var *varToBeAdded) Var *varToBeAdded)
{ {
AttributeEquivalenceClassMember *attributeEqMember = NULL;
Oid relationId = rangeTableEntry->relid; Oid relationId = rangeTableEntry->relid;
Var *relationPartitionKey = NULL;
/* we don't consider local tables in the equality on columns */ /* we don't consider local tables in the equality on columns */
if (!IsDistributedTable(relationId)) if (!IsDistributedTable(relationId))
@ -1412,7 +1388,7 @@ AddRteRelationToAttributeEquivalenceClass(AttributeEquivalenceClass **
return; return;
} }
relationPartitionKey = DistPartitionKey(relationId); Var *relationPartitionKey = DistPartitionKey(relationId);
Assert(rangeTableEntry->rtekind == RTE_RELATION); Assert(rangeTableEntry->rtekind == RTE_RELATION);
@ -1428,7 +1404,8 @@ AddRteRelationToAttributeEquivalenceClass(AttributeEquivalenceClass **
return; return;
} }
attributeEqMember = palloc0(sizeof(AttributeEquivalenceClassMember)); AttributeEquivalenceClassMember *attributeEqMember = palloc0(
sizeof(AttributeEquivalenceClassMember));
attributeEqMember->varattno = varToBeAdded->varattno; attributeEqMember->varattno = varToBeAdded->varattno;
attributeEqMember->varno = varToBeAdded->varno; attributeEqMember->varno = varToBeAdded->varno;
@ -1481,7 +1458,6 @@ static List *
AddAttributeClassToAttributeClassList(List *attributeEquivalenceList, AddAttributeClassToAttributeClassList(List *attributeEquivalenceList,
AttributeEquivalenceClass *attributeEquivalance) AttributeEquivalenceClass *attributeEquivalance)
{ {
List *equivalentAttributes = NULL;
ListCell *attributeEquivalanceCell = NULL; ListCell *attributeEquivalanceCell = NULL;
if (attributeEquivalance == NULL) if (attributeEquivalance == NULL)
@ -1493,7 +1469,7 @@ AddAttributeClassToAttributeClassList(List *attributeEquivalenceList,
* Note that in some cases we allow having equivalentAttributes with zero or * Note that in some cases we allow having equivalentAttributes with zero or
* one elements. For the details, see AddToAttributeEquivalenceClass(). * one elements. For the details, see AddToAttributeEquivalenceClass().
*/ */
equivalentAttributes = attributeEquivalance->equivalentAttributes; List *equivalentAttributes = attributeEquivalance->equivalentAttributes;
if (list_length(equivalentAttributes) < 2) if (list_length(equivalentAttributes) < 2)
{ {
return attributeEquivalenceList; return attributeEquivalenceList;
@ -1589,15 +1565,10 @@ bool
ContainsUnionSubquery(Query *queryTree) ContainsUnionSubquery(Query *queryTree)
{ {
List *rangeTableList = queryTree->rtable; List *rangeTableList = queryTree->rtable;
Node *setOperations = queryTree->setOperations;
List *joinTreeTableIndexList = NIL; List *joinTreeTableIndexList = NIL;
Index subqueryRteIndex = 0;
uint32 joiningRangeTableCount = 0;
RangeTblEntry *rangeTableEntry = NULL;
Query *subqueryTree = NULL;
ExtractRangeTableIndexWalker((Node *) queryTree->jointree, &joinTreeTableIndexList); ExtractRangeTableIndexWalker((Node *) queryTree->jointree, &joinTreeTableIndexList);
joiningRangeTableCount = list_length(joinTreeTableIndexList); uint32 joiningRangeTableCount = list_length(joinTreeTableIndexList);
/* don't allow joins on top of unions */ /* don't allow joins on top of unions */
if (joiningRangeTableCount > 1) if (joiningRangeTableCount > 1)
@ -1611,15 +1582,15 @@ ContainsUnionSubquery(Query *queryTree)
return false; return false;
} }
subqueryRteIndex = linitial_int(joinTreeTableIndexList); Index subqueryRteIndex = linitial_int(joinTreeTableIndexList);
rangeTableEntry = rt_fetch(subqueryRteIndex, rangeTableList); RangeTblEntry *rangeTableEntry = rt_fetch(subqueryRteIndex, rangeTableList);
if (rangeTableEntry->rtekind != RTE_SUBQUERY) if (rangeTableEntry->rtekind != RTE_SUBQUERY)
{ {
return false; return false;
} }
subqueryTree = rangeTableEntry->subquery; Query *subqueryTree = rangeTableEntry->subquery;
setOperations = subqueryTree->setOperations; Node *setOperations = subqueryTree->setOperations;
if (setOperations != NULL) if (setOperations != NULL)
{ {
SetOperationStmt *setOperationStatement = (SetOperationStmt *) setOperations; SetOperationStmt *setOperationStatement = (SetOperationStmt *) setOperations;
@ -1648,15 +1619,12 @@ ContainsUnionSubquery(Query *queryTree)
static Index static Index
RelationRestrictionPartitionKeyIndex(RelationRestriction *relationRestriction) RelationRestrictionPartitionKeyIndex(RelationRestriction *relationRestriction)
{ {
PlannerInfo *relationPlannerRoot = NULL;
Query *relationPlannerParseQuery = NULL;
List *relationTargetList = NIL;
ListCell *targetEntryCell = NULL; ListCell *targetEntryCell = NULL;
Index partitionKeyTargetAttrIndex = 0; Index partitionKeyTargetAttrIndex = 0;
relationPlannerRoot = relationRestriction->plannerInfo; PlannerInfo *relationPlannerRoot = relationRestriction->plannerInfo;
relationPlannerParseQuery = relationPlannerRoot->parse; Query *relationPlannerParseQuery = relationPlannerRoot->parse;
relationTargetList = relationPlannerParseQuery->targetList; List *relationTargetList = relationPlannerParseQuery->targetList;
foreach(targetEntryCell, relationTargetList) foreach(targetEntryCell, relationTargetList)
{ {
@ -1689,12 +1657,11 @@ List *
DistributedRelationIdList(Query *query) DistributedRelationIdList(Query *query)
{ {
List *rangeTableList = NIL; List *rangeTableList = NIL;
List *tableEntryList = NIL;
List *relationIdList = NIL; List *relationIdList = NIL;
ListCell *tableEntryCell = NULL; ListCell *tableEntryCell = NULL;
ExtractRangeTableRelationWalker((Node *) query, &rangeTableList); ExtractRangeTableRelationWalker((Node *) query, &rangeTableList);
tableEntryList = TableEntryList(rangeTableList); List *tableEntryList = TableEntryList(rangeTableList);
foreach(tableEntryCell, tableEntryList) foreach(tableEntryCell, tableEntryList)
{ {
@ -1724,10 +1691,6 @@ PlannerRestrictionContext *
FilterPlannerRestrictionForQuery(PlannerRestrictionContext *plannerRestrictionContext, FilterPlannerRestrictionForQuery(PlannerRestrictionContext *plannerRestrictionContext,
Query *query) Query *query)
{ {
PlannerRestrictionContext *filteredPlannerRestrictionContext = NULL;
int referenceRelationCount = 0;
int totalRelationCount = 0;
Relids queryRteIdentities = QueryRteIdentities(query); Relids queryRteIdentities = QueryRteIdentities(query);
RelationRestrictionContext *relationRestrictionContext = RelationRestrictionContext *relationRestrictionContext =
@ -1742,14 +1705,16 @@ FilterPlannerRestrictionForQuery(PlannerRestrictionContext *plannerRestrictionCo
FilterJoinRestrictionContext(joinRestrictionContext, queryRteIdentities); FilterJoinRestrictionContext(joinRestrictionContext, queryRteIdentities);
/* allocate the filtered planner restriction context and set all the fields */ /* allocate the filtered planner restriction context and set all the fields */
filteredPlannerRestrictionContext = palloc0(sizeof(PlannerRestrictionContext)); PlannerRestrictionContext *filteredPlannerRestrictionContext = palloc0(
sizeof(PlannerRestrictionContext));
filteredPlannerRestrictionContext->memoryContext = filteredPlannerRestrictionContext->memoryContext =
plannerRestrictionContext->memoryContext; plannerRestrictionContext->memoryContext;
totalRelationCount = list_length( int totalRelationCount = list_length(
filteredRelationRestrictionContext->relationRestrictionList); filteredRelationRestrictionContext->relationRestrictionList);
referenceRelationCount = ReferenceRelationCount(filteredRelationRestrictionContext); int referenceRelationCount = ReferenceRelationCount(
filteredRelationRestrictionContext);
filteredRelationRestrictionContext->allReferenceTables = filteredRelationRestrictionContext->allReferenceTables =
(totalRelationCount == referenceRelationCount); (totalRelationCount == referenceRelationCount);
@ -1850,10 +1815,8 @@ static bool
RangeTableArrayContainsAnyRTEIdentities(RangeTblEntry **rangeTableEntries, int RangeTableArrayContainsAnyRTEIdentities(RangeTblEntry **rangeTableEntries, int
rangeTableArrayLength, Relids queryRteIdentities) rangeTableArrayLength, Relids queryRteIdentities)
{ {
int rteIndex = 0;
/* simple_rte_array starts from 1, see plannerInfo struct */ /* simple_rte_array starts from 1, see plannerInfo struct */
for (rteIndex = 1; rteIndex < rangeTableArrayLength; ++rteIndex) for (int rteIndex = 1; rteIndex < rangeTableArrayLength; ++rteIndex)
{ {
RangeTblEntry *rangeTableEntry = rangeTableEntries[rteIndex]; RangeTblEntry *rangeTableEntry = rangeTableEntries[rteIndex];
List *rangeTableRelationList = NULL; List *rangeTableRelationList = NULL;
@ -1883,11 +1846,10 @@ RangeTableArrayContainsAnyRTEIdentities(RangeTblEntry **rangeTableEntries, int
foreach(rteRelationCell, rangeTableRelationList) foreach(rteRelationCell, rangeTableRelationList)
{ {
RangeTblEntry *rteRelation = (RangeTblEntry *) lfirst(rteRelationCell); RangeTblEntry *rteRelation = (RangeTblEntry *) lfirst(rteRelationCell);
int rteIdentity = 0;
Assert(rteRelation->rtekind == RTE_RELATION); Assert(rteRelation->rtekind == RTE_RELATION);
rteIdentity = GetRTEIdentity(rteRelation); int rteIdentity = GetRTEIdentity(rteRelation);
if (bms_is_member(rteIdentity, queryRteIdentities)) if (bms_is_member(rteIdentity, queryRteIdentities))
{ {
return true; return true;
@ -1916,12 +1878,11 @@ QueryRteIdentities(Query *queryTree)
foreach(rangeTableCell, rangeTableList) foreach(rangeTableCell, rangeTableList)
{ {
RangeTblEntry *rangeTableEntry = (RangeTblEntry *) lfirst(rangeTableCell); RangeTblEntry *rangeTableEntry = (RangeTblEntry *) lfirst(rangeTableCell);
int rteIdentity = 0;
/* we're only interested in relations */ /* we're only interested in relations */
Assert(rangeTableEntry->rtekind == RTE_RELATION); Assert(rangeTableEntry->rtekind == RTE_RELATION);
rteIdentity = GetRTEIdentity(rangeTableEntry); int rteIdentity = GetRTEIdentity(rangeTableEntry);
queryRteIdentities = bms_add_member(queryRteIdentities, rteIdentity); queryRteIdentities = bms_add_member(queryRteIdentities, rteIdentity);
} }

View File

@ -308,7 +308,6 @@ PruneShards(Oid relationId, Index rangeTableId, List *whereClauseList,
foreach(pruneCell, context.pruningInstances) foreach(pruneCell, context.pruningInstances)
{ {
PruningInstance *prune = (PruningInstance *) lfirst(pruneCell); PruningInstance *prune = (PruningInstance *) lfirst(pruneCell);
List *pruneOneList;
/* /*
* If this is a partial instance, a fully built one has also been * If this is a partial instance, a fully built one has also been
@ -358,7 +357,7 @@ PruneShards(Oid relationId, Index rangeTableId, List *whereClauseList,
} }
} }
pruneOneList = PruneOne(cacheEntry, &context, prune); List *pruneOneList = PruneOne(cacheEntry, &context, prune);
if (prunedList) if (prunedList)
{ {
@ -643,12 +642,9 @@ AddSAOPartitionKeyRestrictionToInstance(ClauseWalkerContext *context,
equal(strippedLeftOpExpression, context->partitionColumn) && equal(strippedLeftOpExpression, context->partitionColumn) &&
IsA(arrayArgument, Const)) IsA(arrayArgument, Const))
{ {
ArrayType *array = NULL;
int16 typlen = 0; int16 typlen = 0;
bool typbyval = false; bool typbyval = false;
char typalign = '\0'; char typalign = '\0';
Oid elementType = 0;
ArrayIterator arrayIterator = NULL;
Datum arrayElement = 0; Datum arrayElement = 0;
Datum inArray = ((Const *) arrayArgument)->constvalue; Datum inArray = ((Const *) arrayArgument)->constvalue;
bool isNull = false; bool isNull = false;
@ -659,26 +655,25 @@ AddSAOPartitionKeyRestrictionToInstance(ClauseWalkerContext *context,
return; return;
} }
array = DatumGetArrayTypeP(((Const *) arrayArgument)->constvalue); ArrayType *array = DatumGetArrayTypeP(((Const *) arrayArgument)->constvalue);
/* get the necessary information from array type to iterate over it */ /* get the necessary information from array type to iterate over it */
elementType = ARR_ELEMTYPE(array); Oid elementType = ARR_ELEMTYPE(array);
get_typlenbyvalalign(elementType, get_typlenbyvalalign(elementType,
&typlen, &typlen,
&typbyval, &typbyval,
&typalign); &typalign);
/* Iterate over the righthand array of expression */ /* Iterate over the righthand array of expression */
arrayIterator = array_create_iterator(array, 0, NULL); ArrayIterator arrayIterator = array_create_iterator(array, 0, NULL);
while (array_iterate(arrayIterator, &arrayElement, &isNull)) while (array_iterate(arrayIterator, &arrayElement, &isNull))
{ {
OpExpr *arrayEqualityOp = NULL;
Const *constElement = makeConst(elementType, -1, Const *constElement = makeConst(elementType, -1,
DEFAULT_COLLATION_OID, typlen, arrayElement, DEFAULT_COLLATION_OID, typlen, arrayElement,
isNull, typbyval); isNull, typbyval);
/* build partcol = arrayelem operator */ /* build partcol = arrayelem operator */
arrayEqualityOp = makeNode(OpExpr); OpExpr *arrayEqualityOp = makeNode(OpExpr);
arrayEqualityOp->opno = arrayOperatorExpression->opno; arrayEqualityOp->opno = arrayOperatorExpression->opno;
arrayEqualityOp->opfuncid = arrayOperatorExpression->opfuncid; arrayEqualityOp->opfuncid = arrayOperatorExpression->opfuncid;
arrayEqualityOp->inputcollid = arrayOperatorExpression->inputcollid; arrayEqualityOp->inputcollid = arrayOperatorExpression->inputcollid;
@ -734,7 +729,6 @@ AddPartitionKeyRestrictionToInstance(ClauseWalkerContext *context, OpExpr *opCla
Var *partitionColumn, Const *constantClause) Var *partitionColumn, Const *constantClause)
{ {
PruningInstance *prune = context->currentPruningInstance; PruningInstance *prune = context->currentPruningInstance;
List *btreeInterpretationList = NULL;
ListCell *btreeInterpretationCell = NULL; ListCell *btreeInterpretationCell = NULL;
bool matchedOp = false; bool matchedOp = false;
@ -756,7 +750,7 @@ AddPartitionKeyRestrictionToInstance(ClauseWalkerContext *context, OpExpr *opCla
/* at this point, we'd better be able to pass binary Datums to comparison functions */ /* at this point, we'd better be able to pass binary Datums to comparison functions */
Assert(IsBinaryCoercible(constantClause->consttype, partitionColumn->vartype)); Assert(IsBinaryCoercible(constantClause->consttype, partitionColumn->vartype));
btreeInterpretationList = get_op_btree_interpretation(opClause->opno); List *btreeInterpretationList = get_op_btree_interpretation(opClause->opno);
foreach(btreeInterpretationCell, btreeInterpretationList) foreach(btreeInterpretationCell, btreeInterpretationList)
{ {
OpBtreeInterpretation *btreeInterpretation = OpBtreeInterpretation *btreeInterpretation =
@ -924,13 +918,12 @@ AddHashRestrictionToInstance(ClauseWalkerContext *context, OpExpr *opClause,
Var *varClause, Const *constantClause) Var *varClause, Const *constantClause)
{ {
PruningInstance *prune = context->currentPruningInstance; PruningInstance *prune = context->currentPruningInstance;
List *btreeInterpretationList = NULL;
ListCell *btreeInterpretationCell = NULL; ListCell *btreeInterpretationCell = NULL;
/* be paranoid */ /* be paranoid */
Assert(IsBinaryCoercible(constantClause->consttype, INT4OID)); Assert(IsBinaryCoercible(constantClause->consttype, INT4OID));
btreeInterpretationList = List *btreeInterpretationList =
get_op_btree_interpretation(opClause->opno); get_op_btree_interpretation(opClause->opno);
foreach(btreeInterpretationCell, btreeInterpretationList) foreach(btreeInterpretationCell, btreeInterpretationList)
{ {
@ -986,9 +979,8 @@ static List *
ShardArrayToList(ShardInterval **shardArray, int length) ShardArrayToList(ShardInterval **shardArray, int length)
{ {
List *shardIntervalList = NIL; List *shardIntervalList = NIL;
int shardIndex;
for (shardIndex = 0; shardIndex < length; shardIndex++) for (int shardIndex = 0; shardIndex < length; shardIndex++)
{ {
ShardInterval *shardInterval = ShardInterval *shardInterval =
shardArray[shardIndex]; shardArray[shardIndex];
@ -1068,12 +1060,11 @@ PruneOne(DistTableCacheEntry *cacheEntry, ClauseWalkerContext *context,
*/ */
if (prune->hashedEqualConsts) if (prune->hashedEqualConsts)
{ {
int shardIndex = INVALID_SHARD_INDEX;
ShardInterval **sortedShardIntervalArray = cacheEntry->sortedShardIntervalArray; ShardInterval **sortedShardIntervalArray = cacheEntry->sortedShardIntervalArray;
Assert(context->partitionMethod == DISTRIBUTE_BY_HASH); Assert(context->partitionMethod == DISTRIBUTE_BY_HASH);
shardIndex = FindShardIntervalIndex(prune->hashedEqualConsts->constvalue, int shardIndex = FindShardIntervalIndex(prune->hashedEqualConsts->constvalue,
cacheEntry); cacheEntry);
if (shardIndex == INVALID_SHARD_INDEX) if (shardIndex == INVALID_SHARD_INDEX)
@ -1198,14 +1189,12 @@ LowerShardBoundary(Datum partitionColumnValue, ShardInterval **shardIntervalCach
while (lowerBoundIndex < upperBoundIndex) while (lowerBoundIndex < upperBoundIndex)
{ {
int middleIndex = lowerBoundIndex + ((upperBoundIndex - lowerBoundIndex) / 2); int middleIndex = lowerBoundIndex + ((upperBoundIndex - lowerBoundIndex) / 2);
int maxValueComparison = 0;
int minValueComparison = 0;
/* setup minValue as argument */ /* setup minValue as argument */
fcSetArg(compareFunction, 1, shardIntervalCache[middleIndex]->minValue); fcSetArg(compareFunction, 1, shardIntervalCache[middleIndex]->minValue);
/* execute cmp(partitionValue, lowerBound) */ /* execute cmp(partitionValue, lowerBound) */
minValueComparison = PerformCompare(compareFunction); int minValueComparison = PerformCompare(compareFunction);
/* and evaluate results */ /* and evaluate results */
if (minValueComparison < 0) if (minValueComparison < 0)
@ -1219,7 +1208,7 @@ LowerShardBoundary(Datum partitionColumnValue, ShardInterval **shardIntervalCach
fcSetArg(compareFunction, 1, shardIntervalCache[middleIndex]->maxValue); fcSetArg(compareFunction, 1, shardIntervalCache[middleIndex]->maxValue);
/* execute cmp(partitionValue, upperBound) */ /* execute cmp(partitionValue, upperBound) */
maxValueComparison = PerformCompare(compareFunction); int maxValueComparison = PerformCompare(compareFunction);
if ((maxValueComparison == 0 && !includeMax) || if ((maxValueComparison == 0 && !includeMax) ||
maxValueComparison > 0) maxValueComparison > 0)
@ -1276,14 +1265,12 @@ UpperShardBoundary(Datum partitionColumnValue, ShardInterval **shardIntervalCach
while (lowerBoundIndex < upperBoundIndex) while (lowerBoundIndex < upperBoundIndex)
{ {
int middleIndex = lowerBoundIndex + ((upperBoundIndex - lowerBoundIndex) / 2); int middleIndex = lowerBoundIndex + ((upperBoundIndex - lowerBoundIndex) / 2);
int maxValueComparison = 0;
int minValueComparison = 0;
/* setup minValue as argument */ /* setup minValue as argument */
fcSetArg(compareFunction, 1, shardIntervalCache[middleIndex]->minValue); fcSetArg(compareFunction, 1, shardIntervalCache[middleIndex]->minValue);
/* execute cmp(partitionValue, lowerBound) */ /* execute cmp(partitionValue, lowerBound) */
minValueComparison = PerformCompare(compareFunction); int minValueComparison = PerformCompare(compareFunction);
/* and evaluate results */ /* and evaluate results */
if ((minValueComparison == 0 && !includeMin) || if ((minValueComparison == 0 && !includeMin) ||
@ -1298,7 +1285,7 @@ UpperShardBoundary(Datum partitionColumnValue, ShardInterval **shardIntervalCach
fcSetArg(compareFunction, 1, shardIntervalCache[middleIndex]->maxValue); fcSetArg(compareFunction, 1, shardIntervalCache[middleIndex]->maxValue);
/* execute cmp(partitionValue, upperBound) */ /* execute cmp(partitionValue, upperBound) */
maxValueComparison = PerformCompare(compareFunction); int maxValueComparison = PerformCompare(compareFunction);
if (maxValueComparison > 0) if (maxValueComparison > 0)
{ {
@ -1355,7 +1342,6 @@ PruneWithBoundaries(DistTableCacheEntry *cacheEntry, ClauseWalkerContext *contex
bool upperBoundInclusive = false; bool upperBoundInclusive = false;
int lowerBoundIdx = -1; int lowerBoundIdx = -1;
int upperBoundIdx = -1; int upperBoundIdx = -1;
int curIdx = 0;
FunctionCallInfo compareFunctionCall = (FunctionCallInfo) & FunctionCallInfo compareFunctionCall = (FunctionCallInfo) &
context->compareIntervalFunctionCall; context->compareIntervalFunctionCall;
@ -1442,7 +1428,7 @@ PruneWithBoundaries(DistTableCacheEntry *cacheEntry, ClauseWalkerContext *contex
/* /*
* Build list of all shards that are in the range of shards (possibly 0). * Build list of all shards that are in the range of shards (possibly 0).
*/ */
for (curIdx = lowerBoundIdx; curIdx <= upperBoundIdx; curIdx++) for (int curIdx = lowerBoundIdx; curIdx <= upperBoundIdx; curIdx++)
{ {
remainingShardList = lappend(remainingShardList, remainingShardList = lappend(remainingShardList,
sortedShardIntervalArray[curIdx]); sortedShardIntervalArray[curIdx]);
@ -1463,9 +1449,8 @@ ExhaustivePrune(DistTableCacheEntry *cacheEntry, ClauseWalkerContext *context,
List *remainingShardList = NIL; List *remainingShardList = NIL;
int shardCount = cacheEntry->shardIntervalArrayLength; int shardCount = cacheEntry->shardIntervalArrayLength;
ShardInterval **sortedShardIntervalArray = cacheEntry->sortedShardIntervalArray; ShardInterval **sortedShardIntervalArray = cacheEntry->sortedShardIntervalArray;
int curIdx = 0;
for (curIdx = 0; curIdx < shardCount; curIdx++) for (int curIdx = 0; curIdx < shardCount; curIdx++)
{ {
ShardInterval *curInterval = sortedShardIntervalArray[curIdx]; ShardInterval *curInterval = sortedShardIntervalArray[curIdx];

View File

@ -39,11 +39,6 @@ ProgressMonitorData *
CreateProgressMonitor(uint64 progressTypeMagicNumber, int stepCount, Size stepSize, CreateProgressMonitor(uint64 progressTypeMagicNumber, int stepCount, Size stepSize,
Oid relationId) Oid relationId)
{ {
dsm_segment *dsmSegment = NULL;
dsm_handle dsmHandle = 0;
ProgressMonitorData *monitor = NULL;
Size monitorSize = 0;
if (stepSize <= 0 || stepCount <= 0) if (stepSize <= 0 || stepCount <= 0)
{ {
ereport(ERROR, ereport(ERROR,
@ -51,8 +46,8 @@ CreateProgressMonitor(uint64 progressTypeMagicNumber, int stepCount, Size stepSi
"positive values"))); "positive values")));
} }
monitorSize = sizeof(ProgressMonitorData) + stepSize * stepCount; Size monitorSize = sizeof(ProgressMonitorData) + stepSize * stepCount;
dsmSegment = dsm_create(monitorSize, DSM_CREATE_NULL_IF_MAXSEGMENTS); dsm_segment *dsmSegment = dsm_create(monitorSize, DSM_CREATE_NULL_IF_MAXSEGMENTS);
if (dsmSegment == NULL) if (dsmSegment == NULL)
{ {
@ -62,9 +57,9 @@ CreateProgressMonitor(uint64 progressTypeMagicNumber, int stepCount, Size stepSi
return NULL; return NULL;
} }
dsmHandle = dsm_segment_handle(dsmSegment); dsm_handle dsmHandle = dsm_segment_handle(dsmSegment);
monitor = MonitorDataFromDSMHandle(dsmHandle, &dsmSegment); ProgressMonitorData *monitor = MonitorDataFromDSMHandle(dsmHandle, &dsmSegment);
monitor->stepCount = stepCount; monitor->stepCount = stepCount;
monitor->processId = MyProcPid; monitor->processId = MyProcPid;
@ -143,31 +138,27 @@ ProgressMonitorList(uint64 commandTypeMagicNumber, List **attachedDSMSegments)
*/ */
text *commandTypeText = cstring_to_text("VACUUM"); text *commandTypeText = cstring_to_text("VACUUM");
Datum commandTypeDatum = PointerGetDatum(commandTypeText); Datum commandTypeDatum = PointerGetDatum(commandTypeText);
Oid getProgressInfoFunctionOid = InvalidOid;
TupleTableSlot *tupleTableSlot = NULL;
ReturnSetInfo *progressResultSet = NULL;
List *monitorList = NIL; List *monitorList = NIL;
getProgressInfoFunctionOid = FunctionOid("pg_catalog", Oid getProgressInfoFunctionOid = FunctionOid("pg_catalog",
"pg_stat_get_progress_info", "pg_stat_get_progress_info",
1); 1);
progressResultSet = FunctionCallGetTupleStore1(pg_stat_get_progress_info, ReturnSetInfo *progressResultSet = FunctionCallGetTupleStore1(
pg_stat_get_progress_info,
getProgressInfoFunctionOid, getProgressInfoFunctionOid,
commandTypeDatum); commandTypeDatum);
tupleTableSlot = MakeSingleTupleTableSlotCompat(progressResultSet->setDesc, TupleTableSlot *tupleTableSlot = MakeSingleTupleTableSlotCompat(
progressResultSet->setDesc,
&TTSOpsMinimalTuple); &TTSOpsMinimalTuple);
/* iterate over tuples in tuple store, and send them to destination */ /* iterate over tuples in tuple store, and send them to destination */
for (;;) for (;;)
{ {
bool nextTuple = false;
bool isNull = false; bool isNull = false;
Datum magicNumberDatum = 0;
uint64 magicNumber = 0;
nextTuple = tuplestore_gettupleslot(progressResultSet->setResult, bool nextTuple = tuplestore_gettupleslot(progressResultSet->setResult,
true, true,
false, false,
tupleTableSlot); tupleTableSlot);
@ -177,8 +168,8 @@ ProgressMonitorList(uint64 commandTypeMagicNumber, List **attachedDSMSegments)
break; break;
} }
magicNumberDatum = slot_getattr(tupleTableSlot, magicNumberIndex, &isNull); Datum magicNumberDatum = slot_getattr(tupleTableSlot, magicNumberIndex, &isNull);
magicNumber = DatumGetUInt64(magicNumberDatum); uint64 magicNumber = DatumGetUInt64(magicNumberDatum);
if (!isNull && magicNumber == commandTypeMagicNumber) if (!isNull && magicNumber == commandTypeMagicNumber)
{ {

View File

@ -118,7 +118,6 @@ RelayEventExtendNames(Node *parseTree, char *schemaName, uint64 shardId)
command->subtype == AT_ValidateConstraint) command->subtype == AT_ValidateConstraint)
{ {
char **constraintName = &(command->name); char **constraintName = &(command->name);
Oid constraintOid = InvalidOid;
const bool constraintMissingOk = true; const bool constraintMissingOk = true;
if (!OidIsValid(relationId)) if (!OidIsValid(relationId))
@ -129,7 +128,7 @@ RelayEventExtendNames(Node *parseTree, char *schemaName, uint64 shardId)
rvMissingOk); rvMissingOk);
} }
constraintOid = get_relation_constraint_oid(relationId, Oid constraintOid = get_relation_constraint_oid(relationId,
command->name, command->name,
constraintMissingOk); constraintMissingOk);
if (!OidIsValid(constraintOid)) if (!OidIsValid(constraintOid))
@ -161,8 +160,6 @@ RelayEventExtendNames(Node *parseTree, char *schemaName, uint64 shardId)
case T_ClusterStmt: case T_ClusterStmt:
{ {
ClusterStmt *clusterStmt = (ClusterStmt *) parseTree; ClusterStmt *clusterStmt = (ClusterStmt *) parseTree;
char **relationName = NULL;
char **relationSchemaName = NULL;
/* we do not support clustering the entire database */ /* we do not support clustering the entire database */
if (clusterStmt->relation == NULL) if (clusterStmt->relation == NULL)
@ -170,8 +167,8 @@ RelayEventExtendNames(Node *parseTree, char *schemaName, uint64 shardId)
ereport(ERROR, (errmsg("cannot extend name for multi-relation cluster"))); ereport(ERROR, (errmsg("cannot extend name for multi-relation cluster")));
} }
relationName = &(clusterStmt->relation->relname); char **relationName = &(clusterStmt->relation->relname);
relationSchemaName = &(clusterStmt->relation->schemaname); char **relationSchemaName = &(clusterStmt->relation->schemaname);
/* prefix with schema name if it is not added already */ /* prefix with schema name if it is not added already */
SetSchemaNameIfNotExist(relationSchemaName, schemaName); SetSchemaNameIfNotExist(relationSchemaName, schemaName);
@ -232,11 +229,8 @@ RelayEventExtendNames(Node *parseTree, char *schemaName, uint64 shardId)
if (objectType == OBJECT_TABLE || objectType == OBJECT_INDEX || if (objectType == OBJECT_TABLE || objectType == OBJECT_INDEX ||
objectType == OBJECT_FOREIGN_TABLE || objectType == OBJECT_FOREIGN_SERVER) objectType == OBJECT_FOREIGN_TABLE || objectType == OBJECT_FOREIGN_SERVER)
{ {
List *relationNameList = NULL;
int relationNameListLength = 0;
Value *relationSchemaNameValue = NULL; Value *relationSchemaNameValue = NULL;
Value *relationNameValue = NULL; Value *relationNameValue = NULL;
char **relationName = NULL;
uint32 dropCount = list_length(dropStmt->objects); uint32 dropCount = list_length(dropStmt->objects);
if (dropCount > 1) if (dropCount > 1)
@ -253,8 +247,8 @@ RelayEventExtendNames(Node *parseTree, char *schemaName, uint64 shardId)
* have the correct memory address for the name. * have the correct memory address for the name.
*/ */
relationNameList = (List *) linitial(dropStmt->objects); List *relationNameList = (List *) linitial(dropStmt->objects);
relationNameListLength = list_length(relationNameList); int relationNameListLength = list_length(relationNameList);
switch (relationNameListLength) switch (relationNameListLength)
{ {
@ -294,7 +288,7 @@ RelayEventExtendNames(Node *parseTree, char *schemaName, uint64 shardId)
relationNameList = lcons(schemaNameValue, relationNameList); relationNameList = lcons(schemaNameValue, relationNameList);
} }
relationName = &(relationNameValue->val.str); char **relationName = &(relationNameValue->val.str);
AppendShardIdToName(relationName, shardId); AppendShardIdToName(relationName, shardId);
} }
else if (objectType == OBJECT_POLICY) else if (objectType == OBJECT_POLICY)
@ -418,7 +412,6 @@ RelayEventExtendNames(Node *parseTree, char *schemaName, uint64 shardId)
char **oldRelationName = &(renameStmt->relation->relname); char **oldRelationName = &(renameStmt->relation->relname);
char **newRelationName = &(renameStmt->newname); char **newRelationName = &(renameStmt->newname);
char **objectSchemaName = &(renameStmt->relation->schemaname); char **objectSchemaName = &(renameStmt->relation->schemaname);
int newRelationNameLength;
/* prefix with schema name if it is not added already */ /* prefix with schema name if it is not added already */
SetSchemaNameIfNotExist(objectSchemaName, schemaName); SetSchemaNameIfNotExist(objectSchemaName, schemaName);
@ -440,7 +433,7 @@ RelayEventExtendNames(Node *parseTree, char *schemaName, uint64 shardId)
* *
* See also https://github.com/citusdata/citus/issues/1664 * See also https://github.com/citusdata/citus/issues/1664
*/ */
newRelationNameLength = strlen(*newRelationName); int newRelationNameLength = strlen(*newRelationName);
if (newRelationNameLength >= (NAMEDATALEN - 1)) if (newRelationNameLength >= (NAMEDATALEN - 1))
{ {
ereport(ERROR, ereport(ERROR,
@ -676,10 +669,8 @@ AppendShardIdToName(char **name, uint64 shardId)
char extendedName[NAMEDATALEN]; char extendedName[NAMEDATALEN];
int nameLength = strlen(*name); int nameLength = strlen(*name);
char shardIdAndSeparator[NAMEDATALEN]; char shardIdAndSeparator[NAMEDATALEN];
int shardIdAndSeparatorLength;
uint32 longNameHash = 0; uint32 longNameHash = 0;
int multiByteClipLength = 0; int multiByteClipLength = 0;
int neededBytes = 0;
if (nameLength >= NAMEDATALEN) if (nameLength >= NAMEDATALEN)
{ {
@ -690,7 +681,7 @@ AppendShardIdToName(char **name, uint64 shardId)
snprintf(shardIdAndSeparator, NAMEDATALEN, "%c" UINT64_FORMAT, snprintf(shardIdAndSeparator, NAMEDATALEN, "%c" UINT64_FORMAT,
SHARD_NAME_SEPARATOR, shardId); SHARD_NAME_SEPARATOR, shardId);
shardIdAndSeparatorLength = strlen(shardIdAndSeparator); int shardIdAndSeparatorLength = strlen(shardIdAndSeparator);
/* /*
* If *name strlen is < (NAMEDATALEN - shardIdAndSeparatorLength), * If *name strlen is < (NAMEDATALEN - shardIdAndSeparatorLength),
@ -740,7 +731,7 @@ AppendShardIdToName(char **name, uint64 shardId)
} }
(*name) = (char *) repalloc((*name), NAMEDATALEN); (*name) = (char *) repalloc((*name), NAMEDATALEN);
neededBytes = snprintf((*name), NAMEDATALEN, "%s", extendedName); int neededBytes = snprintf((*name), NAMEDATALEN, "%s", extendedName);
if (neededBytes < 0) if (neededBytes < 0)
{ {
ereport(ERROR, (errcode(ERRCODE_OUT_OF_MEMORY), ereport(ERROR, (errcode(ERRCODE_OUT_OF_MEMORY),
@ -764,10 +755,7 @@ shard_name(PG_FUNCTION_ARGS)
{ {
Oid relationId = PG_GETARG_OID(0); Oid relationId = PG_GETARG_OID(0);
int64 shardId = PG_GETARG_INT64(1); int64 shardId = PG_GETARG_INT64(1);
char *relationName = NULL;
Oid schemaId = InvalidOid;
char *schemaName = NULL;
char *qualifiedName = NULL; char *qualifiedName = NULL;
CheckCitusVersion(ERROR); CheckCitusVersion(ERROR);
@ -785,7 +773,7 @@ shard_name(PG_FUNCTION_ARGS)
errmsg("object_name does not reference a valid relation"))); errmsg("object_name does not reference a valid relation")));
} }
relationName = get_rel_name(relationId); char *relationName = get_rel_name(relationId);
if (relationName == NULL) if (relationName == NULL)
{ {
@ -795,8 +783,8 @@ shard_name(PG_FUNCTION_ARGS)
AppendShardIdToName(&relationName, shardId); AppendShardIdToName(&relationName, shardId);
schemaId = get_rel_namespace(relationId); Oid schemaId = get_rel_namespace(relationId);
schemaName = get_namespace_name(schemaId); char *schemaName = get_namespace_name(schemaId);
if (strncmp(schemaName, "public", NAMEDATALEN) == 0) if (strncmp(schemaName, "public", NAMEDATALEN) == 0)
{ {

View File

@ -273,10 +273,9 @@ static void
ResizeStackToMaximumDepth(void) ResizeStackToMaximumDepth(void)
{ {
#ifndef WIN32 #ifndef WIN32
volatile char *stack_resizer = NULL;
long max_stack_depth_bytes = max_stack_depth * 1024L; long max_stack_depth_bytes = max_stack_depth * 1024L;
stack_resizer = alloca(max_stack_depth_bytes); volatile char *stack_resizer = alloca(max_stack_depth_bytes);
/* /*
* Different architectures might have different directions while * Different architectures might have different directions while
@ -345,14 +344,13 @@ StartupCitusBackend(void)
static void static void
CreateRequiredDirectories(void) CreateRequiredDirectories(void)
{ {
int dirNo = 0;
const char *subdirs[] = { const char *subdirs[] = {
"pg_foreign_file", "pg_foreign_file",
"pg_foreign_file/cached", "pg_foreign_file/cached",
"base/" PG_JOB_CACHE_DIR "base/" PG_JOB_CACHE_DIR
}; };
for (dirNo = 0; dirNo < lengthof(subdirs); dirNo++) for (int dirNo = 0; dirNo < lengthof(subdirs); dirNo++)
{ {
int ret = mkdir(subdirs[dirNo], S_IRWXU); int ret = mkdir(subdirs[dirNo], S_IRWXU);
@ -1380,15 +1378,12 @@ NodeConninfoGucCheckHook(char **newval, void **extra, GucSource source)
static void static void
NodeConninfoGucAssignHook(const char *newval, void *extra) NodeConninfoGucAssignHook(const char *newval, void *extra)
{ {
PQconninfoOption *optionArray = NULL;
PQconninfoOption *option = NULL;
if (newval == NULL) if (newval == NULL)
{ {
newval = ""; newval = "";
} }
optionArray = PQconninfoParse(newval, NULL); PQconninfoOption *optionArray = PQconninfoParse(newval, NULL);
if (optionArray == NULL) if (optionArray == NULL)
{ {
ereport(FATAL, (errmsg("cannot parse node_conninfo value"), ereport(FATAL, (errmsg("cannot parse node_conninfo value"),
@ -1398,7 +1393,7 @@ NodeConninfoGucAssignHook(const char *newval, void *extra)
ResetConnParams(); ResetConnParams();
for (option = optionArray; option->keyword != NULL; option++) for (PQconninfoOption *option = optionArray; option->keyword != NULL; option++)
{ {
if (option->val == NULL || option->val[0] == '\0') if (option->val == NULL || option->val[0] == '\0')
{ {

View File

@ -83,7 +83,6 @@ get_colocated_table_array(PG_FUNCTION_ARGS)
{ {
Oid distributedTableId = PG_GETARG_OID(0); Oid distributedTableId = PG_GETARG_OID(0);
ArrayType *colocatedTablesArrayType = NULL;
List *colocatedTableList = ColocatedTableList(distributedTableId); List *colocatedTableList = ColocatedTableList(distributedTableId);
ListCell *colocatedTableCell = NULL; ListCell *colocatedTableCell = NULL;
int colocatedTableCount = list_length(colocatedTableList); int colocatedTableCount = list_length(colocatedTableList);
@ -100,8 +99,9 @@ get_colocated_table_array(PG_FUNCTION_ARGS)
colocatedTableIndex++; colocatedTableIndex++;
} }
colocatedTablesArrayType = DatumArrayToArrayType(colocatedTablesDatumArray, ArrayType *colocatedTablesArrayType = DatumArrayToArrayType(colocatedTablesDatumArray,
colocatedTableCount, arrayTypeId); colocatedTableCount,
arrayTypeId);
PG_RETURN_ARRAYTYPE_P(colocatedTablesArrayType); PG_RETURN_ARRAYTYPE_P(colocatedTablesArrayType);
} }

View File

@ -31,15 +31,12 @@ Datum
deparse_test(PG_FUNCTION_ARGS) deparse_test(PG_FUNCTION_ARGS)
{ {
text *queryStringText = PG_GETARG_TEXT_P(0); text *queryStringText = PG_GETARG_TEXT_P(0);
char *queryStringChar = NULL;
Query *query = NULL;
const char *deparsedQuery = NULL;
queryStringChar = text_to_cstring(queryStringText); char *queryStringChar = text_to_cstring(queryStringText);
query = ParseQueryString(queryStringChar, NULL, 0); Query *query = ParseQueryString(queryStringChar, NULL, 0);
QualifyTreeNode(query->utilityStmt); QualifyTreeNode(query->utilityStmt);
deparsedQuery = DeparseTreeNode(query->utilityStmt); const char *deparsedQuery = DeparseTreeNode(query->utilityStmt);
PG_RETURN_TEXT_P(cstring_to_text(deparsedQuery)); PG_RETURN_TEXT_P(cstring_to_text(deparsedQuery));
} }

View File

@ -50,9 +50,9 @@ deparse_shard_query_test(PG_FUNCTION_ARGS)
{ {
Node *parsetree = (Node *) lfirst(parseTreeCell); Node *parsetree = (Node *) lfirst(parseTreeCell);
ListCell *queryTreeCell = NULL; ListCell *queryTreeCell = NULL;
List *queryTreeList = NIL;
queryTreeList = pg_analyze_and_rewrite((RawStmt *) parsetree, queryStringChar, List *queryTreeList = pg_analyze_and_rewrite((RawStmt *) parsetree,
queryStringChar,
NULL, 0, NULL); NULL, 0, NULL);
foreach(queryTreeCell, queryTreeList) foreach(queryTreeCell, queryTreeList)

View File

@ -40,10 +40,7 @@ Datum
get_adjacency_list_wait_graph(PG_FUNCTION_ARGS) get_adjacency_list_wait_graph(PG_FUNCTION_ARGS)
{ {
TupleDesc tupleDescriptor = NULL; TupleDesc tupleDescriptor = NULL;
Tuplestorestate *tupleStore = NULL;
WaitGraph *waitGraph = NULL;
HTAB *adjacencyList = NULL;
HASH_SEQ_STATUS status; HASH_SEQ_STATUS status;
TransactionNode *transactionNode = NULL; TransactionNode *transactionNode = NULL;
@ -52,9 +49,9 @@ get_adjacency_list_wait_graph(PG_FUNCTION_ARGS)
CheckCitusVersion(ERROR); CheckCitusVersion(ERROR);
tupleStore = SetupTuplestore(fcinfo, &tupleDescriptor); Tuplestorestate *tupleStore = SetupTuplestore(fcinfo, &tupleDescriptor);
waitGraph = BuildGlobalWaitGraph(); WaitGraph *waitGraph = BuildGlobalWaitGraph();
adjacencyList = BuildAdjacencyListsForWaitGraph(waitGraph); HTAB *adjacencyList = BuildAdjacencyListsForWaitGraph(waitGraph);
/* iterate on all nodes */ /* iterate on all nodes */
hash_seq_init(&status, adjacencyList); hash_seq_init(&status, adjacencyList);

View File

@ -62,17 +62,14 @@ Datum
load_shard_id_array(PG_FUNCTION_ARGS) load_shard_id_array(PG_FUNCTION_ARGS)
{ {
Oid distributedTableId = PG_GETARG_OID(0); Oid distributedTableId = PG_GETARG_OID(0);
ArrayType *shardIdArrayType = NULL;
ListCell *shardCell = NULL; ListCell *shardCell = NULL;
int shardIdIndex = 0; int shardIdIndex = 0;
Oid shardIdTypeId = INT8OID; Oid shardIdTypeId = INT8OID;
int shardIdCount = -1;
Datum *shardIdDatumArray = NULL;
List *shardList = LoadShardIntervalList(distributedTableId); List *shardList = LoadShardIntervalList(distributedTableId);
shardIdCount = list_length(shardList); int shardIdCount = list_length(shardList);
shardIdDatumArray = palloc0(shardIdCount * sizeof(Datum)); Datum *shardIdDatumArray = palloc0(shardIdCount * sizeof(Datum));
foreach(shardCell, shardList) foreach(shardCell, shardList)
{ {
@ -83,7 +80,7 @@ load_shard_id_array(PG_FUNCTION_ARGS)
shardIdIndex++; shardIdIndex++;
} }
shardIdArrayType = DatumArrayToArrayType(shardIdDatumArray, shardIdCount, ArrayType *shardIdArrayType = DatumArrayToArrayType(shardIdDatumArray, shardIdCount,
shardIdTypeId); shardIdTypeId);
PG_RETURN_ARRAYTYPE_P(shardIdArrayType); PG_RETURN_ARRAYTYPE_P(shardIdArrayType);
@ -103,11 +100,10 @@ load_shard_interval_array(PG_FUNCTION_ARGS)
Oid expectedType PG_USED_FOR_ASSERTS_ONLY = get_fn_expr_argtype(fcinfo->flinfo, 1); Oid expectedType PG_USED_FOR_ASSERTS_ONLY = get_fn_expr_argtype(fcinfo->flinfo, 1);
ShardInterval *shardInterval = LoadShardInterval(shardId); ShardInterval *shardInterval = LoadShardInterval(shardId);
Datum shardIntervalArray[] = { shardInterval->minValue, shardInterval->maxValue }; Datum shardIntervalArray[] = { shardInterval->minValue, shardInterval->maxValue };
ArrayType *shardIntervalArrayType = NULL;
Assert(expectedType == shardInterval->valueTypeId); Assert(expectedType == shardInterval->valueTypeId);
shardIntervalArrayType = DatumArrayToArrayType(shardIntervalArray, 2, ArrayType *shardIntervalArrayType = DatumArrayToArrayType(shardIntervalArray, 2,
shardInterval->valueTypeId); shardInterval->valueTypeId);
PG_RETURN_ARRAYTYPE_P(shardIntervalArrayType); PG_RETURN_ARRAYTYPE_P(shardIntervalArrayType);
@ -126,12 +122,9 @@ load_shard_placement_array(PG_FUNCTION_ARGS)
{ {
int64 shardId = PG_GETARG_INT64(0); int64 shardId = PG_GETARG_INT64(0);
bool onlyFinalized = PG_GETARG_BOOL(1); bool onlyFinalized = PG_GETARG_BOOL(1);
ArrayType *placementArrayType = NULL;
List *placementList = NIL; List *placementList = NIL;
ListCell *placementCell = NULL; ListCell *placementCell = NULL;
int placementCount = -1;
int placementIndex = 0; int placementIndex = 0;
Datum *placementDatumArray = NULL;
Oid placementTypeId = TEXTOID; Oid placementTypeId = TEXTOID;
StringInfo placementInfo = makeStringInfo(); StringInfo placementInfo = makeStringInfo();
@ -146,8 +139,8 @@ load_shard_placement_array(PG_FUNCTION_ARGS)
placementList = SortList(placementList, CompareShardPlacementsByWorker); placementList = SortList(placementList, CompareShardPlacementsByWorker);
placementCount = list_length(placementList); int placementCount = list_length(placementList);
placementDatumArray = palloc0(placementCount * sizeof(Datum)); Datum *placementDatumArray = palloc0(placementCount * sizeof(Datum));
foreach(placementCell, placementList) foreach(placementCell, placementList)
{ {
@ -160,7 +153,8 @@ load_shard_placement_array(PG_FUNCTION_ARGS)
resetStringInfo(placementInfo); resetStringInfo(placementInfo);
} }
placementArrayType = DatumArrayToArrayType(placementDatumArray, placementCount, ArrayType *placementArrayType = DatumArrayToArrayType(placementDatumArray,
placementCount,
placementTypeId); placementTypeId);
PG_RETURN_ARRAYTYPE_P(placementArrayType); PG_RETURN_ARRAYTYPE_P(placementArrayType);
@ -224,14 +218,12 @@ create_monolithic_shard_row(PG_FUNCTION_ARGS)
StringInfo minInfo = makeStringInfo(); StringInfo minInfo = makeStringInfo();
StringInfo maxInfo = makeStringInfo(); StringInfo maxInfo = makeStringInfo();
uint64 newShardId = GetNextShardId(); uint64 newShardId = GetNextShardId();
text *maxInfoText = NULL;
text *minInfoText = NULL;
appendStringInfo(minInfo, "%d", INT32_MIN); appendStringInfo(minInfo, "%d", INT32_MIN);
appendStringInfo(maxInfo, "%d", INT32_MAX); appendStringInfo(maxInfo, "%d", INT32_MAX);
minInfoText = cstring_to_text(minInfo->data); text *minInfoText = cstring_to_text(minInfo->data);
maxInfoText = cstring_to_text(maxInfo->data); text *maxInfoText = cstring_to_text(maxInfo->data);
InsertShardRow(distributedTableId, newShardId, SHARD_STORAGE_TABLE, minInfoText, InsertShardRow(distributedTableId, newShardId, SHARD_STORAGE_TABLE, minInfoText,
maxInfoText); maxInfoText);
@ -270,9 +262,9 @@ relation_count_in_query(PG_FUNCTION_ARGS)
{ {
Node *parsetree = (Node *) lfirst(parseTreeCell); Node *parsetree = (Node *) lfirst(parseTreeCell);
ListCell *queryTreeCell = NULL; ListCell *queryTreeCell = NULL;
List *queryTreeList = NIL;
queryTreeList = pg_analyze_and_rewrite((RawStmt *) parsetree, queryStringChar, List *queryTreeList = pg_analyze_and_rewrite((RawStmt *) parsetree,
queryStringChar,
NULL, 0, NULL); NULL, 0, NULL);
foreach(queryTreeCell, queryTreeList) foreach(queryTreeCell, queryTreeList)

View File

@ -41,17 +41,14 @@ master_metadata_snapshot(PG_FUNCTION_ARGS)
List *createSnapshotCommands = MetadataCreateCommands(); List *createSnapshotCommands = MetadataCreateCommands();
List *snapshotCommandList = NIL; List *snapshotCommandList = NIL;
ListCell *snapshotCommandCell = NULL; ListCell *snapshotCommandCell = NULL;
int snapshotCommandCount = 0;
Datum *snapshotCommandDatumArray = NULL;
ArrayType *snapshotCommandArrayType = NULL;
int snapshotCommandIndex = 0; int snapshotCommandIndex = 0;
Oid ddlCommandTypeId = TEXTOID; Oid ddlCommandTypeId = TEXTOID;
snapshotCommandList = list_concat(snapshotCommandList, dropSnapshotCommands); snapshotCommandList = list_concat(snapshotCommandList, dropSnapshotCommands);
snapshotCommandList = list_concat(snapshotCommandList, createSnapshotCommands); snapshotCommandList = list_concat(snapshotCommandList, createSnapshotCommands);
snapshotCommandCount = list_length(snapshotCommandList); int snapshotCommandCount = list_length(snapshotCommandList);
snapshotCommandDatumArray = palloc0(snapshotCommandCount * sizeof(Datum)); Datum *snapshotCommandDatumArray = palloc0(snapshotCommandCount * sizeof(Datum));
foreach(snapshotCommandCell, snapshotCommandList) foreach(snapshotCommandCell, snapshotCommandList)
{ {
@ -62,7 +59,7 @@ master_metadata_snapshot(PG_FUNCTION_ARGS)
snapshotCommandIndex++; snapshotCommandIndex++;
} }
snapshotCommandArrayType = DatumArrayToArrayType(snapshotCommandDatumArray, ArrayType *snapshotCommandArrayType = DatumArrayToArrayType(snapshotCommandDatumArray,
snapshotCommandCount, snapshotCommandCount,
ddlCommandTypeId); ddlCommandTypeId);
@ -78,13 +75,10 @@ Datum
wait_until_metadata_sync(PG_FUNCTION_ARGS) wait_until_metadata_sync(PG_FUNCTION_ARGS)
{ {
uint32 timeout = PG_GETARG_UINT32(0); uint32 timeout = PG_GETARG_UINT32(0);
int waitResult = 0;
List *workerList = ActivePrimaryWorkerNodeList(NoLock); List *workerList = ActivePrimaryWorkerNodeList(NoLock);
ListCell *workerCell = NULL; ListCell *workerCell = NULL;
bool waitNotifications = false; bool waitNotifications = false;
MultiConnection *connection = NULL;
int waitFlags = 0;
foreach(workerCell, workerList) foreach(workerCell, workerList)
{ {
@ -109,12 +103,12 @@ wait_until_metadata_sync(PG_FUNCTION_ARGS)
PG_RETURN_VOID(); PG_RETURN_VOID();
} }
connection = GetNodeConnection(FORCE_NEW_CONNECTION, MultiConnection *connection = GetNodeConnection(FORCE_NEW_CONNECTION,
"localhost", PostPortNumber); "localhost", PostPortNumber);
ExecuteCriticalRemoteCommand(connection, "LISTEN " METADATA_SYNC_CHANNEL); ExecuteCriticalRemoteCommand(connection, "LISTEN " METADATA_SYNC_CHANNEL);
waitFlags = WL_SOCKET_READABLE | WL_TIMEOUT | WL_POSTMASTER_DEATH; int waitFlags = WL_SOCKET_READABLE | WL_TIMEOUT | WL_POSTMASTER_DEATH;
waitResult = WaitLatchOrSocket(NULL, waitFlags, PQsocket(connection->pgConn), int waitResult = WaitLatchOrSocket(NULL, waitFlags, PQsocket(connection->pgConn),
timeout, 0); timeout, 0);
if (waitResult & WL_POSTMASTER_DEATH) if (waitResult & WL_POSTMASTER_DEATH)
{ {

View File

@ -95,8 +95,7 @@ show_progress(PG_FUNCTION_ARGS)
ProgressMonitorData *monitor = lfirst(monitorCell); ProgressMonitorData *monitor = lfirst(monitorCell);
uint64 *steps = monitor->steps; uint64 *steps = monitor->steps;
int stepIndex = 0; for (int stepIndex = 0; stepIndex < monitor->stepCount; stepIndex++)
for (stepIndex = 0; stepIndex < monitor->stepCount; stepIndex++)
{ {
uint64 step = steps[stepIndex]; uint64 step = steps[stepIndex];

View File

@ -202,20 +202,16 @@ MakeTextPartitionExpression(Oid distributedTableId, text *value)
static ArrayType * static ArrayType *
PrunedShardIdsForTable(Oid distributedTableId, List *whereClauseList) PrunedShardIdsForTable(Oid distributedTableId, List *whereClauseList)
{ {
ArrayType *shardIdArrayType = NULL;
ListCell *shardCell = NULL; ListCell *shardCell = NULL;
int shardIdIndex = 0; int shardIdIndex = 0;
Oid shardIdTypeId = INT8OID; Oid shardIdTypeId = INT8OID;
Index tableId = 1; Index tableId = 1;
List *shardList = NIL;
int shardIdCount = -1;
Datum *shardIdDatumArray = NULL;
shardList = PruneShards(distributedTableId, tableId, whereClauseList, NULL); List *shardList = PruneShards(distributedTableId, tableId, whereClauseList, NULL);
shardIdCount = list_length(shardList); int shardIdCount = list_length(shardList);
shardIdDatumArray = palloc0(shardIdCount * sizeof(Datum)); Datum *shardIdDatumArray = palloc0(shardIdCount * sizeof(Datum));
foreach(shardCell, shardList) foreach(shardCell, shardList)
{ {
@ -226,7 +222,7 @@ PrunedShardIdsForTable(Oid distributedTableId, List *whereClauseList)
shardIdIndex++; shardIdIndex++;
} }
shardIdArrayType = DatumArrayToArrayType(shardIdDatumArray, shardIdCount, ArrayType *shardIdArrayType = DatumArrayToArrayType(shardIdDatumArray, shardIdCount,
shardIdTypeId); shardIdTypeId);
return shardIdArrayType; return shardIdArrayType;
@ -240,8 +236,6 @@ PrunedShardIdsForTable(Oid distributedTableId, List *whereClauseList)
static ArrayType * static ArrayType *
SortedShardIntervalArray(Oid distributedTableId) SortedShardIntervalArray(Oid distributedTableId)
{ {
ArrayType *shardIdArrayType = NULL;
int shardIndex = 0;
Oid shardIdTypeId = INT8OID; Oid shardIdTypeId = INT8OID;
DistTableCacheEntry *cacheEntry = DistributedTableCacheEntry(distributedTableId); DistTableCacheEntry *cacheEntry = DistributedTableCacheEntry(distributedTableId);
@ -249,7 +243,7 @@ SortedShardIntervalArray(Oid distributedTableId)
int shardIdCount = cacheEntry->shardIntervalArrayLength; int shardIdCount = cacheEntry->shardIntervalArrayLength;
Datum *shardIdDatumArray = palloc0(shardIdCount * sizeof(Datum)); Datum *shardIdDatumArray = palloc0(shardIdCount * sizeof(Datum));
for (shardIndex = 0; shardIndex < shardIdCount; ++shardIndex) for (int shardIndex = 0; shardIndex < shardIdCount; ++shardIndex)
{ {
ShardInterval *shardId = shardIntervalArray[shardIndex]; ShardInterval *shardId = shardIntervalArray[shardIndex];
Datum shardIdDatum = Int64GetDatum(shardId->shardId); Datum shardIdDatum = Int64GetDatum(shardId->shardId);
@ -257,7 +251,7 @@ SortedShardIntervalArray(Oid distributedTableId)
shardIdDatumArray[shardIndex] = shardIdDatum; shardIdDatumArray[shardIndex] = shardIdDatum;
} }
shardIdArrayType = DatumArrayToArrayType(shardIdDatumArray, shardIdCount, ArrayType *shardIdArrayType = DatumArrayToArrayType(shardIdDatumArray, shardIdCount,
shardIdTypeId); shardIdTypeId);
return shardIdArrayType; return shardIdArrayType;

View File

@ -136,7 +136,6 @@ run_commands_on_session_level_connection_to_node(PG_FUNCTION_ARGS)
StringInfo workerProcessStringInfo = makeStringInfo(); StringInfo workerProcessStringInfo = makeStringInfo();
MultiConnection *localConnection = GetNodeConnection(0, LOCAL_HOST_NAME, MultiConnection *localConnection = GetNodeConnection(0, LOCAL_HOST_NAME,
PostPortNumber); PostPortNumber);
Oid pgReloadConfOid = InvalidOid;
if (!singleConnection) if (!singleConnection)
{ {
@ -160,7 +159,7 @@ run_commands_on_session_level_connection_to_node(PG_FUNCTION_ARGS)
CloseConnection(localConnection); CloseConnection(localConnection);
/* Call pg_reload_conf UDF to update changed GUCs above on each backend */ /* Call pg_reload_conf UDF to update changed GUCs above on each backend */
pgReloadConfOid = FunctionOid("pg_catalog", "pg_reload_conf", 0); Oid pgReloadConfOid = FunctionOid("pg_catalog", "pg_reload_conf", 0);
OidFunctionCall0(pgReloadConfOid); OidFunctionCall0(pgReloadConfOid);
@ -197,21 +196,19 @@ GetRemoteProcessId()
{ {
StringInfo queryStringInfo = makeStringInfo(); StringInfo queryStringInfo = makeStringInfo();
PGresult *result = NULL; PGresult *result = NULL;
int64 rowCount = 0;
int64 resultValue = 0;
appendStringInfo(queryStringInfo, GET_PROCESS_ID); appendStringInfo(queryStringInfo, GET_PROCESS_ID);
ExecuteOptionalRemoteCommand(singleConnection, queryStringInfo->data, &result); ExecuteOptionalRemoteCommand(singleConnection, queryStringInfo->data, &result);
rowCount = PQntuples(result); int64 rowCount = PQntuples(result);
if (rowCount != 1) if (rowCount != 1)
{ {
PG_RETURN_VOID(); PG_RETURN_VOID();
} }
resultValue = ParseIntField(result, 0, 0); int64 resultValue = ParseIntField(result, 0, 0);
PQclear(result); PQclear(result);
ClearResults(singleConnection, false); ClearResults(singleConnection, false);

View File

@ -155,12 +155,10 @@ Datum
get_current_transaction_id(PG_FUNCTION_ARGS) get_current_transaction_id(PG_FUNCTION_ARGS)
{ {
TupleDesc tupleDescriptor = NULL; TupleDesc tupleDescriptor = NULL;
HeapTuple heapTuple = NULL;
Datum values[5]; Datum values[5];
bool isNulls[5]; bool isNulls[5];
DistributedTransactionId *distributedTransctionId = NULL;
CheckCitusVersion(ERROR); CheckCitusVersion(ERROR);
@ -176,7 +174,8 @@ get_current_transaction_id(PG_FUNCTION_ARGS)
ereport(ERROR, (errmsg("backend is not ready for distributed transactions"))); ereport(ERROR, (errmsg("backend is not ready for distributed transactions")));
} }
distributedTransctionId = GetCurrentDistributedTransactionId(); DistributedTransactionId *distributedTransctionId =
GetCurrentDistributedTransactionId();
memset(values, 0, sizeof(values)); memset(values, 0, sizeof(values));
memset(isNulls, false, sizeof(isNulls)); memset(isNulls, false, sizeof(isNulls));
@ -198,7 +197,7 @@ get_current_transaction_id(PG_FUNCTION_ARGS)
isNulls[4] = true; isNulls[4] = true;
} }
heapTuple = heap_form_tuple(tupleDescriptor, values, isNulls); HeapTuple heapTuple = heap_form_tuple(tupleDescriptor, values, isNulls);
PG_RETURN_DATUM(HeapTupleGetDatum(heapTuple)); PG_RETURN_DATUM(HeapTupleGetDatum(heapTuple));
} }
@ -215,7 +214,6 @@ Datum
get_global_active_transactions(PG_FUNCTION_ARGS) get_global_active_transactions(PG_FUNCTION_ARGS)
{ {
TupleDesc tupleDescriptor = NULL; TupleDesc tupleDescriptor = NULL;
Tuplestorestate *tupleStore = NULL;
List *workerNodeList = ActivePrimaryWorkerNodeList(NoLock); List *workerNodeList = ActivePrimaryWorkerNodeList(NoLock);
ListCell *workerNodeCell = NULL; ListCell *workerNodeCell = NULL;
List *connectionList = NIL; List *connectionList = NIL;
@ -223,7 +221,7 @@ get_global_active_transactions(PG_FUNCTION_ARGS)
StringInfo queryToSend = makeStringInfo(); StringInfo queryToSend = makeStringInfo();
CheckCitusVersion(ERROR); CheckCitusVersion(ERROR);
tupleStore = SetupTuplestore(fcinfo, &tupleDescriptor); Tuplestorestate *tupleStore = SetupTuplestore(fcinfo, &tupleDescriptor);
appendStringInfo(queryToSend, GET_ACTIVE_TRANSACTION_QUERY); appendStringInfo(queryToSend, GET_ACTIVE_TRANSACTION_QUERY);
@ -236,7 +234,6 @@ get_global_active_transactions(PG_FUNCTION_ARGS)
WorkerNode *workerNode = (WorkerNode *) lfirst(workerNodeCell); WorkerNode *workerNode = (WorkerNode *) lfirst(workerNodeCell);
char *nodeName = workerNode->workerName; char *nodeName = workerNode->workerName;
int nodePort = workerNode->workerPort; int nodePort = workerNode->workerPort;
MultiConnection *connection = NULL;
int connectionFlags = 0; int connectionFlags = 0;
if (workerNode->groupId == GetLocalGroupId()) if (workerNode->groupId == GetLocalGroupId())
@ -245,7 +242,8 @@ get_global_active_transactions(PG_FUNCTION_ARGS)
continue; continue;
} }
connection = StartNodeConnection(connectionFlags, nodeName, nodePort); MultiConnection *connection = StartNodeConnection(connectionFlags, nodeName,
nodePort);
connectionList = lappend(connectionList, connection); connectionList = lappend(connectionList, connection);
} }
@ -256,9 +254,8 @@ get_global_active_transactions(PG_FUNCTION_ARGS)
foreach(connectionCell, connectionList) foreach(connectionCell, connectionList)
{ {
MultiConnection *connection = (MultiConnection *) lfirst(connectionCell); MultiConnection *connection = (MultiConnection *) lfirst(connectionCell);
int querySent = false;
querySent = SendRemoteCommand(connection, queryToSend->data); int querySent = SendRemoteCommand(connection, queryToSend->data);
if (querySent == 0) if (querySent == 0)
{ {
ReportConnectionError(connection, WARNING); ReportConnectionError(connection, WARNING);
@ -269,28 +266,24 @@ get_global_active_transactions(PG_FUNCTION_ARGS)
foreach(connectionCell, connectionList) foreach(connectionCell, connectionList)
{ {
MultiConnection *connection = (MultiConnection *) lfirst(connectionCell); MultiConnection *connection = (MultiConnection *) lfirst(connectionCell);
PGresult *result = NULL;
bool raiseInterrupts = true; bool raiseInterrupts = true;
Datum values[ACTIVE_TRANSACTION_COLUMN_COUNT]; Datum values[ACTIVE_TRANSACTION_COLUMN_COUNT];
bool isNulls[ACTIVE_TRANSACTION_COLUMN_COUNT]; bool isNulls[ACTIVE_TRANSACTION_COLUMN_COUNT];
int64 rowIndex = 0;
int64 rowCount = 0;
int64 colCount = 0;
if (PQstatus(connection->pgConn) != CONNECTION_OK) if (PQstatus(connection->pgConn) != CONNECTION_OK)
{ {
continue; continue;
} }
result = GetRemoteCommandResult(connection, raiseInterrupts); PGresult *result = GetRemoteCommandResult(connection, raiseInterrupts);
if (!IsResponseOK(result)) if (!IsResponseOK(result))
{ {
ReportResultError(connection, result, WARNING); ReportResultError(connection, result, WARNING);
continue; continue;
} }
rowCount = PQntuples(result); int64 rowCount = PQntuples(result);
colCount = PQnfields(result); int64 colCount = PQnfields(result);
/* Although it is not expected */ /* Although it is not expected */
if (colCount != ACTIVE_TRANSACTION_COLUMN_COUNT) if (colCount != ACTIVE_TRANSACTION_COLUMN_COUNT)
@ -300,7 +293,7 @@ get_global_active_transactions(PG_FUNCTION_ARGS)
continue; continue;
} }
for (rowIndex = 0; rowIndex < rowCount; rowIndex++) for (int64 rowIndex = 0; rowIndex < rowCount; rowIndex++)
{ {
memset(values, 0, sizeof(values)); memset(values, 0, sizeof(values));
memset(isNulls, false, sizeof(isNulls)); memset(isNulls, false, sizeof(isNulls));
@ -334,10 +327,9 @@ Datum
get_all_active_transactions(PG_FUNCTION_ARGS) get_all_active_transactions(PG_FUNCTION_ARGS)
{ {
TupleDesc tupleDescriptor = NULL; TupleDesc tupleDescriptor = NULL;
Tuplestorestate *tupleStore = NULL;
CheckCitusVersion(ERROR); CheckCitusVersion(ERROR);
tupleStore = SetupTuplestore(fcinfo, &tupleDescriptor); Tuplestorestate *tupleStore = SetupTuplestore(fcinfo, &tupleDescriptor);
StoreAllActiveTransactions(tupleStore, tupleDescriptor); StoreAllActiveTransactions(tupleStore, tupleDescriptor);
@ -355,7 +347,6 @@ get_all_active_transactions(PG_FUNCTION_ARGS)
static void static void
StoreAllActiveTransactions(Tuplestorestate *tupleStore, TupleDesc tupleDescriptor) StoreAllActiveTransactions(Tuplestorestate *tupleStore, TupleDesc tupleDescriptor)
{ {
int backendIndex = 0;
Datum values[ACTIVE_TRANSACTION_COLUMN_COUNT]; Datum values[ACTIVE_TRANSACTION_COLUMN_COUNT];
bool isNulls[ACTIVE_TRANSACTION_COLUMN_COUNT]; bool isNulls[ACTIVE_TRANSACTION_COLUMN_COUNT];
bool showAllTransactions = superuser(); bool showAllTransactions = superuser();
@ -377,18 +368,14 @@ StoreAllActiveTransactions(Tuplestorestate *tupleStore, TupleDesc tupleDescripto
/* we're reading all distributed transactions, prevent new backends */ /* we're reading all distributed transactions, prevent new backends */
LockBackendSharedMemory(LW_SHARED); LockBackendSharedMemory(LW_SHARED);
for (backendIndex = 0; backendIndex < MaxBackends; ++backendIndex) for (int backendIndex = 0; backendIndex < MaxBackends; ++backendIndex)
{ {
BackendData *currentBackend = BackendData *currentBackend =
&backendManagementShmemData->backends[backendIndex]; &backendManagementShmemData->backends[backendIndex];
bool coordinatorOriginatedQuery = false;
/* to work on data after releasing g spinlock to protect against errors */ /* to work on data after releasing g spinlock to protect against errors */
Oid databaseId = InvalidOid;
int backendPid = -1;
int initiatorNodeIdentifier = -1; int initiatorNodeIdentifier = -1;
uint64 transactionNumber = 0; uint64 transactionNumber = 0;
TimestampTz transactionIdTimestamp = 0;
SpinLockAcquire(&currentBackend->mutex); SpinLockAcquire(&currentBackend->mutex);
@ -409,8 +396,8 @@ StoreAllActiveTransactions(Tuplestorestate *tupleStore, TupleDesc tupleDescripto
continue; continue;
} }
databaseId = currentBackend->databaseId; Oid databaseId = currentBackend->databaseId;
backendPid = ProcGlobal->allProcs[backendIndex].pid; int backendPid = ProcGlobal->allProcs[backendIndex].pid;
initiatorNodeIdentifier = currentBackend->citusBackend.initiatorNodeIdentifier; initiatorNodeIdentifier = currentBackend->citusBackend.initiatorNodeIdentifier;
/* /*
@ -421,10 +408,11 @@ StoreAllActiveTransactions(Tuplestorestate *tupleStore, TupleDesc tupleDescripto
* field with the same name. The reason is that it also covers backends that are not * field with the same name. The reason is that it also covers backends that are not
* inside a distributed transaction. * inside a distributed transaction.
*/ */
coordinatorOriginatedQuery = currentBackend->citusBackend.transactionOriginator; bool coordinatorOriginatedQuery =
currentBackend->citusBackend.transactionOriginator;
transactionNumber = currentBackend->transactionId.transactionNumber; transactionNumber = currentBackend->transactionId.transactionNumber;
transactionIdTimestamp = currentBackend->transactionId.timestamp; TimestampTz transactionIdTimestamp = currentBackend->transactionId.timestamp;
SpinLockRelease(&currentBackend->mutex); SpinLockRelease(&currentBackend->mutex);
@ -489,8 +477,6 @@ BackendManagementShmemInit(void)
if (!alreadyInitialized) if (!alreadyInitialized)
{ {
int backendIndex = 0;
int totalProcs = 0;
char *trancheName = "Backend Management Tranche"; char *trancheName = "Backend Management Tranche";
NamedLWLockTranche *namedLockTranche = NamedLWLockTranche *namedLockTranche =
@ -518,8 +504,8 @@ BackendManagementShmemInit(void)
* We also initiate initiatorNodeIdentifier to -1, which can never be * We also initiate initiatorNodeIdentifier to -1, which can never be
* used as a node id. * used as a node id.
*/ */
totalProcs = TotalProcCount(); int totalProcs = TotalProcCount();
for (backendIndex = 0; backendIndex < totalProcs; ++backendIndex) for (int backendIndex = 0; backendIndex < totalProcs; ++backendIndex)
{ {
BackendData *backendData = BackendData *backendData =
&backendManagementShmemData->backends[backendIndex]; &backendManagementShmemData->backends[backendIndex];
@ -809,7 +795,6 @@ CurrentDistributedTransactionNumber(void)
void void
GetBackendDataForProc(PGPROC *proc, BackendData *result) GetBackendDataForProc(PGPROC *proc, BackendData *result)
{ {
BackendData *backendData = NULL;
int pgprocno = proc->pgprocno; int pgprocno = proc->pgprocno;
if (proc->lockGroupLeader != NULL) if (proc->lockGroupLeader != NULL)
@ -817,7 +802,7 @@ GetBackendDataForProc(PGPROC *proc, BackendData *result)
pgprocno = proc->lockGroupLeader->pgprocno; pgprocno = proc->lockGroupLeader->pgprocno;
} }
backendData = &backendManagementShmemData->backends[pgprocno]; BackendData *backendData = &backendManagementShmemData->backends[pgprocno];
SpinLockAcquire(&backendData->mutex); SpinLockAcquire(&backendData->mutex);
@ -903,14 +888,12 @@ List *
ActiveDistributedTransactionNumbers(void) ActiveDistributedTransactionNumbers(void)
{ {
List *activeTransactionNumberList = NIL; List *activeTransactionNumberList = NIL;
int curBackend = 0;
/* build list of starting procs */ /* build list of starting procs */
for (curBackend = 0; curBackend < MaxBackends; curBackend++) for (int curBackend = 0; curBackend < MaxBackends; curBackend++)
{ {
PGPROC *currentProc = &ProcGlobal->allProcs[curBackend]; PGPROC *currentProc = &ProcGlobal->allProcs[curBackend];
BackendData currentBackendData; BackendData currentBackendData;
uint64 *transactionNumber = NULL;
if (currentProc->pid == 0) if (currentProc->pid == 0)
{ {
@ -932,7 +915,7 @@ ActiveDistributedTransactionNumbers(void)
continue; continue;
} }
transactionNumber = (uint64 *) palloc0(sizeof(uint64)); uint64 *transactionNumber = (uint64 *) palloc0(sizeof(uint64));
*transactionNumber = currentBackendData.transactionId.transactionNumber; *transactionNumber = currentBackendData.transactionId.transactionNumber;
activeTransactionNumberList = lappend(activeTransactionNumberList, activeTransactionNumberList = lappend(activeTransactionNumberList,

View File

@ -269,11 +269,9 @@ PG_FUNCTION_INFO_V1(citus_worker_stat_activity);
Datum Datum
citus_dist_stat_activity(PG_FUNCTION_ARGS) citus_dist_stat_activity(PG_FUNCTION_ARGS)
{ {
List *citusDistStatStatements = NIL;
CheckCitusVersion(ERROR); CheckCitusVersion(ERROR);
citusDistStatStatements = CitusStatActivity(CITUS_DIST_STAT_ACTIVITY_QUERY); List *citusDistStatStatements = CitusStatActivity(CITUS_DIST_STAT_ACTIVITY_QUERY);
ReturnCitusDistStats(citusDistStatStatements, fcinfo); ReturnCitusDistStats(citusDistStatStatements, fcinfo);
@ -289,11 +287,9 @@ citus_dist_stat_activity(PG_FUNCTION_ARGS)
Datum Datum
citus_worker_stat_activity(PG_FUNCTION_ARGS) citus_worker_stat_activity(PG_FUNCTION_ARGS)
{ {
List *citusWorkerStatStatements = NIL;
CheckCitusVersion(ERROR); CheckCitusVersion(ERROR);
citusWorkerStatStatements = CitusStatActivity(CITUS_WORKER_STAT_ACTIVITY_QUERY); List *citusWorkerStatStatements = CitusStatActivity(CITUS_WORKER_STAT_ACTIVITY_QUERY);
ReturnCitusDistStats(citusWorkerStatStatements, fcinfo); ReturnCitusDistStats(citusWorkerStatStatements, fcinfo);
@ -315,11 +311,8 @@ citus_worker_stat_activity(PG_FUNCTION_ARGS)
static List * static List *
CitusStatActivity(const char *statQuery) CitusStatActivity(const char *statQuery)
{ {
List *citusStatsList = NIL;
List *workerNodeList = ActivePrimaryWorkerNodeList(NoLock); List *workerNodeList = ActivePrimaryWorkerNodeList(NoLock);
ListCell *workerNodeCell = NULL; ListCell *workerNodeCell = NULL;
char *nodeUser = NULL;
List *connectionList = NIL; List *connectionList = NIL;
ListCell *connectionCell = NULL; ListCell *connectionCell = NULL;
@ -329,14 +322,14 @@ CitusStatActivity(const char *statQuery)
* the authentication for self-connection via any user who calls the citus * the authentication for self-connection via any user who calls the citus
* stat activity functions. * stat activity functions.
*/ */
citusStatsList = GetLocalNodeCitusDistStat(statQuery); List *citusStatsList = GetLocalNodeCitusDistStat(statQuery);
/* /*
* We prefer to connect with the current user to the remote nodes. This will * We prefer to connect with the current user to the remote nodes. This will
* ensure that we have the same privilage restrictions that pg_stat_activity * ensure that we have the same privilage restrictions that pg_stat_activity
* enforces. * enforces.
*/ */
nodeUser = CurrentUserName(); char *nodeUser = CurrentUserName();
/* open connections in parallel */ /* open connections in parallel */
foreach(workerNodeCell, workerNodeList) foreach(workerNodeCell, workerNodeList)
@ -344,7 +337,6 @@ CitusStatActivity(const char *statQuery)
WorkerNode *workerNode = (WorkerNode *) lfirst(workerNodeCell); WorkerNode *workerNode = (WorkerNode *) lfirst(workerNodeCell);
char *nodeName = workerNode->workerName; char *nodeName = workerNode->workerName;
int nodePort = workerNode->workerPort; int nodePort = workerNode->workerPort;
MultiConnection *connection = NULL;
int connectionFlags = 0; int connectionFlags = 0;
if (workerNode->groupId == GetLocalGroupId()) if (workerNode->groupId == GetLocalGroupId())
@ -353,7 +345,8 @@ CitusStatActivity(const char *statQuery)
continue; continue;
} }
connection = StartNodeUserDatabaseConnection(connectionFlags, nodeName, nodePort, MultiConnection *connection = StartNodeUserDatabaseConnection(connectionFlags,
nodeName, nodePort,
nodeUser, NULL); nodeUser, NULL);
connectionList = lappend(connectionList, connection); connectionList = lappend(connectionList, connection);
@ -365,9 +358,8 @@ CitusStatActivity(const char *statQuery)
foreach(connectionCell, connectionList) foreach(connectionCell, connectionList)
{ {
MultiConnection *connection = (MultiConnection *) lfirst(connectionCell); MultiConnection *connection = (MultiConnection *) lfirst(connectionCell);
int querySent = false;
querySent = SendRemoteCommand(connection, statQuery); int querySent = SendRemoteCommand(connection, statQuery);
if (querySent == 0) if (querySent == 0)
{ {
ReportConnectionError(connection, WARNING); ReportConnectionError(connection, WARNING);
@ -378,21 +370,17 @@ CitusStatActivity(const char *statQuery)
foreach(connectionCell, connectionList) foreach(connectionCell, connectionList)
{ {
MultiConnection *connection = (MultiConnection *) lfirst(connectionCell); MultiConnection *connection = (MultiConnection *) lfirst(connectionCell);
PGresult *result = NULL;
bool raiseInterrupts = true; bool raiseInterrupts = true;
int64 rowIndex = 0;
int64 rowCount = 0;
int64 colCount = 0;
result = GetRemoteCommandResult(connection, raiseInterrupts); PGresult *result = GetRemoteCommandResult(connection, raiseInterrupts);
if (!IsResponseOK(result)) if (!IsResponseOK(result))
{ {
ReportResultError(connection, result, WARNING); ReportResultError(connection, result, WARNING);
continue; continue;
} }
rowCount = PQntuples(result); int64 rowCount = PQntuples(result);
colCount = PQnfields(result); int64 colCount = PQnfields(result);
if (colCount != CITUS_DIST_STAT_ACTIVITY_QUERY_COLS) if (colCount != CITUS_DIST_STAT_ACTIVITY_QUERY_COLS)
{ {
@ -405,7 +393,7 @@ CitusStatActivity(const char *statQuery)
continue; continue;
} }
for (rowIndex = 0; rowIndex < rowCount; rowIndex++) for (int64 rowIndex = 0; rowIndex < rowCount; rowIndex++)
{ {
CitusDistStat *citusDistStat = ParseCitusDistStat(result, rowIndex); CitusDistStat *citusDistStat = ParseCitusDistStat(result, rowIndex);
@ -436,9 +424,7 @@ GetLocalNodeCitusDistStat(const char *statQuery)
{ {
List *citusStatsList = NIL; List *citusStatsList = NIL;
List *workerNodeList = NIL;
ListCell *workerNodeCell = NULL; ListCell *workerNodeCell = NULL;
int localGroupId = -1;
if (IsCoordinator()) if (IsCoordinator())
{ {
@ -452,10 +438,10 @@ GetLocalNodeCitusDistStat(const char *statQuery)
return citusStatsList; return citusStatsList;
} }
localGroupId = GetLocalGroupId(); int localGroupId = GetLocalGroupId();
/* get the current worker's node stats */ /* get the current worker's node stats */
workerNodeList = ActivePrimaryWorkerNodeList(NoLock); List *workerNodeList = ActivePrimaryWorkerNodeList(NoLock);
foreach(workerNodeCell, workerNodeList) foreach(workerNodeCell, workerNodeList)
{ {
WorkerNode *workerNode = (WorkerNode *) lfirst(workerNodeCell); WorkerNode *workerNode = (WorkerNode *) lfirst(workerNodeCell);
@ -488,10 +474,9 @@ static CitusDistStat *
ParseCitusDistStat(PGresult *result, int64 rowIndex) ParseCitusDistStat(PGresult *result, int64 rowIndex)
{ {
CitusDistStat *citusDistStat = (CitusDistStat *) palloc0(sizeof(CitusDistStat)); CitusDistStat *citusDistStat = (CitusDistStat *) palloc0(sizeof(CitusDistStat));
int initiator_node_identifier = 0;
initiator_node_identifier = int initiator_node_identifier =
PQgetisnull(result, rowIndex, 0) ? -1 : ParseIntField(result, rowIndex, 0); PQgetisnull(result, rowIndex, 0) ? -1 : ParseIntField(result, rowIndex, 0);
ReplaceInitiatorNodeIdentifier(initiator_node_identifier, citusDistStat); ReplaceInitiatorNodeIdentifier(initiator_node_identifier, citusDistStat);
@ -591,14 +576,11 @@ static List *
LocalNodeCitusDistStat(const char *statQuery, const char *hostname, int port) LocalNodeCitusDistStat(const char *statQuery, const char *hostname, int port)
{ {
List *localNodeCitusDistStatList = NIL; List *localNodeCitusDistStatList = NIL;
int spiConnectionResult = 0;
int spiQueryResult = 0;
bool readOnly = true; bool readOnly = true;
uint32 rowIndex = 0;
MemoryContext upperContext = CurrentMemoryContext, oldContext = NULL; MemoryContext upperContext = CurrentMemoryContext, oldContext = NULL;
spiConnectionResult = SPI_connect(); int spiConnectionResult = SPI_connect();
if (spiConnectionResult != SPI_OK_CONNECT) if (spiConnectionResult != SPI_OK_CONNECT)
{ {
ereport(WARNING, (errmsg("could not connect to SPI manager to get " ereport(WARNING, (errmsg("could not connect to SPI manager to get "
@ -609,7 +591,7 @@ LocalNodeCitusDistStat(const char *statQuery, const char *hostname, int port)
return NIL; return NIL;
} }
spiQueryResult = SPI_execute(statQuery, readOnly, 0); int spiQueryResult = SPI_execute(statQuery, readOnly, 0);
if (spiQueryResult != SPI_OK_SELECT) if (spiQueryResult != SPI_OK_SELECT)
{ {
ereport(WARNING, (errmsg("execution was not successful while trying to get " ereport(WARNING, (errmsg("execution was not successful while trying to get "
@ -629,15 +611,13 @@ LocalNodeCitusDistStat(const char *statQuery, const char *hostname, int port)
*/ */
oldContext = MemoryContextSwitchTo(upperContext); oldContext = MemoryContextSwitchTo(upperContext);
for (rowIndex = 0; rowIndex < SPI_processed; rowIndex++) for (uint32 rowIndex = 0; rowIndex < SPI_processed; rowIndex++)
{ {
HeapTuple row = NULL;
TupleDesc rowDescriptor = SPI_tuptable->tupdesc; TupleDesc rowDescriptor = SPI_tuptable->tupdesc;
CitusDistStat *citusDistStat = NULL;
/* we use pointers from the tuple, so copy it before processing */ /* we use pointers from the tuple, so copy it before processing */
row = SPI_copytuple(SPI_tuptable->vals[rowIndex]); HeapTuple row = SPI_copytuple(SPI_tuptable->vals[rowIndex]);
citusDistStat = HeapTupleToCitusDistStat(row, rowDescriptor); CitusDistStat *citusDistStat = HeapTupleToCitusDistStat(row, rowDescriptor);
/* /*
* Add the query_host_name and query_host_port which denote where * Add the query_host_name and query_host_port which denote where
@ -670,9 +650,8 @@ static CitusDistStat *
HeapTupleToCitusDistStat(HeapTuple result, TupleDesc rowDescriptor) HeapTupleToCitusDistStat(HeapTuple result, TupleDesc rowDescriptor)
{ {
CitusDistStat *citusDistStat = (CitusDistStat *) palloc0(sizeof(CitusDistStat)); CitusDistStat *citusDistStat = (CitusDistStat *) palloc0(sizeof(CitusDistStat));
int initiator_node_identifier = 0;
initiator_node_identifier = ParseIntFieldFromHeapTuple(result, rowDescriptor, 1); int initiator_node_identifier = ParseIntFieldFromHeapTuple(result, rowDescriptor, 1);
ReplaceInitiatorNodeIdentifier(initiator_node_identifier, citusDistStat); ReplaceInitiatorNodeIdentifier(initiator_node_identifier, citusDistStat);
@ -721,10 +700,9 @@ HeapTupleToCitusDistStat(HeapTuple result, TupleDesc rowDescriptor)
static int64 static int64
ParseIntFieldFromHeapTuple(HeapTuple tuple, TupleDesc tupdesc, int colIndex) ParseIntFieldFromHeapTuple(HeapTuple tuple, TupleDesc tupdesc, int colIndex)
{ {
Datum resultDatum;
bool isNull = false; bool isNull = false;
resultDatum = SPI_getbinval(tuple, tupdesc, colIndex, &isNull); Datum resultDatum = SPI_getbinval(tuple, tupdesc, colIndex, &isNull);
if (isNull) if (isNull)
{ {
return 0; return 0;
@ -741,10 +719,9 @@ ParseIntFieldFromHeapTuple(HeapTuple tuple, TupleDesc tupdesc, int colIndex)
static text * static text *
ParseTextFieldFromHeapTuple(HeapTuple tuple, TupleDesc tupdesc, int colIndex) ParseTextFieldFromHeapTuple(HeapTuple tuple, TupleDesc tupdesc, int colIndex)
{ {
Datum resultDatum;
bool isNull = false; bool isNull = false;
resultDatum = SPI_getbinval(tuple, tupdesc, colIndex, &isNull); Datum resultDatum = SPI_getbinval(tuple, tupdesc, colIndex, &isNull);
if (isNull) if (isNull)
{ {
return NULL; return NULL;
@ -761,10 +738,9 @@ ParseTextFieldFromHeapTuple(HeapTuple tuple, TupleDesc tupdesc, int colIndex)
static Name static Name
ParseNameFieldFromHeapTuple(HeapTuple tuple, TupleDesc tupdesc, int colIndex) ParseNameFieldFromHeapTuple(HeapTuple tuple, TupleDesc tupdesc, int colIndex)
{ {
Datum resultDatum;
bool isNull = false; bool isNull = false;
resultDatum = SPI_getbinval(tuple, tupdesc, colIndex, &isNull); Datum resultDatum = SPI_getbinval(tuple, tupdesc, colIndex, &isNull);
if (isNull) if (isNull)
{ {
return NULL; return NULL;
@ -781,10 +757,9 @@ ParseNameFieldFromHeapTuple(HeapTuple tuple, TupleDesc tupdesc, int colIndex)
static inet * static inet *
ParseInetFieldFromHeapTuple(HeapTuple tuple, TupleDesc tupdesc, int colIndex) ParseInetFieldFromHeapTuple(HeapTuple tuple, TupleDesc tupdesc, int colIndex)
{ {
Datum resultDatum;
bool isNull = false; bool isNull = false;
resultDatum = SPI_getbinval(tuple, tupdesc, colIndex, &isNull); Datum resultDatum = SPI_getbinval(tuple, tupdesc, colIndex, &isNull);
if (isNull) if (isNull)
{ {
return NULL; return NULL;
@ -801,10 +776,9 @@ ParseInetFieldFromHeapTuple(HeapTuple tuple, TupleDesc tupdesc, int colIndex)
static TimestampTz static TimestampTz
ParseTimestampTzFieldFromHeapTuple(HeapTuple tuple, TupleDesc tupdesc, int colIndex) ParseTimestampTzFieldFromHeapTuple(HeapTuple tuple, TupleDesc tupdesc, int colIndex)
{ {
Datum resultDatum;
bool isNull = false; bool isNull = false;
resultDatum = SPI_getbinval(tuple, tupdesc, colIndex, &isNull); Datum resultDatum = SPI_getbinval(tuple, tupdesc, colIndex, &isNull);
if (isNull) if (isNull)
{ {
return DT_NOBEGIN; return DT_NOBEGIN;
@ -821,10 +795,9 @@ ParseTimestampTzFieldFromHeapTuple(HeapTuple tuple, TupleDesc tupdesc, int colIn
static TransactionId static TransactionId
ParseXIDFieldFromHeapTuple(HeapTuple tuple, TupleDesc tupdesc, int colIndex) ParseXIDFieldFromHeapTuple(HeapTuple tuple, TupleDesc tupdesc, int colIndex)
{ {
Datum resultDatum;
bool isNull = false; bool isNull = false;
resultDatum = SPI_getbinval(tuple, tupdesc, colIndex, &isNull); Datum resultDatum = SPI_getbinval(tuple, tupdesc, colIndex, &isNull);
if (isNull) if (isNull)
{ {
/* /*
@ -845,18 +818,14 @@ ParseXIDFieldFromHeapTuple(HeapTuple tuple, TupleDesc tupdesc, int colIndex)
static text * static text *
ParseTextField(PGresult *result, int rowIndex, int colIndex) ParseTextField(PGresult *result, int rowIndex, int colIndex)
{ {
char *resultString = NULL;
Datum resultStringDatum = 0;
Datum textDatum = 0;
if (PQgetisnull(result, rowIndex, colIndex)) if (PQgetisnull(result, rowIndex, colIndex))
{ {
return NULL; return NULL;
} }
resultString = PQgetvalue(result, rowIndex, colIndex); char *resultString = PQgetvalue(result, rowIndex, colIndex);
resultStringDatum = CStringGetDatum(resultString); Datum resultStringDatum = CStringGetDatum(resultString);
textDatum = DirectFunctionCall1(textin, resultStringDatum); Datum textDatum = DirectFunctionCall1(textin, resultStringDatum);
return (text *) DatumGetPointer(textDatum); return (text *) DatumGetPointer(textDatum);
} }
@ -869,8 +838,6 @@ ParseTextField(PGresult *result, int rowIndex, int colIndex)
static Name static Name
ParseNameField(PGresult *result, int rowIndex, int colIndex) ParseNameField(PGresult *result, int rowIndex, int colIndex)
{ {
char *resultString = NULL;
Datum resultStringDatum = 0;
Datum nameDatum = 0; Datum nameDatum = 0;
if (PQgetisnull(result, rowIndex, colIndex)) if (PQgetisnull(result, rowIndex, colIndex))
@ -878,8 +845,8 @@ ParseNameField(PGresult *result, int rowIndex, int colIndex)
return (Name) nameDatum; return (Name) nameDatum;
} }
resultString = PQgetvalue(result, rowIndex, colIndex); char *resultString = PQgetvalue(result, rowIndex, colIndex);
resultStringDatum = CStringGetDatum(resultString); Datum resultStringDatum = CStringGetDatum(resultString);
nameDatum = DirectFunctionCall1(namein, resultStringDatum); nameDatum = DirectFunctionCall1(namein, resultStringDatum);
return (Name) DatumGetPointer(nameDatum); return (Name) DatumGetPointer(nameDatum);
@ -893,18 +860,14 @@ ParseNameField(PGresult *result, int rowIndex, int colIndex)
static inet * static inet *
ParseInetField(PGresult *result, int rowIndex, int colIndex) ParseInetField(PGresult *result, int rowIndex, int colIndex)
{ {
char *resultString = NULL;
Datum resultStringDatum = 0;
Datum inetDatum = 0;
if (PQgetisnull(result, rowIndex, colIndex)) if (PQgetisnull(result, rowIndex, colIndex))
{ {
return NULL; return NULL;
} }
resultString = PQgetvalue(result, rowIndex, colIndex); char *resultString = PQgetvalue(result, rowIndex, colIndex);
resultStringDatum = CStringGetDatum(resultString); Datum resultStringDatum = CStringGetDatum(resultString);
inetDatum = DirectFunctionCall1(inet_in, resultStringDatum); Datum inetDatum = DirectFunctionCall1(inet_in, resultStringDatum);
return DatumGetInetP(inetDatum); return DatumGetInetP(inetDatum);
} }
@ -917,10 +880,6 @@ ParseInetField(PGresult *result, int rowIndex, int colIndex)
static TransactionId static TransactionId
ParseXIDField(PGresult *result, int rowIndex, int colIndex) ParseXIDField(PGresult *result, int rowIndex, int colIndex)
{ {
char *resultString = NULL;
Datum resultStringDatum = 0;
Datum XIDDatum = 0;
if (PQgetisnull(result, rowIndex, colIndex)) if (PQgetisnull(result, rowIndex, colIndex))
{ {
/* /*
@ -930,9 +889,9 @@ ParseXIDField(PGresult *result, int rowIndex, int colIndex)
return PG_UINT32_MAX; return PG_UINT32_MAX;
} }
resultString = PQgetvalue(result, rowIndex, colIndex); char *resultString = PQgetvalue(result, rowIndex, colIndex);
resultStringDatum = CStringGetDatum(resultString); Datum resultStringDatum = CStringGetDatum(resultString);
XIDDatum = DirectFunctionCall1(xidin, resultStringDatum); Datum XIDDatum = DirectFunctionCall1(xidin, resultStringDatum);
return DatumGetTransactionId(XIDDatum); return DatumGetTransactionId(XIDDatum);
} }

View File

@ -103,11 +103,8 @@ check_distributed_deadlocks(PG_FUNCTION_ARGS)
bool bool
CheckForDistributedDeadlocks(void) CheckForDistributedDeadlocks(void)
{ {
WaitGraph *waitGraph = NULL;
HTAB *adjacencyLists = NULL;
HASH_SEQ_STATUS status; HASH_SEQ_STATUS status;
TransactionNode *transactionNode = NULL; TransactionNode *transactionNode = NULL;
int edgeCount = 0;
int localGroupId = GetLocalGroupId(); int localGroupId = GetLocalGroupId();
List *workerNodeList = ActiveReadableNodeList(); List *workerNodeList = ActiveReadableNodeList();
@ -122,10 +119,10 @@ CheckForDistributedDeadlocks(void)
return false; return false;
} }
waitGraph = BuildGlobalWaitGraph(); WaitGraph *waitGraph = BuildGlobalWaitGraph();
adjacencyLists = BuildAdjacencyListsForWaitGraph(waitGraph); HTAB *adjacencyLists = BuildAdjacencyListsForWaitGraph(waitGraph);
edgeCount = waitGraph->edgeCount; int edgeCount = waitGraph->edgeCount;
/* /*
* We iterate on transaction nodes and search for deadlocks where the * We iterate on transaction nodes and search for deadlocks where the
@ -134,7 +131,6 @@ CheckForDistributedDeadlocks(void)
hash_seq_init(&status, adjacencyLists); hash_seq_init(&status, adjacencyLists);
while ((transactionNode = (TransactionNode *) hash_seq_search(&status)) != 0) while ((transactionNode = (TransactionNode *) hash_seq_search(&status)) != 0)
{ {
bool deadlockFound = false;
List *deadlockPath = NIL; List *deadlockPath = NIL;
/* /*
@ -151,7 +147,7 @@ CheckForDistributedDeadlocks(void)
ResetVisitedFields(adjacencyLists); ResetVisitedFields(adjacencyLists);
deadlockFound = CheckDeadlockForTransactionNode(transactionNode, bool deadlockFound = CheckDeadlockForTransactionNode(transactionNode,
maxStackDepth, maxStackDepth,
&deadlockPath); &deadlockPath);
if (deadlockFound) if (deadlockFound)
@ -184,8 +180,6 @@ CheckForDistributedDeadlocks(void)
(TransactionNode *) lfirst(participantTransactionCell); (TransactionNode *) lfirst(participantTransactionCell);
bool transactionAssociatedWithProc = bool transactionAssociatedWithProc =
AssociateDistributedTransactionWithBackendProc(currentNode); AssociateDistributedTransactionWithBackendProc(currentNode);
TimestampTz youngestTimestamp = 0;
TimestampTz currentTimestamp = 0;
LogTransactionNode(currentNode); LogTransactionNode(currentNode);
@ -201,8 +195,9 @@ CheckForDistributedDeadlocks(void)
continue; continue;
} }
youngestTimestamp = youngestAliveTransaction->transactionId.timestamp; TimestampTz youngestTimestamp =
currentTimestamp = currentNode->transactionId.timestamp; youngestAliveTransaction->transactionId.timestamp;
TimestampTz currentTimestamp = currentNode->transactionId.timestamp;
if (timestamptz_cmp_internal(currentTimestamp, youngestTimestamp) == 1) if (timestamptz_cmp_internal(currentTimestamp, youngestTimestamp) == 1)
{ {
youngestAliveTransaction = currentNode; youngestAliveTransaction = currentNode;
@ -258,7 +253,6 @@ CheckDeadlockForTransactionNode(TransactionNode *startingTransactionNode,
/* traverse the graph and search for the deadlocks */ /* traverse the graph and search for the deadlocks */
while (toBeVisitedNodes != NIL) while (toBeVisitedNodes != NIL)
{ {
int currentStackDepth;
QueuedTransactionNode *queuedTransactionNode = QueuedTransactionNode *queuedTransactionNode =
(QueuedTransactionNode *) linitial(toBeVisitedNodes); (QueuedTransactionNode *) linitial(toBeVisitedNodes);
TransactionNode *currentTransactionNode = queuedTransactionNode->transactionNode; TransactionNode *currentTransactionNode = queuedTransactionNode->transactionNode;
@ -284,7 +278,7 @@ CheckDeadlockForTransactionNode(TransactionNode *startingTransactionNode,
currentTransactionNode->transactionVisited = true; currentTransactionNode->transactionVisited = true;
/* set the stack's corresponding element with the current node */ /* set the stack's corresponding element with the current node */
currentStackDepth = queuedTransactionNode->currentStackDepth; int currentStackDepth = queuedTransactionNode->currentStackDepth;
Assert(currentStackDepth < maxStackDepth); Assert(currentStackDepth < maxStackDepth);
transactionNodeStack[currentStackDepth] = currentTransactionNode; transactionNodeStack[currentStackDepth] = currentTransactionNode;
@ -335,11 +329,10 @@ BuildDeadlockPathList(QueuedTransactionNode *cycledTransactionNode,
List **deadlockPath) List **deadlockPath)
{ {
int deadlockStackDepth = cycledTransactionNode->currentStackDepth; int deadlockStackDepth = cycledTransactionNode->currentStackDepth;
int stackIndex = 0;
*deadlockPath = NIL; *deadlockPath = NIL;
for (stackIndex = 0; stackIndex < deadlockStackDepth; stackIndex++) for (int stackIndex = 0; stackIndex < deadlockStackDepth; stackIndex++)
{ {
*deadlockPath = lappend(*deadlockPath, transactionNodeStack[stackIndex]); *deadlockPath = lappend(*deadlockPath, transactionNodeStack[stackIndex]);
} }
@ -380,13 +373,10 @@ ResetVisitedFields(HTAB *adjacencyList)
static bool static bool
AssociateDistributedTransactionWithBackendProc(TransactionNode *transactionNode) AssociateDistributedTransactionWithBackendProc(TransactionNode *transactionNode)
{ {
int backendIndex = 0; for (int backendIndex = 0; backendIndex < MaxBackends; ++backendIndex)
for (backendIndex = 0; backendIndex < MaxBackends; ++backendIndex)
{ {
PGPROC *currentProc = &ProcGlobal->allProcs[backendIndex]; PGPROC *currentProc = &ProcGlobal->allProcs[backendIndex];
BackendData currentBackendData; BackendData currentBackendData;
DistributedTransactionId *currentTransactionId = NULL;
/* we're not interested in processes that are not active or waiting on a lock */ /* we're not interested in processes that are not active or waiting on a lock */
if (currentProc->pid <= 0) if (currentProc->pid <= 0)
@ -402,7 +392,8 @@ AssociateDistributedTransactionWithBackendProc(TransactionNode *transactionNode)
continue; continue;
} }
currentTransactionId = &currentBackendData.transactionId; DistributedTransactionId *currentTransactionId =
&currentBackendData.transactionId;
if (currentTransactionId->transactionNumber != if (currentTransactionId->transactionNumber !=
transactionNode->transactionId.transactionNumber) transactionNode->transactionId.transactionNumber)
@ -455,9 +446,6 @@ extern HTAB *
BuildAdjacencyListsForWaitGraph(WaitGraph *waitGraph) BuildAdjacencyListsForWaitGraph(WaitGraph *waitGraph)
{ {
HASHCTL info; HASHCTL info;
uint32 hashFlags = 0;
HTAB *adjacencyList = NULL;
int edgeIndex = 0;
int edgeCount = waitGraph->edgeCount; int edgeCount = waitGraph->edgeCount;
memset(&info, 0, sizeof(info)); memset(&info, 0, sizeof(info));
@ -466,15 +454,14 @@ BuildAdjacencyListsForWaitGraph(WaitGraph *waitGraph)
info.hash = DistributedTransactionIdHash; info.hash = DistributedTransactionIdHash;
info.match = DistributedTransactionIdCompare; info.match = DistributedTransactionIdCompare;
info.hcxt = CurrentMemoryContext; info.hcxt = CurrentMemoryContext;
hashFlags = (HASH_ELEM | HASH_FUNCTION | HASH_CONTEXT | HASH_COMPARE); uint32 hashFlags = (HASH_ELEM | HASH_FUNCTION | HASH_CONTEXT | HASH_COMPARE);
adjacencyList = hash_create("distributed deadlock detection", 64, &info, hashFlags); HTAB *adjacencyList = hash_create("distributed deadlock detection", 64, &info,
hashFlags);
for (edgeIndex = 0; edgeIndex < edgeCount; edgeIndex++) for (int edgeIndex = 0; edgeIndex < edgeCount; edgeIndex++)
{ {
WaitEdge *edge = &waitGraph->edges[edgeIndex]; WaitEdge *edge = &waitGraph->edges[edgeIndex];
TransactionNode *waitingTransaction = NULL;
TransactionNode *blockingTransaction = NULL;
bool transactionOriginator = false; bool transactionOriginator = false;
DistributedTransactionId waitingId = { DistributedTransactionId waitingId = {
@ -491,9 +478,9 @@ BuildAdjacencyListsForWaitGraph(WaitGraph *waitGraph)
edge->blockingTransactionStamp edge->blockingTransactionStamp
}; };
waitingTransaction = TransactionNode *waitingTransaction =
GetOrCreateTransactionNode(adjacencyList, &waitingId); GetOrCreateTransactionNode(adjacencyList, &waitingId);
blockingTransaction = TransactionNode *blockingTransaction =
GetOrCreateTransactionNode(adjacencyList, &blockingId); GetOrCreateTransactionNode(adjacencyList, &blockingId);
waitingTransaction->waitsFor = lappend(waitingTransaction->waitsFor, waitingTransaction->waitsFor = lappend(waitingTransaction->waitsFor,
@ -512,11 +499,12 @@ BuildAdjacencyListsForWaitGraph(WaitGraph *waitGraph)
static TransactionNode * static TransactionNode *
GetOrCreateTransactionNode(HTAB *adjacencyList, DistributedTransactionId *transactionId) GetOrCreateTransactionNode(HTAB *adjacencyList, DistributedTransactionId *transactionId)
{ {
TransactionNode *transactionNode = NULL;
bool found = false; bool found = false;
transactionNode = (TransactionNode *) hash_search(adjacencyList, transactionId, TransactionNode *transactionNode = (TransactionNode *) hash_search(adjacencyList,
HASH_ENTER, &found); transactionId,
HASH_ENTER,
&found);
if (!found) if (!found)
{ {
transactionNode->waitsFor = NIL; transactionNode->waitsFor = NIL;
@ -535,9 +523,8 @@ static uint32
DistributedTransactionIdHash(const void *key, Size keysize) DistributedTransactionIdHash(const void *key, Size keysize)
{ {
DistributedTransactionId *entry = (DistributedTransactionId *) key; DistributedTransactionId *entry = (DistributedTransactionId *) key;
uint32 hash = 0;
hash = hash_uint32(entry->initiatorNodeIdentifier); uint32 hash = hash_uint32(entry->initiatorNodeIdentifier);
hash = hash_combine(hash, hash_any((unsigned char *) &entry->transactionNumber, hash = hash_combine(hash, hash_any((unsigned char *) &entry->transactionNumber,
sizeof(int64))); sizeof(int64)));
hash = hash_combine(hash, hash_any((unsigned char *) &entry->timestamp, hash = hash_combine(hash, hash_any((unsigned char *) &entry->timestamp,
@ -601,14 +588,12 @@ DistributedTransactionIdCompare(const void *a, const void *b, Size keysize)
static void static void
LogCancellingBackend(TransactionNode *transactionNode) LogCancellingBackend(TransactionNode *transactionNode)
{ {
StringInfo logMessage = NULL;
if (!LogDistributedDeadlockDetection) if (!LogDistributedDeadlockDetection)
{ {
return; return;
} }
logMessage = makeStringInfo(); StringInfo logMessage = makeStringInfo();
appendStringInfo(logMessage, "Cancelling the following backend " appendStringInfo(logMessage, "Cancelling the following backend "
"to resolve distributed deadlock " "to resolve distributed deadlock "
@ -627,16 +612,13 @@ LogCancellingBackend(TransactionNode *transactionNode)
static void static void
LogTransactionNode(TransactionNode *transactionNode) LogTransactionNode(TransactionNode *transactionNode)
{ {
StringInfo logMessage = NULL;
DistributedTransactionId *transactionId = NULL;
if (!LogDistributedDeadlockDetection) if (!LogDistributedDeadlockDetection)
{ {
return; return;
} }
logMessage = makeStringInfo(); StringInfo logMessage = makeStringInfo();
transactionId = &(transactionNode->transactionId); DistributedTransactionId *transactionId = &(transactionNode->transactionId);
appendStringInfo(logMessage, appendStringInfo(logMessage,
"[DistributedTransactionId: (%d, " UINT64_FORMAT ", %s)] = ", "[DistributedTransactionId: (%d, " UINT64_FORMAT ", %s)] = ",

View File

@ -73,9 +73,7 @@ PG_FUNCTION_INFO_V1(dump_global_wait_edges);
Datum Datum
dump_global_wait_edges(PG_FUNCTION_ARGS) dump_global_wait_edges(PG_FUNCTION_ARGS)
{ {
WaitGraph *waitGraph = NULL; WaitGraph *waitGraph = BuildGlobalWaitGraph();
waitGraph = BuildGlobalWaitGraph();
ReturnWaitGraph(waitGraph, fcinfo); ReturnWaitGraph(waitGraph, fcinfo);
@ -106,7 +104,6 @@ BuildGlobalWaitGraph(void)
WorkerNode *workerNode = (WorkerNode *) lfirst(workerNodeCell); WorkerNode *workerNode = (WorkerNode *) lfirst(workerNodeCell);
char *nodeName = workerNode->workerName; char *nodeName = workerNode->workerName;
int nodePort = workerNode->workerPort; int nodePort = workerNode->workerPort;
MultiConnection *connection = NULL;
int connectionFlags = 0; int connectionFlags = 0;
if (workerNode->groupId == localNodeId) if (workerNode->groupId == localNodeId)
@ -115,7 +112,8 @@ BuildGlobalWaitGraph(void)
continue; continue;
} }
connection = StartNodeUserDatabaseConnection(connectionFlags, nodeName, nodePort, MultiConnection *connection = StartNodeUserDatabaseConnection(connectionFlags,
nodeName, nodePort,
nodeUser, NULL); nodeUser, NULL);
connectionList = lappend(connectionList, connection); connectionList = lappend(connectionList, connection);
@ -127,10 +125,9 @@ BuildGlobalWaitGraph(void)
foreach(connectionCell, connectionList) foreach(connectionCell, connectionList)
{ {
MultiConnection *connection = (MultiConnection *) lfirst(connectionCell); MultiConnection *connection = (MultiConnection *) lfirst(connectionCell);
int querySent = false;
const char *command = "SELECT * FROM dump_local_wait_edges()"; const char *command = "SELECT * FROM dump_local_wait_edges()";
querySent = SendRemoteCommand(connection, command); int querySent = SendRemoteCommand(connection, command);
if (querySent == 0) if (querySent == 0)
{ {
ReportConnectionError(connection, WARNING); ReportConnectionError(connection, WARNING);
@ -141,21 +138,17 @@ BuildGlobalWaitGraph(void)
foreach(connectionCell, connectionList) foreach(connectionCell, connectionList)
{ {
MultiConnection *connection = (MultiConnection *) lfirst(connectionCell); MultiConnection *connection = (MultiConnection *) lfirst(connectionCell);
PGresult *result = NULL;
bool raiseInterrupts = true; bool raiseInterrupts = true;
int64 rowIndex = 0;
int64 rowCount = 0;
int64 colCount = 0;
result = GetRemoteCommandResult(connection, raiseInterrupts); PGresult *result = GetRemoteCommandResult(connection, raiseInterrupts);
if (!IsResponseOK(result)) if (!IsResponseOK(result))
{ {
ReportResultError(connection, result, WARNING); ReportResultError(connection, result, WARNING);
continue; continue;
} }
rowCount = PQntuples(result); int64 rowCount = PQntuples(result);
colCount = PQnfields(result); int64 colCount = PQnfields(result);
if (colCount != 9) if (colCount != 9)
{ {
@ -164,7 +157,7 @@ BuildGlobalWaitGraph(void)
continue; continue;
} }
for (rowIndex = 0; rowIndex < rowCount; rowIndex++) for (int64 rowIndex = 0; rowIndex < rowCount; rowIndex++)
{ {
AddWaitEdgeFromResult(waitGraph, result, rowIndex); AddWaitEdgeFromResult(waitGraph, result, rowIndex);
} }
@ -205,14 +198,12 @@ AddWaitEdgeFromResult(WaitGraph *waitGraph, PGresult *result, int rowIndex)
int64 int64
ParseIntField(PGresult *result, int rowIndex, int colIndex) ParseIntField(PGresult *result, int rowIndex, int colIndex)
{ {
char *resultString = NULL;
if (PQgetisnull(result, rowIndex, colIndex)) if (PQgetisnull(result, rowIndex, colIndex))
{ {
return 0; return 0;
} }
resultString = PQgetvalue(result, rowIndex, colIndex); char *resultString = PQgetvalue(result, rowIndex, colIndex);
return pg_strtouint64(resultString, NULL, 10); return pg_strtouint64(resultString, NULL, 10);
} }
@ -225,14 +216,12 @@ ParseIntField(PGresult *result, int rowIndex, int colIndex)
bool bool
ParseBoolField(PGresult *result, int rowIndex, int colIndex) ParseBoolField(PGresult *result, int rowIndex, int colIndex)
{ {
char *resultString = NULL;
if (PQgetisnull(result, rowIndex, colIndex)) if (PQgetisnull(result, rowIndex, colIndex))
{ {
return false; return false;
} }
resultString = PQgetvalue(result, rowIndex, colIndex); char *resultString = PQgetvalue(result, rowIndex, colIndex);
if (strlen(resultString) != 1) if (strlen(resultString) != 1)
{ {
return false; return false;
@ -249,18 +238,14 @@ ParseBoolField(PGresult *result, int rowIndex, int colIndex)
TimestampTz TimestampTz
ParseTimestampTzField(PGresult *result, int rowIndex, int colIndex) ParseTimestampTzField(PGresult *result, int rowIndex, int colIndex)
{ {
char *resultString = NULL;
Datum resultStringDatum = 0;
Datum timestampDatum = 0;
if (PQgetisnull(result, rowIndex, colIndex)) if (PQgetisnull(result, rowIndex, colIndex))
{ {
return DT_NOBEGIN; return DT_NOBEGIN;
} }
resultString = PQgetvalue(result, rowIndex, colIndex); char *resultString = PQgetvalue(result, rowIndex, colIndex);
resultStringDatum = CStringGetDatum(resultString); Datum resultStringDatum = CStringGetDatum(resultString);
timestampDatum = DirectFunctionCall3(timestamptz_in, resultStringDatum, 0, -1); Datum timestampDatum = DirectFunctionCall3(timestamptz_in, resultStringDatum, 0, -1);
return DatumGetTimestampTz(timestampDatum); return DatumGetTimestampTz(timestampDatum);
} }
@ -286,7 +271,6 @@ dump_local_wait_edges(PG_FUNCTION_ARGS)
static void static void
ReturnWaitGraph(WaitGraph *waitGraph, FunctionCallInfo fcinfo) ReturnWaitGraph(WaitGraph *waitGraph, FunctionCallInfo fcinfo)
{ {
size_t curEdgeNum = 0;
TupleDesc tupleDesc; TupleDesc tupleDesc;
Tuplestorestate *tupleStore = SetupTuplestore(fcinfo, &tupleDesc); Tuplestorestate *tupleStore = SetupTuplestore(fcinfo, &tupleDesc);
@ -302,7 +286,7 @@ ReturnWaitGraph(WaitGraph *waitGraph, FunctionCallInfo fcinfo)
* 07: blocking_transaction_stamp * 07: blocking_transaction_stamp
* 08: blocking_transaction_waiting * 08: blocking_transaction_waiting
*/ */
for (curEdgeNum = 0; curEdgeNum < waitGraph->edgeCount; curEdgeNum++) for (size_t curEdgeNum = 0; curEdgeNum < waitGraph->edgeCount; curEdgeNum++)
{ {
Datum values[9]; Datum values[9];
bool nulls[9]; bool nulls[9];
@ -353,8 +337,6 @@ ReturnWaitGraph(WaitGraph *waitGraph, FunctionCallInfo fcinfo)
static WaitGraph * static WaitGraph *
BuildLocalWaitGraph(void) BuildLocalWaitGraph(void)
{ {
WaitGraph *waitGraph = NULL;
int curBackend = 0;
PROCStack remaining; PROCStack remaining;
int totalProcs = TotalProcCount(); int totalProcs = TotalProcCount();
@ -364,7 +346,7 @@ BuildLocalWaitGraph(void)
* more than enough space to build the list of wait edges without a single * more than enough space to build the list of wait edges without a single
* allocation. * allocation.
*/ */
waitGraph = (WaitGraph *) palloc0(sizeof(WaitGraph)); WaitGraph *waitGraph = (WaitGraph *) palloc0(sizeof(WaitGraph));
waitGraph->localNodeId = GetLocalGroupId(); waitGraph->localNodeId = GetLocalGroupId();
waitGraph->allocatedSize = totalProcs * 3; waitGraph->allocatedSize = totalProcs * 3;
waitGraph->edgeCount = 0; waitGraph->edgeCount = 0;
@ -384,7 +366,7 @@ BuildLocalWaitGraph(void)
*/ */
/* build list of starting procs */ /* build list of starting procs */
for (curBackend = 0; curBackend < totalProcs; curBackend++) for (int curBackend = 0; curBackend < totalProcs; curBackend++)
{ {
PGPROC *currentProc = &ProcGlobal->allProcs[curBackend]; PGPROC *currentProc = &ProcGlobal->allProcs[curBackend];
BackendData currentBackendData; BackendData currentBackendData;
@ -476,24 +458,20 @@ BuildLocalWaitGraph(void)
static bool static bool
IsProcessWaitingForSafeOperations(PGPROC *proc) IsProcessWaitingForSafeOperations(PGPROC *proc)
{ {
PROCLOCK *waitProcLock = NULL;
LOCK *waitLock = NULL;
PGXACT *pgxact = NULL;
if (proc->waitStatus != STATUS_WAITING) if (proc->waitStatus != STATUS_WAITING)
{ {
return false; return false;
} }
/* get the transaction that the backend associated with */ /* get the transaction that the backend associated with */
pgxact = &ProcGlobal->allPgXact[proc->pgprocno]; PGXACT *pgxact = &ProcGlobal->allPgXact[proc->pgprocno];
if (pgxact->vacuumFlags & PROC_IS_AUTOVACUUM) if (pgxact->vacuumFlags & PROC_IS_AUTOVACUUM)
{ {
return true; return true;
} }
waitProcLock = proc->waitProcLock; PROCLOCK *waitProcLock = proc->waitProcLock;
waitLock = waitProcLock->tag.myLock; LOCK *waitLock = waitProcLock->tag.myLock;
return waitLock->tag.locktag_type == LOCKTAG_RELATION_EXTEND || return waitLock->tag.locktag_type == LOCKTAG_RELATION_EXTEND ||
waitLock->tag.locktag_type == LOCKTAG_PAGE || waitLock->tag.locktag_type == LOCKTAG_PAGE ||
@ -511,11 +489,9 @@ IsProcessWaitingForSafeOperations(PGPROC *proc)
static void static void
LockLockData(void) LockLockData(void)
{ {
int partitionNum = 0;
LockBackendSharedMemory(LW_SHARED); LockBackendSharedMemory(LW_SHARED);
for (partitionNum = 0; partitionNum < NUM_LOCK_PARTITIONS; partitionNum++) for (int partitionNum = 0; partitionNum < NUM_LOCK_PARTITIONS; partitionNum++)
{ {
LWLockAcquire(LockHashPartitionLockByIndex(partitionNum), LW_SHARED); LWLockAcquire(LockHashPartitionLockByIndex(partitionNum), LW_SHARED);
} }
@ -533,9 +509,7 @@ LockLockData(void)
static void static void
UnlockLockData(void) UnlockLockData(void)
{ {
int partitionNum = 0; for (int partitionNum = NUM_LOCK_PARTITIONS - 1; partitionNum >= 0; partitionNum--)
for (partitionNum = NUM_LOCK_PARTITIONS - 1; partitionNum >= 0; partitionNum--)
{ {
LWLockRelease(LockHashPartitionLockByIndex(partitionNum)); LWLockRelease(LockHashPartitionLockByIndex(partitionNum));
} }

View File

@ -133,14 +133,13 @@ void
AllocateRelationAccessHash(void) AllocateRelationAccessHash(void)
{ {
HASHCTL info; HASHCTL info;
uint32 hashFlags = 0;
memset(&info, 0, sizeof(info)); memset(&info, 0, sizeof(info));
info.keysize = sizeof(RelationAccessHashKey); info.keysize = sizeof(RelationAccessHashKey);
info.entrysize = sizeof(RelationAccessHashEntry); info.entrysize = sizeof(RelationAccessHashEntry);
info.hash = tag_hash; info.hash = tag_hash;
info.hcxt = ConnectionContext; info.hcxt = ConnectionContext;
hashFlags = (HASH_ELEM | HASH_BLOBS | HASH_CONTEXT); uint32 hashFlags = (HASH_ELEM | HASH_BLOBS | HASH_CONTEXT);
RelationAccessHash = hash_create("citus connection cache (relationid)", RelationAccessHash = hash_create("citus connection cache (relationid)",
8, &info, hashFlags); 8, &info, hashFlags);
@ -244,12 +243,12 @@ static void
RecordPlacementAccessToCache(Oid relationId, ShardPlacementAccessType accessType) RecordPlacementAccessToCache(Oid relationId, ShardPlacementAccessType accessType)
{ {
RelationAccessHashKey hashKey; RelationAccessHashKey hashKey;
RelationAccessHashEntry *hashEntry;
bool found = false; bool found = false;
hashKey.relationId = relationId; hashKey.relationId = relationId;
hashEntry = hash_search(RelationAccessHash, &hashKey, HASH_ENTER, &found); RelationAccessHashEntry *hashEntry = hash_search(RelationAccessHash, &hashKey,
HASH_ENTER, &found);
if (!found) if (!found)
{ {
hashEntry->relationAccessMode = 0; hashEntry->relationAccessMode = 0;
@ -270,8 +269,6 @@ RecordPlacementAccessToCache(Oid relationId, ShardPlacementAccessType accessType
void void
RecordParallelRelationAccessForTaskList(List *taskList) RecordParallelRelationAccessForTaskList(List *taskList)
{ {
Task *firstTask = NULL;
if (MultiShardConnectionType == SEQUENTIAL_CONNECTION) if (MultiShardConnectionType == SEQUENTIAL_CONNECTION)
{ {
/* sequential mode prevents parallel access */ /* sequential mode prevents parallel access */
@ -288,7 +285,7 @@ RecordParallelRelationAccessForTaskList(List *taskList)
* Since all the tasks in a task list is expected to operate on the same * Since all the tasks in a task list is expected to operate on the same
* distributed table(s), we only need to process the first task. * distributed table(s), we only need to process the first task.
*/ */
firstTask = linitial(taskList); Task *firstTask = linitial(taskList);
if (firstTask->taskType == SQL_TASK) if (firstTask->taskType == SQL_TASK)
{ {
@ -328,7 +325,6 @@ RecordParallelRelationAccessForTaskList(List *taskList)
static void static void
RecordRelationParallelSelectAccessForTask(Task *task) RecordRelationParallelSelectAccessForTask(Task *task)
{ {
List *relationShardList = NIL;
ListCell *relationShardCell = NULL; ListCell *relationShardCell = NULL;
Oid lastRelationId = InvalidOid; Oid lastRelationId = InvalidOid;
@ -338,7 +334,7 @@ RecordRelationParallelSelectAccessForTask(Task *task)
return; return;
} }
relationShardList = task->relationShardList; List *relationShardList = task->relationShardList;
foreach(relationShardCell, relationShardList) foreach(relationShardCell, relationShardList)
{ {
@ -528,13 +524,12 @@ RecordParallelRelationAccessToCache(Oid relationId,
ShardPlacementAccessType placementAccess) ShardPlacementAccessType placementAccess)
{ {
RelationAccessHashKey hashKey; RelationAccessHashKey hashKey;
RelationAccessHashEntry *hashEntry;
bool found = false; bool found = false;
int parallelRelationAccessBit = 0;
hashKey.relationId = relationId; hashKey.relationId = relationId;
hashEntry = hash_search(RelationAccessHash, &hashKey, HASH_ENTER, &found); RelationAccessHashEntry *hashEntry = hash_search(RelationAccessHash, &hashKey,
HASH_ENTER, &found);
if (!found) if (!found)
{ {
hashEntry->relationAccessMode = 0; hashEntry->relationAccessMode = 0;
@ -544,7 +539,7 @@ RecordParallelRelationAccessToCache(Oid relationId,
hashEntry->relationAccessMode |= (1 << (placementAccess)); hashEntry->relationAccessMode |= (1 << (placementAccess));
/* set the bit representing access mode */ /* set the bit representing access mode */
parallelRelationAccessBit = placementAccess + PARALLEL_MODE_FLAG_OFFSET; int parallelRelationAccessBit = placementAccess + PARALLEL_MODE_FLAG_OFFSET;
hashEntry->relationAccessMode |= (1 << parallelRelationAccessBit); hashEntry->relationAccessMode |= (1 << parallelRelationAccessBit);
} }
@ -557,7 +552,6 @@ bool
ParallelQueryExecutedInTransaction(void) ParallelQueryExecutedInTransaction(void)
{ {
HASH_SEQ_STATUS status; HASH_SEQ_STATUS status;
RelationAccessHashEntry *hashEntry;
if (!ShouldRecordRelationAccess() || RelationAccessHash == NULL) if (!ShouldRecordRelationAccess() || RelationAccessHash == NULL)
{ {
@ -566,7 +560,8 @@ ParallelQueryExecutedInTransaction(void)
hash_seq_init(&status, RelationAccessHash); hash_seq_init(&status, RelationAccessHash);
hashEntry = (RelationAccessHashEntry *) hash_seq_search(&status); RelationAccessHashEntry *hashEntry = (RelationAccessHashEntry *) hash_seq_search(
&status);
while (hashEntry != NULL) while (hashEntry != NULL)
{ {
int relationAccessMode = hashEntry->relationAccessMode; int relationAccessMode = hashEntry->relationAccessMode;
@ -621,8 +616,6 @@ static RelationAccessMode
GetRelationAccessMode(Oid relationId, ShardPlacementAccessType accessType) GetRelationAccessMode(Oid relationId, ShardPlacementAccessType accessType)
{ {
RelationAccessHashKey hashKey; RelationAccessHashKey hashKey;
RelationAccessHashEntry *hashEntry;
int relationAcessMode = 0;
bool found = false; bool found = false;
int parallelRelationAccessBit = accessType + PARALLEL_MODE_FLAG_OFFSET; int parallelRelationAccessBit = accessType + PARALLEL_MODE_FLAG_OFFSET;
@ -634,7 +627,8 @@ GetRelationAccessMode(Oid relationId, ShardPlacementAccessType accessType)
hashKey.relationId = relationId; hashKey.relationId = relationId;
hashEntry = hash_search(RelationAccessHash, &hashKey, HASH_FIND, &found); RelationAccessHashEntry *hashEntry = hash_search(RelationAccessHash, &hashKey,
HASH_FIND, &found);
if (!found) if (!found)
{ {
/* relation not accessed at all */ /* relation not accessed at all */
@ -642,7 +636,7 @@ GetRelationAccessMode(Oid relationId, ShardPlacementAccessType accessType)
} }
relationAcessMode = hashEntry->relationAccessMode; int relationAcessMode = hashEntry->relationAccessMode;
if (!(relationAcessMode & (1 << accessType))) if (!(relationAcessMode & (1 << accessType)))
{ {
/* relation not accessed with the given access type */ /* relation not accessed with the given access type */
@ -692,7 +686,6 @@ ShouldRecordRelationAccess()
static void static void
CheckConflictingRelationAccesses(Oid relationId, ShardPlacementAccessType accessType) CheckConflictingRelationAccesses(Oid relationId, ShardPlacementAccessType accessType)
{ {
DistTableCacheEntry *cacheEntry = NULL;
Oid conflictingReferencingRelationId = InvalidOid; Oid conflictingReferencingRelationId = InvalidOid;
ShardPlacementAccessType conflictingAccessType = PLACEMENT_ACCESS_SELECT; ShardPlacementAccessType conflictingAccessType = PLACEMENT_ACCESS_SELECT;
@ -701,7 +694,7 @@ CheckConflictingRelationAccesses(Oid relationId, ShardPlacementAccessType access
return; return;
} }
cacheEntry = DistributedTableCacheEntry(relationId); DistTableCacheEntry *cacheEntry = DistributedTableCacheEntry(relationId);
if (!(cacheEntry->partitionMethod == DISTRIBUTE_BY_NONE && if (!(cacheEntry->partitionMethod == DISTRIBUTE_BY_NONE &&
cacheEntry->referencingRelationsViaForeignKey != NIL)) cacheEntry->referencingRelationsViaForeignKey != NIL))
@ -791,7 +784,6 @@ static void
CheckConflictingParallelRelationAccesses(Oid relationId, ShardPlacementAccessType CheckConflictingParallelRelationAccesses(Oid relationId, ShardPlacementAccessType
accessType) accessType)
{ {
DistTableCacheEntry *cacheEntry = NULL;
Oid conflictingReferencingRelationId = InvalidOid; Oid conflictingReferencingRelationId = InvalidOid;
ShardPlacementAccessType conflictingAccessType = PLACEMENT_ACCESS_SELECT; ShardPlacementAccessType conflictingAccessType = PLACEMENT_ACCESS_SELECT;
@ -800,7 +792,7 @@ CheckConflictingParallelRelationAccesses(Oid relationId, ShardPlacementAccessTyp
return; return;
} }
cacheEntry = DistributedTableCacheEntry(relationId); DistTableCacheEntry *cacheEntry = DistributedTableCacheEntry(relationId);
if (!(cacheEntry->partitionMethod == DISTRIBUTE_BY_HASH && if (!(cacheEntry->partitionMethod == DISTRIBUTE_BY_HASH &&
cacheEntry->referencedRelationsViaForeignKey != NIL)) cacheEntry->referencedRelationsViaForeignKey != NIL))
{ {
@ -877,9 +869,6 @@ HoldsConflictingLockWithReferencedRelations(Oid relationId, ShardPlacementAccess
foreach(referencedRelationCell, cacheEntry->referencedRelationsViaForeignKey) foreach(referencedRelationCell, cacheEntry->referencedRelationsViaForeignKey)
{ {
Oid referencedRelation = lfirst_oid(referencedRelationCell); Oid referencedRelation = lfirst_oid(referencedRelationCell);
RelationAccessMode selectMode = RELATION_NOT_ACCESSED;
RelationAccessMode dmlMode = RELATION_NOT_ACCESSED;
RelationAccessMode ddlMode = RELATION_NOT_ACCESSED;
/* we're only interested in foreign keys to reference tables */ /* we're only interested in foreign keys to reference tables */
if (PartitionMethod(referencedRelation) != DISTRIBUTE_BY_NONE) if (PartitionMethod(referencedRelation) != DISTRIBUTE_BY_NONE)
@ -891,7 +880,7 @@ HoldsConflictingLockWithReferencedRelations(Oid relationId, ShardPlacementAccess
* A select on a reference table could conflict with a DDL * A select on a reference table could conflict with a DDL
* on a distributed table. * on a distributed table.
*/ */
selectMode = GetRelationSelectAccessMode(referencedRelation); RelationAccessMode selectMode = GetRelationSelectAccessMode(referencedRelation);
if (placementAccess == PLACEMENT_ACCESS_DDL && if (placementAccess == PLACEMENT_ACCESS_DDL &&
selectMode != RELATION_NOT_ACCESSED) selectMode != RELATION_NOT_ACCESSED)
{ {
@ -905,7 +894,7 @@ HoldsConflictingLockWithReferencedRelations(Oid relationId, ShardPlacementAccess
* Both DML and DDL operations on a reference table conflicts with * Both DML and DDL operations on a reference table conflicts with
* any parallel operation on distributed tables. * any parallel operation on distributed tables.
*/ */
dmlMode = GetRelationDMLAccessMode(referencedRelation); RelationAccessMode dmlMode = GetRelationDMLAccessMode(referencedRelation);
if (dmlMode != RELATION_NOT_ACCESSED) if (dmlMode != RELATION_NOT_ACCESSED)
{ {
*conflictingRelationId = referencedRelation; *conflictingRelationId = referencedRelation;
@ -914,7 +903,7 @@ HoldsConflictingLockWithReferencedRelations(Oid relationId, ShardPlacementAccess
return true; return true;
} }
ddlMode = GetRelationDDLAccessMode(referencedRelation); RelationAccessMode ddlMode = GetRelationDDLAccessMode(referencedRelation);
if (ddlMode != RELATION_NOT_ACCESSED) if (ddlMode != RELATION_NOT_ACCESSED)
{ {
*conflictingRelationId = referencedRelation; *conflictingRelationId = referencedRelation;
@ -985,7 +974,6 @@ HoldsConflictingLockWithReferencingRelations(Oid relationId, ShardPlacementAcces
} }
else if (placementAccess == PLACEMENT_ACCESS_DML) else if (placementAccess == PLACEMENT_ACCESS_DML)
{ {
RelationAccessMode ddlMode = RELATION_NOT_ACCESSED;
RelationAccessMode dmlMode = GetRelationDMLAccessMode(referencingRelation); RelationAccessMode dmlMode = GetRelationDMLAccessMode(referencingRelation);
if (dmlMode == RELATION_PARALLEL_ACCESSED) if (dmlMode == RELATION_PARALLEL_ACCESSED)
@ -994,7 +982,7 @@ HoldsConflictingLockWithReferencingRelations(Oid relationId, ShardPlacementAcces
*conflictingAccessMode = PLACEMENT_ACCESS_DML; *conflictingAccessMode = PLACEMENT_ACCESS_DML;
} }
ddlMode = GetRelationDDLAccessMode(referencingRelation); RelationAccessMode ddlMode = GetRelationDDLAccessMode(referencingRelation);
if (ddlMode == RELATION_PARALLEL_ACCESSED) if (ddlMode == RELATION_PARALLEL_ACCESSED)
{ {
/* SELECT on a distributed table conflicts with DDL / TRUNCATE */ /* SELECT on a distributed table conflicts with DDL / TRUNCATE */
@ -1004,25 +992,22 @@ HoldsConflictingLockWithReferencingRelations(Oid relationId, ShardPlacementAcces
} }
else if (placementAccess == PLACEMENT_ACCESS_DDL) else if (placementAccess == PLACEMENT_ACCESS_DDL)
{ {
RelationAccessMode selectMode = RELATION_NOT_ACCESSED; RelationAccessMode selectMode = GetRelationSelectAccessMode(
RelationAccessMode ddlMode = RELATION_NOT_ACCESSED; referencingRelation);
RelationAccessMode dmlMode = RELATION_NOT_ACCESSED;
selectMode = GetRelationSelectAccessMode(referencingRelation);
if (selectMode == RELATION_PARALLEL_ACCESSED) if (selectMode == RELATION_PARALLEL_ACCESSED)
{ {
holdsConflictingLocks = true; holdsConflictingLocks = true;
*conflictingAccessMode = PLACEMENT_ACCESS_SELECT; *conflictingAccessMode = PLACEMENT_ACCESS_SELECT;
} }
dmlMode = GetRelationDMLAccessMode(referencingRelation); RelationAccessMode dmlMode = GetRelationDMLAccessMode(referencingRelation);
if (dmlMode == RELATION_PARALLEL_ACCESSED) if (dmlMode == RELATION_PARALLEL_ACCESSED)
{ {
holdsConflictingLocks = true; holdsConflictingLocks = true;
*conflictingAccessMode = PLACEMENT_ACCESS_DML; *conflictingAccessMode = PLACEMENT_ACCESS_DML;
} }
ddlMode = GetRelationDDLAccessMode(referencingRelation); RelationAccessMode ddlMode = GetRelationDDLAccessMode(referencingRelation);
if (ddlMode == RELATION_PARALLEL_ACCESSED) if (ddlMode == RELATION_PARALLEL_ACCESSED)
{ {
holdsConflictingLocks = true; holdsConflictingLocks = true;

View File

@ -59,10 +59,7 @@ StartRemoteTransactionBegin(struct MultiConnection *connection)
{ {
RemoteTransaction *transaction = &connection->remoteTransaction; RemoteTransaction *transaction = &connection->remoteTransaction;
StringInfo beginAndSetDistributedTransactionId = makeStringInfo(); StringInfo beginAndSetDistributedTransactionId = makeStringInfo();
DistributedTransactionId *distributedTransactionId = NULL;
ListCell *subIdCell = NULL; ListCell *subIdCell = NULL;
List *activeSubXacts = NIL;
const char *timestamp = NULL;
Assert(transaction->transactionState == REMOTE_TRANS_INVALID); Assert(transaction->transactionState == REMOTE_TRANS_INVALID);
@ -84,8 +81,9 @@ StartRemoteTransactionBegin(struct MultiConnection *connection)
* and send both in one step. The reason is purely performance, we don't want * and send both in one step. The reason is purely performance, we don't want
* seperate roundtrips for these two statements. * seperate roundtrips for these two statements.
*/ */
distributedTransactionId = GetCurrentDistributedTransactionId(); DistributedTransactionId *distributedTransactionId =
timestamp = timestamptz_to_str(distributedTransactionId->timestamp); GetCurrentDistributedTransactionId();
const char *timestamp = timestamptz_to_str(distributedTransactionId->timestamp);
appendStringInfo(beginAndSetDistributedTransactionId, appendStringInfo(beginAndSetDistributedTransactionId,
"SELECT assign_distributed_transaction_id(%d, " UINT64_FORMAT "SELECT assign_distributed_transaction_id(%d, " UINT64_FORMAT
", '%s');", ", '%s');",
@ -94,7 +92,7 @@ StartRemoteTransactionBegin(struct MultiConnection *connection)
timestamp); timestamp);
/* append context for in-progress SAVEPOINTs for this transaction */ /* append context for in-progress SAVEPOINTs for this transaction */
activeSubXacts = ActiveSubXactContexts(); List *activeSubXacts = ActiveSubXactContexts();
transaction->lastSuccessfulSubXact = TopSubTransactionId; transaction->lastSuccessfulSubXact = TopSubTransactionId;
transaction->lastQueuedSubXact = TopSubTransactionId; transaction->lastQueuedSubXact = TopSubTransactionId;
foreach(subIdCell, activeSubXacts) foreach(subIdCell, activeSubXacts)
@ -139,12 +137,11 @@ void
FinishRemoteTransactionBegin(struct MultiConnection *connection) FinishRemoteTransactionBegin(struct MultiConnection *connection)
{ {
RemoteTransaction *transaction = &connection->remoteTransaction; RemoteTransaction *transaction = &connection->remoteTransaction;
bool clearSuccessful = true;
bool raiseErrors = true; bool raiseErrors = true;
Assert(transaction->transactionState == REMOTE_TRANS_STARTING); Assert(transaction->transactionState == REMOTE_TRANS_STARTING);
clearSuccessful = ClearResults(connection, raiseErrors); bool clearSuccessful = ClearResults(connection, raiseErrors);
if (clearSuccessful) if (clearSuccessful)
{ {
transaction->transactionState = REMOTE_TRANS_STARTED; transaction->transactionState = REMOTE_TRANS_STARTED;
@ -276,7 +273,6 @@ void
FinishRemoteTransactionCommit(MultiConnection *connection) FinishRemoteTransactionCommit(MultiConnection *connection)
{ {
RemoteTransaction *transaction = &connection->remoteTransaction; RemoteTransaction *transaction = &connection->remoteTransaction;
PGresult *result = NULL;
const bool raiseErrors = false; const bool raiseErrors = false;
const bool isCommit = true; const bool isCommit = true;
@ -284,7 +280,7 @@ FinishRemoteTransactionCommit(MultiConnection *connection)
transaction->transactionState == REMOTE_TRANS_1PC_COMMITTING || transaction->transactionState == REMOTE_TRANS_1PC_COMMITTING ||
transaction->transactionState == REMOTE_TRANS_2PC_COMMITTING); transaction->transactionState == REMOTE_TRANS_2PC_COMMITTING);
result = GetRemoteCommandResult(connection, raiseErrors); PGresult *result = GetRemoteCommandResult(connection, raiseErrors);
if (!IsResponseOK(result)) if (!IsResponseOK(result))
{ {
@ -476,7 +472,6 @@ StartRemoteTransactionPrepare(struct MultiConnection *connection)
RemoteTransaction *transaction = &connection->remoteTransaction; RemoteTransaction *transaction = &connection->remoteTransaction;
StringInfoData command; StringInfoData command;
const bool raiseErrors = true; const bool raiseErrors = true;
WorkerNode *workerNode = NULL;
/* can't prepare a nonexistant transaction */ /* can't prepare a nonexistant transaction */
Assert(transaction->transactionState != REMOTE_TRANS_INVALID); Assert(transaction->transactionState != REMOTE_TRANS_INVALID);
@ -490,7 +485,7 @@ StartRemoteTransactionPrepare(struct MultiConnection *connection)
Assign2PCIdentifier(connection); Assign2PCIdentifier(connection);
/* log transactions to workers in pg_dist_transaction */ /* log transactions to workers in pg_dist_transaction */
workerNode = FindWorkerNode(connection->hostname, connection->port); WorkerNode *workerNode = FindWorkerNode(connection->hostname, connection->port);
if (workerNode != NULL) if (workerNode != NULL)
{ {
LogTransactionRecord(workerNode->groupId, transaction->preparedName); LogTransactionRecord(workerNode->groupId, transaction->preparedName);
@ -520,12 +515,11 @@ void
FinishRemoteTransactionPrepare(struct MultiConnection *connection) FinishRemoteTransactionPrepare(struct MultiConnection *connection)
{ {
RemoteTransaction *transaction = &connection->remoteTransaction; RemoteTransaction *transaction = &connection->remoteTransaction;
PGresult *result = NULL;
const bool raiseErrors = true; const bool raiseErrors = true;
Assert(transaction->transactionState == REMOTE_TRANS_PREPARING); Assert(transaction->transactionState == REMOTE_TRANS_PREPARING);
result = GetRemoteCommandResult(connection, raiseErrors); PGresult *result = GetRemoteCommandResult(connection, raiseErrors);
if (!IsResponseOK(result)) if (!IsResponseOK(result))
{ {
@ -596,7 +590,6 @@ void
RemoteTransactionsBeginIfNecessary(List *connectionList) RemoteTransactionsBeginIfNecessary(List *connectionList)
{ {
ListCell *connectionCell = NULL; ListCell *connectionCell = NULL;
bool raiseInterrupts = true;
/* /*
* Don't do anything if not in a coordinated transaction. That allows the * Don't do anything if not in a coordinated transaction. That allows the
@ -630,7 +623,7 @@ RemoteTransactionsBeginIfNecessary(List *connectionList)
StartRemoteTransactionBegin(connection); StartRemoteTransactionBegin(connection);
} }
raiseInterrupts = true; bool raiseInterrupts = true;
WaitForAllConnections(connectionList, raiseInterrupts); WaitForAllConnections(connectionList, raiseInterrupts);
/* get result of all the BEGINs */ /* get result of all the BEGINs */
@ -798,7 +791,6 @@ void
CoordinatedRemoteTransactionsPrepare(void) CoordinatedRemoteTransactionsPrepare(void)
{ {
dlist_iter iter; dlist_iter iter;
bool raiseInterrupts = false;
List *connectionList = NIL; List *connectionList = NIL;
/* issue PREPARE TRANSACTION; to all relevant remote nodes */ /* issue PREPARE TRANSACTION; to all relevant remote nodes */
@ -822,7 +814,7 @@ CoordinatedRemoteTransactionsPrepare(void)
connectionList = lappend(connectionList, connection); connectionList = lappend(connectionList, connection);
} }
raiseInterrupts = true; bool raiseInterrupts = true;
WaitForAllConnections(connectionList, raiseInterrupts); WaitForAllConnections(connectionList, raiseInterrupts);
/* Wait for result */ /* Wait for result */
@ -857,7 +849,6 @@ CoordinatedRemoteTransactionsCommit(void)
{ {
dlist_iter iter; dlist_iter iter;
List *connectionList = NIL; List *connectionList = NIL;
bool raiseInterrupts = false;
/* /*
* Issue appropriate transaction commands to remote nodes. If everything * Issue appropriate transaction commands to remote nodes. If everything
@ -885,7 +876,7 @@ CoordinatedRemoteTransactionsCommit(void)
connectionList = lappend(connectionList, connection); connectionList = lappend(connectionList, connection);
} }
raiseInterrupts = false; bool raiseInterrupts = false;
WaitForAllConnections(connectionList, raiseInterrupts); WaitForAllConnections(connectionList, raiseInterrupts);
/* wait for the replies to the commands to come in */ /* wait for the replies to the commands to come in */
@ -921,7 +912,6 @@ CoordinatedRemoteTransactionsAbort(void)
{ {
dlist_iter iter; dlist_iter iter;
List *connectionList = NIL; List *connectionList = NIL;
bool raiseInterrupts = false;
/* asynchronously send ROLLBACK [PREPARED] */ /* asynchronously send ROLLBACK [PREPARED] */
dlist_foreach(iter, &InProgressTransactions) dlist_foreach(iter, &InProgressTransactions)
@ -942,7 +932,7 @@ CoordinatedRemoteTransactionsAbort(void)
connectionList = lappend(connectionList, connection); connectionList = lappend(connectionList, connection);
} }
raiseInterrupts = false; bool raiseInterrupts = false;
WaitForAllConnections(connectionList, raiseInterrupts); WaitForAllConnections(connectionList, raiseInterrupts);
/* and wait for the results */ /* and wait for the results */

View File

@ -65,11 +65,9 @@ static bool RecoverPreparedTransactionOnWorker(MultiConnection *connection,
Datum Datum
recover_prepared_transactions(PG_FUNCTION_ARGS) recover_prepared_transactions(PG_FUNCTION_ARGS)
{ {
int recoveredTransactionCount = 0;
CheckCitusVersion(ERROR); CheckCitusVersion(ERROR);
recoveredTransactionCount = RecoverTwoPhaseCommits(); int recoveredTransactionCount = RecoverTwoPhaseCommits();
PG_RETURN_INT32(recoveredTransactionCount); PG_RETURN_INT32(recoveredTransactionCount);
} }
@ -83,9 +81,6 @@ recover_prepared_transactions(PG_FUNCTION_ARGS)
void void
LogTransactionRecord(int32 groupId, char *transactionName) LogTransactionRecord(int32 groupId, char *transactionName)
{ {
Relation pgDistTransaction = NULL;
TupleDesc tupleDescriptor = NULL;
HeapTuple heapTuple = NULL;
Datum values[Natts_pg_dist_transaction]; Datum values[Natts_pg_dist_transaction];
bool isNulls[Natts_pg_dist_transaction]; bool isNulls[Natts_pg_dist_transaction];
@ -97,10 +92,10 @@ LogTransactionRecord(int32 groupId, char *transactionName)
values[Anum_pg_dist_transaction_gid - 1] = CStringGetTextDatum(transactionName); values[Anum_pg_dist_transaction_gid - 1] = CStringGetTextDatum(transactionName);
/* open transaction relation and insert new tuple */ /* open transaction relation and insert new tuple */
pgDistTransaction = heap_open(DistTransactionRelationId(), RowExclusiveLock); Relation pgDistTransaction = heap_open(DistTransactionRelationId(), RowExclusiveLock);
tupleDescriptor = RelationGetDescr(pgDistTransaction); TupleDesc tupleDescriptor = RelationGetDescr(pgDistTransaction);
heapTuple = heap_form_tuple(tupleDescriptor, values, isNulls); HeapTuple heapTuple = heap_form_tuple(tupleDescriptor, values, isNulls);
CatalogTupleInsert(pgDistTransaction, heapTuple); CatalogTupleInsert(pgDistTransaction, heapTuple);
@ -118,11 +113,10 @@ LogTransactionRecord(int32 groupId, char *transactionName)
int int
RecoverTwoPhaseCommits(void) RecoverTwoPhaseCommits(void)
{ {
List *workerList = NIL;
ListCell *workerNodeCell = NULL; ListCell *workerNodeCell = NULL;
int recoveredTransactionCount = 0; int recoveredTransactionCount = 0;
workerList = ActivePrimaryNodeList(NoLock); List *workerList = ActivePrimaryNodeList(NoLock);
foreach(workerNodeCell, workerList) foreach(workerNodeCell, workerList)
{ {
@ -148,26 +142,14 @@ RecoverWorkerTransactions(WorkerNode *workerNode)
char *nodeName = workerNode->workerName; char *nodeName = workerNode->workerName;
int nodePort = workerNode->workerPort; int nodePort = workerNode->workerPort;
List *activeTransactionNumberList = NIL;
HTAB *activeTransactionNumberSet = NULL;
List *pendingTransactionList = NIL;
HTAB *pendingTransactionSet = NULL;
List *recheckTransactionList = NIL;
HTAB *recheckTransactionSet = NULL;
Relation pgDistTransaction = NULL;
SysScanDesc scanDescriptor = NULL;
ScanKeyData scanKey[1]; ScanKeyData scanKey[1];
int scanKeyCount = 1; int scanKeyCount = 1;
bool indexOK = true; bool indexOK = true;
HeapTuple heapTuple = NULL; HeapTuple heapTuple = NULL;
TupleDesc tupleDescriptor = NULL;
HASH_SEQ_STATUS status; HASH_SEQ_STATUS status;
MemoryContext localContext = NULL;
MemoryContext oldContext = NULL;
bool recoveryFailed = false; bool recoveryFailed = false;
int connectionFlags = 0; int connectionFlags = 0;
@ -180,17 +162,18 @@ RecoverWorkerTransactions(WorkerNode *workerNode)
return 0; return 0;
} }
localContext = AllocSetContextCreateExtended(CurrentMemoryContext, MemoryContext localContext = AllocSetContextCreateExtended(CurrentMemoryContext,
"RecoverWorkerTransactions", "RecoverWorkerTransactions",
ALLOCSET_DEFAULT_MINSIZE, ALLOCSET_DEFAULT_MINSIZE,
ALLOCSET_DEFAULT_INITSIZE, ALLOCSET_DEFAULT_INITSIZE,
ALLOCSET_DEFAULT_MAXSIZE); ALLOCSET_DEFAULT_MAXSIZE);
oldContext = MemoryContextSwitchTo(localContext); MemoryContext oldContext = MemoryContextSwitchTo(localContext);
/* take table lock first to avoid running concurrently */ /* take table lock first to avoid running concurrently */
pgDistTransaction = heap_open(DistTransactionRelationId(), ShareUpdateExclusiveLock); Relation pgDistTransaction = heap_open(DistTransactionRelationId(),
tupleDescriptor = RelationGetDescr(pgDistTransaction); ShareUpdateExclusiveLock);
TupleDesc tupleDescriptor = RelationGetDescr(pgDistTransaction);
/* /*
* We're going to check the list of prepared transactions on the worker, * We're going to check the list of prepared transactions on the worker,
@ -225,12 +208,13 @@ RecoverWorkerTransactions(WorkerNode *workerNode)
*/ */
/* find stale prepared transactions on the remote node */ /* find stale prepared transactions on the remote node */
pendingTransactionList = PendingWorkerTransactionList(connection); List *pendingTransactionList = PendingWorkerTransactionList(connection);
pendingTransactionSet = ListToHashSet(pendingTransactionList, NAMEDATALEN, true); HTAB *pendingTransactionSet = ListToHashSet(pendingTransactionList, NAMEDATALEN,
true);
/* find in-progress distributed transactions */ /* find in-progress distributed transactions */
activeTransactionNumberList = ActiveDistributedTransactionNumbers(); List *activeTransactionNumberList = ActiveDistributedTransactionNumbers();
activeTransactionNumberSet = ListToHashSet(activeTransactionNumberList, HTAB *activeTransactionNumberSet = ListToHashSet(activeTransactionNumberList,
sizeof(uint64), false); sizeof(uint64), false);
/* scan through all recovery records of the current worker */ /* scan through all recovery records of the current worker */
@ -238,18 +222,19 @@ RecoverWorkerTransactions(WorkerNode *workerNode)
BTEqualStrategyNumber, F_INT4EQ, Int32GetDatum(groupId)); BTEqualStrategyNumber, F_INT4EQ, Int32GetDatum(groupId));
/* get a snapshot of pg_dist_transaction */ /* get a snapshot of pg_dist_transaction */
scanDescriptor = systable_beginscan(pgDistTransaction, SysScanDesc scanDescriptor = systable_beginscan(pgDistTransaction,
DistTransactionGroupIndexId(), indexOK, DistTransactionGroupIndexId(),
indexOK,
NULL, scanKeyCount, scanKey); NULL, scanKeyCount, scanKey);
/* find stale prepared transactions on the remote node */ /* find stale prepared transactions on the remote node */
recheckTransactionList = PendingWorkerTransactionList(connection); List *recheckTransactionList = PendingWorkerTransactionList(connection);
recheckTransactionSet = ListToHashSet(recheckTransactionList, NAMEDATALEN, true); HTAB *recheckTransactionSet = ListToHashSet(recheckTransactionList, NAMEDATALEN,
true);
while (HeapTupleIsValid(heapTuple = systable_getnext(scanDescriptor))) while (HeapTupleIsValid(heapTuple = systable_getnext(scanDescriptor)))
{ {
bool isNull = false; bool isNull = false;
bool isTransactionInProgress = false;
bool foundPreparedTransactionBeforeCommit = false; bool foundPreparedTransactionBeforeCommit = false;
bool foundPreparedTransactionAfterCommit = false; bool foundPreparedTransactionAfterCommit = false;
@ -258,7 +243,7 @@ RecoverWorkerTransactions(WorkerNode *workerNode)
tupleDescriptor, &isNull); tupleDescriptor, &isNull);
char *transactionName = TextDatumGetCString(transactionNameDatum); char *transactionName = TextDatumGetCString(transactionNameDatum);
isTransactionInProgress = IsTransactionInProgress(activeTransactionNumberSet, bool isTransactionInProgress = IsTransactionInProgress(activeTransactionNumberSet,
transactionName); transactionName);
if (isTransactionInProgress) if (isTransactionInProgress)
{ {
@ -375,17 +360,15 @@ RecoverWorkerTransactions(WorkerNode *workerNode)
while ((pendingTransactionName = hash_seq_search(&status)) != NULL) while ((pendingTransactionName = hash_seq_search(&status)) != NULL)
{ {
bool isTransactionInProgress = false; bool isTransactionInProgress = IsTransactionInProgress(
bool shouldCommit = false; activeTransactionNumberSet,
isTransactionInProgress = IsTransactionInProgress(activeTransactionNumberSet,
pendingTransactionName); pendingTransactionName);
if (isTransactionInProgress) if (isTransactionInProgress)
{ {
continue; continue;
} }
shouldCommit = false; bool shouldCommit = false;
abortSucceeded = RecoverPreparedTransactionOnWorker(connection, abortSucceeded = RecoverPreparedTransactionOnWorker(connection,
pendingTransactionName, pendingTransactionName,
shouldCommit); shouldCommit);
@ -415,10 +398,6 @@ PendingWorkerTransactionList(MultiConnection *connection)
{ {
StringInfo command = makeStringInfo(); StringInfo command = makeStringInfo();
bool raiseInterrupts = true; bool raiseInterrupts = true;
int querySent = 0;
PGresult *result = NULL;
int rowCount = 0;
int rowIndex = 0;
List *transactionNames = NIL; List *transactionNames = NIL;
int coordinatorId = GetLocalGroupId(); int coordinatorId = GetLocalGroupId();
@ -426,21 +405,21 @@ PendingWorkerTransactionList(MultiConnection *connection)
"WHERE gid LIKE 'citus\\_%d\\_%%'", "WHERE gid LIKE 'citus\\_%d\\_%%'",
coordinatorId); coordinatorId);
querySent = SendRemoteCommand(connection, command->data); int querySent = SendRemoteCommand(connection, command->data);
if (querySent == 0) if (querySent == 0)
{ {
ReportConnectionError(connection, ERROR); ReportConnectionError(connection, ERROR);
} }
result = GetRemoteCommandResult(connection, raiseInterrupts); PGresult *result = GetRemoteCommandResult(connection, raiseInterrupts);
if (!IsResponseOK(result)) if (!IsResponseOK(result))
{ {
ReportResultError(connection, result, ERROR); ReportResultError(connection, result, ERROR);
} }
rowCount = PQntuples(result); int rowCount = PQntuples(result);
for (rowIndex = 0; rowIndex < rowCount; rowIndex++) for (int rowIndex = 0; rowIndex < rowCount; rowIndex++)
{ {
const int columnIndex = 0; const int columnIndex = 0;
char *transactionName = PQgetvalue(result, rowIndex, columnIndex); char *transactionName = PQgetvalue(result, rowIndex, columnIndex);
@ -468,11 +447,12 @@ IsTransactionInProgress(HTAB *activeTransactionNumberSet, char *preparedTransact
int procId = 0; int procId = 0;
uint32 connectionNumber = 0; uint32 connectionNumber = 0;
uint64 transactionNumber = 0; uint64 transactionNumber = 0;
bool isValidName = false;
bool isTransactionInProgress = false; bool isTransactionInProgress = false;
isValidName = ParsePreparedTransactionName(preparedTransactionName, &groupId, &procId, bool isValidName = ParsePreparedTransactionName(preparedTransactionName, &groupId,
&transactionNumber, &connectionNumber); &procId,
&transactionNumber,
&connectionNumber);
if (isValidName) if (isValidName)
{ {
hash_search(activeTransactionNumberSet, &transactionNumber, HASH_FIND, hash_search(activeTransactionNumberSet, &transactionNumber, HASH_FIND,
@ -493,7 +473,6 @@ RecoverPreparedTransactionOnWorker(MultiConnection *connection, char *transactio
{ {
StringInfo command = makeStringInfo(); StringInfo command = makeStringInfo();
PGresult *result = NULL; PGresult *result = NULL;
int executeCommand = 0;
bool raiseInterrupts = false; bool raiseInterrupts = false;
if (shouldCommit) if (shouldCommit)
@ -509,7 +488,7 @@ RecoverPreparedTransactionOnWorker(MultiConnection *connection, char *transactio
quote_literal_cstr(transactionName)); quote_literal_cstr(transactionName));
} }
executeCommand = ExecuteOptionalRemoteCommand(connection, command->data, &result); int executeCommand = ExecuteOptionalRemoteCommand(connection, command->data, &result);
if (executeCommand == QUERY_SEND_FAILED) if (executeCommand == QUERY_SEND_FAILED)
{ {
return false; return false;

Some files were not shown because too many files have changed in this diff Show More