Feature additions (#210)

* initial

* add LLM features

* update LLM logic

* update llm functions: change control flow

* add generation config

* update return types

* update processors and tests

* update rtc_offer

* revert new title processor change

* fix unit tests

* add comments and fix HTTP 500

* adjust prompt

* test with reflector app

* revert new event for final title

* update

* move onus onto processors

* move onus onto processors

* stash

* add provision for gen config

* dynamically pack the LLM input using context length

* tune final summary params

* update consolidated class structures

* update consolidated class structures

* update precommit

* add broadcast processors

* working baseline

* Organize LLMParams

* minor fixes

* minor fixes

* minor fixes

* fix unit tests

* fix unit tests

* fix unit tests

* update tests

* update tests

* edit pipeline response events

* update summary return types

* configure tests

* alembic db migration

* change LLM response flow

* edit main llm functions

* edit main llm functions

* change llm name and gen cf

* Update transcript_topic_detector.py

* PR review comments

* checkpoint before db event migration

* update DB migration of past events

* update DB migration of past events

* edit LLM classes

* Delete unwanted file

* remove List typing

* remove List typing

* update oobabooga API call

* topic enhancements

* update UI event handling

* move ensure_casing to llm base

* update tests

* update tests
This commit is contained in:
projects-g
2023-09-13 11:26:08 +05:30
committed by GitHub
parent 762d7bfc3c
commit 9fe261406c
33 changed files with 1334 additions and 202 deletions

View File

@@ -8,6 +8,7 @@ from aiortc import MediaStreamTrack, RTCPeerConnection, RTCSessionDescription
from fastapi import APIRouter, Request
from prometheus_client import Gauge
from pydantic import BaseModel
from reflector.events import subscribers_shutdown
from reflector.logger import logger
from reflector.processors import (
@@ -15,14 +16,19 @@ from reflector.processors import (
AudioFileWriterProcessor,
AudioMergeProcessor,
AudioTranscriptAutoProcessor,
FinalSummary,
FinalLongSummary,
FinalShortSummary,
Pipeline,
TitleSummary,
Transcript,
TranscriptFinalSummaryProcessor,
TranscriptFinalLongSummaryProcessor,
TranscriptFinalShortSummaryProcessor,
TranscriptFinalTitleProcessor,
TranscriptLinerProcessor,
TranscriptTopicDetectorProcessor,
)
from reflector.processors.base import BroadcastProcessor
from reflector.processors.types import FinalTitle
sessions = []
router = APIRouter()
@@ -72,8 +78,10 @@ class StrValue(BaseModel):
class PipelineEvent(StrEnum):
TRANSCRIPT = "TRANSCRIPT"
TOPIC = "TOPIC"
FINAL_SUMMARY = "FINAL_SUMMARY"
FINAL_LONG_SUMMARY = "FINAL_LONG_SUMMARY"
STATUS = "STATUS"
FINAL_SHORT_SUMMARY = "FINAL_SHORT_SUMMARY"
FINAL_TITLE = "FINAL_TITLE"
async def rtc_offer_base(
@@ -124,15 +132,15 @@ async def rtc_offer_base(
data=transcript,
)
async def on_topic(summary: TitleSummary):
async def on_topic(topic: TitleSummary):
# FIXME: make it incremental with the frontend, not send everything
ctx.logger.info("Summary", summary=summary)
ctx.logger.info("Topic", topic=topic)
ctx.topics.append(
{
"title": summary.title,
"timestamp": summary.timestamp,
"transcript": summary.transcript.text,
"desc": summary.summary,
"title": topic.title,
"timestamp": topic.timestamp,
"transcript": topic.transcript.text,
"desc": topic.summary,
}
)
@@ -144,17 +152,17 @@ async def rtc_offer_base(
# send to callback (eg. websocket)
if event_callback:
await event_callback(
event=PipelineEvent.TOPIC, args=event_callback_args, data=summary
event=PipelineEvent.TOPIC, args=event_callback_args, data=topic
)
async def on_final_summary(summary: FinalSummary):
ctx.logger.info("FinalSummary", final_summary=summary)
async def on_final_short_summary(summary: FinalShortSummary):
ctx.logger.info("FinalShortSummary", final_short_summary=summary)
# send to RTC
if ctx.data_channel.readyState == "open":
result = {
"cmd": "DISPLAY_FINAL_SUMMARY",
"summary": summary.summary,
"cmd": "DISPLAY_FINAL_SHORT_SUMMARY",
"summary": summary.short_summary,
"duration": summary.duration,
}
ctx.data_channel.send(dumps(result))
@@ -162,11 +170,47 @@ async def rtc_offer_base(
# send to callback (eg. websocket)
if event_callback:
await event_callback(
event=PipelineEvent.FINAL_SUMMARY,
event=PipelineEvent.FINAL_SHORT_SUMMARY,
args=event_callback_args,
data=summary,
)
async def on_final_long_summary(summary: FinalLongSummary):
ctx.logger.info("FinalLongSummary", final_summary=summary)
# send to RTC
if ctx.data_channel.readyState == "open":
result = {
"cmd": "DISPLAY_FINAL_LONG_SUMMARY",
"summary": summary.long_summary,
"duration": summary.duration,
}
ctx.data_channel.send(dumps(result))
# send to callback (eg. websocket)
if event_callback:
await event_callback(
event=PipelineEvent.FINAL_LONG_SUMMARY,
args=event_callback_args,
data=summary,
)
async def on_final_title(title: FinalTitle):
ctx.logger.info("FinalTitle", final_title=title)
# send to RTC
if ctx.data_channel.readyState == "open":
result = {"cmd": "DISPLAY_FINAL_TITLE", "title": title.title}
ctx.data_channel.send(dumps(result))
# send to callback (eg. websocket)
if event_callback:
await event_callback(
event=PipelineEvent.FINAL_TITLE,
args=event_callback_args,
data=title,
)
# create a context for the whole rtc transaction
# add a customised logger to the context
processors = []
@@ -178,7 +222,17 @@ async def rtc_offer_base(
AudioTranscriptAutoProcessor.as_threaded(callback=on_transcript),
TranscriptLinerProcessor(),
TranscriptTopicDetectorProcessor.as_threaded(callback=on_topic),
TranscriptFinalSummaryProcessor.as_threaded(callback=on_final_summary),
BroadcastProcessor(
processors=[
TranscriptFinalTitleProcessor.as_threaded(callback=on_final_title),
TranscriptFinalLongSummaryProcessor.as_threaded(
callback=on_final_long_summary
),
TranscriptFinalShortSummaryProcessor.as_threaded(
callback=on_final_short_summary
),
]
),
]
ctx.pipeline = Pipeline(*processors)
ctx.pipeline.set_pref("audio:source_language", source_language)

View File

@@ -7,7 +7,6 @@ from typing import Annotated, Optional
from uuid import uuid4
import av
import reflector.auth as auth
from fastapi import (
APIRouter,
Depends,
@@ -18,11 +17,13 @@ from fastapi import (
)
from fastapi_pagination import Page, paginate
from pydantic import BaseModel, Field
from starlette.concurrency import run_in_threadpool
import reflector.auth as auth
from reflector.db import database, transcripts
from reflector.logger import logger
from reflector.settings import settings
from reflector.utils.audio_waveform import get_audio_waveform
from starlette.concurrency import run_in_threadpool
from ._range_requests_response import range_requests_response
from .rtc_offer import PipelineEvent, RtcOffer, rtc_offer_base
@@ -60,8 +61,16 @@ class TranscriptTopic(BaseModel):
timestamp: float
class TranscriptFinalSummary(BaseModel):
summary: str
class TranscriptFinalShortSummary(BaseModel):
short_summary: str
class TranscriptFinalLongSummary(BaseModel):
long_summary: str
class TranscriptFinalTitle(BaseModel):
title: str
class TranscriptEvent(BaseModel):
@@ -77,7 +86,9 @@ class Transcript(BaseModel):
locked: bool = False
duration: float = 0
created_at: datetime = Field(default_factory=datetime.utcnow)
summary: str | None = None
title: str | None = None
short_summary: str | None = None
long_summary: str | None = None
topics: list[TranscriptTopic] = []
events: list[TranscriptEvent] = []
source_language: str = "en"
@@ -241,7 +252,9 @@ class GetTranscript(BaseModel):
status: str
locked: bool
duration: int
summary: str | None
title: str | None
short_summary: str | None
long_summary: str | None
created_at: datetime
source_language: str
target_language: str
@@ -256,7 +269,9 @@ class CreateTranscript(BaseModel):
class UpdateTranscript(BaseModel):
name: Optional[str] = Field(None)
locked: Optional[bool] = Field(None)
summary: Optional[str] = Field(None)
title: Optional[str] = Field(None)
short_summary: Optional[str] = Field(None)
long_summary: Optional[str] = Field(None)
class DeletionStatus(BaseModel):
@@ -315,20 +330,32 @@ async def transcript_update(
transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id)
if not transcript:
raise HTTPException(status_code=404, detail="Transcript not found")
values = {}
values = {"events": []}
if info.name is not None:
values["name"] = info.name
if info.locked is not None:
values["locked"] = info.locked
if info.summary is not None:
values["summary"] = info.summary
# also find FINAL_SUMMARY event and patch it
for te in transcript.events:
if te["event"] == PipelineEvent.FINAL_SUMMARY:
te["summary"] = info.summary
if info.long_summary is not None:
values["long_summary"] = info.long_summary
for transcript_event in transcript.events:
if transcript_event["event"] == PipelineEvent.FINAL_LONG_SUMMARY:
transcript_event["long_summary"] = info.long_summary
break
values["events"] = transcript.events
values["events"].extend(transcript.events)
if info.short_summary is not None:
values["short_summary"] = info.short_summary
for transcript_event in transcript.events:
if transcript_event["event"] == PipelineEvent.FINAL_SHORT_SUMMARY:
transcript_event["short_summary"] = info.short_summary
break
values["events"].extend(transcript.events)
if info.title is not None:
values["title"] = info.title
for transcript_event in transcript.events:
if transcript_event["event"] == PipelineEvent.FINAL_TITLE:
transcript_event["title"] = info.title
break
values["events"].extend(transcript.events)
await transcripts_controller.update(transcript, values)
return transcript
@@ -539,14 +566,38 @@ async def handle_rtc_event(event: PipelineEvent, args, data):
},
)
elif event == PipelineEvent.FINAL_SUMMARY:
final_summary = TranscriptFinalSummary(summary=data.summary)
resp = transcript.add_event(event=event, data=final_summary)
elif event == PipelineEvent.FINAL_TITLE:
final_title = TranscriptFinalTitle(title=data.title)
resp = transcript.add_event(event=event, data=final_title)
await transcripts_controller.update(
transcript,
{
"events": transcript.events_dump(),
"summary": final_summary.summary,
"title": final_title.title,
},
)
elif event == PipelineEvent.FINAL_LONG_SUMMARY:
final_long_summary = TranscriptFinalLongSummary(long_summary=data.long_summary)
resp = transcript.add_event(event=event, data=final_long_summary)
await transcripts_controller.update(
transcript,
{
"events": transcript.events_dump(),
"long_summary": final_long_summary.long_summary,
},
)
elif event == PipelineEvent.FINAL_SHORT_SUMMARY:
final_short_summary = TranscriptFinalShortSummary(
short_summary=data.short_summary
)
resp = transcript.add_event(event=event, data=final_short_summary)
await transcripts_controller.update(
transcript,
{
"events": transcript.events_dump(),
"short_summary": final_short_summary.short_summary,
},
)