Merge branch 'main' into neilalexander/purgeroom

This commit is contained in:
Neil Alexander 2022-09-27 14:14:19 +01:00
commit abd6a6425a
No known key found for this signature in database
GPG key ID: A02A2019A2BB0944
56 changed files with 550 additions and 670 deletions

View file

@ -2,7 +2,7 @@
<!-- Please read docs/CONTRIBUTING.md before submitting your pull request --> <!-- Please read docs/CONTRIBUTING.md before submitting your pull request -->
* [ ] I have added added tests for PR _or_ I have justified why this PR doesn't need tests. * [ ] I have added tests for PR _or_ I have justified why this PR doesn't need tests.
* [ ] Pull request includes a [sign off](https://github.com/matrix-org/dendrite/blob/main/docs/CONTRIBUTING.md#sign-off) * [ ] Pull request includes a [sign off](https://github.com/matrix-org/dendrite/blob/main/docs/CONTRIBUTING.md#sign-off)
Signed-off-by: `Your Name <your@email.example.org>` Signed-off-by: `Your Name <your@email.example.org>`

View file

@ -137,3 +137,63 @@ jobs:
${{ env.DOCKER_NAMESPACE }}/dendrite-polylith:${{ env.RELEASE_VERSION }} ${{ env.DOCKER_NAMESPACE }}/dendrite-polylith:${{ env.RELEASE_VERSION }}
ghcr.io/${{ env.GHCR_NAMESPACE }}/dendrite-polylith:latest ghcr.io/${{ env.GHCR_NAMESPACE }}/dendrite-polylith:latest
ghcr.io/${{ env.GHCR_NAMESPACE }}/dendrite-polylith:${{ env.RELEASE_VERSION }} ghcr.io/${{ env.GHCR_NAMESPACE }}/dendrite-polylith:${{ env.RELEASE_VERSION }}
demo-pinecone:
name: Pinecone demo image
runs-on: ubuntu-latest
permissions:
contents: read
packages: write
steps:
- name: Checkout
uses: actions/checkout@v2
- name: Get release tag
if: github.event_name == 'release' # Only for GitHub releases
run: echo "RELEASE_VERSION=${GITHUB_REF#refs/*/}" >> $GITHUB_ENV
- name: Set up QEMU
uses: docker/setup-qemu-action@v1
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v1
- name: Login to Docker Hub
uses: docker/login-action@v1
with:
username: ${{ env.DOCKER_HUB_USER }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Login to GitHub Containers
uses: docker/login-action@v1
with:
registry: ghcr.io
username: ${{ github.repository_owner }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Build main pinecone demo image
if: github.ref_name == 'main'
id: docker_build_demo_pinecone
uses: docker/build-push-action@v2
with:
cache-from: type=gha
cache-to: type=gha,mode=max
context: .
file: ./build/docker/Dockerfile.demo-pinecone
platforms: ${{ env.PLATFORMS }}
push: true
tags: |
${{ env.DOCKER_NAMESPACE }}/dendrite-demo-pinecone:${{ github.ref_name }}
ghcr.io/${{ env.GHCR_NAMESPACE }}/dendrite-demo-pinecone:${{ github.ref_name }}
- name: Build release pinecone demo image
if: github.event_name == 'release' # Only for GitHub releases
id: docker_build_demo_pinecone_release
uses: docker/build-push-action@v2
with:
cache-from: type=gha
cache-to: type=gha,mode=max
context: .
file: ./build/docker/Dockerfile.demo-pinecone
platforms: ${{ env.PLATFORMS }}
push: true
tags: |
${{ env.DOCKER_NAMESPACE }}/dendrite-demo-pinecone:latest
${{ env.DOCKER_NAMESPACE }}/dendrite-demo-pinecone:${{ env.RELEASE_VERSION }}
ghcr.io/${{ env.GHCR_NAMESPACE }}/dendrite-demo-pinecone:latest
ghcr.io/${{ env.GHCR_NAMESPACE }}/dendrite-demo-pinecone:${{ env.RELEASE_VERSION }}

View file

@ -0,0 +1,25 @@
FROM docker.io/golang:1.19-alpine AS base
RUN apk --update --no-cache add bash build-base
WORKDIR /build
COPY . /build
RUN mkdir -p bin
RUN go build -trimpath -o bin/ ./cmd/dendrite-demo-pinecone
RUN go build -trimpath -o bin/ ./cmd/create-account
RUN go build -trimpath -o bin/ ./cmd/generate-keys
FROM alpine:latest
LABEL org.opencontainers.image.title="Dendrite (Pinecone demo)"
LABEL org.opencontainers.image.description="Next-generation Matrix homeserver written in Go"
LABEL org.opencontainers.image.source="https://github.com/matrix-org/dendrite"
LABEL org.opencontainers.image.licenses="Apache-2.0"
COPY --from=base /build/bin/* /usr/bin/
VOLUME /etc/dendrite
WORKDIR /etc/dendrite
ENTRYPOINT ["/usr/bin/dendrite-demo-pinecone"]

View file

@ -1,4 +1,4 @@
FROM docker.io/golang:1.18-alpine AS base FROM docker.io/golang:1.19-alpine AS base
RUN apk --update --no-cache add bash build-base RUN apk --update --no-cache add bash build-base

View file

@ -1,4 +1,4 @@
FROM docker.io/golang:1.18-alpine AS base FROM docker.io/golang:1.19-alpine AS base
RUN apk --update --no-cache add bash build-base RUN apk --update --no-cache add bash build-base

View file

@ -1,4 +1,4 @@
#!/bin/bash #!/usr/bin/env bash
cd $(git rev-parse --show-toplevel) cd $(git rev-parse --show-toplevel)

View file

@ -1,4 +1,4 @@
#!/bin/bash #!/usr/bin/env bash
TAG=${1:-latest} TAG=${1:-latest}

View file

@ -1,4 +1,4 @@
#!/bin/bash #!/usr/bin/env bash
TAG=${1:-latest} TAG=${1:-latest}

View file

@ -16,6 +16,9 @@ Dendrite can automatically populate the database with the relevant tables and in
it is not capable of creating the databases themselves. You will need to create the databases it is not capable of creating the databases themselves. You will need to create the databases
manually. manually.
The databases **must** be created with UTF-8 encoding configured or you will likely run into problems
with your Dendrite deployment.
At this point, you can choose to either use a single database for all Dendrite components, At this point, you can choose to either use a single database for all Dendrite components,
or you can run each component with its own separate database: or you can run each component with its own separate database:
@ -65,7 +68,7 @@ sudo -u postgres createuser -P dendrite
Create the database itself, using the `dendrite` role from above: Create the database itself, using the `dendrite` role from above:
```bash ```bash
sudo -u postgres createdb -O dendrite dendrite sudo -u postgres createdb -O dendrite -E UTF-8 dendrite
``` ```
### Multiple database creation ### Multiple database creation
@ -85,7 +88,7 @@ The following eight components require a database. In this example they will be
```bash ```bash
for i in appservice federationapi mediaapi mscs roomserver syncapi keyserver userapi; do for i in appservice federationapi mediaapi mscs roomserver syncapi keyserver userapi; do
sudo -u postgres createdb -O dendrite dendrite_$i sudo -u postgres createdb -O dendrite -E UTF-8 dendrite_$i
done done
``` ```

View file

@ -217,7 +217,7 @@ func (r *FederationInternalAPI) performJoinUsingServer(
var remoteEvent *gomatrixserverlib.Event var remoteEvent *gomatrixserverlib.Event
remoteEvent, err = respSendJoin.Event.UntrustedEvent(respMakeJoin.RoomVersion) remoteEvent, err = respSendJoin.Event.UntrustedEvent(respMakeJoin.RoomVersion)
if err == nil && isWellFormedMembershipEvent( if err == nil && isWellFormedMembershipEvent(
remoteEvent, roomID, userID, r.cfg.Matrix.ServerName, remoteEvent, roomID, userID,
) { ) {
event = remoteEvent event = remoteEvent
} }
@ -285,7 +285,7 @@ func (r *FederationInternalAPI) performJoinUsingServer(
// isWellFormedMembershipEvent returns true if the event looks like a legitimate // isWellFormedMembershipEvent returns true if the event looks like a legitimate
// membership event. // membership event.
func isWellFormedMembershipEvent(event *gomatrixserverlib.Event, roomID, userID string, origin gomatrixserverlib.ServerName) bool { func isWellFormedMembershipEvent(event *gomatrixserverlib.Event, roomID, userID string) bool {
if membership, err := event.Membership(); err != nil { if membership, err := event.Membership(); err != nil {
return false return false
} else if membership != gomatrixserverlib.Join { } else if membership != gomatrixserverlib.Join {
@ -294,9 +294,6 @@ func isWellFormedMembershipEvent(event *gomatrixserverlib.Event, roomID, userID
if event.RoomID() != roomID { if event.RoomID() != roomID {
return false return false
} }
if event.Origin() != origin {
return false
}
if !event.StateKeyEquals(userID) { if !event.StateKeyEquals(userID) {
return false return false
} }

View file

@ -148,8 +148,15 @@ func processInvite(
JSON: jsonerror.BadJSON("The event JSON could not be redacted"), JSON: jsonerror.BadJSON("The event JSON could not be redacted"),
} }
} }
_, serverName, err := gomatrixserverlib.SplitID('@', event.Sender())
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.BadJSON("The event JSON contains an invalid sender"),
}
}
verifyRequests := []gomatrixserverlib.VerifyJSONRequest{{ verifyRequests := []gomatrixserverlib.VerifyJSONRequest{{
ServerName: event.Origin(), ServerName: serverName,
Message: redacted, Message: redacted,
AtTS: event.OriginServerTS(), AtTS: event.OriginServerTS(),
StrictValidityChecking: true, StrictValidityChecking: true,

View file

@ -203,14 +203,6 @@ func SendJoin(
} }
} }
// Check that the event is from the server sending the request.
if event.Origin() != request.Origin() {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Forbidden("The join must be sent by the server it originated on"),
}
}
// Check that a state key is provided. // Check that a state key is provided.
if event.StateKey() == nil || event.StateKeyEquals("") { if event.StateKey() == nil || event.StateKeyEquals("") {
return util.JSONResponse{ return util.JSONResponse{
@ -228,16 +220,16 @@ func SendJoin(
// Check that the sender belongs to the server that is sending us // Check that the sender belongs to the server that is sending us
// the request. By this point we've already asserted that the sender // the request. By this point we've already asserted that the sender
// and the state key are equal so we don't need to check both. // and the state key are equal so we don't need to check both.
var domain gomatrixserverlib.ServerName var serverName gomatrixserverlib.ServerName
if _, domain, err = gomatrixserverlib.SplitID('@', event.Sender()); err != nil { if _, serverName, err = gomatrixserverlib.SplitID('@', event.Sender()); err != nil {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusForbidden, Code: http.StatusForbidden,
JSON: jsonerror.Forbidden("The sender of the join is invalid"), JSON: jsonerror.Forbidden("The sender of the join is invalid"),
} }
} else if domain != request.Origin() { } else if serverName != request.Origin() {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusForbidden, Code: http.StatusForbidden,
JSON: jsonerror.Forbidden("The sender of the join must belong to the origin server"), JSON: jsonerror.Forbidden("The sender does not match the server that originated the request"),
} }
} }
@ -292,7 +284,7 @@ func SendJoin(
} }
} }
verifyRequests := []gomatrixserverlib.VerifyJSONRequest{{ verifyRequests := []gomatrixserverlib.VerifyJSONRequest{{
ServerName: event.Origin(), ServerName: serverName,
Message: redacted, Message: redacted,
AtTS: event.OriginServerTS(), AtTS: event.OriginServerTS(),
StrictValidityChecking: true, StrictValidityChecking: true,

View file

@ -118,6 +118,7 @@ func MakeLeave(
} }
// SendLeave implements the /send_leave API // SendLeave implements the /send_leave API
// nolint:gocyclo
func SendLeave( func SendLeave(
httpReq *http.Request, httpReq *http.Request,
request *gomatrixserverlib.FederationRequest, request *gomatrixserverlib.FederationRequest,
@ -167,14 +168,6 @@ func SendLeave(
} }
} }
// Check that the event is from the server sending the request.
if event.Origin() != request.Origin() {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Forbidden("The leave must be sent by the server it originated on"),
}
}
if event.StateKey() == nil || event.StateKeyEquals("") { if event.StateKey() == nil || event.StateKeyEquals("") {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusBadRequest, Code: http.StatusBadRequest,
@ -188,6 +181,22 @@ func SendLeave(
} }
} }
// Check that the sender belongs to the server that is sending us
// the request. By this point we've already asserted that the sender
// and the state key are equal so we don't need to check both.
var serverName gomatrixserverlib.ServerName
if _, serverName, err = gomatrixserverlib.SplitID('@', event.Sender()); err != nil {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Forbidden("The sender of the join is invalid"),
}
} else if serverName != request.Origin() {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Forbidden("The sender does not match the server that originated the request"),
}
}
// Check if the user has already left. If so, no-op! // Check if the user has already left. If so, no-op!
queryReq := &api.QueryLatestEventsAndStateRequest{ queryReq := &api.QueryLatestEventsAndStateRequest{
RoomID: roomID, RoomID: roomID,
@ -240,7 +249,7 @@ func SendLeave(
} }
} }
verifyRequests := []gomatrixserverlib.VerifyJSONRequest{{ verifyRequests := []gomatrixserverlib.VerifyJSONRequest{{
ServerName: event.Origin(), ServerName: serverName,
Message: redacted, Message: redacted,
AtTS: event.OriginServerTS(), AtTS: event.OriginServerTS(),
StrictValidityChecking: true, StrictValidityChecking: true,

4
go.mod
View file

@ -22,8 +22,8 @@ require (
github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91
github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16
github.com/matrix-org/gomatrixserverlib v0.0.0-20220923115829-2217f6c65ce3 github.com/matrix-org/gomatrixserverlib v0.0.0-20220926161602-759a8ee7c4d5
github.com/matrix-org/pinecone v0.0.0-20220923151905-0900fceecb89 github.com/matrix-org/pinecone v0.0.0-20220927101513-d0beb180f44d
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4
github.com/mattn/go-sqlite3 v1.14.15 github.com/mattn/go-sqlite3 v1.14.15
github.com/nats-io/nats-server/v2 v2.9.1-0.20220920152220-52d7b481c4b5 github.com/nats-io/nats-server/v2 v2.9.1-0.20220920152220-52d7b481c4b5

8
go.sum
View file

@ -384,10 +384,10 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 h1:s7fexw
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo= github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo=
github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 h1:ZtO5uywdd5dLDCud4r0r55eP4j9FuUNpl60Gmntcop4= github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 h1:ZtO5uywdd5dLDCud4r0r55eP4j9FuUNpl60Gmntcop4=
github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s=
github.com/matrix-org/gomatrixserverlib v0.0.0-20220923115829-2217f6c65ce3 h1:u3FKZmXxfhv3XhD8RziBlt96QTt8eHFhg1upCloBh2g= github.com/matrix-org/gomatrixserverlib v0.0.0-20220926161602-759a8ee7c4d5 h1:cQMA9hip0WSp6cv7CUfButa9Jl/9E6kqWmQyOjx5A5s=
github.com/matrix-org/gomatrixserverlib v0.0.0-20220923115829-2217f6c65ce3/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4= github.com/matrix-org/gomatrixserverlib v0.0.0-20220926161602-759a8ee7c4d5/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4=
github.com/matrix-org/pinecone v0.0.0-20220923151905-0900fceecb89 h1:Ym50Fgn3GiYya4p29k3nJ5nYsalFGev3eIm3DeGNIq4= github.com/matrix-org/pinecone v0.0.0-20220927101513-d0beb180f44d h1:kGPJ6Rg8nn5an2CbCZrRiuTNyNzE0rRMiqm4UXJYrRs=
github.com/matrix-org/pinecone v0.0.0-20220923151905-0900fceecb89/go.mod h1:K0N1ixHQxXoCyqolDqVxPM3ArrDtcMs8yegOx2Lfv9k= github.com/matrix-org/pinecone v0.0.0-20220927101513-d0beb180f44d/go.mod h1:K0N1ixHQxXoCyqolDqVxPM3ArrDtcMs8yegOx2Lfv9k=
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 h1:eCEHXWDv9Rm335MSuB49mFUK44bwZPFSDde3ORE3syk= github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 h1:eCEHXWDv9Rm335MSuB49mFUK44bwZPFSDde3ORE3syk=
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U= github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U=
github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU=

View file

@ -118,6 +118,10 @@ func (r *Inputer) processRoomEvent(
if roomInfo == nil && !isCreateEvent { if roomInfo == nil && !isCreateEvent {
return fmt.Errorf("room %s does not exist for event %s", event.RoomID(), event.EventID()) return fmt.Errorf("room %s does not exist for event %s", event.RoomID(), event.EventID())
} }
_, senderDomain, err := gomatrixserverlib.SplitID('@', event.Sender())
if err != nil {
return fmt.Errorf("event has invalid sender %q", input.Event.Sender())
}
// If we already know about this outlier and it hasn't been rejected // If we already know about this outlier and it hasn't been rejected
// then we won't attempt to reprocess it. If it was rejected or has now // then we won't attempt to reprocess it. If it was rejected or has now
@ -145,7 +149,8 @@ func (r *Inputer) processRoomEvent(
var missingAuth, missingPrev bool var missingAuth, missingPrev bool
serverRes := &fedapi.QueryJoinedHostServerNamesInRoomResponse{} serverRes := &fedapi.QueryJoinedHostServerNamesInRoomResponse{}
if !isCreateEvent { if !isCreateEvent {
missingAuthIDs, missingPrevIDs, err := r.DB.MissingAuthPrevEvents(ctx, event) var missingAuthIDs, missingPrevIDs []string
missingAuthIDs, missingPrevIDs, err = r.DB.MissingAuthPrevEvents(ctx, event)
if err != nil { if err != nil {
return fmt.Errorf("updater.MissingAuthPrevEvents: %w", err) return fmt.Errorf("updater.MissingAuthPrevEvents: %w", err)
} }
@ -158,7 +163,7 @@ func (r *Inputer) processRoomEvent(
RoomID: event.RoomID(), RoomID: event.RoomID(),
ExcludeSelf: true, ExcludeSelf: true,
} }
if err := r.FSAPI.QueryJoinedHostServerNamesInRoom(ctx, serverReq, serverRes); err != nil { if err = r.FSAPI.QueryJoinedHostServerNamesInRoom(ctx, serverReq, serverRes); err != nil {
return fmt.Errorf("r.FSAPI.QueryJoinedHostServerNamesInRoom: %w", err) return fmt.Errorf("r.FSAPI.QueryJoinedHostServerNamesInRoom: %w", err)
} }
// Sort all of the servers into a map so that we can randomise // Sort all of the servers into a map so that we can randomise
@ -173,9 +178,9 @@ func (r *Inputer) processRoomEvent(
serverRes.ServerNames = append(serverRes.ServerNames, input.Origin) serverRes.ServerNames = append(serverRes.ServerNames, input.Origin)
delete(servers, input.Origin) delete(servers, input.Origin)
} }
if origin := event.Origin(); origin != input.Origin { if senderDomain != input.Origin {
serverRes.ServerNames = append(serverRes.ServerNames, origin) serverRes.ServerNames = append(serverRes.ServerNames, senderDomain)
delete(servers, origin) delete(servers, senderDomain)
} }
for server := range servers { for server := range servers {
serverRes.ServerNames = append(serverRes.ServerNames, server) serverRes.ServerNames = append(serverRes.ServerNames, server)
@ -188,7 +193,7 @@ func (r *Inputer) processRoomEvent(
isRejected := false isRejected := false
authEvents := gomatrixserverlib.NewAuthEvents(nil) authEvents := gomatrixserverlib.NewAuthEvents(nil)
knownEvents := map[string]*types.Event{} knownEvents := map[string]*types.Event{}
if err := r.fetchAuthEvents(ctx, logger, headered, &authEvents, knownEvents, serverRes.ServerNames); err != nil { if err = r.fetchAuthEvents(ctx, logger, headered, &authEvents, knownEvents, serverRes.ServerNames); err != nil {
return fmt.Errorf("r.fetchAuthEvents: %w", err) return fmt.Errorf("r.fetchAuthEvents: %w", err)
} }
@ -231,7 +236,6 @@ func (r *Inputer) processRoomEvent(
if input.Kind == api.KindNew { if input.Kind == api.KindNew {
// Check that the event passes authentication checks based on the // Check that the event passes authentication checks based on the
// current room state. // current room state.
var err error
softfail, err = helpers.CheckForSoftFail(ctx, r.DB, headered, input.StateEventIDs) softfail, err = helpers.CheckForSoftFail(ctx, r.DB, headered, input.StateEventIDs)
if err != nil { if err != nil {
logger.WithError(err).Warn("Error authing soft-failed event") logger.WithError(err).Warn("Error authing soft-failed event")
@ -265,7 +269,8 @@ func (r *Inputer) processRoomEvent(
hadEvents: map[string]bool{}, hadEvents: map[string]bool{},
haveEvents: map[string]*gomatrixserverlib.Event{}, haveEvents: map[string]*gomatrixserverlib.Event{},
} }
if stateSnapshot, err := missingState.processEventWithMissingState(ctx, event, headered.RoomVersion); err != nil { var stateSnapshot *parsedRespState
if stateSnapshot, err = missingState.processEventWithMissingState(ctx, event, headered.RoomVersion); err != nil {
// Something went wrong with retrieving the missing state, so we can't // Something went wrong with retrieving the missing state, so we can't
// really do anything with the event other than reject it at this point. // really do anything with the event other than reject it at this point.
isRejected = true isRejected = true
@ -302,7 +307,6 @@ func (r *Inputer) processRoomEvent(
// burning CPU time. // burning CPU time.
historyVisibility := gomatrixserverlib.HistoryVisibilityShared // Default to shared. historyVisibility := gomatrixserverlib.HistoryVisibilityShared // Default to shared.
if input.Kind != api.KindOutlier && rejectionErr == nil && !isRejected { if input.Kind != api.KindOutlier && rejectionErr == nil && !isRejected {
var err error
historyVisibility, rejectionErr, err = r.processStateBefore(ctx, input, missingPrev) historyVisibility, rejectionErr, err = r.processStateBefore(ctx, input, missingPrev)
if err != nil { if err != nil {
return fmt.Errorf("r.processStateBefore: %w", err) return fmt.Errorf("r.processStateBefore: %w", err)

View file

@ -468,7 +468,9 @@ FindSuccessor:
// Store the server names in a temporary map to avoid duplicates. // Store the server names in a temporary map to avoid duplicates.
serverSet := make(map[gomatrixserverlib.ServerName]bool) serverSet := make(map[gomatrixserverlib.ServerName]bool)
for _, event := range memberEvents { for _, event := range memberEvents {
serverSet[event.Origin()] = true if _, senderDomain, err := gomatrixserverlib.SplitID('@', event.Sender()); err == nil {
serverSet[senderDomain] = true
}
} }
var servers []gomatrixserverlib.ServerName var servers []gomatrixserverlib.ServerName
for server := range serverSet { for server := range serverSet {

View file

@ -50,6 +50,10 @@ func (r *Inviter) PerformInvite(
if event.StateKey() == nil { if event.StateKey() == nil {
return nil, fmt.Errorf("invite must be a state event") return nil, fmt.Errorf("invite must be a state event")
} }
_, senderDomain, err := gomatrixserverlib.SplitID('@', event.Sender())
if err != nil {
return nil, fmt.Errorf("sender %q is invalid", event.Sender())
}
roomID := event.RoomID() roomID := event.RoomID()
targetUserID := *event.StateKey() targetUserID := *event.StateKey()
@ -67,7 +71,7 @@ func (r *Inviter) PerformInvite(
return nil, nil return nil, nil
} }
isTargetLocal := domain == r.Cfg.Matrix.ServerName isTargetLocal := domain == r.Cfg.Matrix.ServerName
isOriginLocal := event.Origin() == r.Cfg.Matrix.ServerName isOriginLocal := senderDomain == r.Cfg.Matrix.ServerName
if !isOriginLocal && !isTargetLocal { if !isOriginLocal && !isTargetLocal {
res.Error = &api.PerformError{ res.Error = &api.PerformError{
Code: api.PerformErrorBadRequest, Code: api.PerformErrorBadRequest,
@ -235,7 +239,7 @@ func (r *Inviter) PerformInvite(
{ {
Kind: api.KindNew, Kind: api.KindNew,
Event: event, Event: event,
Origin: event.Origin(), Origin: senderDomain,
SendAsServer: req.SendAsServer, SendAsServer: req.SendAsServer,
}, },
}, },

View file

@ -81,12 +81,11 @@ func (r *Leaver) performLeaveRoomByID(
// that. // that.
isInvitePending, senderUser, eventID, err := helpers.IsInvitePending(ctx, r.DB, req.RoomID, req.UserID) isInvitePending, senderUser, eventID, err := helpers.IsInvitePending(ctx, r.DB, req.RoomID, req.UserID)
if err == nil && isInvitePending { if err == nil && isInvitePending {
var host gomatrixserverlib.ServerName _, senderDomain, serr := gomatrixserverlib.SplitID('@', senderUser)
_, host, err = gomatrixserverlib.SplitID('@', senderUser) if serr != nil {
if err != nil {
return nil, fmt.Errorf("sender %q is invalid", senderUser) return nil, fmt.Errorf("sender %q is invalid", senderUser)
} }
if host != r.Cfg.Matrix.ServerName { if senderDomain != r.Cfg.Matrix.ServerName {
return r.performFederatedRejectInvite(ctx, req, res, senderUser, eventID) return r.performFederatedRejectInvite(ctx, req, res, senderUser, eventID)
} }
// check that this is not a "server notice room" // check that this is not a "server notice room"
@ -172,6 +171,12 @@ func (r *Leaver) performLeaveRoomByID(
return nil, fmt.Errorf("eventutil.BuildEvent: %w", err) return nil, fmt.Errorf("eventutil.BuildEvent: %w", err)
} }
// Get the sender domain.
_, senderDomain, serr := gomatrixserverlib.SplitID('@', event.Sender())
if serr != nil {
return nil, fmt.Errorf("sender %q is invalid", event.Sender())
}
// Give our leave event to the roomserver input stream. The // Give our leave event to the roomserver input stream. The
// roomserver will process the membership change and notify // roomserver will process the membership change and notify
// downstream automatically. // downstream automatically.
@ -180,7 +185,7 @@ func (r *Leaver) performLeaveRoomByID(
{ {
Kind: api.KindNew, Kind: api.KindNew,
Event: event.Headered(buildRes.RoomVersion), Event: event.Headered(buildRes.RoomVersion),
Origin: event.Origin(), Origin: senderDomain,
SendAsServer: string(r.Cfg.Matrix.ServerName), SendAsServer: string(r.Cfg.Matrix.ServerName),
}, },
}, },

View file

@ -897,7 +897,7 @@ func (d *Database) handleRedactions(
switch { switch {
case redactUser >= pl.Redact: case redactUser >= pl.Redact:
// The power level of the redaction events sender is greater than or equal to the redact level. // The power level of the redaction events sender is greater than or equal to the redact level.
case redactedEvent.Origin() == redactionEvent.Origin() && redactedEvent.Sender() == redactionEvent.Sender(): case redactedEvent.Sender() == redactionEvent.Sender():
// The domain of the redaction events sender matches that of the original events sender. // The domain of the redaction events sender matches that of the original events sender.
default: default:
return nil, "", nil return nil, "", nil

View file

@ -1,4 +1,4 @@
#!/bin/bash #!/usr/bin/env bash
# #
# Runs SyTest either from Docker Hub, or from ../sytest. If it's run # Runs SyTest either from Docker Hub, or from ../sytest. If it's run
# locally, the Docker image is rebuilt first. # locally, the Docker image is rebuilt first.

View file

@ -9,9 +9,10 @@ import (
"time" "time"
"github.com/getsentry/sentry-go" "github.com/getsentry/sentry-go"
"github.com/sirupsen/logrus"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/setup/process"
"github.com/sirupsen/logrus"
natsserver "github.com/nats-io/nats-server/v2/server" natsserver "github.com/nats-io/nats-server/v2/server"
natsclient "github.com/nats-io/nats.go" natsclient "github.com/nats-io/nats.go"
@ -184,6 +185,8 @@ func setupNATS(process *process.ProcessContext, cfg *config.JetStream, nc *natsc
OutputSendToDeviceEvent: {"SyncAPIEDUServerSendToDeviceConsumer", "FederationAPIEDUServerConsumer"}, OutputSendToDeviceEvent: {"SyncAPIEDUServerSendToDeviceConsumer", "FederationAPIEDUServerConsumer"},
OutputTypingEvent: {"SyncAPIEDUServerTypingConsumer", "FederationAPIEDUServerConsumer"}, OutputTypingEvent: {"SyncAPIEDUServerTypingConsumer", "FederationAPIEDUServerConsumer"},
OutputRoomEvent: {"AppserviceRoomserverConsumer"}, OutputRoomEvent: {"AppserviceRoomserverConsumer"},
OutputStreamEvent: {"UserAPISyncAPIStreamEventConsumer"},
OutputReadUpdate: {"UserAPISyncAPIReadUpdateConsumer"},
} { } {
streamName := cfg.Matrix.JetStream.Prefixed(stream) streamName := cfg.Matrix.JetStream.Prefixed(stream)
for _, consumer := range consumers { for _, consumer := range consumers {

View file

@ -94,16 +94,6 @@ var streams = []*nats.StreamConfig{
Retention: nats.InterestPolicy, Retention: nats.InterestPolicy,
Storage: nats.FileStorage, Storage: nats.FileStorage,
}, },
{
Name: OutputStreamEvent,
Retention: nats.InterestPolicy,
Storage: nats.FileStorage,
},
{
Name: OutputReadUpdate,
Retention: nats.InterestPolicy,
Storage: nats.FileStorage,
},
{ {
Name: OutputPresenceEvent, Name: OutputPresenceEvent,
Retention: nats.InterestPolicy, Retention: nats.InterestPolicy,

View file

@ -1,4 +1,4 @@
#! /bin/bash #!/usr/bin/env bash
# #
# Parses a results.tap file from SyTest output and a file containing test names (a test whitelist) # Parses a results.tap file from SyTest output and a file containing test names (a test whitelist)
# and checks whether a test name that exists in the whitelist (that should pass), failed or not. # and checks whether a test name that exists in the whitelist (that should pass), failed or not.

View file

@ -16,9 +16,7 @@ package consumers
import ( import (
"context" "context"
"database/sql"
"encoding/json" "encoding/json"
"fmt"
"github.com/getsentry/sentry-go" "github.com/getsentry/sentry-go"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
@ -31,7 +29,6 @@ import (
"github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/setup/process"
"github.com/matrix-org/dendrite/syncapi/notifier" "github.com/matrix-org/dendrite/syncapi/notifier"
"github.com/matrix-org/dendrite/syncapi/producers"
"github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
) )
@ -46,7 +43,6 @@ type OutputClientDataConsumer struct {
stream types.StreamProvider stream types.StreamProvider
notifier *notifier.Notifier notifier *notifier.Notifier
serverName gomatrixserverlib.ServerName serverName gomatrixserverlib.ServerName
producer *producers.UserAPIReadProducer
} }
// NewOutputClientDataConsumer creates a new OutputClientData consumer. Call Start() to begin consuming from room servers. // NewOutputClientDataConsumer creates a new OutputClientData consumer. Call Start() to begin consuming from room servers.
@ -57,7 +53,6 @@ func NewOutputClientDataConsumer(
store storage.Database, store storage.Database,
notifier *notifier.Notifier, notifier *notifier.Notifier,
stream types.StreamProvider, stream types.StreamProvider,
producer *producers.UserAPIReadProducer,
) *OutputClientDataConsumer { ) *OutputClientDataConsumer {
return &OutputClientDataConsumer{ return &OutputClientDataConsumer{
ctx: process.Context(), ctx: process.Context(),
@ -68,7 +63,6 @@ func NewOutputClientDataConsumer(
notifier: notifier, notifier: notifier,
stream: stream, stream: stream,
serverName: cfg.Matrix.ServerName, serverName: cfg.Matrix.ServerName,
producer: producer,
} }
} }
@ -113,15 +107,6 @@ func (s *OutputClientDataConsumer) onMessage(ctx context.Context, msgs []*nats.M
return false return false
} }
if err = s.sendReadUpdate(ctx, userID, output); err != nil {
log.WithError(err).WithFields(logrus.Fields{
"user_id": userID,
"room_id": output.RoomID,
}).Errorf("Failed to generate read update")
sentry.CaptureException(err)
return false
}
if output.IgnoredUsers != nil { if output.IgnoredUsers != nil {
if err := s.db.UpdateIgnoresForUser(ctx, userID, output.IgnoredUsers); err != nil { if err := s.db.UpdateIgnoresForUser(ctx, userID, output.IgnoredUsers); err != nil {
log.WithError(err).WithFields(logrus.Fields{ log.WithError(err).WithFields(logrus.Fields{
@ -136,34 +121,3 @@ func (s *OutputClientDataConsumer) onMessage(ctx context.Context, msgs []*nats.M
return true return true
} }
func (s *OutputClientDataConsumer) sendReadUpdate(ctx context.Context, userID string, output eventutil.AccountData) error {
if output.Type != "m.fully_read" || output.ReadMarker == nil {
return nil
}
_, serverName, err := gomatrixserverlib.SplitID('@', userID)
if err != nil {
return fmt.Errorf("gomatrixserverlib.SplitID: %w", err)
}
if serverName != s.serverName {
return nil
}
var readPos types.StreamPosition
var fullyReadPos types.StreamPosition
if output.ReadMarker.Read != "" {
if _, readPos, err = s.db.PositionInTopology(ctx, output.ReadMarker.Read); err != nil && err != sql.ErrNoRows {
return fmt.Errorf("s.db.PositionInTopology (Read): %w", err)
}
}
if output.ReadMarker.FullyRead != "" {
if _, fullyReadPos, err = s.db.PositionInTopology(ctx, output.ReadMarker.FullyRead); err != nil && err != sql.ErrNoRows {
return fmt.Errorf("s.db.PositionInTopology (FullyRead): %w", err)
}
}
if readPos > 0 || fullyReadPos > 0 {
if err := s.producer.SendReadUpdate(userID, output.RoomID, readPos, fullyReadPos); err != nil {
return fmt.Errorf("s.producer.SendReadUpdate: %w", err)
}
}
return nil
}

View file

@ -16,22 +16,19 @@ package consumers
import ( import (
"context" "context"
"database/sql"
"fmt"
"strconv" "strconv"
"github.com/getsentry/sentry-go" "github.com/getsentry/sentry-go"
"github.com/matrix-org/gomatrixserverlib"
"github.com/nats-io/nats.go"
log "github.com/sirupsen/logrus"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/setup/process"
"github.com/matrix-org/dendrite/syncapi/notifier" "github.com/matrix-org/dendrite/syncapi/notifier"
"github.com/matrix-org/dendrite/syncapi/producers"
"github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib"
"github.com/nats-io/nats.go"
"github.com/sirupsen/logrus"
log "github.com/sirupsen/logrus"
) )
// OutputReceiptEventConsumer consumes events that originated in the EDU server. // OutputReceiptEventConsumer consumes events that originated in the EDU server.
@ -44,7 +41,6 @@ type OutputReceiptEventConsumer struct {
stream types.StreamProvider stream types.StreamProvider
notifier *notifier.Notifier notifier *notifier.Notifier
serverName gomatrixserverlib.ServerName serverName gomatrixserverlib.ServerName
producer *producers.UserAPIReadProducer
} }
// NewOutputReceiptEventConsumer creates a new OutputReceiptEventConsumer. // NewOutputReceiptEventConsumer creates a new OutputReceiptEventConsumer.
@ -56,7 +52,6 @@ func NewOutputReceiptEventConsumer(
store storage.Database, store storage.Database,
notifier *notifier.Notifier, notifier *notifier.Notifier,
stream types.StreamProvider, stream types.StreamProvider,
producer *producers.UserAPIReadProducer,
) *OutputReceiptEventConsumer { ) *OutputReceiptEventConsumer {
return &OutputReceiptEventConsumer{ return &OutputReceiptEventConsumer{
ctx: process.Context(), ctx: process.Context(),
@ -67,7 +62,6 @@ func NewOutputReceiptEventConsumer(
notifier: notifier, notifier: notifier,
stream: stream, stream: stream,
serverName: cfg.Matrix.ServerName, serverName: cfg.Matrix.ServerName,
producer: producer,
} }
} }
@ -111,42 +105,8 @@ func (s *OutputReceiptEventConsumer) onMessage(ctx context.Context, msgs []*nats
return true return true
} }
if err = s.sendReadUpdate(ctx, output); err != nil {
log.WithError(err).WithFields(logrus.Fields{
"user_id": output.UserID,
"room_id": output.RoomID,
}).Errorf("Failed to generate read update")
sentry.CaptureException(err)
return false
}
s.stream.Advance(streamPos) s.stream.Advance(streamPos)
s.notifier.OnNewReceipt(output.RoomID, types.StreamingToken{ReceiptPosition: streamPos}) s.notifier.OnNewReceipt(output.RoomID, types.StreamingToken{ReceiptPosition: streamPos})
return true return true
} }
func (s *OutputReceiptEventConsumer) sendReadUpdate(ctx context.Context, output types.OutputReceiptEvent) error {
if output.Type != "m.read" {
return nil
}
_, serverName, err := gomatrixserverlib.SplitID('@', output.UserID)
if err != nil {
return fmt.Errorf("gomatrixserverlib.SplitID: %w", err)
}
if serverName != s.serverName {
return nil
}
var readPos types.StreamPosition
if output.EventID != "" {
if _, readPos, err = s.db.PositionInTopology(ctx, output.EventID); err != nil && err != sql.ErrNoRows {
return fmt.Errorf("s.db.PositionInTopology (Read): %w", err)
}
}
if readPos > 0 {
if err := s.producer.SendReadUpdate(output.UserID, output.RoomID, readPos, 0); err != nil {
return fmt.Errorf("s.producer.SendReadUpdate: %w", err)
}
}
return nil
}

View file

@ -21,18 +21,18 @@ import (
"fmt" "fmt"
"github.com/getsentry/sentry-go" "github.com/getsentry/sentry-go"
"github.com/matrix-org/gomatrixserverlib"
"github.com/nats-io/nats.go"
"github.com/sirupsen/logrus"
log "github.com/sirupsen/logrus"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/setup/process"
"github.com/matrix-org/dendrite/syncapi/notifier" "github.com/matrix-org/dendrite/syncapi/notifier"
"github.com/matrix-org/dendrite/syncapi/producers"
"github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib"
"github.com/nats-io/nats.go"
"github.com/sirupsen/logrus"
log "github.com/sirupsen/logrus"
) )
// OutputRoomEventConsumer consumes events that originated in the room server. // OutputRoomEventConsumer consumes events that originated in the room server.
@ -47,7 +47,6 @@ type OutputRoomEventConsumer struct {
pduStream types.StreamProvider pduStream types.StreamProvider
inviteStream types.StreamProvider inviteStream types.StreamProvider
notifier *notifier.Notifier notifier *notifier.Notifier
producer *producers.UserAPIStreamEventProducer
} }
// NewOutputRoomEventConsumer creates a new OutputRoomEventConsumer. Call Start() to begin consuming from room servers. // NewOutputRoomEventConsumer creates a new OutputRoomEventConsumer. Call Start() to begin consuming from room servers.
@ -60,7 +59,6 @@ func NewOutputRoomEventConsumer(
pduStream types.StreamProvider, pduStream types.StreamProvider,
inviteStream types.StreamProvider, inviteStream types.StreamProvider,
rsAPI api.SyncRoomserverAPI, rsAPI api.SyncRoomserverAPI,
producer *producers.UserAPIStreamEventProducer,
) *OutputRoomEventConsumer { ) *OutputRoomEventConsumer {
return &OutputRoomEventConsumer{ return &OutputRoomEventConsumer{
ctx: process.Context(), ctx: process.Context(),
@ -73,7 +71,6 @@ func NewOutputRoomEventConsumer(
pduStream: pduStream, pduStream: pduStream,
inviteStream: inviteStream, inviteStream: inviteStream,
rsAPI: rsAPI, rsAPI: rsAPI,
producer: producer,
} }
} }
@ -258,12 +255,6 @@ func (s *OutputRoomEventConsumer) onNewRoomEvent(
return nil return nil
} }
if err = s.producer.SendStreamEvent(ev.RoomID(), ev, pduPos); err != nil {
log.WithError(err).Errorf("Failed to send stream output event for event %s", ev.EventID())
sentry.CaptureException(err)
return err
}
if pduPos, err = s.notifyJoinedPeeks(ctx, ev, pduPos); err != nil { if pduPos, err = s.notifyJoinedPeeks(ctx, ev, pduPos); err != nil {
log.WithError(err).Errorf("Failed to notifyJoinedPeeks for PDU pos %d", pduPos) log.WithError(err).Errorf("Failed to notifyJoinedPeeks for PDU pos %d", pduPos)
sentry.CaptureException(err) sentry.CaptureException(err)

View file

@ -19,6 +19,9 @@ import (
"encoding/json" "encoding/json"
"github.com/getsentry/sentry-go" "github.com/getsentry/sentry-go"
"github.com/nats-io/nats.go"
log "github.com/sirupsen/logrus"
"github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/jetstream"
@ -26,8 +29,6 @@ import (
"github.com/matrix-org/dendrite/syncapi/notifier" "github.com/matrix-org/dendrite/syncapi/notifier"
"github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
"github.com/nats-io/nats.go"
log "github.com/sirupsen/logrus"
) )
// OutputNotificationDataConsumer consumes events that originated in // OutputNotificationDataConsumer consumes events that originated in

View file

@ -1,62 +0,0 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package producers
import (
"encoding/json"
"github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/nats-io/nats.go"
log "github.com/sirupsen/logrus"
)
// UserAPIProducer produces events for the user API server to consume
type UserAPIReadProducer struct {
Topic string
JetStream nats.JetStreamContext
}
// SendData sends account data to the user API server
func (p *UserAPIReadProducer) SendReadUpdate(userID, roomID string, readPos, fullyReadPos types.StreamPosition) error {
m := &nats.Msg{
Subject: p.Topic,
Header: nats.Header{},
}
m.Header.Set(jetstream.UserID, userID)
m.Header.Set(jetstream.RoomID, roomID)
data := types.ReadUpdate{
UserID: userID,
RoomID: roomID,
Read: readPos,
FullyRead: fullyReadPos,
}
var err error
m.Data, err = json.Marshal(data)
if err != nil {
return err
}
log.WithFields(log.Fields{
"user_id": userID,
"room_id": roomID,
"read_pos": readPos,
"fully_read_pos": fullyReadPos,
}).Tracef("Producing to topic '%s'", p.Topic)
_, err = p.JetStream.PublishMsg(m)
return err
}

View file

@ -1,60 +0,0 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package producers
import (
"encoding/json"
"github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib"
"github.com/nats-io/nats.go"
log "github.com/sirupsen/logrus"
)
// UserAPIProducer produces events for the user API server to consume
type UserAPIStreamEventProducer struct {
Topic string
JetStream nats.JetStreamContext
}
// SendData sends account data to the user API server
func (p *UserAPIStreamEventProducer) SendStreamEvent(roomID string, event *gomatrixserverlib.HeaderedEvent, pos types.StreamPosition) error {
m := &nats.Msg{
Subject: p.Topic,
Header: nats.Header{},
}
m.Header.Set(jetstream.RoomID, roomID)
data := types.StreamedEvent{
Event: event,
StreamPosition: pos,
}
var err error
m.Data, err = json.Marshal(data)
if err != nil {
return err
}
log.WithFields(log.Fields{
"room_id": roomID,
"event_id": event.EventID(),
"event_type": event.Type(),
"stream_pos": pos,
}).Tracef("Producing to topic '%s'", p.Topic)
_, err = p.JetStream.PublishMsg(m)
return err
}

View file

@ -29,6 +29,7 @@ import (
type Database interface { type Database interface {
Presence Presence
SharedUsers SharedUsers
Notifications
MaxStreamPositionForPDUs(ctx context.Context) (types.StreamPosition, error) MaxStreamPositionForPDUs(ctx context.Context) (types.StreamPosition, error)
MaxStreamPositionForReceipts(ctx context.Context) (types.StreamPosition, error) MaxStreamPositionForReceipts(ctx context.Context) (types.StreamPosition, error)
@ -149,12 +150,6 @@ type Database interface {
// GetRoomReceipts gets all receipts for a given roomID // GetRoomReceipts gets all receipts for a given roomID
GetRoomReceipts(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) ([]types.OutputReceiptEvent, error) GetRoomReceipts(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) ([]types.OutputReceiptEvent, error)
// UpsertRoomUnreadNotificationCounts updates the notification statistics about a (user, room) key.
UpsertRoomUnreadNotificationCounts(ctx context.Context, userID, roomID string, notificationCount, highlightCount int) (types.StreamPosition, error)
// GetUserUnreadNotificationCounts returns statistics per room a user is interested in.
GetUserUnreadNotificationCounts(ctx context.Context, userID string, from, to types.StreamPosition) (map[string]*eventutil.NotificationData, error)
SelectContextEvent(ctx context.Context, roomID, eventID string) (int, gomatrixserverlib.HeaderedEvent, error) SelectContextEvent(ctx context.Context, roomID, eventID string) (int, gomatrixserverlib.HeaderedEvent, error)
SelectContextBeforeEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) ([]*gomatrixserverlib.HeaderedEvent, error) SelectContextBeforeEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) ([]*gomatrixserverlib.HeaderedEvent, error)
SelectContextAfterEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) (int, []*gomatrixserverlib.HeaderedEvent, error) SelectContextAfterEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) (int, []*gomatrixserverlib.HeaderedEvent, error)
@ -180,3 +175,11 @@ type SharedUsers interface {
// SharedUsers returns a subset of otherUserIDs that share a room with userID. // SharedUsers returns a subset of otherUserIDs that share a room with userID.
SharedUsers(ctx context.Context, userID string, otherUserIDs []string) ([]string, error) SharedUsers(ctx context.Context, userID string, otherUserIDs []string) ([]string, error)
} }
type Notifications interface {
// UpsertRoomUnreadNotificationCounts updates the notification statistics about a (user, room) key.
UpsertRoomUnreadNotificationCounts(ctx context.Context, userID, roomID string, notificationCount, highlightCount int) (types.StreamPosition, error)
// getUserUnreadNotificationCountsForRooms returns the unread notifications for the given rooms
GetUserUnreadNotificationCountsForRooms(ctx context.Context, userID string, roomIDs map[string]string) (map[string]*eventutil.NotificationData, error)
}

View file

@ -62,6 +62,8 @@ CREATE UNIQUE INDEX IF NOT EXISTS syncapi_event_id_idx ON syncapi_current_room_s
CREATE INDEX IF NOT EXISTS syncapi_membership_idx ON syncapi_current_room_state(type, state_key, membership) WHERE membership IS NOT NULL AND membership != 'leave'; CREATE INDEX IF NOT EXISTS syncapi_membership_idx ON syncapi_current_room_state(type, state_key, membership) WHERE membership IS NOT NULL AND membership != 'leave';
-- for querying state by event IDs -- for querying state by event IDs
CREATE UNIQUE INDEX IF NOT EXISTS syncapi_current_room_state_eventid_idx ON syncapi_current_room_state(event_id); CREATE UNIQUE INDEX IF NOT EXISTS syncapi_current_room_state_eventid_idx ON syncapi_current_room_state(event_id);
-- for improving selectRoomIDsWithAnyMembershipSQL
CREATE INDEX IF NOT EXISTS syncapi_current_room_state_type_state_key_idx ON syncapi_current_room_state(type, state_key);
` `
const upsertRoomStateSQL = "" + const upsertRoomStateSQL = "" +
@ -80,7 +82,7 @@ const selectRoomIDsWithMembershipSQL = "" +
"SELECT DISTINCT room_id FROM syncapi_current_room_state WHERE type = 'm.room.member' AND state_key = $1 AND membership = $2" "SELECT DISTINCT room_id FROM syncapi_current_room_state WHERE type = 'm.room.member' AND state_key = $1 AND membership = $2"
const selectRoomIDsWithAnyMembershipSQL = "" + const selectRoomIDsWithAnyMembershipSQL = "" +
"SELECT DISTINCT room_id, membership FROM syncapi_current_room_state WHERE type = 'm.room.member' AND state_key = $1" "SELECT room_id, membership FROM syncapi_current_room_state WHERE type = 'm.room.member' AND state_key = $1"
const selectCurrentStateSQL = "" + const selectCurrentStateSQL = "" +
"SELECT event_id, headered_event_json FROM syncapi_current_room_state WHERE room_id = $1" + "SELECT event_id, headered_event_json FROM syncapi_current_room_state WHERE room_id = $1" +

View file

@ -18,6 +18,8 @@ import (
"context" "context"
"database/sql" "database/sql"
"github.com/lib/pq"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
@ -33,17 +35,15 @@ func NewPostgresNotificationDataTable(db *sql.DB) (tables.NotificationData, erro
r := &notificationDataStatements{} r := &notificationDataStatements{}
return r, sqlutil.StatementList{ return r, sqlutil.StatementList{
{&r.upsertRoomUnreadCounts, upsertRoomUnreadNotificationCountsSQL}, {&r.upsertRoomUnreadCounts, upsertRoomUnreadNotificationCountsSQL},
{&r.selectUserUnreadCounts, selectUserUnreadNotificationCountsSQL}, {&r.selectUserUnreadCountsForRooms, selectUserUnreadNotificationsForRooms},
{&r.selectMaxID, selectMaxNotificationIDSQL}, {&r.selectMaxID, selectMaxNotificationIDSQL},
{&r.purgeNotificationData, purgeNotificationDataSQL},
}.Prepare(db) }.Prepare(db)
} }
type notificationDataStatements struct { type notificationDataStatements struct {
upsertRoomUnreadCounts *sql.Stmt upsertRoomUnreadCounts *sql.Stmt
selectUserUnreadCounts *sql.Stmt selectUserUnreadCountsForRooms *sql.Stmt
selectMaxID *sql.Stmt selectMaxID *sql.Stmt
purgeNotificationData *sql.Stmt
} }
const notificationDataSchema = ` const notificationDataSchema = `
@ -63,12 +63,10 @@ const upsertRoomUnreadNotificationCountsSQL = `INSERT INTO syncapi_notification_
DO UPDATE SET id = nextval('syncapi_notification_data_id_seq'), notification_count = $3, highlight_count = $4 DO UPDATE SET id = nextval('syncapi_notification_data_id_seq'), notification_count = $3, highlight_count = $4
RETURNING id` RETURNING id`
const selectUserUnreadNotificationCountsSQL = `SELECT const selectUserUnreadNotificationsForRooms = `SELECT room_id, notification_count, highlight_count
id, room_id, notification_count, highlight_count FROM syncapi_notification_data
FROM syncapi_notification_data WHERE user_id = $1 AND
WHERE room_id = ANY($2)`
user_id = $1 AND
id BETWEEN $2 + 1 AND $3`
const selectMaxNotificationIDSQL = `SELECT CASE COUNT(*) WHEN 0 THEN 0 ELSE MAX(id) END FROM syncapi_notification_data` const selectMaxNotificationIDSQL = `SELECT CASE COUNT(*) WHEN 0 THEN 0 ELSE MAX(id) END FROM syncapi_notification_data`
@ -80,20 +78,20 @@ func (r *notificationDataStatements) UpsertRoomUnreadCounts(ctx context.Context,
return return
} }
func (r *notificationDataStatements) SelectUserUnreadCounts(ctx context.Context, txn *sql.Tx, userID string, fromExcl, toIncl types.StreamPosition) (map[string]*eventutil.NotificationData, error) { func (r *notificationDataStatements) SelectUserUnreadCountsForRooms(
rows, err := sqlutil.TxStmt(txn, r.selectUserUnreadCounts).QueryContext(ctx, userID, fromExcl, toIncl) ctx context.Context, txn *sql.Tx, userID string, roomIDs []string,
) (map[string]*eventutil.NotificationData, error) {
rows, err := sqlutil.TxStmt(txn, r.selectUserUnreadCountsForRooms).QueryContext(ctx, userID, pq.Array(roomIDs))
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer internal.CloseAndLogIfError(ctx, rows, "SelectUserUnreadCounts: rows.close() failed") defer internal.CloseAndLogIfError(ctx, rows, "SelectUserUnreadCountsForRooms: rows.close() failed")
roomCounts := map[string]*eventutil.NotificationData{} roomCounts := map[string]*eventutil.NotificationData{}
var roomID string
var notificationCount, highlightCount int
for rows.Next() { for rows.Next() {
var id types.StreamPosition if err = rows.Scan(&roomID, &notificationCount, &highlightCount); err != nil {
var roomID string
var notificationCount, highlightCount int
if err = rows.Scan(&id, &roomID, &notificationCount, &highlightCount); err != nil {
return nil, err return nil, err
} }

View file

@ -1036,8 +1036,15 @@ func (d *Database) UpsertRoomUnreadNotificationCounts(ctx context.Context, userI
return return
} }
func (d *Database) GetUserUnreadNotificationCounts(ctx context.Context, userID string, from, to types.StreamPosition) (map[string]*eventutil.NotificationData, error) { func (d *Database) GetUserUnreadNotificationCountsForRooms(ctx context.Context, userID string, rooms map[string]string) (map[string]*eventutil.NotificationData, error) {
return d.NotificationData.SelectUserUnreadCounts(ctx, nil, userID, from, to) roomIDs := make([]string, 0, len(rooms))
for roomID, membership := range rooms {
if membership != gomatrixserverlib.Join {
continue
}
roomIDs = append(roomIDs, roomID)
}
return d.NotificationData.SelectUserUnreadCountsForRooms(ctx, nil, userID, roomIDs)
} }
func (d *Database) SelectContextEvent(ctx context.Context, roomID, eventID string) (int, gomatrixserverlib.HeaderedEvent, error) { func (d *Database) SelectContextEvent(ctx context.Context, roomID, eventID string) (int, gomatrixserverlib.HeaderedEvent, error) {

View file

@ -51,6 +51,8 @@ CREATE UNIQUE INDEX IF NOT EXISTS syncapi_event_id_idx ON syncapi_current_room_s
-- CREATE INDEX IF NOT EXISTS syncapi_membership_idx ON syncapi_current_room_state(type, state_key, membership) WHERE membership IS NOT NULL AND membership != 'leave'; -- CREATE INDEX IF NOT EXISTS syncapi_membership_idx ON syncapi_current_room_state(type, state_key, membership) WHERE membership IS NOT NULL AND membership != 'leave';
-- for querying state by event IDs -- for querying state by event IDs
CREATE UNIQUE INDEX IF NOT EXISTS syncapi_current_room_state_eventid_idx ON syncapi_current_room_state(event_id); CREATE UNIQUE INDEX IF NOT EXISTS syncapi_current_room_state_eventid_idx ON syncapi_current_room_state(event_id);
-- for improving selectRoomIDsWithAnyMembershipSQL
CREATE INDEX IF NOT EXISTS syncapi_current_room_state_type_state_key_idx ON syncapi_current_room_state(type, state_key);
` `
const upsertRoomStateSQL = "" + const upsertRoomStateSQL = "" +
@ -69,7 +71,7 @@ const selectRoomIDsWithMembershipSQL = "" +
"SELECT DISTINCT room_id FROM syncapi_current_room_state WHERE type = 'm.room.member' AND state_key = $1 AND membership = $2" "SELECT DISTINCT room_id FROM syncapi_current_room_state WHERE type = 'm.room.member' AND state_key = $1 AND membership = $2"
const selectRoomIDsWithAnyMembershipSQL = "" + const selectRoomIDsWithAnyMembershipSQL = "" +
"SELECT DISTINCT room_id, membership FROM syncapi_current_room_state WHERE type = 'm.room.member' AND state_key = $1" "SELECT room_id, membership FROM syncapi_current_room_state WHERE type = 'm.room.member' AND state_key = $1"
const selectCurrentStateSQL = "" + const selectCurrentStateSQL = "" +
"SELECT event_id, headered_event_json FROM syncapi_current_room_state WHERE room_id = $1" "SELECT event_id, headered_event_json FROM syncapi_current_room_state WHERE room_id = $1"

View file

@ -18,6 +18,7 @@ import (
"context" "context"
"database/sql" "database/sql"
"fmt" "fmt"
"strings"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/internal/eventutil"
@ -33,19 +34,21 @@ func NewSqliteNotificationDataTable(db *sql.DB, streamID *StreamIDStatements) (t
} }
r := &notificationDataStatements{ r := &notificationDataStatements{
streamIDStatements: streamID, streamIDStatements: streamID,
db: db,
} }
return r, sqlutil.StatementList{ return r, sqlutil.StatementList{
{&r.upsertRoomUnreadCounts, upsertRoomUnreadNotificationCountsSQL}, {&r.upsertRoomUnreadCounts, upsertRoomUnreadNotificationCountsSQL},
{&r.selectUserUnreadCounts, selectUserUnreadNotificationCountsSQL},
{&r.selectMaxID, selectMaxNotificationIDSQL}, {&r.selectMaxID, selectMaxNotificationIDSQL},
// {&r.selectUserUnreadCountsForRooms, selectUserUnreadNotificationsForRooms}, // used at runtime
}.Prepare(db) }.Prepare(db)
} }
type notificationDataStatements struct { type notificationDataStatements struct {
db *sql.DB
streamIDStatements *StreamIDStatements streamIDStatements *StreamIDStatements
upsertRoomUnreadCounts *sql.Stmt upsertRoomUnreadCounts *sql.Stmt
selectUserUnreadCounts *sql.Stmt
selectMaxID *sql.Stmt selectMaxID *sql.Stmt
//selectUserUnreadCountsForRooms *sql.Stmt
} }
const notificationDataSchema = ` const notificationDataSchema = `
@ -64,12 +67,10 @@ const upsertRoomUnreadNotificationCountsSQL = `INSERT INTO syncapi_notification_
ON CONFLICT (user_id, room_id) ON CONFLICT (user_id, room_id)
DO UPDATE SET id = $5, notification_count = $6, highlight_count = $7` DO UPDATE SET id = $5, notification_count = $6, highlight_count = $7`
const selectUserUnreadNotificationCountsSQL = `SELECT const selectUserUnreadNotificationsForRooms = `SELECT room_id, notification_count, highlight_count
id, room_id, notification_count, highlight_count FROM syncapi_notification_data
FROM syncapi_notification_data WHERE user_id = $1 AND
WHERE room_id IN ($2)`
user_id = $1 AND
id BETWEEN $2 + 1 AND $3`
const selectMaxNotificationIDSQL = `SELECT CASE COUNT(*) WHEN 0 THEN 0 ELSE MAX(id) END FROM syncapi_notification_data` const selectMaxNotificationIDSQL = `SELECT CASE COUNT(*) WHEN 0 THEN 0 ELSE MAX(id) END FROM syncapi_notification_data`
@ -82,20 +83,26 @@ func (r *notificationDataStatements) UpsertRoomUnreadCounts(ctx context.Context,
return return
} }
func (r *notificationDataStatements) SelectUserUnreadCounts(ctx context.Context, txn *sql.Tx, userID string, fromExcl, toIncl types.StreamPosition) (map[string]*eventutil.NotificationData, error) { func (r *notificationDataStatements) SelectUserUnreadCountsForRooms(
rows, err := sqlutil.TxStmt(txn, r.selectUserUnreadCounts).QueryContext(ctx, userID, fromExcl, toIncl) ctx context.Context, txn *sql.Tx, userID string, roomIDs []string,
) (map[string]*eventutil.NotificationData, error) {
params := make([]interface{}, len(roomIDs)+1)
params[0] = userID
for i := range roomIDs {
params[i+1] = roomIDs[i]
}
sql := strings.Replace(selectUserUnreadNotificationsForRooms, "($1)", sqlutil.QueryVariadic(len(params)), 1)
rows, err := r.db.QueryContext(ctx, sql, params)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer internal.CloseAndLogIfError(ctx, rows, "SelectUserUnreadCounts: rows.close() failed") defer internal.CloseAndLogIfError(ctx, rows, "SelectUserUnreadCountsForRooms: rows.close() failed")
roomCounts := map[string]*eventutil.NotificationData{} roomCounts := map[string]*eventutil.NotificationData{}
var roomID string
var notificationCount, highlightCount int
for rows.Next() { for rows.Next() {
var id types.StreamPosition if err = rows.Scan(&roomID, &notificationCount, &highlightCount); err != nil {
var roomID string
var notificationCount, highlightCount int
if err = rows.Scan(&id, &roomID, &notificationCount, &highlightCount); err != nil {
return nil, err return nil, err
} }

View file

@ -198,7 +198,7 @@ type Memberships interface {
type NotificationData interface { type NotificationData interface {
UpsertRoomUnreadCounts(ctx context.Context, txn *sql.Tx, userID, roomID string, notificationCount, highlightCount int) (types.StreamPosition, error) UpsertRoomUnreadCounts(ctx context.Context, txn *sql.Tx, userID, roomID string, notificationCount, highlightCount int) (types.StreamPosition, error)
SelectUserUnreadCounts(ctx context.Context, txn *sql.Tx, userID string, fromExcl, toIncl types.StreamPosition) (map[string]*eventutil.NotificationData, error) SelectUserUnreadCountsForRooms(ctx context.Context, txn *sql.Tx, userID string, roomIDs []string) (map[string]*eventutil.NotificationData, error)
SelectMaxID(ctx context.Context, txn *sql.Tx) (int64, error) SelectMaxID(ctx context.Context, txn *sql.Tx) (int64, error)
PurgeNotificationData(ctx context.Context, txn *sql.Tx, roomID string) error PurgeNotificationData(ctx context.Context, txn *sql.Tx, roomID string) error
} }

View file

@ -3,9 +3,10 @@ package streams
import ( import (
"context" "context"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
) )
type AccountDataStreamProvider struct { type AccountDataStreamProvider struct {

View file

@ -30,26 +30,29 @@ func (p *NotificationDataStreamProvider) CompleteSync(
func (p *NotificationDataStreamProvider) IncrementalSync( func (p *NotificationDataStreamProvider) IncrementalSync(
ctx context.Context, ctx context.Context,
req *types.SyncRequest, req *types.SyncRequest,
from, to types.StreamPosition, from, _ types.StreamPosition,
) types.StreamPosition { ) types.StreamPosition {
// We want counts for all possible rooms, so always start from zero. // Get the unread notifications for rooms in our join response.
countsByRoom, err := p.DB.GetUserUnreadNotificationCounts(ctx, req.Device.UserID, from, to) // This is to ensure clients always have an unread notification section
// and can display the correct numbers.
countsByRoom, err := p.DB.GetUserUnreadNotificationCountsForRooms(ctx, req.Device.UserID, req.Rooms)
if err != nil { if err != nil {
req.Log.WithError(err).Error("GetUserUnreadNotificationCounts failed") req.Log.WithError(err).Error("GetUserUnreadNotificationCountsForRooms failed")
return from return from
} }
// We're merely decorating existing rooms. Note that the Join map // We're merely decorating existing rooms.
// values are not pointers.
for roomID, jr := range req.Response.Rooms.Join { for roomID, jr := range req.Response.Rooms.Join {
counts := countsByRoom[roomID] counts := countsByRoom[roomID]
if counts == nil { if counts == nil {
continue continue
} }
jr.UnreadNotifications = &types.UnreadNotifications{
jr.UnreadNotifications.HighlightCount = counts.UnreadHighlightCount HighlightCount: counts.UnreadHighlightCount,
jr.UnreadNotifications.NotificationCount = counts.UnreadNotificationCount NotificationCount: counts.UnreadNotificationCount,
}
req.Response.Rooms.Join[roomID] = jr req.Response.Rooms.Join[roomID] = jr
} }
return to
return p.LatestPosition(ctx)
} }

View file

@ -77,16 +77,6 @@ func AddPublicRoutes(
logrus.WithError(err).Panicf("failed to start presence consumer") logrus.WithError(err).Panicf("failed to start presence consumer")
} }
userAPIStreamEventProducer := &producers.UserAPIStreamEventProducer{
JetStream: js,
Topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputStreamEvent),
}
userAPIReadUpdateProducer := &producers.UserAPIReadProducer{
JetStream: js,
Topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputReadUpdate),
}
keyChangeConsumer := consumers.NewOutputKeyChangeEventConsumer( keyChangeConsumer := consumers.NewOutputKeyChangeEventConsumer(
base.ProcessContext, cfg, cfg.Matrix.JetStream.Prefixed(jetstream.OutputKeyChangeEvent), base.ProcessContext, cfg, cfg.Matrix.JetStream.Prefixed(jetstream.OutputKeyChangeEvent),
js, rsAPI, syncDB, notifier, js, rsAPI, syncDB, notifier,
@ -98,7 +88,7 @@ func AddPublicRoutes(
roomConsumer := consumers.NewOutputRoomEventConsumer( roomConsumer := consumers.NewOutputRoomEventConsumer(
base.ProcessContext, cfg, js, syncDB, notifier, streams.PDUStreamProvider, base.ProcessContext, cfg, js, syncDB, notifier, streams.PDUStreamProvider,
streams.InviteStreamProvider, rsAPI, userAPIStreamEventProducer, streams.InviteStreamProvider, rsAPI,
) )
if err = roomConsumer.Start(); err != nil { if err = roomConsumer.Start(); err != nil {
logrus.WithError(err).Panicf("failed to start room server consumer") logrus.WithError(err).Panicf("failed to start room server consumer")
@ -106,7 +96,6 @@ func AddPublicRoutes(
clientConsumer := consumers.NewOutputClientDataConsumer( clientConsumer := consumers.NewOutputClientDataConsumer(
base.ProcessContext, cfg, js, syncDB, notifier, streams.AccountDataStreamProvider, base.ProcessContext, cfg, js, syncDB, notifier, streams.AccountDataStreamProvider,
userAPIReadUpdateProducer,
) )
if err = clientConsumer.Start(); err != nil { if err = clientConsumer.Start(); err != nil {
logrus.WithError(err).Panicf("failed to start client data consumer") logrus.WithError(err).Panicf("failed to start client data consumer")
@ -135,7 +124,6 @@ func AddPublicRoutes(
receiptConsumer := consumers.NewOutputReceiptEventConsumer( receiptConsumer := consumers.NewOutputReceiptEventConsumer(
base.ProcessContext, cfg, js, syncDB, notifier, streams.ReceiptStreamProvider, base.ProcessContext, cfg, js, syncDB, notifier, streams.ReceiptStreamProvider,
userAPIReadUpdateProducer,
) )
if err = receiptConsumer.Start(); err != nil { if err = receiptConsumer.Start(); err != nil {
logrus.WithError(err).Panicf("failed to start receipts consumer") logrus.WithError(err).Panicf("failed to start receipts consumer")

View file

@ -398,6 +398,11 @@ func (r *Response) IsEmpty() bool {
len(r.ToDevice.Events) == 0 len(r.ToDevice.Events) == 0
} }
type UnreadNotifications struct {
HighlightCount int `json:"highlight_count"`
NotificationCount int `json:"notification_count"`
}
// JoinResponse represents a /sync response for a room which is under the 'join' or 'peek' key. // JoinResponse represents a /sync response for a room which is under the 'join' or 'peek' key.
type JoinResponse struct { type JoinResponse struct {
Summary struct { Summary struct {
@ -419,10 +424,7 @@ type JoinResponse struct {
AccountData struct { AccountData struct {
Events []gomatrixserverlib.ClientEvent `json:"events"` Events []gomatrixserverlib.ClientEvent `json:"events"`
} `json:"account_data"` } `json:"account_data"`
UnreadNotifications struct { *UnreadNotifications `json:"unread_notifications,omitempty"`
HighlightCount int `json:"highlight_count"`
NotificationCount int `json:"notification_count"`
} `json:"unread_notifications"`
} }
// NewJoinResponse creates an empty response with initialised arrays. // NewJoinResponse creates an empty response with initialised arrays.
@ -503,19 +505,6 @@ type Peek struct {
Deleted bool Deleted bool
} }
type ReadUpdate struct {
UserID string `json:"user_id"`
RoomID string `json:"room_id"`
Read StreamPosition `json:"read,omitempty"`
FullyRead StreamPosition `json:"fully_read,omitempty"`
}
// StreamEvent is the same as gomatrixserverlib.Event but also has the PDU stream position for this event.
type StreamedEvent struct {
Event *gomatrixserverlib.HeaderedEvent `json:"event"`
StreamPosition StreamPosition `json:"stream_position"`
}
// OutputReceiptEvent is an entry in the receipt output kafka log // OutputReceiptEvent is an entry in the receipt output kafka log
type OutputReceiptEvent struct { type OutputReceiptEvent struct {
UserID string `json:"user_id"` UserID string `json:"user_id"`

View file

@ -0,0 +1,127 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package consumers
import (
"context"
"github.com/matrix-org/gomatrixserverlib"
"github.com/nats-io/nats.go"
log "github.com/sirupsen/logrus"
"github.com/matrix-org/dendrite/internal/pushgateway"
"github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/setup/process"
"github.com/matrix-org/dendrite/userapi/producers"
"github.com/matrix-org/dendrite/userapi/util"
)
// OutputReceiptEventConsumer consumes events that originated in the clientAPI.
type OutputReceiptEventConsumer struct {
ctx context.Context
jetstream nats.JetStreamContext
durable string
topic string
db storage.Database
serverName gomatrixserverlib.ServerName
syncProducer *producers.SyncAPI
pgClient pushgateway.Client
}
// NewOutputReceiptEventConsumer creates a new OutputReceiptEventConsumer.
// Call Start() to begin consuming from the EDU server.
func NewOutputReceiptEventConsumer(
process *process.ProcessContext,
cfg *config.UserAPI,
js nats.JetStreamContext,
store storage.Database,
syncProducer *producers.SyncAPI,
pgClient pushgateway.Client,
) *OutputReceiptEventConsumer {
return &OutputReceiptEventConsumer{
ctx: process.Context(),
jetstream: js,
topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputReceiptEvent),
durable: cfg.Matrix.JetStream.Durable("UserAPIReceiptConsumer"),
db: store,
serverName: cfg.Matrix.ServerName,
syncProducer: syncProducer,
pgClient: pgClient,
}
}
// Start consuming receipts events.
func (s *OutputReceiptEventConsumer) Start() error {
return jetstream.JetStreamConsumer(
s.ctx, s.jetstream, s.topic, s.durable, 1,
s.onMessage, nats.DeliverAll(), nats.ManualAck(),
)
}
func (s *OutputReceiptEventConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool {
msg := msgs[0] // Guaranteed to exist if onMessage is called
userID := msg.Header.Get(jetstream.UserID)
roomID := msg.Header.Get(jetstream.RoomID)
readPos := msg.Header.Get(jetstream.EventID)
evType := msg.Header.Get("type")
if readPos == "" || evType != "m.read" {
return true
}
log := log.WithFields(log.Fields{
"room_id": roomID,
"user_id": userID,
})
localpart, domain, err := gomatrixserverlib.SplitID('@', userID)
if err != nil {
log.WithError(err).Error("userapi clientapi consumer: SplitID failure")
return true
}
if domain != s.serverName {
return true
}
metadata, err := msg.Metadata()
if err != nil {
return false
}
updated, err := s.db.SetNotificationsRead(ctx, localpart, roomID, uint64(gomatrixserverlib.AsTimestamp(metadata.Timestamp)), true)
if err != nil {
log.WithError(err).Error("userapi EDU consumer")
return false
}
if err = s.syncProducer.GetAndSendNotificationData(ctx, userID, roomID); err != nil {
log.WithError(err).Error("userapi EDU consumer: GetAndSendNotificationData failed")
return false
}
if !updated {
return true
}
if err = util.NotifyUserCountsAsync(ctx, s.pgClient, localpart, s.db); err != nil {
log.WithError(err).Error("userapi EDU consumer: NotifyUserCounts failed")
return false
}
return true
}

View file

@ -26,7 +26,7 @@ import (
"github.com/matrix-org/dendrite/userapi/util" "github.com/matrix-org/dendrite/userapi/util"
) )
type OutputStreamEventConsumer struct { type OutputRoomEventConsumer struct {
ctx context.Context ctx context.Context
cfg *config.UserAPI cfg *config.UserAPI
rsAPI rsapi.UserRoomserverAPI rsAPI rsapi.UserRoomserverAPI
@ -38,7 +38,7 @@ type OutputStreamEventConsumer struct {
syncProducer *producers.SyncAPI syncProducer *producers.SyncAPI
} }
func NewOutputStreamEventConsumer( func NewOutputRoomEventConsumer(
process *process.ProcessContext, process *process.ProcessContext,
cfg *config.UserAPI, cfg *config.UserAPI,
js nats.JetStreamContext, js nats.JetStreamContext,
@ -46,21 +46,21 @@ func NewOutputStreamEventConsumer(
pgClient pushgateway.Client, pgClient pushgateway.Client,
rsAPI rsapi.UserRoomserverAPI, rsAPI rsapi.UserRoomserverAPI,
syncProducer *producers.SyncAPI, syncProducer *producers.SyncAPI,
) *OutputStreamEventConsumer { ) *OutputRoomEventConsumer {
return &OutputStreamEventConsumer{ return &OutputRoomEventConsumer{
ctx: process.Context(), ctx: process.Context(),
cfg: cfg, cfg: cfg,
jetstream: js, jetstream: js,
db: store, db: store,
durable: cfg.Matrix.JetStream.Durable("UserAPISyncAPIStreamEventConsumer"), durable: cfg.Matrix.JetStream.Durable("UserAPIRoomServerConsumer"),
topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputStreamEvent), topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputRoomEvent),
pgClient: pgClient, pgClient: pgClient,
rsAPI: rsAPI, rsAPI: rsAPI,
syncProducer: syncProducer, syncProducer: syncProducer,
} }
} }
func (s *OutputStreamEventConsumer) Start() error { func (s *OutputRoomEventConsumer) Start() error {
if err := jetstream.JetStreamConsumer( if err := jetstream.JetStreamConsumer(
s.ctx, s.jetstream, s.topic, s.durable, 1, s.ctx, s.jetstream, s.topic, s.durable, 1,
s.onMessage, nats.DeliverAll(), nats.ManualAck(), s.onMessage, nats.DeliverAll(), nats.ManualAck(),
@ -70,35 +70,43 @@ func (s *OutputStreamEventConsumer) Start() error {
return nil return nil
} }
func (s *OutputStreamEventConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool { func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool {
msg := msgs[0] // Guaranteed to exist if onMessage is called msg := msgs[0] // Guaranteed to exist if onMessage is called
var output types.StreamedEvent var output rsapi.OutputEvent
output.Event = &gomatrixserverlib.HeaderedEvent{}
if err := json.Unmarshal(msg.Data, &output); err != nil { if err := json.Unmarshal(msg.Data, &output); err != nil {
log.WithError(err).Errorf("userapi consumer: message parse failure") // If the message was invalid, log it and move on to the next message in the stream
log.WithError(err).Errorf("roomserver output log: message parse failure")
return true return true
} }
if output.Event.Event == nil { if output.Type != rsapi.OutputTypeNewRoomEvent {
return true
}
event := output.NewRoomEvent.Event
if event == nil {
log.Errorf("userapi consumer: expected event") log.Errorf("userapi consumer: expected event")
return true return true
} }
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"event_id": output.Event.EventID(), "event_id": event.EventID(),
"event_type": output.Event.Type(), "event_type": event.Type(),
"stream_pos": output.StreamPosition, }).Tracef("Received message from roomserver: %#v", output)
}).Tracef("Received message from sync API: %#v", output)
if err := s.processMessage(ctx, output.Event, int64(output.StreamPosition)); err != nil { metadata, err := msg.Metadata()
if err != nil {
return true
}
if err := s.processMessage(ctx, event, uint64(gomatrixserverlib.AsTimestamp(metadata.Timestamp))); err != nil {
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"event_id": output.Event.EventID(), "event_id": event.EventID(),
}).WithError(err).Errorf("userapi consumer: process room event failure") }).WithError(err).Errorf("userapi consumer: process room event failure")
} }
return true return true
} }
func (s *OutputStreamEventConsumer) processMessage(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, pos int64) error { func (s *OutputRoomEventConsumer) processMessage(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, streamPos uint64) error {
members, roomSize, err := s.localRoomMembers(ctx, event.RoomID()) members, roomSize, err := s.localRoomMembers(ctx, event.RoomID())
if err != nil { if err != nil {
return fmt.Errorf("s.localRoomMembers: %w", err) return fmt.Errorf("s.localRoomMembers: %w", err)
@ -138,10 +146,10 @@ func (s *OutputStreamEventConsumer) processMessage(ctx context.Context, event *g
// removing it means we can send all notifications to // removing it means we can send all notifications to
// e.g. Element's Push gateway in one go. // e.g. Element's Push gateway in one go.
for _, mem := range members { for _, mem := range members {
if err := s.notifyLocal(ctx, event, pos, mem, roomSize, roomName); err != nil { if err := s.notifyLocal(ctx, event, mem, roomSize, roomName, streamPos); err != nil {
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"localpart": mem.Localpart, "localpart": mem.Localpart,
}).WithError(err).Debugf("Unable to push to local user") }).WithError(err).Error("Unable to push to local user")
continue continue
} }
} }
@ -179,7 +187,7 @@ func newLocalMembership(event *gomatrixserverlib.ClientEvent) (*localMembership,
// localRoomMembers fetches the current local members of a room, and // localRoomMembers fetches the current local members of a room, and
// the total number of members. // the total number of members.
func (s *OutputStreamEventConsumer) localRoomMembers(ctx context.Context, roomID string) ([]*localMembership, int, error) { func (s *OutputRoomEventConsumer) localRoomMembers(ctx context.Context, roomID string) ([]*localMembership, int, error) {
req := &rsapi.QueryMembershipsForRoomRequest{ req := &rsapi.QueryMembershipsForRoomRequest{
RoomID: roomID, RoomID: roomID,
JoinedOnly: true, JoinedOnly: true,
@ -219,7 +227,7 @@ func (s *OutputStreamEventConsumer) localRoomMembers(ctx context.Context, roomID
// looks it up in roomserver. If there is no name, // looks it up in roomserver. If there is no name,
// m.room.canonical_alias is consulted. Returns an empty string if the // m.room.canonical_alias is consulted. Returns an empty string if the
// room has no name. // room has no name.
func (s *OutputStreamEventConsumer) roomName(ctx context.Context, event *gomatrixserverlib.HeaderedEvent) (string, error) { func (s *OutputRoomEventConsumer) roomName(ctx context.Context, event *gomatrixserverlib.HeaderedEvent) (string, error) {
if event.Type() == gomatrixserverlib.MRoomName { if event.Type() == gomatrixserverlib.MRoomName {
name, err := unmarshalRoomName(event) name, err := unmarshalRoomName(event)
if err != nil { if err != nil {
@ -287,7 +295,7 @@ func unmarshalCanonicalAlias(event *gomatrixserverlib.HeaderedEvent) (string, er
} }
// notifyLocal finds the right push actions for a local user, given an event. // notifyLocal finds the right push actions for a local user, given an event.
func (s *OutputStreamEventConsumer) notifyLocal(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, pos int64, mem *localMembership, roomSize int, roomName string) error { func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, mem *localMembership, roomSize int, roomName string, streamPos uint64) error {
actions, err := s.evaluatePushRules(ctx, event, mem, roomSize) actions, err := s.evaluatePushRules(ctx, event, mem, roomSize)
if err != nil { if err != nil {
return err return err
@ -302,7 +310,7 @@ func (s *OutputStreamEventConsumer) notifyLocal(ctx context.Context, event *goma
"event_id": event.EventID(), "event_id": event.EventID(),
"room_id": event.RoomID(), "room_id": event.RoomID(),
"localpart": mem.Localpart, "localpart": mem.Localpart,
}).Debugf("Push rule evaluation rejected the event") }).Tracef("Push rule evaluation rejected the event")
return nil return nil
} }
@ -325,7 +333,7 @@ func (s *OutputStreamEventConsumer) notifyLocal(ctx context.Context, event *goma
RoomID: event.RoomID(), RoomID: event.RoomID(),
TS: gomatrixserverlib.AsTimestamp(time.Now()), TS: gomatrixserverlib.AsTimestamp(time.Now()),
} }
if err = s.db.InsertNotification(ctx, mem.Localpart, event.EventID(), pos, tweaks, n); err != nil { if err = s.db.InsertNotification(ctx, mem.Localpart, event.EventID(), streamPos, tweaks, n); err != nil {
return err return err
} }
@ -345,7 +353,7 @@ func (s *OutputStreamEventConsumer) notifyLocal(ctx context.Context, event *goma
"localpart": mem.Localpart, "localpart": mem.Localpart,
"num_urls": len(devicesByURLAndFormat), "num_urls": len(devicesByURLAndFormat),
"num_unread": userNumUnreadNotifs, "num_unread": userNumUnreadNotifs,
}).Debugf("Notifying single member") }).Trace("Notifying single member")
// Push gateways are out of our control, and we cannot risk // Push gateways are out of our control, and we cannot risk
// looking up the server on a misbehaving push gateway. Each user // looking up the server on a misbehaving push gateway. Each user
@ -396,7 +404,7 @@ func (s *OutputStreamEventConsumer) notifyLocal(ctx context.Context, event *goma
// evaluatePushRules fetches and evaluates the push rules of a local // evaluatePushRules fetches and evaluates the push rules of a local
// user. Returns actions (including dont_notify). // user. Returns actions (including dont_notify).
func (s *OutputStreamEventConsumer) evaluatePushRules(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, mem *localMembership, roomSize int) ([]*pushrules.Action, error) { func (s *OutputRoomEventConsumer) evaluatePushRules(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, mem *localMembership, roomSize int) ([]*pushrules.Action, error) {
if event.Sender() == mem.UserID { if event.Sender() == mem.UserID {
// SPEC: Homeservers MUST NOT notify the Push Gateway for // SPEC: Homeservers MUST NOT notify the Push Gateway for
// events that the user has sent themselves. // events that the user has sent themselves.
@ -447,7 +455,7 @@ func (s *OutputStreamEventConsumer) evaluatePushRules(ctx context.Context, event
"room_id": event.RoomID(), "room_id": event.RoomID(),
"localpart": mem.Localpart, "localpart": mem.Localpart,
"rule_id": rule.RuleID, "rule_id": rule.RuleID,
}).Tracef("Matched a push rule") }).Trace("Matched a push rule")
return rule.Actions, nil return rule.Actions, nil
} }
@ -491,7 +499,7 @@ func (rse *ruleSetEvalContext) HasPowerLevel(userID, levelKey string) (bool, err
// localPushDevices pushes to the configured devices of a local // localPushDevices pushes to the configured devices of a local
// user. The map keys are [url][format]. // user. The map keys are [url][format].
func (s *OutputStreamEventConsumer) localPushDevices(ctx context.Context, localpart string, tweaks map[string]interface{}) (map[string]map[string][]*pushgateway.Device, string, error) { func (s *OutputRoomEventConsumer) localPushDevices(ctx context.Context, localpart string, tweaks map[string]interface{}) (map[string]map[string][]*pushgateway.Device, string, error) {
pusherDevices, err := util.GetPushDevices(ctx, localpart, tweaks, s.db) pusherDevices, err := util.GetPushDevices(ctx, localpart, tweaks, s.db)
if err != nil { if err != nil {
return nil, "", err return nil, "", err
@ -515,7 +523,7 @@ func (s *OutputStreamEventConsumer) localPushDevices(ctx context.Context, localp
} }
// notifyHTTP performs a notificatation to a Push Gateway. // notifyHTTP performs a notificatation to a Push Gateway.
func (s *OutputStreamEventConsumer) notifyHTTP(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, url, format string, devices []*pushgateway.Device, localpart, roomName string, userNumUnreadNotifs int) ([]*pushgateway.Device, error) { func (s *OutputRoomEventConsumer) notifyHTTP(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, url, format string, devices []*pushgateway.Device, localpart, roomName string, userNumUnreadNotifs int) ([]*pushgateway.Device, error) {
logger := log.WithFields(log.Fields{ logger := log.WithFields(log.Fields{
"event_id": event.EventID(), "event_id": event.EventID(),
"url": url, "url": url,
@ -561,13 +569,13 @@ func (s *OutputStreamEventConsumer) notifyHTTP(ctx context.Context, event *gomat
} }
} }
logger.Debugf("Notifying push gateway %s", url) logger.Tracef("Notifying push gateway %s", url)
var res pushgateway.NotifyResponse var res pushgateway.NotifyResponse
if err := s.pgClient.Notify(ctx, url, &req, &res); err != nil { if err := s.pgClient.Notify(ctx, url, &req, &res); err != nil {
logger.WithError(err).Errorf("Failed to notify push gateway %s", url) logger.WithError(err).Errorf("Failed to notify push gateway %s", url)
return nil, err return nil, err
} }
logger.WithField("num_rejected", len(res.Rejected)).Tracef("Push gateway result") logger.WithField("num_rejected", len(res.Rejected)).Trace("Push gateway result")
if len(res.Rejected) == 0 { if len(res.Rejected) == 0 {
return nil, nil return nil, nil
@ -589,7 +597,7 @@ func (s *OutputStreamEventConsumer) notifyHTTP(ctx context.Context, event *gomat
} }
// deleteRejectedPushers deletes the pushers associated with the given devices. // deleteRejectedPushers deletes the pushers associated with the given devices.
func (s *OutputStreamEventConsumer) deleteRejectedPushers(ctx context.Context, devices []*pushgateway.Device, localpart string) { func (s *OutputRoomEventConsumer) deleteRejectedPushers(ctx context.Context, devices []*pushgateway.Device, localpart string) {
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"localpart": localpart, "localpart": localpart,
"app_id0": devices[0].AppID, "app_id0": devices[0].AppID,

View file

@ -40,7 +40,7 @@ func Test_evaluatePushRules(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateDatabase(t, dbType) db, close := mustCreateDatabase(t, dbType)
defer close() defer close()
consumer := OutputStreamEventConsumer{db: db} consumer := OutputRoomEventConsumer{db: db}
testCases := []struct { testCases := []struct {
name string name string

View file

@ -1,137 +0,0 @@
package consumers
import (
"context"
"encoding/json"
"github.com/matrix-org/dendrite/internal/pushgateway"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/setup/process"
"github.com/matrix-org/dendrite/syncapi/types"
uapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/producers"
"github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/dendrite/userapi/util"
"github.com/matrix-org/gomatrixserverlib"
"github.com/nats-io/nats.go"
log "github.com/sirupsen/logrus"
)
type OutputReadUpdateConsumer struct {
ctx context.Context
cfg *config.UserAPI
jetstream nats.JetStreamContext
durable string
db storage.Database
pgClient pushgateway.Client
ServerName gomatrixserverlib.ServerName
topic string
userAPI uapi.UserInternalAPI
syncProducer *producers.SyncAPI
}
func NewOutputReadUpdateConsumer(
process *process.ProcessContext,
cfg *config.UserAPI,
js nats.JetStreamContext,
store storage.Database,
pgClient pushgateway.Client,
userAPI uapi.UserInternalAPI,
syncProducer *producers.SyncAPI,
) *OutputReadUpdateConsumer {
return &OutputReadUpdateConsumer{
ctx: process.Context(),
cfg: cfg,
jetstream: js,
db: store,
ServerName: cfg.Matrix.ServerName,
durable: cfg.Matrix.JetStream.Durable("UserAPISyncAPIReadUpdateConsumer"),
topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputReadUpdate),
pgClient: pgClient,
userAPI: userAPI,
syncProducer: syncProducer,
}
}
func (s *OutputReadUpdateConsumer) Start() error {
if err := jetstream.JetStreamConsumer(
s.ctx, s.jetstream, s.topic, s.durable, 1,
s.onMessage, nats.DeliverAll(), nats.ManualAck(),
); err != nil {
return err
}
return nil
}
func (s *OutputReadUpdateConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool {
msg := msgs[0] // Guaranteed to exist if onMessage is called
var read types.ReadUpdate
if err := json.Unmarshal(msg.Data, &read); err != nil {
log.WithError(err).Error("userapi clientapi consumer: message parse failure")
return true
}
if read.FullyRead == 0 && read.Read == 0 {
return true
}
userID := string(msg.Header.Get(jetstream.UserID))
roomID := string(msg.Header.Get(jetstream.RoomID))
localpart, domain, err := gomatrixserverlib.SplitID('@', userID)
if err != nil {
log.WithError(err).Error("userapi clientapi consumer: SplitID failure")
return true
}
if domain != s.ServerName {
log.Error("userapi clientapi consumer: not a local user")
return true
}
log := log.WithFields(log.Fields{
"room_id": roomID,
"user_id": userID,
})
log.Tracef("Received read update from sync API: %#v", read)
if read.Read > 0 {
updated, err := s.db.SetNotificationsRead(ctx, localpart, roomID, int64(read.Read), true)
if err != nil {
log.WithError(err).Error("userapi EDU consumer")
return false
}
if updated {
if err = s.syncProducer.GetAndSendNotificationData(ctx, userID, roomID); err != nil {
log.WithError(err).Error("userapi EDU consumer: GetAndSendNotificationData failed")
return false
}
if err = util.NotifyUserCountsAsync(ctx, s.pgClient, localpart, s.db); err != nil {
log.WithError(err).Error("userapi EDU consumer: NotifyUserCounts failed")
return false
}
}
}
if read.FullyRead > 0 {
deleted, err := s.db.DeleteNotificationsUpTo(ctx, localpart, roomID, int64(read.FullyRead))
if err != nil {
log.WithError(err).Errorf("userapi clientapi consumer: DeleteNotificationsUpTo failed")
return false
}
if deleted {
if err := util.NotifyUserCountsAsync(ctx, s.pgClient, localpart, s.db); err != nil {
log.WithError(err).Error("userapi clientapi consumer: NotifyUserCounts failed")
return false
}
if err := s.syncProducer.GetAndSendNotificationData(ctx, userID, read.RoomID); err != nil {
log.WithError(err).Errorf("userapi clientapi consumer: GetAndSendNotificationData failed")
return false
}
}
}
return true
}

View file

@ -30,6 +30,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/clientapi/userutil"
"github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/internal/pushgateway"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
keyapi "github.com/matrix-org/dendrite/keyserver/api" keyapi "github.com/matrix-org/dendrite/keyserver/api"
rsapi "github.com/matrix-org/dendrite/roomserver/api" rsapi "github.com/matrix-org/dendrite/roomserver/api"
@ -39,6 +40,7 @@ import (
"github.com/matrix-org/dendrite/userapi/producers" "github.com/matrix-org/dendrite/userapi/producers"
"github.com/matrix-org/dendrite/userapi/storage" "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/dendrite/userapi/storage/tables" "github.com/matrix-org/dendrite/userapi/storage/tables"
userapiUtil "github.com/matrix-org/dendrite/userapi/util"
) )
type UserInternalAPI struct { type UserInternalAPI struct {
@ -51,6 +53,7 @@ type UserInternalAPI struct {
AppServices []config.ApplicationService AppServices []config.ApplicationService
KeyAPI keyapi.UserKeyAPI KeyAPI keyapi.UserKeyAPI
RSAPI rsapi.UserRoomserverAPI RSAPI rsapi.UserRoomserverAPI
PgClient pushgateway.Client
} }
func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAccountDataRequest, res *api.InputAccountDataResponse) error { func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAccountDataRequest, res *api.InputAccountDataResponse) error {
@ -73,6 +76,11 @@ func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAc
ignoredUsers = &synctypes.IgnoredUsers{} ignoredUsers = &synctypes.IgnoredUsers{}
_ = json.Unmarshal(req.AccountData, ignoredUsers) _ = json.Unmarshal(req.AccountData, ignoredUsers)
} }
if req.DataType == "m.fully_read" {
if err := a.setFullyRead(ctx, req); err != nil {
return err
}
}
if err := a.SyncProducer.SendAccountData(req.UserID, eventutil.AccountData{ if err := a.SyncProducer.SendAccountData(req.UserID, eventutil.AccountData{
RoomID: req.RoomID, RoomID: req.RoomID,
Type: req.DataType, Type: req.DataType,
@ -84,6 +92,44 @@ func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAc
return nil return nil
} }
func (a *UserInternalAPI) setFullyRead(ctx context.Context, req *api.InputAccountDataRequest) error {
var output eventutil.ReadMarkerJSON
if err := json.Unmarshal(req.AccountData, &output); err != nil {
return err
}
localpart, domain, err := gomatrixserverlib.SplitID('@', req.UserID)
if err != nil {
logrus.WithError(err).Error("UserInternalAPI.setFullyRead: SplitID failure")
return nil
}
if domain != a.ServerName {
return nil
}
deleted, err := a.DB.DeleteNotificationsUpTo(ctx, localpart, req.RoomID, uint64(gomatrixserverlib.AsTimestamp(time.Now())))
if err != nil {
logrus.WithError(err).Errorf("UserInternalAPI.setFullyRead: DeleteNotificationsUpTo failed")
return err
}
if err = a.SyncProducer.GetAndSendNotificationData(ctx, req.UserID, req.RoomID); err != nil {
logrus.WithError(err).Error("UserInternalAPI.setFullyRead: GetAndSendNotificationData failed")
return err
}
// nothing changed, no need to notify the push gateway
if !deleted {
return nil
}
if err = userapiUtil.NotifyUserCountsAsync(ctx, a.PgClient, localpart, a.DB); err != nil {
logrus.WithError(err).Error("UserInternalAPI.setFullyRead: NotifyUserCounts failed")
return err
}
return nil
}
func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.PerformAccountCreationRequest, res *api.PerformAccountCreationResponse) error { func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.PerformAccountCreationRequest, res *api.PerformAccountCreationResponse) error {
acc, err := a.DB.CreateAccount(ctx, req.Localpart, req.Password, req.AppServiceID, req.AccountType) acc, err := a.DB.CreateAccount(ctx, req.Localpart, req.Password, req.AppServiceID, req.AccountType)
if err != nil { if err != nil {

View file

@ -4,12 +4,13 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/nats-io/nats.go" "github.com/nats-io/nats.go"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/userapi/storage"
) )
type JetStreamPublisher interface { type JetStreamPublisher interface {

View file

@ -119,9 +119,9 @@ type ThreePID interface {
} }
type Notification interface { type Notification interface {
InsertNotification(ctx context.Context, localpart, eventID string, pos int64, tweaks map[string]interface{}, n *api.Notification) error InsertNotification(ctx context.Context, localpart, eventID string, pos uint64, tweaks map[string]interface{}, n *api.Notification) error
DeleteNotificationsUpTo(ctx context.Context, localpart, roomID string, pos int64) (affected bool, err error) DeleteNotificationsUpTo(ctx context.Context, localpart, roomID string, pos uint64) (affected bool, err error)
SetNotificationsRead(ctx context.Context, localpart, roomID string, pos int64, read bool) (affected bool, err error) SetNotificationsRead(ctx context.Context, localpart, roomID string, pos uint64, read bool) (affected bool, err error)
GetNotifications(ctx context.Context, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) GetNotifications(ctx context.Context, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error)
GetNotificationCount(ctx context.Context, localpart string, filter tables.NotificationFilter) (int64, error) GetNotificationCount(ctx context.Context, localpart string, filter tables.NotificationFilter) (int64, error)
GetRoomNotificationCounts(ctx context.Context, localpart, roomID string) (total int64, highlight int64, _ error) GetRoomNotificationCounts(ctx context.Context, localpart, roomID string) (total int64, highlight int64, _ error)

View file

@ -20,12 +20,13 @@ import (
"encoding/json" "encoding/json"
"time" "time"
"github.com/matrix-org/gomatrixserverlib"
log "github.com/sirupsen/logrus"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/tables" "github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/gomatrixserverlib"
log "github.com/sirupsen/logrus"
) )
type notificationsStatements struct { type notificationsStatements struct {
@ -110,7 +111,7 @@ func (s *notificationsStatements) Clean(ctx context.Context, txn *sql.Tx) error
} }
// Insert inserts a notification into the database. // Insert inserts a notification into the database.
func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos int64, highlight bool, n *api.Notification) error { func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos uint64, highlight bool, n *api.Notification) error {
roomID, tsMS := n.RoomID, n.TS roomID, tsMS := n.RoomID, n.TS
nn := *n nn := *n
// Clears out fields that have their own columns to (1) shrink the // Clears out fields that have their own columns to (1) shrink the
@ -126,7 +127,7 @@ func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, local
} }
// DeleteUpTo deletes all previous notifications, up to and including the event. // DeleteUpTo deletes all previous notifications, up to and including the event.
func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos int64) (affected bool, _ error) { func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64) (affected bool, _ error) {
res, err := sqlutil.TxStmt(txn, s.deleteUpToStmt).ExecContext(ctx, localpart, roomID, pos) res, err := sqlutil.TxStmt(txn, s.deleteUpToStmt).ExecContext(ctx, localpart, roomID, pos)
if err != nil { if err != nil {
return false, err return false, err
@ -140,7 +141,7 @@ func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, l
} }
// UpdateRead updates the "read" value for an event. // UpdateRead updates the "read" value for an event.
func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos int64, v bool) (affected bool, _ error) { func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64, v bool) (affected bool, _ error) {
res, err := sqlutil.TxStmt(txn, s.updateReadStmt).ExecContext(ctx, v, localpart, roomID, pos) res, err := sqlutil.TxStmt(txn, s.updateReadStmt).ExecContext(ctx, v, localpart, roomID, pos)
if err != nil { if err != nil {
return false, err return false, err
@ -196,40 +197,12 @@ func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, local
return notifs, maxID, rows.Err() return notifs, maxID, rows.Err()
} }
func (s *notificationsStatements) SelectCount(ctx context.Context, txn *sql.Tx, localpart string, filter tables.NotificationFilter) (int64, error) { func (s *notificationsStatements) SelectCount(ctx context.Context, txn *sql.Tx, localpart string, filter tables.NotificationFilter) (count int64, err error) {
rows, err := sqlutil.TxStmt(txn, s.selectCountStmt).QueryContext(ctx, localpart, uint32(filter)) err = sqlutil.TxStmt(txn, s.selectCountStmt).QueryRowContext(ctx, localpart, uint32(filter)).Scan(&count)
return
if err != nil {
return 0, err
}
defer internal.CloseAndLogIfError(ctx, rows, "notifications.Select: rows.Close() failed")
if rows.Next() {
var count int64
if err := rows.Scan(&count); err != nil {
return 0, err
}
return count, nil
}
return 0, rows.Err()
} }
func (s *notificationsStatements) SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart, roomID string) (total int64, highlight int64, _ error) { func (s *notificationsStatements) SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart, roomID string) (total int64, highlight int64, err error) {
rows, err := sqlutil.TxStmt(txn, s.selectRoomCountsStmt).QueryContext(ctx, localpart, roomID) err = sqlutil.TxStmt(txn, s.selectRoomCountsStmt).QueryRowContext(ctx, localpart, roomID).Scan(&total, &highlight)
return
if err != nil {
return 0, 0, err
}
defer internal.CloseAndLogIfError(ctx, rows, "notifications.Select: rows.Close() failed")
if rows.Next() {
var total, highlight int64
if err := rows.Scan(&total, &highlight); err != nil {
return 0, 0, err
}
return total, highlight, nil
}
return 0, 0, rows.Err()
} }

View file

@ -19,11 +19,12 @@ import (
"database/sql" "database/sql"
"encoding/json" "encoding/json"
"github.com/sirupsen/logrus"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/tables" "github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/sirupsen/logrus"
) )
// See https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-pushers // See https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-pushers
@ -136,7 +137,7 @@ func (s *pushersStatements) SelectPushers(
pushers = append(pushers, pusher) pushers = append(pushers, pusher)
} }
logrus.Debugf("Database returned %d pushers", len(pushers)) logrus.Tracef("Database returned %d pushers", len(pushers))
return pushers, rows.Err() return pushers, rows.Err()
} }

View file

@ -700,13 +700,13 @@ func (d *Database) GetLoginTokenDataByToken(ctx context.Context, token string) (
return d.LoginTokens.SelectLoginToken(ctx, token) return d.LoginTokens.SelectLoginToken(ctx, token)
} }
func (d *Database) InsertNotification(ctx context.Context, localpart, eventID string, pos int64, tweaks map[string]interface{}, n *api.Notification) error { func (d *Database) InsertNotification(ctx context.Context, localpart, eventID string, pos uint64, tweaks map[string]interface{}, n *api.Notification) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.Notifications.Insert(ctx, txn, localpart, eventID, pos, pushrules.BoolTweakOr(tweaks, pushrules.HighlightTweak, false), n) return d.Notifications.Insert(ctx, txn, localpart, eventID, pos, pushrules.BoolTweakOr(tweaks, pushrules.HighlightTweak, false), n)
}) })
} }
func (d *Database) DeleteNotificationsUpTo(ctx context.Context, localpart, roomID string, pos int64) (affected bool, err error) { func (d *Database) DeleteNotificationsUpTo(ctx context.Context, localpart, roomID string, pos uint64) (affected bool, err error) {
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
affected, err = d.Notifications.DeleteUpTo(ctx, txn, localpart, roomID, pos) affected, err = d.Notifications.DeleteUpTo(ctx, txn, localpart, roomID, pos)
return err return err
@ -714,7 +714,7 @@ func (d *Database) DeleteNotificationsUpTo(ctx context.Context, localpart, roomI
return return
} }
func (d *Database) SetNotificationsRead(ctx context.Context, localpart, roomID string, pos int64, b bool) (affected bool, err error) { func (d *Database) SetNotificationsRead(ctx context.Context, localpart, roomID string, pos uint64, b bool) (affected bool, err error) {
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
affected, err = d.Notifications.UpdateRead(ctx, txn, localpart, roomID, pos, b) affected, err = d.Notifications.UpdateRead(ctx, txn, localpart, roomID, pos, b)
return err return err
@ -777,7 +777,7 @@ func (d *Database) GetPushers(
func (d *Database) RemovePusher( func (d *Database) RemovePusher(
ctx context.Context, appid, pushkey, localpart string, ctx context.Context, appid, pushkey, localpart string,
) error { ) error {
return d.Writer.Do(nil, nil, func(txn *sql.Tx) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
err := d.Pushers.DeletePusher(ctx, txn, appid, pushkey, localpart) err := d.Pushers.DeletePusher(ctx, txn, appid, pushkey, localpart)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return nil return nil
@ -792,7 +792,7 @@ func (d *Database) RemovePusher(
func (d *Database) RemovePushers( func (d *Database) RemovePushers(
ctx context.Context, appid, pushkey string, ctx context.Context, appid, pushkey string,
) error { ) error {
return d.Writer.Do(nil, nil, func(txn *sql.Tx) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.Pushers.DeletePushers(ctx, txn, appid, pushkey) return d.Pushers.DeletePushers(ctx, txn, appid, pushkey)
}) })
} }

View file

@ -20,12 +20,13 @@ import (
"encoding/json" "encoding/json"
"time" "time"
"github.com/matrix-org/gomatrixserverlib"
log "github.com/sirupsen/logrus"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/tables" "github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/gomatrixserverlib"
log "github.com/sirupsen/logrus"
) )
type notificationsStatements struct { type notificationsStatements struct {
@ -110,7 +111,7 @@ func (s *notificationsStatements) Clean(ctx context.Context, txn *sql.Tx) error
} }
// Insert inserts a notification into the database. // Insert inserts a notification into the database.
func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos int64, highlight bool, n *api.Notification) error { func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos uint64, highlight bool, n *api.Notification) error {
roomID, tsMS := n.RoomID, n.TS roomID, tsMS := n.RoomID, n.TS
nn := *n nn := *n
// Clears out fields that have their own columns to (1) shrink the // Clears out fields that have their own columns to (1) shrink the
@ -126,7 +127,7 @@ func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, local
} }
// DeleteUpTo deletes all previous notifications, up to and including the event. // DeleteUpTo deletes all previous notifications, up to and including the event.
func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos int64) (affected bool, _ error) { func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64) (affected bool, _ error) {
res, err := sqlutil.TxStmt(txn, s.deleteUpToStmt).ExecContext(ctx, localpart, roomID, pos) res, err := sqlutil.TxStmt(txn, s.deleteUpToStmt).ExecContext(ctx, localpart, roomID, pos)
if err != nil { if err != nil {
return false, err return false, err
@ -140,7 +141,7 @@ func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, l
} }
// UpdateRead updates the "read" value for an event. // UpdateRead updates the "read" value for an event.
func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos int64, v bool) (affected bool, _ error) { func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64, v bool) (affected bool, _ error) {
res, err := sqlutil.TxStmt(txn, s.updateReadStmt).ExecContext(ctx, v, localpart, roomID, pos) res, err := sqlutil.TxStmt(txn, s.updateReadStmt).ExecContext(ctx, v, localpart, roomID, pos)
if err != nil { if err != nil {
return false, err return false, err
@ -196,40 +197,12 @@ func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, local
return notifs, maxID, rows.Err() return notifs, maxID, rows.Err()
} }
func (s *notificationsStatements) SelectCount(ctx context.Context, txn *sql.Tx, localpart string, filter tables.NotificationFilter) (int64, error) { func (s *notificationsStatements) SelectCount(ctx context.Context, txn *sql.Tx, localpart string, filter tables.NotificationFilter) (count int64, err error) {
rows, err := sqlutil.TxStmt(txn, s.selectCountStmt).QueryContext(ctx, localpart, uint32(filter)) err = sqlutil.TxStmt(txn, s.selectCountStmt).QueryRowContext(ctx, localpart, uint32(filter)).Scan(&count)
return
if err != nil {
return 0, err
}
defer internal.CloseAndLogIfError(ctx, rows, "notifications.Select: rows.Close() failed")
if rows.Next() {
var count int64
if err := rows.Scan(&count); err != nil {
return 0, err
}
return count, nil
}
return 0, rows.Err()
} }
func (s *notificationsStatements) SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart, roomID string) (total int64, highlight int64, _ error) { func (s *notificationsStatements) SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart, roomID string) (total int64, highlight int64, err error) {
rows, err := sqlutil.TxStmt(txn, s.selectRoomCountsStmt).QueryContext(ctx, localpart, roomID) err = sqlutil.TxStmt(txn, s.selectRoomCountsStmt).QueryRowContext(ctx, localpart, roomID).Scan(&total, &highlight)
return
if err != nil {
return 0, 0, err
}
defer internal.CloseAndLogIfError(ctx, rows, "notifications.Select: rows.Close() failed")
if rows.Next() {
var total, highlight int64
if err := rows.Scan(&total, &highlight); err != nil {
return 0, 0, err
}
return total, highlight, nil
}
return 0, 0, rows.Err()
} }

View file

@ -19,11 +19,12 @@ import (
"database/sql" "database/sql"
"encoding/json" "encoding/json"
"github.com/sirupsen/logrus"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/tables" "github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/sirupsen/logrus"
) )
// See https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-pushers // See https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-pushers
@ -96,7 +97,7 @@ func (s *pushersStatements) InsertPusher(
ctx context.Context, txn *sql.Tx, session_id int64, ctx context.Context, txn *sql.Tx, session_id int64,
pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string, pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string,
) error { ) error {
_, err := s.insertPusherStmt.ExecContext(ctx, localpart, session_id, pushkey, pushkeyTS, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, data) _, err := sqlutil.TxStmt(txn, s.insertPusherStmt).ExecContext(ctx, localpart, session_id, pushkey, pushkeyTS, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, data)
logrus.Debugf("Created pusher %d", session_id) logrus.Debugf("Created pusher %d", session_id)
return err return err
} }
@ -136,7 +137,7 @@ func (s *pushersStatements) SelectPushers(
pushers = append(pushers, pusher) pushers = append(pushers, pusher)
} }
logrus.Debugf("Database returned %d pushers", len(pushers)) logrus.Tracef("Database returned %d pushers", len(pushers))
return pushers, rows.Err() return pushers, rows.Err()
} }
@ -144,13 +145,13 @@ func (s *pushersStatements) SelectPushers(
func (s *pushersStatements) DeletePusher( func (s *pushersStatements) DeletePusher(
ctx context.Context, txn *sql.Tx, appid, pushkey, localpart string, ctx context.Context, txn *sql.Tx, appid, pushkey, localpart string,
) error { ) error {
_, err := s.deletePusherStmt.ExecContext(ctx, appid, pushkey, localpart) _, err := sqlutil.TxStmt(txn, s.deletePusherStmt).ExecContext(ctx, appid, pushkey, localpart)
return err return err
} }
func (s *pushersStatements) DeletePushers( func (s *pushersStatements) DeletePushers(
ctx context.Context, txn *sql.Tx, appid, pushkey string, ctx context.Context, txn *sql.Tx, appid, pushkey string,
) error { ) error {
_, err := s.deletePushersByAppIdAndPushKeyStmt.ExecContext(ctx, appid, pushkey) _, err := sqlutil.TxStmt(txn, s.deletePushersByAppIdAndPushKeyStmt).ExecContext(ctx, appid, pushkey)
return err return err
} }

View file

@ -7,6 +7,11 @@ import (
"testing" "testing"
"time" "time"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/stretchr/testify/assert"
"golang.org/x/crypto/bcrypt"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/internal/pushrules" "github.com/matrix-org/dendrite/internal/pushrules"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
@ -14,10 +19,6 @@ import (
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage" "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/dendrite/userapi/storage/tables" "github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/stretchr/testify/assert"
"golang.org/x/crypto/bcrypt"
) )
const loginTokenLifetime = time.Minute const loginTokenLifetime = time.Minute
@ -513,7 +514,7 @@ func Test_Notification(t *testing.T) {
RoomID: roomID, RoomID: roomID,
TS: gomatrixserverlib.AsTimestamp(ts), TS: gomatrixserverlib.AsTimestamp(ts),
} }
err = db.InsertNotification(ctx, aliceLocalpart, eventID, int64(i+1), nil, notification) err = db.InsertNotification(ctx, aliceLocalpart, eventID, uint64(i+1), nil, notification)
assert.NoError(t, err, "unable to insert notification") assert.NoError(t, err, "unable to insert notification")
} }

View file

@ -105,9 +105,9 @@ type PusherTable interface {
type NotificationTable interface { type NotificationTable interface {
Clean(ctx context.Context, txn *sql.Tx) error Clean(ctx context.Context, txn *sql.Tx) error
Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos int64, highlight bool, n *api.Notification) error Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos uint64, highlight bool, n *api.Notification) error
DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos int64) (affected bool, _ error) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64) (affected bool, _ error)
UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos int64, v bool) (affected bool, _ error) UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64, v bool) (affected bool, _ error)
Select(ctx context.Context, txn *sql.Tx, localpart string, fromID int64, limit int, filter NotificationFilter) ([]*api.Notification, int64, error) Select(ctx context.Context, txn *sql.Tx, localpart string, fromID int64, limit int, filter NotificationFilter) ([]*api.Notification, int64, error)
SelectCount(ctx context.Context, txn *sql.Tx, localpart string, filter NotificationFilter) (int64, error) SelectCount(ctx context.Context, txn *sql.Tx, localpart string, filter NotificationFilter) (int64, error)
SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart, roomID string) (total int64, highlight int64, _ error) SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart, roomID string) (total int64, highlight int64, _ error)

View file

@ -81,16 +81,17 @@ func NewInternalAPI(
KeyAPI: keyAPI, KeyAPI: keyAPI,
RSAPI: rsAPI, RSAPI: rsAPI,
DisableTLSValidation: cfg.PushGatewayDisableTLSValidation, DisableTLSValidation: cfg.PushGatewayDisableTLSValidation,
PgClient: pgClient,
} }
readConsumer := consumers.NewOutputReadUpdateConsumer( receiptConsumer := consumers.NewOutputReceiptEventConsumer(
base.ProcessContext, cfg, js, db, pgClient, userAPI, syncProducer, base.ProcessContext, cfg, js, db, syncProducer, pgClient,
) )
if err := readConsumer.Start(); err != nil { if err := receiptConsumer.Start(); err != nil {
logrus.WithError(err).Panic("failed to start user API read update consumer") logrus.WithError(err).Panic("failed to start user API receipt consumer")
} }
eventConsumer := consumers.NewOutputStreamEventConsumer( eventConsumer := consumers.NewOutputRoomEventConsumer(
base.ProcessContext, cfg, js, db, pgClient, rsAPI, syncProducer, base.ProcessContext, cfg, js, db, pgClient, rsAPI, syncProducer,
) )
if err := eventConsumer.Start(); err != nil { if err := eventConsumer.Start(); err != nil {