Skip to content

Commit

Permalink
Merge pull request #88 from d-strobel/refactor/dns-package
Browse files Browse the repository at this point in the history
Refactor/dns package
  • Loading branch information
d-strobel authored Dec 12, 2024
2 parents 8da89e3 + 004936b commit 317daea
Show file tree
Hide file tree
Showing 15 changed files with 239 additions and 101 deletions.
6 changes: 3 additions & 3 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package gowindows

import (
"github.com/d-strobel/gowindows/connection"
"github.com/d-strobel/gowindows/windows/dns/server"
"github.com/d-strobel/gowindows/windows/dns"
"github.com/d-strobel/gowindows/windows/local/accounts"
)

Expand All @@ -15,7 +15,7 @@ import (
type Client struct {
Connection connection.Connection
LocalAccounts *accounts.Client
DnsServer *server.Client
Dns *dns.Client
}

// NewClient returns a new instance of the Client object, initialized with the provided configuration.
Expand All @@ -28,7 +28,7 @@ func NewClient(conn connection.Connection) *Client {

// Build the client with the subpackages.
c.LocalAccounts = accounts.NewClient(c.Connection)
c.DnsServer = server.NewClient(c.Connection)
c.Dns = dns.NewClient(c.Connection)

return c
}
Expand Down
46 changes: 46 additions & 0 deletions parsing/cim_time_duration.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package parsing

import (
"encoding/json"
"time"
)

// CimTimeDuration is a custom time type that embeds the time.Duration type.
// It is designed to handle the unmarshalling of CimInstance time-span json blocks.
type CimTimeDuration struct {
time.Duration
}

// cimTimeDurationObject is a struct that represents the unmarshalled
// json of a CimInstance time duration object.
// It is used to do the initial unmarshalling of the json block.
type cimTimeDurationObject struct {
Days int32 `json:"Days"`
Hours int32 `json:"Hours"`
Minutes int32 `json:"Minutes"`
Seconds int32 `json:"Seconds"`
MilliSeconds int32 `json:"Milliseconds"`
}

// UnmarshalJSON implements the json.Unmarshaler interface for the CimTimeDuration type.
// It parses a JSON-encoded CimInstance time duration JSON block and converts it into a CimTimeDuration object.
func (t *CimTimeDuration) UnmarshalJSON(b []byte) error {
var d cimTimeDurationObject

// Unmarshal the json block into the cimTimeDurationObject struct.
if err := json.Unmarshal(b, &d); err != nil {
return err
}

// Convert the fields into a time.Duration object.
duration := time.Duration(d.Days)*24*time.Hour +
time.Duration(d.Hours)*time.Hour +
time.Duration(d.Minutes)*time.Minute +
time.Duration(d.Seconds)*time.Second +
time.Duration(d.MilliSeconds)*time.Millisecond

// Set the time.Duration object to the CimTimeDuration object.
t.Duration = duration

return nil
}
86 changes: 86 additions & 0 deletions parsing/cim_time_duration_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
package parsing

import (
"encoding/json"
"testing"
"time"

"github.com/stretchr/testify/suite"
)

// Unit test suite for all CimTimeDuration parsing functions
type CimTimeDurationUnitTestSuite struct {
suite.Suite
// Fixtures
testJson string
testExpected testCimTime
}

// Fixture objects
type testCimTime struct {
Test string `json:"Test"`
CimTimeDuration CimTimeDuration `json:"LeaseDuration"`
}

func (suite *CimTimeDurationUnitTestSuite) SetupSuite() {
// Fixtures
suite.testJson = `{
"Test": "Test String",
"LeaseDuration": {
"Ticks": 6912000000000,
"Days": 8,
"Hours": 0,
"Milliseconds": 0,
"Minutes": 0,
"Seconds": 0,
"TotalDays": 8,
"TotalHours": 192,
"TotalMilliseconds": 691200000,
"TotalMinutes": 11520,
"TotalSeconds": 691200
}
}`
suite.testExpected = testCimTime{
Test: "Test String",
CimTimeDuration: CimTimeDuration{
Duration: 8 * 24 * time.Hour,
},
}
}

func TestCimTimeDurationUnitTestSuite(t *testing.T) {
suite.Run(t, &CimTimeDurationUnitTestSuite{})
}

func (suite *CimTimeDurationUnitTestSuite) TestUnmarshalJSON() {
suite.T().Parallel()

suite.Run("should unmarshal the CimInstance duration json to CimTimeDuration", func() {
cimTime := CimTimeDuration{}
expectedTimeDuration, err := time.ParseDuration("1h30m")
suite.Require().NoError(err)
expectedCimTimeDuration := CimTimeDuration{Duration: expectedTimeDuration}

err = cimTime.UnmarshalJSON([]byte(`{"Days":0,"Hours":1,"Minutes":30,"Seconds":0,"Milliseconds":0}`))
suite.NoError(err)
suite.Equal(expectedCimTimeDuration, cimTime)
})

suite.Run("should unmarshal the CimInstance duration json to CimTimeDuration with all possible fields", func() {
cimTime := CimTimeDuration{}
expectedTimeDuration, err := time.ParseDuration("98h30m5s10ms")
suite.Require().NoError(err)
expectedCimTimeDuration := CimTimeDuration{Duration: expectedTimeDuration}

err = cimTime.UnmarshalJSON([]byte(`{"Days":4,"Hours":2,"Minutes":30,"Seconds":5,"Milliseconds":10}`))
suite.NoError(err)
suite.Equal(expectedCimTimeDuration, cimTime)
})

suite.Run("should unmarshal the whole CimTimeDuration correctly", func() {
testCimTime := testCimTime{}
err := json.Unmarshal([]byte(suite.testJson), &testCimTime)
suite.NoError(err)
suite.Equal(suite.testExpected, testCimTime)
})
}
35 changes: 16 additions & 19 deletions windows/dns/server/server.go → windows/dns/dns.go
Original file line number Diff line number Diff line change
@@ -1,40 +1,37 @@
// Package server provides a Go library for handling Windows DNS Server.
// Package dns provides a Go library for handling Windows DNS Server.
// The functions are related to the Powershell dns server cmdlets provided by Windows.
// https://learn.microsoft.com/en-us/powershell/module/dnsserver/?view=windowsserver2022-ps
package server
package dns

import (
"context"
"encoding/json"
"time"

"errors"

"github.com/d-strobel/gowindows/connection"
"github.com/d-strobel/gowindows/parsing"
)

// server is a type constraint for the run function, ensuring it works with specific types.
type server interface {
// dns is a type constraint for the run function, ensuring it works with specific types.
type dns interface {
Zone | []Zone | recordObject | []recordObject
}

// Default Windows DNS TTL.
// https://learn.microsoft.com/en-us/windows/win32/ad/configuration-of-ttl-limits?source=recommendations
const defaultTimeToLive int32 = 86400

// timeToLive represents a time to live (TTL) object returned by Powershell DNS commands.
type timeToLive struct {
Seconds int32 `json:"TotalSeconds"`
}
var defaultTimeToLive time.Duration = time.Second * 86400

// recordObject contains the unmarshaled json of the powershell record object.
type recordObject struct {
DistinguishedName string `json:"DistinguishedName"`
Name string `json:"HostName"`
RecordData recordRecordData `json:"RecordData"`
RecordType string `json:"RecordType"`
Timestamp parsing.DotnetTime `json:"Timestamp"`
Type int8 `json:"Type"`
TimeToLive timeToLive `json:"TimeToLive"`
DistinguishedName string `json:"DistinguishedName"`
Name string `json:"HostName"`
RecordData recordRecordData `json:"RecordData"`
RecordType string `json:"RecordType"`
Timestamp parsing.DotnetTime `json:"Timestamp"`
Type int8 `json:"Type"`
TimeToLive parsing.CimTimeDuration `json:"TimeToLive"`
}
type recordRecordData struct {
CimInstanceProperties parsing.CimClassKeyVal `json:"CimInstanceProperties"`
Expand Down Expand Up @@ -62,7 +59,7 @@ func NewClientWithParser(conn connection.Connection, parsing func(string) (strin

// run runs a PowerShell command against a Windows system, handles the command results,
// and unmarshals the output into a local object type.
func run[T server](ctx context.Context, c *Client, cmd string, l *T) error {
func run[T dns](ctx context.Context, c *Client, cmd string, d *T) error {
// Run the command
result, err := c.Connection.RunWithPowershell(ctx, cmd)
if err != nil {
Expand All @@ -84,7 +81,7 @@ func run[T server](ctx context.Context, c *Client, cmd string, l *T) error {
}

// Unmarshal stdout
if err = json.Unmarshal([]byte(result.StdOut), &l); err != nil {
if err = json.Unmarshal([]byte(result.StdOut), &d); err != nil {
return err
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package server
package dns

import (
"context"
Expand Down
25 changes: 14 additions & 11 deletions windows/dns/server/record_a.go → windows/dns/record_a.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package server
package dns

import (
"context"
Expand All @@ -16,7 +16,7 @@ type RecordA struct {
Name string
Addresses []string
Timestamp time.Time
TimeToLive int32
TimeToLive time.Duration
}

// convertOutput converts the unmarshaled JSON output from the recordObject to a RecordA object.
Expand All @@ -25,7 +25,7 @@ func (r *RecordA) convertOutput(o []recordObject) {
r.DistinguishedName = o[0].DistinguishedName
r.Name = o[0].Name
r.Timestamp = o[0].Timestamp.Time
r.TimeToLive = o[0].TimeToLive.Seconds
r.TimeToLive = o[0].TimeToLive.Duration

// Set the addresses and the lowest TTL.
if len(o) == 1 {
Expand All @@ -36,8 +36,8 @@ func (r *RecordA) convertOutput(o []recordObject) {

// Set the lowest TTL to be RFC2181 compliant.
// https://www.rfc-editor.org/rfc/rfc2181#section-5.2
if record.TimeToLive.Seconds < r.TimeToLive {
r.TimeToLive = record.TimeToLive.Seconds
if record.TimeToLive.Duration < r.TimeToLive {
r.TimeToLive = record.TimeToLive.Duration
}
}
}
Expand Down Expand Up @@ -103,7 +103,7 @@ type RecordACreateParams struct {
// Specifies the time to live (TTL) of the record in seconds.
// If not provided, the default is 86400 seconds.
// A TTL of 0 is not allowed.
TimeToLive int32
TimeToLive time.Duration
}

// pwshCommand returns the PowerShell command to create a new A-Record.
Expand All @@ -118,12 +118,14 @@ func (params RecordACreateParams) pwshCommand() string {
cmd = append(cmd, fmt.Sprintf("-ZoneName '%s'", params.Zone))

// Set default TTL if not provided.
// New-TimeSpan only allows int32 values.
// https://learn.microsoft.com/de-de/powershell/module/microsoft.powershell.utility/new-timespan?view=powershell-7.4
if params.TimeToLive == 0 {
params.TimeToLive = defaultTimeToLive
}
cmd = append(cmd, fmt.Sprintf("-TimeToLive %s", fmt.Sprintf("$(New-TimeSpan -Seconds %d)", params.TimeToLive)))

// New-TimeSpan only allows int32 values. So we round the duration to seconds.
// https://learn.microsoft.com/de-de/powershell/module/microsoft.powershell.utility/new-timespan?view=powershell-7.4
seconds := int32(params.TimeToLive.Round(time.Second).Seconds())
cmd = append(cmd, fmt.Sprintf("-TimeToLive %s", fmt.Sprintf("$(New-TimeSpan -Seconds %d)", seconds)))

// Add addresses with single quotes and join them with commas.
for _, address := range params.Addresses {
Expand Down Expand Up @@ -175,7 +177,7 @@ type RecordAUpdateParams struct {
// Specifies the time to live (TTL) of the record in seconds.
// If not provided, the default TTL is 86400 seconds.
// A TTL of 0 is not allowed.
TimeToLive int32
TimeToLive time.Duration
}

// pwshCommand returns the PowerShell command to update an A-Record.
Expand All @@ -186,14 +188,15 @@ func (params RecordAUpdateParams) pwshCommand() string {
if params.TimeToLive == 0 {
params.TimeToLive = defaultTimeToLive
}
seconds := int32(params.TimeToLive.Round(time.Second).Seconds())

// Base command
cmd := []string{"$nr=@();Get-DnsServerResourceRecord -RRType 'A' -Node"}

// Add parameters and logic for handling the TTL update.
cmd = append(cmd, fmt.Sprintf("-Name '%s'", params.Name))
cmd = append(cmd, fmt.Sprintf("-ZoneName '%s'", params.Zone))
cmd = append(cmd, fmt.Sprintf("| ForEach-Object{$r=$_;$n=[ciminstance]::new($r);$n.TimeToLive=New-TimeSpan -Seconds %d", params.TimeToLive))
cmd = append(cmd, fmt.Sprintf("| ForEach-Object{$r=$_;$n=[ciminstance]::new($r);$n.TimeToLive=New-TimeSpan -Seconds %d", seconds))
cmd = append(cmd, fmt.Sprintf(";$nr+=Set-DnsServerResourceRecord -OldInputObject $r -NewInputObject $n -ZoneName '%s' -PassThru}", params.Zone))
cmd = append(cmd, ";if($nr.Count -ge 2){ConvertTo-Json $nr -Compress}else{ConvertTo-Json @($nr) -Compress}")

Expand Down
Loading

0 comments on commit 317daea

Please sign in to comment.