Welcome to OStack Knowledge Sharing Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
762 views
in Technique[技术] by (71.8m points)

python - Automated legend creation for 3D plot

I'm trying to update below function to report the clusters info via legend:

color_names = ["red", "blue", "yellow", "black", "pink", "purple", "orange"]

def plot_3d_transformed_data(df, title, colors="red"):
 
  ax = plt.figure(figsize=(12,10)).gca(projection='3d')
  #fig = plt.figure(figsize=(8, 8))
  #ax = fig.add_subplot(111, projection='3d')
  

  if type(colors) is np.ndarray:
    for cname, class_label in zip(color_names, np.unique(colors)):
      X_color = df[colors == class_label]
      ax.scatter(X_color[:, 0], X_color[:, 1], X_color[:, 2], marker="x", c=cname, label=f"Cluster {class_label}" if type(colors) is np.ndarray else None)
  else:
      ax.scatter(df.Type, df.Length, df.Freq, alpha=0.6, c=colors, marker="x", label=str(clusterSizes)  )

  ax.set_xlabel("PC1: Type")
  ax.set_ylabel("PC2: Length")
  ax.set_zlabel("PC3: Frequency")
  ax.set_title(title)
  
  if type(colors) is np.ndarray:
    #ax.legend()
    plt.gca().legend()
    
  
  plt.legend(bbox_to_anchor=(1.04,1), loc="upper left")
  plt.show()

So I call my function to visualize the clusters patterns by:

plot_3d_transformed_data(pdf_km_pred,
                         f'Clustering rare URL parameters for data of date: {DATE_FROM}  
Method: KMeans over PCA 
n_clusters={n_clusters} , Distance_Measure={DistanceMeasure}',
                         colors=pdf_km_pred.prediction_km)

print(clusterSizes)

Sadly I can't show the legend, and I have to print clusters members manually under the 3D plot. This is the output without legend with the following error: No handles with labels found to put in legend. enter image description here

I check this post, but I couldn't figure out what is the mistake in function to pass the cluster label list properly. I want to update the function so that I can demonstrate cluster labels via clusterSizes.index and their scale via clusterSizes.size

Expected output: As here suggests better using legend_elements() to determine a useful number of legend entries to be shown and return a tuple of handles and labels automatically.

Update: As I mentioned in the expected output should contain one legend for cluster labels and the other legend for cluster size (number of instances in each cluster). It might report this info via single legend too. Please see below example for 2D: img

See Question&Answers more detail:os

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Answer

0 votes
by (71.8m points)

In the function to visualize the clusters, you need ax.legend instead of plt.legend

from matplotlib import pyplot as plt
from mpl_toolkits.mplot3d.axes3d import Axes3D
import numpy as np
import pandas as pd

color_names = ["red", "blue", "yellow", "black", "pink", "purple", "orange"]

def plot_3d_transformed_data(df, title, colors="red"):
 
  ax = plt.figure(figsize=(12,10)).gca(projection='3d')
  #fig = plt.figure(figsize=(8, 8))
  #ax = fig.add_subplot(111, projection='3d')
  

  if type(colors) is np.ndarray:
    for cname, class_label in zip(color_names, np.unique(colors)):
      X_color = df[colors == class_label]
      ax.scatter(X_color[:, 0], X_color[:, 1], X_color[:, 2], marker="x", c=cname, label=f"Cluster {class_label}" if type(colors) is np.ndarray else None)
  else:
      ax.scatter(df.Type, df.Length, df.Freq, alpha=0.6, c=colors, marker="x", label=str(clusterSizes)  )

  ax.set_xlabel("PC1: Type")
  ax.set_ylabel("PC2: Length")
  ax.set_zlabel("PC3: Frequency")
  ax.set_title(title)
  
  if type(colors) is np.ndarray:
    #ax.legend()
    plt.gca().legend()
    
  
  ax.legend(bbox_to_anchor=(.9,1), loc="upper left")
  plt.show()

clusterSizes = 10

test_df = pd.DataFrame({'Type':np.random.randint(0,5,10),
                        'Length':np.random.randint(0,20,10),
                        'Freq':np.random.randint(0,10,10),
                        'Colors':np.random.choice(color_names,10)})

plot_3d_transformed_data(test_df,
                         'Clustering rare URL parameters for data of date:haha
Method: KMeans over PCA 
n_clusters={n_clusters} , Distance_Measure={DistanceMeasure}',
                         colors=test_df.Colors)

Running this example code, you will have legend handle as expected enter image description here


与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome to OStack Knowledge Sharing Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

...