diff --git a/src/backend/distributed/connection/connection_configuration.c b/src/backend/distributed/connection/connection_configuration.c index a689a6a92..4852a6ef5 100644 --- a/src/backend/distributed/connection/connection_configuration.c +++ b/src/backend/distributed/connection/connection_configuration.c @@ -77,7 +77,6 @@ void ResetConnParams() { Index paramIdx = 0; - for (paramIdx = 0; paramIdx < ConnParams.size; paramIdx++) { free((void *) ConnParams.keywords[paramIdx]); @@ -88,21 +87,22 @@ ResetConnParams() ConnParams.size = 0; + InvalidateConnParamsHashEntries(); + AddConnParam("fallback_application_name", "citus"); } /* * AddConnParam adds a parameter setting to the global libpq settings according - * to the provided keyword and value. Under assert-enabled builds, array bounds - * checking is performed. + * to the provided keyword and value. */ void AddConnParam(const char *keyword, const char *value) { if (ConnParams.size + 1 >= ConnParams.maxSize) { - /* we expect developers to see that error messages */ + /* hopefully this error is only seen by developers */ ereport(ERROR, (errcode(ERRCODE_INSUFFICIENT_RESOURCES), errmsg("ConnParams arrays bound check failed"))); } @@ -227,7 +227,7 @@ CheckConninfo(const char *conninfo, const char **whitelist, */ void GetConnParams(ConnectionHashKey *key, char ***keywords, char ***values, - MemoryContext context) + Index *runtimeParamStart, MemoryContext context) { /* make space for the port as a string: sign, 10 digits, NUL */ char *nodePortString = MemoryContextAlloc(context, 12 * sizeof(char *)); @@ -241,16 +241,24 @@ GetConnParams(ConnectionHashKey *key, char ***keywords, char ***values, * The global parameters have already been assigned from a GUC, so begin by * calculating the key-specific parameters (basically just the fields of * the key and the active database encoding). + * + * We allocate everything in the provided context so as to facilitate using + * pfree on all runtime parameters when connections using these entries are + * invalidated during config reloads. */ const char *runtimeKeywords[] = { - "host", "port", "dbname", "user", "client_encoding" + MemoryContextStrdup(context, "host"), + MemoryContextStrdup(context, "port"), + MemoryContextStrdup(context, "dbname"), + MemoryContextStrdup(context, "user"), + MemoryContextStrdup(context, "client_encoding") }; const char *runtimeValues[] = { MemoryContextStrdup(context, key->hostname), nodePortString, MemoryContextStrdup(context, key->database), MemoryContextStrdup(context, key->user), - GetDatabaseEncodingName() + MemoryContextStrdup(context, GetDatabaseEncodingName()) }; /* @@ -265,12 +273,12 @@ GetConnParams(ConnectionHashKey *key, char ***keywords, char ***values, /* auth keywords will begin after global and runtime ones are appended */ Index authParamsIdx = ConnParams.size + lengthof(runtimeKeywords); - int paramIndex = 0; - int runtimeParamIndex = 0; + Index paramIndex = 0; + Index runtimeParamIndex = 0; if (ConnParams.size + lengthof(runtimeKeywords) >= ConnParams.maxSize) { - /* unexpected, intended as developers rather than users */ + /* hopefully this error is only seen by developers */ ereport(ERROR, (errcode(ERRCODE_INVALID_PARAMETER_VALUE), errmsg("too many connParams entries"))); } @@ -285,6 +293,9 @@ GetConnParams(ConnectionHashKey *key, char ***keywords, char ***values, connValues[paramIndex] = ConnParams.values[paramIndex]; } + /* remember where global/GUC params end and runtime ones start */ + *runtimeParamStart = ConnParams.size; + /* second step: begin at end of global params and copy runtime ones */ for (runtimeParamIndex = 0; runtimeParamIndex < lengthof(runtimeKeywords); @@ -312,7 +323,7 @@ GetConnParams(ConnectionHashKey *key, char ***keywords, char ***values, const char * GetConnParam(const char *keyword) { - int i = 0; + Index i = 0; for (i = 0; i < ConnParams.size; i++) { diff --git a/src/backend/distributed/connection/connection_management.c b/src/backend/distributed/connection/connection_management.c index f95f5f8ab..10a0e93af 100644 --- a/src/backend/distributed/connection/connection_management.c +++ b/src/backend/distributed/connection/connection_management.c @@ -42,6 +42,7 @@ MemoryContext ConnectionContext = NULL; static uint32 ConnectionHashHash(const void *key, Size keysize); static int ConnectionHashCompare(const void *a, const void *b, Size keysize); static MultiConnection * StartConnectionEstablishment(ConnectionHashKey *key); +static void FreeConnParamsHashEntryFields(ConnParamsHashEntry *entry); static void AfterXactHostConnectionHandling(ConnectionHashEntry *entry, bool isCommit); static void DefaultCitusNoticeProcessor(void *arg, const char *message); static MultiConnection * FindAvailableConnection(dlist_head *connections, uint32 flags); @@ -91,6 +92,26 @@ InitializeConnectionManagement(void) } +/* + * InvalidateConnParamsHashEntries sets every hash entry's isValid flag to false. + */ +void +InvalidateConnParamsHashEntries(void) +{ + if (ConnParamsHash != NULL) + { + ConnParamsHashEntry *entry = NULL; + HASH_SEQ_STATUS status; + + hash_seq_init(&status, ConnParamsHash); + while ((entry = (ConnParamsHashEntry *) hash_seq_search(&status)) != NULL) + { + entry->isValid = false; + } + } +} + + /* * Perform connection management activity after the end of a transaction. Both * COMMIT and ABORT paths are handled here. @@ -694,8 +715,15 @@ StartConnectionEstablishment(ConnectionHashKey *key) entry = hash_search(ConnParamsHash, key, HASH_ENTER, &found); if (!found || !entry->isValid) { - /* if they're not found, compute them from GUC, runtime, etc. */ - GetConnParams(key, &entry->keywords, &entry->values, ConnectionContext); + /* avoid leaking memory in the keys and values arrays */ + if (found && !entry->isValid) + { + FreeConnParamsHashEntryFields(entry); + } + + /* if not found or not valid, compute them from GUC, runtime, etc. */ + GetConnParams(key, &entry->keywords, &entry->values, &entry->runtimeParamStart, + ConnectionContext); entry->isValid = true; } @@ -726,6 +754,34 @@ StartConnectionEstablishment(ConnectionHashKey *key) } +/* + * FreeConnParamsHashEntryFields frees any dynamically allocated memory reachable + * from the fields of the provided ConnParamsHashEntry. This includes all runtime + * libpq keywords and values, as well as the actual arrays storing them. + */ +static void +FreeConnParamsHashEntryFields(ConnParamsHashEntry *entry) +{ + char **keyword = &entry->keywords[entry->runtimeParamStart]; + char **value = &entry->values[entry->runtimeParamStart]; + + while (*keyword != NULL) + { + pfree(*keyword); + keyword++; + } + + while (*value != NULL) + { + pfree(*value); + value++; + } + + pfree(entry->keywords); + pfree(entry->values); +} + + /* * Close all remote connections if necessary anymore (i.e. not session * lifetime), or if in a failed state. diff --git a/src/backend/distributed/utils/metadata_cache.c b/src/backend/distributed/utils/metadata_cache.c index 14159a3ae..352239c2f 100644 --- a/src/backend/distributed/utils/metadata_cache.c +++ b/src/backend/distributed/utils/metadata_cache.c @@ -3178,15 +3178,7 @@ CreateDistTableCache(void) void InvalidateMetadataSystemCache(void) { - ConnParamsHashEntry *entry = NULL; - HASH_SEQ_STATUS status; - - hash_seq_init(&status, ConnParamsHash); - - while ((entry = (ConnParamsHashEntry *) hash_seq_search(&status)) != NULL) - { - entry->isValid = false; - } + InvalidateConnParamsHashEntries(); memset(&MetadataCache, 0, sizeof(MetadataCache)); workerNodeHashValid = false; diff --git a/src/backend/distributed/worker/task_tracker.c b/src/backend/distributed/worker/task_tracker.c index 8bd110007..3ff754e39 100644 --- a/src/backend/distributed/worker/task_tracker.c +++ b/src/backend/distributed/worker/task_tracker.c @@ -861,15 +861,7 @@ ManageWorkerTasksHash(HTAB *WorkerTasksHash) if (!WorkerTasksSharedState->conninfosValid) { - ConnParamsHashEntry *entry = NULL; - HASH_SEQ_STATUS status; - - hash_seq_init(&status, ConnParamsHash); - - while ((entry = (ConnParamsHashEntry *) hash_seq_search(&status)) != NULL) - { - entry->isValid = false; - } + InvalidateConnParamsHashEntries(); } /* schedule new tasks if we have any */ diff --git a/src/include/distributed/connection_management.h b/src/include/distributed/connection_management.h index 6b2e209c0..b6b6b0f1e 100644 --- a/src/include/distributed/connection_management.h +++ b/src/include/distributed/connection_management.h @@ -118,6 +118,7 @@ typedef struct ConnParamsHashEntry { ConnectionHashKey key; bool isValid; + Index runtimeParamStart; char **keywords; char **values; } ConnParamsHashEntry; @@ -142,9 +143,10 @@ extern void InitializeConnectionManagement(void); extern void InitConnParams(void); extern void ResetConnParams(void); +extern void InvalidateConnParamsHashEntries(void); extern void AddConnParam(const char *keyword, const char *value); extern void GetConnParams(ConnectionHashKey *key, char ***keywords, char ***values, - MemoryContext context); + Index *runtimeParamStart, MemoryContext context); extern const char * GetConnParam(const char *keyword); extern bool CheckConninfo(const char *conninfo, const char **whitelist, Size whitelistLength, char **errmsg); diff --git a/src/test/regress/expected/multi_size_queries.out b/src/test/regress/expected/multi_size_queries.out index 648b81b78..27db22100 100644 --- a/src/test/regress/expected/multi_size_queries.out +++ b/src/test/regress/expected/multi_size_queries.out @@ -125,5 +125,46 @@ ALTER TABLE supplier ALTER COLUMN s_suppkey SET NOT NULL; select citus_table_size('supplier'); ERROR: citus size functions cannot be called in transaction blocks which contain multi-shard data modifications END; +show citus.node_conninfo; + citus.node_conninfo +--------------------- + sslmode=require +(1 row) + +ALTER SYSTEM SET citus.node_conninfo = 'sslmode=require'; +SELECT pg_reload_conf(); + pg_reload_conf +---------------- + t +(1 row) + +-- make sure that any invalidation to the connection info +-- wouldn't prevent future commands to fail +SELECT citus_total_relation_size('customer_copy_hash'); + citus_total_relation_size +--------------------------- + 2646016 +(1 row) + +SELECT pg_reload_conf(); + pg_reload_conf +---------------- + t +(1 row) + +SELECT citus_total_relation_size('customer_copy_hash'); + citus_total_relation_size +--------------------------- + 2646016 +(1 row) + +-- reset back to the original node_conninfo +ALTER SYSTEM RESET citus.node_conninfo; +SELECT pg_reload_conf(); + pg_reload_conf +---------------- + t +(1 row) + DROP INDEX index_1; DROP INDEX index_2; diff --git a/src/test/regress/sql/multi_size_queries.sql b/src/test/regress/sql/multi_size_queries.sql index 8387667c2..220c4e18d 100644 --- a/src/test/regress/sql/multi_size_queries.sql +++ b/src/test/regress/sql/multi_size_queries.sql @@ -65,5 +65,19 @@ ALTER TABLE supplier ALTER COLUMN s_suppkey SET NOT NULL; select citus_table_size('supplier'); END; +show citus.node_conninfo; +ALTER SYSTEM SET citus.node_conninfo = 'sslmode=require'; +SELECT pg_reload_conf(); + +-- make sure that any invalidation to the connection info +-- wouldn't prevent future commands to fail +SELECT citus_total_relation_size('customer_copy_hash'); +SELECT pg_reload_conf(); +SELECT citus_total_relation_size('customer_copy_hash'); + +-- reset back to the original node_conninfo +ALTER SYSTEM RESET citus.node_conninfo; +SELECT pg_reload_conf(); + DROP INDEX index_1; DROP INDEX index_2;