Creating a scatter plot by category in python
Creating scatter plots for has become so easy with python. For datasets with a manageable number of categories, manual mapping with a dictionary is recommended. It is intuitive, gives explicit control, and is easy to extend or debug
The dataset used in this code snippet can be found in vincentarelbundock.github.io/Rdatasets/datasets.html by scrolling down to datasets category(datasets/iris.csv)
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.patches import Patch
iris= pd.read_csv('iris.csv')
#assiging each category of species a color
Dict_colors = {'setosa':'blue','versicolor':'green','virginica':'red'}
#map species to their respective colors
colors = iris['Species'].map(Dict_colors)
#creating a scatter plot
plt.scatter(iris['Petal.Length'],iris['Sepal.Length'],c=colors)
#showing the grid
plt.grid(True, linestyle = '--', alpha = 0.5)
#creating a custom legend using matplotlib.patches
legend_patches = [Patch(color=color, label= species) for species, color in Dict_colors.items()]
plt.legend(handles=legend_patches, title='species',loc='upper left')
#setting labels grid and title
plt.gca().set(xlabel= 'petal length',ylabel='sepal lenght', title= 'Scatter plot of sepal length vs petal length by species')

Code Explanation
1. Importing Required Libraries python
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.patches import Patch
pandas
: Used for reading and manipulating the dataset (iris.csv
).
matplotlib.pyplot
: Provides functions for creating and customizing visualizations.
matplotlib.patches.Patch
: Allows the creation of custom legend elements, specifically colored boxes for categorical data (e.g., species).
2. Loading the Dataset
iris = pd.read_csv('iris.csv')
Purpose: Reads the Iris dataset (iris.csv
) into a DataFrame named iris
.
pd.read_csv()
: Loads a CSV file into a DataFrame.
The dataset typically contains columns such as:
Sepal.Length
: Length of the sepal (in cm).
Petal.Length
: Length of the petal (in cm).
Species
: Categorical column with three classes (setosa
, versicolor
, virginica
).
3. Assigning Colors to Species
Dict_colors = {'setosa': 'blue', 'versicolor': 'green', 'virginica': 'red'}
Creates a dictionary mapping each species to a specific color.
'setosa'
: Blue.
'versicolor'
: Green.
'virginica'
: Red.
This ensures that each species is represented consistently in the scatter plot.
4. Mapping Colors to Data Points
colors = iris['Species'].map(Dict_colors)
Purpose: Maps the Species
column in the dataset to its corresponding color from Dict_colors
.
map()
: A pandas method that applies a function or mapping (dictionary in this case) to a Series.
If the Species
value is 'setosa'
, it assigns 'blue'
.
Similarly, 'versicolor'
gets 'green'
, and 'virginica'
gets 'red'
.
- Result: A Series of color values (
blue
,green
,red
) matching the species.
5. Creating the Scatter Plot
plt.scatter(iris['Petal.Length'], iris['Sepal.Length'], c=colors)
plt.scatter()
: Plots points on a 2D plane.
x-axis
: iris['Petal.Length']
(petal length values).
y-axis
: iris['Sepal.Length']
(sepal length values).
c=colors
: Colors the points based on the mapped species colors.
Visualization: Each point represents a flower, and its position is determined by petal and sepal length, while its color indicates the species.
6. Showing the Grid
plt.grid(True, linestyle='--', alpha=0.5)
Purpose: Adds a grid to the plot to make it easier to read.
Parameters:
True
: Enables the grid.
linestyle='--'
: Dashed grid lines.
alpha=0.5
: Adjusts transparency of the grid lines for a cleaner look.
7. Creating a Custom Legend
legend_patches = [Patch(color=color, label=species) for species, color in Dict_colors.items()]
plt.legend(handles=legend_patches, title='species', loc='upper left')
Purpose: Adds a custom legend to explain the color-coding of species.
Patch
: Creates colored boxes that match the colors in the scatter plot.
color
: The color of the box (e.g., blue
, green
, red
).
label
: The species name (e.g., 'setosa'
).
handles
: The list of Patch
objects is passed to the legend.
title
: Adds a title (species
) to the legend.
loc='upper left'
: Positions the legend in the upper-left corner of the plot.
8. Setting Labels and Title
plt.gca().set(xlabel='petal length', ylabel='sepal length', title='Scatter plot of sepal length vs petal length by species')
plt.gca()
: Retrieves the current axes object (the plot’s coordinate system).
.set()
: Sets multiple properties of the axes:
xlabel
: Label for the x-axis ('petal length'
).
ylabel
: Label for the y-axis ('sepal length'
).
title
: Title of the plot ('Scatter plot of sepal length vs petal length by species'
).