Merge branch 'main' of github.com:Monadical-SAS/reflector into feat-sharing

This commit is contained in:
Sara
2023-11-21 12:11:58 +01:00
47 changed files with 1163 additions and 614 deletions

View File

@@ -41,7 +41,6 @@ if settings.SENTRY_DSN:
else:
logger.info("Sentry disabled")
# build app
app = FastAPI(lifespan=lifespan)
app.add_middleware(
@@ -102,6 +101,23 @@ def use_route_names_as_operation_ids(app: FastAPI) -> None:
use_route_names_as_operation_ids(app)
if settings.PROFILING:
from fastapi import Request
from fastapi.responses import HTMLResponse
from pyinstrument import Profiler
@app.middleware("http")
async def profile_request(request: Request, call_next):
profiling = request.query_params.get("profile", False)
if profiling:
profiler = Profiler(async_mode="enabled")
profiler.start()
await call_next(request)
profiler.stop()
return HTMLResponse(profiler.output_html())
else:
return await call_next(request)
if __name__ == "__main__":
import uvicorn

View File

@@ -11,7 +11,6 @@ from pydantic import BaseModel, Field
from reflector.db import database, metadata
from reflector.processors.types import Word as ProcessorWord
from reflector.settings import settings
from reflector.utils.audio_waveform import get_audio_waveform
transcripts = sqlalchemy.Table(
"transcript",
@@ -86,6 +85,14 @@ class TranscriptFinalTitle(BaseModel):
title: str
class TranscriptDuration(BaseModel):
duration: float
class TranscriptWaveform(BaseModel):
waveform: list[float]
class TranscriptEvent(BaseModel):
event: str
data: dict
@@ -126,22 +133,6 @@ class Transcript(BaseModel):
def topics_dump(self, mode="json"):
return [topic.model_dump(mode=mode) for topic in self.topics]
def convert_audio_to_waveform(self, segments_count=256):
fn = self.audio_waveform_filename
if fn.exists():
return
waveform = get_audio_waveform(
path=self.audio_mp3_filename, segments_count=segments_count
)
try:
with open(fn, "w") as fd:
json.dump(waveform, fd)
except Exception:
# remove file if anything happen during the write
fn.unlink(missing_ok=True)
raise
return waveform
def unlink(self):
self.data_path.unlink(missing_ok=True)

View File

@@ -1,54 +0,0 @@
import httpx
from reflector.llm.base import LLM
from reflector.settings import settings
from reflector.utils.retry import retry
class BananaLLM(LLM):
def __init__(self):
super().__init__()
self.timeout = settings.LLM_TIMEOUT
self.headers = {
"X-Banana-API-Key": settings.LLM_BANANA_API_KEY,
"X-Banana-Model-Key": settings.LLM_BANANA_MODEL_KEY,
}
async def _generate(
self, prompt: str, gen_schema: dict | None, gen_cfg: dict | None, **kwargs
):
json_payload = {"prompt": prompt}
if gen_schema:
json_payload["gen_schema"] = gen_schema
if gen_cfg:
json_payload["gen_cfg"] = gen_cfg
async with httpx.AsyncClient() as client:
response = await retry(client.post)(
settings.LLM_URL,
headers=self.headers,
json=json_payload,
timeout=self.timeout,
retry_timeout=300, # as per their sdk
)
response.raise_for_status()
text = response.json()["text"]
return text
LLM.register("banana", BananaLLM)
if __name__ == "__main__":
from reflector.logger import logger
async def main():
llm = BananaLLM()
prompt = llm.create_prompt(
instruct="Complete the following task",
text="Tell me a joke about programming.",
)
result = await llm.generate(prompt=prompt, logger=logger)
print(result)
import asyncio
asyncio.run(main())

View File

@@ -21,11 +21,13 @@ from pydantic import BaseModel
from reflector.app import app
from reflector.db.transcripts import (
Transcript,
TranscriptDuration,
TranscriptFinalLongSummary,
TranscriptFinalShortSummary,
TranscriptFinalTitle,
TranscriptText,
TranscriptTopic,
TranscriptWaveform,
transcripts_controller,
)
from reflector.logger import logger
@@ -45,6 +47,7 @@ from reflector.processors import (
TranscriptTopicDetectorProcessor,
TranscriptTranslatorProcessor,
)
from reflector.processors.audio_waveform_processor import AudioWaveformProcessor
from reflector.processors.types import AudioDiarizationInput
from reflector.processors.types import (
TitleSummaryWithId as TitleSummaryWithIdProcessorType,
@@ -230,6 +233,33 @@ class PipelineMainBase(PipelineRunner):
data=final_short_summary,
)
@broadcast_to_sockets
async def on_duration(self, data):
async with self.transaction():
duration = TranscriptDuration(duration=data)
transcript = await self.get_transcript()
await transcripts_controller.update(
transcript,
{
"duration": duration.duration,
},
)
return await transcripts_controller.append_event(
transcript=transcript, event="DURATION", data=duration
)
@broadcast_to_sockets
async def on_waveform(self, data):
async with self.transaction():
waveform = TranscriptWaveform(waveform=data)
transcript = await self.get_transcript()
return await transcripts_controller.append_event(
transcript=transcript, event="WAVEFORM", data=waveform
)
class PipelineMainLive(PipelineMainBase):
audio_filename: Path | None = None
@@ -243,7 +273,10 @@ class PipelineMainLive(PipelineMainBase):
transcript = await self.get_transcript()
processors = [
AudioFileWriterProcessor(path=transcript.audio_mp3_filename),
AudioFileWriterProcessor(
path=transcript.audio_mp3_filename,
on_duration=self.on_duration,
),
AudioChunkerProcessor(),
AudioMergeProcessor(),
AudioTranscriptAutoProcessor.as_threaded(),
@@ -253,6 +286,11 @@ class PipelineMainLive(PipelineMainBase):
BroadcastProcessor(
processors=[
TranscriptFinalTitleProcessor.as_threaded(callback=self.on_title),
AudioWaveformProcessor.as_threaded(
audio_path=transcript.audio_mp3_filename,
waveform_path=transcript.audio_waveform_filename,
on_waveform=self.on_waveform,
),
]
),
]
@@ -285,8 +323,13 @@ class PipelineMainDiarization(PipelineMainBase):
# create a context for the whole rtc transaction
# add a customised logger to the context
self.prepare()
processors = [
AudioDiarizationAutoProcessor(callback=self.on_topic),
processors = []
if settings.DIARIZATION_ENABLED:
processors += [
AudioDiarizationAutoProcessor(callback=self.on_topic),
]
processors += [
BroadcastProcessor(
processors=[
TranscriptFinalLongSummaryProcessor.as_threaded(

View File

@@ -12,8 +12,8 @@ class AudioFileWriterProcessor(Processor):
INPUT_TYPE = av.AudioFrame
OUTPUT_TYPE = av.AudioFrame
def __init__(self, path: Path | str):
super().__init__()
def __init__(self, path: Path | str, **kwargs):
super().__init__(**kwargs)
if isinstance(path, str):
path = Path(path)
if path.suffix not in (".mp3", ".wav"):
@@ -21,6 +21,7 @@ class AudioFileWriterProcessor(Processor):
self.path = path
self.out_container = None
self.out_stream = None
self.last_packet = None
async def _push(self, data: av.AudioFrame):
if not self.out_container:
@@ -40,12 +41,30 @@ class AudioFileWriterProcessor(Processor):
raise ValueError("Only mp3 and wav files are supported")
for packet in self.out_stream.encode(data):
self.out_container.mux(packet)
self.last_packet = packet
await self.emit(data)
async def _flush(self):
if self.out_container:
for packet in self.out_stream.encode():
self.out_container.mux(packet)
self.last_packet = packet
try:
if self.last_packet is not None:
duration = round(
float(
(self.last_packet.pts * self.last_packet.duration)
* self.last_packet.time_base
),
2,
)
except Exception:
self.logger.exception("Failed to get duration")
duration = 0
self.out_container.close()
self.out_container = None
self.out_stream = None
if duration > 0:
await self.emit(duration, name="duration")

View File

@@ -1,86 +0,0 @@
"""
Implementation using the GPU service from banana.
API will be a POST request to TRANSCRIPT_URL:
```json
{
"audio_url": "https://...",
"audio_ext": "wav",
"timestamp": 123.456
"language": "en"
}
```
"""
from pathlib import Path
import httpx
from reflector.processors.audio_transcript import AudioTranscriptProcessor
from reflector.processors.audio_transcript_auto import AudioTranscriptAutoProcessor
from reflector.processors.types import AudioFile, Transcript, Word
from reflector.settings import settings
from reflector.storage import Storage
from reflector.utils.retry import retry
class AudioTranscriptBananaProcessor(AudioTranscriptProcessor):
def __init__(self, banana_api_key: str, banana_model_key: str):
super().__init__()
self.transcript_url = settings.TRANSCRIPT_URL
self.timeout = settings.TRANSCRIPT_TIMEOUT
self.storage = Storage.get_instance(
settings.TRANSCRIPT_STORAGE_BACKEND, "TRANSCRIPT_STORAGE_"
)
self.headers = {
"X-Banana-API-Key": banana_api_key,
"X-Banana-Model-Key": banana_model_key,
}
async def _transcript(self, data: AudioFile):
async with httpx.AsyncClient() as client:
print(f"Uploading audio {data.path.name} to S3")
url = await self._upload_file(data.path)
print(f"Try to transcribe audio {data.path.name}")
request_data = {
"audio_url": url,
"audio_ext": data.path.suffix[1:],
"timestamp": float(round(data.timestamp, 2)),
}
response = await retry(client.post)(
self.transcript_url,
json=request_data,
headers=self.headers,
timeout=self.timeout,
)
print(f"Transcript response: {response.status_code} {response.content}")
response.raise_for_status()
result = response.json()
transcript = Transcript(
text=result["text"],
words=[
Word(text=word["text"], start=word["start"], end=word["end"])
for word in result["words"]
],
)
# remove audio file from S3
await self._delete_file(data.path)
return transcript
@retry
async def _upload_file(self, path: Path) -> str:
upload_result = await self.storage.put_file(path.name, open(path, "rb"))
return upload_result.url
@retry
async def _delete_file(self, path: Path):
await self.storage.delete_file(path.name)
return True
AudioTranscriptAutoProcessor.register("banana", AudioTranscriptBananaProcessor)

View File

@@ -0,0 +1,36 @@
import json
from pathlib import Path
from reflector.processors.base import Processor
from reflector.processors.types import TitleSummary
from reflector.utils.audio_waveform import get_audio_waveform
class AudioWaveformProcessor(Processor):
"""
Write the waveform for the final audio
"""
INPUT_TYPE = TitleSummary
def __init__(self, audio_path: Path | str, waveform_path: str, **kwargs):
super().__init__(**kwargs)
if isinstance(audio_path, str):
audio_path = Path(audio_path)
if audio_path.suffix not in (".mp3", ".wav"):
raise ValueError("Only mp3 and wav files are supported")
self.audio_path = audio_path
self.waveform_path = waveform_path
async def _flush(self):
self.waveform_path.parent.mkdir(parents=True, exist_ok=True)
self.logger.info("Waveform Processing Started")
waveform = get_audio_waveform(path=self.audio_path, segments_count=255)
with open(self.waveform_path, "w") as fd:
json.dump(waveform, fd)
self.logger.info("Waveform Processing Finished")
await self.emit(waveform, name="waveform")
async def _push(_self, _data):
return

View File

@@ -14,7 +14,42 @@ class PipelineEvent(BaseModel):
data: Any
class Processor:
class Emitter:
def __init__(self, **kwargs):
self._callbacks = {}
# register callbacks from kwargs (on_*)
for key, value in kwargs.items():
if key.startswith("on_"):
self.on(value, name=key[3:])
def on(self, callback, name="default"):
"""
Register a callback to be called when data is emitted
"""
# ensure callback is asynchronous
if not asyncio.iscoroutinefunction(callback):
raise ValueError("Callback must be a coroutine function")
if name not in self._callbacks:
self._callbacks[name] = []
self._callbacks[name].append(callback)
def off(self, callback, name="default"):
"""
Unregister a callback to be called when data is emitted
"""
if name not in self._callbacks:
return
self._callbacks[name].remove(callback)
async def emit(self, data, name="default"):
if name not in self._callbacks:
return
for callback in self._callbacks[name]:
await callback(data)
class Processor(Emitter):
INPUT_TYPE: type = None
OUTPUT_TYPE: type = None
@@ -59,7 +94,8 @@ class Processor:
["processor"],
)
def __init__(self, callback=None, custom_logger=None):
def __init__(self, callback=None, custom_logger=None, **kwargs):
super().__init__(**kwargs)
self.name = name = self.__class__.__name__
self.m_processor = self.m_processor.labels(name)
self.m_processor_call = self.m_processor_call.labels(name)
@@ -70,9 +106,11 @@ class Processor:
self.m_processor_flush_success = self.m_processor_flush_success.labels(name)
self.m_processor_flush_failure = self.m_processor_flush_failure.labels(name)
self._processors = []
self._callbacks = []
# register callbacks
if callback:
self.on(callback)
self.uid = uuid4().hex
self.flushed = False
self.logger = (custom_logger or logger).bind(processor=self.__class__.__name__)
@@ -100,21 +138,6 @@ class Processor:
"""
self._processors.remove(processor)
def on(self, callback):
"""
Register a callback to be called when data is emitted
"""
# ensure callback is asynchronous
if not asyncio.iscoroutinefunction(callback):
raise ValueError("Callback must be a coroutine function")
self._callbacks.append(callback)
def off(self, callback):
"""
Unregister a callback to be called when data is emitted
"""
self._callbacks.remove(callback)
def get_pref(self, key: str, default: Any = None):
"""
Get a preference from the pipeline prefs
@@ -123,15 +146,16 @@ class Processor:
return self.pipeline.get_pref(key, default)
return default
async def emit(self, data):
if self.pipeline:
await self.pipeline.emit(
PipelineEvent(processor=self.name, uid=self.uid, data=data)
)
for callback in self._callbacks:
await callback(data)
for processor in self._processors:
await processor.push(data)
async def emit(self, data, name="default"):
if name == "default":
if self.pipeline:
await self.pipeline.emit(
PipelineEvent(processor=self.name, uid=self.uid, data=data)
)
await super().emit(data, name=name)
if name == "default":
for processor in self._processors:
await processor.push(data)
async def push(self, data):
"""
@@ -254,11 +278,11 @@ class ThreadedProcessor(Processor):
def disconnect(self, processor: Processor):
self.processor.disconnect(processor)
def on(self, callback):
self.processor.on(callback)
def on(self, callback, name="default"):
self.processor.on(callback, name=name)
def off(self, callback):
self.processor.off(callback)
def off(self, callback, name="default"):
self.processor.off(callback, name=name)
def describe(self, level=0):
super().describe(level)
@@ -305,13 +329,13 @@ class BroadcastProcessor(Processor):
for processor in self.processors:
processor.disconnect(processor)
def on(self, callback):
def on(self, callback, name="default"):
for processor in self.processors:
processor.on(callback)
processor.on(callback, name=name)
def off(self, callback):
def off(self, callback, name="default"):
for processor in self.processors:
processor.off(callback)
processor.off(callback, name=name)
def describe(self, level=0):
super().describe(level)

View File

@@ -16,6 +16,7 @@ class TranscriptTranslatorProcessor(Processor):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.transcript = None
self.translate_url = settings.TRANSLATE_URL
self.timeout = settings.TRANSLATE_TIMEOUT
self.headers = {"Authorization": f"Bearer {settings.LLM_MODAL_API_KEY}"}

View File

@@ -5,6 +5,7 @@ from pathlib import Path
from profanityfilter import ProfanityFilter
from pydantic import BaseModel, PrivateAttr
from reflector.redis_cache import redis_cache
PUNC_RE = re.compile(r"[.;:?!…]")
@@ -68,10 +69,14 @@ class Transcript(BaseModel):
# Uncensored text
return "".join([word.text for word in self.words])
@redis_cache(prefix="profanity", duration=3600 * 24 * 7)
def _get_censored_text(self, text: str):
return profanity_filter.censor(text).strip()
@property
def text(self):
# Censored text
return profanity_filter.censor(self.raw_text).strip()
return self._get_censored_text(self.raw_text)
@property
def human_timestamp(self):

View File

@@ -0,0 +1,50 @@
import functools
import json
import redis
from reflector.settings import settings
redis_clients = {}
def get_redis_client(db=0):
"""
Get a Redis client for the specified database.
"""
if db not in redis_clients:
redis_clients[db] = redis.StrictRedis(
host=settings.REDIS_HOST,
port=settings.REDIS_PORT,
db=db,
)
return redis_clients[db]
def redis_cache(prefix="cache", duration=3600, db=settings.REDIS_CACHE_DB, argidx=1):
"""
Cache the result of a function in Redis.
"""
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
# Check if the first argument is a string
if len(args) < (argidx + 1) or not isinstance(args[argidx], str):
return func(*args, **kwargs)
# Compute the cache key based on the arguments and prefix
cache_key = prefix + ":" + args[argidx]
redis_client = get_redis_client(db=db)
cached_result = redis_client.get(cache_key)
if cached_result:
return json.loads(cached_result.decode("utf-8"))
# If the result is not cached, call the original function
result = func(*args, **kwargs)
redis_client.setex(cache_key, duration, json.dumps(result))
return result
return wrapper
return decorator

View File

@@ -41,7 +41,7 @@ class Settings(BaseSettings):
AUDIO_BUFFER_SIZE: int = 256 * 960
# Audio Transcription
# backends: whisper, banana, modal
# backends: whisper, modal
TRANSCRIPT_BACKEND: str = "whisper"
TRANSCRIPT_URL: str | None = None
TRANSCRIPT_TIMEOUT: int = 90
@@ -50,10 +50,6 @@ class Settings(BaseSettings):
TRANSLATE_URL: str | None = None
TRANSLATE_TIMEOUT: int = 90
# Audio transcription banana.dev configuration
TRANSCRIPT_BANANA_API_KEY: str | None = None
TRANSCRIPT_BANANA_MODEL_KEY: str | None = None
# Audio transcription modal.com configuration
TRANSCRIPT_MODAL_API_KEY: str | None = None
@@ -61,13 +57,16 @@ class Settings(BaseSettings):
TRANSCRIPT_STORAGE_BACKEND: str = "aws"
# Storage configuration for AWS
TRANSCRIPT_STORAGE_AWS_BUCKET_NAME: str = "reflector-bucket/chunks"
TRANSCRIPT_STORAGE_AWS_BUCKET_NAME: str = "reflector-bucket"
TRANSCRIPT_STORAGE_AWS_REGION: str = "us-east-1"
TRANSCRIPT_STORAGE_AWS_ACCESS_KEY_ID: str | None = None
TRANSCRIPT_STORAGE_AWS_SECRET_ACCESS_KEY: str | None = None
# Transcript MP3 storage
TRANSCRIPT_MP3_STORAGE_BACKEND: str = "aws"
# LLM
# available backend: openai, banana, modal, oobabooga
# available backend: openai, modal, oobabooga
LLM_BACKEND: str = "oobabooga"
# LLM common configuration
@@ -82,14 +81,11 @@ class Settings(BaseSettings):
LLM_TEMPERATURE: float = 0.7
ZEPHYR_LLM_URL: str | None = None
# LLM Banana configuration
LLM_BANANA_API_KEY: str | None = None
LLM_BANANA_MODEL_KEY: str | None = None
# LLM Modal configuration
LLM_MODAL_API_KEY: str | None = None
# Diarization
DIARIZATION_ENABLED: bool = True
DIARIZATION_BACKEND: str = "modal"
DIARIZATION_URL: str | None = None
@@ -124,6 +120,7 @@ class Settings(BaseSettings):
# Redis
REDIS_HOST: str = "localhost"
REDIS_PORT: int = 6379
REDIS_CACHE_DB: int = 2
# Secret key
SECRET_KEY: str = "changeme-f02f86fd8b3e4fd892c6043e5a298e21"
@@ -131,5 +128,8 @@ class Settings(BaseSettings):
# Current hosting/domain
BASE_URL: str = "http://localhost:1250"
# Profiling
PROFILING: bool = False
settings = Settings()

View File

@@ -1,7 +1,7 @@
import os
from typing import BinaryIO
from fastapi import HTTPException, Request, status
from fastapi import HTTPException, Request, Response, status
from fastapi.responses import StreamingResponse
@@ -57,6 +57,9 @@ def range_requests_response(
),
}
if request.method == "HEAD":
return Response(headers=headers)
if content_disposition:
headers["Content-Disposition"] = content_disposition

View File

@@ -23,7 +23,6 @@ from reflector.db.transcripts import (
from reflector.processors.types import Transcript as ProcessorTranscript
from reflector.settings import settings
from reflector.ws_manager import get_ws_manager
from starlette.concurrency import run_in_threadpool
from ._range_requests_response import range_requests_response
from .rtc_offer import RtcOffer, rtc_offer_base
@@ -53,7 +52,7 @@ class GetTranscript(BaseModel):
name: str
status: str
locked: bool
duration: int
duration: float
title: str | None
short_summary: str | None
long_summary: str | None
@@ -222,6 +221,7 @@ async def transcript_delete(
@router.get("/transcripts/{transcript_id}/audio/mp3")
@router.head("/transcripts/{transcript_id}/audio/mp3")
async def transcript_get_audio_mp3(
request: Request,
transcript_id: str,
@@ -272,8 +272,6 @@ async def transcript_get_audio_waveform(
if not transcript.audio_mp3_filename.exists():
raise HTTPException(status_code=500, detail="Audio not found")
await run_in_threadpool(transcript.convert_audio_to_waveform)
return transcript.audio_waveform