Cleaned up code and refactored to new file
This commit is contained in:
@@ -1,13 +1,11 @@
|
|||||||
import os
|
import os
|
||||||
import logging
|
import logging
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
import csv
|
|
||||||
|
|
||||||
from typing import Generator
|
|
||||||
|
|
||||||
import click
|
import click
|
||||||
|
|
||||||
from .inference import InferenceContext, load_ucs
|
from .inference import InferenceContext, load_ucs
|
||||||
|
from .import_csv import csv_to_data
|
||||||
from .gather import (build_sentence_class_dataset, print_dataset_stats,
|
from .gather import (build_sentence_class_dataset, print_dataset_stats,
|
||||||
ucs_definitions_generator, scan_metadata, walk_path)
|
ucs_definitions_generator, scan_metadata, walk_path)
|
||||||
from .recommend import print_recommendation
|
from .recommend import print_recommendation
|
||||||
@@ -18,10 +16,11 @@ logger.setLevel(logging.DEBUG)
|
|||||||
stream_handler = logging.StreamHandler()
|
stream_handler = logging.StreamHandler()
|
||||||
stream_handler.setLevel(logging.WARN)
|
stream_handler.setLevel(logging.WARN)
|
||||||
formatter = logging.Formatter(
|
formatter = logging.Formatter(
|
||||||
'%(asctime)s. %(levelname)s %(name)s: %(message)s')
|
'%(asctime)s. %(levelname)s %(name)s: %(message)s')
|
||||||
stream_handler.setFormatter(formatter)
|
stream_handler.setFormatter(formatter)
|
||||||
logger.addHandler(stream_handler)
|
logger.addHandler(stream_handler)
|
||||||
|
|
||||||
|
|
||||||
@click.group(epilog="For more information see "
|
@click.group(epilog="For more information see "
|
||||||
"<https://git.squad51.us/jamie/ucsinfer>")
|
"<https://git.squad51.us/jamie/ucsinfer>")
|
||||||
@click.option('--verbose', '-v', flag_value=True, help='Verbose output')
|
@click.option('--verbose', '-v', flag_value=True, help='Verbose output')
|
||||||
@@ -46,8 +45,8 @@ def ucsinfer(ctx, verbose, no_model_cache, model, complete_ucs):
|
|||||||
else:
|
else:
|
||||||
import warnings
|
import warnings
|
||||||
warnings.filterwarnings(
|
warnings.filterwarnings(
|
||||||
action='ignore', module='torch', category=FutureWarning,
|
action='ignore', module='torch', category=FutureWarning,
|
||||||
message=r"`encoder_attention_mask` is deprecated.*")
|
message=r"`encoder_attention_mask` is deprecated.*")
|
||||||
|
|
||||||
stream_handler.setLevel(logging.WARNING)
|
stream_handler.setLevel(logging.WARNING)
|
||||||
|
|
||||||
@@ -77,7 +76,7 @@ def ucsinfer(ctx, verbose, no_model_cache, model, complete_ucs):
|
|||||||
help="Recommend a category for given text instead of reading "
|
help="Recommend a category for given text instead of reading "
|
||||||
"from a file")
|
"from a file")
|
||||||
@click.argument('paths', nargs=-1, metavar='<paths>')
|
@click.argument('paths', nargs=-1, metavar='<paths>')
|
||||||
@click.option('--interactive','-i', flag_value=True, default=False,
|
@click.option('--interactive', '-i', flag_value=True, default=False,
|
||||||
help="After processing each path in <paths>, prompt for a "
|
help="After processing each path in <paths>, prompt for a "
|
||||||
"recommendation to accept, and then prepend the selection to "
|
"recommendation to accept, and then prepend the selection to "
|
||||||
"the file name.")
|
"the file name.")
|
||||||
@@ -121,7 +120,8 @@ def recommend(ctx, text, paths, interactive, skip_ucs):
|
|||||||
text = os.path.basename(path)
|
text = os.path.basename(path)
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
retval = print_recommendation(path, text, inference_ctx, interactive)
|
retval = print_recommendation(
|
||||||
|
path, text, inference_ctx, interactive)
|
||||||
if not retval:
|
if not retval:
|
||||||
break
|
break
|
||||||
if retval[0] is False:
|
if retval[0] is False:
|
||||||
@@ -136,28 +136,6 @@ def recommend(ctx, text, paths, interactive, skip_ucs):
|
|||||||
os.rename(path, new_path)
|
os.rename(path, new_path)
|
||||||
break
|
break
|
||||||
|
|
||||||
def csv_to_data(paths, description_key, filename_key, catid_list) -> Generator[tuple[str, str], None, None]:
|
|
||||||
"""
|
|
||||||
Accepts a list of paths and returns an iterator of (sentence, class)
|
|
||||||
tuples.
|
|
||||||
"""
|
|
||||||
for path in paths:
|
|
||||||
with open(path, 'r') as f:
|
|
||||||
records = csv.DictReader(f)
|
|
||||||
|
|
||||||
assert filename_key in records.fieldnames, \
|
|
||||||
(f"Filename key `{filename_key}` not present in file "
|
|
||||||
"{path}")
|
|
||||||
|
|
||||||
assert description_key in records.fieldnames, \
|
|
||||||
(f"Description key `{description_key}` not present in "
|
|
||||||
"file {path}")
|
|
||||||
|
|
||||||
for record in records:
|
|
||||||
ucs_comps = parse_ucs(record[filename_key], catid_list)
|
|
||||||
if ucs_comps:
|
|
||||||
yield (record[description_key], ucs_comps.cat_id)
|
|
||||||
|
|
||||||
|
|
||||||
@ucsinfer.command('csv')
|
@ucsinfer.command('csv')
|
||||||
@click.option('--filename-col', default="FileName",
|
@click.option('--filename-col', default="FileName",
|
||||||
@@ -188,8 +166,8 @@ def import_csv(ctx, paths: list[str], out, filename_col, description_col):
|
|||||||
logger.info("Building dataset from csv...")
|
logger.info("Building dataset from csv...")
|
||||||
|
|
||||||
dataset = build_sentence_class_dataset(
|
dataset = build_sentence_class_dataset(
|
||||||
chain(csv_to_data(paths, description_col, filename_col, catid_list),
|
chain(csv_to_data(paths, description_col, filename_col, catid_list),
|
||||||
ucs_definitions_generator(ucs)),catid_list)
|
ucs_definitions_generator(ucs)), catid_list)
|
||||||
|
|
||||||
logger.info(f"Saving dataset to disk at {out}")
|
logger.info(f"Saving dataset to disk at {out}")
|
||||||
print_dataset_stats(dataset, catid_list)
|
print_dataset_stats(dataset, catid_list)
|
||||||
@@ -220,7 +198,7 @@ def gather(ctx, paths, out):
|
|||||||
logger.debug(f"Loading category list...")
|
logger.debug(f"Loading category list...")
|
||||||
ucs = load_ucs(full_ucs=ctx.obj['complete_ucs'])
|
ucs = load_ucs(full_ucs=ctx.obj['complete_ucs'])
|
||||||
|
|
||||||
scan_list: list[tuple[str,str]] = []
|
scan_list: list[tuple[str, str]] = []
|
||||||
catid_list = [cat.catid for cat in ucs]
|
catid_list = [cat.catid for cat in ucs]
|
||||||
|
|
||||||
for path in paths:
|
for path in paths:
|
||||||
@@ -231,14 +209,15 @@ def gather(ctx, paths, out):
|
|||||||
logger.info("Building dataset files...")
|
logger.info("Building dataset files...")
|
||||||
|
|
||||||
dataset = build_sentence_class_dataset(
|
dataset = build_sentence_class_dataset(
|
||||||
chain(scan_metadata(scan_list, catid_list),
|
chain(scan_metadata(scan_list, catid_list),
|
||||||
ucs_definitions_generator(ucs)),
|
ucs_definitions_generator(ucs)),
|
||||||
catid_list)
|
catid_list)
|
||||||
|
|
||||||
logger.info(f"Saving dataset to disk at {out}")
|
logger.info(f"Saving dataset to disk at {out}")
|
||||||
print_dataset_stats(dataset, catid_list)
|
print_dataset_stats(dataset, catid_list)
|
||||||
dataset.save_to_disk(out)
|
dataset.save_to_disk(out)
|
||||||
|
|
||||||
|
|
||||||
@ucsinfer.command('qualify')
|
@ucsinfer.command('qualify')
|
||||||
def qualify():
|
def qualify():
|
||||||
"""
|
"""
|
||||||
@@ -260,7 +239,6 @@ def finetune(ctx):
|
|||||||
logger.debug("FINETUNE mode")
|
logger.debug("FINETUNE mode")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ucsinfer.command('evaluate')
|
@ucsinfer.command('evaluate')
|
||||||
@click.argument('dataset', default='dataset/')
|
@click.argument('dataset', default='dataset/')
|
||||||
@click.pass_context
|
@click.pass_context
|
||||||
@@ -272,6 +250,5 @@ def evaluate(ctx, dataset, offset, limit):
|
|||||||
logger.warning("Model evaluation is not currently implemented")
|
logger.warning("Model evaluation is not currently implemented")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
ucsinfer(obj={})
|
ucsinfer(obj={})
|
||||||
|
|||||||
28
ucsinfer/import_csv.py
Normal file
28
ucsinfer/import_csv.py
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
import csv
|
||||||
|
|
||||||
|
from typing import Generator
|
||||||
|
|
||||||
|
from .util import parse_ucs
|
||||||
|
|
||||||
|
|
||||||
|
def csv_to_data(paths, description_key, filename_key, catid_list) -> Generator[tuple[str, str], None, None]:
|
||||||
|
"""
|
||||||
|
Accepts a list of paths and returns an iterator of (sentence, class)
|
||||||
|
tuples.
|
||||||
|
"""
|
||||||
|
for path in paths:
|
||||||
|
with open(path, 'r') as f:
|
||||||
|
records = csv.DictReader(f)
|
||||||
|
|
||||||
|
assert filename_key in records.fieldnames, \
|
||||||
|
(f"Filename key `{filename_key}` not present in file "
|
||||||
|
"{path}")
|
||||||
|
|
||||||
|
assert description_key in records.fieldnames, \
|
||||||
|
(f"Description key `{description_key}` not present in "
|
||||||
|
"file {path}")
|
||||||
|
|
||||||
|
for record in records:
|
||||||
|
ucs_comps = parse_ucs(record[filename_key], catid_list)
|
||||||
|
if ucs_comps:
|
||||||
|
yield (record[description_key], ucs_comps.cat_id)
|
||||||
Reference in New Issue
Block a user