Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "Connection/DbConnectionBase.h"
#include "Connection/DbConnectionBuilder.h"
#include "Connection/Credentials.h"
#include "Containers/Ticker.h"
#include "ModuleBindings/Types/ClientMessageType.g.h"
#include "ModuleBindings/Types/SubscribeMultiType.g.h"
#include "ModuleBindings/Types/UnsubscribeMultiType.g.h"
Expand All @@ -23,6 +24,69 @@ UDbConnectionBase::UDbConnectionBase(const FObjectInitializer& ObjectInitializer
ProcedureCallbacks = CreateDefaultSubobject<UProcedureCallbacks>(TEXT("ProcedureCallbacks"));
}

UDbConnectionBase::~UDbConnectionBase()
{
// Ensure we unregister from the ticker when destroyed
if (TickerHandle.IsValid())
{
FTSTicker::GetCoreTicker().RemoveTicker(TickerHandle);
TickerHandle.Reset();
}
}

void UDbConnectionBase::SetAutoTicking(bool bAutoTick)
{
if (bIsAutoTicking == bAutoTick)
{
return; // No change needed
}

bIsAutoTicking = bAutoTick;

if (bAutoTick)
{
// Register with FTSTicker for automatic frame ticking
TickerHandle = FTSTicker::GetCoreTicker().AddTicker(FTickerDelegate::CreateUObject(this, &UDbConnectionBase::OnTickerTick));
}
else
{
// Unregister from FTSTicker
if (TickerHandle.IsValid())
{
FTSTicker::GetCoreTicker().RemoveTicker(TickerHandle);
TickerHandle.Reset();
}
}
}

int32 UDbConnectionBase::GetActiveSubscriptionCount() const
{
// Thread-safe access to active subscriptions count
return ActiveSubscriptions.Num();
}

int32 UDbConnectionBase::GetPendingMessageCount() const
{
// Return count of pending messages
return PendingMessages.Num();
}

int32 UDbConnectionBase::GetPreprocessedMessageCount() const
{
return PreprocessedMessages.Num();
}

bool UDbConnectionBase::OnTickerTick(float DeltaTime)
{
// Called by FTSTicker each frame when auto-ticking is enabled
if (bIsAutoTicking)
{
FrameTick();
}
// Return true to continue ticking
return true;
}

void UDbConnectionBase::Disconnect()
{
if (WebSocket)
Expand Down Expand Up @@ -87,6 +151,10 @@ void UDbConnectionBase::HandleWSClosed(int32 /*StatusCode*/, const FString& Reas

void UDbConnectionBase::HandleWSBinaryMessage(const TArray<uint8>& Message)
{
// Track message stats for memory diagnostics
TotalMessagesReceived.fetch_add(1);
TotalBytesReceived.fetch_add(Message.Num());

//tag for arrival order
const int32 Id = NextPreprocessId.GetValue();
NextPreprocessId.Increment();
Expand Down Expand Up @@ -148,30 +216,6 @@ void UDbConnectionBase::FrameTick()
ProcessServerMessage(Msg);
}
}
void UDbConnectionBase::Tick(float DeltaTime)
{
if (bIsAutoTicking)
{
FrameTick();
}
}

TStatId UDbConnectionBase::GetStatId() const
{
// This is used by the engine to track tickables, we return a unique stat ID for this class
RETURN_QUICK_DECLARE_CYCLE_STAT(UMyTickableObject, STATGROUP_Tickables);
}

bool UDbConnectionBase::IsTickable() const
{
return bIsAutoTicking;
}

bool UDbConnectionBase::IsTickableInEditor() const
{
return bIsAutoTicking;
}


void UDbConnectionBase::ProcessServerMessage(const FServerMessageType& Message)
{
Expand Down Expand Up @@ -381,22 +425,38 @@ bool UDbConnectionBase::DecompressBrotli(const TArray<uint8>& InData, TArray<uin

bool UDbConnectionBase::DecompressGzip(const TArray<uint8>& InData, TArray<uint8>& OutData)
{
if (InData.Num() < 4)
if (InData.Num() < 10) // Minimum gzip header size
{
UE_LOG(LogTemp, Error, TEXT("Gzip data too small"));
UE_LOG(LogTemp, Warning, TEXT("Gzip data too small (%d bytes), likely incomplete"), InData.Num());
return false;
}

// Verify gzip magic header (1F 8B)
if (InData[0] != 0x1F || InData[1] != 0x8B)
{
UE_LOG(LogTemp, Warning, TEXT("Invalid gzip header: %02X %02X (expected 1F 8B)"), InData[0], InData[1]);
return false;
}

// Gzip data ends with 4 bytes indicating the uncompressed size
const uint8* SizePtr = InData.GetData() + InData.Num() - 4;
uint32 OutSize = SizePtr[0] | (SizePtr[1] << 8) | (SizePtr[2] << 16) | (SizePtr[3] << 24);

// Validate the output size
// Validate the output size - reject obviously invalid sizes that would indicate truncated data
// Max reasonable size: 100MB. If size is larger, data is likely incomplete/corrupt
constexpr uint32 MaxReasonableSize = 100 * 1024 * 1024;
if (OutSize > MaxReasonableSize)
{
UE_LOG(LogTemp, Warning, TEXT("Gzip claims uncompressed size of %u bytes - likely incomplete data, buffering"), OutSize);
return false;
}

OutData.SetNumUninitialized(OutSize);
// Attempt to decompress the Gzip data
if (!FCompression::UncompressMemory(NAME_Gzip, OutData.GetData(), OutSize, InData.GetData(), InData.Num()))
{
UE_LOG(LogTemp, Error, TEXT("Gzip decompression failed"));
UE_LOG(LogTemp, Warning, TEXT("Gzip decompression failed - data may be incomplete"));
OutData.Reset();
return false;
}

Expand Down Expand Up @@ -505,22 +565,73 @@ FServerMessageType UDbConnectionBase::PreProcessMessage(const TArray<uint8>& Mes
{
if (Message.Num() == 0)
{
UE_LOG(LogTemp, Error, TEXT("Empty message recived from server, ignored"));
UE_LOG(LogTemp, Error, TEXT("Empty message received from server, ignored"));
return FServerMessageType{};
}
// Check if the first byte is a valid compression tag
ECompressableQueryUpdateTag Compression = static_cast<ECompressableQueryUpdateTag>(Message[0]);
TArray<uint8> CompressedPayload;
CompressedPayload.Append(Message.GetData() + 1, Message.Num() - 1);

// Decompress the payload based on the compression tag
TArray<uint8> DataToProcess;
ECompressableQueryUpdateTag Compression;
bool bWasAccumulating = false;

// Thread-safe access to compressed message accumulation buffer
{
FScopeLock Lock(&CompressedBufferMutex);

// Check if we're accumulating a fragmented compressed message
if (bAccumulatingCompressedMessage)
{
// Append incoming data to the buffer (no compression tag on continuation)
IncompleteCompressedBuffer.Append(Message.GetData(), Message.Num());
DataToProcess = IncompleteCompressedBuffer;
Compression = BufferedCompressionType;
bWasAccumulating = true;
}
else
{
// New message - check compression tag
uint8 FirstByte = Message[0];
if (FirstByte > 2)
{
UE_LOG(LogTemp, Error, TEXT("PreProcessMessage: Invalid compression tag %d"), FirstByte);
return FServerMessageType{};
}
Compression = static_cast<ECompressableQueryUpdateTag>(FirstByte);
DataToProcess.Append(Message.GetData() + 1, Message.Num() - 1);
}
}

// Decompress the payload based on the compression tag (outside mutex for parallelism)
TArray<uint8> Decompressed;
if (!DecompressPayload(Compression, CompressedPayload, Decompressed))
if (!DecompressPayload(Compression, DataToProcess, Decompressed))
{
// Decompression failed - if it's a compressed format, buffer for more data
if (Compression == ECompressableQueryUpdateTag::Gzip || Compression == ECompressableQueryUpdateTag::Brotli)
{
FScopeLock Lock(&CompressedBufferMutex);
if (!bAccumulatingCompressedMessage)
{
// Start accumulating
IncompleteCompressedBuffer = DataToProcess;
BufferedCompressionType = Compression;
}
// else: already accumulating, buffer was updated above
bAccumulatingCompressedMessage = true;
return FServerMessageType{}; // Return empty, will process when complete
}
UE_LOG(LogTemp, Error, TEXT("Failed to decompress incoming message"));
return FServerMessageType{};
}

// Decompression succeeded - clear accumulation state
{
FScopeLock Lock(&CompressedBufferMutex);
if (bAccumulatingCompressedMessage)
{
IncompleteCompressedBuffer.Reset();
bAccumulatingCompressedMessage = false;
}
}

// Deserialize the decompressed data into a UServerMessageType object
FServerMessageType Parsed = UE::SpacetimeDB::Deserialize<FServerMessageType>(Decompressed);

Expand Down Expand Up @@ -702,4 +813,9 @@ void UDbConnectionBase::ApplyRegisteredTableUpdates(const FDatabaseUpdateType& U
// Broadcast the diff for each handler
Handler->BroadcastDiff(this, Context);
}
}

int32 UDbConnectionBase::GetPreprocessedTableDataCount() const
{
return PreprocessedTableData.Num();
}
Loading