naivete5656
6/27/2018 - 4:30 AM

graphcut ref: https://qiita.com/naivete5656/items/d15a9f57d7940854b841

import numpy as np
import maxflow
from scipy.misc import imread
from matplotlib import pyplot as plt
import networkx as nx
import math

lumda = 1
k = 1


def create_graph():
    img = imread("result200.tif")
    row,column = img.shape
    g = maxflow.Graph[float]()
    nodeids = g.add_grid_nodes((img.shape))
        # 画素値の差をノード間の重みにするのでパディング
    pad_im = np.pad(img, ((0, 0), (1, 1)), 'constant', constant_values=0)
    weights = np.zeros((img.shape))
    for i in range(img.shape[0]):
        for j in range(img.shape[1]):
            weights[i, j] = lumda * math.exp((-k) * abs(pad_im[i, j] - pad_im[i, j + 1]))
    structure = np.zeros((3, 3))
    structure[1, 2] = 1
    g.add_grid_edges(nodeids, structure=structure, weights=weights, symmetric=True)

    pad_im = np.pad(img, ((1, 1), (0, 0)), 'constant', constant_values=0)
    weights = np.zeros((img.shape))
    for i in range(img.shape[0]):
        for j in range(img.shape[1]):
            weights[i, j] = lumda * math.exp((-k) * abs(pad_im[i, j] - pad_im[i + 1, j]))
    structure = np.zeros((3, 3))
    structure[2, 1] = 1
    g.add_grid_edges(nodeids, structure=structure, weights=weights, symmetric=True)
        # マスク画像をロード
    gro = imread('%05d.png' % 200)
    x, y = np.where(gro > 10) 
    for i in range(x.shape[0]): 
        if img[x[i], y[i]] >= 100: 
             temp = x[i] * column + y[i] 
             g.add_tedge(temp, 1000000000000, 0) 
             label[x[i], y[i]] = 1 
    x, y = np.where(img == 0) 
    for i in range(x.shape[0]): 
        temp = x[i] * column + y[i] 
        g.add_tedge(temp, 0, 100000) 
    return nodeids, g


if __name__ == '__main__':
    nodeids, g = create_graph()
    g.maxflow()
    sgm = g.get_grid_segments(nodeids)
    img = np.int_(np.logical_not(sgm))
    m = np.zeros((1040, 1392, 3)) 

    m[:, :, 0] = np.int_(np.logical_not(sgm)) 
    m = m.astype('uint8') * 255 
    img = cv2.imread("result200.tif") 
    dst = cv2.addWeighted(img, 0.5, m, 0.5, 0)     
    plt.imshow(dst), plt.show()