diff --git a/doc/changelog.rst b/doc/changelog.rst index 2fb225e2e1..4fff06c9cb 100644 --- a/doc/changelog.rst +++ b/doc/changelog.rst @@ -10,10 +10,13 @@ Version 4.12.1 is a bug fix release. - Fixed a bug that could raise ``UnboundLocalError`` when creating asynchronous connections over SSL. - Fixed a bug causing SRV hostname validation to fail when resolver and resolved hostnames are identical with three domain levels. - Fixed a bug that caused direct use of ``pymongo.uri_parser`` to raise an ``AttributeError``. +- Fixed a bug where clients created with connect=False and a "mongodb+srv://" connection string + could cause public ``pymongo.MongoClient`` and ``pymongo.AsyncMongoClient`` attributes (topology_description, + nodes, address, primary, secondaries, arbiters) to incorrectly return a Database, leading to type + errors such as: "NotImplementedError: Database objects do not implement truth value testing or bool()". - Removed Eventlet testing against Python versions newer than 3.9 since Eventlet is actively being sunset by its maintainers and has compatibility issues with PyMongo's dnspython dependency. - Issues Resolved ............... diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index 7744a75d9c..a236b21348 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -109,6 +109,7 @@ ) from pymongo.read_preferences import ReadPreference, _ServerMode from pymongo.results import ClientBulkWriteResult +from pymongo.server_description import ServerDescription from pymongo.server_selectors import writable_server_selector from pymongo.server_type import SERVER_TYPE from pymongo.topology_description import TOPOLOGY_TYPE, TopologyDescription @@ -779,7 +780,7 @@ def __init__( keyword_opts["document_class"] = doc_class self._resolve_srv_info: dict[str, Any] = {"keyword_opts": keyword_opts} - seeds = set() + self._seeds = set() is_srv = False username = None password = None @@ -804,18 +805,18 @@ def __init__( srv_max_hosts=srv_max_hosts, ) is_srv = entity.startswith(SRV_SCHEME) - seeds.update(res["nodelist"]) + self._seeds.update(res["nodelist"]) username = res["username"] or username password = res["password"] or password dbase = res["database"] or dbase opts = res["options"] fqdn = res["fqdn"] else: - seeds.update(split_hosts(entity, self._port)) - if not seeds: + self._seeds.update(split_hosts(entity, self._port)) + if not self._seeds: raise ConfigurationError("need to specify at least one host") - for hostname in [node[0] for node in seeds]: + for hostname in [node[0] for node in self._seeds]: if _detect_external_db(hostname): break @@ -838,7 +839,7 @@ def __init__( srv_service_name = opts.get("srvServiceName", common.SRV_SERVICE_NAME) srv_max_hosts = srv_max_hosts or opts.get("srvmaxhosts") - opts = self._normalize_and_validate_options(opts, seeds) + opts = self._normalize_and_validate_options(opts, self._seeds) # Username and password passed as kwargs override user info in URI. username = opts.get("username", username) @@ -857,7 +858,7 @@ def __init__( "username": username, "password": password, "dbase": dbase, - "seeds": seeds, + "seeds": self._seeds, "fqdn": fqdn, "srv_service_name": srv_service_name, "pool_class": pool_class, @@ -873,8 +874,7 @@ def __init__( self._options.read_concern, ) - if not is_srv: - self._init_based_on_options(seeds, srv_max_hosts, srv_service_name) + self._init_based_on_options(self._seeds, srv_max_hosts, srv_service_name) self._opened = False self._closed = False @@ -975,6 +975,7 @@ def _init_based_on_options( srv_service_name=srv_service_name, srv_max_hosts=srv_max_hosts, server_monitoring_mode=self._options.server_monitoring_mode, + topology_id=self._topology_settings._topology_id if self._topology_settings else None, ) if self._options.auto_encryption_opts: from pymongo.asynchronous.encryption import _Encrypter @@ -1205,6 +1206,16 @@ def topology_description(self) -> TopologyDescription: .. versionadded:: 4.0 """ + if self._topology is None: + servers = {(host, port): ServerDescription((host, port)) for host, port in self._seeds} + return TopologyDescription( + TOPOLOGY_TYPE.Unknown, + servers, + None, + None, + None, + self._topology_settings, + ) return self._topology.description @property @@ -1218,6 +1229,8 @@ def nodes(self) -> FrozenSet[_Address]: to any servers, or a network partition causes it to lose connection to all servers. """ + if self._topology is None: + return frozenset() description = self._topology.description return frozenset(s.address for s in description.known_servers) @@ -1576,6 +1589,8 @@ async def address(self) -> Optional[tuple[str, int]]: .. versionadded:: 3.0 """ + if self._topology is None: + await self._get_topology() topology_type = self._topology._description.topology_type if ( topology_type == TOPOLOGY_TYPE.Sharded @@ -1598,6 +1613,8 @@ async def primary(self) -> Optional[tuple[str, int]]: .. versionadded:: 3.0 AsyncMongoClient gained this property in version 3.0. """ + if self._topology is None: + await self._get_topology() return await self._topology.get_primary() # type: ignore[return-value] @property @@ -1611,6 +1628,8 @@ async def secondaries(self) -> set[_Address]: .. versionadded:: 3.0 AsyncMongoClient gained this property in version 3.0. """ + if self._topology is None: + await self._get_topology() return await self._topology.get_secondaries() @property @@ -1621,6 +1640,8 @@ async def arbiters(self) -> set[_Address]: connected to a replica set, there are no arbiters, or this client was created without the `replicaSet` option. """ + if self._topology is None: + await self._get_topology() return await self._topology.get_arbiters() @property diff --git a/pymongo/asynchronous/settings.py b/pymongo/asynchronous/settings.py index 62be853fba..9c2331971a 100644 --- a/pymongo/asynchronous/settings.py +++ b/pymongo/asynchronous/settings.py @@ -51,6 +51,7 @@ def __init__( srv_service_name: str = common.SRV_SERVICE_NAME, srv_max_hosts: int = 0, server_monitoring_mode: str = common.SERVER_MONITORING_MODE, + topology_id: Optional[ObjectId] = None, ): """Represent MongoClient's configuration. @@ -78,8 +79,10 @@ def __init__( self._srv_service_name = srv_service_name self._srv_max_hosts = srv_max_hosts or 0 self._server_monitoring_mode = server_monitoring_mode - - self._topology_id = ObjectId() + if topology_id is not None: + self._topology_id = topology_id + else: + self._topology_id = ObjectId() # Store the allocation traceback to catch unclosed clients in the # test suite. self._stack = "".join(traceback.format_stack()[:-2]) diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index 1c0adb5d6b..99a517e5c1 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -101,6 +101,7 @@ ) from pymongo.read_preferences import ReadPreference, _ServerMode from pymongo.results import ClientBulkWriteResult +from pymongo.server_description import ServerDescription from pymongo.server_selectors import writable_server_selector from pymongo.server_type import SERVER_TYPE from pymongo.synchronous import client_session, database, uri_parser @@ -777,7 +778,7 @@ def __init__( keyword_opts["document_class"] = doc_class self._resolve_srv_info: dict[str, Any] = {"keyword_opts": keyword_opts} - seeds = set() + self._seeds = set() is_srv = False username = None password = None @@ -802,18 +803,18 @@ def __init__( srv_max_hosts=srv_max_hosts, ) is_srv = entity.startswith(SRV_SCHEME) - seeds.update(res["nodelist"]) + self._seeds.update(res["nodelist"]) username = res["username"] or username password = res["password"] or password dbase = res["database"] or dbase opts = res["options"] fqdn = res["fqdn"] else: - seeds.update(split_hosts(entity, self._port)) - if not seeds: + self._seeds.update(split_hosts(entity, self._port)) + if not self._seeds: raise ConfigurationError("need to specify at least one host") - for hostname in [node[0] for node in seeds]: + for hostname in [node[0] for node in self._seeds]: if _detect_external_db(hostname): break @@ -836,7 +837,7 @@ def __init__( srv_service_name = opts.get("srvServiceName", common.SRV_SERVICE_NAME) srv_max_hosts = srv_max_hosts or opts.get("srvmaxhosts") - opts = self._normalize_and_validate_options(opts, seeds) + opts = self._normalize_and_validate_options(opts, self._seeds) # Username and password passed as kwargs override user info in URI. username = opts.get("username", username) @@ -855,7 +856,7 @@ def __init__( "username": username, "password": password, "dbase": dbase, - "seeds": seeds, + "seeds": self._seeds, "fqdn": fqdn, "srv_service_name": srv_service_name, "pool_class": pool_class, @@ -871,8 +872,7 @@ def __init__( self._options.read_concern, ) - if not is_srv: - self._init_based_on_options(seeds, srv_max_hosts, srv_service_name) + self._init_based_on_options(self._seeds, srv_max_hosts, srv_service_name) self._opened = False self._closed = False @@ -973,6 +973,7 @@ def _init_based_on_options( srv_service_name=srv_service_name, srv_max_hosts=srv_max_hosts, server_monitoring_mode=self._options.server_monitoring_mode, + topology_id=self._topology_settings._topology_id if self._topology_settings else None, ) if self._options.auto_encryption_opts: from pymongo.synchronous.encryption import _Encrypter @@ -1203,6 +1204,16 @@ def topology_description(self) -> TopologyDescription: .. versionadded:: 4.0 """ + if self._topology is None: + servers = {(host, port): ServerDescription((host, port)) for host, port in self._seeds} + return TopologyDescription( + TOPOLOGY_TYPE.Unknown, + servers, + None, + None, + None, + self._topology_settings, + ) return self._topology.description @property @@ -1216,6 +1227,8 @@ def nodes(self) -> FrozenSet[_Address]: to any servers, or a network partition causes it to lose connection to all servers. """ + if self._topology is None: + return frozenset() description = self._topology.description return frozenset(s.address for s in description.known_servers) @@ -1570,6 +1583,8 @@ def address(self) -> Optional[tuple[str, int]]: .. versionadded:: 3.0 """ + if self._topology is None: + self._get_topology() topology_type = self._topology._description.topology_type if ( topology_type == TOPOLOGY_TYPE.Sharded @@ -1592,6 +1607,8 @@ def primary(self) -> Optional[tuple[str, int]]: .. versionadded:: 3.0 MongoClient gained this property in version 3.0. """ + if self._topology is None: + self._get_topology() return self._topology.get_primary() # type: ignore[return-value] @property @@ -1605,6 +1622,8 @@ def secondaries(self) -> set[_Address]: .. versionadded:: 3.0 MongoClient gained this property in version 3.0. """ + if self._topology is None: + self._get_topology() return self._topology.get_secondaries() @property @@ -1615,6 +1634,8 @@ def arbiters(self) -> set[_Address]: connected to a replica set, there are no arbiters, or this client was created without the `replicaSet` option. """ + if self._topology is None: + self._get_topology() return self._topology.get_arbiters() @property diff --git a/pymongo/synchronous/settings.py b/pymongo/synchronous/settings.py index bb17de1874..61b86fa18d 100644 --- a/pymongo/synchronous/settings.py +++ b/pymongo/synchronous/settings.py @@ -51,6 +51,7 @@ def __init__( srv_service_name: str = common.SRV_SERVICE_NAME, srv_max_hosts: int = 0, server_monitoring_mode: str = common.SERVER_MONITORING_MODE, + topology_id: Optional[ObjectId] = None, ): """Represent MongoClient's configuration. @@ -78,8 +79,10 @@ def __init__( self._srv_service_name = srv_service_name self._srv_max_hosts = srv_max_hosts or 0 self._server_monitoring_mode = server_monitoring_mode - - self._topology_id = ObjectId() + if topology_id is not None: + self._topology_id = topology_id + else: + self._topology_id = ObjectId() # Store the allocation traceback to catch unclosed clients in the # test suite. self._stack = "".join(traceback.format_stack()[:-2]) diff --git a/test/asynchronous/test_client.py b/test/asynchronous/test_client.py index c9cfca81fc..b278d684cb 100644 --- a/test/asynchronous/test_client.py +++ b/test/asynchronous/test_client.py @@ -849,6 +849,58 @@ async def test_init_disconnected_with_auth(self): with self.assertRaises(ConnectionFailure): await c.pymongo_test.test.find_one() + @async_client_context.require_no_standalone + @async_client_context.require_no_load_balancer + @async_client_context.require_tls + async def test_init_disconnected_with_srv(self): + c = await self.async_rs_or_single_client( + "mongodb+srv://test1.test.build.10gen.cc", connect=False, tlsInsecure=True + ) + # nodes returns an empty set if not connected + self.assertEqual(c.nodes, frozenset()) + # topology_description returns the initial seed description if not connected + topology_description = c.topology_description + self.assertEqual(topology_description.topology_type, TOPOLOGY_TYPE.Unknown) + self.assertEqual( + { + ("test1.test.build.10gen.cc", None): ServerDescription( + ("test1.test.build.10gen.cc", None) + ) + }, + topology_description.server_descriptions(), + ) + + # address causes client to block until connected + self.assertIsNotNone(await c.address) + # Initial seed topology and connected topology have the same ID + self.assertEqual( + c._topology._topology_id, topology_description._topology_settings._topology_id + ) + await c.close() + + c = await self.async_rs_or_single_client( + "mongodb+srv://test1.test.build.10gen.cc", connect=False, tlsInsecure=True + ) + # primary causes client to block until connected + await c.primary + self.assertIsNotNone(c._topology) + await c.close() + + c = await self.async_rs_or_single_client( + "mongodb+srv://test1.test.build.10gen.cc", connect=False, tlsInsecure=True + ) + # secondaries causes client to block until connected + await c.secondaries + self.assertIsNotNone(c._topology) + await c.close() + + c = await self.async_rs_or_single_client( + "mongodb+srv://test1.test.build.10gen.cc", connect=False, tlsInsecure=True + ) + # arbiters causes client to block until connected + await c.arbiters + self.assertIsNotNone(c._topology) + async def test_equality(self): seed = "{}:{}".format(*list(self.client._topology_settings.seeds)[0]) c = await self.async_rs_or_single_client(seed, connect=False) diff --git a/test/test_client.py b/test/test_client.py index 038ba2241b..8a50c90afb 100644 --- a/test/test_client.py +++ b/test/test_client.py @@ -824,6 +824,58 @@ def test_init_disconnected_with_auth(self): with self.assertRaises(ConnectionFailure): c.pymongo_test.test.find_one() + @client_context.require_no_standalone + @client_context.require_no_load_balancer + @client_context.require_tls + def test_init_disconnected_with_srv(self): + c = self.rs_or_single_client( + "mongodb+srv://test1.test.build.10gen.cc", connect=False, tlsInsecure=True + ) + # nodes returns an empty set if not connected + self.assertEqual(c.nodes, frozenset()) + # topology_description returns the initial seed description if not connected + topology_description = c.topology_description + self.assertEqual(topology_description.topology_type, TOPOLOGY_TYPE.Unknown) + self.assertEqual( + { + ("test1.test.build.10gen.cc", None): ServerDescription( + ("test1.test.build.10gen.cc", None) + ) + }, + topology_description.server_descriptions(), + ) + + # address causes client to block until connected + self.assertIsNotNone(c.address) + # Initial seed topology and connected topology have the same ID + self.assertEqual( + c._topology._topology_id, topology_description._topology_settings._topology_id + ) + c.close() + + c = self.rs_or_single_client( + "mongodb+srv://test1.test.build.10gen.cc", connect=False, tlsInsecure=True + ) + # primary causes client to block until connected + c.primary + self.assertIsNotNone(c._topology) + c.close() + + c = self.rs_or_single_client( + "mongodb+srv://test1.test.build.10gen.cc", connect=False, tlsInsecure=True + ) + # secondaries causes client to block until connected + c.secondaries + self.assertIsNotNone(c._topology) + c.close() + + c = self.rs_or_single_client( + "mongodb+srv://test1.test.build.10gen.cc", connect=False, tlsInsecure=True + ) + # arbiters causes client to block until connected + c.arbiters + self.assertIsNotNone(c._topology) + def test_equality(self): seed = "{}:{}".format(*list(self.client._topology_settings.seeds)[0]) c = self.rs_or_single_client(seed, connect=False)