Skip to content

Commit

Permalink
Add gRPC server for enforcing
Browse files Browse the repository at this point in the history
  • Loading branch information
bluecmd committed Nov 20, 2018
1 parent 292fd45 commit 5530fb1
Show file tree
Hide file tree
Showing 7 changed files with 161 additions and 28 deletions.
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
*.db
*.pem
*.prod.yaml
*.prod.yaml
.*.swp
cmd/enforce/enforce
cmd/enforcerd/enforcerd
24 changes: 20 additions & 4 deletions cmd/enforce/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,26 @@ package main
import (
"flag"
"io/ioutil"
"os"
"strings"

"github.com/dhtech/dnsenforcer/enforcer"
"github.com/dhtech/dnsenforcer/enforcer/ipplan"
log "github.com/sirupsen/logrus"
"gopkg.in/yaml.v2"
)

var (
dbFile = flag.String("ipplan", "./ipplan.db", "Path to ipplan file to use")
staticFile = flag.String("static", "./static.prod.yaml", "Path to static file to use")
)

func main() {
// Parse values
vars := &enforcer.Vars{}
flag.StringVar(&vars.Endpoint, "endpoint", "dns.net.dreamhack.se:443", "gRPC endpoint for DNS server")
flag.StringVar(&vars.Certificate, "cert", "./client.pem", "Client certificate to use")
flag.StringVar(&vars.Key, "key", "./key.pem", "Key to use")
flag.StringVar(&vars.DBFile, "ipplan", "./ipplan.db", "Path to ipplan file to use")
flag.StringVar(&vars.Static, "static", "./static.prod.yaml", "Path to static file to use")
flag.IntVar(&vars.HostTTL, "host-ttl", 1337, "Default TTL to use for host records")
flag.BoolVar(&vars.DryRun, "dry-run", false, "Do not actually update records on the DNS server")
vars.IgnoreTypes = strings.Split(*flag.String("ignore-types", "SOA,NS", "Do not remove or add these types of records"), ",")
Expand All @@ -40,15 +45,26 @@ func main() {
}
vars.Zones = zones.Zones

ipp, err := ipplan.Open(*dbFile)
if err != nil {
log.Fatal(err)
}

static, err := os.Open(*staticFile)
if err != nil {
log.Error("You need to create a static record file")
log.Fatal(err)
}

log.Info("Generating DNS records...")

// Create new enforcer
e, err := enforcer.New(vars)
e, err := enforcer.New(vars, ipp, static)
defer e.Close()
if err != nil {
log.Fatal(err)
}
err = e.UpdateRecords()
_, _, err = e.UpdateRecords()
if err != nil {
log.Fatal(err)
}
Expand Down
102 changes: 102 additions & 0 deletions cmd/enforcerd/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
package main

import (
"context"
"flag"
"io/ioutil"
"net"
"os"
"strings"

"github.com/dhtech/dnsenforcer/enforcer"
"github.com/dhtech/dnsenforcer/enforcer/ipplan"
pb "github.com/dhtech/proto/dns"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc"
"google.golang.org/grpc/reflection"
"gopkg.in/yaml.v2"
)

var (
listenAddress = flag.String("listen", ":1215", "address to listen to")
)

type enforcerServer struct {
v *enforcer.Vars
}

func (s *enforcerServer) Refresh(ctx context.Context, req *pb.RefreshRequest) (*pb.RefreshResponse, error) {
ipp, err := ipplan.Open("/etc/ipplan.db")
if err != nil {
return nil, err
}

static, err := os.Open("./static.yml")
if err != nil {
return nil, err
}

// Create new enforcer
e, err := enforcer.New(s.v, ipp, static)
defer e.Close()
if err != nil {
return nil, err
}

added, removed, err := e.UpdateRecords()
if err != nil {
return nil, err
}

rev, err := ipp.Revision()
if err != nil {
log.Errorf("Could not get revision of ipplan: %v", err)
rev = "<unknown>"
}
log.Info("Records updated to revision %s", rev)
resp := &pb.RefreshResponse{
Version: rev,
Added: uint32(added),
Removed: uint32(removed),
}
return resp, nil
}

func main() {
// Parse values
vars := &enforcer.Vars{}
flag.StringVar(&vars.Endpoint, "endpoint", "dns.net.dreamhack.se:443", "gRPC endpoint for DNS server")
flag.StringVar(&vars.Certificate, "cert", "./client.pem", "Client certificate to use")
flag.StringVar(&vars.Key, "key", "./key.pem", "Key to use")
flag.IntVar(&vars.HostTTL, "host-ttl", 1337, "Default TTL to use for host records")
vars.IgnoreTypes = strings.Split(*flag.String("ignore-types", "SOA,NS", "Do not remove or add these types of records"), ",")
zonefile := flag.String("zones-file", "./zones.prod.yaml", "YAML fail with DNS zones to manage")
flag.Parse()

l, err := net.Listen("tcp", *listenAddress)
if err != nil {
log.Fatalf("failed to listen: %v", err)
}

// Get data from zones file
b, err := ioutil.ReadFile(*zonefile)
if err != nil {
log.Error("You need to create a zone config file")
log.Fatal(err)
}
var zones struct {
Zones []string `yaml:"zones"`
}
err = yaml.Unmarshal(b, &zones)
if err != nil {
log.Error("You need to create a zone config file")
log.Fatal(err)
}
vars.Zones = zones.Zones

s := &enforcerServer{vars}
g := grpc.NewServer()
pb.RegisterEnforcerServiceServer(g, s)
reflection.Register(g)
g.Serve(l)
}
17 changes: 8 additions & 9 deletions enforcer/enforcer.go
Original file line number Diff line number Diff line change
@@ -1,37 +1,36 @@
package enforcer

import (
"io"

"github.com/dhtech/dnsenforcer/enforcer/ipplan"
)

// Enforcer is used to update DNS servers with new data
type Enforcer struct {
IPPlan *ipplan.IPPlan
Vars *Vars
IPPlan *ipplan.IPPlan

static io.Reader
}

// Vars hold values needed for enforcer
type Vars struct {
Endpoint string
Certificate string
Key string
DBFile string
Static string
Zones []string
HostTTL int
DryRun bool
IgnoreTypes []string
}

// New returns a new DNS Enforcer
func New(vars *Vars) (*Enforcer, error) {
p, err := ipplan.Open(vars.DBFile)
if err != nil {
return nil, err
}
func New(vars *Vars, ipp *ipplan.IPPlan, static io.Reader) (*Enforcer, error) {
return &Enforcer{
IPPlan: p,
Vars: vars,
IPPlan: ipp,
static: static,
}, nil
}

Expand Down
14 changes: 14 additions & 0 deletions enforcer/ipplan/reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,17 @@ func (p *IPPlan) Hosts() ([]*Host, error) {
}
return hosts, rows.Err()
}

func (p *IPPlan) Revision() (string, error) {
rows, err := p.db.Query(`SELECT value FROM meta_data WHERE name = 'revision';`)
if err != nil {
return "", err
}
defer rows.Close()
for rows.Next() {
var r string
rows.Scan(&r)
return r, nil
}
return "", fmt.Errorf("No revision metadata")
}
7 changes: 1 addition & 6 deletions enforcer/static.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,15 @@ package enforcer

import (
"io"
"os"

log "github.com/sirupsen/logrus"
"gopkg.in/yaml.v2"
)

// GetStaticRecords returns records that are specified in static YAML file
func (e *Enforcer) GetStaticRecords() ([]*Record, error) {
data, err := os.Open(e.Vars.Static)
if err != nil {
return nil, err
}
var records []*Record
reader := yaml.NewDecoder(data)
reader := yaml.NewDecoder(e.static)
for {
var record *Record
err := reader.Decode(&record)
Expand Down
20 changes: 12 additions & 8 deletions enforcer/updates.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,17 @@ import (
"google.golang.org/grpc/credentials"
)

// UpdateRecords logs all records to stdout
func (e *Enforcer) UpdateRecords() error {
// UpdateRecords logs all records to stdout and returns (added, removed, error)
func (e *Enforcer) UpdateRecords() (int, int, error) {
// Client Auth
certificate, err := tls.LoadX509KeyPair(e.Vars.Certificate, e.Vars.Key)
if err != nil {
return err
return 0, 0, err
}

host, _, err := net.SplitHostPort(e.Vars.Endpoint)
if err != nil {
return err
return 0, 0, err
}

creds := credentials.NewTLS(&tls.Config{
Expand All @@ -34,7 +34,7 @@ func (e *Enforcer) UpdateRecords() error {
// gRPC connection
conn, err := grpc.Dial(e.Vars.Endpoint, grpc.WithTransportCredentials(creds))
if err != nil {
return err
return 0, 0, err
}
defer conn.Close()

Expand Down Expand Up @@ -77,10 +77,10 @@ func (e *Enforcer) UpdateRecords() error {

wg.Wait()

// Get localally constructed records
// Get locally constructed records
localRecords, err := e.GetAllRecords()
if err != nil {
return err
return 0, 0, err
}

// Find which records to remove
Expand All @@ -99,13 +99,15 @@ func (e *Enforcer) UpdateRecords() error {
}

// Remove records that are present on server but no locally
removed := 0
if !e.Vars.DryRun {
log.Infof("Deleting %d records", len(remove))
for _, r := range remove {
if _, err := c.Remove(ctx, &dns.RemoveRequest{Record: []*dns.Record{r}}); err != nil {
log.Errorf("Remove of %s failed with %v", r.Domain, err)
} else {
log.Infof("Removed %s", r.Domain)
removed += 1
}
}
} else {
Expand All @@ -132,12 +134,14 @@ func (e *Enforcer) UpdateRecords() error {
}

// Insert records that are missing on the server
added := 0
if !e.Vars.DryRun {
log.Infof("Inserting %d records", len(insert))
for _, r := range insert {
if _, err := c.Insert(ctx, &dns.InsertRequest{Record: []*dns.Record{r}}); err != nil {
log.Errorf("Insert of %s failed with %v", r.Domain, err)
} else {
added += 1
log.Infof("Added %s", r.Domain)
}
}
Expand All @@ -147,5 +151,5 @@ func (e *Enforcer) UpdateRecords() error {
}
}

return nil
return added, removed, nil
}

0 comments on commit 5530fb1

Please sign in to comment.