plot network structure
def plot_graph(weights, labels, ax=None):
import networkx as nx
G = nx.Graph()
for i in range(weights.shape[0]):
for j in range(weights.shape[1]):
if weights[i][j] != 0:
G.add_edge(labels[i], labels[j], weight=weights[i][j])
e_pos = [(u, v) for (u, v, d) in G.edges(data=True) if d['weight'] > 0]
e_pos_weight = np.array([G[u][v]['weight'] for (u, v) in e_pos])
# e_neg = [(u, v) for (u, v, d) in G.edges(data=True) if d['weight'] < 0]
# e_neg_weight = np.array([G[u][v]['weight'] for (u, v) in e_neg])
# normalize width
max_value = e_pos_weight.max()
e_pos_weight *= 10 / max_value
# e_neg_weight *= 10 / max_value
pos = nx.circular_layout(G) # get layout
nx.draw_networkx_nodes(G, pos, node_size=700, ax=ax) # draw nodes
# nx.draw_networkx_edges(G, pos) # draw edges
nx.draw_networkx_edges(G, pos, edgelist=e_pos, width=e_pos_weight, ax=ax)
# nx.draw_networkx_edges(G, pos, edgelist=e_neg, width=e_neg_weight,
# alpha=0.5, edge_color='b', style='dashed', ax=ax)
nx.draw_networkx_labels(G, pos, font_size=15, # draw labels
font_family='sans-serif', ax=ax)