Skip to content

Commit

Permalink
Merge pull request #1341 from j2kun:pyink
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 721522154
  • Loading branch information
copybara-github committed Jan 30, 2025
2 parents 7859a07 + 19e4424 commit 2c118a3
Show file tree
Hide file tree
Showing 20 changed files with 773 additions and 730 deletions.
233 changes: 115 additions & 118 deletions .github/workflows/copy_tblgen_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,147 +11,144 @@


def cleanup(content):
content = content.replace("<!-- Autogenerated by mlir-tblgen; don't manually edit -->", "")
return content
content = content.replace(
"<!-- Autogenerated by mlir-tblgen; don't manually edit -->", ""
)
return content


def extract_frontmatter(content):
match = re.search(FRONTMATTER_PATTERN, content, re.DOTALL)
if match:
frontmatter = match.group(1)
return yaml.safe_load(frontmatter), content[match.end() :]
return None, content
match = re.search(FRONTMATTER_PATTERN, content, re.DOTALL)
if match:
frontmatter = match.group(1)
return yaml.safe_load(frontmatter), content[match.end() :]
return None, content


def split_sections(content):
# Assumes only level ### headers delineate sections, (#### is used for pass options)
sections = re.split(r"(^### .*$)", content, flags=re.MULTILINE)
return sections
# Assumes only level ### headers delineate sections, (#### is used for pass options)
sections = re.split(r"(^### .*$)", content, flags=re.MULTILINE)
return sections


def sort_sections(sections):
headers_and_content = []
headers_and_content = []

for i in range(1, len(sections), 2):
header = sections[i].strip()
body = sections[i + 1]
headers_and_content.append((header, body))
for i in range(1, len(sections), 2):
header = sections[i].strip()
body = sections[i + 1]
headers_and_content.append((header, body))

sorted_sections = sorted(headers_and_content, key=lambda x: x[0])
return sorted_sections
sorted_sections = sorted(headers_and_content, key=lambda x: x[0])
return sorted_sections


def rebuild_content(frontmatter, sorted_sections):
frontmatter_str = (
"---\n" + yaml.dump(frontmatter) + "---\n\n" if frontmatter else ""
)
content_str = "".join([f"{header}\n{body}\n" for header, body in sorted_sections])
return frontmatter_str + content_str
frontmatter_str = (
"---\n" + yaml.dump(frontmatter) + "---\n\n" if frontmatter else ""
)
content_str = "".join(
[f"{header}\n{body}\n" for header, body in sorted_sections]
)
return frontmatter_str + content_str


def sort_markdown_file_by_header(path):
with open(path, "r") as f:
content = f.read()
with open(path, "r") as f:
content = f.read()

frontmatter, content_without_frontmatter = extract_frontmatter(content)
sections = split_sections(content_without_frontmatter)
sorted_sections = sort_sections(sections)
sorted_content = rebuild_content(frontmatter, sorted_sections)
frontmatter, content_without_frontmatter = extract_frontmatter(content)
sections = split_sections(content_without_frontmatter)
sorted_sections = sort_sections(sections)
sorted_content = rebuild_content(frontmatter, sorted_sections)

with open(path, "w") as f:
f.write(sorted_content)
with open(path, "w") as f:
f.write(sorted_content)


if __name__ == "__main__":
# Create passes.md file with the front matter
with open(PASSES_FILE, "w") as f:
f.write(
"""---
# Create passes.md file with the front matter
with open(PASSES_FILE, "w") as f:
f.write("""---
title: Passes
weight: 70
---\n"""
)

print("Processing Non-conversion Passes")
for src_path in glob.glob(f"{SRC_BASE}/**/*Passes.md", recursive=True):
with open(src_path, "r") as src_file:
with open(PASSES_FILE, "a") as dest_file:
dest_file.write(cleanup(src_file.read()))

print("Processing Conversion Passes")
for src_path in glob.glob(
f"{SRC_BASE}/Dialect/**/Conversions/**/*.md", recursive=True
):
with open(src_path, "r") as src_file:
with open(PASSES_FILE, "a") as dest_file:
dest_file.write(cleanup(src_file.read()))
dest_file.write("\n")

sort_markdown_file_by_header(PASSES_FILE)

print("Processing Dialects")
Path(f"{DEST_BASE}/Dialects/").mkdir(parents=True, exist_ok=True)
for dialect_dir in glob.glob(f"{SRC_BASE}/Dialect/*"):
if not os.path.isdir(dialect_dir):
print(f"Skipping non-directory file {dialect_dir}")
continue

dialect_name = os.path.basename(dialect_dir)
print(f"Processing {dialect_name}")
filename = f"{dialect_name}.md"
dest_path = f"{DEST_BASE}/Dialects/{filename}"

markdown_files = {}
for src_path in glob.glob(f"{dialect_dir}/IR/**/*.md", recursive=True):
markdown_files[src_path] = True
print(f"Adding {src_path} to the queue")

if not markdown_files:
print(f"Skipping {dialect_name} as no markdown files found")
continue


# Write front matter for the dialect markdown
with open(dest_path, "w") as f:
f.write(
f"""---
---\n""")

print("Processing Non-conversion Passes")
for src_path in glob.glob(f"{SRC_BASE}/**/*Passes.md", recursive=True):
with open(src_path, "r") as src_file:
with open(PASSES_FILE, "a") as dest_file:
dest_file.write(cleanup(src_file.read()))

print("Processing Conversion Passes")
for src_path in glob.glob(
f"{SRC_BASE}/Dialect/**/Conversions/**/*.md", recursive=True
):
with open(src_path, "r") as src_file:
with open(PASSES_FILE, "a") as dest_file:
dest_file.write(cleanup(src_file.read()))
dest_file.write("\n")

sort_markdown_file_by_header(PASSES_FILE)

print("Processing Dialects")
Path(f"{DEST_BASE}/Dialects/").mkdir(parents=True, exist_ok=True)
for dialect_dir in glob.glob(f"{SRC_BASE}/Dialect/*"):
if not os.path.isdir(dialect_dir):
print(f"Skipping non-directory file {dialect_dir}")
continue

dialect_name = os.path.basename(dialect_dir)
print(f"Processing {dialect_name}")
filename = f"{dialect_name}.md"
dest_path = f"{DEST_BASE}/Dialects/{filename}"

markdown_files = {}
for src_path in glob.glob(f"{dialect_dir}/IR/**/*.md", recursive=True):
markdown_files[src_path] = True
print(f"Adding {src_path} to the queue")

if not markdown_files:
print(f"Skipping {dialect_name} as no markdown files found")
continue

# Write front matter for the dialect markdown
with open(dest_path, "w") as f:
f.write(f"""---
title: {dialect_name}
github_url: https://github.com/google/heir/edit/main/lib/Dialect/{dialect_name}/IR
---\n"""
)

# Process files in a specific order
groups = ["Dialect", "Attributes", "Types", "Ops"]

for index, group in enumerate(groups):
search_pattern = f"{dialect_dir}/IR/*{group}.md"
for src_path in glob.glob(search_pattern):
if not os.path.isfile(src_path):
continue

with open(dest_path, "a") as dest_file:
if index == 0:
# Special care for the first group
with open(src_path, "r") as src_file:
content = (
src_file.read()
.replace("# Dialect", "")
.replace("[TOC]", "")
)
dest_file.write(content)
else:
dest_file.write(f"## {dialect_name} {group.lower()}\n")
with open(src_path, "r") as src_file:
dest_file.write(src_file.read())

markdown_files.pop(src_path, None)

# Include additional files not processed in the groups
if markdown_files:
with open(dest_path, "a") as dest_file:
dest_file.write(f"## {dialect_name} additional definitions\n")
for src_path in list(markdown_files.keys()):
if os.path.isfile(src_path):
with open(src_path, "r") as src_file:
dest_file.write(src_file.read())
markdown_files.pop(src_path)
---\n""")

# Process files in a specific order
groups = ["Dialect", "Attributes", "Types", "Ops"]

for index, group in enumerate(groups):
search_pattern = f"{dialect_dir}/IR/*{group}.md"
for src_path in glob.glob(search_pattern):
if not os.path.isfile(src_path):
continue

with open(dest_path, "a") as dest_file:
if index == 0:
# Special care for the first group
with open(src_path, "r") as src_file:
content = (
src_file.read().replace("# Dialect", "").replace("[TOC]", "")
)
dest_file.write(content)
else:
dest_file.write(f"## {dialect_name} {group.lower()}\n")
with open(src_path, "r") as src_file:
dest_file.write(src_file.read())

markdown_files.pop(src_path, None)

# Include additional files not processed in the groups
if markdown_files:
with open(dest_path, "a") as dest_file:
dest_file.write(f"## {dialect_name} additional definitions\n")
for src_path in list(markdown_files.keys()):
if os.path.isfile(src_path):
with open(src_path, "r") as src_file:
dest_file.write(src_file.read())
markdown_files.pop(src_path)
2 changes: 2 additions & 0 deletions .github/workflows/pr_style.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ jobs:
steps:
- uses: actions/checkout@8e5e7e5ab8b370d6c329ec480221332ada57f0ab # pin@v3
- uses: actions/setup-python@61a6322f88396a6271a6ee3565807d608ecaddd1 # pin@v4.7.0
with:
python-version: '3.11'

- name: Run pre-commit github action
uses: pre-commit/action@646c83fcd040023954eafda54b4db0192ce70507 # pin@v3
7 changes: 7 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -71,4 +71,11 @@ repos:
hooks:
- id: go-fmt

# python formatter
- repo: https://github.com/google/pyink
rev: 24.10.1
hooks:
- id: pyink
language_version: python3.11

exclude: patches/.*\.patch$
12 changes: 12 additions & 0 deletions heir_py/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,15 @@ variables:
`bazel test` should work out of the box. If it does not, file a bug.
`heir_py/testing.bzl` contains the environment variable setup required to tell
the frontend where to find OpenFHE and related backend shared libraries.

## Formatting

This uses [pyink](https://github.com/google/pyink) for autoformatting, which is
a fork of the more commonly used [black](https://github.com/psf/black) formatter
with some patches to support Google's internal style guide. The configuration in
pyproject.toml corresponds to the options making pyink consistent with Google's
internal style guide.

The `pyink` repo has instructions for setting up `pyink` with VSCode. The
pre-commit configuration for this repo will automatically run `pyink`, and to
run a one-time format of the entire project, use `pre-commit run --all-files`.
4 changes: 2 additions & 2 deletions heir_py/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def compile(
backend: str = "openfhe",
backend_config: Optional[openfhe_config.OpenFHEConfig] = None,
heir_config: Optional[_heir_config.HEIRConfig] = None,
debug : Optional[bool] = False
debug: Optional[bool] = False,
):
"""Compile a function to its private equivalent in FHE.
Expand All @@ -136,7 +136,7 @@ def decorator(func):
func,
openfhe_config=backend_config or openfhe_config.from_os_env(),
heir_config=heir_config or _heir_config.from_os_env(),
debug = debug
debug=debug,
)
if backend == "openfhe":
return OpenfheClientInterface(compilation_result)
Expand Down
4 changes: 3 additions & 1 deletion heir_py/openfhe_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,9 @@ def from_os_env(debug=False) -> OpenFHEConfig:

for include_dir in include_dirs:
if not os.path.exists(include_dir):
print(f"Warning: OpenFHE include directory \"{include_dir}\" does not exist")
print(
f'Warning: OpenFHE include directory "{include_dir}" does not exist'
)

return OpenFHEConfig(
include_dirs=include_dirs
Expand Down
31 changes: 17 additions & 14 deletions heir_py/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,22 +78,22 @@ def run_compiler(
heir_opt_options = [
f"--secretize=function={func_name}",
"--mlir-to-bgv=ciphertext-degree=32",
f"--scheme-to-openfhe=entry-function={func_name}"
f"--scheme-to-openfhe=entry-function={func_name}",
]
heir_opt_output = heir_opt.run_binary(
input=mlir_textual,
options=heir_opt_options,
)

if(debug):
mlir_in_filepath = Path(workspace_dir) / f"{func_name}.in.mlir"
print(f"Debug mode enabled. Writing Input MLIR to {mlir_in_filepath}")
with open(mlir_in_filepath, "w") as f:
f.write(mlir_textual)
mlir_out_filepath = Path(workspace_dir) / f"{func_name}.out.mlir"
print(f"Debug mode enabled. Writing Output MLIR to {mlir_out_filepath}")
with open(mlir_out_filepath, "w") as f:
f.write(heir_opt_output)
if debug:
mlir_in_filepath = Path(workspace_dir) / f"{func_name}.in.mlir"
print(f"Debug mode enabled. Writing Input MLIR to {mlir_in_filepath}")
with open(mlir_in_filepath, "w") as f:
f.write(mlir_textual)
mlir_out_filepath = Path(workspace_dir) / f"{func_name}.out.mlir"
print(f"Debug mode enabled. Writing Output MLIR to {mlir_out_filepath}")
with open(mlir_out_filepath, "w") as f:
f.write(heir_opt_output)

heir_translate = heir_backend.HeirTranslateBackend(
binary_path=heir_config.heir_translate_path
Expand Down Expand Up @@ -130,10 +130,13 @@ def run_compiler(
clang_backend = clang.ClangBackend()
so_filepath = Path(workspace_dir) / f"{func_name}.so"
linker_search_paths = [openfhe_config.lib_dir]
if(debug):
print(
f"Debug mode enabled. Compiling {cpp_filepath} with linker search paths: {linker_search_paths}, include paths: {openfhe_config.include_dirs}, link libs: {openfhe_config.link_libs}\n"
)
if debug:
print(
f"Debug mode enabled. Compiling {cpp_filepath} with linker search"
f" paths: {linker_search_paths}, include paths:"
f" {openfhe_config.include_dirs}, link libs:"
f" {openfhe_config.link_libs}\n"
)

clang_backend.compile_to_shared_object(
cpp_source_filepath=cpp_filepath,
Expand Down
6 changes: 6 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,9 @@ norecursedirs = [
"tools",
"venv",
]

[tool.pyink]
line-length = 80
unstable = true
pyink-indentation = 2
pyink-use-majority-quotes = true
Loading

0 comments on commit 2c118a3

Please sign in to comment.