{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import os\n", "import torch\n", "\n", "torch.cuda.is_available()\n", "os.environ[\"WANDB_ENABLED\"] = \"false\"\n", "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"-1\"\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "device = torch.device(f\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", "from models.Tiffusion import tiffusion\n", "# from models.CSDI import tiffusion\n", "\n", "model = tiffusion.Tiffusion(\n", " seq_length=365,\n", " feature_size=3,\n", " n_layer_enc=6,\n", " n_layer_dec=4,\n", " d_model=128,\n", " timesteps=500,\n", " sampling_timesteps=200,\n", " loss_type='l1',\n", " beta_schedule='cosine',\n", " n_heads=8,\n", " mlp_hidden_times=4,\n", " attn_pd=0.0,\n", " resid_pd=0.0,\n", " kernel_size=1,\n", " padding_size=0,\n", " control_signal=[]\n", ").to(device)\n", "\n", "model.load_state_dict(torch.load(\"./weight/checkpoint-10.pt\", map_location='cpu', weights_only=True)[\"model\"])\n", "# model.load_state_dict(torch.load(\"../../../data/CSDI/ckpt_baseline_365/checkpoint-10.pt\", map_location='cpu', weights_only=True)[\"model\"])\n", "\n", "\n", "coef = 1.0e-2\n", "stepsize = 5.0e-2\n", "sampling_steps = 100 # 这个可以调整 100-500都行 快慢和准度 tradeoff\n", "seq_length = 365\n", "feature_dim = 3\n", "print(f\"seq_length: {seq_length}, feature_dim: {feature_dim}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Sampling" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "\n", "anchor_value = [\n", " # (time, feature_id, y-value, confidence)\n", " (0, 0, 0.04, 1.0),\n", " (2, 0, 0.58, 1.0),\n", " # (6, 0, 0.27, 0.5),\n", " # (10, 0, 0.04, 1.0),\n", " # (12, 0, 0.58, 0.001),\n", " # (16, 0, 0.27, 0.5),\n", " # (20, 0, 0.04, 1.0),\n", " # (22, 0, 0.58, 0.001),\n", " # (26, 0, 0.27, 0.5),\n", " # (30, 0, 0.04, 1.0),\n", " # (32, 0, 0.58, 0.001),\n", " # (36, 0, 0.27, 0.5),\n", " # (40, 0, 0.04, 1.0),\n", " # (42, 0, 0.58, 0.001),\n", " # (46, 0, 0.27, 0.5),\n", " # (50, 0, 0.04, 1.0),\n", " # (52, 0, 0.58, 0.001),\n", " # (56, 0, 0.27, 0.5),\n", " # (60, 0, 0.04, 1.0),\n", " # (62, 0, 0.58, 0.001),\n", " # (66, 0, 0.27, 0.5),\n", "]\n", "\n", "observed_points = torch.zeros((seq_length, feature_dim)).to(device)\n", "observed_mask = torch.zeros((seq_length, feature_dim)).to(device)\n", "\n", "for time, feature_id, y_value, confidence in anchor_value:\n", " observed_points[time, feature_id] = y_value\n", " observed_mask[time, feature_id] = confidence\n", "\n", "auc = -10\n", "auc_weight = 10.0\n", "with torch.no_grad():\n", " results = model.predict_weighted_points(\n", " observed_points, # (seq_length, feature_dim)\n", " observed_mask, # (seq_length, feature_dim)\n", " coef, # fixed\n", " stepsize, # fixed\n", " sampling_steps, # fixed\n", " # model_control_signal=model_control_signal,\n", " gradient_control_signal={\n", " \"auc\": auc, \"auc_weight\": auc_weight,\n", " },\n", " )\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "results.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "\n", "plt.plot(results[:,0], label=\"Predicted Feature 0\")\n", "for time, feature_id, y_value, confidence in anchor_value:\n", " plt.scatter(time, y_value, c='r')\n", "plt.show()\n", "plt.plot(results[:,1], label=\"Predicted Feature 1\")\n", "plt.show()\n", "plt.plot(results[:,2], label=\"Predicted Feature 2\")\n", "plt.show()" ] } ], "metadata": { "kernelspec": { "display_name": "rag", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.14" } }, "nbformat": 4, "nbformat_minor": 2 }