File size: 2,756 Bytes
61d0253
 
 
 
 
 
 
 
 
 
 
b7ca7fe
 
61d0253
 
 
 
 
 
 
 
 
 
 
 
 
 
b7ca7fe
61d0253
 
 
 
 
 
 
 
 
 
b7ca7fe
 
 
 
61d0253
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b7ca7fe
 
61d0253
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import streamlit as st
import torch
import tiktoken
from transformer import GPT, GPTConfig  # Ensure you import your model class

# Load the trained model
@st.cache_resource
def load_model():
    config = GPTConfig()
    model = GPT(config)
    try:
        # Load the model with map_location to handle CPU-only environments
        model.load_state_dict(torch.load('trained_model_quantized.pt', map_location=torch.device('cpu')), strict=False)
        model.eval()  # Set the model to evaluation mode
        st.success("Model loaded successfully!")
    except Exception as e:
        st.error(f"Error loading model: {e}")
    return model

# Load the tokenizer
def load_tokenizer():
    return tiktoken.get_encoding('gpt2')

# Generate text function
def generate_text(model, tokenizer, input_text, length, num_sequences):
    # Encode the input text
    input_ids = tokenizer.encode(input_text)
    input_tensor = torch.tensor(input_ids).unsqueeze(0)  # Add batch dimension (shape: [1, T])

    generated_sequences = []
    for _ in range(num_sequences):
        # Generate additional tokens
        with torch.no_grad():
            for _ in range(length):
                logits = model(input_tensor)[0]  # Get logits
                next_token_logits = logits[:, -1, :]  # Get the last token's logits
                next_token_probs = torch.softmax(next_token_logits, dim=-1)
                next_token = torch.multinomial(next_token_probs, num_samples=1)  # Sample from the distribution
                
                # Ensure the next_token has the correct shape for concatenation
                next_token = next_token.view(1, -1)  # Reshape to [1, 1] if necessary
                input_tensor = torch.cat((input_tensor, next_token), dim=1)  # Append the new token

        # Decode the generated tokens
        generated_sequences.append(tokenizer.decode(input_tensor[0].tolist()))

    return generated_sequences

# Streamlit app layout
st.title("GPT Text Generator")
st.write("Enter your text and specify the length of additional text to generate.")

input_text = st.text_area("Input Text", "Once upon a time", max_chars=512)  # Limit to 512 characters
length = st.slider("Predict Additional Text of Length", 1, 50, 10)
num_sequences = st.slider("Number of Sequences to Generate", 1, 5, 1)

if st.button("Generate"):
    model = load_model()  # Load the model for inference
    tokenizer = load_tokenizer()  # Load the tokenizer
    st.write("Generating text...")
    generated_texts = generate_text(model, tokenizer, input_text, length, num_sequences)
    st.write("Text generation complete.")

    st.write("Generated Texts:")
    for i, text in enumerate(generated_texts):
        st.subheader(f"Sequence {i + 1}")
        st.write(text)