diff --git a/ucsinfer/__main__.py b/ucsinfer/__main__.py index b669c4c..5f55fce 100644 --- a/ucsinfer/__main__.py +++ b/ucsinfer/__main__.py @@ -117,12 +117,14 @@ def recommend(ctx, text, paths, interactive, skip_ucs): @ucsinfer.command('gather') -@click.option('--outfile', type=click.File(mode='w', encoding='utf8'), - default='dataset.csv', show_default=True) +@click.option('--outfile', default='dataset', show_default=True) +@click.option('--ucs-data', flag_value=True, help="Create a dataset based " + "on the UCS category explanations and synonymns (PATHS will " + "be ignored.)") @click.argument('paths', nargs=-1) -def gather(paths, outfile): +def gather(paths, outfile, ucs_data): """ - Scan files to build a training dataset at PATH + Scan files to build a training dataset The `gather` is used to build a training dataset for finetuning the selected model. Description sentences and UCS categories are collected from @@ -141,9 +143,13 @@ def gather(paths, outfile): logger.debug(f"Loading category list...") ucs = load_ucs() - catid_list = [cat.catid for cat in ucs] scan_list = [] + catid_list = [cat.catid for cat in ucs] + + if ucs_data: + paths = [] + for path in paths: logger.info(f"Scanning directory {path}...") for dirpath, _, filenames in os.walk(path): @@ -166,7 +172,16 @@ def gather(paths, outfile): assert comps yield comps.fx_name, str(pair[0]) - dataset = build_sentence_class_dataset(scan_metadata()) + def ucs_metadata(): + for cat in ucs: + yield cat.explanations, cat.catid + yield ", ".join(cat.synonymns), cat.catid + + if ucs_data: + dataset = build_sentence_class_dataset(ucs_metadata(), catid_list) + else: + dataset = build_sentence_class_dataset(scan_metadata(), catid_list) + dataset.save_to_disk(outfile)