From 12d985664efd4c325e479a30fb7bfdf82ee6fdf4 Mon Sep 17 00:00:00 2001 From: Sanjiv Das Date: Thu, 19 Dec 2024 16:43:03 -0800 Subject: [PATCH] Update `/generate` to not split classes & functions across cells (#1158) * Update to ensure no hanging code cells in generated notebooks * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update generate.py * Update generate.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update generate.py --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .../jupyter_ai/chat_handlers/generate.py | 30 +++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py index a69b5ed28..6318e0979 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py @@ -1,3 +1,4 @@ +import ast import asyncio import os import time @@ -198,6 +199,15 @@ async def afill_outline(outline, llm, verbose=False): await asyncio.gather(*all_coros) +# Check if the content of the cell is python code or not +def is_not_python_code(source: str) -> bool: + try: + ast.parse(source) + except: + return True + return False + + def create_notebook(outline): """Create an nbformat Notebook object for a notebook outline.""" nbf = nbformat.v4 @@ -212,6 +222,26 @@ def create_notebook(outline): nb["cells"].append(nbf.new_markdown_cell("## " + section["title"])) for code_block in section["code"].split("\n\n"): nb["cells"].append(nbf.new_code_cell(code_block)) + + # Post process notebook for hanging code cells: merge hanging cell with the previous cell + merged_cells = [] + for cell in nb["cells"]: + # Fix a hanging code cell + follows_code_cell = merged_cells and merged_cells[-1]["cell_type"] == "code" + is_incomplete = cell["cell_type"] == "code" and cell["source"].startswith(" ") + if follows_code_cell and is_incomplete: + merged_cells[-1]["source"] = ( + merged_cells[-1]["source"] + "\n\n" + cell["source"] + ) + else: + merged_cells.append(cell) + + # Fix code cells that should be markdown + for cell in merged_cells: + if cell["cell_type"] == "code" and is_not_python_code(cell["source"]): + cell["cell_type"] = "markdown" + + nb["cells"] = merged_cells return nb