Skip to content

Commit

Permalink
lsp/server: Cache AllRegoVersions at config load (#1325)
Browse files Browse the repository at this point in the history
* lsp/server: Cache AllRegoVersions at config load

This does not reload when a .manifest file is changed yet.

Signed-off-by: Charlie Egan <charlie@styra.com>

* concurrent/map: Nil map is empty

Signed-off-by: Charlie Egan <charlie@styra.com>

---------

Signed-off-by: Charlie Egan <charlie@styra.com>
  • Loading branch information
charlieegan3 authored and anderseknert committed Jan 13, 2025
1 parent f9f30eb commit 50c810c
Show file tree
Hide file tree
Showing 4 changed files with 260 additions and 30 deletions.
89 changes: 69 additions & 20 deletions internal/lsp/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,20 +88,21 @@ func NewLanguageServer(ctx context.Context, opts *LanguageServerOptions) *Langua

var ls *LanguageServer
ls = &LanguageServer{
cache: c,
regoStore: store,
logWriter: opts.LogWriter,
logLevel: opts.LogLevel,
lintFileJobs: make(chan lintFileJob, 10),
lintWorkspaceJobs: make(chan lintWorkspaceJob, 10),
builtinsPositionJobs: make(chan lintFileJob, 10),
commandRequest: make(chan types.ExecuteCommandParams, 10),
templateFileJobs: make(chan lintFileJob, 10),
configWatcher: lsconfig.NewWatcher(&lsconfig.WatcherOpts{LogFunc: ls.logf}),
completionsManager: completions.NewDefaultManager(ctx, c, store),
webServer: web.NewServer(c),
loadedBuiltins: concurrent.MapOf(make(map[string]map[string]*ast.Builtin)),
workspaceDiagnosticsPoll: opts.WorkspaceDiagnosticsPoll,
cache: c,
regoStore: store,
logWriter: opts.LogWriter,
logLevel: opts.LogLevel,
lintFileJobs: make(chan lintFileJob, 10),
lintWorkspaceJobs: make(chan lintWorkspaceJob, 10),
builtinsPositionJobs: make(chan lintFileJob, 10),
commandRequest: make(chan types.ExecuteCommandParams, 10),
templateFileJobs: make(chan lintFileJob, 10),
configWatcher: lsconfig.NewWatcher(&lsconfig.WatcherOpts{LogFunc: ls.logf}),
completionsManager: completions.NewDefaultManager(ctx, c, store),
webServer: web.NewServer(c),
loadedBuiltins: concurrent.MapOf(make(map[string]map[string]*ast.Builtin)),
workspaceDiagnosticsPoll: opts.WorkspaceDiagnosticsPoll,
loadedConfigAllRegoVersions: concurrent.MapOf(make(map[string]ast.RegoVersion)),
}

return ls
Expand All @@ -114,10 +115,13 @@ type LanguageServer struct {
regoStore storage.Store
conn *jsonrpc2.Conn

configWatcher *lsconfig.Watcher
loadedConfig *config.Config
configWatcher *lsconfig.Watcher
loadedConfig *config.Config
// this is also used to lock the updates to the cache of enabled rules
loadedConfigLock sync.Mutex
loadedConfigEnabledNonAggregateRules []string
loadedConfigEnabledAggregateRules []string
loadedConfigAllRegoVersions *concurrent.Map[string, ast.RegoVersion]
loadedBuiltins *concurrent.Map[string, map[string]*ast.Builtin]

clientInitializationOptions types.InitializationOptions
Expand All @@ -138,9 +142,6 @@ type LanguageServer struct {
workspaceRootURI string
clientIdentifier clients.Identifier

// this is also used to lock the updates to the cache of enabled rules
loadedConfigLock sync.Mutex

workspaceDiagnosticsPoll time.Duration
}

Expand Down Expand Up @@ -551,6 +552,22 @@ func (l *LanguageServer) StartConfigWorker(ctx context.Context) {
l.loadedConfig = &mergedConfig
l.loadedConfigLock.Unlock()

// Rego versions may have changed, so reload them.
allRegoVersions, err := config.AllRegoVersions(
uri.ToPath(l.clientIdentifier, l.workspaceRootURI),
l.getLoadedConfig(),
)
if err != nil {
l.logf(log.LevelMessage, "failed to reload rego versions: %s", err)
}

l.loadedConfigAllRegoVersions.Clear()

for k, v := range allRegoVersions {
l.loadedConfigAllRegoVersions.Set(k, v)
}

// Enabled rules might have changed with the new config, so reload.
err = l.loadEnabledRulesFromConfig(ctx, mergedConfig)
if err != nil {
l.logf(log.LevelMessage, "failed to cache enabled rules: %s", err)
Expand Down Expand Up @@ -1096,6 +1113,32 @@ func (l *LanguageServer) StartWebServer(ctx context.Context) {
l.webServer.Start(ctx)
}

func (l *LanguageServer) determineVersionForFile(fileURI string) ast.RegoVersion {
var versionedDirs []string

// if we have no information, then we can return the default
if l.loadedConfigAllRegoVersions.Len() == 0 {
return ast.RegoV1
}

versionedDirs = util.Keys(l.loadedConfigAllRegoVersions.Clone())
slices.Sort(versionedDirs)
slices.Reverse(versionedDirs)

path := strings.TrimPrefix(fileURI, l.workspaceRootURI+"/")

for _, versionedDir := range versionedDirs {
if strings.HasPrefix(path, versionedDir) {
val, ok := l.loadedConfigAllRegoVersions.Get(versionedDir)
if ok {
return val
}
}
}

return ast.RegoV1
}

func (l *LanguageServer) templateContentsForFile(fileURI string) (string, error) {
// this function should not be called with files in the root, but if it is,
// then it is an error to prevent unwanted behavior.
Expand Down Expand Up @@ -1185,7 +1228,13 @@ func (l *LanguageServer) templateContentsForFile(fileURI string) (string, error)
pkg += "_test"
}

return fmt.Sprintf("package %s\n\nimport rego.v1\n", pkg), nil
version := l.determineVersionForFile(fileURI)

if version == ast.RegoV0 {
return fmt.Sprintf("package %s\n\nimport rego.v1\n", pkg), nil
}

return fmt.Sprintf("package %s\n\n", pkg), nil
}

func (l *LanguageServer) fixEditParams(
Expand Down
148 changes: 148 additions & 0 deletions internal/lsp/server_all_rego_versions_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
package lsp

import (
"context"
"os"
"path/filepath"
"testing"
"time"

"github.com/open-policy-agent/opa/v1/ast"

"github.com/styrainc/regal/internal/lsp/clients"
"github.com/styrainc/regal/internal/lsp/log"
"github.com/styrainc/regal/internal/lsp/uri"
"github.com/styrainc/regal/pkg/config"
)

func TestAllRegoVersions(t *testing.T) {
t.Parallel()

testCases := map[string]struct {
FileKey string
ExpectedVersion ast.RegoVersion
DiskContents map[string]string
}{
"unknown version": {
FileKey: "foo/bar.rego",
DiskContents: map[string]string{
"foo/bar.rego": "package foo",
".regal/config.yaml": "",
},
ExpectedVersion: ast.RegoV1,
},
"version set in project config": {
FileKey: "foo/bar.rego",
DiskContents: map[string]string{
"foo/bar.rego": "package foo",
".regal/config.yaml": `
project:
rego-version: 0
`,
},
ExpectedVersion: ast.RegoV0,
},
"version set in root config": {
FileKey: "foo/bar.rego",
DiskContents: map[string]string{
"foo/bar.rego": "package foo",
".regal/config.yaml": `
project:
rego-version: 1
roots:
- path: foo
rego-version: 0
`,
},
ExpectedVersion: ast.RegoV0,
},
"version set in manifest": {
FileKey: "foo/bar.rego",
DiskContents: map[string]string{
"foo/bar.rego": "package foo",
"foo/.manifest": `{"rego_version": 0}`,
".regal/config.yaml": ``,
},
ExpectedVersion: ast.RegoV0,
},
"version set in manifest, overridden by config": {
FileKey: "foo/bar.rego",
DiskContents: map[string]string{
"foo/bar.rego": "package foo",
"foo/.manifest": `{"rego_version": 1}`,
".regal/config.yaml": `
project:
roots:
- path: foo
rego-version: 0
`,
},
ExpectedVersion: ast.RegoV0,
},
}

for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
t.Parallel()

td := t.TempDir()

// init the state on disk
for f, c := range tc.DiskContents {
dir := filepath.Dir(f)

if err := os.MkdirAll(filepath.Join(td, dir), 0o755); err != nil {
t.Fatalf("failed to create directory %s: %s", dir, err)
}

if err := os.WriteFile(filepath.Join(td, f), []byte(c), 0o600); err != nil {
t.Fatalf("failed to write file %s: %s", f, err)
}
}

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

ls := NewLanguageServer(ctx, &LanguageServerOptions{LogWriter: newTestLogger(t), LogLevel: log.LevelDebug})
ls.workspaceRootURI = uri.FromPath(clients.IdentifierGeneric, td)

// have the server load the config
go ls.StartConfigWorker(ctx)

configFile, err := config.FindConfig(td)
if err != nil {
t.Fatalf("unexpected error: %s", err)
}

ls.configWatcher.Watch(configFile.Name())

// wait for ls.loadedConfig to be set
timeout := time.NewTimer(determineTimeout())
defer timeout.Stop()

for success := false; !success; {
select {
default:
if ls.getLoadedConfig() != nil {
success = true

break
}

time.Sleep(500 * time.Millisecond)
case <-timeout.C:
t.Fatalf("timed out waiting for config to be set")
}
}

// check it has the correct version for the file of interest
fileURI := uri.FromPath(clients.IdentifierGeneric, filepath.Join(td, tc.FileKey))

version := ls.determineVersionForFile(fileURI)

if version != tc.ExpectedVersion {
t.Errorf("expected version %v, got %v", tc.ExpectedVersion, version)
}
})
}
}
49 changes: 39 additions & 10 deletions internal/lsp/server_template_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,26 @@ import (

"github.com/sourcegraph/jsonrpc2"

"github.com/open-policy-agent/opa/v1/ast"

"github.com/styrainc/regal/internal/lsp/clients"
"github.com/styrainc/regal/internal/lsp/log"
"github.com/styrainc/regal/internal/lsp/types"
"github.com/styrainc/regal/internal/lsp/uri"
"github.com/styrainc/regal/internal/util/concurrent"
)

func TestTemplateContentsForFile(t *testing.T) {
t.Parallel()

testCases := map[string]struct {
FileKey string
CacheFileContents string
DiskContents map[string]string
RequireConfig bool
ExpectedContents string
ExpectedError string
FileKey string
CacheFileContents string
DiskContents map[string]string
RequireConfig bool
ServerAllRegoVersions *concurrent.Map[string, ast.RegoVersion]
ExpectedContents string
ExpectedError string
}{
"existing contents in file": {
FileKey: "foo/bar.rego",
Expand All @@ -49,7 +53,7 @@ func TestTemplateContentsForFile(t *testing.T) {
"foo/bar.rego": "",
".regal/config.yaml": "",
},
ExpectedContents: "package foo\n\nimport rego.v1\n",
ExpectedContents: "package foo\n\n",
},
"empty test file is templated based on root": {
FileKey: "foo/bar_test.rego",
Expand All @@ -59,7 +63,7 @@ func TestTemplateContentsForFile(t *testing.T) {
".regal/config.yaml": "",
},
RequireConfig: true,
ExpectedContents: "package foo_test\n\nimport rego.v1\n",
ExpectedContents: "package foo_test\n\n",
},
"empty deeply nested file is templated based on root": {
FileKey: "foo/bar/baz/bax.rego",
Expand All @@ -68,8 +72,32 @@ func TestTemplateContentsForFile(t *testing.T) {
"foo/bar/baz/bax.rego": "",
".regal/config.yaml": "",
},
ExpectedContents: "package foo.bar.baz\n\n",
},
"v0 templating using rego version setting": {
FileKey: "foo/bar/baz/bax.rego",
CacheFileContents: "",
ServerAllRegoVersions: concurrent.MapOf(map[string]ast.RegoVersion{
"foo": ast.RegoV0,
}),
DiskContents: map[string]string{
"foo/bar/baz/bax.rego": "",
".regal/config.yaml": "", // we manually set the versions, config not loaded in these tests
},
ExpectedContents: "package foo.bar.baz\n\nimport rego.v1\n",
},
"v1 templating using rego version setting": {
FileKey: "foo/bar/baz/bax.rego",
CacheFileContents: "",
ServerAllRegoVersions: concurrent.MapOf(map[string]ast.RegoVersion{
"foo": ast.RegoV1,
}),
DiskContents: map[string]string{
"foo/bar/baz/bax.rego": "",
".regal/config.yaml": "", // we manually set the versions, config not loaded in these tests
},
ExpectedContents: "package foo.bar.baz\n\n",
},
}

for name, tc := range testCases {
Expand Down Expand Up @@ -100,6 +128,8 @@ func TestTemplateContentsForFile(t *testing.T) {

ls.workspaceRootURI = uri.FromPath(clients.IdentifierGeneric, td)

ls.loadedConfigAllRegoVersions = tc.ServerAllRegoVersions

fileURI := uri.FromPath(clients.IdentifierGeneric, filepath.Join(td, tc.FileKey))

ls.cache.SetFileContents(fileURI, tc.CacheFileContents)
Expand Down Expand Up @@ -188,7 +218,6 @@ func TestTemplateContentsForFileWithUnknownRoot(t *testing.T) {

exp := `package foo
import rego.v1
`
if exp != newContents {
t.Errorf("unexpected content: %s, want %s", newContents, exp)
Expand Down Expand Up @@ -279,7 +308,7 @@ func TestNewFileTemplating(t *testing.T) {
{
"edits": [
{
"newText": "package foo.bar_test\n\nimport rego.v1\n",
"newText": "package foo.bar_test\n\n",
"range": {
"end": {
"character": 0,
Expand Down
Loading

0 comments on commit 50c810c

Please sign in to comment.