{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "1a3ea8e9-175b-4592-9341-cdba151c6fc2", "metadata": {}, "outputs": [], "source": [ "categories = [\n", " {\n", " \"name\": \"Automobile\",\n", " \"description\": \"Topics related to vehicles such as cars, trucks, and their brands.\",\n", " \"keywords\": [\"Mazda\", \"Toyota\", \"SUV\", \"sedan\", \"pickup\"]\n", " },\n", " {\n", " \"name\": \"Firearms\",\n", " \"description\": \"Topics related to guns, rifles, pistols, ammunition, and other weapons.\",\n", " \"keywords\": [\"Winchester\", \"Glock\", \"rifle\", \"bullet\", \"shotgun\"]\n", " },\n", " {\n", " \"name\": \"Computers\",\n", " \"description\": \"Topics involving computer hardware and software, such as hard drives, CPUs, and laptops.\",\n", " \"keywords\": [\"Winchester\", \"CPU\", \"hard drive\", \"RAM\", \"SSD\", \"motherboard\"]\n", " },\n", "]" ] }, { "cell_type": "code", "execution_count": 2, "id": "d181860c-dd33-4327-b9d9-29297170558d", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/Users/j/Library/Caches/pypoetry/virtualenvs/ucsinfer2-yBtBMMP2-py3.13/lib/python3.13/site-packages/torch/nn/modules/module.py:1762: FutureWarning: `encoder_attention_mask` is deprecated and will be removed in version 4.55.0 for `BertSdpaSelfAttention.forward`.\n", " return forward_call(*args, **kwargs)\n" ] } ], "source": [ "from sentence_transformers import SentenceTransformer\n", "import numpy as np\n", "from numpy.linalg import norm\n", "\n", "model = SentenceTransformer(\"all-MiniLM-L6-v2\")\n", "\n", "def build_category_embedding(cat_info):\n", " components = [cat_info[\"name\"], cat_info[\"description\"]] + cat_info.get('keywords', [])\n", " composite_text = \". \".join(components)\n", " return model.encode(composite_text, convert_to_numpy=True)\n", "\n", "# Embed all categories\n", "\n", "for info in categories:\n", " info['embedding'] = build_category_embedding(info)\n" ] }, { "cell_type": "code", "execution_count": 3, "id": "00842b35-04b3-42db-b352-b78154b65818", "metadata": {}, "outputs": [], "source": [ "# def cosine_similarity(a, b):\n", "# return np.dot(a, b) / (norm(a) * norm(b))\n", "\n", "def classify_text(text, categories):\n", " text_embedding = model.encode(text, convert_to_numpy=True)\n", "\n", " print(f\"Text: {text}\")\n", " sim = model.similarity(text_embedding, [info['embedding'] for info in categories])\n", " print(sim)\n", " maxind = np.argmax(sim)\n", " print(f\" -> Category: {categories[maxind]['name']}\")\n" ] }, { "cell_type": "code", "execution_count": 4, "id": "deb8cd47-b09b-445c-a688-a7e0bf0d7f2e", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Text: I took my Winchester to the shooting range yesterday.\n", "tensor([[-0.0456, 0.3874, 0.0935]])\n", " -> Category: Firearms\n", "Text: I bought a new Mazda with an automatic transmission.\n", "tensor([[0.3483, 0.0454, 0.0285]])\n", " -> Category: Automobile\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/Users/j/Library/Caches/pypoetry/virtualenvs/ucsinfer2-yBtBMMP2-py3.13/lib/python3.13/site-packages/sentence_transformers/util.py:55: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/utils/tensor_new.cpp:257.)\n", " a = torch.tensor(a)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Text: My old Winchester hard drive finally failed.\n", "tensor([[-0.0305, 0.2024, 0.3047]])\n", " -> Category: Computers\n", "Text: Keys clicking, typing\n", "tensor([[0.0957, 0.1107, 0.1531]])\n", " -> Category: Computers\n" ] } ], "source": [ "text1 = \"I took my Winchester to the shooting range yesterday.\"\n", "text2 = \"I bought a new Mazda with an automatic transmission.\"\n", "text3 = \"My old Winchester hard drive finally failed.\"\n", "text4 = \"Keys clicking, typing\"\n", "\n", "\n", "for text in [text1, text2, text3, text4]:\n", " classify_text(text, categories)\n" ] }, { "cell_type": "code", "execution_count": null, "id": "1831d482-2596-439c-9641-eb91fcae73c6", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.13.5" } }, "nbformat": 4, "nbformat_minor": 5 }