Classification Models in Machine Learning
Classification is a crucial task in machine learning, as it allows us to predict the labels or categories of new data based on its features. There are several classification algorithms available, each with its unique strengths and weaknesses. In this blog post, we will explore a range of classification models, from traditional machine learning algorithms like LogisticRegression, KNeighborsClassifier, SVC, DecisionTreeClassifier, RandomForestClassifier, GradientBoostingClassifier, AdaBoostClassifier, XGBClassifier to more complex neural network models. We will examine how each algorithm works, their advantages and disadvantages, and when to use them based on the specific data characteristics and problem requirements. By the end of this blog post, you will have a better understanding of different classification techniques and be able to choose the best algorithm for your classification task.
Context¶
An automobile company has plans to enter new markets with their existing products (P1, P2, P3, P4, and P5). After intensive market research, they’ve deduced that the behavior of the new market is similar to their existing market.
In their existing market, the sales team has classified all customers into 4 segments (A, B, C, D ). Then, they performed segmented outreach and communication for a different segment of customers. This strategy has work e exceptionally well for them. They plan to use the same strategy for the new markets and have identified 2627 new potential customers.
You are required to help the manager to predict the right group of the new customers.
Import libraries¶
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import os
Import Dataset¶
train_data=pd.read_csv("Train.csv")
test_data=pd.read_csv("Test.csv")
train_data
ID | Gender | Ever_Married | Age | Graduated | Profession | Work_Experience | Spending_Score | Family_Size | Var_1 | Segmentation | |
---|---|---|---|---|---|---|---|---|---|---|---|
0 | 462809 | Male | No | 22 | No | Healthcare | 1.0 | Low | 4.0 | Cat_4 | D |
1 | 462643 | Female | Yes | 38 | Yes | Engineer | NaN | Average | 3.0 | Cat_4 | A |
2 | 466315 | Female | Yes | 67 | Yes | Engineer | 1.0 | Low | 1.0 | Cat_6 | B |
3 | 461735 | Male | Yes | 67 | Yes | Lawyer | 0.0 | High | 2.0 | Cat_6 | B |
4 | 462669 | Female | Yes | 40 | Yes | Entertainment | NaN | High | 6.0 | Cat_6 | A |
… | … | … | … | … | … | … | … | … | … | … | … |
8063 | 464018 | Male | No | 22 | No | NaN | 0.0 | Low | 7.0 | Cat_1 | D |
8064 | 464685 | Male | No | 35 | No | Executive | 3.0 | Low | 4.0 | Cat_4 | D |
8065 | 465406 | Female | No | 33 | Yes | Healthcare | 1.0 | Low | 1.0 | Cat_6 | D |
8066 | 467299 | Female | No | 27 | Yes | Healthcare | 1.0 | Low | 4.0 | Cat_6 | B |
8067 | 461879 | Male | Yes | 37 | Yes | Executive | 0.0 | Average | 3.0 | Cat_4 | B |
8068 rows × 11 columns
DATASET OVERVIEW¶
View first 10 rows
#First 10 rows
train_data.head(10)
ID | Gender | Ever_Married | Age | Graduated | Profession | Work_Experience | Spending_Score | Family_Size | Var_1 | Segmentation | |
---|---|---|---|---|---|---|---|---|---|---|---|
0 | 462809 | Male | No | 22 | No | Healthcare | 1.0 | Low | 4.0 | Cat_4 | D |
1 | 462643 | Female | Yes | 38 | Yes | Engineer | NaN | Average | 3.0 | Cat_4 | A |
2 | 466315 | Female | Yes | 67 | Yes | Engineer | 1.0 | Low | 1.0 | Cat_6 | B |
3 | 461735 | Male | Yes | 67 | Yes | Lawyer | 0.0 | High | 2.0 | Cat_6 | B |
4 | 462669 | Female | Yes | 40 | Yes | Entertainment | NaN | High | 6.0 | Cat_6 | A |
5 | 461319 | Male | Yes | 56 | No | Artist | 0.0 | Average | 2.0 | Cat_6 | C |
6 | 460156 | Male | No | 32 | Yes | Healthcare | 1.0 | Low | 3.0 | Cat_6 | C |
7 | 464347 | Female | No | 33 | Yes | Healthcare | 1.0 | Low | 3.0 | Cat_6 | D |
8 | 465015 | Female | Yes | 61 | Yes | Engineer | 0.0 | Low | 3.0 | Cat_7 | D |
9 | 465176 | Female | Yes | 55 | Yes | Artist | 1.0 | Average | 4.0 | Cat_6 | C |
#Last 10 rows
train_data.tail(10)
ID | Gender | Ever_Married | Age | Graduated | Profession | Work_Experience | Spending_Score | Family_Size | Var_1 | Segmentation | |
---|---|---|---|---|---|---|---|---|---|---|---|
8058 | 460674 | Female | No | 31 | Yes | Entertainment | 0.0 | Low | 3.0 | Cat_3 | A |
8059 | 460132 | Male | No | 39 | Yes | Healthcare | 3.0 | Low | 2.0 | Cat_6 | D |
8060 | 463613 | Female | Yes | 48 | Yes | Artist | 0.0 | Average | 6.0 | Cat_6 | A |
8061 | 465231 | Male | Yes | 65 | No | Artist | 0.0 | Average | 2.0 | Cat_6 | C |
8062 | 463002 | Male | Yes | 41 | Yes | Artist | 0.0 | High | 5.0 | Cat_6 | B |
8063 | 464018 | Male | No | 22 | No | NaN | 0.0 | Low | 7.0 | Cat_1 | D |
8064 | 464685 | Male | No | 35 | No | Executive | 3.0 | Low | 4.0 | Cat_4 | D |
8065 | 465406 | Female | No | 33 | Yes | Healthcare | 1.0 | Low | 1.0 | Cat_6 | D |
8066 | 467299 | Female | No | 27 | Yes | Healthcare | 1.0 | Low | 4.0 | Cat_6 | B |
8067 | 461879 | Male | Yes | 37 | Yes | Executive | 0.0 | Average | 3.0 | Cat_4 | B |
#random sample of 20 rows
train_data.sample(20)
ID | Gender | Ever_Married | Age | Graduated | Profession | Work_Experience | Spending_Score | Family_Size | Var_1 | Segmentation | |
---|---|---|---|---|---|---|---|---|---|---|---|
4441 | 464444 | Male | Yes | 36 | Yes | Artist | 8.0 | Average | 2.0 | Cat_6 | C |
6805 | 466859 | Male | No | 22 | No | Healthcare | 0.0 | Low | 5.0 | Cat_2 | D |
3013 | 462280 | Male | No | 32 | Yes | Healthcare | 0.0 | Low | 4.0 | Cat_6 | B |
6233 | 460252 | Male | Yes | 41 | Yes | Artist | 0.0 | Low | 2.0 | Cat_6 | A |
1428 | 467221 | Female | No | 20 | No | Doctor | NaN | Low | 4.0 | Cat_6 | C |
2624 | 464566 | Male | Yes | 35 | Yes | Executive | 0.0 | High | 4.0 | Cat_6 | A |
6080 | 466744 | Female | Yes | 52 | Yes | Artist | 1.0 | Average | 3.0 | Cat_6 | C |
4383 | 465479 | Male | No | 27 | No | Artist | 1.0 | Low | 1.0 | Cat_4 | D |
3846 | 465129 | Female | No | 61 | Yes | Entertainment | 3.0 | Low | 1.0 | Cat_6 | A |
2357 | 461693 | Female | Yes | 36 | Yes | Artist | 1.0 | Average | 4.0 | Cat_6 | B |
3884 | 462956 | Male | Yes | 42 | No | Artist | 7.0 | Average | 5.0 | Cat_6 | B |
7163 | 467168 | Male | Yes | 46 | Yes | Artist | 8.0 | Average | 2.0 | Cat_1 | B |
1180 | 464997 | Male | Yes | 43 | Yes | Artist | 5.0 | Average | 2.0 | Cat_6 | B |
4360 | 462525 | Female | Yes | 41 | Yes | Artist | 0.0 | Average | 2.0 | Cat_6 | C |
3642 | 464243 | Female | No | 49 | Yes | Artist | 1.0 | Low | 2.0 | Cat_6 | C |
1899 | 463025 | Male | Yes | 40 | No | Executive | 7.0 | Low | 4.0 | Cat_6 | A |
3508 | 460616 | Male | Yes | 32 | Yes | Homemaker | 2.0 | Average | 2.0 | Cat_3 | C |
1100 | 461129 | Female | Yes | 63 | Yes | Entertainment | 0.0 | Average | 4.0 | Cat_6 | B |
6915 | 464716 | Male | Yes | 43 | Yes | Marketing | 0.0 | High | 6.0 | Cat_4 | D |
303 | 466597 | Female | No | 19 | No | Healthcare | 1.0 | Low | 2.0 | Cat_6 | D |
Define function to get an overview of the data
def data_overview(data, title):
overview_analysis = {f'{title}':[data.shape[1], data.shape[0],
data.isnull().any(axis=1).sum(),
data.isnull().any(axis=1).sum()/len(data)*100,
data.duplicated().sum(),
data.duplicated().sum()/len(data)*100,
sum((data.dtypes == 'object') & (data.nunique() > 2)),
sum((data.dtypes == 'object') & (data.nunique() < 3)),
data.select_dtypes(include=['int64', 'float64']).shape[1]
]}
overview_analysis=pd.DataFrame(overview_analysis, index=['Columns','Rows','Missing_Values','Missing_Values %',
'Duplicates', 'Duplicates %','Categorical_variables','Boolean_variables','Numerical_variables']).round(2)
return overview_analysis
data_overview(train_data, "Data_Overview")
Data_Overview | |
---|---|
Columns | 11.00 |
Rows | 8068.00 |
Missing_Values | 1403.00 |
Missing_Values % | 17.39 |
Duplicates | 0.00 |
Duplicates % | 0.00 |
Categorical_variables | 4.00 |
Boolean_variables | 3.00 |
Numerical_variables | 4.00 |
Define function to have an overview of the variables
def variables_overview1 (data):
variable_details = {'unique':data.nunique(),
'dtype':data.dtypes,
'null':data.isna().sum(),
'null %':data.isna().sum()/len(data)*100
}
variable_details = pd.DataFrame(variable_details)
return variable_details
variables_overview=variables_overview1(train_data)
variables_overview
unique | dtype | null | null % | |
---|---|---|---|---|
ID | 8068 | int64 | 0 | 0.000000 |
Gender | 2 | object | 0 | 0.000000 |
Ever_Married | 2 | object | 140 | 1.735250 |
Age | 67 | int64 | 0 | 0.000000 |
Graduated | 2 | object | 78 | 0.966782 |
Profession | 9 | object | 124 | 1.536936 |
Work_Experience | 15 | float64 | 829 | 10.275161 |
Spending_Score | 3 | object | 0 | 0.000000 |
Family_Size | 9 | float64 | 335 | 4.152206 |
Var_1 | 7 | object | 76 | 0.941993 |
Segmentation | 4 | object | 0 | 0.000000 |
SUMMARY OF STATISTICS¶
Compute summary of statistics for the numerical columns in the DataFrame.
train_data.describe()
ID | Age | Work_Experience | Family_Size | |
---|---|---|---|---|
count | 8068.000000 | 8068.000000 | 7239.000000 | 7733.000000 |
mean | 463479.214551 | 43.466906 | 2.641663 | 2.850123 |
std | 2595.381232 | 16.711696 | 3.406763 | 1.531413 |
min | 458982.000000 | 18.000000 | 0.000000 | 1.000000 |
25% | 461240.750000 | 30.000000 | 0.000000 | 2.000000 |
50% | 463472.500000 | 40.000000 | 1.000000 | 3.000000 |
75% | 465744.250000 | 53.000000 | 4.000000 | 4.000000 |
max | 467974.000000 | 89.000000 | 14.000000 | 9.000000 |
Compute summary of statistics for the categorical columns in the DataFrame.¶
train_data.describe(include='object')
Gender | Ever_Married | Graduated | Profession | Spending_Score | Var_1 | Segmentation | |
---|---|---|---|---|---|---|---|
count | 8068 | 7928 | 7990 | 7944 | 8068 | 7992 | 8068 |
unique | 2 | 2 | 2 | 9 | 3 | 7 | 4 |
top | Male | Yes | Yes | Artist | Low | Cat_6 | D |
freq | 4417 | 4643 | 4968 | 2516 | 4878 | 5238 | 2268 |
EXPLANATORY DATA ANALYSIS (EDA)¶
Visualise customer segmentation count (Target Variable)¶
ax = sns.countplot(train_data["Segmentation"],
order = train_data["Segmentation"].value_counts().index)
abs_values = train_data['Segmentation'].value_counts().values
rel_values = train_data['Segmentation'].value_counts(normalize=True).values * 100
lbls = [f'{p[0]} ({p[1]:.0f}%)' for p in zip(abs_values, rel_values)]
ax.bar_label(container=ax.containers[0], labels=lbls)
ax.set_xlabel('Segmentation')
ax.set_ylabel('Count')
ax.set_title('Segmentation Count')
C:\Users\d\anaconda3\lib\site-packages\seaborn\_decorators.py:36: FutureWarning: Pass the following variable as a keyword arg: x. From version 0.12, the only valid positional argument will be `data`, and passing other arguments without an explicit keyword will result in an error or misinterpretation. warnings.warn(
Text(0.5, 1.0, 'Segmentation Count')
fig, ax = plt.subplots(1, 2, figsize=(9, 5))
train_data["Segmentation"].value_counts().plot.bar(color=['blue', 'orange', 'green', 'red'], ax=ax[0])
train_data["Segmentation"].value_counts().plot(kind='pie',autopct='%.2f%%',shadow=True, ax=ax[1])
centre_circle = plt.Circle((0,0),0.80,fc='white')
fig = plt.gcf()
fig.gca().add_artist(centre_circle)
ax[1].set_title("Segmentation Analysis")
ax[0].set_title("Segmentation Analysis")
ax[1].legend(title="Segmentation", bbox_to_anchor=(1.1, 1), labels=['D', 'A', 'C', 'B'])
for i, patch in enumerate(ax[0].patches):
count = train_data["Segmentation"].value_counts().iloc[i]
ax[0].annotate(str(count), xy=(patch.get_x() + patch.get_width() / 2, patch.get_height() + 1),
ha='center', va='center', fontsize=10)
plt.xticks(rotation=90)
plt.yticks(rotation=45)
plt.show()
col= ['Gender',
'Ever_Married',
'Graduated',
'Profession',
'Spending_Score',
'Var_1']
for j in col:
plt.figure(figsize=(10,6))
plt.title(f'{j} Count by segmentation')
sns.countplot(data=train_data, x= j,hue='Segmentation')
mode_table=pd.DataFrame(train_data.groupby('Segmentation')[['Family_Size', 'Age', 'Work_Experience', 'Gender', 'Ever_Married', 'Profession', 'Spending_Score']].agg(pd.Series.mode))
mode_table
Family_Size | Age | Work_Experience | Gender | Ever_Married | Profession | Spending_Score | |
---|---|---|---|---|---|---|---|
Segmentation | |||||||
A | 2.0 | 35 | 1.0 | Male | Yes | Artist | Low |
B | 2.0 | 43 | 1.0 | Male | Yes | Artist | Low |
C | 2.0 | 50 | 1.0 | Male | Yes | Artist | Average |
D | 4.0 | 22 | 0.0 | Male | No | Healthcare | Low |
mode_table.plot(kind='bar', figsize=(10, 6))
plt.xlabel('Segmentation')
plt.ylabel('Mode')
plt.title('Mode values by segmentation')
plt.legend(title='Variable', bbox_to_anchor=(1, 1))
plt.show()
mean_table=pd.DataFrame(train_data.groupby('Segmentation')[['Family_Size', 'Age', 'Work_Experience', 'Gender', 'Ever_Married', 'Profession', 'Spending_Score']].agg(pd.Series.mean))
mean_table
C:\Users\d\AppData\Local\Temp\ipykernel_14000\2097100306.py:1: FutureWarning: ['Gender', 'Ever_Married', 'Profession', 'Spending_Score'] did not aggregate successfully. If any error is raised this will raise in a future version of pandas. Drop these columns/ops to avoid this warning. mean_table=pd.DataFrame(train_data.groupby('Segmentation')[['Family_Size', 'Age', 'Work_Experience', 'Gender', 'Ever_Married', 'Profession', 'Spending_Score']].agg(pd.Series.mean))
Family_Size | Age | Work_Experience | |
---|---|---|---|
Segmentation | |||
A | 2.439531 | 44.924949 | 2.874578 |
B | 2.696970 | 48.200215 | 2.378151 |
C | 2.974559 | 49.144162 | 2.240771 |
D | 3.232624 | 33.390212 | 3.021717 |
mean_table.plot(kind='bar', figsize=(10, 6))
plt.xlabel('Segmentation')
plt.ylabel('Mode')
plt.title('Mean values by segmentation')
plt.legend(title='Variable', bbox_to_anchor=(1, 1))
plt.show()
#Gender Count in each segmentation
train_data.groupby(['Segmentation','Gender'])[['Gender']].count()
Gender | ||
---|---|---|
Segmentation | Gender | |
A | Female | 909 |
Male | 1063 | |
B | Female | 861 |
Male | 997 | |
C | Female | 922 |
Male | 1048 | |
D | Female | 959 |
Male | 1309 |
train_data.groupby(['Segmentation','Gender'])[['Gender']].count().plot(kind = 'barh')
<AxesSubplot:ylabel='Segmentation,Gender'>