TensorFlow音频分类修复

原先传wav格式,后来发现前端生成的wav格式不完整   后端改mp3  其实是mp3和wav都可以接收 

前端MP3和wav格式不正确,导致可以接收,但都无法计算时长

修复TensorFlow放到生产后报错问题-CSDN博客

依赖

  <dependency>
            <groupId>org.tensorflow</groupId>
            <artifactId>tensorflow-core-api</artifactId>
            <version>0.4.2</version>
        </dependency>
        <dependency>
            <groupId>org.tensorflow</groupId>
            <artifactId>tensorflow-core-api</artifactId>
            <version>0.4.2</version>
            <classifier>linux-x86_64</classifier>
        </dependency>
        <dependency>
            <groupId>org.tensorflow</groupId>
            <artifactId>tensorflow-core-api</artifactId>
            <version>0.4.2</version>
            <classifier>windows-x86_64</classifier>
        </dependency>



        <!-- https://mvnrepository.com/artifact/com.googlecode.soundlibs/jlayer -->
        <dependency>
            <groupId>com.googlecode.soundlibs</groupId>
            <artifactId>jlayer</artifactId>
            <version>1.0.1.4</version>
        </dependency>

TensorFlow工具类

package com.ruoyi.webapp.tensorflow;

import org.springframework.stereotype.Component;
import org.springframework.web.multipart.MultipartFile;
import javax.sound.sampled.*;
import java.io.*;
import java.nio.file.*;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import org.tensorflow.*;
import org.tensorflow.ndarray.*;
import org.tensorflow.proto.framework.MetaGraphDef;
import org.tensorflow.proto.framework.SignatureDef;
import org.tensorflow.types.TFloat32;
import com.google.protobuf.InvalidProtocolBufferException;

import javazoom.jl.decoder.Bitstream;
import javazoom.jl.decoder.BitstreamException;
import javazoom.jl.decoder.Decoder;
import javazoom.jl.decoder.JavaLayerException;
import javazoom.jl.decoder.SampleBuffer;
import javazoom.jl.decoder.Header;


@Component
public class YamnetUtils3 {
    private static final int SAMPLE_RATE = 16000;
    private static final int FRAME_SIZE_IN_MS = 96;
    private static final int HOP_SIZE_IN_MS = 48;
    //private static final String MODEL_PATH = "C:\\Users\\user\\Downloads\\archive";
    private static final String MODEL_PATH = "/usr/local/develop/archive"; // TensorFlow 模型路径
    private static Map<String, SignatureDef> signatureDefMap;

    static {
        try (SavedModelBundle savedModelBundle = SavedModelBundle.load(MODEL_PATH, "serve")) {
            signatureDefMap = MetaGraphDef.parseFrom(savedModelBundle.metaGraphDef().toByteArray()).getSignatureDefMap();
        } catch (InvalidProtocolBufferException e) {
            e.printStackTrace();
        }
    }

    private static final SignatureDef modelSig = signatureDefMap.get("serving_default");
    private static final String inputTensorName = modelSig.getInputsMap().get("waveform").getName();
    private static final String outputTensorName = modelSig.getOutputsMap().get("output_0").getName();
    private static Map<String, String> map = new ConcurrentHashMap<>();

    static {
        //String csvFile = "C:\\Users\\user\\Downloads\\archive\\assets\\yamnet_class_map.csv";
        String csvFile = "/usr/local/develop/archive/assets/yamnet_class_map.csv";
        try {
            List<String> lines = Files.readAllLines(Paths.get(csvFile));
            for (String line : lines) {
                String[] values = line.split(",");
                map.put(values[0], values[2]);
            }
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    public String classifyAudio(MultipartFile file) throws IOException, UnsupportedAudioFileException {
        // Convert the MP3 file to a supported format (WAV)
        File wavFile = convertMp3ToWav(file);
        // Process the converted file
        return yamnetPare(wavFile);
    }

    private File convertMp3ToWav(MultipartFile file) throws IOException {
        File tempFile = File.createTempFile("temp", ".wav");
        File mp3File = File.createTempFile("temp", ".mp3");
        file.transferTo(mp3File);

        try (FileInputStream mp3Stream = new FileInputStream(mp3File);
             FileOutputStream wavStream = new FileOutputStream(tempFile)) {

            Bitstream bitstream = new Bitstream(mp3Stream);
            Decoder decoder = new Decoder();

            // Write WAV header
            writeWavHeader(wavStream, 0, 1, SAMPLE_RATE, 16);

            Header header;
            while ((header = bitstream.readFrame()) != null) {
                SampleBuffer output = (SampleBuffer) decoder.decodeFrame(header, bitstream);
                short[] samples = output.getBuffer();
                for (short sample : samples) {
                    wavStream.write(shortToBytes(sample));
                }
                bitstream.closeFrame();
            }

            // Update WAV header with data size
            updateWavHeader(tempFile);
        } catch (Exception e) {
            throw new IOException("Failed to convert MP3 to WAV", e);
        }
        return tempFile;
    }

    private void writeWavHeader(OutputStream out, long totalAudioLen, int channels, long sampleRate, int bitDepth) throws IOException {
        long totalDataLen = totalAudioLen + 36;
        long byteRate = sampleRate * channels * bitDepth / 8;

        byte[] header = new byte[44];
        header[0] = 'R';  // RIFF/WAVE header
        header[1] = 'I';
        header[2] = 'F';
        header[3] = 'F';
        header[4] = (byte) (totalDataLen & 0xff);
        header[5] = (byte) ((totalDataLen >> 8) & 0xff);
        header[6] = (byte) ((totalDataLen >> 16) & 0xff);
        header[7] = (byte) ((totalDataLen >> 24) & 0xff);
        header[8] = 'W';
        header[9] = 'A';
        header[10] = 'V';
        header[11] = 'E';
        header[12] = 'f';  // 'fmt ' chunk
        header[13] = 'm';
        header[14] = 't';
        header[15] = ' ';
        header[16] = 16;  // 4 bytes: size of 'fmt ' chunk
        header[17] = 0;
        header[18] = 0;
        header[19] = 0;
        header[20] = 1;  // format = 1
        header[21] = 0;
        header[22] = (byte) channels;
        header[23] = 0;
        header[24] = (byte) (sampleRate & 0xff);
        header[25] = (byte) ((sampleRate >> 8) & 0xff);
        header[26] = (byte) ((sampleRate >> 16) & 0xff);
        header[27] = (byte) ((sampleRate >> 24) & 0xff);
        header[28] = (byte) (byteRate & 0xff);
        header[29] = (byte) ((byteRate >> 8) & 0xff);
        header[30] = (byte) ((byteRate >> 16) & 0xff);
        header[31] = (byte) ((byteRate >> 24) & 0xff);
        header[32] = (byte) (2 * 16 / 8);  // block align
        header[33] = 0;
        header[34] = (byte) bitDepth;  // bits per sample
        header[35] = 0;
        header[36] = 'd';
        header[37] = 'a';
        header[38] = 't';
        header[39] = 'a';
        header[40] = (byte) (totalAudioLen & 0xff);
        header[41] = (byte) ((totalAudioLen >> 8) & 0xff);
        header[42] = (byte) ((totalAudioLen >> 16) & 0xff);
        header[43] = (byte) ((totalAudioLen >> 24) & 0xff);

        out.write(header, 0, 44);
    }

    private void updateWavHeader(File wavFile) throws IOException {
        RandomAccessFile wavRAF = new RandomAccessFile(wavFile, "rw");
        wavRAF.seek(4);
        wavRAF.write(intToBytes((int) (wavRAF.length() - 8)));
        wavRAF.seek(40);
        wavRAF.write(intToBytes((int) (wavRAF.length() - 44)));
        wavRAF.close();
    }

    private byte[] intToBytes(int value) {
        return new byte[]{
                (byte) (value & 0xFF),
                (byte) ((value >> 8) & 0xFF),
                (byte) ((value >> 16) & 0xFF),
                (byte) ((value >> 24) & 0xFF)
        };
    }

    private byte[] shortToBytes(short value) {
        return new byte[]{
                (byte) (value & 0xFF),
                (byte) ((value >> 8) & 0xFF)
        };
    }

    private FloatNdArray processAudio(File file) throws IOException, UnsupportedAudioFileException {
        try (AudioInputStream audioStream = AudioSystem.getAudioInputStream(file)) {
            AudioFormat format = audioStream.getFormat();
            if (format.getSampleRate() != SAMPLE_RATE || format.getChannels() != 1) {
                System.out.println("Warning: Audio must be 16kHz mono. Consider preprocessing.");
            }
            int frameSize = (int) (SAMPLE_RATE * FRAME_SIZE_IN_MS / 1000);
            int hopSize = (int) (SAMPLE_RATE * HOP_SIZE_IN_MS / 1000);

            byte[] buffer = new byte[frameSize * format.getFrameSize()];
            short[] audioSamples = new short[frameSize];
            List<Float> floatList = new ArrayList<>();
            while (true) {
                int bytesRead = audioStream.read(buffer);
                if (bytesRead == -1) {
                    break;
                }
                for (int i = 0; i < bytesRead / format.getFrameSize(); i++) {
                    audioSamples[i] = (short) ((buffer[i * 2] & 0xFF) | (buffer[i * 2 + 1] << 8));
                }
                float[] floats = normalizeAudio(audioSamples);
                for (float aFloat : floats) {
                    floatList.add(aFloat);
                }
                System.arraycopy(audioSamples, hopSize, audioSamples, 0, frameSize - hopSize);
            }
            float[] floatArray = new float[floatList.size()];
            for (int i = 0; i < floatList.size(); i++) {
                floatArray[i] = floatList.get(i);
            }
            return StdArrays.ndCopyOf(floatArray);
        }
    }

    private float[] normalizeAudio(short[] frame) {
        float[] normalizedFrame = new float[frame.length];
        for (int i = 0; i < frame.length; i++) {
            normalizedFrame[i] = frame[i] / 32768f;
        }
        return normalizedFrame;
    }

    private String yamnetPare(File file) throws IOException, UnsupportedAudioFileException {
        FloatNdArray floatNdArray = processAudio(file);
        TFloat32 tFloat32 = TFloat32.tensorOf(floatNdArray);

        try (SavedModelBundle savedModelBundle = SavedModelBundle.load(MODEL_PATH, "serve")) {
            try (Session session = savedModelBundle.session()) {
                List<Tensor> run = session.runner()
                        .feed(inputTensorName, tFloat32)
                        .fetch(outputTensorName)
                        .run();
                Tensor tensor = run.get(0);
                Shape shape = tensor.shape();
                System.out.println(shape + "--------------------------------------------------");
                String key = String.valueOf(shape.asArray()[0]);
                String value = map.get(key);
                return value;
            }
        }
    }
}

文件接口

 //记录文件
    @PostMapping("/reportUpload")
    //@RequireVip
    //睡眠监测时 文件上传接口
    public AjaxResult reportUpload(MultipartFile file) throws IOException {
        try {
            //获取文件名
            String originalFilename = file.getOriginalFilename();
            //获取时长秒
            //String wavDuration = getWavDuration(file);
            //String wavDuration = getWavDuration3(file);
            //String wavDuration = getWavDuration55(file);
            String filePath = RuoYiConfig.getUploadPath();///usr/local/nginx/html/upload/upload
            //上传并返回新文件路径
            String fileName = FileUploadUtils.upload(filePath, file);
            //获取音频类型
            //String s = yamnetUtils2.yamnetPare(file);
            String s = yamnetUtils3.classifyAudio(file);
            //获取时长秒
            //String wavDuration =getWavDuration(fileName);getMp3Duration
            String wavDuration =getMp3Duration(fileName);
            if(!StringUtils.isEmpty(s)){
                if (!s.equals("Speech") && !s.equals("Snoring") && !s.equals("Cough")) {
                    s = "Other";
                }
                TReportFile tReportFile=new TReportFile();
                tReportFile.setUid(getLoginUser().getUserId());
                tReportFile.setFileType(s);
                tReportFile.setLengths(wavDuration);
                tReportFile.setFilePath(fileName);
                //tReportFile.setCreateTime(new Date());
                tReportFileService.insertTReportFile(tReportFile);
            }
            return success();
        }catch (Exception e){
            e.printStackTrace();
            return error();
        }
    }

 //获取wav文件时长
    private String getWavDuration(String relativeFilePath) {
        String basePath = "/usr/local/nginx/html/upload";
        if (relativeFilePath.startsWith("/profile")) {
            relativeFilePath = relativeFilePath.substring(8);
        }
        String fullPath = basePath + relativeFilePath;
        File file = new File(fullPath);

        if (!file.exists()) {
            return "File not found";
        }

        try (AudioInputStream audioInputStream = AudioSystem.getAudioInputStream(file)) {
            AudioFileFormat fileFormat = AudioSystem.getAudioFileFormat(file);
            long frameLength = audioInputStream.getFrameLength();
            float frameRate = fileFormat.getFormat().getFrameRate();
            float durationInSeconds = frameLength / frameRate;
            return Math.round(durationInSeconds) + "s";
        } catch (UnsupportedAudioFileException | IOException e) {
            e.printStackTrace();
            return "Error";
        }
    }
    //获取Mp3文件时长
    public static String getMp3Duration(String relativeFilePath) {
        String basePath = "/usr/local/nginx/html/upload";
        if (relativeFilePath.startsWith("/profile")) {
            relativeFilePath = relativeFilePath.substring(8);
        }
        String fullPath = basePath + relativeFilePath;
        File file = new File(fullPath);

        if (!file.exists()) {
            return "File not found";
        }

        try (FileInputStream fileInputStream = new FileInputStream(file)) {
            Bitstream bitstream = new Bitstream(fileInputStream);
            Header header;
            int totalFrames = 0;
            float totalDuration = 0;

            while ((header = bitstream.readFrame()) != null) {
                totalDuration += header.ms_per_frame();
                totalFrames++;
                bitstream.closeFrame();
            }

            float durationInSeconds = totalDuration / 1000.0f;
            return Math.round(durationInSeconds) + "s";
        } catch (IOException | BitstreamException e) {
            e.printStackTrace();
            return "Error";
        }
    }

相关推荐

  1. TensorFlow音频分类修复

    2024-06-18 07:16:04       25 阅读
  2. TensorFlow 量化投资分析

    2024-06-18 07:16:04       40 阅读

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-06-18 07:16:04       98 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-06-18 07:16:04       106 阅读
  3. 在Django里面运行非项目文件

    2024-06-18 07:16:04       87 阅读
  4. Python语言-面向对象

    2024-06-18 07:16:04       96 阅读

热门阅读

  1. 在历史课堂教学过程中培养学生的计算思维能力

    2024-06-18 07:16:04       28 阅读
  2. 【C#基础】C#中的IEnumerable<T>接口介绍

    2024-06-18 07:16:04       32 阅读
  3. 企业级-pdf分页数据推送接收解析保存

    2024-06-18 07:16:04       29 阅读
  4. [absl_py][python]absl_py所有whl文件下载地址汇总

    2024-06-18 07:16:04       39 阅读
  5. XML 应用程序

    2024-06-18 07:16:04       29 阅读
  6. C#语言进阶(一)—委托 第二篇

    2024-06-18 07:16:04       32 阅读
  7. 数实融合创新发展 隆道分享企业级AI应用

    2024-06-18 07:16:04       30 阅读
  8. oracle SCHEDULER

    2024-06-18 07:16:04       26 阅读
  9. mysql面试题 Day2

    2024-06-18 07:16:04       28 阅读
  10. 【Homebrew】包管理器清理软件包,释放mac空间

    2024-06-18 07:16:04       45 阅读