diff --git a/README.md b/README.md index b7e6a76..3e0c580 100644 --- a/README.md +++ b/README.md @@ -156,6 +156,17 @@ type = "sqlite" path = "./data/dashbrr.db" ``` +By default, the database file will be created in the same directory as your configuration file. For example: + +- If your config is at `/home/user/.config/dashbrr/config.toml`, the database will be at `/home/user/.config/dashbrr/data/dashbrr.db` +- If your config is at `/etc/dashbrr/config.toml`, the database will be at `/etc/dashbrr/data/dashbrr.db` + +You can override this behavior by using the `-db` flag to specify a different database location: + +```bash +dashbrr -config=/etc/dashbrr/config.toml -db=/var/lib/dashbrr/dashbrr.db +``` + ### Environment Variables For a complete list of available environment variables and their configurations, see our [Environment Variables Documentation](docs/env_vars.md). diff --git a/cmd/dashbrr/main.go b/cmd/dashbrr/main.go index 4ca1d7a..07ffdcb 100644 --- a/cmd/dashbrr/main.go +++ b/cmd/dashbrr/main.go @@ -10,6 +10,7 @@ import ( "net/http" "os" "os/signal" + "path/filepath" "strings" "syscall" "time" @@ -56,11 +57,46 @@ func startServer() { Str("build_date", date). Msg("Starting dashbrr") - configPath := flag.String("config", "config.toml", "path to config file") - dbPath := flag.String("db", "./data/dashbrr.db", "path to database file") + // Check environment variable first, then fall back to flag + defaultConfigPath := "config.toml" + if envPath := os.Getenv(config.EnvConfigPath); envPath != "" { + defaultConfigPath = envPath + } else { + // Check user config directory + userConfigDir, err := os.UserConfigDir() + if err != nil { + log.Error().Err(err).Msg("failed to get user config directory") + } + + base := []string{filepath.Join(userConfigDir, "dashbrr"), "/config"} + configs := []string{"config.toml", "config.yaml", "config.yml"} + + for _, b := range base { + for _, c := range configs { + p := filepath.Join(b, c) + if _, err := os.Stat(p); err == nil { + defaultConfigPath = p + break + } + } + if defaultConfigPath != "config.toml" { + break + } + } + } + configPath := flag.String("config", defaultConfigPath, "path to config file") + + var dbPath string + flag.StringVar(&dbPath, "db", "", "path to database file") listenAddr := flag.String("listen", ":8080", "address to listen on") flag.Parse() + // If dbPath wasn't set via flag, use config directory + if dbPath == "" { + configDir := filepath.Dir(*configPath) + dbPath = filepath.Join(configDir, "data", "dashbrr.db") + } + var cfg *config.Config var err error @@ -77,7 +113,7 @@ func startServer() { ListenAddr: *listenAddr, }, Database: config.DatabaseConfig{ - Path: *dbPath, + Path: dbPath, }, } log.Warn().Err(err).Msg("Failed to load configuration file, using defaults") diff --git a/docs/commands.md b/docs/commands.md index a3ecc74..4850191 100644 --- a/docs/commands.md +++ b/docs/commands.md @@ -2,6 +2,41 @@ This document outlines all available CLI commands in Dashbrr. +## Startup Flags + +When starting Dashbrr, you can use the following flags to control its configuration: + +```bash +# Start Dashbrr with default settings +dashbrr + +# Specify a custom config file location +dashbrr -config=/path/to/config.toml + +# Specify a custom database location +dashbrr -db=/path/to/database.db + +# Specify a custom listen address +dashbrr -listen=:8081 +``` + +By default: + +- The config file is loaded from `./config.toml` +- The database file is created in the same directory as the config file at `/data/dashbrr.db` +- The server listens on port 8080 + +For example: + +```bash +# Using config in /etc/dashbrr +dashbrr -config=/etc/dashbrr/config.toml +# Database will be created at /etc/dashbrr/data/dashbrr.db + +# Override default database location +dashbrr -config=/etc/dashbrr/config.toml -db=/var/lib/dashbrr/dashbrr.db +``` + ## Core Commands ### User Management diff --git a/docs/env_vars.md b/docs/env_vars.md index 22a7c43..20b3adc 100644 --- a/docs/env_vars.md +++ b/docs/env_vars.md @@ -7,6 +7,18 @@ - Format: `:` - Default: `0.0.0.0:8080` +## Configuration Path + +- `DASHBRR__CONFIG_PATH` + - Purpose: Path to the configuration file + - Default: `config.toml` + - Priority: Environment variable > User config directory > Command line flag > Default value + - Note: The application will check the following locations for the configuration file: + 1. The path specified by the `DASHBRR__CONFIG_PATH` environment variable. + 2. The user config directory (e.g., `~/.config/dashbrr`). + 3. The current working directory for `config.toml`, `config.yaml`, or `config.yml`. + 4. The `-config` command line flag can also be used to specify a different path. + ## Cache Configuration - `CACHE_TYPE` @@ -38,6 +50,11 @@ - `DASHBRR__DB_PATH` - Purpose: Path to SQLite database file - Example: `/data/dashbrr.db` + - Note: If not set, the database will be created in a 'data' subdirectory of the config file's location. This can be overridden by: + 1. Using the `-db` flag when starting dashbrr + 2. Setting this environment variable + 3. Specifying the path in the config file + - Priority: Command line flag > Environment variable > Config file > Default location ### PostgreSQL Configuration diff --git a/internal/api/middleware/auth.go b/internal/api/middleware/auth.go index e62a601..beb81b1 100644 --- a/internal/api/middleware/auth.go +++ b/internal/api/middleware/auth.go @@ -61,7 +61,7 @@ func (m *AuthMiddleware) RequireAuth() gin.HandlerFunc { sessionKey = fmt.Sprintf("session:%s", sessionToken) err = m.cache.Get(c, sessionKey, &sessionData) if err != nil { - log.Error().Err(err).Msg("session not found") + log.Debug().Err(err).Msg("session not found") c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid or expired session"}) c.Abort() return diff --git a/internal/api/routes/routes.go b/internal/api/routes/routes.go index e4f67a6..b169897 100644 --- a/internal/api/routes/routes.go +++ b/internal/api/routes/routes.go @@ -5,6 +5,7 @@ package routes import ( "os" + "path/filepath" "time" "github.com/gin-gonic/gin" @@ -26,15 +27,34 @@ func SetupRoutes(r *gin.Engine, db *database.DB, health *services.HealthService) r.Use(middleware.SetupCORS()) r.Use(middleware.Secure(nil)) // Add secure middleware with default config - // Initialize cache - store, err := cache.InitCache() + // Initialize cache with database directory for session storage + cacheConfig := cache.Config{ + DataDir: filepath.Dir(os.Getenv("DASHBRR__DB_PATH")), // Use same directory as database + } + + // Configure Redis if enabled + if os.Getenv("REDIS_HOST") != "" { + host := os.Getenv("REDIS_HOST") + port := os.Getenv("REDIS_PORT") + if port == "" { + port = "6379" + } + cacheConfig.RedisAddr = host + ":" + port + } + + store, err := cache.InitCache(cacheConfig) if err != nil { // This should never happen as InitCache always returns a valid store log.Debug().Err(err).Msg("Using memory cache") - store = cache.NewMemoryStore() + store = cache.NewMemoryStore(cacheConfig.DataDir) } - log.Debug().Str("type", os.Getenv("CACHE_TYPE")).Msg("Cache initialized") + // Determine cache type based on environment and Redis configuration + cacheType := "memory" + if os.Getenv("CACHE_TYPE") == "redis" && os.Getenv("REDIS_HOST") != "" { + cacheType = "redis" + } + log.Debug().Str("type", cacheType).Msg("Cache initialized") // Create rate limiters with different configurations apiRateLimiter := middleware.NewRateLimiter(store, time.Minute, 60, "api:") // 60 requests per minute for API diff --git a/internal/config/config.go b/internal/config/config.go index 9102602..1c09d53 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -11,6 +11,10 @@ import ( "github.com/pelletier/go-toml/v2" ) +const ( + EnvConfigPath = "DASHBRR__CONFIG_PATH" +) + // Config represents the main configuration structure type Config struct { Server ServerConfig `toml:"server"` diff --git a/internal/database/database.go b/internal/database/database.go index 1c9dc08..b5e2365 100644 --- a/internal/database/database.go +++ b/internal/database/database.go @@ -22,6 +22,7 @@ import ( type DB struct { *sql.DB driver string + path string } // Config holds database configuration @@ -134,7 +135,11 @@ func InitDBWithConfig(config *Config) (*DB, error) { Str("driver", config.Driver). Msg("Successfully connected to database") - db := &DB{database, config.Driver} + db := &DB{ + DB: database, + driver: config.Driver, + path: config.Path, + } // Initialize schema if err := db.initSchema(); err != nil { @@ -144,6 +149,11 @@ func InitDBWithConfig(config *Config) (*DB, error) { return db, nil } +// Path returns the database file path (for SQLite) +func (db *DB) Path() string { + return db.path +} + // initSchema creates the necessary database tables func (db *DB) initSchema() error { var autoIncrement string diff --git a/internal/services/cache/cache.go b/internal/services/cache/cache.go index bf27c64..12136c1 100644 --- a/internal/services/cache/cache.go +++ b/internal/services/cache/cache.go @@ -6,7 +6,7 @@ package cache import ( "context" "encoding/json" - "os" + "errors" "strconv" "strings" "sync" @@ -17,7 +17,8 @@ import ( ) var ( - ErrKeyNotFound = redis.Nil + ErrKeyNotFound = errors.New("cache: key not found") + ErrClosed = errors.New("cache: store is closed") ) const ( @@ -62,11 +63,6 @@ type localCacheItem struct { // NewCache creates a new Redis cache instance with optimized configuration func NewCache(addr string) (Store, error) { - // Check if memory cache is explicitly requested - if os.Getenv("CACHE_TYPE") == "memory" { - return NewMemoryStore(), nil - } - ctx, cancel := context.WithCancel(context.Background()) // Development-optimized Redis configuration @@ -134,7 +130,7 @@ func (s *RedisStore) Get(ctx context.Context, key string, value interface{}) err s.mu.RLock() if s.closed { s.mu.RUnlock() - return redis.ErrClosed + return ErrClosed } s.mu.RUnlock() @@ -186,6 +182,9 @@ func (s *RedisStore) Get(ctx context.Context, key string, value interface{}) err } } + if lastErr == redis.Nil { + return ErrKeyNotFound + } return lastErr } @@ -194,7 +193,7 @@ func (s *RedisStore) Set(ctx context.Context, key string, value interface{}, exp s.mu.RLock() if s.closed { s.mu.RUnlock() - return redis.ErrClosed + return ErrClosed } s.mu.RUnlock() @@ -294,7 +293,7 @@ func (s *RedisStore) Delete(ctx context.Context, key string) error { s.mu.RLock() if s.closed { s.mu.RUnlock() - return redis.ErrClosed + return ErrClosed } s.mu.RUnlock() @@ -332,7 +331,7 @@ func (s *RedisStore) Increment(ctx context.Context, key string, timestamp int64) s.mu.RLock() if s.closed { s.mu.RUnlock() - return redis.ErrClosed + return ErrClosed } s.mu.RUnlock() @@ -367,7 +366,7 @@ func (s *RedisStore) CleanAndCount(ctx context.Context, key string, windowStart s.mu.RLock() if s.closed { s.mu.RUnlock() - return redis.ErrClosed + return ErrClosed } s.mu.RUnlock() @@ -399,7 +398,7 @@ func (s *RedisStore) GetCount(ctx context.Context, key string) (int64, error) { s.mu.RLock() if s.closed { s.mu.RUnlock() - return 0, redis.ErrClosed + return 0, ErrClosed } s.mu.RUnlock() @@ -430,7 +429,7 @@ func (s *RedisStore) Expire(ctx context.Context, key string, expiration time.Dur s.mu.RLock() if s.closed { s.mu.RUnlock() - return redis.ErrClosed + return ErrClosed } s.mu.RUnlock() @@ -466,7 +465,7 @@ func (s *RedisStore) Close() error { s.mu.Lock() if s.closed { s.mu.Unlock() - return redis.ErrClosed + return ErrClosed } s.closed = true s.mu.Unlock() @@ -484,6 +483,6 @@ func (s *RedisStore) Close() error { s.local.items = make(map[string]*localCacheItem) }() - // Close Redis client + // Close client return s.client.Close() } diff --git a/internal/services/cache/init.go b/internal/services/cache/init.go index 8857ca5..078b936 100644 --- a/internal/services/cache/init.go +++ b/internal/services/cache/init.go @@ -5,7 +5,6 @@ package cache import ( "context" - "fmt" "os" "strings" "time" @@ -14,6 +13,15 @@ import ( "github.com/rs/zerolog/log" ) +// Config holds cache configuration options +type Config struct { + // Redis configuration + RedisAddr string + + // Memory cache configuration + DataDir string // Directory for persistent storage (derived from DB path) +} + // CacheType represents the type of cache to use type CacheType string @@ -23,20 +31,9 @@ const ( ) // getRedisOptions returns Redis configuration optimized for the current environment -func getRedisOptions() *redis.Options { +func getRedisOptions(addr string) *redis.Options { isDev := os.Getenv("GIN_MODE") != "release" - // Get Redis connection details from environment - host := os.Getenv("REDIS_HOST") - if host == "" { - host = "localhost" - } - port := os.Getenv("REDIS_PORT") - if port == "" { - port = "6379" - } - addr := fmt.Sprintf("%s:%s", host, port) - // Base configuration opts := &redis.Options{ Addr: addr, @@ -89,21 +86,21 @@ func getCacheType() CacheType { } } -// InitCache initializes a cache instance based on environment configuration. +// InitCache initializes a cache instance based on configuration. // It always returns a valid cache store, falling back to memory cache if Redis fails. -func InitCache() (Store, error) { +func InitCache(cfg Config) (Store, error) { cacheType := getCacheType() switch cacheType { case CacheTypeRedis: - // Only attempt Redis connection if Redis is explicitly configured - if os.Getenv("REDIS_HOST") == "" { - // Silently fall back to memory cache when Redis host isn't configured - return NewMemoryStore(), nil + // Only attempt Redis connection if Redis address is configured + if cfg.RedisAddr == "" { + // Silently fall back to memory cache when Redis isn't configured + return NewMemoryStore(cfg.DataDir), nil } isDev := os.Getenv("GIN_MODE") != "release" - opts := getRedisOptions() + opts := getRedisOptions(cfg.RedisAddr) // Create context with shorter timeout for development timeout := DefaultTimeout @@ -125,7 +122,7 @@ func InitCache() (Store, error) { // Only log error if Redis was explicitly requested log.Error().Err(err).Str("addr", opts.Addr).Msg("Failed to connect to explicitly configured Redis, falling back to memory cache") } - return NewMemoryStore(), err + return NewMemoryStore(cfg.DataDir), err } // Initialize Redis cache store @@ -135,15 +132,15 @@ func InitCache() (Store, error) { // Only log error if Redis was explicitly requested log.Error().Err(err).Msg("Failed to initialize explicitly configured Redis cache, falling back to memory cache") } - return NewMemoryStore(), err + return NewMemoryStore(cfg.DataDir), err } return store, nil case CacheTypeMemory: - return NewMemoryStore(), nil + return NewMemoryStore(cfg.DataDir), nil default: // This shouldn't happen due to getCacheType's default - return NewMemoryStore(), nil + return NewMemoryStore(cfg.DataDir), nil } } diff --git a/internal/services/cache/memory.go b/internal/services/cache/memory.go index 7bee0ae..058ce57 100644 --- a/internal/services/cache/memory.go +++ b/internal/services/cache/memory.go @@ -6,7 +6,10 @@ package cache import ( "context" "encoding/json" + "os" + "path/filepath" "strconv" + "strings" "sync" "time" @@ -24,6 +27,9 @@ type MemoryStore struct { // Additional maps for rate limiting functionality rateLimits sync.Map // map[string]*rateWindow + + // Session persistence + persistPath string } type rateWindow struct { @@ -31,18 +37,39 @@ type rateWindow struct { timestamps map[string]int64 } +type persistedItem struct { + Value []byte `json:"value"` + Expiration time.Time `json:"expiration"` +} + // NewMemoryStore creates a new in-memory cache instance -func NewMemoryStore() Store { +func NewMemoryStore(dataDir string) Store { ctx, cancel := context.WithCancel(context.Background()) store := &MemoryStore{ local: &LocalCache{ items: make(map[string]*localCacheItem), }, - ctx: ctx, - cancel: cancel, + ctx: ctx, + cancel: cancel, + persistPath: filepath.Join(dataDir, "sessions.json"), + } + + // Ensure directory exists with proper permissions + if err := os.MkdirAll(dataDir, 0700); err != nil { + log.Error().Err(err).Msg("Failed to create data directory") + } + + // Set proper permissions on sessions file if it exists + if _, err := os.Stat(store.persistPath); err == nil { + if err := os.Chmod(store.persistPath, 0600); err != nil { + log.Error().Err(err).Msg("Failed to set permissions on sessions file") + } } + // Load persisted sessions + store.loadSessions() + // Start cleanup goroutine store.wg.Add(1) go func() { @@ -53,12 +80,83 @@ func NewMemoryStore() Store { return store } +// loadSessions loads persisted sessions from disk +func (s *MemoryStore) loadSessions() { + data, err := os.ReadFile(s.persistPath) + if err != nil { + if !os.IsNotExist(err) { + log.Error().Err(err).Msg("Failed to read persisted sessions") + } + return + } + + var items map[string]persistedItem + if err := json.Unmarshal(data, &items); err != nil { + log.Error().Err(err).Msg("Failed to unmarshal persisted sessions") + return + } + + now := time.Now() + s.local.Lock() + for key, item := range items { + // Only load non-expired sessions + if now.Before(item.Expiration) { + s.local.items[key] = &localCacheItem{ + value: item.Value, + expiration: item.Expiration, + } + } + } + s.local.Unlock() +} + +// persistSessions saves sessions to disk +func (s *MemoryStore) persistSessions() { + s.local.RLock() + items := make(map[string]persistedItem) + now := time.Now() + + for key, item := range s.local.items { + // Only persist session data (not rate limiting or other cache items) + if strings.HasPrefix(key, "session:") || strings.HasPrefix(key, "oidc:session:") { + // Only persist non-expired sessions + if now.Before(item.expiration) { + items[key] = persistedItem{ + Value: item.value, + Expiration: item.expiration, + } + } + } + } + s.local.RUnlock() + + data, err := json.Marshal(items) + if err != nil { + log.Error().Err(err).Msg("Failed to marshal sessions for persistence") + return + } + + // Write to a temporary file first + tempFile := s.persistPath + ".tmp" + if err := os.WriteFile(tempFile, data, 0600); err != nil { + log.Error().Err(err).Msg("Failed to write temporary sessions file") + return + } + + // Rename temporary file to actual file (atomic operation) + if err := os.Rename(tempFile, s.persistPath); err != nil { + log.Error().Err(err).Msg("Failed to rename temporary sessions file") + _ = os.Remove(tempFile) // Clean up temp file if rename failed + return + } +} + // Get retrieves a value from cache func (s *MemoryStore) Get(ctx context.Context, key string, value interface{}) error { s.mu.RLock() if s.closed { s.mu.RUnlock() - return ErrKeyNotFound + return ErrClosed } s.mu.RUnlock() @@ -81,7 +179,7 @@ func (s *MemoryStore) Set(ctx context.Context, key string, value interface{}, ex s.mu.RLock() if s.closed { s.mu.RUnlock() - return ErrKeyNotFound + return ErrClosed } s.mu.RUnlock() @@ -102,6 +200,11 @@ func (s *MemoryStore) Set(ctx context.Context, key string, value interface{}, ex } s.local.Unlock() + // Persist sessions when they're updated + if strings.HasPrefix(key, "session:") || strings.HasPrefix(key, "oidc:session:") { + s.persistSessions() + } + return nil } @@ -110,7 +213,7 @@ func (s *MemoryStore) Delete(ctx context.Context, key string) error { s.mu.RLock() if s.closed { s.mu.RUnlock() - return ErrKeyNotFound + return ErrClosed } s.mu.RUnlock() @@ -118,6 +221,11 @@ func (s *MemoryStore) Delete(ctx context.Context, key string) error { delete(s.local.items, key) s.local.Unlock() + // Persist sessions when they're deleted + if strings.HasPrefix(key, "session:") || strings.HasPrefix(key, "oidc:session:") { + s.persistSessions() + } + return nil } @@ -126,7 +234,7 @@ func (s *MemoryStore) Increment(ctx context.Context, key string, timestamp int64 s.mu.RLock() if s.closed { s.mu.RUnlock() - return ErrKeyNotFound + return ErrClosed } s.mu.RUnlock() @@ -147,7 +255,7 @@ func (s *MemoryStore) CleanAndCount(ctx context.Context, key string, windowStart s.mu.RLock() if s.closed { s.mu.RUnlock() - return ErrKeyNotFound + return ErrClosed } s.mu.RUnlock() @@ -170,7 +278,7 @@ func (s *MemoryStore) GetCount(ctx context.Context, key string) (int64, error) { s.mu.RLock() if s.closed { s.mu.RUnlock() - return 0, ErrKeyNotFound + return 0, ErrClosed } s.mu.RUnlock() @@ -190,13 +298,17 @@ func (s *MemoryStore) Expire(ctx context.Context, key string, expiration time.Du s.mu.RLock() if s.closed { s.mu.RUnlock() - return ErrKeyNotFound + return ErrClosed } s.mu.RUnlock() s.local.Lock() if item, exists := s.local.items[key]; exists { item.expiration = time.Now().Add(expiration) + // Persist sessions when their expiration is updated + if strings.HasPrefix(key, "session:") || strings.HasPrefix(key, "oidc:session:") { + s.persistSessions() + } } s.local.Unlock() @@ -208,7 +320,7 @@ func (s *MemoryStore) Close() error { s.mu.Lock() if s.closed { s.mu.Unlock() - return ErrKeyNotFound + return ErrClosed } s.closed = true s.mu.Unlock() @@ -216,6 +328,9 @@ func (s *MemoryStore) Close() error { s.cancel() s.wg.Wait() + // Persist sessions before clearing the cache + s.persistSessions() + s.local.Lock() s.local.items = make(map[string]*localCacheItem) s.local.Unlock() @@ -231,16 +346,25 @@ func (s *MemoryStore) localCacheCleanup() { select { case <-ticker.C: now := time.Now() + needsPersist := false // Cleanup main cache s.local.Lock() for key, item := range s.local.items { if now.After(item.expiration) { delete(s.local.items, key) + if strings.HasPrefix(key, "session:") || strings.HasPrefix(key, "oidc:session:") { + needsPersist = true + } } } s.local.Unlock() + // Persist sessions if any were removed + if needsPersist { + s.persistSessions() + } + // Cleanup rate limiting windows older than 24 hours windowStart := time.Now().Add(-24 * time.Hour).Unix() s.rateLimits.Range(func(key, value interface{}) bool { diff --git a/internal/services/cache/memory_test.go b/internal/services/cache/memory_test.go index c9e5bc4..0055471 100644 --- a/internal/services/cache/memory_test.go +++ b/internal/services/cache/memory_test.go @@ -10,7 +10,10 @@ import ( ) func TestMemoryStore(t *testing.T) { - store := NewMemoryStore() + // Create a temporary directory for testing + tempDir := t.TempDir() + + store := NewMemoryStore(tempDir) defer store.Close() ctx := context.Background() @@ -147,7 +150,10 @@ func TestMemoryStore(t *testing.T) { } func TestMemoryStoreClose(t *testing.T) { - store := NewMemoryStore() + // Create a temporary directory for testing + tempDir := t.TempDir() + + store := NewMemoryStore(tempDir) // Test normal operations ctx := context.Background() @@ -164,13 +170,47 @@ func TestMemoryStoreClose(t *testing.T) { // Verify operations fail after close err = store.Set(ctx, "key2", "value2", time.Minute) - if err != ErrKeyNotFound { - t.Errorf("Expected ErrKeyNotFound after close, got %v", err) + if err != ErrClosed { + t.Errorf("Expected ErrClosed after close, got %v", err) } var result string err = store.Get(ctx, "key", &result) - if err != ErrKeyNotFound { - t.Errorf("Expected ErrKeyNotFound after close, got %v", err) + if err != ErrClosed { + t.Errorf("Expected ErrClosed after close, got %v", err) + } +} + +func TestMemoryStorePersistence(t *testing.T) { + // Create a temporary directory for testing + tempDir := t.TempDir() + + // Create a store and add some data + store := NewMemoryStore(tempDir) + ctx := context.Background() + + err := store.Set(ctx, "session:test", "test_value", time.Hour) + if err != nil { + t.Errorf("Failed to set value: %v", err) + } + + // Close the store + err = store.Close() + if err != nil { + t.Errorf("Failed to close store: %v", err) + } + + // Create a new store with the same directory + store2 := NewMemoryStore(tempDir) + defer store2.Close() + + // Try to get the persisted value + var result string + err = store2.Get(ctx, "session:test", &result) + if err != nil { + t.Errorf("Failed to get persisted value: %v", err) + } + if result != "test_value" { + t.Errorf("Expected 'test_value', got '%v'", result) } } diff --git a/internal/services/core/service.go b/internal/services/core/service.go index 0aa3383..d7a4ded 100644 --- a/internal/services/core/service.go +++ b/internal/services/core/service.go @@ -8,6 +8,8 @@ import ( "errors" "io" "net/http" + "os" + "path/filepath" "sync" "time" @@ -64,8 +66,28 @@ func (s *ServiceCore) initCache() error { return nil } + // Get database directory from environment + dataDir := filepath.Dir(os.Getenv("DASHBRR__DB_PATH")) + if dataDir == "." { + dataDir = "./data" // Default to ./data if not set + } + + // Initialize cache config + cfg := cache.Config{ + DataDir: dataDir, + } + + // Add Redis configuration if available + if host := os.Getenv("REDIS_HOST"); host != "" { + port := os.Getenv("REDIS_PORT") + if port == "" { + port = "6379" + } + cfg.RedisAddr = host + ":" + port + } + // Initialize cache using the cache package's initialization logic - store, err := cache.InitCache() + store, err := cache.InitCache(cfg) if err != nil { // If initialization fails, we'll still get a memory cache from InitCache // We can continue with the memory cache but should return the error