File size: 4,527 Bytes
456aee9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import copy
import os
import threading
import uuid
import zipfile
from datetime import datetime

import gradio as gr
import matplotlib.pyplot as plt
from PIL import Image


def add_download_button(gallery, filename_prefix="output"):
    
    def make_3x5_plot(images):
        plot_list = []
        
        # Split the list of images into chunks of 15
        chunks = [images[i:i + 15] for i in range(0, len(images), 15)]
        
        for chunk in chunks:
            fig, axs = plt.subplots(3, 4, figsize=(12, 9))
            for ax in axs.flatten():
                ax.axis("off")
            for ax, img in zip(axs.flatten(), chunk):
                img = img.convert("RGB")
                ax.imshow(img)
            
            plt.tight_layout(h_pad=0.5, w_pad=0.3)

            # Generate a unique filename
            filename = uuid.uuid4()
            tmp_path = f"/tmp/{filename}.png"
            
            # Save the plot to the temporary file
            plt.savefig(tmp_path, bbox_inches='tight', dpi=144)
            
            # Open the saved image
            img = Image.open(tmp_path)
            img = img.convert("RGB")
            img = copy.deepcopy(img)
            
            # Remove the temporary file
            os.remove(tmp_path)

            plot_list.append(img)
            plt.close()
        
        return plot_list
    
    def delete_file_after_delay(file_path, delay):
        def delete_file():
            if os.path.exists(file_path):
                os.remove(file_path)
        
        timer = threading.Timer(delay, delete_file)
        timer.start()
    
    def create_zip_file(images, filename_prefix=filename_prefix):
        if images is None or len(images) == 0:
            gr.Warning("No images selected.")
            return None
        gr.Info("Creating zip file for download...")
        images = [image[0] for image in images]
        if isinstance(images[0], str):
            images = [Image.open(image) for image in images]
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        
        zip_filename = f"/tmp/gallery_download/{filename_prefix}_{timestamp}.zip"
        os.makedirs(os.path.dirname(zip_filename), exist_ok=True)
        
        # plots = make_3x5_plot(images)
        
        
        
        with zipfile.ZipFile(zip_filename, 'w') as zipf:
            # Create a temporary directory to store images and plots
            temp_dir = f"/tmp/gallery_download/images/{uuid.uuid4()}"
            os.makedirs(temp_dir)
            
            try:
                # Save images to the temporary directory
                for i, img in enumerate(images):
                    img = img.convert("RGB")
                    img_path = os.path.join(temp_dir, f"single_{i:04d}.jpg")
                    img.save(img_path)
                    zipf.write(img_path, f"single_{i:04d}.jpg")
                
                # # Save plots to the temporary directory
                # for i, plot in enumerate(plots):
                #     plot = plot.convert("RGB")
                #     plot_path = os.path.join(temp_dir, f"grid_{i:04d}.jpg")
                #     plot.save(plot_path)
                #     zipf.write(plot_path, f"grid_{i:04d}.jpg")
            finally:
                # Clean up the temporary directory
                for file in os.listdir(temp_dir):
                    os.remove(os.path.join(temp_dir, file))
                os.rmdir(temp_dir)
        
        # Schedule the deletion of the zip file after 24 hours (86400 seconds)
        delete_file_after_delay(zip_filename, 86400)
        gr.Info(f"File is ready for download: {os.path.basename(zip_filename)}")
        return gr.update(value=zip_filename, interactive=True)
    
    with gr.Row():
        create_file_button = gr.Button("πŸ“¦ Pack", elem_id="create_file_button", variant='secondary')
        download_button = gr.DownloadButton(label="πŸ“₯ Download", value=None, variant='secondary', elem_id="download_button", interactive=False)
        
        create_file_button.click(create_zip_file, inputs=[gallery], outputs=[download_button])
        def warn_on_click(filename):
            if filename is None:
                gr.Warning("No file to download, please `πŸ“¦ Pack` first.")
            interactive = filename is not None
            return gr.update(interactive=interactive)
        download_button.click(warn_on_click, inputs=[download_button], outputs=[download_button])
    
    return create_file_button, download_button