fffiloni commited on
Commit
40a70b3
·
verified ·
1 Parent(s): 5a49966

integrate llama locally

Browse files
Files changed (1) hide show
  1. app.py +28 -60
app.py CHANGED
@@ -3,10 +3,36 @@ import re
3
  import os
4
  hf_token = os.environ.get('HF_TOKEN')
5
  from gradio_client import Client
 
 
 
 
 
 
 
 
6
  client = Client("https://fffiloni-test-llama-api-debug.hf.space/", hf_token=hf_token)
7
 
8
  clipi_client = Client("https://fffiloni-clip-interrogator-2.hf.space/")
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  def get_text_after_colon(input_text):
11
  # Find the first occurrence of ":"
12
  colon_index = input_text.find(":")
@@ -38,11 +64,7 @@ def infer(image_input, audience):
38
 
39
  """
40
  gr.Info('Calling Llama2 ...')
41
- result = client.predict(
42
- llama_q, # str in 'Message' Textbox component
43
- "I2S",
44
- api_name="/predict"
45
- )
46
 
47
  print(f"Llama2 result: {result}")
48
 
@@ -59,61 +81,7 @@ def infer(image_input, audience):
59
 
60
  css="""
61
  #col-container {max-width: 910px; margin-left: auto; margin-right: auto;}
62
- a {text-decoration-line: underline; font-weight: 600;}
63
- a {text-decoration-line: underline; font-weight: 600;}
64
- .animate-spin {
65
- animation: spin 1s linear infinite;
66
- }
67
- @keyframes spin {
68
- from {
69
- transform: rotate(0deg);
70
- }
71
- to {
72
- transform: rotate(360deg);
73
- }
74
- }
75
- #share-btn-container {
76
- display: flex;
77
- padding-left: 0.5rem !important;
78
- padding-right: 0.5rem !important;
79
- background-color: #000000;
80
- justify-content: center;
81
- align-items: center;
82
- border-radius: 9999px !important;
83
- max-width: 15rem;
84
- }
85
- div#share-btn-container > div {
86
- flex-direction: row;
87
- background: black;
88
- align-items: center;
89
- }
90
- #share-btn-container:hover {
91
- background-color: #060606;
92
- }
93
- #share-btn {
94
- all: initial;
95
- color: #ffffff;
96
- font-weight: 600;
97
- cursor:pointer;
98
- font-family: 'IBM Plex Sans', sans-serif;
99
- margin-left: 0.5rem !important;
100
- padding-top: 0.5rem !important;
101
- padding-bottom: 0.5rem !important;
102
- right:0;
103
- }
104
- #share-btn * {
105
- all: unset;
106
- }
107
- #share-btn-container div:nth-child(-n+2){
108
- width: auto !important;
109
- min-height: 0px !important;
110
- }
111
- #share-btn-container .wrap {
112
- display: none !important;
113
- }
114
- #share-btn-container.hidden {
115
- display: none!important;
116
- }
117
 
118
  div#story textarea {
119
  font-size: 1.5em;
 
3
  import os
4
  hf_token = os.environ.get('HF_TOKEN')
5
  from gradio_client import Client
6
+
7
+ from transformers import AutoTokenizer, AutoModelForCausalLM
8
+
9
+ model_path = "meta-llama/Llama-2-7b-chat-hf"
10
+
11
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False, use_auth_token=hf_token)
12
+ model = AutoModelForCausalLM.from_pretrained(model_path, use_auth_token=hf_token).half().cuda()
13
+
14
  client = Client("https://fffiloni-test-llama-api-debug.hf.space/", hf_token=hf_token)
15
 
16
  clipi_client = Client("https://fffiloni-clip-interrogator-2.hf.space/")
17
 
18
+ def llama_gen_story(prompt):
19
+
20
+ instruction = """[INST] <<SYS>>\nYou are a storyteller. You'll be given an image description and some keyword about the image.
21
+ For that given you'll be asked to generate a story that you think could fit very well with the image provided.
22
+ Always answer with a cool story, while being safe as possible. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
23
+ If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n<</SYS>>\n\n{} [/INST]"""
24
+
25
+
26
+ prompt = instruction.format(prompt)
27
+
28
+ generate_ids = model.generate(tokenizer(prompt, return_tensors='pt').input_ids.cuda(), max_new_tokens=4096)
29
+ output_text = tokenizer.decode(generate_ids[0], skip_special_tokens=True)
30
+ #print(generate_ids)
31
+ #print(output_text)
32
+ pattern = r'\[INST\].*?\[/INST\]'
33
+ cleaned_text = re.sub(pattern, '', output_text, flags=re.DOTALL)
34
+ return cleaned_text
35
+
36
  def get_text_after_colon(input_text):
37
  # Find the first occurrence of ":"
38
  colon_index = input_text.find(":")
 
64
 
65
  """
66
  gr.Info('Calling Llama2 ...')
67
+ result = llama_gen_story(llama_q)
 
 
 
 
68
 
69
  print(f"Llama2 result: {result}")
70
 
 
81
 
82
  css="""
83
  #col-container {max-width: 910px; margin-left: auto; margin-right: auto;}
84
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
  div#story textarea {
87
  font-size: 1.5em;