Merge branch 'main' into jose/markers

This commit is contained in:
Jose B
2023-08-22 13:29:17 -05:00
10 changed files with 447 additions and 46 deletions

View File

@@ -155,7 +155,7 @@ class LLM:
def web():
from fastapi import Depends, FastAPI, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from pydantic import BaseModel
from pydantic import BaseModel, Field
llmstub = LLM()
@@ -172,14 +172,14 @@ def web():
class LLMRequest(BaseModel):
prompt: str
schema: Optional[dict] = None
schema_: Optional[dict] = Field(None, alias="schema")
@app.post("/llm", dependencies=[Depends(apikey_auth)])
async def llm(
req: LLMRequest,
):
if req.schema:
func = llmstub.generate.spawn(prompt=req.prompt, schema=json.dumps(req.schema))
if req.schema_:
func = llmstub.generate.spawn(prompt=req.prompt, schema=json.dumps(req.schema_))
else:
func = llmstub.generate.spawn(prompt=req.prompt)
result = func.get()

View File

@@ -3,11 +3,11 @@ Reflector GPU backend - transcriber
===================================
"""
import tempfile
import os
from modal import Image, method, Stub, asgi_app, Secret
from pydantic import BaseModel
import tempfile
from modal import Image, Secret, Stub, asgi_app, method
from pydantic import BaseModel
# Whisper
WHISPER_MODEL: str = "large-v2"
@@ -15,6 +15,9 @@ WHISPER_COMPUTE_TYPE: str = "float16"
WHISPER_NUM_WORKERS: int = 1
WHISPER_CACHE_DIR: str = "/cache/whisper"
# Translation Model
TRANSLATION_MODEL = "facebook/m2m100_418M"
stub = Stub(name="reflector-transcriber")
@@ -31,6 +34,9 @@ whisper_image = (
"faster-whisper",
"requests",
"torch",
"transformers",
"sentencepiece",
"protobuf",
)
.run_function(download_whisper)
.env(
@@ -51,17 +57,21 @@ whisper_image = (
)
class Whisper:
def __enter__(self):
import torch
import faster_whisper
import torch
from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer
self.use_gpu = torch.cuda.is_available()
device = "cuda" if self.use_gpu else "cpu"
self.device = "cuda" if self.use_gpu else "cpu"
self.model = faster_whisper.WhisperModel(
WHISPER_MODEL,
device=device,
device=self.device,
compute_type=WHISPER_COMPUTE_TYPE,
num_workers=WHISPER_NUM_WORKERS,
)
self.translation_model = M2M100ForConditionalGeneration.from_pretrained(TRANSLATION_MODEL).to(self.device)
self.translation_tokenizer = M2M100Tokenizer.from_pretrained(TRANSLATION_MODEL)
@method()
def warmup(self):
@@ -72,28 +82,30 @@ class Whisper:
self,
audio_data: str,
audio_suffix: str,
timestamp: float = 0,
language: str = "en",
source_language: str,
target_language: str,
timestamp: float = 0
):
with tempfile.NamedTemporaryFile("wb+", suffix=f".{audio_suffix}") as fp:
fp.write(audio_data)
segments, _ = self.model.transcribe(
fp.name,
language=language,
language=source_language,
beam_size=5,
word_timestamps=True,
vad_filter=True,
vad_parameters={"min_silence_duration_ms": 500},
)
transcript = ""
multilingual_transcript = {}
transcript_source_lang = ""
words = []
if segments:
segments = list(segments)
for segment in segments:
transcript += segment.text
transcript_source_lang += segment.text
for word in segment.words:
words.append(
{
@@ -102,9 +114,24 @@ class Whisper:
"end": round(timestamp + word.end, 3),
}
)
multilingual_transcript[source_language] = transcript_source_lang
if target_language != source_language:
self.translation_tokenizer.src_lang = source_language
forced_bos_token_id = self.translation_tokenizer.get_lang_id(target_language)
encoded_transcript = self.translation_tokenizer(transcript_source_lang, return_tensors="pt").to(self.device)
generated_tokens = self.translation_model.generate(
**encoded_transcript,
forced_bos_token_id=forced_bos_token_id
)
result = self.translation_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
translation = result[0].strip()
multilingual_transcript[target_language] = translation
return {
"text": transcript,
"words": words,
"text": multilingual_transcript,
"words": words
}
@@ -122,7 +149,7 @@ class Whisper:
)
@asgi_app()
def web():
from fastapi import FastAPI, UploadFile, Form, Depends, HTTPException, status
from fastapi import Depends, FastAPI, Form, HTTPException, UploadFile, status
from fastapi.security import OAuth2PasswordBearer
from typing_extensions import Annotated
@@ -131,6 +158,7 @@ def web():
app = FastAPI()
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
supported_audio_file_types = ["wav", "mp3", "ogg", "flac"]
def apikey_auth(apikey: str = Depends(oauth2_scheme)):
if apikey != os.environ["REFLECTOR_GPU_APIKEY"]:
@@ -140,28 +168,26 @@ def web():
headers={"WWW-Authenticate": "Bearer"},
)
class TranscriptionRequest(BaseModel):
timestamp: float = 0
language: str = "en"
class TranscriptResponse(BaseModel):
result: str
result: dict
@app.post("/transcribe", dependencies=[Depends(apikey_auth)])
async def transcribe(
file: UploadFile,
timestamp: Annotated[float, Form()] = 0,
language: Annotated[str, Form()] = "en",
):
source_language: Annotated[str, Form()] = "en",
target_language: Annotated[str, Form()] = "en"
) -> TranscriptResponse:
audio_data = await file.read()
audio_suffix = file.filename.split(".")[-1]
assert audio_suffix in ["wav", "mp3", "ogg", "flac"]
assert audio_suffix in supported_audio_file_types
func = transcriberstub.transcribe_segment.spawn(
audio_data=audio_data,
audio_suffix=audio_suffix,
language=language,
timestamp=timestamp,
source_language=source_language,
target_language=target_language,
timestamp=timestamp
)
result = func.get()
return result

View File

@@ -5,19 +5,22 @@ API will be a POST request to TRANSCRIPT_URL:
```form
"timestamp": 123.456
"language": "en"
"source_language": "en"
"target_language": "en"
"file": <audio file>
```
"""
from time import monotonic
import httpx
from reflector.processors.audio_transcript import AudioTranscriptProcessor
from reflector.processors.audio_transcript_auto import AudioTranscriptAutoProcessor
from reflector.processors.types import AudioFile, Transcript, Word
from reflector.processors.types import AudioFile, Transcript, TranslationLanguages, Word
from reflector.settings import settings
from reflector.utils.retry import retry
from time import monotonic
import httpx
class AudioTranscriptModalProcessor(AudioTranscriptProcessor):
@@ -26,9 +29,7 @@ class AudioTranscriptModalProcessor(AudioTranscriptProcessor):
self.transcript_url = settings.TRANSCRIPT_URL + "/transcribe"
self.warmup_url = settings.TRANSCRIPT_URL + "/warmup"
self.timeout = settings.TRANSCRIPT_TIMEOUT
self.headers = {
"Authorization": f"Bearer {modal_api_key}",
}
self.headers = {"Authorization": f"Bearer {modal_api_key}"}
async def _warmup(self):
try:
@@ -52,11 +53,28 @@ class AudioTranscriptModalProcessor(AudioTranscriptProcessor):
files = {
"file": (data.name, data.fd),
}
# TODO: Get the source / target language from the UI preferences dynamically
# Update code here once this is possible.
# i.e) extract from context/session objects
source_language = "en"
target_language = "en"
languages = TranslationLanguages()
# Only way to set the target should be the UI element like dropdown.
# Hence, this assert should never fail.
assert languages.is_supported(target_language)
json_payload = {
"source_language": source_language,
"target_language": target_language,
}
response = await retry(client.post)(
self.transcript_url,
files=files,
timeout=self.timeout,
headers=self.headers,
json=json_payload,
)
self.logger.debug(
@@ -64,8 +82,14 @@ class AudioTranscriptModalProcessor(AudioTranscriptProcessor):
)
response.raise_for_status()
result = response.json()
# Sanity check for translation status in the result
if target_language in result["text"]:
text = result["text"][target_language]
else:
text = result["text"][source_language]
transcript = Transcript(
text=result["text"],
text=text,
words=[
Word(
text=word["text"],

View File

@@ -1,7 +1,8 @@
from pydantic import BaseModel, PrivateAttr
from pathlib import Path
import tempfile
import io
import tempfile
from pathlib import Path
from pydantic import BaseModel, PrivateAttr
class AudioFile(BaseModel):
@@ -104,3 +105,117 @@ class TitleSummary(BaseModel):
class FinalSummary(BaseModel):
summary: str
duration: float
class TranslationLanguages(BaseModel):
language_to_id_mapping: dict = {
"Afrikaans": "af",
"Albanian": "sq",
"Amharic": "am",
"Arabic": "ar",
"Armenian": "hy",
"Asturian": "ast",
"Azerbaijani": "az",
"Bashkir": "ba",
"Belarusian": "be",
"Bengali": "bn",
"Bosnian": "bs",
"Breton": "br",
"Bulgarian": "bg",
"Burmese": "my",
"Catalan; Valencian": "ca",
"Cebuano": "ceb",
"Central Khmer": "km",
"Chinese": "zh",
"Croatian": "hr",
"Czech": "cs",
"Danish": "da",
"Dutch; Flemish": "nl",
"English": "en",
"Estonian": "et",
"Finnish": "fi",
"French": "fr",
"Fulah": "ff",
"Gaelic; Scottish Gaelic": "gd",
"Galician": "gl",
"Ganda": "lg",
"Georgian": "ka",
"German": "de",
"Greeek": "el",
"Gujarati": "gu",
"Haitian; Haitian Creole": "ht",
"Hausa": "ha",
"Hebrew": "he",
"Hindi": "hi",
"Hungarian": "hu",
"Icelandic": "is",
"Igbo": "ig",
"Iloko": "ilo",
"Indonesian": "id",
"Irish": "ga",
"Italian": "it",
"Japanese": "ja",
"Javanese": "jv",
"Kannada": "kn",
"Kazakh": "kk",
"Korean": "ko",
"Lao": "lo",
"Latvian": "lv",
"Lingala": "ln",
"Lithuanian": "lt",
"Luxembourgish; Letzeburgesch": "lb",
"Macedonian": "mk",
"Malagasy": "mg",
"Malay": "ms",
"Malayalam": "ml",
"Marathi": "mr",
"Mongolian": "mn",
"Nepali": "ne",
"Northern Sotho": "ns",
"Norwegian": "no",
"Occitan": "oc",
"Oriya": "or",
"Panjabi; Punjabi": "pa",
"Persian": "fa",
"Polish": "pl",
"Portuguese": "pt",
"Pushto; Pashto": "ps",
"Romanian; Moldavian; Moldovan": "ro",
"Russian": "ru",
"Serbian": "sr",
"Sindhi": "sd",
"Sinhala; Sinhalese": "si",
"Slovak": "sk",
"Slovenian": "sl",
"Somali": "so",
"Spanish": "es",
"Sundanese": "su",
"Swahili": "sw",
"Swati": "ss",
"Swedish": "sv",
"Tagalog": "tl",
"Tamil": "ta",
"Thai": "th",
"Tswana": "tn",
"Turkish": "tr",
"Ukrainian": "uk",
"Urdu": "ur",
"Uzbek": "uz",
"Vietnamese": "vi",
"Welsh": "cy",
"Western Frisian": "fy",
"Wolof": "wo",
"Xhosa": "xh",
"Yiddish": "yi",
"Yoruba": "yo",
"Zulu": "zu",
}
@property
def supported_languages(self):
return self.language_to_id_mapping.values()
def is_supported(self, lang_id: str) -> bool:
if lang_id in self.supported_languages:
return True
return False

View File

@@ -0,0 +1,72 @@
import os
from typing import BinaryIO
from fastapi import HTTPException, Request, status
from fastapi.responses import StreamingResponse
def send_bytes_range_requests(
file_obj: BinaryIO, start: int, end: int, chunk_size: int = 10_000
):
"""Send a file in chunks using Range Requests specification RFC7233
`start` and `end` parameters are inclusive due to specification
"""
with file_obj as f:
f.seek(start)
while (pos := f.tell()) <= end:
read_size = min(chunk_size, end + 1 - pos)
yield f.read(read_size)
def _get_range_header(range_header: str, file_size: int) -> tuple[int, int]:
def _invalid_range():
return HTTPException(
status.HTTP_416_REQUESTED_RANGE_NOT_SATISFIABLE,
detail=f"Invalid request range (Range:{range_header!r})",
)
try:
h = range_header.replace("bytes=", "").split("-")
start = int(h[0]) if h[0] != "" else 0
end = int(h[1]) if h[1] != "" else file_size - 1
except ValueError:
raise _invalid_range()
if start > end or start < 0 or end > file_size - 1:
raise _invalid_range()
return start, end
def range_requests_response(request: Request, file_path: str, content_type: str):
"""Returns StreamingResponse using Range Requests of a given file"""
file_size = os.stat(file_path).st_size
range_header = request.headers.get("range")
headers = {
"content-type": content_type,
"accept-ranges": "bytes",
"content-encoding": "identity",
"content-length": str(file_size),
"access-control-expose-headers": (
"content-type, accept-ranges, content-length, "
"content-range, content-encoding"
),
}
start = 0
end = file_size - 1
status_code = status.HTTP_200_OK
if range_header is not None:
start, end = _get_range_header(range_header, file_size)
size = end - start + 1
headers["content-length"] = str(size)
headers["content-range"] = f"bytes {start}-{end}/{file_size}"
status_code = status.HTTP_206_PARTIAL_CONTENT
return StreamingResponse(
send_bytes_range_requests(open(file_path, mode="rb"), start, end),
headers=headers,
status_code=status_code,
)

View File

@@ -14,7 +14,6 @@ from fastapi import (
WebSocket,
WebSocketDisconnect,
)
from fastapi.responses import FileResponse
from fastapi_pagination import Page, paginate
from pydantic import BaseModel, Field
from reflector.db import database, transcripts
@@ -22,6 +21,7 @@ from reflector.logger import logger
from reflector.settings import settings
from starlette.concurrency import run_in_threadpool
from ._range_requests_response import range_requests_response
from .rtc_offer import PipelineEvent, RtcOffer, rtc_offer_base
router = APIRouter()
@@ -190,6 +190,7 @@ class GetTranscript(BaseModel):
status: str
locked: bool
duration: int
summary: str | None
created_at: datetime
@@ -200,6 +201,7 @@ class CreateTranscript(BaseModel):
class UpdateTranscript(BaseModel):
name: Optional[str] = Field(None)
locked: Optional[bool] = Field(None)
summary: Optional[str] = Field(None)
class TranscriptEntryCreate(BaseModel):
@@ -262,6 +264,15 @@ async def transcript_update(
values["name"] = info.name
if info.locked is not None:
values["locked"] = info.locked
if info.summary is not None:
values["summary"] = info.summary
# also find FINAL_SUMMARY event and patch it
for te in transcript.events:
if te["event"] == PipelineEvent.FINAL_SUMMARY:
te["summary"] = info.summary
break
values["events"] = transcript.events
await transcripts_controller.update(transcript, values)
return transcript
@@ -281,6 +292,7 @@ async def transcript_delete(
@router.get("/transcripts/{transcript_id}/audio")
async def transcript_get_audio(
request: Request,
transcript_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
):
@@ -292,11 +304,16 @@ async def transcript_get_audio(
if not transcript.audio_filename.exists():
raise HTTPException(status_code=404, detail="Audio not found")
return FileResponse(transcript.audio_filename, media_type="audio/wav")
return range_requests_response(
request,
transcript.audio_filename,
content_type="audio/wav",
)
@router.get("/transcripts/{transcript_id}/audio/mp3")
async def transcript_get_audio_mp3(
request: Request,
transcript_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
):
@@ -310,7 +327,11 @@ async def transcript_get_audio_mp3(
await run_in_threadpool(transcript.convert_audio_to_mp3)
return FileResponse(transcript.audio_mp3_filename, media_type="audio/mp3")
return range_requests_response(
request,
transcript.audio_mp3_filename,
content_type="audio/mp3",
)
@router.get("/transcripts/{transcript_id}/topics", response_model=list[TranscriptTopic])

View File

@@ -44,6 +44,54 @@ async def test_transcript_get_update_name():
assert response.json()["name"] == "test2"
@pytest.mark.asyncio
async def test_transcript_get_update_locked():
from reflector.app import app
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
response = await ac.post("/transcripts", json={"name": "test"})
assert response.status_code == 200
assert response.json()["locked"] is False
tid = response.json()["id"]
response = await ac.get(f"/transcripts/{tid}")
assert response.status_code == 200
assert response.json()["locked"] is False
response = await ac.patch(f"/transcripts/{tid}", json={"locked": True})
assert response.status_code == 200
assert response.json()["locked"] is True
response = await ac.get(f"/transcripts/{tid}")
assert response.status_code == 200
assert response.json()["locked"] is True
@pytest.mark.asyncio
async def test_transcript_get_update_summary():
from reflector.app import app
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
response = await ac.post("/transcripts", json={"name": "test"})
assert response.status_code == 200
assert response.json()["summary"] is None
tid = response.json()["id"]
response = await ac.get(f"/transcripts/{tid}")
assert response.status_code == 200
assert response.json()["summary"] is None
response = await ac.patch(f"/transcripts/{tid}", json={"summary": "test"})
assert response.status_code == 200
assert response.json()["summary"] == "test"
response = await ac.get(f"/transcripts/{tid}")
assert response.status_code == 200
assert response.json()["summary"] == "test"
@pytest.mark.asyncio
async def test_transcripts_list_anonymous():
# XXX this test is a bit fragile, as it depends on the storage which

View File

@@ -0,0 +1,95 @@
import pytest
import shutil
from httpx import AsyncClient
from pathlib import Path
@pytest.fixture
async def fake_transcript(tmpdir):
from reflector.settings import settings
from reflector.app import app
from reflector.views.transcripts import transcripts_controller
settings.DATA_DIR = Path(tmpdir)
# create a transcript
ac = AsyncClient(app=app, base_url="http://test/v1")
response = await ac.post("/transcripts", json={"name": "Test audio download"})
assert response.status_code == 200
tid = response.json()["id"]
transcript = await transcripts_controller.get_by_id(tid)
assert transcript is not None
await transcripts_controller.update(transcript, {"status": "finished"})
# manually copy a file at the expected location
audio_filename = transcript.audio_filename
path = Path(__file__).parent / "records" / "test_mathieu_hello.wav"
audio_filename.parent.mkdir(parents=True, exist_ok=True)
shutil.copy(path, audio_filename)
yield transcript
@pytest.mark.asyncio
@pytest.mark.parametrize(
"url_suffix,content_type",
[
["", "audio/wav"],
["/mp3", "audio/mp3"],
],
)
async def test_transcript_audio_download(fake_transcript, url_suffix, content_type):
from reflector.app import app
ac = AsyncClient(app=app, base_url="http://test/v1")
response = await ac.get(f"/transcripts/{fake_transcript.id}/audio{url_suffix}")
assert response.status_code == 200
assert response.headers["content-type"] == content_type
@pytest.mark.asyncio
@pytest.mark.parametrize(
"url_suffix,content_type",
[
["", "audio/wav"],
["/mp3", "audio/mp3"],
],
)
async def test_transcript_audio_download_range(
fake_transcript, url_suffix, content_type
):
from reflector.app import app
ac = AsyncClient(app=app, base_url="http://test/v1")
response = await ac.get(
f"/transcripts/{fake_transcript.id}/audio{url_suffix}",
headers={"range": "bytes=0-100"},
)
assert response.status_code == 206
assert response.headers["content-type"] == content_type
assert response.headers["content-range"].startswith("bytes 0-100/")
assert response.headers["content-length"] == "101"
@pytest.mark.asyncio
@pytest.mark.parametrize(
"url_suffix,content_type",
[
["", "audio/wav"],
["/mp3", "audio/mp3"],
],
)
async def test_transcript_audio_download_range_with_seek(
fake_transcript, url_suffix, content_type
):
from reflector.app import app
ac = AsyncClient(app=app, base_url="http://test/v1")
response = await ac.get(
f"/transcripts/{fake_transcript.id}/audio{url_suffix}",
headers={"range": "bytes=100-"},
)
assert response.status_code == 206
assert response.headers["content-type"] == content_type
assert response.headers["content-range"].startswith("bytes 100-")

View File

@@ -7,6 +7,7 @@ import useTranscript from "../useTranscript";
import { useWebSockets } from "../useWebSockets";
import "../../styles/button.css";
import { Topic } from "../webSocketTypes";
import getApi from "../../lib/getApi";
const App = () => {
const [stream, setStream] = useState<MediaStream | null>(null);
@@ -23,8 +24,9 @@ const App = () => {
}
}, []);
const api = getApi();
const transcript = useTranscript();
const webRTC = useWebRTC(stream, transcript.response?.id);
const webRTC = useWebRTC(stream, transcript.response?.id, api);
const webSockets = useWebSockets(transcript.response?.id);
return (

View File

@@ -5,11 +5,11 @@ import {
V1TranscriptRecordWebrtcRequest,
} from "../api/apis/DefaultApi";
import { Configuration } from "../api/runtime";
import getApi from "../lib/getApi";
const useWebRTC = (
stream: MediaStream | null,
transcriptId: string | null,
api: DefaultApi,
): Peer => {
const [peer, setPeer] = useState<Peer | null>(null);
@@ -18,8 +18,6 @@ const useWebRTC = (
return;
}
const api = getApi();
let p: Peer = new Peer({ initiator: true, stream: stream });
p.on("signal", (data: any) => {