原先传wav格式,后来发现前端生成的wav格式不完整 后端改mp3 其实是mp3和wav都可以接收
前端MP3和wav格式不正确,导致可以接收,但都无法计算时长
依赖
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>tensorflow-core-api</artifactId>
<version>0.4.2</version>
</dependency>
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>tensorflow-core-api</artifactId>
<version>0.4.2</version>
<classifier>linux-x86_64</classifier>
</dependency>
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>tensorflow-core-api</artifactId>
<version>0.4.2</version>
<classifier>windows-x86_64</classifier>
</dependency>
<!-- https://mvnrepository.com/artifact/com.googlecode.soundlibs/jlayer -->
<dependency>
<groupId>com.googlecode.soundlibs</groupId>
<artifactId>jlayer</artifactId>
<version>1.0.1.4</version>
</dependency>
TensorFlow工具类
package com.ruoyi.webapp.tensorflow;
import org.springframework.stereotype.Component;
import org.springframework.web.multipart.MultipartFile;
import javax.sound.sampled.*;
import java.io.*;
import java.nio.file.*;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import org.tensorflow.*;
import org.tensorflow.ndarray.*;
import org.tensorflow.proto.framework.MetaGraphDef;
import org.tensorflow.proto.framework.SignatureDef;
import org.tensorflow.types.TFloat32;
import com.google.protobuf.InvalidProtocolBufferException;
import javazoom.jl.decoder.Bitstream;
import javazoom.jl.decoder.BitstreamException;
import javazoom.jl.decoder.Decoder;
import javazoom.jl.decoder.JavaLayerException;
import javazoom.jl.decoder.SampleBuffer;
import javazoom.jl.decoder.Header;
@Component
public class YamnetUtils3 {
private static final int SAMPLE_RATE = 16000;
private static final int FRAME_SIZE_IN_MS = 96;
private static final int HOP_SIZE_IN_MS = 48;
//private static final String MODEL_PATH = "C:\\Users\\user\\Downloads\\archive";
private static final String MODEL_PATH = "/usr/local/develop/archive"; // TensorFlow 模型路径
private static Map<String, SignatureDef> signatureDefMap;
static {
try (SavedModelBundle savedModelBundle = SavedModelBundle.load(MODEL_PATH, "serve")) {
signatureDefMap = MetaGraphDef.parseFrom(savedModelBundle.metaGraphDef().toByteArray()).getSignatureDefMap();
} catch (InvalidProtocolBufferException e) {
e.printStackTrace();
}
}
private static final SignatureDef modelSig = signatureDefMap.get("serving_default");
private static final String inputTensorName = modelSig.getInputsMap().get("waveform").getName();
private static final String outputTensorName = modelSig.getOutputsMap().get("output_0").getName();
private static Map<String, String> map = new ConcurrentHashMap<>();
static {
//String csvFile = "C:\\Users\\user\\Downloads\\archive\\assets\\yamnet_class_map.csv";
String csvFile = "/usr/local/develop/archive/assets/yamnet_class_map.csv";
try {
List<String> lines = Files.readAllLines(Paths.get(csvFile));
for (String line : lines) {
String[] values = line.split(",");
map.put(values[0], values[2]);
}
} catch (IOException e) {
e.printStackTrace();
}
}
public String classifyAudio(MultipartFile file) throws IOException, UnsupportedAudioFileException {
// Convert the MP3 file to a supported format (WAV)
File wavFile = convertMp3ToWav(file);
// Process the converted file
return yamnetPare(wavFile);
}
private File convertMp3ToWav(MultipartFile file) throws IOException {
File tempFile = File.createTempFile("temp", ".wav");
File mp3File = File.createTempFile("temp", ".mp3");
file.transferTo(mp3File);
try (FileInputStream mp3Stream = new FileInputStream(mp3File);
FileOutputStream wavStream = new FileOutputStream(tempFile)) {
Bitstream bitstream = new Bitstream(mp3Stream);
Decoder decoder = new Decoder();
// Write WAV header
writeWavHeader(wavStream, 0, 1, SAMPLE_RATE, 16);
Header header;
while ((header = bitstream.readFrame()) != null) {
SampleBuffer output = (SampleBuffer) decoder.decodeFrame(header, bitstream);
short[] samples = output.getBuffer();
for (short sample : samples) {
wavStream.write(shortToBytes(sample));
}
bitstream.closeFrame();
}
// Update WAV header with data size
updateWavHeader(tempFile);
} catch (Exception e) {
throw new IOException("Failed to convert MP3 to WAV", e);
}
return tempFile;
}
private void writeWavHeader(OutputStream out, long totalAudioLen, int channels, long sampleRate, int bitDepth) throws IOException {
long totalDataLen = totalAudioLen + 36;
long byteRate = sampleRate * channels * bitDepth / 8;
byte[] header = new byte[44];
header[0] = 'R'; // RIFF/WAVE header
header[1] = 'I';
header[2] = 'F';
header[3] = 'F';
header[4] = (byte) (totalDataLen & 0xff);
header[5] = (byte) ((totalDataLen >> 8) & 0xff);
header[6] = (byte) ((totalDataLen >> 16) & 0xff);
header[7] = (byte) ((totalDataLen >> 24) & 0xff);
header[8] = 'W';
header[9] = 'A';
header[10] = 'V';
header[11] = 'E';
header[12] = 'f'; // 'fmt ' chunk
header[13] = 'm';
header[14] = 't';
header[15] = ' ';
header[16] = 16; // 4 bytes: size of 'fmt ' chunk
header[17] = 0;
header[18] = 0;
header[19] = 0;
header[20] = 1; // format = 1
header[21] = 0;
header[22] = (byte) channels;
header[23] = 0;
header[24] = (byte) (sampleRate & 0xff);
header[25] = (byte) ((sampleRate >> 8) & 0xff);
header[26] = (byte) ((sampleRate >> 16) & 0xff);
header[27] = (byte) ((sampleRate >> 24) & 0xff);
header[28] = (byte) (byteRate & 0xff);
header[29] = (byte) ((byteRate >> 8) & 0xff);
header[30] = (byte) ((byteRate >> 16) & 0xff);
header[31] = (byte) ((byteRate >> 24) & 0xff);
header[32] = (byte) (2 * 16 / 8); // block align
header[33] = 0;
header[34] = (byte) bitDepth; // bits per sample
header[35] = 0;
header[36] = 'd';
header[37] = 'a';
header[38] = 't';
header[39] = 'a';
header[40] = (byte) (totalAudioLen & 0xff);
header[41] = (byte) ((totalAudioLen >> 8) & 0xff);
header[42] = (byte) ((totalAudioLen >> 16) & 0xff);
header[43] = (byte) ((totalAudioLen >> 24) & 0xff);
out.write(header, 0, 44);
}
private void updateWavHeader(File wavFile) throws IOException {
RandomAccessFile wavRAF = new RandomAccessFile(wavFile, "rw");
wavRAF.seek(4);
wavRAF.write(intToBytes((int) (wavRAF.length() - 8)));
wavRAF.seek(40);
wavRAF.write(intToBytes((int) (wavRAF.length() - 44)));
wavRAF.close();
}
private byte[] intToBytes(int value) {
return new byte[]{
(byte) (value & 0xFF),
(byte) ((value >> 8) & 0xFF),
(byte) ((value >> 16) & 0xFF),
(byte) ((value >> 24) & 0xFF)
};
}
private byte[] shortToBytes(short value) {
return new byte[]{
(byte) (value & 0xFF),
(byte) ((value >> 8) & 0xFF)
};
}
private FloatNdArray processAudio(File file) throws IOException, UnsupportedAudioFileException {
try (AudioInputStream audioStream = AudioSystem.getAudioInputStream(file)) {
AudioFormat format = audioStream.getFormat();
if (format.getSampleRate() != SAMPLE_RATE || format.getChannels() != 1) {
System.out.println("Warning: Audio must be 16kHz mono. Consider preprocessing.");
}
int frameSize = (int) (SAMPLE_RATE * FRAME_SIZE_IN_MS / 1000);
int hopSize = (int) (SAMPLE_RATE * HOP_SIZE_IN_MS / 1000);
byte[] buffer = new byte[frameSize * format.getFrameSize()];
short[] audioSamples = new short[frameSize];
List<Float> floatList = new ArrayList<>();
while (true) {
int bytesRead = audioStream.read(buffer);
if (bytesRead == -1) {
break;
}
for (int i = 0; i < bytesRead / format.getFrameSize(); i++) {
audioSamples[i] = (short) ((buffer[i * 2] & 0xFF) | (buffer[i * 2 + 1] << 8));
}
float[] floats = normalizeAudio(audioSamples);
for (float aFloat : floats) {
floatList.add(aFloat);
}
System.arraycopy(audioSamples, hopSize, audioSamples, 0, frameSize - hopSize);
}
float[] floatArray = new float[floatList.size()];
for (int i = 0; i < floatList.size(); i++) {
floatArray[i] = floatList.get(i);
}
return StdArrays.ndCopyOf(floatArray);
}
}
private float[] normalizeAudio(short[] frame) {
float[] normalizedFrame = new float[frame.length];
for (int i = 0; i < frame.length; i++) {
normalizedFrame[i] = frame[i] / 32768f;
}
return normalizedFrame;
}
private String yamnetPare(File file) throws IOException, UnsupportedAudioFileException {
FloatNdArray floatNdArray = processAudio(file);
TFloat32 tFloat32 = TFloat32.tensorOf(floatNdArray);
try (SavedModelBundle savedModelBundle = SavedModelBundle.load(MODEL_PATH, "serve")) {
try (Session session = savedModelBundle.session()) {
List<Tensor> run = session.runner()
.feed(inputTensorName, tFloat32)
.fetch(outputTensorName)
.run();
Tensor tensor = run.get(0);
Shape shape = tensor.shape();
System.out.println(shape + "--------------------------------------------------");
String key = String.valueOf(shape.asArray()[0]);
String value = map.get(key);
return value;
}
}
}
}
文件接口
//记录文件
@PostMapping("/reportUpload")
//@RequireVip
//睡眠监测时 文件上传接口
public AjaxResult reportUpload(MultipartFile file) throws IOException {
try {
//获取文件名
String originalFilename = file.getOriginalFilename();
//获取时长秒
//String wavDuration = getWavDuration(file);
//String wavDuration = getWavDuration3(file);
//String wavDuration = getWavDuration55(file);
String filePath = RuoYiConfig.getUploadPath();///usr/local/nginx/html/upload/upload
//上传并返回新文件路径
String fileName = FileUploadUtils.upload(filePath, file);
//获取音频类型
//String s = yamnetUtils2.yamnetPare(file);
String s = yamnetUtils3.classifyAudio(file);
//获取时长秒
//String wavDuration =getWavDuration(fileName);getMp3Duration
String wavDuration =getMp3Duration(fileName);
if(!StringUtils.isEmpty(s)){
if (!s.equals("Speech") && !s.equals("Snoring") && !s.equals("Cough")) {
s = "Other";
}
TReportFile tReportFile=new TReportFile();
tReportFile.setUid(getLoginUser().getUserId());
tReportFile.setFileType(s);
tReportFile.setLengths(wavDuration);
tReportFile.setFilePath(fileName);
//tReportFile.setCreateTime(new Date());
tReportFileService.insertTReportFile(tReportFile);
}
return success();
}catch (Exception e){
e.printStackTrace();
return error();
}
}
//获取wav文件时长
private String getWavDuration(String relativeFilePath) {
String basePath = "/usr/local/nginx/html/upload";
if (relativeFilePath.startsWith("/profile")) {
relativeFilePath = relativeFilePath.substring(8);
}
String fullPath = basePath + relativeFilePath;
File file = new File(fullPath);
if (!file.exists()) {
return "File not found";
}
try (AudioInputStream audioInputStream = AudioSystem.getAudioInputStream(file)) {
AudioFileFormat fileFormat = AudioSystem.getAudioFileFormat(file);
long frameLength = audioInputStream.getFrameLength();
float frameRate = fileFormat.getFormat().getFrameRate();
float durationInSeconds = frameLength / frameRate;
return Math.round(durationInSeconds) + "s";
} catch (UnsupportedAudioFileException | IOException e) {
e.printStackTrace();
return "Error";
}
}
//获取Mp3文件时长
public static String getMp3Duration(String relativeFilePath) {
String basePath = "/usr/local/nginx/html/upload";
if (relativeFilePath.startsWith("/profile")) {
relativeFilePath = relativeFilePath.substring(8);
}
String fullPath = basePath + relativeFilePath;
File file = new File(fullPath);
if (!file.exists()) {
return "File not found";
}
try (FileInputStream fileInputStream = new FileInputStream(file)) {
Bitstream bitstream = new Bitstream(fileInputStream);
Header header;
int totalFrames = 0;
float totalDuration = 0;
while ((header = bitstream.readFrame()) != null) {
totalDuration += header.ms_per_frame();
totalFrames++;
bitstream.closeFrame();
}
float durationInSeconds = totalDuration / 1000.0f;
return Math.round(durationInSeconds) + "s";
} catch (IOException | BitstreamException e) {
e.printStackTrace();
return "Error";
}
}