1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
//! Bit level I/O operations.

#[cfg(test)]
mod tests;

use std::io::Read;
use std::io::Write;
use std::mem::size_of;
use super::Result;
use super::Error::{Eof, IoError, InvalidInput};

/// A trait for counting the number of bytes flowing trough a `Read` or `Write` implementation.
pub trait ByteCount {
    /// Returns the number of bytes in this stream.
    fn get_count(&self) -> u64;
}

/// A trait for object that allow reading symbols of variable bit lengths.
pub trait BitRead {
    /// Reads a single symbol from the input.
    fn read_bits(&mut self, bits: usize) -> Result<usize>;
}

/// A trait for object that allows writing symbols of variable bit lengths.
pub trait BitWrite {
    /// Writes a single symbol to the output.
    fn write_bits(&mut self, symbol: usize, bits: usize) -> Result<()>;
    /// Flushes all remaining bits to the output after the last whole octet, padded with zero bits.
    fn flush_bits(&mut self) -> Result<()>;
}

/// Common data fields used by the reader and writer implementations.
struct BitBuffer {
    /// Byte buffer that holds currently unread or unwritten bits.
    bytes: [u8; 1usize],
    /// Number of uread or unwritten bits in the buffer.
    bits: usize,
    /// Number of bytes read or written to the underlying buffer.
    count: u64,
}

impl BitBuffer {
    /// Creates a new instance with an empty buffer.
    fn new() -> BitBuffer {
        BitBuffer {
            bytes: [0u8; 1usize],
            bits: 0usize,
            count: 0u64,
        }
    }
}

/// Actual implementation of bit-oriented input reader.
pub struct BitReader<'a> {
    /// Temporary buffer to store unused bits.
    buffer: BitBuffer,
    /// Underlying byte-oriented I/O stream.
    input: &'a mut Read,
}

impl<'a> BitReader<'a> {
    /// Creates a new instance by wrapping a byte input stream.
    pub fn new(reader: &'a mut Read) -> BitReader<'a> {
        BitReader {
            buffer: BitBuffer::new(),
            input: reader,
        }
    }
}

impl<'a> ByteCount for BitReader<'a> {
    fn get_count(&self) -> u64 {
        self.buffer.count
    }
}

impl<'a> BitRead for BitReader<'a> {
    fn read_bits(&mut self, mut bits: usize) -> Result<usize> {
        if bits > size_of::<usize>() * 8 {
            return Err(InvalidInput);
        }

        let mut result = 0usize;
        while bits > 0 {
            if self.buffer.bits >= bits {
                // Get the upper bits from buffer (bytes: 000xxxyy -> 00000xxx)
                result <<= bits;
                result |= self.buffer.bytes[0] as usize >> (self.buffer.bits - bits);
                // Update buffer (bytes: 000xxxyy -> 000000yy)
                self.buffer.bits -= bits;
                self.buffer.bytes[0] &= (1 << self.buffer.bits) - 1;
                // Update reamining bits to read
                bits = 0
            } else if self.buffer.bits > 0 {
                // Get remaining bits from the buffer (bytes: 000000yy)
                result <<= self.buffer.bits;
                result |= self.buffer.bytes[0] as usize;
                // Update reamining bits to read
                bits -= self.buffer.bits;
                // Update buffer
                self.buffer.bytes[0] = 0;
                self.buffer.bits = 0;
            } else {
                // Read next byte from the underlying input stream
                match self.input.read(&mut self.buffer.bytes) {
                    Ok(0) => {
                        return Err(Eof);
                    },
                    Ok(_) => {
                        self.buffer.count += 1;
                        self.buffer.bits = 8;
                    },
                    Err(e) => {
                        return Err(IoError(e));
                    }
                }
            }
        }
        return Ok(result);
    }
}

/// Actual implementation of bit-oriented outour writer.
pub struct BitWriter<'a> {
    /// Temporary buffer to store unused bits.
    buffer: BitBuffer,
    /// Underlying byte-oriented I/O stream.
    output: &'a mut Write,
}

impl<'a> BitWriter<'a> {
    /// Creates a new instance by wrapping a byte output stream.
    pub fn new(writer: &'a mut Write) -> BitWriter<'a> {
        BitWriter {
            buffer: BitBuffer::new(),
            output: writer
        }
    }
}

impl<'a> ByteCount for BitWriter<'a> {
    fn get_count(&self) -> u64 {
        self.buffer.count
    }
}

impl<'a> BitWrite for BitWriter<'a> {
    fn write_bits(&mut self, mut symbol: usize, mut bits: usize) -> Result<()> {
        if (bits > size_of::<usize>() * 8) || (symbol >> bits > 0){
            return Err(InvalidInput);
        }

        while bits > 0 {
            if self.buffer.bits + bits <= 8 {
                // Put the upper bits into buffer (symbol: 00000yyy, bytes: 000000xx -> 000xxyyy)
                if self.buffer.bits > 0 {
                    self.buffer.bytes[0] <<= bits;
                }
                self.buffer.bytes[0] |= symbol as u8;
                self.buffer.bits += bits;
                // Update remaining bits to write
                bits = 0;
                symbol = 0;
            } else if self.buffer.bits < 8 {
                let num = 8 - self.buffer.bits;
                // Put the upper bits into buffer (symbol: 000yyyzz -> 00000yyy, bytes: 000xxxxx -> xxxxxyyy)
                if self.buffer.bits > 0 {
                    self.buffer.bytes[0] <<= num;
                }
                self.buffer.bytes[0] |= (symbol >> (bits - num)) as u8;
                self.buffer.bits += num;
                // Update remaining bits to write (symbol: 000yyyzz -> 000000zz)
                bits -= num;
                symbol &= (1 << bits) - 1;
            }
            if self.buffer.bits == 8 {
                try!(self.flush_bits())
            }
        }
        return Ok(())
    }

    fn flush_bits(&mut self) -> Result<()> {
        if self.buffer.bits > 0 {
            self.buffer.bytes[0] <<= 8 - self.buffer.bits;
            match self.output.write_all(&self.buffer.bytes) {
                Ok(_) => {
                    self.buffer.count += 1;
                    self.buffer.bytes[0] = 0;
                    self.buffer.bits = 0;
                },
                Err(e) => {
                    return Err(IoError(e));
                }
            }
        }
        return Ok(())
    }
}