Skip to content

Commit

Permalink
Update for proper submodule handling
Browse files Browse the repository at this point in the history
  • Loading branch information
bmaltais committed Mar 3, 2024
1 parent fce1bd4 commit fa526c3
Show file tree
Hide file tree
Showing 6 changed files with 55 additions and 10 deletions.
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[submodule "sd-scripts"]
path = sd-scripts
url = https://github.com/kohya-ss/sd-scripts.git
2 changes: 1 addition & 1 deletion sd-scripts
35 changes: 35 additions & 0 deletions setup/setup_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,41 @@
errors = 0 # Define the 'errors' variable before using it
log = logging.getLogger('sd')

def update_submodule(submodule_path, branch_or_tag):
"""
Ensure the submodule is initialized, updated, and set to a specific commit, tag, or branch.
Parameters:
- submodule_path: The relative path within the repository to the submodule.
- branch_or_tag: The specific commit, tag, or branch to checkout in the submodule.
"""
original_dir = os.getcwd() # Store the original directory
try:
# Ensure the working directory is the root of the main repository
if not os.path.exists(submodule_path):
raise FileNotFoundError(f"Submodule path does not exist: {submodule_path}")

# Initialize and update the submodule
subprocess.run(["git", "submodule", "update", "--init", "--recursive", "--quiet", submodule_path], check=True)
log.info("Submodule initialized and updated.")

# Navigate to the submodule directory
os.chdir(submodule_path)

# Fetch the latest changes from the remote, including tags
subprocess.run(["git", "fetch", "--all", "--tags", "--quiet"], check=True)

# Checkout the specified branch, tag, or commit
subprocess.run(["git", "checkout", "--quiet", branch_or_tag], check=True)
log.info(f"Submodule set to {branch_or_tag}.")

except subprocess.CalledProcessError as e:
log.error(f"Error during Git operation: {e}")
except FileNotFoundError as e:
log.error(e)
finally:
os.chdir(original_dir) # Restore the original directory

def read_tag_version_from_file(file_path):
"""
Read the tag version from a given file.
Expand Down
9 changes: 6 additions & 3 deletions setup/setup_linux.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,12 @@ def main_menu(platform_requirements_file, show_stdout: bool = False, no_run_acce
setup_common.setup_logging()
# Read the tag version from the file
tag_version = setup_common.read_tag_version_from_file(".sd-scripts-release")
setup_common.clone_or_checkout(
"https://github.com/kohya-ss/sd-scripts.git", tag_version, "sd-scripts"
)

setup_common.update_submodule("sd-scripts", tag_version)

# setup_common.clone_or_checkout(
# "https://github.com/kohya-ss/sd-scripts.git", tag_version, "sd-scripts"
# )

parser = argparse.ArgumentParser()
parser.add_argument('--platform-requirements-file', dest='platform_requirements_file', default='requirements_linux.txt', help='Path to the platform-specific requirements file')
Expand Down
8 changes: 5 additions & 3 deletions setup/setup_windows.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,11 @@ def install_kohya_ss_torch2():
# Read the tag version from the file
tag_version = setup_common.read_tag_version_from_file(".\.sd-scripts-release")

setup_common.clone_or_checkout(
"https://github.com/kohya-ss/sd-scripts.git", tag_version, "sd-scripts"
)
setup_common.update_submodule("sd-scripts", tag_version)

# setup_common.clone_or_checkout(
# "https://github.com/kohya-ss/sd-scripts.git", tag_version, "sd-scripts"
# )

# Upgrade pip if needed
setup_common.install("--upgrade pip")
Expand Down
8 changes: 5 additions & 3 deletions setup/validate_requirements.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,11 @@ def main():
# Read the tag version from the file
tag_version = setup_common.read_tag_version_from_file(".sd-scripts-release")

setup_common.clone_or_checkout(
"https://github.com/kohya-ss/sd-scripts.git", tag_version, "sd-scripts"
)
setup_common.update_submodule("sd-scripts", tag_version)

# setup_common.clone_or_checkout(
# "https://github.com/kohya-ss/sd-scripts.git", tag_version, "sd-scripts"
# )

if args.requirements:
setup_common.install_requirements(args.requirements, check_no_verify_flag=True)
Expand Down

0 comments on commit fa526c3

Please sign in to comment.