Skip to content

Commit

Permalink
Use IDbContextFactory instead of IServiceScopeFactory
Browse files Browse the repository at this point in the history
  • Loading branch information
oliverbooth committed Mar 25, 2023
1 parent 32308d8 commit 4683a1a
Show file tree
Hide file tree
Showing 11 changed files with 83 additions and 146 deletions.
2 changes: 1 addition & 1 deletion Hammer/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ await Host.CreateDefaultBuilder(args)

services.AddHostedSingleton<LoggingService>();

services.AddDbContext<HammerContext>();
services.AddDbContextFactory<HammerContext>();
services.AddHostedSingleton<DatabaseService>();

services.AddSingleton<HttpClient>();
Expand Down
19 changes: 7 additions & 12 deletions Hammer/Services/BanService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
using Humanizer;
using Microsoft.EntityFrameworkCore;
using Microsoft.EntityFrameworkCore.ChangeTracking;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Hosting;
using X10D.DSharpPlus;
using X10D.Text;
Expand All @@ -21,7 +20,7 @@ internal sealed class BanService : BackgroundService
{
private static readonly TimeSpan QueryInterval = TimeSpan.FromSeconds(30);
private readonly List<TemporaryBan> _temporaryBans = new();
private readonly IServiceScopeFactory _scopeFactory;
private readonly IDbContextFactory<HammerContext> _dbContextFactory;
private readonly DiscordClient _discordClient;
private readonly DiscordLogService _logService;
private readonly InfractionService _infractionService;
Expand All @@ -33,15 +32,15 @@ internal sealed class BanService : BackgroundService
/// Initializes a new instance of the <see cref="BanService" /> class.
/// </summary>
public BanService(
IServiceScopeFactory scopeFactory,
IDbContextFactory<HammerContext> dbContextFactory,
DiscordClient discordClient,
DiscordLogService logService,
InfractionService infractionService,
MailmanService mailmanService,
RuleService ruleService
)
{
_scopeFactory = scopeFactory;
_dbContextFactory = dbContextFactory;
_discordClient = discordClient;
_logService = logService;
_infractionService = infractionService;
Expand Down Expand Up @@ -70,8 +69,7 @@ public async Task<TemporaryBan> AddTemporaryBanAsync(TemporaryBan temporaryBan)
return existingBan;
}

await using AsyncServiceScope scope = _scopeFactory.CreateAsyncScope();
await using var context = scope.ServiceProvider.GetRequiredService<HammerContext>();
await using HammerContext context = await _dbContextFactory.CreateDbContextAsync().ConfigureAwait(false);

existingBan = await context.TemporaryBans.FindAsync(temporaryBan.UserId, temporaryBan.GuildId).ConfigureAwait(false);
if (existingBan is not null)
Expand Down Expand Up @@ -317,8 +315,7 @@ public async Task RevokeBanAsync(DiscordUser user, DiscordMember revoker, string
ArgumentNullException.ThrowIfNull(user);
ArgumentNullException.ThrowIfNull(revoker);

await using AsyncServiceScope scope = _scopeFactory.CreateAsyncScope();
await using var context = scope.ServiceProvider.GetRequiredService<HammerContext>();
await using HammerContext context = await _dbContextFactory.CreateDbContextAsync().ConfigureAwait(false);
TemporaryBan? temporaryBan =
await context.TemporaryBans.FirstOrDefaultAsync(b => b.UserId == user.Id && b.GuildId == revoker.Guild.Id)
.ConfigureAwait(false);
Expand Down Expand Up @@ -455,8 +452,7 @@ private async Task CreateTemporaryBanAsync(DiscordUser user, DiscordGuild guild,
{
var temporaryBan = TemporaryBan.Create(user, guild, expirationTime);

await using AsyncServiceScope scope = _scopeFactory.CreateAsyncScope();
await using var context = scope.ServiceProvider.GetRequiredService<HammerContext>();
await using HammerContext context = await _dbContextFactory.CreateDbContextAsync().ConfigureAwait(false);
EntityEntry<TemporaryBan> entry = await context.TemporaryBans.AddAsync(temporaryBan).ConfigureAwait(false);
await context.SaveChangesAsync().ConfigureAwait(false);

Expand Down Expand Up @@ -493,8 +489,7 @@ private async void TimerOnElapsed(object? sender, ElapsedEventArgs e)

private async Task UpdateFromDatabaseAsync()
{
await using AsyncServiceScope scope = _scopeFactory.CreateAsyncScope();
await using var context = scope.ServiceProvider.GetRequiredService<HammerContext>();
await using HammerContext context = await _dbContextFactory.CreateDbContextAsync().ConfigureAwait(false);
lock (_temporaryBans)
{
_temporaryBans.Clear();
Expand Down
12 changes: 5 additions & 7 deletions Hammer/Services/DatabaseService.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
using Hammer.Data;
using Microsoft.EntityFrameworkCore;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Hosting;
using NLog;

Expand All @@ -12,15 +11,15 @@ namespace Hammer.Services;
internal sealed class DatabaseService : BackgroundService
{
private static readonly ILogger Logger = LogManager.GetCurrentClassLogger();
private readonly IServiceScopeFactory _scopeFactory;
private readonly IDbContextFactory<HammerContext> _dbContextFactory;

/// <summary>
/// Initializes a new instance of the <see cref="DatabaseService" /> class.
/// </summary>
/// <param name="scopeFactory">The scope factory.</param>
public DatabaseService(IServiceScopeFactory scopeFactory)
/// <param name="dbContextFactory">The DbContext factory.</param>
public DatabaseService(IDbContextFactory<HammerContext> dbContextFactory)
{
_scopeFactory = scopeFactory;
_dbContextFactory = dbContextFactory;
}

/// <inheritdoc />
Expand All @@ -33,8 +32,7 @@ private async Task CreateDatabaseAsync()
{
Directory.CreateDirectory("data");

await using AsyncServiceScope scope = _scopeFactory.CreateAsyncScope();
await using var context = scope.ServiceProvider.GetRequiredService<HammerContext>();
await using HammerContext context = await _dbContextFactory.CreateDbContextAsync().ConfigureAwait(false);

Logger.Info("Creating database...");
await context.Database.EnsureCreatedAsync().ConfigureAwait(false);
Expand Down
28 changes: 10 additions & 18 deletions Hammer/Services/InfractionService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
using Humanizer;
using Microsoft.Data.Sqlite;
using Microsoft.EntityFrameworkCore;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Hosting;
using NLog;
using X10D.DSharpPlus;
Expand All @@ -29,7 +28,7 @@ internal sealed class InfractionService : BackgroundService
{
private static readonly ILogger Logger = LogManager.GetCurrentClassLogger();
private readonly Dictionary<ulong, List<Infraction>> _infractionCache = new();
private readonly IServiceScopeFactory _scopeFactory;
private readonly IDbContextFactory<HammerContext> _dbContextFactory;
private readonly DiscordClient _discordClient;
private readonly ConfigurationService _configurationService;
private readonly DiscordLogService _logService;
Expand All @@ -41,7 +40,7 @@ internal sealed class InfractionService : BackgroundService
/// Initializes a new instance of the <see cref="InfractionService" /> class.
/// </summary>
public InfractionService(
IServiceScopeFactory scopeFactory,
IDbContextFactory<HammerContext> dbContextFactory,
DiscordClient discordClient,
ConfigurationService configurationService,
DiscordLogService logService,
Expand All @@ -50,7 +49,7 @@ public InfractionService(
RuleService ruleService
)
{
_scopeFactory = scopeFactory;
_dbContextFactory = dbContextFactory;
_discordClient = discordClient;
_configurationService = configurationService;
_logService = logService;
Expand Down Expand Up @@ -84,8 +83,7 @@ public async Task<Infraction> AddInfractionAsync(Infraction infraction, DiscordG
throw new InvalidOperationException("The specified guild is invalid.");
}

await using AsyncServiceScope scope = _scopeFactory.CreateAsyncScope();
await using var context = scope.ServiceProvider.GetRequiredService<HammerContext>();
await using HammerContext context = await _dbContextFactory.CreateDbContextAsync().ConfigureAwait(false);

try
{
Expand Down Expand Up @@ -137,8 +135,7 @@ public async Task AddInfractionsAsync(IEnumerable<Infraction> infractions)
cache.Add(infraction);
}

await using AsyncServiceScope scope = _scopeFactory.CreateAsyncScope();
await using var context = scope.ServiceProvider.GetRequiredService<HammerContext>();
await using HammerContext context = await _dbContextFactory.CreateDbContextAsync().ConfigureAwait(false);
await context.AddRangeAsync(infractions).ConfigureAwait(false);
await context.SaveChangesAsync().ConfigureAwait(false);
}
Expand Down Expand Up @@ -680,8 +677,7 @@ public async Task ModifyInfractionAsync(Infraction infraction, Action<Infraction
ArgumentNullException.ThrowIfNull(infraction);
ArgumentNullException.ThrowIfNull(action);

await using AsyncServiceScope scope = _scopeFactory.CreateAsyncScope();
await using var context = scope.ServiceProvider.GetRequiredService<HammerContext>();
await using HammerContext context = await _dbContextFactory.CreateDbContextAsync().ConfigureAwait(false);
Infraction? existing = await context.Infractions.FindAsync(infraction.Id).ConfigureAwait(false);
if (existing is null) return;

Expand All @@ -702,8 +698,7 @@ public async Task ModifyInfractionAsync(Infraction infraction, Action<Infraction
/// </summary>
public async Task<int> PruneStaleInfractionsAsync()
{
await using AsyncServiceScope scope = _scopeFactory.CreateAsyncScope();
await using var context = scope.ServiceProvider.GetRequiredService<HammerContext>();
await using HammerContext context = await _dbContextFactory.CreateDbContextAsync().ConfigureAwait(false);
var idCache = new Dictionary<ulong, bool>();
var pruneInfractions = new List<Infraction>();

Expand Down Expand Up @@ -753,8 +748,7 @@ public async Task RemoveInfractionAsync(Infraction infraction)
_cooldownService.StopCooldown(infraction.UserId);
_infractionCache[infraction.GuildId].Remove(infraction);

await using AsyncServiceScope scope = _scopeFactory.CreateAsyncScope();
await using var context = scope.ServiceProvider.GetRequiredService<HammerContext>();
await using HammerContext context = await _dbContextFactory.CreateDbContextAsync().ConfigureAwait(false);
context.Remove(infraction);
await context.SaveChangesAsync().ConfigureAwait(false);
}
Expand All @@ -768,8 +762,7 @@ public async Task RemoveInfractionsAsync(IEnumerable<Infraction> infractions)
{
ArgumentNullException.ThrowIfNull(infractions);

await using AsyncServiceScope scope = _scopeFactory.CreateAsyncScope();
await using var context = scope.ServiceProvider.GetRequiredService<HammerContext>();
await using HammerContext context = await _dbContextFactory.CreateDbContextAsync().ConfigureAwait(false);

foreach (IGrouping<ulong, Infraction> group in infractions.GroupBy(i => i.GuildId))
{
Expand Down Expand Up @@ -802,8 +795,7 @@ private async Task LoadGuildInfractions(DiscordGuild guild)

cache.Clear();

await using AsyncServiceScope scope = _scopeFactory.CreateAsyncScope();
await using var context = scope.ServiceProvider.GetRequiredService<HammerContext>();
await using HammerContext context = await _dbContextFactory.CreateDbContextAsync().ConfigureAwait(false);
cache.AddRange(context.Infractions.Where(i => i.GuildId == guild.Id));

Logger.Info($"Retrieved {cache.Count} infractions for {guild}");
Expand Down
43 changes: 15 additions & 28 deletions Hammer/Services/MemberNoteService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
using Hammer.Resources;
using Microsoft.EntityFrameworkCore;
using Microsoft.EntityFrameworkCore.ChangeTracking;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Hosting;
using SmartFormat;
using X10D.DSharpPlus;
Expand All @@ -18,20 +17,20 @@ namespace Hammer.Services;
/// </summary>
internal sealed class MemberNoteService : BackgroundService
{
private readonly IServiceScopeFactory _scopeFactory;
private readonly IDbContextFactory<HammerContext> _dbContextFactory;
private readonly ConfigurationService _configurationService;
private readonly DiscordLogService _logService;

/// <summary>
/// Initializes a new instance of the <see cref="MemberNoteService" /> class.
/// </summary>
public MemberNoteService(
IServiceScopeFactory scopeFactory,
IDbContextFactory<HammerContext> dbContextFactory,
ConfigurationService configurationService,
DiscordLogService logService
)
{
_scopeFactory = scopeFactory;
_dbContextFactory = dbContextFactory;
_configurationService = configurationService;
_logService = logService;
}
Expand Down Expand Up @@ -75,10 +74,9 @@ public async Task<MemberNote> CreateNoteAsync(DiscordUser user, DiscordMember au
}

var note = new MemberNote(noteType, user, author, guild, trimmedContent);
await using (AsyncServiceScope scope = _scopeFactory.CreateAsyncScope())
{
await using var context = scope.ServiceProvider.GetRequiredService<HammerContext>();

await using (HammerContext context = await _dbContextFactory.CreateDbContextAsync().ConfigureAwait(false))
{
EntityEntry<MemberNote> result = await context.AddAsync(note).ConfigureAwait(false);
note = result.Entity;

Expand All @@ -104,8 +102,7 @@ public async Task<MemberNote> CreateNoteAsync(DiscordUser user, DiscordMember au
/// <exception cref="ArgumentException"><paramref name="id" /> refers to a non-existing note.</exception>
public async Task DeleteNoteAsync(long id)
{
await using AsyncServiceScope scope = _scopeFactory.CreateAsyncScope();
await using var context = scope.ServiceProvider.GetRequiredService<HammerContext>();
await using HammerContext context = await _dbContextFactory.CreateDbContextAsync().ConfigureAwait(false);

MemberNote? note = await context.MemberNotes.FirstOrDefaultAsync(n => n.Id == id).ConfigureAwait(false);
if (note is null)
Expand All @@ -127,8 +124,7 @@ public async Task EditNoteAsync(long id, string? content = null, MemberNoteType?
if (string.IsNullOrWhiteSpace(content) && type is null)
return;

await using AsyncServiceScope scope = _scopeFactory.CreateAsyncScope();
await using var context = scope.ServiceProvider.GetRequiredService<HammerContext>();
await using HammerContext context = await _dbContextFactory.CreateDbContextAsync().ConfigureAwait(false);

MemberNote? note = await context.MemberNotes.FirstOrDefaultAsync(n => n.Id == id).ConfigureAwait(false);
if (note is null)
Expand All @@ -154,9 +150,7 @@ public async Task EditNoteAsync(long id, string? content = null, MemberNoteType?
/// </returns>
public async Task<MemberNote?> GetNoteAsync(long id)
{
await using AsyncServiceScope scope = _scopeFactory.CreateAsyncScope();
await using var context = scope.ServiceProvider.GetRequiredService<HammerContext>();

await using HammerContext context = await _dbContextFactory.CreateDbContextAsync().ConfigureAwait(false);
return await context.MemberNotes.FirstOrDefaultAsync(n => n.Id == id).ConfigureAwait(false);
}

Expand All @@ -176,8 +170,7 @@ public async Task<int> GetNoteCountAsync(DiscordUser user, DiscordGuild guild)
ArgumentNullException.ThrowIfNull(user);
ArgumentNullException.ThrowIfNull(guild);

await using AsyncServiceScope scope = _scopeFactory.CreateAsyncScope();
await using var context = scope.ServiceProvider.GetRequiredService<HammerContext>();
await using HammerContext context = await _dbContextFactory.CreateDbContextAsync().ConfigureAwait(false);
return await context.MemberNotes.CountAsync(n => n.UserId == user.Id && n.GuildId == guild.Id).ConfigureAwait(false);
}

Expand Down Expand Up @@ -205,8 +198,7 @@ public async Task<int> GetNoteCountAsync(DiscordUser user, DiscordGuild guild, M
ArgumentNullException.ThrowIfNull(guild);
if (!Enum.IsDefined(type)) throw new ArgumentOutOfRangeException(nameof(type));

await using AsyncServiceScope scope = _scopeFactory.CreateAsyncScope();
await using var context = scope.ServiceProvider.GetRequiredService<HammerContext>();
await using HammerContext context = await _dbContextFactory.CreateDbContextAsync().ConfigureAwait(false);
return await context.MemberNotes.CountAsync(n => n.UserId == user.Id && n.GuildId == guild.Id && n.Type == type)
.ConfigureAwait(false);
}
Expand All @@ -224,8 +216,7 @@ public async IAsyncEnumerable<MemberNote> GetNotesAsync(DiscordGuild guild)
{
ArgumentNullException.ThrowIfNull(guild);

await using AsyncServiceScope scope = _scopeFactory.CreateAsyncScope();
await using var context = scope.ServiceProvider.GetRequiredService<HammerContext>();
await using HammerContext context = await _dbContextFactory.CreateDbContextAsync().ConfigureAwait(false);

foreach (MemberNote note in context.MemberNotes.Where(n => n.GuildId == guild.Id))
yield return note;
Expand All @@ -249,8 +240,7 @@ public async IAsyncEnumerable<MemberNote> GetNotesAsync(DiscordGuild guild, Memb
ArgumentNullException.ThrowIfNull(guild);
if (!Enum.IsDefined(type)) throw new ArgumentOutOfRangeException(nameof(type));

await using AsyncServiceScope scope = _scopeFactory.CreateAsyncScope();
await using var context = scope.ServiceProvider.GetRequiredService<HammerContext>();
await using HammerContext context = await _dbContextFactory.CreateDbContextAsync().ConfigureAwait(false);

foreach (MemberNote note in context.MemberNotes.Where(n => n.GuildId == guild.Id && n.Type == type))
yield return note;
Expand All @@ -275,8 +265,7 @@ public async IAsyncEnumerable<MemberNote> GetNotesAsync(DiscordUser user, Discor
ArgumentNullException.ThrowIfNull(user);
ArgumentNullException.ThrowIfNull(guild);

await using AsyncServiceScope scope = _scopeFactory.CreateAsyncScope();
await using var context = scope.ServiceProvider.GetRequiredService<HammerContext>();
await using HammerContext context = await _dbContextFactory.CreateDbContextAsync().ConfigureAwait(false);

foreach (MemberNote note in context.MemberNotes.Where(n => n.UserId == user.Id && n.GuildId == guild.Id))
yield return note;
Expand Down Expand Up @@ -306,8 +295,7 @@ public async IAsyncEnumerable<MemberNote> GetNotesAsync(DiscordUser user, Discor
ArgumentNullException.ThrowIfNull(guild);
if (!Enum.IsDefined(type)) throw new ArgumentOutOfRangeException(nameof(type));

await using AsyncServiceScope scope = _scopeFactory.CreateAsyncScope();
await using var context = scope.ServiceProvider.GetRequiredService<HammerContext>();
await using HammerContext context = await _dbContextFactory.CreateDbContextAsync().ConfigureAwait(false);

foreach (MemberNote note in
context.MemberNotes.Where(n => n.UserId == user.Id && n.GuildId == guild.Id && n.Type == type))
Expand All @@ -317,8 +305,7 @@ public async IAsyncEnumerable<MemberNote> GetNotesAsync(DiscordUser user, Discor
/// <inheritdoc />
protected override async Task ExecuteAsync(CancellationToken stoppingToken)
{
await using AsyncServiceScope scope = _scopeFactory.CreateAsyncScope();
await using var context = scope.ServiceProvider.GetRequiredService<HammerContext>();
await using HammerContext context = await _dbContextFactory.CreateDbContextAsync(stoppingToken).ConfigureAwait(false);
await context.Database.EnsureCreatedAsync(stoppingToken).ConfigureAwait(false);
}
}
Loading

0 comments on commit 4683a1a

Please sign in to comment.