{ "cells": [ { "cell_type": "markdown", "id": "c8b6fdc7-e525-44de-bc51-0eb838d1d1af", "metadata": {}, "source": [ "### Push UDLM to Hub" ] }, { "cell_type": "code", "execution_count": null, "id": "5ce3f3e3-9bc0-45cc-8e07-e507378df58a", "metadata": {}, "outputs": [], "source": [ "import os\n", "\n", "import huggingface_hub\n", "import torch\n", "import transformers\n", "\n", "from models.hf import UDLMConfig\n", "from models.hf import UDLM\n", "from models.ema import ExponentialMovingAverage" ] }, { "cell_type": "code", "execution_count": null, "id": "684e2f51-a993-405a-8447-dc7e1445465c", "metadata": {}, "outputs": [], "source": [ "if os.path.exists(os.path.join(os.environ['HF_HOME'], 'token')):\n", " with open(os.path.join(os.environ['HF_HOME'], 'token'), 'r') as f:\n", " token = f.read().strip()\n", "else:\n", " token = None\n", "huggingface_hub.login(token=token)" ] }, { "cell_type": "code", "execution_count": null, "id": "65dc9b6e-5686-438f-935f-b4928a4e2c84", "metadata": {}, "outputs": [], "source": [ "UDLMConfig.register_for_auto_class()\n", "UDLM.register_for_auto_class('AutoModelForMaskedLM')" ] }, { "cell_type": "code", "execution_count": null, "id": "35732a5f-676b-4a7b-b21a-470db0e498d9", "metadata": {}, "outputs": [], "source": [ "device = 'cuda'\n", "# 'bert-base-uncased' for LM1B\n", "# 'yairschiff/qm9-tokenizer' for QM9\n", "tokenizer = transformers.AutoTokenizer.from_pretrained('bert-base-uncased', trust_remote_code=True)\n", "# tokenizer = transformers.AutoTokenizer.from_pretrained('yairschiff/qm9-tokenizer', trust_remote_code=True)\n", "\n", "# 'kuleshov-group/udlm-lm1b' for LM1B\n", "# 'kuleshov-group/udlm-qm9' for QM9\n", "name_or_path = 'kuleshov-group/udlm-lm1b'\n", "# name_or_path = 'kuleshov-group/udlm-qm9'" ] }, { "cell_type": "code", "execution_count": null, "id": "f246fa0c-077d-432c-98fa-f64da7c17cea", "metadata": {}, "outputs": [], "source": [ "config = UDLMConfig(\n", " vocab_size=tokenizer.vocab_size,\n", " model_length=128,\n", " hidden_dim=768,\n", " cond_dim=128,\n", " n_blocks=12, \n", " n_heads=12,\n", " dropout=0.1,\n", " time_conditioning=True,\n", " cfg=False,\n", " cfg_num_classes=-1,\n", " return_dict=False\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "9bcef091-fad2-4e84-a1bc-03c629572d24", "metadata": {}, "outputs": [], "source": [ "model = UDLM(config)\n", "ema = ExponentialMovingAverage(\n", " model.backbone.parameters(),\n", " decay=0.0)" ] }, { "cell_type": "code", "execution_count": null, "id": "8a5f8cb6-0021-46bb-a9bc-3db36a28f111", "metadata": {}, "outputs": [], "source": [ "model.config._name_or_path = name_or_path\n", "model.config.auto_map = {\n", " 'AutoConfig': f'{name_or_path}--configuraction_udlm.UDLMConfig',\n", " 'AutoModelForMaskedLM': f'{name_or_path}--modeling_udlm.UDLM',\n", "}" ] }, { "cell_type": "code", "execution_count": null, "id": "0062a090-b72b-4b7c-9da9-949e5d03e8da", "metadata": {}, "outputs": [], "source": [ "ckpt_path = ''\n", "ckpt = torch.load(ckpt_path)" ] }, { "cell_type": "code", "execution_count": null, "id": "974d3128-5432-4c4f-ae48-687b25f2b29b", "metadata": {}, "outputs": [], "source": [ "ema.load_state_dict(ckpt['ema'])\n", "ema.copy_to(model.backbone.parameters())\n", "model = model.to(device)" ] }, { "cell_type": "code", "execution_count": null, "id": "5f111403-3bda-4f5b-b98e-ccb4c7cb93ad", "metadata": {}, "outputs": [], "source": [ "# Confirm EMA params loaded\n", "for c, m in zip(ema.shadow_params, ckpt['ema']['shadow_params']):\n", " if not torch.allclose(c.to(device), m.to(device)):\n", " print('Issue with EMA!')\n", "\n", "for c, m in zip(ema.shadow_params, model.parameters()):\n", " if not torch.allclose(c.to(device), m.to(device)):\n", " print('Issue with EMA!')" ] }, { "cell_type": "code", "execution_count": null, "id": "6d172396-b837-4384-95ec-7e6704c89b79", "metadata": {}, "outputs": [], "source": [ "model.push_to_hub(name_or_path, private=False)" ] }, { "cell_type": "markdown", "id": "1ccf941f-e01a-4233-be07-2cb1dd3004ee", "metadata": {}, "source": [ "### Test Model from Hub" ] }, { "cell_type": "code", "execution_count": null, "id": "53091198-af84-40fc-84a2-802779e8e75e", "metadata": {}, "outputs": [], "source": [ "model_test = transformers.AutoModelForMaskedLM.from_pretrained(name_or_path, trust_remote_code=True)" ] }, { "cell_type": "code", "execution_count": null, "id": "14fbb458-444e-4127-b821-d46a040d97a0", "metadata": {}, "outputs": [], "source": [ "input_ids = torch.randint(10, size=(2, 10)).to(device)\n", "model_test = model_test.to(device)\n", "model_test.eval()" ] }, { "cell_type": "code", "execution_count": null, "id": "ee5d3580-5fbb-44a4-ba92-5c92bb6a208b", "metadata": {}, "outputs": [], "source": [ "print(model_test(input_ids, torch.zeros(2,).to(device)).shape)\n", "print(model_test(input_ids, torch.zeros(2,).to(device)).mean())" ] }, { "cell_type": "code", "execution_count": null, "id": "be448f86-7d93-4251-9985-466f02327629", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.20" } }, "nbformat": 4, "nbformat_minor": 5 }