Subplots

In the previous section, "Styling your plots", we set the title of a plot using a bit of matplotlib code. We did this by grabbing the underlying "axis" and then calling set_title on that.

In this section we'll explore another matplotlib-based stylistic feature: subplotting.

In [1]:
import pandas as pd
reviews = pd.read_csv("../input/wine-reviews/winemag-data_first150k.csv", index_col=0)
reviews.head(3)
Out[1]:
country description designation points price province region_1 region_2 variety winery
0 US This tremendous 100% varietal wine hails from ... Martha's Vineyard 96 235.0 California Napa Valley Napa Cabernet Sauvignon Heitz
1 Spain Ripe aromas of fig, blackberry and cassis are ... Carodorum Selección Especial Reserva 96 110.0 Northern Spain Toro NaN Tinta de Toro Bodega Carmen Rodríguez
2 US Mac Watson honors the memory of a wine once ma... Special Selected Late Harvest 96 90.0 California Knights Valley Sonoma Sauvignon Blanc Macauley

Subplotting

Subplotting is a technique for creating multiple plots that live side-by-side in one overall figure. We can use the subplots method to create a figure with multiple subplots. subplots takes two arguments. The first one controls the number of rows, the second one the number of columns.

In [2]:
import matplotlib.pyplot as plt
fig, axarr = plt.subplots(2, 1, figsize=(12, 8))

Since we asked for a subplots(2, 1), we got a figure with two rows and one column.

Let's break this down a bit. When pandas generates a bar chart, behind the scenes here is what it actually does:

  1. Generate a new matplotlib Figure object.
  2. Create a new matplotlib AxesSubplot object, and assign it to the Figure.
  3. Use AxesSubplot methods to draw the information on the screen.
  4. Return the result to the user.

In a similar way, our subplots operation above created one overall Figure with two AxesSubplots vertically nested inside of it.

subplots returns two things, a figure (which we assigned to fig) and an array of the axes contained therein (which we assigned to axarr). Here are the axarr contents:

In [3]:
axarr
Out[3]:
array([<matplotlib.axes._subplots.AxesSubplot object at 0x7fd13831e828>,
       <matplotlib.axes._subplots.AxesSubplot object at 0x7fd1382e9b70>],
      dtype=object)

To tell pandas which subplot we want a new plot to go in—the first one or the second one—we need to grab the proper axis out of the list and pass it into pandas via the ax parameter:

In [4]:
fig, axarr = plt.subplots(2, 1, figsize=(12, 8))

reviews['points'].value_counts().sort_index().plot.bar(
    ax=axarr[0]
)

reviews['province'].value_counts().head(20).plot.bar(
    ax=axarr[1]
)
Out[4]:
<matplotlib.axes._subplots.AxesSubplot at 0x7fd13823d128>

We are of course not limited to having only a single row. We can create as many subplots as we want, in whatever configuration we need.

For example:

In [5]:
fig, axarr = plt.subplots(2, 2, figsize=(12, 8))

If there are multiple columns and multiple rows, as above, the axis array becoming a list of lists:

In [6]:
axarr
Out[6]:
array([[<matplotlib.axes._subplots.AxesSubplot object at 0x7fd138148e48>,
        <matplotlib.axes._subplots.AxesSubplot object at 0x7fd12bf9c128>],
       [<matplotlib.axes._subplots.AxesSubplot object at 0x7fd12bf45780>,
        <matplotlib.axes._subplots.AxesSubplot object at 0x7fd12bf6be10>]],
      dtype=object)

That means that to plot our data from earlier, we now need a row number, then a column number.

In [7]:
fig, axarr = plt.subplots(2, 2, figsize=(12, 8))

reviews['points'].value_counts().sort_index().plot.bar(
    ax=axarr[0][0]
)

reviews['province'].value_counts().head(20).plot.bar(
    ax=axarr[1][1]
)
Out[7]:
<matplotlib.axes._subplots.AxesSubplot at 0x7fd12be18f98>

Notice that the bar plot of wines by point counts is in the first row and first column (the [0][0] position), while the bar plot of wines by origin is in the second row and second column ([1][1]).

By combining subplots with the styles we learned in the last section, we can create appealing-looking panel displays.

In [8]:
fig, axarr = plt.subplots(2, 2, figsize=(12, 8))

reviews['points'].value_counts().sort_index().plot.bar(
    ax=axarr[0][0], fontsize=12, color='mediumvioletred'
)
axarr[0][0].set_title("Wine Scores", fontsize=18)

reviews['variety'].value_counts().head(20).plot.bar(
    ax=axarr[1][0], fontsize=12, color='mediumvioletred'
)
axarr[1][0].set_title("Wine Varieties", fontsize=18)

reviews['province'].value_counts().head(20).plot.bar(
    ax=axarr[1][1], fontsize=12, color='mediumvioletred'
)
axarr[1][1].set_title("Wine Origins", fontsize=18)

reviews['price'].value_counts().plot.hist(
    ax=axarr[0][1], fontsize=12, color='mediumvioletred'
)
axarr[0][1].set_title("Wine Prices", fontsize=18)

plt.subplots_adjust(hspace=.3)

import seaborn as sns
sns.despine()

Why subplot?

Why are subplots useful?

Oftentimes as a part of the exploratory data visualization process you will find yourself creating a large number of smaller charts probing one or a few specific aspects of the data. For example, suppose we're interested in comparing the scores for relatively common wines with those for relatively rare ones. In these cases, it makes logical sense to combine the two plots we would produce into one visual "unit" for analysis and discussion.

When we combine subplots with the style attributes we explored in the previous notebook, this technique allows us to create extremely attractive and informative panel displays.

Finally, subplots are critically useful because they enable faceting. Faceting is the act of breaking data variables up across multiple subplots, and combining those subplots into a single figure. So instead of one bar chart, we might have, say, four, arranged together in a grid.

The recommended way to perform faceting is to use the seaborn FacetGrid facility. This feature is explored in a separate section of this tutorial.

Exercises

Let's test ourselves by answering some questions about the plots we've used in this section. Once you have your answers, click on "Output" button below to show the correct answers.

  1. A matplotlib plot consists of a single X composed of one or more Y. What are X and Y?
  2. The subplots function takes which two parameters as input?
  3. The subplots function returns what two variables?
In [9]:
from IPython.display import HTML
HTML("""
<ol>
<li>The plot consists of one overall figure composed of one or more axes.</li>
<li>The subplots function takes the number of rows as the first parameter, and the number of columns as the second.</li>
<li>The subplots function returns a figure and an array of axes.</li>
</ol>
""")
Out[9]:
  1. The plot consists of one overall figure composed of one or more axes.
  2. The subplots function takes the number of rows as the first parameter, and the number of columns as the second.
  3. The subplots function returns a figure and an array of axes.

To put your design skills to the test, try forking this notebook and replicating the plots that follow. To see the answers, hit the "Input" button below to un-hide the code.

In [47]:
import pandas as pd
import matplotlib.pyplot as plt
pokemon = pd.read_csv("../input/pokemon/Pokemon.csv")
pokemon.head(3)
Out[47]:
# Name Type 1 Type 2 Total HP Attack Defense Sp. Atk Sp. Def Speed Generation Legendary
0 1 Bulbasaur Grass Poison 318 45 49 49 65 65 45 1 False
1 2 Ivysaur Grass Poison 405 60 62 63 80 80 60 1 False
2 3 Venusaur Grass Poison 525 80 82 83 100 100 80 1 False

(Hint: use figsize=(8, 8))

In [18]:
plt.subplots(2,1, figsize=(8, 8))
Out[18]:
(<Figure size 576x576 with 2 Axes>,
 array([<matplotlib.axes._subplots.AxesSubplot object at 0x7fd11b4b90b8>,
        <matplotlib.axes._subplots.AxesSubplot object at 0x7fd11b419898>],
       dtype=object))
In [44]:
fig, axarr = plt.subplots(2, 1, figsize=(8, 8))

pokemon['Attack'].plot.hist(ax=axarr[0], title='Pokemon Attack Ratings')
pokemon['Defense'].plot.hist(ax=axarr[1], title='Pokemon Defense Ratings')
Out[44]:
<matplotlib.axes._subplots.AxesSubplot at 0x7fd119764da0>
In [45]:
pokemon.describe()
Out[45]:
# Total HP Attack Defense Sp. Atk Sp. Def Speed Generation
count 800.000000 800.00000 800.000000 800.000000 800.000000 800.000000 800.000000 800.000000 800.00000
mean 362.813750 435.10250 69.258750 79.001250 73.842500 72.820000 71.902500 68.277500 3.32375
std 208.343798 119.96304 25.534669 32.457366 31.183501 32.722294 27.828916 29.060474 1.66129
min 1.000000 180.00000 1.000000 5.000000 5.000000 10.000000 20.000000 5.000000 1.00000
25% 184.750000 330.00000 50.000000 55.000000 50.000000 49.750000 50.000000 45.000000 2.00000
50% 364.500000 450.00000 65.000000 75.000000 70.000000 65.000000 70.000000 65.000000 3.00000
75% 539.250000 515.00000 80.000000 100.000000 90.000000 95.000000 90.000000 90.000000 5.00000
max 721.000000 780.00000 255.000000 190.000000 230.000000 194.000000 230.000000 180.000000 6.00000

Conclusion

In the previous section we explored some pandas/matplotlib style parameters. In this section, we dove a little deeper still by exploring subplots.

Together these two sections conclude our primer on style. Hopefully our plots will now be more legible and informative.

Click here to go to the next section, "Plotting with seaborn".