From 76d2926d806336e846eab00da784c5f9b216791a Mon Sep 17 00:00:00 2001 From: awtkns <32209255+awtkns@users.noreply.github.com> Date: Sat, 22 Jul 2023 13:10:47 -0700 Subject: [PATCH 1/4] =?UTF-8?q?=F0=9F=90=9B=20Improve=20SSL=20devx?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- platform/reworkd_platform/db/utils.py | 7 ++++--- .../services/kafka/consumers/base.py | 4 ++-- .../services/kafka/producers/base.py | 4 ++-- platform/reworkd_platform/services/ssl.py | 19 +++++++++++++++++++ platform/reworkd_platform/settings.py | 2 +- 5 files changed, 28 insertions(+), 8 deletions(-) create mode 100644 platform/reworkd_platform/services/ssl.py diff --git a/platform/reworkd_platform/db/utils.py b/platform/reworkd_platform/db/utils.py index d379e7f4a0..3fe433c64b 100644 --- a/platform/reworkd_platform/db/utils.py +++ b/platform/reworkd_platform/db/utils.py @@ -1,8 +1,9 @@ -import ssl +from ssl import CERT_REQUIRED from sqlalchemy import text from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine +from reworkd_platform.services.ssl import get_ssl_context from reworkd_platform.settings import settings @@ -18,8 +19,8 @@ def create_engine() -> AsyncEngine: echo=settings.db_echo, ) - ssl_context = ssl.create_default_context(cafile=settings.db_ca_path) - ssl_context.verify_mode = ssl.CERT_REQUIRED + ssl_context = get_ssl_context(settings) + ssl_context.verify_mode = CERT_REQUIRED connect_args = {"ssl": ssl_context} return create_async_engine( diff --git a/platform/reworkd_platform/services/kafka/consumers/base.py b/platform/reworkd_platform/services/kafka/consumers/base.py index 669bd1e320..b4c7a54e3a 100644 --- a/platform/reworkd_platform/services/kafka/consumers/base.py +++ b/platform/reworkd_platform/services/kafka/consumers/base.py @@ -1,12 +1,12 @@ import asyncio import json -import ssl from abc import ABC, abstractmethod from typing import Any, Protocol from aiokafka import AIOKafkaConsumer, ConsumerRecord from loguru import logger +from reworkd_platform.services.ssl import get_ssl_context from reworkd_platform.settings import Settings @@ -38,7 +38,7 @@ def __init__( security_protocol="SASL_SSL", sasl_plain_username=settings.kafka_username, sasl_plain_password=settings.kafka_password, - ssl_context=ssl.create_default_context(cafile=settings.db_ca_path), + ssl_context=get_ssl_context(settings), enable_auto_commit=True, auto_offset_reset="earliest", value_deserializer=deserializer.deserialize, diff --git a/platform/reworkd_platform/services/kafka/producers/base.py b/platform/reworkd_platform/services/kafka/producers/base.py index f4e9f57f3f..bea1d8f955 100644 --- a/platform/reworkd_platform/services/kafka/producers/base.py +++ b/platform/reworkd_platform/services/kafka/producers/base.py @@ -1,10 +1,10 @@ -from ssl import create_default_context from typing import Literal from aiokafka import AIOKafkaProducer from loguru import logger from pydantic import BaseModel +from reworkd_platform.services.ssl import get_ssl_context from reworkd_platform.settings import Settings TOPICS = Literal["workflow_task"] @@ -19,7 +19,7 @@ def __init__(self, settings: Settings): security_protocol="SASL_SSL", sasl_plain_username=settings.kafka_username, sasl_plain_password=settings.kafka_password, - ssl_context=create_default_context(cafile=settings.db_ca_path), + ssl_context=get_ssl_context(settings), ) self._headers = ( diff --git a/platform/reworkd_platform/services/ssl.py b/platform/reworkd_platform/services/ssl.py new file mode 100644 index 0000000000..50f8a6455a --- /dev/null +++ b/platform/reworkd_platform/services/ssl.py @@ -0,0 +1,19 @@ +from ssl import create_default_context + +from reworkd_platform.settings import Settings + +MACOS_CERT_PATH = "/etc/ssl/cert.pem" +DOCKER_CERT_PATH = "/etc/ssl/certs/ca-certificates.crt" + + +def get_ssl_context(settings: Settings): + if settings.db_ca_path: + return create_default_context(cafile=settings.db_ca_path) + + for path in (MACOS_CERT_PATH, DOCKER_CERT_PATH): + try: + return create_default_context(cafile=path) + except FileNotFoundError: + continue + + raise FileNotFoundError("No CA certificates found for your OS.") diff --git a/platform/reworkd_platform/settings.py b/platform/reworkd_platform/settings.py index ee86438c0f..23012843aa 100644 --- a/platform/reworkd_platform/settings.py +++ b/platform/reworkd_platform/settings.py @@ -72,7 +72,7 @@ class Settings(BaseSettings): db_pass: str = "reworkd_platform" db_base: str = "reworkd_platform" db_echo: bool = False - db_ca_path: str = "/etc/ssl/cert.pem" + db_ca_path: Optional[str] = None # Variables for Weaviate db. vector_db_url: Optional[str] = None From d2a97f311e570eed231a45bee3dfaea85672d72f Mon Sep 17 00:00:00 2001 From: awtkns <32209255+awtkns@users.noreply.github.com> Date: Sat, 22 Jul 2023 13:19:57 -0700 Subject: [PATCH 2/4] =?UTF-8?q?=F0=9F=90=9B=20Improve=20SSL=20devx?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- platform/reworkd_platform/services/ssl.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/platform/reworkd_platform/services/ssl.py b/platform/reworkd_platform/services/ssl.py index 50f8a6455a..b627fcd0ef 100644 --- a/platform/reworkd_platform/services/ssl.py +++ b/platform/reworkd_platform/services/ssl.py @@ -1,4 +1,4 @@ -from ssl import create_default_context +from ssl import create_default_context, SSLContext from reworkd_platform.settings import Settings @@ -6,7 +6,7 @@ DOCKER_CERT_PATH = "/etc/ssl/certs/ca-certificates.crt" -def get_ssl_context(settings: Settings): +def get_ssl_context(settings: Settings) -> SSLContext: if settings.db_ca_path: return create_default_context(cafile=settings.db_ca_path) From 4323134fab9a22ee4ba7bbffebeee6411f258cc3 Mon Sep 17 00:00:00 2001 From: awtkns <32209255+awtkns@users.noreply.github.com> Date: Sat, 22 Jul 2023 13:43:45 -0700 Subject: [PATCH 3/4] =?UTF-8?q?=F0=9F=90=9B=20Improve=20SSL=20devx?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- cli/src/envGenerator.js | 215 ++++++++++---------- platform/reworkd_platform/services/ssl.py | 10 +- platform/reworkd_platform/tests/test_ssl.py | 24 +++ 3 files changed, 138 insertions(+), 111 deletions(-) create mode 100644 platform/reworkd_platform/tests/test_ssl.py diff --git a/cli/src/envGenerator.js b/cli/src/envGenerator.js index 378d4fea2f..cc9c92a18d 100644 --- a/cli/src/envGenerator.js +++ b/cli/src/envGenerator.js @@ -3,141 +3,140 @@ import fs from "fs"; import chalk from "chalk"; export const generateEnv = (envValues) => { - let isDockerCompose = envValues.runOption === "docker-compose"; - let dbPort = isDockerCompose ? 3307 : 3306; - let platformUrl = isDockerCompose - ? "http://host.docker.internal:8000" - : "http://localhost:8000"; - - const envDefinition = getEnvDefinition( - envValues, - isDockerCompose, - dbPort, - platformUrl, - ); - - const envFileContent = generateEnvFileContent(envDefinition); - saveEnvFile(envFileContent); + let isDockerCompose = envValues.runOption === "docker-compose"; + let dbPort = isDockerCompose ? 3307 : 3306; + let platformUrl = isDockerCompose + ? "http://host.docker.internal:8000" + : "http://localhost:8000"; + + const envDefinition = getEnvDefinition( + envValues, + isDockerCompose, + dbPort, + platformUrl + ); + + const envFileContent = generateEnvFileContent(envDefinition); + saveEnvFile(envFileContent); }; const getEnvDefinition = (envValues, isDockerCompose, dbPort, platformUrl) => { - return { - "Deployment Environment": { - NODE_ENV: "development", - NEXT_PUBLIC_VERCEL_ENV: "${NODE_ENV}", - }, - NextJS: { - NEXT_PUBLIC_BACKEND_URL: "http://localhost:8000", - NEXT_PUBLIC_MAX_LOOPS: 100, - }, - "Next Auth config": { - NEXTAUTH_SECRET: generateAuthSecret(), - NEXTAUTH_URL: "http://localhost:3000", - }, - "Auth providers (Use if you want to get out of development mode sign-in)": { - GOOGLE_CLIENT_ID: "***", - GOOGLE_CLIENT_SECRET: "***", - GITHUB_CLIENT_ID: "***", - GITHUB_CLIENT_SECRET: "***", - DISCORD_CLIENT_SECRET: "***", - DISCORD_CLIENT_ID: "***", - }, - Backend: { - REWORKD_PLATFORM_ENVIRONMENT: "${NODE_ENV}", - REWORKD_PLATFORM_FF_MOCK_MODE_ENABLED: false, - REWORKD_PLATFORM_MAX_LOOPS: "${NEXT_PUBLIC_MAX_LOOPS}", - REWORKD_PLATFORM_OPENAI_API_KEY: - envValues.OpenAIApiKey || '""', - REWORKD_PLATFORM_FRONTEND_URL: "http://localhost:3000", - REWORKD_PLATFORM_RELOAD: true, - REWORKD_PLATFORM_OPENAI_API_BASE: "https://api.openai.com/v1", - REWORKD_PLATFORM_SERP_API_KEY: envValues.serpApiKey || '""', - REWORKD_PLATFORM_REPLICATE_API_KEY: envValues.replicateApiKey || '""', - }, - "Database (Backend)": { - REWORKD_PLATFORM_DATABASE_USER: "reworkd_platform", - REWORKD_PLATFORM_DATABASE_PASSWORD: "reworkd_platform", - REWORKD_PLATFORM_DATABASE_HOST: "db", - REWORKD_PLATFORM_DATABASE_PORT: dbPort, - REWORKD_PLATFORM_DATABASE_NAME: "reworkd_platform", - REWORKD_PLATFORM_DATABASE_URL: - "mysql://${REWORKD_PLATFORM_DATABASE_USER}:${REWORKD_PLATFORM_DATABASE_PASSWORD}@${REWORKD_PLATFORM_DATABASE_HOST}:${REWORKD_PLATFORM_DATABASE_PORT}/${REWORKD_PLATFORM_DATABASE_NAME}", - REWORKD_PLATFORM_DB_CA_PATH: isDockerCompose ? "/etc/ssl/certs/ca-certificates.crt" : "", - }, - "Database (Frontend)": { - DATABASE_USER: "reworkd_platform", - DATABASE_PASSWORD: "reworkd_platform", - DATABASE_HOST: "db", - DATABASE_PORT: dbPort, - DATABASE_NAME: "reworkd_platform", - DATABASE_URL: - "mysql://${DATABASE_USER}:${DATABASE_PASSWORD}@${DATABASE_HOST}:${DATABASE_PORT}/${DATABASE_NAME}", - }, - }; + return { + "Deployment Environment": { + NODE_ENV: "development", + NEXT_PUBLIC_VERCEL_ENV: "${NODE_ENV}", + }, + NextJS: { + NEXT_PUBLIC_BACKEND_URL: "http://localhost:8000", + NEXT_PUBLIC_MAX_LOOPS: 100, + }, + "Next Auth config": { + NEXTAUTH_SECRET: generateAuthSecret(), + NEXTAUTH_URL: "http://localhost:3000", + }, + "Auth providers (Use if you want to get out of development mode sign-in)": { + GOOGLE_CLIENT_ID: "***", + GOOGLE_CLIENT_SECRET: "***", + GITHUB_CLIENT_ID: "***", + GITHUB_CLIENT_SECRET: "***", + DISCORD_CLIENT_SECRET: "***", + DISCORD_CLIENT_ID: "***", + }, + Backend: { + REWORKD_PLATFORM_ENVIRONMENT: "${NODE_ENV}", + REWORKD_PLATFORM_FF_MOCK_MODE_ENABLED: false, + REWORKD_PLATFORM_MAX_LOOPS: "${NEXT_PUBLIC_MAX_LOOPS}", + REWORKD_PLATFORM_OPENAI_API_KEY: + envValues.OpenAIApiKey || '""', + REWORKD_PLATFORM_FRONTEND_URL: "http://localhost:3000", + REWORKD_PLATFORM_RELOAD: true, + REWORKD_PLATFORM_OPENAI_API_BASE: "https://api.openai.com/v1", + REWORKD_PLATFORM_SERP_API_KEY: envValues.serpApiKey || '""', + REWORKD_PLATFORM_REPLICATE_API_KEY: envValues.replicateApiKey || '""', + }, + "Database (Backend)": { + REWORKD_PLATFORM_DATABASE_USER: "reworkd_platform", + REWORKD_PLATFORM_DATABASE_PASSWORD: "reworkd_platform", + REWORKD_PLATFORM_DATABASE_HOST: "db", + REWORKD_PLATFORM_DATABASE_PORT: dbPort, + REWORKD_PLATFORM_DATABASE_NAME: "reworkd_platform", + REWORKD_PLATFORM_DATABASE_URL: + "mysql://${REWORKD_PLATFORM_DATABASE_USER}:${REWORKD_PLATFORM_DATABASE_PASSWORD}@${REWORKD_PLATFORM_DATABASE_HOST}:${REWORKD_PLATFORM_DATABASE_PORT}/${REWORKD_PLATFORM_DATABASE_NAME}", + }, + "Database (Frontend)": { + DATABASE_USER: "reworkd_platform", + DATABASE_PASSWORD: "reworkd_platform", + DATABASE_HOST: "db", + DATABASE_PORT: dbPort, + DATABASE_NAME: "reworkd_platform", + DATABASE_URL: + "mysql://${DATABASE_USER}:${DATABASE_PASSWORD}@${DATABASE_HOST}:${DATABASE_PORT}/${DATABASE_NAME}", + }, + }; }; const generateEnvFileContent = (config) => { - let configFile = ""; - - Object.entries(config).forEach(([section, variables]) => { - configFile += `# ${section}:\n`; - Object.entries(variables).forEach(([key, value]) => { - configFile += `${key}=${value}\n`; - }); - configFile += "\n"; + let configFile = ""; + + Object.entries(config).forEach(([section, variables]) => { + configFile += `# ${section}:\n`; + Object.entries(variables).forEach(([key, value]) => { + configFile += `${key}=${value}\n`; }); + configFile += "\n"; + }); - return configFile.trim(); + return configFile.trim(); }; const generateAuthSecret = () => { - const length = 32; - const buffer = crypto.randomBytes(length); - return buffer.toString("base64"); + const length = 32; + const buffer = crypto.randomBytes(length); + return buffer.toString("base64"); }; const ENV_PATH = "../next/.env"; const BACKEND_ENV_PATH = "../platform/.env"; export const doesEnvFileExist = () => { - return fs.existsSync(ENV_PATH); + return fs.existsSync(ENV_PATH); }; // Read the existing env file, test if it is missing any keys or contains any extra keys export const testEnvFile = () => { - const data = fs.readFileSync(ENV_PATH, "utf8"); + const data = fs.readFileSync(ENV_PATH, "utf8"); - // Make a fake definition to compare the keys of - const envDefinition = getEnvDefinition({}, "", "", "", ""); + // Make a fake definition to compare the keys of + const envDefinition = getEnvDefinition({}, "", "", "", ""); - const lines = data - .split("\n") - .filter((line) => !line.startsWith("#") && line.trim() !== ""); - const envKeysFromFile = lines.map((line) => line.split("=")[0]); + const lines = data + .split("\n") + .filter((line) => !line.startsWith("#") && line.trim() !== ""); + const envKeysFromFile = lines.map((line) => line.split("=")[0]); - const envKeysFromDef = Object.entries(envDefinition).flatMap( - ([section, entries]) => Object.keys(entries) - ); + const envKeysFromDef = Object.entries(envDefinition).flatMap( + ([section, entries]) => Object.keys(entries) + ); - const missingFromFile = envKeysFromDef.filter( - (key) => !envKeysFromFile.includes(key) - ); + const missingFromFile = envKeysFromDef.filter( + (key) => !envKeysFromFile.includes(key) + ); + + if (missingFromFile.length > 0) { + let errorMessage = "\nYour ./next/.env is missing the following keys:\n"; + missingFromFile.forEach((key) => { + errorMessage += chalk.whiteBright(`- ❌ ${key}\n`); + }); + errorMessage += "\n"; - if(missingFromFile.length > 0) { - let errorMessage = "\nYour ./next/.env is missing the following keys:\n"; - missingFromFile.forEach((key) => { - errorMessage += chalk.whiteBright(`- ❌ ${key}\n`); - }); - errorMessage += "\n"; - - errorMessage += chalk.red( - "We recommend deleting your .env file(s) and restarting this script." - ); - throw new Error(errorMessage); - } + errorMessage += chalk.red( + "We recommend deleting your .env file(s) and restarting this script." + ); + throw new Error(errorMessage); + } }; export const saveEnvFile = (envFileContent) => { - fs.writeFileSync(ENV_PATH, envFileContent); - fs.writeFileSync(BACKEND_ENV_PATH, envFileContent); + fs.writeFileSync(ENV_PATH, envFileContent); + fs.writeFileSync(BACKEND_ENV_PATH, envFileContent); }; diff --git a/platform/reworkd_platform/services/ssl.py b/platform/reworkd_platform/services/ssl.py index b627fcd0ef..c0da2ad14d 100644 --- a/platform/reworkd_platform/services/ssl.py +++ b/platform/reworkd_platform/services/ssl.py @@ -1,4 +1,5 @@ from ssl import create_default_context, SSLContext +from typing import List from reworkd_platform.settings import Settings @@ -6,14 +7,17 @@ DOCKER_CERT_PATH = "/etc/ssl/certs/ca-certificates.crt" -def get_ssl_context(settings: Settings) -> SSLContext: +def get_ssl_context(settings: Settings, paths: List[str] = None) -> SSLContext: if settings.db_ca_path: return create_default_context(cafile=settings.db_ca_path) - for path in (MACOS_CERT_PATH, DOCKER_CERT_PATH): + for path in paths or [MACOS_CERT_PATH, DOCKER_CERT_PATH]: try: return create_default_context(cafile=path) except FileNotFoundError: continue - raise FileNotFoundError("No CA certificates found for your OS.") + raise ValueError( + "No CA certificates found for your OS. To fix this, please run change " + "db_ca_path in your settings.py to point to a valid CA certificate file." + ) diff --git a/platform/reworkd_platform/tests/test_ssl.py b/platform/reworkd_platform/tests/test_ssl.py new file mode 100644 index 0000000000..aa5b47ad0f --- /dev/null +++ b/platform/reworkd_platform/tests/test_ssl.py @@ -0,0 +1,24 @@ +import pytest + +from reworkd_platform.services.ssl import ( + get_ssl_context, +) +from reworkd_platform.settings import Settings + + +def test_get_ssl_context(): + get_ssl_context(Settings()) + + +def test_get_ssl_context_raise(): + settings = Settings() + + with pytest.raises(ValueError): + get_ssl_context(settings, paths=["/test/cert.pem"]) + + +def test_get_ssl_context_specified_raise(): + settings = Settings(db_ca_path="/test/cert.pem") + + with pytest.raises(FileNotFoundError): + get_ssl_context(settings) From 325337091e1e2a6b3a02f72ece681a8e8107a771 Mon Sep 17 00:00:00 2001 From: awtkns <32209255+awtkns@users.noreply.github.com> Date: Sat, 22 Jul 2023 13:47:45 -0700 Subject: [PATCH 4/4] =?UTF-8?q?=F0=9F=90=9B=20Improve=20SSL=20devx?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- platform/reworkd_platform/services/ssl.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/platform/reworkd_platform/services/ssl.py b/platform/reworkd_platform/services/ssl.py index c0da2ad14d..a9f523e051 100644 --- a/platform/reworkd_platform/services/ssl.py +++ b/platform/reworkd_platform/services/ssl.py @@ -1,5 +1,5 @@ from ssl import create_default_context, SSLContext -from typing import List +from typing import List, Optional from reworkd_platform.settings import Settings @@ -7,7 +7,9 @@ DOCKER_CERT_PATH = "/etc/ssl/certs/ca-certificates.crt" -def get_ssl_context(settings: Settings, paths: List[str] = None) -> SSLContext: +def get_ssl_context( + settings: Settings, paths: Optional[List[str]] = None +) -> SSLContext: if settings.db_ca_path: return create_default_context(cafile=settings.db_ca_path)