42 lines
		
	
	
		
			1.1 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			42 lines
		
	
	
		
			1.1 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
import os, sys, time
 | 
						|
import numpy as np
 | 
						|
import matplotlib
 | 
						|
import random
 | 
						|
matplotlib.use('agg')
 | 
						|
import matplotlib.pyplot as plt
 | 
						|
import matplotlib.cm as cm
 | 
						|
 | 
						|
def draw_points(points, labels, save_path):
 | 
						|
  title = 'the visualized features'
 | 
						|
  dpi = 100 
 | 
						|
  width, height = 1000, 1000
 | 
						|
  legend_fontsize = 10
 | 
						|
  figsize = width / float(dpi), height / float(dpi)
 | 
						|
  fig = plt.figure(figsize=figsize)
 | 
						|
 | 
						|
  classes = np.unique(labels).tolist()
 | 
						|
  colors = cm.rainbow(np.linspace(0, 1, len(classes)))
 | 
						|
 | 
						|
  legends = []
 | 
						|
  legendnames = []
 | 
						|
 | 
						|
  for cls, c in zip(classes, colors):
 | 
						|
    
 | 
						|
    indexes = labels == cls
 | 
						|
    ptss = points[indexes, :]
 | 
						|
    x = ptss[:,0]
 | 
						|
    y = ptss[:,1]
 | 
						|
    if cls % 2 == 0: marker = 'x'
 | 
						|
    else:            marker = 'o'
 | 
						|
    legend = plt.scatter(x, y, color=c, s=1, marker=marker)
 | 
						|
    legendname = '{:02d}'.format(cls+1)
 | 
						|
    legends.append( legend )
 | 
						|
    legendnames.append( legendname )
 | 
						|
 | 
						|
  plt.legend(legends, legendnames, scatterpoints=1, ncol=5, fontsize=8)
 | 
						|
 | 
						|
  if save_path is not None:
 | 
						|
    fig.savefig(save_path, dpi=dpi, bbox_inches='tight')
 | 
						|
    print ('---- save figure {} into {}'.format(title, save_path))
 | 
						|
  plt.close(fig)
 |