8/7/2018
Microsoft recently released a preview of a machine learning framework for .NET developers---ML.NET.
I needed to perform a clustering analysis from existing data in one of my applications. This is a pretty common machine learning task, so I decided to document the basic approach in this article.
We'll use the well-worn iris data set from the UCI Machine Learning Repository to demonstrate how to perform a cluster analysis using ML.NET. The iris data set contains three fairly distinct clusters of different types of iris---with a few difficult to classify specimens. This will make it easy to check the output of the analysis.
If you wish to review/run the completed application, you can clone/download it from GitHub.
Data preparation
The ML.NET framework allows you to import your data directly as part of the analysis pipeline. However, I have a number of libraries I like to use to load and transform my data, so I used CsvHelper to load the data. This resulted in an IEnumerable
of Iris
objects.
public class Iris
{
public double PetalLength { get; set; }
public double PetalWidth { get; set; }
public double SepalLength { get; set; }
public double SepalWidth { get; set; }
public string Type { get; set; }
}
At present, ML.NET only seems to work with public fields rather than properties (which I understand will be addressed as the framework matures). Rather than expose public fields in my domain classes (e.g.
Iris
) I wrote an Observation
data transfer class and converted my Iris
data to Observation
data.public class Observation
{
[VectorType(4)]
public float[] Features;
public static Observation Create(Iris iris)
{
return new Observation
{
Features = new[]
{
(float)iris.SepalLength,
(float)iris.SepalWidth,
(float)iris.PetalLength,
(float)iris.SepalWidth
}
};
}
}
IEnumerable observations = data.Select(Observation.Create).ToList();
ML.NET looks for a
Features
field during training, so naming is important. This must be a vector of floats. The VectorType
attribute specifies that the feature data is four dimensional.- Sepal length
- Sepal width
- Petal length
- Petal width
Creating the predication class
We also need a class that defines a prediction---i.e. the cluster that contains a given iris.
public class ClusterPrediction
{
public unit PredictedLabel;
public float[] Score;
}
Again, the type and name of these fields (not properties) is important.
Building the pipeline
We are now ready to construct the learning pipeline.
var pipeline = new LearningPipeline
{
CollectionDataSource.Create(observations),
new KMeansPlusPlusClusterer
{
K = 3,
NormalizeFeatures = NormalizeOption.Yes,
MaxIterations = 100
}
};
We want to identify 3 clusters, normalize the training data and stop the analysis after 100 iterations (if it still hasn't converged).
The model can now be trained.
PredictionModel model = pipeline.Train();
Determining the cluster assignments
Finally, we can predict the clusters containing each observation (iris).
data.ToList().ForEach(x =>
{
ClusterPrediction prediction = model.Predict(Observation.Create(x));
Console.WriteLine($"Type {x.Type} was assigned to cluster {prediction.PredictedLabel}");
});
prediction.Score
is a array of k (i.e. one for each cluster) numbers specifying the distance between each observation and the respective cluster centroid.I obtained the following results on a test run
- Cluster 1---40 versicolor and 11 virginica
- Cluster 2---39 virginica and 10 versicolor
- Cluster 3---50 setosa
Your results will differ, as k-means is non-deterministic.
Summary
ML.NET is currently a first preview release, so clearly there are things that will improve over the next few iterations (e.g. support for properties), but it's already looking useful if you have a need to include machine learning in your .NET applications.
Learning Tree training
If you are interested in the topics covered in this blog post, Learning Tree has a number of courses that may help advance your skills in these areas...and avoid any traps.