diff --git a/bindings/python/Cargo.toml b/bindings/python/Cargo.toml index eca8ec80..b09e865d 100644 --- a/bindings/python/Cargo.toml +++ b/bindings/python/Cargo.toml @@ -12,12 +12,13 @@ crate-type = ["cdylib"] [dependencies] anyhow = { workspace = true } either = "1.13" -pyo3 = { version = "0.23", features = ["anyhow", "serde"] } +pyo3 = { version = "0.23", features = ["anyhow", "serde", "either"] } pyo3-async-runtimes = { version = "0.23", features = ["tokio-runtime", "attributes"] } pythonize = "0.23" json_dotpath = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } +rust_decimal = { workspace = true, features = ["maths-nopanic"] } tokio-util = { version = "0.7", features = ["rt"] } zen-engine = { path = "../../core/engine" } zen-expression = { path = "../../core/expression" } diff --git a/bindings/python/src/decision.rs b/bindings/python/src/decision.rs index a653ea94..e88268ec 100644 --- a/bindings/python/src/decision.rs +++ b/bindings/python/src/decision.rs @@ -1,5 +1,10 @@ use std::sync::Arc; +use crate::custom_node::PyCustomNode; +use crate::engine::PyZenEvaluateOptions; +use crate::loader::PyDecisionLoader; +use crate::mt::worker_pool; +use crate::value::PyValue; use anyhow::{anyhow, Context}; use pyo3::types::PyDict; use pyo3::{pyclass, pymethods, Bound, IntoPyObjectExt, Py, PyAny, PyResult, Python}; @@ -9,12 +14,7 @@ use pyo3_async_runtimes::tokio::re_exports::runtime::Runtime; use pythonize::depythonize; use serde_json::Value; use zen_engine::{Decision, EvaluationOptions}; - -use crate::custom_node::PyCustomNode; -use crate::engine::PyZenEvaluateOptions; -use crate::loader::PyDecisionLoader; -use crate::mt::worker_pool; -use crate::value::PyValue; +use zen_expression::Variable; #[pyclass] #[pyo3(name = "ZenDecision")] @@ -35,7 +35,7 @@ impl PyZenDecision { ctx: &Bound<'_, PyDict>, opts: Option<&Bound<'_, PyDict>>, ) -> PyResult> { - let context: Value = depythonize(ctx).context("Failed to convert dict")?; + let context: Variable = depythonize(ctx).context("Failed to convert dict")?; let options: PyZenEvaluateOptions = if let Some(op) = opts { depythonize(op).context("Failed to convert dict")? } else { @@ -47,7 +47,7 @@ impl PyZenDecision { let rt = Runtime::new()?; let result = rt .block_on(decision.evaluate_with_opts( - context.into(), + context, EvaluationOptions { max_depth: options.max_depth, trace: options.trace, diff --git a/bindings/python/src/engine.rs b/bindings/python/src/engine.rs index 50fd969c..ab0f7d68 100644 --- a/bindings/python/src/engine.rs +++ b/bindings/python/src/engine.rs @@ -1,5 +1,10 @@ use std::sync::Arc; +use crate::custom_node::PyCustomNode; +use crate::decision::PyZenDecision; +use crate::loader::PyDecisionLoader; +use crate::mt::{block_on, worker_pool}; +use crate::value::PyValue; use anyhow::{anyhow, Context}; use pyo3::prelude::PyDictMethods; use pyo3::types::PyDict; @@ -11,12 +16,7 @@ use serde::{Deserialize, Serialize}; use serde_json::Value; use zen_engine::model::DecisionContent; use zen_engine::{DecisionEngine, EvaluationOptions}; - -use crate::custom_node::PyCustomNode; -use crate::decision::PyZenDecision; -use crate::loader::PyDecisionLoader; -use crate::mt::{block_on, worker_pool}; -use crate::value::PyValue; +use zen_expression::Variable; #[pyclass] #[pyo3(name = "ZenEngine")] @@ -90,7 +90,7 @@ impl PyZenEngine { ctx: &Bound<'_, PyDict>, opts: Option<&Bound<'_, PyDict>>, ) -> PyResult> { - let context: Value = depythonize(ctx).context("Failed to convert dict")?; + let context: Variable = depythonize(ctx).context("Failed to convert dict")?; let options: PyZenEvaluateOptions = if let Some(op) = opts { depythonize(op).context("Failed to convert dict")? } else { @@ -99,7 +99,7 @@ impl PyZenEngine { let result = block_on(self.engine.evaluate_with_opts( key, - context.into(), + context, EvaluationOptions { max_depth: options.max_depth, trace: options.trace, diff --git a/bindings/python/src/expression.rs b/bindings/python/src/expression.rs index 1b142c3d..602368b2 100644 --- a/bindings/python/src/expression.rs +++ b/bindings/python/src/expression.rs @@ -1,10 +1,35 @@ +use crate::variable::PyVariable; use anyhow::{anyhow, Context}; -use pyo3::types::PyDict; -use pyo3::{pyfunction, Bound, IntoPyObjectExt, Py, PyAny, PyResult, Python}; +use either::Either; +use pyo3::types::{PyDict, PyList}; +use pyo3::{ + pyclass, pyfunction, pymethods, Bound, IntoPyObject, IntoPyObjectExt, Py, PyAny, PyErr, + PyResult, Python, +}; use pythonize::depythonize; -use serde_json::Value; +use zen_expression::expression::{Standard, Unary}; +use zen_expression::vm::VM; +use zen_expression::{Expression, Variable}; -use crate::value::PyValue; +#[pyfunction] +pub fn compile_expression(expression: String) -> PyResult { + let expr = zen_expression::compile_expression(expression.as_str()) + .map_err(|e| anyhow!(serde_json::to_string(&e).unwrap_or_else(|_| e.to_string())))?; + + Ok(PyExpression { + expression: Either::Left(expr), + }) +} + +#[pyfunction] +pub fn compile_unary_expression(expression: String) -> PyResult { + let expr = zen_expression::compile_unary_expression(expression.as_str()) + .map_err(|e| anyhow!(serde_json::to_string(&e).unwrap_or_else(|_| e.to_string())))?; + + Ok(PyExpression { + expression: Either::Right(expr), + }) +} #[pyfunction] #[pyo3(signature = (expression, ctx=None))] @@ -17,19 +42,19 @@ pub fn evaluate_expression( .map(|ctx| depythonize(ctx)) .transpose() .context("Failed to convert context")? - .unwrap_or(Value::Null); + .unwrap_or(Variable::Null); - let result = zen_expression::evaluate_expression(expression.as_str(), context.into()) + let result = zen_expression::evaluate_expression(expression.as_str(), context) .map_err(|e| anyhow!(serde_json::to_string(&e).unwrap_or_else(|_| e.to_string())))?; - PyValue(result.to_value()).into_py_any(py) + PyVariable(result).into_py_any(py) } #[pyfunction] pub fn evaluate_unary_expression(expression: String, ctx: &Bound<'_, PyDict>) -> PyResult { - let context: Value = depythonize(ctx).context("Failed to convert context")?; + let context: Variable = depythonize(ctx).context("Failed to convert context")?; - let result = zen_expression::evaluate_unary_expression(expression.as_str(), context.into()) + let result = zen_expression::evaluate_unary_expression(expression.as_str(), context) .map_err(|e| anyhow!(serde_json::to_string(&e).unwrap_or_else(|_| e.to_string())))?; Ok(result) @@ -41,10 +66,75 @@ pub fn render_template( template: String, ctx: &Bound<'_, PyDict>, ) -> PyResult> { - let context: Value = depythonize(ctx).context("Failed to convert context")?; + let context: Variable = depythonize(ctx) + .context("Failed to convert context") + .unwrap_or(Variable::Null); - let result = zen_tmpl::render(template.as_str(), context.into()) + let result = zen_tmpl::render(template.as_str(), context) .map_err(|e| anyhow!(serde_json::to_string(&e).unwrap_or_else(|_| e.to_string())))?; - PyValue(result.to_value()).into_py_any(py) + PyVariable(result).into_py_any(py) +} + +#[pyclass] +pub struct PyExpression { + expression: Either, Expression>, +} +#[pymethods] +impl PyExpression { + #[pyo3(signature = (ctx=None))] + pub fn evaluate(&self, py: Python, ctx: Option<&Bound<'_, PyDict>>) -> PyResult> { + let context = ctx + .map(|ctx| depythonize(ctx)) + .transpose() + .context("Failed to convert context")? + .unwrap_or(Variable::Null); + + let maybe_result = match &self.expression { + Either::Left(standard) => standard.evaluate(context), + Either::Right(unary) => unary.evaluate(context).map(Variable::Bool), + }; + + let result = maybe_result + .map_err(|e| anyhow!(serde_json::to_string(&e).unwrap_or_else(|_| e.to_string())))?; + + PyVariable(result).into_py_any(py) + } + + pub fn evaluate_many(&self, py: Python, ctx: &Bound<'_, PyList>) -> PyResult> { + let contexts: Vec = depythonize(ctx).context("Failed to convert contexts")?; + + let mut vm = VM::new(); + let results: Vec<_> = contexts + .into_iter() + .map(|context| { + let result = match &self.expression { + Either::Left(standard) => standard.evaluate_with(context, &mut vm), + Either::Right(unary) => { + unary.evaluate_with(context, &mut vm).map(Variable::Bool) + } + }; + + match result { + Ok(ok) => Either::Left(PyVariable(ok)), + Err(err) => Either::Right(PyExpressionError(err.to_string())), + } + }) + .collect(); + + results.into_py_any(py) + } +} + +struct PyExpressionError(String); + +impl<'py> IntoPyObject<'py> for PyExpressionError { + type Target = PyAny; + type Output = Bound<'py, PyAny>; + type Error = PyErr; + + fn into_pyobject(self, py: Python<'py>) -> Result { + let err = pyo3::exceptions::PyException::new_err(self.0); + err.into_bound_py_any(py) + } } diff --git a/bindings/python/src/lib.rs b/bindings/python/src/lib.rs index a7c5dfc4..4af01d59 100644 --- a/bindings/python/src/lib.rs +++ b/bindings/python/src/lib.rs @@ -1,6 +1,9 @@ use crate::decision::PyZenDecision; use crate::engine::PyZenEngine; -use crate::expression::{evaluate_expression, evaluate_unary_expression, render_template}; +use crate::expression::{ + compile_expression, compile_unary_expression, evaluate_expression, evaluate_unary_expression, + render_template, PyExpression, +}; use pyo3::prelude::PyModuleMethods; use pyo3::types::PyModule; use pyo3::{pymodule, wrap_pyfunction, Bound, PyResult, Python}; @@ -13,14 +16,18 @@ mod loader; mod mt; mod types; mod value; +mod variable; #[pymodule] fn zen(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; + m.add_class::()?; m.add_function(wrap_pyfunction!(evaluate_expression, m)?)?; m.add_function(wrap_pyfunction!(evaluate_unary_expression, m)?)?; m.add_function(wrap_pyfunction!(render_template, m)?)?; + m.add_function(wrap_pyfunction!(compile_expression, m)?)?; + m.add_function(wrap_pyfunction!(compile_unary_expression, m)?)?; Ok(()) } diff --git a/bindings/python/src/types.rs b/bindings/python/src/types.rs index 6253611f..6d296abf 100644 --- a/bindings/python/src/types.rs +++ b/bindings/python/src/types.rs @@ -6,6 +6,7 @@ use serde_json::Value; use std::sync::Arc; use crate::value::{value_to_object, PyValue}; +use crate::variable::PyVariable; use zen_engine::handler::custom_node_adapter::{ CustomDecisionNode as BaseCustomDecisionNode, CustomNodeRequest, }; @@ -74,7 +75,7 @@ impl PyNodeRequest { let template_value = zen_tmpl::render(template.as_str(), Variable::from(&self.inner_input)) .map_err(|e| anyhow!(serde_json::to_string(&e).unwrap_or_else(|_| e.to_string())))?; - PyValue(template_value.to_value()).into_py_any(py) + PyVariable(template_value).into_py_any(py) } fn get_field_raw(&self, py: Python, path: String) -> PyResult> { diff --git a/bindings/python/src/variable.rs b/bindings/python/src/variable.rs new file mode 100644 index 00000000..6f86c23a --- /dev/null +++ b/bindings/python/src/variable.rs @@ -0,0 +1,51 @@ +use pyo3::prelude::{PyDictMethods, PyListMethods}; +use pyo3::types::{PyDict, PyList}; +use pyo3::{Bound, IntoPyObject, IntoPyObjectExt, PyAny, PyErr, PyResult, Python}; +use rust_decimal::prelude::ToPrimitive; +use zen_expression::Variable; + +#[repr(transparent)] +#[derive(Clone, Debug)] +pub struct PyVariable(pub Variable); + +pub fn variable_to_object<'py>(py: Python<'py>, val: &Variable) -> PyResult> { + match val { + Variable::Null => py.None().into_bound_py_any(py), + Variable::Bool(b) => b.into_bound_py_any(py), + Variable::Number(n) => { + let of64 = n.to_f64().map(|i| i.into_bound_py_any(py)); + let oi64 = n.to_i64().map(|i| i.into_bound_py_any(py)); + let ou64 = n.to_u64().map(|i| i.into_bound_py_any(py)); + of64.or(oi64).or(ou64).expect("number too large") + } + Variable::String(s) => s.into_bound_py_any(py), + Variable::Array(v) => { + let list = PyList::empty(py); + let b = v.borrow(); + for item in b.iter() { + list.append(variable_to_object(py, item)?)?; + } + + list.into_bound_py_any(py) + } + Variable::Object(m) => { + let dict = PyDict::new(py); + let b = m.borrow(); + for (key, value) in b.iter() { + dict.set_item(key, variable_to_object(py, value)?)?; + } + + dict.into_bound_py_any(py) + } + } +} + +impl<'py> IntoPyObject<'py> for PyVariable { + type Target = PyAny; + type Output = Bound<'py, PyAny>; + type Error = PyErr; + + fn into_pyobject(self, py: Python<'py>) -> Result { + variable_to_object(py, &self.0) + } +} diff --git a/bindings/python/zen.pyi b/bindings/python/zen.pyi index a6fb026b..e18dcf52 100644 --- a/bindings/python/zen.pyi +++ b/bindings/python/zen.pyi @@ -1,31 +1,55 @@ -from typing import Any, Optional, TypedDict from collections.abc import Awaitable +from typing import Any, Optional, TypedDict + class EvaluateResponse(TypedDict): performance: str result: dict trace: dict + class ZenEngine: def __init__(self, maybe_options: Optional[dict] = None) -> None: ... - + def evaluate(self, key: str, ctx: dict, opts: Optional[dict] = None) -> EvaluateResponse: ... - - def async_evaluate(self,key: str, ctx: dict, opts: Optional[dict] = None) -> Awaitable[EvaluateResponse]: ... - + + def async_evaluate(self, key: str, ctx: dict, opts: Optional[dict] = None) -> Awaitable[EvaluateResponse]: ... + def create_decision(self, content: str) -> ZenDecision: ... - + def get_decision(self, key: str) -> ZenDecision: ... + class ZenDecision: - def evaluate(self,ctx: dict, opts: Optional[dict] = None) -> EvaluateResponse: ... - - def async_evaluate(self,ctx: dict, opts: Optional[dict] = None) -> Awaitable[EvaluateResponse]: ... - + def evaluate(self, ctx: dict, opts: Optional[dict] = None) -> EvaluateResponse: ... + + def async_evaluate(self, ctx: dict, opts: Optional[dict] = None) -> Awaitable[EvaluateResponse]: ... + def validate(self) -> None: ... - + + def evaluate_expression(expression: str, ctx: Optional[dict] = None) -> Any: ... + def evaluate_unary_expression(expression: str, ctx: dict) -> bool: ... + def render_template(template: str, ctx: dict) -> Any: ... + + +def compile_expression(expression: str) -> Expression: ... + + +def compile_unary_expression(expression: str) -> Expression: ... + + +class ExpressionResult(TypedDict): + success: bool + result: Optional[Any] + error: Optional[str] + + +class Expression: + def evaluate(self, ctx: Optional[dict] = None) -> Any: ... + + def evaluate_many(self, ctxs: list[dict]) -> list[ExpressionResult]: ... diff --git a/core/expression/src/compiler/compiler.rs b/core/expression/src/compiler/compiler.rs index c868dc01..57bcd317 100644 --- a/core/expression/src/compiler/compiler.rs +++ b/core/expression/src/compiler/compiler.rs @@ -1,44 +1,45 @@ -use rust_decimal::Decimal; -use rust_decimal_macros::dec; -use std::rc::Rc; - use crate::compiler::error::{CompilerError, CompilerResult}; +use crate::compiler::opcode::{FetchFastTarget, Jump}; use crate::compiler::{Opcode, TypeCheckKind, TypeConversionKind}; use crate::lexer::{ArithmeticOperator, ComparisonOperator, LogicalOperator, Operator}; use crate::parser::{BuiltInFunction, Node}; -use crate::variable::Variable; +use rust_decimal::prelude::ToPrimitive; +use rust_decimal::Decimal; +use rust_decimal_macros::dec; +use std::sync::Arc; #[derive(Debug)] -pub struct Compiler<'arena> { - bytecode: Vec>, +pub struct Compiler { + bytecode: Vec, } -impl<'arena> Compiler<'arena> { +impl Compiler { pub fn new() -> Self { Self { bytecode: Default::default(), } } - pub fn compile(&mut self, root: &'arena Node<'arena>) -> CompilerResult<&[Opcode<'arena>]> { + pub fn compile(&mut self, root: &Node) -> CompilerResult<&[Opcode]> { self.bytecode.clear(); CompilerInner::new(&mut self.bytecode, root).compile()?; Ok(self.bytecode.as_slice()) } + + pub fn get_bytecode(&self) -> &[Opcode] { + self.bytecode.as_slice() + } } #[derive(Debug)] struct CompilerInner<'arena, 'bytecode_ref> { root: &'arena Node<'arena>, - bytecode: &'bytecode_ref mut Vec>, + bytecode: &'bytecode_ref mut Vec, } impl<'arena, 'bytecode_ref> CompilerInner<'arena, 'bytecode_ref> { - pub fn new( - bytecode: &'bytecode_ref mut Vec>, - root: &'arena Node<'arena>, - ) -> Self { + pub fn new(bytecode: &'bytecode_ref mut Vec, root: &'arena Node<'arena>) -> Self { Self { root, bytecode } } @@ -47,7 +48,7 @@ impl<'arena, 'bytecode_ref> CompilerInner<'arena, 'bytecode_ref> { Ok(()) } - fn emit(&mut self, op: Opcode<'arena>) -> usize { + fn emit(&mut self, op: Opcode) -> usize { self.bytecode.push(op); self.bytecode.len() } @@ -57,13 +58,16 @@ impl<'arena, 'bytecode_ref> CompilerInner<'arena, 'bytecode_ref> { F: FnOnce(&mut Self) -> CompilerResult<()>, { let begin = self.bytecode.len(); - let end = self.emit(Opcode::JumpIfEnd(0)); + let end = self.emit(Opcode::Jump(Jump::IfEnd, 0)); body(self)?; self.emit(Opcode::IncrementIt); - let e = self.emit(Opcode::JumpBackward(self.calc_backward_jump(begin))); - self.replace(end, Opcode::JumpIfEnd(e - end)); + let e = self.emit(Opcode::Jump( + Jump::Backward, + self.calc_backward_jump(begin) as u32, + )); + self.replace(end, Opcode::Jump(Jump::IfEnd, (e - end) as u32)); Ok(()) } @@ -71,18 +75,18 @@ impl<'arena, 'bytecode_ref> CompilerInner<'arena, 'bytecode_ref> { where F: FnMut(&mut Self), { - let noop = self.emit(Opcode::JumpIfFalse(0)); + let noop = self.emit(Opcode::Jump(Jump::IfFalse, 0)); self.emit(Opcode::Pop); body(self); - let jmp = self.emit(Opcode::Jump(0)); - self.replace(noop, Opcode::JumpIfFalse(jmp - noop)); + let jmp = self.emit(Opcode::Jump(Jump::Forward, 0)); + self.replace(noop, Opcode::Jump(Jump::IfFalse, (jmp - noop) as u32)); let e = self.emit(Opcode::Pop); - self.replace(jmp, Opcode::Jump(e - jmp)); + self.replace(jmp, Opcode::Jump(Jump::Forward, (e - jmp) as u32)); } - fn replace(&mut self, at: usize, op: Opcode<'arena>) { + fn replace(&mut self, at: usize, op: Opcode) { let _ = std::mem::replace(&mut self.bytecode[at - 1], op); } @@ -106,18 +110,47 @@ impl<'arena, 'bytecode_ref> CompilerInner<'arena, 'bytecode_ref> { self.compile_node(arg) } + fn compile_member_fast(&mut self, node: &'arena Node<'arena>) -> Option> { + match node { + Node::Root => Some(vec![FetchFastTarget::Root]), + Node::Identifier(v) => Some(vec![ + FetchFastTarget::Root, + FetchFastTarget::String(Arc::from(*v)), + ]), + Node::Member { node, property } => { + let mut path = self.compile_member_fast(node)?; + match property { + Node::String(v) => { + path.push(FetchFastTarget::String(Arc::from(*v))); + Some(path) + } + Node::Number(v) => { + if let Some(idx) = v.to_u32() { + path.push(FetchFastTarget::Number(idx)); + Some(path) + } else { + None + } + } + _ => None, + } + } + _ => None, + } + } + fn compile_node(&mut self, node: &'arena Node<'arena>) -> CompilerResult { match node { - Node::Null => Ok(self.emit(Opcode::Push(Variable::Null))), - Node::Bool(v) => Ok(self.emit(Opcode::Push(Variable::Bool(*v)))), - Node::Number(v) => Ok(self.emit(Opcode::Push(Variable::Number(*v)))), - Node::String(v) => Ok(self.emit(Opcode::Push(Variable::String(Rc::from(*v))))), + Node::Null => Ok(self.emit(Opcode::PushNull)), + Node::Bool(v) => Ok(self.emit(Opcode::PushBool(*v))), + Node::Number(v) => Ok(self.emit(Opcode::PushNumber(*v))), + Node::String(v) => Ok(self.emit(Opcode::PushString(Arc::from(*v)))), Node::Pointer => Ok(self.emit(Opcode::Pointer)), Node::Root => Ok(self.emit(Opcode::FetchRootEnv)), Node::Array(v) => { v.iter() .try_for_each(|&n| self.compile_node(n).map(|_| ()))?; - self.emit(Opcode::Push(Variable::Number(Decimal::from(v.len())))); + self.emit(Opcode::PushNumber(Decimal::from(v.len()))); Ok(self.emit(Opcode::Array)) } Node::Object(v) => { @@ -128,16 +161,23 @@ impl<'arena, 'bytecode_ref> CompilerInner<'arena, 'bytecode_ref> { Ok(()) })?; - self.emit(Opcode::Push(Variable::Number(Decimal::from(v.len())))); + self.emit(Opcode::PushNumber(Decimal::from(v.len()))); Ok(self.emit(Opcode::Object)) } - Node::Identifier(v) => Ok(self.emit(Opcode::FetchEnv(v))), + Node::Identifier(v) => Ok(self.emit(Opcode::FetchEnv(Arc::from(*v)))), Node::Closure(v) => self.compile_node(v), Node::Parenthesized(v) => self.compile_node(v), - Node::Member { node, property } => { - self.compile_node(node)?; - self.compile_node(property)?; - Ok(self.emit(Opcode::Fetch)) + Node::Member { + node: n, + property: p, + } => { + if let Some(path) = self.compile_member_fast(node) { + Ok(self.emit(Opcode::FetchFast(path))) + } else { + self.compile_node(n)?; + self.compile_node(p)?; + Ok(self.emit(Opcode::Fetch)) + } } Node::TemplateString(parts) => { parts.iter().try_for_each(|&n| { @@ -146,9 +186,9 @@ impl<'arena, 'bytecode_ref> CompilerInner<'arena, 'bytecode_ref> { Ok(()) })?; - self.emit(Opcode::Push(Variable::Number(Decimal::from(parts.len())))); + self.emit(Opcode::PushNumber(Decimal::from(parts.len()))); self.emit(Opcode::Array); - self.emit(Opcode::Push(Variable::String(Rc::from("")))); + self.emit(Opcode::PushString(Arc::from(""))); Ok(self.emit(Opcode::Join)) } Node::Slice { node, to, from } => { @@ -157,14 +197,14 @@ impl<'arena, 'bytecode_ref> CompilerInner<'arena, 'bytecode_ref> { self.compile_node(t)?; } else { self.emit(Opcode::Len); - self.emit(Opcode::Push(Variable::Number(dec!(1)))); + self.emit(Opcode::PushNumber(dec!(1))); self.emit(Opcode::Subtract); } if let Some(f) = from { self.compile_node(f)?; } else { - self.emit(Opcode::Push(Variable::Number(dec!(0)))); + self.emit(Opcode::PushNumber(dec!(0))); } Ok(self.emit(Opcode::Slice)) @@ -178,8 +218,8 @@ impl<'arena, 'bytecode_ref> CompilerInner<'arena, 'bytecode_ref> { self.compile_node(left)?; self.compile_node(right)?; Ok(self.emit(Opcode::Interval { - left_bracket, - right_bracket, + left_bracket: *left_bracket, + right_bracket: *right_bracket, })) } Node::Conditional { @@ -188,16 +228,19 @@ impl<'arena, 'bytecode_ref> CompilerInner<'arena, 'bytecode_ref> { on_false, } => { self.compile_node(condition)?; - let otherwise = self.emit(Opcode::JumpIfFalse(0)); + let otherwise = self.emit(Opcode::Jump(Jump::IfFalse, 0)); self.emit(Opcode::Pop); self.compile_node(on_true)?; - let end = self.emit(Opcode::Jump(0)); + let end = self.emit(Opcode::Jump(Jump::Forward, 0)); - self.replace(otherwise, Opcode::JumpIfFalse(end - otherwise)); + self.replace( + otherwise, + Opcode::Jump(Jump::IfFalse, (end - otherwise) as u32), + ); self.emit(Opcode::Pop); let b = self.compile_node(on_false)?; - self.replace(end, Opcode::Jump(b - end)); + self.replace(end, Opcode::Jump(Jump::Forward, (b - end) as u32)); Ok(b) } @@ -234,28 +277,28 @@ impl<'arena, 'bytecode_ref> CompilerInner<'arena, 'bytecode_ref> { } Operator::Logical(LogicalOperator::Or) => { self.compile_node(left)?; - let end = self.emit(Opcode::JumpIfTrue(0)); + let end = self.emit(Opcode::Jump(Jump::IfTrue, 0)); self.emit(Opcode::Pop); let r = self.compile_node(right)?; - self.replace(end, Opcode::JumpIfTrue(r - end)); + self.replace(end, Opcode::Jump(Jump::IfTrue, (r - end) as u32)); Ok(r) } Operator::Logical(LogicalOperator::And) => { self.compile_node(left)?; - let end = self.emit(Opcode::JumpIfFalse(0)); + let end = self.emit(Opcode::Jump(Jump::IfFalse, 0)); self.emit(Opcode::Pop); let r = self.compile_node(right)?; - self.replace(end, Opcode::JumpIfFalse(r - end)); + self.replace(end, Opcode::Jump(Jump::IfFalse, (r - end) as u32)); Ok(r) } Operator::Logical(LogicalOperator::NullishCoalescing) => { self.compile_node(left)?; - let end = self.emit(Opcode::JumpIfNotNull(0)); + let end = self.emit(Opcode::Jump(Jump::IfNotNull, 0)); self.emit(Opcode::Pop); let r = self.compile_node(right)?; - self.replace(end, Opcode::JumpIfNotNull(r - end)); + self.replace(end, Opcode::Jump(Jump::IfNotNull, (r - end) as u32)); Ok(r) } @@ -470,7 +513,11 @@ impl<'arena, 'bytecode_ref> CompilerInner<'arena, 'bytecode_ref> { BuiltInFunction::StartOf | BuiltInFunction::EndOf => { self.compile_argument(kind, arguments, 0)?; self.compile_argument(kind, arguments, 1)?; - Ok(self.emit(Opcode::DateFunction(kind.into()))) + Ok( + self.emit(Opcode::DateFunction(Arc::from(Into::<&'static str>::into( + kind, + )))), + ) } BuiltInFunction::DayOfWeek | BuiltInFunction::DayOfMonth @@ -482,7 +529,9 @@ impl<'arena, 'bytecode_ref> CompilerInner<'arena, 'bytecode_ref> { | BuiltInFunction::Year | BuiltInFunction::DateString => { self.compile_argument(kind, arguments, 0)?; - Ok(self.emit(Opcode::DateManipulation(kind.into()))) + Ok(self.emit(Opcode::DateManipulation(Arc::from( + Into::<&'static str>::into(kind), + )))) } BuiltInFunction::All => { self.compile_argument(kind, arguments, 0)?; @@ -490,12 +539,15 @@ impl<'arena, 'bytecode_ref> CompilerInner<'arena, 'bytecode_ref> { let mut loop_break: usize = 0; self.emit_loop(|c| { c.compile_argument(kind, arguments, 1)?; - loop_break = c.emit(Opcode::JumpIfFalse(0)); + loop_break = c.emit(Opcode::Jump(Jump::IfFalse, 0)); c.emit(Opcode::Pop); Ok(()) })?; - let e = self.emit(Opcode::Push(Variable::Bool(true))); - self.replace(loop_break, Opcode::JumpIfFalse(e - loop_break)); + let e = self.emit(Opcode::PushBool(true)); + self.replace( + loop_break, + Opcode::Jump(Jump::IfFalse, (e - loop_break) as u32), + ); Ok(self.emit(Opcode::End)) } BuiltInFunction::None => { @@ -505,12 +557,15 @@ impl<'arena, 'bytecode_ref> CompilerInner<'arena, 'bytecode_ref> { self.emit_loop(|c| { c.compile_argument(kind, arguments, 1)?; c.emit(Opcode::Not); - loop_break = c.emit(Opcode::JumpIfFalse(0)); + loop_break = c.emit(Opcode::Jump(Jump::IfFalse, 0)); c.emit(Opcode::Pop); Ok(()) })?; - let e = self.emit(Opcode::Push(Variable::Bool(true))); - self.replace(loop_break, Opcode::JumpIfFalse(e - loop_break)); + let e = self.emit(Opcode::PushBool(true)); + self.replace( + loop_break, + Opcode::Jump(Jump::IfFalse, (e - loop_break) as u32), + ); Ok(self.emit(Opcode::End)) } BuiltInFunction::Some => { @@ -519,12 +574,15 @@ impl<'arena, 'bytecode_ref> CompilerInner<'arena, 'bytecode_ref> { let mut loop_break: usize = 0; self.emit_loop(|c| { c.compile_argument(kind, arguments, 1)?; - loop_break = c.emit(Opcode::JumpIfTrue(0)); + loop_break = c.emit(Opcode::Jump(Jump::IfTrue, 0)); c.emit(Opcode::Pop); Ok(()) })?; - let e = self.emit(Opcode::Push(Variable::Bool(false))); - self.replace(loop_break, Opcode::JumpIfTrue(e - loop_break)); + let e = self.emit(Opcode::PushBool(false)); + self.replace( + loop_break, + Opcode::Jump(Jump::IfTrue, (e - loop_break) as u32), + ); Ok(self.emit(Opcode::End)) } BuiltInFunction::One => { @@ -538,7 +596,7 @@ impl<'arena, 'bytecode_ref> CompilerInner<'arena, 'bytecode_ref> { Ok(()) })?; self.emit(Opcode::GetCount); - self.emit(Opcode::Push(Variable::Number(dec!(1)))); + self.emit(Opcode::PushNumber(dec!(1))); self.emit(Opcode::Equal); Ok(self.emit(Opcode::End)) } diff --git a/core/expression/src/compiler/mod.rs b/core/expression/src/compiler/mod.rs index 88d07316..5885afdd 100644 --- a/core/expression/src/compiler/mod.rs +++ b/core/expression/src/compiler/mod.rs @@ -7,4 +7,4 @@ mod opcode; pub use compiler::Compiler; pub use error::CompilerError; -pub use opcode::{Opcode, TypeCheckKind, TypeConversionKind}; +pub use opcode::{FetchFastTarget, Jump, Opcode, TypeCheckKind, TypeConversionKind}; diff --git a/core/expression/src/compiler/opcode.rs b/core/expression/src/compiler/opcode.rs index c2969dda..040e90b8 100644 --- a/core/expression/src/compiler/opcode.rs +++ b/core/expression/src/compiler/opcode.rs @@ -1,24 +1,32 @@ -use crate::variable::Variable; +use crate::lexer::Bracket; +use rust_decimal::Decimal; +use std::sync::Arc; use strum_macros::Display; +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum FetchFastTarget { + Root, + String(Arc), + Number(u32), +} + /// Machine code interpreted by VM -#[derive(Debug, PartialEq, Eq, Display)] -pub enum Opcode<'a> { - Push(Variable), +#[derive(Debug, PartialEq, Eq, Clone, Display)] +pub enum Opcode { + PushNull, + PushBool(bool), + PushString(Arc), + PushNumber(Decimal), Pop, Rot, Fetch, FetchRootEnv, - FetchEnv(&'a str), + FetchEnv(Arc), + FetchFast(Vec), Negate, Not, Equal, - Jump(usize), - JumpIfTrue(usize), - JumpIfFalse(usize), - JumpIfNotNull(usize), - JumpIfEnd(usize), - JumpBackward(usize), + Jump(Jump, u32), In, Less, More, @@ -42,14 +50,14 @@ pub enum Opcode<'a> { Modulo, Exponent, Interval { - left_bracket: &'a str, - right_bracket: &'a str, + left_bracket: Bracket, + right_bracket: Bracket, }, Contains, Keys, Values, - DateFunction(&'a str), - DateManipulation(&'a str), + DateFunction(Arc), + DateManipulation(Arc), Uppercase, Lowercase, StartsWith, @@ -80,6 +88,16 @@ pub enum Opcode<'a> { TypeCheck(TypeCheckKind), } +#[derive(Debug, PartialEq, Eq, Clone, Copy, Display)] +pub enum Jump { + Forward, + Backward, + IfTrue, + IfFalse, + IfNotNull, + IfEnd, +} + /// Metadata for TypeConversion Opcode #[derive(Debug, PartialEq, Eq, Clone, Copy, Display)] pub enum TypeConversionKind { diff --git a/core/expression/src/expression.rs b/core/expression/src/expression.rs new file mode 100644 index 00000000..28cad6b6 --- /dev/null +++ b/core/expression/src/expression.rs @@ -0,0 +1,88 @@ +use crate::compiler::Opcode; +use crate::vm::VM; +use crate::{IsolateError, Variable}; +use std::marker::PhantomData; +use std::sync::Arc; + +#[derive(Debug, Clone)] +pub struct Standard; + +#[derive(Debug, Clone)] +pub struct Unary; + +#[derive(Debug, Clone)] +pub enum ExpressionKind { + Standard, + Unary, +} + +/// Compiled expression +#[derive(Debug, Clone)] +pub struct Expression { + bytecode: Arc>, + _marker: PhantomData, +} + +impl Expression { + pub fn bytecode(&self) -> &Arc> { + &self.bytecode + } +} + +impl Expression { + pub fn new_standard(bytecode: Arc>) -> Self { + Expression { + bytecode, + _marker: PhantomData, + } + } + + pub fn kind(&self) -> ExpressionKind { + ExpressionKind::Standard + } + + pub fn evaluate(&self, context: Variable) -> Result { + let mut vm = VM::new(); + self.evaluate_with(context, &mut vm) + } + + pub fn evaluate_with(&self, context: Variable, vm: &mut VM) -> Result { + let output = vm.run(self.bytecode.as_slice(), context)?; + Ok(output) + } +} + +impl Expression { + pub fn new_unary(bytecode: Arc>) -> Self { + Expression { + bytecode, + _marker: PhantomData, + } + } + + pub fn kind(&self) -> ExpressionKind { + ExpressionKind::Unary + } + + pub fn evaluate(&self, context: Variable) -> Result { + let mut vm = VM::new(); + self.evaluate_with(context, &mut vm) + } + + pub fn evaluate_with(&self, context: Variable, vm: &mut VM) -> Result { + let Some(context_object_ref) = context.as_object() else { + return Err(IsolateError::MissingContextReference); + }; + + let context_object = context_object_ref.borrow(); + if !context_object.contains_key("$") { + return Err(IsolateError::MissingContextReference); + } + + let output = vm + .run(self.bytecode.as_slice(), context)? + .as_bool() + .ok_or_else(|| IsolateError::ValueCastError)?; + Ok(output) + } +} diff --git a/core/expression/src/function.rs b/core/expression/src/function.rs index 7f90f43d..fe6e388f 100644 --- a/core/expression/src/function.rs +++ b/core/expression/src/function.rs @@ -1,5 +1,6 @@ +use crate::expression::{Standard, Unary}; use crate::variable::Variable; -use crate::{Isolate, IsolateError}; +use crate::{Expression, Isolate, IsolateError}; /// Evaluates a standard expression pub fn evaluate_expression(expression: &str, context: Variable) -> Result { @@ -23,13 +24,23 @@ pub fn evaluate_unary_expression( Isolate::with_environment(context).run_unary(expression) } +/// Compiles a standard expression +pub fn compile_expression(expression: &str) -> Result, IsolateError> { + Isolate::new().compile_standard(expression) +} + +/// Compiles an unary expression +pub fn compile_unary_expression(expression: &str) -> Result, IsolateError> { + Isolate::new().compile_unary(expression) +} + #[cfg(test)] mod test { use crate::evaluate_expression; use serde_json::json; #[test] - fn bla() { + fn example() { let context = json!({ "tax": { "percentage": 10 } }); let tax_amount = evaluate_expression("50 * tax.percentage / 100", context.into()).unwrap(); diff --git a/core/expression/src/isolate.rs b/core/expression/src/isolate.rs index 9f988ace..9980d21b 100644 --- a/core/expression/src/isolate.rs +++ b/core/expression/src/isolate.rs @@ -3,14 +3,17 @@ use serde::ser::SerializeMap; use serde::{Serialize, Serializer}; use std::collections::HashMap; use std::hash::BuildHasherDefault; +use std::sync::Arc; use thiserror::Error; use crate::arena::UnsafeArena; use crate::compiler::{Compiler, CompilerError}; +use crate::expression::{Standard, Unary}; use crate::lexer::{Lexer, LexerError}; use crate::parser::{Parser, ParserError}; use crate::variable::Variable; use crate::vm::{VMError, VM}; +use crate::{Expression, ExpressionKind}; type ADefHasher = BuildHasherDefault; @@ -22,7 +25,7 @@ type ADefHasher = BuildHasherDefault; #[derive(Debug)] pub struct Isolate<'arena> { lexer: Lexer<'arena>, - compiler: Compiler<'arena>, + compiler: Compiler, vm: VM, bump: UnsafeArena<'arena>, @@ -96,64 +99,60 @@ impl<'a> Isolate<'a> { self.references.clear(); } - pub fn run_standard(&mut self, source: &'a str) -> Result { + fn run_internal(&mut self, source: &'a str, kind: ExpressionKind) -> Result<(), IsolateError> { self.bump.with_mut(|b| b.reset()); let bump = self.bump.get(); - let tokens = self - .lexer - .tokenize(source) - .map_err(|source| IsolateError::LexerError { source })?; + let tokens = self.lexer.tokenize(source)?; + + let base_parser = Parser::try_new(tokens, bump)?; + let parser_result = match kind { + ExpressionKind::Unary => base_parser.unary().parse(), + ExpressionKind::Standard => base_parser.standard().parse(), + }; + + parser_result.error()?; - let parser = Parser::try_new(tokens, bump) - .map_err(|source| IsolateError::ParserError { source })? - .standard(); + self.compiler.compile(parser_result.root)?; - let parser_result = parser.parse(); - parser_result - .error() - .map_err(|source| IsolateError::ParserError { source })?; + Ok(()) + } - let bytecode = self - .compiler - .compile(parser_result.root) - .map_err(|source| IsolateError::CompilerError { source })?; + pub fn compile_standard( + &mut self, + source: &'a str, + ) -> Result, IsolateError> { + self.run_internal(source, ExpressionKind::Standard)?; + let bytecode = self.compiler.get_bytecode().to_vec(); + + Ok(Expression::new_standard(Arc::new(bytecode))) + } + pub fn run_standard(&mut self, source: &'a str) -> Result { + self.run_internal(source, ExpressionKind::Standard)?; + + let bytecode = self.compiler.get_bytecode(); let result = self .vm - .run(bytecode, self.environment.clone().unwrap_or(Variable::Null)) - .map_err(|source| IsolateError::VMError { source })?; + .run(bytecode, self.environment.clone().unwrap_or(Variable::Null))?; Ok(result) } - pub fn run_unary(&mut self, source: &'a str) -> Result { - self.bump.with_mut(|b| b.reset()); - let bump = self.bump.get(); - - let tokens = self - .lexer - .tokenize(source) - .map_err(|source| IsolateError::LexerError { source })?; - - let parser = Parser::try_new(tokens, bump) - .map_err(|source| IsolateError::ParserError { source })? - .unary(); + pub fn compile_unary(&mut self, source: &'a str) -> Result, IsolateError> { + self.run_internal(source, ExpressionKind::Unary)?; + let bytecode = self.compiler.get_bytecode().to_vec(); - let parser_result = parser.parse(); - parser_result - .error() - .map_err(|source| IsolateError::ParserError { source })?; + Ok(Expression::new_unary(Arc::new(bytecode))) + } - let bytecode = self - .compiler - .compile(parser_result.root) - .map_err(|source| IsolateError::CompilerError { source })?; + pub fn run_unary(&mut self, source: &'a str) -> Result { + self.run_internal(source, ExpressionKind::Unary)?; + let bytecode = self.compiler.get_bytecode(); let result = self .vm - .run(bytecode, self.environment.clone().unwrap_or(Variable::Null)) - .map_err(|source| IsolateError::VMError { source })?; + .run(bytecode, self.environment.clone().unwrap_or(Variable::Null))?; result.as_bool().ok_or_else(|| IsolateError::ValueCastError) } @@ -222,3 +221,27 @@ impl Serialize for IsolateError { map.end() } } + +impl From for IsolateError { + fn from(source: LexerError) -> Self { + IsolateError::LexerError { source } + } +} + +impl From for IsolateError { + fn from(source: ParserError) -> Self { + IsolateError::ParserError { source } + } +} + +impl From for IsolateError { + fn from(source: VMError) -> Self { + IsolateError::VMError { source } + } +} + +impl From for IsolateError { + fn from(source: CompilerError) -> Self { + IsolateError::CompilerError { source } + } +} diff --git a/core/expression/src/lexer/token.rs b/core/expression/src/lexer/token.rs index 1c02b636..eec29659 100644 --- a/core/expression/src/lexer/token.rs +++ b/core/expression/src/lexer/token.rs @@ -3,7 +3,7 @@ use std::hash::{Hash, Hasher}; use std::str::FromStr; use nohash_hasher::IsEnabled; -use strum_macros::{Display, EnumString, IntoStaticStr}; +use strum_macros::{Display, EnumIter, EnumString, FromRepr, IntoStaticStr}; /// Contains information from lexical analysis #[derive(Debug, PartialEq, Eq, Clone)] @@ -158,7 +158,7 @@ pub enum ComparisonOperator { NotIn, } -#[derive(Debug, PartialEq, Eq, Clone, Copy, EnumString, IntoStaticStr)] +#[derive(Debug, PartialEq, Eq, Clone, Copy, EnumString, IntoStaticStr, EnumIter, FromRepr)] pub enum Bracket { #[strum(serialize = "(")] LeftParenthesis, diff --git a/core/expression/src/lib.rs b/core/expression/src/lib.rs index d81cadb3..e61ef9d0 100644 --- a/core/expression/src/lib.rs +++ b/core/expression/src/lib.rs @@ -59,6 +59,7 @@ mod isolate; mod arena; pub mod compiler; +pub mod expression; mod function; pub mod intellisense; pub mod lexer; @@ -66,6 +67,9 @@ pub mod parser; pub mod variable; pub mod vm; -pub use function::{evaluate_expression, evaluate_unary_expression}; +pub use expression::{Expression, ExpressionKind}; +pub use function::{ + compile_expression, compile_unary_expression, evaluate_expression, evaluate_unary_expression, +}; pub use isolate::{Isolate, IsolateError}; pub use variable::Variable; diff --git a/core/expression/src/parser/ast.rs b/core/expression/src/parser/ast.rs index 1715ad20..53e42ee0 100644 --- a/core/expression/src/parser/ast.rs +++ b/core/expression/src/parser/ast.rs @@ -1,4 +1,4 @@ -use crate::lexer::Operator; +use crate::lexer::{Bracket, Operator}; use crate::parser::builtin::BuiltInFunction; use rust_decimal::Decimal; use std::cell::Cell; @@ -31,8 +31,8 @@ pub enum Node<'a> { Interval { left: &'a Node<'a>, right: &'a Node<'a>, - left_bracket: &'a str, - right_bracket: &'a str, + left_bracket: Bracket, + right_bracket: Bracket, }, Conditional { condition: &'a Node<'a>, diff --git a/core/expression/src/parser/parser.rs b/core/expression/src/parser/parser.rs index 696ae0a2..a6010038 100644 --- a/core/expression/src/parser/parser.rs +++ b/core/expression/src/parser/parser.rs @@ -666,9 +666,7 @@ impl<'arena, 'token_ref, Flavor> Parser<'arena, 'token_ref, Flavor> { }; let initial_position = self.position(); - let left_bracket = self.current()?.value; - - let TokenKind::Bracket(_) = &self.current()?.kind else { + let TokenKind::Bracket(left_bracket) = &self.current()?.kind else { self.set_position(initial_position); return None; }; @@ -691,8 +689,7 @@ impl<'arena, 'token_ref, Flavor> Parser<'arena, 'token_ref, Flavor> { return None; }; - let right_bracket = self.current()?.value; - let TokenKind::Bracket(_) = &self.current()?.kind else { + let TokenKind::Bracket(right_bracket) = &self.current()?.kind else { self.set_position(initial_position); return None; }; @@ -701,10 +698,10 @@ impl<'arena, 'token_ref, Flavor> Parser<'arena, 'token_ref, Flavor> { let interval_node = self.node( Node::Interval { - left_bracket, left, right, - right_bracket, + left_bracket: *left_bracket, + right_bracket: *right_bracket, }, |_| NodeMetadata { span: (initial_position as u32, self.position() as u32), diff --git a/core/expression/src/vm/variable.rs b/core/expression/src/vm/variable.rs index 0ca7f1c9..1be6ff22 100644 --- a/core/expression/src/vm/variable.rs +++ b/core/expression/src/vm/variable.rs @@ -1,10 +1,12 @@ +use crate::lexer::Bracket; use crate::variable::Variable; use ahash::{HashMap, HashMapExt}; -use std::rc::Rc; +use rust_decimal::prelude::ToPrimitive; +use rust_decimal::Decimal; pub(crate) struct IntervalObject { - pub(crate) left_bracket: Rc, - pub(crate) right_bracket: Rc, + pub(crate) left_bracket: Bracket, + pub(crate) right_bracket: Bracket, pub(crate) left: Variable, pub(crate) right: Variable, } @@ -19,11 +21,11 @@ impl IntervalObject { ); tree.insert( "left_bracket".to_string(), - Variable::String(self.left_bracket.clone()), + Variable::Number(Decimal::from(self.left_bracket as usize)), ); tree.insert( "right_bracket".to_string(), - Variable::String(self.right_bracket.clone()), + Variable::Number(Decimal::from(self.right_bracket as usize)), ); tree.insert("left".to_string(), self.left.clone()); tree.insert("right".to_string(), self.right.clone()); @@ -41,14 +43,14 @@ impl IntervalObject { return None; } - let left_bracket = tree_ref.get("left_bracket")?.as_rc_str()?; - let right_bracket = tree_ref.get("right_bracket")?.as_rc_str()?; + let left_bracket = tree_ref.get("left_bracket")?.as_number()?.to_usize()?; + let right_bracket = tree_ref.get("right_bracket")?.as_number()?.to_usize()?; let left = tree_ref.get("left")?.clone(); let right = tree_ref.get("right")?.clone(); Some(Self { - left_bracket, - right_bracket, + left_bracket: Bracket::from_repr(left_bracket)?, + right_bracket: Bracket::from_repr(right_bracket)?, right, left, }) diff --git a/core/expression/src/vm/vm.rs b/core/expression/src/vm/vm.rs index 803d3b33..4971d822 100644 --- a/core/expression/src/vm/vm.rs +++ b/core/expression/src/vm/vm.rs @@ -1,4 +1,5 @@ -use crate::compiler::{Opcode, TypeCheckKind, TypeConversionKind}; +use crate::compiler::{FetchFastTarget, Jump, Opcode, TypeCheckKind, TypeConversionKind}; +use crate::lexer::Bracket; use crate::variable::Variable; use crate::variable::Variable::*; use crate::vm::error::VMError::*; @@ -50,16 +51,16 @@ impl VM { } } -struct VMInner<'arena, 'parent_ref, 'bytecode_ref> { +struct VMInner<'parent_ref, 'bytecode_ref> { scopes: &'parent_ref mut Vec, stack: &'parent_ref mut Vec, - bytecode: &'bytecode_ref [Opcode<'arena>], - ip: usize, + bytecode: &'bytecode_ref [Opcode], + ip: u32, } -impl<'arena, 'parent_ref, 'bytecode_ref> VMInner<'arena, 'parent_ref, 'bytecode_ref> { +impl<'arena, 'parent_ref, 'bytecode_ref> VMInner<'parent_ref, 'bytecode_ref> { pub fn new( - bytecode: &'bytecode_ref [Opcode<'arena>], + bytecode: &'bytecode_ref [Opcode], stack: &'parent_ref mut Vec, scopes: &'parent_ref mut Vec, ) -> Self { @@ -86,21 +87,22 @@ impl<'arena, 'parent_ref, 'bytecode_ref> VMInner<'arena, 'parent_ref, 'bytecode_ self.ip = 0; } - while self.ip < self.bytecode.len() { + while self.ip < self.bytecode.len() as u32 { let op = self .bytecode - .get(self.ip) + .get(self.ip as usize) .ok_or_else(|| OpcodeOutOfBounds { bytecode: format!("{:?}", self.bytecode), - index: self.ip, + index: self.ip as usize, })?; self.ip += 1; match op { - Opcode::Push(v) => { - self.push(v.clone()); - } + Opcode::PushNull => self.push(Null), + Opcode::PushBool(b) => self.push(Bool(*b)), + Opcode::PushNumber(n) => self.push(Number(*n)), + Opcode::PushString(s) => self.push(String(Rc::from(s.as_ref()))), Opcode::Pop => { self.pop()?; } @@ -144,10 +146,31 @@ impl<'arena, 'parent_ref, 'bytecode_ref> VMInner<'arena, 'parent_ref, 'bytecode_ _ => self.push(Null), } } + Opcode::FetchFast(path) => { + let variable = path.iter().fold(Null, |v, p| match p { + FetchFastTarget::Root => env.clone(), + FetchFastTarget::String(key) => match v { + Object(obj) => { + let obj_ref = obj.borrow(); + obj_ref.get(key.as_ref()).cloned().unwrap_or(Null) + } + _ => Null, + }, + FetchFastTarget::Number(num) => match v { + Array(arr) => { + let arr_ref = arr.borrow(); + arr_ref.get(*num as usize).cloned().unwrap_or(Null) + } + _ => Null, + }, + }); + + self.push(variable); + } Opcode::FetchEnv(f) => match &env { Object(o) => { let obj = o.borrow(); - match obj.get(*f) { + match obj.get(f.as_ref()) { None => self.push(Null), Some(v) => self.push(v.clone()), } @@ -210,62 +233,72 @@ impl<'arena, 'parent_ref, 'bytecode_ref> VMInner<'arena, 'parent_ref, 'bytecode_ } } } - Opcode::Jump(j) => self.ip += j, - Opcode::JumpIfTrue(j) => { - let a = self.stack.last().ok_or_else(|| OpcodeErr { - opcode: "JumpIfTrue".into(), - message: "Undefined object key".into(), - })?; - match a { - Bool(a) => { - if *a { - self.ip += j; + Opcode::Jump(kind, j) => match kind { + Jump::Forward => self.ip += j, + Jump::Backward => self.ip -= j, + Jump::IfTrue => { + let a = self.stack.last().ok_or_else(|| OpcodeErr { + opcode: "JumpIfTrue".into(), + message: "Undefined object key".into(), + })?; + match a { + Bool(a) => { + if *a { + self.ip += j; + } + } + _ => { + return Err(OpcodeErr { + opcode: "JumpIfTrue".into(), + message: "Unsupported type".into(), + }); } } - _ => { - return Err(OpcodeErr { - opcode: "JumpIfTrue".into(), - message: "Unsupported type".into(), - }); + } + Jump::IfFalse => { + let a = self.stack.last().ok_or_else(|| OpcodeErr { + opcode: "JumpIfFalse".into(), + message: "Empty array".into(), + })?; + + match a { + Bool(a) => { + if !*a { + self.ip += j; + } + } + _ => { + return Err(OpcodeErr { + opcode: "JumpIfFalse".into(), + message: "Unsupported type".into(), + }); + } } } - } - Opcode::JumpIfFalse(j) => { - let a = self.stack.last().ok_or_else(|| OpcodeErr { - opcode: "JumpIfFalse".into(), - message: "Empty array".into(), - })?; + Jump::IfNotNull => { + let a = self.stack.last().ok_or_else(|| OpcodeErr { + opcode: "JumpIfNull".into(), + message: "Empty array".into(), + })?; - match a { - Bool(a) => { - if !*a { + match a { + Null => {} + _ => { self.ip += j; } } - _ => { - return Err(OpcodeErr { - opcode: "JumpIfFalse".into(), - message: "Unsupported type".into(), - }); - } } - } - Opcode::JumpIfNotNull(j) => { - let a = self.stack.last().ok_or_else(|| OpcodeErr { - opcode: "JumpIfNull".into(), - message: "Empty array".into(), - })?; + Jump::IfEnd => { + let scope = self.scopes.last().ok_or_else(|| OpcodeErr { + opcode: "JumpIfEnd".into(), + message: "Empty stack".into(), + })?; - match a { - Null => {} - _ => { + if scope.iter >= scope.len { self.ip += j; } } - } - Opcode::JumpBackward(j) => { - self.ip -= j; - } + }, Opcode::In => { let b = self.pop()?; let a = self.pop()?; @@ -291,35 +324,35 @@ impl<'arena, 'parent_ref, 'bytecode_ref> VMInner<'arena, 'parent_ref, 'bytecode_ (Number(l), Number(r)) => { let mut is_open = false; - let first = match interval.left_bracket.as_ref() { - "[" => l <= v, - "(" => l < v, - "]" => { + let first = match interval.left_bracket { + Bracket::LeftParenthesis => l < v, + Bracket::LeftSquareBracket => l <= v, + Bracket::RightParenthesis => { is_open = true; - l >= v + l > v } - ")" => { + Bracket::RightSquareBracket => { is_open = true; - l > v + l >= v } _ => { return Err(OpcodeErr { opcode: "In".into(), message: "Unsupported bracket".into(), - }); + }) } }; - let second = match interval.right_bracket.as_ref() { - "]" => r >= v, - ")" => r > v, - "[" => r <= v, - "(" => r < v, + let second = match interval.right_bracket { + Bracket::RightParenthesis => r > v, + Bracket::RightSquareBracket => r >= v, + Bracket::LeftParenthesis => r < v, + Bracket::LeftSquareBracket => r <= v, _ => { return Err(OpcodeErr { opcode: "In".into(), message: "Unsupported bracket".into(), - }); + }) } }; @@ -849,8 +882,8 @@ impl<'arena, 'parent_ref, 'bytecode_ref> VMInner<'arena, 'parent_ref, 'bytecode_ match (&a, &b) { (Number(_), Number(_)) => { let interval = IntervalObject { - left_bracket: Rc::from(*left_bracket), - right_bracket: Rc::from(*right_bracket), + left_bracket: *left_bracket, + right_bracket: *right_bracket, left: a, right: b, }; @@ -1176,7 +1209,7 @@ impl<'arena, 'parent_ref, 'bytecode_ref> VMInner<'arena, 'parent_ref, 'bytecode_ let timestamp = self.pop()?; let time: NaiveDateTime = (×tamp).try_into()?; - let var = match *operation { + let var = match operation.as_ref() { "year" => Number(time.year().into()), "dayOfWeek" => Number(time.weekday().number_from_monday().into()), "dayOfMonth" => Number(time.day().into()), @@ -1208,7 +1241,7 @@ impl<'arena, 'parent_ref, 'bytecode_ref> VMInner<'arena, 'parent_ref, 'bytecode_ }); }; - let s = match *name { + let s = match name.as_ref() { "startOf" => date_time_start_of(date_time, unit_name.as_ref().try_into()?), "endOf" => date_time_end_of(date_time, unit_name.as_ref().try_into()?), _ => { @@ -1509,16 +1542,6 @@ impl<'arena, 'parent_ref, 'bytecode_ref> VMInner<'arena, 'parent_ref, 'bytecode_ let var = self.pop()?; self.push(String(Rc::from(var.type_name()))); } - Opcode::JumpIfEnd(j) => { - let scope = self.scopes.last().ok_or_else(|| OpcodeErr { - opcode: "JumpIfEnd".into(), - message: "Empty stack".into(), - })?; - - if scope.iter >= scope.len { - self.ip += j; - } - } Opcode::IncrementIt => { let scope = self.scopes.last_mut().ok_or_else(|| OpcodeErr { opcode: "IncrementIt".into(), diff --git a/core/expression/tests/standard.rs b/core/expression/tests/standard.rs index 1e500f7d..f534d354 100644 --- a/core/expression/tests/standard.rs +++ b/core/expression/tests/standard.rs @@ -4,7 +4,7 @@ use rust_decimal::Decimal; use rust_decimal_macros::dec; use zen_expression::lexer::{ - ArithmeticOperator, ComparisonOperator, Lexer, LogicalOperator, Operator, + ArithmeticOperator, Bracket, ComparisonOperator, Lexer, LogicalOperator, Operator, }; use zen_expression::parser::{Node, Parser}; @@ -30,10 +30,10 @@ fn standard_test() { StandardTest { src: ")10..25(", result: &Node::Interval { - left_bracket: ")", + left_bracket: Bracket::RightParenthesis, + right_bracket: Bracket::LeftParenthesis, left: &Node::Number(D10), right: &Node::Number(D25), - right_bracket: "(", }, }, StandardTest { @@ -223,10 +223,10 @@ fn standard_test() { left: &Node::Identifier("x"), operator: Operator::Comparison(ComparisonOperator::NotIn), right: &Node::Interval { - left_bracket: "(", + left_bracket: Bracket::LeftParenthesis, left: &Node::Number(D1), right: &Node::Number(D9), - right_bracket: "]", + right_bracket: Bracket::RightSquareBracket, }, }, }, diff --git a/core/expression/tests/unary.rs b/core/expression/tests/unary.rs index e6721c69..b1b2828c 100644 --- a/core/expression/tests/unary.rs +++ b/core/expression/tests/unary.rs @@ -2,7 +2,7 @@ use bumpalo::Bump; use rust_decimal::Decimal; use rust_decimal_macros::dec; -use zen_expression::lexer::{ComparisonOperator, Lexer, LogicalOperator, Operator}; +use zen_expression::lexer::{Bracket, ComparisonOperator, Lexer, LogicalOperator, Operator}; use zen_expression::parser::{BuiltInFunction, Node, Parser}; struct UnaryTest { @@ -66,8 +66,8 @@ fn unary_test() { operator: Operator::Comparison(ComparisonOperator::In), left: &Node::Identifier("$"), right: &Node::Interval { - left_bracket: "[", - right_bracket: "]", + left_bracket: Bracket::LeftSquareBracket, + right_bracket: Bracket::RightSquareBracket, left: &Node::Number(D1), right: &Node::Number(D10), }, @@ -79,8 +79,8 @@ fn unary_test() { operator: Operator::Comparison(ComparisonOperator::In), left: &Node::Identifier("$"), right: &Node::Interval { - left_bracket: "[", - right_bracket: "]", + left_bracket: Bracket::LeftSquareBracket, + right_bracket: Bracket::RightSquareBracket, left: &Node::Number(D1), right: &Node::Number(D10), }, @@ -92,8 +92,8 @@ fn unary_test() { operator: Operator::Comparison(ComparisonOperator::NotIn), left: &Node::Identifier("$"), right: &Node::Interval { - left_bracket: "[", - right_bracket: "]", + left_bracket: Bracket::LeftSquareBracket, + right_bracket: Bracket::RightSquareBracket, left: &Node::Number(D1), right: &Node::Number(D10), },