Apache Spark MultilayerPerceptronClassifier setting fatures - java

I am trying to do a multi class classification using org.apache.spark.ml.classification.MultilayerPerceptronClassifier. Given below is the code I used. I have 262 features and I have to give the feature columns to the MultilayerPerceptronClassifier. Can someone explain me a way to give features to the MultilayerPerceptronClassifier.
I can use setFeaturesCol() method to give features but it is infeasible because by using it, I can add only one feature at a time but I have 262 features.
import org.apache.commons.lang3.ArrayUtils;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.ml.classification.MultilayerPerceptronClassificationModel;
import org.apache.spark.ml.classification.MultilayerPerceptronClassifier;
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator;
import org.apache.spark.sql.DataFrame;
public class NN {
final static String RESPONSE_VARIABLE = "Activity";
public static void main(String args[]){
// Load training data
SparkConf sparkConf = new SparkConf();
sparkConf.setAppName("test-client").setMaster("local[2]");
sparkConf.set("spark.driver.allowMultipleContexts", "true");
JavaSparkContext javaSparkContext = new JavaSparkContext(sparkConf);
SQLContext sqlContext = new SQLContext(javaSparkContext);
// Convert data in csv format to Spark data frame
DataFrame trainDataFrame = sqlContext.read().format("com.databricks.spark.csv")
.option("inferSchema", "true")
.option("header", "true")
.load("/home/thamali/Desktop/Project/csv/libsvm/train.csv");
DataFrame testDataFrame = sqlContext.read().format("com.databricks.spark.csv")
.option("inferSchema", "true")
.option("header", "true")
.load("/home/thamali/Desktop/Project/csv/libsvm/train.csv");
String [] predictors = trainDataFrame.columns();
predictors = ArrayUtils.removeElement(predictors, RESPONSE_VARIABLE);
// specify layers for the neural network:
// input layer of size 4 (features), two intermediate of size 5 and 4
// and output of size 3 (classes)
int[] layers = new int[] {262, 50, 40, 12};
// create the trainer and set its parameters
MultilayerPerceptronClassifier trainer = new MultilayerPerceptronClassifier()
.setLayers(layers)
.setBlockSize(128)
.setSeed(1234L)
.setMaxIter(100);
// train the model
MultilayerPerceptronClassificationModel model = trainer.fit(trainDataFrame);
// compute accuracy on the test set
DataFrame result = model.transform(testDataFrame);
DataFrame predictionAndLabels = result.select("prediction", "label");
MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator()
.setMetricName("accuracy");
System.out.println("Accuracy = " + evaluator.evaluate(predictionAndLabels));
}
}

We can use Apache spark vector Assembler to create a vector containing all the necessary features.

Related

Iterating over RelationalGroupedDataset to find average and count of each key in Java

I have a Dataset<Row> which is built by reading a CSV file. I want to do the group by on one of the fields in CSV and then merge all the records with the same name and do some other computation over the merged Dataset.
My input CSV file looks like this
name,math_marks,science_marks
Ajay,10,20
Ram,15,25
Sita,18,30
Ajay,20,30
Sita,12,10
Sita,20,20
Ram,25,45
I want the final output to be something like this
name,math_avg,science_avg,count_of_records
Ajay,15,25,2
Ram,20,35,2
Sita,25,20,3
My initial code in Java is below:
import lombok.extern.slf4j.Slf4j;
import org.apache.spark.SparkConf;
import org.apache.spark.SparkContext;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.RelationalGroupedDataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import java.util.List;
import java.util.stream.Collectors;
#Slf4j
public class ReadCSVFiles {
public static void main(String[] args) {
SparkConf conf = new SparkConf().setAppName(ReadCSVFiles.class.getName()).setMaster("local");
// create Spark Context
SparkContext context = new SparkContext(conf);
// create spark Session
SparkSession sparkSession = new SparkSession(context);
context.setLogLevel("INFO");
Dataset<Row> df = sparkSession.read()
.format("csv")
.option("header", true)
.option("inferSchema", true)
.load("/Users/ajaychoudhary/Downloads/marksInputFile.csv");
System.out.println("========== Print Schema ============");
df.printSchema();
System.out.println("========== Print Data ==============");
df.show();
System.out.println("========== Print name of dataframe ==============");
df.select("name").show();
RelationalGroupedDataset relationalGroupedDataset = df.groupBy("name");
List<String> relationalGroupedDatasetRows = relationalGroupedDataset.count().collectAsList().stream()
.map(a -> a.mkString("::")).collect(Collectors.toList());
log.info("relationalGroupedDatasetRows is = {} ", relationalGroupedDatasetRows);
}
}
I am receiving this output as of now which is able to find the count of unique users. I am unable to find the average of the marks.
relationalGroupedDatasetRows is = [Ram::2, Ajay::2, Sita::3]
Also, I need to understand whether the above approach of using groupBy is fine or we can use some other alternate to achieve this.
I don't know much about this but you are using the
'count' method which "counts the number of rows for each group". Instead try using
'avg' method which "returns average for each group".

Using HBase API to filter and decode data stored using Phoenix API

Here's a short description of the code I am writing right now:
We are using Apache Hbase API to store data in Hbase Database. The schema is made up of attributes of various data types like date, float, varchar, char, etc...
Now a requirement has come up where we need access to different versions of one tuple ie. the row will be updated at different times over years and we would like to access all these different versions.
Currently Apache API only has support for
1) Defining the number of versions that a table should maintain at the time of creating the table using DDL
2) When creating a new connection, specifying the version number of the table on which all the queries should work
https://community.hortonworks.com/questions/51846/query-versions-in-phoenix.html
But this is too restricted, the HBase API has support for time range and setting max version within that time range which we need. So I decided to access the data stored using Phoenix API using Hbase API.
This is the issue I am facing:
1) I want to filter rows based on any attribute from primary key. My primary key consists of 9 attributes:
Char(10),Char(10),Char(3),Varchar(40),Varchar(8),Varchar(8),Varchar(40),Varchar(256),Date
Phoenix API concatenates these values and creates a row key from them which looks something like this:
$qA$F62&81PPEDOID01 PGKBLOOMBRG\x00VENDOR\x00PRCQUOTE\x00BB\x001\x00\x80\x00\x01aD\x5C\xFC\x00
I am using Hbase Row Filter with Equal To Comparator with Sub String Match to filter rows based on their primary key value...
Filter IDFilter = new RowFilter(CompareOp.EQUAL, new SubstringComparator("$qA$F62&81"));
Filter CurrencyCodeFilter = new RowFilter(CompareOp.EQUAL, new SubstringComparator("PGK"));
ArrayList<Filter> filters = new ArrayList<>();
filters.add(IDFilter);
filters.add(CurrencyCodeFilter);
FilterList filterList = new FilterList(Operator.MUST_PASS_ALL ,filters);
scan.setMaxVersions(1);
scan.setFilter(filterList);
This works fine for primary key attributes that are char, varchar and numbers. But I just can't filter out based on the date and it's really necessary.
The problem with date is:
I don't understand the encoding it uses, eg Phoenix API stores the date "2018-01-30" as
\x80\x00\x01aD\x5C\xFC\x00
I understand that the Phoenix API places "\x00" after varchar to act as a delimiter, but I don't understand this encoding.
So I tried running this command in Hbase Shell:
hbase(main):007:0> scan 'HSTP2', {FILTER => "RowFilter(=,'substring:\x80\x00\x01aD\x5C\xFC\x00')"}
I got proper results
But when I tried the same in Java using Hbase API, I don't get any results:
Filter DateFilter = new RowFilter(CompareOp.EQUAL, new SubstringComparator("\\x80\\x00\\x01aD\\x5C\\xFC\\x00"));
And I get this when I sysout the DateFilter
RowFilter (EQUAL, \x5Cx80\x5Cx00\x5Cx01ad\x5Cx5c\x5Cxfc\x5Cx00)
The conversion of '\' > '\x5C' is the cause of the problem due to which I don't get any results.
How can I perform row filters base on any date? Will I have to convert the date to the format that Phoenix API stores it in and then run a row filter? Or is there some other way?
This is my code so far testing filtering based on different attributes and decoding the fetched data:
import java.io.IOException;
import java.text.DateFormat;
import java.text.ParseException;
import java.text.SimpleDateFormat;
import java.util.ArrayList;
import java.util.Date;
import java.util.Locale;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.hbase.Cell;
import org.apache.hadoop.hbase.HBaseConfiguration;
import org.apache.hadoop.hbase.KeyValue;
import org.apache.hadoop.hbase.client.Get;
import org.apache.hadoop.hbase.client.HTable;
import org.apache.hadoop.hbase.client.Result;
import org.apache.hadoop.hbase.client.ResultScanner;
import org.apache.hadoop.hbase.client.Scan;
import org.apache.hadoop.hbase.filter.Filter;
import org.apache.hadoop.hbase.filter.FirstKeyOnlyFilter;
import org.apache.hadoop.hbase.filter.PrefixFilter;
import org.apache.hadoop.hbase.filter.QualifierFilter;
import org.apache.hadoop.hbase.filter.RegexStringComparator;
import org.apache.hadoop.hbase.filter.RowFilter;
import org.apache.hadoop.hbase.filter.SingleColumnValueFilter;
import org.apache.hadoop.hbase.filter.SubstringComparator;
import org.apache.hadoop.hbase.protobuf.generated.HBaseProtos.CompareType;
import org.apache.hadoop.hbase.filter.BinaryComparator;
import org.apache.hadoop.hbase.filter.BinaryPrefixComparator;
import org.apache.hadoop.hbase.filter.CompareFilter;
import org.apache.hadoop.hbase.filter.CompareFilter.CompareOp;
import org.apache.hadoop.hbase.filter.Filter.ReturnCode;
import org.apache.hadoop.hbase.filter.FilterList;
import org.apache.hadoop.hbase.filter.FilterList.Operator;
import org.apache.phoenix.schema.PStringColumn;
import org.apache.phoenix.schema.SortOrder;
import org.apache.phoenix.schema.types.PDataType;
import org.apache.phoenix.schema.types.PDate;
import org.apache.phoenix.schema.types.PFloat;
import org.apache.phoenix.schema.types.PInteger;
import org.apache.phoenix.shaded.org.apache.directory.shared.kerberos.codec.types.PaDataType;
import org.apache.phoenix.shaded.org.joni.Regex;
import org.apache.hadoop.hbase.util.Bytes;
public class HbaseVersionedPriceFetcher {
public static void main(String[] args) {
try {
Configuration conf = HBaseConfiguration.create(new Configuration());
conf.set("hbase.zookeeper.quorum", "hostName");//Private Detail
conf.set("hbase.zookeeper.property.clientPort", "2181");
HTable table = new HTable(conf, "HSTP2");
// Filter filter = new SingleColumnValueFilter("0".getBytes(),"LAST_CHG_USR_ID".getBytes(), CompareOp.EQUAL, "AUTO:GEN:SCRIPT".getBytes());
// Filter filter = new SingleColumnValueFilter("ISPH".getBytes(),"MKT_OID".getBytes(), CompareOp.EQUAL, "MARKET".getBytes());
// Filter filter = new SingleColumnValueFilter("ISPH".getBytes(),"VALIDATED_PRC_TYPE".getBytes(), CompareOp.EQUAL, "MID".getBytes());
Scan scan = new Scan();
//Filter List
Filter IDFilter = new RowFilter(CompareOp.EQUAL, new SubstringComparator("qA$F62&81"));
Filter CurrencyCodeFilter = new RowFilter(CompareOp.EQUAL, new SubstringComparator("PGK"));
ArrayList<Filter> filters = new ArrayList<>();
filters.add(IDFilter);
filters.add(CurrencyCodeFilter);
FilterList filterList = new FilterList(Operator.MUST_PASS_ALL ,filters);
scan.setMaxVersions(1);
scan.setFilter(filterList);
//REGEX
//Filter filter = new RowFilter(CompareOp.EQUAL, new RegexStringComparator(".*PGK.*VENDOR.*"))
//scan.addColumn("ISPH".getBytes(), "ADJST_TMS".getBytes());
// long start = new Long("1529578558767");
// long end = new Long("1529580854059");
//
// try {
// scan.setTimeRange(start,end);
// } catch (IOException e) {
// // TODO Auto-generated catch block
// e.printStackTrace();
// }
ResultScanner scanner = table.getScanner(scan);
int count = 0;
for (Result rr : scanner) {
count += 1;
System.out.println("Instrument "+ count);
System.out.println(rr);
for (KeyValue value: rr.raw()) {
String qualifier = new String(value.getQualifier());
System.out.print( qualifier+" : ");
byte[] valByteArray = value.getValue();
if(qualifier.equals("ASK_CPRC") || qualifier.equals("BID_CPRC") || qualifier.equals("MID_CPRC") || qualifier.equals("VALIDATED_CPRC")) {
float decoded = PFloat.INSTANCE.getCodec().decodeFloat(valByteArray, 0, SortOrder.getDefault());
System.out.println(decoded);
} else if (qualifier.equals("LAST_CHG_TMS") || qualifier.equals("ADJST_TMS") ) {
System.out.println(PDate.INSTANCE.toObject(valByteArray, SortOrder.getDefault()));
} else if (qualifier.equals("HST_PRC_DTE_OF_NUM")) {
int decoded = PInteger.INSTANCE.getCodec().decodeInt(valByteArray, 0, SortOrder.getDefault());
System.out.println(decoded);
} else {
System.out.println(new String(valByteArray));
}
}
}
scanner.close();
} catch (IOException e1) {
e1.printStackTrace();
}
}
static byte[] getBytes(String string) {
return string.getBytes();
}
}

How to get predicted values from JavaDecsionTreeRegressionExample.java of Spark MLlib?

I would like to get predicted values from JavaDecisionTreeRegressionExample.java, but not only the description of the decision tree and metrics such as MAE and RMSE. Does anyone know how to do it or which method can I use it to get the predicted values?
I have tried many methods, which are provided by RegressionEvaluator and DecisionTreeRegressionModel classes, to solve this problem, but I still don't know how to get them. So, if anyone knows how to do it, please show me. Thank you very much!
The following is the source code of JavaDecisionTreeRegressionExample.java
package org.apache.spark.examples.ml;
// $example on$
import org.apache.spark.ml.Pipeline;
import org.apache.spark.ml.PipelineModel;
import org.apache.spark.ml.PipelineStage;
import org.apache.spark.ml.evaluation.RegressionEvaluator;
import org.apache.spark.ml.feature.VectorIndexer;
import org.apache.spark.ml.feature.VectorIndexerModel;
import org.apache.spark.ml.regression.DecisionTreeRegressionModel;
import org.apache.spark.ml.regression.DecisionTreeRegressor;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
// $example off$
public class JavaDecisionTreeRegressionExample {
public static void main(String[] args) {
SparkSession spark = SparkSession
.builder()
.appName("JavaDecisionTreeRegressionExample")
.getOrCreate();
// $example on$
// Load the data stored in LIBSVM format as a DataFrame.
Dataset<Row> data = spark.read().format("libsvm")
.load("data/mllib/sample_libsvm_data.txt");
// Automatically identify categorical features, and index them.
// Set maxCategories so features with > 4 distinct values are treated as continuous.
VectorIndexerModel featureIndexer = new VectorIndexer()
.setInputCol("features")
.setOutputCol("indexedFeatures")
.setMaxCategories(4)
.fit(data);
// Split the data into training and test sets (30% held out for testing).
Dataset<Row>[] splits = data.randomSplit(new double[]{0.7, 0.3});
Dataset<Row> trainingData = splits[0];
Dataset<Row> testData = splits[1];
// Train a DecisionTree model.
DecisionTreeRegressor dt = new DecisionTreeRegressor()
.setFeaturesCol("indexedFeatures");
// Chain indexer and tree in a Pipeline.
Pipeline pipeline = new Pipeline()
.setStages(new PipelineStage[]{featureIndexer, dt});
// Train model. This also runs the indexer.
PipelineModel model = pipeline.fit(trainingData);
// Make predictions.
Dataset<Row> predictions = model.transform(testData);
// Select example rows to display.
predictions.select("label", "features").show(5);
// Select (prediction, true label) and compute test error.
RegressionEvaluator evaluator = new RegressionEvaluator()
.setLabelCol("label")
.setPredictionCol("prediction")
.setMetricName("rmse");
double rmse = evaluator.evaluate(predictions);
System.out.println("Root Mean Squared Error (RMSE) on test data = " + rmse);
DecisionTreeRegressionModel treeModel =
(DecisionTreeRegressionModel) (model.stages()[1]);
System.out.println("Learned regression tree model:\n" + treeModel.toDebugString());
// $example off$
spark.stop();
}
}
I solve my problem. Modify predictions.select("label", "features").show(5); to predictions.select("prediction","label", "features").show(5); Then, you can get predicted values.

How to integrate ALS in my spark pipeline to implement Non-negative matrix factorization?

I'm using spark mllib to train naive-bayes classifier model where i create a pipeline to index my string features, then normalize and apply PCA for dimensionality reduction after which i train my naive bayes model. When i run the pipeline i get negative values in the PCA components vector.On googling i found out that i have to apply NMF(Non negative matrix factorization) to obtain positive vectors and i found ALS will implement NMF with method .setnonnegative(true), but i dont know how to integrate the ALS into my pipeline after PCA. Any help appreciated. Thanks.
here is the code
import org.apache.spark.SparkConf;
import org.apache.spark.SparkContext;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.ml.Pipeline;
import org.apache.spark.ml.PipelineModel;
import org.apache.spark.ml.PipelineStage;
import org.apache.spark.ml.classification.NaiveBayes;
import org.apache.spark.ml.feature.IndexToString;
import org.apache.spark.ml.feature.Normalizer;
import org.apache.spark.ml.feature.PCA;
import org.apache.spark.ml.feature.StringIndexer;
import org.apache.spark.ml.feature.StringIndexerModel;
import org.apache.spark.ml.feature.VectorAssembler;
import org.apache.spark.ml.recommendation.ALS;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.SQLContext;
public class NBTrainPCA {
public static void main(String args[]){
try{
SparkConf conf = new SparkConf().setAppName("NBTrain");
SparkContext scc = new SparkContext(conf);
scc.setLogLevel("ERROR");
JavaSparkContext sc = new JavaSparkContext(scc);
SQLContext sqlc = new SQLContext(scc);
DataFrame traindata = sqlc.read().format("parquet").load(args[0]).filter("user_email!='NA' and user_email!='00' and user_email!='0ed709b5bec77b6bff96ea5b5e334a8e5' and user_email is not null and ip is not null and region_code is not null and city is not null and browser_name is not null and os_name is not null");
traindata.registerTempTable("master");
//DataFrame data = sqlc.sql("select user_email,user_device,ip,country_code,region_code,city,zip_code,time_zone,browser_name,browser_manf,os_name,os_manf from master where user_email!='NA' and user_email is not null and user_device is not null and ip is not null and country_code is not null and region_code is not null and city is not null and browser_name is not null and browser_manf is not null and zip_code is not null and time_zone is not null and os_name is not null and os_manf is not null");
StringIndexerModel emailIndexer = new StringIndexer()
.setInputCol("user_email")
.setOutputCol("email_index")
.setHandleInvalid("skip")
.fit(traindata);
StringIndexer udevIndexer = new StringIndexer()
.setInputCol("user_device")
.setOutputCol("udev_index")
.setHandleInvalid("skip");
StringIndexer ipIndexer = new StringIndexer()
.setInputCol("ip")
.setOutputCol("ip_index")
.setHandleInvalid("skip");
StringIndexer ccodeIndexer = new StringIndexer()
.setInputCol("country_code")
.setOutputCol("ccode_index")
.setHandleInvalid("skip");
StringIndexer rcodeIndexer = new StringIndexer()
.setInputCol("region_code")
.setOutputCol("rcode_index")
.setHandleInvalid("skip");
StringIndexer cyIndexer = new StringIndexer()
.setInputCol("city")
.setOutputCol("cy_index")
.setHandleInvalid("skip");
StringIndexer zpIndexer = new StringIndexer()
.setInputCol("zip_code")
.setOutputCol("zp_index")
.setHandleInvalid("skip");
StringIndexer tzIndexer = new StringIndexer()
.setInputCol("time_zone")
.setOutputCol("tz_index")
.setHandleInvalid("skip");
StringIndexer bnIndexer = new StringIndexer()
.setInputCol("browser_name")
.setOutputCol("bn_index")
.setHandleInvalid("skip");
StringIndexer bmIndexer = new StringIndexer()
.setInputCol("browser_manf")
.setOutputCol("bm_index")
.setHandleInvalid("skip");
StringIndexer bvIndexer = new StringIndexer()
.setInputCol("browser_version")
.setOutputCol("bv_index")
.setHandleInvalid("skip");
StringIndexer onIndexer = new StringIndexer()
.setInputCol("os_name")
.setOutputCol("on_index")
.setHandleInvalid("skip");
StringIndexer omIndexer = new StringIndexer()
.setInputCol("os_manf")
.setOutputCol("om_index")
.setHandleInvalid("skip");
VectorAssembler assembler = new VectorAssembler()
.setInputCols(new String[]{ "udev_index","ip_index","ccode_index","rcode_index","cy_index","zp_index","tz_index","bn_index","bm_index","bv_index","on_index","om_index"})
.setOutputCol("ffeatures");
Normalizer normalizer = new Normalizer()
.setInputCol("ffeatures")
.setOutputCol("sfeatures")
.setP(1.0);
PCA pca = new PCA()
.setInputCol("sfeatures")
.setOutputCol("pcafeatures")
.setK(5);
NaiveBayes nbcl = new NaiveBayes()
.setFeaturesCol("pcafeatures")
.setLabelCol("email_index")
.setSmoothing(1.0);
IndexToString is = new IndexToString()
.setInputCol("prediction")
.setOutputCol("op")
.setLabels(emailIndexer.labels());
Pipeline pipeline = new Pipeline()
.setStages(new PipelineStage[] {emailIndexer,udevIndexer,ipIndexer,ccodeIndexer,rcodeIndexer,cyIndexer,zpIndexer,tzIndexer,bnIndexer,bmIndexer,bvIndexer,onIndexer,omIndexer,assembler,normalizer,pca,nbcl,is});
PipelineModel model = pipeline.fit(traindata);
//DataFrame chidata = model.transform(data);
//chidata.write().format("com.databricks.spark.csv").save(args[1]);
model.write().overwrite().save(args[1]);
sc.close();
}
catch(Exception e){
}
}
}
I would suggest you to read a bit about PCA so you can get a better feeling of what it is doing. Here some links:
https://stats.stackexchange.com/questions/26352/interpreting-positive-and-negative-signs-of-the-elements-of-pca-eigenvectors
https://stats.stackexchange.com/questions/2691/making-sense-of-principal-component-analysis-eigenvectors-eigenvalues
On the ALS integration to your pipeline seems like you just want to plug one thing after the other. Better to understand what each of them is doing and used for: ALS and PCA are quite different things. ALS is doing matrix factorization using AlS for error minimization, is not finding any principal component to apply a transformation to the data, or dimensionality reduction.
BTW: I do not see any problems getting negative values in the PCA components vector. You can check this in the links above. You are applying a linear transformation to the data. So the new vectors are now a result of the transformation.
I hope it helps.

Apache Spark - datediff for dataframes?

I'm trying to compute a column based on date difference. Is there a corresponding function for datediff that can be used on a column/dataframe? Fe.
Column new = old.col("one").divide(old.col("max").minus(old.col("min")));
But in this case, the minus function doesn't work, because the min and max columns contain dates. So I need something like datediff for Columns. Is there such a thing?
Thank you!
There is and it is called datediff (org.apache.spark.sql.functions.datediff):
public static Column datediff(Column end,
Column start)
Returns the number of days from start to end.
Parameters:
end - (undocumented)
start - (undocumented)
Returns:
(undocumented)
Since:
1.5.0
Example:
import org.apache.spark.api.java.*;
import org.apache.spark.SparkConf;
import org.apache.spark.sql.SQLContext;
import static org.apache.spark.sql.functions.*;
import org.apache.spark.sql.DataFrame;
public class App {
public static void main(String[] args) {
SparkConf conf = new SparkConf().setMaster("local");
JavaSparkContext sc = new JavaSparkContext(conf);
SQLContext sqlContext= new SQLContext(sc);
DataFrame df = sqlContext.sql(
"SELECT CAST('2012-01-01' AS DATE), CAST('2013-08-02' AS DATE)").toDF("first", "second");
df.select(datediff(df.col("first"), df.col("second"))).show();
}
}

Categories