Skip to content

Commit

Permalink
Merge pull request #53 from cicekhayri/create-migrations
Browse files Browse the repository at this point in the history
Create migrations
  • Loading branch information
cicekhayri authored Jan 12, 2024
2 parents edfafe7 + c01edb6 commit cdeb596
Show file tree
Hide file tree
Showing 16 changed files with 652 additions and 55 deletions.
21 changes: 21 additions & 0 deletions inspira/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from inspira.cli.generate_model_file import database_file_exists, generate_model_file
from inspira.cli.generate_repository_file import generate_repository_file
from inspira.cli.generate_service_file import generate_service_file
from inspira.migrations.migrations import create_migrations, run_migrations

DATABASE_TYPES = ["postgres", "mysql", "sqlite", "mssql"]

Expand Down Expand Up @@ -55,6 +56,26 @@ def database(name, type):
create_database_file(name, type)


@cli.command()
@click.argument("module_name")
@click.option(
"--empty", nargs=1, type=str, required=False, help="Generate empty migration file."
)
def createmigrations(module_name, empty):
migration_name = None

if empty:
migration_name = empty

create_migrations(module_name, migration_name)


@cli.command()
@click.argument("module_name")
def migrate(module_name):
run_migrations(module_name)


@cli.command()
def init():
generate_project()
Expand Down
9 changes: 1 addition & 8 deletions inspira/cli/create_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import click

from inspira.cli.create_app import generate_project
from inspira.cli.init_file import create_init_file
from inspira.utils import singularize


Expand Down Expand Up @@ -34,8 +35,6 @@ def create_controller_file(name, is_websocket):
create_test_directory(controller_directory)

controller_template_file = "controller_template.txt"

# Create __init__.py in the resource directory
create_init_file(controller_directory)

controller_file = os.path.join(
Expand All @@ -60,9 +59,3 @@ def create_controller_file(name, is_websocket):
output_file.write(content)

click.echo(f"Module '{singularize_name}' created successfully.")


def create_init_file(directory):
init_file = os.path.join(directory, "__init__.py")
with open(init_file, "w"):
pass
42 changes: 0 additions & 42 deletions inspira/cli/generate_model_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,48 +29,6 @@ def generate_model_file(module_name):
)
output_file.write(content)

update_init_db(module_name)


def update_init_db(module_name):
# Assuming the database.py is in the same directory as this script
main_script_path = "database.py"

if database_file_exists():
# Read the content of the main script
with open(main_script_path, "r") as main_script_file:
main_script_content = main_script_file.read()

# Check if the import statement already exists in the init_db function
import_statement = (
f"import src.{module_name.lower()}.{singularize(module_name)}"
)
if import_statement in main_script_content:
click.echo(
f"Import statement for '{module_name}' already exists in init_db."
)
return

# Find the location of the init_db function
init_db_keyword = "def init_db():"
init_db_index = main_script_content.find(init_db_keyword)

# If init_db function is found, insert the import statement after it
if init_db_index != -1:
init_db_end = init_db_index + len(init_db_keyword)
main_script_content = (
main_script_content[:init_db_end]
+ f"\n {import_statement}\n"
+ main_script_content[init_db_end:]
)
click.echo(f"Import statement for '{module_name}' added to init_db.")
else:
click.echo("Function 'init_db()' not found in database.py.")

# Write the updated content back to the main script
with open(main_script_path, "w") as main_script_file:
main_script_file.write(main_script_content)


def database_file_exists() -> bool:
main_script_path = "database.py"
Expand Down
7 changes: 7 additions & 0 deletions inspira/cli/init_file.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import os


def create_init_file(directory):
init_file = os.path.join(directory, "__init__.py")
with open(init_file, "w"):
pass
4 changes: 0 additions & 4 deletions inspira/cli/templates/database_template.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,3 @@ db_session = scoped_session(
)
Base = declarative_base()
Base.query = db_session.query_property()


def init_db():
Base.metadata.create_all(bind=engine)
Empty file added inspira/migrations/__init__.py
Empty file.
167 changes: 167 additions & 0 deletions inspira/migrations/migrations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
import sys

import click
from sqlalchemy import select, create_engine, inspect
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import declarative_base
from sqlalchemy.sql.expression import func
import os
from sqlalchemy import MetaData, Column, Integer, String, text

from inspira.logging import log
from inspira.migrations.utils import (
get_or_create_migration_directory,
get_migration_files,
load_model_file,
get_columns_from_model,
generate_add_column_sql,
generate_drop_column_sql,
generate_create_table_sql,
generate_rename_column_sql,
generate_empty_sql_file,
get_indexes_from_model,
generate_add_index_sql,
generate_drop_index_sql,
)

PROJECT_ROOT = os.path.abspath(".")
sys.path.append(PROJECT_ROOT)

try:
from database import Base, engine, db_session
except ImportError:
Base = declarative_base()
engine = create_engine("sqlite:///:memory:")
db_session = None


class Migration(Base):
__tablename__ = "migrations"
id = Column(Integer, primary_key=True)
migration_name = Column(String(255))
version = Column(Integer)


def initialize_database(engine):
Base.metadata.create_all(engine)


def execute_sql_file(file_path):
with open(file_path, "r") as file:
sql_content = file.read()

sql_statements = [
statement.strip() for statement in sql_content.split(";") if statement.strip()
]

with engine.connect() as connection:
try:
for statement in sql_statements:
connection.execute(text(statement))
connection.commit()
log.info("Table creation successful.")
except SQLAlchemyError as e:
log.error("Error:", e)
connection.rollback()
log.info("Transaction rolled back.")


def create_migrations(entity_name, empty_migration_file):
if empty_migration_file:
generate_empty_sql_file(entity_name, empty_migration_file)

module = load_model_file(entity_name)

existing_columns = get_existing_columns(entity_name)
new_columns = get_columns_from_model(getattr(module, module.__name__))

if not existing_columns:
generate_create_table_sql(module, entity_name)
else:
renamed_columns = [
(old_col, new_col.key)
for old_col, new_col in zip(existing_columns, new_columns)
if old_col != new_col.key
]
if renamed_columns:
generate_rename_column_sql(entity_name, existing_columns, new_columns)
else:
added_columns = [
col for col in new_columns if col.key not in existing_columns
]
if added_columns:
generate_add_column_sql(entity_name, existing_columns, added_columns)
else:
removed_columns = [
col for col in existing_columns if col not in new_columns
]
if removed_columns:
generate_drop_column_sql(entity_name, existing_columns, new_columns)

existing_indexes = get_existing_indexes(entity_name)
new_indexes = get_indexes_from_model(getattr(module, module.__name__))
generate_add_index_sql(entity_name, existing_indexes, new_indexes)
generate_drop_index_sql(entity_name, existing_indexes, new_indexes)


def run_migrations(module_name):
with engine.connect() as connection:
if not engine.dialect.has_table(connection, "migrations"):
initialize_database(engine)

if not engine.dialect.has_table(connection, module_name):
initialize_database(engine)

migration_dir = get_or_create_migration_directory(module_name)
migration_files = get_migration_files(migration_dir)

for file in migration_files:
migration_name = os.path.basename(file).replace(".sql", "")

current_version = (
connection.execute(
select(func.max(Migration.version)).where(
Migration.migration_name == migration_name
)
).scalar()
or 0
)

if not current_version:
execute_sql_file(file)
click.echo(
f"Applying migration for {migration_name} version {current_version}"
)

insert_migration(current_version, migration_name)


def get_existing_columns(table_name):
metadata = MetaData()
metadata.reflect(bind=engine)

if table_name in metadata.tables:
table = metadata.tables[table_name]
return [column.name for column in table.columns]
else:
return None


def insert_migration(current_version, migration_name):
migration = Migration()
migration.version = current_version + 1
migration.migration_name = migration_name
db_session.add(migration)
db_session.commit()


def get_existing_indexes(table_name):
metadata = MetaData()
metadata.reflect(bind=engine)

if table_name in metadata.tables:
inspector = inspect(engine)

indexes = inspector.get_indexes(table_name)

return [index["name"] for index in indexes]
Loading

0 comments on commit cdeb596

Please sign in to comment.