{ "cells": [ { "cell_type": "markdown", "id": "eb04426a-cfb8-4f6f-9348-4308438fb9a5", "metadata": {}, "source": [ "# Finding UCS categories with sentence embedding\n", "\n", "In this brief example we use sentence embedding to decide the UCS category for a sound, based on a text description.\n", "\n", "## Step 1: Creating embeddings for UCS categories\n", "\n", "We first select a SentenceTransformer model and establish a method for generating embeddings that correspond with each category by using the _Explanations_, _Category_, _SubCategory_, and _Synonyms_ from the UCS spreadsheet.\n", "\n", "`model.encode` is a slow process so we can write this as an async function so the client can parallelize it if it wants to." ] }, { "cell_type": "code", "execution_count": 23, "id": "ef63fc07-c0d7-4616-9be1-1f0c2f275a69", "metadata": {}, "outputs": [], "source": [ "import json\n", "import os.path\n", "\n", "from sentence_transformers import SentenceTransformer\n", "import numpy as np\n", "from numpy.linalg import norm\n", "\n", "MODEL_NAME = \"paraphrase-multilingual-mpnet-base-v2\"\n", "\n", "model = SentenceTransformer(MODEL_NAME)\n", "\n", "def build_category_embedding(cat_info: list[dict] ):\n", " # print(f\"Building embedding for {cat_info['CatID']}...\")\n", " components = [cat_info[\"Explanations\"], cat_info[\"Category\"], cat_info[\"SubCategory\"]] + cat_info.get('Synonyms', [])\n", " composite_text = \". \".join(components)\n", " return model.encode(composite_text, convert_to_numpy=True)" ] }, { "cell_type": "markdown", "id": "da2820ad-851b-4cb7-a496-b124275eef58", "metadata": {}, "source": [ "We now generate an embeddings for each category using the `ucs-community` repository, which conveniently has JSON versions of all of the UCS category descriptions and languages.\n", "\n", "We cache the categories in a file named `EMBEDDING_NAME.cache` so multiple runs don't have to recalculate the entire emebddings table. If this file doesn't exist we create it by creating the embeddings and pickling the result, and if it does we read it." ] }, { "cell_type": "code", "execution_count": 24, "id": "d4c1bc74-5c5d-4714-b671-c75c45b82490", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Cached embeddings unavailable, recalculating...\n", "Loaded 752 categories...\n", "Writing embeddings to file...\n", "Loaded 752 category embeddings...\n" ] } ], "source": [ "import pickle\n", "\n", "def create_embeddings(ucs: list) -> list:\n", " embeddings_list = []\n", " for info in ucs:\n", " embeddings_list += [{'CatID': info['CatID'], \n", " 'Embedding': build_category_embedding(info)\n", " }]\n", "\n", " return embeddings_list\n", "\n", "EMBEDDING_CACHE_NAME = MODEL_NAME + \".cache\"\n", "\n", "if not os.path.exists(EMBEDDING_CACHE_NAME):\n", " print(\"Cached embeddings unavailable, recalculating...\")\n", "\n", " # for lang in ['en']:\n", " with open(\"ucs-community/json/en.json\") as f:\n", " ucs = json.load(f)\n", " \n", " print(f\"Loaded {len(ucs)} categories...\")\n", " \n", " embeddings_list = create_embeddings(ucs)\n", "\n", " with open(EMBEDDING_CACHE_NAME, \"wb\") as g:\n", " print(\"Writing embeddings to file...\")\n", " pickle.dump(embeddings_list, g)\n", "\n", "else:\n", " print(f\"Loading cached category emebddings...\")\n", " with open(EMBEDDING_CACHE_NAME, \"rb\") as g:\n", " embeddings_list = pickle.load(g)\n", "\n", "\n", "print(f\"Loaded {len(embeddings_list)} category embeddings...\")" ] }, { "cell_type": "code", "execution_count": 29, "id": "c98d1af1-b4c5-478c-b051-0f8f33399dfd", "metadata": {}, "outputs": [], "source": [ "def classify_text(text):\n", " text_embedding = model.encode(text, convert_to_numpy=True)\n", " sim = model.similarity(text_embedding, [info['Embedding'] for info in embeddings_list])\n", " maxind = np.argmax(sim)\n", " print(f\" ⇒ Category: {embeddings_list[maxind]['CatID']}\")\n", "\n", "\n", "def classify_text_ranked(text):\n", " text_embedding = model.encode(text, convert_to_numpy=True)\n", " embeddings = np.array([info['Embedding'] for info in embeddings_list])\n", " sim = model.similarity(text_embedding, embeddings)[0]\n", " maxinds = np.argsort(sim)[-5:]\n", " # print(maxinds)\n", " print(\" ⇒ Top 5: \" + \", \".join([embeddings_list[x]['CatID'] for x in reversed(maxinds)]))\n" ] }, { "cell_type": "code", "execution_count": 30, "id": "21768c01-d75f-49be-9f47-686332ba7921", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Text: Black powder explosion with loud report\n", " ⇒ Top 5: EXPLMisc, AIRBrst, METLCrsh, EXPLReal, FIREBrst\n", "\n", "Text: Steam enging chuff\n", " ⇒ Top 5: TRNSteam, FIRESizz, FIREGas, WATRFizz, GEOFuma\n", "\n", "Text: Playing card flick onto table\n", " ⇒ Top 5: GAMEMisc, GAMEBoard, GAMECas, PAPRFltr, GAMEArcd\n", "\n", "Text: BMW 228 out fast\n", " ⇒ Top 5: MOTRMisc, AIRHiss, VEHTire, VEHMoto, VEHAntq\n", "\n", "Text: City night skyline atmosphere\n", " ⇒ Top 5: AMBUrbn, AMBTraf, AMBCele, AMBAir, AMBTran\n", "\n", "Text: Civil war 12-pound gun cannon\n", " ⇒ Top 5: GUNCano, GUNArtl, GUNRif, BLLTMisc, WEAPMisc\n", "\n", "Text: Domestic combination boiler - pump switches off & cooling\n", " ⇒ Top 5: MACHHvac, MACHFan, MACHPump, MOTRTurb, MECHRelay\n", "\n", "Text: Cello bow on cactus, animal screech\n", " ⇒ Top 5: MUSCStr, CERMTonl, MUSCShake, MUSCPluck, MUSCWind\n", "\n", "Text: Electricity Generator And Arc Machine Start Up\n", " ⇒ Top 5: ELECArc, MOTRElec, ELECSprk, BOATElec, TOOLPowr\n", "\n", "Text: Horse, canter One Horse: Canter Up, Stop\n", " ⇒ Top 5: VOXScrm, WEAPWhip, FEETHors, MOVEAnml, VEHWagn\n", "\n" ] } ], "source": [ "texts = [\n", " \"Black powder explosion with loud report\",\n", " \"Steam enging chuff\",\n", " \"Playing card flick onto table\",\n", " \"BMW 228 out fast\",\n", " \"City night skyline atmosphere\",\n", " \"Civil war 12-pound gun cannon\",\n", " \"Domestic combination boiler - pump switches off & cooling\",\n", " \"Cello bow on cactus, animal screech\",\n", " \"Electricity Generator And Arc Machine Start Up\",\n", " \"Horse, canter One Horse: Canter Up, Stop\"\n", "]\n", "\n", "for text in texts:\n", " print(f\"Text: {text}\")\n", " classify_text_ranked(text)\n", " print(\"\")\n" ] }, { "cell_type": "code", "execution_count": null, "id": "36b7a30e-8bd1-486d-bc50-ce77704df64f", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.13.5" } }, "nbformat": 4, "nbformat_minor": 5 }