Skip to content

Commit

Permalink
Add managed cancellation support
Browse files Browse the repository at this point in the history
Closes Tyrrrz#716
  • Loading branch information
Tyrrrz committed Oct 7, 2021
1 parent 2f3e165 commit 21d89af
Show file tree
Hide file tree
Showing 21 changed files with 274 additions and 147 deletions.
12 changes: 7 additions & 5 deletions DiscordChatExporter.Cli/Commands/Base/ExportCommandBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
using DiscordChatExporter.Core.Exporting.Filtering;
using DiscordChatExporter.Core.Exporting.Partitioning;
using DiscordChatExporter.Core.Utils.Extensions;
using Tyrrrz.Extensions;

namespace DiscordChatExporter.Cli.Commands.Base
{
Expand Down Expand Up @@ -56,6 +55,8 @@ public abstract class ExportCommandBase : TokenCommandBase

protected async ValueTask ExecuteAsync(IConsole console, IReadOnlyList<Channel> channels)
{
var cancellationToken = console.RegisterCancellationHandler();

if (ShouldReuseMedia && !ShouldDownloadMedia)
{
throw new CommandException("Option --reuse-media cannot be used without --media.");
Expand All @@ -73,7 +74,7 @@ await channels.ParallelForEachAsync(async channel =>
{
await progressContext.StartTaskAsync($"{channel.Category} / {channel.Name}", async progress =>
{
var guild = await Discord.GetGuildAsync(channel.GuildId);
var guild = await Discord.GetGuildAsync(channel.GuildId, cancellationToken);

var request = new ExportRequest(
guild,
Expand All @@ -89,14 +90,14 @@ await progressContext.StartTaskAsync($"{channel.Category} / {channel.Name}", asy
DateFormat
);

await Exporter.ExportChannelAsync(request, progress);
await Exporter.ExportChannelAsync(request, progress, cancellationToken);
});
}
catch (DiscordChatExporterException ex) when (!ex.IsFatal)
{
errors[channel] = ex.Message;
}
}, ParallelLimit.ClampMin(1));
}, Math.Max(ParallelLimit, 1), cancellationToken);
});

// Print result
Expand Down Expand Up @@ -140,11 +141,12 @@ await console.Output.WriteLineAsync(

protected async ValueTask ExecuteAsync(IConsole console, IReadOnlyList<Snowflake> channelIds)
{
var cancellationToken = console.RegisterCancellationHandler();
var channels = new List<Channel>();

foreach (var channelId in channelIds)
{
var channel = await Discord.GetChannelAsync(channelId);
var channel = await Discord.GetChannelAsync(channelId, cancellationToken);
channels.Add(channel);
}

Expand Down
5 changes: 3 additions & 2 deletions DiscordChatExporter.Cli/Commands/ExportAllCommand.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,17 @@ public class ExportAllCommand : ExportCommandBase

public override async ValueTask ExecuteAsync(IConsole console)
{
var cancellationToken = console.RegisterCancellationHandler();
var channels = new List<Channel>();

await console.Output.WriteLineAsync("Fetching channels...");
await foreach (var guild in Discord.GetUserGuildsAsync())
await foreach (var guild in Discord.GetUserGuildsAsync(cancellationToken))
{
// Skip DMs if instructed to
if (!IncludeDirectMessages && guild.Id == Guild.DirectMessages.Id)
continue;

await foreach (var channel in Discord.GetGuildChannelsAsync(guild.Id))
await foreach (var channel in Discord.GetGuildChannelsAsync(guild.Id, cancellationToken))
{
// Skip non-text channels
if (!channel.IsTextChannel)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@ public class ExportDirectMessagesCommand : ExportCommandBase
{
public override async ValueTask ExecuteAsync(IConsole console)
{
var cancellationToken = console.RegisterCancellationHandler();

await console.Output.WriteLineAsync("Fetching channels...");
var channels = await Discord.GetGuildChannelsAsync(Guild.DirectMessages.Id);
var channels = await Discord.GetGuildChannelsAsync(Guild.DirectMessages.Id, cancellationToken);
var textChannels = channels.Where(c => c.IsTextChannel).ToArray();

await base.ExecuteAsync(console, textChannels);
Expand Down
4 changes: 3 additions & 1 deletion DiscordChatExporter.Cli/Commands/ExportGuildCommand.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@ public class ExportGuildCommand : ExportCommandBase

public override async ValueTask ExecuteAsync(IConsole console)
{
var cancellationToken = console.RegisterCancellationHandler();

await console.Output.WriteLineAsync("Fetching channels...");
var channels = await Discord.GetGuildChannelsAsync(GuildId);
var channels = await Discord.GetGuildChannelsAsync(GuildId, cancellationToken);
var textChannels = channels.Where(c => c.IsTextChannel).ToArray();

await base.ExecuteAsync(console, textChannels);
Expand Down
4 changes: 3 additions & 1 deletion DiscordChatExporter.Cli/Commands/GetChannelsCommand.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ public class GetChannelsCommand : TokenCommandBase

public override async ValueTask ExecuteAsync(IConsole console)
{
var channels = await Discord.GetGuildChannelsAsync(GuildId);
var cancellationToken = console.RegisterCancellationHandler();

var channels = await Discord.GetGuildChannelsAsync(GuildId, cancellationToken);

var textChannels = channels
.Where(c => c.IsTextChannel)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@ public class GetDirectMessageChannelsCommand : TokenCommandBase
{
public override async ValueTask ExecuteAsync(IConsole console)
{
var channels = await Discord.GetGuildChannelsAsync(Guild.DirectMessages.Id);
var cancellationToken = console.RegisterCancellationHandler();

var channels = await Discord.GetGuildChannelsAsync(Guild.DirectMessages.Id, cancellationToken);

var textChannels = channels
.Where(c => c.IsTextChannel)
Expand Down
4 changes: 3 additions & 1 deletion DiscordChatExporter.Cli/Commands/GetGuildsCommand.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ public class GetGuildsCommand : TokenCommandBase
{
public override async ValueTask ExecuteAsync(IConsole console)
{
var guilds = await Discord.GetUserGuildsAsync();
var cancellationToken = console.RegisterCancellationHandler();

var guilds = await Discord.GetUserGuildsAsync(cancellationToken);

foreach (var guild in guilds.OrderBy(g => g.Name))
{
Expand Down
94 changes: 63 additions & 31 deletions DiscordChatExporter.Core/Discord/DiscordClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
using System.Linq;
using System.Net;
using System.Net.Http;
using System.Runtime.CompilerServices;
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
using DiscordChatExporter.Core.Discord.Data;
using DiscordChatExporter.Core.Exceptions;
Expand All @@ -21,18 +23,28 @@ public class DiscordClient

public DiscordClient(AuthToken token) => _token = token;

private async ValueTask<HttpResponseMessage> GetResponseAsync(string url) =>
await Http.ResponsePolicy.ExecuteAsync(async () =>
private async ValueTask<HttpResponseMessage> GetResponseAsync(
string url,
CancellationToken cancellationToken = default)
{
return await Http.ResponsePolicy.ExecuteAsync(async innerCancellationToken =>
{
using var request = new HttpRequestMessage(HttpMethod.Get, new Uri(_baseUri, url));
request.Headers.Authorization = _token.GetAuthenticationHeader();

return await Http.Client.SendAsync(request, HttpCompletionOption.ResponseHeadersRead);
});
return await Http.Client.SendAsync(
request,
HttpCompletionOption.ResponseHeadersRead,
innerCancellationToken
);
}, cancellationToken);
}

private async ValueTask<JsonElement> GetJsonResponseAsync(string url)
private async ValueTask<JsonElement> GetJsonResponseAsync(
string url,
CancellationToken cancellationToken = default)
{
using var response = await GetResponseAsync(url);
using var response = await GetResponseAsync(url, cancellationToken);

if (!response.IsSuccessStatusCode)
{
Expand All @@ -45,19 +57,22 @@ private async ValueTask<JsonElement> GetJsonResponseAsync(string url)
};
}

return await response.Content.ReadAsJsonAsync();
return await response.Content.ReadAsJsonAsync(cancellationToken);
}

private async ValueTask<JsonElement?> TryGetJsonResponseAsync(string url)
private async ValueTask<JsonElement?> TryGetJsonResponseAsync(
string url,
CancellationToken cancellationToken = default)
{
using var response = await GetResponseAsync(url);
using var response = await GetResponseAsync(url, cancellationToken);

return response.IsSuccessStatusCode
? await response.Content.ReadAsJsonAsync()
? await response.Content.ReadAsJsonAsync(cancellationToken)
: null;
}

public async IAsyncEnumerable<Guild> GetUserGuildsAsync()
public async IAsyncEnumerable<Guild> GetUserGuildsAsync(
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
yield return Guild.DirectMessages;

Expand All @@ -71,7 +86,7 @@ public async IAsyncEnumerable<Guild> GetUserGuildsAsync()
.SetQueryParameter("after", currentAfter.ToString())
.Build();

var response = await GetJsonResponseAsync(url);
var response = await GetJsonResponseAsync(url, cancellationToken);

var isEmpty = true;
foreach (var guild in response.EnumerateArray().Select(Guild.Parse))
Expand All @@ -87,26 +102,30 @@ public async IAsyncEnumerable<Guild> GetUserGuildsAsync()
}
}

public async ValueTask<Guild> GetGuildAsync(Snowflake guildId)
public async ValueTask<Guild> GetGuildAsync(
Snowflake guildId,
CancellationToken cancellationToken = default)
{
if (guildId == Guild.DirectMessages.Id)
return Guild.DirectMessages;

var response = await GetJsonResponseAsync($"guilds/{guildId}");
var response = await GetJsonResponseAsync($"guilds/{guildId}", cancellationToken);
return Guild.Parse(response);
}

public async IAsyncEnumerable<Channel> GetGuildChannelsAsync(Snowflake guildId)
public async IAsyncEnumerable<Channel> GetGuildChannelsAsync(
Snowflake guildId,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
if (guildId == Guild.DirectMessages.Id)
{
var response = await GetJsonResponseAsync("users/@me/channels");
var response = await GetJsonResponseAsync("users/@me/channels", cancellationToken);
foreach (var channelJson in response.EnumerateArray())
yield return Channel.Parse(channelJson);
}
else
{
var response = await GetJsonResponseAsync($"guilds/{guildId}/channels");
var response = await GetJsonResponseAsync($"guilds/{guildId}/channels", cancellationToken);

var responseOrdered = response
.EnumerateArray()
Expand Down Expand Up @@ -138,31 +157,38 @@ public async IAsyncEnumerable<Channel> GetGuildChannelsAsync(Snowflake guildId)
}
}

public async IAsyncEnumerable<Role> GetGuildRolesAsync(Snowflake guildId)
public async IAsyncEnumerable<Role> GetGuildRolesAsync(
Snowflake guildId,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
if (guildId == Guild.DirectMessages.Id)
yield break;

var response = await GetJsonResponseAsync($"guilds/{guildId}/roles");
var response = await GetJsonResponseAsync($"guilds/{guildId}/roles", cancellationToken);

foreach (var roleJson in response.EnumerateArray())
yield return Role.Parse(roleJson);
}

public async ValueTask<Member> GetGuildMemberAsync(Snowflake guildId, User user)
public async ValueTask<Member> GetGuildMemberAsync(
Snowflake guildId,
User user,
CancellationToken cancellationToken = default)
{
if (guildId == Guild.DirectMessages.Id)
return Member.CreateForUser(user);

var response = await TryGetJsonResponseAsync($"guilds/{guildId}/members/{user.Id}");
var response = await TryGetJsonResponseAsync($"guilds/{guildId}/members/{user.Id}", cancellationToken);
return response?.Pipe(Member.Parse) ?? Member.CreateForUser(user);
}

public async ValueTask<ChannelCategory> GetChannelCategoryAsync(Snowflake channelId)
public async ValueTask<ChannelCategory> GetChannelCategoryAsync(
Snowflake channelId,
CancellationToken cancellationToken = default)
{
try
{
var response = await GetJsonResponseAsync($"channels/{channelId}");
var response = await GetJsonResponseAsync($"channels/{channelId}", cancellationToken);
return ChannelCategory.Parse(response);
}
// In some cases, the Discord API returns an empty body when requesting channel category.
Expand All @@ -173,42 +199,48 @@ public async ValueTask<ChannelCategory> GetChannelCategoryAsync(Snowflake channe
}
}

public async ValueTask<Channel> GetChannelAsync(Snowflake channelId)
public async ValueTask<Channel> GetChannelAsync(
Snowflake channelId,
CancellationToken cancellationToken = default)
{
var response = await GetJsonResponseAsync($"channels/{channelId}");
var response = await GetJsonResponseAsync($"channels/{channelId}", cancellationToken);

var parentId = response.GetPropertyOrNull("parent_id")?.GetString().Pipe(Snowflake.Parse);

var category = parentId is not null
? await GetChannelCategoryAsync(parentId.Value)
? await GetChannelCategoryAsync(parentId.Value, cancellationToken)
: null;

return Channel.Parse(response, category);
}

private async ValueTask<Message?> TryGetLastMessageAsync(Snowflake channelId, Snowflake? before = null)
private async ValueTask<Message?> TryGetLastMessageAsync(
Snowflake channelId,
Snowflake? before = null,
CancellationToken cancellationToken = default)
{
var url = new UrlBuilder()
.SetPath($"channels/{channelId}/messages")
.SetQueryParameter("limit", "1")
.SetQueryParameter("before", before?.ToString())
.Build();

var response = await GetJsonResponseAsync(url);
var response = await GetJsonResponseAsync(url, cancellationToken);
return response.EnumerateArray().Select(Message.Parse).LastOrDefault();
}

public async IAsyncEnumerable<Message> GetMessagesAsync(
Snowflake channelId,
Snowflake? after = null,
Snowflake? before = null,
IProgress<double>? progress = null)
IProgress<double>? progress = null,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
// Get the last message in the specified range.
// This snapshots the boundaries, which means that messages posted after the export started
// will not appear in the output.
// Additionally, it provides the date of the last message, which is used to calculate progress.
var lastMessage = await TryGetLastMessageAsync(channelId, before);
var lastMessage = await TryGetLastMessageAsync(channelId, before, cancellationToken);
if (lastMessage is null || lastMessage.Timestamp < after?.ToDate())
yield break;

Expand All @@ -224,7 +256,7 @@ public async IAsyncEnumerable<Message> GetMessagesAsync(
.SetQueryParameter("after", currentAfter.ToString())
.Build();

var response = await GetJsonResponseAsync(url);
var response = await GetJsonResponseAsync(url, cancellationToken);

var messages = response
.EnumerateArray()
Expand Down
2 changes: 1 addition & 1 deletion DiscordChatExporter.Core/DiscordChatExporter.Core.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

<ItemGroup>
<PackageReference Include="JsonExtensions" Version="1.1.0" />
<PackageReference Include="MiniRazor.CodeGen" Version="2.1.4" />
<PackageReference Include="MiniRazor.CodeGen" Version="2.2.0" />
<PackageReference Include="Polly" Version="7.2.2" />
<PackageReference Include="Superpower" Version="3.0.0" />
<PackageReference Include="Tyrrrz.Extensions" Version="1.6.5" />
Expand Down
Loading

0 comments on commit 21d89af

Please sign in to comment.