Skip to content

Commit

Permalink
fix: support fallback to default profile and region
Browse files Browse the repository at this point in the history
By moving the remote logic to its own class it allows us to encapsulate the behaviour. And making it easier to extend the behaviour.

Fixes issue #7
  • Loading branch information
Joris Conijn committed Jan 25, 2022
1 parent e15ccf9 commit 13bb074
Show file tree
Hide file tree
Showing 6 changed files with 128 additions and 58 deletions.
10 changes: 5 additions & 5 deletions pull_request_codecommit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@ def main(repository_path: Optional[str]) -> None:
"""
repo = Repository(repository_path)

if not repo.supported:
if not repo.remote.supported:
raise click.ClickException("The repository is not compatible with this tool!")

click.echo(f"Remote: {repo.remote}")
click.echo(f"Region: {repo.aws_region}")
click.echo(f"Profile: {repo.aws_profile}")
click.echo(f"Repo: {repo.repository_name}")
click.echo(f"Remote: {repo.remote.url}")
click.echo(f"Region: {repo.remote.region}")
click.echo(f"Profile: {repo.remote.profile}")
click.echo(f"Repo: {repo.remote.name}")
click.echo(f"Branch: {repo.branch}")
click.echo()
pull_request = repo.pull_request_information()
Expand Down
23 changes: 18 additions & 5 deletions pull_request_codecommit/aws/client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import json
from typing import List
from typing import List, Optional
import subprocess


Expand All @@ -12,12 +12,25 @@ class Client:
We use the AWS CLI for these operations so that we can use tha MFA toke cache from the cli.
"""

def __init__(self, profile: str, region: str) -> None:
self.__profile = profile
self.__region = region
def __init__(self, profile: Optional[str], region: Optional[str]) -> None:
self.__base_command: List[str] = []
self.__profile: Optional[str] = profile
self.__region: Optional[str] = region

@property
def base_command(self) -> List[str]:
base_command = ["aws"]

if self.__profile:
base_command.extend(["--profile", self.__profile])

if self.__region:
base_command.extend(["--region", self.__region])

return base_command

def __execute(self, parameters: List[str]) -> str:
command = ["aws", "--profile", self.__profile, "--region", self.__region]
command = self.base_command
command.extend(parameters)
response = subprocess.run(command, stdout=subprocess.PIPE)

Expand Down
5 changes: 3 additions & 2 deletions pull_request_codecommit/git/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .commits import Commits
from .client import Client
from .commit import Commit
from .commits import Commits
from .message import Message
from .client import Client
from .remote import Remote
50 changes: 50 additions & 0 deletions pull_request_codecommit/git/remote.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import re
from typing import Optional


class Remote:
"""
Understands remote CodeCommit URLs
"""

def __init__(self, url: str):
self.__url: str = url
self.__region: Optional[str] = ""
self.__profile: Optional[str] = ""
self.__name: str = ""

def __regex(self, pattern: str, index: int = 1) -> Optional[str]:
match = re.search(pattern, self.__url)
return match.group(index) if match else None

@property
def supported(self) -> bool:
return self.__url.startswith("codecommit::") and self.name != ""

@property
def url(self) -> str:
return self.__url

@property
def region(self) -> Optional[str]:
if not self.__region:
self.__region = self.__regex(r"^codecommit::(.*)://")

return self.__region

@property
def profile(self) -> Optional[str]:
if self.__profile == "":
self.__profile = self.__regex(r"//(.*)@")

return self.__profile

@property
def name(self) -> str:
if not self.__name:
name = self.__regex(r"(\/\/|@)(.*)$", 2)

if name:
self.__name = name

return self.__name
51 changes: 9 additions & 42 deletions pull_request_codecommit/repository.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import os
import re
from typing import Optional, Tuple
from typing import Optional

from .config import Config
from .git import Client as GitClient
from .git import Client as GitClient, Remote
from .aws import Client as AwsClient
from .pull_request import PullRequest

Expand All @@ -14,10 +13,7 @@ class Repository:
"""

def __init__(self, path: Optional[str] = None) -> None:
self.__remote: str = ""
self.__aws_region: str = ""
self.__aws_profile: str = ""
self.__repository_name: str = ""
self.__remote: Optional[Remote] = None
self.__branch: str = ""

if not path:
Expand All @@ -26,44 +22,15 @@ def __init__(self, path: Optional[str] = None) -> None:
self.__git = GitClient(path)

def __config(self, method: str) -> Optional[str]:
return getattr(Config, method)(self.aws_profile)
return getattr(Config, method)(self.remote.profile)

@property
def supported(self) -> bool:
return self.remote.startswith("codecommit::")

@property
def remote(self) -> str:
def remote(self) -> Remote:
if not self.__remote:
self.__remote = self.__git.remote("origin")

if self.supported:
self.__extract_from_remote(self.__remote)
self.__remote = Remote(self.__git.remote("origin"))

return self.__remote

def __extract_from_remote(self, remote: str) -> None:
def resolve(resolver: Optional[re.Match], index: int) -> str:
return resolver.group(index) if resolver else ""

# This can use some more error handling
match = re.search(r"^codecommit::(.*)://(.*)@(.*)$", remote)
self.__aws_region = resolve(match, 1)
self.__aws_profile = resolve(match, 2)
self.__repository_name = resolve(match, 3)

@property
def aws_region(self) -> str:
return self.__aws_region

@property
def aws_profile(self) -> str:
return self.__aws_profile

@property
def repository_name(self) -> str:
return self.__repository_name

@property
def branch(self) -> str:
if not self.__branch:
Expand All @@ -81,12 +48,12 @@ def pull_request_information(self) -> PullRequest:
return PullRequest(commits)

def create_pull_request(self, title: str, description: str) -> str:
client = AwsClient(profile=self.__aws_profile, region=self.__aws_region)
client = AwsClient(profile=self.remote.profile, region=self.remote.region)
response = client.create_pull_request(
title=title,
description=description,
repository=self.repository_name,
repository=self.remote.name,
source=self.branch,
destination=self.destination,
)
return f"https://{self.__aws_region}.console.aws.amazon.com/codesuite/codecommit/repositories/{self.repository_name}/pull-requests/{response.get('pullRequestId')}/details?region={self.__aws_region}"
return f"https://{self.remote.region}.console.aws.amazon.com/codesuite/codecommit/repositories/{self.remote.name}/pull-requests/{response.get('pullRequestId')}/details?region={self.remote.region}"
47 changes: 43 additions & 4 deletions tests/test_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,32 +15,57 @@ def edit_message(message: str) -> str:


@pytest.mark.parametrize(
"region, profile, config, commits",
"remote, region, profile, config, commits",
[
(
"codecommit::eu-west-1://my-profile@my-repository",
"eu-west-1",
"my-profile",
b"[default]\nbranch: my-main\n[profile my-profile]\nbranch: my-master",
COMMITS,
),
(
"codecommit::eu-west-1://my-profile@my-repository",
"eu-central-1",
"my-profile",
b"[default]\nbranch: my-main\n[profile my-profile]\nbranch: my-master",
COMMITS_NO_ISSUES,
),
(
"codecommit::eu-west-1://my-profile@my-repository",
"eu-west-1",
"my-other-profile",
b"[default]\nbranch: my-main\n[profile my-profile]\nbranch: my-master",
COMMITS_NO_ISSUES,
),
(
"codecommit::eu-west-1://my-profile@my-repository",
"eu-central-1",
"my-other-profile",
b"[default]\nbranch: my-main\n[profile my-profile]\nbranch: my-master",
COMMITS,
),
(
"codecommit::eu-west-1://my-repository",
"eu-central-1",
None,
b"[default]\nbranch: my-main\n[profile my-profile]\nbranch: my-master",
COMMITS,
),
(
"codecommit::://my-profile@my-repository",
None,
"my-profile",
b"[default]\nbranch: my-main\n[profile my-profile]\nbranch: my-master",
COMMITS,
),
(
"codecommit::://my-repository",
None,
None,
b"[default]\nbranch: my-main\n[profile my-profile]\nbranch: my-master",
COMMITS,
),
],
)
@patch("pull_request_codecommit.aws.client.subprocess.run")
Expand All @@ -50,6 +75,7 @@ def test_invoke(
mock_edit: MagicMock,
mock_git_client: MagicMock,
mock_aws_client: MagicMock,
remote: str,
region: str,
profile: str,
config: bytes,
Expand All @@ -58,9 +84,7 @@ def test_invoke(
mock_edit.side_effect = edit_message
mock_git_client.return_value.get_commit_messages.return_value = Commits(commits)

mock_git_client.return_value.remote.return_value = (
f"codecommit::{region}://{profile}@my-repository"
)
mock_git_client.return_value.remote.return_value = remote
mock_git_client.return_value.current_branch.return_value = "feat/my-feature"
configparser.open = MagicMock(return_value=TextIOWrapper(BytesIO(config))) # type: ignore

Expand Down Expand Up @@ -186,3 +210,18 @@ def test_invoke_quit_edit(
runner = CliRunner()
result = runner.invoke(main)
assert result.exit_code == 1


@patch("pull_request_codecommit.repository.GitClient")
def test_invoke_no_repository_name(
mock_git_client: MagicMock,
) -> None:
mock_git_client.return_value.remote.return_value = f"codecommit::eu-west-1://"
mock_git_client.return_value.current_branch.return_value = "feat/my-feature"
config = b"[default]\nbranch: my-main\n[profile my-profile]\nbranch: my-master"
configparser.open = MagicMock(return_value=TextIOWrapper(BytesIO(config))) # type: ignore

runner = CliRunner()
result = runner.invoke(main)
assert result.exit_code == 1
assert "Error: The repository is not compatible with this tool!" in result.output

0 comments on commit 13bb074

Please sign in to comment.