142 lines
3.5 KiB
Python
142 lines
3.5 KiB
Python
|
|
import argparse
|
||
|
|
import numpy
|
||
|
|
import whisper
|
||
|
|
import torch
|
||
|
|
import wave
|
||
|
|
import os
|
||
|
|
|
||
|
|
from datetime import datetime, timedelta
|
||
|
|
from time import sleep
|
||
|
|
from queue import Queue
|
||
|
|
from enum import Enum
|
||
|
|
|
||
|
|
import audio
|
||
|
|
|
||
|
|
class State(Enum):
|
||
|
|
IDLE = 1
|
||
|
|
TRANSCRIBING = 2
|
||
|
|
LISTENING = 3
|
||
|
|
|
||
|
|
state = State.IDLE
|
||
|
|
|
||
|
|
def main():
|
||
|
|
p = argparse.ArgumentParser(description="TRANSCRIPTUM")
|
||
|
|
p.add_argument("--model", default="medium", help="Whisper model", choices=["tiny", "base", "small", "medium", "large"])
|
||
|
|
p.add_argument("--rms", default=1000, help="RMS (energy) threshold for microphone to detect", type=int)
|
||
|
|
p.add_argument("--record_timeout", default=8, help="Timeout for the microphone recording", type=float)
|
||
|
|
p.add_argument("--phrase_timeout", default=2, help="Silence timeout between phrases", type=float)
|
||
|
|
p.add_argument("--dynamic_threshold", action="store_true", help="Use dynamic rms threshold?")
|
||
|
|
|
||
|
|
args = p.parse_args()
|
||
|
|
|
||
|
|
record_timeout = args.record_timeout
|
||
|
|
phrase_timeout = args.phrase_timeout
|
||
|
|
dynamic_threshold = args.dynamic_threshold
|
||
|
|
phrase_time = None
|
||
|
|
|
||
|
|
data_queue = Queue()
|
||
|
|
transcripts = ['']
|
||
|
|
|
||
|
|
model = args.model
|
||
|
|
whisper_model = whisper.load_model(model)
|
||
|
|
transcribing = False
|
||
|
|
print("Model loaded.\n")
|
||
|
|
|
||
|
|
# select microphone?
|
||
|
|
source = audio.Microphone.select()
|
||
|
|
microphone = audio.Microphone(device_info=source, sample_rate=22050)
|
||
|
|
|
||
|
|
listener = audio.Listener()
|
||
|
|
listener.energy_threshold = args.rms
|
||
|
|
listener.dynamic_energy_threshold = args.dynamic_threshold
|
||
|
|
|
||
|
|
# with microphone:
|
||
|
|
# listener.adjust_ambient_noise(microphone, duration=1)
|
||
|
|
|
||
|
|
def print_transcripts(bcolor=None):
|
||
|
|
os.system("clear")
|
||
|
|
for l in transcripts:
|
||
|
|
if bcolor:
|
||
|
|
print(bcolor + l + '\033[0m')
|
||
|
|
else:
|
||
|
|
print(l)
|
||
|
|
print('', end='', flush=True)
|
||
|
|
|
||
|
|
# (frame_data, source.SAMPLE_RATE, source.SAMPLE_WIDTH)
|
||
|
|
def listen_callback(_, buffer:tuple) -> None:
|
||
|
|
data_queue.put(buffer[0])
|
||
|
|
|
||
|
|
def is_listening_callback(is_listening):
|
||
|
|
global state
|
||
|
|
if is_listening and state != State.LISTENING:
|
||
|
|
print_transcripts('\033[1m') #bold
|
||
|
|
state = State.LISTENING
|
||
|
|
elif state != State.IDLE and state != State.TRANSCRIBING:
|
||
|
|
print_transcripts()
|
||
|
|
state = State.IDLE
|
||
|
|
|
||
|
|
stop = listener.listen_in_background(source=microphone, listen_cb=listen_callback, listen_timeout=record_timeout, is_listening_cb=is_listening_callback)
|
||
|
|
|
||
|
|
os.system("clear")
|
||
|
|
|
||
|
|
while True:
|
||
|
|
try:
|
||
|
|
now = datetime.utcnow()
|
||
|
|
|
||
|
|
if not data_queue.empty():
|
||
|
|
|
||
|
|
phrase_complete = False
|
||
|
|
if phrase_time and now - phrase_time > timedelta(seconds=phrase_timeout):
|
||
|
|
phrase_complete = True
|
||
|
|
|
||
|
|
# with data_queue.mutex:
|
||
|
|
# phrase_time = now
|
||
|
|
# audio_data = b"".join(data_queue.queue)
|
||
|
|
# data_queue.queue.clear()
|
||
|
|
|
||
|
|
phrase_time = now
|
||
|
|
audio_data = b"".join(data_queue.queue)
|
||
|
|
data_queue.queue.clear()
|
||
|
|
|
||
|
|
np_data = numpy.frombuffer(audio_data, dtype=numpy.int16).astype(numpy.float32) / 32768.0
|
||
|
|
|
||
|
|
# file_name = f"sound{n}.wav"
|
||
|
|
# with wave.open(file_name, "w") as f:
|
||
|
|
# f.setnchannels(1)
|
||
|
|
# f.setsampwidth(2)
|
||
|
|
# f.setframerate(22050)
|
||
|
|
# f.writeframes(audio_data)
|
||
|
|
# n += 1
|
||
|
|
|
||
|
|
state = State.LISTENING
|
||
|
|
# print_transcripts('\033[4m') #underline
|
||
|
|
print_transcripts('\033[93m') #warning
|
||
|
|
r = whisper_model.transcribe(np_data, fp16=torch.cuda.is_available())
|
||
|
|
|
||
|
|
t = r['text'].strip()
|
||
|
|
|
||
|
|
if len(t) > 0:
|
||
|
|
if phrase_complete:
|
||
|
|
transcripts.append(t)
|
||
|
|
else:
|
||
|
|
transcripts[-1] = t
|
||
|
|
|
||
|
|
print_transcripts()
|
||
|
|
|
||
|
|
state = State.IDLE
|
||
|
|
|
||
|
|
sleep(0.25)
|
||
|
|
|
||
|
|
except KeyboardInterrupt:
|
||
|
|
break
|
||
|
|
|
||
|
|
|
||
|
|
|
||
|
|
stop(True)
|
||
|
|
|
||
|
|
print("\nTranscripts:\n")
|
||
|
|
for l in transcripts:
|
||
|
|
print(l)
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
main()
|