review comments

This commit is contained in:
Gokul Mohanarangan
2023-08-21 13:50:59 +05:30
parent 78153c6cfb
commit a0ea32db8a
2 changed files with 7 additions and 10 deletions

View File

@@ -99,13 +99,13 @@ class Whisper:
) )
multilingual_transcript = {} multilingual_transcript = {}
transcript_en = "" transcript_source_lang = ""
words = [] words = []
if segments: if segments:
segments = list(segments) segments = list(segments)
for segment in segments: for segment in segments:
transcript_en += segment.text transcript_source_lang += segment.text
for word in segment.words: for word in segment.words:
words.append( words.append(
{ {
@@ -115,12 +115,12 @@ class Whisper:
} }
) )
multilingual_transcript["en"] = transcript_en multilingual_transcript[source_language] = transcript_source_lang
if target_language != "en": if target_language != source_language:
self.translation_tokenizer.src_lang = source_language self.translation_tokenizer.src_lang = source_language
forced_bos_token_id = self.translation_tokenizer.get_lang_id(target_language) forced_bos_token_id = self.translation_tokenizer.get_lang_id(target_language)
encoded_transcript = self.translation_tokenizer(transcript_en, return_tensors="pt").to(self.device) encoded_transcript = self.translation_tokenizer(transcript_source_lang, return_tensors="pt").to(self.device)
generated_tokens = self.translation_model.generate( generated_tokens = self.translation_model.generate(
**encoded_transcript, **encoded_transcript,
forced_bos_token_id=forced_bos_token_id forced_bos_token_id=forced_bos_token_id

View File

@@ -29,10 +29,7 @@ class AudioTranscriptModalProcessor(AudioTranscriptProcessor):
self.transcript_url = settings.TRANSCRIPT_URL + "/transcribe" self.transcript_url = settings.TRANSCRIPT_URL + "/transcribe"
self.warmup_url = settings.TRANSCRIPT_URL + "/warmup" self.warmup_url = settings.TRANSCRIPT_URL + "/warmup"
self.timeout = settings.TRANSCRIPT_TIMEOUT self.timeout = settings.TRANSCRIPT_TIMEOUT
self.headers = { self.headers = {"Authorization": f"Bearer {modal_api_key}"}
"Authorization": f"Bearer {modal_api_key}",
# "Content-Type": "multipart/form-data"
}
async def _warmup(self): async def _warmup(self):
try: try:
@@ -90,7 +87,7 @@ class AudioTranscriptModalProcessor(AudioTranscriptProcessor):
if target_language in result["text"]: if target_language in result["text"]:
text = result["text"][target_language] text = result["text"][target_language]
else: else:
text = result["text"]["en"] text = result["text"][source_language]
transcript = Transcript( transcript = Transcript(
text=text, text=text,
words=[ words=[