diff --git a/src/backend/distributed/commands/multi_copy.c b/src/backend/distributed/commands/multi_copy.c index b900617e9..52eb3fa7e 100644 --- a/src/backend/distributed/commands/multi_copy.c +++ b/src/backend/distributed/commands/multi_copy.c @@ -156,6 +156,8 @@ static void CopyAttributeOutText(CopyOutState outputState, char *string); static inline void CopyFlushOutput(CopyOutState outputState, char *start, char *pointer); static bool CitusSendTupleToPlacements(TupleTableSlot *slot, CitusCopyDestReceiver *copyDest); +static uint64 ShardIdForTuple(CitusCopyDestReceiver *copyDest, Datum *columnValues, + bool *columnNulls); /* CitusCopyDestReceiver functions */ static void CitusCopyDestReceiverStartup(DestReceiver *copyDest, int operation, @@ -2277,7 +2279,6 @@ CitusCopyDestReceiverReceive(TupleTableSlot *slot, DestReceiver *dest) static bool CitusSendTupleToPlacements(TupleTableSlot *slot, CitusCopyDestReceiver *copyDest) { - int partitionColumnIndex = copyDest->partitionColumnIndex; TupleDesc tupleDescriptor = copyDest->tupleDescriptor; CopyStmt *copyStatement = copyDest->copyStatement; @@ -2291,8 +2292,6 @@ CitusSendTupleToPlacements(TupleTableSlot *slot, CitusCopyDestReceiver *copyDest Datum *columnValues = NULL; bool *columnNulls = NULL; - Datum partitionColumnValue = 0; - ShardInterval *shardInterval = NULL; int64 shardId = 0; bool shardConnectionsFound = false; @@ -2307,54 +2306,7 @@ CitusSendTupleToPlacements(TupleTableSlot *slot, CitusCopyDestReceiver *copyDest columnValues = slot->tts_values; columnNulls = slot->tts_isnull; - /* - * Find the partition column value and corresponding shard interval - * for non-reference tables. - * Get the existing (and only a single) shard interval for the reference - * tables. Note that, reference tables has NULL partition column values so - * skip the check. - */ - if (partitionColumnIndex != INVALID_PARTITION_COLUMN_INDEX) - { - CopyCoercionData *coercePath = &columnCoercionPaths[partitionColumnIndex]; - - if (columnNulls[partitionColumnIndex]) - { - Oid relationId = copyDest->distributedRelationId; - char *relationName = get_rel_name(relationId); - Oid schemaOid = get_rel_namespace(relationId); - char *schemaName = get_namespace_name(schemaOid); - char *qualifiedTableName = quote_qualified_identifier(schemaName, - relationName); - - ereport(ERROR, (errcode(ERRCODE_NULL_VALUE_NOT_ALLOWED), - errmsg("the partition column of table %s cannot be NULL", - qualifiedTableName))); - } - - /* find the partition column value */ - partitionColumnValue = columnValues[partitionColumnIndex]; - - /* annoyingly this is evaluated twice, but at least we don't crash! */ - partitionColumnValue = CoerceColumnValue(partitionColumnValue, coercePath); - } - - /* - * Find the shard interval and id for the partition column value for - * non-reference tables. - * - * For reference table, this function blindly returns the tables single - * shard. - */ - shardInterval = FindShardInterval(partitionColumnValue, copyDest->tableMetadata); - if (shardInterval == NULL) - { - ereport(ERROR, (errcode(ERRCODE_OBJECT_NOT_IN_PREREQUISITE_STATE), - errmsg("could not find shard for partition column " - "value"))); - } - - shardId = shardInterval->shardId; + shardId = ShardIdForTuple(copyDest, columnValues, columnNulls); /* connections hash is kept in memory context */ MemoryContextSwitchTo(copyDest->memoryContext); @@ -2416,6 +2368,68 @@ CitusSendTupleToPlacements(TupleTableSlot *slot, CitusCopyDestReceiver *copyDest } +/* + * ShardIdForTuple returns id of the shard to which the given tuple belongs to. + */ +static uint64 +ShardIdForTuple(CitusCopyDestReceiver *copyDest, Datum *columnValues, bool *columnNulls) +{ + int partitionColumnIndex = copyDest->partitionColumnIndex; + Datum partitionColumnValue = 0; + CopyCoercionData *columnCoercionPaths = copyDest->columnCoercionPaths; + ShardInterval *shardInterval = NULL; + + /* + * Find the partition column value and corresponding shard interval + * for non-reference tables. + * Get the existing (and only a single) shard interval for the reference + * tables. Note that, reference tables has NULL partition column values so + * skip the check. + */ + if (partitionColumnIndex != INVALID_PARTITION_COLUMN_INDEX) + { + CopyCoercionData *coercePath = &columnCoercionPaths[partitionColumnIndex]; + + if (columnNulls[partitionColumnIndex]) + { + Oid relationId = copyDest->distributedRelationId; + char *relationName = get_rel_name(relationId); + Oid schemaOid = get_rel_namespace(relationId); + char *schemaName = get_namespace_name(schemaOid); + char *qualifiedTableName = quote_qualified_identifier(schemaName, + relationName); + + ereport(ERROR, (errcode(ERRCODE_NULL_VALUE_NOT_ALLOWED), + errmsg("the partition column of table %s cannot be NULL", + qualifiedTableName))); + } + + /* find the partition column value */ + partitionColumnValue = columnValues[partitionColumnIndex]; + + /* annoyingly this is evaluated twice, but at least we don't crash! */ + partitionColumnValue = CoerceColumnValue(partitionColumnValue, coercePath); + } + + /* + * Find the shard interval and id for the partition column value for + * non-reference tables. + * + * For reference table, this function blindly returns the tables single + * shard. + */ + shardInterval = FindShardInterval(partitionColumnValue, copyDest->tableMetadata); + if (shardInterval == NULL) + { + ereport(ERROR, (errcode(ERRCODE_OBJECT_NOT_IN_PREREQUISITE_STATE), + errmsg("could not find shard for partition column " + "value"))); + } + + return shardInterval->shardId; +} + + /* * CitusCopyDestReceiverShutdown implements the rShutdown interface of * CitusCopyDestReceiver. It ends the COPY on all the open connections and closes