Committing module
This commit is contained in:
162
notebooks/00_Infer.ipynb
Normal file
162
notebooks/00_Infer.ipynb
Normal file
@@ -0,0 +1,162 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "1a3ea8e9-175b-4592-9341-cdba151c6fc2",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"categories = [\n",
|
||||
" {\n",
|
||||
" \"name\": \"Automobile\",\n",
|
||||
" \"description\": \"Topics related to vehicles such as cars, trucks, and their brands.\",\n",
|
||||
" \"keywords\": [\"Mazda\", \"Toyota\", \"SUV\", \"sedan\", \"pickup\"]\n",
|
||||
" },\n",
|
||||
" {\n",
|
||||
" \"name\": \"Firearms\",\n",
|
||||
" \"description\": \"Topics related to guns, rifles, pistols, ammunition, and other weapons.\",\n",
|
||||
" \"keywords\": [\"Winchester\", \"Glock\", \"rifle\", \"bullet\", \"shotgun\"]\n",
|
||||
" },\n",
|
||||
" {\n",
|
||||
" \"name\": \"Computers\",\n",
|
||||
" \"description\": \"Topics involving computer hardware and software, such as hard drives, CPUs, and laptops.\",\n",
|
||||
" \"keywords\": [\"Winchester\", \"CPU\", \"hard drive\", \"RAM\", \"SSD\", \"motherboard\"]\n",
|
||||
" },\n",
|
||||
"]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "d181860c-dd33-4327-b9d9-29297170558d",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/Users/j/Library/Caches/pypoetry/virtualenvs/ucsinfer2-yBtBMMP2-py3.13/lib/python3.13/site-packages/torch/nn/modules/module.py:1762: FutureWarning: `encoder_attention_mask` is deprecated and will be removed in version 4.55.0 for `BertSdpaSelfAttention.forward`.\n",
|
||||
" return forward_call(*args, **kwargs)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from sentence_transformers import SentenceTransformer\n",
|
||||
"import numpy as np\n",
|
||||
"from numpy.linalg import norm\n",
|
||||
"\n",
|
||||
"model = SentenceTransformer(\"all-MiniLM-L6-v2\")\n",
|
||||
"\n",
|
||||
"def build_category_embedding(cat_info):\n",
|
||||
" components = [cat_info[\"name\"], cat_info[\"description\"]] + cat_info.get('keywords', [])\n",
|
||||
" composite_text = \". \".join(components)\n",
|
||||
" return model.encode(composite_text, convert_to_numpy=True)\n",
|
||||
"\n",
|
||||
"# Embed all categories\n",
|
||||
"\n",
|
||||
"for info in categories:\n",
|
||||
" info['embedding'] = build_category_embedding(info)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "00842b35-04b3-42db-b352-b78154b65818",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# def cosine_similarity(a, b):\n",
|
||||
"# return np.dot(a, b) / (norm(a) * norm(b))\n",
|
||||
"\n",
|
||||
"def classify_text(text, categories):\n",
|
||||
" text_embedding = model.encode(text, convert_to_numpy=True)\n",
|
||||
"\n",
|
||||
" print(f\"Text: {text}\")\n",
|
||||
" sim = model.similarity(text_embedding, [info['embedding'] for info in categories])\n",
|
||||
" print(sim)\n",
|
||||
" maxind = np.argmax(sim)\n",
|
||||
" print(f\" -> Category: {categories[maxind]['name']}\")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "deb8cd47-b09b-445c-a688-a7e0bf0d7f2e",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Text: I took my Winchester to the shooting range yesterday.\n",
|
||||
"tensor([[-0.0456, 0.3874, 0.0935]])\n",
|
||||
" -> Category: Firearms\n",
|
||||
"Text: I bought a new Mazda with an automatic transmission.\n",
|
||||
"tensor([[0.3483, 0.0454, 0.0285]])\n",
|
||||
" -> Category: Automobile\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/Users/j/Library/Caches/pypoetry/virtualenvs/ucsinfer2-yBtBMMP2-py3.13/lib/python3.13/site-packages/sentence_transformers/util.py:55: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/utils/tensor_new.cpp:257.)\n",
|
||||
" a = torch.tensor(a)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Text: My old Winchester hard drive finally failed.\n",
|
||||
"tensor([[-0.0305, 0.2024, 0.3047]])\n",
|
||||
" -> Category: Computers\n",
|
||||
"Text: Keys clicking, typing\n",
|
||||
"tensor([[0.0957, 0.1107, 0.1531]])\n",
|
||||
" -> Category: Computers\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"text1 = \"I took my Winchester to the shooting range yesterday.\"\n",
|
||||
"text2 = \"I bought a new Mazda with an automatic transmission.\"\n",
|
||||
"text3 = \"My old Winchester hard drive finally failed.\"\n",
|
||||
"text4 = \"Keys clicking, typing\"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"for text in [text1, text2, text3, text4]:\n",
|
||||
" classify_text(text, categories)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "1831d482-2596-439c-9641-eb91fcae73c6",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.13.5"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
Reference in New Issue
Block a user