From 8de8b62669f0251f34859d6d5ce01030ed19846b Mon Sep 17 00:00:00 2001 From: Jelte Fennema Date: Tue, 4 Feb 2020 14:05:25 +0100 Subject: [PATCH] Convert unsafe APIs to safe ones --- .circleci/config.yml | 3 + .ignore | 1 + ci/banned.h.sh | 54 ++++ src/backend/distributed/Makefile | 14 +- src/backend/distributed/commands/collation.c | 10 +- src/backend/distributed/commands/function.c | 10 +- src/backend/distributed/commands/multi_copy.c | 22 +- src/backend/distributed/commands/type.c | 10 +- .../connection/connection_configuration.c | 5 +- .../connection/connection_management.c | 8 +- .../executor/multi_task_tracker_executor.c | 13 +- .../master/master_metadata_utility.c | 5 +- .../distributed/master/worker_node_manager.c | 2 +- .../distributed/metadata/metadata_cache.c | 16 +- .../distributed/metadata/node_metadata.c | 2 +- .../distributed/planner/distributed_planner.c | 2 +- .../planner/multi_router_planner.c | 3 +- .../distributed/planner/recursive_planning.c | 2 +- .../distributed/planner/shard_pruning.c | 2 +- .../distributed/relay/relay_event_utility.c | 17 +- src/backend/distributed/shared_library_init.c | 21 +- .../distributed/transaction/backend_data.c | 2 +- .../transaction/remote_transaction.c | 7 +- .../transaction/transaction_management.c | 3 +- src/backend/distributed/utils/acquire_lock.c | 21 +- .../distributed/utils/citus_safe_lib.c | 302 ++++++++++++++++++ src/backend/distributed/utils/listutils.c | 3 +- src/backend/distributed/utils/maintenanced.c | 17 +- .../distributed/utils/statistics_collection.c | 113 ------- src/backend/distributed/worker/task_tracker.c | 10 +- .../worker/worker_data_fetch_protocol.c | 2 +- src/include/distributed/citus_safe_lib.h | 29 ++ 32 files changed, 528 insertions(+), 203 deletions(-) create mode 100644 .ignore create mode 100755 ci/banned.h.sh create mode 100644 src/backend/distributed/utils/citus_safe_lib.c create mode 100644 src/include/distributed/citus_safe_lib.h diff --git a/.circleci/config.yml b/.circleci/config.yml index 015983c23..2376a49d0 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -40,6 +40,9 @@ jobs: - run: name: 'Check if changed' command: git diff --exit-code + - run: + name: 'Check for banned C API usage' + command: ci/banned.h.sh check-sql-snapshots: docker: - image: 'citus/extbuilder:latest' diff --git a/.ignore b/.ignore new file mode 100644 index 000000000..61ead8666 --- /dev/null +++ b/.ignore @@ -0,0 +1 @@ +/vendor diff --git a/ci/banned.h.sh b/ci/banned.h.sh new file mode 100755 index 000000000..6b91e6072 --- /dev/null +++ b/ci/banned.h.sh @@ -0,0 +1,54 @@ +#!/bin/sh + +# Checks for the APIs that are banned by microsoft. Since we compile for Linux +# we use the replacements from https://github.com/intel/safestringlib +# Not all replacement functions are available in safestringlib. If it doesn't +# exist and you cannot rewrite the code to not use the banned API, then you can +# add a comment containing "IGNORE-BANNED" to the line where the error is and +# this check will ignore that match. +# +# The replacement function that you should use are listed here: +# https://liquid.microsoft.com/Web/Object/Read/ms.security/Requirements/Microsoft.Security.SystemsADM.10082#guide + +set -eu + +files=$(find src -iname '*.[ch]' | git check-attr --stdin citus-style | grep -v ': unset$' | sed 's/: citus-style: set$//') + +# grep is allowed to fail, that means no banned matches are found +set +e +# Required banned from banned.h. These functions are not allowed to be used at +# all. +# shellcheck disable=SC2086 +grep -E '\b(strcpy|strcpyA|strcpyW|wcscpy|_tcscpy|_mbscpy|StrCpy|StrCpyA|StrCpyW|lstrcpy|lstrcpyA|lstrcpyW|_tccpy|_mbccpy|_ftcscpy|strcat|strcatA|strcatW|wcscat|_tcscat|_mbscat|StrCat|StrCatA|StrCatW|lstrcat|lstrcatA|lstrcatW|StrCatBuff|StrCatBuffA|StrCatBuffW|StrCatChainW|_tccat|_mbccat|_ftcscat|sprintfW|sprintfA|wsprintf|wsprintfW|wsprintfA|sprintf|swprintf|_stprintf|wvsprintf|wvsprintfA|wvsprintfW|vsprintf|_vstprintf|vswprintf|strncpy|wcsncpy|_tcsncpy|_mbsncpy|_mbsnbcpy|StrCpyN|StrCpyNA|StrCpyNW|StrNCpy|strcpynA|StrNCpyA|StrNCpyW|lstrcpyn|lstrcpynA|lstrcpynW|strncat|wcsncat|_tcsncat|_mbsncat|_mbsnbcat|StrCatN|StrCatNA|StrCatNW|StrNCat|StrNCatA|StrNCatW|lstrncat|lstrcatnA|lstrcatnW|lstrcatn|gets|_getts|_gettws|IsBadWritePtr|IsBadHugeWritePtr|IsBadReadPtr|IsBadHugeReadPtr|IsBadCodePtr|IsBadStringPtr|memcpy|RtlCopyMemory|CopyMemory|wmemcpy|lstrlen)\(' $files \ + | grep -v "IGNORE-BANNED" \ + && echo "ERROR: Required banned API usage detected" && exit 1 + +# Required banned from table on liquid. These functions are not allowed to be +# used at all. +# shellcheck disable=SC2086 +grep -E '\b(strcat|strcpy|strerror|strncat|strncpy|strtok|wcscat|wcscpy|wcsncat|wcsncpy|wcstok|fprintf|fwprintf|printf|snprintf|sprintf|swprintf|vfprintf|vprintf|vsnprintf|vsprintf|vswprintf|vwprintf|wprintf|fscanf|fwscanf|gets|scanf|sscanf|swscanf|vfscanf|vfwscanf|vscanf|vsscanf|vswscanf|vwscanf|wscanf|asctime|atof|atoi|atol|atoll|bsearch|ctime|fopen|freopen|getenv|gmtime|localtime|mbsrtowcs|mbstowcs|memcpy|memmove|qsort|rewind|setbuf|wmemcpy|wmemmove)\(' $files \ + | grep -v "IGNORE-BANNED" \ + && echo "ERROR: Required banned API usage from table detected" && exit 1 + +# Recommended banned from banned.h. If you can change the code not to use these +# that would be great. You can use IGNORE-BANNED if you need to use it anyway. +# You can also remove it from the regex, if you want to mark the API as allowed +# throughout the codebase (to not have to add IGNORED-BANNED everywhere). In +# that case note it in this comment that you did so. +# shellcheck disable=SC2086 +grep -E '\b(wnsprintf|wnsprintfA|wnsprintfW|_snwprintf|_snprintf|_sntprintf|_vsnprintf|vsnprintf|_vsnwprintf|_vsntprintf|wvnsprintf|wvnsprintfA|wvnsprintfW|strtok|_tcstok|wcstok|_mbstok|makepath|_tmakepath| _makepath|_wmakepath|_splitpath|_tsplitpath|_wsplitpath|scanf|wscanf|_tscanf|sscanf|swscanf|_stscanf|snscanf|snwscanf|_sntscanf|_itoa|_itow|_i64toa|_i64tow|_ui64toa|_ui64tot|_ui64tow|_ultoa|_ultot|_ultow|CharToOem|CharToOemA|CharToOemW|OemToChar|OemToCharA|OemToCharW|CharToOemBuffA|CharToOemBuffW|alloca|_alloca|ChangeWindowMessageFilter)\(' $files \ + | grep -v "IGNORE-BANNED" \ + && echo "ERROR: Recomended banned API usage detected" && exit 1 + +# Recommended banned from table on liquid. If you can change the code not to use these +# that would be great. You can use IGNORE-BANNED if you need to use it anyway. +# You can also remove it from the regex, if you want to mark the API as allowed +# throughout the codebase (to not have to add IGNORED-BANNED everywhere). In +# that case note it in this comment that you did so. +# Banned APIs ignored throughout the codebase: +# - strlen +# shellcheck disable=SC2086 +grep -E '\b(alloca|getwd|mktemp|tmpnam|wcrtomb|wcrtombs|wcslen|wcsrtombs|wcstombs|wctomb|class_addMethod|class_replaceMethod)\(' $files \ + | grep -v "IGNORE-BANNED" \ + && echo "ERROR: Recomended banned API usage detected" && exit 1 +exit 0 diff --git a/src/backend/distributed/Makefile b/src/backend/distributed/Makefile index 76cf0acbb..d5356c0ba 100644 --- a/src/backend/distributed/Makefile +++ b/src/backend/distributed/Makefile @@ -2,6 +2,10 @@ citus_subdir = src/backend/distributed citus_top_builddir = ../../.. +safestringlib_srcdir = $(citus_abs_top_srcdir)/vendor/safestringlib +safestringlib_builddir = $(citus_top_builddir)/vendor/safestringlib/build +safestringlib_a = $(safestringlib_builddir)/libsafestring_static.a +safestringlib_sources = $(wildcard $(safestringlib_srcdir)/safeclib/*) MODULE_big = citus EXTENSION = citus @@ -39,11 +43,19 @@ utils/citus_version.o: $(CITUS_VERSION_INVALIDATE) SHLIB_LINK += $(filter -lssl -lcrypto -lssleay32 -leay32, $(LIBS)) -override CPPFLAGS += -I$(libpq_srcdir) +override LDFLAGS += $(safestringlib_a) +override CPPFLAGS += -I$(libpq_srcdir) -I$(safestringlib_srcdir)/include SQL_DEPDIR=.deps/sql SQL_BUILDDIR=build/sql +$(safestringlib_a): $(safestringlib_sources) + rm -rf $(safestringlib_builddir) + mkdir -p $(safestringlib_builddir) + cd $(safestringlib_builddir) && cmake $(safestringlib_srcdir) && make -j5 + +citus.so: $(safestringlib_a) + $(generated_sql_files): $(citus_abs_srcdir)/build/%: % @mkdir -p $(citus_abs_srcdir)/$(SQL_DEPDIR) $(citus_abs_srcdir)/$(SQL_BUILDDIR) cd $(citus_abs_srcdir) && cpp -undef -w -P -MMD -MP -MF$(SQL_DEPDIR)/$(*F).Po -MT$@ $< > $@ diff --git a/src/backend/distributed/commands/collation.c b/src/backend/distributed/commands/collation.c index af0eb6bcf..8b7063cf1 100644 --- a/src/backend/distributed/commands/collation.c +++ b/src/backend/distributed/commands/collation.c @@ -13,6 +13,7 @@ #include "access/htup_details.h" #include "access/xact.h" #include "catalog/pg_collation.h" +#include "distributed/citus_safe_lib.h" #include "distributed/commands/utility_hook.h" #include "distributed/commands.h" #include "distributed/deparser.h" @@ -529,16 +530,17 @@ GenerateBackupNameForCollationCollision(const ObjectAddress *address) while (true) { - int suffixLength = snprintf(suffix, NAMEDATALEN - 1, "(citus_backup_%d)", - count); + int suffixLength = SafeSnprintf(suffix, NAMEDATALEN - 1, "(citus_backup_%d)", + count); /* trim the base name at the end to leave space for the suffix and trailing \0 */ baseLength = Min(baseLength, NAMEDATALEN - suffixLength - 1); /* clear newName before copying the potentially trimmed baseName and suffix */ memset(newName, 0, NAMEDATALEN); - strncpy(newName, baseName, baseLength); - strncpy(newName + baseLength, suffix, suffixLength); + strncpy_s(newName, NAMEDATALEN, baseName, baseLength); + strncpy_s(newName + baseLength, NAMEDATALEN - baseLength, suffix, + suffixLength); List *newCollationName = list_make2(namespace, makeString(newName)); diff --git a/src/backend/distributed/commands/function.c b/src/backend/distributed/commands/function.c index 9bffb1bac..da6cefa3f 100644 --- a/src/backend/distributed/commands/function.c +++ b/src/backend/distributed/commands/function.c @@ -30,6 +30,7 @@ #include "catalog/pg_type.h" #include "commands/extension.h" #include "distributed/citus_ruleutils.h" +#include "distributed/citus_safe_lib.h" #include "distributed/colocation_utils.h" #include "distributed/commands.h" #include "distributed/commands/utility_hook.h" @@ -1715,16 +1716,17 @@ GenerateBackupNameForProcCollision(const ObjectAddress *address) while (true) { - int suffixLength = snprintf(suffix, NAMEDATALEN - 1, "(citus_backup_%d)", - count); + int suffixLength = SafeSnprintf(suffix, NAMEDATALEN - 1, "(citus_backup_%d)", + count); /* trim the base name at the end to leave space for the suffix and trailing \0 */ baseLength = Min(baseLength, NAMEDATALEN - suffixLength - 1); /* clear newName before copying the potentially trimmed baseName and suffix */ memset(newName, 0, NAMEDATALEN); - strncpy(newName, baseName, baseLength); - strncpy(newName + baseLength, suffix, suffixLength); + strncpy_s(newName, NAMEDATALEN, baseName, baseLength); + strncpy_s(newName + baseLength, NAMEDATALEN - baseLength, suffix, + suffixLength); List *newProcName = list_make2(namespace, makeString(newName)); diff --git a/src/backend/distributed/commands/multi_copy.c b/src/backend/distributed/commands/multi_copy.c index 108739832..d6d3cbee1 100644 --- a/src/backend/distributed/commands/multi_copy.c +++ b/src/backend/distributed/commands/multi_copy.c @@ -65,6 +65,7 @@ #include "catalog/pg_type.h" #include "commands/copy.h" #include "commands/defrem.h" +#include "distributed/citus_safe_lib.h" #include "distributed/commands/multi_copy.h" #include "distributed/commands/utility_hook.h" #include "distributed/intermediate_results.h" @@ -432,9 +433,8 @@ CopyToExistingShards(CopyStmt *copyStatement, char *completionTag) * There is no need to deep copy everything. We will just deep copy of the fields * we will change. */ - memcpy(copiedDistributedRelation, distributedRelation, sizeof(RelationData)); - memcpy(copiedDistributedRelationTuple, distributedRelation->rd_rel, - CLASS_TUPLE_SIZE); + *copiedDistributedRelation = *distributedRelation; + *copiedDistributedRelationTuple = *distributedRelation->rd_rel; copiedDistributedRelation->rd_rel = copiedDistributedRelationTuple; copiedDistributedRelation->rd_att = CreateTupleDescCopyConstr(tupleDescriptor); @@ -511,8 +511,8 @@ CopyToExistingShards(CopyStmt *copyStatement, char *completionTag) if (completionTag != NULL) { - snprintf(completionTag, COMPLETION_TAG_BUFSIZE, - "COPY " UINT64_FORMAT, processedRowCount); + SafeSnprintf(completionTag, COMPLETION_TAG_BUFSIZE, + "COPY " UINT64_FORMAT, processedRowCount); } } @@ -696,8 +696,8 @@ CopyToNewShards(CopyStmt *copyStatement, char *completionTag, Oid relationId) if (completionTag != NULL) { - snprintf(completionTag, COMPLETION_TAG_BUFSIZE, - "COPY " UINT64_FORMAT, processedRowCount); + SafeSnprintf(completionTag, COMPLETION_TAG_BUFSIZE, + "COPY " UINT64_FORMAT, processedRowCount); } } @@ -2693,8 +2693,8 @@ ProcessCopyStmt(CopyStmt *copyStatement, char *completionTag, const char *queryS int64 tuplesSent = WorkerExecuteSqlTask(query, filename, binaryCopyFormat); - snprintf(completionTag, COMPLETION_TAG_BUFSIZE, - "COPY " UINT64_FORMAT, tuplesSent); + SafeSnprintf(completionTag, COMPLETION_TAG_BUFSIZE, + "COPY " UINT64_FORMAT, tuplesSent); return NULL; } @@ -2795,8 +2795,8 @@ CitusCopyTo(CopyStmt *copyStatement, char *completionTag) if (completionTag != NULL) { - snprintf(completionTag, COMPLETION_TAG_BUFSIZE, "COPY " UINT64_FORMAT, - tuplesSent); + SafeSnprintf(completionTag, COMPLETION_TAG_BUFSIZE, "COPY " UINT64_FORMAT, + tuplesSent); } } diff --git a/src/backend/distributed/commands/type.c b/src/backend/distributed/commands/type.c index fec7e68fb..2e46a0f41 100644 --- a/src/backend/distributed/commands/type.c +++ b/src/backend/distributed/commands/type.c @@ -50,6 +50,7 @@ #include "catalog/pg_enum.h" #include "catalog/pg_type.h" #include "commands/extension.h" +#include "distributed/citus_safe_lib.h" #include "distributed/commands.h" #include "distributed/commands/utility_hook.h" #include "distributed/deparser.h" @@ -1075,16 +1076,17 @@ GenerateBackupNameForTypeCollision(const ObjectAddress *address) while (true) { - int suffixLength = snprintf(suffix, NAMEDATALEN - 1, "(citus_backup_%d)", - count); + int suffixLength = SafeSnprintf(suffix, NAMEDATALEN - 1, "(citus_backup_%d)", + count); /* trim the base name at the end to leave space for the suffix and trailing \0 */ baseLength = Min(baseLength, NAMEDATALEN - suffixLength - 1); /* clear newName before copying the potentially trimmed baseName and suffix */ memset(newName, 0, NAMEDATALEN); - strncpy(newName, baseName, baseLength); - strncpy(newName + baseLength, suffix, suffixLength); + strncpy_s(newName, NAMEDATALEN, baseName, baseLength); + strncpy_s(newName + baseLength, NAMEDATALEN - baseLength, suffix, + suffixLength); rel->relname = newName; TypeName *newTypeName = makeTypeNameFromNameList(MakeNameListFromRangeVar(rel)); diff --git a/src/backend/distributed/connection/connection_configuration.c b/src/backend/distributed/connection/connection_configuration.c index 9d1e423e9..8e21fcee0 100644 --- a/src/backend/distributed/connection/connection_configuration.c +++ b/src/backend/distributed/connection/connection_configuration.c @@ -10,6 +10,7 @@ #include "postgres.h" +#include "distributed/citus_safe_lib.h" #include "distributed/connection_management.h" #include "distributed/metadata_cache.h" #include "distributed/worker_manager.h" @@ -189,8 +190,8 @@ CheckConninfo(const char *conninfo, const char **whitelist, continue; } - void *matchingKeyword = bsearch(&option->keyword, whitelist, whitelistLength, - sizeof(char *), pg_qsort_strcmp); + void *matchingKeyword = SafeBsearch(&option->keyword, whitelist, whitelistLength, + sizeof(char *), pg_qsort_strcmp); if (matchingKeyword == NULL) { /* the whitelist lacks this keyword; error out! */ diff --git a/src/backend/distributed/connection/connection_management.c b/src/backend/distributed/connection/connection_management.c index 359431eb6..3e8d992f2 100644 --- a/src/backend/distributed/connection/connection_management.c +++ b/src/backend/distributed/connection/connection_management.c @@ -15,6 +15,8 @@ #include "miscadmin.h" +#include "safe_lib.h" + #include "access/hash.h" #include "commands/dbcommands.h" #include "distributed/connection_management.h" @@ -109,7 +111,8 @@ InitializeConnectionManagement(void) info.hcxt = ConnectionContext; uint32 hashFlags = (HASH_ELEM | HASH_FUNCTION | HASH_CONTEXT | HASH_COMPARE); - memcpy(&connParamsInfo, &info, sizeof(HASHCTL)); + /* connParamsInfo is same as info, except for entrysize */ + connParamsInfo = info; connParamsInfo.entrysize = sizeof(ConnParamsHashEntry); ConnectionHash = hash_create("citus connection cache (host,port,user,database)", @@ -1223,7 +1226,8 @@ DefaultCitusNoticeProcessor(void *arg, const char *message) char *nodeName = connection->hostname; uint32 nodePort = connection->port; char *trimmedMessage = TrimLogLevel(message); - char *level = strtok((char *) message, ":"); + char *strtokPosition; + char *level = strtok_r((char *) message, ":", &strtokPosition); ereport(CitusNoticeLogLevel, (errmsg("%s", ApplyLogRedaction(trimmedMessage)), diff --git a/src/backend/distributed/executor/multi_task_tracker_executor.c b/src/backend/distributed/executor/multi_task_tracker_executor.c index 459406e45..77da0a6e0 100644 --- a/src/backend/distributed/executor/multi_task_tracker_executor.c +++ b/src/backend/distributed/executor/multi_task_tracker_executor.c @@ -692,16 +692,17 @@ TrackerHash(const char *taskTrackerHashName, List *workerNodeList, char *userNam char *nodeName = workerNode->workerName; uint32 nodePort = workerNode->workerPort; - char taskStateHashName[MAXPGPATH]; uint32 taskStateCount = 32; HASHCTL info; /* insert task tracker into the tracker hash */ TaskTracker *taskTracker = TrackerHashEnter(taskTrackerHash, nodeName, nodePort); + /* for each task tracker, create hash to track its assigned tasks */ - snprintf(taskStateHashName, MAXPGPATH, - "Task Tracker \"%s:%u\" Task State Hash", nodeName, nodePort); + StringInfo taskStateHashName = makeStringInfo(); + appendStringInfo(taskStateHashName, "Task Tracker \"%s:%u\" Task State Hash", + nodeName, nodePort); memset(&info, 0, sizeof(info)); info.keysize = sizeof(uint64) + sizeof(uint32); @@ -710,12 +711,12 @@ TrackerHash(const char *taskTrackerHashName, List *workerNodeList, char *userNam info.hcxt = CurrentMemoryContext; int hashFlags = (HASH_ELEM | HASH_FUNCTION | HASH_CONTEXT); - HTAB *taskStateHash = hash_create(taskStateHashName, taskStateCount, &info, + HTAB *taskStateHash = hash_create(taskStateHashName->data, taskStateCount, &info, hashFlags); if (taskStateHash == NULL) { ereport(FATAL, (errcode(ERRCODE_OUT_OF_MEMORY), - errmsg("could not initialize %s", taskStateHashName))); + errmsg("could not initialize %s", taskStateHashName->data))); } taskTracker->taskStateHash = taskStateHash; @@ -781,7 +782,7 @@ TrackerHashEnter(HTAB *taskTrackerHash, char *nodeName, uint32 nodePort) } /* init task tracker object with zeroed out task tracker key */ - memcpy(taskTracker, &taskTrackerKey, sizeof(TaskTracker)); + *taskTracker = taskTrackerKey; taskTracker->trackerStatus = TRACKER_CONNECT_START; taskTracker->connectionId = INVALID_CONNECTION_ID; taskTracker->currentTaskIndex = -1; diff --git a/src/backend/distributed/master/master_metadata_utility.c b/src/backend/distributed/master/master_metadata_utility.c index 88d9a1161..ca3c162ed 100644 --- a/src/backend/distributed/master/master_metadata_utility.c +++ b/src/backend/distributed/master/master_metadata_utility.c @@ -31,6 +31,7 @@ #include "distributed/colocation_utils.h" #include "distributed/connection_management.h" #include "distributed/citus_nodes.h" +#include "distributed/citus_safe_lib.h" #include "distributed/listutils.h" #include "distributed/master_metadata_utility.h" #include "distributed/master_protocol.h" @@ -230,7 +231,7 @@ DistributedTableSizeOnWorker(WorkerNode *workerNode, Oid relationId, char *sizeQ List *sizeList = ReadFirstColumnAsText(result); StringInfo tableSizeStringInfo = (StringInfo) linitial(sizeList); char *tableSizeString = tableSizeStringInfo->data; - uint64 tableSize = atol(tableSizeString); + uint64 tableSize = SafeStringToUint64(tableSizeString); PQclear(result); ClearResults(connection, raiseErrors); @@ -608,7 +609,7 @@ void CopyShardPlacement(ShardPlacement *srcPlacement, ShardPlacement *destPlacement) { /* first copy all by-value fields */ - memcpy(destPlacement, srcPlacement, sizeof(ShardPlacement)); + *destPlacement = *srcPlacement; /* and then the fields pointing to external values */ if (srcPlacement->nodeName) diff --git a/src/backend/distributed/master/worker_node_manager.c b/src/backend/distributed/master/worker_node_manager.c index 63924ba00..f06dafa12 100644 --- a/src/backend/distributed/master/worker_node_manager.c +++ b/src/backend/distributed/master/worker_node_manager.c @@ -355,7 +355,7 @@ FilterActiveNodeListFunc(LOCKMODE lockMode, bool (*checkFunction)(WorkerNode *)) if (workerNode->isActive && checkFunction(workerNode)) { WorkerNode *workerNodeCopy = palloc0(sizeof(WorkerNode)); - memcpy(workerNodeCopy, workerNode, sizeof(WorkerNode)); + *workerNodeCopy = *workerNode; workerNodeList = lappend(workerNodeList, workerNodeCopy); } } diff --git a/src/backend/distributed/metadata/metadata_cache.c b/src/backend/distributed/metadata/metadata_cache.c index f455b2ecd..712821624 100644 --- a/src/backend/distributed/metadata/metadata_cache.c +++ b/src/backend/distributed/metadata/metadata_cache.c @@ -438,7 +438,7 @@ LoadGroupShardPlacement(uint64 shardId, uint64 placementId) { GroupShardPlacement *shardPlacement = CitusMakeNode(GroupShardPlacement); - memcpy(shardPlacement, &placementArray[i], sizeof(GroupShardPlacement)); + *shardPlacement = placementArray[i]; return shardPlacement; } @@ -513,9 +513,11 @@ ResolveGroupShardPlacement(GroupShardPlacement *groupShardPlacement, WorkerNode *workerNode = LookupNodeForGroup(groupId); /* copy everything into shardPlacement but preserve the header */ - memcpy((((CitusNode *) shardPlacement) + 1), - (((CitusNode *) groupShardPlacement) + 1), - sizeof(GroupShardPlacement) - sizeof(CitusNode)); + CitusNode header = shardPlacement->type; + GroupShardPlacement *shardPlacementAsGroupPlacement = + (GroupShardPlacement *) shardPlacement; + *shardPlacementAsGroupPlacement = *groupShardPlacement; + shardPlacement->type = header; shardPlacement->nodeName = pstrdup(workerNode->workerName); shardPlacement->nodePort = workerNode->workerPort; @@ -561,7 +563,7 @@ LookupNodeByNodeId(uint32 nodeId) if (workerNode->nodeId == nodeId) { WorkerNode *workerNodeCopy = palloc0(sizeof(WorkerNode)); - memcpy(workerNodeCopy, workerNode, sizeof(WorkerNode)); + *workerNodeCopy = *workerNode; return workerNodeCopy; } @@ -3597,7 +3599,7 @@ LookupDistPartitionTuple(Relation pgDistPartition, Oid relationId) ScanKeyData scanKey[1]; /* copy scankey to local copy, it will be modified during the scan */ - memcpy(scanKey, DistPartitionScanKey, sizeof(DistPartitionScanKey)); + scanKey[0] = DistPartitionScanKey[0]; /* set scan arguments */ scanKey[0].sk_argument = ObjectIdGetDatum(relationId); @@ -3631,7 +3633,7 @@ LookupDistShardTuples(Oid relationId) Relation pgDistShard = heap_open(DistShardRelationId(), AccessShareLock); /* copy scankey to local copy, it will be modified during the scan */ - memcpy(scanKey, DistShardScanKey, sizeof(DistShardScanKey)); + scanKey[0] = DistShardScanKey[0]; /* set scan arguments */ scanKey[0].sk_argument = ObjectIdGetDatum(relationId); diff --git a/src/backend/distributed/metadata/node_metadata.c b/src/backend/distributed/metadata/node_metadata.c index dd6e26119..28eece83c 100644 --- a/src/backend/distributed/metadata/node_metadata.c +++ b/src/backend/distributed/metadata/node_metadata.c @@ -873,7 +873,7 @@ FindWorkerNode(char *nodeName, int32 nodePort) if (handleFound) { WorkerNode *workerNode = (WorkerNode *) palloc(sizeof(WorkerNode)); - memcpy(workerNode, cachedWorkerNode, sizeof(WorkerNode)); + *workerNode = *cachedWorkerNode; return workerNode; } diff --git a/src/backend/distributed/planner/distributed_planner.c b/src/backend/distributed/planner/distributed_planner.c index e0d25cb97..63d970c2a 100644 --- a/src/backend/distributed/planner/distributed_planner.c +++ b/src/backend/distributed/planner/distributed_planner.c @@ -1029,7 +1029,7 @@ CreateDistributedPlan(uint64 planId, Query *originalQuery, Query *query, ParamLi standard_planner(newQuery, 0, boundParams); /* overwrite the old transformed query with the new transformed query */ - memcpy(query, newQuery, sizeof(Query)); + *query = *newQuery; /* recurse into CreateDistributedPlan with subqueries/CTEs replaced */ distributedPlan = CreateDistributedPlan(planId, originalQuery, query, NULL, false, diff --git a/src/backend/distributed/planner/multi_router_planner.c b/src/backend/distributed/planner/multi_router_planner.c index f0c371bf9..f953fb105 100644 --- a/src/backend/distributed/planner/multi_router_planner.c +++ b/src/backend/distributed/planner/multi_router_planner.c @@ -3294,8 +3294,7 @@ CopyRelationRestrictionContext(RelationRestrictionContext *oldContext) /* can't be copied, we copy (flatly) a RelOptInfo, and then decouple baserestrictinfo */ newRestriction->relOptInfo = palloc(sizeof(RelOptInfo)); - memcpy(newRestriction->relOptInfo, oldRestriction->relOptInfo, - sizeof(RelOptInfo)); + *newRestriction->relOptInfo = *oldRestriction->relOptInfo; newRestriction->relOptInfo->baserestrictinfo = copyObject(oldRestriction->relOptInfo->baserestrictinfo); diff --git a/src/backend/distributed/planner/recursive_planning.c b/src/backend/distributed/planner/recursive_planning.c index 4548f5499..89a3e3fd5 100644 --- a/src/backend/distributed/planner/recursive_planning.c +++ b/src/backend/distributed/planner/recursive_planning.c @@ -1142,7 +1142,7 @@ RecursivelyPlanSubquery(Query *subquery, RecursivePlanningContext *planningConte } /* finally update the input subquery to point the result query */ - memcpy(subquery, resultQuery, sizeof(Query)); + *subquery = *resultQuery; } diff --git a/src/backend/distributed/planner/shard_pruning.c b/src/backend/distributed/planner/shard_pruning.c index 2fbbc6d0c..24d909b83 100644 --- a/src/backend/distributed/planner/shard_pruning.c +++ b/src/backend/distributed/planner/shard_pruning.c @@ -981,7 +981,7 @@ CopyPartialPruningInstance(PruningInstance *sourceInstance) * being partial - if necessary it'll be marked so again by * PrunableExpressionsWalker(). */ - memcpy(newInstance, sourceInstance, sizeof(PruningInstance)); + *newInstance = *sourceInstance; newInstance->addedToPruningInstances = false; newInstance->isPartial = false; diff --git a/src/backend/distributed/relay/relay_event_utility.c b/src/backend/distributed/relay/relay_event_utility.c index f59fffc0f..deb2d3e0c 100644 --- a/src/backend/distributed/relay/relay_event_utility.c +++ b/src/backend/distributed/relay/relay_event_utility.c @@ -29,6 +29,7 @@ #include "catalog/namespace.h" #include "catalog/pg_class.h" #include "catalog/pg_constraint.h" +#include "distributed/citus_safe_lib.h" #include "distributed/commands.h" #include "distributed/metadata_cache.h" #include "distributed/relay_utility.h" @@ -694,8 +695,8 @@ AppendShardIdToName(char **name, uint64 shardId) NAMEDATALEN))); } - snprintf(shardIdAndSeparator, NAMEDATALEN, "%c" UINT64_FORMAT, - SHARD_NAME_SEPARATOR, shardId); + SafeSnprintf(shardIdAndSeparator, NAMEDATALEN, "%c" UINT64_FORMAT, + SHARD_NAME_SEPARATOR, shardId); int shardIdAndSeparatorLength = strlen(shardIdAndSeparator); /* @@ -705,7 +706,7 @@ AppendShardIdToName(char **name, uint64 shardId) if (nameLength < (NAMEDATALEN - shardIdAndSeparatorLength)) { - snprintf(extendedName, NAMEDATALEN, "%s%s", (*name), shardIdAndSeparator); + SafeSnprintf(extendedName, NAMEDATALEN, "%s%s", (*name), shardIdAndSeparator); } /* @@ -739,14 +740,14 @@ AppendShardIdToName(char **name, uint64 shardId) multiByteClipLength = pg_mbcliplen(*name, nameLength, (NAMEDATALEN - shardIdAndSeparatorLength - 10)); - snprintf(extendedName, NAMEDATALEN, "%.*s%c%.8x%s", - multiByteClipLength, (*name), - SHARD_NAME_SEPARATOR, longNameHash, - shardIdAndSeparator); + SafeSnprintf(extendedName, NAMEDATALEN, "%.*s%c%.8x%s", + multiByteClipLength, (*name), + SHARD_NAME_SEPARATOR, longNameHash, + shardIdAndSeparator); } (*name) = (char *) repalloc((*name), NAMEDATALEN); - int neededBytes = snprintf((*name), NAMEDATALEN, "%s", extendedName); + int neededBytes = SafeSnprintf((*name), NAMEDATALEN, "%s", extendedName); if (neededBytes < 0) { ereport(ERROR, (errcode(ERRCODE_OUT_OF_MEMORY), diff --git a/src/backend/distributed/shared_library_init.c b/src/backend/distributed/shared_library_init.c index b0ae355f8..c82f64553 100644 --- a/src/backend/distributed/shared_library_init.c +++ b/src/backend/distributed/shared_library_init.c @@ -13,7 +13,7 @@ #include #include -/* necessary to get alloca() on illumos */ +/* necessary to get alloca on illumos */ #ifdef __sun #include #endif @@ -21,11 +21,14 @@ #include "fmgr.h" #include "miscadmin.h" +#include "safe_lib.h" + #include "citus_version.h" #include "commands/explain.h" #include "executor/executor.h" #include "distributed/backend_data.h" #include "distributed/citus_nodefuncs.h" +#include "distributed/citus_safe_lib.h" #include "distributed/commands.h" #include "distributed/commands/multi_copy.h" #include "distributed/commands/utility_hook.h" @@ -187,6 +190,14 @@ _PG_init(void) "shared_preload_libraries."))); } + /* + * Register contstraint_handler hooks of safestringlib first. This way + * loading the extension will error out if one of these constraints are hit + * during load. + */ + set_str_constraint_handler_s(ereport_constraint_handler); + set_mem_constraint_handler_s(ereport_constraint_handler); + /* * Perform checks before registering any hooks, to avoid erroring out in a * partial state. @@ -290,7 +301,13 @@ ResizeStackToMaximumDepth(void) #ifndef WIN32 long max_stack_depth_bytes = max_stack_depth * 1024L; - volatile char *stack_resizer = alloca(max_stack_depth_bytes); + /* + * Explanation of IGNORE-BANNED: + * alloca is safe to use here since we limit the allocated size. We cannot + * use malloc as a replacement, since we actually want to grow the stack + * here. + */ + volatile char *stack_resizer = alloca(max_stack_depth_bytes); /* IGNORE-BANNED */ /* * Different architectures might have different directions while diff --git a/src/backend/distributed/transaction/backend_data.c b/src/backend/distributed/transaction/backend_data.c index 27222e69e..268d9b521 100644 --- a/src/backend/distributed/transaction/backend_data.c +++ b/src/backend/distributed/transaction/backend_data.c @@ -806,7 +806,7 @@ GetBackendDataForProc(PGPROC *proc, BackendData *result) SpinLockAcquire(&backendData->mutex); - memcpy(result, backendData, sizeof(BackendData)); + *result = *backendData; SpinLockRelease(&backendData->mutex); } diff --git a/src/backend/distributed/transaction/remote_transaction.c b/src/backend/distributed/transaction/remote_transaction.c index 34e1cb15f..6c4a011a7 100644 --- a/src/backend/distributed/transaction/remote_transaction.c +++ b/src/backend/distributed/transaction/remote_transaction.c @@ -16,6 +16,7 @@ #include "access/xact.h" #include "distributed/backend_data.h" +#include "distributed/citus_safe_lib.h" #include "distributed/connection_management.h" #include "distributed/metadata_cache.h" #include "distributed/remote_commands.h" @@ -1330,9 +1331,9 @@ Assign2PCIdentifier(MultiConnection *connection) uint64 transactionNumber = CurrentDistributedTransactionNumber(); /* print all numbers as unsigned to guarantee no minus symbols appear in the name */ - snprintf(connection->remoteTransaction.preparedName, NAMEDATALEN, - PREPARED_TRANSACTION_NAME_FORMAT, GetLocalGroupId(), MyProcPid, - transactionNumber, connectionNumber++); + SafeSnprintf(connection->remoteTransaction.preparedName, NAMEDATALEN, + PREPARED_TRANSACTION_NAME_FORMAT, GetLocalGroupId(), MyProcPid, + transactionNumber, connectionNumber++); } diff --git a/src/backend/distributed/transaction/transaction_management.c b/src/backend/distributed/transaction/transaction_management.c index b5bedc3ac..a547ddec7 100644 --- a/src/backend/distributed/transaction/transaction_management.c +++ b/src/backend/distributed/transaction/transaction_management.c @@ -20,6 +20,7 @@ #include "access/twophase.h" #include "access/xact.h" #include "distributed/backend_data.h" +#include "distributed/citus_safe_lib.h" #include "distributed/connection_management.h" #include "distributed/distributed_planner.h" #include "distributed/hash_helpers.h" @@ -538,7 +539,7 @@ AdjustMaxPreparedTransactions(void) { char newvalue[12]; - snprintf(newvalue, sizeof(newvalue), "%d", MaxConnections * 2); + SafeSnprintf(newvalue, sizeof(newvalue), "%d", MaxConnections * 2); SetConfigOption("max_prepared_transactions", newvalue, PGC_POSTMASTER, PGC_S_OVERRIDE); diff --git a/src/backend/distributed/utils/acquire_lock.c b/src/backend/distributed/utils/acquire_lock.c index 229340c97..dee96fb68 100644 --- a/src/backend/distributed/utils/acquire_lock.c +++ b/src/backend/distributed/utils/acquire_lock.c @@ -34,6 +34,7 @@ #include "utils/snapmgr.h" #include "distributed/citus_acquire_lock.h" +#include "distributed/citus_safe_lib.h" #include "distributed/connection_management.h" #include "distributed/version_compat.h" @@ -75,27 +76,21 @@ StartLockAcquireHelperBackgroundWorker(int backendToHelp, int32 lock_cooldown) args.lock_cooldown = lock_cooldown; /* construct the background worker and start it */ - snprintf(worker.bgw_name, BGW_MAXLEN, - "Citus Lock Acquire Helper: %d/%u", - backendToHelp, MyDatabaseId); - snprintf(worker.bgw_type, BGW_MAXLEN, "citus_lock_aqcuire"); + SafeSnprintf(worker.bgw_name, sizeof(worker.bgw_name), + "Citus Lock Acquire Helper: %d/%u", backendToHelp, MyDatabaseId); + strcpy_s(worker.bgw_type, sizeof(worker.bgw_type), "citus_lock_aqcuire"); worker.bgw_flags = BGWORKER_SHMEM_ACCESS | BGWORKER_BACKEND_DATABASE_CONNECTION; worker.bgw_start_time = BgWorkerStart_RecoveryFinished; worker.bgw_restart_time = BGW_NEVER_RESTART; - snprintf(worker.bgw_library_name, BGW_MAXLEN, "citus"); - snprintf(worker.bgw_function_name, BGW_MAXLEN, "LockAcquireHelperMain"); + strcpy_s(worker.bgw_library_name, sizeof(worker.bgw_library_name), "citus"); + strcpy_s(worker.bgw_function_name, sizeof(worker.bgw_function_name), + "LockAcquireHelperMain"); worker.bgw_main_arg = Int32GetDatum(backendToHelp); worker.bgw_notify_pid = 0; - /* - * we check if args fits in bgw_extra to make sure it is safe to copy the data. Once - * we exceed the size of data to copy this way we need to look into a different way of - * passing the arguments to the worker. - */ - Assert(sizeof(worker.bgw_extra) >= sizeof(args)); - memcpy(worker.bgw_extra, &args, sizeof(args)); + memcpy_s(worker.bgw_extra, sizeof(worker.bgw_extra), &args, sizeof(args)); if (!RegisterDynamicBackgroundWorker(&worker, &handle)) { diff --git a/src/backend/distributed/utils/citus_safe_lib.c b/src/backend/distributed/utils/citus_safe_lib.c new file mode 100644 index 000000000..2e9875051 --- /dev/null +++ b/src/backend/distributed/utils/citus_safe_lib.c @@ -0,0 +1,302 @@ +/*------------------------------------------------------------------------- + * + * safe_lib.c + * + * This file contains all SafeXXXX helper functions that we implement to + * replace missing xxxx_s functions implemented by safestringlib. It also + * contains a constraint handler for use in both our SafeXXX and safestringlib + * its xxxx_s functions. + * + * Copyright (c) Citus Data, Inc. + * + *------------------------------------------------------------------------- + */ + +/* In PG 11 pg_vsnprintf is not exported unless you set this define */ +#if PG_VERSION_NUM < 120000 +#define USE_REPL_SNPRINTF 1 +#endif + +#include "postgres.h" + +#include "safe_lib.h" + +#include + +#include "distributed/citus_safe_lib.h" +#include "lib/stringinfo.h" + + +/* + * ereport_constraint_handler is a constraint handler that calls ereport. A + * constraint handler is called whenever an error occurs in any of the + * safestringlib xxxx_s functions or our SafeXXXX functions. + * + * More info on constraint handlers can be found here: + * https://en.cppreference.com/w/c/error/set_constraint_handler_s + */ +void +ereport_constraint_handler(const char *message, + void *pointer, + errno_t error) +{ + if (message && error) + { + ereport(ERROR, (errcode(ERRCODE_INTERNAL_ERROR), errmsg( + "Memory constraint error: %s (errno %d)", message, error))); + } + else if (message) + { + ereport(ERROR, (errcode(ERRCODE_INTERNAL_ERROR), errmsg( + "Memory constraint error: %s", message))); + } + else if (error) + { + ereport(ERROR, (errcode(ERRCODE_INTERNAL_ERROR), errmsg( + "Unknown function failed with memory constraint error (errno %d)", + error))); + } + else + { + ereport(ERROR, (errcode(ERRCODE_INTERNAL_ERROR), errmsg( + "Unknown function failed with memory constraint error"))); + } +} + + +/* + * SafeStringToInt64 converts a string containing a number to a int64. When it + * fails it calls ereport. + * + * The different error cases are inspired by + * https://stackoverflow.com/a/26083517/2570866 + */ +int64 +SafeStringToInt64(const char *str) +{ + char *endptr; + errno = 0; + int64 number = strtol(str, &endptr, 10); + + if (str == endptr) + { + ereport(ERROR, (errmsg("Error parsing %s as int64, no digits found\n", str))); + } + else if (errno == ERANGE && number == LONG_MIN) + { + ereport(ERROR, (errmsg("Error parsing %s as int64, underflow occured\n", str))); + } + else if (errno == ERANGE && number == LONG_MAX) + { + ereport(ERROR, (errmsg("Error parsing %s as int64, overflow occured\n", str))); + } + else if (errno == EINVAL) + { + ereport(ERROR, (errmsg( + "Error parsing %s as int64, base contains unsupported value\n", + str))); + } + else if (errno != 0 && number == 0) + { + int err = errno; + ereport(ERROR, (errmsg("Error parsing %s as int64, errno %d\n", str, err))); + } + else if (errno == 0 && str && *endptr != '\0') + { + ereport(ERROR, (errmsg( + "Error parsing %s as int64, aditional characters remain after int64\n", + str))); + } + return number; +} + + +/* + * SafeStringToUint64 converts a string containing a number to a uint64. When it + * fails it calls ereport. + * + * The different error cases are inspired by + * https://stackoverflow.com/a/26083517/2570866 + */ +uint64 +SafeStringToUint64(const char *str) +{ + char *endptr; + errno = 0; + uint64 number = strtoul(str, &endptr, 10); + + if (str == endptr) + { + ereport(ERROR, (errmsg("Error parsing %s as uint64, no digits found\n", str))); + } + else if (errno == ERANGE && number == LONG_MIN) + { + ereport(ERROR, (errmsg("Error parsing %s as uint64, underflow occured\n", str))); + } + else if (errno == ERANGE && number == LONG_MAX) + { + ereport(ERROR, (errmsg("Error parsing %s as uint64, overflow occured\n", str))); + } + else if (errno == EINVAL) + { + ereport(ERROR, (errmsg( + "Error parsing %s as uint64, base contains unsupported value\n", + str))); + } + else if (errno != 0 && number == 0) + { + int err = errno; + ereport(ERROR, (errmsg("Error parsing %s as uint64, errno %d\n", str, err))); + } + else if (errno == 0 && str && *endptr != '\0') + { + ereport(ERROR, (errmsg( + "Error parsing %s as uint64, aditional characters remain after uint64\n", + str))); + } + return number; +} + + +/* + * SafeQsort is the non reentrant version of qsort (qsort vs qsort_r), but it + * does the input checks required for qsort_s: + * 1. count or size is greater than RSIZE_MAX + * 2. ptr or comp is a null pointer (unless count is zero) + * source: https://en.cppreference.com/w/c/algorithm/qsort + * + * When it hits these errors it calls the ereport_constraint_handler. + * + * NOTE: this functions calls pg_qsort instead of stdlib qsort. + */ +void +SafeQsort(void *ptr, rsize_t count, rsize_t size, + int (*comp)(const void *, const void *)) +{ + if (count > RSIZE_MAX_MEM) + { + ereport_constraint_handler("SafeQsort: count exceeds max", + NULL, ESLEMAX); + } + + if (size > RSIZE_MAX_MEM) + { + ereport_constraint_handler("SafeQsort: size exceeds max", + NULL, ESLEMAX); + } + if (size != 0) + { + if (ptr == NULL) + { + ereport_constraint_handler("SafeQsort: ptr is NULL", NULL, ESNULLP); + } + if (comp == NULL) + { + ereport_constraint_handler("SafeQsort: comp is NULL", NULL, ESNULLP); + } + } + pg_qsort(ptr, count, size, comp); +} + + +/* + * SafeBsearch is a non reentrant version of bsearch, but it does the + * input checks required for bsearch_s: + * 1. count or size is greater than RSIZE_MAX + * 2. key, ptr or comp is a null pointer (unless count is zero) + * source: https://en.cppreference.com/w/c/algorithm/bsearch + * + * When it hits these errors it calls the ereport_constraint_handler. + * + * NOTE: this functions calls pg_qsort instead of stdlib qsort. + */ +void * +SafeBsearch(const void *key, const void *ptr, rsize_t count, rsize_t size, + int (*comp)(const void *, const void *)) +{ + if (count > RSIZE_MAX_MEM) + { + ereport_constraint_handler("SafeBsearch: count exceeds max", + NULL, ESLEMAX); + } + + if (size > RSIZE_MAX_MEM) + { + ereport_constraint_handler("SafeBsearch: size exceeds max", + NULL, ESLEMAX); + } + if (size != 0) + { + if (key == NULL) + { + ereport_constraint_handler("SafeBsearch: key is NULL", NULL, ESNULLP); + } + if (ptr == NULL) + { + ereport_constraint_handler("SafeBsearch: ptr is NULL", NULL, ESNULLP); + } + if (comp == NULL) + { + ereport_constraint_handler("SafeBsearch: comp is NULL", NULL, ESNULLP); + } + } + + /* + * Explanation of IGNORE-BANNED: + * bsearch is safe to use here since we check the same thing bsearch_s + * does. We cannot use bsearch_s as a replacement, since it's not available + * in safestringlib. + */ + return bsearch(key, ptr, count, size, comp); /* IGNORE-BANNED */ +} + + +/* + * SafeSnprintf is a safer replacement for snprintf, which is needed since + * safestringlib doesn't implement snprintf_s. + * + * The required failure modes of snprint_s are as follows (in parentheses if + * this implements it and how): + * 1. the conversion specifier %n is present in format (yes, %n is not + * supported by pg_vsnprintf) + * 2. any of the arguments corresponding to %s is a null pointer (half, checked + * in postgres when asserts are enabled) + * 3. format or buffer is a null pointer (yes, checked by this function) + * 4. bufsz is zero or greater than RSIZE_MAX (yes, checked by this function) + * 5. encoding errors occur in any of string and character conversion + * specifiers (no clue what postgres does in this case) + * source: https://en.cppreference.com/w/c/io/fprintf + */ +int +SafeSnprintf(char *restrict buffer, rsize_t bufsz, const char *restrict format, ...) +{ + /* failure mode 3 */ + if (buffer == NULL) + { + ereport_constraint_handler("SafeSnprintf: buffer is NULL", NULL, ESNULLP); + } + if (format == NULL) + { + ereport_constraint_handler("SafeSnprintf: format is NULL", NULL, ESNULLP); + } + + /* failure mode 4 */ + if (bufsz == 0) + { + ereport_constraint_handler("SafeSnprintf: bufsz is 0", + NULL, ESZEROL); + } + + if (bufsz > RSIZE_MAX_STR) + { + ereport_constraint_handler("SafeSnprintf: bufsz exceeds max", + NULL, ESLEMAX); + } + + va_list args; + + va_start(args, format); + size_t result = pg_vsnprintf(buffer, bufsz, format, args); + va_end(args); + return result; +} diff --git a/src/backend/distributed/utils/listutils.c b/src/backend/distributed/utils/listutils.c index a94beab91..831418f0f 100644 --- a/src/backend/distributed/utils/listutils.c +++ b/src/backend/distributed/utils/listutils.c @@ -15,6 +15,7 @@ #include "utils/lsyscache.h" #include "lib/stringinfo.h" +#include "distributed/citus_safe_lib.h" #include "distributed/listutils.h" #include "nodes/pg_list.h" #include "utils/memutils.h" @@ -49,7 +50,7 @@ SortList(List *pointerList, int (*comparisonFunction)(const void *, const void * } /* sort the array of pointers using the comparison function */ - qsort(array, arraySize, sizeof(void *), comparisonFunction); + SafeQsort(array, arraySize, sizeof(void *), comparisonFunction); /* convert the sorted array of pointers back to a sorted list */ for (arrayIndex = 0; arrayIndex < arraySize; arrayIndex++) diff --git a/src/backend/distributed/utils/maintenanced.c b/src/backend/distributed/utils/maintenanced.c index efcf6c4ea..d6ba2ad0f 100644 --- a/src/backend/distributed/utils/maintenanced.c +++ b/src/backend/distributed/utils/maintenanced.c @@ -30,6 +30,7 @@ #include "commands/extension.h" #include "libpq/pqsignal.h" #include "catalog/namespace.h" +#include "distributed/citus_safe_lib.h" #include "distributed/distributed_deadlock_detection.h" #include "distributed/maintenanced.h" #include "distributed/master_protocol.h" @@ -171,9 +172,9 @@ InitializeMaintenanceDaemonBackend(void) memset(&worker, 0, sizeof(worker)); - snprintf(worker.bgw_name, BGW_MAXLEN, - "Citus Maintenance Daemon: %u/%u", - MyDatabaseId, extensionOwner); + SafeSnprintf(worker.bgw_name, sizeof(worker.bgw_name), + "Citus Maintenance Daemon: %u/%u", + MyDatabaseId, extensionOwner); /* request ability to connect to target database */ worker.bgw_flags = BGWORKER_SHMEM_ACCESS | BGWORKER_BACKEND_DATABASE_CONNECTION; @@ -186,10 +187,14 @@ InitializeMaintenanceDaemonBackend(void) /* Restart after a bit after errors, but don't bog the system. */ worker.bgw_restart_time = 5; - sprintf(worker.bgw_library_name, "citus"); - sprintf(worker.bgw_function_name, "CitusMaintenanceDaemonMain"); + strcpy_s(worker.bgw_library_name, + sizeof(worker.bgw_library_name), "citus"); + strcpy_s(worker.bgw_function_name, sizeof(worker.bgw_library_name), + "CitusMaintenanceDaemonMain"); + worker.bgw_main_arg = ObjectIdGetDatum(MyDatabaseId); - memcpy(worker.bgw_extra, &extensionOwner, sizeof(Oid)); + memcpy_s(worker.bgw_extra, sizeof(worker.bgw_extra), &extensionOwner, + sizeof(Oid)); worker.bgw_notify_pid = MyProcPid; if (!RegisterDynamicBackgroundWorker(&worker, &handle)) diff --git a/src/backend/distributed/utils/statistics_collection.c b/src/backend/distributed/utils/statistics_collection.c index 0cce1181f..81d51c598 100644 --- a/src/backend/distributed/utils/statistics_collection.c +++ b/src/backend/distributed/utils/statistics_collection.c @@ -21,17 +21,7 @@ PG_FUNCTION_INFO_V1(citus_server_id); #ifdef HAVE_LIBCURL #include -#ifndef WIN32 #include -#else -typedef struct utsname -{ - char sysname[65]; - char release[65]; - char version[65]; - char machine[65]; -} utsname; -#endif #include "access/xact.h" #include "distributed/metadata_cache.h" @@ -54,9 +44,6 @@ static bool SendHttpPostJsonRequest(const char *url, const char *postFields, long timeoutSeconds, curl_write_callback responseCallback); static bool PerformHttpRequest(CURL *curl); -#ifdef WIN32 -static int uname(struct utsname *buf); -#endif /* WarnIfSyncDNS warns if libcurl is compiled with synchronous DNS. */ @@ -360,103 +347,3 @@ citus_server_id(PG_FUNCTION_ARGS) PG_RETURN_UUID_P((pg_uuid_t *) buf); } - - -#ifdef WIN32 - -/* - * Inspired by perl5's win32_uname - * https://github.com/Perl/perl5/blob/69374fe705978962b85217f3eb828a93f836fd8d/win32/win32.c#L2057 - */ -static int -uname(struct utsname *buf) -{ - OSVERSIONINFO ver; - - ver.dwOSVersionInfoSize = sizeof(ver); - GetVersionEx(&ver); - - switch (ver.dwPlatformId) - { - case VER_PLATFORM_WIN32_WINDOWS: - { - strcpy(buf->sysname, "Windows"); - break; - } - - case VER_PLATFORM_WIN32_NT: - { - strcpy(buf->sysname, "Windows NT"); - break; - } - - case VER_PLATFORM_WIN32s: - { - strcpy(buf->sysname, "Win32s"); - break; - } - - default: - { - strcpy(buf->sysname, "Win32 Unknown"); - break; - } - } - - sprintf(buf->release, "%d.%d", ver.dwMajorVersion, ver.dwMinorVersion); - - { - SYSTEM_INFO info; - char *arch; - - GetSystemInfo(&info); - DWORD procarch = info.wProcessorArchitecture; - - switch (procarch) - { - case PROCESSOR_ARCHITECTURE_INTEL: - { - arch = "x86"; - break; - } - - case PROCESSOR_ARCHITECTURE_IA64: - { - arch = "x86"; - break; - } - - case PROCESSOR_ARCHITECTURE_AMD64: - { - arch = "x86"; - break; - } - - case PROCESSOR_ARCHITECTURE_UNKNOWN: - { - arch = "x86"; - break; - } - - default: - { - arch = NULL; - break; - } - } - - if (arch != NULL) - { - strcpy(buf->machine, arch); - } - else - { - sprintf(buf->machine, "unknown(0x%x)", procarch); - } - } - - return 0; -} - - -#endif diff --git a/src/backend/distributed/worker/task_tracker.c b/src/backend/distributed/worker/task_tracker.c index 5d7bbe4fe..5858ef446 100644 --- a/src/backend/distributed/worker/task_tracker.c +++ b/src/backend/distributed/worker/task_tracker.c @@ -30,6 +30,7 @@ #include #include "commands/dbcommands.h" +#include "distributed/citus_safe_lib.h" #include "distributed/multi_client_executor.h" #include "distributed/multi_server_executor.h" #include "distributed/task_tracker.h" @@ -117,10 +118,11 @@ TaskTrackerRegister(void) worker.bgw_flags = BGWORKER_SHMEM_ACCESS; worker.bgw_start_time = BgWorkerStart_ConsistentState; worker.bgw_restart_time = 1; - snprintf(worker.bgw_library_name, BGW_MAXLEN, "citus"); - snprintf(worker.bgw_function_name, BGW_MAXLEN, "TaskTrackerMain"); + strcpy_s(worker.bgw_library_name, sizeof(worker.bgw_library_name), "citus"); + strcpy_s(worker.bgw_function_name, sizeof(worker.bgw_function_name), + "TaskTrackerMain"); worker.bgw_notify_pid = 0; - snprintf(worker.bgw_name, BGW_MAXLEN, "task tracker"); + strcpy_s(worker.bgw_name, sizeof(worker.bgw_name), "task tracker"); RegisterBackgroundWorker(&worker); } @@ -702,7 +704,7 @@ SchedulableTaskPriorityQueue(HTAB *WorkerTasksHash) } /* now order elements in the queue according to our sorting criterion */ - qsort(priorityQueue, queueSize, WORKER_TASK_SIZE, CompareTasksByTime); + SafeQsort(priorityQueue, queueSize, WORKER_TASK_SIZE, CompareTasksByTime); return priorityQueue; } diff --git a/src/backend/distributed/worker/worker_data_fetch_protocol.c b/src/backend/distributed/worker/worker_data_fetch_protocol.c index c6d72d549..9e3a364f1 100644 --- a/src/backend/distributed/worker/worker_data_fetch_protocol.c +++ b/src/backend/distributed/worker/worker_data_fetch_protocol.c @@ -226,7 +226,7 @@ ReceiveRegularFile(const char *nodeName, uint32 nodePort, const char *nodeUser, bool copyDone = false; /* create local file to append remote data to */ - snprintf(filename, MAXPGPATH, "%s", filePath->data); + strlcpy(filename, filePath->data, MAXPGPATH); int32 fileDescriptor = BasicOpenFilePerm(filename, fileFlags, fileMode); if (fileDescriptor < 0) diff --git a/src/include/distributed/citus_safe_lib.h b/src/include/distributed/citus_safe_lib.h new file mode 100644 index 000000000..f1552fbe1 --- /dev/null +++ b/src/include/distributed/citus_safe_lib.h @@ -0,0 +1,29 @@ +/*------------------------------------------------------------------------- + * + * safe_lib.h + * + * This file contains helper functions to expand on the _s functions from + * safestringlib. + * + * Copyright (c) Citus Data, Inc. + * + *------------------------------------------------------------------------- + */ + +#ifndef CITUS_safe_lib_H +#define CITUS_safe_lib_H + +#include "postgres.h" + +#include "safe_lib.h" + +extern void ereport_constraint_handler(const char *message, void *pointer, errno_t error); +extern int64 SafeStringToInt64(const char *str); +extern uint64 SafeStringToUint64(const char *str); +extern void SafeQsort(void *ptr, rsize_t count, rsize_t size, + int (*comp)(const void *, const void *)); +void * SafeBsearch(const void *key, const void *ptr, rsize_t count, rsize_t size, + int (*comp)(const void *, const void *)); +int SafeSnprintf(char *str, rsize_t count, const char *fmt, ...); + +#endif