Raiff1982 commited on
Commit
e457f10
·
verified ·
1 Parent(s): 24c0f30

Create tool_call.py

Browse files
Files changed (1) hide show
  1. tool_call.py +30 -0
tool_call.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from gradio_tool import GradioTool
2
+ import os
3
+
4
+ class StableDiffusionTool(GradioTool):
5
+ """Tool for calling stable diffusion from llm"""
6
+
7
+ def __init__(
8
+ self,
9
+ name="StableDiffusion",
10
+ description=(
11
+ "An image generator. Use this to generate images based on "
12
+ "text input. Input should be a description of what the image should "
13
+ "look like. The output will be a path to an image file."
14
+ ),
15
+ src="gradio-client-demos/stable-diffusion",
16
+ hf_token=None,
17
+ ) -> None:
18
+ super().__init__(name, description, src, hf_token)
19
+
20
+ def create_job(self, query: str) -> Job:
21
+ return self.client.submit(query, "", 9, fn_index=1)
22
+
23
+ def postprocess(self, output: str) -> str:
24
+ return [os.path.join(output, i) for i in os.listdir(output) if not i.endswith("json")][0]
25
+
26
+ def _block_input(self, gr) -> "gr.components.Component":
27
+ return gr.Textbox()
28
+
29
+ def _block_output(self, gr) -> "gr.components.Component":
30
+ return gr.Image()