From de57bd12404e4c7d488a5bb2377538078d58a7b2 Mon Sep 17 00:00:00 2001 From: Shabnam Khan Date: Tue, 27 Jun 2023 12:58:41 +0530 Subject: [PATCH] Added citus_find_shard_split_points as UDF and calling it from function --- .../distributed/operations/auto_shard_split.c | 280 +++++++++--------- .../citus_auto_shard_split_start/12.0-1.sql | 19 +- .../citus_auto_shard_split_start/latest.sql | 17 ++ .../distributed/coordinator_protocol.h | 2 +- 4 files changed, 184 insertions(+), 134 deletions(-) diff --git a/src/backend/distributed/operations/auto_shard_split.c b/src/backend/distributed/operations/auto_shard_split.c index 41668801c..406f19606 100644 --- a/src/backend/distributed/operations/auto_shard_split.c +++ b/src/backend/distributed/operations/auto_shard_split.c @@ -14,10 +14,13 @@ #include "distributed/listutils.h" #include "distributed/metadata_utility.h" #include "distributed/background_jobs.h" +#include "distributed/multi_join_order.h" +#include "distributed/citus_ruleutils.h" PG_FUNCTION_INFO_V1(citus_auto_shard_split_start); +PG_FUNCTION_INFO_V1(citus_find_shard_split_points); -uint64 MaxShardSize = 102400; +int64 MaxShardSize = 102400; double TenantFrequency = 0.3; /* @@ -27,27 +30,19 @@ double TenantFrequency = 0.3; typedef struct ShardInfoData { int64 shardSize; - int64 shardMinValue; - int64 shardMaxValue; int64 shardId; int32 nodeId; - char *tableName; - char *distributionColumn; - char *dataType; - char *shardName; - Oid tableId; - Oid distributionColumnId; + int64 shardGroupSize; }ShardInfoData; typedef ShardInfoData *ShardInfo; void ErrorOnConcurrentOperation(void); -StringInfo GetShardSplitQuery(ShardInfo shardinfo, List *splitPoints, +StringInfo GetShardSplitQuery(ShardInfo shardinfo, Datum datum, char *shardSplitMode); -void ExecuteSplitBackgroundJob(int64 jobId, ShardInfo shardinfo, List *splitPoints, +void ExecuteSplitBackgroundJob(int64 jobId, ShardInfo shardinfo, Datum datum, char *shardSplitMode); -int64 ExecuteAverageHashQuery(ShardInfo shardinfo); -List * FindShardSplitPoints(ShardInfo shardinfo); -int64 ScheduleShardSplit(ShardInfo shardinfo, char *shardSplitMode , int64 jobId); +List * FindShardSplitPoints(int64 shardId); +int64 ScheduleShardSplit(ShardInfo shardinfo, char *shardSplitMode, int64 jobId); /* * It throws an error if a concurrent automatic shard split or Rebalance operation is happening. @@ -79,32 +74,34 @@ ErrorOnConcurrentOperation() * For a given SplitPoints , it creates the SQL query for the Shard Splitting */ StringInfo -GetShardSplitQuery(ShardInfo shardinfo, List *splitPoints, char *shardSplitMode) +GetShardSplitQuery(ShardInfo shardinfo, Datum datum, char *shardSplitMode) { StringInfo splitQuery = makeStringInfo(); - int64 length = list_length(splitPoints); + ArrayType *array = DatumGetArrayTypeP(datum); + Datum *values; + int nelems; + deconstruct_array(array, + INT4OID, + sizeof(int32), true, TYPALIGN_INT, + &values, NULL, &nelems); appendStringInfo(splitQuery, "SELECT citus_split_shard_by_split_points(%ld, ARRAY[", shardinfo->shardId); - int32 splitpoint = 0; - uint64 index = 0; - foreach_int(splitpoint, splitPoints) + for (int i = 0; i < nelems; i++) { - appendStringInfo(splitQuery, "'%d'", splitpoint); + appendStringInfo(splitQuery, "'%d'", values[i]); - if (index < length - 1) + if (i < nelems - 1) { appendStringInfoString(splitQuery, ","); } - - index++; } /*All the shards after the split will be belonging to the same node */ appendStringInfo(splitQuery, "], ARRAY["); - for (int i = 0; i < length; i++) + for (int i = 0; i < nelems; i++) { appendStringInfo(splitQuery, "%d,", shardinfo->nodeId); } @@ -119,11 +116,12 @@ GetShardSplitQuery(ShardInfo shardinfo, List *splitPoints, char *shardSplitMode) * It creates a background job for citus_split_shard_by_split_points and executes it in background. */ void -ExecuteSplitBackgroundJob(int64 jobId, ShardInfo shardinfo, List *splitPoints, +ExecuteSplitBackgroundJob(int64 jobId, ShardInfo shardinfo, Datum datum, char *shardSplitMode) { StringInfo splitQuery = makeStringInfo(); - splitQuery = GetShardSplitQuery(shardinfo, splitPoints, shardSplitMode); + splitQuery = GetShardSplitQuery(shardinfo, datum, shardSplitMode); + /* ereport(LOG, (errmsg(splitQuery->data))); */ int32 nodesInvolved[] = { shardinfo->nodeId }; Oid superUserId = CitusExtensionOwner(); @@ -132,58 +130,37 @@ ExecuteSplitBackgroundJob(int64 jobId, ShardInfo shardinfo, List *splitPoints, } -/* - * It executes a query to find the average hash value in a shard considering rows with a limit of 10GB . - * If there exists a hash value it is returned otherwise shardminvalue-1 is returned. - */ -int64 -ExecuteAverageHashQuery(ShardInfo shardinfo) -{ - StringInfo AvgHashQuery = makeStringInfo(); - uint64 tableSize = 0; - bool check = DistributedTableSize(shardinfo->tableId, TOTAL_RELATION_SIZE, true, - &tableSize); - appendStringInfo(AvgHashQuery, "SELECT avg(h)::int,count(*)" - " FROM (SELECT worker_hash(%s) h FROM %s TABLESAMPLE SYSTEM(least(10, 100*10000000000/%lu))" - " WHERE worker_hash(%s)>=%ld AND worker_hash(%s)<=%ld) s", - shardinfo->distributionColumn, shardinfo->tableName, - tableSize, - shardinfo->distributionColumn, shardinfo->shardMinValue, - shardinfo->distributionColumn, shardinfo->shardMaxValue - ); - ereport(DEBUG4, errmsg("%s", AvgHashQuery->data)); - SPI_connect(); - SPI_exec(AvgHashQuery->data, 0); - SPITupleTable *tupletable = SPI_tuptable; - HeapTuple tuple = tupletable->vals[0]; - bool isnull; - Datum average = SPI_getbinval(tuple, tupletable->tupdesc, 1, &isnull); - int64 isResultNull = 1; - if (!isnull) - { - isResultNull = 0; - } - SPI_freetuptable(tupletable); - SPI_finish(); - - if (isResultNull == 0) - { - return DatumGetInt64(average); - } - else - { - return shardinfo->shardMinValue - 1; - } -} - - /* * This function executes a query and then decides whether a shard is subjected for isolation or average hash 2 way split. * If a tenant is found splitpoints for isolation is returned otherwise average hash value is returned. */ -List * -FindShardSplitPoints(ShardInfo shardinfo) +Datum +citus_find_shard_split_points(PG_FUNCTION_ARGS) { + int64 shardId = PG_GETARG_INT64(0); + int64 shardGroupSize = PG_GETARG_INT64(1); + ereport(DEBUG4, errmsg("%ld", shardGroupSize)); + + /*Filtering Shards with total GroupSize greater than MaxShardSize*1024 i.e Size based Policy*/ + if (shardGroupSize < MaxShardSize * 1024) + { + PG_RETURN_NULL(); + } + + /*Extracting all the shardinfo with the help of shardId*/ + Oid tableId = RelationIdForShard(shardId); + char *distributionColumnName = ColumnToColumnName(tableId, + (Node *) DistPartitionKeyOrError( + tableId)); + char *dataType = format_type_be(ColumnTypeIdForRelationColumnName( + tableId, + distributionColumnName)); + char *shardName = get_rel_name(tableId); + AppendShardIdToName(&shardName, shardId); + ShardInterval *shardrange = LoadShardInterval(shardId); + int64 shardMinValue = shardrange->minValue; + int64 shardMaxValue = shardrange->maxValue; + char *tableName = generate_qualified_relation_name(tableId); StringInfo CommonValueQuery = makeStringInfo(); /* @@ -196,14 +173,14 @@ FindShardSplitPoints(ShardInfo shardinfo) " FROM pg_stats s , unnest(most_common_vals::text::%s[],most_common_freqs) as res(val,freq)" " WHERE tablename = %s AND attname = %s AND schemaname = %s AND freq > %lf $$)" " WHERE result <> '' AND shardid = %ld;", - shardinfo->dataType, quote_literal_cstr(shardinfo->tableName), - shardinfo->dataType, - quote_literal_cstr(shardinfo->shardName), - quote_literal_cstr(shardinfo->distributionColumn), + dataType, quote_literal_cstr(tableName), + dataType, + quote_literal_cstr(shardName), + quote_literal_cstr(distributionColumnName), quote_literal_cstr(get_namespace_name(get_rel_namespace( - shardinfo->tableId))), + tableId))), TenantFrequency, - shardinfo->shardId); + shardId); ereport(DEBUG4, errmsg("%s", CommonValueQuery->data)); List *splitPoints = NULL; @@ -218,26 +195,26 @@ FindShardSplitPoints(ShardInfo shardinfo) MemoryContext spiContext = CurrentMemoryContext; int64 rowCount = SPI_processed; - int64 average; int32 hashedValue; ereport(DEBUG4, errmsg("%ld", rowCount)); if (rowCount > 0) { - /*For every common tenant value split point is calculated on the basis of + /*For every common tenant value split point is calculated on the basis of * the hashed value and the unique split points are appended to the list * and the resulting is then sorted and returned. - */ + */ SPITupleTable *tupletable = SPI_tuptable; - CitusTableCacheEntry *cacheEntry = GetCitusTableCacheEntry(shardinfo->tableId); + CitusTableCacheEntry *cacheEntry = GetCitusTableCacheEntry(tableId); for (int rowIndex = 0; rowIndex < rowCount; rowIndex++) { HeapTuple tuple = tupletable->vals[rowIndex]; char *commonValue = SPI_getvalue(tuple, tupletable->tupdesc, 2); ereport(DEBUG4, errmsg("%s", commonValue)); Datum tenantIdDatum = StringToDatum(commonValue, - shardinfo->distributionColumnId); + ColumnTypeIdForRelationColumnName(tableId, + distributionColumnName)); Datum hashedValueDatum = FunctionCall1Coll(cacheEntry->hashFunction, cacheEntry->partitionColumn-> varcollid, @@ -248,11 +225,11 @@ FindShardSplitPoints(ShardInfo shardinfo) /*Switching the memory context to store the unique SplitPoints in a list*/ MemoryContextSwitchTo(originalContext); - if (hashedValue == shardinfo->shardMinValue) + if (hashedValue == shardMinValue) { splitPoints = list_append_unique_int(splitPoints, hashedValue); } - else if (hashedValue == shardinfo->shardMaxValue) + else if (hashedValue == shardMaxValue) { splitPoints = list_append_unique_int(splitPoints, hashedValue - 1); } @@ -268,17 +245,64 @@ FindShardSplitPoints(ShardInfo shardinfo) } else { - average = ExecuteAverageHashQuery(shardinfo); - ereport(DEBUG4, errmsg("%ld", average)); - MemoryContextSwitchTo(originalContext); - if (shardinfo->shardMinValue <= average) + StringInfo AvgHashQuery = makeStringInfo(); + uint64 tableSize = 0; + bool check = DistributedTableSize(tableId, TOTAL_RELATION_SIZE, true, + &tableSize); + + /* + * It executes a query to find the average hash value in a shard considering rows with a limit of 10GB . + * If there exists a hash value it is returned otherwise NULL is returned. + */ + appendStringInfo(AvgHashQuery, "SELECT avg(h)::int,count(*)" + " FROM (SELECT worker_hash(%s) h FROM %s TABLESAMPLE SYSTEM(least(10, 100*10000000000/%lu))" + " WHERE worker_hash(%s)>=%ld AND worker_hash(%s)<=%ld) s", + distributionColumnName, tableName, + tableSize, + distributionColumnName, shardMinValue, + distributionColumnName, shardMaxValue + ); + ereport(DEBUG4, errmsg("%s", AvgHashQuery->data)); + + SPI_connect(); + SPI_exec(AvgHashQuery->data, 0); + + SPITupleTable *tupletable = SPI_tuptable; + HeapTuple tuple = tupletable->vals[0]; + bool isnull; + + Datum average = SPI_getbinval(tuple, tupletable->tupdesc, 1, &isnull); + int64 isResultNull = 1; + if (!isnull) { - splitPoints = lappend_int(splitPoints, average); + isResultNull = 0; + } + SPI_freetuptable(tupletable); + SPI_finish(); + + if (isResultNull == 0) + { + ereport(DEBUG4, errmsg("%ld", average)); + MemoryContextSwitchTo(originalContext); + splitPoints = lappend_int(splitPoints, DatumGetInt32(average)); } } - SPI_finish(); - return splitPoints; + + /*Converting the list into datum for further conversion into Arraytype*/ + Datum *elements = (Datum *) palloc(sizeof(Datum) * list_length(splitPoints)); + int32 splitPoint, index = 0; + foreach_int(splitPoint, splitPoints) + { + elements[index++] = Int32GetDatum(splitPoint); + } + ArrayType *resultArray = construct_array(elements, list_length(splitPoints), INT4OID, + sizeof(int32), true, TYPALIGN_INT); + if (list_length(splitPoints) == 0) + { + PG_RETURN_NULL(); + } + PG_RETURN_ARRAYTYPE_P(resultArray); } @@ -287,19 +311,33 @@ FindShardSplitPoints(ShardInfo shardinfo) * split and then executes the background job for the shard split. */ int64 -ScheduleShardSplit(ShardInfo shardinfo, char *shardSplitMode , int64 jobId) +ScheduleShardSplit(ShardInfo shardinfo, char *shardSplitMode, int64 jobId) { - List *splitPoints = FindShardSplitPoints(shardinfo); - if (list_length(splitPoints) > 0) + SPI_connect(); + StringInfo findSplitPointsQuery = makeStringInfo(); + appendStringInfo(findSplitPointsQuery, + "SELECT citus_find_shard_split_points(%ld , %ld)", + shardinfo->shardId, + shardinfo->shardGroupSize); + SPI_exec(findSplitPointsQuery->data, 0); + + SPITupleTable *tupletable = SPI_tuptable; + HeapTuple tuple = tupletable->vals[0]; + bool isnull; + Datum resultDatum = SPI_getbinval(tuple, tupletable->tupdesc, 1, &isnull); + + if (!isnull) { - ereport(DEBUG4, errmsg("%s", GetShardSplitQuery(shardinfo, splitPoints, - shardSplitMode)->data)); - ExecuteSplitBackgroundJob(jobId, shardinfo, splitPoints, shardSplitMode); + ereport(DEBUG4, errmsg("%s", GetShardSplitQuery(shardinfo, resultDatum, + shardSplitMode)->data)); + ExecuteSplitBackgroundJob(jobId, shardinfo, resultDatum, shardSplitMode); + SPI_finish(); return 1; } else { ereport(LOG, errmsg("No Splitpoints for shard split")); + SPI_finish(); return 0; } } @@ -323,17 +361,16 @@ citus_auto_shard_split_start(PG_FUNCTION_ARGS) appendStringInfo( query, - " SELECT cs.shardid,pd.shardminvalue,pd.shardmaxvalue,cs.shard_size,pn.nodeid,ct.distribution_column,ct.table_name,cs.shard_name,(SELECT relname FROM pg_class WHERE oid = ct.table_name)" + " SELECT cs.shardid,pd.shardminvalue,pd.shardmaxvalue,cs.shard_size,pn.nodeid, max_sizes.total_sum" " FROM pg_catalog.pg_dist_shard pd JOIN pg_catalog.citus_shards cs ON pd.shardid = cs.shardid JOIN pg_catalog.pg_dist_node pn ON cs.nodename = pn.nodename AND cs.nodeport= pn.nodeport" " JOIN" - " ( select shardid , max_size from (SELECT distinct first_value(shardid) OVER w as shardid, sum(shard_size) OVER (PARTITION BY colocation_id, shardminvalue) as total_sum, max(shard_size) OVER w as max_size" + " ( select shardid , max_size , total_sum from (SELECT distinct first_value(shardid) OVER w as shardid, sum(shard_size) OVER (PARTITION BY colocation_id, shardminvalue) as total_sum, max(shard_size) OVER w as max_size" " FROM citus_shards cs JOIN pg_dist_shard ps USING(shardid)" - " WINDOW w AS (PARTITION BY colocation_id, shardminvalue ORDER BY shard_size DESC) )as t where total_sum >= %lu )" - " AS max_sizes ON cs.shardid=max_sizes.shardid AND cs.shard_size = max_sizes.max_size JOIN citus_tables ct ON cs.table_name = ct.table_name AND pd.shardminvalue <> pd.shardmaxvalue AND pd.shardminvalue <> ''", - MaxShardSize*1024 + " WINDOW w AS (PARTITION BY colocation_id, shardminvalue ORDER BY shard_size DESC))as t)" + " AS max_sizes ON cs.shardid=max_sizes.shardid AND cs.shard_size = max_sizes.max_size AND pd.shardminvalue <> pd.shardmaxvalue AND pd.shardminvalue <> ''" ); - ereport(DEBUG4 ,errmsg("%s", query->data)); + ereport(DEBUG4, errmsg("%s", query->data)); Oid shardTransferModeOid = PG_GETARG_OID(0); Datum enumLabelDatum = DirectFunctionCall1(enum_out, shardTransferModeOid); char *shardSplitMode = DatumGetCString(enumLabelDatum); @@ -351,7 +388,7 @@ citus_auto_shard_split_start(PG_FUNCTION_ARGS) int rowCount = SPI_processed; bool isnull; int64 jobId = CreateBackgroundJob("Automatic Shard Split", - "Split using SplitPoints List"); + "Split using SplitPoints List"); int64 count = 0; for (int rowIndex = 0; rowIndex < rowCount; rowIndex++) @@ -368,37 +405,16 @@ citus_auto_shard_split_start(PG_FUNCTION_ARGS) Datum nodeIdDatum = SPI_getbinval(tuple, tupletable->tupdesc, 5, &isnull); shardinfo.nodeId = DatumGetInt32(nodeIdDatum); - char *shardMinVal = SPI_getvalue(tuple, tupletable->tupdesc, 2); - shardinfo.shardMinValue = strtoi64(shardMinVal, NULL, 10); + char *shardGroupSizeValue = SPI_getvalue(tuple, tupletable->tupdesc, 6); + shardinfo.shardGroupSize = strtoi64(shardGroupSizeValue, NULL, 10); - char *shardMaxVal = SPI_getvalue(tuple, tupletable->tupdesc, 3); - shardinfo.shardMaxValue = strtoi64(shardMaxVal, NULL, 10); - - shardinfo.distributionColumn = SPI_getvalue(tuple, tupletable->tupdesc, 6); - shardinfo.tableName = SPI_getvalue(tuple, tupletable->tupdesc, 7); - - shardinfo.shardName = SPI_getvalue(tuple, tupletable->tupdesc, 9); - AppendShardIdToName(&shardinfo.shardName, shardinfo.shardId); - - Datum tableIdDatum = SPI_getbinval(tuple, tupletable->tupdesc, 7, &isnull); - shardinfo.tableId = DatumGetObjectId(tableIdDatum); - shardinfo.distributionColumnId = ColumnTypeIdForRelationColumnName( - shardinfo.tableId, - shardinfo. - distributionColumn); - shardinfo.dataType = format_type_be(shardinfo.distributionColumnId); - - count = count + ScheduleShardSplit(&shardinfo, shardSplitMode , jobId); - ereport(DEBUG4, (errmsg( - "Shard ID: %ld,ShardMinValue: %ld, ShardMaxValue: %ld , totalSize: %ld , nodeId: %d", - shardinfo.shardId, shardinfo.shardMinValue, - shardinfo.shardMaxValue, - shardinfo.shardSize, shardinfo.nodeId))); + count = count + ScheduleShardSplit(&shardinfo, shardSplitMode, jobId); } SPI_freetuptable(tupletable); SPI_finish(); - if(count==0){ + if (count == 0) + { DirectFunctionCall1(citus_job_cancel, Int64GetDatum(jobId)); } diff --git a/src/backend/distributed/sql/udfs/citus_auto_shard_split_start/12.0-1.sql b/src/backend/distributed/sql/udfs/citus_auto_shard_split_start/12.0-1.sql index aba25f34d..0f3a7bb26 100644 --- a/src/backend/distributed/sql/udfs/citus_auto_shard_split_start/12.0-1.sql +++ b/src/backend/distributed/sql/udfs/citus_auto_shard_split_start/12.0-1.sql @@ -12,4 +12,21 @@ COMMENT ON FUNCTION pg_catalog.citus_auto_shard_split_start(citus.shard_transfer IS 'automatically split the necessary shards in the cluster in the background'; -GRANT EXECUTE ON FUNCTION pg_catalog.citus_auto_shard_split_start(citus.shard_transfer_mode) TO PUBLIC; \ No newline at end of file +GRANT EXECUTE ON FUNCTION pg_catalog.citus_auto_shard_split_start(citus.shard_transfer_mode) TO PUBLIC; + +CREATE OR REPLACE FUNCTION pg_catalog.citus_find_shard_split_points( + shard_id bigint, + shard_group_size bigint + + ) + RETURNS SETOF bigint[] + + AS 'MODULE_PATHNAME' + + LANGUAGE C VOLATILE; + +COMMENT ON FUNCTION pg_catalog.citus_find_shard_split_points(shard_id bigint , shard_group_size bigint) + + IS 'creates split points for shards'; + +GRANT EXECUTE ON FUNCTION pg_catalog.citus_find_shard_split_points(shard_id bigint , shard_group_size bigint) TO PUBLIC; diff --git a/src/backend/distributed/sql/udfs/citus_auto_shard_split_start/latest.sql b/src/backend/distributed/sql/udfs/citus_auto_shard_split_start/latest.sql index b300a2cc7..0f3a7bb26 100644 --- a/src/backend/distributed/sql/udfs/citus_auto_shard_split_start/latest.sql +++ b/src/backend/distributed/sql/udfs/citus_auto_shard_split_start/latest.sql @@ -13,3 +13,20 @@ COMMENT ON FUNCTION pg_catalog.citus_auto_shard_split_start(citus.shard_transfer IS 'automatically split the necessary shards in the cluster in the background'; GRANT EXECUTE ON FUNCTION pg_catalog.citus_auto_shard_split_start(citus.shard_transfer_mode) TO PUBLIC; + +CREATE OR REPLACE FUNCTION pg_catalog.citus_find_shard_split_points( + shard_id bigint, + shard_group_size bigint + + ) + RETURNS SETOF bigint[] + + AS 'MODULE_PATHNAME' + + LANGUAGE C VOLATILE; + +COMMENT ON FUNCTION pg_catalog.citus_find_shard_split_points(shard_id bigint , shard_group_size bigint) + + IS 'creates split points for shards'; + +GRANT EXECUTE ON FUNCTION pg_catalog.citus_find_shard_split_points(shard_id bigint , shard_group_size bigint) TO PUBLIC; diff --git a/src/include/distributed/coordinator_protocol.h b/src/include/distributed/coordinator_protocol.h index 8813b7b32..71aabc79b 100644 --- a/src/include/distributed/coordinator_protocol.h +++ b/src/include/distributed/coordinator_protocol.h @@ -214,7 +214,7 @@ extern int ShardCount; extern int ShardReplicationFactor; extern int NextShardId; extern int NextPlacementId; -extern uint64 MaxShardSize; +extern int64 MaxShardSize; extern double TenantFrequency;