Skip to content

Commit

Permalink
Yield after channel send and move cpu tasks to thread
Browse files Browse the repository at this point in the history
  • Loading branch information
konstin committed Jan 30, 2024
1 parent 3f5e730 commit 55ea90a
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 63 deletions.
12 changes: 8 additions & 4 deletions crates/puffin-client/src/cached_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ impl CachedClient {
/// client.
#[instrument(skip_all)]
pub async fn get_cached_with_callback<
Payload: Serialize + DeserializeOwned + Send,
Payload: Serialize + DeserializeOwned + Send + 'static,
CallBackError,
Callback,
CallbackReturn,
Expand Down Expand Up @@ -172,7 +172,7 @@ impl CachedClient {
}
}

async fn read_cache<Payload: Serialize + DeserializeOwned + Send>(
async fn read_cache<Payload: Serialize + DeserializeOwned + Send + 'static>(
cache_entry: &CacheEntry,
) -> Option<DataWithCachePolicy<Payload>> {
let read_span = info_span!("read_cache", file = %cache_entry.path().display());
Expand All @@ -185,8 +185,12 @@ impl CachedClient {
"parse_cache",
path = %cache_entry.path().display()
);
let parse_result = parse_span
.in_scope(|| rmp_serde::from_slice::<DataWithCachePolicy<Payload>>(&cached));
let parse_result = tokio::task::spawn_blocking(move || {
parse_span
.in_scope(|| rmp_serde::from_slice::<DataWithCachePolicy<Payload>>(&cached))
})
.await
.expect("Tokio executor failed, was there a panic?");
match parse_result {
Ok(data) => Some(data),
Err(err) => {
Expand Down
30 changes: 19 additions & 11 deletions crates/puffin-resolver/src/resolver/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ use std::sync::Arc;

use anyhow::Result;
use dashmap::{DashMap, DashSet};
use futures::channel::mpsc::UnboundedReceiver;
use futures::channel::mpsc::{
UnboundedReceiver as MpscUnboundedReceiver, UnboundedSender as MpscUnboundedSender,
};
use futures::{FutureExt, StreamExt};
use itertools::Itertools;
use pubgrub::error::PubGrubError;
Expand Down Expand Up @@ -241,7 +243,7 @@ impl<'a, Provider: ResolverProvider> Resolver<'a, Provider> {
#[instrument(skip_all)]
async fn solve(
&self,
request_sink: &futures::channel::mpsc::UnboundedSender<Request>,
request_sink: &MpscUnboundedSender<Request>,
) -> Result<ResolutionGraph, ResolveError> {
let root = PubGrubPackage::Root(self.project.clone());

Expand Down Expand Up @@ -386,7 +388,7 @@ impl<'a, Provider: ResolverProvider> Resolver<'a, Provider> {
&self,
package: &PubGrubPackage,
priorities: &mut PubGrubPriorities,
request_sink: &futures::channel::mpsc::UnboundedSender<Request>,
request_sink: &MpscUnboundedSender<Request>,
) -> Result<(), ResolveError> {
match package {
PubGrubPackage::Root(_) => {}
Expand All @@ -413,14 +415,16 @@ impl<'a, Provider: ResolverProvider> Resolver<'a, Provider> {
}
}
}
// Yield after sending on a channel to allow the subscribers to continue
tokio::task::yield_now().await;
Ok(())
}

/// Visit the set of [`PubGrubPackage`] candidates prior to selection. This allows us to fetch
/// metadata for all of the packages in parallel.
fn pre_visit<'data>(
packages: impl Iterator<Item = (&'data PubGrubPackage, &'data Range<Version>)>,
request_sink: &futures::channel::mpsc::UnboundedSender<Request>,
request_sink: &MpscUnboundedSender<Request>,
) -> Result<(), ResolveError> {
// Iterate over the potential packages, and fetch file metadata for any of them. These
// represent our current best guesses for the versions that we _might_ select.
Expand All @@ -441,9 +445,9 @@ impl<'a, Provider: ResolverProvider> Resolver<'a, Provider> {
package: &PubGrubPackage,
range: &Range<Version>,
pins: &mut FilePins,
request_sink: &futures::channel::mpsc::UnboundedSender<Request>,
request_sink: &MpscUnboundedSender<Request>,
) -> Result<Option<Version>, ResolveError> {
return match package {
match package {
PubGrubPackage::Root(_) => Ok(Some(MIN_VERSION.clone())),

PubGrubPackage::Python(PubGrubPython::Installed) => {
Expand Down Expand Up @@ -584,16 +588,17 @@ impl<'a, Provider: ResolverProvider> Resolver<'a, Provider> {

Ok(Some(version))
}
};
}
}

/// Given a candidate package and version, return its dependencies.
#[instrument(skip_all, fields(%package, %version))]
async fn get_dependencies(
&self,
package: &PubGrubPackage,
version: &Version,
priorities: &mut PubGrubPriorities,
request_sink: &futures::channel::mpsc::UnboundedSender<Request>,
request_sink: &MpscUnboundedSender<Request>,
) -> Result<Dependencies, ResolveError> {
match package {
PubGrubPackage::Root(_) => {
Expand Down Expand Up @@ -724,7 +729,10 @@ impl<'a, Provider: ResolverProvider> Resolver<'a, Provider> {
}

/// Fetch the metadata for a stream of packages and versions.
async fn fetch(&self, request_stream: UnboundedReceiver<Request>) -> Result<(), ResolveError> {
async fn fetch(
&self,
request_stream: MpscUnboundedReceiver<Request>,
) -> Result<(), ResolveError> {
let mut response_stream = request_stream
.map(|request| self.process_request(request).boxed())
.buffer_unordered(50);
Expand Down Expand Up @@ -915,10 +923,10 @@ impl Display for Request {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
Request::Package(package_name) => {
write!(f, "Package {package_name}")
write!(f, "Versions {package_name}")
}
Request::Dist(dist) => {
write!(f, "Dist {dist}")
write!(f, "Metadata {dist}")
}
Request::Prefetch(package_name, range) => {
write!(f, "Prefetch {package_name} {range}")
Expand Down
115 changes: 68 additions & 47 deletions crates/puffin-resolver/src/resolver/provider.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
use std::future::Future;
use std::ops::Deref;
use std::sync::Arc;

use anyhow::Result;
use chrono::{DateTime, Utc};
use futures::FutureExt;
use url::Url;

use distribution_types::Dist;
Expand Down Expand Up @@ -45,17 +46,30 @@ pub trait ResolverProvider: Send + Sync {
/// The main IO backend for the resolver, which does cached requests network requests using the
/// [`RegistryClient`] and [`DistributionDatabase`].
pub struct DefaultResolverProvider<'a, Context: BuildContext + Send + Sync> {
/// The [`RegistryClient`] used to query the index.
client: &'a RegistryClient,
/// The [`DistributionDatabase`] used to build source distributions.
fetcher: DistributionDatabase<'a, Context>,
/// Allow moving the parameters to `VersionMap::from_metadata` to a different thread.
inner: Arc<DefaultResolverProviderInner>,
}

pub struct DefaultResolverProviderInner {
/// The [`RegistryClient`] used to query the index.
client: RegistryClient,
/// These are the entries from `--find-links` that act as overrides for index responses.
flat_index: &'a FlatIndex,
tags: &'a Tags,
flat_index: FlatIndex,
tags: Tags,
python_requirement: PythonRequirement,
exclude_newer: Option<DateTime<Utc>>,
allowed_yanks: AllowedYanks,
no_binary: &'a NoBinary,
no_binary: NoBinary,
}

impl<'a, Context: BuildContext + Send + Sync> Deref for DefaultResolverProvider<'a, Context> {
type Target = DefaultResolverProviderInner;

fn deref(&self) -> &Self::Target {
self.inner.as_ref()
}
}

impl<'a, Context: BuildContext + Send + Sync> DefaultResolverProvider<'a, Context> {
Expand All @@ -72,58 +86,65 @@ impl<'a, Context: BuildContext + Send + Sync> DefaultResolverProvider<'a, Contex
no_binary: &'a NoBinary,
) -> Self {
Self {
client,
fetcher,
flat_index,
tags,
python_requirement,
exclude_newer,
allowed_yanks,
no_binary,
inner: Arc::new(DefaultResolverProviderInner {
client: client.clone(),
flat_index: flat_index.clone(),
tags: tags.clone(),
python_requirement,
exclude_newer,
allowed_yanks,
no_binary: no_binary.clone(),
}),
}
}
}

impl<'a, Context: BuildContext + Send + Sync> ResolverProvider
for DefaultResolverProvider<'a, Context>
{
fn get_version_map<'io>(
&'io self,
package_name: &'io PackageName,
) -> impl Future<Output = VersionMapResponse> + Send + 'io {
self.client
.simple(package_name)
.map(move |result| match result {
Ok((index, metadata)) => Ok(VersionMap::from_metadata(
metadata,
package_name,
&index,
self.tags,
&self.python_requirement,
&self.allowed_yanks,
self.exclude_newer.as_ref(),
self.flat_index.get(package_name).cloned(),
self.no_binary,
)),
Err(err) => match err.into_kind() {
kind @ (puffin_client::ErrorKind::PackageNotFound(_)
| puffin_client::ErrorKind::NoIndex(_)) => {
if let Some(flat_index) = self.flat_index.get(package_name).cloned() {
Ok(VersionMap::from(flat_index))
} else {
Err(kind.into())
}
/// Make a simple api request for the package and convert the result to a [`VersionMap`].
async fn get_version_map<'io>(&'io self, package_name: &'io PackageName) -> VersionMapResponse {
let result = self.client.simple(package_name).await;

// If the simple api request was successful, perform on the slow conversion to `VersionMap` on the tokio
// threadpool
match result {
Ok((index, metadata)) => {
let self_send = self.inner.clone();
let package_name_owned = package_name.clone();
Ok(tokio::task::spawn_blocking(move || {
VersionMap::from_metadata(
metadata,
&package_name_owned,
&index,
&self_send.tags,
&self_send.python_requirement,
&self_send.allowed_yanks,
self_send.exclude_newer.as_ref(),
self_send.flat_index.get(&package_name_owned).cloned(),
&self_send.no_binary,
)
})
.await
.expect("Tokio executor failed, was there a panic?"))
}
Err(err) => match err.into_kind() {
kind @ (puffin_client::ErrorKind::PackageNotFound(_)
| puffin_client::ErrorKind::NoIndex(_)) => {
if let Some(flat_index) = self.flat_index.get(package_name).cloned() {
Ok(VersionMap::from(flat_index))
} else {
Err(kind.into())
}
kind => Err(kind.into()),
},
})
}
kind => Err(kind.into()),
},
}
}

fn get_or_build_wheel_metadata<'io>(
&'io self,
dist: &'io Dist,
) -> impl Future<Output = WheelMetadataResponse> + Send + 'io {
self.fetcher.get_or_build_wheel_metadata(dist)
async fn get_or_build_wheel_metadata<'io>(&'io self, dist: &'io Dist) -> WheelMetadataResponse {
self.fetcher.get_or_build_wheel_metadata(dist).await
}

/// Set the [`puffin_distribution::Reporter`] to use for this installer.
Expand Down
2 changes: 1 addition & 1 deletion crates/puffin-traits/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ impl Display for BuildKind {
}
}

#[derive(Debug)]
#[derive(Debug, Clone)]
pub enum NoBinary {
/// Allow installation of any wheel.
None,
Expand Down

0 comments on commit 55ea90a

Please sign in to comment.