Matplotlib: how to plot & chart with Python

In this post I’m going to show you how to use Matplotlib to create plots and charts in Python. I’ll show you how to use matplotlib with and without Pandas, which is heavily used in many Python for finance applications.

Get started with matplotlib

Install matplotlib

In order to start using matplotlib to visualize your data with Python, you need to first install matplotlib in your computer. To do that, open up your command prompt and type in

— pip install matplotlib

Then click the “Enter” key on your keyboard. Matplotlib will be installed. You don’t need to do anything else.

 

Start using matplotlib 

Before I show you how to visualize your data and plot graphs with matplotlib, we need to first import matplotlib into our code. Keep in mind that all code here is written within the Jupyter Notebook environment. Other text editors will work too:

# This will import matplotlib
import matplotlib as mpl
#This will import the main package for plotting
import matplotlib.pyplot as plt
%matplotlib inline

*For the time being we will not use Seaborn. We’ll look at Seaborn later in this post. Seaborn builds ontop of matplotlib by adding more types of charts and plots, and making matplotlib charts look more user-friendly.

Next, let’s look at how to use matplotlib to plot charts without Pandas:

 

Matplotlib without Pandas

Line chart (1 data set)

If you want to plot one dimensional data on a line chart, such as a list or an array, here’s how.

The basic matplotlib plotting function is plt.plot(). This function takes 2 parameters: x and y

plt.plot(x, y)

  1. x = the x axis coordinates of your data
  2. y = the y axis coordinates of your data

x and y are lists or arrays. The number of x and y values should be the same. For example, plotting the following returns:

plt.plot([2, 4, 6, 8], [3, 4, 7, 5])

*In this case, x = [2, 4, 6, 8] and y = [3, 4, 7, 5]

As you can see with the statement [<matplotlib.lines.Line2D at 0x20113c8da08>], Jupyter tells us that this is a 2D plot. Later we’ll look at 3D plots and visualizations.

More often than not we will set the data to be equal to a variable, and use that variable in our plt.plot() function. Here’s an example:

 
x = [2, 4, 6, 8]
y = [3, 4, 7, 5]
plt.plot(x, y) 

What if you only have 1 data set, and you don’t want to set a list or a variable to the x axis? You can essentially set the x axis to be the index of the data set with list comprehension:

data = [3, 4, 7, 5, 5, 22, 1, 4, 5]
x_axis = [x for x in range(len(data))]
plt.plot(x_axis, data)

Alternatively, you don’t even need to pass in any x values. If you don’t pass in x values, .plot() will just use the y values’ index for the x axis. Here’s an example using the same code, except now we don’t have any x axis values

 
data = [3, 4, 7, 5, 5, 22, 1, 4, 5] 
plt.plot(data) 

As you can see, the chart still looks the same. X axis values are not necessary in matplotlib.

 

Line chart (more than 1 data set)

When using matplotlib and Python for finance/trading, you’ll often want to compare and overlap different charts. Here’s the basics of plotting multiple data sets (more than 1 data set) on the same line chart in matplotlib.

First, you need to have 2 data sets which means that you need to have 2 lists/arrays of y axis values. Since x axis values are not necessary in matplotlib when you create a line chart with 1 data set, then x values are also not necessary when you create a line chart with more than 1 data set.

To execute your code, call plt.plot() twice, on both sets of data. Here’s an example

 
y = [3, 4, 7, 5] 
y2 = [5, 9, 12, 2]
plt.plot(y)
plt.plot(y2) 

If you have specified an x axis, make sure that the specified x axis is the same for both data sets. Othewise the chart will not work. For example:

 
x = [2, 4, 6, 8]
x2 = [2, 4, 6, 8]
y = [3, 4, 7, 5] 
y2 = [5, 9, 12, 2]
plt.plot(x, y)
plt.plot(x2, y2) 

Knowing which data set is which can be hard by looking at the chart. So we need to label each data set using the label parameter and print out a legend using the plt.legend() function. Using the same example:

 
x = [2, 4, 6, 8]
x2 = [2, 4, 6, 8] 
y = [3, 4, 7, 5] 
y2 = [5, 9, 12, 2] 
plt.plot(x, y, label='Data Set #1') 
plt.plot(x2, y2, label='Data Set #2') 
plt.legend()

In a future post I’ll show you how to:

  1. Use different scales on the x-axis or y-axis with 2 data sets
  2. Visualize 2 data sets in different ways. For example, you may want to visualize Data Set #1 with a line graph and Data Set #2 with a bar graph.

 

*Special Note

If you want to see the chart, you need to enter plt.show()

I haven’t done so in this tutorial because Jupyter Notebook will show me the chart automatically. But if you’re using another text editor such as Atom, make sure to enter plt.show()

 
data = [3, 4, 7, 5, 5, 22, 1, 4, 5] 
x_axis = [x for x in range(len(data))] 
plt.plot(x_axis, data) 
plt.show()

 

Bar chart

To create a bar chart, all we need to do is call the .bar() function instead of the .plot() function (like we did for a line chart). Using the same example:

 
x = [2, 4, 6, 8] 
y = [3, 4, 7, 5] 
plt.bar(x, y) 

 

Histogram

To create a histogram to show the distribution of a data set, you need to use the .hist() function. But first, you need to pass in your data set as a list and then create “bins” for each bar in the histogram. E.g. we have a data set for 25 peoples’ weight (in pounds). We can then split the distribution of weight into 6 bins of 20 pounds each.

 
peoples_weight = [120, 210, 140, 153, 193, 124, 168, 187, 179, 193, 215, 152, 194, 205, 102, 153, 195, 184, 179, 143, 153, 193, 187, 173, 173]
bins = [100, 120, 140, 160, 180, 200, 220]
plt.hist(peoples_weight, bins, histtype='bar')

*Keep in mind that each bin is exclusive on the right side. E.g. a bin of 100 (inclusive) to 120 (exclusive), will take all data points from 100-119.

If you want to specify a “gap” between each bar in the histogram, you need to pass rwidth=’ ‘ into the .hist() function. Here’s how:

 
peoples_weight = [120, 210, 140, 153, 193, 124, 168, 187, 179, 193, 215, 152, 194, 205, 102, 153, 195, 184, 179, 143, 153, 193, 187, 173, 173]
bins = [100, 120, 140, 160, 180, 200, 220]
plt.hist(peoples_weight, bins, histtype='bar', rwidth=0.7)

 

Scatterplot

To create a scatterplot, you need to use the .scatter() function and pass in 2 lists of data.

  1. List #1 will be the x axis coordinates of your data
  2. List #2 will be the y axis coordinates of your data

Here’s an example:

x = [1, 5, 9, 2, 6, 2, 3, 6, 7]
y = [8, 2, 4, 6, 1, 9, 4, 6, 2]
plt.scatter(x, y)

 

Stackplot

Stackplots are often used in finance when representing market share (as a % of the total market) over time. By using a stackplot, the user can more easily see and understand how market share has changed over time. We can create a stackplot in matplotlib with the .stackplot() function.

years = [1880, 1900, 1920, 1940, 1960, 1980, 2000, 2020]
finance = [20, 30, 50, 40, 40, 40, 30, 20]
tech = [5, 10, 10, 20, 30, 20, 40, 50]
energy = [5, 10, 15, 10, 10, 5, 10, 5]
consumer_staples = [70, 50, 25, 30, 20, 35, 20, 25]
plt.stackplot(years, finance,tech,energy,consumer_staples) 

In this example, you need to pass in a list (e.g. years) to label the x axis.

Next, you need to specify which data sets to include in your stack plot. Each data set should be a list. For example, I’ve included the data sets finance, tech, energy, and consumer_staples

 

Pie charts

You can create a pie chart using matplotlib with the .pie() function. All you need is two variables:

  1. Variable #1 = a list of names for each portion of the pie chart
  2. Variable #2 = a weight for each portion of the pie chart. Don’t worry if the total weights don’t add up to 100.

Then within the .pie() function, you need to set the label parameter to equal Variable #1 (the list of names for each portion of the pie chart)

Here’s an example:

 
company_names = ['Google', 'Facebook', 'Amazon', 'Netflix']
revenue = [20000, 18000, 25000, 7000]
plt.pie(revenue, labels=company_names)

Here’s what this looks like:

If you want to show the percentage weight of each slice in the pie chart, include the autopct parameter:

 
company_names = ['Google', 'Facebook', 'Amazon', 'Netflix']
revenue = [20000, 18000, 25000, 7000]
plt.pie(revenue, labels=company_names, autopct='%.2f')

 

Matplotlib with Pandas

When trying to plot charts in Python, you’ll probably want to use Matplotlib with Pandas. After all, who wouldn’t want to visualize a trading algorithm that they just built in Pandas?

Basics

Plotting a chart with Pandas and matplotlib is easy. Just call the .plot() function on the DataFrame columns that you want to plot.

# This will import matplotlib
import matplotlib as mpl

#This will import the main package for plotting
import matplotlib.pyplot as plt
%matplotlib inline

#importing pandas
import pandas as pd

#reading the csv
df = pd.read_csv('C:\\Users\\vaish\\Desktop\\py4e\\somedata.csv')
df.columns = df.columns.str.strip()

This will print out the following dataframe:

If you want to plot a line chart with a single column in the DataFrame, call .plot() on that column. For example:

 
df['S&P 500'].plot()

To plot more than 1 column within the same chart, pass in a list of the column names. For example, if you want to plot the ‘S&P 500’ and ‘Gold’ columns:

 
df[['S&P 500', 'Gold']].plot()

 

Make your chart more user friendly

Now let’s add some parameters to make your chart more user-friendly, readable, and visually appealing.

 

Changing the x-axis:

Calling the .plot() function on a DataFrame will automatically set the chart’s x-axis as the dataframe’s index, UNLESS you specify that index as something else (e.g. the ‘Date’ column). If you want to set your x-axis as ‘Date’, use the following code:

 
df = df.set_index('Date')
df[['S&amp;P 500', 'Gold']].plot() 

The .plot() function has several parameters which can be used to customize your chart:

 

Kind

This parameter demonstrates what type of chart you want. .plot() is by default set to a line chart .plot(kind=’line’)

Here’s an example:

 
df = df.set_index('Date')
df[['S&P 500', 'Gold']].plot(kind='line')

There are other kinds of charts you can use:

  • ‘bar’ for vertical bar plot
  • ‘hist’ for histogram
  • ‘pie’ for pie plot
  • ‘scatter’ for scatter plot
  • ‘area’ for area plot
  • and much more

Keep in mind that if you want to plot 2 DataFrame columns, there are many kinds of plots that you cannot use. E.g. you cannot use a ‘pie’ chart with 2 DataFrame columns.

 

Title

Use the ‘title’ parameter to add a title to your plot. Here’s an example:

 
df = df.set_index('Date')
df[['S&P 500', 'Gold']].plot(kind='line', title='2 Markets')

 

Grid

Use the ‘grid’ parameter to add gridlines to your chart. The ‘grid’ parameter takes either a True or False value. Here’s an example:

 
df = df.set_index('Date') 
df[['S&P 500', 'Gold']].plot(kind='line', title='2 Markets', grid='True') 

 

Legend

Use the ‘legend’ parameter to add a legend to your chart. The ‘legend’ parameter takes either a True or False value, and is by default True. Here’s an example:

 
df = df.set_index('Date') 
df[['S&P 500', 'Gold']].plot(kind='line', title='2 Markets', grid='True', legend='True') 

 

Log scale chart

Use the ‘logy’ parameter if you want the y-axis to use a log scale, and use the ‘logx’ parameter if you want the x-axis to use a log scale. Both parameters take a True or False value, and are by default False. Here’s an example of a log scale on the y-axis:

 
df = df.set_index('Date') 
df[['S&P 500', 'Gold']].plot(kind='line', title='2 Markets', grid='True', legend='True', logy=True) 

Alternatively, you can use the ‘loglog’ parameter to simultaneously set both x and y axis as log scales.

 

Size of the chart

Use the ‘figsize’ parameter to adjust the size (width and height) of the chart. ‘figsize’ takes a tuple (width, height) in inches. For example:

 
df = df.set_index('Date') 
df[['S&P 500', 'Gold']].plot(kind='line', title='2 Markets', grid='True', legend='True', logy=True, figsize=(10, 5)) 

Leave a Comment