diff --git a/src/backend/distributed/commands/foreign_constraint.c b/src/backend/distributed/commands/foreign_constraint.c index 71dd91aef..1353820ba 100644 --- a/src/backend/distributed/commands/foreign_constraint.c +++ b/src/backend/distributed/commands/foreign_constraint.c @@ -801,31 +801,53 @@ TableReferencing(Oid relationId) /* - * ConstraintIsAForeignKey is a wrapper around GetForeignKeyOidByName that - * returns true if the given constraint name identifies a foreign key - * constraint defined on relation with relationId. + * ConstraintIsAForeignKey is a wrapper around ConstraintWithNameIsOfType that returns true + * if given constraint name identifies a foreign key constraint. */ bool ConstraintIsAForeignKey(char *inputConstaintName, Oid relationId) { - Oid foreignKeyId = GetForeignKeyOidByName(inputConstaintName, relationId); - return OidIsValid(foreignKeyId); + return ConstraintWithNameIsOfType(inputConstaintName, relationId, CONSTRAINT_FOREIGN); } /* - * GetForeignKeyOidByName returns OID of the foreign key with name and defined - * on relation with relationId. If there is no such foreign key constraint, then - * this function returns InvalidOid. + * ConstraintWithNameIsOfType is a wrapper around get_relation_constraint_oid that + * returns true if given constraint name identifies a valid constraint defined + * on relation with relationId and it's type matches the input constraint type. */ -Oid -GetForeignKeyOidByName(char *inputConstaintName, Oid relationId) +bool +ConstraintWithNameIsOfType(char *inputConstaintName, Oid relationId, + char targetConstraintType) { - int flags = INCLUDE_REFERENCING_CONSTRAINTS; - List *foreignKeyOids = GetForeignKeyOids(relationId, flags); + bool missingOk = true; + Oid constraintId = + get_relation_constraint_oid(relationId, inputConstaintName, missingOk); + return ConstraintWithIdIsOfType(constraintId, targetConstraintType); +} - Oid foreignKeyId = FindForeignKeyOidWithName(foreignKeyOids, inputConstaintName); - return foreignKeyId; + +/* + * ConstraintWithIdIsOfType returns true if constraint with constraintId exists + * and is of type targetConstraintType. + */ +bool +ConstraintWithIdIsOfType(Oid constraintId, char targetConstraintType) +{ + HeapTuple heapTuple = SearchSysCache1(CONSTROID, ObjectIdGetDatum(constraintId)); + if (!HeapTupleIsValid(heapTuple)) + { + /* no such constraint */ + return false; + } + + Form_pg_constraint constraintForm = (Form_pg_constraint) GETSTRUCT(heapTuple); + char constraintType = constraintForm->contype; + bool constraintTypeMatches = (constraintType == targetConstraintType); + + ReleaseSysCache(heapTuple); + + return constraintTypeMatches; } diff --git a/src/backend/distributed/commands/table.c b/src/backend/distributed/commands/table.c index fa47db87c..5d795e90d 100644 --- a/src/backend/distributed/commands/table.c +++ b/src/backend/distributed/commands/table.c @@ -17,6 +17,7 @@ #include "access/xact.h" #include "catalog/index.h" #include "catalog/pg_class.h" +#include "catalog/pg_constraint.h" #include "commands/tablecmds.h" #include "distributed/citus_ruleutils.h" #include "distributed/colocation_utils.h" @@ -461,7 +462,9 @@ PreprocessAlterTableStmt(Node *node, const char *alterTableCommand) */ Assert(list_length(commandList) == 1); - Oid foreignKeyId = GetForeignKeyOidByName(constraintName, leftRelationId); + bool missingOk = false; + Oid foreignKeyId = get_relation_constraint_oid(leftRelationId, + constraintName, missingOk); rightRelationId = GetReferencedTableId(foreignKeyId); } } diff --git a/src/include/distributed/commands.h b/src/include/distributed/commands.h index 1ec5a9f26..81d44f577 100644 --- a/src/include/distributed/commands.h +++ b/src/include/distributed/commands.h @@ -148,7 +148,9 @@ extern bool HasForeignKeyToReferenceTable(Oid relationOid); extern bool TableReferenced(Oid relationOid); extern bool TableReferencing(Oid relationOid); extern bool ConstraintIsAForeignKey(char *inputConstaintName, Oid relationOid); -extern Oid GetForeignKeyOidByName(char *inputConstaintName, Oid relationId); +extern bool ConstraintWithNameIsOfType(char *inputConstaintName, Oid relationId, + char targetConstraintType); +extern bool ConstraintWithIdIsOfType(Oid constraintId, char targetConstraintType); extern void ErrorIfTableHasExternalForeignKeys(Oid relationId); extern List * GetForeignKeyOids(Oid relationId, int flags); extern Oid GetReferencedTableId(Oid foreignKeyId);