diff --git a/src/Web/Program.cs b/src/Web/Program.cs index d8050d10f..16e215d5a 100644 --- a/src/Web/Program.cs +++ b/src/Web/Program.cs @@ -121,3 +121,5 @@ } app.Run(); + +public partial class Program { } \ No newline at end of file diff --git a/tests/Hippo.FunctionalTests/TestBase.cs b/tests/Hippo.FunctionalTests/TestBase.cs index dcffc681c..c32707cc0 100644 --- a/tests/Hippo.FunctionalTests/TestBase.cs +++ b/tests/Hippo.FunctionalTests/TestBase.cs @@ -1,6 +1,7 @@ using System; using System.IO; using System.Linq; +using System.Net.Http; using System.Threading.Tasks; using FluentValidation.AspNetCore; using Hippo.Application; @@ -13,6 +14,7 @@ using Microsoft.AspNetCore.Hosting; using Microsoft.AspNetCore.Identity; using Microsoft.AspNetCore.Mvc; +using Microsoft.AspNetCore.Mvc.Testing; using Microsoft.EntityFrameworkCore; using Microsoft.Extensions.Configuration; using Microsoft.Extensions.DependencyInjection; @@ -24,61 +26,38 @@ namespace Hippo.FunctionalTests; public class TestBase : IDisposable { private static IConfigurationRoot _configuration = null!; - private static IServiceScopeFactory _scopeFactory = null!; + protected static WebApplicationFactory _factory = null!; private static Checkpoint _checkpoint = null!; private static string? _currentUserId; public TestBase() { - var builder = new ConfigurationBuilder() + var configBuilder = new ConfigurationBuilder() .SetBasePath(Directory.GetCurrentDirectory()) - .AddJsonFile("appsettings.json", true, true) - .AddEnvironmentVariables(); + .AddJsonFile("appsettings.json", true, true); - _configuration = builder.Build(); + _configuration = configBuilder.Build(); - var services = new ServiceCollection(); - - services.AddSingleton(Mock.Of(w => - w.EnvironmentName == "Development" && - w.ApplicationName == "Hippo.Web")); - - services.AddLogging(); - - services.AddApplication(); - services.AddInfrastructure(_configuration); - - services.AddDatabaseDeveloperPageExceptionFilter(); - - services.AddSingleton(); - - services.AddHttpContextAccessor(); - - services.AddHealthChecks() - .AddDbContextCheck(); - - services.AddControllersWithViews().AddFluentValidation(); - - services.AddRouting(options => options.LowercaseUrls = true); - - services.Configure(options => - options.SuppressModelStateInvalidFilter = true); - - // Replace service registration for ICurrentUserService - // Remove existing registration - var currentUserServiceDescriptor = services.FirstOrDefault(d => - d.ServiceType == typeof(ICurrentUserService)); - - if (currentUserServiceDescriptor != null) + _factory = new WebApplicationFactory() + .WithWebHostBuilder(builder => { - services.Remove(currentUserServiceDescriptor); - } - - // Register testing version - services.AddTransient(provider => - Mock.Of(s => s.UserId == _currentUserId)); - - _scopeFactory = services.BuildServiceProvider().GetRequiredService(); + builder.ConfigureServices(services => + { + // Replace service registration for ICurrentUserService + // Remove existing registration + var currentUserServiceDescriptor = services.FirstOrDefault(d => + d.ServiceType == typeof(ICurrentUserService)); + + if (currentUserServiceDescriptor != null) + { + services.Remove(currentUserServiceDescriptor); + } + + // Register testing version + services.AddTransient(provider => + Mock.Of(s => s.UserId == _currentUserId)); + }); + }); _checkpoint = new Checkpoint { @@ -90,7 +69,7 @@ public TestBase() private static void EnsureDatabase() { - using var scope = _scopeFactory.CreateScope(); + using var scope = _factory.Services.CreateScope(); var context = scope.ServiceProvider.GetRequiredService(); @@ -102,7 +81,7 @@ private static void EnsureDatabase() public static async Task SendAsync(IRequest request) { - using var scope = _scopeFactory.CreateScope(); + using var scope = _factory.Services.CreateScope(); var mediator = scope.ServiceProvider.GetRequiredService(); @@ -121,7 +100,7 @@ public static async Task RunAsAdministratorAsync() public static async Task RunAsUserAsync(string userName, string password, string[] roles) { - using var scope = _scopeFactory.CreateScope(); + using var scope = _factory.Services.CreateScope(); var userManager = scope.ServiceProvider.GetRequiredService>(); @@ -156,7 +135,7 @@ public static async Task RunAsUserAsync(string userName, string password public static async Task FindAsync(params object[] keyValues) where TEntity : class { - using var scope = _scopeFactory.CreateScope(); + using var scope = _factory.Services.CreateScope(); var context = scope.ServiceProvider.GetRequiredService(); @@ -166,7 +145,7 @@ public static async Task RunAsUserAsync(string userName, string password public static async Task AddAsync(TEntity entity) where TEntity : class { - using var scope = _scopeFactory.CreateScope(); + using var scope = _factory.Services.CreateScope(); var context = scope.ServiceProvider.GetRequiredService(); @@ -177,7 +156,7 @@ public static async Task AddAsync(TEntity entity) public static async Task CountAsync() where TEntity : class { - using var scope = _scopeFactory.CreateScope(); + using var scope = _factory.Services.CreateScope(); var context = scope.ServiceProvider.GetRequiredService(); diff --git a/tests/Hippo.FunctionalTests/Web/Controllers/AppControllerTests.cs b/tests/Hippo.FunctionalTests/Web/Controllers/AppControllerTests.cs new file mode 100644 index 000000000..02e5513b3 --- /dev/null +++ b/tests/Hippo.FunctionalTests/Web/Controllers/AppControllerTests.cs @@ -0,0 +1,29 @@ +using System.Collections.Generic; +using System.Net; +using System.Net.Http; +using System.Threading.Tasks; +using Hippo.FunctionalTests; +using Hippo.Web.Controllers; +using Microsoft.AspNetCore.Mvc; +using Microsoft.AspNetCore.Mvc.Testing; +using Xunit; + +public class AppControllerTests : TestBase +{ + [Fact] + public async Task IndexRequiresSignIn() + { + var client = _factory.CreateClient( + new WebApplicationFactoryClientOptions + { + AllowAutoRedirect = false + }); + + var response = await client.GetAsync("/app"); + + + Assert.Equal(HttpStatusCode.Redirect, response.StatusCode); + Assert.StartsWith("http://localhost/Account/Login", + response.Headers.Location?.OriginalString); + } +}