Skip to content

Commit

Permalink
feat: Improve * flag behaviour when evaluation an expression.
Browse files Browse the repository at this point in the history
  • Loading branch information
wxgeo committed Nov 30, 2023
1 parent 7ce504d commit 9a88062
Show file tree
Hide file tree
Showing 8 changed files with 152 additions and 62 deletions.
6 changes: 2 additions & 4 deletions ptyx/compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,9 +320,7 @@ def make_files(
if options.generate_batch_for_windows_printing:
bat_file_name = ptyx_file.parent / ("print_corr.bat" if correction else "print.bat")
with open(bat_file_name, "w") as bat_file:
bat_file.write(
param["win_print_command"] + " ".join(f'"{f.name}.pdf"' for f in filenames) # type: ignore
)
bat_file.write(param["win_print_command"] + " ".join(f'"{f.name}.pdf"' for f in filenames))

# Copy pdf file/files to parent directory.
_copy_file_to_parent("pdf", filenames, ptyx_file, output_basename, options)
Expand Down Expand Up @@ -474,7 +472,7 @@ def compile_latex_to_pdf(

def _build_command(filename: Path, dest: Path, quiet: Optional[bool] = False) -> str:
"""Generate the command used to compile the LaTeX file."""
command: str = param["quiet_tex_command"] if quiet else param["tex_command"] # type: ignore
command: str = param["quiet_tex_command"] if quiet else param["tex_command"]
command += f' -output-directory "{dest}" "{filename}"'
return command

Expand Down
3 changes: 1 addition & 2 deletions ptyx/compilation_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from argparse import Namespace



@dataclass(kw_only=True, frozen=True)
class CompilationOptions:
filenames: list[str] = field(default_factory=list)
Expand All @@ -26,7 +25,7 @@ class CompilationOptions:
context: dict[str, Any] = field(default_factory=dict)

@classmethod
def load(cls, options: Namespace) -> "CompilationOptions": # type: ignore
def load(cls, options: Namespace) -> "CompilationOptions":
kwargs = vars(options)

# -------------------------
Expand Down
5 changes: 4 additions & 1 deletion ptyx/config.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from ptyx.internal_types import ParamDict

try:
from ptyx.custom_config import param as custom_param
except ImportError:
custom_param = {}


# <default_configuration>
param = {
param: ParamDict = {
"tex_command": "pdflatex -interaction=nonstopmode --shell-escape --enable-write18",
"quiet_tex_command": "pdflatex -interaction=batchmode --shell-escape --enable-write18",
"sympy_is_default": True,
Expand Down
51 changes: 51 additions & 0 deletions ptyx/internal_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from dataclasses import dataclass
from enum import Enum, auto
from typing import TypedDict


class ParamDict(TypedDict):
tex_command: str
quiet_tex_command: str
sympy_is_default: bool
import_paths: list[str]
debug: bool
floating_point: str
win_print_command: str


class CustomParamDict(TypedDict, total=False):
tex_command: str
quiet_tex_command: str
sympy_is_default: bool
import_paths: list[str]
debug: bool
floating_point: str
win_print_command: str


class NiceOp(Enum):
NONE = auto()
ADD = auto()
SUB = auto()
MUL = auto()
EQ = auto()


class PickItemAction(Enum):
NONE = auto()
SELECT_FROM_NUM = auto()
RAND_CHOICE = auto()


@dataclass(kw_only=True)
class EvalFlags:
is_mul_coeff: bool = False
eval_as_float: bool = False
keep_dot_as_decimal_mark: bool = False
format_as_str: bool = False
round: int | None = None
previous_nice_op: NiceOp = NiceOp.NONE
round_result: bool = False
result_is_exact: bool | None = None
pick_action: PickItemAction = PickItemAction.NONE
suppress_next_eval: bool = False
116 changes: 67 additions & 49 deletions ptyx/latex_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
from ptyx import __version__
from ptyx.config import param
from ptyx.context import GLOBAL_CONTEXT
from ptyx.internal_types import NiceOp, PickItemAction, EvalFlags
from ptyx.sys_info import SYMPY_AVAILABLE

# from ptyx.printers import sympy2latex
from ptyx.syntax_tree import Node, SyntaxTreeGenerator, Tag, TagSyntax
from ptyx.utilities import advanced_split, numbers_to_floats, _float_me_if_you_can, latex_verbatim

Expand Down Expand Up @@ -71,7 +71,7 @@ class LatexGenerator:

# noinspection RegExpRedundantEscape
re_varname = re.compile(r"[A-Za-z_]\w*(\[.+\])?$")
flags: dict[str, bool | int]
flags: EvalFlags
macros: dict[str, list[str | Node]]
# TODO: use a TypedDict for context. (Extensions should use TypedDict inheriting from this one.)
# Since typing all of global context isn't realistic (in includes sympy!), one should
Expand Down Expand Up @@ -109,7 +109,7 @@ def clear(self):
# Now, one should use write(..., parse=True) if needed, which is much saner.
self.context["write"] = self.write
# Internal flags
self.flags = {}
self.flags = EvalFlags(eval_as_float=self.context.get("ALL_FLOATS"))

def reset(self):
"""To overwrite."""
Expand Down Expand Up @@ -262,10 +262,17 @@ def write(self, text: str, parse: bool = False, verbatim: bool = False) -> None:
# so some basic testing is done before.
text = str(text)
if text.strip():
for flag in "+-*":
if self.flags.get(flag):
text = (r"\times " if flag == "*" else flag) + text
self.flags[flag] = False
match (flags := self.flags).previous_nice_op:
case NiceOp.ADD:
text = "+" + text
flags.previous_nice_op = NiceOp.NONE
case NiceOp.SUB:
text = "-" + text
flags.previous_nice_op = NiceOp.NONE
case NiceOp.MUL:
text = r"\times " + text
flags.previous_nice_op = NiceOp.NONE

if parse and "#" in text:
if param["debug"]:
print("Parsing %s..." % repr(text))
Expand Down Expand Up @@ -401,23 +408,27 @@ def _parse_ASSERT_tag(self, node: Node):
assert test

def _parse_EVAL_tag(self, node: Node):
if self.flags.suppress_next_eval:
self.flags = EvalFlags(eval_as_float=self.context.get("ALL_FLOATS")) # type: ignore
return
args, kw = self._parse_options(node)
if self.context.get("ALL_FLOATS"):
self.flags["floats"] = True
for arg in args:
if arg.isdigit():
self.flags["round"] = int(arg)
self.flags.round = int(arg)
elif arg == ".":
self.flags["."] = True
elif arg == "?":
self.flags["?"] = True
self.flags.keep_dot_as_decimal_mark = True
elif arg == "*":
self.flags.is_mul_coeff = True
elif arg in ("floats", "float"):
self.flags["floats"] = True
self.flags.eval_as_float = True
elif arg == "str":
self.flags["str"] = True
self.flags.format_as_str = True
elif arg == "select":
self.flags.pick_action = PickItemAction.SELECT_FROM_NUM
elif arg == "rand":
self.flags.pick_action = PickItemAction.RAND_CHOICE
else:
raise ValueError("Unknown flag: " + repr(arg))
# XXX: support options round, float, (sympy, python,) select and rand
code = node.arg(0)
assert isinstance(code, str), type(code)
try:
Expand All @@ -428,7 +439,10 @@ def _parse_EVAL_tag(self, node: Node):
# Tags must be cleared *before* calling .write(txt), since .write(txt)
# add '+', '-' and '\times ' before txt if corresponding flags are set,
# and ._eval_and_format_python_expr() has already do this.
self.flags.clear()
self.flags = EvalFlags(
eval_as_float=self.context.get("ALL_FLOATS"), # type: ignore
suppress_next_eval=self.flags.suppress_next_eval,
)
self.write(txt)

def _parse_MACRO_tag(self, node: Node):
Expand Down Expand Up @@ -586,23 +600,23 @@ def _parse_ADD_tag(self, node: Node) -> None:
# a '+' will be displayed at the beginning of the next result if positive ;
# if the result is negative, nothing will be done, and if null,
# no result at all will be displayed.
self.flags["+"] = True
self.flags.previous_nice_op = NiceOp.ADD

def _parse_SUB_tag(self, node: Node) -> None:
# a '-' will be displayed at the beginning of the next result, and the result
# will be embedded in parentheses if negative.
self.flags["-"] = True
self.flags.previous_nice_op = NiceOp.SUB

def _parse_MUL_tag(self, node: Node) -> None:
# a '\times' will be displayed at the beginning of the next result, and the result
# will be embedded in parentheses if negative.
self.flags["*"] = True
self.flags.previous_nice_op = NiceOp.MUL

def _parse_EQUAL_tag(self, node: Node) -> None:
# Display '=' or '\approx' when a rounded result is requested :
# if rounded is equal to exact one, '=' is displayed.
# Else, '\approx' is displayed instead.
self.flags["="] = True
self.flags.previous_nice_op = NiceOp.EQ
# All other operations (#+, #-, #*) occur just before number, but between `=` and
# the result, some formatting instructions may occur (like '\fbox{' for example).
# So, `#=` is used as a temporary marker, and will be replaced by '=' or '\approx' later.
Expand Down Expand Up @@ -680,14 +694,9 @@ def _exec_python_code(self, code: str, context: dict):
return code

def _eval_python_expr(self, code: str):
flags = self.flags
context = self.context
if not code:
return ""
sympy_code = flags.get("sympy", param["sympy_is_default"])

if sympy_code and not SYMPY_AVAILABLE:
raise ImportError("sympy library not found.")

varname = ""
i = code.find("=")
Expand All @@ -704,7 +713,7 @@ def _eval_python_expr(self, code: str):
varname = "_"
if " if " in code and " else " not in code:
code += " else ''"
if SYMPY_AVAILABLE and sympy_code:
if param["sympy_is_default"]:
import sympy

try:
Expand Down Expand Up @@ -739,7 +748,6 @@ def _eval_and_format_python_expr(self, code: str) -> str:
context = self.context
if not code:
return ""
sympy_code = flags.get("sympy", param["sympy_is_default"])

display_result = True
if code.endswith(";"):
Expand All @@ -757,39 +765,42 @@ def _eval_and_format_python_expr(self, code: str) -> str:
if not display_result:
return ""

if flags.get("?"):
if flags.is_mul_coeff:
if result == 1:
if flags.get("+"):
if flags.previous_nice_op == NiceOp.ADD:
return "+"
return ""
elif result == -1:
result = "-"
return "-"
elif result == 0:
flags.suppress_next_eval = True
return ""

if sympy_code and not flags.get("str"):
if param["sympy_is_default"] and not flags.format_as_str:
from ptyx.printers import sympy2latex

latex = sympy2latex(result, **flags)
latex = sympy2latex(result, flags)
else:
latex = str(result)

def neg(latex_):
return latex_.lstrip().startswith("-")

if flags.get("+"):
if flags.previous_nice_op == NiceOp.ADD:
if result == 0:
latex = ""
elif not neg(latex):
latex = "+" + latex
elif flags.get("*"):
elif flags.previous_nice_op == NiceOp.MUL:
if neg(latex) or getattr(result, "is_Add", False):
latex = r"\left(" + latex + r"\right)"
latex = r"\times " + latex
elif flags.get("-"):
elif flags.previous_nice_op == NiceOp.SUB:
if neg(latex) or getattr(result, "is_Add", False):
latex = r"\left(" + latex + r"\right)"
latex = "-" + latex
elif flags.get("="):
if flags.get("result_is_exact"):
elif flags.previous_nice_op == NiceOp.EQ:
if flags.result_is_exact:
symb = " = "
else:
symb = r" \approx "
Expand All @@ -813,18 +824,25 @@ def _apply_flag(self, result):
If result is iterable, an element of result is returned, chosen according
to current flag."""
flags = self.flags
if hasattr(result, "__iter__"):
if flags.get("rand"):
result = random.choice(result)
elif flags.get("select"):
result = result[self.NUM % len(result)]
if flags.pick_action != PickItemAction.NONE:
try:
if flags.pick_action == PickItemAction.RAND_CHOICE:
result = randfunc.randchoice(result)
elif flags.pick_action == PickItemAction.SELECT_FROM_NUM:
result = result[self.NUM % len(result)]
except (TypeError, IndexError) as e:
traceback.print_exception(e, limit=2)
print(
"Warning: flags `rand` or `select` have been used for an incorrect data type"
" (more details above)."
)

if "round" in flags:
if flags.round is not None:
try:
if SYMPY_AVAILABLE:
round_result = numbers_to_floats(result, ndigits=flags["round"])
round_result = numbers_to_floats(result, ndigits=flags.round)
else:
round_result = round(result, flags["round"])
round_result = round(result, flags.round)
except ValueError:
print("** ERROR while rounding value: **")
print(result)
Expand All @@ -836,13 +854,13 @@ def _apply_flag(self, result):
import sympy
# noinspection PyUnboundLocalVariable
if SYMPY_AVAILABLE and isinstance(result, sympy.Basic):
flags["result_is_exact"] = {_float_me_if_you_can(elt) for elt in result.atoms()} == {
flags.result_is_exact = {_float_me_if_you_can(elt) for elt in result.atoms()} == {
_float_me_if_you_can(elt) for elt in round_result.atoms()
}
else:
flags["result_is_exact"] = result == round_result
flags.result_is_exact = result == round_result
result = round_result
elif "floats" in self.flags:
elif self.flags.eval_as_float:
result = numbers_to_floats(result)
return result

Expand Down
Loading

0 comments on commit 9a88062

Please sign in to comment.