ML.NET(二) 使用机器学习预测表情分析

 这个例子使用模型进行表情分析:

准备数据: happy,sad 等;

using Common;
using ConsoleApp2;
using Microsoft.ML;
using Microsoft.ML.Data;
using System.Diagnostics;
using static Microsoft.ML.Transforms.ValueToKeyMappingEstimator;


/*
 * 训练一个Happy 和Sad 等表情的模型并使用预测  图像分类器(Image Classification) 
 *  ***********************预测不是很准,数据集找对应人种数据可以尝试亚洲,欧美分开
 */

var projectDirectory = Path.GetFullPath(Path.Combine(AppContext.BaseDirectory, "./"));
var workspaceRelativePath = Path.Combine(projectDirectory, "workspace");
var assetsRelativePath = Path.Combine(projectDirectory, "assets");


string outputMlNetModelFilePath = "model.zip";//Path.Combine("", "outputs", "imageClassifier.zip");
string imagesFolderPathForPredictions = Path.Combine("", "inputs", "test-images");
// 设置ML.NET环境
var mlContext = new MLContext();

// 加载数据


IEnumerable<ImageData> images = LoadImagesFromDirectory(folder: assetsRelativePath, useFolderNameAsLabel: true);
IDataView fullImagesDataset = mlContext.Data.LoadFromEnumerable(images);
IDataView shuffledFullImageFilePathsDataset = mlContext.Data.ShuffleRows(fullImagesDataset);

// 3. Load Images with in-memory type within the IDataView and Transform Labels to Keys (Categorical)
IDataView shuffledFullImagesDataset = mlContext.Transforms.Conversion.
        MapValueToKey(outputColumnName: "LabelAsKey", inputColumnName: "Label", keyOrdinality: KeyOrdinality.ByValue)
    .Append(mlContext.Transforms.LoadRawImageBytes(
                                    outputColumnName: "Image",
                                    imageFolder: assetsRelativePath,
                                    inputColumnName: "ImagePath"))
    .Fit(shuffledFullImageFilePathsDataset)
    .Transform(shuffledFullImageFilePathsDataset);


// 4. Split the data 80:20 into train and test sets, train and evaluate.
var trainTestData = mlContext.Data.TrainTestSplit(shuffledFullImagesDataset, testFraction: 0.2);
IDataView trainDataView = trainTestData.TrainSet;
IDataView testDataView = trainTestData.TestSet;

// 5. Define the model's training pipeline using DNN default values
//
var pipeline = mlContext.MulticlassClassification.Trainers
        .ImageClassification(featureColumnName: "Image",
                             labelColumnName: "LabelAsKey",
                             validationSet: testDataView)
    .Append(mlContext.Transforms.Conversion.MapKeyToValue(outputColumnName: "PredictedLabel",
                                                          inputColumnName: "PredictedLabel"));
// Measuring training time
var watch = Stopwatch.StartNew();
Console.WriteLine($"--------------------开始训练-------------------------------");

//6. Train
ITransformer trainedModel = pipeline.Fit(trainDataView);


watch.Stop();
var elapsedMs = watch.ElapsedMilliseconds;

Console.WriteLine($"--------------------训练用时: {elapsedMs / 1000} seconds --------------------");

// 7. Get the quality metrics (accuracy, etc.)
EvaluateModel(mlContext, testDataView, trainedModel);

 8. Save the model to assets/outputs (You get ML.NET .zip model file and TensorFlow .pb model file)
mlContext.Model.Save(trainedModel, trainDataView.Schema, outputMlNetModelFilePath); //outputMlNetModelFilePath
Console.WriteLine($"Model saved to: {outputMlNetModelFilePath}");

 9. Try a single prediction simulating an end-user app
TrySinglePrediction(imagesFolderPathForPredictions, mlContext, trainedModel);


static IEnumerable<ImageData> LoadImagesFromDirectory(
   string folder,
   bool useFolderNameAsLabel = true)
   => FileUtils.LoadImagesFromDirectory(folder, useFolderNameAsLabel)
       .Select(x => new ImageData(x.imagePath, x.label));

static void EvaluateModel(MLContext mlContext, IDataView testDataset, ITransformer trainedModel)
{
    Console.WriteLine("Making predictions in bulk for evaluating model's quality...");

    // Measuring time
    var watch = Stopwatch.StartNew();

    var predictionsDataView = trainedModel.Transform(testDataset);

    var metrics = mlContext.MulticlassClassification.Evaluate(predictionsDataView, labelColumnName: "LabelAsKey", predictedLabelColumnName: "PredictedLabel");
    ConsoleHelper.PrintMultiClassClassificationMetrics("TensorFlow DNN Transfer Learning", metrics);

    watch.Stop();
    var elapsed2Ms = watch.ElapsedMilliseconds;

    Console.WriteLine($"Predicting and Evaluation took: {elapsed2Ms / 1000} seconds");
}
static void TrySinglePrediction(string imagesFolderPathForPredictions, MLContext mlContext, ITransformer trainedModel)
{
    // Create prediction function to try one prediction
    var predictionEngine = mlContext.Model
        .CreatePredictionEngine<InMemoryImageData, ImagePrediction>(trainedModel);

    var testImages = FileUtils.LoadInMemoryImagesFromDirectory(
        imagesFolderPathForPredictions, false);

    var imageToPredict = testImages.Last
        ();

    var prediction = predictionEngine.Predict(imageToPredict);

    Console.WriteLine(
        $"Image Filename : [{imageToPredict.ImageFileName}], " +
        $"Scores : [{string.Join(",", prediction.Score)}], " +
        $"Predicted Label : {prediction.PredictedLabel}");
}
// 定义数据结构
class ImageData
{
    public ImageData(string imagePath, string label)
    {
        ImagePath = imagePath;
        Label = label;
    }

    public readonly string ImagePath;

    public readonly string Label;
}

class ModelInput
{
    public byte[] Image { get; set; }

    public UInt32 LabelAsKey { get; set; }

    public string ImagePath { get; set; }

    public string Label { get; set; }
}
class ModelOutput
{
    public string ImagePath { get; set; }

    public string Label { get; set; }

    public string PredictedLabel { get; set; }
}
public class ImagePrediction
{
    [ColumnName("Score")]
    public float[] Score;

    [ColumnName("PredictedLabel")]
    public string PredictedLabel;
}

相关推荐

  1. DMLC深度机器学习框架MXNet的编译安装

    2024-04-07 23:46:02       34 阅读

最近更新

  1. TCP协议是安全的吗?

    2024-04-07 23:46:02       18 阅读
  2. 阿里云服务器执行yum,一直下载docker-ce-stable失败

    2024-04-07 23:46:02       19 阅读
  3. 【Python教程】压缩PDF文件大小

    2024-04-07 23:46:02       19 阅读
  4. 通过文章id递归查询所有评论(xml)

    2024-04-07 23:46:02       20 阅读

热门阅读

  1. [RK3566-Android11] 关于2K (2560x1440)分辨率支持问题

    2024-04-07 23:46:02       12 阅读
  2. PHP获取亚马逊商品详情api接口

    2024-04-07 23:46:02       17 阅读
  3. 一名顶尖的黑客高手要学些什么?

    2024-04-07 23:46:02       13 阅读
  4. OMP实现压缩感知的实现(MATLAB)

    2024-04-07 23:46:02       18 阅读
  5. git log

    2024-04-07 23:46:02       12 阅读
  6. C语言中的预处理详解

    2024-04-07 23:46:02       13 阅读
  7. 探索自然语言处理:简单而完整的学习路线指南

    2024-04-07 23:46:02       12 阅读
  8. nginx + keepalived 搭建教程

    2024-04-07 23:46:02       13 阅读
  9. Windows常用命令

    2024-04-07 23:46:02       17 阅读
  10. 基于YOLOv8的木材缺陷检测系统说明

    2024-04-07 23:46:02       14 阅读
  11. stable diffusion 预处理器解释大全,不断更新

    2024-04-07 23:46:02       13 阅读