transcriptum/transcribe.py
gauthiier 6754406c0c haha
2024-04-24 09:42:52 +02:00

142 lines
3.5 KiB
Python

import argparse
import numpy
import whisper
import torch
import wave
import os
from datetime import datetime, timedelta
from time import sleep
from queue import Queue
from enum import Enum
import audio
class State(Enum):
IDLE = 1
TRANSCRIBING = 2
LISTENING = 3
state = State.IDLE
def main():
p = argparse.ArgumentParser(description="TRANSCRIPTUM")
p.add_argument("--model", default="medium", help="Whisper model", choices=["tiny", "base", "small", "medium", "large"])
p.add_argument("--rms", default=1000, help="RMS (energy) threshold for microphone to detect", type=int)
p.add_argument("--record_timeout", default=8, help="Timeout for the microphone recording", type=float)
p.add_argument("--phrase_timeout", default=2, help="Silence timeout between phrases", type=float)
p.add_argument("--dynamic_threshold", action="store_true", help="Use dynamic rms threshold?")
args = p.parse_args()
record_timeout = args.record_timeout
phrase_timeout = args.phrase_timeout
dynamic_threshold = args.dynamic_threshold
phrase_time = None
data_queue = Queue()
transcripts = ['']
model = args.model
whisper_model = whisper.load_model(model)
transcribing = False
print("Model loaded.\n")
# select microphone?
source = audio.Microphone.select()
microphone = audio.Microphone(device_info=source, sample_rate=22050)
listener = audio.Listener()
listener.energy_threshold = args.rms
listener.dynamic_energy_threshold = args.dynamic_threshold
# with microphone:
# listener.adjust_ambient_noise(microphone, duration=1)
def print_transcripts(bcolor=None):
os.system("clear")
for l in transcripts:
if bcolor:
print(bcolor + l + '\033[0m')
else:
print(l)
print('', end='', flush=True)
# (frame_data, source.SAMPLE_RATE, source.SAMPLE_WIDTH)
def listen_callback(_, buffer:tuple) -> None:
data_queue.put(buffer[0])
def is_listening_callback(is_listening):
global state
if is_listening and state != State.LISTENING:
print_transcripts('\033[1m') #bold
state = State.LISTENING
elif state != State.IDLE and state != State.TRANSCRIBING:
print_transcripts()
state = State.IDLE
stop = listener.listen_in_background(source=microphone, listen_cb=listen_callback, listen_timeout=record_timeout, is_listening_cb=is_listening_callback)
os.system("clear")
while True:
try:
now = datetime.utcnow()
if not data_queue.empty():
phrase_complete = False
if phrase_time and now - phrase_time > timedelta(seconds=phrase_timeout):
phrase_complete = True
# with data_queue.mutex:
# phrase_time = now
# audio_data = b"".join(data_queue.queue)
# data_queue.queue.clear()
phrase_time = now
audio_data = b"".join(data_queue.queue)
data_queue.queue.clear()
np_data = numpy.frombuffer(audio_data, dtype=numpy.int16).astype(numpy.float32) / 32768.0
# file_name = f"sound{n}.wav"
# with wave.open(file_name, "w") as f:
# f.setnchannels(1)
# f.setsampwidth(2)
# f.setframerate(22050)
# f.writeframes(audio_data)
# n += 1
state = State.LISTENING
# print_transcripts('\033[4m') #underline
print_transcripts('\033[93m') #warning
r = whisper_model.transcribe(np_data, fp16=torch.cuda.is_available())
t = r['text'].strip()
if len(t) > 0:
if phrase_complete:
transcripts.append(t)
else:
transcripts[-1] = t
print_transcripts()
state = State.IDLE
sleep(0.25)
except KeyboardInterrupt:
break
stop(True)
print("\nTranscripts:\n")
for l in transcripts:
print(l)
if __name__ == "__main__":
main()