diff --git a/src/backend/distributed/relay/relay_event_utility.c b/src/backend/distributed/relay/relay_event_utility.c index fc6f91e66..06ecb59e1 100644 --- a/src/backend/distributed/relay/relay_event_utility.c +++ b/src/backend/distributed/relay/relay_event_utility.c @@ -45,9 +45,8 @@ /* expression tree walker context for rewriting row references */ typedef struct { - char *relationName; uint64 shardId; -} RowRefWalkerState; +} ColumnRefWalkerState; /* Local functions forward declarations */ @@ -56,9 +55,7 @@ static bool TypeDropIndexConstraint(const AlterTableCmd *command, const RangeVar *relation, uint64 shardId); static void AppendShardIdToConstraintName(AlterTableCmd *command, uint64 shardId); static void SetSchemaNameIfNotExist(char **schemaName, char *newSchemaName); -static bool ExtendRowReferencesWalker(Node *node, RowRefWalkerState *state); -static void AppendShardIdToRowReferences(IndexStmt *statement, char *relationName, uint64 - shardId); +static bool UpdateWholeRowColumnReferencesWalker(Node *node, ColumnRefWalkerState *state); /* * RelayEventExtendNames extends relation names in the given parse tree for @@ -297,6 +294,7 @@ RelayEventExtendNames(Node *parseTree, char *schemaName, uint64 shardId) case T_IndexStmt: { IndexStmt *indexStmt = (IndexStmt *) parseTree; + ColumnRefWalkerState state; char **relationName = &(indexStmt->relation->relname); char **indexName = &(indexStmt->idxname); char **relationSchemaName = &(indexStmt->relation->schemaname); @@ -322,10 +320,14 @@ RelayEventExtendNames(Node *parseTree, char *schemaName, uint64 shardId) ereport(ERROR, (errmsg("cannot extend name for null index name"))); } + /* extend ColumnRef nodes in the IndexStmt with the shardId */ + state.shardId = shardId; + raw_expression_tree_walker((Node *) indexStmt->indexParams, + UpdateWholeRowColumnReferencesWalker, &state); + /* prefix with schema name if it is not added already */ SetSchemaNameIfNotExist(relationSchemaName, schemaName); - AppendShardIdToRowReferences(indexStmt, *relationName, shardId); AppendShardIdToName(relationName, shardId); AppendShardIdToName(indexName, shardId); break; @@ -548,8 +550,10 @@ AppendShardIdToConstraintName(AlterTableCmd *command, uint64 shardId) static bool -ExtendRowReferencesWalker(Node *node, RowRefWalkerState *state) +UpdateWholeRowColumnReferencesWalker(Node *node, ColumnRefWalkerState *state) { + bool walkIsComplete = false; + if (node == NULL) { return false; @@ -559,55 +563,41 @@ ExtendRowReferencesWalker(Node *node, RowRefWalkerState *state) { IndexElem *indexElem = (IndexElem *) node; - return raw_expression_tree_walker(indexElem->expr, ExtendRowReferencesWalker, - state); + walkIsComplete = raw_expression_tree_walker(indexElem->expr, + UpdateWholeRowColumnReferencesWalker, + state); } else if (IsA(node, ColumnRef)) { ColumnRef *columnRef = (ColumnRef *) node; - ListCell *fieldsCell; + Node *lastField = llast(columnRef->fields); - /* - * Append the shardId to any ColumnRef String values that are - * equal to the relationName. These are actually ROW(relname) - * references. - */ - foreach(fieldsCell, columnRef->fields) + if (IsA(lastField, A_Star)) { - Value *fieldValue = (Value *) lfirst(fieldsCell); + /* + * ColumnRef fields list ends with an A_Star, so we can blindly + * extend the penultimate element with the shardId. + */ + Value *relnameValue; + int len = list_length(columnRef->fields); - if (IsA(fieldValue, String)) - { - char *columnName = strVal(fieldValue); + relnameValue = list_nth(columnRef->fields, len - 2); + Assert(IsA(relnameValue, String)); - if (strncmp(columnName, state->relationName, NAMEDATALEN) == 0) - { - AppendShardIdToName(&columnName, state->shardId); - fieldsCell->data.ptr_value = makeString(columnName); - } - } + AppendShardIdToName(&relnameValue->val.str, state->shardId); } + + /* might be more than one ColumnRef to visit */ + walkIsComplete = false; + } + else + { + walkIsComplete = raw_expression_tree_walker(node, + UpdateWholeRowColumnReferencesWalker, + state); } - return raw_expression_tree_walker(node, ExtendRowReferencesWalker, (void *) state); -} - - -/* - * AppendShardIdToRowReferences finds ColumnRef nodes that directly reference - * a column with the same name as the relation and extends those names with the - * given shardId. - */ -static void -AppendShardIdToRowReferences(IndexStmt *indexStmt, char *relationName, uint64 shardId) -{ - RowRefWalkerState state; - - state.relationName = relationName; - state.shardId = shardId; - - raw_expression_tree_walker((Node *) indexStmt->indexParams, ExtendRowReferencesWalker, - &state); + return walkIsComplete; } diff --git a/src/backend/distributed/utils/ruleutils_95.c b/src/backend/distributed/utils/ruleutils_95.c index 34e53c26a..4802e66fb 100644 --- a/src/backend/distributed/utils/ruleutils_95.c +++ b/src/backend/distributed/utils/ruleutils_95.c @@ -3623,8 +3623,20 @@ get_variable(Var *var, int levelsup, bool istoplevel, deparse_context *context) } else { - /* System column - name is fixed, get it from the catalog */ - attname = get_relid_attribute_name(rte->relid, attnum); + CitusRTEKind rtekind; + + rtekind = GetRangeTblKind(rte); + + if (rtekind == CITUS_RTE_SHARD || rtekind == CITUS_RTE_REMOTE_QUERY) + { + /* System column on a Citus shared/remote relation */ + attname = get_relid_attribute_name(rte->relid, attnum); + } + else + { + /* System column - name is fixed, get it from the catalog */ + attname = get_rte_attribute_name(rte, attnum); + } } if (refname && (context->varprefix || attname == NULL))