razhangwei
2/12/2015 - 1:54 AM

plot network structure

plot network structure

def plot_graph(weights, labels, ax=None):
    import networkx as nx

    G = nx.Graph()
    for i in range(weights.shape[0]):
        for j in range(weights.shape[1]):
            if weights[i][j] != 0:
                G.add_edge(labels[i], labels[j], weight=weights[i][j])

    e_pos = [(u, v) for (u, v, d) in G.edges(data=True) if d['weight'] > 0]
    e_pos_weight = np.array([G[u][v]['weight'] for (u, v) in e_pos])
    # e_neg = [(u, v) for (u, v, d) in G.edges(data=True) if d['weight'] < 0]
    # e_neg_weight = np.array([G[u][v]['weight'] for (u, v) in e_neg])

    # normalize width

    max_value = e_pos_weight.max()
    e_pos_weight *= 10 / max_value
    # e_neg_weight *= 10 / max_value

    pos = nx.circular_layout(G)						# get layout
    nx.draw_networkx_nodes(G, pos, node_size=700, ax=ax)  # draw nodes
    # nx.draw_networkx_edges(G, pos)			# draw edges
    nx.draw_networkx_edges(G, pos, edgelist=e_pos, width=e_pos_weight, ax=ax)
    # nx.draw_networkx_edges(G, pos, edgelist=e_neg, width=e_neg_weight,
    #                        alpha=0.5, edge_color='b', style='dashed', ax=ax)
    nx.draw_networkx_labels(G, pos, font_size=15,		# draw labels
                            font_family='sans-serif', ax=ax)