Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Implementation for MSC3664: Pushrules for relations #11804

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/11804.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Implement [MSC3664](https://github.com/matrix-org/matrix-doc/pull/3664). Contributed by Nico.
17 changes: 17 additions & 0 deletions rust/src/push/base_rules.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ use crate::push::Action;
use crate::push::Condition;
use crate::push::EventMatchCondition;
use crate::push::PushRule;
use crate::push::RelatedEventMatchCondition;
use crate::push::SetTweak;
use crate::push::TweakValue;

Expand Down Expand Up @@ -114,6 +115,22 @@ pub const BASE_APPEND_OVERRIDE_RULES: &[PushRule] = &[
default: true,
default_enabled: true,
},
PushRule {
rule_id: Cow::Borrowed("global/override/.im.nheko.msc3664.reply"),
priority_class: 5,
conditions: Cow::Borrowed(&[Condition::Known(KnownCondition::RelatedEventMatch(
RelatedEventMatchCondition {
key: Some(Cow::Borrowed("sender")),
pattern: None,
pattern_type: Some(Cow::Borrowed("user_id")),
rel_type: Cow::Borrowed("m.in_reply_to"),
include_fallbacks: None,
},
))]),
actions: Cow::Borrowed(&[Action::Notify, HIGHLIGHT_ACTION, SOUND_ACTION]),
default: true,
default_enabled: true,
},
PushRule {
rule_id: Cow::Borrowed("global/override/.m.rule.contains_display_name"),
priority_class: 5,
Expand Down
99 changes: 97 additions & 2 deletions rust/src/push/evaluator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use regex::Regex;
use super::{
utils::{get_glob_matcher, get_localpart_from_id, GlobMatchType},
Action, Condition, EventMatchCondition, FilteredPushRules, KnownCondition,
RelatedEventMatchCondition,
};

lazy_static! {
Expand All @@ -49,6 +50,13 @@ pub struct PushRuleEvaluator {
/// The power level of the sender of the event, or None if event is an
/// outlier.
sender_power_level: Option<i64>,

/// The related events, indexed by relation type. Flattened in the same manner as
/// `flattened_keys`.
related_events_flattened: BTreeMap<String, BTreeMap<String, String>>,

/// If msc3664, push rules for related events, is enabled.
related_event_match_enabled: bool,
}

#[pymethods]
Expand All @@ -60,6 +68,8 @@ impl PushRuleEvaluator {
room_member_count: u64,
sender_power_level: Option<i64>,
notification_power_levels: BTreeMap<String, i64>,
related_events_flattened: BTreeMap<String, BTreeMap<String, String>>,
related_event_match_enabled: bool,
) -> Result<Self, Error> {
let body = flattened_keys
.get("content.body")
Expand All @@ -72,6 +82,8 @@ impl PushRuleEvaluator {
room_member_count,
notification_power_levels,
sender_power_level,
related_events_flattened,
related_event_match_enabled,
})
}

Expand Down Expand Up @@ -156,6 +168,9 @@ impl PushRuleEvaluator {
KnownCondition::EventMatch(event_match) => {
self.match_event_match(event_match, user_id)?
}
KnownCondition::RelatedEventMatch(event_match) => {
self.match_related_event_match(event_match, user_id)?
}
KnownCondition::ContainsDisplayName => {
if let Some(dn) = display_name {
if !dn.is_empty() {
Expand Down Expand Up @@ -239,6 +254,79 @@ impl PushRuleEvaluator {
compiled_pattern.is_match(haystack)
}

/// Evaluates a `related_event_match` condition. (MSC3664)
fn match_related_event_match(
&self,
event_match: &RelatedEventMatchCondition,
user_id: Option<&str>,
) -> Result<bool, Error> {
// First check if related event matching is enabled...
if !self.related_event_match_enabled {
return Ok(false);
}

// get the related event, fail if there is none.
let event = if let Some(event) = self.related_events_flattened.get(&*event_match.rel_type) {
event
} else {
return Ok(false);
};

// If we are not matching fallbacks, don't match if our special key indicating this is a
// fallback relation is not present.
if !event_match.include_fallbacks.unwrap_or(false)
&& event.contains_key("im.vector.is_falling_back")
{
return Ok(false);
}

// if we have no key, accept the event as matching, if it existed without matching any
// fields.
let key = if let Some(key) = &event_match.key {
key
} else {
return Ok(true);
};

let pattern = if let Some(pattern) = &event_match.pattern {
pattern
} else if let Some(pattern_type) = &event_match.pattern_type {
// The `pattern_type` can either be "user_id" or "user_localpart",
// either way if we don't have a `user_id` then the condition can't
// match.
let user_id = if let Some(user_id) = user_id {
user_id
} else {
return Ok(false);
};

match &**pattern_type {
"user_id" => user_id,
"user_localpart" => get_localpart_from_id(user_id)?,
_ => return Ok(false),
}
} else {
return Ok(false);
};

let haystack = if let Some(haystack) = event.get(&**key) {
haystack
} else {
return Ok(false);
};

// For the content.body we match against "words", but for everything
// else we match against the entire value.
let match_type = if key == "content.body" {
GlobMatchType::Word
} else {
GlobMatchType::Whole
};

let mut compiled_pattern = get_glob_matcher(pattern, match_type)?;
compiled_pattern.is_match(haystack)
}

/// Match the member count against an 'is' condition
/// The `is` condition can be things like '>2', '==3' or even just '4'.
fn match_member_count(&self, is: &str) -> Result<bool, Error> {
Expand Down Expand Up @@ -267,8 +355,15 @@ impl PushRuleEvaluator {
fn push_rule_evaluator() {
let mut flattened_keys = BTreeMap::new();
flattened_keys.insert("content.body".to_string(), "foo bar bob hello".to_string());
let evaluator =
PushRuleEvaluator::py_new(flattened_keys, 10, Some(0), BTreeMap::new()).unwrap();
let evaluator = PushRuleEvaluator::py_new(
flattened_keys,
10,
Some(0),
BTreeMap::new(),
BTreeMap::new(),
true,
)
.unwrap();

let result = evaluator.run(&FilteredPushRules::default(), None, Some("bob"));
assert_eq!(result.len(), 3);
Expand Down
61 changes: 53 additions & 8 deletions rust/src/push/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,8 @@ pub enum Condition {
#[serde(tag = "kind")]
pub enum KnownCondition {
EventMatch(EventMatchCondition),
#[serde(rename = "im.nheko.msc3664.related_event_match")]
RelatedEventMatch(RelatedEventMatchCondition),
ContainsDisplayName,
RoomMemberCount {
#[serde(skip_serializing_if = "Option::is_none")]
Expand Down Expand Up @@ -299,6 +301,20 @@ pub struct EventMatchCondition {
pub pattern_type: Option<Cow<'static, str>>,
}

/// The body of a [`Condition::RelatedEventMatch`]
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct RelatedEventMatchCondition {
#[serde(skip_serializing_if = "Option::is_none")]
pub key: Option<Cow<'static, str>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub pattern: Option<Cow<'static, str>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub pattern_type: Option<Cow<'static, str>>,
pub rel_type: Cow<'static, str>,
#[serde(skip_serializing_if = "Option::is_none")]
pub include_fallbacks: Option<bool>,
}
deepbluev7 marked this conversation as resolved.
Show resolved Hide resolved

/// The collection of push rules for a user.
#[derive(Debug, Clone, Default)]
#[pyclass(frozen)]
Expand Down Expand Up @@ -391,15 +407,21 @@ impl PushRules {
pub struct FilteredPushRules {
push_rules: PushRules,
enabled_map: BTreeMap<String, bool>,
msc3664_enabled: bool,
}

#[pymethods]
impl FilteredPushRules {
#[new]
pub fn py_new(push_rules: PushRules, enabled_map: BTreeMap<String, bool>) -> Self {
pub fn py_new(
push_rules: PushRules,
enabled_map: BTreeMap<String, bool>,
msc3664_enabled: bool,
) -> Self {
Self {
push_rules,
enabled_map,
msc3664_enabled,
}
}

Expand All @@ -414,13 +436,25 @@ impl FilteredPushRules {
/// Iterates over all the rules and their enabled state, including base
/// rules, in the order they should be executed in.
fn iter(&self) -> impl Iterator<Item = (&PushRule, bool)> {
self.push_rules.iter().map(|r| {
let enabled = *self
.enabled_map
.get(&*r.rule_id)
.unwrap_or(&r.default_enabled);
(r, enabled)
})
self.push_rules
.iter()
.filter(|rule| {
// Ignore disabled experimental push rules
if !self.msc3664_enabled
&& rule.rule_id == "global/override/.im.nheko.msc3664.reply"
erikjohnston marked this conversation as resolved.
Show resolved Hide resolved
{
return false;
}

true
})
.map(|r| {
let enabled = *self
.enabled_map
.get(&*r.rule_id)
.unwrap_or(&r.default_enabled);
(r, enabled)
})
}
}

Expand All @@ -446,6 +480,17 @@ fn test_deserialize_condition() {
let _: Condition = serde_json::from_str(json).unwrap();
}

#[test]
fn test_deserialize_unstable_msc3664_condition() {
let json = r#"{"kind":"im.nheko.msc3664.related_event_match","key":"content.body","pattern":"coffee","rel_type":"m.in_reply_to"}"#;

let condition: Condition = serde_json::from_str(json).unwrap();
assert!(matches!(
condition,
Condition::Known(KnownCondition::RelatedEventMatch(_))
));
}

#[test]
fn test_deserialize_custom_condition() {
let json = r#"{"kind":"custom_tag"}"#;
Expand Down
6 changes: 5 additions & 1 deletion stubs/synapse/synapse_rust/push.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@ class PushRules:
def rules(self) -> Collection[PushRule]: ...

class FilteredPushRules:
def __init__(self, push_rules: PushRules, enabled_map: Dict[str, bool]): ...
def __init__(
self, push_rules: PushRules, enabled_map: Dict[str, bool], msc3664_enabled: bool
): ...
def rules(self) -> Collection[Tuple[PushRule, bool]]: ...

def get_base_rule_ids() -> Collection[str]: ...
Expand All @@ -37,6 +39,8 @@ class PushRuleEvaluator:
room_member_count: int,
sender_power_level: Optional[int],
notification_power_levels: Mapping[str, int],
related_events_flattened: Mapping[str, Mapping[str, str]],
related_event_match_enabled: bool,
): ...
def run(
self,
Expand Down
3 changes: 3 additions & 0 deletions synapse/config/experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,9 @@ def read_config(self, config: JsonDict, **kwargs: Any) -> None:
# MSC3773: Thread notifications
self.msc3773_enabled: bool = experimental.get("msc3773_enabled", False)

# MSC3664: Pushrules to match on related events
self.msc3664_enabled: bool = experimental.get("msc3664_enabled", False)

# MSC3848: Introduce errcodes for specific event sending failures
self.msc3848_enabled: bool = experimental.get("msc3848_enabled", False)

Expand Down
49 changes: 48 additions & 1 deletion synapse/push/bulk_push_rule_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@

logger = logging.getLogger(__name__)


push_rules_invalidation_counter = Counter(
"synapse_push_bulk_push_rule_evaluator_push_rules_invalidation_counter", ""
)
Expand Down Expand Up @@ -107,6 +106,8 @@ def __init__(self, hs: "HomeServer"):
self.clock = hs.get_clock()
self._event_auth_handler = hs.get_event_auth_handler()

self._related_event_match_enabled = self.hs.config.experimental.msc3664_enabled

self.room_push_rule_cache_metrics = register_cache(
"cache",
"room_push_rule_cache",
Expand Down Expand Up @@ -218,6 +219,48 @@ async def _get_power_levels_and_sender_level(

return pl_event.content if pl_event else {}, sender_level

async def _related_events(self, event: EventBase) -> Dict[str, Dict[str, str]]:
"""Fetches the related events for 'event'. Sets the im.vector.is_falling_back key if the event is from a fallback relation

Returns:
Mapping of relation type to flattened events.
"""
related_events: Dict[str, Dict[str, str]] = {}
if self._related_event_match_enabled:
related_event_id = event.content.get("m.relates_to", {}).get("event_id")
relation_type = event.content.get("m.relates_to", {}).get("rel_type")
if related_event_id is not None and relation_type is not None:
related_event = await self.store.get_event(
related_event_id, allow_none=True
)
if related_event is not None:
related_events[relation_type] = _flatten_dict(related_event)

reply_event_id = (
event.content.get("m.relates_to", {})
.get("m.in_reply_to", {})
.get("event_id")
)

# convert replies to pseudo relations
if reply_event_id is not None:
related_event = await self.store.get_event(
reply_event_id, allow_none=True
)

if related_event is not None:
related_events["m.in_reply_to"] = _flatten_dict(related_event)
deepbluev7 marked this conversation as resolved.
Show resolved Hide resolved

# indicate that this is from a fallback relation.
if relation_type == "m.thread" and event.content.get(
"m.relates_to", {}
).get("is_falling_back", False):
related_events["m.in_reply_to"][
"im.vector.is_falling_back"
] = ""

return related_events

async def action_for_events_by_user(
self, events_and_context: List[Tuple[EventBase, EventContext]]
) -> None:
Expand Down Expand Up @@ -286,6 +329,8 @@ async def _action_for_event_by_user(
# the parent is part of a thread.
thread_id = await self.store.get_thread_id(relation.parent_id)

related_events = await self._related_events(event)

# It's possible that old room versions have non-integer power levels (floats or
# strings). Workaround this by explicitly converting to int.
notification_levels = power_levels.get("notifications", {})
Expand All @@ -298,6 +343,8 @@ async def _action_for_event_by_user(
room_member_count,
sender_power_level,
notification_levels,
related_events,
self._related_event_match_enabled,
)

users = rules_by_user.keys()
Expand Down
Loading