{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import pickle\n", "import numpy as np\n", "import os\n", "import shutil\n", "from tqdm import tqdm\n", "import os\n", "import shutil\n", "import random" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def retrieve_joint_connections(dataset):\n", " if dataset == \"Human36M\":\n", " joint_connections = [[14, 15], [15, 16], [13, 12], [12, 11], [9, 8], [8, 7], [4, 5], [5, 6], [3, 2], [2, 1], [7, 0], [0, 4], [0, 1], [8, 11], [8, 14], [9, 10]] \n", " defined_kpts = 17\n", "\n", " elif dataset == \"face\":\n", " joint_connections = [[0, 1], [1, 2], [2, 3], [3, 4],[5, 6], [6, 7], [7, 8], [8, 9], [10, 11], [11, 12], [12, 13], [14, 15], [15, 16], [16, 17], [17, 18], [19, 20], [20, 21], [21, 22], [22, 23], [23, 24], [24, 19], [25, 26], [26, 27], [27, 28], [28, 29], [29, 30], [30, 25], [31, 32], [32, 33], [33, 34], [34, 35], [35, 36], [36, 37], [37, 38], [38, 39], [39, 40], [40, 41], [41, 42], [42, 31], [43, 44], [44, 45], [45, 46], [46, 47], [47, 48], [48, 49], [49, 50], [50, 43]]\n", " defined_kpts = 53\n", "\n", " # elif dataset == \"cheetah\":\n", " # joint_connections = [[0,2], [0,1], [0,3], [2,3], [1,3], [3,4], [4,5], [5,6], [6,7],\n", " # [3,8], [8,9], [9,10], [10,11],\n", " # [3,12], [12,13], [13,14], [14,15],\n", " # [5, 16], [16,17], [17,18], [18,19],\n", " # [5,20], [20,21], [21,22], [22,23]]\n", "\n", " elif dataset == \"cheetah\":\n", " joint_connections = [[3, 1], [1, 2], [3, 4], [4, 5], \n", " [5, 6], [6, 7], [8, 9], [9, 10], [3, 8], \n", " [3, 11], [11, 12], [12, 13], [5, 17], \n", " [17, 18], [18, 19], [5, 14], [14, 15], \n", " [15, 16]]\n", " defined_kpts = 20\n", "\n", " elif dataset == \"cheetahtr\":\n", " joint_connections = [[3, 1], [1, 2], [3, 4], [4, 5], \n", " [5, 6], [6, 7], [8, 9], [9, 10], [3, 8], \n", " [3, 11], [11, 12], [12, 13], [5, 17], \n", " [17, 18], [18, 19], [5, 14], [14, 15], \n", " [15, 16]]\n", " defined_kpts = 20\n", "\n", "\n", " elif dataset == \"cheetahsub\":\n", " joint_connections = [[2, 0], [0, 1], [2, 3], [3, 4],\n", " [4, 5], [6, 7], [7, 8], [2, 6],\n", " [2, 9], [9, 10], [10, 11], [3, 15],\n", " [15, 16], [16, 17], [3, 12], [12, 13], [13, 14]]\n", " defined_kpts = 18\n", "\n", " elif dataset == \"hands\":\n", " joint_connections = [[0,1],[1,2],[2,3],[3,4],[0,5],[5,6],[6,7],[7,8],[0,9],[9,10],[10,11],[11,12],[0,13],[13,14],[14,15],[15,16],[0,17],[17,18],[18,19],[19,20]]\n", " defined_kpts = 21\n", "\n", " elif dataset == \"amass\":\n", " joint_connections = [[9, 13], [13, 16], [16, 18], [18, 20], [6,9], [3, 6], [0, 3], [0, 1], [0, 2], [1, 4] , [4, 7], [7, 10], [2, 5], [5, 8], [8, 11], \n", " [9, 14], [14, 17], [17, 19], [19, 21], [6, 12], [12, 15], \n", " [20, 34], [34, 35], [35, 36], \n", " [20, 22], [22, 23], [23, 24], \n", " [20, 25], [25, 26], [26, 27], \n", " [20, 31], [31, 32], [32, 33],\n", " [20, 28], [28, 29], [29, 30], \n", " # [22, 25], [25, 31], [31, 28]\n", " [21, 49], [49, 50], [50, 51],\n", " [21, 37], [37, 38], [38, 39],\n", " [21, 40], [40, 41], [41, 42],\n", " [21, 46], [46, 47], [47, 48],\n", " [21, 43], [43, 44], [44, 45]]\n", " defined_kpts = 52\n", "\n", " elif dataset == \"openmonkey\":\n", " joint_connections = [[0, 1], [1, 2], [2, 3], [3, 4], [2, 5], [5, 6], [2, 7], [7, 8], [8, 9], [7, 10], [10, 11], [7, 12]] \n", " defined_kpts = 13\n", "\n", " elif dataset == \"wholebodyh36m\":\n", " joint_connections = [[0, 1], [1, 3], [0, 2], [2, 4], [59, 64], [65, 70], [71, 82], \n", " [71, 83], [77, 87], [77, 88], [88, 89], [89, 90], [71, 90],\n", " [5, 7], [7, 9], [9, 91], [91, 92], [93, 96], [96, 100], [100, 104], [104, 108], [91, 108],\n", " [6, 8], [8, 10], [10, 112], [112, 113], [114, 117], [117, 121], [121, 125], [125, 129], [112, 129],\n", " [5, 6], [6, 12], [11, 12], [5, 11], [12, 14], [14, 16], [16, 20], [16, 21], [16, 22], [11, 13], [13, 15], [15, 17], [15, 18], [15, 19]] \n", "\n", " joint_connections = [\n", " # Connect points as defined by the tuples\n", " [0, 1], [1, 3], [0, 2], [2, 4], \n", " [5, 7], [7, 9], [9, 91], #[91, 92], [93, 96], [96, 100], [100, 104], [104, 108], [91, 108],\n", " [6, 8], [8, 10], [10, 112], #[112, 113], [114, 117], [117, 121], [121, 125], [125, 129], [112, 129],\n", " [5, 6], [6, 12], [11, 12], [5, 11], \n", " [12, 14], [14, 16], [16, 20], [16, 21], [16, 22], \n", " [11, 13], [13, 15], [15, 17], [15, 18], [15, 19],\n", "\n", " # Face\n", " [23, 24], [24, 25], [25, 26], [26, 27], [27, 28], [28, 29], [29, 30], [30, 31], [31, 32], [32, 33], [33, 34],\n", " [34, 35], [35, 36], [36, 37], [37, 38], [38, 39], [40, 41], [41, 42], [42, 43], [43, 44], [59, 60], [60, 61],\n", " [61, 62], [62, 63], [63, 64], [59, 64], [45, 46], [46, 47], [47, 48], [48, 49], [65, 66], [66, 67], [67, 68],\n", " [68, 69], [69, 70], [65, 70], [50, 51], [51, 52], [52, 53], [54, 55], [55, 56], [56, 57], [57, 58], [71, 72],\n", " [72, 73], [73, 74], [74, 75], [75, 76], [76, 77], [77, 78], [78, 79], [79, 80], [80, 81], [81, 82], [82, 83],\n", " [83, 84], [84, 85], [85, 86], [86, 87], [87, 88], [88, 89], [89, 90], [91, 92],\n", "\n", " # Left hand\n", " [91, 92], [92, 93], [93, 94], [94, 95], [91, 96], [96, 97], [97, 98], [98, 99], [91, 100], [100, 101], [101, 102],\n", " [102, 103], [91, 104], [104, 105], [105, 106], [106, 107], [91, 108], [108, 109], [109, 110], [110, 111],\n", "\n", "\n", " # Right hand\n", " [112, 113], [113, 114], [114, 115], [115, 116], [112, 117], [117, 118], [118, 119], [119, 120], [112, 121],\n", " [121, 122], [122, 123], [123, 124], [112, 125], [125, 126], [126, 127], [127, 128], [112, 129], [129, 130],\n", " [130, 131], [131, 132]\n", "\n", " ] \n", "\n", " defined_kpts = 133\n", "\n", " elif dataset == \"bp4d+\":\n", " joint_connections = [\n", " # Left eyebrow (viewed from the model's perspective)\n", " [15, 16], [16, 17], [17, 18], [18, 19],\n", " [10, 11], [11, 12], [12, 13], [13, 14],\n", " \n", " # Right eyebrow\n", " [0, 1], [1, 2], [2, 3], [3, 4],\n", " [5, 6], [6, 7], [7, 8], [8, 9],\n", " \n", " # Bridge of the nose (from between the eyebrows to the tip)\n", " [36, 37], [37, 38], [38, 39], [39, 40], [40, 41], [41, 42], [42, 43], [43, 44], [44, 45], [45, 46], [46, 47],\n", " \n", " # Left eye\n", " [28, 29], [29, 30], [30, 31], [31, 32], [32, 33], [33, 34], [34, 35], [35, 28],\n", " \n", " # Right eye\n", " [20, 21], [21, 22], [22, 23], [23, 24], [24, 25], [25, 26], [26, 27], [27, 20],\n", " \n", " # Outer part of the lips (outline of the lips)\n", " [48, 49], [49, 50], [50, 51], [51, 52], [52, 53], [53, 54], [54, 55], [55, 56], [56, 57], [57, 58], [58, 59], [59, 48],\n", " \n", " # Inner part of the lips (detail within the lips)\n", " [60, 61], [61, 62], [62, 63], [63, 64], [64, 65], [65, 66], [66, 67], [67, 60],\n", " \n", " # Jawline (from left ear, around the chin, to right ear)\n", " [68, 69], [69, 70], [70, 71], [71, 72], [72, 73], [73, 74], [74, 75], [75, 76], [76, 77], [77, 78], [78, 79], [79, 80], [80, 81], [81, 82]\n", " ]\n", " defined_kpts = 83\n", "\n", " \n", " elif dataset == \"panoptic\":\n", " joint_connections = [[0, 1], [0, 3], [3, 4], [4, 5], [0, 2], [2, 6], [6, 7], [7, 8], [2, 12], [12, 13], [13, 14], [0, 9], [9, 10], [10, 11]]\n", " defined_kpts = 15\n", "\n", " elif dataset == 'aeroplane':\n", " joint_connections = [[2, 5], [1, 4], [5, 3], [3, 7], [7, 0], [0, 5], [5, 7], [5, 6], [6, 0], [6, 3], [2, 4], [2, 1]]\n", " defined_kpts = 8\n", "\n", " elif dataset == 'bicycle':\n", " joint_connections = [[0,3], [0,7], [0, 2], [0, 6], [0, 10], [9, 10], [4, 10], [8, 10], [1, 9], [5, 9]]\n", " defined_kpts = 11\n", "\n", " elif dataset == \"tiger\" or dataset == \"cow\" or dataset == \"horse\" or dataset == \"hippo\" or dataset == \"dog\":\n", " joint_connections = [[0,24], [0, 20], [1, 21], [1, 24], [7, 25], [19, 25], [6, 17],\n", " [4, 15], [3, 14], [9, 15], [8, 14], [9, 13], [8, 12],\n", " [2, 23], [2, 22], [2, 24], [11, 17], [10, 16], [5, 16],\n", " [7, 10], [7, 11], [13,18], [12, 18], [7, 18], [24,18]]\n", " defined_kpts = 26\n", "\n", " elif dataset == \"tigersubset\" or dataset == \"cowsubset\" or dataset == \"horsesubset\" or dataset == \"hipposubset\" or dataset == \"dogsubset\":\n", " joint_connections = [[3, 17], [15, 17], [2, 13], [5, 9], [4, 8], [7, 13], [6, 12], [1, 12], [3, 6], [3, 7], [9, 14], [8, 14], [3, 14], [5, 11], [4, 10], [0, 14], [0, 16]] \n", " defined_kpts = 18\n", "\n", " elif dataset == 'boat':\n", " joint_connections = [[0, 2], [0, 3], [0, 1], [1, 2], [1, 3], [2, 4], [3, 5], [4, 5], [1, 5], [1, 4]]\n", " defined_kpts = 6\n", "\n", " elif dataset == 'bottle':\n", " joint_connections = [[0, 1], [1, 2], [0, 2], [3, 4], [3, 5], [4, 5], [1, 4], [0, 3], [2, 5], [1, 6], [0, 6], [2, 6]]\n", " defined_kpts = 8\n", "\n", " elif dataset == 'busfull' or dataset == 'bus':\n", " joint_connections = [[5, 7], [4, 5], [6, 7], [4, 6], [1, 5], [1, 3], [3, 7], [0, 1], [2, 3], [0, 2], [2, 10], [0, 8], [8, 9], [10, 11], [6, 11], [4, 9]]\n", " defined_kpts = 12\n", "\n", " elif dataset == 'car':\n", " joint_connections = [[0, 8], [0, 4], [4, 10], [8, 10],\n", " [10, 9], [9, 11], [8, 11], [11, 6], \n", " [9, 2], [2, 6], [4, 1], [5, 1], \n", " [0, 5], [5, 7], [1, 3], [7, 3], [3, 2], [7, 6]] \n", " defined_kpts = 12\n", "\n", " elif dataset == 'busmissing':\n", " joint_connections = [[5, 7], [4, 5], [6, 7], [4, 6], [1, 5], [1, 3], [3, 7], [0, 1], [2, 3], [0, 2], [2, 6], [0, 4]]\n", " defined_kpts = 8\n", "\n", " elif dataset == 'diningtable':\n", " joint_connections = [[0, 2], [4, 6], [1, 3], [5, 7], [1, 5], [3, 7], [0, 4], [2, 6], [0, 1], [2, 3], [4, 5], [6, 7]]\n", " defined_kpts = 8\n", "\n", " elif dataset == 'tvmonitor':\n", " joint_connections = [[5, 7], [4, 5], [4, 6], [6, 7], [0, 1], [0, 2], [2, 3], [1, 3], [3, 7], [1, 5], [2, 6], [0, 4]]\n", " defined_kpts = 8\n", "\n", " elif dataset == 'train':\n", " joint_connections = [[4, 5], [4, 6], [6, 7], [5, 7], [0, 1], [1, 3], [2, 3], [0, 2], [1, 5], [0, 4], [2, 6], [3, 7], [1, 5]]\n", " defined_kpts = 8\n", "\n", "\n", " elif dataset == 'train16':\n", " joint_connections = [[0, 1], [1,5], [5, 9], [9, 15], [3, 7], [7, 11], [11, 13], [2, 3], [2, 6], [6, 10], [10, 12], [1, 3], [0, 2], \n", " [0, 4], [4, 8], [8, 14], [15, 13], [13, 12], [12, 14], [14, 15]]\n", " defined_kpts = 16\n", "\n", " elif dataset == 'motorbike':\n", " joint_connections = [[6, 2], [2, 9], [2, 3], [3, 8], [5, 8],\n", " [3, 5], [2, 1], [1, 0], [0, 7], [0, 4],\n", " [4, 7], [1, 4], [1, 7], [1, 5], [1, 8]]\n", " defined_kpts = 10\n", "\n", " elif dataset == 'sofa':\n", " joint_connections = [[1, 5], [5, 4], [4, 6], [6, 2], [2, 0], \n", " [1, 0], [0, 4], [1, 3], [7, 5], [2, 3], \n", " [3, 7], [9, 7], [7, 6], [6, 8], [8, 9]]\n", " defined_kpts = 10\n", "\n", " elif dataset == 'chair':\n", " joint_connections = [[7, 3], [6, 2], [9, 5], [8, 4], [7, 9], \n", " [8, 6], [6, 7], [9, 8], [9, 1], [8, 0], [1, 0]]\n", " defined_kpts = 10\n", "\n", " # MBW datasets\n", " elif dataset == 'colobusmonkey':\n", " joint_connections = [[0, 1], [1, 2], [2, 3], [3, 4], [1, 5], [5, 6], [6, 7], [1, 8], [8, 9], [9, 10], [10, 11], [11, 12], [9, 13], [13, 14], [14, 15]]\n", " defined_kpts = 16\n", "\n", " elif dataset == 'chimpanzee':\n", " joint_connections = [[0, 1], [1, 2], [2, 3], [3, 4], [1, 5], [5, 6], [6, 7], [1, 8], [8, 9], [9, 10], [10, 11], [11, 12], [9, 13], [13, 14], [14, 15]]\n", " defined_kpts = 16\n", "\n", "\n", " elif dataset == 'tigerzoo':\n", " joint_connections = [[0, 1], [1, 2], [2, 3], [3,4], [4,5], [1,6], [6,7], [1,8], [8,9], [3,10], [10,11], [3,12], [12,13]]\n", " defined_kpts = 14\n", "\n", " elif dataset == 'clownfish':\n", " joint_connections = [[0, 1], [1, 2], [2, 3], [1, 4], [1, 5]]\n", " defined_kpts = 6\n", "\n", " elif dataset == 'fish':\n", " joint_connections = [[0, 1], [1, 2], [2, 3], [1, 3], [3, 4], [4, 5], [5, 6], [6, 7], [5, 7], [5, 8], [8, 9], [9, 10], [8, 10], [10, 11], [11, 0]]\n", " defined_kpts = 12\n", "\n", " elif dataset == 'seahorse':\n", " joint_connections = [[0, 1], [1, 2], [2,3], [1,3], [3,4], [4,5]]\n", " defined_kpts = 6\n", " \n", " return joint_connections" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Load pickle file\n", "\n", "category_name = 'aeroplane'\n", "suffix = 'val'\n", "\n", "pickle_name = category_name + '_' + suffix\n", "\n", "## Load input data\n", "input_data_path = pickle_name + '.pkl'\n", "with open(input_data_path, 'rb') as f:\n", " input_data = pickle.load(f)\n", "\n", "input_2d = input_data['W_GT']\n", "image_path = input_data['image_path']\n", "\n", "## Load predictions pickle file\n", "pred_data_path = category_name + '_3dlfm.pkl'\n", "with open(pred_data_path, 'rb') as f:\n", " pred_data = pickle.load(f)\n", "labels_3d = pred_data['labels_3d']\n", "outputs_3d = pred_data['outputs_3d']\n", "\n", "\n", "joint_connections = retrieve_joint_connections(category_name)\n", "\n", "## Print the statistics\n", "print(\"Number of images: \", len(image_path))\n", "print(\"Input 2D shape: \", input_2d.shape)\n", "print(\"Labels 3D shape: \", labels_3d.shape)\n", "print(\"Outputs 3D shape: \", outputs_3d.shape)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def plot_3d_skeleton(predictions_3d, labels_3d, joint_connections, range_scale=2500, masks=None):\n", " \"\"\"Visualize 3D skeletons for predicted and ground truth data.\"\"\" \n", "\n", " # Extract 3D coordinates and masks for the given sample index\n", " pred_coordinates = predictions_3d\n", " label_coordinates = labels_3d\n", " \n", " # Extract X, Y, Z coordinates after filtering\n", " label_x, label_y, label_z = label_coordinates.T \n", "\n", " # Filter joint connections based on the mask\n", " if masks is not None:\n", " updated_connections = [connection for connection in joint_connections if masks[connection[0]] == 1.0 and masks[connection[1]] == 1.0]\n", " print(\"Updated connections: {}\".format(updated_connections))\n", " else:\n", " updated_connections = joint_connections \n", "\n", " \n", " \n", " # Plotly Traces\n", " traces = []\n", " # Predicted skeleton\n", " traces.extend(get_trace3d(updated_connections, pred_coordinates, 'blue', 'blue', \"Predicted KP\", masks=masks))\n", " # Ground truth skeleton\n", " traces.extend(get_trace3d(updated_connections, label_coordinates, 'red', 'red', \"Groundtruth KP\", masks=masks))\n", "\n", " # Define layout\n", " layout = go.Layout(\n", " scene=dict(\n", " aspectratio=dict(x=1, y=1, z=2),\n", " xaxis=dict(range=[-label_x.max() * range_scale, label_x.max() * range_scale, ], showticklabels=False),\n", " yaxis=dict(range=[-label_z.max() * range_scale, label_z.max() * range_scale], showticklabels=False),\n", " zaxis=dict(range=[-label_y.max() * range_scale, label_y.max() * range_scale], showticklabels=False),\n", " ),\n", " width=700,\n", " margin=dict(r=20, l=10, b=10, t=10),\n", " scene_camera=dict(\n", " up=dict(x=0, y=0, z=1),\n", " center=dict(x=0, y=0, z=0),\n", " eye=dict(x=0, y=-1.5, z=1.25),\n", " )\n", " )\n", "\n", " # Create and display the plot\n", " # fig = go.Figure(data=traces, layout=layout)\n", " fig = go.Figure(data=traces)\n", " fig.update_layout(scene=dict(aspectmode=\"data\")) \n", " fig.update_layout(\n", " scene=dict(\n", " xaxis=dict(title='', showticklabels=False),\n", " yaxis=dict(title='', showticklabels=False),\n", " zaxis=dict(title='', showticklabels=False)\n", " )\n", " )\n", " fig.show() \n", "\n", "\n", "def get_trace3d(joint_connections, points3d, point_color, line_color, name, masks=None):\n", " \"\"\"Generate plotly traces for 3D points and connections.\"\"\"\n", "\n", " # Filter 3D coordinates based on the mask\n", " if masks is not None:\n", " masked_coordinates = points3d[masks == 1.0]\n", " else:\n", " masked_coordinates = points3d\n", " \n", " x, z, y = masked_coordinates.T # Swap Y and Z here\n", " x_trace, z_trace, y_trace = points3d.T # Swap Y and Z here\n", "\n", " # Trace of points\n", " trace_pts = go.Scatter3d(\n", " x=x, y=y, z=z,\n", " mode='markers',\n", " name=name,\n", " marker=dict(symbol='circle', size=6, color=point_color)\n", " )\n", "\n", " # Trace of lines\n", " x_lines = []\n", " y_lines = []\n", " z_lines = []\n", "\n", " for start, end in joint_connections:\n", " x_lines.extend([x_trace[start], x_trace[end], None])\n", " y_lines.extend([y_trace[start], y_trace[end], None])\n", " z_lines.extend([z_trace[start], z_trace[end], None])\n", "\n", " trace_lines = go.Scatter3d(\n", " x=x_lines, y=y_lines, z=z_lines,\n", " mode='lines',\n", " name=name,\n", " line=dict(width=6, color=line_color)\n", " )\n", "\n", " return [trace_pts, trace_lines]\n", "\n", "\n", "import plotly.graph_objects as go\n", "\n", "def plot_2d_skeleton(predictions_2d, labels_2d, joint_connections, masks=None):\n", " \"\"\"Visualize 2D skeletons for predicted and ground truth data.\"\"\"\n", " \n", " # Extract 2D coordinates and masks for the given sample index\n", " pred_coordinates = predictions_2d\n", " label_coordinates = labels_2d\n", "\n", " # Filter joint connections based on the mask\n", " if masks is not None:\n", " updated_connections = [connection for connection in joint_connections if masks[connection[0]] == 1.0 and masks[connection[1]] == 1.0]\n", " print(\"Updated connections: {}\".format(updated_connections))\n", " else:\n", " updated_connections = joint_connections\n", "\n", " \n", " \n", " \n", " # Plotly Traces\n", " traces = []\n", " # Predicted skeleton\n", " traces.extend(get_trace2d(updated_connections, pred_coordinates, 'blue', 'blue', \"Predicted KP\", masks=masks))\n", " # Ground truth skeleton\n", " traces.extend(get_trace2d(updated_connections, label_coordinates, 'red', 'red', \"Groundtruth KP\", masks=masks))\n", "\n", " # Define layout\n", " layout = go.Layout(\n", " width=700,\n", " height=700,\n", " margin=dict(r=20, l=10, b=10, t=10)\n", " )\n", "\n", " # Create and display the plot\n", " fig = go.Figure(data=traces, layout=layout)\n", " fig.show()\n", "\n", "def get_trace2d(joint_connections, points2d, point_color, line_color, name, masks=None, get_lines=None):\n", " \"\"\"Generate plotly traces for 2D points and connections.\"\"\"\n", " \n", " # Filter 2D coordinates based on the mask\n", " if masks is not None:\n", " masked_coordinates = points2d[masks == 1.0]\n", " else:\n", " masked_coordinates = points2d\n", " \n", " x, y = masked_coordinates.T # Swap Y and Z here\n", " x_trace, y_trace = points2d.T # Swap Y and Z here\n", "\n", " # Trace of points\n", " trace_pts = go.Scatter(\n", " x=x, y=y,\n", " mode='markers',\n", " name=name,\n", " marker=dict(symbol='circle', size=6, color=point_color)\n", " )\n", "\n", " # Trace of lines\n", " x_lines = []\n", " y_lines = []\n", "\n", " for start, end in joint_connections:\n", " x_lines.extend([x_trace[start], x_trace[end], None])\n", " y_lines.extend([y_trace[start], y_trace[end], None])\n", "\n", " trace_lines = go.Scatter(\n", " x=x_lines, y=y_lines,\n", " mode='lines',\n", " name=name,\n", " line=dict(width=2, color=line_color)\n", " )\n", "\n", " if get_lines is not None:\n", " if get_lines:\n", " return [trace_pts, trace_lines]\n", " else:\n", " return [trace_pts] \n", " else:\n", " return [trace_pts, trace_lines]\n", "\n", "\n", "import plotly.graph_objs as go\n", "from PIL import Image\n", "import numpy as np\n", "\n", "from PIL import Image\n", "import numpy as np\n", "def plot_2d_skeleton_on_image(predictions_2d, labels_2d, joint_connections, image_path, masks=None, get_lines=None):\n", " \"\"\"Visualize 2D skeletons for predicted and ground truth data on top of an image.\"\"\"\n", " \n", " # Load the image\n", " image = Image.open(image_path)\n", " width, height = image.size\n", "\n", " # Extract 2D coordinates and masks for the given sample index\n", " pred_coordinates = predictions_2d\n", " label_coordinates = labels_2d\n", " masks_ = masks\n", " \n", " # Filter joint connections based on the mask\n", " if masks is not None:\n", " updated_connections = [connection for connection in joint_connections if masks_[connection[0]] == 1.0 and masks_[connection[1]] == 1.0]\n", " else:\n", " updated_connections = joint_connections\n", "\n", " print(\"updated connections: {}\".format(updated_connections))\n", "\n", " # Plotly Traces\n", " traces = []\n", " # Image as background\n", " traces.append(go.Scatter(\n", " x=[0, width],\n", " y=[0, height],\n", " mode=\"markers\",\n", " marker_opacity=0,\n", " hoverinfo=\"none\",\n", " showlegend=False\n", " ))\n", "\n", " # Predicted skeleton\n", " traces.extend(get_trace2d(updated_connections, pred_coordinates, 'blue', 'blue', None, masks_, get_lines=get_lines))\n", " # Ground truth skeleton\n", " traces.extend(get_trace2d(updated_connections, label_coordinates, 'red', 'red', None, masks_, get_lines=get_lines))\n", "\n", " # Define layout\n", " layout = go.Layout(\n", " width=width,\n", " height=height,\n", " xaxis=dict(showgrid=False, zeroline=False, showticklabels=False, range=[0, width]),\n", " yaxis=dict(showgrid=False, zeroline=False, showticklabels=False, range=[height, 0], scaleanchor=\"x\"),\n", " images=[go.layout.Image(source=image, xref=\"x\", yref=\"y\", x=0, y=0, sizex=width, sizey=height, sizing=\"stretch\", opacity=1.0, layer=\"below\")],\n", " margin=dict(r=10, l=10, b=10, t=10),\n", " hovermode=\"closest\",\n", " showlegend=False, # Hide legend\n", " )\n", " \n", " # Create and display the plot\n", " fig = go.Figure(data=traces, layout=layout)\n", " fig.show()\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "get_lines = False; frame_idx = 227\n", "plot_3d_skeleton(-outputs_3d[frame_idx], -labels_3d[frame_idx], joint_connections, range_scale=2500, masks=None)\n", "plot_2d_skeleton_on_image(input_2d[frame_idx], input_2d[frame_idx], joint_connections, image_path[frame_idx], masks=None, get_lines=get_lines)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#### Final storage #####\n", "\n", "## Randomly choose 10 frames and store them as a final pickle file\n", "# random_indices = random.sample(range(len(image_path)), 10)\n", "random_indices = [167, 29, 1, 3, 4, 7, 10, 18, 123, 227]\n", "final_data = {}\n", "final_data['image_path'] = np.asarray(image_path)[random_indices].tolist()\n", "final_data['inputs_2d'] = input_2d[random_indices]\n", "final_data['labels_3d'] = labels_3d[random_indices]\n", "final_data['outputs_3d'] = outputs_3d[random_indices]\n", "\n", "# Create final directory\n", "if not os.path.exists('final'):\n", " os.makedirs('final')\n", "\n", "# if the below directory already exists, then delete it\n", "if os.path.exists('final/' + category_name + '_images'):\n", " shutil.rmtree('final/' + category_name + '_images')\n", "os.makedirs('final/' + category_name + '_images', exist_ok=True)\n", "# Copying images to the new directory and updating the paths\n", "new_image_paths = []\n", "for path in tqdm(final_data['image_path']):\n", " original_image_path = path # Save the original path\n", " new_path = os.path.join('final/' + category_name + '_images', path)\n", "\n", " # Create the directory if it doesn't exist\n", " os.makedirs(os.path.dirname(new_path), exist_ok=True)\n", "\n", " shutil.copy(original_image_path, new_path)\n", " new_image_paths.append(new_path)\n", "\n", "# Update the image_path in data\n", "final_data['image_path'] = new_image_paths\n", "\n", "# Save the new pickle file\n", "final_pickle_name = 'final/' + category_name + '.pkl'\n", "with open(final_pickle_name, 'wb') as f:\n", " pickle.dump(final_data, f)\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# print(random_indices)\n", "random_indices = [167, 29, 1, 3, 4, 7, 10, 18, 123, 227]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "lifting", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.7" } }, "nbformat": 4, "nbformat_minor": 2 }