Skip to content

Commit

Permalink
fix: ensure config's thread safety (#23)
Browse files Browse the repository at this point in the history
* fix: ensure config's thread safety

* fix: add thread-safety to the hosts map
  • Loading branch information
anfragment authored Dec 7, 2023
1 parent 856ccf8 commit f3f18d0
Show file tree
Hide file tree
Showing 13 changed files with 219 additions and 84 deletions.
14 changes: 2 additions & 12 deletions app.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@ import (
"context"
"log"

"github.com/anfragment/zen/certmanager"
"github.com/anfragment/zen/config"
"github.com/anfragment/zen/filter"
"github.com/anfragment/zen/proxy"
)

Expand Down Expand Up @@ -39,17 +37,8 @@ func (a *App) domReady(ctx context.Context) {

// StartProxy initializes the associated resources and starts the proxy
func (a *App) StartProxy() string {
certmanager := certmanager.GetCertManager()
a.proxy = proxy.NewProxy(a.ctx)

if a.proxy == nil {
filter := filter.NewFilter()
a.proxy = proxy.NewProxy(filter, certmanager, a.ctx)
}

if err := certmanager.Init(); err != nil {
log.Printf("failed to initialize certmanager: %v", err)
return err.Error()
}
log.Println("starting proxy")
if err := a.proxy.Start(); err != nil {
log.Printf("failed to start proxy: %v", err)
Expand All @@ -69,5 +58,6 @@ func (a *App) StopProxy() string {
log.Printf("failed to stop proxy: %v", err)
return err.Error()
}
a.proxy = nil
return ""
}
11 changes: 5 additions & 6 deletions certmanager/certmanager.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,12 +104,13 @@ func (cm *CertManager) Init() (err error) {
cm.keyPath = path.Join(folderName, keyName)
cm.certCache = make(map[string]tls.Certificate)

if config.Config.Certmanager.CAInstalled {
if config.Config.GetCAInstalled() {
if err = cm.loadCA(); err != nil {
err = fmt.Errorf("CA load: %v", err)
}
return
}

if err = os.Remove(cm.certPath); err != nil && !os.IsNotExist(err) {
err = fmt.Errorf("remove existing CA cert: %v", err)
return
Expand All @@ -134,8 +135,7 @@ func (cm *CertManager) Init() (err error) {
err = fmt.Errorf("install CA: %v", err)
return
}
config.Config.Certmanager.CAInstalled = true
config.Config.Save()
config.Config.SetCAInstalled(true)
})

return cm.initErr
Expand Down Expand Up @@ -276,7 +276,7 @@ func (cm *CertManager) ClearCache() {
//
// @frontend
func (cm *CertManager) UninstallCA() string {
if !config.Config.Certmanager.CAInstalled {
if !config.Config.GetCAInstalled() {
return "CA is not installed"
}
if err := cm.Init(); err != nil {
Expand All @@ -289,8 +289,7 @@ func (cm *CertManager) UninstallCA() string {
return err.Error()
}

config.Config.Certmanager.CAInstalled = false
config.Config.Save()
config.Config.SetCAInstalled(false)

cm.certData = nil
cm.keyData = nil
Expand Down
81 changes: 67 additions & 14 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@ package config
import (
"embed"
"encoding/json"
"fmt"
"log"
"os"
"path"
"sync"
)

// Config is the singleton config instance.
var Config config

type filterList struct {
Expand All @@ -18,20 +19,29 @@ type filterList struct {
Enabled bool `json:"enabled"`
}

// config stores and manages the configuration for the application.
// Although all fields are public, this is only for use by the JSON marshaller.
// All access to the config should be done through the exported methods.
//
// Methods that get called by the frontend should be annotated with @frontend.
type config struct {
sync.RWMutex
Filter struct {
FilterLists []filterList `json:"filterLists"`
} `json:"filter"`
Certmanager struct {
CAInstalled bool `json:"caInstalled"`
} `json:"certmanager"`
Proxy struct {
Port int `json:"port"`
Port uint16 `json:"port"`
} `json:"proxy"`
ConfigDir string `json:"-"`
DataDir string `json:"-"`
}

// Save saves the config to disk.
// It is not thread-safe, and should only be called if the caller has
// a lock on the config.
func (c *config) Save() error {
configData, err := json.MarshalIndent(c, "", " ")
if err != nil {
Expand All @@ -46,71 +56,114 @@ func (c *config) Save() error {
}

// GetFilterLists returns the list of enabled filter lists.
// Used on the frontend to display the list of filter lists.
//
// @frontend
func (c *config) GetFilterLists() []filterList {
c.RLock()
defer c.RUnlock()

return c.Filter.FilterLists
}

// AddFilterList adds a new filter list to the list of enabled filter lists.
// Used on the frontend to add a new filter list.
//
// @frontend
func (c *config) AddFilterList(list filterList) string {
c.Lock()
defer c.Unlock()

c.Filter.FilterLists = append(c.Filter.FilterLists, list)
if err := c.Save(); err != nil {
fmt.Printf("failed to save config: %v", err)
log.Printf("failed to save config: %v", err)
return err.Error()
}
return ""
}

// RemoveFilterList removes a filter list from the list of enabled filter lists.
// Used on the frontend to remove a filter list.
//
// @frontend
func (c *config) RemoveFilterList(url string) string {
c.Lock()
defer c.Unlock()

for i, filterList := range c.Filter.FilterLists {
if filterList.Url == url {
c.Filter.FilterLists = append(c.Filter.FilterLists[:i], c.Filter.FilterLists[i+1:]...)
break
}
}
if err := c.Save(); err != nil {
fmt.Printf("failed to save config: %v", err)
log.Printf("failed to save config: %v", err)
return err.Error()
}
return ""
}

// ToggleFilterList toggles the enabled state of a filter list.
// Used on the frontend to toggle the enabled state of a filter list.
//
// @frontend
func (c *config) ToggleFilterList(url string, enabled bool) string {
c.Lock()
defer c.Unlock()

for i, filterList := range c.Filter.FilterLists {
if filterList.Url == url {
c.Filter.FilterLists[i].Enabled = enabled
break
}
}
if err := c.Save(); err != nil {
fmt.Printf("failed to save config: %v", err)
log.Printf("failed to save config: %v", err)
return err.Error()
}
return ""
}

// GetPort returns the port the proxy is set to listen on.
// Used on the frontend in the settings manager.
func (c *config) GetPort() int {
//
// @frontend
func (c *config) GetPort() uint16 {
c.RLock()
defer c.RUnlock()

return c.Proxy.Port
}

// SetPort sets the port the proxy is set to listen on.
// Used on the frontend in the settings manager.
func (c *config) SetPort(port int) string {
//
// @frontend
func (c *config) SetPort(port uint16) string {
c.Lock()
defer c.Unlock()

c.Proxy.Port = port
if err := c.Save(); err != nil {
fmt.Printf("failed to save config: %v", err)
log.Printf("failed to save config: %v", err)
return err.Error()
}
return ""
}

// GetCAInstalled returns whether the CA is installed.
func (c *config) GetCAInstalled() bool {
c.RLock()
defer c.RUnlock()

return c.Certmanager.CAInstalled
}

// SetCAInstalled sets whether the CA is installed.
func (c *config) SetCAInstalled(caInstalled bool) {
c.Lock()
defer c.Unlock()

c.Certmanager.CAInstalled = caInstalled
if err := c.Save(); err != nil {
log.Printf("failed to save config: %v", err)
}
}

//go:embed default-config.json
var defaultConfig embed.FS

Expand Down
2 changes: 1 addition & 1 deletion filter/filter.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ func (f *Filter) Init() {
var wg sync.WaitGroup
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
for _, filterList := range config.Config.Filter.FilterLists {
for _, filterList := range config.Config.GetFilterLists() {
if !filterList.Enabled {
continue
}
Expand Down
10 changes: 9 additions & 1 deletion filter/ruletree/ruletree.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"net/url"
"regexp"
"strings"
"sync"

"github.com/anfragment/zen/filter/ruletree/rule"
)
Expand All @@ -18,7 +19,8 @@ type RuleTree struct {
// root is the root node of the trie that stores the rules.
root node
// hosts maps hostnames to filter names.
hosts map[string]*string
hosts map[string]*string
hostsMu sync.RWMutex
}

var (
Expand Down Expand Up @@ -61,7 +63,10 @@ func (rt *RuleTree) AddRule(rawRule string, filterName *string) error {
return nil
}

rt.hostsMu.Lock()
rt.hosts[host] = filterName
rt.hostsMu.Unlock()

return nil
}

Expand Down Expand Up @@ -115,11 +120,14 @@ func (rt *RuleTree) AddRule(rawRule string, filterName *string) error {

func (rt *RuleTree) HandleRequest(req *http.Request) rule.RequestAction {
host := req.URL.Hostname()
rt.hostsMu.RLock()
if filterName, ok := rt.hosts[host]; ok {
rt.hostsMu.RUnlock()
// 0.0.0.0 may not be the actual IP address defined in the hosts file,
// but storing the actual one feels wasteful.
return rule.RequestAction{Type: rule.ActionBlock, RawRule: fmt.Sprintf("0.0.0.0 %s", host), FilterName: *filterName}
}
rt.hostsMu.RUnlock()

urlWithoutPort := url.URL{
Scheme: req.URL.Scheme,
Expand Down
14 changes: 13 additions & 1 deletion frontend/package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion frontend/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
"@blueprintjs/select": "^5.0.17",
"is-emoji-supported": "^0.0.5",
"react": "^18.2.0",
"react-dom": "^18.2.0"
"react-dom": "^18.2.0",
"use-debounce": "^10.0.0"
},
"devDependencies": {
"@types/react": "^18.0.17",
Expand Down
2 changes: 1 addition & 1 deletion frontend/package.json.md5
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2caa3435ad409062c5b4a564f7805f4d
7ca1eca7d1b6cc425bc64ad995948df9
48 changes: 48 additions & 0 deletions frontend/src/SettingsManager/PortInput.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import { FormGroup, NumericInput } from '@blueprintjs/core';
import { useEffect, useState } from 'react';
import { useDebouncedCallback } from 'use-debounce';

import { GetPort, SetPort } from '../../wailsjs/go/config/config';

export function PortInput() {
const [state, setState] = useState({
port: 0,
loading: true,
});

useEffect(() => {
(async () => {
const port = await GetPort();
setState({ ...state, port, loading: false });
})();
}, []);

const setPort = useDebouncedCallback(async (port: number) => {
await SetPort(port);
}, 500);

return (
<FormGroup
label="Port"
labelFor="port"
helperText={
<>
The port the proxy server will listen on (0 for random). <br />
Using a port below 1024 may require elevated privileges.
</>
}
>
<NumericInput
id="port"
min={0}
max={65535}
value={state.port}
onValueChange={(port) => {
setState({ ...state, port });
setPort(port);
}}
disabled={state.loading}
/>
</FormGroup>
);
}
Loading

0 comments on commit f3f18d0

Please sign in to comment.