dataset writing in HF format
This commit is contained in:
@@ -117,12 +117,14 @@ def recommend(ctx, text, paths, interactive, skip_ucs):
|
|||||||
|
|
||||||
|
|
||||||
@ucsinfer.command('gather')
|
@ucsinfer.command('gather')
|
||||||
@click.option('--outfile', type=click.File(mode='w', encoding='utf8'),
|
@click.option('--outfile', default='dataset', show_default=True)
|
||||||
default='dataset.csv', show_default=True)
|
@click.option('--ucs-data', flag_value=True, help="Create a dataset based "
|
||||||
|
"on the UCS category explanations and synonymns (PATHS will "
|
||||||
|
"be ignored.)")
|
||||||
@click.argument('paths', nargs=-1)
|
@click.argument('paths', nargs=-1)
|
||||||
def gather(paths, outfile):
|
def gather(paths, outfile, ucs_data):
|
||||||
"""
|
"""
|
||||||
Scan files to build a training dataset at PATH
|
Scan files to build a training dataset
|
||||||
|
|
||||||
The `gather` is used to build a training dataset for finetuning the
|
The `gather` is used to build a training dataset for finetuning the
|
||||||
selected model. Description sentences and UCS categories are collected from
|
selected model. Description sentences and UCS categories are collected from
|
||||||
@@ -141,9 +143,13 @@ def gather(paths, outfile):
|
|||||||
|
|
||||||
logger.debug(f"Loading category list...")
|
logger.debug(f"Loading category list...")
|
||||||
ucs = load_ucs()
|
ucs = load_ucs()
|
||||||
catid_list = [cat.catid for cat in ucs]
|
|
||||||
|
|
||||||
scan_list = []
|
scan_list = []
|
||||||
|
catid_list = [cat.catid for cat in ucs]
|
||||||
|
|
||||||
|
if ucs_data:
|
||||||
|
paths = []
|
||||||
|
|
||||||
for path in paths:
|
for path in paths:
|
||||||
logger.info(f"Scanning directory {path}...")
|
logger.info(f"Scanning directory {path}...")
|
||||||
for dirpath, _, filenames in os.walk(path):
|
for dirpath, _, filenames in os.walk(path):
|
||||||
@@ -166,7 +172,16 @@ def gather(paths, outfile):
|
|||||||
assert comps
|
assert comps
|
||||||
yield comps.fx_name, str(pair[0])
|
yield comps.fx_name, str(pair[0])
|
||||||
|
|
||||||
dataset = build_sentence_class_dataset(scan_metadata())
|
def ucs_metadata():
|
||||||
|
for cat in ucs:
|
||||||
|
yield cat.explanations, cat.catid
|
||||||
|
yield ", ".join(cat.synonymns), cat.catid
|
||||||
|
|
||||||
|
if ucs_data:
|
||||||
|
dataset = build_sentence_class_dataset(ucs_metadata(), catid_list)
|
||||||
|
else:
|
||||||
|
dataset = build_sentence_class_dataset(scan_metadata(), catid_list)
|
||||||
|
|
||||||
dataset.save_to_disk(outfile)
|
dataset.save_to_disk(outfile)
|
||||||
|
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user