Home » Tutorials » How to Create a Text Generation App with GPT-2 in Python

How to Create a Text Generation App with GPT-2 in Python

Creating a text generation app is like giving your Python code a creative mind of its own. With GPT-2 from Hugging Face, we can build an app that doesn’t just handle words but crafts entire thoughts, ideas, and stories based on a few prompts. Whether you’re curious about natural language processing (NLP) or just want to experiment with AI-generated text, this tutorial has you covered.

Today, you’ll learn how to build a text generator in Python using the powerful GPT-2 model, and we’re taking it up a notch by setting it up with a Tkinter-based GUI. From downloading and managing the model to setting parameters and generating text with a click, we’ll dive deep into each step. Let’s get creative with Python!

Table of Contents

Getting Started

Let’s get our tools ready! Run these commands to install everything we need:

$ pip install transformers
$ pip install requests
$ pip install tqdm
$ pip install tk

Before we dive in, let’s gather our tools to lay a solid foundation:

  • Tkinter: This will let us build the graphical interface for our app.
  • Transformers: This library loads and runs our GPT-2 model with its tokenizers.
  • os: Since we’ll be handling model downloads, os helps manage paths and directories.
  • Threading: This keeps the app responsive by allowing tasks to run in the background.
  • Requests: Essential for handling the HTTP requests that download our files.
  • tqdm: And finally, to show download progress, tqdm adds a helpful progress bar.
import tkinter as tk
from tkinter import scrolledtext, messagebox
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import os
import threading
import requests
from tqdm import tqdm

With these ready, we’re all set to start building!

Download Helper Function with Progress Bar

Now comes a crucial part: downloading our model. We want to keep an eye on the progress so we know if the files are actually downloading or if we’re just sitting around wasting time. That’s where our download_file() function comes into play!

# Helper function to download a file with a progress bar
def download_file(url, dest):

This function uses requests.get() to fetch the file data from the provided URL. First, it checks if the request was successful. If everything is good, we can move on!

   try:
       response = requests.get(url, stream=True)
       response.raise_for_status()

Next, we grab the total size of the file. This step is super important because it lets us set up the progress bar. We also use os.path.basename() to get the file name, ensuring it saves in the right place.

       total_size = int(response.headers.get('content-length', 0))
       with open(dest, 'wb') as file, tqdm(
               desc=f"Downloading {os.path.basename(dest)}",
               total=total_size,
               unit='iB',
               unit_scale=True,
               unit_divisor=1024,
       ) as bar:

Now, instead of downloading the entire file in one go, we do it in smaller chunks. This way, we can update the progress bar in real-time using tqdm. It’s like watching the download fill up right before your eyes!

           for data in response.iter_content(chunk_size=1024):
               size = file.write(data)
               bar.update(size)

And hey, if something goes wrong during the download, we’ve got a safety net. Our function kicks in to display an error message with messagebox.showerror(), letting us know what went wrong.

   except requests.exceptions.RequestException as e:
       messagebox.showerror("Download Error", f"Error downloading {os.path.basename(dest)}:\n{e}")

Retrieving Model and Tokenizer Files from Hugging Face

Downloading the model every time you run your code can be a real hassle and waste precious time. To streamline this process, we created the download_model_with_progress() function. This function first checks if the necessary model files already exist on your device.

def download_model_with_progress(model_name="gpt2-large"):
   model_dir = f"./{model_name}"
   model_file = os.path.join(model_dir, "pytorch_model.bin")
   config_file = os.path.join(model_dir, "config.json")
   tokenizer_file = os.path.join(model_dir, "vocab.json")
   merges_file = os.path.join(model_dir, "merges.txt")

It uses os.path.join() to construct the paths for each file. If all the files are found, it simply lets you know with a friendly message.

   if all(os.path.exists(f) for f in [model_file, config_file, tokenizer_file, merges_file]):
       print("Model files already downloaded.")

However, if any files are missing, the function has a handy trick up its sleeve. It defines a dictionary with the URLs where the files can be downloaded from.

   else:
       model_urls = {
           "model": f"https://huggingface.co/{model_name}/resolve/main/pytorch_model.bin",
           "config": f"https://huggingface.co/{model_name}/resolve/main/config.json",
           "tokenizer": f"https://huggingface.co/{model_name}/resolve/main/vocab.json",
           "merges": f"https://huggingface.co/{model_name}/resolve/main/merges.txt"
       }

Next, it ensures the directory for the model files exists, creating it if it doesn’t. This way, there’s always a place to store the downloaded files.

       os.makedirs(model_dir, exist_ok=True)

Then, the function loops through each file and calls download_file() to fetch any missing files from the specified URLs.

       for name, url in model_urls.items():
           download_file(url, os.path.join(model_dir, f"{name}.bin" if name == "model" else f"{name}.json"))

Once all the files are safely downloaded, the function loads the GPT-2 model and tokenizer from the specified directory, making them ready for action in your code.

   model = GPT2LMHeadModel.from_pretrained(model_dir)
   tokenizer = GPT2Tokenizer.from_pretrained(model_dir)
   return model, tokenizer

Generating Text from a Prompt

Now we’ve reached the heart of our program—the part that generates responses to the user’s prompt! The generate_text() function takes the prompt and transforms it into tokens that the model can understand. To control the generated text, we set specific parameters:

  • max_length sets a limit on how many words can be generated.
  • top_k and top_p define the range of words and flexibility during generation.
  • temperature influences the creativity of the response, where lower values lead to more conservative outputs.
  • repetition_penalty helps make the response feel more natural by reducing repetitive patterns.
  • Lastly, no_repeat_ngram_size establishes a minimum distance between repeated word groups to avoid redundancy.

Taking all these parameters into account, this function utilizes the model to generate tokens from the transformed prompt. After that, those generated tokens are converted back into a readable format for us!

# Function to generate text from a prompt
def generate_text(model, tokenizer, prompt, max_length=200, top_k=50, top_p=0.95, temperature=0.7):
   input_ids = tokenizer.encode(prompt, return_tensors="pt")
   output = model.generate(
       input_ids,
       max_length=max_length,
       do_sample=True,
       top_k=top_k,
       top_p=top_p,
       temperature=temperature,
       repetition_penalty=1.2,
       no_repeat_ngram_size=2
   )
   return tokenizer.decode(output[0], skip_special_tokens=True)

Handling the Generate Button Click

# Function to handle button click and update the text box with generated text using a thread
def on_generate_click():
   prompt = prompt_entry.get().strip()
   if prompt:
       result_box.delete(1.0, tk.END)
       result_box.insert(tk.END, "Generating text, please wait...")
       # Start a new thread for text generation
       threading.Thread(target=generate_and_display_text, args=(prompt,)).start()

In the on_generate_click() function, we bring everything together! It starts by grabbing the prompt from the entry box. Once that’s set, it clears out any old content in the output box to make room for the new text. Then, it adds a message letting the user know that the text generation is underway. Finally, it kicks off a new thread to call the generate_and_display_text() function, ensuring that everything runs smoothly without any interruptions.

Generating and Displaying the Text

# Function to generate and display text
def generate_and_display_text(prompt):
   max_length = int(max_length_entry.get())
   top_k = int(top_k_entry.get())
   top_p = float(top_p_entry.get())
   temperature = float(temperature_entry.get())
   generated_text = generate_text(model, tokenizer, prompt, max_length, top_k, top_p, temperature)
   result_box.delete(1.0, tk.END)
   result_box.insert(tk.END, generated_text)

With the text ready, the next step is to display it. That responsibility falls to the generate_and_display_text() function. This function takes the necessary parameters that control the text generation and calls the generate_text() function to produce the response. Once that’s done, it removes the previous loading message from the output box using result_box.delete() and inserts the newly generated text with result_box.insert().

Setting Up the Main Window

At this point, this is where the layout starts to take shape. We begin by creating the interface using Tkinter, setting its title, and defining its geometry.

# Initialize the model and tokenizer (downloads once if needed)
model, tokenizer = download_model_with_progress("gpt2-large")

# Set up the Tkinter window
window = tk.Tk()
window.title("GPT-2 Text Generator - The Pycodes")
window.geometry("800x600")

Next, we create the label for the prompt input, making it clear to users where to enter their text.

# Label for the prompt input
prompt_label = tk.Label(window, text="Enter a prompt:")
prompt_label.pack(pady=10)

Following that, we add an entry field where users can input their prompt.

# Entry field for the prompt
prompt_entry = tk.Entry(window, width=70)
prompt_entry.pack(pady=5)

Afterwards, we create a frame to hold the parameters that control the generated text, keeping everything organized.

# Parameter settings
param_frame = tk.Frame(window)
param_frame.pack(pady=10)

Next, we set up individual entries for parameters like max length, top K, top P, and temperature, labeling each accordingly to guide the user.

# Max length
tk.Label(param_frame, text="Max Length:").grid(row=0, column=0)
max_length_entry = tk.Entry(param_frame, width=5)
max_length_entry.insert(0, "200")
max_length_entry.grid(row=0, column=1)

# Top K
tk.Label(param_frame, text="Word Selection Range:").grid(row=0, column=2)
top_k_entry = tk.Entry(param_frame, width=5)
top_k_entry.insert(0, "50")
top_k_entry.grid(row=0, column=3)

# Top P
tk.Label(param_frame, text="Flexibility:").grid(row=0, column=4)
top_p_entry = tk.Entry(param_frame, width=5)
top_p_entry.insert(0, "0.95")
top_p_entry.grid(row=0, column=5)

# Temperature
tk.Label(param_frame, text="Creativity Level:").grid(row=0, column=6)
temperature_entry = tk.Entry(param_frame, width=5)
temperature_entry.insert(0, "0.7")
temperature_entry.grid(row=0, column=7)

Then, we create the “Generate Text” button, which triggers the on_generate_click() function when clicked.

# Button to trigger text generation
generate_button = tk.Button(window, text="Generate Text", command=on_generate_click)
generate_button.pack(pady=10)

After that, we set up the output box where the generated result will be displayed, ensuring it includes a vertical scrollbar for easy navigation.

# Scrolled text box to display the generated text
result_box = scrolledtext.ScrolledText(window, height=10, wrap=tk.WORD, width=80)
result_box.pack(pady=20)

Finally, we use the mainloop() method to start the main event loop, keeping the window responsive to user interactions.

# Run the Tkinter main loop
window.mainloop()

Example

Full Code

import tkinter as tk
from tkinter import scrolledtext, messagebox
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import os
import threading
import requests
from tqdm import tqdm


# Helper function to download a file with a progress bar
def download_file(url, dest):
   try:
       response = requests.get(url, stream=True)
       response.raise_for_status()
       total_size = int(response.headers.get('content-length', 0))
       with open(dest, 'wb') as file, tqdm(
               desc=f"Downloading {os.path.basename(dest)}",
               total=total_size,
               unit='iB',
               unit_scale=True,
               unit_divisor=1024,
       ) as bar:
           for data in response.iter_content(chunk_size=1024):
               size = file.write(data)
               bar.update(size)
   except requests.exceptions.RequestException as e:
       messagebox.showerror("Download Error", f"Error downloading {os.path.basename(dest)}:\n{e}")


# Download model and tokenizer if files do not exist
def download_model_with_progress(model_name="gpt2-large"):
   model_dir = f"./{model_name}"
   model_file = os.path.join(model_dir, "pytorch_model.bin")
   config_file = os.path.join(model_dir, "config.json")
   tokenizer_file = os.path.join(model_dir, "vocab.json")
   merges_file = os.path.join(model_dir, "merges.txt")


   if all(os.path.exists(f) for f in [model_file, config_file, tokenizer_file, merges_file]):
       print("Model files already downloaded.")
   else:
       model_urls = {
           "model": f"https://huggingface.co/{model_name}/resolve/main/pytorch_model.bin",
           "config": f"https://huggingface.co/{model_name}/resolve/main/config.json",
           "tokenizer": f"https://huggingface.co/{model_name}/resolve/main/vocab.json",
           "merges": f"https://huggingface.co/{model_name}/resolve/main/merges.txt"
       }
       os.makedirs(model_dir, exist_ok=True)
       for name, url in model_urls.items():
           download_file(url, os.path.join(model_dir, f"{name}.bin" if name == "model" else f"{name}.json"))


   model = GPT2LMHeadModel.from_pretrained(model_dir)
   tokenizer = GPT2Tokenizer.from_pretrained(model_dir)
   return model, tokenizer


# Function to generate text from a prompt
def generate_text(model, tokenizer, prompt, max_length=200, top_k=50, top_p=0.95, temperature=0.7):
   input_ids = tokenizer.encode(prompt, return_tensors="pt")
   output = model.generate(
       input_ids,
       max_length=max_length,
       do_sample=True,
       top_k=top_k,
       top_p=top_p,
       temperature=temperature,
       repetition_penalty=1.2,
       no_repeat_ngram_size=2
   )
   return tokenizer.decode(output[0], skip_special_tokens=True)


# Function to handle button click and update the text box with generated text using a thread
def on_generate_click():
   prompt = prompt_entry.get().strip()
   if prompt:
       result_box.delete(1.0, tk.END)
       result_box.insert(tk.END, "Generating text, please wait...")
       # Start a new thread for text generation
       threading.Thread(target=generate_and_display_text, args=(prompt,)).start()


# Function to generate and display text
def generate_and_display_text(prompt):
   max_length = int(max_length_entry.get())
   top_k = int(top_k_entry.get())
   top_p = float(top_p_entry.get())
   temperature = float(temperature_entry.get())
   generated_text = generate_text(model, tokenizer, prompt, max_length, top_k, top_p, temperature)
   result_box.delete(1.0, tk.END)
   result_box.insert(tk.END, generated_text)


# Initialize the model and tokenizer (downloads once if needed)
model, tokenizer = download_model_with_progress("gpt2-large")


# Set up the Tkinter window
window = tk.Tk()
window.title("GPT-2 Text Generator - The Pycodes")
window.geometry("800x600")


# Label for the prompt input
prompt_label = tk.Label(window, text="Enter a prompt:")
prompt_label.pack(pady=10)


# Entry field for the prompt
prompt_entry = tk.Entry(window, width=70)
prompt_entry.pack(pady=5)


# Parameter settings
param_frame = tk.Frame(window)
param_frame.pack(pady=10)


# Max length
tk.Label(param_frame, text="Max Length:").grid(row=0, column=0)
max_length_entry = tk.Entry(param_frame, width=5)
max_length_entry.insert(0, "200")
max_length_entry.grid(row=0, column=1)


# Top K
tk.Label(param_frame, text="Word Selection Range:").grid(row=0, column=2)
top_k_entry = tk.Entry(param_frame, width=5)
top_k_entry.insert(0, "50")
top_k_entry.grid(row=0, column=3)


# Top P
tk.Label(param_frame, text="Flexibility:").grid(row=0, column=4)
top_p_entry = tk.Entry(param_frame, width=5)
top_p_entry.insert(0, "0.95")
top_p_entry.grid(row=0, column=5)


# Temperature
tk.Label(param_frame, text="Creativity Level:").grid(row=0, column=6)
temperature_entry = tk.Entry(param_frame, width=5)
temperature_entry.insert(0, "0.7")
temperature_entry.grid(row=0, column=7)


# Button to trigger text generation
generate_button = tk.Button(window, text="Generate Text", command=on_generate_click)
generate_button.pack(pady=10)


# Scrolled text box to display the generated text
result_box = scrolledtext.ScrolledText(window, height=10, wrap=tk.WORD, width=80)
result_box.pack(pady=20)


# Run the Tkinter main loop
window.mainloop()

Happy Coding!

Leave a Comment

Your email address will not be published. Required fields are marked *

Scroll to Top