taozhuo
11/4/2015 - 9:39 PM

subgraph generation for giraph

subgraph generation for giraph

package com.adsymp.dpp.giraph;

import java.io.IOException;
import java.util.Arrays;

import org.apache.giraph.edge.Edge;
import org.apache.giraph.graph.BasicComputation;
import org.apache.giraph.graph.GraphState;
import org.apache.giraph.graph.GraphTaskManager;
import org.apache.giraph.graph.Vertex;
import org.apache.giraph.partition.PartitionBalancer;
import org.apache.giraph.worker.WorkerContext;
import org.apache.giraph.worker.WorkerGlobalCommUsage;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.NullWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.io.Writable;
import org.apache.hadoop.io.WritableComparable;
import org.apache.giraph.aggregators.LongSumAggregator;
import org.apache.giraph.master.DefaultMasterCompute;

import org.apache.hadoop.io.IntWritable;
import org.apache.giraph.benchmark.*;
import org.apache.giraph.comm.WorkerClientRequestProcessor;
import org.apache.giraph.conf.ImmutableClassesGiraphConfiguration;

/**
 * Generate subgraph from adjacency list.
 * 
 */
public class SubgraphGenerationComputation
        extends BasicComputation<LongWritable, IntWritable, NullWritable, IntWritable>
{
    private static final String CONF_SEEDS = "com.adsymp.seeds";
    private static final String CONF_NUM_STEPS = "com.adsymp.steps";
    private static final String CONF_MAX_NODES = "com.adsymp.nodes";

    Long[] seeds;
    int steps;
    long numNodes;
    
    public static class SubgraphGenerationMasterCompute extends DefaultMasterCompute
    {
        @Override
        public void initialize() throws InstantiationException,
        IllegalAccessException {
            registerPersistentAggregator("sum", LongSumAggregator.class);
        }
    }
    
    @Override
    public void initialize(GraphState graphState,
            WorkerClientRequestProcessor<LongWritable, IntWritable, NullWritable> workerClientRequestProcessor,
            GraphTaskManager<LongWritable, IntWritable, NullWritable> graphTaskManager,
            WorkerGlobalCommUsage workerGlobalCommUsage, WorkerContext workerContext)
    {
        super.initialize(graphState, workerClientRequestProcessor, graphTaskManager, workerGlobalCommUsage,
                workerContext);
        @SuppressWarnings("rawtypes")
        ImmutableClassesGiraphConfiguration<WritableComparable, Writable, Writable> config = workerContext.getConf();
        String listSeeds = config.get(CONF_SEEDS, "");
        steps = config.getInt(CONF_NUM_STEPS, 30);
        numNodes = config.getLong(CONF_MAX_NODES, 10000000L);
        String[] arraySeeds = listSeeds.split("\\^");
        seeds = new Long[arraySeeds.length];
        for(int i=0; i<seeds.length;i++) {
            seeds[i]=Long.parseLong(arraySeeds[i]);
        }
        //balance partitioning based on edge counts
        config.set(
                PartitionBalancer.PARTITION_BALANCE_ALGORITHM,
                PartitionBalancer.EGDE_BALANCE_ALGORITHM);
        
    }

    @Override
    public void compute(Vertex<LongWritable, IntWritable, NullWritable> vertex, Iterable<IntWritable> messages)
            throws IOException
    {
        long ssNo = getSuperstep();
        int mask = vertex.getValue().get();
        long totalNodes = ((LongWritable)this.getAggregatedValue("sum")).get();
        if (totalNodes >= numNodes) 
        {
            vertex.voteToHalt();
            return;
        }
        if (ssNo == 0)
        {
            Long id = vertex.getId().get();
            for(int i = 0;i < seeds.length; i++)
            {
                if ((long)seeds[i]==id) mask |= 1<<i;
            }
            //if vertex is one of the seeds, send its value to all neighbors
            if (mask != vertex.getValue().get()) 
            {
                vertex.setValue(new IntWritable(mask));
                sendMessageToAllEdges(vertex, vertex.getValue());
            }
            vertex.voteToHalt();
            return;
        }
        for (IntWritable iwm : messages)
        {
            mask |= iwm.get();
        }
        if (mask != vertex.getValue().get()) 
        {
            vertex.setValue(new IntWritable(mask));
            if (ssNo < steps)
            {
                getContext().getCounter("SS" + ssNo, "NODES").increment(1);
                this.aggregate("sum", new LongWritable(1));
                sendMessageToAllEdges(vertex, vertex.getValue());
            }
        }
        vertex.voteToHalt();
        return;
    }
}