diff --git a/src/backend/distributed/executor/intermediate_results.c b/src/backend/distributed/executor/intermediate_results.c index e9ccb849f..f4be6263c 100644 --- a/src/backend/distributed/executor/intermediate_results.c +++ b/src/backend/distributed/executor/intermediate_results.c @@ -653,6 +653,28 @@ RemoveIntermediateResultsDirectory(void) } +/* + * IntermediateResultSize returns the file size of the intermediate result + * or -1 if the file does not exist. + */ +int64 +IntermediateResultSize(char *resultId) +{ + char *resultFileName = NULL; + struct stat fileStat; + int statOK = 0; + + resultFileName = QueryResultFileName(resultId); + statOK = stat(resultFileName, &fileStat); + if (statOK < 0) + { + return -1; + } + + return (int64) fileStat.st_size; +} + + /* * read_intermediate_result is a UDF that returns a COPY-formatted intermediate * result file as a set of records. The file is parsed according to the columns diff --git a/src/backend/distributed/planner/distributed_planner.c b/src/backend/distributed/planner/distributed_planner.c index a4c2f4034..6725f2050 100644 --- a/src/backend/distributed/planner/distributed_planner.c +++ b/src/backend/distributed/planner/distributed_planner.c @@ -17,6 +17,7 @@ #include "distributed/citus_nodefuncs.h" #include "distributed/citus_nodes.h" #include "distributed/insert_select_planner.h" +#include "distributed/intermediate_results.h" #include "distributed/metadata_cache.h" #include "distributed/multi_executor.h" #include "distributed/distributed_planner.h" @@ -31,11 +32,15 @@ #include "nodes/makefuncs.h" #include "nodes/nodeFuncs.h" #include "parser/parsetree.h" +#include "parser/parse_type.h" +#include "optimizer/cost.h" #include "optimizer/pathnode.h" #include "optimizer/planner.h" +#include "utils/builtins.h" #include "utils/datum.h" #include "utils/lsyscache.h" #include "utils/memutils.h" +#include "utils/syscache.h" static List *plannerRestrictionContextList = NIL; @@ -70,6 +75,8 @@ static PlannedStmt * FinalizeNonRouterPlan(PlannedStmt *localPlan, static PlannedStmt * FinalizeRouterPlan(PlannedStmt *localPlan, CustomScan *customScan); static void CheckNodeIsDumpable(Node *node); static Node * CheckNodeCopyAndSerialization(Node *node); +static void AdjustReadIntermediateResultCost(RangeTblEntry *rangeTableEntry, + RelOptInfo *relOptInfo); static List * CopyPlanParamList(List *originalPlanParamList); static PlannerRestrictionContext * CreateAndPushPlannerRestrictionContext(void); static PlannerRestrictionContext * CurrentPlannerRestrictionContext(void); @@ -1159,6 +1166,8 @@ multi_relation_restriction_hook(PlannerInfo *root, RelOptInfo *relOptInfo, Index bool distributedTable = false; bool localTable = false; + AdjustReadIntermediateResultCost(rte, relOptInfo); + if (rte->rtekind != RTE_RELATION) { return; @@ -1215,6 +1224,136 @@ multi_relation_restriction_hook(PlannerInfo *root, RelOptInfo *relOptInfo, Index } +/* + * AdjustReadIntermediateResultCost adjusts the row count and total cost + * of a read_intermediate_result call based on the file size. + */ +static void +AdjustReadIntermediateResultCost(RangeTblEntry *rangeTableEntry, RelOptInfo *relOptInfo) +{ + PathTarget *reltarget = relOptInfo->reltarget; + List *pathList = relOptInfo->pathlist; + Path *path = NULL; + RangeTblFunction *rangeTableFunction = NULL; + FuncExpr *funcExpression = NULL; + Const *resultFormatConst = NULL; + Datum resultFormatDatum = 0; + Oid resultFormatId = InvalidOid; + Const *resultIdConst = NULL; + Datum resultIdDatum = 0; + char *resultId = NULL; + int64 resultSize = 0; + ListCell *typeCell = NULL; + bool binaryFormat = false; + double rowCost = 0.; + double rowSizeEstimate = 0; + double rowCountEstimate = 0.; + double ioCost = 0.; + + if (rangeTableEntry->rtekind != RTE_FUNCTION || + list_length(rangeTableEntry->functions) != 1) + { + /* avoid more expensive checks below for non-functions */ + return; + } + + if (!CitusHasBeenLoaded() || !CheckCitusVersion(DEBUG5)) + { + /* read_intermediate_result may not exist */ + return; + } + + if (!ContainsReadIntermediateResultFunction((Node *) rangeTableEntry->functions)) + { + return; + } + + rangeTableFunction = (RangeTblFunction *) linitial(rangeTableEntry->functions); + funcExpression = (FuncExpr *) rangeTableFunction->funcexpr; + resultIdConst = (Const *) linitial(funcExpression->args); + if (!IsA(resultIdConst, Const)) + { + /* not sure how to interpret non-const */ + return; + } + + resultIdDatum = resultIdConst->constvalue; + resultId = TextDatumGetCString(resultIdDatum); + + resultSize = IntermediateResultSize(resultId); + if (resultSize < 0) + { + /* result does not exist, will probably error out later on */ + return; + } + + resultFormatConst = (Const *) lsecond(funcExpression->args); + if (!IsA(resultFormatConst, Const)) + { + /* not sure how to interpret non-const */ + return; + } + + resultFormatDatum = resultFormatConst->constvalue; + resultFormatId = DatumGetObjectId(resultFormatDatum); + + if (resultFormatId == BinaryCopyFormatId()) + { + binaryFormat = true; + + /* subtract 11-byte signature + 8 byte header + 2-byte footer */ + resultSize -= 21; + } + + /* start with the cost of evaluating quals */ + rowCost += relOptInfo->baserestrictcost.per_tuple; + + /* postgres' estimate for the width of the rows */ + rowSizeEstimate += reltarget->width; + + /* add 2 bytes for column count (binary) or line separator (text) */ + rowSizeEstimate += 2; + + foreach(typeCell, rangeTableFunction->funccoltypes) + { + Oid columnTypeId = lfirst_oid(typeCell); + Oid inputFunctionId = InvalidOid; + Oid typeIOParam = InvalidOid; + + if (binaryFormat) + { + getTypeBinaryInputInfo(columnTypeId, &inputFunctionId, &typeIOParam); + + /* binary format: 4 bytes for field size */ + rowSizeEstimate += 4; + } + else + { + getTypeInputInfo(columnTypeId, &inputFunctionId, &typeIOParam); + + /* text format: 1 byte for tab separator */ + rowSizeEstimate += 1; + } + + /* add the cost of parsing a column */ + rowCost += get_func_cost(inputFunctionId) * cpu_operator_cost; + } + + /* estimate the number of rows based on the file size and estimated row size */ + rowCountEstimate = Max(1, (double) resultSize / rowSizeEstimate); + + /* cost of reading the data */ + ioCost = seq_page_cost * resultSize / BLCKSZ; + + Assert(pathList != NIL); + + /* tell the planner about the cost and row count of the function */ + path = (Path *) linitial(pathList); + path->rows = rowCountEstimate; + path->total_cost = rowCountEstimate * rowCost + ioCost; +} + + /* * CopyPlanParamList deep copies the input PlannerParamItem list and returns the newly * allocated list. diff --git a/src/include/distributed/intermediate_results.h b/src/include/distributed/intermediate_results.h index d5be37c8c..e52a273ef 100644 --- a/src/include/distributed/intermediate_results.h +++ b/src/include/distributed/intermediate_results.h @@ -26,6 +26,7 @@ extern DestReceiver * CreateRemoteFileDestReceiver(char *resultId, EState *execu writeLocalFile); extern void ReceiveQueryResultViaCopy(const char *resultId); extern void RemoveIntermediateResultsDirectory(void); +extern int64 IntermediateResultSize(char *resultId); #endif /* INTERMEDIATE_RESULTS_H */ diff --git a/src/test/regress/expected/intermediate_results.out b/src/test/regress/expected/intermediate_results.out index 78271499e..065c27357 100644 --- a/src/test/regress/expected/intermediate_results.out +++ b/src/test/regress/expected/intermediate_results.out @@ -185,6 +185,47 @@ ON ((s).x = interested_in) ORDER BY 1,2; jon | 5 | (5,25) | {"value": 5} (3 rows) +END; +BEGIN; +-- accurate row count estimates for primitive types +SELECT create_intermediate_result('squares', 'SELECT s, s*s FROM generate_series(1,632) s'); + create_intermediate_result +---------------------------- + 632 +(1 row) + +EXPLAIN SELECT * FROM read_intermediate_result('squares', 'binary') AS res (x int, x2 int); + QUERY PLAN +----------------------------------------------------------------------------------- + Function Scan on read_intermediate_result res (cost=0.00..4.55 rows=632 width=8) +(1 row) + +-- less accurate results for variable types +SELECT create_intermediate_result('hellos', $$SELECT s, 'hello-'||s FROM generate_series(1,63) s$$); + create_intermediate_result +---------------------------- + 63 +(1 row) + +EXPLAIN SELECT * FROM read_intermediate_result('hellos', 'binary') AS res (x int, y text); + QUERY PLAN +----------------------------------------------------------------------------------- + Function Scan on read_intermediate_result res (cost=0.00..0.32 rows=30 width=36) +(1 row) + +-- not very accurate results for text encoding +SELECT create_intermediate_result('stored_squares', 'SELECT square FROM stored_squares'); + create_intermediate_result +---------------------------- + 4 +(1 row) + +EXPLAIN SELECT * FROM read_intermediate_result('stored_squares', 'text') AS res (s intermediate_results.square_type); + QUERY PLAN +---------------------------------------------------------------------------------- + Function Scan on read_intermediate_result res (cost=0.00..0.01 rows=1 width=32) +(1 row) + END; -- pipe query output into a result file and create a table to check the result COPY (SELECT s, s*s FROM generate_series(1,5) s) diff --git a/src/test/regress/sql/intermediate_results.sql b/src/test/regress/sql/intermediate_results.sql index 2b23c416f..9f1e2a460 100644 --- a/src/test/regress/sql/intermediate_results.sql +++ b/src/test/regress/sql/intermediate_results.sql @@ -97,6 +97,20 @@ ON ((s).x = interested_in) ORDER BY 1,2; END; +BEGIN; +-- accurate row count estimates for primitive types +SELECT create_intermediate_result('squares', 'SELECT s, s*s FROM generate_series(1,632) s'); +EXPLAIN SELECT * FROM read_intermediate_result('squares', 'binary') AS res (x int, x2 int); + +-- less accurate results for variable types +SELECT create_intermediate_result('hellos', $$SELECT s, 'hello-'||s FROM generate_series(1,63) s$$); +EXPLAIN SELECT * FROM read_intermediate_result('hellos', 'binary') AS res (x int, y text); + +-- not very accurate results for text encoding +SELECT create_intermediate_result('stored_squares', 'SELECT square FROM stored_squares'); +EXPLAIN SELECT * FROM read_intermediate_result('stored_squares', 'text') AS res (s intermediate_results.square_type); +END; + -- pipe query output into a result file and create a table to check the result COPY (SELECT s, s*s FROM generate_series(1,5) s) TO PROGRAM