Merge commit '1bf9dc0ca2c2eae4efd4b282a39d3bfe42e5e10e' into austin.ellis/dendrite

# Conflicts:
#	servers/dendrite/build/docker/Dockerfile.polylith
This commit is contained in:
texuf 2022-09-28 15:12:34 -07:00
commit 32d48002f7
89 changed files with 1493 additions and 808 deletions

View file

@ -2,7 +2,7 @@
<!-- Please read docs/CONTRIBUTING.md before submitting your pull request --> <!-- Please read docs/CONTRIBUTING.md before submitting your pull request -->
* [ ] I have added added tests for PR _or_ I have justified why this PR doesn't need tests. * [ ] I have added tests for PR _or_ I have justified why this PR doesn't need tests.
* [ ] Pull request includes a [sign off](https://github.com/matrix-org/dendrite/blob/main/docs/CONTRIBUTING.md#sign-off) * [ ] Pull request includes a [sign off](https://github.com/matrix-org/dendrite/blob/main/docs/CONTRIBUTING.md#sign-off)
Signed-off-by: `Your Name <your@email.example.org>` Signed-off-by: `Your Name <your@email.example.org>`

View file

@ -139,3 +139,63 @@ jobs:
${{ env.DOCKER_NAMESPACE }}/dendrite-polylith:${{ env.RELEASE_VERSION }} ${{ env.DOCKER_NAMESPACE }}/dendrite-polylith:${{ env.RELEASE_VERSION }}
ghcr.io/${{ env.GHCR_NAMESPACE }}/dendrite-polylith:latest ghcr.io/${{ env.GHCR_NAMESPACE }}/dendrite-polylith:latest
ghcr.io/${{ env.GHCR_NAMESPACE }}/dendrite-polylith:${{ env.RELEASE_VERSION }} ghcr.io/${{ env.GHCR_NAMESPACE }}/dendrite-polylith:${{ env.RELEASE_VERSION }}
demo-pinecone:
name: Pinecone demo image
runs-on: ubuntu-latest
permissions:
contents: read
packages: write
steps:
- name: Checkout
uses: actions/checkout@v2
- name: Get release tag
if: github.event_name == 'release' # Only for GitHub releases
run: echo "RELEASE_VERSION=${GITHUB_REF#refs/*/}" >> $GITHUB_ENV
- name: Set up QEMU
uses: docker/setup-qemu-action@v1
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v1
- name: Login to Docker Hub
uses: docker/login-action@v1
with:
username: ${{ env.DOCKER_HUB_USER }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Login to GitHub Containers
uses: docker/login-action@v1
with:
registry: ghcr.io
username: ${{ github.repository_owner }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Build main pinecone demo image
if: github.ref_name == 'main'
id: docker_build_demo_pinecone
uses: docker/build-push-action@v2
with:
cache-from: type=gha
cache-to: type=gha,mode=max
context: .
file: ./build/docker/Dockerfile.demo-pinecone
platforms: ${{ env.PLATFORMS }}
push: true
tags: |
${{ env.DOCKER_NAMESPACE }}/dendrite-demo-pinecone:${{ github.ref_name }}
ghcr.io/${{ env.GHCR_NAMESPACE }}/dendrite-demo-pinecone:${{ github.ref_name }}
- name: Build release pinecone demo image
if: github.event_name == 'release' # Only for GitHub releases
id: docker_build_demo_pinecone_release
uses: docker/build-push-action@v2
with:
cache-from: type=gha
cache-to: type=gha,mode=max
context: .
file: ./build/docker/Dockerfile.demo-pinecone
platforms: ${{ env.PLATFORMS }}
push: true
tags: |
${{ env.DOCKER_NAMESPACE }}/dendrite-demo-pinecone:latest
${{ env.DOCKER_NAMESPACE }}/dendrite-demo-pinecone:${{ env.RELEASE_VERSION }}
ghcr.io/${{ env.GHCR_NAMESPACE }}/dendrite-demo-pinecone:latest
ghcr.io/${{ env.GHCR_NAMESPACE }}/dendrite-demo-pinecone:${{ env.RELEASE_VERSION }}

View file

@ -0,0 +1,35 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package authorization
import "github.com/matrix-org/dendrite/setup/config"
type AuthorizationArgs struct {
RoomId string
UserId string
Permission string
}
type Authorization interface {
IsAllowed(args AuthorizationArgs) (bool, error)
}
func NewClientApiAuthorization(cfg *config.ClientAPI) Authorization {
// Load authorization manager for Zion
//if cfg.PublicKeyAuthentication.Ethereum.EnableAuthz {
//}
return &DefaultAuthorization{}
}

View file

@ -0,0 +1,23 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package authorization
type DefaultAuthorization struct {
}
func (azm *DefaultAuthorization) IsAllowed(args AuthorizationArgs) (bool, error) {
// Default. No authorization logic.
return true, nil
}

View file

@ -0,0 +1,25 @@
FROM docker.io/golang:1.19-alpine AS base
RUN apk --update --no-cache add bash build-base
WORKDIR /build
COPY . /build
RUN mkdir -p bin
RUN go build -trimpath -o bin/ ./cmd/dendrite-demo-pinecone
RUN go build -trimpath -o bin/ ./cmd/create-account
RUN go build -trimpath -o bin/ ./cmd/generate-keys
FROM alpine:latest
LABEL org.opencontainers.image.title="Dendrite (Pinecone demo)"
LABEL org.opencontainers.image.description="Next-generation Matrix homeserver written in Go"
LABEL org.opencontainers.image.source="https://github.com/matrix-org/dendrite"
LABEL org.opencontainers.image.licenses="Apache-2.0"
COPY --from=base /build/bin/* /usr/bin/
VOLUME /etc/dendrite
WORKDIR /etc/dendrite
ENTRYPOINT ["/usr/bin/dendrite-demo-pinecone"]

View file

@ -1,4 +1,4 @@
FROM docker.io/golang:1.19-buster AS base FROM docker.io/golang:1.19-alpine AS base
RUN apk --update --no-cache add bash build-base RUN apk --update --no-cache add bash build-base

View file

@ -1,4 +1,4 @@
#!/bin/bash #!/usr/bin/env bash
cd $(git rev-parse --show-toplevel) cd $(git rev-parse --show-toplevel)

View file

@ -1,4 +1,4 @@
#!/bin/bash #!/usr/bin/env bash
TAG=${1:-latest} TAG=${1:-latest}

View file

@ -1,4 +1,4 @@
#!/bin/bash #!/usr/bin/env bash
TAG=${1:-latest} TAG=${1:-latest}

View file

@ -3,15 +3,19 @@ package routing
import ( import (
"encoding/json" "encoding/json"
"net/http" "net/http"
"time"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/internal/httputil"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/jetstream"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
"github.com/nats-io/nats.go"
"github.com/sirupsen/logrus"
) )
func AdminEvacuateRoom(req *http.Request, cfg *config.ClientAPI, device *userapi.Device, rsAPI roomserverAPI.ClientRoomserverAPI) util.JSONResponse { func AdminEvacuateRoom(req *http.Request, cfg *config.ClientAPI, device *userapi.Device, rsAPI roomserverAPI.ClientRoomserverAPI) util.JSONResponse {
@ -138,3 +142,21 @@ func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *userap
}, },
} }
} }
func AdminReindex(req *http.Request, cfg *config.ClientAPI, device *userapi.Device, natsClient *nats.Conn) util.JSONResponse {
if device.AccountType != userapi.AccountTypeAdmin {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Forbidden("This API can only be used by admin users."),
}
}
_, err := natsClient.RequestMsg(nats.NewMsg(cfg.Matrix.JetStream.Prefixed(jetstream.InputFulltextReindex)), time.Second*10)
if err != nil {
logrus.WithError(err).Error("failed to publish nats message")
return jsonerror.InternalServerError()
}
return util.JSONResponse{
Code: http.StatusOK,
JSON: struct{}{},
}
}

View file

@ -20,7 +20,14 @@ import (
"strings" "strings"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/nats-io/nats.go"
"github.com/prometheus/client_golang/prometheus"
"github.com/sirupsen/logrus"
appserviceAPI "github.com/matrix-org/dendrite/appservice/api" appserviceAPI "github.com/matrix-org/dendrite/appservice/api"
"github.com/matrix-org/dendrite/authorization"
"github.com/matrix-org/dendrite/clientapi/api" "github.com/matrix-org/dendrite/clientapi/api"
"github.com/matrix-org/dendrite/clientapi/auth" "github.com/matrix-org/dendrite/clientapi/auth"
clientutil "github.com/matrix-org/dendrite/clientapi/httputil" clientutil "github.com/matrix-org/dendrite/clientapi/httputil"
@ -34,11 +41,6 @@ import (
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/jetstream"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/nats-io/nats.go"
"github.com/prometheus/client_golang/prometheus"
"github.com/sirupsen/logrus"
) )
var ReleaseVersion string var ReleaseVersion string
@ -73,6 +75,8 @@ func Setup(
rateLimits := httputil.NewRateLimits(&cfg.RateLimiting) rateLimits := httputil.NewRateLimits(&cfg.RateLimiting)
userInteractiveAuth := auth.NewUserInteractive(userAPI, userAPI, cfg) userInteractiveAuth := auth.NewUserInteractive(userAPI, userAPI, cfg)
authorization := authorization.NewClientApiAuthorization(cfg)
_ = authorization // todo: use this in httputil.MakeAuthAPI
unstableFeatures := map[string]bool{ unstableFeatures := map[string]bool{
"org.matrix.e2e_cross_signing": true, "org.matrix.e2e_cross_signing": true,
@ -169,6 +173,12 @@ func Setup(
}), }),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
dendriteAdminRouter.Handle("/admin/fulltext/reindex",
httputil.MakeAuthAPI("admin_fultext_reindex", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return AdminReindex(req, cfg, device, natsClient)
}),
).Methods(http.MethodGet, http.MethodOptions)
// server notifications // server notifications
if cfg.Matrix.ServerNotices.Enabled { if cfg.Matrix.ServerNotices.Enabled {
logrus.Info("Enabling server notices at /_synapse/admin/v1/send_server_notice") logrus.Info("Enabling server notices at /_synapse/admin/v1/send_server_notice")

View file

@ -5,10 +5,11 @@ import (
"fmt" "fmt"
"path/filepath" "path/filepath"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
"gopkg.in/yaml.v2" "gopkg.in/yaml.v2"
"github.com/matrix-org/dendrite/setup/config"
) )
func main() { func main() {
@ -82,6 +83,12 @@ func main() {
EnableInbound: true, EnableInbound: true,
EnableOutbound: true, EnableOutbound: true,
} }
cfg.SyncAPI.Fulltext = config.Fulltext{
Enabled: true,
IndexPath: config.Path(filepath.Join(*dirPath, "searchindex")),
InMemory: true,
Language: "en",
}
} }
} else { } else {
var err error var err error

View file

@ -285,10 +285,19 @@ sync_api:
# address of the client. This is likely required if Dendrite is running behind # address of the client. This is likely required if Dendrite is running behind
# a reverse proxy server. # a reverse proxy server.
# real_ip_header: X-Real-IP # real_ip_header: X-Real-IP
fulltext:
# Configuration for the full-text search engine.
search:
# Whether or not search is enabled.
enabled: false enabled: false
index_path: "./fulltextindex"
language: "en" # more possible languages can be found at https://github.com/blevesearch/bleve/tree/master/analysis/lang # The path where the search index will be created in.
index_path: "./searchindex"
# The language most likely to be used on the server - used when indexing, to
# ensure the returned results match expectations. A full list of possible languages
# can be found at https://github.com/blevesearch/bleve/tree/master/analysis/lang
language: "en"
# Configuration for the User API. # Configuration for the User API.
user_api: user_api:

View file

@ -336,10 +336,19 @@ sync_api:
max_open_conns: 10 max_open_conns: 10
max_idle_conns: 2 max_idle_conns: 2
conn_max_lifetime: -1 conn_max_lifetime: -1
fulltext:
# Configuration for the full-text search engine.
search:
# Whether or not search is enabled.
enabled: false enabled: false
index_path: "./fulltextindex"
language: "en" # more possible languages can be found at https://github.com/blevesearch/bleve/tree/master/analysis/lang # The path where the search index will be created in.
index_path: "./searchindex"
# The language most likely to be used on the server - used when indexing, to
# ensure the returned results match expectations. A full list of possible languages
# can be found at https://github.com/blevesearch/bleve/tree/master/analysis/lang
language: "en"
# This option controls which HTTP header to inspect to find the real remote IP # This option controls which HTTP header to inspect to find the real remote IP
# address of the client. This is likely required if Dendrite is running behind # address of the client. This is likely required if Dendrite is running behind

View file

@ -57,6 +57,11 @@ Request body format:
Reset the password of a local user. The `localpart` is the username only, i.e. if Reset the password of a local user. The `localpart` is the username only, i.e. if
the full user ID is `@alice:domain.com` then the local part is `alice`. the full user ID is `@alice:domain.com` then the local part is `alice`.
## GET `/_dendrite/admin/fulltext/reindex`
This endpoint instructs Dendrite to reindex all searchable events (`m.room.message`, `m.room.topic` and `m.room.name`). An empty JSON body will be returned immediately.
Indexing is done in the background, the server logs every 1000 events (or below) when they are being indexed. Once reindexing is done, you'll see something along the lines `Indexed 69586 events in 53.68223182s` in your debug logs.
## POST `/_synapse/admin/v1/send_server_notice` ## POST `/_synapse/admin/v1/send_server_notice`
Request body format: Request body format:

View file

@ -16,6 +16,9 @@ Dendrite can automatically populate the database with the relevant tables and in
it is not capable of creating the databases themselves. You will need to create the databases it is not capable of creating the databases themselves. You will need to create the databases
manually. manually.
The databases **must** be created with UTF-8 encoding configured or you will likely run into problems
with your Dendrite deployment.
At this point, you can choose to either use a single database for all Dendrite components, At this point, you can choose to either use a single database for all Dendrite components,
or you can run each component with its own separate database: or you can run each component with its own separate database:
@ -65,7 +68,7 @@ sudo -u postgres createuser -P dendrite
Create the database itself, using the `dendrite` role from above: Create the database itself, using the `dendrite` role from above:
```bash ```bash
sudo -u postgres createdb -O dendrite dendrite sudo -u postgres createdb -O dendrite -E UTF-8 dendrite
``` ```
### Multiple database creation ### Multiple database creation
@ -85,7 +88,7 @@ The following eight components require a database. In this example they will be
```bash ```bash
for i in appservice federationapi mediaapi mscs roomserver syncapi keyserver userapi; do for i in appservice federationapi mediaapi mscs roomserver syncapi keyserver userapi; do
sudo -u postgres createdb -O dendrite dendrite_$i sudo -u postgres createdb -O dendrite -E UTF-8 dendrite_$i
done done
``` ```

View file

@ -138,16 +138,18 @@ room_server:
conn_max_lifetime: -1 conn_max_lifetime: -1
``` ```
## Fulltext search ## Full-text search
Dendrite supports experimental fulltext indexing using [Bleve](https://github.com/blevesearch/bleve), it is configured in the `sync_api` section as follows. Depending on the language most likely to be used on the server, it might make sense to change the `language` used when indexing, to ensure the returned results match the expections. A full list of possible languages can be found [here](https://github.com/blevesearch/bleve/tree/master/analysis/lang). Dendrite supports experimental full-text indexing using [Bleve](https://github.com/blevesearch/bleve). It is configured in the `sync_api` section as follows.
Depending on the language most likely to be used on the server, it might make sense to change the `language` used when indexing, to ensure the returned results match the expectations. A full list of possible languages can be found [here](https://github.com/blevesearch/bleve/tree/master/analysis/lang).
```yaml ```yaml
sync_api: sync_api:
# ... # ...
fulltext: search:
enabled: false enabled: false
index_path: "./fulltextindex" index_path: "./searchindex"
language: "en" language: "en"
``` ```

View file

@ -217,7 +217,7 @@ func (r *FederationInternalAPI) performJoinUsingServer(
var remoteEvent *gomatrixserverlib.Event var remoteEvent *gomatrixserverlib.Event
remoteEvent, err = respSendJoin.Event.UntrustedEvent(respMakeJoin.RoomVersion) remoteEvent, err = respSendJoin.Event.UntrustedEvent(respMakeJoin.RoomVersion)
if err == nil && isWellFormedMembershipEvent( if err == nil && isWellFormedMembershipEvent(
remoteEvent, roomID, userID, r.cfg.Matrix.ServerName, remoteEvent, roomID, userID,
) { ) {
event = remoteEvent event = remoteEvent
} }
@ -285,7 +285,7 @@ func (r *FederationInternalAPI) performJoinUsingServer(
// isWellFormedMembershipEvent returns true if the event looks like a legitimate // isWellFormedMembershipEvent returns true if the event looks like a legitimate
// membership event. // membership event.
func isWellFormedMembershipEvent(event *gomatrixserverlib.Event, roomID, userID string, origin gomatrixserverlib.ServerName) bool { func isWellFormedMembershipEvent(event *gomatrixserverlib.Event, roomID, userID string) bool {
if membership, err := event.Membership(); err != nil { if membership, err := event.Membership(); err != nil {
return false return false
} else if membership != gomatrixserverlib.Join { } else if membership != gomatrixserverlib.Join {
@ -294,9 +294,6 @@ func isWellFormedMembershipEvent(event *gomatrixserverlib.Event, roomID, userID
if event.RoomID() != roomID { if event.RoomID() != roomID {
return false return false
} }
if event.Origin() != origin {
return false
}
if !event.StateKeyEquals(userID) { if !event.StateKeyEquals(userID) {
return false return false
} }

View file

@ -148,8 +148,15 @@ func processInvite(
JSON: jsonerror.BadJSON("The event JSON could not be redacted"), JSON: jsonerror.BadJSON("The event JSON could not be redacted"),
} }
} }
_, serverName, err := gomatrixserverlib.SplitID('@', event.Sender())
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.BadJSON("The event JSON contains an invalid sender"),
}
}
verifyRequests := []gomatrixserverlib.VerifyJSONRequest{{ verifyRequests := []gomatrixserverlib.VerifyJSONRequest{{
ServerName: event.Origin(), ServerName: serverName,
Message: redacted, Message: redacted,
AtTS: event.OriginServerTS(), AtTS: event.OriginServerTS(),
StrictValidityChecking: true, StrictValidityChecking: true,

View file

@ -203,14 +203,6 @@ func SendJoin(
} }
} }
// Check that the event is from the server sending the request.
if event.Origin() != request.Origin() {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Forbidden("The join must be sent by the server it originated on"),
}
}
// Check that a state key is provided. // Check that a state key is provided.
if event.StateKey() == nil || event.StateKeyEquals("") { if event.StateKey() == nil || event.StateKeyEquals("") {
return util.JSONResponse{ return util.JSONResponse{
@ -228,16 +220,16 @@ func SendJoin(
// Check that the sender belongs to the server that is sending us // Check that the sender belongs to the server that is sending us
// the request. By this point we've already asserted that the sender // the request. By this point we've already asserted that the sender
// and the state key are equal so we don't need to check both. // and the state key are equal so we don't need to check both.
var domain gomatrixserverlib.ServerName var serverName gomatrixserverlib.ServerName
if _, domain, err = gomatrixserverlib.SplitID('@', event.Sender()); err != nil { if _, serverName, err = gomatrixserverlib.SplitID('@', event.Sender()); err != nil {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusForbidden, Code: http.StatusForbidden,
JSON: jsonerror.Forbidden("The sender of the join is invalid"), JSON: jsonerror.Forbidden("The sender of the join is invalid"),
} }
} else if domain != request.Origin() { } else if serverName != request.Origin() {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusForbidden, Code: http.StatusForbidden,
JSON: jsonerror.Forbidden("The sender of the join must belong to the origin server"), JSON: jsonerror.Forbidden("The sender does not match the server that originated the request"),
} }
} }
@ -292,7 +284,7 @@ func SendJoin(
} }
} }
verifyRequests := []gomatrixserverlib.VerifyJSONRequest{{ verifyRequests := []gomatrixserverlib.VerifyJSONRequest{{
ServerName: event.Origin(), ServerName: serverName,
Message: redacted, Message: redacted,
AtTS: event.OriginServerTS(), AtTS: event.OriginServerTS(),
StrictValidityChecking: true, StrictValidityChecking: true,

View file

@ -118,6 +118,7 @@ func MakeLeave(
} }
// SendLeave implements the /send_leave API // SendLeave implements the /send_leave API
// nolint:gocyclo
func SendLeave( func SendLeave(
httpReq *http.Request, httpReq *http.Request,
request *gomatrixserverlib.FederationRequest, request *gomatrixserverlib.FederationRequest,
@ -167,14 +168,6 @@ func SendLeave(
} }
} }
// Check that the event is from the server sending the request.
if event.Origin() != request.Origin() {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Forbidden("The leave must be sent by the server it originated on"),
}
}
if event.StateKey() == nil || event.StateKeyEquals("") { if event.StateKey() == nil || event.StateKeyEquals("") {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusBadRequest, Code: http.StatusBadRequest,
@ -188,6 +181,22 @@ func SendLeave(
} }
} }
// Check that the sender belongs to the server that is sending us
// the request. By this point we've already asserted that the sender
// and the state key are equal so we don't need to check both.
var serverName gomatrixserverlib.ServerName
if _, serverName, err = gomatrixserverlib.SplitID('@', event.Sender()); err != nil {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Forbidden("The sender of the join is invalid"),
}
} else if serverName != request.Origin() {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Forbidden("The sender does not match the server that originated the request"),
}
}
// Check if the user has already left. If so, no-op! // Check if the user has already left. If so, no-op!
queryReq := &api.QueryLatestEventsAndStateRequest{ queryReq := &api.QueryLatestEventsAndStateRequest{
RoomID: roomID, RoomID: roomID,
@ -240,7 +249,7 @@ func SendLeave(
} }
} }
verifyRequests := []gomatrixserverlib.VerifyJSONRequest{{ verifyRequests := []gomatrixserverlib.VerifyJSONRequest{{
ServerName: event.Origin(), ServerName: serverName,
Message: redacted, Message: redacted,
AtTS: event.OriginServerTS(), AtTS: event.OriginServerTS(),
StrictValidityChecking: true, StrictValidityChecking: true,

4
go.mod
View file

@ -26,8 +26,8 @@ require (
github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91
github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16
github.com/matrix-org/gomatrixserverlib v0.0.0-20220923115829-2217f6c65ce3 github.com/matrix-org/gomatrixserverlib v0.0.0-20220926161602-759a8ee7c4d5
github.com/matrix-org/pinecone v0.0.0-20220923151905-0900fceecb89 github.com/matrix-org/pinecone v0.0.0-20220927101513-d0beb180f44d
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4
github.com/mattn/go-sqlite3 v1.14.15 github.com/mattn/go-sqlite3 v1.14.15
github.com/nats-io/nats-server/v2 v2.9.1 github.com/nats-io/nats-server/v2 v2.9.1

8
go.sum
View file

@ -437,10 +437,10 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 h1:s7fexw
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo= github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo=
github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 h1:ZtO5uywdd5dLDCud4r0r55eP4j9FuUNpl60Gmntcop4= github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 h1:ZtO5uywdd5dLDCud4r0r55eP4j9FuUNpl60Gmntcop4=
github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s=
github.com/matrix-org/gomatrixserverlib v0.0.0-20220923115829-2217f6c65ce3 h1:u3FKZmXxfhv3XhD8RziBlt96QTt8eHFhg1upCloBh2g= github.com/matrix-org/gomatrixserverlib v0.0.0-20220926161602-759a8ee7c4d5 h1:cQMA9hip0WSp6cv7CUfButa9Jl/9E6kqWmQyOjx5A5s=
github.com/matrix-org/gomatrixserverlib v0.0.0-20220923115829-2217f6c65ce3/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4= github.com/matrix-org/gomatrixserverlib v0.0.0-20220926161602-759a8ee7c4d5/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4=
github.com/matrix-org/pinecone v0.0.0-20220923151905-0900fceecb89 h1:Ym50Fgn3GiYya4p29k3nJ5nYsalFGev3eIm3DeGNIq4= github.com/matrix-org/pinecone v0.0.0-20220927101513-d0beb180f44d h1:kGPJ6Rg8nn5an2CbCZrRiuTNyNzE0rRMiqm4UXJYrRs=
github.com/matrix-org/pinecone v0.0.0-20220923151905-0900fceecb89/go.mod h1:K0N1ixHQxXoCyqolDqVxPM3ArrDtcMs8yegOx2Lfv9k= github.com/matrix-org/pinecone v0.0.0-20220927101513-d0beb180f44d/go.mod h1:K0N1ixHQxXoCyqolDqVxPM3ArrDtcMs8yegOx2Lfv9k=
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 h1:eCEHXWDv9Rm335MSuB49mFUK44bwZPFSDde3ORE3syk= github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 h1:eCEHXWDv9Rm335MSuB49mFUK44bwZPFSDde3ORE3syk=
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U= github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U=
github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU=

View file

@ -22,8 +22,9 @@ import (
"github.com/blevesearch/bleve/v2" "github.com/blevesearch/bleve/v2"
"github.com/blevesearch/bleve/v2/mapping" "github.com/blevesearch/bleve/v2/mapping"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/setup/config"
) )
// Search contains all existing bleve.Index // Search contains all existing bleve.Index

View file

@ -27,11 +27,11 @@ import (
func mustOpenIndex(t *testing.T, tempDir string) *fulltext.Search { func mustOpenIndex(t *testing.T, tempDir string) *fulltext.Search {
t.Helper() t.Helper()
cfg := config.Fulltext{} cfg := config.Fulltext{
cfg.Defaults(config.DefaultOpts{ Enabled: true,
Generate: true, InMemory: true,
Monolithic: true, Language: "en",
}) }
if tempDir != "" { if tempDir != "" {
cfg.IndexPath = config.Path(tempDir) cfg.IndexPath = config.Path(tempDir)
cfg.InMemory = false cfg.InMemory = false

View file

@ -118,6 +118,10 @@ func (r *Inputer) processRoomEvent(
if roomInfo == nil && !isCreateEvent { if roomInfo == nil && !isCreateEvent {
return fmt.Errorf("room %s does not exist for event %s", event.RoomID(), event.EventID()) return fmt.Errorf("room %s does not exist for event %s", event.RoomID(), event.EventID())
} }
_, senderDomain, err := gomatrixserverlib.SplitID('@', event.Sender())
if err != nil {
return fmt.Errorf("event has invalid sender %q", input.Event.Sender())
}
// If we already know about this outlier and it hasn't been rejected // If we already know about this outlier and it hasn't been rejected
// then we won't attempt to reprocess it. If it was rejected or has now // then we won't attempt to reprocess it. If it was rejected or has now
@ -145,7 +149,8 @@ func (r *Inputer) processRoomEvent(
var missingAuth, missingPrev bool var missingAuth, missingPrev bool
serverRes := &fedapi.QueryJoinedHostServerNamesInRoomResponse{} serverRes := &fedapi.QueryJoinedHostServerNamesInRoomResponse{}
if !isCreateEvent { if !isCreateEvent {
missingAuthIDs, missingPrevIDs, err := r.DB.MissingAuthPrevEvents(ctx, event) var missingAuthIDs, missingPrevIDs []string
missingAuthIDs, missingPrevIDs, err = r.DB.MissingAuthPrevEvents(ctx, event)
if err != nil { if err != nil {
return fmt.Errorf("updater.MissingAuthPrevEvents: %w", err) return fmt.Errorf("updater.MissingAuthPrevEvents: %w", err)
} }
@ -158,7 +163,7 @@ func (r *Inputer) processRoomEvent(
RoomID: event.RoomID(), RoomID: event.RoomID(),
ExcludeSelf: true, ExcludeSelf: true,
} }
if err := r.FSAPI.QueryJoinedHostServerNamesInRoom(ctx, serverReq, serverRes); err != nil { if err = r.FSAPI.QueryJoinedHostServerNamesInRoom(ctx, serverReq, serverRes); err != nil {
return fmt.Errorf("r.FSAPI.QueryJoinedHostServerNamesInRoom: %w", err) return fmt.Errorf("r.FSAPI.QueryJoinedHostServerNamesInRoom: %w", err)
} }
// Sort all of the servers into a map so that we can randomise // Sort all of the servers into a map so that we can randomise
@ -173,9 +178,9 @@ func (r *Inputer) processRoomEvent(
serverRes.ServerNames = append(serverRes.ServerNames, input.Origin) serverRes.ServerNames = append(serverRes.ServerNames, input.Origin)
delete(servers, input.Origin) delete(servers, input.Origin)
} }
if origin := event.Origin(); origin != input.Origin { if senderDomain != input.Origin {
serverRes.ServerNames = append(serverRes.ServerNames, origin) serverRes.ServerNames = append(serverRes.ServerNames, senderDomain)
delete(servers, origin) delete(servers, senderDomain)
} }
for server := range servers { for server := range servers {
serverRes.ServerNames = append(serverRes.ServerNames, server) serverRes.ServerNames = append(serverRes.ServerNames, server)
@ -188,7 +193,7 @@ func (r *Inputer) processRoomEvent(
isRejected := false isRejected := false
authEvents := gomatrixserverlib.NewAuthEvents(nil) authEvents := gomatrixserverlib.NewAuthEvents(nil)
knownEvents := map[string]*types.Event{} knownEvents := map[string]*types.Event{}
if err := r.fetchAuthEvents(ctx, logger, headered, &authEvents, knownEvents, serverRes.ServerNames); err != nil { if err = r.fetchAuthEvents(ctx, logger, headered, &authEvents, knownEvents, serverRes.ServerNames); err != nil {
return fmt.Errorf("r.fetchAuthEvents: %w", err) return fmt.Errorf("r.fetchAuthEvents: %w", err)
} }
@ -231,7 +236,6 @@ func (r *Inputer) processRoomEvent(
if input.Kind == api.KindNew { if input.Kind == api.KindNew {
// Check that the event passes authentication checks based on the // Check that the event passes authentication checks based on the
// current room state. // current room state.
var err error
softfail, err = helpers.CheckForSoftFail(ctx, r.DB, headered, input.StateEventIDs) softfail, err = helpers.CheckForSoftFail(ctx, r.DB, headered, input.StateEventIDs)
if err != nil { if err != nil {
logger.WithError(err).Warn("Error authing soft-failed event") logger.WithError(err).Warn("Error authing soft-failed event")
@ -265,7 +269,8 @@ func (r *Inputer) processRoomEvent(
hadEvents: map[string]bool{}, hadEvents: map[string]bool{},
haveEvents: map[string]*gomatrixserverlib.Event{}, haveEvents: map[string]*gomatrixserverlib.Event{},
} }
if stateSnapshot, err := missingState.processEventWithMissingState(ctx, event, headered.RoomVersion); err != nil { var stateSnapshot *parsedRespState
if stateSnapshot, err = missingState.processEventWithMissingState(ctx, event, headered.RoomVersion); err != nil {
// Something went wrong with retrieving the missing state, so we can't // Something went wrong with retrieving the missing state, so we can't
// really do anything with the event other than reject it at this point. // really do anything with the event other than reject it at this point.
isRejected = true isRejected = true
@ -302,7 +307,6 @@ func (r *Inputer) processRoomEvent(
// burning CPU time. // burning CPU time.
historyVisibility := gomatrixserverlib.HistoryVisibilityShared // Default to shared. historyVisibility := gomatrixserverlib.HistoryVisibilityShared // Default to shared.
if input.Kind != api.KindOutlier && rejectionErr == nil && !isRejected { if input.Kind != api.KindOutlier && rejectionErr == nil && !isRejected {
var err error
historyVisibility, rejectionErr, err = r.processStateBefore(ctx, input, missingPrev) historyVisibility, rejectionErr, err = r.processStateBefore(ctx, input, missingPrev)
if err != nil { if err != nil {
return fmt.Errorf("r.processStateBefore: %w", err) return fmt.Errorf("r.processStateBefore: %w", err)

View file

@ -468,7 +468,9 @@ FindSuccessor:
// Store the server names in a temporary map to avoid duplicates. // Store the server names in a temporary map to avoid duplicates.
serverSet := make(map[gomatrixserverlib.ServerName]bool) serverSet := make(map[gomatrixserverlib.ServerName]bool)
for _, event := range memberEvents { for _, event := range memberEvents {
serverSet[event.Origin()] = true if _, senderDomain, err := gomatrixserverlib.SplitID('@', event.Sender()); err == nil {
serverSet[senderDomain] = true
}
} }
var servers []gomatrixserverlib.ServerName var servers []gomatrixserverlib.ServerName
for server := range serverSet { for server := range serverSet {

View file

@ -50,6 +50,10 @@ func (r *Inviter) PerformInvite(
if event.StateKey() == nil { if event.StateKey() == nil {
return nil, fmt.Errorf("invite must be a state event") return nil, fmt.Errorf("invite must be a state event")
} }
_, senderDomain, err := gomatrixserverlib.SplitID('@', event.Sender())
if err != nil {
return nil, fmt.Errorf("sender %q is invalid", event.Sender())
}
roomID := event.RoomID() roomID := event.RoomID()
targetUserID := *event.StateKey() targetUserID := *event.StateKey()
@ -67,7 +71,7 @@ func (r *Inviter) PerformInvite(
return nil, nil return nil, nil
} }
isTargetLocal := domain == r.Cfg.Matrix.ServerName isTargetLocal := domain == r.Cfg.Matrix.ServerName
isOriginLocal := event.Origin() == r.Cfg.Matrix.ServerName isOriginLocal := senderDomain == r.Cfg.Matrix.ServerName
if !isOriginLocal && !isTargetLocal { if !isOriginLocal && !isTargetLocal {
res.Error = &api.PerformError{ res.Error = &api.PerformError{
Code: api.PerformErrorBadRequest, Code: api.PerformErrorBadRequest,
@ -235,7 +239,7 @@ func (r *Inviter) PerformInvite(
{ {
Kind: api.KindNew, Kind: api.KindNew,
Event: event, Event: event,
Origin: event.Origin(), Origin: senderDomain,
SendAsServer: req.SendAsServer, SendAsServer: req.SendAsServer,
}, },
}, },

View file

@ -81,12 +81,11 @@ func (r *Leaver) performLeaveRoomByID(
// that. // that.
isInvitePending, senderUser, eventID, err := helpers.IsInvitePending(ctx, r.DB, req.RoomID, req.UserID) isInvitePending, senderUser, eventID, err := helpers.IsInvitePending(ctx, r.DB, req.RoomID, req.UserID)
if err == nil && isInvitePending { if err == nil && isInvitePending {
var host gomatrixserverlib.ServerName _, senderDomain, serr := gomatrixserverlib.SplitID('@', senderUser)
_, host, err = gomatrixserverlib.SplitID('@', senderUser) if serr != nil {
if err != nil {
return nil, fmt.Errorf("sender %q is invalid", senderUser) return nil, fmt.Errorf("sender %q is invalid", senderUser)
} }
if host != r.Cfg.Matrix.ServerName { if senderDomain != r.Cfg.Matrix.ServerName {
return r.performFederatedRejectInvite(ctx, req, res, senderUser, eventID) return r.performFederatedRejectInvite(ctx, req, res, senderUser, eventID)
} }
// check that this is not a "server notice room" // check that this is not a "server notice room"
@ -172,6 +171,12 @@ func (r *Leaver) performLeaveRoomByID(
return nil, fmt.Errorf("eventutil.BuildEvent: %w", err) return nil, fmt.Errorf("eventutil.BuildEvent: %w", err)
} }
// Get the sender domain.
_, senderDomain, serr := gomatrixserverlib.SplitID('@', event.Sender())
if serr != nil {
return nil, fmt.Errorf("sender %q is invalid", event.Sender())
}
// Give our leave event to the roomserver input stream. The // Give our leave event to the roomserver input stream. The
// roomserver will process the membership change and notify // roomserver will process the membership change and notify
// downstream automatically. // downstream automatically.
@ -180,7 +185,7 @@ func (r *Leaver) performLeaveRoomByID(
{ {
Kind: api.KindNew, Kind: api.KindNew,
Event: event.Headered(buildRes.RoomVersion), Event: event.Headered(buildRes.RoomVersion),
Origin: event.Origin(), Origin: senderDomain,
SendAsServer: string(r.Cfg.Matrix.ServerName), SendAsServer: string(r.Cfg.Matrix.ServerName),
}, },
}, },

View file

@ -896,7 +896,7 @@ func (d *Database) handleRedactions(
switch { switch {
case redactUser >= pl.Redact: case redactUser >= pl.Redact:
// The power level of the redaction events sender is greater than or equal to the redact level. // The power level of the redaction events sender is greater than or equal to the redact level.
case redactedEvent.Origin() == redactionEvent.Origin() && redactedEvent.Sender() == redactionEvent.Sender(): case redactedEvent.Sender() == redactionEvent.Sender():
// The domain of the redaction events sender matches that of the original events sender. // The domain of the redaction events sender matches that of the original events sender.
default: default:
return nil, "", nil return nil, "", nil

View file

@ -1,4 +1,4 @@
#!/bin/bash #!/usr/bin/env bash
# #
# Runs SyTest either from Docker Hub, or from ../sytest. If it's run # Runs SyTest either from Docker Hub, or from ../sytest. If it's run
# locally, the Docker image is rebuilt first. # locally, the Docker image is rebuilt first.

View file

@ -37,16 +37,13 @@ import (
"golang.org/x/net/http2" "golang.org/x/net/http2"
"golang.org/x/net/http2/h2c" "golang.org/x/net/http2/h2c"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/caching"
"github.com/matrix-org/dendrite/internal/fulltext" "github.com/matrix-org/dendrite/internal/fulltext"
"github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/internal/pushgateway" "github.com/matrix-org/dendrite/internal/pushgateway"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/setup/process"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/kardianos/minwinsvc" "github.com/kardianos/minwinsvc"
@ -61,6 +58,8 @@ import (
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
rsinthttp "github.com/matrix-org/dendrite/roomserver/inthttp" rsinthttp "github.com/matrix-org/dendrite/roomserver/inthttp"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/setup/process"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
userapiinthttp "github.com/matrix-org/dendrite/userapi/inthttp" userapiinthttp "github.com/matrix-org/dendrite/userapi/inthttp"
) )
@ -392,17 +391,26 @@ func (b *BaseDendrite) configureHTTPErrors() {
_, _ = w.Write([]byte(fmt.Sprintf("405 %s not allowed on this endpoint", r.Method))) _, _ = w.Write([]byte(fmt.Sprintf("405 %s not allowed on this endpoint", r.Method)))
} }
clientNotFoundHandler := func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"errcode":"M_UNRECOGNIZED","error":"Unrecognized request"}`)) // nolint:misspell
}
notFoundCORSHandler := httputil.WrapHandlerInCORS(http.NotFoundHandler()) notFoundCORSHandler := httputil.WrapHandlerInCORS(http.NotFoundHandler())
notAllowedCORSHandler := httputil.WrapHandlerInCORS(http.HandlerFunc(notAllowedHandler)) notAllowedCORSHandler := httputil.WrapHandlerInCORS(http.HandlerFunc(notAllowedHandler))
for _, router := range []*mux.Router{ for _, router := range []*mux.Router{
b.PublicClientAPIMux, b.PublicMediaAPIMux, b.PublicMediaAPIMux, b.DendriteAdminMux,
b.DendriteAdminMux, b.SynapseAdminMux, b.SynapseAdminMux, b.PublicWellKnownAPIMux,
b.PublicWellKnownAPIMux,
} { } {
router.NotFoundHandler = notFoundCORSHandler router.NotFoundHandler = notFoundCORSHandler
router.MethodNotAllowedHandler = notAllowedCORSHandler router.MethodNotAllowedHandler = notAllowedCORSHandler
} }
// Special case so that we don't upset clients on the CS API.
b.PublicClientAPIMux.NotFoundHandler = http.HandlerFunc(clientNotFoundHandler)
b.PublicClientAPIMux.MethodNotAllowedHandler = http.HandlerFunc(clientNotFoundHandler)
} }
// SetupAndServeHTTP sets up the HTTP server to serve endpoints registered on // SetupAndServeHTTP sets up the HTTP server to serve endpoints registered on

View file

@ -24,6 +24,7 @@ type EthereumAuthConfig struct {
Enabled bool `yaml:"enabled"` Enabled bool `yaml:"enabled"`
Version uint `yaml:"version"` Version uint `yaml:"version"`
ChainIDs []int `yaml:"chain_ids"` ChainIDs []int `yaml:"chain_ids"`
EnableAuthz bool `yaml:"enable_authz"` // Flag to enable / disable authorization during development
} }
type PublicKeyAuthentication struct { type PublicKeyAuthentication struct {

View file

@ -10,7 +10,7 @@ type SyncAPI struct {
RealIPHeader string `yaml:"real_ip_header"` RealIPHeader string `yaml:"real_ip_header"`
Fulltext Fulltext `yaml:"fulltext"` Fulltext Fulltext `yaml:"search"`
} }
func (c *SyncAPI) Defaults(opts DefaultOpts) { func (c *SyncAPI) Defaults(opts DefaultOpts) {
@ -50,18 +50,14 @@ type Fulltext struct {
func (f *Fulltext) Defaults(opts DefaultOpts) { func (f *Fulltext) Defaults(opts DefaultOpts) {
f.Enabled = false f.Enabled = false
f.IndexPath = "./fulltextindex" f.IndexPath = "./searchindex"
f.Language = "en" f.Language = "en"
if opts.Generate {
f.Enabled = true
f.InMemory = true
}
} }
func (f *Fulltext) Verify(configErrs *ConfigErrors, isMonolith bool) { func (f *Fulltext) Verify(configErrs *ConfigErrors, isMonolith bool) {
if !f.Enabled { if !f.Enabled {
return return
} }
checkNotEmpty(configErrs, "syncapi.fulltext.index_path", string(f.IndexPath)) checkNotEmpty(configErrs, "syncapi.search.index_path", string(f.IndexPath))
checkNotEmpty(configErrs, "syncapi.fulltext.language", f.Language) checkNotEmpty(configErrs, "syncapi.search.language", f.Language)
} }

View file

@ -9,9 +9,10 @@ import (
"time" "time"
"github.com/getsentry/sentry-go" "github.com/getsentry/sentry-go"
"github.com/sirupsen/logrus"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/setup/process"
"github.com/sirupsen/logrus"
natsserver "github.com/nats-io/nats-server/v2/server" natsserver "github.com/nats-io/nats-server/v2/server"
natsclient "github.com/nats-io/nats.go" natsclient "github.com/nats-io/nats.go"
@ -184,6 +185,8 @@ func setupNATS(process *process.ProcessContext, cfg *config.JetStream, nc *natsc
OutputSendToDeviceEvent: {"SyncAPIEDUServerSendToDeviceConsumer", "FederationAPIEDUServerConsumer"}, OutputSendToDeviceEvent: {"SyncAPIEDUServerSendToDeviceConsumer", "FederationAPIEDUServerConsumer"},
OutputTypingEvent: {"SyncAPIEDUServerTypingConsumer", "FederationAPIEDUServerConsumer"}, OutputTypingEvent: {"SyncAPIEDUServerTypingConsumer", "FederationAPIEDUServerConsumer"},
OutputRoomEvent: {"AppserviceRoomserverConsumer"}, OutputRoomEvent: {"AppserviceRoomserverConsumer"},
OutputStreamEvent: {"UserAPISyncAPIStreamEventConsumer"},
OutputReadUpdate: {"UserAPISyncAPIReadUpdateConsumer"},
} { } {
streamName := cfg.Matrix.JetStream.Prefixed(stream) streamName := cfg.Matrix.JetStream.Prefixed(stream)
for _, consumer := range consumers { for _, consumer := range consumers {

View file

@ -94,16 +94,6 @@ var streams = []*nats.StreamConfig{
Retention: nats.InterestPolicy, Retention: nats.InterestPolicy,
Storage: nats.FileStorage, Storage: nats.FileStorage,
}, },
{
Name: OutputStreamEvent,
Retention: nats.InterestPolicy,
Storage: nats.FileStorage,
},
{
Name: OutputReadUpdate,
Retention: nats.InterestPolicy,
Storage: nats.FileStorage,
},
{ {
Name: OutputPresenceEvent, Name: OutputPresenceEvent,
Retention: nats.InterestPolicy, Retention: nats.InterestPolicy,

View file

@ -1,4 +1,4 @@
#! /bin/bash #!/usr/bin/env bash
# #
# Parses a results.tap file from SyTest output and a file containing test names (a test whitelist) # Parses a results.tap file from SyTest output and a file containing test names (a test whitelist)
# and checks whether a test name that exists in the whitelist (that should pass), failed or not. # and checks whether a test name that exists in the whitelist (that should pass), failed or not.

View file

@ -16,22 +16,23 @@ package consumers
import ( import (
"context" "context"
"database/sql"
"encoding/json" "encoding/json"
"fmt" "strings"
"time"
"github.com/getsentry/sentry-go" "github.com/getsentry/sentry-go"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/nats-io/nats.go" "github.com/nats-io/nats.go"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/internal/fulltext"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/setup/process"
"github.com/matrix-org/dendrite/syncapi/notifier" "github.com/matrix-org/dendrite/syncapi/notifier"
"github.com/matrix-org/dendrite/syncapi/producers"
"github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
) )
@ -40,13 +41,16 @@ import (
type OutputClientDataConsumer struct { type OutputClientDataConsumer struct {
ctx context.Context ctx context.Context
jetstream nats.JetStreamContext jetstream nats.JetStreamContext
nats *nats.Conn
durable string durable string
topic string topic string
topicReIndex string
db storage.Database db storage.Database
stream types.StreamProvider stream types.StreamProvider
notifier *notifier.Notifier notifier *notifier.Notifier
serverName gomatrixserverlib.ServerName serverName gomatrixserverlib.ServerName
producer *producers.UserAPIReadProducer fts *fulltext.Search
cfg *config.SyncAPI
} }
// NewOutputClientDataConsumer creates a new OutputClientData consumer. Call Start() to begin consuming from room servers. // NewOutputClientDataConsumer creates a new OutputClientData consumer. Call Start() to begin consuming from room servers.
@ -54,26 +58,93 @@ func NewOutputClientDataConsumer(
process *process.ProcessContext, process *process.ProcessContext,
cfg *config.SyncAPI, cfg *config.SyncAPI,
js nats.JetStreamContext, js nats.JetStreamContext,
nats *nats.Conn,
store storage.Database, store storage.Database,
notifier *notifier.Notifier, notifier *notifier.Notifier,
stream types.StreamProvider, stream types.StreamProvider,
producer *producers.UserAPIReadProducer, fts *fulltext.Search,
) *OutputClientDataConsumer { ) *OutputClientDataConsumer {
return &OutputClientDataConsumer{ return &OutputClientDataConsumer{
ctx: process.Context(), ctx: process.Context(),
jetstream: js, jetstream: js,
topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputClientData), topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputClientData),
topicReIndex: cfg.Matrix.JetStream.Prefixed(jetstream.InputFulltextReindex),
durable: cfg.Matrix.JetStream.Durable("SyncAPIAccountDataConsumer"), durable: cfg.Matrix.JetStream.Durable("SyncAPIAccountDataConsumer"),
nats: nats,
db: store, db: store,
notifier: notifier, notifier: notifier,
stream: stream, stream: stream,
serverName: cfg.Matrix.ServerName, serverName: cfg.Matrix.ServerName,
producer: producer, fts: fts,
cfg: cfg,
} }
} }
// Start consuming from room servers // Start consuming from room servers
func (s *OutputClientDataConsumer) Start() error { func (s *OutputClientDataConsumer) Start() error {
_, err := s.nats.Subscribe(s.topicReIndex, func(msg *nats.Msg) {
if err := msg.Ack(); err != nil {
return
}
if !s.cfg.Fulltext.Enabled {
logrus.Warn("Fulltext indexing is disabled")
return
}
ctx := context.Background()
logrus.Infof("Starting to index events")
var offset int
start := time.Now()
count := 0
var id int64 = 0
for {
evs, err := s.db.ReIndex(ctx, 1000, id)
if err != nil {
logrus.WithError(err).Errorf("unable to get events to index")
return
}
if len(evs) == 0 {
break
}
logrus.Debugf("Indexing %d events", len(evs))
elements := make([]fulltext.IndexElement, 0, len(evs))
for streamPos, ev := range evs {
id = streamPos
e := fulltext.IndexElement{
EventID: ev.EventID(),
RoomID: ev.RoomID(),
StreamPosition: streamPos,
}
e.SetContentType(ev.Type())
switch ev.Type() {
case "m.room.message":
e.Content = gjson.GetBytes(ev.Content(), "body").String()
case gomatrixserverlib.MRoomName:
e.Content = gjson.GetBytes(ev.Content(), "name").String()
case gomatrixserverlib.MRoomTopic:
e.Content = gjson.GetBytes(ev.Content(), "topic").String()
default:
continue
}
if strings.TrimSpace(e.Content) == "" {
continue
}
elements = append(elements, e)
}
if err = s.fts.Index(elements...); err != nil {
logrus.WithError(err).Error("unable to index events")
continue
}
offset += len(evs)
count += len(elements)
}
logrus.Infof("Indexed %d events in %v", count, time.Since(start))
})
if err != nil {
return err
}
return jetstream.JetStreamConsumer( return jetstream.JetStreamConsumer(
s.ctx, s.jetstream, s.topic, s.durable, 1, s.ctx, s.jetstream, s.topic, s.durable, 1,
s.onMessage, nats.DeliverAll(), nats.ManualAck(), s.onMessage, nats.DeliverAll(), nats.ManualAck(),
@ -113,15 +184,6 @@ func (s *OutputClientDataConsumer) onMessage(ctx context.Context, msgs []*nats.M
return false return false
} }
if err = s.sendReadUpdate(ctx, userID, output); err != nil {
log.WithError(err).WithFields(logrus.Fields{
"user_id": userID,
"room_id": output.RoomID,
}).Errorf("Failed to generate read update")
sentry.CaptureException(err)
return false
}
if output.IgnoredUsers != nil { if output.IgnoredUsers != nil {
if err := s.db.UpdateIgnoresForUser(ctx, userID, output.IgnoredUsers); err != nil { if err := s.db.UpdateIgnoresForUser(ctx, userID, output.IgnoredUsers); err != nil {
log.WithError(err).WithFields(logrus.Fields{ log.WithError(err).WithFields(logrus.Fields{
@ -136,34 +198,3 @@ func (s *OutputClientDataConsumer) onMessage(ctx context.Context, msgs []*nats.M
return true return true
} }
func (s *OutputClientDataConsumer) sendReadUpdate(ctx context.Context, userID string, output eventutil.AccountData) error {
if output.Type != "m.fully_read" || output.ReadMarker == nil {
return nil
}
_, serverName, err := gomatrixserverlib.SplitID('@', userID)
if err != nil {
return fmt.Errorf("gomatrixserverlib.SplitID: %w", err)
}
if serverName != s.serverName {
return nil
}
var readPos types.StreamPosition
var fullyReadPos types.StreamPosition
if output.ReadMarker.Read != "" {
if _, readPos, err = s.db.PositionInTopology(ctx, output.ReadMarker.Read); err != nil && err != sql.ErrNoRows {
return fmt.Errorf("s.db.PositionInTopology (Read): %w", err)
}
}
if output.ReadMarker.FullyRead != "" {
if _, fullyReadPos, err = s.db.PositionInTopology(ctx, output.ReadMarker.FullyRead); err != nil && err != sql.ErrNoRows {
return fmt.Errorf("s.db.PositionInTopology (FullyRead): %w", err)
}
}
if readPos > 0 || fullyReadPos > 0 {
if err := s.producer.SendReadUpdate(userID, output.RoomID, readPos, fullyReadPos); err != nil {
return fmt.Errorf("s.producer.SendReadUpdate: %w", err)
}
}
return nil
}

View file

@ -16,22 +16,19 @@ package consumers
import ( import (
"context" "context"
"database/sql"
"fmt"
"strconv" "strconv"
"github.com/getsentry/sentry-go" "github.com/getsentry/sentry-go"
"github.com/matrix-org/gomatrixserverlib"
"github.com/nats-io/nats.go"
log "github.com/sirupsen/logrus"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/setup/process"
"github.com/matrix-org/dendrite/syncapi/notifier" "github.com/matrix-org/dendrite/syncapi/notifier"
"github.com/matrix-org/dendrite/syncapi/producers"
"github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib"
"github.com/nats-io/nats.go"
"github.com/sirupsen/logrus"
log "github.com/sirupsen/logrus"
) )
// OutputReceiptEventConsumer consumes events that originated in the EDU server. // OutputReceiptEventConsumer consumes events that originated in the EDU server.
@ -44,7 +41,6 @@ type OutputReceiptEventConsumer struct {
stream types.StreamProvider stream types.StreamProvider
notifier *notifier.Notifier notifier *notifier.Notifier
serverName gomatrixserverlib.ServerName serverName gomatrixserverlib.ServerName
producer *producers.UserAPIReadProducer
} }
// NewOutputReceiptEventConsumer creates a new OutputReceiptEventConsumer. // NewOutputReceiptEventConsumer creates a new OutputReceiptEventConsumer.
@ -56,7 +52,6 @@ func NewOutputReceiptEventConsumer(
store storage.Database, store storage.Database,
notifier *notifier.Notifier, notifier *notifier.Notifier,
stream types.StreamProvider, stream types.StreamProvider,
producer *producers.UserAPIReadProducer,
) *OutputReceiptEventConsumer { ) *OutputReceiptEventConsumer {
return &OutputReceiptEventConsumer{ return &OutputReceiptEventConsumer{
ctx: process.Context(), ctx: process.Context(),
@ -67,7 +62,6 @@ func NewOutputReceiptEventConsumer(
notifier: notifier, notifier: notifier,
stream: stream, stream: stream,
serverName: cfg.Matrix.ServerName, serverName: cfg.Matrix.ServerName,
producer: producer,
} }
} }
@ -111,42 +105,8 @@ func (s *OutputReceiptEventConsumer) onMessage(ctx context.Context, msgs []*nats
return true return true
} }
if err = s.sendReadUpdate(ctx, output); err != nil {
log.WithError(err).WithFields(logrus.Fields{
"user_id": output.UserID,
"room_id": output.RoomID,
}).Errorf("Failed to generate read update")
sentry.CaptureException(err)
return false
}
s.stream.Advance(streamPos) s.stream.Advance(streamPos)
s.notifier.OnNewReceipt(output.RoomID, types.StreamingToken{ReceiptPosition: streamPos}) s.notifier.OnNewReceipt(output.RoomID, types.StreamingToken{ReceiptPosition: streamPos})
return true return true
} }
func (s *OutputReceiptEventConsumer) sendReadUpdate(ctx context.Context, output types.OutputReceiptEvent) error {
if output.Type != "m.read" {
return nil
}
_, serverName, err := gomatrixserverlib.SplitID('@', output.UserID)
if err != nil {
return fmt.Errorf("gomatrixserverlib.SplitID: %w", err)
}
if serverName != s.serverName {
return nil
}
var readPos types.StreamPosition
if output.EventID != "" {
if _, readPos, err = s.db.PositionInTopology(ctx, output.EventID); err != nil && err != sql.ErrNoRows {
return fmt.Errorf("s.db.PositionInTopology (Read): %w", err)
}
}
if readPos > 0 {
if err := s.producer.SendReadUpdate(output.UserID, output.RoomID, readPos, 0); err != nil {
return fmt.Errorf("s.producer.SendReadUpdate: %w", err)
}
}
return nil
}

View file

@ -21,17 +21,19 @@ import (
"fmt" "fmt"
"github.com/getsentry/sentry-go" "github.com/getsentry/sentry-go"
"github.com/matrix-org/gomatrixserverlib"
"github.com/nats-io/nats.go"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"github.com/matrix-org/dendrite/internal/fulltext"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/setup/process"
"github.com/matrix-org/dendrite/syncapi/notifier" "github.com/matrix-org/dendrite/syncapi/notifier"
"github.com/matrix-org/dendrite/syncapi/producers"
"github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib"
"github.com/nats-io/nats.go"
log "github.com/sirupsen/logrus"
) )
// OutputRoomEventConsumer consumes events that originated in the room server. // OutputRoomEventConsumer consumes events that originated in the room server.
@ -46,7 +48,7 @@ type OutputRoomEventConsumer struct {
pduStream types.StreamProvider pduStream types.StreamProvider
inviteStream types.StreamProvider inviteStream types.StreamProvider
notifier *notifier.Notifier notifier *notifier.Notifier
producer *producers.UserAPIStreamEventProducer fts *fulltext.Search
} }
// NewOutputRoomEventConsumer creates a new OutputRoomEventConsumer. Call Start() to begin consuming from room servers. // NewOutputRoomEventConsumer creates a new OutputRoomEventConsumer. Call Start() to begin consuming from room servers.
@ -59,7 +61,7 @@ func NewOutputRoomEventConsumer(
pduStream types.StreamProvider, pduStream types.StreamProvider,
inviteStream types.StreamProvider, inviteStream types.StreamProvider,
rsAPI api.SyncRoomserverAPI, rsAPI api.SyncRoomserverAPI,
producer *producers.UserAPIStreamEventProducer, fts *fulltext.Search,
) *OutputRoomEventConsumer { ) *OutputRoomEventConsumer {
return &OutputRoomEventConsumer{ return &OutputRoomEventConsumer{
ctx: process.Context(), ctx: process.Context(),
@ -72,7 +74,7 @@ func NewOutputRoomEventConsumer(
pduStream: pduStream, pduStream: pduStream,
inviteStream: inviteStream, inviteStream: inviteStream,
rsAPI: rsAPI, rsAPI: rsAPI,
producer: producer, fts: fts,
} }
} }
@ -254,11 +256,11 @@ func (s *OutputRoomEventConsumer) onNewRoomEvent(
}).Panicf("roomserver output log: write new event failure") }).Panicf("roomserver output log: write new event failure")
return nil return nil
} }
if err = s.writeFTS(ev, pduPos); err != nil {
if err = s.producer.SendStreamEvent(ev.RoomID(), ev, pduPos); err != nil { log.WithFields(log.Fields{
log.WithError(err).Errorf("Failed to send stream output event for event %s", ev.EventID()) "event_id": ev.EventID(),
sentry.CaptureException(err) "type": ev.Type(),
return err }).WithError(err).Warn("failed to index fulltext element")
} }
if pduPos, err = s.notifyJoinedPeeks(ctx, ev, pduPos); err != nil { if pduPos, err = s.notifyJoinedPeeks(ctx, ev, pduPos); err != nil {
@ -304,6 +306,13 @@ func (s *OutputRoomEventConsumer) onOldRoomEvent(
return nil return nil
} }
if err = s.writeFTS(ev, pduPos); err != nil {
log.WithFields(log.Fields{
"event_id": ev.EventID(),
"type": ev.Type(),
}).WithError(err).Warn("failed to index fulltext element")
}
if pduPos, err = s.notifyJoinedPeeks(ctx, ev, pduPos); err != nil { if pduPos, err = s.notifyJoinedPeeks(ctx, ev, pduPos); err != nil {
log.WithError(err).Errorf("Failed to notifyJoinedPeeks for PDU pos %d", pduPos) log.WithError(err).Errorf("Failed to notifyJoinedPeeks for PDU pos %d", pduPos)
return err return err
@ -460,3 +469,39 @@ func (s *OutputRoomEventConsumer) updateStateEvent(event *gomatrixserverlib.Head
event.Event, err = event.SetUnsigned(prev) event.Event, err = event.SetUnsigned(prev)
return event, err return event, err
} }
func (s *OutputRoomEventConsumer) writeFTS(ev *gomatrixserverlib.HeaderedEvent, pduPosition types.StreamPosition) error {
if !s.cfg.Fulltext.Enabled {
return nil
}
e := fulltext.IndexElement{
EventID: ev.EventID(),
RoomID: ev.RoomID(),
StreamPosition: int64(pduPosition),
}
e.SetContentType(ev.Type())
switch ev.Type() {
case "m.room.message":
e.Content = gjson.GetBytes(ev.Content(), "body").String()
case gomatrixserverlib.MRoomName:
e.Content = gjson.GetBytes(ev.Content(), "name").String()
case gomatrixserverlib.MRoomTopic:
e.Content = gjson.GetBytes(ev.Content(), "topic").String()
case gomatrixserverlib.MRoomRedaction:
log.Tracef("Redacting event: %s", ev.Redacts())
if err := s.fts.Delete(ev.Redacts()); err != nil {
return fmt.Errorf("failed to delete entry from fulltext index: %w", err)
}
return nil
default:
return nil
}
if e.Content != "" {
log.Tracef("Indexing element: %+v", e)
if err := s.fts.Index(e); err != nil {
return err
}
}
return nil
}

View file

@ -19,6 +19,9 @@ import (
"encoding/json" "encoding/json"
"github.com/getsentry/sentry-go" "github.com/getsentry/sentry-go"
"github.com/nats-io/nats.go"
log "github.com/sirupsen/logrus"
"github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/jetstream"
@ -26,8 +29,6 @@ import (
"github.com/matrix-org/dendrite/syncapi/notifier" "github.com/matrix-org/dendrite/syncapi/notifier"
"github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
"github.com/nats-io/nats.go"
log "github.com/sirupsen/logrus"
) )
// OutputNotificationDataConsumer consumes events that originated in // OutputNotificationDataConsumer consumes events that originated in

View file

@ -1,62 +0,0 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package producers
import (
"encoding/json"
"github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/nats-io/nats.go"
log "github.com/sirupsen/logrus"
)
// UserAPIProducer produces events for the user API server to consume
type UserAPIReadProducer struct {
Topic string
JetStream nats.JetStreamContext
}
// SendData sends account data to the user API server
func (p *UserAPIReadProducer) SendReadUpdate(userID, roomID string, readPos, fullyReadPos types.StreamPosition) error {
m := &nats.Msg{
Subject: p.Topic,
Header: nats.Header{},
}
m.Header.Set(jetstream.UserID, userID)
m.Header.Set(jetstream.RoomID, roomID)
data := types.ReadUpdate{
UserID: userID,
RoomID: roomID,
Read: readPos,
FullyRead: fullyReadPos,
}
var err error
m.Data, err = json.Marshal(data)
if err != nil {
return err
}
log.WithFields(log.Fields{
"user_id": userID,
"room_id": roomID,
"read_pos": readPos,
"fully_read_pos": fullyReadPos,
}).Tracef("Producing to topic '%s'", p.Topic)
_, err = p.JetStream.PublishMsg(m)
return err
}

View file

@ -1,60 +0,0 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package producers
import (
"encoding/json"
"github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib"
"github.com/nats-io/nats.go"
log "github.com/sirupsen/logrus"
)
// UserAPIProducer produces events for the user API server to consume
type UserAPIStreamEventProducer struct {
Topic string
JetStream nats.JetStreamContext
}
// SendData sends account data to the user API server
func (p *UserAPIStreamEventProducer) SendStreamEvent(roomID string, event *gomatrixserverlib.HeaderedEvent, pos types.StreamPosition) error {
m := &nats.Msg{
Subject: p.Topic,
Header: nats.Header{},
}
m.Header.Set(jetstream.RoomID, roomID)
data := types.StreamedEvent{
Event: event,
StreamPosition: pos,
}
var err error
m.Data, err = json.Marshal(data)
if err != nil {
return err
}
log.WithFields(log.Fields{
"room_id": roomID,
"event_id": event.EventID(),
"event_type": event.Type(),
"stream_pos": pos,
}).Tracef("Producing to topic '%s'", p.Topic)
_, err = p.JetStream.PublishMsg(m)
return err
}

View file

@ -37,11 +37,11 @@ import (
type ContextRespsonse struct { type ContextRespsonse struct {
End string `json:"end"` End string `json:"end"`
Event gomatrixserverlib.ClientEvent `json:"event"` Event *gomatrixserverlib.ClientEvent `json:"event,omitempty"`
EventsAfter []gomatrixserverlib.ClientEvent `json:"events_after,omitempty"` EventsAfter []gomatrixserverlib.ClientEvent `json:"events_after,omitempty"`
EventsBefore []gomatrixserverlib.ClientEvent `json:"events_before,omitempty"` EventsBefore []gomatrixserverlib.ClientEvent `json:"events_before,omitempty"`
Start string `json:"start"` Start string `json:"start"`
State []gomatrixserverlib.ClientEvent `json:"state"` State []gomatrixserverlib.ClientEvent `json:"state,omitempty"`
} }
func Context( func Context(
@ -162,8 +162,9 @@ func Context(
eventsAfterClient := gomatrixserverlib.HeaderedToClientEvents(eventsAfterFiltered, gomatrixserverlib.FormatAll) eventsAfterClient := gomatrixserverlib.HeaderedToClientEvents(eventsAfterFiltered, gomatrixserverlib.FormatAll)
newState := applyLazyLoadMembers(device, filter, eventsAfterClient, eventsBeforeClient, state, lazyLoadCache) newState := applyLazyLoadMembers(device, filter, eventsAfterClient, eventsBeforeClient, state, lazyLoadCache)
ev := gomatrixserverlib.HeaderedToClientEvent(&requestedEvent, gomatrixserverlib.FormatAll)
response := ContextRespsonse{ response := ContextRespsonse{
Event: gomatrixserverlib.HeaderedToClientEvent(&requestedEvent, gomatrixserverlib.FormatAll), Event: &ev,
EventsAfter: eventsAfterClient, EventsAfter: eventsAfterClient,
EventsBefore: eventsBeforeClient, EventsBefore: eventsBeforeClient,
State: gomatrixserverlib.HeaderedToClientEvents(newState, gomatrixserverlib.FormatAll), State: gomatrixserverlib.HeaderedToClientEvents(newState, gomatrixserverlib.FormatAll),

View file

@ -18,15 +18,18 @@ import (
"net/http" "net/http"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/caching"
"github.com/matrix-org/dendrite/internal/fulltext"
"github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/sync" "github.com/matrix-org/dendrite/syncapi/sync"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
) )
// Setup configures the given mux with sync-server listeners // Setup configures the given mux with sync-server listeners
@ -40,6 +43,7 @@ func Setup(
rsAPI api.SyncRoomserverAPI, rsAPI api.SyncRoomserverAPI,
cfg *config.SyncAPI, cfg *config.SyncAPI,
lazyLoadCache caching.LazyLoadCache, lazyLoadCache caching.LazyLoadCache,
fts *fulltext.Search,
) { ) {
v3mux := csMux.PathPrefix("/{apiversion:(?:r0|v3)}/").Subrouter() v3mux := csMux.PathPrefix("/{apiversion:(?:r0|v3)}/").Subrouter()
@ -95,4 +99,24 @@ func Setup(
) )
}), }),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
v3mux.Handle("/search",
httputil.MakeAuthAPI("search", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if !cfg.Fulltext.Enabled {
return util.JSONResponse{
Code: http.StatusNotImplemented,
JSON: jsonerror.Unknown("Search has been disabled by the server administrator."),
}
}
var nextBatch *string
if err := req.ParseForm(); err != nil {
return jsonerror.InternalServerError()
}
if req.Form.Has("next_batch") {
nb := req.FormValue("next_batch")
nextBatch = &nb
}
return Search(req, device, syncDB, fts, nextBatch)
}),
).Methods(http.MethodPost, http.MethodOptions)
} }

344
syncapi/routing/search.go Normal file
View file

@ -0,0 +1,344 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package routing
import (
"context"
"net/http"
"sort"
"strconv"
"strings"
"time"
"github.com/blevesearch/bleve/v2/search"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/internal/fulltext"
"github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/userapi/api"
)
// nolint:gocyclo
func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts *fulltext.Search, from *string) util.JSONResponse {
start := time.Now()
var (
searchReq SearchRequest
err error
ctx = req.Context()
)
resErr := httputil.UnmarshalJSONRequest(req, &searchReq)
if resErr != nil {
logrus.Error("failed to unmarshal search request")
return *resErr
}
nextBatch := 0
if from != nil && *from != "" {
nextBatch, err = strconv.Atoi(*from)
if err != nil {
return jsonerror.InternalServerError()
}
}
if searchReq.SearchCategories.RoomEvents.Filter.Limit == 0 {
searchReq.SearchCategories.RoomEvents.Filter.Limit = 5
}
// only search rooms the user is actually joined to
joinedRooms, err := syncDB.RoomIDsWithMembership(ctx, device.UserID, "join")
if err != nil {
return jsonerror.InternalServerError()
}
if len(joinedRooms) == 0 {
return util.JSONResponse{
Code: http.StatusNotFound,
JSON: jsonerror.NotFound("User not joined to any rooms."),
}
}
joinedRoomsMap := make(map[string]struct{}, len(joinedRooms))
for _, roomID := range joinedRooms {
joinedRoomsMap[roomID] = struct{}{}
}
rooms := []string{}
if searchReq.SearchCategories.RoomEvents.Filter.Rooms != nil {
for _, roomID := range *searchReq.SearchCategories.RoomEvents.Filter.Rooms {
if _, ok := joinedRoomsMap[roomID]; ok {
rooms = append(rooms, roomID)
}
}
} else {
rooms = joinedRooms
}
if len(rooms) == 0 {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Unknown("User not allowed to search in this room(s)."),
}
}
orderByTime := searchReq.SearchCategories.RoomEvents.OrderBy == "recent"
result, err := fts.Search(
searchReq.SearchCategories.RoomEvents.SearchTerm,
rooms,
searchReq.SearchCategories.RoomEvents.Keys,
searchReq.SearchCategories.RoomEvents.Filter.Limit,
nextBatch,
orderByTime,
)
if err != nil {
logrus.WithError(err).Error("failed to search fulltext")
return jsonerror.InternalServerError()
}
logrus.Debugf("Search took %s", result.Took)
// From was specified but empty, return no results, only the count
if from != nil && *from == "" {
return util.JSONResponse{
Code: http.StatusOK,
JSON: SearchResponse{
SearchCategories: SearchCategories{
RoomEvents: RoomEvents{
Count: int(result.Total),
NextBatch: nil,
},
},
},
}
}
results := []Result{}
wantEvents := make([]string, 0, len(result.Hits))
eventScore := make(map[string]*search.DocumentMatch)
for _, hit := range result.Hits {
wantEvents = append(wantEvents, hit.ID)
eventScore[hit.ID] = hit
}
// Filter on m.room.message, as otherwise we also get events like m.reaction
// which "breaks" displaying results in Element Web.
types := []string{"m.room.message"}
roomFilter := &gomatrixserverlib.RoomEventFilter{
Rooms: &rooms,
Types: &types,
}
evs, err := syncDB.Events(ctx, wantEvents)
if err != nil {
logrus.WithError(err).Error("failed to get events from database")
return jsonerror.InternalServerError()
}
groups := make(map[string]RoomResult)
knownUsersProfiles := make(map[string]ProfileInfo)
// Sort the events by depth, as the returned values aren't ordered
if orderByTime {
sort.Slice(evs, func(i, j int) bool {
return evs[i].Depth() > evs[j].Depth()
})
}
stateForRooms := make(map[string][]gomatrixserverlib.ClientEvent)
for _, event := range evs {
eventsBefore, eventsAfter, err := contextEvents(ctx, syncDB, event, roomFilter, searchReq)
if err != nil {
logrus.WithError(err).Error("failed to get context events")
return jsonerror.InternalServerError()
}
startToken, endToken, err := getStartEnd(ctx, syncDB, eventsBefore, eventsAfter)
if err != nil {
logrus.WithError(err).Error("failed to get start/end")
return jsonerror.InternalServerError()
}
profileInfos := make(map[string]ProfileInfo)
for _, ev := range append(eventsBefore, eventsAfter...) {
profile, ok := knownUsersProfiles[event.Sender()]
if !ok {
stateEvent, err := syncDB.GetStateEvent(ctx, ev.RoomID(), gomatrixserverlib.MRoomMember, ev.Sender())
if err != nil {
logrus.WithError(err).WithField("user_id", event.Sender()).Warn("failed to query userprofile")
continue
}
if stateEvent == nil {
continue
}
profile = ProfileInfo{
AvatarURL: gjson.GetBytes(stateEvent.Content(), "avatar_url").Str,
DisplayName: gjson.GetBytes(stateEvent.Content(), "displayname").Str,
}
knownUsersProfiles[event.Sender()] = profile
}
profileInfos[ev.Sender()] = profile
}
results = append(results, Result{
Context: SearchContextResponse{
Start: startToken.String(),
End: endToken.String(),
EventsAfter: gomatrixserverlib.HeaderedToClientEvents(eventsAfter, gomatrixserverlib.FormatSync),
EventsBefore: gomatrixserverlib.HeaderedToClientEvents(eventsBefore, gomatrixserverlib.FormatSync),
ProfileInfo: profileInfos,
},
Rank: eventScore[event.EventID()].Score,
Result: gomatrixserverlib.HeaderedToClientEvent(event, gomatrixserverlib.FormatAll),
})
roomGroup := groups[event.RoomID()]
roomGroup.Results = append(roomGroup.Results, event.EventID())
groups[event.RoomID()] = roomGroup
if _, ok := stateForRooms[event.RoomID()]; searchReq.SearchCategories.RoomEvents.IncludeState && !ok {
stateFilter := gomatrixserverlib.DefaultStateFilter()
state, err := syncDB.CurrentState(ctx, event.RoomID(), &stateFilter, nil)
if err != nil {
logrus.WithError(err).Error("unable to get current state")
return jsonerror.InternalServerError()
}
stateForRooms[event.RoomID()] = gomatrixserverlib.HeaderedToClientEvents(state, gomatrixserverlib.FormatSync)
}
}
var nextBatchResult *string = nil
if int(result.Total) > nextBatch+len(results) {
nb := strconv.Itoa(len(results) + nextBatch)
nextBatchResult = &nb
} else if int(result.Total) == nextBatch+len(results) {
// Sytest expects a next_batch even if we don't actually have any more results
nb := ""
nextBatchResult = &nb
}
res := SearchResponse{
SearchCategories: SearchCategories{
RoomEvents: RoomEvents{
Count: int(result.Total),
Groups: Groups{RoomID: groups},
Results: results,
NextBatch: nextBatchResult,
Highlights: strings.Split(searchReq.SearchCategories.RoomEvents.SearchTerm, " "),
State: stateForRooms,
},
},
}
logrus.Debugf("Full search request took %v", time.Since(start))
return util.JSONResponse{
Code: http.StatusOK,
JSON: res,
}
}
// contextEvents returns the events around a given eventID
func contextEvents(
ctx context.Context,
syncDB storage.Database,
event *gomatrixserverlib.HeaderedEvent,
roomFilter *gomatrixserverlib.RoomEventFilter,
searchReq SearchRequest,
) ([]*gomatrixserverlib.HeaderedEvent, []*gomatrixserverlib.HeaderedEvent, error) {
id, _, err := syncDB.SelectContextEvent(ctx, event.RoomID(), event.EventID())
if err != nil {
logrus.WithError(err).Error("failed to query context event")
return nil, nil, err
}
roomFilter.Limit = searchReq.SearchCategories.RoomEvents.EventContext.BeforeLimit
eventsBefore, err := syncDB.SelectContextBeforeEvent(ctx, id, event.RoomID(), roomFilter)
if err != nil {
logrus.WithError(err).Error("failed to query before context event")
return nil, nil, err
}
roomFilter.Limit = searchReq.SearchCategories.RoomEvents.EventContext.AfterLimit
_, eventsAfter, err := syncDB.SelectContextAfterEvent(ctx, id, event.RoomID(), roomFilter)
if err != nil {
logrus.WithError(err).Error("failed to query after context event")
return nil, nil, err
}
return eventsBefore, eventsAfter, err
}
type SearchRequest struct {
SearchCategories struct {
RoomEvents struct {
EventContext struct {
AfterLimit int `json:"after_limit,omitempty"`
BeforeLimit int `json:"before_limit,omitempty"`
IncludeProfile bool `json:"include_profile,omitempty"`
} `json:"event_context"`
Filter gomatrixserverlib.StateFilter `json:"filter"`
Groupings struct {
GroupBy []struct {
Key string `json:"key"`
} `json:"group_by"`
} `json:"groupings"`
IncludeState bool `json:"include_state"`
Keys []string `json:"keys"`
OrderBy string `json:"order_by"`
SearchTerm string `json:"search_term"`
} `json:"room_events"`
} `json:"search_categories"`
}
type SearchResponse struct {
SearchCategories SearchCategories `json:"search_categories"`
}
type RoomResult struct {
NextBatch *string `json:"next_batch,omitempty"`
Order int `json:"order"`
Results []string `json:"results"`
}
type Groups struct {
RoomID map[string]RoomResult `json:"room_id"`
}
type Result struct {
Context SearchContextResponse `json:"context"`
Rank float64 `json:"rank"`
Result gomatrixserverlib.ClientEvent `json:"result"`
}
type SearchContextResponse struct {
End string `json:"end"`
EventsAfter []gomatrixserverlib.ClientEvent `json:"events_after"`
EventsBefore []gomatrixserverlib.ClientEvent `json:"events_before"`
Start string `json:"start"`
ProfileInfo map[string]ProfileInfo `json:"profile_info"`
}
type ProfileInfo struct {
AvatarURL string `json:"avatar_url"`
DisplayName string `json:"display_name"`
}
type RoomEvents struct {
Count int `json:"count"`
Groups Groups `json:"groups"`
Highlights []string `json:"highlights"`
NextBatch *string `json:"next_batch,omitempty"`
Results []Result `json:"results"`
State map[string][]gomatrixserverlib.ClientEvent `json:"state,omitempty"`
}
type SearchCategories struct {
RoomEvents RoomEvents `json:"room_events"`
}

View file

@ -29,6 +29,7 @@ import (
type Database interface { type Database interface {
Presence Presence
SharedUsers SharedUsers
Notifications
MaxStreamPositionForPDUs(ctx context.Context) (types.StreamPosition, error) MaxStreamPositionForPDUs(ctx context.Context) (types.StreamPosition, error)
MaxStreamPositionForReceipts(ctx context.Context) (types.StreamPosition, error) MaxStreamPositionForReceipts(ctx context.Context) (types.StreamPosition, error)
@ -148,12 +149,6 @@ type Database interface {
// GetRoomReceipts gets all receipts for a given roomID // GetRoomReceipts gets all receipts for a given roomID
GetRoomReceipts(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) ([]types.OutputReceiptEvent, error) GetRoomReceipts(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) ([]types.OutputReceiptEvent, error)
// UpsertRoomUnreadNotificationCounts updates the notification statistics about a (user, room) key.
UpsertRoomUnreadNotificationCounts(ctx context.Context, userID, roomID string, notificationCount, highlightCount int) (types.StreamPosition, error)
// GetUserUnreadNotificationCounts returns statistics per room a user is interested in.
GetUserUnreadNotificationCounts(ctx context.Context, userID string, from, to types.StreamPosition) (map[string]*eventutil.NotificationData, error)
SelectContextEvent(ctx context.Context, roomID, eventID string) (int, gomatrixserverlib.HeaderedEvent, error) SelectContextEvent(ctx context.Context, roomID, eventID string) (int, gomatrixserverlib.HeaderedEvent, error)
SelectContextBeforeEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) ([]*gomatrixserverlib.HeaderedEvent, error) SelectContextBeforeEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) ([]*gomatrixserverlib.HeaderedEvent, error)
SelectContextAfterEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) (int, []*gomatrixserverlib.HeaderedEvent, error) SelectContextAfterEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) (int, []*gomatrixserverlib.HeaderedEvent, error)
@ -166,6 +161,7 @@ type Database interface {
// returns "leave", the topological position and no error. If an error occurs, other than sql.ErrNoRows, returns that and an empty // returns "leave", the topological position and no error. If an error occurs, other than sql.ErrNoRows, returns that and an empty
// string as the membership. // string as the membership.
SelectMembershipForUser(ctx context.Context, roomID, userID string, pos int64) (membership string, topologicalPos int, err error) SelectMembershipForUser(ctx context.Context, roomID, userID string, pos int64) (membership string, topologicalPos int, err error)
ReIndex(ctx context.Context, limit, afterID int64) (map[int64]gomatrixserverlib.HeaderedEvent, error)
} }
type Presence interface { type Presence interface {
@ -179,3 +175,11 @@ type SharedUsers interface {
// SharedUsers returns a subset of otherUserIDs that share a room with userID. // SharedUsers returns a subset of otherUserIDs that share a room with userID.
SharedUsers(ctx context.Context, userID string, otherUserIDs []string) ([]string, error) SharedUsers(ctx context.Context, userID string, otherUserIDs []string) ([]string, error)
} }
type Notifications interface {
// UpsertRoomUnreadNotificationCounts updates the notification statistics about a (user, room) key.
UpsertRoomUnreadNotificationCounts(ctx context.Context, userID, roomID string, notificationCount, highlightCount int) (types.StreamPosition, error)
// getUserUnreadNotificationCountsForRooms returns the unread notifications for the given rooms
GetUserUnreadNotificationCountsForRooms(ctx context.Context, userID string, roomIDs map[string]string) (map[string]*eventutil.NotificationData, error)
}

View file

@ -99,14 +99,15 @@ func (s *accountDataStatements) InsertAccountData(
} }
func (s *accountDataStatements) SelectAccountDataInRange( func (s *accountDataStatements) SelectAccountDataInRange(
ctx context.Context, ctx context.Context, txn *sql.Tx,
userID string, userID string,
r types.Range, r types.Range,
accountDataEventFilter *gomatrixserverlib.EventFilter, accountDataEventFilter *gomatrixserverlib.EventFilter,
) (data map[string][]string, pos types.StreamPosition, err error) { ) (data map[string][]string, pos types.StreamPosition, err error) {
data = make(map[string][]string) data = make(map[string][]string)
rows, err := s.selectAccountDataInRangeStmt.QueryContext(ctx, userID, r.Low(), r.High(), rows, err := sqlutil.TxStmt(txn, s.selectAccountDataInRangeStmt).QueryContext(
ctx, userID, r.Low(), r.High(),
pq.StringArray(filterConvertTypeWildcardToSQL(accountDataEventFilter.Types)), pq.StringArray(filterConvertTypeWildcardToSQL(accountDataEventFilter.Types)),
pq.StringArray(filterConvertTypeWildcardToSQL(accountDataEventFilter.NotTypes)), pq.StringArray(filterConvertTypeWildcardToSQL(accountDataEventFilter.NotTypes)),
accountDataEventFilter.Limit, accountDataEventFilter.Limit,

View file

@ -79,9 +79,9 @@ func (s *backwardExtremitiesStatements) InsertsBackwardExtremity(
} }
func (s *backwardExtremitiesStatements) SelectBackwardExtremitiesForRoom( func (s *backwardExtremitiesStatements) SelectBackwardExtremitiesForRoom(
ctx context.Context, roomID string, ctx context.Context, txn *sql.Tx, roomID string,
) (bwExtrems map[string][]string, err error) { ) (bwExtrems map[string][]string, err error) {
rows, err := s.selectBackwardExtremitiesForRoomStmt.QueryContext(ctx, roomID) rows, err := sqlutil.TxStmt(txn, s.selectBackwardExtremitiesForRoomStmt).QueryContext(ctx, roomID)
if err != nil { if err != nil {
return return
} }

View file

@ -62,6 +62,8 @@ CREATE UNIQUE INDEX IF NOT EXISTS syncapi_event_id_idx ON syncapi_current_room_s
CREATE INDEX IF NOT EXISTS syncapi_membership_idx ON syncapi_current_room_state(type, state_key, membership) WHERE membership IS NOT NULL AND membership != 'leave'; CREATE INDEX IF NOT EXISTS syncapi_membership_idx ON syncapi_current_room_state(type, state_key, membership) WHERE membership IS NOT NULL AND membership != 'leave';
-- for querying state by event IDs -- for querying state by event IDs
CREATE UNIQUE INDEX IF NOT EXISTS syncapi_current_room_state_eventid_idx ON syncapi_current_room_state(event_id); CREATE UNIQUE INDEX IF NOT EXISTS syncapi_current_room_state_eventid_idx ON syncapi_current_room_state(event_id);
-- for improving selectRoomIDsWithAnyMembershipSQL
CREATE INDEX IF NOT EXISTS syncapi_current_room_state_type_state_key_idx ON syncapi_current_room_state(type, state_key);
` `
const upsertRoomStateSQL = "" + const upsertRoomStateSQL = "" +
@ -80,7 +82,7 @@ const selectRoomIDsWithMembershipSQL = "" +
"SELECT DISTINCT room_id FROM syncapi_current_room_state WHERE type = 'm.room.member' AND state_key = $1 AND membership = $2" "SELECT DISTINCT room_id FROM syncapi_current_room_state WHERE type = 'm.room.member' AND state_key = $1 AND membership = $2"
const selectRoomIDsWithAnyMembershipSQL = "" + const selectRoomIDsWithAnyMembershipSQL = "" +
"SELECT DISTINCT room_id, membership FROM syncapi_current_room_state WHERE type = 'm.room.member' AND state_key = $1" "SELECT room_id, membership FROM syncapi_current_room_state WHERE type = 'm.room.member' AND state_key = $1"
const selectCurrentStateSQL = "" + const selectCurrentStateSQL = "" +
"SELECT event_id, headered_event_json FROM syncapi_current_room_state WHERE room_id = $1" + "SELECT event_id, headered_event_json FROM syncapi_current_room_state WHERE room_id = $1" +
@ -183,9 +185,9 @@ func NewPostgresCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, erro
// SelectJoinedUsers returns a map of room ID to a list of joined user IDs. // SelectJoinedUsers returns a map of room ID to a list of joined user IDs.
func (s *currentRoomStateStatements) SelectJoinedUsers( func (s *currentRoomStateStatements) SelectJoinedUsers(
ctx context.Context, ctx context.Context, txn *sql.Tx,
) (map[string][]string, error) { ) (map[string][]string, error) {
rows, err := s.selectJoinedUsersStmt.QueryContext(ctx) rows, err := sqlutil.TxStmt(txn, s.selectJoinedUsersStmt).QueryContext(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -207,9 +209,9 @@ func (s *currentRoomStateStatements) SelectJoinedUsers(
// SelectJoinedUsersInRoom returns a map of room ID to a list of joined user IDs for a given room. // SelectJoinedUsersInRoom returns a map of room ID to a list of joined user IDs for a given room.
func (s *currentRoomStateStatements) SelectJoinedUsersInRoom( func (s *currentRoomStateStatements) SelectJoinedUsersInRoom(
ctx context.Context, roomIDs []string, ctx context.Context, txn *sql.Tx, roomIDs []string,
) (map[string][]string, error) { ) (map[string][]string, error) {
rows, err := s.selectJoinedUsersInRoomStmt.QueryContext(ctx, pq.StringArray(roomIDs)) rows, err := sqlutil.TxStmt(txn, s.selectJoinedUsersInRoomStmt).QueryContext(ctx, pq.StringArray(roomIDs))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -385,9 +387,9 @@ func rowsToEvents(rows *sql.Rows) ([]*gomatrixserverlib.HeaderedEvent, error) {
} }
func (s *currentRoomStateStatements) SelectStateEvent( func (s *currentRoomStateStatements) SelectStateEvent(
ctx context.Context, roomID, evType, stateKey string, ctx context.Context, txn *sql.Tx, roomID, evType, stateKey string,
) (*gomatrixserverlib.HeaderedEvent, error) { ) (*gomatrixserverlib.HeaderedEvent, error) {
stmt := s.selectStateEventStmt stmt := sqlutil.TxStmt(txn, s.selectStateEventStmt)
var res []byte var res []byte
err := stmt.QueryRowContext(ctx, roomID, evType, stateKey).Scan(&res) err := stmt.QueryRowContext(ctx, roomID, evType, stateKey).Scan(&res)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {

View file

@ -19,6 +19,7 @@ import (
"database/sql" "database/sql"
"encoding/json" "encoding/json"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
@ -73,11 +74,11 @@ func NewPostgresFilterTable(db *sql.DB) (tables.Filter, error) {
} }
func (s *filterStatements) SelectFilter( func (s *filterStatements) SelectFilter(
ctx context.Context, target *gomatrixserverlib.Filter, localpart string, filterID string, ctx context.Context, txn *sql.Tx, target *gomatrixserverlib.Filter, localpart string, filterID string,
) error { ) error {
// Retrieve filter from database (stored as canonical JSON) // Retrieve filter from database (stored as canonical JSON)
var filterData []byte var filterData []byte
err := s.selectFilterStmt.QueryRowContext(ctx, localpart, filterID).Scan(&filterData) err := sqlutil.TxStmt(txn, s.selectFilterStmt).QueryRowContext(ctx, localpart, filterID).Scan(&filterData)
if err != nil { if err != nil {
return err return err
} }
@ -90,7 +91,7 @@ func (s *filterStatements) SelectFilter(
} }
func (s *filterStatements) InsertFilter( func (s *filterStatements) InsertFilter(
ctx context.Context, filter *gomatrixserverlib.Filter, localpart string, ctx context.Context, txn *sql.Tx, filter *gomatrixserverlib.Filter, localpart string,
) (filterID string, err error) { ) (filterID string, err error) {
var existingFilterID string var existingFilterID string
@ -111,8 +112,9 @@ func (s *filterStatements) InsertFilter(
// This can result in a race condition when two clients try to insert the // This can result in a race condition when two clients try to insert the
// same filter and localpart at the same time, however this is not a // same filter and localpart at the same time, however this is not a
// problem as both calls will result in the same filterID // problem as both calls will result in the same filterID
err = s.selectFilterIDByContentStmt.QueryRowContext(ctx, err = sqlutil.TxStmt(txn, s.selectFilterIDByContentStmt).QueryRowContext(
localpart, filterJSON).Scan(&existingFilterID) ctx, localpart, filterJSON,
).Scan(&existingFilterID)
if err != nil && err != sql.ErrNoRows { if err != nil && err != sql.ErrNoRows {
return "", err return "", err
} }
@ -122,7 +124,7 @@ func (s *filterStatements) InsertFilter(
} }
// Otherwise insert the filter and return the new ID // Otherwise insert the filter and return the new ID
err = s.insertFilterStmt.QueryRowContext(ctx, filterJSON, localpart). err = sqlutil.TxStmt(txn, s.insertFilterStmt).QueryRowContext(ctx, filterJSON, localpart).
Scan(&filterID) Scan(&filterID)
return return
} }

View file

@ -99,7 +99,7 @@ func (s *inviteEventsStatements) InsertInviteEvent(
return return
} }
err = s.insertInviteEventStmt.QueryRowContext( err = sqlutil.TxStmt(txn, s.insertInviteEventStmt).QueryRowContext(
ctx, ctx,
inviteEvent.RoomID(), inviteEvent.RoomID(),
inviteEvent.EventID(), inviteEvent.EventID(),

View file

@ -18,6 +18,8 @@ import (
"context" "context"
"database/sql" "database/sql"
"github.com/lib/pq"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
@ -33,14 +35,14 @@ func NewPostgresNotificationDataTable(db *sql.DB) (tables.NotificationData, erro
r := &notificationDataStatements{} r := &notificationDataStatements{}
return r, sqlutil.StatementList{ return r, sqlutil.StatementList{
{&r.upsertRoomUnreadCounts, upsertRoomUnreadNotificationCountsSQL}, {&r.upsertRoomUnreadCounts, upsertRoomUnreadNotificationCountsSQL},
{&r.selectUserUnreadCounts, selectUserUnreadNotificationCountsSQL}, {&r.selectUserUnreadCountsForRooms, selectUserUnreadNotificationsForRooms},
{&r.selectMaxID, selectMaxNotificationIDSQL}, {&r.selectMaxID, selectMaxNotificationIDSQL},
}.Prepare(db) }.Prepare(db)
} }
type notificationDataStatements struct { type notificationDataStatements struct {
upsertRoomUnreadCounts *sql.Stmt upsertRoomUnreadCounts *sql.Stmt
selectUserUnreadCounts *sql.Stmt selectUserUnreadCountsForRooms *sql.Stmt
selectMaxID *sql.Stmt selectMaxID *sql.Stmt
} }
@ -61,12 +63,10 @@ const upsertRoomUnreadNotificationCountsSQL = `INSERT INTO syncapi_notification_
DO UPDATE SET id = nextval('syncapi_notification_data_id_seq'), notification_count = $3, highlight_count = $4 DO UPDATE SET id = nextval('syncapi_notification_data_id_seq'), notification_count = $3, highlight_count = $4
RETURNING id` RETURNING id`
const selectUserUnreadNotificationCountsSQL = `SELECT const selectUserUnreadNotificationsForRooms = `SELECT room_id, notification_count, highlight_count
id, room_id, notification_count, highlight_count
FROM syncapi_notification_data FROM syncapi_notification_data
WHERE WHERE user_id = $1 AND
user_id = $1 AND room_id = ANY($2)`
id BETWEEN $2 + 1 AND $3`
const selectMaxNotificationIDSQL = `SELECT CASE COUNT(*) WHEN 0 THEN 0 ELSE MAX(id) END FROM syncapi_notification_data` const selectMaxNotificationIDSQL = `SELECT CASE COUNT(*) WHEN 0 THEN 0 ELSE MAX(id) END FROM syncapi_notification_data`
@ -75,20 +75,20 @@ func (r *notificationDataStatements) UpsertRoomUnreadCounts(ctx context.Context,
return return
} }
func (r *notificationDataStatements) SelectUserUnreadCounts(ctx context.Context, txn *sql.Tx, userID string, fromExcl, toIncl types.StreamPosition) (map[string]*eventutil.NotificationData, error) { func (r *notificationDataStatements) SelectUserUnreadCountsForRooms(
rows, err := sqlutil.TxStmt(txn, r.selectUserUnreadCounts).QueryContext(ctx, userID, fromExcl, toIncl) ctx context.Context, txn *sql.Tx, userID string, roomIDs []string,
) (map[string]*eventutil.NotificationData, error) {
rows, err := sqlutil.TxStmt(txn, r.selectUserUnreadCountsForRooms).QueryContext(ctx, userID, pq.Array(roomIDs))
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer internal.CloseAndLogIfError(ctx, rows, "SelectUserUnreadCounts: rows.close() failed") defer internal.CloseAndLogIfError(ctx, rows, "SelectUserUnreadCountsForRooms: rows.close() failed")
roomCounts := map[string]*eventutil.NotificationData{} roomCounts := map[string]*eventutil.NotificationData{}
for rows.Next() {
var id types.StreamPosition
var roomID string var roomID string
var notificationCount, highlightCount int var notificationCount, highlightCount int
for rows.Next() {
if err = rows.Scan(&id, &roomID, &notificationCount, &highlightCount); err != nil { if err = rows.Scan(&roomID, &notificationCount, &highlightCount); err != nil {
return nil, err return nil, err
} }

View file

@ -166,6 +166,8 @@ const selectContextAfterEventSQL = "" +
" AND ( $7::text[] IS NULL OR NOT(type LIKE ANY($7)) )" + " AND ( $7::text[] IS NULL OR NOT(type LIKE ANY($7)) )" +
" ORDER BY id ASC LIMIT $3" " ORDER BY id ASC LIMIT $3"
const selectSearchSQL = "SELECT id, event_id, headered_event_json FROM syncapi_output_room_events WHERE id > $1 AND type = ANY($2) ORDER BY id ASC LIMIT $3"
type outputRoomEventsStatements struct { type outputRoomEventsStatements struct {
insertEventStmt *sql.Stmt insertEventStmt *sql.Stmt
selectEventsStmt *sql.Stmt selectEventsStmt *sql.Stmt
@ -180,6 +182,7 @@ type outputRoomEventsStatements struct {
selectContextEventStmt *sql.Stmt selectContextEventStmt *sql.Stmt
selectContextBeforeEventStmt *sql.Stmt selectContextBeforeEventStmt *sql.Stmt
selectContextAfterEventStmt *sql.Stmt selectContextAfterEventStmt *sql.Stmt
selectSearchStmt *sql.Stmt
} }
func NewPostgresEventsTable(db *sql.DB) (tables.Events, error) { func NewPostgresEventsTable(db *sql.DB) (tables.Events, error) {
@ -215,15 +218,16 @@ func NewPostgresEventsTable(db *sql.DB) (tables.Events, error) {
{&s.selectContextEventStmt, selectContextEventSQL}, {&s.selectContextEventStmt, selectContextEventSQL},
{&s.selectContextBeforeEventStmt, selectContextBeforeEventSQL}, {&s.selectContextBeforeEventStmt, selectContextBeforeEventSQL},
{&s.selectContextAfterEventStmt, selectContextAfterEventSQL}, {&s.selectContextAfterEventStmt, selectContextAfterEventSQL},
{&s.selectSearchStmt, selectSearchSQL},
}.Prepare(db) }.Prepare(db)
} }
func (s *outputRoomEventsStatements) UpdateEventJSON(ctx context.Context, event *gomatrixserverlib.HeaderedEvent) error { func (s *outputRoomEventsStatements) UpdateEventJSON(ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent) error {
headeredJSON, err := json.Marshal(event) headeredJSON, err := json.Marshal(event)
if err != nil { if err != nil {
return err return err
} }
_, err = s.updateEventJSONStmt.ExecContext(ctx, headeredJSON, event.EventID()) _, err = sqlutil.TxStmt(txn, s.updateEventJSONStmt).ExecContext(ctx, headeredJSON, event.EventID())
return err return err
} }
@ -632,3 +636,27 @@ func rowsToStreamEvents(rows *sql.Rows) ([]types.StreamEvent, error) {
} }
return result, rows.Err() return result, rows.Err()
} }
func (s *outputRoomEventsStatements) ReIndex(ctx context.Context, txn *sql.Tx, limit, afterID int64, types []string) (map[int64]gomatrixserverlib.HeaderedEvent, error) {
rows, err := sqlutil.TxStmt(txn, s.selectSearchStmt).QueryContext(ctx, afterID, pq.StringArray(types), limit)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "rows.close() failed")
var eventID string
var id int64
result := make(map[int64]gomatrixserverlib.HeaderedEvent)
for rows.Next() {
var ev gomatrixserverlib.HeaderedEvent
var eventBytes []byte
if err = rows.Scan(&id, &eventID, &eventBytes); err != nil {
return nil, err
}
if err = ev.UnmarshalJSONWithEventID(eventBytes, eventID); err != nil {
return nil, err
}
result[id] = ev
}
return result, rows.Err()
}

View file

@ -173,7 +173,7 @@ func (s *outputRoomEventsTopologyStatements) SelectEventIDsInRange(
func (s *outputRoomEventsTopologyStatements) SelectPositionInTopology( func (s *outputRoomEventsTopologyStatements) SelectPositionInTopology(
ctx context.Context, txn *sql.Tx, eventID string, ctx context.Context, txn *sql.Tx, eventID string,
) (pos, spos types.StreamPosition, err error) { ) (pos, spos types.StreamPosition, err error) {
err = s.selectPositionInTopologyStmt.QueryRowContext(ctx, eventID).Scan(&pos, &spos) err = sqlutil.TxStmt(txn, s.selectPositionInTopologyStmt).QueryRowContext(ctx, eventID).Scan(&pos, &spos)
return return
} }
@ -183,9 +183,9 @@ func (s *outputRoomEventsTopologyStatements) SelectStreamToTopologicalPosition(
ctx context.Context, txn *sql.Tx, roomID string, streamPos types.StreamPosition, backwardOrdering bool, ctx context.Context, txn *sql.Tx, roomID string, streamPos types.StreamPosition, backwardOrdering bool,
) (topoPos types.StreamPosition, err error) { ) (topoPos types.StreamPosition, err error) {
if backwardOrdering { if backwardOrdering {
err = s.selectStreamToTopologicalPositionDescStmt.QueryRowContext(ctx, roomID, streamPos).Scan(&topoPos) err = sqlutil.TxStmt(txn, s.selectStreamToTopologicalPositionDescStmt).QueryRowContext(ctx, roomID, streamPos).Scan(&topoPos)
} else { } else {
err = s.selectStreamToTopologicalPositionAscStmt.QueryRowContext(ctx, roomID, streamPos).Scan(&topoPos) err = sqlutil.TxStmt(txn, s.selectStreamToTopologicalPositionAscStmt).QueryRowContext(ctx, roomID, streamPos).Scan(&topoPos)
} }
return return
} }
@ -193,6 +193,6 @@ func (s *outputRoomEventsTopologyStatements) SelectStreamToTopologicalPosition(
func (s *outputRoomEventsTopologyStatements) SelectMaxPositionInTopology( func (s *outputRoomEventsTopologyStatements) SelectMaxPositionInTopology(
ctx context.Context, txn *sql.Tx, roomID string, ctx context.Context, txn *sql.Tx, roomID string,
) (pos types.StreamPosition, spos types.StreamPosition, err error) { ) (pos types.StreamPosition, spos types.StreamPosition, err error) {
err = s.selectMaxPositionInTopologyStmt.QueryRowContext(ctx, roomID).Scan(&pos, &spos) err = sqlutil.TxStmt(txn, s.selectMaxPositionInTopologyStmt).QueryRowContext(ctx, roomID).Scan(&pos, &spos)
return return
} }

View file

@ -152,9 +152,9 @@ func (s *peekStatements) SelectPeeksInRange(
} }
func (s *peekStatements) SelectPeekingDevices( func (s *peekStatements) SelectPeekingDevices(
ctx context.Context, ctx context.Context, txn *sql.Tx,
) (peekingDevices map[string][]types.PeekingDevice, err error) { ) (peekingDevices map[string][]types.PeekingDevice, err error) {
rows, err := s.selectPeekingDevicesStmt.QueryContext(ctx) rows, err := sqlutil.TxStmt(txn, s.selectPeekingDevicesStmt).QueryContext(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -104,9 +104,9 @@ func (r *receiptStatements) UpsertReceipt(ctx context.Context, txn *sql.Tx, room
return return
} }
func (r *receiptStatements) SelectRoomReceiptsAfter(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) (types.StreamPosition, []types.OutputReceiptEvent, error) { func (r *receiptStatements) SelectRoomReceiptsAfter(ctx context.Context, txn *sql.Tx, roomIDs []string, streamPos types.StreamPosition) (types.StreamPosition, []types.OutputReceiptEvent, error) {
var lastPos types.StreamPosition var lastPos types.StreamPosition
rows, err := r.selectRoomReceipts.QueryContext(ctx, pq.Array(roomIDs), streamPos) rows, err := sqlutil.TxStmt(txn, r.selectRoomReceipts).QueryContext(ctx, pq.Array(roomIDs), streamPos)
if err != nil { if err != nil {
return 0, nil, fmt.Errorf("unable to query room receipts: %w", err) return 0, nil, fmt.Errorf("unable to query room receipts: %w", err)
} }

View file

@ -148,7 +148,7 @@ func (d *Database) PeeksInRange(ctx context.Context, userID, deviceID string, r
} }
func (d *Database) RoomReceiptsAfter(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) (types.StreamPosition, []types.OutputReceiptEvent, error) { func (d *Database) RoomReceiptsAfter(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) (types.StreamPosition, []types.OutputReceiptEvent, error) {
return d.Receipts.SelectRoomReceiptsAfter(ctx, roomIDs, streamPos) return d.Receipts.SelectRoomReceiptsAfter(ctx, nil, roomIDs, streamPos)
} }
// Events lookups a list of event by their event ID. // Events lookups a list of event by their event ID.
@ -168,15 +168,15 @@ func (d *Database) Events(ctx context.Context, eventIDs []string) ([]*gomatrixse
} }
func (d *Database) AllJoinedUsersInRooms(ctx context.Context) (map[string][]string, error) { func (d *Database) AllJoinedUsersInRooms(ctx context.Context) (map[string][]string, error) {
return d.CurrentRoomState.SelectJoinedUsers(ctx) return d.CurrentRoomState.SelectJoinedUsers(ctx, nil)
} }
func (d *Database) AllJoinedUsersInRoom(ctx context.Context, roomIDs []string) (map[string][]string, error) { func (d *Database) AllJoinedUsersInRoom(ctx context.Context, roomIDs []string) (map[string][]string, error) {
return d.CurrentRoomState.SelectJoinedUsersInRoom(ctx, roomIDs) return d.CurrentRoomState.SelectJoinedUsersInRoom(ctx, nil, roomIDs)
} }
func (d *Database) AllPeekingDevicesInRooms(ctx context.Context) (map[string][]types.PeekingDevice, error) { func (d *Database) AllPeekingDevicesInRooms(ctx context.Context) (map[string][]types.PeekingDevice, error) {
return d.Peeks.SelectPeekingDevices(ctx) return d.Peeks.SelectPeekingDevices(ctx, nil)
} }
func (d *Database) SharedUsers(ctx context.Context, userID string, otherUserIDs []string) ([]string, error) { func (d *Database) SharedUsers(ctx context.Context, userID string, otherUserIDs []string) ([]string, error) {
@ -186,7 +186,7 @@ func (d *Database) SharedUsers(ctx context.Context, userID string, otherUserIDs
func (d *Database) GetStateEvent( func (d *Database) GetStateEvent(
ctx context.Context, roomID, evType, stateKey string, ctx context.Context, roomID, evType, stateKey string,
) (*gomatrixserverlib.HeaderedEvent, error) { ) (*gomatrixserverlib.HeaderedEvent, error) {
return d.CurrentRoomState.SelectStateEvent(ctx, roomID, evType, stateKey) return d.CurrentRoomState.SelectStateEvent(ctx, nil, roomID, evType, stateKey)
} }
func (d *Database) GetStateEventsForRoom( func (d *Database) GetStateEventsForRoom(
@ -277,7 +277,7 @@ func (d *Database) GetAccountDataInRange(
ctx context.Context, userID string, r types.Range, ctx context.Context, userID string, r types.Range,
accountDataFilterPart *gomatrixserverlib.EventFilter, accountDataFilterPart *gomatrixserverlib.EventFilter,
) (map[string][]string, types.StreamPosition, error) { ) (map[string][]string, types.StreamPosition, error) {
return d.AccountData.SelectAccountDataInRange(ctx, userID, r, accountDataFilterPart) return d.AccountData.SelectAccountDataInRange(ctx, nil, userID, r, accountDataFilterPart)
} }
// UpsertAccountData keeps track of new or updated account data, by saving the type // UpsertAccountData keeps track of new or updated account data, by saving the type
@ -484,7 +484,7 @@ func (d *Database) GetEventsInTopologicalRange(
func (d *Database) BackwardExtremitiesForRoom( func (d *Database) BackwardExtremitiesForRoom(
ctx context.Context, roomID string, ctx context.Context, roomID string,
) (backwardExtremities map[string][]string, err error) { ) (backwardExtremities map[string][]string, err error) {
return d.BackwardExtremities.SelectBackwardExtremitiesForRoom(ctx, roomID) return d.BackwardExtremities.SelectBackwardExtremitiesForRoom(ctx, nil, roomID)
} }
func (d *Database) MaxTopologicalPosition( func (d *Database) MaxTopologicalPosition(
@ -530,7 +530,7 @@ func (d *Database) StreamToTopologicalPosition(
func (d *Database) GetFilter( func (d *Database) GetFilter(
ctx context.Context, target *gomatrixserverlib.Filter, localpart string, filterID string, ctx context.Context, target *gomatrixserverlib.Filter, localpart string, filterID string,
) error { ) error {
return d.Filter.SelectFilter(ctx, target, localpart, filterID) return d.Filter.SelectFilter(ctx, nil, target, localpart, filterID)
} }
func (d *Database) PutFilter( func (d *Database) PutFilter(
@ -538,8 +538,8 @@ func (d *Database) PutFilter(
) (string, error) { ) (string, error) {
var filterID string var filterID string
var err error var err error
err = d.Writer.Do(nil, nil, func(txn *sql.Tx) error { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
filterID, err = d.Filter.InsertFilter(ctx, filter, localpart) filterID, err = d.Filter.InsertFilter(ctx, txn, filter, localpart)
return err return err
}) })
return filterID, err return filterID, err
@ -561,8 +561,8 @@ func (d *Database) RedactEvent(ctx context.Context, redactedEventID string, reda
} }
newEvent := eventToRedact.Headered(redactedBecause.RoomVersion) newEvent := eventToRedact.Headered(redactedBecause.RoomVersion)
err = d.Writer.Do(nil, nil, func(txn *sql.Tx) error { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.OutputEvents.UpdateEventJSON(ctx, newEvent) return d.OutputEvents.UpdateEventJSON(ctx, txn, newEvent)
}) })
return err return err
} }
@ -1024,7 +1024,7 @@ func (d *Database) StoreReceipt(ctx context.Context, roomId, receiptType, userId
} }
func (d *Database) GetRoomReceipts(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) ([]types.OutputReceiptEvent, error) { func (d *Database) GetRoomReceipts(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) ([]types.OutputReceiptEvent, error) {
_, receipts, err := d.Receipts.SelectRoomReceiptsAfter(ctx, roomIDs, streamPos) _, receipts, err := d.Receipts.SelectRoomReceiptsAfter(ctx, nil, roomIDs, streamPos)
return receipts, err return receipts, err
} }
@ -1036,8 +1036,15 @@ func (d *Database) UpsertRoomUnreadNotificationCounts(ctx context.Context, userI
return return
} }
func (d *Database) GetUserUnreadNotificationCounts(ctx context.Context, userID string, from, to types.StreamPosition) (map[string]*eventutil.NotificationData, error) { func (d *Database) GetUserUnreadNotificationCountsForRooms(ctx context.Context, userID string, rooms map[string]string) (map[string]*eventutil.NotificationData, error) {
return d.NotificationData.SelectUserUnreadCounts(ctx, nil, userID, from, to) roomIDs := make([]string, 0, len(rooms))
for roomID, membership := range rooms {
if membership != gomatrixserverlib.Join {
continue
}
roomIDs = append(roomIDs, roomID)
}
return d.NotificationData.SelectUserUnreadCountsForRooms(ctx, nil, userID, roomIDs)
} }
func (d *Database) SelectContextEvent(ctx context.Context, roomID, eventID string) (int, gomatrixserverlib.HeaderedEvent, error) { func (d *Database) SelectContextEvent(ctx context.Context, roomID, eventID string) (int, gomatrixserverlib.HeaderedEvent, error) {
@ -1086,3 +1093,11 @@ func (d *Database) MaxStreamPositionForPresence(ctx context.Context) (types.Stre
func (d *Database) SelectMembershipForUser(ctx context.Context, roomID, userID string, pos int64) (membership string, topologicalPos int, err error) { func (d *Database) SelectMembershipForUser(ctx context.Context, roomID, userID string, pos int64) (membership string, topologicalPos int, err error) {
return d.Memberships.SelectMembershipForUser(ctx, nil, roomID, userID, pos) return d.Memberships.SelectMembershipForUser(ctx, nil, roomID, userID, pos)
} }
func (s *Database) ReIndex(ctx context.Context, limit, afterID int64) (map[int64]gomatrixserverlib.HeaderedEvent, error) {
return s.OutputEvents.ReIndex(ctx, nil, limit, afterID, []string{
gomatrixserverlib.MRoomName,
gomatrixserverlib.MRoomTopic,
"m.room.message",
})
}

View file

@ -91,14 +91,14 @@ func (s *accountDataStatements) InsertAccountData(
} }
func (s *accountDataStatements) SelectAccountDataInRange( func (s *accountDataStatements) SelectAccountDataInRange(
ctx context.Context, ctx context.Context, txn *sql.Tx,
userID string, userID string,
r types.Range, r types.Range,
filter *gomatrixserverlib.EventFilter, filter *gomatrixserverlib.EventFilter,
) (data map[string][]string, pos types.StreamPosition, err error) { ) (data map[string][]string, pos types.StreamPosition, err error) {
data = make(map[string][]string) data = make(map[string][]string)
stmt, params, err := prepareWithFilters( stmt, params, err := prepareWithFilters(
s.db, nil, selectAccountDataInRangeSQL, s.db, txn, selectAccountDataInRangeSQL,
[]interface{}{ []interface{}{
userID, r.Low(), r.High(), userID, r.Low(), r.High(),
}, },

View file

@ -82,9 +82,9 @@ func (s *backwardExtremitiesStatements) InsertsBackwardExtremity(
} }
func (s *backwardExtremitiesStatements) SelectBackwardExtremitiesForRoom( func (s *backwardExtremitiesStatements) SelectBackwardExtremitiesForRoom(
ctx context.Context, roomID string, ctx context.Context, txn *sql.Tx, roomID string,
) (bwExtrems map[string][]string, err error) { ) (bwExtrems map[string][]string, err error) {
rows, err := s.selectBackwardExtremitiesForRoomStmt.QueryContext(ctx, roomID) rows, err := sqlutil.TxStmt(txn, s.selectBackwardExtremitiesForRoomStmt).QueryContext(ctx, roomID)
if err != nil { if err != nil {
return return
} }

View file

@ -51,6 +51,8 @@ CREATE UNIQUE INDEX IF NOT EXISTS syncapi_event_id_idx ON syncapi_current_room_s
-- CREATE INDEX IF NOT EXISTS syncapi_membership_idx ON syncapi_current_room_state(type, state_key, membership) WHERE membership IS NOT NULL AND membership != 'leave'; -- CREATE INDEX IF NOT EXISTS syncapi_membership_idx ON syncapi_current_room_state(type, state_key, membership) WHERE membership IS NOT NULL AND membership != 'leave';
-- for querying state by event IDs -- for querying state by event IDs
CREATE UNIQUE INDEX IF NOT EXISTS syncapi_current_room_state_eventid_idx ON syncapi_current_room_state(event_id); CREATE UNIQUE INDEX IF NOT EXISTS syncapi_current_room_state_eventid_idx ON syncapi_current_room_state(event_id);
-- for improving selectRoomIDsWithAnyMembershipSQL
CREATE INDEX IF NOT EXISTS syncapi_current_room_state_type_state_key_idx ON syncapi_current_room_state(type, state_key);
` `
const upsertRoomStateSQL = "" + const upsertRoomStateSQL = "" +
@ -69,7 +71,7 @@ const selectRoomIDsWithMembershipSQL = "" +
"SELECT DISTINCT room_id FROM syncapi_current_room_state WHERE type = 'm.room.member' AND state_key = $1 AND membership = $2" "SELECT DISTINCT room_id FROM syncapi_current_room_state WHERE type = 'm.room.member' AND state_key = $1 AND membership = $2"
const selectRoomIDsWithAnyMembershipSQL = "" + const selectRoomIDsWithAnyMembershipSQL = "" +
"SELECT DISTINCT room_id, membership FROM syncapi_current_room_state WHERE type = 'm.room.member' AND state_key = $1" "SELECT room_id, membership FROM syncapi_current_room_state WHERE type = 'm.room.member' AND state_key = $1"
const selectCurrentStateSQL = "" + const selectCurrentStateSQL = "" +
"SELECT event_id, headered_event_json FROM syncapi_current_room_state WHERE room_id = $1" "SELECT event_id, headered_event_json FROM syncapi_current_room_state WHERE room_id = $1"
@ -161,9 +163,9 @@ func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *StreamIDStatements) (t
// SelectJoinedUsers returns a map of room ID to a list of joined user IDs. // SelectJoinedUsers returns a map of room ID to a list of joined user IDs.
func (s *currentRoomStateStatements) SelectJoinedUsers( func (s *currentRoomStateStatements) SelectJoinedUsers(
ctx context.Context, ctx context.Context, txn *sql.Tx,
) (map[string][]string, error) { ) (map[string][]string, error) {
rows, err := s.selectJoinedUsersStmt.QueryContext(ctx) rows, err := sqlutil.TxStmt(txn, s.selectJoinedUsersStmt).QueryContext(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -185,7 +187,7 @@ func (s *currentRoomStateStatements) SelectJoinedUsers(
// SelectJoinedUsersInRoom returns a map of room ID to a list of joined user IDs for a given room. // SelectJoinedUsersInRoom returns a map of room ID to a list of joined user IDs for a given room.
func (s *currentRoomStateStatements) SelectJoinedUsersInRoom( func (s *currentRoomStateStatements) SelectJoinedUsersInRoom(
ctx context.Context, roomIDs []string, ctx context.Context, txn *sql.Tx, roomIDs []string,
) (map[string][]string, error) { ) (map[string][]string, error) {
query := strings.Replace(selectJoinedUsersInRoomSQL, "($1)", sqlutil.QueryVariadic(len(roomIDs)), 1) query := strings.Replace(selectJoinedUsersInRoomSQL, "($1)", sqlutil.QueryVariadic(len(roomIDs)), 1)
params := make([]interface{}, 0, len(roomIDs)) params := make([]interface{}, 0, len(roomIDs))
@ -198,7 +200,7 @@ func (s *currentRoomStateStatements) SelectJoinedUsersInRoom(
} }
defer internal.CloseAndLogIfError(ctx, stmt, "SelectJoinedUsersInRoom: stmt.close() failed") defer internal.CloseAndLogIfError(ctx, stmt, "SelectJoinedUsersInRoom: stmt.close() failed")
rows, err := stmt.QueryContext(ctx, params...) rows, err := sqlutil.TxStmt(txn, stmt).QueryContext(ctx, params...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -399,9 +401,9 @@ func rowsToEvents(rows *sql.Rows) ([]*gomatrixserverlib.HeaderedEvent, error) {
} }
func (s *currentRoomStateStatements) SelectStateEvent( func (s *currentRoomStateStatements) SelectStateEvent(
ctx context.Context, roomID, evType, stateKey string, ctx context.Context, txn *sql.Tx, roomID, evType, stateKey string,
) (*gomatrixserverlib.HeaderedEvent, error) { ) (*gomatrixserverlib.HeaderedEvent, error) {
stmt := s.selectStateEventStmt stmt := sqlutil.TxStmt(txn, s.selectStateEventStmt)
var res []byte var res []byte
err := stmt.QueryRowContext(ctx, roomID, evType, stateKey).Scan(&res) err := stmt.QueryRowContext(ctx, roomID, evType, stateKey).Scan(&res)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
@ -427,10 +429,17 @@ func (s *currentRoomStateStatements) SelectSharedUsers(
params[k+1] = v params[k+1] = v
} }
var provider sqlutil.QueryProvider
if txn == nil {
provider = s.db
} else {
provider = txn
}
result := make([]string, 0, len(otherUserIDs)) result := make([]string, 0, len(otherUserIDs))
query := strings.Replace(selectSharedUsersSQL, "($2)", sqlutil.QueryVariadicOffset(len(otherUserIDs), 1), 1) query := strings.Replace(selectSharedUsersSQL, "($2)", sqlutil.QueryVariadicOffset(len(otherUserIDs), 1), 1)
err := sqlutil.RunLimitedVariablesQuery( err := sqlutil.RunLimitedVariablesQuery(
ctx, query, s.db, params, sqlutil.SQLite3MaxVariables, ctx, query, provider, params, sqlutil.SQLite3MaxVariables,
func(rows *sql.Rows) error { func(rows *sql.Rows) error {
var stateKey string var stateKey string
for rows.Next() { for rows.Next() {

View file

@ -20,6 +20,7 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
@ -77,11 +78,11 @@ func NewSqliteFilterTable(db *sql.DB) (tables.Filter, error) {
} }
func (s *filterStatements) SelectFilter( func (s *filterStatements) SelectFilter(
ctx context.Context, target *gomatrixserverlib.Filter, localpart string, filterID string, ctx context.Context, txn *sql.Tx, target *gomatrixserverlib.Filter, localpart string, filterID string,
) error { ) error {
// Retrieve filter from database (stored as canonical JSON) // Retrieve filter from database (stored as canonical JSON)
var filterData []byte var filterData []byte
err := s.selectFilterStmt.QueryRowContext(ctx, localpart, filterID).Scan(&filterData) err := sqlutil.TxStmt(txn, s.selectFilterStmt).QueryRowContext(ctx, localpart, filterID).Scan(&filterData)
if err != nil { if err != nil {
return err return err
} }
@ -94,7 +95,7 @@ func (s *filterStatements) SelectFilter(
} }
func (s *filterStatements) InsertFilter( func (s *filterStatements) InsertFilter(
ctx context.Context, filter *gomatrixserverlib.Filter, localpart string, ctx context.Context, txn *sql.Tx, filter *gomatrixserverlib.Filter, localpart string,
) (filterID string, err error) { ) (filterID string, err error) {
var existingFilterID string var existingFilterID string
@ -115,8 +116,9 @@ func (s *filterStatements) InsertFilter(
// This can result in a race condition when two clients try to insert the // This can result in a race condition when two clients try to insert the
// same filter and localpart at the same time, however this is not a // same filter and localpart at the same time, however this is not a
// problem as both calls will result in the same filterID // problem as both calls will result in the same filterID
err = s.selectFilterIDByContentStmt.QueryRowContext(ctx, err = sqlutil.TxStmt(txn, s.selectFilterIDByContentStmt).QueryRowContext(
localpart, filterJSON).Scan(&existingFilterID) ctx, localpart, filterJSON,
).Scan(&existingFilterID)
if err != nil && err != sql.ErrNoRows { if err != nil && err != sql.ErrNoRows {
return "", err return "", err
} }
@ -126,7 +128,7 @@ func (s *filterStatements) InsertFilter(
} }
// Otherwise insert the filter and return the new ID // Otherwise insert the filter and return the new ID
res, err := s.insertFilterStmt.ExecContext(ctx, filterJSON, localpart) res, err := sqlutil.TxStmt(txn, s.insertFilterStmt).ExecContext(ctx, filterJSON, localpart)
if err != nil { if err != nil {
return "", err return "", err
} }

View file

@ -17,6 +17,7 @@ package sqlite3
import ( import (
"context" "context"
"database/sql" "database/sql"
"strings"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/internal/eventutil"
@ -32,19 +33,21 @@ func NewSqliteNotificationDataTable(db *sql.DB, streamID *StreamIDStatements) (t
} }
r := &notificationDataStatements{ r := &notificationDataStatements{
streamIDStatements: streamID, streamIDStatements: streamID,
db: db,
} }
return r, sqlutil.StatementList{ return r, sqlutil.StatementList{
{&r.upsertRoomUnreadCounts, upsertRoomUnreadNotificationCountsSQL}, {&r.upsertRoomUnreadCounts, upsertRoomUnreadNotificationCountsSQL},
{&r.selectUserUnreadCounts, selectUserUnreadNotificationCountsSQL},
{&r.selectMaxID, selectMaxNotificationIDSQL}, {&r.selectMaxID, selectMaxNotificationIDSQL},
// {&r.selectUserUnreadCountsForRooms, selectUserUnreadNotificationsForRooms}, // used at runtime
}.Prepare(db) }.Prepare(db)
} }
type notificationDataStatements struct { type notificationDataStatements struct {
db *sql.DB
streamIDStatements *StreamIDStatements streamIDStatements *StreamIDStatements
upsertRoomUnreadCounts *sql.Stmt upsertRoomUnreadCounts *sql.Stmt
selectUserUnreadCounts *sql.Stmt
selectMaxID *sql.Stmt selectMaxID *sql.Stmt
//selectUserUnreadCountsForRooms *sql.Stmt
} }
const notificationDataSchema = ` const notificationDataSchema = `
@ -63,12 +66,10 @@ const upsertRoomUnreadNotificationCountsSQL = `INSERT INTO syncapi_notification_
ON CONFLICT (user_id, room_id) ON CONFLICT (user_id, room_id)
DO UPDATE SET id = $5, notification_count = $6, highlight_count = $7` DO UPDATE SET id = $5, notification_count = $6, highlight_count = $7`
const selectUserUnreadNotificationCountsSQL = `SELECT const selectUserUnreadNotificationsForRooms = `SELECT room_id, notification_count, highlight_count
id, room_id, notification_count, highlight_count
FROM syncapi_notification_data FROM syncapi_notification_data
WHERE WHERE user_id = $1 AND
user_id = $1 AND room_id IN ($2)`
id BETWEEN $2 + 1 AND $3`
const selectMaxNotificationIDSQL = `SELECT CASE COUNT(*) WHEN 0 THEN 0 ELSE MAX(id) END FROM syncapi_notification_data` const selectMaxNotificationIDSQL = `SELECT CASE COUNT(*) WHEN 0 THEN 0 ELSE MAX(id) END FROM syncapi_notification_data`
@ -81,20 +82,31 @@ func (r *notificationDataStatements) UpsertRoomUnreadCounts(ctx context.Context,
return return
} }
func (r *notificationDataStatements) SelectUserUnreadCounts(ctx context.Context, txn *sql.Tx, userID string, fromExcl, toIncl types.StreamPosition) (map[string]*eventutil.NotificationData, error) { func (r *notificationDataStatements) SelectUserUnreadCountsForRooms(
rows, err := sqlutil.TxStmt(txn, r.selectUserUnreadCounts).QueryContext(ctx, userID, fromExcl, toIncl) ctx context.Context, txn *sql.Tx, userID string, roomIDs []string,
) (map[string]*eventutil.NotificationData, error) {
params := make([]interface{}, len(roomIDs)+1)
params[0] = userID
for i := range roomIDs {
params[i+1] = roomIDs[i]
}
sql := strings.Replace(selectUserUnreadNotificationsForRooms, "($2)", sqlutil.QueryVariadicOffset(len(roomIDs), 1), 1)
prep, err := r.db.PrepareContext(ctx, sql)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer internal.CloseAndLogIfError(ctx, rows, "SelectUserUnreadCounts: rows.close() failed") defer internal.CloseAndLogIfError(ctx, prep, "SelectUserUnreadCountsForRooms: prep.close() failed")
rows, err := sqlutil.TxStmt(txn, prep).QueryContext(ctx, params...)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "SelectUserUnreadCountsForRooms: rows.close() failed")
roomCounts := map[string]*eventutil.NotificationData{} roomCounts := map[string]*eventutil.NotificationData{}
for rows.Next() {
var id types.StreamPosition
var roomID string var roomID string
var notificationCount, highlightCount int var notificationCount, highlightCount int
for rows.Next() {
if err = rows.Scan(&id, &roomID, &notificationCount, &highlightCount); err != nil { if err = rows.Scan(&roomID, &notificationCount, &highlightCount); err != nil {
return nil, err return nil, err
} }

View file

@ -115,6 +115,8 @@ const selectContextAfterEventSQL = "" +
// WHEN, ORDER BY and LIMIT are appended by prepareWithFilters // WHEN, ORDER BY and LIMIT are appended by prepareWithFilters
const selectSearchSQL = "SELECT id, event_id, headered_event_json FROM syncapi_output_room_events WHERE type IN ($1) AND id > $2 LIMIT $3 ORDER BY id ASC"
type outputRoomEventsStatements struct { type outputRoomEventsStatements struct {
db *sql.DB db *sql.DB
streamIDStatements *StreamIDStatements streamIDStatements *StreamIDStatements
@ -125,6 +127,7 @@ type outputRoomEventsStatements struct {
selectContextEventStmt *sql.Stmt selectContextEventStmt *sql.Stmt
selectContextBeforeEventStmt *sql.Stmt selectContextBeforeEventStmt *sql.Stmt
selectContextAfterEventStmt *sql.Stmt selectContextAfterEventStmt *sql.Stmt
//selectSearchStmt *sql.Stmt - prepared at runtime
} }
func NewSqliteEventsTable(db *sql.DB, streamID *StreamIDStatements) (tables.Events, error) { func NewSqliteEventsTable(db *sql.DB, streamID *StreamIDStatements) (tables.Events, error) {
@ -157,15 +160,16 @@ func NewSqliteEventsTable(db *sql.DB, streamID *StreamIDStatements) (tables.Even
{&s.selectContextEventStmt, selectContextEventSQL}, {&s.selectContextEventStmt, selectContextEventSQL},
{&s.selectContextBeforeEventStmt, selectContextBeforeEventSQL}, {&s.selectContextBeforeEventStmt, selectContextBeforeEventSQL},
{&s.selectContextAfterEventStmt, selectContextAfterEventSQL}, {&s.selectContextAfterEventStmt, selectContextAfterEventSQL},
//{&s.selectSearchStmt, selectSearchSQL}, - prepared at runtime
}.Prepare(db) }.Prepare(db)
} }
func (s *outputRoomEventsStatements) UpdateEventJSON(ctx context.Context, event *gomatrixserverlib.HeaderedEvent) error { func (s *outputRoomEventsStatements) UpdateEventJSON(ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent) error {
headeredJSON, err := json.Marshal(event) headeredJSON, err := json.Marshal(event)
if err != nil { if err != nil {
return err return err
} }
_, err = s.updateEventJSONStmt.ExecContext(ctx, headeredJSON, event.EventID()) _, err = sqlutil.TxStmt(txn, s.updateEventJSONStmt).ExecContext(ctx, headeredJSON, event.EventID())
return err return err
} }
@ -628,3 +632,40 @@ func unmarshalStateIDs(addIDsJSON, delIDsJSON string) (addIDs []string, delIDs [
} }
return return
} }
func (s *outputRoomEventsStatements) ReIndex(ctx context.Context, txn *sql.Tx, limit, afterID int64, types []string) (map[int64]gomatrixserverlib.HeaderedEvent, error) {
params := make([]interface{}, len(types))
for i := range types {
params[i] = types[i]
}
params = append(params, afterID)
params = append(params, limit)
selectSQL := strings.Replace(selectSearchSQL, "($1)", sqlutil.QueryVariadic(len(types)), 1)
stmt, err := s.db.Prepare(selectSQL)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, stmt, "selectEvents: stmt.close() failed")
rows, err := sqlutil.TxStmt(txn, stmt).QueryContext(ctx, params...)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "rows.close() failed")
var eventID string
var id int64
result := make(map[int64]gomatrixserverlib.HeaderedEvent)
for rows.Next() {
var ev gomatrixserverlib.HeaderedEvent
var eventBytes []byte
if err = rows.Scan(&id, &eventID, &eventBytes); err != nil {
return nil, err
}
if err = ev.UnmarshalJSONWithEventID(eventBytes, eventID); err != nil {
return nil, err
}
result[id] = ev
}
return result, rows.Err()
}

View file

@ -176,9 +176,9 @@ func (s *outputRoomEventsTopologyStatements) SelectStreamToTopologicalPosition(
ctx context.Context, txn *sql.Tx, roomID string, streamPos types.StreamPosition, backwardOrdering bool, ctx context.Context, txn *sql.Tx, roomID string, streamPos types.StreamPosition, backwardOrdering bool,
) (topoPos types.StreamPosition, err error) { ) (topoPos types.StreamPosition, err error) {
if backwardOrdering { if backwardOrdering {
err = s.selectStreamToTopologicalPositionDescStmt.QueryRowContext(ctx, roomID, streamPos).Scan(&topoPos) err = sqlutil.TxStmt(txn, s.selectStreamToTopologicalPositionDescStmt).QueryRowContext(ctx, roomID, streamPos).Scan(&topoPos)
} else { } else {
err = s.selectStreamToTopologicalPositionAscStmt.QueryRowContext(ctx, roomID, streamPos).Scan(&topoPos) err = sqlutil.TxStmt(txn, s.selectStreamToTopologicalPositionAscStmt).QueryRowContext(ctx, roomID, streamPos).Scan(&topoPos)
} }
return return
} }

View file

@ -172,9 +172,9 @@ func (s *peekStatements) SelectPeeksInRange(
} }
func (s *peekStatements) SelectPeekingDevices( func (s *peekStatements) SelectPeekingDevices(
ctx context.Context, ctx context.Context, txn *sql.Tx,
) (peekingDevices map[string][]types.PeekingDevice, err error) { ) (peekingDevices map[string][]types.PeekingDevice, err error) {
rows, err := s.selectPeekingDevicesStmt.QueryContext(ctx) rows, err := sqlutil.TxStmt(txn, s.selectPeekingDevicesStmt).QueryContext(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -108,7 +108,7 @@ func (r *receiptStatements) UpsertReceipt(ctx context.Context, txn *sql.Tx, room
} }
// SelectRoomReceiptsAfter select all receipts for a given room after a specific timestamp // SelectRoomReceiptsAfter select all receipts for a given room after a specific timestamp
func (r *receiptStatements) SelectRoomReceiptsAfter(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) (types.StreamPosition, []types.OutputReceiptEvent, error) { func (r *receiptStatements) SelectRoomReceiptsAfter(ctx context.Context, txn *sql.Tx, roomIDs []string, streamPos types.StreamPosition) (types.StreamPosition, []types.OutputReceiptEvent, error) {
selectSQL := strings.Replace(selectRoomReceipts, "($2)", sqlutil.QueryVariadicOffset(len(roomIDs), 1), 1) selectSQL := strings.Replace(selectRoomReceipts, "($2)", sqlutil.QueryVariadicOffset(len(roomIDs), 1), 1)
var lastPos types.StreamPosition var lastPos types.StreamPosition
params := make([]interface{}, len(roomIDs)+1) params := make([]interface{}, len(roomIDs)+1)
@ -116,7 +116,12 @@ func (r *receiptStatements) SelectRoomReceiptsAfter(ctx context.Context, roomIDs
for k, v := range roomIDs { for k, v := range roomIDs {
params[k+1] = v params[k+1] = v
} }
rows, err := r.db.QueryContext(ctx, selectSQL, params...) prep, err := r.db.Prepare(selectSQL)
if err != nil {
return 0, nil, fmt.Errorf("unable to prepare statement: %w", err)
}
defer internal.CloseAndLogIfError(ctx, prep, "SelectRoomReceiptsAfter: prep.close() failed")
rows, err := sqlutil.TxStmt(txn, prep).QueryContext(ctx, params...)
if err != nil { if err != nil {
return 0, nil, fmt.Errorf("unable to query room receipts: %w", err) return 0, nil, fmt.Errorf("unable to query room receipts: %w", err)
} }

View file

@ -28,7 +28,7 @@ import (
type AccountData interface { type AccountData interface {
InsertAccountData(ctx context.Context, txn *sql.Tx, userID, roomID, dataType string) (pos types.StreamPosition, err error) InsertAccountData(ctx context.Context, txn *sql.Tx, userID, roomID, dataType string) (pos types.StreamPosition, err error)
// SelectAccountDataInRange returns a map of room ID to a list of `dataType`. // SelectAccountDataInRange returns a map of room ID to a list of `dataType`.
SelectAccountDataInRange(ctx context.Context, userID string, r types.Range, accountDataEventFilter *gomatrixserverlib.EventFilter) (data map[string][]string, pos types.StreamPosition, err error) SelectAccountDataInRange(ctx context.Context, txn *sql.Tx, userID string, r types.Range, accountDataEventFilter *gomatrixserverlib.EventFilter) (data map[string][]string, pos types.StreamPosition, err error)
SelectMaxAccountDataID(ctx context.Context, txn *sql.Tx) (id int64, err error) SelectMaxAccountDataID(ctx context.Context, txn *sql.Tx) (id int64, err error)
} }
@ -46,7 +46,7 @@ type Peeks interface {
DeletePeek(ctx context.Context, txn *sql.Tx, roomID, userID, deviceID string) (streamPos types.StreamPosition, err error) DeletePeek(ctx context.Context, txn *sql.Tx, roomID, userID, deviceID string) (streamPos types.StreamPosition, err error)
DeletePeeks(ctx context.Context, txn *sql.Tx, roomID, userID string) (streamPos types.StreamPosition, err error) DeletePeeks(ctx context.Context, txn *sql.Tx, roomID, userID string) (streamPos types.StreamPosition, err error)
SelectPeeksInRange(ctxt context.Context, txn *sql.Tx, userID, deviceID string, r types.Range) (peeks []types.Peek, err error) SelectPeeksInRange(ctxt context.Context, txn *sql.Tx, userID, deviceID string, r types.Range) (peeks []types.Peek, err error)
SelectPeekingDevices(ctxt context.Context) (peekingDevices map[string][]types.PeekingDevice, err error) SelectPeekingDevices(ctxt context.Context, txn *sql.Tx) (peekingDevices map[string][]types.PeekingDevice, err error)
SelectMaxPeekID(ctx context.Context, txn *sql.Tx) (id int64, err error) SelectMaxPeekID(ctx context.Context, txn *sql.Tx) (id int64, err error)
} }
@ -68,13 +68,14 @@ type Events interface {
// SelectEarlyEvents returns the earliest events in the given room. // SelectEarlyEvents returns the earliest events in the given room.
SelectEarlyEvents(ctx context.Context, txn *sql.Tx, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter) ([]types.StreamEvent, error) SelectEarlyEvents(ctx context.Context, txn *sql.Tx, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter) ([]types.StreamEvent, error)
SelectEvents(ctx context.Context, txn *sql.Tx, eventIDs []string, filter *gomatrixserverlib.RoomEventFilter, preserveOrder bool) ([]types.StreamEvent, error) SelectEvents(ctx context.Context, txn *sql.Tx, eventIDs []string, filter *gomatrixserverlib.RoomEventFilter, preserveOrder bool) ([]types.StreamEvent, error)
UpdateEventJSON(ctx context.Context, event *gomatrixserverlib.HeaderedEvent) error UpdateEventJSON(ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent) error
// DeleteEventsForRoom removes all event information for a room. This should only be done when removing the room entirely. // DeleteEventsForRoom removes all event information for a room. This should only be done when removing the room entirely.
DeleteEventsForRoom(ctx context.Context, txn *sql.Tx, roomID string) (err error) DeleteEventsForRoom(ctx context.Context, txn *sql.Tx, roomID string) (err error)
SelectContextEvent(ctx context.Context, txn *sql.Tx, roomID, eventID string) (int, gomatrixserverlib.HeaderedEvent, error) SelectContextEvent(ctx context.Context, txn *sql.Tx, roomID, eventID string) (int, gomatrixserverlib.HeaderedEvent, error)
SelectContextBeforeEvent(ctx context.Context, txn *sql.Tx, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) ([]*gomatrixserverlib.HeaderedEvent, error) SelectContextBeforeEvent(ctx context.Context, txn *sql.Tx, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) ([]*gomatrixserverlib.HeaderedEvent, error)
SelectContextAfterEvent(ctx context.Context, txn *sql.Tx, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) (int, []*gomatrixserverlib.HeaderedEvent, error) SelectContextAfterEvent(ctx context.Context, txn *sql.Tx, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) (int, []*gomatrixserverlib.HeaderedEvent, error)
ReIndex(ctx context.Context, txn *sql.Tx, limit, offset int64, types []string) (map[int64]gomatrixserverlib.HeaderedEvent, error)
} }
// Topology keeps track of the depths and stream positions for all events. // Topology keeps track of the depths and stream positions for all events.
@ -97,7 +98,7 @@ type Topology interface {
} }
type CurrentRoomState interface { type CurrentRoomState interface {
SelectStateEvent(ctx context.Context, roomID, evType, stateKey string) (*gomatrixserverlib.HeaderedEvent, error) SelectStateEvent(ctx context.Context, txn *sql.Tx, roomID, evType, stateKey string) (*gomatrixserverlib.HeaderedEvent, error)
SelectEventsWithEventIDs(ctx context.Context, txn *sql.Tx, eventIDs []string) ([]types.StreamEvent, error) SelectEventsWithEventIDs(ctx context.Context, txn *sql.Tx, eventIDs []string) ([]types.StreamEvent, error)
UpsertRoomState(ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, membership *string, addedAt types.StreamPosition) error UpsertRoomState(ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, membership *string, addedAt types.StreamPosition) error
DeleteRoomStateByEventID(ctx context.Context, txn *sql.Tx, eventID string) error DeleteRoomStateByEventID(ctx context.Context, txn *sql.Tx, eventID string) error
@ -109,9 +110,9 @@ type CurrentRoomState interface {
// SelectRoomIDsWithAnyMembership returns a map of all memberships for the given user. // SelectRoomIDsWithAnyMembership returns a map of all memberships for the given user.
SelectRoomIDsWithAnyMembership(ctx context.Context, txn *sql.Tx, userID string) (map[string]string, error) SelectRoomIDsWithAnyMembership(ctx context.Context, txn *sql.Tx, userID string) (map[string]string, error)
// SelectJoinedUsers returns a map of room ID to a list of joined user IDs. // SelectJoinedUsers returns a map of room ID to a list of joined user IDs.
SelectJoinedUsers(ctx context.Context) (map[string][]string, error) SelectJoinedUsers(ctx context.Context, txn *sql.Tx) (map[string][]string, error)
// SelectJoinedUsersInRoom returns a map of room ID to a list of joined user IDs for a given room. // SelectJoinedUsersInRoom returns a map of room ID to a list of joined user IDs for a given room.
SelectJoinedUsersInRoom(ctx context.Context, roomIDs []string) (map[string][]string, error) SelectJoinedUsersInRoom(ctx context.Context, txn *sql.Tx, roomIDs []string) (map[string][]string, error)
// SelectSharedUsers returns a subset of otherUserIDs that share a room with userID. // SelectSharedUsers returns a subset of otherUserIDs that share a room with userID.
SelectSharedUsers(ctx context.Context, txn *sql.Tx, userID string, otherUserIDs []string) ([]string, error) SelectSharedUsers(ctx context.Context, txn *sql.Tx, userID string, otherUserIDs []string) ([]string, error)
} }
@ -141,7 +142,7 @@ type BackwardsExtremities interface {
// InsertsBackwardExtremity inserts a new backwards extremity. // InsertsBackwardExtremity inserts a new backwards extremity.
InsertsBackwardExtremity(ctx context.Context, txn *sql.Tx, roomID, eventID string, prevEventID string) (err error) InsertsBackwardExtremity(ctx context.Context, txn *sql.Tx, roomID, eventID string, prevEventID string) (err error)
// SelectBackwardExtremitiesForRoom retrieves all backwards extremities for the room, as a map of event_id to list of prev_event_ids. // SelectBackwardExtremitiesForRoom retrieves all backwards extremities for the room, as a map of event_id to list of prev_event_ids.
SelectBackwardExtremitiesForRoom(ctx context.Context, roomID string) (bwExtrems map[string][]string, err error) SelectBackwardExtremitiesForRoom(ctx context.Context, txn *sql.Tx, roomID string) (bwExtrems map[string][]string, err error)
// DeleteBackwardExtremity removes a backwards extremity for a room, if one existed. // DeleteBackwardExtremity removes a backwards extremity for a room, if one existed.
DeleteBackwardExtremity(ctx context.Context, txn *sql.Tx, roomID, knownEventID string) (err error) DeleteBackwardExtremity(ctx context.Context, txn *sql.Tx, roomID, knownEventID string) (err error)
} }
@ -171,13 +172,13 @@ type SendToDevice interface {
} }
type Filter interface { type Filter interface {
SelectFilter(ctx context.Context, target *gomatrixserverlib.Filter, localpart string, filterID string) error SelectFilter(ctx context.Context, txn *sql.Tx, target *gomatrixserverlib.Filter, localpart string, filterID string) error
InsertFilter(ctx context.Context, filter *gomatrixserverlib.Filter, localpart string) (filterID string, err error) InsertFilter(ctx context.Context, txn *sql.Tx, filter *gomatrixserverlib.Filter, localpart string) (filterID string, err error)
} }
type Receipts interface { type Receipts interface {
UpsertReceipt(ctx context.Context, txn *sql.Tx, roomId, receiptType, userId, eventId string, timestamp gomatrixserverlib.Timestamp) (pos types.StreamPosition, err error) UpsertReceipt(ctx context.Context, txn *sql.Tx, roomId, receiptType, userId, eventId string, timestamp gomatrixserverlib.Timestamp) (pos types.StreamPosition, err error)
SelectRoomReceiptsAfter(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) (types.StreamPosition, []types.OutputReceiptEvent, error) SelectRoomReceiptsAfter(ctx context.Context, txn *sql.Tx, roomIDs []string, streamPos types.StreamPosition) (types.StreamPosition, []types.OutputReceiptEvent, error)
SelectMaxReceiptID(ctx context.Context, txn *sql.Tx) (id int64, err error) SelectMaxReceiptID(ctx context.Context, txn *sql.Tx) (id int64, err error)
} }
@ -190,7 +191,7 @@ type Memberships interface {
type NotificationData interface { type NotificationData interface {
UpsertRoomUnreadCounts(ctx context.Context, txn *sql.Tx, userID, roomID string, notificationCount, highlightCount int) (types.StreamPosition, error) UpsertRoomUnreadCounts(ctx context.Context, txn *sql.Tx, userID, roomID string, notificationCount, highlightCount int) (types.StreamPosition, error)
SelectUserUnreadCounts(ctx context.Context, txn *sql.Tx, userID string, fromExcl, toIncl types.StreamPosition) (map[string]*eventutil.NotificationData, error) SelectUserUnreadCountsForRooms(ctx context.Context, txn *sql.Tx, userID string, roomIDs []string) (map[string]*eventutil.NotificationData, error)
SelectMaxID(ctx context.Context, txn *sql.Tx) (int64, error) SelectMaxID(ctx context.Context, txn *sql.Tx) (int64, error)
} }

View file

@ -3,9 +3,10 @@ package streams
import ( import (
"context" "context"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
) )
type AccountDataStreamProvider struct { type AccountDataStreamProvider struct {

View file

@ -30,27 +30,29 @@ func (p *NotificationDataStreamProvider) CompleteSync(
func (p *NotificationDataStreamProvider) IncrementalSync( func (p *NotificationDataStreamProvider) IncrementalSync(
ctx context.Context, ctx context.Context,
req *types.SyncRequest, req *types.SyncRequest,
from, to types.StreamPosition, from, _ types.StreamPosition,
) types.StreamPosition { ) types.StreamPosition {
// We want counts for all possible rooms, so always start from zero. // Get the unread notifications for rooms in our join response.
countsByRoom, err := p.DB.GetUserUnreadNotificationCounts(ctx, req.Device.UserID, from, to) // This is to ensure clients always have an unread notification section
// and can display the correct numbers.
countsByRoom, err := p.DB.GetUserUnreadNotificationCountsForRooms(ctx, req.Device.UserID, req.Rooms)
if err != nil { if err != nil {
req.Log.WithError(err).Error("GetUserUnreadNotificationCounts failed") req.Log.WithError(err).Error("GetUserUnreadNotificationCountsForRooms failed")
return from return from
} }
// We're merely decorating existing rooms. Note that the Join map // We're merely decorating existing rooms.
// values are not pointers.
for roomID, jr := range req.Response.Rooms.Join { for roomID, jr := range req.Response.Rooms.Join {
counts := countsByRoom[roomID] counts := countsByRoom[roomID]
if counts == nil { if counts == nil {
continue continue
} }
jr.UnreadNotifications.HighlightCount = counts.UnreadHighlightCount jr.UnreadNotifications = &types.UnreadNotifications{
jr.UnreadNotifications.NotificationCount = counts.UnreadNotificationCount HighlightCount: counts.UnreadHighlightCount,
NotificationCount: counts.UnreadNotificationCount,
}
req.Response.Rooms.Join[roomID] = jr req.Response.Rooms.Join[roomID] = jr
} }
// BEGIN ZION CODE but return all notifications regardless of whether they're in a room we're in. // BEGIN ZION CODE but return all notifications regardless of whether they're in a room we're in.
for roomID, counts := range countsByRoom { for roomID, counts := range countsByRoom {
unreadNotificationsData := *types.NewUnreadNotificationsResponse() unreadNotificationsData := *types.NewUnreadNotificationsResponse()
@ -60,5 +62,5 @@ func (p *NotificationDataStreamProvider) IncrementalSync(
req.Response.Rooms.UnreadNotifications[roomID] = unreadNotificationsData req.Response.Rooms.UnreadNotifications[roomID] = unreadNotificationsData
} }
// END ZION CODE // END ZION CODE
return to return p.LatestPosition(ctx)
} }

View file

@ -77,16 +77,6 @@ func AddPublicRoutes(
logrus.WithError(err).Panicf("failed to start presence consumer") logrus.WithError(err).Panicf("failed to start presence consumer")
} }
userAPIStreamEventProducer := &producers.UserAPIStreamEventProducer{
JetStream: js,
Topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputStreamEvent),
}
userAPIReadUpdateProducer := &producers.UserAPIReadProducer{
JetStream: js,
Topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputReadUpdate),
}
keyChangeConsumer := consumers.NewOutputKeyChangeEventConsumer( keyChangeConsumer := consumers.NewOutputKeyChangeEventConsumer(
base.ProcessContext, cfg, cfg.Matrix.JetStream.Prefixed(jetstream.OutputKeyChangeEvent), base.ProcessContext, cfg, cfg.Matrix.JetStream.Prefixed(jetstream.OutputKeyChangeEvent),
js, rsAPI, syncDB, notifier, js, rsAPI, syncDB, notifier,
@ -98,15 +88,15 @@ func AddPublicRoutes(
roomConsumer := consumers.NewOutputRoomEventConsumer( roomConsumer := consumers.NewOutputRoomEventConsumer(
base.ProcessContext, cfg, js, syncDB, notifier, streams.PDUStreamProvider, base.ProcessContext, cfg, js, syncDB, notifier, streams.PDUStreamProvider,
streams.InviteStreamProvider, rsAPI, userAPIStreamEventProducer, streams.InviteStreamProvider, rsAPI, base.Fulltext,
) )
if err = roomConsumer.Start(); err != nil { if err = roomConsumer.Start(); err != nil {
logrus.WithError(err).Panicf("failed to start room server consumer") logrus.WithError(err).Panicf("failed to start room server consumer")
} }
clientConsumer := consumers.NewOutputClientDataConsumer( clientConsumer := consumers.NewOutputClientDataConsumer(
base.ProcessContext, cfg, js, syncDB, notifier, streams.AccountDataStreamProvider, base.ProcessContext, cfg, js, natsClient, syncDB, notifier,
userAPIReadUpdateProducer, streams.AccountDataStreamProvider, base.Fulltext,
) )
if err = clientConsumer.Start(); err != nil { if err = clientConsumer.Start(); err != nil {
logrus.WithError(err).Panicf("failed to start client data consumer") logrus.WithError(err).Panicf("failed to start client data consumer")
@ -135,7 +125,6 @@ func AddPublicRoutes(
receiptConsumer := consumers.NewOutputReceiptEventConsumer( receiptConsumer := consumers.NewOutputReceiptEventConsumer(
base.ProcessContext, cfg, js, syncDB, notifier, streams.ReceiptStreamProvider, base.ProcessContext, cfg, js, syncDB, notifier, streams.ReceiptStreamProvider,
userAPIReadUpdateProducer,
) )
if err = receiptConsumer.Start(); err != nil { if err = receiptConsumer.Start(); err != nil {
logrus.WithError(err).Panicf("failed to start receipts consumer") logrus.WithError(err).Panicf("failed to start receipts consumer")
@ -143,6 +132,6 @@ func AddPublicRoutes(
routing.Setup( routing.Setup(
base.PublicClientAPIMux, requestPool, syncDB, userAPI, base.PublicClientAPIMux, requestPool, syncDB, userAPI,
rsAPI, cfg, base.Caches, rsAPI, cfg, base.Caches, base.Fulltext,
) )
} }

View file

@ -402,6 +402,11 @@ func (r *Response) IsEmpty() bool {
len(r.ToDevice.Events) == 0 len(r.ToDevice.Events) == 0
} }
type UnreadNotifications struct {
HighlightCount int `json:"highlight_count"`
NotificationCount int `json:"notification_count"`
}
// JoinResponse represents a /sync response for a room which is under the 'join' or 'peek' key. // JoinResponse represents a /sync response for a room which is under the 'join' or 'peek' key.
type JoinResponse struct { type JoinResponse struct {
Summary struct { Summary struct {
@ -423,10 +428,7 @@ type JoinResponse struct {
AccountData struct { AccountData struct {
Events []gomatrixserverlib.ClientEvent `json:"events"` Events []gomatrixserverlib.ClientEvent `json:"events"`
} `json:"account_data"` } `json:"account_data"`
UnreadNotifications struct { *UnreadNotifications `json:"unread_notifications,omitempty"`
HighlightCount int `json:"highlight_count"`
NotificationCount int `json:"notification_count"`
} `json:"unread_notifications"`
} }
// NewJoinResponse creates an empty response with initialised arrays. // NewJoinResponse creates an empty response with initialised arrays.
@ -517,19 +519,6 @@ type Peek struct {
Deleted bool Deleted bool
} }
type ReadUpdate struct {
UserID string `json:"user_id"`
RoomID string `json:"room_id"`
Read StreamPosition `json:"read,omitempty"`
FullyRead StreamPosition `json:"fully_read,omitempty"`
}
// StreamEvent is the same as gomatrixserverlib.Event but also has the PDU stream position for this event.
type StreamedEvent struct {
Event *gomatrixserverlib.HeaderedEvent `json:"event"`
StreamPosition StreamPosition `json:"stream_position"`
}
// OutputReceiptEvent is an entry in the receipt output kafka log // OutputReceiptEvent is an entry in the receipt output kafka log
type OutputReceiptEvent struct { type OutputReceiptEvent struct {
UserID string `json:"user_id"` UserID string `json:"user_id"`

View file

@ -0,0 +1,127 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package consumers
import (
"context"
"github.com/matrix-org/gomatrixserverlib"
"github.com/nats-io/nats.go"
log "github.com/sirupsen/logrus"
"github.com/matrix-org/dendrite/internal/pushgateway"
"github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/setup/process"
"github.com/matrix-org/dendrite/userapi/producers"
"github.com/matrix-org/dendrite/userapi/util"
)
// OutputReceiptEventConsumer consumes events that originated in the clientAPI.
type OutputReceiptEventConsumer struct {
ctx context.Context
jetstream nats.JetStreamContext
durable string
topic string
db storage.Database
serverName gomatrixserverlib.ServerName
syncProducer *producers.SyncAPI
pgClient pushgateway.Client
}
// NewOutputReceiptEventConsumer creates a new OutputReceiptEventConsumer.
// Call Start() to begin consuming from the EDU server.
func NewOutputReceiptEventConsumer(
process *process.ProcessContext,
cfg *config.UserAPI,
js nats.JetStreamContext,
store storage.Database,
syncProducer *producers.SyncAPI,
pgClient pushgateway.Client,
) *OutputReceiptEventConsumer {
return &OutputReceiptEventConsumer{
ctx: process.Context(),
jetstream: js,
topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputReceiptEvent),
durable: cfg.Matrix.JetStream.Durable("UserAPIReceiptConsumer"),
db: store,
serverName: cfg.Matrix.ServerName,
syncProducer: syncProducer,
pgClient: pgClient,
}
}
// Start consuming receipts events.
func (s *OutputReceiptEventConsumer) Start() error {
return jetstream.JetStreamConsumer(
s.ctx, s.jetstream, s.topic, s.durable, 1,
s.onMessage, nats.DeliverAll(), nats.ManualAck(),
)
}
func (s *OutputReceiptEventConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool {
msg := msgs[0] // Guaranteed to exist if onMessage is called
userID := msg.Header.Get(jetstream.UserID)
roomID := msg.Header.Get(jetstream.RoomID)
readPos := msg.Header.Get(jetstream.EventID)
evType := msg.Header.Get("type")
if readPos == "" || evType != "m.read" {
return true
}
log := log.WithFields(log.Fields{
"room_id": roomID,
"user_id": userID,
})
localpart, domain, err := gomatrixserverlib.SplitID('@', userID)
if err != nil {
log.WithError(err).Error("userapi clientapi consumer: SplitID failure")
return true
}
if domain != s.serverName {
return true
}
metadata, err := msg.Metadata()
if err != nil {
return false
}
updated, err := s.db.SetNotificationsRead(ctx, localpart, roomID, uint64(gomatrixserverlib.AsTimestamp(metadata.Timestamp)), true)
if err != nil {
log.WithError(err).Error("userapi EDU consumer")
return false
}
if err = s.syncProducer.GetAndSendNotificationData(ctx, userID, roomID); err != nil {
log.WithError(err).Error("userapi EDU consumer: GetAndSendNotificationData failed")
return false
}
if !updated {
return true
}
if err = util.NotifyUserCountsAsync(ctx, s.pgClient, localpart, s.db); err != nil {
log.WithError(err).Error("userapi EDU consumer: NotifyUserCounts failed")
return false
}
return true
}

View file

@ -26,7 +26,7 @@ import (
"github.com/matrix-org/dendrite/userapi/util" "github.com/matrix-org/dendrite/userapi/util"
) )
type OutputStreamEventConsumer struct { type OutputRoomEventConsumer struct {
ctx context.Context ctx context.Context
cfg *config.UserAPI cfg *config.UserAPI
rsAPI rsapi.UserRoomserverAPI rsAPI rsapi.UserRoomserverAPI
@ -38,7 +38,7 @@ type OutputStreamEventConsumer struct {
syncProducer *producers.SyncAPI syncProducer *producers.SyncAPI
} }
func NewOutputStreamEventConsumer( func NewOutputRoomEventConsumer(
process *process.ProcessContext, process *process.ProcessContext,
cfg *config.UserAPI, cfg *config.UserAPI,
js nats.JetStreamContext, js nats.JetStreamContext,
@ -46,21 +46,21 @@ func NewOutputStreamEventConsumer(
pgClient pushgateway.Client, pgClient pushgateway.Client,
rsAPI rsapi.UserRoomserverAPI, rsAPI rsapi.UserRoomserverAPI,
syncProducer *producers.SyncAPI, syncProducer *producers.SyncAPI,
) *OutputStreamEventConsumer { ) *OutputRoomEventConsumer {
return &OutputStreamEventConsumer{ return &OutputRoomEventConsumer{
ctx: process.Context(), ctx: process.Context(),
cfg: cfg, cfg: cfg,
jetstream: js, jetstream: js,
db: store, db: store,
durable: cfg.Matrix.JetStream.Durable("UserAPISyncAPIStreamEventConsumer"), durable: cfg.Matrix.JetStream.Durable("UserAPIRoomServerConsumer"),
topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputStreamEvent), topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputRoomEvent),
pgClient: pgClient, pgClient: pgClient,
rsAPI: rsAPI, rsAPI: rsAPI,
syncProducer: syncProducer, syncProducer: syncProducer,
} }
} }
func (s *OutputStreamEventConsumer) Start() error { func (s *OutputRoomEventConsumer) Start() error {
if err := jetstream.JetStreamConsumer( if err := jetstream.JetStreamConsumer(
s.ctx, s.jetstream, s.topic, s.durable, 1, s.ctx, s.jetstream, s.topic, s.durable, 1,
s.onMessage, nats.DeliverAll(), nats.ManualAck(), s.onMessage, nats.DeliverAll(), nats.ManualAck(),
@ -70,35 +70,43 @@ func (s *OutputStreamEventConsumer) Start() error {
return nil return nil
} }
func (s *OutputStreamEventConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool { func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool {
msg := msgs[0] // Guaranteed to exist if onMessage is called msg := msgs[0] // Guaranteed to exist if onMessage is called
var output types.StreamedEvent var output rsapi.OutputEvent
output.Event = &gomatrixserverlib.HeaderedEvent{}
if err := json.Unmarshal(msg.Data, &output); err != nil { if err := json.Unmarshal(msg.Data, &output); err != nil {
log.WithError(err).Errorf("userapi consumer: message parse failure") // If the message was invalid, log it and move on to the next message in the stream
log.WithError(err).Errorf("roomserver output log: message parse failure")
return true return true
} }
if output.Event.Event == nil { if output.Type != rsapi.OutputTypeNewRoomEvent {
return true
}
event := output.NewRoomEvent.Event
if event == nil {
log.Errorf("userapi consumer: expected event") log.Errorf("userapi consumer: expected event")
return true return true
} }
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"event_id": output.Event.EventID(), "event_id": event.EventID(),
"event_type": output.Event.Type(), "event_type": event.Type(),
"stream_pos": output.StreamPosition, }).Tracef("Received message from roomserver: %#v", output)
}).Tracef("Received message from sync API: %#v", output)
if err := s.processMessage(ctx, output.Event, int64(output.StreamPosition)); err != nil { metadata, err := msg.Metadata()
if err != nil {
return true
}
if err := s.processMessage(ctx, event, uint64(gomatrixserverlib.AsTimestamp(metadata.Timestamp))); err != nil {
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"event_id": output.Event.EventID(), "event_id": event.EventID(),
}).WithError(err).Errorf("userapi consumer: process room event failure") }).WithError(err).Errorf("userapi consumer: process room event failure")
} }
return true return true
} }
func (s *OutputStreamEventConsumer) processMessage(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, pos int64) error { func (s *OutputRoomEventConsumer) processMessage(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, streamPos uint64) error {
members, roomSize, err := s.localRoomMembers(ctx, event.RoomID()) members, roomSize, err := s.localRoomMembers(ctx, event.RoomID())
if err != nil { if err != nil {
return fmt.Errorf("s.localRoomMembers: %w", err) return fmt.Errorf("s.localRoomMembers: %w", err)
@ -138,10 +146,10 @@ func (s *OutputStreamEventConsumer) processMessage(ctx context.Context, event *g
// removing it means we can send all notifications to // removing it means we can send all notifications to
// e.g. Element's Push gateway in one go. // e.g. Element's Push gateway in one go.
for _, mem := range members { for _, mem := range members {
if err := s.notifyLocal(ctx, event, pos, mem, roomSize, roomName); err != nil { if err := s.notifyLocal(ctx, event, mem, roomSize, roomName, streamPos); err != nil {
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"localpart": mem.Localpart, "localpart": mem.Localpart,
}).WithError(err).Debugf("Unable to push to local user") }).WithError(err).Error("Unable to push to local user")
continue continue
} }
} }
@ -179,7 +187,7 @@ func newLocalMembership(event *gomatrixserverlib.ClientEvent) (*localMembership,
// localRoomMembers fetches the current local members of a room, and // localRoomMembers fetches the current local members of a room, and
// the total number of members. // the total number of members.
func (s *OutputStreamEventConsumer) localRoomMembers(ctx context.Context, roomID string) ([]*localMembership, int, error) { func (s *OutputRoomEventConsumer) localRoomMembers(ctx context.Context, roomID string) ([]*localMembership, int, error) {
req := &rsapi.QueryMembershipsForRoomRequest{ req := &rsapi.QueryMembershipsForRoomRequest{
RoomID: roomID, RoomID: roomID,
JoinedOnly: true, JoinedOnly: true,
@ -219,7 +227,7 @@ func (s *OutputStreamEventConsumer) localRoomMembers(ctx context.Context, roomID
// looks it up in roomserver. If there is no name, // looks it up in roomserver. If there is no name,
// m.room.canonical_alias is consulted. Returns an empty string if the // m.room.canonical_alias is consulted. Returns an empty string if the
// room has no name. // room has no name.
func (s *OutputStreamEventConsumer) roomName(ctx context.Context, event *gomatrixserverlib.HeaderedEvent) (string, error) { func (s *OutputRoomEventConsumer) roomName(ctx context.Context, event *gomatrixserverlib.HeaderedEvent) (string, error) {
if event.Type() == gomatrixserverlib.MRoomName { if event.Type() == gomatrixserverlib.MRoomName {
name, err := unmarshalRoomName(event) name, err := unmarshalRoomName(event)
if err != nil { if err != nil {
@ -287,7 +295,7 @@ func unmarshalCanonicalAlias(event *gomatrixserverlib.HeaderedEvent) (string, er
} }
// notifyLocal finds the right push actions for a local user, given an event. // notifyLocal finds the right push actions for a local user, given an event.
func (s *OutputStreamEventConsumer) notifyLocal(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, pos int64, mem *localMembership, roomSize int, roomName string) error { func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, mem *localMembership, roomSize int, roomName string, streamPos uint64) error {
actions, err := s.evaluatePushRules(ctx, event, mem, roomSize) actions, err := s.evaluatePushRules(ctx, event, mem, roomSize)
if err != nil { if err != nil {
return err return err
@ -302,7 +310,7 @@ func (s *OutputStreamEventConsumer) notifyLocal(ctx context.Context, event *goma
"event_id": event.EventID(), "event_id": event.EventID(),
"room_id": event.RoomID(), "room_id": event.RoomID(),
"localpart": mem.Localpart, "localpart": mem.Localpart,
}).Debugf("Push rule evaluation rejected the event") }).Tracef("Push rule evaluation rejected the event")
return nil return nil
} }
@ -325,7 +333,7 @@ func (s *OutputStreamEventConsumer) notifyLocal(ctx context.Context, event *goma
RoomID: event.RoomID(), RoomID: event.RoomID(),
TS: gomatrixserverlib.AsTimestamp(time.Now()), TS: gomatrixserverlib.AsTimestamp(time.Now()),
} }
if err = s.db.InsertNotification(ctx, mem.Localpart, event.EventID(), pos, tweaks, n); err != nil { if err = s.db.InsertNotification(ctx, mem.Localpart, event.EventID(), streamPos, tweaks, n); err != nil {
return err return err
} }
@ -345,7 +353,7 @@ func (s *OutputStreamEventConsumer) notifyLocal(ctx context.Context, event *goma
"localpart": mem.Localpart, "localpart": mem.Localpart,
"num_urls": len(devicesByURLAndFormat), "num_urls": len(devicesByURLAndFormat),
"num_unread": userNumUnreadNotifs, "num_unread": userNumUnreadNotifs,
}).Debugf("Notifying single member") }).Trace("Notifying single member")
// Push gateways are out of our control, and we cannot risk // Push gateways are out of our control, and we cannot risk
// looking up the server on a misbehaving push gateway. Each user // looking up the server on a misbehaving push gateway. Each user
@ -396,7 +404,7 @@ func (s *OutputStreamEventConsumer) notifyLocal(ctx context.Context, event *goma
// evaluatePushRules fetches and evaluates the push rules of a local // evaluatePushRules fetches and evaluates the push rules of a local
// user. Returns actions (including dont_notify). // user. Returns actions (including dont_notify).
func (s *OutputStreamEventConsumer) evaluatePushRules(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, mem *localMembership, roomSize int) ([]*pushrules.Action, error) { func (s *OutputRoomEventConsumer) evaluatePushRules(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, mem *localMembership, roomSize int) ([]*pushrules.Action, error) {
if event.Sender() == mem.UserID { if event.Sender() == mem.UserID {
// SPEC: Homeservers MUST NOT notify the Push Gateway for // SPEC: Homeservers MUST NOT notify the Push Gateway for
// events that the user has sent themselves. // events that the user has sent themselves.
@ -447,7 +455,7 @@ func (s *OutputStreamEventConsumer) evaluatePushRules(ctx context.Context, event
"room_id": event.RoomID(), "room_id": event.RoomID(),
"localpart": mem.Localpart, "localpart": mem.Localpart,
"rule_id": rule.RuleID, "rule_id": rule.RuleID,
}).Tracef("Matched a push rule") }).Trace("Matched a push rule")
return rule.Actions, nil return rule.Actions, nil
} }
@ -491,7 +499,7 @@ func (rse *ruleSetEvalContext) HasPowerLevel(userID, levelKey string) (bool, err
// localPushDevices pushes to the configured devices of a local // localPushDevices pushes to the configured devices of a local
// user. The map keys are [url][format]. // user. The map keys are [url][format].
func (s *OutputStreamEventConsumer) localPushDevices(ctx context.Context, localpart string, tweaks map[string]interface{}) (map[string]map[string][]*pushgateway.Device, string, error) { func (s *OutputRoomEventConsumer) localPushDevices(ctx context.Context, localpart string, tweaks map[string]interface{}) (map[string]map[string][]*pushgateway.Device, string, error) {
pusherDevices, err := util.GetPushDevices(ctx, localpart, tweaks, s.db) pusherDevices, err := util.GetPushDevices(ctx, localpart, tweaks, s.db)
if err != nil { if err != nil {
return nil, "", err return nil, "", err
@ -515,7 +523,7 @@ func (s *OutputStreamEventConsumer) localPushDevices(ctx context.Context, localp
} }
// notifyHTTP performs a notificatation to a Push Gateway. // notifyHTTP performs a notificatation to a Push Gateway.
func (s *OutputStreamEventConsumer) notifyHTTP(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, url, format string, devices []*pushgateway.Device, localpart, roomName string, userNumUnreadNotifs int) ([]*pushgateway.Device, error) { func (s *OutputRoomEventConsumer) notifyHTTP(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, url, format string, devices []*pushgateway.Device, localpart, roomName string, userNumUnreadNotifs int) ([]*pushgateway.Device, error) {
logger := log.WithFields(log.Fields{ logger := log.WithFields(log.Fields{
"event_id": event.EventID(), "event_id": event.EventID(),
"url": url, "url": url,
@ -561,13 +569,13 @@ func (s *OutputStreamEventConsumer) notifyHTTP(ctx context.Context, event *gomat
} }
} }
logger.Debugf("Notifying push gateway %s", url) logger.Tracef("Notifying push gateway %s", url)
var res pushgateway.NotifyResponse var res pushgateway.NotifyResponse
if err := s.pgClient.Notify(ctx, url, &req, &res); err != nil { if err := s.pgClient.Notify(ctx, url, &req, &res); err != nil {
logger.WithError(err).Errorf("Failed to notify push gateway %s", url) logger.WithError(err).Errorf("Failed to notify push gateway %s", url)
return nil, err return nil, err
} }
logger.WithField("num_rejected", len(res.Rejected)).Tracef("Push gateway result") logger.WithField("num_rejected", len(res.Rejected)).Trace("Push gateway result")
if len(res.Rejected) == 0 { if len(res.Rejected) == 0 {
return nil, nil return nil, nil
@ -589,7 +597,7 @@ func (s *OutputStreamEventConsumer) notifyHTTP(ctx context.Context, event *gomat
} }
// deleteRejectedPushers deletes the pushers associated with the given devices. // deleteRejectedPushers deletes the pushers associated with the given devices.
func (s *OutputStreamEventConsumer) deleteRejectedPushers(ctx context.Context, devices []*pushgateway.Device, localpart string) { func (s *OutputRoomEventConsumer) deleteRejectedPushers(ctx context.Context, devices []*pushgateway.Device, localpart string) {
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"localpart": localpart, "localpart": localpart,
"app_id0": devices[0].AppID, "app_id0": devices[0].AppID,

View file

@ -40,7 +40,7 @@ func Test_evaluatePushRules(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateDatabase(t, dbType) db, close := mustCreateDatabase(t, dbType)
defer close() defer close()
consumer := OutputStreamEventConsumer{db: db} consumer := OutputRoomEventConsumer{db: db}
testCases := []struct { testCases := []struct {
name string name string

View file

@ -1,142 +0,0 @@
package consumers
import (
"context"
"encoding/json"
"github.com/matrix-org/dendrite/internal/pushgateway"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/setup/process"
"github.com/matrix-org/dendrite/syncapi/types"
uapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/producers"
"github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/dendrite/userapi/util"
"github.com/matrix-org/gomatrixserverlib"
"github.com/nats-io/nats.go"
log "github.com/sirupsen/logrus"
)
type OutputReadUpdateConsumer struct {
ctx context.Context
cfg *config.UserAPI
jetstream nats.JetStreamContext
durable string
db storage.Database
pgClient pushgateway.Client
ServerName gomatrixserverlib.ServerName
topic string
userAPI uapi.UserInternalAPI
syncProducer *producers.SyncAPI
}
func NewOutputReadUpdateConsumer(
process *process.ProcessContext,
cfg *config.UserAPI,
js nats.JetStreamContext,
store storage.Database,
pgClient pushgateway.Client,
userAPI uapi.UserInternalAPI,
syncProducer *producers.SyncAPI,
) *OutputReadUpdateConsumer {
return &OutputReadUpdateConsumer{
ctx: process.Context(),
cfg: cfg,
jetstream: js,
db: store,
ServerName: cfg.Matrix.ServerName,
durable: cfg.Matrix.JetStream.Durable("UserAPISyncAPIReadUpdateConsumer"),
topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputReadUpdate),
pgClient: pgClient,
userAPI: userAPI,
syncProducer: syncProducer,
}
}
func (s *OutputReadUpdateConsumer) Start() error {
if err := jetstream.JetStreamConsumer(
s.ctx, s.jetstream, s.topic, s.durable, 1,
s.onMessage, nats.DeliverAll(), nats.ManualAck(),
); err != nil {
return err
}
return nil
}
func (s *OutputReadUpdateConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool {
msg := msgs[0] // Guaranteed to exist if onMessage is called
var read types.ReadUpdate
if err := json.Unmarshal(msg.Data, &read); err != nil {
log.WithError(err).Error("userapi clientapi consumer: message parse failure")
return true
}
if read.FullyRead == 0 && read.Read == 0 {
return true
}
userID := string(msg.Header.Get(jetstream.UserID))
roomID := string(msg.Header.Get(jetstream.RoomID))
localpart, domain, err := gomatrixserverlib.SplitID('@', userID)
if err != nil {
log.WithError(err).Error("userapi clientapi consumer: SplitID failure")
return true
}
if domain != s.ServerName {
log.Error("userapi clientapi consumer: not a local user")
return true
}
log := log.WithFields(log.Fields{
"room_id": roomID,
"user_id": userID,
})
log.Tracef("Received read update from sync API: %#v", read)
if read.Read > 0 {
/*updated*/ _, err := s.db.SetNotificationsRead(ctx, localpart, roomID, int64(read.Read), true)
if err != nil {
log.WithError(err).Error("userapi EDU consumer")
return false
}
// zion hack - always send the notification data
// the notifications are stored in two different databases, and somehow the notification database
// prunes data, so trying to mark old notifications as read will fail, but the notification will still exist in the other db
// todo, revist when this refactor lands: https://github.com/matrix-org/dendrite/pull/2688/files
//if updated {
if err = s.syncProducer.GetAndSendNotificationData(ctx, userID, roomID); err != nil {
log.WithError(err).Error("userapi EDU consumer: GetAndSendNotificationData failed")
return false
}
if err = util.NotifyUserCountsAsync(ctx, s.pgClient, localpart, s.db); err != nil {
log.WithError(err).Error("userapi EDU consumer: NotifyUserCounts failed")
return false
}
//}
}
if read.FullyRead > 0 {
deleted, err := s.db.DeleteNotificationsUpTo(ctx, localpart, roomID, int64(read.FullyRead))
if err != nil {
log.WithError(err).Errorf("userapi clientapi consumer: DeleteNotificationsUpTo failed")
return false
}
if deleted {
if err := util.NotifyUserCountsAsync(ctx, s.pgClient, localpart, s.db); err != nil {
log.WithError(err).Error("userapi clientapi consumer: NotifyUserCounts failed")
return false
}
if err := s.syncProducer.GetAndSendNotificationData(ctx, userID, read.RoomID); err != nil {
log.WithError(err).Errorf("userapi clientapi consumer: GetAndSendNotificationData failed")
return false
}
}
}
return true
}

View file

@ -30,6 +30,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/clientapi/userutil"
"github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/internal/pushgateway"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
keyapi "github.com/matrix-org/dendrite/keyserver/api" keyapi "github.com/matrix-org/dendrite/keyserver/api"
rsapi "github.com/matrix-org/dendrite/roomserver/api" rsapi "github.com/matrix-org/dendrite/roomserver/api"
@ -39,6 +40,7 @@ import (
"github.com/matrix-org/dendrite/userapi/producers" "github.com/matrix-org/dendrite/userapi/producers"
"github.com/matrix-org/dendrite/userapi/storage" "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/dendrite/userapi/storage/tables" "github.com/matrix-org/dendrite/userapi/storage/tables"
userapiUtil "github.com/matrix-org/dendrite/userapi/util"
) )
type UserInternalAPI struct { type UserInternalAPI struct {
@ -51,6 +53,7 @@ type UserInternalAPI struct {
AppServices []config.ApplicationService AppServices []config.ApplicationService
KeyAPI keyapi.UserKeyAPI KeyAPI keyapi.UserKeyAPI
RSAPI rsapi.UserRoomserverAPI RSAPI rsapi.UserRoomserverAPI
PgClient pushgateway.Client
} }
func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAccountDataRequest, res *api.InputAccountDataResponse) error { func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAccountDataRequest, res *api.InputAccountDataResponse) error {
@ -73,6 +76,11 @@ func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAc
ignoredUsers = &synctypes.IgnoredUsers{} ignoredUsers = &synctypes.IgnoredUsers{}
_ = json.Unmarshal(req.AccountData, ignoredUsers) _ = json.Unmarshal(req.AccountData, ignoredUsers)
} }
if req.DataType == "m.fully_read" {
if err := a.setFullyRead(ctx, req); err != nil {
return err
}
}
if err := a.SyncProducer.SendAccountData(req.UserID, eventutil.AccountData{ if err := a.SyncProducer.SendAccountData(req.UserID, eventutil.AccountData{
RoomID: req.RoomID, RoomID: req.RoomID,
Type: req.DataType, Type: req.DataType,
@ -84,6 +92,44 @@ func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAc
return nil return nil
} }
func (a *UserInternalAPI) setFullyRead(ctx context.Context, req *api.InputAccountDataRequest) error {
var output eventutil.ReadMarkerJSON
if err := json.Unmarshal(req.AccountData, &output); err != nil {
return err
}
localpart, domain, err := gomatrixserverlib.SplitID('@', req.UserID)
if err != nil {
logrus.WithError(err).Error("UserInternalAPI.setFullyRead: SplitID failure")
return nil
}
if domain != a.ServerName {
return nil
}
deleted, err := a.DB.DeleteNotificationsUpTo(ctx, localpart, req.RoomID, uint64(gomatrixserverlib.AsTimestamp(time.Now())))
if err != nil {
logrus.WithError(err).Errorf("UserInternalAPI.setFullyRead: DeleteNotificationsUpTo failed")
return err
}
if err = a.SyncProducer.GetAndSendNotificationData(ctx, req.UserID, req.RoomID); err != nil {
logrus.WithError(err).Error("UserInternalAPI.setFullyRead: GetAndSendNotificationData failed")
return err
}
// nothing changed, no need to notify the push gateway
if !deleted {
return nil
}
if err = userapiUtil.NotifyUserCountsAsync(ctx, a.PgClient, localpart, a.DB); err != nil {
logrus.WithError(err).Error("UserInternalAPI.setFullyRead: NotifyUserCounts failed")
return err
}
return nil
}
func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.PerformAccountCreationRequest, res *api.PerformAccountCreationResponse) error { func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.PerformAccountCreationRequest, res *api.PerformAccountCreationResponse) error {
acc, err := a.DB.CreateAccount(ctx, req.Localpart, req.Password, req.AppServiceID, req.AccountType) acc, err := a.DB.CreateAccount(ctx, req.Localpart, req.Password, req.AppServiceID, req.AccountType)
if err != nil { if err != nil {

View file

@ -4,12 +4,13 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/nats-io/nats.go" "github.com/nats-io/nats.go"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/userapi/storage"
) )
type JetStreamPublisher interface { type JetStreamPublisher interface {

View file

@ -119,9 +119,9 @@ type ThreePID interface {
} }
type Notification interface { type Notification interface {
InsertNotification(ctx context.Context, localpart, eventID string, pos int64, tweaks map[string]interface{}, n *api.Notification) error InsertNotification(ctx context.Context, localpart, eventID string, pos uint64, tweaks map[string]interface{}, n *api.Notification) error
DeleteNotificationsUpTo(ctx context.Context, localpart, roomID string, pos int64) (affected bool, err error) DeleteNotificationsUpTo(ctx context.Context, localpart, roomID string, pos uint64) (affected bool, err error)
SetNotificationsRead(ctx context.Context, localpart, roomID string, pos int64, read bool) (affected bool, err error) SetNotificationsRead(ctx context.Context, localpart, roomID string, pos uint64, read bool) (affected bool, err error)
GetNotifications(ctx context.Context, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) GetNotifications(ctx context.Context, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error)
GetNotificationCount(ctx context.Context, localpart string, filter tables.NotificationFilter) (int64, error) GetNotificationCount(ctx context.Context, localpart string, filter tables.NotificationFilter) (int64, error)
GetRoomNotificationCounts(ctx context.Context, localpart, roomID string) (total int64, highlight int64, _ error) GetRoomNotificationCounts(ctx context.Context, localpart, roomID string) (total int64, highlight int64, _ error)

View file

@ -20,12 +20,13 @@ import (
"encoding/json" "encoding/json"
"time" "time"
"github.com/matrix-org/gomatrixserverlib"
log "github.com/sirupsen/logrus"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/tables" "github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/gomatrixserverlib"
log "github.com/sirupsen/logrus"
) )
type notificationsStatements struct { type notificationsStatements struct {
@ -110,7 +111,7 @@ func (s *notificationsStatements) Clean(ctx context.Context, txn *sql.Tx) error
} }
// Insert inserts a notification into the database. // Insert inserts a notification into the database.
func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos int64, highlight bool, n *api.Notification) error { func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos uint64, highlight bool, n *api.Notification) error {
roomID, tsMS := n.RoomID, n.TS roomID, tsMS := n.RoomID, n.TS
nn := *n nn := *n
// Clears out fields that have their own columns to (1) shrink the // Clears out fields that have their own columns to (1) shrink the
@ -126,7 +127,7 @@ func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, local
} }
// DeleteUpTo deletes all previous notifications, up to and including the event. // DeleteUpTo deletes all previous notifications, up to and including the event.
func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos int64) (affected bool, _ error) { func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64) (affected bool, _ error) {
res, err := sqlutil.TxStmt(txn, s.deleteUpToStmt).ExecContext(ctx, localpart, roomID, pos) res, err := sqlutil.TxStmt(txn, s.deleteUpToStmt).ExecContext(ctx, localpart, roomID, pos)
if err != nil { if err != nil {
return false, err return false, err
@ -140,7 +141,7 @@ func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, l
} }
// UpdateRead updates the "read" value for an event. // UpdateRead updates the "read" value for an event.
func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos int64, v bool) (affected bool, _ error) { func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64, v bool) (affected bool, _ error) {
res, err := sqlutil.TxStmt(txn, s.updateReadStmt).ExecContext(ctx, v, localpart, roomID, pos) res, err := sqlutil.TxStmt(txn, s.updateReadStmt).ExecContext(ctx, v, localpart, roomID, pos)
if err != nil { if err != nil {
return false, err return false, err
@ -196,40 +197,12 @@ func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, local
return notifs, maxID, rows.Err() return notifs, maxID, rows.Err()
} }
func (s *notificationsStatements) SelectCount(ctx context.Context, txn *sql.Tx, localpart string, filter tables.NotificationFilter) (int64, error) { func (s *notificationsStatements) SelectCount(ctx context.Context, txn *sql.Tx, localpart string, filter tables.NotificationFilter) (count int64, err error) {
rows, err := sqlutil.TxStmt(txn, s.selectCountStmt).QueryContext(ctx, localpart, uint32(filter)) err = sqlutil.TxStmt(txn, s.selectCountStmt).QueryRowContext(ctx, localpart, uint32(filter)).Scan(&count)
return
if err != nil {
return 0, err
}
defer internal.CloseAndLogIfError(ctx, rows, "notifications.Select: rows.Close() failed")
if rows.Next() {
var count int64
if err := rows.Scan(&count); err != nil {
return 0, err
} }
return count, nil func (s *notificationsStatements) SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart, roomID string) (total int64, highlight int64, err error) {
} err = sqlutil.TxStmt(txn, s.selectRoomCountsStmt).QueryRowContext(ctx, localpart, roomID).Scan(&total, &highlight)
return 0, rows.Err() return
}
func (s *notificationsStatements) SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart, roomID string) (total int64, highlight int64, _ error) {
rows, err := sqlutil.TxStmt(txn, s.selectRoomCountsStmt).QueryContext(ctx, localpart, roomID)
if err != nil {
return 0, 0, err
}
defer internal.CloseAndLogIfError(ctx, rows, "notifications.Select: rows.Close() failed")
if rows.Next() {
var total, highlight int64
if err := rows.Scan(&total, &highlight); err != nil {
return 0, 0, err
}
return total, highlight, nil
}
return 0, 0, rows.Err()
} }

View file

@ -19,11 +19,12 @@ import (
"database/sql" "database/sql"
"encoding/json" "encoding/json"
"github.com/sirupsen/logrus"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/tables" "github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/sirupsen/logrus"
) )
// See https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-pushers // See https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-pushers
@ -136,7 +137,7 @@ func (s *pushersStatements) SelectPushers(
pushers = append(pushers, pusher) pushers = append(pushers, pusher)
} }
logrus.Debugf("Database returned %d pushers", len(pushers)) logrus.Tracef("Database returned %d pushers", len(pushers))
return pushers, rows.Err() return pushers, rows.Err()
} }

View file

@ -700,13 +700,13 @@ func (d *Database) GetLoginTokenDataByToken(ctx context.Context, token string) (
return d.LoginTokens.SelectLoginToken(ctx, token) return d.LoginTokens.SelectLoginToken(ctx, token)
} }
func (d *Database) InsertNotification(ctx context.Context, localpart, eventID string, pos int64, tweaks map[string]interface{}, n *api.Notification) error { func (d *Database) InsertNotification(ctx context.Context, localpart, eventID string, pos uint64, tweaks map[string]interface{}, n *api.Notification) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.Notifications.Insert(ctx, txn, localpart, eventID, pos, pushrules.BoolTweakOr(tweaks, pushrules.HighlightTweak, false), n) return d.Notifications.Insert(ctx, txn, localpart, eventID, pos, pushrules.BoolTweakOr(tweaks, pushrules.HighlightTweak, false), n)
}) })
} }
func (d *Database) DeleteNotificationsUpTo(ctx context.Context, localpart, roomID string, pos int64) (affected bool, err error) { func (d *Database) DeleteNotificationsUpTo(ctx context.Context, localpart, roomID string, pos uint64) (affected bool, err error) {
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
affected, err = d.Notifications.DeleteUpTo(ctx, txn, localpart, roomID, pos) affected, err = d.Notifications.DeleteUpTo(ctx, txn, localpart, roomID, pos)
return err return err
@ -714,7 +714,7 @@ func (d *Database) DeleteNotificationsUpTo(ctx context.Context, localpart, roomI
return return
} }
func (d *Database) SetNotificationsRead(ctx context.Context, localpart, roomID string, pos int64, b bool) (affected bool, err error) { func (d *Database) SetNotificationsRead(ctx context.Context, localpart, roomID string, pos uint64, b bool) (affected bool, err error) {
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
affected, err = d.Notifications.UpdateRead(ctx, txn, localpart, roomID, pos, b) affected, err = d.Notifications.UpdateRead(ctx, txn, localpart, roomID, pos, b)
return err return err
@ -777,7 +777,7 @@ func (d *Database) GetPushers(
func (d *Database) RemovePusher( func (d *Database) RemovePusher(
ctx context.Context, appid, pushkey, localpart string, ctx context.Context, appid, pushkey, localpart string,
) error { ) error {
return d.Writer.Do(nil, nil, func(txn *sql.Tx) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
err := d.Pushers.DeletePusher(ctx, txn, appid, pushkey, localpart) err := d.Pushers.DeletePusher(ctx, txn, appid, pushkey, localpart)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return nil return nil
@ -792,7 +792,7 @@ func (d *Database) RemovePusher(
func (d *Database) RemovePushers( func (d *Database) RemovePushers(
ctx context.Context, appid, pushkey string, ctx context.Context, appid, pushkey string,
) error { ) error {
return d.Writer.Do(nil, nil, func(txn *sql.Tx) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.Pushers.DeletePushers(ctx, txn, appid, pushkey) return d.Pushers.DeletePushers(ctx, txn, appid, pushkey)
}) })
} }

View file

@ -20,12 +20,13 @@ import (
"encoding/json" "encoding/json"
"time" "time"
"github.com/matrix-org/gomatrixserverlib"
log "github.com/sirupsen/logrus"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/tables" "github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/gomatrixserverlib"
log "github.com/sirupsen/logrus"
) )
type notificationsStatements struct { type notificationsStatements struct {
@ -110,7 +111,7 @@ func (s *notificationsStatements) Clean(ctx context.Context, txn *sql.Tx) error
} }
// Insert inserts a notification into the database. // Insert inserts a notification into the database.
func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos int64, highlight bool, n *api.Notification) error { func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos uint64, highlight bool, n *api.Notification) error {
roomID, tsMS := n.RoomID, n.TS roomID, tsMS := n.RoomID, n.TS
nn := *n nn := *n
// Clears out fields that have their own columns to (1) shrink the // Clears out fields that have their own columns to (1) shrink the
@ -126,7 +127,7 @@ func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, local
} }
// DeleteUpTo deletes all previous notifications, up to and including the event. // DeleteUpTo deletes all previous notifications, up to and including the event.
func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos int64) (affected bool, _ error) { func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64) (affected bool, _ error) {
res, err := sqlutil.TxStmt(txn, s.deleteUpToStmt).ExecContext(ctx, localpart, roomID, pos) res, err := sqlutil.TxStmt(txn, s.deleteUpToStmt).ExecContext(ctx, localpart, roomID, pos)
if err != nil { if err != nil {
return false, err return false, err
@ -140,7 +141,7 @@ func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, l
} }
// UpdateRead updates the "read" value for an event. // UpdateRead updates the "read" value for an event.
func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos int64, v bool) (affected bool, _ error) { func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64, v bool) (affected bool, _ error) {
res, err := sqlutil.TxStmt(txn, s.updateReadStmt).ExecContext(ctx, v, localpart, roomID, pos) res, err := sqlutil.TxStmt(txn, s.updateReadStmt).ExecContext(ctx, v, localpart, roomID, pos)
if err != nil { if err != nil {
return false, err return false, err
@ -196,40 +197,12 @@ func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, local
return notifs, maxID, rows.Err() return notifs, maxID, rows.Err()
} }
func (s *notificationsStatements) SelectCount(ctx context.Context, txn *sql.Tx, localpart string, filter tables.NotificationFilter) (int64, error) { func (s *notificationsStatements) SelectCount(ctx context.Context, txn *sql.Tx, localpart string, filter tables.NotificationFilter) (count int64, err error) {
rows, err := sqlutil.TxStmt(txn, s.selectCountStmt).QueryContext(ctx, localpart, uint32(filter)) err = sqlutil.TxStmt(txn, s.selectCountStmt).QueryRowContext(ctx, localpart, uint32(filter)).Scan(&count)
return
if err != nil {
return 0, err
}
defer internal.CloseAndLogIfError(ctx, rows, "notifications.Select: rows.Close() failed")
if rows.Next() {
var count int64
if err := rows.Scan(&count); err != nil {
return 0, err
} }
return count, nil func (s *notificationsStatements) SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart, roomID string) (total int64, highlight int64, err error) {
} err = sqlutil.TxStmt(txn, s.selectRoomCountsStmt).QueryRowContext(ctx, localpart, roomID).Scan(&total, &highlight)
return 0, rows.Err() return
}
func (s *notificationsStatements) SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart, roomID string) (total int64, highlight int64, _ error) {
rows, err := sqlutil.TxStmt(txn, s.selectRoomCountsStmt).QueryContext(ctx, localpart, roomID)
if err != nil {
return 0, 0, err
}
defer internal.CloseAndLogIfError(ctx, rows, "notifications.Select: rows.Close() failed")
if rows.Next() {
var total, highlight int64
if err := rows.Scan(&total, &highlight); err != nil {
return 0, 0, err
}
return total, highlight, nil
}
return 0, 0, rows.Err()
} }

View file

@ -19,11 +19,12 @@ import (
"database/sql" "database/sql"
"encoding/json" "encoding/json"
"github.com/sirupsen/logrus"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/tables" "github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/sirupsen/logrus"
) )
// See https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-pushers // See https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-pushers
@ -96,7 +97,7 @@ func (s *pushersStatements) InsertPusher(
ctx context.Context, txn *sql.Tx, session_id int64, ctx context.Context, txn *sql.Tx, session_id int64,
pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string, pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string,
) error { ) error {
_, err := s.insertPusherStmt.ExecContext(ctx, localpart, session_id, pushkey, pushkeyTS, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, data) _, err := sqlutil.TxStmt(txn, s.insertPusherStmt).ExecContext(ctx, localpart, session_id, pushkey, pushkeyTS, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, data)
logrus.Debugf("Created pusher %d", session_id) logrus.Debugf("Created pusher %d", session_id)
return err return err
} }
@ -136,7 +137,7 @@ func (s *pushersStatements) SelectPushers(
pushers = append(pushers, pusher) pushers = append(pushers, pusher)
} }
logrus.Debugf("Database returned %d pushers", len(pushers)) logrus.Tracef("Database returned %d pushers", len(pushers))
return pushers, rows.Err() return pushers, rows.Err()
} }
@ -144,13 +145,13 @@ func (s *pushersStatements) SelectPushers(
func (s *pushersStatements) DeletePusher( func (s *pushersStatements) DeletePusher(
ctx context.Context, txn *sql.Tx, appid, pushkey, localpart string, ctx context.Context, txn *sql.Tx, appid, pushkey, localpart string,
) error { ) error {
_, err := s.deletePusherStmt.ExecContext(ctx, appid, pushkey, localpart) _, err := sqlutil.TxStmt(txn, s.deletePusherStmt).ExecContext(ctx, appid, pushkey, localpart)
return err return err
} }
func (s *pushersStatements) DeletePushers( func (s *pushersStatements) DeletePushers(
ctx context.Context, txn *sql.Tx, appid, pushkey string, ctx context.Context, txn *sql.Tx, appid, pushkey string,
) error { ) error {
_, err := s.deletePushersByAppIdAndPushKeyStmt.ExecContext(ctx, appid, pushkey) _, err := sqlutil.TxStmt(txn, s.deletePushersByAppIdAndPushKeyStmt).ExecContext(ctx, appid, pushkey)
return err return err
} }

View file

@ -7,6 +7,11 @@ import (
"testing" "testing"
"time" "time"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/stretchr/testify/assert"
"golang.org/x/crypto/bcrypt"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/internal/pushrules" "github.com/matrix-org/dendrite/internal/pushrules"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
@ -14,10 +19,6 @@ import (
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage" "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/dendrite/userapi/storage/tables" "github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/stretchr/testify/assert"
"golang.org/x/crypto/bcrypt"
) )
const loginTokenLifetime = time.Minute const loginTokenLifetime = time.Minute
@ -513,7 +514,7 @@ func Test_Notification(t *testing.T) {
RoomID: roomID, RoomID: roomID,
TS: gomatrixserverlib.AsTimestamp(ts), TS: gomatrixserverlib.AsTimestamp(ts),
} }
err = db.InsertNotification(ctx, aliceLocalpart, eventID, int64(i+1), nil, notification) err = db.InsertNotification(ctx, aliceLocalpart, eventID, uint64(i+1), nil, notification)
assert.NoError(t, err, "unable to insert notification") assert.NoError(t, err, "unable to insert notification")
} }

View file

@ -105,9 +105,9 @@ type PusherTable interface {
type NotificationTable interface { type NotificationTable interface {
Clean(ctx context.Context, txn *sql.Tx) error Clean(ctx context.Context, txn *sql.Tx) error
Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos int64, highlight bool, n *api.Notification) error Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos uint64, highlight bool, n *api.Notification) error
DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos int64) (affected bool, _ error) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64) (affected bool, _ error)
UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos int64, v bool) (affected bool, _ error) UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64, v bool) (affected bool, _ error)
Select(ctx context.Context, txn *sql.Tx, localpart string, fromID int64, limit int, filter NotificationFilter) ([]*api.Notification, int64, error) Select(ctx context.Context, txn *sql.Tx, localpart string, fromID int64, limit int, filter NotificationFilter) ([]*api.Notification, int64, error)
SelectCount(ctx context.Context, txn *sql.Tx, localpart string, filter NotificationFilter) (int64, error) SelectCount(ctx context.Context, txn *sql.Tx, localpart string, filter NotificationFilter) (int64, error)
SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart, roomID string) (total int64, highlight int64, _ error) SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart, roomID string) (total int64, highlight int64, _ error)

View file

@ -81,16 +81,17 @@ func NewInternalAPI(
KeyAPI: keyAPI, KeyAPI: keyAPI,
RSAPI: rsAPI, RSAPI: rsAPI,
DisableTLSValidation: cfg.PushGatewayDisableTLSValidation, DisableTLSValidation: cfg.PushGatewayDisableTLSValidation,
PgClient: pgClient,
} }
readConsumer := consumers.NewOutputReadUpdateConsumer( receiptConsumer := consumers.NewOutputReceiptEventConsumer(
base.ProcessContext, cfg, js, db, pgClient, userAPI, syncProducer, base.ProcessContext, cfg, js, db, syncProducer, pgClient,
) )
if err := readConsumer.Start(); err != nil { if err := receiptConsumer.Start(); err != nil {
logrus.WithError(err).Panic("failed to start user API read update consumer") logrus.WithError(err).Panic("failed to start user API receipt consumer")
} }
eventConsumer := consumers.NewOutputStreamEventConsumer( eventConsumer := consumers.NewOutputRoomEventConsumer(
base.ProcessContext, cfg, js, db, pgClient, rsAPI, syncProducer, base.ProcessContext, cfg, js, db, pgClient, rsAPI, syncProducer,
) )
if err := eventConsumer.Start(); err != nil { if err := eventConsumer.Start(); err != nil {

65
web3/account.go Normal file
View file

@ -0,0 +1,65 @@
package web3
import (
"context"
"crypto/ecdsa"
"errors"
"fmt"
"math/big"
"github.com/ethereum/go-ethereum/accounts/abi/bind"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/ethclient"
)
type CreateTransactionSignerArgs struct {
PrivateKey string
ChainId int64
Client *ethclient.Client
GasValue int64 // in wei
GasLimit int64 // in units
}
func CreateTransactionSigner(args CreateTransactionSignerArgs) (*bind.TransactOpts, error) {
privateKey, err := crypto.HexToECDSA(args.PrivateKey)
if err != nil {
return nil, err
}
publicKey := privateKey.Public()
publicKeyECDSA, ok := publicKey.(*ecdsa.PublicKey)
if !ok {
return nil, errors.New("cannot create public key ECDSA")
}
fromAddress := crypto.PubkeyToAddress(*publicKeyECDSA)
nonce, err := args.Client.PendingNonceAt(context.Background(), fromAddress)
if err != nil {
return nil, err
}
gasPrice, err := args.Client.SuggestGasPrice((context.Background()))
if err != nil {
return nil, err
}
signer, err := bind.NewKeyedTransactorWithChainID(privateKey, big.NewInt(args.ChainId))
if err != nil {
return nil, err
}
signer.Nonce = big.NewInt(int64(nonce))
signer.Value = big.NewInt(args.GasValue)
signer.GasLimit = uint64(args.GasLimit)
signer.GasPrice = gasPrice
fmt.Printf("{ nonce: %d, value: %d, gasLimit: %d, gasPrice: %d }\n",
signer.Nonce,
signer.Value,
signer.GasLimit,
signer.GasPrice,
)
return signer, nil
}

14
web3/client.go Normal file
View file

@ -0,0 +1,14 @@
package web3
import (
"github.com/ethereum/go-ethereum/ethclient"
)
func GetEthClient(web3ProviderUrl string) (*ethclient.Client, error) {
client, err := ethclient.Dial(web3ProviderUrl)
if err != nil {
return nil, err
}
return client, nil
}