Compare commits
3 Commits
e5698fec7b
...
master
| Author | SHA1 | Date | |
|---|---|---|---|
| 3ca921ad02 | |||
| 5fc57cf7c8 | |||
| 2fa5e4575d |
@@ -5,6 +5,7 @@ from itertools import chain
|
||||
import click
|
||||
|
||||
from .inference import InferenceContext, load_ucs
|
||||
from .import_csv import csv_to_data
|
||||
from .gather import (build_sentence_class_dataset, print_dataset_stats,
|
||||
ucs_definitions_generator, scan_metadata, walk_path)
|
||||
from .recommend import print_recommendation
|
||||
@@ -19,6 +20,7 @@ formatter = logging.Formatter(
|
||||
stream_handler.setFormatter(formatter)
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
|
||||
@click.group(epilog="For more information see "
|
||||
"<https://git.squad51.us/jamie/ucsinfer>")
|
||||
@click.option('--verbose', '-v', flag_value=True, help='Verbose output')
|
||||
@@ -49,8 +51,8 @@ def ucsinfer(ctx, verbose, no_model_cache, model, complete_ucs):
|
||||
stream_handler.setLevel(logging.WARNING)
|
||||
|
||||
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
|
||||
logger.info("Setting TOKENIZERS_PARALLELISM environment variable to `false"
|
||||
" explicitly")
|
||||
logger.info("Setting TOKENIZERS_PARALLELISM environment variable to "
|
||||
"`false` explicitly")
|
||||
|
||||
ctx.ensure_object(dict)
|
||||
ctx.obj['model_cache'] = not no_model_cache
|
||||
@@ -118,7 +120,8 @@ def recommend(ctx, text, paths, interactive, skip_ucs):
|
||||
text = os.path.basename(path)
|
||||
|
||||
while True:
|
||||
retval = print_recommendation(path, text, inference_ctx, interactive)
|
||||
retval = print_recommendation(
|
||||
path, text, inference_ctx, interactive)
|
||||
if not retval:
|
||||
break
|
||||
if retval[0] is False:
|
||||
@@ -133,6 +136,7 @@ def recommend(ctx, text, paths, interactive, skip_ucs):
|
||||
os.rename(path, new_path)
|
||||
break
|
||||
|
||||
|
||||
@ucsinfer.command('csv')
|
||||
@click.option('--filename-col', default="FileName",
|
||||
help="Heading or index of the column containing filenames",
|
||||
@@ -143,7 +147,7 @@ def recommend(ctx, text, paths, interactive, skip_ucs):
|
||||
@click.option('--out', default='dataset/', show_default=True)
|
||||
@click.argument('paths', nargs=-1)
|
||||
@click.pass_context
|
||||
def csv(ctx, paths, out, filename_col, description_col):
|
||||
def import_csv(ctx, paths: list[str], out, filename_col, description_col):
|
||||
"""
|
||||
Scan training data from CSV files
|
||||
|
||||
@@ -152,16 +156,29 @@ def csv(ctx, paths, out, filename_col, description_col):
|
||||
file system it builds a dataset from descriptions and UCS filenames in
|
||||
columns of a CSV file.
|
||||
"""
|
||||
pass
|
||||
logger.debug("CSV mode")
|
||||
|
||||
logger.debug(f"Loading category list...")
|
||||
ucs = load_ucs(full_ucs=ctx.obj['complete_ucs'])
|
||||
|
||||
catid_list = [cat.catid for cat in ucs]
|
||||
|
||||
logger.info("Building dataset from csv...")
|
||||
|
||||
dataset = build_sentence_class_dataset(
|
||||
chain(csv_to_data(paths, description_col, filename_col, catid_list),
|
||||
ucs_definitions_generator(ucs)), catid_list)
|
||||
|
||||
logger.info(f"Saving dataset to disk at {out}")
|
||||
print_dataset_stats(dataset, catid_list)
|
||||
dataset.save_to_disk(out)
|
||||
|
||||
|
||||
@ucsinfer.command('gather')
|
||||
@click.option('--out', default='dataset/', show_default=True)
|
||||
# @click.option('--ucs-data', flag_value=True, help="Create a dataset based "
|
||||
# "on the UCS category explanations and synonymns (PATHS will "
|
||||
# "be ignored.)")
|
||||
@click.argument('paths', nargs=-1)
|
||||
@click.pass_context
|
||||
def gather(ctx, paths, out, ucs_data):
|
||||
def gather(ctx, paths, out):
|
||||
"""
|
||||
Scan training data from audio files
|
||||
|
||||
@@ -184,16 +201,12 @@ def gather(ctx, paths, out, ucs_data):
|
||||
scan_list: list[tuple[str, str]] = []
|
||||
catid_list = [cat.catid for cat in ucs]
|
||||
|
||||
if ucs_data:
|
||||
logger.info('Creating dataset for UCS categories instead of from PATH')
|
||||
paths = []
|
||||
|
||||
for path in paths:
|
||||
scan_list += walk_path(path, catid_list)
|
||||
|
||||
logger.info(f"Found {len(scan_list)} files to process.")
|
||||
|
||||
logger.info("Building dataset...")
|
||||
logger.info("Building dataset files...")
|
||||
|
||||
dataset = build_sentence_class_dataset(
|
||||
chain(scan_metadata(scan_list, catid_list),
|
||||
@@ -204,6 +217,7 @@ def gather(ctx, paths, out, ucs_data):
|
||||
print_dataset_stats(dataset, catid_list)
|
||||
dataset.save_to_disk(out)
|
||||
|
||||
|
||||
@ucsinfer.command('qualify')
|
||||
def qualify():
|
||||
"""
|
||||
@@ -225,7 +239,6 @@ def finetune(ctx):
|
||||
logger.debug("FINETUNE mode")
|
||||
|
||||
|
||||
|
||||
@ucsinfer.command('evaluate')
|
||||
@click.argument('dataset', default='dataset/')
|
||||
@click.pass_context
|
||||
@@ -237,6 +250,5 @@ def evaluate(ctx, dataset, offset, limit):
|
||||
logger.warning("Model evaluation is not currently implemented")
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
ucsinfer(obj={})
|
||||
|
||||
28
ucsinfer/import_csv.py
Normal file
28
ucsinfer/import_csv.py
Normal file
@@ -0,0 +1,28 @@
|
||||
import csv
|
||||
|
||||
from typing import Generator
|
||||
|
||||
from .util import parse_ucs
|
||||
|
||||
|
||||
def csv_to_data(paths, description_key, filename_key, catid_list) -> Generator[tuple[str, str], None, None]:
|
||||
"""
|
||||
Accepts a list of paths and returns an iterator of (sentence, class)
|
||||
tuples.
|
||||
"""
|
||||
for path in paths:
|
||||
with open(path, 'r') as f:
|
||||
records = csv.DictReader(f)
|
||||
|
||||
assert filename_key in records.fieldnames, \
|
||||
(f"Filename key `{filename_key}` not present in file "
|
||||
"{path}")
|
||||
|
||||
assert description_key in records.fieldnames, \
|
||||
(f"Description key `{description_key}` not present in "
|
||||
"file {path}")
|
||||
|
||||
for record in records:
|
||||
ucs_comps = parse_ucs(record[filename_key], catid_list)
|
||||
if ucs_comps:
|
||||
yield (record[description_key], ucs_comps.cat_id)
|
||||
Reference in New Issue
Block a user