mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-21 12:49:06 +00:00
Merge branch 'main' of github.com:Monadical-SAS/reflector into feat-sharing
This commit is contained in:
@@ -41,7 +41,6 @@ if settings.SENTRY_DSN:
|
||||
else:
|
||||
logger.info("Sentry disabled")
|
||||
|
||||
|
||||
# build app
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
app.add_middleware(
|
||||
@@ -102,6 +101,23 @@ def use_route_names_as_operation_ids(app: FastAPI) -> None:
|
||||
|
||||
use_route_names_as_operation_ids(app)
|
||||
|
||||
if settings.PROFILING:
|
||||
from fastapi import Request
|
||||
from fastapi.responses import HTMLResponse
|
||||
from pyinstrument import Profiler
|
||||
|
||||
@app.middleware("http")
|
||||
async def profile_request(request: Request, call_next):
|
||||
profiling = request.query_params.get("profile", False)
|
||||
if profiling:
|
||||
profiler = Profiler(async_mode="enabled")
|
||||
profiler.start()
|
||||
await call_next(request)
|
||||
profiler.stop()
|
||||
return HTMLResponse(profiler.output_html())
|
||||
else:
|
||||
return await call_next(request)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
@@ -11,7 +11,6 @@ from pydantic import BaseModel, Field
|
||||
from reflector.db import database, metadata
|
||||
from reflector.processors.types import Word as ProcessorWord
|
||||
from reflector.settings import settings
|
||||
from reflector.utils.audio_waveform import get_audio_waveform
|
||||
|
||||
transcripts = sqlalchemy.Table(
|
||||
"transcript",
|
||||
@@ -86,6 +85,14 @@ class TranscriptFinalTitle(BaseModel):
|
||||
title: str
|
||||
|
||||
|
||||
class TranscriptDuration(BaseModel):
|
||||
duration: float
|
||||
|
||||
|
||||
class TranscriptWaveform(BaseModel):
|
||||
waveform: list[float]
|
||||
|
||||
|
||||
class TranscriptEvent(BaseModel):
|
||||
event: str
|
||||
data: dict
|
||||
@@ -126,22 +133,6 @@ class Transcript(BaseModel):
|
||||
def topics_dump(self, mode="json"):
|
||||
return [topic.model_dump(mode=mode) for topic in self.topics]
|
||||
|
||||
def convert_audio_to_waveform(self, segments_count=256):
|
||||
fn = self.audio_waveform_filename
|
||||
if fn.exists():
|
||||
return
|
||||
waveform = get_audio_waveform(
|
||||
path=self.audio_mp3_filename, segments_count=segments_count
|
||||
)
|
||||
try:
|
||||
with open(fn, "w") as fd:
|
||||
json.dump(waveform, fd)
|
||||
except Exception:
|
||||
# remove file if anything happen during the write
|
||||
fn.unlink(missing_ok=True)
|
||||
raise
|
||||
return waveform
|
||||
|
||||
def unlink(self):
|
||||
self.data_path.unlink(missing_ok=True)
|
||||
|
||||
|
||||
@@ -1,54 +0,0 @@
|
||||
import httpx
|
||||
|
||||
from reflector.llm.base import LLM
|
||||
from reflector.settings import settings
|
||||
from reflector.utils.retry import retry
|
||||
|
||||
|
||||
class BananaLLM(LLM):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.timeout = settings.LLM_TIMEOUT
|
||||
self.headers = {
|
||||
"X-Banana-API-Key": settings.LLM_BANANA_API_KEY,
|
||||
"X-Banana-Model-Key": settings.LLM_BANANA_MODEL_KEY,
|
||||
}
|
||||
|
||||
async def _generate(
|
||||
self, prompt: str, gen_schema: dict | None, gen_cfg: dict | None, **kwargs
|
||||
):
|
||||
json_payload = {"prompt": prompt}
|
||||
if gen_schema:
|
||||
json_payload["gen_schema"] = gen_schema
|
||||
if gen_cfg:
|
||||
json_payload["gen_cfg"] = gen_cfg
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await retry(client.post)(
|
||||
settings.LLM_URL,
|
||||
headers=self.headers,
|
||||
json=json_payload,
|
||||
timeout=self.timeout,
|
||||
retry_timeout=300, # as per their sdk
|
||||
)
|
||||
response.raise_for_status()
|
||||
text = response.json()["text"]
|
||||
return text
|
||||
|
||||
|
||||
LLM.register("banana", BananaLLM)
|
||||
|
||||
if __name__ == "__main__":
|
||||
from reflector.logger import logger
|
||||
|
||||
async def main():
|
||||
llm = BananaLLM()
|
||||
prompt = llm.create_prompt(
|
||||
instruct="Complete the following task",
|
||||
text="Tell me a joke about programming.",
|
||||
)
|
||||
result = await llm.generate(prompt=prompt, logger=logger)
|
||||
print(result)
|
||||
|
||||
import asyncio
|
||||
|
||||
asyncio.run(main())
|
||||
@@ -21,11 +21,13 @@ from pydantic import BaseModel
|
||||
from reflector.app import app
|
||||
from reflector.db.transcripts import (
|
||||
Transcript,
|
||||
TranscriptDuration,
|
||||
TranscriptFinalLongSummary,
|
||||
TranscriptFinalShortSummary,
|
||||
TranscriptFinalTitle,
|
||||
TranscriptText,
|
||||
TranscriptTopic,
|
||||
TranscriptWaveform,
|
||||
transcripts_controller,
|
||||
)
|
||||
from reflector.logger import logger
|
||||
@@ -45,6 +47,7 @@ from reflector.processors import (
|
||||
TranscriptTopicDetectorProcessor,
|
||||
TranscriptTranslatorProcessor,
|
||||
)
|
||||
from reflector.processors.audio_waveform_processor import AudioWaveformProcessor
|
||||
from reflector.processors.types import AudioDiarizationInput
|
||||
from reflector.processors.types import (
|
||||
TitleSummaryWithId as TitleSummaryWithIdProcessorType,
|
||||
@@ -230,6 +233,33 @@ class PipelineMainBase(PipelineRunner):
|
||||
data=final_short_summary,
|
||||
)
|
||||
|
||||
@broadcast_to_sockets
|
||||
async def on_duration(self, data):
|
||||
async with self.transaction():
|
||||
duration = TranscriptDuration(duration=data)
|
||||
|
||||
transcript = await self.get_transcript()
|
||||
await transcripts_controller.update(
|
||||
transcript,
|
||||
{
|
||||
"duration": duration.duration,
|
||||
},
|
||||
)
|
||||
return await transcripts_controller.append_event(
|
||||
transcript=transcript, event="DURATION", data=duration
|
||||
)
|
||||
|
||||
@broadcast_to_sockets
|
||||
async def on_waveform(self, data):
|
||||
async with self.transaction():
|
||||
waveform = TranscriptWaveform(waveform=data)
|
||||
|
||||
transcript = await self.get_transcript()
|
||||
|
||||
return await transcripts_controller.append_event(
|
||||
transcript=transcript, event="WAVEFORM", data=waveform
|
||||
)
|
||||
|
||||
|
||||
class PipelineMainLive(PipelineMainBase):
|
||||
audio_filename: Path | None = None
|
||||
@@ -243,7 +273,10 @@ class PipelineMainLive(PipelineMainBase):
|
||||
transcript = await self.get_transcript()
|
||||
|
||||
processors = [
|
||||
AudioFileWriterProcessor(path=transcript.audio_mp3_filename),
|
||||
AudioFileWriterProcessor(
|
||||
path=transcript.audio_mp3_filename,
|
||||
on_duration=self.on_duration,
|
||||
),
|
||||
AudioChunkerProcessor(),
|
||||
AudioMergeProcessor(),
|
||||
AudioTranscriptAutoProcessor.as_threaded(),
|
||||
@@ -253,6 +286,11 @@ class PipelineMainLive(PipelineMainBase):
|
||||
BroadcastProcessor(
|
||||
processors=[
|
||||
TranscriptFinalTitleProcessor.as_threaded(callback=self.on_title),
|
||||
AudioWaveformProcessor.as_threaded(
|
||||
audio_path=transcript.audio_mp3_filename,
|
||||
waveform_path=transcript.audio_waveform_filename,
|
||||
on_waveform=self.on_waveform,
|
||||
),
|
||||
]
|
||||
),
|
||||
]
|
||||
@@ -285,8 +323,13 @@ class PipelineMainDiarization(PipelineMainBase):
|
||||
# create a context for the whole rtc transaction
|
||||
# add a customised logger to the context
|
||||
self.prepare()
|
||||
processors = [
|
||||
AudioDiarizationAutoProcessor(callback=self.on_topic),
|
||||
processors = []
|
||||
if settings.DIARIZATION_ENABLED:
|
||||
processors += [
|
||||
AudioDiarizationAutoProcessor(callback=self.on_topic),
|
||||
]
|
||||
|
||||
processors += [
|
||||
BroadcastProcessor(
|
||||
processors=[
|
||||
TranscriptFinalLongSummaryProcessor.as_threaded(
|
||||
|
||||
@@ -12,8 +12,8 @@ class AudioFileWriterProcessor(Processor):
|
||||
INPUT_TYPE = av.AudioFrame
|
||||
OUTPUT_TYPE = av.AudioFrame
|
||||
|
||||
def __init__(self, path: Path | str):
|
||||
super().__init__()
|
||||
def __init__(self, path: Path | str, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if isinstance(path, str):
|
||||
path = Path(path)
|
||||
if path.suffix not in (".mp3", ".wav"):
|
||||
@@ -21,6 +21,7 @@ class AudioFileWriterProcessor(Processor):
|
||||
self.path = path
|
||||
self.out_container = None
|
||||
self.out_stream = None
|
||||
self.last_packet = None
|
||||
|
||||
async def _push(self, data: av.AudioFrame):
|
||||
if not self.out_container:
|
||||
@@ -40,12 +41,30 @@ class AudioFileWriterProcessor(Processor):
|
||||
raise ValueError("Only mp3 and wav files are supported")
|
||||
for packet in self.out_stream.encode(data):
|
||||
self.out_container.mux(packet)
|
||||
self.last_packet = packet
|
||||
await self.emit(data)
|
||||
|
||||
async def _flush(self):
|
||||
if self.out_container:
|
||||
for packet in self.out_stream.encode():
|
||||
self.out_container.mux(packet)
|
||||
self.last_packet = packet
|
||||
try:
|
||||
if self.last_packet is not None:
|
||||
duration = round(
|
||||
float(
|
||||
(self.last_packet.pts * self.last_packet.duration)
|
||||
* self.last_packet.time_base
|
||||
),
|
||||
2,
|
||||
)
|
||||
except Exception:
|
||||
self.logger.exception("Failed to get duration")
|
||||
duration = 0
|
||||
|
||||
self.out_container.close()
|
||||
self.out_container = None
|
||||
self.out_stream = None
|
||||
|
||||
if duration > 0:
|
||||
await self.emit(duration, name="duration")
|
||||
|
||||
@@ -1,86 +0,0 @@
|
||||
"""
|
||||
Implementation using the GPU service from banana.
|
||||
|
||||
API will be a POST request to TRANSCRIPT_URL:
|
||||
|
||||
```json
|
||||
{
|
||||
"audio_url": "https://...",
|
||||
"audio_ext": "wav",
|
||||
"timestamp": 123.456
|
||||
"language": "en"
|
||||
}
|
||||
```
|
||||
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
from reflector.processors.audio_transcript import AudioTranscriptProcessor
|
||||
from reflector.processors.audio_transcript_auto import AudioTranscriptAutoProcessor
|
||||
from reflector.processors.types import AudioFile, Transcript, Word
|
||||
from reflector.settings import settings
|
||||
from reflector.storage import Storage
|
||||
from reflector.utils.retry import retry
|
||||
|
||||
|
||||
class AudioTranscriptBananaProcessor(AudioTranscriptProcessor):
|
||||
def __init__(self, banana_api_key: str, banana_model_key: str):
|
||||
super().__init__()
|
||||
self.transcript_url = settings.TRANSCRIPT_URL
|
||||
self.timeout = settings.TRANSCRIPT_TIMEOUT
|
||||
self.storage = Storage.get_instance(
|
||||
settings.TRANSCRIPT_STORAGE_BACKEND, "TRANSCRIPT_STORAGE_"
|
||||
)
|
||||
self.headers = {
|
||||
"X-Banana-API-Key": banana_api_key,
|
||||
"X-Banana-Model-Key": banana_model_key,
|
||||
}
|
||||
|
||||
async def _transcript(self, data: AudioFile):
|
||||
async with httpx.AsyncClient() as client:
|
||||
print(f"Uploading audio {data.path.name} to S3")
|
||||
url = await self._upload_file(data.path)
|
||||
|
||||
print(f"Try to transcribe audio {data.path.name}")
|
||||
request_data = {
|
||||
"audio_url": url,
|
||||
"audio_ext": data.path.suffix[1:],
|
||||
"timestamp": float(round(data.timestamp, 2)),
|
||||
}
|
||||
response = await retry(client.post)(
|
||||
self.transcript_url,
|
||||
json=request_data,
|
||||
headers=self.headers,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
|
||||
print(f"Transcript response: {response.status_code} {response.content}")
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
transcript = Transcript(
|
||||
text=result["text"],
|
||||
words=[
|
||||
Word(text=word["text"], start=word["start"], end=word["end"])
|
||||
for word in result["words"]
|
||||
],
|
||||
)
|
||||
|
||||
# remove audio file from S3
|
||||
await self._delete_file(data.path)
|
||||
|
||||
return transcript
|
||||
|
||||
@retry
|
||||
async def _upload_file(self, path: Path) -> str:
|
||||
upload_result = await self.storage.put_file(path.name, open(path, "rb"))
|
||||
return upload_result.url
|
||||
|
||||
@retry
|
||||
async def _delete_file(self, path: Path):
|
||||
await self.storage.delete_file(path.name)
|
||||
return True
|
||||
|
||||
|
||||
AudioTranscriptAutoProcessor.register("banana", AudioTranscriptBananaProcessor)
|
||||
36
server/reflector/processors/audio_waveform_processor.py
Normal file
36
server/reflector/processors/audio_waveform_processor.py
Normal file
@@ -0,0 +1,36 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
from reflector.processors.base import Processor
|
||||
from reflector.processors.types import TitleSummary
|
||||
from reflector.utils.audio_waveform import get_audio_waveform
|
||||
|
||||
|
||||
class AudioWaveformProcessor(Processor):
|
||||
"""
|
||||
Write the waveform for the final audio
|
||||
"""
|
||||
|
||||
INPUT_TYPE = TitleSummary
|
||||
|
||||
def __init__(self, audio_path: Path | str, waveform_path: str, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if isinstance(audio_path, str):
|
||||
audio_path = Path(audio_path)
|
||||
if audio_path.suffix not in (".mp3", ".wav"):
|
||||
raise ValueError("Only mp3 and wav files are supported")
|
||||
self.audio_path = audio_path
|
||||
self.waveform_path = waveform_path
|
||||
|
||||
async def _flush(self):
|
||||
self.waveform_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self.logger.info("Waveform Processing Started")
|
||||
waveform = get_audio_waveform(path=self.audio_path, segments_count=255)
|
||||
|
||||
with open(self.waveform_path, "w") as fd:
|
||||
json.dump(waveform, fd)
|
||||
self.logger.info("Waveform Processing Finished")
|
||||
await self.emit(waveform, name="waveform")
|
||||
|
||||
async def _push(_self, _data):
|
||||
return
|
||||
@@ -14,7 +14,42 @@ class PipelineEvent(BaseModel):
|
||||
data: Any
|
||||
|
||||
|
||||
class Processor:
|
||||
class Emitter:
|
||||
def __init__(self, **kwargs):
|
||||
self._callbacks = {}
|
||||
|
||||
# register callbacks from kwargs (on_*)
|
||||
for key, value in kwargs.items():
|
||||
if key.startswith("on_"):
|
||||
self.on(value, name=key[3:])
|
||||
|
||||
def on(self, callback, name="default"):
|
||||
"""
|
||||
Register a callback to be called when data is emitted
|
||||
"""
|
||||
# ensure callback is asynchronous
|
||||
if not asyncio.iscoroutinefunction(callback):
|
||||
raise ValueError("Callback must be a coroutine function")
|
||||
if name not in self._callbacks:
|
||||
self._callbacks[name] = []
|
||||
self._callbacks[name].append(callback)
|
||||
|
||||
def off(self, callback, name="default"):
|
||||
"""
|
||||
Unregister a callback to be called when data is emitted
|
||||
"""
|
||||
if name not in self._callbacks:
|
||||
return
|
||||
self._callbacks[name].remove(callback)
|
||||
|
||||
async def emit(self, data, name="default"):
|
||||
if name not in self._callbacks:
|
||||
return
|
||||
for callback in self._callbacks[name]:
|
||||
await callback(data)
|
||||
|
||||
|
||||
class Processor(Emitter):
|
||||
INPUT_TYPE: type = None
|
||||
OUTPUT_TYPE: type = None
|
||||
|
||||
@@ -59,7 +94,8 @@ class Processor:
|
||||
["processor"],
|
||||
)
|
||||
|
||||
def __init__(self, callback=None, custom_logger=None):
|
||||
def __init__(self, callback=None, custom_logger=None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.name = name = self.__class__.__name__
|
||||
self.m_processor = self.m_processor.labels(name)
|
||||
self.m_processor_call = self.m_processor_call.labels(name)
|
||||
@@ -70,9 +106,11 @@ class Processor:
|
||||
self.m_processor_flush_success = self.m_processor_flush_success.labels(name)
|
||||
self.m_processor_flush_failure = self.m_processor_flush_failure.labels(name)
|
||||
self._processors = []
|
||||
self._callbacks = []
|
||||
|
||||
# register callbacks
|
||||
if callback:
|
||||
self.on(callback)
|
||||
|
||||
self.uid = uuid4().hex
|
||||
self.flushed = False
|
||||
self.logger = (custom_logger or logger).bind(processor=self.__class__.__name__)
|
||||
@@ -100,21 +138,6 @@ class Processor:
|
||||
"""
|
||||
self._processors.remove(processor)
|
||||
|
||||
def on(self, callback):
|
||||
"""
|
||||
Register a callback to be called when data is emitted
|
||||
"""
|
||||
# ensure callback is asynchronous
|
||||
if not asyncio.iscoroutinefunction(callback):
|
||||
raise ValueError("Callback must be a coroutine function")
|
||||
self._callbacks.append(callback)
|
||||
|
||||
def off(self, callback):
|
||||
"""
|
||||
Unregister a callback to be called when data is emitted
|
||||
"""
|
||||
self._callbacks.remove(callback)
|
||||
|
||||
def get_pref(self, key: str, default: Any = None):
|
||||
"""
|
||||
Get a preference from the pipeline prefs
|
||||
@@ -123,15 +146,16 @@ class Processor:
|
||||
return self.pipeline.get_pref(key, default)
|
||||
return default
|
||||
|
||||
async def emit(self, data):
|
||||
if self.pipeline:
|
||||
await self.pipeline.emit(
|
||||
PipelineEvent(processor=self.name, uid=self.uid, data=data)
|
||||
)
|
||||
for callback in self._callbacks:
|
||||
await callback(data)
|
||||
for processor in self._processors:
|
||||
await processor.push(data)
|
||||
async def emit(self, data, name="default"):
|
||||
if name == "default":
|
||||
if self.pipeline:
|
||||
await self.pipeline.emit(
|
||||
PipelineEvent(processor=self.name, uid=self.uid, data=data)
|
||||
)
|
||||
await super().emit(data, name=name)
|
||||
if name == "default":
|
||||
for processor in self._processors:
|
||||
await processor.push(data)
|
||||
|
||||
async def push(self, data):
|
||||
"""
|
||||
@@ -254,11 +278,11 @@ class ThreadedProcessor(Processor):
|
||||
def disconnect(self, processor: Processor):
|
||||
self.processor.disconnect(processor)
|
||||
|
||||
def on(self, callback):
|
||||
self.processor.on(callback)
|
||||
def on(self, callback, name="default"):
|
||||
self.processor.on(callback, name=name)
|
||||
|
||||
def off(self, callback):
|
||||
self.processor.off(callback)
|
||||
def off(self, callback, name="default"):
|
||||
self.processor.off(callback, name=name)
|
||||
|
||||
def describe(self, level=0):
|
||||
super().describe(level)
|
||||
@@ -305,13 +329,13 @@ class BroadcastProcessor(Processor):
|
||||
for processor in self.processors:
|
||||
processor.disconnect(processor)
|
||||
|
||||
def on(self, callback):
|
||||
def on(self, callback, name="default"):
|
||||
for processor in self.processors:
|
||||
processor.on(callback)
|
||||
processor.on(callback, name=name)
|
||||
|
||||
def off(self, callback):
|
||||
def off(self, callback, name="default"):
|
||||
for processor in self.processors:
|
||||
processor.off(callback)
|
||||
processor.off(callback, name=name)
|
||||
|
||||
def describe(self, level=0):
|
||||
super().describe(level)
|
||||
|
||||
@@ -16,6 +16,7 @@ class TranscriptTranslatorProcessor(Processor):
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.transcript = None
|
||||
self.translate_url = settings.TRANSLATE_URL
|
||||
self.timeout = settings.TRANSLATE_TIMEOUT
|
||||
self.headers = {"Authorization": f"Bearer {settings.LLM_MODAL_API_KEY}"}
|
||||
|
||||
@@ -5,6 +5,7 @@ from pathlib import Path
|
||||
|
||||
from profanityfilter import ProfanityFilter
|
||||
from pydantic import BaseModel, PrivateAttr
|
||||
from reflector.redis_cache import redis_cache
|
||||
|
||||
PUNC_RE = re.compile(r"[.;:?!…]")
|
||||
|
||||
@@ -68,10 +69,14 @@ class Transcript(BaseModel):
|
||||
# Uncensored text
|
||||
return "".join([word.text for word in self.words])
|
||||
|
||||
@redis_cache(prefix="profanity", duration=3600 * 24 * 7)
|
||||
def _get_censored_text(self, text: str):
|
||||
return profanity_filter.censor(text).strip()
|
||||
|
||||
@property
|
||||
def text(self):
|
||||
# Censored text
|
||||
return profanity_filter.censor(self.raw_text).strip()
|
||||
return self._get_censored_text(self.raw_text)
|
||||
|
||||
@property
|
||||
def human_timestamp(self):
|
||||
|
||||
50
server/reflector/redis_cache.py
Normal file
50
server/reflector/redis_cache.py
Normal file
@@ -0,0 +1,50 @@
|
||||
import functools
|
||||
import json
|
||||
|
||||
import redis
|
||||
from reflector.settings import settings
|
||||
|
||||
redis_clients = {}
|
||||
|
||||
|
||||
def get_redis_client(db=0):
|
||||
"""
|
||||
Get a Redis client for the specified database.
|
||||
"""
|
||||
if db not in redis_clients:
|
||||
redis_clients[db] = redis.StrictRedis(
|
||||
host=settings.REDIS_HOST,
|
||||
port=settings.REDIS_PORT,
|
||||
db=db,
|
||||
)
|
||||
return redis_clients[db]
|
||||
|
||||
|
||||
def redis_cache(prefix="cache", duration=3600, db=settings.REDIS_CACHE_DB, argidx=1):
|
||||
"""
|
||||
Cache the result of a function in Redis.
|
||||
"""
|
||||
|
||||
def decorator(func):
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
# Check if the first argument is a string
|
||||
if len(args) < (argidx + 1) or not isinstance(args[argidx], str):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
# Compute the cache key based on the arguments and prefix
|
||||
cache_key = prefix + ":" + args[argidx]
|
||||
redis_client = get_redis_client(db=db)
|
||||
cached_result = redis_client.get(cache_key)
|
||||
|
||||
if cached_result:
|
||||
return json.loads(cached_result.decode("utf-8"))
|
||||
|
||||
# If the result is not cached, call the original function
|
||||
result = func(*args, **kwargs)
|
||||
redis_client.setex(cache_key, duration, json.dumps(result))
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
@@ -41,7 +41,7 @@ class Settings(BaseSettings):
|
||||
AUDIO_BUFFER_SIZE: int = 256 * 960
|
||||
|
||||
# Audio Transcription
|
||||
# backends: whisper, banana, modal
|
||||
# backends: whisper, modal
|
||||
TRANSCRIPT_BACKEND: str = "whisper"
|
||||
TRANSCRIPT_URL: str | None = None
|
||||
TRANSCRIPT_TIMEOUT: int = 90
|
||||
@@ -50,10 +50,6 @@ class Settings(BaseSettings):
|
||||
TRANSLATE_URL: str | None = None
|
||||
TRANSLATE_TIMEOUT: int = 90
|
||||
|
||||
# Audio transcription banana.dev configuration
|
||||
TRANSCRIPT_BANANA_API_KEY: str | None = None
|
||||
TRANSCRIPT_BANANA_MODEL_KEY: str | None = None
|
||||
|
||||
# Audio transcription modal.com configuration
|
||||
TRANSCRIPT_MODAL_API_KEY: str | None = None
|
||||
|
||||
@@ -61,13 +57,16 @@ class Settings(BaseSettings):
|
||||
TRANSCRIPT_STORAGE_BACKEND: str = "aws"
|
||||
|
||||
# Storage configuration for AWS
|
||||
TRANSCRIPT_STORAGE_AWS_BUCKET_NAME: str = "reflector-bucket/chunks"
|
||||
TRANSCRIPT_STORAGE_AWS_BUCKET_NAME: str = "reflector-bucket"
|
||||
TRANSCRIPT_STORAGE_AWS_REGION: str = "us-east-1"
|
||||
TRANSCRIPT_STORAGE_AWS_ACCESS_KEY_ID: str | None = None
|
||||
TRANSCRIPT_STORAGE_AWS_SECRET_ACCESS_KEY: str | None = None
|
||||
|
||||
# Transcript MP3 storage
|
||||
TRANSCRIPT_MP3_STORAGE_BACKEND: str = "aws"
|
||||
|
||||
# LLM
|
||||
# available backend: openai, banana, modal, oobabooga
|
||||
# available backend: openai, modal, oobabooga
|
||||
LLM_BACKEND: str = "oobabooga"
|
||||
|
||||
# LLM common configuration
|
||||
@@ -82,14 +81,11 @@ class Settings(BaseSettings):
|
||||
LLM_TEMPERATURE: float = 0.7
|
||||
ZEPHYR_LLM_URL: str | None = None
|
||||
|
||||
# LLM Banana configuration
|
||||
LLM_BANANA_API_KEY: str | None = None
|
||||
LLM_BANANA_MODEL_KEY: str | None = None
|
||||
|
||||
# LLM Modal configuration
|
||||
LLM_MODAL_API_KEY: str | None = None
|
||||
|
||||
# Diarization
|
||||
DIARIZATION_ENABLED: bool = True
|
||||
DIARIZATION_BACKEND: str = "modal"
|
||||
DIARIZATION_URL: str | None = None
|
||||
|
||||
@@ -124,6 +120,7 @@ class Settings(BaseSettings):
|
||||
# Redis
|
||||
REDIS_HOST: str = "localhost"
|
||||
REDIS_PORT: int = 6379
|
||||
REDIS_CACHE_DB: int = 2
|
||||
|
||||
# Secret key
|
||||
SECRET_KEY: str = "changeme-f02f86fd8b3e4fd892c6043e5a298e21"
|
||||
@@ -131,5 +128,8 @@ class Settings(BaseSettings):
|
||||
# Current hosting/domain
|
||||
BASE_URL: str = "http://localhost:1250"
|
||||
|
||||
# Profiling
|
||||
PROFILING: bool = False
|
||||
|
||||
|
||||
settings = Settings()
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import os
|
||||
from typing import BinaryIO
|
||||
|
||||
from fastapi import HTTPException, Request, status
|
||||
from fastapi import HTTPException, Request, Response, status
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
|
||||
@@ -57,6 +57,9 @@ def range_requests_response(
|
||||
),
|
||||
}
|
||||
|
||||
if request.method == "HEAD":
|
||||
return Response(headers=headers)
|
||||
|
||||
if content_disposition:
|
||||
headers["Content-Disposition"] = content_disposition
|
||||
|
||||
|
||||
@@ -23,7 +23,6 @@ from reflector.db.transcripts import (
|
||||
from reflector.processors.types import Transcript as ProcessorTranscript
|
||||
from reflector.settings import settings
|
||||
from reflector.ws_manager import get_ws_manager
|
||||
from starlette.concurrency import run_in_threadpool
|
||||
|
||||
from ._range_requests_response import range_requests_response
|
||||
from .rtc_offer import RtcOffer, rtc_offer_base
|
||||
@@ -53,7 +52,7 @@ class GetTranscript(BaseModel):
|
||||
name: str
|
||||
status: str
|
||||
locked: bool
|
||||
duration: int
|
||||
duration: float
|
||||
title: str | None
|
||||
short_summary: str | None
|
||||
long_summary: str | None
|
||||
@@ -222,6 +221,7 @@ async def transcript_delete(
|
||||
|
||||
|
||||
@router.get("/transcripts/{transcript_id}/audio/mp3")
|
||||
@router.head("/transcripts/{transcript_id}/audio/mp3")
|
||||
async def transcript_get_audio_mp3(
|
||||
request: Request,
|
||||
transcript_id: str,
|
||||
@@ -272,8 +272,6 @@ async def transcript_get_audio_waveform(
|
||||
if not transcript.audio_mp3_filename.exists():
|
||||
raise HTTPException(status_code=500, detail="Audio not found")
|
||||
|
||||
await run_in_threadpool(transcript.convert_audio_to_waveform)
|
||||
|
||||
return transcript.audio_waveform
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user