diff --git a/crates/ruff_formatter/src/builders.rs b/crates/ruff_formatter/src/builders.rs index ba5acf913de733..fe5ce2afb1aaf4 100644 --- a/crates/ruff_formatter/src/builders.rs +++ b/crates/ruff_formatter/src/builders.rs @@ -9,9 +9,7 @@ use Tag::*; use crate::format_element::tag::{Condition, Tag}; use crate::prelude::tag::{DedentMode, GroupMode, LabelId}; use crate::prelude::*; -use crate::{ - format_element, write, Argument, Arguments, FormatContext, FormatOptions, GroupId, TextSize, -}; +use crate::{write, Argument, Arguments, FormatContext, FormatOptions, GroupId, TextSize}; use crate::{Buffer, VecBuffer}; /// A line break that only gets printed if the enclosing `Group` doesn't fit on a single line. @@ -2543,15 +2541,12 @@ impl Format for BestFitting<'_, Context> { fn fmt(&self, f: &mut Formatter) -> FormatResult<()> { let variants = self.variants.items(); - let mut formatted_variants = Vec::with_capacity(variants.len()); + let mut buffer = VecBuffer::with_capacity(variants.len() * 8, f.state_mut()); for variant in variants { - let mut buffer = VecBuffer::with_capacity(8, f.state_mut()); - buffer.write_element(FormatElement::Tag(StartEntry)); + buffer.write_element(FormatElement::Tag(StartBestFittingEntry)); buffer.write_fmt(Arguments::from(variant))?; - buffer.write_element(FormatElement::Tag(EndEntry)); - - formatted_variants.push(buffer.into_vec().into_boxed_slice()); + buffer.write_element(FormatElement::Tag(EndBestFittingEntry)); } // SAFETY: The constructor guarantees that there are always at least two variants. It's, therefore, @@ -2559,9 +2554,7 @@ impl Format for BestFitting<'_, Context> { #[allow(unsafe_code)] let element = unsafe { FormatElement::BestFitting { - variants: format_element::BestFittingVariants::from_vec_unchecked( - formatted_variants, - ), + variants: BestFittingVariants::from_vec_unchecked(buffer.into_vec()), mode: self.mode, } }; diff --git a/crates/ruff_formatter/src/format_element.rs b/crates/ruff_formatter/src/format_element.rs index 48bd7157ca5061..f9fe281df3fde6 100644 --- a/crates/ruff_formatter/src/format_element.rs +++ b/crates/ruff_formatter/src/format_element.rs @@ -3,6 +3,7 @@ pub mod tag; use std::borrow::Cow; use std::hash::{Hash, Hasher}; +use std::iter::FusedIterator; use std::num::NonZeroU32; use std::ops::Deref; use std::rc::Rc; @@ -67,6 +68,16 @@ pub enum FormatElement { Tag(Tag), } +impl FormatElement { + pub fn tag_kind(&self) -> Option { + if let FormatElement::Tag(tag) = self { + Some(tag.kind()) + } else { + None + } + } +} + impl std::fmt::Debug for FormatElement { fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result { match self { @@ -318,7 +329,7 @@ pub enum BestFittingMode { /// The first element is the one that takes up the most space horizontally (the most flat), /// The last element takes up the least space horizontally (but most horizontal space). #[derive(Clone, Eq, PartialEq, Debug)] -pub struct BestFittingVariants(Box<[Box<[FormatElement]>]>); +pub struct BestFittingVariants(Box<[FormatElement]>); impl BestFittingVariants { /// Creates a new best fitting IR with the given variants. The method itself isn't unsafe @@ -331,9 +342,13 @@ impl BestFittingVariants { /// The slice must contain at least two variants. #[doc(hidden)] #[allow(unsafe_code)] - pub unsafe fn from_vec_unchecked(variants: Vec>) -> Self { + pub unsafe fn from_vec_unchecked(variants: Vec) -> Self { debug_assert!( - variants.len() >= 2, + variants + .iter() + .filter(|element| matches!(element, FormatElement::Tag(Tag::StartBestFittingEntry))) + .count() + >= 2, "Requires at least the least expanded and most expanded variants" ); @@ -342,40 +357,85 @@ impl BestFittingVariants { /// Returns the most expanded variant pub fn most_expanded(&self) -> &[FormatElement] { - self.0.last().expect( - "Most contain at least two elements, as guaranteed by the best fitting builder.", - ) + self.into_iter().last().unwrap() } - pub fn as_slice(&self) -> &[Box<[FormatElement]>] { + pub fn as_slice(&self) -> &[FormatElement] { &self.0 } /// Returns the least expanded variant pub fn most_flat(&self) -> &[FormatElement] { - self.0.first().expect( - "Most contain at least two elements, as guaranteed by the best fitting builder.", - ) + self.into_iter().next().unwrap() } } impl Deref for BestFittingVariants { - type Target = [Box<[FormatElement]>]; + type Target = [FormatElement]; fn deref(&self) -> &Self::Target { self.as_slice() } } +pub struct BestFittingVariantsIter<'a> { + elements: &'a [FormatElement], +} + impl<'a> IntoIterator for &'a BestFittingVariants { - type Item = &'a Box<[FormatElement]>; - type IntoIter = std::slice::Iter<'a, Box<[FormatElement]>>; + type Item = &'a [FormatElement]; + type IntoIter = BestFittingVariantsIter<'a>; fn into_iter(self) -> Self::IntoIter { - self.as_slice().iter() + BestFittingVariantsIter { elements: &self.0 } + } +} + +impl<'a> Iterator for BestFittingVariantsIter<'a> { + type Item = &'a [FormatElement]; + + fn next(&mut self) -> Option { + match self.elements.first()? { + FormatElement::Tag(Tag::StartBestFittingEntry) => { + let end = self + .elements + .iter() + .position(|element| { + matches!(element, FormatElement::Tag(Tag::EndBestFittingEntry)) + }) + .map_or(self.elements.len(), |position| position + 1); + + let (variant, rest) = self.elements.split_at(end); + self.elements = rest; + + Some(variant) + } + _ => None, + } + } + + fn last(mut self) -> Option + where + Self: Sized, + { + self.next_back() } } +impl<'a> DoubleEndedIterator for BestFittingVariantsIter<'a> { + fn next_back(&mut self) -> Option { + let start_position = self.elements.iter().rposition(|element| { + matches!(element, FormatElement::Tag(Tag::StartBestFittingEntry)) + })?; + + let (rest, variant) = self.elements.split_at(start_position); + self.elements = rest; + Some(variant) + } +} + +impl FusedIterator for BestFittingVariantsIter<'_> {} + pub trait FormatElements { /// Returns true if this [`FormatElement`] is guaranteed to break across multiple lines by the printer. /// This is the case if this format element recursively contains a: diff --git a/crates/ruff_formatter/src/format_element/document.rs b/crates/ruff_formatter/src/format_element/document.rs index 9321d7a97ccfbb..87d8b9a1caa7cf 100644 --- a/crates/ruff_formatter/src/format_element/document.rs +++ b/crates/ruff_formatter/src/format_element/document.rs @@ -35,7 +35,10 @@ impl Document { enum Enclosing<'a> { Group(&'a tag::Group), ConditionalGroup(&'a tag::ConditionalGroup), - FitsExpanded(&'a tag::FitsExpanded), + FitsExpanded { + tag: &'a tag::FitsExpanded, + expands_before: bool, + }, BestFitting, } @@ -43,7 +46,7 @@ impl Document { match enclosing.last() { Some(Enclosing::Group(group)) => group.propagate_expand(), Some(Enclosing::ConditionalGroup(group)) => group.propagate_expand(), - Some(Enclosing::FitsExpanded(fits_expanded)) => fits_expanded.propagate_expand(), + Some(Enclosing::FitsExpanded { tag, .. }) => tag.propagate_expand(), _ => {} } } @@ -85,23 +88,24 @@ impl Document { FormatElement::BestFitting { variants, mode: _ } => { enclosing.push(Enclosing::BestFitting); - for variant in variants { - propagate_expands(variant, enclosing, checked_interned); - } - - // Best fitting acts as a boundary - expands = false; + propagate_expands(variants, enclosing, checked_interned); enclosing.pop(); continue; } FormatElement::Tag(Tag::StartFitsExpanded(fits_expanded)) => { - enclosing.push(Enclosing::FitsExpanded(fits_expanded)); + enclosing.push(Enclosing::FitsExpanded { + tag: fits_expanded, + expands_before: expands, + }); false } FormatElement::Tag(Tag::EndFitsExpanded) => { - enclosing.pop(); - // Fits expanded acts as a boundary - expands = false; + if let Some(Enclosing::FitsExpanded { expands_before, .. }) = + enclosing.pop() + { + expands = expands_before; + } + continue; } FormatElement::Text { @@ -338,14 +342,20 @@ impl Format> for &[FormatElement] { } FormatElement::BestFitting { variants, mode } => { - write!(f, [token("best_fitting([")])?; + write!(f, [token("best_fitting(")])?; + + if *mode != BestFittingMode::FirstLine { + write!(f, [text(&std::format!("mode: {mode:?}, "), None)])?; + } + + write!(f, [token("[")])?; f.write_elements([ FormatElement::Tag(StartIndent), FormatElement::Line(LineMode::Hard), ]); for variant in variants { - write!(f, [&**variant, hard_line_break()])?; + write!(f, [variant, hard_line_break()])?; } f.write_elements([ @@ -353,13 +363,7 @@ impl Format> for &[FormatElement] { FormatElement::Line(LineMode::Hard), ]); - write!(f, [token("]")])?; - - if *mode != BestFittingMode::FirstLine { - write!(f, [text(&std::format!(", mode: {mode:?}"), None),])?; - } - - write!(f, [token(")")])?; + write!(f, [token("])")])?; } FormatElement::Interned(interned) => { @@ -594,10 +598,10 @@ impl Format> for &[FormatElement] { } } - StartEntry => { + StartEntry | StartBestFittingEntry { .. } => { // handled after the match for all start tags } - EndEntry => write!(f, [ContentArrayEnd])?, + EndEntry | EndBestFittingEntry => write!(f, [ContentArrayEnd])?, EndFill | EndLabelled diff --git a/crates/ruff_formatter/src/format_element/tag.rs b/crates/ruff_formatter/src/format_element/tag.rs index da69013faa271d..fd29152961994e 100644 --- a/crates/ruff_formatter/src/format_element/tag.rs +++ b/crates/ruff_formatter/src/format_element/tag.rs @@ -83,6 +83,9 @@ pub enum Tag { StartFitsExpanded(FitsExpanded), EndFitsExpanded, + + StartBestFittingEntry, + EndBestFittingEntry, } impl Tag { @@ -103,6 +106,7 @@ impl Tag { | Tag::StartVerbatim(_) | Tag::StartLabelled(_) | Tag::StartFitsExpanded(_) + | Tag::StartBestFittingEntry, ) } @@ -129,6 +133,7 @@ impl Tag { StartVerbatim(_) | EndVerbatim => TagKind::Verbatim, StartLabelled(_) | EndLabelled => TagKind::Labelled, StartFitsExpanded { .. } | EndFitsExpanded => TagKind::FitsExpanded, + StartBestFittingEntry { .. } | EndBestFittingEntry => TagKind::BestFittingEntry, } } } @@ -152,6 +157,7 @@ pub enum TagKind { Verbatim, Labelled, FitsExpanded, + BestFittingEntry, } #[derive(Debug, Copy, Default, Clone, Eq, PartialEq)] diff --git a/crates/ruff_formatter/src/printer/mod.rs b/crates/ruff_formatter/src/printer/mod.rs index 3d8d58be62ce91..951ce5633c1274 100644 --- a/crates/ruff_formatter/src/printer/mod.rs +++ b/crates/ruff_formatter/src/printer/mod.rs @@ -288,7 +288,9 @@ impl<'a> Printer<'a> { stack.push(TagKind::FitsExpanded, args); } - FormatElement::Tag(tag @ (StartLabelled(_) | StartEntry)) => { + FormatElement::Tag( + tag @ (StartLabelled(_) | StartEntry | StartBestFittingEntry { .. }), + ) => { stack.push(tag.kind(), args); } @@ -305,6 +307,7 @@ impl<'a> Printer<'a> { | EndFitsExpanded | EndVerbatim | EndLineSuffix + | EndBestFittingEntry | EndFill), ) => { stack.pop(tag.kind())?; @@ -495,47 +498,64 @@ impl<'a> Printer<'a> { if args.mode().is_flat() && self.state.measured_group_fits { queue.extend_back(variants.most_flat()); - self.print_entry(queue, stack, args) + self.print_entry(queue, stack, args, TagKind::BestFittingEntry) } else { self.state.measured_group_fits = true; - let normal_variants = &variants[..variants.len() - 1]; + let mut variants_iter = variants.into_iter(); + let mut current = variants_iter.next().unwrap(); - for variant in normal_variants { + for next in variants_iter { // Test if this variant fits and if so, use it. Otherwise try the next // variant. // Try to fit only the first variant on a single line - if !matches!(variant.first(), Some(&FormatElement::Tag(Tag::StartEntry))) { - return invalid_start_tag(TagKind::Entry, variant.first()); + if !matches!( + current.first(), + Some(&FormatElement::Tag(Tag::StartBestFittingEntry)) + ) { + return invalid_start_tag(TagKind::BestFittingEntry, current.first()); } // Skip the first element because we want to override the args for the entry and the // args must be popped from the stack as soon as it sees the matching end entry. - let content = &variant[1..]; + let content = ¤t[1..]; let entry_args = args .with_print_mode(PrintMode::Flat) .with_measure_mode(MeasureMode::from(mode)); queue.extend_back(content); - stack.push(TagKind::Entry, entry_args); + stack.push(TagKind::BestFittingEntry, entry_args); let variant_fits = self.fits(queue, stack)?; - stack.pop(TagKind::Entry)?; + stack.pop(TagKind::BestFittingEntry)?; // Remove the content slice because printing needs the variant WITH the start entry let popped_slice = queue.pop_slice(); debug_assert_eq!(popped_slice, Some(content)); if variant_fits { - queue.extend_back(variant); - return self.print_entry(queue, stack, args.with_print_mode(PrintMode::Flat)); + queue.extend_back(current); + return self.print_entry( + queue, + stack, + args.with_print_mode(PrintMode::Flat), + TagKind::BestFittingEntry, + ); } + + current = next; } + // At this stage current is the most expanded. + // No variant fits, take the last (most expanded) as fallback - let most_expanded = variants.most_expanded(); - queue.extend_back(most_expanded); - self.print_entry(queue, stack, args.with_print_mode(PrintMode::Expanded)) + queue.extend_back(current); + self.print_entry( + queue, + stack, + args.with_print_mode(PrintMode::Expanded), + TagKind::BestFittingEntry, + ) } } @@ -686,7 +706,7 @@ impl<'a> Printer<'a> { stack: &mut PrintCallStack, args: PrintElementArgs, ) -> PrintResult<()> { - self.print_entry(queue, stack, args) + self.print_entry(queue, stack, args, TagKind::Entry) } /// Semantic alias for [`Self::print_entry`] for fill separators. @@ -696,7 +716,7 @@ impl<'a> Printer<'a> { stack: &mut PrintCallStack, args: PrintElementArgs, ) -> PrintResult<()> { - self.print_entry(queue, stack, args) + self.print_entry(queue, stack, args, TagKind::Entry) } /// Fully print an element (print the element itself and all its descendants) @@ -708,32 +728,31 @@ impl<'a> Printer<'a> { queue: &mut PrintQueue<'a>, stack: &mut PrintCallStack, args: PrintElementArgs, + kind: TagKind, ) -> PrintResult<()> { let start_entry = queue.top(); - if !matches!(start_entry, Some(&FormatElement::Tag(Tag::StartEntry))) { - return invalid_start_tag(TagKind::Entry, start_entry); + if queue + .pop() + .is_some_and(|start| start.tag_kind() == Some(kind)) + { + stack.push(kind, args); + } else { + return invalid_start_tag(kind, start_entry); } - let mut depth = 0; + let mut depth = 1u32; while let Some(element) = queue.pop() { match element { - FormatElement::Tag(Tag::StartEntry) => { - // Handle the start of the first element by pushing the args on the stack. - if depth == 0 { - depth = 1; - stack.push(TagKind::Entry, args); - continue; - } - + FormatElement::Tag(Tag::StartEntry | Tag::StartBestFittingEntry) => { depth += 1; } - FormatElement::Tag(Tag::EndEntry) => { + FormatElement::Tag(end_tag @ (Tag::EndEntry | Tag::EndBestFittingEntry)) => { depth -= 1; // Reached the end entry, pop the entry from the stack and return. if depth == 0 { - stack.pop(TagKind::Entry)?; + stack.pop(end_tag.kind())?; return Ok(()); } } @@ -745,7 +764,7 @@ impl<'a> Printer<'a> { self.print_element(stack, queue, element)?; } - invalid_end_tag(TagKind::Entry, stack.top_kind()) + invalid_end_tag(kind, stack.top_kind()) } fn print_char(&mut self, char: char) { @@ -1148,11 +1167,14 @@ impl<'a, 'print> FitsMeasurer<'a, 'print> { PrintMode::Expanded => (variants.most_expanded(), args), }; - if !matches!(slice.first(), Some(FormatElement::Tag(Tag::StartEntry))) { - return invalid_start_tag(TagKind::Entry, slice.first()); + if !matches!( + slice.first(), + Some(FormatElement::Tag(Tag::StartBestFittingEntry)) + ) { + return invalid_start_tag(TagKind::BestFittingEntry, slice.first()); } - self.stack.push(TagKind::Entry, args); + self.stack.push(TagKind::BestFittingEntry, args); self.queue.extend_back(&slice[1..]); } @@ -1277,7 +1299,11 @@ impl<'a, 'print> FitsMeasurer<'a, 'print> { } FormatElement::Tag( - tag @ (StartFill | StartVerbatim(_) | StartLabelled(_) | StartEntry), + tag @ (StartFill + | StartVerbatim(_) + | StartLabelled(_) + | StartEntry + | StartBestFittingEntry { .. }), ) => { self.stack.push(tag.kind(), args); } @@ -1294,6 +1320,7 @@ impl<'a, 'print> FitsMeasurer<'a, 'print> { | EndAlign | EndDedent | EndIndent + | EndBestFittingEntry | EndFitsExpanded), ) => { self.stack.pop(tag.kind())?;