Use streamlit, seaborn and scikit-learn to build a webapp to predict if a pokemon is legendary or not.
In this tutorial, we will be working with Pokemon Dataset. We will be using seaborn for the visualizations and RandomForest to build our model. View the website live
Install and Import Necessary Libraries
Setup Virtual Environment
pip install virtualenv /* Install virtual environment */
virtualenv venv /* Create a virtual environment */
venv/Scripts/activate /* Activate the virtual environment */
Install Libraries
Make sure your virtual environment is activated before installing the libraries
pip install streamlit, seaborn, scikit-learn
Import the Libraries
import streamlit as st
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import confusion_matrix , accuracy_score, precision_score, recall_score
We import streamlit, seaborn and models, metrics from sci-kit-learn.
Helper Functions
Title Function
def title(s):
st.text("")
st.title(s)
st.text("")
The title function essentially displays and title between a couple of blank lines
Clean and Split Data
def clean_and_split(df):
legendary_df = df[df['is_legendary'] == 1]
normal_df = df[df['is_legendary'] == 0].sample(75)
legendary_df.fillna(legendary_df.mean(),inplace=True)
normal_df.fillna(normal_df.mean(),inplace=True)
feature_list = ['weight_kg' , 'height_m' , 'sp_attack' , 'attack' , 'sp_defense' , 'defense' , 'speed' , 'hp' , 'is_legendary']
sub_df = pd.concat([legendary_df,normal_df])[feature_list]
X = sub_df.loc[:, sub_df.columns != 'is_legendary']
Y = sub_df['is_legendary']
X_train, X_test , y_train , y_test = train_test_split(X ,Y ,random_state=1 ,test_size= 0.2 ,shuffle=True,stratify=Y)
return X_train , X_test , y_train , y_test
The number of legendary pokemon is lower compared to non-legendary pokemon, therefore, we will use undersampling to compensate for the imbalanced dataset.
- First, extract the legendary pokemon, 75 random normal pokemon from the original dataset.
- Fill the Nan values with the mean. For the legendary subset, use the legendary data mean and for the normal subset, use the normal data mean.
- Concatenate the data
- We will use the following features [‘weight_kg’ , ‘height_m’ , ‘sp_attack’ , ‘attack’ , ‘sp_defense’ , ‘defense’ , ‘speed’ , ‘hp’]
- Use the train_test_split method to split the dataset into training data and test data. Set the parameter stratify to Y to make sure the testing set has an equal distribution of legendary, normal pokemon.
Building the UI
# Intro
st.title('Is that a Legendary Pokemon?')
st.image('bg.jpg', width=600)
st.markdown('''
Photo by [Kamil S](https://unsplash.com/@16bitspixelz?utm_source=unsplash&utm_medium=referral&utm_content=creditCopyText) on [Unsplash](https://unsplash.com/s/photos/pokemon?utm_source=unsplash&utm_medium=referral&utm_content=creditCopyText)
''')
I used this pic for the intro image. Set the width parameter of the image function to resize the pic as needed.
# Load Data
df = pd.read_csv('pokemon.csv')
st.dataframe(df.head())
# Basic Info
shape = df.shape
num_total = len(df)
num_legendary = len(df[df['is_legendary'] == 1])
num_non_legendary = num_total - num_legendary
Load the data and get the number of normal, legendary pokemon. We will use the subheader() function to display the totals.
st.subheader('''
Number of Pokemons: {}
'''.format(num_total))
st.subheader('''
Number of Legendary Pokemons: {}
'''.format(num_legendary))
st.subheader('''
Number of Non Legendary Pokemons: {}
'''.format(num_non_legendary))
st.subheader('''
Number of Features :{}
'''.format(shape[1]))
Legendary Pokemon Distribution based on Type
title('Legendary Pokemon Distribution based on Type')
legendary_df = df[df['is_legendary'] == 1]
fig1 = plt.figure()
ax = sns.countplot(data=legendary_df , x = 'type1',order=legendary_df['type1'].value_counts().index)
plt.xticks(rotation=45)
st.pyplot(fig1)
- Use the title helper function to display the title
- Use seaborn’s countplot to plot the distribution
Height vs Weight for Legendary and Non-Legendary Pokemons
title('Height vs Weight for Legendary and Non-Legendary Pokemons')
fig2 = plt.figure()
sns.scatterplot(data=df , x = 'weight_kg' , y = 'height_m' , hue='is_legendary')
st.pyplot(fig2)
- Use seaborn’s scatterplot to plot the height vs weight distribution for legendary, normal pokemon
- Set the hue parameter to the ‘is_legendary’ column to color code the pokemons
Correlation between features
title('Correlation between features')
fig3 = plt.figure()
sns.heatmap(legendary_df[['attack','sp_attack','defense','sp_defense','height_m','weight_kg','speed']].corr())
st.pyplot(fig3)
- Use seaborn’s heatmap to plot the correlations.
- Use the corr() function to get the correlation values
Special Attack vs Attack
fig4 = plt.figure()
sns.scatterplot(data=df, x='sp_attack',y='attack',hue='is_legendary')
st.pyplot(fig4)
- As we did for the height vs weight plot, we will use seaborn’s scatterplot
Model
We will use a random forest to make our predictions.
title('Random Forest')
X_train , X_test , y_train , y_test = clean_and_split(df)
st.subheader("Sample Data")
st.dataframe(X_train.head(3))
model = RandomForestClassifier(n_estimators=100)
model.fit(X_train , y_train)
title("Metrics")
st.subheader("Model Score: {}".format(model.score(X_test , y_test)))
st.subheader("Precision Score: {}".format(precision_score(model.predict(X_test) , y_test)))
st.subheader("Recall Score: {}".format(recall_score(model.predict(X_test) , y_test)))
- Use the helper function we created earlier to get the training, testing data
- Create an instance of RandomForestClassifier and set the n_estimators parameter to the number of the decision tress you want the model to consider while making a prediction
- Use the various prediction functions we imported earlier to score our model
st.subheader("Confusion Matrix")
fig5 = plt.figure()
conf_matrix = confusion_matrix(model.predict(X_test) , y_test)
sns.heatmap(conf_matrix , annot=True , xticklabels=['Normal' , 'Legendary'] , yticklabels=['Normal' , 'Legendary'])
plt.ylabel("True")
plt.xlabel("Predicted")
st.pyplot(fig5)
- Use the confusion_matrix imported from sklearn’s metrics to get a confusion matrix
- Use the confusion matrix to plot a heatmap using seaborn’s heatmap function
Deploy your App