leo-y
11/2/2018 - 4:22 AM

pytorch

pytorch validation


def validate(model, loader):
    print("start to do validation ...")
    model.eval()
    label_map = [' '] + phoneme_list.PHONEME_MAP
    decoder = CTCBeamDecoder(labels=label_map, blank_id=0,
                             beam_width=BEAM_WIDTH)
    all_len = len(loader)
    ls = 0
    with torch.no_grad():
        for num_batch, (data, label) in enumerate(loader):
            if use_cuda:
                data = [x.cuda() for x in data]
                label = [y.cuda() for y in label]
            x_lens = [len(s) for s in data]
            y_lens = [len(s) for s in label]

            x_lens, y_lens, data, label = sort_data(x_lens, y_lens, data, label)
            output, output_lens = model(data, x_lens)

            output = torch.transpose(output, 0, 1)

            probs = f.softmax(output, dim=2)
            x_lens = torch.from_numpy(np.asarray(x_lens))
            output, scores, timesteps, out_seq_len = decoder.decode(probs=probs,
                                                                    seq_lens=x_lens)

            print('type is', type(output))
            output_size = output.size(0)

            print('shape: ', output[0].shape)
            print('shape: ', output[0])
            print('shape2: ', out_seq_len[0].shape)
            print('out length: ', out_seq_len[0])

            for i in range(output_size):
                label_cpu = label[i].cpu().numpy()
                predicted_val = [label_map[o] for o in output[i, 0,
                                                       :out_seq_len[i, 0]]]
                predicted_val = "".join(predicted_val)

                true_val = "".join([label_map[val] for val in label_cpu])
                ls += L.distance(predicted_val, true_val)
                if i == 0:  # todo_test
                    print("predicted", predicted_val)
                    print("true value: ", true_val)
            #
            for i in range(output.size(0)):
                result_label = label[i].cpu().numpy()
                pred = "".join(
                    label_map[o] for o in output[i, 0, :out_seq_len[i, 0]])
                true = "".join(label_map[l] for l in result_label)
                if i == 0:
                    print("pred: ", pred)
                    print("actual: ", true)
                ls += L.distance(pred, true)
        print("Validation: ", ls / all_len)
        return ls / all_len