{ "cells": [ { "cell_type": "markdown", "id": "42aaa537-ab3b-46d2-a12b-d0355669bf26", "metadata": {}, "source": [ "# ZeroShotClassification\n", "\n", "Experimenting with a ZeroShotClassification Pipeline.\n", "\n", "This works very good. Even \"login id\" was tagged as \"username\" at 20% confidence. Based on testing we will not consider below 18%." ] }, { "cell_type": "code", "execution_count": 1, "id": "c8e72fc9-198e-43c6-8045-aaf811145e1c", "metadata": {}, "outputs": [], "source": [ "from transformers import pipeline\n", "from faker import Faker" ] }, { "cell_type": "code", "execution_count": 2, "id": "620682d6-fd51-47b7-b99b-a32434c28044", "metadata": {}, "outputs": [], "source": [ "fake = Faker()" ] }, { "cell_type": "code", "execution_count": 3, "id": "0c5d29b2-2d9f-4009-9f55-690af4be7ea7", "metadata": {}, "outputs": [], "source": [ "# Create a dictionary of Faker functions with descriptive labels\n", "faker_functions = {\n", " \"person name\": fake.name,\n", " \"first name\": fake.first_name,\n", " \"last name\": fake.last_name,\n", " \"email address\": fake.email,\n", " \"phone number\": fake.phone_number,\n", " \"street address\": fake.street_address,\n", " \"city name\": fake.city,\n", " \"state name\": fake.state,\n", " \"country name\": fake.country,\n", " \"zip code\": fake.zipcode,\n", " \"job title\": fake.job,\n", " \"company name\": fake.company,\n", " \"credit card number\": fake.credit_card_number,\n", " \"date of birth\": fake.date_of_birth,\n", " \"username\": fake.user_name,\n", " \"website url\": fake.url,\n", " \"paragraph text\": fake.paragraph,\n", " \"sentence text\": fake.sentence\n", "}" ] }, { "cell_type": "code", "execution_count": 4, "id": "998e0fa8-0ed5-4b39-8f5e-732628ace900", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Device set to use mps:0\n" ] } ], "source": [ "pipe = pipeline(model=\"facebook/bart-large-mnli\")" ] }, { "cell_type": "code", "execution_count": 5, "id": "1a1d7010-1291-4e31-b525-83327b5e7c01", "metadata": {}, "outputs": [], "source": [ "result = pipe(\n", " [\"The first name of a user\", \"login id\", \"full name of member\"],\n", " candidate_labels=list(faker_functions.keys())\n", ")" ] }, { "cell_type": "code", "execution_count": 6, "id": "d8cd7a50-aa79-48e4-aba9-7e78a4e64074", "metadata": {}, "outputs": [], "source": [ "def get_highest_score_functions(result, faker_functions, threshold=0.18):\n", " sequence_to_function = {}\n", " \n", " for item in result:\n", " sequence = item['sequence']\n", " label = item['labels'][0]\n", " score = item['scores'][0]\n", " \n", " if (score >= threshold):\n", " sequence_to_function[sequence] = faker_functions.get(label)\n", " else:\n", " sequence_to_function[sequence] = None\n", " \n", " return sequence_to_function" ] }, { "cell_type": "code", "execution_count": 7, "id": "5c1989e2-99b8-4de3-b7f9-45c625405eb9", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'The first name of a user': >,\n", " 'login id': >,\n", " 'full name of member': >}" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "get_highest_score_functions(result, faker_functions, threshold=0.18)" ] }, { "cell_type": "code", "execution_count": null, "id": "799de95c-3d08-4da9-8a84-c3bc81ab5928", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.11" } }, "nbformat": 4, "nbformat_minor": 5 }