diff --git a/cli/src/envGenerator.js b/cli/src/envGenerator.js index 378d4fea2f..cc9c92a18d 100644 --- a/cli/src/envGenerator.js +++ b/cli/src/envGenerator.js @@ -3,141 +3,140 @@ import fs from "fs"; import chalk from "chalk"; export const generateEnv = (envValues) => { - let isDockerCompose = envValues.runOption === "docker-compose"; - let dbPort = isDockerCompose ? 3307 : 3306; - let platformUrl = isDockerCompose - ? "http://host.docker.internal:8000" - : "http://localhost:8000"; - - const envDefinition = getEnvDefinition( - envValues, - isDockerCompose, - dbPort, - platformUrl, - ); - - const envFileContent = generateEnvFileContent(envDefinition); - saveEnvFile(envFileContent); + let isDockerCompose = envValues.runOption === "docker-compose"; + let dbPort = isDockerCompose ? 3307 : 3306; + let platformUrl = isDockerCompose + ? "http://host.docker.internal:8000" + : "http://localhost:8000"; + + const envDefinition = getEnvDefinition( + envValues, + isDockerCompose, + dbPort, + platformUrl + ); + + const envFileContent = generateEnvFileContent(envDefinition); + saveEnvFile(envFileContent); }; const getEnvDefinition = (envValues, isDockerCompose, dbPort, platformUrl) => { - return { - "Deployment Environment": { - NODE_ENV: "development", - NEXT_PUBLIC_VERCEL_ENV: "${NODE_ENV}", - }, - NextJS: { - NEXT_PUBLIC_BACKEND_URL: "http://localhost:8000", - NEXT_PUBLIC_MAX_LOOPS: 100, - }, - "Next Auth config": { - NEXTAUTH_SECRET: generateAuthSecret(), - NEXTAUTH_URL: "http://localhost:3000", - }, - "Auth providers (Use if you want to get out of development mode sign-in)": { - GOOGLE_CLIENT_ID: "***", - GOOGLE_CLIENT_SECRET: "***", - GITHUB_CLIENT_ID: "***", - GITHUB_CLIENT_SECRET: "***", - DISCORD_CLIENT_SECRET: "***", - DISCORD_CLIENT_ID: "***", - }, - Backend: { - REWORKD_PLATFORM_ENVIRONMENT: "${NODE_ENV}", - REWORKD_PLATFORM_FF_MOCK_MODE_ENABLED: false, - REWORKD_PLATFORM_MAX_LOOPS: "${NEXT_PUBLIC_MAX_LOOPS}", - REWORKD_PLATFORM_OPENAI_API_KEY: - envValues.OpenAIApiKey || '""', - REWORKD_PLATFORM_FRONTEND_URL: "http://localhost:3000", - REWORKD_PLATFORM_RELOAD: true, - REWORKD_PLATFORM_OPENAI_API_BASE: "https://api.openai.com/v1", - REWORKD_PLATFORM_SERP_API_KEY: envValues.serpApiKey || '""', - REWORKD_PLATFORM_REPLICATE_API_KEY: envValues.replicateApiKey || '""', - }, - "Database (Backend)": { - REWORKD_PLATFORM_DATABASE_USER: "reworkd_platform", - REWORKD_PLATFORM_DATABASE_PASSWORD: "reworkd_platform", - REWORKD_PLATFORM_DATABASE_HOST: "db", - REWORKD_PLATFORM_DATABASE_PORT: dbPort, - REWORKD_PLATFORM_DATABASE_NAME: "reworkd_platform", - REWORKD_PLATFORM_DATABASE_URL: - "mysql://${REWORKD_PLATFORM_DATABASE_USER}:${REWORKD_PLATFORM_DATABASE_PASSWORD}@${REWORKD_PLATFORM_DATABASE_HOST}:${REWORKD_PLATFORM_DATABASE_PORT}/${REWORKD_PLATFORM_DATABASE_NAME}", - REWORKD_PLATFORM_DB_CA_PATH: isDockerCompose ? "/etc/ssl/certs/ca-certificates.crt" : "", - }, - "Database (Frontend)": { - DATABASE_USER: "reworkd_platform", - DATABASE_PASSWORD: "reworkd_platform", - DATABASE_HOST: "db", - DATABASE_PORT: dbPort, - DATABASE_NAME: "reworkd_platform", - DATABASE_URL: - "mysql://${DATABASE_USER}:${DATABASE_PASSWORD}@${DATABASE_HOST}:${DATABASE_PORT}/${DATABASE_NAME}", - }, - }; + return { + "Deployment Environment": { + NODE_ENV: "development", + NEXT_PUBLIC_VERCEL_ENV: "${NODE_ENV}", + }, + NextJS: { + NEXT_PUBLIC_BACKEND_URL: "http://localhost:8000", + NEXT_PUBLIC_MAX_LOOPS: 100, + }, + "Next Auth config": { + NEXTAUTH_SECRET: generateAuthSecret(), + NEXTAUTH_URL: "http://localhost:3000", + }, + "Auth providers (Use if you want to get out of development mode sign-in)": { + GOOGLE_CLIENT_ID: "***", + GOOGLE_CLIENT_SECRET: "***", + GITHUB_CLIENT_ID: "***", + GITHUB_CLIENT_SECRET: "***", + DISCORD_CLIENT_SECRET: "***", + DISCORD_CLIENT_ID: "***", + }, + Backend: { + REWORKD_PLATFORM_ENVIRONMENT: "${NODE_ENV}", + REWORKD_PLATFORM_FF_MOCK_MODE_ENABLED: false, + REWORKD_PLATFORM_MAX_LOOPS: "${NEXT_PUBLIC_MAX_LOOPS}", + REWORKD_PLATFORM_OPENAI_API_KEY: + envValues.OpenAIApiKey || '""', + REWORKD_PLATFORM_FRONTEND_URL: "http://localhost:3000", + REWORKD_PLATFORM_RELOAD: true, + REWORKD_PLATFORM_OPENAI_API_BASE: "https://api.openai.com/v1", + REWORKD_PLATFORM_SERP_API_KEY: envValues.serpApiKey || '""', + REWORKD_PLATFORM_REPLICATE_API_KEY: envValues.replicateApiKey || '""', + }, + "Database (Backend)": { + REWORKD_PLATFORM_DATABASE_USER: "reworkd_platform", + REWORKD_PLATFORM_DATABASE_PASSWORD: "reworkd_platform", + REWORKD_PLATFORM_DATABASE_HOST: "db", + REWORKD_PLATFORM_DATABASE_PORT: dbPort, + REWORKD_PLATFORM_DATABASE_NAME: "reworkd_platform", + REWORKD_PLATFORM_DATABASE_URL: + "mysql://${REWORKD_PLATFORM_DATABASE_USER}:${REWORKD_PLATFORM_DATABASE_PASSWORD}@${REWORKD_PLATFORM_DATABASE_HOST}:${REWORKD_PLATFORM_DATABASE_PORT}/${REWORKD_PLATFORM_DATABASE_NAME}", + }, + "Database (Frontend)": { + DATABASE_USER: "reworkd_platform", + DATABASE_PASSWORD: "reworkd_platform", + DATABASE_HOST: "db", + DATABASE_PORT: dbPort, + DATABASE_NAME: "reworkd_platform", + DATABASE_URL: + "mysql://${DATABASE_USER}:${DATABASE_PASSWORD}@${DATABASE_HOST}:${DATABASE_PORT}/${DATABASE_NAME}", + }, + }; }; const generateEnvFileContent = (config) => { - let configFile = ""; - - Object.entries(config).forEach(([section, variables]) => { - configFile += `# ${section}:\n`; - Object.entries(variables).forEach(([key, value]) => { - configFile += `${key}=${value}\n`; - }); - configFile += "\n"; + let configFile = ""; + + Object.entries(config).forEach(([section, variables]) => { + configFile += `# ${section}:\n`; + Object.entries(variables).forEach(([key, value]) => { + configFile += `${key}=${value}\n`; }); + configFile += "\n"; + }); - return configFile.trim(); + return configFile.trim(); }; const generateAuthSecret = () => { - const length = 32; - const buffer = crypto.randomBytes(length); - return buffer.toString("base64"); + const length = 32; + const buffer = crypto.randomBytes(length); + return buffer.toString("base64"); }; const ENV_PATH = "../next/.env"; const BACKEND_ENV_PATH = "../platform/.env"; export const doesEnvFileExist = () => { - return fs.existsSync(ENV_PATH); + return fs.existsSync(ENV_PATH); }; // Read the existing env file, test if it is missing any keys or contains any extra keys export const testEnvFile = () => { - const data = fs.readFileSync(ENV_PATH, "utf8"); + const data = fs.readFileSync(ENV_PATH, "utf8"); - // Make a fake definition to compare the keys of - const envDefinition = getEnvDefinition({}, "", "", "", ""); + // Make a fake definition to compare the keys of + const envDefinition = getEnvDefinition({}, "", "", "", ""); - const lines = data - .split("\n") - .filter((line) => !line.startsWith("#") && line.trim() !== ""); - const envKeysFromFile = lines.map((line) => line.split("=")[0]); + const lines = data + .split("\n") + .filter((line) => !line.startsWith("#") && line.trim() !== ""); + const envKeysFromFile = lines.map((line) => line.split("=")[0]); - const envKeysFromDef = Object.entries(envDefinition).flatMap( - ([section, entries]) => Object.keys(entries) - ); + const envKeysFromDef = Object.entries(envDefinition).flatMap( + ([section, entries]) => Object.keys(entries) + ); - const missingFromFile = envKeysFromDef.filter( - (key) => !envKeysFromFile.includes(key) - ); + const missingFromFile = envKeysFromDef.filter( + (key) => !envKeysFromFile.includes(key) + ); + + if (missingFromFile.length > 0) { + let errorMessage = "\nYour ./next/.env is missing the following keys:\n"; + missingFromFile.forEach((key) => { + errorMessage += chalk.whiteBright(`- ❌ ${key}\n`); + }); + errorMessage += "\n"; - if(missingFromFile.length > 0) { - let errorMessage = "\nYour ./next/.env is missing the following keys:\n"; - missingFromFile.forEach((key) => { - errorMessage += chalk.whiteBright(`- ❌ ${key}\n`); - }); - errorMessage += "\n"; - - errorMessage += chalk.red( - "We recommend deleting your .env file(s) and restarting this script." - ); - throw new Error(errorMessage); - } + errorMessage += chalk.red( + "We recommend deleting your .env file(s) and restarting this script." + ); + throw new Error(errorMessage); + } }; export const saveEnvFile = (envFileContent) => { - fs.writeFileSync(ENV_PATH, envFileContent); - fs.writeFileSync(BACKEND_ENV_PATH, envFileContent); + fs.writeFileSync(ENV_PATH, envFileContent); + fs.writeFileSync(BACKEND_ENV_PATH, envFileContent); }; diff --git a/platform/reworkd_platform/db/utils.py b/platform/reworkd_platform/db/utils.py index d379e7f4a0..3fe433c64b 100644 --- a/platform/reworkd_platform/db/utils.py +++ b/platform/reworkd_platform/db/utils.py @@ -1,8 +1,9 @@ -import ssl +from ssl import CERT_REQUIRED from sqlalchemy import text from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine +from reworkd_platform.services.ssl import get_ssl_context from reworkd_platform.settings import settings @@ -18,8 +19,8 @@ def create_engine() -> AsyncEngine: echo=settings.db_echo, ) - ssl_context = ssl.create_default_context(cafile=settings.db_ca_path) - ssl_context.verify_mode = ssl.CERT_REQUIRED + ssl_context = get_ssl_context(settings) + ssl_context.verify_mode = CERT_REQUIRED connect_args = {"ssl": ssl_context} return create_async_engine( diff --git a/platform/reworkd_platform/services/kafka/consumers/base.py b/platform/reworkd_platform/services/kafka/consumers/base.py index 669bd1e320..b4c7a54e3a 100644 --- a/platform/reworkd_platform/services/kafka/consumers/base.py +++ b/platform/reworkd_platform/services/kafka/consumers/base.py @@ -1,12 +1,12 @@ import asyncio import json -import ssl from abc import ABC, abstractmethod from typing import Any, Protocol from aiokafka import AIOKafkaConsumer, ConsumerRecord from loguru import logger +from reworkd_platform.services.ssl import get_ssl_context from reworkd_platform.settings import Settings @@ -38,7 +38,7 @@ def __init__( security_protocol="SASL_SSL", sasl_plain_username=settings.kafka_username, sasl_plain_password=settings.kafka_password, - ssl_context=ssl.create_default_context(cafile=settings.db_ca_path), + ssl_context=get_ssl_context(settings), enable_auto_commit=True, auto_offset_reset="earliest", value_deserializer=deserializer.deserialize, diff --git a/platform/reworkd_platform/services/kafka/producers/base.py b/platform/reworkd_platform/services/kafka/producers/base.py index f4e9f57f3f..bea1d8f955 100644 --- a/platform/reworkd_platform/services/kafka/producers/base.py +++ b/platform/reworkd_platform/services/kafka/producers/base.py @@ -1,10 +1,10 @@ -from ssl import create_default_context from typing import Literal from aiokafka import AIOKafkaProducer from loguru import logger from pydantic import BaseModel +from reworkd_platform.services.ssl import get_ssl_context from reworkd_platform.settings import Settings TOPICS = Literal["workflow_task"] @@ -19,7 +19,7 @@ def __init__(self, settings: Settings): security_protocol="SASL_SSL", sasl_plain_username=settings.kafka_username, sasl_plain_password=settings.kafka_password, - ssl_context=create_default_context(cafile=settings.db_ca_path), + ssl_context=get_ssl_context(settings), ) self._headers = ( diff --git a/platform/reworkd_platform/services/ssl.py b/platform/reworkd_platform/services/ssl.py new file mode 100644 index 0000000000..a9f523e051 --- /dev/null +++ b/platform/reworkd_platform/services/ssl.py @@ -0,0 +1,25 @@ +from ssl import create_default_context, SSLContext +from typing import List, Optional + +from reworkd_platform.settings import Settings + +MACOS_CERT_PATH = "/etc/ssl/cert.pem" +DOCKER_CERT_PATH = "/etc/ssl/certs/ca-certificates.crt" + + +def get_ssl_context( + settings: Settings, paths: Optional[List[str]] = None +) -> SSLContext: + if settings.db_ca_path: + return create_default_context(cafile=settings.db_ca_path) + + for path in paths or [MACOS_CERT_PATH, DOCKER_CERT_PATH]: + try: + return create_default_context(cafile=path) + except FileNotFoundError: + continue + + raise ValueError( + "No CA certificates found for your OS. To fix this, please run change " + "db_ca_path in your settings.py to point to a valid CA certificate file." + ) diff --git a/platform/reworkd_platform/settings.py b/platform/reworkd_platform/settings.py index ee86438c0f..23012843aa 100644 --- a/platform/reworkd_platform/settings.py +++ b/platform/reworkd_platform/settings.py @@ -72,7 +72,7 @@ class Settings(BaseSettings): db_pass: str = "reworkd_platform" db_base: str = "reworkd_platform" db_echo: bool = False - db_ca_path: str = "/etc/ssl/cert.pem" + db_ca_path: Optional[str] = None # Variables for Weaviate db. vector_db_url: Optional[str] = None diff --git a/platform/reworkd_platform/tests/test_ssl.py b/platform/reworkd_platform/tests/test_ssl.py new file mode 100644 index 0000000000..aa5b47ad0f --- /dev/null +++ b/platform/reworkd_platform/tests/test_ssl.py @@ -0,0 +1,24 @@ +import pytest + +from reworkd_platform.services.ssl import ( + get_ssl_context, +) +from reworkd_platform.settings import Settings + + +def test_get_ssl_context(): + get_ssl_context(Settings()) + + +def test_get_ssl_context_raise(): + settings = Settings() + + with pytest.raises(ValueError): + get_ssl_context(settings, paths=["/test/cert.pem"]) + + +def test_get_ssl_context_specified_raise(): + settings = Settings(db_ca_path="/test/cert.pem") + + with pytest.raises(FileNotFoundError): + get_ssl_context(settings)