diff --git a/syncapi/consumers/roomserver.go b/syncapi/consumers/roomserver.go index be84a2816..dac825c3d 100644 --- a/syncapi/consumers/roomserver.go +++ b/syncapi/consumers/roomserver.go @@ -259,6 +259,12 @@ func (s *OutputRoomEventConsumer) notifyJoinedPeeks(ctx context.Context, ev *gom func (s *OutputRoomEventConsumer) onNewInviteEvent( ctx context.Context, msg api.OutputNewInviteEvent, ) error { + if msg.Event.StateKey() == nil { + log.WithFields(log.Fields{ + "event": string(msg.Event.JSON()), + }).Panicf("roomserver output log: invite has no state key") + return nil + } pduPos, err := s.db.AddInviteEvent(ctx, msg.Event) if err != nil { // panic rather than continue with an inconsistent database @@ -269,7 +275,7 @@ func (s *OutputRoomEventConsumer) onNewInviteEvent( }).Panicf("roomserver output log: write invite failure") return nil } - s.notifier.OnNewEvent(msg.Event, "", nil, types.StreamingToken{PDUPosition: pduPos}) + s.notifier.OnNewInvite(types.StreamingToken{PDUPosition: pduPos}, *msg.Event.StateKey()) return nil } diff --git a/syncapi/storage/shared/syncserver.go b/syncapi/storage/shared/syncserver.go index 6fe0f4314..fb79a50fc 100644 --- a/syncapi/storage/shared/syncserver.go +++ b/syncapi/storage/shared/syncserver.go @@ -492,6 +492,7 @@ func (d *Database) syncPositionTx( PDUPosition: types.StreamPosition(maxEventID), TypingPosition: types.StreamPosition(d.EDUCache.GetLatestSyncPosition()), ReceiptPosition: types.StreamPosition(maxReceiptID), + InvitePosition: types.StreamPosition(maxInviteID), } return } @@ -502,6 +503,7 @@ func (d *Database) addPDUDeltaToResponse( ctx context.Context, device userapi.Device, r types.Range, + ir types.Range, numRecentEventsPerRoom int, wantFullState bool, res *types.Response, @@ -544,7 +546,7 @@ func (d *Database) addPDUDeltaToResponse( } // TODO: This should be done in getStateDeltas - if err = d.addInvitesToResponse(ctx, txn, device.UserID, r, res); err != nil { + if err = d.addInvitesToResponse(ctx, txn, device.UserID, ir, res); err != nil { return nil, fmt.Errorf("d.addInvitesToResponse: %w", err) } @@ -702,8 +704,12 @@ func (d *Database) IncrementalSync( From: fromPos.PDUPosition, To: toPos.PDUPosition, } + ir := types.Range{ + From: fromPos.InvitePosition, + To: toPos.InvitePosition, + } joinedRoomIDs, err = d.addPDUDeltaToResponse( - ctx, device, r, numRecentEventsPerRoom, wantFullState, res, + ctx, device, r, ir, numRecentEventsPerRoom, wantFullState, res, ) if err != nil { return nil, fmt.Errorf("d.addPDUDeltaToResponse: %w", err) @@ -784,6 +790,10 @@ func (d *Database) getResponseWithPDUsForCompleteSync( From: 0, To: toPos.PDUPosition, } + ir := types.Range{ + From: 0, + To: toPos.InvitePosition, + } res.NextBatch.ApplyUpdates(toPos) @@ -825,7 +835,7 @@ func (d *Database) getResponseWithPDUsForCompleteSync( } } - if err = d.addInvitesToResponse(ctx, txn, userID, r, res); err != nil { + if err = d.addInvitesToResponse(ctx, txn, userID, ir, res); err != nil { return } diff --git a/syncapi/sync/notifier.go b/syncapi/sync/notifier.go index 284371073..66460a8db 100644 --- a/syncapi/sync/notifier.go +++ b/syncapi/sync/notifier.go @@ -193,6 +193,16 @@ func (n *Notifier) OnNewKeyChange( n.wakeupUsers([]string{wakeUserID}, nil, n.currPos) } +func (n *Notifier) OnNewInvite( + posUpdate types.StreamingToken, wakeUserID string, +) { + n.streamLock.Lock() + defer n.streamLock.Unlock() + + n.currPos.ApplyUpdates(posUpdate) + n.wakeupUsers([]string{wakeUserID}, nil, n.currPos) +} + // GetListener returns a UserStreamListener that can be used to wait for // updates for a user. Must be closed. // notify for anything before sincePos diff --git a/syncapi/sync/requestpool.go b/syncapi/sync/requestpool.go index c3637f177..d5be04af2 100644 --- a/syncapi/sync/requestpool.go +++ b/syncapi/sync/requestpool.go @@ -191,7 +191,6 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *userapi. // Wait for notifier to wake us up case <-userStreamListener.GetNotifyChannel(sincePos): currPos = userStreamListener.GetSyncPosition() - sincePos = currPos // Or for timeout to expire case <-timer.C: // We just need to ensure we get out of the select after reaching the diff --git a/syncapi/types/types.go b/syncapi/types/types.go index 1e29a0389..8e5260326 100644 --- a/syncapi/types/types.go +++ b/syncapi/types/types.go @@ -113,6 +113,7 @@ type StreamingToken struct { TypingPosition StreamPosition ReceiptPosition StreamPosition SendToDevicePosition StreamPosition + InvitePosition StreamPosition DeviceListPosition LogPosition } @@ -129,9 +130,10 @@ func (s *StreamingToken) UnmarshalText(text []byte) (err error) { func (t StreamingToken) String() string { posStr := fmt.Sprintf( - "s%d_%d_%d_%d", + "s%d_%d_%d_%d_%d", t.PDUPosition, t.TypingPosition, t.ReceiptPosition, t.SendToDevicePosition, + t.InvitePosition, ) if dl := t.DeviceListPosition; !dl.IsEmpty() { posStr += fmt.Sprintf(".dl-%d-%d", dl.Partition, dl.Offset) @@ -150,6 +152,8 @@ func (t *StreamingToken) IsAfter(other StreamingToken) bool { return true case t.SendToDevicePosition > other.SendToDevicePosition: return true + case t.InvitePosition > other.InvitePosition: + return true case t.DeviceListPosition.IsAfter(&other.DeviceListPosition): return true } @@ -157,7 +161,7 @@ func (t *StreamingToken) IsAfter(other StreamingToken) bool { } func (t *StreamingToken) IsEmpty() bool { - return t == nil || t.PDUPosition+t.TypingPosition+t.ReceiptPosition+t.SendToDevicePosition == 0 && t.DeviceListPosition.IsEmpty() + return t == nil || t.PDUPosition+t.TypingPosition+t.ReceiptPosition+t.SendToDevicePosition+t.InvitePosition == 0 && t.DeviceListPosition.IsEmpty() } // WithUpdates returns a copy of the StreamingToken with updates applied from another StreamingToken. @@ -174,16 +178,22 @@ func (t *StreamingToken) WithUpdates(other StreamingToken) StreamingToken { // streaming token contains any positions that are not 0, they are considered updates // and will overwrite the value in the token. func (t *StreamingToken) ApplyUpdates(other StreamingToken) { - switch { - case other.PDUPosition > 0: + if other.PDUPosition > 0 { t.PDUPosition = other.PDUPosition - case other.TypingPosition > 0: + } + if other.TypingPosition > 0 { t.TypingPosition = other.TypingPosition - case other.ReceiptPosition > 0: + } + if other.ReceiptPosition > 0 { t.ReceiptPosition = other.ReceiptPosition - case other.SendToDevicePosition > 0: + } + if other.SendToDevicePosition > 0 { t.SendToDevicePosition = other.SendToDevicePosition - case other.DeviceListPosition.Offset > 0: + } + if other.InvitePosition > 0 { + t.InvitePosition = other.InvitePosition + } + if other.DeviceListPosition.Offset > 0 { t.DeviceListPosition = other.DeviceListPosition } } @@ -276,7 +286,7 @@ func NewStreamTokenFromString(tok string) (token StreamingToken, err error) { } categories := strings.Split(tok[1:], ".") parts := strings.Split(categories[0], "_") - var positions [4]StreamPosition + var positions [5]StreamPosition for i, p := range parts { if i > len(positions) { break @@ -293,6 +303,7 @@ func NewStreamTokenFromString(tok string) (token StreamingToken, err error) { TypingPosition: positions[1], ReceiptPosition: positions[2], SendToDevicePosition: positions[3], + InvitePosition: positions[4], } // dl-0-1234 // $log_name-$partition-$offset diff --git a/syncapi/types/types_test.go b/syncapi/types/types_test.go index ecb0ab6fd..caefcc4f7 100644 --- a/syncapi/types/types_test.go +++ b/syncapi/types/types_test.go @@ -42,10 +42,10 @@ func TestNewSyncTokenWithLogs(t *testing.T) { func TestSyncTokens(t *testing.T) { shouldPass := map[string]string{ - "s4_0_0_0": StreamingToken{4, 0, 0, 0, LogPosition{}}.String(), - "s3_1_0_0.dl-1-2": StreamingToken{3, 1, 0, 0, LogPosition{1, 2}}.String(), - "s3_1_2_3": StreamingToken{3, 1, 2, 3, LogPosition{}}.String(), - "t3_1": TopologyToken{3, 1}.String(), + "s4_0_0_0_0": StreamingToken{4, 0, 0, 0, 0, LogPosition{}}.String(), + "s3_1_0_0_0.dl-1-2": StreamingToken{3, 1, 0, 0, 0, LogPosition{1, 2}}.String(), + "s3_1_2_3_5": StreamingToken{3, 1, 2, 3, 5, LogPosition{}}.String(), + "t3_1": TopologyToken{3, 1}.String(), } for a, b := range shouldPass {