Linear Regression and cross validation in Java using Weka

I stumbled upon a question in the internet about how to make price prediction based on price history in Android. Assuming the history size is quite small (few hundreds) and the attribute is not many (less than 20), I quickly thought that Weka Java API would be one of the easiest way to achieve this.

Unfortunately, I can’t easily find straightforward tutorial or example on this since most of them are for GUI version of Weka. So, I decided to whip up an example (using bleeding-edge weka-dev 3.9.2) and post the brief explanation here 😆

I use a demand-forecasting (regression) dataset from UCI for this example. I choose this dataset because it has quite similar characteristics with price prediction.

To convert the dataset for Weka, I use methods below:

import weka.core.Attribute;
import weka.core.DenseInstance;
import weka.core.Instances;

...

public static String DATASET_FILE = "Daily_Demand_Forecasting_Orders.csv";
public static int DATASET_SIZE = 60;
public static int DATASET_ATTRIBUTES_NUM = 13;

...

private Instances loadDataset() throws RuntimeException {
    Instances dataset = null;
    BufferedReader br = null;
    FileReader fr = null;
    try {
        ClassLoader classLoader = getClass().getClassLoader();
        fr = new FileReader(classLoader.getResource(DATASET_FILE).getPath());
        br = new BufferedReader(fr);
        String sCurrentLine;
        int line = 1;

        dataset = this.createEmptyDataset();
        while ((sCurrentLine = br.readLine()) != null) {
            if (line > 1) {
                try {
                    double[] values = new double[DATASET_ATTRIBUTES_NUM];
                    int i = 0;
                    for (String val : sCurrentLine.split(";")) {
                        values[i] = Double.parseDouble(val);
                        i++;
                    }
                    dataset.add(new DenseInstance(1.0, values));
                } catch (NumberFormatException ex) {
                    System.err.println(ex.getMessage());
                }
            }
            line++;
        }
        br.close();
    } catch (final Exception e) {
        throw new RuntimeException(e);
    } finally {
        try {
            if (br != null) br.close();
            if (fr != null) fr.close();
        } catch (IOException ex) {
            ex.printStackTrace();
        }
    }
    return dataset;
}

private Instances createEmptyDataset() {
 ArrayList<Attribute> header = this.createHeader();
 Instances instances = new Instances(DATASET_FILE, header, DATASET_SIZE);
 instances.setClassIndex(DATASET_ATTRIBUTES_NUM - 1);
 return instances;
}

Where the createHeader() is actually dataset attributes definition:

private ArrayList<Attribute> createHeader() {
    ArrayList<Attribute> header = new ArrayList<>();
    header.add(new Attribute("Week_of_the_month"));
    header.add(new Attribute("Day_of_the_week_"));
    header.add(new Attribute("Non_urgent_order"));
    header.add(new Attribute("Urgent_order"));
    header.add(new Attribute("Order_type_A"));
    header.add(new Attribute("Order_type_B"));
    header.add(new Attribute("Order_type_C"));
    header.add(new Attribute("Fiscal_sector_orders"));
    header.add(new Attribute("Orders_from_the_traffic_controller_sector"));
    header.add(new Attribute("Banking_orders_(1)"));
    header.add(new Attribute("Banking_orders_(2)"));
    header.add(new Attribute("Banking_orders_(3)"));
    header.add(new Attribute("Target_(Total_orders)"));
    return header;
}

Then we can use loadDataset() method to build LinearRegression model:

import weka.classifiers.functions.LinearRegression;
import weka.core.Instances;


...

Instances dataset = loadDataset();
LinearRegression lr = new LinearRegression();
lr.setRidge(1.0E-8);
lr.buildClassifier(dataset);

Finally, we can use the model, lr, to predict a data:

import weka.core.Attribute;
import weka.core.DenseInstance;

...

double[] data = new double[]{1.0, 4.0, 316.307, 223.270, 61.543, 175.586, 302.448, 0.0, 65556.0, 44914.0, 188411.0, 14793.0, 539.577};
double expectation = data[data.length - 1];
DenseInstance instance = new DenseInstance(1.0, data);
double prediction = lr.classifyInstance(instance)

Or if you want to evaluate the model performance, you can do k-fold cross validation and check the error rate. In this example, I do 10 folds cross validation and measure the root mean square error (RMSE):

Evaluation evaluation = new Evaluation(dataset);
evaluation.crossValidateModel(lr, dataset, 10, new Random(1));
double rmse = evaluation.rootMeanSquaredError();

In the full example, I also use normalization to make sure each values belongs to the same scale. On many cases, this may also improve model’s performance.

Enjoy! ☕

Leave a Reply

Your email address will not be published. Required fields are marked *