Tensorflow音频分类

tensorflow

https://www.tensorflow.org/lite/examples/audio_classification/overview?hl=zh-cn

官方有移动端demo

前端不会  就只能找找有没有java支持

注意版本

注意JDK版本

package com.example.demo17.controller;


import org.tensorflow.*;
import org.tensorflow.ndarray.*;
import org.tensorflow.ndarray.impl.dense.FloatDenseNdArray;
import org.tensorflow.proto.framework.DataType;
import org.tensorflow.proto.framework.MetaGraphDef;
import org.tensorflow.proto.framework.SignatureDef;
import org.tensorflow.proto.framework.TensorInfo;
import org.tensorflow.types.TFloat32;
import org.tensorflow.types.TInt64;

import javax.sound.sampled.AudioFormat;
import javax.sound.sampled.AudioInputStream;
import javax.sound.sampled.AudioSystem;
import javax.sound.sampled.UnsupportedAudioFileException;
import javax.xml.transform.Result;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;

public class Test {


    private static FloatNdArray t1() {
//        String audioFilePath = "D:\\ai\\cat.wav";
        String audioFilePath = "C:\\Users\\user\\Downloads\\output_Wo9KJb-5zuz1_2.wav";
//        String audioFilePath = "D:\\ai\\111\\111.wav";
        // YAMNet期望的采样率
        int sampleRate = 16000;
        // YAMNet帧大小,0.96秒
        int frameSizeInMs = 96;
        // YAMNet帧步长,0.48秒
        int hopSizeInMs = 48;

        try (AudioInputStream audioStream = AudioSystem.getAudioInputStream(Paths.get(audioFilePath).toFile())) {
            AudioFormat format = audioStream.getFormat();
            if (format.getSampleRate() != sampleRate || format.getChannels() != 1) {
                System.out.println("Warning: Audio must be 16kHz mono. Consider preprocessing.");
            }
            int frameSize = (int) (sampleRate * frameSizeInMs / 1000);
            int hopSize = (int) (sampleRate * hopSizeInMs / 1000);

            byte[] buffer = new byte[frameSize * format.getFrameSize()];
            short[] audioSamples = new short[frameSize];
            // 存储每个帧的音频数据
            List<Float> floatList = new ArrayList<>();
            while (true) {
                int bytesRead = audioStream.read(buffer);
                if (bytesRead == -1) {
                    break;
                }
                // 将读取的字节转换为short数组(假设16位精度)
                for (int i = 0; i < bytesRead / format.getFrameSize(); i++) {
                    audioSamples[i] = (short) ((buffer[i * 2] & 0xFF) | (buffer[i * 2 + 1] << 8));
                }
                // 对当前帧进行处理(例如,归一化和准备送入模型)
                float[] floats = processFrame(audioSamples);
                for (float aFloat : floats) {
                    floatList.add(aFloat);
                }
                // 移动到下一个帧
                System.arraycopy(audioSamples, hopSize, audioSamples, 0, frameSize - hopSize);
            }

            // 将List<Float>转换为float[]
            float[] floatArray = new float[floatList.size()];
            for (int i = 0; i < floatList.size(); i++) {
                floatArray[i] = floatList.get(i);
            }

            return StdArrays.ndCopyOf(floatArray);
        } catch (UnsupportedAudioFileException | IOException e) {
            e.printStackTrace();
        }
        return null;
    }


    private static float[] processFrame(short[] frame) {
        // 示例:归一化音频数据到[-1.0, 1.0]
        float[] normalizedFrame = new float[frame.length];
        for (int i = 0; i < frame.length; i++) {
            // short的最大值为32767,故除以32768得到[-1.0, 1.0]
            normalizedFrame[i] = frame[i] / 32768f;
        }
        return normalizedFrame;
    }

    static Map<String,String> map=new ConcurrentHashMap<>();

    public static void main(String[] args) throws Exception {
        FloatNdArray floatNdArray = t1();
        TFloat32 tFloat32 = TFloat32.tensorOf(floatNdArray);

        //SavedModelBundle savedModelBundle = SavedModelBundle.load("D:\\saved_model", "serve");
        SavedModelBundle savedModelBundle = SavedModelBundle.load("C:\\Users\\user\\Downloads\\archive", "serve");
        Map<String, SignatureDef> signatureDefMap = MetaGraphDef.parseFrom(savedModelBundle.metaGraphDef().toByteArray()).getSignatureDefMap();
        /**
         * 获取基本定义信息
         */
        SignatureDef modelSig = signatureDefMap.get("serving_default");
        String inputTensorName = modelSig.getInputsMap().get("waveform").getName();
        String outputTensorName = modelSig.getOutputsMap().get("output_0").getName();
        savedModelBundle.graph();
        try (Session session = savedModelBundle.session()) {
            /*JDK 17*/
//            Result run = session.runner()
//                    .feed(inputTensorName, tFloat32)
//                    .fetch(outputTensorName)
//                    .run();
//            Tensor out = run.get(0);
//            Shape shape = out.shape();
//
//            System.out.println(shape);
            /*JDK 8*/
            List<Tensor> run = session.runner()
                    .feed(inputTensorName, tFloat32)
                    .fetch(outputTensorName)
                    .run();
            Tensor tensor = run.get(0);
            Shape shape = tensor.shape();
            System.out.println(shape.asArray());
            String l=String.valueOf(shape.asArray()[0]);
            //读取CSV文件
            String csvFile = "C:\\Users\\user\\Downloads\\archive\\assets\\yamnet_class_map.csv";
            try {
                List<String> lines = Files.readAllLines(Paths.get(csvFile));
                for (String line : lines) {
                    String[] values = line.split(",");
                    map.put(values[0], values[2]);
                }
            } catch (IOException e) {
                e.printStackTrace();
            }
            String s = map.get(l);
            System.out.println(s);
        }
    }
}

相关推荐

  1. TensorFlow音频分类修复

    2024-06-08 18:58:08       6 阅读
  2. TensorFlow 量化投资分析

    2024-06-08 18:58:08       22 阅读
  3. 音频数据分析注意事项

    2024-06-08 18:58:08       16 阅读
  4. TensorFlow

    2024-06-08 18:58:08       34 阅读
  5. TensorFlow

    2024-06-08 18:58:08       37 阅读

最近更新

  1. TCP协议是安全的吗?

    2024-06-08 18:58:08       19 阅读
  2. 阿里云服务器执行yum,一直下载docker-ce-stable失败

    2024-06-08 18:58:08       19 阅读
  3. 【Python教程】压缩PDF文件大小

    2024-06-08 18:58:08       19 阅读
  4. 通过文章id递归查询所有评论(xml)

    2024-06-08 18:58:08       20 阅读

热门阅读

  1. Linux网络编程之select的理解

    2024-06-08 18:58:08       7 阅读
  2. MATLAB sort

    2024-06-08 18:58:08       7 阅读
  3. 2024-06-04 问AI: 介绍一下 Tensorflow 里面的 Keras

    2024-06-08 18:58:08       6 阅读
  4. spec文件是干嘛的?

    2024-06-08 18:58:08       5 阅读
  5. 11本AI人工智能相关电子书推荐(带下载地址)

    2024-06-08 18:58:08       11 阅读
  6. 深度学习 - PyTorch简介

    2024-06-08 18:58:08       6 阅读
  7. springAMQP(示例)

    2024-06-08 18:58:08       8 阅读
  8. QT5.5.0中使用lambda表达式时遇到的问题

    2024-06-08 18:58:08       6 阅读
  9. C++的算法:拓扑排序的原理及应用

    2024-06-08 18:58:08       5 阅读
  10. 百度大模型算法实习岗上岸经验分享!

    2024-06-08 18:58:08       10 阅读
  11. 矩阵相乘torch.einsum()

    2024-06-08 18:58:08       8 阅读
  12. mybatisplus QueryWrapper or 写法

    2024-06-08 18:58:08       10 阅读
  13. window.clearInterval(timer) 清除定时器

    2024-06-08 18:58:08       12 阅读