{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "5e21c7be", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[nltk_data] Downloading package punkt to\n", "[nltk_data] /Users/shenjiajun/nltk_data...\n", "[nltk_data] Package punkt is already up-to-date!\n", "/Users/shenjiajun/miniconda3/envs/torch_gpu/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] } ], "source": [ "from citekit.cite_modules.LLM import LLM\n", "from citekit.cite_modules.augment_model import AttributingModule\n", "from citekit.pipeline.pipeline import Sequence\n", "from citekit.prompt.prompt import Prompt\n", "from citekit.Dataset.Dataset import FileDataset\n", "from citekit.evaluator.evaluator import DefaultEvaluator\n", "import json" ] }, { "cell_type": "markdown", "id": "0260d8c5", "metadata": {}, "source": [ "# Data Loading" ] }, { "cell_type": "code", "execution_count": 2, "id": "c374b212", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Instruction: Write an accurate, engaging, and concise answer for the given question using only the provided search results (some of which might be irrelevant) and cite them properly. Use an unbiased and journalistic tone. Always cite for any factual claim. When citing several search results, use [1][2][3]. Cite at least one document and at most three documents in each sentence. If multiple documents support the sentence, only cite a minimum sufficient subset of the documents.\n" ] } ], "source": [ "dataset = FileDataset('data/asqa.json')\n", "\n", "with open('prompts/asqa.json','r',encoding='utf-8') as file:\n", " demo = json.load(file)\n", " instruction = demo['INST']\n", "print(instruction)" ] }, { "cell_type": "markdown", "id": "fb237785", "metadata": {}, "source": [ "# Make a Request Wrapper\n", "\n", "The request wrapper contains is a template that consist of different components. Components with no data passing in will be automatically dropped." ] }, { "cell_type": "code", "execution_count": 3, "id": "62abd5d2", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{INST}\n", "\n", "Question:{question}\n", "\n", "{docs}\n", "Prefix: {prefix}\n", "\n", "The highlighted spans are: \n", "{span}\n", "\n", "Answer: \n" ] } ], "source": [ "prompt = Prompt(template='Answer: ',\n", " components={'INST':'{INST}\\n\\n', \n", " 'question':'Question:{question}\\n\\n',\n", " 'docs':'{docs}\\n',\n", " 'span':'The highlighted spans are: \\n{span}\\n\\n',\n", " 'prefix':'Prefix: {prefix}\\n\\n',\n", " })\n", "\n", "print(prompt)" ] }, { "cell_type": "markdown", "id": "72d33ff4", "metadata": {}, "source": [ "# Define an Evaluater\n", "An Evaluator will automatically evaluate the output, once some metrics are defined. You can use some pre-defined ones or new ones." ] }, { "cell_type": "code", "execution_count": 4, "id": "37788167", "metadata": {}, "outputs": [], "source": [ "evaluator = DefaultEvaluator(criteria = ['str_em','length','rouge'])" ] }, { "cell_type": "markdown", "id": "fd3541a3", "metadata": {}, "source": [ "# Define Generation Modules and Enhancing Modules" ] }, { "cell_type": "code", "execution_count": 5, "id": "9ad59c1a", "metadata": {}, "outputs": [], "source": [ "# attributer = AttributingModule(model='gpt-3.5-turbo')\n", "llm = LLM(model='gpt-3.5-turbo',prompt_maker=prompt, self_prompt={'INST':instruction})" ] }, { "cell_type": "markdown", "id": "1e4a38c0", "metadata": {}, "source": [ "# Construct a Pipeline" ] }, { "cell_type": "code", "execution_count": null, "id": "14b5c262", "metadata": {}, "outputs": [], "source": [ "pipeline = Sequence(sequence = [llm], head_prompt_maker = prompt, evaluator = evaluator, dataset = dataset)" ] }, { "cell_type": "markdown", "id": "6788a72b", "metadata": {}, "source": [ "# Run" ] }, { "cell_type": "code", "execution_count": null, "id": "f12bb819", "metadata": {}, "outputs": [], "source": [ "pipeline.run_on_dataset(datakeys=['question','docs'], init_docs='docs')" ] }, { "cell_type": "code", "execution_count": null, "id": "18a787c6", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "torch", "language": "python", "name": "torch" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.0" } }, "nbformat": 4, "nbformat_minor": 5 }