ericwen229
4/10/2018 - 11:56 AM

## Java implementation of Hungarian algorithm used to find maximum match in unweighted bipartite graph

Java implementation of Hungarian algorithm used to find maximum match in unweighted bipartite graph

``````import java.io.PrintStream;
import java.util.*;

public class Hungarian {

public static void main(String[] args) {
Scanner scanner = new Scanner(System.in);
int leftVertexCount = scanner.nextInt();
int rightVertexCount = scanner.nextInt();
BinGraph graph = new BinGraph(leftVertexCount, rightVertexCount);
for (int iLeftVertex = 0; iLeftVertex < leftVertexCount; ++ iLeftVertex) {
int neighbourCount = scanner.nextInt();
for (int iNeighbour = 0; iNeighbour < neighbourCount; ++ iNeighbour) {
int iRightVertex = scanner.nextInt();
}
}

System.out.println();
System.out.println("graph:");
graph.print(System.out);
System.out.println();

BinMatch maxMatch = graph.findMaxMatch();

System.out.println("max match:");
maxMatch.print(System.out);
System.out.println(String.format("size: %d", maxMatch.size()));
System.out.println(maxMatch.size());
}

}

class BinGraph {

class NeighbourIterator implements Iterator<Integer> {

private BinGraph graph;
private boolean isLeftVertex;
private int vertexIndex;
private int currNeighbourIndex;

NeighbourIterator(BinGraph graph, boolean isLeftVertex, int vertexIndex) {
this.graph = graph;
this.isLeftVertex = isLeftVertex;
this.vertexIndex = vertexIndex;
currNeighbourIndex = 0;
proceedIterator();
}

private void proceedIterator() {
if (isLeftVertex) {
while ((currNeighbourIndex < graph.rightVertexCount) &&
(!graph.connectionMatrix[vertexIndex][currNeighbourIndex])) {
++ currNeighbourIndex;
}
}
else {
while ((currNeighbourIndex < graph.leftVertexCount) &&
(!graph.connectionMatrix[currNeighbourIndex][vertexIndex])) {
++ currNeighbourIndex;
}
}
}

public boolean hasNext() {
return !(currNeighbourIndex == (isLeftVertex? graph.rightVertexCount: graph.leftVertexCount));
}

public Integer next() {
int neighbourIndex = currNeighbourIndex;
++ currNeighbourIndex;
proceedIterator();
return neighbourIndex;
}

}

private int leftVertexCount;
private int rightVertexCount;
private boolean[][] connectionMatrix;

BinGraph(int leftVertexCount, int rightVertexCount) {
this.leftVertexCount = leftVertexCount;
this.rightVertexCount = rightVertexCount;
initializeConnectionMatrix();
}

private void initializeConnectionMatrix() {
connectionMatrix = new boolean[leftVertexCount][rightVertexCount];
for (int i = 0; i < leftVertexCount; ++ i) {
for (int j = 0; j < rightVertexCount; ++ j) {
connectionMatrix[i][j] = false;
}
}
}

void addEdge(int leftVertexIndex, int rightVertexIndex) {
connectionMatrix[leftVertexIndex][rightVertexIndex] = true;
}

void print(PrintStream out) {
for (int iLeftVertex = 0; iLeftVertex < leftVertexCount; ++ iLeftVertex) {
out.print(iLeftVertex + ":");
NeighbourIterator it = createNeighbourIterator(true, iLeftVertex);
while (it.hasNext()) {
int neighbourIndex = it.next();
out.print(" " + neighbourIndex);
}
out.println();
}
}

NeighbourIterator createNeighbourIterator(boolean isLeftVertex, int vertexIndex) {
return new NeighbourIterator(this, isLeftVertex, vertexIndex);
}

private boolean dfs(boolean isLeftVertex,
int vertexIndex,
BinMatch currMaxMatch,
Set<Integer> vertexMarker) {
vertexMarker.add(isLeftVertex? vertexIndex: - (1 + vertexIndex));
if ((!isLeftVertex) && (!currMaxMatch.hasSaturated(false, vertexIndex))) {
return true;
}
if (isLeftVertex) {
// currently at left side
// then search saturated neighbours
NeighbourIterator it = createNeighbourIterator(true, vertexIndex);
while (it.hasNext()) {
int neighbourIndex = it.next();
if (!vertexMarker.contains(- (1 + neighbourIndex))) {
boolean success = dfs(false, neighbourIndex, currMaxMatch, pathRecorder, vertexMarker);
if (success) {
return true;
}
}
}
}
else {
// currently at right side
// then search the matched neighbour
int matchedNeighbourIndex = currMaxMatch.getMatchedNeighbour(false, vertexIndex);
if (!vertexMarker.contains(matchedNeighbourIndex)) {
boolean success = dfs(true, matchedNeighbourIndex, currMaxMatch, pathRecorder, vertexMarker);
if (success) {
return true;
}
}
}
pathRecorder.removeLast();
return false;
}

List<Integer> findAugmentingPath(BinMatch currMaxMatch) {
Set<Integer> vertexMarker = new HashSet<>();
for (int iLeftVertex = 0; iLeftVertex < leftVertexCount; ++ iLeftVertex) {
if (currMaxMatch.hasSaturated(true, iLeftVertex)) {
continue;
}
pathRecorder.clear();
boolean success = dfs(true, iLeftVertex, currMaxMatch, pathRecorder, vertexMarker);
if (success) {
return pathRecorder;
}
}
return null;
}

BinMatch findMaxMatch() {
BinMatch currMaxMatch = new BinMatch(leftVertexCount, rightVertexCount);
List<Integer> augmentingPath = findAugmentingPath(currMaxMatch);
while (augmentingPath != null) {
currMaxMatch.expand(augmentingPath, true);
augmentingPath = findAugmentingPath(currMaxMatch);
}
return currMaxMatch;
}

}

class BinMatch {

private int leftVertexCount;
private int rightVertexCount;
private HashMap<Integer, Integer> leftVertexMap;
private HashMap<Integer, Integer> rightVertexMap;

BinMatch(int leftVertexCount, int rightVertexCount) {
this.leftVertexCount = leftVertexCount;
this.rightVertexCount = rightVertexCount;
leftVertexMap = new HashMap<>();
rightVertexMap = new HashMap<>();
}

int size() {
assert(leftVertexMap.size() == rightVertexMap.size());
return leftVertexMap.size();
}

void print(PrintStream out) {
leftVertexMap.forEach((k, v) -> out.println(k + ": " + v));
}

void expand(List<Integer> augmentingPath, boolean fromLeft) {
if ((augmentingPath == null) || (augmentingPath.size() == 0)) return;
int firstVertexIndex = augmentingPath.get(0);
int lastVertexIndex = augmentingPath.get(augmentingPath.size() - 1);
int leftStartVertexIndex = fromLeft? firstVertexIndex: lastVertexIndex;
int rightEndVertexIndex = fromLeft? lastVertexIndex: firstVertexIndex;
assert(!hasSaturated(true, leftStartVertexIndex));
assert(!hasSaturated(false, rightEndVertexIndex));
ListIterator<Integer> it =
fromLeft? augmentingPath.listIterator(): augmentingPath.listIterator(augmentingPath.size());
while (fromLeft? it.hasNext(): it.hasPrevious()) {
int leftVertexIndex = 0;
int rightVertexIndex = 0;
if (fromLeft) {
leftVertexIndex = it.next();
rightVertexIndex = it.next();
}
else {
rightVertexIndex = it.previous();
leftVertexIndex = it.previous();
}
assert((leftVertexIndex == leftStartVertexIndex) || (hasSaturated(true, leftVertexIndex)));
assert((rightVertexIndex == rightEndVertexIndex) || (hasSaturated(false, rightEndVertexIndex)));
assert((leftVertexIndex >= 0) && (leftVertexIndex < leftVertexCount));
assert((rightVertexIndex >= 0) && (rightVertexIndex < rightVertexCount));
assert(leftVertexMap.get(leftVertexIndex) != rightVertexIndex);
assert(rightVertexMap.get(leftVertexIndex) != rightVertexIndex);
leftVertexMap.put(leftVertexIndex, rightVertexIndex);
rightVertexMap.put(rightVertexIndex, leftVertexIndex);
}
}

int getMatchedNeighbour(boolean isLeftVertex, int vertexIndex) {
if (isLeftVertex) {
return leftVertexMap.get(vertexIndex);
}
else {
return rightVertexMap.get(vertexIndex);
}
}

boolean hasSaturated(boolean isLeftVertex, int vertexIndex) {
if (isLeftVertex) {
return leftVertexMap.containsKey(vertexIndex);
}
else {
return rightVertexMap.containsKey(vertexIndex);
}
}

}``````