Skip to content

Commit

Permalink
feat: add stream API
Browse files Browse the repository at this point in the history
  • Loading branch information
winstxnhdw committed Oct 12, 2024
1 parent d7ca930 commit 00651e8
Show file tree
Hide file tree
Showing 8 changed files with 69 additions and 29 deletions.
13 changes: 10 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ A fast cross-platform CPU-first video/audio English-only transcriber for generat

## Usage (API)

Simply cURL the endpoint like in the following. Currently, the only available caption format is `srt` and `vtt`.
Simply cURL the endpoint like in the following. Currently, the only available caption format are `srt`, `vtt` and `txt`.

```bash
curl "https://winstxnhdw-CapGen.hf.space/api/v1/transcribe?caption_format=$CAPTION_FORMAT" \
Expand All @@ -29,8 +29,15 @@ curl "https://winstxnhdw-CapGen.hf.space/api/v1/transcribe?caption_format=$CAPTI
You can also redirect the output to a file.

```bash
curl "https://winstxnhdw-CapGen.hf.space/api/v1/transcribe" \
-F "request=@$AUDIO_FILE_PATH" | jq -r ".result" > result.srt
curl "https://winstxnhdw-CapGen.hf.space/api/v1/transcribe?caption_format=$CAPTION_FORMAT" \
-F "request=@$AUDIO_FILE_PATH" | jq -r ".result" > result.srt
```

You can stream the captions in real-time with the following.

```bash
curl -N "https://winstxnhdw-CapGen.hf.space/api/v1/transcribe/stream?caption_format=$CAPTION_FORMAT" \
-F "request=@$AUDIO_FILE_PATH"
```

## Usage (CLI)
Expand Down
4 changes: 2 additions & 2 deletions capgen/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,10 +100,10 @@ def main():
resolve_cuda_libraries()

if not (transcription := Transcriber(**options).transcribe(args.file, args.caption)):
raise InvalidFormatError(f'Invalid format: {args.caption}!')
raise InvalidFormatError(f'Invalid file: {args.file}!')

with open(args.output, 'w', encoding='utf-8') as file:
file.write(transcription)
file.write('\n\n'.join(transcription))


if __name__ == '__main__':
Expand Down
1 change: 1 addition & 0 deletions capgen/transcriber/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from capgen.transcriber.caption_format import CaptionFormat as CaptionFormat
from capgen.transcriber.transcriber import Transcriber as Transcriber
3 changes: 3 additions & 0 deletions capgen/transcriber/caption_format.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from typing import Literal

CaptionFormat = Literal['srt', 'vtt', 'txt']
13 changes: 6 additions & 7 deletions capgen/transcriber/converter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Iterable
from typing import Iterable, Iterator

from faster_whisper.transcribe import Segment

Expand Down Expand Up @@ -50,7 +50,7 @@ def convert_seconds_to_hhmmssmmm(self, seconds: float, millisecond_separator: st

return f'{int(hours):02}:{int(minutes):02}:{int(seconds):02}{millisecond_separator}{milliseconds:03}'

def to_srt(self) -> str:
def to_srt(self) -> Iterator[str]:
"""
Summary
-------
Expand All @@ -64,14 +64,14 @@ def to_srt(self) -> str:
-------
subrip_subtitle (str) : the SRT subtitles
"""
return '\n\n'.join(
return (
f'{id}\n'
f'{self.convert_seconds_to_hhmmssmmm(start, ",")} --> '
f'{self.convert_seconds_to_hhmmssmmm(end, ",")}\n{text[1:]}'
for id, _, start, end, text, *_ in self.segments
)

def to_vtt(self) -> str:
def to_vtt(self) -> Iterator[str]:
"""
Summary
-------
Expand All @@ -85,10 +85,9 @@ def to_vtt(self) -> str:
-------
video_text_tracks_subtitle (str) : the VTT subtitles
"""
captions = '\n\n'.join(
yield 'WEBVTT'
yield from (
f'{self.convert_seconds_to_hhmmssmmm(start, ".")} --> '
f'{self.convert_seconds_to_hhmmssmmm(end, ".")}\n{text[1:]}'
for _, _, start, end, text, *_ in self.segments
)

return f'WEBVTT\n\n{captions}'
28 changes: 17 additions & 11 deletions capgen/transcriber/transcriber.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from typing import BinaryIO, Literal
from typing import BinaryIO, Iterator, Literal

from av.error import InvalidDataError
from faster_whisper import WhisperModel

from capgen.transcriber.caption_format import CaptionFormat
from capgen.transcriber.converter import Converter


Expand All @@ -28,7 +30,7 @@ def __init__(self, device: Literal['auto', 'cpu', 'cuda'], number_of_threads: in
num_workers=number_of_workers,
)

def transcribe(self, file: str | BinaryIO, caption_format: str) -> str | None:
def transcribe(self, file: str | BinaryIO, caption_format: CaptionFormat) -> Iterator[str] | None:
"""
Summary
-------
Expand All @@ -41,15 +43,19 @@ def transcribe(self, file: str | BinaryIO, caption_format: str) -> str | None:
Returns
-------
transcription (str | None) : the transcribed text in the chosen caption format
transcription (Iterator[str] | None) : the transcribed text in the chosen caption format
"""
segments, _ = self.model.transcribe(
file,
language='en',
beam_size=1,
vad_filter=True,
vad_parameters={'min_silence_duration_ms': 500},
)
try:
segments, _ = self.model.transcribe(
file,
language='en',
beam_size=1,
vad_filter=True,
vad_parameters={'min_silence_duration_ms': 500},
)

except InvalidDataError:
return None

converter = Converter(segments)

Expand All @@ -59,4 +65,4 @@ def transcribe(self, file: str | BinaryIO, caption_format: str) -> str | None:
if caption_format == 'vtt':
return converter.to_vtt()

return None
return (segment.text for segment in segments)
6 changes: 4 additions & 2 deletions capgen/types/arguments.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from typing import BinaryIO, Literal, NamedTuple
from typing import BinaryIO, NamedTuple

from capgen.transcriber import CaptionFormat


class Arguments(NamedTuple):
Expand All @@ -16,7 +18,7 @@ class Arguments(NamedTuple):
"""

file: str | BinaryIO
caption: Literal['srt', 'vtt']
caption: CaptionFormat
output: str
cuda: bool
threads: int | None
Expand Down
30 changes: 26 additions & 4 deletions server/api/v1/transcribe.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
from io import BytesIO
from typing import Annotated, Literal
from typing import Annotated

from litestar import Controller, post
from litestar.concurrency import _run_sync_asyncio as run_sync
from litestar.datastructures import UploadFile
from litestar.enums import RequestEncodingType
from litestar.exceptions import ClientException
from litestar.params import Body
from litestar.response.sse import ServerSentEvent
from litestar.status_codes import HTTP_200_OK

from capgen.transcriber import CaptionFormat
from server.schemas.v1 import Transcribed
from server.state import AppState

Expand All @@ -27,7 +29,7 @@ async def transcribe(
self,
state: AppState,
data: Annotated[UploadFile, Body(media_type=RequestEncodingType.MULTI_PART)],
caption_format: Literal['srt', 'vtt'] = 'srt',
caption_format: CaptionFormat = 'srt',
) -> Transcribed:
"""
Summary
Expand All @@ -38,6 +40,26 @@ async def transcribe(
transcription = await run_sync(state.transcriber.transcribe, audio, caption_format)

if not transcription:
raise ClientException(detail=f'Invalid format: {caption_format}!')
raise ClientException(detail=f'Invalid file: {data.filename}!')

return Transcribed(result=transcription)
return Transcribed(result='\n\n'.join(transcription))

@post('/stream', status_code=HTTP_200_OK)
async def transcribe_stream(
self,
state: AppState,
data: Annotated[UploadFile, Body(media_type=RequestEncodingType.MULTI_PART)],
caption_format: CaptionFormat = 'srt',
) -> ServerSentEvent:
"""
Summary
-------
the POST variant of the `/transcribe/stream` route
"""
audio = BytesIO(await data.read())
transcription = await run_sync(state.transcriber.transcribe, audio, caption_format)

if not transcription:
raise ClientException(detail=f'Invalid file: {data.filename}!')

return ServerSentEvent(transcription)

0 comments on commit 00651e8

Please sign in to comment.