/*
 * Decompiled with CFR 0.152.
 */
package weka.classifiers.functions;

import java.util.Enumeration;
import java.util.Vector;
import weka.classifiers.Classifier;
import weka.classifiers.Evaluation;
import weka.classifiers.functions.LinearRegression;
import weka.classifiers.functions.Logistic;
import weka.clusterers.MakeDensityBasedClusterer;
import weka.clusterers.SimpleKMeans;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.Utils;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.ClusterMembership;

public class RBFNetwork
extends Classifier
implements OptionHandler {
    private Logistic m_logistic;
    private LinearRegression m_linear;
    private ClusterMembership m_basisFilter;
    private int m_numClusters = 2;
    protected double m_ridge = 1.0E-8;
    private int m_maxIts = -1;
    private int m_clusteringSeed = 1;

    public String globalInfo() {
        return "Class that implements a radial basis function network. It uses the K-Means clustering algorithm to provide the basis functions and learns either a logistic regression (discrete class problems) or linear regression (numeric class problems) on top of that.";
    }

    public void buildClassifier(Instances instances) throws Exception {
        SimpleKMeans simpleKMeans = new SimpleKMeans();
        simpleKMeans.setNumClusters(this.m_numClusters);
        simpleKMeans.setSeed(this.m_clusteringSeed);
        MakeDensityBasedClusterer makeDensityBasedClusterer = new MakeDensityBasedClusterer();
        makeDensityBasedClusterer.setClusterer(simpleKMeans);
        this.m_basisFilter = new ClusterMembership();
        this.m_basisFilter.setDensityBasedClusterer(makeDensityBasedClusterer);
        this.m_basisFilter.setInputFormat(instances);
        Instances instances2 = Filter.useFilter(instances, this.m_basisFilter);
        if (instances.classAttribute().isNominal()) {
            this.m_linear = null;
            this.m_logistic = new Logistic();
            this.m_logistic.setRidge(this.m_ridge);
            this.m_logistic.setMaxIts(this.m_maxIts);
            this.m_logistic.buildClassifier(instances2);
        } else {
            this.m_logistic = null;
            this.m_linear = new LinearRegression();
            this.m_linear.setRidge(this.m_ridge);
            this.m_linear.buildClassifier(instances2);
        }
    }

    public double[] distributionForInstance(Instance instance) throws Exception {
        this.m_basisFilter.input(instance);
        Instance instance2 = this.m_basisFilter.output();
        return instance.classAttribute().isNominal() ? this.m_logistic.distributionForInstance(instance2) : this.m_linear.distributionForInstance(instance2);
    }

    public String toString() {
        StringBuffer stringBuffer = new StringBuffer();
        stringBuffer.append("Radial basis function network\n");
        stringBuffer.append(this.m_linear == null ? "(Logistic regression " : "(Linear regression ");
        stringBuffer.append("applied to K-means clusters as basis functions):\n\n");
        stringBuffer.append(this.m_linear == null ? this.m_logistic.toString() : this.m_linear.toString());
        return stringBuffer.toString();
    }

    public String maxItsTipText() {
        return "Maximum number of iterations for the logistic regression to perform. Only applied to discrete class problems.";
    }

    public int getMaxIts() {
        return this.m_maxIts;
    }

    public void setMaxIts(int n) {
        this.m_maxIts = n;
    }

    public String ridgeTipText() {
        return "Set the Ridge value for the logistic or linear regression.";
    }

    public void setRidge(double d) {
        this.m_ridge = d;
    }

    public double getRidge() {
        return this.m_ridge;
    }

    public String numClustersTipText() {
        return "The number of clusters for K-Means to generate.";
    }

    public void setNumClusters(int n) {
        if (n > 0) {
            this.m_numClusters = n;
        }
    }

    public int getNumClusters() {
        return this.m_numClusters;
    }

    public String clusteringSeedTipText() {
        return "The random seed to pass on to K-means.";
    }

    public void setClusteringSeed(int n) {
        this.m_clusteringSeed = n;
    }

    public int getClusteringSeed() {
        return this.m_clusteringSeed;
    }

    public Enumeration listOptions() {
        Vector<Option> vector = new Vector<Option>(4);
        vector.addElement(new Option("\tSet the number of clusters (basis functions) to generate. (default = 2).", "B", 1, "-B <number>"));
        vector.addElement(new Option("\tSet the random seed to be used by K-means. (default = 1).", "S", 1, "-S <seed>"));
        vector.addElement(new Option("\tSet the ridge value for the logistic or linear regression.", "R", 1, "-R <ridge>"));
        vector.addElement(new Option("\tSet the maximum number of iterations for the logistic regression. (default -1, until convergence).", "M", 1, "-M <number>"));
        return vector.elements();
    }

    public void setOptions(String[] stringArray) throws Exception {
        String string;
        this.setDebug(Utils.getFlag('D', stringArray));
        String string2 = Utils.getOption('R', stringArray);
        this.m_ridge = string2.length() != 0 ? Double.parseDouble(string2) : 1.0E-8;
        String string3 = Utils.getOption('M', stringArray);
        this.m_maxIts = string3.length() != 0 ? Integer.parseInt(string3) : -1;
        String string4 = Utils.getOption('B', stringArray);
        if (string4.length() != 0) {
            this.setNumClusters(Integer.parseInt(string4));
        }
        if ((string = Utils.getOption('S', stringArray)).length() != 0) {
            this.setClusteringSeed(Integer.parseInt(string));
        }
        Utils.checkForRemainingOptions(stringArray);
    }

    public String[] getOptions() {
        String[] stringArray = new String[8];
        int n = 0;
        stringArray[n++] = "-B";
        stringArray[n++] = "" + this.m_numClusters;
        stringArray[n++] = "-S";
        stringArray[n++] = "" + this.m_clusteringSeed;
        stringArray[n++] = "-R";
        stringArray[n++] = "" + this.m_ridge;
        stringArray[n++] = "-M";
        stringArray[n++] = "" + this.m_maxIts;
        while (n < stringArray.length) {
            stringArray[n++] = "";
        }
        return stringArray;
    }

    public static void main(String[] stringArray) {
        try {
            System.out.println(Evaluation.evaluateModel(new RBFNetwork(), stringArray));
        }
        catch (Exception exception) {
            exception.printStackTrace();
            System.err.println(exception.getMessage());
        }
    }
}

