Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 44 additions & 25 deletions InertiaCore/Extensions/Configure.cs
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
using System.Net;
using InertiaCore.Models;
using InertiaCore.Ssr;
using InertiaCore.Utils;
using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Hosting;
using Microsoft.AspNetCore.Mvc;
using Microsoft.AspNetCore.Mvc.ViewFeatures;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.DependencyInjection.Extensions;
using Microsoft.Extensions.Logging;

namespace InertiaCore.Extensions;

Expand All @@ -25,20 +25,52 @@ public static IApplicationBuilder UseInertia(this IApplicationBuilder app)
Inertia.Version(Vite.GetManifestHash);
}

app.Use(async (context, next) =>
// Check if TempData services are available for error bag functionality
CheckTempDataAvailability(app);

app.UseMiddleware<Middleware>();

return app;
}

private static void CheckTempDataAvailability(IApplicationBuilder app)
{
// Skip warning in test environments
var environment = app.ApplicationServices.GetService<IWebHostEnvironment>();
if (environment?.EnvironmentName == "Test" ||
(environment?.EnvironmentName != "Development" && IsTestEnvironment()))
{
if (context.IsInertiaRequest()
&& context.Request.Method == "GET"
&& context.Request.Headers[InertiaHeader.Version] != Inertia.GetVersion())
return;
}

try
{
var tempDataFactory = app.ApplicationServices.GetService<ITempDataDictionaryFactory>();
if (tempDataFactory == null)
{
await OnVersionChange(context, app);
return;
var logger = app.ApplicationServices.GetService<ILogger<IApplicationBuilder>>();
logger?.LogWarning("TempData services are not configured. Error bag functionality will be limited. " +
"Consider adding services.AddSession() and app.UseSession() to enable full error bag support.");
}
}
catch (Exception)
{
// If we can't check for TempData services, that's also a sign they might not be configured
var logger = app.ApplicationServices.GetService<ILogger<IApplicationBuilder>>();
logger?.LogWarning("Unable to verify TempData configuration. Error bag functionality may be limited. " +
"Ensure services.AddSession() and app.UseSession() are configured for full error bag support.");
}
}

await next();
});

return app;
private static bool IsTestEnvironment()
{
// Check if we're running in a test context by looking for common test assemblies
var assemblies = AppDomain.CurrentDomain.GetAssemblies();
return assemblies.Any(a =>
a.FullName?.Contains("nunit", StringComparison.OrdinalIgnoreCase) == true ||
a.FullName?.Contains("xunit", StringComparison.OrdinalIgnoreCase) == true ||
a.FullName?.Contains("mstest", StringComparison.OrdinalIgnoreCase) == true ||
a.FullName?.Contains("testhost", StringComparison.OrdinalIgnoreCase) == true);
}

public static IServiceCollection AddInertia(this IServiceCollection services,
Expand Down Expand Up @@ -76,17 +108,4 @@ public static IServiceCollection AddViteHelper(this IServiceCollection services,

return services;
}

private static async Task OnVersionChange(HttpContext context, IApplicationBuilder app)
{
var tempData = app.ApplicationServices.GetRequiredService<ITempDataDictionaryFactory>()
.GetTempData(context);

if (tempData.Any()) tempData.Keep();

context.Response.Headers.Override(InertiaHeader.Location, context.RequestedUri());
context.Response.StatusCode = (int)HttpStatusCode.Conflict;

await context.Response.CompleteAsync();
}
}
138 changes: 138 additions & 0 deletions InertiaCore/Extensions/InertiaExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Http.Extensions;
using Microsoft.AspNetCore.Mvc;
using Microsoft.AspNetCore.Mvc.ModelBinding;
using Microsoft.AspNetCore.Mvc.ViewFeatures;
using Microsoft.Extensions.DependencyInjection;
using System.Text;

namespace InertiaCore.Extensions;
Expand Down Expand Up @@ -63,4 +66,139 @@ internal static string MD5(this string s)

return sb.ToString();
}

/// <summary>
/// Gets the TempData dictionary for the current HTTP context.
/// </summary>
internal static ITempDataDictionary? GetTempData(this HttpContext context)
{
try
{
var tempDataFactory = context.RequestServices?.GetRequiredService<ITempDataDictionaryFactory>();
return tempDataFactory?.GetTempData(context);
}
catch (InvalidOperationException)
{
// Service provider not available, return null
return null;
}
}

/// <summary>
/// Sets validation errors in TempData for the specified error bag.
/// </summary>
public static void SetValidationErrors(this ITempDataDictionary tempData, Dictionary<string, string> errors,
string bagName = "default")
{
// Deserialize existing error bags from JSON
var errorBags = new Dictionary<string, Dictionary<string, string>>();
if (tempData["__ValidationErrors"] is string existingJson && !string.IsNullOrEmpty(existingJson))
{
try
{
errorBags = JsonSerializer.Deserialize<Dictionary<string, Dictionary<string, string>>>(existingJson)
?? new Dictionary<string, Dictionary<string, string>>();
}
catch (JsonException)
{
// If deserialization fails, start fresh
errorBags = new Dictionary<string, Dictionary<string, string>>();
}
}

errorBags[bagName] = errors;

// Serialize back to JSON for storage
tempData["__ValidationErrors"] = JsonSerializer.Serialize(errorBags);
}

/// <summary>
/// Sets validation errors in TempData from ModelState for the specified error bag.
/// </summary>
public static void SetValidationErrors(this ITempDataDictionary tempData, ModelStateDictionary modelState,
string bagName = "default")
{
var errors = modelState.ToDictionary(
kvp => kvp.Key,
kvp => kvp.Value?.Errors.FirstOrDefault()?.ErrorMessage ?? ""
);
tempData.SetValidationErrors(errors, bagName);
}

/// <summary>
/// Retrieve and clear validation errors from TempData, supporting error bags.
/// </summary>
public static Dictionary<string, string> GetAndClearValidationErrors(this ITempDataDictionary tempData,
HttpRequest request)
{
var errors = new Dictionary<string, string>();

if (!tempData.ContainsKey("__ValidationErrors"))
return errors;

// Deserialize from JSON
Dictionary<string, Dictionary<string, string>> storedErrors;
if (tempData["__ValidationErrors"] is string jsonString && !string.IsNullOrEmpty(jsonString))
{
try
{
storedErrors = JsonSerializer.Deserialize<Dictionary<string, Dictionary<string, string>>>(jsonString) ??
new Dictionary<string, Dictionary<string, string>>();
}
catch (JsonException)
{
// If deserialization fails, return empty
return errors;
}
}
else
{
return errors;
}

// Check if there's a specific error bag in the request header
var errorBag = "default";
if (request.Headers.ContainsKey(InertiaHeader.ErrorBag))
{
errorBag = request.Headers[InertiaHeader.ErrorBag].ToString();
}

// If there's only the default bag and no specific bag requested, return the default bag directly
if (storedErrors.Count == 1 && storedErrors.ContainsKey("default") && errorBag == "default")
{
foreach (var kvp in storedErrors["default"])
{
errors[kvp.Key] = kvp.Value;
}
}

// If there are multiple bags or a specific bag is requested, return the named bag
else if (storedErrors.ContainsKey(errorBag))
{
foreach (var kvp in storedErrors[errorBag])
{
errors[kvp.Key] = kvp.Value;
}
}

// If no specific bag and multiple bags exist, return all bags
else if (errorBag == "default" && storedErrors.Count > 1)
{
// Return all error bags as nested structure
// This will be handled differently but for now just return default or first available
var firstBag = storedErrors.Values.FirstOrDefault();
if (firstBag != null)
{
foreach (var kvp in firstBag)
{
errors[kvp.Key] = kvp.Value;
}
}
}

// Clear the temp data after reading (one-time use)
tempData.Remove("__ValidationErrors");

return errors;
}
}
2 changes: 2 additions & 0 deletions InertiaCore/Inertia.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ public static class Inertia

internal static void UseFactory(IResponseFactory factory) => _factory = factory;

internal static void ResetFactory() => _factory = default!;

public static Response Render(string component, object? props = null) => _factory.Render(component, props);

public static Task<IHtmlContent> Head(dynamic model) => _factory.Head(model);
Expand Down
110 changes: 110 additions & 0 deletions InertiaCore/Middleware.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
using System.Net;
using InertiaCore.Extensions;
using InertiaCore.Utils;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Mvc.ViewFeatures;
using Microsoft.Extensions.DependencyInjection;

namespace InertiaCore;

public class Middleware
{
private readonly RequestDelegate _next;

public Middleware(RequestDelegate next)
{
_next = next;
}

public async Task InvokeAsync(HttpContext context)
{
if (context.IsInertiaRequest()
&& context.Request.Method == "GET"
&& context.Request.Headers[InertiaHeader.Version] != Inertia.GetVersion())
{
await OnVersionChange(context);
return;
}

await _next(context);

// Handle empty responses for Inertia requests
if (context.IsInertiaRequest()
&& context.Response.StatusCode == 200
&& await IsEmptyResponse(context))
{
await OnEmptyResponse(context);
}
}

private static async Task OnVersionChange(HttpContext context)
{
var tempData = context.RequestServices.GetRequiredService<ITempDataDictionaryFactory>()
.GetTempData(context);

if (tempData.Any()) tempData.Keep();

context.Response.Headers.Override(InertiaHeader.Location, context.RequestedUri());
context.Response.StatusCode = (int)HttpStatusCode.Conflict;

await context.Response.CompleteAsync();
}

private static async Task<bool> IsEmptyResponse(HttpContext context)
{
// Check if Content-Length is 0 or not set
if (context.Response.Headers.ContentLength.HasValue)
{
return context.Response.Headers.ContentLength.Value == 0;
}

// Check if response body is empty or only whitespace
if (context.Response.Body.CanSeek && context.Response.Body.Length >= 0)
{
var position = context.Response.Body.Position;

// Check if the stream has any content
if (context.Response.Body.Length == 0)
{
return true;
}

context.Response.Body.Seek(0, SeekOrigin.Begin);

using var reader = new StreamReader(context.Response.Body, leaveOpen: true);
var content = await reader.ReadToEndAsync();

context.Response.Body.Seek(position, SeekOrigin.Begin);

return string.IsNullOrWhiteSpace(content);
}

// For non-seekable streams, check if the response body position is still 0
// This indicates nothing has been written to the response
try
{
return context.Response.Body.Position == 0;
}
catch
{
// If we can't determine, assume it's not empty to be safe
return false;
}
}

private static async Task OnEmptyResponse(HttpContext context)
{
// Use Inertia.Back() to redirect back
var backResult = Inertia.Back();

// Determine the redirect URL using the same logic as BackResult
var referrer = context.Request.Headers.Referer.ToString();
var redirectUrl = !string.IsNullOrEmpty(referrer) ? referrer : "/";

// Set the appropriate headers and status code for a back redirect
context.Response.StatusCode = (int)HttpStatusCode.SeeOther;
context.Response.Headers.Override("Location", redirectUrl);

await context.Response.CompleteAsync();
}
}
Loading
Loading