Skip to content

Commit

Permalink
Change the return type of the solve function to a HashMap of `RepoD…
Browse files Browse the repository at this point in the history
…ataRecord` and list of features
  • Loading branch information
prsabahrami committed Jan 16, 2025
1 parent 0307d9e commit c34e4ee
Show file tree
Hide file tree
Showing 14 changed files with 69 additions and 58 deletions.
21 changes: 15 additions & 6 deletions crates/rattler-bin/src/commands/create.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::{
borrow::Cow,
collections::HashMap,
env,
future::IntoFuture,
path::PathBuf,
Expand Down Expand Up @@ -261,12 +262,17 @@ pub async fn create(opt: Opt) -> anyhow::Result<()> {
// Next, use a solver to solve this specific problem. This provides us with all
// the operations we need to apply to our environment to bring it up to
// date.
let required_packages =
let required_packages_with_features =
wrap_in_progress("solving", move || match opt.solver.unwrap_or_default() {
Solver::Resolvo => resolvo::Solver.solve(solver_task),
Solver::LibSolv => libsolv_c::Solver.solve(solver_task),
})?;

let required_packages: Vec<RepoDataRecord> = required_packages_with_features
.clone()
.into_keys()
.collect();

if opt.dry_run {
// Construct a transaction to
let transaction = Transaction::from_current_and_desired(
Expand All @@ -278,7 +284,7 @@ pub async fn create(opt: Opt) -> anyhow::Result<()> {
if transaction.operations.is_empty() {
println!("No operations necessary");
} else {
print_transaction(&transaction);
print_transaction(&transaction, required_packages_with_features);
}

return Ok(());
Expand Down Expand Up @@ -309,26 +315,29 @@ pub async fn create(opt: Opt) -> anyhow::Result<()> {
console::style(console::Emoji("✔", "")).green(),
install_start.elapsed()
);
print_transaction(&result.transaction);
print_transaction(&result.transaction, required_packages_with_features);
}

Ok(())
}

/// Prints the operations of the transaction to the console.
fn print_transaction(transaction: &Transaction<PrefixRecord, RepoDataRecord>) {
fn print_transaction(
transaction: &Transaction<PrefixRecord, RepoDataRecord>,
features: HashMap<RepoDataRecord, Option<Vec<String>>>,
) {
let format_record = |r: &RepoDataRecord| {
let direct_url_print = if let Some(channel) = &r.channel {
channel.clone()
} else {
String::new()
};

if let Some(feature) = &r.selected_feature {
if let Some(Some(features)) = features.get(r) {
format!(
"{}[{}] {} {} {}",
r.package_record.name.as_normalized(),
feature,
features.join(", "),
r.package_record.version,
r.package_record.build,
direct_url_print,
Expand Down
1 change: 0 additions & 1 deletion crates/rattler/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,5 @@ pub(crate) fn get_repodata_record(package_path: impl AsRef<std::path::Path>) ->
.to_string(),
url: url::Url::from_file_path(package_path).unwrap(),
channel: Some(String::from("test")),
selected_feature: None,
}
}
2 changes: 1 addition & 1 deletion crates/rattler_conda_types/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ pub use repo_data::{
ChannelInfo, ConvertSubdirError, PackageRecord, RecordFromPath, RepoData,
ValidatePackageRecordsError,
};
pub use repo_data_record::RepoDataRecord;
pub use repo_data_record::{RepoDataRecord, SolverResult};
pub use run_export::RunExportKind;
pub use version::{
Component, ParseVersionError, ParseVersionErrorKind, StrictVersion, Version, VersionBumpError,
Expand Down
1 change: 0 additions & 1 deletion crates/rattler_conda_types/src/match_spec/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -734,7 +734,6 @@ mod tests {
file_name: String::from("mamba-1.0-py37_0"),
url: url::Url::parse("https://mamba.io/mamba-1.0-py37_0.conda").unwrap(),
channel: Some(String::from("mamba")),
selected_feature: None,
};
let package_record = repodata_record.clone().package_record;

Expand Down
1 change: 0 additions & 1 deletion crates/rattler_conda_types/src/repo_data/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,6 @@ impl RepoData {
channel: Some(channel.base_url.as_str().to_string()),
package_record,
file_name: filename,
selected_feature: None,
});
}
records
Expand Down
17 changes: 6 additions & 11 deletions crates/rattler_conda_types/src/repo_data_record.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
//! Defines the `[RepoDataRecord]` struct.
use std::collections::HashMap;
use std::vec::Vec;

use crate::PackageRecord;
use serde::{Deserialize, Serialize};
use url::Url;
Expand All @@ -24,21 +27,13 @@ pub struct RepoDataRecord {
/// explicit about where the package came from.
/// TODO: Refactor this into `Source` which can be a "name", "channelurl", or "direct url".
pub channel: Option<String>,

/// The selected feature set for this package.
pub selected_feature: Option<String>,
}

impl RepoDataRecord {
/// Set the selected feature set for this package.
pub fn set_selected_feature(&mut self, selected_feature: String) -> &mut Self {
self.selected_feature = Some(selected_feature);
self
}
}

impl AsRef<PackageRecord> for RepoDataRecord {
fn as_ref(&self) -> &PackageRecord {
&self.package_record
}
}

/// Type alias for the solver result containing records and their features
pub type SolverResult = HashMap<RepoDataRecord, Option<Vec<String>>>;
1 change: 0 additions & 1 deletion crates/rattler_lock/src/conda.rs
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,6 @@ impl TryFrom<CondaBinaryData> for RepoDataRecord {
file_name: value.file_name,
url: value.location.try_into_url()?,
channel: value.channel.map(|channel| channel.to_string()),
selected_feature: None,
})
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@ impl DirectUrlQuery {
file_name: self.url.clone().to_string(),
url: self.url.clone(),
channel: None,
selected_feature: None,
}]))
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,6 @@ async fn parse_records<R: AsRef<[u8]> + Send + 'static>(
channel: Some(channel_base_url.url().clone().redact().to_string()),
package_record,
file_name,
selected_feature: None,
})
.collect())
})
Expand Down
1 change: 0 additions & 1 deletion crates/rattler_repodata_gateway/src/sparse/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,6 @@ fn parse_records<'i>(
channel: Some(channel_name.url().clone().redact().to_string()),
package_record,
file_name: key.filename.to_owned(),
selected_feature: None,
});
}

Expand Down
4 changes: 2 additions & 2 deletions crates/rattler_solve/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ pub mod resolvo;
use std::fmt;

use chrono::{DateTime, Utc};
use rattler_conda_types::{GenericVirtualPackage, MatchSpec, RepoDataRecord};
use rattler_conda_types::{GenericVirtualPackage, MatchSpec, RepoDataRecord, SolverResult};

/// Represents a solver implementation, capable of solving [`SolverTask`]s
pub trait SolverImpl {
Expand All @@ -28,7 +28,7 @@ pub trait SolverImpl {
>(
&mut self,
task: SolverTask<TAvailablePackagesIterator>,
) -> Result<Vec<RepoDataRecord>, SolveError>;
) -> Result<SolverResult, SolveError>;
}

/// Represents an error when solving the dependencies for a given environment
Expand Down
9 changes: 6 additions & 3 deletions crates/rattler_solve/src/libsolv_c/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ pub use input::cache_repodata;
use input::{add_repodata_records, add_solv_file, add_virtual_packages};
pub use libc_byte_slice::LibcByteSlice;
use output::get_required_packages;
use rattler_conda_types::{MatchSpec, NamelessMatchSpec, RepoDataRecord};
use rattler_conda_types::{MatchSpec, NamelessMatchSpec, RepoDataRecord, SolverResult};
use wrapper::{
flags::SolverFlag,
pool::{Pool, Verbosity},
Expand Down Expand Up @@ -94,7 +94,7 @@ impl super::SolverImpl for Solver {
>(
&mut self,
task: SolverTask<TAvailablePackagesIterator>,
) -> Result<Vec<RepoDataRecord>, SolveError> {
) -> Result<SolverResult, SolveError> {
if task.timeout.is_some() {
return Err(SolveError::UnsupportedOperations(vec![
"timeout".to_string()
Expand Down Expand Up @@ -279,7 +279,10 @@ impl super::SolverImpl for Solver {
)
})?;

Ok(required_records)
Ok(required_records
.into_iter()
.map(|rec| (rec, None))
.collect())
}
}

Expand Down
38 changes: 21 additions & 17 deletions crates/rattler_solve/src/resolvo/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use itertools::Itertools;
use rattler_conda_types::{
package::ArchiveType, version_spec::EqualityOperator, BuildNumberSpec, GenericVirtualPackage,
MatchSpec, Matches, NamelessMatchSpec, OrdOperator, PackageName, PackageRecord,
ParseMatchSpecError, ParseStrictness, RepoDataRecord, StringMatcher, VersionSpec,
ParseMatchSpecError, ParseStrictness, RepoDataRecord, SolverResult, StringMatcher, VersionSpec,
};
use resolvo::{
utils::{Pool, VersionSet},
Expand Down Expand Up @@ -812,7 +812,7 @@ impl super::SolverImpl for Solver {
>(
&mut self,
task: SolverTask<TAvailablePackagesIterator>,
) -> Result<Vec<RepoDataRecord>, SolveError> {
) -> Result<SolverResult, SolveError> {
let stop_time = task
.timeout
.map(|timeout| std::time::SystemTime::now() + timeout);
Expand Down Expand Up @@ -902,22 +902,26 @@ impl super::SolverImpl for Solver {
})?;

// Get the resulting packages from the solver.
let required_records = solvables
.into_iter()
.filter_map(
|id| match &solver.provider().pool.resolve_solvable(id).record {
SolverPackageRecord::Record(rec) => Some((*rec).clone()),
SolverPackageRecord::RecordWithFeature(rec, feature) => {
let mut cloned = (*rec).clone();
cloned.set_selected_feature(feature.to_string());
Some(cloned)
}
SolverPackageRecord::VirtualPackage(_) => None,
},
)
.collect();
let mut record_features: HashMap<RepoDataRecord, Option<Vec<String>>> = HashMap::new();

for id in solvables {
match &solver.provider().pool.resolve_solvable(id).record {
SolverPackageRecord::Record(rec) => {
record_features.entry((*rec).clone()).or_insert(None);
}
SolverPackageRecord::RecordWithFeature(rec, feature) => {
let rec = (*rec).clone();
record_features
.entry(rec)
.or_insert_with(|| Some(Vec::new()))
.get_or_insert_with(Vec::new)
.push(feature.clone());
}
SolverPackageRecord::VirtualPackage(_) => {}
}
}

Ok(required_records)
Ok(record_features)
}
}

Expand Down
29 changes: 18 additions & 11 deletions crates/rattler_solve/tests/backends.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,6 @@ fn installed_package(
python_site_packages_path: None,
run_exports: None,
},
selected_feature: None,
}
}

Expand All @@ -143,7 +142,7 @@ fn solve_real_world<T: SolverImpl + Default>(specs: Vec<&str>) -> Vec<String> {
};

let pkgs1 = match T::default().solve(solver_task) {
Ok(result) => result,
Ok(result) => result.into_keys().collect(),
Err(e) => panic!("{e}"),
};

Expand Down Expand Up @@ -655,7 +654,7 @@ mod libsolv_c {

let specs: Vec<MatchSpec> = vec!["foo<4".parse().unwrap()];

let pkgs = rattler_solve::libsolv_c::Solver
let pkgs: Vec<RepoDataRecord> = rattler_solve::libsolv_c::Solver

Check failure on line 657 in crates/rattler_solve/tests/backends.rs

View workflow job for this annotation

GitHub Actions / Format and Lint

cannot find type `RepoDataRecord` in this scope
.solve(SolverTask {
locked_packages: Vec::new(),
virtual_packages: Vec::new(),
Expand All @@ -668,7 +667,9 @@ mod libsolv_c {
exclude_newer: None,
strategy: SolveStrategy::default(),
})
.unwrap();
.unwrap()
.into_keys()
.collect();

if pkgs.is_empty() {
println!("No packages in the environment!");
Expand Down Expand Up @@ -885,7 +886,6 @@ mod resolvo {
file_name: url_str.to_string(),
url: url.clone(),
channel: None,
selected_feature: None,
}];

// Completely clean solver task, except for the specs and RepoData
Expand All @@ -900,7 +900,11 @@ mod resolvo {
..SolverTask::from_iter([&repo_data])
};

let pkgs = rattler_solve::resolvo::Solver.solve(task).unwrap();
let pkgs: Vec<RepoDataRecord> = rattler_solve::resolvo::Solver
.solve(task)
.unwrap()
.into_keys()
.collect();

assert_eq!(pkgs.len(), 1);
assert_eq!(pkgs[0].package_record.name.as_normalized(), "_libgcc_mutex");
Expand All @@ -919,7 +923,6 @@ mod resolvo {
file_name: url_str.to_string(),
url: Url::from_str("https://false.dont").unwrap(),
channel: None,
selected_feature: None,
}];

// Completely clean solver task, except for the specs and RepoData
Expand Down Expand Up @@ -1214,7 +1217,7 @@ fn solve<T: SolverImpl + Default>(
println!("No packages in the environment!");
}

Ok(pkgs)
Ok(pkgs.into_keys().collect())
}

#[derive(Default)]
Expand Down Expand Up @@ -1269,7 +1272,9 @@ fn compare_solve(task: CompareTask<'_>) {
exclude_newer: task.exclude_newer,
..SolverTask::from_iter(&available_packages)
})
.unwrap(),
.unwrap()
.into_keys()
.collect(),
),
));
let end_solve = Instant::now();
Expand All @@ -1288,7 +1293,9 @@ fn compare_solve(task: CompareTask<'_>) {
exclude_newer: task.exclude_newer,
..SolverTask::from_iter(&available_packages)
})
.unwrap(),
.unwrap()
.into_keys()
.collect(),
),
));
let end_solve = Instant::now();
Expand Down Expand Up @@ -1372,7 +1379,7 @@ fn solve_to_get_channel_of_spec<T: SolverImpl + Default>(
..SolverTask::from_iter(&available_packages)
};

let result = T::default().solve(task).unwrap();
let result: Vec<RepoDataRecord> = T::default().solve(task).unwrap().into_keys().collect();

let record = result.iter().find(|record| {
record.package_record.name.as_normalized() == spec.name.as_ref().unwrap().as_normalized()
Expand Down

0 comments on commit c34e4ee

Please sign in to comment.