Skip to content

Commit

Permalink
Infer output length (#18)
Browse files Browse the repository at this point in the history
- Adds `decode!` macro that allows to skip specifying output length.
- Makes panic messages more informative by including context
information.
  • Loading branch information
slowli authored Sep 26, 2023
2 parents 0b21bac + 3b0f2b4 commit 989c0a4
Show file tree
Hide file tree
Showing 10 changed files with 358 additions and 52 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@ The project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.html)

## [Unreleased]

### Added

- Add `decode!` macro that allows to skip specifying output length.
- Make panic messages more informative by including context information.

### Changed

- Bump MSRV to 1.66.
Expand Down
6 changes: 6 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 5 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@ categories = ["encoding", "no-std"]
description = "Constant functions for converting hex- and base64-encoded strings into bytes"
repository = "https://github.com/slowli/const-decoder"

[dependencies.compile-fmt]
git = "https://github.com/slowli/compile-fmt.git"
version = "0.1.0"
rev = "0f0c1ab3e90c854beede0ceaa87ce0d96e209a7f"

[dev-dependencies]
base64 = "0.21.0"
bech32 = "0.9.0"
Expand Down
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@ const SECRET_KEY: [u8; 64] = Decoder::Hex.decode(
b"9e55d1e1aa1f455b8baad9fdf975503655f8b359d542fa7e4ce84106d625b352\
06fac1f22240cffd637ead6647188429fafda9c9cb7eae43386ac17f61115075",
);
// Alternatively, you can use `decode!` macro:
const PUBLIC_KEY: &[u8] = &const_decoder::decode!(
Decoder::Hex,
b"06fac1f22240cffd637ead6647188429fafda9c9cb7eae43386ac17f61115075",
);
```

[Bech32] encoding:
Expand Down
4 changes: 4 additions & 0 deletions deny.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,7 @@ allow-wildcard-paths = true
[sources]
unknown-registry = "deny"
unknown-git = "deny"
allow-git = [
# Temporarily allow relying on an unpiblished version of `compile-fmt`
"https://github.com/slowli/compile-fmt.git",
]
170 changes: 132 additions & 38 deletions src/decoder.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,57 @@
//! `Decoder` and closely related types.
use compile_fmt::{clip_ascii, compile_assert, compile_panic, fmt, Ascii};

use crate::wrappers::{SkipWhitespace, Skipper};

// Since `?` is not allowed in `const fn`s, we use its simplified version.
macro_rules! const_try {
($result:expr) => {
match $result {
Ok(value) => value,
Err(err) => return Err(err),
}
};
}

#[derive(Debug)]
struct DecodeError {
invalid_char: u8,
// `None` for hex encoding
alphabet: Option<Ascii<'static>>,
}

impl DecodeError {
const fn invalid_char(invalid_char: u8, alphabet: Option<Ascii<'static>>) -> Self {
Self {
invalid_char,
alphabet,
}
}

const fn panic(self, input_pos: usize) -> ! {
if self.invalid_char.is_ascii() {
if let Some(alphabet) = self.alphabet {
compile_panic!(
"Character '", self.invalid_char as char => fmt::<char>(), "' at position ",
input_pos => fmt::<usize>(), " is not a part of \
the decoder alphabet '", alphabet => clip_ascii(64, ""), "'"
);
} else {
compile_panic!(
"Character '", self.invalid_char as char => fmt::<char>(), "' at position ",
input_pos => fmt::<usize>(), " is not a hex digit"
);
}
} else {
compile_panic!(
"Non-ASCII character with decimal code ", self.invalid_char => fmt::<u8>(),
" encountered at position ", input_pos => fmt::<usize>()
);
}
}
}

/// Custom encoding scheme based on a certain alphabet (mapping between a subset of ASCII chars
/// and digits in `0..P`, where `P` is a power of 2).
///
Expand All @@ -20,6 +70,7 @@ use crate::wrappers::{SkipWhitespace, Skipper};
/// ```
#[derive(Debug, Clone, Copy)]
pub struct Encoding {
alphabet: Ascii<'static>,
table: [u8; 128],
bits_per_char: u8,
}
Expand All @@ -40,45 +91,52 @@ impl Encoding {
/// - Panics if `alphabet` does not consist of distinct ASCII chars.
/// - Panics if `alphabet` length is not 2, 4, 8, 16, 32 or 64.
#[allow(clippy::cast_possible_truncation)]
pub const fn new(alphabet: &str) -> Self {
pub const fn new(alphabet: &'static str) -> Self {
let bits_per_char = match alphabet.len() {
2 => 1,
4 => 2,
8 => 3,
16 => 4,
32 => 5,
64 => 6,
_ => panic!("Invalid alphabet length; must be one of 2, 4, 8, 16, 32, or 64"),
other => compile_panic!(
"Invalid alphabet length ", other => fmt::<usize>(),
"; must be one of 2, 4, 8, 16, 32, or 64"
),
};

let mut table = [Self::NO_MAPPING; 128];
let alphabet_bytes = alphabet.as_bytes();
let alphabet = Ascii::new(alphabet); // will panic if `alphabet` contains non-ASCII chars
let mut index = 0;
while index < alphabet_bytes.len() {
let byte = alphabet_bytes[index];
assert!(byte < 0x80, "Non-ASCII alphabet character");
let byte_idx = byte as usize;
assert!(
compile_assert!(
table[byte_idx] == Self::NO_MAPPING,
"Alphabet character is mentioned several times"
"Alphabet character '", byte as char => fmt::<char>(), "' is mentioned several times"
);
table[byte_idx] = index as u8;
index += 1;
}

Self {
alphabet,
table,
bits_per_char,
}
}

const fn lookup(&self, ascii_char: u8) -> u8 {
const fn lookup(&self, ascii_char: u8) -> Result<u8, DecodeError> {
if !ascii_char.is_ascii() {
return Err(DecodeError::invalid_char(ascii_char, Some(self.alphabet)));
}
let mapping = self.table[ascii_char as usize];
assert!(
mapping != Self::NO_MAPPING,
"Character is not present in the alphabet"
);
mapping
if mapping == Self::NO_MAPPING {
Err(DecodeError::invalid_char(ascii_char, Some(self.alphabet)))
} else {
Ok(mapping)
}
}
}

Expand All @@ -87,30 +145,30 @@ impl Encoding {
struct HexDecoderState(Option<u8>);

impl HexDecoderState {
const fn byte_value(val: u8) -> u8 {
match val {
const fn byte_value(val: u8) -> Result<u8, DecodeError> {
Ok(match val {
b'0'..=b'9' => val - b'0',
b'A'..=b'F' => val - b'A' + 10,
b'a'..=b'f' => val - b'a' + 10,
_ => panic!("Invalid character in input; expected a hex digit"),
}
_ => return Err(DecodeError::invalid_char(val, None)),
})
}

const fn new() -> Self {
Self(None)
}

#[allow(clippy::option_if_let_else)] // `Option::map_or_else` cannot be used in const fns
const fn update(mut self, byte: u8) -> (Self, Option<u8>) {
let byte = Self::byte_value(byte);
const fn update(mut self, byte: u8) -> Result<(Self, Option<u8>), DecodeError> {
let byte = const_try!(Self::byte_value(byte));
let output = if let Some(b) = self.0 {
self.0 = None;
Some((b << 4) + byte)
} else {
self.0 = Some(byte);
None
};
(self, output)
Ok((self, output))
}

const fn is_final(self) -> bool {
Expand All @@ -136,8 +194,8 @@ impl CustomDecoderState {
}

#[allow(clippy::comparison_chain)] // not feasible in const context
const fn update(mut self, byte: u8) -> (Self, Option<u8>) {
let byte = self.table.lookup(byte);
const fn update(mut self, byte: u8) -> Result<(Self, Option<u8>), DecodeError> {
let byte = const_try!(self.table.lookup(byte));
let output = if self.filled_bits < 8 - self.table.bits_per_char {
self.partial_byte = (self.partial_byte << self.table.bits_per_char) + byte;
self.filled_bits += self.table.bits_per_char;
Expand All @@ -155,7 +213,7 @@ impl CustomDecoderState {
self.filled_bits = new_filled_bits;
Some(output)
};
(self, output)
Ok((self, output))
}

const fn is_final(&self) -> bool {
Expand All @@ -173,25 +231,25 @@ enum DecoderState {
}

impl DecoderState {
const fn update(self, byte: u8) -> (Self, Option<u8>) {
match self {
const fn update(self, byte: u8) -> Result<(Self, Option<u8>), DecodeError> {
Ok(match self {
Self::Hex(state) => {
let (updated_state, output) = state.update(byte);
let (updated_state, output) = const_try!(state.update(byte));
(Self::Hex(updated_state), output)
}
Self::Base64(state) => {
if byte == b'=' {
(self, None)
} else {
let (updated_state, output) = state.update(byte);
let (updated_state, output) = const_try!(state.update(byte));
(Self::Base64(updated_state), output)
}
}
Self::Custom(state) => {
let (updated_state, output) = state.update(byte);
let (updated_state, output) = const_try!(state.update(byte));
(Self::Custom(updated_state), output)
}
}
})
}

const fn is_final(&self) -> bool {
Expand Down Expand Up @@ -232,7 +290,7 @@ impl Decoder {
/// # Panics
///
/// Panics in the same situations as [`Encoding::new()`].
pub const fn custom(alphabet: &str) -> Self {
pub const fn custom(alphabet: &'static str) -> Self {
Self::Custom(Encoding::new(alphabet))
}

Expand Down Expand Up @@ -279,29 +337,65 @@ impl Decoder {
}
}

let update = state.update(input[in_index]);
let update = match state.update(input[in_index]) {
Ok(update) => update,
Err(err) => err.panic(in_index),
};
state = update.0;
if let Some(byte) = update.1 {
assert!(
out_index < N,
"Output overflow: the input decodes to more bytes than specified \
as the output length"
);
bytes[out_index] = byte;
if out_index < N {
bytes[out_index] = byte;
}
out_index += 1;
}
in_index += 1;
}
assert!(

compile_assert!(
out_index <= N,
"Output overflow: the input decodes to ", out_index => fmt::<usize>(),
" bytes, while type inference implies ", N => fmt::<usize>(), ". \
Either fix the input or change the output buffer length correspondingly"
);
compile_assert!(
out_index == N,
"Output underflow: the input was decoded into less bytes than specified \
as the output length"
"Output underflow: the input decodes to ", out_index => fmt::<usize>(),
" bytes, while type inference implies ", N => fmt::<usize>(), ". \
Either fix the input or change the output buffer length correspondingly"
);

assert!(
state.is_final(),
"Left-over state after processing input. This usually means that the input \
is incorrect (e.g., an odd number of hex digits)."
);
bytes
}

pub(crate) const fn do_decode_len(self, input: &[u8], skipper: Option<Skipper>) -> usize {
let mut in_index = 0;
let mut out_index = 0;
let mut state = self.new_state();

while in_index < input.len() {
if let Some(skipper) = skipper {
let new_in_index = skipper.skip(input, in_index);
if new_in_index != in_index {
in_index = new_in_index;
continue;
}
}

let update = match state.update(input[in_index]) {
Ok(update) => update,
Err(err) => err.panic(in_index),
};
state = update.0;
if update.1.is_some() {
out_index += 1;
}
in_index += 1;
}
out_index
}
}
Loading

0 comments on commit 989c0a4

Please sign in to comment.