Force casing on slot inputs

This commit is contained in:
Michael Hansen
2020-01-02 11:40:04 -05:00
parent 74761b942f
commit 4f6d02169c
+35 -8
View File
@@ -13,17 +13,32 @@ from uuid import uuid4
import attr
import json5
from quart import (Quart, Response, jsonify, request, safe_join, send_file,
send_from_directory, websocket)
from quart import (
Quart,
Response,
jsonify,
request,
safe_join,
send_file,
send_from_directory,
websocket,
)
from quart_cors import cors
from swagger_ui import quart_api_doc
from rhasspy.actor import ActorSystem, ConfigureEvent, RhasspyActor
from rhasspy.core import RhasspyCore
from rhasspy.events import IntentRecognized, ProfileTrainingFailed
from rhasspy.utils import (FunctionLoggingHandler, buffer_to_wav,
get_all_intents, get_ini_paths, get_wav_duration,
load_phoneme_examples, read_dict, recursive_remove)
from rhasspy.utils import (
FunctionLoggingHandler,
buffer_to_wav,
get_all_intents,
get_ini_paths,
get_wav_duration,
load_phoneme_examples,
read_dict,
recursive_remove,
)
# -----------------------------------------------------------------------------
# Quart Web App Setup
@@ -808,6 +823,16 @@ async def api_slots() -> Union[str, Response]:
overwrite_all = request.args.get("overwrite_all", "false").lower() == "true"
new_slot_values = json5.loads(await request.data)
word_casing = core.profile.get(
"speech_to_text.dictionary_casing", "ignore"
).lower()
word_transform = lambda s: s
if word_casing == "lower":
word_transform = str.lower
elif word_casing == "upper":
word_transform = str.upper
slots_dir = Path(
core.profile.write_path(
core.profile.get("speech_to_text.slots_dir", "slots")
@@ -834,15 +859,17 @@ async def api_slots() -> Union[str, Response]:
slots_path.parent.mkdir(parents=True, exist_ok=True)
# Merge with existing values
values = set(values)
values = set([word_transform(v.strip()) for v in values])
if slots_path.is_file():
values.update(line for line in slots_path.read_text().splitlines())
values.update(
word_transform(line.strip())
for line in slots_path.read_text().splitlines()
)
# Write merged values
if values:
with open(slots_path, "w") as slots_file:
for value in values:
value = value.strip()
if value:
print(value, file=slots_file)