Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ruff SIM #286

Merged
merged 2 commits into from
Jan 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
#
# The short X.Y version.
ver_dic = {}
exec(compile(open("../pytools/version.py").read(), "../pytools/version.py", "exec"),
with open("../pytools/version.py") as vfile:
exec(compile(vfile.read(), "../pytools/version.py", "exec"),
ver_dic)

version = ".".join(str(x) for x in ver_dic["VERSION"])
release = ver_dic["VERSION_TEXT"]

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ extend-select = [
"RUF", # ruff
"W", # pycodestyle
"TC",
"SIM",
]
extend-ignore = [
"C90", # McCabe complexity
Expand Down
131 changes: 16 additions & 115 deletions pytools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
"""

import builtins
import contextlib
import logging
import operator
import re
Expand Down Expand Up @@ -421,10 +422,8 @@ def __init__(self,
def get_copy_kwargs(self, **kwargs):
for f in self.__class__.fields:
if f not in kwargs:
try:
with contextlib.suppress(AttributeError):
kwargs[f] = getattr(self, f)
except AttributeError:
pass
return kwargs

def copy(self, **kwargs):
Expand Down Expand Up @@ -615,10 +614,7 @@ def is_single_valued(
except StopIteration:
raise ValueError("empty iterable passed to 'single_valued()'") from None

for other_item in it:
if not equality_pred(other_item, first_item):
return False
return True
return all(equality_pred(other_item, first_item) for other_item in it)


all_equal = is_single_valued
Expand All @@ -642,12 +638,7 @@ def single_valued(
except StopIteration:
raise ValueError("empty iterable passed to 'single_valued()'") from None

def others_same():
for other_item in it:
if not equality_pred(other_item, first_item):
return False
return True
assert others_same()
assert all(equality_pred(other_item, first_item) for other_item in it)

return first_item

Expand Down Expand Up @@ -754,10 +745,7 @@ def memoize_on_first_arg(
)

def wrapper(obj: T, *args: P.args, **kwargs: P.kwargs) -> R:
if kwargs:
key = (_HasKwargs, frozenset(kwargs.items()), *args)
else:
key = args
key = (_HasKwargs, frozenset(kwargs.items()), *args) if kwargs else args

assert cache_dict_name is not None
try:
Expand Down Expand Up @@ -1925,89 +1913,6 @@ def word_wrap(text, width, wrap_using="\n"):
# }}}


# {{{ command line interfaces

def _exec_arg(arg, execenv):
import os
if os.access(arg, os.F_OK):
exec(compile(open(arg), arg, "exec"), execenv)
else:
exec(compile(arg, "<command line>", "exec"), execenv)


class CPyUserInterface:
class Parameters(Record):
pass

def __init__(self, variables, constants=None, doc=None):
if constants is None:
constants = {}
if doc is None:
doc = {}
self.variables = variables
self.constants = constants
self.doc = doc

def show_usage(self, progname):
print(f"usage: {progname} <FILE-OR-STATEMENTS>")
print()
print("FILE-OR-STATEMENTS may either be Python statements of the form")
print("'variable1 = value1; variable2 = value2' or the name of a file")
print("containing such statements. Any valid Python code may be used")
print("on the command line or in a command file. If new variables are")
print("used, they must start with 'user_' or just '_'.")
print()
print("The following variables are recognized:")
for v in sorted(self.variables):
print(f" {v} = {self.variables[v]}")
if v in self.doc:
print(f" {self.doc[v]}")

print()
print("The following constants are supplied:")
for c in sorted(self.constants):
print(f" {c} = {self.constants[c]}")
if c in self.doc:
print(f" {self.doc[c]}")

def gather(self, argv=None):
if argv is None:
argv = sys.argv

if len(argv) == 1 or (
("-h" in argv)
or ("help" in argv)
or ("-help" in argv)
or ("--help" in argv)):
self.show_usage(argv[0])
sys.exit(2)

execenv = self.variables.copy()
execenv.update(self.constants)

for arg in argv[1:]:
_exec_arg(arg, execenv)

# check if the user set invalid keys
for added_key in (
set(execenv.keys())
- set(self.variables.keys())
- set(self.constants.keys())):
if not (added_key.startswith("user_") or added_key.startswith("_")):
raise ValueError(
f"invalid setup key: '{added_key}' "
"(user variables must start with 'user_' or '_')")

result = self.Parameters({key: execenv[key] for key in self.variables})
self.validate(result)
return result

def validate(self, setup):
pass

# }}}


# {{{ debugging

class StderrToStdout:
Expand Down Expand Up @@ -2093,9 +1998,8 @@ def invoke_editor(s, filename="edit.txt", descr="the file"):
from os.path import join
full_name = join(tempdir, filename)

outf = open(full_name, "w")
outf.write(str(s))
outf.close()
with open(full_name, "w") as outf:
outf.write(str(s))

import os
if "EDITOR" in os.environ:
Expand All @@ -2107,9 +2011,8 @@ def invoke_editor(s, filename="edit.txt", descr="the file"):
"dropped directly into an editor next time.)")
input(f"Edit {descr} at {full_name} now, then hit [Enter]:")

inf = open(full_name)
result = inf.read()
inf.close()
with open(full_name) as inf:
result = inf.read()

return result

Expand Down Expand Up @@ -2634,16 +2537,14 @@ def __init__(
use_late_start_logging = False

if use_late_start_logging:
try:
# https://github.com/firedrakeproject/firedrake/issues/1422
#
# Starting a thread may fail in various environments, e.g. MPI.
# Since the late-start logging is an optional 'quality-of-life'
# feature for interactive use, tolerate failures of it without
# warning.
with contextlib.suppress(RuntimeError):
self.late_start_log_thread.start()
except RuntimeError:
# https://github.com/firedrakeproject/firedrake/issues/1422
#
# Starting a thread may fail in various environments, e.g. MPI.
# Since the late-start logging is an optional 'quality-of-life'
# feature for interactive use, tolerate failures of it without
# warning.
pass

self.timer = ProcessTimer()

Expand Down
23 changes: 7 additions & 16 deletions pytools/batchjob.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,8 @@ def _cp(src, dest):
from pytools import assert_not_a_file
assert_not_a_file(dest)

inf = open(src, "rb")
try:
outf = open(dest, "wb")
try:
outf.write(inf.read())
finally:
outf.close()
finally:
inf.close()
with open(src, "rb") as inf, open(dest, "wb") as outf:
outf.write(inf.read())


def get_timestamp():
Expand Down Expand Up @@ -43,10 +36,9 @@ def __init__(self, moniker, main_file, aux_files=(), timestamp=None):

os.makedirs(self.path)

runscript = open(f"{self.path}/run.sh", "w")
import sys
runscript.write(f"{sys.executable} {main_file} setup.cpy")
runscript.close()
with open(f"{self.path}/run.sh", "w") as runscript:
import sys
runscript.write(f"{sys.executable} {main_file} setup.cpy")

from os.path import basename

Expand All @@ -58,9 +50,8 @@ def __init__(self, moniker, main_file, aux_files=(), timestamp=None):

def write_setup(self, lines):
import os.path
setup = open(os.path.join(self.path, "setup.cpy"), "w")
setup.write("\n".join(lines))
setup.close()
with open(os.path.join(self.path, "setup.cpy"), "w") as setup:
setup.write("\n".join(lines))


class INHERIT:
Expand Down
18 changes: 9 additions & 9 deletions pytools/convergence.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,15 +145,15 @@ def __str__(self):
return self.pretty_print()

def write_gnuplot_file(self, filename: str) -> None:
outfile = open(filename, "w")
for absc, err in self.history:
outfile.write(f"{absc:f} {err:f}\n")
result = self.estimate_order_of_convergence()
const = result[0, 0]
order = result[0, 1]
outfile.write("\n")
for absc, _err in self.history:
outfile.write(f"{absc:f} {const * absc**(-order):f}\n")
with open(filename, "w") as outfile:
for absc, err in self.history:
outfile.write(f"{absc:f} {err:f}\n")
result = self.estimate_order_of_convergence()
const = result[0, 0]
order = result[0, 1]
outfile.write("\n")
for absc, _err in self.history:
outfile.write(f"{absc:f} {const * absc**(-order):f}\n")


def stringify_eocs(*eocs: EOCRecorder,
Expand Down
17 changes: 5 additions & 12 deletions pytools/debug.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import contextlib
import sys

from pytools import memoize
Expand Down Expand Up @@ -66,17 +67,12 @@ def is_excluded(o):
return True

from sys import _getframe
if isinstance(o, FrameType) and \
o.f_code.co_filename == _getframe().f_code.co_filename:
return True

return False
return bool(isinstance(o, FrameType)
and o.f_code.co_filename == _getframe().f_code.co_filename)

if top_level:
try:
with contextlib.suppress(RefDebugQuit):
refdebug(obj, top_level=False, exclude=exclude)
except RefDebugQuit:
pass
return

import gc
Expand All @@ -94,10 +90,7 @@ def is_excluded(o):
print_head = False
r = reflist[idx]

if isinstance(r, FrameType):
s = str(r.f_code)
else:
s = str(r)
s = str(r.f_code) if isinstance(r, FrameType) else str(r)

print(f"{idx}/{len(reflist)}: ", id(r), type(r), s)

Expand Down
6 changes: 3 additions & 3 deletions pytools/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,9 +368,9 @@ def compute_transitive_closure(
closure = deepcopy(graph)

# (assumes all graph nodes are included in keys)
for k in graph.keys():
for n1 in graph.keys():
for n2 in graph.keys():
for k in graph:
for n1 in graph:
for n2 in graph:
if k in closure[n1] and n2 in closure[k]:
closure[n1].add(n2)

Expand Down
5 changes: 1 addition & 4 deletions pytools/obj_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,10 +369,7 @@ def with_object_array_or_scalar(f, field, obj_array_only=False):
"use obj_array_vectorize", DeprecationWarning, stacklevel=2)

if obj_array_only:
if is_obj_array(field):
ls = field.shape
else:
ls = ()
ls = field.shape if is_obj_array(field) else ()
else:
ls = log_shape(field)
if ls != ():
Expand Down
10 changes: 2 additions & 8 deletions pytools/spatial_btree.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,7 @@ def do_boxes_intersect(bl, tr):
(bl1, tr1) = bl
(bl2, tr2) = tr
(dimension,) = bl1.shape
for i in range(dimension):
if max(bl1[i], bl2[i]) > min(tr1[i], tr2[i]):
return False
return True
return all(max(bl1[i], bl2[i]) <= min(tr1[i], tr2[i]) for i in range(dimension))


def make_buckets(bottom_left, top_right, allbuckets, max_elements_per_box):
Expand Down Expand Up @@ -131,10 +128,7 @@ def generate_matches(self, point):
(dimensions,) = point.shape
bucket = self.buckets
for dim in range(dimensions):
if point[dim] < self.center[dim]:
bucket = bucket[0]
else:
bucket = bucket[1]
bucket = bucket[0] if point[dim] < self.center[dim] else bucket[1]

yield from bucket.generate_matches(point)

Expand Down
Loading
Loading