wottpal
9/13/2018 - 10:29 AM

Word Similarity Visualization with Word2Vec, t-SNE, k-Means

#! -*- coding: utf-8 -*-

import sys
import time
import numpy as np
import gensim
import matplotlib.pyplot as plt
from matplotlib import font_manager, rc
from sklearn.manifold import TSNE
from sklearn.cluster import MiniBatchKMeans
from scipy.spatial import distance
import random
import math
from adjustText import adjust_text


VOCAB_SIZE = 10000      # MAX 3.000.000


# Load the Google-News Model (https://drive.google.com/file/d/0B7XkCwpI5KDYNlNUTTlSS21pQmM)
model_path = "./GoogleNews-vectors-negative300.bin"
model = gensim.models.KeyedVectors.load_word2vec_format(
    model_path, binary=True)
wv = model.wv.syn0              # word vectors (wv)
vocabulary = model.wv.vocab     # vocabulary (lexicon)


# Run TSNE  
tsne = TSNE(n_components=2, random_state=0)
np.set_printoptions(suppress=True)
Y = tsne.fit_transform(wv[:VOCAB_SIZE, :])
word_positions = zip(vocabulary, Y[:, 0], Y[:, 1])


def get_word_position(word):
    """Finds words location in the vocabulary and tsne-plot"""
    try:
        if word in vocabulary:
            word_index = list(vocabulary).index(word)
            word_x = Y[word_index, 0]
            word_y = Y[word_index, 1]
            print(f"'{word}' at ({word_x},{word_y})")
            return word_index, word_x, word_y
        else:
            return None, None, None
            # raise ValueError(f"Can't find given word '{word}'")
    except:
        return None, None, None


def get_random_color(pastel_factor=0.9):
    return [(x+pastel_factor)/(1.0+pastel_factor) for x in [random.uniform(0, 1.0) for i in [1, 2, 3]]]

    
def save_word_plot(word, max_dist):
    """Plots the tSNE surroundings of a given word with a given distance."""
    plt.figure()
    word_index, word_x, word_y = get_word_position(word)
    if word_index is None or word_x is None or word_y is None:
        return

    # Add words with a minimum location to the word of interest to a new list
    vocab_new = []
    x_new = []
    y_new = []
    for w, x, y in zip(vocabulary, Y[:, 0], Y[:, 1]):
        dist = distance.euclidean([word_x, word_y], [x, y])
        if dist < max_dist:
            vocab_new.append(w)
            x_new.append(x)
            y_new.append(y)

    # Plot
    plt.scatter(x_new, y_new, c=get_random_color())
    texts = []
    for label, x, y in zip(vocab_new, x_new, y_new):
        texts.append(plt.text(x, y, label))
        # plt.annotate(label, xy=(x, y), xytext=(0, 0), textcoord
        # s='offset points')
    # plt.show()
    adjust_text(texts)
    plt.axis('off')
    # plt.savefig(f'results/plot_{word}_{max_dist}.png', bbox_inches='tight')
    filename = f'results/plot_{word}_{max_dist}.pdf'
    plt.savefig(filename, bbox_inches='tight')
    plt.close()
    print(f"Saved '{filename}'")


def save_full_plot(show_labels = False, kmeans_clusters = 8):
    """Plots the full tSNE world"""
    bw = kmeans_clusters == 1
    plt.figure()
    kmeans = MiniBatchKMeans(n_clusters=kmeans_clusters)
    labels = kmeans.fit_predict(Y)
    colors = [get_random_color() for x in range(kmeans_clusters)]
    if bw: colors = ["gray"]
    for idx, word_pos in enumerate(word_positions):
        label, x, y = word_pos
        label_color = colors[labels[idx]]
        plt.plot(x, y, color=label_color, marker='o', markersize=1)
        if show_labels: plt.annotate(label, xy=(x, y), xytext=(0, 0), textcoords='offset points')
    plt.axis('off')
    filename = f'results/full_plot_{"labels_" if show_labels else ""}{VOCAB_SIZE}_{kmeans_clusters}.pdf'
    plt.savefig(filename, bbox_inches='tight')
    plt.close()
    print(f"Saved '{filename}'")