From 00651e87ce80abb06c3dbb7014295890709824d0 Mon Sep 17 00:00:00 2001 From: winstxnhdw Date: Sat, 12 Oct 2024 21:13:16 +0100 Subject: [PATCH] feat: add stream API --- README.md | 13 +++++++++--- capgen/__init__.py | 4 ++-- capgen/transcriber/__init__.py | 1 + capgen/transcriber/caption_format.py | 3 +++ capgen/transcriber/converter.py | 13 ++++++------ capgen/transcriber/transcriber.py | 28 ++++++++++++++++---------- capgen/types/arguments.py | 6 ++++-- server/api/v1/transcribe.py | 30 ++++++++++++++++++++++++---- 8 files changed, 69 insertions(+), 29 deletions(-) create mode 100644 capgen/transcriber/caption_format.py diff --git a/README.md b/README.md index 68bb819..e1d207e 100644 --- a/README.md +++ b/README.md @@ -19,7 +19,7 @@ A fast cross-platform CPU-first video/audio English-only transcriber for generat ## Usage (API) -Simply cURL the endpoint like in the following. Currently, the only available caption format is `srt` and `vtt`. +Simply cURL the endpoint like in the following. Currently, the only available caption format are `srt`, `vtt` and `txt`. ```bash curl "https://winstxnhdw-CapGen.hf.space/api/v1/transcribe?caption_format=$CAPTION_FORMAT" \ @@ -29,8 +29,15 @@ curl "https://winstxnhdw-CapGen.hf.space/api/v1/transcribe?caption_format=$CAPTI You can also redirect the output to a file. ```bash - curl "https://winstxnhdw-CapGen.hf.space/api/v1/transcribe" \ - -F "request=@$AUDIO_FILE_PATH" | jq -r ".result" > result.srt +curl "https://winstxnhdw-CapGen.hf.space/api/v1/transcribe?caption_format=$CAPTION_FORMAT" \ + -F "request=@$AUDIO_FILE_PATH" | jq -r ".result" > result.srt +``` + +You can stream the captions in real-time with the following. + +```bash +curl -N "https://winstxnhdw-CapGen.hf.space/api/v1/transcribe/stream?caption_format=$CAPTION_FORMAT" \ + -F "request=@$AUDIO_FILE_PATH" ``` ## Usage (CLI) diff --git a/capgen/__init__.py b/capgen/__init__.py index 02d2502..210ba17 100644 --- a/capgen/__init__.py +++ b/capgen/__init__.py @@ -100,10 +100,10 @@ def main(): resolve_cuda_libraries() if not (transcription := Transcriber(**options).transcribe(args.file, args.caption)): - raise InvalidFormatError(f'Invalid format: {args.caption}!') + raise InvalidFormatError(f'Invalid file: {args.file}!') with open(args.output, 'w', encoding='utf-8') as file: - file.write(transcription) + file.write('\n\n'.join(transcription)) if __name__ == '__main__': diff --git a/capgen/transcriber/__init__.py b/capgen/transcriber/__init__.py index e9253e9..8f34dcd 100644 --- a/capgen/transcriber/__init__.py +++ b/capgen/transcriber/__init__.py @@ -1 +1,2 @@ +from capgen.transcriber.caption_format import CaptionFormat as CaptionFormat from capgen.transcriber.transcriber import Transcriber as Transcriber diff --git a/capgen/transcriber/caption_format.py b/capgen/transcriber/caption_format.py new file mode 100644 index 0000000..025e834 --- /dev/null +++ b/capgen/transcriber/caption_format.py @@ -0,0 +1,3 @@ +from typing import Literal + +CaptionFormat = Literal['srt', 'vtt', 'txt'] diff --git a/capgen/transcriber/converter.py b/capgen/transcriber/converter.py index c5a4c81..747e18d 100644 --- a/capgen/transcriber/converter.py +++ b/capgen/transcriber/converter.py @@ -1,4 +1,4 @@ -from typing import Iterable +from typing import Iterable, Iterator from faster_whisper.transcribe import Segment @@ -50,7 +50,7 @@ def convert_seconds_to_hhmmssmmm(self, seconds: float, millisecond_separator: st return f'{int(hours):02}:{int(minutes):02}:{int(seconds):02}{millisecond_separator}{milliseconds:03}' - def to_srt(self) -> str: + def to_srt(self) -> Iterator[str]: """ Summary ------- @@ -64,14 +64,14 @@ def to_srt(self) -> str: ------- subrip_subtitle (str) : the SRT subtitles """ - return '\n\n'.join( + return ( f'{id}\n' f'{self.convert_seconds_to_hhmmssmmm(start, ",")} --> ' f'{self.convert_seconds_to_hhmmssmmm(end, ",")}\n{text[1:]}' for id, _, start, end, text, *_ in self.segments ) - def to_vtt(self) -> str: + def to_vtt(self) -> Iterator[str]: """ Summary ------- @@ -85,10 +85,9 @@ def to_vtt(self) -> str: ------- video_text_tracks_subtitle (str) : the VTT subtitles """ - captions = '\n\n'.join( + yield 'WEBVTT' + yield from ( f'{self.convert_seconds_to_hhmmssmmm(start, ".")} --> ' f'{self.convert_seconds_to_hhmmssmmm(end, ".")}\n{text[1:]}' for _, _, start, end, text, *_ in self.segments ) - - return f'WEBVTT\n\n{captions}' diff --git a/capgen/transcriber/transcriber.py b/capgen/transcriber/transcriber.py index ab12b17..492df20 100644 --- a/capgen/transcriber/transcriber.py +++ b/capgen/transcriber/transcriber.py @@ -1,7 +1,9 @@ -from typing import BinaryIO, Literal +from typing import BinaryIO, Iterator, Literal +from av.error import InvalidDataError from faster_whisper import WhisperModel +from capgen.transcriber.caption_format import CaptionFormat from capgen.transcriber.converter import Converter @@ -28,7 +30,7 @@ def __init__(self, device: Literal['auto', 'cpu', 'cuda'], number_of_threads: in num_workers=number_of_workers, ) - def transcribe(self, file: str | BinaryIO, caption_format: str) -> str | None: + def transcribe(self, file: str | BinaryIO, caption_format: CaptionFormat) -> Iterator[str] | None: """ Summary ------- @@ -41,15 +43,19 @@ def transcribe(self, file: str | BinaryIO, caption_format: str) -> str | None: Returns ------- - transcription (str | None) : the transcribed text in the chosen caption format + transcription (Iterator[str] | None) : the transcribed text in the chosen caption format """ - segments, _ = self.model.transcribe( - file, - language='en', - beam_size=1, - vad_filter=True, - vad_parameters={'min_silence_duration_ms': 500}, - ) + try: + segments, _ = self.model.transcribe( + file, + language='en', + beam_size=1, + vad_filter=True, + vad_parameters={'min_silence_duration_ms': 500}, + ) + + except InvalidDataError: + return None converter = Converter(segments) @@ -59,4 +65,4 @@ def transcribe(self, file: str | BinaryIO, caption_format: str) -> str | None: if caption_format == 'vtt': return converter.to_vtt() - return None + return (segment.text for segment in segments) diff --git a/capgen/types/arguments.py b/capgen/types/arguments.py index dc7638a..37cc5a9 100644 --- a/capgen/types/arguments.py +++ b/capgen/types/arguments.py @@ -1,4 +1,6 @@ -from typing import BinaryIO, Literal, NamedTuple +from typing import BinaryIO, NamedTuple + +from capgen.transcriber import CaptionFormat class Arguments(NamedTuple): @@ -16,7 +18,7 @@ class Arguments(NamedTuple): """ file: str | BinaryIO - caption: Literal['srt', 'vtt'] + caption: CaptionFormat output: str cuda: bool threads: int | None diff --git a/server/api/v1/transcribe.py b/server/api/v1/transcribe.py index 7e9466c..4f27382 100644 --- a/server/api/v1/transcribe.py +++ b/server/api/v1/transcribe.py @@ -1,5 +1,5 @@ from io import BytesIO -from typing import Annotated, Literal +from typing import Annotated from litestar import Controller, post from litestar.concurrency import _run_sync_asyncio as run_sync @@ -7,8 +7,10 @@ from litestar.enums import RequestEncodingType from litestar.exceptions import ClientException from litestar.params import Body +from litestar.response.sse import ServerSentEvent from litestar.status_codes import HTTP_200_OK +from capgen.transcriber import CaptionFormat from server.schemas.v1 import Transcribed from server.state import AppState @@ -27,7 +29,7 @@ async def transcribe( self, state: AppState, data: Annotated[UploadFile, Body(media_type=RequestEncodingType.MULTI_PART)], - caption_format: Literal['srt', 'vtt'] = 'srt', + caption_format: CaptionFormat = 'srt', ) -> Transcribed: """ Summary @@ -38,6 +40,26 @@ async def transcribe( transcription = await run_sync(state.transcriber.transcribe, audio, caption_format) if not transcription: - raise ClientException(detail=f'Invalid format: {caption_format}!') + raise ClientException(detail=f'Invalid file: {data.filename}!') - return Transcribed(result=transcription) + return Transcribed(result='\n\n'.join(transcription)) + + @post('/stream', status_code=HTTP_200_OK) + async def transcribe_stream( + self, + state: AppState, + data: Annotated[UploadFile, Body(media_type=RequestEncodingType.MULTI_PART)], + caption_format: CaptionFormat = 'srt', + ) -> ServerSentEvent: + """ + Summary + ------- + the POST variant of the `/transcribe/stream` route + """ + audio = BytesIO(await data.read()) + transcription = await run_sync(state.transcriber.transcribe, audio, caption_format) + + if not transcription: + raise ClientException(detail=f'Invalid file: {data.filename}!') + + return ServerSentEvent(transcription)