NATURESPEAK-ML-UTTER/speak_broadcast.py

437 lines
10 KiB
Python
Raw Permalink Normal View History

2022-04-11 13:09:01 +02:00
import argparse, json, sys, time, random, logging, signal, threading, string
2022-03-13 17:09:05 +01:00
import utterance.voice
import utterance.utils
2022-04-10 15:20:13 +02:00
import utterance.oscosc
2022-03-13 17:09:05 +01:00
import examine.metric
2022-03-20 16:06:57 +01:00
logging.basicConfig(level=logging.INFO)
2022-03-17 15:29:44 +01:00
2022-04-10 15:20:13 +02:00
UTTERANCE_LEN = 64 #<--------------- these should be in config
2022-04-04 09:12:49 +02:00
NUM_METRIC_GEN = 75
2022-03-13 18:56:35 +01:00
NUM_SAMPLE_VOICES = 3
2022-04-04 09:12:49 +02:00
RANDOM_SEED_TIMER_MIN = 2
2022-04-10 15:20:13 +02:00
STATE_TRANSITION_TIMER_MIN = 10
2022-03-13 17:09:05 +01:00
broadcast = None
2022-03-13 18:56:35 +01:00
metric = None
exit = False
2022-03-17 15:29:44 +01:00
terminal = False
2022-03-20 16:06:57 +01:00
debug = False
2022-04-10 15:20:13 +02:00
state = "METRIC"
2022-04-11 13:09:01 +02:00
B_SKIP = []
B_SWAPS = {}
2022-03-13 17:09:05 +01:00
2022-03-13 18:56:35 +01:00
def format_str(text) -> str:
2022-03-13 17:09:05 +01:00
t = utterance.utils.clean(text)
2022-03-13 18:56:35 +01:00
return utterance.utils.format(t)
2022-04-10 15:20:13 +02:00
def tokenise_str(text):
return utterance.utils.tokenise(text)
def utter_one(v, temp=None, length=None) -> str:
u = v.utter_one(temp=temp, length=length)
return format_str(u)
def prompt_one(v, pinput: str, temp=None, length=None) -> str:
u = v.prompt(pinput=pinput, temp=None, length=length)
2022-03-13 18:56:35 +01:00
return format_str(u)
2022-04-10 15:20:13 +02:00
def utter_one_vectorise(v, temp=None, length=None):
global metric
uv = utter_one(v, temp, length)
uv_vec = metric.vector(uv)
return uv, uv_vec
def prompt_one_vectorise(v, pinput: str, temp=None, length=None):
2022-03-13 18:56:35 +01:00
global metric
2022-04-10 15:20:13 +02:00
uv = prompt_one(v, pinput, temp, length)
2022-03-13 18:56:35 +01:00
uv_vec = metric.vector(uv)
return uv, uv_vec
2022-04-10 15:20:13 +02:00
def utter_n_vectorise_distance(v, n, vec, temp=None, length=None):
2022-03-13 18:56:35 +01:00
global metric
results = []
2022-04-10 15:20:13 +02:00
texts = v.utter_n(n=n, temp=temp, length=length)
2022-03-13 18:56:35 +01:00
for t in texts:
t = format_str(t)
t_vec = metric.vector(t)
d = examine.metric.cos_dist(vec, t_vec)
results.append([d, t, v])
return results
2022-03-13 17:09:05 +01:00
2022-03-17 15:29:44 +01:00
def terminal_utterance(utterance):
if terminal:
print(utterance, end="")
2022-04-11 13:09:01 +02:00
def fix_ending(frags):
result = frags.copy()
end = result[-1]
end = end.translate(str.maketrans('', '', string.punctuation))
fix = utterance.utils.fix_sentence(end)
if fix is None or len(fix) == 0:
result = result[:-1]
# result[-1] = result[-1]
else:
result[-1] = fix
result[-1] = utterance.utils.fix_punctuation(result[-1])
print(result)
return result
def fix_beginning(frags):
global B_SKIP, B_SWAPS
result = frags.copy()
beginnig = result[0]
toks = beginnig.split()
if len(toks) > 0:
f = toks[0].lower()
if f[0] in string.punctuation:
f = f[1:]
if f in B_SKIP:
if len(toks) > 2:
result[0] = " ".join(toks[1:]).capitalize() + "\n"
else:
if result[1][0] == ' ':
result[1] = result[1][1:]
result[1] = result[1].capitalize()
return result[1:]
elif toks[0] in B_SWAPS:
result[0] = result[0].replace(toks[0], B_SWAPS[toks[0]])
return result
2022-03-13 17:09:05 +01:00
def broadcast_utterance(v, utterance):
2022-04-11 13:09:01 +02:00
print(utterance)
2022-03-20 16:06:57 +01:00
global broadcast, exit, debug
2022-03-13 17:09:05 +01:00
2022-04-04 09:12:49 +02:00
# Send all text to server to calculate bounds in advance
2022-04-04 10:48:02 +02:00
broadcast.utterance(utterance, v.calculate)
2022-03-13 17:09:05 +01:00
2022-04-04 10:48:02 +02:00
text = ""
2022-03-13 17:09:05 +01:00
broadcast.utterance(text, v.channel)
2022-03-17 15:29:44 +01:00
terminal_utterance(text)
2022-03-13 17:09:05 +01:00
time.sleep(2)
frags = v.fragments(utterance)
2022-04-11 13:09:01 +02:00
frags = fix_beginning(frags)
frags = fix_ending(frags)
2022-03-13 17:09:05 +01:00
for f in frags:
2022-04-11 13:09:01 +02:00
terminal_utterance(f)
2022-03-13 17:09:05 +01:00
text += f
broadcast.utterance(text, v.channel)
2022-04-10 15:20:13 +02:00
# sleep_time = 2
# toks = tokenise_str(f)
toks = f.split()
sleep_time = len(toks)
if sleep_time <= 2:
sleep_time += 1
time.sleep(sleep_time)
2022-03-13 18:56:35 +01:00
if exit:
return
2022-03-13 17:09:05 +01:00
broadcast.command('clear')
2022-03-17 15:29:44 +01:00
print("==========")
2022-04-10 15:20:13 +02:00
time.sleep(3)
2022-03-13 17:09:05 +01:00
2022-03-13 19:24:02 +01:00
def find_candidates(v, uv_vec, voices, results):
logging.info(f"LOOP::finding candidates")
start = time.time()
candidates = random.sample(voices, NUM_SAMPLE_VOICES)
for c in candidates:
if exit:
break
if c == v:
continue
results += utter_n_vectorise_distance(c, NUM_METRIC_GEN, uv_vec)
results.sort(key=lambda t: t[0], reverse=True)
lapse = time.time() - start
logging.info(f"LOOP::done - {lapse} secs")
2022-04-10 15:20:13 +02:00
# def update():
# global exit
# while not exit:
# try:
# utterance.osc.update()
# except Exception as e:
# logging.error(e)
# pass
2022-03-20 16:06:57 +01:00
2022-04-10 15:20:13 +02:00
# time.sleep(0.2)
2022-03-13 19:24:02 +01:00
2022-03-13 18:56:35 +01:00
def signal_terminate(signum, frame):
global exit
logging.warning("::SIGNAL TERMINATE::")
exit = True
2022-03-13 17:09:05 +01:00
def main() -> int:
2022-04-11 13:09:01 +02:00
global broadcast, metric, terminal, debug, state, UTTERANCE_LEN, NUM_METRIC_GEN, NUM_SAMPLE_VOICES, RANDOM_SEED_TIMER_MIN, STATE_TRANSITION_TIMER_MIN, B_SKIP, B_SWAPS
2022-03-13 17:09:05 +01:00
p = argparse.ArgumentParser()
p.add_argument("-c", "--config", type=str, default="voice.config.json", help="configuratin file")
p.add_argument("-i", "--iterations", type=int, default=10, help="number of iterations")
2022-04-04 09:12:49 +02:00
p.add_argument("-t", "--terminal", action='store_true', help="print to terminal")
2022-03-17 15:29:44 +01:00
args = p.parse_args()
2022-03-13 18:56:35 +01:00
logging.info(f"INIT::loading config file - {args.config}")
2022-03-13 17:09:05 +01:00
with open(args.config) as f:
conf = json.load(f)
2022-03-13 18:56:35 +01:00
logging.info(conf)
2022-03-17 15:29:44 +01:00
terminal = args.terminal
2022-04-11 13:09:01 +02:00
#--------------------#
# CONFIGS
#--------------------#
u_conf = conf['utterance_configuration']
UTTERANCE_LEN = u_conf['UTTERANCE_LEN']
NUM_METRIC_GEN = u_conf['NUM_METRIC_GEN']
NUM_SAMPLE_VOICES = u_conf['NUM_SAMPLE_VOICES']
RANDOM_SEED_TIMER_MIN = u_conf['RANDOM_SEED_TIMER_MIN']
STATE_TRANSITION_TIMER_MIN = u_conf['STATE_TRANSITION_TIMER_MIN']
B_SKIP = u_conf['b_skip']
B_SWAPS = u_conf['b_swaps']
2022-03-13 18:56:35 +01:00
#--------------------#
# VOICES
#--------------------#
logging.info(f"INIT::creating voices")
2022-03-13 17:09:05 +01:00
voices = []
for v in conf['voices']:
model = v['model']
2022-04-10 15:20:13 +02:00
voice = utterance.voice.Voice(name=v["name"].upper(), model=model['model_dir'], tokenizer=model['tokeniser_file'], temp=float(model["temperature"]), length=UTTERANCE_LEN)
2022-03-13 17:09:05 +01:00
voice.set_channel(v['osc_channel']['root'], v['osc_channel']['utterance'])
2022-04-04 09:12:49 +02:00
voice.set_calculate(v['osc_channel']['root'], v['osc_channel']['calculate'])
voice.set_temperature(v['osc_channel']['root'], v['osc_channel']['temperature'])
2022-03-13 17:09:05 +01:00
voices.append(voice)
2022-04-10 15:20:13 +02:00
#--------------------#
# QUESTION
#--------------------#
logging.info(f"INIT::setting up question")
questions = conf['questions']
questions_array = questions.copy()
random.shuffle(questions_array)
2022-03-13 17:09:05 +01:00
2022-03-13 18:56:35 +01:00
#--------------------#
# NET
#--------------------#
logging.info(f"INIT::setting up OSC")
2022-03-20 10:45:12 +01:00
2022-04-10 15:20:13 +02:00
broadcast = utterance.oscosc.OscBroadcaster(name="osc_broadcast", host=conf['host_voicemachine'], port=conf['port_voicemachine'], command_channel=conf['command_osc_channel'])
2022-03-20 10:45:12 +01:00
2022-04-10 15:20:13 +02:00
# def receiver_cb_temp(unused_addr, args, temp, name):
# for v in voices:
# if v.name == name:
# print(f'{name} - {temp}')
# v.temp = temp
# # broadcast.temperature(temp, v.temperature) # <-- doesn works because deadlocks in osc_process...
2022-03-20 10:45:12 +01:00
2022-04-10 15:20:13 +02:00
# def receiver_cb_command(unused_addr, args, cmd):
# global debug
# debug = name
# logging.info(f"DEBUG MODE: {debug}")
2022-03-20 10:45:12 +01:00
2022-04-10 15:20:13 +02:00
# receiver = utterance.oscosc.OscReceiver(name="osc_receiver", host=conf['host_machinespeak'], port=conf['port_machinespeak'], callback_fn_command=receiver_cb_command, callback_fn_temp=receiver_cb_temp)
2022-03-20 10:45:12 +01:00
2022-04-10 15:20:13 +02:00
# t_osc_receiver = threading.Thread(target=receiver.server.serve_forever)
# t_osc_receiver.start()
2022-03-13 17:09:05 +01:00
2022-03-13 18:56:35 +01:00
#--------------------#
# METRIC
#--------------------#
logging.info(f"INIT::loading doc2vec metrics")
2022-03-13 17:09:05 +01:00
metric = examine.metric.Metric(model_input='data/models/doc2vec.model')
2022-03-20 16:06:57 +01:00
#--------------------#
# RANDOM
#--------------------#
2022-04-10 15:20:13 +02:00
def random_seed(seconds):
global t_random_seed, exit
i = 0
while i < seconds:
i += 1
time.sleep(1)
if exit:
return
logging.info("RANDOM::SEEDING RANDOM")
2022-03-20 16:06:57 +01:00
random.seed(time.time())
if not exit:
2022-04-10 15:20:13 +02:00
t_random_seed = threading.Thread(target=random_seed, args=(random.randint(60, 60 * RANDOM_SEED_TIMER_MIN), ))
2022-03-20 16:06:57 +01:00
t_random_seed.start()
2022-04-10 15:20:13 +02:00
t_random_seed = threading.Thread(target=random_seed, args=(random.randint(60, 60 * RANDOM_SEED_TIMER_MIN), ))
2022-03-20 16:06:57 +01:00
t_random_seed.start()
2022-04-10 15:20:13 +02:00
#--------------------#
# STATE TRANSITION
#--------------------#
# STATES = ["METRIC", "QUESTION", "RANDOM"]
def state_transition(seconds):
global t_state_transition, exit, state
i = 0
while i < seconds:
i += 1
time.sleep(1)
if exit:
return
logging.info("STATE::STATE TRANSITION")
state = "QUESTION"
if not exit:
t_state_transition = threading.Thread(target=state_transition, args=(random.randint(60, 60 * STATE_TRANSITION_TIMER_MIN), ))
t_state_transition.start()
t_state_transition = threading.Thread(target=state_transition, args=(random.randint(60, 60 * STATE_TRANSITION_TIMER_MIN), ))
t_state_transition.start()
2022-03-13 18:56:35 +01:00
#--------------------#
# A
#--------------------#
logging.info(f"INIT::generating first utterance")
2022-03-20 16:06:57 +01:00
2022-03-13 18:56:35 +01:00
v = random.choice(voices)
uv, uv_vec = utter_one_vectorise(v)
2022-03-20 10:45:12 +01:00
2022-04-10 15:20:13 +02:00
# -- this only updates OSC --
# -- might not need this in production
# t_update = threading.Thread(target=update)
# t_update.start()
2022-03-13 17:09:05 +01:00
2022-03-13 18:56:35 +01:00
while not exit:
2022-03-13 17:09:05 +01:00
2022-04-10 15:20:13 +02:00
if state == "METRIC":
2022-03-13 19:24:02 +01:00
2022-04-10 15:20:13 +02:00
logging.info(f"- state METRIC")
results = []
t = threading.Thread(target=find_candidates, args=[v, uv_vec, voices, results])
t.start()
logging.info(f"METRIC::broadcasting {v.name}")
2022-04-11 13:09:01 +02:00
try:
broadcast_utterance(v, uv)
except Exception as e:
logging.error(e)
pass
2022-04-10 15:20:13 +02:00
t.join()
# ok here we need to randomise maybe...?!
# ([d, t, v])
choice = results[0]
# makse sure we don't say the same thing over and over again
for r in results:
v = r[2]
u = r[1]
if v.select(u):
choice = r
break
else:
logging.info(f"METRIC::reduncancy {v.name}")
v = choice[2]
uv = choice[1]
uv_vec = metric.vector(uv)
logging.info(f"METRIC::next {v.name}")
elif state == "QUESTION":
logging.info(f"- state QUESTION")
2022-03-13 17:09:05 +01:00
2022-04-10 15:20:13 +02:00
if len(questions_array) <= 0:
questions_array = questions.copy()
random.shuffle(questions_array)
2022-04-04 09:12:49 +02:00
2022-04-10 15:20:13 +02:00
# random question
q = questions_array.pop(0)
2022-04-04 09:12:49 +02:00
2022-04-10 15:20:13 +02:00
# random voice
v = random.choice(voices)
2022-04-04 09:12:49 +02:00
2022-04-10 15:20:13 +02:00
# random voice asks random question
2022-03-13 17:09:05 +01:00
2022-04-10 15:20:13 +02:00
logging.info(f"QUESTION::{v.name} : {q['question']}")
2022-03-13 17:09:05 +01:00
2022-04-10 15:20:13 +02:00
broadcast.utterance(q['question'], v.calculate)
broadcast.utterance(q['question'], v.channel)
time.sleep(15)
# answer
v = [e for e in voices if e.name == q['voice']]
if len(v) == 1:
v = v[0]
logging.info(f"QUESTION::answer - {v.name}")
uv, uv_vec = prompt_one_vectorise(v, q['prompt'])
v.remember(uv)
# broadcast_utterance(v, uv) <-- this is broadcasted as part of METRIC STATE
state = "METRIC"
elif state == "RANDOM":
logging.info(f"- state RANDOM")
v = random.choice(voices)
l = random.randint(5, UTTERANCE_LEN)
uv, uv_vec = utter_one_vectorise(v, length=l)
broadcast_utterance(v, uv)
# t_update.join()
# logging.info(f"TERMINATE::terminating OSC")
# t_osc_receiver.stop()
# t_osc_receiver.join()
2022-03-13 17:09:05 +01:00
2022-03-20 16:06:57 +01:00
# if t_random_seed:
logging.info(f"TERMINATE::random seed")
t_random_seed.join()
2022-04-10 15:20:13 +02:00
logging.info(f"TERMINATE::state transition")
t_state_transition.join()
2022-03-13 18:56:35 +01:00
logging.info(f"FIN")
2022-03-13 17:09:05 +01:00
if __name__ == '__main__':
2022-03-13 18:56:35 +01:00
signal.signal(signal.SIGINT, signal_terminate)
sys.exit(main())