Skip to content

Commit

Permalink
Merge pull request #15 from PotLock/feat/security
Browse files Browse the repository at this point in the history
Adds security protections
  • Loading branch information
elliotBraem authored Jan 16, 2025
2 parents 964da3f + dbf054c commit 009c84d
Show file tree
Hide file tree
Showing 6 changed files with 211 additions and 110 deletions.
1 change: 1 addition & 0 deletions backend/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
"dotenv": "^16.0.3",
"drizzle-orm": "^0.38.3",
"elysia": "^1.2.10",
"elysia-helmet": "^2.0.0",
"express": "^4.18.2",
"ora": "^8.1.1",
"winston": "^3.17.0",
Expand Down
257 changes: 147 additions & 110 deletions backend/src/index.ts
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
import { cors } from "@elysiajs/cors";
import { staticPlugin } from "@elysiajs/static";
import { swagger } from "@elysiajs/swagger";
import { ServerWebSocket } from "bun";
import dotenv from "dotenv";
import { Elysia } from "elysia";
import { helmet } from "elysia-helmet";
import path from "path";
import { DistributionService } from "services/distribution/distribution.service";
import configService, { validateEnv } from "./config/config";
import RssPlugin from "./external/rss";
import { db } from "./services/db";
import { SubmissionService } from "./services/submissions/submission.service";
import { TwitterService } from "./services/twitter/client";
import { webSocketService } from "./services/websocket/websocket.service";
import {
cleanup,
failSpinner,
Expand All @@ -24,21 +25,15 @@ const FRONTEND_DIST_PATH =
process.env.FRONTEND_DIST_PATH ||
path.join(process.cwd(), "../frontend/dist");

// Store active WebSocket connections
const activeConnections = new Set();
// Configuration
const ALLOWED_ORIGINS = [
"http://localhost:3000",
"https://curatedotfun.fly.dev",
];

// Broadcast to all connected clients
export function broadcastUpdate(data: unknown) {
const message = JSON.stringify(data);
activeConnections.forEach((ws) => {
try {
(ws as ServerWebSocket).send(message);
} catch (error) {
logger.error("Error broadcasting to WebSocket client:", error);
activeConnections.delete(ws);
}
});
}
// Export broadcast function for other modules
export const broadcastUpdate =
webSocketService.broadcast.bind(webSocketService);

export async function main() {
try {
Expand Down Expand Up @@ -81,21 +76,39 @@ export async function main() {
startSpinner("server", "Starting server...");

const app = new Elysia()
.use(cors())
.use(
helmet({
contentSecurityPolicy: {
directives: {
defaultSrc: ["'self'"],
connectSrc: ["'self'", "ws:", "wss:"], // Allow WebSocket connections
scriptSrc: ["'self'", "'unsafe-inline'"], // Required for some frontend frameworks
styleSrc: ["'self'", "'unsafe-inline'"], // Required for styled-components
imgSrc: ["'self'", "data:", "https:"], // Allow images from HTTPS sources
fontSrc: ["'self'", "data:", "https:"], // Allow fonts
},
},
crossOriginEmbedderPolicy: false, // Required for some static assets
crossOriginResourcePolicy: { policy: "cross-origin" }, // Allow resources to be shared
xFrameOptions: { action: "sameorigin" },
}),
)
.use(
cors({
origin: ALLOWED_ORIGINS,
methods: ["GET", "POST"],
}),
)
.use(swagger())
// WebSocket handling
.ws("/ws", {
open: (ws) => {
activeConnections.add(ws);
logger.debug(
`WebSocket client connected. Total connections: ${activeConnections.size}`,
);
open: (ws: any) => {
if (!webSocketService.addConnection(ws.remoteAddress, ws)) {
ws.close();
}
},
close: (ws) => {
activeConnections.delete(ws);
logger.debug(
`WebSocket client disconnected. Total connections: ${activeConnections.size}`,
);
close: (ws: any) => {
webSocketService.removeConnection(ws.remoteAddress, ws);
},
message: () => {
// we don't care about two-way connection yet
Expand All @@ -106,15 +119,21 @@ export async function main() {
const lastTweetId = twitterService.getLastCheckedTweetId();
return { lastTweetId };
})
.post("/api/last-tweet-id", async ({ body }) => {
const data = body as Record<string, unknown>;
if (!data?.tweetId || typeof data.tweetId !== "string") {
throw new Error("Invalid tweetId");
}
await twitterService.setLastCheckedTweetId(data.tweetId);
return { success: true };
})
.get("/api/submissions", ({ query }) => {
.post(
"/api/last-tweet-id",
async ({ body }: { body: { tweetId: string } }) => {
if (
!body?.tweetId ||
typeof body.tweetId !== "string" ||
!body.tweetId.match(/^\d+$/)
) {
throw new Error("Invalid tweetId format");
}
await twitterService.setLastCheckedTweetId(body.tweetId);
return { success: true };
},
)
.get("/api/submissions", ({ query }: { query: { status?: string } }) => {
const status = query?.status as
| "pending"
| "approved"
Expand All @@ -124,96 +143,114 @@ export async function main() {
? db.getSubmissionsByStatus(status)
: db.getAllSubmissions();
})
.get("/api/feed/:hashtagId", ({ params: { hashtagId } }) => {
const config = configService.getConfig();
const feed = config.feeds.find((f) => f.id === hashtagId);
if (!feed) {
throw new Error(`Feed not found: ${hashtagId}`);
}
.get(
"/api/feed/:hashtagId",
({ params: { hashtagId } }: { params: { hashtagId: string } }) => {
const config = configService.getConfig();
const feed = config.feeds.find((f) => f.id === hashtagId);
if (!feed) {
throw new Error(`Feed not found: ${hashtagId}`);
}

return db.getSubmissionsByFeed(hashtagId);
})
.get("/api/submissions/:hashtagId", ({ params: { hashtagId } }) => {
const config = configService.getConfig();
const feed = config.feeds.find((f) => f.id === hashtagId);
if (!feed) {
throw new Error(`Feed not found: ${hashtagId}`);
}
// this should be pending submissions
return db.getSubmissionsByFeed(hashtagId);
})
return db.getSubmissionsByFeed(hashtagId);
},
)
.get(
"/api/submissions/:hashtagId",
({ params: { hashtagId } }: { params: { hashtagId: string } }) => {
const config = configService.getConfig();
const feed = config.feeds.find((f) => f.id === hashtagId);
if (!feed) {
throw new Error(`Feed not found: ${hashtagId}`);
}
// this should be pending submissions
return db.getSubmissionsByFeed(hashtagId);
},
)
.get("/api/approved", () => {
return db.getSubmissionsByStatus("approved");
})
.get("/api/content/:contentId", ({ params: { contentId } }) => {
const content = db.getContent(contentId);
if (!content) {
throw new Error(`Content not found: ${contentId}`);
}
return content;
})
.get(
"/api/content/:contentId",
({ params: { contentId } }: { params: { contentId: string } }) => {
const content = db.getContent(contentId);
if (!content) {
throw new Error(`Content not found: ${contentId}`);
}
return content;
},
)
.get("/api/feeds", () => {
const config = configService.getConfig();
return config.feeds;
})
.get("/api/config/:feedId", ({ params: { feedId } }) => {
const config = configService.getConfig();
const feed = config.feeds.find((f) => f.id === feedId);
if (!feed) {
throw new Error(`Feed not found: ${feedId}`);
}
return feed;
})
.get("/plugin/rss/:feedId", ({ params: { feedId } }) => {
const rssPlugin = distributionService.getPlugin("rss");
if (!rssPlugin || !(rssPlugin instanceof RssPlugin)) {
throw new Error("RSS plugin not found or invalid");
}
.get(
"/api/config/:feedId",
({ params: { feedId } }: { params: { feedId: string } }) => {
const config = configService.getConfig();
const feed = config.feeds.find((f) => f.id === feedId);
if (!feed) {
throw new Error(`Feed not found: ${feedId}`);
}
return feed;
},
)
.get(
"/plugin/rss/:feedId",
({ params: { feedId } }: { params: { feedId: string } }) => {
const rssPlugin = distributionService.getPlugin("rss");
if (!rssPlugin || !(rssPlugin instanceof RssPlugin)) {
throw new Error("RSS plugin not found or invalid");
}

const service = rssPlugin.getServices().get(feedId);
if (!service) {
throw new Error("RSS service not initialized for this feed");
}
const service = rssPlugin.getServices().get(feedId);
if (!service) {
throw new Error("RSS service not initialized for this feed");
}

return service.getItems();
})
.post("/api/feeds/:feedId/process", async ({ params: { feedId } }) => {
// Get feed config
const config = configService.getConfig();
const feed = config.feeds.find((f) => f.id === feedId);
if (!feed) {
throw new Error(`Feed not found: ${feedId}`);
}
return service.getItems();
},
)
.post(
"/api/feeds/:feedId/process",
async ({ params: { feedId } }: { params: { feedId: string } }) => {
// Get feed config
const config = configService.getConfig();
const feed = config.feeds.find((f) => f.id === feedId);
if (!feed) {
throw new Error(`Feed not found: ${feedId}`);
}

// Get approved submissions for this feed
const submissions = db
.getSubmissionsByFeed(feedId)
.filter((sub) => sub.status === "approved");
// Get approved submissions for this feed
const submissions = db
.getSubmissionsByFeed(feedId)
.filter((sub) => sub.status === "approved");

if (submissions.length === 0) {
return { processed: 0 };
}
if (submissions.length === 0) {
return { processed: 0 };
}

// Process each submission through stream output
let processed = 0;
for (const submission of submissions) {
try {
await distributionService.processStreamOutput(
feedId,
submission.tweetId,
submission.content,
);
processed++;
} catch (error) {
logger.error(
`Error processing submission ${submission.tweetId}:`,
error,
);
// Process each submission through stream output
let processed = 0;
for (const submission of submissions) {
try {
await distributionService.processStreamOutput(
feedId,
submission.tweetId,
submission.content,
);
processed++;
} catch (error) {
logger.error(
`Error processing submission ${submission.tweetId}:`,
error,
);
}
}
}

return { processed };
})
return { processed };
},
)
// This was the most annoying thing to set up and debug. Serves our frontend and handles routing. alwaysStatic is essential.
.use(
staticPlugin({
Expand Down
1 change: 1 addition & 0 deletions backend/src/services/twitter/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ export class TwitterService {
}
} catch (error) {
logger.error("Failed to login to Twitter, retrying...", error);
break;
}

// Wait before retrying
Expand Down
59 changes: 59 additions & 0 deletions backend/src/services/websocket/websocket.service.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import { ServerWebSocket } from "bun";
import { logger } from "../../utils/logger";

export class WebSocketService {
// Store active connections
private activeConnections = new Map<string, Set<ServerWebSocket>>();

/**
* Add a new WebSocket connection
*/
public addConnection(ip: string, ws: ServerWebSocket): boolean {
// Initialize connection set for IP if needed
if (!this.activeConnections.has(ip)) {
this.activeConnections.set(ip, new Set());
}

const connections = this.activeConnections.get(ip)!;
connections.add(ws);
logger.debug(`WebSocket client connected from ${ip}`);
return true;
}

/**
* Remove a WebSocket connection
*/
public removeConnection(ip: string, ws: ServerWebSocket): void {
const connections = this.activeConnections.get(ip);
if (connections) {
connections.delete(ws);
if (connections.size === 0) {
this.activeConnections.delete(ip);
}
logger.debug(`WebSocket client disconnected from ${ip}`);
}
}

/**
* Broadcast a message to all connected clients
*/
public broadcast(data: unknown): void {
const message = JSON.stringify(data);
for (const [ip, connections] of this.activeConnections.entries()) {
connections.forEach((ws) => {
try {
ws.send(message);
} catch (error) {
logger.error(
`Error broadcasting to WebSocket client (${ip}):`,
error,
);
this.removeConnection(ip, ws);
}
});
}
}
}

// Export singleton instance
export const webSocketService = new WebSocketService();
Binary file modified bun.lockb
Binary file not shown.
Loading

0 comments on commit 009c84d

Please sign in to comment.