diff --git a/ucsinfer/__main__.py b/ucsinfer/__main__.py index 85f4f13..5ebf52c 100644 --- a/ucsinfer/__main__.py +++ b/ucsinfer/__main__.py @@ -1,6 +1,9 @@ import os import logging from itertools import chain +import csv + +from typing import Generator import click @@ -49,8 +52,8 @@ def ucsinfer(ctx, verbose, no_model_cache, model, complete_ucs): stream_handler.setLevel(logging.WARNING) os.environ['TOKENIZERS_PARALLELISM'] = 'false' - logger.info("Setting TOKENIZERS_PARALLELISM environment variable to `false" - " explicitly") + logger.info("Setting TOKENIZERS_PARALLELISM environment variable to " + "`false` explicitly") ctx.ensure_object(dict) ctx.obj['model_cache'] = not no_model_cache @@ -132,7 +135,30 @@ def recommend(ctx, text, paths, interactive, skip_ucs): print(f"Renaming {path} \n to {new_path}") os.rename(path, new_path) break - + +def csv_to_data(paths, description_key, filename_key, catid_list) -> Generator[tuple[str, str], None, None]: + """ + Accepts a list of paths and returns an iterator of (sentence, class) + tuples. + """ + for path in paths: + with open(path, 'r') as f: + records = csv.DictReader(f) + + assert filename_key in records.fieldnames, \ + (f"Filename key `{filename_key}` not present in file " + "{path}") + + assert description_key in records.fieldnames, \ + (f"Description key `{description_key}` not present in " + "file {path}") + + for record in records: + ucs_comps = parse_ucs(record[filename_key], catid_list) + if ucs_comps: + yield (record[description_key], ucs_comps.cat_id) + + @ucsinfer.command('csv') @click.option('--filename-col', default="FileName", help="Heading or index of the column containing filenames", @@ -143,7 +169,7 @@ def recommend(ctx, text, paths, interactive, skip_ucs): @click.option('--out', default='dataset/', show_default=True) @click.argument('paths', nargs=-1) @click.pass_context -def csv(ctx, paths, out, filename_col, description_col): +def import_csv(ctx, paths: list[str], out, filename_col, description_col): """ Scan training data from CSV files @@ -152,16 +178,29 @@ def csv(ctx, paths, out, filename_col, description_col): file system it builds a dataset from descriptions and UCS filenames in columns of a CSV file. """ - pass + logger.debug("CSV mode") + + logger.debug(f"Loading category list...") + ucs = load_ucs(full_ucs=ctx.obj['complete_ucs']) + + catid_list = [cat.catid for cat in ucs] + + logger.info("Building dataset from csv...") + + dataset = build_sentence_class_dataset( + chain(csv_to_data(paths, description_col, filename_col, catid_list), + ucs_definitions_generator(ucs)),catid_list) + + logger.info(f"Saving dataset to disk at {out}") + print_dataset_stats(dataset, catid_list) + dataset.save_to_disk(out) + @ucsinfer.command('gather') @click.option('--out', default='dataset/', show_default=True) -# @click.option('--ucs-data', flag_value=True, help="Create a dataset based " -# "on the UCS category explanations and synonymns (PATHS will " -# "be ignored.)") @click.argument('paths', nargs=-1) @click.pass_context -def gather(ctx, paths, out, ucs_data): +def gather(ctx, paths, out): """ Scan training data from audio files @@ -184,20 +223,16 @@ def gather(ctx, paths, out, ucs_data): scan_list: list[tuple[str,str]] = [] catid_list = [cat.catid for cat in ucs] - if ucs_data: - logger.info('Creating dataset for UCS categories instead of from PATH') - paths = [] - for path in paths: scan_list += walk_path(path, catid_list) logger.info(f"Found {len(scan_list)} files to process.") - logger.info("Building dataset...") + logger.info("Building dataset files...") dataset = build_sentence_class_dataset( - chain(scan_metadata(scan_list, catid_list), - ucs_definitions_generator(ucs)), + chain(scan_metadata(scan_list, catid_list), + ucs_definitions_generator(ucs)), catid_list) logger.info(f"Saving dataset to disk at {out}")