From 20debfc0ee92ddb727487377cdc7d30b527bd034 Mon Sep 17 00:00:00 2001 From: Burak Yucesoy Date: Tue, 19 Jul 2016 15:21:20 +0300 Subject: [PATCH] Fix COUNT DISTINCT approximation with schema Fixes #555 Before this change, we were resolving HLL function and type Oid without qualified name. Now we find the schema name where HLL objects are stored and generate qualified names for each objects. Similar fix is also applied for cstore_table_size function call. --- .../planner/multi_logical_optimizer.c | 67 ++++++++++++++++--- .../distributed/utils/citus_ruleutils.c | 3 +- .../worker/worker_data_fetch_protocol.c | 11 ++- src/include/distributed/citus_ruleutils.h | 1 + .../distributed/multi_logical_optimizer.h | 3 +- .../multi_agg_approximate_distinct.out | 43 ++++++++++++ .../sql/multi_agg_approximate_distinct.sql | 31 +++++++++ 7 files changed, 144 insertions(+), 15 deletions(-) diff --git a/src/backend/distributed/planner/multi_logical_optimizer.c b/src/backend/distributed/planner/multi_logical_optimizer.c index 91b98e7b8..395f94bf0 100644 --- a/src/backend/distributed/planner/multi_logical_optimizer.c +++ b/src/backend/distributed/planner/multi_logical_optimizer.c @@ -26,6 +26,7 @@ #include "catalog/pg_type.h" #include "commands/extension.h" #include "distributed/citus_nodes.h" +#include "distributed/citus_ruleutils.h" #include "distributed/metadata_cache.h" #include "distributed/multi_logical_optimizer.h" #include "distributed/multi_logical_planner.h" @@ -125,6 +126,7 @@ static List * WorkerAggregateExpressionList(Aggref *originalAggregate, static AggregateType GetAggregateType(Oid aggFunctionId); static Oid AggregateArgumentType(Aggref *aggregate); static Oid AggregateFunctionOid(const char *functionName, Oid inputType); +static Oid TypeOid(Oid schemaId, const char *typeName); /* Local functions forward declarations for count(distinct) approximations */ static char * CountDistinctHashFunctionName(Oid argumentType); @@ -1419,11 +1421,18 @@ MasterAggregateExpression(Aggref *originalAggregate, Aggref *unionAggregate = NULL; FuncExpr *cardinalityExpression = NULL; - Oid unionFunctionId = FunctionOid(HLL_UNION_AGGREGATE_NAME, argCount); - Oid cardinalityFunctionId = FunctionOid(HLL_CARDINALITY_FUNC_NAME, argCount); + /* extract schema name of hll */ + Oid hllId = get_extension_oid(HLL_EXTENSION_NAME, false); + Oid hllSchemaOid = get_extension_schema(hllId); + const char *hllSchemaName = get_namespace_name(hllSchemaOid); + + Oid unionFunctionId = FunctionOid(hllSchemaName, HLL_UNION_AGGREGATE_NAME, + argCount); + Oid cardinalityFunctionId = FunctionOid(hllSchemaName, HLL_CARDINALITY_FUNC_NAME, + argCount); Oid cardinalityReturnType = get_func_rettype(cardinalityFunctionId); - Oid hllType = TypenameGetTypid(HLL_TYPE_NAME); + Oid hllType = TypeOid(hllSchemaOid, HLL_TYPE_NAME); Oid hllTypeCollationId = get_typcollation(hllType); Var *hllColumn = makeVar(masterTableId, walkerContext->columnId, hllType, defaultTypeMod, @@ -1911,13 +1920,20 @@ WorkerAggregateExpressionList(Aggref *originalAggregate, TargetEntry *argument = (TargetEntry *) linitial(originalAggregate->args); Expr *argumentExpression = copyObject(argument->expr); + /* extract schema name of hll */ + Oid hllId = get_extension_oid(HLL_EXTENSION_NAME, false); + Oid hllSchemaOid = get_extension_schema(hllId); + const char *hllSchemaName = get_namespace_name(hllSchemaOid); + char *hashFunctionName = CountDistinctHashFunctionName(argumentType); - Oid hashFunctionId = FunctionOid(hashFunctionName, hashArgumentCount); + Oid hashFunctionId = FunctionOid(hllSchemaName, hashFunctionName, + hashArgumentCount); Oid hashFunctionReturnType = get_func_rettype(hashFunctionId); /* init hll_add_agg() related variables */ - Oid addFunctionId = FunctionOid(HLL_ADD_AGGREGATE_NAME, addArgumentCount); - Oid hllType = TypenameGetTypid(HLL_TYPE_NAME); + Oid addFunctionId = FunctionOid(hllSchemaName, HLL_ADD_AGGREGATE_NAME, + addArgumentCount); + Oid hllType = TypeOid(hllSchemaOid, HLL_TYPE_NAME); int logOfStorageSize = CountDistinctStorageSize(CountDistinctErrorRate); Const *logOfStorageSizeConst = MakeIntegerConst(logOfStorageSize); @@ -2103,18 +2119,19 @@ AggregateFunctionOid(const char *functionName, Oid inputType) * of arguments, and returns the corresponding function's oid. */ Oid -FunctionOid(const char *functionName, int argumentCount) +FunctionOid(const char *schemaName, const char *functionName, int argumentCount) { FuncCandidateList functionList = NULL; Oid functionOid = InvalidOid; - List *qualifiedFunctionName = stringToQualifiedNameList(functionName); + char *qualifiedFunctionName = quote_qualified_identifier(schemaName, functionName); + List *qualifiedFunctionNameList = stringToQualifiedNameList(qualifiedFunctionName); List *argumentList = NIL; const bool findVariadics = false; const bool findDefaults = false; const bool missingOK = true; - functionList = FuncnameGetCandidates(qualifiedFunctionName, argumentCount, + functionList = FuncnameGetCandidates(qualifiedFunctionNameList, argumentCount, argumentList, findVariadics, findDefaults, missingOK); @@ -2135,6 +2152,22 @@ FunctionOid(const char *functionName, int argumentCount) } +/* + * TypeOid looks for a type that has the given name and schema, and returns the + * corresponding type's oid. + */ +static Oid +TypeOid(Oid schemaId, const char *typeName) +{ + Oid typeOid; + + typeOid = GetSysCacheOid2(TYPENAMENSP, PointerGetDatum(typeName), + ObjectIdGetDatum(schemaId)); + + return typeOid; +} + + /* * CountDistinctHashFunctionName resolves the hll_hash function name to use for * the given input type, and returns this function name. @@ -4393,9 +4426,21 @@ static bool HasOrderByHllType(List *sortClauseList, List *targetList) { bool hasOrderByHllType = false; - Oid hllTypeId = TypenameGetTypid(HLL_TYPE_NAME); - + Oid hllId = InvalidOid; + Oid hllSchemaOid = InvalidOid; + Oid hllTypeId = InvalidOid; ListCell *sortClauseCell = NULL; + + /* check whether HLL is loaded */ + hllId = get_extension_oid(HLL_EXTENSION_NAME, true); + if (!OidIsValid(hllId)) + { + return hasOrderByHllType; + } + + hllSchemaOid = get_extension_schema(hllId); + hllTypeId = TypeOid(hllSchemaOid, HLL_TYPE_NAME); + foreach(sortClauseCell, sortClauseList) { SortGroupClause *sortClause = (SortGroupClause *) lfirst(sortClauseCell); diff --git a/src/backend/distributed/utils/citus_ruleutils.c b/src/backend/distributed/utils/citus_ruleutils.c index 706f9599a..787ba2697 100644 --- a/src/backend/distributed/utils/citus_ruleutils.c +++ b/src/backend/distributed/utils/citus_ruleutils.c @@ -53,7 +53,6 @@ #include "utils/typcache.h" #include "utils/xml.h" -static Oid get_extension_schema(Oid ext_oid); static void AppendOptionListToString(StringInfo stringData, List *options); static const char * convert_aclright_to_string(int aclright); @@ -101,7 +100,7 @@ pg_get_extensiondef_string(Oid tableRelationId) * * Returns InvalidOid if no such extension. */ -static Oid +Oid get_extension_schema(Oid ext_oid) { /* *INDENT-OFF* */ diff --git a/src/backend/distributed/worker/worker_data_fetch_protocol.c b/src/backend/distributed/worker/worker_data_fetch_protocol.c index 8a1147334..72ada2cd2 100644 --- a/src/backend/distributed/worker/worker_data_fetch_protocol.c +++ b/src/backend/distributed/worker/worker_data_fetch_protocol.c @@ -23,6 +23,8 @@ #include "catalog/namespace.h" #include "commands/copy.h" #include "commands/dbcommands.h" +#include "commands/extension.h" +#include "distributed/citus_ruleutils.h" #include "distributed/master_protocol.h" #include "distributed/multi_client_executor.h" #include "distributed/multi_logical_optimizer.h" @@ -592,8 +594,15 @@ LocalTableSize(Oid relationId) bool cstoreTable = CStoreTable(relationId); if (cstoreTable) { + /* extract schema name of cstore */ + Oid cstoreId = get_extension_oid(CSTORE_FDW_NAME, false); + Oid cstoreSchemaOid = get_extension_schema(cstoreId); + const char *cstoreSchemaName = get_namespace_name(cstoreSchemaOid); + const int tableSizeArgumentCount = 1; - Oid tableSizeFunctionOid = FunctionOid(CSTORE_TABLE_SIZE_FUNCTION_NAME, + + Oid tableSizeFunctionOid = FunctionOid(cstoreSchemaName, + CSTORE_TABLE_SIZE_FUNCTION_NAME, tableSizeArgumentCount); Datum tableSizeDatum = OidFunctionCall1(tableSizeFunctionOid, relationIdDatum); diff --git a/src/include/distributed/citus_ruleutils.h b/src/include/distributed/citus_ruleutils.h index dcee35280..861cb56d7 100644 --- a/src/include/distributed/citus_ruleutils.h +++ b/src/include/distributed/citus_ruleutils.h @@ -22,6 +22,7 @@ extern char * pg_get_tableschemadef_string(Oid tableRelationId); extern char * pg_get_tablecolumnoptionsdef_string(Oid tableRelationId); extern char * pg_get_indexclusterdef_string(Oid indexRelationId); extern List * pg_get_table_grants(Oid relationId); +extern Oid get_extension_schema(Oid ext_oid); /* Function declarations for version dependent PostgreSQL ruleutils functions */ extern void pg_get_query_def(Query *query, StringInfo buffer); diff --git a/src/include/distributed/multi_logical_optimizer.h b/src/include/distributed/multi_logical_optimizer.h index c04b30c1a..3b840f935 100644 --- a/src/include/distributed/multi_logical_optimizer.h +++ b/src/include/distributed/multi_logical_optimizer.h @@ -112,7 +112,8 @@ extern void MultiLogicalPlanOptimize(MultiTreeRoot *multiTree); extern char PartitionMethod(Oid relationId); /* Function declaration for getting oid for the given function name */ -extern Oid FunctionOid(const char *functionName, int argumentCount); +extern Oid FunctionOid(const char *schemaName, const char *functionName, + int argumentCount); /* Function declaration for helper functions in subquery pushdown */ extern List * SubqueryMultiTableList(MultiNode *multiNode); diff --git a/src/test/regress/expected/multi_agg_approximate_distinct.out b/src/test/regress/expected/multi_agg_approximate_distinct.out index 852142d0f..a9dd982c2 100644 --- a/src/test/regress/expected/multi_agg_approximate_distinct.out +++ b/src/test/regress/expected/multi_agg_approximate_distinct.out @@ -1,6 +1,8 @@ -- -- MULTI_AGG_APPROXIMATE_DISTINCT -- +ALTER SEQUENCE pg_catalog.pg_dist_shardid_seq RESTART 340000; +ALTER SEQUENCE pg_catalog.pg_dist_jobid_seq RESTART 340000; -- Try to execute count(distinct) when approximate distincts aren't enabled SELECT count(distinct l_orderkey) FROM lineitem; ERROR: cannot compute aggregate (distinct) @@ -100,6 +102,47 @@ SELECT count(DISTINCT l_orderkey) as distinct_order_count, l_quantity FROM linei 223 | 31.00 (10 rows) +-- Check that approximate count(distinct) works at a table in a schema other than public +-- create necessary objects +CREATE SCHEMA test_count_distinct_schema; +NOTICE: Citus partially supports CREATE SCHEMA for distributed databases +DETAIL: schema usage in joins and in some UDFs provided by Citus are not supported yet +CREATE TABLE test_count_distinct_schema.nation_hash( + n_nationkey integer not null, + n_name char(25) not null, + n_regionkey integer not null, + n_comment varchar(152) +); +SELECT master_create_distributed_table('test_count_distinct_schema.nation_hash', 'n_nationkey', 'hash'); + master_create_distributed_table +--------------------------------- + +(1 row) + +SELECT master_create_worker_shards('test_count_distinct_schema.nation_hash', 4, 2); + master_create_worker_shards +----------------------------- + +(1 row) + +\COPY test_count_distinct_schema.nation_hash FROM STDIN with delimiter '|'; +SET search_path TO public; +SET citus.count_distinct_error_rate TO 0.01; +SELECT COUNT (DISTINCT n_regionkey) FROM test_count_distinct_schema.nation_hash; + count +------- + 3 +(1 row) + +-- test with search_path is set +SET search_path TO test_count_distinct_schema; +SELECT COUNT (DISTINCT n_regionkey) FROM nation_hash; + count +------- + 3 +(1 row) + +SET search_path TO public; -- If we have an order by on count(distinct) that we intend to push down to -- worker nodes, we need to error out. Otherwise, we are fine. SET citus.limit_clause_row_fetch_count = 1000; diff --git a/src/test/regress/sql/multi_agg_approximate_distinct.sql b/src/test/regress/sql/multi_agg_approximate_distinct.sql index a79be18a1..a80b59967 100644 --- a/src/test/regress/sql/multi_agg_approximate_distinct.sql +++ b/src/test/regress/sql/multi_agg_approximate_distinct.sql @@ -52,6 +52,37 @@ SELECT count(DISTINCT l_orderkey) as distinct_order_count, l_quantity FROM linei ORDER BY distinct_order_count ASC, l_quantity ASC LIMIT 10; +-- Check that approximate count(distinct) works at a table in a schema other than public +-- create necessary objects +CREATE SCHEMA test_count_distinct_schema; + +CREATE TABLE test_count_distinct_schema.nation_hash( + n_nationkey integer not null, + n_name char(25) not null, + n_regionkey integer not null, + n_comment varchar(152) +); +SELECT master_create_distributed_table('test_count_distinct_schema.nation_hash', 'n_nationkey', 'hash'); +SELECT master_create_worker_shards('test_count_distinct_schema.nation_hash', 4, 2); + +\COPY test_count_distinct_schema.nation_hash FROM STDIN with delimiter '|'; +0|ALGERIA|0|haggle. carefully final deposits detect slyly agai +1|ARGENTINA|1|al foxes promise slyly according to the regular accounts. bold requests alon +2|BRAZIL|1|y alongside of the pending deposits. carefully special packages are about the ironic forges. slyly special +3|CANADA|1|eas hang ironic, silent packages. slyly regular packages are furiously over the tithes. fluffily bold +4|EGYPT|4|y above the carefully unusual theodolites. final dugouts are quickly across the furiously regular d +5|ETHIOPIA|0|ven packages wake quickly. regu +\. + +SET search_path TO public; +SET citus.count_distinct_error_rate TO 0.01; +SELECT COUNT (DISTINCT n_regionkey) FROM test_count_distinct_schema.nation_hash; + +-- test with search_path is set +SET search_path TO test_count_distinct_schema; +SELECT COUNT (DISTINCT n_regionkey) FROM nation_hash; +SET search_path TO public; + -- If we have an order by on count(distinct) that we intend to push down to -- worker nodes, we need to error out. Otherwise, we are fine.