Use explicit track keys for processing

This commit is contained in:
2025-10-17 14:42:07 +02:00
parent c23518d2e3
commit fc79ff3114
3 changed files with 55 additions and 123 deletions

View File

@@ -192,38 +192,7 @@ class PipelineMainMultitrack(PipelineMainBase):
async with self.lock_transaction():
return await transcripts_controller.set_status(transcript_id, status)
async def _list_immediate_keys(
self, s3, bucket_name: str, prefix: str
) -> list[str]:
paginator = s3.get_paginator("list_objects_v2")
raw_prefix = prefix.rstrip("/")
prefixes = [raw_prefix, raw_prefix + "/"]
keys: set[str] = set()
for pref in prefixes:
for page in paginator.paginate(Bucket=bucket_name, Prefix=pref):
for obj in page.get("Contents", []):
key = obj["Key"]
if not key.startswith(pref):
continue
if pref.endswith("/"):
rel = key[len(pref) :]
if not rel or rel.endswith("/") or "/" in rel:
continue
else:
if key != pref:
continue
keys.add(key)
result = sorted(keys)
self.logger.info(
"S3 list immediate files",
prefixes=prefixes,
total_keys=len(result),
sample=result[:5],
)
return result
async def process(self, bucket_name: str, prefix: str):
async def process(self, bucket_name: str, track_keys: list[str]):
transcript = await self.get_transcript()
s3 = boto3.client(
@@ -233,15 +202,11 @@ class PipelineMainMultitrack(PipelineMainBase):
aws_secret_access_key=settings.RECORDING_STORAGE_AWS_SECRET_ACCESS_KEY,
)
keys = await self._list_immediate_keys(s3, bucket_name, prefix)
if not keys:
raise Exception("No audio tracks found under prefix")
storage = get_transcripts_storage()
# Pre-download bytes for all tracks for mixing and transcription
track_datas: list[bytes] = []
for key in keys:
for key in track_keys:
try:
obj = s3.get_object(Bucket=bucket_name, Key=key)
track_datas.append(obj["Body"].read())
@@ -262,7 +227,7 @@ class PipelineMainMultitrack(PipelineMainBase):
self.logger.error("Mixdown failed", error=str(e))
speaker_transcripts: list[TranscriptType] = []
for idx, key in enumerate(keys):
for idx, key in enumerate(track_keys):
ext = ".mp4"
try:
@@ -433,12 +398,12 @@ class PipelineMainMultitrack(PipelineMainBase):
@shared_task
@asynctask
async def task_pipeline_multitrack_process(
*, transcript_id: str, bucket_name: str, prefix: str
*, transcript_id: str, bucket_name: str, track_keys: list[str]
):
pipeline = PipelineMainMultitrack(transcript_id=transcript_id)
try:
await pipeline.set_status(transcript_id, "processing")
await pipeline.process(bucket_name, prefix)
await pipeline.process(bucket_name, track_keys)
except Exception:
await pipeline.set_status(transcript_id, "error")
raise