Committing module
This commit is contained in:
4
.gitignore
vendored
Normal file
4
.gitignore
vendored
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
.ipynb_*
|
||||||
|
poetry.lock
|
||||||
|
*.pyc
|
||||||
|
.DS_Store
|
6
.gitmodules
vendored
Normal file
6
.gitmodules
vendored
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
[submodule "ucs-community"]
|
||||||
|
path = ucsinfer/ucs-community
|
||||||
|
url = https://github.com/iluvcapra/ucs-community.git
|
||||||
|
[submodule "ucsinfer/ucs-community"]
|
||||||
|
path = ucsinfer/ucs-community
|
||||||
|
url = https://github.com/iluvcapra/ucs-community.git
|
162
notebooks/00_Infer.ipynb
Normal file
162
notebooks/00_Infer.ipynb
Normal file
@@ -0,0 +1,162 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 1,
|
||||||
|
"id": "1a3ea8e9-175b-4592-9341-cdba151c6fc2",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"categories = [\n",
|
||||||
|
" {\n",
|
||||||
|
" \"name\": \"Automobile\",\n",
|
||||||
|
" \"description\": \"Topics related to vehicles such as cars, trucks, and their brands.\",\n",
|
||||||
|
" \"keywords\": [\"Mazda\", \"Toyota\", \"SUV\", \"sedan\", \"pickup\"]\n",
|
||||||
|
" },\n",
|
||||||
|
" {\n",
|
||||||
|
" \"name\": \"Firearms\",\n",
|
||||||
|
" \"description\": \"Topics related to guns, rifles, pistols, ammunition, and other weapons.\",\n",
|
||||||
|
" \"keywords\": [\"Winchester\", \"Glock\", \"rifle\", \"bullet\", \"shotgun\"]\n",
|
||||||
|
" },\n",
|
||||||
|
" {\n",
|
||||||
|
" \"name\": \"Computers\",\n",
|
||||||
|
" \"description\": \"Topics involving computer hardware and software, such as hard drives, CPUs, and laptops.\",\n",
|
||||||
|
" \"keywords\": [\"Winchester\", \"CPU\", \"hard drive\", \"RAM\", \"SSD\", \"motherboard\"]\n",
|
||||||
|
" },\n",
|
||||||
|
"]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 2,
|
||||||
|
"id": "d181860c-dd33-4327-b9d9-29297170558d",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"/Users/j/Library/Caches/pypoetry/virtualenvs/ucsinfer2-yBtBMMP2-py3.13/lib/python3.13/site-packages/torch/nn/modules/module.py:1762: FutureWarning: `encoder_attention_mask` is deprecated and will be removed in version 4.55.0 for `BertSdpaSelfAttention.forward`.\n",
|
||||||
|
" return forward_call(*args, **kwargs)\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"from sentence_transformers import SentenceTransformer\n",
|
||||||
|
"import numpy as np\n",
|
||||||
|
"from numpy.linalg import norm\n",
|
||||||
|
"\n",
|
||||||
|
"model = SentenceTransformer(\"all-MiniLM-L6-v2\")\n",
|
||||||
|
"\n",
|
||||||
|
"def build_category_embedding(cat_info):\n",
|
||||||
|
" components = [cat_info[\"name\"], cat_info[\"description\"]] + cat_info.get('keywords', [])\n",
|
||||||
|
" composite_text = \". \".join(components)\n",
|
||||||
|
" return model.encode(composite_text, convert_to_numpy=True)\n",
|
||||||
|
"\n",
|
||||||
|
"# Embed all categories\n",
|
||||||
|
"\n",
|
||||||
|
"for info in categories:\n",
|
||||||
|
" info['embedding'] = build_category_embedding(info)\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"id": "00842b35-04b3-42db-b352-b78154b65818",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# def cosine_similarity(a, b):\n",
|
||||||
|
"# return np.dot(a, b) / (norm(a) * norm(b))\n",
|
||||||
|
"\n",
|
||||||
|
"def classify_text(text, categories):\n",
|
||||||
|
" text_embedding = model.encode(text, convert_to_numpy=True)\n",
|
||||||
|
"\n",
|
||||||
|
" print(f\"Text: {text}\")\n",
|
||||||
|
" sim = model.similarity(text_embedding, [info['embedding'] for info in categories])\n",
|
||||||
|
" print(sim)\n",
|
||||||
|
" maxind = np.argmax(sim)\n",
|
||||||
|
" print(f\" -> Category: {categories[maxind]['name']}\")\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 4,
|
||||||
|
"id": "deb8cd47-b09b-445c-a688-a7e0bf0d7f2e",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Text: I took my Winchester to the shooting range yesterday.\n",
|
||||||
|
"tensor([[-0.0456, 0.3874, 0.0935]])\n",
|
||||||
|
" -> Category: Firearms\n",
|
||||||
|
"Text: I bought a new Mazda with an automatic transmission.\n",
|
||||||
|
"tensor([[0.3483, 0.0454, 0.0285]])\n",
|
||||||
|
" -> Category: Automobile\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"/Users/j/Library/Caches/pypoetry/virtualenvs/ucsinfer2-yBtBMMP2-py3.13/lib/python3.13/site-packages/sentence_transformers/util.py:55: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/utils/tensor_new.cpp:257.)\n",
|
||||||
|
" a = torch.tensor(a)\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Text: My old Winchester hard drive finally failed.\n",
|
||||||
|
"tensor([[-0.0305, 0.2024, 0.3047]])\n",
|
||||||
|
" -> Category: Computers\n",
|
||||||
|
"Text: Keys clicking, typing\n",
|
||||||
|
"tensor([[0.0957, 0.1107, 0.1531]])\n",
|
||||||
|
" -> Category: Computers\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"text1 = \"I took my Winchester to the shooting range yesterday.\"\n",
|
||||||
|
"text2 = \"I bought a new Mazda with an automatic transmission.\"\n",
|
||||||
|
"text3 = \"My old Winchester hard drive finally failed.\"\n",
|
||||||
|
"text4 = \"Keys clicking, typing\"\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"for text in [text1, text2, text3, text4]:\n",
|
||||||
|
" classify_text(text, categories)\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "1831d482-2596-439c-9641-eb91fcae73c6",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": []
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3 (ipykernel)",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.13.5"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
225
notebooks/01_UCS Embedding Classification.ipynb
Normal file
225
notebooks/01_UCS Embedding Classification.ipynb
Normal file
@@ -0,0 +1,225 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "eb04426a-cfb8-4f6f-9348-4308438fb9a5",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# Finding UCS categories with sentence embedding\n",
|
||||||
|
"\n",
|
||||||
|
"In this brief example we use sentence embedding to decide the UCS category for a sound, based on a text description.\n",
|
||||||
|
"\n",
|
||||||
|
"## Step 1: Creating embeddings for UCS categories\n",
|
||||||
|
"\n",
|
||||||
|
"We first select a SentenceTransformer model and establish a method for generating embeddings that correspond with each category by using the _Explanations_, _Category_, _SubCategory_, and _Synonyms_ from the UCS spreadsheet.\n",
|
||||||
|
"\n",
|
||||||
|
"`model.encode` is a slow process so we can write this as an async function so the client can parallelize it if it wants to."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 23,
|
||||||
|
"id": "ef63fc07-c0d7-4616-9be1-1f0c2f275a69",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import json\n",
|
||||||
|
"import os.path\n",
|
||||||
|
"\n",
|
||||||
|
"from sentence_transformers import SentenceTransformer\n",
|
||||||
|
"import numpy as np\n",
|
||||||
|
"from numpy.linalg import norm\n",
|
||||||
|
"\n",
|
||||||
|
"MODEL_NAME = \"paraphrase-multilingual-mpnet-base-v2\"\n",
|
||||||
|
"\n",
|
||||||
|
"model = SentenceTransformer(MODEL_NAME)\n",
|
||||||
|
"\n",
|
||||||
|
"def build_category_embedding(cat_info: list[dict] ):\n",
|
||||||
|
" # print(f\"Building embedding for {cat_info['CatID']}...\")\n",
|
||||||
|
" components = [cat_info[\"Explanations\"], cat_info[\"Category\"], cat_info[\"SubCategory\"]] + cat_info.get('Synonyms', [])\n",
|
||||||
|
" composite_text = \". \".join(components)\n",
|
||||||
|
" return model.encode(composite_text, convert_to_numpy=True)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "da2820ad-851b-4cb7-a496-b124275eef58",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"We now generate an embeddings for each category using the `ucs-community` repository, which conveniently has JSON versions of all of the UCS category descriptions and languages.\n",
|
||||||
|
"\n",
|
||||||
|
"We cache the categories in a file named `EMBEDDING_NAME.cache` so multiple runs don't have to recalculate the entire emebddings table. If this file doesn't exist we create it by creating the embeddings and pickling the result, and if it does we read it."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 24,
|
||||||
|
"id": "d4c1bc74-5c5d-4714-b671-c75c45b82490",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Cached embeddings unavailable, recalculating...\n",
|
||||||
|
"Loaded 752 categories...\n",
|
||||||
|
"Writing embeddings to file...\n",
|
||||||
|
"Loaded 752 category embeddings...\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"import pickle\n",
|
||||||
|
"\n",
|
||||||
|
"def create_embeddings(ucs: list) -> list:\n",
|
||||||
|
" embeddings_list = []\n",
|
||||||
|
" for info in ucs:\n",
|
||||||
|
" embeddings_list += [{'CatID': info['CatID'], \n",
|
||||||
|
" 'Embedding': build_category_embedding(info)\n",
|
||||||
|
" }]\n",
|
||||||
|
"\n",
|
||||||
|
" return embeddings_list\n",
|
||||||
|
"\n",
|
||||||
|
"EMBEDDING_CACHE_NAME = MODEL_NAME + \".cache\"\n",
|
||||||
|
"\n",
|
||||||
|
"if not os.path.exists(EMBEDDING_CACHE_NAME):\n",
|
||||||
|
" print(\"Cached embeddings unavailable, recalculating...\")\n",
|
||||||
|
"\n",
|
||||||
|
" # for lang in ['en']:\n",
|
||||||
|
" with open(\"ucs-community/json/en.json\") as f:\n",
|
||||||
|
" ucs = json.load(f)\n",
|
||||||
|
" \n",
|
||||||
|
" print(f\"Loaded {len(ucs)} categories...\")\n",
|
||||||
|
" \n",
|
||||||
|
" embeddings_list = create_embeddings(ucs)\n",
|
||||||
|
"\n",
|
||||||
|
" with open(EMBEDDING_CACHE_NAME, \"wb\") as g:\n",
|
||||||
|
" print(\"Writing embeddings to file...\")\n",
|
||||||
|
" pickle.dump(embeddings_list, g)\n",
|
||||||
|
"\n",
|
||||||
|
"else:\n",
|
||||||
|
" print(f\"Loading cached category emebddings...\")\n",
|
||||||
|
" with open(EMBEDDING_CACHE_NAME, \"rb\") as g:\n",
|
||||||
|
" embeddings_list = pickle.load(g)\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"print(f\"Loaded {len(embeddings_list)} category embeddings...\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 29,
|
||||||
|
"id": "c98d1af1-b4c5-478c-b051-0f8f33399dfd",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def classify_text(text):\n",
|
||||||
|
" text_embedding = model.encode(text, convert_to_numpy=True)\n",
|
||||||
|
" sim = model.similarity(text_embedding, [info['Embedding'] for info in embeddings_list])\n",
|
||||||
|
" maxind = np.argmax(sim)\n",
|
||||||
|
" print(f\" ⇒ Category: {embeddings_list[maxind]['CatID']}\")\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"def classify_text_ranked(text):\n",
|
||||||
|
" text_embedding = model.encode(text, convert_to_numpy=True)\n",
|
||||||
|
" embeddings = np.array([info['Embedding'] for info in embeddings_list])\n",
|
||||||
|
" sim = model.similarity(text_embedding, embeddings)[0]\n",
|
||||||
|
" maxinds = np.argsort(sim)[-5:]\n",
|
||||||
|
" # print(maxinds)\n",
|
||||||
|
" print(\" ⇒ Top 5: \" + \", \".join([embeddings_list[x]['CatID'] for x in reversed(maxinds)]))\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 30,
|
||||||
|
"id": "21768c01-d75f-49be-9f47-686332ba7921",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Text: Black powder explosion with loud report\n",
|
||||||
|
" ⇒ Top 5: EXPLMisc, AIRBrst, METLCrsh, EXPLReal, FIREBrst\n",
|
||||||
|
"\n",
|
||||||
|
"Text: Steam enging chuff\n",
|
||||||
|
" ⇒ Top 5: TRNSteam, FIRESizz, FIREGas, WATRFizz, GEOFuma\n",
|
||||||
|
"\n",
|
||||||
|
"Text: Playing card flick onto table\n",
|
||||||
|
" ⇒ Top 5: GAMEMisc, GAMEBoard, GAMECas, PAPRFltr, GAMEArcd\n",
|
||||||
|
"\n",
|
||||||
|
"Text: BMW 228 out fast\n",
|
||||||
|
" ⇒ Top 5: MOTRMisc, AIRHiss, VEHTire, VEHMoto, VEHAntq\n",
|
||||||
|
"\n",
|
||||||
|
"Text: City night skyline atmosphere\n",
|
||||||
|
" ⇒ Top 5: AMBUrbn, AMBTraf, AMBCele, AMBAir, AMBTran\n",
|
||||||
|
"\n",
|
||||||
|
"Text: Civil war 12-pound gun cannon\n",
|
||||||
|
" ⇒ Top 5: GUNCano, GUNArtl, GUNRif, BLLTMisc, WEAPMisc\n",
|
||||||
|
"\n",
|
||||||
|
"Text: Domestic combination boiler - pump switches off & cooling\n",
|
||||||
|
" ⇒ Top 5: MACHHvac, MACHFan, MACHPump, MOTRTurb, MECHRelay\n",
|
||||||
|
"\n",
|
||||||
|
"Text: Cello bow on cactus, animal screech\n",
|
||||||
|
" ⇒ Top 5: MUSCStr, CERMTonl, MUSCShake, MUSCPluck, MUSCWind\n",
|
||||||
|
"\n",
|
||||||
|
"Text: Electricity Generator And Arc Machine Start Up\n",
|
||||||
|
" ⇒ Top 5: ELECArc, MOTRElec, ELECSprk, BOATElec, TOOLPowr\n",
|
||||||
|
"\n",
|
||||||
|
"Text: Horse, canter One Horse: Canter Up, Stop\n",
|
||||||
|
" ⇒ Top 5: VOXScrm, WEAPWhip, FEETHors, MOVEAnml, VEHWagn\n",
|
||||||
|
"\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"texts = [\n",
|
||||||
|
" \"Black powder explosion with loud report\",\n",
|
||||||
|
" \"Steam enging chuff\",\n",
|
||||||
|
" \"Playing card flick onto table\",\n",
|
||||||
|
" \"BMW 228 out fast\",\n",
|
||||||
|
" \"City night skyline atmosphere\",\n",
|
||||||
|
" \"Civil war 12-pound gun cannon\",\n",
|
||||||
|
" \"Domestic combination boiler - pump switches off & cooling\",\n",
|
||||||
|
" \"Cello bow on cactus, animal screech\",\n",
|
||||||
|
" \"Electricity Generator And Arc Machine Start Up\",\n",
|
||||||
|
" \"Horse, canter One Horse: Canter Up, Stop\"\n",
|
||||||
|
"]\n",
|
||||||
|
"\n",
|
||||||
|
"for text in texts:\n",
|
||||||
|
" print(f\"Text: {text}\")\n",
|
||||||
|
" classify_text_ranked(text)\n",
|
||||||
|
" print(\"\")\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "36b7a30e-8bd1-486d-bc50-ce77704df64f",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": []
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3 (ipykernel)",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.13.5"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
146
notebooks/02_Gather Training Data.ipynb
Normal file
146
notebooks/02_Gather Training Data.ipynb
Normal file
@@ -0,0 +1,146 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"id": "24915fba-1f44-46af-9233-66b896d7fa41",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import json\n",
|
||||||
|
"\n",
|
||||||
|
"with open(\"ucs-community/json/en.json\") as f:\n",
|
||||||
|
" ucs = json.load(f)\n",
|
||||||
|
" cat_ids = [x['CatID'] for x in ucs]\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 4,
|
||||||
|
"id": "3a4b887d-5e72-4136-a7d9-650667619b12",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def ucs_catid(path: str) -> Optional[str]:\n",
|
||||||
|
" import os.path\n",
|
||||||
|
" 'True if the file at `path` has a valid UCS filename'\n",
|
||||||
|
"\n",
|
||||||
|
" basename = os.path.basename(path)\n",
|
||||||
|
" first_component = basename.split(\"_\")[0]\n",
|
||||||
|
"\n",
|
||||||
|
" if first_component in cat_ids:\n",
|
||||||
|
" return first_component\n",
|
||||||
|
" else:\n",
|
||||||
|
" return False"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 5,
|
||||||
|
"id": "faedeeb7-8c4d-4e60-ab7a-9baf9add008a",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from typing import Optional\n",
|
||||||
|
"\n",
|
||||||
|
"def description(path: str) -> Optional[str]:\n",
|
||||||
|
" import json, subprocess\n",
|
||||||
|
" result = subprocess.run(['ffprobe', '-show_format', '-of', 'json', path], capture_output=True)\n",
|
||||||
|
" try:\n",
|
||||||
|
" result.check_returncode()\n",
|
||||||
|
" except:\n",
|
||||||
|
" return None\n",
|
||||||
|
" \n",
|
||||||
|
" stream = json.loads(result.stdout)\n",
|
||||||
|
" fmt = stream.get(\"format\", None)\n",
|
||||||
|
" if fmt:\n",
|
||||||
|
" tags = fmt.get(\"tags\", None)\n",
|
||||||
|
" if tags:\n",
|
||||||
|
" return tags.get(\"comment\", None)\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 15,
|
||||||
|
"id": "f57c7c75-8cda-441a-bfb3-179e0afde861",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from typing import Optional, Tuple\n",
|
||||||
|
"\n",
|
||||||
|
"def test_data_for_file(path: str) -> Optional[Tuple[str, str]]:\n",
|
||||||
|
" 'CatID and description if both are present'\n",
|
||||||
|
"\n",
|
||||||
|
" catid = ucs_catid(path)\n",
|
||||||
|
" if catid is None:\n",
|
||||||
|
" return None\n",
|
||||||
|
" \n",
|
||||||
|
" desc = description(path)\n",
|
||||||
|
"\n",
|
||||||
|
" if desc is not None:\n",
|
||||||
|
" return (catid, desc)\n",
|
||||||
|
" else:\n",
|
||||||
|
" return None\n",
|
||||||
|
"\n",
|
||||||
|
"def collect_dataset(scan_root: str, set_name: str):\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
" Scans scan_root recursively and collects all catid/description pairs\n",
|
||||||
|
" it can find.\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
" import os, csv\n",
|
||||||
|
" test_data = []\n",
|
||||||
|
" for root, _, files in os.walk(scan_root):\n",
|
||||||
|
" for file in files:\n",
|
||||||
|
" if file.endswith(\".wav\") or file.endswith(\".flac\"):\n",
|
||||||
|
" if test_datum := test_data_for_file(os.path.join(root,file)):\n",
|
||||||
|
" test_data += [test_datum]\n",
|
||||||
|
"\n",
|
||||||
|
" with open(set_name + '.csv', 'w') as f:\n",
|
||||||
|
" writer = csv.writer(f)\n",
|
||||||
|
" writer.writerow(['Category', 'Description'])\n",
|
||||||
|
" for row in test_data:\n",
|
||||||
|
" writer.writerow(row)\n",
|
||||||
|
" \n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 17,
|
||||||
|
"id": "1e05629d-15a5-406b-8064-879900e4b3c7",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"collect_dataset(\"/Volumes/NAS SFX Library/JAMIELIB Libraries by Studio/_Designers/Jamie Hardt\",\"jamie_files\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "7340f734-5ba7-4db0-a012-9e2bd46a4fc5",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": []
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3 (ipykernel)",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.13.5"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
25
pyproject.toml
Normal file
25
pyproject.toml
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
[project]
|
||||||
|
name = "ucsinfer"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = ""
|
||||||
|
authors = [
|
||||||
|
{name = "Jamie Hardt",email = "jamiehardt@me.com"}
|
||||||
|
]
|
||||||
|
readme = "README.md"
|
||||||
|
requires-python = ">=3.13"
|
||||||
|
dependencies = [
|
||||||
|
"sentence-transformers (>=5.0.0,<6.0.0)",
|
||||||
|
"numpy (>=2.3.2,<3.0.0)",
|
||||||
|
"tqdm (>=4.67.1,<5.0.0)",
|
||||||
|
"platformdirs (>=4.3.8,<5.0.0)"
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
[build-system]
|
||||||
|
requires = ["poetry-core>=2.0.0,<3.0.0"]
|
||||||
|
build-backend = "poetry.core.masonry.api"
|
||||||
|
|
||||||
|
[tool.poetry.group.dev.dependencies]
|
||||||
|
ipython = "^9.4.0"
|
||||||
|
jupyter = "^1.1.1"
|
||||||
|
|
215
ucsinfer/__main__.py
Normal file
215
ucsinfer/__main__.py
Normal file
@@ -0,0 +1,215 @@
|
|||||||
|
import os
|
||||||
|
import json
|
||||||
|
import pickle
|
||||||
|
import sys
|
||||||
|
import subprocess
|
||||||
|
import cmd
|
||||||
|
|
||||||
|
from typing import Optional, Tuple, IO
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import platformdirs
|
||||||
|
import tqdm
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
|
||||||
|
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
|
||||||
|
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
|
||||||
|
|
||||||
|
import warnings
|
||||||
|
warnings.simplefilter(action='ignore', category=FutureWarning)
|
||||||
|
|
||||||
|
def load_ucs_categories() -> list:
|
||||||
|
cats = []
|
||||||
|
ucs_defs = os.path.join(ROOT_DIR, 'ucs-community', 'json', 'en.json')
|
||||||
|
|
||||||
|
with open(ucs_defs, 'r') as f:
|
||||||
|
cats = json.load(f)
|
||||||
|
|
||||||
|
return cats
|
||||||
|
|
||||||
|
def encoode_category(cat_defn: dict, model: SentenceTransformer) -> np.ndarray:
|
||||||
|
sentence_components = [cat_defn['Explanations'], cat_defn['Category'], cat_defn['SubCategory']]
|
||||||
|
sentence_components += cat_defn['Synonyms']
|
||||||
|
sentence = ", ".join(sentence_components)
|
||||||
|
return model.encode(sentence, convert_to_numpy=True)
|
||||||
|
|
||||||
|
def load_embeddings(ucs: list, model) -> list:
|
||||||
|
cache_dir = platformdirs.user_cache_dir('ucsinfer', 'Squad 51')
|
||||||
|
embedding_cache = os.path.join(cache_dir, f"ucs_embedding.cache")
|
||||||
|
embeddings = []
|
||||||
|
|
||||||
|
if os.path.exists(embedding_cache):
|
||||||
|
with open(embedding_cache, 'rb') as f:
|
||||||
|
embeddings = pickle.load(f)
|
||||||
|
|
||||||
|
else:
|
||||||
|
print("Calculating embeddings...")
|
||||||
|
|
||||||
|
for cat_defn in tqdm.tqdm(ucs):
|
||||||
|
embeddings += [{
|
||||||
|
'CatID': cat_defn['CatID'],
|
||||||
|
'Embedding': encoode_category(cat_defn, model)
|
||||||
|
}]
|
||||||
|
|
||||||
|
os.makedirs(os.path.dirname(embedding_cache), exist_ok=True)
|
||||||
|
with open(embedding_cache, 'wb') as g:
|
||||||
|
pickle.dump(embeddings, g)
|
||||||
|
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
|
def description(path: str) -> Optional[str]:
|
||||||
|
result = subprocess.run(['ffprobe', '-show_format', '-of', 'json', path], capture_output=True)
|
||||||
|
# print(result)
|
||||||
|
try:
|
||||||
|
result.check_returncode()
|
||||||
|
except:
|
||||||
|
return None
|
||||||
|
|
||||||
|
stream = json.loads(result.stdout)
|
||||||
|
fmt = stream.get("format", None)
|
||||||
|
if fmt:
|
||||||
|
tags = fmt.get("tags", None)
|
||||||
|
if tags:
|
||||||
|
return tags.get("comment", None)
|
||||||
|
|
||||||
|
|
||||||
|
def classify_text_ranked(text, embeddings_list, model, limit=5):
|
||||||
|
text_embedding = model.encode(text, convert_to_numpy=True)
|
||||||
|
embeddings = np.array([info['Embedding'] for info in embeddings_list])
|
||||||
|
sim = model.similarity(text_embedding, embeddings)[0]
|
||||||
|
maxinds = np.argsort(sim)[-limit:]
|
||||||
|
return [embeddings_list[x]['CatID'] for x in reversed(maxinds)]
|
||||||
|
|
||||||
|
|
||||||
|
def recommend_category(path, embeddings, model) -> Tuple[str, list]:
|
||||||
|
"""
|
||||||
|
Get a text description of the file at `path` and a list of UCS cat IDs
|
||||||
|
"""
|
||||||
|
desc = description(path)
|
||||||
|
if desc is None:
|
||||||
|
desc = os.path.basename(path)
|
||||||
|
|
||||||
|
return desc, classify_text_ranked(desc, embeddings, model)
|
||||||
|
|
||||||
|
def lookup_cat(catid: str, ucs: list) -> Optional[tuple[str,str]]:
|
||||||
|
return next( ((x['Category'], x['SubCategory']) for x in ucs if x['CatID'] == catid) , None)
|
||||||
|
|
||||||
|
|
||||||
|
class Commands(cmd.Cmd):
|
||||||
|
|
||||||
|
def __init__(self, completekey: str = "tab", stdin: IO[str] | None = None,
|
||||||
|
stdout: IO[str] | None = None) -> None:
|
||||||
|
super().__init__(completekey, stdin, stdout)
|
||||||
|
self.file_list = []
|
||||||
|
self.model = None
|
||||||
|
self.embeddings = None
|
||||||
|
self.catlist = None
|
||||||
|
self._rec_list = []
|
||||||
|
self._file_cursor = 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def file_cursor(self):
|
||||||
|
return self._file_cursor
|
||||||
|
|
||||||
|
@file_cursor.setter
|
||||||
|
def file_cursor(self, val):
|
||||||
|
self._file_cursor = val
|
||||||
|
self.onecmd('file')
|
||||||
|
|
||||||
|
@property
|
||||||
|
def rec_list(self):
|
||||||
|
return self._rec_list
|
||||||
|
|
||||||
|
@rec_list.setter
|
||||||
|
def rec_list(self, value):
|
||||||
|
self._rec_list = value
|
||||||
|
if isinstance(self.rec_list, list) and self.catlist:
|
||||||
|
for i, cat_id in enumerate(self.rec_list):
|
||||||
|
cat, subcat = lookup_cat(cat_id, self.catlist)
|
||||||
|
print(f" [ {i+1} ]: {cat_id} ({cat} / {subcat})")
|
||||||
|
|
||||||
|
def default(self, line):
|
||||||
|
if len(self.rec_list) > 0:
|
||||||
|
try:
|
||||||
|
rec = int(line)
|
||||||
|
if rec < len(self.rec_list):
|
||||||
|
print(f"Accept option {rec}")
|
||||||
|
self.onecmd("next")
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
|
||||||
|
except ValueError:
|
||||||
|
super().default(line)
|
||||||
|
|
||||||
|
else:
|
||||||
|
super().default(line)
|
||||||
|
|
||||||
|
def preloop(self) -> None:
|
||||||
|
self.file_cursor = 0
|
||||||
|
self.update_prompt()
|
||||||
|
return super().preloop()
|
||||||
|
|
||||||
|
def postcmd(self, stop: bool, line: str) -> bool:
|
||||||
|
return super().postcmd(stop, line)
|
||||||
|
|
||||||
|
def update_prompt(self):
|
||||||
|
self.prompt = f"(ucsinfer:{self.file_cursor}/{len(self.file_list)}) "
|
||||||
|
|
||||||
|
def do_file(self, args):
|
||||||
|
'Print info about the current file'
|
||||||
|
if self.file_cursor < len(self.file_list):
|
||||||
|
self.update_prompt()
|
||||||
|
path = self.file_list[self.file_cursor]
|
||||||
|
f = os.path.basename(path)
|
||||||
|
print(f" > {f}")
|
||||||
|
desc, recs = recommend_category(path, self.embeddings, self.model)
|
||||||
|
print(f" >> {desc}")
|
||||||
|
self.rec_list = recs
|
||||||
|
else:
|
||||||
|
print(" > No file")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def do_addcontext(self, args):
|
||||||
|
'Add the argument to all file descriptions before searching for '
|
||||||
|
'similar. Enter a blank value to reset.'
|
||||||
|
pass
|
||||||
|
|
||||||
|
def do_lookup(self, args):
|
||||||
|
'print a list of UCS categories similar to the argument'
|
||||||
|
self.rec_list = classify_text_ranked(args, self.embeddings, self.model)
|
||||||
|
|
||||||
|
def do_next(self, _):
|
||||||
|
'go to next file'
|
||||||
|
self.file_cursor += 1
|
||||||
|
|
||||||
|
def do_prev(self, _):
|
||||||
|
'go to previous file'
|
||||||
|
self.file_cursor -= 1
|
||||||
|
|
||||||
|
def do_quit(self, _):
|
||||||
|
'exit'
|
||||||
|
print("Exiting...")
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
cats = load_ucs_categories()
|
||||||
|
print(f"Loaded UCS categories.", file=sys.stderr)
|
||||||
|
model = SentenceTransformer("paraphrase-multilingual-mpnet-base-v2")
|
||||||
|
embeddings = load_embeddings(cats, model)
|
||||||
|
print(f"Loaded embeddings...", file=sys.stderr)
|
||||||
|
|
||||||
|
com = Commands()
|
||||||
|
com.file_list = sys.argv[1:]
|
||||||
|
com.model = model
|
||||||
|
com.embeddings = embeddings
|
||||||
|
com.catlist = cats
|
||||||
|
|
||||||
|
com.cmdloop()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
1
ucsinfer/ucs-community
Submodule
1
ucsinfer/ucs-community
Submodule
Submodule ucsinfer/ucs-community added at daa714a9b4
Reference in New Issue
Block a user