7.7 KiB
Finding UCS categories with sentence embedding¶
In this brief example we use sentence embedding to decide the UCS category for a sound, based on a text description.
Step 1: Creating embeddings for UCS categories¶
We first select a SentenceTransformer model and establish a method for generating embeddings that correspond with each category by using the Explanations, Category, SubCategory, and Synonyms from the UCS spreadsheet.
model.encode
is a slow process so we can write this as an async function so the client can parallelize it if it wants to.
import json
import os.path
from sentence_transformers import SentenceTransformer
import numpy as np
from numpy.linalg import norm
MODEL_NAME = "paraphrase-multilingual-mpnet-base-v2"
model = SentenceTransformer(MODEL_NAME)
def build_category_embedding(cat_info: list[dict] ):
# print(f"Building embedding for {cat_info['CatID']}...")
components = [cat_info["Explanations"], cat_info["Category"], cat_info["SubCategory"]] + cat_info.get('Synonyms', [])
composite_text = ". ".join(components)
return model.encode(composite_text, convert_to_numpy=True)
We now generate an embeddings for each category using the ucs-community
repository, which conveniently has JSON versions of all of the UCS category descriptions and languages.
We cache the categories in a file named EMBEDDING_NAME.cache
so multiple runs don't have to recalculate the entire emebddings table. If this file doesn't exist we create it by creating the embeddings and pickling the result, and if it does we read it.
import pickle
def create_embeddings(ucs: list) -> list:
embeddings_list = []
for info in ucs:
embeddings_list += [{'CatID': info['CatID'],
'Embedding': build_category_embedding(info)
}]
return embeddings_list
EMBEDDING_CACHE_NAME = MODEL_NAME + ".cache"
if not os.path.exists(EMBEDDING_CACHE_NAME):
print("Cached embeddings unavailable, recalculating...")
# for lang in ['en']:
with open("ucs-community/json/en.json") as f:
ucs = json.load(f)
print(f"Loaded {len(ucs)} categories...")
embeddings_list = create_embeddings(ucs)
with open(EMBEDDING_CACHE_NAME, "wb") as g:
print("Writing embeddings to file...")
pickle.dump(embeddings_list, g)
else:
print(f"Loading cached category emebddings...")
with open(EMBEDDING_CACHE_NAME, "rb") as g:
embeddings_list = pickle.load(g)
print(f"Loaded {len(embeddings_list)} category embeddings...")
def classify_text(text):
text_embedding = model.encode(text, convert_to_numpy=True)
sim = model.similarity(text_embedding, [info['Embedding'] for info in embeddings_list])
maxind = np.argmax(sim)
print(f" ⇒ Category: {embeddings_list[maxind]['CatID']}")
def classify_text_ranked(text):
text_embedding = model.encode(text, convert_to_numpy=True)
embeddings = np.array([info['Embedding'] for info in embeddings_list])
sim = model.similarity(text_embedding, embeddings)[0]
maxinds = np.argsort(sim)[-5:]
# print(maxinds)
print(" ⇒ Top 5: " + ", ".join([embeddings_list[x]['CatID'] for x in reversed(maxinds)]))
texts = [
"Black powder explosion with loud report",
"Steam enging chuff",
"Playing card flick onto table",
"BMW 228 out fast",
"City night skyline atmosphere",
"Civil war 12-pound gun cannon",
"Domestic combination boiler - pump switches off & cooling",
"Cello bow on cactus, animal screech",
"Electricity Generator And Arc Machine Start Up",
"Horse, canter One Horse: Canter Up, Stop"
]
for text in texts:
print(f"Text: {text}")
classify_text_ranked(text)
print("")