Avoid panics when incorrect buffer sizes are provided to read/write_frames

This commit is contained in:
Ian Hobson
2023-05-12 16:52:06 +02:00
parent 3400778991
commit 3d1826007e
3 changed files with 22 additions and 15 deletions

View File

@@ -41,6 +41,12 @@ pub enum Error {
/// The file is not optimized for writing new data /// The file is not optimized for writing new data
DataChunkNotPreparedForAppend, DataChunkNotPreparedForAppend,
/// A buffer with a length that isn't a multiple of channel_count was provided
InvalidBufferSize {
buffer_size: usize,
channel_count: u16,
},
} }
impl StdError for Error {} impl StdError for Error {}

View File

@@ -110,12 +110,12 @@ impl<R: Read + Seek> AudioFrameReader<R> {
let common_format = self.format.common_format(); let common_format = self.format.common_format();
let bits_per_sample = self.format.bits_per_sample; let bits_per_sample = self.format.bits_per_sample;
assert!( if buffer.len() % channel_count != 0 {
buffer.len() % channel_count == 0, return Err(Error::InvalidBufferSize {
"read_frames was called with a mis-sized buffer, expected a multiple of {}, was {}", buffer_size: buffer.len(),
channel_count, channel_count: self.format.channel_count,
buffer.len() });
); }
let position = self.inner.stream_position()? - self.start; let position = self.inner.stream_position()? - self.start;
let frames_requested = (buffer.len() / channel_count) as u64; let frames_requested = (buffer.len() / channel_count) as u64;

View File

@@ -37,21 +37,22 @@ where
/// Write interleaved samples in `buffer` /// Write interleaved samples in `buffer`
/// ///
/// # Panics /// The writer will convert from the buffer's sample type into the file's sample type.
/// /// Note that no dithering will be applied during sample type conversion,
/// This function will panic if `buffer.len()` modulo the Wave file's channel count /// if dithering is required then it will need to be applied manually.
/// is not zero. pub fn write_frames<S>(&mut self, buffer: &[S]) -> Result<(), Error>
pub fn write_frames<S>(&mut self, buffer: &[S]) -> Result<u64, Error>
where where
S: Sample, S: Sample,
{ {
let format = &self.inner.inner.format; let format = &self.inner.inner.format;
let channel_count = format.channel_count as usize; let channel_count = format.channel_count as usize;
assert!( if buffer.len() % channel_count != 0 {
buffer.len() % channel_count == 0, return Err(Error::InvalidBufferSize {
"frames buffer does not contain a number of samples % channel_count == 0" buffer_size: buffer.len(),
); channel_count: format.channel_count,
});
}
let mut write_buffer = self let mut write_buffer = self
.inner .inner