Autopep
This commit is contained in:
@@ -22,7 +22,7 @@ def recommend():
|
||||
"""
|
||||
Infer a UCS category for a text description
|
||||
"""
|
||||
pass
|
||||
pass
|
||||
|
||||
|
||||
@ucsinfer.command('gather')
|
||||
@@ -36,7 +36,7 @@ def gather(paths, outfile):
|
||||
types = ['.wav', '.flac']
|
||||
table = csv.writer(outfile)
|
||||
print(f"Loading category list...")
|
||||
catid_list = [cat.catid for cat in load_ucs()]
|
||||
catid_list = [cat.catid for cat in load_ucs()]
|
||||
|
||||
scan_list = []
|
||||
for path in paths:
|
||||
@@ -47,12 +47,12 @@ def gather(paths, outfile):
|
||||
if ext in types and \
|
||||
(ucs_components := parse_ucs(root, catid_list)) and \
|
||||
not filename.startswith("._"):
|
||||
scan_list.append((ucs_components.cat_id,
|
||||
scan_list.append((ucs_components.cat_id,
|
||||
os.path.join(dirpath, filename)))
|
||||
|
||||
print(f"Found {len(scan_list)} files to process.")
|
||||
|
||||
for pair in tqdm.tqdm(scan_list, unit='files',file=sys.stderr):
|
||||
for pair in tqdm.tqdm(scan_list, unit='files', file=sys.stderr):
|
||||
if desc := ffmpeg_description(pair[1]):
|
||||
table.writerow([pair[0], desc])
|
||||
|
||||
@@ -62,13 +62,13 @@ def finetune():
|
||||
"""
|
||||
Fine-tune a model with training data
|
||||
"""
|
||||
pass
|
||||
pass
|
||||
|
||||
|
||||
@ucsinfer.command('evaluate')
|
||||
@click.option('--offset', type=int, default=0)
|
||||
@click.option('--limit', type=int, default=-1)
|
||||
@click.argument('dataset', type=click.File('r', encoding='utf8'),
|
||||
@click.argument('dataset', type=click.File('r', encoding='utf8'),
|
||||
default='dataset.csv')
|
||||
def evaluate(dataset, offset, limit):
|
||||
"""
|
||||
@@ -82,7 +82,7 @@ def evaluate(dataset, offset, limit):
|
||||
for i, row in enumerate(tqdm.tqdm(reader)):
|
||||
if i < offset:
|
||||
continue
|
||||
|
||||
|
||||
if limit > 0 and i >= limit + offset:
|
||||
break
|
||||
|
||||
@@ -107,33 +107,33 @@ def evaluate(dataset, offset, limit):
|
||||
|
||||
miss_counts = []
|
||||
for cat in cats:
|
||||
miss_counts.append((cat, len([x for x in results \
|
||||
if x['catid'] == cat and x['result'] == 'MISS'])))
|
||||
miss_counts.append((cat, len([x for x in results
|
||||
if x['catid'] == cat and x['result'] == 'MISS'])))
|
||||
|
||||
miss_counts = sorted(miss_counts, key=lambda x: x[1])
|
||||
|
||||
print(f" === RESULTS === ")
|
||||
|
||||
|
||||
table = [
|
||||
["Total records in sample:", f"{total}"],
|
||||
["Top Result:", f"{total_top}",
|
||||
f"{float(total_top)/float(total):.2%}"],
|
||||
["Top 5 Result:", f"{total_top_5}",
|
||||
f"{float(total_top_5)/float(total):.2%}"],
|
||||
["Top 10 Result:", f"{total_top_10}",
|
||||
f"{float(total_top_10)/float(total):.2%}"],
|
||||
SEPARATING_LINE,
|
||||
["UCS category count:", f"{len(ctx.catlist)}"],
|
||||
["Total categories in sample:", f"{total_cats}",
|
||||
f"{float(total_cats)/float(len(ctx.catlist)):.2%}"],
|
||||
[f"Most missed category ({miss_counts[-1][0]}):",
|
||||
f"{miss_counts[-1][1]}",
|
||||
f"{float(miss_counts[-1][1])/float(total):.2%}"]
|
||||
]
|
||||
["Total records in sample:", f"{total}"],
|
||||
["Top Result:", f"{total_top}",
|
||||
f"{float(total_top)/float(total):.2%}"],
|
||||
["Top 5 Result:", f"{total_top_5}",
|
||||
f"{float(total_top_5)/float(total):.2%}"],
|
||||
["Top 10 Result:", f"{total_top_10}",
|
||||
f"{float(total_top_10)/float(total):.2%}"],
|
||||
SEPARATING_LINE,
|
||||
["UCS category count:", f"{len(ctx.catlist)}"],
|
||||
["Total categories in sample:", f"{total_cats}",
|
||||
f"{float(total_cats)/float(len(ctx.catlist)):.2%}"],
|
||||
[f"Most missed category ({miss_counts[-1][0]}):",
|
||||
f"{miss_counts[-1][1]}",
|
||||
f"{float(miss_counts[-1][1])/float(total):.2%}"]
|
||||
]
|
||||
|
||||
print(tabulate(table, headers=['', 'n', 'pct']))
|
||||
|
||||
print(tabulate(table, headers=['','n','pct']))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
|
||||
|
||||
|
Reference in New Issue
Block a user