Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1618,6 +1618,14 @@ private void LoginNoFailover(ServerInfo serverInfo,

if (RoutingInfo != null)
{
// Check if we received enhanced routing info, but not the ack for the feature.
// In this case, we should ignore the routing info and connect to the current server.
if (!string.IsNullOrEmpty(RoutingInfo.DatabaseName) && !IsEnhancedRoutingSupportEnabled)
{
RoutingInfo = null;
break;
}

SqlClientEventSource.Log.TryTraceEvent("<sc.SqlInternalConnectionTds.LoginNoFailover> Routed to {0}", serverInfo.ExtendedServerName);
if (routingAttempts > MaxNumberOfRedirectRoute)
{
Expand Down Expand Up @@ -1879,6 +1887,14 @@ TimeoutTimer timeout
int routingAttempts = 0;
while (RoutingInfo != null)
{
// Check if we received enhanced routing info, but not the ack for the feature.
// In this case, we should ignore the routing info and connect to the current server.
if (!string.IsNullOrEmpty(RoutingInfo.DatabaseName) && !IsEnhancedRoutingSupportEnabled)
{
RoutingInfo = null;
continue;
}

if (routingAttempts > MaxNumberOfRedirectRoute)
{
throw SQL.ROR_RecursiveRoutingNotSupported(this, MaxNumberOfRedirectRoute);
Expand Down Expand Up @@ -2723,7 +2739,7 @@ internal SqlFedAuthToken GetFedAuthToken(SqlFedAuthInfo fedAuthInfo)

internal void OnFeatureExtAck(int featureId, byte[] data)
{
if (RoutingInfo != null && featureId != TdsEnums.FEATUREEXT_SQLDNSCACHING)
if (RoutingInfo != null && featureId != TdsEnums.FEATUREEXT_SQLDNSCACHING && featureId != TdsEnums.FEATUREEXT_ENHANCEDROUTINGSUPPORT)
{
return;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1648,6 +1648,14 @@ private void LoginNoFailover(ServerInfo serverInfo,

if (RoutingInfo != null)
{
// Check if we received enhanced routing info, but not the ack for the feature.
// In this case, we should ignore the routing info and connect to the current server.
if (!string.IsNullOrEmpty(RoutingInfo.DatabaseName) && !IsEnhancedRoutingSupportEnabled)
{
RoutingInfo = null;
break;
}

SqlClientEventSource.Log.TryTraceEvent("<sc.SqlInternalConnectionTds.LoginNoFailover> Routed to {0}", serverInfo.ExtendedServerName);
if (routingAttempts > MaxNumberOfRedirectRoute)
{
Expand Down Expand Up @@ -1933,6 +1941,14 @@ TimeoutTimer timeout
int routingAttempts = 0;
while (RoutingInfo != null)
{
// Check if we received enhanced routing info, but not the ack for the feature.
// In this case, we should ignore the routing info and connect to the current server.
if (!string.IsNullOrEmpty(RoutingInfo.DatabaseName) && !IsEnhancedRoutingSupportEnabled)
{
RoutingInfo = null;
continue;
}

if (routingAttempts > MaxNumberOfRedirectRoute)
{
throw SQL.ROR_RecursiveRoutingNotSupported(this, MaxNumberOfRedirectRoute);
Expand Down Expand Up @@ -2766,7 +2782,7 @@ internal SqlFedAuthToken GetFedAuthToken(SqlFedAuthInfo fedAuthInfo)

internal void OnFeatureExtAck(int featureId, byte[] data)
{
if (RoutingInfo != null && featureId != TdsEnums.FEATUREEXT_SQLDNSCACHING)
if (RoutingInfo != null && featureId != TdsEnums.FEATUREEXT_SQLDNSCACHING && featureId != TdsEnums.FEATUREEXT_ENHANCEDROUTINGSUPPORT)
{
return;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,185 +16,111 @@ namespace Microsoft.Data.SqlClient.UnitTests.SimulatedServerTests;
[Collection("SimulatedServerTests")]
public class ConnectionEnhancedRoutingTests
{
[Fact]
public void RoutedConnection()
/// <summary>
/// Tests that a connection is routed to the target server when enhanced routing is enabled.
/// Uses Theory to test both sync and async code paths.
/// </summary>
[Theory]
[InlineData(false)]
[InlineData(true)]
public async Task RoutedConnection(bool useAsync)
{
// Arrange
using TdsServer server = new(new());
server.Start();
using TestRoutingServers servers = new(FeatureExtensionBehavior.Enabled);

string routingDatabaseName = Guid.NewGuid().ToString();
bool clientProvidedCorrectDatabase = false;
server.OnLogin7Validated = loginToken =>
servers.TargetServer.OnLogin7Validated = loginToken =>
{
clientProvidedCorrectDatabase = routingDatabaseName == loginToken.Database;
clientProvidedCorrectDatabase = servers.RoutingDatabaseName == loginToken.Database;
};

RoutingTdsServer router = new(
new RoutingTdsServerArguments()
{
RoutingTCPHost = "localhost",
RoutingTCPPort = (ushort)server.EndPoint.Port,
RoutingDatabaseName = routingDatabaseName,
RequireReadOnly = false
});
router.Start();
router.EnhancedRoutingBehavior = FeatureExtensionBehavior.Enabled;

string connectionString = (new SqlConnectionStringBuilder()
{
DataSource = $"localhost,{router.EndPoint.Port}",
Encrypt = false,
ConnectTimeout = 10000
}).ConnectionString;

// Act
using SqlConnection connection = new(connectionString);
connection.Open();

// Assert
Assert.Equal(ConnectionState.Open, connection.State);
Assert.Equal($"localhost,{server.EndPoint.Port}", ((SqlInternalConnectionTds)connection.InnerConnection).RoutingDestination);
Assert.Equal(routingDatabaseName, connection.Database);
Assert.True(clientProvidedCorrectDatabase);

Assert.Equal(1, router.PreLoginCount);
Assert.Equal(1, server.PreLoginCount);
}

[Fact]
public async Task RoutedAsyncConnection()
{
// Arrange
using TdsServer server = new(new());
server.Start();

string routingDatabaseName = Guid.NewGuid().ToString();
bool clientProvidedCorrectDatabase = false;
server.OnLogin7Validated = loginToken =>
using SqlConnection connection = new(servers.ConnectionString);
if (useAsync)
{
clientProvidedCorrectDatabase = routingDatabaseName == loginToken.Database;
};

RoutingTdsServer router = new(
new RoutingTdsServerArguments()
{
RoutingTCPHost = "localhost",
RoutingTCPPort = (ushort)server.EndPoint.Port,
RoutingDatabaseName = routingDatabaseName,
RequireReadOnly = false
});
router.Start();
router.EnhancedRoutingBehavior = FeatureExtensionBehavior.Enabled;

string connectionString = (new SqlConnectionStringBuilder()
await connection.OpenAsync();
}
else
{
DataSource = $"localhost,{router.EndPoint.Port}",
Encrypt = false,
ConnectTimeout = 10000
}).ConnectionString;

// Act
using SqlConnection connection = new(connectionString);
await connection.OpenAsync();
connection.Open();
}

// Assert
Assert.Equal(ConnectionState.Open, connection.State);
Assert.Equal($"localhost,{server.EndPoint.Port}", ((SqlInternalConnectionTds)connection.InnerConnection).RoutingDestination);
Assert.Equal(routingDatabaseName, connection.Database);
Assert.Equal($"localhost,{servers.TargetServer.EndPoint.Port}", ((SqlInternalConnectionTds)connection.InnerConnection).RoutingDestination);
Assert.Equal(servers.RoutingDatabaseName, connection.Database);
Assert.True(clientProvidedCorrectDatabase);

Assert.Equal(1, router.PreLoginCount);
Assert.Equal(1, server.PreLoginCount);
Assert.Equal(1, servers.Router.PreLoginCount);
Assert.Equal(1, servers.TargetServer.PreLoginCount);
}

[Fact]
public void ServerIgnoresEnhancedRoutingRequest()
/// <summary>
/// Tests that a connection is NOT routed when the server does not acknowledge the enhanced routing feature
/// or has it disabled. Covers both DoNotAcknowledge and Disabled behaviors.
/// </summary>
[Theory]
[InlineData(FeatureExtensionBehavior.DoNotAcknowledge)]
[InlineData(FeatureExtensionBehavior.Disabled)]
public void ServerDoesNotRoute(FeatureExtensionBehavior behavior)
{
// Arrange
using TdsServer server = new(new());
server.Start();

string routingDatabaseName = Guid.NewGuid().ToString();
bool clientProvidedCorrectDatabase = false;
server.OnLogin7Validated = loginToken =>
{
clientProvidedCorrectDatabase = null == loginToken.Database;
};

RoutingTdsServer router = new(
new RoutingTdsServerArguments()
{
RoutingTCPHost = "localhost",
RoutingTCPPort = (ushort)server.EndPoint.Port,
RequireReadOnly = false
});
router.Start();
router.EnhancedRoutingBehavior = FeatureExtensionBehavior.DoNotAcknowledge;

string connectionString = (new SqlConnectionStringBuilder()
{
DataSource = $"localhost,{router.EndPoint.Port}",
Encrypt = false,
ConnectTimeout = 10000
}).ConnectionString;
using TestRoutingServers servers = new(behavior);

// Act
using SqlConnection connection = new(connectionString);
using SqlConnection connection = new(servers.ConnectionString);
connection.Open();

// Assert
Assert.Equal(ConnectionState.Open, connection.State);
Assert.Equal($"localhost,{server.EndPoint.Port}", ((SqlInternalConnectionTds)connection.InnerConnection).RoutingDestination);
Assert.Null(((SqlInternalConnectionTds)connection.InnerConnection).RoutingDestination);
Assert.Equal("master", connection.Database);
Assert.True(clientProvidedCorrectDatabase);

Assert.Equal(1, router.PreLoginCount);
Assert.Equal(1, server.PreLoginCount);
Assert.Equal(1, servers.Router.PreLoginCount);
Assert.Equal(0, servers.TargetServer.PreLoginCount);
}

[Fact]
public void ServerRejectsEnhancedRoutingRequest()
/// <summary>
/// Helper class that encapsulates the setup of a routing TDS server and target TDS server
/// for enhanced routing tests.
/// </summary>
private sealed class TestRoutingServers : IDisposable
{
// Arrange
using TdsServer server = new(new());
server.Start();
public TdsServer TargetServer { get; }
public RoutingTdsServer Router { get; }
public string RoutingDatabaseName { get; }
public string ConnectionString { get; }

string routingDatabaseName = Guid.NewGuid().ToString();
bool clientProvidedCorrectDatabase = false;
server.OnLogin7Validated = loginToken =>
public TestRoutingServers(FeatureExtensionBehavior enhancedRoutingBehavior)
{
clientProvidedCorrectDatabase = null == loginToken.Database;
};

RoutingTdsServer router = new(
new RoutingTdsServerArguments()
RoutingDatabaseName = Guid.NewGuid().ToString();

TargetServer = new TdsServer(new());
TargetServer.Start();

Router = new RoutingTdsServer(
new RoutingTdsServerArguments()
{
RoutingTCPHost = "localhost",
RoutingTCPPort = (ushort)TargetServer.EndPoint.Port,
RoutingDatabaseName = RoutingDatabaseName,
RequireReadOnly = false
});
Router.Start();
Router.EnhancedRoutingBehavior = enhancedRoutingBehavior;

ConnectionString = new SqlConnectionStringBuilder()
{
RoutingTCPHost = "localhost",
RoutingTCPPort = (ushort)server.EndPoint.Port,
RequireReadOnly = false
});
router.Start();
router.EnhancedRoutingBehavior = FeatureExtensionBehavior.Disabled;

string connectionString = (new SqlConnectionStringBuilder()
{
DataSource = $"localhost,{router.EndPoint.Port}",
Encrypt = false,
ConnectTimeout = 10000
}).ConnectionString;

// Act
using SqlConnection connection = new(connectionString);
connection.Open();

// Assert
Assert.Equal(ConnectionState.Open, connection.State);
Assert.Equal($"localhost,{server.EndPoint.Port}", ((SqlInternalConnectionTds)connection.InnerConnection).RoutingDestination);
Assert.Equal("master", connection.Database);
Assert.True(clientProvidedCorrectDatabase);
DataSource = $"localhost,{Router.EndPoint.Port}",
Encrypt = false,
ConnectTimeout = 10000
}.ConnectionString;
}

Assert.Equal(1, router.PreLoginCount);
Assert.Equal(1, server.PreLoginCount);
public void Dispose()
{
Router?.Dispose();
TargetServer?.Dispose();
}
}
}
Loading