Skip to content

Commit

Permalink
refactor: implement with a visitor
Browse files Browse the repository at this point in the history
  • Loading branch information
DaniPopes committed Jan 5, 2025
1 parent 9b66930 commit 33d9130
Show file tree
Hide file tree
Showing 2 changed files with 150 additions and 80 deletions.
159 changes: 79 additions & 80 deletions crates/forge/bin/cmd/bind_json.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@ use foundry_compilers::{
};
use foundry_config::Config;
use itertools::Itertools;
use rayon::prelude::*;
use solar_ast::{
ast::{Arena, FunctionKind, ItemKind, VarMut},
ast::{self, Arena, FunctionKind, Span, VarMut},
interface::source_map::FileName,
visit::Visit,
};
use solar_parse::{interface::Session, Parser as SolarParser};
use std::{
Expand Down Expand Up @@ -88,95 +88,94 @@ impl BindJsonArgs {
.unwrap()
.1;

// Insert empty bindings file
let sess = Session::builder().with_stderr_emitter().build();
let result = sess.enter(|| -> solar_parse::interface::Result<()> {
// TODO: Switch back to par_iter_mut and `enter_parallel` after solar update.
sources.0.iter_mut().try_for_each(|(path, source)| {
let mut content = Arc::try_unwrap(std::mem::take(&mut source.content)).unwrap();

let arena = Arena::new();
let mut parser = SolarParser::from_source_code(
&sess,
&arena,
FileName::Real(path.clone()),
content.to_string(),
)?;
let ast = parser.parse_file().map_err(|e| e.emit())?;

let mut visitor = PreprocessorVisitor::new();
visitor.visit_source_unit(&ast);
visitor.update(&sess, &mut content);

source.content = Arc::new(content);
Ok(())
})
});
eyre::ensure!(result.is_ok(), "failed parsing");

// Insert empty bindings file.
sources.insert(target_path.clone(), Source::new("library JsonBindings {}"));

let sources = Sources(
sources
.0
.into_par_iter()
.map(|(path, source)| {
let mut content = Arc::unwrap_or_clone(source.content);
let sess = Session::builder().with_stderr_emitter().build();

let result = sess.enter(|| -> solar_ast::interface::Result<()> {
let arena = Arena::new();
let mut funcs = Vec::new();
let mut locs_to_update = Vec::new();
let mut parser = SolarParser::from_source_code(
&sess,
&arena,
FileName::Real(path.clone()),
content.to_string(),
)?;

let parsed = parser.parse_file().map_err(|e| e.emit())?;

for item in parsed.items {
if let ItemKind::Function(def) = &item.kind {
funcs.push(def);
}
if let ItemKind::Contract(contract) = &item.kind {
for part in contract.body.iter() {
match &part.kind {
ItemKind::Function(def) => {
funcs.push(def);
}
ItemKind::Variable(def) => {
if let Some(VarMut::Immutable) = def.mutability {
locs_to_update.push((
def.span.lo().0,
def.span.hi().0,
String::new(),
));
}
}
_ => {}
}
}
}
}
Ok(PreprocessedState { sources, target_path, project, config })
}
}

for func in funcs {
// If there's no body block, keep the function as is
let Some(stmt) = &func.body else {
continue;
};
let new_body = match func.kind {
FunctionKind::Modifier => "_;",
_ => "revert();",
};
let start = stmt.first().map(|s| s.span.lo().0);
let end = stmt.last().map(|s| s.span.hi().0);
if let (Some(start), Some(end)) = (start, end) {
locs_to_update.push((start, end, new_body.to_string()));
}
}
struct PreprocessorVisitor {
updates: Vec<(Span, &'static str)>,
}

locs_to_update.sort_by_key(|(start, _, _)| *start);
impl PreprocessorVisitor {
fn new() -> Self {
Self { updates: Vec::new() }
}

let mut shift = 0_i64;
fn update(mut self, sess: &Session, content: &mut String) {
if self.updates.is_empty() {
return;
}

for (start, end, new) in locs_to_update {
let start = ((start as i64) - shift) as usize;
let end = ((end as i64) - shift) as usize;
let sf = sess.source_map().lookup_source_file(self.updates[0].0.lo());
let base = sf.start_pos.0;

content.replace_range(start..end, new.as_str());
shift += (end - start) as i64;
shift -= new.len() as i64;
}
self.updates.sort_by_key(|(span, _)| span.lo());
let mut shift = 0_i64;
for (span, new) in self.updates {
let lo = span.lo() - base;
let hi = span.hi() - base;
let start = ((lo.0 as i64) - shift) as usize;
let end = ((hi.0 as i64) - shift) as usize;

Ok(())
});
content.replace_range(start..end, new);
shift += (end - start) as i64;
shift -= new.len() as i64;
}
}
}

eyre::ensure!(result.is_ok(), "parsing failed");
impl<'ast> Visit<'ast> for PreprocessorVisitor {
fn visit_item_function(&mut self, func: &'ast ast::ItemFunction<'ast>) {
// Replace function bodies with a noop statement.
if let Some(block) = &func.body {
if !block.is_empty() {
let span = block.first().unwrap().span.to(block.last().unwrap().span);
let new_body = match func.kind {
FunctionKind::Modifier => "_;",
_ => "revert();",
};
self.updates.push((span, new_body));
}
}

Ok((path, Source::new(content)))
})
.collect::<Result<BTreeMap<_, _>>>()?,
);
self.walk_item_function(func)
}

Ok(PreprocessedState { sources, target_path, project, config })
fn visit_variable_definition(&mut self, var: &'ast ast::VariableDefinition<'ast>) {
// Remove `immutable` attributes.
if let Some(VarMut::Immutable) = var.mutability {
self.updates.push((var.span, ""));
}

self.walk_variable_definition(var)
}
}

Expand Down
71 changes: 71 additions & 0 deletions crates/forge/tests/cli/bind_json.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use foundry_test_utils::snapbox;

// tests complete bind-json workflow
// ensures that we can run forge-bind even if files are depending on yet non-existent bindings and
// that generated bindings are correct
Expand Down Expand Up @@ -50,5 +52,74 @@ contract BindJsonTest is Test {
.unwrap();

cmd.arg("bind-json").assert_success();

snapbox::assert_data_eq!(
snapbox::Data::read_from(&prj.root().join("utils/JsonBindings.sol"), None),
snapbox::str![[r#"
// Automatically generated by forge bind-json.
pragma solidity >=0.6.2 <0.9.0;
pragma experimental ABIEncoderV2;
import {BindJsonTest, TopLevelStruct} from "test/JsonBindings.sol";
interface Vm {
function parseJsonTypeArray(string calldata json, string calldata key, string calldata typeDescription) external pure returns (bytes memory);
function parseJsonType(string calldata json, string calldata typeDescription) external pure returns (bytes memory);
function parseJsonType(string calldata json, string calldata key, string calldata typeDescription) external pure returns (bytes memory);
function serializeJsonType(string calldata typeDescription, bytes memory value) external pure returns (string memory json);
function serializeJsonType(string calldata objectKey, string calldata valueKey, string calldata typeDescription, bytes memory value) external returns (string memory json);
}
library JsonBindings {
Vm constant vm = Vm(address(uint160(uint256(keccak256("hevm cheat code")))));
string constant schema_TopLevelStruct = "TopLevelStruct(uint256 param1,int8 param2)";
string constant schema_ContractLevelStruct = "ContractLevelStruct(address[][] param1,address addrParam)";
function serialize(TopLevelStruct memory value) internal pure returns (string memory) {
return vm.serializeJsonType(schema_TopLevelStruct, abi.encode(value));
}
function serialize(TopLevelStruct memory value, string memory objectKey, string memory valueKey) internal returns (string memory) {
return vm.serializeJsonType(objectKey, valueKey, schema_TopLevelStruct, abi.encode(value));
}
function deserializeTopLevelStruct(string memory json) public pure returns (TopLevelStruct memory) {
return abi.decode(vm.parseJsonType(json, schema_TopLevelStruct), (TopLevelStruct));
}
function deserializeTopLevelStruct(string memory json, string memory path) public pure returns (TopLevelStruct memory) {
return abi.decode(vm.parseJsonType(json, path, schema_TopLevelStruct), (TopLevelStruct));
}
function deserializeTopLevelStructArray(string memory json, string memory path) public pure returns (TopLevelStruct[] memory) {
return abi.decode(vm.parseJsonTypeArray(json, path, schema_TopLevelStruct), (TopLevelStruct[]));
}
function serialize(BindJsonTest.ContractLevelStruct memory value) internal pure returns (string memory) {
return vm.serializeJsonType(schema_ContractLevelStruct, abi.encode(value));
}
function serialize(BindJsonTest.ContractLevelStruct memory value, string memory objectKey, string memory valueKey) internal returns (string memory) {
return vm.serializeJsonType(objectKey, valueKey, schema_ContractLevelStruct, abi.encode(value));
}
function deserializeContractLevelStruct(string memory json) public pure returns (BindJsonTest.ContractLevelStruct memory) {
return abi.decode(vm.parseJsonType(json, schema_ContractLevelStruct), (BindJsonTest.ContractLevelStruct));
}
function deserializeContractLevelStruct(string memory json, string memory path) public pure returns (BindJsonTest.ContractLevelStruct memory) {
return abi.decode(vm.parseJsonType(json, path, schema_ContractLevelStruct), (BindJsonTest.ContractLevelStruct));
}
function deserializeContractLevelStructArray(string memory json, string memory path) public pure returns (BindJsonTest.ContractLevelStruct[] memory) {
return abi.decode(vm.parseJsonTypeArray(json, path, schema_ContractLevelStruct), (BindJsonTest.ContractLevelStruct[]));
}
}
"#]],
);

cmd.forge_fuse().args(["test"]).assert_success();
});

0 comments on commit 33d9130

Please sign in to comment.