Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Yield after channel send and move cpu tasks to thread #1163

Merged
merged 3 commits into from
Feb 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ target-lexicon = { version = "0.12.13" }
tempfile = { version = "3.9.0" }
textwrap = { version = "0.15.2" }
thiserror = { version = "1.0.56" }
tokio-stream = { version = "0.1.14" }
tl = { version = "0.7.7" }
tokio = { version = "1.35.1", features = ["rt-multi-thread"] }
tokio-tar = { version = "0.3.1" }
Expand Down
12 changes: 8 additions & 4 deletions crates/puffin-client/src/cached_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ impl CachedClient {
/// client.
#[instrument(skip_all)]
pub async fn get_cached_with_callback<
Payload: Serialize + DeserializeOwned + Send,
Payload: Serialize + DeserializeOwned + Send + 'static,
CallBackError,
Callback,
CallbackReturn,
Expand Down Expand Up @@ -172,7 +172,7 @@ impl CachedClient {
}
}

async fn read_cache<Payload: Serialize + DeserializeOwned + Send>(
async fn read_cache<Payload: Serialize + DeserializeOwned + Send + 'static>(
cache_entry: &CacheEntry,
) -> Option<DataWithCachePolicy<Payload>> {
let read_span = info_span!("read_cache", file = %cache_entry.path().display());
Expand All @@ -185,8 +185,12 @@ impl CachedClient {
"parse_cache",
path = %cache_entry.path().display()
);
let parse_result = parse_span
.in_scope(|| rmp_serde::from_slice::<DataWithCachePolicy<Payload>>(&cached));
let parse_result = tokio::task::spawn_blocking(move || {
parse_span
.in_scope(|| rmp_serde::from_slice::<DataWithCachePolicy<Payload>>(&cached))
})
.await
.expect("Tokio executor failed, was there a panic?");
match parse_result {
Ok(data) => Some(data),
Err(err) => {
Expand Down
1 change: 1 addition & 0 deletions crates/puffin-resolver/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ sha2 = { workspace = true }
tempfile = { workspace = true }
thiserror = { workspace = true }
tokio = { workspace = true, features = ["macros"] }
tokio-stream = { workspace = true }
tokio-util = { workspace = true, features = ["compat"] }
tracing = { workspace = true }
url = { workspace = true }
Expand Down
15 changes: 7 additions & 8 deletions crates/puffin-resolver/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,11 @@ pub enum ResolveError {
#[error("Failed to find a version of {0} that satisfies the requirement")]
NotFound(Requirement),

#[error("The request stream terminated unexpectedly")]
StreamTermination,

#[error(transparent)]
Client(#[from] puffin_client::Error),

#[error(transparent)]
TrySend(#[from] futures::channel::mpsc::SendError),
#[error("The channel is closed, was there a panic?")]
ChannelClosed,

#[error(transparent)]
Join(#[from] tokio::task::JoinError),
Expand Down Expand Up @@ -88,9 +85,11 @@ pub enum ResolveError {
Failure(String),
}

impl<T> From<futures::channel::mpsc::TrySendError<T>> for ResolveError {
fn from(value: futures::channel::mpsc::TrySendError<T>) -> Self {
value.into_send_error().into()
impl<T> From<tokio::sync::mpsc::error::SendError<T>> for ResolveError {
/// Drop the value we want to send to not leak the private type we're sending.
/// The tokio error only says "channel closed", so we don't lose information.
fn from(_value: tokio::sync::mpsc::error::SendError<T>) -> Self {
Self::ChannelClosed
}
}

Expand Down
63 changes: 30 additions & 33 deletions crates/puffin-resolver/src/resolver/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ use std::sync::Arc;

use anyhow::Result;
use dashmap::{DashMap, DashSet};
use futures::channel::mpsc::UnboundedReceiver;
use futures::{FutureExt, StreamExt};
use itertools::Itertools;
use pubgrub::error::PubGrubError;
Expand All @@ -14,6 +13,7 @@ use pubgrub::solver::{Incompatibility, State};
use pubgrub::type_aliases::DependencyConstraints;
use rustc_hash::{FxHashMap, FxHashSet};
use tokio::select;
use tokio_stream::wrappers::ReceiverStream;
use tracing::{debug, info_span, instrument, trace, Instrument};
use url::Url;

Expand Down Expand Up @@ -202,7 +202,8 @@ impl<'a, Provider: ResolverProvider> Resolver<'a, Provider> {
pub async fn resolve(self) -> Result<ResolutionGraph, ResolveError> {
// A channel to fetch package metadata (e.g., given `flask`, fetch all versions) and version
// metadata (e.g., given `flask==1.0.0`, fetch the metadata for that version).
let (request_sink, request_stream) = futures::channel::mpsc::unbounded();
// Channel size is set to the same size as the task buffer for simplicity.
let (request_sink, request_stream) = tokio::sync::mpsc::channel(50);

// Run the fetcher.
let requests_fut = self.fetch(request_stream).fuse();
Expand All @@ -213,7 +214,7 @@ impl<'a, Provider: ResolverProvider> Resolver<'a, Provider> {
let resolution = select! {
result = requests_fut => {
result?;
return Err(ResolveError::StreamTermination);
return Err(ResolveError::ChannelClosed);
}
resolution = resolve_fut => {
resolution.map_err(|err| {
Expand Down Expand Up @@ -241,7 +242,7 @@ impl<'a, Provider: ResolverProvider> Resolver<'a, Provider> {
#[instrument(skip_all)]
async fn solve(
&self,
request_sink: &futures::channel::mpsc::UnboundedSender<Request>,
request_sink: &tokio::sync::mpsc::Sender<Request>,
) -> Result<ResolutionGraph, ResolveError> {
let root = PubGrubPackage::Root(self.project.clone());

Expand All @@ -265,7 +266,7 @@ impl<'a, Provider: ResolverProvider> Resolver<'a, Provider> {
state.unit_propagation(next)?;

// Pre-visit all candidate packages, to allow metadata to be fetched in parallel.
Self::pre_visit(state.partial_solution.prioritized_packages(), request_sink)?;
Self::pre_visit(state.partial_solution.prioritized_packages(), request_sink).await?;

// Choose a package version.
let Some(highest_priority_pkg) =
Expand Down Expand Up @@ -386,7 +387,7 @@ impl<'a, Provider: ResolverProvider> Resolver<'a, Provider> {
&self,
package: &PubGrubPackage,
priorities: &mut PubGrubPriorities,
request_sink: &futures::channel::mpsc::UnboundedSender<Request>,
request_sink: &tokio::sync::mpsc::Sender<Request>,
) -> Result<(), ResolveError> {
match package {
PubGrubPackage::Root(_) => {}
Expand All @@ -395,21 +396,17 @@ impl<'a, Provider: ResolverProvider> Resolver<'a, Provider> {
// Emit a request to fetch the metadata for this package.
if self.index.packages.register(package_name.clone()) {
priorities.add(package_name.clone());
request_sink.unbounded_send(Request::Package(package_name.clone()))?;

// Yield to allow subscribers to continue, as the channel is sync.
tokio::task::yield_now().await;
request_sink
.send(Request::Package(package_name.clone()))
.await?;
}
}
PubGrubPackage::Package(package_name, _extra, Some(url)) => {
// Emit a request to fetch the metadata for this distribution.
let dist = Dist::from_url(package_name.clone(), url.clone())?;
if self.index.distributions.register(dist.package_id()) {
priorities.add(dist.name().clone());
request_sink.unbounded_send(Request::Dist(dist))?;

// Yield to allow subscribers to continue, as the channel is sync.
tokio::task::yield_now().await;
request_sink.send(Request::Dist(dist)).await?;
}
}
}
Expand All @@ -418,17 +415,19 @@ impl<'a, Provider: ResolverProvider> Resolver<'a, Provider> {

/// Visit the set of [`PubGrubPackage`] candidates prior to selection. This allows us to fetch
/// metadata for all of the packages in parallel.
fn pre_visit<'data>(
async fn pre_visit<'data>(
packages: impl Iterator<Item = (&'data PubGrubPackage, &'data Range<Version>)>,
request_sink: &futures::channel::mpsc::UnboundedSender<Request>,
request_sink: &tokio::sync::mpsc::Sender<Request>,
) -> Result<(), ResolveError> {
// Iterate over the potential packages, and fetch file metadata for any of them. These
// represent our current best guesses for the versions that we _might_ select.
for (package, range) in packages {
let PubGrubPackage::Package(package_name, _extra, None) = package else {
continue;
};
request_sink.unbounded_send(Request::Prefetch(package_name.clone(), range.clone()))?;
request_sink
.send(Request::Prefetch(package_name.clone(), range.clone()))
.await?;
}
Ok(())
}
Expand All @@ -441,9 +440,9 @@ impl<'a, Provider: ResolverProvider> Resolver<'a, Provider> {
package: &PubGrubPackage,
range: &Range<Version>,
pins: &mut FilePins,
request_sink: &futures::channel::mpsc::UnboundedSender<Request>,
request_sink: &tokio::sync::mpsc::Sender<Request>,
) -> Result<Option<Version>, ResolveError> {
return match package {
match package {
PubGrubPackage::Root(_) => Ok(Some(MIN_VERSION.clone())),

PubGrubPackage::Python(PubGrubPython::Installed) => {
Expand Down Expand Up @@ -576,24 +575,22 @@ impl<'a, Provider: ResolverProvider> Resolver<'a, Provider> {
// Emit a request to fetch the metadata for this version.
if self.index.distributions.register(candidate.package_id()) {
let dist = candidate.resolve().dist.clone();
request_sink.unbounded_send(Request::Dist(dist))?;

// Yield to allow subscribers to continue, as the channel is sync.
tokio::task::yield_now().await;
request_sink.send(Request::Dist(dist)).await?;
}

Ok(Some(version))
}
};
}
}

/// Given a candidate package and version, return its dependencies.
#[instrument(skip_all, fields(%package, %version))]
async fn get_dependencies(
&self,
package: &PubGrubPackage,
version: &Version,
priorities: &mut PubGrubPriorities,
request_sink: &futures::channel::mpsc::UnboundedSender<Request>,
request_sink: &tokio::sync::mpsc::Sender<Request>,
) -> Result<Dependencies, ResolveError> {
match package {
PubGrubPackage::Root(_) => {
Expand Down Expand Up @@ -724,8 +721,11 @@ impl<'a, Provider: ResolverProvider> Resolver<'a, Provider> {
}

/// Fetch the metadata for a stream of packages and versions.
async fn fetch(&self, request_stream: UnboundedReceiver<Request>) -> Result<(), ResolveError> {
let mut response_stream = request_stream
async fn fetch(
&self,
request_stream: tokio::sync::mpsc::Receiver<Request>,
) -> Result<(), ResolveError> {
let mut response_stream = ReceiverStream::new(request_stream)
.map(|request| self.process_request(request).boxed())
.buffer_unordered(50);

Expand Down Expand Up @@ -769,9 +769,6 @@ impl<'a, Provider: ResolverProvider> Resolver<'a, Provider> {
}
None => {}
}

// Yield to allow subscribers to continue, as the channel is sync.
tokio::task::yield_now().await;
}

Ok::<(), ResolveError>(())
Expand Down Expand Up @@ -902,7 +899,7 @@ impl<'a, Provider: ResolverProvider> Resolver<'a, Provider> {
/// Fetch the metadata for an item
#[derive(Debug)]
#[allow(clippy::large_enum_variant)]
enum Request {
pub(crate) enum Request {
/// A request to fetch the metadata for a package.
Package(PackageName),
/// A request to fetch the metadata for a built or source distribution.
Expand All @@ -915,10 +912,10 @@ impl Display for Request {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
Request::Package(package_name) => {
write!(f, "Package {package_name}")
write!(f, "Versions {package_name}")
}
Request::Dist(dist) => {
write!(f, "Dist {dist}")
write!(f, "Metadata {dist}")
}
Request::Prefetch(package_name, range) => {
write!(f, "Prefetch {package_name} {range}")
Expand Down
Loading
Loading