diff --git a/.github/workflows/copy_tblgen_files.py b/.github/workflows/copy_tblgen_files.py index 4c21e4892..c14fa2f93 100644 --- a/.github/workflows/copy_tblgen_files.py +++ b/.github/workflows/copy_tblgen_files.py @@ -11,147 +11,144 @@ def cleanup(content): - content = content.replace("", "") - return content + content = content.replace( + "", "" + ) + return content def extract_frontmatter(content): - match = re.search(FRONTMATTER_PATTERN, content, re.DOTALL) - if match: - frontmatter = match.group(1) - return yaml.safe_load(frontmatter), content[match.end() :] - return None, content + match = re.search(FRONTMATTER_PATTERN, content, re.DOTALL) + if match: + frontmatter = match.group(1) + return yaml.safe_load(frontmatter), content[match.end() :] + return None, content def split_sections(content): - # Assumes only level ### headers delineate sections, (#### is used for pass options) - sections = re.split(r"(^### .*$)", content, flags=re.MULTILINE) - return sections + # Assumes only level ### headers delineate sections, (#### is used for pass options) + sections = re.split(r"(^### .*$)", content, flags=re.MULTILINE) + return sections def sort_sections(sections): - headers_and_content = [] + headers_and_content = [] - for i in range(1, len(sections), 2): - header = sections[i].strip() - body = sections[i + 1] - headers_and_content.append((header, body)) + for i in range(1, len(sections), 2): + header = sections[i].strip() + body = sections[i + 1] + headers_and_content.append((header, body)) - sorted_sections = sorted(headers_and_content, key=lambda x: x[0]) - return sorted_sections + sorted_sections = sorted(headers_and_content, key=lambda x: x[0]) + return sorted_sections def rebuild_content(frontmatter, sorted_sections): - frontmatter_str = ( - "---\n" + yaml.dump(frontmatter) + "---\n\n" if frontmatter else "" - ) - content_str = "".join([f"{header}\n{body}\n" for header, body in sorted_sections]) - return frontmatter_str + content_str + frontmatter_str = ( + "---\n" + yaml.dump(frontmatter) + "---\n\n" if frontmatter else "" + ) + content_str = "".join( + [f"{header}\n{body}\n" for header, body in sorted_sections] + ) + return frontmatter_str + content_str def sort_markdown_file_by_header(path): - with open(path, "r") as f: - content = f.read() + with open(path, "r") as f: + content = f.read() - frontmatter, content_without_frontmatter = extract_frontmatter(content) - sections = split_sections(content_without_frontmatter) - sorted_sections = sort_sections(sections) - sorted_content = rebuild_content(frontmatter, sorted_sections) + frontmatter, content_without_frontmatter = extract_frontmatter(content) + sections = split_sections(content_without_frontmatter) + sorted_sections = sort_sections(sections) + sorted_content = rebuild_content(frontmatter, sorted_sections) - with open(path, "w") as f: - f.write(sorted_content) + with open(path, "w") as f: + f.write(sorted_content) if __name__ == "__main__": - # Create passes.md file with the front matter - with open(PASSES_FILE, "w") as f: - f.write( - """--- + # Create passes.md file with the front matter + with open(PASSES_FILE, "w") as f: + f.write("""--- title: Passes weight: 70 - ---\n""" - ) - - print("Processing Non-conversion Passes") - for src_path in glob.glob(f"{SRC_BASE}/**/*Passes.md", recursive=True): - with open(src_path, "r") as src_file: - with open(PASSES_FILE, "a") as dest_file: - dest_file.write(cleanup(src_file.read())) - - print("Processing Conversion Passes") - for src_path in glob.glob( - f"{SRC_BASE}/Dialect/**/Conversions/**/*.md", recursive=True - ): - with open(src_path, "r") as src_file: - with open(PASSES_FILE, "a") as dest_file: - dest_file.write(cleanup(src_file.read())) - dest_file.write("\n") - - sort_markdown_file_by_header(PASSES_FILE) - - print("Processing Dialects") - Path(f"{DEST_BASE}/Dialects/").mkdir(parents=True, exist_ok=True) - for dialect_dir in glob.glob(f"{SRC_BASE}/Dialect/*"): - if not os.path.isdir(dialect_dir): - print(f"Skipping non-directory file {dialect_dir}") - continue - - dialect_name = os.path.basename(dialect_dir) - print(f"Processing {dialect_name}") - filename = f"{dialect_name}.md" - dest_path = f"{DEST_BASE}/Dialects/{filename}" - - markdown_files = {} - for src_path in glob.glob(f"{dialect_dir}/IR/**/*.md", recursive=True): - markdown_files[src_path] = True - print(f"Adding {src_path} to the queue") - - if not markdown_files: - print(f"Skipping {dialect_name} as no markdown files found") - continue - - - # Write front matter for the dialect markdown - with open(dest_path, "w") as f: - f.write( - f"""--- + ---\n""") + + print("Processing Non-conversion Passes") + for src_path in glob.glob(f"{SRC_BASE}/**/*Passes.md", recursive=True): + with open(src_path, "r") as src_file: + with open(PASSES_FILE, "a") as dest_file: + dest_file.write(cleanup(src_file.read())) + + print("Processing Conversion Passes") + for src_path in glob.glob( + f"{SRC_BASE}/Dialect/**/Conversions/**/*.md", recursive=True + ): + with open(src_path, "r") as src_file: + with open(PASSES_FILE, "a") as dest_file: + dest_file.write(cleanup(src_file.read())) + dest_file.write("\n") + + sort_markdown_file_by_header(PASSES_FILE) + + print("Processing Dialects") + Path(f"{DEST_BASE}/Dialects/").mkdir(parents=True, exist_ok=True) + for dialect_dir in glob.glob(f"{SRC_BASE}/Dialect/*"): + if not os.path.isdir(dialect_dir): + print(f"Skipping non-directory file {dialect_dir}") + continue + + dialect_name = os.path.basename(dialect_dir) + print(f"Processing {dialect_name}") + filename = f"{dialect_name}.md" + dest_path = f"{DEST_BASE}/Dialects/{filename}" + + markdown_files = {} + for src_path in glob.glob(f"{dialect_dir}/IR/**/*.md", recursive=True): + markdown_files[src_path] = True + print(f"Adding {src_path} to the queue") + + if not markdown_files: + print(f"Skipping {dialect_name} as no markdown files found") + continue + + # Write front matter for the dialect markdown + with open(dest_path, "w") as f: + f.write(f"""--- title: {dialect_name} github_url: https://github.com/google/heir/edit/main/lib/Dialect/{dialect_name}/IR ----\n""" - ) - - # Process files in a specific order - groups = ["Dialect", "Attributes", "Types", "Ops"] - - for index, group in enumerate(groups): - search_pattern = f"{dialect_dir}/IR/*{group}.md" - for src_path in glob.glob(search_pattern): - if not os.path.isfile(src_path): - continue - - with open(dest_path, "a") as dest_file: - if index == 0: - # Special care for the first group - with open(src_path, "r") as src_file: - content = ( - src_file.read() - .replace("# Dialect", "") - .replace("[TOC]", "") - ) - dest_file.write(content) - else: - dest_file.write(f"## {dialect_name} {group.lower()}\n") - with open(src_path, "r") as src_file: - dest_file.write(src_file.read()) - - markdown_files.pop(src_path, None) - - # Include additional files not processed in the groups - if markdown_files: - with open(dest_path, "a") as dest_file: - dest_file.write(f"## {dialect_name} additional definitions\n") - for src_path in list(markdown_files.keys()): - if os.path.isfile(src_path): - with open(src_path, "r") as src_file: - dest_file.write(src_file.read()) - markdown_files.pop(src_path) +---\n""") + + # Process files in a specific order + groups = ["Dialect", "Attributes", "Types", "Ops"] + + for index, group in enumerate(groups): + search_pattern = f"{dialect_dir}/IR/*{group}.md" + for src_path in glob.glob(search_pattern): + if not os.path.isfile(src_path): + continue + + with open(dest_path, "a") as dest_file: + if index == 0: + # Special care for the first group + with open(src_path, "r") as src_file: + content = ( + src_file.read().replace("# Dialect", "").replace("[TOC]", "") + ) + dest_file.write(content) + else: + dest_file.write(f"## {dialect_name} {group.lower()}\n") + with open(src_path, "r") as src_file: + dest_file.write(src_file.read()) + + markdown_files.pop(src_path, None) + + # Include additional files not processed in the groups + if markdown_files: + with open(dest_path, "a") as dest_file: + dest_file.write(f"## {dialect_name} additional definitions\n") + for src_path in list(markdown_files.keys()): + if os.path.isfile(src_path): + with open(src_path, "r") as src_file: + dest_file.write(src_file.read()) + markdown_files.pop(src_path) diff --git a/.github/workflows/pr_style.yml b/.github/workflows/pr_style.yml index eb158f14b..0a10e053d 100644 --- a/.github/workflows/pr_style.yml +++ b/.github/workflows/pr_style.yml @@ -11,6 +11,8 @@ jobs: steps: - uses: actions/checkout@8e5e7e5ab8b370d6c329ec480221332ada57f0ab # pin@v3 - uses: actions/setup-python@61a6322f88396a6271a6ee3565807d608ecaddd1 # pin@v4.7.0 + with: + python-version: '3.11' - name: Run pre-commit github action uses: pre-commit/action@646c83fcd040023954eafda54b4db0192ce70507 # pin@v3 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d46281d59..27e18080e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -71,4 +71,11 @@ repos: hooks: - id: go-fmt +# python formatter +- repo: https://github.com/google/pyink + rev: 24.10.1 + hooks: + - id: pyink + language_version: python3.11 + exclude: patches/.*\.patch$ diff --git a/heir_py/README.md b/heir_py/README.md index 3e7924f9f..e86c509cf 100644 --- a/heir_py/README.md +++ b/heir_py/README.md @@ -70,3 +70,15 @@ variables: `bazel test` should work out of the box. If it does not, file a bug. `heir_py/testing.bzl` contains the environment variable setup required to tell the frontend where to find OpenFHE and related backend shared libraries. + +## Formatting + +This uses [pyink](https://github.com/google/pyink) for autoformatting, which is +a fork of the more commonly used [black](https://github.com/psf/black) formatter +with some patches to support Google's internal style guide. The configuration in +pyproject.toml corresponds to the options making pyink consistent with Google's +internal style guide. + +The `pyink` repo has instructions for setting up `pyink` with VSCode. The +pre-commit configuration for this repo will automatically run `pyink`, and to +run a one-time format of the entire project, use `pre-commit run --all-files`. diff --git a/heir_py/decorator.py b/heir_py/decorator.py index 7ff145eaf..29411015f 100644 --- a/heir_py/decorator.py +++ b/heir_py/decorator.py @@ -115,7 +115,7 @@ def compile( backend: str = "openfhe", backend_config: Optional[openfhe_config.OpenFHEConfig] = None, heir_config: Optional[_heir_config.HEIRConfig] = None, - debug : Optional[bool] = False + debug: Optional[bool] = False, ): """Compile a function to its private equivalent in FHE. @@ -136,7 +136,7 @@ def decorator(func): func, openfhe_config=backend_config or openfhe_config.from_os_env(), heir_config=heir_config or _heir_config.from_os_env(), - debug = debug + debug=debug, ) if backend == "openfhe": return OpenfheClientInterface(compilation_result) diff --git a/heir_py/openfhe_config.py b/heir_py/openfhe_config.py index 195ebba18..73b30e52d 100644 --- a/heir_py/openfhe_config.py +++ b/heir_py/openfhe_config.py @@ -91,7 +91,9 @@ def from_os_env(debug=False) -> OpenFHEConfig: for include_dir in include_dirs: if not os.path.exists(include_dir): - print(f"Warning: OpenFHE include directory \"{include_dir}\" does not exist") + print( + f'Warning: OpenFHE include directory "{include_dir}" does not exist' + ) return OpenFHEConfig( include_dirs=include_dirs diff --git a/heir_py/pipeline.py b/heir_py/pipeline.py index 89fccca08..c0ef42c2b 100644 --- a/heir_py/pipeline.py +++ b/heir_py/pipeline.py @@ -78,22 +78,22 @@ def run_compiler( heir_opt_options = [ f"--secretize=function={func_name}", "--mlir-to-bgv=ciphertext-degree=32", - f"--scheme-to-openfhe=entry-function={func_name}" + f"--scheme-to-openfhe=entry-function={func_name}", ] heir_opt_output = heir_opt.run_binary( input=mlir_textual, options=heir_opt_options, ) - if(debug): - mlir_in_filepath = Path(workspace_dir) / f"{func_name}.in.mlir" - print(f"Debug mode enabled. Writing Input MLIR to {mlir_in_filepath}") - with open(mlir_in_filepath, "w") as f: - f.write(mlir_textual) - mlir_out_filepath = Path(workspace_dir) / f"{func_name}.out.mlir" - print(f"Debug mode enabled. Writing Output MLIR to {mlir_out_filepath}") - with open(mlir_out_filepath, "w") as f: - f.write(heir_opt_output) + if debug: + mlir_in_filepath = Path(workspace_dir) / f"{func_name}.in.mlir" + print(f"Debug mode enabled. Writing Input MLIR to {mlir_in_filepath}") + with open(mlir_in_filepath, "w") as f: + f.write(mlir_textual) + mlir_out_filepath = Path(workspace_dir) / f"{func_name}.out.mlir" + print(f"Debug mode enabled. Writing Output MLIR to {mlir_out_filepath}") + with open(mlir_out_filepath, "w") as f: + f.write(heir_opt_output) heir_translate = heir_backend.HeirTranslateBackend( binary_path=heir_config.heir_translate_path @@ -130,10 +130,13 @@ def run_compiler( clang_backend = clang.ClangBackend() so_filepath = Path(workspace_dir) / f"{func_name}.so" linker_search_paths = [openfhe_config.lib_dir] - if(debug): - print( - f"Debug mode enabled. Compiling {cpp_filepath} with linker search paths: {linker_search_paths}, include paths: {openfhe_config.include_dirs}, link libs: {openfhe_config.link_libs}\n" - ) + if debug: + print( + f"Debug mode enabled. Compiling {cpp_filepath} with linker search" + f" paths: {linker_search_paths}, include paths:" + f" {openfhe_config.include_dirs}, link libs:" + f" {openfhe_config.link_libs}\n" + ) clang_backend.compile_to_shared_object( cpp_source_filepath=cpp_filepath, diff --git a/pyproject.toml b/pyproject.toml index 173b912c8..473aca69b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,3 +13,9 @@ norecursedirs = [ "tools", "venv", ] + +[tool.pyink] +line-length = 80 +unstable = true +pyink-indentation = 2 +pyink-use-majority-quotes = true diff --git a/scripts/gcp/examples/add_one_lut3_main.py b/scripts/gcp/examples/add_one_lut3_main.py index 607c5bcf8..1a268a312 100644 --- a/scripts/gcp/examples/add_one_lut3_main.py +++ b/scripts/gcp/examples/add_one_lut3_main.py @@ -32,18 +32,16 @@ cleartext_x = type_converters.u8_to_bit_slice(x) ciphertext_x = [jaxite_bool.encrypt(z, cks, lwe_rng) for z in cleartext_x] -result_ciphertext = add_one_lut3_lib.add_one_lut3( - ciphertext_x, sks, params -) +result_ciphertext = add_one_lut3_lib.add_one_lut3(ciphertext_x, sks, params) + # Using Timeit def timed_fn(): - result_ciphertext = add_one_lut3_lib.add_one_lut3( - ciphertext_x, sks, params - ) + result_ciphertext = add_one_lut3_lib.add_one_lut3(ciphertext_x, sks, params) for c in result_ciphertext: c.block_until_ready() + timer = timeit.Timer(timed_fn) execution_time = timer.repeat(repeat=1, number=1) print("Add one execution time: ", execution_time) diff --git a/scripts/gcp/examples/jaxite_example.py b/scripts/gcp/examples/jaxite_example.py index 2a1683708..829ce20db 100644 --- a/scripts/gcp/examples/jaxite_example.py +++ b/scripts/gcp/examples/jaxite_example.py @@ -36,10 +36,13 @@ # included in timing metircs. and_gate = jaxite_bool.and_(ct_false, ct_true, sks, params) + # Using Timeit def timed_fn(): and_gate = jaxite_bool.and_(ct_false, ct_true, sks, params) and_gate.block_until_ready() + + timer = timeit.Timer(timed_fn) execution_time = timer.repeat(repeat=1, number=1) print("And gate execution time: ", execution_time) diff --git a/scripts/jupyter/heir_play/__init__.py b/scripts/jupyter/heir_play/__init__.py index 4d09e910f..7dfe86c29 100644 --- a/scripts/jupyter/heir_play/__init__.py +++ b/scripts/jupyter/heir_play/__init__.py @@ -1,4 +1,5 @@ """A magic for running heir-opt nightly binary""" + __version__ = "0.0.1" import os @@ -9,13 +10,15 @@ def load_ipython_extension(ipython): - ipython.register_magics( - HeirOptMagic(ipython, binary_path=str(load_nightly("heir-opt"))) - ) - ipython.register_magics( - HeirTranslateMagic(ipython, binary_path=str(load_nightly("heir-translate"))) - ) - load_nightly("abc") - load_nightly("techmap.v") - os.environ['HEIR_YOSYS_SCRIPTS_DIR'] = os.getcwd() - os.environ['HEIR_ABC_BINARY'] = os.getcwd() + "/abc" + ipython.register_magics( + HeirOptMagic(ipython, binary_path=str(load_nightly("heir-opt"))) + ) + ipython.register_magics( + HeirTranslateMagic( + ipython, binary_path=str(load_nightly("heir-translate")) + ) + ) + load_nightly("abc") + load_nightly("techmap.v") + os.environ["HEIR_YOSYS_SCRIPTS_DIR"] = os.getcwd() + os.environ["HEIR_ABC_BINARY"] = os.getcwd() + "/abc" diff --git a/scripts/jupyter/heir_play/heir_opt.py b/scripts/jupyter/heir_play/heir_opt.py index d0a1ef4fe..789d9c923 100644 --- a/scripts/jupyter/heir_play/heir_opt.py +++ b/scripts/jupyter/heir_play/heir_opt.py @@ -8,74 +8,77 @@ class BinaryMagic(Magics): - def __init__(self, shell, binary_path): - """ - Initialize a magic for running a binary with a given path. - """ - super(BinaryMagic, self).__init__(shell) - self.binary_path = binary_path - - def run_binary(self, line, cell): - """ - Run the binary on the input cell. - - Args: - line: The options to pass to the binary. - cell: The input to pass to the binary. - """ - print(f"Running {self.binary_path}...") - completed_process = subprocess.run( - [os.path.abspath(self.binary_path)] + shlex.split(line), - input=cell, - text=True, - capture_output=True, - ) - if completed_process.returncode != 0: - print(f"Error running {self.binary_path}") - print(completed_process.stdout) - print(completed_process.stderr) - return - output = completed_process.stdout - print(output) + + def __init__(self, shell, binary_path): + """ + Initialize a magic for running a binary with a given path. + """ + super(BinaryMagic, self).__init__(shell) + self.binary_path = binary_path + + def run_binary(self, line, cell): + """ + Run the binary on the input cell. + + Args: + line: The options to pass to the binary. + cell: The input to pass to the binary. + """ + print(f"Running {self.binary_path}...") + completed_process = subprocess.run( + [os.path.abspath(self.binary_path)] + shlex.split(line), + input=cell, + text=True, + capture_output=True, + ) + if completed_process.returncode != 0: + print(f"Error running {self.binary_path}") + print(completed_process.stdout) + print(completed_process.stderr) + return + output = completed_process.stdout + print(output) @magics_class class HeirOptMagic(BinaryMagic): - def __init__(self, shell, binary_path="heir-opt"): - """ - Initialize heir-opt with a path to the heir-opt binary. - If not specified, will assume heir-opt is on the path. - """ - super(HeirOptMagic, self).__init__(shell, binary_path) - @cell_magic - def heir_opt(self, line, cell): - """ - Run heir-opt on the input cell. + def __init__(self, shell, binary_path="heir-opt"): + """ + Initialize heir-opt with a path to the heir-opt binary. + If not specified, will assume heir-opt is on the path. + """ + super(HeirOptMagic, self).__init__(shell, binary_path) + + @cell_magic + def heir_opt(self, line, cell): + """ + Run heir-opt on the input cell. - Args: - line: The options to pass to heir-opt. - cell: The input to pass to heir-opt. - """ - return self.run_binary(line, cell) + Args: + line: The options to pass to heir-opt. + cell: The input to pass to heir-opt. + """ + return self.run_binary(line, cell) @magics_class class HeirTranslateMagic(BinaryMagic): - def __init__(self, shell, binary_path="heir-translate"): - """ - Initialize heir-translate with a path to the heir-translate binary. - If not specified, will assume heir-translate is on the path. - """ - super(HeirTranslateMagic, self).__init__(shell, binary_path) - - @cell_magic - def heir_translate(self, line, cell): - """ - Run heir-translate on the input cell. - - Args: - line: The options to pass to heir-translate. - cell: The input to pass to heir-translate. - """ - return self.run_binary(line, cell) + + def __init__(self, shell, binary_path="heir-translate"): + """ + Initialize heir-translate with a path to the heir-translate binary. + If not specified, will assume heir-translate is on the path. + """ + super(HeirTranslateMagic, self).__init__(shell, binary_path) + + @cell_magic + def heir_translate(self, line, cell): + """ + Run heir-translate on the input cell. + + Args: + line: The options to pass to heir-translate. + cell: The input to pass to heir-translate. + """ + return self.run_binary(line, cell) diff --git a/scripts/jupyter/heir_play/utils.py b/scripts/jupyter/heir_play/utils.py index dad58a519..357be7749 100644 --- a/scripts/jupyter/heir_play/utils.py +++ b/scripts/jupyter/heir_play/utils.py @@ -5,36 +5,37 @@ ASSET_BASE_URL = "https://github.com/google/heir/releases/download/nightly/" + def abort_cleanup(filename): - cwd = Path(dir=os.getcwd()) - tmpfile = cwd / filename - os.remove(tmpfile) + cwd = Path(dir=os.getcwd()) + tmpfile = cwd / filename + os.remove(tmpfile) def load_nightly(filename) -> Path: - """Fetches the nightly heir-opt binary from GitHub and returns the path to it.""" - print(f"Loading {filename} nightly binary") - # TODO: how to clean up the tmpdir after ipython closes? - # At worst, the user will see this in their heir_play dir and delete it. - cwd = Path(dir=os.getcwd()) - tmpfile = cwd / filename - if os.path.isfile(tmpfile): - print(f"Using existing local {filename}") - return tmpfile - - # -L follows redirects, necessary for GH asset downloads - asset_url = ASSET_BASE_URL + filename - proc = subprocess.run(["curl", "-L", "-o", tmpfile, asset_url]) - if proc.returncode != 0: - print(f"Error downloading {filename}") - print(proc.stderr) - return None - - proc = subprocess.run(["chmod", "a+x", tmpfile]) - if proc.returncode != 0: - print(f"Error modifying permissions on {filename}") - print(proc.stderr) - abort_cleanup(filename) - return None - + """Fetches the nightly heir-opt binary from GitHub and returns the path to it.""" + print(f"Loading {filename} nightly binary") + # TODO: how to clean up the tmpdir after ipython closes? + # At worst, the user will see this in their heir_play dir and delete it. + cwd = Path(dir=os.getcwd()) + tmpfile = cwd / filename + if os.path.isfile(tmpfile): + print(f"Using existing local {filename}") return tmpfile + + # -L follows redirects, necessary for GH asset downloads + asset_url = ASSET_BASE_URL + filename + proc = subprocess.run(["curl", "-L", "-o", tmpfile, asset_url]) + if proc.returncode != 0: + print(f"Error downloading {filename}") + print(proc.stderr) + return None + + proc = subprocess.run(["chmod", "a+x", tmpfile]) + if proc.returncode != 0: + print(f"Error modifying permissions on {filename}") + print(proc.stderr) + abort_cleanup(filename) + return None + + return tmpfile diff --git a/scripts/lit_to_bazel.py b/scripts/lit_to_bazel.py index 19e6f254f..124619736 100644 --- a/scripts/lit_to_bazel.py +++ b/scripts/lit_to_bazel.py @@ -13,102 +13,102 @@ def strip_run_prefix(line): - if RUN_PREFIX in line: - return line.split(RUN_PREFIX)[1] - return line + if RUN_PREFIX in line: + return line.split(RUN_PREFIX)[1] + return line def convert_to_run_commands(run_lines): - run_lines = deque(run_lines) - cmds = [] + run_lines = deque(run_lines) + cmds = [] + current_command = "" + while run_lines: + line = run_lines.popleft() + if RUN_PREFIX not in line: + continue + + line = strip_run_prefix(line) + + if "|" in line: + first, second = line.split("|", maxsplit=1) + current_command += " " + first.strip() + cmds.append(current_command.strip()) + current_command = "" + cmds.append(PIPE) + run_lines.appendleft(RUN_PREFIX + " " + second.strip()) + continue + + # redirecting to a file implicitly ends the command on that line + if OUT_REDIRECT in line or IN_REDIRECT in line: + cmds.append(line.strip()) + current_command = "" + continue + + if line.strip().endswith("\\"): + current_command += " " + line.replace("\\", "").strip() + continue + + current_command += line + cmds.append(current_command.strip()) current_command = "" - while run_lines: - line = run_lines.popleft() - if RUN_PREFIX not in line: - continue - line = strip_run_prefix(line) - - if "|" in line: - first, second = line.split("|", maxsplit=1) - current_command += " " + first.strip() - cmds.append(current_command.strip()) - current_command = "" - cmds.append(PIPE) - run_lines.appendleft(RUN_PREFIX + " " + second.strip()) - continue - - # redirecting to a file implicitly ends the command on that line - if OUT_REDIRECT in line or IN_REDIRECT in line: - cmds.append(line.strip()) - current_command = "" - continue - - if line.strip().endswith("\\"): - current_command += " " + line.replace("\\", "").strip() - continue - - current_command += line - cmds.append(current_command.strip()) - current_command = "" - - return cmds + return cmds def lit_to_bazel( lit_test_file: str, git_root: str = "", ): - """A helper CLI that converts MLIR test files to bazel run commands. - - Args: - lit_test_file: The lit test file that should be converted to a bazel run - command. - """ - - if not git_root: - git_root = pathlib.Path(__file__).parent.parent - if not os.path.isdir(git_root / ".git"): - raise RuntimeError(f"Could not find git root, looked at {git_root}") - # if git root is manually specified, just trust it - - if not lit_test_file: - raise ValueError("lit_test_file must be provided") - - if not os.path.isfile(lit_test_file): - raise ValueError("Unable to find lit_test_file '%s'" % lit_test_file) - - run_lines = [] - with open(lit_test_file, "r") as f: - for line in f: - if "// RUN:" in line: - run_lines.append(line) - - commands = convert_to_run_commands(run_lines) - commands = [x for x in commands if "FileCheck" not in x] - # remove consecutive and trailing pipes - if commands[-1] == PIPE: - commands.pop() - deduped_commands = [] - for command in commands: - if command == PIPE and deduped_commands[-1] == PIPE: - continue - deduped_commands.append(command) - - joined = " ".join(deduped_commands) - # I would consider using bazel-bin/tools/heir-opt, but the yosys - # requirement requires additional env vars to be set for the yosys and ABC - # paths, which is not yet worth doing for this script. - joined = joined.replace( - "heir-opt", - "bazel run --noallow_analysis_cache_discard //tools:heir-opt --", - ) - joined = joined.replace( - "heir-translate", f"{git_root}/bazel-bin/tools/heir-translate" - ) - joined = joined.replace("%s", str(pathlib.Path(lit_test_file).absolute())) - print(joined) + """A helper CLI that converts MLIR test files to bazel run commands. + + Args: + lit_test_file: The lit test file that should be converted to a bazel run + command. + """ + + if not git_root: + git_root = pathlib.Path(__file__).parent.parent + if not os.path.isdir(git_root / ".git"): + raise RuntimeError(f"Could not find git root, looked at {git_root}") + # if git root is manually specified, just trust it + + if not lit_test_file: + raise ValueError("lit_test_file must be provided") + + if not os.path.isfile(lit_test_file): + raise ValueError("Unable to find lit_test_file '%s'" % lit_test_file) + + run_lines = [] + with open(lit_test_file, "r") as f: + for line in f: + if "// RUN:" in line: + run_lines.append(line) + + commands = convert_to_run_commands(run_lines) + commands = [x for x in commands if "FileCheck" not in x] + # remove consecutive and trailing pipes + if commands[-1] == PIPE: + commands.pop() + deduped_commands = [] + for command in commands: + if command == PIPE and deduped_commands[-1] == PIPE: + continue + deduped_commands.append(command) + + joined = " ".join(deduped_commands) + # I would consider using bazel-bin/tools/heir-opt, but the yosys + # requirement requires additional env vars to be set for the yosys and ABC + # paths, which is not yet worth doing for this script. + joined = joined.replace( + "heir-opt", + "bazel run --noallow_analysis_cache_discard //tools:heir-opt --", + ) + joined = joined.replace( + "heir-translate", f"{git_root}/bazel-bin/tools/heir-translate" + ) + joined = joined.replace("%s", str(pathlib.Path(lit_test_file).absolute())) + print(joined) if __name__ == "__main__": - fire.Fire(lit_to_bazel) + fire.Fire(lit_to_bazel) diff --git a/scripts/templates/templates.py b/scripts/templates/templates.py index 1c2e1f1d7..9cdcc8be3 100644 --- a/scripts/templates/templates.py +++ b/scripts/templates/templates.py @@ -9,344 +9,340 @@ def render_all(path: pathlib.Path, **args): - env = jinja2.Environment(loader=jinja2.FileSystemLoader(path)) - for template_filename in os.listdir(path): - template = env.get_template(template_filename) - content = template.render(**args) - with open(path / template_filename, mode="w") as outfile: - outfile.write(content) - print(f"Rendered template for {path / template_filename}") + env = jinja2.Environment(loader=jinja2.FileSystemLoader(path)) + for template_filename in os.listdir(path): + template = env.get_template(template_filename) + content = template.render(**args) + with open(path / template_filename, mode="w") as outfile: + outfile.write(content) + print(f"Rendered template for {path / template_filename}") def try_create_dirs(lib_path, force=False): - print(f"Creating dirs:\n {lib_path}") - try: - os.makedirs(lib_path) - except FileExistsError: - if force: - shutil.rmtree(lib_path) - os.mkdir(lib_path) - else: - raise + print(f"Creating dirs:\n {lib_path}") + try: + os.makedirs(lib_path) + except FileExistsError: + if force: + shutil.rmtree(lib_path) + os.mkdir(lib_path) + else: + raise def copy_all(filepath_mapping): - for src, dest in filepath_mapping.items(): - shutil.copy(src, dest) + for src, dest in filepath_mapping.items(): + shutil.copy(src, dest) class CLI: - """A helper CLI for generating boilerplate MLIR code in HEIR. + """A helper CLI for generating boilerplate MLIR code in HEIR. + + Available subcommands: + + new_conversion_pass: Create a conversion pass from one dialect to another. + new_dialect_transform: Create a pass for a dialect-specific transform. + new_dialect: Create a new dialect. + new_transform: Create a pass for a non-dialect-specific transform. + + To see the help for a subcommand, run + + python scripts/templates/templates.py --help + """ + + def __init__(self): + git_root = pathlib.Path(__file__).parent.parent.parent + if not os.path.isdir(git_root / ".git"): + raise RuntimeError(f"Could not find git root, looked at {git_root}") + self.root = git_root + + def new_conversion_pass( + self, + pass_name: str = None, + source_dialect_name: str = None, + source_dialect_namespace: str = None, + source_dialect_mnemonic: str = None, + target_dialect_name: str = None, + target_dialect_namespace: str = None, + target_dialect_mnemonic: str = None, + force: bool = False, + ): + """Create a new conversion pass. + + Args: + pass_name: The CPP class name and directory name for the conversion + pass, e.g., BGVToLWE + source_dialect_name: The source dialect's CPP class name prefix and + directory name, e.g., CGGI (for CGGIDialect) + source_dialect_namespace: The source dialect's CPP namespace, e.g., + tfhe_rust for TfheRustDialect + source_dialect_mnemonic: The source dialect's mnemonic, e.g., cggi + target_dialect_name: The target dialect's CPP class name prefix and + directory name, e.g., CGGI (for CGGIDialect) + target_dialect_namespace: The target dialect's CPP namespace, e.g., + tfhe_rust for TfheRustDialect + target_dialect_mnemonic: The target dialect's mnemonic, e.g., cggi + force: If True, overwrite existing files. If False, raise an error if + any files already exist. + """ + if not source_dialect_name: + raise ValueError("source_dialect_name must be provided") + if not target_dialect_name: + raise ValueError("target_dialect_name must be provided") + + if not pass_name: + pass_name = f"{source_dialect_name}To{target_dialect_name}" + + # These defaults could be smarter: look up the name in the actual + # tablegen for the dialect or quit if it can't be found + if not source_dialect_mnemonic: + source_dialect_mnemonic = source_dialect_name.lower() + + if not source_dialect_namespace: + source_dialect_namespace = source_dialect_mnemonic + + if not target_dialect_mnemonic: + target_dialect_mnemonic = target_dialect_name.lower() + + if not target_dialect_namespace: + target_dialect_namespace = target_dialect_mnemonic + + lib_path = ( + self.root + / "lib" + / "Dialect" + / source_dialect_name + / "Conversions" + / pass_name + ) + + if not force and os.path.isdir(lib_path): + raise ValueError( + f"Conversion pass directories already exist at {lib_path}" + ) + + templates_path = self.root / "scripts" / "templates" / "Conversion" + templ_lib = templates_path / "lib" + path_mapping = { + templ_lib / "ConversionPass.h.jinja": lib_path / f"{pass_name}.h", + templ_lib / "ConversionPass.td.jinja": lib_path / f"{pass_name}.td", + templ_lib / "BUILD.jinja": lib_path / "BUILD", + templ_lib / "ConversionPass.cpp.jinja": lib_path / f"{pass_name}.cpp", + } - Available subcommands: + try: + try_create_dirs(lib_path, force) + copy_all(path_mapping) + + render_all( + lib_path, + pass_name=pass_name, + source_dialect_name=source_dialect_name, + source_dialect_namespace=source_dialect_namespace, + source_dialect_mnemonic=source_dialect_mnemonic, + target_dialect_name=target_dialect_name, + target_dialect_namespace=target_dialect_namespace, + target_dialect_mnemonic=target_dialect_mnemonic, + ) + except: + print("Hit unrecoverable error, cleaning up") + shutil.rmtree(lib_path) + raise + + def new_dialect_transform( + self, + pass_name: str = None, + pass_flag: str = None, + dialect_name: str = None, + dialect_namespace: str = None, + force: bool = False, + ): + """Create a new pass for a dialect-specific transform. + + Args: + pass_name: The CPP class name for the pass, e.g., ForgetSecrets. + pass_flag: The CLI flag to use for the pass (optional). + dialect_name: The dialect's CPP class name prefix and directory name, + e.g., CGGI (for CGGIDialect). + dialect_namespace: The dialect's CPP namespace, e.g., tfhe_rust for + TfheRustDialect. + force: If True, overwrite existing files. If False, raise an error if + any files already exist. + """ + if not pass_name: + raise ValueError("pass_name must be provided") + if not dialect_name: + raise ValueError("dialect_name must be provided") + + if not pass_flag: + pass_flag = f"{dialect_name.lower()}-{pass_name.lower()}" + + # Default could be smarter: look up the name in the actual tablegen for the + # dialect or quit if it can't be found + if not dialect_namespace: + dialect_namespace = dialect_name.lower() + + lib_path = self.root / "lib" / "Dialect" / dialect_name / "Transforms" + + if not force and os.path.isdir(lib_path): + raise ValueError(f"Pass directories already exist at {lib_path}") + + templates_path = self.root / "scripts" / "templates" / "DialectTransforms" + templ_lib = templates_path / "lib" + path_mapping = { + templ_lib / "BUILD.jinja": lib_path / "BUILD", + templ_lib / "Pass.cpp.jinja": lib_path / f"{pass_name}.cpp", + templ_lib / "Pass.h.jinja": lib_path / f"{pass_name}.h", + templ_lib / "Passes.h.jinja": lib_path / "Passes.h", + templ_lib / "Passes.td.jinja": lib_path / "Passes.td", + } - new_conversion_pass: Create a conversion pass from one dialect to another. - new_dialect_transform: Create a pass for a dialect-specific transform. - new_dialect: Create a new dialect. - new_transform: Create a pass for a non-dialect-specific transform. + try: + try_create_dirs(lib_path, force) + copy_all(path_mapping) + render_all( + lib_path, + pass_name=pass_name, + pass_flag=pass_flag, + dialect_name=dialect_name, + dialect_namespace=dialect_namespace, + ) + except: + print("Hit unrecoverable error, cleaning up") + shutil.rmtree(lib_path) + raise + + def new_transform( + self, + pass_name: str = None, + pass_flag: str = None, + force: bool = False, + ): + """Create a new pass for a dialect-specific transform. + + Args: + pass_name: The CPP class name for the pass, e.g., ForgetSecrets. + pass_flag: The CLI flag to use for the pass (optional). + force: If True, overwrite existing files. If False, raise an error if + any files already exist. + """ + if not pass_name: + raise ValueError("pass_name must be provided") - To see the help for a subcommand, run + if not pass_flag: + pass_name = f"{pass_name.lower()}" - python scripts/templates/templates.py --help + lib_path = self.root / "lib" / "Transforms" / pass_name + + if not force and os.path.isdir(lib_path): + raise ValueError(f"Pass directories already exist at {lib_path}") + + templates_path = self.root / "scripts" / "templates" / "Transforms" + templ_lib = templates_path / "lib" + path_mapping = { + templ_lib / "BUILD.jinja": lib_path / "BUILD", + templ_lib / "Pass.cpp.jinja": lib_path / f"{pass_name}.cpp", + templ_lib / "Pass.h.jinja": lib_path / f"{pass_name}.h", + templ_lib / "Pass.td.jinja": lib_path / f"{pass_name}.td", + } + + try: + try_create_dirs(lib_path, force) + copy_all(path_mapping) + render_all( + lib_path, + pass_name=pass_name, + pass_flag=pass_flag, + ) + except: + print("Hit unrecoverable error, cleaning up") + shutil.rmtree(lib_path) + raise + + def new_dialect( + self, + dialect_name: str = None, + dialect_namespace: str = None, + enable_attributes: bool = True, + enable_types: bool = True, + enable_ops: bool = True, + force: bool = False, + ): + """Create a new dialect. + + Args: + dialect_name: The dialect's CPP class name prefix and directory name, + e.g., CGGI (for CGGIDialect). + dialect_namespace: The dialect's CPP namespace, e.g., tfhe_rust for + TfheRustDialect. + enable_attributes: Generate a separate tablegen and includes for + attributes. + enable_types: Generate a separate tablegen and includes for types. + enable_ops: Generate a separate tablegen and includes for ops. + force: If True, overwrite existing files. If False, raise an error if + any files already exist. """ + if not dialect_name: + raise ValueError("dialect_name must be provided") + + if not dialect_namespace: + dialect_namespace = dialect_name.lower() + + lib_path = self.root / "lib" / "Dialect" / dialect_name / "IR" + + if not force and os.path.isdir(lib_path): + raise ValueError(f"Dialect directories already exist at {lib_path}") + + templates_path = self.root / "scripts" / "templates" / "Dialect" + templ_lib = templates_path / "lib" + path_mapping = { + templ_lib / "BUILD.jinja": lib_path / "BUILD", + templ_lib + / "Dialect.cpp.jinja": lib_path / f"{dialect_name}Dialect.cpp", + templ_lib / "Dialect.h.jinja": lib_path / f"{dialect_name}Dialect.h", + templ_lib / "Dialect.td.jinja": lib_path / f"{dialect_name}Dialect.td", + } + + if enable_attributes: + path_mapping.update({ + templ_lib + / "Attributes.h.jinja": lib_path / f"{dialect_name}Attributes.h", + templ_lib + / "Attributes.td.jinja": lib_path / f"{dialect_name}Attributes.td", + templ_lib + / "Attributes.cpp.jinja": lib_path / f"{dialect_name}Attributes.cpp", + }) + + if enable_types: + path_mapping.update({ + templ_lib / "Types.h.jinja": lib_path / f"{dialect_name}Types.h", + templ_lib / "Types.td.jinja": lib_path / f"{dialect_name}Types.td", + templ_lib / "Types.cpp.jinja": lib_path / f"{dialect_name}Types.cpp", + }) + + if enable_ops: + path_mapping.update({ + templ_lib / "Ops.h.jinja": lib_path / f"{dialect_name}Ops.h", + templ_lib / "Ops.td.jinja": lib_path / f"{dialect_name}Ops.td", + templ_lib / "Ops.cpp.jinja": lib_path / f"{dialect_name}Ops.cpp", + }) - def __init__(self): - git_root = pathlib.Path(__file__).parent.parent.parent - if not os.path.isdir(git_root / ".git"): - raise RuntimeError(f"Could not find git root, looked at {git_root}") - self.root = git_root - - def new_conversion_pass( - self, - pass_name: str = None, - source_dialect_name: str = None, - source_dialect_namespace: str = None, - source_dialect_mnemonic: str = None, - target_dialect_name: str = None, - target_dialect_namespace: str = None, - target_dialect_mnemonic: str = None, - force: bool = False, - ): - """Create a new conversion pass. - - Args: - pass_name: The CPP class name and directory name for the conversion - pass, e.g., BGVToLWE - source_dialect_name: The source dialect's CPP class name prefix and - directory name, e.g., CGGI (for CGGIDialect) - source_dialect_namespace: The source dialect's CPP namespace, e.g., - tfhe_rust for TfheRustDialect - source_dialect_mnemonic: The source dialect's mnemonic, e.g., cggi - target_dialect_name: The target dialect's CPP class name prefix and - directory name, e.g., CGGI (for CGGIDialect) - target_dialect_namespace: The target dialect's CPP namespace, e.g., - tfhe_rust for TfheRustDialect - target_dialect_mnemonic: The target dialect's mnemonic, e.g., cggi - force: If True, overwrite existing files. If False, raise an error if - any files already exist. - """ - if not source_dialect_name: - raise ValueError("source_dialect_name must be provided") - if not target_dialect_name: - raise ValueError("target_dialect_name must be provided") - - if not pass_name: - pass_name = f"{source_dialect_name}To{target_dialect_name}" - - # These defaults could be smarter: look up the name in the actual - # tablegen for the dialect or quit if it can't be found - if not source_dialect_mnemonic: - source_dialect_mnemonic = source_dialect_name.lower() - - if not source_dialect_namespace: - source_dialect_namespace = source_dialect_mnemonic - - if not target_dialect_mnemonic: - target_dialect_mnemonic = target_dialect_name.lower() - - if not target_dialect_namespace: - target_dialect_namespace = target_dialect_mnemonic - - lib_path = self.root / "lib" / "Dialect" / \ - source_dialect_name / "Conversions" / pass_name - - if not force and os.path.isdir(lib_path): - raise ValueError(f"Conversion pass directories already exist at {lib_path}") - - templates_path = self.root / "scripts" / "templates" / "Conversion" - templ_lib = templates_path / "lib" - path_mapping = { - templ_lib / "ConversionPass.h.jinja": lib_path / f"{pass_name}.h", - templ_lib / "ConversionPass.td.jinja": lib_path / f"{pass_name}.td", - templ_lib / "BUILD.jinja": lib_path / "BUILD", - templ_lib / "ConversionPass.cpp.jinja": lib_path / f"{pass_name}.cpp", - } - - try: - try_create_dirs(lib_path, force) - copy_all(path_mapping) - - render_all( - lib_path, - pass_name=pass_name, - source_dialect_name=source_dialect_name, - source_dialect_namespace=source_dialect_namespace, - source_dialect_mnemonic=source_dialect_mnemonic, - target_dialect_name=target_dialect_name, - target_dialect_namespace=target_dialect_namespace, - target_dialect_mnemonic=target_dialect_mnemonic, - ) - except: - print("Hit unrecoverable error, cleaning up") - shutil.rmtree(lib_path) - raise - - def new_dialect_transform( - self, - pass_name: str = None, - pass_flag: str = None, - dialect_name: str = None, - dialect_namespace: str = None, - force: bool = False, - ): - """Create a new pass for a dialect-specific transform. - - Args: - pass_name: The CPP class name for the pass, e.g., ForgetSecrets. - pass_flag: The CLI flag to use for the pass (optional). - dialect_name: The dialect's CPP class name prefix and directory name, - e.g., CGGI (for CGGIDialect). - dialect_namespace: The dialect's CPP namespace, e.g., tfhe_rust for - TfheRustDialect. - force: If True, overwrite existing files. If False, raise an error if - any files already exist. - """ - if not pass_name: - raise ValueError("pass_name must be provided") - if not dialect_name: - raise ValueError("dialect_name must be provided") - - if not pass_flag: - pass_flag = f"{dialect_name.lower()}-{pass_name.lower()}" - - # Default could be smarter: look up the name in the actual tablegen for the - # dialect or quit if it can't be found - if not dialect_namespace: - dialect_namespace = dialect_name.lower() - - lib_path = self.root / "lib" / "Dialect" / dialect_name / "Transforms" - - if not force and os.path.isdir(lib_path): - raise ValueError(f"Pass directories already exist at {lib_path}") - - templates_path = self.root / "scripts" / "templates" / "DialectTransforms" - templ_lib = templates_path / "lib" - path_mapping = { - templ_lib / "BUILD.jinja": lib_path / "BUILD", - templ_lib / "Pass.cpp.jinja": lib_path / f"{pass_name}.cpp", - templ_lib / "Pass.h.jinja": lib_path / f"{pass_name}.h", - templ_lib / "Passes.h.jinja": lib_path / "Passes.h", - templ_lib / "Passes.td.jinja": lib_path / "Passes.td", - } - - try: - try_create_dirs(lib_path, force) - copy_all(path_mapping) - render_all( - lib_path, - pass_name=pass_name, - pass_flag=pass_flag, - dialect_name=dialect_name, - dialect_namespace=dialect_namespace, - ) - except: - print("Hit unrecoverable error, cleaning up") - shutil.rmtree(lib_path) - raise - - def new_transform( - self, - pass_name: str = None, - pass_flag: str = None, - force: bool = False, - ): - """Create a new pass for a dialect-specific transform. - - Args: - pass_name: The CPP class name for the pass, e.g., ForgetSecrets. - pass_flag: The CLI flag to use for the pass (optional). - force: If True, overwrite existing files. If False, raise an error if - any files already exist. - """ - if not pass_name: - raise ValueError("pass_name must be provided") - - if not pass_flag: - pass_name = f"{pass_name.lower()}" - - lib_path = self.root / "lib" / "Transforms" / pass_name - - if not force and os.path.isdir(lib_path): - raise ValueError(f"Pass directories already exist at {lib_path}") - - templates_path = self.root / "scripts" / "templates" / "Transforms" - templ_lib = templates_path / "lib" - path_mapping = { - templ_lib / "BUILD.jinja": lib_path / "BUILD", - templ_lib / "Pass.cpp.jinja": lib_path / f"{pass_name}.cpp", - templ_lib / "Pass.h.jinja": lib_path / f"{pass_name}.h", - templ_lib / "Pass.td.jinja": lib_path / f"{pass_name}.td", - } - - try: - try_create_dirs(lib_path, force) - copy_all(path_mapping) - render_all( - lib_path, - pass_name=pass_name, - pass_flag=pass_flag, - ) - except: - print("Hit unrecoverable error, cleaning up") - shutil.rmtree(lib_path) - raise - - def new_dialect( - self, - dialect_name: str = None, - dialect_namespace: str = None, - enable_attributes: bool = True, - enable_types: bool = True, - enable_ops: bool = True, - force: bool = False, - ): - """Create a new dialect. - - Args: - dialect_name: The dialect's CPP class name prefix and directory name, - e.g., CGGI (for CGGIDialect). - dialect_namespace: The dialect's CPP namespace, e.g., tfhe_rust for - TfheRustDialect. - enable_attributes: Generate a separate tablegen and includes for - attributes. - enable_types: Generate a separate tablegen and includes for types. - enable_ops: Generate a separate tablegen and includes for ops. - force: If True, overwrite existing files. If False, raise an error if - any files already exist. - """ - if not dialect_name: - raise ValueError("dialect_name must be provided") - - if not dialect_namespace: - dialect_namespace = dialect_name.lower() - - lib_path = self.root / "lib" / "Dialect" / dialect_name / "IR" - - if not force and os.path.isdir(lib_path): - raise ValueError( - f"Dialect directories already exist at {lib_path}") - - templates_path = self.root / "scripts" / "templates" / "Dialect" - templ_lib = templates_path / "lib" - path_mapping = { - templ_lib / "BUILD.jinja": lib_path / "BUILD", - templ_lib / "Dialect.cpp.jinja": lib_path / f"{dialect_name}Dialect.cpp", - templ_lib / "Dialect.h.jinja": lib_path / f"{dialect_name}Dialect.h", - templ_lib / "Dialect.td.jinja": lib_path / f"{dialect_name}Dialect.td", - } - - if enable_attributes: - path_mapping.update( - { - templ_lib - / "Attributes.h.jinja": lib_path - / f"{dialect_name}Attributes.h", - templ_lib - / "Attributes.td.jinja": ( - lib_path / f"{dialect_name}Attributes.td" - ), - templ_lib - / "Attributes.cpp.jinja": lib_path - / f"{dialect_name}Attributes.cpp", - } - ) - - if enable_types: - path_mapping.update( - { - templ_lib / "Types.h.jinja": lib_path / f"{dialect_name}Types.h", - templ_lib / "Types.td.jinja": lib_path / f"{dialect_name}Types.td", - templ_lib - / "Types.cpp.jinja": lib_path - / f"{dialect_name}Types.cpp", - } - ) - - if enable_ops: - path_mapping.update( - { - templ_lib / "Ops.h.jinja": lib_path / f"{dialect_name}Ops.h", - templ_lib / "Ops.td.jinja": lib_path / f"{dialect_name}Ops.td", - templ_lib / "Ops.cpp.jinja": lib_path / f"{dialect_name}Ops.cpp", - } - ) - - try: - try_create_dirs(lib_path, force) - copy_all(path_mapping) - render_all( - lib_path, - dialect_name=dialect_name, - dialect_namespace=dialect_namespace, - enable_attributes=enable_attributes, - enable_types=enable_types, - enable_ops=enable_ops, - ) - except: - print("Hit unrecoverable error, cleaning up") - shutil.rmtree(lib_path) - raise + try: + try_create_dirs(lib_path, force) + copy_all(path_mapping) + render_all( + lib_path, + dialect_name=dialect_name, + dialect_namespace=dialect_namespace, + enable_attributes=enable_attributes, + enable_types=enable_types, + enable_ops=enable_ops, + ) + except: + print("Hit unrecoverable error, cleaning up") + shutil.rmtree(lib_path) + raise if __name__ == "__main__": - fire.Fire(CLI) + fire.Fire(CLI) diff --git a/scripts/test_lit_to_bazel.py b/scripts/test_lit_to_bazel.py index c1bf82af5..5dd767b0b 100644 --- a/scripts/test_lit_to_bazel.py +++ b/scripts/test_lit_to_bazel.py @@ -2,77 +2,83 @@ def test_convert_to_run_commands_simple(): - run_lines = [ - "// RUN: heir-opt --canonicalize", - ] - assert convert_to_run_commands(run_lines) == [ - "heir-opt --canonicalize", - ] + run_lines = [ + "// RUN: heir-opt --canonicalize", + ] + assert convert_to_run_commands(run_lines) == [ + "heir-opt --canonicalize", + ] + def test_convert_to_run_commands_simple_with_filecheck(): - run_lines = [ - "// RUN: heir-opt --canonicalize | FileCheck %s", - ] - assert convert_to_run_commands(run_lines) == [ - "heir-opt --canonicalize", - PIPE, - "FileCheck %s", - ] + run_lines = [ + "// RUN: heir-opt --canonicalize | FileCheck %s", + ] + assert convert_to_run_commands(run_lines) == [ + "heir-opt --canonicalize", + PIPE, + "FileCheck %s", + ] + def test_convert_to_run_commands_simple_with_line_continuation(): - run_lines = [ - "// RUN: heir-opt \\", - "// RUN: --canonicalize | FileCheck %s", - ] - assert convert_to_run_commands(run_lines) == [ - "heir-opt --canonicalize", - PIPE, - "FileCheck %s", - ] + run_lines = [ + "// RUN: heir-opt \\", + "// RUN: --canonicalize | FileCheck %s", + ] + assert convert_to_run_commands(run_lines) == [ + "heir-opt --canonicalize", + PIPE, + "FileCheck %s", + ] + def test_convert_to_run_commands_simple_with_multiple_line_continuations(): - run_lines = [ - "// RUN: heir-opt \\", - "// RUN: --canonicalize \\", - "// RUN: --cse | FileCheck %s", - ] - assert convert_to_run_commands(run_lines) == [ - "heir-opt --canonicalize --cse", - PIPE, - "FileCheck %s", - ] + run_lines = [ + "// RUN: heir-opt \\", + "// RUN: --canonicalize \\", + "// RUN: --cse | FileCheck %s", + ] + assert convert_to_run_commands(run_lines) == [ + "heir-opt --canonicalize --cse", + PIPE, + "FileCheck %s", + ] + def test_convert_to_run_commands_simple_with_second_command(): - run_lines = [ - "// RUN: heir-opt --canonicalize > %t", - "// RUN: FileCheck %s < %t", - ] - assert convert_to_run_commands(run_lines) == [ - "heir-opt --canonicalize > %t", - "FileCheck %s < %t", - ] + run_lines = [ + "// RUN: heir-opt --canonicalize > %t", + "// RUN: FileCheck %s < %t", + ] + assert convert_to_run_commands(run_lines) == [ + "heir-opt --canonicalize > %t", + "FileCheck %s < %t", + ] + def test_convert_to_run_commands_simple_with_non_run_garbage(): - run_lines = [ - "// RUN: heir-opt --canonicalize > %t", - "// wat", - "// RUN: FileCheck %s < %t", - ] - assert convert_to_run_commands(run_lines) == [ - "heir-opt --canonicalize > %t", - "FileCheck %s < %t", - ] + run_lines = [ + "// RUN: heir-opt --canonicalize > %t", + "// wat", + "// RUN: FileCheck %s < %t", + ] + assert convert_to_run_commands(run_lines) == [ + "heir-opt --canonicalize > %t", + "FileCheck %s < %t", + ] + def test_convert_to_run_commands_with_multiple_pipes(): - run_lines = [ - "// RUN: heir-opt --canonicalize \\", - "// RUN: | heir-translate --emit-verilog \\", - "// RUN: | FileCheck %s", - ] - assert convert_to_run_commands(run_lines) == [ - "heir-opt --canonicalize", - PIPE, - "heir-translate --emit-verilog", - PIPE, - "FileCheck %s", - ] + run_lines = [ + "// RUN: heir-opt --canonicalize \\", + "// RUN: | heir-translate --emit-verilog \\", + "// RUN: | FileCheck %s", + ] + assert convert_to_run_commands(run_lines) == [ + "heir-opt --canonicalize", + PIPE, + "heir-translate --emit-verilog", + PIPE, + "FileCheck %s", + ] diff --git a/tests/Dialect/Polynomial/Conversions/heir_polynomial_to_llvm/runner/generate_test_cases.py b/tests/Dialect/Polynomial/Conversions/heir_polynomial_to_llvm/runner/generate_test_cases.py index cc618e1b2..22b881808 100644 --- a/tests/Dialect/Polynomial/Conversions/heir_polynomial_to_llvm/runner/generate_test_cases.py +++ b/tests/Dialect/Polynomial/Conversions/heir_polynomial_to_llvm/runner/generate_test_cases.py @@ -157,7 +157,7 @@ def parse_to_sympy(poly_str: str, var: sympy.Symbol, cmod: int): def make_coset_regex(x, cmod): """Return a regex that matches x or x +/- cmod.""" if x == 0: - return '0' + return '0' if x < 0: return '{{' + f'({x}|{cmod + x})' + '}}' return '{{' + f'({x}|{x - cmod})' + '}}' @@ -223,7 +223,9 @@ def main(args: argparse.Namespace) -> None: # positive or negative one. # This is because I can't seem to nail down how remsi instructions produce # an output. - expected_coeffs = [make_coset_regex(coeff, cmod) for coeff in expected_coeffs] + expected_coeffs = [ + make_coset_regex(coeff, cmod) for coeff in expected_coeffs + ] coefficient_list_regex = ', '.join(expected_coeffs) if len(expected_coeffs) < coeff_list_len: diff --git a/tests/Examples/jaxite/add_one_lut3_test.py b/tests/Examples/jaxite/add_one_lut3_test.py index 59df33800..afb6cf5c0 100644 --- a/tests/Examples/jaxite/add_one_lut3_test.py +++ b/tests/Examples/jaxite/add_one_lut3_test.py @@ -6,20 +6,21 @@ class AddOneLut3Test(absltest.TestCase): - def test_add_one(self): - x = 5 - lwe_rng, boolean_params, cks, sks = test_utils.setup_test_params() - ciphertext_x = test_utils.encrypt_u8(x, cks, lwe_rng) - result_ciphertext = add_one_lut3_lib.test_add_one_lut3( - ciphertext_x, - sks, - boolean_params, - ) + def test_add_one(self): + x = 5 + lwe_rng, boolean_params, cks, sks = test_utils.setup_test_params() + ciphertext_x = test_utils.encrypt_u8(x, cks, lwe_rng) - result = test_utils.decrypt_u8(result_ciphertext, cks) - self.assertEqual(x + 1, result) + result_ciphertext = add_one_lut3_lib.test_add_one_lut3( + ciphertext_x, + sks, + boolean_params, + ) + + result = test_utils.decrypt_u8(result_ciphertext, cks) + self.assertEqual(x + 1, result) if __name__ == '__main__': - absltest.main() + absltest.main() diff --git a/tests/Examples/jaxite/test_utils.py b/tests/Examples/jaxite/test_utils.py index f8d34cccf..54b7448ed 100644 --- a/tests/Examples/jaxite/test_utils.py +++ b/tests/Examples/jaxite/test_utils.py @@ -1,4 +1,5 @@ """A demonstration of adding 1 to a number in FHE.""" + from typing import Any from jaxite.jaxite_bool import bool_params diff --git a/tests/lit.cfg.py b/tests/lit.cfg.py index 15e147fd1..b0bf6b874 100644 --- a/tests/lit.cfg.py +++ b/tests/lit.cfg.py @@ -40,9 +40,9 @@ "at_clifford_yosys", ] -CMAKE_HEIR_PATH = os.environ.get("CMAKE_HEIR_PATH","") +CMAKE_HEIR_PATH = os.environ.get("CMAKE_HEIR_PATH", "") if CMAKE_HEIR_PATH: - CMAKE_HEIR_PATH = ":"+CMAKE_HEIR_PATH + CMAKE_HEIR_PATH = ":" + CMAKE_HEIR_PATH config.environment["PATH"] = ( ":".join(str(runfiles_dir.joinpath(Path(path))) for path in tool_relpaths) + CMAKE_HEIR_PATH @@ -51,12 +51,12 @@ ) abc_relpath = "edu_berkeley_abc/abc" -config.environment["HEIR_ABC_BINARY"] = ( - str(runfiles_dir.joinpath(Path(abc_relpath))) +config.environment["HEIR_ABC_BINARY"] = str( + runfiles_dir.joinpath(Path(abc_relpath)) ) yosys_libs = "heir/lib/Transforms/YosysOptimizer/yosys" -config.environment["HEIR_YOSYS_SCRIPTS_DIR"] = ( - str(runfiles_dir.joinpath(Path(yosys_libs))) +config.environment["HEIR_YOSYS_SCRIPTS_DIR"] = str( + runfiles_dir.joinpath(Path(yosys_libs)) ) # Some tests that use mlir-runner need access to additional shared libs to