diff --git a/client-specification b/client-specification index 3d91a8f..207acbe 160000 --- a/client-specification +++ b/client-specification @@ -1 +1 @@ -Subproject commit 3d91a8f18e1679994d308e5dde47954ae2def56d +Subproject commit 207acbe3e62c1a8132baf57ab7330cfb88218081 diff --git a/src/client.rs b/src/client.rs index 07e01e7..ce508cc 100644 --- a/src/client.rs +++ b/src/client.rs @@ -591,12 +591,12 @@ where } let identifier = identifier.unwrap(); let total_weight = feature.variants.iter().map(|v| v.value.weight as u32).sum(); - strategy::normalised_hash(&group, identifier, total_weight) + strategy::normalised_variant_hash(&group, identifier, total_weight) .map(|selected_weight| { let mut counter: u32 = 0; for variant in feature.variants.iter().as_ref() { counter += variant.value.weight as u32; - if counter > selected_weight { + if counter >= selected_weight { variant.count.fetch_add(1, Ordering::Relaxed); return variant.into(); } @@ -1320,7 +1320,7 @@ mod tests { ], enabled: true, }; - assert_eq!(variant2, c.get_variant(UserFeatures::two, &uid1)); + assert_eq!(variant1, c.get_variant(UserFeatures::two, &uid1)); assert_eq!(variant2, c.get_variant(UserFeatures::two, &session1)); assert_eq!(variant1, c.get_variant(UserFeatures::two, &host1)); } @@ -1395,7 +1395,7 @@ mod tests { ], enabled: true, }; - assert_eq!(variant2, c.get_variant_str("two", &uid1)); + assert_eq!(variant1, c.get_variant_str("two", &uid1)); assert_eq!(variant2, c.get_variant_str("two", &session1)); assert_eq!(variant1, c.get_variant_str("two", &host1)); } diff --git a/src/strategy.rs b/src/strategy.rs index b8d82ca..dcbf269 100644 --- a/src/strategy.rs +++ b/src/strategy.rs @@ -107,7 +107,7 @@ pub fn partial_rollout(group: &str, variable: Option<&String>, rollout: u32) -> 100 => true, rollout => { if let Ok(normalised) = normalised_hash(group, variable, 100) { - rollout > normalised + rollout >= normalised } else { false } @@ -119,13 +119,36 @@ pub fn partial_rollout(group: &str, variable: Option<&String>, rollout: u32) -> /// required for extension strategies, but reusing this is probably a good idea /// for consistency across implementations. pub fn normalised_hash(group: &str, identifier: &str, modulus: u32) -> std::io::Result { + normalised_hash_internal(group, identifier, modulus, 0) +} + +const VARIANT_NORMALIZATION_SEED: u32 = 86028157; + +/// Calculates a hash for **variant distribution** in the standard way +/// expected for Unleash clients. This differs from the +/// [`normalised_hash`] function in that it uses a different seed to +/// ensure a fair distribution. +pub fn normalised_variant_hash( + group: &str, + identifier: &str, + modulus: u32, +) -> std::io::Result { + normalised_hash_internal(group, identifier, modulus, VARIANT_NORMALIZATION_SEED) +} + +fn normalised_hash_internal( + group: &str, + identifier: &str, + modulus: u32, + seed: u32, +) -> std::io::Result { // See https://github.com/stusmall/murmur3/pull/16 : .chain may avoid // copying in the general case, and may be faster (though perhaps // benchmarking would be useful - small datasizes here could make the best // path non-obvious) - but until murmur3 is fixed, we need to provide it // with a single string no matter what. let mut reader = Cursor::new(format!("{}:{}", &group, &identifier)); - murmur3_32(&mut reader, 0).map(|hash_result| hash_result % modulus) + murmur3_32(&mut reader, seed).map(|hash_result| hash_result % modulus + 1) } // Build a closure to handle session id rollouts, parameterised by groupId and a @@ -861,4 +884,22 @@ mod tests { fn normalised_hash() { assert!(50 > super::normalised_hash("AB12A", "122", 100).unwrap()); } + + #[test] + fn test_normalized_hash() { + assert_eq!(73, super::normalised_hash("gr1", "123", 100).unwrap()); + assert_eq!(25, super::normalised_hash("groupX", "999", 100).unwrap()); + } + + #[test] + fn test_normalised_variant_hash() { + assert_eq!( + 96, + super::normalised_variant_hash("gr1", "123", 100).unwrap() + ); + assert_eq!( + 60, + super::normalised_variant_hash("groupX", "999", 100).unwrap() + ); + } } diff --git a/tests/clientspec.rs b/tests/clientspec.rs index a70483d..b3ca47b 100644 --- a/tests/clientspec.rs +++ b/tests/clientspec.rs @@ -38,6 +38,21 @@ mod tests { enabled: bool, } + impl PartialEq for VariantResult { + fn eq(&self, other: &client::Variant) -> bool { + let payload_matches = match &self._payload { + Some(payload) => match (other.payload.get("type"), other.payload.get("value")) { + (Some(_type), Some(value)) => { + &payload._type == _type && &payload._value == value + } + _ => false, + }, + None => other.payload.get("type").is_none() && other.payload.get("value").is_none(), + }; + self.enabled == other.enabled && self._name == other.name && payload_matches + } + } + #[derive(Debug, Deserialize)] struct VariantTest { description: String, @@ -134,11 +149,11 @@ mod tests { } Tests::VariantTests { variant_tests } => { for test in variant_tests { - let result = - c.is_enabled_str(&test.toggle_name, Some(&test.context), false); + let result = c.get_variant_str(&test.toggle_name, &test.context); + assert_eq!( - test.expected_result.enabled, result, - "Test '{}' failed: got {} instead of {:?}", + test.expected_result, result, + "Test '{}' failed: got {:?} instead of {:?}", test.description, result, test.expected_result ); }