Merge branch 'main' into mailbox

This commit is contained in:
Devon Hudson 2023-01-19 14:47:14 -07:00
commit 3c54ea1d56
No known key found for this signature in database
GPG key ID: CD06B18E77F6A628
62 changed files with 1617 additions and 282 deletions

View file

@ -1,5 +1,25 @@
# Changelog
## Dendrite 0.10.9 (2023-01-17)
### Features
* Stale device lists are now cleaned up on startup, removing entries for users the server doesn't share a room with anymore
* Dendrite now has its own Helm chart
* Guest access is now handled correctly (disallow joins, kick guests on revocation of guest access, as well as over federation)
### Fixes
* Push rules have seen several tweaks and fixes, which should, for example, fix notifications for `m.read_receipts`
* Outgoing presence will now correctly be sent to newly joined hosts
* Fixes the `/_dendrite/admin/resetPassword/{userID}` admin endpoint to use the correct variable
* Federated backfilling for medium/large rooms has been fixed
* `/login` causing wrong device list updates has been resolved
* `/sync` should now return the correct room summary heroes
* The default config options for `recaptcha_sitekey_class` and `recaptcha_form_field` are now set correctly
* `/messages` now omits empty `state` to be more spec compliant (contributed by [handlerug](https://github.com/handlerug))
* `/sync` has been optimised to only query state events for history visibility if they are really needed
## Dendrite 0.10.8 (2022-11-29)
### Features

View file

@ -7,9 +7,12 @@ import (
"testing"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/federationapi"
"github.com/matrix-org/dendrite/keyserver"
"github.com/matrix-org/dendrite/roomserver"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/syncapi"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/tidwall/gjson"
@ -41,7 +44,7 @@ func TestAdminResetPassword(t *testing.T) {
userAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, keyAPI, rsAPI, nil)
keyAPI.SetUserAPI(userAPI)
// We mostly need the userAPI for this test, so nil for other APIs/caches etc.
AddPublicRoutes(base, nil, nil, nil, nil, nil, userAPI, nil, nil, nil)
AddPublicRoutes(base, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, nil)
// Create the users in the userapi and login
accessTokens := map[*test.User]string{
@ -112,6 +115,7 @@ func TestAdminResetPassword(t *testing.T) {
}
for _, tc := range testCases {
tc := tc // ensure we don't accidentally only test the last test case
t.Run(tc.name, func(t *testing.T) {
req := test.NewRequest(t, http.MethodPost, "/_dendrite/admin/resetPassword/"+tc.userID)
if tc.requestOpt != nil {
@ -132,3 +136,100 @@ func TestAdminResetPassword(t *testing.T) {
}
})
}
func TestPurgeRoom(t *testing.T) {
aliceAdmin := test.NewUser(t, test.WithAccountType(uapi.AccountTypeAdmin))
bob := test.NewUser(t)
room := test.NewRoom(t, aliceAdmin, test.RoomPreset(test.PresetTrustedPrivateChat))
// Invite Bob
room.CreateAndInsert(t, aliceAdmin, gomatrixserverlib.MRoomMember, map[string]interface{}{
"membership": "invite",
}, test.WithStateKey(bob.ID))
ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
base, baseClose := testrig.CreateBaseDendrite(t, dbType)
defer baseClose()
fedClient := base.CreateFederationClient()
rsAPI := roomserver.NewInternalAPI(base)
keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fedClient, rsAPI)
userAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, keyAPI, rsAPI, nil)
// this starts the JetStream consumers
syncapi.AddPublicRoutes(base, userAPI, rsAPI, keyAPI)
federationapi.NewInternalAPI(base, fedClient, rsAPI, base.Caches, nil, true)
rsAPI.SetFederationAPI(nil, nil)
keyAPI.SetUserAPI(userAPI)
// Create the room
if err := api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil {
t.Fatalf("failed to send events: %v", err)
}
// We mostly need the rsAPI for this test, so nil for other APIs/caches etc.
AddPublicRoutes(base, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, nil)
// Create the users in the userapi and login
accessTokens := map[*test.User]string{
aliceAdmin: "",
}
for u := range accessTokens {
localpart, serverName, _ := gomatrixserverlib.SplitID('@', u.ID)
userRes := &uapi.PerformAccountCreationResponse{}
password := util.RandomString(8)
if err := userAPI.PerformAccountCreation(ctx, &uapi.PerformAccountCreationRequest{
AccountType: u.AccountType,
Localpart: localpart,
ServerName: serverName,
Password: password,
}, userRes); err != nil {
t.Errorf("failed to create account: %s", err)
}
req := test.NewRequest(t, http.MethodPost, "/_matrix/client/v3/login", test.WithJSONBody(t, map[string]interface{}{
"type": authtypes.LoginTypePassword,
"identifier": map[string]interface{}{
"type": "m.id.user",
"user": u.ID,
},
"password": password,
}))
rec := httptest.NewRecorder()
base.PublicClientAPIMux.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("failed to login: %s", rec.Body.String())
}
accessTokens[u] = gjson.GetBytes(rec.Body.Bytes(), "access_token").String()
}
testCases := []struct {
name string
roomID string
wantOK bool
}{
{name: "Can purge existing room", wantOK: true, roomID: room.ID},
{name: "Can not purge non-existent room", wantOK: false, roomID: "!doesnotexist:localhost"},
{name: "rejects invalid room ID", wantOK: false, roomID: "@doesnotexist:localhost"},
}
for _, tc := range testCases {
tc := tc // ensure we don't accidentally only test the last test case
t.Run(tc.name, func(t *testing.T) {
req := test.NewRequest(t, http.MethodPost, "/_dendrite/admin/purgeRoom/"+tc.roomID)
req.Header.Set("Authorization", "Bearer "+accessTokens[aliceAdmin])
rec := httptest.NewRecorder()
base.DendriteAdminMux.ServeHTTP(rec, req)
t.Logf("%s", rec.Body.String())
if tc.wantOK && rec.Code != http.StatusOK {
t.Fatalf("expected http status %d, got %d: %s", http.StatusOK, rec.Code, rec.Body.String())
}
})
}
})
}

View file

@ -1,6 +1,7 @@
package routing
import (
"context"
"encoding/json"
"fmt"
"net/http"
@ -98,6 +99,37 @@ func AdminEvacuateUser(req *http.Request, cfg *config.ClientAPI, device *userapi
}
}
func AdminPurgeRoom(req *http.Request, cfg *config.ClientAPI, device *userapi.Device, rsAPI roomserverAPI.ClientRoomserverAPI) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
}
roomID, ok := vars["roomID"]
if !ok {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.MissingArgument("Expecting room ID."),
}
}
res := &roomserverAPI.PerformAdminPurgeRoomResponse{}
if err := rsAPI.PerformAdminPurgeRoom(
context.Background(),
&roomserverAPI.PerformAdminPurgeRoomRequest{
RoomID: roomID,
},
res,
); err != nil {
return util.ErrorResponse(err)
}
if err := res.Error; err != nil {
return err.JSONResponse()
}
return util.JSONResponse{
Code: 200,
JSON: res,
}
}
func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *userapi.Device, userAPI userapi.ClientUserAPI) util.JSONResponse {
if req.Body == nil {
return util.JSONResponse{

View file

@ -165,6 +165,12 @@ func Setup(
}),
).Methods(http.MethodGet, http.MethodOptions)
dendriteAdminRouter.Handle("/admin/purgeRoom/{roomID}",
httputil.MakeAdminAPI("admin_purge_room", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return AdminPurgeRoom(req, cfg, device, rsAPI)
}),
).Methods(http.MethodPost, http.MethodOptions)
dendriteAdminRouter.Handle("/admin/resetPassword/{userID}",
httputil.MakeAdminAPI("admin_reset_password", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return AdminResetPassword(req, cfg, device, userAPI)

View file

@ -25,6 +25,7 @@ import (
"github.com/matrix-org/gomatrixserverlib"
"github.com/nats-io/nats.go"
"github.com/sirupsen/logrus"
log "github.com/sirupsen/logrus"
"github.com/matrix-org/dendrite/federationapi/queue"
@ -90,8 +91,10 @@ func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msgs []*nats.Ms
msg := msgs[0] // Guaranteed to exist if onMessage is called
receivedType := api.OutputType(msg.Header.Get(jetstream.RoomEventType))
// Only handle events we care about
if receivedType != api.OutputTypeNewRoomEvent && receivedType != api.OutputTypeNewInboundPeek {
// Only handle events we care about, avoids unneeded unmarshalling
switch receivedType {
case api.OutputTypeNewRoomEvent, api.OutputTypeNewInboundPeek, api.OutputTypePurgeRoom:
default:
return true
}
@ -126,6 +129,14 @@ func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msgs []*nats.Ms
return false
}
case api.OutputTypePurgeRoom:
log.WithField("room_id", output.PurgeRoom.RoomID).Warn("Purging room from federation API")
if err := s.db.PurgeRoom(ctx, output.PurgeRoom.RoomID); err != nil {
logrus.WithField("room_id", output.PurgeRoom.RoomID).WithError(err).Error("Failed to purge room from federation API")
} else {
logrus.WithField("room_id", output.PurgeRoom.RoomID).Warn("Room purged from federation API")
}
default:
log.WithField("type", output.Type).Debug(
"roomserver output log: ignoring unknown output type",
@ -195,7 +206,7 @@ func (s *OutputRoomEventConsumer) processMessage(ore api.OutputNewRoomEvent, rew
}
// If we added new hosts, inform them about our known presence events for this room
if len(addsJoinedHosts) > 0 && ore.Event.Type() == gomatrixserverlib.MRoomMember && ore.Event.StateKey() != nil {
if s.cfg.Matrix.Presence.EnableOutbound && len(addsJoinedHosts) > 0 && ore.Event.Type() == gomatrixserverlib.MRoomMember && ore.Event.StateKey() != nil {
membership, _ := ore.Event.Membership()
if membership == gomatrixserverlib.Join {
s.sendPresence(ore.Event.RoomID(), addsJoinedHosts)

View file

@ -77,6 +77,8 @@ type Database interface {
GetNotaryKeys(ctx context.Context, serverName gomatrixserverlib.ServerName, optKeyIDs []gomatrixserverlib.KeyID) ([]gomatrixserverlib.ServerKeys, error)
// DeleteExpiredEDUs cleans up expired EDUs
DeleteExpiredEDUs(ctx context.Context) error
PurgeRoom(ctx context.Context, roomID string) error
}
type P2PDatabase interface {

View file

@ -373,3 +373,18 @@ func (d *Database) GetNotaryKeys(
})
return sks, err
}
func (d *Database) PurgeRoom(ctx context.Context, roomID string) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
if err := d.FederationJoinedHosts.DeleteJoinedHostsForRoom(ctx, txn, roomID); err != nil {
return fmt.Errorf("failed to purge joined hosts: %w", err)
}
if err := d.FederationInboundPeeks.DeleteInboundPeeks(ctx, txn, roomID); err != nil {
return fmt.Errorf("failed to purge inbound peeks: %w", err)
}
if err := d.FederationOutboundPeeks.DeleteOutboundPeeks(ctx, txn, roomID); err != nil {
return fmt.Errorf("failed to purge outbound peeks: %w", err)
}
return nil
})
}

View file

@ -1,7 +1,7 @@
apiVersion: v2
name: dendrite
version: "0.10.8"
appVersion: "0.10.8"
version: "0.10.9"
appVersion: "0.10.9"
description: Dendrite Matrix Homeserver
type: application
keywords:

View file

@ -15,9 +15,11 @@
{{- define "image.name" -}}
image: {{ .name }}
{{- with .Values.image -}}
image: {{ .repository }}:{{ .tag | default (printf "v%s" $.Chart.AppVersion) }}
imagePullPolicy: {{ .pullPolicy }}
{{- end -}}
{{- end -}}
{{/*
Expand the name of the chart.

View file

@ -45,8 +45,8 @@ spec:
persistentVolumeClaim:
claimName: {{ default (print ( include "dendrite.fullname" . ) "-search-pvc") $.Values.persistence.search.existingClaim | quote }}
containers:
- name: {{ $.Chart.Name }}
{{- include "image.name" $.Values.image | nindent 8 }}
- name: {{ .Chart.Name }}
{{- include "image.name" . | nindent 8 }}
args:
- '--config'
- '/etc/dendrite/dendrite.yaml'

View file

@ -8,6 +8,7 @@ metadata:
name: {{ $name }}
labels:
app.kubernetes.io/component: signingkey-job
{{- include "dendrite.labels" . | nindent 4 }}
---
apiVersion: rbac.authorization.k8s.io/v1
kind: Role
@ -80,7 +81,7 @@ spec:
name: signing-key
readOnly: true
- name: generate-key
{{- include "image.name" $.Values.image | nindent 8 }}
{{- include "image.name" . | nindent 8 }}
command:
- sh
- -c

View file

@ -13,5 +13,5 @@ spec:
ports:
- name: http
protocol: TCP
port: 8008
port: {{ .Values.service.port }}
targetPort: 8008

View file

@ -1,8 +1,10 @@
image:
# -- Docker repository/image to use
name: "ghcr.io/matrix-org/dendrite-monolith:v0.10.8"
repository: "ghcr.io/matrix-org/dendrite-monolith"
# -- Kubernetes pullPolicy
pullPolicy: IfNotPresent
# Overrides the image tag whose default is the chart appVersion.
tag: ""
# signing key to use
@ -345,4 +347,4 @@ ingress:
service:
type: ClusterIP
port: 80
port: 8008

View file

@ -124,6 +124,11 @@ type QueryProvider interface {
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
}
// ExecProvider defines the interface for querys used by RunLimitedVariablesExec.
type ExecProvider interface {
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
}
// SQLite3MaxVariables is the default maximum number of host parameters in a single SQL statement
// SQLlite can handle. See https://www.sqlite.org/limits.html for more information.
const SQLite3MaxVariables = 999
@ -153,6 +158,22 @@ func RunLimitedVariablesQuery(ctx context.Context, query string, qp QueryProvide
return nil
}
// RunLimitedVariablesExec split up a query with more variables than the used database can handle in multiple queries.
func RunLimitedVariablesExec(ctx context.Context, query string, qp ExecProvider, variables []interface{}, limit uint) error {
var start int
for start < len(variables) {
n := minOfInts(len(variables)-start, int(limit))
nextQuery := strings.Replace(query, "($1)", QueryVariadic(n), 1)
_, err := qp.ExecContext(ctx, nextQuery, variables[start:start+n]...)
if err != nil {
util.GetLogger(ctx).WithError(err).Error("ExecContext returned an error")
return err
}
start = start + n
}
return nil
}
// StatementList is a list of SQL statements to prepare and a pointer to where to store the resulting prepared statement.
type StatementList []struct {
Statement **sql.Stmt

View file

@ -3,10 +3,11 @@ package sqlutil
import (
"context"
"database/sql"
"errors"
"reflect"
"testing"
sqlmock "github.com/DATA-DOG/go-sqlmock"
"github.com/DATA-DOG/go-sqlmock"
)
func TestShouldReturnCorrectAmountOfResulstIfFewerVariablesThanLimit(t *testing.T) {
@ -164,6 +165,54 @@ func TestShouldReturnErrorIfRowsScanReturnsError(t *testing.T) {
}
}
func TestRunLimitedVariablesExec(t *testing.T) {
db, mock, err := sqlmock.New()
assertNoError(t, err, "Failed to make DB")
// Query and expect two queries to be executed
mock.ExpectExec(`DELETE FROM WHERE id IN \(\$1\, \$2\)`).
WillReturnResult(sqlmock.NewResult(0, 0))
mock.ExpectExec(`DELETE FROM WHERE id IN \(\$1\, \$2\)`).
WillReturnResult(sqlmock.NewResult(0, 0))
variables := []interface{}{
1, 2, 3, 4,
}
query := "DELETE FROM WHERE id IN ($1)"
if err = RunLimitedVariablesExec(context.Background(), query, db, variables, 2); err != nil {
t.Fatal(err)
}
// Query again, but only 3 parameters, still queries two times
mock.ExpectExec(`DELETE FROM WHERE id IN \(\$1\, \$2\)`).
WillReturnResult(sqlmock.NewResult(0, 0))
mock.ExpectExec(`DELETE FROM WHERE id IN \(\$1\)`).
WillReturnResult(sqlmock.NewResult(0, 0))
if err = RunLimitedVariablesExec(context.Background(), query, db, variables[:3], 2); err != nil {
t.Fatal(err)
}
// Query again, but only 2 parameters, queries only once
mock.ExpectExec(`DELETE FROM WHERE id IN \(\$1\, \$2\)`).
WillReturnResult(sqlmock.NewResult(0, 0))
if err = RunLimitedVariablesExec(context.Background(), query, db, variables[:2], 2); err != nil {
t.Fatal(err)
}
// Test with invalid query (typo) should return an error
mock.ExpectExec(`DELTE FROM`).
WillReturnResult(sqlmock.NewResult(0, 0)).
WillReturnError(errors.New("typo in query"))
if err = RunLimitedVariablesExec(context.Background(), "DELTE FROM", db, variables[:2], 2); err == nil {
t.Fatal("expected an error, but got none")
}
}
func assertNoError(t *testing.T, err error, msg string) {
t.Helper()
if err == nil {

View file

@ -17,7 +17,7 @@ var build string
const (
VersionMajor = 0
VersionMinor = 10
VersionPatch = 8
VersionPatch = 9
VersionTag = "" // example: "rc1"
)

View file

@ -151,6 +151,7 @@ type ClientRoomserverAPI interface {
PerformRoomUpgrade(ctx context.Context, req *PerformRoomUpgradeRequest, resp *PerformRoomUpgradeResponse) error
PerformAdminEvacuateRoom(ctx context.Context, req *PerformAdminEvacuateRoomRequest, res *PerformAdminEvacuateRoomResponse) error
PerformAdminEvacuateUser(ctx context.Context, req *PerformAdminEvacuateUserRequest, res *PerformAdminEvacuateUserResponse) error
PerformAdminPurgeRoom(ctx context.Context, req *PerformAdminPurgeRoomRequest, res *PerformAdminPurgeRoomResponse) error
PerformAdminDownloadState(ctx context.Context, req *PerformAdminDownloadStateRequest, res *PerformAdminDownloadStateResponse) error
PerformPeek(ctx context.Context, req *PerformPeekRequest, res *PerformPeekResponse) error
PerformUnpeek(ctx context.Context, req *PerformUnpeekRequest, res *PerformUnpeekResponse) error

View file

@ -137,6 +137,16 @@ func (t *RoomserverInternalAPITrace) PerformAdminEvacuateUser(
return err
}
func (t *RoomserverInternalAPITrace) PerformAdminPurgeRoom(
ctx context.Context,
req *PerformAdminPurgeRoomRequest,
res *PerformAdminPurgeRoomResponse,
) error {
err := t.Impl.PerformAdminPurgeRoom(ctx, req, res)
util.GetLogger(ctx).WithError(err).Infof("PerformAdminPurgeRoom req=%+v res=%+v", js(req), js(res))
return err
}
func (t *RoomserverInternalAPITrace) PerformAdminDownloadState(
ctx context.Context,
req *PerformAdminDownloadStateRequest,

View file

@ -55,6 +55,8 @@ const (
OutputTypeNewInboundPeek OutputType = "new_inbound_peek"
// OutputTypeRetirePeek indicates that the kafka event is an OutputRetirePeek
OutputTypeRetirePeek OutputType = "retire_peek"
// OutputTypePurgeRoom indicates the event is an OutputPurgeRoom
OutputTypePurgeRoom OutputType = "purge_room"
)
// An OutputEvent is an entry in the roomserver output kafka log.
@ -78,6 +80,8 @@ type OutputEvent struct {
NewInboundPeek *OutputNewInboundPeek `json:"new_inbound_peek,omitempty"`
// The content of event with type OutputTypeRetirePeek
RetirePeek *OutputRetirePeek `json:"retire_peek,omitempty"`
// The content of the event with type OutputPurgeRoom
PurgeRoom *OutputPurgeRoom `json:"purge_room,omitempty"`
}
// Type of the OutputNewRoomEvent.
@ -257,3 +261,7 @@ type OutputRetirePeek struct {
UserID string
DeviceID string
}
type OutputPurgeRoom struct {
RoomID string
}

View file

@ -241,6 +241,14 @@ type PerformAdminEvacuateUserResponse struct {
Error *PerformError
}
type PerformAdminPurgeRoomRequest struct {
RoomID string `json:"room_id"`
}
type PerformAdminPurgeRoomResponse struct {
Error *PerformError `json:"error,omitempty"`
}
type PerformAdminDownloadStateRequest struct {
RoomID string `json:"room_id"`
UserID string `json:"user_id"`

View file

@ -28,6 +28,7 @@ import (
"github.com/matrix-org/dendrite/roomserver/storage"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/gomatrixserverlib"
"github.com/sirupsen/logrus"
)
type Admin struct {
@ -242,6 +243,42 @@ func (r *Admin) PerformAdminEvacuateUser(
return nil
}
func (r *Admin) PerformAdminPurgeRoom(
ctx context.Context,
req *api.PerformAdminPurgeRoomRequest,
res *api.PerformAdminPurgeRoomResponse,
) error {
// Validate we actually got a room ID and nothing else
if _, _, err := gomatrixserverlib.SplitID('!', req.RoomID); err != nil {
res.Error = &api.PerformError{
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("Malformed room ID: %s", err),
}
return nil
}
logrus.WithField("room_id", req.RoomID).Warn("Purging room from roomserver")
if err := r.DB.PurgeRoom(ctx, req.RoomID); err != nil {
logrus.WithField("room_id", req.RoomID).WithError(err).Warn("Failed to purge room from roomserver")
res.Error = &api.PerformError{
Code: api.PerformErrorBadRequest,
Msg: err.Error(),
}
return nil
}
logrus.WithField("room_id", req.RoomID).Warn("Room purged from roomserver")
return r.Inputer.OutputProducer.ProduceRoomEvents(req.RoomID, []api.OutputEvent{
{
Type: api.OutputTypePurgeRoom,
PurgeRoom: &api.OutputPurgeRoom{
RoomID: req.RoomID,
},
},
})
}
func (r *Admin) PerformAdminDownloadState(
ctx context.Context,
req *api.PerformAdminDownloadStateRequest,

View file

@ -40,6 +40,7 @@ const (
RoomserverPerformAdminEvacuateRoomPath = "/roomserver/performAdminEvacuateRoom"
RoomserverPerformAdminEvacuateUserPath = "/roomserver/performAdminEvacuateUser"
RoomserverPerformAdminDownloadStatePath = "/roomserver/performAdminDownloadState"
RoomserverPerformAdminPurgeRoomPath = "/roomserver/performAdminPurgeRoom"
// Query operations
RoomserverQueryLatestEventsAndStatePath = "/roomserver/queryLatestEventsAndState"
@ -285,6 +286,17 @@ func (h *httpRoomserverInternalAPI) PerformAdminEvacuateUser(
)
}
func (h *httpRoomserverInternalAPI) PerformAdminPurgeRoom(
ctx context.Context,
request *api.PerformAdminPurgeRoomRequest,
response *api.PerformAdminPurgeRoomResponse,
) error {
return httputil.CallInternalRPCAPI(
"PerformAdminPurgeRoom", h.roomserverURL+RoomserverPerformAdminPurgeRoomPath,
h.httpClient, ctx, request, response,
)
}
// QueryLatestEventsAndState implements RoomserverQueryAPI
func (h *httpRoomserverInternalAPI) QueryLatestEventsAndState(
ctx context.Context,

View file

@ -65,6 +65,11 @@ func AddRoutes(r api.RoomserverInternalAPI, internalAPIMux *mux.Router, enableMe
httputil.MakeInternalRPCAPI("RoomserverPerformAdminEvacuateUser", enableMetrics, r.PerformAdminEvacuateUser),
)
internalAPIMux.Handle(
RoomserverPerformAdminPurgeRoomPath,
httputil.MakeInternalRPCAPI("RoomserverPerformAdminPurgeRoom", enableMetrics, r.PerformAdminPurgeRoom),
)
internalAPIMux.Handle(
RoomserverPerformAdminDownloadStatePath,
httputil.MakeInternalRPCAPI("RoomserverPerformAdminDownloadState", enableMetrics, r.PerformAdminDownloadState),

View file

@ -14,6 +14,10 @@ import (
userAPI "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/federationapi"
"github.com/matrix-org/dendrite/keyserver"
"github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/syncapi"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/roomserver"
@ -223,3 +227,164 @@ func Test_QueryLeftUsers(t *testing.T) {
})
}
func TestPurgeRoom(t *testing.T) {
alice := test.NewUser(t)
bob := test.NewUser(t)
room := test.NewRoom(t, alice, test.RoomPreset(test.PresetTrustedPrivateChat))
// Invite Bob
inviteEvent := room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomMember, map[string]interface{}{
"membership": "invite",
}, test.WithStateKey(bob.ID))
ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
base, db, close := mustCreateDatabase(t, dbType)
defer close()
jsCtx, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream)
defer jetstream.DeleteAllStreams(jsCtx, &base.Cfg.Global.JetStream)
fedClient := base.CreateFederationClient()
rsAPI := roomserver.NewInternalAPI(base)
keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fedClient, rsAPI)
userAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, keyAPI, rsAPI, nil)
// this starts the JetStream consumers
syncapi.AddPublicRoutes(base, userAPI, rsAPI, keyAPI)
federationapi.NewInternalAPI(base, fedClient, rsAPI, base.Caches, nil, true)
rsAPI.SetFederationAPI(nil, nil)
// Create the room
if err := api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil {
t.Fatalf("failed to send events: %v", err)
}
// some dummy entries to validate after purging
publishResp := &api.PerformPublishResponse{}
if err := rsAPI.PerformPublish(ctx, &api.PerformPublishRequest{RoomID: room.ID, Visibility: "public"}, publishResp); err != nil {
t.Fatal(err)
}
if publishResp.Error != nil {
t.Fatal(publishResp.Error)
}
isPublished, err := db.GetPublishedRoom(ctx, room.ID)
if err != nil {
t.Fatal(err)
}
if !isPublished {
t.Fatalf("room should be published before purging")
}
aliasResp := &api.SetRoomAliasResponse{}
if err = rsAPI.SetRoomAlias(ctx, &api.SetRoomAliasRequest{RoomID: room.ID, Alias: "myalias", UserID: alice.ID}, aliasResp); err != nil {
t.Fatal(err)
}
// check the alias is actually there
aliasesResp := &api.GetAliasesForRoomIDResponse{}
if err = rsAPI.GetAliasesForRoomID(ctx, &api.GetAliasesForRoomIDRequest{RoomID: room.ID}, aliasesResp); err != nil {
t.Fatal(err)
}
wantAliases := 1
if gotAliases := len(aliasesResp.Aliases); gotAliases != wantAliases {
t.Fatalf("expected %d aliases, got %d", wantAliases, gotAliases)
}
// validate the room exists before purging
roomInfo, err := db.RoomInfo(ctx, room.ID)
if err != nil {
t.Fatal(err)
}
if roomInfo == nil {
t.Fatalf("room does not exist")
}
// remember the roomInfo before purging
existingRoomInfo := roomInfo
// validate there is an invite for bob
nids, err := db.EventStateKeyNIDs(ctx, []string{bob.ID})
if err != nil {
t.Fatal(err)
}
bobNID, ok := nids[bob.ID]
if !ok {
t.Fatalf("%s does not exist", bob.ID)
}
_, inviteEventIDs, _, err := db.GetInvitesForUser(ctx, roomInfo.RoomNID, bobNID)
if err != nil {
t.Fatal(err)
}
wantInviteCount := 1
if inviteCount := len(inviteEventIDs); inviteCount != wantInviteCount {
t.Fatalf("expected there to be only %d invite events, got %d", wantInviteCount, inviteCount)
}
if inviteEventIDs[0] != inviteEvent.EventID() {
t.Fatalf("expected invite event ID %s, got %s", inviteEvent.EventID(), inviteEventIDs[0])
}
// purge the room from the database
purgeResp := &api.PerformAdminPurgeRoomResponse{}
if err = rsAPI.PerformAdminPurgeRoom(ctx, &api.PerformAdminPurgeRoomRequest{RoomID: room.ID}, purgeResp); err != nil {
t.Fatal(err)
}
// wait for all consumers to process the purge event
var sum = 1
timeout := time.Second * 5
deadline, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
for sum > 0 {
if deadline.Err() != nil {
t.Fatalf("test timed out after %s", timeout)
}
sum = 0
consumerCh := jsCtx.Consumers(base.Cfg.Global.JetStream.Prefixed(jetstream.OutputRoomEvent))
for x := range consumerCh {
sum += x.NumAckPending
}
time.Sleep(time.Millisecond)
}
roomInfo, err = db.RoomInfo(ctx, room.ID)
if err != nil {
t.Fatal(err)
}
if roomInfo != nil {
t.Fatalf("room should not exist after purging: %+v", roomInfo)
}
// validation below
// There should be no invite left
_, inviteEventIDs, _, err = db.GetInvitesForUser(ctx, existingRoomInfo.RoomNID, bobNID)
if err != nil {
t.Fatal(err)
}
if inviteCount := len(inviteEventIDs); inviteCount > 0 {
t.Fatalf("expected there to be only %d invite events, got %d", wantInviteCount, inviteCount)
}
// aliases should be deleted
aliases, err := db.GetAliasesForRoomID(ctx, room.ID)
if err != nil {
t.Fatal(err)
}
if aliasCount := len(aliases); aliasCount > 0 {
t.Fatalf("expected there to be only %d invite events, got %d", 0, aliasCount)
}
// published room should be deleted
isPublished, err = db.GetPublishedRoom(ctx, room.ID)
if err != nil {
t.Fatal(err)
}
if isPublished {
t.Fatalf("room should not be published after purging")
}
})
}

View file

@ -173,5 +173,6 @@ type Database interface {
GetHistoryVisibilityState(ctx context.Context, roomInfo *types.RoomInfo, eventID string, domain string) ([]*gomatrixserverlib.Event, error)
GetLeftUsers(ctx context.Context, userIDs []string) ([]string, error)
PurgeRoom(ctx context.Context, roomID string) error
UpgradeRoom(ctx context.Context, oldRoomID, newRoomID, eventSender string) error
}

View file

@ -0,0 +1,133 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgres
import (
"context"
"database/sql"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/types"
)
const purgeEventJSONSQL = "" +
"DELETE FROM roomserver_event_json WHERE event_nid = ANY(" +
" SELECT event_nid FROM roomserver_events WHERE room_nid = $1" +
")"
const purgeEventsSQL = "" +
"DELETE FROM roomserver_events WHERE room_nid = $1"
const purgeInvitesSQL = "" +
"DELETE FROM roomserver_invites WHERE room_nid = $1"
const purgeMembershipsSQL = "" +
"DELETE FROM roomserver_membership WHERE room_nid = $1"
const purgePreviousEventsSQL = "" +
"DELETE FROM roomserver_previous_events WHERE event_nids && ANY(" +
" SELECT ARRAY_AGG(event_nid) FROM roomserver_events WHERE room_nid = $1" +
")"
const purgePublishedSQL = "" +
"DELETE FROM roomserver_published WHERE room_id = $1"
const purgeRedactionsSQL = "" +
"DELETE FROM roomserver_redactions WHERE redaction_event_id = ANY(" +
" SELECT event_id FROM roomserver_events WHERE room_nid = $1" +
")"
const purgeRoomAliasesSQL = "" +
"DELETE FROM roomserver_room_aliases WHERE room_id = $1"
const purgeRoomSQL = "" +
"DELETE FROM roomserver_rooms WHERE room_nid = $1"
const purgeStateBlockEntriesSQL = "" +
"DELETE FROM roomserver_state_block WHERE state_block_nid = ANY(" +
" SELECT DISTINCT UNNEST(state_block_nids) FROM roomserver_state_snapshots WHERE room_nid = $1" +
")"
const purgeStateSnapshotEntriesSQL = "" +
"DELETE FROM roomserver_state_snapshots WHERE room_nid = $1"
type purgeStatements struct {
purgeEventJSONStmt *sql.Stmt
purgeEventsStmt *sql.Stmt
purgeInvitesStmt *sql.Stmt
purgeMembershipsStmt *sql.Stmt
purgePreviousEventsStmt *sql.Stmt
purgePublishedStmt *sql.Stmt
purgeRedactionStmt *sql.Stmt
purgeRoomAliasesStmt *sql.Stmt
purgeRoomStmt *sql.Stmt
purgeStateBlockEntriesStmt *sql.Stmt
purgeStateSnapshotEntriesStmt *sql.Stmt
}
func PreparePurgeStatements(db *sql.DB) (*purgeStatements, error) {
s := &purgeStatements{}
return s, sqlutil.StatementList{
{&s.purgeEventJSONStmt, purgeEventJSONSQL},
{&s.purgeEventsStmt, purgeEventsSQL},
{&s.purgeInvitesStmt, purgeInvitesSQL},
{&s.purgeMembershipsStmt, purgeMembershipsSQL},
{&s.purgePublishedStmt, purgePublishedSQL},
{&s.purgePreviousEventsStmt, purgePreviousEventsSQL},
{&s.purgeRedactionStmt, purgeRedactionsSQL},
{&s.purgeRoomAliasesStmt, purgeRoomAliasesSQL},
{&s.purgeRoomStmt, purgeRoomSQL},
{&s.purgeStateBlockEntriesStmt, purgeStateBlockEntriesSQL},
{&s.purgeStateSnapshotEntriesStmt, purgeStateSnapshotEntriesSQL},
}.Prepare(db)
}
func (s *purgeStatements) PurgeRoom(
ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, roomID string,
) error {
// purge by roomID
purgeByRoomID := []*sql.Stmt{
s.purgeRoomAliasesStmt,
s.purgePublishedStmt,
}
for _, stmt := range purgeByRoomID {
_, err := sqlutil.TxStmt(txn, stmt).ExecContext(ctx, roomID)
if err != nil {
return err
}
}
// purge by roomNID
purgeByRoomNID := []*sql.Stmt{
s.purgeStateBlockEntriesStmt,
s.purgeStateSnapshotEntriesStmt,
s.purgeInvitesStmt,
s.purgeMembershipsStmt,
s.purgePreviousEventsStmt,
s.purgeEventJSONStmt,
s.purgeRedactionStmt,
s.purgeEventsStmt,
s.purgeRoomStmt,
}
for _, stmt := range purgeByRoomNID {
_, err := sqlutil.TxStmt(txn, stmt).ExecContext(ctx, roomNID)
if err != nil {
return err
}
}
return nil
}

View file

@ -58,6 +58,9 @@ const insertRoomNIDSQL = "" +
const selectRoomNIDSQL = "" +
"SELECT room_nid FROM roomserver_rooms WHERE room_id = $1"
const selectRoomNIDForUpdateSQL = "" +
"SELECT room_nid FROM roomserver_rooms WHERE room_id = $1 FOR UPDATE"
const selectLatestEventNIDsSQL = "" +
"SELECT latest_event_nids, state_snapshot_nid FROM roomserver_rooms WHERE room_nid = $1"
@ -85,6 +88,7 @@ const bulkSelectRoomNIDsSQL = "" +
type roomStatements struct {
insertRoomNIDStmt *sql.Stmt
selectRoomNIDStmt *sql.Stmt
selectRoomNIDForUpdateStmt *sql.Stmt
selectLatestEventNIDsStmt *sql.Stmt
selectLatestEventNIDsForUpdateStmt *sql.Stmt
updateLatestEventNIDsStmt *sql.Stmt
@ -106,6 +110,7 @@ func PrepareRoomsTable(db *sql.DB) (tables.Rooms, error) {
return s, sqlutil.StatementList{
{&s.insertRoomNIDStmt, insertRoomNIDSQL},
{&s.selectRoomNIDStmt, selectRoomNIDSQL},
{&s.selectRoomNIDForUpdateStmt, selectRoomNIDForUpdateSQL},
{&s.selectLatestEventNIDsStmt, selectLatestEventNIDsSQL},
{&s.selectLatestEventNIDsForUpdateStmt, selectLatestEventNIDsForUpdateSQL},
{&s.updateLatestEventNIDsStmt, updateLatestEventNIDsSQL},
@ -169,6 +174,15 @@ func (s *roomStatements) SelectRoomNID(
return types.RoomNID(roomNID), err
}
func (s *roomStatements) SelectRoomNIDForUpdate(
ctx context.Context, txn *sql.Tx, roomID string,
) (types.RoomNID, error) {
var roomNID int64
stmt := sqlutil.TxStmt(txn, s.selectRoomNIDForUpdateStmt)
err := stmt.QueryRowContext(ctx, roomID).Scan(&roomNID)
return types.RoomNID(roomNID), err
}
func (s *roomStatements) SelectLatestEventNIDs(
ctx context.Context, txn *sql.Tx, roomNID types.RoomNID,
) ([]types.EventNID, types.StateSnapshotNID, error) {

View file

@ -189,6 +189,10 @@ func (d *Database) prepare(db *sql.DB, writer sqlutil.Writer, cache caching.Room
if err != nil {
return err
}
purge, err := PreparePurgeStatements(db)
if err != nil {
return err
}
d.Database = shared.Database{
DB: db,
Cache: cache,
@ -206,6 +210,7 @@ func (d *Database) prepare(db *sql.DB, writer sqlutil.Writer, cache caching.Room
MembershipTable: membership,
PublishedTable: published,
RedactionsTable: redactions,
Purge: purge,
}
return nil
}

View file

@ -43,6 +43,7 @@ type Database struct {
MembershipTable tables.Membership
PublishedTable tables.Published
RedactionsTable tables.Redactions
Purge tables.Purge
GetRoomUpdaterFn func(ctx context.Context, roomInfo *types.RoomInfo) (*RoomUpdater, error)
}
@ -1445,6 +1446,21 @@ func (d *Database) ForgetRoom(ctx context.Context, userID, roomID string, forget
})
}
// PurgeRoom removes all information about a given room from the roomserver.
// For large rooms this operation may take a considerable amount of time.
func (d *Database) PurgeRoom(ctx context.Context, roomID string) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
roomNID, err := d.RoomsTable.SelectRoomNIDForUpdate(ctx, txn, roomID)
if err != nil {
if err == sql.ErrNoRows {
return fmt.Errorf("room %s does not exist", roomID)
}
return fmt.Errorf("failed to lock the room: %w", err)
}
return d.Purge.PurgeRoom(ctx, txn, roomNID, roomID)
})
}
func (d *Database) UpgradeRoom(ctx context.Context, oldRoomID, newRoomID, eventSender string) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {

View file

@ -0,0 +1,153 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/types"
)
const purgeEventJSONSQL = "" +
"DELETE FROM roomserver_event_json WHERE event_nid IN (" +
" SELECT event_nid FROM roomserver_events WHERE room_nid = $1" +
")"
const purgeEventsSQL = "" +
"DELETE FROM roomserver_events WHERE room_nid = $1"
const purgeInvitesSQL = "" +
"DELETE FROM roomserver_invites WHERE room_nid = $1"
const purgeMembershipsSQL = "" +
"DELETE FROM roomserver_membership WHERE room_nid = $1"
const purgePreviousEventsSQL = "" +
"DELETE FROM roomserver_previous_events WHERE event_nids IN(" +
" SELECT event_nid FROM roomserver_events WHERE room_nid = $1" +
")"
const purgePublishedSQL = "" +
"DELETE FROM roomserver_published WHERE room_id = $1"
const purgeRedactionsSQL = "" +
"DELETE FROM roomserver_redactions WHERE redaction_event_id IN(" +
" SELECT event_id FROM roomserver_events WHERE room_nid = $1" +
")"
const purgeRoomAliasesSQL = "" +
"DELETE FROM roomserver_room_aliases WHERE room_id = $1"
const purgeRoomSQL = "" +
"DELETE FROM roomserver_rooms WHERE room_nid = $1"
const purgeStateSnapshotEntriesSQL = "" +
"DELETE FROM roomserver_state_snapshots WHERE room_nid = $1"
type purgeStatements struct {
purgeEventJSONStmt *sql.Stmt
purgeEventsStmt *sql.Stmt
purgeInvitesStmt *sql.Stmt
purgeMembershipsStmt *sql.Stmt
purgePreviousEventsStmt *sql.Stmt
purgePublishedStmt *sql.Stmt
purgeRedactionStmt *sql.Stmt
purgeRoomAliasesStmt *sql.Stmt
purgeRoomStmt *sql.Stmt
purgeStateSnapshotEntriesStmt *sql.Stmt
stateSnapshot *stateSnapshotStatements
}
func PreparePurgeStatements(db *sql.DB, stateSnapshot *stateSnapshotStatements) (*purgeStatements, error) {
s := &purgeStatements{stateSnapshot: stateSnapshot}
return s, sqlutil.StatementList{
{&s.purgeEventJSONStmt, purgeEventJSONSQL},
{&s.purgeEventsStmt, purgeEventsSQL},
{&s.purgeInvitesStmt, purgeInvitesSQL},
{&s.purgeMembershipsStmt, purgeMembershipsSQL},
{&s.purgePublishedStmt, purgePublishedSQL},
{&s.purgePreviousEventsStmt, purgePreviousEventsSQL},
{&s.purgeRedactionStmt, purgeRedactionsSQL},
{&s.purgeRoomAliasesStmt, purgeRoomAliasesSQL},
{&s.purgeRoomStmt, purgeRoomSQL},
//{&s.purgeStateBlockEntriesStmt, purgeStateBlockEntriesSQL},
{&s.purgeStateSnapshotEntriesStmt, purgeStateSnapshotEntriesSQL},
}.Prepare(db)
}
func (s *purgeStatements) PurgeRoom(
ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, roomID string,
) error {
// purge by roomID
purgeByRoomID := []*sql.Stmt{
s.purgeRoomAliasesStmt,
s.purgePublishedStmt,
}
for _, stmt := range purgeByRoomID {
_, err := sqlutil.TxStmt(txn, stmt).ExecContext(ctx, roomID)
if err != nil {
return err
}
}
// purge by roomNID
if err := s.purgeStateBlocks(ctx, txn, roomNID); err != nil {
return err
}
purgeByRoomNID := []*sql.Stmt{
s.purgeStateSnapshotEntriesStmt,
s.purgeInvitesStmt,
s.purgeMembershipsStmt,
s.purgePreviousEventsStmt,
s.purgeEventJSONStmt,
s.purgeRedactionStmt,
s.purgeEventsStmt,
s.purgeRoomStmt,
}
for _, stmt := range purgeByRoomNID {
_, err := sqlutil.TxStmt(txn, stmt).ExecContext(ctx, roomNID)
if err != nil {
return err
}
}
return nil
}
func (s *purgeStatements) purgeStateBlocks(
ctx context.Context, txn *sql.Tx, roomNID types.RoomNID,
) error {
// Get all stateBlockNIDs
stateBlockNIDs, err := s.stateSnapshot.selectStateBlockNIDsForRoomNID(ctx, txn, roomNID)
if err != nil {
return err
}
params := make([]interface{}, len(stateBlockNIDs))
seenNIDs := make(map[types.StateBlockNID]struct{}, len(stateBlockNIDs))
// dedupe NIDs
for k, v := range stateBlockNIDs {
if _, ok := seenNIDs[v]; ok {
continue
}
params[k] = v
seenNIDs[v] = struct{}{}
}
query := "DELETE FROM roomserver_state_block WHERE state_block_nid IN($1)"
return sqlutil.RunLimitedVariablesExec(ctx, query, txn, params, sqlutil.SQLite3MaxVariables)
}

View file

@ -74,10 +74,14 @@ const bulkSelectRoomIDsSQL = "" +
const bulkSelectRoomNIDsSQL = "" +
"SELECT room_nid FROM roomserver_rooms WHERE room_id IN ($1)"
const selectRoomNIDForUpdateSQL = "" +
"SELECT room_nid FROM roomserver_rooms WHERE room_id = $1"
type roomStatements struct {
db *sql.DB
insertRoomNIDStmt *sql.Stmt
selectRoomNIDStmt *sql.Stmt
selectRoomNIDForUpdateStmt *sql.Stmt
selectLatestEventNIDsStmt *sql.Stmt
selectLatestEventNIDsForUpdateStmt *sql.Stmt
updateLatestEventNIDsStmt *sql.Stmt
@ -105,6 +109,7 @@ func PrepareRoomsTable(db *sql.DB) (tables.Rooms, error) {
//{&s.selectRoomVersionForRoomNIDsStmt, selectRoomVersionForRoomNIDsSQL},
{&s.selectRoomInfoStmt, selectRoomInfoSQL},
{&s.selectRoomIDsStmt, selectRoomIDsSQL},
{&s.selectRoomNIDForUpdateStmt, selectRoomNIDForUpdateSQL},
}.Prepare(db)
}
@ -169,6 +174,15 @@ func (s *roomStatements) SelectRoomNID(
return types.RoomNID(roomNID), err
}
func (s *roomStatements) SelectRoomNIDForUpdate(
ctx context.Context, txn *sql.Tx, roomID string,
) (types.RoomNID, error) {
var roomNID int64
stmt := sqlutil.TxStmt(txn, s.selectRoomNIDForUpdateStmt)
err := stmt.QueryRowContext(ctx, roomID).Scan(&roomNID)
return types.RoomNID(roomNID), err
}
func (s *roomStatements) SelectLatestEventNIDs(
ctx context.Context, txn *sql.Tx, roomNID types.RoomNID,
) ([]types.EventNID, types.StateSnapshotNID, error) {

View file

@ -24,7 +24,6 @@ import (
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/storage/tables"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/util"
)
@ -68,7 +67,7 @@ func CreateStateBlockTable(db *sql.DB) error {
return err
}
func PrepareStateBlockTable(db *sql.DB) (tables.StateBlock, error) {
func PrepareStateBlockTable(db *sql.DB) (*stateBlockStatements, error) {
s := &stateBlockStatements{
db: db,
}

View file

@ -62,10 +62,14 @@ const bulkSelectStateBlockNIDsSQL = "" +
"SELECT state_snapshot_nid, state_block_nids FROM roomserver_state_snapshots" +
" WHERE state_snapshot_nid IN ($1) ORDER BY state_snapshot_nid ASC"
const selectStateBlockNIDsForRoomNID = "" +
"SELECT state_block_nids FROM roomserver_state_snapshots WHERE room_nid = $1"
type stateSnapshotStatements struct {
db *sql.DB
insertStateStmt *sql.Stmt
bulkSelectStateBlockNIDsStmt *sql.Stmt
selectStateBlockNIDsStmt *sql.Stmt
}
func CreateStateSnapshotTable(db *sql.DB) error {
@ -73,7 +77,7 @@ func CreateStateSnapshotTable(db *sql.DB) error {
return err
}
func PrepareStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) {
func PrepareStateSnapshotTable(db *sql.DB) (*stateSnapshotStatements, error) {
s := &stateSnapshotStatements{
db: db,
}
@ -81,6 +85,7 @@ func PrepareStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) {
return s, sqlutil.StatementList{
{&s.insertStateStmt, insertStateSQL},
{&s.bulkSelectStateBlockNIDsStmt, bulkSelectStateBlockNIDsSQL},
{&s.selectStateBlockNIDsStmt, selectStateBlockNIDsForRoomNID},
}.Prepare(db)
}
@ -146,3 +151,29 @@ func (s *stateSnapshotStatements) BulkSelectStateForHistoryVisibility(
) ([]types.EventNID, error) {
return nil, tables.OptimisationNotSupportedError
}
func (s *stateSnapshotStatements) selectStateBlockNIDsForRoomNID(
ctx context.Context, txn *sql.Tx, roomNID types.RoomNID,
) ([]types.StateBlockNID, error) {
var res []types.StateBlockNID
rows, err := sqlutil.TxStmt(txn, s.selectStateBlockNIDsStmt).QueryContext(ctx, roomNID)
if err != nil {
return res, nil
}
defer internal.CloseAndLogIfError(ctx, rows, "selectStateBlockNIDsForRoomNID: rows.close() failed")
var stateBlockNIDs []types.StateBlockNID
var stateBlockNIDsJSON string
for rows.Next() {
if err = rows.Scan(&stateBlockNIDsJSON); err != nil {
return nil, err
}
if err = json.Unmarshal([]byte(stateBlockNIDsJSON), &stateBlockNIDs); err != nil {
return nil, err
}
res = append(res, stateBlockNIDs...)
}
return res, rows.Err()
}

View file

@ -197,6 +197,11 @@ func (d *Database) prepare(db *sql.DB, writer sqlutil.Writer, cache caching.Room
if err != nil {
return err
}
purge, err := PreparePurgeStatements(db, stateSnapshot)
if err != nil {
return err
}
d.Database = shared.Database{
DB: db,
Cache: cache,
@ -215,6 +220,7 @@ func (d *Database) prepare(db *sql.DB, writer sqlutil.Writer, cache caching.Room
PublishedTable: published,
RedactionsTable: redactions,
GetRoomUpdaterFn: d.GetRoomUpdater,
Purge: purge,
}
return nil
}

View file

@ -73,6 +73,7 @@ type Events interface {
type Rooms interface {
InsertRoomNID(ctx context.Context, txn *sql.Tx, roomID string, roomVersion gomatrixserverlib.RoomVersion) (types.RoomNID, error)
SelectRoomNID(ctx context.Context, txn *sql.Tx, roomID string) (types.RoomNID, error)
SelectRoomNIDForUpdate(ctx context.Context, txn *sql.Tx, roomID string) (types.RoomNID, error)
SelectLatestEventNIDs(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) ([]types.EventNID, types.StateSnapshotNID, error)
SelectLatestEventsNIDsForUpdate(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) ([]types.EventNID, types.EventNID, types.StateSnapshotNID, error)
UpdateLatestEventNIDs(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, eventNIDs []types.EventNID, lastEventSentNID types.EventNID, stateSnapshotNID types.StateSnapshotNID) error
@ -173,6 +174,12 @@ type Redactions interface {
MarkRedactionValidated(ctx context.Context, txn *sql.Tx, redactionEventID string, validated bool) error
}
type Purge interface {
PurgeRoom(
ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, roomID string,
) error
}
// StrippedEvent represents a stripped event for returning extracted content values.
type StrippedEvent struct {
RoomID string

View file

@ -85,10 +85,10 @@ func (c *ClientAPI) Verify(configErrs *ConfigErrors, isMonolith bool) {
c.RecaptchaApiJsUrl = "https://www.google.com/recaptcha/api.js"
}
if c.RecaptchaFormField == "" {
c.RecaptchaFormField = "g-recaptcha"
c.RecaptchaFormField = "g-recaptcha-response"
}
if c.RecaptchaSitekeyClass == "" {
c.RecaptchaSitekeyClass = "g-recaptcha-response"
c.RecaptchaSitekeyClass = "g-recaptcha"
}
checkNotEmpty(configErrs, "client_api.recaptcha_public_key", c.RecaptchaPublicKey)
checkNotEmpty(configErrs, "client_api.recaptcha_private_key", c.RecaptchaPrivateKey)

View file

@ -23,6 +23,7 @@ import (
"github.com/getsentry/sentry-go"
"github.com/matrix-org/gomatrixserverlib"
"github.com/nats-io/nats.go"
"github.com/sirupsen/logrus"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
@ -127,6 +128,12 @@ func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msgs []*nats.Ms
s.onRetirePeek(s.ctx, *output.RetirePeek)
case api.OutputTypeRedactedEvent:
err = s.onRedactEvent(s.ctx, *output.RedactedEvent)
case api.OutputTypePurgeRoom:
err = s.onPurgeRoom(s.ctx, *output.PurgeRoom)
if err != nil {
logrus.WithField("room_id", output.PurgeRoom.RoomID).WithError(err).Error("Failed to purge room from sync API")
return true // non-fatal, as otherwise we end up in a loop of trying to purge the room
}
default:
log.WithField("type", output.Type).Debug(
"roomserver output log: ignoring unknown output type",
@ -473,6 +480,20 @@ func (s *OutputRoomEventConsumer) onRetirePeek(
s.notifier.OnRetirePeek(msg.RoomID, msg.UserID, msg.DeviceID, types.StreamingToken{PDUPosition: sp})
}
func (s *OutputRoomEventConsumer) onPurgeRoom(
ctx context.Context, req api.OutputPurgeRoom,
) error {
logrus.WithField("room_id", req.RoomID).Warn("Purging room from sync API")
if err := s.db.PurgeRoom(ctx, req.RoomID); err != nil {
logrus.WithField("room_id", req.RoomID).WithError(err).Error("Failed to purge room from sync API")
return err
} else {
logrus.WithField("room_id", req.RoomID).Warn("Room purged from sync API")
return nil
}
}
func (s *OutputRoomEventConsumer) updateStateEvent(event *gomatrixserverlib.HeaderedEvent) (*gomatrixserverlib.HeaderedEvent, error) {
if event.StateKey() == nil {
return event, nil

View file

@ -16,16 +16,16 @@ package routing
import (
"encoding/json"
"math"
"net/http"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/types"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
)
type getMembershipResponse struct {
@ -87,21 +87,20 @@ func GetMemberships(
if err != nil {
return jsonerror.InternalServerError()
}
defer db.Rollback() // nolint: errcheck
atToken, err := types.NewTopologyTokenFromString(at)
if err != nil {
atToken = types.TopologyToken{Depth: math.MaxInt64, PDUPosition: math.MaxInt64}
if queryRes.HasBeenInRoom && !queryRes.IsInRoom {
// If you have left the room then this will be the members of the room when you left.
atToken, err = db.EventPositionInTopology(req.Context(), queryRes.EventID)
} else {
// If you are joined to the room then this will be the current members of the room.
atToken, err = db.MaxTopologicalPosition(req.Context(), roomID)
}
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("unable to get 'atToken'")
return jsonerror.InternalServerError()
}
}
}
eventIDs, err := db.SelectMemberships(req.Context(), roomID, atToken, membership, notMembership)
if err != nil {

View file

@ -17,6 +17,7 @@ package routing
import (
"context"
"fmt"
"math"
"net/http"
"sort"
"time"
@ -57,7 +58,7 @@ type messagesResp struct {
StartStream string `json:"start_stream,omitempty"` // NOTSPEC: used by Cerulean, so clients can hit /messages then immediately /sync with a latest sync token
End string `json:"end,omitempty"`
Chunk []gomatrixserverlib.ClientEvent `json:"chunk"`
State []gomatrixserverlib.ClientEvent `json:"state"`
State []gomatrixserverlib.ClientEvent `json:"state,omitempty"`
}
// OnIncomingMessagesRequest implements the /messages endpoint from the
@ -177,10 +178,11 @@ func OnIncomingMessagesRequest(
// If "to" isn't provided, it defaults to either the earliest stream
// position (if we're going backward) or to the latest one (if we're
// going forward).
to, err = setToDefault(req.Context(), snapshot, backwardOrdering, roomID)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("setToDefault failed")
return jsonerror.InternalServerError()
to = types.TopologyToken{Depth: math.MaxInt64, PDUPosition: math.MaxInt64}
if backwardOrdering {
// go 1 earlier than the first event so we correctly fetch the earliest event
// this is because Database.GetEventsInTopologicalRange is exclusive of the lower-bound.
to = types.TopologyToken{}
}
wasToProvided = false
}
@ -577,24 +579,3 @@ func (r *messagesReq) backfill(roomID string, backwardsExtremities map[string][]
return events, nil
}
// setToDefault returns the default value for the "to" query parameter of a
// request to /messages if not provided. It defaults to either the earliest
// topological position (if we're going backward) or to the latest one (if we're
// going forward).
// Returns an error if there was an issue with retrieving the latest position
// from the database
func setToDefault(
ctx context.Context, snapshot storage.DatabaseTransaction, backwardOrdering bool,
roomID string,
) (to types.TopologyToken, err error) {
if backwardOrdering {
// go 1 earlier than the first event so we correctly fetch the earliest event
// this is because Database.GetEventsInTopologicalRange is exclusive of the lower-bound.
to = types.TopologyToken{}
} else {
to, err = snapshot.MaxTopologicalPosition(ctx, roomID)
}
return
}

View file

@ -84,8 +84,6 @@ type DatabaseTransaction interface {
EventPositionInTopology(ctx context.Context, eventID string) (types.TopologyToken, error)
// BackwardExtremitiesForRoom returns a map of backwards extremity event ID to a list of its prev_events.
BackwardExtremitiesForRoom(ctx context.Context, roomID string) (backwardExtremities map[string][]string, err error)
// MaxTopologicalPosition returns the highest topological position for a given room.
MaxTopologicalPosition(ctx context.Context, roomID string) (types.TopologyToken, error)
// StreamEventsToEvents converts streamEvent to Event. If device is non-nil and
// matches the streamevent.transactionID device then the transaction ID gets
// added to the unsigned section of the output event.
@ -134,6 +132,8 @@ type Database interface {
// PurgeRoomState completely purges room state from the sync API. This is done when
// receiving an output event that completely resets the state.
PurgeRoomState(ctx context.Context, roomID string) error
// PurgeRoom entirely eliminates a room from the sync API, timeline, state and all.
PurgeRoom(ctx context.Context, roomID string) error
// UpsertAccountData keeps track of new or updated account data, by saving the type
// of the new/updated data, and the user ID and room ID the data is related to (empty)
// room ID means the data isn't specific to any room)

View file

@ -47,10 +47,14 @@ const selectBackwardExtremitiesForRoomSQL = "" +
const deleteBackwardExtremitySQL = "" +
"DELETE FROM syncapi_backward_extremities WHERE room_id = $1 AND prev_event_id = $2"
const purgeBackwardExtremitiesSQL = "" +
"DELETE FROM syncapi_backward_extremities WHERE room_id = $1"
type backwardExtremitiesStatements struct {
insertBackwardExtremityStmt *sql.Stmt
selectBackwardExtremitiesForRoomStmt *sql.Stmt
deleteBackwardExtremityStmt *sql.Stmt
purgeBackwardExtremitiesStmt *sql.Stmt
}
func NewPostgresBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremities, error) {
@ -59,16 +63,12 @@ func NewPostgresBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremiti
if err != nil {
return nil, err
}
if s.insertBackwardExtremityStmt, err = db.Prepare(insertBackwardExtremitySQL); err != nil {
return nil, err
}
if s.selectBackwardExtremitiesForRoomStmt, err = db.Prepare(selectBackwardExtremitiesForRoomSQL); err != nil {
return nil, err
}
if s.deleteBackwardExtremityStmt, err = db.Prepare(deleteBackwardExtremitySQL); err != nil {
return nil, err
}
return s, nil
return s, sqlutil.StatementList{
{&s.insertBackwardExtremityStmt, insertBackwardExtremitySQL},
{&s.selectBackwardExtremitiesForRoomStmt, selectBackwardExtremitiesForRoomSQL},
{&s.deleteBackwardExtremityStmt, deleteBackwardExtremitySQL},
{&s.purgeBackwardExtremitiesStmt, purgeBackwardExtremitiesSQL},
}.Prepare(db)
}
func (s *backwardExtremitiesStatements) InsertsBackwardExtremity(
@ -106,3 +106,10 @@ func (s *backwardExtremitiesStatements) DeleteBackwardExtremity(
_, err = sqlutil.TxStmt(txn, s.deleteBackwardExtremityStmt).ExecContext(ctx, roomID, knownEventID)
return
}
func (s *backwardExtremitiesStatements) PurgeBackwardExtremities(
ctx context.Context, txn *sql.Tx, roomID string,
) error {
_, err := sqlutil.TxStmt(txn, s.purgeBackwardExtremitiesStmt).ExecContext(ctx, roomID)
return err
}

View file

@ -62,11 +62,15 @@ const selectInviteEventsInRangeSQL = "" +
const selectMaxInviteIDSQL = "" +
"SELECT MAX(id) FROM syncapi_invite_events"
const purgeInvitesSQL = "" +
"DELETE FROM syncapi_invite_events WHERE room_id = $1"
type inviteEventsStatements struct {
insertInviteEventStmt *sql.Stmt
selectInviteEventsInRangeStmt *sql.Stmt
deleteInviteEventStmt *sql.Stmt
selectMaxInviteIDStmt *sql.Stmt
purgeInvitesStmt *sql.Stmt
}
func NewPostgresInvitesTable(db *sql.DB) (tables.Invites, error) {
@ -75,19 +79,13 @@ func NewPostgresInvitesTable(db *sql.DB) (tables.Invites, error) {
if err != nil {
return nil, err
}
if s.insertInviteEventStmt, err = db.Prepare(insertInviteEventSQL); err != nil {
return nil, err
}
if s.selectInviteEventsInRangeStmt, err = db.Prepare(selectInviteEventsInRangeSQL); err != nil {
return nil, err
}
if s.deleteInviteEventStmt, err = db.Prepare(deleteInviteEventSQL); err != nil {
return nil, err
}
if s.selectMaxInviteIDStmt, err = db.Prepare(selectMaxInviteIDSQL); err != nil {
return nil, err
}
return s, nil
return s, sqlutil.StatementList{
{&s.insertInviteEventStmt, insertInviteEventSQL},
{&s.selectInviteEventsInRangeStmt, selectInviteEventsInRangeSQL},
{&s.deleteInviteEventStmt, deleteInviteEventSQL},
{&s.selectMaxInviteIDStmt, selectMaxInviteIDSQL},
{&s.purgeInvitesStmt, purgeInvitesSQL},
}.Prepare(db)
}
func (s *inviteEventsStatements) InsertInviteEvent(
@ -181,3 +179,10 @@ func (s *inviteEventsStatements) SelectMaxInviteID(
}
return
}
func (s *inviteEventsStatements) PurgeInvites(
ctx context.Context, txn *sql.Tx, roomID string,
) error {
_, err := sqlutil.TxStmt(txn, s.purgeInvitesStmt).ExecContext(ctx, roomID)
return err
}

View file

@ -65,6 +65,9 @@ const selectMembershipCountSQL = "" +
const selectMembershipBeforeSQL = "" +
"SELECT membership, topological_pos FROM syncapi_memberships WHERE room_id = $1 and user_id = $2 AND topological_pos <= $3 ORDER BY topological_pos DESC LIMIT 1"
const purgeMembershipsSQL = "" +
"DELETE FROM syncapi_memberships WHERE room_id = $1"
const selectMembersSQL = `
SELECT event_id FROM (
SELECT DISTINCT ON (room_id, user_id) room_id, user_id, event_id, membership FROM syncapi_memberships WHERE room_id = $1 AND topological_pos <= $2 ORDER BY room_id, user_id, stream_pos DESC
@ -77,6 +80,7 @@ type membershipsStatements struct {
upsertMembershipStmt *sql.Stmt
selectMembershipCountStmt *sql.Stmt
selectMembershipForUserStmt *sql.Stmt
purgeMembershipsStmt *sql.Stmt
selectMembersStmt *sql.Stmt
}
@ -90,6 +94,7 @@ func NewPostgresMembershipsTable(db *sql.DB) (tables.Memberships, error) {
{&s.upsertMembershipStmt, upsertMembershipSQL},
{&s.selectMembershipCountStmt, selectMembershipCountSQL},
{&s.selectMembershipForUserStmt, selectMembershipBeforeSQL},
{&s.purgeMembershipsStmt, purgeMembershipsSQL},
{&s.selectMembersStmt, selectMembersSQL},
}.Prepare(db)
}
@ -139,6 +144,13 @@ func (s *membershipsStatements) SelectMembershipForUser(
return membership, topologyPos, nil
}
func (s *membershipsStatements) PurgeMemberships(
ctx context.Context, txn *sql.Tx, roomID string,
) error {
_, err := sqlutil.TxStmt(txn, s.purgeMembershipsStmt).ExecContext(ctx, roomID)
return err
}
func (s *membershipsStatements) SelectMemberships(
ctx context.Context, txn *sql.Tx,
roomID string, pos types.TopologyToken,

View file

@ -37,6 +37,7 @@ func NewPostgresNotificationDataTable(db *sql.DB) (tables.NotificationData, erro
{&r.upsertRoomUnreadCounts, upsertRoomUnreadNotificationCountsSQL},
{&r.selectUserUnreadCountsForRooms, selectUserUnreadNotificationsForRooms},
{&r.selectMaxID, selectMaxNotificationIDSQL},
{&r.purgeNotificationData, purgeNotificationDataSQL},
}.Prepare(db)
}
@ -44,6 +45,7 @@ type notificationDataStatements struct {
upsertRoomUnreadCounts *sql.Stmt
selectUserUnreadCountsForRooms *sql.Stmt
selectMaxID *sql.Stmt
purgeNotificationData *sql.Stmt
}
const notificationDataSchema = `
@ -70,6 +72,9 @@ const selectUserUnreadNotificationsForRooms = `SELECT room_id, notification_coun
const selectMaxNotificationIDSQL = `SELECT CASE COUNT(*) WHEN 0 THEN 0 ELSE MAX(id) END FROM syncapi_notification_data`
const purgeNotificationDataSQL = "" +
"DELETE FROM syncapi_notification_data WHERE room_id = $1"
func (r *notificationDataStatements) UpsertRoomUnreadCounts(ctx context.Context, txn *sql.Tx, userID, roomID string, notificationCount, highlightCount int) (pos types.StreamPosition, err error) {
err = sqlutil.TxStmt(txn, r.upsertRoomUnreadCounts).QueryRowContext(ctx, userID, roomID, notificationCount, highlightCount).Scan(&pos)
return
@ -106,3 +111,10 @@ func (r *notificationDataStatements) SelectMaxID(ctx context.Context, txn *sql.T
err := sqlutil.TxStmt(txn, r.selectMaxID).QueryRowContext(ctx).Scan(&id)
return id, err
}
func (s *notificationDataStatements) PurgeNotificationData(
ctx context.Context, txn *sql.Tx, roomID string,
) error {
_, err := sqlutil.TxStmt(txn, s.purgeNotificationData).ExecContext(ctx, roomID)
return err
}

View file

@ -176,6 +176,9 @@ const selectContextAfterEventSQL = "" +
" AND ( $7::text[] IS NULL OR NOT(type LIKE ANY($7)) )" +
" ORDER BY id ASC LIMIT $3"
const purgeEventsSQL = "" +
"DELETE FROM syncapi_output_room_events WHERE room_id = $1"
const selectSearchSQL = "SELECT id, event_id, headered_event_json FROM syncapi_output_room_events WHERE id > $1 AND type = ANY($2) ORDER BY id ASC LIMIT $3"
type outputRoomEventsStatements struct {
@ -193,6 +196,7 @@ type outputRoomEventsStatements struct {
selectContextEventStmt *sql.Stmt
selectContextBeforeEventStmt *sql.Stmt
selectContextAfterEventStmt *sql.Stmt
purgeEventsStmt *sql.Stmt
selectSearchStmt *sql.Stmt
}
@ -230,6 +234,7 @@ func NewPostgresEventsTable(db *sql.DB) (tables.Events, error) {
{&s.selectContextEventStmt, selectContextEventSQL},
{&s.selectContextBeforeEventStmt, selectContextBeforeEventSQL},
{&s.selectContextAfterEventStmt, selectContextAfterEventSQL},
{&s.purgeEventsStmt, purgeEventsSQL},
{&s.selectSearchStmt, selectSearchSQL},
}.Prepare(db)
}
@ -658,6 +663,13 @@ func rowsToStreamEvents(rows *sql.Rows) ([]types.StreamEvent, error) {
return result, rows.Err()
}
func (s *outputRoomEventsStatements) PurgeEvents(
ctx context.Context, txn *sql.Tx, roomID string,
) error {
_, err := sqlutil.TxStmt(txn, s.purgeEventsStmt).ExecContext(ctx, roomID)
return err
}
func (s *outputRoomEventsStatements) ReIndex(ctx context.Context, txn *sql.Tx, limit, afterID int64, types []string) (map[int64]gomatrixserverlib.HeaderedEvent, error) {
rows, err := sqlutil.TxStmt(txn, s.selectSearchStmt).QueryContext(ctx, afterID, pq.StringArray(types), limit)
if err != nil {

View file

@ -18,11 +18,12 @@ import (
"context"
"database/sql"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib"
)
const outputRoomEventsTopologySchema = `
@ -65,28 +66,23 @@ const selectPositionInTopologySQL = "" +
"SELECT topological_position, stream_position FROM syncapi_output_room_events_topology" +
" WHERE event_id = $1"
// Select the max topological position for the room, then sort by stream position and take the highest,
// returning both topological and stream positions.
const selectMaxPositionInTopologySQL = "" +
"SELECT topological_position, stream_position FROM syncapi_output_room_events_topology" +
" WHERE topological_position=(" +
"SELECT MAX(topological_position) FROM syncapi_output_room_events_topology WHERE room_id=$1" +
") ORDER BY stream_position DESC LIMIT 1"
const selectStreamToTopologicalPositionAscSQL = "" +
"SELECT topological_position FROM syncapi_output_room_events_topology WHERE room_id = $1 AND stream_position >= $2 ORDER BY topological_position ASC LIMIT 1;"
const selectStreamToTopologicalPositionDescSQL = "" +
"SELECT topological_position FROM syncapi_output_room_events_topology WHERE room_id = $1 AND stream_position <= $2 ORDER BY topological_position DESC LIMIT 1;"
const purgeEventsTopologySQL = "" +
"DELETE FROM syncapi_output_room_events_topology WHERE room_id = $1"
type outputRoomEventsTopologyStatements struct {
insertEventInTopologyStmt *sql.Stmt
selectEventIDsInRangeASCStmt *sql.Stmt
selectEventIDsInRangeDESCStmt *sql.Stmt
selectPositionInTopologyStmt *sql.Stmt
selectMaxPositionInTopologyStmt *sql.Stmt
selectStreamToTopologicalPositionAscStmt *sql.Stmt
selectStreamToTopologicalPositionDescStmt *sql.Stmt
purgeEventsTopologyStmt *sql.Stmt
}
func NewPostgresTopologyTable(db *sql.DB) (tables.Topology, error) {
@ -95,28 +91,15 @@ func NewPostgresTopologyTable(db *sql.DB) (tables.Topology, error) {
if err != nil {
return nil, err
}
if s.insertEventInTopologyStmt, err = db.Prepare(insertEventInTopologySQL); err != nil {
return nil, err
}
if s.selectEventIDsInRangeASCStmt, err = db.Prepare(selectEventIDsInRangeASCSQL); err != nil {
return nil, err
}
if s.selectEventIDsInRangeDESCStmt, err = db.Prepare(selectEventIDsInRangeDESCSQL); err != nil {
return nil, err
}
if s.selectPositionInTopologyStmt, err = db.Prepare(selectPositionInTopologySQL); err != nil {
return nil, err
}
if s.selectMaxPositionInTopologyStmt, err = db.Prepare(selectMaxPositionInTopologySQL); err != nil {
return nil, err
}
if s.selectStreamToTopologicalPositionAscStmt, err = db.Prepare(selectStreamToTopologicalPositionAscSQL); err != nil {
return nil, err
}
if s.selectStreamToTopologicalPositionDescStmt, err = db.Prepare(selectStreamToTopologicalPositionDescSQL); err != nil {
return nil, err
}
return s, nil
return s, sqlutil.StatementList{
{&s.insertEventInTopologyStmt, insertEventInTopologySQL},
{&s.selectEventIDsInRangeASCStmt, selectEventIDsInRangeASCSQL},
{&s.selectEventIDsInRangeDESCStmt, selectEventIDsInRangeDESCSQL},
{&s.selectPositionInTopologyStmt, selectPositionInTopologySQL},
{&s.selectStreamToTopologicalPositionAscStmt, selectStreamToTopologicalPositionAscSQL},
{&s.selectStreamToTopologicalPositionDescStmt, selectStreamToTopologicalPositionDescSQL},
{&s.purgeEventsTopologyStmt, purgeEventsTopologySQL},
}.Prepare(db)
}
// InsertEventInTopology inserts the given event in the room's topology, based
@ -190,9 +173,9 @@ func (s *outputRoomEventsTopologyStatements) SelectStreamToTopologicalPosition(
return
}
func (s *outputRoomEventsTopologyStatements) SelectMaxPositionInTopology(
func (s *outputRoomEventsTopologyStatements) PurgeEventsTopology(
ctx context.Context, txn *sql.Tx, roomID string,
) (pos types.StreamPosition, spos types.StreamPosition, err error) {
err = sqlutil.TxStmt(txn, s.selectMaxPositionInTopologyStmt).QueryRowContext(ctx, roomID).Scan(&pos, &spos)
return
) error {
_, err := sqlutil.TxStmt(txn, s.purgeEventsTopologyStmt).ExecContext(ctx, roomID)
return err
}

View file

@ -65,6 +65,9 @@ const selectPeekingDevicesSQL = "" +
const selectMaxPeekIDSQL = "" +
"SELECT MAX(id) FROM syncapi_peeks"
const purgePeeksSQL = "" +
"DELETE FROM syncapi_peeks WHERE room_id = $1"
type peekStatements struct {
db *sql.DB
insertPeekStmt *sql.Stmt
@ -73,6 +76,7 @@ type peekStatements struct {
selectPeeksInRangeStmt *sql.Stmt
selectPeekingDevicesStmt *sql.Stmt
selectMaxPeekIDStmt *sql.Stmt
purgePeeksStmt *sql.Stmt
}
func NewPostgresPeeksTable(db *sql.DB) (tables.Peeks, error) {
@ -83,25 +87,15 @@ func NewPostgresPeeksTable(db *sql.DB) (tables.Peeks, error) {
s := &peekStatements{
db: db,
}
if s.insertPeekStmt, err = db.Prepare(insertPeekSQL); err != nil {
return nil, err
}
if s.deletePeekStmt, err = db.Prepare(deletePeekSQL); err != nil {
return nil, err
}
if s.deletePeeksStmt, err = db.Prepare(deletePeeksSQL); err != nil {
return nil, err
}
if s.selectPeeksInRangeStmt, err = db.Prepare(selectPeeksInRangeSQL); err != nil {
return nil, err
}
if s.selectPeekingDevicesStmt, err = db.Prepare(selectPeekingDevicesSQL); err != nil {
return nil, err
}
if s.selectMaxPeekIDStmt, err = db.Prepare(selectMaxPeekIDSQL); err != nil {
return nil, err
}
return s, nil
return s, sqlutil.StatementList{
{&s.insertPeekStmt, insertPeekSQL},
{&s.deletePeekStmt, deletePeekSQL},
{&s.deletePeeksStmt, deletePeeksSQL},
{&s.selectPeeksInRangeStmt, selectPeeksInRangeSQL},
{&s.selectPeekingDevicesStmt, selectPeekingDevicesSQL},
{&s.selectMaxPeekIDStmt, selectMaxPeekIDSQL},
{&s.purgePeeksStmt, purgePeeksSQL},
}.Prepare(db)
}
func (s *peekStatements) InsertPeek(
@ -184,3 +178,10 @@ func (s *peekStatements) SelectMaxPeekID(
}
return
}
func (s *peekStatements) PurgePeeks(
ctx context.Context, txn *sql.Tx, roomID string,
) error {
_, err := sqlutil.TxStmt(txn, s.purgePeeksStmt).ExecContext(ctx, roomID)
return err
}

View file

@ -62,11 +62,15 @@ const selectRoomReceipts = "" +
const selectMaxReceiptIDSQL = "" +
"SELECT MAX(id) FROM syncapi_receipts"
const purgeReceiptsSQL = "" +
"DELETE FROM syncapi_receipts WHERE room_id = $1"
type receiptStatements struct {
db *sql.DB
upsertReceipt *sql.Stmt
selectRoomReceipts *sql.Stmt
selectMaxReceiptID *sql.Stmt
purgeReceiptsStmt *sql.Stmt
}
func NewPostgresReceiptsTable(db *sql.DB) (tables.Receipts, error) {
@ -86,16 +90,12 @@ func NewPostgresReceiptsTable(db *sql.DB) (tables.Receipts, error) {
r := &receiptStatements{
db: db,
}
if r.upsertReceipt, err = db.Prepare(upsertReceipt); err != nil {
return nil, fmt.Errorf("unable to prepare upsertReceipt statement: %w", err)
}
if r.selectRoomReceipts, err = db.Prepare(selectRoomReceipts); err != nil {
return nil, fmt.Errorf("unable to prepare selectRoomReceipts statement: %w", err)
}
if r.selectMaxReceiptID, err = db.Prepare(selectMaxReceiptIDSQL); err != nil {
return nil, fmt.Errorf("unable to prepare selectRoomReceipts statement: %w", err)
}
return r, nil
return r, sqlutil.StatementList{
{&r.upsertReceipt, upsertReceipt},
{&r.selectRoomReceipts, selectRoomReceipts},
{&r.selectMaxReceiptID, selectMaxReceiptIDSQL},
{&r.purgeReceiptsStmt, purgeReceiptsSQL},
}.Prepare(db)
}
func (r *receiptStatements) UpsertReceipt(ctx context.Context, txn *sql.Tx, roomId, receiptType, userId, eventId string, timestamp gomatrixserverlib.Timestamp) (pos types.StreamPosition, err error) {
@ -138,3 +138,10 @@ func (s *receiptStatements) SelectMaxReceiptID(
}
return
}
func (s *receiptStatements) PurgeReceipts(
ctx context.Context, txn *sql.Tx, roomID string,
) error {
_, err := sqlutil.TxStmt(txn, s.purgeReceiptsStmt).ExecContext(ctx, roomID)
return err
}

View file

@ -242,20 +242,6 @@ func (d *Database) handleBackwardExtremities(ctx context.Context, txn *sql.Tx, e
return nil
}
func (d *Database) PurgeRoomState(
ctx context.Context, roomID string,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
// If the event is a create event then we'll delete all of the existing
// data for the room. The only reason that a create event would be replayed
// to us in this way is if we're about to receive the entire room state.
if err := d.CurrentRoomState.DeleteRoomStateForRoom(ctx, txn, roomID); err != nil {
return fmt.Errorf("d.CurrentRoomState.DeleteRoomStateForRoom: %w", err)
}
return nil
})
}
func (d *Database) WriteEvent(
ctx context.Context,
ev *gomatrixserverlib.HeaderedEvent,

View file

@ -4,6 +4,7 @@ import (
"context"
"database/sql"
"fmt"
"math"
"github.com/matrix-org/gomatrixserverlib"
"github.com/tidwall/gjson"
@ -269,16 +270,6 @@ func (d *DatabaseTransaction) BackwardExtremitiesForRoom(
return d.BackwardExtremities.SelectBackwardExtremitiesForRoom(ctx, d.txn, roomID)
}
func (d *DatabaseTransaction) MaxTopologicalPosition(
ctx context.Context, roomID string,
) (types.TopologyToken, error) {
depth, streamPos, err := d.Topology.SelectMaxPositionInTopology(ctx, d.txn, roomID)
if err != nil {
return types.TopologyToken{}, err
}
return types.TopologyToken{Depth: depth, PDUPosition: streamPos}, nil
}
func (d *DatabaseTransaction) EventPositionInTopology(
ctx context.Context, eventID string,
) (types.TopologyToken, error) {
@ -297,11 +288,7 @@ func (d *DatabaseTransaction) StreamToTopologicalPosition(
case err == sql.ErrNoRows && backwardOrdering: // no events in range, going backward
return types.TopologyToken{PDUPosition: streamPos}, nil
case err == sql.ErrNoRows && !backwardOrdering: // no events in range, going forward
topoPos, streamPos, err = d.Topology.SelectMaxPositionInTopology(ctx, d.txn, roomID)
if err != nil {
return types.TopologyToken{}, fmt.Errorf("d.Topology.SelectMaxPositionInTopology: %w", err)
}
return types.TopologyToken{Depth: topoPos, PDUPosition: streamPos}, nil
return types.TopologyToken{Depth: math.MaxInt64, PDUPosition: math.MaxInt64}, nil
case err != nil: // some other error happened
return types.TopologyToken{}, fmt.Errorf("d.Topology.SelectStreamToTopologicalPosition: %w", err)
default:
@ -662,6 +649,53 @@ func (d *DatabaseTransaction) MaxStreamPositionForPresence(ctx context.Context)
return d.Presence.GetMaxPresenceID(ctx, d.txn)
}
func (d *Database) PurgeRoom(ctx context.Context, roomID string) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
if err := d.BackwardExtremities.PurgeBackwardExtremities(ctx, txn, roomID); err != nil {
return fmt.Errorf("failed to purge backward extremities: %w", err)
}
if err := d.CurrentRoomState.DeleteRoomStateForRoom(ctx, txn, roomID); err != nil {
return fmt.Errorf("failed to purge current room state: %w", err)
}
if err := d.Invites.PurgeInvites(ctx, txn, roomID); err != nil {
return fmt.Errorf("failed to purge invites: %w", err)
}
if err := d.Memberships.PurgeMemberships(ctx, txn, roomID); err != nil {
return fmt.Errorf("failed to purge memberships: %w", err)
}
if err := d.NotificationData.PurgeNotificationData(ctx, txn, roomID); err != nil {
return fmt.Errorf("failed to purge notification data: %w", err)
}
if err := d.OutputEvents.PurgeEvents(ctx, txn, roomID); err != nil {
return fmt.Errorf("failed to purge events: %w", err)
}
if err := d.Topology.PurgeEventsTopology(ctx, txn, roomID); err != nil {
return fmt.Errorf("failed to purge events topology: %w", err)
}
if err := d.Peeks.PurgePeeks(ctx, txn, roomID); err != nil {
return fmt.Errorf("failed to purge peeks: %w", err)
}
if err := d.Receipts.PurgeReceipts(ctx, txn, roomID); err != nil {
return fmt.Errorf("failed to purge receipts: %w", err)
}
return nil
})
}
func (d *Database) PurgeRoomState(
ctx context.Context, roomID string,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
// If the event is a create event then we'll delete all of the existing
// data for the room. The only reason that a create event would be replayed
// to us in this way is if we're about to receive the entire room state.
if err := d.CurrentRoomState.DeleteRoomStateForRoom(ctx, txn, roomID); err != nil {
return fmt.Errorf("d.CurrentRoomState.DeleteRoomStateForRoom: %w", err)
}
return nil
})
}
func (d *DatabaseTransaction) MaxStreamPositionForRelations(ctx context.Context) (types.StreamPosition, error) {
id, err := d.Relations.SelectMaxRelationID(ctx, d.txn)
return types.StreamPosition(id), err

View file

@ -47,11 +47,15 @@ const selectBackwardExtremitiesForRoomSQL = "" +
const deleteBackwardExtremitySQL = "" +
"DELETE FROM syncapi_backward_extremities WHERE room_id = $1 AND prev_event_id = $2"
const purgeBackwardExtremitiesSQL = "" +
"DELETE FROM syncapi_backward_extremities WHERE room_id = $1"
type backwardExtremitiesStatements struct {
db *sql.DB
insertBackwardExtremityStmt *sql.Stmt
selectBackwardExtremitiesForRoomStmt *sql.Stmt
deleteBackwardExtremityStmt *sql.Stmt
purgeBackwardExtremitiesStmt *sql.Stmt
}
func NewSqliteBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremities, error) {
@ -62,16 +66,12 @@ func NewSqliteBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremities
if err != nil {
return nil, err
}
if s.insertBackwardExtremityStmt, err = db.Prepare(insertBackwardExtremitySQL); err != nil {
return nil, err
}
if s.selectBackwardExtremitiesForRoomStmt, err = db.Prepare(selectBackwardExtremitiesForRoomSQL); err != nil {
return nil, err
}
if s.deleteBackwardExtremityStmt, err = db.Prepare(deleteBackwardExtremitySQL); err != nil {
return nil, err
}
return s, nil
return s, sqlutil.StatementList{
{&s.insertBackwardExtremityStmt, insertBackwardExtremitySQL},
{&s.selectBackwardExtremitiesForRoomStmt, selectBackwardExtremitiesForRoomSQL},
{&s.deleteBackwardExtremityStmt, deleteBackwardExtremitySQL},
{&s.purgeBackwardExtremitiesStmt, purgeBackwardExtremitiesSQL},
}.Prepare(db)
}
func (s *backwardExtremitiesStatements) InsertsBackwardExtremity(
@ -109,3 +109,10 @@ func (s *backwardExtremitiesStatements) DeleteBackwardExtremity(
_, err = sqlutil.TxStmt(txn, s.deleteBackwardExtremityStmt).ExecContext(ctx, roomID, knownEventID)
return err
}
func (s *backwardExtremitiesStatements) PurgeBackwardExtremities(
ctx context.Context, txn *sql.Tx, roomID string,
) error {
_, err := sqlutil.TxStmt(txn, s.purgeBackwardExtremitiesStmt).ExecContext(ctx, roomID)
return err
}

View file

@ -57,6 +57,9 @@ const selectInviteEventsInRangeSQL = "" +
const selectMaxInviteIDSQL = "" +
"SELECT MAX(id) FROM syncapi_invite_events"
const purgeInvitesSQL = "" +
"DELETE FROM syncapi_invite_events WHERE room_id = $1"
type inviteEventsStatements struct {
db *sql.DB
streamIDStatements *StreamIDStatements
@ -64,6 +67,7 @@ type inviteEventsStatements struct {
selectInviteEventsInRangeStmt *sql.Stmt
deleteInviteEventStmt *sql.Stmt
selectMaxInviteIDStmt *sql.Stmt
purgeInvitesStmt *sql.Stmt
}
func NewSqliteInvitesTable(db *sql.DB, streamID *StreamIDStatements) (tables.Invites, error) {
@ -75,19 +79,13 @@ func NewSqliteInvitesTable(db *sql.DB, streamID *StreamIDStatements) (tables.Inv
if err != nil {
return nil, err
}
if s.insertInviteEventStmt, err = db.Prepare(insertInviteEventSQL); err != nil {
return nil, err
}
if s.selectInviteEventsInRangeStmt, err = db.Prepare(selectInviteEventsInRangeSQL); err != nil {
return nil, err
}
if s.deleteInviteEventStmt, err = db.Prepare(deleteInviteEventSQL); err != nil {
return nil, err
}
if s.selectMaxInviteIDStmt, err = db.Prepare(selectMaxInviteIDSQL); err != nil {
return nil, err
}
return s, nil
return s, sqlutil.StatementList{
{&s.insertInviteEventStmt, insertInviteEventSQL},
{&s.selectInviteEventsInRangeStmt, selectInviteEventsInRangeSQL},
{&s.deleteInviteEventStmt, deleteInviteEventSQL},
{&s.selectMaxInviteIDStmt, selectMaxInviteIDSQL},
{&s.purgeInvitesStmt, purgeInvitesSQL},
}.Prepare(db)
}
func (s *inviteEventsStatements) InsertInviteEvent(
@ -192,3 +190,10 @@ func (s *inviteEventsStatements) SelectMaxInviteID(
}
return
}
func (s *inviteEventsStatements) PurgeInvites(
ctx context.Context, txn *sql.Tx, roomID string,
) error {
_, err := sqlutil.TxStmt(txn, s.purgeInvitesStmt).ExecContext(ctx, roomID)
return err
}

View file

@ -72,6 +72,9 @@ SELECT event_id FROM
AND ($4 IS NULL OR t.membership <> $4)
`
const purgeMembershipsSQL = "" +
"DELETE FROM syncapi_memberships WHERE room_id = $1"
type membershipsStatements struct {
db *sql.DB
upsertMembershipStmt *sql.Stmt
@ -79,6 +82,7 @@ type membershipsStatements struct {
//selectHeroesStmt *sql.Stmt - prepared at runtime due to variadic
selectMembershipForUserStmt *sql.Stmt
selectMembersStmt *sql.Stmt
purgeMembershipsStmt *sql.Stmt
}
func NewSqliteMembershipsTable(db *sql.DB) (tables.Memberships, error) {
@ -94,6 +98,7 @@ func NewSqliteMembershipsTable(db *sql.DB) (tables.Memberships, error) {
{&s.selectMembershipCountStmt, selectMembershipCountSQL},
{&s.selectMembershipForUserStmt, selectMembershipBeforeSQL},
{&s.selectMembersStmt, selectMembersSQL},
{&s.purgeMembershipsStmt, purgeMembershipsSQL},
}.Prepare(db)
}
@ -142,6 +147,13 @@ func (s *membershipsStatements) SelectMembershipForUser(
return membership, topologyPos, nil
}
func (s *membershipsStatements) PurgeMemberships(
ctx context.Context, txn *sql.Tx, roomID string,
) error {
_, err := sqlutil.TxStmt(txn, s.purgeMembershipsStmt).ExecContext(ctx, roomID)
return err
}
func (s *membershipsStatements) SelectMemberships(
ctx context.Context, txn *sql.Tx,
roomID string, pos types.TopologyToken,

View file

@ -38,6 +38,7 @@ func NewSqliteNotificationDataTable(db *sql.DB, streamID *StreamIDStatements) (t
return r, sqlutil.StatementList{
{&r.upsertRoomUnreadCounts, upsertRoomUnreadNotificationCountsSQL},
{&r.selectMaxID, selectMaxNotificationIDSQL},
{&r.purgeNotificationData, purgeNotificationDataSQL},
// {&r.selectUserUnreadCountsForRooms, selectUserUnreadNotificationsForRooms}, // used at runtime
}.Prepare(db)
}
@ -47,6 +48,7 @@ type notificationDataStatements struct {
streamIDStatements *StreamIDStatements
upsertRoomUnreadCounts *sql.Stmt
selectMaxID *sql.Stmt
purgeNotificationData *sql.Stmt
//selectUserUnreadCountsForRooms *sql.Stmt
}
@ -73,6 +75,9 @@ const selectUserUnreadNotificationsForRooms = `SELECT room_id, notification_coun
const selectMaxNotificationIDSQL = `SELECT CASE COUNT(*) WHEN 0 THEN 0 ELSE MAX(id) END FROM syncapi_notification_data`
const purgeNotificationDataSQL = "" +
"DELETE FROM syncapi_notification_data WHERE room_id = $1"
func (r *notificationDataStatements) UpsertRoomUnreadCounts(ctx context.Context, txn *sql.Tx, userID, roomID string, notificationCount, highlightCount int) (pos types.StreamPosition, err error) {
pos, err = r.streamIDStatements.nextNotificationID(ctx, nil)
if err != nil {
@ -124,3 +129,10 @@ func (r *notificationDataStatements) SelectMaxID(ctx context.Context, txn *sql.T
err := sqlutil.TxStmt(txn, r.selectMaxID).QueryRowContext(ctx).Scan(&id)
return id, err
}
func (s *notificationDataStatements) PurgeNotificationData(
ctx context.Context, txn *sql.Tx, roomID string,
) error {
_, err := sqlutil.TxStmt(txn, s.purgeNotificationData).ExecContext(ctx, roomID)
return err
}

View file

@ -120,6 +120,9 @@ const selectContextAfterEventSQL = "" +
const selectSearchSQL = "SELECT id, event_id, headered_event_json FROM syncapi_output_room_events WHERE type IN ($1) AND id > $2 LIMIT $3 ORDER BY id ASC"
const purgeEventsSQL = "" +
"DELETE FROM syncapi_output_room_events WHERE room_id = $1"
type outputRoomEventsStatements struct {
db *sql.DB
streamIDStatements *StreamIDStatements
@ -130,6 +133,7 @@ type outputRoomEventsStatements struct {
selectContextEventStmt *sql.Stmt
selectContextBeforeEventStmt *sql.Stmt
selectContextAfterEventStmt *sql.Stmt
purgeEventsStmt *sql.Stmt
//selectSearchStmt *sql.Stmt - prepared at runtime
}
@ -163,6 +167,7 @@ func NewSqliteEventsTable(db *sql.DB, streamID *StreamIDStatements) (tables.Even
{&s.selectContextEventStmt, selectContextEventSQL},
{&s.selectContextBeforeEventStmt, selectContextBeforeEventSQL},
{&s.selectContextAfterEventStmt, selectContextAfterEventSQL},
{&s.purgeEventsStmt, purgeEventsSQL},
//{&s.selectSearchStmt, selectSearchSQL}, - prepared at runtime
}.Prepare(db)
}
@ -666,6 +671,13 @@ func unmarshalStateIDs(addIDsJSON, delIDsJSON string) (addIDs []string, delIDs [
return
}
func (s *outputRoomEventsStatements) PurgeEvents(
ctx context.Context, txn *sql.Tx, roomID string,
) error {
_, err := sqlutil.TxStmt(txn, s.purgeEventsStmt).ExecContext(ctx, roomID)
return err
}
func (s *outputRoomEventsStatements) ReIndex(ctx context.Context, txn *sql.Tx, limit, afterID int64, types []string) (map[int64]gomatrixserverlib.HeaderedEvent, error) {
params := make([]interface{}, len(types))
for i := range types {

View file

@ -18,10 +18,11 @@ import (
"context"
"database/sql"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib"
)
const outputRoomEventsTopologySchema = `
@ -61,25 +62,24 @@ const selectPositionInTopologySQL = "" +
"SELECT topological_position, stream_position FROM syncapi_output_room_events_topology" +
" WHERE event_id = $1"
const selectMaxPositionInTopologySQL = "" +
"SELECT MAX(topological_position), stream_position FROM syncapi_output_room_events_topology" +
" WHERE room_id = $1 ORDER BY stream_position DESC"
const selectStreamToTopologicalPositionAscSQL = "" +
"SELECT topological_position FROM syncapi_output_room_events_topology WHERE room_id = $1 AND stream_position >= $2 ORDER BY topological_position ASC LIMIT 1;"
const selectStreamToTopologicalPositionDescSQL = "" +
"SELECT topological_position FROM syncapi_output_room_events_topology WHERE room_id = $1 AND stream_position <= $2 ORDER BY topological_position DESC LIMIT 1;"
const purgeEventsTopologySQL = "" +
"DELETE FROM syncapi_output_room_events_topology WHERE room_id = $1"
type outputRoomEventsTopologyStatements struct {
db *sql.DB
insertEventInTopologyStmt *sql.Stmt
selectEventIDsInRangeASCStmt *sql.Stmt
selectEventIDsInRangeDESCStmt *sql.Stmt
selectPositionInTopologyStmt *sql.Stmt
selectMaxPositionInTopologyStmt *sql.Stmt
selectStreamToTopologicalPositionAscStmt *sql.Stmt
selectStreamToTopologicalPositionDescStmt *sql.Stmt
purgeEventsTopologyStmt *sql.Stmt
}
func NewSqliteTopologyTable(db *sql.DB) (tables.Topology, error) {
@ -90,28 +90,15 @@ func NewSqliteTopologyTable(db *sql.DB) (tables.Topology, error) {
if err != nil {
return nil, err
}
if s.insertEventInTopologyStmt, err = db.Prepare(insertEventInTopologySQL); err != nil {
return nil, err
}
if s.selectEventIDsInRangeASCStmt, err = db.Prepare(selectEventIDsInRangeASCSQL); err != nil {
return nil, err
}
if s.selectEventIDsInRangeDESCStmt, err = db.Prepare(selectEventIDsInRangeDESCSQL); err != nil {
return nil, err
}
if s.selectPositionInTopologyStmt, err = db.Prepare(selectPositionInTopologySQL); err != nil {
return nil, err
}
if s.selectMaxPositionInTopologyStmt, err = db.Prepare(selectMaxPositionInTopologySQL); err != nil {
return nil, err
}
if s.selectStreamToTopologicalPositionAscStmt, err = db.Prepare(selectStreamToTopologicalPositionAscSQL); err != nil {
return nil, err
}
if s.selectStreamToTopologicalPositionDescStmt, err = db.Prepare(selectStreamToTopologicalPositionDescSQL); err != nil {
return nil, err
}
return s, nil
return s, sqlutil.StatementList{
{&s.insertEventInTopologyStmt, insertEventInTopologySQL},
{&s.selectEventIDsInRangeASCStmt, selectEventIDsInRangeASCSQL},
{&s.selectEventIDsInRangeDESCStmt, selectEventIDsInRangeDESCSQL},
{&s.selectPositionInTopologyStmt, selectPositionInTopologySQL},
{&s.selectStreamToTopologicalPositionAscStmt, selectStreamToTopologicalPositionAscSQL},
{&s.selectStreamToTopologicalPositionDescStmt, selectStreamToTopologicalPositionDescSQL},
{&s.purgeEventsTopologyStmt, purgeEventsTopologySQL},
}.Prepare(db)
}
// insertEventInTopology inserts the given event in the room's topology, based
@ -183,10 +170,9 @@ func (s *outputRoomEventsTopologyStatements) SelectStreamToTopologicalPosition(
return
}
func (s *outputRoomEventsTopologyStatements) SelectMaxPositionInTopology(
func (s *outputRoomEventsTopologyStatements) PurgeEventsTopology(
ctx context.Context, txn *sql.Tx, roomID string,
) (pos types.StreamPosition, spos types.StreamPosition, err error) {
stmt := sqlutil.TxStmt(txn, s.selectMaxPositionInTopologyStmt)
err = stmt.QueryRowContext(ctx, roomID).Scan(&pos, &spos)
return
) error {
_, err := sqlutil.TxStmt(txn, s.purgeEventsTopologyStmt).ExecContext(ctx, roomID)
return err
}

View file

@ -64,6 +64,9 @@ const selectPeekingDevicesSQL = "" +
const selectMaxPeekIDSQL = "" +
"SELECT MAX(id) FROM syncapi_peeks"
const purgePeeksSQL = "" +
"DELETE FROM syncapi_peeks WHERE room_id = $1"
type peekStatements struct {
db *sql.DB
streamIDStatements *StreamIDStatements
@ -73,6 +76,7 @@ type peekStatements struct {
selectPeeksInRangeStmt *sql.Stmt
selectPeekingDevicesStmt *sql.Stmt
selectMaxPeekIDStmt *sql.Stmt
purgePeeksStmt *sql.Stmt
}
func NewSqlitePeeksTable(db *sql.DB, streamID *StreamIDStatements) (tables.Peeks, error) {
@ -84,25 +88,15 @@ func NewSqlitePeeksTable(db *sql.DB, streamID *StreamIDStatements) (tables.Peeks
db: db,
streamIDStatements: streamID,
}
if s.insertPeekStmt, err = db.Prepare(insertPeekSQL); err != nil {
return nil, err
}
if s.deletePeekStmt, err = db.Prepare(deletePeekSQL); err != nil {
return nil, err
}
if s.deletePeeksStmt, err = db.Prepare(deletePeeksSQL); err != nil {
return nil, err
}
if s.selectPeeksInRangeStmt, err = db.Prepare(selectPeeksInRangeSQL); err != nil {
return nil, err
}
if s.selectPeekingDevicesStmt, err = db.Prepare(selectPeekingDevicesSQL); err != nil {
return nil, err
}
if s.selectMaxPeekIDStmt, err = db.Prepare(selectMaxPeekIDSQL); err != nil {
return nil, err
}
return s, nil
return s, sqlutil.StatementList{
{&s.insertPeekStmt, insertPeekSQL},
{&s.deletePeekStmt, deletePeekSQL},
{&s.deletePeeksStmt, deletePeeksSQL},
{&s.selectPeeksInRangeStmt, selectPeeksInRangeSQL},
{&s.selectPeekingDevicesStmt, selectPeekingDevicesSQL},
{&s.selectMaxPeekIDStmt, selectMaxPeekIDSQL},
{&s.purgePeeksStmt, purgePeeksSQL},
}.Prepare(db)
}
func (s *peekStatements) InsertPeek(
@ -204,3 +198,10 @@ func (s *peekStatements) SelectMaxPeekID(
}
return
}
func (s *peekStatements) PurgePeeks(
ctx context.Context, txn *sql.Tx, roomID string,
) error {
_, err := sqlutil.TxStmt(txn, s.purgePeeksStmt).ExecContext(ctx, roomID)
return err
}

View file

@ -58,12 +58,16 @@ const selectRoomReceipts = "" +
const selectMaxReceiptIDSQL = "" +
"SELECT MAX(id) FROM syncapi_receipts"
const purgeReceiptsSQL = "" +
"DELETE FROM syncapi_receipts WHERE room_id = $1"
type receiptStatements struct {
db *sql.DB
streamIDStatements *StreamIDStatements
upsertReceipt *sql.Stmt
selectRoomReceipts *sql.Stmt
selectMaxReceiptID *sql.Stmt
purgeReceiptsStmt *sql.Stmt
}
func NewSqliteReceiptsTable(db *sql.DB, streamID *StreamIDStatements) (tables.Receipts, error) {
@ -84,16 +88,12 @@ func NewSqliteReceiptsTable(db *sql.DB, streamID *StreamIDStatements) (tables.Re
db: db,
streamIDStatements: streamID,
}
if r.upsertReceipt, err = db.Prepare(upsertReceipt); err != nil {
return nil, fmt.Errorf("unable to prepare upsertReceipt statement: %w", err)
}
if r.selectRoomReceipts, err = db.Prepare(selectRoomReceipts); err != nil {
return nil, fmt.Errorf("unable to prepare selectRoomReceipts statement: %w", err)
}
if r.selectMaxReceiptID, err = db.Prepare(selectMaxReceiptIDSQL); err != nil {
return nil, fmt.Errorf("unable to prepare selectRoomReceipts statement: %w", err)
}
return r, nil
return r, sqlutil.StatementList{
{&r.upsertReceipt, upsertReceipt},
{&r.selectRoomReceipts, selectRoomReceipts},
{&r.selectMaxReceiptID, selectMaxReceiptIDSQL},
{&r.purgeReceiptsStmt, purgeReceiptsSQL},
}.Prepare(db)
}
// UpsertReceipt creates new user receipts
@ -153,3 +153,10 @@ func (s *receiptStatements) SelectMaxReceiptID(
}
return
}
func (s *receiptStatements) PurgeReceipts(
ctx context.Context, txn *sql.Tx, roomID string,
) error {
_, err := sqlutil.TxStmt(txn, s.purgeReceiptsStmt).ExecContext(ctx, roomID)
return err
}

View file

@ -5,6 +5,7 @@ import (
"context"
"encoding/json"
"fmt"
"math"
"reflect"
"testing"
@ -199,10 +200,7 @@ func TestGetEventsInRangeWithTopologyToken(t *testing.T) {
_ = MustWriteEvents(t, db, events)
WithSnapshot(t, db, func(snapshot storage.DatabaseTransaction) {
from, err := snapshot.MaxTopologicalPosition(ctx, r.ID)
if err != nil {
t.Fatalf("failed to get MaxTopologicalPosition: %s", err)
}
from := types.TopologyToken{Depth: math.MaxInt64, PDUPosition: math.MaxInt64}
t.Logf("max topo pos = %+v", from)
// head towards the beginning of time
to := types.TopologyToken{}
@ -219,6 +217,88 @@ func TestGetEventsInRangeWithTopologyToken(t *testing.T) {
})
}
func TestStreamToTopologicalPosition(t *testing.T) {
alice := test.NewUser(t)
r := test.NewRoom(t, alice)
testCases := []struct {
name string
roomID string
streamPos types.StreamPosition
backwardOrdering bool
wantToken types.TopologyToken
}{
{
name: "forward ordering found streamPos returns found position",
roomID: r.ID,
streamPos: 1,
backwardOrdering: false,
wantToken: types.TopologyToken{Depth: 1, PDUPosition: 1},
},
{
name: "forward ordering not found streamPos returns max position",
roomID: r.ID,
streamPos: 100,
backwardOrdering: false,
wantToken: types.TopologyToken{Depth: math.MaxInt64, PDUPosition: math.MaxInt64},
},
{
name: "backward ordering found streamPos returns found position",
roomID: r.ID,
streamPos: 1,
backwardOrdering: true,
wantToken: types.TopologyToken{Depth: 1, PDUPosition: 1},
},
{
name: "backward ordering not found streamPos returns maxDepth with param pduPosition",
roomID: r.ID,
streamPos: 100,
backwardOrdering: true,
wantToken: types.TopologyToken{Depth: 5, PDUPosition: 100},
},
{
name: "backward non-existent room returns zero token",
roomID: "!doesnotexist:localhost",
streamPos: 1,
backwardOrdering: true,
wantToken: types.TopologyToken{Depth: 0, PDUPosition: 1},
},
{
name: "forward non-existent room returns max token",
roomID: "!doesnotexist:localhost",
streamPos: 1,
backwardOrdering: false,
wantToken: types.TopologyToken{Depth: math.MaxInt64, PDUPosition: math.MaxInt64},
},
}
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close, closeBase := MustCreateDatabase(t, dbType)
defer close()
defer closeBase()
txn, err := db.NewDatabaseTransaction(ctx)
if err != nil {
t.Fatal(err)
}
defer txn.Rollback()
MustWriteEvents(t, db, r.Events())
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
token, err := txn.StreamToTopologicalPosition(ctx, tc.roomID, tc.streamPos, tc.backwardOrdering)
if err != nil {
t.Fatal(err)
}
if tc.wantToken != token {
t.Fatalf("expected token %q, got %q", tc.wantToken, token)
}
})
}
})
}
/*
// The purpose of this test is to make sure that backpagination returns all events, even if some events have the same depth.
// For cases where events have the same depth, the streaming token should be used to tie break so events written via WriteEvent

View file

@ -39,6 +39,7 @@ type Invites interface {
// for the room.
SelectInviteEventsInRange(ctx context.Context, txn *sql.Tx, targetUserID string, r types.Range) (invites map[string]*gomatrixserverlib.HeaderedEvent, retired map[string]*gomatrixserverlib.HeaderedEvent, maxID types.StreamPosition, err error)
SelectMaxInviteID(ctx context.Context, txn *sql.Tx) (id int64, err error)
PurgeInvites(ctx context.Context, txn *sql.Tx, roomID string) error
}
type Peeks interface {
@ -48,6 +49,7 @@ type Peeks interface {
SelectPeeksInRange(ctxt context.Context, txn *sql.Tx, userID, deviceID string, r types.Range) (peeks []types.Peek, err error)
SelectPeekingDevices(ctxt context.Context, txn *sql.Tx) (peekingDevices map[string][]types.PeekingDevice, err error)
SelectMaxPeekID(ctx context.Context, txn *sql.Tx) (id int64, err error)
PurgePeeks(ctx context.Context, txn *sql.Tx, roomID string) error
}
type Events interface {
@ -75,6 +77,8 @@ type Events interface {
SelectContextEvent(ctx context.Context, txn *sql.Tx, roomID, eventID string) (int, gomatrixserverlib.HeaderedEvent, error)
SelectContextBeforeEvent(ctx context.Context, txn *sql.Tx, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) ([]*gomatrixserverlib.HeaderedEvent, error)
SelectContextAfterEvent(ctx context.Context, txn *sql.Tx, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) (int, []*gomatrixserverlib.HeaderedEvent, error)
PurgeEvents(ctx context.Context, txn *sql.Tx, roomID string) error
ReIndex(ctx context.Context, txn *sql.Tx, limit, offset int64, types []string) (map[int64]gomatrixserverlib.HeaderedEvent, error)
}
@ -91,10 +95,9 @@ type Topology interface {
SelectEventIDsInRange(ctx context.Context, txn *sql.Tx, roomID string, minDepth, maxDepth, maxStreamPos types.StreamPosition, limit int, chronologicalOrder bool) (eventIDs []string, err error)
// SelectPositionInTopology returns the depth and stream position of a given event in the topology of the room it belongs to.
SelectPositionInTopology(ctx context.Context, txn *sql.Tx, eventID string) (depth, spos types.StreamPosition, err error)
// SelectMaxPositionInTopology returns the event which has the highest depth, and if there are multiple, the event with the highest stream position.
SelectMaxPositionInTopology(ctx context.Context, txn *sql.Tx, roomID string) (depth types.StreamPosition, spos types.StreamPosition, err error)
// SelectStreamToTopologicalPosition converts a stream position to a topological position by finding the nearest topological position in the room.
SelectStreamToTopologicalPosition(ctx context.Context, txn *sql.Tx, roomID string, streamPos types.StreamPosition, forward bool) (topoPos types.StreamPosition, err error)
PurgeEventsTopology(ctx context.Context, txn *sql.Tx, roomID string) error
}
type CurrentRoomState interface {
@ -148,6 +151,7 @@ type BackwardsExtremities interface {
SelectBackwardExtremitiesForRoom(ctx context.Context, txn *sql.Tx, roomID string) (bwExtrems map[string][]string, err error)
// DeleteBackwardExtremity removes a backwards extremity for a room, if one existed.
DeleteBackwardExtremity(ctx context.Context, txn *sql.Tx, roomID, knownEventID string) (err error)
PurgeBackwardExtremities(ctx context.Context, txn *sql.Tx, roomID string) error
}
// SendToDevice tracks send-to-device messages which are sent to individual
@ -183,12 +187,14 @@ type Receipts interface {
UpsertReceipt(ctx context.Context, txn *sql.Tx, roomId, receiptType, userId, eventId string, timestamp gomatrixserverlib.Timestamp) (pos types.StreamPosition, err error)
SelectRoomReceiptsAfter(ctx context.Context, txn *sql.Tx, roomIDs []string, streamPos types.StreamPosition) (types.StreamPosition, []types.OutputReceiptEvent, error)
SelectMaxReceiptID(ctx context.Context, txn *sql.Tx) (id int64, err error)
PurgeReceipts(ctx context.Context, txn *sql.Tx, roomID string) error
}
type Memberships interface {
UpsertMembership(ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, streamPos, topologicalPos types.StreamPosition) error
SelectMembershipCount(ctx context.Context, txn *sql.Tx, roomID, membership string, pos types.StreamPosition) (count int, err error)
SelectMembershipForUser(ctx context.Context, txn *sql.Tx, roomID, userID string, pos int64) (membership string, topologicalPos int, err error)
PurgeMemberships(ctx context.Context, txn *sql.Tx, roomID string) error
SelectMemberships(
ctx context.Context, txn *sql.Tx,
roomID string, pos types.TopologyToken,
@ -200,6 +206,7 @@ type NotificationData interface {
UpsertRoomUnreadCounts(ctx context.Context, txn *sql.Tx, userID, roomID string, notificationCount, highlightCount int) (types.StreamPosition, error)
SelectUserUnreadCountsForRooms(ctx context.Context, txn *sql.Tx, userID string, roomIDs []string) (map[string]*eventutil.NotificationData, error)
SelectMaxID(ctx context.Context, txn *sql.Tx) (int64, error)
PurgeNotificationData(ctx context.Context, txn *sql.Tx, roomID string) error
}
type Ignores interface {

View file

@ -384,19 +384,32 @@ func applyHistoryVisibilityFilter(
roomID, userID string,
recentEvents []*gomatrixserverlib.HeaderedEvent,
) ([]*gomatrixserverlib.HeaderedEvent, error) {
// We need to make sure we always include the latest states events, if they are in the timeline.
// We grep at least limit * 2 events, to ensure we really get the needed events.
// We need to make sure we always include the latest state events, if they are in the timeline.
alwaysIncludeIDs := make(map[string]struct{})
var stateTypes []string
var senders []string
for _, ev := range recentEvents {
if ev.StateKey() != nil {
stateTypes = append(stateTypes, ev.Type())
senders = append(senders, ev.Sender())
}
}
// Only get the state again if there are state events in the timeline
if len(stateTypes) > 0 {
filter := gomatrixserverlib.DefaultStateFilter()
filter.Types = &stateTypes
filter.Senders = &senders
stateEvents, err := snapshot.CurrentState(ctx, roomID, &filter, nil)
if err != nil {
// Not a fatal error, we can continue without the stateEvents,
// they are only needed if there are state events in the timeline.
logrus.WithError(err).Warnf("Failed to get current room state for history visibility")
return nil, fmt.Errorf("failed to get current room state for history visibility calculation: %w", err)
}
alwaysIncludeIDs := make(map[string]struct{}, len(stateEvents))
for _, ev := range stateEvents {
alwaysIncludeIDs[ev.EventID()] = struct{}{}
}
}
startTime := time.Now()
events, err := internal.ApplyHistoryVisibilityFilter(ctx, snapshot, rsAPI, recentEvents, alwaysIncludeIDs, userID, "sync")
if err != nil {

View file

@ -521,6 +521,252 @@ func verifyEventVisible(t *testing.T, wantVisible bool, wantVisibleEvent *gomatr
}
}
func TestGetMembership(t *testing.T) {
alice := test.NewUser(t)
aliceDev := userapi.Device{
ID: "ALICEID",
UserID: alice.ID,
AccessToken: "ALICE_BEARER_TOKEN",
DisplayName: "Alice",
AccountType: userapi.AccountTypeUser,
}
bob := test.NewUser(t)
bobDev := userapi.Device{
ID: "BOBID",
UserID: bob.ID,
AccessToken: "notjoinedtoanyrooms",
}
testCases := []struct {
name string
roomID string
additionalEvents func(t *testing.T, room *test.Room)
request func(t *testing.T, room *test.Room) *http.Request
wantOK bool
wantMemberCount int
useSleep bool // :/
}{
{
name: "/members - Alice joined",
request: func(t *testing.T, room *test.Room) *http.Request {
return test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/members", room.ID), test.WithQueryParams(map[string]string{
"access_token": aliceDev.AccessToken,
}))
},
wantOK: true,
wantMemberCount: 1,
},
{
name: "/members - Bob never joined",
request: func(t *testing.T, room *test.Room) *http.Request {
return test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/members", room.ID), test.WithQueryParams(map[string]string{
"access_token": bobDev.AccessToken,
}))
},
wantOK: false,
},
{
name: "/joined_members - Bob never joined",
request: func(t *testing.T, room *test.Room) *http.Request {
return test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/joined_members", room.ID), test.WithQueryParams(map[string]string{
"access_token": bobDev.AccessToken,
}))
},
wantOK: false,
},
{
name: "/joined_members - Alice joined",
request: func(t *testing.T, room *test.Room) *http.Request {
return test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/joined_members", room.ID), test.WithQueryParams(map[string]string{
"access_token": aliceDev.AccessToken,
}))
},
wantOK: true,
},
{
name: "Alice leaves before Bob joins, should not be able to see Bob",
request: func(t *testing.T, room *test.Room) *http.Request {
return test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/members", room.ID), test.WithQueryParams(map[string]string{
"access_token": aliceDev.AccessToken,
}))
},
additionalEvents: func(t *testing.T, room *test.Room) {
room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomMember, map[string]interface{}{
"membership": "leave",
}, test.WithStateKey(alice.ID))
room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{
"membership": "join",
}, test.WithStateKey(bob.ID))
},
useSleep: true,
wantOK: true,
wantMemberCount: 1,
},
{
name: "Alice leaves after Bob joins, should be able to see Bob",
request: func(t *testing.T, room *test.Room) *http.Request {
return test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/members", room.ID), test.WithQueryParams(map[string]string{
"access_token": aliceDev.AccessToken,
}))
},
additionalEvents: func(t *testing.T, room *test.Room) {
room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{
"membership": "join",
}, test.WithStateKey(bob.ID))
room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomMember, map[string]interface{}{
"membership": "leave",
}, test.WithStateKey(alice.ID))
},
useSleep: true,
wantOK: true,
wantMemberCount: 2,
},
{
name: "/joined_members - Alice leaves, shouldn't be able to see members ",
request: func(t *testing.T, room *test.Room) *http.Request {
return test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/joined_members", room.ID), test.WithQueryParams(map[string]string{
"access_token": aliceDev.AccessToken,
}))
},
additionalEvents: func(t *testing.T, room *test.Room) {
room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomMember, map[string]interface{}{
"membership": "leave",
}, test.WithStateKey(alice.ID))
},
useSleep: true,
wantOK: false,
},
{
name: "'at' specified, returns memberships before Bob joins",
request: func(t *testing.T, room *test.Room) *http.Request {
return test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/members", room.ID), test.WithQueryParams(map[string]string{
"access_token": aliceDev.AccessToken,
"at": "t2_5",
}))
},
additionalEvents: func(t *testing.T, room *test.Room) {
room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{
"membership": "join",
}, test.WithStateKey(bob.ID))
},
useSleep: true,
wantOK: true,
wantMemberCount: 1,
},
{
name: "'membership=leave' specified, returns no memberships",
request: func(t *testing.T, room *test.Room) *http.Request {
return test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/members", room.ID), test.WithQueryParams(map[string]string{
"access_token": aliceDev.AccessToken,
"membership": "leave",
}))
},
wantOK: true,
wantMemberCount: 0,
},
{
name: "'not_membership=join' specified, returns no memberships",
request: func(t *testing.T, room *test.Room) *http.Request {
return test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/members", room.ID), test.WithQueryParams(map[string]string{
"access_token": aliceDev.AccessToken,
"not_membership": "join",
}))
},
wantOK: true,
wantMemberCount: 0,
},
{
name: "'not_membership=leave' & 'membership=join' specified, returns correct memberships",
request: func(t *testing.T, room *test.Room) *http.Request {
return test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/members", room.ID), test.WithQueryParams(map[string]string{
"access_token": aliceDev.AccessToken,
"not_membership": "leave",
"membership": "join",
}))
},
additionalEvents: func(t *testing.T, room *test.Room) {
room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{
"membership": "join",
}, test.WithStateKey(bob.ID))
room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{
"membership": "leave",
}, test.WithStateKey(bob.ID))
},
wantOK: true,
wantMemberCount: 1,
},
{
name: "non-existent room ID",
request: func(t *testing.T, room *test.Room) *http.Request {
return test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/members", "!notavalidroom:test"), test.WithQueryParams(map[string]string{
"access_token": aliceDev.AccessToken,
}))
},
wantOK: false,
},
}
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
base, close := testrig.CreateBaseDendrite(t, dbType)
defer close()
jsctx, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream)
defer jetstream.DeleteAllStreams(jsctx, &base.Cfg.Global.JetStream)
// Use an actual roomserver for this
rsAPI := roomserver.NewInternalAPI(base)
rsAPI.SetFederationAPI(nil, nil)
AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{aliceDev, bobDev}}, rsAPI, &syncKeyAPI{})
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
room := test.NewRoom(t, alice)
t.Cleanup(func() {
t.Logf("running cleanup for %s", tc.name)
})
// inject additional events
if tc.additionalEvents != nil {
tc.additionalEvents(t, room)
}
if err := api.SendEvents(context.Background(), rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil {
t.Fatalf("failed to send events: %v", err)
}
// wait for the events to come down sync
if tc.useSleep {
time.Sleep(time.Millisecond * 100)
} else {
syncUntil(t, base, aliceDev.AccessToken, false, func(syncBody string) bool {
// wait for the last sent eventID to come down sync
path := fmt.Sprintf(`rooms.join.%s.timeline.events.#(event_id=="%s")`, room.ID, room.Events()[len(room.Events())-1].EventID())
return gjson.Get(syncBody, path).Exists()
})
}
w := httptest.NewRecorder()
base.PublicClientAPIMux.ServeHTTP(w, tc.request(t, room))
if w.Code != 200 && tc.wantOK {
t.Logf("%s", w.Body.String())
t.Fatalf("got HTTP %d want %d", w.Code, 200)
}
t.Logf("[%s] Resp: %s", tc.name, w.Body.String())
// check we got the expected events
if tc.wantOK {
memberCount := len(gjson.GetBytes(w.Body.Bytes(), "chunk").Array())
if memberCount != tc.wantMemberCount {
t.Fatalf("expected %d members, got %d", tc.wantMemberCount, memberCount)
}
}
})
}
})
}
func TestSendToDevice(t *testing.T) {
test.WithAllDatabases(t, testSendToDevice)
}