Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Auto-language detection #24

Merged
merged 7 commits into from
Jun 11, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ rayon = "1.7.0"
regex = "1.8.2"
termcolor = "1.2.0"
tree-sitter = "0.20.10"
tree-sitter-javascript = "0.20.0"
tree-sitter-rust = "0.20.3"
tree-sitter-typescript = "0.20.2"

Expand Down
72 changes: 71 additions & 1 deletion src/language.rs
Original file line number Diff line number Diff line change
@@ -1,24 +1,30 @@
use std::{collections::HashMap, ffi::OsStr, path::Path};

use clap::ValueEnum;
use tree_sitter::Language;

#[derive(Copy, Clone, Debug, Eq, PartialEq, ValueEnum)]
#[derive(Copy, Clone, Debug, Eq, PartialEq, ValueEnum, Hash)]
pub enum SupportedLanguageName {
Rust,
Typescript,
Javascript,
}

impl SupportedLanguageName {
pub fn get_language(self) -> Box<dyn SupportedLanguage> {
match self {
Self::Rust => Box::new(get_rust_language()),
Self::Typescript => Box::new(get_typescript_language()),
Self::Javascript => Box::new(get_javascript_language()),
}
}
}

pub trait SupportedLanguage {
fn language(&self) -> Language;
fn name(&self) -> SupportedLanguageName;
fn name_for_ignore_select(&self) -> &'static str;
fn extensions(&self) -> Vec<&'static str>;
}

pub struct SupportedLanguageRust;
Expand All @@ -28,9 +34,17 @@ impl SupportedLanguage for SupportedLanguageRust {
tree_sitter_rust::language()
}

fn name(&self) -> SupportedLanguageName {
SupportedLanguageName::Rust
}

fn name_for_ignore_select(&self) -> &'static str {
"rust"
}

fn extensions(&self) -> Vec<&'static str> {
vec!["rs"]
}
}

pub fn get_rust_language() -> SupportedLanguageRust {
Expand All @@ -44,11 +58,67 @@ impl SupportedLanguage for SupportedLanguageTypescript {
tree_sitter_typescript::language_tsx()
}

fn name(&self) -> SupportedLanguageName {
SupportedLanguageName::Typescript
}

fn name_for_ignore_select(&self) -> &'static str {
"ts"
}

fn extensions(&self) -> Vec<&'static str> {
vec!["ts", "tsx"]
}
}

pub fn get_typescript_language() -> SupportedLanguageTypescript {
SupportedLanguageTypescript
}

pub struct SupportedLanguageJavascript;

impl SupportedLanguage for SupportedLanguageJavascript {
fn language(&self) -> Language {
tree_sitter_javascript::language()
}

fn name(&self) -> SupportedLanguageName {
SupportedLanguageName::Javascript
}

fn name_for_ignore_select(&self) -> &'static str {
"js"
}

fn extensions(&self) -> Vec<&'static str> {
vec!["js", "jsx", "vue", "cjs", "mjs"]
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method is for mapping back to a known supported language when iterating through project files

I got this list from the ignore source code, I think there will probably be more subtly to some of this tree-sitter grammar <-> filename mapping stuff (eg Typescript vs TSX grammar) but this seems like a reasonable dirty approach for the moment?

}
}

pub fn get_javascript_language() -> SupportedLanguageJavascript {
SupportedLanguageJavascript
}

pub fn get_all_supported_languages() -> HashMap<SupportedLanguageName, Box<dyn SupportedLanguage>> {
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I started to think about making this eg a static Lazy singleton and then decided I was being premature optimization-y

HashMap::from_iter([
(
SupportedLanguageName::Rust,
Box::new(get_rust_language()) as Box<dyn SupportedLanguage>,
),
(
SupportedLanguageName::Typescript,
Box::new(get_typescript_language()) as Box<dyn SupportedLanguage>,
),
(
SupportedLanguageName::Javascript,
Box::new(get_javascript_language()) as Box<dyn SupportedLanguage>,
),
])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alternatively, you can just .into() the array to convert into a HashMap.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure updated to use .into()

}

pub fn maybe_supported_language_from_path(path: &Path) -> Option<Box<dyn SupportedLanguage>> {
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the core thing mapping project file path back to supported language, so this is getting called once per project file

let extension = path.extension().and_then(OsStr::to_str)?;
get_all_supported_languages()
.into_values()
.find(|language| language.extensions().contains(&extension))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm sure it makes no difference performance-wise for this amount of data, but the natural collection I'd expected to look this up in is a hash map of extension -> language?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ya I think I was actually doing that (writing a helper to produce such a hash map with the idea of then maybe "lazy singleton"-izing it) but I think (not sure if it was for some specific reason) I just decided this was easier for now

}
140 changes: 119 additions & 21 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,18 @@ use once_cell::unsync::OnceCell;
use rayon::iter::IterBridge;
use rayon::prelude::*;
use std::cell::RefCell;
use std::collections::HashMap;
use std::fs;
use std::path::Path;
use std::path::PathBuf;
use std::process;
use std::rc::Rc;
use std::sync::atomic::AtomicU32;
use std::sync::atomic::Ordering;
use std::sync::mpsc;
use std::sync::mpsc::Receiver;
use std::sync::Arc;
use std::sync::Mutex;
use std::thread;
use std::thread::JoinHandle;
use termcolor::Buffer;
Expand All @@ -33,9 +38,12 @@ mod macros;
mod plugin;
mod treesitter;

use language::{SupportedLanguage, SupportedLanguageName};
use language::{
get_all_supported_languages, maybe_supported_language_from_path, SupportedLanguage,
SupportedLanguageName,
};
use plugin::get_loaded_filter;
use treesitter::{get_matches, get_query};
use treesitter::{get_matches, maybe_get_query};

#[derive(Parser)]
pub struct Args {
Expand All @@ -45,7 +53,7 @@ pub struct Args {
#[arg(short, long = "capture")]
pub capture_name: Option<String>,
#[arg(short, long, value_enum)]
pub language: SupportedLanguageName,
pub language: Option<SupportedLanguageName>,
#[arg(short, long)]
pub filter: Option<String>,
#[arg(short = 'a', long)]
Expand Down Expand Up @@ -91,25 +99,77 @@ fn get_output_mode(args: &Args) -> OutputMode {
}
}

struct MaybeInitializedCaptureIndex(AtomicU32);
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is maybe questionable

The weird thing this is trying to address is that we can't validate a provided --capture (capture name) command-line argument (and sort of "parse-don't-validate"/resolve it to the corresponding capture index) until we have a successfully parsed tree-sitter query. But now that happens "on-the-fly" in parallel-world

So it seems like we want a shared atomic-y lazy-initialized resolved capture index (or the failure case also has to be considered, where they provided an invalid capture name and we'd like that to just fail "once" with a nicely printed error message)?


impl MaybeInitializedCaptureIndex {
fn mark_failed(&self) {
self.0.store(u32::MAX - 1, Ordering::Relaxed);
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I did (I guess in order to avoid having multiple atomics going on at which point the whole Ordering thing gets harder to reason about) was just have a single AtomicUsize and encode the "uninitialized" and "invalid capture name" states into its value

}

pub fn get(&self) -> Result<Option<u32>, ()> {
let loaded = self.0.load(Ordering::Relaxed);
match loaded {
loaded if loaded == u32::MAX => Ok(None),
loaded if loaded == u32::MAX - 1 => Err(()),
loaded => Ok(Some(loaded)),
}
}

pub fn get_or_initialize(&self, query: &Query, capture_name: Option<&str>) -> Result<u32, ()> {
if let Some(already_initialized) = self.get()? {
return Ok(already_initialized);
}
let capture_index = match capture_name {
None => 0,
Some(capture_name) => {
let capture_index = query.capture_index_for_name(capture_name);
if capture_index.is_none() {
self.mark_failed();
fail(&format!("invalid capture name '{}'", capture_name));
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So in theory this would be the one time that this error message gets printed (other threads then calling .get_or_initialize() would early-return an Err above)

And in practice (eg I added a test for this) I've only been seeing it get printed once

However I think it's actually "race"-y because I think another thread could enter this block before this one has had a chance to .mark_failed() here

Ooh actually I bet I could use a "fetch-swap" or whatever in .mark_failed() so basically ensure that after the call to .mark_failed() it's definitely stored the failure value but then return a value from that call to .mark_failed() that indicates whether we actually "won" the race to insert that failure value. And then only print the failure message here if we won

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok I updated this to use .compare_exchange() to avoid such a race. And in fact "validated" it by artificially introducing a long sleep, seeing the race happen (ie multiple printings of the failure message), and then no longer seeing it happen after adding the .compare_exchange()

}
capture_index.unwrap()
}
};
self.set(capture_index);
Ok(capture_index)
}

fn set(&self, capture_index: u32) {
self.0.store(capture_index, Ordering::Relaxed);
}
}

impl Default for MaybeInitializedCaptureIndex {
fn default() -> Self {
Self(AtomicU32::new(u32::MAX))
}
}

pub fn run(args: Args) {
let query_source = match args.query_args.path_to_query_file.as_ref() {
Some(path_to_query_file) => fs::read_to_string(path_to_query_file).unwrap(),
None => args.query_args.query_source.clone().unwrap(),
};
let supported_language = args.language.get_language();
let language = supported_language.language();
let query = Arc::new(get_query(&query_source, language));
let capture_index = args.capture_name.as_ref().map_or(0, |capture_name| {
query
.capture_index_for_name(capture_name)
.expect(&format!("Unknown capture name: `{}`", capture_name))
});
let specified_supported_language = args.language.map(|language| language.get_language());
let query_or_failure_by_language: Mutex<HashMap<SupportedLanguageName, Option<Arc<Query>>>> =
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A None value in this HashMap indicates "we already tried parsing the query with this language's grammar and it failed to parse"

Default::default();
let capture_index = MaybeInitializedCaptureIndex::default();
let output_mode = get_output_mode(&args);
let buffer_writer = BufferWriter::stdout(ColorChoice::Never);

get_project_file_walker(&*supported_language, &args.use_paths())
get_project_file_walker(specified_supported_language.as_deref(), &args.use_paths())
.into_parallel_iterator()
.for_each(|project_file_dir_entry| {
let language = maybe_supported_language_from_path(project_file_dir_entry.path())
.expect("Walker should've been pre-filtered to just supported file types");
let query = return_if_none!(get_and_cache_query_for_language(
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reading this I like this macro better than the boilerplate of

let whatever = ...;
if whatever.is_none() {
    return;
}
let whatever = whatever.unwrap();

or:

let whatever = match ... {
    None => return,
    Some(whatever) => whatever,
};

but it would be nice in terms of "noisiness"/"precedence" to write it postfix eg:

let query = ...
    .return_if_none!();

Looks like that is a proposal

Or I guess if this function returned an Option/Result then could use ??

&query_source,
&query_or_failure_by_language,
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is currently no optimization for when you set the --language flag other than the fact that the walker limits itself to just returning files of that type

ie we're still using this Mutex<HashMap<...>> even when we "know" that there's only one language/parsed query we're interested in

On the one hand it seemed a little silly to tear out that "fast path" but it's hard to picture how it wouldn't have made the code more complex and so I deemed it premature optimization to try and maintain that

&*language,
));
let capture_index = return_if_none!(capture_index
.get_or_initialize(&query, args.capture_name.as_deref())
.ok());
let printer = get_printer(&buffer_writer, output_mode);
let mut printer = printer.borrow_mut();
let path =
Expand All @@ -118,7 +178,7 @@ pub fn run(args: Args) {
let matcher = TreeSitterMatcher::new(
&query,
capture_index,
language,
language.language(),
args.filter.clone(),
args.filter_arg.clone(),
);
Expand All @@ -129,6 +189,38 @@ pub fn run(args: Args) {
.unwrap();
buffer_writer.print(printer.get_mut()).unwrap();
});

error_if_no_successful_query_parsing(&query_or_failure_by_language);
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Turn "n lazy soft fails" into a "hard fail"

}

fn error_if_no_successful_query_parsing(
query_or_failure_by_language: &Mutex<HashMap<SupportedLanguageName, Option<Arc<Query>>>>,
) {
let query_or_failure_by_language = query_or_failure_by_language.lock().unwrap();
if !query_or_failure_by_language
.values()
.any(|query| query.is_some())
{
fail("invalid query");
}
}

fn fail(message: &str) -> ! {
eprintln!("error: {message}");
process::exit(1);
}

fn get_and_cache_query_for_language(
query_source: &str,
query_or_failure_by_language: &Mutex<HashMap<SupportedLanguageName, Option<Arc<Query>>>>,
language: &dyn SupportedLanguage,
) -> Option<Arc<Query>> {
query_or_failure_by_language
.lock()
.unwrap()
.entry(language.name())
.or_insert_with(|| maybe_get_query(query_source, language.language()).map(Arc::new))
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought a little bit about this in terms of performance - I don't know how slow query-parsing is (vs project file iteration) but there could be competition for this mutex especially "at the beginning" (if a bunch of project files get into rayon-world and get stuck waiting for the mutex to get access to the parsed query for their language)

But (in the "effectively single-language" scenario) they'd all have to in some sense wait for that query-parsing time anyway so doesn't seem terrible

If there are multiple supported languages being encountered then you could have parsing less-than-optimally parallelized I think, seems like then maybe having per-language mutexes would be the optimization there (rather than one mutex for n supported languages)?

.clone()
}

type Printer = Standard<Buffer>;
Expand Down Expand Up @@ -253,16 +345,22 @@ impl Iterator for WalkParallelIterator {
}
}

fn get_project_file_walker(language: &dyn SupportedLanguage, paths: &[PathBuf]) -> WalkParallel {
fn get_project_file_walker(
language: Option<&dyn SupportedLanguage>,
paths: &[PathBuf],
) -> WalkParallel {
assert!(!paths.is_empty());
let mut builder = WalkBuilder::new(&paths[0]);
builder.types(
TypesBuilder::new()
.add_defaults()
.select(language.name_for_ignore_select())
.build()
.unwrap(),
);
let mut types_builder = TypesBuilder::new();
types_builder.add_defaults();
if let Some(language) = language {
types_builder.select(language.name_for_ignore_select());
} else {
for language in get_all_supported_languages().values() {
types_builder.select(language.name_for_ignore_select());
}
}
builder.types(types_builder.build().unwrap());
for path in &paths[1..] {
builder.add(path);
}
Expand Down
14 changes: 13 additions & 1 deletion src/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,17 @@ macro_rules! regex {
($re:literal $(,)?) => {{
static RE: once_cell::sync::OnceCell<regex::Regex> = once_cell::sync::OnceCell::new();
RE.get_or_init(|| regex::Regex::new($re).unwrap())
}}
}};
}

#[macro_export]
macro_rules! return_if_none {
($expr:expr $(,)?) => {
match $expr {
None => {
return;
}
Some(expr) => expr,
}
};
}
4 changes: 2 additions & 2 deletions src/treesitter/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ pub fn get_parser(language: Language) -> Parser {
parser
}

pub fn get_query(source: &str, language: Language) -> Query {
Query::new(language, source).unwrap()
pub fn maybe_get_query(source: &str, language: Language) -> Option<Query> {
Query::new(language, source).ok()
}

pub fn get_matches(
Expand Down
1 change: 1 addition & 0 deletions tests/fixtures/mixed_project/javascript_src/index.js
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
const js_foo = () => {}
1 change: 1 addition & 0 deletions tests/fixtures/mixed_project/rust_src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
fn foo() {}
1 change: 1 addition & 0 deletions tests/fixtures/mixed_project/typescript_src/index.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
const foo = () => {}
1 change: 1 addition & 0 deletions tests/fixtures/rust_project/function-itemz.scm
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
(function_itemz) @function_item
Loading