Skip to content

Commit

Permalink
Adds the ability to load a pre-trained LightGBM file and import it in…
Browse files Browse the repository at this point in the history
…to ML.Net. (#6569)

* need to finish multiclass

* multiclass

* reverting test for now

* reverting test for now

* added test and fixed objective parsing

* minor testing changes
  • Loading branch information
michaelgsharp authored May 16, 2023
1 parent 1c41ed4 commit f93ab25
Show file tree
Hide file tree
Showing 9 changed files with 7,707 additions and 50 deletions.
35 changes: 31 additions & 4 deletions src/Microsoft.ML.LightGbm/LightGbmBinaryTrainer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Generic;
using System.IO;
using Microsoft.ML;
using Microsoft.ML.Calibrators;
using Microsoft.ML.CommandLine;
Expand Down Expand Up @@ -228,6 +230,26 @@ internal LightGbmBinaryTrainer(IHostEnvironment env,
{
}

/// <summary>
/// Initializes a new instance of <see cref="LightGbmBinaryTrainer"/>
/// </summary>
/// <param name="env">The private instance of <see cref="IHostEnvironment"/>.</param>
/// <param name="lightGbmModel"> A pre-trained <see cref="System.IO.Stream"/> of a LightGBM model file inferencing</param>
/// <param name="featureColumnName">The name of the feature column.</param>
internal LightGbmBinaryTrainer(IHostEnvironment env,
Stream lightGbmModel,
string featureColumnName = DefaultColumnNames.Features)
: base(env,
LoadNameValue,
new Options()
{
FeatureColumnName = featureColumnName,
LightGbmModel = lightGbmModel
},
new SchemaShape.Column())
{
}

private protected override CalibratedModelParametersBase<LightGbmBinaryModelParameters, PlattCalibrator> CreatePredictor()
{
Host.Check(TrainedEnsemble != null, "The predictor cannot be created before training is complete");
Expand All @@ -241,11 +263,16 @@ private protected override void CheckDataValid(IChannel ch, RoleMappedData data)
{
Host.AssertValue(ch);
base.CheckDataValid(ch, data);
var labelType = data.Schema.Label.Value.Type;
if (!(labelType is BooleanDataViewType || labelType is KeyDataViewType || labelType == NumberDataViewType.Single))

// If using a pre-trained model file we don't need a label column
if (LightGbmTrainerOptions.LightGbmModel == null)
{
throw ch.ExceptParam(nameof(data),
$"Label column '{data.Schema.Label.Value.Name}' is of type '{labelType.RawType}', but must be unsigned int, boolean or float.");
var labelType = data.Schema.Label.Value.Type;
if (!(labelType is BooleanDataViewType || labelType is KeyDataViewType || labelType == NumberDataViewType.Single))
{
throw ch.ExceptParam(nameof(data),
$"Label column '{data.Schema.Label.Value.Name}' is of type '{labelType.RawType}', but must be unsigned int, boolean or float.");
}
}
}

Expand Down
65 changes: 65 additions & 0 deletions src/Microsoft.ML.LightGbm/LightGbmCatalog.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System.IO;
using Microsoft.ML.Data;
using Microsoft.ML.Runtime;
using Microsoft.ML.Trainers.LightGbm;
Expand Down Expand Up @@ -67,6 +68,22 @@ public static LightGbmRegressionTrainer LightGbm(this RegressionCatalog.Regressi
return new LightGbmRegressionTrainer(env, options);
}

/// <summary>
/// Create <see cref="LightGbmRegressionTrainer"/> from a pre-trained LightGBM model, which predicts a target using a gradient boosting decision tree regression.
/// </summary>
/// <param name="catalog">The <see cref="RegressionCatalog"/>.</param>
/// <param name="lightGbmModel"> A pre-trained <see cref="System.IO.Stream"/> of a LightGBM model file inferencing</param>
/// <param name="featureColumnName">The name of the feature column. The column data must be a known-sized vector of <see cref="System.Single"/>.</param>
public static LightGbmRegressionTrainer LightGbm(this RegressionCatalog.RegressionTrainers catalog,
Stream lightGbmModel,
string featureColumnName = DefaultColumnNames.Features
)
{
Contracts.CheckValue(catalog, nameof(catalog));
var env = CatalogUtils.GetEnvironment(catalog);
return new LightGbmRegressionTrainer(env, lightGbmModel, featureColumnName);
}

/// <summary>
/// Create <see cref="LightGbmBinaryTrainer"/>, which predicts a target using a gradient boosting decision tree binary classification.
/// </summary>
Expand Down Expand Up @@ -119,6 +136,22 @@ public static LightGbmBinaryTrainer LightGbm(this BinaryClassificationCatalog.Bi
return new LightGbmBinaryTrainer(env, options);
}

/// <summary>
/// Create <see cref="LightGbmBinaryTrainer"/> from a pre-trained LightGBM model, which predicts a target using a gradient boosting decision tree binary classification.
/// </summary>
/// <param name="catalog">The <see cref="BinaryClassificationCatalog"/>.</param>
/// <param name="lightGbmModel"> A pre-trained <see cref="System.IO.Stream"/> of a LightGBM model file inferencing</param>
/// <param name="featureColumnName">The name of the feature column. The column data must be a known-sized vector of <see cref="System.Single"/>.</param>
public static LightGbmBinaryTrainer LightGbm(this BinaryClassificationCatalog.BinaryClassificationTrainers catalog,
Stream lightGbmModel,
string featureColumnName = DefaultColumnNames.Features
)
{
Contracts.CheckValue(catalog, nameof(catalog));
var env = CatalogUtils.GetEnvironment(catalog);
return new LightGbmBinaryTrainer(env, lightGbmModel, featureColumnName);
}

/// <summary>
/// Create <see cref="LightGbmRankingTrainer"/>, which predicts a target using a gradient boosting decision tree ranking model.
/// </summary>
Expand Down Expand Up @@ -174,6 +207,22 @@ public static LightGbmRankingTrainer LightGbm(this RankingCatalog.RankingTrainer
return new LightGbmRankingTrainer(env, options);
}

/// <summary>
/// Create <see cref="LightGbmRankingTrainer"/> from a pre-trained LightGBM model, which predicts a target using a gradient boosting decision tree ranking model.
/// </summary>
/// <param name="catalog">The <see cref="RankingCatalog"/>.</param>
/// <param name="lightGbmModel"> A pre-trained <see cref="System.IO.Stream"/> of a LightGBM model file inferencing</param>
/// <param name="featureColumnName">The name of the feature column. The column data must be a known-sized vector of <see cref="System.Single"/>.</param>
public static LightGbmRankingTrainer LightGbm(this RankingCatalog.RankingTrainers catalog,
Stream lightGbmModel,
string featureColumnName = DefaultColumnNames.Features
)
{
Contracts.CheckValue(catalog, nameof(catalog));
var env = CatalogUtils.GetEnvironment(catalog);
return new LightGbmRankingTrainer(env, lightGbmModel, featureColumnName);
}

/// <summary>
/// Create <see cref="LightGbmMulticlassTrainer"/>, which predicts a target using a gradient boosting decision tree multiclass classification model.
/// </summary>
Expand Down Expand Up @@ -225,5 +274,21 @@ public static LightGbmMulticlassTrainer LightGbm(this MulticlassClassificationCa
var env = CatalogUtils.GetEnvironment(catalog);
return new LightGbmMulticlassTrainer(env, options);
}

/// <summary>
/// Create <see cref="LightGbmMulticlassTrainer"/> from a pre-trained LightGBM model, which predicts a target using a gradient boosting decision tree multiclass classification model.
/// </summary>
/// <param name="catalog">The <see cref="MulticlassClassificationCatalog"/>.</param>
/// <param name="lightGbmModel"> A pre-trained <see cref="System.IO.Stream"/> of a LightGBM model file inferencing</param>
/// <param name="featureColumnName">The name of the feature column. The column data must be a known-sized vector of <see cref="System.Single"/>.</param>
public static LightGbmMulticlassTrainer LightGbm(this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog,
Stream lightGbmModel,
string featureColumnName = DefaultColumnNames.Features
)
{
Contracts.CheckValue(catalog, nameof(catalog));
var env = CatalogUtils.GetEnvironment(catalog);
return new LightGbmMulticlassTrainer(env, lightGbmModel, featureColumnName);
}
}
}
59 changes: 51 additions & 8 deletions src/Microsoft.ML.LightGbm/LightGbmMulticlassTrainer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using Microsoft.ML;
using Microsoft.ML.Calibrators;
Expand Down Expand Up @@ -170,6 +171,26 @@ internal LightGbmMulticlassTrainer(IHostEnvironment env,
{
}

/// <summary>
/// Initializes a new instance of <see cref="LightGbmRankingTrainer"/>
/// </summary>
/// <param name="env">The private instance of <see cref="IHostEnvironment"/>.</param>
/// <param name="lightGbmModel"> A pre-trained <see cref="System.IO.Stream"/> of a LightGBM model file inferencing</param>
/// <param name="featureColumnName">The name of the feature column.</param>
internal LightGbmMulticlassTrainer(IHostEnvironment env,
Stream lightGbmModel,
string featureColumnName = DefaultColumnNames.Features)
: base(env,
LoadNameValue,
new Options()
{
FeatureColumnName = featureColumnName,
LightGbmModel = lightGbmModel
},
new SchemaShape.Column())
{
}

private InternalTreeEnsemble GetBinaryEnsemble(int classID)
{
var res = new InternalTreeEnsemble();
Expand Down Expand Up @@ -213,11 +234,15 @@ private protected override void CheckDataValid(IChannel ch, RoleMappedData data)
{
Host.AssertValue(ch);
base.CheckDataValid(ch, data);
var labelType = data.Schema.Label.Value.Type;
if (!(labelType is BooleanDataViewType || labelType is KeyDataViewType || labelType == NumberDataViewType.Single))
// If using a pre-trained model file we don't need a label or group column
if (LightGbmTrainerOptions.LightGbmModel == null)
{
throw ch.ExceptParam(nameof(data),
$"Label column '{data.Schema.Label.Value.Name}' is of type '{labelType.RawType}', but must be of unsigned int, boolean or float.");
var labelType = data.Schema.Label.Value.Type;
if (!(labelType is BooleanDataViewType || labelType is KeyDataViewType || labelType == NumberDataViewType.Single))
{
throw ch.ExceptParam(nameof(data),
$"Label column '{data.Schema.Label.Value.Name}' is of type '{labelType.RawType}', but must be of unsigned int, boolean or float.");
}
}
}

Expand All @@ -227,6 +252,21 @@ private protected override void InitializeBeforeTraining()
_numberOfClasses = 0;
}

private protected override void AdditionalLoadPreTrainedModel(string modelText)
{
string[] lines = modelText.Split(new char[] { '\r', '\n' }, StringSplitOptions.RemoveEmptyEntries);
// Jump to the "objective" value in the file. It's at the beginning.
int i = 0;
while (!lines[i].StartsWith("objective"))
i++;

// Format in the file is objective=multiclass num_class:4
var split = lines[i].Split(' ');
_numberOfClassesIncludingNan = int.Parse(split[1].Split(':')[1]);
_numberOfClasses = _numberOfClassesIncludingNan;
}


private protected override void ConvertNaNLabels(IChannel ch, RoleMappedData data, float[] labels)
{
// Only initialize one time.
Expand Down Expand Up @@ -317,11 +357,14 @@ private protected override void CheckAndUpdateParametersBeforeTraining(IChannel

private protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema)
{
bool success = inputSchema.TryFindColumn(LabelColumn.Name, out var labelCol);
Contracts.Assert(success);
SchemaShape.Column labelCol = default;
if (LightGbmTrainerOptions.LightGbmModel == null)
{
bool success = inputSchema.TryFindColumn(LabelColumn.Name, out labelCol);
Contracts.Assert(success);
}

var metadata = new SchemaShape(labelCol.Annotations.Where(x => x.Name == AnnotationUtils.Kinds.KeyValues)
.Concat(AnnotationUtils.GetTrainerOutputAnnotation()));
var metadata = LightGbmTrainerOptions.LightGbmModel == null ? new SchemaShape(labelCol.Annotations.Where(x => x.Name == AnnotationUtils.Kinds.KeyValues).Concat(AnnotationUtils.GetTrainerOutputAnnotation())) : new SchemaShape(AnnotationUtils.GetTrainerOutputAnnotation());
return new[]
{
new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Vector, NumberDataViewType.Single, false, new SchemaShape(AnnotationUtils.AnnotationsForMulticlassScoreColumn(labelCol))),
Expand Down
58 changes: 42 additions & 16 deletions src/Microsoft.ML.LightGbm/LightGbmRankingTrainer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

using System;
using System.Collections.Generic;
using System.IO;
using Microsoft.ML;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
Expand Down Expand Up @@ -215,27 +216,52 @@ internal LightGbmRankingTrainer(IHostEnvironment env,
Host.CheckNonEmpty(rowGroupIdColumnName, nameof(rowGroupIdColumnName));
}

/// <summary>
/// Initializes a new instance of <see cref="LightGbmRankingTrainer"/>
/// </summary>
/// <param name="env">The private instance of <see cref="IHostEnvironment"/>.</param>
/// <param name="lightGbmModel"> A pre-trained <see cref="System.IO.Stream"/> of a LightGBM model file inferencing</param>
/// <param name="featureColumnName">The name of the feature column.</param>
internal LightGbmRankingTrainer(IHostEnvironment env,
Stream lightGbmModel,
string featureColumnName = DefaultColumnNames.Features)
: base(env,
LoadNameValue,
new Options()
{
FeatureColumnName = featureColumnName,
LightGbmModel = lightGbmModel
},
new SchemaShape.Column())
{
}

private protected override void CheckDataValid(IChannel ch, RoleMappedData data)
{
Host.AssertValue(ch);
base.CheckDataValid(ch, data);
// Check label types.
var labelCol = data.Schema.Label.Value;
var labelType = labelCol.Type;
if (!(labelType is KeyDataViewType || labelType == NumberDataViewType.Single))
{
throw ch.ExceptParam(nameof(data),
$"Label column '{labelCol.Name}' is of type '{labelType.RawType}', but must be Key or Single.");
}
// Check group types.
if (!data.Schema.Group.HasValue)
throw ch.ExceptValue(nameof(data.Schema.Group), "Group column is missing.");
var groupCol = data.Schema.Group.Value;
var groupType = groupCol.Type;
if (!(groupType == NumberDataViewType.UInt32 || groupType is KeyDataViewType))

// If using a pre-trained model file we don't need a label or group column
if (LightGbmTrainerOptions.LightGbmModel == null)
{
throw ch.ExceptParam(nameof(data),
$"Group column '{groupCol.Name}' is of type '{groupType.RawType}', but must be UInt32 or Key.");
// Check label types.
var labelCol = data.Schema.Label.Value;
var labelType = labelCol.Type;
if (!(labelType is KeyDataViewType || labelType == NumberDataViewType.Single))
{
throw ch.ExceptParam(nameof(data),
$"Label column '{labelCol.Name}' is of type '{labelType.RawType}', but must be Key or Single.");
}
// Check group types.
if (!data.Schema.Group.HasValue)
throw ch.ExceptValue(nameof(data.Schema.Group), "Group column is missing.");
var groupCol = data.Schema.Group.Value;
var groupType = groupCol.Type;
if (!(groupType == NumberDataViewType.UInt32 || groupType is KeyDataViewType))
{
throw ch.ExceptParam(nameof(data),
$"Group column '{groupCol.Name}' is of type '{groupType.RawType}', but must be UInt32 or Key.");
}
}
}

Expand Down
34 changes: 30 additions & 4 deletions src/Microsoft.ML.LightGbm/LightGbmRegressionTrainer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
// See the LICENSE file in the project root for more information.

using System.Collections.Generic;
using System.IO;
using Microsoft.ML;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
Expand Down Expand Up @@ -192,6 +193,26 @@ internal LightGbmRegressionTrainer(IHostEnvironment env, Options options)
{
}

/// <summary>
/// Initializes a new instance of <see cref="LightGbmRegressionTrainer"/>
/// </summary>
/// <param name="env">The private instance of <see cref="IHostEnvironment"/>.</param>
/// <param name="lightGbmModel"> A pre-trained <see cref="System.IO.Stream"/> of a LightGBM model file inferencing</param>
/// <param name="featureColumnName">The name of the feature column.</param>
internal LightGbmRegressionTrainer(IHostEnvironment env,
Stream lightGbmModel,
string featureColumnName = DefaultColumnNames.Features)
: base(env,
LoadNameValue,
new Options()
{
FeatureColumnName = featureColumnName,
LightGbmModel = lightGbmModel
},
new SchemaShape.Column())
{
}

private protected override LightGbmRegressionModelParameters CreatePredictor()
{
Host.Check(TrainedEnsemble != null,
Expand All @@ -204,11 +225,16 @@ private protected override void CheckDataValid(IChannel ch, RoleMappedData data)
{
Host.AssertValue(ch);
base.CheckDataValid(ch, data);
var labelType = data.Schema.Label.Value.Type;
if (!(labelType is BooleanDataViewType || labelType is KeyDataViewType || labelType == NumberDataViewType.Single))

// If using a pre-trained model file we don't need a label column
if (LightGbmTrainerOptions.LightGbmModel == null)
{
throw ch.ExceptParam(nameof(data),
$"Label column '{data.Schema.Label.Value.Name}' is of type '{labelType.RawType}', but must be an unsigned int, boolean or float.");
var labelType = data.Schema.Label.Value.Type;
if (!(labelType is BooleanDataViewType || labelType is KeyDataViewType || labelType == NumberDataViewType.Single))
{
throw ch.ExceptParam(nameof(data),
$"Label column '{data.Schema.Label.Value.Name}' is of type '{labelType.RawType}', but must be an unsigned int, boolean or float.");
}
}
}

Expand Down
Loading

0 comments on commit f93ab25

Please sign in to comment.