Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Configure pyink formatter in pre-commit and apply to all files #1341

Merged
merged 1 commit into from
Jan 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
233 changes: 115 additions & 118 deletions .github/workflows/copy_tblgen_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,147 +11,144 @@


def cleanup(content):
content = content.replace("<!-- Autogenerated by mlir-tblgen; don't manually edit -->", "")
return content
content = content.replace(
"<!-- Autogenerated by mlir-tblgen; don't manually edit -->", ""
)
return content


def extract_frontmatter(content):
match = re.search(FRONTMATTER_PATTERN, content, re.DOTALL)
if match:
frontmatter = match.group(1)
return yaml.safe_load(frontmatter), content[match.end() :]
return None, content
match = re.search(FRONTMATTER_PATTERN, content, re.DOTALL)
if match:
frontmatter = match.group(1)
return yaml.safe_load(frontmatter), content[match.end() :]
return None, content


def split_sections(content):
# Assumes only level ### headers delineate sections, (#### is used for pass options)
sections = re.split(r"(^### .*$)", content, flags=re.MULTILINE)
return sections
# Assumes only level ### headers delineate sections, (#### is used for pass options)
sections = re.split(r"(^### .*$)", content, flags=re.MULTILINE)
return sections


def sort_sections(sections):
headers_and_content = []
headers_and_content = []

for i in range(1, len(sections), 2):
header = sections[i].strip()
body = sections[i + 1]
headers_and_content.append((header, body))
for i in range(1, len(sections), 2):
header = sections[i].strip()
body = sections[i + 1]
headers_and_content.append((header, body))

sorted_sections = sorted(headers_and_content, key=lambda x: x[0])
return sorted_sections
sorted_sections = sorted(headers_and_content, key=lambda x: x[0])
return sorted_sections


def rebuild_content(frontmatter, sorted_sections):
frontmatter_str = (
"---\n" + yaml.dump(frontmatter) + "---\n\n" if frontmatter else ""
)
content_str = "".join([f"{header}\n{body}\n" for header, body in sorted_sections])
return frontmatter_str + content_str
frontmatter_str = (
"---\n" + yaml.dump(frontmatter) + "---\n\n" if frontmatter else ""
)
content_str = "".join(
[f"{header}\n{body}\n" for header, body in sorted_sections]
)
return frontmatter_str + content_str


def sort_markdown_file_by_header(path):
with open(path, "r") as f:
content = f.read()
with open(path, "r") as f:
content = f.read()

frontmatter, content_without_frontmatter = extract_frontmatter(content)
sections = split_sections(content_without_frontmatter)
sorted_sections = sort_sections(sections)
sorted_content = rebuild_content(frontmatter, sorted_sections)
frontmatter, content_without_frontmatter = extract_frontmatter(content)
sections = split_sections(content_without_frontmatter)
sorted_sections = sort_sections(sections)
sorted_content = rebuild_content(frontmatter, sorted_sections)

with open(path, "w") as f:
f.write(sorted_content)
with open(path, "w") as f:
f.write(sorted_content)


if __name__ == "__main__":
# Create passes.md file with the front matter
with open(PASSES_FILE, "w") as f:
f.write(
"""---
# Create passes.md file with the front matter
with open(PASSES_FILE, "w") as f:
f.write("""---
title: Passes
weight: 70
---\n"""
)

print("Processing Non-conversion Passes")
for src_path in glob.glob(f"{SRC_BASE}/**/*Passes.md", recursive=True):
with open(src_path, "r") as src_file:
with open(PASSES_FILE, "a") as dest_file:
dest_file.write(cleanup(src_file.read()))

print("Processing Conversion Passes")
for src_path in glob.glob(
f"{SRC_BASE}/Dialect/**/Conversions/**/*.md", recursive=True
):
with open(src_path, "r") as src_file:
with open(PASSES_FILE, "a") as dest_file:
dest_file.write(cleanup(src_file.read()))
dest_file.write("\n")

sort_markdown_file_by_header(PASSES_FILE)

print("Processing Dialects")
Path(f"{DEST_BASE}/Dialects/").mkdir(parents=True, exist_ok=True)
for dialect_dir in glob.glob(f"{SRC_BASE}/Dialect/*"):
if not os.path.isdir(dialect_dir):
print(f"Skipping non-directory file {dialect_dir}")
continue

dialect_name = os.path.basename(dialect_dir)
print(f"Processing {dialect_name}")
filename = f"{dialect_name}.md"
dest_path = f"{DEST_BASE}/Dialects/{filename}"

markdown_files = {}
for src_path in glob.glob(f"{dialect_dir}/IR/**/*.md", recursive=True):
markdown_files[src_path] = True
print(f"Adding {src_path} to the queue")

if not markdown_files:
print(f"Skipping {dialect_name} as no markdown files found")
continue


# Write front matter for the dialect markdown
with open(dest_path, "w") as f:
f.write(
f"""---
---\n""")

print("Processing Non-conversion Passes")
for src_path in glob.glob(f"{SRC_BASE}/**/*Passes.md", recursive=True):
with open(src_path, "r") as src_file:
with open(PASSES_FILE, "a") as dest_file:
dest_file.write(cleanup(src_file.read()))

print("Processing Conversion Passes")
for src_path in glob.glob(
f"{SRC_BASE}/Dialect/**/Conversions/**/*.md", recursive=True
):
with open(src_path, "r") as src_file:
with open(PASSES_FILE, "a") as dest_file:
dest_file.write(cleanup(src_file.read()))
dest_file.write("\n")

sort_markdown_file_by_header(PASSES_FILE)

print("Processing Dialects")
Path(f"{DEST_BASE}/Dialects/").mkdir(parents=True, exist_ok=True)
for dialect_dir in glob.glob(f"{SRC_BASE}/Dialect/*"):
if not os.path.isdir(dialect_dir):
print(f"Skipping non-directory file {dialect_dir}")
continue

dialect_name = os.path.basename(dialect_dir)
print(f"Processing {dialect_name}")
filename = f"{dialect_name}.md"
dest_path = f"{DEST_BASE}/Dialects/{filename}"

markdown_files = {}
for src_path in glob.glob(f"{dialect_dir}/IR/**/*.md", recursive=True):
markdown_files[src_path] = True
print(f"Adding {src_path} to the queue")

if not markdown_files:
print(f"Skipping {dialect_name} as no markdown files found")
continue

# Write front matter for the dialect markdown
with open(dest_path, "w") as f:
f.write(f"""---
title: {dialect_name}
github_url: https://github.com/google/heir/edit/main/lib/Dialect/{dialect_name}/IR
---\n"""
)

# Process files in a specific order
groups = ["Dialect", "Attributes", "Types", "Ops"]

for index, group in enumerate(groups):
search_pattern = f"{dialect_dir}/IR/*{group}.md"
for src_path in glob.glob(search_pattern):
if not os.path.isfile(src_path):
continue

with open(dest_path, "a") as dest_file:
if index == 0:
# Special care for the first group
with open(src_path, "r") as src_file:
content = (
src_file.read()
.replace("# Dialect", "")
.replace("[TOC]", "")
)
dest_file.write(content)
else:
dest_file.write(f"## {dialect_name} {group.lower()}\n")
with open(src_path, "r") as src_file:
dest_file.write(src_file.read())

markdown_files.pop(src_path, None)

# Include additional files not processed in the groups
if markdown_files:
with open(dest_path, "a") as dest_file:
dest_file.write(f"## {dialect_name} additional definitions\n")
for src_path in list(markdown_files.keys()):
if os.path.isfile(src_path):
with open(src_path, "r") as src_file:
dest_file.write(src_file.read())
markdown_files.pop(src_path)
---\n""")

# Process files in a specific order
groups = ["Dialect", "Attributes", "Types", "Ops"]

for index, group in enumerate(groups):
search_pattern = f"{dialect_dir}/IR/*{group}.md"
for src_path in glob.glob(search_pattern):
if not os.path.isfile(src_path):
continue

with open(dest_path, "a") as dest_file:
if index == 0:
# Special care for the first group
with open(src_path, "r") as src_file:
content = (
src_file.read().replace("# Dialect", "").replace("[TOC]", "")
)
dest_file.write(content)
else:
dest_file.write(f"## {dialect_name} {group.lower()}\n")
with open(src_path, "r") as src_file:
dest_file.write(src_file.read())

markdown_files.pop(src_path, None)

# Include additional files not processed in the groups
if markdown_files:
with open(dest_path, "a") as dest_file:
dest_file.write(f"## {dialect_name} additional definitions\n")
for src_path in list(markdown_files.keys()):
if os.path.isfile(src_path):
with open(src_path, "r") as src_file:
dest_file.write(src_file.read())
markdown_files.pop(src_path)
2 changes: 2 additions & 0 deletions .github/workflows/pr_style.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ jobs:
steps:
- uses: actions/checkout@8e5e7e5ab8b370d6c329ec480221332ada57f0ab # pin@v3
- uses: actions/setup-python@61a6322f88396a6271a6ee3565807d608ecaddd1 # pin@v4.7.0
with:
python-version: '3.11'

- name: Run pre-commit github action
uses: pre-commit/action@646c83fcd040023954eafda54b4db0192ce70507 # pin@v3
7 changes: 7 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -71,4 +71,11 @@ repos:
hooks:
- id: go-fmt

# python formatter
- repo: https://github.com/google/pyink
rev: 24.10.1
hooks:
- id: pyink
language_version: python3.11

exclude: patches/.*\.patch$
12 changes: 12 additions & 0 deletions heir_py/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,15 @@ variables:
`bazel test` should work out of the box. If it does not, file a bug.
`heir_py/testing.bzl` contains the environment variable setup required to tell
the frontend where to find OpenFHE and related backend shared libraries.

## Formatting

This uses [pyink](https://github.com/google/pyink) for autoformatting, which is
a fork of the more commonly used [black](https://github.com/psf/black) formatter
with some patches to support Google's internal style guide. The configuration in
pyproject.toml corresponds to the options making pyink consistent with Google's
internal style guide.

The `pyink` repo has instructions for setting up `pyink` with VSCode. The
pre-commit configuration for this repo will automatically run `pyink`, and to
run a one-time format of the entire project, use `pre-commit run --all-files`.
4 changes: 2 additions & 2 deletions heir_py/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def compile(
backend: str = "openfhe",
backend_config: Optional[openfhe_config.OpenFHEConfig] = None,
heir_config: Optional[_heir_config.HEIRConfig] = None,
debug : Optional[bool] = False
debug: Optional[bool] = False,
):
"""Compile a function to its private equivalent in FHE.

Expand All @@ -136,7 +136,7 @@ def decorator(func):
func,
openfhe_config=backend_config or openfhe_config.from_os_env(),
heir_config=heir_config or _heir_config.from_os_env(),
debug = debug
debug=debug,
)
if backend == "openfhe":
return OpenfheClientInterface(compilation_result)
Expand Down
4 changes: 3 additions & 1 deletion heir_py/openfhe_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,9 @@ def from_os_env(debug=False) -> OpenFHEConfig:

for include_dir in include_dirs:
if not os.path.exists(include_dir):
print(f"Warning: OpenFHE include directory \"{include_dir}\" does not exist")
print(
f'Warning: OpenFHE include directory "{include_dir}" does not exist'
)

return OpenFHEConfig(
include_dirs=include_dirs
Expand Down
31 changes: 17 additions & 14 deletions heir_py/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,22 +78,22 @@ def run_compiler(
heir_opt_options = [
f"--secretize=function={func_name}",
"--mlir-to-bgv=ciphertext-degree=32",
f"--scheme-to-openfhe=entry-function={func_name}"
f"--scheme-to-openfhe=entry-function={func_name}",
]
heir_opt_output = heir_opt.run_binary(
input=mlir_textual,
options=heir_opt_options,
)

if(debug):
mlir_in_filepath = Path(workspace_dir) / f"{func_name}.in.mlir"
print(f"Debug mode enabled. Writing Input MLIR to {mlir_in_filepath}")
with open(mlir_in_filepath, "w") as f:
f.write(mlir_textual)
mlir_out_filepath = Path(workspace_dir) / f"{func_name}.out.mlir"
print(f"Debug mode enabled. Writing Output MLIR to {mlir_out_filepath}")
with open(mlir_out_filepath, "w") as f:
f.write(heir_opt_output)
if debug:
mlir_in_filepath = Path(workspace_dir) / f"{func_name}.in.mlir"
print(f"Debug mode enabled. Writing Input MLIR to {mlir_in_filepath}")
with open(mlir_in_filepath, "w") as f:
f.write(mlir_textual)
mlir_out_filepath = Path(workspace_dir) / f"{func_name}.out.mlir"
print(f"Debug mode enabled. Writing Output MLIR to {mlir_out_filepath}")
with open(mlir_out_filepath, "w") as f:
f.write(heir_opt_output)

heir_translate = heir_backend.HeirTranslateBackend(
binary_path=heir_config.heir_translate_path
Expand Down Expand Up @@ -130,10 +130,13 @@ def run_compiler(
clang_backend = clang.ClangBackend()
so_filepath = Path(workspace_dir) / f"{func_name}.so"
linker_search_paths = [openfhe_config.lib_dir]
if(debug):
print(
f"Debug mode enabled. Compiling {cpp_filepath} with linker search paths: {linker_search_paths}, include paths: {openfhe_config.include_dirs}, link libs: {openfhe_config.link_libs}\n"
)
if debug:
print(
f"Debug mode enabled. Compiling {cpp_filepath} with linker search"
f" paths: {linker_search_paths}, include paths:"
f" {openfhe_config.include_dirs}, link libs:"
f" {openfhe_config.link_libs}\n"
)

clang_backend.compile_to_shared_object(
cpp_source_filepath=cpp_filepath,
Expand Down
6 changes: 6 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,9 @@ norecursedirs = [
"tools",
"venv",
]

[tool.pyink]
line-length = 80
unstable = true
pyink-indentation = 2
pyink-use-majority-quotes = true
Loading
Loading