From 5530fb1acb5b4a05842b3c4c99a1925d735528ab Mon Sep 17 00:00:00 2001 From: Christian Svensson Date: Tue, 20 Nov 2018 17:11:23 +0100 Subject: [PATCH] Add gRPC server for enforcing --- .gitignore | 5 +- cmd/enforce/main.go | 24 +++++++-- cmd/enforcerd/main.go | 102 ++++++++++++++++++++++++++++++++++++++ enforcer/enforcer.go | 17 +++---- enforcer/ipplan/reader.go | 14 ++++++ enforcer/static.go | 7 +-- enforcer/updates.go | 20 +++++--- 7 files changed, 161 insertions(+), 28 deletions(-) create mode 100644 cmd/enforcerd/main.go diff --git a/.gitignore b/.gitignore index 92d8051..f13ca75 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,6 @@ *.db *.pem -*.prod.yaml \ No newline at end of file +*.prod.yaml +.*.swp +cmd/enforce/enforce +cmd/enforcerd/enforcerd diff --git a/cmd/enforce/main.go b/cmd/enforce/main.go index 977d9cc..3cde929 100644 --- a/cmd/enforce/main.go +++ b/cmd/enforce/main.go @@ -3,21 +3,26 @@ package main import ( "flag" "io/ioutil" + "os" "strings" "github.com/dhtech/dnsenforcer/enforcer" + "github.com/dhtech/dnsenforcer/enforcer/ipplan" log "github.com/sirupsen/logrus" "gopkg.in/yaml.v2" ) +var ( + dbFile = flag.String("ipplan", "./ipplan.db", "Path to ipplan file to use") + staticFile = flag.String("static", "./static.prod.yaml", "Path to static file to use") +) + func main() { // Parse values vars := &enforcer.Vars{} flag.StringVar(&vars.Endpoint, "endpoint", "dns.net.dreamhack.se:443", "gRPC endpoint for DNS server") flag.StringVar(&vars.Certificate, "cert", "./client.pem", "Client certificate to use") flag.StringVar(&vars.Key, "key", "./key.pem", "Key to use") - flag.StringVar(&vars.DBFile, "ipplan", "./ipplan.db", "Path to ipplan file to use") - flag.StringVar(&vars.Static, "static", "./static.prod.yaml", "Path to static file to use") flag.IntVar(&vars.HostTTL, "host-ttl", 1337, "Default TTL to use for host records") flag.BoolVar(&vars.DryRun, "dry-run", false, "Do not actually update records on the DNS server") vars.IgnoreTypes = strings.Split(*flag.String("ignore-types", "SOA,NS", "Do not remove or add these types of records"), ",") @@ -40,15 +45,26 @@ func main() { } vars.Zones = zones.Zones + ipp, err := ipplan.Open(*dbFile) + if err != nil { + log.Fatal(err) + } + + static, err := os.Open(*staticFile) + if err != nil { + log.Error("You need to create a static record file") + log.Fatal(err) + } + log.Info("Generating DNS records...") // Create new enforcer - e, err := enforcer.New(vars) + e, err := enforcer.New(vars, ipp, static) defer e.Close() if err != nil { log.Fatal(err) } - err = e.UpdateRecords() + _, _, err = e.UpdateRecords() if err != nil { log.Fatal(err) } diff --git a/cmd/enforcerd/main.go b/cmd/enforcerd/main.go new file mode 100644 index 0000000..30feb93 --- /dev/null +++ b/cmd/enforcerd/main.go @@ -0,0 +1,102 @@ +package main + +import ( + "context" + "flag" + "io/ioutil" + "net" + "os" + "strings" + + "github.com/dhtech/dnsenforcer/enforcer" + "github.com/dhtech/dnsenforcer/enforcer/ipplan" + pb "github.com/dhtech/proto/dns" + log "github.com/sirupsen/logrus" + "google.golang.org/grpc" + "google.golang.org/grpc/reflection" + "gopkg.in/yaml.v2" +) + +var ( + listenAddress = flag.String("listen", ":1215", "address to listen to") +) + +type enforcerServer struct { + v *enforcer.Vars +} + +func (s *enforcerServer) Refresh(ctx context.Context, req *pb.RefreshRequest) (*pb.RefreshResponse, error) { + ipp, err := ipplan.Open("/etc/ipplan.db") + if err != nil { + return nil, err + } + + static, err := os.Open("./static.yml") + if err != nil { + return nil, err + } + + // Create new enforcer + e, err := enforcer.New(s.v, ipp, static) + defer e.Close() + if err != nil { + return nil, err + } + + added, removed, err := e.UpdateRecords() + if err != nil { + return nil, err + } + + rev, err := ipp.Revision() + if err != nil { + log.Errorf("Could not get revision of ipplan: %v", err) + rev = "" + } + log.Info("Records updated to revision %s", rev) + resp := &pb.RefreshResponse{ + Version: rev, + Added: uint32(added), + Removed: uint32(removed), + } + return resp, nil +} + +func main() { + // Parse values + vars := &enforcer.Vars{} + flag.StringVar(&vars.Endpoint, "endpoint", "dns.net.dreamhack.se:443", "gRPC endpoint for DNS server") + flag.StringVar(&vars.Certificate, "cert", "./client.pem", "Client certificate to use") + flag.StringVar(&vars.Key, "key", "./key.pem", "Key to use") + flag.IntVar(&vars.HostTTL, "host-ttl", 1337, "Default TTL to use for host records") + vars.IgnoreTypes = strings.Split(*flag.String("ignore-types", "SOA,NS", "Do not remove or add these types of records"), ",") + zonefile := flag.String("zones-file", "./zones.prod.yaml", "YAML fail with DNS zones to manage") + flag.Parse() + + l, err := net.Listen("tcp", *listenAddress) + if err != nil { + log.Fatalf("failed to listen: %v", err) + } + + // Get data from zones file + b, err := ioutil.ReadFile(*zonefile) + if err != nil { + log.Error("You need to create a zone config file") + log.Fatal(err) + } + var zones struct { + Zones []string `yaml:"zones"` + } + err = yaml.Unmarshal(b, &zones) + if err != nil { + log.Error("You need to create a zone config file") + log.Fatal(err) + } + vars.Zones = zones.Zones + + s := &enforcerServer{vars} + g := grpc.NewServer() + pb.RegisterEnforcerServiceServer(g, s) + reflection.Register(g) + g.Serve(l) +} diff --git a/enforcer/enforcer.go b/enforcer/enforcer.go index 18fa582..74718e3 100644 --- a/enforcer/enforcer.go +++ b/enforcer/enforcer.go @@ -1,13 +1,17 @@ package enforcer import ( + "io" + "github.com/dhtech/dnsenforcer/enforcer/ipplan" ) // Enforcer is used to update DNS servers with new data type Enforcer struct { - IPPlan *ipplan.IPPlan Vars *Vars + IPPlan *ipplan.IPPlan + + static io.Reader } // Vars hold values needed for enforcer @@ -15,8 +19,6 @@ type Vars struct { Endpoint string Certificate string Key string - DBFile string - Static string Zones []string HostTTL int DryRun bool @@ -24,14 +26,11 @@ type Vars struct { } // New returns a new DNS Enforcer -func New(vars *Vars) (*Enforcer, error) { - p, err := ipplan.Open(vars.DBFile) - if err != nil { - return nil, err - } +func New(vars *Vars, ipp *ipplan.IPPlan, static io.Reader) (*Enforcer, error) { return &Enforcer{ - IPPlan: p, Vars: vars, + IPPlan: ipp, + static: static, }, nil } diff --git a/enforcer/ipplan/reader.go b/enforcer/ipplan/reader.go index a3db686..1edcc07 100644 --- a/enforcer/ipplan/reader.go +++ b/enforcer/ipplan/reader.go @@ -80,3 +80,17 @@ func (p *IPPlan) Hosts() ([]*Host, error) { } return hosts, rows.Err() } + +func (p *IPPlan) Revision() (string, error) { + rows, err := p.db.Query(`SELECT value FROM meta_data WHERE name = 'revision';`) + if err != nil { + return "", err + } + defer rows.Close() + for rows.Next() { + var r string + rows.Scan(&r) + return r, nil + } + return "", fmt.Errorf("No revision metadata") +} diff --git a/enforcer/static.go b/enforcer/static.go index 39b9eb2..5ee92f3 100644 --- a/enforcer/static.go +++ b/enforcer/static.go @@ -2,7 +2,6 @@ package enforcer import ( "io" - "os" log "github.com/sirupsen/logrus" "gopkg.in/yaml.v2" @@ -10,12 +9,8 @@ import ( // GetStaticRecords returns records that are specified in static YAML file func (e *Enforcer) GetStaticRecords() ([]*Record, error) { - data, err := os.Open(e.Vars.Static) - if err != nil { - return nil, err - } var records []*Record - reader := yaml.NewDecoder(data) + reader := yaml.NewDecoder(e.static) for { var record *Record err := reader.Decode(&record) diff --git a/enforcer/updates.go b/enforcer/updates.go index 457a7c1..0ecb94d 100644 --- a/enforcer/updates.go +++ b/enforcer/updates.go @@ -13,17 +13,17 @@ import ( "google.golang.org/grpc/credentials" ) -// UpdateRecords logs all records to stdout -func (e *Enforcer) UpdateRecords() error { +// UpdateRecords logs all records to stdout and returns (added, removed, error) +func (e *Enforcer) UpdateRecords() (int, int, error) { // Client Auth certificate, err := tls.LoadX509KeyPair(e.Vars.Certificate, e.Vars.Key) if err != nil { - return err + return 0, 0, err } host, _, err := net.SplitHostPort(e.Vars.Endpoint) if err != nil { - return err + return 0, 0, err } creds := credentials.NewTLS(&tls.Config{ @@ -34,7 +34,7 @@ func (e *Enforcer) UpdateRecords() error { // gRPC connection conn, err := grpc.Dial(e.Vars.Endpoint, grpc.WithTransportCredentials(creds)) if err != nil { - return err + return 0, 0, err } defer conn.Close() @@ -77,10 +77,10 @@ func (e *Enforcer) UpdateRecords() error { wg.Wait() - // Get localally constructed records + // Get locally constructed records localRecords, err := e.GetAllRecords() if err != nil { - return err + return 0, 0, err } // Find which records to remove @@ -99,6 +99,7 @@ func (e *Enforcer) UpdateRecords() error { } // Remove records that are present on server but no locally + removed := 0 if !e.Vars.DryRun { log.Infof("Deleting %d records", len(remove)) for _, r := range remove { @@ -106,6 +107,7 @@ func (e *Enforcer) UpdateRecords() error { log.Errorf("Remove of %s failed with %v", r.Domain, err) } else { log.Infof("Removed %s", r.Domain) + removed += 1 } } } else { @@ -132,12 +134,14 @@ func (e *Enforcer) UpdateRecords() error { } // Insert records that are missing on the server + added := 0 if !e.Vars.DryRun { log.Infof("Inserting %d records", len(insert)) for _, r := range insert { if _, err := c.Insert(ctx, &dns.InsertRequest{Record: []*dns.Record{r}}); err != nil { log.Errorf("Insert of %s failed with %v", r.Domain, err) } else { + added += 1 log.Infof("Added %s", r.Domain) } } @@ -147,5 +151,5 @@ func (e *Enforcer) UpdateRecords() error { } } - return nil + return added, removed, nil }