用 Python 实现全连接层网络可视化

以下代码即可实现全连接层网络的可视化: # 引用模块from pylab import mpl #matplotlib使用中文# 自编函数def ANN_

以下代码即可实现全连接层网络的可视化:

# 引用模块
from pylab import mpl #matplotlib使用中文# 自编函数
def ANN_ksh(number_input,number_hidden,number_output):import numpy as npimport networkx as nximport matplotlib.pyplot as pltmpl.rcParams['font.sans-serif']=['SimHei'] #matplotlib使用中文,SimHei为黑体# number_input为输入层节点个数,number_hidden为隐藏层各层节点个数,number_output为输出层节点个数ceng_hidden=len(number_hidden) #隐藏层层数G=nx.DiGraph()# 节点vertex_input_list=['v'+str(i) for i in range(1,number_input+1)] #输入层vertex_hidden_list=[]start=number_input+1end=number_input+number_hidden[0]+1vertex_hidden_list.append(['v'+str(i) for i in range(start,end)]) #隐藏层for j in range(1,ceng_hidden):start=endend=start+number_hidden[j]vertex_hidden_list.append(['v'+str(i) for i in range(start,end)]) #隐藏层vertex_output_list=['v'+str(i) for i in range(end,end+number_output)] #输出层vertex_list=[]vertex_list.extend(vertex_input_list)list(map(lambda i:vertex_list.extend(vertex_hidden_list[i]),range(ceng_hidden)))vertex_list.extend(vertex_output_list)G.add_nodes_from(vertex_list)# 连接edge_input_hidden_list=[]edge_input_hidden_list.extend([(i,j) for i in vertex_input_list for j in vertex_hidden_list[0]]) #输入层-隐藏层edge_list=[]edge_list.extend(edge_input_hidden_list)edge_hidden_hidden_list=[]if ceng_hidden>1:for k in range(ceng_hidden-1):edge_hidden_hidden_list.extend([(i,j) for i in vertex_hidden_list[k] for j in vertex_hidden_list[k+1]]) #隐藏层-隐藏层edge_list.extend(edge_hidden_hidden_list)edge_hidden_output_list=[]edge_hidden_output_list.extend([(i,j) for i in vertex_hidden_list[len(vertex_hidden_list)-1] for j in vertex_output_list]) #隐藏层-输出层edge_list.extend(edge_hidden_output_list)G.add_edges_from(edge_list)# 位置pos={}ceng_pos_x=np.linspace(-(ceng_hidden+2)/2,(ceng_hidden+2)/2,num=ceng_hidden+2)list(map(lambda i:pos.update({vertex_input_list[int(np.where(np.arange(-number_input/2*1+1/2,number_input/2*1+1/2,1)==i)[0])]:(ceng_pos_x[0],i)}),np.arange(-number_input/2*1+1/2,number_input/2*1+1/2,1))) #输入层list(map(lambda j:list(map(lambda i:pos.update({vertex_hidden_list[j][int(np.where(np.arange(-number_hidden[j]/2*1+1/2,number_hidden[j]/2*1+1/2,1)==i)[0])]:(ceng_pos_x[j+1],i)}),np.arange(-number_hidden[j]/2*1+1/2,number_hidden[j]/2*1+1/2,1))),range(ceng_hidden))) #隐藏层list(map(lambda i:pos.update({vertex_output_list[int(np.where(np.arange(-number_output/2*1+1/2,number_output/2*1+1/2,1)==i)[0])]:(ceng_pos_x[len(ceng_pos_x)-1],i)}),np.arange(-number_output/2*1+1/2,number_output/2*1+1/2,1))) #输出层fig=plt.figure(figsize=(8,5),dpi=300)plt.xlim(ceng_pos_x[0]-1,ceng_pos_x[len(ceng_pos_x)-1]+1)plt.ylim(-max(number_input,max(number_hidden),number_output)/2*1,max(number_input,max(number_hidden),number_output)/2*1+1/2)nx.draw(G,pos=pos,node_color='red',edge_color='black',with_labels=False,font_size=10,node_size=300,)fig.savefig('全连接层网络可视化.png')

函数参数说明:

  number_input 为输入层的节点个数,number_hidden 为隐藏层各层的节点个数,number_output 为输出层的节点个数。

调用函数示例:

ANN_ksh(8,[8,5,2],2)

结果:

图 1 全连接层网络可视化