Compare commits

...

3 Commits

Author SHA1 Message Date
Jamie Hardt
7c591e9dbb Merge branch 'master' of https://git.squad51.us/jamie/ucsinfer 2025-08-30 20:36:47 -07:00
Jamie Hardt
3009d3831e Elborated rename function to recommend 2025-08-30 20:36:36 -07:00
Jamie Hardt
47829c5427 Added rename function to recommend 2025-08-30 20:14:09 -07:00
2 changed files with 67 additions and 9 deletions

View File

@@ -1,6 +1,7 @@
import os
import sys
import csv
from re import match
from sentence_transformers import SentenceTransformer
import tqdm
@@ -14,17 +15,33 @@ from .util import ffmpeg_description, parse_ucs
def recommend_text(text: str, ctx: InferenceContext):
return ctx.classify_text_ranked(text)
def print_recommendation(path: str | None, text: str, ctx: InferenceContext):
recommendations = ctx.classify_text_ranked(text)
def print_recommendation(path: str | None, text: str, ctx: InferenceContext,
interactive_rename: bool):
recs = ctx.classify_text_ranked(text)
print("----------")
if path:
print(f"Path: {path}")
print(f"Text: {text or '<None>'}")
for i, r in enumerate(recommendations):
for i, r in enumerate(recs):
cat, subcat, _ = ctx.lookup_category(r)
print(f"- {i}: {r} ({cat}-{subcat})")
if interactive_rename and path is not None:
response = input("Enter number, t <text> for alternate text, or "
"return to skip: ")
if m := match(r'^([0-9]+)', response):
selection = int(m.group(1))
if 0 <= selection < len(recs):
new_name = recs[selection] + '_' + os.path.basename(path)
new_path = os.path.join(os.path.dirname(path), new_name)
os.rename(path, new_path)
elif m := match(r'^t (.*)', response):
print("searching for new matches")
text = m.group(1)
print_recommendation(path, text, ctx, True)
@click.group(epilog="For more information see "
"<https://git.squad51.us/jamie/ucsinfer>")
@@ -40,12 +57,16 @@ def ucsinfer():
@click.option('--text', default=None,
help="Recommend a category for given text instead of reading "
"from a file")
@click.argument('paths', nargs=-1)
@click.argument('paths', nargs=-1, metavar='<paths>')
@click.option('--model', type=str, metavar="<model-name>",
default="paraphrase-multilingual-mpnet-base-v2",
show_default=True,
help="Select the sentence_transformer model to use")
def recommend(text, paths, model):
@click.option('--interactive','-i', flag_value=True, default=False,
help="After processing each path in <paths>, prompt for a "
"recommendation to accept, and then prepend the selection to "
"the file name.")
def recommend(text, paths, model, interactive):
"""
Infer a UCS category for a text description
@@ -59,16 +80,16 @@ def recommend(text, paths, model):
ctx = InferenceContext(m, model)
if text is not None:
print_recommendation(None, text, ctx)
print_recommendation(None, text, ctx, interactive_rename=False)
for path in paths:
text = ffmpeg_description(path)
if text:
print_recommendation(path, text, ctx)
print_recommendation(path, text, ctx, interactive)
else:
filename = os.path.basename(path)
print_recommendation(path, filename, ctx)
print_recommendation(path, filename, ctx, interactive)
@ucsinfer.command('gather')

View File

@@ -1,5 +1,6 @@
import subprocess
import json
import os
from typing import NamedTuple, Optional
from re import match
@@ -59,6 +60,20 @@ class UcsNameComponents(NamedTuple):
return False
def normalize_ucs(basename: str, catid_list: list[str]):
"""
Take any filename and normalize it into the UCS system
"""
n, ext = os.path.splitext(basename)
r = parse_ucs(n, catid_list)
if r:
pass
else:
pass
return f"aaa.{ext}"
def build_ucs(components: UcsNameComponents, extension: str) -> str:
"""
Build a UCS filename
@@ -66,7 +81,29 @@ def build_ucs(components: UcsNameComponents, extension: str) -> str:
assert components.validate(), \
"UcsNameComponents contains invalid characters"
return ""
cat_segment = components.cat_id
if components.user_cat:
cat_segment += f"-{components.user_cat}"
name_segment = components.fx_name
if components.vendor_cat:
name_segment = f"{components.vendor_cat}-{components.fx_name}"
all_comps = [cat_segment, name_segment]
if components.creator:
all_comps += [components.creator]
if components.source:
all_comps += [components.source]
if components.user_data:
all_comps += [components.user_data]
root_name = "_".join(all_comps)
return root_name + '.' + extension
def parse_ucs(rootname: str,