Skip to content

Commit

Permalink
kola: azure: Restrict network access to vnet
Browse files Browse the repository at this point in the history
If user specifies the vnet kola is running from.

Signed-off-by: Jeremi Piotrowski <jpiotrowski@microsoft.com>
  • Loading branch information
jepio committed Jun 28, 2024
1 parent 503c6bc commit 8262af9
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 32 deletions.
1 change: 1 addition & 0 deletions cmd/kola/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ func init() {
sv(&kola.AzureOptions.ResourceGroup, "azure-resource-group", "", "Deploy resources in an existing resource group")
sv(&kola.AzureOptions.AvailabilitySet, "azure-availability-set", "", "Deploy instances with an existing availibity set")
bv(&kola.AzureOptions.TrustedLaunch, "azure-trusted-launch", false, "Enable trusted launch for VMs (default \"false\")")
sv(&kola.AzureOptions.KolaVnet, "azure-kola-vnet", "", "Pass the vnet/subnet that kola is being ran from to restrict network access to created storage accounts")

// do-specific options
sv(&kola.DOOptions.ConfigPath, "do-config-file", "", "DigitalOcean config file (default \"~/"+auth.DOConfigPath+"\")")
Expand Down
68 changes: 36 additions & 32 deletions platform/api/azure/network.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,45 +31,49 @@ var (
kolaVnet = "kola-vn"
)

func (a *API) PrepareNetworkResources(resourceGroup string) (Network, error) {
if a.Opts.VnetSubnetName != "" {
parts := strings.SplitN(a.Opts.VnetSubnetName, "/", 2)
vnetName := parts[0]
subnetName := "default"
if len(parts) > 1 {
subnetName = parts[1]
func (a *API) findVnetSubnet(vnetSubnetStr string) (Network, error) {
parts := strings.SplitN(vnetSubnetStr, "/", 2)
vnetName := parts[0]
subnetName := "default"
if len(parts) > 1 {
subnetName = parts[1]
}
var net *armnetwork.VirtualNetwork
pager := a.netClient.NewListAllPager(nil)
for pager.More() {
page, err := pager.NextPage(context.TODO())
if err != nil {
return Network{}, fmt.Errorf("failed to iterate vnets: %w", err)
}
var net *armnetwork.VirtualNetwork
pager := a.netClient.NewListAllPager(nil)
for pager.More() {
page, err := pager.NextPage(context.TODO())
if err != nil {
return Network{}, fmt.Errorf("failed to iterate vnets: %w", err)
}
for _, vnet := range page.Value {
if vnet.Name != nil && *vnet.Name == vnetName {
net = vnet
break
}
}
if net != nil {
for _, vnet := range page.Value {
if vnet.Name != nil && *vnet.Name == vnetName {
net = vnet
break
}
}
if net == nil {
return Network{}, fmt.Errorf("failed to find vnet %s", vnetName)
}
subnets := net.Properties.Subnets
if subnets == nil {
return Network{}, fmt.Errorf("failed to find subnet %s in vnet %s", subnetName, vnetName)
}
for _, subnet := range subnets {
if subnet != nil && subnet.Name != nil && *subnet.Name == subnetName {
return Network{*subnet}, nil
}
if net != nil {
break
}
}
if net == nil {
return Network{}, fmt.Errorf("failed to find vnet %s", vnetName)
}
subnets := net.Properties.Subnets
if subnets == nil {
return Network{}, fmt.Errorf("failed to find subnet %s in vnet %s", subnetName, vnetName)
}
for _, subnet := range subnets {
if subnet != nil && subnet.Name != nil && *subnet.Name == subnetName {
return Network{*subnet}, nil
}
}
return Network{}, fmt.Errorf("failed to find subnet %s in vnet %s", subnetName, vnetName)
}

func (a *API) PrepareNetworkResources(resourceGroup string) (Network, error) {
if a.Opts.VnetSubnetName != "" {
return a.findVnetSubnet(a.Opts.VnetSubnetName)
}

if err := a.createVirtualNetwork(resourceGroup); err != nil {
return Network{}, err
Expand Down
1 change: 1 addition & 0 deletions platform/api/azure/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ type Options struct {
Location string
HyperVGeneration string
VnetSubnetName string
KolaVnet string
UseGallery bool
UsePrivateIPs bool
TrustedLaunch bool
Expand Down
15 changes: 15 additions & 0 deletions platform/api/azure/storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,21 @@ func (a *API) CreateStorageAccount(resourceGroup string) (string, error) {
AllowSharedKeyAccess: to.Ptr(false),
},
}
if a.Opts.KolaVnet != "" {
net, err := a.findVnetSubnet(a.Opts.KolaVnet)
if err != nil {
return "", fmt.Errorf("CreateStorageAccount: %v", err)
}
parameters.Properties.NetworkRuleSet = &armstorage.NetworkRuleSet{
DefaultAction: to.Ptr(armstorage.DefaultActionDeny),
VirtualNetworkRules: []*armstorage.VirtualNetworkRule{
{
VirtualNetworkResourceID: net.subnet.ID,
},
},
}
}

plog.Infof("Creating StorageAccount %s", name)
poller, err := a.accClient.BeginCreate(ctx, resourceGroup, name, parameters, nil)
if err != nil {
Expand Down

0 comments on commit 8262af9

Please sign in to comment.