Welcome to OStack Knowledge Sharing Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
1.2k views
in Technique[技术] by (71.8m points)

numpy - understanding matplotlib.subplots python


与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Answer

0 votes
by (71.8m points)

The different return types are due to the squeeze keyword argument to plt.subplots() which is set to True by default. Let's enhance the documentation with the respective unpackings:

squeeze : bool, optional, default: True

  • If True, extra dimensions are squeezed out from the returned Axes object:

    • if only one subplot is constructed (nrows=ncols=1), the resulting single Axes object is returned as a scalar.
      fig, ax = plt.subplots()
    • for Nx1 or 1xN subplots, the returned object is a 1D numpy object array of Axes objects are returned as numpy 1D arrays.
      fig, (ax1, ..., axN) = plt.subplots(nrows=N, ncols=1) (for Nx1)
      fig, (ax1, ..., axN) = plt.subplots(nrows=1, ncols=N) (for 1xN)
    • for NxM, subplots with N>1 and M>1 are returned as a 2D arrays.
      fig, ((ax11, .., ax1M),..,(axN1, .., axNM)) = plt.subplots(nrows=N, ncols=M)
  • If False, no squeezing at all is done: the returned Axes object is always a 2D array containing Axes instances, even if it ends up being 1x1.
    fig, ((ax,),) = plt.subplots(nrows=1, ncols=1, squeeze=False)
    fig, ((ax,), .. ,(axN,)) = plt.subplots(nrows=N, ncols=1, squeeze=False) for Nx1
    fig, ((ax, .. ,axN),) = plt.subplots(nrows=1, ncols=N, squeeze=False) for 1xN
    fig, ((ax11, .., ax1M),..,(axN1, .., axNM)) = plt.subplots(nrows=N, ncols=M)

Alternatively you may always use the unpacked version

fig, ax_arr = plt.subplots(nrows=N, ncols=M, squeeze=False)

and index the array to obtain the axes, ax_arr[1,2].plot(..).

So for a 2 x 3 grid it wouldn't actually matter if you set squeeze to False. The result will always be a 2D array. You may unpack it as

fig, ((ax1, ax2, ax3),(ax4, ax5, ax6)) = plt.subplots(nrows=2, ncols=3)

to have ax{i} as the matplotlib axes objects, or you may use the packed version

fig, ax_arr = plt.subplots(nrows=2, ncols=3)
ax_arr[0,0].plot(..) # plot to first top left axes
ax_arr[1,2].plot(..) # plot to last bottom right axes

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome to OStack Knowledge Sharing Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

...