mirror of
https://github.com/matrix-org/dendrite.git
synced 2025-12-26 08:13:09 -06:00
This commit is contained in:
commit
c7a8d6d641
|
|
@ -258,6 +258,7 @@ mscs:
|
||||||
# A list of enabled MSC's
|
# A list of enabled MSC's
|
||||||
# Currently valid values are:
|
# Currently valid values are:
|
||||||
# - msc2836 (Threading, see https://github.com/matrix-org/matrix-doc/pull/2836)
|
# - msc2836 (Threading, see https://github.com/matrix-org/matrix-doc/pull/2836)
|
||||||
|
# - msc2946 (Spaces Summary, see https://github.com/matrix-org/matrix-doc/pull/2946)
|
||||||
mscs: []
|
mscs: []
|
||||||
database:
|
database:
|
||||||
connection_string: file:mscs.db
|
connection_string: file:mscs.db
|
||||||
|
|
|
||||||
2
go.mod
2
go.mod
|
|
@ -22,7 +22,7 @@ require (
|
||||||
github.com/matrix-org/go-http-js-libp2p v0.0.0-20200518170932-783164aeeda4
|
github.com/matrix-org/go-http-js-libp2p v0.0.0-20200518170932-783164aeeda4
|
||||||
github.com/matrix-org/go-sqlite3-js v0.0.0-20200522092705-bc8506ccbcf3
|
github.com/matrix-org/go-sqlite3-js v0.0.0-20200522092705-bc8506ccbcf3
|
||||||
github.com/matrix-org/gomatrix v0.0.0-20200827122206-7dd5e2a05bcd
|
github.com/matrix-org/gomatrix v0.0.0-20200827122206-7dd5e2a05bcd
|
||||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20201209172200-eb6a8903f9fb
|
github.com/matrix-org/gomatrixserverlib v0.0.0-20210113173004-b1c67ac867cc
|
||||||
github.com/matrix-org/naffka v0.0.0-20200901083833-bcdd62999a91
|
github.com/matrix-org/naffka v0.0.0-20200901083833-bcdd62999a91
|
||||||
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4
|
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4
|
||||||
github.com/mattn/go-sqlite3 v1.14.2
|
github.com/mattn/go-sqlite3 v1.14.2
|
||||||
|
|
|
||||||
4
go.sum
4
go.sum
|
|
@ -567,8 +567,8 @@ github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26 h1:Hr3zjRsq2bh
|
||||||
github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0=
|
github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0=
|
||||||
github.com/matrix-org/gomatrix v0.0.0-20200827122206-7dd5e2a05bcd h1:xVrqJK3xHREMNjwjljkAUaadalWc0rRbmVuQatzmgwg=
|
github.com/matrix-org/gomatrix v0.0.0-20200827122206-7dd5e2a05bcd h1:xVrqJK3xHREMNjwjljkAUaadalWc0rRbmVuQatzmgwg=
|
||||||
github.com/matrix-org/gomatrix v0.0.0-20200827122206-7dd5e2a05bcd/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s=
|
github.com/matrix-org/gomatrix v0.0.0-20200827122206-7dd5e2a05bcd/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s=
|
||||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20201209172200-eb6a8903f9fb h1:UlhiSebJupQ+qAM93cdVGg4nAJ6bnxwAA5/EBygtYoo=
|
github.com/matrix-org/gomatrixserverlib v0.0.0-20210113173004-b1c67ac867cc h1:n2Hnbg8RZ4102Qmxie1riLkIyrqeqShJUILg1miSmDI=
|
||||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20201209172200-eb6a8903f9fb/go.mod h1:JsAzE1Ll3+gDWS9JSUHPJiiyAksvOOnGWF2nXdg4ZzU=
|
github.com/matrix-org/gomatrixserverlib v0.0.0-20210113173004-b1c67ac867cc/go.mod h1:JsAzE1Ll3+gDWS9JSUHPJiiyAksvOOnGWF2nXdg4ZzU=
|
||||||
github.com/matrix-org/naffka v0.0.0-20200901083833-bcdd62999a91 h1:HJ6U3S3ljJqNffYMcIeAncp5qT/i+ZMiJ2JC2F0aXP4=
|
github.com/matrix-org/naffka v0.0.0-20200901083833-bcdd62999a91 h1:HJ6U3S3ljJqNffYMcIeAncp5qT/i+ZMiJ2JC2F0aXP4=
|
||||||
github.com/matrix-org/naffka v0.0.0-20200901083833-bcdd62999a91/go.mod h1:sjyPyRxKM5uw1nD2cJ6O2OxI6GOqyVBfNXqKjBZTBZE=
|
github.com/matrix-org/naffka v0.0.0-20200901083833-bcdd62999a91/go.mod h1:sjyPyRxKM5uw1nD2cJ6O2OxI6GOqyVBfNXqKjBZTBZE=
|
||||||
github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7 h1:ntrLa/8xVzeSs8vHFHK25k0C+NV74sYMJnNSg5NoSRo=
|
github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7 h1:ntrLa/8xVzeSs8vHFHK25k0C+NV74sYMJnNSg5NoSRo=
|
||||||
|
|
|
||||||
|
|
@ -20,6 +20,7 @@ import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal/eventutil"
|
"github.com/matrix-org/dendrite/internal/eventutil"
|
||||||
"github.com/matrix-org/dendrite/roomserver/api"
|
"github.com/matrix-org/dendrite/roomserver/api"
|
||||||
|
|
@ -28,9 +29,29 @@ import (
|
||||||
"github.com/matrix-org/dendrite/roomserver/types"
|
"github.com/matrix-org/dendrite/roomserver/types"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
"github.com/matrix-org/util"
|
"github.com/matrix-org/util"
|
||||||
|
"github.com/prometheus/client_golang/prometheus"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
prometheus.MustRegister(processRoomEventDuration)
|
||||||
|
}
|
||||||
|
|
||||||
|
var processRoomEventDuration = prometheus.NewHistogramVec(
|
||||||
|
prometheus.HistogramOpts{
|
||||||
|
Namespace: "dendrite",
|
||||||
|
Subsystem: "roomserver",
|
||||||
|
Name: "processroomevent_duration_millis",
|
||||||
|
Help: "How long it takes the roomserver to process an event",
|
||||||
|
Buckets: []float64{ // milliseconds
|
||||||
|
5, 10, 25, 50, 75, 100, 250, 500,
|
||||||
|
1000, 2000, 3000, 4000, 5000, 6000,
|
||||||
|
7000, 8000, 9000, 10000, 15000, 20000,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
[]string{"room_id"},
|
||||||
|
)
|
||||||
|
|
||||||
// processRoomEvent can only be called once at a time
|
// processRoomEvent can only be called once at a time
|
||||||
//
|
//
|
||||||
// TODO(#375): This should be rewritten to allow concurrent calls. The
|
// TODO(#375): This should be rewritten to allow concurrent calls. The
|
||||||
|
|
@ -42,6 +63,15 @@ func (r *Inputer) processRoomEvent(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
input *api.InputRoomEvent,
|
input *api.InputRoomEvent,
|
||||||
) (eventID string, err error) {
|
) (eventID string, err error) {
|
||||||
|
// Measure how long it takes to process this event.
|
||||||
|
started := time.Now()
|
||||||
|
defer func() {
|
||||||
|
timetaken := time.Since(started)
|
||||||
|
processRoomEventDuration.With(prometheus.Labels{
|
||||||
|
"room_id": input.Event.RoomID(),
|
||||||
|
}).Observe(float64(timetaken.Milliseconds()))
|
||||||
|
}()
|
||||||
|
|
||||||
// Parse and validate the event JSON
|
// Parse and validate the event JSON
|
||||||
headered := input.Event
|
headered := input.Event
|
||||||
event := headered.Unwrap()
|
event := headered.Unwrap()
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@ package config
|
||||||
type MSCs struct {
|
type MSCs struct {
|
||||||
Matrix *Global `yaml:"-"`
|
Matrix *Global `yaml:"-"`
|
||||||
|
|
||||||
// The MSCs to enable, currently only `msc2836` is supported.
|
// The MSCs to enable
|
||||||
MSCs []string `yaml:"mscs"`
|
MSCs []string `yaml:"mscs"`
|
||||||
|
|
||||||
Database DatabaseOptions `yaml:"database"`
|
Database DatabaseOptions `yaml:"database"`
|
||||||
|
|
|
||||||
369
setup/mscs/msc2946/msc2946.go
Normal file
369
setup/mscs/msc2946/msc2946.go
Normal file
|
|
@ -0,0 +1,369 @@
|
||||||
|
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
// Package msc2946 'Spaces Summary' implements https://github.com/matrix-org/matrix-doc/pull/2946
|
||||||
|
package msc2946
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/gorilla/mux"
|
||||||
|
chttputil "github.com/matrix-org/dendrite/clientapi/httputil"
|
||||||
|
"github.com/matrix-org/dendrite/internal/hooks"
|
||||||
|
"github.com/matrix-org/dendrite/internal/httputil"
|
||||||
|
roomserver "github.com/matrix-org/dendrite/roomserver/api"
|
||||||
|
"github.com/matrix-org/dendrite/setup"
|
||||||
|
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
"github.com/matrix-org/util"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
ConstCreateEventContentKey = "org.matrix.msc1772.type"
|
||||||
|
ConstSpaceChildEventType = "org.matrix.msc1772.space.child"
|
||||||
|
ConstSpaceParentEventType = "org.matrix.msc1772.room.parent"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SpacesRequest is the request body to POST /_matrix/client/r0/rooms/{roomID}/spaces
|
||||||
|
type SpacesRequest struct {
|
||||||
|
MaxRoomsPerSpace int `json:"max_rooms_per_space"`
|
||||||
|
Limit int `json:"limit"`
|
||||||
|
Batch string `json:"batch"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Defaults sets the request defaults
|
||||||
|
func (r *SpacesRequest) Defaults() {
|
||||||
|
r.Limit = 100
|
||||||
|
r.MaxRoomsPerSpace = -1
|
||||||
|
}
|
||||||
|
|
||||||
|
// SpacesResponse is the response body to POST /_matrix/client/r0/rooms/{roomID}/spaces
|
||||||
|
type SpacesResponse struct {
|
||||||
|
NextBatch string `json:"next_batch"`
|
||||||
|
// Rooms are nodes on the space graph.
|
||||||
|
Rooms []Room `json:"rooms"`
|
||||||
|
// Events are edges on the space graph, exclusively m.space.child or m.room.parent events
|
||||||
|
Events []gomatrixserverlib.ClientEvent `json:"events"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Room is a node on the space graph
|
||||||
|
type Room struct {
|
||||||
|
gomatrixserverlib.PublicRoom
|
||||||
|
NumRefs int `json:"num_refs"`
|
||||||
|
RoomType string `json:"room_type"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Enable this MSC
|
||||||
|
func Enable(
|
||||||
|
base *setup.BaseDendrite, rsAPI roomserver.RoomserverInternalAPI, userAPI userapi.UserInternalAPI,
|
||||||
|
) error {
|
||||||
|
db, err := NewDatabase(&base.Cfg.MSCs.Database)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("Cannot enable MSC2946: %w", err)
|
||||||
|
}
|
||||||
|
hooks.Enable()
|
||||||
|
hooks.Attach(hooks.KindNewEventPersisted, func(headeredEvent interface{}) {
|
||||||
|
he := headeredEvent.(*gomatrixserverlib.HeaderedEvent)
|
||||||
|
hookErr := db.StoreReference(context.Background(), he)
|
||||||
|
if hookErr != nil {
|
||||||
|
util.GetLogger(context.Background()).WithError(hookErr).WithField("event_id", he.EventID()).Error(
|
||||||
|
"failed to StoreReference",
|
||||||
|
)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
base.PublicClientAPIMux.Handle("/unstable/rooms/{roomID}/spaces",
|
||||||
|
httputil.MakeAuthAPI("spaces", userAPI, spacesHandler(db, rsAPI)),
|
||||||
|
).Methods(http.MethodPost, http.MethodOptions)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func spacesHandler(db Database, rsAPI roomserver.RoomserverInternalAPI) func(*http.Request, *userapi.Device) util.JSONResponse {
|
||||||
|
return func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
|
inMemoryBatchCache := make(map[string]set)
|
||||||
|
// Extract the room ID from the request. Sanity check request data.
|
||||||
|
params, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||||
|
if err != nil {
|
||||||
|
return util.ErrorResponse(err)
|
||||||
|
}
|
||||||
|
roomID := params["roomID"]
|
||||||
|
var r SpacesRequest
|
||||||
|
r.Defaults()
|
||||||
|
if resErr := chttputil.UnmarshalJSONRequest(req, &r); resErr != nil {
|
||||||
|
return *resErr
|
||||||
|
}
|
||||||
|
if r.Limit > 100 {
|
||||||
|
r.Limit = 100
|
||||||
|
}
|
||||||
|
w := walker{
|
||||||
|
req: &r,
|
||||||
|
rootRoomID: roomID,
|
||||||
|
caller: device,
|
||||||
|
ctx: req.Context(),
|
||||||
|
|
||||||
|
db: db,
|
||||||
|
rsAPI: rsAPI,
|
||||||
|
inMemoryBatchCache: inMemoryBatchCache,
|
||||||
|
}
|
||||||
|
res := w.walk()
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: 200,
|
||||||
|
JSON: res,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type walker struct {
|
||||||
|
req *SpacesRequest
|
||||||
|
rootRoomID string
|
||||||
|
caller *userapi.Device
|
||||||
|
db Database
|
||||||
|
rsAPI roomserver.RoomserverInternalAPI
|
||||||
|
ctx context.Context
|
||||||
|
|
||||||
|
// user ID|device ID|batch_num => event/room IDs sent to client
|
||||||
|
inMemoryBatchCache map[string]set
|
||||||
|
mu sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *walker) alreadySent(id string) bool {
|
||||||
|
w.mu.Lock()
|
||||||
|
defer w.mu.Unlock()
|
||||||
|
m, ok := w.inMemoryBatchCache[w.caller.UserID+"|"+w.caller.ID]
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return m[id]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *walker) markSent(id string) {
|
||||||
|
w.mu.Lock()
|
||||||
|
defer w.mu.Unlock()
|
||||||
|
m := w.inMemoryBatchCache[w.caller.UserID+"|"+w.caller.ID]
|
||||||
|
if m == nil {
|
||||||
|
m = make(set)
|
||||||
|
}
|
||||||
|
m[id] = true
|
||||||
|
w.inMemoryBatchCache[w.caller.UserID+"|"+w.caller.ID] = m
|
||||||
|
}
|
||||||
|
|
||||||
|
// nolint:gocyclo
|
||||||
|
func (w *walker) walk() *SpacesResponse {
|
||||||
|
var res SpacesResponse
|
||||||
|
// Begin walking the graph starting with the room ID in the request in a queue of unvisited rooms
|
||||||
|
unvisited := []string{w.rootRoomID}
|
||||||
|
processed := make(set)
|
||||||
|
for len(unvisited) > 0 {
|
||||||
|
roomID := unvisited[0]
|
||||||
|
unvisited = unvisited[1:]
|
||||||
|
// If this room has already been processed, skip. NB: do not remember this between calls
|
||||||
|
if processed[roomID] || roomID == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// Mark this room as processed.
|
||||||
|
processed[roomID] = true
|
||||||
|
// Is the caller currently joined to the room or is the room `world_readable`
|
||||||
|
// If no, skip this room. If yes, continue.
|
||||||
|
if !w.authorised(roomID) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// Get all `m.space.child` and `m.room.parent` state events for the room. *In addition*, get
|
||||||
|
// all `m.space.child` and `m.room.parent` state events which *point to* (via `state_key` or `content.room_id`)
|
||||||
|
// this room. This requires servers to store reverse lookups.
|
||||||
|
refs, err := w.references(roomID)
|
||||||
|
if err != nil {
|
||||||
|
util.GetLogger(w.ctx).WithError(err).WithField("room_id", roomID).Error("failed to extract references for room")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// If this room has not ever been in `rooms` (across multiple requests), extract the
|
||||||
|
// `PublicRoomsChunk` for this room.
|
||||||
|
if !w.alreadySent(roomID) {
|
||||||
|
pubRoom := w.publicRoomsChunk(roomID)
|
||||||
|
roomType := ""
|
||||||
|
create := w.stateEvent(roomID, "m.room.create", "")
|
||||||
|
if create != nil {
|
||||||
|
roomType = gjson.GetBytes(create.Content(), ConstCreateEventContentKey).Str
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add the total number of events to `PublicRoomsChunk` under `num_refs`. Add `PublicRoomsChunk` to `rooms`.
|
||||||
|
res.Rooms = append(res.Rooms, Room{
|
||||||
|
PublicRoom: *pubRoom,
|
||||||
|
NumRefs: refs.len(),
|
||||||
|
RoomType: roomType,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
uniqueRooms := make(set)
|
||||||
|
|
||||||
|
// If this is the root room from the original request, insert all these events into `events` if
|
||||||
|
// they haven't been added before (across multiple requests).
|
||||||
|
if w.rootRoomID == roomID {
|
||||||
|
for _, ev := range refs.events() {
|
||||||
|
if !w.alreadySent(ev.EventID()) {
|
||||||
|
res.Events = append(res.Events, gomatrixserverlib.HeaderedToClientEvent(
|
||||||
|
ev, gomatrixserverlib.FormatAll,
|
||||||
|
))
|
||||||
|
uniqueRooms[ev.RoomID()] = true
|
||||||
|
uniqueRooms[SpaceTarget(ev)] = true
|
||||||
|
w.markSent(ev.EventID())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Else add them to `events` honouring the `limit` and `max_rooms_per_space` values. If either
|
||||||
|
// are exceeded, stop adding events. If the event has already been added, do not add it again.
|
||||||
|
numAdded := 0
|
||||||
|
for _, ev := range refs.events() {
|
||||||
|
if w.req.Limit > 0 && len(res.Events) >= w.req.Limit {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if w.req.MaxRoomsPerSpace > 0 && numAdded >= w.req.MaxRoomsPerSpace {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if w.alreadySent(ev.EventID()) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
res.Events = append(res.Events, gomatrixserverlib.HeaderedToClientEvent(
|
||||||
|
ev, gomatrixserverlib.FormatAll,
|
||||||
|
))
|
||||||
|
uniqueRooms[ev.RoomID()] = true
|
||||||
|
uniqueRooms[SpaceTarget(ev)] = true
|
||||||
|
w.markSent(ev.EventID())
|
||||||
|
// we don't distinguish between child state events and parent state events for the purposes of
|
||||||
|
// max_rooms_per_space, maybe we should?
|
||||||
|
numAdded++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// For each referenced room ID in the events being returned to the caller (both parent and child)
|
||||||
|
// add the room ID to the queue of unvisited rooms. Loop from the beginning.
|
||||||
|
for roomID := range uniqueRooms {
|
||||||
|
unvisited = append(unvisited, roomID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return &res
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *walker) stateEvent(roomID, evType, stateKey string) *gomatrixserverlib.HeaderedEvent {
|
||||||
|
var queryRes roomserver.QueryCurrentStateResponse
|
||||||
|
tuple := gomatrixserverlib.StateKeyTuple{
|
||||||
|
EventType: evType,
|
||||||
|
StateKey: stateKey,
|
||||||
|
}
|
||||||
|
err := w.rsAPI.QueryCurrentState(w.ctx, &roomserver.QueryCurrentStateRequest{
|
||||||
|
RoomID: roomID,
|
||||||
|
StateTuples: []gomatrixserverlib.StateKeyTuple{tuple},
|
||||||
|
}, &queryRes)
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return queryRes.StateEvents[tuple]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *walker) publicRoomsChunk(roomID string) *gomatrixserverlib.PublicRoom {
|
||||||
|
pubRooms, err := roomserver.PopulatePublicRooms(w.ctx, []string{roomID}, w.rsAPI)
|
||||||
|
if err != nil {
|
||||||
|
util.GetLogger(w.ctx).WithError(err).Error("failed to PopulatePublicRooms")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if len(pubRooms) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &pubRooms[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
// authorised returns true iff the user is joined this room or the room is world_readable
|
||||||
|
func (w *walker) authorised(roomID string) bool {
|
||||||
|
hisVisTuple := gomatrixserverlib.StateKeyTuple{
|
||||||
|
EventType: gomatrixserverlib.MRoomHistoryVisibility,
|
||||||
|
StateKey: "",
|
||||||
|
}
|
||||||
|
roomMemberTuple := gomatrixserverlib.StateKeyTuple{
|
||||||
|
EventType: gomatrixserverlib.MRoomMember,
|
||||||
|
StateKey: w.caller.UserID,
|
||||||
|
}
|
||||||
|
var queryRes roomserver.QueryCurrentStateResponse
|
||||||
|
err := w.rsAPI.QueryCurrentState(w.ctx, &roomserver.QueryCurrentStateRequest{
|
||||||
|
RoomID: roomID,
|
||||||
|
StateTuples: []gomatrixserverlib.StateKeyTuple{
|
||||||
|
hisVisTuple, roomMemberTuple,
|
||||||
|
},
|
||||||
|
}, &queryRes)
|
||||||
|
if err != nil {
|
||||||
|
util.GetLogger(w.ctx).WithError(err).Error("failed to QueryCurrentState")
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
memberEv := queryRes.StateEvents[roomMemberTuple]
|
||||||
|
hisVisEv := queryRes.StateEvents[hisVisTuple]
|
||||||
|
if memberEv != nil {
|
||||||
|
membership, _ := memberEv.Membership()
|
||||||
|
if membership == gomatrixserverlib.Join {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if hisVisEv != nil {
|
||||||
|
hisVis, _ := hisVisEv.HistoryVisibility()
|
||||||
|
if hisVis == "world_readable" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// references returns all references pointing to or from this room.
|
||||||
|
func (w *walker) references(roomID string) (eventLookup, error) {
|
||||||
|
events, err := w.db.References(w.ctx, roomID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
el := make(eventLookup)
|
||||||
|
for _, ev := range events {
|
||||||
|
el.set(ev)
|
||||||
|
}
|
||||||
|
return el, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// state event lookup across multiple rooms keyed on event type
|
||||||
|
// NOT THREAD SAFE
|
||||||
|
type eventLookup map[string][]*gomatrixserverlib.HeaderedEvent
|
||||||
|
|
||||||
|
func (el eventLookup) set(ev *gomatrixserverlib.HeaderedEvent) {
|
||||||
|
evs := el[ev.Type()]
|
||||||
|
if evs == nil {
|
||||||
|
evs = make([]*gomatrixserverlib.HeaderedEvent, 0)
|
||||||
|
}
|
||||||
|
evs = append(evs, ev)
|
||||||
|
el[ev.Type()] = evs
|
||||||
|
}
|
||||||
|
|
||||||
|
func (el eventLookup) len() int {
|
||||||
|
sum := 0
|
||||||
|
for _, evs := range el {
|
||||||
|
sum += len(evs)
|
||||||
|
}
|
||||||
|
return sum
|
||||||
|
}
|
||||||
|
|
||||||
|
func (el eventLookup) events() (events []*gomatrixserverlib.HeaderedEvent) {
|
||||||
|
for _, evs := range el {
|
||||||
|
events = append(events, evs...)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
type set map[string]bool
|
||||||
486
setup/mscs/msc2946/msc2946_test.go
Normal file
486
setup/mscs/msc2946/msc2946_test.go
Normal file
|
|
@ -0,0 +1,486 @@
|
||||||
|
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package msc2946_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"crypto/ed25519"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io/ioutil"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gorilla/mux"
|
||||||
|
"github.com/matrix-org/dendrite/internal/hooks"
|
||||||
|
"github.com/matrix-org/dendrite/internal/httputil"
|
||||||
|
roomserver "github.com/matrix-org/dendrite/roomserver/api"
|
||||||
|
"github.com/matrix-org/dendrite/setup"
|
||||||
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
|
"github.com/matrix-org/dendrite/setup/mscs/msc2946"
|
||||||
|
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
client = &http.Client{
|
||||||
|
Timeout: 10 * time.Second,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
// Basic sanity check of MSC2946 logic. Tests a single room with a few state events
|
||||||
|
// and a bit of recursion to subspaces. Makes a graph like:
|
||||||
|
// Root
|
||||||
|
// ____|_____
|
||||||
|
// | | |
|
||||||
|
// R1 R2 S1
|
||||||
|
// |_________
|
||||||
|
// | | |
|
||||||
|
// R3 R4 S2
|
||||||
|
// | <-- this link is just a parent, not a child
|
||||||
|
// R5
|
||||||
|
//
|
||||||
|
// Alice is not joined to R4, but R4 is "world_readable".
|
||||||
|
func TestMSC2946(t *testing.T) {
|
||||||
|
alice := "@alice:localhost"
|
||||||
|
// give access token to alice
|
||||||
|
nopUserAPI := &testUserAPI{
|
||||||
|
accessTokens: make(map[string]userapi.Device),
|
||||||
|
}
|
||||||
|
nopUserAPI.accessTokens["alice"] = userapi.Device{
|
||||||
|
AccessToken: "alice",
|
||||||
|
DisplayName: "Alice",
|
||||||
|
UserID: alice,
|
||||||
|
}
|
||||||
|
rootSpace := "!rootspace:localhost"
|
||||||
|
subSpaceS1 := "!subspaceS1:localhost"
|
||||||
|
subSpaceS2 := "!subspaceS2:localhost"
|
||||||
|
room1 := "!room1:localhost"
|
||||||
|
room2 := "!room2:localhost"
|
||||||
|
room3 := "!room3:localhost"
|
||||||
|
room4 := "!room4:localhost"
|
||||||
|
empty := ""
|
||||||
|
room5 := "!room5:localhost"
|
||||||
|
allRooms := []string{
|
||||||
|
rootSpace, subSpaceS1, subSpaceS2,
|
||||||
|
room1, room2, room3, room4, room5,
|
||||||
|
}
|
||||||
|
rootToR1 := mustCreateEvent(t, fledglingEvent{
|
||||||
|
RoomID: rootSpace,
|
||||||
|
Sender: alice,
|
||||||
|
Type: msc2946.ConstSpaceChildEventType,
|
||||||
|
StateKey: &room1,
|
||||||
|
Content: map[string]interface{}{
|
||||||
|
"via": []string{"localhost"},
|
||||||
|
"present": true,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
rootToR2 := mustCreateEvent(t, fledglingEvent{
|
||||||
|
RoomID: rootSpace,
|
||||||
|
Sender: alice,
|
||||||
|
Type: msc2946.ConstSpaceChildEventType,
|
||||||
|
StateKey: &room2,
|
||||||
|
Content: map[string]interface{}{
|
||||||
|
"via": []string{"localhost"},
|
||||||
|
"present": true,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
rootToS1 := mustCreateEvent(t, fledglingEvent{
|
||||||
|
RoomID: rootSpace,
|
||||||
|
Sender: alice,
|
||||||
|
Type: msc2946.ConstSpaceChildEventType,
|
||||||
|
StateKey: &subSpaceS1,
|
||||||
|
Content: map[string]interface{}{
|
||||||
|
"via": []string{"localhost"},
|
||||||
|
"present": true,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
s1ToR3 := mustCreateEvent(t, fledglingEvent{
|
||||||
|
RoomID: subSpaceS1,
|
||||||
|
Sender: alice,
|
||||||
|
Type: msc2946.ConstSpaceChildEventType,
|
||||||
|
StateKey: &room3,
|
||||||
|
Content: map[string]interface{}{
|
||||||
|
"via": []string{"localhost"},
|
||||||
|
"present": true,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
s1ToR4 := mustCreateEvent(t, fledglingEvent{
|
||||||
|
RoomID: subSpaceS1,
|
||||||
|
Sender: alice,
|
||||||
|
Type: msc2946.ConstSpaceChildEventType,
|
||||||
|
StateKey: &room4,
|
||||||
|
Content: map[string]interface{}{
|
||||||
|
"via": []string{"localhost"},
|
||||||
|
"present": true,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
s1ToS2 := mustCreateEvent(t, fledglingEvent{
|
||||||
|
RoomID: subSpaceS1,
|
||||||
|
Sender: alice,
|
||||||
|
Type: msc2946.ConstSpaceChildEventType,
|
||||||
|
StateKey: &subSpaceS2,
|
||||||
|
Content: map[string]interface{}{
|
||||||
|
"via": []string{"localhost"},
|
||||||
|
"present": true,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
// This is a parent link only
|
||||||
|
s2ToR5 := mustCreateEvent(t, fledglingEvent{
|
||||||
|
RoomID: room5,
|
||||||
|
Sender: alice,
|
||||||
|
Type: msc2946.ConstSpaceParentEventType,
|
||||||
|
StateKey: &empty,
|
||||||
|
Content: map[string]interface{}{
|
||||||
|
"room_id": subSpaceS2,
|
||||||
|
"via": []string{"localhost"},
|
||||||
|
"present": true,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
// history visibility for R4
|
||||||
|
r4HisVis := mustCreateEvent(t, fledglingEvent{
|
||||||
|
RoomID: room4,
|
||||||
|
Sender: "@someone:localhost",
|
||||||
|
Type: gomatrixserverlib.MRoomHistoryVisibility,
|
||||||
|
StateKey: &empty,
|
||||||
|
Content: map[string]interface{}{
|
||||||
|
"history_visibility": "world_readable",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
var joinEvents []*gomatrixserverlib.HeaderedEvent
|
||||||
|
for _, roomID := range allRooms {
|
||||||
|
if roomID == room4 {
|
||||||
|
continue // not joined to that room
|
||||||
|
}
|
||||||
|
joinEvents = append(joinEvents, mustCreateEvent(t, fledglingEvent{
|
||||||
|
RoomID: roomID,
|
||||||
|
Sender: alice,
|
||||||
|
StateKey: &alice,
|
||||||
|
Type: gomatrixserverlib.MRoomMember,
|
||||||
|
Content: map[string]interface{}{
|
||||||
|
"membership": "join",
|
||||||
|
},
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
roomNameTuple := gomatrixserverlib.StateKeyTuple{
|
||||||
|
EventType: "m.room.name",
|
||||||
|
StateKey: "",
|
||||||
|
}
|
||||||
|
hisVisTuple := gomatrixserverlib.StateKeyTuple{
|
||||||
|
EventType: "m.room.history_visibility",
|
||||||
|
StateKey: "",
|
||||||
|
}
|
||||||
|
nopRsAPI := &testRoomserverAPI{
|
||||||
|
joinEvents: joinEvents,
|
||||||
|
events: map[string]*gomatrixserverlib.HeaderedEvent{
|
||||||
|
rootToR1.EventID(): rootToR1,
|
||||||
|
rootToR2.EventID(): rootToR2,
|
||||||
|
rootToS1.EventID(): rootToS1,
|
||||||
|
s1ToR3.EventID(): s1ToR3,
|
||||||
|
s1ToR4.EventID(): s1ToR4,
|
||||||
|
s1ToS2.EventID(): s1ToS2,
|
||||||
|
s2ToR5.EventID(): s2ToR5,
|
||||||
|
r4HisVis.EventID(): r4HisVis,
|
||||||
|
},
|
||||||
|
pubRoomState: map[string]map[gomatrixserverlib.StateKeyTuple]string{
|
||||||
|
rootSpace: {
|
||||||
|
roomNameTuple: "Root",
|
||||||
|
hisVisTuple: "shared",
|
||||||
|
},
|
||||||
|
subSpaceS1: {
|
||||||
|
roomNameTuple: "Sub-Space 1",
|
||||||
|
hisVisTuple: "joined",
|
||||||
|
},
|
||||||
|
subSpaceS2: {
|
||||||
|
roomNameTuple: "Sub-Space 2",
|
||||||
|
hisVisTuple: "shared",
|
||||||
|
},
|
||||||
|
room1: {
|
||||||
|
hisVisTuple: "joined",
|
||||||
|
},
|
||||||
|
room2: {
|
||||||
|
hisVisTuple: "joined",
|
||||||
|
},
|
||||||
|
room3: {
|
||||||
|
hisVisTuple: "joined",
|
||||||
|
},
|
||||||
|
room4: {
|
||||||
|
hisVisTuple: "world_readable",
|
||||||
|
},
|
||||||
|
room5: {
|
||||||
|
hisVisTuple: "joined",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
allEvents := []*gomatrixserverlib.HeaderedEvent{
|
||||||
|
rootToR1, rootToR2, rootToS1,
|
||||||
|
s1ToR3, s1ToR4, s1ToS2,
|
||||||
|
s2ToR5, r4HisVis,
|
||||||
|
}
|
||||||
|
allEvents = append(allEvents, joinEvents...)
|
||||||
|
router := injectEvents(t, nopUserAPI, nopRsAPI, allEvents)
|
||||||
|
cancel := runServer(t, router)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
t.Run("returns no events for unknown rooms", func(t *testing.T) {
|
||||||
|
res := postSpaces(t, 200, "alice", "!unknown:localhost", newReq(t, map[string]interface{}{}))
|
||||||
|
if len(res.Events) > 0 {
|
||||||
|
t.Errorf("got %d events, want 0", len(res.Events))
|
||||||
|
}
|
||||||
|
if len(res.Rooms) > 0 {
|
||||||
|
t.Errorf("got %d rooms, want 0", len(res.Rooms))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
t.Run("returns the entire graph", func(t *testing.T) {
|
||||||
|
res := postSpaces(t, 200, "alice", rootSpace, newReq(t, map[string]interface{}{}))
|
||||||
|
if len(res.Events) != 7 {
|
||||||
|
t.Errorf("got %d events, want 7", len(res.Events))
|
||||||
|
}
|
||||||
|
if len(res.Rooms) != len(allRooms) {
|
||||||
|
t.Errorf("got %d rooms, want %d", len(res.Rooms), len(allRooms))
|
||||||
|
}
|
||||||
|
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func newReq(t *testing.T, jsonBody map[string]interface{}) *msc2946.SpacesRequest {
|
||||||
|
t.Helper()
|
||||||
|
b, err := json.Marshal(jsonBody)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to marshal request: %s", err)
|
||||||
|
}
|
||||||
|
var r msc2946.SpacesRequest
|
||||||
|
if err := json.Unmarshal(b, &r); err != nil {
|
||||||
|
t.Fatalf("Failed to unmarshal request: %s", err)
|
||||||
|
}
|
||||||
|
return &r
|
||||||
|
}
|
||||||
|
|
||||||
|
func runServer(t *testing.T, router *mux.Router) func() {
|
||||||
|
t.Helper()
|
||||||
|
externalServ := &http.Server{
|
||||||
|
Addr: string(":8010"),
|
||||||
|
WriteTimeout: 60 * time.Second,
|
||||||
|
Handler: router,
|
||||||
|
}
|
||||||
|
go func() {
|
||||||
|
externalServ.ListenAndServe()
|
||||||
|
}()
|
||||||
|
// wait to listen on the port
|
||||||
|
time.Sleep(500 * time.Millisecond)
|
||||||
|
return func() {
|
||||||
|
externalServ.Shutdown(context.TODO())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func postSpaces(t *testing.T, expectCode int, accessToken, roomID string, req *msc2946.SpacesRequest) *msc2946.SpacesResponse {
|
||||||
|
t.Helper()
|
||||||
|
var r msc2946.SpacesRequest
|
||||||
|
r.Defaults()
|
||||||
|
data, err := json.Marshal(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to marshal request: %s", err)
|
||||||
|
}
|
||||||
|
httpReq, err := http.NewRequest(
|
||||||
|
"POST", "http://localhost:8010/_matrix/client/unstable/rooms/"+url.PathEscape(roomID)+"/spaces",
|
||||||
|
bytes.NewBuffer(data),
|
||||||
|
)
|
||||||
|
httpReq.Header.Set("Authorization", "Bearer "+accessToken)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to prepare request: %s", err)
|
||||||
|
}
|
||||||
|
res, err := client.Do(httpReq)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to do request: %s", err)
|
||||||
|
}
|
||||||
|
if res.StatusCode != expectCode {
|
||||||
|
body, _ := ioutil.ReadAll(res.Body)
|
||||||
|
t.Fatalf("wrong response code, got %d want %d - body: %s", res.StatusCode, expectCode, string(body))
|
||||||
|
}
|
||||||
|
if res.StatusCode == 200 {
|
||||||
|
var result msc2946.SpacesResponse
|
||||||
|
body, err := ioutil.ReadAll(res.Body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("response 200 OK but failed to read response body: %s", err)
|
||||||
|
}
|
||||||
|
t.Logf("Body: %s", string(body))
|
||||||
|
if err := json.Unmarshal(body, &result); err != nil {
|
||||||
|
t.Fatalf("response 200 OK but failed to deserialise JSON : %s\nbody: %s", err, string(body))
|
||||||
|
}
|
||||||
|
return &result
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type testUserAPI struct {
|
||||||
|
accessTokens map[string]userapi.Device
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *testUserAPI) InputAccountData(ctx context.Context, req *userapi.InputAccountDataRequest, res *userapi.InputAccountDataResponse) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (u *testUserAPI) PerformAccountCreation(ctx context.Context, req *userapi.PerformAccountCreationRequest, res *userapi.PerformAccountCreationResponse) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (u *testUserAPI) PerformPasswordUpdate(ctx context.Context, req *userapi.PerformPasswordUpdateRequest, res *userapi.PerformPasswordUpdateResponse) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (u *testUserAPI) PerformDeviceCreation(ctx context.Context, req *userapi.PerformDeviceCreationRequest, res *userapi.PerformDeviceCreationResponse) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (u *testUserAPI) PerformDeviceDeletion(ctx context.Context, req *userapi.PerformDeviceDeletionRequest, res *userapi.PerformDeviceDeletionResponse) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (u *testUserAPI) PerformDeviceUpdate(ctx context.Context, req *userapi.PerformDeviceUpdateRequest, res *userapi.PerformDeviceUpdateResponse) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (u *testUserAPI) PerformLastSeenUpdate(ctx context.Context, req *userapi.PerformLastSeenUpdateRequest, res *userapi.PerformLastSeenUpdateResponse) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (u *testUserAPI) PerformAccountDeactivation(ctx context.Context, req *userapi.PerformAccountDeactivationRequest, res *userapi.PerformAccountDeactivationResponse) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (u *testUserAPI) QueryProfile(ctx context.Context, req *userapi.QueryProfileRequest, res *userapi.QueryProfileResponse) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (u *testUserAPI) QueryAccessToken(ctx context.Context, req *userapi.QueryAccessTokenRequest, res *userapi.QueryAccessTokenResponse) error {
|
||||||
|
dev, ok := u.accessTokens[req.AccessToken]
|
||||||
|
if !ok {
|
||||||
|
res.Err = fmt.Errorf("unknown token")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
res.Device = &dev
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (u *testUserAPI) QueryDevices(ctx context.Context, req *userapi.QueryDevicesRequest, res *userapi.QueryDevicesResponse) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (u *testUserAPI) QueryAccountData(ctx context.Context, req *userapi.QueryAccountDataRequest, res *userapi.QueryAccountDataResponse) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (u *testUserAPI) QueryDeviceInfos(ctx context.Context, req *userapi.QueryDeviceInfosRequest, res *userapi.QueryDeviceInfosResponse) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (u *testUserAPI) QuerySearchProfiles(ctx context.Context, req *userapi.QuerySearchProfilesRequest, res *userapi.QuerySearchProfilesResponse) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type testRoomserverAPI struct {
|
||||||
|
// use a trace API as it implements method stubs so we don't need to have them here.
|
||||||
|
// We'll override the functions we care about.
|
||||||
|
roomserver.RoomserverInternalAPITrace
|
||||||
|
joinEvents []*gomatrixserverlib.HeaderedEvent
|
||||||
|
events map[string]*gomatrixserverlib.HeaderedEvent
|
||||||
|
pubRoomState map[string]map[gomatrixserverlib.StateKeyTuple]string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *testRoomserverAPI) QueryBulkStateContent(ctx context.Context, req *roomserver.QueryBulkStateContentRequest, res *roomserver.QueryBulkStateContentResponse) error {
|
||||||
|
res.Rooms = make(map[string]map[gomatrixserverlib.StateKeyTuple]string)
|
||||||
|
for _, roomID := range req.RoomIDs {
|
||||||
|
pubRoomData, ok := r.pubRoomState[roomID]
|
||||||
|
if ok {
|
||||||
|
res.Rooms[roomID] = pubRoomData
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *testRoomserverAPI) QueryCurrentState(ctx context.Context, req *roomserver.QueryCurrentStateRequest, res *roomserver.QueryCurrentStateResponse) error {
|
||||||
|
res.StateEvents = make(map[gomatrixserverlib.StateKeyTuple]*gomatrixserverlib.HeaderedEvent)
|
||||||
|
checkEvent := func(he *gomatrixserverlib.HeaderedEvent) {
|
||||||
|
if he.RoomID() != req.RoomID {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if he.StateKey() == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
tuple := gomatrixserverlib.StateKeyTuple{
|
||||||
|
EventType: he.Type(),
|
||||||
|
StateKey: *he.StateKey(),
|
||||||
|
}
|
||||||
|
for _, t := range req.StateTuples {
|
||||||
|
if t == tuple {
|
||||||
|
res.StateEvents[t] = he
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, he := range r.joinEvents {
|
||||||
|
checkEvent(he)
|
||||||
|
}
|
||||||
|
for _, he := range r.events {
|
||||||
|
checkEvent(he)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func injectEvents(t *testing.T, userAPI userapi.UserInternalAPI, rsAPI roomserver.RoomserverInternalAPI, events []*gomatrixserverlib.HeaderedEvent) *mux.Router {
|
||||||
|
t.Helper()
|
||||||
|
cfg := &config.Dendrite{}
|
||||||
|
cfg.Defaults()
|
||||||
|
cfg.Global.ServerName = "localhost"
|
||||||
|
cfg.MSCs.Database.ConnectionString = "file:msc2946_test.db"
|
||||||
|
cfg.MSCs.MSCs = []string{"msc2946"}
|
||||||
|
base := &setup.BaseDendrite{
|
||||||
|
Cfg: cfg,
|
||||||
|
PublicClientAPIMux: mux.NewRouter().PathPrefix(httputil.PublicClientPathPrefix).Subrouter(),
|
||||||
|
PublicFederationAPIMux: mux.NewRouter().PathPrefix(httputil.PublicFederationPathPrefix).Subrouter(),
|
||||||
|
}
|
||||||
|
|
||||||
|
err := msc2946.Enable(base, rsAPI, userAPI)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to enable MSC2946: %s", err)
|
||||||
|
}
|
||||||
|
for _, ev := range events {
|
||||||
|
hooks.Run(hooks.KindNewEventPersisted, ev)
|
||||||
|
}
|
||||||
|
return base.PublicClientAPIMux
|
||||||
|
}
|
||||||
|
|
||||||
|
type fledglingEvent struct {
|
||||||
|
Type string
|
||||||
|
StateKey *string
|
||||||
|
Content interface{}
|
||||||
|
Sender string
|
||||||
|
RoomID string
|
||||||
|
}
|
||||||
|
|
||||||
|
func mustCreateEvent(t *testing.T, ev fledglingEvent) (result *gomatrixserverlib.HeaderedEvent) {
|
||||||
|
t.Helper()
|
||||||
|
roomVer := gomatrixserverlib.RoomVersionV6
|
||||||
|
seed := make([]byte, ed25519.SeedSize) // zero seed
|
||||||
|
key := ed25519.NewKeyFromSeed(seed)
|
||||||
|
eb := gomatrixserverlib.EventBuilder{
|
||||||
|
Sender: ev.Sender,
|
||||||
|
Depth: 999,
|
||||||
|
Type: ev.Type,
|
||||||
|
StateKey: ev.StateKey,
|
||||||
|
RoomID: ev.RoomID,
|
||||||
|
}
|
||||||
|
err := eb.SetContent(ev.Content)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("mustCreateEvent: failed to marshal event content %+v", ev.Content)
|
||||||
|
}
|
||||||
|
// make sure the origin_server_ts changes so we can test recency
|
||||||
|
time.Sleep(1 * time.Millisecond)
|
||||||
|
signedEvent, err := eb.Build(time.Now(), gomatrixserverlib.ServerName("localhost"), "ed25519:test", key, roomVer)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("mustCreateEvent: failed to sign event: %s", err)
|
||||||
|
}
|
||||||
|
h := signedEvent.Headered(roomVer)
|
||||||
|
return h
|
||||||
|
}
|
||||||
183
setup/mscs/msc2946/storage.go
Normal file
183
setup/mscs/msc2946/storage.go
Normal file
|
|
@ -0,0 +1,183 @@
|
||||||
|
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package msc2946
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/internal"
|
||||||
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
relTypes = map[string]int{
|
||||||
|
ConstSpaceChildEventType: 1,
|
||||||
|
ConstSpaceParentEventType: 2,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
type Database interface {
|
||||||
|
// StoreReference persists a child or parent space mapping.
|
||||||
|
StoreReference(ctx context.Context, he *gomatrixserverlib.HeaderedEvent) error
|
||||||
|
// References returns all events which have the given roomID as a parent or child space.
|
||||||
|
References(ctx context.Context, roomID string) ([]*gomatrixserverlib.HeaderedEvent, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type DB struct {
|
||||||
|
db *sql.DB
|
||||||
|
writer sqlutil.Writer
|
||||||
|
insertEdgeStmt *sql.Stmt
|
||||||
|
selectEdgesStmt *sql.Stmt
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewDatabase loads the database for msc2836
|
||||||
|
func NewDatabase(dbOpts *config.DatabaseOptions) (Database, error) {
|
||||||
|
if dbOpts.ConnectionString.IsPostgres() {
|
||||||
|
return newPostgresDatabase(dbOpts)
|
||||||
|
}
|
||||||
|
return newSQLiteDatabase(dbOpts)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newPostgresDatabase(dbOpts *config.DatabaseOptions) (Database, error) {
|
||||||
|
d := DB{
|
||||||
|
writer: sqlutil.NewDummyWriter(),
|
||||||
|
}
|
||||||
|
var err error
|
||||||
|
if d.db, err = sqlutil.Open(dbOpts); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
_, err = d.db.Exec(`
|
||||||
|
CREATE TABLE IF NOT EXISTS msc2946_edges (
|
||||||
|
room_version TEXT NOT NULL,
|
||||||
|
-- the room ID of the event, the source of the arrow
|
||||||
|
source_room_id TEXT NOT NULL,
|
||||||
|
-- the target room ID, the arrow destination
|
||||||
|
dest_room_id TEXT NOT NULL,
|
||||||
|
-- the kind of relation, either child or parent (1,2)
|
||||||
|
rel_type SMALLINT NOT NULL,
|
||||||
|
event_json TEXT NOT NULL,
|
||||||
|
CONSTRAINT msc2946_edges_uniq UNIQUE (source_room_id, dest_room_id, rel_type)
|
||||||
|
);
|
||||||
|
`)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if d.insertEdgeStmt, err = d.db.Prepare(`
|
||||||
|
INSERT INTO msc2946_edges(room_version, source_room_id, dest_room_id, rel_type, event_json)
|
||||||
|
VALUES($1, $2, $3, $4, $5)
|
||||||
|
ON CONFLICT DO NOTHING
|
||||||
|
`); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if d.selectEdgesStmt, err = d.db.Prepare(`
|
||||||
|
SELECT room_version, event_json FROM msc2946_edges
|
||||||
|
WHERE source_room_id = $1 OR dest_room_id = $2
|
||||||
|
`); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &d, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func newSQLiteDatabase(dbOpts *config.DatabaseOptions) (Database, error) {
|
||||||
|
d := DB{
|
||||||
|
writer: sqlutil.NewExclusiveWriter(),
|
||||||
|
}
|
||||||
|
var err error
|
||||||
|
if d.db, err = sqlutil.Open(dbOpts); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
_, err = d.db.Exec(`
|
||||||
|
CREATE TABLE IF NOT EXISTS msc2946_edges (
|
||||||
|
room_version TEXT NOT NULL,
|
||||||
|
-- the room ID of the event, the source of the arrow
|
||||||
|
source_room_id TEXT NOT NULL,
|
||||||
|
-- the target room ID, the arrow destination
|
||||||
|
dest_room_id TEXT NOT NULL,
|
||||||
|
-- the kind of relation, either child or parent (1,2)
|
||||||
|
rel_type SMALLINT NOT NULL,
|
||||||
|
event_json TEXT NOT NULL,
|
||||||
|
UNIQUE (source_room_id, dest_room_id, rel_type)
|
||||||
|
);
|
||||||
|
`)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if d.insertEdgeStmt, err = d.db.Prepare(`
|
||||||
|
INSERT INTO msc2946_edges(room_version, source_room_id, dest_room_id, rel_type, event_json)
|
||||||
|
VALUES($1, $2, $3, $4, $5)
|
||||||
|
ON CONFLICT DO NOTHING
|
||||||
|
`); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if d.selectEdgesStmt, err = d.db.Prepare(`
|
||||||
|
SELECT room_version, event_json FROM msc2946_edges
|
||||||
|
WHERE source_room_id = $1 OR dest_room_id = $2
|
||||||
|
`); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &d, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DB) StoreReference(ctx context.Context, he *gomatrixserverlib.HeaderedEvent) error {
|
||||||
|
target := SpaceTarget(he)
|
||||||
|
if target == "" {
|
||||||
|
return nil // malformed event
|
||||||
|
}
|
||||||
|
relType := relTypes[he.Type()]
|
||||||
|
_, err := d.insertEdgeStmt.ExecContext(ctx, he.RoomVersion, he.RoomID(), target, relType, he.JSON())
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DB) References(ctx context.Context, roomID string) ([]*gomatrixserverlib.HeaderedEvent, error) {
|
||||||
|
rows, err := d.selectEdgesStmt.QueryContext(ctx, roomID, roomID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer internal.CloseAndLogIfError(ctx, rows, "failed to close References")
|
||||||
|
refs := make([]*gomatrixserverlib.HeaderedEvent, 0)
|
||||||
|
for rows.Next() {
|
||||||
|
var roomVer string
|
||||||
|
var jsonBytes []byte
|
||||||
|
if err := rows.Scan(&roomVer, &jsonBytes); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
ev, err := gomatrixserverlib.NewEventFromTrustedJSON(jsonBytes, false, gomatrixserverlib.RoomVersion(roomVer))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
he := ev.Headered(gomatrixserverlib.RoomVersion(roomVer))
|
||||||
|
refs = append(refs, he)
|
||||||
|
}
|
||||||
|
return refs, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SpaceTarget returns the destination room ID for the space event. This is either a child or a parent
|
||||||
|
// depending on the event type.
|
||||||
|
func SpaceTarget(he *gomatrixserverlib.HeaderedEvent) string {
|
||||||
|
if he.StateKey() == nil {
|
||||||
|
return "" // no-op
|
||||||
|
}
|
||||||
|
switch he.Type() {
|
||||||
|
case ConstSpaceParentEventType:
|
||||||
|
return gjson.GetBytes(he.Content(), "room_id").Str
|
||||||
|
case ConstSpaceChildEventType:
|
||||||
|
return *he.StateKey()
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
@ -21,6 +21,7 @@ import (
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/setup"
|
"github.com/matrix-org/dendrite/setup"
|
||||||
"github.com/matrix-org/dendrite/setup/mscs/msc2836"
|
"github.com/matrix-org/dendrite/setup/mscs/msc2836"
|
||||||
|
"github.com/matrix-org/dendrite/setup/mscs/msc2946"
|
||||||
"github.com/matrix-org/util"
|
"github.com/matrix-org/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -39,6 +40,8 @@ func EnableMSC(base *setup.BaseDendrite, monolith *setup.Monolith, msc string) e
|
||||||
switch msc {
|
switch msc {
|
||||||
case "msc2836":
|
case "msc2836":
|
||||||
return msc2836.Enable(base, monolith.RoomserverAPI, monolith.FederationSenderAPI, monolith.UserAPI, monolith.KeyRing)
|
return msc2836.Enable(base, monolith.RoomserverAPI, monolith.FederationSenderAPI, monolith.UserAPI, monolith.KeyRing)
|
||||||
|
case "msc2946":
|
||||||
|
return msc2946.Enable(base, monolith.RoomserverAPI, monolith.UserAPI)
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("EnableMSC: unknown msc '%s'", msc)
|
return fmt.Errorf("EnableMSC: unknown msc '%s'", msc)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -33,6 +33,7 @@ type Database interface {
|
||||||
MaxStreamPositionForReceipts(ctx context.Context) (types.StreamPosition, error)
|
MaxStreamPositionForReceipts(ctx context.Context) (types.StreamPosition, error)
|
||||||
MaxStreamPositionForInvites(ctx context.Context) (types.StreamPosition, error)
|
MaxStreamPositionForInvites(ctx context.Context) (types.StreamPosition, error)
|
||||||
MaxStreamPositionForAccountData(ctx context.Context) (types.StreamPosition, error)
|
MaxStreamPositionForAccountData(ctx context.Context) (types.StreamPosition, error)
|
||||||
|
MaxStreamPositionForSendToDeviceMessages(ctx context.Context) (types.StreamPosition, error)
|
||||||
|
|
||||||
CurrentState(ctx context.Context, roomID string, stateFilterPart *gomatrixserverlib.StateFilter) ([]*gomatrixserverlib.HeaderedEvent, error)
|
CurrentState(ctx context.Context, roomID string, stateFilterPart *gomatrixserverlib.StateFilter) ([]*gomatrixserverlib.HeaderedEvent, error)
|
||||||
GetStateDeltasForFullStateSync(ctx context.Context, device *userapi.Device, r types.Range, userID string, stateFilter *gomatrixserverlib.StateFilter) ([]types.StateDelta, []string, error)
|
GetStateDeltasForFullStateSync(ctx context.Context, device *userapi.Device, r types.Range, userID string, stateFilter *gomatrixserverlib.StateFilter) ([]types.StateDelta, []string, error)
|
||||||
|
|
@ -117,26 +118,14 @@ type Database interface {
|
||||||
// matches the streamevent.transactionID device then the transaction ID gets
|
// matches the streamevent.transactionID device then the transaction ID gets
|
||||||
// added to the unsigned section of the output event.
|
// added to the unsigned section of the output event.
|
||||||
StreamEventsToEvents(device *userapi.Device, in []types.StreamEvent) []*gomatrixserverlib.HeaderedEvent
|
StreamEventsToEvents(device *userapi.Device, in []types.StreamEvent) []*gomatrixserverlib.HeaderedEvent
|
||||||
// SendToDeviceUpdatesForSync returns a list of send-to-device updates. It returns three lists:
|
// SendToDeviceUpdatesForSync returns a list of send-to-device updates. It returns the
|
||||||
// - "events": a list of send-to-device events that should be included in the sync
|
// relevant events within the given ranges for the supplied user ID and device ID.
|
||||||
// - "changes": a list of send-to-device events that should be updated in the database by
|
SendToDeviceUpdatesForSync(ctx context.Context, userID, deviceID string, from, to types.StreamPosition) (pos types.StreamPosition, events []types.SendToDeviceEvent, err error)
|
||||||
// CleanSendToDeviceUpdates
|
|
||||||
// - "deletions": a list of send-to-device events which have been confirmed as sent and
|
|
||||||
// can be deleted altogether by CleanSendToDeviceUpdates
|
|
||||||
// The token supplied should be the current requested sync token, e.g. from the "since"
|
|
||||||
// parameter.
|
|
||||||
SendToDeviceUpdatesForSync(ctx context.Context, userID, deviceID string, token types.StreamingToken) (pos types.StreamPosition, events []types.SendToDeviceEvent, changes []types.SendToDeviceNID, deletions []types.SendToDeviceNID, err error)
|
|
||||||
// StoreNewSendForDeviceMessage stores a new send-to-device event for a user's device.
|
// StoreNewSendForDeviceMessage stores a new send-to-device event for a user's device.
|
||||||
StoreNewSendForDeviceMessage(ctx context.Context, userID, deviceID string, event gomatrixserverlib.SendToDeviceEvent) (types.StreamPosition, error)
|
StoreNewSendForDeviceMessage(ctx context.Context, userID, deviceID string, event gomatrixserverlib.SendToDeviceEvent) (types.StreamPosition, error)
|
||||||
// CleanSendToDeviceUpdates will update or remove any send-to-device updates based on the
|
// CleanSendToDeviceUpdates removes all send-to-device messages BEFORE the specified
|
||||||
// result to a previous call to SendDeviceUpdatesForSync. This is separate as it allows
|
// from position, preventing the send-to-device table from growing indefinitely.
|
||||||
// SendToDeviceUpdatesForSync to be called multiple times if needed (e.g. before and after
|
CleanSendToDeviceUpdates(ctx context.Context, userID, deviceID string, before types.StreamPosition) (err error)
|
||||||
// starting to wait for an incremental sync with timeout).
|
|
||||||
// The token supplied should be the current requested sync token, e.g. from the "since"
|
|
||||||
// parameter.
|
|
||||||
CleanSendToDeviceUpdates(ctx context.Context, toUpdate, toDelete []types.SendToDeviceNID, token types.StreamingToken) (err error)
|
|
||||||
// SendToDeviceUpdatesWaiting returns true if there are send-to-device updates waiting to be sent.
|
|
||||||
SendToDeviceUpdatesWaiting(ctx context.Context, userID, deviceID string) (bool, error)
|
|
||||||
// GetFilter looks up the filter associated with a given local user and filter ID.
|
// GetFilter looks up the filter associated with a given local user and filter ID.
|
||||||
// Returns a filter structure. Otherwise returns an error if no such filter exists
|
// Returns a filter structure. Otherwise returns an error if no such filter exists
|
||||||
// or if there was an error talking to the database.
|
// or if there was an error talking to the database.
|
||||||
|
|
|
||||||
|
|
@ -24,6 +24,7 @@ import (
|
||||||
|
|
||||||
func LoadFromGoose() {
|
func LoadFromGoose() {
|
||||||
goose.AddMigration(UpFixSequences, DownFixSequences)
|
goose.AddMigration(UpFixSequences, DownFixSequences)
|
||||||
|
goose.AddMigration(UpRemoveSendToDeviceSentColumn, DownRemoveSendToDeviceSentColumn)
|
||||||
}
|
}
|
||||||
|
|
||||||
func LoadFixSequences(m *sqlutil.Migrations) {
|
func LoadFixSequences(m *sqlutil.Migrations) {
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,48 @@
|
||||||
|
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package deltas
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
|
)
|
||||||
|
|
||||||
|
func LoadRemoveSendToDeviceSentColumn(m *sqlutil.Migrations) {
|
||||||
|
m.AddMigration(UpRemoveSendToDeviceSentColumn, DownRemoveSendToDeviceSentColumn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func UpRemoveSendToDeviceSentColumn(tx *sql.Tx) error {
|
||||||
|
_, err := tx.Exec(`
|
||||||
|
ALTER TABLE syncapi_send_to_device
|
||||||
|
DROP COLUMN IF EXISTS sent_by_token;
|
||||||
|
`)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to execute upgrade: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func DownRemoveSendToDeviceSentColumn(tx *sql.Tx) error {
|
||||||
|
_, err := tx.Exec(`
|
||||||
|
ALTER TABLE syncapi_send_to_device
|
||||||
|
ADD COLUMN IF NOT EXISTS sent_by_token TEXT;
|
||||||
|
`)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to execute downgrade: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
@ -19,7 +19,6 @@ import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
|
||||||
"github.com/lib/pq"
|
|
||||||
"github.com/matrix-org/dendrite/internal"
|
"github.com/matrix-org/dendrite/internal"
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
"github.com/matrix-org/dendrite/syncapi/storage/tables"
|
"github.com/matrix-org/dendrite/syncapi/storage/tables"
|
||||||
|
|
@ -38,11 +37,7 @@ CREATE TABLE IF NOT EXISTS syncapi_send_to_device (
|
||||||
-- The device ID to send the message to.
|
-- The device ID to send the message to.
|
||||||
device_id TEXT NOT NULL,
|
device_id TEXT NOT NULL,
|
||||||
-- The event content JSON.
|
-- The event content JSON.
|
||||||
content TEXT NOT NULL,
|
content TEXT NOT NULL
|
||||||
-- The token that was supplied to the /sync at the time that this
|
|
||||||
-- message was included in a sync response, or NULL if we haven't
|
|
||||||
-- included it in a /sync response yet.
|
|
||||||
sent_by_token TEXT
|
|
||||||
);
|
);
|
||||||
`
|
`
|
||||||
|
|
||||||
|
|
@ -52,34 +47,26 @@ const insertSendToDeviceMessageSQL = `
|
||||||
RETURNING id
|
RETURNING id
|
||||||
`
|
`
|
||||||
|
|
||||||
const countSendToDeviceMessagesSQL = `
|
|
||||||
SELECT COUNT(*)
|
|
||||||
FROM syncapi_send_to_device
|
|
||||||
WHERE user_id = $1 AND device_id = $2
|
|
||||||
`
|
|
||||||
|
|
||||||
const selectSendToDeviceMessagesSQL = `
|
const selectSendToDeviceMessagesSQL = `
|
||||||
SELECT id, user_id, device_id, content, sent_by_token
|
SELECT id, user_id, device_id, content
|
||||||
FROM syncapi_send_to_device
|
FROM syncapi_send_to_device
|
||||||
WHERE user_id = $1 AND device_id = $2
|
WHERE user_id = $1 AND device_id = $2 AND id > $3 AND id <= $4
|
||||||
ORDER BY id DESC
|
ORDER BY id DESC
|
||||||
`
|
`
|
||||||
|
|
||||||
const updateSentSendToDeviceMessagesSQL = `
|
const deleteSendToDeviceMessagesSQL = `
|
||||||
UPDATE syncapi_send_to_device SET sent_by_token = $1
|
DELETE FROM syncapi_send_to_device
|
||||||
WHERE id = ANY($2)
|
WHERE user_id = $1 AND device_id = $2 AND id < $3
|
||||||
`
|
`
|
||||||
|
|
||||||
const deleteSendToDeviceMessagesSQL = `
|
const selectMaxSendToDeviceIDSQL = "" +
|
||||||
DELETE FROM syncapi_send_to_device WHERE id = ANY($1)
|
"SELECT MAX(id) FROM syncapi_send_to_device"
|
||||||
`
|
|
||||||
|
|
||||||
type sendToDeviceStatements struct {
|
type sendToDeviceStatements struct {
|
||||||
insertSendToDeviceMessageStmt *sql.Stmt
|
insertSendToDeviceMessageStmt *sql.Stmt
|
||||||
countSendToDeviceMessagesStmt *sql.Stmt
|
|
||||||
selectSendToDeviceMessagesStmt *sql.Stmt
|
selectSendToDeviceMessagesStmt *sql.Stmt
|
||||||
updateSentSendToDeviceMessagesStmt *sql.Stmt
|
|
||||||
deleteSendToDeviceMessagesStmt *sql.Stmt
|
deleteSendToDeviceMessagesStmt *sql.Stmt
|
||||||
|
selectMaxSendToDeviceIDStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewPostgresSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) {
|
func NewPostgresSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) {
|
||||||
|
|
@ -91,16 +78,13 @@ func NewPostgresSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) {
|
||||||
if s.insertSendToDeviceMessageStmt, err = db.Prepare(insertSendToDeviceMessageSQL); err != nil {
|
if s.insertSendToDeviceMessageStmt, err = db.Prepare(insertSendToDeviceMessageSQL); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if s.countSendToDeviceMessagesStmt, err = db.Prepare(countSendToDeviceMessagesSQL); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if s.selectSendToDeviceMessagesStmt, err = db.Prepare(selectSendToDeviceMessagesSQL); err != nil {
|
if s.selectSendToDeviceMessagesStmt, err = db.Prepare(selectSendToDeviceMessagesSQL); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if s.updateSentSendToDeviceMessagesStmt, err = db.Prepare(updateSentSendToDeviceMessagesSQL); err != nil {
|
if s.deleteSendToDeviceMessagesStmt, err = db.Prepare(deleteSendToDeviceMessagesSQL); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if s.deleteSendToDeviceMessagesStmt, err = db.Prepare(deleteSendToDeviceMessagesSQL); err != nil {
|
if s.selectMaxSendToDeviceIDStmt, err = db.Prepare(selectMaxSendToDeviceIDSQL); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return s, nil
|
return s, nil
|
||||||
|
|
@ -113,64 +97,55 @@ func (s *sendToDeviceStatements) InsertSendToDeviceMessage(
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *sendToDeviceStatements) CountSendToDeviceMessages(
|
|
||||||
ctx context.Context, txn *sql.Tx, userID, deviceID string,
|
|
||||||
) (count int, err error) {
|
|
||||||
row := sqlutil.TxStmt(txn, s.countSendToDeviceMessagesStmt).QueryRowContext(ctx, userID, deviceID)
|
|
||||||
if err = row.Scan(&count); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
return count, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *sendToDeviceStatements) SelectSendToDeviceMessages(
|
func (s *sendToDeviceStatements) SelectSendToDeviceMessages(
|
||||||
ctx context.Context, txn *sql.Tx, userID, deviceID string,
|
ctx context.Context, txn *sql.Tx, userID, deviceID string, from, to types.StreamPosition,
|
||||||
) (lastPos types.StreamPosition, events []types.SendToDeviceEvent, err error) {
|
) (lastPos types.StreamPosition, events []types.SendToDeviceEvent, err error) {
|
||||||
rows, err := sqlutil.TxStmt(txn, s.selectSendToDeviceMessagesStmt).QueryContext(ctx, userID, deviceID)
|
rows, err := sqlutil.TxStmt(txn, s.selectSendToDeviceMessagesStmt).QueryContext(ctx, userID, deviceID, from, to)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer internal.CloseAndLogIfError(ctx, rows, "SelectSendToDeviceMessages: rows.close() failed")
|
defer internal.CloseAndLogIfError(ctx, rows, "SelectSendToDeviceMessages: rows.close() failed")
|
||||||
|
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var id types.SendToDeviceNID
|
var id types.StreamPosition
|
||||||
var userID, deviceID, content string
|
var userID, deviceID, content string
|
||||||
var sentByToken *string
|
if err = rows.Scan(&id, &userID, &deviceID, &content); err != nil {
|
||||||
if err = rows.Scan(&id, &userID, &deviceID, &content, &sentByToken); err != nil {
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if id > lastPos {
|
||||||
|
lastPos = id
|
||||||
|
}
|
||||||
event := types.SendToDeviceEvent{
|
event := types.SendToDeviceEvent{
|
||||||
ID: id,
|
ID: id,
|
||||||
UserID: userID,
|
UserID: userID,
|
||||||
DeviceID: deviceID,
|
DeviceID: deviceID,
|
||||||
}
|
}
|
||||||
if err = json.Unmarshal([]byte(content), &event.SendToDeviceEvent); err != nil {
|
if err = json.Unmarshal([]byte(content), &event.SendToDeviceEvent); err != nil {
|
||||||
return
|
continue
|
||||||
}
|
|
||||||
if sentByToken != nil {
|
|
||||||
if token, err := types.NewStreamTokenFromString(*sentByToken); err == nil {
|
|
||||||
event.SentByToken = &token
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
events = append(events, event)
|
events = append(events, event)
|
||||||
if types.StreamPosition(id) > lastPos {
|
|
||||||
lastPos = types.StreamPosition(id)
|
|
||||||
}
|
}
|
||||||
|
if lastPos == 0 {
|
||||||
|
lastPos = to
|
||||||
}
|
}
|
||||||
|
|
||||||
return lastPos, events, rows.Err()
|
return lastPos, events, rows.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *sendToDeviceStatements) UpdateSentSendToDeviceMessages(
|
func (s *sendToDeviceStatements) DeleteSendToDeviceMessages(
|
||||||
ctx context.Context, txn *sql.Tx, token string, nids []types.SendToDeviceNID,
|
ctx context.Context, txn *sql.Tx, userID, deviceID string, pos types.StreamPosition,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
_, err = sqlutil.TxStmt(txn, s.updateSentSendToDeviceMessagesStmt).ExecContext(ctx, token, pq.Array(nids))
|
_, err = sqlutil.TxStmt(txn, s.deleteSendToDeviceMessagesStmt).ExecContext(ctx, userID, deviceID, pos)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *sendToDeviceStatements) DeleteSendToDeviceMessages(
|
func (s *sendToDeviceStatements) SelectMaxSendToDeviceMessageID(
|
||||||
ctx context.Context, txn *sql.Tx, nids []types.SendToDeviceNID,
|
ctx context.Context, txn *sql.Tx,
|
||||||
) (err error) {
|
) (id int64, err error) {
|
||||||
_, err = sqlutil.TxStmt(txn, s.deleteSendToDeviceMessagesStmt).ExecContext(ctx, pq.Array(nids))
|
var nullableID sql.NullInt64
|
||||||
|
stmt := sqlutil.TxStmt(txn, s.selectMaxSendToDeviceIDStmt)
|
||||||
|
err = stmt.QueryRowContext(ctx).Scan(&nullableID)
|
||||||
|
if nullableID.Valid {
|
||||||
|
id = nullableID.Int64
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -89,6 +89,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*SyncServerDatasource, e
|
||||||
}
|
}
|
||||||
m := sqlutil.NewMigrations()
|
m := sqlutil.NewMigrations()
|
||||||
deltas.LoadFixSequences(m)
|
deltas.LoadFixSequences(m)
|
||||||
|
deltas.LoadRemoveSendToDeviceSentColumn(m)
|
||||||
if err = m.RunDeltas(d.db, dbProperties); err != nil {
|
if err = m.RunDeltas(d.db, dbProperties); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -29,6 +29,7 @@ import (
|
||||||
"github.com/matrix-org/dendrite/syncapi/storage/tables"
|
"github.com/matrix-org/dendrite/syncapi/storage/tables"
|
||||||
"github.com/matrix-org/dendrite/syncapi/types"
|
"github.com/matrix-org/dendrite/syncapi/types"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -85,6 +86,14 @@ func (d *Database) MaxStreamPositionForInvites(ctx context.Context) (types.Strea
|
||||||
return types.StreamPosition(id), nil
|
return types.StreamPosition(id), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (d *Database) MaxStreamPositionForSendToDeviceMessages(ctx context.Context) (types.StreamPosition, error) {
|
||||||
|
id, err := d.SendToDevice.SelectMaxSendToDeviceMessageID(ctx, nil)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("d.SendToDevice.SelectMaxSendToDeviceMessageID: %w", err)
|
||||||
|
}
|
||||||
|
return types.StreamPosition(id), nil
|
||||||
|
}
|
||||||
|
|
||||||
func (d *Database) MaxStreamPositionForAccountData(ctx context.Context) (types.StreamPosition, error) {
|
func (d *Database) MaxStreamPositionForAccountData(ctx context.Context) (types.StreamPosition, error) {
|
||||||
id, err := d.AccountData.SelectMaxAccountDataID(ctx, nil)
|
id, err := d.AccountData.SelectMaxAccountDataID(ctx, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -168,30 +177,6 @@ func (d *Database) GetEventsInStreamingRange(
|
||||||
return events, err
|
return events, err
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
|
||||||
func (d *Database) AddTypingUser(
|
|
||||||
userID, roomID string, expireTime *time.Time,
|
|
||||||
) types.StreamPosition {
|
|
||||||
return types.StreamPosition(d.EDUCache.AddTypingUser(userID, roomID, expireTime))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *Database) RemoveTypingUser(
|
|
||||||
userID, roomID string,
|
|
||||||
) types.StreamPosition {
|
|
||||||
return types.StreamPosition(d.EDUCache.RemoveUser(userID, roomID))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *Database) SetTypingTimeoutCallback(fn cache.TimeoutCallbackFn) {
|
|
||||||
d.EDUCache.SetTimeoutCallback(fn)
|
|
||||||
}
|
|
||||||
*/
|
|
||||||
|
|
||||||
/*
|
|
||||||
func (d *Database) AddSendToDevice() types.StreamPosition {
|
|
||||||
return types.StreamPosition(d.EDUCache.AddSendToDeviceMessage())
|
|
||||||
}
|
|
||||||
*/
|
|
||||||
|
|
||||||
func (d *Database) AllJoinedUsersInRooms(ctx context.Context) (map[string][]string, error) {
|
func (d *Database) AllJoinedUsersInRooms(ctx context.Context) (map[string][]string, error) {
|
||||||
return d.CurrentRoomState.SelectJoinedUsers(ctx)
|
return d.CurrentRoomState.SelectJoinedUsers(ctx)
|
||||||
}
|
}
|
||||||
|
|
@ -891,16 +876,6 @@ func (d *Database) currentStateStreamEventsForRoom(
|
||||||
return s, nil
|
return s, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) SendToDeviceUpdatesWaiting(
|
|
||||||
ctx context.Context, userID, deviceID string,
|
|
||||||
) (bool, error) {
|
|
||||||
count, err := d.SendToDevice.CountSendToDeviceMessages(ctx, nil, userID, deviceID)
|
|
||||||
if err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
return count > 0, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *Database) StoreNewSendForDeviceMessage(
|
func (d *Database) StoreNewSendForDeviceMessage(
|
||||||
ctx context.Context, userID, deviceID string, event gomatrixserverlib.SendToDeviceEvent,
|
ctx context.Context, userID, deviceID string, event gomatrixserverlib.SendToDeviceEvent,
|
||||||
) (newPos types.StreamPosition, err error) {
|
) (newPos types.StreamPosition, err error) {
|
||||||
|
|
@ -919,78 +894,38 @@ func (d *Database) StoreNewSendForDeviceMessage(
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
return 0, nil
|
return newPos, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) SendToDeviceUpdatesForSync(
|
func (d *Database) SendToDeviceUpdatesForSync(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
userID, deviceID string,
|
userID, deviceID string,
|
||||||
token types.StreamingToken,
|
from, to types.StreamPosition,
|
||||||
) (types.StreamPosition, []types.SendToDeviceEvent, []types.SendToDeviceNID, []types.SendToDeviceNID, error) {
|
) (types.StreamPosition, []types.SendToDeviceEvent, error) {
|
||||||
// First of all, get our send-to-device updates for this user.
|
// First of all, get our send-to-device updates for this user.
|
||||||
lastPos, events, err := d.SendToDevice.SelectSendToDeviceMessages(ctx, nil, userID, deviceID)
|
lastPos, events, err := d.SendToDevice.SelectSendToDeviceMessages(ctx, nil, userID, deviceID, from, to)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, nil, nil, nil, fmt.Errorf("d.SendToDevice.SelectSendToDeviceMessages: %w", err)
|
return from, nil, fmt.Errorf("d.SendToDevice.SelectSendToDeviceMessages: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// If there's nothing to do then stop here.
|
// If there's nothing to do then stop here.
|
||||||
if len(events) == 0 {
|
if len(events) == 0 {
|
||||||
return 0, nil, nil, nil, nil
|
return to, nil, nil
|
||||||
}
|
}
|
||||||
|
return lastPos, events, nil
|
||||||
// Work out whether we need to update any of the database entries.
|
|
||||||
toReturn := []types.SendToDeviceEvent{}
|
|
||||||
toUpdate := []types.SendToDeviceNID{}
|
|
||||||
toDelete := []types.SendToDeviceNID{}
|
|
||||||
for _, event := range events {
|
|
||||||
if event.SentByToken == nil {
|
|
||||||
// If the event has no sent-by token yet then we haven't attempted to send
|
|
||||||
// it. Record the current requested sync token in the database.
|
|
||||||
toUpdate = append(toUpdate, event.ID)
|
|
||||||
toReturn = append(toReturn, event)
|
|
||||||
event.SentByToken = &token
|
|
||||||
} else if token.IsAfter(*event.SentByToken) {
|
|
||||||
// The event had a sync token, therefore we've sent it before. The current
|
|
||||||
// sync token is now after the stored one so we can assume that the client
|
|
||||||
// successfully completed the previous sync (it would re-request it otherwise)
|
|
||||||
// so we can remove the entry from the database.
|
|
||||||
toDelete = append(toDelete, event.ID)
|
|
||||||
} else {
|
|
||||||
// It looks like the sync is being re-requested, maybe it timed out or
|
|
||||||
// failed. Re-send any that should have been acknowledged by now.
|
|
||||||
toReturn = append(toReturn, event)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return lastPos, toReturn, toUpdate, toDelete, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) CleanSendToDeviceUpdates(
|
func (d *Database) CleanSendToDeviceUpdates(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
toUpdate, toDelete []types.SendToDeviceNID,
|
userID, deviceID string, before types.StreamPosition,
|
||||||
token types.StreamingToken,
|
|
||||||
) (err error) {
|
) (err error) {
|
||||||
if len(toUpdate) == 0 && len(toDelete) == 0 {
|
if err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
|
return d.SendToDevice.DeleteSendToDeviceMessages(ctx, txn, userID, deviceID, before)
|
||||||
|
}); err != nil {
|
||||||
|
logrus.WithError(err).Errorf("Failed to clean up old send-to-device messages for user %q device %q", userID, deviceID)
|
||||||
|
return err
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
// If we need to write to the database then we'll ask the SendToDeviceWriter to
|
|
||||||
// do that for us. It'll guarantee that we don't lock the table for writes in
|
|
||||||
// more than one place.
|
|
||||||
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
|
||||||
// Delete any send-to-device messages marked for deletion.
|
|
||||||
if e := d.SendToDevice.DeleteSendToDeviceMessages(ctx, txn, toDelete); e != nil {
|
|
||||||
return fmt.Errorf("d.SendToDevice.DeleteSendToDeviceMessages: %w", e)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Now update any outstanding send-to-device messages with the new sync token.
|
|
||||||
if e := d.SendToDevice.UpdateSentSendToDeviceMessages(ctx, txn, token.String(), toUpdate); e != nil {
|
|
||||||
return fmt.Errorf("d.SendToDevice.UpdateSentSendToDeviceMessages: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// getMembershipFromEvent returns the value of content.membership iff the event is a state event
|
// getMembershipFromEvent returns the value of content.membership iff the event is a state event
|
||||||
// with type 'm.room.member' and state_key of userID. Otherwise, an empty string is returned.
|
// with type 'm.room.member' and state_key of userID. Otherwise, an empty string is returned.
|
||||||
|
|
|
||||||
|
|
@ -24,6 +24,7 @@ import (
|
||||||
|
|
||||||
func LoadFromGoose() {
|
func LoadFromGoose() {
|
||||||
goose.AddMigration(UpFixSequences, DownFixSequences)
|
goose.AddMigration(UpFixSequences, DownFixSequences)
|
||||||
|
goose.AddMigration(UpRemoveSendToDeviceSentColumn, DownRemoveSendToDeviceSentColumn)
|
||||||
}
|
}
|
||||||
|
|
||||||
func LoadFixSequences(m *sqlutil.Migrations) {
|
func LoadFixSequences(m *sqlutil.Migrations) {
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,67 @@
|
||||||
|
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package deltas
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
|
)
|
||||||
|
|
||||||
|
func LoadRemoveSendToDeviceSentColumn(m *sqlutil.Migrations) {
|
||||||
|
m.AddMigration(UpRemoveSendToDeviceSentColumn, DownRemoveSendToDeviceSentColumn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func UpRemoveSendToDeviceSentColumn(tx *sql.Tx) error {
|
||||||
|
_, err := tx.Exec(`
|
||||||
|
CREATE TEMPORARY TABLE syncapi_send_to_device_backup(id, user_id, device_id, content);
|
||||||
|
INSERT INTO syncapi_send_to_device_backup SELECT id, user_id, device_id, content FROM syncapi_send_to_device;
|
||||||
|
DROP TABLE syncapi_send_to_device;
|
||||||
|
CREATE TABLE syncapi_send_to_device(
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
user_id TEXT NOT NULL,
|
||||||
|
device_id TEXT NOT NULL,
|
||||||
|
content TEXT NOT NULL
|
||||||
|
);
|
||||||
|
INSERT INTO syncapi_send_to_device SELECT id, user_id, device_id, content FROM syncapi_send_to_device_backup;
|
||||||
|
DROP TABLE syncapi_send_to_device_backup;
|
||||||
|
`)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to execute upgrade: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func DownRemoveSendToDeviceSentColumn(tx *sql.Tx) error {
|
||||||
|
_, err := tx.Exec(`
|
||||||
|
CREATE TEMPORARY TABLE syncapi_send_to_device_backup(id, user_id, device_id, content);
|
||||||
|
INSERT INTO syncapi_send_to_device_backup SELECT id, user_id, device_id, content FROM syncapi_send_to_device;
|
||||||
|
DROP TABLE syncapi_send_to_device;
|
||||||
|
CREATE TABLE syncapi_send_to_device(
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
user_id TEXT NOT NULL,
|
||||||
|
device_id TEXT NOT NULL,
|
||||||
|
content TEXT NOT NULL,
|
||||||
|
sent_by_token TEXT
|
||||||
|
);
|
||||||
|
INSERT INTO syncapi_send_to_device SELECT id, user_id, device_id, content FROM syncapi_send_to_device_backup;
|
||||||
|
DROP TABLE syncapi_send_to_device_backup;
|
||||||
|
`)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to execute upgrade: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
@ -18,12 +18,12 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal"
|
"github.com/matrix-org/dendrite/internal"
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
"github.com/matrix-org/dendrite/syncapi/storage/tables"
|
"github.com/matrix-org/dendrite/syncapi/storage/tables"
|
||||||
"github.com/matrix-org/dendrite/syncapi/types"
|
"github.com/matrix-org/dendrite/syncapi/types"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
const sendToDeviceSchema = `
|
const sendToDeviceSchema = `
|
||||||
|
|
@ -36,11 +36,7 @@ CREATE TABLE IF NOT EXISTS syncapi_send_to_device (
|
||||||
-- The device ID to send the message to.
|
-- The device ID to send the message to.
|
||||||
device_id TEXT NOT NULL,
|
device_id TEXT NOT NULL,
|
||||||
-- The event content JSON.
|
-- The event content JSON.
|
||||||
content TEXT NOT NULL,
|
content TEXT NOT NULL
|
||||||
-- The token that was supplied to the /sync at the time that this
|
|
||||||
-- message was included in a sync response, or NULL if we haven't
|
|
||||||
-- included it in a /sync response yet.
|
|
||||||
sent_by_token TEXT
|
|
||||||
);
|
);
|
||||||
`
|
`
|
||||||
|
|
||||||
|
|
@ -49,33 +45,27 @@ const insertSendToDeviceMessageSQL = `
|
||||||
VALUES ($1, $2, $3)
|
VALUES ($1, $2, $3)
|
||||||
`
|
`
|
||||||
|
|
||||||
const countSendToDeviceMessagesSQL = `
|
|
||||||
SELECT COUNT(*)
|
|
||||||
FROM syncapi_send_to_device
|
|
||||||
WHERE user_id = $1 AND device_id = $2
|
|
||||||
`
|
|
||||||
|
|
||||||
const selectSendToDeviceMessagesSQL = `
|
const selectSendToDeviceMessagesSQL = `
|
||||||
SELECT id, user_id, device_id, content, sent_by_token
|
SELECT id, user_id, device_id, content
|
||||||
FROM syncapi_send_to_device
|
FROM syncapi_send_to_device
|
||||||
WHERE user_id = $1 AND device_id = $2
|
WHERE user_id = $1 AND device_id = $2 AND id > $3 AND id <= $4
|
||||||
ORDER BY id DESC
|
ORDER BY id DESC
|
||||||
`
|
`
|
||||||
|
|
||||||
const updateSentSendToDeviceMessagesSQL = `
|
const deleteSendToDeviceMessagesSQL = `
|
||||||
UPDATE syncapi_send_to_device SET sent_by_token = $1
|
DELETE FROM syncapi_send_to_device
|
||||||
WHERE id IN ($2)
|
WHERE user_id = $1 AND device_id = $2 AND id < $3
|
||||||
`
|
`
|
||||||
|
|
||||||
const deleteSendToDeviceMessagesSQL = `
|
const selectMaxSendToDeviceIDSQL = "" +
|
||||||
DELETE FROM syncapi_send_to_device WHERE id IN ($1)
|
"SELECT MAX(id) FROM syncapi_send_to_device"
|
||||||
`
|
|
||||||
|
|
||||||
type sendToDeviceStatements struct {
|
type sendToDeviceStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
insertSendToDeviceMessageStmt *sql.Stmt
|
insertSendToDeviceMessageStmt *sql.Stmt
|
||||||
selectSendToDeviceMessagesStmt *sql.Stmt
|
selectSendToDeviceMessagesStmt *sql.Stmt
|
||||||
countSendToDeviceMessagesStmt *sql.Stmt
|
deleteSendToDeviceMessagesStmt *sql.Stmt
|
||||||
|
selectMaxSendToDeviceIDStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSqliteSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) {
|
func NewSqliteSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) {
|
||||||
|
|
@ -86,15 +76,18 @@ func NewSqliteSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if s.countSendToDeviceMessagesStmt, err = db.Prepare(countSendToDeviceMessagesSQL); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if s.insertSendToDeviceMessageStmt, err = db.Prepare(insertSendToDeviceMessageSQL); err != nil {
|
if s.insertSendToDeviceMessageStmt, err = db.Prepare(insertSendToDeviceMessageSQL); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if s.selectSendToDeviceMessagesStmt, err = db.Prepare(selectSendToDeviceMessagesSQL); err != nil {
|
if s.selectSendToDeviceMessagesStmt, err = db.Prepare(selectSendToDeviceMessagesSQL); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
if s.deleteSendToDeviceMessagesStmt, err = db.Prepare(deleteSendToDeviceMessagesSQL); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if s.selectMaxSendToDeviceIDStmt, err = db.Prepare(selectMaxSendToDeviceIDSQL); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
return s, nil
|
return s, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -111,75 +104,57 @@ func (s *sendToDeviceStatements) InsertSendToDeviceMessage(
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *sendToDeviceStatements) CountSendToDeviceMessages(
|
|
||||||
ctx context.Context, txn *sql.Tx, userID, deviceID string,
|
|
||||||
) (count int, err error) {
|
|
||||||
row := sqlutil.TxStmt(txn, s.countSendToDeviceMessagesStmt).QueryRowContext(ctx, userID, deviceID)
|
|
||||||
if err = row.Scan(&count); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
return count, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *sendToDeviceStatements) SelectSendToDeviceMessages(
|
func (s *sendToDeviceStatements) SelectSendToDeviceMessages(
|
||||||
ctx context.Context, txn *sql.Tx, userID, deviceID string,
|
ctx context.Context, txn *sql.Tx, userID, deviceID string, from, to types.StreamPosition,
|
||||||
) (lastPos types.StreamPosition, events []types.SendToDeviceEvent, err error) {
|
) (lastPos types.StreamPosition, events []types.SendToDeviceEvent, err error) {
|
||||||
rows, err := sqlutil.TxStmt(txn, s.selectSendToDeviceMessagesStmt).QueryContext(ctx, userID, deviceID)
|
rows, err := sqlutil.TxStmt(txn, s.selectSendToDeviceMessagesStmt).QueryContext(ctx, userID, deviceID, from, to)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer internal.CloseAndLogIfError(ctx, rows, "SelectSendToDeviceMessages: rows.close() failed")
|
defer internal.CloseAndLogIfError(ctx, rows, "SelectSendToDeviceMessages: rows.close() failed")
|
||||||
|
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var id types.SendToDeviceNID
|
var id types.StreamPosition
|
||||||
var userID, deviceID, content string
|
var userID, deviceID, content string
|
||||||
var sentByToken *string
|
if err = rows.Scan(&id, &userID, &deviceID, &content); err != nil {
|
||||||
if err = rows.Scan(&id, &userID, &deviceID, &content, &sentByToken); err != nil {
|
logrus.WithError(err).Errorf("Failed to retrieve send-to-device message")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if id > lastPos {
|
||||||
|
lastPos = id
|
||||||
|
}
|
||||||
event := types.SendToDeviceEvent{
|
event := types.SendToDeviceEvent{
|
||||||
ID: id,
|
ID: id,
|
||||||
UserID: userID,
|
UserID: userID,
|
||||||
DeviceID: deviceID,
|
DeviceID: deviceID,
|
||||||
}
|
}
|
||||||
if err = json.Unmarshal([]byte(content), &event.SendToDeviceEvent); err != nil {
|
if err = json.Unmarshal([]byte(content), &event.SendToDeviceEvent); err != nil {
|
||||||
return
|
logrus.WithError(err).Errorf("Failed to unmarshal send-to-device message")
|
||||||
}
|
continue
|
||||||
if sentByToken != nil {
|
|
||||||
if token, err := types.NewStreamTokenFromString(*sentByToken); err == nil {
|
|
||||||
event.SentByToken = &token
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
events = append(events, event)
|
events = append(events, event)
|
||||||
if types.StreamPosition(id) > lastPos {
|
|
||||||
lastPos = types.StreamPosition(id)
|
|
||||||
}
|
}
|
||||||
|
if lastPos == 0 {
|
||||||
|
lastPos = to
|
||||||
}
|
}
|
||||||
|
|
||||||
return lastPos, events, rows.Err()
|
return lastPos, events, rows.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *sendToDeviceStatements) UpdateSentSendToDeviceMessages(
|
func (s *sendToDeviceStatements) DeleteSendToDeviceMessages(
|
||||||
ctx context.Context, txn *sql.Tx, token string, nids []types.SendToDeviceNID,
|
ctx context.Context, txn *sql.Tx, userID, deviceID string, pos types.StreamPosition,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
query := strings.Replace(updateSentSendToDeviceMessagesSQL, "($2)", sqlutil.QueryVariadic(1+len(nids)), 1)
|
_, err = sqlutil.TxStmt(txn, s.deleteSendToDeviceMessagesStmt).ExecContext(ctx, userID, deviceID, pos)
|
||||||
params := make([]interface{}, 1+len(nids))
|
|
||||||
params[0] = token
|
|
||||||
for k, v := range nids {
|
|
||||||
params[k+1] = v
|
|
||||||
}
|
|
||||||
_, err = txn.ExecContext(ctx, query, params...)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *sendToDeviceStatements) DeleteSendToDeviceMessages(
|
func (s *sendToDeviceStatements) SelectMaxSendToDeviceMessageID(
|
||||||
ctx context.Context, txn *sql.Tx, nids []types.SendToDeviceNID,
|
ctx context.Context, txn *sql.Tx,
|
||||||
) (err error) {
|
) (id int64, err error) {
|
||||||
query := strings.Replace(deleteSendToDeviceMessagesSQL, "($1)", sqlutil.QueryVariadic(len(nids)), 1)
|
var nullableID sql.NullInt64
|
||||||
params := make([]interface{}, 1+len(nids))
|
stmt := sqlutil.TxStmt(txn, s.selectMaxSendToDeviceIDStmt)
|
||||||
for k, v := range nids {
|
err = stmt.QueryRowContext(ctx).Scan(&nullableID)
|
||||||
params[k] = v
|
if nullableID.Valid {
|
||||||
|
id = nullableID.Int64
|
||||||
}
|
}
|
||||||
_, err = txn.ExecContext(ctx, query, params...)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -102,6 +102,7 @@ func (d *SyncServerDatasource) prepare(dbProperties *config.DatabaseOptions) (er
|
||||||
}
|
}
|
||||||
m := sqlutil.NewMigrations()
|
m := sqlutil.NewMigrations()
|
||||||
deltas.LoadFixSequences(m)
|
deltas.LoadFixSequences(m)
|
||||||
|
deltas.LoadRemoveSendToDeviceSentColumn(m)
|
||||||
if err = m.RunDeltas(d.db, dbProperties); err != nil {
|
if err = m.RunDeltas(d.db, dbProperties); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -147,10 +147,9 @@ type BackwardsExtremities interface {
|
||||||
// sync response, as the client is seemingly trying to repeat the same /sync.
|
// sync response, as the client is seemingly trying to repeat the same /sync.
|
||||||
type SendToDevice interface {
|
type SendToDevice interface {
|
||||||
InsertSendToDeviceMessage(ctx context.Context, txn *sql.Tx, userID, deviceID, content string) (pos types.StreamPosition, err error)
|
InsertSendToDeviceMessage(ctx context.Context, txn *sql.Tx, userID, deviceID, content string) (pos types.StreamPosition, err error)
|
||||||
SelectSendToDeviceMessages(ctx context.Context, txn *sql.Tx, userID, deviceID string) (lastPos types.StreamPosition, events []types.SendToDeviceEvent, err error)
|
SelectSendToDeviceMessages(ctx context.Context, txn *sql.Tx, userID, deviceID string, from, to types.StreamPosition) (lastPos types.StreamPosition, events []types.SendToDeviceEvent, err error)
|
||||||
UpdateSentSendToDeviceMessages(ctx context.Context, txn *sql.Tx, token string, nids []types.SendToDeviceNID) (err error)
|
DeleteSendToDeviceMessages(ctx context.Context, txn *sql.Tx, userID, deviceID string, from types.StreamPosition) (err error)
|
||||||
DeleteSendToDeviceMessages(ctx context.Context, txn *sql.Tx, nids []types.SendToDeviceNID) (err error)
|
SelectMaxSendToDeviceMessageID(ctx context.Context, txn *sql.Tx) (id int64, err error)
|
||||||
CountSendToDeviceMessages(ctx context.Context, txn *sql.Tx, userID, deviceID string) (count int, err error)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type Filter interface {
|
type Filter interface {
|
||||||
|
|
|
||||||
|
|
@ -10,6 +10,16 @@ type SendToDeviceStreamProvider struct {
|
||||||
StreamProvider
|
StreamProvider
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *SendToDeviceStreamProvider) Setup() {
|
||||||
|
p.StreamProvider.Setup()
|
||||||
|
|
||||||
|
id, err := p.DB.MaxStreamPositionForSendToDeviceMessages(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
p.latest = id
|
||||||
|
}
|
||||||
|
|
||||||
func (p *SendToDeviceStreamProvider) CompleteSync(
|
func (p *SendToDeviceStreamProvider) CompleteSync(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
req *types.SyncRequest,
|
req *types.SyncRequest,
|
||||||
|
|
@ -23,24 +33,19 @@ func (p *SendToDeviceStreamProvider) IncrementalSync(
|
||||||
from, to types.StreamPosition,
|
from, to types.StreamPosition,
|
||||||
) types.StreamPosition {
|
) types.StreamPosition {
|
||||||
// See if we have any new tasks to do for the send-to-device messaging.
|
// See if we have any new tasks to do for the send-to-device messaging.
|
||||||
lastPos, events, updates, deletions, err := p.DB.SendToDeviceUpdatesForSync(req.Context, req.Device.UserID, req.Device.ID, req.Since)
|
lastPos, events, err := p.DB.SendToDeviceUpdatesForSync(req.Context, req.Device.UserID, req.Device.ID, from, to)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
req.Log.WithError(err).Error("p.DB.SendToDeviceUpdatesForSync failed")
|
req.Log.WithError(err).Error("p.DB.SendToDeviceUpdatesForSync failed")
|
||||||
return from
|
return from
|
||||||
}
|
}
|
||||||
|
|
||||||
// Before we return the sync response, make sure that we take action on
|
if len(events) > 0 {
|
||||||
// any send-to-device database updates or deletions that we need to do.
|
// Clean up old send-to-device messages from before this stream position.
|
||||||
// Then add the updates into the sync response.
|
if err := p.DB.CleanSendToDeviceUpdates(req.Context, req.Device.UserID, req.Device.ID, from); err != nil {
|
||||||
if len(updates) > 0 || len(deletions) > 0 {
|
|
||||||
// Handle the updates and deletions in the database.
|
|
||||||
err = p.DB.CleanSendToDeviceUpdates(context.Background(), updates, deletions, req.Since)
|
|
||||||
if err != nil {
|
|
||||||
req.Log.WithError(err).Error("p.DB.CleanSendToDeviceUpdates failed")
|
req.Log.WithError(err).Error("p.DB.CleanSendToDeviceUpdates failed")
|
||||||
return from
|
return from
|
||||||
}
|
}
|
||||||
}
|
|
||||||
if len(events) > 0 {
|
|
||||||
// Add the updates into the sync response.
|
// Add the updates into the sync response.
|
||||||
for _, event := range events {
|
for _, event := range events {
|
||||||
req.Response.ToDevice.Events = append(req.Response.ToDevice.Events, event.SendToDeviceEvent)
|
req.Response.ToDevice.Events = append(req.Response.ToDevice.Events, event.SendToDeviceEvent)
|
||||||
|
|
|
||||||
|
|
@ -492,14 +492,11 @@ func NewLeaveResponse() *LeaveResponse {
|
||||||
return &res
|
return &res
|
||||||
}
|
}
|
||||||
|
|
||||||
type SendToDeviceNID int
|
|
||||||
|
|
||||||
type SendToDeviceEvent struct {
|
type SendToDeviceEvent struct {
|
||||||
gomatrixserverlib.SendToDeviceEvent
|
gomatrixserverlib.SendToDeviceEvent
|
||||||
ID SendToDeviceNID
|
ID StreamPosition
|
||||||
UserID string
|
UserID string
|
||||||
DeviceID string
|
DeviceID string
|
||||||
SentByToken *StreamingToken
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type PeekingDevice struct {
|
type PeekingDevice struct {
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue