-
Notifications
You must be signed in to change notification settings - Fork 2.7k
/
Copy pathProgram.cs
70 lines (57 loc) · 2.16 KB
/
Program.cs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
using Microsoft.ML;
using Microsoft.ML.Data;
using Microsoft.ML.TorchSharp;
using MathNet.Numerics.Statistics;
using Microsoft.ML.Transforms;
// Initialize MLContext
var ctx = new MLContext();
// (Optional) Use GPU
ctx.GpuDeviceId = 0;
ctx.FallbackToCpu = false;
// Log training output
ctx.Log += (o, e) => {
if (e.Source.Contains("NasBertTrainer"))
Console.WriteLine(e.Message);
};
// Load data into IDataView
var columns = new[]
{
new TextLoader.Column("search_term",DataKind.String,3),
new TextLoader.Column("relevance",DataKind.Single,4),
new TextLoader.Column("product_description",DataKind.String,5)
};
var loaderOptions = new TextLoader.Options()
{
Columns = columns,
HasHeader = true,
Separators = new[] { ',' },
MaxRows = 1000 // Dataset has 75k rows. Only load 1k for quicker training
};
var dataPath = Path.GetFullPath(@"..\..\..\..\Data\home-depot-sentence-similarity.csv");
var textLoader = ctx.Data.CreateTextLoader(loaderOptions);
var data = textLoader.Load(dataPath);
// Split data into 80% training, 20% testing
var dataSplit = ctx.Data.TrainTestSplit(data, testFraction: 0.2);
// Define pipeline
var pipeline =
ctx.Transforms.ReplaceMissingValues("relevance", replacementMode: MissingValueReplacingEstimator.ReplacementMode.Mean)
.Append(ctx.Regression.Trainers.SentenceSimilarity(labelColumnName: "relevance", sentence1ColumnName: "search_term", sentence2ColumnName: "product_description"));
// Train the model
var model = pipeline.Fit(dataSplit.TrainSet);
// Use the model to make predictions on the test dataset
var predictions = model.Transform(dataSplit.TestSet);
// Evaluate the model
Evaluate(predictions, "relevance", "Score");
// Save the model
ctx.Model.Save(model, data.Schema, "model.zip");
void Evaluate(IDataView predictions, string actualColumnName, string predictedColumnName)
{
var actual =
predictions.GetColumn<float>(actualColumnName)
.Select(x => (double)x);
var predicted =
predictions.GetColumn<float>(predictedColumnName)
.Select(x => (double)x);
var corr = Correlation.Pearson(actual, predicted);
Console.WriteLine($"Pearson Correlation: {corr}");
}