sandboxws
1/19/2019 - 6:55 PM

AlbumPipeline.java

package com.sandboxws.chinook;

import com.google.api.services.bigquery.model.TableRow;
import com.sandboxws.beam.AppOptions;
import com.sandboxws.beam.coders.TableRowCoder;
import com.sandboxws.chinook.bigquery.schema.AlbumTableSchema;

import java.sql.ResultSet;
import java.util.HashMap;
import java.util.Map;

import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.coders.NullableCoder;
import org.apache.beam.sdk.coders.StringUtf8Coder;
import org.apache.beam.sdk.io.TextIO;
import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO;
import org.apache.beam.sdk.io.gcp.bigquery.TableRowJsonCoder;
import org.apache.beam.sdk.io.jdbc.JdbcIO;
import org.apache.beam.sdk.io.jdbc.JdbcIO.RowMapper;
import org.apache.beam.sdk.options.PipelineOptionsFactory;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.ToString;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollectionTuple;
import org.apache.beam.sdk.values.TupleTag;
import org.apache.beam.sdk.values.TupleTagList;
import org.postgresql.ds.PGSimpleDataSource;

/**
 * Albums batch pipeline.
 *
 * @author Ahmed El.Hussaini
 */
@SuppressWarnings("serial")
public class AlbumPipeline {
  public static void main(String[] args) {
    String tableName = "Album";
    String pkName = "AlbumId";

    // Prepare and parse pipeline options
    PipelineOptionsFactory.register(AppOptions.class);
    AppOptions options = PipelineOptionsFactory.fromArgs(args).withValidation().as(AppOptions.class);

    // create pipeline
    Pipeline pipeline = Pipeline.create(options);

    PGSimpleDataSource pgDataSource = getPostgresDataSource(options);

    // Fetch all albums from database
    PCollection<HashMap<String, Object>> rows = pipeline.apply(
      "Read Albums from PG",
      JdbcIO.<HashMap<String, Object>>read()
        .withDataSourceConfiguration(JdbcIO.DataSourceConfiguration.create(pgDataSource))
        .withCoder(TableRowCoder.of())
        .withRowMapper(new RowMapper<HashMap<String, Object>>() {
          @Override
          public HashMap<String, Object> mapRow(ResultSet resultSet) throws Exception {
            return TableRowMapper.asMap(resultSet, tableName, pkName);
          }
        })
        .withQuery("select * from public.\"Album\""));

    // Map ResultSet to a HashMap using output tags.
    final TupleTag<TableRow> bqTableRowsSuccessTag = new TupleTag<TableRow>() {};
    final TupleTag<String> bqTableRowsFailedTag = new TupleTag<String>() {};
    PCollectionTuple bqTableRowsTuple  = rows.apply(
      "HashMap to TableRow",
      ParDo.of(new HashMapToTableRowFn(bqTableRowsSuccessTag, bqTableRowsFailedTag)).withOutputTags(bqTableRowsSuccessTag, TupleTagList.of(bqTableRowsFailedTag))
    );

    PCollection<String> bqTableRowsErrors = bqTableRowsTuple.get(bqTableRowsFailedTag).setCoder(NullableCoder.of(StringUtf8Coder.of()));
    // Log errors to a text file under cloud storage.
    bqTableRowsErrors
      .apply(
        "Write Errors",
        TextIO.write().to("gs://beam-tutorial/album_errors.txt")
      );

    // Write to BigQuery
    PCollection<TableRow> bqTableRows = bqTableRowsTuple.get(bqTableRowsSuccessTag).setCoder(NullableCoder.of(TableRowJsonCoder.of()));
    bqTableRows.apply("Write to BigQuery",
      BigQueryIO.writeTableRows()
        .to(options.getOutputTable()) // Passed as an argument from the command line
        .withSchema(AlbumTableSchema.schema()) // The schema for the BigQuery table
        .ignoreUnknownValues() // Ignore any values passed but not defined on the table schema
        .withWriteDisposition(BigQueryIO.Write.WriteDisposition.WRITE_APPEND) // Append to the BigQuery table.
        .withCreateDisposition(BigQueryIO.Write.CreateDisposition.CREATE_IF_NEEDED) // Create the BigQuery table if it doesn't exist
    );

    // Run the pipeline
    pipeline.run().waitUntilFinish();
  }

  private static class HashMapToTableRowFn extends DoFn<HashMap<String, Object>, TableRow> {
    static final long serialVersionUID = 1L;
    TupleTag<TableRow> bqTableRowsSuccessTag;
    TupleTag<String> bqTableRowsFailedTag;

    public HashMapToTableRowFn(TupleTag<TableRow> bqTableRowsSuccessTag, TupleTag<String> bqTableRowsFailedTag) {
      this.bqTableRowsSuccessTag = bqTableRowsSuccessTag;
      this.bqTableRowsFailedTag = bqTableRowsFailedTag;
    }

    @ProcessElement
    public void processElement(ProcessContext c) {
      try {
        HashMap<String, Object> map = c.element();
        TableRow tableRow = new TableRow();
        for (Map.Entry<String, Object> entry : map.entrySet()) {
          tableRow.set(entry.getKey(), entry.getValue());
        }

        c.output(tableRow);
      } catch (Exception e) {
        c.output(bqTableRowsFailedTag, e.toString());
      }
    }
  }

  private static PGSimpleDataSource getPostgresDataSource(AppOptions options) {
    PGSimpleDataSource dataSource = new PGSimpleDataSource();
    dataSource.setDatabaseName(options.getPgDatabase());
    dataSource.setServerName(options.getPgHost());
    dataSource.setPortNumber(options.getPgPort());
    dataSource.setUser(options.getPgUsername());
    dataSource.setPassword(options.getPgPassword());

    return dataSource;
  }
}