Home » Tutorials » Traffic Sign Recognition System Using CNN and TensorFlow in Python

Traffic Sign Recognition System Using CNN and TensorFlow in Python

Have you ever wondered how self-driving cars can recognize traffic signs on the road? It’s all thanks to powerful machine learning models that can quickly process visual data and make real-time decisions. In today’s world, such systems are crucial for making transportation safer and smarter.

In this tutorial, you’ll learn how to build a Traffic Sign Recognition System from scratch. Using Convolutional Neural Networks (CNN) and TensorFlow, we’ll guide you through creating a model that can accurately identify different traffic signs. By the time you’re done, you’ll have a fully functioning system that predicts signs with impressive precision.

Table of Contents

Getting Started

Before we start, be sure to structure your file like this:

To make sure everything runs smoothly, you’ll need to install a few essential libraries. You can easily do this by running the following commands in your terminal:

$ pip install tensorflow
$ pip install Pillow
$ pip install numpy
$ pip install tk 
$ pip install matplotlib
$ pip install opencv-python
$ pip install scikit-learn

Let’s get started by gathering the core ingredients we’ll need for our Traffic Sign Recognition System. Here’s a quick rundown of our essential libraries and their roles:

  • Tkinter: This is our trusty tool for creating the graphical interface. It will serve as the control panel for our application.
import tkinter as tk
from tkinter import ttk
from tkinter import filedialog, messagebox, scrolledtext
from PIL import Image, ImageTk
  • TensorFlow/Keras: Our go-to for training the neural network. TensorFlow and Keras will handle the heavy lifting of image classification.
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.preprocessing.image import ImageDataGenerator
  • NumPy: We’ll use this for manipulating and processing our image data.
import numpy as np
  • OpenCV (cv2): This library will help us with image preprocessing and resizing.
import cv2
  • For managing file paths and directories, os comes in handy:
import os
  • We’ll also use train_test_split from sklearn.model_selection to split our dataset into training and testing sets. This step is crucial for evaluating how well our model performs on new, unseen data.
from sklearn.model_selection import train_test_split
  • Threading: To keep our interface responsive while the model trains, we’ll use threading.
import threading

Now that we have all our tools lined up, let’s dive into building our Traffic Sign Recognition System. The first thing we’ll do is:

Set Up Parameters and Define Class Names for Traffic Sign Recognition

Keep TensorFlow Quiet

If you’ve ever worked with TensorFlow, you know it has a knack for throwing a lot of warnings your way, which can be a bit overwhelming. To keep things calm and focused, we’ll turn off those noisy warnings and only keep the important ones.

# Suppress TensorFlow warnings
import logging
logging.getLogger('tensorflow').setLevel(logging.ERROR)

Defining Image Size and Paths

Before we get into the fun stuff, we need to set up our basics. This means deciding on the image’s dimensions, setting the number of traffic sign classes, and choosing where to save our trained model. Think of it as getting everything ready before we start cooking.

# Define image dimensions and number of classes
IMG_HEIGHT, IMG_WIDTH = 30, 30
NUM_CLASSES = 43

# Define model path
MODEL_PATH = 'traffic_sign_model.keras'

Traffic Sign Classes

Let’s introduce our VIP list of traffic signs! We’ve got 43 different classes, and this list helps us decode what our model predicts. Since the model gives us numbers from 0 to 42, this list translates those numbers into the actual traffic signs we’ll recognize.

# Class names for GTSRB (German Traffic Sign Recognition Benchmark)
# Ensure that these names match your dataset's classes. Update if necessary.
class_names = [
  'Speed Limit 20', 'Speed Limit 30', 'Speed Limit 50', 'Speed Limit 60',
  'Speed Limit 70', 'Speed Limit 80', 'End of Speed Limit 80', 'Speed Limit (End)',
  'Speed Limit (End)', 'Right-of-Way at Intersection', 'Priority Road',
  'Give Way', 'Stop', 'No Traffic Both Ways', 'No Trucks',
  'No Entry', 'General Caution', 'Dangerous Curve Left', 'Dangerous Curve Right',
  'Double Curve', 'Bumpy Road', 'Slippery Road', 'Road Narrows on the Right',
  'Road Work', 'Traffic Signals', 'Pedestrians', 'Children Crossing',
  'Bicycles Crossing', 'Beware of Ice/Snow', 'Wild Animals Crossing',
  'End of All Speed and Passing Limits', 'Turn Right Ahead',
  'Turn Left Ahead', 'Ahead Only', 'Go Straight or Right', 'Go Straight or Left',
  'Keep Right', 'Keep Left', 'Roundabout Mandatory', 'End of No Passing',
  'End of No Passing by Vehicles over 3.5 Tons', 'No Stopping', 'No Parking',
  'No Standing'
]

Loading and Preprocessing the Data

Loading Data

It’s time to load our data using the load_data() function:

First, we set up two empty lists: images to store our images and labels to keep track of their associated classes (IDs). The function then loops through each folder (one for each class) in the training directory, grabs the images, resizes them to 30×30 using cv2, and adds them to the images list with their class in the labels list. If it encounters a corrupted image or something goes wrong with a specific image, it skips over it and keeps going.

By the end, you’ll have two NumPy arrays—one for the images and one for their corresponding classes—ready to be used in the rest of the code.

def load_data(data_dir):
  """
  Load images and labels from the specified directory.


  Args:
      data_dir (str): Path to the training data directory.


  Returns:
      Tuple of NumPy arrays: (images, labels)
  """
  images, labels = [], []
  for class_dir in os.listdir(data_dir):
      class_path = os.path.join(data_dir, class_dir)
      if not os.path.isdir(class_path):
          continue
      try:
          label = int(class_dir)
      except ValueError:
          print(f"Skipping non-integer directory: {class_dir}")
          continue
      for img_name in os.listdir(class_path):
          # Filter out non-image files
          if not img_name.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.ppm')):
              continue  # Skip non-image files like CSVs
          img_path = os.path.join(class_path, img_name)
          img = cv2.imread(img_path)
          if img is None:
              print(f"Warning: Unable to read image {img_path}. Skipping.")
              continue
          img = cv2.resize(img, (IMG_WIDTH, IMG_HEIGHT))
          img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
          images.append(img)
          labels.append(label)
  return np.array(images), np.array(labels)

Preprocessing Data

What comes after loading the data? The next step is to tweak it a bit through normalization and one-hot encoding. Here’s what each of these processes involves:

  • Normalization adjusts the data to a consistent range, which helps the model learn more efficiently. To achieve this, we divide the pixel values by 255, scaling them between 0 and 1.
  • One-hot encoding transforms our labels (class numbers) into a format that the model can understand. Each label is converted into a vector of length 43, where only one position is marked as 1 (indicating the class index), and all other positions are set to 0.
def preprocess_data(X, y):
  """
  Normalize image data and convert labels to categorical format.


  Args:
      X (np.array): Array of images.
      y (np.array): Array of labels.


  Returns:
      Tuple of NumPy arrays: (X_processed, y_processed)
  """
  X = X.astype('float32') / 255.0
  y = to_categorical(y, NUM_CLASSES)
  return X, y

Building and Saving the CNN Model

Designing the CNN Model

Here’s where the magic really happens! We’re about to build a model that can classify images into one of 43 traffic sign categories once it’s trained. So, let’s dive in:

  • First, we use Conv2D layers to scan the image for key features like lines, curves, and textures. It’s like teaching the model to spot important details.
  • Next, we add MaxPooling layers to help the model zoom out and simplify the image, making it easier to handle.
  • Then, we flatten the 2D data into a 1D vector so that the Dense layers can get to work, learning patterns to classify the images. The softmax activation function helps the model express how confident it is in its predictions.
  • Finally, we compile the model, and it’s all set to be trained and used!
def build_model(input_shape, num_classes):
  """
  Build and return a CNN model for traffic sign classification.


  Args:
      input_shape (tuple): Shape of the input images (height, width, channels).
      num_classes (int): Number of classes.


  Returns:
      keras.Model: Compiled CNN model.
  """
  inputs = layers.Input(shape=input_shape)
  x = layers.Conv2D(32, (3, 3), activation='relu')(inputs)
  x = layers.BatchNormalization()(x)
  x = layers.MaxPooling2D((2, 2))(x)


  x = layers.Conv2D(64, (3, 3), activation='relu')(x)
  x = layers.BatchNormalization()(x)
  x = layers.MaxPooling2D((2, 2))(x)


  x = layers.Flatten()(x)
  x = layers.Dense(128, activation='relu')(x)
  x = layers.BatchNormalization()(x)
  outputs = layers.Dense(num_classes, activation='softmax')(x)


  model = models.Model(inputs=inputs, outputs=outputs)
  return model

Having designed and built our CNN model, the next step is to make sure our hard work pays off. Let’s save this model so we can easily use it again in the future without the need for retraining.

Storing the Trained Model

Training the model takes time and effort, so we don’t want to do it every single time we use it. Instead, we save the model after training using the save_model() function. This function saves our trained model to a file in .keras format, so we can easily reload it later for making predictions. This way, we can use our model as often as we need without retraining it from scratch!

def save_model(model, path=MODEL_PATH):
  """
  Save the trained model to the specified path.


  Args:
      model (keras.Model): Trained Keras model.
      path (str): File path to save the model.
  """
  model.save(path)

Integrating Model Training and Traffic Sign Prediction with a GUI

Connecting Training with the GUI

When you’re training your model, it’s like working out at the gym — you want to see the progress you’re making. That’s the beauty of the TrainingCallback class, which updates the GUI to display real-time progress during the model’s training. How does it do that? Let’s break it down:

First things first, we design tf.keras.callbacks.Callback from TensorFlow to communicate with the GUI during training. The class is initialized with the __init__ method, and we reference the app using self.app.

class TrainingCallback(tf.keras.callbacks.Callback):
    """
    Custom callback to update the GUI with training progress and logs.
    """

    def __init__(self, app):
        super().__init__()
        self.app = app

End of Epoch Logging

This method is automatically called by TensorFlow, and it updates the training logs and progress bar. Here’s how it works:

  • We start by making the logs variable a dictionary that contains key training information.
  • Then, we calculate the percentage of total epochs and use it for the current progress.
  • Afterward, we format the log messages to show the current epoch number, loss, accuracy, validation loss, and validation accuracy. We use get() to avoid any crashes if some of these metrics are missing.
  • Finally, the calculated progress is sent to the GUI to update the progress bar, and the formatted logs are sent to the log_message() function, which displays them on the text box for the user.
    def on_epoch_end(self, epoch, logs=None):
        logs = logs or {}
        progress = ((epoch + 1) / self.params['epochs']) * 100
        message = (f"Epoch {epoch + 1}/{self.params['epochs']} - "
                   f"Loss: {logs.get('loss', 0):.4f} - "
                   f"Accuracy: {logs.get('accuracy', 0):.4f} - "
                   f"Val_Loss: {logs.get('val_loss', 0):.4f} - "
                   f"Val_Accuracy: {logs.get('val_accuracy', 0):.4f}\n")
        self.app.update_progress(progress)
        self.app.log_message(message)

End of Training

When the training process finishes, the on_train_end() function is automatically called. This function informs the user that the training is complete by showing a log message and a message box. It also enables the Train buttons, which were disabled during the training process.

    def on_train_end(self, logs=None):
        self.app.log_message("Training completed.\n")
        self.app.enable_train_buttons()
        messagebox.showinfo("Training Complete", f"Model trained and saved as {MODEL_PATH}.")

TrafficSignApp Class

This is the heart of our program, the TrafficSignApp class, which manages both the graphical interface and the internal logic for traffic sign detection. The __init__ method sets the stage by creating the main window, setting its title, and defining its geometry. We also call the constructor method, which then calls the create_widgets() function to build the GUI components.

class TrafficSignApp:
    def __init__(self, master):
        self.master = master
        master.title("Traffic Sign Detection and Training - The Pycodes")
        master.geometry("800x800")  # Increased height to accommodate scrollbars

        self.model = None

        # Initialize UI components
        self.create_widgets()

Creating the Widgets

Now let’s bring the GUI to life. We divide the main window into two sections: Model Training and Traffic Sign Prediction.

Model Training Section

This section includes a label to display the Training Directory once it’s selected, and two buttons:

  • The first button lets you select the training directory and calls the select_train_directory() function.
  • The second button starts the training by calling the start_training() function. It’s disabled at first and is only enabled when a training directory is selected.

This section also includes a progress bar and a text box where the training logs are displayed.

    def create_widgets(self):
        # Frame for Training
        training_frame = tk.LabelFrame(self.master, text="Model Training", padx=10, pady=10)
        training_frame.pack(fill="both", expand="yes", padx=10, pady=10)

        # Training Directory Selection
        self.train_dir_label = tk.Label(training_frame, text="Training Directory: Not Selected", wraplength=600,
                                        justify="left")
        self.train_dir_label.pack(anchor='w')

        self.select_train_dir_button = tk.Button(training_frame, text="Select Training Directory",
                                                 command=self.select_train_directory)
        self.select_train_dir_button.pack(pady=5)

        # Start Training Button
        self.train_button = tk.Button(training_frame, text="Start Training", command=self.start_training,
                                      state='disabled', bg='green', fg='white')
        self.train_button.pack(pady=5)

        # Progress Bar
        self.progress_var = tk.DoubleVar()
        self.progress_bar = ttk.Progressbar(training_frame, variable=self.progress_var, maximum=100)
        self.progress_bar.pack(fill='x', pady=5)

        # Training Logs
        self.log_text = scrolledtext.ScrolledText(training_frame, height=10, state='disabled')
        self.log_text.pack(fill='both', expand=True, pady=5)

Traffic Sign Prediction Section

This section is all about making predictions. We use a frame with a scrollbar for smooth navigation, and inside it, we add:

  • An Upload Image for Prediction button, which calls the upload_image() function.
  • A frame that holds the canvas where the image will be displayed, complete with horizontal and vertical scrollbars.
  • A Prediction Label to display the predicted traffic sign and a Confidence Label to show how confident the model is about its prediction.
        # Frame for Prediction
        prediction_frame = tk.LabelFrame(self.master, text="Traffic Sign Prediction", padx=10, pady=10)
        prediction_frame.pack(fill="both", expand="yes", padx=10, pady=10)

        # Create a Canvas inside prediction_frame
        canvas = tk.Canvas(prediction_frame, borderwidth=0, background="#f0f0f0")
        canvas.pack(side=tk.LEFT, fill=tk.BOTH, expand=True)

        # Add vertical and horizontal scrollbars to the Canvas
        v_scrollbar = tk.Scrollbar(prediction_frame, orient=tk.VERTICAL, command=canvas.yview)
        v_scrollbar.pack(side=tk.RIGHT, fill=tk.Y)

        h_scrollbar = tk.Scrollbar(self.master, orient=tk.HORIZONTAL, command=canvas.xview)
        h_scrollbar.pack(fill=tk.X)

        canvas.configure(yscrollcommand=v_scrollbar.set, xscrollcommand=h_scrollbar.set)

        # Create an inner frame to hold prediction widgets
        self.inner_prediction_frame = tk.Frame(canvas, background="#f0f0f0")
        canvas.create_window((0, 0), window=self.inner_prediction_frame, anchor='nw')

        # Bind the inner frame to configure the scroll region
        self.inner_prediction_frame.bind("<Configure>", lambda event, canvas=canvas: self.on_frame_configure(canvas))

        # Upload Image Button
        self.upload_button = tk.Button(self.inner_prediction_frame, text="Upload Image for Prediction", command=self.upload_image,
                                       bg='blue', fg='white', font=('Helvetica', 12, 'bold'))
        self.upload_button.pack(pady=10)

        # Create a frame to hold the canvas and scrollbars for the image
        self.image_frame = tk.Frame(self.inner_prediction_frame)
        self.image_frame.pack(pady=5, fill=tk.BOTH, expand=True)

        # Create Canvas for image display
        self.image_canvas = tk.Canvas(self.image_frame, width=400, height=300, bg='gray')
        self.image_canvas.pack(side=tk.LEFT, fill=tk.BOTH, expand=True)

        # Add scrollbars to the image_canvas
        img_v_scrollbar = tk.Scrollbar(self.image_frame, orient=tk.VERTICAL, command=self.image_canvas.yview)
        img_v_scrollbar.pack(side=tk.RIGHT, fill=tk.Y)
        img_h_scrollbar = tk.Scrollbar(self.inner_prediction_frame, orient=tk.HORIZONTAL, command=self.image_canvas.xview)
        img_h_scrollbar.pack(fill=tk.X)

        self.image_canvas.configure(yscrollcommand=img_v_scrollbar.set, xscrollcommand=img_h_scrollbar.set)

        # Prediction Label
        self.prediction_label = tk.Label(self.inner_prediction_frame, text="Prediction: None", font=('Helvetica', 14))
        self.prediction_label.pack(pady=5)

        # Confidence Label
        self.confidence_label = tk.Label(self.inner_prediction_frame, text="Confidence: 0.00%", font=('Helvetica', 14))
        self.confidence_label.pack(pady=5)

Managing Scrolling and Training Directory Selection

Now that we’ve explored the integration of model training and prediction, it’s time to take a closer look at how key tasks like scrolling, selecting the training directory, and initiating the model training process are managed.

Handling Scrolling

We start with the on_frame_configure() function. It might seem like a small piece of the puzzle, but it’s crucial. Think of it as the function that keeps your canvas in check when larger images pop up, making sure scrolling works smoothly and the frame is properly resized to fit the content.

  def on_frame_configure(self, canvas):
      """
      Reset the scroll region to encompass the inner frame
      """
      canvas.configure(scrollregion=canvas.bbox("all"))

Selecting a Training Directory

Now, when it comes to selecting a directory for training, the select_train_directory() function has your back. It opens up a file dialog, letting you pick a directory. Once that’s done, it checks if you’ve made a selection. If yes, it updates the label to show the directory path and enables the “Start Training” button. If not, it keeps the Start Training button disabled and reminds you that no directory was chosen.

  def select_train_directory(self):
      directory = filedialog.askdirectory()
      if directory:
          self.train_dir = directory
          self.train_dir_label.config(text=f"Training Directory: {directory}")
          self.train_button.config(state='normal')
      else:
          self.train_dir_label.config(text="Training Directory: Not Selected")
          self.train_button.config(state='disabled')

Starting the Training

Once you’ve selected a directory, it’s time to kick off the training. When you hit that Start Training button, the start_training() function springs into action:

First, it disables both the Select Training Directory and Start Training buttons to make sure you don’t accidentally start another training session. Then, it logs the start of the training, sets the progress bar to zero, and launches a new thread that calls the train_model() function to get the actual training underway.

  def start_training(self):
      # Disable buttons to prevent multiple training sessions
      self.train_button.config(state='disabled')
      self.select_train_dir_button.config(state='disabled')
      self.log_text.config(state='normal')
      self.log_text.insert(tk.END, "Starting training...\n")
      self.log_text.config(state='disabled')
      self.progress_var.set(0)


      # Start training in a separate thread
      training_thread = threading.Thread(target=self.train_model)
      training_thread.start()

Training the Model

Alright, let’s get into the nitty-gritty of training our model. First up, we kick things off by loading the data. The load_data() function grabs the data from the selected directory, and we check if anything was loaded. If the directory turns out empty, we let the user know and re-enable the directory selection button. But if we’ve got data, we log how many images we’ve successfully loaded.

self.log_message("Loading training data...\n")
X, y = load_data(self.train_dir)
if X.size == 0:
    self.log_message("No data found. Please check the training directory.\n")
    self.enable_train_buttons()
    return
self.log_message(f"Loaded {X.shape[0]} images.\n")

Next, we dive into preprocessing the data with the preprocess_data() function. This step is crucial for getting our data ready for training.

self.log_message("Preprocessing data...\n")
X, y = preprocess_data(X, y)

Then, we split our data into training and validation sets using train_test_split(). This ensures we have a separate set of data to evaluate the model’s performance. We log the number of samples in each set for transparency.

self.log_message("Splitting data into training and validation sets...\n")
X_train, X_val, y_train, y_val = train_test_split(
    X, y, test_size=0.2, random_state=42
)
self.log_message(f"Training samples: {X_train.shape[0]}\n")
self.log_message(f"Validation samples: {X_val.shape[0]}\n")

Now, it’s time to enhance our data with some augmentation. This step rotates, zooms, and shifts our images to make our model more robust.

self.log_message("Applying data augmentation...\n")
datagen = ImageDataGenerator(
    rotation_range=10,
    zoom_range=0.1,
    width_shift_range=0.1,
    height_shift_range=0.1,
    horizontal_flip=False
)
datagen.fit(X_train)

With the data ready, we build our CNN model using build_model(). We log its architecture so you can see exactly what’s going on under the hood.

self.log_message("Building the CNN model...\n")
self.model = build_model(input_shape=(IMG_HEIGHT, IMG_WIDTH, 3), num_classes=NUM_CLASSES)
self.log_message("Model architecture:\n")
model_summary = []
self.model.summary(print_fn=lambda x: model_summary.append(x))
self.log_message("\n".join(model_summary) + "\n")

Next, we compile the model with the Adam optimizer and categorical cross-entropy loss function. This setup prepares the model for training.

self.log_message("Compiling the model...\n")
self.model.compile(
    optimizer='adam',
    loss='categorical_crossentropy',
    metrics=['accuracy']
)

To keep training on track and avoid overfitting, we use callbacks like TrainingCallback and EarlyStopping. These help us monitor the training process and stop if the model’s performance plateaus.

training_callback = TrainingCallback(self)
early_stopping = EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)

We then start the training process. The model trains for 20 epochs using augmented data, and we validate its performance with the validation set. Training progress is logged, and once done, we save the model to avoid retraining in the future.

self.log_message("Starting training...\n")
self.model.fit(
    datagen.flow(X_train, y_train, batch_size=32),
    epochs=20,
    validation_data=(X_val, y_val),
    callbacks=[training_callback, early_stopping],
    verbose=0  # Suppress Keras' own output
)

Lastly, we handle any errors that might occur during training, log them, and display an error message. Once everything is complete, we re-enable the training buttons, signaling the end of the training session.

save_model(self.model)
self.log_message(f"Model trained and saved as {MODEL_PATH}.\n")

except Exception as e:
    self.log_message(f"Error during training: {e}\n")
    messagebox.showerror("Training Error", str(e))
    self.enable_train_buttons()

Updating the GUI with Log Messages and Progress

Logging Messages

When it comes to logging messages to the GUI, we don’t want to freeze up the main window. So, we use a couple of functions to handle this smoothly:

First up is log_message(), which schedules an update for the main thread using after(). But it doesn’t handle the actual text insertion directly. Instead, it hands off the job to _log_message(). This function temporarily unlocks the text box, adds the new message, and makes sure the text box scrolls down to show the latest log. Once it’s done, it locks the text box back up.

def log_message(self, message):
    # Schedule GUI update in the main thread
    self.master.after(0, self._log_message, message)

def _log_message(self, message):
    self.log_text.config(state='normal')
    self.log_text.insert(tk.END, message)
    self.log_text.see(tk.END)
    self.log_text.config(state='disabled')

Updating Progress

Updating the progress bar is a bit like logging messages. We use two functions here as well. update_progress() schedules a call to _update_progress() in the main thread. Then, _update_progress() actually updates the progress bar so you can see how the training is going in real-time.

def update_progress(self, progress):
    # Schedule GUI update in the main thread
    self.master.after(0, self._update_progress, progress)

def _update_progress(self, progress):
    self.progress_var.set(progress)
    self.master.update_idletasks()

Enabling Training Buttons

Enabling the training buttons follows a similar pattern. The enable_train_buttons() function schedules a call to _enable_train_buttons() in the main thread. The _enable_train_buttons() function then makes the buttons clickable again by setting their state to normal.

def enable_train_buttons(self):
    # Schedule GUI update in the main thread
    self.master.after(0, self._enable_train_buttons)

def _enable_train_buttons(self):
    self.train_button.config(state='normal')
    self.select_train_dir_button.config(state='normal')

Image Upload, Prediction, and Display

Uploading Images

When you upload an image, the upload_image() function opens a file dialog for you to select an image. If you choose an image, it opens the file using Pillow and keeps a copy for future use. Instead of resizing the image, it creates a scrollable canvas to show the entire image. If there’s an issue opening the image, an error message will pop up.

Next, it checks if the model is loaded. If not, it tries to load the model from a saved path and gives feedback on whether this was successful. If everything is in place, it calls predict_image() to classify the image, and then updates the UI with the prediction and confidence.

def upload_image(self):
    file_path = filedialog.askopenfilename(
        filetypes=[
            ("Image Files", "*.png;*.jpg;*.jpeg;*.bmp;*.ppm"),
            ("All Files", "*.*")
        ]
    )
    if not file_path:
        return  # User cancelled

    try:
        img = Image.open(file_path)
        self.original_img = img.copy()  # Keep a copy for potential future use

        img_tk = ImageTk.PhotoImage(img)
        self.image_canvas.delete("all")
        self.image_canvas.create_image(0, 0, anchor='nw', image=img_tk)
        self.image_canvas.image = img_tk  # Keep a reference

        self.image_canvas.config(scrollregion=self.image_canvas.bbox(tk.ALL))
    except Exception as e:
        messagebox.showerror("Error", f"Unable to open image.\n{e}")
        return

    if not self.model:
        if os.path.exists(MODEL_PATH):
            try:
                self.log_message("Loading trained model for prediction...\n")
                self.model = tf.keras.models.load_model(MODEL_PATH)
                self.log_message("Model loaded successfully.\n")
            except Exception as e:
                messagebox.showerror("Error", f"Failed to load model.\n{e}")
                return
        else:
            messagebox.showerror("Model Not Found",
                                 f"Model file '{MODEL_PATH}' not found. Please train the model first.")
            return

    class_id, confidence = self.predict_image(file_path)
    if class_id is not None:
        if class_id < len(class_names):
            predicted_class = class_names[class_id]
        else:
            predicted_class = f"Unknown Class ({class_id})"
        confidence_percent = confidence * 100
        self.prediction_label.config(text=f"Prediction: {predicted_class}")
        self.confidence_label.config(text=f"Confidence: {confidence_percent:.2f}%")
    else:
        self.prediction_label.config(text="Prediction: Error")
        self.confidence_label.config(text="Confidence: N/A")

Image Preprocessing

Before making predictions, the image needs to be preprocessed to fit the model’s requirements. The preprocess_image() function handles this by resizing the image, converting its color format, normalizing pixel values, and expanding the dimensions to include a batch size. If something goes wrong during this process, it raises an error to keep things running smoothly.

def preprocess_image(self, image_path):
    """
    Preprocess the image for prediction.

    Args:
        image_path (str): Path to the image file.

    Returns:
        np.array: Preprocessed image ready for prediction.
    """
    img = cv2.imread(image_path)
    if img is None:
        raise ValueError(f"Unable to read image at {image_path}")
    img = cv2.resize(img, (IMG_WIDTH, IMG_HEIGHT))
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = img.astype('float32') / 255.0
    img = np.expand_dims(img, axis=0)  # Add batch dimension
    return img

Predicting the Image

Once the image is preprocessed, predict_image() takes over to handle the prediction. It runs the preprocessed image through the model to get prediction scores for each class. It then identifies the class with the highest probability and the confidence of this prediction, which is returned for display.

def predict_image(self, image_path):
    """
    Preprocess the image, run the model prediction, and return the predicted class and confidence score.

    Args:
        image_path (str): Path to the image file.

    Returns:
        Tuple: (class_id, confidence_score)
    """
    try:
        img = self.preprocess_image(image_path)
    except Exception as e:
        messagebox.showerror("Error", str(e))
        return None, None

    predictions = self.model.predict(img)
    class_id = np.argmax(predictions, axis=1)[0]
    confidence = np.max(predictions)
    return class_id, confidence

Running the Application

Bringing it all together

At this point, everything is coming together. We make sure our app only runs when executed directly by adding a simple check using if __name__ == "__main__". This tells Python to execute the following code only when the script is run as the main program.

Next, we create the main window of our application with tk.Tk(). This window will hold all our buttons, frames, and everything else we’ve designed. Then, we initialize our app by creating an instance of TrafficSignApp and passing in the root window.

if __name__ == "__main__":
  root = tk.Tk()
  app = TrafficSignApp(root)

In the final step, we start the main event loop with root.mainloop(), ensuring the window stays active and responsive to any user interactions.

  root.mainloop()

Example

You can get the dataset from either one of these two URLs : 

To train the model, simply choose the folder containing your training data by clicking the “Select Training Directory” button. After that, click the “Start Training” button, and the training process will begin automatically!

As you can see in the video, I’ve already uploaded the training folder.

Full Code

import tkinter as tk
from tkinter import ttk
from tkinter import filedialog, messagebox, scrolledtext
from PIL import Image, ImageTk
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import numpy as np
import cv2
import os
from sklearn.model_selection import train_test_split
import threading


# Suppress TensorFlow warnings
import logging
logging.getLogger('tensorflow').setLevel(logging.ERROR)


# Define image dimensions and number of classes
IMG_HEIGHT, IMG_WIDTH = 30, 30
NUM_CLASSES = 43


# Define model path
MODEL_PATH = 'traffic_sign_model.keras'


# Class names for GTSRB (German Traffic Sign Recognition Benchmark)
# Ensure that these names match your dataset's classes. Update if necessary.
class_names = [
  'Speed Limit 20', 'Speed Limit 30', 'Speed Limit 50', 'Speed Limit 60',
  'Speed Limit 70', 'Speed Limit 80', 'End of Speed Limit 80', 'Speed Limit (End)',
  'Speed Limit (End)', 'Right-of-Way at Intersection', 'Priority Road',
  'Give Way', 'Stop', 'No Traffic Both Ways', 'No Trucks',
  'No Entry', 'General Caution', 'Dangerous Curve Left', 'Dangerous Curve Right',
  'Double Curve', 'Bumpy Road', 'Slippery Road', 'Road Narrows on the Right',
  'Road Work', 'Traffic Signals', 'Pedestrians', 'Children Crossing',
  'Bicycles Crossing', 'Beware of Ice/Snow', 'Wild Animals Crossing',
  'End of All Speed and Passing Limits', 'Turn Right Ahead',
  'Turn Left Ahead', 'Ahead Only', 'Go Straight or Right', 'Go Straight or Left',
  'Keep Right', 'Keep Left', 'Roundabout Mandatory', 'End of No Passing',
  'End of No Passing by Vehicles over 3.5 Tons', 'No Stopping', 'No Parking',
  'No Standing'
]


def load_data(data_dir):
  """
  Load images and labels from the specified directory.


  Args:
      data_dir (str): Path to the training data directory.


  Returns:
      Tuple of NumPy arrays: (images, labels)
  """
  images, labels = [], []
  for class_dir in os.listdir(data_dir):
      class_path = os.path.join(data_dir, class_dir)
      if not os.path.isdir(class_path):
          continue
      try:
          label = int(class_dir)
      except ValueError:
          print(f"Skipping non-integer directory: {class_dir}")
          continue
      for img_name in os.listdir(class_path):
          # Filter out non-image files
          if not img_name.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.ppm')):
              continue  # Skip non-image files like CSVs
          img_path = os.path.join(class_path, img_name)
          img = cv2.imread(img_path)
          if img is None:
              print(f"Warning: Unable to read image {img_path}. Skipping.")
              continue
          img = cv2.resize(img, (IMG_WIDTH, IMG_HEIGHT))
          img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
          images.append(img)
          labels.append(label)
  return np.array(images), np.array(labels)


def preprocess_data(X, y):
  """
  Normalize image data and convert labels to categorical format.


  Args:
      X (np.array): Array of images.
      y (np.array): Array of labels.


  Returns:
      Tuple of NumPy arrays: (X_processed, y_processed)
  """
  X = X.astype('float32') / 255.0
  y = to_categorical(y, NUM_CLASSES)
  return X, y


def build_model(input_shape, num_classes):
  """
  Build and return a CNN model for traffic sign classification.


  Args:
      input_shape (tuple): Shape of the input images (height, width, channels).
      num_classes (int): Number of classes.


  Returns:
      keras.Model: Compiled CNN model.
  """
  inputs = layers.Input(shape=input_shape)
  x = layers.Conv2D(32, (3, 3), activation='relu')(inputs)
  x = layers.BatchNormalization()(x)
  x = layers.MaxPooling2D((2, 2))(x)


  x = layers.Conv2D(64, (3, 3), activation='relu')(x)
  x = layers.BatchNormalization()(x)
  x = layers.MaxPooling2D((2, 2))(x)


  x = layers.Flatten()(x)
  x = layers.Dense(128, activation='relu')(x)
  x = layers.BatchNormalization()(x)
  outputs = layers.Dense(num_classes, activation='softmax')(x)


  model = models.Model(inputs=inputs, outputs=outputs)
  return model


def save_model(model, path=MODEL_PATH):
  """
  Save the trained model to the specified path.


  Args:
      model (keras.Model): Trained Keras model.
      path (str): File path to save the model.
  """
  model.save(path)


class TrainingCallback(tf.keras.callbacks.Callback):
  """
  Custom callback to update the GUI with training progress and logs.
  """


  def __init__(self, app):
      super().__init__()
      self.app = app


  def on_epoch_end(self, epoch, logs=None):
      logs = logs or {}
      progress = ((epoch + 1) / self.params['epochs']) * 100
      message = (f"Epoch {epoch + 1}/{self.params['epochs']} - "
                 f"Loss: {logs.get('loss', 0):.4f} - "
                 f"Accuracy: {logs.get('accuracy', 0):.4f} - "
                 f"Val_Loss: {logs.get('val_loss', 0):.4f} - "
                 f"Val_Accuracy: {logs.get('val_accuracy', 0):.4f}\n")
      self.app.update_progress(progress)
      self.app.log_message(message)


  def on_train_end(self, logs=None):
      self.app.log_message("Training completed.\n")
      self.app.enable_train_buttons()
      messagebox.showinfo("Training Complete", f"Model trained and saved as {MODEL_PATH}.")


class TrafficSignApp:
  def __init__(self, master):
      self.master = master
      master.title("Traffic Sign Detection and Training - The Pycodes")
      master.geometry("800x800")  # Increased height to accommodate scrollbars


      self.model = None


      # Initialize UI components
      self.create_widgets()


  def create_widgets(self):
      # Frame for Training
      training_frame = tk.LabelFrame(self.master, text="Model Training", padx=10, pady=10)
      training_frame.pack(fill="both", expand="yes", padx=10, pady=10)


      # Training Directory Selection
      self.train_dir_label = tk.Label(training_frame, text="Training Directory: Not Selected", wraplength=600,
                                      justify="left")
      self.train_dir_label.pack(anchor='w')


      self.select_train_dir_button = tk.Button(training_frame, text="Select Training Directory",
                                               command=self.select_train_directory)
      self.select_train_dir_button.pack(pady=5)


      # Start Training Button
      self.train_button = tk.Button(training_frame, text="Start Training", command=self.start_training,
                                    state='disabled', bg='green', fg='white')
      self.train_button.pack(pady=5)


      # Progress Bar
      self.progress_var = tk.DoubleVar()
      self.progress_bar = ttk.Progressbar(training_frame, variable=self.progress_var, maximum=100)
      self.progress_bar.pack(fill='x', pady=5)


      # Training Logs
      self.log_text = scrolledtext.ScrolledText(training_frame, height=10, state='disabled')
      self.log_text.pack(fill='both', expand=True, pady=5)


      # Separator
      separator = tk.Frame(self.master, height=2, bd=1, relief='sunken')
      separator.pack(fill='x', padx=5, pady=10)


      # Frame for Prediction
      prediction_frame = tk.LabelFrame(self.master, text="Traffic Sign Prediction", padx=10, pady=10)
      prediction_frame.pack(fill="both", expand="yes", padx=10, pady=10)


      # Create a Canvas inside prediction_frame
      canvas = tk.Canvas(prediction_frame, borderwidth=0, background="#f0f0f0")
      canvas.pack(side=tk.LEFT, fill=tk.BOTH, expand=True)


      # Add vertical and horizontal scrollbars to the Canvas
      v_scrollbar = tk.Scrollbar(prediction_frame, orient=tk.VERTICAL, command=canvas.yview)
      v_scrollbar.pack(side=tk.RIGHT, fill=tk.Y)


      h_scrollbar = tk.Scrollbar(self.master, orient=tk.HORIZONTAL, command=canvas.xview)
      h_scrollbar.pack(fill=tk.X)


      canvas.configure(yscrollcommand=v_scrollbar.set, xscrollcommand=h_scrollbar.set)


      # Create an inner frame to hold prediction widgets
      self.inner_prediction_frame = tk.Frame(canvas, background="#f0f0f0")
      canvas.create_window((0, 0), window=self.inner_prediction_frame, anchor='nw')


      # Bind the inner frame to configure the scroll region
      self.inner_prediction_frame.bind("<Configure>", lambda event, canvas=canvas: self.on_frame_configure(canvas))


      # Upload Image Button
      self.upload_button = tk.Button(self.inner_prediction_frame, text="Upload Image for Prediction", command=self.upload_image,
                                     bg='blue', fg='white', font=('Helvetica', 12, 'bold'))
      self.upload_button.pack(pady=10)


      # Create a frame to hold the canvas and scrollbars for the image
      self.image_frame = tk.Frame(self.inner_prediction_frame)
      self.image_frame.pack(pady=5, fill=tk.BOTH, expand=True)


      # Create Canvas for image display
      self.image_canvas = tk.Canvas(self.image_frame, width=400, height=300, bg='gray')
      self.image_canvas.pack(side=tk.LEFT, fill=tk.BOTH, expand=True)


      # Add scrollbars to the image_canvas
      img_v_scrollbar = tk.Scrollbar(self.image_frame, orient=tk.VERTICAL, command=self.image_canvas.yview)
      img_v_scrollbar.pack(side=tk.RIGHT, fill=tk.Y)
      img_h_scrollbar = tk.Scrollbar(self.inner_prediction_frame, orient=tk.HORIZONTAL, command=self.image_canvas.xview)
      img_h_scrollbar.pack(fill=tk.X)


      self.image_canvas.configure(yscrollcommand=img_v_scrollbar.set, xscrollcommand=img_h_scrollbar.set)


      # Prediction Label
      self.prediction_label = tk.Label(self.inner_prediction_frame, text="Prediction: None", font=('Helvetica', 14))
      self.prediction_label.pack(pady=5)


      # Confidence Label
      self.confidence_label = tk.Label(self.inner_prediction_frame, text="Confidence: 0.00%", font=('Helvetica', 14))
      self.confidence_label.pack(pady=5)


  def on_frame_configure(self, canvas):
      """
      Reset the scroll region to encompass the inner frame
      """
      canvas.configure(scrollregion=canvas.bbox("all"))


  def select_train_directory(self):
      directory = filedialog.askdirectory()
      if directory:
          self.train_dir = directory
          self.train_dir_label.config(text=f"Training Directory: {directory}")
          self.train_button.config(state='normal')
      else:
          self.train_dir_label.config(text="Training Directory: Not Selected")
          self.train_button.config(state='disabled')


  def start_training(self):
      # Disable buttons to prevent multiple training sessions
      self.train_button.config(state='disabled')
      self.select_train_dir_button.config(state='disabled')
      self.log_text.config(state='normal')
      self.log_text.insert(tk.END, "Starting training...\n")
      self.log_text.config(state='disabled')
      self.progress_var.set(0)


      # Start training in a separate thread
      training_thread = threading.Thread(target=self.train_model)
      training_thread.start()


  def train_model(self):
      try:
          # Load and preprocess data
          self.log_message("Loading training data...\n")
          X, y = load_data(self.train_dir)
          if X.size == 0:
              self.log_message("No data found. Please check the training directory.\n")
              self.enable_train_buttons()
              return
          self.log_message(f"Loaded {X.shape[0]} images.\n")


          self.log_message("Preprocessing data...\n")
          X, y = preprocess_data(X, y)


          # Split into training and validation sets
          self.log_message("Splitting data into training and validation sets...\n")
          X_train, X_val, y_train, y_val = train_test_split(
              X, y, test_size=0.2, random_state=42
          )
          self.log_message(f"Training samples: {X_train.shape[0]}\n")
          self.log_message(f"Validation samples: {X_val.shape[0]}\n")


          # Data Augmentation
          self.log_message("Applying data augmentation...\n")
          datagen = ImageDataGenerator(
              rotation_range=10,
              zoom_range=0.1,
              width_shift_range=0.1,
              height_shift_range=0.1,
              horizontal_flip=False
          )
          datagen.fit(X_train)


          # Build the model
          self.log_message("Building the CNN model...\n")
          self.model = build_model(input_shape=(IMG_HEIGHT, IMG_WIDTH, 3), num_classes=NUM_CLASSES)
          self.log_message("Model architecture:\n")
          model_summary = []
          self.model.summary(print_fn=lambda x: model_summary.append(x))
          self.log_message("\n".join(model_summary) + "\n")


          # Compile the model
          self.log_message("Compiling the model...\n")
          self.model.compile(
              optimizer='adam',
              loss='categorical_crossentropy',
              metrics=['accuracy']
          )


          # Define callbacks
          training_callback = TrainingCallback(self)
          early_stopping = EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)


          # Train the model with EarlyStopping
          self.log_message("Starting training...\n")
          self.model.fit(
              datagen.flow(X_train, y_train, batch_size=32),
              epochs=20,
              validation_data=(X_val, y_val),
              callbacks=[training_callback, early_stopping],
              verbose=0  # Suppress Keras' own output
          )


          # Save the model
          save_model(self.model)
          self.log_message(f"Model trained and saved as {MODEL_PATH}.\n")


      except Exception as e:
          self.log_message(f"Error during training: {e}\n")
          messagebox.showerror("Training Error", str(e))
          self.enable_train_buttons()


  def log_message(self, message):
      # Schedule GUI update in the main thread
      self.master.after(0, self._log_message, message)


  def _log_message(self, message):
      self.log_text.config(state='normal')
      self.log_text.insert(tk.END, message)
      self.log_text.see(tk.END)
      self.log_text.config(state='disabled')


  def update_progress(self, progress):
      # Schedule GUI update in the main thread
      self.master.after(0, self._update_progress, progress)


  def _update_progress(self, progress):
      self.progress_var.set(progress)
      self.master.update_idletasks()


  def enable_train_buttons(self):
      # Schedule GUI update in the main thread
      self.master.after(0, self._enable_train_buttons)


  def _enable_train_buttons(self):
      self.train_button.config(state='normal')
      self.select_train_dir_button.config(state='normal')


  def upload_image(self):
      file_path = filedialog.askopenfilename(
          filetypes=[
              ("Image Files", "*.png;*.jpg;*.jpeg;*.bmp;*.ppm"),
              ("All Files", "*.*")
          ]
      )
      if not file_path:
          return  # User cancelled


      # Display the image with scrollbars
      try:
          img = Image.open(file_path)
          self.original_img = img.copy()  # Keep a copy for potential future use


          # Do not resize the image for display to enable scrolling
          img_tk = ImageTk.PhotoImage(img)
          self.image_canvas.delete("all")
          self.image_canvas.create_image(0, 0, anchor='nw', image=img_tk)
          self.image_canvas.image = img_tk  # Keep a reference


          # Update the scrollregion to the size of the image
          self.image_canvas.config(scrollregion=self.image_canvas.bbox(tk.ALL))
      except Exception as e:
          messagebox.showerror("Error", f"Unable to open image.\n{e}")
          return


      # Predict the class
      if not self.model:
          if os.path.exists(MODEL_PATH):
              try:
                  self.log_message("Loading trained model for prediction...\n")
                  self.model = tf.keras.models.load_model(MODEL_PATH)
                  self.log_message("Model loaded successfully.\n")
              except Exception as e:
                  messagebox.showerror("Error", f"Failed to load model.\n{e}")
                  return
          else:
              messagebox.showerror("Model Not Found",
                                   f"Model file '{MODEL_PATH}' not found. Please train the model first.")
              return


      class_id, confidence = self.predict_image(file_path)
      if class_id is not None:
          if class_id < len(class_names):
              predicted_class = class_names[class_id]
          else:
              predicted_class = f"Unknown Class ({class_id})"
          confidence_percent = confidence * 100
          self.prediction_label.config(text=f"Prediction: {predicted_class}")
          self.confidence_label.config(text=f"Confidence: {confidence_percent:.2f}%")
      else:
          self.prediction_label.config(text="Prediction: Error")
          self.confidence_label.config(text="Confidence: N/A")


  def preprocess_image(self, image_path):
      """
      Preprocess the image for prediction.


      Args:
          image_path (str): Path to the image file.


      Returns:
          np.array: Preprocessed image ready for prediction.
      """
      img = cv2.imread(image_path)
      if img is None:
          raise ValueError(f"Unable to read image at {image_path}")
      img = cv2.resize(img, (IMG_WIDTH, IMG_HEIGHT))
      img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
      img = img.astype('float32') / 255.0
      img = np.expand_dims(img, axis=0)  # Add batch dimension
      return img


  def predict_image(self, image_path):
      """
      Preprocess the image, run the model prediction, and return the predicted class and confidence score.


      Args:
          image_path (str): Path to the image file.


      Returns:
          Tuple: (class_id, confidence_score)
      """
      try:
          img = self.preprocess_image(image_path)
      except Exception as e:
          messagebox.showerror("Error", str(e))
          return None, None


      predictions = self.model.predict(img)
      class_id = np.argmax(predictions, axis=1)[0]
      confidence = np.max(predictions)
      return class_id, confidence


if __name__ == "__main__":
  root = tk.Tk()
  app = TrafficSignApp(root)
  root.mainloop()

Happy Coding!

Leave a Comment

Your email address will not be published. Required fields are marked *

Scroll to Top
×