diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs index 9ce4eccb71..cfd42f3e60 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs @@ -1618,6 +1618,14 @@ private void LoginNoFailover(ServerInfo serverInfo, if (RoutingInfo != null) { + // Check if we received enhanced routing info, but not the ack for the feature. + // In this case, we should ignore the routing info and connect to the current server. + if (!string.IsNullOrEmpty(RoutingInfo.DatabaseName) && !IsEnhancedRoutingSupportEnabled) + { + RoutingInfo = null; + break; + } + SqlClientEventSource.Log.TryTraceEvent(" Routed to {0}", serverInfo.ExtendedServerName); if (routingAttempts > MaxNumberOfRedirectRoute) { @@ -1879,6 +1887,14 @@ TimeoutTimer timeout int routingAttempts = 0; while (RoutingInfo != null) { + // Check if we received enhanced routing info, but not the ack for the feature. + // In this case, we should ignore the routing info and connect to the current server. + if (!string.IsNullOrEmpty(RoutingInfo.DatabaseName) && !IsEnhancedRoutingSupportEnabled) + { + RoutingInfo = null; + continue; + } + if (routingAttempts > MaxNumberOfRedirectRoute) { throw SQL.ROR_RecursiveRoutingNotSupported(this, MaxNumberOfRedirectRoute); @@ -2723,7 +2739,7 @@ internal SqlFedAuthToken GetFedAuthToken(SqlFedAuthInfo fedAuthInfo) internal void OnFeatureExtAck(int featureId, byte[] data) { - if (RoutingInfo != null && featureId != TdsEnums.FEATUREEXT_SQLDNSCACHING) + if (RoutingInfo != null && featureId != TdsEnums.FEATUREEXT_SQLDNSCACHING && featureId != TdsEnums.FEATUREEXT_ENHANCEDROUTINGSUPPORT) { return; } diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs index 8388133898..857aca2f81 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs @@ -1648,6 +1648,14 @@ private void LoginNoFailover(ServerInfo serverInfo, if (RoutingInfo != null) { + // Check if we received enhanced routing info, but not the ack for the feature. + // In this case, we should ignore the routing info and connect to the current server. + if (!string.IsNullOrEmpty(RoutingInfo.DatabaseName) && !IsEnhancedRoutingSupportEnabled) + { + RoutingInfo = null; + break; + } + SqlClientEventSource.Log.TryTraceEvent(" Routed to {0}", serverInfo.ExtendedServerName); if (routingAttempts > MaxNumberOfRedirectRoute) { @@ -1933,6 +1941,14 @@ TimeoutTimer timeout int routingAttempts = 0; while (RoutingInfo != null) { + // Check if we received enhanced routing info, but not the ack for the feature. + // In this case, we should ignore the routing info and connect to the current server. + if (!string.IsNullOrEmpty(RoutingInfo.DatabaseName) && !IsEnhancedRoutingSupportEnabled) + { + RoutingInfo = null; + continue; + } + if (routingAttempts > MaxNumberOfRedirectRoute) { throw SQL.ROR_RecursiveRoutingNotSupported(this, MaxNumberOfRedirectRoute); @@ -2766,7 +2782,7 @@ internal SqlFedAuthToken GetFedAuthToken(SqlFedAuthInfo fedAuthInfo) internal void OnFeatureExtAck(int featureId, byte[] data) { - if (RoutingInfo != null && featureId != TdsEnums.FEATUREEXT_SQLDNSCACHING) + if (RoutingInfo != null && featureId != TdsEnums.FEATUREEXT_SQLDNSCACHING && featureId != TdsEnums.FEATUREEXT_ENHANCEDROUTINGSUPPORT) { return; } diff --git a/src/Microsoft.Data.SqlClient/tests/UnitTests/SimulatedServerTests/ConnectionEnhancedRoutingTests.cs b/src/Microsoft.Data.SqlClient/tests/UnitTests/SimulatedServerTests/ConnectionEnhancedRoutingTests.cs index 7852f91f31..e0cec1110f 100644 --- a/src/Microsoft.Data.SqlClient/tests/UnitTests/SimulatedServerTests/ConnectionEnhancedRoutingTests.cs +++ b/src/Microsoft.Data.SqlClient/tests/UnitTests/SimulatedServerTests/ConnectionEnhancedRoutingTests.cs @@ -16,185 +16,111 @@ namespace Microsoft.Data.SqlClient.UnitTests.SimulatedServerTests; [Collection("SimulatedServerTests")] public class ConnectionEnhancedRoutingTests { - [Fact] - public void RoutedConnection() + /// + /// Tests that a connection is routed to the target server when enhanced routing is enabled. + /// Uses Theory to test both sync and async code paths. + /// + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task RoutedConnection(bool useAsync) { // Arrange - using TdsServer server = new(new()); - server.Start(); + using TestRoutingServers servers = new(FeatureExtensionBehavior.Enabled); - string routingDatabaseName = Guid.NewGuid().ToString(); bool clientProvidedCorrectDatabase = false; - server.OnLogin7Validated = loginToken => + servers.TargetServer.OnLogin7Validated = loginToken => { - clientProvidedCorrectDatabase = routingDatabaseName == loginToken.Database; + clientProvidedCorrectDatabase = servers.RoutingDatabaseName == loginToken.Database; }; - RoutingTdsServer router = new( - new RoutingTdsServerArguments() - { - RoutingTCPHost = "localhost", - RoutingTCPPort = (ushort)server.EndPoint.Port, - RoutingDatabaseName = routingDatabaseName, - RequireReadOnly = false - }); - router.Start(); - router.EnhancedRoutingBehavior = FeatureExtensionBehavior.Enabled; - - string connectionString = (new SqlConnectionStringBuilder() - { - DataSource = $"localhost,{router.EndPoint.Port}", - Encrypt = false, - ConnectTimeout = 10000 - }).ConnectionString; - // Act - using SqlConnection connection = new(connectionString); - connection.Open(); - - // Assert - Assert.Equal(ConnectionState.Open, connection.State); - Assert.Equal($"localhost,{server.EndPoint.Port}", ((SqlInternalConnectionTds)connection.InnerConnection).RoutingDestination); - Assert.Equal(routingDatabaseName, connection.Database); - Assert.True(clientProvidedCorrectDatabase); - - Assert.Equal(1, router.PreLoginCount); - Assert.Equal(1, server.PreLoginCount); - } - - [Fact] - public async Task RoutedAsyncConnection() - { - // Arrange - using TdsServer server = new(new()); - server.Start(); - - string routingDatabaseName = Guid.NewGuid().ToString(); - bool clientProvidedCorrectDatabase = false; - server.OnLogin7Validated = loginToken => + using SqlConnection connection = new(servers.ConnectionString); + if (useAsync) { - clientProvidedCorrectDatabase = routingDatabaseName == loginToken.Database; - }; - - RoutingTdsServer router = new( - new RoutingTdsServerArguments() - { - RoutingTCPHost = "localhost", - RoutingTCPPort = (ushort)server.EndPoint.Port, - RoutingDatabaseName = routingDatabaseName, - RequireReadOnly = false - }); - router.Start(); - router.EnhancedRoutingBehavior = FeatureExtensionBehavior.Enabled; - - string connectionString = (new SqlConnectionStringBuilder() + await connection.OpenAsync(); + } + else { - DataSource = $"localhost,{router.EndPoint.Port}", - Encrypt = false, - ConnectTimeout = 10000 - }).ConnectionString; - - // Act - using SqlConnection connection = new(connectionString); - await connection.OpenAsync(); + connection.Open(); + } // Assert Assert.Equal(ConnectionState.Open, connection.State); - Assert.Equal($"localhost,{server.EndPoint.Port}", ((SqlInternalConnectionTds)connection.InnerConnection).RoutingDestination); - Assert.Equal(routingDatabaseName, connection.Database); + Assert.Equal($"localhost,{servers.TargetServer.EndPoint.Port}", ((SqlInternalConnectionTds)connection.InnerConnection).RoutingDestination); + Assert.Equal(servers.RoutingDatabaseName, connection.Database); Assert.True(clientProvidedCorrectDatabase); - Assert.Equal(1, router.PreLoginCount); - Assert.Equal(1, server.PreLoginCount); + Assert.Equal(1, servers.Router.PreLoginCount); + Assert.Equal(1, servers.TargetServer.PreLoginCount); } - [Fact] - public void ServerIgnoresEnhancedRoutingRequest() + /// + /// Tests that a connection is NOT routed when the server does not acknowledge the enhanced routing feature + /// or has it disabled. Covers both DoNotAcknowledge and Disabled behaviors. + /// + [Theory] + [InlineData(FeatureExtensionBehavior.DoNotAcknowledge)] + [InlineData(FeatureExtensionBehavior.Disabled)] + public void ServerDoesNotRoute(FeatureExtensionBehavior behavior) { // Arrange - using TdsServer server = new(new()); - server.Start(); - - string routingDatabaseName = Guid.NewGuid().ToString(); - bool clientProvidedCorrectDatabase = false; - server.OnLogin7Validated = loginToken => - { - clientProvidedCorrectDatabase = null == loginToken.Database; - }; - - RoutingTdsServer router = new( - new RoutingTdsServerArguments() - { - RoutingTCPHost = "localhost", - RoutingTCPPort = (ushort)server.EndPoint.Port, - RequireReadOnly = false - }); - router.Start(); - router.EnhancedRoutingBehavior = FeatureExtensionBehavior.DoNotAcknowledge; - - string connectionString = (new SqlConnectionStringBuilder() - { - DataSource = $"localhost,{router.EndPoint.Port}", - Encrypt = false, - ConnectTimeout = 10000 - }).ConnectionString; + using TestRoutingServers servers = new(behavior); // Act - using SqlConnection connection = new(connectionString); + using SqlConnection connection = new(servers.ConnectionString); connection.Open(); // Assert Assert.Equal(ConnectionState.Open, connection.State); - Assert.Equal($"localhost,{server.EndPoint.Port}", ((SqlInternalConnectionTds)connection.InnerConnection).RoutingDestination); + Assert.Null(((SqlInternalConnectionTds)connection.InnerConnection).RoutingDestination); Assert.Equal("master", connection.Database); - Assert.True(clientProvidedCorrectDatabase); - Assert.Equal(1, router.PreLoginCount); - Assert.Equal(1, server.PreLoginCount); + Assert.Equal(1, servers.Router.PreLoginCount); + Assert.Equal(0, servers.TargetServer.PreLoginCount); } - [Fact] - public void ServerRejectsEnhancedRoutingRequest() + /// + /// Helper class that encapsulates the setup of a routing TDS server and target TDS server + /// for enhanced routing tests. + /// + private sealed class TestRoutingServers : IDisposable { - // Arrange - using TdsServer server = new(new()); - server.Start(); + public TdsServer TargetServer { get; } + public RoutingTdsServer Router { get; } + public string RoutingDatabaseName { get; } + public string ConnectionString { get; } - string routingDatabaseName = Guid.NewGuid().ToString(); - bool clientProvidedCorrectDatabase = false; - server.OnLogin7Validated = loginToken => + public TestRoutingServers(FeatureExtensionBehavior enhancedRoutingBehavior) { - clientProvidedCorrectDatabase = null == loginToken.Database; - }; - - RoutingTdsServer router = new( - new RoutingTdsServerArguments() + RoutingDatabaseName = Guid.NewGuid().ToString(); + + TargetServer = new TdsServer(new()); + TargetServer.Start(); + + Router = new RoutingTdsServer( + new RoutingTdsServerArguments() + { + RoutingTCPHost = "localhost", + RoutingTCPPort = (ushort)TargetServer.EndPoint.Port, + RoutingDatabaseName = RoutingDatabaseName, + RequireReadOnly = false + }); + Router.Start(); + Router.EnhancedRoutingBehavior = enhancedRoutingBehavior; + + ConnectionString = new SqlConnectionStringBuilder() { - RoutingTCPHost = "localhost", - RoutingTCPPort = (ushort)server.EndPoint.Port, - RequireReadOnly = false - }); - router.Start(); - router.EnhancedRoutingBehavior = FeatureExtensionBehavior.Disabled; - - string connectionString = (new SqlConnectionStringBuilder() - { - DataSource = $"localhost,{router.EndPoint.Port}", - Encrypt = false, - ConnectTimeout = 10000 - }).ConnectionString; - - // Act - using SqlConnection connection = new(connectionString); - connection.Open(); - - // Assert - Assert.Equal(ConnectionState.Open, connection.State); - Assert.Equal($"localhost,{server.EndPoint.Port}", ((SqlInternalConnectionTds)connection.InnerConnection).RoutingDestination); - Assert.Equal("master", connection.Database); - Assert.True(clientProvidedCorrectDatabase); + DataSource = $"localhost,{Router.EndPoint.Port}", + Encrypt = false, + ConnectTimeout = 10000 + }.ConnectionString; + } - Assert.Equal(1, router.PreLoginCount); - Assert.Equal(1, server.PreLoginCount); + public void Dispose() + { + Router?.Dispose(); + TargetServer?.Dispose(); + } } }