Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TYP: fix most typing errors in provider #2038

Merged
merged 6 commits into from
Oct 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 20 additions & 13 deletions src/_nebari/provider/cloud/amazon_web_services.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
import re
import time
from typing import Dict, List
from typing import Dict, List, Optional

import boto3
from botocore.exceptions import ClientError, EndpointConnectionError
Expand All @@ -29,7 +29,9 @@ def check_credentials():


@functools.lru_cache()
def aws_session(region: str = None, digitalocean_region: str = None) -> boto3.Session:
def aws_session(
region: Optional[str] = None, digitalocean_region: Optional[str] = None
) -> boto3.Session:
"""Create a boto3 session."""
if digitalocean_region:
aws_access_key_id = os.environ["SPACES_ACCESS_KEY_ID"]
Expand Down Expand Up @@ -126,7 +128,7 @@ def instances(region: str) -> Dict[str, str]:
return {t: t for t in instance_types}


def aws_get_vpc_id(name: str, namespace: str, region: str) -> str:
def aws_get_vpc_id(name: str, namespace: str, region: str) -> Optional[str]:
"""Return VPC ID for the EKS cluster namedd `{name}-{namespace}`."""
cluster_name = f"{name}-{namespace}"
session = aws_session(region=region)
Expand All @@ -138,6 +140,7 @@ def aws_get_vpc_id(name: str, namespace: str, region: str) -> str:
for tag in tags:
if tag["Key"] == "Name" and tag["Value"] == cluster_name:
return vpc["VpcId"]
return None


def aws_get_subnet_ids(name: str, namespace: str, region: str) -> List[str]:
Expand Down Expand Up @@ -216,11 +219,11 @@ def aws_get_security_group_ids(name: str, namespace: str, region: str) -> List[s
return security_group_ids


def aws_get_load_balancer_name(vpc_id: str, region: str) -> str:
def aws_get_load_balancer_name(vpc_id: str, region: str) -> Optional[str]:
"""Return load balancer name for the VPC ID."""
if not vpc_id:
print("No VPC ID provided. Exiting...")
return
return None

session = aws_session(region=region)
client = session.client("elb")
Expand All @@ -229,6 +232,7 @@ def aws_get_load_balancer_name(vpc_id: str, region: str) -> str:
for load_balancer in response:
if load_balancer["VPCId"] == vpc_id:
return load_balancer["LoadBalancerName"]
return None


def aws_get_efs_ids(name: str, namespace: str, region: str) -> List[str]:
Expand Down Expand Up @@ -260,7 +264,7 @@ def aws_get_efs_mount_target_ids(efs_id: str, region: str) -> List[str]:
"""Return list of EFS mount target IDs for the EFS ID."""
if not efs_id:
print("No EFS ID provided. Exiting...")
return
return []

session = aws_session(region=region)
client = session.client("efs")
Expand Down Expand Up @@ -290,7 +294,9 @@ def aws_get_ec2_volume_ids(name: str, namespace: str, region: str) -> List[str]:
return volume_ids


def aws_get_iam_policy(region: str, name: str = None, pattern: str = None) -> str:
def aws_get_iam_policy(
region: Optional[str], name: Optional[str] = None, pattern: Optional[str] = None
) -> Optional[str]:
"""Return IAM policy ARN for the policy name or pattern."""
session = aws_session(region=region)
client = session.client("iam")
Expand All @@ -301,6 +307,7 @@ def aws_get_iam_policy(region: str, name: str = None, pattern: str = None) -> st
pattern and re.match(pattern, policy["PolicyName"])
):
return policy["Arn"]
return None


def aws_delete_load_balancer(name: str, namespace: str, region: str):
Expand Down Expand Up @@ -640,9 +647,9 @@ def aws_delete_ec2_volumes(name: str, namespace: str, region: str):

def aws_delete_s3_objects(
bucket_name: str,
endpoint: str = None,
region: str = None,
digitalocean_region: str = None,
endpoint: Optional[str] = None,
region: Optional[str] = None,
digitalocean_region: Optional[str] = None,
):
"""
Delete all objects in the S3 bucket.
Expand Down Expand Up @@ -707,9 +714,9 @@ def aws_delete_s3_objects(

def aws_delete_s3_bucket(
bucket_name: str,
endpoint: str = None,
region: str = None,
digitalocean_region: str = None,
endpoint: Optional[str] = None,
region: Optional[str] = None,
digitalocean_region: Optional[str] = None,
):
"""
Delete S3 bucket.
Expand Down
9 changes: 6 additions & 3 deletions src/_nebari/provider/git.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,27 @@
import os
import subprocess
from pathlib import Path
from typing import Optional

from _nebari.utils import change_directory


def is_git_repo(path: Path = None):
def is_git_repo(path: Optional[Path] = None):
path = path or Path.cwd()
return ".git" in os.listdir(path)


def initialize_git(path: Path = None):
def initialize_git(path: Optional[Path] = None):
path = path or Path.cwd()
with change_directory(path):
subprocess.check_output(["git", "init"])
# Ensure initial branch is called main
subprocess.check_output(["git", "checkout", "-b", "main"])


def add_git_remote(remote_path: str, path: Path = None, remote_name: str = "origin"):
def add_git_remote(
remote_path: str, path: Optional[Path] = None, remote_name: str = "origin"
):
path = path or Path.cwd()

c = configparser.ConfigParser()
Expand Down
7 changes: 2 additions & 5 deletions src/_nebari/provider/terraform.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ def deploy(
terraform_import: bool = False,
terraform_apply: bool = True,
terraform_destroy: bool = False,
input_vars: Dict[str, Any] = None,
state_imports: List = None,
input_vars: Dict[str, Any] = {},
state_imports: List[Any] = [],
):
"""Execute a given terraform directory.

Expand All @@ -52,9 +52,6 @@ def deploy(
state_imports: (addr, id) pairs for iterate through and attempt
to terraform import
"""
input_vars = input_vars or {}
state_imports = state_imports or []

with tempfile.NamedTemporaryFile(
mode="w", encoding="utf-8", suffix=".tfvars.json"
) as f:
Expand Down