Home » Tutorials » How to Predict Titanic Survival with Random Forest in Python

How to Predict Titanic Survival with Random Forest in Python

Ever wondered what your chances of survival would be if you were aboard the Titanic? The tragic story of the Titanic has fascinated people for over a century, and with today’s technology, we can simulate what might have happened if you were a passenger on that fateful voyage.

Today, you’ll learn how to predict Titanic survival using Python and the Random Forest Classifier. In this tutorial, we’ll walk you through building a program that takes your inputs, like age, ticket class, and fare, processes them using decision trees, and then displays a calculated prediction of your survival chances in a tkinter message box. So, let’s get started!

Table of Contents

Necessary Libraries

To make sure everything works smoothly, you’ll need to install a few libraries. Just pop these commands into your terminal or command prompt:

$ pip install pandas 
$ pip install scikit-learn 
$ pip install matplotlib 
$ pip install tk 

Note: If you’re curious about how the Random Forest model works, it’s a fascinating ensemble method that combines multiple decision trees to improve accuracy—check out this link for more details.

Imports

import pandas as pd
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.ensemble import RandomForestClassifier
from sklearn import tree
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
import matplotlib.pyplot as plt
import tkinter as tk
from tkinter import ttk
from tkinter import messagebox
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg

Before we dive into today’s journey, let’s gather our trusty tools, each with its unique skill set:

  • First up, we have pandas. This tool helps us work with data in table format.
  • Next, we’ll use StratifiedShuffleSplit to split our data into training and testing sets.
  • Then, we bring in RandomForestClassifier. This tool creates a bunch of decision trees and combines their predictions to make accurate forecasts.
  • We also need tree for visualizing those decision trees.
  • For evaluating our model’s performance, we’ll use accuracy_score, classification_report, and confusion_matrix.
  • To create plots and charts, we’ll turn to matplotlib.pyplot.
  • For our graphical user interface, we’ll rely on tkinter. This includes using themed widgets from ttk and displaying messages with messagebox.
  • Finally, we’ll use FigureCanvasTkAgg to show our matplotlib plots right in the tkinter window.

Loading and Preprocessing the Dataset

After gathering our tools, the next step is to download the Titanic dataset, which holds detailed information about each passenger, like age and gender.

# Load Titanic dataset
url = 'https://raw.githubusercontent.com/datasciencedojo/datasets/master/titanic.csv'
data = pd.read_csv(url)

We don’t need all of this information for our analysis. So, we’ll focus on the columns we need and discard the rest using dropna() to remove any missing values.

# Preprocess the data
data = data[['Pclass', 'Age', 'Fare', 'Sex', 'Survived']]
data = data.dropna()

Next, we need to convert categorical text data to numeric values. For example, we’ll map gender to numbers: 0 for male and 1 for female.

# Convert 'Sex' to numeric
data['Sex'] = data['Sex'].map({'male': 0, 'female': 1})

Additionally, we’ll add a new column called IsBaby. Here, 1 indicates that the passenger is a baby (1 year old or younger), and 0 means otherwise.

# Create 'IsBaby' feature
data['IsBaby'] = (data['Age'] <= 1).astype(int)

Preparing the Data for Training

Now that we have the relevant columns from the dataset, we’ll split them into two sets:

  • Features (inputs): These are Pclass, Age, Fare, Sex, and IsBaby.
  • Target: This is the Survived column, where 1 indicates the passenger survived and 0 means they did not.

Here’s how we define our features and target:

# Features and target variable
X = data[['Pclass', 'Age', 'Fare', 'Sex', 'IsBaby']]
y = data['Survived']

Next, we use StratifiedShuffleSplit to split the data into training and testing sets. The training set, which is about 70% of the data, will be used to train our model. The testing set, approximately 30%, will help us evaluate the model’s performance.

# Split data into training and test sets
split = StratifiedShuffleSplit(n_splits=1, test_size=0.3, random_state=42)
for train_index, test_index in split.split(X, y):
    X_train, X_test = X.iloc[train_index], X.iloc[test_index]
    y_train, y_test = y.iloc[train_index], y.iloc[test_index]

Building and Evaluating the Random Forest Model

With our data ready, the next step is to create our model. We use a RandomForestClassifier with 100 trees to make our predictions more accurate by combining the results of multiple decision trees.

# Create and train the random forest model
model = RandomForestClassifier(n_estimators=100, max_depth=10, random_state=42)
model.fit(X_train, y_train)

Once the model is trained using the training data, we make predictions on the test data and store them in y_pred.

# Make predictions
y_pred = model.predict(X_test)

To evaluate how well our model performs, we compare the predictions to the actual test data. We use accuracy_score, classification_report, and confusion_matrix to assess the model’s accuracy and performance.

# Evaluate the model
accuracy = accuracy_score(y_test, y_pred)
report = classification_report(y_test, y_pred)
conf_matrix = confusion_matrix(y_test, y_pred)

The Predict Survival Function

def predict_survival():
  try:
      pclass = int(combo_pclass.get())
      age = float(entry_age.get())
      fare = float(spin_fare.get())
      gender = combo_gender.get()
      is_baby = 1 if age <= 1 else 0


      # Map gender to numeric
      gender_numeric = 1 if gender == 'Female' else 0


      # Create a DataFrame with the same feature names as used during training
      input_data = pd.DataFrame([[pclass, age, fare, gender_numeric, is_baby]],
                                columns=['Pclass', 'Age', 'Fare', 'Sex', 'IsBaby'])


      # Print input data for verification
      print(f"Input Data: {input_data}")


      # Make the prediction
      prediction = model.predict(input_data)[0]
      print(f"Prediction: {prediction}")
      result = 'Survived' if prediction == 1 else 'Not Survived'
      messagebox.showinfo("Prediction Result", f"Prediction: {result}")
  except ValueError:
      messagebox.showerror("Input Error", "Please enter valid numerical values.")

Once everything is in place, we just need to make predictions. The predict_survival() function takes the user’s inputs and formats them into a DataFrame using pd.DataFrame(). The model then applies the patterns it learned during training to these inputs to generate a prediction. The result is printed and displayed in a message box. If an error occurs during this process, an error message will be shown to the user.

Visualize Decision Tree and Model Summary

For this step, let’s dive into how we visualize our decision tree and summarize the model’s performance:

First up, we have the show_tree() function. This function is all about bringing the decision tree to life. We start by creating a larger plot with figsize=(20, 15) to ensure everything is clear and visible. Using matplotlib, we plot one of the decision trees from our RandomForest model, complete with feature names and class labels. We give it a title for context and save it as a high-resolution PNG file in your project directory. To make it easy to view, we also embed this plot directly into the tkinter GUI with FigureCanvasTkAgg, so you can see the tree right within your application.

def show_tree():
  # Plot the first decision tree in the forest with increased size
  fig, ax = plt.subplots(figsize=(20, 15))  # Increase size for better visibility
  tree.plot_tree(model.estimators_[0],
                 feature_names=['Pclass', 'Age', 'Fare', 'Sex', 'IsBaby'],
                 class_names=['Not Survived', 'Survived'],
                 filled=True, fontsize=10, ax=ax)
  ax.set_title('Decision Tree for Titanic Survival Prediction (Random Forest)')


  # Save the plot as a PNG file with higher resolution
  plt.savefig("decision_tree.png", dpi=300)


  # Embed the plot in Tkinter window
  canvas = FigureCanvasTkAgg(fig, master=window)
  canvas.draw()
  canvas.get_tk_widget().pack(side=tk.TOP, fill=tk.BOTH, expand=1)

Next, we have the show_summary() function. This function provides a snapshot of how well our model is performing. It displays the model’s accuracy, a detailed classification report, and the confusion matrix in a message box. This gives you a quick overview of how well the model is doing and where it might be making errors.

def show_summary():
  summary_text = (
      f"Model Accuracy: {accuracy:.2f}\n\n"
      "Classification Report:\n"
      f"{report}\n\n"
      "Confusion Matrix:\n"
      f"{conf_matrix}"
  )
  messagebox.showinfo("Model Summary", summary_text)

Setting Up the Main Window

Congratulations, everyone! We’ve reached the final part of our script, where we set up the main window using Tkinter to create a graphical interface for controlling the whole operation. First, we set the window’s title and define its geometry. We also add some styling to make the labels and buttons more visually appealing.

For each input feature, we create labels and corresponding widgets like comboboxes and entry boxes where users can select or enter data. We then add three buttons: “Predict Survival” to call the predict_survival() function, “Show Decision Tree” to call the show_tree() function, and “Show Model Summary” to call the show_summary() function.

# Tkinter window
window = tk.Tk()
window.title("Titanic Survival Predictor - The Pycodes")
window.geometry("600x700")


# Styling
style = ttk.Style()
style.configure('TLabel', font=('Helvetica', 12))
style.configure('TButton', font=('Helvetica', 12))


# Passenger Class dropdown menu
tk.Label(window, text="Passenger Class:", font=('Helvetica', 12)).pack(pady=5)
combo_pclass = ttk.Combobox(window, values=[1, 2, 3], state="readonly")
combo_pclass.pack(pady=5)
combo_pclass.set("Select Class")


# Gender dropdown menu
tk.Label(window, text="Gender:", font=('Helvetica', 12)).pack(pady=5)
combo_gender = ttk.Combobox(window, values=['Male', 'Female'], state="readonly")
combo_gender.pack(pady=5)
combo_gender.set("Select Gender")


# Age input
tk.Label(window, text="Age:", font=('Helvetica', 12)).pack(pady=5)
entry_age = tk.Entry(window)
entry_age.pack(pady=5)


# Fare dropdown menu (using Spinbox for a range)
tk.Label(window, text="Fare:", font=('Helvetica', 12)).pack(pady=5)
spin_fare = ttk.Spinbox(window, from_=0, to=500, increment=1, format="%.2f")
spin_fare.pack(pady=5)


# Buttons
btn_predict = ttk.Button(window, text="Predict Survival", command=predict_survival)
btn_predict.pack(pady=10)


btn_show_tree = ttk.Button(window, text="Show Decision Tree", command=show_tree)
btn_show_tree.pack(pady=10)


btn_summary = ttk.Button(window, text="Show Model Summary", command=show_summary)
btn_summary.pack(pady=10)

Lastly, we start the main event loop with the mainloop() method to keep the window running and responsive to user interactions.

# Run the Tkinter event loop
window.mainloop()

Example

Full Code

import pandas as pd
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.ensemble import RandomForestClassifier
from sklearn import tree
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
import matplotlib.pyplot as plt
import tkinter as tk
from tkinter import ttk
from tkinter import messagebox
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg




# Load Titanic dataset
url = 'https://raw.githubusercontent.com/datasciencedojo/datasets/master/titanic.csv'
data = pd.read_csv(url)


# Preprocess the data
data = data[['Pclass', 'Age', 'Fare', 'Sex', 'Survived']]
data = data.dropna()


# Convert 'Sex' to numeric
data['Sex'] = data['Sex'].map({'male': 0, 'female': 1})


# Create 'IsBaby' feature
data['IsBaby'] = (data['Age'] <= 1).astype(int)


# Features and target variable
X = data[['Pclass', 'Age', 'Fare', 'Sex', 'IsBaby']]
y = data['Survived']


# Split data into training and test sets
split = StratifiedShuffleSplit(n_splits=1, test_size=0.3, random_state=42)
for train_index, test_index in split.split(X, y):
  X_train, X_test = X.iloc[train_index], X.iloc[test_index]
  y_train, y_test = y.iloc[train_index], y.iloc[test_index]


# Create and train the random forest model
model = RandomForestClassifier(n_estimators=100, max_depth=10, random_state=42)
model.fit(X_train, y_train)


# Make predictions
y_pred = model.predict(X_test)


# Evaluate the model
accuracy = accuracy_score(y_test, y_pred)
report = classification_report(y_test, y_pred)
conf_matrix = confusion_matrix(y_test, y_pred)


def predict_survival():
  try:
      pclass = int(combo_pclass.get())
      age = float(entry_age.get())
      fare = float(spin_fare.get())
      gender = combo_gender.get()
      is_baby = 1 if age <= 1 else 0


      # Map gender to numeric
      gender_numeric = 1 if gender == 'Female' else 0


      # Create a DataFrame with the same feature names as used during training
      input_data = pd.DataFrame([[pclass, age, fare, gender_numeric, is_baby]],
                                columns=['Pclass', 'Age', 'Fare', 'Sex', 'IsBaby'])


      # Print input data for verification
      print(f"Input Data: {input_data}")


      # Make the prediction
      prediction = model.predict(input_data)[0]
      print(f"Prediction: {prediction}")
      result = 'Survived' if prediction == 1 else 'Not Survived'
      messagebox.showinfo("Prediction Result", f"Prediction: {result}")
  except ValueError:
      messagebox.showerror("Input Error", "Please enter valid numerical values.")


def show_tree():
  # Plot the first decision tree in the forest with increased size
  fig, ax = plt.subplots(figsize=(20, 15))  # Increase size for better visibility
  tree.plot_tree(model.estimators_[0],
                 feature_names=['Pclass', 'Age', 'Fare', 'Sex', 'IsBaby'],
                 class_names=['Not Survived', 'Survived'],
                 filled=True, fontsize=10, ax=ax)
  ax.set_title('Decision Tree for Titanic Survival Prediction (Random Forest)')


  # Save the plot as a PNG file with higher resolution
  plt.savefig("decision_tree.png", dpi=300)


  # Embed the plot in Tkinter window
  canvas = FigureCanvasTkAgg(fig, master=window)
  canvas.draw()
  canvas.get_tk_widget().pack(side=tk.TOP, fill=tk.BOTH, expand=1)


def show_summary():
  summary_text = (
      f"Model Accuracy: {accuracy:.2f}\n\n"
      "Classification Report:\n"
      f"{report}\n\n"
      "Confusion Matrix:\n"
      f"{conf_matrix}"
  )
  messagebox.showinfo("Model Summary", summary_text)


# Tkinter window
window = tk.Tk()
window.title("Titanic Survival Predictor - The Pycodes")
window.geometry("600x700")


# Styling
style = ttk.Style()
style.configure('TLabel', font=('Helvetica', 12))
style.configure('TButton', font=('Helvetica', 12))


# Passenger Class dropdown menu
tk.Label(window, text="Passenger Class:", font=('Helvetica', 12)).pack(pady=5)
combo_pclass = ttk.Combobox(window, values=[1, 2, 3], state="readonly")
combo_pclass.pack(pady=5)
combo_pclass.set("Select Class")


# Gender dropdown menu
tk.Label(window, text="Gender:", font=('Helvetica', 12)).pack(pady=5)
combo_gender = ttk.Combobox(window, values=['Male', 'Female'], state="readonly")
combo_gender.pack(pady=5)
combo_gender.set("Select Gender")


# Age input
tk.Label(window, text="Age:", font=('Helvetica', 12)).pack(pady=5)
entry_age = tk.Entry(window)
entry_age.pack(pady=5)


# Fare dropdown menu (using Spinbox for a range)
tk.Label(window, text="Fare:", font=('Helvetica', 12)).pack(pady=5)
spin_fare = ttk.Spinbox(window, from_=0, to=500, increment=1, format="%.2f")
spin_fare.pack(pady=5)


# Buttons
btn_predict = ttk.Button(window, text="Predict Survival", command=predict_survival)
btn_predict.pack(pady=10)


btn_show_tree = ttk.Button(window, text="Show Decision Tree", command=show_tree)
btn_show_tree.pack(pady=10)


btn_summary = ttk.Button(window, text="Show Model Summary", command=show_summary)
btn_summary.pack(pady=10)


# Run the Tkinter event loop
window.mainloop()

Happy Coding!

Leave a Comment

Your email address will not be published. Required fields are marked *

Scroll to Top