42 lines
1.1 KiB
Python
42 lines
1.1 KiB
Python
import os, sys, time
|
|
import numpy as np
|
|
import matplotlib
|
|
import random
|
|
matplotlib.use('agg')
|
|
import matplotlib.pyplot as plt
|
|
import matplotlib.cm as cm
|
|
|
|
def draw_points(points, labels, save_path):
|
|
title = 'the visualized features'
|
|
dpi = 100
|
|
width, height = 1000, 1000
|
|
legend_fontsize = 10
|
|
figsize = width / float(dpi), height / float(dpi)
|
|
fig = plt.figure(figsize=figsize)
|
|
|
|
classes = np.unique(labels).tolist()
|
|
colors = cm.rainbow(np.linspace(0, 1, len(classes)))
|
|
|
|
legends = []
|
|
legendnames = []
|
|
|
|
for cls, c in zip(classes, colors):
|
|
|
|
indexes = labels == cls
|
|
ptss = points[indexes, :]
|
|
x = ptss[:,0]
|
|
y = ptss[:,1]
|
|
if cls % 2 == 0: marker = 'x'
|
|
else: marker = 'o'
|
|
legend = plt.scatter(x, y, color=c, s=1, marker=marker)
|
|
legendname = '{:02d}'.format(cls+1)
|
|
legends.append( legend )
|
|
legendnames.append( legendname )
|
|
|
|
plt.legend(legends, legendnames, scatterpoints=1, ncol=5, fontsize=8)
|
|
|
|
if save_path is not None:
|
|
fig.savefig(save_path, dpi=dpi, bbox_inches='tight')
|
|
print ('---- save figure {} into {}'.format(title, save_path))
|
|
plt.close(fig)
|