mirror of
https://github.com/matrix-org/dendrite.git
synced 2026-01-01 03:03:10 -06:00
Merge branch 'main' into implement-push-notifications
This commit is contained in:
commit
dd2518c269
2
.github/workflows/tests.yml
vendored
2
.github/workflows/tests.yml
vendored
|
|
@ -63,7 +63,7 @@ jobs:
|
|||
# Run Complement
|
||||
- run: |
|
||||
set -o pipefail &&
|
||||
go test -v -p 1 -json -tags dendrite_blacklist ./tests/... 2>&1 | gotestfmt
|
||||
go test -v -json -tags dendrite_blacklist ./tests/... 2>&1 | gotestfmt
|
||||
shell: bash
|
||||
name: Run Complement Tests
|
||||
env:
|
||||
|
|
|
|||
|
|
@ -1,9 +1,5 @@
|
|||
#!/bin/sh
|
||||
|
||||
<<<<<<< HEAD
|
||||
for db in userapi_accounts userapi_devices pushserver mediaapi syncapi roomserver keyserver federationapi appservice naffka; do
|
||||
=======
|
||||
for db in userapi_accounts userapi_devices mediaapi syncapi roomserver keyserver federationapi appservice mscs; do
|
||||
>>>>>>> main
|
||||
createdb -U dendrite -O dendrite dendrite_$db
|
||||
done
|
||||
|
|
|
|||
|
|
@ -144,21 +144,23 @@ func (u *UserInteractive) AddCompletedStage(sessionID, authType string) {
|
|||
delete(u.Sessions, sessionID)
|
||||
}
|
||||
|
||||
// Challenge returns an HTTP 401 with the supported flows for authenticating
|
||||
func (u *UserInteractive) Challenge(sessionID string) *util.JSONResponse {
|
||||
return &util.JSONResponse{
|
||||
Code: 401,
|
||||
JSON: struct {
|
||||
type Challenge struct {
|
||||
Completed []string `json:"completed"`
|
||||
Flows []userInteractiveFlow `json:"flows"`
|
||||
Session string `json:"session"`
|
||||
// TODO: Return any additional `params`
|
||||
Params map[string]interface{} `json:"params"`
|
||||
}{
|
||||
u.Completed,
|
||||
u.Flows,
|
||||
sessionID,
|
||||
make(map[string]interface{}),
|
||||
}
|
||||
|
||||
// Challenge returns an HTTP 401 with the supported flows for authenticating
|
||||
func (u *UserInteractive) Challenge(sessionID string) *util.JSONResponse {
|
||||
return &util.JSONResponse{
|
||||
Code: 401,
|
||||
JSON: Challenge{
|
||||
Completed: u.Completed,
|
||||
Flows: u.Flows,
|
||||
Session: sessionID,
|
||||
Params: make(map[string]interface{}),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -162,7 +162,7 @@ func AuthFallback(
|
|||
}
|
||||
|
||||
// Success. Add recaptcha as a completed login flow
|
||||
AddCompletedSessionStage(sessionID, authtypes.LoginTypeRecaptcha)
|
||||
sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypeRecaptcha)
|
||||
|
||||
serveSuccess()
|
||||
return nil
|
||||
|
|
|
|||
|
|
@ -25,6 +25,7 @@ import (
|
|||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/matrix-org/util"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
// https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-devices
|
||||
|
|
@ -163,6 +164,15 @@ func DeleteDeviceById(
|
|||
req *http.Request, userInteractiveAuth *auth.UserInteractive, userAPI api.UserInternalAPI, device *api.Device,
|
||||
deviceID string,
|
||||
) util.JSONResponse {
|
||||
var (
|
||||
deleteOK bool
|
||||
sessionID string
|
||||
)
|
||||
defer func() {
|
||||
if deleteOK {
|
||||
sessions.deleteSession(sessionID)
|
||||
}
|
||||
}()
|
||||
ctx := req.Context()
|
||||
defer req.Body.Close() // nolint:errcheck
|
||||
bodyBytes, err := ioutil.ReadAll(req.Body)
|
||||
|
|
@ -172,8 +182,29 @@ func DeleteDeviceById(
|
|||
JSON: jsonerror.BadJSON("The request body could not be read: " + err.Error()),
|
||||
}
|
||||
}
|
||||
|
||||
// check that we know this session, and it matches with the device to delete
|
||||
s := gjson.GetBytes(bodyBytes, "auth.session").Str
|
||||
if dev, ok := sessions.getDeviceToDelete(s); ok {
|
||||
if dev != deviceID {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusForbidden,
|
||||
JSON: jsonerror.Forbidden("session & device mismatch"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if s != "" {
|
||||
sessionID = s
|
||||
}
|
||||
|
||||
login, errRes := userInteractiveAuth.Verify(ctx, bodyBytes, device)
|
||||
if errRes != nil {
|
||||
switch data := errRes.JSON.(type) {
|
||||
case auth.Challenge:
|
||||
sessions.addDeviceToDelete(data.Session, deviceID)
|
||||
default:
|
||||
}
|
||||
return *errRes
|
||||
}
|
||||
|
||||
|
|
@ -201,6 +232,8 @@ func DeleteDeviceById(
|
|||
return jsonerror.InternalServerError()
|
||||
}
|
||||
|
||||
deleteOK = true
|
||||
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusOK,
|
||||
JSON: struct{}{},
|
||||
|
|
|
|||
|
|
@ -70,7 +70,7 @@ func UploadCrossSigningDeviceKeys(
|
|||
if _, authErr := typePassword.Login(req.Context(), &uploadReq.Auth.PasswordRequest); authErr != nil {
|
||||
return *authErr
|
||||
}
|
||||
AddCompletedSessionStage(sessionID, authtypes.LoginTypePassword)
|
||||
sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypePassword)
|
||||
|
||||
uploadReq.UserID = device.UserID
|
||||
keyserverAPI.PerformUploadDeviceKeys(req.Context(), &uploadReq.PerformUploadDeviceKeysRequest, uploadRes)
|
||||
|
|
|
|||
|
|
@ -80,7 +80,7 @@ func Password(
|
|||
if _, authErr := typePassword.Login(req.Context(), &r.Auth.PasswordRequest); authErr != nil {
|
||||
return *authErr
|
||||
}
|
||||
AddCompletedSessionStage(sessionID, authtypes.LoginTypePassword)
|
||||
sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypePassword)
|
||||
|
||||
// Check the new password strength.
|
||||
if resErr = validatePassword(r.NewPassword); resErr != nil {
|
||||
|
|
|
|||
|
|
@ -72,14 +72,23 @@ func init() {
|
|||
// sessionsDict keeps track of completed auth stages for each session.
|
||||
// It shouldn't be passed by value because it contains a mutex.
|
||||
type sessionsDict struct {
|
||||
sync.Mutex
|
||||
sync.RWMutex
|
||||
sessions map[string][]authtypes.LoginType
|
||||
params map[string]registerRequest
|
||||
timer map[string]*time.Timer
|
||||
// deleteSessionToDeviceID protects requests to DELETE /devices/{deviceID} from being abused.
|
||||
// If a UIA session is started by trying to delete device1, and then UIA is completed by deleting device2,
|
||||
// the delete request will fail for device2 since the UIA was initiated by trying to delete device1.
|
||||
deleteSessionToDeviceID map[string]string
|
||||
}
|
||||
|
||||
// GetCompletedStages returns the completed stages for a session.
|
||||
func (d *sessionsDict) GetCompletedStages(sessionID string) []authtypes.LoginType {
|
||||
d.Lock()
|
||||
defer d.Unlock()
|
||||
// defaultTimeout is the timeout used to clean up sessions
|
||||
const defaultTimeOut = time.Minute * 5
|
||||
|
||||
// getCompletedStages returns the completed stages for a session.
|
||||
func (d *sessionsDict) getCompletedStages(sessionID string) []authtypes.LoginType {
|
||||
d.RLock()
|
||||
defer d.RUnlock()
|
||||
|
||||
if completedStages, ok := d.sessions[sessionID]; ok {
|
||||
return completedStages
|
||||
|
|
@ -88,28 +97,95 @@ func (d *sessionsDict) GetCompletedStages(sessionID string) []authtypes.LoginTyp
|
|||
return make([]authtypes.LoginType, 0)
|
||||
}
|
||||
|
||||
func newSessionsDict() *sessionsDict {
|
||||
return &sessionsDict{
|
||||
sessions: make(map[string][]authtypes.LoginType),
|
||||
// addParams adds a registerRequest to a sessionID and starts a timer to delete that registerRequest
|
||||
func (d *sessionsDict) addParams(sessionID string, r registerRequest) {
|
||||
d.startTimer(defaultTimeOut, sessionID)
|
||||
d.Lock()
|
||||
defer d.Unlock()
|
||||
d.params[sessionID] = r
|
||||
}
|
||||
|
||||
func (d *sessionsDict) getParams(sessionID string) (registerRequest, bool) {
|
||||
d.RLock()
|
||||
defer d.RUnlock()
|
||||
r, ok := d.params[sessionID]
|
||||
return r, ok
|
||||
}
|
||||
|
||||
// deleteSession cleans up a given session, either because the registration completed
|
||||
// successfully, or because a given timeout (default: 5min) was reached.
|
||||
func (d *sessionsDict) deleteSession(sessionID string) {
|
||||
d.Lock()
|
||||
defer d.Unlock()
|
||||
delete(d.params, sessionID)
|
||||
delete(d.sessions, sessionID)
|
||||
delete(d.deleteSessionToDeviceID, sessionID)
|
||||
// stop the timer, e.g. because the registration was completed
|
||||
if t, ok := d.timer[sessionID]; ok {
|
||||
if !t.Stop() {
|
||||
select {
|
||||
case <-t.C:
|
||||
default:
|
||||
}
|
||||
}
|
||||
delete(d.timer, sessionID)
|
||||
}
|
||||
}
|
||||
|
||||
// AddCompletedSessionStage records that a session has completed an auth stage.
|
||||
func AddCompletedSessionStage(sessionID string, stage authtypes.LoginType) {
|
||||
sessions.Lock()
|
||||
defer sessions.Unlock()
|
||||
func newSessionsDict() *sessionsDict {
|
||||
return &sessionsDict{
|
||||
sessions: make(map[string][]authtypes.LoginType),
|
||||
params: make(map[string]registerRequest),
|
||||
timer: make(map[string]*time.Timer),
|
||||
deleteSessionToDeviceID: make(map[string]string),
|
||||
}
|
||||
}
|
||||
|
||||
for _, completedStage := range sessions.sessions[sessionID] {
|
||||
func (d *sessionsDict) startTimer(duration time.Duration, sessionID string) {
|
||||
d.Lock()
|
||||
defer d.Unlock()
|
||||
t, ok := d.timer[sessionID]
|
||||
if ok {
|
||||
if !t.Stop() {
|
||||
<-t.C
|
||||
}
|
||||
t.Reset(duration)
|
||||
return
|
||||
}
|
||||
d.timer[sessionID] = time.AfterFunc(duration, func() {
|
||||
d.deleteSession(sessionID)
|
||||
})
|
||||
}
|
||||
|
||||
// addCompletedSessionStage records that a session has completed an auth stage
|
||||
// also starts a timer to delete the session once done.
|
||||
func (d *sessionsDict) addCompletedSessionStage(sessionID string, stage authtypes.LoginType) {
|
||||
d.startTimer(defaultTimeOut, sessionID)
|
||||
d.Lock()
|
||||
defer d.Unlock()
|
||||
for _, completedStage := range d.sessions[sessionID] {
|
||||
if completedStage == stage {
|
||||
return
|
||||
}
|
||||
}
|
||||
sessions.sessions[sessionID] = append(sessions.sessions[sessionID], stage)
|
||||
d.sessions[sessionID] = append(sessions.sessions[sessionID], stage)
|
||||
}
|
||||
|
||||
func (d *sessionsDict) addDeviceToDelete(sessionID, deviceID string) {
|
||||
d.startTimer(defaultTimeOut, sessionID)
|
||||
d.Lock()
|
||||
defer d.Unlock()
|
||||
d.deleteSessionToDeviceID[sessionID] = deviceID
|
||||
}
|
||||
|
||||
func (d *sessionsDict) getDeviceToDelete(sessionID string) (string, bool) {
|
||||
d.RLock()
|
||||
defer d.RUnlock()
|
||||
deviceID, ok := d.deleteSessionToDeviceID[sessionID]
|
||||
return deviceID, ok
|
||||
}
|
||||
|
||||
var (
|
||||
// TODO: Remove old sessions. Need to do so on a session-specific timeout.
|
||||
// sessions stores the completed flow stages for all sessions. Referenced using their sessionID.
|
||||
sessions = newSessionsDict()
|
||||
validUsernameRegex = regexp.MustCompile(`^[0-9a-z_\-=./]+$`)
|
||||
)
|
||||
|
|
@ -167,7 +243,7 @@ func newUserInteractiveResponse(
|
|||
params map[string]interface{},
|
||||
) userInteractiveResponse {
|
||||
return userInteractiveResponse{
|
||||
fs, sessions.GetCompletedStages(sessionID), params, sessionID,
|
||||
fs, sessions.getCompletedStages(sessionID), params, sessionID,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -645,12 +721,12 @@ func handleRegistrationFlow(
|
|||
}
|
||||
|
||||
// Add Recaptcha to the list of completed registration stages
|
||||
AddCompletedSessionStage(sessionID, authtypes.LoginTypeRecaptcha)
|
||||
sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypeRecaptcha)
|
||||
|
||||
case authtypes.LoginTypeDummy:
|
||||
// there is nothing to do
|
||||
// Add Dummy to the list of completed registration stages
|
||||
AddCompletedSessionStage(sessionID, authtypes.LoginTypeDummy)
|
||||
sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypeDummy)
|
||||
|
||||
case "":
|
||||
// An empty auth type means that we want to fetch the available
|
||||
|
|
@ -666,7 +742,7 @@ func handleRegistrationFlow(
|
|||
// Check if the user's registration flow has been completed successfully
|
||||
// A response with current registration flow and remaining available methods
|
||||
// will be returned if a flow has not been successfully completed yet
|
||||
return checkAndCompleteFlow(sessions.GetCompletedStages(sessionID),
|
||||
return checkAndCompleteFlow(sessions.getCompletedStages(sessionID),
|
||||
req, r, sessionID, cfg, userAPI)
|
||||
}
|
||||
|
||||
|
|
@ -708,7 +784,7 @@ func handleApplicationServiceRegistration(
|
|||
// Don't need to worry about appending to registration stages as
|
||||
// application service registration is entirely separate.
|
||||
return completeRegistration(
|
||||
req.Context(), userAPI, r.Username, "", appserviceID, req.RemoteAddr, req.UserAgent(),
|
||||
req.Context(), userAPI, r.Username, "", appserviceID, req.RemoteAddr, req.UserAgent(), r.Auth.Session,
|
||||
r.InhibitLogin, r.InitialDisplayName, r.DeviceID, userapi.AccountTypeAppService,
|
||||
)
|
||||
}
|
||||
|
|
@ -727,11 +803,11 @@ func checkAndCompleteFlow(
|
|||
if checkFlowCompleted(flow, cfg.Derived.Registration.Flows) {
|
||||
// This flow was completed, registration can continue
|
||||
return completeRegistration(
|
||||
req.Context(), userAPI, r.Username, r.Password, "", req.RemoteAddr, req.UserAgent(),
|
||||
req.Context(), userAPI, r.Username, r.Password, "", req.RemoteAddr, req.UserAgent(), sessionID,
|
||||
r.InhibitLogin, r.InitialDisplayName, r.DeviceID, userapi.AccountTypeUser,
|
||||
)
|
||||
}
|
||||
|
||||
sessions.addParams(sessionID, r)
|
||||
// There are still more stages to complete.
|
||||
// Return the flows and those that have been completed.
|
||||
return util.JSONResponse{
|
||||
|
|
@ -750,11 +826,25 @@ func checkAndCompleteFlow(
|
|||
func completeRegistration(
|
||||
ctx context.Context,
|
||||
userAPI userapi.UserInternalAPI,
|
||||
username, password, appserviceID, ipAddr, userAgent string,
|
||||
username, password, appserviceID, ipAddr, userAgent, sessionID string,
|
||||
inhibitLogin eventutil.WeakBoolean,
|
||||
displayName, deviceID *string,
|
||||
accType userapi.AccountType,
|
||||
) util.JSONResponse {
|
||||
var registrationOK bool
|
||||
defer func() {
|
||||
if registrationOK {
|
||||
sessions.deleteSession(sessionID)
|
||||
}
|
||||
}()
|
||||
|
||||
if data, ok := sessions.getParams(sessionID); ok {
|
||||
username = data.Username
|
||||
password = data.Password
|
||||
deviceID = data.DeviceID
|
||||
displayName = data.InitialDisplayName
|
||||
inhibitLogin = data.InhibitLogin
|
||||
}
|
||||
if username == "" {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
|
|
@ -795,6 +885,7 @@ func completeRegistration(
|
|||
// Check whether inhibit_login option is set. If so, don't create an access
|
||||
// token or a device for this user
|
||||
if inhibitLogin {
|
||||
registrationOK = true
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusOK,
|
||||
JSON: registerResponse{
|
||||
|
|
@ -828,6 +919,7 @@ func completeRegistration(
|
|||
}
|
||||
}
|
||||
|
||||
registrationOK = true
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusOK,
|
||||
JSON: registerResponse{
|
||||
|
|
@ -976,5 +1068,5 @@ func handleSharedSecretRegistration(userAPI userapi.UserInternalAPI, sr *SharedS
|
|||
if ssrr.Admin {
|
||||
accType = userapi.AccountTypeAdmin
|
||||
}
|
||||
return completeRegistration(req.Context(), userAPI, ssrr.User, ssrr.Password, "", req.RemoteAddr, req.UserAgent(), false, &ssrr.User, &deviceID, accType)
|
||||
return completeRegistration(req.Context(), userAPI, ssrr.User, ssrr.Password, "", req.RemoteAddr, req.UserAgent(), "", false, &ssrr.User, &deviceID, accType)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ package routing
|
|||
import (
|
||||
"regexp"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
|
|
@ -140,7 +141,7 @@ func TestFlowCheckingExtraneousIncorrectInput(t *testing.T) {
|
|||
func TestEmptyCompletedFlows(t *testing.T) {
|
||||
fakeEmptySessions := newSessionsDict()
|
||||
fakeSessionID := "aRandomSessionIDWhichDoesNotExist"
|
||||
ret := fakeEmptySessions.GetCompletedStages(fakeSessionID)
|
||||
ret := fakeEmptySessions.getCompletedStages(fakeSessionID)
|
||||
|
||||
// check for []
|
||||
if ret == nil || len(ret) != 0 {
|
||||
|
|
@ -208,3 +209,55 @@ func TestValidationOfApplicationServices(t *testing.T) {
|
|||
t.Errorf("user_id should not have been valid: @_something_else:localhost")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionCleanUp(t *testing.T) {
|
||||
s := newSessionsDict()
|
||||
|
||||
t.Run("session is cleaned up after a while", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
dummySession := "helloWorld"
|
||||
// manually added, as s.addParams() would start the timer with the default timeout
|
||||
s.params[dummySession] = registerRequest{Username: "Testing"}
|
||||
s.startTimer(time.Millisecond, dummySession)
|
||||
time.Sleep(time.Millisecond * 2)
|
||||
if data, ok := s.getParams(dummySession); ok {
|
||||
t.Errorf("expected session to be deleted: %+v", data)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("session is deleted, once the registration completed", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
dummySession := "helloWorld2"
|
||||
s.startTimer(time.Minute, dummySession)
|
||||
s.deleteSession(dummySession)
|
||||
if data, ok := s.getParams(dummySession); ok {
|
||||
t.Errorf("expected session to be deleted: %+v", data)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("session timer is restarted after second call", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
dummySession := "helloWorld3"
|
||||
// the following will start a timer with the default timeout of 5min
|
||||
s.addParams(dummySession, registerRequest{Username: "Testing"})
|
||||
s.addCompletedSessionStage(dummySession, authtypes.LoginTypeRecaptcha)
|
||||
s.addCompletedSessionStage(dummySession, authtypes.LoginTypeDummy)
|
||||
s.addDeviceToDelete(dummySession, "dummyDevice")
|
||||
s.getCompletedStages(dummySession)
|
||||
// reset the timer with a lower timeout
|
||||
s.startTimer(time.Millisecond, dummySession)
|
||||
time.Sleep(time.Millisecond * 2)
|
||||
if data, ok := s.getParams(dummySession); ok {
|
||||
t.Errorf("expected session to be deleted: %+v", data)
|
||||
}
|
||||
if _, ok := s.timer[dummySession]; ok {
|
||||
t.Error("expected timer to be delete")
|
||||
}
|
||||
if _, ok := s.sessions[dummySession]; ok {
|
||||
t.Error("expected session to be delete")
|
||||
}
|
||||
if _, ok := s.getDeviceToDelete(dummySession); ok {
|
||||
t.Error("expected session to device to be delete")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
@ -48,6 +48,8 @@ Example:
|
|||
# read password from stdin
|
||||
%s --config dendrite.yaml -username alice -passwordstdin < my.pass
|
||||
cat my.pass | %s --config dendrite.yaml -username alice -passwordstdin
|
||||
# reset password for a user, can be used with a combination above to read the password
|
||||
%s --config dendrite.yaml -reset-password -username alice -password foobarbaz
|
||||
|
||||
Arguments:
|
||||
|
||||
|
|
@ -60,12 +62,13 @@ var (
|
|||
pwdStdin = flag.Bool("passwordstdin", false, "Reads the password from stdin")
|
||||
askPass = flag.Bool("ask-pass", false, "Ask for the password to use")
|
||||
isAdmin = flag.Bool("admin", false, "Create an admin account")
|
||||
resetPassword = flag.Bool("reset-password", false, "Resets the password for the given username")
|
||||
)
|
||||
|
||||
func main() {
|
||||
name := os.Args[0]
|
||||
flag.Usage = func() {
|
||||
_, _ = fmt.Fprintf(os.Stderr, usage, name, name, name, name, name, name)
|
||||
_, _ = fmt.Fprintf(os.Stderr, usage, name, name, name, name, name, name, name)
|
||||
flag.PrintDefaults()
|
||||
}
|
||||
cfg := setup.ParseFlags(true)
|
||||
|
|
@ -93,6 +96,19 @@ func main() {
|
|||
if *isAdmin {
|
||||
accType = api.AccountTypeAdmin
|
||||
}
|
||||
|
||||
if *resetPassword {
|
||||
err = accountDB.SetPassword(context.Background(), *username, pass)
|
||||
if err != nil {
|
||||
logrus.Fatalf("Failed to update password for user %s: %s", *username, err.Error())
|
||||
}
|
||||
if _, err = accountDB.RemoveAllDevices(context.Background(), *username, ""); err != nil {
|
||||
logrus.Fatalf("Failed to remove all devices: %s", err.Error())
|
||||
}
|
||||
logrus.Infof("Updated password for user %s and invalidated all logins\n", *username)
|
||||
return
|
||||
}
|
||||
|
||||
_, err = accountDB.CreateAccount(context.Background(), *username, pass, "", accType)
|
||||
if err != nil {
|
||||
logrus.Fatalln("Failed to create the account:", err.Error())
|
||||
|
|
|
|||
|
|
@ -29,7 +29,6 @@ import (
|
|||
"github.com/matrix-org/dendrite/federationapi"
|
||||
"github.com/matrix-org/dendrite/internal/httputil"
|
||||
"github.com/matrix-org/dendrite/keyserver"
|
||||
"github.com/matrix-org/dendrite/pushserver"
|
||||
"github.com/matrix-org/dendrite/roomserver"
|
||||
"github.com/matrix-org/dendrite/setup"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
|
|
@ -173,13 +172,7 @@ func main() {
|
|||
cfg.RoomServer.Database.ConnectionString = "file:/idb/dendritejs_roomserver.db"
|
||||
cfg.SyncAPI.Database.ConnectionString = "file:/idb/dendritejs_syncapi.db"
|
||||
cfg.KeyServer.Database.ConnectionString = "file:/idb/dendritejs_e2ekey.db"
|
||||
<<<<<<< HEAD
|
||||
cfg.PushServer.Database.ConnectionString = "file:/idb/dendritejs_pushserver.db"
|
||||
cfg.Global.Kafka.UseNaffka = true
|
||||
cfg.Global.Kafka.Database.ConnectionString = "file:/idb/dendritejs_naffka.db"
|
||||
=======
|
||||
cfg.Global.JetStream.StoragePath = "file:/idb/dendritejs/"
|
||||
>>>>>>> main
|
||||
cfg.Global.TrustedIDServers = []string{
|
||||
"matrix.org", "vector.im",
|
||||
}
|
||||
|
|
|
|||
|
|
@ -345,28 +345,12 @@ user_api:
|
|||
max_open_conns: 10
|
||||
max_idle_conns: 2
|
||||
conn_max_lifetime: -1
|
||||
device_database:
|
||||
connection_string: file:userapi_devices.db
|
||||
max_open_conns: 10
|
||||
max_idle_conns: 2
|
||||
conn_max_lifetime: -1
|
||||
# The length of time that a token issued for a relying party from
|
||||
# /_matrix/client/r0/user/{userId}/openid/request_token endpoint
|
||||
# is considered to be valid in milliseconds.
|
||||
# The default lifetime is 3600000ms (60 minutes).
|
||||
# openid_token_lifetime_ms: 3600000
|
||||
|
||||
# Configuration for the Push Server API.
|
||||
push_server:
|
||||
internal_api:
|
||||
listen: http://localhost:7782
|
||||
connect: http://localhost:7782
|
||||
database:
|
||||
connection_string: file:pushserver.db
|
||||
max_open_conns: 10
|
||||
max_idle_conns: 2
|
||||
conn_max_lifetime: -1
|
||||
|
||||
# Configuration for Opentracing.
|
||||
# See https://github.com/matrix-org/dendrite/tree/master/docs/tracing for information on
|
||||
# how this works and how to set it up.
|
||||
|
|
@ -385,9 +369,9 @@ tracing:
|
|||
|
||||
# Logging configuration
|
||||
logging:
|
||||
- type: std
|
||||
- type: std
|
||||
level: info
|
||||
- type: file
|
||||
- type: file
|
||||
# The logging level, must be one of debug, info, warn, error, fatal, panic.
|
||||
level: info
|
||||
params:
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ Group=dendrite
|
|||
WorkingDirectory=/opt/dendrite/
|
||||
ExecStart=/opt/dendrite/bin/dendrite-monolith-server
|
||||
Restart=always
|
||||
LimitNOFILE=65535
|
||||
|
||||
[Install]
|
||||
WantedBy=multi-user.target
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ type FederationClient interface {
|
|||
QueryKeys(ctx context.Context, s gomatrixserverlib.ServerName, keys map[string][]string) (res gomatrixserverlib.RespQueryKeys, err error)
|
||||
GetEvent(ctx context.Context, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error)
|
||||
MSC2836EventRelationships(ctx context.Context, dst gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error)
|
||||
MSC2946Spaces(ctx context.Context, dst gomatrixserverlib.ServerName, roomID string, r gomatrixserverlib.MSC2946SpacesRequest) (res gomatrixserverlib.MSC2946SpacesResponse, err error)
|
||||
MSC2946Spaces(ctx context.Context, dst gomatrixserverlib.ServerName, roomID string, suggestedOnly bool) (res gomatrixserverlib.MSC2946SpacesResponse, err error)
|
||||
LookupServerKeys(ctx context.Context, s gomatrixserverlib.ServerName, keyRequests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp) ([]gomatrixserverlib.ServerKeys, error)
|
||||
GetEventAuth(ctx context.Context, s gomatrixserverlib.ServerName, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string) (res gomatrixserverlib.RespEventAuth, err error)
|
||||
LookupMissingEvents(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error)
|
||||
|
|
|
|||
|
|
@ -166,12 +166,12 @@ func (a *FederationInternalAPI) MSC2836EventRelationships(
|
|||
}
|
||||
|
||||
func (a *FederationInternalAPI) MSC2946Spaces(
|
||||
ctx context.Context, s gomatrixserverlib.ServerName, roomID string, r gomatrixserverlib.MSC2946SpacesRequest,
|
||||
ctx context.Context, s gomatrixserverlib.ServerName, roomID string, suggestedOnly bool,
|
||||
) (res gomatrixserverlib.MSC2946SpacesResponse, err error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, time.Minute)
|
||||
defer cancel()
|
||||
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
|
||||
return a.federation.MSC2946Spaces(ctx, s, roomID, r)
|
||||
return a.federation.MSC2946Spaces(ctx, s, roomID, suggestedOnly)
|
||||
})
|
||||
if err != nil {
|
||||
return res, err
|
||||
|
|
|
|||
|
|
@ -527,21 +527,21 @@ func (h *httpFederationInternalAPI) MSC2836EventRelationships(
|
|||
|
||||
type spacesReq struct {
|
||||
S gomatrixserverlib.ServerName
|
||||
Req gomatrixserverlib.MSC2946SpacesRequest
|
||||
SuggestedOnly bool
|
||||
RoomID string
|
||||
Res gomatrixserverlib.MSC2946SpacesResponse
|
||||
Err *api.FederationClientError
|
||||
}
|
||||
|
||||
func (h *httpFederationInternalAPI) MSC2946Spaces(
|
||||
ctx context.Context, dst gomatrixserverlib.ServerName, roomID string, r gomatrixserverlib.MSC2946SpacesRequest,
|
||||
ctx context.Context, dst gomatrixserverlib.ServerName, roomID string, suggestedOnly bool,
|
||||
) (res gomatrixserverlib.MSC2946SpacesResponse, err error) {
|
||||
span, ctx := opentracing.StartSpanFromContext(ctx, "MSC2946Spaces")
|
||||
defer span.Finish()
|
||||
|
||||
request := spacesReq{
|
||||
S: dst,
|
||||
Req: r,
|
||||
SuggestedOnly: suggestedOnly,
|
||||
RoomID: roomID,
|
||||
}
|
||||
var response spacesReq
|
||||
|
|
|
|||
|
|
@ -378,7 +378,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) {
|
|||
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
|
||||
return util.MessageResponse(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
res, err := intAPI.MSC2946Spaces(req.Context(), request.S, request.RoomID, request.Req)
|
||||
res, err := intAPI.MSC2946Spaces(req.Context(), request.S, request.RoomID, request.SuggestedOnly)
|
||||
if err != nil {
|
||||
ferr, ok := err.(*api.FederationClientError)
|
||||
if ok {
|
||||
|
|
|
|||
10
go.mod
10
go.mod
|
|
@ -19,6 +19,7 @@ require (
|
|||
github.com/getsentry/sentry-go v0.12.0
|
||||
github.com/gologme/log v1.3.0
|
||||
github.com/google/go-cmp v0.5.6
|
||||
github.com/google/uuid v1.2.0
|
||||
github.com/gorilla/mux v1.8.0
|
||||
github.com/gorilla/websocket v1.4.2
|
||||
github.com/h2non/filetype v1.1.3 // indirect
|
||||
|
|
@ -40,10 +41,9 @@ require (
|
|||
github.com/matrix-org/go-http-js-libp2p v0.0.0-20200518170932-783164aeeda4
|
||||
github.com/matrix-org/go-sqlite3-js v0.0.0-20210709140738-b0d1ba599a6d
|
||||
github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16
|
||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20220214133635-20632dd262ed
|
||||
github.com/matrix-org/pinecone v0.0.0-20220121094951-351265543ddf
|
||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20220301141554-e124bd7d7902
|
||||
github.com/matrix-org/pinecone v0.0.0-20220223104432-0f0afd1a46aa
|
||||
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4
|
||||
github.com/matryer/is v1.4.0
|
||||
github.com/mattn/go-sqlite3 v1.14.10
|
||||
github.com/morikuni/aec v1.0.0 // indirect
|
||||
github.com/nats-io/nats-server/v2 v2.3.2
|
||||
|
|
@ -63,11 +63,11 @@ require (
|
|||
github.com/uber/jaeger-lib v2.4.1+incompatible
|
||||
github.com/yggdrasil-network/yggdrasil-go v0.4.2
|
||||
go.uber.org/atomic v1.9.0
|
||||
golang.org/x/crypto v0.0.0-20220209195652-db638375bc3a
|
||||
golang.org/x/crypto v0.0.0-20220214200702-86341886e292
|
||||
golang.org/x/image v0.0.0-20211028202545-6944b10bf410
|
||||
golang.org/x/mobile v0.0.0-20220112015953-858099ff7816
|
||||
golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd
|
||||
golang.org/x/sys v0.0.0-20220207234003-57398862261d // indirect
|
||||
golang.org/x/sys v0.0.0-20220227234510-4e6760a101f9 // indirect
|
||||
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211
|
||||
gopkg.in/h2non/bimg.v1 v1.1.5
|
||||
gopkg.in/yaml.v2 v2.4.0
|
||||
|
|
|
|||
18
go.sum
18
go.sum
|
|
@ -983,15 +983,13 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20210709140738-b0d1ba599a6d/go.mod h1
|
|||
github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0=
|
||||
github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 h1:ZtO5uywdd5dLDCud4r0r55eP4j9FuUNpl60Gmntcop4=
|
||||
github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s=
|
||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20220214133635-20632dd262ed h1:R8EiLWArq7KT96DrUq1xq9scPh8vLwKKeCTnORPyjhU=
|
||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20220214133635-20632dd262ed/go.mod h1:qFvhfbQ5orQxlH9vCiFnP4dW27xxnWHdNUBKyj/fbiY=
|
||||
github.com/matrix-org/pinecone v0.0.0-20220121094951-351265543ddf h1:/nqfHUdQHr3WVdbZieaYFvHF1rin5pvDTa/NOZ/qCyE=
|
||||
github.com/matrix-org/pinecone v0.0.0-20220121094951-351265543ddf/go.mod h1:r6dsL+ylE0yXe/7zh8y/Bdh6aBYI1r+u4yZni9A4iyk=
|
||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20220301141554-e124bd7d7902 h1:WHlrE8BYh/hzn1RKwq3YMAlhHivX47jQKAjZFtkJyPE=
|
||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20220301141554-e124bd7d7902/go.mod h1:+WF5InseAMgi1fTnU46JH39IDpEvLep0fDzx9LDf2Bo=
|
||||
github.com/matrix-org/pinecone v0.0.0-20220223104432-0f0afd1a46aa h1:rMYFNVto66gp+eWS8XAUzgp4m0qmUBid6l1HX3mHstk=
|
||||
github.com/matrix-org/pinecone v0.0.0-20220223104432-0f0afd1a46aa/go.mod h1:r6dsL+ylE0yXe/7zh8y/Bdh6aBYI1r+u4yZni9A4iyk=
|
||||
github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U=
|
||||
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 h1:eCEHXWDv9Rm335MSuB49mFUK44bwZPFSDde3ORE3syk=
|
||||
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U=
|
||||
github.com/matryer/is v1.4.0 h1:sosSmIWwkYITGrxZ25ULNDeKiMNzFSr4V/eqBQP0PeE=
|
||||
github.com/matryer/is v1.4.0/go.mod h1:8I/i5uYgLzgsgEloJE1U6xx5HkBQpAZvepWuujKwMRU=
|
||||
github.com/mattn/go-colorable v0.0.6/go.mod h1:9vuHe8Xs5qXnSaW/c/ABM9alt+Vo+STaOChaDxuIBZU=
|
||||
github.com/mattn/go-colorable v0.0.9/go.mod h1:9vuHe8Xs5qXnSaW/c/ABM9alt+Vo+STaOChaDxuIBZU=
|
||||
github.com/mattn/go-colorable v0.1.1/go.mod h1:FuOcm+DKB9mbwrcAfNl7/TZVBZ6rcnceauSikq3lYCQ=
|
||||
|
|
@ -1512,8 +1510,8 @@ golang.org/x/crypto v0.0.0-20210506145944-38f3c27a63bf/go.mod h1:P+XmwS30IXTQdn5
|
|||
golang.org/x/crypto v0.0.0-20210513164829-c07d793c2f9a/go.mod h1:P+XmwS30IXTQdn5tA2iutPOUgjI07+tq3H3K9MVA1s8=
|
||||
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||
golang.org/x/crypto v0.0.0-20220112180741-5e0467b6c7ce/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
|
||||
golang.org/x/crypto v0.0.0-20220209195652-db638375bc3a h1:atOEWVSedO4ksXBe/UrlbSLVxQQ9RxM/tT2Jy10IaHo=
|
||||
golang.org/x/crypto v0.0.0-20220209195652-db638375bc3a/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
|
||||
golang.org/x/crypto v0.0.0-20220214200702-86341886e292 h1:f+lwQ+GtmgoY+A2YaQxlSOnDjXcQ7ZRLWOHbC6HtRqE=
|
||||
golang.org/x/crypto v0.0.0-20220214200702-86341886e292/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
|
||||
golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
|
|
@ -1739,8 +1737,8 @@ golang.org/x/sys v0.0.0-20211007075335-d3039528d8ac/go.mod h1:oPkhp1MJrh7nUepCBc
|
|||
golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220111092808-5a964db01320/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220114195835-da31bd327af9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220207234003-57398862261d h1:Bm7BNOQt2Qv7ZqysjeLjgCBanX+88Z/OtdvsrEv1Djc=
|
||||
golang.org/x/sys v0.0.0-20220207234003-57398862261d/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220227234510-4e6760a101f9 h1:nhht2DYV/Sn3qOayu8lM+cU1ii9sTLUeBQwQQfUHtrs=
|
||||
golang.org/x/sys v0.0.0-20220227234510-4e6760a101f9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211 h1:JGgROgKl9N8DuW20oFS5gxc+lE67/N3FcwmBPMe7ArY=
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ const (
|
|||
FederationEventCacheName = "federation_event"
|
||||
FederationEventCacheMaxEntries = 256
|
||||
FederationEventCacheMutable = true // to allow use of Unset only
|
||||
FederationEventCacheMaxAge = CacheNoMaxAge
|
||||
)
|
||||
|
||||
// FederationCache contains the subset of functions needed for
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
package caching
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/roomserver/types"
|
||||
)
|
||||
|
||||
|
|
@ -16,6 +18,7 @@ const (
|
|||
RoomInfoCacheName = "roominfo"
|
||||
RoomInfoCacheMaxEntries = 1024
|
||||
RoomInfoCacheMutable = true
|
||||
RoomInfoCacheMaxAge = time.Minute * 5
|
||||
)
|
||||
|
||||
// RoomInfosCache contains the subset of functions needed for
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ const (
|
|||
RoomServerRoomIDsCacheName = "roomserver_room_ids"
|
||||
RoomServerRoomIDsCacheMaxEntries = 1024
|
||||
RoomServerRoomIDsCacheMutable = false
|
||||
RoomServerRoomIDsCacheMaxAge = CacheNoMaxAge
|
||||
)
|
||||
|
||||
type RoomServerCaches interface {
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ const (
|
|||
RoomVersionCacheName = "room_versions"
|
||||
RoomVersionCacheMaxEntries = 1024
|
||||
RoomVersionCacheMutable = false
|
||||
RoomVersionCacheMaxAge = CacheNoMaxAge
|
||||
)
|
||||
|
||||
// RoomVersionsCache contains the subset of functions needed for
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ const (
|
|||
ServerKeyCacheName = "server_key"
|
||||
ServerKeyCacheMaxEntries = 4096
|
||||
ServerKeyCacheMutable = true
|
||||
ServerKeyCacheMaxAge = CacheNoMaxAge
|
||||
)
|
||||
|
||||
// ServerKeyCache contains the subset of functions needed for
|
||||
|
|
|
|||
33
internal/caching/cache_space_rooms.go
Normal file
33
internal/caching/cache_space_rooms.go
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
package caching
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
const (
|
||||
SpaceSummaryRoomsCacheName = "space_summary_rooms"
|
||||
SpaceSummaryRoomsCacheMaxEntries = 100
|
||||
SpaceSummaryRoomsCacheMutable = true
|
||||
SpaceSummaryRoomsCacheMaxAge = time.Minute * 5
|
||||
)
|
||||
|
||||
type SpaceSummaryRoomsCache interface {
|
||||
GetSpaceSummary(roomID string) (r gomatrixserverlib.MSC2946SpacesResponse, ok bool)
|
||||
StoreSpaceSummary(roomID string, r gomatrixserverlib.MSC2946SpacesResponse)
|
||||
}
|
||||
|
||||
func (c Caches) GetSpaceSummary(roomID string) (r gomatrixserverlib.MSC2946SpacesResponse, ok bool) {
|
||||
val, found := c.SpaceSummaryRooms.Get(roomID)
|
||||
if found && val != nil {
|
||||
if resp, ok := val.(gomatrixserverlib.MSC2946SpacesResponse); ok {
|
||||
return resp, true
|
||||
}
|
||||
}
|
||||
return r, false
|
||||
}
|
||||
|
||||
func (c Caches) StoreSpaceSummary(roomID string, r gomatrixserverlib.MSC2946SpacesResponse) {
|
||||
c.SpaceSummaryRooms.Set(roomID, r)
|
||||
}
|
||||
|
|
@ -1,5 +1,7 @@
|
|||
package caching
|
||||
|
||||
import "time"
|
||||
|
||||
// Caches contains a set of references to caches. They may be
|
||||
// different implementations as long as they satisfy the Cache
|
||||
// interface.
|
||||
|
|
@ -10,6 +12,7 @@ type Caches struct {
|
|||
RoomServerRoomIDs Cache // RoomServerNIDsCache
|
||||
RoomInfos Cache // RoomInfoCache
|
||||
FederationEvents Cache // FederationEventsCache
|
||||
SpaceSummaryRooms Cache // SpaceSummaryRoomsCache
|
||||
}
|
||||
|
||||
// Cache is the interface that an implementation must satisfy.
|
||||
|
|
@ -18,3 +21,5 @@ type Cache interface {
|
|||
Set(key string, value interface{})
|
||||
Unset(key string)
|
||||
}
|
||||
|
||||
const CacheNoMaxAge = time.Duration(0)
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ func NewInMemoryLRUCache(enablePrometheus bool) (*Caches, error) {
|
|||
RoomVersionCacheName,
|
||||
RoomVersionCacheMutable,
|
||||
RoomVersionCacheMaxEntries,
|
||||
RoomVersionCacheMaxAge,
|
||||
enablePrometheus,
|
||||
)
|
||||
if err != nil {
|
||||
|
|
@ -23,6 +24,7 @@ func NewInMemoryLRUCache(enablePrometheus bool) (*Caches, error) {
|
|||
ServerKeyCacheName,
|
||||
ServerKeyCacheMutable,
|
||||
ServerKeyCacheMaxEntries,
|
||||
ServerKeyCacheMaxAge,
|
||||
enablePrometheus,
|
||||
)
|
||||
if err != nil {
|
||||
|
|
@ -32,6 +34,7 @@ func NewInMemoryLRUCache(enablePrometheus bool) (*Caches, error) {
|
|||
RoomServerRoomIDsCacheName,
|
||||
RoomServerRoomIDsCacheMutable,
|
||||
RoomServerRoomIDsCacheMaxEntries,
|
||||
RoomServerRoomIDsCacheMaxAge,
|
||||
enablePrometheus,
|
||||
)
|
||||
if err != nil {
|
||||
|
|
@ -41,6 +44,7 @@ func NewInMemoryLRUCache(enablePrometheus bool) (*Caches, error) {
|
|||
RoomInfoCacheName,
|
||||
RoomInfoCacheMutable,
|
||||
RoomInfoCacheMaxEntries,
|
||||
RoomInfoCacheMaxAge,
|
||||
enablePrometheus,
|
||||
)
|
||||
if err != nil {
|
||||
|
|
@ -50,6 +54,17 @@ func NewInMemoryLRUCache(enablePrometheus bool) (*Caches, error) {
|
|||
FederationEventCacheName,
|
||||
FederationEventCacheMutable,
|
||||
FederationEventCacheMaxEntries,
|
||||
FederationEventCacheMaxAge,
|
||||
enablePrometheus,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
spaceRooms, err := NewInMemoryLRUCachePartition(
|
||||
SpaceSummaryRoomsCacheName,
|
||||
SpaceSummaryRoomsCacheMutable,
|
||||
SpaceSummaryRoomsCacheMaxEntries,
|
||||
SpaceSummaryRoomsCacheMaxAge,
|
||||
enablePrometheus,
|
||||
)
|
||||
if err != nil {
|
||||
|
|
@ -57,7 +72,7 @@ func NewInMemoryLRUCache(enablePrometheus bool) (*Caches, error) {
|
|||
}
|
||||
go cacheCleaner(
|
||||
roomVersions, serverKeys, roomServerRoomIDs,
|
||||
roomInfos, federationEvents,
|
||||
roomInfos, federationEvents, spaceRooms,
|
||||
)
|
||||
return &Caches{
|
||||
RoomVersions: roomVersions,
|
||||
|
|
@ -65,6 +80,7 @@ func NewInMemoryLRUCache(enablePrometheus bool) (*Caches, error) {
|
|||
RoomServerRoomIDs: roomServerRoomIDs,
|
||||
RoomInfos: roomInfos,
|
||||
FederationEvents: federationEvents,
|
||||
SpaceSummaryRooms: spaceRooms,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
|
@ -86,15 +102,22 @@ type InMemoryLRUCachePartition struct {
|
|||
name string
|
||||
mutable bool
|
||||
maxEntries int
|
||||
maxAge time.Duration
|
||||
lru *lru.Cache
|
||||
}
|
||||
|
||||
func NewInMemoryLRUCachePartition(name string, mutable bool, maxEntries int, enablePrometheus bool) (*InMemoryLRUCachePartition, error) {
|
||||
type inMemoryLRUCacheEntry struct {
|
||||
value interface{}
|
||||
created time.Time
|
||||
}
|
||||
|
||||
func NewInMemoryLRUCachePartition(name string, mutable bool, maxEntries int, maxAge time.Duration, enablePrometheus bool) (*InMemoryLRUCachePartition, error) {
|
||||
var err error
|
||||
cache := InMemoryLRUCachePartition{
|
||||
name: name,
|
||||
mutable: mutable,
|
||||
maxEntries: maxEntries,
|
||||
maxAge: maxAge,
|
||||
}
|
||||
cache.lru, err = lru.New(maxEntries)
|
||||
if err != nil {
|
||||
|
|
@ -114,11 +137,16 @@ func NewInMemoryLRUCachePartition(name string, mutable bool, maxEntries int, ena
|
|||
|
||||
func (c *InMemoryLRUCachePartition) Set(key string, value interface{}) {
|
||||
if !c.mutable {
|
||||
if peek, ok := c.lru.Peek(key); ok && peek != value {
|
||||
if peek, ok := c.lru.Peek(key); ok {
|
||||
if entry, ok := peek.(*inMemoryLRUCacheEntry); ok && entry.value != value {
|
||||
panic(fmt.Sprintf("invalid use of immutable cache tries to mutate existing value of %q", key))
|
||||
}
|
||||
}
|
||||
c.lru.Add(key, value)
|
||||
}
|
||||
c.lru.Add(key, &inMemoryLRUCacheEntry{
|
||||
value: value,
|
||||
created: time.Now(),
|
||||
})
|
||||
}
|
||||
|
||||
func (c *InMemoryLRUCachePartition) Unset(key string) {
|
||||
|
|
@ -129,5 +157,20 @@ func (c *InMemoryLRUCachePartition) Unset(key string) {
|
|||
}
|
||||
|
||||
func (c *InMemoryLRUCachePartition) Get(key string) (value interface{}, ok bool) {
|
||||
return c.lru.Get(key)
|
||||
v, ok := c.lru.Get(key)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
entry, ok := v.(*inMemoryLRUCacheEntry)
|
||||
switch {
|
||||
case ok && c.maxAge == CacheNoMaxAge:
|
||||
return entry.value, ok // There's no maximum age policy
|
||||
case ok && time.Since(entry.created) < c.maxAge:
|
||||
return entry.value, ok // The value for the key isn't stale
|
||||
default:
|
||||
// Either the key was found and it was stale, or the key
|
||||
// wasn't found at all
|
||||
c.lru.Remove(key)
|
||||
return nil, false
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -166,8 +166,10 @@ func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.P
|
|||
}
|
||||
|
||||
// We can't have a self-signing or user-signing key without a master
|
||||
// key, so make sure we have one of those.
|
||||
if !hasMasterKey {
|
||||
// key, so make sure we have one of those. We will also only actually do
|
||||
// something if any of the specified keys in the request are different
|
||||
// to what we've got in the database, to avoid generating key change
|
||||
// notifications unnecessarily.
|
||||
existingKeys, err := a.DB.CrossSigningKeysDataForUser(ctx, req.UserID)
|
||||
if err != nil {
|
||||
res.Error = &api.KeyError{
|
||||
|
|
@ -176,18 +178,43 @@ func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.P
|
|||
return
|
||||
}
|
||||
|
||||
_, hasMasterKey = existingKeys[gomatrixserverlib.CrossSigningKeyPurposeMaster]
|
||||
}
|
||||
|
||||
// If we still can't find a master key for the user then stop the upload.
|
||||
// This satisfies the "Fails to upload self-signing key without master key" test.
|
||||
if !hasMasterKey {
|
||||
if _, hasMasterKey = existingKeys[gomatrixserverlib.CrossSigningKeyPurposeMaster]; !hasMasterKey {
|
||||
res.Error = &api.KeyError{
|
||||
Err: "No master key was found",
|
||||
IsMissingParam: true,
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Check if anything actually changed compared to what we have in the database.
|
||||
changed := false
|
||||
for _, purpose := range []gomatrixserverlib.CrossSigningKeyPurpose{
|
||||
gomatrixserverlib.CrossSigningKeyPurposeMaster,
|
||||
gomatrixserverlib.CrossSigningKeyPurposeSelfSigning,
|
||||
gomatrixserverlib.CrossSigningKeyPurposeUserSigning,
|
||||
} {
|
||||
old, gotOld := existingKeys[purpose]
|
||||
new, gotNew := toStore[purpose]
|
||||
if gotOld != gotNew {
|
||||
// A new key purpose has been specified that we didn't know before,
|
||||
// or one has been removed.
|
||||
changed = true
|
||||
break
|
||||
}
|
||||
if !bytes.Equal(old, new) {
|
||||
// One of the existing keys for a purpose we already knew about has
|
||||
// changed.
|
||||
changed = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !changed {
|
||||
return
|
||||
}
|
||||
|
||||
// Store the keys.
|
||||
if err := a.DB.StoreCrossSigningKeysForUser(ctx, req.UserID, toStore); err != nil {
|
||||
|
|
|
|||
|
|
@ -48,13 +48,19 @@ type mockDeviceListUpdaterDatabase struct {
|
|||
staleUsers map[string]bool
|
||||
prevIDsExist func(string, []int) bool
|
||||
storedKeys []api.DeviceMessage
|
||||
mu sync.Mutex // protect staleUsers
|
||||
}
|
||||
|
||||
// StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists.
|
||||
// If no domains are given, all user IDs with stale device lists are returned.
|
||||
func (d *mockDeviceListUpdaterDatabase) StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
var result []string
|
||||
for userID := range d.staleUsers {
|
||||
for userID, isStale := range d.staleUsers {
|
||||
if !isStale {
|
||||
continue
|
||||
}
|
||||
_, remoteServer, err := gomatrixserverlib.SplitID('@', userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
@ -75,6 +81,8 @@ func (d *mockDeviceListUpdaterDatabase) StaleDeviceLists(ctx context.Context, do
|
|||
|
||||
// MarkDeviceListStale sets the stale bit for this user to isStale.
|
||||
func (d *mockDeviceListUpdaterDatabase) MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
d.staleUsers[userID] = isStale
|
||||
return nil
|
||||
}
|
||||
|
|
@ -247,3 +255,82 @@ func TestUpdateNoPrevID(t *testing.T) {
|
|||
}
|
||||
|
||||
}
|
||||
|
||||
// Test that if we make N calls to ManualUpdate for the same user, we only do it once, assuming the
|
||||
// update is still ongoing.
|
||||
func TestDebounce(t *testing.T) {
|
||||
db := &mockDeviceListUpdaterDatabase{
|
||||
staleUsers: make(map[string]bool),
|
||||
prevIDsExist: func(string, []int) bool {
|
||||
return true
|
||||
},
|
||||
}
|
||||
ap := &mockDeviceListUpdaterAPI{}
|
||||
producer := &mockKeyChangeProducer{}
|
||||
fedCh := make(chan *http.Response, 1)
|
||||
srv := gomatrixserverlib.ServerName("example.com")
|
||||
userID := "@alice:example.com"
|
||||
keyJSON := `{"user_id":"` + userID + `","device_id":"JLAFKJWSCS","algorithms":["m.olm.v1.curve25519-aes-sha2","m.megolm.v1.aes-sha2"],"keys":{"curve25519:JLAFKJWSCS":"3C5BFWi2Y8MaVvjM8M22DBmh24PmgR0nPvJOIArzgyI","ed25519:JLAFKJWSCS":"lEuiRJBit0IG6nUf5pUzWTUEsRVVe/HJkoKuEww9ULI"},"signatures":{"` + userID + `":{"ed25519:JLAFKJWSCS":"dSO80A01XiigH3uBiDVx/EjzaoycHcjq9lfQX0uWsqxl2giMIiSPR8a4d291W1ihKJL/a+myXS367WT6NAIcBA"}}}`
|
||||
incomingFedReq := make(chan struct{})
|
||||
fedClient := newFedClient(func(req *http.Request) (*http.Response, error) {
|
||||
if req.URL.Path != "/_matrix/federation/v1/user/devices/"+url.PathEscape(userID) {
|
||||
return nil, fmt.Errorf("test: invalid path: %s", req.URL.Path)
|
||||
}
|
||||
close(incomingFedReq)
|
||||
return <-fedCh, nil
|
||||
})
|
||||
updater := NewDeviceListUpdater(db, ap, producer, fedClient, 1)
|
||||
if err := updater.Start(); err != nil {
|
||||
t.Fatalf("failed to start updater: %s", err)
|
||||
}
|
||||
|
||||
// hit this 5 times
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(5)
|
||||
for i := 0; i < 5; i++ {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if err := updater.ManualUpdate(context.Background(), srv, userID); err != nil {
|
||||
t.Errorf("ManualUpdate: %s", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// wait until the updater hits federation
|
||||
select {
|
||||
case <-incomingFedReq:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatalf("timed out waiting for updater to hit federation")
|
||||
}
|
||||
|
||||
// user should be marked as stale
|
||||
if !db.staleUsers[userID] {
|
||||
t.Errorf("user %s not marked as stale", userID)
|
||||
}
|
||||
// now send the response over federation
|
||||
fedCh <- &http.Response{
|
||||
StatusCode: 200,
|
||||
Body: ioutil.NopCloser(strings.NewReader(`
|
||||
{
|
||||
"user_id": "` + userID + `",
|
||||
"stream_id": 5,
|
||||
"devices": [
|
||||
{
|
||||
"device_id": "JLAFKJWSCS",
|
||||
"keys": ` + keyJSON + `,
|
||||
"device_display_name": "Mobile Phone"
|
||||
}
|
||||
]
|
||||
}
|
||||
`)),
|
||||
}
|
||||
close(fedCh)
|
||||
// wait until all 5 ManualUpdates return. If we hit federation again we won't send a response
|
||||
// and should panic with read on a closed channel
|
||||
wg.Wait()
|
||||
|
||||
// user is no longer stale now
|
||||
if db.staleUsers[userID] {
|
||||
t.Errorf("user %s is marked as stale", userID)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -269,6 +269,7 @@ type QueryAuthChainResponse struct {
|
|||
|
||||
type QuerySharedUsersRequest struct {
|
||||
UserID string
|
||||
OtherUserIDs []string
|
||||
ExcludeRoomIDs []string
|
||||
IncludeRoomIDs []string
|
||||
}
|
||||
|
|
@ -313,6 +314,9 @@ type QueryBulkStateContentResponse struct {
|
|||
|
||||
type QueryCurrentStateRequest struct {
|
||||
RoomID string
|
||||
AllowWildcards bool
|
||||
// State key tuples. If a state_key has '*' and AllowWidlcards is true, returns all matching
|
||||
// state events with that event type.
|
||||
StateTuples []gomatrixserverlib.StateKeyTuple
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -51,12 +51,8 @@ func SendEventWithState(
|
|||
state *gomatrixserverlib.RespState, event *gomatrixserverlib.HeaderedEvent,
|
||||
origin gomatrixserverlib.ServerName, haveEventIDs map[string]bool, async bool,
|
||||
) error {
|
||||
outliers, err := state.Events(event.RoomVersion)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var ires []InputRoomEvent
|
||||
outliers := state.Events(event.RoomVersion)
|
||||
ires := make([]InputRoomEvent, 0, len(outliers))
|
||||
for _, outlier := range outliers {
|
||||
if haveEventIDs[outlier.EventID()] {
|
||||
continue
|
||||
|
|
|
|||
|
|
@ -20,22 +20,17 @@ import (
|
|||
"sort"
|
||||
|
||||
"github.com/matrix-org/dendrite/roomserver/state"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage"
|
||||
"github.com/matrix-org/dendrite/roomserver/types"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
type checkForAuthAndSoftFailStorage interface {
|
||||
state.StateResolutionStorage
|
||||
StateEntriesForEventIDs(ctx context.Context, eventIDs []string) ([]types.StateEntry, error)
|
||||
RoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error)
|
||||
}
|
||||
|
||||
// CheckForSoftFail returns true if the event should be soft-failed
|
||||
// and false otherwise. The return error value should be checked before
|
||||
// the soft-fail bool.
|
||||
func CheckForSoftFail(
|
||||
ctx context.Context,
|
||||
db checkForAuthAndSoftFailStorage,
|
||||
db storage.Database,
|
||||
event *gomatrixserverlib.HeaderedEvent,
|
||||
stateEventIDs []string,
|
||||
) (bool, error) {
|
||||
|
|
@ -97,7 +92,7 @@ func CheckForSoftFail(
|
|||
// Returns the numeric IDs for the auth events.
|
||||
func CheckAuthEvents(
|
||||
ctx context.Context,
|
||||
db checkForAuthAndSoftFailStorage,
|
||||
db storage.Database,
|
||||
event *gomatrixserverlib.HeaderedEvent,
|
||||
authEventIDs []string,
|
||||
) ([]types.EventNID, error) {
|
||||
|
|
|
|||
|
|
@ -19,7 +19,6 @@ import (
|
|||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
|
|
@ -40,19 +39,6 @@ import (
|
|||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
type retryAction int
|
||||
type commitAction int
|
||||
|
||||
const (
|
||||
doNotRetry retryAction = iota
|
||||
retryLater
|
||||
)
|
||||
|
||||
const (
|
||||
commitTransaction commitAction = iota
|
||||
rollbackTransaction
|
||||
)
|
||||
|
||||
var keyContentFields = map[string]string{
|
||||
"m.room.join_rules": "join_rule",
|
||||
"m.room.history_visibility": "history_visibility",
|
||||
|
|
@ -117,8 +103,7 @@ func (r *Inputer) Start() error {
|
|||
_ = msg.InProgress() // resets the acknowledgement wait timer
|
||||
defer eventsInProgress.Delete(index)
|
||||
defer roomserverInputBackpressure.With(prometheus.Labels{"room_id": roomID}).Dec()
|
||||
action, err := r.processRoomEventUsingUpdater(r.ProcessContext.Context(), roomID, &inputRoomEvent)
|
||||
if err != nil {
|
||||
if err := r.processRoomEvent(r.ProcessContext.Context(), &inputRoomEvent); err != nil {
|
||||
if !errors.Is(err, context.DeadlineExceeded) && !errors.Is(err, context.Canceled) {
|
||||
sentry.CaptureException(err)
|
||||
}
|
||||
|
|
@ -127,11 +112,8 @@ func (r *Inputer) Start() error {
|
|||
"event_id": inputRoomEvent.Event.EventID(),
|
||||
"type": inputRoomEvent.Event.Type(),
|
||||
}).Warn("Roomserver failed to process async event")
|
||||
}
|
||||
switch action {
|
||||
case retryLater:
|
||||
_ = msg.Nak()
|
||||
case doNotRetry:
|
||||
_ = msg.Term()
|
||||
} else {
|
||||
_ = msg.Ack()
|
||||
}
|
||||
})
|
||||
|
|
@ -153,37 +135,6 @@ func (r *Inputer) Start() error {
|
|||
return err
|
||||
}
|
||||
|
||||
// processRoomEventUsingUpdater opens up a room updater and tries to
|
||||
// process the event. It returns whether or not we should positively
|
||||
// or negatively acknowledge the event (i.e. for NATS) and an error
|
||||
// if it occurred.
|
||||
func (r *Inputer) processRoomEventUsingUpdater(
|
||||
ctx context.Context,
|
||||
roomID string,
|
||||
inputRoomEvent *api.InputRoomEvent,
|
||||
) (retryAction, error) {
|
||||
roomInfo, err := r.DB.RoomInfo(ctx, roomID)
|
||||
if err != nil {
|
||||
return doNotRetry, fmt.Errorf("r.DB.RoomInfo: %w", err)
|
||||
}
|
||||
updater, err := r.DB.GetRoomUpdater(ctx, roomInfo)
|
||||
if err != nil {
|
||||
return retryLater, fmt.Errorf("r.DB.GetRoomUpdater: %w", err)
|
||||
}
|
||||
action, err := r.processRoomEvent(ctx, updater, inputRoomEvent)
|
||||
switch action {
|
||||
case commitTransaction:
|
||||
if cerr := updater.Commit(); cerr != nil {
|
||||
return retryLater, fmt.Errorf("updater.Commit: %w", cerr)
|
||||
}
|
||||
case rollbackTransaction:
|
||||
if rerr := updater.Rollback(); rerr != nil {
|
||||
return retryLater, fmt.Errorf("updater.Rollback: %w", rerr)
|
||||
}
|
||||
}
|
||||
return doNotRetry, err
|
||||
}
|
||||
|
||||
// InputRoomEvents implements api.RoomserverInternalAPI
|
||||
func (r *Inputer) InputRoomEvents(
|
||||
ctx context.Context,
|
||||
|
|
@ -230,7 +181,7 @@ func (r *Inputer) InputRoomEvents(
|
|||
worker.Act(nil, func() {
|
||||
defer eventsInProgress.Delete(index)
|
||||
defer roomserverInputBackpressure.With(prometheus.Labels{"room_id": roomID}).Dec()
|
||||
_, err := r.processRoomEventUsingUpdater(ctx, roomID, &inputRoomEvent)
|
||||
err := r.processRoomEvent(ctx, &inputRoomEvent)
|
||||
if err != nil {
|
||||
if !errors.Is(err, context.DeadlineExceeded) && !errors.Is(err, context.Canceled) {
|
||||
sentry.CaptureException(err)
|
||||
|
|
|
|||
|
|
@ -26,10 +26,10 @@ import (
|
|||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/eventutil"
|
||||
"github.com/matrix-org/dendrite/internal/hooks"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/roomserver/api"
|
||||
"github.com/matrix-org/dendrite/roomserver/internal/helpers"
|
||||
"github.com/matrix-org/dendrite/roomserver/state"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/shared"
|
||||
"github.com/matrix-org/dendrite/roomserver/types"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/matrix-org/util"
|
||||
|
|
@ -68,15 +68,14 @@ var processRoomEventDuration = prometheus.NewHistogramVec(
|
|||
// nolint:gocyclo
|
||||
func (r *Inputer) processRoomEvent(
|
||||
ctx context.Context,
|
||||
updater *shared.RoomUpdater,
|
||||
input *api.InputRoomEvent,
|
||||
) (commitAction, error) {
|
||||
) error {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// Before we do anything, make sure the context hasn't expired for this pending task.
|
||||
// If it has then we'll give up straight away — it's probably a synchronous input
|
||||
// request and the caller has already given up, but the inbox task was still queued.
|
||||
return rollbackTransaction, context.DeadlineExceeded
|
||||
return context.DeadlineExceeded
|
||||
default:
|
||||
}
|
||||
|
||||
|
|
@ -109,7 +108,7 @@ func (r *Inputer) processRoomEvent(
|
|||
// if we have already got this event then do not process it again, if the input kind is an outlier.
|
||||
// Outliers contain no extra information which may warrant a re-processing.
|
||||
if input.Kind == api.KindOutlier {
|
||||
evs, err2 := updater.EventsFromIDs(ctx, []string{event.EventID()})
|
||||
evs, err2 := r.DB.EventsFromIDs(ctx, []string{event.EventID()})
|
||||
if err2 == nil && len(evs) == 1 {
|
||||
// check hash matches if we're on early room versions where the event ID was a random string
|
||||
idFormat, err2 := headered.RoomVersion.EventIDFormat()
|
||||
|
|
@ -118,11 +117,11 @@ func (r *Inputer) processRoomEvent(
|
|||
case gomatrixserverlib.EventIDFormatV1:
|
||||
if bytes.Equal(event.EventReference().EventSHA256, evs[0].EventReference().EventSHA256) {
|
||||
logger.Debugf("Already processed event; ignoring")
|
||||
return rollbackTransaction, nil
|
||||
return nil
|
||||
}
|
||||
default:
|
||||
logger.Debugf("Already processed event; ignoring")
|
||||
return rollbackTransaction, nil
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -131,17 +130,21 @@ func (r *Inputer) processRoomEvent(
|
|||
// Don't waste time processing the event if the room doesn't exist.
|
||||
// A room entry locally will only be created in response to a create
|
||||
// event.
|
||||
roomInfo, rerr := r.DB.RoomInfo(ctx, event.RoomID())
|
||||
if rerr != nil {
|
||||
return fmt.Errorf("r.DB.RoomInfo: %w", rerr)
|
||||
}
|
||||
isCreateEvent := event.Type() == gomatrixserverlib.MRoomCreate && event.StateKeyEquals("")
|
||||
if !updater.RoomExists() && !isCreateEvent {
|
||||
return rollbackTransaction, fmt.Errorf("room %s does not exist for event %s", event.RoomID(), event.EventID())
|
||||
if roomInfo == nil && !isCreateEvent {
|
||||
return fmt.Errorf("room %s does not exist for event %s", event.RoomID(), event.EventID())
|
||||
}
|
||||
|
||||
var missingAuth, missingPrev bool
|
||||
serverRes := &fedapi.QueryJoinedHostServerNamesInRoomResponse{}
|
||||
if !isCreateEvent {
|
||||
missingAuthIDs, missingPrevIDs, err := updater.MissingAuthPrevEvents(ctx, event)
|
||||
missingAuthIDs, missingPrevIDs, err := r.DB.MissingAuthPrevEvents(ctx, event)
|
||||
if err != nil {
|
||||
return rollbackTransaction, fmt.Errorf("updater.MissingAuthPrevEvents: %w", err)
|
||||
return fmt.Errorf("updater.MissingAuthPrevEvents: %w", err)
|
||||
}
|
||||
missingAuth = len(missingAuthIDs) > 0
|
||||
missingPrev = !input.HasState && len(missingPrevIDs) > 0
|
||||
|
|
@ -153,7 +156,7 @@ func (r *Inputer) processRoomEvent(
|
|||
ExcludeSelf: true,
|
||||
}
|
||||
if err := r.FSAPI.QueryJoinedHostServerNamesInRoom(ctx, serverReq, serverRes); err != nil {
|
||||
return rollbackTransaction, fmt.Errorf("r.FSAPI.QueryJoinedHostServerNamesInRoom: %w", err)
|
||||
return fmt.Errorf("r.FSAPI.QueryJoinedHostServerNamesInRoom: %w", err)
|
||||
}
|
||||
// Sort all of the servers into a map so that we can randomise
|
||||
// their order. Then make sure that the input origin and the
|
||||
|
|
@ -182,8 +185,8 @@ func (r *Inputer) processRoomEvent(
|
|||
isRejected := false
|
||||
authEvents := gomatrixserverlib.NewAuthEvents(nil)
|
||||
knownEvents := map[string]*types.Event{}
|
||||
if err := r.fetchAuthEvents(ctx, updater, logger, headered, &authEvents, knownEvents, serverRes.ServerNames); err != nil {
|
||||
return rollbackTransaction, fmt.Errorf("r.fetchAuthEvents: %w", err)
|
||||
if err := r.fetchAuthEvents(ctx, logger, headered, &authEvents, knownEvents, serverRes.ServerNames); err != nil {
|
||||
return fmt.Errorf("r.fetchAuthEvents: %w", err)
|
||||
}
|
||||
|
||||
// Check if the event is allowed by its auth events. If it isn't then
|
||||
|
|
@ -205,12 +208,12 @@ func (r *Inputer) processRoomEvent(
|
|||
// but weren't found.
|
||||
if isRejected {
|
||||
if event.StateKey() != nil {
|
||||
return commitTransaction, fmt.Errorf(
|
||||
return fmt.Errorf(
|
||||
"missing auth event %s for state event %s (type %q, state key %q)",
|
||||
authEventID, event.EventID(), event.Type(), *event.StateKey(),
|
||||
)
|
||||
} else {
|
||||
return commitTransaction, fmt.Errorf(
|
||||
return fmt.Errorf(
|
||||
"missing auth event %s for timeline event %s (type %q)",
|
||||
authEventID, event.EventID(), event.Type(),
|
||||
)
|
||||
|
|
@ -226,7 +229,7 @@ func (r *Inputer) processRoomEvent(
|
|||
// Check that the event passes authentication checks based on the
|
||||
// current room state.
|
||||
var err error
|
||||
softfail, err = helpers.CheckForSoftFail(ctx, updater, headered, input.StateEventIDs)
|
||||
softfail, err = helpers.CheckForSoftFail(ctx, r.DB, headered, input.StateEventIDs)
|
||||
if err != nil {
|
||||
logger.WithError(err).Warn("Error authing soft-failed event")
|
||||
}
|
||||
|
|
@ -250,7 +253,8 @@ func (r *Inputer) processRoomEvent(
|
|||
missingState := missingStateReq{
|
||||
origin: input.Origin,
|
||||
inputer: r,
|
||||
db: updater,
|
||||
db: r.DB,
|
||||
roomInfo: roomInfo,
|
||||
federation: r.FSAPI,
|
||||
keys: r.KeyRing,
|
||||
roomsMu: internal.NewMutexByRoom(),
|
||||
|
|
@ -290,16 +294,16 @@ func (r *Inputer) processRoomEvent(
|
|||
}
|
||||
|
||||
// Store the event.
|
||||
_, _, stateAtEvent, redactionEvent, redactedEventID, err := updater.StoreEvent(ctx, event, authEventNIDs, isRejected)
|
||||
_, _, stateAtEvent, redactionEvent, redactedEventID, err := r.DB.StoreEvent(ctx, event, authEventNIDs, isRejected)
|
||||
if err != nil {
|
||||
return rollbackTransaction, fmt.Errorf("updater.StoreEvent: %w", err)
|
||||
return fmt.Errorf("updater.StoreEvent: %w", err)
|
||||
}
|
||||
|
||||
// if storing this event results in it being redacted then do so.
|
||||
if !isRejected && redactedEventID == event.EventID() {
|
||||
r, rerr := eventutil.RedactEvent(redactionEvent, event)
|
||||
if rerr != nil {
|
||||
return rollbackTransaction, fmt.Errorf("eventutil.RedactEvent: %w", rerr)
|
||||
return fmt.Errorf("eventutil.RedactEvent: %w", rerr)
|
||||
}
|
||||
event = r
|
||||
}
|
||||
|
|
@ -310,23 +314,25 @@ func (r *Inputer) processRoomEvent(
|
|||
if input.Kind == api.KindOutlier {
|
||||
logger.Debug("Stored outlier")
|
||||
hooks.Run(hooks.KindNewEventPersisted, headered)
|
||||
return commitTransaction, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
roomInfo, err := updater.RoomInfo(ctx, event.RoomID())
|
||||
// Request the room info again — it's possible that the room has been
|
||||
// created by now if it didn't exist already.
|
||||
roomInfo, err = r.DB.RoomInfo(ctx, event.RoomID())
|
||||
if err != nil {
|
||||
return rollbackTransaction, fmt.Errorf("updater.RoomInfo: %w", err)
|
||||
return fmt.Errorf("updater.RoomInfo: %w", err)
|
||||
}
|
||||
if roomInfo == nil {
|
||||
return rollbackTransaction, fmt.Errorf("updater.RoomInfo missing for room %s", event.RoomID())
|
||||
return fmt.Errorf("updater.RoomInfo missing for room %s", event.RoomID())
|
||||
}
|
||||
|
||||
if input.HasState || (!missingPrev && stateAtEvent.BeforeStateSnapshotNID == 0) {
|
||||
// We haven't calculated a state for this event yet.
|
||||
// Lets calculate one.
|
||||
err = r.calculateAndSetState(ctx, updater, input, roomInfo, &stateAtEvent, event, isRejected)
|
||||
err = r.calculateAndSetState(ctx, input, roomInfo, &stateAtEvent, event, isRejected)
|
||||
if err != nil {
|
||||
return rollbackTransaction, fmt.Errorf("r.calculateAndSetState: %w", err)
|
||||
return fmt.Errorf("r.calculateAndSetState: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -337,16 +343,15 @@ func (r *Inputer) processRoomEvent(
|
|||
"missing_prev": missingPrev,
|
||||
}).Warn("Stored rejected event")
|
||||
if rejectionErr != nil {
|
||||
return commitTransaction, types.RejectedError(rejectionErr.Error())
|
||||
return types.RejectedError(rejectionErr.Error())
|
||||
}
|
||||
return commitTransaction, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
switch input.Kind {
|
||||
case api.KindNew:
|
||||
if err = r.updateLatestEvents(
|
||||
ctx, // context
|
||||
updater, // room updater
|
||||
roomInfo, // room info for the room being updated
|
||||
stateAtEvent, // state at event (below)
|
||||
event, // event
|
||||
|
|
@ -354,7 +359,7 @@ func (r *Inputer) processRoomEvent(
|
|||
input.TransactionID, // transaction ID
|
||||
input.HasState, // rewrites state?
|
||||
); err != nil {
|
||||
return rollbackTransaction, fmt.Errorf("r.updateLatestEvents: %w", err)
|
||||
return fmt.Errorf("r.updateLatestEvents: %w", err)
|
||||
}
|
||||
case api.KindOld:
|
||||
err = r.WriteOutputEvents(event.RoomID(), []api.OutputEvent{
|
||||
|
|
@ -366,7 +371,7 @@ func (r *Inputer) processRoomEvent(
|
|||
},
|
||||
})
|
||||
if err != nil {
|
||||
return rollbackTransaction, fmt.Errorf("r.WriteOutputEvents (old): %w", err)
|
||||
return fmt.Errorf("r.WriteOutputEvents (old): %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -385,14 +390,14 @@ func (r *Inputer) processRoomEvent(
|
|||
},
|
||||
})
|
||||
if err != nil {
|
||||
return rollbackTransaction, fmt.Errorf("r.WriteOutputEvents (redactions): %w", err)
|
||||
return fmt.Errorf("r.WriteOutputEvents (redactions): %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Everything was OK — the latest events updater didn't error and
|
||||
// we've sent output events. Finally, generate a hook call.
|
||||
hooks.Run(hooks.KindNewEventPersisted, headered)
|
||||
return commitTransaction, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// fetchAuthEvents will check to see if any of the
|
||||
|
|
@ -404,7 +409,6 @@ func (r *Inputer) processRoomEvent(
|
|||
// they are now in the database.
|
||||
func (r *Inputer) fetchAuthEvents(
|
||||
ctx context.Context,
|
||||
updater *shared.RoomUpdater,
|
||||
logger *logrus.Entry,
|
||||
event *gomatrixserverlib.HeaderedEvent,
|
||||
auth *gomatrixserverlib.AuthEvents,
|
||||
|
|
@ -418,7 +422,7 @@ func (r *Inputer) fetchAuthEvents(
|
|||
}
|
||||
|
||||
for _, authEventID := range authEventIDs {
|
||||
authEvents, err := updater.EventsFromIDs(ctx, []string{authEventID})
|
||||
authEvents, err := r.DB.EventsFromIDs(ctx, []string{authEventID})
|
||||
if err != nil || len(authEvents) == 0 || authEvents[0].Event == nil {
|
||||
unknown[authEventID] = struct{}{}
|
||||
continue
|
||||
|
|
@ -495,7 +499,7 @@ nextAuthEvent:
|
|||
}
|
||||
|
||||
// Finally, store the event in the database.
|
||||
eventNID, _, _, _, _, err := updater.StoreEvent(ctx, authEvent, authEventNIDs, isRejected)
|
||||
eventNID, _, _, _, _, err := r.DB.StoreEvent(ctx, authEvent, authEventNIDs, isRejected)
|
||||
if err != nil {
|
||||
return fmt.Errorf("updater.StoreEvent: %w", err)
|
||||
}
|
||||
|
|
@ -520,14 +524,18 @@ nextAuthEvent:
|
|||
|
||||
func (r *Inputer) calculateAndSetState(
|
||||
ctx context.Context,
|
||||
updater *shared.RoomUpdater,
|
||||
input *api.InputRoomEvent,
|
||||
roomInfo *types.RoomInfo,
|
||||
stateAtEvent *types.StateAtEvent,
|
||||
event *gomatrixserverlib.Event,
|
||||
isRejected bool,
|
||||
) error {
|
||||
var err error
|
||||
var succeeded bool
|
||||
updater, err := r.DB.GetRoomUpdater(ctx, roomInfo)
|
||||
if err != nil {
|
||||
return fmt.Errorf("r.DB.GetRoomUpdater: %w", err)
|
||||
}
|
||||
defer sqlutil.EndTransactionWithCheck(updater, &succeeded, &err)
|
||||
roomState := state.NewStateResolution(updater, roomInfo)
|
||||
|
||||
if input.HasState {
|
||||
|
|
@ -536,7 +544,7 @@ func (r *Inputer) calculateAndSetState(
|
|||
// We've been told what the state at the event is so we don't need to calculate it.
|
||||
// Check that those state events are in the database and store the state.
|
||||
var entries []types.StateEntry
|
||||
if entries, err = updater.StateEntriesForEventIDs(ctx, input.StateEventIDs); err != nil {
|
||||
if entries, err = r.DB.StateEntriesForEventIDs(ctx, input.StateEventIDs); err != nil {
|
||||
return fmt.Errorf("updater.StateEntriesForEventIDs: %w", err)
|
||||
}
|
||||
entries = types.DeduplicateStateEntries(entries)
|
||||
|
|
@ -557,5 +565,6 @@ func (r *Inputer) calculateAndSetState(
|
|||
if err != nil {
|
||||
return fmt.Errorf("r.DB.SetState: %w", err)
|
||||
}
|
||||
succeeded = true
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ import (
|
|||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/roomserver/api"
|
||||
"github.com/matrix-org/dendrite/roomserver/state"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/shared"
|
||||
|
|
@ -47,7 +48,6 @@ import (
|
|||
// Can only be called once at a time
|
||||
func (r *Inputer) updateLatestEvents(
|
||||
ctx context.Context,
|
||||
updater *shared.RoomUpdater,
|
||||
roomInfo *types.RoomInfo,
|
||||
stateAtEvent types.StateAtEvent,
|
||||
event *gomatrixserverlib.Event,
|
||||
|
|
@ -55,6 +55,14 @@ func (r *Inputer) updateLatestEvents(
|
|||
transactionID *api.TransactionID,
|
||||
rewritesState bool,
|
||||
) (err error) {
|
||||
var succeeded bool
|
||||
updater, err := r.DB.GetRoomUpdater(ctx, roomInfo)
|
||||
if err != nil {
|
||||
return fmt.Errorf("r.DB.GetRoomUpdater: %w", err)
|
||||
}
|
||||
|
||||
defer sqlutil.EndTransactionWithCheck(updater, &succeeded, &err)
|
||||
|
||||
u := latestEventsUpdater{
|
||||
ctx: ctx,
|
||||
api: r,
|
||||
|
|
@ -71,6 +79,7 @@ func (r *Inputer) updateLatestEvents(
|
|||
return fmt.Errorf("u.doUpdateLatestEvents: %w", err)
|
||||
}
|
||||
|
||||
succeeded = true
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ import (
|
|||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/roomserver/api"
|
||||
"github.com/matrix-org/dendrite/roomserver/state"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/shared"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage"
|
||||
"github.com/matrix-org/dendrite/roomserver/types"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/matrix-org/util"
|
||||
|
|
@ -23,9 +23,25 @@ type parsedRespState struct {
|
|||
StateEvents []*gomatrixserverlib.Event
|
||||
}
|
||||
|
||||
func (p *parsedRespState) Events() []*gomatrixserverlib.Event {
|
||||
eventsByID := make(map[string]*gomatrixserverlib.Event, len(p.AuthEvents)+len(p.StateEvents))
|
||||
for i, event := range p.AuthEvents {
|
||||
eventsByID[event.EventID()] = p.AuthEvents[i]
|
||||
}
|
||||
for i, event := range p.StateEvents {
|
||||
eventsByID[event.EventID()] = p.StateEvents[i]
|
||||
}
|
||||
allEvents := make([]*gomatrixserverlib.Event, 0, len(eventsByID))
|
||||
for _, event := range eventsByID {
|
||||
allEvents = append(allEvents, event)
|
||||
}
|
||||
return gomatrixserverlib.ReverseTopologicalOrdering(allEvents, gomatrixserverlib.TopologicalOrderByAuthEvents)
|
||||
}
|
||||
|
||||
type missingStateReq struct {
|
||||
origin gomatrixserverlib.ServerName
|
||||
db *shared.RoomUpdater
|
||||
db storage.Database
|
||||
roomInfo *types.RoomInfo
|
||||
inputer *Inputer
|
||||
keys gomatrixserverlib.JSONVerifier
|
||||
federation fedapi.FederationInternalAPI
|
||||
|
|
@ -80,7 +96,7 @@ func (t *missingStateReq) processEventWithMissingState(
|
|||
// we can just inject all the newEvents as new as we may have only missed 1 or 2 events and have filled
|
||||
// in the gap in the DAG
|
||||
for _, newEvent := range newEvents {
|
||||
_, err = t.inputer.processRoomEvent(ctx, t.db, &api.InputRoomEvent{
|
||||
err = t.inputer.processRoomEvent(ctx, &api.InputRoomEvent{
|
||||
Kind: api.KindOld,
|
||||
Event: newEvent.Headered(roomVersion),
|
||||
Origin: t.origin,
|
||||
|
|
@ -123,11 +139,8 @@ func (t *missingStateReq) processEventWithMissingState(
|
|||
t.hadEventsMutex.Unlock()
|
||||
|
||||
sendOutliers := func(resolvedState *parsedRespState) error {
|
||||
outliers, oerr := gomatrixserverlib.OrderAuthAndStateEvents(resolvedState.AuthEvents, resolvedState.StateEvents, roomVersion)
|
||||
if oerr != nil {
|
||||
return fmt.Errorf("gomatrixserverlib.OrderAuthAndStateEvents: %w", oerr)
|
||||
}
|
||||
var outlierRoomEvents []api.InputRoomEvent
|
||||
outliers := resolvedState.Events()
|
||||
outlierRoomEvents := make([]api.InputRoomEvent, 0, len(outliers))
|
||||
for _, outlier := range outliers {
|
||||
if hadEvents[outlier.EventID()] {
|
||||
continue
|
||||
|
|
@ -139,8 +152,7 @@ func (t *missingStateReq) processEventWithMissingState(
|
|||
})
|
||||
}
|
||||
for _, ire := range outlierRoomEvents {
|
||||
_, err = t.inputer.processRoomEvent(ctx, t.db, &ire)
|
||||
if err != nil {
|
||||
if err = t.inputer.processRoomEvent(ctx, &ire); err != nil {
|
||||
if _, ok := err.(types.RejectedError); !ok {
|
||||
return fmt.Errorf("t.inputer.processRoomEvent (outlier): %w", err)
|
||||
}
|
||||
|
|
@ -163,7 +175,7 @@ func (t *missingStateReq) processEventWithMissingState(
|
|||
stateIDs = append(stateIDs, event.EventID())
|
||||
}
|
||||
|
||||
_, err = t.inputer.processRoomEvent(ctx, t.db, &api.InputRoomEvent{
|
||||
err = t.inputer.processRoomEvent(ctx, &api.InputRoomEvent{
|
||||
Kind: api.KindOld,
|
||||
Event: backwardsExtremity.Headered(roomVersion),
|
||||
Origin: t.origin,
|
||||
|
|
@ -182,7 +194,7 @@ func (t *missingStateReq) processEventWithMissingState(
|
|||
// they will automatically fast-forward based on the room state at the
|
||||
// extremity in the last step.
|
||||
for _, newEvent := range newEvents {
|
||||
_, err = t.inputer.processRoomEvent(ctx, t.db, &api.InputRoomEvent{
|
||||
err = t.inputer.processRoomEvent(ctx, &api.InputRoomEvent{
|
||||
Kind: api.KindOld,
|
||||
Event: newEvent.Headered(roomVersion),
|
||||
Origin: t.origin,
|
||||
|
|
@ -473,8 +485,10 @@ retryAllowedState:
|
|||
// without `e`. If `isGapFilled=false` then `newEvents` contains the response to /get_missing_events
|
||||
func (t *missingStateReq) getMissingEvents(ctx context.Context, e *gomatrixserverlib.Event, roomVersion gomatrixserverlib.RoomVersion) (newEvents []*gomatrixserverlib.Event, isGapFilled, prevStateKnown bool, err error) {
|
||||
logger := util.GetLogger(ctx).WithField("event_id", e.EventID()).WithField("room_id", e.RoomID())
|
||||
|
||||
latest := t.db.LatestEvents()
|
||||
latest, _, _, err := t.db.LatestEventIDs(ctx, t.roomInfo.RoomNID)
|
||||
if err != nil {
|
||||
return nil, false, false, fmt.Errorf("t.DB.LatestEventIDs: %w", err)
|
||||
}
|
||||
latestEvents := make([]string, len(latest))
|
||||
for i, ev := range latest {
|
||||
latestEvents[i] = ev.EventID
|
||||
|
|
|
|||
|
|
@ -621,6 +621,18 @@ func (r *Queryer) QueryPublishedRooms(
|
|||
func (r *Queryer) QueryCurrentState(ctx context.Context, req *api.QueryCurrentStateRequest, res *api.QueryCurrentStateResponse) error {
|
||||
res.StateEvents = make(map[gomatrixserverlib.StateKeyTuple]*gomatrixserverlib.HeaderedEvent)
|
||||
for _, tuple := range req.StateTuples {
|
||||
if tuple.StateKey == "*" && req.AllowWildcards {
|
||||
events, err := r.DB.GetStateEventsWithEventType(ctx, req.RoomID, tuple.EventType)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, e := range events {
|
||||
res.StateEvents[gomatrixserverlib.StateKeyTuple{
|
||||
EventType: e.Type(),
|
||||
StateKey: *e.StateKey(),
|
||||
}] = e
|
||||
}
|
||||
} else {
|
||||
ev, err := r.DB.GetStateEvent(ctx, req.RoomID, tuple.EventType, tuple.StateKey)
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
@ -629,6 +641,7 @@ func (r *Queryer) QueryCurrentState(ctx context.Context, req *api.QueryCurrentSt
|
|||
res.StateEvents[tuple] = ev
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
@ -696,7 +709,7 @@ func (r *Queryer) QuerySharedUsers(ctx context.Context, req *api.QuerySharedUser
|
|||
}
|
||||
roomIDs = roomIDs[:j]
|
||||
|
||||
users, err := r.DB.JoinedUsersSetInRooms(ctx, roomIDs)
|
||||
users, err := r.DB.JoinedUsersSetInRooms(ctx, roomIDs, req.OtherUserIDs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -814,6 +814,7 @@ func (v *StateResolution) resolveConflictsV2(
|
|||
// events may be duplicated across these sets but that's OK.
|
||||
authSets := make(map[string][]*gomatrixserverlib.Event, len(conflicted))
|
||||
authEvents := make([]*gomatrixserverlib.Event, 0, estimate*3)
|
||||
gotAuthEvents := make(map[string]struct{}, estimate*3)
|
||||
authDifference := make([]*gomatrixserverlib.Event, 0, estimate)
|
||||
|
||||
// For each conflicted event, let's try and get the needed auth events.
|
||||
|
|
@ -850,8 +851,21 @@ func (v *StateResolution) resolveConflictsV2(
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
authEvents = append(authEvents, authSets[key]...)
|
||||
|
||||
// Only add auth events into the authEvents slice once, otherwise the
|
||||
// check for the auth difference can become expensive and produce
|
||||
// duplicate entries, which just waste memory and CPU time.
|
||||
for _, event := range authSets[key] {
|
||||
if _, ok := gotAuthEvents[event.EventID()]; !ok {
|
||||
authEvents = append(authEvents, event)
|
||||
gotAuthEvents[event.EventID()] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Kill the reference to this so that the GC may pick it up, since we no
|
||||
// longer need this after this point.
|
||||
gotAuthEvents = nil // nolint:ineffassign
|
||||
|
||||
// This function helps us to work out whether an event exists in one of the
|
||||
// auth sets.
|
||||
|
|
@ -866,11 +880,12 @@ func (v *StateResolution) resolveConflictsV2(
|
|||
|
||||
// This function works out if an event exists in all of the auth sets.
|
||||
isInAllAuthLists := func(event *gomatrixserverlib.Event) bool {
|
||||
found := true
|
||||
for k := range authSets {
|
||||
found = found && isInAuthList(k, event)
|
||||
if !isInAuthList(k, event) {
|
||||
return false
|
||||
}
|
||||
return found
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// Look through all of the auth events that we've been given and work out if
|
||||
|
|
|
|||
|
|
@ -35,6 +35,11 @@ type Database interface {
|
|||
stateBlockNIDs []types.StateBlockNID,
|
||||
state []types.StateEntry,
|
||||
) (types.StateSnapshotNID, error)
|
||||
|
||||
MissingAuthPrevEvents(
|
||||
ctx context.Context, e *gomatrixserverlib.Event,
|
||||
) (missingAuth, missingPrev []string, err error)
|
||||
|
||||
// Look up the state of a room at each event for a list of string event IDs.
|
||||
// Returns an error if there is an error talking to the database.
|
||||
// The length of []types.StateAtEvent is guaranteed to equal the length of eventIDs if no error is returned.
|
||||
|
|
@ -141,13 +146,14 @@ type Database interface {
|
|||
// If no event could be found, returns nil
|
||||
// If there was an issue during the retrieval, returns an error
|
||||
GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*gomatrixserverlib.HeaderedEvent, error)
|
||||
GetStateEventsWithEventType(ctx context.Context, roomID, evType string) ([]*gomatrixserverlib.HeaderedEvent, error)
|
||||
// GetRoomsByMembership returns a list of room IDs matching the provided membership and user ID (as state_key).
|
||||
GetRoomsByMembership(ctx context.Context, userID, membership string) ([]string, error)
|
||||
// GetBulkStateContent returns all state events which match a given room ID and a given state key tuple. Both must be satisfied for a match.
|
||||
// If a tuple has the StateKey of '*' and allowWildcards=true then all state events with the EventType should be returned.
|
||||
GetBulkStateContent(ctx context.Context, roomIDs []string, tuples []gomatrixserverlib.StateKeyTuple, allowWildcards bool) ([]tables.StrippedEvent, error)
|
||||
// JoinedUsersSetInRooms returns all joined users in the rooms given, along with the count of how many times they appear.
|
||||
JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) (map[string]int, error)
|
||||
// JoinedUsersSetInRooms returns how many times each of the given users appears across the given rooms.
|
||||
JoinedUsersSetInRooms(ctx context.Context, roomIDs, userIDs []string) (map[string]int, error)
|
||||
// GetLocalServerInRoom returns true if we think we're in a given room or false otherwise.
|
||||
GetLocalServerInRoom(ctx context.Context, roomNID types.RoomNID) (bool, error)
|
||||
// GetServerInRoom returns true if we think a server is in a given room or false otherwise.
|
||||
|
|
|
|||
|
|
@ -66,7 +66,8 @@ CREATE TABLE IF NOT EXISTS roomserver_membership (
|
|||
`
|
||||
|
||||
var selectJoinedUsersSetForRoomsSQL = "" +
|
||||
"SELECT target_nid, COUNT(room_nid) FROM roomserver_membership WHERE room_nid = ANY($1) AND" +
|
||||
"SELECT target_nid, COUNT(room_nid) FROM roomserver_membership" +
|
||||
" WHERE room_nid = ANY($1) AND target_nid = ANY($2) AND" +
|
||||
" membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " and forgotten = false" +
|
||||
" GROUP BY target_nid"
|
||||
|
||||
|
|
@ -306,13 +307,10 @@ func (s *membershipStatements) SelectRoomsWithMembership(
|
|||
func (s *membershipStatements) SelectJoinedUsersSetForRooms(
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
roomNIDs []types.RoomNID,
|
||||
userNIDs []types.EventStateKeyNID,
|
||||
) (map[types.EventStateKeyNID]int, error) {
|
||||
roomIDarray := make([]int64, len(roomNIDs))
|
||||
for i := range roomNIDs {
|
||||
roomIDarray[i] = int64(roomNIDs[i])
|
||||
}
|
||||
stmt := sqlutil.TxStmt(txn, s.selectJoinedUsersSetForRoomsStmt)
|
||||
rows, err := stmt.QueryContext(ctx, pq.Int64Array(roomIDarray))
|
||||
rows, err := stmt.QueryContext(ctx, pq.Array(roomNIDs), pq.Array(userNIDs))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -103,25 +103,6 @@ func (u *RoomUpdater) CurrentStateSnapshotNID() types.StateSnapshotNID {
|
|||
return u.currentStateSnapshotNID
|
||||
}
|
||||
|
||||
func (u *RoomUpdater) MissingAuthPrevEvents(
|
||||
ctx context.Context, e *gomatrixserverlib.Event,
|
||||
) (missingAuth, missingPrev []string, err error) {
|
||||
for _, authEventID := range e.AuthEventIDs() {
|
||||
if nids, err := u.EventNIDs(ctx, []string{authEventID}); err != nil || len(nids) == 0 {
|
||||
missingAuth = append(missingAuth, authEventID)
|
||||
}
|
||||
}
|
||||
|
||||
for _, prevEventID := range e.PrevEventIDs() {
|
||||
state, err := u.StateAtEventIDs(ctx, []string{prevEventID})
|
||||
if err != nil || len(state) == 0 || (!state[0].IsCreate() && state[0].BeforeStateSnapshotNID == 0) {
|
||||
missingPrev = append(missingPrev, prevEventID)
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// StorePreviousEvents implements types.RoomRecentEventsUpdater - This must be called from a Writer
|
||||
func (u *RoomUpdater) StorePreviousEvents(eventNID types.EventNID, previousEventReferences []gomatrixserverlib.EventReference) error {
|
||||
return u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error {
|
||||
|
|
@ -146,13 +127,6 @@ func (u *RoomUpdater) SnapshotNIDFromEventID(
|
|||
return u.d.snapshotNIDFromEventID(ctx, u.txn, eventID)
|
||||
}
|
||||
|
||||
func (u *RoomUpdater) StoreEvent(
|
||||
ctx context.Context, event *gomatrixserverlib.Event,
|
||||
authEventNIDs []types.EventNID, isRejected bool,
|
||||
) (types.EventNID, types.RoomNID, types.StateAtEvent, *gomatrixserverlib.Event, string, error) {
|
||||
return u.d.storeEvent(ctx, u, event, authEventNIDs, isRejected)
|
||||
}
|
||||
|
||||
func (u *RoomUpdater) StateBlockNIDs(
|
||||
ctx context.Context, stateNIDs []types.StateSnapshotNID,
|
||||
) ([]types.StateBlockNIDList, error) {
|
||||
|
|
@ -212,44 +186,16 @@ func (u *RoomUpdater) EventIDs(
|
|||
return u.d.EventsTable.BulkSelectEventID(ctx, u.txn, eventNIDs)
|
||||
}
|
||||
|
||||
func (u *RoomUpdater) EventNIDs(
|
||||
ctx context.Context, eventIDs []string,
|
||||
) (map[string]types.EventNID, error) {
|
||||
return u.d.eventNIDs(ctx, u.txn, eventIDs, NoFilter)
|
||||
}
|
||||
|
||||
func (u *RoomUpdater) UnsentEventNIDs(
|
||||
ctx context.Context, eventIDs []string,
|
||||
) (map[string]types.EventNID, error) {
|
||||
return u.d.eventNIDs(ctx, u.txn, eventIDs, FilterUnsentOnly)
|
||||
}
|
||||
|
||||
func (u *RoomUpdater) StateAtEventIDs(
|
||||
ctx context.Context, eventIDs []string,
|
||||
) ([]types.StateAtEvent, error) {
|
||||
return u.d.EventsTable.BulkSelectStateAtEventByID(ctx, u.txn, eventIDs)
|
||||
}
|
||||
|
||||
func (u *RoomUpdater) StateEntriesForEventIDs(
|
||||
ctx context.Context, eventIDs []string,
|
||||
) ([]types.StateEntry, error) {
|
||||
return u.d.EventsTable.BulkSelectStateEventByID(ctx, u.txn, eventIDs)
|
||||
}
|
||||
|
||||
func (u *RoomUpdater) EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) {
|
||||
return u.d.eventsFromIDs(ctx, u.txn, eventIDs, false)
|
||||
}
|
||||
|
||||
func (u *RoomUpdater) UnsentEventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) {
|
||||
return u.d.eventsFromIDs(ctx, u.txn, eventIDs, true)
|
||||
}
|
||||
|
||||
func (u *RoomUpdater) GetMembershipEventNIDsForRoom(
|
||||
ctx context.Context, roomNID types.RoomNID, joinOnly bool, localOnly bool,
|
||||
) ([]types.EventNID, error) {
|
||||
return u.d.getMembershipEventNIDsForRoom(ctx, u.txn, roomNID, joinOnly, localOnly)
|
||||
}
|
||||
|
||||
// IsReferenced implements types.RoomRecentEventsUpdater
|
||||
func (u *RoomUpdater) IsReferenced(eventReference gomatrixserverlib.EventReference) (bool, error) {
|
||||
err := u.d.PrevEventsTable.SelectPreviousEventExists(u.ctx, u.txn, eventReference.EventID, eventReference.EventSHA256)
|
||||
|
|
|
|||
|
|
@ -13,7 +13,6 @@ import (
|
|||
"github.com/matrix-org/dendrite/roomserver/types"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/matrix-org/util"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
|
|
@ -674,6 +673,29 @@ func (d *Database) GetPublishedRooms(ctx context.Context) ([]string, error) {
|
|||
return d.PublishedTable.SelectAllPublishedRooms(ctx, nil, true)
|
||||
}
|
||||
|
||||
func (d *Database) MissingAuthPrevEvents(
|
||||
ctx context.Context, e *gomatrixserverlib.Event,
|
||||
) (missingAuth, missingPrev []string, err error) {
|
||||
authEventNIDs, err := d.EventNIDs(ctx, e.AuthEventIDs())
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("d.EventNIDs: %w", err)
|
||||
}
|
||||
for _, authEventID := range e.AuthEventIDs() {
|
||||
if _, ok := authEventNIDs[authEventID]; !ok {
|
||||
missingAuth = append(missingAuth, authEventID)
|
||||
}
|
||||
}
|
||||
|
||||
for _, prevEventID := range e.PrevEventIDs() {
|
||||
state, err := d.StateAtEventIDs(ctx, []string{prevEventID})
|
||||
if err != nil || len(state) == 0 || (!state[0].IsCreate() && state[0].BeforeStateSnapshotNID == 0) {
|
||||
missingPrev = append(missingPrev, prevEventID)
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (d *Database) assignRoomNID(
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
roomID string, roomVersion gomatrixserverlib.RoomVersion,
|
||||
|
|
@ -956,6 +978,62 @@ func (d *Database) GetStateEvent(ctx context.Context, roomID, evType, stateKey s
|
|||
return nil, nil
|
||||
}
|
||||
|
||||
// Same as GetStateEvent but returns all matching state events with this event type. Returns no error
|
||||
// if there are no events with this event type.
|
||||
func (d *Database) GetStateEventsWithEventType(ctx context.Context, roomID, evType string) ([]*gomatrixserverlib.HeaderedEvent, error) {
|
||||
roomInfo, err := d.RoomInfo(ctx, roomID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if roomInfo == nil {
|
||||
return nil, fmt.Errorf("room %s doesn't exist", roomID)
|
||||
}
|
||||
// e.g invited rooms
|
||||
if roomInfo.IsStub {
|
||||
return nil, nil
|
||||
}
|
||||
eventTypeNID, err := d.EventTypesTable.SelectEventTypeNID(ctx, nil, evType)
|
||||
if err == sql.ErrNoRows {
|
||||
// No rooms have an event of this type, otherwise we'd have an event type NID
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
entries, err := d.loadStateAtSnapshot(ctx, roomInfo.StateSnapshotNID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var eventNIDs []types.EventNID
|
||||
for _, e := range entries {
|
||||
if e.EventTypeNID == eventTypeNID {
|
||||
eventNIDs = append(eventNIDs, e.EventNID)
|
||||
}
|
||||
}
|
||||
eventIDs, _ := d.EventsTable.BulkSelectEventID(ctx, nil, eventNIDs)
|
||||
if err != nil {
|
||||
eventIDs = map[types.EventNID]string{}
|
||||
}
|
||||
// return the events requested
|
||||
eventPairs, err := d.EventJSONTable.BulkSelectEventJSON(ctx, nil, eventNIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(eventPairs) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
var result []*gomatrixserverlib.HeaderedEvent
|
||||
for _, pair := range eventPairs {
|
||||
ev, err := gomatrixserverlib.NewEventFromTrustedJSONWithEventID(eventIDs[pair.EventNID], pair.EventJSON, false, roomInfo.RoomVersion)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result = append(result, ev.Headered(roomInfo.RoomVersion))
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GetRoomsByMembership returns a list of room IDs matching the provided membership and user ID (as state_key).
|
||||
func (d *Database) GetRoomsByMembership(ctx context.Context, userID, membership string) ([]string, error) {
|
||||
var membershipState tables.MembershipState
|
||||
|
|
@ -1081,13 +1159,23 @@ func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tu
|
|||
return result, nil
|
||||
}
|
||||
|
||||
// JoinedUsersSetInRooms returns all joined users in the rooms given, along with the count of how many times they appear.
|
||||
func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) (map[string]int, error) {
|
||||
// JoinedUsersSetInRooms returns a map of how many times the given users appear in the specified rooms.
|
||||
func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs, userIDs []string) (map[string]int, error) {
|
||||
roomNIDs, err := d.RoomsTable.BulkSelectRoomNIDs(ctx, nil, roomIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
userNIDToCount, err := d.MembershipTable.SelectJoinedUsersSetForRooms(ctx, nil, roomNIDs)
|
||||
userNIDsMap, err := d.EventStateKeysTable.BulkSelectEventStateKeyNID(ctx, nil, userIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
userNIDs := make([]types.EventStateKeyNID, 0, len(userNIDsMap))
|
||||
nidToUserID := make(map[types.EventStateKeyNID]string, len(userNIDsMap))
|
||||
for id, nid := range userNIDsMap {
|
||||
userNIDs = append(userNIDs, nid)
|
||||
nidToUserID[nid] = id
|
||||
}
|
||||
userNIDToCount, err := d.MembershipTable.SelectJoinedUsersSetForRooms(ctx, nil, roomNIDs, userNIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -1097,13 +1185,6 @@ func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs []string)
|
|||
stateKeyNIDs[i] = nid
|
||||
i++
|
||||
}
|
||||
nidToUserID, err := d.EventStateKeysTable.BulkSelectEventStateKey(ctx, nil, stateKeyNIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(nidToUserID) != len(userNIDToCount) {
|
||||
logrus.Warnf("SelectJoinedUsersSetForRooms found %d users but BulkSelectEventStateKey only returned state key NIDs for %d of them", len(userNIDToCount), len(nidToUserID))
|
||||
}
|
||||
result := make(map[string]int, len(userNIDToCount))
|
||||
for nid, count := range userNIDToCount {
|
||||
result[nidToUserID[nid]] = count
|
||||
|
|
|
|||
|
|
@ -42,7 +42,8 @@ const membershipSchema = `
|
|||
`
|
||||
|
||||
var selectJoinedUsersSetForRoomsSQL = "" +
|
||||
"SELECT target_nid, COUNT(room_nid) FROM roomserver_membership WHERE room_nid IN ($1) AND" +
|
||||
"SELECT target_nid, COUNT(room_nid) FROM roomserver_membership" +
|
||||
" WHERE room_nid IN ($1) AND target_nid IN ($2) AND" +
|
||||
" membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " and forgotten = false" +
|
||||
" GROUP BY target_nid"
|
||||
|
||||
|
|
@ -280,18 +281,22 @@ func (s *membershipStatements) SelectRoomsWithMembership(
|
|||
return roomNIDs, nil
|
||||
}
|
||||
|
||||
func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID) (map[types.EventStateKeyNID]int, error) {
|
||||
iRoomNIDs := make([]interface{}, len(roomNIDs))
|
||||
for i, v := range roomNIDs {
|
||||
iRoomNIDs[i] = v
|
||||
func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID, userNIDs []types.EventStateKeyNID) (map[types.EventStateKeyNID]int, error) {
|
||||
params := make([]interface{}, 0, len(roomNIDs)+len(userNIDs))
|
||||
for _, v := range roomNIDs {
|
||||
params = append(params, v)
|
||||
}
|
||||
query := strings.Replace(selectJoinedUsersSetForRoomsSQL, "($1)", sqlutil.QueryVariadic(len(iRoomNIDs)), 1)
|
||||
for _, v := range userNIDs {
|
||||
params = append(params, v)
|
||||
}
|
||||
query := strings.Replace(selectJoinedUsersSetForRoomsSQL, "($1)", sqlutil.QueryVariadic(len(roomNIDs)), 1)
|
||||
query = strings.Replace(query, "($2)", sqlutil.QueryVariadicOffset(len(userNIDs), len(roomNIDs)), 1)
|
||||
var rows *sql.Rows
|
||||
var err error
|
||||
if txn != nil {
|
||||
rows, err = txn.QueryContext(ctx, query, iRoomNIDs...)
|
||||
rows, err = txn.QueryContext(ctx, query, params...)
|
||||
} else {
|
||||
rows, err = s.db.QueryContext(ctx, query, iRoomNIDs...)
|
||||
rows, err = s.db.QueryContext(ctx, query, params...)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
|
|||
|
|
@ -127,9 +127,8 @@ type Membership interface {
|
|||
SelectMembershipsFromRoomAndMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, membership MembershipState, localOnly bool) (eventNIDs []types.EventNID, err error)
|
||||
UpdateMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership MembershipState, eventNID types.EventNID, forgotten bool) error
|
||||
SelectRoomsWithMembership(ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, membershipState MembershipState) ([]types.RoomNID, error)
|
||||
// SelectJoinedUsersSetForRooms returns the set of all users in the rooms who are joined to any of these rooms, along with the
|
||||
// counts of how many rooms they are joined.
|
||||
SelectJoinedUsersSetForRooms(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID) (map[types.EventStateKeyNID]int, error)
|
||||
// SelectJoinedUsersSetForRooms returns how many times each of the given users appears across the given rooms.
|
||||
SelectJoinedUsersSetForRooms(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID, userNIDs []types.EventStateKeyNID) (map[types.EventStateKeyNID]int, error)
|
||||
SelectKnownUsers(ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error)
|
||||
UpdateForgetMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, forget bool) error
|
||||
SelectLocalServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) (bool, error)
|
||||
|
|
|
|||
|
|
@ -654,11 +654,7 @@ func (rc *reqCtx) injectResponseToRoomserver(res *MSC2836EventRelationshipsRespo
|
|||
AuthEvents: res.AuthChain,
|
||||
StateEvents: stateEvents,
|
||||
}
|
||||
eventsInOrder, err := respState.Events(rc.roomVersion)
|
||||
if err != nil {
|
||||
util.GetLogger(rc.ctx).WithError(err).Error("failed to calculate order to send events in MSC2836EventRelationshipsResponse")
|
||||
return
|
||||
}
|
||||
eventsInOrder := respState.Events(rc.roomVersion)
|
||||
// everything gets sent as an outlier because auth chain events may be disjoint from the DAG
|
||||
// as may the threaded events.
|
||||
var ires []roomserver.InputRoomEvent
|
||||
|
|
@ -669,7 +665,7 @@ func (rc *reqCtx) injectResponseToRoomserver(res *MSC2836EventRelationshipsRespo
|
|||
})
|
||||
}
|
||||
// we've got the data by this point so use a background context
|
||||
err = roomserver.SendInputRoomEvents(context.Background(), rc.rsAPI, ires, false)
|
||||
err := roomserver.SendInputRoomEvents(context.Background(), rc.rsAPI, ires, false)
|
||||
if err != nil {
|
||||
util.GetLogger(rc.ctx).WithError(err).Error("failed to inject MSC2836EventRelationshipsResponse into the roomserver")
|
||||
}
|
||||
|
|
|
|||
|
|
@ -18,17 +18,19 @@ package msc2946
|
|||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/gorilla/mux"
|
||||
chttputil "github.com/matrix-org/dendrite/clientapi/httputil"
|
||||
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
||||
fs "github.com/matrix-org/dendrite/federationapi/api"
|
||||
"github.com/matrix-org/dendrite/internal/hooks"
|
||||
"github.com/matrix-org/dendrite/internal/caching"
|
||||
"github.com/matrix-org/dendrite/internal/httputil"
|
||||
roomserver "github.com/matrix-org/dendrite/roomserver/api"
|
||||
"github.com/matrix-org/dendrite/setup/base"
|
||||
|
|
@ -40,41 +42,26 @@ import (
|
|||
|
||||
const (
|
||||
ConstCreateEventContentKey = "type"
|
||||
ConstCreateEventContentValueSpace = "m.space"
|
||||
ConstSpaceChildEventType = "m.space.child"
|
||||
ConstSpaceParentEventType = "m.space.parent"
|
||||
)
|
||||
|
||||
// Defaults sets the request defaults
|
||||
func Defaults(r *gomatrixserverlib.MSC2946SpacesRequest) {
|
||||
r.Limit = 2000
|
||||
r.MaxRoomsPerSpace = -1
|
||||
type MSC2946ClientResponse struct {
|
||||
Rooms []gomatrixserverlib.MSC2946Room `json:"rooms"`
|
||||
NextBatch string `json:"next_batch,omitempty"`
|
||||
}
|
||||
|
||||
// Enable this MSC
|
||||
func Enable(
|
||||
base *base.BaseDendrite, rsAPI roomserver.RoomserverInternalAPI, userAPI userapi.UserInternalAPI,
|
||||
fsAPI fs.FederationInternalAPI, keyRing gomatrixserverlib.JSONVerifier,
|
||||
fsAPI fs.FederationInternalAPI, keyRing gomatrixserverlib.JSONVerifier, cache caching.SpaceSummaryRoomsCache,
|
||||
) error {
|
||||
db, err := NewDatabase(&base.Cfg.MSCs.Database)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot enable MSC2946: %w", err)
|
||||
}
|
||||
hooks.Enable()
|
||||
hooks.Attach(hooks.KindNewEventPersisted, func(headeredEvent interface{}) {
|
||||
he := headeredEvent.(*gomatrixserverlib.HeaderedEvent)
|
||||
hookErr := db.StoreReference(context.Background(), he)
|
||||
if hookErr != nil {
|
||||
util.GetLogger(context.Background()).WithError(hookErr).WithField("event_id", he.EventID()).Error(
|
||||
"failed to StoreReference",
|
||||
)
|
||||
}
|
||||
})
|
||||
clientAPI := httputil.MakeAuthAPI("spaces", userAPI, spacesHandler(rsAPI, fsAPI, cache, base.Cfg.Global.ServerName))
|
||||
base.PublicClientAPIMux.Handle("/v1/rooms/{roomID}/hierarchy", clientAPI).Methods(http.MethodGet, http.MethodOptions)
|
||||
base.PublicClientAPIMux.Handle("/unstable/org.matrix.msc2946/rooms/{roomID}/hierarchy", clientAPI).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
base.PublicClientAPIMux.Handle("/unstable/org.matrix.msc2946/rooms/{roomID}/spaces",
|
||||
httputil.MakeAuthAPI("spaces", userAPI, spacesHandler(db, rsAPI, fsAPI, base.Cfg.Global.ServerName)),
|
||||
).Methods(http.MethodPost, http.MethodOptions)
|
||||
|
||||
base.PublicFederationAPIMux.Handle("/unstable/org.matrix.msc2946/spaces/{roomID}", httputil.MakeExternalAPI(
|
||||
fedAPI := httputil.MakeExternalAPI(
|
||||
"msc2946_fed_spaces", func(req *http.Request) util.JSONResponse {
|
||||
fedReq, errResp := gomatrixserverlib.VerifyHTTPRequest(
|
||||
req, time.Now(), base.Cfg.Global.ServerName, keyRing,
|
||||
|
|
@ -88,252 +75,308 @@ func Enable(
|
|||
return util.ErrorResponse(err)
|
||||
}
|
||||
roomID := params["roomID"]
|
||||
return federatedSpacesHandler(req.Context(), fedReq, roomID, db, rsAPI, fsAPI, base.Cfg.Global.ServerName)
|
||||
return federatedSpacesHandler(req.Context(), fedReq, roomID, cache, rsAPI, fsAPI, base.Cfg.Global.ServerName)
|
||||
},
|
||||
)).Methods(http.MethodPost, http.MethodOptions)
|
||||
)
|
||||
base.PublicFederationAPIMux.Handle("/unstable/org.matrix.msc2946/hierarchy/{roomID}", fedAPI).Methods(http.MethodGet)
|
||||
base.PublicFederationAPIMux.Handle("/v1/hierarchy/{roomID}", fedAPI).Methods(http.MethodGet)
|
||||
return nil
|
||||
}
|
||||
|
||||
func federatedSpacesHandler(
|
||||
ctx context.Context, fedReq *gomatrixserverlib.FederationRequest, roomID string, db Database,
|
||||
ctx context.Context, fedReq *gomatrixserverlib.FederationRequest, roomID string,
|
||||
cache caching.SpaceSummaryRoomsCache,
|
||||
rsAPI roomserver.RoomserverInternalAPI, fsAPI fs.FederationInternalAPI,
|
||||
thisServer gomatrixserverlib.ServerName,
|
||||
) util.JSONResponse {
|
||||
inMemoryBatchCache := make(map[string]set)
|
||||
var r gomatrixserverlib.MSC2946SpacesRequest
|
||||
Defaults(&r)
|
||||
if err := json.Unmarshal(fedReq.Content(), &r); err != nil {
|
||||
u, err := url.Parse(fedReq.RequestURI())
|
||||
if err != nil {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: jsonerror.BadJSON("The request body could not be decoded into valid JSON. " + err.Error()),
|
||||
Code: 400,
|
||||
JSON: jsonerror.InvalidParam("bad request uri"),
|
||||
}
|
||||
}
|
||||
|
||||
w := walker{
|
||||
req: &r,
|
||||
rootRoomID: roomID,
|
||||
serverName: fedReq.Origin(),
|
||||
thisServer: thisServer,
|
||||
ctx: ctx,
|
||||
cache: cache,
|
||||
suggestedOnly: u.Query().Get("suggested_only") == "true",
|
||||
limit: 1000,
|
||||
// The main difference is that it does not recurse into spaces and does not support pagination.
|
||||
// This is somewhat equivalent to a Client-Server request with a max_depth=1.
|
||||
maxDepth: 1,
|
||||
|
||||
db: db,
|
||||
rsAPI: rsAPI,
|
||||
fsAPI: fsAPI,
|
||||
inMemoryBatchCache: inMemoryBatchCache,
|
||||
}
|
||||
res := w.walk()
|
||||
return util.JSONResponse{
|
||||
Code: 200,
|
||||
JSON: res,
|
||||
// inline cache as we don't have pagination in federation mode
|
||||
paginationCache: make(map[string]paginationInfo),
|
||||
}
|
||||
return w.walk()
|
||||
}
|
||||
|
||||
func spacesHandler(
|
||||
db Database, rsAPI roomserver.RoomserverInternalAPI, fsAPI fs.FederationInternalAPI,
|
||||
rsAPI roomserver.RoomserverInternalAPI,
|
||||
fsAPI fs.FederationInternalAPI,
|
||||
cache caching.SpaceSummaryRoomsCache,
|
||||
thisServer gomatrixserverlib.ServerName,
|
||||
) func(*http.Request, *userapi.Device) util.JSONResponse {
|
||||
// declared outside the returned handler so it persists between calls
|
||||
// TODO: clear based on... time?
|
||||
paginationCache := make(map[string]paginationInfo)
|
||||
|
||||
return func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
inMemoryBatchCache := make(map[string]set)
|
||||
// Extract the room ID from the request. Sanity check request data.
|
||||
params, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
if err != nil {
|
||||
return util.ErrorResponse(err)
|
||||
}
|
||||
roomID := params["roomID"]
|
||||
var r gomatrixserverlib.MSC2946SpacesRequest
|
||||
Defaults(&r)
|
||||
if resErr := chttputil.UnmarshalJSONRequest(req, &r); resErr != nil {
|
||||
return *resErr
|
||||
}
|
||||
w := walker{
|
||||
req: &r,
|
||||
suggestedOnly: req.URL.Query().Get("suggested_only") == "true",
|
||||
limit: parseInt(req.URL.Query().Get("limit"), 1000),
|
||||
maxDepth: parseInt(req.URL.Query().Get("max_depth"), -1),
|
||||
paginationToken: req.URL.Query().Get("from"),
|
||||
rootRoomID: roomID,
|
||||
caller: device,
|
||||
thisServer: thisServer,
|
||||
ctx: req.Context(),
|
||||
cache: cache,
|
||||
|
||||
db: db,
|
||||
rsAPI: rsAPI,
|
||||
fsAPI: fsAPI,
|
||||
inMemoryBatchCache: inMemoryBatchCache,
|
||||
}
|
||||
res := w.walk()
|
||||
return util.JSONResponse{
|
||||
Code: 200,
|
||||
JSON: res,
|
||||
paginationCache: paginationCache,
|
||||
}
|
||||
return w.walk()
|
||||
}
|
||||
}
|
||||
|
||||
type paginationInfo struct {
|
||||
processed set
|
||||
unvisited []roomVisit
|
||||
}
|
||||
|
||||
type walker struct {
|
||||
req *gomatrixserverlib.MSC2946SpacesRequest
|
||||
rootRoomID string
|
||||
caller *userapi.Device
|
||||
serverName gomatrixserverlib.ServerName
|
||||
thisServer gomatrixserverlib.ServerName
|
||||
db Database
|
||||
rsAPI roomserver.RoomserverInternalAPI
|
||||
fsAPI fs.FederationInternalAPI
|
||||
ctx context.Context
|
||||
cache caching.SpaceSummaryRoomsCache
|
||||
suggestedOnly bool
|
||||
limit int
|
||||
maxDepth int
|
||||
paginationToken string
|
||||
|
||||
// user ID|device ID|batch_num => event/room IDs sent to client
|
||||
inMemoryBatchCache map[string]set
|
||||
paginationCache map[string]paginationInfo
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func (w *walker) roomIsExcluded(roomID string) bool {
|
||||
for _, exclRoom := range w.req.ExcludeRooms {
|
||||
if exclRoom == roomID {
|
||||
return true
|
||||
func (w *walker) newPaginationCache() (string, paginationInfo) {
|
||||
p := paginationInfo{
|
||||
processed: make(set),
|
||||
unvisited: nil,
|
||||
}
|
||||
}
|
||||
return false
|
||||
tok := uuid.NewString()
|
||||
return tok, p
|
||||
}
|
||||
|
||||
func (w *walker) callerID() string {
|
||||
func (w *walker) loadPaginationCache(paginationToken string) *paginationInfo {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
p := w.paginationCache[paginationToken]
|
||||
return &p
|
||||
}
|
||||
|
||||
func (w *walker) storePaginationCache(paginationToken string, cache paginationInfo) {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
w.paginationCache[paginationToken] = cache
|
||||
}
|
||||
|
||||
type roomVisit struct {
|
||||
roomID string
|
||||
depth int
|
||||
vias []string // vias to query this room by
|
||||
}
|
||||
|
||||
func (w *walker) walk() util.JSONResponse {
|
||||
if !w.authorised(w.rootRoomID) {
|
||||
if w.caller != nil {
|
||||
return w.caller.UserID + "|" + w.caller.ID
|
||||
// CS API format
|
||||
return util.JSONResponse{
|
||||
Code: 403,
|
||||
JSON: jsonerror.Forbidden("room is unknown/forbidden"),
|
||||
}
|
||||
return string(w.serverName)
|
||||
}
|
||||
|
||||
func (w *walker) alreadySent(id string) bool {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
m, ok := w.inMemoryBatchCache[w.callerID()]
|
||||
if !ok {
|
||||
return false
|
||||
} else {
|
||||
// SS API format
|
||||
return util.JSONResponse{
|
||||
Code: 404,
|
||||
JSON: jsonerror.NotFound("room is unknown/forbidden"),
|
||||
}
|
||||
return m[id]
|
||||
}
|
||||
|
||||
func (w *walker) markSent(id string) {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
m := w.inMemoryBatchCache[w.callerID()]
|
||||
if m == nil {
|
||||
m = make(set)
|
||||
}
|
||||
m[id] = true
|
||||
w.inMemoryBatchCache[w.callerID()] = m
|
||||
}
|
||||
|
||||
func (w *walker) walk() *gomatrixserverlib.MSC2946SpacesResponse {
|
||||
var res gomatrixserverlib.MSC2946SpacesResponse
|
||||
// Begin walking the graph starting with the room ID in the request in a queue of unvisited rooms
|
||||
unvisited := []string{w.rootRoomID}
|
||||
processed := make(set)
|
||||
for len(unvisited) > 0 {
|
||||
roomID := unvisited[0]
|
||||
unvisited = unvisited[1:]
|
||||
// If this room has already been processed, skip. NB: do not remember this between calls
|
||||
if processed[roomID] || roomID == "" {
|
||||
continue
|
||||
}
|
||||
// Mark this room as processed.
|
||||
processed[roomID] = true
|
||||
|
||||
// Collect rooms/events to send back (either locally or fetched via federation)
|
||||
var discoveredRooms []gomatrixserverlib.MSC2946Room
|
||||
var discoveredEvents []gomatrixserverlib.MSC2946StrippedEvent
|
||||
|
||||
// If we know about this room and the caller is authorised (joined/world_readable) then pull
|
||||
// events locally
|
||||
if w.roomExists(roomID) && w.authorised(roomID) {
|
||||
// Get all `m.space.child` and `m.space.parent` state events for the room. *In addition*, get
|
||||
// all `m.space.child` and `m.space.parent` state events which *point to* (via `state_key` or `content.room_id`)
|
||||
// this room. This requires servers to store reverse lookups.
|
||||
events, err := w.references(roomID)
|
||||
if err != nil {
|
||||
util.GetLogger(w.ctx).WithError(err).WithField("room_id", roomID).Error("failed to extract references for room")
|
||||
var cache *paginationInfo
|
||||
if w.paginationToken != "" {
|
||||
cache = w.loadPaginationCache(w.paginationToken)
|
||||
if cache == nil {
|
||||
return util.JSONResponse{
|
||||
Code: 400,
|
||||
JSON: jsonerror.InvalidArgumentValue("invalid from"),
|
||||
}
|
||||
}
|
||||
} else {
|
||||
tok, c := w.newPaginationCache()
|
||||
cache = &c
|
||||
w.paginationToken = tok
|
||||
// Begin walking the graph starting with the room ID in the request in a queue of unvisited rooms
|
||||
c.unvisited = append(c.unvisited, roomVisit{
|
||||
roomID: w.rootRoomID,
|
||||
depth: 0,
|
||||
})
|
||||
}
|
||||
|
||||
processed := cache.processed
|
||||
unvisited := cache.unvisited
|
||||
|
||||
// Depth first -> stack data structure
|
||||
for len(unvisited) > 0 {
|
||||
if len(discoveredRooms) >= w.limit {
|
||||
break
|
||||
}
|
||||
|
||||
// pop the stack
|
||||
rv := unvisited[len(unvisited)-1]
|
||||
unvisited = unvisited[:len(unvisited)-1]
|
||||
// If this room has already been processed, skip.
|
||||
// If this room exceeds the specified depth, skip.
|
||||
if processed.isSet(rv.roomID) || rv.roomID == "" || (w.maxDepth > 0 && rv.depth > w.maxDepth) {
|
||||
continue
|
||||
}
|
||||
discoveredEvents = events
|
||||
|
||||
pubRoom := w.publicRoomsChunk(roomID)
|
||||
roomType := ""
|
||||
create := w.stateEvent(roomID, gomatrixserverlib.MRoomCreate, "")
|
||||
// Mark this room as processed.
|
||||
processed.set(rv.roomID)
|
||||
|
||||
// if this room is not a space room, skip.
|
||||
var roomType string
|
||||
create := w.stateEvent(rv.roomID, gomatrixserverlib.MRoomCreate, "")
|
||||
if create != nil {
|
||||
// escape the `.`s so gjson doesn't think it's nested
|
||||
roomType = gjson.GetBytes(create.Content(), strings.ReplaceAll(ConstCreateEventContentKey, ".", `\.`)).Str
|
||||
}
|
||||
|
||||
// Add the total number of events to `PublicRoomsChunk` under `num_refs`. Add `PublicRoomsChunk` to `rooms`.
|
||||
// Collect rooms/events to send back (either locally or fetched via federation)
|
||||
var discoveredChildEvents []gomatrixserverlib.MSC2946StrippedEvent
|
||||
|
||||
// If we know about this room and the caller is authorised (joined/world_readable) then pull
|
||||
// events locally
|
||||
if w.roomExists(rv.roomID) && w.authorised(rv.roomID) {
|
||||
// Get all `m.space.child` state events for this room
|
||||
events, err := w.childReferences(rv.roomID)
|
||||
if err != nil {
|
||||
util.GetLogger(w.ctx).WithError(err).WithField("room_id", rv.roomID).Error("failed to extract references for room")
|
||||
continue
|
||||
}
|
||||
discoveredChildEvents = events
|
||||
|
||||
pubRoom := w.publicRoomsChunk(rv.roomID)
|
||||
|
||||
discoveredRooms = append(discoveredRooms, gomatrixserverlib.MSC2946Room{
|
||||
PublicRoom: *pubRoom,
|
||||
NumRefs: len(discoveredEvents),
|
||||
RoomType: roomType,
|
||||
ChildrenState: events,
|
||||
})
|
||||
} else {
|
||||
// attempt to query this room over federation, as either we've never heard of it before
|
||||
// or we've left it and hence are not authorised (but info may be exposed regardless)
|
||||
fedRes, err := w.federatedRoomInfo(roomID)
|
||||
fedRes, err := w.federatedRoomInfo(rv.roomID, rv.vias)
|
||||
if err != nil {
|
||||
util.GetLogger(w.ctx).WithError(err).WithField("room_id", roomID).Errorf("failed to query federated spaces")
|
||||
util.GetLogger(w.ctx).WithError(err).WithField("room_id", rv.roomID).Errorf("failed to query federated spaces")
|
||||
continue
|
||||
}
|
||||
if fedRes != nil {
|
||||
discoveredRooms = fedRes.Rooms
|
||||
discoveredEvents = fedRes.Events
|
||||
discoveredChildEvents = fedRes.Room.ChildrenState
|
||||
discoveredRooms = append(discoveredRooms, fedRes.Room)
|
||||
if len(fedRes.Children) > 0 {
|
||||
discoveredRooms = append(discoveredRooms, fedRes.Children...)
|
||||
}
|
||||
// mark this room as a space room as the federated server responded.
|
||||
// we need to do this so we add the children of this room to the unvisited stack
|
||||
// as these children may be rooms we do know about.
|
||||
roomType = ConstCreateEventContentValueSpace
|
||||
}
|
||||
}
|
||||
|
||||
// If this room has not ever been in `rooms` (across multiple requests), send it now
|
||||
for _, room := range discoveredRooms {
|
||||
if !w.alreadySent(room.RoomID) && !w.roomIsExcluded(room.RoomID) {
|
||||
res.Rooms = append(res.Rooms, room)
|
||||
w.markSent(room.RoomID)
|
||||
}
|
||||
}
|
||||
|
||||
uniqueRooms := make(set)
|
||||
|
||||
// If this is the root room from the original request, insert all these events into `events` if
|
||||
// they haven't been added before (across multiple requests).
|
||||
if w.rootRoomID == roomID {
|
||||
for _, ev := range discoveredEvents {
|
||||
if !w.alreadySent(eventKey(&ev)) {
|
||||
res.Events = append(res.Events, ev)
|
||||
uniqueRooms[ev.RoomID] = true
|
||||
uniqueRooms[spaceTargetStripped(&ev)] = true
|
||||
w.markSent(eventKey(&ev))
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Else add them to `events` honouring the `limit` and `max_rooms_per_space` values. If either
|
||||
// are exceeded, stop adding events. If the event has already been added, do not add it again.
|
||||
numAdded := 0
|
||||
for _, ev := range discoveredEvents {
|
||||
if w.req.Limit > 0 && len(res.Events) >= w.req.Limit {
|
||||
break
|
||||
}
|
||||
if w.req.MaxRoomsPerSpace > 0 && numAdded >= w.req.MaxRoomsPerSpace {
|
||||
break
|
||||
}
|
||||
if w.alreadySent(eventKey(&ev)) {
|
||||
// don't walk the children
|
||||
// if the parent is not a space room
|
||||
if roomType != ConstCreateEventContentValueSpace {
|
||||
continue
|
||||
}
|
||||
// Skip the room if it's part of exclude_rooms but ONLY IF the source matches, as we still
|
||||
// want to catch arrows which point to excluded rooms.
|
||||
if w.roomIsExcluded(ev.RoomID) {
|
||||
continue
|
||||
}
|
||||
res.Events = append(res.Events, ev)
|
||||
uniqueRooms[ev.RoomID] = true
|
||||
uniqueRooms[spaceTargetStripped(&ev)] = true
|
||||
w.markSent(eventKey(&ev))
|
||||
// we don't distinguish between child state events and parent state events for the purposes of
|
||||
// max_rooms_per_space, maybe we should?
|
||||
numAdded++
|
||||
}
|
||||
}
|
||||
|
||||
// For each referenced room ID in the events being returned to the caller (both parent and child)
|
||||
// For each referenced room ID in the child events being returned to the caller
|
||||
// add the room ID to the queue of unvisited rooms. Loop from the beginning.
|
||||
for roomID := range uniqueRooms {
|
||||
unvisited = append(unvisited, roomID)
|
||||
// We need to invert the order here because the child events are lo->hi on the timestamp,
|
||||
// so we need to ensure we pop in the same lo->hi order, which won't be the case if we
|
||||
// insert the highest timestamp last in a stack.
|
||||
for i := len(discoveredChildEvents) - 1; i >= 0; i-- {
|
||||
spaceContent := struct {
|
||||
Via []string `json:"via"`
|
||||
}{}
|
||||
ev := discoveredChildEvents[i]
|
||||
_ = json.Unmarshal(ev.Content, &spaceContent)
|
||||
unvisited = append(unvisited, roomVisit{
|
||||
roomID: ev.StateKey,
|
||||
depth: rv.depth + 1,
|
||||
vias: spaceContent.Via,
|
||||
})
|
||||
}
|
||||
}
|
||||
return &res
|
||||
|
||||
if len(unvisited) > 0 {
|
||||
// we still have more rooms so we need to send back a pagination token,
|
||||
// we probably hit a room limit
|
||||
cache.processed = processed
|
||||
cache.unvisited = unvisited
|
||||
w.storePaginationCache(w.paginationToken, *cache)
|
||||
} else {
|
||||
// clear the pagination token so we don't send it back to the client
|
||||
// Note we do NOT nuke the cache just in case this response is lost
|
||||
// and the client retries it.
|
||||
w.paginationToken = ""
|
||||
}
|
||||
|
||||
if w.caller != nil {
|
||||
// return CS API format
|
||||
return util.JSONResponse{
|
||||
Code: 200,
|
||||
JSON: MSC2946ClientResponse{
|
||||
Rooms: discoveredRooms,
|
||||
NextBatch: w.paginationToken,
|
||||
},
|
||||
}
|
||||
}
|
||||
// return SS API format
|
||||
// the first discovered room will be the room asked for, and subsequent ones the depth=1 children
|
||||
if len(discoveredRooms) == 0 {
|
||||
return util.JSONResponse{
|
||||
Code: 404,
|
||||
JSON: jsonerror.NotFound("room is unknown/forbidden"),
|
||||
}
|
||||
}
|
||||
return util.JSONResponse{
|
||||
Code: 200,
|
||||
JSON: gomatrixserverlib.MSC2946SpacesResponse{
|
||||
Room: discoveredRooms[0],
|
||||
Children: discoveredRooms[1:],
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (w *walker) stateEvent(roomID, evType, stateKey string) *gomatrixserverlib.HeaderedEvent {
|
||||
|
|
@ -366,46 +409,41 @@ func (w *walker) publicRoomsChunk(roomID string) *gomatrixserverlib.PublicRoom {
|
|||
|
||||
// federatedRoomInfo returns more of the spaces graph from another server. Returns nil if this was
|
||||
// unsuccessful.
|
||||
func (w *walker) federatedRoomInfo(roomID string) (*gomatrixserverlib.MSC2946SpacesResponse, error) {
|
||||
func (w *walker) federatedRoomInfo(roomID string, vias []string) (*gomatrixserverlib.MSC2946SpacesResponse, error) {
|
||||
// only do federated requests for client requests
|
||||
if w.caller == nil {
|
||||
return nil, nil
|
||||
}
|
||||
// extract events which point to this room ID and extract their vias
|
||||
events, err := w.db.References(w.ctx, roomID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get References events: %w", err)
|
||||
resp, ok := w.cache.GetSpaceSummary(roomID)
|
||||
if ok {
|
||||
util.GetLogger(w.ctx).Debugf("Returning cached response for %s", roomID)
|
||||
return &resp, nil
|
||||
}
|
||||
vias := make(set)
|
||||
for _, ev := range events {
|
||||
if ev.StateKeyEquals(roomID) {
|
||||
// event points at this room, extract vias
|
||||
content := struct {
|
||||
Vias []string `json:"via"`
|
||||
}{}
|
||||
if err = json.Unmarshal(ev.Content(), &content); err != nil {
|
||||
continue // silently ignore corrupted state events
|
||||
}
|
||||
for _, v := range content.Vias {
|
||||
vias[v] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
util.GetLogger(w.ctx).Infof("Querying federatedRoomInfo via %+v", vias)
|
||||
util.GetLogger(w.ctx).Debugf("Querying %s via %+v", roomID, vias)
|
||||
ctx := context.Background()
|
||||
// query more of the spaces graph using these servers
|
||||
for serverName := range vias {
|
||||
for _, serverName := range vias {
|
||||
if serverName == string(w.thisServer) {
|
||||
continue
|
||||
}
|
||||
res, err := w.fsAPI.MSC2946Spaces(ctx, gomatrixserverlib.ServerName(serverName), roomID, gomatrixserverlib.MSC2946SpacesRequest{
|
||||
Limit: w.req.Limit,
|
||||
MaxRoomsPerSpace: w.req.MaxRoomsPerSpace,
|
||||
})
|
||||
res, err := w.fsAPI.MSC2946Spaces(ctx, gomatrixserverlib.ServerName(serverName), roomID, w.suggestedOnly)
|
||||
if err != nil {
|
||||
util.GetLogger(w.ctx).WithError(err).Warnf("failed to call MSC2946Spaces on server %s", serverName)
|
||||
continue
|
||||
}
|
||||
// ensure nil slices are empty as we send this to the client sometimes
|
||||
if res.Room.ChildrenState == nil {
|
||||
res.Room.ChildrenState = []gomatrixserverlib.MSC2946StrippedEvent{}
|
||||
}
|
||||
for i := 0; i < len(res.Children); i++ {
|
||||
child := res.Children[i]
|
||||
if child.ChildrenState == nil {
|
||||
child.ChildrenState = []gomatrixserverlib.MSC2946StrippedEvent{}
|
||||
}
|
||||
res.Children[i] = child
|
||||
}
|
||||
w.cache.StoreSpaceSummary(roomID, res)
|
||||
|
||||
return &res, nil
|
||||
}
|
||||
return nil, nil
|
||||
|
|
@ -501,7 +539,7 @@ func (w *walker) authorisedUser(roomID string) bool {
|
|||
hisVisEv := queryRes.StateEvents[hisVisTuple]
|
||||
if memberEv != nil {
|
||||
membership, _ := memberEv.Membership()
|
||||
if membership == gomatrixserverlib.Join {
|
||||
if membership == gomatrixserverlib.Join || membership == gomatrixserverlib.Invite {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
|
@ -514,29 +552,73 @@ func (w *walker) authorisedUser(roomID string) bool {
|
|||
return false
|
||||
}
|
||||
|
||||
// references returns all references pointing to or from this room.
|
||||
func (w *walker) references(roomID string) ([]gomatrixserverlib.MSC2946StrippedEvent, error) {
|
||||
events, err := w.db.References(w.ctx, roomID)
|
||||
// references returns all child references pointing to or from this room.
|
||||
func (w *walker) childReferences(roomID string) ([]gomatrixserverlib.MSC2946StrippedEvent, error) {
|
||||
createTuple := gomatrixserverlib.StateKeyTuple{
|
||||
EventType: gomatrixserverlib.MRoomCreate,
|
||||
StateKey: "",
|
||||
}
|
||||
var res roomserver.QueryCurrentStateResponse
|
||||
err := w.rsAPI.QueryCurrentState(context.Background(), &roomserver.QueryCurrentStateRequest{
|
||||
RoomID: roomID,
|
||||
AllowWildcards: true,
|
||||
StateTuples: []gomatrixserverlib.StateKeyTuple{
|
||||
createTuple, {
|
||||
EventType: ConstSpaceChildEventType,
|
||||
StateKey: "*",
|
||||
},
|
||||
},
|
||||
}, &res)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
el := make([]gomatrixserverlib.MSC2946StrippedEvent, 0, len(events))
|
||||
for _, ev := range events {
|
||||
|
||||
// don't return any child refs if the room is not a space room
|
||||
if res.StateEvents[createTuple] != nil {
|
||||
// escape the `.`s so gjson doesn't think it's nested
|
||||
roomType := gjson.GetBytes(res.StateEvents[createTuple].Content(), strings.ReplaceAll(ConstCreateEventContentKey, ".", `\.`)).Str
|
||||
if roomType != ConstCreateEventContentValueSpace {
|
||||
return []gomatrixserverlib.MSC2946StrippedEvent{}, nil
|
||||
}
|
||||
}
|
||||
delete(res.StateEvents, createTuple)
|
||||
|
||||
el := make([]gomatrixserverlib.MSC2946StrippedEvent, 0, len(res.StateEvents))
|
||||
for _, ev := range res.StateEvents {
|
||||
content := gjson.ParseBytes(ev.Content())
|
||||
// only return events that have a `via` key as per MSC1772
|
||||
// else we'll incorrectly walk redacted events (as the link
|
||||
// is in the state_key)
|
||||
if gjson.GetBytes(ev.Content(), "via").Exists() {
|
||||
if content.Get("via").Exists() {
|
||||
strip := stripped(ev.Event)
|
||||
if strip == nil {
|
||||
continue
|
||||
}
|
||||
// if suggested only and this child isn't suggested, skip it.
|
||||
// if suggested only = false we include everything so don't need to check the content.
|
||||
if w.suggestedOnly && !content.Get("suggested").Bool() {
|
||||
continue
|
||||
}
|
||||
el = append(el, *strip)
|
||||
}
|
||||
}
|
||||
// sort by origin_server_ts as per MSC2946
|
||||
sort.Slice(el, func(i, j int) bool {
|
||||
return el[i].OriginServerTS < el[j].OriginServerTS
|
||||
})
|
||||
|
||||
return el, nil
|
||||
}
|
||||
|
||||
type set map[string]bool
|
||||
type set map[string]struct{}
|
||||
|
||||
func (s set) set(val string) {
|
||||
s[val] = struct{}{}
|
||||
}
|
||||
func (s set) isSet(val string) bool {
|
||||
_, ok := s[val]
|
||||
return ok
|
||||
}
|
||||
|
||||
func stripped(ev *gomatrixserverlib.Event) *gomatrixserverlib.MSC2946StrippedEvent {
|
||||
if ev.StateKey() == nil {
|
||||
|
|
@ -548,6 +630,7 @@ func stripped(ev *gomatrixserverlib.Event) *gomatrixserverlib.MSC2946StrippedEve
|
|||
Content: ev.Content(),
|
||||
Sender: ev.Sender(),
|
||||
RoomID: ev.RoomID(),
|
||||
OriginServerTS: ev.OriginServerTS(),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -567,3 +650,11 @@ func spaceTargetStripped(event *gomatrixserverlib.MSC2946StrippedEvent) string {
|
|||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func parseInt(intstr string, defaultVal int) int {
|
||||
i, err := strconv.ParseInt(intstr, 10, 32)
|
||||
if err != nil {
|
||||
return defaultVal
|
||||
}
|
||||
return int(i)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,464 +0,0 @@
|
|||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package msc2946_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"encoding/json"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/matrix-org/dendrite/internal/hooks"
|
||||
"github.com/matrix-org/dendrite/internal/httputil"
|
||||
roomserver "github.com/matrix-org/dendrite/roomserver/api"
|
||||
"github.com/matrix-org/dendrite/setup/base"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/dendrite/setup/mscs/msc2946"
|
||||
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
var (
|
||||
client = &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
}
|
||||
roomVer = gomatrixserverlib.RoomVersionV6
|
||||
)
|
||||
|
||||
// Basic sanity check of MSC2946 logic. Tests a single room with a few state events
|
||||
// and a bit of recursion to subspaces. Makes a graph like:
|
||||
// Root
|
||||
// ____|_____
|
||||
// | | |
|
||||
// R1 R2 S1
|
||||
// |_________
|
||||
// | | |
|
||||
// R3 R4 S2
|
||||
// | <-- this link is just a parent, not a child
|
||||
// R5
|
||||
//
|
||||
// Alice is not joined to R4, but R4 is "world_readable".
|
||||
func TestMSC2946(t *testing.T) {
|
||||
alice := "@alice:localhost"
|
||||
// give access token to alice
|
||||
nopUserAPI := &testUserAPI{
|
||||
accessTokens: make(map[string]userapi.Device),
|
||||
}
|
||||
nopUserAPI.accessTokens["alice"] = userapi.Device{
|
||||
AccessToken: "alice",
|
||||
DisplayName: "Alice",
|
||||
UserID: alice,
|
||||
}
|
||||
rootSpace := "!rootspace:localhost"
|
||||
subSpaceS1 := "!subspaceS1:localhost"
|
||||
subSpaceS2 := "!subspaceS2:localhost"
|
||||
room1 := "!room1:localhost"
|
||||
room2 := "!room2:localhost"
|
||||
room3 := "!room3:localhost"
|
||||
room4 := "!room4:localhost"
|
||||
empty := ""
|
||||
room5 := "!room5:localhost"
|
||||
allRooms := []string{
|
||||
rootSpace, subSpaceS1, subSpaceS2,
|
||||
room1, room2, room3, room4, room5,
|
||||
}
|
||||
rootToR1 := mustCreateEvent(t, fledglingEvent{
|
||||
RoomID: rootSpace,
|
||||
Sender: alice,
|
||||
Type: msc2946.ConstSpaceChildEventType,
|
||||
StateKey: &room1,
|
||||
Content: map[string]interface{}{
|
||||
"via": []string{"localhost"},
|
||||
},
|
||||
})
|
||||
rootToR2 := mustCreateEvent(t, fledglingEvent{
|
||||
RoomID: rootSpace,
|
||||
Sender: alice,
|
||||
Type: msc2946.ConstSpaceChildEventType,
|
||||
StateKey: &room2,
|
||||
Content: map[string]interface{}{
|
||||
"via": []string{"localhost"},
|
||||
},
|
||||
})
|
||||
rootToS1 := mustCreateEvent(t, fledglingEvent{
|
||||
RoomID: rootSpace,
|
||||
Sender: alice,
|
||||
Type: msc2946.ConstSpaceChildEventType,
|
||||
StateKey: &subSpaceS1,
|
||||
Content: map[string]interface{}{
|
||||
"via": []string{"localhost"},
|
||||
},
|
||||
})
|
||||
s1ToR3 := mustCreateEvent(t, fledglingEvent{
|
||||
RoomID: subSpaceS1,
|
||||
Sender: alice,
|
||||
Type: msc2946.ConstSpaceChildEventType,
|
||||
StateKey: &room3,
|
||||
Content: map[string]interface{}{
|
||||
"via": []string{"localhost"},
|
||||
},
|
||||
})
|
||||
s1ToR4 := mustCreateEvent(t, fledglingEvent{
|
||||
RoomID: subSpaceS1,
|
||||
Sender: alice,
|
||||
Type: msc2946.ConstSpaceChildEventType,
|
||||
StateKey: &room4,
|
||||
Content: map[string]interface{}{
|
||||
"via": []string{"localhost"},
|
||||
},
|
||||
})
|
||||
s1ToS2 := mustCreateEvent(t, fledglingEvent{
|
||||
RoomID: subSpaceS1,
|
||||
Sender: alice,
|
||||
Type: msc2946.ConstSpaceChildEventType,
|
||||
StateKey: &subSpaceS2,
|
||||
Content: map[string]interface{}{
|
||||
"via": []string{"localhost"},
|
||||
},
|
||||
})
|
||||
// This is a parent link only
|
||||
s2ToR5 := mustCreateEvent(t, fledglingEvent{
|
||||
RoomID: room5,
|
||||
Sender: alice,
|
||||
Type: msc2946.ConstSpaceParentEventType,
|
||||
StateKey: &subSpaceS2,
|
||||
Content: map[string]interface{}{
|
||||
"via": []string{"localhost"},
|
||||
},
|
||||
})
|
||||
// history visibility for R4
|
||||
r4HisVis := mustCreateEvent(t, fledglingEvent{
|
||||
RoomID: room4,
|
||||
Sender: "@someone:localhost",
|
||||
Type: gomatrixserverlib.MRoomHistoryVisibility,
|
||||
StateKey: &empty,
|
||||
Content: map[string]interface{}{
|
||||
"history_visibility": "world_readable",
|
||||
},
|
||||
})
|
||||
var joinEvents []*gomatrixserverlib.HeaderedEvent
|
||||
for _, roomID := range allRooms {
|
||||
if roomID == room4 {
|
||||
continue // not joined to that room
|
||||
}
|
||||
joinEvents = append(joinEvents, mustCreateEvent(t, fledglingEvent{
|
||||
RoomID: roomID,
|
||||
Sender: alice,
|
||||
StateKey: &alice,
|
||||
Type: gomatrixserverlib.MRoomMember,
|
||||
Content: map[string]interface{}{
|
||||
"membership": "join",
|
||||
},
|
||||
}))
|
||||
}
|
||||
roomNameTuple := gomatrixserverlib.StateKeyTuple{
|
||||
EventType: "m.room.name",
|
||||
StateKey: "",
|
||||
}
|
||||
hisVisTuple := gomatrixserverlib.StateKeyTuple{
|
||||
EventType: "m.room.history_visibility",
|
||||
StateKey: "",
|
||||
}
|
||||
nopRsAPI := &testRoomserverAPI{
|
||||
joinEvents: joinEvents,
|
||||
events: map[string]*gomatrixserverlib.HeaderedEvent{
|
||||
rootToR1.EventID(): rootToR1,
|
||||
rootToR2.EventID(): rootToR2,
|
||||
rootToS1.EventID(): rootToS1,
|
||||
s1ToR3.EventID(): s1ToR3,
|
||||
s1ToR4.EventID(): s1ToR4,
|
||||
s1ToS2.EventID(): s1ToS2,
|
||||
s2ToR5.EventID(): s2ToR5,
|
||||
r4HisVis.EventID(): r4HisVis,
|
||||
},
|
||||
pubRoomState: map[string]map[gomatrixserverlib.StateKeyTuple]string{
|
||||
rootSpace: {
|
||||
roomNameTuple: "Root",
|
||||
hisVisTuple: "shared",
|
||||
},
|
||||
subSpaceS1: {
|
||||
roomNameTuple: "Sub-Space 1",
|
||||
hisVisTuple: "joined",
|
||||
},
|
||||
subSpaceS2: {
|
||||
roomNameTuple: "Sub-Space 2",
|
||||
hisVisTuple: "shared",
|
||||
},
|
||||
room1: {
|
||||
hisVisTuple: "joined",
|
||||
},
|
||||
room2: {
|
||||
hisVisTuple: "joined",
|
||||
},
|
||||
room3: {
|
||||
hisVisTuple: "joined",
|
||||
},
|
||||
room4: {
|
||||
hisVisTuple: "world_readable",
|
||||
},
|
||||
room5: {
|
||||
hisVisTuple: "joined",
|
||||
},
|
||||
},
|
||||
}
|
||||
allEvents := []*gomatrixserverlib.HeaderedEvent{
|
||||
rootToR1, rootToR2, rootToS1,
|
||||
s1ToR3, s1ToR4, s1ToS2,
|
||||
s2ToR5, r4HisVis,
|
||||
}
|
||||
allEvents = append(allEvents, joinEvents...)
|
||||
router := injectEvents(t, nopUserAPI, nopRsAPI, allEvents)
|
||||
cancel := runServer(t, router)
|
||||
defer cancel()
|
||||
|
||||
t.Run("returns no events for unknown rooms", func(t *testing.T) {
|
||||
res := postSpaces(t, 200, "alice", "!unknown:localhost", newReq(t, map[string]interface{}{}))
|
||||
if len(res.Events) > 0 {
|
||||
t.Errorf("got %d events, want 0", len(res.Events))
|
||||
}
|
||||
if len(res.Rooms) > 0 {
|
||||
t.Errorf("got %d rooms, want 0", len(res.Rooms))
|
||||
}
|
||||
})
|
||||
t.Run("returns the entire graph", func(t *testing.T) {
|
||||
res := postSpaces(t, 200, "alice", rootSpace, newReq(t, map[string]interface{}{}))
|
||||
if len(res.Events) != 7 {
|
||||
t.Errorf("got %d events, want 7", len(res.Events))
|
||||
}
|
||||
if len(res.Rooms) != len(allRooms) {
|
||||
t.Errorf("got %d rooms, want %d", len(res.Rooms), len(allRooms))
|
||||
}
|
||||
})
|
||||
t.Run("can update the graph", func(t *testing.T) {
|
||||
// remove R3 from the graph
|
||||
rmS1ToR3 := mustCreateEvent(t, fledglingEvent{
|
||||
RoomID: subSpaceS1,
|
||||
Sender: alice,
|
||||
Type: msc2946.ConstSpaceChildEventType,
|
||||
StateKey: &room3,
|
||||
Content: map[string]interface{}{}, // redacted
|
||||
})
|
||||
nopRsAPI.events[rmS1ToR3.EventID()] = rmS1ToR3
|
||||
hooks.Run(hooks.KindNewEventPersisted, rmS1ToR3)
|
||||
|
||||
res := postSpaces(t, 200, "alice", rootSpace, newReq(t, map[string]interface{}{}))
|
||||
if len(res.Events) != 6 { // one less since we don't return redacted events
|
||||
t.Errorf("got %d events, want 6", len(res.Events))
|
||||
}
|
||||
if len(res.Rooms) != (len(allRooms) - 1) { // one less due to lack of R3
|
||||
t.Errorf("got %d rooms, want %d", len(res.Rooms), len(allRooms)-1)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func newReq(t *testing.T, jsonBody map[string]interface{}) *gomatrixserverlib.MSC2946SpacesRequest {
|
||||
t.Helper()
|
||||
b, err := json.Marshal(jsonBody)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal request: %s", err)
|
||||
}
|
||||
var r gomatrixserverlib.MSC2946SpacesRequest
|
||||
if err := json.Unmarshal(b, &r); err != nil {
|
||||
t.Fatalf("Failed to unmarshal request: %s", err)
|
||||
}
|
||||
return &r
|
||||
}
|
||||
|
||||
func runServer(t *testing.T, router *mux.Router) func() {
|
||||
t.Helper()
|
||||
externalServ := &http.Server{
|
||||
Addr: string(":8010"),
|
||||
WriteTimeout: 60 * time.Second,
|
||||
Handler: router,
|
||||
}
|
||||
go func() {
|
||||
externalServ.ListenAndServe()
|
||||
}()
|
||||
// wait to listen on the port
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
return func() {
|
||||
externalServ.Shutdown(context.TODO())
|
||||
}
|
||||
}
|
||||
|
||||
func postSpaces(t *testing.T, expectCode int, accessToken, roomID string, req *gomatrixserverlib.MSC2946SpacesRequest) *gomatrixserverlib.MSC2946SpacesResponse {
|
||||
t.Helper()
|
||||
var r gomatrixserverlib.MSC2946SpacesRequest
|
||||
msc2946.Defaults(&r)
|
||||
data, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to marshal request: %s", err)
|
||||
}
|
||||
httpReq, err := http.NewRequest(
|
||||
"POST", "http://localhost:8010/_matrix/client/unstable/org.matrix.msc2946/rooms/"+url.PathEscape(roomID)+"/spaces",
|
||||
bytes.NewBuffer(data),
|
||||
)
|
||||
httpReq.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to prepare request: %s", err)
|
||||
}
|
||||
res, err := client.Do(httpReq)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to do request: %s", err)
|
||||
}
|
||||
if res.StatusCode != expectCode {
|
||||
body, _ := ioutil.ReadAll(res.Body)
|
||||
t.Fatalf("wrong response code, got %d want %d - body: %s", res.StatusCode, expectCode, string(body))
|
||||
}
|
||||
if res.StatusCode == 200 {
|
||||
var result gomatrixserverlib.MSC2946SpacesResponse
|
||||
body, err := ioutil.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("response 200 OK but failed to read response body: %s", err)
|
||||
}
|
||||
t.Logf("Body: %s", string(body))
|
||||
if err := json.Unmarshal(body, &result); err != nil {
|
||||
t.Fatalf("response 200 OK but failed to deserialise JSON : %s\nbody: %s", err, string(body))
|
||||
}
|
||||
return &result
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type testUserAPI struct {
|
||||
userapi.UserInternalAPITrace
|
||||
accessTokens map[string]userapi.Device
|
||||
}
|
||||
|
||||
func (u *testUserAPI) QueryAccessToken(ctx context.Context, req *userapi.QueryAccessTokenRequest, res *userapi.QueryAccessTokenResponse) error {
|
||||
dev, ok := u.accessTokens[req.AccessToken]
|
||||
if !ok {
|
||||
res.Err = "unknown token"
|
||||
return nil
|
||||
}
|
||||
res.Device = &dev
|
||||
return nil
|
||||
}
|
||||
|
||||
type testRoomserverAPI struct {
|
||||
// use a trace API as it implements method stubs so we don't need to have them here.
|
||||
// We'll override the functions we care about.
|
||||
roomserver.RoomserverInternalAPITrace
|
||||
joinEvents []*gomatrixserverlib.HeaderedEvent
|
||||
events map[string]*gomatrixserverlib.HeaderedEvent
|
||||
pubRoomState map[string]map[gomatrixserverlib.StateKeyTuple]string
|
||||
}
|
||||
|
||||
func (r *testRoomserverAPI) QueryServerJoinedToRoom(ctx context.Context, req *roomserver.QueryServerJoinedToRoomRequest, res *roomserver.QueryServerJoinedToRoomResponse) error {
|
||||
res.IsInRoom = true
|
||||
res.RoomExists = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *testRoomserverAPI) QueryBulkStateContent(ctx context.Context, req *roomserver.QueryBulkStateContentRequest, res *roomserver.QueryBulkStateContentResponse) error {
|
||||
res.Rooms = make(map[string]map[gomatrixserverlib.StateKeyTuple]string)
|
||||
for _, roomID := range req.RoomIDs {
|
||||
pubRoomData, ok := r.pubRoomState[roomID]
|
||||
if ok {
|
||||
res.Rooms[roomID] = pubRoomData
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *testRoomserverAPI) QueryCurrentState(ctx context.Context, req *roomserver.QueryCurrentStateRequest, res *roomserver.QueryCurrentStateResponse) error {
|
||||
res.StateEvents = make(map[gomatrixserverlib.StateKeyTuple]*gomatrixserverlib.HeaderedEvent)
|
||||
checkEvent := func(he *gomatrixserverlib.HeaderedEvent) {
|
||||
if he.RoomID() != req.RoomID {
|
||||
return
|
||||
}
|
||||
if he.StateKey() == nil {
|
||||
return
|
||||
}
|
||||
tuple := gomatrixserverlib.StateKeyTuple{
|
||||
EventType: he.Type(),
|
||||
StateKey: *he.StateKey(),
|
||||
}
|
||||
for _, t := range req.StateTuples {
|
||||
if t == tuple {
|
||||
res.StateEvents[t] = he
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, he := range r.joinEvents {
|
||||
checkEvent(he)
|
||||
}
|
||||
for _, he := range r.events {
|
||||
checkEvent(he)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func injectEvents(t *testing.T, userAPI userapi.UserInternalAPI, rsAPI roomserver.RoomserverInternalAPI, events []*gomatrixserverlib.HeaderedEvent) *mux.Router {
|
||||
t.Helper()
|
||||
cfg := &config.Dendrite{}
|
||||
cfg.Defaults(true)
|
||||
cfg.Global.ServerName = "localhost"
|
||||
cfg.MSCs.Database.ConnectionString = "file:msc2946_test.db"
|
||||
cfg.MSCs.MSCs = []string{"msc2946"}
|
||||
base := &base.BaseDendrite{
|
||||
Cfg: cfg,
|
||||
PublicClientAPIMux: mux.NewRouter().PathPrefix(httputil.PublicClientPathPrefix).Subrouter(),
|
||||
PublicFederationAPIMux: mux.NewRouter().PathPrefix(httputil.PublicFederationPathPrefix).Subrouter(),
|
||||
}
|
||||
|
||||
err := msc2946.Enable(base, rsAPI, userAPI, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to enable MSC2946: %s", err)
|
||||
}
|
||||
for _, ev := range events {
|
||||
hooks.Run(hooks.KindNewEventPersisted, ev)
|
||||
}
|
||||
return base.PublicClientAPIMux
|
||||
}
|
||||
|
||||
type fledglingEvent struct {
|
||||
Type string
|
||||
StateKey *string
|
||||
Content interface{}
|
||||
Sender string
|
||||
RoomID string
|
||||
}
|
||||
|
||||
func mustCreateEvent(t *testing.T, ev fledglingEvent) (result *gomatrixserverlib.HeaderedEvent) {
|
||||
t.Helper()
|
||||
seed := make([]byte, ed25519.SeedSize) // zero seed
|
||||
key := ed25519.NewKeyFromSeed(seed)
|
||||
eb := gomatrixserverlib.EventBuilder{
|
||||
Sender: ev.Sender,
|
||||
Depth: 999,
|
||||
Type: ev.Type,
|
||||
StateKey: ev.StateKey,
|
||||
RoomID: ev.RoomID,
|
||||
}
|
||||
err := eb.SetContent(ev.Content)
|
||||
if err != nil {
|
||||
t.Fatalf("mustCreateEvent: failed to marshal event content %+v", ev.Content)
|
||||
}
|
||||
// make sure the origin_server_ts changes so we can test recency
|
||||
time.Sleep(1 * time.Millisecond)
|
||||
signedEvent, err := eb.Build(time.Now(), gomatrixserverlib.ServerName("localhost"), "ed25519:test", key, roomVer)
|
||||
if err != nil {
|
||||
t.Fatalf("mustCreateEvent: failed to sign event: %s", err)
|
||||
}
|
||||
h := signedEvent.Headered(roomVer)
|
||||
return h
|
||||
}
|
||||
|
|
@ -1,182 +0,0 @@
|
|||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package msc2946
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
var (
|
||||
relTypes = map[string]int{
|
||||
ConstSpaceChildEventType: 1,
|
||||
ConstSpaceParentEventType: 2,
|
||||
}
|
||||
)
|
||||
|
||||
type Database interface {
|
||||
// StoreReference persists a child or parent space mapping.
|
||||
StoreReference(ctx context.Context, he *gomatrixserverlib.HeaderedEvent) error
|
||||
// References returns all events which have the given roomID as a parent or child space.
|
||||
References(ctx context.Context, roomID string) ([]*gomatrixserverlib.HeaderedEvent, error)
|
||||
}
|
||||
|
||||
type DB struct {
|
||||
db *sql.DB
|
||||
writer sqlutil.Writer
|
||||
insertEdgeStmt *sql.Stmt
|
||||
selectEdgesStmt *sql.Stmt
|
||||
}
|
||||
|
||||
// NewDatabase loads the database for msc2836
|
||||
func NewDatabase(dbOpts *config.DatabaseOptions) (Database, error) {
|
||||
if dbOpts.ConnectionString.IsPostgres() {
|
||||
return newPostgresDatabase(dbOpts)
|
||||
}
|
||||
return newSQLiteDatabase(dbOpts)
|
||||
}
|
||||
|
||||
func newPostgresDatabase(dbOpts *config.DatabaseOptions) (Database, error) {
|
||||
d := DB{
|
||||
writer: sqlutil.NewDummyWriter(),
|
||||
}
|
||||
var err error
|
||||
if d.db, err = sqlutil.Open(dbOpts); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
_, err = d.db.Exec(`
|
||||
CREATE TABLE IF NOT EXISTS msc2946_edges (
|
||||
room_version TEXT NOT NULL,
|
||||
-- the room ID of the event, the source of the arrow
|
||||
source_room_id TEXT NOT NULL,
|
||||
-- the target room ID, the arrow destination
|
||||
dest_room_id TEXT NOT NULL,
|
||||
-- the kind of relation, either child or parent (1,2)
|
||||
rel_type SMALLINT NOT NULL,
|
||||
event_json TEXT NOT NULL,
|
||||
CONSTRAINT msc2946_edges_uniq UNIQUE (source_room_id, dest_room_id, rel_type)
|
||||
);
|
||||
`)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if d.insertEdgeStmt, err = d.db.Prepare(`
|
||||
INSERT INTO msc2946_edges(room_version, source_room_id, dest_room_id, rel_type, event_json)
|
||||
VALUES($1, $2, $3, $4, $5)
|
||||
ON CONFLICT ON CONSTRAINT msc2946_edges_uniq DO UPDATE SET event_json = $5
|
||||
`); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if d.selectEdgesStmt, err = d.db.Prepare(`
|
||||
SELECT room_version, event_json FROM msc2946_edges
|
||||
WHERE source_room_id = $1 OR dest_room_id = $2
|
||||
`); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &d, err
|
||||
}
|
||||
|
||||
func newSQLiteDatabase(dbOpts *config.DatabaseOptions) (Database, error) {
|
||||
d := DB{
|
||||
writer: sqlutil.NewExclusiveWriter(),
|
||||
}
|
||||
var err error
|
||||
if d.db, err = sqlutil.Open(dbOpts); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
_, err = d.db.Exec(`
|
||||
CREATE TABLE IF NOT EXISTS msc2946_edges (
|
||||
room_version TEXT NOT NULL,
|
||||
-- the room ID of the event, the source of the arrow
|
||||
source_room_id TEXT NOT NULL,
|
||||
-- the target room ID, the arrow destination
|
||||
dest_room_id TEXT NOT NULL,
|
||||
-- the kind of relation, either child or parent (1,2)
|
||||
rel_type SMALLINT NOT NULL,
|
||||
event_json TEXT NOT NULL,
|
||||
UNIQUE (source_room_id, dest_room_id, rel_type)
|
||||
);
|
||||
`)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if d.insertEdgeStmt, err = d.db.Prepare(`
|
||||
INSERT INTO msc2946_edges(room_version, source_room_id, dest_room_id, rel_type, event_json)
|
||||
VALUES($1, $2, $3, $4, $5)
|
||||
ON CONFLICT (source_room_id, dest_room_id, rel_type) DO UPDATE SET event_json = $5
|
||||
`); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if d.selectEdgesStmt, err = d.db.Prepare(`
|
||||
SELECT room_version, event_json FROM msc2946_edges
|
||||
WHERE source_room_id = $1 OR dest_room_id = $2
|
||||
`); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &d, err
|
||||
}
|
||||
|
||||
func (d *DB) StoreReference(ctx context.Context, he *gomatrixserverlib.HeaderedEvent) error {
|
||||
target := SpaceTarget(he)
|
||||
if target == "" {
|
||||
return nil // malformed event
|
||||
}
|
||||
relType := relTypes[he.Type()]
|
||||
_, err := d.insertEdgeStmt.ExecContext(ctx, he.RoomVersion, he.RoomID(), target, relType, he.JSON())
|
||||
return err
|
||||
}
|
||||
|
||||
func (d *DB) References(ctx context.Context, roomID string) ([]*gomatrixserverlib.HeaderedEvent, error) {
|
||||
rows, err := d.selectEdgesStmt.QueryContext(ctx, roomID, roomID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "failed to close References")
|
||||
refs := make([]*gomatrixserverlib.HeaderedEvent, 0)
|
||||
for rows.Next() {
|
||||
var roomVer string
|
||||
var jsonBytes []byte
|
||||
if err := rows.Scan(&roomVer, &jsonBytes); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ev, err := gomatrixserverlib.NewEventFromTrustedJSON(jsonBytes, false, gomatrixserverlib.RoomVersion(roomVer))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
he := ev.Headered(gomatrixserverlib.RoomVersion(roomVer))
|
||||
refs = append(refs, he)
|
||||
}
|
||||
return refs, nil
|
||||
}
|
||||
|
||||
// SpaceTarget returns the destination room ID for the space event. This is either a child or a parent
|
||||
// depending on the event type.
|
||||
func SpaceTarget(he *gomatrixserverlib.HeaderedEvent) string {
|
||||
if he.StateKey() == nil {
|
||||
return "" // no-op
|
||||
}
|
||||
switch he.Type() {
|
||||
case ConstSpaceParentEventType:
|
||||
return *he.StateKey()
|
||||
case ConstSpaceChildEventType:
|
||||
return *he.StateKey()
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
|
@ -42,7 +42,7 @@ func EnableMSC(base *base.BaseDendrite, monolith *setup.Monolith, msc string) er
|
|||
case "msc2836":
|
||||
return msc2836.Enable(base, monolith.RoomserverAPI, monolith.FederationAPI, monolith.UserAPI, monolith.KeyRing)
|
||||
case "msc2946":
|
||||
return msc2946.Enable(base, monolith.RoomserverAPI, monolith.UserAPI, monolith.FederationAPI, monolith.KeyRing)
|
||||
return msc2946.Enable(base, monolith.RoomserverAPI, monolith.UserAPI, monolith.FederationAPI, monolith.KeyRing, base.Caches)
|
||||
case "msc2444": // enabled inside federationapi
|
||||
case "msc2753": // enabled inside clientapi
|
||||
default:
|
||||
|
|
|
|||
|
|
@ -82,7 +82,16 @@ func DeviceListCatchup(
|
|||
util.GetLogger(ctx).WithError(queryRes.Error).Error("QueryKeyChanges failed")
|
||||
return to, hasNew, nil
|
||||
}
|
||||
// QueryKeyChanges gets ALL users who have changed keys, we want the ones who share rooms with the user.
|
||||
|
||||
// Work out which user IDs we care about — that includes those in the original request,
|
||||
// the response from QueryKeyChanges (which includes ALL users who have changed keys)
|
||||
// as well as every user who has a join or leave event in the current sync response. We
|
||||
// will request information about which rooms these users are joined to, so that we can
|
||||
// see if we still share any rooms with them.
|
||||
joinUserIDs, leaveUserIDs := membershipEvents(res)
|
||||
queryRes.UserIDs = append(queryRes.UserIDs, joinUserIDs...)
|
||||
queryRes.UserIDs = append(queryRes.UserIDs, leaveUserIDs...)
|
||||
queryRes.UserIDs = util.UniqueStrings(queryRes.UserIDs)
|
||||
var sharedUsersMap map[string]int
|
||||
sharedUsersMap, queryRes.UserIDs = filterSharedUsers(ctx, rsAPI, userID, queryRes.UserIDs)
|
||||
util.GetLogger(ctx).Debugf(
|
||||
|
|
@ -100,9 +109,8 @@ func DeviceListCatchup(
|
|||
userSet[userID] = true
|
||||
}
|
||||
}
|
||||
// if the response has any join/leave events, add them now.
|
||||
// Finally, add in users who have joined or left.
|
||||
// TODO: This is sub-optimal because we will add users to `changed` even if we already shared a room with them.
|
||||
joinUserIDs, leaveUserIDs := membershipEvents(res)
|
||||
for _, userID := range joinUserIDs {
|
||||
if !userSet[userID] {
|
||||
res.DeviceLists.Changed = append(res.DeviceLists.Changed, userID)
|
||||
|
|
@ -214,6 +222,7 @@ func filterSharedUsers(
|
|||
var sharedUsersRes roomserverAPI.QuerySharedUsersResponse
|
||||
err := rsAPI.QuerySharedUsers(ctx, &roomserverAPI.QuerySharedUsersRequest{
|
||||
UserID: userID,
|
||||
OtherUserIDs: usersWithChangedKeys,
|
||||
}, &sharedUsersRes)
|
||||
if err != nil {
|
||||
// default to all users so we do needless queries rather than miss some important device update
|
||||
|
|
|
|||
|
|
@ -44,7 +44,7 @@ func Context(
|
|||
syncDB storage.Database,
|
||||
roomID, eventID string,
|
||||
) util.JSONResponse {
|
||||
filter, err := parseContextParams(req)
|
||||
filter, err := parseRoomEventFilter(req)
|
||||
if err != nil {
|
||||
errMsg := ""
|
||||
switch err.(type) {
|
||||
|
|
@ -164,7 +164,7 @@ func applyLazyLoadMembers(filter *gomatrixserverlib.RoomEventFilter, eventsAfter
|
|||
return newState
|
||||
}
|
||||
|
||||
func parseContextParams(req *http.Request) (*gomatrixserverlib.RoomEventFilter, error) {
|
||||
func parseRoomEventFilter(req *http.Request) (*gomatrixserverlib.RoomEventFilter, error) {
|
||||
// Default room filter
|
||||
filter := &gomatrixserverlib.RoomEventFilter{Limit: 10}
|
||||
|
||||
|
|
|
|||
|
|
@ -55,13 +55,13 @@ func Test_parseContextParams(t *testing.T) {
|
|||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotFilter, err := parseContextParams(tt.req)
|
||||
gotFilter, err := parseRoomEventFilter(tt.req)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("parseContextParams() error = %v, wantErr %v", err, tt.wantErr)
|
||||
t.Errorf("parseRoomEventFilter() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(gotFilter, tt.wantFilter) {
|
||||
t.Errorf("parseContextParams() gotFilter = %v, want %v", gotFilter, tt.wantFilter)
|
||||
t.Errorf("parseRoomEventFilter() gotFilter = %v, want %v", gotFilter, tt.wantFilter)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -19,7 +19,6 @@ import (
|
|||
"fmt"
|
||||
"net/http"
|
||||
"sort"
|
||||
"strconv"
|
||||
|
||||
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
||||
"github.com/matrix-org/dendrite/roomserver/api"
|
||||
|
|
@ -45,8 +44,8 @@ type messagesReq struct {
|
|||
fromStream *types.StreamingToken
|
||||
device *userapi.Device
|
||||
wasToProvided bool
|
||||
limit int
|
||||
backwardOrdering bool
|
||||
filter *gomatrixserverlib.RoomEventFilter
|
||||
}
|
||||
|
||||
type messagesResp struct {
|
||||
|
|
@ -54,10 +53,9 @@ type messagesResp struct {
|
|||
StartStream string `json:"start_stream,omitempty"` // NOTSPEC: so clients can hit /messages then immediately /sync with a latest sync token
|
||||
End string `json:"end"`
|
||||
Chunk []gomatrixserverlib.ClientEvent `json:"chunk"`
|
||||
State []gomatrixserverlib.ClientEvent `json:"state"`
|
||||
}
|
||||
|
||||
const defaultMessagesLimit = 10
|
||||
|
||||
// OnIncomingMessagesRequest implements the /messages endpoint from the
|
||||
// client-server API.
|
||||
// See: https://matrix.org/docs/spec/client_server/latest.html#get-matrix-client-r0-rooms-roomid-messages
|
||||
|
|
@ -83,6 +81,14 @@ func OnIncomingMessagesRequest(
|
|||
}
|
||||
}
|
||||
|
||||
filter, err := parseRoomEventFilter(req)
|
||||
if err != nil {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: jsonerror.InvalidArgumentValue("unable to parse filter"),
|
||||
}
|
||||
}
|
||||
|
||||
// Extract parameters from the request's URL.
|
||||
// Pagination tokens.
|
||||
var fromStream *types.StreamingToken
|
||||
|
|
@ -143,18 +149,6 @@ func OnIncomingMessagesRequest(
|
|||
wasToProvided = false
|
||||
}
|
||||
|
||||
// Maximum number of events to return; defaults to 10.
|
||||
limit := defaultMessagesLimit
|
||||
if len(req.URL.Query().Get("limit")) > 0 {
|
||||
limit, err = strconv.Atoi(req.URL.Query().Get("limit"))
|
||||
|
||||
if err != nil {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: jsonerror.InvalidArgumentValue("limit could not be parsed into an integer: " + err.Error()),
|
||||
}
|
||||
}
|
||||
}
|
||||
// TODO: Implement filtering (#587)
|
||||
|
||||
// Check the room ID's format.
|
||||
|
|
@ -176,7 +170,7 @@ func OnIncomingMessagesRequest(
|
|||
to: &to,
|
||||
fromStream: fromStream,
|
||||
wasToProvided: wasToProvided,
|
||||
limit: limit,
|
||||
filter: filter,
|
||||
backwardOrdering: backwardOrdering,
|
||||
device: device,
|
||||
}
|
||||
|
|
@ -187,10 +181,27 @@ func OnIncomingMessagesRequest(
|
|||
return jsonerror.InternalServerError()
|
||||
}
|
||||
|
||||
// at least fetch the membership events for the users returned in chunk if LazyLoadMembers is set
|
||||
state := []gomatrixserverlib.ClientEvent{}
|
||||
if filter.LazyLoadMembers {
|
||||
memberShipToUser := make(map[string]*gomatrixserverlib.HeaderedEvent)
|
||||
for _, evt := range clientEvents {
|
||||
memberShip, err := db.GetStateEvent(req.Context(), roomID, gomatrixserverlib.MRoomMember, evt.Sender)
|
||||
if err != nil {
|
||||
util.GetLogger(req.Context()).WithError(err).Error("failed to get membership event for user")
|
||||
continue
|
||||
}
|
||||
memberShipToUser[evt.Sender] = memberShip
|
||||
}
|
||||
for _, evt := range memberShipToUser {
|
||||
state = append(state, gomatrixserverlib.HeaderedToClientEvent(evt, gomatrixserverlib.FormatAll))
|
||||
}
|
||||
}
|
||||
|
||||
util.GetLogger(req.Context()).WithFields(logrus.Fields{
|
||||
"from": from.String(),
|
||||
"to": to.String(),
|
||||
"limit": limit,
|
||||
"limit": filter.Limit,
|
||||
"backwards": backwardOrdering,
|
||||
"return_start": start.String(),
|
||||
"return_end": end.String(),
|
||||
|
|
@ -200,6 +211,7 @@ func OnIncomingMessagesRequest(
|
|||
Chunk: clientEvents,
|
||||
Start: start.String(),
|
||||
End: end.String(),
|
||||
State: state,
|
||||
}
|
||||
if emptyFromSupplied {
|
||||
res.StartStream = fromStream.String()
|
||||
|
|
@ -234,19 +246,18 @@ func (r *messagesReq) retrieveEvents() (
|
|||
clientEvents []gomatrixserverlib.ClientEvent, start,
|
||||
end types.TopologyToken, err error,
|
||||
) {
|
||||
eventFilter := gomatrixserverlib.DefaultRoomEventFilter()
|
||||
eventFilter.Limit = r.limit
|
||||
eventFilter := r.filter
|
||||
|
||||
// Retrieve the events from the local database.
|
||||
var streamEvents []types.StreamEvent
|
||||
if r.fromStream != nil {
|
||||
toStream := r.to.StreamToken()
|
||||
streamEvents, err = r.db.GetEventsInStreamingRange(
|
||||
r.ctx, r.fromStream, &toStream, r.roomID, &eventFilter, r.backwardOrdering,
|
||||
r.ctx, r.fromStream, &toStream, r.roomID, eventFilter, r.backwardOrdering,
|
||||
)
|
||||
} else {
|
||||
streamEvents, err = r.db.GetEventsInTopologicalRange(
|
||||
r.ctx, r.from, r.to, r.roomID, r.limit, r.backwardOrdering,
|
||||
r.ctx, r.from, r.to, r.roomID, eventFilter.Limit, r.backwardOrdering,
|
||||
)
|
||||
}
|
||||
if err != nil {
|
||||
|
|
@ -434,7 +445,7 @@ func (r *messagesReq) handleEmptyEventsSlice() (
|
|||
// Check if we have backward extremities for this room.
|
||||
if len(backwardExtremities) > 0 {
|
||||
// If so, retrieve as much events as needed through backfilling.
|
||||
events, err = r.backfill(r.roomID, backwardExtremities, r.limit)
|
||||
events, err = r.backfill(r.roomID, backwardExtremities, r.filter.Limit)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
|
@ -456,7 +467,7 @@ func (r *messagesReq) handleNonEmptyEventsSlice(streamEvents []types.StreamEvent
|
|||
events []*gomatrixserverlib.HeaderedEvent, err error,
|
||||
) {
|
||||
// Check if we have enough events.
|
||||
isSetLargeEnough := len(streamEvents) >= r.limit
|
||||
isSetLargeEnough := len(streamEvents) >= r.filter.Limit
|
||||
if !isSetLargeEnough {
|
||||
// it might be fine we don't have up to 'limit' events, let's find out
|
||||
if r.backwardOrdering {
|
||||
|
|
@ -483,7 +494,7 @@ func (r *messagesReq) handleNonEmptyEventsSlice(streamEvents []types.StreamEvent
|
|||
if len(backwardExtremities) > 0 && !isSetLargeEnough && r.backwardOrdering {
|
||||
var pdus []*gomatrixserverlib.HeaderedEvent
|
||||
// Only ask the remote server for enough events to reach the limit.
|
||||
pdus, err = r.backfill(r.roomID, backwardExtremities, r.limit-len(streamEvents))
|
||||
pdus, err = r.backfill(r.roomID, backwardExtremities, r.filter.Limit-len(streamEvents))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
|
|
|||
|
|
@ -96,7 +96,7 @@ func (r *receiptStatements) UpsertReceipt(ctx context.Context, txn *sql.Tx, room
|
|||
}
|
||||
|
||||
func (r *receiptStatements) SelectRoomReceiptsAfter(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) (types.StreamPosition, []api.OutputReceiptEvent, error) {
|
||||
lastPos := streamPos
|
||||
var lastPos types.StreamPosition
|
||||
rows, err := r.selectRoomReceipts.QueryContext(ctx, pq.Array(roomIDs), streamPos)
|
||||
if err != nil {
|
||||
return 0, nil, fmt.Errorf("unable to query room receipts: %w", err)
|
||||
|
|
|
|||
|
|
@ -101,7 +101,7 @@ func (r *receiptStatements) UpsertReceipt(ctx context.Context, txn *sql.Tx, room
|
|||
// SelectRoomReceiptsAfter select all receipts for a given room after a specific timestamp
|
||||
func (r *receiptStatements) SelectRoomReceiptsAfter(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) (types.StreamPosition, []api.OutputReceiptEvent, error) {
|
||||
selectSQL := strings.Replace(selectRoomReceipts, "($2)", sqlutil.QueryVariadicOffset(len(roomIDs), 1), 1)
|
||||
lastPos := streamPos
|
||||
var lastPos types.StreamPosition
|
||||
params := make([]interface{}, len(roomIDs)+1)
|
||||
params[0] = streamPos
|
||||
for k, v := range roomIDs {
|
||||
|
|
|
|||
|
|
@ -63,7 +63,6 @@ func (p *ReceiptStreamProvider) IncrementalSync(
|
|||
if existing, ok := req.Response.Rooms.Join[roomID]; ok {
|
||||
jr = existing
|
||||
}
|
||||
var ok bool
|
||||
|
||||
ev := gomatrixserverlib.ClientEvent{
|
||||
Type: gomatrixserverlib.MReceipt,
|
||||
|
|
@ -71,8 +70,8 @@ func (p *ReceiptStreamProvider) IncrementalSync(
|
|||
}
|
||||
content := make(map[string]eduAPI.ReceiptMRead)
|
||||
for _, receipt := range receipts {
|
||||
var read eduAPI.ReceiptMRead
|
||||
if read, ok = content[receipt.EventID]; !ok {
|
||||
read, ok := content[receipt.EventID]
|
||||
if !ok {
|
||||
read = eduAPI.ReceiptMRead{
|
||||
User: make(map[string]eduAPI.ReceiptTS),
|
||||
}
|
||||
|
|
|
|||
|
|
@ -74,7 +74,7 @@ func (s *Streams) Latest(ctx context.Context) types.StreamingToken {
|
|||
return types.StreamingToken{
|
||||
PDUPosition: s.PDUStreamProvider.LatestPosition(ctx),
|
||||
TypingPosition: s.TypingStreamProvider.LatestPosition(ctx),
|
||||
ReceiptPosition: s.PDUStreamProvider.LatestPosition(ctx),
|
||||
ReceiptPosition: s.ReceiptStreamProvider.LatestPosition(ctx),
|
||||
InvitePosition: s.InviteStreamProvider.LatestPosition(ctx),
|
||||
SendToDevicePosition: s.SendToDeviceStreamProvider.LatestPosition(ctx),
|
||||
AccountDataPosition: s.AccountDataStreamProvider.LatestPosition(ctx),
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ Local device key changes get to remote servers with correct prev_id
|
|||
|
||||
# Flakey
|
||||
Local device key changes appear in /keys/changes
|
||||
/context/ with lazy_load_members filter works
|
||||
|
||||
# we don't support groups
|
||||
Remove group category
|
||||
|
|
|
|||
|
|
@ -652,7 +652,14 @@ Device list doesn't change if remote server is down
|
|||
/context/ on non world readable room does not work
|
||||
/context/ returns correct number of events
|
||||
/context/ with lazy_load_members filter works
|
||||
GET /rooms/:room_id/messages lazy loads members correctly
|
||||
Can query remote device keys using POST after notification
|
||||
Device deletion propagates over federation
|
||||
Get left notifs in sync and /keys/changes when other user leaves
|
||||
Remote banned user is kicked and may not rejoin until unbanned
|
||||
registration remembers parameters
|
||||
registration accepts non-ascii passwords
|
||||
registration with inhibit_login inhibits login
|
||||
The operation must be consistent through an interactive authentication session
|
||||
Multiple calls to /sync should not cause 500 errors
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue