Skip to content

Commit 334f8b9

Browse files
committed
apply review comments
1 parent 30ed74f commit 334f8b9

File tree

4 files changed

+50
-48
lines changed

4 files changed

+50
-48
lines changed

aws_advanced_python_wrapper/read_write_splitting_plugin.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -56,11 +56,11 @@ def __init__(
5656
self,
5757
plugin_service: PluginService,
5858
props: Properties,
59-
connection_handler: ConnectionHandler,
59+
connection_handler: ReadWriteConnectionHandler,
6060
):
6161
self._plugin_service: PluginService = plugin_service
6262
self._properties: Properties = props
63-
self._connection_handler: ConnectionHandler = connection_handler
63+
self._connection_handler: ReadWriteConnectionHandler = connection_handler
6464
self._writer_connection: Optional[Connection] = None
6565
self._reader_connection: Optional[Connection] = None
6666
self._writer_host_info: Optional[HostInfo] = None
@@ -310,7 +310,7 @@ def _switch_to_reader_connection(self):
310310

311311
if (
312312
self._reader_connection is not None
313-
and not self._connection_handler.old_reader_can_be_used(
313+
and not self._connection_handler.can_host_be_used(
314314
self._reader_host_info
315315
)
316316
):
@@ -344,7 +344,7 @@ def _switch_to_reader_connection(self):
344344
self._close_connection_if_idle(self._writer_connection)
345345

346346
def _initialize_reader_connection(self):
347-
if self._connection_handler.need_connect_to_writer():
347+
if self._connection_handler.has_no_readers():
348348
if not self._is_connection_usable(
349349
self._writer_connection, self._plugin_service.driver_dialect
350350
):
@@ -437,7 +437,7 @@ def close_connection(conn: Optional[Connection], driver_dialect: DriverDialect):
437437
return
438438

439439

440-
class ConnectionHandler(Protocol):
440+
class ReadWriteConnectionHandler(Protocol):
441441
"""Protocol for handling writer/reader connection logic."""
442442

443443
@property
@@ -491,15 +491,15 @@ def is_writer_host(self, current_host: HostInfo) -> bool:
491491
...
492492

493493
def is_reader_host(self, current_host: HostInfo) -> bool:
494-
"""Return true if the current host fits the criteria of a writer host."""
494+
"""Return true if the current host fits the criteria of a reader host."""
495495
...
496496

497-
def old_reader_can_be_used(self, reader_host_info: HostInfo) -> bool:
498-
"""Return true if the current host can be used to switch connection to."""
497+
def can_host_be_used(self, host_info: HostInfo) -> bool:
498+
"""Returns true if connections can be switched to the given host"""
499499
...
500500

501-
def need_connect_to_writer(self) -> bool:
502-
"""Return true if switching to reader should instead connect to writer."""
501+
def has_no_readers(self) -> bool:
502+
"""Return true if there are no readers in the host list"""
503503
...
504504

505505
def refresh_and_store_host_list(
@@ -509,7 +509,7 @@ def refresh_and_store_host_list(
509509
...
510510

511511

512-
class TopologyBasedConnectionHandler(ConnectionHandler):
512+
class TopologyBasedConnectionHandler(ReadWriteConnectionHandler):
513513
"""Topology based implementation of connection handling logic."""
514514

515515
def __init__(self, plugin_service: PluginService, props: Properties):
@@ -615,11 +615,11 @@ def get_verified_initial_connection(
615615

616616
return current_conn
617617

618-
def old_reader_can_be_used(self, reader_host_info: HostInfo) -> bool:
618+
def can_host_be_used(self, host_info: HostInfo) -> bool:
619619
hostnames = [host_info.host for host_info in self._hosts]
620-
return reader_host_info is not None and reader_host_info.host in hostnames
620+
return host_info.host in hostnames
621621

622-
def need_connect_to_writer(self) -> bool:
622+
def has_no_readers(self) -> bool:
623623
if len(self._hosts) == 1:
624624
return self._get_writer() is not None
625625
return False

aws_advanced_python_wrapper/simple_read_write_splitting_plugin.py

Lines changed: 31 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
from aws_advanced_python_wrapper.host_availability import HostAvailability
2121
from aws_advanced_python_wrapper.read_write_splitting_plugin import (
22-
ConnectionHandler, ReadWriteSplittingConnectionManager)
22+
ReadWriteConnectionHandler, ReadWriteSplittingConnectionManager)
2323
from aws_advanced_python_wrapper.utils.rds_url_type import RdsUrlType
2424
from aws_advanced_python_wrapper.utils.rdsutils import RdsUtils
2525

@@ -37,14 +37,14 @@
3737
from aws_advanced_python_wrapper.utils.properties import WrapperProperties
3838

3939

40-
class EndpointBasedConnectionHandler(ConnectionHandler):
40+
class EndpointBasedConnectionHandler(ReadWriteConnectionHandler):
4141
"""Endpoint based implementation of connection handling logic."""
4242

4343
def __init__(self, plugin_service: PluginService, props: Properties):
44-
self._read_endpoint: str = EndpointBasedConnectionHandler._verify_parameter(
44+
read_endpoint: str = EndpointBasedConnectionHandler._verify_parameter(
4545
WrapperProperties.SRW_READ_ENDPOINT, props, str, required=True
4646
)
47-
self._write_endpoint: str = EndpointBasedConnectionHandler._verify_parameter(
47+
write_endpoint: str = EndpointBasedConnectionHandler._verify_parameter(
4848
WrapperProperties.SRW_WRITE_ENDPOINT, props, str, required=True
4949
)
5050

@@ -60,17 +60,19 @@ def __init__(self, plugin_service: PluginService, props: Properties):
6060
WrapperProperties.SRW_CONNECT_RETRY_INTERVAL_MS, props, int, lambda x: x > 0
6161
)
6262

63-
self._verify_opened_connection_type: Optional[HostRole] = (
64-
EndpointBasedConnectionHandler._parse_connection_type(
63+
self._verify_initial_connection_type: Optional[HostRole] = (
64+
EndpointBasedConnectionHandler._parse_role(
6565
WrapperProperties.SRW_VERIFY_INITIAL_CONNECTION_TYPE.get(props)
6666
)
6767
)
6868

6969
self._plugin_service: PluginService = plugin_service
7070
self._rds_utils: RdsUtils = RdsUtils()
7171
self._host_list_provider_service: Optional[HostListProviderService] = None
72-
self._write_endpoint_host_info: HostInfo = self._create_host_info(self._write_endpoint, HostRole.WRITER)
73-
self._read_endpoint_host_info: HostInfo = self._create_host_info(self._read_endpoint, HostRole.READER)
72+
self._write_endpoint_host_info: HostInfo = self._create_host_info(write_endpoint, HostRole.WRITER)
73+
self._read_endpoint_host_info: HostInfo = self._create_host_info(read_endpoint, HostRole.READER)
74+
self._write_endpoint = write_endpoint.casefold()
75+
self._read_endpoint = read_endpoint.casefold()
7476

7577
@property
7678
def host_list_provider_service(self) -> Optional[HostListProviderService]:
@@ -116,12 +118,12 @@ def get_verified_initial_connection(
116118

117119
if (
118120
url_type == RdsUrlType.RDS_WRITER_CLUSTER
119-
or self._verify_opened_connection_type == HostRole.WRITER
121+
or self._verify_initial_connection_type == HostRole.WRITER
120122
):
121123
conn = self._get_verified_connection(host_info, HostRole.WRITER, plugin_service_connect_func, connect_func)
122124
elif (
123125
url_type == RdsUrlType.RDS_READER_CLUSTER
124-
or self._verify_opened_connection_type == HostRole.READER
126+
or self._verify_initial_connection_type == HostRole.READER
125127
):
126128
conn = self._get_verified_connection(host_info, HostRole.READER, plugin_service_connect_func, connect_func)
127129

@@ -154,10 +156,8 @@ def _get_verified_connection(
154156
try:
155157
if connect_func is not None:
156158
candidate_conn = connect_func()
157-
elif host_info is not None:
158-
candidate_conn = plugin_service_connect_func(host_info)
159159
else:
160-
return None
160+
candidate_conn = plugin_service_connect_func(host_info)
161161

162162
if candidate_conn is None:
163163
self._delay()
@@ -178,12 +178,12 @@ def _get_verified_connection(
178178

179179
return None
180180

181-
def old_reader_can_be_used(self, reader_host_info: HostInfo) -> bool:
182-
# Assume that the old reader can always be used, no topology-based information to check.
181+
def can_host_be_used(self, host_info: HostInfo) -> bool:
182+
# Assume that the host can always be used, no topology-based information to check.
183183
return True
184184

185-
def need_connect_to_writer(self) -> bool:
186-
# SetReadOnly(true) will always connect to the read_endpoint, and not the writer.
185+
def has_no_readers(self) -> bool:
186+
# SetReadOnly(true) will always connect to the read_endpoint, regardless of number of readers.
187187
return False
188188

189189
def refresh_and_store_host_list(
@@ -218,21 +218,24 @@ def should_update_reader_with_current_conn(
218218

219219
def is_writer_host(self, current_host: HostInfo) -> bool:
220220
return (
221-
current_host.host.casefold() == self._write_endpoint.casefold()
222-
or current_host.url.casefold() == self._write_endpoint.casefold()
221+
current_host.host.casefold() == self._write_endpoint
222+
or current_host.url.casefold() == self._write_endpoint
223223
)
224224

225225
def is_reader_host(self, current_host: HostInfo) -> bool:
226226
return (
227-
current_host.host.casefold() == self._read_endpoint.casefold()
228-
or current_host.url.casefold() == self._read_endpoint.casefold()
227+
current_host.host.casefold() == self._read_endpoint
228+
or current_host.url.casefold() == self._read_endpoint
229229
)
230230

231231
def _create_host_info(self, endpoint: str, role: HostRole) -> HostInfo:
232232
endpoint = endpoint.strip()
233233
host = endpoint
234-
port = self._plugin_service.database_dialect.default_port if not self._plugin_service.current_host_info.is_port_specified() \
235-
else self._plugin_service.current_host_info.port
234+
try:
235+
port = self._plugin_service.database_dialect.default_port if not self._plugin_service.current_host_info.is_port_specified() \
236+
else self._plugin_service.current_host_info.port
237+
except AwsWrapperError: # if current_host_info cannot be determined fallback to default port
238+
port = self._plugin_service.database_dialect.default_port
236239
colon_index = endpoint.rfind(":")
237240

238241
if colon_index != -1:
@@ -272,14 +275,14 @@ def _delay(self):
272275
sleep(self._connect_retry_interval_ms / 1000)
273276

274277
@staticmethod
275-
def _parse_connection_type(phase_str: Optional[str]) -> HostRole:
276-
if not phase_str:
278+
def _parse_role(role_str: Optional[str]) -> HostRole:
279+
if not role_str:
277280
return HostRole.UNKNOWN
278281

279-
phase_upper = phase_str.lower()
280-
if phase_upper == "reader":
282+
phase_lower = role_str.lower()
283+
if phase_lower == "reader":
281284
return HostRole.READER
282-
elif phase_upper == "writer":
285+
elif phase_lower == "writer":
283286
return HostRole.WRITER
284287
else:
285288
raise ValueError(

aws_advanced_python_wrapper/sql_alchemy_connection_provider.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,11 @@ def _create_sql_alchemy_pool(self, **kwargs):
173173

174174
def release_resources(self):
175175
for _, cache_item in SqlAlchemyPooledConnectionProvider._database_pools.items():
176-
cache_item.item.dispose()
176+
try:
177+
cache_item.item.dispose()
178+
except Exception:
179+
# Swallow exception, connections may already be dead
180+
pass
177181
SqlAlchemyPooledConnectionProvider._database_pools.clear()
178182

179183

tests/integration/container/test_read_write_splitting.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -80,15 +80,10 @@ def rds_utils(self):
8080

8181
@pytest.fixture(autouse=True)
8282
def clear_caches(self):
83-
# Clear wrapper caches
8483
RdsHostListProvider._topology_cache.clear()
8584
RdsHostListProvider._is_primary_cluster_id_cache.clear()
8685
RdsHostListProvider._cluster_ids_to_update.clear()
8786

88-
# Force DNS refresh by clearing any internal connection pools
89-
ConnectionProviderManager.release_resources()
90-
ConnectionProviderManager.reset_provider()
91-
9287
@pytest.fixture
9388
def props(self, plugin_config, conn_utils):
9489
plugin_name, plugin_value = plugin_config

0 commit comments

Comments
 (0)