ericwen229
4/10/2018 - 11:56 AM

Java implementation of Hungarian algorithm used to find maximum match in unweighted bipartite graph

Java implementation of Hungarian algorithm used to find maximum match in unweighted bipartite graph

import java.io.PrintStream;
import java.util.*;

public class Hungarian {

  public static void main(String[] args) {
    Scanner scanner = new Scanner(System.in);
    int leftVertexCount = scanner.nextInt();
    int rightVertexCount = scanner.nextInt();
    BinGraph graph = new BinGraph(leftVertexCount, rightVertexCount);
    for (int iLeftVertex = 0; iLeftVertex < leftVertexCount; ++ iLeftVertex) {
      int neighbourCount = scanner.nextInt();
      for (int iNeighbour = 0; iNeighbour < neighbourCount; ++ iNeighbour) {
        int iRightVertex = scanner.nextInt();
        graph.addEdge(iLeftVertex, iRightVertex - 1);
      }
    }

    System.out.println();
    System.out.println("graph:");
    graph.print(System.out);
    System.out.println();

    BinMatch maxMatch = graph.findMaxMatch();

    System.out.println("max match:");
    maxMatch.print(System.out);
    System.out.println(String.format("size: %d", maxMatch.size()));
    System.out.println(maxMatch.size());
  }

}

class BinGraph {

  class NeighbourIterator implements Iterator<Integer> {

    private BinGraph graph;
    private boolean isLeftVertex;
    private int vertexIndex;
    private int currNeighbourIndex;

    NeighbourIterator(BinGraph graph, boolean isLeftVertex, int vertexIndex) {
      this.graph = graph;
      this.isLeftVertex = isLeftVertex;
      this.vertexIndex = vertexIndex;
      currNeighbourIndex = 0;
      proceedIterator();
    }

    private void proceedIterator() {
      if (isLeftVertex) {
        while ((currNeighbourIndex < graph.rightVertexCount) &&
               (!graph.connectionMatrix[vertexIndex][currNeighbourIndex])) {
          ++ currNeighbourIndex;
        }
      }
      else {
        while ((currNeighbourIndex < graph.leftVertexCount) &&
               (!graph.connectionMatrix[currNeighbourIndex][vertexIndex])) {
          ++ currNeighbourIndex;
        }
      }
    }

    public boolean hasNext() {
      return !(currNeighbourIndex == (isLeftVertex? graph.rightVertexCount: graph.leftVertexCount));
    }

    public Integer next() {
      int neighbourIndex = currNeighbourIndex;
      ++ currNeighbourIndex;
      proceedIterator();
      return neighbourIndex;
    }

  }

  private int leftVertexCount;
  private int rightVertexCount;
  private boolean[][] connectionMatrix;

  BinGraph(int leftVertexCount, int rightVertexCount) {
    this.leftVertexCount = leftVertexCount;
    this.rightVertexCount = rightVertexCount;
    initializeConnectionMatrix();
  }

  private void initializeConnectionMatrix() {
    connectionMatrix = new boolean[leftVertexCount][rightVertexCount];
    for (int i = 0; i < leftVertexCount; ++ i) {
      for (int j = 0; j < rightVertexCount; ++ j) {
        connectionMatrix[i][j] = false;
      }
    }
  }

  void addEdge(int leftVertexIndex, int rightVertexIndex) {
    connectionMatrix[leftVertexIndex][rightVertexIndex] = true;
  }

  void print(PrintStream out) {
    for (int iLeftVertex = 0; iLeftVertex < leftVertexCount; ++ iLeftVertex) {
      out.print(iLeftVertex + ":");
      NeighbourIterator it = createNeighbourIterator(true, iLeftVertex);
      while (it.hasNext()) {
        int neighbourIndex = it.next();
        out.print(" " + neighbourIndex);
      }
      out.println();
    }
  }

  NeighbourIterator createNeighbourIterator(boolean isLeftVertex, int vertexIndex) {
    return new NeighbourIterator(this, isLeftVertex, vertexIndex);
  }

  private boolean dfs(boolean isLeftVertex,
                      int vertexIndex,
                      BinMatch currMaxMatch,
                      LinkedList<Integer> pathRecorder,
                      Set<Integer> vertexMarker) {
    pathRecorder.addLast(vertexIndex);
    vertexMarker.add(isLeftVertex? vertexIndex: - (1 + vertexIndex));
    if ((!isLeftVertex) && (!currMaxMatch.hasSaturated(false, vertexIndex))) {
      return true;
    }
    if (isLeftVertex) {
      // currently at left side
      // then search saturated neighbours
      NeighbourIterator it = createNeighbourIterator(true, vertexIndex);
      while (it.hasNext()) {
        int neighbourIndex = it.next();
        if (!vertexMarker.contains(- (1 + neighbourIndex))) {
          boolean success = dfs(false, neighbourIndex, currMaxMatch, pathRecorder, vertexMarker);
          if (success) {
            return true;
          }
        }
      }
    }
    else {
      // currently at right side
      // then search the matched neighbour
      int matchedNeighbourIndex = currMaxMatch.getMatchedNeighbour(false, vertexIndex);
      if (!vertexMarker.contains(matchedNeighbourIndex)) {
        boolean success = dfs(true, matchedNeighbourIndex, currMaxMatch, pathRecorder, vertexMarker);
        if (success) {
          return true;
        }
      }
    }
    pathRecorder.removeLast();
    return false;
  }

  List<Integer> findAugmentingPath(BinMatch currMaxMatch) {
    LinkedList<Integer> pathRecorder = new LinkedList<>();
    Set<Integer> vertexMarker = new HashSet<>();
    for (int iLeftVertex = 0; iLeftVertex < leftVertexCount; ++ iLeftVertex) {
      if (currMaxMatch.hasSaturated(true, iLeftVertex)) {
        continue;
      }
      pathRecorder.clear();
      boolean success = dfs(true, iLeftVertex, currMaxMatch, pathRecorder, vertexMarker);
      if (success) {
        return pathRecorder;
      }
    }
    // augmenting path not found
    return null;
  }

  BinMatch findMaxMatch() {
    BinMatch currMaxMatch = new BinMatch(leftVertexCount, rightVertexCount);
    List<Integer> augmentingPath = findAugmentingPath(currMaxMatch);
    while (augmentingPath != null) {
      currMaxMatch.expand(augmentingPath, true);
      augmentingPath = findAugmentingPath(currMaxMatch);
    }
    return currMaxMatch;
  }

}

class BinMatch {

  private int leftVertexCount;
  private int rightVertexCount;
  private HashMap<Integer, Integer> leftVertexMap;
  private HashMap<Integer, Integer> rightVertexMap;

  BinMatch(int leftVertexCount, int rightVertexCount) {
    this.leftVertexCount = leftVertexCount;
    this.rightVertexCount = rightVertexCount;
    leftVertexMap = new HashMap<>();
    rightVertexMap = new HashMap<>();
  }

  int size() {
    assert(leftVertexMap.size() == rightVertexMap.size());
    return leftVertexMap.size();
  }

  void print(PrintStream out) {
    leftVertexMap.forEach((k, v) -> out.println(k + ": " + v));
  }

  void expand(List<Integer> augmentingPath, boolean fromLeft) {
    if ((augmentingPath == null) || (augmentingPath.size() == 0)) return;
    int firstVertexIndex = augmentingPath.get(0);
    int lastVertexIndex = augmentingPath.get(augmentingPath.size() - 1);
    int leftStartVertexIndex = fromLeft? firstVertexIndex: lastVertexIndex;
    int rightEndVertexIndex = fromLeft? lastVertexIndex: firstVertexIndex;
    assert(!hasSaturated(true, leftStartVertexIndex));
    assert(!hasSaturated(false, rightEndVertexIndex));
    ListIterator<Integer> it =
      fromLeft? augmentingPath.listIterator(): augmentingPath.listIterator(augmentingPath.size());
    while (fromLeft? it.hasNext(): it.hasPrevious()) {
      int leftVertexIndex = 0;
      int rightVertexIndex = 0;
      if (fromLeft) {
        leftVertexIndex = it.next();
        rightVertexIndex = it.next();
      }
      else {
        rightVertexIndex = it.previous();
        leftVertexIndex = it.previous();
      }
      assert((leftVertexIndex == leftStartVertexIndex) || (hasSaturated(true, leftVertexIndex)));
      assert((rightVertexIndex == rightEndVertexIndex) || (hasSaturated(false, rightEndVertexIndex)));
      assert((leftVertexIndex >= 0) && (leftVertexIndex < leftVertexCount));
      assert((rightVertexIndex >= 0) && (rightVertexIndex < rightVertexCount));
      assert(leftVertexMap.get(leftVertexIndex) != rightVertexIndex);
      assert(rightVertexMap.get(leftVertexIndex) != rightVertexIndex);
      leftVertexMap.put(leftVertexIndex, rightVertexIndex);
      rightVertexMap.put(rightVertexIndex, leftVertexIndex);
    }
  }

  int getMatchedNeighbour(boolean isLeftVertex, int vertexIndex) {
    if (isLeftVertex) {
      return leftVertexMap.get(vertexIndex);
    }
    else {
      return rightVertexMap.get(vertexIndex);
    }
  }

  boolean hasSaturated(boolean isLeftVertex, int vertexIndex) {
    if (isLeftVertex) {
      return leftVertexMap.containsKey(vertexIndex);
    }
    else {
      return rightVertexMap.containsKey(vertexIndex);
    }
  }

}