diff --git a/packages/backend-common/src/config.ts b/packages/backend-common/src/config.ts index 3d80ed61..291cadc6 100644 --- a/packages/backend-common/src/config.ts +++ b/packages/backend-common/src/config.ts @@ -5,18 +5,6 @@ import { logger } from '@guardian/transcription-service-backend-common'; import { DestinationService } from '@guardian/transcription-service-common'; import { SecretsManager } from '@aws-sdk/client-secrets-manager'; -type WorkerType = 'cpu' | 'gpu'; - -const parseWorkerType = (workerType?: string): WorkerType | undefined => { - if (workerType === 'gpu') { - return 'gpu'; - } - if (workerType === 'cpu') { - return 'cpu'; - } - return undefined; -}; - export interface TranscriptionConfig { auth: { clientId: string; @@ -30,6 +18,7 @@ export interface TranscriptionConfig { deadLetterQueueUrl?: string; mediaDownloadQueueUrl: string; stage: string; + app: string; emailNotificationFromAddress: string; sourceMediaBucket: string; transcriptionOutputBucket: string; @@ -39,7 +28,6 @@ export interface TranscriptionConfig { mediaDownloadProxyIpAddress: string; mediaDownloadProxyPort: number; mediaExportFunctionName: string; - workerType?: WorkerType; }; aws: { region: string; @@ -87,9 +75,7 @@ export const getConfig = async (): Promise => { region, credentials: credentialProvider(stage !== 'DEV'), }); - - const workerTypeEnvVar = process.env['WORKER_TYPE']; - const workerType = parseWorkerType(workerTypeEnvVar); + const app = await getEnvVarOrMetadata('APP', 'tags/instance/App'); const paramPath = `/${stage}/investigations/transcription-service/`; @@ -195,6 +181,7 @@ export const getConfig = async (): Promise => { deadLetterQueueUrl, mediaDownloadQueueUrl, stage, + app, sourceMediaBucket, emailNotificationFromAddress, destinationQueueUrls: { @@ -207,7 +194,6 @@ export const getConfig = async (): Promise => { mediaDownloadProxyIpAddress, mediaDownloadProxyPort: 1337, mediaExportFunctionName, - workerType, }, aws: { region, diff --git a/packages/cdk/lib/transcription-service.ts b/packages/cdk/lib/transcription-service.ts index 7e79c5bf..3486df48 100644 --- a/packages/cdk/lib/transcription-service.ts +++ b/packages/cdk/lib/transcription-service.ts @@ -537,7 +537,7 @@ export class TranscriptionService extends GuStack { Tags.of(transcriptionGpuWorkerASG).add( 'App', - `transcription-service-worker`, + `transcription-service-gpu-worker`, { applyToLaunchedInstances: true, }, diff --git a/packages/worker/src/index.ts b/packages/worker/src/index.ts index 03f3306b..f2ff0ea2 100644 --- a/packages/worker/src/index.ts +++ b/packages/worker/src/index.ts @@ -68,6 +68,13 @@ const main = async () => { ); const autoScalingClient = getASGClient(config.aws.region); + const isGpu = config.app.app.startsWith('transcription-service-gpu-worker'); + const asgName = isGpu + ? `transcription-service-gpu-workers-${config.app.stage}` + : `transcription-service-workers-${config.app.stage}`; + const taskQueueUrl = isGpu + ? config.app.gpuTaskQueueUrl + : config.app.taskQueueUrl; if (config.app.stage !== 'DEV') { // start job to regularly check the instance interruption (Note: deliberately not using await here so the job @@ -88,10 +95,9 @@ const main = async () => { await pollTranscriptionQueue( pollCount, sqsClient, - config.app.workerType === 'gpu' - ? config.app.gpuTaskQueueUrl - : config.app.taskQueueUrl, + taskQueueUrl, autoScalingClient, + asgName, metrics, config, instanceId, @@ -130,6 +136,7 @@ const pollTranscriptionQueue = async ( sqsClient: SQSClient, taskQueueUrl: string, autoScalingClient: AutoScalingClient, + asgName: string, metrics: MetricsService, config: TranscriptionConfig, instanceId: string, @@ -137,10 +144,6 @@ const pollTranscriptionQueue = async ( const stage = config.app.stage; const numberOfThreads = config.app.stage === 'PROD' ? 16 : 2; const isDev = config.app.stage === 'DEV'; - const asgName = - config.app.workerType === 'gpu' - ? `transcription-service-gpu-workers-${config.app.stage}` - : `transcription-service-workers-${config.app.stage}`; logger.info( `worker polling for transcription task. Poll count = ${pollCount}`, diff --git a/scripts/setup.sh b/scripts/setup.sh index a4445159..da32bf73 100755 --- a/scripts/setup.sh +++ b/scripts/setup.sh @@ -46,7 +46,7 @@ TASK_QUEUE_URL=$(aws --endpoint-url=http://localhost:4566 sqs create-queue --que "RedrivePolicy": "{\"deadLetterTargetArn\":\"arn:aws:sqs:us-east-1:000000000000:transcription-service-task-dead-letter-queue-DEV.fifo\",\"maxReceiveCount\":\"3\"}" }' | jq .QueueUrl) # We don't install the localstack dns so need to replace the endpoint with localhost -TASK_QUEUE_URL_LOCALHOST=${QUEUE_URL/sqs.eu-west-1.localhost.localstack.cloud/localhost} +TASK_QUEUE_URL_LOCALHOST=${TASK_QUEUE_URL/sqs.eu-west-1.localhost.localstack.cloud/localhost} echo "Created task queue in localstack, url: ${TASK_QUEUE_URL_LOCALHOST}" @@ -57,7 +57,7 @@ GPU_TASK_QUEUE_URL=$(aws --endpoint-url=http://localhost:4566 sqs create-queue - "RedrivePolicy": "{\"deadLetterTargetArn\":\"arn:aws:sqs:us-east-1:000000000000:transcription-service-task-dead-letter-queue-DEV.fifo\",\"maxReceiveCount\":\"3\"}" }' | jq .QueueUrl) # We don't install the localstack dns so need to replace the endpoint with localhost -GPU_TASK_QUEUE_URL_LOCALHOST=${QUEUE_URL/sqs.eu-west-1.localhost.localstack.cloud/localhost} +GPU_TASK_QUEUE_URL_LOCALHOST=${GPU_TASK_QUEUE_URL/sqs.eu-west-1.localhost.localstack.cloud/localhost} echo "Created task queue in localstack, url: ${GPU_TASK_QUEUE_URL_LOCALHOST}"